nvidia-nat-a2a 1.4.0a20251207__py3-none-any.whl → 1.4.0a20251224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nvidia-nat-a2a might be problematic. Click here for more details.
- nat/plugins/a2a/auth/__init__.py +15 -0
- nat/plugins/a2a/auth/credential_service.py +418 -0
- nat/plugins/a2a/client/client_base.py +33 -7
- nat/plugins/a2a/client/client_config.py +5 -2
- nat/plugins/a2a/client/client_impl.py +36 -8
- nat/plugins/a2a/server/front_end_config.py +27 -0
- nat/plugins/a2a/server/front_end_plugin.py +15 -2
- nat/plugins/a2a/server/front_end_plugin_worker.py +98 -2
- nat/plugins/a2a/server/oauth_middleware.py +121 -0
- {nvidia_nat_a2a-1.4.0a20251207.dist-info → nvidia_nat_a2a-1.4.0a20251224.dist-info}/METADATA +2 -2
- nvidia_nat_a2a-1.4.0a20251224.dist-info/RECORD +22 -0
- nvidia_nat_a2a-1.4.0a20251207.dist-info/RECORD +0 -19
- {nvidia_nat_a2a-1.4.0a20251207.dist-info → nvidia_nat_a2a-1.4.0a20251224.dist-info}/WHEEL +0 -0
- {nvidia_nat_a2a-1.4.0a20251207.dist-info → nvidia_nat_a2a-1.4.0a20251224.dist-info}/entry_points.txt +0 -0
- {nvidia_nat_a2a-1.4.0a20251207.dist-info → nvidia_nat_a2a-1.4.0a20251224.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat_a2a-1.4.0a20251207.dist-info → nvidia_nat_a2a-1.4.0a20251224.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""Authentication support for A2A clients."""
|
|
@@ -0,0 +1,418 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""Bridge NAT AuthProviderBase to A2A SDK CredentialService."""
|
|
16
|
+
|
|
17
|
+
import asyncio
|
|
18
|
+
import logging
|
|
19
|
+
|
|
20
|
+
from a2a.client import ClientCallContext
|
|
21
|
+
from a2a.client import CredentialService
|
|
22
|
+
from a2a.types import AgentCard
|
|
23
|
+
from a2a.types import APIKeySecurityScheme
|
|
24
|
+
from a2a.types import HTTPAuthSecurityScheme
|
|
25
|
+
from a2a.types import OAuth2SecurityScheme
|
|
26
|
+
from a2a.types import OpenIdConnectSecurityScheme
|
|
27
|
+
from a2a.types import SecurityScheme
|
|
28
|
+
from nat.authentication.interfaces import AuthProviderBase
|
|
29
|
+
from nat.builder.context import Context
|
|
30
|
+
from nat.data_models.authentication import AuthResult
|
|
31
|
+
from nat.data_models.authentication import BasicAuthCred
|
|
32
|
+
from nat.data_models.authentication import BearerTokenCred
|
|
33
|
+
from nat.data_models.authentication import CookieCred
|
|
34
|
+
from nat.data_models.authentication import HeaderCred
|
|
35
|
+
from nat.data_models.authentication import QueryCred
|
|
36
|
+
|
|
37
|
+
logger = logging.getLogger(__name__)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class A2ACredentialService(CredentialService):
|
|
41
|
+
"""
|
|
42
|
+
Adapts NAT AuthProviderBase to A2A SDK CredentialService interface.
|
|
43
|
+
|
|
44
|
+
This class bridges NAT's authentication system with the A2A SDK's authentication
|
|
45
|
+
mechanism, allowing A2A clients to use NAT's auth providers (API Key, OAuth2, etc.)
|
|
46
|
+
to authenticate with A2A agents.
|
|
47
|
+
|
|
48
|
+
The adapter:
|
|
49
|
+
- Calls NAT auth provider to obtain credentials
|
|
50
|
+
- Maps NAT credential types to A2A security scheme requirements
|
|
51
|
+
- Handles token expiration and automatic refresh
|
|
52
|
+
- Supports session-based multi-user authentication
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
auth_provider: NAT authentication provider instance
|
|
56
|
+
agent_card: Agent card containing security scheme definitions
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
auth_provider: AuthProviderBase,
|
|
62
|
+
agent_card: AgentCard | None = None,
|
|
63
|
+
):
|
|
64
|
+
self._auth_provider = auth_provider
|
|
65
|
+
self._agent_card = agent_card
|
|
66
|
+
self._cached_auth_result: AuthResult | None = None
|
|
67
|
+
self._auth_lock = asyncio.Lock()
|
|
68
|
+
|
|
69
|
+
# Validate provider compatibility with agent's security requirements
|
|
70
|
+
self._validate_provider_compatibility()
|
|
71
|
+
|
|
72
|
+
async def get_credentials(
|
|
73
|
+
self,
|
|
74
|
+
security_scheme_name: str,
|
|
75
|
+
context: ClientCallContext | None,
|
|
76
|
+
) -> str | None:
|
|
77
|
+
"""
|
|
78
|
+
Retrieve credentials for a security scheme.
|
|
79
|
+
|
|
80
|
+
This method:
|
|
81
|
+
1. Gets user_id from NAT context
|
|
82
|
+
2. Authenticates via NAT auth provider
|
|
83
|
+
3. Handles token expiration and refresh
|
|
84
|
+
4. Maps credentials to the requested security scheme
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
security_scheme_name: Name of the security scheme from AgentCard
|
|
88
|
+
context: Client call context with optional session information
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Credential string or None if not available
|
|
92
|
+
"""
|
|
93
|
+
# Get user_id from NAT context
|
|
94
|
+
user_id = Context.get().user_id
|
|
95
|
+
|
|
96
|
+
# Authenticate and get credentials from NAT provider
|
|
97
|
+
auth_result = await self._authenticate(user_id)
|
|
98
|
+
|
|
99
|
+
if not auth_result:
|
|
100
|
+
logger.warning("Authentication failed, no credentials available")
|
|
101
|
+
return None
|
|
102
|
+
|
|
103
|
+
# Map NAT credentials to A2A format based on security scheme
|
|
104
|
+
credential = self._extract_credential_for_scheme(auth_result, security_scheme_name)
|
|
105
|
+
|
|
106
|
+
if credential:
|
|
107
|
+
logger.debug(
|
|
108
|
+
"Successfully retrieved credentials for scheme '%s'",
|
|
109
|
+
security_scheme_name,
|
|
110
|
+
)
|
|
111
|
+
else:
|
|
112
|
+
logger.warning(
|
|
113
|
+
"No compatible credentials found for scheme '%s'",
|
|
114
|
+
security_scheme_name,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
return credential
|
|
118
|
+
|
|
119
|
+
async def _authenticate(self, user_id: str | None) -> AuthResult | None:
|
|
120
|
+
"""
|
|
121
|
+
Authenticate and get credentials from NAT auth provider.
|
|
122
|
+
|
|
123
|
+
Handles token expiration by triggering re-authentication if needed.
|
|
124
|
+
Uses a lock to prevent concurrent authentication requests and race conditions.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
user_id: User identifier for authentication
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
AuthResult with credentials or None on failure
|
|
131
|
+
"""
|
|
132
|
+
try:
|
|
133
|
+
# Fast path: check cache without lock
|
|
134
|
+
auth_result = self._cached_auth_result
|
|
135
|
+
if auth_result and not auth_result.is_expired():
|
|
136
|
+
return auth_result
|
|
137
|
+
|
|
138
|
+
# Acquire lock to serialize authentication attempts
|
|
139
|
+
async with self._auth_lock:
|
|
140
|
+
# Double-check: another coroutine may have refreshed while we waited for lock
|
|
141
|
+
auth_result = self._cached_auth_result
|
|
142
|
+
if auth_result and not auth_result.is_expired():
|
|
143
|
+
logger.debug("Credentials were refreshed by another coroutine while waiting for lock")
|
|
144
|
+
return auth_result
|
|
145
|
+
|
|
146
|
+
# Log if we're refreshing expired credentials
|
|
147
|
+
if auth_result and auth_result.is_expired():
|
|
148
|
+
logger.info("Cached credentials expired, re-authenticating")
|
|
149
|
+
|
|
150
|
+
# Call NAT auth provider (provider is responsible for token refresh/validity)
|
|
151
|
+
auth_result = await self._auth_provider.authenticate(user_id=user_id)
|
|
152
|
+
|
|
153
|
+
# Cache the result while holding the lock
|
|
154
|
+
self._cached_auth_result = auth_result
|
|
155
|
+
|
|
156
|
+
# Warn if provider returned expired credentials (provider bug)
|
|
157
|
+
if auth_result and auth_result.is_expired():
|
|
158
|
+
logger.warning("Auth provider returned already-expired credentials. "
|
|
159
|
+
"This may indicate a bug in the auth provider's token refresh logic.")
|
|
160
|
+
|
|
161
|
+
return auth_result
|
|
162
|
+
|
|
163
|
+
except Exception as e:
|
|
164
|
+
logger.error("Authentication failed: %s", e, exc_info=True)
|
|
165
|
+
return None
|
|
166
|
+
|
|
167
|
+
def _extract_credential_for_scheme(self, auth_result: AuthResult, security_scheme_name: str) -> str | None:
|
|
168
|
+
"""
|
|
169
|
+
Extract appropriate credential based on security scheme type.
|
|
170
|
+
|
|
171
|
+
Maps NAT credential types to A2A security scheme requirements:
|
|
172
|
+
- BearerTokenCred -> OAuth2, OIDC, HTTP Bearer
|
|
173
|
+
- HeaderCred -> API Key in header
|
|
174
|
+
- QueryCred -> API Key in query
|
|
175
|
+
- CookieCred -> API Key in cookie
|
|
176
|
+
- BasicAuthCred -> HTTP Basic
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
auth_result: Authentication result containing credentials
|
|
180
|
+
security_scheme_name: Name of the security scheme
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
Credential string or None
|
|
184
|
+
"""
|
|
185
|
+
# Get scheme definition from agent card
|
|
186
|
+
scheme_def = self._get_scheme_definition(security_scheme_name)
|
|
187
|
+
|
|
188
|
+
# Try to match NAT credentials to security scheme
|
|
189
|
+
for cred in auth_result.credentials:
|
|
190
|
+
# Check compatibility and extract credential value
|
|
191
|
+
credential_value = None
|
|
192
|
+
|
|
193
|
+
if isinstance(cred, BearerTokenCred) and self._is_bearer_compatible(scheme_def):
|
|
194
|
+
credential_value = cred.token.get_secret_value()
|
|
195
|
+
elif isinstance(cred, HeaderCred) and self._is_header_compatible(scheme_def, cred.name):
|
|
196
|
+
credential_value = cred.value.get_secret_value()
|
|
197
|
+
elif isinstance(cred, QueryCred) and self._is_query_compatible(scheme_def, cred.name):
|
|
198
|
+
credential_value = cred.value.get_secret_value()
|
|
199
|
+
elif isinstance(cred, CookieCred) and self._is_cookie_compatible(scheme_def, cred.name):
|
|
200
|
+
credential_value = cred.value.get_secret_value()
|
|
201
|
+
elif isinstance(cred, BasicAuthCred) and self._is_basic_compatible(scheme_def):
|
|
202
|
+
# For HTTP Basic, encode username:password as base64
|
|
203
|
+
import base64
|
|
204
|
+
|
|
205
|
+
username = cred.username.get_secret_value()
|
|
206
|
+
password = cred.password.get_secret_value()
|
|
207
|
+
credentials = f"{username}:{password}"
|
|
208
|
+
credential_value = base64.b64encode(credentials.encode()).decode()
|
|
209
|
+
|
|
210
|
+
if credential_value:
|
|
211
|
+
return credential_value
|
|
212
|
+
|
|
213
|
+
return None
|
|
214
|
+
|
|
215
|
+
def _get_scheme_definition(self, scheme_name: str) -> SecurityScheme | None:
|
|
216
|
+
"""
|
|
217
|
+
Get security scheme definition from agent card.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
scheme_name: Name of the security scheme
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
SecurityScheme definition or None
|
|
224
|
+
"""
|
|
225
|
+
if not self._agent_card or not self._agent_card.security_schemes:
|
|
226
|
+
return None
|
|
227
|
+
return self._agent_card.security_schemes.get(scheme_name)
|
|
228
|
+
|
|
229
|
+
def _validate_provider_compatibility(self) -> None:
|
|
230
|
+
"""
|
|
231
|
+
Validate that the auth provider type is compatible with agent's security schemes.
|
|
232
|
+
|
|
233
|
+
This performs early validation at connection time to fail fast if there's a
|
|
234
|
+
configuration mismatch between the NAT auth provider and the A2A agent's
|
|
235
|
+
security requirements.
|
|
236
|
+
|
|
237
|
+
Raises:
|
|
238
|
+
ValueError: If the provider is incompatible with all required security schemes
|
|
239
|
+
"""
|
|
240
|
+
if not self._agent_card or not self._agent_card.security_schemes:
|
|
241
|
+
# No security schemes defined, nothing to validate
|
|
242
|
+
logger.debug("No security schemes defined in agent card, skipping validation")
|
|
243
|
+
return
|
|
244
|
+
|
|
245
|
+
provider_type = type(self._auth_provider).__name__
|
|
246
|
+
schemes = self._agent_card.security_schemes
|
|
247
|
+
|
|
248
|
+
logger.info("Validating auth provider '%s' against agent security schemes: %s",
|
|
249
|
+
provider_type,
|
|
250
|
+
list(schemes.keys()))
|
|
251
|
+
|
|
252
|
+
# Check if provider type is compatible with at least one security scheme
|
|
253
|
+
compatible_schemes = []
|
|
254
|
+
incompatible_schemes = []
|
|
255
|
+
|
|
256
|
+
for scheme_name, scheme in schemes.items():
|
|
257
|
+
is_compatible = self._is_provider_compatible_with_scheme(scheme)
|
|
258
|
+
if is_compatible:
|
|
259
|
+
compatible_schemes.append(scheme_name)
|
|
260
|
+
else:
|
|
261
|
+
incompatible_schemes.append((scheme_name, type(scheme.root).__name__))
|
|
262
|
+
|
|
263
|
+
if not compatible_schemes:
|
|
264
|
+
# Provider is not compatible with any security scheme
|
|
265
|
+
scheme_details = ", ".join(f"{name} ({scheme_type})" for name, scheme_type in incompatible_schemes)
|
|
266
|
+
raise ValueError(f"Auth provider '{provider_type}' is not compatible with agent's "
|
|
267
|
+
f"security requirements. Agent requires: {scheme_details}")
|
|
268
|
+
|
|
269
|
+
logger.info("Auth provider '%s' is compatible with schemes: %s", provider_type, compatible_schemes)
|
|
270
|
+
|
|
271
|
+
def _is_provider_compatible_with_scheme(self, scheme: SecurityScheme) -> bool:
|
|
272
|
+
"""
|
|
273
|
+
Check if the current auth provider can satisfy a security scheme.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
scheme: Security scheme from agent card
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
True if provider is compatible with the scheme
|
|
280
|
+
"""
|
|
281
|
+
provider_type = type(self._auth_provider).__name__
|
|
282
|
+
|
|
283
|
+
# OAuth2/OIDC schemes require OAuth2 providers
|
|
284
|
+
if isinstance(scheme.root, OAuth2SecurityScheme | OpenIdConnectSecurityScheme):
|
|
285
|
+
return "OAuth2" in provider_type
|
|
286
|
+
|
|
287
|
+
# API Key schemes (can be in header, query, or cookie)
|
|
288
|
+
if isinstance(scheme.root, APIKeySecurityScheme):
|
|
289
|
+
return "APIKey" in provider_type
|
|
290
|
+
|
|
291
|
+
# HTTP Auth schemes (Basic or Bearer)
|
|
292
|
+
if isinstance(scheme.root, HTTPAuthSecurityScheme):
|
|
293
|
+
scheme_lower = scheme.root.scheme.lower()
|
|
294
|
+
if scheme_lower == "basic":
|
|
295
|
+
return "HTTPBasic" in provider_type or "BasicAuth" in provider_type
|
|
296
|
+
elif scheme_lower == "bearer":
|
|
297
|
+
# Bearer can be satisfied by OAuth2 or API Key providers
|
|
298
|
+
return "OAuth2" in provider_type or "APIKey" in provider_type
|
|
299
|
+
|
|
300
|
+
# Unknown or unsupported scheme type
|
|
301
|
+
logger.warning("Unknown security scheme type: %s", type(scheme.root).__name__)
|
|
302
|
+
return False
|
|
303
|
+
|
|
304
|
+
@staticmethod
|
|
305
|
+
def _is_bearer_compatible(scheme_def: SecurityScheme | None) -> bool:
|
|
306
|
+
"""
|
|
307
|
+
Check if security scheme accepts Bearer tokens.
|
|
308
|
+
|
|
309
|
+
Bearer tokens are compatible with:
|
|
310
|
+
- OAuth2SecurityScheme
|
|
311
|
+
- OpenIdConnectSecurityScheme
|
|
312
|
+
- HTTPAuthSecurityScheme with scheme='bearer'
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
scheme_def: Security scheme definition
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
True if Bearer token is compatible
|
|
319
|
+
"""
|
|
320
|
+
if not scheme_def:
|
|
321
|
+
return False
|
|
322
|
+
|
|
323
|
+
# Check for OAuth2 or OIDC schemes
|
|
324
|
+
if isinstance(scheme_def.root, OAuth2SecurityScheme | OpenIdConnectSecurityScheme):
|
|
325
|
+
return True
|
|
326
|
+
|
|
327
|
+
# Check for HTTP Bearer scheme
|
|
328
|
+
if isinstance(scheme_def.root, HTTPAuthSecurityScheme):
|
|
329
|
+
return scheme_def.root.scheme.lower() == "bearer"
|
|
330
|
+
|
|
331
|
+
return False
|
|
332
|
+
|
|
333
|
+
@staticmethod
|
|
334
|
+
def _is_header_compatible(scheme_def: SecurityScheme | None, header_name: str) -> bool:
|
|
335
|
+
"""
|
|
336
|
+
Check if security scheme accepts header-based API keys.
|
|
337
|
+
|
|
338
|
+
Args:
|
|
339
|
+
scheme_def: Security scheme definition
|
|
340
|
+
header_name: Name of the header containing the credential
|
|
341
|
+
|
|
342
|
+
Returns:
|
|
343
|
+
True if header credential is compatible
|
|
344
|
+
"""
|
|
345
|
+
if not scheme_def:
|
|
346
|
+
return False
|
|
347
|
+
|
|
348
|
+
# Check for API Key in header
|
|
349
|
+
if isinstance(scheme_def.root, APIKeySecurityScheme):
|
|
350
|
+
if scheme_def.root.in_ == "header":
|
|
351
|
+
# Match header name (case-insensitive)
|
|
352
|
+
return scheme_def.root.name.lower() == header_name.lower()
|
|
353
|
+
|
|
354
|
+
return False
|
|
355
|
+
|
|
356
|
+
@staticmethod
|
|
357
|
+
def _is_query_compatible(scheme_def: SecurityScheme | None, param_name: str) -> bool:
|
|
358
|
+
"""
|
|
359
|
+
Check if security scheme accepts query parameter API keys.
|
|
360
|
+
|
|
361
|
+
Args:
|
|
362
|
+
scheme_def: Security scheme definition
|
|
363
|
+
param_name: Name of the query parameter
|
|
364
|
+
|
|
365
|
+
Returns:
|
|
366
|
+
True if query credential is compatible
|
|
367
|
+
"""
|
|
368
|
+
if not scheme_def:
|
|
369
|
+
return False
|
|
370
|
+
|
|
371
|
+
# Check for API Key in query
|
|
372
|
+
if isinstance(scheme_def.root, APIKeySecurityScheme):
|
|
373
|
+
if scheme_def.root.in_ == "query":
|
|
374
|
+
return scheme_def.root.name == param_name
|
|
375
|
+
|
|
376
|
+
return False
|
|
377
|
+
|
|
378
|
+
@staticmethod
|
|
379
|
+
def _is_cookie_compatible(scheme_def: SecurityScheme | None, cookie_name: str) -> bool:
|
|
380
|
+
"""
|
|
381
|
+
Check if security scheme accepts cookie-based API keys.
|
|
382
|
+
|
|
383
|
+
Args:
|
|
384
|
+
scheme_def: Security scheme definition
|
|
385
|
+
cookie_name: Name of the cookie
|
|
386
|
+
|
|
387
|
+
Returns:
|
|
388
|
+
True if cookie credential is compatible
|
|
389
|
+
"""
|
|
390
|
+
if not scheme_def:
|
|
391
|
+
return False
|
|
392
|
+
|
|
393
|
+
# Check for API Key in cookie
|
|
394
|
+
if isinstance(scheme_def.root, APIKeySecurityScheme):
|
|
395
|
+
if scheme_def.root.in_ == "cookie":
|
|
396
|
+
return scheme_def.root.name == cookie_name
|
|
397
|
+
|
|
398
|
+
return False
|
|
399
|
+
|
|
400
|
+
@staticmethod
|
|
401
|
+
def _is_basic_compatible(scheme_def: SecurityScheme | None) -> bool:
|
|
402
|
+
"""
|
|
403
|
+
Check if security scheme accepts HTTP Basic authentication.
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
scheme_def: Security scheme definition
|
|
407
|
+
|
|
408
|
+
Returns:
|
|
409
|
+
True if Basic auth is compatible
|
|
410
|
+
"""
|
|
411
|
+
if not scheme_def:
|
|
412
|
+
return False
|
|
413
|
+
|
|
414
|
+
# Check for HTTP Basic scheme
|
|
415
|
+
if isinstance(scheme_def.root, HTTPAuthSecurityScheme):
|
|
416
|
+
return scheme_def.root.scheme.lower() == "basic"
|
|
417
|
+
|
|
418
|
+
return False
|
|
@@ -18,6 +18,7 @@ from __future__ import annotations
|
|
|
18
18
|
import logging
|
|
19
19
|
from collections.abc import AsyncGenerator
|
|
20
20
|
from datetime import timedelta
|
|
21
|
+
from typing import TYPE_CHECKING
|
|
21
22
|
from uuid import uuid4
|
|
22
23
|
|
|
23
24
|
import httpx
|
|
@@ -34,6 +35,9 @@ from a2a.types import Role
|
|
|
34
35
|
from a2a.types import Task
|
|
35
36
|
from a2a.types import TextPart
|
|
36
37
|
|
|
38
|
+
if TYPE_CHECKING:
|
|
39
|
+
from nat.authentication.interfaces import AuthProviderBase
|
|
40
|
+
|
|
37
41
|
logger = logging.getLogger(__name__)
|
|
38
42
|
|
|
39
43
|
|
|
@@ -43,20 +47,25 @@ class A2ABaseClient:
|
|
|
43
47
|
|
|
44
48
|
Args:
|
|
45
49
|
base_url: The base URL of the A2A agent
|
|
50
|
+
agent_card_path: Path to agent card (default: /.well-known/agent-card.json)
|
|
46
51
|
task_timeout: Timeout for task operations (default: 300 seconds)
|
|
52
|
+
streaming: Enable streaming responses (default: True)
|
|
53
|
+
auth_provider: Optional NAT authentication provider for securing requests
|
|
47
54
|
"""
|
|
48
55
|
|
|
49
56
|
def __init__(
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
57
|
+
self,
|
|
58
|
+
base_url: str,
|
|
59
|
+
agent_card_path: str = "/.well-known/agent-card.json",
|
|
60
|
+
task_timeout: timedelta = timedelta(seconds=300),
|
|
61
|
+
streaming: bool = True,
|
|
62
|
+
auth_provider: AuthProviderBase | None = None,
|
|
55
63
|
):
|
|
56
64
|
self._base_url = base_url
|
|
57
65
|
self._agent_card_path = agent_card_path
|
|
58
66
|
self._task_timeout = task_timeout
|
|
59
67
|
self._streaming = streaming
|
|
68
|
+
self._auth_provider = auth_provider
|
|
60
69
|
|
|
61
70
|
self._httpx_client: httpx.AsyncClient | None = None
|
|
62
71
|
self._client: Client | None = None
|
|
@@ -82,13 +91,30 @@ class A2ABaseClient:
|
|
|
82
91
|
if not self._agent_card:
|
|
83
92
|
raise RuntimeError("Agent card not resolved")
|
|
84
93
|
|
|
85
|
-
# 3)
|
|
94
|
+
# 3) Setup authentication interceptors if auth is configured
|
|
95
|
+
interceptors = []
|
|
96
|
+
if self._auth_provider:
|
|
97
|
+
try:
|
|
98
|
+
from a2a.client import AuthInterceptor
|
|
99
|
+
from nat.plugins.a2a.auth.credential_service import A2ACredentialService
|
|
100
|
+
|
|
101
|
+
credential_service = A2ACredentialService(
|
|
102
|
+
auth_provider=self._auth_provider,
|
|
103
|
+
agent_card=self._agent_card,
|
|
104
|
+
)
|
|
105
|
+
interceptors.append(AuthInterceptor(credential_service))
|
|
106
|
+
logger.info("Authentication configured for A2A client")
|
|
107
|
+
except ImportError as e:
|
|
108
|
+
logger.error("Failed to setup authentication: %s", e)
|
|
109
|
+
raise RuntimeError("Authentication requires a2a-sdk with AuthInterceptor support") from e
|
|
110
|
+
|
|
111
|
+
# 4) Create A2A client with interceptors
|
|
86
112
|
client_config = ClientConfig(
|
|
87
113
|
httpx_client=self._httpx_client,
|
|
88
114
|
streaming=self._streaming,
|
|
89
115
|
)
|
|
90
116
|
factory = ClientFactory(client_config)
|
|
91
|
-
self._client = factory.create(self._agent_card)
|
|
117
|
+
self._client = factory.create(self._agent_card, interceptors=interceptors)
|
|
92
118
|
|
|
93
119
|
logger.info("Connected to A2A agent at %s", self._base_url)
|
|
94
120
|
return self
|
|
@@ -65,5 +65,8 @@ class A2AClientConfig(FunctionGroupBaseConfig, name="a2a_client"):
|
|
|
65
65
|
description="Whether to enable streaming support for the A2A client",
|
|
66
66
|
)
|
|
67
67
|
|
|
68
|
-
auth_provider: str | AuthenticationRef | None = Field(
|
|
69
|
-
|
|
68
|
+
auth_provider: str | AuthenticationRef | None = Field(
|
|
69
|
+
default=None,
|
|
70
|
+
description="Reference to NAT authentication provider for authenticating with the A2A agent. "
|
|
71
|
+
"Supports OAuth2, API Key, HTTP Basic, and other NAT auth providers.",
|
|
72
|
+
)
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import logging
|
|
17
17
|
from collections.abc import AsyncGenerator
|
|
18
|
+
from typing import TYPE_CHECKING
|
|
18
19
|
from typing import Any
|
|
19
20
|
|
|
20
21
|
from pydantic import BaseModel
|
|
@@ -22,10 +23,13 @@ from pydantic import Field
|
|
|
22
23
|
|
|
23
24
|
from nat.builder.function import FunctionGroup
|
|
24
25
|
from nat.builder.workflow_builder import Builder
|
|
25
|
-
from nat.cli.register_workflow import
|
|
26
|
+
from nat.cli.register_workflow import register_per_user_function_group
|
|
26
27
|
from nat.plugins.a2a.client.client_base import A2ABaseClient
|
|
27
28
|
from nat.plugins.a2a.client.client_config import A2AClientConfig
|
|
28
29
|
|
|
30
|
+
if TYPE_CHECKING:
|
|
31
|
+
from nat.authentication.interfaces import AuthProviderBase
|
|
32
|
+
|
|
29
33
|
logger = logging.getLogger(__name__)
|
|
30
34
|
|
|
31
35
|
|
|
@@ -66,13 +70,36 @@ class A2AClientFunctionGroup(FunctionGroup):
|
|
|
66
70
|
config: A2AClientConfig = self._config # type: ignore[assignment]
|
|
67
71
|
base_url = str(config.url)
|
|
68
72
|
|
|
73
|
+
# Get user_id from context (set by runtime for per-user function groups)
|
|
74
|
+
from nat.builder.context import Context
|
|
75
|
+
user_id = Context.get().user_id
|
|
76
|
+
if not user_id:
|
|
77
|
+
raise RuntimeError("User ID not found in context")
|
|
78
|
+
|
|
79
|
+
# Resolve auth provider if configured
|
|
80
|
+
auth_provider: AuthProviderBase | None = None
|
|
81
|
+
if config.auth_provider:
|
|
82
|
+
try:
|
|
83
|
+
auth_provider = await self._builder.get_auth_provider(config.auth_provider)
|
|
84
|
+
logger.info("Resolved authentication provider for A2A client")
|
|
85
|
+
except Exception as e:
|
|
86
|
+
logger.error("Failed to resolve auth provider '%s': %s", config.auth_provider, e)
|
|
87
|
+
raise RuntimeError(f"Failed to resolve auth provider: {e}") from e
|
|
88
|
+
|
|
69
89
|
# Create and initialize A2A client
|
|
70
|
-
self._client = A2ABaseClient(
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
90
|
+
self._client = A2ABaseClient(
|
|
91
|
+
base_url=base_url,
|
|
92
|
+
agent_card_path=config.agent_card_path,
|
|
93
|
+
task_timeout=config.task_timeout,
|
|
94
|
+
streaming=config.streaming,
|
|
95
|
+
auth_provider=auth_provider,
|
|
96
|
+
)
|
|
74
97
|
await self._client.__aenter__()
|
|
75
|
-
|
|
98
|
+
|
|
99
|
+
if auth_provider:
|
|
100
|
+
logger.info("Connected to A2A agent at %s with authentication (user_id: %s)", base_url, user_id)
|
|
101
|
+
else:
|
|
102
|
+
logger.info("Connected to A2A agent at %s (user_id: %s)", base_url, user_id)
|
|
76
103
|
|
|
77
104
|
# Discover agent card and register functions
|
|
78
105
|
self._register_functions()
|
|
@@ -281,11 +308,12 @@ class A2AClientFunctionGroup(FunctionGroup):
|
|
|
281
308
|
yield event
|
|
282
309
|
|
|
283
310
|
|
|
284
|
-
@
|
|
311
|
+
@register_per_user_function_group(config_type=A2AClientConfig)
|
|
285
312
|
async def a2a_client_function_group(config: A2AClientConfig, _builder: Builder):
|
|
286
313
|
"""
|
|
287
314
|
Connect to an A2A agent, discover agent card and publish the primary
|
|
288
|
-
agent function and helper functions.
|
|
315
|
+
agent function and helper functions. This function group is per-user,
|
|
316
|
+
meaning each user gets their own isolated instance.
|
|
289
317
|
|
|
290
318
|
This function group creates a three-level API:
|
|
291
319
|
- High-level: Agent function named after the agent (e.g., dice_agent)
|
|
@@ -17,7 +17,9 @@ import logging
|
|
|
17
17
|
|
|
18
18
|
from pydantic import BaseModel
|
|
19
19
|
from pydantic import Field
|
|
20
|
+
from pydantic import model_validator
|
|
20
21
|
|
|
22
|
+
from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
|
|
21
23
|
from nat.data_models.front_end import FrontEndBaseConfig
|
|
22
24
|
|
|
23
25
|
logger = logging.getLogger(__name__)
|
|
@@ -102,3 +104,28 @@ class A2AFrontEndConfig(FrontEndBaseConfig, name="a2a"):
|
|
|
102
104
|
default=None,
|
|
103
105
|
description="Custom worker class for handling A2A routes (default: built-in worker)",
|
|
104
106
|
)
|
|
107
|
+
|
|
108
|
+
# OAuth2 Resource Server (for protecting this A2A agent)
|
|
109
|
+
server_auth: OAuth2ResourceServerConfig | None = Field(
|
|
110
|
+
default=None,
|
|
111
|
+
description=("OAuth 2.0 Resource Server configuration for token verification. "
|
|
112
|
+
"When configured, the A2A server will validate OAuth2 Bearer tokens on all requests "
|
|
113
|
+
"except public agent card discovery. Supports both JWT validation (via JWKS) and "
|
|
114
|
+
"opaque token validation (via RFC 7662 introspection)."),
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
@model_validator(mode="after")
|
|
118
|
+
def validate_security_configuration(self):
|
|
119
|
+
"""Validate security configuration to prevent accidental misconfigurations."""
|
|
120
|
+
# Check if server is bound to a non-localhost interface without authentication
|
|
121
|
+
localhost_hosts = {"localhost", "127.0.0.1", "::1"}
|
|
122
|
+
if self.host not in localhost_hosts and self.server_auth is None:
|
|
123
|
+
logger.warning(
|
|
124
|
+
"A2A server is configured to bind to '%s' without authentication. "
|
|
125
|
+
"This may expose your server to unauthorized access. "
|
|
126
|
+
"Consider either: (1) binding to localhost for local-only access, "
|
|
127
|
+
"or (2) configuring server_auth for production deployments on public interfaces.",
|
|
128
|
+
self.host,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
return self
|
|
@@ -53,7 +53,7 @@ class A2AFrontEndPlugin(FrontEndBase[A2AFrontEndConfig]):
|
|
|
53
53
|
agent_card = await worker.create_agent_card(workflow)
|
|
54
54
|
|
|
55
55
|
# Create agent executor adapter
|
|
56
|
-
agent_executor = worker.create_agent_executor(workflow)
|
|
56
|
+
agent_executor = worker.create_agent_executor(workflow, builder)
|
|
57
57
|
|
|
58
58
|
# Create A2A server
|
|
59
59
|
a2a_server = worker.create_a2a_server(agent_card, agent_executor)
|
|
@@ -70,8 +70,21 @@ class A2AFrontEndPlugin(FrontEndBase[A2AFrontEndConfig]):
|
|
|
70
70
|
self.front_end_config.host,
|
|
71
71
|
self.front_end_config.port)
|
|
72
72
|
|
|
73
|
-
# Build the ASGI app
|
|
73
|
+
# Build the ASGI app
|
|
74
74
|
app = a2a_server.build()
|
|
75
|
+
|
|
76
|
+
# Add OAuth2 validation middleware if configured
|
|
77
|
+
if self.front_end_config.server_auth:
|
|
78
|
+
from nat.plugins.a2a.server.oauth_middleware import OAuth2ValidationMiddleware
|
|
79
|
+
|
|
80
|
+
app.add_middleware(OAuth2ValidationMiddleware, config=self.front_end_config.server_auth)
|
|
81
|
+
logger.info(
|
|
82
|
+
"OAuth2 token validation enabled for A2A server (issuer=%s, scopes=%s)",
|
|
83
|
+
self.front_end_config.server_auth.issuer_url,
|
|
84
|
+
self.front_end_config.server_auth.scopes,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Run with uvicorn
|
|
75
88
|
config = uvicorn.Config(
|
|
76
89
|
app,
|
|
77
90
|
host=self.front_end_config.host,
|
|
@@ -25,8 +25,10 @@ from a2a.server.tasks import InMemoryTaskStore
|
|
|
25
25
|
from a2a.types import AgentCapabilities
|
|
26
26
|
from a2a.types import AgentCard
|
|
27
27
|
from a2a.types import AgentSkill
|
|
28
|
+
from a2a.types import SecurityScheme
|
|
28
29
|
from nat.builder.function import Function
|
|
29
30
|
from nat.builder.workflow import Workflow
|
|
31
|
+
from nat.builder.workflow_builder import WorkflowBuilder
|
|
30
32
|
from nat.data_models.config import Config
|
|
31
33
|
from nat.plugins.a2a.server.agent_executor_adapter import NATWorkflowAgentExecutor
|
|
32
34
|
from nat.plugins.a2a.server.front_end_config import A2AFrontEndConfig
|
|
@@ -72,6 +74,81 @@ class A2AFrontEndPluginWorker:
|
|
|
72
74
|
|
|
73
75
|
return functions
|
|
74
76
|
|
|
77
|
+
async def _generate_security_schemes(
|
|
78
|
+
self, server_auth_config) -> tuple[dict[str, SecurityScheme], list[dict[str, list[str]]]]:
|
|
79
|
+
"""Generate A2A security schemes from OAuth2ResourceServerConfig.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
server_auth_config: OAuth2ResourceServerConfig
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Tuple of (security_schemes dict, security requirements list)
|
|
86
|
+
"""
|
|
87
|
+
from a2a.types import AuthorizationCodeOAuthFlow
|
|
88
|
+
from a2a.types import OAuth2SecurityScheme
|
|
89
|
+
from a2a.types import OAuthFlows
|
|
90
|
+
|
|
91
|
+
# Resolve OAuth2 endpoints from configuration
|
|
92
|
+
auth_url, token_url = await self._resolve_oauth_endpoints(server_auth_config)
|
|
93
|
+
|
|
94
|
+
# Create scope descriptions
|
|
95
|
+
scope_descriptions = {scope: f"Permission: {scope}" for scope in server_auth_config.scopes}
|
|
96
|
+
|
|
97
|
+
# Build OAuth2 security scheme
|
|
98
|
+
security_schemes = {
|
|
99
|
+
"oauth2":
|
|
100
|
+
SecurityScheme(root=OAuth2SecurityScheme(
|
|
101
|
+
type="oauth2",
|
|
102
|
+
description="OAuth 2.0 authentication required to access this agent",
|
|
103
|
+
flows=OAuthFlows(authorizationCode=AuthorizationCodeOAuthFlow(
|
|
104
|
+
authorizationUrl=auth_url,
|
|
105
|
+
tokenUrl=token_url,
|
|
106
|
+
scopes=scope_descriptions,
|
|
107
|
+
)),
|
|
108
|
+
))
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
# Security requirements (scopes needed)
|
|
112
|
+
security = [{"oauth2": server_auth_config.scopes}]
|
|
113
|
+
|
|
114
|
+
return security_schemes, security
|
|
115
|
+
|
|
116
|
+
async def _resolve_oauth_endpoints(self, server_auth_config) -> tuple[str, str]:
|
|
117
|
+
"""Resolve authorization and token URLs from OAuth2 configuration.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
server_auth_config: OAuth2ResourceServerConfig
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
Tuple of (authorization_url, token_url)
|
|
124
|
+
"""
|
|
125
|
+
import httpx
|
|
126
|
+
|
|
127
|
+
# If discovery URL is provided, use OIDC discovery
|
|
128
|
+
if server_auth_config.discovery_url:
|
|
129
|
+
try:
|
|
130
|
+
async with httpx.AsyncClient() as client:
|
|
131
|
+
response = await client.get(server_auth_config.discovery_url, timeout=5.0)
|
|
132
|
+
response.raise_for_status()
|
|
133
|
+
metadata = response.json()
|
|
134
|
+
|
|
135
|
+
auth_url = metadata.get("authorization_endpoint")
|
|
136
|
+
token_url = metadata.get("token_endpoint")
|
|
137
|
+
|
|
138
|
+
if auth_url and token_url:
|
|
139
|
+
logger.info("Resolved OAuth endpoints via discovery: %s", server_auth_config.discovery_url)
|
|
140
|
+
return auth_url, token_url
|
|
141
|
+
except Exception as e:
|
|
142
|
+
logger.warning("Failed to discover OAuth endpoints: %s", e)
|
|
143
|
+
|
|
144
|
+
# Fallback: derive from issuer URL (common convention)
|
|
145
|
+
issuer = server_auth_config.issuer_url.rstrip("/")
|
|
146
|
+
auth_url = f"{issuer}/oauth/authorize"
|
|
147
|
+
token_url = f"{issuer}/oauth/token"
|
|
148
|
+
|
|
149
|
+
logger.info("Using derived OAuth endpoints from issuer: %s", issuer)
|
|
150
|
+
return auth_url, token_url
|
|
151
|
+
|
|
75
152
|
async def create_agent_card(self, workflow: Workflow) -> AgentCard:
|
|
76
153
|
"""Build AgentCard from configuration and workflow functions.
|
|
77
154
|
|
|
@@ -112,6 +189,18 @@ class A2AFrontEndPluginWorker:
|
|
|
112
189
|
|
|
113
190
|
logger.info("Auto-generated %d skills from workflow functions", len(skills))
|
|
114
191
|
|
|
192
|
+
# Generate security schemes if server_auth is configured
|
|
193
|
+
security_schemes = None
|
|
194
|
+
security = None
|
|
195
|
+
|
|
196
|
+
if config.server_auth:
|
|
197
|
+
security_schemes, security = await self._generate_security_schemes(config.server_auth)
|
|
198
|
+
logger.info(
|
|
199
|
+
"Generated OAuth2 security schemes for agent (issuer=%s, scopes=%s)",
|
|
200
|
+
config.server_auth.issuer_url,
|
|
201
|
+
config.server_auth.scopes,
|
|
202
|
+
)
|
|
203
|
+
|
|
115
204
|
# Build agent card
|
|
116
205
|
agent_url = f"http://{config.host}:{config.port}/"
|
|
117
206
|
agent_card = AgentCard(
|
|
@@ -123,15 +212,19 @@ class A2AFrontEndPluginWorker:
|
|
|
123
212
|
default_output_modes=config.default_output_modes,
|
|
124
213
|
capabilities=capabilities,
|
|
125
214
|
skills=skills,
|
|
215
|
+
security_schemes=security_schemes,
|
|
216
|
+
security=security,
|
|
126
217
|
)
|
|
127
218
|
|
|
128
219
|
logger.info("Created AgentCard for: %s v%s", config.name, config.version)
|
|
129
220
|
logger.info("Agent URL: %s", agent_url)
|
|
130
221
|
logger.info("Skills: %d", len(skills))
|
|
222
|
+
if security_schemes:
|
|
223
|
+
logger.info("Security: OAuth2 authentication required")
|
|
131
224
|
|
|
132
225
|
return agent_card
|
|
133
226
|
|
|
134
|
-
def create_agent_executor(self, workflow: Workflow) -> NATWorkflowAgentExecutor:
|
|
227
|
+
def create_agent_executor(self, workflow: Workflow, builder: WorkflowBuilder) -> NATWorkflowAgentExecutor:
|
|
135
228
|
"""Create agent executor adapter for the workflow.
|
|
136
229
|
|
|
137
230
|
This creates a SessionManager to handle concurrent A2A task requests,
|
|
@@ -139,13 +232,16 @@ class A2AFrontEndPluginWorker:
|
|
|
139
232
|
|
|
140
233
|
Args:
|
|
141
234
|
workflow: The NAT workflow to expose
|
|
235
|
+
builder: The workflow builder used to create the workflow
|
|
142
236
|
|
|
143
237
|
Returns:
|
|
144
238
|
NATWorkflowAgentExecutor that wraps the workflow with a SessionManager
|
|
145
239
|
"""
|
|
146
240
|
# Create SessionManager to handle concurrent requests with proper limits
|
|
147
241
|
session_manager = SessionManager(
|
|
148
|
-
|
|
242
|
+
config=self.full_config,
|
|
243
|
+
shared_builder=builder,
|
|
244
|
+
shared_workflow=workflow,
|
|
149
245
|
max_concurrency=self.max_concurrency,
|
|
150
246
|
)
|
|
151
247
|
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""OAuth 2.0 token validation middleware for A2A servers."""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
|
|
19
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
20
|
+
from starlette.requests import Request
|
|
21
|
+
from starlette.responses import JSONResponse
|
|
22
|
+
|
|
23
|
+
from nat.authentication.credential_validator.bearer_token_validator import BearerTokenValidator
|
|
24
|
+
from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class OAuth2ValidationMiddleware(BaseHTTPMiddleware):
|
|
30
|
+
"""OAuth2 Bearer token validation middleware for A2A servers.
|
|
31
|
+
|
|
32
|
+
Validates Bearer tokens using NAT's BearerTokenValidator which supports:
|
|
33
|
+
- JWT validation via JWKS (RFC 7519)
|
|
34
|
+
- Opaque token validation via introspection (RFC 7662)
|
|
35
|
+
- OIDC discovery
|
|
36
|
+
- Scope and audience enforcement
|
|
37
|
+
|
|
38
|
+
The middleware allows public access to the agent card discovery endpoint
|
|
39
|
+
(/.well-known/agent.json) and validates all other A2A requests.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, app, config: OAuth2ResourceServerConfig):
|
|
43
|
+
"""Initialize OAuth2 validation middleware.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
app: Starlette application
|
|
47
|
+
config: OAuth2 resource server configuration
|
|
48
|
+
"""
|
|
49
|
+
super().__init__(app)
|
|
50
|
+
|
|
51
|
+
# Create validator using NAT's BearerTokenValidator
|
|
52
|
+
self.validator = BearerTokenValidator(
|
|
53
|
+
issuer=config.issuer_url,
|
|
54
|
+
audience=config.audience,
|
|
55
|
+
scopes=config.scopes,
|
|
56
|
+
jwks_uri=config.jwks_uri,
|
|
57
|
+
introspection_endpoint=config.introspection_endpoint,
|
|
58
|
+
discovery_url=config.discovery_url,
|
|
59
|
+
client_id=config.client_id,
|
|
60
|
+
client_secret=config.client_secret.get_secret_value() if config.client_secret else None,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
logger.info(
|
|
64
|
+
"OAuth2 validation middleware initialized (issuer=%s, scopes=%s, audience=%s)",
|
|
65
|
+
config.issuer_url,
|
|
66
|
+
config.scopes,
|
|
67
|
+
config.audience,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
async def dispatch(self, request: Request, call_next):
|
|
71
|
+
"""Validate OAuth2 Bearer token for all requests except agent card discovery.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
request: Incoming HTTP request
|
|
75
|
+
call_next: Next middleware/handler in chain
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
HTTP response (either error or result from next handler)
|
|
79
|
+
"""
|
|
80
|
+
# Public: Agent card discovery (per A2A spec)
|
|
81
|
+
if request.url.path == "/.well-known/agent-card.json":
|
|
82
|
+
logger.debug("Public access to agent card discovery")
|
|
83
|
+
return await call_next(request)
|
|
84
|
+
|
|
85
|
+
# Extract Bearer token
|
|
86
|
+
auth_header = request.headers.get("Authorization", "")
|
|
87
|
+
if not auth_header.startswith("Bearer "):
|
|
88
|
+
logger.warning("Missing or invalid Authorization header")
|
|
89
|
+
return JSONResponse({
|
|
90
|
+
"error": "unauthorized", "message": "Missing or invalid Bearer token"
|
|
91
|
+
},
|
|
92
|
+
status_code=401)
|
|
93
|
+
|
|
94
|
+
token = auth_header[7:] # Strip "Bearer "
|
|
95
|
+
|
|
96
|
+
# Validate token using NAT's validator
|
|
97
|
+
try:
|
|
98
|
+
result = await self.validator.verify(token)
|
|
99
|
+
except Exception as e:
|
|
100
|
+
logger.error(f"Token validation error: {e}")
|
|
101
|
+
return JSONResponse({"error": "invalid_token", "message": "Token validation failed"}, status_code=403)
|
|
102
|
+
|
|
103
|
+
# Check if token is active
|
|
104
|
+
if not result.active:
|
|
105
|
+
logger.warning("Token is not active")
|
|
106
|
+
return JSONResponse({"error": "invalid_token", "message": "Token is not active"}, status_code=403)
|
|
107
|
+
|
|
108
|
+
# Attach token info to request state for potential use by handlers
|
|
109
|
+
request.state.oauth_user = result.subject
|
|
110
|
+
request.state.oauth_scopes = result.scopes or []
|
|
111
|
+
request.state.oauth_client_id = result.client_id
|
|
112
|
+
request.state.oauth_token_info = result
|
|
113
|
+
|
|
114
|
+
logger.debug(
|
|
115
|
+
"Token validated successfully (user=%s, scopes=%s, client=%s)",
|
|
116
|
+
result.subject,
|
|
117
|
+
result.scopes,
|
|
118
|
+
result.client_id,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
return await call_next(request)
|
{nvidia_nat_a2a-1.4.0a20251207.dist-info → nvidia_nat_a2a-1.4.0a20251224.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: nvidia-nat-a2a
|
|
3
|
-
Version: 1.4.
|
|
3
|
+
Version: 1.4.0a20251224
|
|
4
4
|
Summary: Subpackage for A2A Protocol integration in NeMo Agent Toolkit
|
|
5
5
|
Author: NVIDIA Corporation
|
|
6
6
|
Maintainer: NVIDIA Corporation
|
|
@@ -15,7 +15,7 @@ Classifier: Programming Language :: Python :: 3.13
|
|
|
15
15
|
Requires-Python: <3.14,>=3.11
|
|
16
16
|
Description-Content-Type: text/markdown
|
|
17
17
|
License-File: LICENSE.md
|
|
18
|
-
Requires-Dist: nvidia-nat==v1.4.
|
|
18
|
+
Requires-Dist: nvidia-nat==v1.4.0a20251224
|
|
19
19
|
Requires-Dist: a2a-sdk~=0.3.20
|
|
20
20
|
Dynamic: license-file
|
|
21
21
|
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
nat/meta/pypi.md,sha256=YkfjzZntzheoaBie5ZovnAwB78xxVqk9sblkZRZcdLU,1661
|
|
2
|
+
nat/plugins/a2a/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
|
|
3
|
+
nat/plugins/a2a/register.py,sha256=pUN1hbJ38M8GbdNcA0qQzJ1S-ZC91GnRGk_8SO_kTVg,853
|
|
4
|
+
nat/plugins/a2a/auth/__init__.py,sha256=iQFx1YrjFcepS7k8jp93A0IVOkFeNx_I35M6dIngoJA,726
|
|
5
|
+
nat/plugins/a2a/auth/credential_service.py,sha256=-_VdDF4YESaAtY1ONUiOL5z4aGDJZYVuhyhI9BZhuyI,15967
|
|
6
|
+
nat/plugins/a2a/client/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
|
|
7
|
+
nat/plugins/a2a/client/client_base.py,sha256=xShDZDFKa4R2XsY3yBMvM-eDaf_0cdE48XJzQ4WcEOw,13366
|
|
8
|
+
nat/plugins/a2a/client/client_config.py,sha256=KwWjymDg9GUfSYcIaBhcxph4Hu6IeTe414hrNUUo-6g,2875
|
|
9
|
+
nat/plugins/a2a/client/client_impl.py,sha256=CGAjiHr6EyWcnlSipmT8ixgjD4s8VbPRBPOZy2q_Sm0,12958
|
|
10
|
+
nat/plugins/a2a/server/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
|
|
11
|
+
nat/plugins/a2a/server/agent_executor_adapter.py,sha256=wvGXOb3FcV0_pYRv-yr-QzozjzXM909D49Dxm9199xI,7015
|
|
12
|
+
nat/plugins/a2a/server/front_end_config.py,sha256=Lg-qjDmC4fwrwnHNtSRl54pMpdwVnO06xhgbLt-aEZY,4902
|
|
13
|
+
nat/plugins/a2a/server/front_end_plugin.py,sha256=fX3Lagkd48snSiNo2IMTRpR-40WHUWQidpjKu8uQChY,4896
|
|
14
|
+
nat/plugins/a2a/server/front_end_plugin_worker.py,sha256=Ehdv6lyUcrWkfMq7YomD4NYFAusrtQ2JYj2HnkIqGhY,11696
|
|
15
|
+
nat/plugins/a2a/server/oauth_middleware.py,sha256=NvvIJSPB8wRui2eQlxr6AaNhN0JxdUQ1Ajr8Dnk0rnY,4751
|
|
16
|
+
nat/plugins/a2a/server/register_frontend.py,sha256=4TmpBcZF4x71c2xnWuketsygqHmU7D2hKA2bzO34TpU,1480
|
|
17
|
+
nvidia_nat_a2a-1.4.0a20251224.dist-info/licenses/LICENSE.md,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
|
|
18
|
+
nvidia_nat_a2a-1.4.0a20251224.dist-info/METADATA,sha256=7KIbUZYZ3X7iQhe1bvuUrLvSyQndAHLAXxmVxtzrSms,2438
|
|
19
|
+
nvidia_nat_a2a-1.4.0a20251224.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
20
|
+
nvidia_nat_a2a-1.4.0a20251224.dist-info/entry_points.txt,sha256=Lacvy6nXpDTv8dh8vKJ_QE8TobliVdhgABuw25t8fBg,145
|
|
21
|
+
nvidia_nat_a2a-1.4.0a20251224.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
|
|
22
|
+
nvidia_nat_a2a-1.4.0a20251224.dist-info/RECORD,,
|
|
@@ -1,19 +0,0 @@
|
|
|
1
|
-
nat/meta/pypi.md,sha256=YkfjzZntzheoaBie5ZovnAwB78xxVqk9sblkZRZcdLU,1661
|
|
2
|
-
nat/plugins/a2a/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
|
|
3
|
-
nat/plugins/a2a/register.py,sha256=pUN1hbJ38M8GbdNcA0qQzJ1S-ZC91GnRGk_8SO_kTVg,853
|
|
4
|
-
nat/plugins/a2a/client/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
|
|
5
|
-
nat/plugins/a2a/client/client_base.py,sha256=s0HTS0ebuwoc_fL4Z_laRe-OIZ5iooSG_FzcgZca8_E,12070
|
|
6
|
-
nat/plugins/a2a/client/client_config.py,sha256=SWu46fAa25IYc3Lhq9w9nIt5xkCtdBuuPy74pd5vPPk,2788
|
|
7
|
-
nat/plugins/a2a/client/client_impl.py,sha256=cc_rYyPq86_8R12MjesMHaZrYG9lDnip2125veQ1fEY,11775
|
|
8
|
-
nat/plugins/a2a/server/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
|
|
9
|
-
nat/plugins/a2a/server/agent_executor_adapter.py,sha256=wvGXOb3FcV0_pYRv-yr-QzozjzXM909D49Dxm9199xI,7015
|
|
10
|
-
nat/plugins/a2a/server/front_end_config.py,sha256=Qnjbx6n67Xy3sZ6rkAYZaKk-WfBDrVX5OzZSWxU6fIg,3423
|
|
11
|
-
nat/plugins/a2a/server/front_end_plugin.py,sha256=euhh5LXkZpyC5HaUaJKFJH3BIF6jS2ti3NNXVQ71bgI,4255
|
|
12
|
-
nat/plugins/a2a/server/front_end_plugin_worker.py,sha256=Ih00L9DtZZYQvN_RRw4a5M_StTFhZ4JGHNniw6glOzY,7757
|
|
13
|
-
nat/plugins/a2a/server/register_frontend.py,sha256=4TmpBcZF4x71c2xnWuketsygqHmU7D2hKA2bzO34TpU,1480
|
|
14
|
-
nvidia_nat_a2a-1.4.0a20251207.dist-info/licenses/LICENSE.md,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
|
|
15
|
-
nvidia_nat_a2a-1.4.0a20251207.dist-info/METADATA,sha256=6qhxCj7OS1n7csjZtbbByAd86SYtWVoydK0nFnvDfRc,2438
|
|
16
|
-
nvidia_nat_a2a-1.4.0a20251207.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
17
|
-
nvidia_nat_a2a-1.4.0a20251207.dist-info/entry_points.txt,sha256=Lacvy6nXpDTv8dh8vKJ_QE8TobliVdhgABuw25t8fBg,145
|
|
18
|
-
nvidia_nat_a2a-1.4.0a20251207.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
|
|
19
|
-
nvidia_nat_a2a-1.4.0a20251207.dist-info/RECORD,,
|
|
File without changes
|
{nvidia_nat_a2a-1.4.0a20251207.dist-info → nvidia_nat_a2a-1.4.0a20251224.dist-info}/entry_points.txt
RENAMED
|
File without changes
|
|
File without changes
|
{nvidia_nat_a2a-1.4.0a20251207.dist-info → nvidia_nat_a2a-1.4.0a20251224.dist-info}/top_level.txt
RENAMED
|
File without changes
|