nvidia-nat-mcp 1.3.0a20250917__py3-none-any.whl → 1.3.0a20250923__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -20,12 +20,9 @@ from abc import ABC
20
20
  from abc import abstractmethod
21
21
  from contextlib import AsyncExitStack
22
22
  from contextlib import asynccontextmanager
23
- from enum import Enum
24
- from typing import Any
23
+ from typing import AsyncGenerator
25
24
 
26
- from pydantic import BaseModel
27
- from pydantic import Field
28
- from pydantic import create_model
25
+ import httpx
29
26
 
30
27
  from mcp import ClientSession
31
28
  from mcp.client.sse import sse_client
@@ -33,104 +30,120 @@ from mcp.client.stdio import StdioServerParameters
33
30
  from mcp.client.stdio import stdio_client
34
31
  from mcp.client.streamable_http import streamablehttp_client
35
32
  from mcp.types import TextContent
33
+ from nat.authentication.interfaces import AuthProviderBase
34
+ from nat.data_models.authentication import AuthReason
35
+ from nat.data_models.authentication import AuthRequest
36
36
  from nat.plugins.mcp.exception_handler import mcp_exception_handler
37
37
  from nat.plugins.mcp.exceptions import MCPToolNotFoundError
38
+ from nat.plugins.mcp.utils import model_from_mcp_schema
38
39
  from nat.utils.type_utils import override
39
40
 
40
41
  logger = logging.getLogger(__name__)
41
42
 
42
43
 
43
- def model_from_mcp_schema(name: str, mcp_input_schema: dict) -> type[BaseModel]:
44
+ class AuthAdapter(httpx.Auth):
44
45
  """
45
- Create a pydantic model from the input schema of the MCP tool
46
+ httpx.Auth adapter for authentication providers.
47
+ Converts AuthProviderBase to httpx.Auth interface for dynamic token management.
46
48
  """
