nvidia-nat-mcp 1.4.0a20260107__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/meta/pypi.md +32 -0
- nat/plugins/mcp/__init__.py +14 -0
- nat/plugins/mcp/auth/__init__.py +14 -0
- nat/plugins/mcp/auth/auth_flow_handler.py +208 -0
- nat/plugins/mcp/auth/auth_provider.py +431 -0
- nat/plugins/mcp/auth/auth_provider_config.py +86 -0
- nat/plugins/mcp/auth/register.py +33 -0
- nat/plugins/mcp/auth/service_account/__init__.py +14 -0
- nat/plugins/mcp/auth/service_account/provider.py +136 -0
- nat/plugins/mcp/auth/service_account/provider_config.py +137 -0
- nat/plugins/mcp/auth/service_account/token_client.py +156 -0
- nat/plugins/mcp/auth/token_storage.py +265 -0
- nat/plugins/mcp/cli/__init__.py +15 -0
- nat/plugins/mcp/cli/commands.py +1051 -0
- nat/plugins/mcp/client/__init__.py +15 -0
- nat/plugins/mcp/client/client_base.py +665 -0
- nat/plugins/mcp/client/client_config.py +146 -0
- nat/plugins/mcp/client/client_impl.py +782 -0
- nat/plugins/mcp/exception_handler.py +211 -0
- nat/plugins/mcp/exceptions.py +142 -0
- nat/plugins/mcp/register.py +23 -0
- nat/plugins/mcp/server/__init__.py +15 -0
- nat/plugins/mcp/server/front_end_config.py +109 -0
- nat/plugins/mcp/server/front_end_plugin.py +155 -0
- nat/plugins/mcp/server/front_end_plugin_worker.py +411 -0
- nat/plugins/mcp/server/introspection_token_verifier.py +72 -0
- nat/plugins/mcp/server/memory_profiler.py +320 -0
- nat/plugins/mcp/server/register_frontend.py +27 -0
- nat/plugins/mcp/server/tool_converter.py +286 -0
- nat/plugins/mcp/utils.py +228 -0
- nvidia_nat_mcp-1.4.0a20260107.dist-info/METADATA +55 -0
- nvidia_nat_mcp-1.4.0a20260107.dist-info/RECORD +37 -0
- nvidia_nat_mcp-1.4.0a20260107.dist-info/WHEEL +5 -0
- nvidia_nat_mcp-1.4.0a20260107.dist-info/entry_points.txt +9 -0
- nvidia_nat_mcp-1.4.0a20260107.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat_mcp-1.4.0a20260107.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat_mcp-1.4.0a20260107.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
import ssl
|
|
18
|
+
import sys
|
|
19
|
+
from collections.abc import Callable
|
|
20
|
+
from functools import wraps
|
|
21
|
+
from typing import Any
|
|
22
|
+
|
|
23
|
+
import httpx
|
|
24
|
+
|
|
25
|
+
from nat.plugins.mcp.exceptions import MCPAuthenticationError
|
|
26
|
+
from nat.plugins.mcp.exceptions import MCPConnectionError
|
|
27
|
+
from nat.plugins.mcp.exceptions import MCPError
|
|
28
|
+
from nat.plugins.mcp.exceptions import MCPProtocolError
|
|
29
|
+
from nat.plugins.mcp.exceptions import MCPRequestError
|
|
30
|
+
from nat.plugins.mcp.exceptions import MCPSSLError
|
|
31
|
+
from nat.plugins.mcp.exceptions import MCPTimeoutError
|
|
32
|
+
from nat.plugins.mcp.exceptions import MCPToolNotFoundError
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def format_mcp_error(error: MCPError, include_traceback: bool = False) -> None:
|
|
38
|
+
"""Format MCP errors for CLI display with structured logging and user guidance.
|
|
39
|
+
|
|
40
|
+
Logs structured error information for debugging and displays user-friendly
|
|
41
|
+
error messages with actionable suggestions to stderr.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
error (MCPError): MCPError instance containing message, url, category, suggestions, and original_exception
|
|
45
|
+
include_traceback (bool, optional): Whether to include the traceback in the error message. Defaults to False.
|
|
46
|
+
"""
|
|
47
|
+
# Log structured error information for debugging
|
|
48
|
+
logger.error("MCP operation failed: %s", error, exc_info=include_traceback)
|
|
49
|
+
|
|
50
|
+
# Display user-friendly suggestions
|
|
51
|
+
for suggestion in error.suggestions:
|
|
52
|
+
print(f" → {suggestion}", file=sys.stderr)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _extract_url(args: tuple, kwargs: dict[str, Any], url_param: str, func_name: str) -> str:
|
|
56
|
+
"""Extract URL from function arguments using clean fallback chain.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
args: Function positional arguments
|
|
60
|
+
kwargs: Function keyword arguments
|
|
61
|
+
url_param (str): Parameter name containing the URL
|
|
62
|
+
func_name (str): Function name for logging
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
str: URL string or "unknown" if extraction fails
|
|
66
|
+
"""
|
|
67
|
+
# Try keyword arguments first
|
|
68
|
+
if url_param in kwargs:
|
|
69
|
+
return kwargs[url_param]
|
|
70
|
+
|
|
71
|
+
# Try self attribute (e.g., self.url)
|
|
72
|
+
if args and hasattr(args[0], url_param):
|
|
73
|
+
return getattr(args[0], url_param)
|
|
74
|
+
|
|
75
|
+
# Try common case: url as second parameter after self
|
|
76
|
+
if len(args) > 1 and url_param == "url":
|
|
77
|
+
return args[1]
|
|
78
|
+
|
|
79
|
+
# Fallback with warning
|
|
80
|
+
logger.warning("Could not extract URL for error handling in %s", func_name)
|
|
81
|
+
return "unknown"
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def extract_primary_exception(exceptions: list[Exception]) -> Exception:
|
|
85
|
+
"""Extract the most relevant exception from a group.
|
|
86
|
+
|
|
87
|
+
Prioritizes connection errors over others for better user experience.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
exceptions (list[Exception]): List of exceptions from ExceptionGroup
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
Exception: Most relevant exception for user feedback
|
|
94
|
+
"""
|
|
95
|
+
# Prioritize connection errors
|
|
96
|
+
for exc in exceptions:
|
|
97
|
+
if isinstance(exc, httpx.ConnectError | ConnectionError):
|
|
98
|
+
return exc
|
|
99
|
+
|
|
100
|
+
# Then timeout errors
|
|
101
|
+
for exc in exceptions:
|
|
102
|
+
if isinstance(exc, httpx.TimeoutException):
|
|
103
|
+
return exc
|
|
104
|
+
|
|
105
|
+
# Then SSL errors
|
|
106
|
+
for exc in exceptions:
|
|
107
|
+
if isinstance(exc, ssl.SSLError):
|
|
108
|
+
return exc
|
|
109
|
+
|
|
110
|
+
# Fall back to first exception
|
|
111
|
+
return exceptions[0]
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def convert_to_mcp_error(exception: Exception, url: str) -> MCPError:
|
|
115
|
+
"""Convert single exception to appropriate MCPError.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
exception (Exception): Single exception to convert
|
|
119
|
+
url (str): MCP server URL for context
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
MCPError: Appropriate MCPError subclass
|
|
123
|
+
"""
|
|
124
|
+
match exception:
|
|
125
|
+
case httpx.ConnectError() | ConnectionError():
|
|
126
|
+
return MCPConnectionError(url, exception)
|
|
127
|
+
case httpx.TimeoutException():
|
|
128
|
+
return MCPTimeoutError(url, exception)
|
|
129
|
+
case ssl.SSLError():
|
|
130
|
+
return MCPSSLError(url, exception)
|
|
131
|
+
case httpx.RequestError():
|
|
132
|
+
return MCPRequestError(url, exception)
|
|
133
|
+
case ValueError() if "Tool" in str(exception) and "not available" in str(exception):
|
|
134
|
+
# Extract tool name from error message if possible
|
|
135
|
+
tool_name = str(exception).split("Tool ")[1].split(" not available")[0] if "Tool " in str(
|
|
136
|
+
exception) else "unknown"
|
|
137
|
+
return MCPToolNotFoundError(tool_name, url, exception)
|
|
138
|
+
case _:
|
|
139
|
+
# Handle TaskGroup error message specifically
|
|
140
|
+
if "unhandled errors in a TaskGroup" in str(exception):
|
|
141
|
+
return MCPProtocolError(url, "Failed to connect to MCP server", exception)
|
|
142
|
+
if "unauthorized" in str(exception).lower() or "forbidden" in str(exception).lower():
|
|
143
|
+
return MCPAuthenticationError(url, exception)
|
|
144
|
+
return MCPError(f"Unexpected error: {exception}", url, original_exception=exception)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def handle_mcp_exceptions(url_param: str = "url") -> Callable[..., Any]:
|
|
148
|
+
"""Decorator that handles exceptions and converts them to MCPErrors.
|
|
149
|
+
|
|
150
|
+
This decorator wraps MCP client methods and converts low-level exceptions
|
|
151
|
+
to structured MCPError instances with helpful user guidance.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
url_param (str): Name of the parameter or attribute containing the MCP server URL
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
Callable[..., Any]: Decorated function
|
|
158
|
+
|
|
159
|
+
Example:
|
|
160
|
+
.. code-block:: python
|
|
161
|
+
|
|
162
|
+
@handle_mcp_exceptions("url")
|
|
163
|
+
async def get_tools(self, url: str):
|
|
164
|
+
# Method implementation
|
|
165
|
+
pass
|
|
166
|
+
|
|
167
|
+
@handle_mcp_exceptions("url") # Uses self.url
|
|
168
|
+
async def get_tool(self):
|
|
169
|
+
# Method implementation
|
|
170
|
+
pass
|
|
171
|
+
"""
|
|
172
|
+
|
|
173
|
+
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
174
|
+
|
|
175
|
+
@wraps(func)
|
|
176
|
+
async def wrapper(*args, **kwargs):
|
|
177
|
+
try:
|
|
178
|
+
return await func(*args, **kwargs)
|
|
179
|
+
except MCPError:
|
|
180
|
+
# Re-raise MCPErrors as-is
|
|
181
|
+
raise
|
|
182
|
+
except Exception as e:
|
|
183
|
+
url = _extract_url(args, kwargs, url_param, func.__name__)
|
|
184
|
+
|
|
185
|
+
# Handle ExceptionGroup by extracting most relevant exception
|
|
186
|
+
if isinstance(e, ExceptionGroup): # noqa: F821
|
|
187
|
+
primary_exception = extract_primary_exception(list(e.exceptions))
|
|
188
|
+
mcp_error = convert_to_mcp_error(primary_exception, url)
|
|
189
|
+
else:
|
|
190
|
+
mcp_error = convert_to_mcp_error(e, url)
|
|
191
|
+
|
|
192
|
+
raise mcp_error from e
|
|
193
|
+
|
|
194
|
+
return wrapper
|
|
195
|
+
|
|
196
|
+
return decorator
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def mcp_exception_handler(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
200
|
+
"""Simplified decorator for methods that have self.url attribute.
|
|
201
|
+
|
|
202
|
+
This is a convenience decorator that assumes the URL is available as self.url.
|
|
203
|
+
Follows the same pattern as schema_exception_handler in this directory.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
func (Callable[..., Any]): The function to decorate
|
|
207
|
+
|
|
208
|
+
Returns:
|
|
209
|
+
Callable[..., Any]: Decorated function
|
|
210
|
+
"""
|
|
211
|
+
return handle_mcp_exceptions("url")(func)
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from enum import Enum
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class MCPErrorCategory(str, Enum):
|
|
20
|
+
"""Categories of MCP errors for structured handling."""
|
|
21
|
+
CONNECTION = "connection"
|
|
22
|
+
TIMEOUT = "timeout"
|
|
23
|
+
SSL = "ssl"
|
|
24
|
+
AUTHENTICATION = "authentication"
|
|
25
|
+
TOOL_NOT_FOUND = "tool_not_found"
|
|
26
|
+
PROTOCOL = "protocol"
|
|
27
|
+
UNKNOWN = "unknown"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class MCPError(Exception):
|
|
31
|
+
"""Base exception for MCP-related errors."""
|
|
32
|
+
|
|
33
|
+
def __init__(self,
|
|
34
|
+
message: str,
|
|
35
|
+
url: str,
|
|
36
|
+
category: MCPErrorCategory = MCPErrorCategory.UNKNOWN,
|
|
37
|
+
suggestions: list[str] | None = None,
|
|
38
|
+
original_exception: Exception | None = None):
|
|
39
|
+
super().__init__(message)
|
|
40
|
+
self.url = url
|
|
41
|
+
self.category = category
|
|
42
|
+
self.suggestions = suggestions or []
|
|
43
|
+
self.original_exception = original_exception
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class MCPConnectionError(MCPError):
|
|
47
|
+
"""Exception for MCP connection failures."""
|
|
48
|
+
|
|
49
|
+
def __init__(self, url: str, original_exception: Exception | None = None):
|
|
50
|
+
super().__init__(f"Unable to connect to MCP server at {url}",
|
|
51
|
+
url=url,
|
|
52
|
+
category=MCPErrorCategory.CONNECTION,
|
|
53
|
+
suggestions=[
|
|
54
|
+
"Please ensure the MCP server is running and accessible",
|
|
55
|
+
"Check if the URL and port are correct"
|
|
56
|
+
],
|
|
57
|
+
original_exception=original_exception)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class MCPTimeoutError(MCPError):
|
|
61
|
+
"""Exception for MCP timeout errors."""
|
|
62
|
+
|
|
63
|
+
def __init__(self, url: str, original_exception: Exception | None = None):
|
|
64
|
+
super().__init__(f"Connection timed out to MCP server at {url}",
|
|
65
|
+
url=url,
|
|
66
|
+
category=MCPErrorCategory.TIMEOUT,
|
|
67
|
+
suggestions=[
|
|
68
|
+
"The server may be overloaded or network is slow",
|
|
69
|
+
"Try again in a moment or check network connectivity"
|
|
70
|
+
],
|
|
71
|
+
original_exception=original_exception)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class MCPSSLError(MCPError):
|
|
75
|
+
"""Exception for MCP SSL/TLS errors."""
|
|
76
|
+
|
|
77
|
+
def __init__(self, url: str, original_exception: Exception | None = None):
|
|
78
|
+
super().__init__(f"SSL/TLS error connecting to {url}",
|
|
79
|
+
url=url,
|
|
80
|
+
category=MCPErrorCategory.SSL,
|
|
81
|
+
suggestions=[
|
|
82
|
+
"Check if the server requires HTTPS or has valid certificates",
|
|
83
|
+
"Try using HTTP instead of HTTPS if appropriate"
|
|
84
|
+
],
|
|
85
|
+
original_exception=original_exception)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class MCPRequestError(MCPError):
|
|
89
|
+
"""Exception for MCP request errors."""
|
|
90
|
+
|
|
91
|
+
def __init__(self, url: str, original_exception: Exception | None = None):
|
|
92
|
+
message = f"Request failed to MCP server at {url}"
|
|
93
|
+
if original_exception:
|
|
94
|
+
message += f": {original_exception}"
|
|
95
|
+
|
|
96
|
+
super().__init__(message,
|
|
97
|
+
url=url,
|
|
98
|
+
category=MCPErrorCategory.PROTOCOL,
|
|
99
|
+
suggestions=["Check the server URL format and network settings"],
|
|
100
|
+
original_exception=original_exception)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class MCPToolNotFoundError(MCPError):
|
|
104
|
+
"""Exception for when a specific MCP tool is not found."""
|
|
105
|
+
|
|
106
|
+
def __init__(self, tool_name: str, url: str, original_exception: Exception | None = None):
|
|
107
|
+
super().__init__(f"Tool '{tool_name}' not available at {url}",
|
|
108
|
+
url=url,
|
|
109
|
+
category=MCPErrorCategory.TOOL_NOT_FOUND,
|
|
110
|
+
suggestions=[
|
|
111
|
+
"Use 'nat info mcp --detail' to see available tools",
|
|
112
|
+
"Check that the tool name is spelled correctly"
|
|
113
|
+
],
|
|
114
|
+
original_exception=original_exception)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class MCPAuthenticationError(MCPError):
|
|
118
|
+
"""Exception for MCP authentication failures."""
|
|
119
|
+
|
|
120
|
+
def __init__(self, url: str, original_exception: Exception | None = None):
|
|
121
|
+
super().__init__(f"Authentication failed when connecting to MCP server at {url}",
|
|
122
|
+
url=url,
|
|
123
|
+
category=MCPErrorCategory.AUTHENTICATION,
|
|
124
|
+
suggestions=[
|
|
125
|
+
"Check if the server requires authentication credentials",
|
|
126
|
+
"Verify that your credentials are correct and not expired"
|
|
127
|
+
],
|
|
128
|
+
original_exception=original_exception)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class MCPProtocolError(MCPError):
|
|
132
|
+
"""Exception for MCP protocol-related errors."""
|
|
133
|
+
|
|
134
|
+
def __init__(self, url: str, message: str = "Protocol error", original_exception: Exception | None = None):
|
|
135
|
+
super().__init__(f"{message} (MCP server at {url})",
|
|
136
|
+
url=url,
|
|
137
|
+
category=MCPErrorCategory.PROTOCOL,
|
|
138
|
+
suggestions=[
|
|
139
|
+
"Check that the MCP server is running and accessible at this URL",
|
|
140
|
+
"Verify the server supports the expected MCP protocol version"
|
|
141
|
+
],
|
|
142
|
+
original_exception=original_exception)
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
# flake8: noqa
|
|
17
|
+
# isort:skip_file
|
|
18
|
+
|
|
19
|
+
# Register client components
|
|
20
|
+
from .client import client_impl
|
|
21
|
+
|
|
22
|
+
# Register server/frontend components
|
|
23
|
+
from .server import register_frontend
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""MCP server/frontend components."""
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
from typing import Literal
|
|
18
|
+
|
|
19
|
+
from pydantic import Field
|
|
20
|
+
from pydantic import field_validator
|
|
21
|
+
from pydantic import model_validator
|
|
22
|
+
|
|
23
|
+
from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
|
|
24
|
+
from nat.data_models.front_end import FrontEndBaseConfig
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class MCPFrontEndConfig(FrontEndBaseConfig, name="mcp"):
|
|
30
|
+
"""MCP front end configuration.
|
|
31
|
+
|
|
32
|
+
A simple MCP (Model Context Protocol) front end for NeMo Agent toolkit.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
name: str = Field(default="NeMo Agent Toolkit MCP",
|
|
36
|
+
description="Name of the MCP server (default: NeMo Agent Toolkit MCP)")
|
|
37
|
+
host: str = Field(default="localhost", description="Host to bind the server to (default: localhost)")
|
|
38
|
+
port: int = Field(default=9901, description="Port to bind the server to (default: 9901)", ge=0, le=65535)
|
|
39
|
+
debug: bool = Field(default=False, description="Enable debug mode (default: False)")
|
|
40
|
+
log_level: str = Field(default="INFO", description="Log level for the MCP server (default: INFO)")
|
|
41
|
+
tool_names: list[str] = Field(
|
|
42
|
+
default_factory=list,
|
|
43
|
+
description="The list of tools MCP server will expose (default: all tools)."
|
|
44
|
+
"Tool names can be functions or function groups",
|
|
45
|
+
)
|
|
46
|
+
transport: Literal["sse", "streamable-http"] = Field(
|
|
47
|
+
default="streamable-http",
|
|
48
|
+
description="Transport type for the MCP server (default: streamable-http, backwards compatible with sse)")
|
|
49
|
+
runner_class: str | None = Field(
|
|
50
|
+
default=None, description="Custom worker class for handling MCP routes (default: built-in worker)")
|
|
51
|
+
base_path: str | None = Field(default=None,
|
|
52
|
+
description="Base path to mount the MCP server at (e.g., '/api/v1'). "
|
|
53
|
+
"If specified, the server will be accessible at http://host:port{base_path}/mcp. "
|
|
54
|
+
"If None, server runs at root path /mcp.")
|
|
55
|
+
|
|
56
|
+
server_auth: OAuth2ResourceServerConfig | None = Field(
|
|
57
|
+
default=None, description=("OAuth 2.0 Resource Server configuration for token verification."))
|
|
58
|
+
|
|
59
|
+
@field_validator('base_path')
|
|
60
|
+
@classmethod
|
|
61
|
+
def validate_base_path(cls, v: str | None) -> str | None:
|
|
62
|
+
"""Validate that base_path starts with '/' and doesn't end with '/'."""
|
|
63
|
+
if v is not None:
|
|
64
|
+
if not v.startswith('/'):
|
|
65
|
+
raise ValueError("base_path must start with '/'")
|
|
66
|
+
if v.endswith('/'):
|
|
67
|
+
raise ValueError("base_path must not end with '/'")
|
|
68
|
+
return v
|
|
69
|
+
|
|
70
|
+
# Memory profiling configuration
|
|
71
|
+
enable_memory_profiling: bool = Field(default=False,
|
|
72
|
+
description="Enable memory profiling and diagnostics (default: False)")
|
|
73
|
+
memory_profile_interval: int = Field(default=50,
|
|
74
|
+
description="Log memory stats every N requests (default: 50)",
|
|
75
|
+
ge=1)
|
|
76
|
+
memory_profile_top_n: int = Field(default=10,
|
|
77
|
+
description="Number of top memory allocations to log (default: 10)",
|
|
78
|
+
ge=1,
|
|
79
|
+
le=50)
|
|
80
|
+
memory_profile_log_level: str = Field(default="DEBUG",
|
|
81
|
+
description="Log level for memory profiling output (default: DEBUG)")
|
|
82
|
+
|
|
83
|
+
@model_validator(mode="after")
|
|
84
|
+
def validate_security_configuration(self):
|
|
85
|
+
"""Validate security configuration to prevent accidental misconfigurations."""
|
|
86
|
+
# Check if server is bound to a non-localhost interface without authentication
|
|
87
|
+
localhost_hosts = {"localhost", "127.0.0.1", "::1"}
|
|
88
|
+
if self.host not in localhost_hosts and self.server_auth is None:
|
|
89
|
+
logger.warning(
|
|
90
|
+
"MCP server is configured to bind to '%s' without authentication. "
|
|
91
|
+
"This may expose your server to unauthorized access. "
|
|
92
|
+
"Consider either: (1) binding to localhost for local-only access, "
|
|
93
|
+
"or (2) configuring server_auth for production deployments on public interfaces.",
|
|
94
|
+
self.host)
|
|
95
|
+
|
|
96
|
+
# Check if SSE transport is used (which doesn't support authentication)
|
|
97
|
+
if self.transport == "sse":
|
|
98
|
+
if self.server_auth is not None:
|
|
99
|
+
logger.warning("SSE transport does not support authentication. "
|
|
100
|
+
"The configured server_auth will be ignored. "
|
|
101
|
+
"For production use with authentication, use 'streamable-http' transport instead.")
|
|
102
|
+
elif self.host not in localhost_hosts:
|
|
103
|
+
logger.warning(
|
|
104
|
+
"SSE transport does not support authentication and is bound to '%s'. "
|
|
105
|
+
"This configuration is not recommended for production use. "
|
|
106
|
+
"For production deployments, use 'streamable-http' transport with server_auth configured.",
|
|
107
|
+
self.host)
|
|
108
|
+
|
|
109
|
+
return self
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
import typing
|
|
18
|
+
|
|
19
|
+
from nat.builder.front_end import FrontEndBase
|
|
20
|
+
from nat.builder.workflow_builder import WorkflowBuilder
|
|
21
|
+
from nat.plugins.mcp.server.front_end_config import MCPFrontEndConfig
|
|
22
|
+
from nat.plugins.mcp.server.front_end_plugin_worker import MCPFrontEndPluginWorkerBase
|
|
23
|
+
|
|
24
|
+
if typing.TYPE_CHECKING:
|
|
25
|
+
from mcp.server.fastmcp import FastMCP
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class MCPFrontEndPlugin(FrontEndBase[MCPFrontEndConfig]):
|
|
31
|
+
"""MCP front end plugin implementation."""
|
|
32
|
+
|
|
33
|
+
def get_worker_class(self) -> type[MCPFrontEndPluginWorkerBase]:
|
|
34
|
+
"""Get the worker class for handling MCP routes."""
|
|
35
|
+
from nat.plugins.mcp.server.front_end_plugin_worker import MCPFrontEndPluginWorker
|
|
36
|
+
|
|
37
|
+
return MCPFrontEndPluginWorker
|
|
38
|
+
|
|
39
|
+
@typing.final
|
|
40
|
+
def get_worker_class_name(self) -> str:
|
|
41
|
+
"""Get the worker class name from configuration or default."""
|
|
42
|
+
if self.front_end_config.runner_class:
|
|
43
|
+
return self.front_end_config.runner_class
|
|
44
|
+
|
|
45
|
+
worker_class = self.get_worker_class()
|
|
46
|
+
return f"{worker_class.__module__}.{worker_class.__qualname__}"
|
|
47
|
+
|
|
48
|
+
def _get_worker_instance(self):
|
|
49
|
+
"""Get an instance of the worker class."""
|
|
50
|
+
# Import the worker class dynamically if specified in config
|
|
51
|
+
if self.front_end_config.runner_class:
|
|
52
|
+
module_name, class_name = self.front_end_config.runner_class.rsplit(".", 1)
|
|
53
|
+
import importlib
|
|
54
|
+
module = importlib.import_module(module_name)
|
|
55
|
+
worker_class = getattr(module, class_name)
|
|
56
|
+
else:
|
|
57
|
+
worker_class = self.get_worker_class()
|
|
58
|
+
|
|
59
|
+
return worker_class(self.full_config)
|
|
60
|
+
|
|
61
|
+
async def run(self) -> None:
|
|
62
|
+
"""Run the MCP server."""
|
|
63
|
+
# Build the workflow and add routes using the worker
|
|
64
|
+
async with WorkflowBuilder.from_config(config=self.full_config) as builder:
|
|
65
|
+
|
|
66
|
+
# Get the worker instance
|
|
67
|
+
worker = self._get_worker_instance()
|
|
68
|
+
|
|
69
|
+
# Let the worker create the MCP server (allows plugins to customize)
|
|
70
|
+
mcp = await worker.create_mcp_server()
|
|
71
|
+
|
|
72
|
+
# Add routes through the worker (includes health endpoint and function registration)
|
|
73
|
+
await worker.add_routes(mcp, builder)
|
|
74
|
+
|
|
75
|
+
# Start the MCP server with configurable transport
|
|
76
|
+
# streamable-http is the default, but users can choose sse if preferred
|
|
77
|
+
try:
|
|
78
|
+
# If base_path is configured, mount server at sub-path using FastAPI wrapper
|
|
79
|
+
if self.front_end_config.base_path:
|
|
80
|
+
if self.front_end_config.transport == "sse":
|
|
81
|
+
logger.warning(
|
|
82
|
+
"base_path is configured but SSE transport does not support mounting at sub-paths. "
|
|
83
|
+
"Use streamable-http transport for base_path support.")
|
|
84
|
+
logger.info("Starting MCP server with SSE endpoint at /sse")
|
|
85
|
+
await mcp.run_sse_async()
|
|
86
|
+
else:
|
|
87
|
+
full_url = f"http://{self.front_end_config.host}:{self.front_end_config.port}{self.front_end_config.base_path}/mcp"
|
|
88
|
+
logger.info(
|
|
89
|
+
"Mounting MCP server at %s/mcp on %s:%s",
|
|
90
|
+
self.front_end_config.base_path,
|
|
91
|
+
self.front_end_config.host,
|
|
92
|
+
self.front_end_config.port,
|
|
93
|
+
)
|
|
94
|
+
logger.info("MCP server URL: %s", full_url)
|
|
95
|
+
await self._run_with_mount(mcp)
|
|
96
|
+
# Standard behavior - run at root path
|
|
97
|
+
elif self.front_end_config.transport == "sse":
|
|
98
|
+
logger.info("Starting MCP server with SSE endpoint at /sse")
|
|
99
|
+
await mcp.run_sse_async()
|
|
100
|
+
else: # streamable-http
|
|
101
|
+
full_url = f"http://{self.front_end_config.host}:{self.front_end_config.port}/mcp"
|
|
102
|
+
logger.info("MCP server URL: %s", full_url)
|
|
103
|
+
await mcp.run_streamable_http_async()
|
|
104
|
+
except KeyboardInterrupt:
|
|
105
|
+
logger.info("MCP server shutdown requested (Ctrl+C). Shutting down gracefully.")
|
|
106
|
+
|
|
107
|
+
async def _run_with_mount(self, mcp: "FastMCP") -> None:
|
|
108
|
+
"""Run MCP server mounted at configured base_path using FastAPI wrapper.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
mcp: The FastMCP server instance to mount
|
|
112
|
+
"""
|
|
113
|
+
import contextlib
|
|
114
|
+
|
|
115
|
+
import uvicorn
|
|
116
|
+
from fastapi import FastAPI
|
|
117
|
+
|
|
118
|
+
@contextlib.asynccontextmanager
|
|
119
|
+
async def lifespan(_app: FastAPI):
|
|
120
|
+
"""Manage MCP server session lifecycle."""
|
|
121
|
+
logger.info("Starting MCP server session manager...")
|
|
122
|
+
async with contextlib.AsyncExitStack() as stack:
|
|
123
|
+
try:
|
|
124
|
+
# Initialize the MCP server's session manager
|
|
125
|
+
await stack.enter_async_context(mcp.session_manager.run())
|
|
126
|
+
logger.info("MCP server session manager started successfully")
|
|
127
|
+
yield
|
|
128
|
+
except Exception as e:
|
|
129
|
+
logger.error("Failed to start MCP server session manager: %s", e)
|
|
130
|
+
raise
|
|
131
|
+
logger.info("MCP server session manager stopped")
|
|
132
|
+
|
|
133
|
+
# Create a FastAPI wrapper app with lifespan management
|
|
134
|
+
app = FastAPI(
|
|
135
|
+
title=self.front_end_config.name,
|
|
136
|
+
description="MCP server mounted at custom base path",
|
|
137
|
+
lifespan=lifespan,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Mount the MCP server's ASGI app at the configured base_path
|
|
141
|
+
app.mount(self.front_end_config.base_path, mcp.streamable_http_app())
|
|
142
|
+
|
|
143
|
+
# Allow plugins to add routes to the wrapper app (e.g., OAuth discovery endpoints)
|
|
144
|
+
worker = self._get_worker_instance()
|
|
145
|
+
await worker.add_root_level_routes(app, mcp)
|
|
146
|
+
|
|
147
|
+
# Configure and start uvicorn server
|
|
148
|
+
config = uvicorn.Config(
|
|
149
|
+
app,
|
|
150
|
+
host=self.front_end_config.host,
|
|
151
|
+
port=self.front_end_config.port,
|
|
152
|
+
log_level=self.front_end_config.log_level.lower(),
|
|
153
|
+
)
|
|
154
|
+
server = uvicorn.Server(config)
|
|
155
|
+
await server.serve()
|