nvidia-nat-mcp 1.4.0a20260107__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/meta/pypi.md +32 -0
- nat/plugins/mcp/__init__.py +14 -0
- nat/plugins/mcp/auth/__init__.py +14 -0
- nat/plugins/mcp/auth/auth_flow_handler.py +208 -0
- nat/plugins/mcp/auth/auth_provider.py +431 -0
- nat/plugins/mcp/auth/auth_provider_config.py +86 -0
- nat/plugins/mcp/auth/register.py +33 -0
- nat/plugins/mcp/auth/service_account/__init__.py +14 -0
- nat/plugins/mcp/auth/service_account/provider.py +136 -0
- nat/plugins/mcp/auth/service_account/provider_config.py +137 -0
- nat/plugins/mcp/auth/service_account/token_client.py +156 -0
- nat/plugins/mcp/auth/token_storage.py +265 -0
- nat/plugins/mcp/cli/__init__.py +15 -0
- nat/plugins/mcp/cli/commands.py +1051 -0
- nat/plugins/mcp/client/__init__.py +15 -0
- nat/plugins/mcp/client/client_base.py +665 -0
- nat/plugins/mcp/client/client_config.py +146 -0
- nat/plugins/mcp/client/client_impl.py +782 -0
- nat/plugins/mcp/exception_handler.py +211 -0
- nat/plugins/mcp/exceptions.py +142 -0
- nat/plugins/mcp/register.py +23 -0
- nat/plugins/mcp/server/__init__.py +15 -0
- nat/plugins/mcp/server/front_end_config.py +109 -0
- nat/plugins/mcp/server/front_end_plugin.py +155 -0
- nat/plugins/mcp/server/front_end_plugin_worker.py +411 -0
- nat/plugins/mcp/server/introspection_token_verifier.py +72 -0
- nat/plugins/mcp/server/memory_profiler.py +320 -0
- nat/plugins/mcp/server/register_frontend.py +27 -0
- nat/plugins/mcp/server/tool_converter.py +286 -0
- nat/plugins/mcp/utils.py +228 -0
- nvidia_nat_mcp-1.4.0a20260107.dist-info/METADATA +55 -0
- nvidia_nat_mcp-1.4.0a20260107.dist-info/RECORD +37 -0
- nvidia_nat_mcp-1.4.0a20260107.dist-info/WHEEL +5 -0
- nvidia_nat_mcp-1.4.0a20260107.dist-info/entry_points.txt +9 -0
- nvidia_nat_mcp-1.4.0a20260107.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat_mcp-1.4.0a20260107.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat_mcp-1.4.0a20260107.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import importlib
|
|
18
|
+
import logging
|
|
19
|
+
import typing
|
|
20
|
+
|
|
21
|
+
from pydantic import SecretStr
|
|
22
|
+
|
|
23
|
+
from nat.authentication.interfaces import AuthProviderBase
|
|
24
|
+
from nat.data_models.authentication import AuthResult
|
|
25
|
+
from nat.data_models.authentication import Credential
|
|
26
|
+
from nat.data_models.authentication import HeaderCred
|
|
27
|
+
from nat.plugins.mcp.auth.service_account.provider_config import MCPServiceAccountProviderConfig
|
|
28
|
+
from nat.plugins.mcp.auth.service_account.token_client import ServiceAccountTokenClient
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class MCPServiceAccountProvider(AuthProviderBase[MCPServiceAccountProviderConfig]):
|
|
34
|
+
"""
|
|
35
|
+
MCP service account authentication provider using OAuth2 client credentials.
|
|
36
|
+
|
|
37
|
+
Provides headless authentication for MCP clients using service account credentials.
|
|
38
|
+
Supports two authentication patterns:
|
|
39
|
+
|
|
40
|
+
1. Single authentication: OAuth2 service account token only
|
|
41
|
+
2. Dual authentication: OAuth2 service account token + service-specific token
|
|
42
|
+
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(self, config: MCPServiceAccountProviderConfig, builder=None):
|
|
46
|
+
super().__init__(config)
|
|
47
|
+
|
|
48
|
+
# Initialize token client
|
|
49
|
+
self._token_client = ServiceAccountTokenClient(
|
|
50
|
+
client_id=config.client_id,
|
|
51
|
+
client_secret=config.client_secret,
|
|
52
|
+
token_url=config.token_url,
|
|
53
|
+
scopes=" ".join(config.scopes), # Convert list to space-delimited string for OAuth2
|
|
54
|
+
token_cache_buffer_seconds=config.token_cache_buffer_seconds,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# Load dynamic service token function if configured
|
|
58
|
+
self._service_token_function = None
|
|
59
|
+
if config.service_token and config.service_token.function:
|
|
60
|
+
self._service_token_function = self._load_function(config.service_token.function)
|
|
61
|
+
|
|
62
|
+
logger.info("Initialized MCP service account auth provider: "
|
|
63
|
+
"token_url=%s, scopes=%s, has_service_token=%s",
|
|
64
|
+
config.token_url,
|
|
65
|
+
config.scopes,
|
|
66
|
+
config.service_token is not None)
|
|
67
|
+
|
|
68
|
+
def _load_function(self, function_path: str) -> typing.Callable:
|
|
69
|
+
"""Load a Python function from a module path string (e.g., 'my_module.get_token')."""
|
|
70
|
+
try:
|
|
71
|
+
module_name, func_name = function_path.rsplit(".", 1)
|
|
72
|
+
module = importlib.import_module(module_name)
|
|
73
|
+
func = getattr(module, func_name)
|
|
74
|
+
logger.info("Loaded service token function: %s", function_path)
|
|
75
|
+
return func
|
|
76
|
+
except Exception as e:
|
|
77
|
+
raise ValueError(f"Failed to load service token function '{function_path}': {e}") from e
|
|
78
|
+
|
|
79
|
+
async def authenticate(self, user_id: str | None = None, **kwargs) -> AuthResult:
|
|
80
|
+
"""
|
|
81
|
+
Authenticate using OAuth2 client credentials flow.
|
|
82
|
+
|
|
83
|
+
Note: user_id is ignored for service accounts (non-session-specific).
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
AuthResult with HeaderCred objects for service account authentication
|
|
87
|
+
"""
|
|
88
|
+
# Get OAuth2 access token (cached if still valid)
|
|
89
|
+
access_token = await self._token_client.get_access_token()
|
|
90
|
+
|
|
91
|
+
# Build credentials list using HeaderCred
|
|
92
|
+
credentials: list[Credential] = [
|
|
93
|
+
HeaderCred(name="Authorization", value=SecretStr(f"Bearer {access_token.get_secret_value()}"))
|
|
94
|
+
]
|
|
95
|
+
|
|
96
|
+
# Add service-specific token if configured
|
|
97
|
+
if self.config.service_token:
|
|
98
|
+
service_header = self.config.service_token.header
|
|
99
|
+
service_token_value = None
|
|
100
|
+
|
|
101
|
+
# Get service token from static config or dynamic function
|
|
102
|
+
if self.config.service_token.token:
|
|
103
|
+
# Static token from config
|
|
104
|
+
service_token_value = self.config.service_token.token.get_secret_value()
|
|
105
|
+
|
|
106
|
+
elif self._service_token_function:
|
|
107
|
+
# Dynamic token from function
|
|
108
|
+
try:
|
|
109
|
+
# Pass configured kwargs to the function
|
|
110
|
+
# Function can access runtime context via AIQContext.get() if needed
|
|
111
|
+
# Handle both sync and async functions
|
|
112
|
+
if asyncio.iscoroutinefunction(self._service_token_function):
|
|
113
|
+
result = await self._service_token_function(**self.config.service_token.kwargs)
|
|
114
|
+
else:
|
|
115
|
+
result = self._service_token_function(**self.config.service_token.kwargs)
|
|
116
|
+
|
|
117
|
+
# Handle function return type: str or tuple[str, str]
|
|
118
|
+
if isinstance(result, tuple):
|
|
119
|
+
service_header, service_token_value = result
|
|
120
|
+
else:
|
|
121
|
+
service_token_value = result
|
|
122
|
+
|
|
123
|
+
logger.debug("Retrieved service token via dynamic function")
|
|
124
|
+
|
|
125
|
+
except Exception as e:
|
|
126
|
+
raise RuntimeError(f"Failed to get service token from function: {e}") from e
|
|
127
|
+
|
|
128
|
+
if service_token_value:
|
|
129
|
+
credentials.append(HeaderCred(name=service_header, value=SecretStr(service_token_value)))
|
|
130
|
+
|
|
131
|
+
# Return AuthResult with HeaderCred objects
|
|
132
|
+
return AuthResult(
|
|
133
|
+
credentials=credentials,
|
|
134
|
+
token_expires_at=self._token_client.token_expires_at,
|
|
135
|
+
raw={},
|
|
136
|
+
)
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import typing
|
|
17
|
+
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
from pydantic import Field
|
|
20
|
+
from pydantic import field_validator
|
|
21
|
+
from pydantic import model_validator
|
|
22
|
+
|
|
23
|
+
from nat.authentication.interfaces import AuthProviderBaseConfig
|
|
24
|
+
from nat.data_models.common import OptionalSecretStr
|
|
25
|
+
from nat.data_models.common import SerializableSecretStr
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ServiceTokenConfig(BaseModel):
|
|
29
|
+
"""
|
|
30
|
+
Configuration for service-specific token in dual authentication patterns.
|
|
31
|
+
|
|
32
|
+
Supports two modes:
|
|
33
|
+
|
|
34
|
+
1. Static token: Provide token and header directly
|
|
35
|
+
2. Dynamic function: Provide function path and optional kwargs
|
|
36
|
+
|
|
37
|
+
The function will be called on every request and should have signature::
|
|
38
|
+
|
|
39
|
+
async def get_service_token(**kwargs) -> str | tuple[str, str]
|
|
40
|
+
|
|
41
|
+
If function returns ``tuple[str, str]``, it's interpreted as (header_name, token).
|
|
42
|
+
If function returns ``str``, it's the token and header field is used for header name.
|
|
43
|
+
|
|
44
|
+
The function can access runtime context via AIQContext.get() if needed.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
# Static token approach
|
|
48
|
+
token: OptionalSecretStr = Field(
|
|
49
|
+
default=None,
|
|
50
|
+
description="Static service token value (mutually exclusive with function)",
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
header: str = Field(
|
|
54
|
+
default="X-Service-Account-Token",
|
|
55
|
+
description="HTTP header name for service token (default: 'X-Service-Account-Token')",
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# Dynamic function approach
|
|
59
|
+
function: str | None = Field(
|
|
60
|
+
default=None,
|
|
61
|
+
description=("Python function path that returns service token dynamically (mutually exclusive with token). "
|
|
62
|
+
"Function signature: async def func(\\**kwargs) -> str | tuple[str, str]. "
|
|
63
|
+
"Access runtime context via AIQContext.get() if needed."),
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
kwargs: dict[str, typing.Any] = Field(
|
|
67
|
+
default_factory=dict,
|
|
68
|
+
description="Additional keyword arguments to pass to the custom function",
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
@model_validator(mode="after")
|
|
72
|
+
def validate_token_or_function(self):
|
|
73
|
+
"""Ensure either token or function is provided, but not both."""
|
|
74
|
+
has_token = self.token is not None
|
|
75
|
+
has_function = self.function is not None
|
|
76
|
+
|
|
77
|
+
if not has_token and not has_function:
|
|
78
|
+
raise ValueError("Either 'token' or 'function' must be provided in service_token config")
|
|
79
|
+
|
|
80
|
+
if has_token and has_function:
|
|
81
|
+
raise ValueError("Cannot specify both 'token' and 'function' in service_token config. Choose one.")
|
|
82
|
+
|
|
83
|
+
return self
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class MCPServiceAccountProviderConfig(AuthProviderBaseConfig, name="mcp_service_account"):
|
|
87
|
+
"""
|
|
88
|
+
Configuration for MCP service account authentication using OAuth2 client credentials.
|
|
89
|
+
|
|
90
|
+
Generic implementation supporting any OAuth2 client credentials flow.
|
|
91
|
+
|
|
92
|
+
Supports two authentication patterns:
|
|
93
|
+
1. Single authentication: OAuth2 service account token only
|
|
94
|
+
2. Dual authentication: OAuth2 service account token + service-specific token
|
|
95
|
+
|
|
96
|
+
Common use cases:
|
|
97
|
+
- Headless/automated MCP workflows
|
|
98
|
+
- CI/CD pipelines
|
|
99
|
+
- Backend services without user interaction
|
|
100
|
+
|
|
101
|
+
All values must be provided via configuration. Use ${ENV_VAR} syntax in YAML
|
|
102
|
+
configs for environment variable substitution.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
# Required: OAuth2 client credentials
|
|
106
|
+
client_id: str = Field(description="OAuth2 client identifier")
|
|
107
|
+
|
|
108
|
+
client_secret: SerializableSecretStr = Field(description="OAuth2 client secret")
|
|
109
|
+
|
|
110
|
+
# Required: Token endpoint URL
|
|
111
|
+
token_url: str = Field(description="OAuth2 token endpoint URL")
|
|
112
|
+
|
|
113
|
+
# Required: OAuth2 scopes
|
|
114
|
+
scopes: list[str] = Field(description="List of OAuth2 scopes (will be joined with spaces for OAuth2 request)")
|
|
115
|
+
|
|
116
|
+
# Optional: Service-specific token configuration for dual authentication patterns
|
|
117
|
+
service_token: ServiceTokenConfig | None = Field(
|
|
118
|
+
default=None,
|
|
119
|
+
description="Optional service token configuration for dual authentication patterns. "
|
|
120
|
+
"Provide either a static token or a dynamic function that returns the token at runtime.",
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# Token caching configuration
|
|
124
|
+
token_cache_buffer_seconds: int = Field(default=300,
|
|
125
|
+
description="Seconds before token expiry to refresh (default: 300s/5min)")
|
|
126
|
+
|
|
127
|
+
@field_validator("scopes", mode="before")
|
|
128
|
+
@classmethod
|
|
129
|
+
def validate_scopes(cls, v):
|
|
130
|
+
"""
|
|
131
|
+
Accept both list[str] and space-delimited string formats for scopes.
|
|
132
|
+
Converts string to list for consistency.
|
|
133
|
+
"""
|
|
134
|
+
if isinstance(v, str):
|
|
135
|
+
# Split space-delimited string into list
|
|
136
|
+
return [scope.strip() for scope in v.split() if scope.strip()]
|
|
137
|
+
return v
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import base64
|
|
18
|
+
import logging
|
|
19
|
+
from datetime import datetime
|
|
20
|
+
from datetime import timedelta
|
|
21
|
+
|
|
22
|
+
import httpx
|
|
23
|
+
from pydantic import SecretStr
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ServiceAccountTokenClient:
|
|
29
|
+
"""
|
|
30
|
+
Generic OAuth2 client credentials token client for service accounts.
|
|
31
|
+
|
|
32
|
+
Implements standard OAuth2 client credentials flow with token caching.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
client_id: str,
|
|
38
|
+
client_secret: SecretStr,
|
|
39
|
+
token_url: str,
|
|
40
|
+
scopes: str,
|
|
41
|
+
token_cache_buffer_seconds: int = 300,
|
|
42
|
+
):
|
|
43
|
+
"""
|
|
44
|
+
Initialize service account token client.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
client_id: OAuth2 client identifier
|
|
48
|
+
client_secret: OAuth2 client secret (SecretStr)
|
|
49
|
+
token_url: OAuth2 token endpoint URL
|
|
50
|
+
scopes: Space-separated list of scopes
|
|
51
|
+
token_cache_buffer_seconds: Seconds before expiry to refresh (default: 5 min)
|
|
52
|
+
"""
|
|
53
|
+
self.client_id = client_id
|
|
54
|
+
self.client_secret = client_secret
|
|
55
|
+
self.token_url = token_url
|
|
56
|
+
self.scopes = scopes
|
|
57
|
+
self.token_cache_buffer_seconds = token_cache_buffer_seconds
|
|
58
|
+
|
|
59
|
+
# Token cache
|
|
60
|
+
self._cached_token: SecretStr | None = None
|
|
61
|
+
self._token_expires_at: datetime | None = None
|
|
62
|
+
self._lock = None # Will be initialized as asyncio.Lock when needed
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def token_expires_at(self) -> datetime | None:
|
|
66
|
+
return self._token_expires_at
|
|
67
|
+
|
|
68
|
+
async def _get_lock(self) -> asyncio.Lock:
|
|
69
|
+
"""Lazy initialization of asyncio.Lock."""
|
|
70
|
+
if self._lock is None:
|
|
71
|
+
self._lock = asyncio.Lock()
|
|
72
|
+
return self._lock
|
|
73
|
+
|
|
74
|
+
def _is_token_valid(self) -> bool:
|
|
75
|
+
"""Check if cached token is still valid (with buffer time)."""
|
|
76
|
+
if not self._cached_token or not self._token_expires_at:
|
|
77
|
+
return False
|
|
78
|
+
buffer = timedelta(seconds=self.token_cache_buffer_seconds)
|
|
79
|
+
return datetime.now() < (self._token_expires_at - buffer)
|
|
80
|
+
|
|
81
|
+
async def get_access_token(self) -> SecretStr:
|
|
82
|
+
"""
|
|
83
|
+
Get OAuth2 access token, using cache if valid.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Access token as SecretStr
|
|
87
|
+
|
|
88
|
+
Raises:
|
|
89
|
+
RuntimeError: If token acquisition fails
|
|
90
|
+
"""
|
|
91
|
+
# Fast path: check cache without lock
|
|
92
|
+
if self._is_token_valid():
|
|
93
|
+
logger.debug("Using cached service account token")
|
|
94
|
+
assert self._cached_token is not None # _is_token_valid() ensures this
|
|
95
|
+
return self._cached_token
|
|
96
|
+
|
|
97
|
+
# Slow path: acquire lock and refresh token
|
|
98
|
+
lock = await self._get_lock()
|
|
99
|
+
async with lock:
|
|
100
|
+
# Double-check after acquiring lock
|
|
101
|
+
if self._is_token_valid():
|
|
102
|
+
logger.debug("Using cached service account token (acquired during lock wait)")
|
|
103
|
+
assert self._cached_token is not None # _is_token_valid() ensures this
|
|
104
|
+
return self._cached_token
|
|
105
|
+
|
|
106
|
+
logger.info("Fetching new service account token")
|
|
107
|
+
return await self._fetch_new_token()
|
|
108
|
+
|
|
109
|
+
async def _fetch_new_token(self) -> SecretStr:
|
|
110
|
+
"""
|
|
111
|
+
Fetch a new token from the OAuth2 token endpoint.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
New access token as SecretStr
|
|
115
|
+
|
|
116
|
+
Raises:
|
|
117
|
+
RuntimeError: If token request fails
|
|
118
|
+
"""
|
|
119
|
+
# Encode credentials for Basic authentication
|
|
120
|
+
credentials = f"{self.client_id}:{self.client_secret.get_secret_value()}"
|
|
121
|
+
encoded_credentials = base64.b64encode(credentials.encode()).decode()
|
|
122
|
+
|
|
123
|
+
headers = {"Authorization": f"Basic {encoded_credentials}", "Content-Type": "application/x-www-form-urlencoded"}
|
|
124
|
+
|
|
125
|
+
data = {"grant_type": "client_credentials", "scope": self.scopes}
|
|
126
|
+
|
|
127
|
+
try:
|
|
128
|
+
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
129
|
+
response = await client.post(self.token_url, headers=headers, data=data)
|
|
130
|
+
|
|
131
|
+
if response.status_code == 200:
|
|
132
|
+
token_data = response.json()
|
|
133
|
+
|
|
134
|
+
# Cache the token
|
|
135
|
+
access_token = token_data.get("access_token")
|
|
136
|
+
if not access_token:
|
|
137
|
+
raise RuntimeError("Access token not found in token response")
|
|
138
|
+
self._cached_token = SecretStr(access_token)
|
|
139
|
+
expires_in = token_data.get("expires_in", 3600)
|
|
140
|
+
self._token_expires_at = datetime.now() + timedelta(seconds=expires_in)
|
|
141
|
+
|
|
142
|
+
logger.info("Service account token acquired (expires in %ss)", expires_in)
|
|
143
|
+
return self._cached_token
|
|
144
|
+
|
|
145
|
+
elif response.status_code == 401:
|
|
146
|
+
raise RuntimeError("Invalid service account credentials")
|
|
147
|
+
elif response.status_code == 429:
|
|
148
|
+
raise RuntimeError("Service account rate limit exceeded")
|
|
149
|
+
else:
|
|
150
|
+
raise RuntimeError(
|
|
151
|
+
f"Service account token request failed: {response.status_code} - {response.text}")
|
|
152
|
+
|
|
153
|
+
except httpx.TimeoutException as e:
|
|
154
|
+
raise RuntimeError(f"Service account token request timed out: {e}") from e
|
|
155
|
+
except httpx.RequestError as e:
|
|
156
|
+
raise RuntimeError(f"Service account token request failed: {e}") from e
|
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import hashlib
|
|
17
|
+
import json
|
|
18
|
+
import logging
|
|
19
|
+
from abc import ABC
|
|
20
|
+
from abc import abstractmethod
|
|
21
|
+
|
|
22
|
+
from nat.data_models.authentication import AuthResult
|
|
23
|
+
from nat.data_models.authentication import BasicAuthCred
|
|
24
|
+
from nat.data_models.authentication import BearerTokenCred
|
|
25
|
+
from nat.data_models.authentication import CookieCred
|
|
26
|
+
from nat.data_models.authentication import HeaderCred
|
|
27
|
+
from nat.data_models.authentication import QueryCred
|
|
28
|
+
from nat.data_models.object_store import NoSuchKeyError
|
|
29
|
+
from nat.object_store.interfaces import ObjectStore
|
|
30
|
+
from nat.object_store.models import ObjectStoreItem
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class TokenStorageBase(ABC):
|
|
36
|
+
"""
|
|
37
|
+
Abstract base class for token storage implementations.
|
|
38
|
+
|
|
39
|
+
Token storage implementations handle the secure persistence of authentication
|
|
40
|
+
tokens for MCP OAuth2 flows. Implementations can use various backends such as
|
|
41
|
+
object stores, databases, or in-memory storage.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
async def store(self, user_id: str, auth_result: AuthResult) -> None:
|
|
46
|
+
"""
|
|
47
|
+
Store an authentication result for a user.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
user_id: The unique identifier for the user
|
|
51
|
+
auth_result: The authentication result to store
|
|
52
|
+
"""
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
@abstractmethod
|
|
56
|
+
async def retrieve(self, user_id: str) -> AuthResult | None:
|
|
57
|
+
"""
|
|
58
|
+
Retrieve an authentication result for a user.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
user_id: The unique identifier for the user
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
The authentication result if found, None otherwise
|
|
65
|
+
"""
|
|
66
|
+
pass
|
|
67
|
+
|
|
68
|
+
@abstractmethod
|
|
69
|
+
async def delete(self, user_id: str) -> None:
|
|
70
|
+
"""
|
|
71
|
+
Delete an authentication result for a user.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
user_id: The unique identifier for the user
|
|
75
|
+
"""
|
|
76
|
+
pass
|
|
77
|
+
|
|
78
|
+
@abstractmethod
|
|
79
|
+
async def clear_all(self) -> None:
|
|
80
|
+
"""
|
|
81
|
+
Clear all stored authentication results.
|
|
82
|
+
"""
|
|
83
|
+
pass
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class ObjectStoreTokenStorage(TokenStorageBase):
|
|
87
|
+
"""
|
|
88
|
+
Token storage implementation backed by a NeMo Agent toolkit object store.
|
|
89
|
+
|
|
90
|
+
This implementation uses the object store infrastructure to persist tokens,
|
|
91
|
+
which provides encryption at rest, access controls, and persistence across
|
|
92
|
+
restarts when using backends like S3, MySQL, or Redis.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
def __init__(self, object_store: ObjectStore):
|
|
96
|
+
"""
|
|
97
|
+
Initialize the object store token storage.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
object_store: The object store instance to use for token persistence
|
|
101
|
+
"""
|
|
102
|
+
self._object_store = object_store
|
|
103
|
+
|
|
104
|
+
def _get_key(self, user_id: str) -> str:
|
|
105
|
+
"""
|
|
106
|
+
Generate the object store key for a user's token.
|
|
107
|
+
|
|
108
|
+
Uses SHA256 hash to ensure the key is S3-compatible and doesn't
|
|
109
|
+
contain special characters like "://" that are invalid in object keys.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
user_id: The user identifier
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
The object store key
|
|
116
|
+
"""
|
|
117
|
+
# Hash the user_id to create an S3-safe key
|
|
118
|
+
user_hash = hashlib.sha256(user_id.encode('utf-8')).hexdigest()
|
|
119
|
+
return f"tokens/{user_hash}"
|
|
120
|
+
|
|
121
|
+
async def store(self, user_id: str, auth_result: AuthResult) -> None:
|
|
122
|
+
"""
|
|
123
|
+
Store an authentication result in the object store.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
user_id: The unique identifier for the user
|
|
127
|
+
auth_result: The authentication result to store
|
|
128
|
+
"""
|
|
129
|
+
key = self._get_key(user_id)
|
|
130
|
+
|
|
131
|
+
# Serialize the AuthResult to JSON with secrets exposed
|
|
132
|
+
# SecretStr values are masked by default, so we need to expose them manually
|
|
133
|
+
# Create a serializable dict with exposed secrets
|
|
134
|
+
auth_dict = auth_result.model_dump(mode='json')
|
|
135
|
+
# Manually expose SecretStr values in credentials
|
|
136
|
+
for i, cred_obj in enumerate(auth_result.credentials):
|
|
137
|
+
if isinstance(cred_obj, BearerTokenCred):
|
|
138
|
+
auth_dict['credentials'][i]['token'] = cred_obj.token.get_secret_value()
|
|
139
|
+
elif isinstance(cred_obj, BasicAuthCred):
|
|
140
|
+
auth_dict['credentials'][i]['username'] = cred_obj.username.get_secret_value()
|
|
141
|
+
auth_dict['credentials'][i]['password'] = cred_obj.password.get_secret_value()
|
|
142
|
+
elif isinstance(cred_obj, HeaderCred | QueryCred | CookieCred):
|
|
143
|
+
auth_dict['credentials'][i]['value'] = cred_obj.value.get_secret_value()
|
|
144
|
+
|
|
145
|
+
data = json.dumps(auth_dict).encode('utf-8')
|
|
146
|
+
|
|
147
|
+
# Prepare metadata
|
|
148
|
+
metadata = {}
|
|
149
|
+
if auth_result.token_expires_at:
|
|
150
|
+
metadata["expires_at"] = auth_result.token_expires_at.isoformat()
|
|
151
|
+
|
|
152
|
+
# Create the object store item
|
|
153
|
+
item = ObjectStoreItem(data=data, content_type="application/json", metadata=metadata if metadata else None)
|
|
154
|
+
|
|
155
|
+
# Store using upsert to handle both new and existing tokens
|
|
156
|
+
await self._object_store.upsert_object(key, item)
|
|
157
|
+
|
|
158
|
+
async def retrieve(self, user_id: str) -> AuthResult | None:
|
|
159
|
+
"""
|
|
160
|
+
Retrieve an authentication result from the object store.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
user_id: The unique identifier for the user
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
The authentication result if found, None otherwise
|
|
167
|
+
"""
|
|
168
|
+
key = self._get_key(user_id)
|
|
169
|
+
|
|
170
|
+
try:
|
|
171
|
+
item = await self._object_store.get_object(key)
|
|
172
|
+
# Deserialize the AuthResult from JSON
|
|
173
|
+
auth_result = AuthResult.model_validate_json(item.data)
|
|
174
|
+
return auth_result
|
|
175
|
+
except NoSuchKeyError:
|
|
176
|
+
return None
|
|
177
|
+
except Exception as e:
|
|
178
|
+
logger.error(f"Error deserializing token for user {user_id}: {e}", exc_info=True)
|
|
179
|
+
return None
|
|
180
|
+
|
|
181
|
+
async def delete(self, user_id: str) -> None:
|
|
182
|
+
"""
|
|
183
|
+
Delete an authentication result from the object store.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
user_id: The unique identifier for the user
|
|
187
|
+
"""
|
|
188
|
+
key = self._get_key(user_id)
|
|
189
|
+
|
|
190
|
+
try:
|
|
191
|
+
await self._object_store.delete_object(key)
|
|
192
|
+
except NoSuchKeyError:
|
|
193
|
+
# Token doesn't exist, which is fine for delete operations
|
|
194
|
+
pass
|
|
195
|
+
|
|
196
|
+
async def clear_all(self) -> None:
|
|
197
|
+
"""
|
|
198
|
+
Clear all stored authentication results.
|
|
199
|
+
|
|
200
|
+
Note: This implementation does not support clearing all tokens as the
|
|
201
|
+
object store interface doesn't provide a list operation. Individual
|
|
202
|
+
tokens must be deleted explicitly.
|
|
203
|
+
"""
|
|
204
|
+
logger.warning("clear_all() is not supported for ObjectStoreTokenStorage")
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class InMemoryTokenStorage(TokenStorageBase):
|
|
208
|
+
"""
|
|
209
|
+
In-memory token storage using the built-in object store provided by the NeMo Agent toolkit.
|
|
210
|
+
|
|
211
|
+
This implementation uses the in-memory object store for token persistence,
|
|
212
|
+
which provides a secure default option that doesn't require external storage
|
|
213
|
+
configuration. Tokens are stored in memory and cleared when the process exits.
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
def __init__(self):
|
|
217
|
+
"""
|
|
218
|
+
Initialize the in-memory token storage.
|
|
219
|
+
"""
|
|
220
|
+
from nat.object_store.in_memory_object_store import InMemoryObjectStore
|
|
221
|
+
|
|
222
|
+
# Create a dedicated in-memory object store for tokens
|
|
223
|
+
self._object_store = InMemoryObjectStore()
|
|
224
|
+
|
|
225
|
+
# Wrap with ObjectStoreTokenStorage for the actual implementation
|
|
226
|
+
self._storage = ObjectStoreTokenStorage(self._object_store)
|
|
227
|
+
logger.debug("Initialized in-memory token storage")
|
|
228
|
+
|
|
229
|
+
async def store(self, user_id: str, auth_result: AuthResult) -> None:
|
|
230
|
+
"""
|
|
231
|
+
Store an authentication result in memory.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
user_id: The unique identifier for the user
|
|
235
|
+
auth_result: The authentication result to store
|
|
236
|
+
"""
|
|
237
|
+
await self._storage.store(user_id, auth_result)
|
|
238
|
+
|
|
239
|
+
async def retrieve(self, user_id: str) -> AuthResult | None:
|
|
240
|
+
"""
|
|
241
|
+
Retrieve an authentication result from memory.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
user_id: The unique identifier for the user
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
The authentication result if found, None otherwise
|
|
248
|
+
"""
|
|
249
|
+
return await self._storage.retrieve(user_id)
|
|
250
|
+
|
|
251
|
+
async def delete(self, user_id: str) -> None:
|
|
252
|
+
"""
|
|
253
|
+
Delete an authentication result from memory.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
user_id: The unique identifier for the user
|
|
257
|
+
"""
|
|
258
|
+
await self._storage.delete(user_id)
|
|
259
|
+
|
|
260
|
+
async def clear_all(self) -> None:
|
|
261
|
+
"""
|
|
262
|
+
Clear all stored authentication results from memory.
|
|
263
|
+
"""
|
|
264
|
+
# For in-memory storage, we can access the internal storage
|
|
265
|
+
self._object_store._store.clear()
|