AstrBot 4.1.4__py3-none-any.whl → 4.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/core/agent/agent.py +1 -1
- astrbot/core/agent/mcp_client.py +3 -1
- astrbot/core/agent/runners/tool_loop_agent_runner.py +6 -27
- astrbot/core/agent/tool.py +28 -17
- astrbot/core/config/default.py +50 -14
- astrbot/core/db/sqlite.py +16 -1
- astrbot/core/pipeline/content_safety_check/stage.py +1 -1
- astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +1 -1
- astrbot/core/pipeline/content_safety_check/strategies/keywords.py +1 -1
- astrbot/core/pipeline/context_utils.py +4 -1
- astrbot/core/pipeline/process_stage/method/llm_request.py +23 -4
- astrbot/core/pipeline/process_stage/method/star_request.py +8 -6
- astrbot/core/platform/manager.py +4 -0
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +2 -1
- astrbot/core/platform/sources/misskey/misskey_adapter.py +391 -0
- astrbot/core/platform/sources/misskey/misskey_api.py +404 -0
- astrbot/core/platform/sources/misskey/misskey_event.py +123 -0
- astrbot/core/platform/sources/misskey/misskey_utils.py +327 -0
- astrbot/core/platform/sources/satori/satori_adapter.py +290 -24
- astrbot/core/platform/sources/satori/satori_event.py +9 -0
- astrbot/core/platform/sources/telegram/tg_event.py +0 -1
- astrbot/core/provider/entities.py +13 -3
- astrbot/core/provider/func_tool_manager.py +4 -4
- astrbot/core/provider/manager.py +53 -24
- astrbot/core/star/context.py +31 -14
- astrbot/core/star/filter/command_group.py +4 -4
- astrbot/core/star/filter/platform_adapter_type.py +10 -5
- astrbot/core/star/register/star.py +3 -1
- astrbot/core/star/register/star_handler.py +65 -36
- astrbot/core/star/session_plugin_manager.py +3 -0
- astrbot/core/star/star_handler.py +4 -4
- astrbot/core/star/star_manager.py +10 -4
- astrbot/core/star/star_tools.py +6 -2
- astrbot/core/star/updator.py +3 -0
- {astrbot-4.1.4.dist-info → astrbot-4.1.6.dist-info}/METADATA +6 -7
- {astrbot-4.1.4.dist-info → astrbot-4.1.6.dist-info}/RECORD +39 -35
- {astrbot-4.1.4.dist-info → astrbot-4.1.6.dist-info}/WHEEL +0 -0
- {astrbot-4.1.4.dist-info → astrbot-4.1.6.dist-info}/entry_points.txt +0 -0
- {astrbot-4.1.4.dist-info → astrbot-4.1.6.dist-info}/licenses/LICENSE +0 -0
|
@@ -17,7 +17,14 @@ from astrbot.api.platform import (
|
|
|
17
17
|
register_platform_adapter,
|
|
18
18
|
)
|
|
19
19
|
from astrbot.core.platform.astr_message_event import MessageSession
|
|
20
|
-
from astrbot.api.message_components import
|
|
20
|
+
from astrbot.api.message_components import (
|
|
21
|
+
Plain,
|
|
22
|
+
Image,
|
|
23
|
+
At,
|
|
24
|
+
File,
|
|
25
|
+
Record,
|
|
26
|
+
Reply,
|
|
27
|
+
)
|
|
21
28
|
from xml.etree import ElementTree as ET
|
|
22
29
|
|
|
23
30
|
|
|
@@ -38,12 +45,18 @@ class SatoriPlatformAdapter(Platform):
|
|
|
38
45
|
)
|
|
39
46
|
self.token = self.config.get("satori_token", "")
|
|
40
47
|
self.endpoint = self.config.get(
|
|
41
|
-
"satori_endpoint", "ws://
|
|
48
|
+
"satori_endpoint", "ws://localhost:5140/satori/v1/events"
|
|
42
49
|
)
|
|
43
50
|
self.auto_reconnect = self.config.get("satori_auto_reconnect", True)
|
|
44
51
|
self.heartbeat_interval = self.config.get("satori_heartbeat_interval", 10)
|
|
45
52
|
self.reconnect_delay = self.config.get("satori_reconnect_delay", 5)
|
|
46
53
|
|
|
54
|
+
self.metadata = PlatformMetadata(
|
|
55
|
+
name="satori",
|
|
56
|
+
description="Satori 通用协议适配器",
|
|
57
|
+
id=self.config["id"],
|
|
58
|
+
)
|
|
59
|
+
|
|
47
60
|
self.ws: Optional[ClientConnection] = None
|
|
48
61
|
self.session: Optional[ClientSession] = None
|
|
49
62
|
self.sequence = 0
|
|
@@ -63,7 +76,7 @@ class SatoriPlatformAdapter(Platform):
|
|
|
63
76
|
await super().send_by_session(session, message_chain)
|
|
64
77
|
|
|
65
78
|
def meta(self) -> PlatformMetadata:
|
|
66
|
-
return
|
|
79
|
+
return self.metadata
|
|
67
80
|
|
|
68
81
|
def _is_websocket_closed(self, ws) -> bool:
|
|
69
82
|
"""检查WebSocket连接是否已关闭"""
|
|
@@ -312,12 +325,52 @@ class SatoriPlatformAdapter(Platform):
|
|
|
312
325
|
|
|
313
326
|
abm.self_id = login.get("user", {}).get("id", "")
|
|
314
327
|
|
|
328
|
+
# 消息链
|
|
329
|
+
abm.message = []
|
|
330
|
+
|
|
315
331
|
content = message.get("content", "")
|
|
316
|
-
abm.message = await self.parse_satori_elements(content)
|
|
317
332
|
|
|
318
|
-
|
|
333
|
+
quote = message.get("quote")
|
|
334
|
+
content_for_parsing = content # 副本
|
|
335
|
+
|
|
336
|
+
# 提取<quote>标签
|
|
337
|
+
if "<quote" in content:
|
|
338
|
+
try:
|
|
339
|
+
quote_info = await self._extract_quote_element(content)
|
|
340
|
+
if quote_info:
|
|
341
|
+
quote = quote_info["quote"]
|
|
342
|
+
content_for_parsing = quote_info["content_without_quote"]
|
|
343
|
+
except Exception as e:
|
|
344
|
+
logger.error(f"解析<quote>标签时发生错误: {e}, 错误内容: {content}")
|
|
345
|
+
|
|
346
|
+
if quote:
|
|
347
|
+
# 引用消息
|
|
348
|
+
quote_abm = await self._convert_quote_message(quote)
|
|
349
|
+
if quote_abm:
|
|
350
|
+
sender_id = quote_abm.sender.user_id
|
|
351
|
+
if isinstance(sender_id, str) and sender_id.isdigit():
|
|
352
|
+
sender_id = int(sender_id)
|
|
353
|
+
elif not isinstance(sender_id, int):
|
|
354
|
+
sender_id = 0 # 默认值
|
|
355
|
+
|
|
356
|
+
reply_component = Reply(
|
|
357
|
+
id=quote_abm.message_id,
|
|
358
|
+
chain=quote_abm.message,
|
|
359
|
+
sender_id=quote_abm.sender.user_id,
|
|
360
|
+
sender_nickname=quote_abm.sender.nickname,
|
|
361
|
+
time=quote_abm.timestamp,
|
|
362
|
+
message_str=quote_abm.message_str,
|
|
363
|
+
text=quote_abm.message_str,
|
|
364
|
+
qq=sender_id,
|
|
365
|
+
)
|
|
366
|
+
abm.message.append(reply_component)
|
|
367
|
+
|
|
368
|
+
# 解析消息内容
|
|
369
|
+
content_elements = await self.parse_satori_elements(content_for_parsing)
|
|
370
|
+
abm.message.extend(content_elements)
|
|
371
|
+
|
|
319
372
|
abm.message_str = ""
|
|
320
|
-
for comp in
|
|
373
|
+
for comp in content_elements:
|
|
321
374
|
if isinstance(comp, Plain):
|
|
322
375
|
abm.message_str += comp.text
|
|
323
376
|
|
|
@@ -333,6 +386,163 @@ class SatoriPlatformAdapter(Platform):
|
|
|
333
386
|
logger.error(f"转换 Satori 消息失败: {e}")
|
|
334
387
|
return None
|
|
335
388
|
|
|
389
|
+
def _extract_namespace_prefixes(self, content: str) -> set:
|
|
390
|
+
"""提取XML内容中的命名空间前缀"""
|
|
391
|
+
prefixes = set()
|
|
392
|
+
|
|
393
|
+
# 查找所有标签
|
|
394
|
+
i = 0
|
|
395
|
+
while i < len(content):
|
|
396
|
+
# 查找开始标签
|
|
397
|
+
if content[i] == "<" and i + 1 < len(content) and content[i + 1] != "/":
|
|
398
|
+
# 找到标签结束位置
|
|
399
|
+
tag_end = content.find(">", i)
|
|
400
|
+
if tag_end != -1:
|
|
401
|
+
# 提取标签内容
|
|
402
|
+
tag_content = content[i + 1 : tag_end]
|
|
403
|
+
# 检查是否有命名空间前缀
|
|
404
|
+
if ":" in tag_content and "xmlns:" not in tag_content:
|
|
405
|
+
# 分割标签名
|
|
406
|
+
parts = tag_content.split()
|
|
407
|
+
if parts:
|
|
408
|
+
tag_name = parts[0]
|
|
409
|
+
if ":" in tag_name:
|
|
410
|
+
prefix = tag_name.split(":")[0]
|
|
411
|
+
# 确保是有效的命名空间前缀
|
|
412
|
+
if (
|
|
413
|
+
prefix.isalnum()
|
|
414
|
+
or prefix.replace("_", "").isalnum()
|
|
415
|
+
):
|
|
416
|
+
prefixes.add(prefix)
|
|
417
|
+
i = tag_end + 1
|
|
418
|
+
else:
|
|
419
|
+
i += 1
|
|
420
|
+
# 查找结束标签
|
|
421
|
+
elif content[i] == "<" and i + 1 < len(content) and content[i + 1] == "/":
|
|
422
|
+
# 找到标签结束位置
|
|
423
|
+
tag_end = content.find(">", i)
|
|
424
|
+
if tag_end != -1:
|
|
425
|
+
# 提取标签内容
|
|
426
|
+
tag_content = content[i + 2 : tag_end]
|
|
427
|
+
# 检查是否有命名空间前缀
|
|
428
|
+
if ":" in tag_content:
|
|
429
|
+
prefix = tag_content.split(":")[0]
|
|
430
|
+
# 确保是有效的命名空间前缀
|
|
431
|
+
if prefix.isalnum() or prefix.replace("_", "").isalnum():
|
|
432
|
+
prefixes.add(prefix)
|
|
433
|
+
i = tag_end + 1
|
|
434
|
+
else:
|
|
435
|
+
i += 1
|
|
436
|
+
else:
|
|
437
|
+
i += 1
|
|
438
|
+
|
|
439
|
+
return prefixes
|
|
440
|
+
|
|
441
|
+
async def _extract_quote_element(self, content: str) -> Optional[dict]:
|
|
442
|
+
"""提取<quote>标签信息"""
|
|
443
|
+
try:
|
|
444
|
+
# 处理命名空间前缀问题
|
|
445
|
+
processed_content = content
|
|
446
|
+
if ":" in content and not content.startswith("<root"):
|
|
447
|
+
prefixes = self._extract_namespace_prefixes(content)
|
|
448
|
+
|
|
449
|
+
# 构建命名空间声明
|
|
450
|
+
ns_declarations = " ".join(
|
|
451
|
+
[
|
|
452
|
+
f'xmlns:{prefix}="http://temp.uri/{prefix}"'
|
|
453
|
+
for prefix in prefixes
|
|
454
|
+
]
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
# 包装内容
|
|
458
|
+
processed_content = f"<root {ns_declarations}>{content}</root>"
|
|
459
|
+
elif not content.startswith("<root"):
|
|
460
|
+
processed_content = f"<root>{content}</root>"
|
|
461
|
+
else:
|
|
462
|
+
processed_content = content
|
|
463
|
+
|
|
464
|
+
root = ET.fromstring(processed_content)
|
|
465
|
+
|
|
466
|
+
# 查找<quote>标签
|
|
467
|
+
quote_element = None
|
|
468
|
+
for elem in root.iter():
|
|
469
|
+
tag_name = elem.tag
|
|
470
|
+
if "}" in tag_name:
|
|
471
|
+
tag_name = tag_name.split("}")[1]
|
|
472
|
+
if tag_name.lower() == "quote":
|
|
473
|
+
quote_element = elem
|
|
474
|
+
break
|
|
475
|
+
|
|
476
|
+
if quote_element is not None:
|
|
477
|
+
# 提取quote标签的属性
|
|
478
|
+
quote_id = quote_element.get("id", "")
|
|
479
|
+
|
|
480
|
+
# 提取<quote>标签内部的内容
|
|
481
|
+
inner_content = ""
|
|
482
|
+
if quote_element.text:
|
|
483
|
+
inner_content += quote_element.text
|
|
484
|
+
for child in quote_element:
|
|
485
|
+
inner_content += ET.tostring(
|
|
486
|
+
child, encoding="unicode", method="xml"
|
|
487
|
+
)
|
|
488
|
+
if child.tail:
|
|
489
|
+
inner_content += child.tail
|
|
490
|
+
|
|
491
|
+
# 构造移除了<quote>标签的内容
|
|
492
|
+
content_without_quote = content.replace(
|
|
493
|
+
ET.tostring(quote_element, encoding="unicode", method="xml"), ""
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
return {
|
|
497
|
+
"quote": {"id": quote_id, "content": inner_content},
|
|
498
|
+
"content_without_quote": content_without_quote,
|
|
499
|
+
}
|
|
500
|
+
|
|
501
|
+
return None
|
|
502
|
+
except Exception as e:
|
|
503
|
+
logger.error(f"提取<quote>标签时发生错误: {e}")
|
|
504
|
+
return None
|
|
505
|
+
|
|
506
|
+
async def _convert_quote_message(self, quote: dict) -> Optional[AstrBotMessage]:
|
|
507
|
+
"""转换引用消息"""
|
|
508
|
+
try:
|
|
509
|
+
quote_abm = AstrBotMessage()
|
|
510
|
+
quote_abm.message_id = quote.get("id", "")
|
|
511
|
+
|
|
512
|
+
# 解析引用消息的发送者
|
|
513
|
+
quote_author = quote.get("author", {})
|
|
514
|
+
if quote_author:
|
|
515
|
+
quote_abm.sender = MessageMember(
|
|
516
|
+
user_id=quote_author.get("id", ""),
|
|
517
|
+
nickname=quote_author.get("nick", quote_author.get("name", "")),
|
|
518
|
+
)
|
|
519
|
+
else:
|
|
520
|
+
# 如果没有作者信息,使用默认值
|
|
521
|
+
quote_abm.sender = MessageMember(
|
|
522
|
+
user_id=quote.get("user_id", ""),
|
|
523
|
+
nickname="内容",
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
# 解析引用消息内容
|
|
527
|
+
quote_content = quote.get("content", "")
|
|
528
|
+
quote_abm.message = await self.parse_satori_elements(quote_content)
|
|
529
|
+
|
|
530
|
+
quote_abm.message_str = ""
|
|
531
|
+
for comp in quote_abm.message:
|
|
532
|
+
if isinstance(comp, Plain):
|
|
533
|
+
quote_abm.message_str += comp.text
|
|
534
|
+
|
|
535
|
+
quote_abm.timestamp = int(quote.get("timestamp", time.time()))
|
|
536
|
+
|
|
537
|
+
# 如果没有任何内容,使用默认文本
|
|
538
|
+
if not quote_abm.message_str.strip():
|
|
539
|
+
quote_abm.message_str = "[引用消息]"
|
|
540
|
+
|
|
541
|
+
return quote_abm
|
|
542
|
+
except Exception as e:
|
|
543
|
+
logger.error(f"转换引用消息失败: {e}")
|
|
544
|
+
return None
|
|
545
|
+
|
|
336
546
|
async def parse_satori_elements(self, content: str) -> list:
|
|
337
547
|
"""解析 Satori 消息元素"""
|
|
338
548
|
elements = []
|
|
@@ -341,12 +551,35 @@ class SatoriPlatformAdapter(Platform):
|
|
|
341
551
|
return elements
|
|
342
552
|
|
|
343
553
|
try:
|
|
344
|
-
|
|
345
|
-
|
|
554
|
+
# 处理命名空间前缀问题
|
|
555
|
+
processed_content = content
|
|
556
|
+
if ":" in content and not content.startswith("<root"):
|
|
557
|
+
prefixes = self._extract_namespace_prefixes(content)
|
|
558
|
+
|
|
559
|
+
# 构建命名空间声明
|
|
560
|
+
ns_declarations = " ".join(
|
|
561
|
+
[
|
|
562
|
+
f'xmlns:{prefix}="http://temp.uri/{prefix}"'
|
|
563
|
+
for prefix in prefixes
|
|
564
|
+
]
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
# 包装内容
|
|
568
|
+
processed_content = f"<root {ns_declarations}>{content}</root>"
|
|
569
|
+
elif not content.startswith("<root"):
|
|
570
|
+
processed_content = f"<root>{content}</root>"
|
|
571
|
+
else:
|
|
572
|
+
processed_content = content
|
|
573
|
+
|
|
574
|
+
root = ET.fromstring(processed_content)
|
|
346
575
|
await self._parse_xml_node(root, elements)
|
|
347
576
|
except ET.ParseError as e:
|
|
348
|
-
|
|
577
|
+
logger.error(f"解析 Satori 元素时发生解析错误: {e}, 错误内容: {content}")
|
|
578
|
+
# 如果解析失败,将整个内容当作纯文本
|
|
579
|
+
if content.strip():
|
|
580
|
+
elements.append(Plain(text=content))
|
|
349
581
|
except Exception as e:
|
|
582
|
+
logger.error(f"解析 Satori 元素时发生未知错误: {e}")
|
|
350
583
|
raise e
|
|
351
584
|
|
|
352
585
|
# 如果没有解析到任何元素,将整个内容当作纯文本
|
|
@@ -361,7 +594,12 @@ class SatoriPlatformAdapter(Platform):
|
|
|
361
594
|
elements.append(Plain(text=node.text))
|
|
362
595
|
|
|
363
596
|
for child in node:
|
|
364
|
-
|
|
597
|
+
# 获取标签名,去除命名空间前缀
|
|
598
|
+
tag_name = child.tag
|
|
599
|
+
if "}" in tag_name:
|
|
600
|
+
tag_name = tag_name.split("}")[1]
|
|
601
|
+
tag_name = tag_name.lower()
|
|
602
|
+
|
|
365
603
|
attrs = child.attrib
|
|
366
604
|
|
|
367
605
|
if tag_name == "at":
|
|
@@ -372,31 +610,59 @@ class SatoriPlatformAdapter(Platform):
|
|
|
372
610
|
src = attrs.get("src", "")
|
|
373
611
|
if not src:
|
|
374
612
|
continue
|
|
375
|
-
|
|
376
|
-
src = src.split(",")[1]
|
|
377
|
-
elements.append(Image.fromBase64(src))
|
|
378
|
-
elif src.startswith("http"):
|
|
379
|
-
elements.append(Image.fromURL(src))
|
|
380
|
-
else:
|
|
381
|
-
logger.error(f"未知的图片 src 格式: {str(src)[:16]}")
|
|
613
|
+
elements.append(Image(file=src))
|
|
382
614
|
|
|
383
615
|
elif tag_name == "file":
|
|
384
616
|
src = attrs.get("src", "")
|
|
385
617
|
name = attrs.get("name", "文件")
|
|
386
618
|
if src:
|
|
387
|
-
elements.append(File(
|
|
619
|
+
elements.append(File(name=name, file=src))
|
|
388
620
|
|
|
389
621
|
elif tag_name in ("audio", "record"):
|
|
390
622
|
src = attrs.get("src", "")
|
|
391
623
|
if not src:
|
|
392
624
|
continue
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
625
|
+
elements.append(Record(file=src))
|
|
626
|
+
|
|
627
|
+
elif tag_name == "quote":
|
|
628
|
+
# quote标签已经被特殊处理
|
|
629
|
+
pass
|
|
630
|
+
|
|
631
|
+
elif tag_name == "face":
|
|
632
|
+
face_id = attrs.get("id", "")
|
|
633
|
+
face_name = attrs.get("name", "")
|
|
634
|
+
face_type = attrs.get("type", "")
|
|
635
|
+
|
|
636
|
+
if face_name:
|
|
637
|
+
elements.append(Plain(text=f"[表情:{face_name}]"))
|
|
638
|
+
elif face_id and face_type:
|
|
639
|
+
elements.append(Plain(text=f"[表情ID:{face_id},类型:{face_type}]"))
|
|
640
|
+
elif face_id:
|
|
641
|
+
elements.append(Plain(text=f"[表情ID:{face_id}]"))
|
|
642
|
+
else:
|
|
643
|
+
elements.append(Plain(text="[表情]"))
|
|
644
|
+
|
|
645
|
+
elif tag_name == "ark":
|
|
646
|
+
# 作为纯文本添加到消息链中
|
|
647
|
+
data = attrs.get("data", "")
|
|
648
|
+
if data:
|
|
649
|
+
import html
|
|
650
|
+
|
|
651
|
+
decoded_data = html.unescape(data)
|
|
652
|
+
elements.append(Plain(text=f"[ARK卡片数据: {decoded_data}]"))
|
|
653
|
+
else:
|
|
654
|
+
elements.append(Plain(text="[ARK卡片]"))
|
|
655
|
+
|
|
656
|
+
elif tag_name == "json":
|
|
657
|
+
# JSON标签 视为ARK卡片消息
|
|
658
|
+
data = attrs.get("data", "")
|
|
659
|
+
if data:
|
|
660
|
+
import html
|
|
661
|
+
|
|
662
|
+
decoded_data = html.unescape(data)
|
|
663
|
+
elements.append(Plain(text=f"[ARK卡片数据: {decoded_data}]"))
|
|
398
664
|
else:
|
|
399
|
-
|
|
665
|
+
elements.append(Plain(text="[JSON卡片]"))
|
|
400
666
|
|
|
401
667
|
else:
|
|
402
668
|
# 未知标签,递归处理其内容
|
|
@@ -17,6 +17,15 @@ class SatoriPlatformEvent(AstrMessageEvent):
|
|
|
17
17
|
session_id: str,
|
|
18
18
|
adapter: "SatoriPlatformAdapter",
|
|
19
19
|
):
|
|
20
|
+
# 更新平台元数据
|
|
21
|
+
if adapter and hasattr(adapter, "logins") and adapter.logins:
|
|
22
|
+
current_login = adapter.logins[0]
|
|
23
|
+
platform_name = current_login.get("platform", "satori")
|
|
24
|
+
user = current_login.get("user", {})
|
|
25
|
+
user_id = user.get("id", "") if user else ""
|
|
26
|
+
if not platform_meta.id and user_id:
|
|
27
|
+
platform_meta.id = f"{platform_name}({user_id})"
|
|
28
|
+
|
|
20
29
|
super().__init__(message_str, message_obj, platform_meta, session_id)
|
|
21
30
|
self.adapter = adapter
|
|
22
31
|
self.platform = None
|
|
@@ -218,7 +218,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|
|
218
218
|
try:
|
|
219
219
|
msg = await self.client.send_message(text=delta, **payload)
|
|
220
220
|
current_content = delta
|
|
221
|
-
delta = ""
|
|
222
221
|
except Exception as e:
|
|
223
222
|
logger.warning(f"发送消息失败(streaming): {e!s}")
|
|
224
223
|
message_id = msg.message_id
|
|
@@ -65,13 +65,16 @@ class AssistantMessageSegment:
|
|
|
65
65
|
role: str = "assistant"
|
|
66
66
|
|
|
67
67
|
def to_dict(self):
|
|
68
|
-
ret = {
|
|
68
|
+
ret: dict[str, str | list[dict]] = {
|
|
69
69
|
"role": self.role,
|
|
70
70
|
}
|
|
71
71
|
if self.content:
|
|
72
72
|
ret["content"] = self.content
|
|
73
73
|
if self.tool_calls:
|
|
74
|
-
|
|
74
|
+
tool_calls_dict = [
|
|
75
|
+
tc if isinstance(tc, dict) else tc.to_dict() for tc in self.tool_calls
|
|
76
|
+
]
|
|
77
|
+
ret["tool_calls"] = tool_calls_dict
|
|
75
78
|
return ret
|
|
76
79
|
|
|
77
80
|
|
|
@@ -117,7 +120,14 @@ class ProviderRequest:
|
|
|
117
120
|
"""模型名称,为 None 时使用提供商的默认模型"""
|
|
118
121
|
|
|
119
122
|
def __repr__(self):
|
|
120
|
-
return
|
|
123
|
+
return (
|
|
124
|
+
f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, "
|
|
125
|
+
f"image_count={len(self.image_urls or [])}, "
|
|
126
|
+
f"func_tool={self.func_tool}, "
|
|
127
|
+
f"contexts={self._print_friendly_context()}, "
|
|
128
|
+
f"system_prompt={self.system_prompt}, "
|
|
129
|
+
f"conversation_id={self.conversation.cid if self.conversation else 'N/A'}, "
|
|
130
|
+
)
|
|
121
131
|
|
|
122
132
|
def __str__(self):
|
|
123
133
|
return self.__repr__()
|
|
@@ -4,7 +4,7 @@ import os
|
|
|
4
4
|
import asyncio
|
|
5
5
|
import aiohttp
|
|
6
6
|
|
|
7
|
-
from typing import Dict, List, Awaitable
|
|
7
|
+
from typing import Dict, List, Awaitable, Callable, Any
|
|
8
8
|
from astrbot import logger
|
|
9
9
|
from astrbot.core import sp
|
|
10
10
|
|
|
@@ -109,7 +109,7 @@ class FunctionToolManager:
|
|
|
109
109
|
name: str,
|
|
110
110
|
func_args: list,
|
|
111
111
|
desc: str,
|
|
112
|
-
handler: Awaitable,
|
|
112
|
+
handler: Callable[..., Awaitable[Any]],
|
|
113
113
|
) -> FuncTool:
|
|
114
114
|
params = {
|
|
115
115
|
"type": "object", # hard-coded here
|
|
@@ -132,7 +132,7 @@ class FunctionToolManager:
|
|
|
132
132
|
name: str,
|
|
133
133
|
func_args: list,
|
|
134
134
|
desc: str,
|
|
135
|
-
handler: Awaitable,
|
|
135
|
+
handler: Callable[..., Awaitable[Any]],
|
|
136
136
|
) -> None:
|
|
137
137
|
"""添加函数调用工具
|
|
138
138
|
|
|
@@ -220,7 +220,7 @@ class FunctionToolManager:
|
|
|
220
220
|
name: str,
|
|
221
221
|
cfg: dict,
|
|
222
222
|
event: asyncio.Event,
|
|
223
|
-
ready_future: asyncio.Future = None,
|
|
223
|
+
ready_future: asyncio.Future | None = None,
|
|
224
224
|
) -> None:
|
|
225
225
|
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
|
|
226
226
|
try:
|
astrbot/core/provider/manager.py
CHANGED
|
@@ -7,7 +7,13 @@ from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
|
|
7
7
|
from astrbot.core.db import BaseDatabase
|
|
8
8
|
|
|
9
9
|
from .entities import ProviderType
|
|
10
|
-
from .provider import
|
|
10
|
+
from .provider import (
|
|
11
|
+
Provider,
|
|
12
|
+
STTProvider,
|
|
13
|
+
TTSProvider,
|
|
14
|
+
EmbeddingProvider,
|
|
15
|
+
RerankProvider,
|
|
16
|
+
)
|
|
11
17
|
from .register import llm_tools, provider_cls_map
|
|
12
18
|
from ..persona_mgr import PersonaManager
|
|
13
19
|
|
|
@@ -38,7 +44,12 @@ class ProviderManager:
|
|
|
38
44
|
"""加载的 Text To Speech Provider 的实例"""
|
|
39
45
|
self.embedding_provider_insts: List[EmbeddingProvider] = []
|
|
40
46
|
"""加载的 Embedding Provider 的实例"""
|
|
41
|
-
self.
|
|
47
|
+
self.rerank_provider_insts: List[RerankProvider] = []
|
|
48
|
+
"""加载的 Rerank Provider 的实例"""
|
|
49
|
+
self.inst_map: dict[
|
|
50
|
+
str,
|
|
51
|
+
Provider | STTProvider | TTSProvider | EmbeddingProvider | RerankProvider,
|
|
52
|
+
] = {}
|
|
42
53
|
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
|
|
43
54
|
self.llm_tools = llm_tools
|
|
44
55
|
|
|
@@ -87,19 +98,31 @@ class ProviderManager:
|
|
|
87
98
|
)
|
|
88
99
|
return
|
|
89
100
|
# 不启用提供商会话隔离模式的情况
|
|
90
|
-
|
|
91
|
-
|
|
101
|
+
|
|
102
|
+
prov = self.inst_map[provider_id]
|
|
103
|
+
if provider_type == ProviderType.TEXT_TO_SPEECH and isinstance(
|
|
104
|
+
prov, TTSProvider
|
|
105
|
+
):
|
|
106
|
+
self.curr_tts_provider_inst = prov
|
|
92
107
|
sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global")
|
|
93
|
-
elif provider_type == ProviderType.SPEECH_TO_TEXT
|
|
108
|
+
elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance(
|
|
109
|
+
prov, STTProvider
|
|
110
|
+
):
|
|
111
|
+
self.curr_stt_provider_inst = prov
|
|
94
112
|
sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global")
|
|
95
|
-
elif provider_type == ProviderType.CHAT_COMPLETION
|
|
113
|
+
elif provider_type == ProviderType.CHAT_COMPLETION and isinstance(
|
|
114
|
+
prov, Provider
|
|
115
|
+
):
|
|
116
|
+
self.curr_provider_inst = prov
|
|
96
117
|
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
|
|
97
118
|
|
|
98
119
|
async def get_provider_by_id(self, provider_id: str) -> Provider | None:
|
|
99
120
|
"""根据提供商 ID 获取提供商实例"""
|
|
100
121
|
return self.inst_map.get(provider_id)
|
|
101
122
|
|
|
102
|
-
def get_using_provider(
|
|
123
|
+
def get_using_provider(
|
|
124
|
+
self, provider_type: ProviderType, umo=None
|
|
125
|
+
) -> Provider | STTProvider | TTSProvider | None:
|
|
103
126
|
"""获取正在使用的提供商实例。
|
|
104
127
|
|
|
105
128
|
Args:
|
|
@@ -303,12 +326,14 @@ class ProviderManager:
|
|
|
303
326
|
provider_metadata = provider_cls_map[provider_config["type"]]
|
|
304
327
|
try:
|
|
305
328
|
# 按任务实例化提供商
|
|
329
|
+
cls_type = provider_metadata.cls_type
|
|
330
|
+
if not cls_type:
|
|
331
|
+
logger.error(f"无法找到 {provider_metadata.type} 的类")
|
|
332
|
+
return
|
|
306
333
|
|
|
307
334
|
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
|
|
308
335
|
# STT 任务
|
|
309
|
-
inst =
|
|
310
|
-
provider_config, self.provider_settings
|
|
311
|
-
)
|
|
336
|
+
inst = cls_type(provider_config, self.provider_settings)
|
|
312
337
|
|
|
313
338
|
if getattr(inst, "initialize", None):
|
|
314
339
|
await inst.initialize()
|
|
@@ -327,9 +352,7 @@ class ProviderManager:
|
|
|
327
352
|
|
|
328
353
|
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
|
|
329
354
|
# TTS 任务
|
|
330
|
-
inst =
|
|
331
|
-
provider_config, self.provider_settings
|
|
332
|
-
)
|
|
355
|
+
inst = cls_type(provider_config, self.provider_settings)
|
|
333
356
|
|
|
334
357
|
if getattr(inst, "initialize", None):
|
|
335
358
|
await inst.initialize()
|
|
@@ -345,7 +368,7 @@ class ProviderManager:
|
|
|
345
368
|
|
|
346
369
|
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
|
|
347
370
|
# 文本生成任务
|
|
348
|
-
inst =
|
|
371
|
+
inst = cls_type(
|
|
349
372
|
provider_config,
|
|
350
373
|
self.provider_settings,
|
|
351
374
|
self.selected_default_persona,
|
|
@@ -366,16 +389,16 @@ class ProviderManager:
|
|
|
366
389
|
if not self.curr_provider_inst:
|
|
367
390
|
self.curr_provider_inst = inst
|
|
368
391
|
|
|
369
|
-
elif provider_metadata.provider_type
|
|
370
|
-
|
|
371
|
-
ProviderType.RERANK,
|
|
372
|
-
]:
|
|
373
|
-
inst = provider_metadata.cls_type(
|
|
374
|
-
provider_config, self.provider_settings
|
|
375
|
-
)
|
|
392
|
+
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
|
|
393
|
+
inst = cls_type(provider_config, self.provider_settings)
|
|
376
394
|
if getattr(inst, "initialize", None):
|
|
377
395
|
await inst.initialize()
|
|
378
396
|
self.embedding_provider_insts.append(inst)
|
|
397
|
+
elif provider_metadata.provider_type == ProviderType.RERANK:
|
|
398
|
+
inst = cls_type(provider_config, self.provider_settings)
|
|
399
|
+
if getattr(inst, "initialize", None):
|
|
400
|
+
await inst.initialize()
|
|
401
|
+
self.rerank_provider_insts.append(inst)
|
|
379
402
|
|
|
380
403
|
self.inst_map[provider_config["id"]] = inst
|
|
381
404
|
except Exception as e:
|
|
@@ -430,11 +453,17 @@ class ProviderManager:
|
|
|
430
453
|
)
|
|
431
454
|
|
|
432
455
|
if self.inst_map[provider_id] in self.provider_insts:
|
|
433
|
-
self.
|
|
456
|
+
prov_inst = self.inst_map[provider_id]
|
|
457
|
+
if isinstance(prov_inst, Provider):
|
|
458
|
+
self.provider_insts.remove(prov_inst)
|
|
434
459
|
if self.inst_map[provider_id] in self.stt_provider_insts:
|
|
435
|
-
self.
|
|
460
|
+
prov_inst = self.inst_map[provider_id]
|
|
461
|
+
if isinstance(prov_inst, STTProvider):
|
|
462
|
+
self.stt_provider_insts.remove(prov_inst)
|
|
436
463
|
if self.inst_map[provider_id] in self.tts_provider_insts:
|
|
437
|
-
self.
|
|
464
|
+
prov_inst = self.inst_map[provider_id]
|
|
465
|
+
if isinstance(prov_inst, TTSProvider):
|
|
466
|
+
self.tts_provider_insts.remove(prov_inst)
|
|
438
467
|
|
|
439
468
|
if self.inst_map[provider_id] == self.curr_provider_inst:
|
|
440
469
|
self.curr_provider_inst = None
|