nvidia-nat-mcp 1.3.0a20250926__py3-none-any.whl → 1.3.0a20251111__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,265 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import hashlib
17
+ import json
18
+ import logging
19
+ from abc import ABC
20
+ from abc import abstractmethod
21
+
22
+ from nat.data_models.authentication import AuthResult
23
+ from nat.data_models.authentication import BasicAuthCred
24
+ from nat.data_models.authentication import BearerTokenCred
25
+ from nat.data_models.authentication import CookieCred
26
+ from nat.data_models.authentication import HeaderCred
27
+ from nat.data_models.authentication import QueryCred
28
+ from nat.data_models.object_store import NoSuchKeyError
29
+ from nat.object_store.interfaces import ObjectStore
30
+ from nat.object_store.models import ObjectStoreItem
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class TokenStorageBase(ABC):
36
+ """
37
+ Abstract base class for token storage implementations.
38
+
39
+ Token storage implementations handle the secure persistence of authentication
40
+ tokens for MCP OAuth2 flows. Implementations can use various backends such as
41
+ object stores, databases, or in-memory storage.
42
+ """
43
+
44
+ @abstractmethod
45
+ async def store(self, user_id: str, auth_result: AuthResult) -> None:
46
+ """
47
+ Store an authentication result for a user.
48
+
49
+ Args:
50
+ user_id: The unique identifier for the user
51
+ auth_result: The authentication result to store
52
+ """
53
+ pass
54
+
55
+ @abstractmethod
56
+ async def retrieve(self, user_id: str) -> AuthResult | None:
57
+ """
58
+ Retrieve an authentication result for a user.
59
+
60
+ Args:
61
+ user_id: The unique identifier for the user
62
+
63
+ Returns:
64
+ The authentication result if found, None otherwise
65
+ """
66
+ pass
67
+
68
+ @abstractmethod
69
+ async def delete(self, user_id: str) -> None:
70
+ """
71
+ Delete an authentication result for a user.
72
+
73
+ Args:
74
+ user_id: The unique identifier for the user
75
+ """
76
+ pass
77
+
78
+ @abstractmethod
79
+ async def clear_all(self) -> None:
80
+ """
81
+ Clear all stored authentication results.
82
+ """
83
+ pass
84
+
85
+
86
+ class ObjectStoreTokenStorage(TokenStorageBase):
87
+ """
88
+ Token storage implementation backed by a NeMo Agent toolkit object store.
89
+
90
+ This implementation uses the object store infrastructure to persist tokens,
91
+ which provides encryption at rest, access controls, and persistence across
92
+ restarts when using backends like S3, MySQL, or Redis.
93
+ """
94
+
95
+ def __init__(self, object_store: ObjectStore):
96
+ """
97
+ Initialize the object store token storage.
98
+
99
+ Args:
100
+ object_store: The object store instance to use for token persistence
101
+ """
102
+ self._object_store = object_store
103
+
104
+ def _get_key(self, user_id: str) -> str:
105
+ """
106
+ Generate the object store key for a user's token.
107
+
108
+ Uses SHA256 hash to ensure the key is S3-compatible and doesn't
109
+ contain special characters like "://" that are invalid in object keys.
110
+
111
+ Args:
112
+ user_id: The user identifier
113
+
114
+ Returns:
115
+ The object store key
116
+ """
117
+ # Hash the user_id to create an S3-safe key
118
+ user_hash = hashlib.sha256(user_id.encode('utf-8')).hexdigest()
119
+ return f"tokens/{user_hash}"
120
+
121
+ async def store(self, user_id: str, auth_result: AuthResult) -> None:
122
+ """
123
+ Store an authentication result in the object store.
124
+
125
+ Args:
126
+ user_id: The unique identifier for the user
127
+ auth_result: The authentication result to store
128
+ """
129
+ key = self._get_key(user_id)
130
+
131
+ # Serialize the AuthResult to JSON with secrets exposed
132
+ # SecretStr values are masked by default, so we need to expose them manually
133
+ # Create a serializable dict with exposed secrets
134
+ auth_dict = auth_result.model_dump(mode='json')
135
+ # Manually expose SecretStr values in credentials
136
+ for i, cred_obj in enumerate(auth_result.credentials):
137
+ if isinstance(cred_obj, BearerTokenCred):
138
+ auth_dict['credentials'][i]['token'] = cred_obj.token.get_secret_value()
139
+ elif isinstance(cred_obj, BasicAuthCred):
140
+ auth_dict['credentials'][i]['username'] = cred_obj.username.get_secret_value()
141
+ auth_dict['credentials'][i]['password'] = cred_obj.password.get_secret_value()
142
+ elif isinstance(cred_obj, HeaderCred | QueryCred | CookieCred):
143
+ auth_dict['credentials'][i]['value'] = cred_obj.value.get_secret_value()
144
+
145
+ data = json.dumps(auth_dict).encode('utf-8')
146
+
147
+ # Prepare metadata
148
+ metadata = {}
149
+ if auth_result.token_expires_at:
150
+ metadata["expires_at"] = auth_result.token_expires_at.isoformat()
151
+
152
+ # Create the object store item
153
+ item = ObjectStoreItem(data=data, content_type="application/json", metadata=metadata if metadata else None)
154
+
155
+ # Store using upsert to handle both new and existing tokens
156
+ await self._object_store.upsert_object(key, item)
157
+
158
+ async def retrieve(self, user_id: str) -> AuthResult | None:
159
+ """
160
+ Retrieve an authentication result from the object store.
161
+
162
+ Args:
163
+ user_id: The unique identifier for the user
164
+
165
+ Returns:
166
+ The authentication result if found, None otherwise
167
+ """
168
+ key = self._get_key(user_id)
169
+
170
+ try:
171
+ item = await self._object_store.get_object(key)
172
+ # Deserialize the AuthResult from JSON
173
+ auth_result = AuthResult.model_validate_json(item.data)
174
+ return auth_result
175
+ except NoSuchKeyError:
176
+ return None
177
+ except Exception as e:
178
+ logger.error(f"Error deserializing token for user {user_id}: {e}", exc_info=True)
179
+ return None
180
+
181
+ async def delete(self, user_id: str) -> None:
182
+ """
183
+ Delete an authentication result from the object store.
184
+
185
+ Args:
186
+ user_id: The unique identifier for the user
187
+ """
188
+ key = self._get_key(user_id)
189
+
190
+ try:
191
+ await self._object_store.delete_object(key)
192
+ except NoSuchKeyError:
193
+ # Token doesn't exist, which is fine for delete operations
194
+ pass
195
+
196
+ async def clear_all(self) -> None:
197
+ """
198
+ Clear all stored authentication results.
199
+
200
+ Note: This implementation does not support clearing all tokens as the
201
+ object store interface doesn't provide a list operation. Individual
202
+ tokens must be deleted explicitly.
203
+ """
204
+ logger.warning("clear_all() is not supported for ObjectStoreTokenStorage")
205
+
206
+
207
+ class InMemoryTokenStorage(TokenStorageBase):
208
+ """
209
+ In-memory token storage using NeMo Agent toolkit's built-in object store.
210
+
211
+ This implementation uses the in-memory object store for token persistence,
212
+ which provides a secure default option that doesn't require external storage
213
+ configuration. Tokens are stored in memory and cleared when the process exits.
214
+ """
215
+
216
+ def __init__(self):
217
+ """
218
+ Initialize the in-memory token storage.
219
+ """
220
+ from nat.object_store.in_memory_object_store import InMemoryObjectStore
221
+
222
+ # Create a dedicated in-memory object store for tokens
223
+ self._object_store = InMemoryObjectStore()
224
+
225
+ # Wrap with ObjectStoreTokenStorage for the actual implementation
226
+ self._storage = ObjectStoreTokenStorage(self._object_store)
227
+ logger.debug("Initialized in-memory token storage")
228
+
229
+ async def store(self, user_id: str, auth_result: AuthResult) -> None:
230
+ """
231
+ Store an authentication result in memory.
232
+
233
+ Args:
234
+ user_id: The unique identifier for the user
235
+ auth_result: The authentication result to store
236
+ """
237
+ await self._storage.store(user_id, auth_result)
238
+
239
+ async def retrieve(self, user_id: str) -> AuthResult | None:
240
+ """
241
+ Retrieve an authentication result from memory.
242
+
243
+ Args:
244
+ user_id: The unique identifier for the user
245
+
246
+ Returns:
247
+ The authentication result if found, None otherwise
248
+ """
249
+ return await self._storage.retrieve(user_id)
250
+
251
+ async def delete(self, user_id: str) -> None:
252
+ """
253
+ Delete an authentication result from memory.
254
+
255
+ Args:
256
+ user_id: The unique identifier for the user
257
+ """
258
+ await self._storage.delete(user_id)
259
+
260
+ async def clear_all(self) -> None:
261
+ """
262
+ Clear all stored authentication results from memory.
263
+ """
264
+ # For in-memory storage, we can access the internal storage
265
+ self._object_store._store.clear()
@@ -16,15 +16,16 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import asyncio
19
- import json
20
19
  import logging
