nvidia-nat-a2a 1.5.0a20251229__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nvidia-nat-a2a might be problematic. Click here for more details.
- nat/meta/pypi.md +36 -0
- nat/plugins/a2a/__init__.py +14 -0
- nat/plugins/a2a/auth/__init__.py +15 -0
- nat/plugins/a2a/auth/credential_service.py +418 -0
- nat/plugins/a2a/client/__init__.py +14 -0
- nat/plugins/a2a/client/client_base.py +354 -0
- nat/plugins/a2a/client/client_config.py +72 -0
- nat/plugins/a2a/client/client_impl.py +324 -0
- nat/plugins/a2a/register.py +23 -0
- nat/plugins/a2a/server/__init__.py +14 -0
- nat/plugins/a2a/server/agent_executor_adapter.py +172 -0
- nat/plugins/a2a/server/front_end_config.py +131 -0
- nat/plugins/a2a/server/front_end_plugin.py +122 -0
- nat/plugins/a2a/server/front_end_plugin_worker.py +306 -0
- nat/plugins/a2a/server/oauth_middleware.py +121 -0
- nat/plugins/a2a/server/register_frontend.py +37 -0
- nvidia_nat_a2a-1.5.0a20251229.dist-info/METADATA +57 -0
- nvidia_nat_a2a-1.5.0a20251229.dist-info/RECORD +22 -0
- nvidia_nat_a2a-1.5.0a20251229.dist-info/WHEEL +5 -0
- nvidia_nat_a2a-1.5.0a20251229.dist-info/entry_points.txt +5 -0
- nvidia_nat_a2a-1.5.0a20251229.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat_a2a-1.5.0a20251229.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,354 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import logging
|
|
19
|
+
from collections.abc import AsyncGenerator
|
|
20
|
+
from datetime import timedelta
|
|
21
|
+
from typing import TYPE_CHECKING
|
|
22
|
+
from uuid import uuid4
|
|
23
|
+
|
|
24
|
+
import httpx
|
|
25
|
+
|
|
26
|
+
from a2a.client import A2ACardResolver
|
|
27
|
+
from a2a.client import Client
|
|
28
|
+
from a2a.client import ClientConfig
|
|
29
|
+
from a2a.client import ClientEvent
|
|
30
|
+
from a2a.client import ClientFactory
|
|
31
|
+
from a2a.types import AgentCard
|
|
32
|
+
from a2a.types import Message
|
|
33
|
+
from a2a.types import Part
|
|
34
|
+
from a2a.types import Role
|
|
35
|
+
from a2a.types import Task
|
|
36
|
+
from a2a.types import TextPart
|
|
37
|
+
|
|
38
|
+
if TYPE_CHECKING:
|
|
39
|
+
from nat.authentication.interfaces import AuthProviderBase
|
|
40
|
+
|
|
41
|
+
logger = logging.getLogger(__name__)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class A2ABaseClient:
|
|
45
|
+
"""
|
|
46
|
+
Minimal A2A client for connecting to an A2A agent.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
base_url: The base URL of the A2A agent
|
|
50
|
+
agent_card_path: Path to agent card (default: /.well-known/agent-card.json)
|
|
51
|
+
task_timeout: Timeout for task operations (default: 300 seconds)
|
|
52
|
+
streaming: Enable streaming responses (default: True)
|
|
53
|
+
auth_provider: Optional NAT authentication provider for securing requests
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
base_url: str,
|
|
59
|
+
agent_card_path: str = "/.well-known/agent-card.json",
|
|
60
|
+
task_timeout: timedelta = timedelta(seconds=300),
|
|
61
|
+
streaming: bool = True,
|
|
62
|
+
auth_provider: AuthProviderBase | None = None,
|
|
63
|
+
):
|
|
64
|
+
self._base_url = base_url
|
|
65
|
+
self._agent_card_path = agent_card_path
|
|
66
|
+
self._task_timeout = task_timeout
|
|
67
|
+
self._streaming = streaming
|
|
68
|
+
self._auth_provider = auth_provider
|
|
69
|
+
|
|
70
|
+
self._httpx_client: httpx.AsyncClient | None = None
|
|
71
|
+
self._client: Client | None = None
|
|
72
|
+
self._agent_card: AgentCard | None = None
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def base_url(self) -> str:
|
|
76
|
+
return self._base_url
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def agent_card(self) -> AgentCard | None:
|
|
80
|
+
return self._agent_card
|
|
81
|
+
|
|
82
|
+
async def __aenter__(self):
|
|
83
|
+
if self._httpx_client is not None or self._client is not None:
|
|
84
|
+
raise RuntimeError("A2ABaseClient already initialized")
|
|
85
|
+
|
|
86
|
+
# 1) Create httpx client explicitly
|
|
87
|
+
self._httpx_client = httpx.AsyncClient(timeout=httpx.Timeout(self._task_timeout.total_seconds()))
|
|
88
|
+
|
|
89
|
+
# 2) Resolve agent card
|
|
90
|
+
await self._resolve_agent_card()
|
|
91
|
+
if not self._agent_card:
|
|
92
|
+
raise RuntimeError("Agent card not resolved")
|
|
93
|
+
|
|
94
|
+
# 3) Setup authentication interceptors if auth is configured
|
|
95
|
+
interceptors = []
|
|
96
|
+
if self._auth_provider:
|
|
97
|
+
try:
|
|
98
|
+
from a2a.client import AuthInterceptor
|
|
99
|
+
from nat.plugins.a2a.auth.credential_service import A2ACredentialService
|
|
100
|
+
|
|
101
|
+
credential_service = A2ACredentialService(
|
|
102
|
+
auth_provider=self._auth_provider,
|
|
103
|
+
agent_card=self._agent_card,
|
|
104
|
+
)
|
|
105
|
+
interceptors.append(AuthInterceptor(credential_service))
|
|
106
|
+
logger.info("Authentication configured for A2A client")
|
|
107
|
+
except ImportError as e:
|
|
108
|
+
logger.error("Failed to setup authentication: %s", e)
|
|
109
|
+
raise RuntimeError("Authentication requires a2a-sdk with AuthInterceptor support") from e
|
|
110
|
+
|
|
111
|
+
# 4) Create A2A client with interceptors
|
|
112
|
+
client_config = ClientConfig(
|
|
113
|
+
httpx_client=self._httpx_client,
|
|
114
|
+
streaming=self._streaming,
|
|
115
|
+
)
|
|
116
|
+
factory = ClientFactory(client_config)
|
|
117
|
+
self._client = factory.create(self._agent_card, interceptors=interceptors)
|
|
118
|
+
|
|
119
|
+
logger.info("Connected to A2A agent at %s", self._base_url)
|
|
120
|
+
return self
|
|
121
|
+
|
|
122
|
+
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
123
|
+
# Close A2A client first (if it exposes aclose)
|
|
124
|
+
if self._client is not None:
|
|
125
|
+
aclose = getattr(self._client, "aclose", None)
|
|
126
|
+
if aclose is not None:
|
|
127
|
+
try:
|
|
128
|
+
await aclose()
|
|
129
|
+
except Exception:
|
|
130
|
+
logger.warning("Error while closing A2A client", exc_info=True)
|
|
131
|
+
|
|
132
|
+
# Then close httpx client
|
|
133
|
+
if self._httpx_client is not None:
|
|
134
|
+
try:
|
|
135
|
+
await self._httpx_client.aclose()
|
|
136
|
+
except Exception:
|
|
137
|
+
logger.warning("Error while closing HTTPX client", exc_info=True)
|
|
138
|
+
|
|
139
|
+
self._httpx_client = None
|
|
140
|
+
self._client = None
|
|
141
|
+
self._agent_card = None
|
|
142
|
+
|
|
143
|
+
async def _resolve_agent_card(self):
|
|
144
|
+
"""Fetch the agent card from the A2A agent."""
|
|
145
|
+
if not self._httpx_client:
|
|
146
|
+
raise RuntimeError("httpx_client is not initialized")
|
|
147
|
+
|
|
148
|
+
try:
|
|
149
|
+
resolver = A2ACardResolver(httpx_client=self._httpx_client,
|
|
150
|
+
base_url=self._base_url,
|
|
151
|
+
agent_card_path=self._agent_card_path)
|
|
152
|
+
logger.info("Fetching agent card from: %s%s", self._base_url, self._agent_card_path)
|
|
153
|
+
self._agent_card = await resolver.get_agent_card()
|
|
154
|
+
logger.info("Successfully fetched public agent card")
|
|
155
|
+
# TODO: add support for authenticated extended agent card
|
|
156
|
+
except Exception as e:
|
|
157
|
+
logger.error("Failed to fetch agent card: %s", e, exc_info=True)
|
|
158
|
+
raise RuntimeError(f"Failed to fetch agent card from {self._base_url}") from e
|
|
159
|
+
|
|
160
|
+
async def send_message(self,
|
|
161
|
+
message_text: str,
|
|
162
|
+
task_id: str | None = None,
|
|
163
|
+
context_id: str | None = None) -> AsyncGenerator[ClientEvent | Message, None]:
|
|
164
|
+
"""
|
|
165
|
+
Send a message to the agent and stream response events.
|
|
166
|
+
|
|
167
|
+
This is the low-level A2A protocol method that yields events as they arrive.
|
|
168
|
+
For simpler usage, prefer the high-level agent function registered by this client.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
message_text: The message text to send
|
|
172
|
+
task_id: Optional task ID to continue an existing conversation
|
|
173
|
+
context_id: Optional context ID for the conversation
|
|
174
|
+
|
|
175
|
+
Yields:
|
|
176
|
+
ClientEvent | Message: The agent's response events as they arrive.
|
|
177
|
+
ClientEvent is a tuple of (Task, UpdateEvent | None)
|
|
178
|
+
Message is a direct message response
|
|
179
|
+
"""
|
|
180
|
+
if not self._client:
|
|
181
|
+
raise RuntimeError("A2ABaseClient not initialized")
|
|
182
|
+
|
|
183
|
+
text_part = TextPart(text=message_text)
|
|
184
|
+
parts: list[Part] = [Part(root=text_part)]
|
|
185
|
+
message = Message(role=Role.user, parts=parts, message_id=uuid4().hex, task_id=task_id, context_id=context_id)
|
|
186
|
+
|
|
187
|
+
async for response in self._client.send_message(message):
|
|
188
|
+
yield response
|
|
189
|
+
|
|
190
|
+
async def get_task(self, task_id: str, history_length: int | None = None) -> Task:
|
|
191
|
+
"""
|
|
192
|
+
Get the status and details of a specific task.
|
|
193
|
+
|
|
194
|
+
This is an A2A protocol operation for retrieving task information.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
task_id: The unique identifier of the task
|
|
198
|
+
history_length: Optional limit on the number of history messages to retrieve
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
Task: The task object with current status and history
|
|
202
|
+
"""
|
|
203
|
+
if not self._client:
|
|
204
|
+
raise RuntimeError("A2ABaseClient not initialized")
|
|
205
|
+
|
|
206
|
+
from a2a.types import TaskQueryParams
|
|
207
|
+
params = TaskQueryParams(id=task_id, history_length=history_length)
|
|
208
|
+
return await self._client.get_task(params)
|
|
209
|
+
|
|
210
|
+
async def cancel_task(self, task_id: str) -> Task:
|
|
211
|
+
"""
|
|
212
|
+
Cancel a running task.
|
|
213
|
+
|
|
214
|
+
This is an A2A protocol operation for canceling tasks.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
task_id: The unique identifier of the task to cancel
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
Task: The task object with updated status
|
|
221
|
+
"""
|
|
222
|
+
if not self._client:
|
|
223
|
+
raise RuntimeError("A2ABaseClient not initialized")
|
|
224
|
+
|
|
225
|
+
from a2a.types import TaskIdParams
|
|
226
|
+
params = TaskIdParams(id=task_id)
|
|
227
|
+
return await self._client.cancel_task(params)
|
|
228
|
+
|
|
229
|
+
async def send_message_streaming(self,
|
|
230
|
+
message_text: str,
|
|
231
|
+
task_id: str | None = None,
|
|
232
|
+
context_id: str | None = None) -> AsyncGenerator[ClientEvent | Message, None]:
|
|
233
|
+
"""
|
|
234
|
+
Send a message to the agent and stream response events (alias for send_message).
|
|
235
|
+
|
|
236
|
+
This method provides an explicit streaming interface that mirrors the A2A SDK pattern.
|
|
237
|
+
It is functionally identical to send_message(), which already streams events.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
message_text: The message text to send
|
|
241
|
+
task_id: Optional task ID to continue an existing conversation
|
|
242
|
+
context_id: Optional context ID for the conversation
|
|
243
|
+
|
|
244
|
+
Yields:
|
|
245
|
+
ClientEvent | Message: The agent's response events as they arrive.
|
|
246
|
+
"""
|
|
247
|
+
async for event in self.send_message(message_text, task_id=task_id, context_id=context_id):
|
|
248
|
+
yield event
|
|
249
|
+
|
|
250
|
+
def extract_text_from_parts(self, parts: list) -> list[str]:
|
|
251
|
+
"""
|
|
252
|
+
Extract text content from A2A message parts.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
parts: List of A2A Part objects
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
List of text strings extracted from the parts
|
|
259
|
+
"""
|
|
260
|
+
text_parts = []
|
|
261
|
+
for part in parts:
|
|
262
|
+
# Handle Part wrapper (RootModel)
|
|
263
|
+
if hasattr(part, 'root'):
|
|
264
|
+
part_content = part.root
|
|
265
|
+
else:
|
|
266
|
+
part_content = part
|
|
267
|
+
|
|
268
|
+
# Extract text from TextPart
|
|
269
|
+
if hasattr(part_content, 'text'):
|
|
270
|
+
text_parts.append(part_content.text)
|
|
271
|
+
|
|
272
|
+
return text_parts
|
|
273
|
+
|
|
274
|
+
def extract_text_from_task(self, task) -> str:
|
|
275
|
+
"""
|
|
276
|
+
Extract text response from an A2A Task object.
|
|
277
|
+
|
|
278
|
+
This method understands the A2A protocol structure and extracts the final
|
|
279
|
+
text response from a completed task, prioritizing artifacts over history.
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
task: A2A Task object
|
|
283
|
+
|
|
284
|
+
Returns:
|
|
285
|
+
Extracted text response or status message
|
|
286
|
+
|
|
287
|
+
Priority order:
|
|
288
|
+
1. Check task status (return error/progress if not completed)
|
|
289
|
+
2. Extract from task.artifacts (structured output)
|
|
290
|
+
3. Fallback to last agent message in task.history
|
|
291
|
+
"""
|
|
292
|
+
from a2a.types import TaskState
|
|
293
|
+
|
|
294
|
+
# Check task status
|
|
295
|
+
if task.status and task.status.state != TaskState.completed:
|
|
296
|
+
# Task not completed - return status message or indicate in progress
|
|
297
|
+
if task.status.state == TaskState.failed:
|
|
298
|
+
return f"Task failed: {task.status.message or 'Unknown error'}"
|
|
299
|
+
return f"Task in progress (state: {task.status.state})"
|
|
300
|
+
|
|
301
|
+
# Priority 1: Extract from artifacts (structured output)
|
|
302
|
+
if task.artifacts:
|
|
303
|
+
# Get text from all artifacts
|
|
304
|
+
all_text = []
|
|
305
|
+
for artifact in task.artifacts:
|
|
306
|
+
if artifact.parts:
|
|
307
|
+
text_parts = self.extract_text_from_parts(artifact.parts)
|
|
308
|
+
if text_parts:
|
|
309
|
+
all_text.extend(text_parts)
|
|
310
|
+
if all_text:
|
|
311
|
+
return " ".join(all_text)
|
|
312
|
+
|
|
313
|
+
# Priority 2: Fallback to history (conversational messages)
|
|
314
|
+
if task.history:
|
|
315
|
+
# Get the last agent message from history
|
|
316
|
+
for msg in reversed(task.history):
|
|
317
|
+
if msg.role.value == 'agent': # Get last agent message
|
|
318
|
+
text_parts = self.extract_text_from_parts(msg.parts)
|
|
319
|
+
if text_parts:
|
|
320
|
+
return " ".join(text_parts)
|
|
321
|
+
|
|
322
|
+
return "No response"
|
|
323
|
+
|
|
324
|
+
def extract_text_from_events(self, events: list) -> str:
|
|
325
|
+
"""
|
|
326
|
+
Extract text response from a list of A2A events.
|
|
327
|
+
|
|
328
|
+
This is a convenience method that handles both Message and ClientEvent types.
|
|
329
|
+
|
|
330
|
+
Args:
|
|
331
|
+
events: List of A2A events (ClientEvent or Message objects)
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
Extracted text response
|
|
335
|
+
"""
|
|
336
|
+
from a2a.types import Message as A2AMessage
|
|
337
|
+
|
|
338
|
+
if not events:
|
|
339
|
+
return "No response"
|
|
340
|
+
|
|
341
|
+
# Get the last event
|
|
342
|
+
last_event = events[-1]
|
|
343
|
+
|
|
344
|
+
# If it's a Message, extract text from parts
|
|
345
|
+
if isinstance(last_event, A2AMessage):
|
|
346
|
+
text_parts = self.extract_text_from_parts(last_event.parts)
|
|
347
|
+
return " ".join(text_parts) if text_parts else str(last_event)
|
|
348
|
+
|
|
349
|
+
# If it's a ClientEvent (Task, TaskStatusUpdateEvent), extract from task
|
|
350
|
+
if isinstance(last_event, tuple):
|
|
351
|
+
task, _ = last_event
|
|
352
|
+
return self.extract_text_from_task(task)
|
|
353
|
+
|
|
354
|
+
return str(last_event)
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from datetime import timedelta
|
|
17
|
+
|
|
18
|
+
from pydantic import Field
|
|
19
|
+
from pydantic import HttpUrl
|
|
20
|
+
|
|
21
|
+
from nat.data_models.component_ref import AuthenticationRef
|
|
22
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class A2AClientConfig(FunctionGroupBaseConfig, name="a2a_client"):
|
|
26
|
+
"""Configuration for A2A client function group.
|
|
27
|
+
|
|
28
|
+
This configuration enables NAT workflows to connect to remote A2A agents
|
|
29
|
+
and publish the primary agent function and helper functions.
|
|
30
|
+
|
|
31
|
+
Attributes:
|
|
32
|
+
url: The base URL of the A2A agent (e.g., https://agent.example.com)
|
|
33
|
+
agent_card_path: Path to the agent card (default: /.well-known/agent-card.json)
|
|
34
|
+
task_timeout: Maximum time to wait for task completion (default: 300 seconds)
|
|
35
|
+
include_skills_in_description: Include skill details in high-level function description (default: True)
|
|
36
|
+
streaming: Whether to enable streaming support for the A2A client (default: False)
|
|
37
|
+
auth_provider: Optional reference to NAT auth provider for authentication
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
url: HttpUrl = Field(
|
|
41
|
+
...,
|
|
42
|
+
description="Base URL of the A2A agent",
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
agent_card_path: str = Field(
|
|
46
|
+
default='/.well-known/agent-card.json',
|
|
47
|
+
description="Path to the agent card",
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
task_timeout: timedelta = Field(
|
|
51
|
+
default=timedelta(seconds=300),
|
|
52
|
+
description="Maximum time to wait for task completion",
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
include_skills_in_description: bool = Field(
|
|
56
|
+
default=True,
|
|
57
|
+
description="Include skill details in the high-level agent function description. "
|
|
58
|
+
"Set to False for shorter descriptions (useful for token optimization). "
|
|
59
|
+
"Skills are always available via get_skills() helper.",
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# streaming is disabled by default because of AIQ-2496
|
|
63
|
+
streaming: bool = Field(
|
|
64
|
+
default=False,
|
|
65
|
+
description="Whether to enable streaming support for the A2A client",
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
auth_provider: str | AuthenticationRef | None = Field(
|
|
69
|
+
default=None,
|
|
70
|
+
description="Reference to NAT authentication provider for authenticating with the A2A agent. "
|
|
71
|
+
"Supports OAuth2, API Key, HTTP Basic, and other NAT auth providers.",
|
|
72
|
+
)
|