AstrBot 4.7.4__py3-none-any.whl → 4.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. astrbot/cli/__init__.py +1 -1
  2. astrbot/core/agent/runners/tool_loop_agent_runner.py +0 -1
  3. astrbot/core/agent/tool.py +7 -2
  4. astrbot/core/astr_agent_run_util.py +15 -1
  5. astrbot/core/astr_agent_tool_exec.py +5 -1
  6. astrbot/core/config/astrbot_config.py +4 -0
  7. astrbot/core/config/default.py +116 -1
  8. astrbot/core/core_lifecycle.py +1 -1
  9. astrbot/core/db/__init__.py +32 -4
  10. astrbot/core/db/migration/migra_3_to_4.py +2 -0
  11. astrbot/core/db/migration/sqlite_v3.py +6 -4
  12. astrbot/core/db/po.py +16 -15
  13. astrbot/core/db/sqlite.py +56 -1
  14. astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +2 -0
  15. astrbot/core/event_bus.py +6 -1
  16. astrbot/core/knowledge_base/retrieval/manager.py +5 -1
  17. astrbot/core/log.py +2 -1
  18. astrbot/core/message/components.py +9 -3
  19. astrbot/core/persona_mgr.py +2 -2
  20. astrbot/core/pipeline/content_safety_check/stage.py +1 -1
  21. astrbot/core/pipeline/context_utils.py +2 -1
  22. astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +1 -1
  23. astrbot/core/pipeline/process_stage/method/star_request.py +1 -2
  24. astrbot/core/pipeline/process_stage/stage.py +1 -1
  25. astrbot/core/pipeline/respond/stage.py +4 -2
  26. astrbot/core/pipeline/result_decorate/stage.py +68 -21
  27. astrbot/core/pipeline/scheduler.py +5 -1
  28. astrbot/core/pipeline/waking_check/stage.py +10 -0
  29. astrbot/core/platform/astr_message_event.py +5 -3
  30. astrbot/core/platform/astrbot_message.py +2 -2
  31. astrbot/core/platform/manager.py +71 -9
  32. astrbot/core/platform/platform.py +109 -4
  33. astrbot/core/platform/platform_metadata.py +1 -1
  34. astrbot/core/platform/register.py +1 -0
  35. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +8 -6
  36. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +13 -8
  37. astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +28 -22
  38. astrbot/core/platform/sources/dingtalk/dingtalk_event.py +5 -2
  39. astrbot/core/platform/sources/discord/client.py +16 -4
  40. astrbot/core/platform/sources/discord/components.py +2 -2
  41. astrbot/core/platform/sources/discord/discord_platform_adapter.py +53 -26
  42. astrbot/core/platform/sources/discord/discord_platform_event.py +29 -8
  43. astrbot/core/platform/sources/lark/lark_adapter.py +178 -22
  44. astrbot/core/platform/sources/lark/lark_event.py +39 -4
  45. astrbot/core/platform/sources/lark/server.py +206 -0
  46. astrbot/core/platform/sources/misskey/misskey_adapter.py +3 -5
  47. astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +64 -18
  48. astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +14 -10
  49. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -11
  50. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +15 -2
  51. astrbot/core/platform/sources/satori/satori_adapter.py +1 -2
  52. astrbot/core/platform/sources/slack/client.py +58 -40
  53. astrbot/core/platform/sources/slack/slack_adapter.py +36 -16
  54. astrbot/core/platform/sources/slack/slack_event.py +11 -10
  55. astrbot/core/platform/sources/telegram/tg_adapter.py +2 -3
  56. astrbot/core/platform/sources/telegram/tg_event.py +23 -27
  57. astrbot/core/platform/sources/webchat/webchat_adapter.py +97 -31
  58. astrbot/core/platform/sources/webchat/webchat_event.py +35 -35
  59. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +27 -11
  60. astrbot/core/platform/sources/wecom/wecom_adapter.py +75 -36
  61. astrbot/core/platform/sources/wecom/wecom_event.py +3 -3
  62. astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +26 -9
  63. astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +3 -3
  64. astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +27 -5
  65. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +81 -35
  66. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +11 -8
  67. astrbot/core/platform_message_history_mgr.py +3 -3
  68. astrbot/core/provider/func_tool_manager.py +3 -3
  69. astrbot/core/provider/manager.py +130 -74
  70. astrbot/core/provider/provider.py +12 -1
  71. astrbot/core/provider/sources/azure_tts_source.py +31 -9
  72. astrbot/core/provider/sources/bailian_rerank_source.py +4 -0
  73. astrbot/core/provider/sources/dashscope_tts.py +3 -2
  74. astrbot/core/provider/sources/edge_tts_source.py +1 -1
  75. astrbot/core/provider/sources/fishaudio_tts_api_source.py +5 -4
  76. astrbot/core/provider/sources/gemini_embedding_source.py +15 -5
  77. astrbot/core/provider/sources/gemini_source.py +12 -10
  78. astrbot/core/provider/sources/minimax_tts_api_source.py +4 -2
  79. astrbot/core/provider/sources/openai_embedding_source.py +2 -2
  80. astrbot/core/provider/sources/openai_source.py +4 -0
  81. astrbot/core/provider/sources/sensevoice_selfhosted_source.py +5 -2
  82. astrbot/core/provider/sources/vllm_rerank_source.py +1 -0
  83. astrbot/core/provider/sources/whisper_api_source.py +44 -12
  84. astrbot/core/provider/sources/whisper_selfhosted_source.py +6 -2
  85. astrbot/core/provider/sources/xinference_rerank_source.py +10 -2
  86. astrbot/core/star/context.py +2 -2
  87. astrbot/core/star/register/star_handler.py +22 -5
  88. astrbot/core/star/star_handler.py +85 -4
  89. astrbot/core/updator.py +3 -3
  90. astrbot/core/utils/io.py +1 -1
  91. astrbot/core/utils/session_waiter.py +17 -10
  92. astrbot/core/utils/shared_preferences.py +32 -0
  93. astrbot/core/utils/t2i/__init__.py +2 -2
  94. astrbot/core/utils/t2i/local_strategy.py +25 -31
  95. astrbot/core/utils/tencent_record_helper.py +2 -2
  96. astrbot/core/utils/version_comparator.py +6 -3
  97. astrbot/core/utils/webhook_utils.py +66 -0
  98. astrbot/dashboard/routes/__init__.py +2 -0
  99. astrbot/dashboard/routes/chat.py +311 -76
  100. astrbot/dashboard/routes/config.py +14 -5
  101. astrbot/dashboard/routes/knowledge_base.py +254 -79
  102. astrbot/dashboard/routes/log.py +13 -8
  103. astrbot/dashboard/routes/platform.py +100 -0
  104. astrbot/dashboard/routes/plugin.py +108 -51
  105. astrbot/dashboard/routes/route.py +2 -0
  106. astrbot/dashboard/server.py +9 -4
  107. {astrbot-4.7.4.dist-info → astrbot-4.9.0.dist-info}/METADATA +50 -37
  108. {astrbot-4.7.4.dist-info → astrbot-4.9.0.dist-info}/RECORD +111 -108
  109. {astrbot-4.7.4.dist-info → astrbot-4.9.0.dist-info}/WHEEL +0 -0
  110. {astrbot-4.7.4.dist-info → astrbot-4.9.0.dist-info}/entry_points.txt +0 -0
  111. {astrbot-4.7.4.dist-info → astrbot-4.9.0.dist-info}/licenses/LICENSE +0 -0