21
20
  from abc import ABC
22
21
  from abc import abstractmethod
23
22
  from collections.abc import AsyncGenerator
23
+ from collections.abc import Callable
24
24
  from contextlib import AsyncExitStack
25
25
  from contextlib import asynccontextmanager
26
26
  from datetime import timedelta
27
27
 
28
+ import anyio
28
29
  import httpx
29
30
 
30
31
  from mcp import ClientSession
@@ -33,9 +34,9 @@ from mcp.client.stdio import StdioServerParameters
33
34
  from mcp.client.stdio import stdio_client
34
35
  from mcp.client.streamable_http import streamablehttp_client
35
36
  from mcp.types import TextContent
37
+ from nat.authentication.interfaces import AuthenticatedContext
38
+ from nat.authentication.interfaces import AuthFlowType
36
39
  from nat.authentication.interfaces import AuthProviderBase
37
- from nat.data_models.authentication import AuthReason
38
- from nat.data_models.authentication import AuthRequest
39
40
  from nat.plugins.mcp.exception_handler import convert_to_mcp_error
40
41
  from nat.plugins.mcp.exception_handler import format_mcp_error
41
42
  from nat.plugins.mcp.exception_handler import mcp_exception_handler
@@ -53,74 +54,71 @@ class AuthAdapter(httpx.Auth):
53
54
  Converts AuthProviderBase to httpx.Auth interface for dynamic token management.
