nvidia-nat-a2a 1.5.0a20251229__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nvidia-nat-a2a might be problematic. Click here for more details.
- nat/meta/pypi.md +36 -0
- nat/plugins/a2a/__init__.py +14 -0
- nat/plugins/a2a/auth/__init__.py +15 -0
- nat/plugins/a2a/auth/credential_service.py +418 -0
- nat/plugins/a2a/client/__init__.py +14 -0
- nat/plugins/a2a/client/client_base.py +354 -0
- nat/plugins/a2a/client/client_config.py +72 -0
- nat/plugins/a2a/client/client_impl.py +324 -0
- nat/plugins/a2a/register.py +23 -0
- nat/plugins/a2a/server/__init__.py +14 -0
- nat/plugins/a2a/server/agent_executor_adapter.py +172 -0
- nat/plugins/a2a/server/front_end_config.py +131 -0
- nat/plugins/a2a/server/front_end_plugin.py +122 -0
- nat/plugins/a2a/server/front_end_plugin_worker.py +306 -0
- nat/plugins/a2a/server/oauth_middleware.py +121 -0
- nat/plugins/a2a/server/register_frontend.py +37 -0
- nvidia_nat_a2a-1.5.0a20251229.dist-info/METADATA +57 -0
- nvidia_nat_a2a-1.5.0a20251229.dist-info/RECORD +22 -0
- nvidia_nat_a2a-1.5.0a20251229.dist-info/WHEEL +5 -0
- nvidia_nat_a2a-1.5.0a20251229.dist-info/entry_points.txt +5 -0
- nvidia_nat_a2a-1.5.0a20251229.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat_a2a-1.5.0a20251229.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,324 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
from collections.abc import AsyncGenerator
|
|
18
|
+
from typing import TYPE_CHECKING
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
from pydantic import BaseModel
|
|
22
|
+
from pydantic import Field
|
|
23
|
+
|
|
24
|
+
from nat.builder.function import FunctionGroup
|
|
25
|
+
from nat.builder.workflow_builder import Builder
|
|
26
|
+
from nat.cli.register_workflow import register_per_user_function_group
|
|
27
|
+
from nat.plugins.a2a.client.client_base import A2ABaseClient
|
|
28
|
+
from nat.plugins.a2a.client.client_config import A2AClientConfig
|
|
29
|
+
|
|
30
|
+
if TYPE_CHECKING:
|
|
31
|
+
from nat.authentication.interfaces import AuthProviderBase
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# Input models for helper functions
|
|
37
|
+
class GetTaskInput(BaseModel):
|
|
38
|
+
"""Input for get_task function."""
|
|
39
|
+
task_id: str = Field(..., description="The ID of the task to retrieve")
|
|
40
|
+
history_length: int | None = Field(default=None, description="Number of history items to include")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class CancelTaskInput(BaseModel):
|
|
44
|
+
"""Input for cancel_task function."""
|
|
45
|
+
task_id: str = Field(..., description="The ID of the task to cancel")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class SendMessageInput(BaseModel):
|
|
49
|
+
"""Input for send_message function."""
|
|
50
|
+
query: str = Field(..., description="The query to send to the agent")
|
|
51
|
+
task_id: str | None = Field(default=None, description="Optional task ID for continuation")
|
|
52
|
+
context_id: str | None = Field(default=None, description="Optional context ID for session management")
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class A2AClientFunctionGroup(FunctionGroup):
|
|
56
|
+
"""
|
|
57
|
+
A minimal FunctionGroup for A2A agents.
|
|
58
|
+
|
|
59
|
+
Exposes a simple `send_message` function to interact with A2A agents.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def __init__(self, config: A2AClientConfig, builder: Builder):
|
|
63
|
+
super().__init__(config=config)
|
|
64
|
+
self._builder = builder
|
|
65
|
+
self._client: A2ABaseClient | None = None
|
|
66
|
+
self._include_skills_in_description = config.include_skills_in_description
|
|
67
|
+
|
|
68
|
+
async def __aenter__(self):
|
|
69
|
+
"""Initialize the A2A client and register functions."""
|
|
70
|
+
config: A2AClientConfig = self._config # type: ignore[assignment]
|
|
71
|
+
base_url = str(config.url)
|
|
72
|
+
|
|
73
|
+
# Get user_id from context (set by runtime for per-user function groups)
|
|
74
|
+
from nat.builder.context import Context
|
|
75
|
+
user_id = Context.get().user_id
|
|
76
|
+
if not user_id:
|
|
77
|
+
raise RuntimeError("User ID not found in context")
|
|
78
|
+
|
|
79
|
+
# Resolve auth provider if configured
|
|
80
|
+
auth_provider: AuthProviderBase | None = None
|
|
81
|
+
if config.auth_provider:
|
|
82
|
+
try:
|
|
83
|
+
auth_provider = await self._builder.get_auth_provider(config.auth_provider)
|
|
84
|
+
logger.info("Resolved authentication provider for A2A client")
|
|
85
|
+
except Exception as e:
|
|
86
|
+
logger.error("Failed to resolve auth provider '%s': %s", config.auth_provider, e)
|
|
87
|
+
raise RuntimeError(f"Failed to resolve auth provider: {e}") from e
|
|
88
|
+
|
|
89
|
+
# Create and initialize A2A client
|
|
90
|
+
self._client = A2ABaseClient(
|
|
91
|
+
base_url=base_url,
|
|
92
|
+
agent_card_path=config.agent_card_path,
|
|
93
|
+
task_timeout=config.task_timeout,
|
|
94
|
+
streaming=config.streaming,
|
|
95
|
+
auth_provider=auth_provider,
|
|
96
|
+
)
|
|
97
|
+
await self._client.__aenter__()
|
|
98
|
+
|
|
99
|
+
if auth_provider:
|
|
100
|
+
logger.info("Connected to A2A agent at %s with authentication (user_id: %s)", base_url, user_id)
|
|
101
|
+
else:
|
|
102
|
+
logger.info("Connected to A2A agent at %s (user_id: %s)", base_url, user_id)
|
|
103
|
+
|
|
104
|
+
# Discover agent card and register functions
|
|
105
|
+
self._register_functions()
|
|
106
|
+
|
|
107
|
+
return self
|
|
108
|
+
|
|
109
|
+
def _register_functions(self):
|
|
110
|
+
"""Retrieve agent card and register the three-level API: high-level, helpers, and low-level."""
|
|
111
|
+
# Validate client is initialized
|
|
112
|
+
if not self._client:
|
|
113
|
+
raise RuntimeError("A2A client not initialized")
|
|
114
|
+
|
|
115
|
+
# Get and validate agent card
|
|
116
|
+
agent_card = self._client.agent_card
|
|
117
|
+
if not agent_card:
|
|
118
|
+
raise RuntimeError("Agent card not available")
|
|
119
|
+
|
|
120
|
+
# Log agent information
|
|
121
|
+
logger.info("Agent: %s v%s", agent_card.name, agent_card.version)
|
|
122
|
+
if agent_card.skills:
|
|
123
|
+
logger.info("Skills: %s", [skill.name for skill in agent_card.skills])
|
|
124
|
+
|
|
125
|
+
# Register functions
|
|
126
|
+
# LEVEL 1: High-level main function (LLM-friendly)
|
|
127
|
+
self.add_function(
|
|
128
|
+
name="call",
|
|
129
|
+
fn=self._create_high_level_function(),
|
|
130
|
+
description=self._format_main_function_description(agent_card),
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# LEVEL 2: Standard helpers (metadata/utility)
|
|
134
|
+
self.add_function(
|
|
135
|
+
name="get_skills",
|
|
136
|
+
fn=self._get_skills,
|
|
137
|
+
description="Get the list of skills and capabilities available from this agent",
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
self.add_function(
|
|
141
|
+
name="get_info",
|
|
142
|
+
fn=self._get_agent_info,
|
|
143
|
+
description="Get metadata about this agent (name, version, provider, capabilities)",
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
self.add_function(
|
|
147
|
+
name="get_task",
|
|
148
|
+
fn=self._wrap_get_task,
|
|
149
|
+
description="Get the status and details of a specific task by task_id",
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
self.add_function(
|
|
153
|
+
name="cancel_task",
|
|
154
|
+
fn=self._wrap_cancel_task,
|
|
155
|
+
description="Cancel a running task by task_id",
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# LEVEL 3: Low-level protocol (advanced)
|
|
159
|
+
self.add_function(
|
|
160
|
+
name="send_message",
|
|
161
|
+
fn=self._send_message_advanced,
|
|
162
|
+
description=("Advanced: Send a message with full control over the A2A protocol. "
|
|
163
|
+
"Returns raw events as a list. For most use cases, prefer using the "
|
|
164
|
+
"high-level 'call()' function instead."),
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
self.add_function(
|
|
168
|
+
name="send_message_streaming",
|
|
169
|
+
fn=self._send_message_streaming,
|
|
170
|
+
description=("Advanced: Send a message and stream response events as they arrive. "
|
|
171
|
+
"Yields raw events one by one. This is an async generator function. "
|
|
172
|
+
"For most use cases, prefer using the high-level 'call()' function instead."),
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
176
|
+
"""Clean up the A2A client."""
|
|
177
|
+
if self._client:
|
|
178
|
+
await self._client.__aexit__(exc_type, exc_value, traceback)
|
|
179
|
+
self._client = None
|
|
180
|
+
logger.info("Disconnected from A2A agent")
|
|
181
|
+
|
|
182
|
+
def _format_main_function_description(self, agent_card) -> str:
|
|
183
|
+
"""Create description for the main agent function."""
|
|
184
|
+
description = f"{agent_card.description}\n\n"
|
|
185
|
+
|
|
186
|
+
# Conditionally include skills based on configuration
|
|
187
|
+
if self._include_skills_in_description and agent_card.skills:
|
|
188
|
+
description += "**Capabilities:**\n"
|
|
189
|
+
for skill in agent_card.skills:
|
|
190
|
+
description += f"\n• **{skill.name}**: {skill.description}"
|
|
191
|
+
if skill.examples:
|
|
192
|
+
examples = skill.examples[:2] # Limit to 2 examples
|
|
193
|
+
description += f"\n Examples: {', '.join(examples)}"
|
|
194
|
+
description += "\n\n"
|
|
195
|
+
elif agent_card.skills:
|
|
196
|
+
# Brief mention that skills are available
|
|
197
|
+
description += f"**{len(agent_card.skills)} capabilities available.** "
|
|
198
|
+
description += "Use get_skills() for detailed information.\n\n"
|
|
199
|
+
|
|
200
|
+
description += "**Usage:** Send natural language queries to interact with this agent."
|
|
201
|
+
|
|
202
|
+
return description
|
|
203
|
+
|
|
204
|
+
def _create_high_level_function(self):
|
|
205
|
+
"""High-level function that simplifies the response."""
|
|
206
|
+
|
|
207
|
+
async def high_level_fn(query: str, task_id: str | None = None, context_id: str | None = None) -> str:
|
|
208
|
+
"""
|
|
209
|
+
Send a query to the agent and get a simple text response.
|
|
210
|
+
|
|
211
|
+
This is the recommended method for LLM usage.
|
|
212
|
+
For advanced use cases, use send_message() for raw events.
|
|
213
|
+
"""
|
|
214
|
+
if not self._client:
|
|
215
|
+
raise RuntimeError("A2A client not initialized")
|
|
216
|
+
|
|
217
|
+
events = []
|
|
218
|
+
async for event in self._client.send_message(query, task_id, context_id):
|
|
219
|
+
events.append(event)
|
|
220
|
+
|
|
221
|
+
# Extract and return just the text response using base client helper
|
|
222
|
+
return self._client.extract_text_from_events(events)
|
|
223
|
+
|
|
224
|
+
return high_level_fn
|
|
225
|
+
|
|
226
|
+
async def _get_skills(self, params: dict | None = None) -> dict:
|
|
227
|
+
"""Helper function to list agent skills."""
|
|
228
|
+
if not self._client or not self._client.agent_card:
|
|
229
|
+
return {"skills": []}
|
|
230
|
+
|
|
231
|
+
agent_card = self._client.agent_card
|
|
232
|
+
return {
|
|
233
|
+
"agent":
|
|
234
|
+
agent_card.name,
|
|
235
|
+
"skills": [{
|
|
236
|
+
"id": skill.id,
|
|
237
|
+
"name": skill.name,
|
|
238
|
+
"description": skill.description,
|
|
239
|
+
"examples": skill.examples or [],
|
|
240
|
+
"tags": skill.tags or []
|
|
241
|
+
} for skill in agent_card.skills]
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
async def _get_agent_info(self, params: dict | None = None) -> dict:
|
|
245
|
+
"""Helper function to get agent metadata."""
|
|
246
|
+
if not self._client or not self._client.agent_card:
|
|
247
|
+
return {}
|
|
248
|
+
|
|
249
|
+
agent_card = self._client.agent_card
|
|
250
|
+
return {
|
|
251
|
+
"name": agent_card.name,
|
|
252
|
+
"description": agent_card.description,
|
|
253
|
+
"version": agent_card.version,
|
|
254
|
+
"provider": agent_card.provider.model_dump() if agent_card.provider else None,
|
|
255
|
+
"url": agent_card.url,
|
|
256
|
+
"capabilities": {
|
|
257
|
+
"streaming": agent_card.capabilities.streaming if agent_card.capabilities else False,
|
|
258
|
+
},
|
|
259
|
+
"num_skills": len(agent_card.skills)
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
async def _wrap_get_task(self, params: GetTaskInput) -> Any:
|
|
263
|
+
"""Wrapper for get_task that delegates to client_base."""
|
|
264
|
+
if not self._client:
|
|
265
|
+
raise RuntimeError("A2A client not initialized")
|
|
266
|
+
return await self._client.get_task(params.task_id, params.history_length)
|
|
267
|
+
|
|
268
|
+
async def _wrap_cancel_task(self, params: CancelTaskInput) -> Any:
|
|
269
|
+
"""Wrapper for cancel_task that delegates to client_base."""
|
|
270
|
+
if not self._client:
|
|
271
|
+
raise RuntimeError("A2A client not initialized")
|
|
272
|
+
return await self._client.cancel_task(params.task_id)
|
|
273
|
+
|
|
274
|
+
async def _send_message_advanced(self, params: SendMessageInput) -> list:
|
|
275
|
+
"""
|
|
276
|
+
Send a message with full A2A protocol control.
|
|
277
|
+
|
|
278
|
+
Returns: List of ClientEvent|Message objects containing:
|
|
279
|
+
- Task information
|
|
280
|
+
- Status updates
|
|
281
|
+
- Artifact updates
|
|
282
|
+
- Full message details
|
|
283
|
+
"""
|
|
284
|
+
if not self._client:
|
|
285
|
+
raise RuntimeError("A2A client not initialized")
|
|
286
|
+
|
|
287
|
+
events = []
|
|
288
|
+
async for event in self._client.send_message(params.query, params.task_id, params.context_id):
|
|
289
|
+
events.append(event)
|
|
290
|
+
return events
|
|
291
|
+
|
|
292
|
+
async def _send_message_streaming(self, params: SendMessageInput) -> AsyncGenerator[Any, None]:
|
|
293
|
+
"""
|
|
294
|
+
Send a message with full A2A protocol control and stream events.
|
|
295
|
+
|
|
296
|
+
This is an async generator that yields events as they arrive from the agent.
|
|
297
|
+
|
|
298
|
+
Yields: ClientEvent|Message objects containing:
|
|
299
|
+
- Task information
|
|
300
|
+
- Status updates
|
|
301
|
+
- Artifact updates
|
|
302
|
+
- Full message details
|
|
303
|
+
"""
|
|
304
|
+
if not self._client:
|
|
305
|
+
raise RuntimeError("A2A client not initialized")
|
|
306
|
+
|
|
307
|
+
async for event in self._client.send_message_streaming(params.query, params.task_id, params.context_id):
|
|
308
|
+
yield event
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
@register_per_user_function_group(config_type=A2AClientConfig)
|
|
312
|
+
async def a2a_client_function_group(config: A2AClientConfig, _builder: Builder):
|
|
313
|
+
"""
|
|
314
|
+
Connect to an A2A agent, discover agent card and publish the primary
|
|
315
|
+
agent function and helper functions. This function group is per-user,
|
|
316
|
+
meaning each user gets their own isolated instance.
|
|
317
|
+
|
|
318
|
+
This function group creates a three-level API:
|
|
319
|
+
- High-level: Agent function named after the agent (e.g., dice_agent)
|
|
320
|
+
- Helpers: get_skills, get_info, get_task, cancel_task
|
|
321
|
+
- Low-level: send_message for advanced usage
|
|
322
|
+
"""
|
|
323
|
+
async with A2AClientFunctionGroup(config, _builder) as group:
|
|
324
|
+
yield group
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
# flake8: noqa
|
|
17
|
+
# isort:skip_file
|
|
18
|
+
|
|
19
|
+
# Register client components
|
|
20
|
+
from .client import client_impl
|
|
21
|
+
|
|
22
|
+
# Register server/frontend components
|
|
23
|
+
from .server import register_frontend
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""Adapter to bridge NAT workflows with A2A AgentExecutor interface.
|
|
16
|
+
|
|
17
|
+
This module implements a message-only A2A agent for Phase 1, providing stateless
|
|
18
|
+
request/response interactions without task lifecycle management.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
import logging
|
|
22
|
+
|
|
23
|
+
from a2a.server.agent_execution import AgentExecutor
|
|
24
|
+
from a2a.server.agent_execution import RequestContext
|
|
25
|
+
from a2a.server.events import EventQueue
|
|
26
|
+
from a2a.types import InternalError
|
|
27
|
+
from a2a.types import InvalidParamsError
|
|
28
|
+
from a2a.types import UnsupportedOperationError
|
|
29
|
+
from a2a.utils import new_agent_text_message
|
|
30
|
+
from a2a.utils.errors import ServerError
|
|
31
|
+
from nat.runtime.session import SessionManager
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class NATWorkflowAgentExecutor(AgentExecutor):
|
|
37
|
+
"""Adapts NAT workflows to A2A AgentExecutor interface as a message-only agent.
|
|
38
|
+
|
|
39
|
+
This adapter implements Phase 1 support for A2A integration, providing stateless
|
|
40
|
+
message-based interactions. Each request is handled independently without maintaining
|
|
41
|
+
conversation state or task lifecycle.
|
|
42
|
+
|
|
43
|
+
Key characteristics:
|
|
44
|
+
- Stateless: Each message is processed independently
|
|
45
|
+
- Synchronous: Returns immediate responses (no long-running tasks)
|
|
46
|
+
- Message-only: Returns Message objects, not Task objects
|
|
47
|
+
- Concurrent: Uses SessionManager's semaphore for concurrency control
|
|
48
|
+
|
|
49
|
+
Note: Multi-turn conversations and task-based interactions are deferred to Phase 5.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(self, session_manager: SessionManager):
|
|
53
|
+
"""Initialize the adapter with a NAT SessionManager.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
session_manager: The SessionManager for handling workflow execution
|
|
57
|
+
with concurrency control via semaphore
|
|
58
|
+
"""
|
|
59
|
+
self.session_manager = session_manager
|
|
60
|
+
logger.info("Initialized NATWorkflowAgentExecutor (message-only) for workflow: %s",
|
|
61
|
+
session_manager.workflow.config.workflow.type)
|
|
62
|
+
|
|
63
|
+
async def execute(
|
|
64
|
+
self,
|
|
65
|
+
context: RequestContext,
|
|
66
|
+
event_queue: EventQueue,
|
|
67
|
+
) -> None:
|
|
68
|
+
"""Execute the NAT workflow and return a message response.
|
|
69
|
+
|
|
70
|
+
This is a message-only implementation (Phase 1):
|
|
71
|
+
1. Extracts the user query from the A2A message
|
|
72
|
+
2. Runs the NAT workflow (stateless, no conversation history)
|
|
73
|
+
3. Returns the result as a Message object (not a Task)
|
|
74
|
+
|
|
75
|
+
For Phase 1, each message is handled independently with no state preservation
|
|
76
|
+
between requests. The context_id and task_id from the A2A protocol are mapped
|
|
77
|
+
to NAT's conversation_id and user_message_id for tracing purposes only.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
context: The A2A request context containing the user message
|
|
81
|
+
event_queue: Queue for sending the response message back to the client
|
|
82
|
+
|
|
83
|
+
Raises:
|
|
84
|
+
ServerError: If validation fails or workflow execution errors occur
|
|
85
|
+
"""
|
|
86
|
+
# Validate the request
|
|
87
|
+
error = self._validate_request(context)
|
|
88
|
+
if error:
|
|
89
|
+
raise ServerError(error=InvalidParamsError())
|
|
90
|
+
|
|
91
|
+
# Extract query from the message
|
|
92
|
+
query = context.get_user_input()
|
|
93
|
+
if not query:
|
|
94
|
+
logger.error("No user input found in context")
|
|
95
|
+
raise ServerError(error=InvalidParamsError())
|
|
96
|
+
|
|
97
|
+
# Extract IDs for tracing (stored but not used for state management in Phase 1)
|
|
98
|
+
context_id = context.context_id
|
|
99
|
+
task_id = context.task_id
|
|
100
|
+
|
|
101
|
+
logger.info("Processing message-only request (context_id=%s, task_id=%s): %s", context_id, task_id, query[:100])
|
|
102
|
+
|
|
103
|
+
try:
|
|
104
|
+
# Run the NAT workflow using SessionManager for proper concurrency handling
|
|
105
|
+
# Each message gets its own independent session (stateless)
|
|
106
|
+
# TODO: Add support for user input callbacks and authentication in later phases
|
|
107
|
+
async with self.session_manager.session() as session:
|
|
108
|
+
async with session.run(query) as runner:
|
|
109
|
+
# Get the result as a string
|
|
110
|
+
response_text = await runner.result(to_type=str)
|
|
111
|
+
|
|
112
|
+
logger.info("Workflow completed successfully (context_id=%s, task_id=%s)", context_id, task_id)
|
|
113
|
+
|
|
114
|
+
# Create and send the response message (message-only pattern)
|
|
115
|
+
response_message = new_agent_text_message(
|
|
116
|
+
response_text,
|
|
117
|
+
context_id=context_id,
|
|
118
|
+
task_id=task_id,
|
|
119
|
+
)
|
|
120
|
+
await event_queue.enqueue_event(response_message)
|
|
121
|
+
|
|
122
|
+
except Exception as e:
|
|
123
|
+
logger.error("Error executing NAT workflow (context_id=%s, task_id=%s): %s",
|
|
124
|
+
context_id,
|
|
125
|
+
task_id,
|
|
126
|
+
e,
|
|
127
|
+
exc_info=True)
|
|
128
|
+
|
|
129
|
+
# Send error message back to client
|
|
130
|
+
error_message = new_agent_text_message(
|
|
131
|
+
f"An error occurred while processing your request: {str(e)}",
|
|
132
|
+
context_id=context_id,
|
|
133
|
+
task_id=task_id,
|
|
134
|
+
)
|
|
135
|
+
await event_queue.enqueue_event(error_message)
|
|
136
|
+
raise ServerError(error=InternalError()) from e
|
|
137
|
+
|
|
138
|
+
def _validate_request(self, context: RequestContext) -> bool:
|
|
139
|
+
"""Validate the incoming request context.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
context: The request context to validate
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
True if validation fails, False if validation succeeds
|
|
146
|
+
"""
|
|
147
|
+
# Basic validation - can be extended as needed
|
|
148
|
+
if not context.message:
|
|
149
|
+
logger.error("Request context has no message")
|
|
150
|
+
return True
|
|
151
|
+
|
|
152
|
+
return False
|
|
153
|
+
|
|
154
|
+
async def cancel(
|
|
155
|
+
self,
|
|
156
|
+
_context: RequestContext,
|
|
157
|
+
_event_queue: EventQueue,
|
|
158
|
+
) -> None:
|
|
159
|
+
"""Handle task cancellation requests.
|
|
160
|
+
|
|
161
|
+
Not applicable for message-only agents in Phase 1. Cancellation is a task-based
|
|
162
|
+
feature that will be implemented in Phase 5 along with long-running task support.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
_context: The request context (unused in Phase 1)
|
|
166
|
+
_event_queue: Event queue for sending updates (unused in Phase 1)
|
|
167
|
+
|
|
168
|
+
Raises:
|
|
169
|
+
ServerError: Always raises UnsupportedOperationError
|
|
170
|
+
"""
|
|
171
|
+
logger.warning("Task cancellation requested but not supported in message-only mode (Phase 1)")
|
|
172
|
+
raise ServerError(error=UnsupportedOperationError())
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
from pydantic import Field
|
|
20
|
+
from pydantic import model_validator
|
|
21
|
+
|
|
22
|
+
from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
|
|
23
|
+
from nat.data_models.front_end import FrontEndBaseConfig
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class A2ACapabilitiesConfig(BaseModel):
|
|
29
|
+
"""A2A agent capabilities configuration."""
|
|
30
|
+
|
|
31
|
+
streaming: bool = Field(
|
|
32
|
+
default=True,
|
|
33
|
+
description="Enable streaming responses (default: True)",
|
|
34
|
+
)
|
|
35
|
+
push_notifications: bool = Field(
|
|
36
|
+
default=False,
|
|
37
|
+
description="Enable push notifications (default: False)",
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class A2AFrontEndConfig(FrontEndBaseConfig, name="a2a"):
|
|
42
|
+
"""A2A front end configuration.
|
|
43
|
+
|
|
44
|
+
A front end that exposes NeMo Agent toolkit workflows as A2A-compliant remote agents.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
# Server settings
|
|
48
|
+
host: str = Field(
|
|
49
|
+
default="localhost",
|
|
50
|
+
description="Host to bind the server to (default: localhost)",
|
|
51
|
+
)
|
|
52
|
+
port: int = Field(
|
|
53
|
+
default=10000,
|
|
54
|
+
description="Port to bind the server to (default: 10000)",
|
|
55
|
+
ge=0,
|
|
56
|
+
le=65535,
|
|
57
|
+
)
|
|
58
|
+
version: str = Field(
|
|
59
|
+
default="1.0.0",
|
|
60
|
+
description="Version of the agent (default: 1.0.0)",
|
|
61
|
+
)
|
|
62
|
+
log_level: str = Field(
|
|
63
|
+
default="INFO",
|
|
64
|
+
description="Log level for the A2A server (default: INFO)",
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# Agent metadata
|
|
68
|
+
name: str = Field(
|
|
69
|
+
default="NeMo Agent Toolkit A2A Agent",
|
|
70
|
+
description="Name of the A2A agent (default: NeMo Agent Toolkit A2A Agent)",
|
|
71
|
+
)
|
|
72
|
+
description: str = Field(
|
|
73
|
+
default="An AI agent powered by NeMo Agent Toolkit exposed via A2A protocol",
|
|
74
|
+
description="Description of what the agent does (default: generic description)",
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# A2A capabilities
|
|
78
|
+
capabilities: A2ACapabilitiesConfig = Field(
|
|
79
|
+
default_factory=A2ACapabilitiesConfig,
|
|
80
|
+
description="Agent capabilities configuration",
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# Concurrency control
|
|
84
|
+
max_concurrency: int = Field(
|
|
85
|
+
default=8,
|
|
86
|
+
description="Maximum number of concurrent workflow executions (default: 8). "
|
|
87
|
+
"Controls how many A2A requests can execute workflows simultaneously. "
|
|
88
|
+
"Set to 0 or -1 for unlimited concurrency.",
|
|
89
|
+
ge=-1,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# Content modes
|
|
93
|
+
default_input_modes: list[str] = Field(
|
|
94
|
+
default_factory=lambda: ["text", "text/plain"],
|
|
95
|
+
description="Supported input content types (default: text, text/plain)",
|
|
96
|
+
)
|
|
97
|
+
default_output_modes: list[str] = Field(
|
|
98
|
+
default_factory=lambda: ["text", "text/plain"],
|
|
99
|
+
description="Supported output content types (default: text, text/plain)",
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# Optional customization
|
|
103
|
+
runner_class: str | None = Field(
|
|
104
|
+
default=None,
|
|
105
|
+
description="Custom worker class for handling A2A routes (default: built-in worker)",
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# OAuth2 Resource Server (for protecting this A2A agent)
|
|
109
|
+
server_auth: OAuth2ResourceServerConfig | None = Field(
|
|
110
|
+
default=None,
|
|
111
|
+
description=("OAuth 2.0 Resource Server configuration for token verification. "
|
|
112
|
+
"When configured, the A2A server will validate OAuth2 Bearer tokens on all requests "
|
|
113
|
+
"except public agent card discovery. Supports both JWT validation (via JWKS) and "
|
|
114
|
+
"opaque token validation (via RFC 7662 introspection)."),
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
@model_validator(mode="after")
|
|
118
|
+
def validate_security_configuration(self):
|
|
119
|
+
"""Validate security configuration to prevent accidental misconfigurations."""
|
|
120
|
+
# Check if server is bound to a non-localhost interface without authentication
|
|
121
|
+
localhost_hosts = {"localhost", "127.0.0.1", "::1"}
|
|
122
|
+
if self.host not in localhost_hosts and self.server_auth is None:
|
|
123
|
+
logger.warning(
|
|
124
|
+
"A2A server is configured to bind to '%s' without authentication. "
|
|
125
|
+
"This may expose your server to unauthorized access. "
|
|
126
|
+
"Consider either: (1) binding to localhost for local-only access, "
|
|
127
|
+
"or (2) configuring server_auth for production deployments on public interfaces.",
|
|
128
|
+
self.host,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
return self
|