agent-runtime-sdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. agent_runtime/__init__.py +84 -0
  2. agent_runtime/builder.py +317 -0
  3. agent_runtime/config/__init__.py +29 -0
  4. agent_runtime/config/definitions.py +144 -0
  5. agent_runtime/config/policies.py +63 -0
  6. agent_runtime/config/storage.py +117 -0
  7. agent_runtime/context.py +10 -0
  8. agent_runtime/definitions.py +33 -0
  9. agent_runtime/discovery.py +16 -0
  10. agent_runtime/exceptions.py +74 -0
  11. agent_runtime/mcp/__init__.py +28 -0
  12. agent_runtime/mcp/discovery.py +146 -0
  13. agent_runtime/mcp/metadata.py +68 -0
  14. agent_runtime/mcp/utils.py +52 -0
  15. agent_runtime/model_registry.py +40 -0
  16. agent_runtime/plugins/__init__.py +4 -0
  17. agent_runtime/plugins/base.py +90 -0
  18. agent_runtime/plugins/default.py +19 -0
  19. agent_runtime/plugins/instructions.py +38 -0
  20. agent_runtime/plugins/loader.py +59 -0
  21. agent_runtime/policies.py +15 -0
  22. agent_runtime/runtime.py +110 -0
  23. agent_runtime/runtime_engine/__init__.py +22 -0
  24. agent_runtime/runtime_engine/a2a_bridge.py +190 -0
  25. agent_runtime/runtime_engine/a2a_task_io.py +165 -0
  26. agent_runtime/runtime_engine/agent_build.py +315 -0
  27. agent_runtime/runtime_engine/context.py +469 -0
  28. agent_runtime/runtime_engine/loading.py +170 -0
  29. agent_runtime/runtime_engine/observability.py +154 -0
  30. agent_runtime/runtime_engine/policy_registry.py +98 -0
  31. agent_runtime/runtime_engine/protocol_tools.py +94 -0
  32. agent_runtime/runtime_engine/task_flow.py +897 -0
  33. agent_runtime/runtime_engine/tool_flow.py +332 -0
  34. agent_runtime/sdk_agent.py +548 -0
  35. agent_runtime/server/__init__.py +15 -0
  36. agent_runtime/server/app_factory.py +37 -0
  37. agent_runtime/server/bootstrap.py +48 -0
  38. agent_runtime/server/endpoint_utils.py +37 -0
  39. agent_runtime/server/management.py +107 -0
  40. agent_runtime/smol/__init__.py +4 -0
  41. agent_runtime/smol/agents.py +431 -0
  42. agent_runtime/smol/llm_models.py +212 -0
  43. agent_runtime/smol/memory.py +111 -0
  44. agent_runtime/smol/models.py +69 -0
  45. agent_runtime/standalone.py +57 -0
  46. agent_runtime/storage.py +5 -0
  47. agent_runtime/tools.py +5 -0
  48. agent_runtime_sdk-0.1.0.dist-info/METADATA +125 -0
  49. agent_runtime_sdk-0.1.0.dist-info/RECORD +51 -0
  50. agent_runtime_sdk-0.1.0.dist-info/WHEEL +5 -0
  51. agent_runtime_sdk-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,332 @@
