nvidia-nat-mcp 1.4.0a20260107__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. nat/meta/pypi.md +32 -0
  2. nat/plugins/mcp/__init__.py +14 -0
  3. nat/plugins/mcp/auth/__init__.py +14 -0
  4. nat/plugins/mcp/auth/auth_flow_handler.py +208 -0
  5. nat/plugins/mcp/auth/auth_provider.py +431 -0
  6. nat/plugins/mcp/auth/auth_provider_config.py +86 -0
  7. nat/plugins/mcp/auth/register.py +33 -0
  8. nat/plugins/mcp/auth/service_account/__init__.py +14 -0
  9. nat/plugins/mcp/auth/service_account/provider.py +136 -0
  10. nat/plugins/mcp/auth/service_account/provider_config.py +137 -0
  11. nat/plugins/mcp/auth/service_account/token_client.py +156 -0
  12. nat/plugins/mcp/auth/token_storage.py +265 -0
  13. nat/plugins/mcp/cli/__init__.py +15 -0
  14. nat/plugins/mcp/cli/commands.py +1051 -0
  15. nat/plugins/mcp/client/__init__.py +15 -0
  16. nat/plugins/mcp/client/client_base.py +665 -0
  17. nat/plugins/mcp/client/client_config.py +146 -0
  18. nat/plugins/mcp/client/client_impl.py +782 -0
  19. nat/plugins/mcp/exception_handler.py +211 -0
  20. nat/plugins/mcp/exceptions.py +142 -0
  21. nat/plugins/mcp/register.py +23 -0
  22. nat/plugins/mcp/server/__init__.py +15 -0
  23. nat/plugins/mcp/server/front_end_config.py +109 -0
  24. nat/plugins/mcp/server/front_end_plugin.py +155 -0
  25. nat/plugins/mcp/server/front_end_plugin_worker.py +411 -0
  26. nat/plugins/mcp/server/introspection_token_verifier.py +72 -0
  27. nat/plugins/mcp/server/memory_profiler.py +320 -0
  28. nat/plugins/mcp/server/register_frontend.py +27 -0
  29. nat/plugins/mcp/server/tool_converter.py +286 -0
  30. nat/plugins/mcp/utils.py +228 -0
  31. nvidia_nat_mcp-1.4.0a20260107.dist-info/METADATA +55 -0
  32. nvidia_nat_mcp-1.4.0a20260107.dist-info/RECORD +37 -0
  33. nvidia_nat_mcp-1.4.0a20260107.dist-info/WHEEL +5 -0
  34. nvidia_nat_mcp-1.4.0a20260107.dist-info/entry_points.txt +9 -0
  35. nvidia_nat_mcp-1.4.0a20260107.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
  36. nvidia_nat_mcp-1.4.0a20260107.dist-info/licenses/LICENSE.md +201 -0
  37. nvidia_nat_mcp-1.4.0a20260107.dist-info/top_level.txt +1 -0
