nvidia-nat-mcp 1.4.0a20260117__py3-none-any.whl → 1.5.0a20260103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nvidia-nat-mcp might be problematic. Click here for more details.
- nat/meta/pypi.md +1 -1
- nat/plugins/mcp/__init__.py +1 -1
- nat/plugins/mcp/auth/__init__.py +1 -1
- nat/plugins/mcp/auth/auth_flow_handler.py +1 -1
- nat/plugins/mcp/auth/auth_provider.py +1 -1
- nat/plugins/mcp/auth/auth_provider_config.py +1 -1
- nat/plugins/mcp/auth/register.py +1 -1
- nat/plugins/mcp/auth/service_account/__init__.py +1 -1
- nat/plugins/mcp/auth/service_account/provider.py +1 -1
- nat/plugins/mcp/auth/service_account/provider_config.py +1 -1
- nat/plugins/mcp/auth/service_account/token_client.py +1 -1
- nat/plugins/mcp/auth/token_storage.py +2 -2
- nat/plugins/mcp/{client/client_base.py → client_base.py} +1 -1
- nat/plugins/mcp/{client/client_config.py → client_config.py} +9 -24
- nat/plugins/mcp/{client/client_impl.py → client_impl.py} +51 -219
- nat/plugins/mcp/exception_handler.py +1 -1
- nat/plugins/mcp/exceptions.py +1 -1
- nat/plugins/mcp/register.py +4 -5
- nat/plugins/mcp/tool.py +138 -0
- nat/plugins/mcp/utils.py +1 -1
- {nvidia_nat_mcp-1.4.0a20260117.dist-info → nvidia_nat_mcp-1.5.0a20260103.dist-info}/METADATA +5 -5
- nvidia_nat_mcp-1.5.0a20260103.dist-info/RECORD +27 -0
- nvidia_nat_mcp-1.5.0a20260103.dist-info/entry_points.txt +3 -0
- nat/plugins/mcp/cli/__init__.py +0 -15
- nat/plugins/mcp/cli/commands.py +0 -1055
- nat/plugins/mcp/client/__init__.py +0 -15
- nat/plugins/mcp/server/__init__.py +0 -15
- nat/plugins/mcp/server/front_end_config.py +0 -109
- nat/plugins/mcp/server/front_end_plugin.py +0 -155
- nat/plugins/mcp/server/front_end_plugin_worker.py +0 -415
- nat/plugins/mcp/server/introspection_token_verifier.py +0 -72
- nat/plugins/mcp/server/memory_profiler.py +0 -320
- nat/plugins/mcp/server/register_frontend.py +0 -27
- nat/plugins/mcp/server/tool_converter.py +0 -290
- nvidia_nat_mcp-1.4.0a20260117.dist-info/RECORD +0 -37
- nvidia_nat_mcp-1.4.0a20260117.dist-info/entry_points.txt +0 -9
- {nvidia_nat_mcp-1.4.0a20260117.dist-info → nvidia_nat_mcp-1.5.0a20260103.dist-info}/WHEEL +0 -0
- {nvidia_nat_mcp-1.4.0a20260117.dist-info → nvidia_nat_mcp-1.5.0a20260103.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat_mcp-1.4.0a20260117.dist-info → nvidia_nat_mcp-1.5.0a20260103.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat_mcp-1.4.0a20260117.dist-info → nvidia_nat_mcp-1.5.0a20260103.dist-info}/top_level.txt +0 -0
nat/meta/pypi.md
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
<!--
|
|
2
|
-
SPDX-FileCopyrightText: Copyright (c) 2025
|
|
2
|
+
SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
3
|
SPDX-License-Identifier: Apache-2.0
|
|
4
4
|
|
|
5
5
|
Licensed under the Apache License, Version 2.0 (the "License");
|
nat/plugins/mcp/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
nat/plugins/mcp/auth/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
nat/plugins/mcp/auth/register.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -206,7 +206,7 @@ class ObjectStoreTokenStorage(TokenStorageBase):
|
|
|
206
206
|
|
|
207
207
|
class InMemoryTokenStorage(TokenStorageBase):
|
|
208
208
|
"""
|
|
209
|
-
In-memory token storage using
|
|
209
|
+
In-memory token storage using NeMo Agent toolkit's built-in object store.
|
|
210
210
|
|
|
211
211
|
This implementation uses the in-memory object store for token persistence,
|
|
212
212
|
which provides a secure default option that doesn't require external storage
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -80,9 +80,9 @@ class MCPServerConfig(BaseModel):
|
|
|
80
80
|
return self
|
|
81
81
|
|
|
82
82
|
|
|
83
|
-
class
|
|
83
|
+
class MCPClientConfig(FunctionGroupBaseConfig, name="mcp_client"):
|
|
84
84
|
"""
|
|
85
|
-
|
|
85
|
+
Configuration for connecting to an MCP server as a client and exposing selected tools.
|
|
86
86
|
"""
|
|
87
87
|
server: MCPServerConfig = Field(..., description="Server connection details (transport, url/command, etc.)")
|
|
88
88
|
tool_call_timeout: timedelta = Field(
|
|
@@ -114,19 +114,6 @@ class MCPClientBaseConfig(FunctionGroupBaseConfig):
|
|
|
114
114
|
calculator_multiply:
|
|
115
115
|
description: "Multiply two numbers" # alias defaults to original name
|
|
116
116
|
""")
|
|
117
|
-
|
|
118
|
-
@model_validator(mode="after")
|
|
119
|
-
def _validate_reconnect_backoff(self) -> "MCPClientBaseConfig":
|
|
120
|
-
"""Validate reconnect backoff values."""
|
|
121
|
-
if self.reconnect_max_backoff < self.reconnect_initial_backoff:
|
|
122
|
-
raise ValueError("reconnect_max_backoff must be greater than or equal to reconnect_initial_backoff")
|
|
123
|
-
return self
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
class MCPClientConfig(MCPClientBaseConfig, name="mcp_client"):
|
|
127
|
-
"""
|
|
128
|
-
Configuration for connecting to an MCP server as a client and exposing selected tools.
|
|
129
|
-
"""
|
|
130
117
|
session_aware_tools: bool = Field(default=True,
|
|
131
118
|
description="Session-aware tools are created if True. Defaults to True.")
|
|
132
119
|
max_sessions: int = Field(default=100,
|
|
@@ -136,11 +123,9 @@ class MCPClientConfig(MCPClientBaseConfig, name="mcp_client"):
|
|
|
136
123
|
default=timedelta(hours=1),
|
|
137
124
|
description="Time after which inactive sessions are cleaned up. Defaults to 1 hour.")
|
|
138
125
|
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
"""
|
|
146
|
-
pass
|
|
126
|
+
@model_validator(mode="after")
|
|
127
|
+
def _validate_reconnect_backoff(self) -> "MCPClientConfig":
|
|
128
|
+
"""Validate reconnect backoff values."""
|
|
129
|
+
if self.reconnect_max_backoff < self.reconnect_initial_backoff:
|
|
130
|
+
raise ValueError("reconnect_max_backoff must be greater than or equal to reconnect_initial_backoff")
|
|
131
|
+
return self
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -26,77 +26,16 @@ from pydantic import BaseModel
|
|
|
26
26
|
|
|
27
27
|
from nat.authentication.interfaces import AuthProviderBase
|
|
28
28
|
from nat.builder.builder import Builder
|
|
29
|
-
from nat.builder.context import Context
|
|
30
29
|
from nat.builder.function import FunctionGroup
|
|
31
30
|
from nat.cli.register_workflow import register_function_group
|
|
32
|
-
from nat.
|
|
33
|
-
from nat.plugins.mcp.
|
|
34
|
-
from nat.plugins.mcp.
|
|
35
|
-
from nat.plugins.mcp.client.client_config import MCPToolOverrideConfig
|
|
36
|
-
from nat.plugins.mcp.client.client_config import PerUserMCPClientConfig
|
|
31
|
+
from nat.plugins.mcp.client_base import MCPBaseClient
|
|
32
|
+
from nat.plugins.mcp.client_config import MCPClientConfig
|
|
33
|
+
from nat.plugins.mcp.client_config import MCPToolOverrideConfig
|
|
37
34
|
from nat.plugins.mcp.utils import truncate_session_id
|
|
38
35
|
|
|
39
36
|
logger = logging.getLogger(__name__)
|
|
40
37
|
|
|
41
38
|
|
|
42
|
-
class PerUserMCPFunctionGroup(FunctionGroup):
|
|
43
|
-
"""
|
|
44
|
-
A specialized FunctionGroup for per-user MCP clients.
|
|
45
|
-
"""
|
|
46
|
-
|
|
47
|
-
def __init__(self, *args, **kwargs):
|
|
48
|
-
super().__init__(*args, **kwargs)
|
|
49
|
-
|
|
50
|
-
self.mcp_client: MCPBaseClient | None = None # Will be set to the actual MCP client instance
|
|
51
|
-
self.mcp_client_server_name: str | None = None
|
|
52
|
-
self.mcp_client_transport: str | None = None
|
|
53
|
-
self.user_id: str | None = None
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
def mcp_per_user_tool_function(tool, client: MCPBaseClient):
|
|
57
|
-
"""
|
|
58
|
-
Create a per-user NAT function for an MCP tool.
|
|
59
|
-
|
|
60
|
-
Args:
|
|
61
|
-
tool: The MCP tool to create a function for
|
|
62
|
-
client: The MCP client to use for the function
|
|
63
|
-
|
|
64
|
-
Returns:
|
|
65
|
-
The NAT function
|
|
66
|
-
"""
|
|
67
|
-
from nat.builder.function import FunctionInfo
|
|
68
|
-
|
|
69
|
-
def _convert_from_str(input_str: str) -> tool.input_schema:
|
|
70
|
-
return tool.input_schema.model_validate_json(input_str)
|
|
71
|
-
|
|
72
|
-
async def _response_fn(tool_input: BaseModel | None = None, **kwargs) -> str:
|
|
73
|
-
try:
|
|
74
|
-
mcp_tool = await client.get_tool(tool.name)
|
|
75
|
-
|
|
76
|
-
if tool_input:
|
|
77
|
-
args = tool_input.model_dump(exclude_none=True, mode='json')
|
|
78
|
-
return await mcp_tool.acall(args)
|
|
79
|
-
|
|
80
|
-
# kwargs arrives with all optional fields set to None because NAT's framework
|
|
81
|
-
# converts the input dict to a Pydantic model (filling in all Field(default=None)),
|
|
82
|
-
# then dumps it back to a dict. We need to strip out these None values because
|
|
83
|
-
# many MCP servers (e.g., Kaggle) reject requests with excessive null fields.
|
|
84
|
-
# We re-validate here (yes, redundant) to leverage Pydantic's exclude_none with
|
|
85
|
-
# mode='json' for recursive None removal in nested models.
|
|
86
|
-
# Reference: function_info.py:_convert_input_pydantic
|
|
87
|
-
validated_input = mcp_tool.input_schema.model_validate(kwargs)
|
|
88
|
-
args = validated_input.model_dump(exclude_none=True, mode='json')
|
|
89
|
-
return await mcp_tool.acall(args)
|
|
90
|
-
except Exception as e:
|
|
91
|
-
logger.warning("Error calling tool %s", tool.name, exc_info=True)
|
|
92
|
-
return str(e)
|
|
93
|
-
|
|
94
|
-
return FunctionInfo.create(single_fn=_response_fn,
|
|
95
|
-
description=tool.description,
|
|
96
|
-
input_schema=tool.input_schema,
|
|
97
|
-
converters=[_convert_from_str])
|
|
98
|
-
|
|
99
|
-
|
|
100
39
|
@dataclass
|
|
101
40
|
class SessionData:
|
|
102
41
|
"""Container for all session-related data."""
|
|
@@ -152,9 +91,9 @@ class MCPFunctionGroup(FunctionGroup):
|
|
|
152
91
|
def __init__(self, *args, **kwargs):
|
|
153
92
|
super().__init__(*args, **kwargs)
|
|
154
93
|
# MCP client attributes with proper typing
|
|
155
|
-
self.
|
|
156
|
-
self.
|
|
157
|
-
self.
|
|
94
|
+
self._mcp_client = None # Will be set to the actual MCP client instance
|
|
95
|
+
self._mcp_client_server_name: str | None = None
|
|
96
|
+
self._mcp_client_transport: str | None = None
|
|
158
97
|
|
|
159
98
|
# Session management - consolidated data structure
|
|
160
99
|
self._sessions: dict[str, SessionData] = {}
|
|
@@ -177,6 +116,36 @@ class MCPFunctionGroup(FunctionGroup):
|
|
|
177
116
|
# Use random session id for testing only
|
|
178
117
|
self._use_random_session_id_for_testing: bool = False
|
|
179
118
|
|
|
119
|
+
@property
|
|
120
|
+
def mcp_client(self):
|
|
121
|
+
"""Get the MCP client instance."""
|
|
122
|
+
return self._mcp_client
|
|
123
|
+
|
|
124
|
+
@mcp_client.setter
|
|
125
|
+
def mcp_client(self, client):
|
|
126
|
+
"""Set the MCP client instance."""
|
|
127
|
+
self._mcp_client = client
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
def mcp_client_server_name(self) -> str | None:
|
|
131
|
+
"""Get the MCP client server name."""
|
|
132
|
+
return self._mcp_client_server_name
|
|
133
|
+
|
|
134
|
+
@mcp_client_server_name.setter
|
|
135
|
+
def mcp_client_server_name(self, server_name: str | None):
|
|
136
|
+
"""Set the MCP client server name."""
|
|
137
|
+
self._mcp_client_server_name = server_name
|
|
138
|
+
|
|
139
|
+
@property
|
|
140
|
+
def mcp_client_transport(self) -> str | None:
|
|
141
|
+
"""Get the MCP client transport type."""
|
|
142
|
+
return self._mcp_client_transport
|
|
143
|
+
|
|
144
|
+
@mcp_client_transport.setter
|
|
145
|
+
def mcp_client_transport(self, transport: str | None):
|
|
146
|
+
"""Set the MCP client transport type."""
|
|
147
|
+
self._mcp_client_transport = transport
|
|
148
|
+
|
|
180
149
|
@property
|
|
181
150
|
def session_count(self) -> int:
|
|
182
151
|
"""Current number of active sessions."""
|
|
@@ -289,7 +258,7 @@ class MCPFunctionGroup(FunctionGroup):
|
|
|
289
258
|
except Exception as e:
|
|
290
259
|
logger.warning("Error cleaning up session client %s: %s", truncate_session_id(session_id), e)
|
|
291
260
|
|
|
292
|
-
async def _get_session_client(self, session_id: str) -> MCPBaseClient
|
|
261
|
+
async def _get_session_client(self, session_id: str) -> MCPBaseClient:
|
|
293
262
|
"""Get the appropriate MCP client for the session."""
|
|
294
263
|
# Throttled cleanup on access
|
|
295
264
|
now = datetime.now()
|
|
@@ -379,7 +348,7 @@ class MCPFunctionGroup(FunctionGroup):
|
|
|
379
348
|
|
|
380
349
|
async def _create_session_client(self, session_id: str) -> tuple[MCPBaseClient, asyncio.Event, asyncio.Task]:
|
|
381
350
|
"""Create a new MCP client instance for the session."""
|
|
382
|
-
from nat.plugins.mcp.
|
|
351
|
+
from nat.plugins.mcp.client_base import MCPStreamableHTTPClient
|
|
383
352
|
|
|
384
353
|
config = self._client_config
|
|
385
354
|
if not config:
|
|
@@ -471,13 +440,9 @@ def mcp_session_tool_function(tool, function_group: MCPFunctionGroup):
|
|
|
471
440
|
if (not function_group._shared_auth_provider or session_id == function_group._default_user_id):
|
|
472
441
|
# Use base client directly for default user
|
|
473
442
|
client = function_group.mcp_client
|
|
474
|
-
if client is None:
|
|
475
|
-
return "Tool temporarily unavailable. Try again."
|
|
476
443
|
session_tool = await client.get_tool(tool.name)
|
|
477
444
|
else:
|
|
478
445
|
# Use session usage context to prevent cleanup during tool execution
|
|
479
|
-
if session_id is None:
|
|
480
|
-
return "Tool temporarily unavailable. Try again."
|
|
481
446
|
async with function_group._session_usage_context(session_id) as client:
|
|
482
447
|
if client is None:
|
|
483
448
|
return "Tool temporarily unavailable. Try again."
|
|
@@ -519,9 +484,9 @@ async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder):
|
|
|
519
484
|
Returns:
|
|
520
485
|
The function group
|
|
521
486
|
"""
|
|
522
|
-
from nat.plugins.mcp.
|
|
523
|
-
from nat.plugins.mcp.
|
|
524
|
-
from nat.plugins.mcp.
|
|
487
|
+
from nat.plugins.mcp.client_base import MCPSSEClient
|
|
488
|
+
from nat.plugins.mcp.client_base import MCPStdioClient
|
|
489
|
+
from nat.plugins.mcp.client_base import MCPStreamableHTTPClient
|
|
525
490
|
|
|
526
491
|
# Resolve auth provider if specified
|
|
527
492
|
auth_provider = None
|
|
@@ -609,16 +574,23 @@ async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder):
|
|
|
609
574
|
# Create the tool function according to configuration
|
|
610
575
|
tool_fn = mcp_session_tool_function(tool, group)
|
|
611
576
|
|
|
577
|
+
# Normalize optional typing for linter/type-checker compatibility
|
|
578
|
+
single_fn = tool_fn.single_fn
|
|
579
|
+
if single_fn is None:
|
|
580
|
+
# Should not happen because FunctionInfo always sets a single_fn
|
|
581
|
+
logger.warning("Skipping tool %s because single_fn is None", function_name)
|
|
582
|
+
continue
|
|
583
|
+
|
|
612
584
|
input_schema = tool_fn.input_schema
|
|
613
585
|
# Convert NoneType sentinel to None for FunctionGroup.add_function signature
|
|
614
|
-
if input_schema is type(None):
|
|
586
|
+
if input_schema is type(None): # noqa: E721
|
|
615
587
|
input_schema = None
|
|
616
588
|
|
|
617
589
|
# Add to group
|
|
618
590
|
logger.info("Adding tool %s to group", function_name)
|
|
619
591
|
group.add_function(name=function_name,
|
|
620
592
|
description=description,
|
|
621
|
-
fn=
|
|
593
|
+
fn=single_fn,
|
|
622
594
|
input_schema=input_schema,
|
|
623
595
|
converters=tool_fn.converters)
|
|
624
596
|
|
|
@@ -640,143 +612,3 @@ def mcp_apply_tool_alias_and_description(
|
|
|
640
612
|
return {}
|
|
641
613
|
|
|
642
614
|
return {name: override for name, override in tool_overrides.items() if name in all_tools}
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
@register_per_user_function_group(config_type=PerUserMCPClientConfig)
|
|
646
|
-
async def per_user_mcp_client_function_group(config: PerUserMCPClientConfig, _builder: Builder):
|
|
647
|
-
"""
|
|
648
|
-
Connect to an MCP server and expose tools as a function group for per-user workflows.
|
|
649
|
-
|
|
650
|
-
Args:
|
|
651
|
-
config: The configuration for the MCP client
|
|
652
|
-
_builder: The builder
|
|
653
|
-
Returns:
|
|
654
|
-
The function group
|
|
655
|
-
"""
|
|
656
|
-
from nat.plugins.mcp.client.client_base import MCPSSEClient
|
|
657
|
-
from nat.plugins.mcp.client.client_base import MCPStdioClient
|
|
658
|
-
from nat.plugins.mcp.client.client_base import MCPStreamableHTTPClient
|
|
659
|
-
|
|
660
|
-
# Resolve auth provider if specified
|
|
661
|
-
auth_provider = None
|
|
662
|
-
if config.server.auth_provider:
|
|
663
|
-
auth_provider = await _builder.get_auth_provider(config.server.auth_provider)
|
|
664
|
-
|
|
665
|
-
user_id = Context.get().user_id
|
|
666
|
-
|
|
667
|
-
# Build the appropriate client
|
|
668
|
-
if config.server.transport == "stdio":
|
|
669
|
-
if not config.server.command:
|
|
670
|
-
raise ValueError("command is required for stdio transport")
|
|
671
|
-
client = MCPStdioClient(config.server.command,
|
|
672
|
-
config.server.args,
|
|
673
|
-
config.server.env,
|
|
674
|
-
tool_call_timeout=config.tool_call_timeout,
|
|
675
|
-
auth_flow_timeout=config.auth_flow_timeout,
|
|
676
|
-
reconnect_enabled=config.reconnect_enabled,
|
|
677
|
-
reconnect_max_attempts=config.reconnect_max_attempts,
|
|
678
|
-
reconnect_initial_backoff=config.reconnect_initial_backoff,
|
|
679
|
-
reconnect_max_backoff=config.reconnect_max_backoff)
|
|
680
|
-
elif config.server.transport == "sse":
|
|
681
|
-
client = MCPSSEClient(str(config.server.url),
|
|
682
|
-
tool_call_timeout=config.tool_call_timeout,
|
|
683
|
-
auth_flow_timeout=config.auth_flow_timeout,
|
|
684
|
-
reconnect_enabled=config.reconnect_enabled,
|
|
685
|
-
reconnect_max_attempts=config.reconnect_max_attempts,
|
|
686
|
-
reconnect_initial_backoff=config.reconnect_initial_backoff,
|
|
687
|
-
reconnect_max_backoff=config.reconnect_max_backoff)
|
|
688
|
-
elif config.server.transport == "streamable-http":
|
|
689
|
-
client = MCPStreamableHTTPClient(str(config.server.url),
|
|
690
|
-
auth_provider=auth_provider,
|
|
691
|
-
user_id=user_id,
|
|
692
|
-
tool_call_timeout=config.tool_call_timeout,
|
|
693
|
-
auth_flow_timeout=config.auth_flow_timeout,
|
|
694
|
-
reconnect_enabled=config.reconnect_enabled,
|
|
695
|
-
reconnect_max_attempts=config.reconnect_max_attempts,
|
|
696
|
-
reconnect_initial_backoff=config.reconnect_initial_backoff,
|
|
697
|
-
reconnect_max_backoff=config.reconnect_max_backoff)
|
|
698
|
-
else:
|
|
699
|
-
raise ValueError(f"Unsupported transport: {config.server.transport}")
|
|
700
|
-
|
|
701
|
-
logger.info("Per-user MCP client configured for server: %s (user: %s)", client.server_name, user_id)
|
|
702
|
-
|
|
703
|
-
group = PerUserMCPFunctionGroup(config=config)
|
|
704
|
-
|
|
705
|
-
# Use a lifetime task to ensure the client context is entered and exited in the same task.
|
|
706
|
-
# This avoids anyio's "Attempted to exit cancel scope in a different task" error.
|
|
707
|
-
ready = asyncio.Event()
|
|
708
|
-
stop_event = asyncio.Event()
|
|
709
|
-
|
|
710
|
-
async def _lifetime():
|
|
711
|
-
"""Lifetime task that owns the client's async context."""
|
|
712
|
-
try:
|
|
713
|
-
async with client:
|
|
714
|
-
ready.set()
|
|
715
|
-
await stop_event.wait()
|
|
716
|
-
except Exception:
|
|
717
|
-
ready.set() # Ensure we don't hang the waiter
|
|
718
|
-
raise
|
|
719
|
-
|
|
720
|
-
lifetime_task = asyncio.create_task(_lifetime(), name=f"mcp-per-user-{user_id}")
|
|
721
|
-
|
|
722
|
-
# Wait for client initialization
|
|
723
|
-
timeout = config.tool_call_timeout.total_seconds()
|
|
724
|
-
try:
|
|
725
|
-
await asyncio.wait_for(ready.wait(), timeout=timeout)
|
|
726
|
-
except TimeoutError:
|
|
727
|
-
lifetime_task.cancel()
|
|
728
|
-
try:
|
|
729
|
-
await lifetime_task
|
|
730
|
-
except asyncio.CancelledError:
|
|
731
|
-
pass
|
|
732
|
-
raise RuntimeError(f"Per-user MCP client initialization timed out after {timeout}s")
|
|
733
|
-
|
|
734
|
-
# Check if initialization failed
|
|
735
|
-
if lifetime_task.done():
|
|
736
|
-
try:
|
|
737
|
-
await lifetime_task
|
|
738
|
-
except Exception as e:
|
|
739
|
-
raise RuntimeError(f"Failed to initialize per-user MCP client: {e}") from e
|
|
740
|
-
|
|
741
|
-
try:
|
|
742
|
-
# Expose the live MCP client on the function group instance so other components (e.g., HTTP endpoints)
|
|
743
|
-
# can reuse the already-established session instead of creating a new client per request.
|
|
744
|
-
group.mcp_client = client
|
|
745
|
-
group.mcp_client_server_name = client.server_name
|
|
746
|
-
group.mcp_client_transport = client.transport
|
|
747
|
-
group.user_id = user_id
|
|
748
|
-
|
|
749
|
-
all_tools = await client.get_tools()
|
|
750
|
-
tool_overrides = mcp_apply_tool_alias_and_description(all_tools, config.tool_overrides)
|
|
751
|
-
|
|
752
|
-
# Add each tool as a function to the group
|
|
753
|
-
for tool_name, tool in all_tools.items():
|
|
754
|
-
# Get override if it exists
|
|
755
|
-
override = tool_overrides.get(tool_name)
|
|
756
|
-
|
|
757
|
-
# Use override values or defaults
|
|
758
|
-
function_name = override.alias if override and override.alias else tool_name
|
|
759
|
-
description = override.description if override and override.description else tool.description
|
|
760
|
-
|
|
761
|
-
# Create the tool function according to configuration
|
|
762
|
-
tool_fn = mcp_per_user_tool_function(tool, client)
|
|
763
|
-
|
|
764
|
-
input_schema = tool_fn.input_schema
|
|
765
|
-
# Convert NoneType sentinel to None for FunctionGroup.add_function signature
|
|
766
|
-
if input_schema is type(None):
|
|
767
|
-
input_schema = None
|
|
768
|
-
|
|
769
|
-
# Add to group
|
|
770
|
-
logger.info("Adding tool %s to group", function_name)
|
|
771
|
-
group.add_function(name=function_name,
|
|
772
|
-
description=description,
|
|
773
|
-
fn=tool_fn.single_fn,
|
|
774
|
-
input_schema=input_schema,
|
|
775
|
-
converters=tool_fn.converters)
|
|
776
|
-
|
|
777
|
-
yield group
|
|
778
|
-
finally:
|
|
779
|
-
# Signal the lifetime task to exit and wait for clean shutdown
|
|
780
|
-
stop_event.set()
|
|
781
|
-
if not lifetime_task.done():
|
|
782
|
-
await lifetime_task
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
nat/plugins/mcp/exceptions.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
nat/plugins/mcp/register.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -16,8 +16,7 @@
|
|
|
16
16
|
# flake8: noqa
|
|
17
17
|
# isort:skip_file
|
|
18
18
|
|
|
19
|
-
#
|
|
20
|
-
from .client import client_impl
|
|
19
|
+
# Import any providers which need to be automatically registered here
|
|
21
20
|
|
|
22
|
-
|
|
23
|
-
from .
|
|
21
|
+
from . import client_impl
|
|
22
|
+
from . import tool
|