nvidia-nat-a2a 1.4.0a20251207__py3-none-any.whl → 1.4.0a20251231__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nvidia-nat-a2a might be problematic. Click here for more details.

@@ -0,0 +1,15 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Authentication support for A2A clients."""
@@ -0,0 +1,418 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Bridge NAT AuthProviderBase to A2A SDK CredentialService."""
16
+
17
+ import asyncio
18
+ import logging
19
+
20
+ from a2a.client import ClientCallContext
21
+ from a2a.client import CredentialService
22
+ from a2a.types import AgentCard
23
+ from a2a.types import APIKeySecurityScheme
24
+ from a2a.types import HTTPAuthSecurityScheme
25
+ from a2a.types import OAuth2SecurityScheme
26
+ from a2a.types import OpenIdConnectSecurityScheme
27
+ from a2a.types import SecurityScheme
28
+ from nat.authentication.interfaces import AuthProviderBase
29
+ from nat.builder.context import Context
30
+ from nat.data_models.authentication import AuthResult
31
+ from nat.data_models.authentication import BasicAuthCred
32
+ from nat.data_models.authentication import BearerTokenCred
33
+ from nat.data_models.authentication import CookieCred
34
+ from nat.data_models.authentication import HeaderCred
35
+ from nat.data_models.authentication import QueryCred
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ class A2ACredentialService(CredentialService):
41
+ """
42
+ Adapts NAT AuthProviderBase to A2A SDK CredentialService interface.
43
+
44
+ This class bridges NAT's authentication system with the A2A SDK's authentication
45
+ mechanism, allowing A2A clients to use NAT's auth providers (API Key, OAuth2, etc.)
46
+ to authenticate with A2A agents.
47
+
48
+ The adapter:
49
+ - Calls NAT auth provider to obtain credentials
50
+ - Maps NAT credential types to A2A security scheme requirements
51
+ - Handles token expiration and automatic refresh
52
+ - Supports session-based multi-user authentication
53
+
54
+ Args:
55
+ auth_provider: NAT authentication provider instance
56
+ agent_card: Agent card containing security scheme definitions
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ auth_provider: AuthProviderBase,
62
+ agent_card: AgentCard | None = None,
63
+ ):
64
+ self._auth_provider = auth_provider
65
+ self._agent_card = agent_card
66
+ self._cached_auth_result: AuthResult | None = None
67
+ self._auth_lock = asyncio.Lock()
68
+
69
+ # Validate provider compatibility with agent's security requirements
70
+ self._validate_provider_compatibility()
71
+
72
+ async def get_credentials(
73
+ self,
74
+ security_scheme_name: str,
75
+ context: ClientCallContext | None,
76
+ ) -> str | None:
77
+ """
78
+ Retrieve credentials for a security scheme.
79
+
80
+ This method:
81
+ 1. Gets user_id from NAT context
82
+ 2. Authenticates via NAT auth provider
83
+ 3. Handles token expiration and refresh
84
+ 4. Maps credentials to the requested security scheme
85
+
86
+ Args:
87
+ security_scheme_name: Name of the security scheme from AgentCard
88
+ context: Client call context with optional session information
89
+
90
+ Returns:
91
+ Credential string or None if not available
92
+ """
93
+ # Get user_id from NAT context
94
+ user_id = Context.get().user_id
95
+
96
+ # Authenticate and get credentials from NAT provider
97
+ auth_result = await self._authenticate(user_id)
98
+
99
+ if not auth_result:
100
+ logger.warning("Authentication failed, no credentials available")
101
+ return None
102
+
103
+ # Map NAT credentials to A2A format based on security scheme
104
+ credential = self._extract_credential_for_scheme(auth_result, security_scheme_name)
105
+
106
+ if credential:
107
+ logger.debug(
108
+ "Successfully retrieved credentials for scheme '%s'",
109
+ security_scheme_name,
110
+ )
111
+ else:
112
+ logger.warning(
113
+ "No compatible credentials found for scheme '%s'",
114
+ security_scheme_name,
115
+ )
116
+
117
+ return credential
118
+
119
+ async def _authenticate(self, user_id: str | None) -> AuthResult | None:
120
+ """
121
+ Authenticate and get credentials from NAT auth provider.
122
+
123
+ Handles token expiration by triggering re-authentication if needed.
124
+ Uses a lock to prevent concurrent authentication requests and race conditions.
125
+
126
+ Args:
127
+ user_id: User identifier for authentication
128
+
129
+ Returns:
130
+ AuthResult with credentials or None on failure
131
+ """
132
+ try:
133
+ # Fast path: check cache without lock
134
+ auth_result = self._cached_auth_result
135
+ if auth_result and not auth_result.is_expired():
136
+ return auth_result
137
+
138
+ # Acquire lock to serialize authentication attempts
139
+ async with self._auth_lock:
140
+ # Double-check: another coroutine may have refreshed while we waited for lock
141
+ auth_result = self._cached_auth_result
142
+ if auth_result and not auth_result.is_expired():
143
+ logger.debug("Credentials were refreshed by another coroutine while waiting for lock")
144
+ return auth_result
145
+
146
+ # Log if we're refreshing expired credentials
147
+ if auth_result and auth_result.is_expired():
148
+ logger.info("Cached credentials expired, re-authenticating")
149
+
150
+ # Call NAT auth provider (provider is responsible for token refresh/validity)
151
+ auth_result = await self._auth_provider.authenticate(user_id=user_id)
152
+
153
+ # Cache the result while holding the lock
154
+ self._cached_auth_result = auth_result
155
+
156
+ # Warn if provider returned expired credentials (provider bug)
157
+ if auth_result and auth_result.is_expired():
158
+ logger.warning("Auth provider returned already-expired credentials. "
159
+ "This may indicate a bug in the auth provider's token refresh logic.")
160
+
161
+ return auth_result
162
+
163
+ except Exception as e:
164
+ logger.error("Authentication failed: %s", e, exc_info=True)
165
+ return None
166
+
167
+ def _extract_credential_for_scheme(self, auth_result: AuthResult, security_scheme_name: str) -> str | None:
168
+ """
169
+ Extract appropriate credential based on security scheme type.
170
+
171
+ Maps NAT credential types to A2A security scheme requirements:
172
+ - BearerTokenCred -> OAuth2, OIDC, HTTP Bearer
173
+ - HeaderCred -> API Key in header
174
+ - QueryCred -> API Key in query
175
+ - CookieCred -> API Key in cookie
176
+ - BasicAuthCred -> HTTP Basic
177
+
178
+ Args:
179
+ auth_result: Authentication result containing credentials
180
+ security_scheme_name: Name of the security scheme
181
+
182
+ Returns:
183
+ Credential string or None
184
+ """
185
+ # Get scheme definition from agent card
186
+ scheme_def = self._get_scheme_definition(security_scheme_name)
187
+
188
+ # Try to match NAT credentials to security scheme
189
+ for cred in auth_result.credentials:
190
+ # Check compatibility and extract credential value
191
+ credential_value = None
192
+
193
+ if isinstance(cred, BearerTokenCred) and self._is_bearer_compatible(scheme_def):
194
+ credential_value = cred.token.get_secret_value()
195
+ elif isinstance(cred, HeaderCred) and self._is_header_compatible(scheme_def, cred.name):
196
+ credential_value = cred.value.get_secret_value()
197
+ elif isinstance(cred, QueryCred) and self._is_query_compatible(scheme_def, cred.name):
198
+ credential_value = cred.value.get_secret_value()
199
+ elif isinstance(cred, CookieCred) and self._is_cookie_compatible(scheme_def, cred.name):
200
+ credential_value = cred.value.get_secret_value()
201
+ elif isinstance(cred, BasicAuthCred) and self._is_basic_compatible(scheme_def):
202
+ # For HTTP Basic, encode username:password as base64
203
+ import base64
204
+
205
+ username = cred.username.get_secret_value()
206
+ password = cred.password.get_secret_value()
207
+ credentials = f"{username}:{password}"
208
+ credential_value = base64.b64encode(credentials.encode()).decode()
209
+
210
+ if credential_value:
211
+ return credential_value
212
+
213
+ return None
214
+
215
+ def _get_scheme_definition(self, scheme_name: str) -> SecurityScheme | None:
216
+ """
217
+ Get security scheme definition from agent card.
218
+
219
+ Args:
220
+ scheme_name: Name of the security scheme
221
+
222
+ Returns:
223
+ SecurityScheme definition or None
224
+ """
225
+ if not self._agent_card or not self._agent_card.security_schemes:
226
+ return None
227
+ return self._agent_card.security_schemes.get(scheme_name)
228
+
229
+ def _validate_provider_compatibility(self) -> None:
230
+ """
231
+ Validate that the auth provider type is compatible with agent's security schemes.
232
+
233
+ This performs early validation at connection time to fail fast if there's a
234
+ configuration mismatch between the NAT auth provider and the A2A agent's
235
+ security requirements.
236
+
237
+ Raises:
238
+ ValueError: If the provider is incompatible with all required security schemes
239
+ """
240
+ if not self._agent_card or not self._agent_card.security_schemes:
241
+ # No security schemes defined, nothing to validate
242
+ logger.debug("No security schemes defined in agent card, skipping validation")
243
+ return
244
+
245
+ provider_type = type(self._auth_provider).__name__
246
+ schemes = self._agent_card.security_schemes
247
+
248
+ logger.info("Validating auth provider '%s' against agent security schemes: %s",
249
+ provider_type,
250
+ list(schemes.keys()))
251
+
252
+ # Check if provider type is compatible with at least one security scheme
253
+ compatible_schemes = []
254
+ incompatible_schemes = []
255
+
256
+ for scheme_name, scheme in schemes.items():
257
+ is_compatible = self._is_provider_compatible_with_scheme(scheme)
258
+ if is_compatible:
259
+ compatible_schemes.append(scheme_name)
260
+ else:
261
+ incompatible_schemes.append((scheme_name, type(scheme.root).__name__))
262
+
263
+ if not compatible_schemes:
264
+ # Provider is not compatible with any security scheme
265
+ scheme_details = ", ".join(f"{name} ({scheme_type})" for name, scheme_type in incompatible_schemes)
266
+ raise ValueError(f"Auth provider '{provider_type}' is not compatible with agent's "
267
+ f"security requirements. Agent requires: {scheme_details}")
268
+
269
+ logger.info("Auth provider '%s' is compatible with schemes: %s", provider_type, compatible_schemes)
270
+
271
+ def _is_provider_compatible_with_scheme(self, scheme: SecurityScheme) -> bool:
272
+ """
273
+ Check if the current auth provider can satisfy a security scheme.
274
+
275
+ Args:
276
+ scheme: Security scheme from agent card
277
+
278
+ Returns:
279
+ True if provider is compatible with the scheme
280
+ """
281
+ provider_type = type(self._auth_provider).__name__
282
+
283
+ # OAuth2/OIDC schemes require OAuth2 providers
284
+ if isinstance(scheme.root, OAuth2SecurityScheme | OpenIdConnectSecurityScheme):
285
+ return "OAuth2" in provider_type
286
+
287
+ # API Key schemes (can be in header, query, or cookie)
288
+ if isinstance(scheme.root, APIKeySecurityScheme):
289
+ return "APIKey" in provider_type
290
+
291
+ # HTTP Auth schemes (Basic or Bearer)
292
+ if isinstance(scheme.root, HTTPAuthSecurityScheme):
293
+ scheme_lower = scheme.root.scheme.lower()
294
+ if scheme_lower == "basic":
295
+ return "HTTPBasic" in provider_type or "BasicAuth" in provider_type
296
+ elif scheme_lower == "bearer":
297
+ # Bearer can be satisfied by OAuth2 or API Key providers
298
+ return "OAuth2" in provider_type or "APIKey" in provider_type
299
+
300
+ # Unknown or unsupported scheme type
301
+ logger.warning("Unknown security scheme type: %s", type(scheme.root).__name__)
302
+ return False
303
+
304
+ @staticmethod
305
+ def _is_bearer_compatible(scheme_def: SecurityScheme | None) -> bool:
306
+ """
307
+ Check if security scheme accepts Bearer tokens.
308
+
309
+ Bearer tokens are compatible with:
310
+ - OAuth2SecurityScheme
311
+ - OpenIdConnectSecurityScheme
312
+ - HTTPAuthSecurityScheme with scheme='bearer'
313
+
314
+ Args:
315
+ scheme_def: Security scheme definition
316
+
317
+ Returns:
318
+ True if Bearer token is compatible
319
+ """
320
+ if not scheme_def:
321
+ return False
322
+
323
+ # Check for OAuth2 or OIDC schemes
324
+ if isinstance(scheme_def.root, OAuth2SecurityScheme | OpenIdConnectSecurityScheme):
325
+ return True
326
+
327
+ # Check for HTTP Bearer scheme
328
+ if isinstance(scheme_def.root, HTTPAuthSecurityScheme):
329
+ return scheme_def.root.scheme.lower() == "bearer"
330
+
331
+ return False
332
+
333
+ @staticmethod
334
+ def _is_header_compatible(scheme_def: SecurityScheme | None, header_name: str) -> bool:
335
+ """
336
+ Check if security scheme accepts header-based API keys.
337
+
338
+ Args:
339
+ scheme_def: Security scheme definition
340
+ header_name: Name of the header containing the credential
341
+
342
+ Returns:
343
+ True if header credential is compatible
344
+ """
345
+ if not scheme_def:
346
+ return False
347
+
348
+ # Check for API Key in header
349
+ if isinstance(scheme_def.root, APIKeySecurityScheme):
350
+ if scheme_def.root.in_ == "header":
351
+ # Match header name (case-insensitive)
352
+ return scheme_def.root.name.lower() == header_name.lower()
353
+
354
+ return False
355
+
356
+ @staticmethod
357
+ def _is_query_compatible(scheme_def: SecurityScheme | None, param_name: str) -> bool:
358
+ """
359
+ Check if security scheme accepts query parameter API keys.
360
+
361
+ Args:
362
+ scheme_def: Security scheme definition
363
+ param_name: Name of the query parameter
364
+
365
+ Returns:
366
+ True if query credential is compatible
367
+ """
368
+ if not scheme_def:
369
+ return False
370
+
371
+ # Check for API Key in query
372
+ if isinstance(scheme_def.root, APIKeySecurityScheme):
373
+ if scheme_def.root.in_ == "query":
374
+ return scheme_def.root.name == param_name
375
+
376
+ return False
377
+
378
+ @staticmethod
379
+ def _is_cookie_compatible(scheme_def: SecurityScheme | None, cookie_name: str) -> bool:
380
+ """
381
+ Check if security scheme accepts cookie-based API keys.
382
+
383
+ Args:
384
+ scheme_def: Security scheme definition
385
+ cookie_name: Name of the cookie
386
+
387
+ Returns:
388
+ True if cookie credential is compatible
389
+ """
390
+ if not scheme_def:
391
+ return False
392
+
393
+ # Check for API Key in cookie
394
+ if isinstance(scheme_def.root, APIKeySecurityScheme):
395
+ if scheme_def.root.in_ == "cookie":
396
+ return scheme_def.root.name == cookie_name
397
+
398
+ return False
399
+
400
+ @staticmethod
401
+ def _is_basic_compatible(scheme_def: SecurityScheme | None) -> bool:
402
+ """
403
+ Check if security scheme accepts HTTP Basic authentication.
404
+
405
+ Args:
406
+ scheme_def: Security scheme definition
407
+
408
+ Returns:
409
+ True if Basic auth is compatible
410
+ """
411
+ if not scheme_def:
412
+ return False
413
+
414
+ # Check for HTTP Basic scheme
415
+ if isinstance(scheme_def.root, HTTPAuthSecurityScheme):
416
+ return scheme_def.root.scheme.lower() == "basic"
417
+
418
+ return False
@@ -18,6 +18,7 @@ from __future__ import annotations
18
18
  import logging
19
19
  from collections.abc import AsyncGenerator
20
20
  from datetime import timedelta
21
+ from typing import TYPE_CHECKING
21
22
  from uuid import uuid4
22
23
 
23
24
  import httpx
@@ -34,6 +35,9 @@ from a2a.types import Role
34
35
  from a2a.types import Task
35
36
  from a2a.types import TextPart
36
37
 
38
+ if TYPE_CHECKING:
39
+ from nat.authentication.interfaces import AuthProviderBase
40
+
37
41
  logger = logging.getLogger(__name__)
38
42
 
39
43
 
@@ -43,20 +47,25 @@ class A2ABaseClient:
43
47
 
44
48
  Args:
45
49
  base_url: The base URL of the A2A agent
50
+ agent_card_path: Path to agent card (default: /.well-known/agent-card.json)
46
51
  task_timeout: Timeout for task operations (default: 300 seconds)
52
+ streaming: Enable streaming responses (default: True)
53
+ auth_provider: Optional NAT authentication provider for securing requests
47
54
  """
