nvidia-nat-mcp 1.3.0a20250910__py3-none-any.whl → 1.3.0a20250922__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/plugins/mcp/auth/__init__.py +14 -0
- nat/plugins/mcp/auth/auth_provider.py +367 -0
- nat/plugins/mcp/auth/auth_provider_config.py +76 -0
- nat/plugins/mcp/auth/register.py +25 -0
- nat/plugins/mcp/client_base.py +118 -88
- nat/plugins/mcp/client_impl.py +90 -137
- nat/plugins/mcp/tool.py +41 -35
- nat/plugins/mcp/utils.py +95 -0
- {nvidia_nat_mcp-1.3.0a20250910.dist-info → nvidia_nat_mcp-1.3.0a20250922.dist-info}/METADATA +3 -3
- nvidia_nat_mcp-1.3.0a20250922.dist-info/RECORD +18 -0
- {nvidia_nat_mcp-1.3.0a20250910.dist-info → nvidia_nat_mcp-1.3.0a20250922.dist-info}/entry_points.txt +1 -0
- nvidia_nat_mcp-1.3.0a20250910.dist-info/RECORD +0 -13
- {nvidia_nat_mcp-1.3.0a20250910.dist-info → nvidia_nat_mcp-1.3.0a20250922.dist-info}/WHEEL +0 -0
- {nvidia_nat_mcp-1.3.0a20250910.dist-info → nvidia_nat_mcp-1.3.0a20250922.dist-info}/top_level.txt +0 -0
nat/plugins/mcp/client_base.py
CHANGED
@@ -20,12 +20,9 @@ from abc import ABC
|
|
20
20
|
from abc import abstractmethod
|
21
21
|
from contextlib import AsyncExitStack
|
22
22
|
from contextlib import asynccontextmanager
|
23
|
-
from
|
24
|
-
from typing import Any
|
23
|
+
from typing import AsyncGenerator
|
25
24
|
|
26
|
-
|
27
|
-
from pydantic import Field
|
28
|
-
from pydantic import create_model
|
25
|
+
import httpx
|
29
26
|
|
30
27
|
from mcp import ClientSession
|
31
28
|
from mcp.client.sse import sse_client
|
@@ -33,104 +30,120 @@ from mcp.client.stdio import StdioServerParameters
|
|
33
30
|
from mcp.client.stdio import stdio_client
|
34
31
|
from mcp.client.streamable_http import streamablehttp_client
|
35
32
|
from mcp.types import TextContent
|
33
|
+
from nat.authentication.interfaces import AuthProviderBase
|
34
|
+
from nat.data_models.authentication import AuthReason
|
35
|
+
from nat.data_models.authentication import AuthRequest
|
36
36
|
from nat.plugins.mcp.exception_handler import mcp_exception_handler
|
37
37
|
from nat.plugins.mcp.exceptions import MCPToolNotFoundError
|
38
|
+
from nat.plugins.mcp.utils import model_from_mcp_schema
|
38
39
|
from nat.utils.type_utils import override
|
39
40
|
|
40
41
|
logger = logging.getLogger(__name__)
|
41
42
|
|
42
43
|
|
43
|
-
|
44
|
+
class AuthAdapter(httpx.Auth):
|
44
45
|
"""
|
45
|
-
|
46
|
+
httpx.Auth adapter for authentication providers.
|
47
|
+
Converts AuthProviderBase to httpx.Auth interface for dynamic token management.
|
46
48
|
"""
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
"
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
49
|
+
|
50
|
+
def __init__(self, auth_provider: AuthProviderBase, auth_for_tool_calls_only: bool = False):
|
51
|
+
self.auth_provider = auth_provider
|
52
|
+
self.auth_for_tool_calls_only = auth_for_tool_calls_only
|
53
|
+
|
54
|
+
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
|
55
|
+
"""Add authentication headers to the request using NAT auth provider."""
|
56
|
+
# Check if we should only auth tool calls, Is this needed?
|
57
|
+
if self.auth_for_tool_calls_only and not self._is_tool_call_request(request):
|
58
|
+
# Skip auth for non-tool calls
|
59
|
+
yield request
|
60
|
+
return
|
61
|
+
|
62
|
+
try:
|
63
|
+
# Get fresh auth headers from the NAT auth provider
|
64
|
+
auth_headers = await self._get_auth_headers(reason=AuthReason.NORMAL)
|
65
|
+
request.headers.update(auth_headers)
|
66
|
+
except Exception as e:
|
67
|
+
logger.info("Failed to get auth headers: %s", e)
|
68
|
+
# Continue without auth headers if auth fails
|
69
|
+
|
70
|
+
response = yield request
|
71
|
+
|
72
|
+
# Handle 401 responses by retrying with fresh auth
|
73
|
+
if response.status_code == 401:
|
74
|
+
try:
|
75
|
+
# Get fresh auth headers with 401 context
|
76
|
+
auth_headers = await self._get_auth_headers(reason=AuthReason.RETRY_AFTER_401, response=response)
|
77
|
+
request.headers.update(auth_headers)
|
78
|
+
yield request # Retry the request
|
79
|
+
except Exception as e:
|
80
|
+
logger.info("Failed to refresh auth after 401: %s", e)
|
81
|
+
return
|
82
|
+
|
83
|
+
def _is_tool_call_request(self, request: httpx.Request) -> bool:
|
84
|
+
"""Check if this is a tool call request based on the request body."""
|
85
|
+
try:
|
86
|
+
# Check if the request body contains a tool call
|
87
|
+
if request.content:
|
88
|
+
import json
|
89
|
+
body = json.loads(request.content.decode('utf-8'))
|
90
|
+
# Check if it's a JSON-RPC request with method "tools/call"
|
91
|
+
if (isinstance(body, dict) and body.get("method") == "tools/call"):
|
92
|
+
return True
|
93
|
+
except (json.JSONDecodeError, UnicodeDecodeError, AttributeError):
|
94
|
+
# If we can't parse the body, assume it's not a tool call
|
95
|
+
pass
|
96
|
+
return False
|
97
|
+
|
98
|
+
async def _get_auth_headers(self, reason: AuthReason, response: httpx.Response | None = None) -> dict[str, str]:
|
99
|
+
"""Get authentication headers from the NAT auth provider."""
|
100
|
+
# Build auth request
|
101
|
+
www_authenticate = response.headers.get("WWW-Authenticate", None) if response else None
|
102
|
+
auth_request = AuthRequest(
|
103
|
+
reason=reason,
|
104
|
+
www_authenticate=www_authenticate,
|
105
|
+
)
|
106
|
+
try:
|
107
|
+
# Mutating the config is not thread-safe, so we need to lock here
|
108
|
+
# Is mutating the config the only way to pass the auth request to the auth provider? This needs
|
109
|
+
# to be re-visited.
|
110
|
+
self.auth_provider.config.auth_request = auth_request
|
111
|
+
auth_result = await self.auth_provider.authenticate()
|
112
|
+
# Check if we have BearerTokenCred
|
113
|
+
from nat.data_models.authentication import BearerTokenCred
|
114
|
+
if auth_result.credentials and isinstance(auth_result.credentials[0], BearerTokenCred):
|
115
|
+
token = auth_result.credentials[0].token.get_secret_value()
|
116
|
+
return {"Authorization": f"Bearer {token}"}
|
78
117
|
else:
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
mapped = _type_map.get(t, Any)
|
85
|
-
field_type = mapped if field_type is None else field_type | mapped
|
86
|
-
|
87
|
-
return field_type, Field(
|
88
|
-
default=field_properties.get("default", None if "null" in json_type else ...),
|
89
|
-
description=field_properties.get("description", "")
|
90
|
-
)
|
91
|
-
else:
|
92
|
-
field_type = _type_map.get(json_type, Any)
|
93
|
-
|
94
|
-
# Determine the default value based on whether the field is required
|
95
|
-
if field_name in required_fields:
|
96
|
-
# Field is required - use explicit default if provided, otherwise make it required
|
97
|
-
default_value = field_properties.get("default", ...)
|
98
|
-
else:
|
99
|
-
# Field is optional - use explicit default if provided, otherwise None
|
100
|
-
default_value = field_properties.get("default", None)
|
101
|
-
# Make the type optional if no default was provided
|
102
|
-
if "default" not in field_properties:
|
103
|
-
field_type = field_type | None
|
104
|
-
|
105
|
-
nullable = field_properties.get("nullable", False)
|
106
|
-
description = field_properties.get("description", "")
|
107
|
-
|
108
|
-
field_type = field_type | None if nullable else field_type
|
109
|
-
|
110
|
-
return field_type, Field(default=default_value, description=description)
|
111
|
-
|
112
|
-
for field_name, field_props in properties.items():
|
113
|
-
schema_dict[field_name] = _generate_field(field_name=field_name, field_properties=field_props)
|
114
|
-
return create_model(f"{_generate_valid_classname(name)}InputSchema", **schema_dict)
|
118
|
+
logger.warning("Auth provider did not return BearerTokenCred")
|
119
|
+
return {}
|
120
|
+
except Exception as e:
|
121
|
+
logger.warning("Failed to get auth token: %s", e)
|
122
|
+
return {}
|
115
123
|
|
116
124
|
|
117
125
|
class MCPBaseClient(ABC):
|
118
126
|
"""
|
119
|
-
Base client for creating a session and connecting to an MCP server
|
127
|
+
Base client for creating a MCP transport session and connecting to an MCP server
|
120
128
|
|
121
129
|
Args:
|
122
130
|
transport (str): The type of client to use ('sse', 'stdio', or 'streamable-http')
|
131
|
+
auth_provider (AuthProviderBase | None): Optional authentication provider for Bearer token injection
|
123
132
|
"""
|
124
133
|
|
125
|
-
def __init__(self, transport: str = 'streamable-http'):
|
134
|
+
def __init__(self, transport: str = 'streamable-http', auth_provider: AuthProviderBase | None = None):
|
126
135
|
self._tools = None
|
127
136
|
self._transport = transport.lower()
|
128
137
|
if self._transport not in ['sse', 'stdio', 'streamable-http']:
|
129
138
|
raise ValueError("transport must be either 'sse', 'stdio' or 'streamable-http'")
|
130
139
|
|
131
140
|
self._exit_stack: AsyncExitStack | None = None
|
141
|
+
self._session: ClientSession | None = None # Main session
|
142
|
+
self._connection_established = False
|
143
|
+
self._initial_connection = False
|
132
144
|
|
133
|
-
|
145
|
+
# Convert auth provider to AuthAdapter
|
146
|
+
self._httpx_auth = AuthAdapter(auth_provider) if auth_provider else None
|
134
147
|
|
135
148
|
@property
|
136
149
|
def transport(self) -> str:
|
@@ -142,15 +155,19 @@ class MCPBaseClient(ABC):
|
|
142
155
|
|
143
156
|
self._exit_stack = AsyncExitStack()
|
144
157
|
|
158
|
+
# Establish connection with httpx.Auth
|
145
159
|
self._session = await self._exit_stack.enter_async_context(self.connect_to_server())
|
146
160
|
|
161
|
+
self._initial_connection = True
|
162
|
+
self._connection_established = True
|
163
|
+
|
147
164
|
return self
|
148
165
|
|
149
166
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
150
|
-
|
151
167
|
if not self._exit_stack:
|
152
168
|
raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
|
153
169
|
|
170
|
+
# Close session
|
154
171
|
await self._exit_stack.aclose()
|
155
172
|
self._session = None
|
156
173
|
self._exit_stack = None
|
@@ -168,11 +185,12 @@ class MCPBaseClient(ABC):
|
|
168
185
|
"""
|
169
186
|
Establish a session with an MCP server within an async context
|
170
187
|
"""
|
171
|
-
|
188
|
+
yield
|
172
189
|
|
173
190
|
async def get_tools(self):
|
174
191
|
"""
|
175
192
|
Retrieve a dictionary of all tools served by the MCP server.
|
193
|
+
Uses unauthenticated session for discovery.
|
176
194
|
"""
|
177
195
|
|
178
196
|
if not self._session:
|
@@ -185,7 +203,8 @@ class MCPBaseClient(ABC):
|
|
185
203
|
MCPToolClient(session=self._session,
|
186
204
|
tool_name=tool.name,
|
187
205
|
tool_description=tool.description,
|
188
|
-
tool_input_schema=tool.inputSchema
|
206
|
+
tool_input_schema=tool.inputSchema,
|
207
|
+
parent_client=self)
|
189
208
|
for tool in response.tools
|
190
209
|
}
|
191
210
|
|
@@ -257,7 +276,9 @@ class MCPSSEClient(MCPBaseClient):
|
|
257
276
|
|
258
277
|
class MCPStdioClient(MCPBaseClient):
|
259
278
|
"""
|
260
|
-
Client for creating a session and connecting to an MCP server using stdio
|
279
|
+
Client for creating a session and connecting to an MCP server using stdio.
|
280
|
+
This is a local transport that spawns the MCP server process and communicates
|
281
|
+
with it over stdin/stdout.
|
261
282
|
|
262
283
|
Args:
|
263
284
|
command (str): The command to run
|
@@ -307,11 +328,11 @@ class MCPStreamableHTTPClient(MCPBaseClient):
|
|
307
328
|
|
308
329
|
Args:
|
309
330
|
url (str): The url of the MCP server
|
331
|
+
auth_provider (AuthProviderBase | None): Optional authentication provider for Bearer token injection
|
310
332
|
"""
|
311
333
|
|
312
|
-
def __init__(self, url: str):
|
313
|
-
super().__init__("streamable-http")
|
314
|
-
|
334
|
+
def __init__(self, url: str, auth_provider: AuthProviderBase | None = None):
|
335
|
+
super().__init__("streamable-http", auth_provider=auth_provider)
|
315
336
|
self._url = url
|
316
337
|
|
317
338
|
@property
|
@@ -323,11 +344,13 @@ class MCPStreamableHTTPClient(MCPBaseClient):
|
|
323
344
|
return f"streamable-http:{self._url}"
|
324
345
|
|
325
346
|
@asynccontextmanager
|
347
|
+
@override
|
326
348
|
async def connect_to_server(self):
|
327
349
|
"""
|
328
350
|
Establish a session with an MCP server via streamable-http within an async context
|
329
351
|
"""
|
330
|
-
|
352
|
+
# Use httpx.Auth for authentication
|
353
|
+
async with streamablehttp_client(url=self._url, auth=self._httpx_auth) as (read, write, _):
|
331
354
|
async with ClientSession(read, write) as session:
|
332
355
|
await session.initialize()
|
333
356
|
yield session
|
@@ -335,24 +358,28 @@ class MCPStreamableHTTPClient(MCPBaseClient):
|
|
335
358
|
|
336
359
|
class MCPToolClient:
|
337
360
|
"""
|
338
|
-
Client wrapper used to call an MCP tool.
|
361
|
+
Client wrapper used to call an MCP tool. This assumes that the MCP transport session
|
362
|
+
has already been setup.
|
339
363
|
|
340
364
|
Args:
|
341
|
-
|
365
|
+
session (ClientSession): The MCP client session
|
342
366
|
tool_name (str): The name of the tool to wrap
|
343
367
|
tool_description (str): The description of the tool provided by the MCP server.
|
344
368
|
tool_input_schema (dict): The input schema for the tool.
|
369
|
+
parent_client (MCPBaseClient): The parent MCP client for auth management.
|
345
370
|
"""
|
346
371
|
|
347
372
|
def __init__(self,
|
348
373
|
session: ClientSession,
|
349
374
|
tool_name: str,
|
350
375
|
tool_description: str | None,
|
351
|
-
tool_input_schema: dict | None = None
|
376
|
+
tool_input_schema: dict | None = None,
|
377
|
+
parent_client: "MCPBaseClient | None" = None):
|
352
378
|
self._session = session
|
353
379
|
self._tool_name = tool_name
|
354
380
|
self._tool_description = tool_description
|
355
381
|
self._input_schema = (model_from_mcp_schema(self._tool_name, tool_input_schema) if tool_input_schema else None)
|
382
|
+
self._parent_client = parent_client
|
356
383
|
|
357
384
|
@property
|
358
385
|
def name(self):
|
@@ -388,6 +415,9 @@ class MCPToolClient:
|
|
388
415
|
Args:
|
389
416
|
tool_args (dict[str, Any]): A dictionary of key value pairs to serve as inputs for the MCP tool.
|
390
417
|
"""
|
418
|
+
if self._session is None:
|
419
|
+
raise RuntimeError("No session available for tool call")
|
420
|
+
logger.info("Calling tool %s with arguments %s", self._tool_name, tool_args)
|
391
421
|
result = await self._session.call_tool(self._tool_name, tool_args)
|
392
422
|
|
393
423
|
output = []
|
nat/plugins/mcp/client_impl.py
CHANGED
@@ -22,11 +22,11 @@ from pydantic import HttpUrl
|
|
22
22
|
from pydantic import model_validator
|
23
23
|
|
24
24
|
from nat.builder.builder import Builder
|
25
|
-
from nat.builder.
|
26
|
-
from nat.cli.register_workflow import
|
27
|
-
from nat.data_models.
|
28
|
-
from nat.
|
29
|
-
from nat.plugins.mcp.
|
25
|
+
from nat.builder.function import FunctionGroup
|
26
|
+
from nat.cli.register_workflow import register_function_group
|
27
|
+
from nat.data_models.component_ref import AuthenticationRef
|
28
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
29
|
+
from nat.plugins.mcp.tool import mcp_tool_function
|
30
30
|
|
31
31
|
logger = logging.getLogger(__name__)
|
32
32
|
|
@@ -54,6 +54,9 @@ class MCPServerConfig(BaseModel):
|
|
54
54
|
args: list[str] | None = Field(default=None, description="Arguments for the stdio command")
|
55
55
|
env: dict[str, str] | None = Field(default=None, description="Environment variables for the stdio process")
|
56
56
|
|
57
|
+
# Authentication configuration
|
58
|
+
auth_provider: AuthenticationRef | None = Field(default=None, description="Reference to authentication provider")
|
59
|
+
|
57
60
|
@model_validator(mode="after")
|
58
61
|
def validate_model(self):
|
59
62
|
"""Validate that stdio and SSE/Streamable HTTP properties are mutually exclusive."""
|
@@ -62,168 +65,118 @@ class MCPServerConfig(BaseModel):
|
|
62
65
|
raise ValueError("url should not be set when using stdio transport")
|
63
66
|
if not self.command:
|
64
67
|
raise ValueError("command is required when using stdio transport")
|
65
|
-
|
68
|
+
# Auth is not supported for stdio transport
|
69
|
+
if self.auth_provider is not None:
|
70
|
+
raise ValueError("Authentication is not supported for stdio transport")
|
71
|
+
elif self.transport == "sse":
|
72
|
+
if self.command is not None or self.args is not None or self.env is not None:
|
73
|
+
raise ValueError("command, args, and env should not be set when using sse transport")
|
74
|
+
if not self.url:
|
75
|
+
raise ValueError("url is required when using sse transport")
|
76
|
+
# Auth is not supported for SSE transport
|
77
|
+
if self.auth_provider is not None:
|
78
|
+
raise ValueError("Authentication is not supported for SSE transport.")
|
79
|
+
elif self.transport == "streamable-http":
|
66
80
|
if self.command is not None or self.args is not None or self.env is not None:
|
67
|
-
raise ValueError("command, args, and env should not be set when using
|
81
|
+
raise ValueError("command, args, and env should not be set when using streamable-http transport")
|
68
82
|
if not self.url:
|
69
|
-
raise ValueError("url is required when using
|
83
|
+
raise ValueError("url is required when using streamable-http transport")
|
84
|
+
|
70
85
|
return self
|
71
86
|
|
72
87
|
|
73
|
-
class MCPClientConfig(
|
88
|
+
class MCPClientConfig(FunctionGroupBaseConfig, name="mcp_client"):
|
74
89
|
"""
|
75
90
|
Configuration for connecting to an MCP server as a client and exposing selected tools.
|
76
91
|
"""
|
77
92
|
server: MCPServerConfig = Field(..., description="Server connection details (transport, url/command, etc.)")
|
78
|
-
|
93
|
+
tool_overrides: dict[str, MCPToolOverrideConfig] | None = Field(
|
79
94
|
default=None,
|
80
|
-
description="""
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
95
|
+
description="""Optional tool name overrides and description changes.
|
96
|
+
Example:
|
97
|
+
tool_overrides:
|
98
|
+
calculator_add:
|
99
|
+
alias: "add_numbers"
|
100
|
+
description: "Add two numbers together"
|
101
|
+
calculator_multiply:
|
102
|
+
description: "Multiply two numbers" # alias defaults to original name
|
86
103
|
""")
|
87
104
|
|
88
105
|
|
89
|
-
|
90
|
-
|
91
|
-
Configuration for wrapping a single tool from an MCP server as a NeMo Agent toolkit function.
|
92
|
-
"""
|
93
|
-
client: MCPBaseClient = Field(..., description="MCP client to use for the tool")
|
94
|
-
tool_name: str = Field(..., description="Name of the tool to use")
|
95
|
-
tool_description: str | None = Field(default=None, description="Description of the tool")
|
96
|
-
|
97
|
-
model_config = {"arbitrary_types_allowed": True}
|
98
|
-
|
99
|
-
|
100
|
-
def _get_server_name_safe(client: MCPBaseClient) -> str:
|
101
|
-
# Avoid leaking env secrets from stdio client in logs.
|
102
|
-
if client.transport == "stdio":
|
103
|
-
safe_server = f"stdio: {client.command}"
|
104
|
-
else:
|
105
|
-
safe_server = f"{client.transport}: {client.url}"
|
106
|
-
|
107
|
-
return safe_server
|
108
|
-
|
109
|
-
|
110
|
-
@register_function(config_type=MCPSingleToolConfig)
|
111
|
-
async def mcp_single_tool(config: MCPSingleToolConfig, builder: Builder):
|
106
|
+
@register_function_group(config_type=MCPClientConfig)
|
107
|
+
async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder):
|
112
108
|
"""
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
logger.info("Configured to use tool: %s from MCP server at %s", tool.name, _get_server_name_safe(config.client))
|
121
|
-
|
122
|
-
def _convert_from_str(input_str: str) -> BaseModel:
|
123
|
-
return input_schema.model_validate_json(input_str)
|
124
|
-
|
125
|
-
@experimental(feature_name="mcp_client")
|
126
|
-
async def _response_fn(tool_input: BaseModel | None = None, **kwargs) -> str:
|
127
|
-
try:
|
128
|
-
if tool_input:
|
129
|
-
return await tool.acall(tool_input.model_dump())
|
130
|
-
_ = input_schema.model_validate(kwargs)
|
131
|
-
return await tool.acall(kwargs)
|
132
|
-
except Exception as e:
|
133
|
-
return str(e)
|
134
|
-
|
135
|
-
fn = FunctionInfo.create(single_fn=_response_fn,
|
136
|
-
description=tool.description,
|
137
|
-
input_schema=input_schema,
|
138
|
-
converters=[_convert_from_str])
|
139
|
-
yield fn
|
140
|
-
|
141
|
-
|
142
|
-
@register_function(MCPClientConfig)
|
143
|
-
async def mcp_client_function_handler(config: MCPClientConfig, builder: Builder):
|
144
|
-
"""
|
145
|
-
Connect to an MCP server, discover tools, and register them as functions in the workflow.
|
146
|
-
|
147
|
-
Note:
|
148
|
-
- Uses builder's exit stack to manage client lifecycle
|
149
|
-
- Applies tool filters if provided
|
109
|
+
Connect to an MCP server and expose tools as a function group.
|
110
|
+
Args:
|
111
|
+
config: The configuration for the MCP client
|
112
|
+
_builder: The builder
|
113
|
+
Returns:
|
114
|
+
The function group
|
150
115
|
"""
|
151
116
|
from nat.plugins.mcp.client_base import MCPSSEClient
|
152
117
|
from nat.plugins.mcp.client_base import MCPStdioClient
|
153
118
|
from nat.plugins.mcp.client_base import MCPStreamableHTTPClient
|
154
119
|
|
155
|
-
#
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
"streamable-http": lambda: MCPStreamableHTTPClient(str(config.server.url)),
|
160
|
-
}.get(config.server.transport)
|
120
|
+
# Resolve auth provider if specified
|
121
|
+
auth_provider = None
|
122
|
+
if config.server.auth_provider:
|
123
|
+
auth_provider = await _builder.get_auth_provider(config.server.auth_provider)
|
161
124
|
|
162
|
-
|
125
|
+
# Build the appropriate client
|
126
|
+
if config.server.transport == "stdio":
|
127
|
+
if not config.server.command:
|
128
|
+
raise ValueError("command is required for stdio transport")
|
129
|
+
client = MCPStdioClient(config.server.command, config.server.args, config.server.env)
|
130
|
+
elif config.server.transport == "sse":
|
131
|
+
client = MCPSSEClient(str(config.server.url))
|
132
|
+
elif config.server.transport == "streamable-http":
|
133
|
+
client = MCPStreamableHTTPClient(str(config.server.url), auth_provider=auth_provider)
|
134
|
+
else:
|
163
135
|
raise ValueError(f"Unsupported transport: {config.server.transport}")
|
164
136
|
|
165
|
-
|
166
|
-
|
137
|
+
logger.info("Configured to use MCP server at %s", client.server_name)
|
138
|
+
|
139
|
+
# Create the function group
|
140
|
+
group = FunctionGroup(config=config)
|
167
141
|
|
168
|
-
# client aenter connects to the server and stores the client in the exit stack
|
169
|
-
# so it's cleaned up when the workflow is done
|
170
142
|
async with client:
|
171
143
|
all_tools = await client.get_tools()
|
172
|
-
|
144
|
+
tool_overrides = mcp_apply_tool_alias_and_description(all_tools, config.tool_overrides)
|
173
145
|
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
client=client,
|
179
|
-
tool_name=tool_name,
|
180
|
-
tool_description=tool_cfg["description"],
|
181
|
-
))
|
146
|
+
# Add each tool as a function to the group
|
147
|
+
for tool_name, tool in all_tools.items():
|
148
|
+
# Get override if it exists
|
149
|
+
override = tool_overrides.get(tool_name)
|
182
150
|
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
return f"MCP client connected: {text}"
|
151
|
+
# Use override values or defaults
|
152
|
+
function_name = override.alias if override and override.alias else tool_name
|
153
|
+
description = override.description if override and override.description else tool.description
|
187
154
|
|
188
|
-
|
155
|
+
# Create the tool function
|
156
|
+
tool_fn = mcp_tool_function(tool)
|
189
157
|
|
158
|
+
# Add to group
|
159
|
+
logger.info("Adding tool %s to group", function_name)
|
160
|
+
group.add_function(name=function_name,
|
161
|
+
description=description,
|
162
|
+
fn=tool_fn.single_fn,
|
163
|
+
input_schema=tool_fn.input_schema,
|
164
|
+
converters=tool_fn.converters)
|
190
165
|
|
191
|
-
|
192
|
-
"""
|
193
|
-
Apply tool filtering and optional aliasing/description overrides.
|
166
|
+
yield group
|
194
167
|
|
168
|
+
|
169
|
+
def mcp_apply_tool_alias_and_description(
|
170
|
+
all_tools: dict, tool_overrides: dict[str, MCPToolOverrideConfig] | None) -> dict[str, MCPToolOverrideConfig]:
|
171
|
+
"""
|
172
|
+
Filter tool overrides to only include tools that exist in the MCP server.
|
173
|
+
Args:
|
174
|
+
all_tools: The tools from the MCP server
|
175
|
+
tool_overrides: The tool overrides to apply
|
195
176
|
Returns:
|
196
|
-
|
197
|
-
- function_name
|
198
|
-
- description
|
177
|
+
Dictionary of valid tool overrides
|
199
178
|
"""
|
200
|
-
if
|
201
|
-
return {
|
202
|
-
|
203
|
-
|
204
|
-
return {
|
205
|
-
name: {
|
206
|
-
"function_name": name, "description": all_tools[name].description
|
207
|
-
}
|
208
|
-
for name in tool_filter if name in all_tools
|
209
|
-
}
|
210
|
-
|
211
|
-
if isinstance(tool_filter, dict):
|
212
|
-
result = {}
|
213
|
-
for name, override in tool_filter.items():
|
214
|
-
tool = all_tools.get(name)
|
215
|
-
if not tool:
|
216
|
-
logger.warning("Tool '%s' specified in tool_filter not found in MCP server", name)
|
217
|
-
continue
|
218
|
-
|
219
|
-
if isinstance(override, MCPToolOverrideConfig):
|
220
|
-
result[name] = {
|
221
|
-
"function_name": override.alias or name, "description": override.description or tool.description
|
222
|
-
}
|
223
|
-
else:
|
224
|
-
logger.warning("Unsupported override type for '%s': %s", name, type(override))
|
225
|
-
result[name] = {"function_name": name, "description": tool.description}
|
226
|
-
return result
|
227
|
-
|
228
|
-
# Fallback for unsupported tool_filter types
|
229
|
-
raise ValueError(f"Unsupported tool_filter type: {type(tool_filter)}")
|
179
|
+
if not tool_overrides:
|
180
|
+
return {}
|
181
|
+
|
182
|
+
return {name: override for name, override in tool_overrides.items() if name in all_tools}
|