nvidia-nat-mcp 1.4.0a20251103__py3-none-any.whl → 1.4.0a20251123__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/plugins/mcp/auth/register.py +8 -0
- nat/plugins/mcp/auth/service_account/__init__.py +14 -0
- nat/plugins/mcp/auth/service_account/provider.py +136 -0
- nat/plugins/mcp/auth/service_account/provider_config.py +137 -0
- nat/plugins/mcp/auth/service_account/token_client.py +156 -0
- nat/plugins/mcp/client_base.py +14 -7
- nat/plugins/mcp/client_impl.py +23 -8
- {nvidia_nat_mcp-1.4.0a20251103.dist-info → nvidia_nat_mcp-1.4.0a20251123.dist-info}/METADATA +2 -2
- {nvidia_nat_mcp-1.4.0a20251103.dist-info → nvidia_nat_mcp-1.4.0a20251123.dist-info}/RECORD +14 -10
- {nvidia_nat_mcp-1.4.0a20251103.dist-info → nvidia_nat_mcp-1.4.0a20251123.dist-info}/WHEEL +0 -0
- {nvidia_nat_mcp-1.4.0a20251103.dist-info → nvidia_nat_mcp-1.4.0a20251123.dist-info}/entry_points.txt +0 -0
- {nvidia_nat_mcp-1.4.0a20251103.dist-info → nvidia_nat_mcp-1.4.0a20251123.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat_mcp-1.4.0a20251103.dist-info → nvidia_nat_mcp-1.4.0a20251123.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat_mcp-1.4.0a20251103.dist-info → nvidia_nat_mcp-1.4.0a20251123.dist-info}/top_level.txt +0 -0
nat/plugins/mcp/auth/register.py
CHANGED
|
@@ -17,9 +17,17 @@ from nat.builder.builder import Builder
|
|
|
17
17
|
from nat.cli.register_workflow import register_auth_provider
|
|
18
18
|
from nat.plugins.mcp.auth.auth_provider import MCPOAuth2Provider
|
|
19
19
|
from nat.plugins.mcp.auth.auth_provider_config import MCPOAuth2ProviderConfig
|
|
20
|
+
from nat.plugins.mcp.auth.service_account.provider import MCPServiceAccountProvider
|
|
21
|
+
from nat.plugins.mcp.auth.service_account.provider_config import MCPServiceAccountProviderConfig
|
|
20
22
|
|
|
21
23
|
|
|
22
24
|
@register_auth_provider(config_type=MCPOAuth2ProviderConfig)
|
|
23
25
|
async def mcp_oauth2_provider(authentication_provider: MCPOAuth2ProviderConfig, builder: Builder):
|
|
24
26
|
"""Register MCP OAuth2 authentication provider with NAT system."""
|
|
25
27
|
yield MCPOAuth2Provider(authentication_provider, builder=builder)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@register_auth_provider(config_type=MCPServiceAccountProviderConfig)
|
|
31
|
+
async def mcp_service_account_provider(authentication_provider: MCPServiceAccountProviderConfig, builder: Builder):
|
|
32
|
+
"""Register MCP Service Account authentication provider with NAT system."""
|
|
33
|
+
yield MCPServiceAccountProvider(authentication_provider, builder=builder)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import importlib
|
|
18
|
+
import logging
|
|
19
|
+
import typing
|
|
20
|
+
|
|
21
|
+
from pydantic import SecretStr
|
|
22
|
+
|
|
23
|
+
from nat.authentication.interfaces import AuthProviderBase
|
|
24
|
+
from nat.data_models.authentication import AuthResult
|
|
25
|
+
from nat.data_models.authentication import Credential
|
|
26
|
+
from nat.data_models.authentication import HeaderCred
|
|
27
|
+
from nat.plugins.mcp.auth.service_account.provider_config import MCPServiceAccountProviderConfig
|
|
28
|
+
from nat.plugins.mcp.auth.service_account.token_client import ServiceAccountTokenClient
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class MCPServiceAccountProvider(AuthProviderBase[MCPServiceAccountProviderConfig]):
|
|
34
|
+
"""
|
|
35
|
+
MCP service account authentication provider using OAuth2 client credentials.
|
|
36
|
+
|
|
37
|
+
Provides headless authentication for MCP clients using service account credentials.
|
|
38
|
+
Supports two authentication patterns:
|
|
39
|
+
|
|
40
|
+
1. Single authentication: OAuth2 service account token only
|
|
41
|
+
2. Dual authentication: OAuth2 service account token + service-specific token
|
|
42
|
+
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(self, config: MCPServiceAccountProviderConfig, builder=None):
|
|
46
|
+
super().__init__(config)
|
|
47
|
+
|
|
48
|
+
# Initialize token client
|
|
49
|
+
self._token_client = ServiceAccountTokenClient(
|
|
50
|
+
client_id=config.client_id,
|
|
51
|
+
client_secret=config.client_secret,
|
|
52
|
+
token_url=config.token_url,
|
|
53
|
+
scopes=" ".join(config.scopes), # Convert list to space-delimited string for OAuth2
|
|
54
|
+
token_cache_buffer_seconds=config.token_cache_buffer_seconds,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# Load dynamic service token function if configured
|
|
58
|
+
self._service_token_function = None
|
|
59
|
+
if config.service_token and config.service_token.function:
|
|
60
|
+
self._service_token_function = self._load_function(config.service_token.function)
|
|
61
|
+
|
|
62
|
+
logger.info("Initialized MCP service account auth provider: "
|
|
63
|
+
"token_url=%s, scopes=%s, has_service_token=%s",
|
|
64
|
+
config.token_url,
|
|
65
|
+
config.scopes,
|
|
66
|
+
config.service_token is not None)
|
|
67
|
+
|
|
68
|
+
def _load_function(self, function_path: str) -> typing.Callable:
|
|
69
|
+
"""Load a Python function from a module path string (e.g., 'my_module.get_token')."""
|
|
70
|
+
try:
|
|
71
|
+
module_name, func_name = function_path.rsplit(".", 1)
|
|
72
|
+
module = importlib.import_module(module_name)
|
|
73
|
+
func = getattr(module, func_name)
|
|
74
|
+
logger.info("Loaded service token function: %s", function_path)
|
|
75
|
+
return func
|
|
76
|
+
except Exception as e:
|
|
77
|
+
raise ValueError(f"Failed to load service token function '{function_path}': {e}") from e
|
|
78
|
+
|
|
79
|
+
async def authenticate(self, user_id: str | None = None, **kwargs) -> AuthResult:
|
|
80
|
+
"""
|
|
81
|
+
Authenticate using OAuth2 client credentials flow.
|
|
82
|
+
|
|
83
|
+
Note: user_id is ignored for service accounts (non-session-specific).
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
AuthResult with HeaderCred objects for service account authentication
|
|
87
|
+
"""
|
|
88
|
+
# Get OAuth2 access token (cached if still valid)
|
|
89
|
+
access_token = await self._token_client.get_access_token()
|
|
90
|
+
|
|
91
|
+
# Build credentials list using HeaderCred
|
|
92
|
+
credentials: list[Credential] = [
|
|
93
|
+
HeaderCred(name="Authorization", value=SecretStr(f"Bearer {access_token.get_secret_value()}"))
|
|
94
|
+
]
|
|
95
|
+
|
|
96
|
+
# Add service-specific token if configured
|
|
97
|
+
if self.config.service_token:
|
|
98
|
+
service_header = self.config.service_token.header
|
|
99
|
+
service_token_value = None
|
|
100
|
+
|
|
101
|
+
# Get service token from static config or dynamic function
|
|
102
|
+
if self.config.service_token.token:
|
|
103
|
+
# Static token from config
|
|
104
|
+
service_token_value = self.config.service_token.token.get_secret_value()
|
|
105
|
+
|
|
106
|
+
elif self._service_token_function:
|
|
107
|
+
# Dynamic token from function
|
|
108
|
+
try:
|
|
109
|
+
# Pass configured kwargs to the function
|
|
110
|
+
# Function can access runtime context via AIQContext.get() if needed
|
|
111
|
+
# Handle both sync and async functions
|
|
112
|
+
if asyncio.iscoroutinefunction(self._service_token_function):
|
|
113
|
+
result = await self._service_token_function(**self.config.service_token.kwargs)
|
|
114
|
+
else:
|
|
115
|
+
result = self._service_token_function(**self.config.service_token.kwargs)
|
|
116
|
+
|
|
117
|
+
# Handle function return type: str or tuple[str, str]
|
|
118
|
+
if isinstance(result, tuple):
|
|
119
|
+
service_header, service_token_value = result
|
|
120
|
+
else:
|
|
121
|
+
service_token_value = result
|
|
122
|
+
|
|
123
|
+
logger.debug("Retrieved service token via dynamic function")
|
|
124
|
+
|
|
125
|
+
except Exception as e:
|
|
126
|
+
raise RuntimeError(f"Failed to get service token from function: {e}") from e
|
|
127
|
+
|
|
128
|
+
if service_token_value:
|
|
129
|
+
credentials.append(HeaderCred(name=service_header, value=SecretStr(service_token_value)))
|
|
130
|
+
|
|
131
|
+
# Return AuthResult with HeaderCred objects
|
|
132
|
+
return AuthResult(
|
|
133
|
+
credentials=credentials,
|
|
134
|
+
token_expires_at=self._token_client.token_expires_at,
|
|
135
|
+
raw={},
|
|
136
|
+
)
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import typing
|
|
17
|
+
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
from pydantic import Field
|
|
20
|
+
from pydantic import field_validator
|
|
21
|
+
from pydantic import model_validator
|
|
22
|
+
|
|
23
|
+
from nat.authentication.interfaces import AuthProviderBaseConfig
|
|
24
|
+
from nat.data_models.common import OptionalSecretStr
|
|
25
|
+
from nat.data_models.common import SerializableSecretStr
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ServiceTokenConfig(BaseModel):
|
|
29
|
+
"""
|
|
30
|
+
Configuration for service-specific token in dual authentication patterns.
|
|
31
|
+
|
|
32
|
+
Supports two modes:
|
|
33
|
+
|
|
34
|
+
1. Static token: Provide token and header directly
|
|
35
|
+
2. Dynamic function: Provide function path and optional kwargs
|
|
36
|
+
|
|
37
|
+
The function will be called on every request and should have signature::
|
|
38
|
+
|
|
39
|
+
async def get_service_token(**kwargs) -> str | tuple[str, str]
|
|
40
|
+
|
|
41
|
+
If function returns ``tuple[str, str]``, it's interpreted as (header_name, token).
|
|
42
|
+
If function returns ``str``, it's the token and header field is used for header name.
|
|
43
|
+
|
|
44
|
+
The function can access runtime context via AIQContext.get() if needed.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
# Static token approach
|
|
48
|
+
token: OptionalSecretStr = Field(
|
|
49
|
+
default=None,
|
|
50
|
+
description="Static service token value (mutually exclusive with function)",
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
header: str = Field(
|
|
54
|
+
default="X-Service-Account-Token",
|
|
55
|
+
description="HTTP header name for service token (default: 'X-Service-Account-Token')",
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# Dynamic function approach
|
|
59
|
+
function: str | None = Field(
|
|
60
|
+
default=None,
|
|
61
|
+
description=("Python function path that returns service token dynamically (mutually exclusive with token). "
|
|
62
|
+
"Function signature: async def func(\\**kwargs) -> str | tuple[str, str]. "
|
|
63
|
+
"Access runtime context via AIQContext.get() if needed."),
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
kwargs: dict[str, typing.Any] = Field(
|
|
67
|
+
default_factory=dict,
|
|
68
|
+
description="Additional keyword arguments to pass to the custom function",
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
@model_validator(mode="after")
|
|
72
|
+
def validate_token_or_function(self):
|
|
73
|
+
"""Ensure either token or function is provided, but not both."""
|
|
74
|
+
has_token = self.token is not None
|
|
75
|
+
has_function = self.function is not None
|
|
76
|
+
|
|
77
|
+
if not has_token and not has_function:
|
|
78
|
+
raise ValueError("Either 'token' or 'function' must be provided in service_token config")
|
|
79
|
+
|
|
80
|
+
if has_token and has_function:
|
|
81
|
+
raise ValueError("Cannot specify both 'token' and 'function' in service_token config. Choose one.")
|
|
82
|
+
|
|
83
|
+
return self
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class MCPServiceAccountProviderConfig(AuthProviderBaseConfig, name="mcp_service_account"):
|
|
87
|
+
"""
|
|
88
|
+
Configuration for MCP service account authentication using OAuth2 client credentials.
|
|
89
|
+
|
|
90
|
+
Generic implementation supporting any OAuth2 client credentials flow.
|
|
91
|
+
|
|
92
|
+
Supports two authentication patterns:
|
|
93
|
+
1. Single authentication: OAuth2 service account token only
|
|
94
|
+
2. Dual authentication: OAuth2 service account token + service-specific token
|
|
95
|
+
|
|
96
|
+
Common use cases:
|
|
97
|
+
- Headless/automated MCP workflows
|
|
98
|
+
- CI/CD pipelines
|
|
99
|
+
- Backend services without user interaction
|
|
100
|
+
|
|
101
|
+
All values must be provided via configuration. Use ${ENV_VAR} syntax in YAML
|
|
102
|
+
configs for environment variable substitution.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
# Required: OAuth2 client credentials
|
|
106
|
+
client_id: str = Field(description="OAuth2 client identifier")
|
|
107
|
+
|
|
108
|
+
client_secret: SerializableSecretStr = Field(description="OAuth2 client secret")
|
|
109
|
+
|
|
110
|
+
# Required: Token endpoint URL
|
|
111
|
+
token_url: str = Field(description="OAuth2 token endpoint URL")
|
|
112
|
+
|
|
113
|
+
# Required: OAuth2 scopes
|
|
114
|
+
scopes: list[str] = Field(description="List of OAuth2 scopes (will be joined with spaces for OAuth2 request)")
|
|
115
|
+
|
|
116
|
+
# Optional: Service-specific token configuration for dual authentication patterns
|
|
117
|
+
service_token: ServiceTokenConfig | None = Field(
|
|
118
|
+
default=None,
|
|
119
|
+
description="Optional service token configuration for dual authentication patterns. "
|
|
120
|
+
"Provide either a static token or a dynamic function that returns the token at runtime.",
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# Token caching configuration
|
|
124
|
+
token_cache_buffer_seconds: int = Field(default=300,
|
|
125
|
+
description="Seconds before token expiry to refresh (default: 300s/5min)")
|
|
126
|
+
|
|
127
|
+
@field_validator("scopes", mode="before")
|
|
128
|
+
@classmethod
|
|
129
|
+
def validate_scopes(cls, v):
|
|
130
|
+
"""
|
|
131
|
+
Accept both list[str] and space-delimited string formats for scopes.
|
|
132
|
+
Converts string to list for consistency.
|
|
133
|
+
"""
|
|
134
|
+
if isinstance(v, str):
|
|
135
|
+
# Split space-delimited string into list
|
|
136
|
+
return [scope.strip() for scope in v.split() if scope.strip()]
|
|
137
|
+
return v
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import base64
|
|
18
|
+
import logging
|
|
19
|
+
from datetime import datetime
|
|
20
|
+
from datetime import timedelta
|
|
21
|
+
|
|
22
|
+
import httpx
|
|
23
|
+
from pydantic import SecretStr
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ServiceAccountTokenClient:
|
|
29
|
+
"""
|
|
30
|
+
Generic OAuth2 client credentials token client for service accounts.
|
|
31
|
+
|
|
32
|
+
Implements standard OAuth2 client credentials flow with token caching.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
client_id: str,
|
|
38
|
+
client_secret: SecretStr,
|
|
39
|
+
token_url: str,
|
|
40
|
+
scopes: str,
|
|
41
|
+
token_cache_buffer_seconds: int = 300,
|
|
42
|
+
):
|
|
43
|
+
"""
|
|
44
|
+
Initialize service account token client.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
client_id: OAuth2 client identifier
|
|
48
|
+
client_secret: OAuth2 client secret (SecretStr)
|
|
49
|
+
token_url: OAuth2 token endpoint URL
|
|
50
|
+
scopes: Space-separated list of scopes
|
|
51
|
+
token_cache_buffer_seconds: Seconds before expiry to refresh (default: 5 min)
|
|
52
|
+
"""
|
|
53
|
+
self.client_id = client_id
|
|
54
|
+
self.client_secret = client_secret
|
|
55
|
+
self.token_url = token_url
|
|
56
|
+
self.scopes = scopes
|
|
57
|
+
self.token_cache_buffer_seconds = token_cache_buffer_seconds
|
|
58
|
+
|
|
59
|
+
# Token cache
|
|
60
|
+
self._cached_token: SecretStr | None = None
|
|
61
|
+
self._token_expires_at: datetime | None = None
|
|
62
|
+
self._lock = None # Will be initialized as asyncio.Lock when needed
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def token_expires_at(self) -> datetime | None:
|
|
66
|
+
return self._token_expires_at
|
|
67
|
+
|
|
68
|
+
async def _get_lock(self) -> asyncio.Lock:
|
|
69
|
+
"""Lazy initialization of asyncio.Lock."""
|
|
70
|
+
if self._lock is None:
|
|
71
|
+
self._lock = asyncio.Lock()
|
|
72
|
+
return self._lock
|
|
73
|
+
|
|
74
|
+
def _is_token_valid(self) -> bool:
|
|
75
|
+
"""Check if cached token is still valid (with buffer time)."""
|
|
76
|
+
if not self._cached_token or not self._token_expires_at:
|
|
77
|
+
return False
|
|
78
|
+
buffer = timedelta(seconds=self.token_cache_buffer_seconds)
|
|
79
|
+
return datetime.now() < (self._token_expires_at - buffer)
|
|
80
|
+
|
|
81
|
+
async def get_access_token(self) -> SecretStr:
|
|
82
|
+
"""
|
|
83
|
+
Get OAuth2 access token, using cache if valid.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Access token as SecretStr
|
|
87
|
+
|
|
88
|
+
Raises:
|
|
89
|
+
RuntimeError: If token acquisition fails
|
|
90
|
+
"""
|
|
91
|
+
# Fast path: check cache without lock
|
|
92
|
+
if self._is_token_valid():
|
|
93
|
+
logger.debug("Using cached service account token")
|
|
94
|
+
assert self._cached_token is not None # _is_token_valid() ensures this
|
|
95
|
+
return self._cached_token
|
|
96
|
+
|
|
97
|
+
# Slow path: acquire lock and refresh token
|
|
98
|
+
lock = await self._get_lock()
|
|
99
|
+
async with lock:
|
|
100
|
+
# Double-check after acquiring lock
|
|
101
|
+
if self._is_token_valid():
|
|
102
|
+
logger.debug("Using cached service account token (acquired during lock wait)")
|
|
103
|
+
assert self._cached_token is not None # _is_token_valid() ensures this
|
|
104
|
+
return self._cached_token
|
|
105
|
+
|
|
106
|
+
logger.info("Fetching new service account token")
|
|
107
|
+
return await self._fetch_new_token()
|
|
108
|
+
|
|
109
|
+
async def _fetch_new_token(self) -> SecretStr:
|
|
110
|
+
"""
|
|
111
|
+
Fetch a new token from the OAuth2 token endpoint.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
New access token as SecretStr
|
|
115
|
+
|
|
116
|
+
Raises:
|
|
117
|
+
RuntimeError: If token request fails
|
|
118
|
+
"""
|
|
119
|
+
# Encode credentials for Basic authentication
|
|
120
|
+
credentials = f"{self.client_id}:{self.client_secret.get_secret_value()}"
|
|
121
|
+
encoded_credentials = base64.b64encode(credentials.encode()).decode()
|
|
122
|
+
|
|
123
|
+
headers = {"Authorization": f"Basic {encoded_credentials}", "Content-Type": "application/x-www-form-urlencoded"}
|
|
124
|
+
|
|
125
|
+
data = {"grant_type": "client_credentials", "scope": self.scopes}
|
|
126
|
+
|
|
127
|
+
try:
|
|
128
|
+
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
129
|
+
response = await client.post(self.token_url, headers=headers, data=data)
|
|
130
|
+
|
|
131
|
+
if response.status_code == 200:
|
|
132
|
+
token_data = response.json()
|
|
133
|
+
|
|
134
|
+
# Cache the token
|
|
135
|
+
access_token = token_data.get("access_token")
|
|
136
|
+
if not access_token:
|
|
137
|
+
raise RuntimeError("Access token not found in token response")
|
|
138
|
+
self._cached_token = SecretStr(access_token)
|
|
139
|
+
expires_in = token_data.get("expires_in", 3600)
|
|
140
|
+
self._token_expires_at = datetime.now() + timedelta(seconds=expires_in)
|
|
141
|
+
|
|
142
|
+
logger.info("Service account token acquired (expires in %ss)", expires_in)
|
|
143
|
+
return self._cached_token
|
|
144
|
+
|
|
145
|
+
elif response.status_code == 401:
|
|
146
|
+
raise RuntimeError("Invalid service account credentials")
|
|
147
|
+
elif response.status_code == 429:
|
|
148
|
+
raise RuntimeError("Service account rate limit exceeded")
|
|
149
|
+
else:
|
|
150
|
+
raise RuntimeError(
|
|
151
|
+
f"Service account token request failed: {response.status_code} - {response.text}")
|
|
152
|
+
|
|
153
|
+
except httpx.TimeoutException as e:
|
|
154
|
+
raise RuntimeError(f"Service account token request timed out: {e}") from e
|
|
155
|
+
except httpx.RequestError as e:
|
|
156
|
+
raise RuntimeError(f"Service account token request failed: {e}") from e
|
nat/plugins/mcp/client_base.py
CHANGED
|
@@ -112,14 +112,21 @@ class AuthAdapter(httpx.Auth):
|
|
|
112
112
|
# Use the user_id passed to this AuthAdapter instance
|
|
113
113
|
auth_result = await self.auth_provider.authenticate(user_id=self.user_id, response=response)
|
|
114
114
|
|
|
115
|
-
#
|
|
115
|
+
# Build headers from credentials
|
|
116
116
|
from nat.data_models.authentication import BearerTokenCred
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
117
|
+
from nat.data_models.authentication import HeaderCred
|
|
118
|
+
headers = {}
|
|
119
|
+
|
|
120
|
+
for cred in auth_result.credentials:
|
|
121
|
+
if isinstance(cred, BearerTokenCred):
|
|
122
|
+
# Standard Bearer token
|
|
123
|
+
token = cred.token.get_secret_value()
|
|
124
|
+
headers["Authorization"] = f"Bearer {token}"
|
|
125
|
+
elif isinstance(cred, HeaderCred):
|
|
126
|
+
# Generic header credential (supports custom formats and service accounts)
|
|
127
|
+
headers[cred.name] = cred.value.get_secret_value()
|
|
128
|
+
|
|
129
|
+
return headers
|
|
123
130
|
except Exception as e:
|
|
124
131
|
logger.warning("Failed to get auth token: %s", e)
|
|
125
132
|
return {}
|
nat/plugins/mcp/client_impl.py
CHANGED
|
@@ -109,6 +109,10 @@ class MCPFunctionGroup(FunctionGroup):
|
|
|
109
109
|
self._shared_auth_provider: AuthProviderBase | None = None
|
|
110
110
|
self._client_config: MCPClientConfig | None = None
|
|
111
111
|
|
|
112
|
+
# Auth provider config defaults (set when auth provider is assigned)
|
|
113
|
+
self._default_user_id: str | None = None
|
|
114
|
+
self._allow_default_user_id_for_tool_calls: bool = True
|
|
115
|
+
|
|
112
116
|
# Use random session id for testing only
|
|
113
117
|
self._use_random_session_id_for_testing: bool = False
|
|
114
118
|
|
|
@@ -176,9 +180,8 @@ class MCPFunctionGroup(FunctionGroup):
|
|
|
176
180
|
|
|
177
181
|
if not session_id:
|
|
178
182
|
# use default user id if allowed
|
|
179
|
-
if self._shared_auth_provider and
|
|
180
|
-
self.
|
|
181
|
-
session_id = self._shared_auth_provider.config.default_user_id
|
|
183
|
+
if self._shared_auth_provider and self._allow_default_user_id_for_tool_calls:
|
|
184
|
+
session_id = self._default_user_id
|
|
182
185
|
return session_id
|
|
183
186
|
except Exception:
|
|
184
187
|
return None
|
|
@@ -266,8 +269,7 @@ class MCPFunctionGroup(FunctionGroup):
|
|
|
266
269
|
# If the session_id equals the configured default_user_id use the base client
|
|
267
270
|
# instead of creating a per-session client
|
|
268
271
|
if self._shared_auth_provider:
|
|
269
|
-
|
|
270
|
-
if default_uid and session_id == default_uid:
|
|
272
|
+
if self._default_user_id and session_id == self._default_user_id:
|
|
271
273
|
return self.mcp_client
|
|
272
274
|
|
|
273
275
|
# Fast path: check if session already exists (reader lock for concurrent access)
|
|
@@ -435,8 +437,7 @@ def mcp_session_tool_function(tool, function_group: MCPFunctionGroup):
|
|
|
435
437
|
return "User not authorized to call the tool"
|
|
436
438
|
|
|
437
439
|
# Check if this is the default user - if so, use base client directly
|
|
438
|
-
if (not function_group._shared_auth_provider
|
|
439
|
-
or session_id == function_group._shared_auth_provider.config.default_user_id):
|
|
440
|
+
if (not function_group._shared_auth_provider or session_id == function_group._default_user_id):
|
|
440
441
|
# Use base client directly for default user
|
|
441
442
|
client = function_group.mcp_client
|
|
442
443
|
session_tool = await client.get_tool(tool.name)
|
|
@@ -507,7 +508,9 @@ async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder):
|
|
|
507
508
|
reconnect_max_backoff=config.reconnect_max_backoff)
|
|
508
509
|
elif config.server.transport == "streamable-http":
|
|
509
510
|
# Use default_user_id for the base client
|
|
510
|
-
|
|
511
|
+
# For interactive OAuth2: from config. For service accounts: defaults to server URL
|
|
512
|
+
base_user_id = getattr(auth_provider.config, 'default_user_id', str(
|
|
513
|
+
config.server.url)) if auth_provider else None
|
|
511
514
|
client = MCPStreamableHTTPClient(str(config.server.url),
|
|
512
515
|
auth_provider=auth_provider,
|
|
513
516
|
user_id=base_user_id,
|
|
@@ -529,6 +532,18 @@ async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder):
|
|
|
529
532
|
group._shared_auth_provider = auth_provider
|
|
530
533
|
group._client_config = config
|
|
531
534
|
|
|
535
|
+
# Set auth provider config defaults
|
|
536
|
+
# For interactive OAuth2: use config values
|
|
537
|
+
# For service accounts: default_user_id = server URL, allow_default_user_id_for_tool_calls = True
|
|
538
|
+
if auth_provider:
|
|
539
|
+
group._default_user_id = getattr(auth_provider.config, 'default_user_id', str(config.server.url))
|
|
540
|
+
group._allow_default_user_id_for_tool_calls = getattr(auth_provider.config,
|
|
541
|
+
'allow_default_user_id_for_tool_calls',
|
|
542
|
+
True)
|
|
543
|
+
else:
|
|
544
|
+
group._default_user_id = None
|
|
545
|
+
group._allow_default_user_id_for_tool_calls = True
|
|
546
|
+
|
|
532
547
|
async with client:
|
|
533
548
|
# Expose the live MCP client on the function group instance so other components (e.g., HTTP endpoints)
|
|
534
549
|
# can reuse the already-established session instead of creating a new client per request.
|
{nvidia_nat_mcp-1.4.0a20251103.dist-info → nvidia_nat_mcp-1.4.0a20251123.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: nvidia-nat-mcp
|
|
3
|
-
Version: 1.4.
|
|
3
|
+
Version: 1.4.0a20251123
|
|
4
4
|
Summary: Subpackage for MCP client integration in NeMo Agent toolkit
|
|
5
5
|
Author: NVIDIA Corporation
|
|
6
6
|
Maintainer: NVIDIA Corporation
|
|
@@ -16,7 +16,7 @@ Requires-Python: <3.14,>=3.11
|
|
|
16
16
|
Description-Content-Type: text/markdown
|
|
17
17
|
License-File: LICENSE-3rd-party.txt
|
|
18
18
|
License-File: LICENSE.md
|
|
19
|
-
Requires-Dist: nvidia-nat==v1.4.
|
|
19
|
+
Requires-Dist: nvidia-nat==v1.4.0a20251123
|
|
20
20
|
Requires-Dist: aiorwlock~=1.5
|
|
21
21
|
Requires-Dist: mcp~=1.14
|
|
22
22
|
Dynamic: license-file
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
nat/meta/pypi.md,sha256=EYyJTCCEOWzuuz-uNaYJ_WBk55Jiig87wcUr9E4g0yw,1484
|
|
2
2
|
nat/plugins/mcp/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
|
|
3
|
-
nat/plugins/mcp/client_base.py,sha256=
|
|
3
|
+
nat/plugins/mcp/client_base.py,sha256=Jm7FX4V5125vmqX9unAG4q-eeqgTyzfHgw14sZqnTRM,26753
|
|
4
4
|
nat/plugins/mcp/client_config.py,sha256=l9tVUHe8WdFPJ9rXDg8dZkQi1dvHGYwoqQ8Glqg2LGs,6783
|
|
5
|
-
nat/plugins/mcp/client_impl.py,sha256=
|
|
5
|
+
nat/plugins/mcp/client_impl.py,sha256=BV7r39ijipAz9foAKtrLTairkaOjDH8Bnluw8sf72Ek,27902
|
|
6
6
|
nat/plugins/mcp/exception_handler.py,sha256=4JVdZDJL4LyumZEcMIEBK2LYC6djuSMzqUhQDZZ6dUo,7648
|
|
7
7
|
nat/plugins/mcp/exceptions.py,sha256=EGVOnYlui8xufm8dhJyPL1SUqBLnCGOTvRoeyNcmcWE,5980
|
|
8
8
|
nat/plugins/mcp/register.py,sha256=HOT2Wl2isGuyFc7BUTi58-BbjI5-EtZMZo7stsv5pN4,831
|
|
@@ -12,12 +12,16 @@ nat/plugins/mcp/auth/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aM
|
|
|
12
12
|
nat/plugins/mcp/auth/auth_flow_handler.py,sha256=v21IK3IKZ2TLEP6wO9r-sJQiilWPq7Ry40M96SAxQFA,9125
|
|
13
13
|
nat/plugins/mcp/auth/auth_provider.py,sha256=BgH66DlZgzhLDLO4cBERpHvNAmli5fMo_SCy11W9aBU,21251
|
|
14
14
|
nat/plugins/mcp/auth/auth_provider_config.py,sha256=ZdiUObYU_Oj8KDDZ-JqkS6kJgup5EBy2ZREUOBw_kkI,4143
|
|
15
|
-
nat/plugins/mcp/auth/register.py,sha256=
|
|
15
|
+
nat/plugins/mcp/auth/register.py,sha256=miNZNmNszGgMCOCADLpH0Nz1nugkDXqdVc7ooapBx-c,1754
|
|
16
16
|
nat/plugins/mcp/auth/token_storage.py,sha256=aS13ZvEJXcYzkZ0GSbrSor4i5bpjD5BkXHQw1iywC9k,9240
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
nvidia_nat_mcp-1.4.
|
|
22
|
-
nvidia_nat_mcp-1.4.
|
|
23
|
-
nvidia_nat_mcp-1.4.
|
|
17
|
+
nat/plugins/mcp/auth/service_account/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
|
|
18
|
+
nat/plugins/mcp/auth/service_account/provider.py,sha256=jJmt7PA_U3C59K5chSzmEIOd2bRxRKqXly9vHkTmsHQ,5915
|
|
19
|
+
nat/plugins/mcp/auth/service_account/provider_config.py,sha256=KhvQpROporbNKDFnx8np8b31QhdT3s5zYXd_IVTLgfE,5327
|
|
20
|
+
nat/plugins/mcp/auth/service_account/token_client.py,sha256=PRJ3u8Ts1tpuLuQKOyHDPOIAynSKxW9a928Rhh_4fhc,5981
|
|
21
|
+
nvidia_nat_mcp-1.4.0a20251123.dist-info/licenses/LICENSE-3rd-party.txt,sha256=fOk5jMmCX9YoKWyYzTtfgl-SUy477audFC5hNY4oP7Q,284609
|
|
22
|
+
nvidia_nat_mcp-1.4.0a20251123.dist-info/licenses/LICENSE.md,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
|
|
23
|
+
nvidia_nat_mcp-1.4.0a20251123.dist-info/METADATA,sha256=52-egpVHr45VW2p1147j9s4ZSxnY1UO2tXgJDHR5FOI,2319
|
|
24
|
+
nvidia_nat_mcp-1.4.0a20251123.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
25
|
+
nvidia_nat_mcp-1.4.0a20251123.dist-info/entry_points.txt,sha256=rYvUp4i-klBr3bVNh7zYOPXret704vTjvCk1qd7FooI,97
|
|
26
|
+
nvidia_nat_mcp-1.4.0a20251123.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
|
|
27
|
+
nvidia_nat_mcp-1.4.0a20251123.dist-info/RECORD,,
|
|
File without changes
|
{nvidia_nat_mcp-1.4.0a20251103.dist-info → nvidia_nat_mcp-1.4.0a20251123.dist-info}/entry_points.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{nvidia_nat_mcp-1.4.0a20251103.dist-info → nvidia_nat_mcp-1.4.0a20251123.dist-info}/top_level.txt
RENAMED
|
File without changes
|