48
55
 
49
56
  def __init__(
50
- self,
51
- base_url: str,
52
- agent_card_path: str = "/.well-known/agent-card.json",
53
- task_timeout: timedelta = timedelta(seconds=300),
54
- streaming: bool = True,
57
+ self,
58
+ base_url: str,
59
+ agent_card_path: str = "/.well-known/agent-card.json",
60
+ task_timeout: timedelta = timedelta(seconds=300),
61
+ streaming: bool = True,
62
+ auth_provider: AuthProviderBase | None = None,
55
63
  ):
56
64
  self._base_url = base_url
57
65
  self._agent_card_path = agent_card_path
58
66
  self._task_timeout = task_timeout
59
67
  self._streaming = streaming
68
+ self._auth_provider = auth_provider
60
69
 
61
70
  self._httpx_client: httpx.AsyncClient | None = None
62
71
  self._client: Client | None = None
@@ -82,13 +91,30 @@ class A2ABaseClient:
82
91
  if not self._agent_card:
83
92
  raise RuntimeError("Agent card not resolved")
84
93
 
85
- # 3) Create A2A client
94
+ # 3) Setup authentication interceptors if auth is configured
95
+ interceptors = []
96
+ if self._auth_provider:
97
+ try:
98
+ from a2a.client import AuthInterceptor
99
+ from nat.plugins.a2a.auth.credential_service import A2ACredentialService
100
+
101
+ credential_service = A2ACredentialService(
102
+ auth_provider=self._auth_provider,
103
+ agent_card=self._agent_card,
104
+ )
105
+ interceptors.append(AuthInterceptor(credential_service))
106
+ logger.info("Authentication configured for A2A client")
107
+ except ImportError as e:
108
+ logger.error("Failed to setup authentication: %s", e)
109
+ raise RuntimeError("Authentication requires a2a-sdk with AuthInterceptor support") from e
110
+
111
+ # 4) Create A2A client with interceptors
86
112
  client_config = ClientConfig(
87
113
  httpx_client=self._httpx_client,
88
114
  streaming=self._streaming,
89
115
  )