@@ -7,6 +7,7 @@ import asyncio
7
7
  import os
8
8
  import re
9
9
  from datetime import datetime
10
+ from typing import cast
10
11
 
11
12
  from funasr_onnx import SenseVoiceSmall
12
13
  from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess
@@ -32,7 +33,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
32
33
  provider_settings: dict,
33
34
  ) -> None:
34
35
  super().__init__(provider_config, provider_settings)
35
- self.set_model(provider_config.get("stt_model"))
36
+ self.set_model(provider_config["stt_model"])
36
37
  self.model = None
37
38
  self.is_emotion = provider_config.get("is_emotion", False)
38
39
 
@@ -86,7 +87,9 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
86
87
  loop = asyncio.get_event_loop()
87
88
  res = await loop.run_in_executor(
88
89
  None, # 使用默认的线程池
89
- lambda: self.model(audio_url, language="auto", use_itn=True),
90
+ lambda: cast(SenseVoiceSmall, self.model)(
91
+ audio_url, language="auto", use_itn=True
92
+ ),
90
93
  )
91
94
 
92
95
  # res = self.model(audio_url, language="auto", use_itn=True)
@@ -44,6 +44,7 @@ class VLLMRerankProvider(RerankProvider):
44
44
  }
45
45
  if top_n is not None:
46
46
  payload["top_n"] = top_n
47
+ assert self.client is not None
47
48
  async with self.client.post(
48
49
  f"{self.base_url}/v1/rerank",
49
50
  json=payload,
@@ -6,7 +6,10 @@ from openai import NOT_GIVEN, AsyncOpenAI
6
6
  from astrbot.core import logger
7
7
  from astrbot.core.utils.astrbot_path import get_astrbot_data_path
8
8
  from astrbot.core.utils.io import download_file
9
- from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
9
+ from astrbot.core.utils.tencent_record_helper import (
10
+ convert_to_pcm_wav,
11
+ tencent_silk_to_wav,
12
+ )
10
13
 
11
14
  from ..entities import ProviderType
12
15
  from ..provider import STTProvider
@@ -33,20 +36,30 @@ class ProviderOpenAIWhisperAPI(STTProvider):
33
36
  timeout=provider_config.get("timeout", NOT_GIVEN),
34
37
  )
35
38
 
36
- self.set_model(provider_config.get("model"))
39
+ self.set_model(provider_config["model"])
37
40
 
38
- async def _is_silk_file(self, file_path):
41
+ async def _get_audio_format(self, file_path):
42
+ # 定义要检测的头部字节
39
43
  silk_header = b"SILK"
40
- with open(file_path, "rb") as f:
41
- file_header = f.read(8)
44
+ amr_header = b"#!AMR"
45
+
46
+ try:
47
+ with open(file_path, "rb") as f:
48
+ file_header = f.read(8)
49
+ except FileNotFoundError:
50
+ return None
42
51
 
43
52
  if silk_header in file_header:
44
- return True
45
- return False
53
+ return "silk"
54
+
55
+ if amr_header in file_header:
56
+ return "amr"
57
+ return None
46
58
 
47
59
  async def get_text(self, audio_url: str) -> str:
48
60
  """Only supports mp3, mp4, mpeg, m4a, wav, webm"""
49
61
  is_tencent = False
62
+ output_path = None
50
63
 
51
64
  if audio_url.startswith("http"):
52
65
  if "multimedia.nt.qq.com.cn" in audio_url:
@@ -62,16 +75,35 @@ class ProviderOpenAIWhisperAPI(STTProvider):
62
75
  raise FileNotFoundError(f"文件不存在: {audio_url}")
63
76
 
64
77
  if audio_url.endswith(".amr") or audio_url.endswith(".silk") or is_tencent:
65
- is_silk = await self._is_silk_file(audio_url)
66
- if is_silk:
67
- logger.info("Converting silk file to wav ...")
78
+ file_format = await self._get_audio_format(audio_url)
79
+
80
+ # 判断是否需要转换
81
+ if file_format in ["silk", "amr"]:
68
82
  temp_dir = os.path.join(get_astrbot_data_path(), "temp")
69
83
  output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav")
70
- await tencent_silk_to_wav(audio_url, output_path)
84
+
85
+ if file_format == "silk":
86
+ logger.info(
87
+ "Converting silk file to wav using tencent_silk_to_wav..."
88
+ )
89
+ await tencent_silk_to_wav(audio_url, output_path)
90
+ elif file_format == "amr":
91
+ logger.info(
92
+ "Converting amr file to wav using convert_to_pcm_wav..."
93
+ )
94
+ await convert_to_pcm_wav(audio_url, output_path)
95
+
71
96
  audio_url = output_path
72
97
 
73
98
  result = await self.client.audio.transcriptions.create(
74
99
  model=self.model_name,
75
- file=open(audio_url, "rb"),
100
+ file=("audio.wav", open(audio_url, "rb")),
76
101
  )
102
+
103
+ # remove temp file
104
+ if output_path and os.path.exists(output_path):
105
+ try:
106
+ os.remove(audio_url)
107
+ except Exception as e:
108
+ logger.error(f"Failed to remove temp file {audio_url}: {e}")
77
109
  return result.text
@@ -1,6 +1,7 @@
1
1
  import asyncio
2
2
  import os
3
3
  import uuid
4
+ from typing import cast
4
5
 
5
6
  import whisper
6
7
 