47
- _type_map = {
48
- "string": str,
49
- "number": float,
50
- "integer": int,
51
- "boolean": bool,
52
- "array": list,
53
- "null": None,
54
- "object": dict,
55
- }
56
-
57
- properties = mcp_input_schema.get("properties", {})
58
- required_fields = set(mcp_input_schema.get("required", []))
59
- schema_dict = {}
60
-
61
- def _generate_valid_classname(class_name: str):
62
- return class_name.replace('_', ' ').replace('-', ' ').title().replace(' ', '')
63
-
64
- def _generate_field(field_name: str, field_properties: dict[str, Any]) -> tuple:
65
- json_type = field_properties.get("type", "string")
66
- enum_vals = field_properties.get("enum")
67
-
68
- if enum_vals:
69
- enum_name = f"{field_name.capitalize()}Enum"
70
- field_type = Enum(enum_name, {item: item for item in enum_vals})
71
-
72
- elif json_type == "object" and "properties" in field_properties:
73
- field_type = model_from_mcp_schema(name=field_name, mcp_input_schema=field_properties)
74
- elif json_type == "array" and "items" in field_properties:
75
- item_properties = field_properties.get("items", {})
76
- if item_properties.get("type") == "object":
77
- item_type = model_from_mcp_schema(name=field_name, mcp_input_schema=item_properties)
49
+
50
+ def __init__(self, auth_provider: AuthProviderBase, auth_for_tool_calls_only: bool = False):
51
+ self.auth_provider = auth_provider
52
+ self.auth_for_tool_calls_only = auth_for_tool_calls_only
53
+
54
+ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
55
+ """Add authentication headers to the request using NAT auth provider."""
56
+ # Check if we should only auth tool calls, Is this needed?
57
+ if self.auth_for_tool_calls_only and not self._is_tool_call_request(request):
58
+ # Skip auth for non-tool calls
59
+ yield request
60
+ return
61
+
62
+ try:
63
+ # Get fresh auth headers from the NAT auth provider
64
+ auth_headers = await self._get_auth_headers(reason=AuthReason.NORMAL)
65
+ request.headers.update(auth_headers)
66
+ except Exception as e:
67
+ logger.info("Failed to get auth headers: %s", e)
68
+ # Continue without auth headers if auth fails
69
+
70
+ response = yield request
71
+
72
+ # Handle 401 responses by retrying with fresh auth
73
+ if response.status_code == 401:
74
+ try:
75
+ # Get fresh auth headers with 401 context
76
+ auth_headers = await self._get_auth_headers(reason=AuthReason.RETRY_AFTER_401, response=response)
77
+ request.headers.update(auth_headers)
78
+ yield request # Retry the request
79
+ except Exception as e:
80
+ logger.info("Failed to refresh auth after 401: %s", e)
81
+ return
82
+
83
+ def _is_tool_call_request(self, request: httpx.Request) -> bool:
84
+ """Check if this is a tool call request based on the request body."""
85
+ try:
86
+ # Check if the request body contains a tool call
87
+ if request.content:
88
+ import json
89
+ body = json.loads(request.content.decode('utf-8'))
90
+ # Check if it's a JSON-RPC request with method "tools/call"
91
+ if (isinstance(body, dict) and body.get("method") == "tools/call"):
92
+ return True
93
+ except (json.JSONDecodeError, UnicodeDecodeError, AttributeError):
94
+ # If we can't parse the body, assume it's not a tool call
95
+ pass
96
+ return False
97
+
98
+ async def _get_auth_headers(self, reason: AuthReason, response: httpx.Response | None = None) -> dict[str, str]:
99
+ """Get authentication headers from the NAT auth provider."""
100
+ # Build auth request
101
+ www_authenticate = response.headers.get("WWW-Authenticate", None) if response else None
102
+ auth_request = AuthRequest(
103
+ reason=reason,
104
+ www_authenticate=www_authenticate,
105
+ )
106
+ try:
107
+ # Mutating the config is not thread-safe, so we need to lock here
108
+ # Is mutating the config the only way to pass the auth request to the auth provider? This needs
109
+ # to be re-visited.
110
+ self.auth_provider.config.auth_request = auth_request
111
+ auth_result = await self.auth_provider.authenticate()
112
+ # Check if we have BearerTokenCred
113
+ from nat.data_models.authentication import BearerTokenCred
114
+ if auth_result.credentials and isinstance(auth_result.credentials[0], BearerTokenCred):
115
+ token = auth_result.credentials[0].token.get_secret_value()
116
+ return {"Authorization": f"Bearer {token}"}
78
117
  else:
79
- item_type = _type_map.get(item_properties.get("type", "string"), Any)
80
- field_type = list[item_type]
81
- elif isinstance(json_type, list):
82
- field_type = None
83
- for t in json_type:
84
- mapped = _type_map.get(t, Any)
85
- field_type = mapped if field_type is None else field_type | mapped
86
-
87
- return field_type, Field(
88
- default=field_properties.get("default", None if "null" in json_type else ...),
89
- description=field_properties.get("description", "")
90
- )
91
- else:
92
- field_type = _type_map.get(json_type, Any)
93
-
94
- # Determine the default value based on whether the field is required
95
- if field_name in required_fields:
96
- # Field is required - use explicit default if provided, otherwise make it required
97
- default_value = field_properties.get("default", ...)
98
- else:
99
- # Field is optional - use explicit default if provided, otherwise None
100
- default_value = field_properties.get("default", None)
101
- # Make the type optional if no default was provided
102
- if "default" not in field_properties:
103
- field_type = field_type | None
104
-
105
- nullable = field_properties.get("nullable", False)
106
- description = field_properties.get("description", "")
107
-
108
- field_type = field_type | None if nullable else field_type
109
-
110
- return field_type, Field(default=default_value, description=description)
111
-
112
- for field_name, field_props in properties.items():
113
- schema_dict[field_name] = _generate_field(field_name=field_name, field_properties=field_props)
114
- return create_model(f"{_generate_valid_classname(name)}InputSchema", **schema_dict)
118
+ logger.warning("Auth provider did not return BearerTokenCred")
119
+ return {}
120
+ except Exception as e:
121
+ logger.warning("Failed to get auth token: %s", e)
122
+ return {}
115
123
 
