AstrBot 4.11.4__py3-none-any.whl → 4.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. astrbot/cli/__init__.py +1 -1
  2. astrbot/core/agent/runners/tool_loop_agent_runner.py +10 -8
  3. astrbot/core/config/default.py +66 -2
  4. astrbot/core/db/__init__.py +84 -2
  5. astrbot/core/db/po.py +65 -0
  6. astrbot/core/db/sqlite.py +225 -4
  7. astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +103 -49
  8. astrbot/core/pipeline/process_stage/utils.py +40 -0
  9. astrbot/core/platform/astr_message_event.py +23 -4
  10. astrbot/core/platform/sources/discord/discord_platform_adapter.py +2 -0
  11. astrbot/core/platform/sources/telegram/tg_adapter.py +2 -0
  12. astrbot/core/platform/sources/webchat/webchat_adapter.py +3 -2
  13. astrbot/core/platform/sources/webchat/webchat_event.py +17 -4
  14. astrbot/core/provider/sources/anthropic_source.py +44 -0
  15. astrbot/core/sandbox/booters/base.py +31 -0
  16. astrbot/core/sandbox/booters/boxlite.py +186 -0
  17. astrbot/core/sandbox/booters/shipyard.py +67 -0
  18. astrbot/core/sandbox/olayer/__init__.py +5 -0
  19. astrbot/core/sandbox/olayer/filesystem.py +33 -0
  20. astrbot/core/sandbox/olayer/python.py +19 -0
  21. astrbot/core/sandbox/olayer/shell.py +21 -0
  22. astrbot/core/sandbox/sandbox_client.py +52 -0
  23. astrbot/core/sandbox/tools/__init__.py +10 -0
  24. astrbot/core/sandbox/tools/fs.py +188 -0
  25. astrbot/core/sandbox/tools/python.py +74 -0
  26. astrbot/core/sandbox/tools/shell.py +55 -0
  27. astrbot/core/star/context.py +162 -44
  28. astrbot/dashboard/routes/__init__.py +2 -0
  29. astrbot/dashboard/routes/chat.py +40 -12
  30. astrbot/dashboard/routes/chatui_project.py +245 -0
  31. astrbot/dashboard/routes/session_management.py +545 -0
  32. astrbot/dashboard/server.py +1 -0
  33. {astrbot-4.11.4.dist-info → astrbot-4.12.1.dist-info}/METADATA +2 -1
  34. {astrbot-4.11.4.dist-info → astrbot-4.12.1.dist-info}/RECORD +37 -28
  35. astrbot/builtin_stars/python_interpreter/main.py +0 -536
  36. astrbot/builtin_stars/python_interpreter/metadata.yaml +0 -4
  37. astrbot/builtin_stars/python_interpreter/requirements.txt +0 -1
  38. astrbot/builtin_stars/python_interpreter/shared/api.py +0 -22
  39. {astrbot-4.11.4.dist-info → astrbot-4.12.1.dist-info}/WHEEL +0 -0
  40. {astrbot-4.11.4.dist-info → astrbot-4.12.1.dist-info}/entry_points.txt +0 -0
  41. {astrbot-4.11.4.dist-info → astrbot-4.12.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,55 @@
1
+ import json
2
+ from dataclasses import dataclass, field
3
+
4
+ from astrbot.api import FunctionTool
5
+ from astrbot.core.agent.run_context import ContextWrapper
6
+ from astrbot.core.agent.tool import ToolExecResult
7
+ from astrbot.core.astr_agent_context import AstrAgentContext
8
+
9
+ from ..sandbox_client import get_booter
10
+
11
+
12
+ @dataclass
13
+ class ExecuteShellTool(FunctionTool):
14
+ name: str = "astrbot_execute_shell"
15
+ description: str = "Execute a command in the shell."
16
+ parameters: dict = field(
17
+ default_factory=lambda: {
18
+ "type": "object",
19
+ "properties": {
20
+ "command": {
21
+ "type": "string",
22
+ "description": "The bash command to execute. Equal to 'cd {working_dir} && {your_command}'.",
23
+ },
24
+ "background": {
25
+ "type": "boolean",
26
+ "description": "Whether to run the command in the background.",
27
+ "default": False,
28
+ },
29
+ "env": {
30
+ "type": "object",
31
+ "description": "Optional environment variables to set for the file creation process.",
32
+ "additionalProperties": {"type": "string"},
33
+ "default": {},
34
+ },
35
+ },
36
+ "required": ["command"],
37
+ }
38
+ )
39
+
40
+ async def call(
41
+ self,
42
+ context: ContextWrapper[AstrAgentContext],
43
+ command: str,
44
+ background: bool = False,
45
+ env: dict = {},
46
+ ) -> ToolExecResult:
47
+ sb = await get_booter(
48
+ context.context.context,
49
+ context.context.event.unified_msg_origin,
50
+ )
51
+ try:
52
+ result = await sb.shell.exec(command, background=background, env=env)
53
+ return json.dumps(result)
54
+ except Exception as e:
55
+ return f"Error executing command: {str(e)}"
@@ -49,7 +49,7 @@ class Context:
49
49
 
50
50
  registered_web_apis: list = []
51
51
 
52
- # back compatibility
52
+ # 向后兼容的变量
53
53
  _register_tasks: list[Awaitable] = []
54
54
  _star_manager = None
55
55
 
@@ -73,12 +73,19 @@ class Context:
73
73
  self._db = db
74
74
  """AstrBot 数据库"""
75
75
  self.provider_manager = provider_manager
76
+ """模型提供商管理器"""
76
77
  self.platform_manager = platform_manager
78
+ """平台适配器管理器"""
77
79
  self.conversation_manager = conversation_manager
80
+ """会话管理器"""
78
81
  self.message_history_manager = message_history_manager
82
+ """平台消息历史管理器"""
79
83
  self.persona_manager = persona_manager
84
+ """人格角色设定管理器"""
80
85
  self.astrbot_config_mgr = astrbot_config_mgr
86
+ """配置文件管理器(非webui)"""
81
87
  self.kb_manager = knowledge_base_manager
88
+ """知识库管理器"""
82
89
 
83
90
  async def llm_generate(
84
91
  self,
@@ -226,14 +233,16 @@ class Context:
226
233
  return llm_resp
227
234
 
228
235
  async def get_current_chat_provider_id(self, umo: str) -> str:
229
- """Get the ID of the currently used chat provider.
236
+ """获取当前使用的聊天模型 Provider ID
230
237
 
231
238
  Args:
232
- umo(str): unified_message_origin value, if provided and user has enabled provider session isolation, the provider preferred by that session will be used.
239
+ umo: unified_message_origin。消息会话来源 ID。
233
240
 
234
- Raises:
235
- ProviderNotFoundError: If the specified chat provider is not found
241
+ Returns:
242
+ 指定消息会话来源当前使用的聊天模型 Provider ID。
236
243
 
244
+ Raises:
245
+ ProviderNotFoundError: 未找到。
237
246
  """
238
247
  prov = self.get_using_provider(umo)
239
248
  if not prov:
@@ -255,20 +264,27 @@ class Context:
255
264
  return self.provider_manager.llm_tools
256
265
 
257
266
  def activate_llm_tool(self, name: str) -> bool:
258
- """激活一个已经注册的函数调用工具。注册的工具默认是激活状态。
267
+ """激活一个已经注册的函数调用工具。
268
+
269
+ Args:
270
+ name: 工具名称。
259
271
 
260
272
  Returns:
261
- 如果没找到,会返回 False
273
+ 如果成功激活返回 True,如果没找到工具返回 False
262
274
 
275
+ Note:
276
+ 注册的工具默认是激活状态。
263
277
  """
264
278
  return self.provider_manager.llm_tools.activate_llm_tool(name, star_map)
265
279
 
266
280
  def deactivate_llm_tool(self, name: str) -> bool:
267
281
  """停用一个已经注册的函数调用工具。
268
282
 
269
- Returns:
270
- 如果没找到,会返回 False
283
+ Args:
284
+ name: 工具名称。
271
285
 
286
+ Returns:
287
+ 如果成功停用返回 True,如果没找到工具返回 False。
272
288
  """
273
289
  return self.provider_manager.llm_tools.deactivate_llm_tool(name)
274
290
 
@@ -278,7 +294,17 @@ class Context:
278
294
  ) -> (
279
295
  Provider | TTSProvider | STTProvider | EmbeddingProvider | RerankProvider | None
280
296
  ):
281
- """通过 ID 获取对应的 LLM Provider。"""
297
+ """通过 ID 获取对应的 LLM Provider。
298
+
299
+ Args:
300
+ provider_id: 提供者 ID。
301
+
302
+ Returns:
303
+ 提供者实例,如果未找到则返回 None。
304
+
305
+ Note:
306
+ 如果提供者 ID 存在但未找到提供者,会记录警告日志。
307
+ """
282
308
  prov = self.provider_manager.inst_map.get(provider_id)
283
309
  if provider_id and not prov:
284
310
  logger.warning(
@@ -303,11 +329,20 @@ class Context:
303
329
  return self.provider_manager.embedding_provider_insts
304
330
 
305
331
  def get_using_provider(self, umo: str | None = None) -> Provider:
306
- """获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
332
+ """获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)
307
333
 
308
334
  Args:
309
- umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。
335
+ umo: unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,
336
+ 则使用该会话偏好的提供商。
337
+
338
+ Returns:
339
+ 当前使用的文本生成提供者。
310
340
 
341
+ Raises:
342
+ ValueError: 返回的提供者不是 Provider 类型。
343
+
344
+ Note:
345
+ 通过 /provider 指令可以切换提供者。
311
346
  """
312
347
  prov = self.provider_manager.get_using_provider(
313
348
  provider_type=ProviderType.CHAT_COMPLETION,
@@ -321,8 +356,13 @@ class Context:
321
356
  """获取当前使用的用于 TTS 任务的 Provider。
322
357
 
323
358
  Args:
324
- umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
359
+ umo: unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
360
+
361
+ Returns:
362
+ 当前使用的 TTS 提供者,如果未设置则返回 None。
325
363
 
364
+ Raises:
365
+ ValueError: 返回的提供者不是 TTSProvider 类型。
326
366
  """
327
367
  prov = self.provider_manager.get_using_provider(
328
368
  provider_type=ProviderType.TEXT_TO_SPEECH,
@@ -336,8 +376,13 @@ class Context:
336
376
  """获取当前使用的用于 STT 任务的 Provider。
337
377
 
338
378
  Args:
339
- umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
379
+ umo: unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
380
+
381
+ Returns:
382
+ 当前使用的 STT 提供者,如果未设置则返回 None。
340
383
 
384
+ Raises:
385
+ ValueError: 返回的提供者不是 STTProvider 类型。
341
386
  """
342
387
  prov = self.provider_manager.get_using_provider(
343
388
  provider_type=ProviderType.SPEECH_TO_TEXT,
@@ -348,9 +393,19 @@ class Context:
348
393
  return prov
349
394
 
350
395
  def get_config(self, umo: str | None = None) -> AstrBotConfig:
351
- """获取 AstrBot 的配置。"""
396
+ """获取 AstrBot 的配置。
397
+
398
+ Args:
399
+ umo: unified_message_origin 值,用于获取特定会话的配置。
400
+
401
+ Returns:
402
+ AstrBot 配置对象。
403
+
404
+ Note:
405
+ 如果不提供 umo 参数,将返回默认配置。
406
+ """
352
407
  if not umo:
353
- # using default config
408
+ # 使用默认配置
354
409
  return self._config
355
410
  return self.astrbot_config_mgr.get_conf(umo)
356
411
 
@@ -361,14 +416,19 @@ class Context:
361
416
  ) -> bool:
362
417
  """根据 session(unified_msg_origin) 主动发送消息。
363
418
 
364
- @param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。
365
- @param message_chain: 消息链。
419
+ Args:
420
+ session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。
421
+ message_chain: 消息链。
366
422
 
367
- @return: 是否找到匹配的平台。
423
+ Returns:
424
+ 是否找到匹配的平台。
368
425
 
369
- 当 session 为字符串时,会尝试解析为 MessageSesion 对象,如果解析失败,会抛出 ValueError 异常。
426
+ Raises:
427
+ ValueError: session 字符串不合法时抛出。
370
428
 
371
- NOTE: qq_official(QQ 官方 API 平台) 不支持此方法
429
+ Note:
430
+ 当 session 为字符串时,会尝试解析为 MessageSession 对象。(类名为MessageSesion是因为历史遗留拼写错误)
431
+ qq_official(QQ 官方 API 平台) 不支持此方法。
372
432
  """
373
433
  if isinstance(session, str):
374
434
  try:
@@ -383,7 +443,14 @@ class Context:
383
443
  return False
384
444
 
385
445
  def add_llm_tools(self, *tools: FunctionTool) -> None:
386
- """添加 LLM 工具。"""
446
+ """添加 LLM 工具。
447
+
448
+ Args:
449
+ *tools: 要添加的函数工具对象。
450
+
451
+ Note:
452
+ 如果工具已存在,会替换已存在的工具。
453
+ """
387
454
  tool_name = {tool.name for tool in self.provider_manager.llm_tools.func_list}
388
455
  module_path = ""
389
456
  for tool in tools:
@@ -416,6 +483,17 @@ class Context:
416
483
  methods: list,
417
484
  desc: str,
418
485
  ):
486
+ """注册 Web API。
487
+
488
+ Args:
489
+ route: API 路由路径。
490
+ view_handler: 异步视图处理函数。
491
+ methods: HTTP 方法列表。
492
+ desc: API 描述。
493
+
494
+ Note:
495
+ 如果相同路由和方法已注册,会替换现有的 API。
496
+ """
419
497
  for idx, api in enumerate(self.registered_web_apis):
420
498
  if api[0] == route and methods == api[2]:
421
499
  self.registered_web_apis[idx] = (route, view_handler, methods, desc)
@@ -434,7 +512,14 @@ class Context:
434
512
  def get_platform(self, platform_type: PlatformAdapterType | str) -> Platform | None:
435
513
  """获取指定类型的平台适配器。
436
514
 
437
- 该方法已经过时,请使用 get_platform_inst 方法。(>= AstrBot v4.0.0)
515
+ Args:
516
+ platform_type: 平台类型或平台名称。
517
+
518
+ Returns:
519
+ 平台适配器实例,如果未找到则返回 None。
520
+
521
+ Note:
522
+ 该方法已经过时,请使用 get_platform_inst 方法。(>= AstrBot v4.0.0)
438
523
  """
439
524
  for platform in self.platform_manager.platform_insts:
440
525
  name = platform.meta().name
@@ -451,22 +536,32 @@ class Context:
451
536
  """获取指定 ID 的平台适配器实例。
452
537
 
453
538
  Args:
454
- platform_id (str): 平台适配器的唯一标识符。你可以通过 event.get_platform_id() 获取。
539
+ platform_id: 平台适配器的唯一标识符。
455
540
 
456
541
  Returns:
457
- Platform: 平台适配器实例,如果未找到则返回 None。
542
+ 平台适配器实例,如果未找到则返回 None。
458
543
 
544
+ Note:
545
+ 可以通过 event.get_platform_id() 获取平台 ID。
459
546
  """
460
547
  for platform in self.platform_manager.platform_insts:
461
548
  if platform.meta().id == platform_id:
462
549
  return platform
463
550
 
464
551
  def get_db(self) -> BaseDatabase:
465
- """获取 AstrBot 数据库。"""
552
+ """获取 AstrBot 数据库。
553
+
554
+ Returns:
555
+ 数据库实例。
556
+ """
466
557
  return self._db
467
558
 
468
559
  def register_provider(self, provider: Provider):
469
- """注册一个 LLM Provider(Chat_Completion 类型)。"""
560
+ """注册一个 LLM Provider(Chat_Completion 类型)。
561
+
562
+ Args:
563
+ provider: 提供者实例。
564
+ """
470
565
  self.provider_manager.provider_insts.append(provider)
471
566
 
472
567
  def register_llm_tool(
@@ -478,12 +573,16 @@ class Context:
478
573
  ) -> None:
479
574
  """[DEPRECATED]为函数调用(function-calling / tools-use)添加工具。
480
575
 
481
- @param name: 函数名
482
- @param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
483
- @param desc: 函数描述
484
- @param func_obj: 异步处理函数。
485
-
486
- 异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。
576
+ Args:
577
+ name: 函数名。
578
+ func_args: 函数参数列表,格式为
579
+ [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]。
580
+ desc: 函数描述。
581
+ func_obj: 异步处理函数。
582
+
583
+ Note:
584
+ 异步处理函数会接收到额外的关键词参数:event: AstrMessageEvent, context: Context。
585
+ 该方法已弃用,请使用新的注册方式。
487
586
  """
488
587
  md = StarHandlerMetadata(
489
588
  event_type=EventType.OnLLMRequestEvent,
@@ -498,7 +597,15 @@ class Context:
498
597
  self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj)
499
598
 
500
599
  def unregister_llm_tool(self, name: str) -> None:
501
- """[DEPRECATED]删除一个函数调用工具。如果再要启用,需要重新注册。"""
600
+ """[DEPRECATED]删除一个函数调用工具。
601
+
602
+ Args:
603
+ name: 工具名称。
604
+
605
+ Note:
606
+ 如果再要启用,需要重新注册。
607
+ 该方法已弃用。
608
+ """
502
609
  self.provider_manager.llm_tools.remove_func(name)
503
610
 
504
611
  def register_commands(
@@ -511,16 +618,19 @@ class Context:
511
618
  use_regex=False,
512
619
  ignore_prefix=False,
513
620
  ):
514
- """注册一个命令。
515
-
516
- [Deprecated] 推荐使用装饰器注册指令。该方法将在未来的版本中被移除。
517
-
518
- @param star_name: 插件(Star)名称。
519
- @param command_name: 命令名称。
520
- @param desc: 命令描述。
521
- @param priority: 优先级。1-10。
522
- @param awaitable: 异步处理函数。
621
+ """[DEPRECATED]注册一个命令。
523
622
 
623
+ Args:
624
+ star_name: 插件(Star)名称。
625
+ command_name: 命令名称。
626
+ desc: 命令描述。
627
+ priority: 优先级。1-10。
628
+ awaitable: 异步处理函数。
629
+ use_regex: 是否使用正则表达式匹配命令。
630
+ ignore_prefix: 是否忽略命令前缀。
631
+
632
+ Note:
633
+ 推荐使用装饰器注册指令。该方法将在未来的版本中被移除。
524
634
  """
525
635
  md = StarHandlerMetadata(
526
636
  event_type=EventType.AdapterMessageEvent,
@@ -540,5 +650,13 @@ class Context:
540
650
  star_handlers_registry.append(md)
541
651
 
542
652
  def register_task(self, task: Awaitable, desc: str):
543
- """[DEPRECATED]注册一个异步任务。"""
653
+ """[DEPRECATED]注册一个异步任务。
654
+
655
+ Args:
656
+ task: 异步任务。
657
+ desc: 任务描述。
658
+
659
+ Note:
660
+ 该方法已弃用。
661
+ """
544
662
  self._register_tasks.append(task)
@@ -1,6 +1,7 @@
1
1
  from .auth import AuthRoute
2
2
  from .backup import BackupRoute
3
3
  from .chat import ChatRoute
4
+ from .chatui_project import ChatUIProjectRoute
4
5
  from .command import CommandRoute
5
6
  from .config import ConfigRoute
6
7
  from .conversation import ConversationRoute
@@ -20,6 +21,7 @@ __all__ = [
20
21
  "AuthRoute",
21
22
  "BackupRoute",
22
23
  "ChatRoute",
24
+ "ChatUIProjectRoute",
23
25
  "CommandRoute",
24
26
  "ConfigRoute",
25
27
  "ConversationRoute",
@@ -296,6 +296,8 @@ class ChatRoute(Route):
296
296
  # 构建用户消息段(包含 path 用于传递给 adapter)
297
297
  message_parts = await self._build_user_message_parts(message)
298
298
 
299
+ message_id = str(uuid.uuid4())
300
+
299
301
  async def stream():
300
302
  client_disconnected = False
301
303
  accumulated_parts = []
@@ -319,6 +321,13 @@ class ChatRoute(Route):
319
321
  if not result:
320
322
  continue
321
323
 
324
+ if (
325
+ "message_id" in result
326
+ and result["message_id"] != message_id
327
+ ):
328
+ logger.warning("webchat stream message_id mismatch")
329
+ continue
330
+
322
331
  result_text = result["data"]
323
332
  msg_type = result.get("type")
324
333
  streaming = result.get("streaming", False)
@@ -456,6 +465,7 @@ class ChatRoute(Route):
456
465
  "selected_provider": selected_provider,
457
466
  "selected_model": selected_model,
458
467
  "enable_streaming": enable_streaming,
468
+ "message_id": message_id,
459
469
  },
460
470
  ),
461
471
  )
