nvidia-nat-a2a 1.5.0a20251229__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nvidia-nat-a2a might be problematic. Click here for more details.

@@ -0,0 +1,354 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import logging
19
+ from collections.abc import AsyncGenerator
20
+ from datetime import timedelta
21
+ from typing import TYPE_CHECKING
22
+ from uuid import uuid4
23
+
24
+ import httpx
25
+
26
+ from a2a.client import A2ACardResolver
27
+ from a2a.client import Client
28
+ from a2a.client import ClientConfig
29
+ from a2a.client import ClientEvent
30
+ from a2a.client import ClientFactory
31
+ from a2a.types import AgentCard
32
+ from a2a.types import Message
33
+ from a2a.types import Part
34
+ from a2a.types import Role
35
+ from a2a.types import Task
36
+ from a2a.types import TextPart
37
+
38
+ if TYPE_CHECKING:
39
+ from nat.authentication.interfaces import AuthProviderBase
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ class A2ABaseClient:
45
+ """
46
+ Minimal A2A client for connecting to an A2A agent.
47
+
48
+ Args:
49
+ base_url: The base URL of the A2A agent
50
+ agent_card_path: Path to agent card (default: /.well-known/agent-card.json)
51
+ task_timeout: Timeout for task operations (default: 300 seconds)
52
+ streaming: Enable streaming responses (default: True)
53
+ auth_provider: Optional NAT authentication provider for securing requests
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ base_url: str,
59
+ agent_card_path: str = "/.well-known/agent-card.json",
60
+ task_timeout: timedelta = timedelta(seconds=300),
61
+ streaming: bool = True,
62
+ auth_provider: AuthProviderBase | None = None,
63
+ ):
64
+ self._base_url = base_url
65
+ self._agent_card_path = agent_card_path
66
+ self._task_timeout = task_timeout
67
+ self._streaming = streaming
68
+ self._auth_provider = auth_provider
69
+
70
+ self._httpx_client: httpx.AsyncClient | None = None
71
+ self._client: Client | None = None
72
+ self._agent_card: AgentCard | None = None
73
+
74
+ @property
75
+ def base_url(self) -> str:
76
+ return self._base_url
77
+
78
+ @property
79
+ def agent_card(self) -> AgentCard | None:
80
+ return self._agent_card
81
+
82
+ async def __aenter__(self):
83
+ if self._httpx_client is not None or self._client is not None:
84
+ raise RuntimeError("A2ABaseClient already initialized")
85
+
86
+ # 1) Create httpx client explicitly
87
+ self._httpx_client = httpx.AsyncClient(timeout=httpx.Timeout(self._task_timeout.total_seconds()))
88
+
89
+ # 2) Resolve agent card
90
+ await self._resolve_agent_card()
91
+ if not self._agent_card:
92
+ raise RuntimeError("Agent card not resolved")
93
+
94
+ # 3) Setup authentication interceptors if auth is configured
95
+ interceptors = []
96
+ if self._auth_provider:
97
+ try:
98
+ from a2a.client import AuthInterceptor
99
+ from nat.plugins.a2a.auth.credential_service import A2ACredentialService
100
+
101
+ credential_service = A2ACredentialService(
102
+ auth_provider=self._auth_provider,
103
+ agent_card=self._agent_card,
104
+ )
105
+ interceptors.append(AuthInterceptor(credential_service))
106
+ logger.info("Authentication configured for A2A client")
107
+ except ImportError as e:
108
+ logger.error("Failed to setup authentication: %s", e)
109
+ raise RuntimeError("Authentication requires a2a-sdk with AuthInterceptor support") from e
110
+
111
+ # 4) Create A2A client with interceptors
112
+ client_config = ClientConfig(
113
+ httpx_client=self._httpx_client,
114
+ streaming=self._streaming,
115
+ )
116
+ factory = ClientFactory(client_config)
117
+ self._client = factory.create(self._agent_card, interceptors=interceptors)
118
+
119
+ logger.info("Connected to A2A agent at %s", self._base_url)
120
+ return self
121
+
122
+ async def __aexit__(self, exc_type, exc_value, traceback):
123
+ # Close A2A client first (if it exposes aclose)
124
+ if self._client is not None:
125
+ aclose = getattr(self._client, "aclose", None)
126
+ if aclose is not None:
127
+ try:
128
+ await aclose()
129
+ except Exception:
130
+ logger.warning("Error while closing A2A client", exc_info=True)
131
+
132
+ # Then close httpx client
133
+ if self._httpx_client is not None:
134
+ try:
135
+ await self._httpx_client.aclose()
136
+ except Exception:
137
+ logger.warning("Error while closing HTTPX client", exc_info=True)
138
+
139
+ self._httpx_client = None
140
+ self._client = None
141
+ self._agent_card = None
142
+
143
+ async def _resolve_agent_card(self):
144
+ """Fetch the agent card from the A2A agent."""
145
+ if not self._httpx_client:
146
+ raise RuntimeError("httpx_client is not initialized")
147
+
148
+ try:
149
+ resolver = A2ACardResolver(httpx_client=self._httpx_client,
150
+ base_url=self._base_url,
151
+ agent_card_path=self._agent_card_path)
152
+ logger.info("Fetching agent card from: %s%s", self._base_url, self._agent_card_path)
153
+ self._agent_card = await resolver.get_agent_card()
154
+ logger.info("Successfully fetched public agent card")
155
+ # TODO: add support for authenticated extended agent card
156
+ except Exception as e:
157
+ logger.error("Failed to fetch agent card: %s", e, exc_info=True)
158
+ raise RuntimeError(f"Failed to fetch agent card from {self._base_url}") from e
159
+
160
+ async def send_message(self,
161
+ message_text: str,
162
+ task_id: str | None = None,
163
+ context_id: str | None = None) -> AsyncGenerator[ClientEvent | Message, None]:
164
+ """
165
+ Send a message to the agent and stream response events.
166
+
167
+ This is the low-level A2A protocol method that yields events as they arrive.
168
+ For simpler usage, prefer the high-level agent function registered by this client.
169
+
170
+ Args:
171
+ message_text: The message text to send
172
+ task_id: Optional task ID to continue an existing conversation
173
+ context_id: Optional context ID for the conversation
174
+
175
+ Yields:
176
+ ClientEvent | Message: The agent's response events as they arrive.
177
+ ClientEvent is a tuple of (Task, UpdateEvent | None)
178
+ Message is a direct message response
179
+ """
180
+ if not self._client:
181
+ raise RuntimeError("A2ABaseClient not initialized")
182
+
183
+ text_part = TextPart(text=message_text)
184
+ parts: list[Part] = [Part(root=text_part)]
185
+ message = Message(role=Role.user, parts=parts, message_id=uuid4().hex, task_id=task_id, context_id=context_id)
186
+
187
+ async for response in self._client.send_message(message):
188
+ yield response
189
+
190
+ async def get_task(self, task_id: str, history_length: int | None = None) -> Task:
191
+ """
192
+ Get the status and details of a specific task.
193
+
194
+ This is an A2A protocol operation for retrieving task information.
195
+
196
+ Args:
197
+ task_id: The unique identifier of the task
198
+ history_length: Optional limit on the number of history messages to retrieve
199
+
200
+ Returns:
201
+ Task: The task object with current status and history
202
+ """
203
+ if not self._client:
204
+ raise RuntimeError("A2ABaseClient not initialized")
205
+
206
+ from a2a.types import TaskQueryParams
207
+ params = TaskQueryParams(id=task_id, history_length=history_length)
208
+ return await self._client.get_task(params)
209
+
210
+ async def cancel_task(self, task_id: str) -> Task:
211
+ """
212
+ Cancel a running task.
213
+
214
+ This is an A2A protocol operation for canceling tasks.
215
+
216
+ Args:
217
+ task_id: The unique identifier of the task to cancel
218
+
219
+ Returns:
220
+ Task: The task object with updated status
221
+ """
222
+ if not self._client:
223
+ raise RuntimeError("A2ABaseClient not initialized")
224
+
225
+ from a2a.types import TaskIdParams
226
+ params = TaskIdParams(id=task_id)
227
+ return await self._client.cancel_task(params)
228
+
229
+ async def send_message_streaming(self,
230
+ message_text: str,
231
+ task_id: str | None = None,
232
+ context_id: str | None = None) -> AsyncGenerator[ClientEvent | Message, None]:
233
+ """
234
+ Send a message to the agent and stream response events (alias for send_message).
235
+
236
+ This method provides an explicit streaming interface that mirrors the A2A SDK pattern.
237
+ It is functionally identical to send_message(), which already streams events.
238
+
239
+ Args:
240
+ message_text: The message text to send
241
+ task_id: Optional task ID to continue an existing conversation
242
+ context_id: Optional context ID for the conversation
243
+
244
+ Yields:
245
+ ClientEvent | Message: The agent's response events as they arrive.
246
+ """
247
+ async for event in self.send_message(message_text, task_id=task_id, context_id=context_id):
248
+ yield event
249
+
250
+ def extract_text_from_parts(self, parts: list) -> list[str]:
251
+ """
252
+ Extract text content from A2A message parts.
253
+
254
+ Args:
255
+ parts: List of A2A Part objects
256
+
257
+ Returns:
258
+ List of text strings extracted from the parts
259
+ """
260
+ text_parts = []
261
+ for part in parts:
262
+ # Handle Part wrapper (RootModel)
263
+ if hasattr(part, 'root'):
264
+ part_content = part.root
265
+ else:
266
+ part_content = part
267
+
268
+ # Extract text from TextPart
269
+ if hasattr(part_content, 'text'):
270
+ text_parts.append(part_content.text)
271
+
272
+ return text_parts
273
+
274
+ def extract_text_from_task(self, task) -> str:
275
+ """
276
+ Extract text response from an A2A Task object.
277
+
278
+ This method understands the A2A protocol structure and extracts the final
279
+ text response from a completed task, prioritizing artifacts over history.
280
+
281
+ Args:
282
+ task: A2A Task object
283
+
284
+ Returns:
285
+ Extracted text response or status message
286
+
287
+ Priority order:
288
+ 1. Check task status (return error/progress if not completed)
289
+ 2. Extract from task.artifacts (structured output)
290
+ 3. Fallback to last agent message in task.history
291
+ """
292
+ from a2a.types import TaskState
293
+
294
+ # Check task status
295
+ if task.status and task.status.state != TaskState.completed:
296
+ # Task not completed - return status message or indicate in progress
297
+ if task.status.state == TaskState.failed:
298
+ return f"Task failed: {task.status.message or 'Unknown error'}"
299
+ return f"Task in progress (state: {task.status.state})"
300
+
301
+ # Priority 1: Extract from artifacts (structured output)
302
+ if task.artifacts:
303
+ # Get text from all artifacts
304
+ all_text = []
305
+ for artifact in task.artifacts:
306
+ if artifact.parts:
307
+ text_parts = self.extract_text_from_parts(artifact.parts)
308
+ if text_parts:
309
+ all_text.extend(text_parts)
310
+ if all_text:
311
+ return " ".join(all_text)
312
+
313
+ # Priority 2: Fallback to history (conversational messages)
314
+ if task.history:
315
+ # Get the last agent message from history
316
+ for msg in reversed(task.history):
317
+ if msg.role.value == 'agent': # Get last agent message
318
+ text_parts = self.extract_text_from_parts(msg.parts)
319
+ if text_parts:
320
+ return " ".join(text_parts)
321
+
322
+ return "No response"
323
+
324
+ def extract_text_from_events(self, events: list) -> str:
325
+ """
326
+ Extract text response from a list of A2A events.
327
+
328
+ This is a convenience method that handles both Message and ClientEvent types.
329
+
330
+ Args:
331
+ events: List of A2A events (ClientEvent or Message objects)
332
+
333
+ Returns:
334
+ Extracted text response
335
+ """
336
+ from a2a.types import Message as A2AMessage
337
+
338
+ if not events:
339
+ return "No response"
340
+
341
+ # Get the last event
342
+ last_event = events[-1]
343
+
344
+ # If it's a Message, extract text from parts
345
+ if isinstance(last_event, A2AMessage):
346
+ text_parts = self.extract_text_from_parts(last_event.parts)
347
+ return " ".join(text_parts) if text_parts else str(last_event)
348
+
349
+ # If it's a ClientEvent (Task, TaskStatusUpdateEvent), extract from task
350
+ if isinstance(last_event, tuple):
351
+ task, _ = last_event
352
+ return self.extract_text_from_task(task)
353
+
354
+ return str(last_event)
@@ -0,0 +1,72 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from datetime import timedelta
17
+
18
+ from pydantic import Field
19
+ from pydantic import HttpUrl
20
+
21
+ from nat.data_models.component_ref import AuthenticationRef
22
+ from nat.data_models.function import FunctionGroupBaseConfig
23
+
24
+
25
+ class A2AClientConfig(FunctionGroupBaseConfig, name="a2a_client"):
26
+ """Configuration for A2A client function group.
27
+
28
+ This configuration enables NAT workflows to connect to remote A2A agents
29
+ and publish the primary agent function and helper functions.
30
+
31
+ Attributes:
32
+ url: The base URL of the A2A agent (e.g., https://agent.example.com)
33
+ agent_card_path: Path to the agent card (default: /.well-known/agent-card.json)
34
+ task_timeout: Maximum time to wait for task completion (default: 300 seconds)
35
+ include_skills_in_description: Include skill details in high-level function description (default: True)
36
+ streaming: Whether to enable streaming support for the A2A client (default: False)
37
+ auth_provider: Optional reference to NAT auth provider for authentication
38
+ """
39
+
40
+ url: HttpUrl = Field(
41
+ ...,
42
+ description="Base URL of the A2A agent",
43
+ )
44
+
45
+ agent_card_path: str = Field(
46
+ default='/.well-known/agent-card.json',
47
+ description="Path to the agent card",
48
+ )
49
+
50
+ task_timeout: timedelta = Field(
51
+ default=timedelta(seconds=300),
52
+ description="Maximum time to wait for task completion",
53
+ )
54
+
55
+ include_skills_in_description: bool = Field(
56
+ default=True,
57
+ description="Include skill details in the high-level agent function description. "
58
+ "Set to False for shorter descriptions (useful for token optimization). "
59
+ "Skills are always available via get_skills() helper.",
60
+ )
61
+
62
+ # streaming is disabled by default because of AIQ-2496
63
+ streaming: bool = Field(
64
+ default=False,
65
+ description="Whether to enable streaming support for the A2A client",
66
+ )
67
+
68
+ auth_provider: str | AuthenticationRef | None = Field(
69
+ default=None,
70
+ description="Reference to NAT authentication provider for authenticating with the A2A agent. "
71
+ "Supports OAuth2, API Key, HTTP Basic, and other NAT auth providers.",
72
+ )