116
124
 
117
125
  class MCPBaseClient(ABC):
118
126
  """
119
- Base client for creating a session and connecting to an MCP server
127
+ Base client for creating a MCP transport session and connecting to an MCP server
120
128
 
121
129
  Args:
122
130
  transport (str): The type of client to use ('sse', 'stdio', or 'streamable-http')
131
+ auth_provider (AuthProviderBase | None): Optional authentication provider for Bearer token injection
123
132
  """
124
133
 
125
- def __init__(self, transport: str = 'streamable-http'):
134
+ def __init__(self, transport: str = 'streamable-http', auth_provider: AuthProviderBase | None = None):
126
135
  self._tools = None
127
136
  self._transport = transport.lower()
128
137
  if self._transport not in ['sse', 'stdio', 'streamable-http']:
129
138
  raise ValueError("transport must be either 'sse', 'stdio' or 'streamable-http'")
130
139
 
131
140
  self._exit_stack: AsyncExitStack | None = None
141
+ self._session: ClientSession | None = None # Main session
142
+ self._connection_established = False
143
+ self._initial_connection = False
132
144
 
133
- self._session: ClientSession | None = None
145
+ # Convert auth provider to AuthAdapter
146
+ self._httpx_auth = AuthAdapter(auth_provider) if auth_provider else None
134
147
 
135
148
  @property
136
149
  def transport(self) -> str:
@@ -142,15 +155,19 @@ class MCPBaseClient(ABC):
142
155
 
143
156
  self._exit_stack = AsyncExitStack()
144
157
 
158
+ # Establish connection with httpx.Auth
145
159
  self._session = await self._exit_stack.enter_async_context(self.connect_to_server())
146
160
 
161
+ self._initial_connection = True
162
+ self._connection_established = True
163
+
147
164
  return self
148
165
 
149
166
  async def __aexit__(self, exc_type, exc_value, traceback):
150
-
151
167
  if not self._exit_stack:
152
168
  raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
153
169
 
170
+ # Close session
154
171
  await self._exit_stack.aclose()
155
172
  self._session = None
156
173
  self._exit_stack = None
@@ -168,11 +185,12 @@ class MCPBaseClient(ABC):
168
185
  """
169
186
  Establish a session with an MCP server within an async context
170
187
  """
171
- pass
188
+ yield
172
189
 
173
190
  async def get_tools(self):
174
191
  """
175
192
  Retrieve a dictionary of all tools served by the MCP server.
193
+ Uses unauthenticated session for discovery.
176
194
  """
177
195
 
178
196
  if not self._session:
@@ -185,7 +203,8 @@ class MCPBaseClient(ABC):
185
203
  MCPToolClient(session=self._session,
186
204
  tool_name=tool.name,
187
205
  tool_description=tool.description,
188
- tool_input_schema=tool.inputSchema)
206
+ tool_input_schema=tool.inputSchema,
207
+ parent_client=self)
189
208
  for tool in response.tools
190
209
  }
191
210
 
@@ -257,7 +276,9 @@ class MCPSSEClient(MCPBaseClient):
257
276
 
258
277
  class MCPStdioClient(MCPBaseClient):
259
278
  """
260
- Client for creating a session and connecting to an MCP server using stdio
279
+ Client for creating a session and connecting to an MCP server using stdio.
280
+ This is a local transport that spawns the MCP server process and communicates
281
+ with it over stdin/stdout.
261
282
 
262
283
  Args:
263
284
  command (str): The command to run
@@ -307,11 +328,11 @@ class MCPStreamableHTTPClient(MCPBaseClient):
307
328
 
308
329
  Args:
309
330
  url (str): The url of the MCP server
331
+ auth_provider (AuthProviderBase | None): Optional authentication provider for Bearer token injection
310
332
  """