@@ -618,9 +628,17 @@ class ChatRoute(Route):
618
628
  page_size=100, # 暂时返回前100个
619
629
  )
620
630
 
621
- # 转换为字典格式,并添加额外信息
631
+ # 转换为字典格式,并添加项目信息
632
+ # get_platform_sessions_by_creator 现在返回 list[dict] 包含 session 和项目字段
622
633
  sessions_data = []
623
- for session in sessions:
634
+ for item in sessions:
635
+ session = item["session"]
636
+ project_id = item["project_id"]
637
+
638
+ # 跳过属于项目的会话(在侧边栏对话列表中不显示)
639
+ if project_id is not None:
640
+ continue
641
+
624
642
  sessions_data.append(
625
643
  {
626
644
  "session_id": session.session_id,
@@ -645,6 +663,12 @@ class ChatRoute(Route):
645
663
  session = await self.db.get_platform_session_by_id(session_id)
646
664
  platform_id = session.platform_id if session else "webchat"
647
665
 
666
+ # 获取项目信息(如果会话属于某个项目)
667
+ username = g.get("username", "guest")
668
+ project_info = await self.db.get_project_by_session(
669
+ session_id=session_id, creator=username
670
+ )
671
+
648
672
  # Get platform message history using session_id
649
673
  history_ls = await self.platform_history_mgr.get(
650
674
  platform_id=platform_id,
@@ -655,16 +679,20 @@ class ChatRoute(Route):
655
679
 
656
680
  history_res = [history.model_dump() for history in history_ls]
657
681
 
658
- return (
659
- Response()
660
- .ok(
661
- data={
662
- "history": history_res,
663
- "is_running": self.running_convs.get(session_id, False),
664
- },
665
- )
666
- .__dict__
667
- )
682
+ response_data = {
683
+ "history": history_res,
684
+ "is_running": self.running_convs.get(session_id, False),
685
+ }
686
+
687
+ # 如果会话属于项目,添加项目信息
688
+ if project_info:
689
+ response_data["project"] = {
690
+ "project_id": project_info.project_id,
691
+ "title": project_info.title,
692
+ "emoji": project_info.emoji,
693
+ }
694
+
695
+ return Response().ok(data=response_data).__dict__
668
696
 
669
697
  async def update_session_display_name(self):
670
698
  """Update a Platform session's display name."""