54
55
  """
55
56
 
56
- def __init__(self, auth_provider: AuthProviderBase, auth_for_tool_calls_only: bool = False):
57
+ def __init__(self, auth_provider: AuthProviderBase, user_id: str | None = None):
57
58
  self.auth_provider = auth_provider
58
- self.auth_for_tool_calls_only = auth_for_tool_calls_only
59
+ self.user_id = user_id # Session-specific user ID for cache isolation
60
+ # each adapter instance has its own lock to avoid unnecessary delays for multiple clients
61
+ self._lock = anyio.Lock()
62
+ # Track whether we're currently in an interactive authentication flow
63
+ self.is_authenticating = False
59
64
 
60
65
  async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
61
66
  """Add authentication headers to the request using NAT auth provider."""
62
- # Check if we should only auth tool calls, Is this needed?
63
- if self.auth_for_tool_calls_only and not self._is_tool_call_request(request):
64
- # Skip auth for non-tool calls
65
- yield request
66
- return
67
-
68
- try:
69
- # Get fresh auth headers from the NAT auth provider
70
- auth_headers = await self._get_auth_headers(reason=AuthReason.NORMAL)
71
- request.headers.update(auth_headers)
72
- except Exception as e:
73
- logger.info("Failed to get auth headers: %s", e)
74
- # Continue without auth headers if auth fails
75
-
76
- response = yield request
77
-
78
- # Handle 401 responses by retrying with fresh auth
79
- if response.status_code == 401:
67
+ async with self._lock:
80
68
  try:
81
- # Get fresh auth headers with 401 context
82
- auth_headers = await self._get_auth_headers(reason=AuthReason.RETRY_AFTER_401, response=response)
69
+ # Get auth headers from the NAT auth provider:
70
+ # 1. If discovery is yet to done this will return None and request will be sent without auth header.
71
+ # 2. If discovery is done, this will return the auth header from cache if the token is still valid
72
+ auth_headers = await self._get_auth_headers(request=request, response=None)
83
73
  request.headers.update(auth_headers)
84
- yield request # Retry the request
85
74
  except Exception as e:
86
- logger.info("Failed to refresh auth after 401: %s", e)
75
+ logger.info("Failed to get auth headers: %s", e)
76
+ # Continue without auth headers if auth fails
77
+
78
+ response = yield request
79
+
80
+ # Handle 401 responses by retrying with fresh auth
81
+ if response.status_code == 401:
82
+ try:
83
+ # 401 can happen if:
84
+ # 1. The request was sent without auth header
85
+ # 2. The auth headers are invalid
86
+ # 3. The auth headers are expired
87
+ # 4. The auth headers are revoked
88
+ # 5. Auth config on the MCP server has changed
89
+ # In this case we attempt to re-run discovery and authentication
90
+
91
+ # Signal that we're entering interactive auth flow
92
+ self.is_authenticating = True
93
+ logger.debug("Starting authentication flow due to 401 response")
94
+
95
+ auth_headers = await self._get_auth_headers(request=request, response=response)
96
+ request.headers.update(auth_headers)
97
+ yield request # Retry the request
98
+ except Exception as e:
99
+ logger.info("Failed to refresh auth after 401: %s", e)
100
+ raise
101
+ finally:
102
+ # Signal that auth flow is complete
103
+ self.is_authenticating = False
104
+ logger.debug("Authentication flow completed")
87
105
  return
88
106
 
89
- def _is_tool_call_request(self, request: httpx.Request) -> bool:
90
- """Check if this is a tool call request based on the request body."""
91
- try:
92
- # Check if the request body contains a tool call
93
- if request.content:
94
- body = json.loads(request.content.decode('utf-8'))
95
- # Check if it's a JSON-RPC request with method "tools/call"
96
- if (isinstance(body, dict) and body.get("method") == "tools/call"):
97
- return True
98
- except (json.JSONDecodeError, UnicodeDecodeError, AttributeError):
99
- # If we can't parse the body, assume it's not a tool call
100
- pass
101
- return False
102
-
103
- async def _get_auth_headers(self, reason: AuthReason, response: httpx.Response | None = None) -> dict[str, str]:
107
+ async def _get_auth_headers(self,
108
+ request: httpx.Request | None = None,
109
+ response: httpx.Response | None = None) -> dict[str, str]:
104
110
  """Get authentication headers from the NAT auth provider."""
105
- # Build auth request
106
- www_authenticate = response.headers.get("WWW-Authenticate", None) if response else None
107
- auth_request = AuthRequest(
108
- reason=reason,
109
- www_authenticate=www_authenticate,
110
- )
111
111
  try:
112
- # Mutating the config is not thread-safe, so we need to lock here
113
- # Is mutating the config the only way to pass the auth request to the auth provider? This needs
114
- # to be re-visited.
115
- self.auth_provider.config.auth_request = auth_request
116
- auth_result = await self.auth_provider.authenticate()
112
+ # Use the user_id passed to this AuthAdapter instance
113
+ auth_result = await self.auth_provider.authenticate(user_id=self.user_id, response=response)
114
+
117
115
  # Check if we have BearerTokenCred
118
116
  from nat.data_models.authentication import BearerTokenCred
119
117
  if auth_result.credentials and isinstance(auth_result.credentials[0], BearerTokenCred):
120
118
  token = auth_result.credentials[0].token.get_secret_value()
121
119
  return {"Authorization": f"Bearer {token}"}
122
120
  else:
123
- logger.warning("Auth provider did not return BearerTokenCred")
121
+ logger.info("Auth provider did not return BearerTokenCred")
124
122
  return {}
125
123
  except Exception as e:
126
124
  logger.warning("Failed to get auth token: %s", e)
@@ -134,12 +132,20 @@ class MCPBaseClient(ABC):
134
132
  Args:
135
133
  transport (str): The type of client to use ('sse', 'stdio', or 'streamable-http')
136
134
  auth_provider (AuthProviderBase | None): Optional authentication provider for Bearer token injection
135
+ tool_call_timeout (timedelta): Timeout for tool calls when authentication is not required
136
+ auth_flow_timeout (timedelta): Extended timeout for tool calls that may require interactive authentication
137
+ reconnect_enabled (bool): Whether to automatically reconnect on connection failures
138
+ reconnect_max_attempts (int): Maximum number of reconnection attempts
139
+ reconnect_initial_backoff (float): Initial backoff delay in seconds for reconnection attempts
140
+ reconnect_max_backoff (float): Maximum backoff delay in seconds for reconnection attempts
137
141
  """
