agent-runtime-sdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. agent_runtime/__init__.py +84 -0
  2. agent_runtime/builder.py +317 -0
  3. agent_runtime/config/__init__.py +29 -0
  4. agent_runtime/config/definitions.py +144 -0
  5. agent_runtime/config/policies.py +63 -0
  6. agent_runtime/config/storage.py +117 -0
  7. agent_runtime/context.py +10 -0
  8. agent_runtime/definitions.py +33 -0
  9. agent_runtime/discovery.py +16 -0
  10. agent_runtime/exceptions.py +74 -0
  11. agent_runtime/mcp/__init__.py +28 -0
  12. agent_runtime/mcp/discovery.py +146 -0
  13. agent_runtime/mcp/metadata.py +68 -0
  14. agent_runtime/mcp/utils.py +52 -0
  15. agent_runtime/model_registry.py +40 -0
  16. agent_runtime/plugins/__init__.py +4 -0
  17. agent_runtime/plugins/base.py +90 -0
  18. agent_runtime/plugins/default.py +19 -0
  19. agent_runtime/plugins/instructions.py +38 -0
  20. agent_runtime/plugins/loader.py +59 -0
  21. agent_runtime/policies.py +15 -0
  22. agent_runtime/runtime.py +110 -0
  23. agent_runtime/runtime_engine/__init__.py +22 -0
  24. agent_runtime/runtime_engine/a2a_bridge.py +190 -0
  25. agent_runtime/runtime_engine/a2a_task_io.py +165 -0
  26. agent_runtime/runtime_engine/agent_build.py +315 -0
  27. agent_runtime/runtime_engine/context.py +469 -0
  28. agent_runtime/runtime_engine/loading.py +170 -0
  29. agent_runtime/runtime_engine/observability.py +154 -0
  30. agent_runtime/runtime_engine/policy_registry.py +98 -0
  31. agent_runtime/runtime_engine/protocol_tools.py +94 -0
  32. agent_runtime/runtime_engine/task_flow.py +897 -0
  33. agent_runtime/runtime_engine/tool_flow.py +332 -0
  34. agent_runtime/sdk_agent.py +548 -0
  35. agent_runtime/server/__init__.py +15 -0
  36. agent_runtime/server/app_factory.py +37 -0
  37. agent_runtime/server/bootstrap.py +48 -0
  38. agent_runtime/server/endpoint_utils.py +37 -0
  39. agent_runtime/server/management.py +107 -0
  40. agent_runtime/smol/__init__.py +4 -0
  41. agent_runtime/smol/agents.py +431 -0
  42. agent_runtime/smol/llm_models.py +212 -0
  43. agent_runtime/smol/memory.py +111 -0
  44. agent_runtime/smol/models.py +69 -0
  45. agent_runtime/standalone.py +57 -0
  46. agent_runtime/storage.py +5 -0
  47. agent_runtime/tools.py +5 -0
  48. agent_runtime_sdk-0.1.0.dist-info/METADATA +125 -0
  49. agent_runtime_sdk-0.1.0.dist-info/RECORD +51 -0
  50. agent_runtime_sdk-0.1.0.dist-info/WHEEL +5 -0
  51. agent_runtime_sdk-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,897 @@
