nvidia-nat-mcp 1.3.0a20251005__py3-none-any.whl → 1.3.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nat/meta/pypi.md CHANGED
@@ -19,9 +19,9 @@ limitations under the License.
19
19
 
20
20
 
21
21
  # NVIDIA NeMo Agent Toolkit MCP Subpackage
22
- Subpackage for MCP client integration in NeMo Agent toolkit.
22
+ Subpackage for MCP integration in NeMo Agent toolkit.
23
23
 
24
- This package provides MCP (Model Context Protocol) client functionality, allowing NeMo Agent toolkit workflows to connect to external MCP servers and use their tools as functions.
24
+ This package provides MCP (Model Context Protocol) functionality, allowing NeMo Agent toolkit workflows to connect to external MCP servers and use their tools as functions.
25
25
 
26
26
  ## Features
27
27
 
@@ -23,6 +23,7 @@ import httpx
23
23
  from pydantic import BaseModel
24
24
  from pydantic import Field
25
25
  from pydantic import HttpUrl
26
+ from pydantic import TypeAdapter
26
27
 
27
28
  from mcp.shared.auth import OAuthClientInformationFull
28
29
  from mcp.shared.auth import OAuthClientMetadata
@@ -65,7 +66,6 @@ class DiscoverOAuth2Endpoints:
65
66
  def __init__(self, config: MCPOAuth2ProviderConfig):
66
67
  self.config = config
67
68
  self._cached_endpoints: OAuth2Endpoints | None = None
68
- self._authenticated_servers: dict[str, AuthResult] = {}
69
69
 
70
70
  self._flow_handler: MCPAuthenticationFlowHandler = MCPAuthenticationFlowHandler()
71
71
 
@@ -192,11 +192,13 @@ class DiscoverOAuth2Endpoints:
192
192
  continue
193
193
  if meta.authorization_endpoint and meta.token_endpoint:
194
194
  logger.info("Discovered OAuth2 endpoints from %s", url)
195
- # this is bit of a hack to get the scopes supported by the auth server
195
+ # Convert AnyHttpUrl to HttpUrl using TypeAdapter
196
+ http_url_adapter = TypeAdapter(HttpUrl)
196
197
  return OAuth2Endpoints(
197
- authorization_url=str(meta.authorization_endpoint),
198
- token_url=str(meta.token_endpoint),
199
- registration_url=str(meta.registration_endpoint) if meta.registration_endpoint else None,
198
+ authorization_url=http_url_adapter.validate_python(str(meta.authorization_endpoint)),
199
+ token_url=http_url_adapter.validate_python(str(meta.token_endpoint)),
200
+ registration_url=http_url_adapter.validate_python(str(meta.registration_endpoint))
201
+ if meta.registration_endpoint else None,
200
202
  scopes=meta.scopes_supported,
201
203
  )
202
204
  except Exception as e:
@@ -283,8 +285,9 @@ class DynamicClientRegistration:
283
285
  class MCPOAuth2Provider(AuthProviderBase[MCPOAuth2ProviderConfig]):
284
286
  """MCP OAuth2 authentication provider that delegates to NAT framework."""
285
287
 
286
- def __init__(self, config: MCPOAuth2ProviderConfig):
288
+ def __init__(self, config: MCPOAuth2ProviderConfig, builder=None):
287
289
  super().__init__(config)
290
+ self._builder = builder
288
291
 
289
292
  # Discovery
290
293
  self._discoverer = DiscoverOAuth2Endpoints(config)
@@ -300,6 +303,19 @@ class MCPOAuth2Provider(AuthProviderBase[MCPOAuth2ProviderConfig]):
300
303
 
301
304
  self._auth_callback = None
302
305
 
