agent-runtime-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent_runtime/__init__.py +84 -0
- agent_runtime/builder.py +317 -0
- agent_runtime/config/__init__.py +29 -0
- agent_runtime/config/definitions.py +144 -0
- agent_runtime/config/policies.py +63 -0
- agent_runtime/config/storage.py +117 -0
- agent_runtime/context.py +10 -0
- agent_runtime/definitions.py +33 -0
- agent_runtime/discovery.py +16 -0
- agent_runtime/exceptions.py +74 -0
- agent_runtime/mcp/__init__.py +28 -0
- agent_runtime/mcp/discovery.py +146 -0
- agent_runtime/mcp/metadata.py +68 -0
- agent_runtime/mcp/utils.py +52 -0
- agent_runtime/model_registry.py +40 -0
- agent_runtime/plugins/__init__.py +4 -0
- agent_runtime/plugins/base.py +90 -0
- agent_runtime/plugins/default.py +19 -0
- agent_runtime/plugins/instructions.py +38 -0
- agent_runtime/plugins/loader.py +59 -0
- agent_runtime/policies.py +15 -0
- agent_runtime/runtime.py +110 -0
- agent_runtime/runtime_engine/__init__.py +22 -0
- agent_runtime/runtime_engine/a2a_bridge.py +190 -0
- agent_runtime/runtime_engine/a2a_task_io.py +165 -0
- agent_runtime/runtime_engine/agent_build.py +315 -0
- agent_runtime/runtime_engine/context.py +469 -0
- agent_runtime/runtime_engine/loading.py +170 -0
- agent_runtime/runtime_engine/observability.py +154 -0
- agent_runtime/runtime_engine/policy_registry.py +98 -0
- agent_runtime/runtime_engine/protocol_tools.py +94 -0
- agent_runtime/runtime_engine/task_flow.py +897 -0
- agent_runtime/runtime_engine/tool_flow.py +332 -0
- agent_runtime/sdk_agent.py +548 -0
- agent_runtime/server/__init__.py +15 -0
- agent_runtime/server/app_factory.py +37 -0
- agent_runtime/server/bootstrap.py +48 -0
- agent_runtime/server/endpoint_utils.py +37 -0
- agent_runtime/server/management.py +107 -0
- agent_runtime/smol/__init__.py +4 -0
- agent_runtime/smol/agents.py +431 -0
- agent_runtime/smol/llm_models.py +212 -0
- agent_runtime/smol/memory.py +111 -0
- agent_runtime/smol/models.py +69 -0
- agent_runtime/standalone.py +57 -0
- agent_runtime/storage.py +5 -0
- agent_runtime/tools.py +5 -0
- agent_runtime_sdk-0.1.0.dist-info/METADATA +125 -0
- agent_runtime_sdk-0.1.0.dist-info/RECORD +51 -0
- agent_runtime_sdk-0.1.0.dist-info/WHEEL +5 -0
- agent_runtime_sdk-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,897 @@
|
|
|
1
|
+
"""任务生命周期状态机。
|
|
2
|
+
|
|
3
|
+
这个文件处理一条任务从开始到结束的全过程:
|
|
4
|
+
|
|
5
|
+
- 启动任务
|
|
6
|
+
- 进入 input_required / auth_required 等待态
|
|
7
|
+
- 用户恢复任务
|
|
8
|
+
- 超时与取消
|
|
9
|
+
- 最终完成或失败
|
|
10
|
+
|
|
11
|
+
它是 runtime 里最像“调度器”的一层。
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
18
|
+
from contextvars import copy_context
|
|
19
|
+
from functools import partial
|
|
20
|
+
import logging
|
|
21
|
+
import os
|
|
22
|
+
from typing import Any
|
|
23
|
+
|
|
24
|
+
from .a2a_task_io import (
|
|
25
|
+
emit_wait_state,
|
|
26
|
+
publish_error,
|
|
27
|
+
publish_message_completion,
|
|
28
|
+
publish_result,
|
|
29
|
+
)
|
|
30
|
+
from .context import (
|
|
31
|
+
TaskContext,
|
|
32
|
+
TaskPhase,
|
|
33
|
+
TaskUpdaterProtocol,
|
|
34
|
+
WAIT_TYPE_AUTH_REQUIRED,
|
|
35
|
+
WAIT_TYPE_INPUT_REQUIRED,
|
|
36
|
+
wait_state_payload,
|
|
37
|
+
wait_state_type,
|
|
38
|
+
)
|
|
39
|
+
from .observability import (
|
|
40
|
+
A2ATaskObservation,
|
|
41
|
+
a2a_task_observation,
|
|
42
|
+
update_observation_output,
|
|
43
|
+
)
|
|
44
|
+
from ..exceptions import (
|
|
45
|
+
TaskCancelledError,
|
|
46
|
+
TaskExecutionTimeoutError,
|
|
47
|
+
TaskWaitTimeoutError,
|
|
48
|
+
UserCancelledError,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
logger = logging.getLogger(__name__)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class RuntimeTaskFlow:
|
|
56
|
+
"""负责任务生命周期调度与 wait/resume 状态转换。"""
|
|
57
|
+
|
|
58
|
+
_AUTH_DECISION_KEYS = (
|
|
59
|
+
"approve",
|
|
60
|
+
"approved",
|
|
61
|
+
"confirm",
|
|
62
|
+
"confirmed",
|
|
63
|
+
"accept",
|
|
64
|
+
"accepted",
|
|
65
|
+
"allow",
|
|
66
|
+
"allowed",
|
|
67
|
+
"authorize",
|
|
68
|
+
"authorized",
|
|
69
|
+
"grant",
|
|
70
|
+
"granted",
|
|
71
|
+
"decision",
|
|
72
|
+
)
|
|
73
|
+
_AUTH_DECISION_KEY_SET = set(_AUTH_DECISION_KEYS)
|
|
74
|
+
|
|
75
|
+
@staticmethod
|
|
76
|
+
def _format_timeout_seconds(value: float) -> str:
|
|
77
|
+
if float(value).is_integer():
|
|
78
|
+
return str(int(value))
|
|
79
|
+
return f"{value:g}"
|
|
80
|
+
|
|
81
|
+
def pop_failed_task_id(self, task_id: str) -> bool:
|
|
82
|
+
if task_id not in self._failed_task_ids:
|
|
83
|
+
return False
|
|
84
|
+
self._failed_task_ids.remove(task_id)
|
|
85
|
+
return True
|
|
86
|
+
|
|
87
|
+
@staticmethod
|
|
88
|
+
def _normalize_auth_decision(value: object) -> bool | None:
|
|
89
|
+
if isinstance(value, bool):
|
|
90
|
+
return value
|
|
91
|
+
if isinstance(value, int) and not isinstance(value, bool):
|
|
92
|
+
if value == 1:
|
|
93
|
+
return True
|
|
94
|
+
if value == 0:
|
|
95
|
+
return False
|
|
96
|
+
if value is None:
|
|
97
|
+
return None
|
|
98
|
+
text = str(value).strip().lower()
|
|
99
|
+
if text in {
|
|
100
|
+
"true",
|
|
101
|
+
"1",
|
|
102
|
+
"yes",
|
|
103
|
+
"y",
|
|
104
|
+
"ok",
|
|
105
|
+
"approve",
|
|
106
|
+
"approved",
|
|
107
|
+
"confirm",
|
|
108
|
+
"confirmed",
|
|
109
|
+
"accept",
|
|
110
|
+
"accepted",
|
|
111
|
+
"allow",
|
|
112
|
+
"allowed",
|
|
113
|
+
"authorize",
|
|
114
|
+
"authorized",
|
|
115
|
+
"grant",
|
|
116
|
+
"granted",
|
|
117
|
+
"同意",
|
|
118
|
+
"确认",
|
|
119
|
+
"批准",
|
|
120
|
+
"允许",
|
|
121
|
+
"可以",
|
|
122
|
+
"是",
|
|
123
|
+
"好的",
|
|
124
|
+
"继续",
|
|
125
|
+
"通过",
|
|
126
|
+
}:
|
|
127
|
+
return True
|
|
128
|
+
if text in {
|
|
129
|
+
"false",
|
|
130
|
+
"0",
|
|
131
|
+
"no",
|
|
132
|
+
"n",
|
|
133
|
+
"deny",
|
|
134
|
+
"denied",
|
|
135
|
+
"reject",
|
|
136
|
+
"rejected",
|
|
137
|
+
"cancel",
|
|
138
|
+
"cancelled",
|
|
139
|
+
"取消",
|
|
140
|
+
"拒绝",
|
|
141
|
+
"不同意",
|
|
142
|
+
"否",
|
|
143
|
+
"不行",
|
|
144
|
+
"停止",
|
|
145
|
+
}:
|
|
146
|
+
return False
|
|
147
|
+
return None
|
|
148
|
+
|
|
149
|
+
@classmethod
|
|
150
|
+
def _extract_auth_decision(cls, parsed_dict: dict[str, Any]) -> bool | None:
|
|
151
|
+
for key in cls._AUTH_DECISION_KEYS:
|
|
152
|
+
if key in parsed_dict:
|
|
153
|
+
return cls._normalize_auth_decision(parsed_dict.get(key))
|
|
154
|
+
return None
|
|
155
|
+
|
|
156
|
+
@classmethod
|
|
157
|
+
def _auth_payload(cls, parsed_dict: dict[str, Any]) -> dict[str, Any]:
|
|
158
|
+
for key in ("data", "payload"):
|
|
159
|
+
value = parsed_dict.get(key)
|
|
160
|
+
if not isinstance(value, dict):
|
|
161
|
+
continue
|
|
162
|
+
if cls._AUTH_DECISION_KEY_SET.intersection(value) or any(
|
|
163
|
+
arg_key in value for arg_key in ("args", "tool_args", "arguments")
|
|
164
|
+
):
|
|
165
|
+
return value
|
|
166
|
+
return parsed_dict
|
|
167
|
+
|
|
168
|
+
@classmethod
|
|
169
|
+
def _extract_auth_override(
|
|
170
|
+
cls,
|
|
171
|
+
parsed_dict: dict[str, Any] | None,
|
|
172
|
+
) -> tuple[str | None, Any]:
|
|
173
|
+
if parsed_dict is None:
|
|
174
|
+
return None, None
|
|
175
|
+
|
|
176
|
+
tool_name = parsed_dict.get("tool_name") or parsed_dict.get("name")
|
|
177
|
+
if "args" in parsed_dict:
|
|
178
|
+
return tool_name, parsed_dict.get("args")
|
|
179
|
+
if "tool_args" in parsed_dict:
|
|
180
|
+
return tool_name, parsed_dict.get("tool_args")
|
|
181
|
+
if "arguments" in parsed_dict:
|
|
182
|
+
return tool_name, parsed_dict.get("arguments")
|
|
183
|
+
if not cls._AUTH_DECISION_KEY_SET.intersection(parsed_dict):
|
|
184
|
+
return tool_name, parsed_dict
|
|
185
|
+
return tool_name, None
|
|
186
|
+
|
|
187
|
+
@staticmethod
|
|
188
|
+
def _cancel_wait_timeout(task_info: TaskContext) -> None:
|
|
189
|
+
handle = task_info.timeout_handle
|
|
190
|
+
if handle is not None:
|
|
191
|
+
handle.cancel()
|
|
192
|
+
task_info.timeout_handle = None
|
|
193
|
+
|
|
194
|
+
@staticmethod
|
|
195
|
+
def _cancel_task_timeout(task_info: TaskContext) -> None:
|
|
196
|
+
handle = task_info.control.task_timeout_handle
|
|
197
|
+
if handle is not None:
|
|
198
|
+
handle.cancel()
|
|
199
|
+
task_info.control.task_timeout_handle = None
|
|
200
|
+
|
|
201
|
+
def _cancel_all_timeouts(self, task_info: TaskContext) -> None:
|
|
202
|
+
self._cancel_wait_timeout(task_info)
|
|
203
|
+
self._cancel_task_timeout(task_info)
|
|
204
|
+
|
|
205
|
+
def _clear_wait_state(
|
|
206
|
+
self, task_info: TaskContext, *, clear_resume_payload: bool = True
|
|
207
|
+
) -> None:
|
|
208
|
+
self._cancel_wait_timeout(task_info)
|
|
209
|
+
if clear_resume_payload:
|
|
210
|
+
task_info.clear_wait()
|
|
211
|
+
else:
|
|
212
|
+
task_info.clear_wait_item()
|
|
213
|
+
|
|
214
|
+
@staticmethod
|
|
215
|
+
def _shutdown_task_executor(task_info: TaskContext) -> None:
|
|
216
|
+
executor = task_info.task_executor
|
|
217
|
+
if executor is None:
|
|
218
|
+
return
|
|
219
|
+
task_info.task_executor = None
|
|
220
|
+
try:
|
|
221
|
+
executor.shutdown(wait=False, cancel_futures=True)
|
|
222
|
+
except TypeError:
|
|
223
|
+
executor.shutdown(wait=False)
|
|
224
|
+
|
|
225
|
+
@staticmethod
|
|
226
|
+
def _release_event(task_info: TaskContext) -> None:
|
|
227
|
+
event = task_info.event
|
|
228
|
+
if event:
|
|
229
|
+
event.set()
|
|
230
|
+
task_info.event = asyncio.Event()
|
|
231
|
+
|
|
232
|
+
@staticmethod
|
|
233
|
+
def _interrupt_agent(task_info: TaskContext) -> None:
|
|
234
|
+
agent = task_info.agent
|
|
235
|
+
if agent is not None:
|
|
236
|
+
try:
|
|
237
|
+
agent.interrupt_switch = True
|
|
238
|
+
except Exception:
|
|
239
|
+
pass
|
|
240
|
+
|
|
241
|
+
agent_task = task_info.agent_task
|
|
242
|
+
if agent_task is not None and not agent_task.done():
|
|
243
|
+
try:
|
|
244
|
+
agent_task.cancel()
|
|
245
|
+
except Exception:
|
|
246
|
+
pass
|
|
247
|
+
|
|
248
|
+
def _abort_active_tool_call(self, task_info: TaskContext) -> None:
|
|
249
|
+
active_tool = task_info.active_tool
|
|
250
|
+
if active_tool is None:
|
|
251
|
+
return
|
|
252
|
+
try:
|
|
253
|
+
active_tool.cancel()
|
|
254
|
+
except Exception:
|
|
255
|
+
logger.warning(
|
|
256
|
+
"Failed to cancel active tool call agent_id=%s tool=%s mcp=%s",
|
|
257
|
+
self._agent.agent_id,
|
|
258
|
+
active_tool.tool_name,
|
|
259
|
+
active_tool.mcp_name,
|
|
260
|
+
exc_info=True,
|
|
261
|
+
)
|
|
262
|
+
finally:
|
|
263
|
+
task_info.clear_active_tool()
|
|
264
|
+
|
|
265
|
+
@staticmethod
|
|
266
|
+
def _phase_for_error(error: Exception) -> TaskPhase:
|
|
267
|
+
if isinstance(error, (TaskWaitTimeoutError, TaskExecutionTimeoutError)):
|
|
268
|
+
return TaskPhase.TIMED_OUT
|
|
269
|
+
if isinstance(error, (UserCancelledError, TaskCancelledError)):
|
|
270
|
+
return TaskPhase.CANCELLED
|
|
271
|
+
return TaskPhase.FAILED
|
|
272
|
+
|
|
273
|
+
async def _start_new_task(
|
|
274
|
+
self,
|
|
275
|
+
task_id: str,
|
|
276
|
+
initial_text: str,
|
|
277
|
+
request_headers: dict[str, str],
|
|
278
|
+
main_loop: asyncio.AbstractEventLoop,
|
|
279
|
+
task_updater: TaskUpdaterProtocol | Any,
|
|
280
|
+
*,
|
|
281
|
+
context_id: str | None = None,
|
|
282
|
+
task_store=None,
|
|
283
|
+
) -> None:
|
|
284
|
+
from a2a.types import TaskState
|
|
285
|
+
|
|
286
|
+
await task_updater.update_status(TaskState.submitted)
|
|
287
|
+
task_info = TaskContext(
|
|
288
|
+
event=asyncio.Event(),
|
|
289
|
+
loop=main_loop,
|
|
290
|
+
updater=task_updater,
|
|
291
|
+
phase=TaskPhase.SUBMITTED,
|
|
292
|
+
)
|
|
293
|
+
task_observation = A2ATaskObservation(
|
|
294
|
+
agent_id=self._agent.agent_id,
|
|
295
|
+
task_id=task_id,
|
|
296
|
+
context_id=context_id,
|
|
297
|
+
request_headers=request_headers,
|
|
298
|
+
task_input=initial_text,
|
|
299
|
+
)
|
|
300
|
+
self.task_pool[task_id] = task_info
|
|
301
|
+
|
|
302
|
+
try:
|
|
303
|
+
agent = self.build_agent(mcp_headers=request_headers)
|
|
304
|
+
except Exception:
|
|
305
|
+
self.task_pool.pop(task_id, None)
|
|
306
|
+
raise
|
|
307
|
+
|
|
308
|
+
logger.info(
|
|
309
|
+
"Starting new task agent_id=%s task_id=%s has_text=%s",
|
|
310
|
+
self._agent.agent_id,
|
|
311
|
+
task_id,
|
|
312
|
+
bool(initial_text),
|
|
313
|
+
)
|
|
314
|
+
task_executor = ThreadPoolExecutor(
|
|
315
|
+
max_workers=1,
|
|
316
|
+
thread_name_prefix=f"runtime-task-{task_id[:8]}",
|
|
317
|
+
)
|
|
318
|
+
run_context = copy_context()
|
|
319
|
+
agent_task = main_loop.run_in_executor(
|
|
320
|
+
task_executor,
|
|
321
|
+
partial(
|
|
322
|
+
run_context.run,
|
|
323
|
+
self._run_agent_with_cleanup,
|
|
324
|
+
agent,
|
|
325
|
+
initial_text,
|
|
326
|
+
task_observation,
|
|
327
|
+
),
|
|
328
|
+
)
|
|
329
|
+
agent_task.add_done_callback(
|
|
330
|
+
lambda task: None if task.cancelled() else task.exception()
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
task_info.agent = agent
|
|
334
|
+
task_info.agent_task = agent_task
|
|
335
|
+
task_info.task_executor = task_executor
|
|
336
|
+
task_info.set_phase(TaskPhase.RUNNING)
|
|
337
|
+
self._schedule_task_timeout(
|
|
338
|
+
task_id,
|
|
339
|
+
task_info,
|
|
340
|
+
task_updater,
|
|
341
|
+
task_store=task_store,
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
def _schedule_task_timeout(
|
|
345
|
+
self,
|
|
346
|
+
task_id: str,
|
|
347
|
+
task_info: TaskContext,
|
|
348
|
+
task_updater: TaskUpdaterProtocol | Any,
|
|
349
|
+
*,
|
|
350
|
+
task_store=None,
|
|
351
|
+
) -> None:
|
|
352
|
+
self._cancel_task_timeout(task_info)
|
|
353
|
+
task_timeout = float(os.getenv("MCP_AGENT_TASK_TIMEOUT_SECONDS", "0"))
|
|
354
|
+
if task_timeout <= 0:
|
|
355
|
+
return
|
|
356
|
+
|
|
357
|
+
loop = task_info.loop
|
|
358
|
+
if not loop or loop.is_closed():
|
|
359
|
+
return
|
|
360
|
+
|
|
361
|
+
def on_timeout() -> None:
|
|
362
|
+
current_info = self.task_pool.get(task_id)
|
|
363
|
+
if current_info is None or current_info.finalized:
|
|
364
|
+
return
|
|
365
|
+
|
|
366
|
+
reason = (
|
|
367
|
+
f"任务执行超过 {self._format_timeout_seconds(task_timeout)} 秒,已自动取消"
|
|
368
|
+
)
|
|
369
|
+
logger.warning(
|
|
370
|
+
"Task execution timeout agent_id=%s task_id=%s reason=%s",
|
|
371
|
+
self._agent.agent_id,
|
|
372
|
+
task_id,
|
|
373
|
+
reason,
|
|
374
|
+
)
|
|
375
|
+
current_info.control.task_timeout_handle = None
|
|
376
|
+
loop.create_task(
|
|
377
|
+
self._request_task_stop(
|
|
378
|
+
task_id,
|
|
379
|
+
task_updater,
|
|
380
|
+
error=TaskExecutionTimeoutError(reason),
|
|
381
|
+
reason=reason,
|
|
382
|
+
phase=TaskPhase.TIMED_OUT,
|
|
383
|
+
task_store=task_store,
|
|
384
|
+
)
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
task_info.control.task_timeout_handle = loop.call_later(task_timeout, on_timeout)
|
|
388
|
+
|
|
389
|
+
def _schedule_wait_timeout(
|
|
390
|
+
self,
|
|
391
|
+
task_id: str,
|
|
392
|
+
task_info: TaskContext,
|
|
393
|
+
task_updater: TaskUpdaterProtocol | Any,
|
|
394
|
+
*,
|
|
395
|
+
task_store=None,
|
|
396
|
+
) -> None:
|
|
397
|
+
self._cancel_wait_timeout(task_info)
|
|
398
|
+
wait_timeout = float(os.getenv("MCP_AGENT_WAIT_TIMEOUT_SECONDS", "1800"))
|
|
399
|
+
if wait_timeout <= 0:
|
|
400
|
+
return
|
|
401
|
+
|
|
402
|
+
loop = task_info.loop
|
|
403
|
+
if not loop or loop.is_closed():
|
|
404
|
+
return
|
|
405
|
+
|
|
406
|
+
def on_timeout() -> None:
|
|
407
|
+
current_info = self.task_pool.get(task_id)
|
|
408
|
+
if (
|
|
409
|
+
current_info is None
|
|
410
|
+
or current_info.finalized
|
|
411
|
+
or current_info.wait_item is None
|
|
412
|
+
or current_info.timed_out
|
|
413
|
+
):
|
|
414
|
+
return
|
|
415
|
+
reason = (
|
|
416
|
+
"任务等待用户输入或授权超过 "
|
|
417
|
+
f"{self._format_timeout_seconds(wait_timeout)} 秒,已自动失败"
|
|
418
|
+
)
|
|
419
|
+
logger.warning(
|
|
420
|
+
"Task wait timeout agent_id=%s task_id=%s reason=%s",
|
|
421
|
+
self._agent.agent_id,
|
|
422
|
+
task_id,
|
|
423
|
+
reason,
|
|
424
|
+
)
|
|
425
|
+
current_info.timeout_handle = None
|
|
426
|
+
loop.create_task(
|
|
427
|
+
self._request_task_stop(
|
|
428
|
+
task_id,
|
|
429
|
+
task_updater,
|
|
430
|
+
error=TaskWaitTimeoutError(reason),
|
|
431
|
+
reason=reason,
|
|
432
|
+
phase=TaskPhase.TIMED_OUT,
|
|
433
|
+
task_store=task_store,
|
|
434
|
+
)
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
task_info.timeout_handle = loop.call_later(wait_timeout, on_timeout)
|
|
438
|
+
|
|
439
|
+
async def _request_task_stop(
|
|
440
|
+
self,
|
|
441
|
+
task_id: str,
|
|
442
|
+
task_updater: TaskUpdaterProtocol | Any,
|
|
443
|
+
*,
|
|
444
|
+
error: Exception,
|
|
445
|
+
reason: str,
|
|
446
|
+
phase: TaskPhase,
|
|
447
|
+
task_store=None,
|
|
448
|
+
) -> None:
|
|
449
|
+
"""统一处理任务停止。
|
|
450
|
+
|
|
451
|
+
所有 stop path 都收口到这里:
|
|
452
|
+
|
|
453
|
+
- 用户取消
|
|
454
|
+
- 等待超时
|
|
455
|
+
- 整体任务超时
|
|
456
|
+
"""
|
|
457
|
+
|
|
458
|
+
task_info = self.task_pool.get(task_id)
|
|
459
|
+
if not task_info or task_info.finalized:
|
|
460
|
+
return
|
|
461
|
+
if task_info.control.stop_requested and task_info.control.stop_error is not None:
|
|
462
|
+
return
|
|
463
|
+
|
|
464
|
+
logger.info(
|
|
465
|
+
"Stopping task agent_id=%s task_id=%s phase=%s reason=%s",
|
|
466
|
+
self._agent.agent_id,
|
|
467
|
+
task_id,
|
|
468
|
+
phase.value,
|
|
469
|
+
reason,
|
|
470
|
+
)
|
|
471
|
+
task_info.request_stop(
|
|
472
|
+
error=error,
|
|
473
|
+
reason=reason,
|
|
474
|
+
timed_out=phase == TaskPhase.TIMED_OUT,
|
|
475
|
+
)
|
|
476
|
+
task_info.set_phase(TaskPhase.CANCELLING)
|
|
477
|
+
self._cancel_all_timeouts(task_info)
|
|
478
|
+
self._clear_wait_state(task_info)
|
|
479
|
+
self._abort_active_tool_call(task_info)
|
|
480
|
+
self._interrupt_agent(task_info)
|
|
481
|
+
self._release_event(task_info)
|
|
482
|
+
await self._finalize_task(
|
|
483
|
+
task_id,
|
|
484
|
+
task_updater,
|
|
485
|
+
error=error,
|
|
486
|
+
task_store=task_store,
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
async def _delete_failed_task_from_store(
|
|
490
|
+
self,
|
|
491
|
+
task_id: str,
|
|
492
|
+
task_updater: TaskUpdaterProtocol | Any,
|
|
493
|
+
*,
|
|
494
|
+
task_store=None,
|
|
495
|
+
) -> None:
|
|
496
|
+
if task_store is None:
|
|
497
|
+
return
|
|
498
|
+
|
|
499
|
+
event_queue = getattr(task_updater, "event_queue", None)
|
|
500
|
+
if event_queue is not None and not event_queue.is_closed():
|
|
501
|
+
self._failed_task_ids.add(task_id)
|
|
502
|
+
return
|
|
503
|
+
|
|
504
|
+
try:
|
|
505
|
+
await task_store.delete(task_id)
|
|
506
|
+
except Exception:
|
|
507
|
+
logger.warning(
|
|
508
|
+
"Failed to delete failed task directly from task_store task_id=%s",
|
|
509
|
+
task_id,
|
|
510
|
+
exc_info=True,
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
def _run_agent_with_cleanup(
|
|
514
|
+
self,
|
|
515
|
+
agent: Any,
|
|
516
|
+
task_text: str,
|
|
517
|
+
observation: A2ATaskObservation | None = None,
|
|
518
|
+
):
|
|
519
|
+
try:
|
|
520
|
+
with a2a_task_observation(observation) as span:
|
|
521
|
+
result = agent.run(task=task_text)
|
|
522
|
+
update_observation_output(span, result)
|
|
523
|
+
return result
|
|
524
|
+
finally:
|
|
525
|
+
mcp_clients = list(getattr(agent, "_runtime_mcp_clients", []) or [])
|
|
526
|
+
for client in mcp_clients:
|
|
527
|
+
try:
|
|
528
|
+
client.disconnect()
|
|
529
|
+
except Exception:
|
|
530
|
+
logger.warning(
|
|
531
|
+
"Failed to disconnect MCP client during task cleanup agent_id=%s",
|
|
532
|
+
self._agent.agent_id,
|
|
533
|
+
exc_info=True,
|
|
534
|
+
)
|
|
535
|
+
try:
|
|
536
|
+
agent._runtime_mcp_clients = []
|
|
537
|
+
except Exception:
|
|
538
|
+
pass
|
|
539
|
+
|
|
540
|
+
async def _handle_wait_resume(
|
|
541
|
+
self,
|
|
542
|
+
task_id: str,
|
|
543
|
+
task_info: TaskContext,
|
|
544
|
+
task_updater: TaskUpdaterProtocol | Any,
|
|
545
|
+
user_input: str | dict | bool | None,
|
|
546
|
+
*,
|
|
547
|
+
task_store=None,
|
|
548
|
+
) -> bool:
|
|
549
|
+
wait_item = task_info.wait_item
|
|
550
|
+
if wait_item is None or user_input is None:
|
|
551
|
+
return False
|
|
552
|
+
|
|
553
|
+
wait_type = wait_state_type(wait_item)
|
|
554
|
+
if wait_type == WAIT_TYPE_INPUT_REQUIRED:
|
|
555
|
+
task_info.user_input = user_input
|
|
556
|
+
task_info.set_phase(TaskPhase.RUNNING)
|
|
557
|
+
logger.debug("Resuming task with user input task_id=%s", task_id)
|
|
558
|
+
self._release_event(task_info)
|
|
559
|
+
self._clear_wait_state(task_info, clear_resume_payload=False)
|
|
560
|
+
return True
|
|
561
|
+
|
|
562
|
+
if wait_type != WAIT_TYPE_AUTH_REQUIRED:
|
|
563
|
+
return False
|
|
564
|
+
|
|
565
|
+
approved: bool | None = None
|
|
566
|
+
override_tool_name = None
|
|
567
|
+
override_args = None
|
|
568
|
+
parsed_dict = user_input if isinstance(user_input, dict) else None
|
|
569
|
+
if parsed_dict is None:
|
|
570
|
+
approved = self._normalize_auth_decision(user_input)
|
|
571
|
+
else:
|
|
572
|
+
auth_payload = self._auth_payload(parsed_dict)
|
|
573
|
+
approved = self._extract_auth_decision(auth_payload)
|
|
574
|
+
override_tool_name, override_args = self._extract_auth_override(auth_payload)
|
|
575
|
+
|
|
576
|
+
if approved is None:
|
|
577
|
+
return False
|
|
578
|
+
if approved:
|
|
579
|
+
task_info.auth_denied = False
|
|
580
|
+
task_info.tool_name_override = override_tool_name
|
|
581
|
+
task_info.tool_args_override = override_args
|
|
582
|
+
task_info.set_phase(TaskPhase.RUNNING)
|
|
583
|
+
logger.debug(
|
|
584
|
+
"Resuming task with auth approved task_id=%s tool=%s",
|
|
585
|
+
task_id,
|
|
586
|
+
override_tool_name,
|
|
587
|
+
)
|
|
588
|
+
self._release_event(task_info)
|
|
589
|
+
self._clear_wait_state(task_info)
|
|
590
|
+
return True
|
|
591
|
+
|
|
592
|
+
logger.info(
|
|
593
|
+
"Task auth denied agent_id=%s task_id=%s",
|
|
594
|
+
self._agent.agent_id,
|
|
595
|
+
task_id,
|
|
596
|
+
)
|
|
597
|
+
task_info.auth_denied = True
|
|
598
|
+
self._release_event(task_info)
|
|
599
|
+
self._clear_wait_state(task_info)
|
|
600
|
+
await self._request_task_stop(
|
|
601
|
+
task_id,
|
|
602
|
+
task_updater,
|
|
603
|
+
error=UserCancelledError("user denied auth request"),
|
|
604
|
+
reason="user denied auth request",
|
|
605
|
+
phase=TaskPhase.CANCELLED,
|
|
606
|
+
task_store=task_store,
|
|
607
|
+
)
|
|
608
|
+
return True
|
|
609
|
+
|
|
610
|
+
async def _listen_for_events(self, task_id: str):
|
|
611
|
+
while True:
|
|
612
|
+
task_info = self.task_pool.get(task_id)
|
|
613
|
+
if not task_info or task_info.finalized:
|
|
614
|
+
return
|
|
615
|
+
agent_task = task_info.agent_task
|
|
616
|
+
if agent_task and agent_task.done():
|
|
617
|
+
return
|
|
618
|
+
if task_info.wait_item is not None:
|
|
619
|
+
yield task_info.wait_item
|
|
620
|
+
await asyncio.sleep(0.2)
|
|
621
|
+
|
|
622
|
+
async def _emit_wait_state(
|
|
623
|
+
self,
|
|
624
|
+
task_id: str,
|
|
625
|
+
task_info: TaskContext,
|
|
626
|
+
task_updater: TaskUpdaterProtocol | Any,
|
|
627
|
+
*,
|
|
628
|
+
task_store=None,
|
|
629
|
+
) -> None:
|
|
630
|
+
async for pending_item in self._listen_for_events(task_id):
|
|
631
|
+
payload = wait_state_payload(pending_item)
|
|
632
|
+
wait_type = payload["type"]
|
|
633
|
+
wait_data = payload["data"]
|
|
634
|
+
if wait_type == WAIT_TYPE_INPUT_REQUIRED:
|
|
635
|
+
task_info.set_phase(TaskPhase.WAITING_INPUT)
|
|
636
|
+
logger.info(
|
|
637
|
+
"Task entered input_required agent_id=%s task_id=%s",
|
|
638
|
+
self._agent.agent_id,
|
|
639
|
+
task_id,
|
|
640
|
+
)
|
|
641
|
+
self._schedule_wait_timeout(
|
|
642
|
+
task_id,
|
|
643
|
+
task_info,
|
|
644
|
+
task_updater,
|
|
645
|
+
task_store=task_store,
|
|
646
|
+
)
|
|
647
|
+
await emit_wait_state(task_updater, payload)
|
|
648
|
+
task_info.wait_item_emitted = True
|
|
649
|
+
return
|
|
650
|
+
if wait_type == WAIT_TYPE_AUTH_REQUIRED:
|
|
651
|
+
task_info.set_phase(TaskPhase.WAITING_AUTH)
|
|
652
|
+
tool_name = wait_data.get("tool_name")
|
|
653
|
+
logger.info(
|
|
654
|
+
"Task entered auth_required agent_id=%s task_id=%s tool=%s",
|
|
655
|
+
self._agent.agent_id,
|
|
656
|
+
task_id,
|
|
657
|
+
tool_name,
|
|
658
|
+
)
|
|
659
|
+
self._schedule_wait_timeout(
|
|
660
|
+
task_id,
|
|
661
|
+
task_info,
|
|
662
|
+
task_updater,
|
|
663
|
+
task_store=task_store,
|
|
664
|
+
)
|
|
665
|
+
await emit_wait_state(task_updater, payload)
|
|
666
|
+
task_info.wait_item_emitted = True
|
|
667
|
+
return
|
|
668
|
+
|
|
669
|
+
async def _process_task_request(
|
|
670
|
+
self,
|
|
671
|
+
task_id: str,
|
|
672
|
+
task_updater: TaskUpdaterProtocol | Any,
|
|
673
|
+
user_input: str | dict | bool | None = None,
|
|
674
|
+
*,
|
|
675
|
+
task_store=None,
|
|
676
|
+
) -> None:
|
|
677
|
+
task_info = self.task_pool.get(task_id)
|
|
678
|
+
if not task_info:
|
|
679
|
+
return
|
|
680
|
+
|
|
681
|
+
wait_item = task_info.wait_item
|
|
682
|
+
if wait_item and user_input is None and task_info.wait_item_emitted:
|
|
683
|
+
return
|
|
684
|
+
|
|
685
|
+
if wait_item and user_input is not None:
|
|
686
|
+
handled = await self._handle_wait_resume(
|
|
687
|
+
task_id,
|
|
688
|
+
task_info,
|
|
689
|
+
task_updater,
|
|
690
|
+
user_input,
|
|
691
|
+
task_store=task_store,
|
|
692
|
+
)
|
|
693
|
+
if handled:
|
|
694
|
+
return
|
|
695
|
+
|
|
696
|
+
await self._emit_wait_state(
|
|
697
|
+
task_id,
|
|
698
|
+
task_info,
|
|
699
|
+
task_updater,
|
|
700
|
+
task_store=task_store,
|
|
701
|
+
)
|
|
702
|
+
|
|
703
|
+
async def _run_task_cycle(
|
|
704
|
+
self,
|
|
705
|
+
task_id: str,
|
|
706
|
+
task_updater: TaskUpdaterProtocol | Any,
|
|
707
|
+
user_input: str | dict | bool | None = None,
|
|
708
|
+
*,
|
|
709
|
+
task_store=None,
|
|
710
|
+
) -> None:
|
|
711
|
+
"""驱动一次任务执行周期。
|
|
712
|
+
|
|
713
|
+
A2A bridge 只负责把请求转发进来;真正的等待、恢复、完成、失败都在这里收口。
|
|
714
|
+
"""
|
|
715
|
+
|
|
716
|
+
await self._process_task_request(
|
|
717
|
+
task_id,
|
|
718
|
+
task_updater,
|
|
719
|
+
user_input,
|
|
720
|
+
task_store=task_store,
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
task_info = self.task_pool.get(task_id)
|
|
724
|
+
if not task_info or task_info.finalized:
|
|
725
|
+
return
|
|
726
|
+
if task_info.control.stop_requested:
|
|
727
|
+
return
|
|
728
|
+
|
|
729
|
+
agent_task = task_info.agent_task
|
|
730
|
+
if not agent_task:
|
|
731
|
+
return
|
|
732
|
+
|
|
733
|
+
while True:
|
|
734
|
+
if task_info.wait_item is not None:
|
|
735
|
+
await self._process_task_request(
|
|
736
|
+
task_id,
|
|
737
|
+
task_updater,
|
|
738
|
+
None,
|
|
739
|
+
task_store=task_store,
|
|
740
|
+
)
|
|
741
|
+
return
|
|
742
|
+
if task_info.control.stop_requested:
|
|
743
|
+
return
|
|
744
|
+
if agent_task.done():
|
|
745
|
+
break
|
|
746
|
+
await asyncio.sleep(0.1)
|
|
747
|
+
|
|
748
|
+
try:
|
|
749
|
+
result = await agent_task
|
|
750
|
+
except asyncio.CancelledError:
|
|
751
|
+
if task_info.finalized:
|
|
752
|
+
return
|
|
753
|
+
error = task_info.control.stop_error or TaskCancelledError(
|
|
754
|
+
"task cancelled by runtime"
|
|
755
|
+
)
|
|
756
|
+
await self._finalize_task(
|
|
757
|
+
task_id,
|
|
758
|
+
task_updater,
|
|
759
|
+
error=error,
|
|
760
|
+
task_store=task_store,
|
|
761
|
+
)
|
|
762
|
+
except Exception as exc:
|
|
763
|
+
if task_info.finalized:
|
|
764
|
+
return
|
|
765
|
+
error = task_info.control.stop_error or exc
|
|
766
|
+
await self._finalize_task(
|
|
767
|
+
task_id,
|
|
768
|
+
task_updater,
|
|
769
|
+
error=error,
|
|
770
|
+
task_store=task_store,
|
|
771
|
+
)
|
|
772
|
+
else:
|
|
773
|
+
if task_info.finalized:
|
|
774
|
+
return
|
|
775
|
+
if (
|
|
776
|
+
task_info.control.stop_requested
|
|
777
|
+
and task_info.control.stop_error is not None
|
|
778
|
+
):
|
|
779
|
+
await self._finalize_task(
|
|
780
|
+
task_id,
|
|
781
|
+
task_updater,
|
|
782
|
+
error=task_info.control.stop_error,
|
|
783
|
+
task_store=task_store,
|
|
784
|
+
)
|
|
785
|
+
return
|
|
786
|
+
await self._finalize_task(
|
|
787
|
+
task_id,
|
|
788
|
+
task_updater,
|
|
789
|
+
result=result,
|
|
790
|
+
task_store=task_store,
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
async def _cancel_task(
|
|
794
|
+
self,
|
|
795
|
+
task_id: str,
|
|
796
|
+
task_updater: TaskUpdaterProtocol | Any,
|
|
797
|
+
*,
|
|
798
|
+
task_store=None,
|
|
799
|
+
) -> None:
|
|
800
|
+
"""终止任务并统一走 finalize 清理路径。"""
|
|
801
|
+
|
|
802
|
+
task_info = self.task_pool.get(task_id)
|
|
803
|
+
if not task_info or task_info.finalized:
|
|
804
|
+
return
|
|
805
|
+
|
|
806
|
+
logger.info(
|
|
807
|
+
"Cancelling task agent_id=%s task_id=%s",
|
|
808
|
+
self._agent.agent_id,
|
|
809
|
+
task_id,
|
|
810
|
+
)
|
|
811
|
+
await self._request_task_stop(
|
|
812
|
+
task_id,
|
|
813
|
+
task_updater,
|
|
814
|
+
error=TaskCancelledError("task cancelled by client"),
|
|
815
|
+
reason="task cancelled by client",
|
|
816
|
+
phase=TaskPhase.CANCELLED,
|
|
817
|
+
task_store=task_store,
|
|
818
|
+
)
|
|
819
|
+
|
|
820
|
+
async def _finalize_task(
|
|
821
|
+
self,
|
|
822
|
+
task_id: str,
|
|
823
|
+
task_updater: TaskUpdaterProtocol | Any,
|
|
824
|
+
result: Any | None = None,
|
|
825
|
+
error: Exception | None = None,
|
|
826
|
+
message_text: str | None = None,
|
|
827
|
+
*,
|
|
828
|
+
task_store=None,
|
|
829
|
+
) -> None:
|
|
830
|
+
task_info = self.task_pool.get(task_id)
|
|
831
|
+
if not task_info or task_info.finalized:
|
|
832
|
+
return
|
|
833
|
+
task_info.finalized = True
|
|
834
|
+
|
|
835
|
+
if message_text is not None:
|
|
836
|
+
task_info.set_phase(TaskPhase.COMPLETED)
|
|
837
|
+
logger.debug(
|
|
838
|
+
"Task finalized with message agent_id=%s task_id=%s",
|
|
839
|
+
self._agent.agent_id,
|
|
840
|
+
task_id,
|
|
841
|
+
)
|
|
842
|
+
await publish_message_completion(task_updater, message_text)
|
|
843
|
+
elif error is not None:
|
|
844
|
+
task_info.set_phase(self._phase_for_error(error))
|
|
845
|
+
if isinstance(error, TaskWaitTimeoutError):
|
|
846
|
+
logger.warning(
|
|
847
|
+
"Task timed out agent_id=%s task_id=%s error=%s: %s",
|
|
848
|
+
self._agent.agent_id,
|
|
849
|
+
task_id,
|
|
850
|
+
type(error).__name__,
|
|
851
|
+
error,
|
|
852
|
+
)
|
|
853
|
+
elif isinstance(error, TaskExecutionTimeoutError):
|
|
854
|
+
logger.warning(
|
|
855
|
+
"Task execution timed out agent_id=%s task_id=%s error=%s: %s",
|
|
856
|
+
self._agent.agent_id,
|
|
857
|
+
task_id,
|
|
858
|
+
type(error).__name__,
|
|
859
|
+
error,
|
|
860
|
+
)
|
|
861
|
+
elif isinstance(error, (UserCancelledError, TaskCancelledError)):
|
|
862
|
+
logger.info(
|
|
863
|
+
"Task cancelled agent_id=%s task_id=%s error=%s: %s",
|
|
864
|
+
self._agent.agent_id,
|
|
865
|
+
task_id,
|
|
866
|
+
type(error).__name__,
|
|
867
|
+
error,
|
|
868
|
+
)
|
|
869
|
+
else:
|
|
870
|
+
logger.error(
|
|
871
|
+
"Task failed agent_id=%s task_id=%s error=%s: %s",
|
|
872
|
+
self._agent.agent_id,
|
|
873
|
+
task_id,
|
|
874
|
+
type(error).__name__,
|
|
875
|
+
error,
|
|
876
|
+
)
|
|
877
|
+
await publish_error(task_updater, error)
|
|
878
|
+
await self._delete_failed_task_from_store(
|
|
879
|
+
task_id,
|
|
880
|
+
task_updater,
|
|
881
|
+
task_store=task_store,
|
|
882
|
+
)
|
|
883
|
+
else:
|
|
884
|
+
task_info.set_phase(TaskPhase.COMPLETED)
|
|
885
|
+
logger.info(
|
|
886
|
+
"Task completed agent_id=%s task_id=%s",
|
|
887
|
+
self._agent.agent_id,
|
|
888
|
+
task_id,
|
|
889
|
+
)
|
|
890
|
+
final_result = self.format_result(result)
|
|
891
|
+
await publish_result(task_updater, final_result)
|
|
892
|
+
|
|
893
|
+
self._cancel_all_timeouts(task_info)
|
|
894
|
+
self._abort_active_tool_call(task_info)
|
|
895
|
+
self._shutdown_task_executor(task_info)
|
|
896
|
+
task_info.set_phase(TaskPhase.FINALIZED)
|
|
897
|
+
self.task_pool.pop(task_id, None)
|