nvidia-nat-mcp 1.4.0a20260105__py3-none-any.whl → 1.4.0a20260117__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nvidia-nat-mcp might be problematic. Click here for more details.
- nat/plugins/mcp/auth/token_storage.py +1 -1
- nat/plugins/mcp/cli/__init__.py +15 -0
- nat/plugins/mcp/cli/commands.py +1055 -0
- nat/plugins/mcp/client/__init__.py +15 -0
- nat/plugins/mcp/{client_config.py → client/client_config.py} +23 -8
- nat/plugins/mcp/{client_impl.py → client/client_impl.py} +218 -50
- nat/plugins/mcp/register.py +4 -3
- nat/plugins/mcp/server/__init__.py +15 -0
- nat/plugins/mcp/server/front_end_config.py +109 -0
- nat/plugins/mcp/server/front_end_plugin.py +155 -0
- nat/plugins/mcp/server/front_end_plugin_worker.py +415 -0
- nat/plugins/mcp/server/introspection_token_verifier.py +72 -0
- nat/plugins/mcp/server/memory_profiler.py +320 -0
- nat/plugins/mcp/server/register_frontend.py +27 -0
- nat/plugins/mcp/server/tool_converter.py +290 -0
- {nvidia_nat_mcp-1.4.0a20260105.dist-info → nvidia_nat_mcp-1.4.0a20260117.dist-info}/METADATA +3 -3
- nvidia_nat_mcp-1.4.0a20260117.dist-info/RECORD +37 -0
- nvidia_nat_mcp-1.4.0a20260117.dist-info/entry_points.txt +9 -0
- nat/plugins/mcp/tool.py +0 -138
- nvidia_nat_mcp-1.4.0a20260105.dist-info/RECORD +0 -27
- nvidia_nat_mcp-1.4.0a20260105.dist-info/entry_points.txt +0 -3
- /nat/plugins/mcp/{client_base.py → client/client_base.py} +0 -0
- {nvidia_nat_mcp-1.4.0a20260105.dist-info → nvidia_nat_mcp-1.4.0a20260117.dist-info}/WHEEL +0 -0
- {nvidia_nat_mcp-1.4.0a20260105.dist-info → nvidia_nat_mcp-1.4.0a20260117.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat_mcp-1.4.0a20260105.dist-info → nvidia_nat_mcp-1.4.0a20260117.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat_mcp-1.4.0a20260105.dist-info → nvidia_nat_mcp-1.4.0a20260117.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""MCP client components."""
|
|
@@ -80,9 +80,9 @@ class MCPServerConfig(BaseModel):
|
|
|
80
80
|
return self
|
|
81
81
|
|
|
82
82
|
|
|
83
|
-
class
|
|
83
|
+
class MCPClientBaseConfig(FunctionGroupBaseConfig):
|
|
84
84
|
"""
|
|
85
|
-
|
|
85
|
+
Base configuration shared by MCP client variants.
|
|
86
86
|
"""
|
|
87
87
|
server: MCPServerConfig = Field(..., description="Server connection details (transport, url/command, etc.)")
|
|
88
88
|
tool_call_timeout: timedelta = Field(
|
|
@@ -114,6 +114,19 @@ class MCPClientConfig(FunctionGroupBaseConfig, name="mcp_client"):
|
|
|
114
114
|
calculator_multiply:
|
|
115
115
|
description: "Multiply two numbers" # alias defaults to original name
|
|
116
116
|
""")
|
|
117
|
+
|
|
118
|
+
@model_validator(mode="after")
|
|
119
|
+
def _validate_reconnect_backoff(self) -> "MCPClientBaseConfig":
|
|
120
|
+
"""Validate reconnect backoff values."""
|
|
121
|
+
if self.reconnect_max_backoff < self.reconnect_initial_backoff:
|
|
122
|
+
raise ValueError("reconnect_max_backoff must be greater than or equal to reconnect_initial_backoff")
|
|
123
|
+
return self
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class MCPClientConfig(MCPClientBaseConfig, name="mcp_client"):
|
|
127
|
+
"""
|
|
128
|
+
Configuration for connecting to an MCP server as a client and exposing selected tools.
|
|
129
|
+
"""
|
|
117
130
|
session_aware_tools: bool = Field(default=True,
|
|
118
131
|
description="Session-aware tools are created if True. Defaults to True.")
|
|
119
132
|
max_sessions: int = Field(default=100,
|
|
@@ -123,9 +136,11 @@ class MCPClientConfig(FunctionGroupBaseConfig, name="mcp_client"):
|
|
|
123
136
|
default=timedelta(hours=1),
|
|
124
137
|
description="Time after which inactive sessions are cleaned up. Defaults to 1 hour.")
|
|
125
138
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
139
|
+
|
|
140
|
+
class PerUserMCPClientConfig(MCPClientBaseConfig, name="per_user_mcp_client"):
|
|
141
|
+
"""
|
|
142
|
+
MCP Client configuration for per-user workflows that are registered with @register_per_user_function,
|
|
143
|
+
|
|
144
|
+
and each user gets their own MCP client instance.
|
|
145
|
+
"""
|
|
146
|
+
pass
|
|
@@ -26,16 +26,77 @@ from pydantic import BaseModel
|
|
|
26
26
|
|
|
27
27
|
from nat.authentication.interfaces import AuthProviderBase
|
|
28
28
|
from nat.builder.builder import Builder
|
|
29
|
+
from nat.builder.context import Context
|
|
29
30
|
from nat.builder.function import FunctionGroup
|
|
30
31
|
from nat.cli.register_workflow import register_function_group
|
|
31
|
-
from nat.
|
|
32
|
-
from nat.plugins.mcp.
|
|
33
|
-
from nat.plugins.mcp.client_config import
|
|
32
|
+
from nat.cli.register_workflow import register_per_user_function_group
|
|
33
|
+
from nat.plugins.mcp.client.client_base import MCPBaseClient
|
|
34
|
+
from nat.plugins.mcp.client.client_config import MCPClientConfig
|
|
35
|
+
from nat.plugins.mcp.client.client_config import MCPToolOverrideConfig
|
|
36
|
+
from nat.plugins.mcp.client.client_config import PerUserMCPClientConfig
|
|
34
37
|
from nat.plugins.mcp.utils import truncate_session_id
|
|
35
38
|
|
|
36
39
|
logger = logging.getLogger(__name__)
|
|
37
40
|
|
|
38
41
|
|
|
42
|
+
class PerUserMCPFunctionGroup(FunctionGroup):
|
|
43
|
+
"""
|
|
44
|
+
A specialized FunctionGroup for per-user MCP clients.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(self, *args, **kwargs):
|
|
48
|
+
super().__init__(*args, **kwargs)
|
|
49
|
+
|
|
50
|
+
self.mcp_client: MCPBaseClient | None = None # Will be set to the actual MCP client instance
|
|
51
|
+
self.mcp_client_server_name: str | None = None
|
|
52
|
+
self.mcp_client_transport: str | None = None
|
|
53
|
+
self.user_id: str | None = None
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def mcp_per_user_tool_function(tool, client: MCPBaseClient):
|
|
57
|
+
"""
|
|
58
|
+
Create a per-user NAT function for an MCP tool.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
tool: The MCP tool to create a function for
|
|
62
|
+
client: The MCP client to use for the function
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
The NAT function
|
|
66
|
+
"""
|
|
67
|
+
from nat.builder.function import FunctionInfo
|
|
68
|
+
|
|
69
|
+
def _convert_from_str(input_str: str) -> tool.input_schema:
|
|
70
|
+
return tool.input_schema.model_validate_json(input_str)
|
|
71
|
+
|
|
72
|
+
async def _response_fn(tool_input: BaseModel | None = None, **kwargs) -> str:
|
|
73
|
+
try:
|
|
74
|
+
mcp_tool = await client.get_tool(tool.name)
|
|
75
|
+
|
|
76
|
+
if tool_input:
|
|
77
|
+
args = tool_input.model_dump(exclude_none=True, mode='json')
|
|
78
|
+
return await mcp_tool.acall(args)
|
|
79
|
+
|
|
80
|
+
# kwargs arrives with all optional fields set to None because NAT's framework
|
|
81
|
+
# converts the input dict to a Pydantic model (filling in all Field(default=None)),
|
|
82
|
+
# then dumps it back to a dict. We need to strip out these None values because
|
|
83
|
+
# many MCP servers (e.g., Kaggle) reject requests with excessive null fields.
|
|
84
|
+
# We re-validate here (yes, redundant) to leverage Pydantic's exclude_none with
|
|
85
|
+
# mode='json' for recursive None removal in nested models.
|
|
86
|
+
# Reference: function_info.py:_convert_input_pydantic
|
|
87
|
+
validated_input = mcp_tool.input_schema.model_validate(kwargs)
|
|
88
|
+
args = validated_input.model_dump(exclude_none=True, mode='json')
|
|
89
|
+
return await mcp_tool.acall(args)
|
|
90
|
+
except Exception as e:
|
|
91
|
+
logger.warning("Error calling tool %s", tool.name, exc_info=True)
|
|
92
|
+
return str(e)
|
|
93
|
+
|
|
94
|
+
return FunctionInfo.create(single_fn=_response_fn,
|
|
95
|
+
description=tool.description,
|
|
96
|
+
input_schema=tool.input_schema,
|
|
97
|
+
converters=[_convert_from_str])
|
|
98
|
+
|
|
99
|
+
|
|
39
100
|
@dataclass
|
|
40
101
|
class SessionData:
|
|
41
102
|
"""Container for all session-related data."""
|
|
@@ -91,9 +152,9 @@ class MCPFunctionGroup(FunctionGroup):
|
|
|
91
152
|
def __init__(self, *args, **kwargs):
|
|
92
153
|
super().__init__(*args, **kwargs)
|
|
93
154
|
# MCP client attributes with proper typing
|
|
94
|
-
self.
|
|
95
|
-
self.
|
|
96
|
-
self.
|
|
155
|
+
self.mcp_client: MCPBaseClient | None = None # Will be set to the actual MCP client instance
|
|
156
|
+
self.mcp_client_server_name: str | None = None
|
|
157
|
+
self.mcp_client_transport: str | None = None
|
|
97
158
|
|
|
98
159
|
# Session management - consolidated data structure
|
|
99
160
|
self._sessions: dict[str, SessionData] = {}
|
|
@@ -116,36 +177,6 @@ class MCPFunctionGroup(FunctionGroup):
|
|
|
116
177
|
# Use random session id for testing only
|
|
117
178
|
self._use_random_session_id_for_testing: bool = False
|
|
118
179
|
|
|
119
|
-
@property
|
|
120
|
-
def mcp_client(self):
|
|
121
|
-
"""Get the MCP client instance."""
|
|
122
|
-
return self._mcp_client
|
|
123
|
-
|
|
124
|
-
@mcp_client.setter
|
|
125
|
-
def mcp_client(self, client):
|
|
126
|
-
"""Set the MCP client instance."""
|
|
127
|
-
self._mcp_client = client
|
|
128
|
-
|
|
129
|
-
@property
|
|
130
|
-
def mcp_client_server_name(self) -> str | None:
|
|
131
|
-
"""Get the MCP client server name."""
|
|
132
|
-
return self._mcp_client_server_name
|
|
133
|
-
|
|
134
|
-
@mcp_client_server_name.setter
|
|
135
|
-
def mcp_client_server_name(self, server_name: str | None):
|
|
136
|
-
"""Set the MCP client server name."""
|
|
137
|
-
self._mcp_client_server_name = server_name
|
|
138
|
-
|
|
139
|
-
@property
|
|
140
|
-
def mcp_client_transport(self) -> str | None:
|
|
141
|
-
"""Get the MCP client transport type."""
|
|
142
|
-
return self._mcp_client_transport
|
|
143
|
-
|
|
144
|
-
@mcp_client_transport.setter
|
|
145
|
-
def mcp_client_transport(self, transport: str | None):
|
|
146
|
-
"""Set the MCP client transport type."""
|
|
147
|
-
self._mcp_client_transport = transport
|
|
148
|
-
|
|
149
180
|
@property
|
|
150
181
|
def session_count(self) -> int:
|
|
151
182
|
"""Current number of active sessions."""
|
|
@@ -258,7 +289,7 @@ class MCPFunctionGroup(FunctionGroup):
|
|
|
258
289
|
except Exception as e:
|
|
259
290
|
logger.warning("Error cleaning up session client %s: %s", truncate_session_id(session_id), e)
|
|
260
291
|
|
|
261
|
-
async def _get_session_client(self, session_id: str) -> MCPBaseClient:
|
|
292
|
+
async def _get_session_client(self, session_id: str) -> MCPBaseClient | None:
|
|
262
293
|
"""Get the appropriate MCP client for the session."""
|
|
263
294
|
# Throttled cleanup on access
|
|
264
295
|
now = datetime.now()
|
|
@@ -348,7 +379,7 @@ class MCPFunctionGroup(FunctionGroup):
|
|
|
348
379
|
|
|
349
380
|
async def _create_session_client(self, session_id: str) -> tuple[MCPBaseClient, asyncio.Event, asyncio.Task]:
|
|
350
381
|
"""Create a new MCP client instance for the session."""
|
|
351
|
-
from nat.plugins.mcp.client_base import MCPStreamableHTTPClient
|
|
382
|
+
from nat.plugins.mcp.client.client_base import MCPStreamableHTTPClient
|
|
352
383
|
|
|
353
384
|
config = self._client_config
|
|
354
385
|
if not config:
|
|
@@ -440,9 +471,13 @@ def mcp_session_tool_function(tool, function_group: MCPFunctionGroup):
|
|
|
440
471
|
if (not function_group._shared_auth_provider or session_id == function_group._default_user_id):
|
|
441
472
|
# Use base client directly for default user
|
|
442
473
|
client = function_group.mcp_client
|
|
474
|
+
if client is None:
|
|
475
|
+
return "Tool temporarily unavailable. Try again."
|
|
443
476
|
session_tool = await client.get_tool(tool.name)
|
|
444
477
|
else:
|
|
445
478
|
# Use session usage context to prevent cleanup during tool execution
|
|
479
|
+
if session_id is None:
|
|
480
|
+
return "Tool temporarily unavailable. Try again."
|
|
446
481
|
async with function_group._session_usage_context(session_id) as client:
|
|
447
482
|
if client is None:
|
|
448
483
|
return "Tool temporarily unavailable. Try again."
|
|
@@ -484,9 +519,9 @@ async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder):
|
|
|
484
519
|
Returns:
|
|
485
520
|
The function group
|
|
486
521
|
"""
|
|
487
|
-
from nat.plugins.mcp.client_base import MCPSSEClient
|
|
488
|
-
from nat.plugins.mcp.client_base import MCPStdioClient
|
|
489
|
-
from nat.plugins.mcp.client_base import MCPStreamableHTTPClient
|
|
522
|
+
from nat.plugins.mcp.client.client_base import MCPSSEClient
|
|
523
|
+
from nat.plugins.mcp.client.client_base import MCPStdioClient
|
|
524
|
+
from nat.plugins.mcp.client.client_base import MCPStreamableHTTPClient
|
|
490
525
|
|
|
491
526
|
# Resolve auth provider if specified
|
|
492
527
|
auth_provider = None
|
|
@@ -574,23 +609,16 @@ async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder):
|
|
|
574
609
|
# Create the tool function according to configuration
|
|
575
610
|
tool_fn = mcp_session_tool_function(tool, group)
|
|
576
611
|
|
|
577
|
-
# Normalize optional typing for linter/type-checker compatibility
|
|
578
|
-
single_fn = tool_fn.single_fn
|
|
579
|
-
if single_fn is None:
|
|
580
|
-
# Should not happen because FunctionInfo always sets a single_fn
|
|
581
|
-
logger.warning("Skipping tool %s because single_fn is None", function_name)
|
|
582
|
-
continue
|
|
583
|
-
|
|
584
612
|
input_schema = tool_fn.input_schema
|
|
585
613
|
# Convert NoneType sentinel to None for FunctionGroup.add_function signature
|
|
586
|
-
if input_schema is type(None):
|
|
614
|
+
if input_schema is type(None):
|
|
587
615
|
input_schema = None
|
|
588
616
|
|
|
589
617
|
# Add to group
|
|
590
618
|
logger.info("Adding tool %s to group", function_name)
|
|
591
619
|
group.add_function(name=function_name,
|
|
592
620
|
description=description,
|
|
593
|
-
fn=single_fn,
|
|
621
|
+
fn=tool_fn.single_fn,
|
|
594
622
|
input_schema=input_schema,
|
|
595
623
|
converters=tool_fn.converters)
|
|
596
624
|
|
|
@@ -612,3 +640,143 @@ def mcp_apply_tool_alias_and_description(
|
|
|
612
640
|
return {}
|
|
613
641
|
|
|
614
642
|
return {name: override for name, override in tool_overrides.items() if name in all_tools}
|
|
643
|
+
|
|
644
|
+
|
|
645
|
+
@register_per_user_function_group(config_type=PerUserMCPClientConfig)
|
|
646
|
+
async def per_user_mcp_client_function_group(config: PerUserMCPClientConfig, _builder: Builder):
|
|
647
|
+
"""
|
|
648
|
+
Connect to an MCP server and expose tools as a function group for per-user workflows.
|
|
649
|
+
|
|
650
|
+
Args:
|
|
651
|
+
config: The configuration for the MCP client
|
|
652
|
+
_builder: The builder
|
|
653
|
+
Returns:
|
|
654
|
+
The function group
|
|
655
|
+
"""
|
|
656
|
+
from nat.plugins.mcp.client.client_base import MCPSSEClient
|
|
657
|
+
from nat.plugins.mcp.client.client_base import MCPStdioClient
|
|
658
|
+
from nat.plugins.mcp.client.client_base import MCPStreamableHTTPClient
|
|
659
|
+
|
|
660
|
+
# Resolve auth provider if specified
|
|
661
|
+
auth_provider = None
|
|
662
|
+
if config.server.auth_provider:
|
|
663
|
+
auth_provider = await _builder.get_auth_provider(config.server.auth_provider)
|
|
664
|
+
|
|
665
|
+
user_id = Context.get().user_id
|
|
666
|
+
|
|
667
|
+
# Build the appropriate client
|
|
668
|
+
if config.server.transport == "stdio":
|
|
669
|
+
if not config.server.command:
|
|
670
|
+
raise ValueError("command is required for stdio transport")
|
|
671
|
+
client = MCPStdioClient(config.server.command,
|
|
672
|
+
config.server.args,
|
|
673
|
+
config.server.env,
|
|
674
|
+
tool_call_timeout=config.tool_call_timeout,
|
|
675
|
+
auth_flow_timeout=config.auth_flow_timeout,
|
|
676
|
+
reconnect_enabled=config.reconnect_enabled,
|
|
677
|
+
reconnect_max_attempts=config.reconnect_max_attempts,
|
|
678
|
+
reconnect_initial_backoff=config.reconnect_initial_backoff,
|
|
679
|
+
reconnect_max_backoff=config.reconnect_max_backoff)
|
|
680
|
+
elif config.server.transport == "sse":
|
|
681
|
+
client = MCPSSEClient(str(config.server.url),
|
|
682
|
+
tool_call_timeout=config.tool_call_timeout,
|
|
683
|
+
auth_flow_timeout=config.auth_flow_timeout,
|
|
684
|
+
reconnect_enabled=config.reconnect_enabled,
|
|
685
|
+
reconnect_max_attempts=config.reconnect_max_attempts,
|
|
686
|
+
reconnect_initial_backoff=config.reconnect_initial_backoff,
|
|
687
|
+
reconnect_max_backoff=config.reconnect_max_backoff)
|
|
688
|
+
elif config.server.transport == "streamable-http":
|
|
689
|
+
client = MCPStreamableHTTPClient(str(config.server.url),
|
|
690
|
+
auth_provider=auth_provider,
|
|
691
|
+
user_id=user_id,
|
|
692
|
+
tool_call_timeout=config.tool_call_timeout,
|
|
693
|
+
auth_flow_timeout=config.auth_flow_timeout,
|
|
694
|
+
reconnect_enabled=config.reconnect_enabled,
|
|
695
|
+
reconnect_max_attempts=config.reconnect_max_attempts,
|
|
696
|
+
reconnect_initial_backoff=config.reconnect_initial_backoff,
|
|
697
|
+
reconnect_max_backoff=config.reconnect_max_backoff)
|
|
698
|
+
else:
|
|
699
|
+
raise ValueError(f"Unsupported transport: {config.server.transport}")
|
|
700
|
+
|
|
701
|
+
logger.info("Per-user MCP client configured for server: %s (user: %s)", client.server_name, user_id)
|
|
702
|
+
|
|
703
|
+
group = PerUserMCPFunctionGroup(config=config)
|
|
704
|
+
|
|
705
|
+
# Use a lifetime task to ensure the client context is entered and exited in the same task.
|
|
706
|
+
# This avoids anyio's "Attempted to exit cancel scope in a different task" error.
|
|
707
|
+
ready = asyncio.Event()
|
|
708
|
+
stop_event = asyncio.Event()
|
|
709
|
+
|
|
710
|
+
async def _lifetime():
|
|
711
|
+
"""Lifetime task that owns the client's async context."""
|
|
712
|
+
try:
|
|
713
|
+
async with client:
|
|
714
|
+
ready.set()
|
|
715
|
+
await stop_event.wait()
|
|
716
|
+
except Exception:
|
|
717
|
+
ready.set() # Ensure we don't hang the waiter
|
|
718
|
+
raise
|
|
719
|
+
|
|
720
|
+
lifetime_task = asyncio.create_task(_lifetime(), name=f"mcp-per-user-{user_id}")
|
|
721
|
+
|
|
722
|
+
# Wait for client initialization
|
|
723
|
+
timeout = config.tool_call_timeout.total_seconds()
|
|
724
|
+
try:
|
|
725
|
+
await asyncio.wait_for(ready.wait(), timeout=timeout)
|
|
726
|
+
except TimeoutError:
|
|
727
|
+
lifetime_task.cancel()
|
|
728
|
+
try:
|
|
729
|
+
await lifetime_task
|
|
730
|
+
except asyncio.CancelledError:
|
|
731
|
+
pass
|
|
732
|
+
raise RuntimeError(f"Per-user MCP client initialization timed out after {timeout}s")
|
|
733
|
+
|
|
734
|
+
# Check if initialization failed
|
|
735
|
+
if lifetime_task.done():
|
|
736
|
+
try:
|
|
737
|
+
await lifetime_task
|
|
738
|
+
except Exception as e:
|
|
739
|
+
raise RuntimeError(f"Failed to initialize per-user MCP client: {e}") from e
|
|
740
|
+
|
|
741
|
+
try:
|
|
742
|
+
# Expose the live MCP client on the function group instance so other components (e.g., HTTP endpoints)
|
|
743
|
+
# can reuse the already-established session instead of creating a new client per request.
|
|
744
|
+
group.mcp_client = client
|
|
745
|
+
group.mcp_client_server_name = client.server_name
|
|
746
|
+
group.mcp_client_transport = client.transport
|
|
747
|
+
group.user_id = user_id
|
|
748
|
+
|
|
749
|
+
all_tools = await client.get_tools()
|
|
750
|
+
tool_overrides = mcp_apply_tool_alias_and_description(all_tools, config.tool_overrides)
|
|
751
|
+
|
|
752
|
+
# Add each tool as a function to the group
|
|
753
|
+
for tool_name, tool in all_tools.items():
|
|
754
|
+
# Get override if it exists
|
|
755
|
+
override = tool_overrides.get(tool_name)
|
|
756
|
+
|
|
757
|
+
# Use override values or defaults
|
|
758
|
+
function_name = override.alias if override and override.alias else tool_name
|
|
759
|
+
description = override.description if override and override.description else tool.description
|
|
760
|
+
|
|
761
|
+
# Create the tool function according to configuration
|
|
762
|
+
tool_fn = mcp_per_user_tool_function(tool, client)
|
|
763
|
+
|
|
764
|
+
input_schema = tool_fn.input_schema
|
|
765
|
+
# Convert NoneType sentinel to None for FunctionGroup.add_function signature
|
|
766
|
+
if input_schema is type(None):
|
|
767
|
+
input_schema = None
|
|
768
|
+
|
|
769
|
+
# Add to group
|
|
770
|
+
logger.info("Adding tool %s to group", function_name)
|
|
771
|
+
group.add_function(name=function_name,
|
|
772
|
+
description=description,
|
|
773
|
+
fn=tool_fn.single_fn,
|
|
774
|
+
input_schema=input_schema,
|
|
775
|
+
converters=tool_fn.converters)
|
|
776
|
+
|
|
777
|
+
yield group
|
|
778
|
+
finally:
|
|
779
|
+
# Signal the lifetime task to exit and wait for clean shutdown
|
|
780
|
+
stop_event.set()
|
|
781
|
+
if not lifetime_task.done():
|
|
782
|
+
await lifetime_task
|
nat/plugins/mcp/register.py
CHANGED
|
@@ -16,7 +16,8 @@
|
|
|
16
16
|
# flake8: noqa
|
|
17
17
|
# isort:skip_file
|
|
18
18
|
|
|
19
|
-
#
|
|
19
|
+
# Register client components
|
|
20
|
+
from .client import client_impl
|
|
20
21
|
|
|
21
|
-
|
|
22
|
-
from . import
|
|
22
|
+
# Register server/frontend components
|
|
23
|
+
from .server import register_frontend
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""MCP server/frontend components."""
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
from typing import Literal
|
|
18
|
+
|
|
19
|
+
from pydantic import Field
|
|
20
|
+
from pydantic import field_validator
|
|
21
|
+
from pydantic import model_validator
|
|
22
|
+
|
|
23
|
+
from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
|
|
24
|
+
from nat.data_models.front_end import FrontEndBaseConfig
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class MCPFrontEndConfig(FrontEndBaseConfig, name="mcp"):
|
|
30
|
+
"""MCP front end configuration.
|
|
31
|
+
|
|
32
|
+
A simple MCP (Model Context Protocol) front end for NeMo Agent toolkit.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
name: str = Field(default="NeMo Agent Toolkit MCP",
|
|
36
|
+
description="Name of the MCP server (default: NeMo Agent Toolkit MCP)")
|
|
37
|
+
host: str = Field(default="localhost", description="Host to bind the server to (default: localhost)")
|
|
38
|
+
port: int = Field(default=9901, description="Port to bind the server to (default: 9901)", ge=0, le=65535)
|
|
39
|
+
debug: bool = Field(default=False, description="Enable debug mode (default: False)")
|
|
40
|
+
log_level: str = Field(default="INFO", description="Log level for the MCP server (default: INFO)")
|
|
41
|
+
tool_names: list[str] = Field(
|
|
42
|
+
default_factory=list,
|
|
43
|
+
description="The list of tools MCP server will expose (default: all tools)."
|
|
44
|
+
"Tool names can be functions or function groups",
|
|
45
|
+
)
|
|
46
|
+
transport: Literal["sse", "streamable-http"] = Field(
|
|
47
|
+
default="streamable-http",
|
|
48
|
+
description="Transport type for the MCP server (default: streamable-http, backwards compatible with sse)")
|
|
49
|
+
runner_class: str | None = Field(
|
|
50
|
+
default=None, description="Custom worker class for handling MCP routes (default: built-in worker)")
|
|
51
|
+
base_path: str | None = Field(default=None,
|
|
52
|
+
description="Base path to mount the MCP server at (e.g., '/api/v1'). "
|
|
53
|
+
"If specified, the server will be accessible at http://host:port{base_path}/mcp. "
|
|
54
|
+
"If None, server runs at root path /mcp.")
|
|
55
|
+
|
|
56
|
+
server_auth: OAuth2ResourceServerConfig | None = Field(
|
|
57
|
+
default=None, description=("OAuth 2.0 Resource Server configuration for token verification."))
|
|
58
|
+
|
|
59
|
+
@field_validator('base_path')
|
|
60
|
+
@classmethod
|
|
61
|
+
def validate_base_path(cls, v: str | None) -> str | None:
|
|
62
|
+
"""Validate that base_path starts with '/' and doesn't end with '/'."""
|
|
63
|
+
if v is not None:
|
|
64
|
+
if not v.startswith('/'):
|
|
65
|
+
raise ValueError("base_path must start with '/'")
|
|
66
|
+
if v.endswith('/'):
|
|
67
|
+
raise ValueError("base_path must not end with '/'")
|
|
68
|
+
return v
|
|
69
|
+
|
|
70
|
+
# Memory profiling configuration
|
|
71
|
+
enable_memory_profiling: bool = Field(default=False,
|
|
72
|
+
description="Enable memory profiling and diagnostics (default: False)")
|
|
73
|
+
memory_profile_interval: int = Field(default=50,
|
|
74
|
+
description="Log memory stats every N requests (default: 50)",
|
|
75
|
+
ge=1)
|
|
76
|
+
memory_profile_top_n: int = Field(default=10,
|
|
77
|
+
description="Number of top memory allocations to log (default: 10)",
|
|
78
|
+
ge=1,
|
|
79
|
+
le=50)
|
|
80
|
+
memory_profile_log_level: str = Field(default="DEBUG",
|
|
81
|
+
description="Log level for memory profiling output (default: DEBUG)")
|
|
82
|
+
|
|
83
|
+
@model_validator(mode="after")
|
|
84
|
+
def validate_security_configuration(self):
|
|
85
|
+
"""Validate security configuration to prevent accidental misconfigurations."""
|
|
86
|
+
# Check if server is bound to a non-localhost interface without authentication
|
|
87
|
+
localhost_hosts = {"localhost", "127.0.0.1", "::1"}
|
|
88
|
+
if self.host not in localhost_hosts and self.server_auth is None:
|
|
89
|
+
logger.warning(
|
|
90
|
+
"MCP server is configured to bind to '%s' without authentication. "
|
|
91
|
+
"This may expose your server to unauthorized access. "
|
|
92
|
+
"Consider either: (1) binding to localhost for local-only access, "
|
|
93
|
+
"or (2) configuring server_auth for production deployments on public interfaces.",
|
|
94
|
+
self.host)
|
|
95
|
+
|
|
96
|
+
# Check if SSE transport is used (which doesn't support authentication)
|
|
97
|
+
if self.transport == "sse":
|
|
98
|
+
if self.server_auth is not None:
|
|
99
|
+
logger.warning("SSE transport does not support authentication. "
|
|
100
|
+
"The configured server_auth will be ignored. "
|
|
101
|
+
"For production use with authentication, use 'streamable-http' transport instead.")
|
|
102
|
+
elif self.host not in localhost_hosts:
|
|
103
|
+
logger.warning(
|
|
104
|
+
"SSE transport does not support authentication and is bound to '%s'. "
|
|
105
|
+
"This configuration is not recommended for production use. "
|
|
106
|
+
"For production deployments, use 'streamable-http' transport with server_auth configured.",
|
|
107
|
+
self.host)
|
|
108
|
+
|
|
109
|
+
return self
|