306
+ # Initialize token storage
307
+ self._token_storage = None
308
+ self._token_storage_object_store_name = None
309
+
310
+ if self.config.token_storage_object_store:
311
+ # Store object store name, will be resolved later when builder context is available
312
+ self._token_storage_object_store_name = self.config.token_storage_object_store
313
+ logger.info(f"Configured to use object store '{self._token_storage_object_store_name}' for token storage")
314
+ else:
315
+ # Default: use in-memory token storage
316
+ from .token_storage import InMemoryTokenStorage
317
+ self._token_storage = InMemoryTokenStorage()
318
+
303
319
  def _set_custom_auth_callback(self,
304
320
  auth_callback: Callable[[OAuth2AuthCodeFlowProviderConfig, AuthFlowType],
305
321
  Awaitable[AuthenticatedContext]]):
@@ -308,7 +324,7 @@ class MCPOAuth2Provider(AuthProviderBase[MCPOAuth2ProviderConfig]):
308
324
  logger.info("Using custom authentication callback")
309
325
  self._auth_callback = auth_callback
310
326
  if self._auth_code_provider:
311
- self._auth_code_provider._set_custom_auth_callback(self._auth_callback)
327
+ self._auth_code_provider._set_custom_auth_callback(self._auth_callback) # type: ignore[arg-type]
312
328
 
313
329
  async def authenticate(self, user_id: str | None = None, **kwargs) -> AuthResult:
