AstrBot 4.8.0__py3-none-any.whl → 4.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/cli/__init__.py +1 -1
- astrbot/core/agent/runners/tool_loop_agent_runner.py +0 -1
- astrbot/core/agent/tool.py +7 -2
- astrbot/core/astr_agent_tool_exec.py +5 -1
- astrbot/core/config/astrbot_config.py +4 -0
- astrbot/core/config/default.py +72 -1
- astrbot/core/config/i18n_utils.py +1 -0
- astrbot/core/core_lifecycle.py +1 -1
- astrbot/core/db/__init__.py +2 -3
- astrbot/core/db/migration/migra_3_to_4.py +2 -0
- astrbot/core/db/migration/sqlite_v3.py +6 -4
- astrbot/core/db/po.py +16 -15
- astrbot/core/db/sqlite.py +4 -3
- astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +2 -0
- astrbot/core/event_bus.py +6 -1
- astrbot/core/knowledge_base/retrieval/manager.py +5 -1
- astrbot/core/log.py +2 -1
- astrbot/core/message/components.py +9 -3
- astrbot/core/persona_mgr.py +2 -2
- astrbot/core/pipeline/content_safety_check/stage.py +1 -1
- astrbot/core/pipeline/context_utils.py +2 -1
- astrbot/core/pipeline/process_stage/method/star_request.py +1 -2
- astrbot/core/pipeline/process_stage/stage.py +1 -1
- astrbot/core/pipeline/respond/stage.py +8 -2
- astrbot/core/pipeline/result_decorate/stage.py +89 -22
- astrbot/core/pipeline/scheduler.py +5 -1
- astrbot/core/pipeline/waking_check/stage.py +10 -0
- astrbot/core/platform/astr_message_event.py +5 -3
- astrbot/core/platform/astrbot_message.py +2 -2
- astrbot/core/platform/manager.py +4 -0
- astrbot/core/platform/platform.py +11 -3
- astrbot/core/platform/platform_metadata.py +1 -1
- astrbot/core/platform/register.py +1 -0
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +8 -6
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +9 -5
- astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +24 -16
- astrbot/core/platform/sources/dingtalk/dingtalk_event.py +5 -2
- astrbot/core/platform/sources/discord/client.py +16 -4
- astrbot/core/platform/sources/discord/components.py +2 -2
- astrbot/core/platform/sources/discord/discord_platform_adapter.py +52 -24
- astrbot/core/platform/sources/discord/discord_platform_event.py +29 -8
- astrbot/core/platform/sources/lark/lark_adapter.py +183 -20
- astrbot/core/platform/sources/lark/lark_event.py +39 -4
- astrbot/core/platform/sources/lark/server.py +206 -0
- astrbot/core/platform/sources/misskey/misskey_adapter.py +2 -3
- astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +62 -18
- astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +13 -7
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +5 -3
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +2 -1
- astrbot/core/platform/sources/slack/client.py +9 -2
- astrbot/core/platform/sources/slack/slack_adapter.py +15 -9
- astrbot/core/platform/sources/slack/slack_event.py +8 -7
- astrbot/core/platform/sources/telegram/tg_adapter.py +1 -1
- astrbot/core/platform/sources/telegram/tg_event.py +23 -27
- astrbot/core/platform/sources/webchat/webchat_adapter.py +2 -2
- astrbot/core/platform/sources/webchat/webchat_event.py +2 -2
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +26 -9
- astrbot/core/platform/sources/wecom/wecom_adapter.py +25 -28
- astrbot/core/platform/sources/wecom/wecom_event.py +2 -2
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +3 -3
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +30 -25
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +10 -7
- astrbot/core/provider/func_tool_manager.py +3 -3
- astrbot/core/provider/manager.py +130 -74
- astrbot/core/provider/provider.py +12 -1
- astrbot/core/provider/sources/azure_tts_source.py +31 -9
- astrbot/core/provider/sources/bailian_rerank_source.py +4 -0
- astrbot/core/provider/sources/dashscope_tts.py +3 -2
- astrbot/core/provider/sources/edge_tts_source.py +1 -1
- astrbot/core/provider/sources/fishaudio_tts_api_source.py +5 -4
- astrbot/core/provider/sources/gemini_embedding_source.py +15 -5
- astrbot/core/provider/sources/gemini_source.py +12 -10
- astrbot/core/provider/sources/minimax_tts_api_source.py +4 -2
- astrbot/core/provider/sources/openai_embedding_source.py +2 -2
- astrbot/core/provider/sources/openai_source.py +4 -0
- astrbot/core/provider/sources/sensevoice_selfhosted_source.py +5 -2
- astrbot/core/provider/sources/vllm_rerank_source.py +1 -0
- astrbot/core/provider/sources/whisper_api_source.py +1 -1
- astrbot/core/provider/sources/whisper_selfhosted_source.py +6 -2
- astrbot/core/provider/sources/xinference_rerank_source.py +10 -2
- astrbot/core/star/context.py +2 -2
- astrbot/core/star/register/star_handler.py +22 -5
- astrbot/core/star/star_handler.py +85 -4
- astrbot/core/updator.py +3 -3
- astrbot/core/utils/io.py +1 -1
- astrbot/core/utils/session_waiter.py +17 -10
- astrbot/core/utils/shared_preferences.py +32 -0
- astrbot/core/utils/t2i/__init__.py +2 -2
- astrbot/core/utils/t2i/local_strategy.py +25 -31
- astrbot/core/utils/tencent_record_helper.py +1 -1
- astrbot/core/utils/version_comparator.py +6 -3
- astrbot/core/utils/webhook_utils.py +19 -0
- astrbot/dashboard/routes/chat.py +14 -9
- astrbot/dashboard/routes/config.py +10 -20
- astrbot/dashboard/routes/conversation.py +91 -1
- astrbot/dashboard/routes/knowledge_base.py +253 -78
- astrbot/dashboard/routes/log.py +13 -8
- astrbot/dashboard/routes/platform.py +1 -1
- astrbot/dashboard/routes/plugin.py +113 -52
- astrbot/dashboard/routes/route.py +2 -0
- astrbot/dashboard/server.py +6 -3
- {astrbot-4.8.0.dist-info → astrbot-4.9.1.dist-info}/METADATA +9 -1
- {astrbot-4.8.0.dist-info → astrbot-4.9.1.dist-info}/RECORD +106 -105
- {astrbot-4.8.0.dist-info → astrbot-4.9.1.dist-info}/WHEEL +0 -0
- {astrbot-4.8.0.dist-info → astrbot-4.9.1.dist-info}/entry_points.txt +0 -0
- {astrbot-4.8.0.dist-info → astrbot-4.9.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -4,6 +4,7 @@ import json
|
|
|
4
4
|
import logging
|
|
5
5
|
import random
|
|
6
6
|
from collections.abc import AsyncGenerator
|
|
7
|
+
from typing import cast
|
|
7
8
|
|
|
8
9
|
from google import genai
|
|
9
10
|
from google.genai import types
|
|
@@ -126,17 +127,17 @@ class ProviderGoogleGenAI(Provider):
|
|
|
126
127
|
) -> types.GenerateContentConfig:
|
|
127
128
|
"""准备查询配置"""
|
|
128
129
|
if not modalities:
|
|
129
|
-
modalities = ["
|
|
130
|
+
modalities = ["TEXT"]
|
|
130
131
|
|
|
131
132
|
# 流式输出不支持图片模态
|
|
132
133
|
if (
|
|
133
134
|
self.provider_settings.get("streaming_response", False)
|
|
134
|
-
and "
|
|
135
|
+
and "IMAGE" in modalities
|
|
135
136
|
):
|
|
136
137
|
logger.warning("流式输出不支持图片模态,已自动降级为文本模态")
|
|
137
|
-
modalities = ["
|
|
138
|
+
modalities = ["TEXT"]
|
|
138
139
|
|
|
139
|
-
tool_list = []
|
|
140
|
+
tool_list: list[types.Tool] | None = []
|
|
140
141
|
model_name = self.get_model()
|
|
141
142
|
native_coderunner = self.provider_config.get("gm_native_coderunner", False)
|
|
142
143
|
native_search = self.provider_config.get("gm_native_search", False)
|
|
@@ -213,7 +214,7 @@ class ProviderGoogleGenAI(Provider):
|
|
|
213
214
|
logprobs=payloads.get("logprobs"),
|
|
214
215
|
seed=payloads.get("seed"),
|
|
215
216
|
response_modalities=modalities,
|
|
216
|
-
tools=tool_list,
|
|
217
|
+
tools=cast(types.ToolListUnion | None, tool_list),
|
|
217
218
|
safety_settings=self.safety_settings if self.safety_settings else None,
|
|
218
219
|
thinking_config=(
|
|
219
220
|
types.ThinkingConfig(
|
|
@@ -257,6 +258,7 @@ class ProviderGoogleGenAI(Provider):
|
|
|
257
258
|
content_cls: type[types.Content],
|
|
258
259
|
) -> None:
|
|
259
260
|
if contents and isinstance(contents[-1], content_cls):
|
|
261
|
+
assert contents[-1].parts is not None
|
|
260
262
|
contents[-1].parts.extend(part)
|
|
261
263
|
else:
|
|
262
264
|
contents.append(content_cls(parts=part))
|
|
@@ -429,9 +431,9 @@ class ProviderGoogleGenAI(Provider):
|
|
|
429
431
|
None,
|
|
430
432
|
)
|
|
431
433
|
|
|
432
|
-
modalities = ["
|
|
434
|
+
modalities = ["TEXT"]
|
|
433
435
|
if self.provider_config.get("gm_resp_image_modal", False):
|
|
434
|
-
modalities.append("
|
|
436
|
+
modalities.append("IMAGE")
|
|
435
437
|
|
|
436
438
|
conversation = self._prepare_conversation(payloads)
|
|
437
439
|
temperature = payloads.get("temperature", 0.7)
|
|
@@ -448,7 +450,7 @@ class ProviderGoogleGenAI(Provider):
|
|
|
448
450
|
)
|
|
449
451
|
result = await self.client.models.generate_content(
|
|
450
452
|
model=self.get_model(),
|
|
451
|
-
contents=conversation,
|
|
453
|
+
contents=cast(types.ContentListUnion, conversation),
|
|
452
454
|
config=config,
|
|
453
455
|
)
|
|
454
456
|
logger.debug(f"genai result: {result}")
|
|
@@ -488,7 +490,7 @@ class ProviderGoogleGenAI(Provider):
|
|
|
488
490
|
logger.warning(
|
|
489
491
|
f"{self.get_model()} 不支持多模态输出,降级为文本模态",
|
|
490
492
|
)
|
|
491
|
-
modalities = ["
|
|
493
|
+
modalities = ["TEXT"]
|
|
492
494
|
else:
|
|
493
495
|
raise
|
|
494
496
|
continue
|
|
@@ -524,7 +526,7 @@ class ProviderGoogleGenAI(Provider):
|
|
|
524
526
|
)
|
|
525
527
|
result = await self.client.models.generate_content_stream(
|
|
526
528
|
model=self.get_model(),
|
|
527
|
-
contents=conversation,
|
|
529
|
+
contents=cast(types.ContentListUnion, conversation),
|
|
528
530
|
config=config,
|
|
529
531
|
)
|
|
530
532
|
break
|
|
@@ -87,7 +87,7 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
|
|
|
87
87
|
|
|
88
88
|
return json.dumps(dict_body)
|
|
89
89
|
|
|
90
|
-
async def _call_tts_stream(self, text: str) -> AsyncIterator[
|
|
90
|
+
async def _call_tts_stream(self, text: str) -> AsyncIterator[str]:
|
|
91
91
|
"""进行流式请求"""
|
|
92
92
|
try:
|
|
93
93
|
async with (
|
|
@@ -117,7 +117,9 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
|
|
|
117
117
|
data = json.loads(message[6:])
|
|
118
118
|
if "extra_info" in data:
|
|
119
119
|
continue
|
|
120
|
-
audio = data.get("data", {}).get(
|
|
120
|
+
audio: str | None = data.get("data", {}).get(
|
|
121
|
+
"audio"
|
|
122
|
+
)
|
|
121
123
|
if audio is not None:
|
|
122
124
|
yield audio
|
|
123
125
|
except json.JSONDecodeError:
|
|
@@ -30,9 +30,9 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
|
|
30
30
|
embedding = await self.client.embeddings.create(input=text, model=self.model)
|
|
31
31
|
return embedding.data[0].embedding
|
|
32
32
|
|
|
33
|
-
async def get_embeddings(self,
|
|
33
|
+
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
|
|
34
34
|
"""批量获取文本的嵌入"""
|
|
35
|
-
embeddings = await self.client.embeddings.create(input=
|
|
35
|
+
embeddings = await self.client.embeddings.create(input=text, model=self.model)
|
|
36
36
|
return [item.embedding for item in embeddings.data]
|
|
37
37
|
|
|
38
38
|
def get_dim(self) -> int:
|
|
@@ -284,6 +284,10 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
284
284
|
if isinstance(tool_call, str):
|
|
285
285
|
# workaround for #1359
|
|
286
286
|
tool_call = json.loads(tool_call)
|
|
287
|
+
if tools is None:
|
|
288
|
+
# 工具集未提供
|
|
289
|
+
# Should be unreachable
|
|
290
|
+
raise Exception("工具集未提供")
|
|
287
291
|
for tool in tools.func_list:
|
|
288
292
|
if (
|
|
289
293
|
tool_call.type == "function"
|
|
@@ -7,6 +7,7 @@ import asyncio
|
|
|
7
7
|
import os
|
|
8
8
|
import re
|
|
9
9
|
from datetime import datetime
|
|
10
|
+
from typing import cast
|
|
10
11
|
|
|
11
12
|
from funasr_onnx import SenseVoiceSmall
|
|
12
13
|
from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess
|
|
@@ -32,7 +33,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
|
|
|
32
33
|
provider_settings: dict,
|
|
33
34
|
) -> None:
|
|
34
35
|
super().__init__(provider_config, provider_settings)
|
|
35
|
-
self.set_model(provider_config
|
|
36
|
+
self.set_model(provider_config["stt_model"])
|
|
36
37
|
self.model = None
|
|
37
38
|
self.is_emotion = provider_config.get("is_emotion", False)
|
|
38
39
|
|
|
@@ -86,7 +87,9 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
|
|
|
86
87
|
loop = asyncio.get_event_loop()
|
|
87
88
|
res = await loop.run_in_executor(
|
|
88
89
|
None, # 使用默认的线程池
|
|
89
|
-
lambda: self.model(
|
|
90
|
+
lambda: cast(SenseVoiceSmall, self.model)(
|
|
91
|
+
audio_url, language="auto", use_itn=True
|
|
92
|
+
),
|
|
90
93
|
)
|
|
91
94
|
|
|
92
95
|
# res = self.model(audio_url, language="auto", use_itn=True)
|
|
@@ -36,7 +36,7 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
|
|
36
36
|
timeout=provider_config.get("timeout", NOT_GIVEN),
|
|
37
37
|
)
|
|
38
38
|
|
|
39
|
-
self.set_model(provider_config
|
|
39
|
+
self.set_model(provider_config["model"])
|
|
40
40
|
|
|
41
41
|
async def _get_audio_format(self, file_path):
|
|
42
42
|
# 定义要检测的头部字节
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import os
|
|
3
3
|
import uuid
|
|
4
|
+
from typing import cast
|
|
4
5
|
|
|
5
6
|
import whisper
|
|
6
7
|
|
|
@@ -26,7 +27,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
|
|
26
27
|
provider_settings: dict,
|
|
27
28
|
) -> None:
|
|
28
29
|
super().__init__(provider_config, provider_settings)
|
|
29
|
-
self.set_model(provider_config
|
|
30
|
+
self.set_model(provider_config["model"])
|
|
30
31
|
self.model = None
|
|
31
32
|
|
|
32
33
|
async def initialize(self):
|
|
@@ -75,5 +76,8 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
|
|
75
76
|
await tencent_silk_to_wav(audio_url, output_path)
|
|
76
77
|
audio_url = output_path
|
|
77
78
|
|
|
79
|
+
if not self.model:
|
|
80
|
+
raise RuntimeError("Whisper 模型未初始化")
|
|
81
|
+
|
|
78
82
|
result = await loop.run_in_executor(None, self.model.transcribe, audio_url)
|
|
79
|
-
return result["text"]
|
|
83
|
+
return cast(str, result["text"])
|
|
@@ -1,6 +1,11 @@
|
|
|
1
|
+
from typing import cast
|
|
2
|
+
|
|
1
3
|
from xinference_client.client.restful.async_restful_client import (
|
|
2
4
|
AsyncClient as Client,
|
|
3
5
|
)
|
|
6
|
+
from xinference_client.client.restful.async_restful_client import (
|
|
7
|
+
AsyncRESTfulRerankModelHandle,
|
|
8
|
+
)
|
|
4
9
|
|
|
5
10
|
from astrbot import logger
|
|
6
11
|
|
|
@@ -29,7 +34,7 @@ class XinferenceRerankProvider(RerankProvider):
|
|
|
29
34
|
False,
|
|
30
35
|
)
|
|
31
36
|
self.client = None
|
|
32
|
-
self.model = None
|
|
37
|
+
self.model: AsyncRESTfulRerankModelHandle | None = None
|
|
33
38
|
self.model_uid = None
|
|
34
39
|
|
|
35
40
|
async def initialize(self):
|
|
@@ -65,7 +70,10 @@ class XinferenceRerankProvider(RerankProvider):
|
|
|
65
70
|
return
|
|
66
71
|
|
|
67
72
|
if self.model_uid:
|
|
68
|
-
self.model =
|
|
73
|
+
self.model = cast(
|
|
74
|
+
AsyncRESTfulRerankModelHandle,
|
|
75
|
+
await self.client.get_model(self.model_uid),
|
|
76
|
+
)
|
|
69
77
|
|
|
70
78
|
except Exception as e:
|
|
71
79
|
logger.error(f"Failed to initialize Xinference model: {e}")
|
astrbot/core/star/context.py
CHANGED
|
@@ -285,7 +285,7 @@ class Context:
|
|
|
285
285
|
"""获取所有用于 Embedding 任务的 Provider。"""
|
|
286
286
|
return self.provider_manager.embedding_provider_insts
|
|
287
287
|
|
|
288
|
-
def get_using_provider(self, umo: str | None = None) -> Provider
|
|
288
|
+
def get_using_provider(self, umo: str | None = None) -> Provider:
|
|
289
289
|
"""获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
|
|
290
290
|
|
|
291
291
|
Args:
|
|
@@ -296,7 +296,7 @@ class Context:
|
|
|
296
296
|
provider_type=ProviderType.CHAT_COMPLETION,
|
|
297
297
|
umo=umo,
|
|
298
298
|
)
|
|
299
|
-
if
|
|
299
|
+
if not isinstance(prov, Provider):
|
|
300
300
|
raise ValueError("返回的 Provider 不是 Provider 类型")
|
|
301
301
|
return prov
|
|
302
302
|
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import re
|
|
4
|
-
from collections.abc import Awaitable, Callable
|
|
4
|
+
from collections.abc import AsyncGenerator, Awaitable, Callable
|
|
5
5
|
from typing import Any
|
|
6
6
|
|
|
7
7
|
import docstring_parser
|
|
@@ -12,6 +12,7 @@ from astrbot.core.agent.handoff import HandoffTool
|
|
|
12
12
|
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
|
13
13
|
from astrbot.core.agent.tool import FunctionTool
|
|
14
14
|
from astrbot.core.astr_agent_context import AstrAgentContext
|
|
15
|
+
from astrbot.core.message.message_event_result import MessageEventResult
|
|
15
16
|
from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES
|
|
16
17
|
from astrbot.core.provider.register import llm_tools
|
|
17
18
|
|
|
@@ -28,13 +29,19 @@ from ..filter.regex import RegexFilter
|
|
|
28
29
|
from ..star_handler import EventType, StarHandlerMetadata, star_handlers_registry
|
|
29
30
|
|
|
30
31
|
|
|
31
|
-
def get_handler_full_name(
|
|
32
|
+
def get_handler_full_name(
|
|
33
|
+
awaitable: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
|
|
34
|
+
) -> str:
|
|
32
35
|
"""获取 Handler 的全名"""
|
|
33
36
|
return f"{awaitable.__module__}_{awaitable.__name__}"
|
|
34
37
|
|
|
35
38
|
|
|
36
39
|
def get_handler_or_create(
|
|
37
|
-
handler: Callable[
|
|
40
|
+
handler: Callable[
|
|
41
|
+
...,
|
|
42
|
+
Awaitable[MessageEventResult | str | None]
|
|
43
|
+
| AsyncGenerator[MessageEventResult | str | None],
|
|
44
|
+
],
|
|
38
45
|
event_type: EventType,
|
|
39
46
|
dont_add=False,
|
|
40
47
|
**kwargs,
|
|
@@ -169,6 +176,8 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
|
|
|
169
176
|
for (
|
|
170
177
|
sub_handle
|
|
171
178
|
) in parent_register_commandable.parent_group.sub_command_filters:
|
|
179
|
+
if isinstance(sub_handle, CommandGroupFilter):
|
|
180
|
+
continue
|
|
172
181
|
# 所有符合fullname一致的子指令handle添加自定义过滤器。
|
|
173
182
|
# 不确定是否会有多个子指令有一样的fullname,比如一个方法添加多个command装饰器?
|
|
174
183
|
sub_handle_md = sub_handle.get_handler_md()
|
|
@@ -180,6 +189,8 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
|
|
|
180
189
|
|
|
181
190
|
else:
|
|
182
191
|
# 裸指令
|
|
192
|
+
# 确保运行时是可调用的 handler,针对类型检查器添加忽略
|
|
193
|
+
assert isinstance(awaitable, Callable)
|
|
183
194
|
handler_md = get_handler_or_create(
|
|
184
195
|
awaitable,
|
|
185
196
|
EventType.AdapterMessageEvent,
|
|
@@ -237,7 +248,7 @@ class RegisteringCommandable:
|
|
|
237
248
|
|
|
238
249
|
group: Callable[..., Callable[..., RegisteringCommandable]] = register_command_group
|
|
239
250
|
command: Callable[..., Callable[..., None]] = register_command
|
|
240
|
-
custom_filter: Callable[..., Callable[...,
|
|
251
|
+
custom_filter: Callable[..., Callable[..., Any]] = register_custom_filter
|
|
241
252
|
|
|
242
253
|
def __init__(self, parent_group: CommandGroupFilter):
|
|
243
254
|
self.parent_group = parent_group
|
|
@@ -412,7 +423,13 @@ def register_llm_tool(name: str | None = None, **kwargs):
|
|
|
412
423
|
if kwargs.get("registering_agent"):
|
|
413
424
|
registering_agent = kwargs["registering_agent"]
|
|
414
425
|
|
|
415
|
-
def decorator(
|
|
426
|
+
def decorator(
|
|
427
|
+
awaitable: Callable[
|
|
428
|
+
...,
|
|
429
|
+
AsyncGenerator[MessageEventResult | str | None]
|
|
430
|
+
| Awaitable[MessageEventResult | str | None],
|
|
431
|
+
],
|
|
432
|
+
):
|
|
416
433
|
llm_tool_name = name_ if name_ else awaitable.__name__
|
|
417
434
|
func_doc = awaitable.__doc__ or ""
|
|
418
435
|
docstring = docstring_parser.parse(func_doc)
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import enum
|
|
4
|
-
from collections.abc import Awaitable, Callable
|
|
4
|
+
from collections.abc import AsyncGenerator, Awaitable, Callable
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
|
-
from typing import Any, Generic, TypeVar
|
|
6
|
+
from typing import Any, Generic, Literal, TypeVar, overload
|
|
7
7
|
|
|
8
8
|
from .filter import HandlerFilter
|
|
9
9
|
from .star import star_map
|
|
@@ -29,6 +29,84 @@ class StarHandlerRegistry(Generic[T]):
|
|
|
29
29
|
for handler in self._handlers:
|
|
30
30
|
print(handler.handler_full_name)
|
|
31
31
|
|
|
32
|
+
@overload
|
|
33
|
+
def get_handlers_by_event_type(
|
|
34
|
+
self,
|
|
35
|
+
event_type: Literal[EventType.OnAstrBotLoadedEvent],
|
|
36
|
+
only_activated=True,
|
|
37
|
+
plugins_name: list[str] | None = None,
|
|
38
|
+
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
|
39
|
+
|
|
40
|
+
@overload
|
|
41
|
+
def get_handlers_by_event_type(
|
|
42
|
+
self,
|
|
43
|
+
event_type: Literal[EventType.OnPlatformLoadedEvent],
|
|
44
|
+
only_activated=True,
|
|
45
|
+
plugins_name: list[str] | None = None,
|
|
46
|
+
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
|
47
|
+
|
|
48
|
+
@overload
|
|
49
|
+
def get_handlers_by_event_type(
|
|
50
|
+
self,
|
|
51
|
+
event_type: Literal[EventType.AdapterMessageEvent],
|
|
52
|
+
only_activated=True,
|
|
53
|
+
plugins_name: list[str] | None = None,
|
|
54
|
+
) -> list[
|
|
55
|
+
StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
|
|
56
|
+
]: ...
|
|
57
|
+
|
|
58
|
+
@overload
|
|
59
|
+
def get_handlers_by_event_type(
|
|
60
|
+
self,
|
|
61
|
+
event_type: Literal[EventType.OnLLMRequestEvent],
|
|
62
|
+
only_activated=True,
|
|
63
|
+
plugins_name: list[str] | None = None,
|
|
64
|
+
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
|
65
|
+
|
|
66
|
+
@overload
|
|
67
|
+
def get_handlers_by_event_type(
|
|
68
|
+
self,
|
|
69
|
+
event_type: Literal[EventType.OnLLMResponseEvent],
|
|
70
|
+
only_activated=True,
|
|
71
|
+
plugins_name: list[str] | None = None,
|
|
72
|
+
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
|
73
|
+
|
|
74
|
+
@overload
|
|
75
|
+
def get_handlers_by_event_type(
|
|
76
|
+
self,
|
|
77
|
+
event_type: Literal[EventType.OnDecoratingResultEvent],
|
|
78
|
+
only_activated=True,
|
|
79
|
+
plugins_name: list[str] | None = None,
|
|
80
|
+
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
|
81
|
+
|
|
82
|
+
@overload
|
|
83
|
+
def get_handlers_by_event_type(
|
|
84
|
+
self,
|
|
85
|
+
event_type: Literal[EventType.OnCallingFuncToolEvent],
|
|
86
|
+
only_activated=True,
|
|
87
|
+
plugins_name: list[str] | None = None,
|
|
88
|
+
) -> list[
|
|
89
|
+
StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
|
|
90
|
+
]: ...
|
|
91
|
+
|
|
92
|
+
@overload
|
|
93
|
+
def get_handlers_by_event_type(
|
|
94
|
+
self,
|
|
95
|
+
event_type: Literal[EventType.OnAfterMessageSentEvent],
|
|
96
|
+
only_activated=True,
|
|
97
|
+
plugins_name: list[str] | None = None,
|
|
98
|
+
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
|
99
|
+
|
|
100
|
+
@overload
|
|
101
|
+
def get_handlers_by_event_type(
|
|
102
|
+
self,
|
|
103
|
+
event_type: EventType,
|
|
104
|
+
only_activated=True,
|
|
105
|
+
plugins_name: list[str] | None = None,
|
|
106
|
+
) -> list[
|
|
107
|
+
StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
|
|
108
|
+
]: ...
|
|
109
|
+
|
|
32
110
|
def get_handlers_by_event_type(
|
|
33
111
|
self,
|
|
34
112
|
event_type: EventType,
|
|
@@ -111,8 +189,11 @@ class EventType(enum.Enum):
|
|
|
111
189
|
OnAfterMessageSentEvent = enum.auto() # 发送消息后
|
|
112
190
|
|
|
113
191
|
|
|
192
|
+
H = TypeVar("H", bound=Callable[..., Any])
|
|
193
|
+
|
|
194
|
+
|
|
114
195
|
@dataclass
|
|
115
|
-
class StarHandlerMetadata:
|
|
196
|
+
class StarHandlerMetadata(Generic[H]):
|
|
116
197
|
"""描述一个 Star 所注册的某一个 Handler。"""
|
|
117
198
|
|
|
118
199
|
event_type: EventType
|
|
@@ -127,7 +208,7 @@ class StarHandlerMetadata:
|
|
|
127
208
|
handler_module_path: str
|
|
128
209
|
"""Handler 所在的模块路径。"""
|
|
129
210
|
|
|
130
|
-
handler:
|
|
211
|
+
handler: H
|
|
131
212
|
"""Handler 的函数对象,应当是一个异步函数"""
|
|
132
213
|
|
|
133
214
|
event_filters: list[HandlerFilter]
|
astrbot/core/updator.py
CHANGED
|
@@ -71,10 +71,10 @@ class AstrBotUpdator(RepoZipUpdator):
|
|
|
71
71
|
|
|
72
72
|
async def check_update(
|
|
73
73
|
self,
|
|
74
|
-
url: str,
|
|
75
|
-
current_version: str,
|
|
74
|
+
url: str | None,
|
|
75
|
+
current_version: str | None,
|
|
76
76
|
consider_prerelease: bool = True,
|
|
77
|
-
) -> ReleaseInfo:
|
|
77
|
+
) -> ReleaseInfo | None:
|
|
78
78
|
"""检查更新"""
|
|
79
79
|
return await super().check_update(
|
|
80
80
|
self.ASTRBOT_RELEASE_API,
|
astrbot/core/utils/io.py
CHANGED
|
@@ -49,7 +49,7 @@ def port_checker(port: int, host: str = "localhost"):
|
|
|
49
49
|
return False
|
|
50
50
|
|
|
51
51
|
|
|
52
|
-
def save_temp_img(img: Image.Image |
|
|
52
|
+
def save_temp_img(img: Image.Image | bytes) -> str:
|
|
53
53
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
|
54
54
|
# 获得文件创建时间,清除超过 12 小时的
|
|
55
55
|
try:
|
|
@@ -20,16 +20,16 @@ class SessionController:
|
|
|
20
20
|
|
|
21
21
|
def __init__(self):
|
|
22
22
|
self.future = asyncio.Future()
|
|
23
|
-
self.current_event: asyncio.Event = None
|
|
23
|
+
self.current_event: asyncio.Event | None = None
|
|
24
24
|
"""当前正在等待的所用的异步事件"""
|
|
25
|
-
self.ts: float = None
|
|
25
|
+
self.ts: float | None = None
|
|
26
26
|
"""上次保持(keep)开始时的时间"""
|
|
27
|
-
self.timeout: float | int = None
|
|
27
|
+
self.timeout: float | int | None = None
|
|
28
28
|
"""上次保持(keep)开始时的超时时间"""
|
|
29
29
|
|
|
30
30
|
self.history_chains: list[list[Comp.BaseMessageComponent]] = []
|
|
31
31
|
|
|
32
|
-
def stop(self, error: Exception = None):
|
|
32
|
+
def stop(self, error: Exception | None = None):
|
|
33
33
|
"""立即结束这个会话"""
|
|
34
34
|
if not self.future.done():
|
|
35
35
|
if error:
|
|
@@ -53,6 +53,8 @@ class SessionController:
|
|
|
53
53
|
self.stop()
|
|
54
54
|
return
|
|
55
55
|
else:
|
|
56
|
+
assert self.timeout is not None
|
|
57
|
+
assert self.ts is not None
|
|
56
58
|
left_timeout = self.timeout - (new_ts - self.ts)
|
|
57
59
|
timeout = left_timeout + timeout
|
|
58
60
|
if timeout <= 0:
|
|
@@ -69,7 +71,7 @@ class SessionController:
|
|
|
69
71
|
|
|
70
72
|
asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep
|
|
71
73
|
|
|
72
|
-
async def _holding(self, event: asyncio.Event, timeout:
|
|
74
|
+
async def _holding(self, event: asyncio.Event, timeout: float):
|
|
73
75
|
"""等待事件结束或超时"""
|
|
74
76
|
try:
|
|
75
77
|
await asyncio.wait_for(event.wait(), timeout)
|
|
@@ -108,7 +110,9 @@ class SessionWaiter:
|
|
|
108
110
|
):
|
|
109
111
|
self.session_id = session_id
|
|
110
112
|
self.session_filter = session_filter
|
|
111
|
-
self.handler:
|
|
113
|
+
self.handler: (
|
|
114
|
+
Callable[[SessionController, AstrMessageEvent], Awaitable[Any]] | None
|
|
115
|
+
) = None # 处理函数
|
|
112
116
|
|
|
113
117
|
self.session_controller = SessionController()
|
|
114
118
|
self.record_history_chains = record_history_chains
|
|
@@ -119,7 +123,7 @@ class SessionWaiter:
|
|
|
119
123
|
|
|
120
124
|
async def register_wait(
|
|
121
125
|
self,
|
|
122
|
-
handler: Callable[[
|
|
126
|
+
handler: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
|
|
123
127
|
timeout: int = 30,
|
|
124
128
|
) -> Any:
|
|
125
129
|
"""等待外部输入并处理"""
|
|
@@ -137,7 +141,7 @@ class SessionWaiter:
|
|
|
137
141
|
finally:
|
|
138
142
|
self._cleanup()
|
|
139
143
|
|
|
140
|
-
def _cleanup(self, error: Exception = None):
|
|
144
|
+
def _cleanup(self, error: Exception | None = None):
|
|
141
145
|
"""清理会话"""
|
|
142
146
|
USER_SESSIONS.pop(self.session_id, None)
|
|
143
147
|
try:
|
|
@@ -161,6 +165,7 @@ class SessionWaiter:
|
|
|
161
165
|
)
|
|
162
166
|
try:
|
|
163
167
|
# TODO: 这里使用 create_task,跟踪 task,防止超时后这里 handler 仍然在执行
|
|
168
|
+
assert session.handler is not None
|
|
164
169
|
await session.handler(session.session_controller, event)
|
|
165
170
|
except Exception as e:
|
|
166
171
|
session.session_controller.stop(e)
|
|
@@ -173,11 +178,13 @@ def session_waiter(timeout: int = 30, record_history_chains: bool = False):
|
|
|
173
178
|
:param record_history_chain: 是否自动记录历史消息链。可以通过 controller.get_history_chains() 获取。深拷贝。
|
|
174
179
|
"""
|
|
175
180
|
|
|
176
|
-
def decorator(
|
|
181
|
+
def decorator(
|
|
182
|
+
func: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
|
|
183
|
+
):
|
|
177
184
|
@functools.wraps(func)
|
|
178
185
|
async def wrapper(
|
|
179
186
|
event: AstrMessageEvent,
|
|
180
|
-
session_filter: SessionFilter = None,
|
|
187
|
+
session_filter: SessionFilter | None = None,
|
|
181
188
|
*args,
|
|
182
189
|
**kwargs,
|
|
183
190
|
):
|
|
@@ -53,6 +53,38 @@ class SharedPreferences:
|
|
|
53
53
|
ret = await self.db_helper.get_preferences(scope, scope_id, key)
|
|
54
54
|
return ret
|
|
55
55
|
|
|
56
|
+
@overload
|
|
57
|
+
async def session_get(
|
|
58
|
+
self,
|
|
59
|
+
umo: str,
|
|
60
|
+
key: str,
|
|
61
|
+
default: _VT = None,
|
|
62
|
+
) -> _VT: ...
|
|
63
|
+
|
|
64
|
+
@overload
|
|
65
|
+
async def session_get(
|
|
66
|
+
self,
|
|
67
|
+
umo: None,
|
|
68
|
+
key: str,
|
|
69
|
+
default: Any = None,
|
|
70
|
+
) -> list[Preference]: ...
|
|
71
|
+
|
|
72
|
+
@overload
|
|
73
|
+
async def session_get(
|
|
74
|
+
self,
|
|
75
|
+
umo: str,
|
|
76
|
+
key: None,
|
|
77
|
+
default: Any = None,
|
|
78
|
+
) -> list[Preference]: ...
|
|
79
|
+
|
|
80
|
+
@overload
|
|
81
|
+
async def session_get(
|
|
82
|
+
self,
|
|
83
|
+
umo: None,
|
|
84
|
+
key: None,
|
|
85
|
+
default: Any = None,
|
|
86
|
+
) -> list[Preference]: ...
|
|
87
|
+
|
|
56
88
|
async def session_get(
|
|
57
89
|
self,
|
|
58
90
|
umo: str | None,
|
|
@@ -3,11 +3,11 @@ from abc import ABC, abstractmethod
|
|
|
3
3
|
|
|
4
4
|
class RenderStrategy(ABC):
|
|
5
5
|
@abstractmethod
|
|
6
|
-
def render(self, text: str, return_url: bool) -> str:
|
|
6
|
+
async def render(self, text: str, return_url: bool) -> str:
|
|
7
7
|
pass
|
|
8
8
|
|
|
9
9
|
@abstractmethod
|
|
10
|
-
def render_custom_template(
|
|
10
|
+
async def render_custom_template(
|
|
11
11
|
self,
|
|
12
12
|
tmpl_str: str,
|
|
13
13
|
tmpl_data: dict,
|