nvidia-nat-a2a 1.4.0a20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nvidia-nat-a2a might be problematic. Click here for more details.

@@ -0,0 +1,296 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ from collections.abc import AsyncGenerator
18
+ from typing import Any
19
+
20
+ from pydantic import BaseModel
21
+ from pydantic import Field
22
+
23
+ from nat.builder.function import FunctionGroup
24
+ from nat.builder.workflow_builder import Builder
25
+ from nat.cli.register_workflow import register_function_group
26
+ from nat.plugins.a2a.client.client_base import A2ABaseClient
27
+ from nat.plugins.a2a.client.client_config import A2AClientConfig
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ # Input models for helper functions
33
+ class GetTaskInput(BaseModel):
34
+ """Input for get_task function."""
35
+ task_id: str = Field(..., description="The ID of the task to retrieve")
36
+ history_length: int | None = Field(default=None, description="Number of history items to include")
37
+
38
+
39
+ class CancelTaskInput(BaseModel):
40
+ """Input for cancel_task function."""
41
+ task_id: str = Field(..., description="The ID of the task to cancel")
42
+
43
+
44
+ class SendMessageInput(BaseModel):
45
+ """Input for send_message function."""
46
+ query: str = Field(..., description="The query to send to the agent")
47
+ task_id: str | None = Field(default=None, description="Optional task ID for continuation")
48
+ context_id: str | None = Field(default=None, description="Optional context ID for session management")
49
+
50
+
51
+ class A2AClientFunctionGroup(FunctionGroup):
52
+ """
53
+ A minimal FunctionGroup for A2A agents.
54
+
55
+ Exposes a simple `send_message` function to interact with A2A agents.
56
+ """
57
+
58
+ def __init__(self, config: A2AClientConfig, builder: Builder):
59
+ super().__init__(config=config)
60
+ self._builder = builder
61
+ self._client: A2ABaseClient | None = None
62
+ self._include_skills_in_description = config.include_skills_in_description
63
+
64
+ async def __aenter__(self):
65
+ """Initialize the A2A client and register functions."""
66
+ config: A2AClientConfig = self._config # type: ignore[assignment]
67
+ base_url = str(config.url)
68
+
69
+ # Create and initialize A2A client
70
+ self._client = A2ABaseClient(base_url=base_url,
71
+ agent_card_path=config.agent_card_path,
72
+ task_timeout=config.task_timeout,
73
+ streaming=config.streaming)
74
+ await self._client.__aenter__()
75
+ logger.info("Connected to A2A agent at %s", base_url)
76
+
77
+ # Discover agent card and register functions
78
+ self._register_functions()
79
+
80
+ return self
81
+
82
+ def _register_functions(self):
83
+ """Retrieve agent card and register the three-level API: high-level, helpers, and low-level."""
84
+ # Validate client is initialized
85
+ if not self._client:
86
+ raise RuntimeError("A2A client not initialized")
87
+
88
+ # Get and validate agent card
89
+ agent_card = self._client.agent_card
90
+ if not agent_card:
91
+ raise RuntimeError("Agent card not available")
92
+
93
+ # Log agent information
94
+ logger.info("Agent: %s v%s", agent_card.name, agent_card.version)
95
+ if agent_card.skills:
96
+ logger.info("Skills: %s", [skill.name for skill in agent_card.skills])
97
+
98
+ # Register functions
99
+ # LEVEL 1: High-level main function (LLM-friendly)
100
+ self.add_function(
101
+ name="call",
102
+ fn=self._create_high_level_function(),
103
+ description=self._format_main_function_description(agent_card),
104
+ )
105
+
106
+ # LEVEL 2: Standard helpers (metadata/utility)
107
+ self.add_function(
108
+ name="get_skills",
109
+ fn=self._get_skills,
110
+ description="Get the list of skills and capabilities available from this agent",
111
+ )
112
+
113
+ self.add_function(
114
+ name="get_info",
115
+ fn=self._get_agent_info,
116
+ description="Get metadata about this agent (name, version, provider, capabilities)",
117
+ )
118
+
119
+ self.add_function(
120
+ name="get_task",
121
+ fn=self._wrap_get_task,
122
+ description="Get the status and details of a specific task by task_id",
123
+ )
124
+
125
+ self.add_function(
126
+ name="cancel_task",
127
+ fn=self._wrap_cancel_task,
128
+ description="Cancel a running task by task_id",
129
+ )
130
+
131
+ # LEVEL 3: Low-level protocol (advanced)
132
+ self.add_function(
133
+ name="send_message",
134
+ fn=self._send_message_advanced,
135
+ description=("Advanced: Send a message with full control over the A2A protocol. "
136
+ "Returns raw events as a list. For most use cases, prefer using the "
137
+ "high-level 'call()' function instead."),
138
+ )
139
+
140
+ self.add_function(
141
+ name="send_message_streaming",
142
+ fn=self._send_message_streaming,
143
+ description=("Advanced: Send a message and stream response events as they arrive. "
144
+ "Yields raw events one by one. This is an async generator function. "
145
+ "For most use cases, prefer using the high-level 'call()' function instead."),
146
+ )
147
+
148
+ async def __aexit__(self, exc_type, exc_value, traceback):
149
+ """Clean up the A2A client."""
150
+ if self._client:
151
+ await self._client.__aexit__(exc_type, exc_value, traceback)
152
+ self._client = None
153
+ logger.info("Disconnected from A2A agent")
154
+
155
+ def _format_main_function_description(self, agent_card) -> str:
156
+ """Create description for the main agent function."""
157
+ description = f"{agent_card.description}\n\n"
158
+
159
+ # Conditionally include skills based on configuration
160
+ if self._include_skills_in_description and agent_card.skills:
161
+ description += "**Capabilities:**\n"
162
+ for skill in agent_card.skills:
163
+ description += f"\n• **{skill.name}**: {skill.description}"
164
+ if skill.examples:
165
+ examples = skill.examples[:2] # Limit to 2 examples
166
+ description += f"\n Examples: {', '.join(examples)}"
167
+ description += "\n\n"
168
+ elif agent_card.skills:
169
+ # Brief mention that skills are available
170
+ description += f"**{len(agent_card.skills)} capabilities available.** "
171
+ description += "Use get_skills() for detailed information.\n\n"
172
+
173
+ description += "**Usage:** Send natural language queries to interact with this agent."
174
+
175
+ return description
176
+
177
+ def _create_high_level_function(self):
178
+ """High-level function that simplifies the response."""
179
+
180
+ async def high_level_fn(query: str, task_id: str | None = None, context_id: str | None = None) -> str:
181
+ """
182
+ Send a query to the agent and get a simple text response.
183
+
184
+ This is the recommended method for LLM usage.
185
+ For advanced use cases, use send_message() for raw events.
186
+ """
187
+ if not self._client:
188
+ raise RuntimeError("A2A client not initialized")
189
+
190
+ events = []
191
+ async for event in self._client.send_message(query, task_id, context_id):
192
+ events.append(event)
193
+
194
+ # Extract and return just the text response using base client helper
195
+ return self._client.extract_text_from_events(events)
196
+
197
+ return high_level_fn
198
+
199
+ async def _get_skills(self, params: dict | None = None) -> dict:
200
+ """Helper function to list agent skills."""
201
+ if not self._client or not self._client.agent_card:
202
+ return {"skills": []}
203
+
204
+ agent_card = self._client.agent_card
205
+ return {
206
+ "agent":
207
+ agent_card.name,
208
+ "skills": [{
209
+ "id": skill.id,
210
+ "name": skill.name,
211
+ "description": skill.description,
212
+ "examples": skill.examples or [],
213
+ "tags": skill.tags or []
214
+ } for skill in agent_card.skills]
215
+ }
216
+
217
+ async def _get_agent_info(self, params: dict | None = None) -> dict:
218
+ """Helper function to get agent metadata."""
219
+ if not self._client or not self._client.agent_card:
220
+ return {}
221
+
222
+ agent_card = self._client.agent_card
223
+ return {
224
+ "name": agent_card.name,
225
+ "description": agent_card.description,
226
+ "version": agent_card.version,
227
+ "provider": agent_card.provider.model_dump() if agent_card.provider else None,
228
+ "url": agent_card.url,
229
+ "capabilities": {
230
+ "streaming": agent_card.capabilities.streaming if agent_card.capabilities else False,
231
+ },
232
+ "num_skills": len(agent_card.skills)
233
+ }
234
+
235
+ async def _wrap_get_task(self, params: GetTaskInput) -> Any:
236
+ """Wrapper for get_task that delegates to client_base."""
237
+ if not self._client:
238
+ raise RuntimeError("A2A client not initialized")
239
+ return await self._client.get_task(params.task_id, params.history_length)
240
+
241
+ async def _wrap_cancel_task(self, params: CancelTaskInput) -> Any:
242
+ """Wrapper for cancel_task that delegates to client_base."""
243
+ if not self._client:
244
+ raise RuntimeError("A2A client not initialized")
245
+ return await self._client.cancel_task(params.task_id)
246
+
247
+ async def _send_message_advanced(self, params: SendMessageInput) -> list:
248
+ """
249
+ Send a message with full A2A protocol control.
250
+
251
+ Returns: List of ClientEvent|Message objects containing:
252
+ - Task information
253
+ - Status updates
254
+ - Artifact updates
255
+ - Full message details
256
+ """
257
+ if not self._client:
258
+ raise RuntimeError("A2A client not initialized")
259
+
260
+ events = []
261
+ async for event in self._client.send_message(params.query, params.task_id, params.context_id):
262
+ events.append(event)
263
+ return events
264
+
265
+ async def _send_message_streaming(self, params: SendMessageInput) -> AsyncGenerator[Any, None]:
266
+ """
267
+ Send a message with full A2A protocol control and stream events.
268
+
269
+ This is an async generator that yields events as they arrive from the agent.
270
+
271
+ Yields: ClientEvent|Message objects containing:
272
+ - Task information
273
+ - Status updates
274
+ - Artifact updates
275
+ - Full message details
276
+ """
277
+ if not self._client:
278
+ raise RuntimeError("A2A client not initialized")
279
+
280
+ async for event in self._client.send_message_streaming(params.query, params.task_id, params.context_id):
281
+ yield event
282
+
283
+
284
+ @register_function_group(config_type=A2AClientConfig)
285
+ async def a2a_client_function_group(config: A2AClientConfig, _builder: Builder):
286
+ """
287
+ Connect to an A2A agent, discover agent card and publish the primary
288
+ agent function and helper functions.
289
+
290
+ This function group creates a three-level API:
291
+ - High-level: Agent function named after the agent (e.g., dice_agent)
292
+ - Helpers: get_skills, get_info, get_task, cancel_task
293
+ - Low-level: send_message for advanced usage
294
+ """
295
+ async with A2AClientFunctionGroup(config, _builder) as group:
296
+ yield group
@@ -0,0 +1,23 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # flake8: noqa
17
+ # isort:skip_file
18
+
19
+ # Register client components
20
+ from .client import client_impl
21
+
22
+ # Register server/frontend components
23
+ from .server import register_frontend
@@ -0,0 +1,14 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
@@ -0,0 +1,172 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Adapter to bridge NAT workflows with A2A AgentExecutor interface.
16
+
17
+ This module implements a message-only A2A agent for Phase 1, providing stateless
18
+ request/response interactions without task lifecycle management.
19
+ """
20
+
21
+ import logging
22
+
23
+ from a2a.server.agent_execution import AgentExecutor
24
+ from a2a.server.agent_execution import RequestContext
25
+ from a2a.server.events import EventQueue
26
+ from a2a.types import InternalError
27
+ from a2a.types import InvalidParamsError
28
+ from a2a.types import UnsupportedOperationError
29
+ from a2a.utils import new_agent_text_message
30
+ from a2a.utils.errors import ServerError
31
+ from nat.runtime.session import SessionManager
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ class NATWorkflowAgentExecutor(AgentExecutor):
37
+ """Adapts NAT workflows to A2A AgentExecutor interface as a message-only agent.
38
+
39
+ This adapter implements Phase 1 support for A2A integration, providing stateless
40
+ message-based interactions. Each request is handled independently without maintaining
41
+ conversation state or task lifecycle.
42
+
43
+ Key characteristics:
44
+ - Stateless: Each message is processed independently
45
+ - Synchronous: Returns immediate responses (no long-running tasks)
46
+ - Message-only: Returns Message objects, not Task objects
47
+ - Concurrent: Uses SessionManager's semaphore for concurrency control
48
+
49
+ Note: Multi-turn conversations and task-based interactions are deferred to Phase 5.
50
+ """
51
+
52
+ def __init__(self, session_manager: SessionManager):
53
+ """Initialize the adapter with a NAT SessionManager.
54
+
55
+ Args:
56
+ session_manager: The SessionManager for handling workflow execution
57
+ with concurrency control via semaphore
58
+ """
59
+ self.session_manager = session_manager
60
+ logger.info("Initialized NATWorkflowAgentExecutor (message-only) for workflow: %s",
61
+ session_manager.workflow.config.workflow.type)
62
+
63
+ async def execute(
64
+ self,
65
+ context: RequestContext,
66
+ event_queue: EventQueue,
67
+ ) -> None:
68
+ """Execute the NAT workflow and return a message response.
69
+
70
+ This is a message-only implementation (Phase 1):
71
+ 1. Extracts the user query from the A2A message
72
+ 2. Runs the NAT workflow (stateless, no conversation history)
73
+ 3. Returns the result as a Message object (not a Task)
74
+
75
+ For Phase 1, each message is handled independently with no state preservation
76
+ between requests. The context_id and task_id from the A2A protocol are mapped
77
+ to NAT's conversation_id and user_message_id for tracing purposes only.
78
+
79
+ Args:
80
+ context: The A2A request context containing the user message
81
+ event_queue: Queue for sending the response message back to the client
82
+
83
+ Raises:
84
+ ServerError: If validation fails or workflow execution errors occur
85
+ """
86
+ # Validate the request
87
+ error = self._validate_request(context)
88
+ if error:
89
+ raise ServerError(error=InvalidParamsError())
90
+
91
+ # Extract query from the message
92
+ query = context.get_user_input()
93
+ if not query:
94
+ logger.error("No user input found in context")
95
+ raise ServerError(error=InvalidParamsError())
96
+
97
+ # Extract IDs for tracing (stored but not used for state management in Phase 1)
98
+ context_id = context.context_id
99
+ task_id = context.task_id
100
+
101
+ logger.info("Processing message-only request (context_id=%s, task_id=%s): %s", context_id, task_id, query[:100])
102
+
103
+ try:
104
+ # Run the NAT workflow using SessionManager for proper concurrency handling
105
+ # Each message gets its own independent session (stateless)
106
+ # TODO: Add support for user input callbacks and authentication in later phases
107
+ async with self.session_manager.session() as session:
108
+ async with session.run(query) as runner:
109
+ # Get the result as a string
110
+ response_text = await runner.result(to_type=str)
111
+
112
+ logger.info("Workflow completed successfully (context_id=%s, task_id=%s)", context_id, task_id)
113
+
114
+ # Create and send the response message (message-only pattern)
115
+ response_message = new_agent_text_message(
116
+ response_text,
117
+ context_id=context_id,
118
+ task_id=task_id,
119
+ )
120
+ await event_queue.enqueue_event(response_message)
121
+
122
+ except Exception as e:
123
+ logger.error("Error executing NAT workflow (context_id=%s, task_id=%s): %s",
124
+ context_id,
125
+ task_id,
126
+ e,
127
+ exc_info=True)
128
+
129
+ # Send error message back to client
130
+ error_message = new_agent_text_message(
131
+ f"An error occurred while processing your request: {str(e)}",
132
+ context_id=context_id,
133
+ task_id=task_id,
134
+ )
135
+ await event_queue.enqueue_event(error_message)
136
+ raise ServerError(error=InternalError()) from e
137
+
138
+ def _validate_request(self, context: RequestContext) -> bool:
139
+ """Validate the incoming request context.
140
+
141
+ Args:
142
+ context: The request context to validate
143
+
144
+ Returns:
145
+ True if validation fails, False if validation succeeds
146
+ """
147
+ # Basic validation - can be extended as needed
148
+ if not context.message:
149
+ logger.error("Request context has no message")
150
+ return True
151
+
152
+ return False
153
+
154
+ async def cancel(
155
+ self,
156
+ _context: RequestContext,
157
+ _event_queue: EventQueue,
158
+ ) -> None:
159
+ """Handle task cancellation requests.
160
+
161
+ Not applicable for message-only agents in Phase 1. Cancellation is a task-based
162
+ feature that will be implemented in Phase 5 along with long-running task support.
163
+
164
+ Args:
165
+ _context: The request context (unused in Phase 1)
166
+ _event_queue: Event queue for sending updates (unused in Phase 1)
167
+
168
+ Raises:
169
+ ServerError: Always raises UnsupportedOperationError
170
+ """
171
+ logger.warning("Task cancellation requested but not supported in message-only mode (Phase 1)")
172
+ raise ServerError(error=UnsupportedOperationError())
@@ -0,0 +1,104 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+
18
+ from pydantic import BaseModel
19
+ from pydantic import Field
20
+
21
+ from nat.data_models.front_end import FrontEndBaseConfig
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class A2ACapabilitiesConfig(BaseModel):
27
+ """A2A agent capabilities configuration."""
28
+
29
+ streaming: bool = Field(
30
+ default=True,
31
+ description="Enable streaming responses (default: True)",
32
+ )
33
+ push_notifications: bool = Field(
34
+ default=False,
35
+ description="Enable push notifications (default: False)",
36
+ )
37
+
38
+
39
+ class A2AFrontEndConfig(FrontEndBaseConfig, name="a2a"):
40
+ """A2A front end configuration.
41
+
42
+ A front end that exposes NeMo Agent toolkit workflows as A2A-compliant remote agents.
43
+ """
44
+
45
+ # Server settings
46
+ host: str = Field(
47
+ default="localhost",
48
+ description="Host to bind the server to (default: localhost)",
49
+ )
50
+ port: int = Field(
51
+ default=10000,
52
+ description="Port to bind the server to (default: 10000)",
53
+ ge=0,
54
+ le=65535,
55
+ )
56
+ version: str = Field(
57
+ default="1.0.0",
58
+ description="Version of the agent (default: 1.0.0)",
59
+ )
60
+ log_level: str = Field(
61
+ default="INFO",
62
+ description="Log level for the A2A server (default: INFO)",
63
+ )
64
+
65
+ # Agent metadata
66
+ name: str = Field(
67
+ default="NeMo Agent Toolkit A2A Agent",
68
+ description="Name of the A2A agent (default: NeMo Agent Toolkit A2A Agent)",
69
+ )
70
+ description: str = Field(
71
+ default="An AI agent powered by NeMo Agent Toolkit exposed via A2A protocol",
72
+ description="Description of what the agent does (default: generic description)",
73
+ )
74
+
75
+ # A2A capabilities
76
+ capabilities: A2ACapabilitiesConfig = Field(
77
+ default_factory=A2ACapabilitiesConfig,
78
+ description="Agent capabilities configuration",
79
+ )
80
+
81
+ # Concurrency control
82
+ max_concurrency: int = Field(
83
+ default=8,
84
+ description="Maximum number of concurrent workflow executions (default: 8). "
85
+ "Controls how many A2A requests can execute workflows simultaneously. "
86
+ "Set to 0 or -1 for unlimited concurrency.",
87
+ ge=-1,
88
+ )
89
+
90
+ # Content modes
91
+ default_input_modes: list[str] = Field(
92
+ default_factory=lambda: ["text", "text/plain"],
93
+ description="Supported input content types (default: text, text/plain)",
94
+ )
95
+ default_output_modes: list[str] = Field(
96
+ default_factory=lambda: ["text", "text/plain"],
97
+ description="Supported output content types (default: text, text/plain)",
98
+ )
99
+
100
+ # Optional customization
101
+ runner_class: str | None = Field(
102
+ default=None,
103
+ description="Custom worker class for handling A2A routes (default: built-in worker)",
104
+ )