90
116
  factory = ClientFactory(client_config)
91
- self._client = factory.create(self._agent_card)
117
+ self._client = factory.create(self._agent_card, interceptors=interceptors)
92
118
 
93
119
  logger.info("Connected to A2A agent at %s", self._base_url)
94
120
  return self
@@ -65,5 +65,8 @@ class A2AClientConfig(FunctionGroupBaseConfig, name="a2a_client"):
65
65
  description="Whether to enable streaming support for the A2A client",
66
66
  )
67
67
 
68
- auth_provider: str | AuthenticationRef | None = Field(default=None,
69
- description="Reference to authentication provider")
68
+ auth_provider: str | AuthenticationRef | None = Field(
69
+ default=None,
70
+ description="Reference to NAT authentication provider for authenticating with the A2A agent. "
71
+ "Supports OAuth2, API Key, HTTP Basic, and other NAT auth providers.",
72
+ )
@@ -15,6 +15,7 @@
15
15
 
16
16
  import logging
17
17
  from collections.abc import AsyncGenerator
18
+ from typing import TYPE_CHECKING
18
19
  from typing import Any
19
20
 
20
21
  from pydantic import BaseModel
@@ -22,10 +23,13 @@ from pydantic import Field
22
23
 
23
24
  from nat.builder.function import FunctionGroup
