nvidia-nat 1.3.0a20250829__py3-none-any.whl → 1.3.0a20250831__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,13 +17,17 @@ import json
17
17
  import logging
18
18
  from inspect import Parameter
19
19
  from inspect import Signature
20
+ from typing import TYPE_CHECKING
20
21
 
21
22
  from mcp.server.fastmcp import FastMCP
22
23
  from pydantic import BaseModel
23
24
 
25
+ from nat.builder.context import ContextState
24
26
  from nat.builder.function import Function
25
27
  from nat.builder.function_base import FunctionBase
26
- from nat.builder.workflow import Workflow
28
+
29
+ if TYPE_CHECKING:
30
+ from nat.builder.workflow import Workflow
27
31
 
28
32
  logger = logging.getLogger(__name__)
29
33
 
@@ -33,14 +37,16 @@ def create_function_wrapper(
33
37
  function: FunctionBase,
34
38
  schema: type[BaseModel],
35
39
  is_workflow: bool = False,
40
+ workflow: 'Workflow | None' = None,
36
41
  ):
37
42
  """Create a wrapper function that exposes the actual parameters of a NAT Function as an MCP tool.
38
43
 
39
44
  Args:
40
- function_name: The name of the function/tool
41
- function: The NAT Function object
42
- schema: The input schema of the function
43
- is_workflow: Whether the function is a Workflow
45
+ function_name (str): The name of the function/tool
46
+ function (FunctionBase): The NAT Function object
47
+ schema (type[BaseModel]): The input schema of the function
48
+ is_workflow (bool): Whether the function is a Workflow
49
+ workflow (Workflow | None): The parent workflow for observability context
44
50
 
45
51
  Returns:
46
52
  A wrapper function suitable for registration with MCP
@@ -101,6 +107,19 @@ def create_function_wrapper(
101
107
  await ctx.report_progress(0, 100)
102
108
 
103
109
  try:
110
+ # Helper function to wrap function calls with observability
111
+ async def call_with_observability(func_call):
112
+ # Use workflow's observability context (workflow should always be available)
113
+ if not workflow:
114
+ logger.error("Missing workflow context for function %s - observability will not be available",
115
+ function_name)
116
+ raise RuntimeError("Workflow context is required for observability")
117
+
118
+ logger.debug("Starting observability context for function %s", function_name)
119
+ context_state = ContextState.get()
120
+ async with workflow.exporter_manager.start(context_state=context_state):
121
+ return await func_call()
122
+
104
123
  # Special handling for ChatRequest
105
124
  if is_chat_request:
106
125
  from nat.data_models.api_server import ChatRequest
@@ -118,7 +137,7 @@ def create_function_wrapper(
118
137
  result = await runner.result(to_type=str)
119
138
  else:
120
139
  # Regular functions use ainvoke
121
- result = await function.ainvoke(chat_request, to_type=str)
140
+ result = await call_with_observability(lambda: function.ainvoke(chat_request, to_type=str))
122
141
  else:
123
142
  # Regular handling
124
143
  # Handle complex input schema - if we extracted fields from a nested schema,
@@ -129,7 +148,7 @@ def create_function_wrapper(
129
148
  field_type = schema.model_fields[field_name].annotation
130
149
 
131
150
  # If it's a pydantic model, we need to create an instance
132
- if hasattr(field_type, "model_validate"):
151
+ if field_type and hasattr(field_type, "model_validate"):
133
152
  # Create the nested object
134
153
  nested_obj = field_type.model_validate(kwargs)
135
154
  # Call with the nested object
@@ -147,7 +166,7 @@ def create_function_wrapper(
147
166
  result = await runner.result(to_type=str)
148
167
  else:
149
168
  # Regular function call
150
- result = await function.acall_invoke(**kwargs)
169
+ result = await call_with_observability(lambda: function.acall_invoke(**kwargs))
151
170
 
152
171
  # Report completion
153
172
  if ctx:
@@ -170,7 +189,7 @@ def create_function_wrapper(
170
189
  wrapper = create_wrapper()
171
190
 
172
191
  # Set the signature on the wrapper function (WITHOUT ctx)
173
- wrapper.__signature__ = sig
192
+ wrapper.__signature__ = sig # type: ignore
174
193
  wrapper.__name__ = function_name
175
194
 
176
195
  # Return the wrapper with proper signature
@@ -183,8 +202,8 @@ def get_function_description(function: FunctionBase) -> str:
183
202
 
184
203
  The description is determined using the following precedence:
185
204
  1. If the function is a Workflow and has a 'description' attribute, use it.
186
- 2. If the Workflow's config has a 'topic', use it.
187
- 3. If the Workflow's config has a 'description', use it.
205
+ 2. If the Workflow's config has a 'description', use it.
206
+ 3. If the Workflow's config has a 'topic', use it.
188
207
  4. If the function is a regular Function, use its 'description' attribute.
189
208
 
190
209
  Args:
@@ -195,6 +214,9 @@ def get_function_description(function: FunctionBase) -> str:
195
214
  """