1
+ """任务生命周期状态机。
2
+
3
+ 这个文件处理一条任务从开始到结束的全过程:
4
+
5
+ - 启动任务
6
+ - 进入 input_required / auth_required 等待态
7
+ - 用户恢复任务
8
+ - 超时与取消
9
+ - 最终完成或失败
10
+
11
+ 它是 runtime 里最像“调度器”的一层。
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import asyncio
17
+ from concurrent.futures import ThreadPoolExecutor
18
+ from contextvars import copy_context
19
+ from functools import partial
20
+ import logging
21
+ import os
22
+ from typing import Any
23
+
24
+ from .a2a_task_io import (
25
+ emit_wait_state,
26
+ publish_error,
27
+ publish_message_completion,
28
+ publish_result,
29
+ )
30
+ from .context import (
31
+ TaskContext,
32
+ TaskPhase,
33
+ TaskUpdaterProtocol,
34
+ WAIT_TYPE_AUTH_REQUIRED,
35
+ WAIT_TYPE_INPUT_REQUIRED,
36
+ wait_state_payload,
37
+ wait_state_type,
38
+ )
39
+ from .observability import (
40
+ A2ATaskObservation,
41
+ a2a_task_observation,
42
+ update_observation_output,
43
+ )
44
+ from ..exceptions import (
45
+ TaskCancelledError,
46
+ TaskExecutionTimeoutError,
47
+ TaskWaitTimeoutError,
48
+ UserCancelledError,
49
+ )
50
+
51
+
52
+ logger = logging.getLogger(__name__)
53
+
54
+
55
+ class RuntimeTaskFlow:
56
+ """负责任务生命周期调度与 wait/resume 状态转换。"""
57
+
58
+ _AUTH_DECISION_KEYS = (
59
+ "approve",
60
+ "approved",
61
+ "confirm",
62
+ "confirmed",
63
+ "accept",
64
+ "accepted",
65
+ "allow",
66
+ "allowed",
67
+ "authorize",
68
+ "authorized",
69
+ "grant",
70
+ "granted",
71
+ "decision",
72
+ )
73
+ _AUTH_DECISION_KEY_SET = set(_AUTH_DECISION_KEYS)
74
+
75
+ @staticmethod
76
+ def _format_timeout_seconds(value: float) -> str:
77
+ if float(value).is_integer():
78
+ return str(int(value))
79
+ return f"{value:g}"
80
+
81
+ def pop_failed_task_id(self, task_id: str) -> bool:
82
+ if task_id not in self._failed_task_ids:
83
+ return False
84
+ self._failed_task_ids.remove(task_id)
85
+ return True
86
+
87
+ @staticmethod
88
+ def _normalize_auth_decision(value: object) -> bool | None:
89
+ if isinstance(value, bool):
90
+ return value
91
+ if isinstance(value, int) and not isinstance(value, bool):
92
+ if value == 1:
93
+ return True
94
+ if value == 0:
95
+ return False
96
+ if value is None:
97
+ return None
98
+ text = str(value).strip().lower()
99
+ if text in {
100
+ "true",
101
+ "1",
102
+ "yes",
103
+ "y",
104
+ "ok",
105
+ "approve",
106
+ "approved",
107
+ "confirm",
108
+ "confirmed",
109
+ "accept",
110
+ "accepted",
111
+ "allow",
112
+ "allowed",
113
+ "authorize",
114
+ "authorized",
115
+ "grant",
116
+ "granted",
117
+ "同意",
118
+ "确认",
119
+ "批准",
120
+ "允许",
121
+ "可以",
122
+ "是",
123
+ "好的",
124
+ "继续",
125
+ "通过",
126
+ }:
127
+ return True
128
+ if text in {
129
+ "false",
130
+ "0",
131
+ "no",
132
+ "n",
133
+ "deny",
134
+ "denied",
135
+ "reject",
136
+ "rejected",
137
+ "cancel",
138
+ "cancelled",
139
+ "取消",
140
+ "拒绝",
141
+ "不同意",
142
+ "否",
143
+ "不行",
144
+ "停止",
145
+ }:
146
+ return False
147
+ return None
148
+
149
+ @classmethod
150
+ def _extract_auth_decision(cls, parsed_dict: dict[str, Any]) -> bool | None:
151
+ for key in cls._AUTH_DECISION_KEYS:
152
+ if key in parsed_dict:
153
+ return cls._normalize_auth_decision(parsed_dict.get(key))
154
+ return None
155
+
156
+ @classmethod
157
+ def _auth_payload(cls, parsed_dict: dict[str, Any]) -> dict[str, Any]:
158
+ for key in ("data", "payload"):
159
+ value = parsed_dict.get(key)
160
+ if not isinstance(value, dict):
161
+ continue
162
+ if cls._AUTH_DECISION_KEY_SET.intersection(value) or any(
163
+ arg_key in value for arg_key in ("args", "tool_args", "arguments")
164
+ ):
165
+ return value
166
+ return parsed_dict
167
+
168
+ @classmethod
169
+ def _extract_auth_override(
170
+ cls,
171
+ parsed_dict: dict[str, Any] | None,
172
+ ) -> tuple[str | None, Any]:
173
+ if parsed_dict is None:
174
+ return None, None
175
+
176
+ tool_name = parsed_dict.get("tool_name") or parsed_dict.get("name")
177
+ if "args" in parsed_dict:
178
+ return tool_name, parsed_dict.get("args")
179
+ if "tool_args" in parsed_dict:
180
+ return tool_name, parsed_dict.get("tool_args")
181
+ if "arguments" in parsed_dict:
182
+ return tool_name, parsed_dict.get("arguments")
183
+ if not cls._AUTH_DECISION_KEY_SET.intersection(parsed_dict):
184
+ return tool_name, parsed_dict
185
+ return tool_name, None
186
+
187
+ @staticmethod
188
+ def _cancel_wait_timeout(task_info: TaskContext) -> None:
189
+ handle = task_info.timeout_handle
190
+ if handle is not None:
191
+ handle.cancel()
192
+ task_info.timeout_handle = None
193
+
194
+ @staticmethod
195
+ def _cancel_task_timeout(task_info: TaskContext) -> None:
196
+ handle = task_info.control.task_timeout_handle
197
+ if handle is not None:
198
+ handle.cancel()
199
+ task_info.control.task_timeout_handle = None
200
+
201
+ def _cancel_all_timeouts(self, task_info: TaskContext) -> None:
202
+ self._cancel_wait_timeout(task_info)
203
+ self._cancel_task_timeout(task_info)
204
+
205
+ def _clear_wait_state(
206
+ self, task_info: TaskContext, *, clear_resume_payload: bool = True
207
+ ) -> None:
208
+ self._cancel_wait_timeout(task_info)
209
+ if clear_resume_payload:
210
+ task_info.clear_wait()
211
+ else:
212
+ task_info.clear_wait_item()
213
+
214
+ @staticmethod
215
+ def _shutdown_task_executor(task_info: TaskContext) -> None:
216
+ executor = task_info.task_executor
217
+ if executor is None:
218
+ return
219
+ task_info.task_executor = None
220
+ try:
221
+ executor.shutdown(wait=False, cancel_futures=True)
222
+ except TypeError:
223
+ executor.shutdown(wait=False)
224
+
225
+ @staticmethod
226
+ def _release_event(task_info: TaskContext) -> None:
227
+ event = task_info.event
228
+ if event:
229
+ event.set()
230
+ task_info.event = asyncio.Event()
231
+
232
+ @staticmethod
233
+ def _interrupt_agent(task_info: TaskContext) -> None:
234
+ agent = task_info.agent
235
+ if agent is not None:
236
+ try:
237
+ agent.interrupt_switch = True
238
+ except Exception:
239
+ pass
240
+
241
+ agent_task = task_info.agent_task
242
+ if agent_task is not None and not agent_task.done():
243
+ try:
244
+ agent_task.cancel()
245
+ except Exception:
246
+ pass
247
+
248
+ def _abort_active_tool_call(self, task_info: TaskContext) -> None:
249
+ active_tool = task_info.active_tool
250
+ if active_tool is None:
251
+ return
252
+ try:
253
+ active_tool.cancel()
254
+ except Exception:
255
+ logger.warning(
256
+ "Failed to cancel active tool call agent_id=%s tool=%s mcp=%s",
257
+ self._agent.agent_id,
258
+ active_tool.tool_name,
259
+ active_tool.mcp_name,
260
+ exc_info=True,
261
+ )
262
+ finally:
263
+ task_info.clear_active_tool()
264
+
265
+ @staticmethod
266
+ def _phase_for_error(error: Exception) -> TaskPhase:
267
+ if isinstance(error, (TaskWaitTimeoutError, TaskExecutionTimeoutError)):
268
+ return TaskPhase.TIMED_OUT
269
+ if isinstance(error, (UserCancelledError, TaskCancelledError)):
270
+ return TaskPhase.CANCELLED
271
+ return TaskPhase.FAILED
272
+
273
+ async def _start_new_task(
274
+ self,
275
+ task_id: str,
276
+ initial_text: str,
277
+ request_headers: dict[str, str],
278
+ main_loop: asyncio.AbstractEventLoop,
279
+ task_updater: TaskUpdaterProtocol | Any,
280
+ *,
281
+ context_id: str | None = None,
282
+ task_store=None,
283
+ ) -> None:
284
+ from a2a.types import TaskState
285
+
286
+ await task_updater.update_status(TaskState.submitted)
287
+ task_info = TaskContext(
288
+ event=asyncio.Event(),
289
+ loop=main_loop,
290
+ updater=task_updater,
291
+ phase=TaskPhase.SUBMITTED,
292
+ )
293
+ task_observation = A2ATaskObservation(
294
+ agent_id=self._agent.agent_id,
295
+ task_id=task_id,
296
+ context_id=context_id,
297
+ request_headers=request_headers,
298
+ task_input=initial_text,
299
+ )
300
+ self.task_pool[task_id] = task_info
301
+
302
+ try:
303
+ agent = self.build_agent(mcp_headers=request_headers)
304
+ except Exception:
305
+ self.task_pool.pop(task_id, None)
306
+ raise
307
+
308
+ logger.info(
309
+ "Starting new task agent_id=%s task_id=%s has_text=%s",
310
+ self._agent.agent_id,
311
+ task_id,
312
+ bool(initial_text),
313
+ )
314
+ task_executor = ThreadPoolExecutor(
315
+ max_workers=1,
316
+ thread_name_prefix=f"runtime-task-{task_id[:8]}",
317
+ )
318
+ run_context = copy_context()
319
+ agent_task = main_loop.run_in_executor(
320
+ task_executor,
321
+ partial(
322
+ run_context.run,
323
+ self._run_agent_with_cleanup,
324
+ agent,
325
+ initial_text,
326
+ task_observation,
327
+ ),
328
+ )
329
+ agent_task.add_done_callback(
330
+ lambda task: None if task.cancelled() else task.exception()
331
+ )
332
+
333
+ task_info.agent = agent
334
+ task_info.agent_task = agent_task
335
+ task_info.task_executor = task_executor
336
+ task_info.set_phase(TaskPhase.RUNNING)
337
+ self._schedule_task_timeout(
338
+ task_id,
339
+ task_info,
340
+ task_updater,
341
+ task_store=task_store,
342
+ )
343
+
344
+ def _schedule_task_timeout(
345
+ self,
346
+ task_id: str,
347
+ task_info: TaskContext,
348
+ task_updater: TaskUpdaterProtocol | Any,
349
+ *,
350
+ task_store=None,
351
+ ) -> None:
352
+ self._cancel_task_timeout(task_info)
353
+ task_timeout = float(os.getenv("MCP_AGENT_TASK_TIMEOUT_SECONDS", "0"))
354
+ if task_timeout <= 0:
355
+ return
356
+
357
+ loop = task_info.loop
358
+ if not loop or loop.is_closed():
359
+ return
360
+
361
+ def on_timeout() -> None:
362
+ current_info = self.task_pool.get(task_id)
363
+ if current_info is None or current_info.finalized:
364
+ return
365
+
366
+ reason = (
367
+ f"任务执行超过 {self._format_timeout_seconds(task_timeout)} 秒,已自动取消"
368
+ )
369
+ logger.warning(
370
+ "Task execution timeout agent_id=%s task_id=%s reason=%s",
371
+ self._agent.agent_id,
372
+ task_id,
373
+ reason,
374
+ )
375
+ current_info.control.task_timeout_handle = None
376
+ loop.create_task(
377
+ self._request_task_stop(
378
+ task_id,
379
+ task_updater,
380
+ error=TaskExecutionTimeoutError(reason),
381
+ reason=reason,
382
+ phase=TaskPhase.TIMED_OUT,
383
+ task_store=task_store,
384
+ )
385
+ )
386
+
387
+ task_info.control.task_timeout_handle = loop.call_later(task_timeout, on_timeout)
388
+
389
+ def _schedule_wait_timeout(
390
+ self,
391
+ task_id: str,
392
+ task_info: TaskContext,
393
+ task_updater: TaskUpdaterProtocol | Any,
394
+ *,
395
+ task_store=None,
396
+ ) -> None:
397
+ self._cancel_wait_timeout(task_info)
398
+ wait_timeout = float(os.getenv("MCP_AGENT_WAIT_TIMEOUT_SECONDS", "1800"))
399
+ if wait_timeout <= 0:
400
+ return
401
+
402
+ loop = task_info.loop
403
+ if not loop or loop.is_closed():
404
+ return
405
+
406
+ def on_timeout() -> None:
407
+ current_info = self.task_pool.get(task_id)
408
+ if (
409
+ current_info is None
410
+ or current_info.finalized
411
+ or current_info.wait_item is None
412
+ or current_info.timed_out
413
+ ):
414
+ return
415
+ reason = (
416
+ "任务等待用户输入或授权超过 "
417
+ f"{self._format_timeout_seconds(wait_timeout)} 秒,已自动失败"
418
+ )
419
+ logger.warning(
420
+ "Task wait timeout agent_id=%s task_id=%s reason=%s",
421
+ self._agent.agent_id,
422
+ task_id,
423
+ reason,
424
+ )
425
+ current_info.timeout_handle = None
426
+ loop.create_task(
427
+ self._request_task_stop(
428
+ task_id,
429
+ task_updater,
430
+ error=TaskWaitTimeoutError(reason),
431
+ reason=reason,
432
+ phase=TaskPhase.TIMED_OUT,
433
+ task_store=task_store,
434
+ )
435
+ )
436
+
437
+ task_info.timeout_handle = loop.call_later(wait_timeout, on_timeout)
438
+
439
+ async def _request_task_stop(
440
+ self,
441
+ task_id: str,
442
+ task_updater: TaskUpdaterProtocol | Any,
443
+ *,
444
+ error: Exception,
445
+ reason: str,
446
+ phase: TaskPhase,
447
+ task_store=None,
448
+ ) -> None:
449
+ """统一处理任务停止。
450
+
451
+ 所有 stop path 都收口到这里:
452
+
453
+ - 用户取消
454
+ - 等待超时
455
+ - 整体任务超时
456
+ """
457
+
458
+ task_info = self.task_pool.get(task_id)
459
+ if not task_info or task_info.finalized:
460
+ return
461
+ if task_info.control.stop_requested and task_info.control.stop_error is not None:
462
+ return
463
+
464
+ logger.info(
465
+ "Stopping task agent_id=%s task_id=%s phase=%s reason=%s",
466
+ self._agent.agent_id,
467
+ task_id,
468
+ phase.value,
469
+ reason,
470
+ )
471
+ task_info.request_stop(
472
+ error=error,
473
+ reason=reason,
474
+ timed_out=phase == TaskPhase.TIMED_OUT,
475
+ )
476
+ task_info.set_phase(TaskPhase.CANCELLING)
477
+ self._cancel_all_timeouts(task_info)
478
+ self._clear_wait_state(task_info)
479
+ self._abort_active_tool_call(task_info)
480
+ self._interrupt_agent(task_info)
481
+ self._release_event(task_info)
482
+ await self._finalize_task(
483
+ task_id,
484
+ task_updater,
485
+ error=error,
486
+ task_store=task_store,
487
+ )
488
+
489
+ async def _delete_failed_task_from_store(
490
+ self,
491
+ task_id: str,
492
+ task_updater: TaskUpdaterProtocol | Any,
493
+ *,
494
+ task_store=None,
495
+ ) -> None:
496
+ if task_store is None:
497
+ return
498
+
499
+ event_queue = getattr(task_updater, "event_queue", None)
500
+ if event_queue is not None and not event_queue.is_closed():
501
+ self._failed_task_ids.add(task_id)
502
+ return
503
+
504
+ try:
505
+ await task_store.delete(task_id)
506
+ except Exception:
507
+ logger.warning(
508
+ "Failed to delete failed task directly from task_store task_id=%s",
509
+ task_id,
510
+ exc_info=True,
511
+ )
512
+
513
+ def _run_agent_with_cleanup(
514
+ self,
515
+ agent: Any,
516
+ task_text: str,
517
+ observation: A2ATaskObservation | None = None,
518
+ ):
519
+ try:
520
+ with a2a_task_observation(observation) as span:
521
+ result = agent.run(task=task_text)
522
+ update_observation_output(span, result)
523
+ return result
524
+ finally:
525
+ mcp_clients = list(getattr(agent, "_runtime_mcp_clients", []) or [])
526
+ for client in mcp_clients:
527
+ try:
528
+ client.disconnect()
529
+ except Exception:
530
+ logger.warning(
531
+ "Failed to disconnect MCP client during task cleanup agent_id=%s",
532
+ self._agent.agent_id,
533
+ exc_info=True,
534
+ )
535
+ try:
536
+ agent._runtime_mcp_clients = []
537
+ except Exception:
538
+ pass
539
+
540
+ async def _handle_wait_resume(
541
+ self,
542
+ task_id: str,
543
+ task_info: TaskContext,
544
+ task_updater: TaskUpdaterProtocol | Any,
545
+ user_input: str | dict | bool | None,
546
+ *,
547
+ task_store=None,
548
+ ) -> bool:
549
+ wait_item = task_info.wait_item
550
+ if wait_item is None or user_input is None:
551
+ return False
552
+
553
+ wait_type = wait_state_type(wait_item)
554
+ if wait_type == WAIT_TYPE_INPUT_REQUIRED:
555
+ task_info.user_input = user_input
556
+ task_info.set_phase(TaskPhase.RUNNING)
557
+ logger.debug("Resuming task with user input task_id=%s", task_id)
558
+ self._release_event(task_info)
559
+ self._clear_wait_state(task_info, clear_resume_payload=False)
560
+ return True
561
+
562
+ if wait_type != WAIT_TYPE_AUTH_REQUIRED:
563
+ return False
564
+
565
+ approved: bool | None = None
566
+ override_tool_name = None
567
+ override_args = None
568
+ parsed_dict = user_input if isinstance(user_input, dict) else None
569
+ if parsed_dict is None:
570
+ approved = self._normalize_auth_decision(user_input)
571
+ else:
572
+ auth_payload = self._auth_payload(parsed_dict)
573
+ approved = self._extract_auth_decision(auth_payload)
574
+ override_tool_name, override_args = self._extract_auth_override(auth_payload)
575
+
576
+ if approved is None:
577
+ return False
578
+ if approved:
579
+ task_info.auth_denied = False
580
+ task_info.tool_name_override = override_tool_name
581
+ task_info.tool_args_override = override_args
582
+ task_info.set_phase(TaskPhase.RUNNING)
583
+ logger.debug(
584
+ "Resuming task with auth approved task_id=%s tool=%s",
585
+ task_id,
586
+ override_tool_name,
587
+ )
588
+ self._release_event(task_info)
589
+ self._clear_wait_state(task_info)
590
+ return True
591
+
592
+ logger.info(
593
+ "Task auth denied agent_id=%s task_id=%s",
594
+ self._agent.agent_id,
595
+ task_id,
596
+ )
597
+ task_info.auth_denied = True
598
+ self._release_event(task_info)
599
+ self._clear_wait_state(task_info)
600
+ await self._request_task_stop(
601
+ task_id,
602
+ task_updater,
603
+ error=UserCancelledError("user denied auth request"),
604
+ reason="user denied auth request",
605
+ phase=TaskPhase.CANCELLED,
606
+ task_store=task_store,
607
+ )
608
+ return True
609
+
610
+ async def _listen_for_events(self, task_id: str):
611
+ while True:
612
+ task_info = self.task_pool.get(task_id)
613
+ if not task_info or task_info.finalized:
614
+ return
615
+ agent_task = task_info.agent_task
616
+ if agent_task and agent_task.done():
617
+ return
618
+ if task_info.wait_item is not None:
619
+ yield task_info.wait_item
620
+ await asyncio.sleep(0.2)
621
+
622
+ async def _emit_wait_state(
623
+ self,
624
+ task_id: str,
625
+ task_info: TaskContext,
626
+ task_updater: TaskUpdaterProtocol | Any,
627
+ *,
628
+ task_store=None,
629
+ ) -> None:
630
+ async for pending_item in self._listen_for_events(task_id):
631
+ payload = wait_state_payload(pending_item)
632
+ wait_type = payload["type"]
633
+ wait_data = payload["data"]
634
+ if wait_type == WAIT_TYPE_INPUT_REQUIRED:
635
+ task_info.set_phase(TaskPhase.WAITING_INPUT)
636
+ logger.info(
637
+ "Task entered input_required agent_id=%s task_id=%s",
638
+ self._agent.agent_id,
639
+ task_id,
640
+ )
641
+ self._schedule_wait_timeout(
642
+ task_id,
643
+ task_info,
644
+ task_updater,
645
+ task_store=task_store,
646
+ )
647
+ await emit_wait_state(task_updater, payload)
648
+ task_info.wait_item_emitted = True
649
+ return
650
+ if wait_type == WAIT_TYPE_AUTH_REQUIRED:
651
+ task_info.set_phase(TaskPhase.WAITING_AUTH)
652
+ tool_name = wait_data.get("tool_name")
653
+ logger.info(
654
+ "Task entered auth_required agent_id=%s task_id=%s tool=%s",
655
+ self._agent.agent_id,
656
+ task_id,
657
+ tool_name,
658
+ )
659
+ self._schedule_wait_timeout(
660
+ task_id,
661
+ task_info,
662
+ task_updater,
663
+ task_store=task_store,
664
+ )
665
+ await emit_wait_state(task_updater, payload)
666
+ task_info.wait_item_emitted = True
667
+ return
668
+
669
+ async def _process_task_request(
670
+ self,
671
+ task_id: str,
672
+ task_updater: TaskUpdaterProtocol | Any,
673
+ user_input: str | dict | bool | None = None,
674
+ *,
675
+ task_store=None,
676
+ ) -> None:
677
+ task_info = self.task_pool.get(task_id)
678
+ if not task_info:
679
+ return
680
+
681
+ wait_item = task_info.wait_item
682
+ if wait_item and user_input is None and task_info.wait_item_emitted:
683
+ return
684
+
685
+ if wait_item and user_input is not None:
686
+ handled = await self._handle_wait_resume(
687
+ task_id,
688
+ task_info,
689
+ task_updater,
690
+ user_input,
691
+ task_store=task_store,
692
+ )
693
+ if handled:
694
+ return
695
+
696
+ await self._emit_wait_state(
697
+ task_id,
698
+ task_info,
699
+ task_updater,
700
+ task_store=task_store,
701
+ )
702
+
703
+ async def _run_task_cycle(
704
+ self,
705
+ task_id: str,
706
+ task_updater: TaskUpdaterProtocol | Any,
707
+ user_input: str | dict | bool | None = None,
708
+ *,
709
+ task_store=None,
710
+ ) -> None:
711
+ """驱动一次任务执行周期。
712
+
713
+ A2A bridge 只负责把请求转发进来;真正的等待、恢复、完成、失败都在这里收口。
714
+ """
715
+
716
+ await self._process_task_request(
717
+ task_id,
718
+ task_updater,
719
+ user_input,
720
+ task_store=task_store,
721
+ )
722
+
723
+ task_info = self.task_pool.get(task_id)
724
+ if not task_info or task_info.finalized:
725
+ return
726
+ if task_info.control.stop_requested:
727
+ return
728
+
729
+ agent_task = task_info.agent_task
730
+ if not agent_task:
731
+ return
732
+
733
+ while True:
734
+ if task_info.wait_item is not None:
735
+ await self._process_task_request(
736
+ task_id,
737
+ task_updater,
738
+ None,
739
+ task_store=task_store,
740
+ )
741
+ return
742
+ if task_info.control.stop_requested:
743
+ return
744
+ if agent_task.done():
745
+ break
746
+ await asyncio.sleep(0.1)
747
+
748
+ try:
749
+ result = await agent_task
750
+ except asyncio.CancelledError:
751
+ if task_info.finalized:
752
+ return
753
+ error = task_info.control.stop_error or TaskCancelledError(
754
+ "task cancelled by runtime"
755
+ )
756
+ await self._finalize_task(
757
+ task_id,
758
+ task_updater,
759
+ error=error,
760
+ task_store=task_store,
761
+ )
762
+ except Exception as exc:
763
+ if task_info.finalized:
764
+ return
765
+ error = task_info.control.stop_error or exc
766
+ await self._finalize_task(
767
+ task_id,
768
+ task_updater,
769
+ error=error,
770
+ task_store=task_store,
771
+ )
772
+ else:
773
+ if task_info.finalized:
774
+ return
775
+ if (
776
+ task_info.control.stop_requested
777
+ and task_info.control.stop_error is not None
778
+ ):
779
+ await self._finalize_task(
780
+ task_id,
781
+ task_updater,
782
+ error=task_info.control.stop_error,
783
+ task_store=task_store,
784
+ )
785
+ return
786
+ await self._finalize_task(
787
+ task_id,
788
+ task_updater,
789
+ result=result,
790
+ task_store=task_store,
791
+ )
792
+
793
+ async def _cancel_task(
794
+ self,
795
+ task_id: str,
796
+ task_updater: TaskUpdaterProtocol | Any,
797
+ *,
798
+ task_store=None,
799
+ ) -> None:
800
+ """终止任务并统一走 finalize 清理路径。"""
801
+
802
+ task_info = self.task_pool.get(task_id)
803
+ if not task_info or task_info.finalized:
804
+ return
805
+
806
+ logger.info(
807
+ "Cancelling task agent_id=%s task_id=%s",
808
+ self._agent.agent_id,
809
+ task_id,
810
+ )
811
+ await self._request_task_stop(
812
+ task_id,
813
+ task_updater,
814
+ error=TaskCancelledError("task cancelled by client"),
815
+ reason="task cancelled by client",
816
+ phase=TaskPhase.CANCELLED,
817
+ task_store=task_store,
818
+ )
819
+
820
+ async def _finalize_task(
821
+ self,
822
+ task_id: str,
823
+ task_updater: TaskUpdaterProtocol | Any,
824
+ result: Any | None = None,
825
+ error: Exception | None = None,
826
+ message_text: str | None = None,
827
+ *,
828
+ task_store=None,
829
+ ) -> None:
830
+ task_info = self.task_pool.get(task_id)
831
+ if not task_info or task_info.finalized:
832
+ return
833
+ task_info.finalized = True
834
+
835
+ if message_text is not None:
836
+ task_info.set_phase(TaskPhase.COMPLETED)
837
+ logger.debug(
838
+ "Task finalized with message agent_id=%s task_id=%s",
839
+ self._agent.agent_id,
840
+ task_id,
841
+ )
842
+ await publish_message_completion(task_updater, message_text)
843
+ elif error is not None:
844
+ task_info.set_phase(self._phase_for_error(error))
845
+ if isinstance(error, TaskWaitTimeoutError):
846
+ logger.warning(
847
+ "Task timed out agent_id=%s task_id=%s error=%s: %s",
848
+ self._agent.agent_id,
849
+ task_id,
850
+ type(error).__name__,
851
+ error,
852
+ )
853
+ elif isinstance(error, TaskExecutionTimeoutError):
854
+ logger.warning(
855
+ "Task execution timed out agent_id=%s task_id=%s error=%s: %s",
856
+ self._agent.agent_id,
857
+ task_id,
858
+ type(error).__name__,
859
+ error,
860
+ )
861
+ elif isinstance(error, (UserCancelledError, TaskCancelledError)):
862
+ logger.info(
863
+ "Task cancelled agent_id=%s task_id=%s error=%s: %s",
864
+ self._agent.agent_id,
865
+ task_id,
866
+ type(error).__name__,
867
+ error,
868
+ )
869
+ else:
870
+ logger.error(
871
+ "Task failed agent_id=%s task_id=%s error=%s: %s",
872
+ self._agent.agent_id,
873
+ task_id,
874
+ type(error).__name__,
875
+ error,
876
+ )
877
+ await publish_error(task_updater, error)
878
+ await self._delete_failed_task_from_store(
879
+ task_id,
880
+ task_updater,
881
+ task_store=task_store,
882
+ )
883
+ else:
884
+ task_info.set_phase(TaskPhase.COMPLETED)
885
+ logger.info(
886
+ "Task completed agent_id=%s task_id=%s",
887
+ self._agent.agent_id,
888
+ task_id,
889
+ )
890
+ final_result = self.format_result(result)
891
+ await publish_result(task_updater, final_result)
892
+
893
+ self._cancel_all_timeouts(task_info)
894
+ self._abort_active_tool_call(task_info)
895
+ self._shutdown_task_executor(task_info)
896
+ task_info.set_phase(TaskPhase.FINALIZED)
897
+ self.task_pool.pop(task_id, None)