24
25
  from nat.builder.workflow_builder import Builder
25
- from nat.cli.register_workflow import register_function_group
26
+ from nat.cli.register_workflow import register_per_user_function_group
26
27
  from nat.plugins.a2a.client.client_base import A2ABaseClient
27
28
  from nat.plugins.a2a.client.client_config import A2AClientConfig
28
29
 
30
+ if TYPE_CHECKING:
31
+ from nat.authentication.interfaces import AuthProviderBase
32
+
29
33
  logger = logging.getLogger(__name__)
30
34
 
31
35
 
@@ -66,13 +70,36 @@ class A2AClientFunctionGroup(FunctionGroup):
66
70
  config: A2AClientConfig = self._config # type: ignore[assignment]
67
71
  base_url = str(config.url)
68
72
 
73
+ # Get user_id from context (set by runtime for per-user function groups)
74
+ from nat.builder.context import Context
75
+ user_id = Context.get().user_id
76
+ if not user_id:
77
+ raise RuntimeError("User ID not found in context")
78
+
79
+ # Resolve auth provider if configured
80
+ auth_provider: AuthProviderBase | None = None
81
+ if config.auth_provider:
82
+ try:
83
+ auth_provider = await self._builder.get_auth_provider(config.auth_provider)
84
+ logger.info("Resolved authentication provider for A2A client")
85
+ except Exception as e:
86
+ logger.error("Failed to resolve auth provider '%s': %s", config.auth_provider, e)
87
+ raise RuntimeError(f"Failed to resolve auth provider: {e}") from e
88
+
69
89
  # Create and initialize A2A client
70
- self._client = A2ABaseClient(base_url=base_url,
71
- agent_card_path=config.agent_card_path,
72
- task_timeout=config.task_timeout,
73
- streaming=config.streaming)
90
+ self._client = A2ABaseClient(
91
+ base_url=base_url,
92
+ agent_card_path=config.agent_card_path,
93
+ task_timeout=config.task_timeout,
94
+ streaming=config.streaming,
95
+ auth_provider=auth_provider,
96
+ )
74
97
  await self._client.__aenter__()
