AstrBot 4.7.4__py3-none-any.whl → 4.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/cli/__init__.py +1 -1
- astrbot/core/agent/runners/tool_loop_agent_runner.py +0 -1
- astrbot/core/agent/tool.py +7 -2
- astrbot/core/astr_agent_run_util.py +15 -1
- astrbot/core/astr_agent_tool_exec.py +5 -1
- astrbot/core/config/astrbot_config.py +4 -0
- astrbot/core/config/default.py +116 -1
- astrbot/core/core_lifecycle.py +1 -1
- astrbot/core/db/__init__.py +32 -4
- astrbot/core/db/migration/migra_3_to_4.py +2 -0
- astrbot/core/db/migration/sqlite_v3.py +6 -4
- astrbot/core/db/po.py +16 -15
- astrbot/core/db/sqlite.py +56 -1
- astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +2 -0
- astrbot/core/event_bus.py +6 -1
- astrbot/core/knowledge_base/retrieval/manager.py +5 -1
- astrbot/core/log.py +2 -1
- astrbot/core/message/components.py +9 -3
- astrbot/core/persona_mgr.py +2 -2
- astrbot/core/pipeline/content_safety_check/stage.py +1 -1
- astrbot/core/pipeline/context_utils.py +2 -1
- astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +1 -1
- astrbot/core/pipeline/process_stage/method/star_request.py +1 -2
- astrbot/core/pipeline/process_stage/stage.py +1 -1
- astrbot/core/pipeline/respond/stage.py +4 -2
- astrbot/core/pipeline/result_decorate/stage.py +68 -21
- astrbot/core/pipeline/scheduler.py +5 -1
- astrbot/core/pipeline/waking_check/stage.py +10 -0
- astrbot/core/platform/astr_message_event.py +5 -3
- astrbot/core/platform/astrbot_message.py +2 -2
- astrbot/core/platform/manager.py +71 -9
- astrbot/core/platform/platform.py +109 -4
- astrbot/core/platform/platform_metadata.py +1 -1
- astrbot/core/platform/register.py +1 -0
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +8 -6
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +13 -8
- astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +28 -22
- astrbot/core/platform/sources/dingtalk/dingtalk_event.py +5 -2
- astrbot/core/platform/sources/discord/client.py +16 -4
- astrbot/core/platform/sources/discord/components.py +2 -2
- astrbot/core/platform/sources/discord/discord_platform_adapter.py +53 -26
- astrbot/core/platform/sources/discord/discord_platform_event.py +29 -8
- astrbot/core/platform/sources/lark/lark_adapter.py +178 -22
- astrbot/core/platform/sources/lark/lark_event.py +39 -4
- astrbot/core/platform/sources/lark/server.py +206 -0
- astrbot/core/platform/sources/misskey/misskey_adapter.py +3 -5
- astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +64 -18
- astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +14 -10
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -11
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +15 -2
- astrbot/core/platform/sources/satori/satori_adapter.py +1 -2
- astrbot/core/platform/sources/slack/client.py +58 -40
- astrbot/core/platform/sources/slack/slack_adapter.py +36 -16
- astrbot/core/platform/sources/slack/slack_event.py +11 -10
- astrbot/core/platform/sources/telegram/tg_adapter.py +2 -3
- astrbot/core/platform/sources/telegram/tg_event.py +23 -27
- astrbot/core/platform/sources/webchat/webchat_adapter.py +97 -31
- astrbot/core/platform/sources/webchat/webchat_event.py +35 -35
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +27 -11
- astrbot/core/platform/sources/wecom/wecom_adapter.py +75 -36
- astrbot/core/platform/sources/wecom/wecom_event.py +3 -3
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +26 -9
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +3 -3
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +27 -5
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +81 -35
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +11 -8
- astrbot/core/platform_message_history_mgr.py +3 -3
- astrbot/core/provider/func_tool_manager.py +3 -3
- astrbot/core/provider/manager.py +130 -74
- astrbot/core/provider/provider.py +12 -1
- astrbot/core/provider/sources/azure_tts_source.py +31 -9
- astrbot/core/provider/sources/bailian_rerank_source.py +4 -0
- astrbot/core/provider/sources/dashscope_tts.py +3 -2
- astrbot/core/provider/sources/edge_tts_source.py +1 -1
- astrbot/core/provider/sources/fishaudio_tts_api_source.py +5 -4
- astrbot/core/provider/sources/gemini_embedding_source.py +15 -5
- astrbot/core/provider/sources/gemini_source.py +12 -10
- astrbot/core/provider/sources/minimax_tts_api_source.py +4 -2
- astrbot/core/provider/sources/openai_embedding_source.py +2 -2
- astrbot/core/provider/sources/openai_source.py +4 -0
- astrbot/core/provider/sources/sensevoice_selfhosted_source.py +5 -2
- astrbot/core/provider/sources/vllm_rerank_source.py +1 -0
- astrbot/core/provider/sources/whisper_api_source.py +44 -12
- astrbot/core/provider/sources/whisper_selfhosted_source.py +6 -2
- astrbot/core/provider/sources/xinference_rerank_source.py +10 -2
- astrbot/core/star/context.py +2 -2
- astrbot/core/star/register/star_handler.py +22 -5
- astrbot/core/star/star_handler.py +85 -4
- astrbot/core/updator.py +3 -3
- astrbot/core/utils/io.py +1 -1
- astrbot/core/utils/session_waiter.py +17 -10
- astrbot/core/utils/shared_preferences.py +32 -0
- astrbot/core/utils/t2i/__init__.py +2 -2
- astrbot/core/utils/t2i/local_strategy.py +25 -31
- astrbot/core/utils/tencent_record_helper.py +2 -2
- astrbot/core/utils/version_comparator.py +6 -3
- astrbot/core/utils/webhook_utils.py +66 -0
- astrbot/dashboard/routes/__init__.py +2 -0
- astrbot/dashboard/routes/chat.py +311 -76
- astrbot/dashboard/routes/config.py +14 -5
- astrbot/dashboard/routes/knowledge_base.py +254 -79
- astrbot/dashboard/routes/log.py +13 -8
- astrbot/dashboard/routes/platform.py +100 -0
- astrbot/dashboard/routes/plugin.py +108 -51
- astrbot/dashboard/routes/route.py +2 -0
- astrbot/dashboard/server.py +9 -4
- {astrbot-4.7.4.dist-info → astrbot-4.9.0.dist-info}/METADATA +50 -37
- {astrbot-4.7.4.dist-info → astrbot-4.9.0.dist-info}/RECORD +111 -108
- {astrbot-4.7.4.dist-info → astrbot-4.9.0.dist-info}/WHEEL +0 -0
- {astrbot-4.7.4.dist-info → astrbot-4.9.0.dist-info}/entry_points.txt +0 -0
- {astrbot-4.7.4.dist-info → astrbot-4.9.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -7,6 +7,7 @@ import asyncio
|
|
|
7
7
|
import os
|
|
8
8
|
import re
|
|
9
9
|
from datetime import datetime
|
|
10
|
+
from typing import cast
|
|
10
11
|
|
|
11
12
|
from funasr_onnx import SenseVoiceSmall
|
|
12
13
|
from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess
|
|
@@ -32,7 +33,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
|
|
|
32
33
|
provider_settings: dict,
|
|
33
34
|
) -> None:
|
|
34
35
|
super().__init__(provider_config, provider_settings)
|
|
35
|
-
self.set_model(provider_config
|
|
36
|
+
self.set_model(provider_config["stt_model"])
|
|
36
37
|
self.model = None
|
|
37
38
|
self.is_emotion = provider_config.get("is_emotion", False)
|
|
38
39
|
|
|
@@ -86,7 +87,9 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
|
|
|
86
87
|
loop = asyncio.get_event_loop()
|
|
87
88
|
res = await loop.run_in_executor(
|
|
88
89
|
None, # 使用默认的线程池
|
|
89
|
-
lambda: self.model(
|
|
90
|
+
lambda: cast(SenseVoiceSmall, self.model)(
|
|
91
|
+
audio_url, language="auto", use_itn=True
|
|
92
|
+
),
|
|
90
93
|
)
|
|
91
94
|
|
|
92
95
|
# res = self.model(audio_url, language="auto", use_itn=True)
|
|
@@ -6,7 +6,10 @@ from openai import NOT_GIVEN, AsyncOpenAI
|
|
|
6
6
|
from astrbot.core import logger
|
|
7
7
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|
8
8
|
from astrbot.core.utils.io import download_file
|
|
9
|
-
from astrbot.core.utils.tencent_record_helper import
|
|
9
|
+
from astrbot.core.utils.tencent_record_helper import (
|
|
10
|
+
convert_to_pcm_wav,
|
|
11
|
+
tencent_silk_to_wav,
|
|
12
|
+
)
|
|
10
13
|
|
|
11
14
|
from ..entities import ProviderType
|
|
12
15
|
from ..provider import STTProvider
|
|
@@ -33,20 +36,30 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
|
|
33
36
|
timeout=provider_config.get("timeout", NOT_GIVEN),
|
|
34
37
|
)
|
|
35
38
|
|
|
36
|
-
self.set_model(provider_config
|
|
39
|
+
self.set_model(provider_config["model"])
|
|
37
40
|
|
|
38
|
-
async def
|
|
41
|
+
async def _get_audio_format(self, file_path):
|
|
42
|
+
# 定义要检测的头部字节
|
|
39
43
|
silk_header = b"SILK"
|
|
40
|
-
|
|
41
|
-
|
|
44
|
+
amr_header = b"#!AMR"
|
|
45
|
+
|
|
46
|
+
try:
|
|
47
|
+
with open(file_path, "rb") as f:
|
|
48
|
+
file_header = f.read(8)
|
|
49
|
+
except FileNotFoundError:
|
|
50
|
+
return None
|
|
42
51
|
|
|
43
52
|
if silk_header in file_header:
|
|
44
|
-
return
|
|
45
|
-
|
|
53
|
+
return "silk"
|
|
54
|
+
|
|
55
|
+
if amr_header in file_header:
|
|
56
|
+
return "amr"
|
|
57
|
+
return None
|
|
46
58
|
|
|
47
59
|
async def get_text(self, audio_url: str) -> str:
|
|
48
60
|
"""Only supports mp3, mp4, mpeg, m4a, wav, webm"""
|
|
49
61
|
is_tencent = False
|
|
62
|
+
output_path = None
|
|
50
63
|
|
|
51
64
|
if audio_url.startswith("http"):
|
|
52
65
|
if "multimedia.nt.qq.com.cn" in audio_url:
|
|
@@ -62,16 +75,35 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
|
|
62
75
|
raise FileNotFoundError(f"文件不存在: {audio_url}")
|
|
63
76
|
|
|
64
77
|
if audio_url.endswith(".amr") or audio_url.endswith(".silk") or is_tencent:
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
78
|
+
file_format = await self._get_audio_format(audio_url)
|
|
79
|
+
|
|
80
|
+
# 判断是否需要转换
|
|
81
|
+
if file_format in ["silk", "amr"]:
|
|
68
82
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
|
69
83
|
output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav")
|
|
70
|
-
|
|
84
|
+
|
|
85
|
+
if file_format == "silk":
|
|
86
|
+
logger.info(
|
|
87
|
+
"Converting silk file to wav using tencent_silk_to_wav..."
|
|
88
|
+
)
|
|
89
|
+
await tencent_silk_to_wav(audio_url, output_path)
|
|
90
|
+
elif file_format == "amr":
|
|
91
|
+
logger.info(
|
|
92
|
+
"Converting amr file to wav using convert_to_pcm_wav..."
|
|
93
|
+
)
|
|
94
|
+
await convert_to_pcm_wav(audio_url, output_path)
|
|
95
|
+
|
|
71
96
|
audio_url = output_path
|
|
72
97
|
|
|
73
98
|
result = await self.client.audio.transcriptions.create(
|
|
74
99
|
model=self.model_name,
|
|
75
|
-
file=open(audio_url, "rb"),
|
|
100
|
+
file=("audio.wav", open(audio_url, "rb")),
|
|
76
101
|
)
|
|
102
|
+
|
|
103
|
+
# remove temp file
|
|
104
|
+
if output_path and os.path.exists(output_path):
|
|
105
|
+
try:
|
|
106
|
+
os.remove(audio_url)
|
|
107
|
+
except Exception as e:
|
|
108
|
+
logger.error(f"Failed to remove temp file {audio_url}: {e}")
|
|
77
109
|
return result.text
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import os
|
|
3
3
|
import uuid
|
|
4
|
+
from typing import cast
|
|
4
5
|
|
|
5
6
|
import whisper
|
|
6
7
|
|
|
@@ -26,7 +27,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
|
|
26
27
|
provider_settings: dict,
|
|
27
28
|
) -> None:
|
|
28
29
|
super().__init__(provider_config, provider_settings)
|
|
29
|
-
self.set_model(provider_config
|
|
30
|
+
self.set_model(provider_config["model"])
|
|
30
31
|
self.model = None
|
|
31
32
|
|
|
32
33
|
async def initialize(self):
|
|
@@ -75,5 +76,8 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
|
|
75
76
|
await tencent_silk_to_wav(audio_url, output_path)
|
|
76
77
|
audio_url = output_path
|
|
77
78
|
|
|
79
|
+
if not self.model:
|
|
80
|
+
raise RuntimeError("Whisper 模型未初始化")
|
|
81
|
+
|
|
78
82
|
result = await loop.run_in_executor(None, self.model.transcribe, audio_url)
|
|
79
|
-
return result["text"]
|
|
83
|
+
return cast(str, result["text"])
|
|
@@ -1,6 +1,11 @@
|
|
|
1
|
+
from typing import cast
|
|
2
|
+
|
|
1
3
|
from xinference_client.client.restful.async_restful_client import (
|
|
2
4
|
AsyncClient as Client,
|
|
3
5
|
)
|
|
6
|
+
from xinference_client.client.restful.async_restful_client import (
|
|
7
|
+
AsyncRESTfulRerankModelHandle,
|
|
8
|
+
)
|
|
4
9
|
|
|
5
10
|
from astrbot import logger
|
|
6
11
|
|
|
@@ -29,7 +34,7 @@ class XinferenceRerankProvider(RerankProvider):
|
|
|
29
34
|
False,
|
|
30
35
|
)
|
|
31
36
|
self.client = None
|
|
32
|
-
self.model = None
|
|
37
|
+
self.model: AsyncRESTfulRerankModelHandle | None = None
|
|
33
38
|
self.model_uid = None
|
|
34
39
|
|
|
35
40
|
async def initialize(self):
|
|
@@ -65,7 +70,10 @@ class XinferenceRerankProvider(RerankProvider):
|
|
|
65
70
|
return
|
|
66
71
|
|
|
67
72
|
if self.model_uid:
|
|
68
|
-
self.model =
|
|
73
|
+
self.model = cast(
|
|
74
|
+
AsyncRESTfulRerankModelHandle,
|
|
75
|
+
await self.client.get_model(self.model_uid),
|
|
76
|
+
)
|
|
69
77
|
|
|
70
78
|
except Exception as e:
|
|
71
79
|
logger.error(f"Failed to initialize Xinference model: {e}")
|
astrbot/core/star/context.py
CHANGED
|
@@ -285,7 +285,7 @@ class Context:
|
|
|
285
285
|
"""获取所有用于 Embedding 任务的 Provider。"""
|
|
286
286
|
return self.provider_manager.embedding_provider_insts
|
|
287
287
|
|
|
288
|
-
def get_using_provider(self, umo: str | None = None) -> Provider
|
|
288
|
+
def get_using_provider(self, umo: str | None = None) -> Provider:
|
|
289
289
|
"""获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
|
|
290
290
|
|
|
291
291
|
Args:
|
|
@@ -296,7 +296,7 @@ class Context:
|
|
|
296
296
|
provider_type=ProviderType.CHAT_COMPLETION,
|
|
297
297
|
umo=umo,
|
|
298
298
|
)
|
|
299
|
-
if
|
|
299
|
+
if not isinstance(prov, Provider):
|
|
300
300
|
raise ValueError("返回的 Provider 不是 Provider 类型")
|
|
301
301
|
return prov
|
|
302
302
|
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import re
|
|
4
|
-
from collections.abc import Awaitable, Callable
|
|
4
|
+
from collections.abc import AsyncGenerator, Awaitable, Callable
|
|
5
5
|
from typing import Any
|
|
6
6
|
|
|
7
7
|
import docstring_parser
|
|
@@ -12,6 +12,7 @@ from astrbot.core.agent.handoff import HandoffTool
|
|
|
12
12
|
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
|
13
13
|
from astrbot.core.agent.tool import FunctionTool
|
|
14
14
|
from astrbot.core.astr_agent_context import AstrAgentContext
|
|
15
|
+
from astrbot.core.message.message_event_result import MessageEventResult
|
|
15
16
|
from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES
|
|
16
17
|
from astrbot.core.provider.register import llm_tools
|
|
17
18
|
|
|
@@ -28,13 +29,19 @@ from ..filter.regex import RegexFilter
|
|
|
28
29
|
from ..star_handler import EventType, StarHandlerMetadata, star_handlers_registry
|
|
29
30
|
|
|
30
31
|
|
|
31
|
-
def get_handler_full_name(
|
|
32
|
+
def get_handler_full_name(
|
|
33
|
+
awaitable: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
|
|
34
|
+
) -> str:
|
|
32
35
|
"""获取 Handler 的全名"""
|
|
33
36
|
return f"{awaitable.__module__}_{awaitable.__name__}"
|
|
34
37
|
|
|
35
38
|
|
|
36
39
|
def get_handler_or_create(
|
|
37
|
-
handler: Callable[
|
|
40
|
+
handler: Callable[
|
|
41
|
+
...,
|
|
42
|
+
Awaitable[MessageEventResult | str | None]
|
|
43
|
+
| AsyncGenerator[MessageEventResult | str | None],
|
|
44
|
+
],
|
|
38
45
|
event_type: EventType,
|
|
39
46
|
dont_add=False,
|
|
40
47
|
**kwargs,
|
|
@@ -169,6 +176,8 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
|
|
|
169
176
|
for (
|
|
170
177
|
sub_handle
|
|
171
178
|
) in parent_register_commandable.parent_group.sub_command_filters:
|
|
179
|
+
if isinstance(sub_handle, CommandGroupFilter):
|
|
180
|
+
continue
|
|
172
181
|
# 所有符合fullname一致的子指令handle添加自定义过滤器。
|
|
173
182
|
# 不确定是否会有多个子指令有一样的fullname,比如一个方法添加多个command装饰器?
|
|
174
183
|
sub_handle_md = sub_handle.get_handler_md()
|
|
@@ -180,6 +189,8 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
|
|
|
180
189
|
|
|
181
190
|
else:
|
|
182
191
|
# 裸指令
|
|
192
|
+
# 确保运行时是可调用的 handler,针对类型检查器添加忽略
|
|
193
|
+
assert isinstance(awaitable, Callable)
|
|
183
194
|
handler_md = get_handler_or_create(
|
|
184
195
|
awaitable,
|
|
185
196
|
EventType.AdapterMessageEvent,
|
|
@@ -237,7 +248,7 @@ class RegisteringCommandable:
|
|
|
237
248
|
|
|
238
249
|
group: Callable[..., Callable[..., RegisteringCommandable]] = register_command_group
|
|
239
250
|
command: Callable[..., Callable[..., None]] = register_command
|
|
240
|
-
custom_filter: Callable[..., Callable[...,
|
|
251
|
+
custom_filter: Callable[..., Callable[..., Any]] = register_custom_filter
|
|
241
252
|
|
|
242
253
|
def __init__(self, parent_group: CommandGroupFilter):
|
|
243
254
|
self.parent_group = parent_group
|
|
@@ -412,7 +423,13 @@ def register_llm_tool(name: str | None = None, **kwargs):
|
|
|
412
423
|
if kwargs.get("registering_agent"):
|
|
413
424
|
registering_agent = kwargs["registering_agent"]
|
|
414
425
|
|
|
415
|
-
def decorator(
|
|
426
|
+
def decorator(
|
|
427
|
+
awaitable: Callable[
|
|
428
|
+
...,
|
|
429
|
+
AsyncGenerator[MessageEventResult | str | None]
|
|
430
|
+
| Awaitable[MessageEventResult | str | None],
|
|
431
|
+
],
|
|
432
|
+
):
|
|
416
433
|
llm_tool_name = name_ if name_ else awaitable.__name__
|
|
417
434
|
func_doc = awaitable.__doc__ or ""
|
|
418
435
|
docstring = docstring_parser.parse(func_doc)
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import enum
|
|
4
|
-
from collections.abc import Awaitable, Callable
|
|
4
|
+
from collections.abc import AsyncGenerator, Awaitable, Callable
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
|
-
from typing import Any, Generic, TypeVar
|
|
6
|
+
from typing import Any, Generic, Literal, TypeVar, overload
|
|
7
7
|
|
|
8
8
|
from .filter import HandlerFilter
|
|
9
9
|
from .star import star_map
|
|
@@ -29,6 +29,84 @@ class StarHandlerRegistry(Generic[T]):
|
|
|
29
29
|
for handler in self._handlers:
|
|
30
30
|
print(handler.handler_full_name)
|
|
31
31
|
|
|
32
|
+
@overload
|
|
33
|
+
def get_handlers_by_event_type(
|
|
34
|
+
self,
|
|
35
|
+
event_type: Literal[EventType.OnAstrBotLoadedEvent],
|
|
36
|
+
only_activated=True,
|
|
37
|
+
plugins_name: list[str] | None = None,
|
|
38
|
+
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
|
39
|
+
|
|
40
|
+
@overload
|
|
41
|
+
def get_handlers_by_event_type(
|
|
42
|
+
self,
|
|
43
|
+
event_type: Literal[EventType.OnPlatformLoadedEvent],
|
|
44
|
+
only_activated=True,
|
|
45
|
+
plugins_name: list[str] | None = None,
|
|
46
|
+
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
|
47
|
+
|
|
48
|
+
@overload
|
|
49
|
+
def get_handlers_by_event_type(
|
|
50
|
+
self,
|
|
51
|
+
event_type: Literal[EventType.AdapterMessageEvent],
|
|
52
|
+
only_activated=True,
|
|
53
|
+
plugins_name: list[str] | None = None,
|
|
54
|
+
) -> list[
|
|
55
|
+
StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
|
|
56
|
+
]: ...
|
|
57
|
+
|
|
58
|
+
@overload
|
|
59
|
+
def get_handlers_by_event_type(
|
|
60
|
+
self,
|
|
61
|
+
event_type: Literal[EventType.OnLLMRequestEvent],
|
|
62
|
+
only_activated=True,
|
|
63
|
+
plugins_name: list[str] | None = None,
|
|
64
|
+
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
|
65
|
+
|
|
66
|
+
@overload
|
|
67
|
+
def get_handlers_by_event_type(
|
|
68
|
+
self,
|
|
69
|
+
event_type: Literal[EventType.OnLLMResponseEvent],
|
|
70
|
+
only_activated=True,
|
|
71
|
+
plugins_name: list[str] | None = None,
|
|
72
|
+
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
|
73
|
+
|
|
74
|
+
@overload
|
|
75
|
+
def get_handlers_by_event_type(
|
|
76
|
+
self,
|
|
77
|
+
event_type: Literal[EventType.OnDecoratingResultEvent],
|
|
78
|
+
only_activated=True,
|
|
79
|
+
plugins_name: list[str] | None = None,
|
|
80
|
+
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
|
81
|
+
|
|
82
|
+
@overload
|
|
83
|
+
def get_handlers_by_event_type(
|
|
84
|
+
self,
|
|
85
|
+
event_type: Literal[EventType.OnCallingFuncToolEvent],
|
|
86
|
+
only_activated=True,
|
|
87
|
+
plugins_name: list[str] | None = None,
|
|
88
|
+
) -> list[
|
|
89
|
+
StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
|
|
90
|
+
]: ...
|
|
91
|
+
|
|
92
|
+
@overload
|
|
93
|
+
def get_handlers_by_event_type(
|
|
94
|
+
self,
|
|
95
|
+
event_type: Literal[EventType.OnAfterMessageSentEvent],
|
|
96
|
+
only_activated=True,
|
|
97
|
+
plugins_name: list[str] | None = None,
|
|
98
|
+
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
|
99
|
+
|
|
100
|
+
@overload
|
|
101
|
+
def get_handlers_by_event_type(
|
|
102
|
+
self,
|
|
103
|
+
event_type: EventType,
|
|
104
|
+
only_activated=True,
|
|
105
|
+
plugins_name: list[str] | None = None,
|
|
106
|
+
) -> list[
|
|
107
|
+
StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
|
|
108
|
+
]: ...
|
|
109
|
+
|
|
32
110
|
def get_handlers_by_event_type(
|
|
33
111
|
self,
|
|
34
112
|
event_type: EventType,
|
|
@@ -111,8 +189,11 @@ class EventType(enum.Enum):
|
|
|
111
189
|
OnAfterMessageSentEvent = enum.auto() # 发送消息后
|
|
112
190
|
|
|
113
191
|
|
|
192
|
+
H = TypeVar("H", bound=Callable[..., Any])
|
|
193
|
+
|
|
194
|
+
|
|
114
195
|
@dataclass
|
|
115
|
-
class StarHandlerMetadata:
|
|
196
|
+
class StarHandlerMetadata(Generic[H]):
|
|
116
197
|
"""描述一个 Star 所注册的某一个 Handler。"""
|
|
117
198
|
|
|
118
199
|
event_type: EventType
|
|
@@ -127,7 +208,7 @@ class StarHandlerMetadata:
|
|
|
127
208
|
handler_module_path: str
|
|
128
209
|
"""Handler 所在的模块路径。"""
|
|
129
210
|
|
|
130
|
-
handler:
|
|
211
|
+
handler: H
|
|
131
212
|
"""Handler 的函数对象,应当是一个异步函数"""
|
|
132
213
|
|
|
133
214
|
event_filters: list[HandlerFilter]
|
astrbot/core/updator.py
CHANGED
|
@@ -71,10 +71,10 @@ class AstrBotUpdator(RepoZipUpdator):
|
|
|
71
71
|
|
|
72
72
|
async def check_update(
|
|
73
73
|
self,
|
|
74
|
-
url: str,
|
|
75
|
-
current_version: str,
|
|
74
|
+
url: str | None,
|
|
75
|
+
current_version: str | None,
|
|
76
76
|
consider_prerelease: bool = True,
|
|
77
|
-
) -> ReleaseInfo:
|
|
77
|
+
) -> ReleaseInfo | None:
|
|
78
78
|
"""检查更新"""
|
|
79
79
|
return await super().check_update(
|
|
80
80
|
self.ASTRBOT_RELEASE_API,
|
astrbot/core/utils/io.py
CHANGED
|
@@ -49,7 +49,7 @@ def port_checker(port: int, host: str = "localhost"):
|
|
|
49
49
|
return False
|
|
50
50
|
|
|
51
51
|
|
|
52
|
-
def save_temp_img(img: Image.Image |
|
|
52
|
+
def save_temp_img(img: Image.Image | bytes) -> str:
|
|
53
53
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
|
54
54
|
# 获得文件创建时间,清除超过 12 小时的
|
|
55
55
|
try:
|
|
@@ -20,16 +20,16 @@ class SessionController:
|
|
|
20
20
|
|
|
21
21
|
def __init__(self):
|
|
22
22
|
self.future = asyncio.Future()
|
|
23
|
-
self.current_event: asyncio.Event = None
|
|
23
|
+
self.current_event: asyncio.Event | None = None
|
|
24
24
|
"""当前正在等待的所用的异步事件"""
|
|
25
|
-
self.ts: float = None
|
|
25
|
+
self.ts: float | None = None
|
|
26
26
|
"""上次保持(keep)开始时的时间"""
|
|
27
|
-
self.timeout: float | int = None
|
|
27
|
+
self.timeout: float | int | None = None
|
|
28
28
|
"""上次保持(keep)开始时的超时时间"""
|
|
29
29
|
|
|
30
30
|
self.history_chains: list[list[Comp.BaseMessageComponent]] = []
|
|
31
31
|
|
|
32
|
-
def stop(self, error: Exception = None):
|
|
32
|
+
def stop(self, error: Exception | None = None):
|
|
33
33
|
"""立即结束这个会话"""
|
|
34
34
|
if not self.future.done():
|
|
35
35
|
if error:
|
|
@@ -53,6 +53,8 @@ class SessionController:
|
|
|
53
53
|
self.stop()
|
|
54
54
|
return
|
|
55
55
|
else:
|
|
56
|
+
assert self.timeout is not None
|
|
57
|
+
assert self.ts is not None
|
|
56
58
|
left_timeout = self.timeout - (new_ts - self.ts)
|
|
57
59
|
timeout = left_timeout + timeout
|
|
58
60
|
if timeout <= 0:
|
|
@@ -69,7 +71,7 @@ class SessionController:
|
|
|
69
71
|
|
|
70
72
|
asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep
|
|
71
73
|
|
|
72
|
-
async def _holding(self, event: asyncio.Event, timeout:
|
|
74
|
+
async def _holding(self, event: asyncio.Event, timeout: float):
|
|
73
75
|
"""等待事件结束或超时"""
|
|
74
76
|
try:
|
|
75
77
|
await asyncio.wait_for(event.wait(), timeout)
|
|
@@ -108,7 +110,9 @@ class SessionWaiter:
|
|
|
108
110
|
):
|
|
109
111
|
self.session_id = session_id
|
|
110
112
|
self.session_filter = session_filter
|
|
111
|
-
self.handler:
|
|
113
|
+
self.handler: (
|
|
114
|
+
Callable[[SessionController, AstrMessageEvent], Awaitable[Any]] | None
|
|
115
|
+
) = None # 处理函数
|
|
112
116
|
|
|
113
117
|
self.session_controller = SessionController()
|
|
114
118
|
self.record_history_chains = record_history_chains
|
|
@@ -119,7 +123,7 @@ class SessionWaiter:
|
|
|
119
123
|
|
|
120
124
|
async def register_wait(
|
|
121
125
|
self,
|
|
122
|
-
handler: Callable[[
|
|
126
|
+
handler: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
|
|
123
127
|
timeout: int = 30,
|
|
124
128
|
) -> Any:
|
|
125
129
|
"""等待外部输入并处理"""
|
|
@@ -137,7 +141,7 @@ class SessionWaiter:
|
|
|
137
141
|
finally:
|
|
138
142
|
self._cleanup()
|
|
139
143
|
|
|
140
|
-
def _cleanup(self, error: Exception = None):
|
|
144
|
+
def _cleanup(self, error: Exception | None = None):
|
|
141
145
|
"""清理会话"""
|
|
142
146
|
USER_SESSIONS.pop(self.session_id, None)
|
|
143
147
|
try:
|
|
@@ -161,6 +165,7 @@ class SessionWaiter:
|
|
|
161
165
|
)
|
|
162
166
|
try:
|
|
163
167
|
# TODO: 这里使用 create_task,跟踪 task,防止超时后这里 handler 仍然在执行
|
|
168
|
+
assert session.handler is not None
|
|
164
169
|
await session.handler(session.session_controller, event)
|
|
165
170
|
except Exception as e:
|
|
166
171
|
session.session_controller.stop(e)
|
|
@@ -173,11 +178,13 @@ def session_waiter(timeout: int = 30, record_history_chains: bool = False):
|
|
|
173
178
|
:param record_history_chain: 是否自动记录历史消息链。可以通过 controller.get_history_chains() 获取。深拷贝。
|
|
174
179
|
"""
|
|
175
180
|
|
|
176
|
-
def decorator(
|
|
181
|
+
def decorator(
|
|
182
|
+
func: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
|
|
183
|
+
):
|
|
177
184
|
@functools.wraps(func)
|
|
178
185
|
async def wrapper(
|
|
179
186
|
event: AstrMessageEvent,
|
|
180
|
-
session_filter: SessionFilter = None,
|
|
187
|
+
session_filter: SessionFilter | None = None,
|
|
181
188
|
*args,
|
|
182
189
|
**kwargs,
|
|
183
190
|
):
|
|
@@ -53,6 +53,38 @@ class SharedPreferences:
|
|
|
53
53
|
ret = await self.db_helper.get_preferences(scope, scope_id, key)
|
|
54
54
|
return ret
|
|
55
55
|
|
|
56
|
+
@overload
|
|
57
|
+
async def session_get(
|
|
58
|
+
self,
|
|
59
|
+
umo: str,
|
|
60
|
+
key: str,
|
|
61
|
+
default: _VT = None,
|
|
62
|
+
) -> _VT: ...
|
|
63
|
+
|
|
64
|
+
@overload
|
|
65
|
+
async def session_get(
|
|
66
|
+
self,
|
|
67
|
+
umo: None,
|
|
68
|
+
key: str,
|
|
69
|
+
default: Any = None,
|
|
70
|
+
) -> list[Preference]: ...
|
|
71
|
+
|
|
72
|
+
@overload
|
|
73
|
+
async def session_get(
|
|
74
|
+
self,
|
|
75
|
+
umo: str,
|
|
76
|
+
key: None,
|
|
77
|
+
default: Any = None,
|
|
78
|
+
) -> list[Preference]: ...
|
|
79
|
+
|
|
80
|
+
@overload
|
|
81
|
+
async def session_get(
|
|
82
|
+
self,
|
|
83
|
+
umo: None,
|
|
84
|
+
key: None,
|
|
85
|
+
default: Any = None,
|
|
86
|
+
) -> list[Preference]: ...
|
|
87
|
+
|
|
56
88
|
async def session_get(
|
|
57
89
|
self,
|
|
58
90
|
umo: str | None,
|
|
@@ -3,11 +3,11 @@ from abc import ABC, abstractmethod
|
|
|
3
3
|
|
|
4
4
|
class RenderStrategy(ABC):
|
|
5
5
|
@abstractmethod
|
|
6
|
-
def render(self, text: str, return_url: bool) -> str:
|
|
6
|
+
async def render(self, text: str, return_url: bool) -> str:
|
|
7
7
|
pass
|
|
8
8
|
|
|
9
9
|
@abstractmethod
|
|
10
|
-
def render_custom_template(
|
|
10
|
+
async def render_custom_template(
|
|
11
11
|
self,
|
|
12
12
|
tmpl_str: str,
|
|
13
13
|
tmpl_data: dict,
|