nvidia-nat 1.3.0a20250829__py3-none-any.whl → 1.3.0a20250830__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/react_agent/agent.py +37 -25
- nat/agent/react_agent/register.py +6 -1
- nat/builder/workflow.py +6 -2
- nat/cli/commands/info/list_mcp.py +183 -47
- nat/cli/commands/start.py +14 -2
- nat/data_models/thinking_mixin.py +27 -8
- nat/front_ends/mcp/mcp_front_end_config.py +5 -0
- nat/front_ends/mcp/mcp_front_end_plugin.py +8 -2
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +2 -2
- nat/front_ends/mcp/tool_converter.py +40 -13
- nat/observability/register.py +3 -1
- nat/tool/mcp/{mcp_client.py → mcp_client_base.py} +197 -46
- nat/tool/mcp/mcp_client_impl.py +229 -0
- nat/tool/mcp/mcp_tool.py +79 -42
- nat/tool/register.py +1 -0
- {nvidia_nat-1.3.0a20250829.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/METADATA +3 -3
- {nvidia_nat-1.3.0a20250829.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/RECORD +22 -21
- {nvidia_nat-1.3.0a20250829.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250829.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3.0a20250829.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250829.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250829.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/top_level.txt +0 -0
|
@@ -17,13 +17,17 @@ import json
|
|
|
17
17
|
import logging
|
|
18
18
|
from inspect import Parameter
|
|
19
19
|
from inspect import Signature
|
|
20
|
+
from typing import TYPE_CHECKING
|
|
20
21
|
|
|
21
22
|
from mcp.server.fastmcp import FastMCP
|
|
22
23
|
from pydantic import BaseModel
|
|
23
24
|
|
|
25
|
+
from nat.builder.context import ContextState
|
|
24
26
|
from nat.builder.function import Function
|
|
25
27
|
from nat.builder.function_base import FunctionBase
|
|
26
|
-
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from nat.builder.workflow import Workflow
|
|
27
31
|
|
|
28
32
|
logger = logging.getLogger(__name__)
|
|
29
33
|
|
|
@@ -33,14 +37,16 @@ def create_function_wrapper(
|
|
|
33
37
|
function: FunctionBase,
|
|
34
38
|
schema: type[BaseModel],
|
|
35
39
|
is_workflow: bool = False,
|
|
40
|
+
workflow: 'Workflow | None' = None,
|
|
36
41
|
):
|
|
37
42
|
"""Create a wrapper function that exposes the actual parameters of a NAT Function as an MCP tool.
|
|
38
43
|
|
|
39
44
|
Args:
|
|
40
|
-
function_name: The name of the function/tool
|
|
41
|
-
function: The NAT Function object
|
|
42
|
-
schema: The input schema of the function
|
|
43
|
-
is_workflow: Whether the function is a Workflow
|
|
45
|
+
function_name (str): The name of the function/tool
|
|
46
|
+
function (FunctionBase): The NAT Function object
|
|
47
|
+
schema (type[BaseModel]): The input schema of the function
|
|
48
|
+
is_workflow (bool): Whether the function is a Workflow
|
|
49
|
+
workflow (Workflow | None): The parent workflow for observability context
|
|
44
50
|
|
|
45
51
|
Returns:
|
|
46
52
|
A wrapper function suitable for registration with MCP
|
|
@@ -101,6 +107,19 @@ def create_function_wrapper(
|
|
|
101
107
|
await ctx.report_progress(0, 100)
|
|
102
108
|
|
|
103
109
|
try:
|
|
110
|
+
# Helper function to wrap function calls with observability
|
|
111
|
+
async def call_with_observability(func_call):
|
|
112
|
+
# Use workflow's observability context (workflow should always be available)
|
|
113
|
+
if not workflow:
|
|
114
|
+
logger.error("Missing workflow context for function %s - observability will not be available",
|
|
115
|
+
function_name)
|
|
116
|
+
raise RuntimeError("Workflow context is required for observability")
|
|
117
|
+
|
|
118
|
+
logger.debug("Starting observability context for function %s", function_name)
|
|
119
|
+
context_state = ContextState.get()
|
|
120
|
+
async with workflow.exporter_manager.start(context_state=context_state):
|
|
121
|
+
return await func_call()
|
|
122
|
+
|
|
104
123
|
# Special handling for ChatRequest
|
|
105
124
|
if is_chat_request:
|
|
106
125
|
from nat.data_models.api_server import ChatRequest
|
|
@@ -118,7 +137,7 @@ def create_function_wrapper(
|
|
|
118
137
|
result = await runner.result(to_type=str)
|
|
119
138
|
else:
|
|
120
139
|
# Regular functions use ainvoke
|
|
121
|
-
result = await function.ainvoke(chat_request, to_type=str)
|
|
140
|
+
result = await call_with_observability(lambda: function.ainvoke(chat_request, to_type=str))
|
|
122
141
|
else:
|
|
123
142
|
# Regular handling
|
|
124
143
|
# Handle complex input schema - if we extracted fields from a nested schema,
|
|
@@ -129,7 +148,7 @@ def create_function_wrapper(
|
|
|
129
148
|
field_type = schema.model_fields[field_name].annotation
|
|
130
149
|
|
|
131
150
|
# If it's a pydantic model, we need to create an instance
|
|
132
|
-
if hasattr(field_type, "model_validate"):
|
|
151
|
+
if field_type and hasattr(field_type, "model_validate"):
|
|
133
152
|
# Create the nested object
|
|
134
153
|
nested_obj = field_type.model_validate(kwargs)
|
|
135
154
|
# Call with the nested object
|
|
@@ -147,7 +166,7 @@ def create_function_wrapper(
|
|
|
147
166
|
result = await runner.result(to_type=str)
|
|
148
167
|
else:
|
|
149
168
|
# Regular function call
|
|
150
|
-
result = await function.acall_invoke(**kwargs)
|
|
169
|
+
result = await call_with_observability(lambda: function.acall_invoke(**kwargs))
|
|
151
170
|
|
|
152
171
|
# Report completion
|
|
153
172
|
if ctx:
|
|
@@ -170,7 +189,7 @@ def create_function_wrapper(
|
|
|
170
189
|
wrapper = create_wrapper()
|
|
171
190
|
|
|
172
191
|
# Set the signature on the wrapper function (WITHOUT ctx)
|
|
173
|
-
wrapper.__signature__ = sig
|
|
192
|
+
wrapper.__signature__ = sig # type: ignore
|
|
174
193
|
wrapper.__name__ = function_name
|
|
175
194
|
|
|
176
195
|
# Return the wrapper with proper signature
|
|
@@ -183,8 +202,8 @@ def get_function_description(function: FunctionBase) -> str:
|
|
|
183
202
|
|
|
184
203
|
The description is determined using the following precedence:
|
|
185
204
|
1. If the function is a Workflow and has a 'description' attribute, use it.
|
|
186
|
-
2. If the Workflow's config has a '
|
|
187
|
-
3. If the Workflow's config has a '
|
|
205
|
+
2. If the Workflow's config has a 'description', use it.
|
|
206
|
+
3. If the Workflow's config has a 'topic', use it.
|
|
188
207
|
4. If the function is a regular Function, use its 'description' attribute.
|
|
189
208
|
|
|
190
209
|
Args:
|
|
@@ -195,6 +214,9 @@ def get_function_description(function: FunctionBase) -> str:
|
|
|
195
214
|
"""
|
|
196
215
|
function_description = ""
|
|
197
216
|
|
|
217
|
+
# Import here to avoid circular imports
|
|
218
|
+
from nat.builder.workflow import Workflow
|
|
219
|
+
|
|
198
220
|
if isinstance(function, Workflow):
|
|
199
221
|
config = function.config
|
|
200
222
|
|
|
@@ -214,13 +236,17 @@ def get_function_description(function: FunctionBase) -> str:
|
|
|
214
236
|
return function_description
|
|
215
237
|
|
|
216
238
|
|
|
217
|
-
def register_function_with_mcp(mcp: FastMCP,
|
|
239
|
+
def register_function_with_mcp(mcp: FastMCP,
|
|
240
|
+
function_name: str,
|
|
241
|
+
function: FunctionBase,
|
|
242
|
+
workflow: 'Workflow | None' = None) -> None:
|
|
218
243
|
"""Register a NAT Function as an MCP tool.
|
|
219
244
|
|
|
220
245
|
Args:
|
|
221
246
|
mcp: The FastMCP instance
|
|
222
247
|
function_name: The name to register the function under
|
|
223
248
|
function: The NAT Function to register
|
|
249
|
+
workflow: The parent workflow for observability context (if available)
|
|
224
250
|
"""
|
|
225
251
|
logger.info("Registering function %s with MCP", function_name)
|
|
226
252
|
|
|
@@ -229,6 +255,7 @@ def register_function_with_mcp(mcp: FastMCP, function_name: str, function: Funct
|
|
|
229
255
|
logger.info("Function %s has input schema: %s", function_name, input_schema)
|
|
230
256
|
|
|
231
257
|
# Check if we're dealing with a Workflow
|
|
258
|
+
from nat.builder.workflow import Workflow
|
|
232
259
|
is_workflow = isinstance(function, Workflow)
|
|
233
260
|
if is_workflow:
|
|
234
261
|
logger.info("Function %s is a Workflow", function_name)
|
|
@@ -237,5 +264,5 @@ def register_function_with_mcp(mcp: FastMCP, function_name: str, function: Funct
|
|
|
237
264
|
function_description = get_function_description(function)
|
|
238
265
|
|
|
239
266
|
# Create and register the wrapper function with MCP
|
|
240
|
-
wrapper_func = create_function_wrapper(function_name, function, input_schema, is_workflow)
|
|
267
|
+
wrapper_func = create_function_wrapper(function_name, function, input_schema, is_workflow, workflow)
|
|
241
268
|
mcp.tool(name=function_name, description=function_description)(wrapper_func)
|
nat/observability/register.py
CHANGED
|
@@ -72,8 +72,10 @@ async def console_logging_method(config: ConsoleLoggingMethodConfig, builder: Bu
|
|
|
72
72
|
"""
|
|
73
73
|
Build and return a StreamHandler for console-based logging.
|
|
74
74
|
"""
|
|
75
|
+
import sys
|
|
76
|
+
|
|
75
77
|
level = getattr(logging, config.level.upper(), logging.INFO)
|
|
76
|
-
handler = logging.StreamHandler()
|
|
78
|
+
handler = logging.StreamHandler(stream=sys.stdout)
|
|
77
79
|
handler.setLevel(level)
|
|
78
80
|
yield handler
|
|
79
81
|
|
|
@@ -16,12 +16,18 @@
|
|
|
16
16
|
from __future__ import annotations
|
|
17
17
|
|
|
18
18
|
import logging
|
|
19
|
+
from abc import ABC
|
|
20
|
+
from abc import abstractmethod
|
|
21
|
+
from contextlib import AsyncExitStack
|
|
19
22
|
from contextlib import asynccontextmanager
|
|
20
23
|
from enum import Enum
|
|
21
24
|
from typing import Any
|
|
22
25
|
|
|
23
26
|
from mcp import ClientSession
|
|
24
27
|
from mcp.client.sse import sse_client
|
|
28
|
+
from mcp.client.stdio import StdioServerParameters
|
|
29
|
+
from mcp.client.stdio import stdio_client
|
|
30
|
+
from mcp.client.streamable_http import streamablehttp_client
|
|
25
31
|
from mcp.types import TextContent
|
|
26
32
|
from pydantic import BaseModel
|
|
27
33
|
from pydantic import Field
|
|
@@ -29,6 +35,7 @@ from pydantic import create_model
|
|
|
29
35
|
|
|
30
36
|
from nat.tool.mcp.exceptions import MCPToolNotFoundError
|
|
31
37
|
from nat.utils.exception_handlers.mcp import mcp_exception_handler
|
|
38
|
+
from nat.utils.type_utils import override
|
|
32
39
|
|
|
33
40
|
logger = logging.getLogger(__name__)
|
|
34
41
|
|
|
@@ -107,56 +114,78 @@ def model_from_mcp_schema(name: str, mcp_input_schema: dict) -> type[BaseModel]:
|
|
|
107
114
|
return create_model(f"{_generate_valid_classname(name)}InputSchema", **schema_dict)
|
|
108
115
|
|
|
109
116
|
|
|
110
|
-
class
|
|
117
|
+
class MCPBaseClient(ABC):
|
|
111
118
|
"""
|
|
112
|
-
|
|
119
|
+
Base client for creating a session and connecting to an MCP server
|
|
113
120
|
|
|
114
121
|
Args:
|
|
115
|
-
|
|
122
|
+
transport (str): The type of client to use ('sse', 'stdio', or 'streamable-http')
|
|
116
123
|
"""
|
|
117
124
|
|
|
118
|
-
def __init__(self,
|
|
119
|
-
self.
|
|
125
|
+
def __init__(self, transport: str = 'streamable-http'):
|
|
126
|
+
self._tools = None
|
|
127
|
+
self._transport = transport.lower()
|
|
128
|
+
if self._transport not in ['sse', 'stdio', 'streamable-http']:
|
|
129
|
+
raise ValueError("transport must be either 'sse', 'stdio' or 'streamable-http'")
|
|
120
130
|
|
|
121
|
-
|
|
122
|
-
async def connect_to_sse_server(self):
|
|
123
|
-
"""
|
|
124
|
-
Establish a session with an MCP SSE server within an aync context
|
|
125
|
-
"""
|
|
126
|
-
async with sse_client(url=self.url) as (read, write):
|
|
127
|
-
async with ClientSession(read, write) as session:
|
|
128
|
-
await session.initialize()
|
|
129
|
-
yield session
|
|
131
|
+
self._exit_stack: AsyncExitStack | None = None
|
|
130
132
|
|
|
133
|
+
self._session: ClientSession | None = None
|
|
131
134
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
+
@property
|
|
136
|
+
def transport(self) -> str:
|
|
137
|
+
return self._transport
|
|
135
138
|
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
+
async def __aenter__(self):
|
|
140
|
+
if self._exit_stack:
|
|
141
|
+
raise RuntimeError("MCPBaseClient already initialized. Use async with to initialize.")
|
|
139
142
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
self.
|
|
143
|
+
self._exit_stack = AsyncExitStack()
|
|
144
|
+
|
|
145
|
+
self._session = await self._exit_stack.enter_async_context(self.connect_to_server())
|
|
146
|
+
|
|
147
|
+
return self
|
|
148
|
+
|
|
149
|
+
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
150
|
+
|
|
151
|
+
if not self._exit_stack:
|
|
152
|
+
raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
|
|
153
|
+
|
|
154
|
+
await self._exit_stack.aclose()
|
|
155
|
+
self._session = None
|
|
156
|
+
self._exit_stack = None
|
|
157
|
+
|
|
158
|
+
@property
|
|
159
|
+
def server_name(self):
|
|
160
|
+
"""
|
|
161
|
+
Provide server name for logging
|
|
162
|
+
"""
|
|
163
|
+
return self._transport
|
|
164
|
+
|
|
165
|
+
@abstractmethod
|
|
166
|
+
@asynccontextmanager
|
|
167
|
+
async def connect_to_server(self):
|
|
168
|
+
"""
|
|
169
|
+
Establish a session with an MCP server within an async context
|
|
170
|
+
"""
|
|
171
|
+
pass
|
|
143
172
|
|
|
144
|
-
@mcp_exception_handler
|
|
145
173
|
async def get_tools(self):
|
|
146
174
|
"""
|
|
147
175
|
Retrieve a dictionary of all tools served by the MCP server.
|
|
176
|
+
"""
|
|
148
177
|
|
|
149
|
-
|
|
150
|
-
|
|
178
|
+
if not self._session:
|
|
179
|
+
raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
|
|
151
180
|
|
|
152
|
-
|
|
153
|
-
MCPError: If connection or tool retrieval fails
|
|
154
|
-
"""
|
|
155
|
-
async with self.connect_to_sse_server() as session:
|
|
156
|
-
response = await session.list_tools()
|
|
181
|
+
response = await self._session.list_tools()
|
|
157
182
|
|
|
158
183
|
return {
|
|
159
|
-
tool.name:
|
|
184
|
+
tool.name:
|
|
185
|
+
MCPToolClient(session=self._session,
|
|
186
|
+
tool_name=tool.name,
|
|
187
|
+
tool_description=tool.description,
|
|
188
|
+
tool_input_schema=tool.inputSchema)
|
|
160
189
|
for tool in response.tools
|
|
161
190
|
}
|
|
162
191
|
|
|
@@ -172,9 +201,11 @@ class MCPBuilder(MCPSSEClient):
|
|
|
172
201
|
MCPToolClient for the configured tool.
|
|
173
202
|
|
|
174
203
|
Raises:
|
|
175
|
-
MCPToolNotFoundError: If no tool is available with that name
|
|
176
|
-
MCPError: If connection fails
|
|
204
|
+
MCPToolNotFoundError: If no tool is available with that name.
|
|
177
205
|
"""
|
|
206
|
+
if not self._exit_stack:
|
|
207
|
+
raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
|
|
208
|
+
|
|
178
209
|
if not self._tools:
|
|
179
210
|
self._tools = await self.get_tools()
|
|
180
211
|
|
|
@@ -185,27 +216,143 @@ class MCPBuilder(MCPSSEClient):
|
|
|
185
216
|
|
|
186
217
|
@mcp_exception_handler
|
|
187
218
|
async def call_tool(self, tool_name: str, tool_args: dict | None):
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
219
|
+
if not self._session:
|
|
220
|
+
raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
|
|
221
|
+
|
|
222
|
+
result = await self._session.call_tool(tool_name, tool_args)
|
|
223
|
+
return result
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
class MCPSSEClient(MCPBaseClient):
|
|
227
|
+
"""
|
|
228
|
+
Client for creating a session and connecting to an MCP server using SSE
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
url (str): The url of the MCP server
|
|
232
|
+
"""
|
|
233
|
+
|
|
234
|
+
def __init__(self, url: str):
|
|
235
|
+
super().__init__("sse")
|
|
236
|
+
self._url = url
|
|
237
|
+
|
|
238
|
+
@property
|
|
239
|
+
def url(self) -> str:
|
|
240
|
+
return self._url
|
|
241
|
+
|
|
242
|
+
@property
|
|
243
|
+
def server_name(self):
|
|
244
|
+
return f"sse:{self._url}"
|
|
245
|
+
|
|
246
|
+
@asynccontextmanager
|
|
247
|
+
@override
|
|
248
|
+
async def connect_to_server(self):
|
|
249
|
+
"""
|
|
250
|
+
Establish a session with an MCP SSE server within an async context
|
|
251
|
+
"""
|
|
252
|
+
async with sse_client(url=self._url) as (read, write):
|
|
253
|
+
async with ClientSession(read, write) as session:
|
|
254
|
+
await session.initialize()
|
|
255
|
+
yield session
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
class MCPStdioClient(MCPBaseClient):
|
|
259
|
+
"""
|
|
260
|
+
Client for creating a session and connecting to an MCP server using stdio
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
command (str): The command to run
|
|
264
|
+
args (list[str] | None): Additional arguments for the command
|
|
265
|
+
env (dict[str, str] | None): Environment variables to set for the process
|
|
266
|
+
"""
|
|
267
|
+
|
|
268
|
+
def __init__(self, command: str, args: list[str] | None = None, env: dict[str, str] | None = None):
|
|
269
|
+
super().__init__("stdio")
|
|
270
|
+
self._command = command
|
|
271
|
+
self._args = args
|
|
272
|
+
self._env = env
|
|
273
|
+
|
|
274
|
+
@property
|
|
275
|
+
def command(self) -> str:
|
|
276
|
+
return self._command
|
|
277
|
+
|
|
278
|
+
@property
|
|
279
|
+
def server_name(self):
|
|
280
|
+
return f"stdio:{self._command}"
|
|
281
|
+
|
|
282
|
+
@property
|
|
283
|
+
def args(self) -> list[str] | None:
|
|
284
|
+
return self._args
|
|
285
|
+
|
|
286
|
+
@property
|
|
287
|
+
def env(self) -> dict[str, str] | None:
|
|
288
|
+
return self._env
|
|
289
|
+
|
|
290
|
+
@asynccontextmanager
|
|
291
|
+
@override
|
|
292
|
+
async def connect_to_server(self):
|
|
293
|
+
"""
|
|
294
|
+
Establish a session with an MCP server via stdio within an async context
|
|
295
|
+
"""
|
|
296
|
+
|
|
297
|
+
server_params = StdioServerParameters(command=self._command, args=self._args or [], env=self._env)
|
|
298
|
+
async with stdio_client(server_params) as (read, write):
|
|
299
|
+
async with ClientSession(read, write) as session:
|
|
300
|
+
await session.initialize()
|
|
301
|
+
yield session
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
class MCPStreamableHTTPClient(MCPBaseClient):
|
|
305
|
+
"""
|
|
306
|
+
Client for creating a session and connecting to an MCP server using streamable-http
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
url (str): The url of the MCP server
|
|
310
|
+
"""
|
|
311
|
+
|
|
312
|
+
def __init__(self, url: str):
|
|
313
|
+
super().__init__("streamable-http")
|
|
314
|
+
|
|
315
|
+
self._url = url
|
|
316
|
+
|
|
317
|
+
@property
|
|
318
|
+
def url(self) -> str:
|
|
319
|
+
return self._url
|
|
191
320
|
|
|
321
|
+
@property
|
|
322
|
+
def server_name(self):
|
|
323
|
+
return f"streamable-http:{self._url}"
|
|
192
324
|
|
|
193
|
-
|
|
325
|
+
@asynccontextmanager
|
|
326
|
+
async def connect_to_server(self):
|
|
327
|
+
"""
|
|
328
|
+
Establish a session with an MCP server via streamable-http within an async context
|
|
329
|
+
"""
|
|
330
|
+
async with streamablehttp_client(url=self._url) as (read, write, get_session_id):
|
|
331
|
+
async with ClientSession(read, write) as session:
|
|
332
|
+
await session.initialize()
|
|
333
|
+
yield session
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
class MCPToolClient:
|
|
194
337
|
"""
|
|
195
338
|
Client wrapper used to call an MCP tool.
|
|
196
339
|
|
|
197
340
|
Args:
|
|
198
|
-
|
|
341
|
+
connect_fn (callable): Function that returns an async context manager for connecting to the server
|
|
199
342
|
tool_name (str): The name of the tool to wrap
|
|
200
343
|
tool_description (str): The description of the tool provided by the MCP server.
|
|
201
344
|
tool_input_schema (dict): The input schema for the tool.
|
|
202
345
|
"""
|
|
203
346
|
|
|
204
|
-
def __init__(self,
|
|
205
|
-
|
|
347
|
+
def __init__(self,
|
|
348
|
+
session: ClientSession,
|
|
349
|
+
tool_name: str,
|
|
350
|
+
tool_description: str | None,
|
|
351
|
+
tool_input_schema: dict | None = None):
|
|
352
|
+
self._session = session
|
|
206
353
|
self._tool_name = tool_name
|
|
207
354
|
self._tool_description = tool_description
|
|
208
|
-
self._input_schema = model_from_mcp_schema(self._tool_name, tool_input_schema) if tool_input_schema else None
|
|
355
|
+
self._input_schema = (model_from_mcp_schema(self._tool_name, tool_input_schema) if tool_input_schema else None)
|
|
209
356
|
|
|
210
357
|
@property
|
|
211
358
|
def name(self):
|
|
@@ -234,7 +381,6 @@ class MCPToolClient(MCPSSEClient):
|
|
|
234
381
|
"""
|
|
235
382
|
self._tool_description = description
|
|
236
383
|
|
|
237
|
-
@mcp_exception_handler
|
|
238
384
|
async def acall(self, tool_args: dict) -> str:
|
|
239
385
|
"""
|
|
240
386
|
Call the MCP tool with the provided arguments.
|
|
@@ -242,14 +388,19 @@ class MCPToolClient(MCPSSEClient):
|
|
|
242
388
|
Args:
|
|
243
389
|
tool_args (dict[str, Any]): A dictionary of key value pairs to serve as inputs for the MCP tool.
|
|
244
390
|
"""
|
|
245
|
-
|
|
246
|
-
result = await session.call_tool(self._tool_name, tool_args)
|
|
391
|
+
result = await self._session.call_tool(self._tool_name, tool_args)
|
|
247
392
|
|
|
248
393
|
output = []
|
|
394
|
+
|
|
249
395
|
for res in result.content:
|
|
250
396
|
if isinstance(res, TextContent):
|
|
251
397
|
output.append(res.text)
|
|
252
398
|
else:
|
|
253
399
|
# Log non-text content for now
|
|
254
400
|
logger.warning("Got not-text output from %s of type %s", self.name, type(res))
|
|
255
|
-
|
|
401
|
+
result_str = "\n".join(output)
|
|
402
|
+
|
|
403
|
+
if result.isError:
|
|
404
|
+
raise RuntimeError(result_str)
|
|
405
|
+
|
|
406
|
+
return result_str
|