75
- logger.info("Connected to A2A agent at %s", base_url)
98
+
99
+ if auth_provider:
100
+ logger.info("Connected to A2A agent at %s with authentication (user_id: %s)", base_url, user_id)
101
+ else:
102
+ logger.info("Connected to A2A agent at %s (user_id: %s)", base_url, user_id)
76
103
 
77
104
  # Discover agent card and register functions
78
105
  self._register_functions()
@@ -281,11 +308,12 @@ class A2AClientFunctionGroup(FunctionGroup):
281
308
  yield event
282
309
 
283
310
 
284
- @register_function_group(config_type=A2AClientConfig)
311
+ @register_per_user_function_group(config_type=A2AClientConfig)
285
312
  async def a2a_client_function_group(config: A2AClientConfig, _builder: Builder):
286
313
  """
287
314
  Connect to an A2A agent, discover agent card and publish the primary
288
- agent function and helper functions.
315
+ agent function and helper functions. This function group is per-user,
316
+ meaning each user gets their own isolated instance.
289
317
 
290
318
  This function group creates a three-level API:
291
319
  - High-level: Agent function named after the agent (e.g., dice_agent)
@@ -17,7 +17,9 @@ import logging
17
17
 
18
18
  from pydantic import BaseModel
19
19
  from pydantic import Field
20
+ from pydantic import model_validator
20
21
 
22
+ from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
21
23
  from nat.data_models.front_end import FrontEndBaseConfig
22
24
 
23
25
  logger = logging.getLogger(__name__)
@@ -102,3 +104,28 @@ class A2AFrontEndConfig(FrontEndBaseConfig, name="a2a"):
102
104
  default=None,
103
105
  description="Custom worker class for handling A2A routes (default: built-in worker)",
104
106
  )
107
+
108
+ # OAuth2 Resource Server (for protecting this A2A agent)
109
+ server_auth: OAuth2ResourceServerConfig | None = Field(
110
+ default=None,
111
+ description=("OAuth 2.0 Resource Server configuration for token verification. "
112
+ "When configured, the A2A server will validate OAuth2 Bearer tokens on all requests "
113
+ "except public agent card discovery. Supports both JWT validation (via JWKS) and "
114
+ "opaque token validation (via RFC 7662 introspection)."),
115
+ )
116
+
117
+ @model_validator(mode="after")
118
+ def validate_security_configuration(self):
119
+ """Validate security configuration to prevent accidental misconfigurations."""
120
+ # Check if server is bound to a non-localhost interface without authentication
121
+ localhost_hosts = {"localhost", "127.0.0.1", "::1"}
122
+ if self.host not in localhost_hosts and self.server_auth is None:
123
+ logger.warning(
124
+ "A2A server is configured to bind to '%s' without authentication. "
125
+ "This may expose your server to unauthorized access. "
126
+ "Consider either: (1) binding to localhost for local-only access, "
127
+ "or (2) configuring server_auth for production deployments on public interfaces.",
128
+ self.host,
129
+ )
130
+
131
+ return self
@@ -53,7 +53,7 @@ class A2AFrontEndPlugin(FrontEndBase[A2AFrontEndConfig]):
53
53
  agent_card = await worker.create_agent_card(workflow)
54
54
 
55
55
  # Create agent executor adapter
56
- agent_executor = worker.create_agent_executor(workflow)
56
+ agent_executor = worker.create_agent_executor(workflow, builder)
57
57
 
58
58
  # Create A2A server
59
59
  a2a_server = worker.create_a2a_server(agent_card, agent_executor)
@@ -70,8 +70,21 @@ class A2AFrontEndPlugin(FrontEndBase[A2AFrontEndConfig]):
70
70
  self.front_end_config.host,
71
71
  self.front_end_config.port)
72
72
 
73
- # Build the ASGI app and run with uvicorn
73
+ # Build the ASGI app
74
74
  app = a2a_server.build()
75
+
76
+ # Add OAuth2 validation middleware if configured
77
+ if self.front_end_config.server_auth:
78
+ from nat.plugins.a2a.server.oauth_middleware import OAuth2ValidationMiddleware
79
+
80
+ app.add_middleware(OAuth2ValidationMiddleware, config=self.front_end_config.server_auth)
81
+ logger.info(
82
+ "OAuth2 token validation enabled for A2A server (issuer=%s, scopes=%s)",
83
+ self.front_end_config.server_auth.issuer_url,
84
+ self.front_end_config.server_auth.scopes,
85
+ )
86
+
87
+ # Run with uvicorn
75
88
  config = uvicorn.Config(
76
89
  app,
77
90
  host=self.front_end_config.host,
@@ -25,8 +25,10 @@ from a2a.server.tasks import InMemoryTaskStore
25
25
  from a2a.types import AgentCapabilities
26
26
  from a2a.types import AgentCard
27
27
  from a2a.types import AgentSkill
28
+ from a2a.types import SecurityScheme
28
29
  from nat.builder.function import Function
29
30
  from nat.builder.workflow import Workflow
31
+ from nat.builder.workflow_builder import WorkflowBuilder
30
32
  from nat.data_models.config import Config
31
33
  from nat.plugins.a2a.server.agent_executor_adapter import NATWorkflowAgentExecutor
32
34
  from nat.plugins.a2a.server.front_end_config import A2AFrontEndConfig
@@ -72,6 +74,81 @@ class A2AFrontEndPluginWorker:
72
74
 
73
75
  return functions
74
76
 
77
+ async def _generate_security_schemes(
78
+ self, server_auth_config) -> tuple[dict[str, SecurityScheme], list[dict[str, list[str]]]]:
79
+ """Generate A2A security schemes from OAuth2ResourceServerConfig.
80
+
81
+ Args:
82
+ server_auth_config: OAuth2ResourceServerConfig
83
+
84
+ Returns:
85
+ Tuple of (security_schemes dict, security requirements list)
86
+ """
87
+ from a2a.types import AuthorizationCodeOAuthFlow
88
+ from a2a.types import OAuth2SecurityScheme
89
+ from a2a.types import OAuthFlows
90
+
91
+ # Resolve OAuth2 endpoints from configuration
92
+ auth_url, token_url = await self._resolve_oauth_endpoints(server_auth_config)
93
+
94
+ # Create scope descriptions
95
+ scope_descriptions = {scope: f"Permission: {scope}" for scope in server_auth_config.scopes}
96
+
97
+ # Build OAuth2 security scheme
98
+ security_schemes = {
99
+ "oauth2":
100
+ SecurityScheme(root=OAuth2SecurityScheme(
101
+ type="oauth2",
102
+ description="OAuth 2.0 authentication required to access this agent",
103
+ flows=OAuthFlows(authorizationCode=AuthorizationCodeOAuthFlow(
104
+ authorizationUrl=auth_url,
105
+ tokenUrl=token_url,
106
+ scopes=scope_descriptions,
107
+ )),
108
+ ))
109
+ }
110
+
111
+ # Security requirements (scopes needed)
112
+ security = [{"oauth2": server_auth_config.scopes}]
113
+
114
+ return security_schemes, security
115
+
116
+ async def _resolve_oauth_endpoints(self, server_auth_config) -> tuple[str, str]:
117
+ """Resolve authorization and token URLs from OAuth2 configuration.
118
+
119
+ Args:
120
+ server_auth_config: OAuth2ResourceServerConfig
121
+
122
+ Returns:
123
+ Tuple of (authorization_url, token_url)
124
+ """
125
+ import httpx
126
+
127
+ # If discovery URL is provided, use OIDC discovery
128
+ if server_auth_config.discovery_url:
129
+ try:
130
+ async with httpx.AsyncClient() as client:
131
+ response = await client.get(server_auth_config.discovery_url, timeout=5.0)
132
+ response.raise_for_status()
133
+ metadata = response.json()
134
+
135
+ auth_url = metadata.get("authorization_endpoint")
136
+ token_url = metadata.get("token_endpoint")
137
+
138
+ if auth_url and token_url:
139
+ logger.info("Resolved OAuth endpoints via discovery: %s", server_auth_config.discovery_url)
140
+ return auth_url, token_url
141
+ except Exception as e:
142
+ logger.warning("Failed to discover OAuth endpoints: %s", e)
143
+
144
+ # Fallback: derive from issuer URL (common convention)
145
+ issuer = server_auth_config.issuer_url.rstrip("/")
146
+ auth_url = f"{issuer}/oauth/authorize"
147
+ token_url = f"{issuer}/oauth/token"
148
+
149
+ logger.info("Using derived OAuth endpoints from issuer: %s", issuer)
150
+ return auth_url, token_url
151
+
75
152
  async def create_agent_card(self, workflow: Workflow) -> AgentCard:
76
153
  """Build AgentCard from configuration and workflow functions.
