AstrBot 4.3.3__py3-none-any.whl → 4.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. astrbot/core/agent/mcp_client.py +18 -4
  2. astrbot/core/agent/runners/tool_loop_agent_runner.py +31 -2
  3. astrbot/core/astr_agent_context.py +1 -0
  4. astrbot/core/astrbot_config_mgr.py +23 -51
  5. astrbot/core/config/default.py +139 -14
  6. astrbot/core/conversation_mgr.py +36 -1
  7. astrbot/core/core_lifecycle.py +24 -5
  8. astrbot/core/db/migration/migra_45_to_46.py +44 -0
  9. astrbot/core/db/vec_db/base.py +33 -2
  10. astrbot/core/db/vec_db/faiss_impl/document_storage.py +310 -52
  11. astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +31 -3
  12. astrbot/core/db/vec_db/faiss_impl/vec_db.py +81 -23
  13. astrbot/core/file_token_service.py +6 -1
  14. astrbot/core/initial_loader.py +6 -3
  15. astrbot/core/knowledge_base/chunking/__init__.py +11 -0
  16. astrbot/core/knowledge_base/chunking/base.py +24 -0
  17. astrbot/core/knowledge_base/chunking/fixed_size.py +57 -0
  18. astrbot/core/knowledge_base/chunking/recursive.py +155 -0
  19. astrbot/core/knowledge_base/kb_db_sqlite.py +299 -0
  20. astrbot/core/knowledge_base/kb_helper.py +348 -0
  21. astrbot/core/knowledge_base/kb_mgr.py +287 -0
  22. astrbot/core/knowledge_base/models.py +114 -0
  23. astrbot/core/knowledge_base/parsers/__init__.py +15 -0
  24. astrbot/core/knowledge_base/parsers/base.py +50 -0
  25. astrbot/core/knowledge_base/parsers/markitdown_parser.py +25 -0
  26. astrbot/core/knowledge_base/parsers/pdf_parser.py +100 -0
  27. astrbot/core/knowledge_base/parsers/text_parser.py +41 -0
  28. astrbot/core/knowledge_base/parsers/util.py +13 -0
  29. astrbot/core/knowledge_base/retrieval/__init__.py +16 -0
  30. astrbot/core/knowledge_base/retrieval/hit_stopwords.txt +767 -0
  31. astrbot/core/knowledge_base/retrieval/manager.py +273 -0
  32. astrbot/core/knowledge_base/retrieval/rank_fusion.py +138 -0
  33. astrbot/core/knowledge_base/retrieval/sparse_retriever.py +130 -0
  34. astrbot/core/pipeline/process_stage/method/llm_request.py +61 -21
  35. astrbot/core/pipeline/process_stage/utils.py +80 -0
  36. astrbot/core/pipeline/scheduler.py +1 -1
  37. astrbot/core/platform/astr_message_event.py +8 -7
  38. astrbot/core/platform/manager.py +4 -0
  39. astrbot/core/platform/sources/misskey/misskey_adapter.py +380 -44
  40. astrbot/core/platform/sources/misskey/misskey_api.py +581 -45
  41. astrbot/core/platform/sources/misskey/misskey_event.py +76 -41
  42. astrbot/core/platform/sources/misskey/misskey_utils.py +254 -43
  43. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +2 -1
  44. astrbot/core/platform/sources/satori/satori_adapter.py +27 -1
  45. astrbot/core/platform/sources/satori/satori_event.py +270 -77
  46. astrbot/core/platform/sources/webchat/webchat_adapter.py +0 -1
  47. astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +289 -0
  48. astrbot/core/platform/sources/wecom_ai_bot/__init__.py +17 -0
  49. astrbot/core/platform/sources/wecom_ai_bot/ierror.py +20 -0
  50. astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +445 -0
  51. astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +378 -0
  52. astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +149 -0
  53. astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +148 -0
  54. astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +166 -0
  55. astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +199 -0
  56. astrbot/core/provider/manager.py +14 -9
  57. astrbot/core/provider/provider.py +67 -0
  58. astrbot/core/provider/sources/anthropic_source.py +4 -4
  59. astrbot/core/provider/sources/dashscope_source.py +10 -9
  60. astrbot/core/provider/sources/dify_source.py +6 -8
  61. astrbot/core/provider/sources/gemini_embedding_source.py +1 -2
  62. astrbot/core/provider/sources/openai_embedding_source.py +1 -2
  63. astrbot/core/provider/sources/openai_source.py +18 -15
  64. astrbot/core/provider/sources/openai_tts_api_source.py +1 -1
  65. astrbot/core/star/context.py +3 -0
  66. astrbot/core/star/star.py +6 -0
  67. astrbot/core/star/star_manager.py +13 -7
  68. astrbot/core/umop_config_router.py +81 -0
  69. astrbot/core/updator.py +1 -1
  70. astrbot/core/utils/io.py +23 -12
  71. astrbot/dashboard/routes/__init__.py +2 -0
  72. astrbot/dashboard/routes/config.py +137 -9
  73. astrbot/dashboard/routes/knowledge_base.py +1065 -0
  74. astrbot/dashboard/routes/plugin.py +24 -5
  75. astrbot/dashboard/routes/tools.py +14 -0
  76. astrbot/dashboard/routes/update.py +1 -1
  77. astrbot/dashboard/server.py +6 -0
  78. astrbot/dashboard/utils.py +161 -0
  79. {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/METADATA +91 -55
  80. {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/RECORD +83 -50
  81. {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/WHEEL +0 -0
  82. {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/entry_points.txt +0 -0
  83. {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -6,7 +6,8 @@ import asyncio
6
6
  import copy
7
7
  import json
8
8
  import traceback
9
- from typing import AsyncGenerator, Union
9
+ from datetime import timedelta
10
+ from collections.abc import AsyncGenerator
10
11
  from astrbot.core.conversation_mgr import Conversation
11
12
  from astrbot.core import logger
12
13
  from astrbot.core.message.components import Image
@@ -32,6 +33,7 @@ from astrbot.core.star.star_handler import EventType
32
33
  from astrbot.core.utils.metrics import Metric
33
34
  from ...context import PipelineContext, call_event_hook, call_handler
34
35
  from ..stage import Stage
36
+ from ..utils import inject_kb_context
35
37
  from astrbot.core.provider.register import llm_tools
36
38
  from astrbot.core.star.star_handler import star_map
37
39
  from astrbot.core.astr_agent_context import AstrAgentContext
@@ -43,7 +45,7 @@ except (ModuleNotFoundError, ImportError):
43
45
 
44
46
 
45
47
  AgentContextWrapper = ContextWrapper[AstrAgentContext]
46
- AgentRunner = ToolLoopAgentRunner[AgentContextWrapper]
48
+ AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
47
49
 
48
50
 
49
51
  class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
@@ -101,7 +103,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
101
103
 
102
104
  request = ProviderRequest(
103
105
  prompt=input_,
104
- system_prompt=tool.description,
106
+ system_prompt=tool.description or "",
105
107
  image_urls=[], # 暂时不传递原始 agent 的上下文
106
108
  contexts=[], # 暂时不传递原始 agent 的上下文
107
109
  func_tool=toolset,
@@ -185,21 +187,33 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
185
187
  handler=awaitable,
186
188
  **tool_args,
187
189
  )
188
- async for resp in wrapper:
189
- if resp is not None:
190
- if isinstance(resp, mcp.types.CallToolResult):
191
- yield resp
190
+ # async for resp in wrapper:
191
+ while True:
192
+ try:
193
+ resp = await asyncio.wait_for(
194
+ anext(wrapper),
195
+ timeout=run_context.context.tool_call_timeout,
196
+ )
197
+ if resp is not None:
198
+ if isinstance(resp, mcp.types.CallToolResult):
199
+ yield resp
200
+ else:
201
+ text_content = mcp.types.TextContent(
202
+ type="text",
203
+ text=str(resp),
204
+ )
205
+ yield mcp.types.CallToolResult(content=[text_content])
192
206
  else:
193
- text_content = mcp.types.TextContent(
194
- type="text",
195
- text=str(resp),
196
- )
197
- yield mcp.types.CallToolResult(content=[text_content])
198
- else:
199
- # NOTE: Tool 在这里直接请求发送消息给用户
200
- # TODO: 是否需要判断 event.get_result() 是否为空?
201
- # 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
202
- yield None
207
+ # NOTE: Tool 在这里直接请求发送消息给用户
208
+ # TODO: 是否需要判断 event.get_result() 是否为空?
209
+ # 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
210
+ yield None
211
+ except asyncio.TimeoutError:
212
+ raise Exception(
213
+ f"tool {tool.name} execution timeout after {run_context.context.tool_call_timeout} seconds."
214
+ )
215
+ except StopAsyncIteration:
216
+ break
203
217
 
204
218
  @classmethod
205
219
  async def _execute_mcp(
@@ -217,13 +231,16 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
217
231
  res = await session.call_tool(
218
232
  name=tool.name,
219
233
  arguments=tool_args,
234
+ read_timeout_seconds=timedelta(
235
+ seconds=run_context.context.tool_call_timeout
236
+ ),
220
237
  )
221
238
  if not res:
222
239
  return
223
240
  yield res
224
241
 
225
242
 
226
- class MainAgentHooks(BaseAgentRunHooks[AgentContextWrapper]):
243
+ class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
227
244
  async def on_agent_done(self, run_context, llm_response):
228
245
  # 执行事件钩子
229
246
  await call_event_hook(
@@ -307,6 +324,7 @@ class LLMRequestSubStage(Stage):
307
324
  )
308
325
  self.streaming_response: bool = settings["streaming_response"]
309
326
  self.max_step: int = settings.get("max_agent_step", 30)
327
+ self.tool_call_timeout: int = settings.get("tool_call_timeout", 60)
310
328
  if isinstance(self.max_step, bool): # workaround: #2622
311
329
  self.max_step = 30
312
330
  self.show_tool_use: bool = settings.get("show_tool_use_status", True)
@@ -320,7 +338,7 @@ class LLMRequestSubStage(Stage):
320
338
 
321
339
  self.conv_manager = ctx.plugin_manager.context.conversation_manager
322
340
 
323
- def _select_provider(self, event: AstrMessageEvent) -> Provider | None:
341
+ def _select_provider(self, event: AstrMessageEvent):
324
342
  """选择使用的 LLM 提供商"""
325
343
  sel_provider = event.get_extra("selected_provider")
326
344
  _ctx = self.ctx.plugin_manager.context
@@ -350,7 +368,7 @@ class LLMRequestSubStage(Stage):
350
368
 
351
369
  async def process(
352
370
  self, event: AstrMessageEvent, _nested: bool = False
353
- ) -> Union[None, AsyncGenerator[None, None]]:
371
+ ) -> None | AsyncGenerator[None, None]:
354
372
  req: ProviderRequest | None = None
355
373
 
356
374
  if not self.ctx.astrbot_config["provider_settings"]["enable"]:
@@ -365,6 +383,9 @@ class LLMRequestSubStage(Stage):
365
383
  provider = self._select_provider(event)
366
384
  if provider is None:
367
385
  return
386
+ if not isinstance(provider, Provider):
387
+ logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。")
388
+ return
368
389
 
369
390
  if event.get_extra("provider_request"):
370
391
  req = event.get_extra("provider_request")
@@ -399,6 +420,14 @@ class LLMRequestSubStage(Stage):
399
420
  if not req.prompt and not req.image_urls:
400
421
  return
401
422
 
423
+ # 应用知识库
424
+ try:
425
+ await inject_kb_context(
426
+ umo=event.unified_msg_origin, p_ctx=self.ctx, req=req
427
+ )
428
+ except Exception as e:
429
+ logger.error(f"调用知识库时遇到问题: {e}")
430
+
402
431
  # 执行请求 LLM 前事件钩子。
403
432
  if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
404
433
  return
@@ -463,6 +492,9 @@ class LLMRequestSubStage(Stage):
463
492
  new_tool_set.add_tool(tool)
464
493
  req.func_tool = new_tool_set
465
494
 
495
+ # 备份 req.contexts
496
+ backup_contexts = copy.deepcopy(req.contexts)
497
+
466
498
  # run agent
467
499
  agent_runner = AgentRunner()
468
500
  logger.debug(
@@ -473,6 +505,7 @@ class LLMRequestSubStage(Stage):
473
505
  first_provider_request=req,
474
506
  curr_provider_request=req,
475
507
  streaming=self.streaming_response,
508
+ tool_call_timeout=self.tool_call_timeout,
476
509
  )
477
510
  await agent_runner.reset(
478
511
  provider=provider,
@@ -499,8 +532,10 @@ class LLMRequestSubStage(Stage):
499
532
  chain = (
500
533
  MessageChain().message(final_llm_resp.completion_text).chain
501
534
  )
502
- else:
535
+ elif final_llm_resp.result_chain:
503
536
  chain = final_llm_resp.result_chain.chain
537
+ else:
538
+ chain = MessageChain().chain
504
539
  event.set_result(
505
540
  MessageEventResult(
506
541
  chain=chain,
@@ -511,6 +546,9 @@ class LLMRequestSubStage(Stage):
511
546
  async for _ in run_agent(agent_runner, self.max_step, self.show_tool_use):
512
547
  yield
513
548
 
549
+ # 恢复备份的 contexts
550
+ req.contexts = backup_contexts
551
+
514
552
  await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
515
553
 
516
554
  # 异步处理 WebChat 特殊情况
@@ -529,6 +567,8 @@ class LLMRequestSubStage(Stage):
529
567
  self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
530
568
  ):
531
569
  """处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
570
+ if not req.conversation:
571
+ return
532
572
  conversation = await self.conv_manager.get_conversation(
533
573
  event.unified_msg_origin, req.conversation.cid
534
574
  )
@@ -0,0 +1,80 @@
1
+ from ..context import PipelineContext
2
+ from astrbot.core.provider.entities import ProviderRequest
3
+ from astrbot.api import logger, sp
4
+
5
+
6
+ async def inject_kb_context(
7
+ umo: str,
8
+ p_ctx: PipelineContext,
9
+ req: ProviderRequest,
10
+ ) -> None:
11
+ """inject knowledge base context into the provider request
12
+
13
+ Args:
14
+ umo: Unique message object (session ID)
15
+ p_ctx: Pipeline context
16
+ req: Provider request
17
+ """
18
+
19
+ kb_mgr = p_ctx.plugin_manager.context.kb_manager
20
+
21
+ # 1. 优先读取会话级配置
22
+ session_config = await sp.session_get(umo, "kb_config", default={})
23
+
24
+ if session_config and "kb_ids" in session_config:
25
+ # 会话级配置
26
+ kb_ids = session_config.get("kb_ids", [])
27
+
28
+ # 如果配置为空列表,明确表示不使用知识库
29
+ if not kb_ids:
30
+ logger.info(f"[知识库] 会话 {umo} 已被配置为不使用知识库")
31
+ return
32
+
33
+ top_k = session_config.get("top_k", 5)
34
+
35
+ # 将 kb_ids 转换为 kb_names
36
+ kb_names = []
37
+ invalid_kb_ids = []
38
+ for kb_id in kb_ids:
39
+ kb_helper = await kb_mgr.get_kb(kb_id)
40
+ if kb_helper:
41
+ kb_names.append(kb_helper.kb.kb_name)
42
+ else:
43
+ logger.warning(f"[知识库] 知识库不存在或未加载: {kb_id}")
44
+ invalid_kb_ids.append(kb_id)
45
+
46
+ if invalid_kb_ids:
47
+ logger.warning(
48
+ f"[知识库] 会话 {umo} 配置的以下知识库无效: {invalid_kb_ids}"
49
+ )
50
+
51
+ if not kb_names:
52
+ return
53
+
54
+ logger.debug(f"[知识库] 使用会话级配置,知识库数量: {len(kb_names)}")
55
+ else:
56
+ kb_names = p_ctx.astrbot_config.get("kb_names", [])
57
+ top_k = p_ctx.astrbot_config.get("kb_final_top_k", 5)
58
+ logger.debug(f"[知识库] 使用全局配置,知识库数量: {len(kb_names)}")
59
+
60
+ top_k_fusion = p_ctx.astrbot_config.get("kb_fusion_top_k", 20)
61
+
62
+ if not kb_names:
63
+ return
64
+
65
+ logger.debug(f"[知识库] 开始检索知识库,数量: {len(kb_names)}, top_k={top_k}")
66
+ kb_context = await kb_mgr.retrieve(
67
+ query=req.prompt,
68
+ kb_names=kb_names,
69
+ top_k_fusion=top_k_fusion,
70
+ top_m_final=top_k,
71
+ )
72
+
73
+ if not kb_context:
74
+ return
75
+
76
+ formatted = kb_context.get("context_text", "")
77
+ if formatted:
78
+ results = kb_context.get("results", [])
79
+ logger.debug(f"[知识库] 为会话 {umo} 注入了 {len(results)} 条相关知识块")
80
+ req.system_prompt = f"{formatted}\n\n{req.system_prompt or ''}"
@@ -74,7 +74,7 @@ class PipelineScheduler:
74
74
  await self._process_stages(event)
75
75
 
76
76
  # 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
77
- if event.get_platform_name() == "webchat":
77
+ if event.get_platform_name() in ["webchat", "wecom_ai_bot"]:
78
78
  await event.send(None)
79
79
 
80
80
  logger.debug("pipeline 执行完毕。")
@@ -4,7 +4,7 @@ import re
4
4
  import hashlib
5
5
  import uuid
6
6
 
7
- from typing import List, Union, Optional, AsyncGenerator, TypeVar, Any
7
+ from typing import List, Union, Optional, AsyncGenerator, Any
8
8
 
9
9
  from astrbot import logger
10
10
  from astrbot.core.db.po import Conversation
@@ -26,8 +26,6 @@ from .astrbot_message import AstrBotMessage, Group
26
26
  from .platform_metadata import PlatformMetadata
27
27
  from .message_session import MessageSession, MessageSesion # noqa
28
28
 
29
- _VT = TypeVar("_VT")
30
-
31
29
 
32
30
  class AstrMessageEvent(abc.ABC):
33
31
  def __init__(
@@ -92,8 +90,10 @@ class AstrMessageEvent(abc.ABC):
92
90
  """
93
91
  return self.message_str
94
92
 
95
- def _outline_chain(self, chain: List[BaseMessageComponent]) -> str:
93
+ def _outline_chain(self, chain: Optional[List[BaseMessageComponent]]) -> str:
96
94
  outline = ""
95
+ if not chain:
96
+ return outline
97
97
  for i in chain:
98
98
  if isinstance(i, Plain):
99
99
  outline += i.text
@@ -175,9 +175,7 @@ class AstrMessageEvent(abc.ABC):
175
175
  """
176
176
  self._extras[key] = value
177
177
 
178
- def get_extra(
179
- self, key: str | None = None, default: _VT = None
180
- ) -> dict[str, Any] | _VT:
178
+ def get_extra(self, key: str | None = None, default=None) -> Any:
181
179
  """
182
180
  获取额外的信息。
183
181
  """
@@ -265,6 +263,9 @@ class AstrMessageEvent(abc.ABC):
265
263
  """
266
264
  if isinstance(result, str):
267
265
  result = MessageEventResult().message(result)
266
+ # 兼容外部插件或调用方传入的 chain=None 的情况,确保为可迭代列表
267
+ if isinstance(result, MessageEventResult) and result.chain is None:
268
+ result.chain = []
268
269
  self._result = result
269
270
 
270
271
  def stop_event(self):
@@ -82,6 +82,10 @@ class PlatformManager:
82
82
  from .sources.wecom.wecom_adapter import (
83
83
  WecomPlatformAdapter, # noqa: F401
84
84
  )
85
+ case "wecom_ai_bot":
86
+ from .sources.wecom_ai_bot.wecomai_adapter import (
87
+ WecomAIBotAdapter, # noqa: F401
88
+ )
85
89
  case "weixin_official_account":
86
90
  from .sources.weixin_official_account.weixin_offacc_adapter import (
87
91
  WeixinOfficialAccountPlatformAdapter, # noqa: F401