nvidia-nat-mcp 1.3.0rc6__py3-none-any.whl → 1.4.0a20251123__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nvidia-nat-mcp might be problematic. Click here for more details.

@@ -18,6 +18,7 @@ from pydantic import HttpUrl
18
18
  from pydantic import model_validator
19
19
 
20
20
  from nat.authentication.interfaces import AuthProviderBaseConfig
21
+ from nat.data_models.common import OptionalSecretStr
21
22
 
22
23
 
23
24
  class MCPOAuth2ProviderConfig(AuthProviderBaseConfig, name="mcp_oauth2"):
@@ -36,7 +37,8 @@ class MCPOAuth2ProviderConfig(AuthProviderBaseConfig, name="mcp_oauth2"):
36
37
 
37
38
  # Client registration (manual registration vs DCR)
38
39
  client_id: str | None = Field(default=None, description="OAuth2 client ID for pre-registered clients")
39
- client_secret: str | None = Field(default=None, description="OAuth2 client secret for pre-registered clients")
40
+ client_secret: OptionalSecretStr = Field(default=None,
41
+ description="OAuth2 client secret for pre-registered clients")
40
42
  enable_dynamic_registration: bool = Field(default=True,
41
43
  description="Enable OAuth2 Dynamic Client Registration (RFC 7591)")
42
44
  client_name: str = Field(default="NAT MCP Client", description="OAuth2 client name for dynamic registration")
@@ -17,9 +17,17 @@ from nat.builder.builder import Builder
17
17
  from nat.cli.register_workflow import register_auth_provider
18
18
  from nat.plugins.mcp.auth.auth_provider import MCPOAuth2Provider
19
19
  from nat.plugins.mcp.auth.auth_provider_config import MCPOAuth2ProviderConfig
20
+ from nat.plugins.mcp.auth.service_account.provider import MCPServiceAccountProvider
21
+ from nat.plugins.mcp.auth.service_account.provider_config import MCPServiceAccountProviderConfig
20
22
 
21
23
 
22
24
  @register_auth_provider(config_type=MCPOAuth2ProviderConfig)
23
25
  async def mcp_oauth2_provider(authentication_provider: MCPOAuth2ProviderConfig, builder: Builder):
24
26
  """Register MCP OAuth2 authentication provider with NAT system."""
25
27
  yield MCPOAuth2Provider(authentication_provider, builder=builder)
28
+
29
+
30
+ @register_auth_provider(config_type=MCPServiceAccountProviderConfig)
31
+ async def mcp_service_account_provider(authentication_provider: MCPServiceAccountProviderConfig, builder: Builder):
32
+ """Register MCP Service Account authentication provider with NAT system."""
33
+ yield MCPServiceAccountProvider(authentication_provider, builder=builder)
@@ -0,0 +1,14 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
@@ -0,0 +1,136 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import asyncio
17
+ import importlib
18
+ import logging
19
+ import typing
20
+
21
+ from pydantic import SecretStr
22
+
23
+ from nat.authentication.interfaces import AuthProviderBase
24
+ from nat.data_models.authentication import AuthResult
25
+ from nat.data_models.authentication import Credential
26
+ from nat.data_models.authentication import HeaderCred
27
+ from nat.plugins.mcp.auth.service_account.provider_config import MCPServiceAccountProviderConfig
28
+ from nat.plugins.mcp.auth.service_account.token_client import ServiceAccountTokenClient
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ class MCPServiceAccountProvider(AuthProviderBase[MCPServiceAccountProviderConfig]):
34
+ """
35
+ MCP service account authentication provider using OAuth2 client credentials.
36
+
37
+ Provides headless authentication for MCP clients using service account credentials.
38
+ Supports two authentication patterns:
39
+
40
+ 1. Single authentication: OAuth2 service account token only
41
+ 2. Dual authentication: OAuth2 service account token + service-specific token
42
+
43
+ """
44
+
45
+ def __init__(self, config: MCPServiceAccountProviderConfig, builder=None):
46
+ super().__init__(config)
47
+
48
+ # Initialize token client
49
+ self._token_client = ServiceAccountTokenClient(
50
+ client_id=config.client_id,
51
+ client_secret=config.client_secret,
52
+ token_url=config.token_url,
53
+ scopes=" ".join(config.scopes), # Convert list to space-delimited string for OAuth2
54
+ token_cache_buffer_seconds=config.token_cache_buffer_seconds,
55
+ )
56
+
57
+ # Load dynamic service token function if configured
58
+ self._service_token_function = None
59
+ if config.service_token and config.service_token.function:
60
+ self._service_token_function = self._load_function(config.service_token.function)
61
+
62
+ logger.info("Initialized MCP service account auth provider: "
63
+ "token_url=%s, scopes=%s, has_service_token=%s",
64
+ config.token_url,
65
+ config.scopes,
66
+ config.service_token is not None)
67
+
68
+ def _load_function(self, function_path: str) -> typing.Callable:
69
+ """Load a Python function from a module path string (e.g., 'my_module.get_token')."""
70
+ try:
71
+ module_name, func_name = function_path.rsplit(".", 1)
72
+ module = importlib.import_module(module_name)
73
+ func = getattr(module, func_name)
74
+ logger.info("Loaded service token function: %s", function_path)
75
+ return func
76
+ except Exception as e:
77
+ raise ValueError(f"Failed to load service token function '{function_path}': {e}") from e
78
+
79
+ async def authenticate(self, user_id: str | None = None, **kwargs) -> AuthResult:
80
+ """
81
+ Authenticate using OAuth2 client credentials flow.
82
+
83
+ Note: user_id is ignored for service accounts (non-session-specific).
84
+
85
+ Returns:
86
+ AuthResult with HeaderCred objects for service account authentication
87
+ """
88
+ # Get OAuth2 access token (cached if still valid)
89
+ access_token = await self._token_client.get_access_token()
90
+
91
+ # Build credentials list using HeaderCred
92
+ credentials: list[Credential] = [
93
+ HeaderCred(name="Authorization", value=SecretStr(f"Bearer {access_token.get_secret_value()}"))
94
+ ]
95
+
96
+ # Add service-specific token if configured
97
+ if self.config.service_token:
98
+ service_header = self.config.service_token.header
99
+ service_token_value = None
100
+
101
+ # Get service token from static config or dynamic function
102
+ if self.config.service_token.token:
103
+ # Static token from config
104
+ service_token_value = self.config.service_token.token.get_secret_value()
105
+
106
+ elif self._service_token_function:
107
+ # Dynamic token from function
108
+ try:
109
+ # Pass configured kwargs to the function
110
+ # Function can access runtime context via AIQContext.get() if needed
111
+ # Handle both sync and async functions
112
+ if asyncio.iscoroutinefunction(self._service_token_function):
113
+ result = await self._service_token_function(**self.config.service_token.kwargs)
114
+ else:
115
+ result = self._service_token_function(**self.config.service_token.kwargs)
116
+
117
+ # Handle function return type: str or tuple[str, str]
118
+ if isinstance(result, tuple):
119
+ service_header, service_token_value = result
120
+ else:
121
+ service_token_value = result
122
+
123
+ logger.debug("Retrieved service token via dynamic function")
124
+
125
+ except Exception as e:
126
+ raise RuntimeError(f"Failed to get service token from function: {e}") from e
127
+
128
+ if service_token_value:
129
+ credentials.append(HeaderCred(name=service_header, value=SecretStr(service_token_value)))
130
+
131
+ # Return AuthResult with HeaderCred objects
132
+ return AuthResult(
133
+ credentials=credentials,
134
+ token_expires_at=self._token_client.token_expires_at,
135
+ raw={},
136
+ )
@@ -0,0 +1,137 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import typing
17
+
18
+ from pydantic import BaseModel
19
+ from pydantic import Field
20
+ from pydantic import field_validator
21
+ from pydantic import model_validator
22
+
23
+ from nat.authentication.interfaces import AuthProviderBaseConfig
24
+ from nat.data_models.common import OptionalSecretStr
25
+ from nat.data_models.common import SerializableSecretStr
26
+
27
+
28
+ class ServiceTokenConfig(BaseModel):
29
+ """
30
+ Configuration for service-specific token in dual authentication patterns.
31
+
32
+ Supports two modes:
33
+
34
+ 1. Static token: Provide token and header directly
35
+ 2. Dynamic function: Provide function path and optional kwargs
36
+
37
+ The function will be called on every request and should have signature::
38
+
39
+ async def get_service_token(**kwargs) -> str | tuple[str, str]
40
+
41
+ If function returns ``tuple[str, str]``, it's interpreted as (header_name, token).
42
+ If function returns ``str``, it's the token and header field is used for header name.
43
+
44
+ The function can access runtime context via AIQContext.get() if needed.
45
+ """
46
+
47
+ # Static token approach
48
+ token: OptionalSecretStr = Field(
49
+ default=None,
50
+ description="Static service token value (mutually exclusive with function)",
51
+ )
52
+
53
+ header: str = Field(
54
+ default="X-Service-Account-Token",
55
+ description="HTTP header name for service token (default: 'X-Service-Account-Token')",
56
+ )
57
+
58
+ # Dynamic function approach
59
+ function: str | None = Field(
60
+ default=None,
61
+ description=("Python function path that returns service token dynamically (mutually exclusive with token). "
62
+ "Function signature: async def func(\\**kwargs) -> str | tuple[str, str]. "
63
+ "Access runtime context via AIQContext.get() if needed."),
64
+ )
65
+
66
+ kwargs: dict[str, typing.Any] = Field(
67
+ default_factory=dict,
68
+ description="Additional keyword arguments to pass to the custom function",
69
+ )
70
+
71
+ @model_validator(mode="after")
72
+ def validate_token_or_function(self):
73
+ """Ensure either token or function is provided, but not both."""
74
+ has_token = self.token is not None
75
+ has_function = self.function is not None
76
+
77
+ if not has_token and not has_function:
78
+ raise ValueError("Either 'token' or 'function' must be provided in service_token config")
79
+
80
+ if has_token and has_function:
81
+ raise ValueError("Cannot specify both 'token' and 'function' in service_token config. Choose one.")
82
+
83
+ return self
84
+
85
+
86
+ class MCPServiceAccountProviderConfig(AuthProviderBaseConfig, name="mcp_service_account"):
87
+ """
88
+ Configuration for MCP service account authentication using OAuth2 client credentials.
89
+
90
+ Generic implementation supporting any OAuth2 client credentials flow.
91
+
92
+ Supports two authentication patterns:
93
+ 1. Single authentication: OAuth2 service account token only
94
+ 2. Dual authentication: OAuth2 service account token + service-specific token
95
+
96
+ Common use cases:
97
+ - Headless/automated MCP workflows
98
+ - CI/CD pipelines
99
+ - Backend services without user interaction
100
+
101
+ All values must be provided via configuration. Use ${ENV_VAR} syntax in YAML
102
+ configs for environment variable substitution.
103
+ """
104
+
105
+ # Required: OAuth2 client credentials
106
+ client_id: str = Field(description="OAuth2 client identifier")
107
+
108
+ client_secret: SerializableSecretStr = Field(description="OAuth2 client secret")
109
+
110
+ # Required: Token endpoint URL
111
+ token_url: str = Field(description="OAuth2 token endpoint URL")
112
+
113
+ # Required: OAuth2 scopes
114
+ scopes: list[str] = Field(description="List of OAuth2 scopes (will be joined with spaces for OAuth2 request)")
115
+
116
+ # Optional: Service-specific token configuration for dual authentication patterns
117
+ service_token: ServiceTokenConfig | None = Field(
118
+ default=None,
119
+ description="Optional service token configuration for dual authentication patterns. "
120
+ "Provide either a static token or a dynamic function that returns the token at runtime.",
121
+ )
122
+
123
+ # Token caching configuration
124
+ token_cache_buffer_seconds: int = Field(default=300,
125
+ description="Seconds before token expiry to refresh (default: 300s/5min)")
126
+
127
+ @field_validator("scopes", mode="before")
128
+ @classmethod
129
+ def validate_scopes(cls, v):
130
+ """
131
+ Accept both list[str] and space-delimited string formats for scopes.
132
+ Converts string to list for consistency.
133
+ """
134
+ if isinstance(v, str):
135
+ # Split space-delimited string into list
136
+ return [scope.strip() for scope in v.split() if scope.strip()]
137
+ return v
@@ -0,0 +1,156 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import asyncio
17
+ import base64
18
+ import logging
19
+ from datetime import datetime
20
+ from datetime import timedelta
21
+
22
+ import httpx
23
+ from pydantic import SecretStr
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class ServiceAccountTokenClient:
29
+ """
30
+ Generic OAuth2 client credentials token client for service accounts.
31
+
32
+ Implements standard OAuth2 client credentials flow with token caching.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ client_id: str,
38
+ client_secret: SecretStr,
39
+ token_url: str,
40
+ scopes: str,
41
+ token_cache_buffer_seconds: int = 300,
42
+ ):
43
+ """
44
+ Initialize service account token client.
45
+
46
+ Args:
47
+ client_id: OAuth2 client identifier
48
+ client_secret: OAuth2 client secret (SecretStr)
49
+ token_url: OAuth2 token endpoint URL
50
+ scopes: Space-separated list of scopes
51
+ token_cache_buffer_seconds: Seconds before expiry to refresh (default: 5 min)
52
+ """
53
+ self.client_id = client_id
54
+ self.client_secret = client_secret
55
+ self.token_url = token_url
56
+ self.scopes = scopes
57
+ self.token_cache_buffer_seconds = token_cache_buffer_seconds
58
+
59
+ # Token cache
60
+ self._cached_token: SecretStr | None = None
61
+ self._token_expires_at: datetime | None = None
62
+ self._lock = None # Will be initialized as asyncio.Lock when needed
63
+
64
+ @property
65
+ def token_expires_at(self) -> datetime | None:
66
+ return self._token_expires_at
67
+
68
+ async def _get_lock(self) -> asyncio.Lock:
69
+ """Lazy initialization of asyncio.Lock."""
70
+ if self._lock is None:
71
+ self._lock = asyncio.Lock()
72
+ return self._lock
73
+
74
+ def _is_token_valid(self) -> bool:
75
+ """Check if cached token is still valid (with buffer time)."""
76
+ if not self._cached_token or not self._token_expires_at:
77
+ return False
78
+ buffer = timedelta(seconds=self.token_cache_buffer_seconds)
79
+ return datetime.now() < (self._token_expires_at - buffer)
80
+
81
+ async def get_access_token(self) -> SecretStr:
82
+ """
83
+ Get OAuth2 access token, using cache if valid.
84
+
85
+ Returns:
86
+ Access token as SecretStr
87
+
88
+ Raises:
89
+ RuntimeError: If token acquisition fails
90
+ """
91
+ # Fast path: check cache without lock
92
+ if self._is_token_valid():
93
+ logger.debug("Using cached service account token")
94
+ assert self._cached_token is not None # _is_token_valid() ensures this
95
+ return self._cached_token
96
+
97
+ # Slow path: acquire lock and refresh token
98
+ lock = await self._get_lock()
99
+ async with lock:
100
+ # Double-check after acquiring lock
101
+ if self._is_token_valid():
102
+ logger.debug("Using cached service account token (acquired during lock wait)")
103
+ assert self._cached_token is not None # _is_token_valid() ensures this
104
+ return self._cached_token
105
+
106
+ logger.info("Fetching new service account token")
107
+ return await self._fetch_new_token()
108
+
109
+ async def _fetch_new_token(self) -> SecretStr:
110
+ """
111
+ Fetch a new token from the OAuth2 token endpoint.
112
+
113
+ Returns:
114
+ New access token as SecretStr
115
+
116
+ Raises:
117
+ RuntimeError: If token request fails
118
+ """
119
+ # Encode credentials for Basic authentication
120
+ credentials = f"{self.client_id}:{self.client_secret.get_secret_value()}"
121
+ encoded_credentials = base64.b64encode(credentials.encode()).decode()
122
+
123
+ headers = {"Authorization": f"Basic {encoded_credentials}", "Content-Type": "application/x-www-form-urlencoded"}
124
+
125
+ data = {"grant_type": "client_credentials", "scope": self.scopes}
126
+
127
+ try:
128
+ async with httpx.AsyncClient(timeout=30.0) as client:
129
+ response = await client.post(self.token_url, headers=headers, data=data)
130
+
131
+ if response.status_code == 200:
132
+ token_data = response.json()
133
+
134
+ # Cache the token
135
+ access_token = token_data.get("access_token")
136
+ if not access_token:
137
+ raise RuntimeError("Access token not found in token response")
138
+ self._cached_token = SecretStr(access_token)
139
+ expires_in = token_data.get("expires_in", 3600)
140
+ self._token_expires_at = datetime.now() + timedelta(seconds=expires_in)
141
+
142
+ logger.info("Service account token acquired (expires in %ss)", expires_in)
143
+ return self._cached_token
144
+
145
+ elif response.status_code == 401:
146
+ raise RuntimeError("Invalid service account credentials")
147
+ elif response.status_code == 429:
148
+ raise RuntimeError("Service account rate limit exceeded")
149
+ else:
150
+ raise RuntimeError(
151
+ f"Service account token request failed: {response.status_code} - {response.text}")
152
+
153
+ except httpx.TimeoutException as e:
154
+ raise RuntimeError(f"Service account token request timed out: {e}") from e
155
+ except httpx.RequestError as e:
156
+ raise RuntimeError(f"Service account token request failed: {e}") from e
@@ -112,14 +112,21 @@ class AuthAdapter(httpx.Auth):
112
112
  # Use the user_id passed to this AuthAdapter instance
113
113
  auth_result = await self.auth_provider.authenticate(user_id=self.user_id, response=response)
114
114
 
115
- # Check if we have BearerTokenCred
115
+ # Build headers from credentials
116
116
  from nat.data_models.authentication import BearerTokenCred
117
- if auth_result.credentials and isinstance(auth_result.credentials[0], BearerTokenCred):
118
- token = auth_result.credentials[0].token.get_secret_value()
119
- return {"Authorization": f"Bearer {token}"}
120
- else:
121
- logger.info("Auth provider did not return BearerTokenCred")
122
- return {}
117
+ from nat.data_models.authentication import HeaderCred
118
+ headers = {}
119
+
120
+ for cred in auth_result.credentials:
121
+ if isinstance(cred, BearerTokenCred):
122
+ # Standard Bearer token
123
+ token = cred.token.get_secret_value()
124
+ headers["Authorization"] = f"Bearer {token}"
125
+ elif isinstance(cred, HeaderCred):
126
+ # Generic header credential (supports custom formats and service accounts)
127
+ headers[cred.name] = cred.value.get_secret_value()
128
+
129
+ return headers
123
130
  except Exception as e:
124
131
  logger.warning("Failed to get auth token: %s", e)
125
132
  return {}
@@ -109,6 +109,10 @@ class MCPFunctionGroup(FunctionGroup):
109
109
  self._shared_auth_provider: AuthProviderBase | None = None
110
110
  self._client_config: MCPClientConfig | None = None
111
111
 
112
+ # Auth provider config defaults (set when auth provider is assigned)
113
+ self._default_user_id: str | None = None
114
+ self._allow_default_user_id_for_tool_calls: bool = True
115
+
112
116
  # Use random session id for testing only
113
117
  self._use_random_session_id_for_testing: bool = False
114
118
 
@@ -176,9 +180,8 @@ class MCPFunctionGroup(FunctionGroup):
176
180
 
177
181
  if not session_id:
178
182
  # use default user id if allowed
179
- if self._shared_auth_provider and \
180
- self._shared_auth_provider.config.allow_default_user_id_for_tool_calls:
181
- session_id = self._shared_auth_provider.config.default_user_id
183
+ if self._shared_auth_provider and self._allow_default_user_id_for_tool_calls:
184
+ session_id = self._default_user_id
182
185
  return session_id
183
186
  except Exception:
184
187
  return None
@@ -266,8 +269,7 @@ class MCPFunctionGroup(FunctionGroup):
266
269
  # If the session_id equals the configured default_user_id use the base client
267
270
  # instead of creating a per-session client
268
271
  if self._shared_auth_provider:
269
- default_uid = self._shared_auth_provider.config.default_user_id
270
- if default_uid and session_id == default_uid:
272
+ if self._default_user_id and session_id == self._default_user_id:
271
273
  return self.mcp_client
272
274
 
273
275
  # Fast path: check if session already exists (reader lock for concurrent access)
@@ -435,8 +437,7 @@ def mcp_session_tool_function(tool, function_group: MCPFunctionGroup):
435
437
  return "User not authorized to call the tool"
436
438
 
437
439
  # Check if this is the default user - if so, use base client directly
438
- if (not function_group._shared_auth_provider
439
- or session_id == function_group._shared_auth_provider.config.default_user_id):
440
+ if (not function_group._shared_auth_provider or session_id == function_group._default_user_id):
440
441
  # Use base client directly for default user
441
442
  client = function_group.mcp_client
442
443
  session_tool = await client.get_tool(tool.name)
@@ -507,7 +508,9 @@ async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder):
507
508
  reconnect_max_backoff=config.reconnect_max_backoff)
508
509
  elif config.server.transport == "streamable-http":
509
510
  # Use default_user_id for the base client
510
- base_user_id = auth_provider.config.default_user_id if auth_provider else None
511
+ # For interactive OAuth2: from config. For service accounts: defaults to server URL
512
+ base_user_id = getattr(auth_provider.config, 'default_user_id', str(
513
+ config.server.url)) if auth_provider else None
511
514
  client = MCPStreamableHTTPClient(str(config.server.url),
512
515
  auth_provider=auth_provider,
513
516
  user_id=base_user_id,
@@ -529,6 +532,18 @@ async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder):
529
532
  group._shared_auth_provider = auth_provider
530
533
  group._client_config = config
531
534
 
535
+ # Set auth provider config defaults
536
+ # For interactive OAuth2: use config values
537
+ # For service accounts: default_user_id = server URL, allow_default_user_id_for_tool_calls = True
538
+ if auth_provider:
539
+ group._default_user_id = getattr(auth_provider.config, 'default_user_id', str(config.server.url))
540
+ group._allow_default_user_id_for_tool_calls = getattr(auth_provider.config,
541
+ 'allow_default_user_id_for_tool_calls',
542
+ True)
543
+ else:
544
+ group._default_user_id = None
545
+ group._allow_default_user_id_for_tool_calls = True
546
+
532
547
  async with client:
533
548
  # Expose the live MCP client on the function group instance so other components (e.g., HTTP endpoints)
534
549
  # can reuse the already-established session instead of creating a new client per request.
nat/plugins/mcp/utils.py CHANGED
@@ -47,7 +47,7 @@ def model_from_mcp_schema(name: str, mcp_input_schema: dict) -> type[BaseModel]:
47
47
  "integer": int,
48
48
  "boolean": bool,
49
49
  "array": list,
50
- "null": None,
50
+ "null": type(None),
51
51
  "object": dict,
52
52
  }
53
53
 
@@ -58,51 +58,168 @@ def model_from_mcp_schema(name: str, mcp_input_schema: dict) -> type[BaseModel]:
58
58
  def _generate_valid_classname(class_name: str):
59
59
  return class_name.replace('_', ' ').replace('-', ' ').title().replace(' ', '')
60
60
 
61
- def _generate_field(field_name: str, field_properties: dict[str, Any]) -> tuple:
62
- json_type = field_properties.get("type", "string")
63
- enum_vals = field_properties.get("enum")
64
-
61
+ def _resolve_schema_type(schema: dict[str, Any], name: str) -> Any:
62
+ """
63
+ Recursively resolve a JSON schema to a Python type.
64
+ Handles nested anyOf/oneOf, arrays, objects, enums, and primitive types.
65
+ """
66
+ # Check for anyOf/oneOf first
67
+ any_of = schema.get("anyOf")
68
+ one_of = schema.get("oneOf")
69
+
70
+ if any_of or one_of:
71
+ union_schemas = any_of if any_of else one_of
72
+ resolved_type: Any = None
73
+
74
+ if union_schemas:
75
+ for sub_schema in union_schemas:
76
+ mapped = _resolve_schema_type(sub_schema, name)
77
+ if resolved_type is None:
78
+ resolved_type = mapped
79
+ elif mapped is not type(None):
80
+ # Don't add None here, handle separately
81
+ resolved_type = resolved_type | mapped
82
+ else:
83
+ # If we encounter null, combine with None at the end
84
+ resolved_type = resolved_type | None if resolved_type else type(None)
85
+
86
+ return resolved_type if resolved_type is not None else Any
87
+
88
+ # Handle enum values
89
+ enum_vals = schema.get("enum")
65
90
  if enum_vals:
66
- enum_name = f"{field_name.capitalize()}Enum"
67
- field_type = Enum(enum_name, {item: item for item in enum_vals})
68
-
69
- elif json_type == "object" and "properties" in field_properties:
70
- field_type = model_from_mcp_schema(name=field_name, mcp_input_schema=field_properties)
71
- elif json_type == "array" and "items" in field_properties:
72
- item_properties = field_properties.get("items", {})
73
- if item_properties.get("type") == "object":
74
- item_type = model_from_mcp_schema(name=field_name, mcp_input_schema=item_properties)
91
+ # Check if enum contains null
92
+ has_null = any(val is None or val == "null" for val in enum_vals)
93
+ # Filter out None/null values from enum
94
+ non_null_vals = [v for v in enum_vals if v is not None and v != "null"]
95
+
96
+ if non_null_vals:
97
+ enum_name = f"{name.capitalize()}Enum"
98
+ enum_type: Any = Enum(enum_name, {item: item for item in non_null_vals})
99
+ # If enum had null, make it a union with None
100
+ return enum_type | None if has_null else enum_type
101
+ elif has_null:
102
+ # Enum only contains null
103
+ return type(None)
75
104
  else:
76
- item_type = _type_map.get(item_properties.get("type", "string"), Any)
77
- field_type = list[item_type]
78
- elif isinstance(json_type, list):
79
- field_type = None
80
- for t in json_type:
81
- mapped = _type_map.get(t, Any)
82
- field_type = mapped if field_type is None else field_type | mapped
83
-
84
- return field_type, Field(
85
- default=field_properties.get("default", None if "null" in json_type else ...),
86
- description=field_properties.get("description", "")
87
- )
88
- else:
89
- field_type = _type_map.get(json_type, Any)
105
+ # Empty enum (shouldn't happen but handle gracefully)
106
+ return Any
107
+
108
+ schema_type = schema.get("type")
109
+
110
+ # Handle type as list (e.g., ["string", "integer", "null"])
111
+ if isinstance(schema_type, list):
112
+ list_type: Any = None
113
+ for t in schema_type:
114
+ if t == "array":
115
+ # Incorporate the mapped type of items
116
+ item_schema = schema.get("items", {})
117
+ if item_schema:
118
+ item_type = _resolve_schema_type(item_schema, name)
119
+ mapped = list[item_type]
120
+ else:
121
+ mapped = _type_map.get(t, Any)
122
+ elif t == "object":
123
+ # Incorporate the mapped type from properties
124
+ if "properties" in schema:
125
+ mapped = model_from_mcp_schema(name=name, mcp_input_schema=schema)
126
+ else:
127
+ mapped = _type_map.get(t, Any)
128
+ else:
129
+ mapped = _type_map.get(t, Any)
130
+
131
+ list_type = mapped if list_type is None else list_type | mapped
132
+ return list_type if list_type is not None else Any
133
+
134
+ # Handle null type
135
+ if schema_type == "null":
136
+ return type(None)
137
+
138
+ # Handle object type
139
+ if schema_type == "object" and "properties" in schema:
140
+ return model_from_mcp_schema(name=name, mcp_input_schema=schema)
141
+
142
+ # Handle array type
143
+ if schema_type == "array" and "items" in schema:
144
+ item_schema = schema.get("items", {})
145
+ # Recursively resolve item type (handles nested anyOf/oneOf)
146
+ item_type = _resolve_schema_type(item_schema, name)
147
+ return list[item_type]
148
+
149
+ # Handle primitive types
150
+ if schema_type is not None:
151
+ return _type_map.get(schema_type, Any)
152
+
153
+ return Any
154
+
155
+ def _has_null_in_type(field_properties: dict[str, Any]) -> bool:
156
+ """Check if a schema contains null as a valid type."""
157
+ # Check anyOf/oneOf for null
158
+ any_of = field_properties.get("anyOf")
159
+ one_of = field_properties.get("oneOf")
160
+ if any_of or one_of:
161
+ union_schemas = any_of if any_of else one_of
162
+ if union_schemas:
163
+ for schema in union_schemas:
164
+ if schema.get("type") == "null":
165
+ return True
166
+
167
+ # Check type list for null
168
+ json_type = field_properties.get("type")
169
+ if isinstance(json_type, list) and "null" in json_type:
170
+ return True
171
+
172
+ # Check enum for null (Python None or string "null")
173
+ enum_vals = field_properties.get("enum")
174
+ if enum_vals:
175
+ for val in enum_vals:
176
+ if val is None or val == "null":
177
+ return True
178
+
179
+ # Check const for null (Python None or string "null")
180
+ if "const" in field_properties:
181
+ const_val = field_properties.get("const")
182
+ if const_val is None or const_val == "null":
183
+ return True
184
+
185
+ return False
186
+
187
+ def _generate_field(field_name: str, field_properties: dict[str, Any]) -> tuple:
188
+ """
189
+ Generate a Pydantic field from JSON schema properties.
190
+ Uses _resolve_schema_type for type resolution and handles field-specific logic.
191
+ """
192
+ # Resolve the field type using the unified resolver
193
+ field_type = _resolve_schema_type(field_properties, field_name)
194
+
195
+ # Check if the type includes null
196
+ has_null = _has_null_in_type(field_properties)
90
197
 
91
198
  # Determine the default value based on whether the field is required
199
+ default_value = field_properties.get("default")
200
+
92
201
  if field_name in required_fields:
93
- # Field is required - use explicit default if provided, otherwise make it required
94
- default_value = field_properties.get("default", ...)
202
+ # Field is required - use explicit default if provided, otherwise use ... to enforce presence
203
+ if default_value is None and "default" not in field_properties:
204
+ # Required field without explicit default: always use ... even if nullable
205
+ default_value = ...
206
+ # Make the field type nullable if it allows null
207
+ if has_null:
208
+ field_type = field_type | None
95
209
  else:
96
210
  # Field is optional - use explicit default if provided, otherwise None
97
- default_value = field_properties.get("default", None)
98
- # Make the type optional if no default was provided
99
- if "default" not in field_properties:
211
+ if default_value is None:
212
+ default_value = None
213
+ # Make the type optional if no default was provided and not already nullable
214
+ if "default" not in field_properties and not has_null:
100
215
  field_type = field_type | None
101
216
 
217
+ # Handle nullable property (less common, but still supported)
102
218
  nullable = field_properties.get("nullable", False)
103
- description = field_properties.get("description", "")
219
+ if nullable and not has_null:
220
+ field_type = field_type | None
104
221
 
105
- field_type = field_type | None if nullable else field_type
222
+ description = field_properties.get("description", "")
106
223
 
107
224
  return field_type, Field(default=default_value, description=description)
108
225
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nvidia-nat-mcp
3
- Version: 1.3.0rc6
3
+ Version: 1.4.0a20251123
4
4
  Summary: Subpackage for MCP client integration in NeMo Agent toolkit
5
5
  Author: NVIDIA Corporation
6
6
  Maintainer: NVIDIA Corporation
@@ -16,7 +16,7 @@ Requires-Python: <3.14,>=3.11
16
16
  Description-Content-Type: text/markdown
17
17
  License-File: LICENSE-3rd-party.txt
18
18
  License-File: LICENSE.md
19
- Requires-Dist: nvidia-nat==v1.3.0-rc6
19
+ Requires-Dist: nvidia-nat==v1.4.0a20251123
20
20
  Requires-Dist: aiorwlock~=1.5
21
21
  Requires-Dist: mcp~=1.14
22
22
  Dynamic: license-file
@@ -0,0 +1,27 @@
1
+ nat/meta/pypi.md,sha256=EYyJTCCEOWzuuz-uNaYJ_WBk55Jiig87wcUr9E4g0yw,1484
2
+ nat/plugins/mcp/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
3
+ nat/plugins/mcp/client_base.py,sha256=Jm7FX4V5125vmqX9unAG4q-eeqgTyzfHgw14sZqnTRM,26753
4
+ nat/plugins/mcp/client_config.py,sha256=l9tVUHe8WdFPJ9rXDg8dZkQi1dvHGYwoqQ8Glqg2LGs,6783
5
+ nat/plugins/mcp/client_impl.py,sha256=BV7r39ijipAz9foAKtrLTairkaOjDH8Bnluw8sf72Ek,27902
6
+ nat/plugins/mcp/exception_handler.py,sha256=4JVdZDJL4LyumZEcMIEBK2LYC6djuSMzqUhQDZZ6dUo,7648
7
+ nat/plugins/mcp/exceptions.py,sha256=EGVOnYlui8xufm8dhJyPL1SUqBLnCGOTvRoeyNcmcWE,5980
8
+ nat/plugins/mcp/register.py,sha256=HOT2Wl2isGuyFc7BUTi58-BbjI5-EtZMZo7stsv5pN4,831
9
+ nat/plugins/mcp/tool.py,sha256=xNfBIF__ugJKFEjkYEM417wWM1PpuTaCMGtSFmxHSuA,6089
10
+ nat/plugins/mcp/utils.py,sha256=dUIig7jeKz0ctb4o38jFGbe2uvM3DMR3PSJjfN_Lr5M,9111
11
+ nat/plugins/mcp/auth/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
12
+ nat/plugins/mcp/auth/auth_flow_handler.py,sha256=v21IK3IKZ2TLEP6wO9r-sJQiilWPq7Ry40M96SAxQFA,9125
13
+ nat/plugins/mcp/auth/auth_provider.py,sha256=BgH66DlZgzhLDLO4cBERpHvNAmli5fMo_SCy11W9aBU,21251
14
+ nat/plugins/mcp/auth/auth_provider_config.py,sha256=ZdiUObYU_Oj8KDDZ-JqkS6kJgup5EBy2ZREUOBw_kkI,4143
15
+ nat/plugins/mcp/auth/register.py,sha256=miNZNmNszGgMCOCADLpH0Nz1nugkDXqdVc7ooapBx-c,1754
16
+ nat/plugins/mcp/auth/token_storage.py,sha256=aS13ZvEJXcYzkZ0GSbrSor4i5bpjD5BkXHQw1iywC9k,9240
17
+ nat/plugins/mcp/auth/service_account/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
18
+ nat/plugins/mcp/auth/service_account/provider.py,sha256=jJmt7PA_U3C59K5chSzmEIOd2bRxRKqXly9vHkTmsHQ,5915
19
+ nat/plugins/mcp/auth/service_account/provider_config.py,sha256=KhvQpROporbNKDFnx8np8b31QhdT3s5zYXd_IVTLgfE,5327
20
+ nat/plugins/mcp/auth/service_account/token_client.py,sha256=PRJ3u8Ts1tpuLuQKOyHDPOIAynSKxW9a928Rhh_4fhc,5981
21
+ nvidia_nat_mcp-1.4.0a20251123.dist-info/licenses/LICENSE-3rd-party.txt,sha256=fOk5jMmCX9YoKWyYzTtfgl-SUy477audFC5hNY4oP7Q,284609
22
+ nvidia_nat_mcp-1.4.0a20251123.dist-info/licenses/LICENSE.md,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
23
+ nvidia_nat_mcp-1.4.0a20251123.dist-info/METADATA,sha256=52-egpVHr45VW2p1147j9s4ZSxnY1UO2tXgJDHR5FOI,2319
24
+ nvidia_nat_mcp-1.4.0a20251123.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
25
+ nvidia_nat_mcp-1.4.0a20251123.dist-info/entry_points.txt,sha256=rYvUp4i-klBr3bVNh7zYOPXret704vTjvCk1qd7FooI,97
26
+ nvidia_nat_mcp-1.4.0a20251123.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
27
+ nvidia_nat_mcp-1.4.0a20251123.dist-info/RECORD,,
@@ -1,23 +0,0 @@
1
- nat/meta/pypi.md,sha256=EYyJTCCEOWzuuz-uNaYJ_WBk55Jiig87wcUr9E4g0yw,1484
2
- nat/plugins/mcp/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
3
- nat/plugins/mcp/client_base.py,sha256=JIyO2ZJsVkQ1g5BOU2zKXGHg_0yxv16g7_YJAqdCXTA,26504
4
- nat/plugins/mcp/client_config.py,sha256=l9tVUHe8WdFPJ9rXDg8dZkQi1dvHGYwoqQ8Glqg2LGs,6783
5
- nat/plugins/mcp/client_impl.py,sha256=j7cKAUBKtZAY3mt5Mm8VqgqMhRZk7kzvUd1nwMU_h0o,27072
6
- nat/plugins/mcp/exception_handler.py,sha256=4JVdZDJL4LyumZEcMIEBK2LYC6djuSMzqUhQDZZ6dUo,7648
7
- nat/plugins/mcp/exceptions.py,sha256=EGVOnYlui8xufm8dhJyPL1SUqBLnCGOTvRoeyNcmcWE,5980
8
- nat/plugins/mcp/register.py,sha256=HOT2Wl2isGuyFc7BUTi58-BbjI5-EtZMZo7stsv5pN4,831
9
- nat/plugins/mcp/tool.py,sha256=xNfBIF__ugJKFEjkYEM417wWM1PpuTaCMGtSFmxHSuA,6089
10
- nat/plugins/mcp/utils.py,sha256=4kNF5FJRiDUn-3fQcsvwvWtG6tYG1y4jU7vpptp0fsA,4522
11
- nat/plugins/mcp/auth/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
12
- nat/plugins/mcp/auth/auth_flow_handler.py,sha256=v21IK3IKZ2TLEP6wO9r-sJQiilWPq7Ry40M96SAxQFA,9125
13
- nat/plugins/mcp/auth/auth_provider.py,sha256=BgH66DlZgzhLDLO4cBERpHvNAmli5fMo_SCy11W9aBU,21251
14
- nat/plugins/mcp/auth/auth_provider_config.py,sha256=b1AaXzOuAkygKXAWSxMKWg8wfW8k33tmUUq6Dk5Mmwk,4038
15
- nat/plugins/mcp/auth/register.py,sha256=L2x69NjJPS4s6CCE5myzWVrWn3e_ttHyojmGXvBipMg,1228
16
- nat/plugins/mcp/auth/token_storage.py,sha256=aS13ZvEJXcYzkZ0GSbrSor4i5bpjD5BkXHQw1iywC9k,9240
17
- nvidia_nat_mcp-1.3.0rc6.dist-info/licenses/LICENSE-3rd-party.txt,sha256=fOk5jMmCX9YoKWyYzTtfgl-SUy477audFC5hNY4oP7Q,284609
18
- nvidia_nat_mcp-1.3.0rc6.dist-info/licenses/LICENSE.md,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
19
- nvidia_nat_mcp-1.3.0rc6.dist-info/METADATA,sha256=yTrTMrdDdfjXpxpGKrLbSI5xyZ_xBwk8lYrp0BO1c-A,2308
20
- nvidia_nat_mcp-1.3.0rc6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
21
- nvidia_nat_mcp-1.3.0rc6.dist-info/entry_points.txt,sha256=rYvUp4i-klBr3bVNh7zYOPXret704vTjvCk1qd7FooI,97
22
- nvidia_nat_mcp-1.3.0rc6.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
23
- nvidia_nat_mcp-1.3.0rc6.dist-info/RECORD,,