AstrBot 4.5.1__py3-none-any.whl → 4.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/api/__init__.py +10 -11
- astrbot/api/event/__init__.py +5 -6
- astrbot/api/event/filter/__init__.py +37 -36
- astrbot/api/platform/__init__.py +7 -8
- astrbot/api/provider/__init__.py +7 -7
- astrbot/api/star/__init__.py +3 -4
- astrbot/api/util/__init__.py +2 -2
- astrbot/cli/__main__.py +5 -5
- astrbot/cli/commands/__init__.py +3 -3
- astrbot/cli/commands/cmd_conf.py +19 -16
- astrbot/cli/commands/cmd_init.py +3 -2
- astrbot/cli/commands/cmd_plug.py +8 -10
- astrbot/cli/commands/cmd_run.py +5 -6
- astrbot/cli/utils/__init__.py +6 -6
- astrbot/cli/utils/basic.py +14 -14
- astrbot/cli/utils/plugin.py +24 -15
- astrbot/cli/utils/version_comparator.py +10 -12
- astrbot/core/__init__.py +8 -6
- astrbot/core/agent/agent.py +3 -2
- astrbot/core/agent/handoff.py +6 -2
- astrbot/core/agent/hooks.py +9 -6
- astrbot/core/agent/mcp_client.py +50 -15
- astrbot/core/agent/message.py +168 -0
- astrbot/core/agent/response.py +2 -1
- astrbot/core/agent/run_context.py +2 -3
- astrbot/core/agent/runners/base.py +10 -13
- astrbot/core/agent/runners/tool_loop_agent_runner.py +52 -51
- astrbot/core/agent/tool.py +60 -41
- astrbot/core/agent/tool_executor.py +9 -3
- astrbot/core/astr_agent_context.py +3 -1
- astrbot/core/astrbot_config_mgr.py +29 -9
- astrbot/core/config/__init__.py +2 -2
- astrbot/core/config/astrbot_config.py +28 -26
- astrbot/core/config/default.py +4 -6
- astrbot/core/conversation_mgr.py +105 -36
- astrbot/core/core_lifecycle.py +68 -54
- astrbot/core/db/__init__.py +33 -18
- astrbot/core/db/migration/helper.py +12 -10
- astrbot/core/db/migration/migra_3_to_4.py +53 -34
- astrbot/core/db/migration/migra_45_to_46.py +1 -1
- astrbot/core/db/migration/shared_preferences_v3.py +2 -1
- astrbot/core/db/migration/sqlite_v3.py +26 -23
- astrbot/core/db/po.py +27 -18
- astrbot/core/db/sqlite.py +74 -45
- astrbot/core/db/vec_db/base.py +10 -14
- astrbot/core/db/vec_db/faiss_impl/document_storage.py +90 -77
- astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +9 -3
- astrbot/core/db/vec_db/faiss_impl/vec_db.py +36 -31
- astrbot/core/event_bus.py +8 -6
- astrbot/core/file_token_service.py +6 -5
- astrbot/core/initial_loader.py +7 -5
- astrbot/core/knowledge_base/chunking/__init__.py +1 -3
- astrbot/core/knowledge_base/chunking/base.py +1 -0
- astrbot/core/knowledge_base/chunking/fixed_size.py +2 -0
- astrbot/core/knowledge_base/chunking/recursive.py +16 -10
- astrbot/core/knowledge_base/kb_db_sqlite.py +50 -48
- astrbot/core/knowledge_base/kb_helper.py +30 -17
- astrbot/core/knowledge_base/kb_mgr.py +6 -7
- astrbot/core/knowledge_base/models.py +10 -4
- astrbot/core/knowledge_base/parsers/__init__.py +3 -5
- astrbot/core/knowledge_base/parsers/base.py +1 -0
- astrbot/core/knowledge_base/parsers/markitdown_parser.py +2 -1
- astrbot/core/knowledge_base/parsers/pdf_parser.py +2 -1
- astrbot/core/knowledge_base/parsers/text_parser.py +1 -0
- astrbot/core/knowledge_base/parsers/util.py +1 -1
- astrbot/core/knowledge_base/retrieval/__init__.py +6 -8
- astrbot/core/knowledge_base/retrieval/manager.py +17 -14
- astrbot/core/knowledge_base/retrieval/rank_fusion.py +7 -3
- astrbot/core/knowledge_base/retrieval/sparse_retriever.py +11 -5
- astrbot/core/log.py +21 -13
- astrbot/core/message/components.py +123 -217
- astrbot/core/message/message_event_result.py +24 -24
- astrbot/core/persona_mgr.py +20 -11
- astrbot/core/pipeline/__init__.py +7 -7
- astrbot/core/pipeline/content_safety_check/stage.py +13 -9
- astrbot/core/pipeline/content_safety_check/strategies/__init__.py +1 -2
- astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +12 -13
- astrbot/core/pipeline/content_safety_check/strategies/keywords.py +1 -0
- astrbot/core/pipeline/content_safety_check/strategies/strategy.py +6 -6
- astrbot/core/pipeline/context.py +4 -1
- astrbot/core/pipeline/context_utils.py +77 -7
- astrbot/core/pipeline/preprocess_stage/stage.py +12 -9
- astrbot/core/pipeline/process_stage/method/llm_request.py +125 -72
- astrbot/core/pipeline/process_stage/method/star_request.py +19 -17
- astrbot/core/pipeline/process_stage/stage.py +13 -10
- astrbot/core/pipeline/process_stage/utils.py +6 -5
- astrbot/core/pipeline/rate_limit_check/stage.py +37 -36
- astrbot/core/pipeline/respond/stage.py +23 -20
- astrbot/core/pipeline/result_decorate/stage.py +31 -23
- astrbot/core/pipeline/scheduler.py +12 -8
- astrbot/core/pipeline/session_status_check/stage.py +12 -8
- astrbot/core/pipeline/stage.py +10 -4
- astrbot/core/pipeline/waking_check/stage.py +24 -18
- astrbot/core/pipeline/whitelist_check/stage.py +10 -7
- astrbot/core/platform/__init__.py +6 -6
- astrbot/core/platform/astr_message_event.py +76 -110
- astrbot/core/platform/astrbot_message.py +11 -13
- astrbot/core/platform/manager.py +16 -15
- astrbot/core/platform/message_session.py +5 -3
- astrbot/core/platform/platform.py +16 -24
- astrbot/core/platform/platform_metadata.py +4 -4
- astrbot/core/platform/register.py +8 -8
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +23 -15
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +51 -33
- astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +42 -27
- astrbot/core/platform/sources/dingtalk/dingtalk_event.py +7 -3
- astrbot/core/platform/sources/discord/client.py +9 -6
- astrbot/core/platform/sources/discord/components.py +18 -14
- astrbot/core/platform/sources/discord/discord_platform_adapter.py +45 -30
- astrbot/core/platform/sources/discord/discord_platform_event.py +38 -30
- astrbot/core/platform/sources/lark/lark_adapter.py +23 -17
- astrbot/core/platform/sources/lark/lark_event.py +21 -14
- astrbot/core/platform/sources/misskey/misskey_adapter.py +107 -67
- astrbot/core/platform/sources/misskey/misskey_api.py +153 -129
- astrbot/core/platform/sources/misskey/misskey_event.py +20 -15
- astrbot/core/platform/sources/misskey/misskey_utils.py +74 -62
- astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +63 -44
- astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +41 -26
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -17
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +3 -1
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +12 -7
- astrbot/core/platform/sources/satori/satori_adapter.py +56 -38
- astrbot/core/platform/sources/satori/satori_event.py +34 -25
- astrbot/core/platform/sources/slack/client.py +11 -9
- astrbot/core/platform/sources/slack/slack_adapter.py +52 -36
- astrbot/core/platform/sources/slack/slack_event.py +34 -24
- astrbot/core/platform/sources/telegram/tg_adapter.py +38 -18
- astrbot/core/platform/sources/telegram/tg_event.py +32 -18
- astrbot/core/platform/sources/webchat/webchat_adapter.py +27 -17
- astrbot/core/platform/sources/webchat/webchat_event.py +14 -10
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +115 -120
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +9 -8
- astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +15 -16
- astrbot/core/platform/sources/wecom/wecom_adapter.py +35 -18
- astrbot/core/platform/sources/wecom/wecom_event.py +55 -48
- astrbot/core/platform/sources/wecom/wecom_kf.py +34 -44
- astrbot/core/platform/sources/wecom/wecom_kf_message.py +26 -10
- astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +18 -10
- astrbot/core/platform/sources/wecom_ai_bot/__init__.py +3 -5
- astrbot/core/platform/sources/wecom_ai_bot/ierror.py +0 -1
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +61 -37
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +67 -28
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +8 -9
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +18 -9
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +14 -12
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +22 -12
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +40 -26
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +47 -45
- astrbot/core/platform_message_history_mgr.py +5 -3
- astrbot/core/provider/__init__.py +2 -3
- astrbot/core/provider/entites.py +8 -8
- astrbot/core/provider/entities.py +61 -75
- astrbot/core/provider/func_tool_manager.py +59 -55
- astrbot/core/provider/manager.py +32 -22
- astrbot/core/provider/provider.py +72 -46
- astrbot/core/provider/register.py +7 -7
- astrbot/core/provider/sources/anthropic_source.py +48 -30
- astrbot/core/provider/sources/azure_tts_source.py +17 -13
- astrbot/core/provider/sources/coze_api_client.py +27 -17
- astrbot/core/provider/sources/coze_source.py +104 -87
- astrbot/core/provider/sources/dashscope_source.py +18 -11
- astrbot/core/provider/sources/dashscope_tts.py +36 -23
- astrbot/core/provider/sources/dify_source.py +25 -20
- astrbot/core/provider/sources/edge_tts_source.py +21 -17
- astrbot/core/provider/sources/fishaudio_tts_api_source.py +22 -14
- astrbot/core/provider/sources/gemini_embedding_source.py +12 -13
- astrbot/core/provider/sources/gemini_source.py +72 -58
- astrbot/core/provider/sources/gemini_tts_source.py +8 -6
- astrbot/core/provider/sources/gsv_selfhosted_source.py +17 -14
- astrbot/core/provider/sources/gsvi_tts_source.py +11 -7
- astrbot/core/provider/sources/minimax_tts_api_source.py +50 -40
- astrbot/core/provider/sources/openai_embedding_source.py +6 -8
- astrbot/core/provider/sources/openai_source.py +77 -69
- astrbot/core/provider/sources/openai_tts_api_source.py +14 -6
- astrbot/core/provider/sources/sensevoice_selfhosted_source.py +13 -11
- astrbot/core/provider/sources/vllm_rerank_source.py +10 -4
- astrbot/core/provider/sources/volcengine_tts.py +38 -31
- astrbot/core/provider/sources/whisper_api_source.py +14 -12
- astrbot/core/provider/sources/whisper_selfhosted_source.py +15 -11
- astrbot/core/provider/sources/xinference_rerank_source.py +16 -8
- astrbot/core/provider/sources/xinference_stt_provider.py +35 -25
- astrbot/core/star/__init__.py +16 -11
- astrbot/core/star/config.py +10 -15
- astrbot/core/star/context.py +97 -75
- astrbot/core/star/filter/__init__.py +4 -3
- astrbot/core/star/filter/command.py +30 -28
- astrbot/core/star/filter/command_group.py +27 -24
- astrbot/core/star/filter/custom_filter.py +6 -5
- astrbot/core/star/filter/event_message_type.py +4 -2
- astrbot/core/star/filter/permission.py +4 -2
- astrbot/core/star/filter/platform_adapter_type.py +4 -2
- astrbot/core/star/filter/regex.py +4 -2
- astrbot/core/star/register/__init__.py +19 -19
- astrbot/core/star/register/star.py +6 -2
- astrbot/core/star/register/star_handler.py +96 -73
- astrbot/core/star/session_llm_manager.py +48 -14
- astrbot/core/star/session_plugin_manager.py +29 -15
- astrbot/core/star/star.py +1 -2
- astrbot/core/star/star_handler.py +13 -8
- astrbot/core/star/star_manager.py +151 -59
- astrbot/core/star/star_tools.py +44 -37
- astrbot/core/star/updator.py +10 -10
- astrbot/core/umop_config_router.py +10 -4
- astrbot/core/updator.py +13 -5
- astrbot/core/utils/astrbot_path.py +3 -5
- astrbot/core/utils/dify_api_client.py +33 -15
- astrbot/core/utils/io.py +66 -42
- astrbot/core/utils/log_pipe.py +1 -1
- astrbot/core/utils/metrics.py +7 -7
- astrbot/core/utils/path_util.py +15 -16
- astrbot/core/utils/pip_installer.py +5 -5
- astrbot/core/utils/session_waiter.py +19 -20
- astrbot/core/utils/shared_preferences.py +45 -20
- astrbot/core/utils/t2i/__init__.py +4 -1
- astrbot/core/utils/t2i/network_strategy.py +35 -26
- astrbot/core/utils/t2i/renderer.py +11 -5
- astrbot/core/utils/t2i/template_manager.py +14 -15
- astrbot/core/utils/tencent_record_helper.py +19 -13
- astrbot/core/utils/version_comparator.py +10 -13
- astrbot/core/zip_updator.py +43 -40
- astrbot/dashboard/routes/__init__.py +18 -18
- astrbot/dashboard/routes/auth.py +10 -8
- astrbot/dashboard/routes/chat.py +30 -21
- astrbot/dashboard/routes/config.py +92 -75
- astrbot/dashboard/routes/conversation.py +46 -39
- astrbot/dashboard/routes/file.py +4 -2
- astrbot/dashboard/routes/knowledge_base.py +47 -40
- astrbot/dashboard/routes/log.py +9 -4
- astrbot/dashboard/routes/persona.py +19 -16
- astrbot/dashboard/routes/plugin.py +69 -55
- astrbot/dashboard/routes/route.py +3 -1
- astrbot/dashboard/routes/session_management.py +130 -116
- astrbot/dashboard/routes/stat.py +34 -34
- astrbot/dashboard/routes/t2i.py +15 -12
- astrbot/dashboard/routes/tools.py +47 -52
- astrbot/dashboard/routes/update.py +32 -28
- astrbot/dashboard/server.py +30 -26
- astrbot/dashboard/utils.py +8 -4
- {astrbot-4.5.1.dist-info → astrbot-4.5.2.dist-info}/METADATA +2 -1
- astrbot-4.5.2.dist-info/RECORD +261 -0
- astrbot-4.5.1.dist-info/RECORD +0 -260
- {astrbot-4.5.1.dist-info → astrbot-4.5.2.dist-info}/WHEEL +0 -0
- {astrbot-4.5.1.dist-info → astrbot-4.5.2.dist-info}/entry_points.txt +0 -0
- {astrbot-4.5.1.dist-info → astrbot-4.5.2.dist-info}/licenses/LICENSE +0 -0
astrbot/cli/utils/plugin.py
CHANGED
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
import shutil
|
|
2
2
|
import tempfile
|
|
3
|
-
|
|
4
|
-
import httpx
|
|
5
|
-
import yaml
|
|
6
3
|
from enum import Enum
|
|
7
4
|
from io import BytesIO
|
|
8
5
|
from pathlib import Path
|
|
9
6
|
from zipfile import ZipFile
|
|
10
7
|
|
|
11
8
|
import click
|
|
9
|
+
import httpx
|
|
10
|
+
import yaml
|
|
11
|
+
|
|
12
12
|
from .version_comparator import VersionComparator
|
|
13
13
|
|
|
14
14
|
|
|
@@ -32,7 +32,8 @@ def get_git_repo(url: str, target_path: Path, proxy: str | None = None):
|
|
|
32
32
|
release_url = f"https://api.github.com/repos/{author}/{repo}/releases"
|
|
33
33
|
try:
|
|
34
34
|
with httpx.Client(
|
|
35
|
-
proxy=proxy if proxy else None,
|
|
35
|
+
proxy=proxy if proxy else None,
|
|
36
|
+
follow_redirects=True,
|
|
36
37
|
) as client:
|
|
37
38
|
resp = client.get(release_url)
|
|
38
39
|
resp.raise_for_status()
|
|
@@ -55,7 +56,8 @@ def get_git_repo(url: str, target_path: Path, proxy: str | None = None):
|
|
|
55
56
|
|
|
56
57
|
# 下载并解压
|
|
57
58
|
with httpx.Client(
|
|
58
|
-
proxy=proxy if proxy else None,
|
|
59
|
+
proxy=proxy if proxy else None,
|
|
60
|
+
follow_redirects=True,
|
|
59
61
|
) as client:
|
|
60
62
|
resp = client.get(download_url)
|
|
61
63
|
if (
|
|
@@ -89,6 +91,7 @@ def load_yaml_metadata(plugin_dir: Path) -> dict:
|
|
|
89
91
|
|
|
90
92
|
Returns:
|
|
91
93
|
dict: 包含元数据的字典,如果读取失败则返回空字典
|
|
94
|
+
|
|
92
95
|
"""
|
|
93
96
|
yaml_path = plugin_dir / "metadata.yaml"
|
|
94
97
|
if yaml_path.exists():
|
|
@@ -107,6 +110,7 @@ def build_plug_list(plugins_dir: Path) -> list:
|
|
|
107
110
|
|
|
108
111
|
Returns:
|
|
109
112
|
list: 包含插件信息的字典列表
|
|
113
|
+
|
|
110
114
|
"""
|
|
111
115
|
# 获取本地插件信息
|
|
112
116
|
result = []
|
|
@@ -133,7 +137,7 @@ def build_plug_list(plugins_dir: Path) -> list:
|
|
|
133
137
|
"repo": str(metadata.get("repo", "")),
|
|
134
138
|
"status": PluginStatus.INSTALLED,
|
|
135
139
|
"local_path": str(plugin_dir),
|
|
136
|
-
}
|
|
140
|
+
},
|
|
137
141
|
)
|
|
138
142
|
|
|
139
143
|
# 获取在线插件列表
|
|
@@ -153,7 +157,7 @@ def build_plug_list(plugins_dir: Path) -> list:
|
|
|
153
157
|
"repo": str(plugin_info.get("repo", "")),
|
|
154
158
|
"status": PluginStatus.NOT_INSTALLED,
|
|
155
159
|
"local_path": None,
|
|
156
|
-
}
|
|
160
|
+
},
|
|
157
161
|
)
|
|
158
162
|
except Exception as e:
|
|
159
163
|
click.echo(f"获取在线插件列表失败: {e}", err=True)
|
|
@@ -168,7 +172,8 @@ def build_plug_list(plugins_dir: Path) -> list:
|
|
|
168
172
|
)
|
|
169
173
|
if (
|
|
170
174
|
VersionComparator.compare_version(
|
|
171
|
-
local_plugin["version"],
|
|
175
|
+
local_plugin["version"],
|
|
176
|
+
online_plugin["version"],
|
|
172
177
|
)
|
|
173
178
|
< 0
|
|
174
179
|
):
|
|
@@ -186,7 +191,10 @@ def build_plug_list(plugins_dir: Path) -> list:
|
|
|
186
191
|
|
|
187
192
|
|
|
188
193
|
def manage_plugin(
|
|
189
|
-
plugin: dict,
|
|
194
|
+
plugin: dict,
|
|
195
|
+
plugins_dir: Path,
|
|
196
|
+
is_update: bool = False,
|
|
197
|
+
proxy: str | None = None,
|
|
190
198
|
) -> None:
|
|
191
199
|
"""安装或更新插件
|
|
192
200
|
|
|
@@ -195,6 +203,7 @@ def manage_plugin(
|
|
|
195
203
|
plugins_dir (Path): 插件目录
|
|
196
204
|
is_update (bool, optional): 是否为更新操作. 默认为 False
|
|
197
205
|
proxy (str, optional): 代理服务器地址
|
|
206
|
+
|
|
198
207
|
"""
|
|
199
208
|
plugin_name = plugin["name"]
|
|
200
209
|
repo_url = plugin["repo"]
|
|
@@ -212,26 +221,26 @@ def manage_plugin(
|
|
|
212
221
|
raise click.ClickException(f"插件 {plugin_name} 未安装,无法更新")
|
|
213
222
|
|
|
214
223
|
# 备份现有插件
|
|
215
|
-
if is_update and backup_path.exists():
|
|
224
|
+
if is_update and backup_path is not None and backup_path.exists():
|
|
216
225
|
shutil.rmtree(backup_path)
|
|
217
|
-
if is_update:
|
|
226
|
+
if is_update and backup_path is not None:
|
|
218
227
|
shutil.copytree(target_path, backup_path)
|
|
219
228
|
|
|
220
229
|
try:
|
|
221
230
|
click.echo(
|
|
222
|
-
f"正在从 {repo_url} {'更新' if is_update else '下载'}插件 {plugin_name}..."
|
|
231
|
+
f"正在从 {repo_url} {'更新' if is_update else '下载'}插件 {plugin_name}...",
|
|
223
232
|
)
|
|
224
233
|
get_git_repo(repo_url, target_path, proxy)
|
|
225
234
|
|
|
226
235
|
# 更新成功,删除备份
|
|
227
|
-
if is_update and backup_path.exists():
|
|
236
|
+
if is_update and backup_path is not None and backup_path.exists():
|
|
228
237
|
shutil.rmtree(backup_path)
|
|
229
238
|
click.echo(f"插件 {plugin_name} {'更新' if is_update else '安装'}成功")
|
|
230
239
|
except Exception as e:
|
|
231
240
|
if target_path.exists():
|
|
232
241
|
shutil.rmtree(target_path, ignore_errors=True)
|
|
233
|
-
if is_update and backup_path.exists():
|
|
242
|
+
if is_update and backup_path is not None and backup_path.exists():
|
|
234
243
|
shutil.move(backup_path, target_path)
|
|
235
244
|
raise click.ClickException(
|
|
236
|
-
f"{'更新' if is_update else '安装'}插件 {plugin_name} 时出错: {e}"
|
|
245
|
+
f"{'更新' if is_update else '安装'}插件 {plugin_name} 时出错: {e}",
|
|
237
246
|
)
|
|
@@ -1,6 +1,4 @@
|
|
|
1
|
-
"""
|
|
2
|
-
拷贝自 astrbot.core.utils.version_comparator
|
|
3
|
-
"""
|
|
1
|
+
"""拷贝自 astrbot.core.utils.version_comparator"""
|
|
4
2
|
|
|
5
3
|
import re
|
|
6
4
|
|
|
@@ -42,15 +40,15 @@ class VersionComparator:
|
|
|
42
40
|
for i in range(length):
|
|
43
41
|
if v1_parts[i] > v2_parts[i]:
|
|
44
42
|
return 1
|
|
45
|
-
|
|
43
|
+
if v1_parts[i] < v2_parts[i]:
|
|
46
44
|
return -1
|
|
47
45
|
|
|
48
46
|
# 比较预发布标签
|
|
49
47
|
if v1_prerelease is None and v2_prerelease is not None:
|
|
50
48
|
return 1 # 没有预发布标签的版本高于有预发布标签的版本
|
|
51
|
-
|
|
49
|
+
if v1_prerelease is not None and v2_prerelease is None:
|
|
52
50
|
return -1 # 有预发布标签的版本低于没有预发布标签的版本
|
|
53
|
-
|
|
51
|
+
if v1_prerelease is not None and v2_prerelease is not None:
|
|
54
52
|
len_pre = max(len(v1_prerelease), len(v2_prerelease))
|
|
55
53
|
for i in range(len_pre):
|
|
56
54
|
p1 = v1_prerelease[i] if i < len(v1_prerelease) else None
|
|
@@ -58,21 +56,21 @@ class VersionComparator:
|
|
|
58
56
|
|
|
59
57
|
if p1 is None and p2 is not None:
|
|
60
58
|
return -1
|
|
61
|
-
|
|
59
|
+
if p1 is not None and p2 is None:
|
|
62
60
|
return 1
|
|
63
|
-
|
|
61
|
+
if isinstance(p1, int) and isinstance(p2, str):
|
|
64
62
|
return -1
|
|
65
|
-
|
|
63
|
+
if isinstance(p1, str) and isinstance(p2, int):
|
|
66
64
|
return 1
|
|
67
|
-
|
|
65
|
+
if isinstance(p1, int) and isinstance(p2, int):
|
|
68
66
|
if p1 > p2:
|
|
69
67
|
return 1
|
|
70
|
-
|
|
68
|
+
if p1 < p2:
|
|
71
69
|
return -1
|
|
72
70
|
elif isinstance(p1, str) and isinstance(p2, str):
|
|
73
71
|
if p1 > p2:
|
|
74
72
|
return 1
|
|
75
|
-
|
|
73
|
+
if p1 < p2:
|
|
76
74
|
return -1
|
|
77
75
|
return 0 # 预发布标签完全相同
|
|
78
76
|
|
astrbot/core/__init__.py
CHANGED
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
import os
|
|
2
|
-
|
|
3
|
-
from astrbot.core.utils.t2i.renderer import HtmlRenderer
|
|
4
|
-
from astrbot.core.utils.shared_preferences import SharedPreferences
|
|
5
|
-
from astrbot.core.utils.pip_installer import PipInstaller
|
|
6
|
-
from astrbot.core.db.sqlite import SQLiteDatabase
|
|
7
|
-
from astrbot.core.config.default import DB_PATH
|
|
2
|
+
|
|
8
3
|
from astrbot.core.config import AstrBotConfig
|
|
4
|
+
from astrbot.core.config.default import DB_PATH
|
|
5
|
+
from astrbot.core.db.sqlite import SQLiteDatabase
|
|
9
6
|
from astrbot.core.file_token_service import FileTokenService
|
|
7
|
+
from astrbot.core.utils.pip_installer import PipInstaller
|
|
8
|
+
from astrbot.core.utils.shared_preferences import SharedPreferences
|
|
9
|
+
from astrbot.core.utils.t2i.renderer import HtmlRenderer
|
|
10
|
+
|
|
11
|
+
from .log import LogBroker, LogManager # noqa
|
|
10
12
|
from .utils.astrbot_path import get_astrbot_data_path
|
|
11
13
|
|
|
12
14
|
# 初始化数据存储文件夹
|
astrbot/core/agent/agent.py
CHANGED
astrbot/core/agent/handoff.py
CHANGED
|
@@ -1,14 +1,18 @@
|
|
|
1
1
|
from typing import Generic
|
|
2
|
-
|
|
2
|
+
|
|
3
3
|
from .agent import Agent
|
|
4
4
|
from .run_context import TContext
|
|
5
|
+
from .tool import FunctionTool
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
class HandoffTool(FunctionTool, Generic[TContext]):
|
|
8
9
|
"""Handoff tool for delegating tasks to another agent."""
|
|
9
10
|
|
|
10
11
|
def __init__(
|
|
11
|
-
self,
|
|
12
|
+
self,
|
|
13
|
+
agent: Agent[TContext],
|
|
14
|
+
parameters: dict | None = None,
|
|
15
|
+
**kwargs,
|
|
12
16
|
):
|
|
13
17
|
self.agent = agent
|
|
14
18
|
super().__init__(
|
astrbot/core/agent/hooks.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
|
1
|
-
import mcp
|
|
2
|
-
from dataclasses import dataclass
|
|
3
|
-
from .run_context import ContextWrapper, TContext
|
|
4
1
|
from typing import Generic
|
|
5
|
-
|
|
2
|
+
|
|
3
|
+
import mcp
|
|
4
|
+
|
|
6
5
|
from astrbot.core.agent.tool import FunctionTool
|
|
6
|
+
from astrbot.core.provider.entities import LLMResponse
|
|
7
|
+
|
|
8
|
+
from .run_context import ContextWrapper, TContext
|
|
7
9
|
|
|
8
10
|
|
|
9
|
-
@dataclass
|
|
10
11
|
class BaseAgentRunHooks(Generic[TContext]):
|
|
11
12
|
async def on_agent_begin(self, run_context: ContextWrapper[TContext]): ...
|
|
12
13
|
async def on_tool_start(
|
|
@@ -23,5 +24,7 @@ class BaseAgentRunHooks(Generic[TContext]):
|
|
|
23
24
|
tool_result: mcp.types.CallToolResult | None,
|
|
24
25
|
): ...
|
|
25
26
|
async def on_agent_done(
|
|
26
|
-
self,
|
|
27
|
+
self,
|
|
28
|
+
run_context: ContextWrapper[TContext],
|
|
29
|
+
llm_response: LLMResponse,
|
|
27
30
|
): ...
|
astrbot/core/agent/mcp_client.py
CHANGED
|
@@ -1,11 +1,16 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import logging
|
|
3
|
-
from datetime import timedelta
|
|
4
|
-
from typing import Optional
|
|
5
3
|
from contextlib import AsyncExitStack
|
|
4
|
+
from datetime import timedelta
|
|
5
|
+
from typing import Generic
|
|
6
|
+
|
|
6
7
|
from astrbot import logger
|
|
8
|
+
from astrbot.core.agent.run_context import ContextWrapper
|
|
7
9
|
from astrbot.core.utils.log_pipe import LogPipe
|
|
8
10
|
|
|
11
|
+
from .run_context import TContext
|
|
12
|
+
from .tool import FunctionTool
|
|
13
|
+
|
|
9
14
|
try:
|
|
10
15
|
import mcp
|
|
11
16
|
from mcp.client.sse import sse_client
|
|
@@ -16,13 +21,13 @@ try:
|
|
|
16
21
|
from mcp.client.streamable_http import streamablehttp_client
|
|
17
22
|
except (ModuleNotFoundError, ImportError):
|
|
18
23
|
logger.warning(
|
|
19
|
-
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。"
|
|
24
|
+
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。",
|
|
20
25
|
)
|
|
21
26
|
|
|
22
27
|
|
|
23
28
|
def _prepare_config(config: dict) -> dict:
|
|
24
29
|
"""准备配置,处理嵌套格式"""
|
|
25
|
-
if
|
|
30
|
+
if config.get("mcpServers"):
|
|
26
31
|
first_key = next(iter(config["mcpServers"]))
|
|
27
32
|
config = config["mcpServers"][first_key]
|
|
28
33
|
config.pop("active", None)
|
|
@@ -71,8 +76,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
|
|
71
76
|
) as response:
|
|
72
77
|
if response.status == 200:
|
|
73
78
|
return True, ""
|
|
74
|
-
|
|
75
|
-
return False, f"HTTP {response.status}: {response.reason}"
|
|
79
|
+
return False, f"HTTP {response.status}: {response.reason}"
|
|
76
80
|
else:
|
|
77
81
|
async with session.get(
|
|
78
82
|
url,
|
|
@@ -84,8 +88,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
|
|
84
88
|
) as response:
|
|
85
89
|
if response.status == 200:
|
|
86
90
|
return True, ""
|
|
87
|
-
|
|
88
|
-
return False, f"HTTP {response.status}: {response.reason}"
|
|
91
|
+
return False, f"HTTP {response.status}: {response.reason}"
|
|
89
92
|
|
|
90
93
|
except asyncio.TimeoutError:
|
|
91
94
|
return False, f"连接超时: {timeout}秒"
|
|
@@ -96,7 +99,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
|
|
96
99
|
class MCPClient:
|
|
97
100
|
def __init__(self):
|
|
98
101
|
# Initialize session and client objects
|
|
99
|
-
self.session:
|
|
102
|
+
self.session: mcp.ClientSession | None = None
|
|
100
103
|
self.exit_stack = AsyncExitStack()
|
|
101
104
|
|
|
102
105
|
self.name: str | None = None
|
|
@@ -115,6 +118,7 @@ class MCPClient:
|
|
|
115
118
|
|
|
116
119
|
Args:
|
|
117
120
|
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
|
|
121
|
+
|
|
118
122
|
"""
|
|
119
123
|
cfg = _prepare_config(mcp_server_config.copy())
|
|
120
124
|
|
|
@@ -144,7 +148,7 @@ class MCPClient:
|
|
|
144
148
|
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
|
|
145
149
|
)
|
|
146
150
|
streams = await self.exit_stack.enter_async_context(
|
|
147
|
-
self._streams_context
|
|
151
|
+
self._streams_context,
|
|
148
152
|
)
|
|
149
153
|
|
|
150
154
|
# Create a new client session
|
|
@@ -154,12 +158,12 @@ class MCPClient:
|
|
|
154
158
|
*streams,
|
|
155
159
|
read_timeout_seconds=read_timeout,
|
|
156
160
|
logging_callback=logging_callback, # type: ignore
|
|
157
|
-
)
|
|
161
|
+
),
|
|
158
162
|
)
|
|
159
163
|
else:
|
|
160
164
|
timeout = timedelta(seconds=cfg.get("timeout", 30))
|
|
161
165
|
sse_read_timeout = timedelta(
|
|
162
|
-
seconds=cfg.get("sse_read_timeout", 60 * 5)
|
|
166
|
+
seconds=cfg.get("sse_read_timeout", 60 * 5),
|
|
163
167
|
)
|
|
164
168
|
self._streams_context = streamablehttp_client(
|
|
165
169
|
url=cfg["url"],
|
|
@@ -169,7 +173,7 @@ class MCPClient:
|
|
|
169
173
|
terminate_on_close=cfg.get("terminate_on_close", True),
|
|
170
174
|
)
|
|
171
175
|
read_s, write_s, _ = await self.exit_stack.enter_async_context(
|
|
172
|
-
self._streams_context
|
|
176
|
+
self._streams_context,
|
|
173
177
|
)
|
|
174
178
|
|
|
175
179
|
# Create a new client session
|
|
@@ -180,7 +184,7 @@ class MCPClient:
|
|
|
180
184
|
write_stream=write_s,
|
|
181
185
|
read_timeout_seconds=read_timeout,
|
|
182
186
|
logging_callback=logging_callback, # type: ignore
|
|
183
|
-
)
|
|
187
|
+
),
|
|
184
188
|
)
|
|
185
189
|
|
|
186
190
|
else:
|
|
@@ -206,7 +210,7 @@ class MCPClient:
|
|
|
206
210
|
|
|
207
211
|
# Create a new client session
|
|
208
212
|
self.session = await self.exit_stack.enter_async_context(
|
|
209
|
-
mcp.ClientSession(*stdio_transport)
|
|
213
|
+
mcp.ClientSession(*stdio_transport),
|
|
210
214
|
)
|
|
211
215
|
await self.session.initialize()
|
|
212
216
|
|
|
@@ -222,3 +226,34 @@ class MCPClient:
|
|
|
222
226
|
"""Clean up resources"""
|
|
223
227
|
await self.exit_stack.aclose()
|
|
224
228
|
self.running_event.set() # Set the running event to indicate cleanup is done
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class MCPTool(FunctionTool, Generic[TContext]):
|
|
232
|
+
"""A function tool that calls an MCP service."""
|
|
233
|
+
|
|
234
|
+
def __init__(
|
|
235
|
+
self, mcp_tool: mcp.Tool, mcp_client: MCPClient, mcp_server_name: str, **kwargs
|
|
236
|
+
):
|
|
237
|
+
super().__init__(
|
|
238
|
+
name=mcp_tool.name,
|
|
239
|
+
description=mcp_tool.description or "",
|
|
240
|
+
parameters=mcp_tool.inputSchema,
|
|
241
|
+
)
|
|
242
|
+
self.mcp_tool = mcp_tool
|
|
243
|
+
self.mcp_client = mcp_client
|
|
244
|
+
self.mcp_server_name = mcp_server_name
|
|
245
|
+
|
|
246
|
+
async def call(
|
|
247
|
+
self, context: ContextWrapper[TContext], **kwargs
|
|
248
|
+
) -> mcp.types.CallToolResult:
|
|
249
|
+
session = self.mcp_client.session
|
|
250
|
+
if not session:
|
|
251
|
+
raise ValueError("MCP session is not available for MCP function tools.")
|
|
252
|
+
res = await session.call_tool(
|
|
253
|
+
name=self.mcp_tool.name,
|
|
254
|
+
arguments=kwargs,
|
|
255
|
+
read_timeout_seconds=timedelta(
|
|
256
|
+
seconds=context.tool_call_timeout,
|
|
257
|
+
),
|
|
258
|
+
)
|
|
259
|
+
return res
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
# Inspired by MoonshotAI/kosong, credits to MoonshotAI/kosong authors for the original implementation.
|
|
2
|
+
# License: Apache License 2.0
|
|
3
|
+
|
|
4
|
+
from typing import Any, ClassVar, Literal, cast
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, GetCoreSchemaHandler
|
|
7
|
+
from pydantic_core import core_schema
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ContentPart(BaseModel):
|
|
11
|
+
"""A part of the content in a message."""
|
|
12
|
+
|
|
13
|
+
__content_part_registry: ClassVar[dict[str, type["ContentPart"]]] = {}
|
|
14
|
+
|
|
15
|
+
type: str
|
|
16
|
+
|
|
17
|
+
def __init_subclass__(cls, **kwargs: Any) -> None:
|
|
18
|
+
super().__init_subclass__(**kwargs)
|
|
19
|
+
|
|
20
|
+
invalid_subclass_error_msg = f"ContentPart subclass {cls.__name__} must have a `type` field of type `str`"
|
|
21
|
+
|
|
22
|
+
type_value = getattr(cls, "type", None)
|
|
23
|
+
if type_value is None or not isinstance(type_value, str):
|
|
24
|
+
raise ValueError(invalid_subclass_error_msg)
|
|
25
|
+
|
|
26
|
+
cls.__content_part_registry[type_value] = cls
|
|
27
|
+
|
|
28
|
+
@classmethod
|
|
29
|
+
def __get_pydantic_core_schema__(
|
|
30
|
+
cls, source_type: Any, handler: GetCoreSchemaHandler
|
|
31
|
+
) -> core_schema.CoreSchema:
|
|
32
|
+
# If we're dealing with the base ContentPart class, use custom validation
|
|
33
|
+
if cls.__name__ == "ContentPart":
|
|
34
|
+
|
|
35
|
+
def validate_content_part(value: Any) -> Any:
|
|
36
|
+
# if it's already an instance of a ContentPart subclass, return it
|
|
37
|
+
if hasattr(value, "__class__") and issubclass(value.__class__, cls):
|
|
38
|
+
return value
|
|
39
|
+
|
|
40
|
+
# if it's a dict with a type field, dispatch to the appropriate subclass
|
|
41
|
+
if isinstance(value, dict) and "type" in value:
|
|
42
|
+
type_value: Any | None = cast(dict[str, Any], value).get("type")
|
|
43
|
+
if not isinstance(type_value, str):
|
|
44
|
+
raise ValueError(f"Cannot validate {value} as ContentPart")
|
|
45
|
+
target_class = cls.__content_part_registry[type_value]
|
|
46
|
+
return target_class.model_validate(value)
|
|
47
|
+
|
|
48
|
+
raise ValueError(f"Cannot validate {value} as ContentPart")
|
|
49
|
+
|
|
50
|
+
return core_schema.no_info_plain_validator_function(validate_content_part)
|
|
51
|
+
|
|
52
|
+
# for subclasses, use the default schema
|
|
53
|
+
return handler(source_type)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class TextPart(ContentPart):
|
|
57
|
+
"""
|
|
58
|
+
>>> TextPart(text="Hello, world!").model_dump()
|
|
59
|
+
{'type': 'text', 'text': 'Hello, world!'}
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
type: str = "text"
|
|
63
|
+
text: str
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class ImageURLPart(ContentPart):
|
|
67
|
+
"""
|
|
68
|
+
>>> ImageURLPart(image_url="http://example.com/image.jpg").model_dump()
|
|
69
|
+
{'type': 'image_url', 'image_url': 'http://example.com/image.jpg'}
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
class ImageURL(BaseModel):
|
|
73
|
+
url: str
|
|
74
|
+
"""The URL of the image, can be data URI scheme like `data:image/png;base64,...`."""
|
|
75
|
+
id: str | None = None
|
|
76
|
+
"""The ID of the image, to allow LLMs to distinguish different images."""
|
|
77
|
+
|
|
78
|
+
type: str = "image_url"
|
|
79
|
+
image_url: str
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class AudioURLPart(ContentPart):
|
|
83
|
+
"""
|
|
84
|
+
>>> AudioURLPart(audio_url=AudioURLPart.AudioURL(url="https://example.com/audio.mp3")).model_dump()
|
|
85
|
+
{'type': 'audio_url', 'audio_url': {'url': 'https://example.com/audio.mp3', 'id': None}}
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
class AudioURL(BaseModel):
|
|
89
|
+
url: str
|
|
90
|
+
"""The URL of the audio, can be data URI scheme like `data:audio/aac;base64,...`."""
|
|
91
|
+
id: str | None = None
|
|
92
|
+
"""The ID of the audio, to allow LLMs to distinguish different audios."""
|
|
93
|
+
|
|
94
|
+
type: str = "audio_url"
|
|
95
|
+
audio_url: AudioURL
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class ToolCall(BaseModel):
|
|
99
|
+
"""
|
|
100
|
+
A tool call requested by the assistant.
|
|
101
|
+
|
|
102
|
+
>>> ToolCall(
|
|
103
|
+
... id="123",
|
|
104
|
+
... function=ToolCall.FunctionBody(
|
|
105
|
+
... name="function",
|
|
106
|
+
... arguments="{}"
|
|
107
|
+
... ),
|
|
108
|
+
... ).model_dump()
|
|
109
|
+
{'type': 'function', 'id': '123', 'function': {'name': 'function', 'arguments': '{}'}}
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
class FunctionBody(BaseModel):
|
|
113
|
+
name: str
|
|
114
|
+
arguments: str | None
|
|
115
|
+
|
|
116
|
+
type: Literal["function"] = "function"
|
|
117
|
+
|
|
118
|
+
id: str
|
|
119
|
+
"""The ID of the tool call."""
|
|
120
|
+
function: FunctionBody
|
|
121
|
+
"""The function body of the tool call."""
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class ToolCallPart(BaseModel):
|
|
125
|
+
"""A part of the tool call."""
|
|
126
|
+
|
|
127
|
+
arguments_part: str | None = None
|
|
128
|
+
"""A part of the arguments of the tool call."""
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class Message(BaseModel):
|
|
132
|
+
"""A message in a conversation."""
|
|
133
|
+
|
|
134
|
+
role: Literal[
|
|
135
|
+
"system",
|
|
136
|
+
"user",
|
|
137
|
+
"assistant",
|
|
138
|
+
"tool",
|
|
139
|
+
]
|
|
140
|
+
|
|
141
|
+
content: str | list[ContentPart]
|
|
142
|
+
"""The content of the message."""
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class AssistantMessageSegment(Message):
|
|
146
|
+
"""A message segment from the assistant."""
|
|
147
|
+
|
|
148
|
+
role: Literal["assistant"] = "assistant"
|
|
149
|
+
tool_calls: list[ToolCall] | list[dict] | None = None
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class ToolCallMessageSegment(Message):
|
|
153
|
+
"""A message segment representing a tool call."""
|
|
154
|
+
|
|
155
|
+
role: Literal["tool"] = "tool"
|
|
156
|
+
tool_call_id: str
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class UserMessageSegment(Message):
|
|
160
|
+
"""A message segment from the user."""
|
|
161
|
+
|
|
162
|
+
role: Literal["user"] = "user"
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class SystemMessageSegment(Message):
|
|
166
|
+
"""A message segment from the system."""
|
|
167
|
+
|
|
168
|
+
role: Literal["system"] = "system"
|
astrbot/core/agent/response.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
2
|
from typing import Any, Generic
|
|
3
|
-
from typing_extensions import TypeVar
|
|
4
3
|
|
|
5
|
-
from
|
|
4
|
+
from typing_extensions import TypeVar
|
|
6
5
|
|
|
7
6
|
TContext = TypeVar("TContext", default=Any)
|
|
8
7
|
|
|
@@ -12,7 +11,7 @@ class ContextWrapper(Generic[TContext]):
|
|
|
12
11
|
"""A context for running an agent, which can be used to pass additional data or state."""
|
|
13
12
|
|
|
14
13
|
context: TContext
|
|
15
|
-
|
|
14
|
+
tool_call_timeout: int = 60 # Default tool call timeout in seconds
|
|
16
15
|
|
|
17
16
|
|
|
18
17
|
NoContext = ContextWrapper[None]
|
|
@@ -1,13 +1,15 @@
|
|
|
1
1
|
import abc
|
|
2
2
|
import typing as T
|
|
3
3
|
from enum import Enum, auto
|
|
4
|
-
|
|
5
|
-
from ..response import AgentResponse
|
|
6
|
-
from ..hooks import BaseAgentRunHooks
|
|
7
|
-
from ..tool_executor import BaseFunctionToolExecutor
|
|
4
|
+
|
|
8
5
|
from astrbot.core.provider import Provider
|
|
9
6
|
from astrbot.core.provider.entities import LLMResponse
|
|
10
7
|
|
|
8
|
+
from ..hooks import BaseAgentRunHooks
|
|
9
|
+
from ..response import AgentResponse
|
|
10
|
+
from ..run_context import ContextWrapper, TContext
|
|
11
|
+
from ..tool_executor import BaseFunctionToolExecutor
|
|
12
|
+
|
|
11
13
|
|
|
12
14
|
class AgentState(Enum):
|
|
13
15
|
"""Defines the state of the agent."""
|
|
@@ -28,31 +30,26 @@ class BaseAgentRunner(T.Generic[TContext]):
|
|
|
28
30
|
agent_hooks: BaseAgentRunHooks[TContext],
|
|
29
31
|
**kwargs: T.Any,
|
|
30
32
|
) -> None:
|
|
31
|
-
"""
|
|
32
|
-
Reset the agent to its initial state.
|
|
33
|
+
"""Reset the agent to its initial state.
|
|
33
34
|
This method should be called before starting a new run.
|
|
34
35
|
"""
|
|
35
36
|
...
|
|
36
37
|
|
|
37
38
|
@abc.abstractmethod
|
|
38
39
|
async def step(self) -> T.AsyncGenerator[AgentResponse, None]:
|
|
39
|
-
"""
|
|
40
|
-
Process a single step of the agent.
|
|
41
|
-
"""
|
|
40
|
+
"""Process a single step of the agent."""
|
|
42
41
|
...
|
|
43
42
|
|
|
44
43
|
@abc.abstractmethod
|
|
45
44
|
def done(self) -> bool:
|
|
46
|
-
"""
|
|
47
|
-
Check if the agent has completed its task.
|
|
45
|
+
"""Check if the agent has completed its task.
|
|
48
46
|
Returns True if the agent is done, False otherwise.
|
|
49
47
|
"""
|
|
50
48
|
...
|
|
51
49
|
|
|
52
50
|
@abc.abstractmethod
|
|
53
51
|
def get_final_llm_resp(self) -> LLMResponse | None:
|
|
54
|
-
"""
|
|
55
|
-
Get the final observation from the agent.
|
|
52
|
+
"""Get the final observation from the agent.
|
|
56
53
|
This method should be called after the agent is done.
|
|
57
54
|
"""
|
|
58
55
|
...
|