311
333
 
312
- def __init__(self, url: str):
313
- super().__init__("streamable-http")
314
-
334
+ def __init__(self, url: str, auth_provider: AuthProviderBase | None = None):
335
+ super().__init__("streamable-http", auth_provider=auth_provider)
315
336
  self._url = url
316
337
 
317
338
  @property
@@ -323,11 +344,13 @@ class MCPStreamableHTTPClient(MCPBaseClient):
323
344
  return f"streamable-http:{self._url}"
324
345
 
325
346
  @asynccontextmanager
347
+ @override
326
348
  async def connect_to_server(self):
327
349
  """
328
350
  Establish a session with an MCP server via streamable-http within an async context
329
351
  """
330
- async with streamablehttp_client(url=self._url) as (read, write, get_session_id):
352
+ # Use httpx.Auth for authentication
353
+ async with streamablehttp_client(url=self._url, auth=self._httpx_auth) as (read, write, _):
331
354
  async with ClientSession(read, write) as session:
332
355
  await session.initialize()
333
356
  yield session
@@ -335,24 +358,28 @@ class MCPStreamableHTTPClient(MCPBaseClient):
335
358
 
336
359
  class MCPToolClient:
337
360
  """
338
- Client wrapper used to call an MCP tool.
361
+ Client wrapper used to call an MCP tool. This assumes that the MCP transport session
362
+ has already been setup.
339
363
 
340
364
  Args:
341
- connect_fn (callable): Function that returns an async context manager for connecting to the server
365
+ session (ClientSession): The MCP client session
342
366
  tool_name (str): The name of the tool to wrap
343
367
  tool_description (str): The description of the tool provided by the MCP server.
344
368
  tool_input_schema (dict): The input schema for the tool.
369
+ parent_client (MCPBaseClient): The parent MCP client for auth management.
345
370
  """
346
371
 
347
372
  def __init__(self,
348
373
  session: ClientSession,
349
374
  tool_name: str,
350
375
  tool_description: str | None,
351
- tool_input_schema: dict | None = None):
376
+ tool_input_schema: dict | None = None,
377
+ parent_client: "MCPBaseClient | None" = None):
352
378
  self._session = session
353
379
  self._tool_name = tool_name
354
380
  self._tool_description = tool_description
355
381
  self._input_schema = (model_from_mcp_schema(self._tool_name, tool_input_schema) if tool_input_schema else None)
382
+ self._parent_client = parent_client
356
383
 
357
384
  @property
358
385
  def name(self):
@@ -388,6 +415,9 @@ class MCPToolClient:
388
415
  Args:
389
416
  tool_args (dict[str, Any]): A dictionary of key value pairs to serve as inputs for the MCP tool.
