nvidia-nat-mcp 1.4.0a20251014__py3-none-any.whl → 1.5.0a20260115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nvidia-nat-mcp might be problematic. Click here for more details.

Files changed (40) hide show
  1. nat/meta/pypi.md +1 -1
  2. nat/plugins/mcp/__init__.py +1 -1
  3. nat/plugins/mcp/auth/__init__.py +1 -1
  4. nat/plugins/mcp/auth/auth_flow_handler.py +65 -1
  5. nat/plugins/mcp/auth/auth_provider.py +3 -2
  6. nat/plugins/mcp/auth/auth_provider_config.py +5 -2
  7. nat/plugins/mcp/auth/register.py +9 -1
  8. nat/plugins/mcp/auth/service_account/__init__.py +14 -0
  9. nat/plugins/mcp/auth/service_account/provider.py +136 -0
  10. nat/plugins/mcp/auth/service_account/provider_config.py +137 -0
  11. nat/plugins/mcp/auth/service_account/token_client.py +156 -0
  12. nat/plugins/mcp/auth/token_storage.py +2 -2
  13. nat/plugins/mcp/cli/__init__.py +15 -0
  14. nat/plugins/mcp/cli/commands.py +1094 -0
  15. nat/plugins/mcp/client/__init__.py +15 -0
  16. nat/plugins/mcp/{client_base.py → client/client_base.py} +18 -10
  17. nat/plugins/mcp/{client_config.py → client/client_config.py} +24 -9
  18. nat/plugins/mcp/{client_impl.py → client/client_impl.py} +253 -62
  19. nat/plugins/mcp/exception_handler.py +1 -1
  20. nat/plugins/mcp/exceptions.py +1 -1
  21. nat/plugins/mcp/register.py +5 -4
  22. nat/plugins/mcp/server/__init__.py +15 -0
  23. nat/plugins/mcp/server/front_end_config.py +109 -0
  24. nat/plugins/mcp/server/front_end_plugin.py +155 -0
  25. nat/plugins/mcp/server/front_end_plugin_worker.py +415 -0
  26. nat/plugins/mcp/server/introspection_token_verifier.py +72 -0
  27. nat/plugins/mcp/server/memory_profiler.py +320 -0
  28. nat/plugins/mcp/server/register_frontend.py +27 -0
  29. nat/plugins/mcp/server/tool_converter.py +290 -0
  30. nat/plugins/mcp/utils.py +153 -36
  31. {nvidia_nat_mcp-1.4.0a20251014.dist-info → nvidia_nat_mcp-1.5.0a20260115.dist-info}/METADATA +5 -5
  32. nvidia_nat_mcp-1.5.0a20260115.dist-info/RECORD +37 -0
  33. nvidia_nat_mcp-1.5.0a20260115.dist-info/entry_points.txt +9 -0
  34. nat/plugins/mcp/tool.py +0 -138
  35. nvidia_nat_mcp-1.4.0a20251014.dist-info/RECORD +0 -23
  36. nvidia_nat_mcp-1.4.0a20251014.dist-info/entry_points.txt +0 -3
  37. {nvidia_nat_mcp-1.4.0a20251014.dist-info → nvidia_nat_mcp-1.5.0a20260115.dist-info}/WHEEL +0 -0
  38. {nvidia_nat_mcp-1.4.0a20251014.dist-info → nvidia_nat_mcp-1.5.0a20260115.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  39. {nvidia_nat_mcp-1.4.0a20251014.dist-info → nvidia_nat_mcp-1.5.0a20260115.dist-info}/licenses/LICENSE.md +0 -0
  40. {nvidia_nat_mcp-1.4.0a20251014.dist-info → nvidia_nat_mcp-1.5.0a20260115.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,15 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """MCP client components."""
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -112,14 +112,21 @@ class AuthAdapter(httpx.Auth):
112
112
  # Use the user_id passed to this AuthAdapter instance
113
113
  auth_result = await self.auth_provider.authenticate(user_id=self.user_id, response=response)
114
114
 
115
- # Check if we have BearerTokenCred
115
+ # Build headers from credentials
116
116
  from nat.data_models.authentication import BearerTokenCred
117
- if auth_result.credentials and isinstance(auth_result.credentials[0], BearerTokenCred):
118
- token = auth_result.credentials[0].token.get_secret_value()
119
- return {"Authorization": f"Bearer {token}"}
120
- else:
121
- logger.info("Auth provider did not return BearerTokenCred")
122
- return {}
117
+ from nat.data_models.authentication import HeaderCred
118
+ headers = {}
119
+
120
+ for cred in auth_result.credentials:
121
+ if isinstance(cred, BearerTokenCred):
122
+ # Standard Bearer token
123
+ token = cred.token.get_secret_value()
124
+ headers["Authorization"] = f"Bearer {token}"
125
+ elif isinstance(cred, HeaderCred):
126
+ # Generic header credential (supports custom formats and service accounts)
127
+ headers[cred.name] = cred.value.get_secret_value()
128
+
129
+ return headers
123
130
  except Exception as e:
124
131
  logger.warning("Failed to get auth token: %s", e)
125
132
  return {}
@@ -162,8 +169,9 @@ class MCPBaseClient(ABC):
162
169
 
163
170
  # Convert auth provider to AuthAdapter
164
171
  self._auth_provider = auth_provider
165
- # Use provided user_id or fall back to auth provider's default_user_id
166
- effective_user_id = user_id or (auth_provider.config.default_user_id if auth_provider else None)
172
+ # Use provided user_id or fall back to auth provider's default_user_id (if available)
173
+ effective_user_id = user_id or (getattr(auth_provider.config, 'default_user_id', None)
174
+ if auth_provider else None)
167
175
  self._httpx_auth = AuthAdapter(auth_provider, effective_user_id) if auth_provider else None
168
176
 
169
177
  self._tool_call_timeout = tool_call_timeout
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -80,9 +80,9 @@ class MCPServerConfig(BaseModel):
80
80
  return self
81
81
 
82
82
 
83
- class MCPClientConfig(FunctionGroupBaseConfig, name="mcp_client"):
83
+ class MCPClientBaseConfig(FunctionGroupBaseConfig):
84
84
  """
85
- Configuration for connecting to an MCP server as a client and exposing selected tools.
85
+ Base configuration shared by MCP client variants.
86
86
  """
87
87
  server: MCPServerConfig = Field(..., description="Server connection details (transport, url/command, etc.)")
88
88
  tool_call_timeout: timedelta = Field(
@@ -114,6 +114,19 @@ class MCPClientConfig(FunctionGroupBaseConfig, name="mcp_client"):
114
114
  calculator_multiply:
115
115
  description: "Multiply two numbers" # alias defaults to original name
116
116
  """)
117
+
118
+ @model_validator(mode="after")
119
+ def _validate_reconnect_backoff(self) -> "MCPClientBaseConfig":
120
+ """Validate reconnect backoff values."""
121
+ if self.reconnect_max_backoff < self.reconnect_initial_backoff:
122
+ raise ValueError("reconnect_max_backoff must be greater than or equal to reconnect_initial_backoff")
123
+ return self
124
+
125
+
126
+ class MCPClientConfig(MCPClientBaseConfig, name="mcp_client"):
127
+ """
128
+ Configuration for connecting to an MCP server as a client and exposing selected tools.
129
+ """
117
130
  session_aware_tools: bool = Field(default=True,
118
131
  description="Session-aware tools are created if True. Defaults to True.")
119
132
  max_sessions: int = Field(default=100,
@@ -123,9 +136,11 @@ class MCPClientConfig(FunctionGroupBaseConfig, name="mcp_client"):
123
136
  default=timedelta(hours=1),
124
137
  description="Time after which inactive sessions are cleaned up. Defaults to 1 hour.")
125
138
 
126
- @model_validator(mode="after")
127
- def _validate_reconnect_backoff(self) -> "MCPClientConfig":
128
- """Validate reconnect backoff values."""
129
- if self.reconnect_max_backoff < self.reconnect_initial_backoff:
130
- raise ValueError("reconnect_max_backoff must be greater than or equal to reconnect_initial_backoff")
131
- return self
139
+
140
+ class PerUserMCPClientConfig(MCPClientBaseConfig, name="per_user_mcp_client"):
141
+ """
142
+ MCP Client configuration for per-user workflows that are registered with @register_per_user_function,
143
+
144
+ and each user gets their own MCP client instance.
145
+ """
146
+ pass
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -26,16 +26,77 @@ from pydantic import BaseModel
26
26
 
27
27
  from nat.authentication.interfaces import AuthProviderBase
28
28
  from nat.builder.builder import Builder
29
+ from nat.builder.context import Context
29
30
  from nat.builder.function import FunctionGroup
30
31
  from nat.cli.register_workflow import register_function_group
31
- from nat.plugins.mcp.client_base import MCPBaseClient
32
- from nat.plugins.mcp.client_config import MCPClientConfig
33
- from nat.plugins.mcp.client_config import MCPToolOverrideConfig
32
+ from nat.cli.register_workflow import register_per_user_function_group
33
+ from nat.plugins.mcp.client.client_base import MCPBaseClient
34
+ from nat.plugins.mcp.client.client_config import MCPClientConfig
35
+ from nat.plugins.mcp.client.client_config import MCPToolOverrideConfig
36
+ from nat.plugins.mcp.client.client_config import PerUserMCPClientConfig
34
37
  from nat.plugins.mcp.utils import truncate_session_id
35
38
 
36
39
  logger = logging.getLogger(__name__)
37
40
 
38
41
 
42
+ class PerUserMCPFunctionGroup(FunctionGroup):
43
+ """
44
+ A specialized FunctionGroup for per-user MCP clients.
45
+ """
46
+
47
+ def __init__(self, *args, **kwargs):
48
+ super().__init__(*args, **kwargs)
49
+
50
+ self.mcp_client: MCPBaseClient | None = None # Will be set to the actual MCP client instance
51
+ self.mcp_client_server_name: str | None = None
52
+ self.mcp_client_transport: str | None = None
53
+ self.user_id: str | None = None
54
+
55
+
56
+ def mcp_per_user_tool_function(tool, client: MCPBaseClient):
57
+ """
58
+ Create a per-user NAT function for an MCP tool.
59
+
60
+ Args:
61
+ tool: The MCP tool to create a function for
62
+ client: The MCP client to use for the function
63
+
64
+ Returns:
65
+ The NAT function
66
+ """
67
+ from nat.builder.function import FunctionInfo
68
+
69
+ def _convert_from_str(input_str: str) -> tool.input_schema:
70
+ return tool.input_schema.model_validate_json(input_str)
71
+
72
+ async def _response_fn(tool_input: BaseModel | None = None, **kwargs) -> str:
73
+ try:
74
+ mcp_tool = await client.get_tool(tool.name)
75
+
76
+ if tool_input:
77
+ args = tool_input.model_dump(exclude_none=True, mode='json')
78
+ return await mcp_tool.acall(args)
79
+
80
+ # kwargs arrives with all optional fields set to None because NAT's framework
81
+ # converts the input dict to a Pydantic model (filling in all Field(default=None)),
82
+ # then dumps it back to a dict. We need to strip out these None values because
83
+ # many MCP servers (e.g., Kaggle) reject requests with excessive null fields.
84
+ # We re-validate here (yes, redundant) to leverage Pydantic's exclude_none with
85
+ # mode='json' for recursive None removal in nested models.
86
+ # Reference: function_info.py:_convert_input_pydantic
87
+ validated_input = mcp_tool.input_schema.model_validate(kwargs)
88
+ args = validated_input.model_dump(exclude_none=True, mode='json')
89
+ return await mcp_tool.acall(args)
90
+ except Exception as e:
91
+ logger.warning("Error calling tool %s", tool.name, exc_info=True)
92
+ return str(e)
93
+
94
+ return FunctionInfo.create(single_fn=_response_fn,
95
+ description=tool.description,
96
+ input_schema=tool.input_schema,
97
+ converters=[_convert_from_str])
98
+
99
+
39
100
  @dataclass
40
101
  class SessionData:
41
102
  """Container for all session-related data."""
@@ -91,9 +152,9 @@ class MCPFunctionGroup(FunctionGroup):
91
152
  def __init__(self, *args, **kwargs):
92
153
  super().__init__(*args, **kwargs)
93
154
  # MCP client attributes with proper typing
94
- self._mcp_client = None # Will be set to the actual MCP client instance
95
- self._mcp_client_server_name: str | None = None
96
- self._mcp_client_transport: str | None = None
155
+ self.mcp_client: MCPBaseClient | None = None # Will be set to the actual MCP client instance
156
+ self.mcp_client_server_name: str | None = None
157
+ self.mcp_client_transport: str | None = None
97
158
 
98
159
  # Session management - consolidated data structure
99
160
  self._sessions: dict[str, SessionData] = {}
@@ -109,39 +170,13 @@ class MCPFunctionGroup(FunctionGroup):
109
170
  self._shared_auth_provider: AuthProviderBase | None = None
110
171
  self._client_config: MCPClientConfig | None = None
111
172
 
173
+ # Auth provider config defaults (set when auth provider is assigned)
174
+ self._default_user_id: str | None = None
175
+ self._allow_default_user_id_for_tool_calls: bool = True
176
+
112
177
  # Use random session id for testing only
113
178
  self._use_random_session_id_for_testing: bool = False
114
179
 
115
- @property
116
- def mcp_client(self):
117
- """Get the MCP client instance."""
118
- return self._mcp_client
119
-
120
- @mcp_client.setter
121
- def mcp_client(self, client):
122
- """Set the MCP client instance."""
123
- self._mcp_client = client
124
-
125
- @property
126
- def mcp_client_server_name(self) -> str | None:
127
- """Get the MCP client server name."""
128
- return self._mcp_client_server_name
129
-
130
- @mcp_client_server_name.setter
131
- def mcp_client_server_name(self, server_name: str | None):
132
- """Set the MCP client server name."""
133
- self._mcp_client_server_name = server_name
134
-
135
- @property
136
- def mcp_client_transport(self) -> str | None:
137
- """Get the MCP client transport type."""
138
- return self._mcp_client_transport
139
-
140
- @mcp_client_transport.setter
141
- def mcp_client_transport(self, transport: str | None):
142
- """Set the MCP client transport type."""
143
- self._mcp_client_transport = transport
144
-
145
180
  @property
146
181
  def session_count(self) -> int:
147
182
  """Current number of active sessions."""
@@ -176,9 +211,8 @@ class MCPFunctionGroup(FunctionGroup):
176
211
 
177
212
  if not session_id:
178
213
  # use default user id if allowed
179
- if self._shared_auth_provider and \
180
- self._shared_auth_provider.config.allow_default_user_id_for_tool_calls:
181
- session_id = self._shared_auth_provider.config.default_user_id
214
+ if self._shared_auth_provider and self._allow_default_user_id_for_tool_calls:
215
+ session_id = self._default_user_id
182
216
  return session_id
183
217
  except Exception:
184
218
  return None
@@ -255,7 +289,7 @@ class MCPFunctionGroup(FunctionGroup):
255
289
  except Exception as e:
256
290
  logger.warning("Error cleaning up session client %s: %s", truncate_session_id(session_id), e)
257
291
 
258
- async def _get_session_client(self, session_id: str) -> MCPBaseClient:
292
+ async def _get_session_client(self, session_id: str) -> MCPBaseClient | None:
259
293
  """Get the appropriate MCP client for the session."""
260
294
  # Throttled cleanup on access
261
295
  now = datetime.now()
@@ -266,8 +300,7 @@ class MCPFunctionGroup(FunctionGroup):
266
300
  # If the session_id equals the configured default_user_id use the base client
267
301
  # instead of creating a per-session client
268
302
  if self._shared_auth_provider:
269
- default_uid = self._shared_auth_provider.config.default_user_id
270
- if default_uid and session_id == default_uid:
303
+ if self._default_user_id and session_id == self._default_user_id:
271
304
  return self.mcp_client
272
305
 
273
306
  # Fast path: check if session already exists (reader lock for concurrent access)
@@ -346,7 +379,7 @@ class MCPFunctionGroup(FunctionGroup):
346
379
 
347
380
  async def _create_session_client(self, session_id: str) -> tuple[MCPBaseClient, asyncio.Event, asyncio.Task]:
348
381
  """Create a new MCP client instance for the session."""
349
- from nat.plugins.mcp.client_base import MCPStreamableHTTPClient
382
+ from nat.plugins.mcp.client.client_base import MCPStreamableHTTPClient
350
383
 
351
384
  config = self._client_config
352
385
  if not config:
@@ -435,13 +468,16 @@ def mcp_session_tool_function(tool, function_group: MCPFunctionGroup):
435
468
  return "User not authorized to call the tool"
436
469
 
437
470
  # Check if this is the default user - if so, use base client directly
438
- if (not function_group._shared_auth_provider
439
- or session_id == function_group._shared_auth_provider.config.default_user_id):
471
+ if (not function_group._shared_auth_provider or session_id == function_group._default_user_id):
440
472
  # Use base client directly for default user
441
473
  client = function_group.mcp_client
474
+ if client is None:
475
+ return "Tool temporarily unavailable. Try again."
442
476
  session_tool = await client.get_tool(tool.name)
443
477
  else:
444
478
  # Use session usage context to prevent cleanup during tool execution
479
+ if session_id is None:
480
+ return "Tool temporarily unavailable. Try again."
445
481
  async with function_group._session_usage_context(session_id) as client:
446
482
  if client is None:
447
483
  return "Tool temporarily unavailable. Try again."
@@ -449,11 +485,19 @@ def mcp_session_tool_function(tool, function_group: MCPFunctionGroup):
449
485
 
450
486
  # Preserve original calling convention
451
487
  if tool_input:
452
- args = tool_input.model_dump()
488
+ args = tool_input.model_dump(exclude_none=True, mode='json')
453
489
  return await session_tool.acall(args)
454
490
 
455
- _ = session_tool.input_schema.model_validate(kwargs)
456
- return await session_tool.acall(kwargs)
491
+ # kwargs arrives with all optional fields set to None because NAT's framework
492
+ # converts the input dict to a Pydantic model (filling in all Field(default=None)),
493
+ # then dumps it back to a dict. We need to strip out these None values because
494
+ # many MCP servers (e.g., Kaggle) reject requests with excessive null fields.
495
+ # We re-validate here (yes, redundant) to leverage Pydantic's exclude_none with
496
+ # mode='json' for recursive None removal in nested models.
497
+ # Reference: function_info.py:_convert_input_pydantic
498
+ validated_input = session_tool.input_schema.model_validate(kwargs)
499
+ args = validated_input.model_dump(exclude_none=True, mode='json')
500
+ return await session_tool.acall(args)
457
501
  except Exception as e:
458
502
  logger.warning("Error calling tool %s", tool.name, exc_info=True)
459
503
  return str(e)
@@ -475,9 +519,9 @@ async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder):
475
519
  Returns:
476
520
  The function group
477
521
  """
478
- from nat.plugins.mcp.client_base import MCPSSEClient
479
- from nat.plugins.mcp.client_base import MCPStdioClient
480
- from nat.plugins.mcp.client_base import MCPStreamableHTTPClient
522
+ from nat.plugins.mcp.client.client_base import MCPSSEClient
523
+ from nat.plugins.mcp.client.client_base import MCPStdioClient
524
+ from nat.plugins.mcp.client.client_base import MCPStreamableHTTPClient
481
525
 
482
526
  # Resolve auth provider if specified
483
527
  auth_provider = None
@@ -507,7 +551,9 @@ async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder):
507
551
  reconnect_max_backoff=config.reconnect_max_backoff)
508
552
  elif config.server.transport == "streamable-http":
509
553
  # Use default_user_id for the base client
510
- base_user_id = auth_provider.config.default_user_id if auth_provider else None
554
+ # For interactive OAuth2: from config. For service accounts: defaults to server URL
555
+ base_user_id = getattr(auth_provider.config, 'default_user_id', str(
556
+ config.server.url)) if auth_provider else None
511
557
  client = MCPStreamableHTTPClient(str(config.server.url),
512
558
  auth_provider=auth_provider,
513
559
  user_id=base_user_id,
@@ -529,6 +575,18 @@ async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder):
529
575
  group._shared_auth_provider = auth_provider
530
576
  group._client_config = config
531
577
 
578
+ # Set auth provider config defaults
579
+ # For interactive OAuth2: use config values
580
+ # For service accounts: default_user_id = server URL, allow_default_user_id_for_tool_calls = True
581
+ if auth_provider:
582
+ group._default_user_id = getattr(auth_provider.config, 'default_user_id', str(config.server.url))
583
+ group._allow_default_user_id_for_tool_calls = getattr(auth_provider.config,
584
+ 'allow_default_user_id_for_tool_calls',
585
+ True)
586
+ else:
587
+ group._default_user_id = None
588
+ group._allow_default_user_id_for_tool_calls = True
589
+
532
590
  async with client:
533
591
  # Expose the live MCP client on the function group instance so other components (e.g., HTTP endpoints)
534
592
  # can reuse the already-established session instead of creating a new client per request.
@@ -551,23 +609,16 @@ async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder):
551
609
  # Create the tool function according to configuration