314
330
  """
@@ -374,6 +390,22 @@ class MCPOAuth2Provider(AuthProviderBase[MCPOAuth2ProviderConfig]):
374
390
  endpoints = self._cached_endpoints
375
391
  credentials = self._cached_credentials
376
392
 
393
+ # Resolve object store reference if needed
394
+ if self._token_storage_object_store_name and not self._token_storage:
395
+ try:
396
+ if not self._builder:
397
+ raise RuntimeError("Builder not available for resolving object store")
398
+ object_store = await self._builder.get_object_store_client(self._token_storage_object_store_name)
399
+ from .token_storage import ObjectStoreTokenStorage
400
+ self._token_storage = ObjectStoreTokenStorage(object_store)
401
+ logger.info(f"Initialized token storage with object store '{self._token_storage_object_store_name}'")
402
+ except Exception as e:
403
+ logger.warning(
404
+ f"Failed to resolve object store '{self._token_storage_object_store_name}' for token storage: {e}. "
405
+ "Falling back to in-memory storage.")
406
+ from .token_storage import InMemoryTokenStorage
407
+ self._token_storage = InMemoryTokenStorage()
408
+
377
409
  # Build the OAuth2 provider if not already built
378
410
  if self._auth_code_provider is None:
379
411
  scopes = self._effective_scopes
@@ -387,12 +419,12 @@ class MCPOAuth2Provider(AuthProviderBase[MCPOAuth2ProviderConfig]):
387
419
  scopes=scopes,
388
420
  use_pkce=bool(self.config.use_pkce),
389
421
  authorization_kwargs={"resource": str(self.config.server_url)})
390
- self._auth_code_provider = OAuth2AuthCodeFlowProvider(oauth2_config)
422
+ self._auth_code_provider = OAuth2AuthCodeFlowProvider(oauth2_config, token_storage=self._token_storage)
391
423
 
392
424
  # Use MCP-specific authentication method if available
393
425
  if hasattr(self._auth_code_provider, "_set_custom_auth_callback"):
394
- self._auth_code_provider._set_custom_auth_callback(self._auth_callback
395
- or self._flow_handler.authenticate)
426
+ callback = self._auth_callback or self._flow_handler.authenticate
427
+ self._auth_code_provider._set_custom_auth_callback(callback) # type: ignore[arg-type]
396
428
 
397
429
  # Auth code provider is responsible for per-user cache + refresh
398
430
  return await self._auth_code_provider.authenticate(user_id=user_id)
@@ -53,6 +53,11 @@ class MCPOAuth2ProviderConfig(AuthProviderBaseConfig, name="mcp_oauth2"):
53
53
  default_user_id: str | None = Field(default=None, description="Default user ID for authentication")
54
54
  allow_default_user_id_for_tool_calls: bool = Field(default=True, description="Allow default user ID for tool calls")
55
55
 
56
+ # Token storage configuration
57
+ token_storage_object_store: str | None = Field(
58
+ default=None,
59
+ description="Reference to object store for secure token storage. If None, uses in-memory storage.")
60
+
56
61
  @model_validator(mode="after")
57
62
  def validate_auth_config(self):
58
63
  """Validate authentication configuration for MCP-specific options."""
@@ -22,4 +22,4 @@ from nat.plugins.mcp.auth.auth_provider_config import MCPOAuth2ProviderConfig
22
22
  @register_auth_provider(config_type=MCPOAuth2ProviderConfig)
23
23
  async def mcp_oauth2_provider(authentication_provider: MCPOAuth2ProviderConfig, builder: Builder):
24
24
  """Register MCP OAuth2 authentication provider with NAT system."""
25
- yield MCPOAuth2Provider(authentication_provider)
25
+ yield MCPOAuth2Provider(authentication_provider, builder=builder)
@@ -0,0 +1,265 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import hashlib
17
+ import json
18
+ import logging
19
+ from abc import ABC
20
+ from abc import abstractmethod
21
+
22
+ from nat.data_models.authentication import AuthResult
23
+ from nat.data_models.authentication import BasicAuthCred
24
+ from nat.data_models.authentication import BearerTokenCred
25
+ from nat.data_models.authentication import CookieCred
26
+ from nat.data_models.authentication import HeaderCred
27
+ from nat.data_models.authentication import QueryCred
28
+ from nat.data_models.object_store import NoSuchKeyError
29
+ from nat.object_store.interfaces import ObjectStore
30
+ from nat.object_store.models import ObjectStoreItem
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class TokenStorageBase(ABC):
36
+ """
37
+ Abstract base class for token storage implementations.
38
+
39
+ Token storage implementations handle the secure persistence of authentication
40
+ tokens for MCP OAuth2 flows. Implementations can use various backends such as
41
+ object stores, databases, or in-memory storage.
42
+ """
43
+
44
+ @abstractmethod
45
+ async def store(self, user_id: str, auth_result: AuthResult) -> None:
46
+ """
47
+ Store an authentication result for a user.
48
+
49
+ Args:
50
+ user_id: The unique identifier for the user
51
+ auth_result: The authentication result to store
52
+ """
53
+ pass
54
+
55
+ @abstractmethod
56
+ async def retrieve(self, user_id: str) -> AuthResult | None:
57
+ """
58
+ Retrieve an authentication result for a user.
59
+
60
+ Args:
61
+ user_id: The unique identifier for the user
62
+
63
+ Returns:
64
+ The authentication result if found, None otherwise
65
+ """
66
+ pass
67
+
68
+ @abstractmethod
69
+ async def delete(self, user_id: str) -> None:
70
+ """
71
+ Delete an authentication result for a user.
72
+
73
+ Args:
74
+ user_id: The unique identifier for the user
75
+ """
76
+ pass
77
+
78
+ @abstractmethod
79
+ async def clear_all(self) -> None:
80
+ """
81
+ Clear all stored authentication results.
82
+ """
83
+ pass
84
+
85
+
86
+ class ObjectStoreTokenStorage(TokenStorageBase):
87
+ """
88
+ Token storage implementation backed by a NeMo Agent toolkit object store.
89
+
90
+ This implementation uses the object store infrastructure to persist tokens,
91
+ which provides encryption at rest, access controls, and persistence across
92
+ restarts when using backends like S3, MySQL, or Redis.
93
+ """
94
+
95
+ def __init__(self, object_store: ObjectStore):
96
+ """
97
+ Initialize the object store token storage.
98
+
99
+ Args:
100
+ object_store: The object store instance to use for token persistence
101
+ """
102
+ self._object_store = object_store
103
+
104
+ def _get_key(self, user_id: str) -> str:
105
+ """
106
+ Generate the object store key for a user's token.
107
+
108
+ Uses SHA256 hash to ensure the key is S3-compatible and doesn't
109
+ contain special characters like "://" that are invalid in object keys.
110
+
111
+ Args:
112
+ user_id: The user identifier
113
+
114
+ Returns:
115
+ The object store key
116
+ """
117
+ # Hash the user_id to create an S3-safe key
118
+ user_hash = hashlib.sha256(user_id.encode('utf-8')).hexdigest()
119
+ return f"tokens/{user_hash}"
120
+
121
+ async def store(self, user_id: str, auth_result: AuthResult) -> None:
122
+ """
123
+ Store an authentication result in the object store.
124
+
125
+ Args:
126
+ user_id: The unique identifier for the user
127
+ auth_result: The authentication result to store
128
+ """
129
+ key = self._get_key(user_id)
130
+
131
+ # Serialize the AuthResult to JSON with secrets exposed
132
+ # SecretStr values are masked by default, so we need to expose them manually
133
+ # Create a serializable dict with exposed secrets
134
+ auth_dict = auth_result.model_dump(mode='json')
135
+ # Manually expose SecretStr values in credentials
136
+ for i, cred_obj in enumerate(auth_result.credentials):
137
+ if isinstance(cred_obj, BearerTokenCred):
138
+ auth_dict['credentials'][i]['token'] = cred_obj.token.get_secret_value()
139
+ elif isinstance(cred_obj, BasicAuthCred):
140
+ auth_dict['credentials'][i]['username'] = cred_obj.username.get_secret_value()
141
+ auth_dict['credentials'][i]['password'] = cred_obj.password.get_secret_value()
142
+ elif isinstance(cred_obj, HeaderCred | QueryCred | CookieCred):
143
+ auth_dict['credentials'][i]['value'] = cred_obj.value.get_secret_value()
144
+
145
+ data = json.dumps(auth_dict).encode('utf-8')
146
+
147
+ # Prepare metadata
148
+ metadata = {}
149
+ if auth_result.token_expires_at:
150
+ metadata["expires_at"] = auth_result.token_expires_at.isoformat()
151
+
152
+ # Create the object store item
153
+ item = ObjectStoreItem(data=data, content_type="application/json", metadata=metadata if metadata else None)
154
+
155
+ # Store using upsert to handle both new and existing tokens
156
+ await self._object_store.upsert_object(key, item)
157
+
158
+ async def retrieve(self, user_id: str) -> AuthResult | None:
159
+ """
160
+ Retrieve an authentication result from the object store.
161
+
162
+ Args:
163
+ user_id: The unique identifier for the user
164
+
165
+ Returns:
166
+ The authentication result if found, None otherwise
167
+ """
168
+ key = self._get_key(user_id)
169
+
170
+ try:
171
+ item = await self._object_store.get_object(key)
172
+ # Deserialize the AuthResult from JSON
173
+ auth_result = AuthResult.model_validate_json(item.data)
174
+ return auth_result
175
+ except NoSuchKeyError:
176
+ return None
177
+ except Exception as e:
178
+ logger.error(f"Error deserializing token for user {user_id}: {e}", exc_info=True)
179
+ return None
180
+
181
+ async def delete(self, user_id: str) -> None:
182
+ """
183
+ Delete an authentication result from the object store.
184
+
185
+ Args:
186
+ user_id: The unique identifier for the user
187
+ """
188
+ key = self._get_key(user_id)
189
+
190
+ try:
191
+ await self._object_store.delete_object(key)
192
+ except NoSuchKeyError:
193
+ # Token doesn't exist, which is fine for delete operations
194
+ pass
195
+
196
+ async def clear_all(self) -> None:
197
+ """
198
+ Clear all stored authentication results.
199
+
200
+ Note: This implementation does not support clearing all tokens as the
201
+ object store interface doesn't provide a list operation. Individual
202
+ tokens must be deleted explicitly.
203
+ """
204
+ logger.warning("clear_all() is not supported for ObjectStoreTokenStorage")
205
+
206
+
207
+ class InMemoryTokenStorage(TokenStorageBase):
208
+ """
209
+ In-memory token storage using NeMo Agent toolkit's built-in object store.
210
+
211
+ This implementation uses the in-memory object store for token persistence,
212
+ which provides a secure default option that doesn't require external storage
213
+ configuration. Tokens are stored in memory and cleared when the process exits.
214
+ """
215
+
216
+ def __init__(self):
217
+ """
218
+ Initialize the in-memory token storage.
219
+ """
220
+ from nat.object_store.in_memory_object_store import InMemoryObjectStore
221
+
222
+ # Create a dedicated in-memory object store for tokens
223
+ self._object_store = InMemoryObjectStore()
224
+
225
+ # Wrap with ObjectStoreTokenStorage for the actual implementation
226
+ self._storage = ObjectStoreTokenStorage(self._object_store)
227
+ logger.debug("Initialized in-memory token storage")
228
+
229
+ async def store(self, user_id: str, auth_result: AuthResult) -> None:
230
+ """
231
+ Store an authentication result in memory.
232
+
233
+ Args:
234
+ user_id: The unique identifier for the user
235
+ auth_result: The authentication result to store
236
+ """
237
+ await self._storage.store(user_id, auth_result)
238
+
239
+ async def retrieve(self, user_id: str) -> AuthResult | None:
240
+ """
241
+ Retrieve an authentication result from memory.
242
+
243
+ Args:
244
+ user_id: The unique identifier for the user
245
+
246
+ Returns:
247
+ The authentication result if found, None otherwise
248
+ """
249
+ return await self._storage.retrieve(user_id)
250
+
251
+ async def delete(self, user_id: str) -> None:
252
+ """
253
+ Delete an authentication result from memory.
254
+
255
+ Args:
256
+ user_id: The unique identifier for the user
257
+ """
258
+ await self._storage.delete(user_id)
259
+
260
+ async def clear_all(self) -> None:
261
+ """
262
+ Clear all stored authentication results from memory.
263
+ """
264
+ # For in-memory storage, we can access the internal storage
265
+ self._object_store._store.clear()
@@ -16,7 +16,6 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import asyncio
19
- import json
20
19
  import logging
21
20
  from abc import ABC
22
21
  from abc import abstractmethod
@@ -55,8 +54,9 @@ class AuthAdapter(httpx.Auth):
55
54
  Converts AuthProviderBase to httpx.Auth interface for dynamic token management.
56
55
  """