77
154
 
@@ -112,6 +189,18 @@ class A2AFrontEndPluginWorker:
112
189
 
113
190
  logger.info("Auto-generated %d skills from workflow functions", len(skills))
114
191
 
192
+ # Generate security schemes if server_auth is configured
193
+ security_schemes = None
194
+ security = None
195
+
196
+ if config.server_auth:
197
+ security_schemes, security = await self._generate_security_schemes(config.server_auth)
198
+ logger.info(
199
+ "Generated OAuth2 security schemes for agent (issuer=%s, scopes=%s)",
200
+ config.server_auth.issuer_url,
201
+ config.server_auth.scopes,
202
+ )
203
+
115
204
  # Build agent card
116
205
  agent_url = f"http://{config.host}:{config.port}/"
117
206
  agent_card = AgentCard(
@@ -123,15 +212,19 @@ class A2AFrontEndPluginWorker:
123
212
  default_output_modes=config.default_output_modes,
124
213
  capabilities=capabilities,
125
214
  skills=skills,
215
+ security_schemes=security_schemes,
216
+ security=security,
126
217
  )
127
218
 
128
219
  logger.info("Created AgentCard for: %s v%s", config.name, config.version)
129
220
  logger.info("Agent URL: %s", agent_url)
130
221
  logger.info("Skills: %d", len(skills))
222
+ if security_schemes:
223
+ logger.info("Security: OAuth2 authentication required")
131
224
 
132
225
  return agent_card
133
226
 
134
- def create_agent_executor(self, workflow: Workflow) -> NATWorkflowAgentExecutor:
227
+ def create_agent_executor(self, workflow: Workflow, builder: WorkflowBuilder) -> NATWorkflowAgentExecutor:
135
228
  """Create agent executor adapter for the workflow.
136
229
 
137
230
  This creates a SessionManager to handle concurrent A2A task requests,
@@ -139,13 +232,16 @@ class A2AFrontEndPluginWorker:
139
232
 
140
233
  Args:
141
234
  workflow: The NAT workflow to expose
235
+ builder: The workflow builder used to create the workflow
142
236
 
143
237
  Returns:
144
238
  NATWorkflowAgentExecutor that wraps the workflow with a SessionManager
145
239
  """
146
240
  # Create SessionManager to handle concurrent requests with proper limits