@@ -26,7 +27,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
26
27
  provider_settings: dict,
27
28
  ) -> None:
28
29
  super().__init__(provider_config, provider_settings)
29
- self.set_model(provider_config.get("model"))
30
+ self.set_model(provider_config["model"])
30
31
  self.model = None
31
32
 
32
33
  async def initialize(self):
@@ -75,5 +76,8 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
75
76
  await tencent_silk_to_wav(audio_url, output_path)
76
77
  audio_url = output_path
77
78
 
79
+ if not self.model:
80
+ raise RuntimeError("Whisper 模型未初始化")
81
+
78
82
  result = await loop.run_in_executor(None, self.model.transcribe, audio_url)
79
- return result["text"]
83
+ return cast(str, result["text"])
@@ -1,6 +1,11 @@
1
+ from typing import cast
2
+
1
3
  from xinference_client.client.restful.async_restful_client import (
2
4
  AsyncClient as Client,
3
5
  )
6
+ from xinference_client.client.restful.async_restful_client import (
7
+ AsyncRESTfulRerankModelHandle,
8
+ )
4
9
 
5
10
  from astrbot import logger
6
11
 
@@ -29,7 +34,7 @@ class XinferenceRerankProvider(RerankProvider):
29
34
  False,
30
35
  )
31
36
  self.client = None
32
- self.model = None
37
+ self.model: AsyncRESTfulRerankModelHandle | None = None
33
38
  self.model_uid = None
34
39
 
35
40
  async def initialize(self):
@@ -65,7 +70,10 @@ class XinferenceRerankProvider(RerankProvider):
65
70
  return
66
71
 
67
72
  if self.model_uid:
68
- self.model = await self.client.get_model(self.model_uid)
73
+ self.model = cast(
74
+ AsyncRESTfulRerankModelHandle,
75
+ await self.client.get_model(self.model_uid),
76
+ )
69
77
 
70
78
  except Exception as e:
71
79
  logger.error(f"Failed to initialize Xinference model: {e}")
@@ -285,7 +285,7 @@ class Context:
285
285
  """获取所有用于 Embedding 任务的 Provider。"""
286
286
  return self.provider_manager.embedding_provider_insts
287
287
 