390
417
  """
418
+ if self._session is None:
419
+ raise RuntimeError("No session available for tool call")
420
+ logger.info("Calling tool %s with arguments %s", self._tool_name, tool_args)
391
421
  result = await self._session.call_tool(self._tool_name, tool_args)
392
422
 
393
423
  output = []
@@ -22,11 +22,11 @@ from pydantic import HttpUrl
22
22
  from pydantic import model_validator
23
23
 
24
24
  from nat.builder.builder import Builder
25
- from nat.builder.function_info import FunctionInfo
26
- from nat.cli.register_workflow import register_function
27
- from nat.data_models.function import FunctionBaseConfig
28
- from nat.experimental.decorators.experimental_warning_decorator import experimental
29
- from nat.plugins.mcp.client_base import MCPBaseClient
25
+ from nat.builder.function import FunctionGroup
26
+ from nat.cli.register_workflow import register_function_group
27
+ from nat.data_models.component_ref import AuthenticationRef
28
+ from nat.data_models.function import FunctionGroupBaseConfig
29
+ from nat.plugins.mcp.tool import mcp_tool_function
30
30
 
31
31
  logger = logging.getLogger(__name__)
32
32
 
@@ -54,6 +54,9 @@ class MCPServerConfig(BaseModel):
54
54
  args: list[str] | None = Field(default=None, description="Arguments for the stdio command")
55
55
  env: dict[str, str] | None = Field(default=None, description="Environment variables for the stdio process")
56
56
 
57
+ # Authentication configuration
58
+ auth_provider: AuthenticationRef | None = Field(default=None, description="Reference to authentication provider")
59
+
57
60
  @model_validator(mode="after")
58
61
  def validate_model(self):
59
62
  """Validate that stdio and SSE/Streamable HTTP properties are mutually exclusive."""
@@ -62,168 +65,118 @@ class MCPServerConfig(BaseModel):
62
65
  raise ValueError("url should not be set when using stdio transport")
63
66
  if not self.command:
64
67
  raise ValueError("command is required when using stdio transport")
65
- elif self.transport in ("sse", "streamable-http"):
68
+ # Auth is not supported for stdio transport
69
+ if self.auth_provider is not None:
70
+ raise ValueError("Authentication is not supported for stdio transport")
71
+ elif self.transport == "sse":
72
+ if self.command is not None or self.args is not None or self.env is not None:
73
+ raise ValueError("command, args, and env should not be set when using sse transport")
74
+ if not self.url:
75
+ raise ValueError("url is required when using sse transport")
76
+ # Auth is not supported for SSE transport
77
+ if self.auth_provider is not None:
78
+ raise ValueError("Authentication is not supported for SSE transport.")
79
+ elif self.transport == "streamable-http":
66
80
  if self.command is not None or self.args is not None or self.env is not None:
67
- raise ValueError("command, args, and env should not be set when using sse or streamable-http transport")
81
+ raise ValueError("command, args, and env should not be set when using streamable-http transport")
68
82
  if not self.url:
69
- raise ValueError("url is required when using sse or streamable-http transport")
83
+ raise ValueError("url is required when using streamable-http transport")
84
+
70
85
  return self
71
86
 
72
87
 
73
- class MCPClientConfig(FunctionBaseConfig, name="mcp_client"):
88
+ class MCPClientConfig(FunctionGroupBaseConfig, name="mcp_client"):
74
89
  """
75
90
  Configuration for connecting to an MCP server as a client and exposing selected tools.
76
91
  """
77
92
  server: MCPServerConfig = Field(..., description="Server connection details (transport, url/command, etc.)")
78
- tool_filter: dict[str, MCPToolOverrideConfig] | list[str] | None = Field(
93
+ tool_overrides: dict[str, MCPToolOverrideConfig] | None = Field(
79
94
  default=None,
80
- description="""Filter or map tools to expose from the server (list or dict).
81
- Can be:
82
- - A list of tool names to expose: ['tool1', 'tool2']
83
- - A dict mapping tool names to override configs:
84
- {'tool1': {'alias': 'new_name', 'description': 'New desc'}}
85
- {'tool2': {'description': 'Override description only'}} # alias defaults to 'tool2'
95
+ description="""Optional tool name overrides and description changes.
96
+ Example:
97
+ tool_overrides:
98
+ calculator_add:
99
+ alias: "add_numbers"
100
+ description: "Add two numbers together"
101
+ calculator_multiply:
102
+ description: "Multiply two numbers" # alias defaults to original name
86
103
  """)
87
104
 
88
105
 