147
241
  session_manager = SessionManager(
148
- workflow=workflow,
242
+ config=self.full_config,
243
+ shared_builder=builder,
244
+ shared_workflow=workflow,
149
245
  max_concurrency=self.max_concurrency,
150
246
  )
151
247
 
@@ -0,0 +1,121 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """OAuth 2.0 token validation middleware for A2A servers."""
16
+
17
+ import logging
18
+
19
+ from starlette.middleware.base import BaseHTTPMiddleware
20
+ from starlette.requests import Request
21
+ from starlette.responses import JSONResponse
22
+
23
+ from nat.authentication.credential_validator.bearer_token_validator import BearerTokenValidator
24
+ from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class OAuth2ValidationMiddleware(BaseHTTPMiddleware):
30
+ """OAuth2 Bearer token validation middleware for A2A servers.
31
+
32
+ Validates Bearer tokens using NAT's BearerTokenValidator which supports:
33
+ - JWT validation via JWKS (RFC 7519)
34
+ - Opaque token validation via introspection (RFC 7662)
35
+ - OIDC discovery
36
+ - Scope and audience enforcement
37
+
38
+ The middleware allows public access to the agent card discovery endpoint
39
+ (/.well-known/agent.json) and validates all other A2A requests.
40
+ """
41
+
42
+ def __init__(self, app, config: OAuth2ResourceServerConfig):
43
+ """Initialize OAuth2 validation middleware.
44
+
45
+ Args:
46
+ app: Starlette application
47
+ config: OAuth2 resource server configuration
48
+ """
49
+ super().__init__(app)
50
+
51
+ # Create validator using NAT's BearerTokenValidator
52
+ self.validator = BearerTokenValidator(
53
+ issuer=config.issuer_url,
54
+ audience=config.audience,
55
+ scopes=config.scopes,
56
+ jwks_uri=config.jwks_uri,
57
+ introspection_endpoint=config.introspection_endpoint,
58
+ discovery_url=config.discovery_url,
59
+ client_id=config.client_id,
60
+ client_secret=config.client_secret.get_secret_value() if config.client_secret else None,
61
+ )
62
+
63
+ logger.info(
64
+ "OAuth2 validation middleware initialized (issuer=%s, scopes=%s, audience=%s)",
65
+ config.issuer_url,
66
+ config.scopes,
67
+ config.audience,
68
+ )
69
+
70
+ async def dispatch(self, request: Request, call_next):
71
+ """Validate OAuth2 Bearer token for all requests except agent card discovery.
72
+
73
+ Args:
74
+ request: Incoming HTTP request
75
+ call_next: Next middleware/handler in chain
76
+
77
+ Returns:
78
+ HTTP response (either error or result from next handler)
79
+ """
80
+ # Public: Agent card discovery (per A2A spec)
81
+ if request.url.path == "/.well-known/agent-card.json":
82
+ logger.debug("Public access to agent card discovery")
83
+ return await call_next(request)
84
+
85
+ # Extract Bearer token
86
+ auth_header = request.headers.get("Authorization", "")
87
+ if not auth_header.startswith("Bearer "):
88
+ logger.warning("Missing or invalid Authorization header")
89
+ return JSONResponse({
90
+ "error": "unauthorized", "message": "Missing or invalid Bearer token"
91
+ },
92
+ status_code=401)
93
+
94
+ token = auth_header[7:] # Strip "Bearer "
95
+
96
+ # Validate token using NAT's validator
97
+ try:
98
+ result = await self.validator.verify(token)
99
+ except Exception as e:
100
+ logger.error(f"Token validation error: {e}")
101
+ return JSONResponse({"error": "invalid_token", "message": "Token validation failed"}, status_code=403)
102
+
103
+ # Check if token is active
104
+ if not result.active:
105
+ logger.warning("Token is not active")
106
+ return JSONResponse({"error": "invalid_token", "message": "Token is not active"}, status_code=403)
107
+
108
+ # Attach token info to request state for potential use by handlers
109
+ request.state.oauth_user = result.subject
110
+ request.state.oauth_scopes = result.scopes or []
111
+ request.state.oauth_client_id = result.client_id
112
+ request.state.oauth_token_info = result
113
+
114
+ logger.debug(
115
+ "Token validated successfully (user=%s, scopes=%s, client=%s)",
116
+ result.subject,
117
+ result.scopes,
118
+ result.client_id,
119
+ )
120
+
121
+ return await call_next(request)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nvidia-nat-a2a
3
- Version: 1.4.0a20251207
3
+ Version: 1.4.0a20251231
4
4
  Summary: Subpackage for A2A Protocol integration in NeMo Agent Toolkit
5
5
  Author: NVIDIA Corporation
6
6
  Maintainer: NVIDIA Corporation
@@ -15,7 +15,7 @@ Classifier: Programming Language :: Python :: 3.13
15
15
  Requires-Python: <3.14,>=3.11
16
16
  Description-Content-Type: text/markdown
17
17
  License-File: LICENSE.md
18
- Requires-Dist: nvidia-nat==v1.4.0a20251207
18
+ Requires-Dist: nvidia-nat==v1.4.0a20251231
19
19
  Requires-Dist: a2a-sdk~=0.3.20
20
20
  Dynamic: license-file
21
21
 