288
- def get_using_provider(self, umo: str | None = None) -> Provider | None:
288
+ def get_using_provider(self, umo: str | None = None) -> Provider:
289
289
  """获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
290
290
 
291
291
  Args:
@@ -296,7 +296,7 @@ class Context:
296
296
  provider_type=ProviderType.CHAT_COMPLETION,
297
297
  umo=umo,
298
298
  )
299
- if prov and not isinstance(prov, Provider):
299
+ if not isinstance(prov, Provider):
300
300
  raise ValueError("返回的 Provider 不是 Provider 类型")
301
301
  return prov
302
302
 
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import re
4
- from collections.abc import Awaitable, Callable
4
+ from collections.abc import AsyncGenerator, Awaitable, Callable
5
5
  from typing import Any
6
6
 
7
7
  import docstring_parser
@@ -12,6 +12,7 @@ from astrbot.core.agent.handoff import HandoffTool
12
12
  from astrbot.core.agent.hooks import BaseAgentRunHooks
13
13
  from astrbot.core.agent.tool import FunctionTool
14
14
  from astrbot.core.astr_agent_context import AstrAgentContext
15
+ from astrbot.core.message.message_event_result import MessageEventResult
15
16
  from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES
16
17
  from astrbot.core.provider.register import llm_tools
17
18
 
@@ -28,13 +29,19 @@ from ..filter.regex import RegexFilter
28
29
  from ..star_handler import EventType, StarHandlerMetadata, star_handlers_registry
29
30
 
30
31
 
31
- def get_handler_full_name(awaitable: Callable[..., Awaitable[Any]]) -> str:
32
+ def get_handler_full_name(
33
+ awaitable: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
34
+ ) -> str:
32
35
  """获取 Handler 的全名"""
33
36
  return f"{awaitable.__module__}_{awaitable.__name__}"
34
37
 
35
38
 
36
39
  def get_handler_or_create(
37
- handler: Callable[..., Awaitable[Any]],
40
+ handler: Callable[
41
+ ...,
42
+ Awaitable[MessageEventResult | str | None]
43
+ | AsyncGenerator[MessageEventResult | str | None],
44
+ ],
38
45
  event_type: EventType,
39
46
  dont_add=False,
40
47
  **kwargs,
@@ -169,6 +176,8 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
169
176
  for (
170
177
  sub_handle
171
178
  ) in parent_register_commandable.parent_group.sub_command_filters:
179
+ if isinstance(sub_handle, CommandGroupFilter):
180
+ continue
172
181
  # 所有符合fullname一致的子指令handle添加自定义过滤器。
173
182
  # 不确定是否会有多个子指令有一样的fullname,比如一个方法添加多个command装饰器?
174
183
  sub_handle_md = sub_handle.get_handler_md()
@@ -180,6 +189,8 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
180
189
 
181
190
  else:
182
191
  # 裸指令
192
+ # 确保运行时是可调用的 handler,针对类型检查器添加忽略
193
+ assert isinstance(awaitable, Callable)
183
194
  handler_md = get_handler_or_create(
184
195
  awaitable,
185
196
  EventType.AdapterMessageEvent,
@@ -237,7 +248,7 @@ class RegisteringCommandable:
237
248
 
238
249
  group: Callable[..., Callable[..., RegisteringCommandable]] = register_command_group
239
250
  command: Callable[..., Callable[..., None]] = register_command
240
- custom_filter: Callable[..., Callable[..., None]] = register_custom_filter
251
+ custom_filter: Callable[..., Callable[..., Any]] = register_custom_filter
241
252
 
242
253
  def __init__(self, parent_group: CommandGroupFilter):
243
254
  self.parent_group = parent_group
@@ -412,7 +423,13 @@ def register_llm_tool(name: str | None = None, **kwargs):
412
423
  if kwargs.get("registering_agent"):
413
424
  registering_agent = kwargs["registering_agent"]
414
425
 
415
- def decorator(awaitable: Callable[..., Awaitable[Any]]):
426
+ def decorator(
427
+ awaitable: Callable[
428
+ ...,
429
+ AsyncGenerator[MessageEventResult | str | None]
430
+ | Awaitable[MessageEventResult | str | None],
431
+ ],
432
+ ):
416
433
  llm_tool_name = name_ if name_ else awaitable.__name__
417
434
  func_doc = awaitable.__doc__ or ""
418
435
  docstring = docstring_parser.parse(func_doc)
@@ -1,9 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import enum
4
- from collections.abc import Awaitable, Callable
4
+ from collections.abc import AsyncGenerator, Awaitable, Callable
5
5
  from dataclasses import dataclass, field
6
- from typing import Any, Generic, TypeVar
6
+ from typing import Any, Generic, Literal, TypeVar, overload
7
7
 
8
8
  from .filter import HandlerFilter
9
9
  from .star import star_map
@@ -29,6 +29,84 @@ class StarHandlerRegistry(Generic[T]):
29
29
  for handler in self._handlers:
30
30
  print(handler.handler_full_name)
31
31
 
32
+ @overload
33
+ def get_handlers_by_event_type(
34
+ self,
35
+ event_type: Literal[EventType.OnAstrBotLoadedEvent],
36
+ only_activated=True,
37
+ plugins_name: list[str] | None = None,
38
+ ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
39
+
40
+ @overload
41
+ def get_handlers_by_event_type(
42
+ self,
43
+ event_type: Literal[EventType.OnPlatformLoadedEvent],
44
+ only_activated=True,
45
+ plugins_name: list[str] | None = None,
46
+ ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
47
+
48
+ @overload
49
+ def get_handlers_by_event_type(
50
+ self,
51
+ event_type: Literal[EventType.AdapterMessageEvent],
52
+ only_activated=True,
53
+ plugins_name: list[str] | None = None,
54
+ ) -> list[
55
+ StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
56
+ ]: ...
57
+
58
+ @overload
59
+ def get_handlers_by_event_type(
60
+ self,
61
+ event_type: Literal[EventType.OnLLMRequestEvent],
62
+ only_activated=True,
63
+ plugins_name: list[str] | None = None,
64
+ ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
65
+
66
+ @overload
67
+ def get_handlers_by_event_type(
68
+ self,
69
+ event_type: Literal[EventType.OnLLMResponseEvent],
70
+ only_activated=True,
71
+ plugins_name: list[str] | None = None,
72
+ ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
73
+
74
+ @overload
75
+ def get_handlers_by_event_type(
76
+ self,
77
+ event_type: Literal[EventType.OnDecoratingResultEvent],
78
+ only_activated=True,
79
+ plugins_name: list[str] | None = None,
80
+ ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
81
+
82
+ @overload
83
+ def get_handlers_by_event_type(
84
+ self,
85
+ event_type: Literal[EventType.OnCallingFuncToolEvent],
86
+ only_activated=True,
87
+ plugins_name: list[str] | None = None,
88
+ ) -> list[
89
+ StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
90
+ ]: ...
91
+
92
+ @overload
93
+ def get_handlers_by_event_type(
94
+ self,
95
+ event_type: Literal[EventType.OnAfterMessageSentEvent],
96
+ only_activated=True,
97
+ plugins_name: list[str] | None = None,
98
+ ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
99
+
100
+ @overload
101
+ def get_handlers_by_event_type(
102
+ self,
103
+ event_type: EventType,
104
+ only_activated=True,
105
+ plugins_name: list[str] | None = None,
106
+ ) -> list[
107
+ StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
108
+ ]: ...
109
+
32
110
  def get_handlers_by_event_type(
33
111
  self,
34
112
  event_type: EventType,
@@ -111,8 +189,11 @@ class EventType(enum.Enum):
111
189
  OnAfterMessageSentEvent = enum.auto() # 发送消息后
112
190
 
113
191
 
192
+ H = TypeVar("H", bound=Callable[..., Any])
193
+
194
+
114
195
  @dataclass
115
- class StarHandlerMetadata:
196
+ class StarHandlerMetadata(Generic[H]):
116
197
  """描述一个 Star 所注册的某一个 Handler。"""
117
198
 
118
199
  event_type: EventType
@@ -127,7 +208,7 @@ class StarHandlerMetadata:
127
208
  handler_module_path: str
128
209
  """Handler 所在的模块路径。"""
129
210
 
130
- handler: Callable[..., Awaitable[Any]]
211
+ handler: H
131
212
  """Handler 的函数对象,应当是一个异步函数"""
132
213
 
133
214
  event_filters: list[HandlerFilter]
astrbot/core/updator.py CHANGED
@@ -71,10 +71,10 @@ class AstrBotUpdator(RepoZipUpdator):
71
71
 
72
72
  async def check_update(
73
73
  self,
74
- url: str,
75
- current_version: str,
74
+ url: str | None,
75
+ current_version: str | None,
76
76
  consider_prerelease: bool = True,
77
- ) -> ReleaseInfo:
77
+ ) -> ReleaseInfo | None:
78
78
  """检查更新"""
79
79
  return await super().check_update(
80
80
  self.ASTRBOT_RELEASE_API,
astrbot/core/utils/io.py CHANGED
@@ -49,7 +49,7 @@ def port_checker(port: int, host: str = "localhost"):
49
49
  return False
50
50
 
51
51
 
52
- def save_temp_img(img: Image.Image | str) -> str:
52
+ def save_temp_img(img: Image.Image | bytes) -> str:
53
53
  temp_dir = os.path.join(get_astrbot_data_path(), "temp")
54
54
  # 获得文件创建时间,清除超过 12 小时的
55
55
  try:
@@ -20,16 +20,16 @@ class SessionController:
20
20
 
21
21
  def __init__(self):
22
22
  self.future = asyncio.Future()
23
- self.current_event: asyncio.Event = None
23
+ self.current_event: asyncio.Event | None = None
24
24
  """当前正在等待的所用的异步事件"""
25
- self.ts: float = None
25
+ self.ts: float | None = None
26
26
  """上次保持(keep)开始时的时间"""
27
- self.timeout: float | int = None
27
+ self.timeout: float | int | None = None
28
28
  """上次保持(keep)开始时的超时时间"""
29
29
 
30
30
  self.history_chains: list[list[Comp.BaseMessageComponent]] = []
31
31
 
32
- def stop(self, error: Exception = None):
32
+ def stop(self, error: Exception | None = None):
33
33
  """立即结束这个会话"""
34
34
  if not self.future.done():
35
35
  if error:
@@ -53,6 +53,8 @@ class SessionController:
53
53
  self.stop()
54
54
  return
55
55
  else:
56
+ assert self.timeout is not None
57
+ assert self.ts is not None
56
58
  left_timeout = self.timeout - (new_ts - self.ts)
57
59
  timeout = left_timeout + timeout
58
60
  if timeout <= 0:
@@ -69,7 +71,7 @@ class SessionController:
69
71
 
70
72
  asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep
71
73
 
72
- async def _holding(self, event: asyncio.Event, timeout: int):
74
+ async def _holding(self, event: asyncio.Event, timeout: float):
73
75
  """等待事件结束或超时"""
74
76
  try:
75
77
  await asyncio.wait_for(event.wait(), timeout)
@@ -108,7 +110,9 @@ class SessionWaiter:
108
110
  ):
109
111
  self.session_id = session_id
110
112
  self.session_filter = session_filter
111
- self.handler: Callable[[str], Awaitable[Any]] | None = None # 处理函数
113
+ self.handler: (
114
+ Callable[[SessionController, AstrMessageEvent], Awaitable[Any]] | None
115
+ ) = None # 处理函数
112
116
 
113
117
  self.session_controller = SessionController()
114
118
  self.record_history_chains = record_history_chains
@@ -119,7 +123,7 @@ class SessionWaiter:
119
123
 
120
124
  async def register_wait(
121
125
  self,
122
- handler: Callable[[str], Awaitable[Any]],
126
+ handler: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
123
127
  timeout: int = 30,
124
128
  ) -> Any:
125
129
  """等待外部输入并处理"""
@@ -137,7 +141,7 @@ class SessionWaiter:
137
141
  finally:
138
142
  self._cleanup()
139
143
 
140
- def _cleanup(self, error: Exception = None):
144
+ def _cleanup(self, error: Exception | None = None):
141
145
  """清理会话"""
142
146
  USER_SESSIONS.pop(self.session_id, None)
143
147
  try:
@@ -161,6 +165,7 @@ class SessionWaiter:
161
165
  )
162
166
  try:
163
167
  # TODO: 这里使用 create_task,跟踪 task,防止超时后这里 handler 仍然在执行
168
+ assert session.handler is not None
164
169
  await session.handler(session.session_controller, event)
165
170
  except Exception as e:
166
171
  session.session_controller.stop(e)
@@ -173,11 +178,13 @@ def session_waiter(timeout: int = 30, record_history_chains: bool = False):
173
178
  :param record_history_chain: 是否自动记录历史消息链。可以通过 controller.get_history_chains() 获取。深拷贝。
174
179
  """
175
180
 
176
- def decorator(func: Callable[[str], Awaitable[Any]]):
181
+ def decorator(
182
+ func: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
183
+ ):
177
184
  @functools.wraps(func)
178
185
  async def wrapper(
179
186
  event: AstrMessageEvent,
180
- session_filter: SessionFilter = None,
187
+ session_filter: SessionFilter | None = None,
181
188
  *args,
182
189
  **kwargs,
183
190
  ):
@@ -53,6 +53,38 @@ class SharedPreferences:
53
53
  ret = await self.db_helper.get_preferences(scope, scope_id, key)
54
54
  return ret
55
55
 
56
+ @overload
57
+ async def session_get(
58
+ self,
59
+ umo: str,
60
+ key: str,
61
+ default: _VT = None,
62
+ ) -> _VT: ...
63
+
64
+ @overload
65
+ async def session_get(
66
+ self,
67
+ umo: None,
68
+ key: str,
69
+ default: Any = None,
70
+ ) -> list[Preference]: ...
71
+
72
+ @overload
73
+ async def session_get(
74
+ self,
75
+ umo: str,
76
+ key: None,
77
+ default: Any = None,
78
+ ) -> list[Preference]: ...
79
+
80
+ @overload
81
+ async def session_get(
82
+ self,
83
+ umo: None,
84
+ key: None,
85
+ default: Any = None,
86
+ ) -> list[Preference]: ...
87
+
56
88
  async def session_get(
57
89
  self,
58
90
  umo: str | None,
@@ -3,11 +3,11 @@ from abc import ABC, abstractmethod
3
3
 
4
4
  class RenderStrategy(ABC):
5
5
  @abstractmethod
6
- def render(self, text: str, return_url: bool) -> str:
6
+ async def render(self, text: str, return_url: bool) -> str:
7
7
  pass
8
8
 
9
9
  @abstractmethod
10
- def render_custom_template(
10
+ async def render_custom_template(
11
11
  self,
12
12
  tmpl_str: str,
13
13
  tmpl_data: dict,