agent-runtime-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent_runtime/__init__.py +84 -0
- agent_runtime/builder.py +317 -0
- agent_runtime/config/__init__.py +29 -0
- agent_runtime/config/definitions.py +144 -0
- agent_runtime/config/policies.py +63 -0
- agent_runtime/config/storage.py +117 -0
- agent_runtime/context.py +10 -0
- agent_runtime/definitions.py +33 -0
- agent_runtime/discovery.py +16 -0
- agent_runtime/exceptions.py +74 -0
- agent_runtime/mcp/__init__.py +28 -0
- agent_runtime/mcp/discovery.py +146 -0
- agent_runtime/mcp/metadata.py +68 -0
- agent_runtime/mcp/utils.py +52 -0
- agent_runtime/model_registry.py +40 -0
- agent_runtime/plugins/__init__.py +4 -0
- agent_runtime/plugins/base.py +90 -0
- agent_runtime/plugins/default.py +19 -0
- agent_runtime/plugins/instructions.py +38 -0
- agent_runtime/plugins/loader.py +59 -0
- agent_runtime/policies.py +15 -0
- agent_runtime/runtime.py +110 -0
- agent_runtime/runtime_engine/__init__.py +22 -0
- agent_runtime/runtime_engine/a2a_bridge.py +190 -0
- agent_runtime/runtime_engine/a2a_task_io.py +165 -0
- agent_runtime/runtime_engine/agent_build.py +315 -0
- agent_runtime/runtime_engine/context.py +469 -0
- agent_runtime/runtime_engine/loading.py +170 -0
- agent_runtime/runtime_engine/observability.py +154 -0
- agent_runtime/runtime_engine/policy_registry.py +98 -0
- agent_runtime/runtime_engine/protocol_tools.py +94 -0
- agent_runtime/runtime_engine/task_flow.py +897 -0
- agent_runtime/runtime_engine/tool_flow.py +332 -0
- agent_runtime/sdk_agent.py +548 -0
- agent_runtime/server/__init__.py +15 -0
- agent_runtime/server/app_factory.py +37 -0
- agent_runtime/server/bootstrap.py +48 -0
- agent_runtime/server/endpoint_utils.py +37 -0
- agent_runtime/server/management.py +107 -0
- agent_runtime/smol/__init__.py +4 -0
- agent_runtime/smol/agents.py +431 -0
- agent_runtime/smol/llm_models.py +212 -0
- agent_runtime/smol/memory.py +111 -0
- agent_runtime/smol/models.py +69 -0
- agent_runtime/standalone.py +57 -0
- agent_runtime/storage.py +5 -0
- agent_runtime/tools.py +5 -0
- agent_runtime_sdk-0.1.0.dist-info/METADATA +125 -0
- agent_runtime_sdk-0.1.0.dist-info/RECORD +51 -0
- agent_runtime_sdk-0.1.0.dist-info/WHEEL +5 -0
- agent_runtime_sdk-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
"""工具调用前的运行时治理逻辑。
|
|
2
|
+
|
|
3
|
+
注意:这里不真正执行工具,只在“工具即将执行”前做几件事:
|
|
4
|
+
|
|
5
|
+
- 归一化参数
|
|
6
|
+
- 检查是否缺少必填输入
|
|
7
|
+
- 必要时向用户追问
|
|
8
|
+
- 必要时向用户请求授权确认
|
|
9
|
+
- 调用 plugin 的 before_tool / after_tool 钩子
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import asyncio
|
|
15
|
+
import json
|
|
16
|
+
import logging
|
|
17
|
+
from typing import Any, Callable
|
|
18
|
+
|
|
19
|
+
from .context import (
|
|
20
|
+
AuthRequiredWaitState,
|
|
21
|
+
InputRequiredWaitState,
|
|
22
|
+
TaskContext,
|
|
23
|
+
TaskPhase,
|
|
24
|
+
WaitState,
|
|
25
|
+
current_task_id,
|
|
26
|
+
)
|
|
27
|
+
from ..config.definitions import ToolPolicy
|
|
28
|
+
from ..exceptions import TaskCancelledError, TaskWaitTimeoutError, UserCancelledError
|
|
29
|
+
from ..plugins import ToolExecutionContext
|
|
30
|
+
from ..config.policies import (
|
|
31
|
+
collect_missing_fields,
|
|
32
|
+
merge_input_fields,
|
|
33
|
+
parse_user_payload,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
logger = logging.getLogger(__name__)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class RuntimeToolFlow:
|
|
41
|
+
"""负责工具执行前后的治理钩子与等待中断逻辑。"""
|
|
42
|
+
|
|
43
|
+
def _build_tool_callbacks(
|
|
44
|
+
self, tools: list[Any], extra_tools: list[Any]
|
|
45
|
+
) -> dict[str, list[Callable]]:
|
|
46
|
+
tool_callbacks: dict[str, list[Callable]] = {
|
|
47
|
+
"ask_user": [self._interrupt_before_ask_user],
|
|
48
|
+
"ask_auth": [self._interrupt_before_ask_auth],
|
|
49
|
+
}
|
|
50
|
+
for tool in [*tools, *extra_tools]:
|
|
51
|
+
tool_name = getattr(tool, "name", "")
|
|
52
|
+
if tool_name and tool_name not in {"ask_user", "ask_auth", "final_answer"}:
|
|
53
|
+
tool_callbacks[tool_name] = [self._build_tool_callback(tool_name)]
|
|
54
|
+
return tool_callbacks
|
|
55
|
+
|
|
56
|
+
def _extract_tool_prompt(self, chat_message: Any, tool_name: str) -> str | None:
|
|
57
|
+
for tool_call in chat_message.tool_calls or []:
|
|
58
|
+
if tool_call.function.name != tool_name:
|
|
59
|
+
continue
|
|
60
|
+
args = tool_call.function.arguments
|
|
61
|
+
if isinstance(args, dict):
|
|
62
|
+
return args.get("task") or args.get("prompt") or args.get("question")
|
|
63
|
+
return str(args) if args is not None else None
|
|
64
|
+
return None
|
|
65
|
+
|
|
66
|
+
def _wait_for_event(self, task_id: str) -> None:
|
|
67
|
+
task_info = self.task_pool.get(task_id)
|
|
68
|
+
if not task_info:
|
|
69
|
+
return
|
|
70
|
+
event = task_info.event
|
|
71
|
+
loop = task_info.loop
|
|
72
|
+
if event and loop:
|
|
73
|
+
future = asyncio.run_coroutine_threadsafe(event.wait(), loop)
|
|
74
|
+
future.result()
|
|
75
|
+
|
|
76
|
+
def _current_task(self) -> tuple[str | None, TaskContext | None]:
|
|
77
|
+
try:
|
|
78
|
+
task_id = current_task_id.get()
|
|
79
|
+
except LookupError:
|
|
80
|
+
return None, None
|
|
81
|
+
return task_id, self.task_pool.get(task_id)
|
|
82
|
+
|
|
83
|
+
@staticmethod
|
|
84
|
+
def _coerce_tool_arguments(raw_args: Any) -> dict[str, Any]:
|
|
85
|
+
if isinstance(raw_args, dict):
|
|
86
|
+
return dict(raw_args)
|
|
87
|
+
if isinstance(raw_args, str):
|
|
88
|
+
try:
|
|
89
|
+
parsed = json.loads(raw_args)
|
|
90
|
+
except Exception:
|
|
91
|
+
return {"value": raw_args}
|
|
92
|
+
return parsed if isinstance(parsed, dict) else {"value": parsed}
|
|
93
|
+
return {}
|
|
94
|
+
|
|
95
|
+
@staticmethod
|
|
96
|
+
def _set_wait_item(task_info: TaskContext, wait_item: WaitState) -> None:
|
|
97
|
+
if isinstance(wait_item, InputRequiredWaitState):
|
|
98
|
+
phase = TaskPhase.WAITING_INPUT
|
|
99
|
+
else:
|
|
100
|
+
phase = TaskPhase.WAITING_AUTH
|
|
101
|
+
task_info.set_wait(wait_item, phase)
|
|
102
|
+
|
|
103
|
+
@staticmethod
|
|
104
|
+
def _raise_if_wait_timed_out(task_info: TaskContext) -> None:
|
|
105
|
+
if task_info.timed_out:
|
|
106
|
+
raise TaskWaitTimeoutError(task_info.timeout_reason or "任务等待超时")
|
|
107
|
+
|
|
108
|
+
@staticmethod
|
|
109
|
+
def _raise_if_task_stopped(task_info: TaskContext) -> None:
|
|
110
|
+
if task_info.control.stop_requested:
|
|
111
|
+
raise task_info.control.stop_error or TaskCancelledError(
|
|
112
|
+
"task cancelled while waiting"
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def _collect_required_input(
|
|
116
|
+
self,
|
|
117
|
+
*,
|
|
118
|
+
task_id: str,
|
|
119
|
+
task_info: TaskContext,
|
|
120
|
+
tool_name: str,
|
|
121
|
+
args: dict[str, Any],
|
|
122
|
+
policy: ToolPolicy,
|
|
123
|
+
) -> dict[str, Any]:
|
|
124
|
+
missing_fields = collect_missing_fields(policy, args)
|
|
125
|
+
while missing_fields:
|
|
126
|
+
prompt = policy.prompt or f"请补充工具 {tool_name} 所需信息"
|
|
127
|
+
logger.debug(
|
|
128
|
+
"Task waiting for user input task_id=%s tool=%s missing_fields=%s",
|
|
129
|
+
task_id,
|
|
130
|
+
tool_name,
|
|
131
|
+
[field.name for field in missing_fields],
|
|
132
|
+
)
|
|
133
|
+
self._set_wait_item(task_info, InputRequiredWaitState(prompt=prompt))
|
|
134
|
+
self._wait_for_event(task_id)
|
|
135
|
+
self._raise_if_task_stopped(task_info)
|
|
136
|
+
self._raise_if_wait_timed_out(task_info)
|
|
137
|
+
resume_payload = parse_user_payload(task_info.user_input)
|
|
138
|
+
task_info.user_input = None
|
|
139
|
+
args = merge_input_fields(args, policy, resume_payload)
|
|
140
|
+
missing_fields = collect_missing_fields(policy, args)
|
|
141
|
+
return args
|
|
142
|
+
|
|
143
|
+
def _request_tool_confirmation(
|
|
144
|
+
self,
|
|
145
|
+
*,
|
|
146
|
+
task_id: str,
|
|
147
|
+
task_info: TaskContext,
|
|
148
|
+
tool_name: str,
|
|
149
|
+
args: dict[str, Any],
|
|
150
|
+
policy: ToolPolicy,
|
|
151
|
+
) -> Any:
|
|
152
|
+
if not policy.requires_confirmation:
|
|
153
|
+
return args
|
|
154
|
+
|
|
155
|
+
if task_info.wait_item is not None:
|
|
156
|
+
return args
|
|
157
|
+
|
|
158
|
+
logger.debug(
|
|
159
|
+
"Task waiting for auth confirmation task_id=%s tool=%s",
|
|
160
|
+
task_id,
|
|
161
|
+
tool_name,
|
|
162
|
+
)
|
|
163
|
+
self._set_wait_item(
|
|
164
|
+
task_info, AuthRequiredWaitState(tool_name=tool_name, args=args)
|
|
165
|
+
)
|
|
166
|
+
self._wait_for_event(task_id)
|
|
167
|
+
self._raise_if_task_stopped(task_info)
|
|
168
|
+
self._raise_if_wait_timed_out(task_info)
|
|
169
|
+
if task_info.auth_denied:
|
|
170
|
+
task_info.auth_denied = True
|
|
171
|
+
raise UserCancelledError("user denied auth request")
|
|
172
|
+
|
|
173
|
+
if task_info.tool_args_override is not None:
|
|
174
|
+
override = task_info.tool_args_override
|
|
175
|
+
override_name = task_info.tool_name_override
|
|
176
|
+
task_info.tool_args_override = None
|
|
177
|
+
task_info.tool_name_override = None
|
|
178
|
+
if policy.allow_arg_override and (
|
|
179
|
+
not override_name or override_name == tool_name
|
|
180
|
+
):
|
|
181
|
+
return override
|
|
182
|
+
return args
|
|
183
|
+
|
|
184
|
+
def _prepare_tool_arguments(
|
|
185
|
+
self,
|
|
186
|
+
*,
|
|
187
|
+
tool_name: str,
|
|
188
|
+
mcp_name: str | None,
|
|
189
|
+
args: Any,
|
|
190
|
+
task_id: str,
|
|
191
|
+
task_info: TaskContext,
|
|
192
|
+
tool_call: Any,
|
|
193
|
+
) -> Any:
|
|
194
|
+
plugin = self.plugin
|
|
195
|
+
if plugin is None:
|
|
196
|
+
return args
|
|
197
|
+
|
|
198
|
+
context = ToolExecutionContext(
|
|
199
|
+
agent_definition=self.definition,
|
|
200
|
+
tool_name=tool_name,
|
|
201
|
+
mcp_name=mcp_name,
|
|
202
|
+
args=dict(args),
|
|
203
|
+
task_info=task_info,
|
|
204
|
+
raw_tool_call=tool_call,
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
self._raise_if_task_stopped(task_info)
|
|
208
|
+
|
|
209
|
+
normalized = plugin.normalize_args(context)
|
|
210
|
+
if isinstance(normalized, dict):
|
|
211
|
+
args = normalized
|
|
212
|
+
context.args = dict(args)
|
|
213
|
+
|
|
214
|
+
decision = plugin.before_tool(context)
|
|
215
|
+
policy = self._merged_policy(tool_name, decision, mcp_name)
|
|
216
|
+
args = self._collect_required_input(
|
|
217
|
+
task_id=task_id,
|
|
218
|
+
task_info=task_info,
|
|
219
|
+
tool_name=tool_name,
|
|
220
|
+
args=args,
|
|
221
|
+
policy=policy,
|
|
222
|
+
)
|
|
223
|
+
args = self._request_tool_confirmation(
|
|
224
|
+
task_id=task_id,
|
|
225
|
+
task_info=task_info,
|
|
226
|
+
tool_name=tool_name,
|
|
227
|
+
args=args,
|
|
228
|
+
policy=policy,
|
|
229
|
+
)
|
|
230
|
+
context.args = (
|
|
231
|
+
dict(args) if isinstance(args, dict) else self._coerce_tool_arguments(args)
|
|
232
|
+
)
|
|
233
|
+
return args
|
|
234
|
+
|
|
235
|
+
def _build_tool_callback(self, tool_name: str):
|
|
236
|
+
source_mcp = self._tool_source_by_name.get(tool_name)
|
|
237
|
+
|
|
238
|
+
def _callback(chat_message, agent=None):
|
|
239
|
+
tool_call = self._find_tool_call(chat_message, tool_name)
|
|
240
|
+
if tool_call is None:
|
|
241
|
+
return
|
|
242
|
+
|
|
243
|
+
task_id, task_info = self._current_task()
|
|
244
|
+
if task_id is None or task_info is None:
|
|
245
|
+
return
|
|
246
|
+
|
|
247
|
+
logger.debug(
|
|
248
|
+
"Preparing tool call task_id=%s tool=%s mcp=%s",
|
|
249
|
+
task_id,
|
|
250
|
+
tool_name,
|
|
251
|
+
source_mcp,
|
|
252
|
+
)
|
|
253
|
+
args = self._coerce_tool_arguments(tool_call.function.arguments)
|
|
254
|
+
tool_call.function.arguments = self._prepare_tool_arguments(
|
|
255
|
+
tool_name=tool_name,
|
|
256
|
+
mcp_name=source_mcp,
|
|
257
|
+
args=args,
|
|
258
|
+
task_id=task_id,
|
|
259
|
+
task_info=task_info,
|
|
260
|
+
tool_call=tool_call,
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
return _callback
|
|
264
|
+
|
|
265
|
+
def _merged_policy(
|
|
266
|
+
self, tool_name: str, decision: Any, mcp_name: str | None = None
|
|
267
|
+
) -> ToolPolicy:
|
|
268
|
+
return self._policy_registry.merged_policy(tool_name, decision, mcp_name)
|
|
269
|
+
|
|
270
|
+
@staticmethod
|
|
271
|
+
def _find_tool_call(chat_message: Any, tool_name: str):
|
|
272
|
+
for tool_call in chat_message.tool_calls or []:
|
|
273
|
+
if tool_call.function.name == tool_name:
|
|
274
|
+
return tool_call
|
|
275
|
+
return None
|
|
276
|
+
|
|
277
|
+
def _interrupt_before_ask_user(self, chat_message, agent=None):
|
|
278
|
+
task_id, task_info = self._current_task()
|
|
279
|
+
if task_id is None or task_info is None:
|
|
280
|
+
return
|
|
281
|
+
logger.debug("Task entered explicit ask_user wait task_id=%s", task_id)
|
|
282
|
+
self._set_wait_item(
|
|
283
|
+
task_info,
|
|
284
|
+
InputRequiredWaitState(
|
|
285
|
+
prompt=self._extract_tool_prompt(chat_message, "ask_user")
|
|
286
|
+
or "需要您补充信息"
|
|
287
|
+
),
|
|
288
|
+
)
|
|
289
|
+
self._wait_for_event(task_id)
|
|
290
|
+
self._raise_if_task_stopped(task_info)
|
|
291
|
+
self._raise_if_wait_timed_out(task_info)
|
|
292
|
+
|
|
293
|
+
def _interrupt_before_ask_auth(self, chat_message, agent=None):
|
|
294
|
+
task_id, task_info = self._current_task()
|
|
295
|
+
if task_id is None or task_info is None:
|
|
296
|
+
return
|
|
297
|
+
logger.debug("Task entered explicit ask_auth wait task_id=%s", task_id)
|
|
298
|
+
self._set_wait_item(task_info, AuthRequiredWaitState())
|
|
299
|
+
self._wait_for_event(task_id)
|
|
300
|
+
self._raise_if_task_stopped(task_info)
|
|
301
|
+
self._raise_if_wait_timed_out(task_info)
|
|
302
|
+
if task_info.auth_denied:
|
|
303
|
+
task_info.auth_denied = True
|
|
304
|
+
raise UserCancelledError("user denied auth request")
|
|
305
|
+
|
|
306
|
+
def _after_tool_hook(self, tool_output):
|
|
307
|
+
if self.plugin is None:
|
|
308
|
+
return
|
|
309
|
+
tool_name = (
|
|
310
|
+
getattr(tool_output.tool_call, "name", "") if tool_output.tool_call else ""
|
|
311
|
+
)
|
|
312
|
+
source_mcp = self._tool_source_by_name.get(tool_name)
|
|
313
|
+
task_id, task_info = self._current_task()
|
|
314
|
+
if task_id is not None:
|
|
315
|
+
logger.debug(
|
|
316
|
+
"Finished tool call task_id=%s tool=%s mcp=%s",
|
|
317
|
+
task_id,
|
|
318
|
+
tool_name,
|
|
319
|
+
source_mcp,
|
|
320
|
+
)
|
|
321
|
+
if task_info is not None:
|
|
322
|
+
task_info.clear_active_tool()
|
|
323
|
+
self.plugin.after_tool(
|
|
324
|
+
ToolExecutionContext(
|
|
325
|
+
agent_definition=self.definition,
|
|
326
|
+
tool_name=tool_name,
|
|
327
|
+
mcp_name=source_mcp,
|
|
328
|
+
args={},
|
|
329
|
+
observation=tool_output.observation,
|
|
330
|
+
result=tool_output.output,
|
|
331
|
+
)
|
|
332
|
+
)
|