552
610
  tool_fn = mcp_session_tool_function(tool, group)
553
611
 
554
- # Normalize optional typing for linter/type-checker compatibility
555
- single_fn = tool_fn.single_fn
556
- if single_fn is None:
557
- # Should not happen because FunctionInfo always sets a single_fn
558
- logger.warning("Skipping tool %s because single_fn is None", function_name)
559
- continue
560
-
561
612
  input_schema = tool_fn.input_schema
562
613
  # Convert NoneType sentinel to None for FunctionGroup.add_function signature
563
- if input_schema is type(None): # noqa: E721
614
+ if input_schema is type(None):
564
615
  input_schema = None
565
616
 
566
617
  # Add to group
567
618
  logger.info("Adding tool %s to group", function_name)
568
619
  group.add_function(name=function_name,
569
620
  description=description,
570
- fn=single_fn,
621
+ fn=tool_fn.single_fn,
571
622
  input_schema=input_schema,
572
623
  converters=tool_fn.converters)
573
624
 
@@ -589,3 +640,143 @@ def mcp_apply_tool_alias_and_description(
589
640
  return {}
590
641
 
591
642
  return {name: override for name, override in tool_overrides.items() if name in all_tools}
643
+
644
+
645
+ @register_per_user_function_group(config_type=PerUserMCPClientConfig)
646
+ async def per_user_mcp_client_function_group(config: PerUserMCPClientConfig, _builder: Builder):
647
+ """
648
+ Connect to an MCP server and expose tools as a function group for per-user workflows.
649
+
650
+ Args:
651
+ config: The configuration for the MCP client
652
+ _builder: The builder
653
+ Returns:
654
+ The function group
655
+ """
656
+ from nat.plugins.mcp.client.client_base import MCPSSEClient
657
+ from nat.plugins.mcp.client.client_base import MCPStdioClient
658
+ from nat.plugins.mcp.client.client_base import MCPStreamableHTTPClient
659
+
660
+ # Resolve auth provider if specified
661
+ auth_provider = None
662
+ if config.server.auth_provider:
663
+ auth_provider = await _builder.get_auth_provider(config.server.auth_provider)
664
+
665
+ user_id = Context.get().user_id
666
+
667
+ # Build the appropriate client
668
+ if config.server.transport == "stdio":
669
+ if not config.server.command:
670
+ raise ValueError("command is required for stdio transport")
671
+ client = MCPStdioClient(config.server.command,
672
+ config.server.args,
673
+ config.server.env,
674
+ tool_call_timeout=config.tool_call_timeout,
675
+ auth_flow_timeout=config.auth_flow_timeout,
676
+ reconnect_enabled=config.reconnect_enabled,
677
+ reconnect_max_attempts=config.reconnect_max_attempts,
678
+ reconnect_initial_backoff=config.reconnect_initial_backoff,
679
+ reconnect_max_backoff=config.reconnect_max_backoff)
680
+ elif config.server.transport == "sse":
681
+ client = MCPSSEClient(str(config.server.url),
682
+ tool_call_timeout=config.tool_call_timeout,
683
+ auth_flow_timeout=config.auth_flow_timeout,
684
+ reconnect_enabled=config.reconnect_enabled,
685
+ reconnect_max_attempts=config.reconnect_max_attempts,
686
+ reconnect_initial_backoff=config.reconnect_initial_backoff,
687
+ reconnect_max_backoff=config.reconnect_max_backoff)
688
+ elif config.server.transport == "streamable-http":
689
+ client = MCPStreamableHTTPClient(str(config.server.url),
690
+ auth_provider=auth_provider,
691
+ user_id=user_id,
692
+ tool_call_timeout=config.tool_call_timeout,
693
+ auth_flow_timeout=config.auth_flow_timeout,
694
+ reconnect_enabled=config.reconnect_enabled,
695
+ reconnect_max_attempts=config.reconnect_max_attempts,
696
+ reconnect_initial_backoff=config.reconnect_initial_backoff,
697
+ reconnect_max_backoff=config.reconnect_max_backoff)
698
+ else:
699
+ raise ValueError(f"Unsupported transport: {config.server.transport}")
700
+
701
+ logger.info("Per-user MCP client configured for server: %s (user: %s)", client.server_name, user_id)
702
+
703
+ group = PerUserMCPFunctionGroup(config=config)
704
+
705
+ # Use a lifetime task to ensure the client context is entered and exited in the same task.
706
+ # This avoids anyio's "Attempted to exit cancel scope in a different task" error.
707
+ ready = asyncio.Event()
708
+ stop_event = asyncio.Event()
709
+
710
+ async def _lifetime():
711
+ """Lifetime task that owns the client's async context."""
712
+ try:
713
+ async with client:
714
+ ready.set()
715
+ await stop_event.wait()
716
+ except Exception:
717
+ ready.set() # Ensure we don't hang the waiter
718
+ raise
719
+
720
+ lifetime_task = asyncio.create_task(_lifetime(), name=f"mcp-per-user-{user_id}")
721
+
722
+ # Wait for client initialization
723
+ timeout = config.tool_call_timeout.total_seconds()
724
+ try:
725
+ await asyncio.wait_for(ready.wait(), timeout=timeout)
726
+ except TimeoutError:
727
+ lifetime_task.cancel()
728
+ try:
729
+ await lifetime_task
730
+ except asyncio.CancelledError:
731
+ pass
732
+ raise RuntimeError(f"Per-user MCP client initialization timed out after {timeout}s")
733
+
734
+ # Check if initialization failed
735
+ if lifetime_task.done():
736
+ try:
737
+ await lifetime_task
738
+ except Exception as e:
739
+ raise RuntimeError(f"Failed to initialize per-user MCP client: {e}") from e
740
+
741
+ try:
742
+ # Expose the live MCP client on the function group instance so other components (e.g., HTTP endpoints)
743
+ # can reuse the already-established session instead of creating a new client per request.
744
+ group.mcp_client = client
745
+ group.mcp_client_server_name = client.server_name
746
+ group.mcp_client_transport = client.transport
747
+ group.user_id = user_id
748
+
749
+ all_tools = await client.get_tools()
750
+ tool_overrides = mcp_apply_tool_alias_and_description(all_tools, config.tool_overrides)
751
+
752
+ # Add each tool as a function to the group
753
+ for tool_name, tool in all_tools.items():
754
+ # Get override if it exists
755
+ override = tool_overrides.get(tool_name)
756
+
757
+ # Use override values or defaults
758
+ function_name = override.alias if override and override.alias else tool_name
759
+ description = override.description if override and override.description else tool.description
760
+
761
+ # Create the tool function according to configuration
762
+ tool_fn = mcp_per_user_tool_function(tool, client)
763
+
764
+ input_schema = tool_fn.input_schema
765
+ # Convert NoneType sentinel to None for FunctionGroup.add_function signature
766
+ if input_schema is type(None):
767
+ input_schema = None
768
+
769
+ # Add to group
770
+ logger.info("Adding tool %s to group", function_name)
771
+ group.add_function(name=function_name,
772
+ description=description,
773
+ fn=tool_fn.single_fn,
774
+ input_schema=input_schema,
775
+ converters=tool_fn.converters)
776
+
777
+ yield group
778
+ finally:
779
+ # Signal the lifetime task to exit and wait for clean shutdown
780
+ stop_event.set()
781
+ if not lifetime_task.done():
782
+ await lifetime_task
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,7 +16,8 @@
16
16
  # flake8: noqa
17
17
  # isort:skip_file
18
18
 
19
- # Import any providers which need to be automatically registered here
19
+ # Register client components
20
+ from .client import client_impl
20
21
 
21
- from . import client_impl
22
- from . import tool
22
+ # Register server/frontend components
23
+ from .server import register_frontend
@@ -0,0 +1,15 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """MCP server/frontend components."""