89
- class MCPSingleToolConfig(FunctionBaseConfig, name="mcp_single_tool"):
90
- """
91
- Configuration for wrapping a single tool from an MCP server as a NeMo Agent toolkit function.
92
- """
93
- client: MCPBaseClient = Field(..., description="MCP client to use for the tool")
94
- tool_name: str = Field(..., description="Name of the tool to use")
95
- tool_description: str | None = Field(default=None, description="Description of the tool")
96
-
97
- model_config = {"arbitrary_types_allowed": True}
98
-
99
-
100
- def _get_server_name_safe(client: MCPBaseClient) -> str:
101
- # Avoid leaking env secrets from stdio client in logs.
102
- if client.transport == "stdio":
103
- safe_server = f"stdio: {client.command}"
104
- else:
105
- safe_server = f"{client.transport}: {client.url}"
106
-
107
- return safe_server
108
-
109
-
110
- @register_function(config_type=MCPSingleToolConfig)
111
- async def mcp_single_tool(config: MCPSingleToolConfig, builder: Builder):
106
+ @register_function_group(config_type=MCPClientConfig)
107
+ async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder):
112
108
  """
113
- Wrap a single tool from an MCP server as a NeMo Agent toolkit function.
114
- """
115
- tool = await config.client.get_tool(config.tool_name)
116
- if config.tool_description:
117
- tool.set_description(description=config.tool_description)
118
- input_schema = tool.input_schema
119
-
120
- logger.info("Configured to use tool: %s from MCP server at %s", tool.name, _get_server_name_safe(config.client))
121
-
122
- def _convert_from_str(input_str: str) -> BaseModel:
123
- return input_schema.model_validate_json(input_str)
124
-
125
- @experimental(feature_name="mcp_client")
126
- async def _response_fn(tool_input: BaseModel | None = None, **kwargs) -> str:
127
- try:
128
- if tool_input:
129
- return await tool.acall(tool_input.model_dump())
130
- _ = input_schema.model_validate(kwargs)
131
- return await tool.acall(kwargs)
132
- except Exception as e:
133
- return str(e)
134
-
135
- fn = FunctionInfo.create(single_fn=_response_fn,
136
- description=tool.description,
137
- input_schema=input_schema,
138
- converters=[_convert_from_str])
139
- yield fn
140
-
141
-
142
- @register_function(MCPClientConfig)
143
- async def mcp_client_function_handler(config: MCPClientConfig, builder: Builder):
144
- """
145
- Connect to an MCP server, discover tools, and register them as functions in the workflow.
146
-
147
- Note:
148
- - Uses builder's exit stack to manage client lifecycle
149
- - Applies tool filters if provided
109
+ Connect to an MCP server and expose tools as a function group.
110
+ Args:
111
+ config: The configuration for the MCP client
112
+ _builder: The builder
113
+ Returns:
114
+ The function group
150
115
  """
151
116
  from nat.plugins.mcp.client_base import MCPSSEClient
152
117
  from nat.plugins.mcp.client_base import MCPStdioClient
153
118
  from nat.plugins.mcp.client_base import MCPStreamableHTTPClient
154
119
 
155
- # Build the appropriate client
156
- client_cls = {
157
- "stdio": lambda: MCPStdioClient(config.server.command, config.server.args, config.server.env),
158
- "sse": lambda: MCPSSEClient(str(config.server.url)),
159
- "streamable-http": lambda: MCPStreamableHTTPClient(str(config.server.url)),
160
- }.get(config.server.transport)
120
+ # Resolve auth provider if specified
121
+ auth_provider = None
122
+ if config.server.auth_provider:
123
+ auth_provider = await _builder.get_auth_provider(config.server.auth_provider)
161
124
 
162
- if not client_cls:
125
+ # Build the appropriate client
126
+ if config.server.transport == "stdio":
127
+ if not config.server.command:
128
+ raise ValueError("command is required for stdio transport")
129
+ client = MCPStdioClient(config.server.command, config.server.args, config.server.env)
130
+ elif config.server.transport == "sse":
131
+ client = MCPSSEClient(str(config.server.url))
132
+ elif config.server.transport == "streamable-http":
133
+ client = MCPStreamableHTTPClient(str(config.server.url), auth_provider=auth_provider)
134
+ else:
163
135
  raise ValueError(f"Unsupported transport: {config.server.transport}")
164
136
 
165
- client = client_cls()
166
- logger.info("Configured to use MCP server at %s", _get_server_name_safe(client))
137
+ logger.info("Configured to use MCP server at %s", client.server_name)
138
+
139
+ # Create the function group
140
+ group = FunctionGroup(config=config)
167
141
 
168
- # client aenter connects to the server and stores the client in the exit stack
169
- # so it's cleaned up when the workflow is done
170
142
  async with client:
171
143
  all_tools = await client.get_tools()
172
- tool_configs = mcp_filter_tools(all_tools, config.tool_filter)
144
+ tool_overrides = mcp_apply_tool_alias_and_description(all_tools, config.tool_overrides)
173
145
 
174
- for tool_name, tool_cfg in tool_configs.items():
175
- await builder.add_function(
176
- tool_cfg["function_name"],
177
- MCPSingleToolConfig(
178
- client=client,
179
- tool_name=tool_name,
180
- tool_description=tool_cfg["description"],
181
- ))
146
+ # Add each tool as a function to the group
147
+ for tool_name, tool in all_tools.items():
148
+ # Get override if it exists
149
+ override = tool_overrides.get(tool_name)
182
150
 
183
- @experimental(feature_name="mcp_client")
184
- async def idle_fn(text: str) -> str:
185
- # This function is a placeholder and will be removed when function groups are used
186
- return f"MCP client connected: {text}"
151
+ # Use override values or defaults
152
+ function_name = override.alias if override and override.alias else tool_name
153
+ description = override.description if override and override.description else tool.description
187
154
 
188
- yield FunctionInfo.create(single_fn=idle_fn, description="MCP client")
155
+ # Create the tool function
156
+ tool_fn = mcp_tool_function(tool)
189
157
 
158
+ # Add to group
159
+ logger.info("Adding tool %s to group", function_name)
160
+ group.add_function(name=function_name,
161
+ description=description,
162
+ fn=tool_fn.single_fn,
163
+ input_schema=tool_fn.input_schema,
164
+ converters=tool_fn.converters)
190
165
 
191
- def mcp_filter_tools(all_tools: dict, tool_filter) -> dict[str, dict]:
192
- """
193
- Apply tool filtering and optional aliasing/description overrides.
166
+ yield group
194
167
 
168
+
169
+ def mcp_apply_tool_alias_and_description(
170
+ all_tools: dict, tool_overrides: dict[str, MCPToolOverrideConfig] | None) -> dict[str, MCPToolOverrideConfig]:
171
+ """
172
+ Filter tool overrides to only include tools that exist in the MCP server.
173
+ Args:
174
+ all_tools: The tools from the MCP server
175
+ tool_overrides: The tool overrides to apply
195
176
  Returns:
196
- Dict[str, dict] where each value has:
197
- - function_name
198
- - description
177
+ Dictionary of valid tool overrides
199
178
  """
200
- if tool_filter is None:
201
- return {name: {"function_name": name, "description": tool.description} for name, tool in all_tools.items()}
202
-
203
- if isinstance(tool_filter, list):
204
- return {
205
- name: {
206
- "function_name": name, "description": all_tools[name].description
207
- }
208
- for name in tool_filter if name in all_tools
209
- }
210
-
211
- if isinstance(tool_filter, dict):
212
- result = {}
213
- for name, override in tool_filter.items():
214
- tool = all_tools.get(name)
215
- if not tool:
216
- logger.warning("Tool '%s' specified in tool_filter not found in MCP server", name)
217
- continue
218
-
219
- if isinstance(override, MCPToolOverrideConfig):
220
- result[name] = {
221
- "function_name": override.alias or name, "description": override.description or tool.description
222
- }
223
- else:
224
- logger.warning("Unsupported override type for '%s': %s", name, type(override))
225
- result[name] = {"function_name": name, "description": tool.description}
226
- return result
227
-
228
- # Fallback for unsupported tool_filter types
229
- raise ValueError(f"Unsupported tool_filter type: {type(tool_filter)}")
179
+ if not tool_overrides:
180
+ return {}
181
+
182
+ return {name: override for name, override in tool_overrides.items() if name in all_tools}