1
+ """工具调用前的运行时治理逻辑。
2
+
3
+ 注意:这里不真正执行工具,只在“工具即将执行”前做几件事:
4
+
5
+ - 归一化参数
6
+ - 检查是否缺少必填输入
7
+ - 必要时向用户追问
8
+ - 必要时向用户请求授权确认
9
+ - 调用 plugin 的 before_tool / after_tool 钩子
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import asyncio
15
+ import json
16
+ import logging
17
+ from typing import Any, Callable
18
+
19
+ from .context import (
20
+ AuthRequiredWaitState,
21
+ InputRequiredWaitState,
22
+ TaskContext,
23
+ TaskPhase,
24
+ WaitState,
25
+ current_task_id,
26
+ )
27
+ from ..config.definitions import ToolPolicy
28
+ from ..exceptions import TaskCancelledError, TaskWaitTimeoutError, UserCancelledError
29
+ from ..plugins import ToolExecutionContext
30
+ from ..config.policies import (
31
+ collect_missing_fields,
32
+ merge_input_fields,
33
+ parse_user_payload,
34
+ )
35
+
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ class RuntimeToolFlow:
41
+ """负责工具执行前后的治理钩子与等待中断逻辑。"""
42
+
43
+ def _build_tool_callbacks(
44
+ self, tools: list[Any], extra_tools: list[Any]
45
+ ) -> dict[str, list[Callable]]:
46
+ tool_callbacks: dict[str, list[Callable]] = {
47
+ "ask_user": [self._interrupt_before_ask_user],
48
+ "ask_auth": [self._interrupt_before_ask_auth],
49
+ }
50
+ for tool in [*tools, *extra_tools]:
51
+ tool_name = getattr(tool, "name", "")
52
+ if tool_name and tool_name not in {"ask_user", "ask_auth", "final_answer"}:
53
+ tool_callbacks[tool_name] = [self._build_tool_callback(tool_name)]
54
+ return tool_callbacks
55
+
56
+ def _extract_tool_prompt(self, chat_message: Any, tool_name: str) -> str | None:
57
+ for tool_call in chat_message.tool_calls or []:
58
+ if tool_call.function.name != tool_name:
59
+ continue
60
+ args = tool_call.function.arguments
61
+ if isinstance(args, dict):
62
+ return args.get("task") or args.get("prompt") or args.get("question")
63
+ return str(args) if args is not None else None
64
+ return None
65
+
66
+ def _wait_for_event(self, task_id: str) -> None:
67
+ task_info = self.task_pool.get(task_id)
68
+ if not task_info:
69
+ return
70
+ event = task_info.event
71
+ loop = task_info.loop
72
+ if event and loop:
73
+ future = asyncio.run_coroutine_threadsafe(event.wait(), loop)
74
+ future.result()
75
+
76
+ def _current_task(self) -> tuple[str | None, TaskContext | None]:
77
+ try:
78
+ task_id = current_task_id.get()
79
+ except LookupError:
80
+ return None, None
81
+ return task_id, self.task_pool.get(task_id)
82
+
83
+ @staticmethod
84
+ def _coerce_tool_arguments(raw_args: Any) -> dict[str, Any]:
85
+ if isinstance(raw_args, dict):
86
+ return dict(raw_args)
87
+ if isinstance(raw_args, str):
88
+ try:
89
+ parsed = json.loads(raw_args)
90
+ except Exception:
91
+ return {"value": raw_args}
92
+ return parsed if isinstance(parsed, dict) else {"value": parsed}
93
+ return {}
94
+
95
+ @staticmethod
96
+ def _set_wait_item(task_info: TaskContext, wait_item: WaitState) -> None:
97
+ if isinstance(wait_item, InputRequiredWaitState):
98
+ phase = TaskPhase.WAITING_INPUT
99
+ else:
100
+ phase = TaskPhase.WAITING_AUTH
101
+ task_info.set_wait(wait_item, phase)
102
+
103
+ @staticmethod
104
+ def _raise_if_wait_timed_out(task_info: TaskContext) -> None:
105
+ if task_info.timed_out:
106
+ raise TaskWaitTimeoutError(task_info.timeout_reason or "任务等待超时")
107
+
108
+ @staticmethod
109
+ def _raise_if_task_stopped(task_info: TaskContext) -> None:
110
+ if task_info.control.stop_requested:
111
+ raise task_info.control.stop_error or TaskCancelledError(
112
+ "task cancelled while waiting"
113
+ )
114
+
115
+ def _collect_required_input(
116
+ self,
117
+ *,
118
+ task_id: str,
119
+ task_info: TaskContext,
120
+ tool_name: str,
121
+ args: dict[str, Any],
122
+ policy: ToolPolicy,
123
+ ) -> dict[str, Any]:
124
+ missing_fields = collect_missing_fields(policy, args)
125
+ while missing_fields:
126
+ prompt = policy.prompt or f"请补充工具 {tool_name} 所需信息"
127
+ logger.debug(
128
+ "Task waiting for user input task_id=%s tool=%s missing_fields=%s",
129
+ task_id,
130
+ tool_name,
131
+ [field.name for field in missing_fields],
132
+ )
133
+ self._set_wait_item(task_info, InputRequiredWaitState(prompt=prompt))
134
+ self._wait_for_event(task_id)
135
+ self._raise_if_task_stopped(task_info)
136
+ self._raise_if_wait_timed_out(task_info)
137
+ resume_payload = parse_user_payload(task_info.user_input)
138
+ task_info.user_input = None
139
+ args = merge_input_fields(args, policy, resume_payload)
140
+ missing_fields = collect_missing_fields(policy, args)
141
+ return args
142
+
143
+ def _request_tool_confirmation(
144
+ self,
145
+ *,
146
+ task_id: str,
147
+ task_info: TaskContext,
148
+ tool_name: str,
149
+ args: dict[str, Any],
150
+ policy: ToolPolicy,
151
+ ) -> Any:
152
+ if not policy.requires_confirmation:
153
+ return args
154
+
155
+ if task_info.wait_item is not None:
156
+ return args
157
+
158
+ logger.debug(
159
+ "Task waiting for auth confirmation task_id=%s tool=%s",
160
+ task_id,
161
+ tool_name,
162
+ )
163
+ self._set_wait_item(
164
+ task_info, AuthRequiredWaitState(tool_name=tool_name, args=args)
165
+ )
166
+ self._wait_for_event(task_id)
167
+ self._raise_if_task_stopped(task_info)
168
+ self._raise_if_wait_timed_out(task_info)
169
+ if task_info.auth_denied:
170
+ task_info.auth_denied = True
171
+ raise UserCancelledError("user denied auth request")
172
+
173
+ if task_info.tool_args_override is not None:
174
+ override = task_info.tool_args_override
175
+ override_name = task_info.tool_name_override
176
+ task_info.tool_args_override = None
177
+ task_info.tool_name_override = None
178
+ if policy.allow_arg_override and (
179
+ not override_name or override_name == tool_name
180
+ ):
181
+ return override
182
+ return args
183
+
184
+ def _prepare_tool_arguments(
185
+ self,
186
+ *,
187
+ tool_name: str,
188
+ mcp_name: str | None,
189
+ args: Any,
190
+ task_id: str,
191
+ task_info: TaskContext,
192
+ tool_call: Any,
193
+ ) -> Any:
194
+ plugin = self.plugin
195
+ if plugin is None:
196
+ return args
197
+
198
+ context = ToolExecutionContext(
199
+ agent_definition=self.definition,
200
+ tool_name=tool_name,
201
+ mcp_name=mcp_name,
202
+ args=dict(args),
203
+ task_info=task_info,
204
+ raw_tool_call=tool_call,
205
+ )
206
+
207
+ self._raise_if_task_stopped(task_info)
208
+
209
+ normalized = plugin.normalize_args(context)
210
+ if isinstance(normalized, dict):
211
+ args = normalized
212
+ context.args = dict(args)
213
+
214
+ decision = plugin.before_tool(context)
215
+ policy = self._merged_policy(tool_name, decision, mcp_name)
216
+ args = self._collect_required_input(
217
+ task_id=task_id,
218
+ task_info=task_info,
219
+ tool_name=tool_name,
220
+ args=args,
221
+ policy=policy,
222
+ )
223
+ args = self._request_tool_confirmation(
224
+ task_id=task_id,
225
+ task_info=task_info,
226
+ tool_name=tool_name,
227
+ args=args,
228
+ policy=policy,
229
+ )
230
+ context.args = (
231
+ dict(args) if isinstance(args, dict) else self._coerce_tool_arguments(args)
232
+ )
233
+ return args
234
+
235
+ def _build_tool_callback(self, tool_name: str):
236
+ source_mcp = self._tool_source_by_name.get(tool_name)
237
+
238
+ def _callback(chat_message, agent=None):
239
+ tool_call = self._find_tool_call(chat_message, tool_name)
240
+ if tool_call is None:
241
+ return
242
+
243
+ task_id, task_info = self._current_task()
244
+ if task_id is None or task_info is None:
245
+ return
246
+
247
+ logger.debug(
248
+ "Preparing tool call task_id=%s tool=%s mcp=%s",
249
+ task_id,
250
+ tool_name,
251
+ source_mcp,
252
+ )
253
+ args = self._coerce_tool_arguments(tool_call.function.arguments)
254
+ tool_call.function.arguments = self._prepare_tool_arguments(
255
+ tool_name=tool_name,
256
+ mcp_name=source_mcp,
257
+ args=args,
258
+ task_id=task_id,
259
+ task_info=task_info,
260
+ tool_call=tool_call,
261
+ )
262
+
263
+ return _callback
264
+
265
+ def _merged_policy(
266
+ self, tool_name: str, decision: Any, mcp_name: str | None = None
267
+ ) -> ToolPolicy:
268
+ return self._policy_registry.merged_policy(tool_name, decision, mcp_name)
269
+
270
+ @staticmethod
271
+ def _find_tool_call(chat_message: Any, tool_name: str):
272
+ for tool_call in chat_message.tool_calls or []:
273
+ if tool_call.function.name == tool_name:
274
+ return tool_call
275
+ return None
276
+
277
+ def _interrupt_before_ask_user(self, chat_message, agent=None):
278
+ task_id, task_info = self._current_task()
279
+ if task_id is None or task_info is None:
280
+ return
281
+ logger.debug("Task entered explicit ask_user wait task_id=%s", task_id)
282
+ self._set_wait_item(
283
+ task_info,
284
+ InputRequiredWaitState(
285
+ prompt=self._extract_tool_prompt(chat_message, "ask_user")
286
+ or "需要您补充信息"
287
+ ),
288
+ )
289
+ self._wait_for_event(task_id)
290
+ self._raise_if_task_stopped(task_info)
291
+ self._raise_if_wait_timed_out(task_info)
292
+
293
+ def _interrupt_before_ask_auth(self, chat_message, agent=None):
294
+ task_id, task_info = self._current_task()
295
+ if task_id is None or task_info is None:
296
+ return
297
+ logger.debug("Task entered explicit ask_auth wait task_id=%s", task_id)
298
+ self._set_wait_item(task_info, AuthRequiredWaitState())
299
+ self._wait_for_event(task_id)
300
+ self._raise_if_task_stopped(task_info)
301
+ self._raise_if_wait_timed_out(task_info)
302
+ if task_info.auth_denied:
303
+ task_info.auth_denied = True
304
+ raise UserCancelledError("user denied auth request")
305
+
306
+ def _after_tool_hook(self, tool_output):
307
+ if self.plugin is None:
308
+ return
309
+ tool_name = (
310
+ getattr(tool_output.tool_call, "name", "") if tool_output.tool_call else ""
311
+ )
312
+ source_mcp = self._tool_source_by_name.get(tool_name)
313
+ task_id, task_info = self._current_task()
314
+ if task_id is not None:
315
+ logger.debug(
316
+ "Finished tool call task_id=%s tool=%s mcp=%s",
317
+ task_id,
318
+ tool_name,
319
+ source_mcp,
320
+ )
321
+ if task_info is not None:
322
+ task_info.clear_active_tool()
323
+ self.plugin.after_tool(
324
+ ToolExecutionContext(
325
+ agent_definition=self.definition,
326
+ tool_name=tool_name,
327
+ mcp_name=source_mcp,
328
+ args={},
329
+ observation=tool_output.observation,
330
+ result=tool_output.output,
331
+ )
332
+ )