nvidia-nat-mcp 1.4.0a20251014__py3-none-any.whl → 1.5.0a20260115__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nvidia-nat-mcp might be problematic. Click here for more details.
- nat/meta/pypi.md +1 -1
- nat/plugins/mcp/__init__.py +1 -1
- nat/plugins/mcp/auth/__init__.py +1 -1
- nat/plugins/mcp/auth/auth_flow_handler.py +65 -1
- nat/plugins/mcp/auth/auth_provider.py +3 -2
- nat/plugins/mcp/auth/auth_provider_config.py +5 -2
- nat/plugins/mcp/auth/register.py +9 -1
- nat/plugins/mcp/auth/service_account/__init__.py +14 -0
- nat/plugins/mcp/auth/service_account/provider.py +136 -0
- nat/plugins/mcp/auth/service_account/provider_config.py +137 -0
- nat/plugins/mcp/auth/service_account/token_client.py +156 -0
- nat/plugins/mcp/auth/token_storage.py +2 -2
- nat/plugins/mcp/cli/__init__.py +15 -0
- nat/plugins/mcp/cli/commands.py +1094 -0
- nat/plugins/mcp/client/__init__.py +15 -0
- nat/plugins/mcp/{client_base.py → client/client_base.py} +18 -10
- nat/plugins/mcp/{client_config.py → client/client_config.py} +24 -9
- nat/plugins/mcp/{client_impl.py → client/client_impl.py} +253 -62
- nat/plugins/mcp/exception_handler.py +1 -1
- nat/plugins/mcp/exceptions.py +1 -1
- nat/plugins/mcp/register.py +5 -4
- nat/plugins/mcp/server/__init__.py +15 -0
- nat/plugins/mcp/server/front_end_config.py +109 -0
- nat/plugins/mcp/server/front_end_plugin.py +155 -0
- nat/plugins/mcp/server/front_end_plugin_worker.py +415 -0
- nat/plugins/mcp/server/introspection_token_verifier.py +72 -0
- nat/plugins/mcp/server/memory_profiler.py +320 -0
- nat/plugins/mcp/server/register_frontend.py +27 -0
- nat/plugins/mcp/server/tool_converter.py +290 -0
- nat/plugins/mcp/utils.py +153 -36
- {nvidia_nat_mcp-1.4.0a20251014.dist-info → nvidia_nat_mcp-1.5.0a20260115.dist-info}/METADATA +5 -5
- nvidia_nat_mcp-1.5.0a20260115.dist-info/RECORD +37 -0
- nvidia_nat_mcp-1.5.0a20260115.dist-info/entry_points.txt +9 -0
- nat/plugins/mcp/tool.py +0 -138
- nvidia_nat_mcp-1.4.0a20251014.dist-info/RECORD +0 -23
- nvidia_nat_mcp-1.4.0a20251014.dist-info/entry_points.txt +0 -3
- {nvidia_nat_mcp-1.4.0a20251014.dist-info → nvidia_nat_mcp-1.5.0a20260115.dist-info}/WHEEL +0 -0
- {nvidia_nat_mcp-1.4.0a20251014.dist-info → nvidia_nat_mcp-1.5.0a20260115.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat_mcp-1.4.0a20251014.dist-info → nvidia_nat_mcp-1.5.0a20260115.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat_mcp-1.4.0a20251014.dist-info → nvidia_nat_mcp-1.5.0a20260115.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
from typing import Literal
|
|
18
|
+
|
|
19
|
+
from pydantic import Field
|
|
20
|
+
from pydantic import field_validator
|
|
21
|
+
from pydantic import model_validator
|
|
22
|
+
|
|
23
|
+
from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
|
|
24
|
+
from nat.data_models.front_end import FrontEndBaseConfig
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class MCPFrontEndConfig(FrontEndBaseConfig, name="mcp"):
|
|
30
|
+
"""MCP front end configuration.
|
|
31
|
+
|
|
32
|
+
A simple MCP (Model Context Protocol) front end for NeMo Agent toolkit.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
name: str = Field(default="NeMo Agent Toolkit MCP",
|
|
36
|
+
description="Name of the MCP server (default: NeMo Agent Toolkit MCP)")
|
|
37
|
+
host: str = Field(default="localhost", description="Host to bind the server to (default: localhost)")
|
|
38
|
+
port: int = Field(default=9901, description="Port to bind the server to (default: 9901)", ge=0, le=65535)
|
|
39
|
+
debug: bool = Field(default=False, description="Enable debug mode (default: False)")
|
|
40
|
+
log_level: str = Field(default="INFO", description="Log level for the MCP server (default: INFO)")
|
|
41
|
+
tool_names: list[str] = Field(
|
|
42
|
+
default_factory=list,
|
|
43
|
+
description="The list of tools MCP server will expose (default: all tools)."
|
|
44
|
+
"Tool names can be functions or function groups",
|
|
45
|
+
)
|
|
46
|
+
transport: Literal["sse", "streamable-http"] = Field(
|
|
47
|
+
default="streamable-http",
|
|
48
|
+
description="Transport type for the MCP server (default: streamable-http, backwards compatible with sse)")
|
|
49
|
+
runner_class: str | None = Field(
|
|
50
|
+
default=None, description="Custom worker class for handling MCP routes (default: built-in worker)")
|
|
51
|
+
base_path: str | None = Field(default=None,
|
|
52
|
+
description="Base path to mount the MCP server at (e.g., '/api/v1'). "
|
|
53
|
+
"If specified, the server will be accessible at http://host:port{base_path}/mcp. "
|
|
54
|
+
"If None, server runs at root path /mcp.")
|
|
55
|
+
|
|
56
|
+
server_auth: OAuth2ResourceServerConfig | None = Field(
|
|
57
|
+
default=None, description=("OAuth 2.0 Resource Server configuration for token verification."))
|
|
58
|
+
|
|
59
|
+
@field_validator('base_path')
|
|
60
|
+
@classmethod
|
|
61
|
+
def validate_base_path(cls, v: str | None) -> str | None:
|
|
62
|
+
"""Validate that base_path starts with '/' and doesn't end with '/'."""
|
|
63
|
+
if v is not None:
|
|
64
|
+
if not v.startswith('/'):
|
|
65
|
+
raise ValueError("base_path must start with '/'")
|
|
66
|
+
if v.endswith('/'):
|
|
67
|
+
raise ValueError("base_path must not end with '/'")
|
|
68
|
+
return v
|
|
69
|
+
|
|
70
|
+
# Memory profiling configuration
|
|
71
|
+
enable_memory_profiling: bool = Field(default=False,
|
|
72
|
+
description="Enable memory profiling and diagnostics (default: False)")
|
|
73
|
+
memory_profile_interval: int = Field(default=50,
|
|
74
|
+
description="Log memory stats every N requests (default: 50)",
|
|
75
|
+
ge=1)
|
|
76
|
+
memory_profile_top_n: int = Field(default=10,
|
|
77
|
+
description="Number of top memory allocations to log (default: 10)",
|
|
78
|
+
ge=1,
|
|
79
|
+
le=50)
|
|
80
|
+
memory_profile_log_level: str = Field(default="DEBUG",
|
|
81
|
+
description="Log level for memory profiling output (default: DEBUG)")
|
|
82
|
+
|
|
83
|
+
@model_validator(mode="after")
|
|
84
|
+
def validate_security_configuration(self):
|
|
85
|
+
"""Validate security configuration to prevent accidental misconfigurations."""
|
|
86
|
+
# Check if server is bound to a non-localhost interface without authentication
|
|
87
|
+
localhost_hosts = {"localhost", "127.0.0.1", "::1"}
|
|
88
|
+
if self.host not in localhost_hosts and self.server_auth is None:
|
|
89
|
+
logger.warning(
|
|
90
|
+
"MCP server is configured to bind to '%s' without authentication. "
|
|
91
|
+
"This may expose your server to unauthorized access. "
|
|
92
|
+
"Consider either: (1) binding to localhost for local-only access, "
|
|
93
|
+
"or (2) configuring server_auth for production deployments on public interfaces.",
|
|
94
|
+
self.host)
|
|
95
|
+
|
|
96
|
+
# Check if SSE transport is used (which doesn't support authentication)
|
|
97
|
+
if self.transport == "sse":
|
|
98
|
+
if self.server_auth is not None:
|
|
99
|
+
logger.warning("SSE transport does not support authentication. "
|
|
100
|
+
"The configured server_auth will be ignored. "
|
|
101
|
+
"For production use with authentication, use 'streamable-http' transport instead.")
|
|
102
|
+
elif self.host not in localhost_hosts:
|
|
103
|
+
logger.warning(
|
|
104
|
+
"SSE transport does not support authentication and is bound to '%s'. "
|
|
105
|
+
"This configuration is not recommended for production use. "
|
|
106
|
+
"For production deployments, use 'streamable-http' transport with server_auth configured.",
|
|
107
|
+
self.host)
|
|
108
|
+
|
|
109
|
+
return self
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
import typing
|
|
18
|
+
|
|
19
|
+
from nat.builder.front_end import FrontEndBase
|
|
20
|
+
from nat.builder.workflow_builder import WorkflowBuilder
|
|
21
|
+
from nat.plugins.mcp.server.front_end_config import MCPFrontEndConfig
|
|
22
|
+
from nat.plugins.mcp.server.front_end_plugin_worker import MCPFrontEndPluginWorkerBase
|
|
23
|
+
|
|
24
|
+
if typing.TYPE_CHECKING:
|
|
25
|
+
from mcp.server.fastmcp import FastMCP
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class MCPFrontEndPlugin(FrontEndBase[MCPFrontEndConfig]):
|
|
31
|
+
"""MCP front end plugin implementation."""
|
|
32
|
+
|
|
33
|
+
def get_worker_class(self) -> type[MCPFrontEndPluginWorkerBase]:
|
|
34
|
+
"""Get the worker class for handling MCP routes."""
|
|
35
|
+
from nat.plugins.mcp.server.front_end_plugin_worker import MCPFrontEndPluginWorker
|
|
36
|
+
|
|
37
|
+
return MCPFrontEndPluginWorker
|
|
38
|
+
|
|
39
|
+
@typing.final
|
|
40
|
+
def get_worker_class_name(self) -> str:
|
|
41
|
+
"""Get the worker class name from configuration or default."""
|
|
42
|
+
if self.front_end_config.runner_class:
|
|
43
|
+
return self.front_end_config.runner_class
|
|
44
|
+
|
|
45
|
+
worker_class = self.get_worker_class()
|
|
46
|
+
return f"{worker_class.__module__}.{worker_class.__qualname__}"
|
|
47
|
+
|
|
48
|
+
def _get_worker_instance(self):
|
|
49
|
+
"""Get an instance of the worker class."""
|
|
50
|
+
# Import the worker class dynamically if specified in config
|
|
51
|
+
if self.front_end_config.runner_class:
|
|
52
|
+
module_name, class_name = self.front_end_config.runner_class.rsplit(".", 1)
|
|
53
|
+
import importlib
|
|
54
|
+
module = importlib.import_module(module_name)
|
|
55
|
+
worker_class = getattr(module, class_name)
|
|
56
|
+
else:
|
|
57
|
+
worker_class = self.get_worker_class()
|
|
58
|
+
|
|
59
|
+
return worker_class(self.full_config)
|
|
60
|
+
|
|
61
|
+
async def run(self) -> None:
|
|
62
|
+
"""Run the MCP server."""
|
|
63
|
+
# Build the workflow and add routes using the worker
|
|
64
|
+
async with WorkflowBuilder.from_config(config=self.full_config) as builder:
|
|
65
|
+
|
|
66
|
+
# Get the worker instance
|
|
67
|
+
worker = self._get_worker_instance()
|
|
68
|
+
|
|
69
|
+
# Let the worker create the MCP server (allows plugins to customize)
|
|
70
|
+
mcp = await worker.create_mcp_server()
|
|
71
|
+
|
|
72
|
+
# Add routes through the worker (includes health endpoint and function registration)
|
|
73
|
+
await worker.add_routes(mcp, builder)
|
|
74
|
+
|
|
75
|
+
# Start the MCP server with configurable transport
|
|
76
|
+
# streamable-http is the default, but users can choose sse if preferred
|
|
77
|
+
try:
|
|
78
|
+
# If base_path is configured, mount server at sub-path using FastAPI wrapper
|
|
79
|
+
if self.front_end_config.base_path:
|
|
80
|
+
if self.front_end_config.transport == "sse":
|
|
81
|
+
logger.warning(
|
|
82
|
+
"base_path is configured but SSE transport does not support mounting at sub-paths. "
|
|
83
|
+
"Use streamable-http transport for base_path support.")
|
|
84
|
+
logger.info("Starting MCP server with SSE endpoint at /sse")
|
|
85
|
+
await mcp.run_sse_async()
|
|
86
|
+
else:
|
|
87
|
+
full_url = f"http://{self.front_end_config.host}:{self.front_end_config.port}{self.front_end_config.base_path}/mcp"
|
|
88
|
+
logger.info(
|
|
89
|
+
"Mounting MCP server at %s/mcp on %s:%s",
|
|
90
|
+
self.front_end_config.base_path,
|
|
91
|
+
self.front_end_config.host,
|
|
92
|
+
self.front_end_config.port,
|
|
93
|
+
)
|
|
94
|
+
logger.info("MCP server URL: %s", full_url)
|
|
95
|
+
await self._run_with_mount(mcp)
|
|
96
|
+
# Standard behavior - run at root path
|
|
97
|
+
elif self.front_end_config.transport == "sse":
|
|
98
|
+
logger.info("Starting MCP server with SSE endpoint at /sse")
|
|
99
|
+
await mcp.run_sse_async()
|
|
100
|
+
else: # streamable-http
|
|
101
|
+
full_url = f"http://{self.front_end_config.host}:{self.front_end_config.port}/mcp"
|
|
102
|
+
logger.info("MCP server URL: %s", full_url)
|
|
103
|
+
await mcp.run_streamable_http_async()
|
|
104
|
+
except KeyboardInterrupt:
|
|
105
|
+
logger.info("MCP server shutdown requested (Ctrl+C). Shutting down gracefully.")
|
|
106
|
+
|
|
107
|
+
async def _run_with_mount(self, mcp: "FastMCP") -> None:
|
|
108
|
+
"""Run MCP server mounted at configured base_path using FastAPI wrapper.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
mcp: The FastMCP server instance to mount
|
|
112
|
+
"""
|
|
113
|
+
import contextlib
|
|
114
|
+
|
|
115
|
+
import uvicorn
|
|
116
|
+
from fastapi import FastAPI
|
|
117
|
+
|
|
118
|
+
@contextlib.asynccontextmanager
|
|
119
|
+
async def lifespan(_app: FastAPI):
|
|
120
|
+
"""Manage MCP server session lifecycle."""
|
|
121
|
+
logger.info("Starting MCP server session manager...")
|
|
122
|
+
async with contextlib.AsyncExitStack() as stack:
|
|
123
|
+
try:
|
|
124
|
+
# Initialize the MCP server's session manager
|
|
125
|
+
await stack.enter_async_context(mcp.session_manager.run())
|
|
126
|
+
logger.info("MCP server session manager started successfully")
|
|
127
|
+
yield
|
|
128
|
+
except Exception as e:
|
|
129
|
+
logger.error("Failed to start MCP server session manager: %s", e)
|
|
130
|
+
raise
|
|
131
|
+
logger.info("MCP server session manager stopped")
|
|
132
|
+
|
|
133
|
+
# Create a FastAPI wrapper app with lifespan management
|
|
134
|
+
app = FastAPI(
|
|
135
|
+
title=self.front_end_config.name,
|
|
136
|
+
description="MCP server mounted at custom base path",
|
|
137
|
+
lifespan=lifespan,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Mount the MCP server's ASGI app at the configured base_path
|
|
141
|
+
app.mount(self.front_end_config.base_path, mcp.streamable_http_app())
|
|
142
|
+
|
|
143
|
+
# Allow plugins to add routes to the wrapper app (e.g., OAuth discovery endpoints)
|
|
144
|
+
worker = self._get_worker_instance()
|
|
145
|
+
await worker.add_root_level_routes(app, mcp)
|
|
146
|
+
|
|
147
|
+
# Configure and start uvicorn server
|
|
148
|
+
config = uvicorn.Config(
|
|
149
|
+
app,
|
|
150
|
+
host=self.front_end_config.host,
|
|
151
|
+
port=self.front_end_config.port,
|
|
152
|
+
log_level=self.front_end_config.log_level.lower(),
|
|
153
|
+
)
|
|
154
|
+
server = uvicorn.Server(config)
|
|
155
|
+
await server.serve()
|
|
@@ -0,0 +1,415 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
from abc import ABC
|
|
18
|
+
from abc import abstractmethod
|
|
19
|
+
from collections.abc import Mapping
|
|
20
|
+
from typing import TYPE_CHECKING
|
|
21
|
+
from typing import Any
|
|
22
|
+
|
|
23
|
+
from starlette.exceptions import HTTPException
|
|
24
|
+
from starlette.requests import Request
|
|
25
|
+
|
|
26
|
+
from mcp.server.fastmcp import FastMCP
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from fastapi import FastAPI
|
|
30
|
+
|
|
31
|
+
from nat.builder.function import Function
|
|
32
|
+
from nat.builder.function_base import FunctionBase
|
|
33
|
+
from nat.builder.workflow import Workflow
|
|
34
|
+
from nat.builder.workflow_builder import WorkflowBuilder
|
|
35
|
+
from nat.data_models.config import Config
|
|
36
|
+
from nat.plugins.mcp.server.front_end_config import MCPFrontEndConfig
|
|
37
|
+
from nat.plugins.mcp.server.memory_profiler import MemoryProfiler
|
|
38
|
+
from nat.runtime.session import SessionManager
|
|
39
|
+
|
|
40
|
+
logger = logging.getLogger(__name__)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class MCPFrontEndPluginWorkerBase(ABC):
|
|
44
|
+
"""Base class for MCP front end plugin workers.
|
|
45
|
+
|
|
46
|
+
This abstract base class provides shared utilities and defines the contract
|
|
47
|
+
for MCP worker implementations. Most users should inherit from
|
|
48
|
+
MCPFrontEndPluginWorker instead of this class directly.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(self, config: Config):
|
|
52
|
+
"""Initialize the MCP worker with configuration.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
config: The full NAT configuration
|
|
56
|
+
"""
|
|
57
|
+
self.full_config = config
|
|
58
|
+
self.front_end_config: MCPFrontEndConfig = config.general.front_end
|
|
59
|
+
|
|
60
|
+
# Initialize memory profiler if enabled
|
|
61
|
+
self.memory_profiler = MemoryProfiler(enabled=self.front_end_config.enable_memory_profiling,
|
|
62
|
+
log_interval=self.front_end_config.memory_profile_interval,
|
|
63
|
+
top_n=self.front_end_config.memory_profile_top_n,
|
|
64
|
+
log_level=self.front_end_config.memory_profile_log_level)
|
|
65
|
+
|
|
66
|
+
def _setup_health_endpoint(self, mcp: FastMCP):
|
|
67
|
+
"""Set up the HTTP health endpoint that exercises MCP ping handler."""
|
|
68
|
+
|
|
69
|
+
@mcp.custom_route("/health", methods=["GET"])
|
|
70
|
+
async def health_check(_request: Request):
|
|
71
|
+
"""HTTP health check using server's internal ping handler"""
|
|
72
|
+
from starlette.responses import JSONResponse
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
from mcp.types import PingRequest
|
|
76
|
+
|
|
77
|
+
# Create a ping request
|
|
78
|
+
ping_request = PingRequest(method="ping")
|
|
79
|
+
|
|
80
|
+
# Call the ping handler directly (same one that responds to MCP pings)
|
|
81
|
+
await mcp._mcp_server.request_handlers[PingRequest](ping_request)
|
|
82
|
+
|
|
83
|
+
return JSONResponse({
|
|
84
|
+
"status": "healthy",
|
|
85
|
+
"error": None,
|
|
86
|
+
"server_name": mcp.name,
|
|
87
|
+
})
|
|
88
|
+
|
|
89
|
+
except Exception as e:
|
|
90
|
+
return JSONResponse({
|
|
91
|
+
"status": "unhealthy",
|
|
92
|
+
"error": str(e),
|
|
93
|
+
"server_name": mcp.name,
|
|
94
|
+
},
|
|
95
|
+
status_code=503)
|
|
96
|
+
|
|
97
|
+
@abstractmethod
|
|
98
|
+
async def create_mcp_server(self) -> FastMCP:
|
|
99
|
+
"""Create and configure the MCP server instance.
|
|
100
|
+
|
|
101
|
+
This is the main extension point. Plugins can return FastMCP or any subclass
|
|
102
|
+
to customize server behavior (for example, add authentication, custom transports).
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
FastMCP instance or a subclass with custom behavior
|
|
106
|
+
"""
|
|
107
|
+
...
|
|
108
|
+
|
|
109
|
+
@abstractmethod
|
|
110
|
+
async def add_routes(self, mcp: FastMCP, builder: WorkflowBuilder):
|
|
111
|
+
"""Add routes to the MCP server.
|
|
112
|
+
|
|
113
|
+
Plugins must implement this method. Most plugins can call
|
|
114
|
+
_default_add_routes() for standard behavior and then add
|
|
115
|
+
custom enhancements.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
mcp: The FastMCP server instance
|
|
119
|
+
builder: The workflow builder instance
|
|
120
|
+
"""
|
|
121
|
+
...
|
|
122
|
+
|
|
123
|
+
async def _default_add_routes(self, mcp: FastMCP, builder: WorkflowBuilder):
|
|
124
|
+
"""Default route registration logic - reusable by subclasses.
|
|
125
|
+
|
|
126
|
+
This is a protected helper method that plugins can call to get
|
|
127
|
+
standard route registration behavior. Plugins typically call this
|
|
128
|
+
from their add_routes() implementation and then add custom features.
|
|
129
|
+
|
|
130
|
+
This method:
|
|
131
|
+
- Sets up the health endpoint
|
|
132
|
+
- Builds the workflow and extracts all functions
|
|
133
|
+
- Filters functions based on tool_names config
|
|
134
|
+
- Registers each function as an MCP tool
|
|
135
|
+
- Sets up debug endpoints for tool introspection
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
mcp: The FastMCP server instance
|
|
139
|
+
builder: The workflow builder instance
|
|
140
|
+
"""
|
|
141
|
+
from nat.plugins.mcp.server.tool_converter import register_function_with_mcp
|
|
142
|
+
|
|
143
|
+
# Set up the health endpoint
|
|
144
|
+
self._setup_health_endpoint(mcp)
|
|
145
|
+
|
|
146
|
+
# Build the default workflow
|
|
147
|
+
workflow = await builder.build()
|
|
148
|
+
|
|
149
|
+
# Get all functions from the workflow
|
|
150
|
+
functions = await self._get_all_functions(workflow)
|
|
151
|
+
|
|
152
|
+
# Filter functions based on tool_names if provided
|
|
153
|
+
if self.front_end_config.tool_names:
|
|
154
|
+
logger.info("Filtering functions based on tool_names: %s", self.front_end_config.tool_names)
|
|
155
|
+
filtered_functions: dict[str, Function] = {}
|
|
156
|
+
for function_name, function in functions.items():
|
|
157
|
+
if function_name in self.front_end_config.tool_names:
|
|
158
|
+
# Treat current tool_names as function names, so check if the function name is in the list
|
|
159
|
+
filtered_functions[function_name] = function
|
|
160
|
+
elif any(function_name.startswith(f"{group_name}.") for group_name in self.front_end_config.tool_names):
|
|
161
|
+
# Treat tool_names as function group names, so check if the function name starts with the group name
|
|
162
|
+
filtered_functions[function_name] = function
|
|
163
|
+
else:
|
|
164
|
+
logger.debug("Skipping function %s as it's not in tool_names", function_name)
|
|
165
|
+
functions = filtered_functions
|
|
166
|
+
|
|
167
|
+
# Create SessionManagers for each function
|
|
168
|
+
# For regular functions, wrap them in a mini-workflow with that function as entry point
|
|
169
|
+
# For workflows, use them directly
|
|
170
|
+
session_managers: dict[str, SessionManager] = {}
|
|
171
|
+
for function_name, function in functions.items():
|
|
172
|
+
if isinstance(function, Workflow):
|
|
173
|
+
# Already a workflow, use it directly
|
|
174
|
+
logger.info("Function %s is a Workflow, using directly", function_name)
|
|
175
|
+
session_managers[function_name] = await SessionManager.create(config=self.full_config,
|
|
176
|
+
shared_builder=builder,
|
|
177
|
+
entry_function=None)
|
|
178
|
+
else:
|
|
179
|
+
# Regular function - build a workflow with this function as entry point
|
|
180
|
+
logger.info("Function %s is a regular function, building entry workflow", function_name)
|
|
181
|
+
session_managers[function_name] = await SessionManager.create(config=self.full_config,
|
|
182
|
+
shared_builder=builder,
|
|
183
|
+
entry_function=function_name)
|
|
184
|
+
|
|
185
|
+
# Register each function with MCP, passing SessionManager for observability
|
|
186
|
+
for function_name, session_manager in session_managers.items():
|
|
187
|
+
register_function_with_mcp(mcp,
|
|
188
|
+
function_name,
|
|
189
|
+
session_manager,
|
|
190
|
+
self.memory_profiler,
|
|
191
|
+
function=functions.get(function_name))
|
|
192
|
+
|
|
193
|
+
# Add a simple fallback function if no functions were found
|
|
194
|
+
if not session_managers:
|
|
195
|
+
raise RuntimeError("No functions found in workflow. Please check your configuration.")
|
|
196
|
+
|
|
197
|
+
# After registration, expose debug endpoints for tool/schema inspection
|
|
198
|
+
# Extract the entry functions from session managers for debug endpoints
|
|
199
|
+
debug_functions = {name: sm.workflow for name, sm in session_managers.items()}
|
|
200
|
+
self._setup_debug_endpoints(mcp, debug_functions)
|
|
201
|
+
|
|
202
|
+
async def _get_all_functions(self, workflow: Workflow) -> dict[str, Function]:
|
|
203
|
+
"""Get all functions from the workflow.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
workflow: The NAT workflow.
|
|
207
|
+
|
|
208
|
+
Returns:
|
|
209
|
+
Dict mapping function names to Function objects.
|
|
210
|
+
"""
|
|
211
|
+
functions: dict[str, Function] = {}
|
|
212
|
+
|
|
213
|
+
# Extract all functions from the workflow
|
|
214
|
+
functions.update(workflow.functions)
|
|
215
|
+
for function_group in workflow.function_groups.values():
|
|
216
|
+
functions.update(await function_group.get_accessible_functions())
|
|
217
|
+
|
|
218
|
+
if workflow.config.workflow.workflow_alias:
|
|
219
|
+
functions[workflow.config.workflow.workflow_alias] = workflow
|
|
220
|
+
else:
|
|
221
|
+
functions[workflow.config.workflow.type] = workflow
|
|
222
|
+
|
|
223
|
+
return functions
|
|
224
|
+
|
|
225
|
+
async def add_root_level_routes(self, wrapper_app: "FastAPI", mcp: FastMCP) -> None:
|
|
226
|
+
"""Add routes to the wrapper FastAPI app (optional extension point).
|
|
227
|
+
|
|
228
|
+
This method is called when base_path is configured and a wrapper
|
|
229
|
+
FastAPI app is created to mount the MCP server. Plugins can override
|
|
230
|
+
this to add routes to the wrapper app at the root level, outside the
|
|
231
|
+
mounted MCP server path.
|
|
232
|
+
|
|
233
|
+
Common use cases:
|
|
234
|
+
- OAuth discovery endpoints (e.g., /.well-known/oauth-protected-resource)
|
|
235
|
+
- Health checks at root level
|
|
236
|
+
- Static file serving
|
|
237
|
+
- Custom authentication/authorization endpoints
|
|
238
|
+
|
|
239
|
+
Default implementation does nothing, making this an optional extension point.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
wrapper_app: The FastAPI wrapper application that mounts the MCP server
|
|
243
|
+
mcp: The FastMCP server instance (already mounted at base_path)
|
|
244
|
+
"""
|
|
245
|
+
pass # Default: no additional root-level routes
|
|
246
|
+
|
|
247
|
+
def _setup_debug_endpoints(self, mcp: FastMCP, functions: Mapping[str, FunctionBase]) -> None:
|
|
248
|
+
"""Set up HTTP debug endpoints for introspecting tools and schemas.
|
|
249
|
+
|
|
250
|
+
Exposes:
|
|
251
|
+
- GET /debug/tools/list: List tools. Optional query param `name` (one or more, repeatable or comma separated)
|
|
252
|
+
selects a subset and returns details for those tools.
|
|
253
|
+
- GET /debug/memory/stats: Get current memory profiling statistics (read-only)
|
|
254
|
+
"""
|
|
255
|
+
|
|
256
|
+
@mcp.custom_route("/debug/tools/list", methods=["GET"])
|
|
257
|
+
async def list_tools(request: Request):
|
|
258
|
+
"""HTTP list tools endpoint."""
|
|
259
|
+
|
|
260
|
+
from starlette.responses import JSONResponse
|
|
261
|
+
|
|
262
|
+
from nat.plugins.mcp.server.tool_converter import get_function_description
|
|
263
|
+
|
|
264
|
+
# Query params
|
|
265
|
+
# Support repeated names and comma-separated lists
|
|
266
|
+
names_param_list = set(request.query_params.getlist("name"))
|
|
267
|
+
names: list[str] = []
|
|
268
|
+
for raw in names_param_list:
|
|
269
|
+
# if p.strip() is empty, it won't be included in the list!
|
|
270
|
+
parts = [p.strip() for p in raw.split(",") if p.strip()]
|
|
271
|
+
names.extend(parts)
|
|
272
|
+
detail_raw = request.query_params.get("detail")
|
|
273
|
+
|
|
274
|
+
def _parse_detail_param(detail_param: str | None, has_names: bool) -> bool:
|
|
275
|
+
if detail_param is None:
|
|
276
|
+
if has_names:
|
|
277
|
+
return True
|
|
278
|
+
return False
|
|
279
|
+
v = detail_param.strip().lower()
|
|
280
|
+
if v in ("0", "false", "no", "off"):
|
|
281
|
+
return False
|
|
282
|
+
if v in ("1", "true", "yes", "on"):
|
|
283
|
+
return True
|
|
284
|
+
# For invalid values, default based on whether names are present
|
|
285
|
+
return has_names
|
|
286
|
+
|
|
287
|
+
# Helper function to build the input schema info
|
|
288
|
+
def _build_schema_info(fn: FunctionBase) -> dict[str, Any] | None:
|
|
289
|
+
schema = getattr(fn, "input_schema", None)
|
|
290
|
+
if schema is None:
|
|
291
|
+
return None
|
|
292
|
+
|
|
293
|
+
# check if schema is a ChatRequest
|
|
294
|
+
schema_name = getattr(schema, "__name__", "")
|
|
295
|
+
schema_qualname = getattr(schema, "__qualname__", "")
|
|
296
|
+
if "ChatRequest" in schema_name or "ChatRequest" in schema_qualname:
|
|
297
|
+
# Simplified interface used by MCP wrapper for ChatRequest
|
|
298
|
+
return {
|
|
299
|
+
"type": "object",
|
|
300
|
+
"properties": {
|
|
301
|
+
"query": {
|
|
302
|
+
"type": "string", "description": "User query string"
|
|
303
|
+
}
|
|
304
|
+
},
|
|
305
|
+
"required": ["query"],
|
|
306
|
+
"title": "ChatRequestQuery",
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
# Pydantic models provide model_json_schema
|
|
310
|
+
if schema is not None and hasattr(schema, "model_json_schema"):
|
|
311
|
+
return schema.model_json_schema()
|
|
312
|
+
|
|
313
|
+
return None
|
|
314
|
+
|
|
315
|
+
def _build_final_json(functions_to_include: Mapping[str, FunctionBase],
|
|
316
|
+
include_schemas: bool = False) -> dict[str, Any]:
|
|
317
|
+
tools = []
|
|
318
|
+
for name, fn in functions_to_include.items():
|
|
319
|
+
list_entry: dict[str, Any] = {
|
|
320
|
+
"name": name, "description": get_function_description(fn), "is_workflow": hasattr(fn, "run")
|
|
321
|
+
}
|
|
322
|
+
if include_schemas:
|
|
323
|
+
list_entry["schema"] = _build_schema_info(fn)
|
|
324
|
+
tools.append(list_entry)
|
|
325
|
+
|
|
326
|
+
return {
|
|
327
|
+
"count": len(tools),
|
|
328
|
+
"tools": tools,
|
|
329
|
+
"server_name": mcp.name,
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
if names:
|
|
333
|
+
# Return selected tools
|
|
334
|
+
try:
|
|
335
|
+
functions_to_include = {n: functions[n] for n in names}
|
|
336
|
+
except KeyError as e:
|
|
337
|
+
raise HTTPException(status_code=404, detail=f"Tool \"{e.args[0]}\" not found.") from e
|
|
338
|
+
else:
|
|
339
|
+
functions_to_include = functions
|
|
340
|
+
|
|
341
|
+
# Default for listing all: detail defaults to False unless explicitly set true
|
|
342
|
+
return JSONResponse(
|
|
343
|
+
_build_final_json(functions_to_include, _parse_detail_param(detail_raw, has_names=bool(names))))
|
|
344
|
+
|
|
345
|
+
# Memory profiling endpoint (read-only)
|
|
346
|
+
@mcp.custom_route("/debug/memory/stats", methods=["GET"])
|
|
347
|
+
async def get_memory_stats(_request: Request):
|
|
348
|
+
"""Get current memory profiling statistics."""
|
|
349
|
+
from starlette.responses import JSONResponse
|
|
350
|
+
|
|
351
|
+
stats = self.memory_profiler.get_stats()
|
|
352
|
+
return JSONResponse(stats)
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
class MCPFrontEndPluginWorker(MCPFrontEndPluginWorkerBase):
|
|
356
|
+
"""Default MCP server worker implementation.
|
|
357
|
+
|
|
358
|
+
Inherit from this class to create custom MCP workers that extend or modify
|
|
359
|
+
server behavior. Override create_mcp_server() to use a different server type,
|
|
360
|
+
and override add_routes() to add custom functionality.
|
|
361
|
+
|
|
362
|
+
Example:
|
|
363
|
+
class CustomWorker(MCPFrontEndPluginWorker):
|
|
364
|
+
async def create_mcp_server(self):
|
|
365
|
+
# Return custom MCP server instance
|
|
366
|
+
return MyCustomFastMCP(...)
|
|
367
|
+
|
|
368
|
+
async def add_routes(self, mcp, builder):
|
|
369
|
+
# Get default routes
|
|
370
|
+
await super().add_routes(mcp, builder)
|
|
371
|
+
# Add custom features
|
|
372
|
+
self._add_my_custom_features(mcp)
|
|
373
|
+
"""
|
|
374
|
+
|
|
375
|
+
async def create_mcp_server(self) -> FastMCP:
|
|
376
|
+
"""Create default MCP server with optional authentication.
|
|
377
|
+
|
|
378
|
+
Returns:
|
|
379
|
+
FastMCP instance configured with settings from NAT config
|
|
380
|
+
"""
|
|
381
|
+
# Handle auth if configured
|
|
382
|
+
auth_settings = None
|
|
383
|
+
token_verifier = None
|
|
384
|
+
|
|
385
|
+
if self.front_end_config.server_auth:
|
|
386
|
+
from pydantic import AnyHttpUrl
|
|
387
|
+
|
|
388
|
+
from mcp.server.auth.settings import AuthSettings
|
|
389
|
+
|
|
390
|
+
server_url = f"http://{self.front_end_config.host}:{self.front_end_config.port}"
|
|
391
|
+
auth_settings = AuthSettings(issuer_url=AnyHttpUrl(self.front_end_config.server_auth.issuer_url),
|
|
392
|
+
required_scopes=self.front_end_config.server_auth.scopes,
|
|
393
|
+
resource_server_url=AnyHttpUrl(server_url))
|
|
394
|
+
|
|
395
|
+
# Create token verifier
|
|
396
|
+
from nat.plugins.mcp.server.introspection_token_verifier import IntrospectionTokenVerifier
|
|
397
|
+
|
|
398
|
+
token_verifier = IntrospectionTokenVerifier(self.front_end_config.server_auth)
|
|
399
|
+
|
|
400
|
+
return FastMCP(name=self.front_end_config.name,
|
|
401
|
+
host=self.front_end_config.host,
|
|
402
|
+
port=self.front_end_config.port,
|
|
403
|
+
debug=self.front_end_config.debug,
|
|
404
|
+
auth=auth_settings,
|
|
405
|
+
token_verifier=token_verifier)
|
|
406
|
+
|
|
407
|
+
async def add_routes(self, mcp: FastMCP, builder: WorkflowBuilder):
|
|
408
|
+
"""Add default routes to the MCP server.
|
|
409
|
+
|
|
410
|
+
Args:
|
|
411
|
+
mcp: The FastMCP server instance
|
|
412
|
+
builder: The workflow builder instance
|
|
413
|
+
"""
|
|
414
|
+
# Use the default implementation from base class to add the tools to the MCP server
|
|
415
|
+
await self._default_add_routes(mcp, builder)
|