138
142
 
139
143
  def __init__(self,
140
144
  transport: str = 'streamable-http',
141
145
  auth_provider: AuthProviderBase | None = None,
142
- tool_call_timeout: timedelta = timedelta(seconds=5),
146
+ user_id: str | None = None,
147
+ tool_call_timeout: timedelta = timedelta(seconds=60),
148
+ auth_flow_timeout: timedelta = timedelta(seconds=300),
143
149
  reconnect_enabled: bool = True,
144
150
  reconnect_max_attempts: int = 2,
145
151
  reconnect_initial_backoff: float = 0.5,
@@ -155,9 +161,13 @@ class MCPBaseClient(ABC):
155
161
  self._initial_connection = False
156
162
 
157
163
  # Convert auth provider to AuthAdapter
158
- self._httpx_auth = AuthAdapter(auth_provider) if auth_provider else None
164
+ self._auth_provider = auth_provider
165
+ # Use provided user_id or fall back to auth provider's default_user_id
166
+ effective_user_id = user_id or (auth_provider.config.default_user_id if auth_provider else None)
167
+ self._httpx_auth = AuthAdapter(auth_provider, effective_user_id) if auth_provider else None
159
168
 
160
169
  self._tool_call_timeout = tool_call_timeout
170
+ self._auth_flow_timeout = auth_flow_timeout
161
171
 
162
172
  # Reconnect configuration
163
173
  self._reconnect_enabled = reconnect_enabled
@@ -166,6 +176,10 @@ class MCPBaseClient(ABC):
166
176
  self._reconnect_max_backoff = reconnect_max_backoff
167
177
  self._reconnect_lock: asyncio.Lock = asyncio.Lock()
168
178
 
179
+ @property
180
+ def auth_provider(self) -> AuthProviderBase | None:
181
+ return self._auth_provider
182
+
169
183
  @property
170
184
  def transport(self) -> str:
171
185
  return self._transport
@@ -248,12 +262,25 @@ class MCPBaseClient(ABC):
248
262
  async def _with_reconnect(self, coro):
249
263
  """
250
264
  Execute an awaited operation, reconnecting once on errors.
265
+ Does not reconnect if the error occurs during an active authentication flow.
251
266
  """
252
267
  try:
253
268
  return await coro()
254
269
  except Exception as e:
270
+ # Check if error happened during active authentication flow
271
+ if self._httpx_auth and self._httpx_auth.is_authenticating:
272
+ # Provide specific error message for authentication timeouts
273
+ if isinstance(e, TimeoutError):
274
+ logger.error("Timeout during user authentication flow - user may have abandoned authentication")
275
+ raise RuntimeError(
276
+ "Authentication timed out. User did not complete authentication in browser within "
277
+ f"{self._auth_flow_timeout.total_seconds()} seconds.") from e
278
+ else:
279
+ logger.error("Error during authentication flow: %s", e)
280
+ raise
281
+
282
+ # Normal error - attempt reconnect if enabled
255
283
  if self._reconnect_enabled:
256
- logger.warning("MCP Client operation failed. Attempting reconnect: %s", e)
257
284
  try:
258
285
  await self._reconnect()
259
286
  except Exception as reconnect_err:
@@ -262,7 +289,49 @@ class MCPBaseClient(ABC):
262
289
  return await coro()
263
290
  raise
264
291
 
265
- async def get_tools(self):
292
+ async def _has_cached_auth_token(self) -> bool:
293
+ """
294
+ Check if we have a cached, non-expired authentication token.
295
+
296
+ Returns:
297
+ bool: True if we have a valid cached token, False if authentication may be needed
298
+ """
299
+ if not self._auth_provider:
300
+ return True # No auth needed
301
+
302
+ try:
303
+ # Check if OAuth2 provider has tokens cached
304
+ if hasattr(self._auth_provider, '_auth_code_provider'):
305
+ provider = self._auth_provider._auth_code_provider
306
+ if provider and hasattr(provider, '_authenticated_tokens'):
307
+ # Check if we have at least one non-expired token
308
+ for auth_result in provider._authenticated_tokens.values():
309
+ if not auth_result.is_expired():
310
+ return True
311
+
312
+ return False
313
+ except Exception:
314
+ # If we can't check, assume we need auth to be safe
315
+ return False
316
+
317
+ async def _get_tool_call_timeout(self) -> timedelta:
318
+ """
319
+ Determine the appropriate timeout for a tool call based on authentication state.
320
+
321
+ Returns:
322
+ timedelta: auth_flow_timeout if authentication may be needed, tool_call_timeout otherwise
323
+ """
324
+ if self._auth_provider:
325
+ has_token = await self._has_cached_auth_token()
326
+ timeout = self._tool_call_timeout if has_token else self._auth_flow_timeout
327
+ if not has_token:
328
+ logger.debug("Using extended timeout (%s) for potential interactive authentication", timeout)
329
+ return timeout
330
+ else:
331
+ return self._tool_call_timeout
332
+
333
+ @mcp_exception_handler
334
+ async def get_tools(self) -> dict[str, MCPToolClient]:
266
335
  """
267
336
  Retrieve a dictionary of all tools served by the MCP server.
268
337
  Uses unauthenticated session for discovery.
@@ -270,7 +339,16 @@ class MCPBaseClient(ABC):
270
339
 
271
340
  async def _get_tools():
272
341
  session = self._session
273
- return await session.list_tools()
342
+ try:
343
+ # Add timeout to the list_tools call.
344
+ # This is needed because MCP SDK does not support timeout for list_tools()
345
+ with anyio.fail_after(self._tool_call_timeout.total_seconds()):
346
+ tools = await session.list_tools()
347
+ except TimeoutError as e:
348
+ from nat.plugins.mcp.exceptions import MCPTimeoutError
349
+ raise MCPTimeoutError(self.server_name, e)
350
+
351
+ return tools
274
352
 
275
353
  try:
276
354
  response = await self._with_reconnect(_get_tools)
@@ -284,8 +362,7 @@ class MCPBaseClient(ABC):
284
362
  tool_name=tool.name,
285
363
  tool_description=tool.description,
286
364
  tool_input_schema=tool.inputSchema,
287
- parent_client=self,
288
- tool_call_timeout=self._tool_call_timeout)
365
+ parent_client=self)
289
366
  for tool in response.tools
290
367
  }
291
368
 
@@ -314,12 +391,18 @@ class MCPBaseClient(ABC):
314
391
  raise MCPToolNotFoundError(tool_name, self.server_name)
315
392
  return tool
316
393
 
394
+ def set_user_auth_callback(self, auth_callback: Callable[[AuthFlowType], AuthenticatedContext]):
395
+ """Set the user authentication callback."""
396
+ if self._auth_provider and hasattr(self._auth_provider, "_set_custom_auth_callback"):
397
+ self._auth_provider._set_custom_auth_callback(auth_callback)
398
+
317
399
  @mcp_exception_handler
318
400
  async def call_tool(self, tool_name: str, tool_args: dict | None):
319
401
 
320
402
  async def _call_tool():
321
403
  session = self._session
322
- return await session.call_tool(tool_name, tool_args, read_timeout_seconds=self._tool_call_timeout)
404
+ timeout = await self._get_tool_call_timeout()
405
+ return await session.call_tool(tool_name, tool_args, read_timeout_seconds=timeout)
323
406
 
324
407
  return await self._with_reconnect(_call_tool)
325
408
 
@@ -334,13 +417,15 @@ class MCPSSEClient(MCPBaseClient):
334
417
 
335
418
  def __init__(self,
336
419
  url: str,
337
- tool_call_timeout: timedelta = timedelta(seconds=5),
420
+ tool_call_timeout: timedelta = timedelta(seconds=60),
421
+ auth_flow_timeout: timedelta = timedelta(seconds=300),
338
422
  reconnect_enabled: bool = True,
339
423
  reconnect_max_attempts: int = 2,
340
424
  reconnect_initial_backoff: float = 0.5,
341
425
  reconnect_max_backoff: float = 50.0):
342
426
  super().__init__("sse",
343
427
  tool_call_timeout=tool_call_timeout,
428
+ auth_flow_timeout=auth_flow_timeout,
344
429
  reconnect_enabled=reconnect_enabled,
345
430
  reconnect_max_attempts=reconnect_max_attempts,
346
431
  reconnect_initial_backoff=reconnect_initial_backoff,
@@ -383,13 +468,15 @@ class MCPStdioClient(MCPBaseClient):
383
468
  command: str,
384
469
  args: list[str] | None = None,
385
470
  env: dict[str, str] | None = None,
386
- tool_call_timeout: timedelta = timedelta(seconds=5),
471
+ tool_call_timeout: timedelta = timedelta(seconds=60),
472
+ auth_flow_timeout: timedelta = timedelta(seconds=300),
387
473
  reconnect_enabled: bool = True,
388
474
  reconnect_max_attempts: int = 2,
389
475
  reconnect_initial_backoff: float = 0.5,
390
476
  reconnect_max_backoff: float = 50.0):
391
477
  super().__init__("stdio",
392
478
  tool_call_timeout=tool_call_timeout,
479
+ auth_flow_timeout=auth_flow_timeout,
393
480
  reconnect_enabled=reconnect_enabled,
394
481
  reconnect_max_attempts=reconnect_max_attempts,
395
482
  reconnect_initial_backoff=reconnect_initial_backoff,
@@ -440,14 +527,18 @@ class MCPStreamableHTTPClient(MCPBaseClient):
440
527
  def __init__(self,
441
528
  url: str,
442
529
  auth_provider: AuthProviderBase | None = None,
443
- tool_call_timeout: timedelta = timedelta(seconds=5),
530
+ user_id: str | None = None,
531
+ tool_call_timeout: timedelta = timedelta(seconds=60),
532
+ auth_flow_timeout: timedelta = timedelta(seconds=300),
444
533
  reconnect_enabled: bool = True,
445
534
  reconnect_max_attempts: int = 2,
446
535
  reconnect_initial_backoff: float = 0.5,
447
536
  reconnect_max_backoff: float = 50.0):
448
537
  super().__init__("streamable-http",
449
538
  auth_provider=auth_provider,
539
+ user_id=user_id,
450
540
  tool_call_timeout=tool_call_timeout,
541
+ auth_flow_timeout=auth_flow_timeout,
451
542
  reconnect_enabled=reconnect_enabled,
452
543
  reconnect_max_attempts=reconnect_max_attempts,
453
544
  reconnect_initial_backoff=reconnect_initial_backoff,
@@ -490,17 +581,15 @@ class MCPToolClient:
490
581
 
491
582
  def __init__(self,
492
583
  session: ClientSession,
493
- parent_client: "MCPBaseClient",
584
+ parent_client: MCPBaseClient,
494
585
  tool_name: str,
495
586
  tool_description: str | None,
496
- tool_input_schema: dict | None = None,
497
- tool_call_timeout: timedelta = timedelta(seconds=5)):
587
+ tool_input_schema: dict | None = None):
498
588
  self._session = session
499
589
  self._tool_name = tool_name
500
590
  self._tool_description = tool_description
501
591
  self._input_schema = (model_from_mcp_schema(self._tool_name, tool_input_schema) if tool_input_schema else None)
502
592
  self._parent_client = parent_client
503
- self._tool_call_timeout = tool_call_timeout
504
593
 
505
594
  if self._parent_client is None:
506
595
  raise RuntimeError("MCPToolClient initialized without a parent client.")
@@ -535,12 +624,17 @@ class MCPToolClient:
535
624
  async def acall(self, tool_args: dict) -> str:
536
625
  """
537
626
  Call the MCP tool with the provided arguments.
627
+ Session context is now handled at the client level, eliminating the need for metadata injection.
538
628
 
539
629
  Args:
540
630
  tool_args (dict[str, Any]): A dictionary of key value pairs to serve as inputs for the MCP tool.
541
631
  """
542
- logger.info("Calling tool %s with arguments %s", self._tool_name, tool_args)
632
+ if self._session is None:
633
+ raise RuntimeError("No session available for tool call")
634
+
543
635
  try:
636
+ # Simple tool call - session context is already in the client instance
637
+ logger.info("Calling tool %s", self._tool_name)
544
638
  result = await self._parent_client.call_tool(self._tool_name, tool_args)
545
639
 
546
640
  output = []
@@ -558,6 +652,6 @@ class MCPToolClient:
558
652
 
559
653
  except MCPError as e:
560
654
  format_mcp_error(e, include_traceback=False)
561
- result_str = "MCPToolClient tool call failed: %s" % e.original_exception
655
+ result_str = f"MCPToolClient tool call failed: {e.original_exception}"
562
656
 
563
657
  return result_str