@@ -0,0 +1,22 @@
1
+ nat/meta/pypi.md,sha256=YkfjzZntzheoaBie5ZovnAwB78xxVqk9sblkZRZcdLU,1661
2
+ nat/plugins/a2a/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
3
+ nat/plugins/a2a/register.py,sha256=pUN1hbJ38M8GbdNcA0qQzJ1S-ZC91GnRGk_8SO_kTVg,853
4
+ nat/plugins/a2a/auth/__init__.py,sha256=iQFx1YrjFcepS7k8jp93A0IVOkFeNx_I35M6dIngoJA,726
5
+ nat/plugins/a2a/auth/credential_service.py,sha256=-_VdDF4YESaAtY1ONUiOL5z4aGDJZYVuhyhI9BZhuyI,15967
6
+ nat/plugins/a2a/client/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
7
+ nat/plugins/a2a/client/client_base.py,sha256=xShDZDFKa4R2XsY3yBMvM-eDaf_0cdE48XJzQ4WcEOw,13366
8
+ nat/plugins/a2a/client/client_config.py,sha256=KwWjymDg9GUfSYcIaBhcxph4Hu6IeTe414hrNUUo-6g,2875
9
+ nat/plugins/a2a/client/client_impl.py,sha256=CGAjiHr6EyWcnlSipmT8ixgjD4s8VbPRBPOZy2q_Sm0,12958
10
+ nat/plugins/a2a/server/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
11
+ nat/plugins/a2a/server/agent_executor_adapter.py,sha256=wvGXOb3FcV0_pYRv-yr-QzozjzXM909D49Dxm9199xI,7015
12
+ nat/plugins/a2a/server/front_end_config.py,sha256=Lg-qjDmC4fwrwnHNtSRl54pMpdwVnO06xhgbLt-aEZY,4902
13
+ nat/plugins/a2a/server/front_end_plugin.py,sha256=fX3Lagkd48snSiNo2IMTRpR-40WHUWQidpjKu8uQChY,4896
14
+ nat/plugins/a2a/server/front_end_plugin_worker.py,sha256=Ehdv6lyUcrWkfMq7YomD4NYFAusrtQ2JYj2HnkIqGhY,11696
15
+ nat/plugins/a2a/server/oauth_middleware.py,sha256=NvvIJSPB8wRui2eQlxr6AaNhN0JxdUQ1Ajr8Dnk0rnY,4751
16
+ nat/plugins/a2a/server/register_frontend.py,sha256=4TmpBcZF4x71c2xnWuketsygqHmU7D2hKA2bzO34TpU,1480
17
+ nvidia_nat_a2a-1.4.0a20251231.dist-info/licenses/LICENSE.md,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
18
+ nvidia_nat_a2a-1.4.0a20251231.dist-info/METADATA,sha256=U2rDLPsY0wUsArIBqHmm3srRc85nlXDjQHRUsKpYHzQ,2438
19
+ nvidia_nat_a2a-1.4.0a20251231.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
20
+ nvidia_nat_a2a-1.4.0a20251231.dist-info/entry_points.txt,sha256=Lacvy6nXpDTv8dh8vKJ_QE8TobliVdhgABuw25t8fBg,145
21
+ nvidia_nat_a2a-1.4.0a20251231.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
22
+ nvidia_nat_a2a-1.4.0a20251231.dist-info/RECORD,,
@@ -1,19 +0,0 @@
1
- nat/meta/pypi.md,sha256=YkfjzZntzheoaBie5ZovnAwB78xxVqk9sblkZRZcdLU,1661
2
- nat/plugins/a2a/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
3
- nat/plugins/a2a/register.py,sha256=pUN1hbJ38M8GbdNcA0qQzJ1S-ZC91GnRGk_8SO_kTVg,853
4
- nat/plugins/a2a/client/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
5
- nat/plugins/a2a/client/client_base.py,sha256=s0HTS0ebuwoc_fL4Z_laRe-OIZ5iooSG_FzcgZca8_E,12070
6
- nat/plugins/a2a/client/client_config.py,sha256=SWu46fAa25IYc3Lhq9w9nIt5xkCtdBuuPy74pd5vPPk,2788
7
- nat/plugins/a2a/client/client_impl.py,sha256=cc_rYyPq86_8R12MjesMHaZrYG9lDnip2125veQ1fEY,11775
8
- nat/plugins/a2a/server/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
9
- nat/plugins/a2a/server/agent_executor_adapter.py,sha256=wvGXOb3FcV0_pYRv-yr-QzozjzXM909D49Dxm9199xI,7015
10
- nat/plugins/a2a/server/front_end_config.py,sha256=Qnjbx6n67Xy3sZ6rkAYZaKk-WfBDrVX5OzZSWxU6fIg,3423
11
- nat/plugins/a2a/server/front_end_plugin.py,sha256=euhh5LXkZpyC5HaUaJKFJH3BIF6jS2ti3NNXVQ71bgI,4255
12
- nat/plugins/a2a/server/front_end_plugin_worker.py,sha256=Ih00L9DtZZYQvN_RRw4a5M_StTFhZ4JGHNniw6glOzY,7757
13
- nat/plugins/a2a/server/register_frontend.py,sha256=4TmpBcZF4x71c2xnWuketsygqHmU7D2hKA2bzO34TpU,1480
14
- nvidia_nat_a2a-1.4.0a20251207.dist-info/licenses/LICENSE.md,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
15
- nvidia_nat_a2a-1.4.0a20251207.dist-info/METADATA,sha256=6qhxCj7OS1n7csjZtbbByAd86SYtWVoydK0nFnvDfRc,2438
16
- nvidia_nat_a2a-1.4.0a20251207.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
17
- nvidia_nat_a2a-1.4.0a20251207.dist-info/entry_points.txt,sha256=Lacvy6nXpDTv8dh8vKJ_QE8TobliVdhgABuw25t8fBg,145
18
- nvidia_nat_a2a-1.4.0a20251207.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
19
- nvidia_nat_a2a-1.4.0a20251207.dist-info/RECORD,,