agent-runtime-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent_runtime/__init__.py +84 -0
- agent_runtime/builder.py +317 -0
- agent_runtime/config/__init__.py +29 -0
- agent_runtime/config/definitions.py +144 -0
- agent_runtime/config/policies.py +63 -0
- agent_runtime/config/storage.py +117 -0
- agent_runtime/context.py +10 -0
- agent_runtime/definitions.py +33 -0
- agent_runtime/discovery.py +16 -0
- agent_runtime/exceptions.py +74 -0
- agent_runtime/mcp/__init__.py +28 -0
- agent_runtime/mcp/discovery.py +146 -0
- agent_runtime/mcp/metadata.py +68 -0
- agent_runtime/mcp/utils.py +52 -0
- agent_runtime/model_registry.py +40 -0
- agent_runtime/plugins/__init__.py +4 -0
- agent_runtime/plugins/base.py +90 -0
- agent_runtime/plugins/default.py +19 -0
- agent_runtime/plugins/instructions.py +38 -0
- agent_runtime/plugins/loader.py +59 -0
- agent_runtime/policies.py +15 -0
- agent_runtime/runtime.py +110 -0
- agent_runtime/runtime_engine/__init__.py +22 -0
- agent_runtime/runtime_engine/a2a_bridge.py +190 -0
- agent_runtime/runtime_engine/a2a_task_io.py +165 -0
- agent_runtime/runtime_engine/agent_build.py +315 -0
- agent_runtime/runtime_engine/context.py +469 -0
- agent_runtime/runtime_engine/loading.py +170 -0
- agent_runtime/runtime_engine/observability.py +154 -0
- agent_runtime/runtime_engine/policy_registry.py +98 -0
- agent_runtime/runtime_engine/protocol_tools.py +94 -0
- agent_runtime/runtime_engine/task_flow.py +897 -0
- agent_runtime/runtime_engine/tool_flow.py +332 -0
- agent_runtime/sdk_agent.py +548 -0
- agent_runtime/server/__init__.py +15 -0
- agent_runtime/server/app_factory.py +37 -0
- agent_runtime/server/bootstrap.py +48 -0
- agent_runtime/server/endpoint_utils.py +37 -0
- agent_runtime/server/management.py +107 -0
- agent_runtime/smol/__init__.py +4 -0
- agent_runtime/smol/agents.py +431 -0
- agent_runtime/smol/llm_models.py +212 -0
- agent_runtime/smol/memory.py +111 -0
- agent_runtime/smol/models.py +69 -0
- agent_runtime/standalone.py +57 -0
- agent_runtime/storage.py +5 -0
- agent_runtime/tools.py +5 -0
- agent_runtime_sdk-0.1.0.dist-info/METADATA +125 -0
- agent_runtime_sdk-0.1.0.dist-info/RECORD +51 -0
- agent_runtime_sdk-0.1.0.dist-info/WHEEL +5 -0
- agent_runtime_sdk-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from .base import BaseAgentPlugin
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class DefaultPlugin(BaseAgentPlugin):
|
|
9
|
+
"""默认兜底插件。
|
|
10
|
+
|
|
11
|
+
框架基础提示词会由 runtime 统一注入;DefaultPlugin 只补充配置里的额外说明。
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(self, config: dict[str, Any] | None = None):
|
|
15
|
+
super().__init__(config)
|
|
16
|
+
self._extra_instructions: str = (config or {}).get("instructions", "")
|
|
17
|
+
|
|
18
|
+
def build_instructions(self, definition, tools: list) -> str:
|
|
19
|
+
return self._extra_instructions.strip()
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def join_instruction_sections(*sections: str | None) -> str:
|
|
7
|
+
parts: list[str] = []
|
|
8
|
+
for section in sections:
|
|
9
|
+
if section is None:
|
|
10
|
+
continue
|
|
11
|
+
text = str(section).strip()
|
|
12
|
+
if text:
|
|
13
|
+
parts.append(text)
|
|
14
|
+
return "\n\n".join(parts)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def build_framework_instructions(definition: Any, tools: list[Any]) -> str:
|
|
18
|
+
tool_lines: list[str] = []
|
|
19
|
+
for tool in tools:
|
|
20
|
+
line = f"- **{getattr(tool, 'name', '')}**"
|
|
21
|
+
description = getattr(tool, "description", "") or ""
|
|
22
|
+
if description:
|
|
23
|
+
line += f": {description}"
|
|
24
|
+
inputs = getattr(tool, "inputs", None) or {}
|
|
25
|
+
if inputs:
|
|
26
|
+
line += f" (参数: {', '.join(inputs.keys())})"
|
|
27
|
+
tool_lines.append(line)
|
|
28
|
+
|
|
29
|
+
tool_section = "\n".join(tool_lines) if tool_lines else "(暂无工具)"
|
|
30
|
+
return (
|
|
31
|
+
f"你是 {definition.agent.name},一个智能助手。\n\n"
|
|
32
|
+
f"## 可用工具\n{tool_section}\n\n"
|
|
33
|
+
"## 使用规则\n"
|
|
34
|
+
"- 需要补充信息时,调用 ask_user\n"
|
|
35
|
+
"- 需要用户授权确认时,调用 ask_auth\n"
|
|
36
|
+
"- 最终回答必须通过 final_answer 输出\n"
|
|
37
|
+
"- 不要猜测参数值,不确定就向用户确认\n"
|
|
38
|
+
)
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
import importlib.util
|
|
5
|
+
import logging
|
|
6
|
+
import sys
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
from .base import BaseAgentPlugin
|
|
10
|
+
from ..exceptions import PluginLoadError
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _load_module(module_path: str):
|
|
16
|
+
if module_path.endswith(".py") or "/" in module_path:
|
|
17
|
+
path = Path(module_path).resolve()
|
|
18
|
+
module_name = f"dynamic_plugin_{path.stem}_{abs(hash(path))}"
|
|
19
|
+
spec = importlib.util.spec_from_file_location(module_name, path)
|
|
20
|
+
if spec is None or spec.loader is None:
|
|
21
|
+
raise ImportError(f"unable to load plugin module from {path}")
|
|
22
|
+
module = importlib.util.module_from_spec(spec)
|
|
23
|
+
sys.modules[module_name] = module
|
|
24
|
+
spec.loader.exec_module(module)
|
|
25
|
+
return module
|
|
26
|
+
|
|
27
|
+
importlib.invalidate_caches()
|
|
28
|
+
return importlib.import_module(module_path)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def load_plugin(module_path: str, class_name: str, config: dict | None = None) -> BaseAgentPlugin:
|
|
32
|
+
try:
|
|
33
|
+
module = _load_module(module_path)
|
|
34
|
+
except Exception as exc:
|
|
35
|
+
logger.exception("Failed to load plugin module path=%s", module_path)
|
|
36
|
+
raise PluginLoadError(
|
|
37
|
+
f"failed to load plugin module '{module_path}': {exc}",
|
|
38
|
+
) from exc
|
|
39
|
+
|
|
40
|
+
try:
|
|
41
|
+
plugin_cls = getattr(module, class_name)
|
|
42
|
+
except AttributeError as exc:
|
|
43
|
+
logger.exception("Plugin class not found module=%s class=%s", module_path, class_name)
|
|
44
|
+
raise PluginLoadError(
|
|
45
|
+
f"plugin class '{class_name}' not found in '{module_path}'",
|
|
46
|
+
) from exc
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
plugin = plugin_cls(config=config or {})
|
|
50
|
+
except Exception as exc:
|
|
51
|
+
logger.exception("Failed to instantiate plugin module=%s class=%s", module_path, class_name)
|
|
52
|
+
raise PluginLoadError(
|
|
53
|
+
f"failed to instantiate plugin '{module_path}:{class_name}': {exc}",
|
|
54
|
+
) from exc
|
|
55
|
+
|
|
56
|
+
if not isinstance(plugin, BaseAgentPlugin):
|
|
57
|
+
raise PluginLoadError(f"{module_path}:{class_name} must inherit BaseAgentPlugin")
|
|
58
|
+
logger.debug("Loaded plugin module=%s class=%s", module_path, class_name)
|
|
59
|
+
return plugin
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Public facade for tool-policy helper functions."""
|
|
2
|
+
|
|
3
|
+
from .config.policies import (
|
|
4
|
+
collect_missing_fields,
|
|
5
|
+
is_missing,
|
|
6
|
+
merge_input_fields,
|
|
7
|
+
parse_user_payload,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"collect_missing_fields",
|
|
12
|
+
"is_missing",
|
|
13
|
+
"merge_input_fields",
|
|
14
|
+
"parse_user_payload",
|
|
15
|
+
]
|
agent_runtime/runtime.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""框架的运行时门面。
|
|
4
|
+
|
|
5
|
+
这个文件只保留 runtime 的核心状态与 public API。
|
|
6
|
+
具体实现按职责拆到 runtime_engine/:
|
|
7
|
+
|
|
8
|
+
- loading.py: 插件加载与工具发现
|
|
9
|
+
- agent_build.py: 单次请求的 agent 构建
|
|
10
|
+
- tool_flow.py: 工具调用前的治理与中断
|
|
11
|
+
- task_flow.py: 任务生命周期与等待恢复
|
|
12
|
+
- a2a_bridge.py: A2A / ASGI 桥接
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from typing import Callable
|
|
16
|
+
|
|
17
|
+
from .config.definitions import AgentDefinition, ToolPolicy
|
|
18
|
+
from .mcp.discovery import default_discover_tools
|
|
19
|
+
from .mcp.metadata import DiscoveredTool
|
|
20
|
+
from .plugins import BaseAgentPlugin
|
|
21
|
+
from .runtime_engine.context import TaskPool
|
|
22
|
+
from .runtime_engine.a2a_bridge import RuntimeA2ABridge
|
|
23
|
+
from .runtime_engine.agent_build import RuntimeAgentBuild
|
|
24
|
+
from .runtime_engine.loading import RuntimeLoading
|
|
25
|
+
from .runtime_engine.observability import setup_smolagents_observability
|
|
26
|
+
from .runtime_engine.policy_registry import MCPPolicyRegistry
|
|
27
|
+
from .runtime_engine.task_flow import RuntimeTaskFlow
|
|
28
|
+
from .runtime_engine.tool_flow import RuntimeToolFlow
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ManagedAgentRuntime(
|
|
32
|
+
RuntimeA2ABridge,
|
|
33
|
+
RuntimeTaskFlow,
|
|
34
|
+
RuntimeToolFlow,
|
|
35
|
+
RuntimeAgentBuild,
|
|
36
|
+
RuntimeLoading,
|
|
37
|
+
):
|
|
38
|
+
"""单个 agent definition 对应的一份运行时实例。
|
|
39
|
+
|
|
40
|
+
这里故意只做门面:
|
|
41
|
+
|
|
42
|
+
- 保存 runtime 级状态
|
|
43
|
+
- 暴露对外方法
|
|
44
|
+
- 通过 runtime_engine 里的几个职责 mixin 形成完整能力
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
definition: AgentDefinition,
|
|
50
|
+
public_base_url: str,
|
|
51
|
+
discoverer: Callable[..., list[DiscoveredTool]] | None = None,
|
|
52
|
+
):
|
|
53
|
+
self.definition = definition
|
|
54
|
+
self.public_base_url = public_base_url.rstrip("/")
|
|
55
|
+
self.discoverer = discoverer or default_discover_tools
|
|
56
|
+
self.task_pool = TaskPool()
|
|
57
|
+
self.plugin: BaseAgentPlugin | None = None
|
|
58
|
+
# 这里只缓存 discovery 阶段的工具摘要,不缓存可执行 tool 实例。
|
|
59
|
+
self.discovered_tools: list[DiscoveredTool] = []
|
|
60
|
+
self.load_error: str | None = None
|
|
61
|
+
self.load_exception: Exception | None = None
|
|
62
|
+
self._asgi_app = None
|
|
63
|
+
self._failed_task_ids: set[str] = set()
|
|
64
|
+
self._tool_source_by_name: dict[str, str] = {}
|
|
65
|
+
setup_smolagents_observability()
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def status(self) -> str:
|
|
69
|
+
return "error" if self.load_error else "ready"
|
|
70
|
+
|
|
71
|
+
def public_url(self) -> str:
|
|
72
|
+
return self.public_base_url
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def _agent(self):
|
|
76
|
+
return self.definition.agent
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def _runtime_config(self):
|
|
80
|
+
return self.definition.runtime
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def _mcps(self):
|
|
84
|
+
return self.definition.mcps
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def _policy_registry(self) -> MCPPolicyRegistry:
|
|
88
|
+
return MCPPolicyRegistry(self._mcps, self._tool_source_by_name)
|
|
89
|
+
|
|
90
|
+
def set_instructions(self, text: str) -> None:
|
|
91
|
+
self._agent.extra_instructions = text
|
|
92
|
+
|
|
93
|
+
def get_instructions(self) -> str:
|
|
94
|
+
return self._agent.extra_instructions
|
|
95
|
+
|
|
96
|
+
def set_tool_policy(
|
|
97
|
+
self, tool_name: str, policy: ToolPolicy, mcp_name: str | None = None
|
|
98
|
+
) -> None:
|
|
99
|
+
self._policy_registry.set_tool_policy(tool_name, policy, mcp_name)
|
|
100
|
+
|
|
101
|
+
def get_tool_policy(
|
|
102
|
+
self, tool_name: str, mcp_name: str | None = None
|
|
103
|
+
) -> ToolPolicy | None:
|
|
104
|
+
return self._policy_registry.get_tool_policy(tool_name, mcp_name)
|
|
105
|
+
|
|
106
|
+
def list_tool_policies(self, mcp_name: str | None = None) -> dict[str, ToolPolicy]:
|
|
107
|
+
return self._policy_registry.list_tool_policies(mcp_name)
|
|
108
|
+
|
|
109
|
+
def remove_tool_policy(self, tool_name: str, mcp_name: str | None = None) -> bool:
|
|
110
|
+
return self._policy_registry.remove_tool_policy(tool_name, mcp_name)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""ManagedAgentRuntime 的内部引擎模块
|
|
2
|
+
|
|
3
|
+
- `runtime.py` 只保留 runtime 的核心状态和 public API
|
|
4
|
+
- `runtime_engine/` 按职责拆开内部实现,避免所有逻辑都堆在一个大文件里
|
|
5
|
+
|
|
6
|
+
阅读顺序建议:
|
|
7
|
+
|
|
8
|
+
1. `loading.py`: reload 时加载 plugin,并发现 MCP 工具
|
|
9
|
+
2. `agent_build.py`: 单次请求到来时,把 model/tools/instructions 组装成 agent
|
|
10
|
+
3. `tool_flow.py`: 工具真正执行前,做缺参补齐和授权确认
|
|
11
|
+
4. `task_flow.py`: 管理任务的启动、等待、恢复、超时、取消、完成
|
|
12
|
+
5. `a2a_bridge.py`: 把 runtime 接到 A2A / ASGI
|
|
13
|
+
|
|
14
|
+
辅助模块:
|
|
15
|
+
|
|
16
|
+
- `context.py`: 任务上下文、contextvars、wait state
|
|
17
|
+
- `a2a_task_io.py`: A2A 请求解析和任务消息输出
|
|
18
|
+
- `protocol_tools.py`: ask_user / ask_auth / final_answer 这三个协议工具
|
|
19
|
+
- `policy_registry.py`: MCP 级工具策略的统一读写
|
|
20
|
+
|
|
21
|
+
这些模块只服务于 runtime 门面,不作为对外 public contract。
|
|
22
|
+
"""
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
"""A2A / ASGI 桥接层。
|
|
2
|
+
|
|
3
|
+
这个模块把 `ManagedAgentRuntime` 暴露成 A2A server:
|
|
4
|
+
|
|
5
|
+
- 构建 agent card
|
|
6
|
+
- 创建 A2A 应用
|
|
7
|
+
- 把 A2A 的 execute / cancel 生命周期转发到 runtime task flow
|
|
8
|
+
|
|
9
|
+
具体的 A2A 请求解析和任务消息组装,已经下沉到 `a2a_task_io.py`。
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import asyncio
|
|
15
|
+
import logging
|
|
16
|
+
|
|
17
|
+
from .a2a_task_io import (
|
|
18
|
+
extract_initial_text,
|
|
19
|
+
extract_request_headers,
|
|
20
|
+
extract_user_input,
|
|
21
|
+
publish_error,
|
|
22
|
+
)
|
|
23
|
+
from .context import current_task_id, current_task_pool
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class RuntimeA2ABridge:
|
|
30
|
+
"""把 runtime 接到 A2A 协议和 ASGI 应用上的桥接层。"""
|
|
31
|
+
|
|
32
|
+
def build_agent_card(self):
|
|
33
|
+
from a2a.types import AgentCapabilities, AgentCard, AgentSkill
|
|
34
|
+
|
|
35
|
+
skills = [
|
|
36
|
+
AgentSkill(
|
|
37
|
+
id=skill.id,
|
|
38
|
+
name=skill.name,
|
|
39
|
+
description=skill.description,
|
|
40
|
+
tags=skill.tags,
|
|
41
|
+
examples=skill.examples,
|
|
42
|
+
)
|
|
43
|
+
for skill in self._agent.a2a.skills
|
|
44
|
+
]
|
|
45
|
+
return AgentCard(
|
|
46
|
+
name=self._agent.name,
|
|
47
|
+
description=self._agent.description,
|
|
48
|
+
url=self.public_url(),
|
|
49
|
+
version=self._agent.version,
|
|
50
|
+
default_input_modes=self._agent.a2a.default_input_modes,
|
|
51
|
+
default_output_modes=self._agent.a2a.default_output_modes,
|
|
52
|
+
capabilities=AgentCapabilities(streaming=True),
|
|
53
|
+
skills=skills,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
def get_asgi_app(self):
|
|
57
|
+
if self._asgi_app is None:
|
|
58
|
+
from a2a.server.apps import A2AStarletteApplication
|
|
59
|
+
from a2a.server.request_handlers import DefaultRequestHandler
|
|
60
|
+
from a2a.server.tasks import InMemoryTaskStore
|
|
61
|
+
|
|
62
|
+
runtime = self
|
|
63
|
+
task_store = InMemoryTaskStore()
|
|
64
|
+
executor = self._build_executor(task_store=task_store)
|
|
65
|
+
|
|
66
|
+
class CleanupRequestHandler(DefaultRequestHandler):
|
|
67
|
+
async def _cleanup_producer(
|
|
68
|
+
self, producer_task: asyncio.Task, task_id: str
|
|
69
|
+
) -> None:
|
|
70
|
+
await super()._cleanup_producer(producer_task, task_id)
|
|
71
|
+
if not runtime.pop_failed_task_id(task_id):
|
|
72
|
+
return
|
|
73
|
+
try:
|
|
74
|
+
await self.task_store.delete(task_id)
|
|
75
|
+
except Exception:
|
|
76
|
+
logger.warning(
|
|
77
|
+
"Failed to delete failed task from A2A task_store task_id=%s",
|
|
78
|
+
task_id,
|
|
79
|
+
exc_info=True,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
handler = CleanupRequestHandler(
|
|
83
|
+
agent_executor=executor, task_store=task_store
|
|
84
|
+
)
|
|
85
|
+
self._asgi_app = A2AStarletteApplication(
|
|
86
|
+
agent_card=self.build_agent_card(),
|
|
87
|
+
http_handler=handler,
|
|
88
|
+
).build()
|
|
89
|
+
return self._asgi_app
|
|
90
|
+
|
|
91
|
+
def _build_executor(self, task_store=None):
|
|
92
|
+
from a2a.server.agent_execution import AgentExecutor
|
|
93
|
+
from a2a.server.tasks import TaskUpdater
|
|
94
|
+
|
|
95
|
+
class RuntimeExecutor(AgentExecutor):
|
|
96
|
+
async def execute(self, context, event_queue):
|
|
97
|
+
main_loop = asyncio.get_running_loop()
|
|
98
|
+
updater = TaskUpdater(event_queue, context.task_id, context.context_id)
|
|
99
|
+
|
|
100
|
+
if not context.task_id:
|
|
101
|
+
return
|
|
102
|
+
|
|
103
|
+
current_task_id.set(context.task_id)
|
|
104
|
+
current_task_pool.set(self_runtime.task_pool)
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
if not context.current_task:
|
|
108
|
+
await self_runtime._start_new_task(
|
|
109
|
+
context.task_id,
|
|
110
|
+
extract_initial_text(context),
|
|
111
|
+
extract_request_headers(context),
|
|
112
|
+
main_loop,
|
|
113
|
+
updater,
|
|
114
|
+
context_id=context.context_id,
|
|
115
|
+
task_store=task_store,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
task_info = self_runtime.task_pool.get(context.task_id)
|
|
119
|
+
if task_info:
|
|
120
|
+
task_info.updater = updater
|
|
121
|
+
task_info.loop = main_loop
|
|
122
|
+
|
|
123
|
+
from a2a.types import TaskState
|
|
124
|
+
|
|
125
|
+
await updater.update_status(TaskState.working)
|
|
126
|
+
|
|
127
|
+
user_input = (
|
|
128
|
+
extract_user_input(context, task_info)
|
|
129
|
+
if context.current_task
|
|
130
|
+
else None
|
|
131
|
+
)
|
|
132
|
+
await self_runtime._run_task_cycle(
|
|
133
|
+
context.task_id,
|
|
134
|
+
updater,
|
|
135
|
+
user_input,
|
|
136
|
+
task_store=task_store,
|
|
137
|
+
)
|
|
138
|
+
except Exception as exc:
|
|
139
|
+
logger.exception(
|
|
140
|
+
"Unhandled task execution error agent_id=%s task_id=%s",
|
|
141
|
+
self_runtime._agent.agent_id,
|
|
142
|
+
context.task_id,
|
|
143
|
+
)
|
|
144
|
+
task_info = self_runtime.task_pool.get(context.task_id)
|
|
145
|
+
if task_info is not None and not task_info.finalized:
|
|
146
|
+
await self_runtime._finalize_task(
|
|
147
|
+
context.task_id,
|
|
148
|
+
updater,
|
|
149
|
+
error=exc,
|
|
150
|
+
task_store=task_store,
|
|
151
|
+
)
|
|
152
|
+
else:
|
|
153
|
+
await publish_error(updater, exc)
|
|
154
|
+
|
|
155
|
+
async def cancel(self, context, event_queue):
|
|
156
|
+
if not context.task_id:
|
|
157
|
+
return
|
|
158
|
+
|
|
159
|
+
task_info = self_runtime.task_pool.get(context.task_id)
|
|
160
|
+
if not task_info or task_info.finalized:
|
|
161
|
+
return
|
|
162
|
+
|
|
163
|
+
updater = task_info.updater or TaskUpdater(
|
|
164
|
+
event_queue, context.task_id, context.context_id
|
|
165
|
+
)
|
|
166
|
+
try:
|
|
167
|
+
await self_runtime._cancel_task(
|
|
168
|
+
context.task_id,
|
|
169
|
+
updater,
|
|
170
|
+
task_store=task_store,
|
|
171
|
+
)
|
|
172
|
+
except Exception as exc:
|
|
173
|
+
logger.exception(
|
|
174
|
+
"Unhandled task cancellation error agent_id=%s task_id=%s",
|
|
175
|
+
self_runtime._agent.agent_id,
|
|
176
|
+
context.task_id,
|
|
177
|
+
)
|
|
178
|
+
task_info = self_runtime.task_pool.get(context.task_id)
|
|
179
|
+
if task_info is not None and not task_info.finalized:
|
|
180
|
+
await self_runtime._finalize_task(
|
|
181
|
+
context.task_id,
|
|
182
|
+
updater,
|
|
183
|
+
error=exc,
|
|
184
|
+
task_store=task_store,
|
|
185
|
+
)
|
|
186
|
+
else:
|
|
187
|
+
await publish_error(updater, exc)
|
|
188
|
+
|
|
189
|
+
self_runtime = self
|
|
190
|
+
return RuntimeExecutor()
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
"""A2A 任务传输适配层。
|
|
2
|
+
|
|
3
|
+
这个模块只处理 A2A 协议相关的任务 I/O:
|
|
4
|
+
|
|
5
|
+
- 从 A2A context 提取 headers、初始文本、resume 输入
|
|
6
|
+
- 把 runtime 的 wait state / 完成 / 失败结果转成 A2A TaskUpdater 调用
|
|
7
|
+
|
|
8
|
+
这样 `task_flow.py` 只管任务调度,不再直接依赖 A2A message 细节。
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
from .context import (
|
|
16
|
+
TaskContext,
|
|
17
|
+
TaskUpdaterProtocol,
|
|
18
|
+
WAIT_TYPE_AUTH_REQUIRED,
|
|
19
|
+
WAIT_TYPE_INPUT_REQUIRED,
|
|
20
|
+
wait_state_type,
|
|
21
|
+
)
|
|
22
|
+
from ..exceptions import user_message_for_error
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def extract_request_headers(context: Any) -> dict[str, str]:
|
|
26
|
+
headers = {}
|
|
27
|
+
if context.call_context and "headers" in context.call_context.state:
|
|
28
|
+
headers = context.call_context.state.get("headers") or {}
|
|
29
|
+
return {
|
|
30
|
+
"X-Session-Id": headers.get("x-session-id") or headers.get("X-Session-Id", ""),
|
|
31
|
+
"accessToken": headers.get("accesstoken") or headers.get("accessToken", ""),
|
|
32
|
+
"user_email": headers.get("user_email", ""),
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def extract_initial_text(context: Any) -> str:
|
|
37
|
+
return "\n".join(
|
|
38
|
+
part.root.text for part in context.message.parts if part.root.kind == "text"
|
|
39
|
+
).strip()
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
_MISSING = object()
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def extract_user_input(
|
|
46
|
+
context: Any, task_info: TaskContext | None
|
|
47
|
+
) -> str | dict | bool | None:
|
|
48
|
+
message = context.message
|
|
49
|
+
if not message:
|
|
50
|
+
return None
|
|
51
|
+
|
|
52
|
+
text_parts: list[str] = []
|
|
53
|
+
first_data_part: Any = _MISSING
|
|
54
|
+
for part in message.parts:
|
|
55
|
+
root = getattr(part, "root", None)
|
|
56
|
+
if not root:
|
|
57
|
+
continue
|
|
58
|
+
kind = getattr(root, "kind", None)
|
|
59
|
+
if kind == "text":
|
|
60
|
+
text = getattr(root, "text", None)
|
|
61
|
+
if text is not None:
|
|
62
|
+
text_parts.append(str(text))
|
|
63
|
+
elif kind == "data" and first_data_part is _MISSING:
|
|
64
|
+
data = getattr(root, "data", None)
|
|
65
|
+
if isinstance(data, (dict, bool)):
|
|
66
|
+
first_data_part = data
|
|
67
|
+
|
|
68
|
+
text_input = "\n".join(text_parts).strip() if text_parts else None
|
|
69
|
+
wait_type = wait_state_type(task_info.wait_item) if task_info else None
|
|
70
|
+
|
|
71
|
+
if wait_type == WAIT_TYPE_AUTH_REQUIRED and first_data_part is not _MISSING:
|
|
72
|
+
return first_data_part
|
|
73
|
+
if text_input is not None:
|
|
74
|
+
return text_input
|
|
75
|
+
if wait_type == WAIT_TYPE_INPUT_REQUIRED and isinstance(first_data_part, dict):
|
|
76
|
+
for key in ("answer", "text", "input", "value"):
|
|
77
|
+
value = first_data_part.get(key)
|
|
78
|
+
if value is not None:
|
|
79
|
+
return str(value)
|
|
80
|
+
if isinstance(first_data_part, dict):
|
|
81
|
+
return first_data_part
|
|
82
|
+
return None
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
async def emit_wait_state(
|
|
86
|
+
task_updater: TaskUpdaterProtocol | Any, payload: dict[str, Any]
|
|
87
|
+
) -> None:
|
|
88
|
+
from a2a.types import Part, TaskState
|
|
89
|
+
|
|
90
|
+
wait_type = payload["type"]
|
|
91
|
+
wait_data = payload["data"]
|
|
92
|
+
|
|
93
|
+
if wait_type == WAIT_TYPE_INPUT_REQUIRED:
|
|
94
|
+
prompt = wait_data.get("prompt")
|
|
95
|
+
message = None
|
|
96
|
+
if prompt:
|
|
97
|
+
message = task_updater.new_agent_message(
|
|
98
|
+
parts=[Part.model_validate({"text": prompt})],
|
|
99
|
+
metadata={
|
|
100
|
+
"event": wait_type,
|
|
101
|
+
"state": TaskState.input_required.value,
|
|
102
|
+
},
|
|
103
|
+
)
|
|
104
|
+
await task_updater.requires_input(message=message)
|
|
105
|
+
return
|
|
106
|
+
|
|
107
|
+
if wait_type == WAIT_TYPE_AUTH_REQUIRED:
|
|
108
|
+
message = task_updater.new_agent_message(
|
|
109
|
+
parts=[Part.model_validate({"kind": "data", "data": wait_data})],
|
|
110
|
+
metadata={
|
|
111
|
+
"event": wait_type,
|
|
112
|
+
"state": TaskState.auth_required.value,
|
|
113
|
+
},
|
|
114
|
+
)
|
|
115
|
+
await task_updater.update_status(
|
|
116
|
+
TaskState.auth_required,
|
|
117
|
+
message=message,
|
|
118
|
+
metadata={
|
|
119
|
+
"event": wait_type,
|
|
120
|
+
"state": TaskState.auth_required.value,
|
|
121
|
+
},
|
|
122
|
+
)
|
|
123
|
+
return
|
|
124
|
+
|
|
125
|
+
raise ValueError(f"unsupported wait state type: {wait_type}")
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
async def publish_message_completion(
|
|
129
|
+
task_updater: TaskUpdaterProtocol | Any, message_text: str
|
|
130
|
+
) -> None:
|
|
131
|
+
from a2a.types import Part, TaskState
|
|
132
|
+
|
|
133
|
+
message = task_updater.new_agent_message(
|
|
134
|
+
parts=[Part.model_validate({"text": message_text})],
|
|
135
|
+
metadata={"event": "completed", "state": TaskState.completed.value},
|
|
136
|
+
)
|
|
137
|
+
await task_updater.complete(message=message)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
async def publish_error(
|
|
141
|
+
task_updater: TaskUpdaterProtocol | Any, error: Exception
|
|
142
|
+
) -> None:
|
|
143
|
+
from a2a.types import Part, TaskState
|
|
144
|
+
|
|
145
|
+
message = task_updater.new_agent_message(
|
|
146
|
+
parts=[Part.model_validate({"text": user_message_for_error(error)})],
|
|
147
|
+
metadata={"event": "failed", "state": TaskState.failed.value},
|
|
148
|
+
)
|
|
149
|
+
await task_updater.failed(message=message)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
async def publish_result(
|
|
153
|
+
task_updater: TaskUpdaterProtocol | Any, result: Any
|
|
154
|
+
) -> None:
|
|
155
|
+
from a2a.types import Part, TaskState
|
|
156
|
+
|
|
157
|
+
if isinstance(result, dict):
|
|
158
|
+
parts = [Part.model_validate({"kind": "data", "data": result})]
|
|
159
|
+
else:
|
|
160
|
+
parts = [Part.model_validate({"text": str(result)})]
|
|
161
|
+
message = task_updater.new_agent_message(
|
|
162
|
+
parts=parts,
|
|
163
|
+
metadata={"event": "completed", "state": TaskState.completed.value},
|
|
164
|
+
)
|
|
165
|
+
await task_updater.complete(message=message)
|