nvidia-nat-mcp 1.4.0a20260107__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. nat/meta/pypi.md +32 -0
  2. nat/plugins/mcp/__init__.py +14 -0
  3. nat/plugins/mcp/auth/__init__.py +14 -0
  4. nat/plugins/mcp/auth/auth_flow_handler.py +208 -0
  5. nat/plugins/mcp/auth/auth_provider.py +431 -0
  6. nat/plugins/mcp/auth/auth_provider_config.py +86 -0
  7. nat/plugins/mcp/auth/register.py +33 -0
  8. nat/plugins/mcp/auth/service_account/__init__.py +14 -0
  9. nat/plugins/mcp/auth/service_account/provider.py +136 -0
  10. nat/plugins/mcp/auth/service_account/provider_config.py +137 -0
  11. nat/plugins/mcp/auth/service_account/token_client.py +156 -0
  12. nat/plugins/mcp/auth/token_storage.py +265 -0
  13. nat/plugins/mcp/cli/__init__.py +15 -0
  14. nat/plugins/mcp/cli/commands.py +1051 -0
  15. nat/plugins/mcp/client/__init__.py +15 -0
  16. nat/plugins/mcp/client/client_base.py +665 -0
  17. nat/plugins/mcp/client/client_config.py +146 -0
  18. nat/plugins/mcp/client/client_impl.py +782 -0
  19. nat/plugins/mcp/exception_handler.py +211 -0
  20. nat/plugins/mcp/exceptions.py +142 -0
  21. nat/plugins/mcp/register.py +23 -0
  22. nat/plugins/mcp/server/__init__.py +15 -0
  23. nat/plugins/mcp/server/front_end_config.py +109 -0
  24. nat/plugins/mcp/server/front_end_plugin.py +155 -0
  25. nat/plugins/mcp/server/front_end_plugin_worker.py +411 -0
  26. nat/plugins/mcp/server/introspection_token_verifier.py +72 -0
  27. nat/plugins/mcp/server/memory_profiler.py +320 -0
  28. nat/plugins/mcp/server/register_frontend.py +27 -0
  29. nat/plugins/mcp/server/tool_converter.py +286 -0
  30. nat/plugins/mcp/utils.py +228 -0
  31. nvidia_nat_mcp-1.4.0a20260107.dist-info/METADATA +55 -0
  32. nvidia_nat_mcp-1.4.0a20260107.dist-info/RECORD +37 -0
  33. nvidia_nat_mcp-1.4.0a20260107.dist-info/WHEEL +5 -0
  34. nvidia_nat_mcp-1.4.0a20260107.dist-info/entry_points.txt +9 -0
  35. nvidia_nat_mcp-1.4.0a20260107.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
  36. nvidia_nat_mcp-1.4.0a20260107.dist-info/licenses/LICENSE.md +201 -0
  37. nvidia_nat_mcp-1.4.0a20260107.dist-info/top_level.txt +1 -0