196
215
  function_description = ""
197
216
 
217
+ # Import here to avoid circular imports
218
+ from nat.builder.workflow import Workflow
219
+
198
220
  if isinstance(function, Workflow):
199
221
  config = function.config
200
222
 
@@ -214,13 +236,17 @@ def get_function_description(function: FunctionBase) -> str:
214
236
  return function_description
215
237
 
216
238
 
217
- def register_function_with_mcp(mcp: FastMCP, function_name: str, function: FunctionBase) -> None:
239
+ def register_function_with_mcp(mcp: FastMCP,
240
+ function_name: str,
241
+ function: FunctionBase,
242
+ workflow: 'Workflow | None' = None) -> None:
218
243
  """Register a NAT Function as an MCP tool.
219
244
 
220
245
  Args:
221
246
  mcp: The FastMCP instance
222
247
  function_name: The name to register the function under
223
248
  function: The NAT Function to register
249
+ workflow: The parent workflow for observability context (if available)
224
250
  """
225
251
  logger.info("Registering function %s with MCP", function_name)
226
252
 
@@ -229,6 +255,7 @@ def register_function_with_mcp(mcp: FastMCP, function_name: str, function: Funct
229
255
  logger.info("Function %s has input schema: %s", function_name, input_schema)
230
256
 
231
257
  # Check if we're dealing with a Workflow
258
+ from nat.builder.workflow import Workflow
232
259
  is_workflow = isinstance(function, Workflow)
233
260
  if is_workflow:
234
261
  logger.info("Function %s is a Workflow", function_name)
@@ -237,5 +264,5 @@ def register_function_with_mcp(mcp: FastMCP, function_name: str, function: Funct
237
264
  function_description = get_function_description(function)
238
265
 
239
266
  # Create and register the wrapper function with MCP
240
- wrapper_func = create_function_wrapper(function_name, function, input_schema, is_workflow)
267
+ wrapper_func = create_function_wrapper(function_name, function, input_schema, is_workflow, workflow)
241
268
  mcp.tool(name=function_name, description=function_description)(wrapper_func)
@@ -72,8 +72,10 @@ async def console_logging_method(config: ConsoleLoggingMethodConfig, builder: Bu
72
72
  """
73
73
  Build and return a StreamHandler for console-based logging.
74
74
  """
75
+ import sys
76
+
75
77
  level = getattr(logging, config.level.upper(), logging.INFO)
76
- handler = logging.StreamHandler()
78
+ handler = logging.StreamHandler(stream=sys.stdout)
77
79
  handler.setLevel(level)
78
80
  yield handler
79
81
 
@@ -16,12 +16,18 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import logging
19
+ from abc import ABC
20
+ from abc import abstractmethod
21
+ from contextlib import AsyncExitStack
19
22
  from contextlib import asynccontextmanager
20
23
  from enum import Enum
21
24
  from typing import Any
22
25
 
23
26
  from mcp import ClientSession
24
27
  from mcp.client.sse import sse_client
28
+ from mcp.client.stdio import StdioServerParameters
29
+ from mcp.client.stdio import stdio_client
30
+ from mcp.client.streamable_http import streamablehttp_client
25
31
  from mcp.types import TextContent
26
32
  from pydantic import BaseModel
27
33
  from pydantic import Field
@@ -29,6 +35,7 @@ from pydantic import create_model
29
35
 
30
36
  from nat.tool.mcp.exceptions import MCPToolNotFoundError
31
37
  from nat.utils.exception_handlers.mcp import mcp_exception_handler
38
+ from nat.utils.type_utils import override
32
39
 
33
40
  logger = logging.getLogger(__name__)
34
41
 
@@ -107,56 +114,78 @@ def model_from_mcp_schema(name: str, mcp_input_schema: dict) -> type[BaseModel]:
107
114
  return create_model(f"{_generate_valid_classname(name)}InputSchema", **schema_dict)
108
115
 
109
116
 
110
- class MCPSSEClient:
117
+ class MCPBaseClient(ABC):
111
118
  """
112
- Client for creating a session and connecting to an MCP server using SSE
119
+ Base client for creating a session and connecting to an MCP server
113
120
 
114
121
  Args:
115
- url (str): The url of the MCP server
122
+ transport (str): The type of client to use ('sse', 'stdio', or 'streamable-http')
116
123
  """
117
124
 
118
- def __init__(self, url: str):
119
- self.url = url
125
+ def __init__(self, transport: str = 'streamable-http'):
126
+ self._tools = None
127
+ self._transport = transport.lower()
128
+ if self._transport not in ['sse', 'stdio', 'streamable-http']:
129
+ raise ValueError("transport must be either 'sse', 'stdio' or 'streamable-http'")
120
130
 
121
- @asynccontextmanager
122
- async def connect_to_sse_server(self):
123
- """
124
- Establish a session with an MCP SSE server within an aync context
125
- """
126
- async with sse_client(url=self.url) as (read, write):
127
- async with ClientSession(read, write) as session:
128
- await session.initialize()
129
- yield session
131
+ self._exit_stack: AsyncExitStack | None = None
130
132
 
133
+ self._session: ClientSession | None = None
131
134
 
132
- class MCPBuilder(MCPSSEClient):
133
- """
134
- Builder class used to connect to an MCP Server and generate ToolClients
135
+ @property
136
+ def transport(self) -> str:
137
+ return self._transport
135
138
 
136
- Args:
137
- url (str): The url of the MCP server
138
- """
139
+ async def __aenter__(self):
140
+ if self._exit_stack:
141
+ raise RuntimeError("MCPBaseClient already initialized. Use async with to initialize.")
139
142
 
140
- def __init__(self, url):
141
- super().__init__(url)
142
- self._tools = None
143
+ self._exit_stack = AsyncExitStack()
144
+
145
+ self._session = await self._exit_stack.enter_async_context(self.connect_to_server())
146
+
147
+ return self
148
+
149
+ async def __aexit__(self, exc_type, exc_value, traceback):
150
+
151
+ if not self._exit_stack:
152
+ raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
153
+
154
+ await self._exit_stack.aclose()
155
+ self._session = None
156
+ self._exit_stack = None
157
+
158
+ @property
159
+ def server_name(self):
160
+ """
161
+ Provide server name for logging
162
+ """
163
+ return self._transport
164
+
165
+ @abstractmethod
166
+ @asynccontextmanager
167
+ async def connect_to_server(self):
168
+ """
169
+ Establish a session with an MCP server within an async context
170
+ """
171
+ pass
143
172
 
144
- @mcp_exception_handler
145
173
  async def get_tools(self):
146
174
  """
147
175
  Retrieve a dictionary of all tools served by the MCP server.
176
+ """
148
177
 
149
- Returns:
150
- Dict of tool name to MCPToolClient
178
+ if not self._session:
179
+ raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
151
180
 
152
- Raises:
153
- MCPError: If connection or tool retrieval fails
154
- """
155
- async with self.connect_to_sse_server() as session:
156
- response = await session.list_tools()
181
+ response = await self._session.list_tools()
157
182
 
158
183
  return {
159
- tool.name: MCPToolClient(self.url, tool.name, tool.description, tool_input_schema=tool.inputSchema)
184
+ tool.name:
185
+ MCPToolClient(session=self._session,
186
+ tool_name=tool.name,
187
+ tool_description=tool.description,
188
+ tool_input_schema=tool.inputSchema)
160
189
  for tool in response.tools
161
190
  }
162
191
 
@@ -172,9 +201,11 @@ class MCPBuilder(MCPSSEClient):
172
201
  MCPToolClient for the configured tool.
173
202
 
174
203
  Raises:
175
- MCPToolNotFoundError: If no tool is available with that name
176
- MCPError: If connection fails
204
+ MCPToolNotFoundError: If no tool is available with that name.
177
205
  """
206
+ if not self._exit_stack:
207
+ raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
208
+
178
209
  if not self._tools:
179
210
  self._tools = await self.get_tools()
180
211
 
@@ -185,27 +216,143 @@ class MCPBuilder(MCPSSEClient):
185
216
 
186
217
  @mcp_exception_handler
187
218
  async def call_tool(self, tool_name: str, tool_args: dict | None):
188
- async with self.connect_to_sse_server() as session:
189
- result = await session.call_tool(tool_name, tool_args)
190
- return result
219
+ if not self._session:
220
+ raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
221
+
222
+ result = await self._session.call_tool(tool_name, tool_args)
223
+ return result
224
+
225
+
226
+ class MCPSSEClient(MCPBaseClient):
227
+ """
228
+ Client for creating a session and connecting to an MCP server using SSE
229
+
230
+ Args:
231
+ url (str): The url of the MCP server
232
+ """
233
+
234
+ def __init__(self, url: str):
235
+ super().__init__("sse")
236
+ self._url = url
237
+
238
+ @property
239
+ def url(self) -> str:
240
+ return self._url
241
+
242
+ @property
243
+ def server_name(self):
244
+ return f"sse:{self._url}"
245
+
246
+ @asynccontextmanager
247
+ @override
248
+ async def connect_to_server(self):
249
+ """
250
+ Establish a session with an MCP SSE server within an async context
251
+ """
252
+ async with sse_client(url=self._url) as (read, write):
253
+ async with ClientSession(read, write) as session:
254
+ await session.initialize()
255
+ yield session
256
+
257
+
258
+ class MCPStdioClient(MCPBaseClient):
259
+ """
260
+ Client for creating a session and connecting to an MCP server using stdio
261
+
262
+ Args:
263
+ command (str): The command to run
264
+ args (list[str] | None): Additional arguments for the command
265
+ env (dict[str, str] | None): Environment variables to set for the process
266
+ """
267
+
268
+ def __init__(self, command: str, args: list[str] | None = None, env: dict[str, str] | None = None):
269
+ super().__init__("stdio")
270
+ self._command = command
271
+ self._args = args
272
+ self._env = env
273
+
274
+ @property
275
+ def command(self) -> str:
276
+ return self._command
277
+
278
+ @property
279
+ def server_name(self):
280
+ return f"stdio:{self._command}"
281
+
282
+ @property
283
+ def args(self) -> list[str] | None:
284
+ return self._args
285
+
286
+ @property
287
+ def env(self) -> dict[str, str] | None:
288
+ return self._env
289
+
290
+ @asynccontextmanager
291
+ @override
292
+ async def connect_to_server(self):
293
+ """
294
+ Establish a session with an MCP server via stdio within an async context
295
+ """
296
+
297
+ server_params = StdioServerParameters(command=self._command, args=self._args or [], env=self._env)
298
+ async with stdio_client(server_params) as (read, write):
299
+ async with ClientSession(read, write) as session:
300
+ await session.initialize()
301
+ yield session
302
+
303
+
304
+ class MCPStreamableHTTPClient(MCPBaseClient):
305
+ """
306
+ Client for creating a session and connecting to an MCP server using streamable-http
307
+
308
+ Args:
309
+ url (str): The url of the MCP server
310
+ """
311
+
312
+ def __init__(self, url: str):
313
+ super().__init__("streamable-http")
314
+
315
+ self._url = url
316
+
317
+ @property
318
+ def url(self) -> str:
319
+ return self._url
191
320
 
321
+ @property
322
+ def server_name(self):
323
+ return f"streamable-http:{self._url}"
192
324
 
193
- class MCPToolClient(MCPSSEClient):
325
+ @asynccontextmanager
326
+ async def connect_to_server(self):
327
+ """
328
+ Establish a session with an MCP server via streamable-http within an async context
329
+ """
330
+ async with streamablehttp_client(url=self._url) as (read, write, get_session_id):
331
+ async with ClientSession(read, write) as session:
332
+ await session.initialize()
333
+ yield session
334
+
335
+
336
+ class MCPToolClient:
194
337
  """
195
338
  Client wrapper used to call an MCP tool.
196
339
 
197
340
  Args:
198
- url (str): The url of the MCP server
341
+ connect_fn (callable): Function that returns an async context manager for connecting to the server
199
342
  tool_name (str): The name of the tool to wrap
200
343
  tool_description (str): The description of the tool provided by the MCP server.
201
344
  tool_input_schema (dict): The input schema for the tool.
202
345
  """
203
346
 
204
- def __init__(self, url: str, tool_name: str, tool_description: str | None, tool_input_schema: dict | None = None):
205
- super().__init__(url)
347
+ def __init__(self,
348
+ session: ClientSession,
349
+ tool_name: str,
350
+ tool_description: str | None,
351
+ tool_input_schema: dict | None = None):
352
+ self._session = session
206
353
  self._tool_name = tool_name
207
354
  self._tool_description = tool_description
208
- self._input_schema = model_from_mcp_schema(self._tool_name, tool_input_schema) if tool_input_schema else None
355
+ self._input_schema = (model_from_mcp_schema(self._tool_name, tool_input_schema) if tool_input_schema else None)
209
356
 
210
357
  @property
211
358
  def name(self):
@@ -234,7 +381,6 @@ class MCPToolClient(MCPSSEClient):
234
381
  """
235
382
  self._tool_description = description
236
383
 
237
- @mcp_exception_handler
238
384
  async def acall(self, tool_args: dict) -> str:
239
385
  """
240
386
  Call the MCP tool with the provided arguments.
@@ -242,14 +388,19 @@ class MCPToolClient(MCPSSEClient):
242
388
  Args:
243
389
  tool_args (dict[str, Any]): A dictionary of key value pairs to serve as inputs for the MCP tool.
244
390
  """
245
- async with self.connect_to_sse_server() as session:
246
- result = await session.call_tool(self._tool_name, tool_args)
391
+ result = await self._session.call_tool(self._tool_name, tool_args)
247
392
 
248
393
  output = []
394
+
249
395
  for res in result.content:
250
396
  if isinstance(res, TextContent):
251
397
  output.append(res.text)
252
398
  else:
253
399
  # Log non-text content for now
254
400
  logger.warning("Got not-text output from %s of type %s", self.name, type(res))
255
- return "\n".join(output)
401
+ result_str = "\n".join(output)
402
+
403
+ if result.isError:
404
+ raise RuntimeError(result_str)
405
+
406
+ return result_str