@@ -0,0 +1,782 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import asyncio
17
+ import logging
18
+ from contextlib import asynccontextmanager
19
+ from dataclasses import dataclass
20
+ from dataclasses import field
21
+ from datetime import datetime
22
+ from datetime import timedelta
23
+
24
+ import aiorwlock
25
+ from pydantic import BaseModel
26
+
27
+ from nat.authentication.interfaces import AuthProviderBase
28
+ from nat.builder.builder import Builder
29
+ from nat.builder.context import Context
30
+ from nat.builder.function import FunctionGroup
31
+ from nat.cli.register_workflow import register_function_group
32
+ from nat.cli.register_workflow import register_per_user_function_group
33
+ from nat.plugins.mcp.client.client_base import MCPBaseClient
34
+ from nat.plugins.mcp.client.client_config import MCPClientConfig
35
+ from nat.plugins.mcp.client.client_config import MCPToolOverrideConfig
36
+ from nat.plugins.mcp.client.client_config import PerUserMCPClientConfig
37
+ from nat.plugins.mcp.utils import truncate_session_id
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ class PerUserMCPFunctionGroup(FunctionGroup):
43
+ """
44
+ A specialized FunctionGroup for per-user MCP clients.
45
+ """
46
+
47
+ def __init__(self, *args, **kwargs):
48
+ super().__init__(*args, **kwargs)
49
+
50
+ self.mcp_client: MCPBaseClient | None = None # Will be set to the actual MCP client instance
51
+ self.mcp_client_server_name: str | None = None
52
+ self.mcp_client_transport: str | None = None
53
+ self.user_id: str | None = None
54
+
55
+
56
+ def mcp_per_user_tool_function(tool, client: MCPBaseClient):
57
+ """
58
+ Create a per-user NAT function for an MCP tool.
59
+
60
+ Args:
61
+ tool: The MCP tool to create a function for
62
+ client: The MCP client to use for the function
63
+
64
+ Returns:
65
+ The NAT function
66
+ """
67
+ from nat.builder.function import FunctionInfo
68
+
69
+ def _convert_from_str(input_str: str) -> tool.input_schema:
70
+ return tool.input_schema.model_validate_json(input_str)
71
+
72
+ async def _response_fn(tool_input: BaseModel | None = None, **kwargs) -> str:
73
+ try:
74
+ mcp_tool = await client.get_tool(tool.name)
75
+
76
+ if tool_input:
77
+ args = tool_input.model_dump(exclude_none=True, mode='json')
78
+ return await mcp_tool.acall(args)
79
+
80
+ # kwargs arrives with all optional fields set to None because NAT's framework
81
+ # converts the input dict to a Pydantic model (filling in all Field(default=None)),
82
+ # then dumps it back to a dict. We need to strip out these None values because
83
+ # many MCP servers (e.g., Kaggle) reject requests with excessive null fields.
84
+ # We re-validate here (yes, redundant) to leverage Pydantic's exclude_none with
85
+ # mode='json' for recursive None removal in nested models.
86
+ # Reference: function_info.py:_convert_input_pydantic
87
+ validated_input = mcp_tool.input_schema.model_validate(kwargs)
88
+ args = validated_input.model_dump(exclude_none=True, mode='json')
89
+ return await mcp_tool.acall(args)
90
+ except Exception as e:
91
+ logger.warning("Error calling tool %s", tool.name, exc_info=True)
92
+ return str(e)
93
+
94
+ return FunctionInfo.create(single_fn=_response_fn,
95
+ description=tool.description,
96
+ input_schema=tool.input_schema,
97
+ converters=[_convert_from_str])
98
+
99
+
100
+ @dataclass
101
+ class SessionData:
102
+ """Container for all session-related data."""
103
+ client: MCPBaseClient
104
+ last_activity: datetime
105
+ ref_count: int = 0
106
+ lock: asyncio.Lock = field(default_factory=asyncio.Lock)
107
+
108
+ # lifetime task to respect task boundaries
109
+ stop_event: asyncio.Event = field(default_factory=asyncio.Event)
110
+ lifetime_task: asyncio.Task | None = None
111
+
112
+
113
+ class MCPFunctionGroup(FunctionGroup):
114
+ """
115
+ A specialized FunctionGroup for MCP clients that includes MCP-specific attributes
116
+ with session management.
117
+
118
+ Locking model (simple + safe; occasional 'temporarily unavailable' is acceptable).
119
+
120
+ RW semantics:
121
+ - Multiple readers may hold the reader lock concurrently.
122
+ - While any reader holds the lock, writers cannot proceed.
123
+ - While the writer holds the lock, no new readers can proceed.
124
+
125
+ Data:
126
+ - _sessions: dict[str, SessionData]; SessionData = {client, last_activity, ref_count, lock}.
127
+
128
+ Locks:
129
+ - _session_rwlock (aiorwlock.RWLock)
130
+ • Reader: very short sections — dict lookups, ref_count ++/--, touch last_activity.
131
+ • Writer: structural changes — create session entries, enforce limits, remove on cleanup.
132
+ - SessionData.lock (asyncio.Lock)
133
+ • Protects per-session ref_count only, taken only while holding RW *reader*.
134
+ • last_activity: written without session lock (timestamp races acceptable for cleanup heuristic).
135
+
136
+ Ordering & awaits:
137
+ - Always acquire RWLock (reader/writer) before SessionData.lock; never the reverse.
138
+ - Never await network I/O under the writer (client creation is the one intentional exception).
139
+ - Client close happens after releasing the writer.
140
+
141
+ Cleanup:
142
+ - Under writer: find inactive (ref_count == 0 and idle > max_age), pop from _sessions, stash clients.
143
+ - After writer: await client.__aexit__() for each stashed client.
144
+ - TOCTOU race: cleanup may read ref_count==0 then a usage increments it; accepted, yields None gracefully.
145
+
146
+ Invariants:
147
+ - ref_count > 0 prevents cleanup.
148
+ - Usage context increments ref_count before yielding and decrements on exit.
149
+ - If a session disappears between ensure/use, callers return "Tool temporarily unavailable".
150
+ """
151
+
152
+ def __init__(self, *args, **kwargs):
153
+ super().__init__(*args, **kwargs)
154
+ # MCP client attributes with proper typing
155
+ self.mcp_client: MCPBaseClient | None = None # Will be set to the actual MCP client instance
156
+ self.mcp_client_server_name: str | None = None
157
+ self.mcp_client_transport: str | None = None
158
+
159
+ # Session management - consolidated data structure
160
+ self._sessions: dict[str, SessionData] = {}
161
+
162
+ # Use RWLock for better concurrency: multiple readers (tool calls) can access
163
+ # existing sessions simultaneously, while writers (create/delete) get exclusive access
164
+ self._session_rwlock = aiorwlock.RWLock()
165
+ # Throttled cleanup control
166
+ self._last_cleanup_check: datetime = datetime.now()
167
+ self._cleanup_check_interval: timedelta = timedelta(minutes=5)
168
+
169
+ # Shared components for session client creation
170
+ self._shared_auth_provider: AuthProviderBase | None = None
171
+ self._client_config: MCPClientConfig | None = None
172
+
173
+ # Auth provider config defaults (set when auth provider is assigned)
174
+ self._default_user_id: str | None = None
175
+ self._allow_default_user_id_for_tool_calls: bool = True
176
+
177
+ # Use random session id for testing only
178
+ self._use_random_session_id_for_testing: bool = False
179
+
180
+ @property
181
+ def session_count(self) -> int:
182
+ """Current number of active sessions."""
183
+ return len(self._sessions)
184
+
185
+ @property
186
+ def session_limit(self) -> int:
187
+ """Maximum allowed sessions."""
188
+ return self._client_config.max_sessions if self._client_config else 100
189
+
190
+ def _get_random_session_id(self) -> str:
191
+ """Get a random session ID."""
192
+ import uuid
193
+ return str(uuid.uuid4())
194
+
195
+ def _get_session_id_from_context(self) -> str | None:
196
+ """Get the session ID from the current context."""
197
+ try:
198
+ from nat.builder.context import Context as _Ctx
199
+
200
+ # Get session id from context, authentication is done per-websocket session for tool calls
201
+ session_id = None
202
+ # get session id from cookies if session_aware_tools is enabled
203
+ if self._client_config and self._client_config.session_aware_tools:
204
+ cookies = getattr(_Ctx.get().metadata, "cookies", None)
205
+ if cookies:
206
+ if self._use_random_session_id_for_testing:
207
+ # This path is for testing only and should not be used in production
208
+ session_id = self._get_random_session_id()
209
+ else:
210
+ session_id = cookies.get("nat-session")
211
+
212
+ if not session_id:
213
+ # use default user id if allowed
214
+ if self._shared_auth_provider and self._allow_default_user_id_for_tool_calls:
215
+ session_id = self._default_user_id
216
+ return session_id
217
+ except Exception:
218
+ return None
219
+
220
+ async def cleanup_sessions(self, max_age: timedelta | None = None) -> int:
221
+ """
222
+ Manually trigger cleanup of inactive sessions.
223
+
224
+ Args:
225
+ max_age: Maximum age for sessions before cleanup. If None, uses configured timeout.
226
+
227
+ Returns:
228
+ Number of sessions cleaned up.
229
+ """
230
+ sessions_before = len(self._sessions)
231
+ await self._cleanup_inactive_sessions(max_age)
232
+ sessions_after = len(self._sessions)
233
+ return sessions_before - sessions_after
234
+
235
+ async def _cleanup_inactive_sessions(self, max_age: timedelta | None = None):
236
+ """Remove clients for sessions inactive longer than max_age.
237
+
238
+ This method uses the RWLock writer to ensure thread-safe cleanup.
239
+ """
240
+ if max_age is None:
241
+ max_age = self._client_config.session_idle_timeout if self._client_config else timedelta(hours=1)
242
+
243
+ to_close: list[tuple[str, SessionData]] = []
244
+
245
+ async with self._session_rwlock.writer:
246
+ current_time = datetime.now()
247
+ inactive_sessions = []
248
+
249
+ for session_id, session_data in self._sessions.items():
250
+ # Skip cleanup if session is actively being used
251
+ if session_data.ref_count > 0:
252
+ continue
253
+
254
+ if current_time - session_data.last_activity > max_age:
255
+ inactive_sessions.append(session_id)
256
+
257
+ for session_id in inactive_sessions:
258
+ try:
259
+ logger.info("Cleaning up inactive session client: %s", truncate_session_id(session_id))
260
+ session_data = self._sessions[session_id]
261
+ # Close the client connection
262
+ if session_data:
263
+ to_close.append((session_id, session_data))
264
+ except Exception as e:
265
+ logger.warning("Error cleaning up session client %s: %s", truncate_session_id(session_id), e)
266
+ finally:
267
+ # Always remove from tracking to prevent leaks, even if close failed
268
+ self._sessions.pop(session_id, None)
269
+ logger.info("Cleaned up session tracking for: %s", truncate_session_id(session_id))
270
+ logger.info(" Total sessions: %d", len(self._sessions))
271
+
272
+ # Close sessions outside the writer lock to avoid deadlock
273
+ for session_id, sdata in to_close:
274
+ try:
275
+ if sdata.stop_event and sdata.lifetime_task:
276
+ if not sdata.lifetime_task.done():
277
+ # Instead of directly exiting the task, set the stop event
278
+ # and wait for the task to exit. This ensures the cancel scope
279
+ # is entered and exited in the same task.
280
+ sdata.stop_event.set()
281
+ await sdata.lifetime_task # __aexit__ runs in that task
282
+ else:
283
+ logger.debug("Session client %s lifetime task already done", truncate_session_id(session_id))
284
+ else:
285
+ # add fallback to ensure we clean up the client
286
+ logger.warning("Session client %s lifetime task not found, cleaning up client",
287
+ truncate_session_id(session_id))
288
+ await sdata.client.__aexit__(None, None, None)
289
+ except Exception as e:
290
+ logger.warning("Error cleaning up session client %s: %s", truncate_session_id(session_id), e)
291
+
292
+ async def _get_session_client(self, session_id: str) -> MCPBaseClient | None:
293
+ """Get the appropriate MCP client for the session."""
294
+ # Throttled cleanup on access
295
+ now = datetime.now()
296
+ if now - self._last_cleanup_check > self._cleanup_check_interval:
297
+ await self._cleanup_inactive_sessions()
298
+ self._last_cleanup_check = now
299
+
300
+ # If the session_id equals the configured default_user_id use the base client
301
+ # instead of creating a per-session client
302
+ if self._shared_auth_provider:
303
+ if self._default_user_id and session_id == self._default_user_id:
304
+ return self.mcp_client
305
+
306
+ # Fast path: check if session already exists (reader lock for concurrent access)
307
+ async with self._session_rwlock.reader:
308
+ if session_id in self._sessions:
309
+ # Update last activity for existing client
310
+ self._sessions[session_id].last_activity = datetime.now()
311
+ return self._sessions[session_id].client
312
+
313
+ # Check session limit before creating new client (outside writer lock to avoid deadlock)
314
+ if self._client_config and len(self._sessions) >= self._client_config.max_sessions:
315
+ # Try cleanup first to free up space
316
+ await self._cleanup_inactive_sessions()
317
+
318
+ # Slow path: create session with writer lock for exclusive access
319
+ async with self._session_rwlock.writer:
320
+ # Double-check after acquiring writer lock (another coroutine might have created it)
321
+ if session_id in self._sessions:
322
+ self._sessions[session_id].last_activity = datetime.now()
323
+ return self._sessions[session_id].client
324
+
325
+ # Re-check session limit inside writer lock
326
+ if self._client_config and len(self._sessions) >= self._client_config.max_sessions:
327
+ logger.warning("Session limit reached (%d), rejecting new session: %s",
328
+ self._client_config.max_sessions,
329
+ truncate_session_id(session_id))
330
+ raise RuntimeError(f"Tool unavailable: Maximum concurrent sessions "
331
+ f"({self._client_config.max_sessions}) exceeded.")
332
+
333
+ # Create session client lazily
334
+ logger.info("Creating new MCP client for session: %s", truncate_session_id(session_id))
335
+ session_client, stop_event, lifetime_task = await self._create_session_client(session_id)
336
+ session_data = SessionData(
337
+ client=session_client,
338
+ last_activity=datetime.now(),
339
+ ref_count=0,
340
+ stop_event=stop_event,
341
+ lifetime_task=lifetime_task,
342
+ )
343
+
344
+ # Cache the session data
345
+ self._sessions[session_id] = session_data
346
+ logger.info(" Total sessions: %d", len(self._sessions))
347
+ return session_client
348
+
349
+ @asynccontextmanager
350
+ async def _session_usage_context(self, session_id: str):
351
+ """Context manager to track active session usage and prevent cleanup."""
352
+ # Ensure session exists - create it if it doesn't
353
+ if session_id not in self._sessions:
354
+ # Create session client first
355
+ await self._get_session_client(session_id) # START read phase: bump ref_count under reader + session lock
356
+
357
+ async with self._session_rwlock.reader:
358
+ sdata = self._sessions.get(session_id)
359
+ if not sdata:
360
+ # this can happen if the session is cleaned up between the check and the lock
361
+ # this is rare and we can just return that the tool is temporarily unavailable
362
+ yield None
363
+ return
364
+ async with sdata.lock:
365
+ sdata.ref_count += 1
366
+ client = sdata.client # capture
367
+ # END read phase (release reader before long await)
368
+
369
+ try:
370
+ yield client
371
+ finally:
372
+ # Brief read phase to decrement ref_count and touch activity
373
+ async with self._session_rwlock.reader:
374
+ sdata = self._sessions.get(session_id)
375
+ if sdata:
376
+ async with sdata.lock:
377
+ sdata.ref_count -= 1
378
+ sdata.last_activity = datetime.now()
379
+
380
+ async def _create_session_client(self, session_id: str) -> tuple[MCPBaseClient, asyncio.Event, asyncio.Task]:
381
+ """Create a new MCP client instance for the session."""
382
+ from nat.plugins.mcp.client.client_base import MCPStreamableHTTPClient
383
+
384
+ config = self._client_config
385
+ if not config:
386
+ raise RuntimeError("Client config not initialized")
387
+
388
+ if config.server.transport == "streamable-http":
389
+ client = MCPStreamableHTTPClient(
390
+ str(config.server.url),
391
+ auth_provider=self._shared_auth_provider,
392
+ user_id=session_id, # Pass session_id as user_id for cache isolation
393
+ tool_call_timeout=config.tool_call_timeout,
394
+ auth_flow_timeout=config.auth_flow_timeout,
395
+ reconnect_enabled=config.reconnect_enabled,
396
+ reconnect_max_attempts=config.reconnect_max_attempts,
397
+ reconnect_initial_backoff=config.reconnect_initial_backoff,
398
+ reconnect_max_backoff=config.reconnect_max_backoff)
399
+ else:
400
+ # per-user sessions are only supported for streamable-http transport
401
+ raise ValueError(f"Unsupported transport: {config.server.transport}")
402
+
403
+ ready = asyncio.Event()
404
+ stop_event = asyncio.Event()
405
+
406
+ async def _lifetime():
407
+ """
408
+ Create a lifetime task to respect task boundaries and ensure the
409
+ cancel scope is entered and exited in the same task.
410
+ """
411
+ try:
412
+ async with client:
413
+ ready.set()
414
+ await stop_event.wait()
415
+ except Exception:
416
+ ready.set() # Ensure we don't hang the waiter
417
+ raise
418
+
419
+ task = asyncio.create_task(_lifetime(), name=f"mcp-session-{truncate_session_id(session_id)}")
420
+
421
+ # Wait for initialization with timeout to prevent infinite hangs
422
+ timeout = config.tool_call_timeout.total_seconds() if config else 300
423
+ try:
424
+ await asyncio.wait_for(ready.wait(), timeout=timeout)
425
+ except TimeoutError:
426
+ task.cancel()
427
+ try:
428
+ await task
429
+ except asyncio.CancelledError:
430
+ pass
431
+ logger.error("Session client initialization timed out after %ds for %s",
432
+ timeout,
433
+ truncate_session_id(session_id))
434
+ raise RuntimeError(f"Session client initialization timed out after {timeout}s")
435
+
436
+ # Check if initialization failed before ready was set
437
+ if task.done():
438
+ try:
439
+ await task # Re-raise exception if the task failed
440
+ except Exception as e:
441
+ logger.error("Failed to initialize session client for %s: %s", truncate_session_id(session_id), e)
442
+ raise RuntimeError(f"Failed to initialize session client: {e}") from e
443
+
444
+ logger.info("Created session client for session: %s", truncate_session_id(session_id))
445
+ # NOTE: caller will place client into SessionData and attach stop_event/task
446
+ return client, stop_event, task
447
+
448
+
449
+ def mcp_session_tool_function(tool, function_group: MCPFunctionGroup):
450
+ """Create a session-aware NAT function for an MCP tool.
451
+
452
+ Routes each invocation to the appropriate per-session MCP client while
453
+ preserving the original tool input schema, converters, and description.
454
+ """
455
+ from nat.builder.function import FunctionInfo
456
+
457
+ def _convert_from_str(input_str: str) -> tool.input_schema:
458
+ return tool.input_schema.model_validate_json(input_str)
459
+
460
+ async def _response_fn(tool_input: BaseModel | None = None, **kwargs) -> str:
461
+ """Response function for the session-aware tool."""
462
+ try:
463
+ # Route to the appropriate session client
464
+ session_id = function_group._get_session_id_from_context()
465
+
466
+ # If no session is available and default-user fallback is disabled, deny the call
467
+ if function_group._shared_auth_provider and session_id is None:
468
+ return "User not authorized to call the tool"
469
+
470
+ # Check if this is the default user - if so, use base client directly
471
+ if (not function_group._shared_auth_provider or session_id == function_group._default_user_id):
472
+ # Use base client directly for default user
473
+ client = function_group.mcp_client
474
+ if client is None:
475
+ return "Tool temporarily unavailable. Try again."
476
+ session_tool = await client.get_tool(tool.name)
477
+ else:
478
+ # Use session usage context to prevent cleanup during tool execution
479
+ if session_id is None:
480
+ return "Tool temporarily unavailable. Try again."
481
+ async with function_group._session_usage_context(session_id) as client:
482
+ if client is None:
483
+ return "Tool temporarily unavailable. Try again."
484
+ session_tool = await client.get_tool(tool.name)
485
+
486
+ # Preserve original calling convention
487
+ if tool_input:
488
+ args = tool_input.model_dump(exclude_none=True, mode='json')
489
+ return await session_tool.acall(args)
490
+
491
+ # kwargs arrives with all optional fields set to None because NAT's framework
492
+ # converts the input dict to a Pydantic model (filling in all Field(default=None)),
493
+ # then dumps it back to a dict. We need to strip out these None values because
494
+ # many MCP servers (e.g., Kaggle) reject requests with excessive null fields.
495
+ # We re-validate here (yes, redundant) to leverage Pydantic's exclude_none with
496
+ # mode='json' for recursive None removal in nested models.
497
+ # Reference: function_info.py:_convert_input_pydantic
498
+ validated_input = session_tool.input_schema.model_validate(kwargs)
499
+ args = validated_input.model_dump(exclude_none=True, mode='json')
500
+ return await session_tool.acall(args)
501
+ except Exception as e:
502
+ logger.warning("Error calling tool %s", tool.name, exc_info=True)
503
+ return str(e)
504
+
505
+ return FunctionInfo.create(single_fn=_response_fn,
506
+ description=tool.description,
507
+ input_schema=tool.input_schema,
508
+ converters=[_convert_from_str])
509
+
510
+
511
+ @register_function_group(config_type=MCPClientConfig)
512
+ async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder):
513
+ """
514
+ Connect to an MCP server and expose tools as a function group.
515
+
516
+ Args:
517
+ config: The configuration for the MCP client
518
+ _builder: The builder
519
+ Returns:
520
+ The function group
521
+ """
522
+ from nat.plugins.mcp.client.client_base import MCPSSEClient
523
+ from nat.plugins.mcp.client.client_base import MCPStdioClient
524
+ from nat.plugins.mcp.client.client_base import MCPStreamableHTTPClient
525
+
526
+ # Resolve auth provider if specified
527
+ auth_provider = None
528
+ if config.server.auth_provider:
529
+ auth_provider = await _builder.get_auth_provider(config.server.auth_provider)
530
+
531
+ # Build the appropriate client
532
+ if config.server.transport == "stdio":
533
+ if not config.server.command:
534
+ raise ValueError("command is required for stdio transport")
535
+ client = MCPStdioClient(config.server.command,
536
+ config.server.args,
537
+ config.server.env,
538
+ tool_call_timeout=config.tool_call_timeout,
539
+ auth_flow_timeout=config.auth_flow_timeout,
540
+ reconnect_enabled=config.reconnect_enabled,
541
+ reconnect_max_attempts=config.reconnect_max_attempts,
542
+ reconnect_initial_backoff=config.reconnect_initial_backoff,
543
+ reconnect_max_backoff=config.reconnect_max_backoff)
544
+ elif config.server.transport == "sse":
545
+ client = MCPSSEClient(str(config.server.url),
546
+ tool_call_timeout=config.tool_call_timeout,
547
+ auth_flow_timeout=config.auth_flow_timeout,
548
+ reconnect_enabled=config.reconnect_enabled,
549
+ reconnect_max_attempts=config.reconnect_max_attempts,
550
+ reconnect_initial_backoff=config.reconnect_initial_backoff,
551
+ reconnect_max_backoff=config.reconnect_max_backoff)
552
+ elif config.server.transport == "streamable-http":
553
+ # Use default_user_id for the base client
554
+ # For interactive OAuth2: from config. For service accounts: defaults to server URL
555
+ base_user_id = getattr(auth_provider.config, 'default_user_id', str(
556
+ config.server.url)) if auth_provider else None
557
+ client = MCPStreamableHTTPClient(str(config.server.url),
558
+ auth_provider=auth_provider,
559
+ user_id=base_user_id,
560
+ tool_call_timeout=config.tool_call_timeout,
561
+ auth_flow_timeout=config.auth_flow_timeout,
562
+ reconnect_enabled=config.reconnect_enabled,
563
+ reconnect_max_attempts=config.reconnect_max_attempts,
564
+ reconnect_initial_backoff=config.reconnect_initial_backoff,
565
+ reconnect_max_backoff=config.reconnect_max_backoff)
566
+ else:
567
+ raise ValueError(f"Unsupported transport: {config.server.transport}")
568
+
569
+ logger.info("Configured to use MCP server at %s", client.server_name)
570
+
571
+ # Create the MCP function group
572
+ group = MCPFunctionGroup(config=config)
573
+
574
+ # Store shared components for session client creation
575
+ group._shared_auth_provider = auth_provider
576
+ group._client_config = config
577
+
578
+ # Set auth provider config defaults
579
+ # For interactive OAuth2: use config values
580
+ # For service accounts: default_user_id = server URL, allow_default_user_id_for_tool_calls = True
581
+ if auth_provider:
582
+ group._default_user_id = getattr(auth_provider.config, 'default_user_id', str(config.server.url))
583
+ group._allow_default_user_id_for_tool_calls = getattr(auth_provider.config,
584
+ 'allow_default_user_id_for_tool_calls',
585
+ True)
586
+ else:
587
+ group._default_user_id = None
588
+ group._allow_default_user_id_for_tool_calls = True
589
+
590
+ async with client:
591
+ # Expose the live MCP client on the function group instance so other components (e.g., HTTP endpoints)
592
+ # can reuse the already-established session instead of creating a new client per request.
593
+ group.mcp_client = client
594
+ group.mcp_client_server_name = client.server_name
595
+ group.mcp_client_transport = client.transport
596
+
597
+ all_tools = await client.get_tools()
598
+ tool_overrides = mcp_apply_tool_alias_and_description(all_tools, config.tool_overrides)
599
+
600
+ # Add each tool as a function to the group
601
+ for tool_name, tool in all_tools.items():
602
+ # Get override if it exists
603
+ override = tool_overrides.get(tool_name)
604
+
605
+ # Use override values or defaults
606
+ function_name = override.alias if override and override.alias else tool_name
607
+ description = override.description if override and override.description else tool.description
608
+
609
+ # Create the tool function according to configuration
610
+ tool_fn = mcp_session_tool_function(tool, group)
611
+
612
+ input_schema = tool_fn.input_schema
613
+ # Convert NoneType sentinel to None for FunctionGroup.add_function signature
614
+ if input_schema is type(None):
615
+ input_schema = None
616
+
617
+ # Add to group
618
+ logger.info("Adding tool %s to group", function_name)
619
+ group.add_function(name=function_name,
620
+ description=description,
621
+ fn=tool_fn.single_fn,
622
+ input_schema=input_schema,
623
+ converters=tool_fn.converters)
624
+
625
+ yield group
626
+
627
+
628
+ def mcp_apply_tool_alias_and_description(
629
+ all_tools: dict, tool_overrides: dict[str, MCPToolOverrideConfig] | None) -> dict[str, MCPToolOverrideConfig]:
630
+ """
631
+ Filter tool overrides to only include tools that exist in the MCP server.
632
+
633
+ Args:
634
+ all_tools: The tools from the MCP server
635
+ tool_overrides: The tool overrides to apply
636
+ Returns:
637
+ Dictionary of valid tool overrides
638
+ """
639
+ if not tool_overrides:
640
+ return {}
641
+
642
+ return {name: override for name, override in tool_overrides.items() if name in all_tools}
643
+
644
+
645
+ @register_per_user_function_group(config_type=PerUserMCPClientConfig)
646
+ async def per_user_mcp_client_function_group(config: PerUserMCPClientConfig, _builder: Builder):
647
+ """
648
+ Connect to an MCP server and expose tools as a function group for per-user workflows.
649
+
650
+ Args:
651
+ config: The configuration for the MCP client
652
+ _builder: The builder
653
+ Returns:
654
+ The function group
655
+ """
656
+ from nat.plugins.mcp.client.client_base import MCPSSEClient
657
+ from nat.plugins.mcp.client.client_base import MCPStdioClient
658
+ from nat.plugins.mcp.client.client_base import MCPStreamableHTTPClient
659
+
660
+ # Resolve auth provider if specified
661
+ auth_provider = None
662
+ if config.server.auth_provider:
663
+ auth_provider = await _builder.get_auth_provider(config.server.auth_provider)
664
+
665
+ user_id = Context.get().user_id
666
+
667
+ # Build the appropriate client
668
+ if config.server.transport == "stdio":
669
+ if not config.server.command:
670
+ raise ValueError("command is required for stdio transport")
671
+ client = MCPStdioClient(config.server.command,
672
+ config.server.args,
673
+ config.server.env,
674
+ tool_call_timeout=config.tool_call_timeout,
675
+ auth_flow_timeout=config.auth_flow_timeout,
676
+ reconnect_enabled=config.reconnect_enabled,
677
+ reconnect_max_attempts=config.reconnect_max_attempts,
678
+ reconnect_initial_backoff=config.reconnect_initial_backoff,
679
+ reconnect_max_backoff=config.reconnect_max_backoff)
680
+ elif config.server.transport == "sse":
681
+ client = MCPSSEClient(str(config.server.url),
682
+ tool_call_timeout=config.tool_call_timeout,
683
+ auth_flow_timeout=config.auth_flow_timeout,
684
+ reconnect_enabled=config.reconnect_enabled,
685
+ reconnect_max_attempts=config.reconnect_max_attempts,
686
+ reconnect_initial_backoff=config.reconnect_initial_backoff,
687
+ reconnect_max_backoff=config.reconnect_max_backoff)
688
+ elif config.server.transport == "streamable-http":
689
+ client = MCPStreamableHTTPClient(str(config.server.url),
690
+ auth_provider=auth_provider,
691
+ user_id=user_id,
692
+ tool_call_timeout=config.tool_call_timeout,
693
+ auth_flow_timeout=config.auth_flow_timeout,
694
+ reconnect_enabled=config.reconnect_enabled,
695
+ reconnect_max_attempts=config.reconnect_max_attempts,
696
+ reconnect_initial_backoff=config.reconnect_initial_backoff,
697
+ reconnect_max_backoff=config.reconnect_max_backoff)
698
+ else:
699
+ raise ValueError(f"Unsupported transport: {config.server.transport}")
700
+
701
+ logger.info("Per-user MCP client configured for server: %s (user: %s)", client.server_name, user_id)
702
+
703
+ group = PerUserMCPFunctionGroup(config=config)
704
+
705
+ # Use a lifetime task to ensure the client context is entered and exited in the same task.
706
+ # This avoids anyio's "Attempted to exit cancel scope in a different task" error.
707
+ ready = asyncio.Event()
708
+ stop_event = asyncio.Event()
709
+
710
+ async def _lifetime():
711
+ """Lifetime task that owns the client's async context."""
712
+ try:
713
+ async with client:
714
+ ready.set()
715
+ await stop_event.wait()
716
+ except Exception:
717
+ ready.set() # Ensure we don't hang the waiter
718
+ raise
719
+
720
+ lifetime_task = asyncio.create_task(_lifetime(), name=f"mcp-per-user-{user_id}")
721
+
722
+ # Wait for client initialization
723
+ timeout = config.tool_call_timeout.total_seconds()
724
+ try:
725
+ await asyncio.wait_for(ready.wait(), timeout=timeout)
726
+ except TimeoutError:
727
+ lifetime_task.cancel()
728
+ try:
729
+ await lifetime_task
730
+ except asyncio.CancelledError:
731
+ pass
732
+ raise RuntimeError(f"Per-user MCP client initialization timed out after {timeout}s")
733
+
734
+ # Check if initialization failed
735
+ if lifetime_task.done():
736
+ try:
737
+ await lifetime_task
738
+ except Exception as e:
739
+ raise RuntimeError(f"Failed to initialize per-user MCP client: {e}") from e
740
+
741
+ try:
742
+ # Expose the live MCP client on the function group instance so other components (e.g., HTTP endpoints)
743
+ # can reuse the already-established session instead of creating a new client per request.
744
+ group.mcp_client = client
745
+ group.mcp_client_server_name = client.server_name
746
+ group.mcp_client_transport = client.transport
747
+ group.user_id = user_id
748
+
749
+ all_tools = await client.get_tools()
750
+ tool_overrides = mcp_apply_tool_alias_and_description(all_tools, config.tool_overrides)
751
+
752
+ # Add each tool as a function to the group
753
+ for tool_name, tool in all_tools.items():
754
+ # Get override if it exists
755
+ override = tool_overrides.get(tool_name)
756
+
757
+ # Use override values or defaults
758
+ function_name = override.alias if override and override.alias else tool_name
759
+ description = override.description if override and override.description else tool.description
760
+
761
+ # Create the tool function according to configuration
762
+ tool_fn = mcp_per_user_tool_function(tool, client)
763
+
764
+ input_schema = tool_fn.input_schema
765
+ # Convert NoneType sentinel to None for FunctionGroup.add_function signature
766
+ if input_schema is type(None):
767
+ input_schema = None
768
+
769
+ # Add to group
770
+ logger.info("Adding tool %s to group", function_name)
771
+ group.add_function(name=function_name,
772
+ description=description,
773
+ fn=tool_fn.single_fn,
774
+ input_schema=input_schema,
775
+ converters=tool_fn.converters)
776
+
777
+ yield group
778
+ finally:
779
+ # Signal the lifetime task to exit and wait for clean shutdown
780
+ stop_event.set()
781
+ if not lifetime_task.done():
782
+ await lifetime_task