@@ -0,0 +1,136 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import asyncio
17
+ import importlib
18
+ import logging
19
+ import typing
20
+
21
+ from pydantic import SecretStr
22
+
23
+ from nat.authentication.interfaces import AuthProviderBase
24
+ from nat.data_models.authentication import AuthResult
25
+ from nat.data_models.authentication import Credential
26
+ from nat.data_models.authentication import HeaderCred
27
+ from nat.plugins.mcp.auth.service_account.provider_config import MCPServiceAccountProviderConfig
28
+ from nat.plugins.mcp.auth.service_account.token_client import ServiceAccountTokenClient
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ class MCPServiceAccountProvider(AuthProviderBase[MCPServiceAccountProviderConfig]):
34
+ """
35
+ MCP service account authentication provider using OAuth2 client credentials.
36
+
37
+ Provides headless authentication for MCP clients using service account credentials.
38
+ Supports two authentication patterns:
39
+
40
+ 1. Single authentication: OAuth2 service account token only
41
+ 2. Dual authentication: OAuth2 service account token + service-specific token
42
+
43
+ """
44
+
45
+ def __init__(self, config: MCPServiceAccountProviderConfig, builder=None):
46
+ super().__init__(config)
47
+
48
+ # Initialize token client
49
+ self._token_client = ServiceAccountTokenClient(
50
+ client_id=config.client_id,
51
+ client_secret=config.client_secret,
52
+ token_url=config.token_url,
53
+ scopes=" ".join(config.scopes), # Convert list to space-delimited string for OAuth2
54
+ token_cache_buffer_seconds=config.token_cache_buffer_seconds,
55
+ )
56
+
57
+ # Load dynamic service token function if configured
58
+ self._service_token_function = None
59
+ if config.service_token and config.service_token.function:
60
+ self._service_token_function = self._load_function(config.service_token.function)
61
+
62
+ logger.info("Initialized MCP service account auth provider: "
63
+ "token_url=%s, scopes=%s, has_service_token=%s",
64
+ config.token_url,
65
+ config.scopes,
66
+ config.service_token is not None)
67
+
68
+ def _load_function(self, function_path: str) -> typing.Callable:
69
+ """Load a Python function from a module path string (e.g., 'my_module.get_token')."""
70
+ try:
71
+ module_name, func_name = function_path.rsplit(".", 1)
72
+ module = importlib.import_module(module_name)
73
+ func = getattr(module, func_name)
74
+ logger.info("Loaded service token function: %s", function_path)
75
+ return func
76
+ except Exception as e:
77
+ raise ValueError(f"Failed to load service token function '{function_path}': {e}") from e
78
+
79
+ async def authenticate(self, user_id: str | None = None, **kwargs) -> AuthResult:
80
+ """
81
+ Authenticate using OAuth2 client credentials flow.
82
+
83
+ Note: user_id is ignored for service accounts (non-session-specific).
84
+
85
+ Returns:
86
+ AuthResult with HeaderCred objects for service account authentication
87
+ """
88
+ # Get OAuth2 access token (cached if still valid)
89
+ access_token = await self._token_client.get_access_token()
90
+
91
+ # Build credentials list using HeaderCred
92
+ credentials: list[Credential] = [
93
+ HeaderCred(name="Authorization", value=SecretStr(f"Bearer {access_token.get_secret_value()}"))
94
+ ]
95
+
96
+ # Add service-specific token if configured
97
+ if self.config.service_token:
98
+ service_header = self.config.service_token.header
99
+ service_token_value = None
100
+
101
+ # Get service token from static config or dynamic function
102
+ if self.config.service_token.token:
103
+ # Static token from config
104
+ service_token_value = self.config.service_token.token.get_secret_value()
105
+
106
+ elif self._service_token_function:
107
+ # Dynamic token from function
108
+ try:
109
+ # Pass configured kwargs to the function
110
+ # Function can access runtime context via AIQContext.get() if needed
111
+ # Handle both sync and async functions
112
+ if asyncio.iscoroutinefunction(self._service_token_function):
113
+ result = await self._service_token_function(**self.config.service_token.kwargs)
114
+ else:
115
+ result = self._service_token_function(**self.config.service_token.kwargs)
116
+
117
+ # Handle function return type: str or tuple[str, str]
118
+ if isinstance(result, tuple):
119
+ service_header, service_token_value = result
120
+ else:
121
+ service_token_value = result
122
+
123
+ logger.debug("Retrieved service token via dynamic function")
124
+
125
+ except Exception as e:
126
+ raise RuntimeError(f"Failed to get service token from function: {e}") from e
127
+
128
+ if service_token_value:
129
+ credentials.append(HeaderCred(name=service_header, value=SecretStr(service_token_value)))
130
+
131
+ # Return AuthResult with HeaderCred objects
132
+ return AuthResult(
133
+ credentials=credentials,
134
+ token_expires_at=self._token_client.token_expires_at,
135
+ raw={},
136
+ )
@@ -0,0 +1,137 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import typing
17
+
18
+ from pydantic import BaseModel
19
+ from pydantic import Field
20
+ from pydantic import field_validator
21
+ from pydantic import model_validator
22
+
23
+ from nat.authentication.interfaces import AuthProviderBaseConfig
24
+ from nat.data_models.common import OptionalSecretStr
25
+ from nat.data_models.common import SerializableSecretStr
26
+
27
+
28
+ class ServiceTokenConfig(BaseModel):
29
+ """
30
+ Configuration for service-specific token in dual authentication patterns.
31
+
32
+ Supports two modes:
33
+
34
+ 1. Static token: Provide token and header directly
35
+ 2. Dynamic function: Provide function path and optional kwargs
36
+
37
+ The function will be called on every request and should have signature::
38
+
39
+ async def get_service_token(**kwargs) -> str | tuple[str, str]
40
+
41
+ If function returns ``tuple[str, str]``, it's interpreted as (header_name, token).
42
+ If function returns ``str``, it's the token and header field is used for header name.
43
+
44
+ The function can access runtime context via AIQContext.get() if needed.
45
+ """
46
+
47
+ # Static token approach
48
+ token: OptionalSecretStr = Field(
49
+ default=None,
50
+ description="Static service token value (mutually exclusive with function)",
51
+ )
52
+
53
+ header: str = Field(
54
+ default="X-Service-Account-Token",
55
+ description="HTTP header name for service token (default: 'X-Service-Account-Token')",
56
+ )
57
+
58
+ # Dynamic function approach
59
+ function: str | None = Field(
60
+ default=None,
61
+ description=("Python function path that returns service token dynamically (mutually exclusive with token). "
62
+ "Function signature: async def func(\\**kwargs) -> str | tuple[str, str]. "
63
+ "Access runtime context via AIQContext.get() if needed."),
64
+ )
65
+
66
+ kwargs: dict[str, typing.Any] = Field(
67
+ default_factory=dict,
68
+ description="Additional keyword arguments to pass to the custom function",
69
+ )
70
+
71
+ @model_validator(mode="after")
72
+ def validate_token_or_function(self):
73
+ """Ensure either token or function is provided, but not both."""
74
+ has_token = self.token is not None
75
+ has_function = self.function is not None
76
+
77
+ if not has_token and not has_function:
78
+ raise ValueError("Either 'token' or 'function' must be provided in service_token config")
79
+
80
+ if has_token and has_function:
81
+ raise ValueError("Cannot specify both 'token' and 'function' in service_token config. Choose one.")
82
+
83
+ return self
84
+
85
+
86
+ class MCPServiceAccountProviderConfig(AuthProviderBaseConfig, name="mcp_service_account"):
87
+ """
88
+ Configuration for MCP service account authentication using OAuth2 client credentials.
89
+
90
+ Generic implementation supporting any OAuth2 client credentials flow.
91
+
92
+ Supports two authentication patterns:
93
+ 1. Single authentication: OAuth2 service account token only
94
+ 2. Dual authentication: OAuth2 service account token + service-specific token
95
+
96
+ Common use cases:
97
+ - Headless/automated MCP workflows
98
+ - CI/CD pipelines
99
+ - Backend services without user interaction
100
+
101
+ All values must be provided via configuration. Use ${ENV_VAR} syntax in YAML
102
+ configs for environment variable substitution.
103
+ """
104
+
105
+ # Required: OAuth2 client credentials
106
+ client_id: str = Field(description="OAuth2 client identifier")
107
+
108
+ client_secret: SerializableSecretStr = Field(description="OAuth2 client secret")
109
+
110
+ # Required: Token endpoint URL
111
+ token_url: str = Field(description="OAuth2 token endpoint URL")
112
+
113
+ # Required: OAuth2 scopes
114
+ scopes: list[str] = Field(description="List of OAuth2 scopes (will be joined with spaces for OAuth2 request)")
115
+
116
+ # Optional: Service-specific token configuration for dual authentication patterns
117
+ service_token: ServiceTokenConfig | None = Field(
118
+ default=None,
119
+ description="Optional service token configuration for dual authentication patterns. "
120
+ "Provide either a static token or a dynamic function that returns the token at runtime.",
121
+ )
122
+
123
+ # Token caching configuration
124
+ token_cache_buffer_seconds: int = Field(default=300,
125
+ description="Seconds before token expiry to refresh (default: 300s/5min)")
126
+
127
+ @field_validator("scopes", mode="before")
128
+ @classmethod
129
+ def validate_scopes(cls, v):
130
+ """
131
+ Accept both list[str] and space-delimited string formats for scopes.
132
+ Converts string to list for consistency.
133
+ """
134
+ if isinstance(v, str):
135
+ # Split space-delimited string into list
136
+ return [scope.strip() for scope in v.split() if scope.strip()]
137
+ return v
@@ -0,0 +1,156 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import asyncio
17
+ import base64
18
+ import logging
19
+ from datetime import datetime
20
+ from datetime import timedelta
21
+
22
+ import httpx
23
+ from pydantic import SecretStr
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class ServiceAccountTokenClient:
29
+ """
30
+ Generic OAuth2 client credentials token client for service accounts.
31
+
32
+ Implements standard OAuth2 client credentials flow with token caching.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ client_id: str,
38
+ client_secret: SecretStr,
39
+ token_url: str,
40
+ scopes: str,
41
+ token_cache_buffer_seconds: int = 300,
42
+ ):
43
+ """
44
+ Initialize service account token client.
45
+
46
+ Args:
47
+ client_id: OAuth2 client identifier
48
+ client_secret: OAuth2 client secret (SecretStr)
49
+ token_url: OAuth2 token endpoint URL
50
+ scopes: Space-separated list of scopes
51
+ token_cache_buffer_seconds: Seconds before expiry to refresh (default: 5 min)
52
+ """
53
+ self.client_id = client_id
54
+ self.client_secret = client_secret
55
+ self.token_url = token_url
56
+ self.scopes = scopes
57
+ self.token_cache_buffer_seconds = token_cache_buffer_seconds
58
+
59
+ # Token cache
60
+ self._cached_token: SecretStr | None = None
61
+ self._token_expires_at: datetime | None = None
62
+ self._lock = None # Will be initialized as asyncio.Lock when needed
63
+
64
+ @property
65
+ def token_expires_at(self) -> datetime | None:
66
+ return self._token_expires_at
67
+
68
+ async def _get_lock(self) -> asyncio.Lock:
69
+ """Lazy initialization of asyncio.Lock."""
70
+ if self._lock is None:
71
+ self._lock = asyncio.Lock()
72
+ return self._lock
73
+
74
+ def _is_token_valid(self) -> bool:
75
+ """Check if cached token is still valid (with buffer time)."""
76
+ if not self._cached_token or not self._token_expires_at:
77
+ return False
78
+ buffer = timedelta(seconds=self.token_cache_buffer_seconds)
79
+ return datetime.now() < (self._token_expires_at - buffer)
80
+
81
+ async def get_access_token(self) -> SecretStr:
82
+ """
83
+ Get OAuth2 access token, using cache if valid.
84
+
85
+ Returns:
86
+ Access token as SecretStr
87
+
88
+ Raises:
89
+ RuntimeError: If token acquisition fails
90
+ """
91
+ # Fast path: check cache without lock
92
+ if self._is_token_valid():
93
+ logger.debug("Using cached service account token")
94
+ assert self._cached_token is not None # _is_token_valid() ensures this
95
+ return self._cached_token
96
+
97
+ # Slow path: acquire lock and refresh token
98
+ lock = await self._get_lock()
99
+ async with lock:
100
+ # Double-check after acquiring lock
101
+ if self._is_token_valid():
102
+ logger.debug("Using cached service account token (acquired during lock wait)")
103
+ assert self._cached_token is not None # _is_token_valid() ensures this
104
+ return self._cached_token
105
+
106
+ logger.info("Fetching new service account token")
107
+ return await self._fetch_new_token()
108
+
109
+ async def _fetch_new_token(self) -> SecretStr:
110
+ """
111
+ Fetch a new token from the OAuth2 token endpoint.
112
+
113
+ Returns:
114
+ New access token as SecretStr
115
+
116
+ Raises:
117
+ RuntimeError: If token request fails
118
+ """
119
+ # Encode credentials for Basic authentication
120
+ credentials = f"{self.client_id}:{self.client_secret.get_secret_value()}"
121
+ encoded_credentials = base64.b64encode(credentials.encode()).decode()
122
+
123
+ headers = {"Authorization": f"Basic {encoded_credentials}", "Content-Type": "application/x-www-form-urlencoded"}
124
+
125
+ data = {"grant_type": "client_credentials", "scope": self.scopes}
126
+
127
+ try:
128
+ async with httpx.AsyncClient(timeout=30.0) as client:
129
+ response = await client.post(self.token_url, headers=headers, data=data)
130
+
131
+ if response.status_code == 200:
132
+ token_data = response.json()
133
+
134
+ # Cache the token
135
+ access_token = token_data.get("access_token")
136
+ if not access_token:
137
+ raise RuntimeError("Access token not found in token response")
138
+ self._cached_token = SecretStr(access_token)
139
+ expires_in = token_data.get("expires_in", 3600)
140
+ self._token_expires_at = datetime.now() + timedelta(seconds=expires_in)
141
+
142
+ logger.info("Service account token acquired (expires in %ss)", expires_in)
143
+ return self._cached_token
144
+
145
+ elif response.status_code == 401:
146
+ raise RuntimeError("Invalid service account credentials")
147
+ elif response.status_code == 429:
148
+ raise RuntimeError("Service account rate limit exceeded")
149
+ else:
150
+ raise RuntimeError(
151
+ f"Service account token request failed: {response.status_code} - {response.text}")
152
+
153
+ except httpx.TimeoutException as e:
154
+ raise RuntimeError(f"Service account token request timed out: {e}") from e
155
+ except httpx.RequestError as e:
156
+ raise RuntimeError(f"Service account token request failed: {e}") from e
@@ -0,0 +1,265 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import hashlib
17
+ import json
18
+ import logging
19
+ from abc import ABC
20
+ from abc import abstractmethod
21
+
22
+ from nat.data_models.authentication import AuthResult
23
+ from nat.data_models.authentication import BasicAuthCred
24
+ from nat.data_models.authentication import BearerTokenCred
25
+ from nat.data_models.authentication import CookieCred
26
+ from nat.data_models.authentication import HeaderCred
27
+ from nat.data_models.authentication import QueryCred
28
+ from nat.data_models.object_store import NoSuchKeyError
29
+ from nat.object_store.interfaces import ObjectStore
30
+ from nat.object_store.models import ObjectStoreItem
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class TokenStorageBase(ABC):
36
+ """
37
+ Abstract base class for token storage implementations.
38
+
39
+ Token storage implementations handle the secure persistence of authentication
40
+ tokens for MCP OAuth2 flows. Implementations can use various backends such as
41
+ object stores, databases, or in-memory storage.
42
+ """
43
+
44
+ @abstractmethod
45
+ async def store(self, user_id: str, auth_result: AuthResult) -> None:
46
+ """
47
+ Store an authentication result for a user.
48
+
49
+ Args:
50
+ user_id: The unique identifier for the user
51
+ auth_result: The authentication result to store
52
+ """
53
+ pass
54
+
55
+ @abstractmethod
56
+ async def retrieve(self, user_id: str) -> AuthResult | None:
57
+ """
58
+ Retrieve an authentication result for a user.
59
+
60
+ Args:
61
+ user_id: The unique identifier for the user
62
+
63
+ Returns:
64
+ The authentication result if found, None otherwise
65
+ """
66
+ pass
67
+
68
+ @abstractmethod
69
+ async def delete(self, user_id: str) -> None:
70
+ """
71
+ Delete an authentication result for a user.
72
+
73
+ Args:
74
+ user_id: The unique identifier for the user
75
+ """
76
+ pass
77
+
78
+ @abstractmethod
79
+ async def clear_all(self) -> None:
80
+ """
81
+ Clear all stored authentication results.
82
+ """
83
+ pass
84
+
85
+
86
+ class ObjectStoreTokenStorage(TokenStorageBase):
87
+ """
88
+ Token storage implementation backed by a NeMo Agent toolkit object store.
89
+
90
+ This implementation uses the object store infrastructure to persist tokens,
91
+ which provides encryption at rest, access controls, and persistence across
92
+ restarts when using backends like S3, MySQL, or Redis.
93
+ """
94
+
95
+ def __init__(self, object_store: ObjectStore):
96
+ """
97
+ Initialize the object store token storage.
98
+
99
+ Args:
100
+ object_store: The object store instance to use for token persistence
101
+ """
102
+ self._object_store = object_store
103
+
104
+ def _get_key(self, user_id: str) -> str:
105
+ """
106
+ Generate the object store key for a user's token.
107
+
108
+ Uses SHA256 hash to ensure the key is S3-compatible and doesn't
109
+ contain special characters like "://" that are invalid in object keys.
110
+
111
+ Args:
112
+ user_id: The user identifier
113
+
114
+ Returns:
115
+ The object store key
116
+ """
117
+ # Hash the user_id to create an S3-safe key
118
+ user_hash = hashlib.sha256(user_id.encode('utf-8')).hexdigest()
119
+ return f"tokens/{user_hash}"
120
+
121
+ async def store(self, user_id: str, auth_result: AuthResult) -> None:
122
+ """
123
+ Store an authentication result in the object store.
124
+
125
+ Args:
126
+ user_id: The unique identifier for the user
127
+ auth_result: The authentication result to store
128
+ """
129
+ key = self._get_key(user_id)
130
+
131
+ # Serialize the AuthResult to JSON with secrets exposed
132
+ # SecretStr values are masked by default, so we need to expose them manually
133
+ # Create a serializable dict with exposed secrets
134
+ auth_dict = auth_result.model_dump(mode='json')
135
+ # Manually expose SecretStr values in credentials
136
+ for i, cred_obj in enumerate(auth_result.credentials):
137
+ if isinstance(cred_obj, BearerTokenCred):
138
+ auth_dict['credentials'][i]['token'] = cred_obj.token.get_secret_value()
139
+ elif isinstance(cred_obj, BasicAuthCred):
140
+ auth_dict['credentials'][i]['username'] = cred_obj.username.get_secret_value()
141
+ auth_dict['credentials'][i]['password'] = cred_obj.password.get_secret_value()
142
+ elif isinstance(cred_obj, HeaderCred | QueryCred | CookieCred):
143
+ auth_dict['credentials'][i]['value'] = cred_obj.value.get_secret_value()
144
+
145
+ data = json.dumps(auth_dict).encode('utf-8')
146
+
147
+ # Prepare metadata
148
+ metadata = {}
149
+ if auth_result.token_expires_at:
150
+ metadata["expires_at"] = auth_result.token_expires_at.isoformat()
151
+
152
+ # Create the object store item
153
+ item = ObjectStoreItem(data=data, content_type="application/json", metadata=metadata if metadata else None)
154
+
155
+ # Store using upsert to handle both new and existing tokens
156
+ await self._object_store.upsert_object(key, item)
157
+
158
+ async def retrieve(self, user_id: str) -> AuthResult | None:
159
+ """
160
+ Retrieve an authentication result from the object store.
161
+
162
+ Args:
163
+ user_id: The unique identifier for the user
164
+
165
+ Returns:
166
+ The authentication result if found, None otherwise
167
+ """
168
+ key = self._get_key(user_id)
169
+
170
+ try:
171
+ item = await self._object_store.get_object(key)
172
+ # Deserialize the AuthResult from JSON
173
+ auth_result = AuthResult.model_validate_json(item.data)
174
+ return auth_result
175
+ except NoSuchKeyError:
176
+ return None
177
+ except Exception as e:
178
+ logger.error(f"Error deserializing token for user {user_id}: {e}", exc_info=True)
179
+ return None
180
+
181
+ async def delete(self, user_id: str) -> None:
182
+ """
183
+ Delete an authentication result from the object store.
184
+
185
+ Args:
186
+ user_id: The unique identifier for the user
187
+ """
188
+ key = self._get_key(user_id)
189
+
190
+ try:
191
+ await self._object_store.delete_object(key)
192
+ except NoSuchKeyError:
193
+ # Token doesn't exist, which is fine for delete operations
194
+ pass
195
+
196
+ async def clear_all(self) -> None:
197
+ """
198
+ Clear all stored authentication results.
199
+
200
+ Note: This implementation does not support clearing all tokens as the
201
+ object store interface doesn't provide a list operation. Individual
202
+ tokens must be deleted explicitly.
203
+ """
204
+ logger.warning("clear_all() is not supported for ObjectStoreTokenStorage")
205
+
206
+
207
+ class InMemoryTokenStorage(TokenStorageBase):
208
+ """
209
+ In-memory token storage using the built-in object store provided by the NeMo Agent toolkit.
210
+
211
+ This implementation uses the in-memory object store for token persistence,
212
+ which provides a secure default option that doesn't require external storage
213
+ configuration. Tokens are stored in memory and cleared when the process exits.
214
+ """
215
+
216
+ def __init__(self):
217
+ """
218
+ Initialize the in-memory token storage.
219
+ """
220
+ from nat.object_store.in_memory_object_store import InMemoryObjectStore
221
+
222
+ # Create a dedicated in-memory object store for tokens
223
+ self._object_store = InMemoryObjectStore()
224
+
225
+ # Wrap with ObjectStoreTokenStorage for the actual implementation
226
+ self._storage = ObjectStoreTokenStorage(self._object_store)
227
+ logger.debug("Initialized in-memory token storage")
228
+
229
+ async def store(self, user_id: str, auth_result: AuthResult) -> None:
230
+ """
231
+ Store an authentication result in memory.
232
+
233
+ Args:
234
+ user_id: The unique identifier for the user
235
+ auth_result: The authentication result to store
236
+ """
237
+ await self._storage.store(user_id, auth_result)
238
+
239
+ async def retrieve(self, user_id: str) -> AuthResult | None:
240
+ """
241
+ Retrieve an authentication result from memory.
242
+
243
+ Args:
244
+ user_id: The unique identifier for the user
245
+
246
+ Returns:
247
+ The authentication result if found, None otherwise
248
+ """
249
+ return await self._storage.retrieve(user_id)
250
+
251
+ async def delete(self, user_id: str) -> None:
252
+ """
253
+ Delete an authentication result from memory.
254
+
255
+ Args:
256
+ user_id: The unique identifier for the user
257
+ """
258
+ await self._storage.delete(user_id)
259
+
260
+ async def clear_all(self) -> None:
261
+ """
262
+ Clear all stored authentication results from memory.
263
+ """
264
+ # For in-memory storage, we can access the internal storage
265
+ self._object_store._store.clear()