57
56
 
58
- def __init__(self, auth_provider: AuthProviderBase):
57
+ def __init__(self, auth_provider: AuthProviderBase, user_id: str | None = None):
59
58
  self.auth_provider = auth_provider
59
+ self.user_id = user_id # Session-specific user ID for cache isolation
60
60
  # each adapter instance has its own lock to avoid unnecessary delays for multiple clients
61
61
  self._lock = anyio.Lock()
62
62
  # Track whether we're currently in an interactive authentication flow
@@ -104,41 +104,13 @@ class AuthAdapter(httpx.Auth):
104
104
  logger.debug("Authentication flow completed")
105
105
  return
106
106
 
107
- def _get_session_id_from_tool_call_request(self, request: httpx.Request) -> tuple[str | None, bool]:
108
- """Check if this is a tool call request based on the request body.
109
- Return the session id if it exists and a boolean indicating if it is a tool call request
110
- """
111
- try:
112
- # Check if the request body contains a tool call
113
- if request.content:
114
- body = json.loads(request.content.decode('utf-8'))
115
- # Check if it's a JSON-RPC request with method "tools/call"
116
- if (isinstance(body, dict) and body.get("method") == "tools/call"):
117
- session_id = body.get("params").get("_meta").get("session_id")
118
- return session_id, True
119
- except (json.JSONDecodeError, UnicodeDecodeError, AttributeError):
120
- # If we can't parse the body, assume it's not a tool call
121
- pass
122
- return None, False
123
-
124
107
  async def _get_auth_headers(self,
125
108
  request: httpx.Request | None = None,
126
109
  response: httpx.Response | None = None) -> dict[str, str]:
127
110
  """Get authentication headers from the NAT auth provider."""
128
111
  try:
129
- session_id = None
130
- is_tool_call = False
131
- if request:
132
- session_id, is_tool_call = self._get_session_id_from_tool_call_request(request)
133
-
134
- if is_tool_call:
135
- # Tool call requests should use the session id
136
- user_id = session_id
137
- else:
138
- # Non-tool call requests should use the session id if it exists and fallback to default user id
139
- user_id = session_id or self.auth_provider.config.default_user_id
140
-
141
- auth_result = await self.auth_provider.authenticate(user_id=user_id, response=response)
112
+ # Use the user_id passed to this AuthAdapter instance
113
+ auth_result = await self.auth_provider.authenticate(user_id=self.user_id, response=response)
142
114
 
143
115
  # Check if we have BearerTokenCred
144
116
  from nat.data_models.authentication import BearerTokenCred
@@ -171,6 +143,7 @@ class MCPBaseClient(ABC):
171
143
  def __init__(self,
172
144
  transport: str = 'streamable-http',
173
145
  auth_provider: AuthProviderBase | None = None,
146
+ user_id: str | None = None,
174
147
  tool_call_timeout: timedelta = timedelta(seconds=60),
175
148
  auth_flow_timeout: timedelta = timedelta(seconds=300),
176
149
  reconnect_enabled: bool = True,
@@ -189,7 +162,9 @@ class MCPBaseClient(ABC):
189
162
 
190
163
  # Convert auth provider to AuthAdapter
191
164
  self._auth_provider = auth_provider
192
- self._httpx_auth = AuthAdapter(auth_provider) if auth_provider else None
165
+ # Use provided user_id or fall back to auth provider's default_user_id
166
+ effective_user_id = user_id or (auth_provider.config.default_user_id if auth_provider else None)
167
+ self._httpx_auth = AuthAdapter(auth_provider, effective_user_id) if auth_provider else None
193
168
 
194
169
  self._tool_call_timeout = tool_call_timeout
195
170
  self._auth_flow_timeout = auth_flow_timeout
@@ -421,24 +396,6 @@ class MCPBaseClient(ABC):
421
396
  if self._auth_provider and hasattr(self._auth_provider, "_set_custom_auth_callback"):
422
397
  self._auth_provider._set_custom_auth_callback(auth_callback)
423
398
 
424
- @mcp_exception_handler
425
- async def call_tool_with_meta(self, tool_name: str, args: dict, session_id: str):
426
- from mcp.types import CallToolRequest
427
- from mcp.types import CallToolRequestParams
428
- from mcp.types import CallToolResult
429
- from mcp.types import ClientRequest
430
-
431
- if not self._session:
432
- raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
433
-
434
- async def _call_tool_with_meta():
435
- params = CallToolRequestParams(name=tool_name, arguments=args, **{"_meta": {"session_id": session_id}})
436
- req = ClientRequest(CallToolRequest(params=params))
437
- timeout = await self._get_tool_call_timeout()
438
- return await self._session.send_request(req, CallToolResult, request_read_timeout_seconds=timeout)
439
-
440
- return await self._with_reconnect(_call_tool_with_meta)
441
-
442
399
  @mcp_exception_handler
443
400
  async def call_tool(self, tool_name: str, tool_args: dict | None):
444
401
 
@@ -570,6 +527,7 @@ class MCPStreamableHTTPClient(MCPBaseClient):
570
527
  def __init__(self,
571
528
  url: str,
572
529
  auth_provider: AuthProviderBase | None = None,
530
+ user_id: str | None = None,
573
531
  tool_call_timeout: timedelta = timedelta(seconds=60),
574
532
  auth_flow_timeout: timedelta = timedelta(seconds=300),
575
533
  reconnect_enabled: bool = True,
@@ -578,6 +536,7 @@ class MCPStreamableHTTPClient(MCPBaseClient):
578
536
  reconnect_max_backoff: float = 50.0):
579
537
  super().__init__("streamable-http",
580
538
  auth_provider=auth_provider,
539
+ user_id=user_id,
581
540
  tool_call_timeout=tool_call_timeout,
582
541
  auth_flow_timeout=auth_flow_timeout,
583
542
  reconnect_enabled=reconnect_enabled,
@@ -662,35 +621,10 @@ class MCPToolClient:
662
621
  """
663
622
  self._tool_description = description
664
623
 
665
- def _get_session_id(self) -> str | None:
666
- """
667
- Get the session id from the context.
668
- """
669
- from nat.builder.context import Context as _Ctx
670
-
671
- # get auth callback (for example: WebSocketAuthenticationFlowHandler). this is lazily set in the client
672
- # on first tool call
673
- auth_callback = _Ctx.get().user_auth_callback
674
- if auth_callback and self._parent_client:
675
- # set custom auth callback
676
- self._parent_client.set_user_auth_callback(auth_callback)
677
-
678
- # get session id from context, authentication is done per-websocket session for tool calls
679
- session_id = None
680
- cookies = getattr(_Ctx.get().metadata, "cookies", None)
681
- if cookies:
682
- session_id = cookies.get("nat-session")
683
-
684
- if not session_id:
685
- # use default user id if allowed
686
- if self._parent_client.auth_provider and \
687
- self._parent_client.auth_provider.config.allow_default_user_id_for_tool_calls:
688
- session_id = self._parent_client.auth_provider.config.default_user_id
689
- return session_id
690
-
691
624
  async def acall(self, tool_args: dict) -> str:
692
625
  """
693
626
  Call the MCP tool with the provided arguments.
627
+ Session context is now handled at the client level, eliminating the need for metadata injection.
694
628
 
695
629
  Args:
696
630
  tool_args (dict[str, Any]): A dictionary of key value pairs to serve as inputs for the MCP tool.
@@ -698,25 +632,10 @@ class MCPToolClient:
698
632
  if self._session is None:
699
633
  raise RuntimeError("No session available for tool call")
700
634
 
701
- # Extract context information
702
635
  try:
703
- session_id = self._get_session_id()
704
- except Exception:
705
- session_id = None
706
-
707
- try:
708
- # if auth is enabled and session id is not available return user is not authorized to call the tool
709
- if self._parent_client.auth_provider and not session_id:
710
- result_str = "User is not authorized to call the tool"
711
- mcp_error: MCPError = convert_to_mcp_error(RuntimeError(result_str), self._parent_client.server_name)
712
- raise mcp_error
713
-
714
- if session_id:
715
- logger.info("Calling tool %s with arguments %s for a user session", self._tool_name, tool_args)
716
- result = await self._parent_client.call_tool_with_meta(self._tool_name, tool_args, session_id)
717
- else:
718
- logger.info("Calling tool %s with arguments %s", self._tool_name, tool_args)
719
- result = await self._parent_client.call_tool(self._tool_name, tool_args)
636
+ # Simple tool call - session context is already in the client instance
637
+ logger.info("Calling tool %s with arguments %s", self._tool_name, tool_args)
638
+ result = await self._parent_client.call_tool(self._tool_name, tool_args)
720
639
 
721
640
  output = []
722
641
  for res in result.content: