nvidia-nat 1.3.0a20250828__py3-none-any.whl → 1.3.0a20250830__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. nat/agent/base.py +6 -1
  2. nat/agent/react_agent/agent.py +46 -38
  3. nat/agent/react_agent/register.py +7 -2
  4. nat/agent/rewoo_agent/agent.py +16 -30
  5. nat/agent/rewoo_agent/register.py +3 -3
  6. nat/agent/tool_calling_agent/agent.py +9 -19
  7. nat/agent/tool_calling_agent/register.py +2 -2
  8. nat/builder/eval_builder.py +2 -2
  9. nat/builder/function.py +8 -8
  10. nat/builder/workflow.py +6 -2
  11. nat/builder/workflow_builder.py +21 -24
  12. nat/cli/cli_utils/config_override.py +1 -1
  13. nat/cli/commands/info/list_channels.py +1 -1
  14. nat/cli/commands/info/list_mcp.py +183 -47
  15. nat/cli/commands/registry/publish.py +2 -2
  16. nat/cli/commands/registry/pull.py +2 -2
  17. nat/cli/commands/registry/remove.py +2 -2
  18. nat/cli/commands/registry/search.py +1 -1
  19. nat/cli/commands/start.py +15 -3
  20. nat/cli/commands/uninstall.py +1 -1
  21. nat/cli/commands/workflow/workflow_commands.py +4 -4
  22. nat/data_models/discovery_metadata.py +4 -4
  23. nat/data_models/thinking_mixin.py +27 -8
  24. nat/eval/evaluate.py +6 -6
  25. nat/eval/intermediate_step_adapter.py +1 -1
  26. nat/eval/rag_evaluator/evaluate.py +2 -2
  27. nat/eval/rag_evaluator/register.py +1 -1
  28. nat/eval/remote_workflow.py +3 -3
  29. nat/eval/swe_bench_evaluator/evaluate.py +5 -5
  30. nat/eval/trajectory_evaluator/evaluate.py +1 -1
  31. nat/eval/tunable_rag_evaluator/evaluate.py +3 -3
  32. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +2 -2
  33. nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
  34. nat/front_ends/fastapi/fastapi_front_end_plugin.py +1 -1
  35. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +3 -3
  36. nat/front_ends/fastapi/message_handler.py +2 -2
  37. nat/front_ends/fastapi/message_validator.py +8 -10
  38. nat/front_ends/fastapi/response_helpers.py +4 -4
  39. nat/front_ends/fastapi/step_adaptor.py +1 -1
  40. nat/front_ends/mcp/mcp_front_end_config.py +5 -0
  41. nat/front_ends/mcp/mcp_front_end_plugin.py +8 -2
  42. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +2 -2
  43. nat/front_ends/mcp/tool_converter.py +40 -13
  44. nat/observability/exporter/base_exporter.py +1 -1
  45. nat/observability/exporter/processing_exporter.py +8 -9
  46. nat/observability/exporter_manager.py +5 -5
  47. nat/observability/mixin/file_mixin.py +7 -7
  48. nat/observability/processor/batching_processor.py +4 -6
  49. nat/observability/register.py +3 -1
  50. nat/profiler/calc/calc_runner.py +3 -4
  51. nat/profiler/callbacks/agno_callback_handler.py +1 -1
  52. nat/profiler/callbacks/langchain_callback_handler.py +5 -5
  53. nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
  54. nat/profiler/callbacks/semantic_kernel_callback_handler.py +2 -2
  55. nat/profiler/profile_runner.py +1 -1
  56. nat/profiler/utils.py +1 -1
  57. nat/registry_handlers/local/local_handler.py +2 -2
  58. nat/registry_handlers/package_utils.py +1 -1
  59. nat/registry_handlers/pypi/pypi_handler.py +3 -3
  60. nat/registry_handlers/rest/rest_handler.py +4 -4
  61. nat/retriever/milvus/retriever.py +1 -1
  62. nat/retriever/nemo_retriever/retriever.py +1 -1
  63. nat/runtime/loader.py +1 -1
  64. nat/runtime/runner.py +2 -2
  65. nat/settings/global_settings.py +1 -1
  66. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
  67. nat/tool/mcp/{mcp_client.py → mcp_client_base.py} +197 -46
  68. nat/tool/mcp/mcp_client_impl.py +229 -0
  69. nat/tool/mcp/mcp_tool.py +79 -42
  70. nat/tool/nvidia_rag.py +1 -1
  71. nat/tool/register.py +1 -0
  72. nat/tool/retriever.py +3 -2
  73. nat/utils/io/yaml_tools.py +1 -1
  74. nat/utils/reactive/observer.py +2 -2
  75. nat/utils/settings/global_settings.py +2 -2
  76. {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/METADATA +3 -3
  77. {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/RECORD +82 -81
  78. {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/WHEEL +0 -0
  79. {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/entry_points.txt +0 -0
  80. {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  81. {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/licenses/LICENSE.md +0 -0
  82. {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/top_level.txt +0 -0
@@ -40,7 +40,7 @@ class IntermediateStepAdapter:
40
40
  try:
41
41
  validated_steps.append(IntermediateStep.model_validate(step_data))
42
42
  except Exception as e:
43
- logger.exception("Validation failed for step: %r, Error: %s", step_data, e, exc_info=True)
43
+ logger.exception("Validation failed for step: %r, Error: %s", step_data, e)
44
44
  return validated_steps
45
45
 
46
46
  def serialize_intermediate_steps(self, intermediate_steps: list[IntermediateStep]) -> list[dict]:
@@ -102,7 +102,7 @@ class RAGEvaluator:
102
102
  """Converts the ragas EvaluationResult to nat EvalOutput"""
103
103
 
104
104
  if not results_dataset:
105
- logger.error("Ragas evaluation failed with no results")
105
+ logger.error("Ragas evaluation failed with no results", exc_info=True)
106
106
  return EvalOutput(average_score=0.0, eval_output_items=[])
107
107
 
108
108
  scores: list[dict[str, float]] = results_dataset.scores
@@ -169,7 +169,7 @@ class RAGEvaluator:
169
169
  _pbar=pbar)
170
170
  except Exception as e:
171
171
  # On exception we still continue with other evaluators. Log and return an avg_score of 0.0
172
- logger.exception("Error evaluating ragas metric, Error: %s", e, exc_info=True)
172
+ logger.exception("Error evaluating ragas metric, Error: %s", e)
173
173
  results_dataset = None
174
174
  finally:
175
175
  pbar.close()
@@ -104,7 +104,7 @@ async def register_ragas_evaluator(config: RagasEvaluatorConfig, builder: EvalBu
104
104
  raise ValueError(message) from e
105
105
  except AttributeError as e:
106
106
  message = f"Ragas metric {metric_name} not found {e}."
107
- logger.error(message)
107
+ logger.exception(message)
108
108
  return None
109
109
 
110
110
  async def evaluate_fn(eval_input: EvalInput) -> EvalOutput:
@@ -74,7 +74,7 @@ class EvaluationRemoteWorkflowHandler:
74
74
  if chunk_data.get("value"):
75
75
  final_response = chunk_data.get("value")
76
76
  except json.JSONDecodeError as e:
77
- logger.error("Failed to parse generate response chunk: %s", e)
77
+ logger.exception("Failed to parse generate response chunk: %s", e)
78
78
  continue
79
79
  elif line.startswith(INTERMEDIATE_DATA_PREFIX):
80
80
  # This is an intermediate step
@@ -90,12 +90,12 @@ class EvaluationRemoteWorkflowHandler:
90
90
  payload=payload)
91
91
  intermediate_steps.append(intermediate_step)
92
92
  except (json.JSONDecodeError, ValidationError) as e:
93
- logger.error("Failed to parse intermediate step: %s", e)
93
+ logger.exception("Failed to parse intermediate step: %s", e)
94
94
  continue
95
95
 
96
96
  except aiohttp.ClientError as e:
97
97
  # Handle connection or HTTP-related errors
98
- logger.error("Request failed for question %s: %s", question, e)
98
+ logger.exception("Request failed for question %s: %s", question, e)
99
99
  item.output_obj = None
100
100
  item.trajectory = []
101
101
  return
@@ -69,13 +69,13 @@ class SweBenchEvaluator:
69
69
  try:
70
70
  shutil.move(swe_bench_report_file, report_dir)
71
71
  except Exception as e:
72
- logger.exception("Error moving report file: %s", e, exc_info=True)
72
+ logger.exception("Error moving report file: %s", e)
73
73
 
74
74
  try:
75
75
  dest_logs_dir = os.path.join(report_dir, 'logs')
76
76
  shutil.move(logs_dir, dest_logs_dir)
77
77
  except Exception as e:
78
- logger.exception("Error moving logs directory: %s", e, exc_info=True)
78
+ logger.exception("Error moving logs directory: %s", e)
79
79
 
80
80
  def is_repo_supported(self, repo: str, version: str) -> bool:
81
81
  """Check if the repo is supported by swebench"""
@@ -106,7 +106,7 @@ class SweBenchEvaluator:
106
106
  self._model_name_or_path = swebench_output.model_name_or_path
107
107
 
108
108
  except Exception as e:
109
- logger.exception("Failed to parse EvalInputItem %s: %s", item.id, e, exc_info=True)
109
+ logger.exception("Failed to parse EvalInputItem %s: %s", item.id, e)
110
110
 
111
111
  # Filter out repos/version not supported by SWEBench
112
112
  supported_inputs = [
@@ -114,7 +114,7 @@ class SweBenchEvaluator:
114
114
  ]
115
115
 
116
116
  if not supported_inputs:
117
- logger.error("No supported instances; nothing to evaluate")
117
+ logger.exception("No supported instances; nothing to evaluate")
118
118
  return None, None
119
119
 
120
120
  if len(supported_inputs) < len(swebench_inputs):
@@ -135,7 +135,7 @@ class SweBenchEvaluator:
135
135
  filtered_outputs = [output for output in swebench_outputs if output.instance_id in valid_instance_ids]
136
136
 
137
137
  if not filtered_outputs:
138
- logger.error("No supported outputs; nothing to evaluate")
138
+ logger.error("No supported outputs; nothing to evaluate", exc_info=True)
139
139
  return None, None
140
140
 
141
141
  # Write SWEBenchOutput to file
@@ -65,7 +65,7 @@ class TrajectoryEvaluator(BaseEvaluator):
65
65
  prediction=generated_answer,
66
66
  )
67
67
  except Exception as e:
68
- logger.exception("Error evaluating trajectory for question: %s, Error: %s", question, e, exc_info=True)
68
+ logger.exception("Error evaluating trajectory for question: %s, Error: %s", question, e)
69
69
  return EvalOutputItem(id=item.id, score=0.0, reasoning=f"Error evaluating trajectory: {e}")
70
70
 
71
71
  reasoning = {
@@ -182,8 +182,8 @@ class TunableRagEvaluator(BaseEvaluator):
182
182
  relevance_score = parsed_response["relevance_score"]
183
183
  reasoning = parsed_response["reasoning"]
184
184
  except KeyError as e:
185
- logger.error("Missing required keys in default scoring response: %s",
186
- ", ".join(str(arg) for arg in e.args))
185
+ logger.exception("Missing required keys in default scoring response: %s",
186
+ ", ".join(str(arg) for arg in e.args))
187
187
  reasoning = f"Error in evaluator from parsing judge LLM response. Missing required key(s): {', '.join(str(arg) for arg in e.args)}"
188
188
 
189
189
  coverage_weight = self.default_score_weights.get("coverage", 1 / 3)
@@ -215,7 +215,7 @@ class TunableRagEvaluator(BaseEvaluator):
215
215
  reasoning = f"Error in evaluator from parsing judge LLM response. Missing required key(s): {', '.join(str(arg) for arg in e.args)}"
216
216
  raise
217
217
  except (KeyError, ValueError) as e:
218
- logger.error("Error parsing judge LLM response: %s", e)
218
+ logger.exception("Error parsing judge LLM response: %s", e)
219
219
  score = 0.0
220
220
  reasoning = "Error in evaluator from parsing judge LLM response."
221
221
 
@@ -148,13 +148,13 @@ async def register_ttc_tool_orchestration_function(
148
148
  result = await fn.acall_invoke(item.output)
149
149
  return item, result, None
150
150
  except Exception as e:
151
- logger.error(f"Error invoking function '{item.name}': {e}")
151
+ logger.exception(f"Error invoking function '{item.name}': {e}")
152
152
  return item, None, str(e)
153
153
 
154
154
  tasks = []
155
155
  for item in ttc_items:
156
156
  if item.name not in function_map:
157
- logger.error(f"Function '{item.name}' not found in function map.")
157
+ logger.error(f"Function '{item.name}' not found in function map.", exc_info=True)
158
158
  item.output = f"Error: Function '{item.name}' not found in function map. Check your input"
159
159
  else:
160
160
  fn = function_map[item.name]
@@ -47,11 +47,11 @@ class _FastApiFrontEndController:
47
47
  self._server_background_task = asyncio.create_task(self._server.serve())
48
48
  except asyncio.CancelledError as e:
49
49
  error_message = f"Task error occurred while starting API server: {str(e)}"
50
- logger.error(error_message, exc_info=True)
50
+ logger.error(error_message)
51
51
  raise RuntimeError(error_message) from e
52
52
  except Exception as e:
53
53
  error_message = f"Unexpected error occurred while starting API server: {str(e)}"
54
- logger.error(error_message, exc_info=True)
54
+ logger.exception(error_message)
55
55
  raise RuntimeError(error_message) from e
56
56
 
57
57
  async def stop_server(self) -> None:
@@ -63,6 +63,6 @@ class _FastApiFrontEndController:
63
63
  self._server.should_exit = True
64
64
  await self._server_background_task
65
65
  except asyncio.CancelledError as e:
66
- logger.error("Server shutdown failed: %s", str(e), exc_info=True)
66
+ logger.exception("Server shutdown failed: %s", str(e))
67
67
  except Exception as e:
68
- logger.error("Unexpected error occurred: %s", str(e), exc_info=True)
68
+ logger.exception("Unexpected error occurred: %s", str(e))
@@ -113,4 +113,4 @@ class FastApiFrontEndPlugin(FrontEndBase[FastApiFrontEndConfig]):
113
113
  try:
114
114
  os.remove(config_file_name)
115
115
  except OSError as e:
116
- logger.error(f"Warning: Failed to delete temp file {config_file_name}: {e}")
116
+ logger.exception(f"Warning: Failed to delete temp file {config_file_name}: {e}")
@@ -215,7 +215,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
215
215
  job_store.cleanup_expired_jobs()
216
216
  logger.debug("Expired %s jobs cleaned up", name)
217
217
  except Exception as e:
218
- logger.error("Error during %s job cleanup: %s", name, e)
218
+ logger.exception("Error during %s job cleanup: %s", name, e)
219
219
  await asyncio.sleep(sleep_time_sec)
220
220
 
221
221
  async def create_cleanup_task(self, app: FastAPI, name: str, job_store: JobStore, sleep_time_sec: int = 300):
@@ -301,7 +301,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
301
301
 
302
302
  job_store.update_status(job_id, "success", output_path=str(parent_dir))
303
303
  except Exception as e:
304
- logger.error("Error in evaluation job %s: %s", job_id, str(e))
304
+ logger.exception("Error in evaluation job %s: %s", job_id, str(e))
305
305
  job_store.update_status(job_id, "failure", error=str(e))
306
306
 
307
307
  async def start_evaluation(request: EvaluateRequest, background_tasks: BackgroundTasks, http_request: Request):
@@ -735,7 +735,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
735
735
  result_type=result_type)
736
736
  job_store.update_status(job_id, "success", output=result)
737
737
  except Exception as e:
738
- logger.error("Error in evaluation job %s: %s", job_id, e)
738
+ logger.exception("Error in evaluation job %s: %s", job_id, e)
739
739
  job_store.update_status(job_id, "failure", error=str(e))
740
740
 
741
741
  def _job_status_to_response(job: JobInfo) -> AsyncGenerationStatusResponse:
@@ -170,7 +170,7 @@ class WebSocketMessageHandler:
170
170
  self._workflow_schema_type])).add_done_callback(_done_callback)
171
171
 
172
172
  except ValueError as e:
173
- logger.error("User message content not found: %s", str(e), exc_info=True)
173
+ logger.exception("User message content not found: %s", str(e))
174
174
  await self.create_websocket_message(data_model=Error(code=ErrorTypes.INVALID_USER_MESSAGE_CONTENT,
175
175
  message="User message content could not be found",
176
176
  details=str(e)),
@@ -238,7 +238,7 @@ class WebSocketMessageHandler:
238
238
  f"Message type could not be resolved by input data model: {data_model.model_dump_json()}")
239
239
 
240
240
  except (ValidationError, TypeError, ValueError) as e:
241
- logger.error("A data vaidation error ocurred creating websocket message: %s", str(e), exc_info=True)
241
+ logger.exception("A data vaidation error ocurred creating websocket message: %s", str(e))
242
242
  message = await self._message_validator.create_system_response_token_message(
243
243
  message_type=WebSocketMessageType.ERROR_MESSAGE,
244
244
  conversation_id=self._conversation_id,
@@ -97,7 +97,7 @@ class MessageValidator:
97
97
  return validated_message
98
98
 
99
99
  except (ValidationError, TypeError, ValueError) as e:
100
- logger.error("A data validation error %s occurred for message: %s", str(e), str(message), exc_info=True)
100
+ logger.exception("A data validation error %s occurred for message: %s", str(e), str(message))
101
101
  return await self.create_system_response_token_message(message_type=WebSocketMessageType.ERROR_MESSAGE,
102
102
  content=Error(code=ErrorTypes.INVALID_MESSAGE,
103
103
  message="Error validating message.",
@@ -119,7 +119,7 @@ class MessageValidator:
119
119
  return schema
120
120
 
121
121
  except (TypeError, ValueError) as e:
122
- logger.error("Error retrieving schema for message type '%s': %s", message_type, str(e), exc_info=True)
122
+ logger.exception("Error retrieving schema for message type '%s': %s", message_type, str(e))
123
123
  return Error
124
124
 
125
125
  async def convert_data_to_message_content(self, data_model: BaseModel) -> BaseModel:
@@ -156,7 +156,7 @@ class MessageValidator:
156
156
  return validated_message_content
157
157
 
158
158
  except ValueError as e:
159
- logger.error("Input data could not be converted to validated message content: %s", str(e), exc_info=True)
159
+ logger.exception("Input data could not be converted to validated message content: %s", str(e))
160
160
  return Error(code=ErrorTypes.INVALID_DATA_CONTENT, message="Input data not supported.", details=str(e))
161
161
 
162
162
  async def convert_text_content_to_human_response(self, text_content: TextContent,
@@ -191,7 +191,7 @@ class MessageValidator:
191
191
  return human_response
192
192
 
193
193
  except ValueError as e:
194
- logger.error("Error human response content not found: %s", str(e), exc_info=True)
194
+ logger.exception("Error human response content not found: %s", str(e))
195
195
  return HumanResponseText(text=str(e))
196
196
 
197
197
  async def resolve_message_type_by_data(self, data_model: BaseModel) -> str:
@@ -218,9 +218,7 @@ class MessageValidator:
218
218
  return validated_message_type
219
219
 
220
220
  except ValueError as e:
221
- logger.error("Error type not found converting data to validated websocket message content: %s",
222
- str(e),
223
- exc_info=True)
221
+ logger.exception("Error type not found converting data to validated websocket message content: %s", str(e))
224
222
  return WebSocketMessageType.ERROR_MESSAGE
225
223
 
226
224
  async def get_intermediate_step_parent_id(self, data_model: ResponseIntermediateStep) -> str:
@@ -269,7 +267,7 @@ class MessageValidator:
269
267
  timestamp=timestamp)
270
268
 
271
269
  except Exception as e:
272
- logger.error("Error creating system response token message: %s", str(e), exc_info=True)
270
+ logger.exception("Error creating system response token message: %s", str(e))
273
271
  return None
274
272
 
275
273
  async def create_system_intermediate_step_message(
@@ -308,7 +306,7 @@ class MessageValidator:
308
306
  timestamp=timestamp)
309
307
 
310
308
  except Exception as e:
311
- logger.error("Error creating system intermediate step message: %s", str(e), exc_info=True)
309
+ logger.exception("Error creating system intermediate step message: %s", str(e))
312
310
  return None
313
311
 
314
312
  async def create_system_interaction_message(
@@ -348,5 +346,5 @@ class MessageValidator:
348
346
  timestamp=timestamp)
349
347
 
350
348
  except Exception as e:
351
- logger.error("Error creating system interaction message: %s", str(e), exc_info=True)
349
+ logger.exception("Error creating system interaction message: %s", str(e))
352
350
  return None
@@ -98,9 +98,9 @@ async def generate_streaming_response(payload: typing.Any,
98
98
  yield item
99
99
  else:
100
100
  yield ResponsePayloadOutput(payload=item)
101
- except Exception as e:
101
+ except Exception:
102
102
  # Handle exceptions here
103
- raise e
103
+ raise
104
104
  finally:
105
105
  await q.close()
106
106
 
@@ -165,9 +165,9 @@ async def generate_streaming_response_full(payload: typing.Any,
165
165
  yield item
166
166
  else:
167
167
  yield ResponsePayloadOutput(payload=item)
168
- except Exception as e:
168
+ except Exception:
169
169
  # Handle exceptions here
170
- raise e
170
+ raise
171
171
  finally:
172
172
  await q.close()
173
173
 
@@ -314,6 +314,6 @@ class StepAdaptor:
314
314
  return self._handle_custom(payload, ancestry)
315
315
 
316
316
  except Exception as e:
317
- logger.error("Error processing intermediate step: %s", e, exc_info=True)
317
+ logger.exception("Error processing intermediate step: %s", e)
318
318
 
319
319
  return None
@@ -13,6 +13,8 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ from typing import Literal
17
+
16
18
  from pydantic import Field
17
19
 
18
20
  from nat.data_models.front_end import FrontEndBaseConfig
@@ -32,5 +34,8 @@ class MCPFrontEndConfig(FrontEndBaseConfig, name="mcp"):
32
34
  log_level: str = Field(default="INFO", description="Log level for the MCP server (default: INFO)")
33
35
  tool_names: list[str] = Field(default_factory=list,
34
36
  description="The list of tools MCP server will expose (default: all tools)")
37
+ transport: Literal["sse", "streamable-http"] = Field(
38
+ default="streamable-http",
39
+ description="Transport type for the MCP server (default: streamable-http, backwards compatible with sse)")
35
40
  runner_class: str | None = Field(
36
41
  default=None, description="Custom worker class for handling MCP routes (default: built-in worker)")
@@ -77,5 +77,11 @@ class MCPFrontEndPlugin(FrontEndBase[MCPFrontEndConfig]):
77
77
  # Add routes through the worker (includes health endpoint and function registration)
78
78
  await worker.add_routes(mcp, builder)
79
79
 
80
- # Start the MCP server
81
- await mcp.run_sse_async()
80
+ # Start the MCP server with configurable transport
81
+ # streamable-http is the default, but users can choose sse if preferred
82
+ if self.front_end_config.transport == "sse":
83
+ logger.info("Starting MCP server with SSE endpoint at /sse")
84
+ await mcp.run_sse_async()
85
+ else: # streamable-http
86
+ logger.info("Starting MCP server with streamable-http endpoint at /mcp/")
87
+ await mcp.run_streamable_http_async()
@@ -134,9 +134,9 @@ class MCPFrontEndPluginWorker(MCPFrontEndPluginWorkerBase):
134
134
  logger.debug("Skipping function %s as it's not in tool_names", function_name)
135
135
  functions = filtered_functions
136
136
 
137
- # Register each function with MCP
137
+ # Register each function with MCP, passing workflow context for observability
138
138
  for function_name, function in functions.items():
139
- register_function_with_mcp(mcp, function_name, function)
139
+ register_function_with_mcp(mcp, function_name, function, workflow)
140
140
 
141
141
  # Add a simple fallback function if no functions were found
142
142
  if not functions:
@@ -17,13 +17,17 @@ import json
17
17
  import logging
18
18
  from inspect import Parameter
19
19
  from inspect import Signature
20
+ from typing import TYPE_CHECKING
20
21
 
21
22
  from mcp.server.fastmcp import FastMCP
22
23
  from pydantic import BaseModel
23
24
 
25
+ from nat.builder.context import ContextState
24
26
  from nat.builder.function import Function
25
27
  from nat.builder.function_base import FunctionBase
26
- from nat.builder.workflow import Workflow
28
+
29
+ if TYPE_CHECKING:
30
+ from nat.builder.workflow import Workflow
27
31
 
28
32
  logger = logging.getLogger(__name__)
29
33
 
@@ -33,14 +37,16 @@ def create_function_wrapper(
33
37
  function: FunctionBase,
34
38
  schema: type[BaseModel],
35
39
  is_workflow: bool = False,
40
+ workflow: 'Workflow | None' = None,
36
41
  ):
37
42
  """Create a wrapper function that exposes the actual parameters of a NAT Function as an MCP tool.
38
43
 
39
44
  Args:
40
- function_name: The name of the function/tool
41
- function: The NAT Function object
42
- schema: The input schema of the function
43
- is_workflow: Whether the function is a Workflow
45
+ function_name (str): The name of the function/tool
46
+ function (FunctionBase): The NAT Function object
47
+ schema (type[BaseModel]): The input schema of the function
48
+ is_workflow (bool): Whether the function is a Workflow
49
+ workflow (Workflow | None): The parent workflow for observability context
44
50
 
45
51
  Returns:
46
52
  A wrapper function suitable for registration with MCP
@@ -101,6 +107,19 @@ def create_function_wrapper(
101
107
  await ctx.report_progress(0, 100)
102
108
 
103
109
  try:
110
+ # Helper function to wrap function calls with observability
111
+ async def call_with_observability(func_call):
112
+ # Use workflow's observability context (workflow should always be available)
113
+ if not workflow:
114
+ logger.error("Missing workflow context for function %s - observability will not be available",
115
+ function_name)
116
+ raise RuntimeError("Workflow context is required for observability")
117
+
118
+ logger.debug("Starting observability context for function %s", function_name)
119
+ context_state = ContextState.get()
120
+ async with workflow.exporter_manager.start(context_state=context_state):
121
+ return await func_call()
122
+
104
123
  # Special handling for ChatRequest
105
124
  if is_chat_request:
106
125
  from nat.data_models.api_server import ChatRequest
@@ -118,7 +137,7 @@ def create_function_wrapper(
118
137
  result = await runner.result(to_type=str)
119
138
  else:
120
139
  # Regular functions use ainvoke
121
- result = await function.ainvoke(chat_request, to_type=str)
140
+ result = await call_with_observability(lambda: function.ainvoke(chat_request, to_type=str))
122
141
  else:
123
142
  # Regular handling
124
143
  # Handle complex input schema - if we extracted fields from a nested schema,
@@ -129,7 +148,7 @@ def create_function_wrapper(
129
148
  field_type = schema.model_fields[field_name].annotation
130
149
 
131
150
  # If it's a pydantic model, we need to create an instance
132
- if hasattr(field_type, "model_validate"):
151
+ if field_type and hasattr(field_type, "model_validate"):
133
152
  # Create the nested object
134
153
  nested_obj = field_type.model_validate(kwargs)
135
154
  # Call with the nested object
@@ -147,7 +166,7 @@ def create_function_wrapper(
147
166
  result = await runner.result(to_type=str)
148
167
  else:
149
168
  # Regular function call
150
- result = await function.acall_invoke(**kwargs)
169
+ result = await call_with_observability(lambda: function.acall_invoke(**kwargs))
151
170
 
152
171
  # Report completion
153
172
  if ctx:
@@ -170,7 +189,7 @@ def create_function_wrapper(
170
189
  wrapper = create_wrapper()
171
190
 
172
191
  # Set the signature on the wrapper function (WITHOUT ctx)
173
- wrapper.__signature__ = sig
192
+ wrapper.__signature__ = sig # type: ignore
174
193
  wrapper.__name__ = function_name
175
194
 
176
195
  # Return the wrapper with proper signature
@@ -183,8 +202,8 @@ def get_function_description(function: FunctionBase) -> str:
183
202
 
184
203
  The description is determined using the following precedence:
185
204
  1. If the function is a Workflow and has a 'description' attribute, use it.
186
- 2. If the Workflow's config has a 'topic', use it.
187
- 3. If the Workflow's config has a 'description', use it.
205
+ 2. If the Workflow's config has a 'description', use it.
206
+ 3. If the Workflow's config has a 'topic', use it.
188
207
  4. If the function is a regular Function, use its 'description' attribute.
189
208
 
190
209
  Args:
@@ -195,6 +214,9 @@ def get_function_description(function: FunctionBase) -> str:
195
214
  """
196
215
  function_description = ""
197
216
 
217
+ # Import here to avoid circular imports
218
+ from nat.builder.workflow import Workflow
219
+
198
220
  if isinstance(function, Workflow):
199
221
  config = function.config
200
222
 
@@ -214,13 +236,17 @@ def get_function_description(function: FunctionBase) -> str:
214
236
  return function_description
215
237
 
216
238
 
217
- def register_function_with_mcp(mcp: FastMCP, function_name: str, function: FunctionBase) -> None:
239
+ def register_function_with_mcp(mcp: FastMCP,
240
+ function_name: str,
241
+ function: FunctionBase,
242
+ workflow: 'Workflow | None' = None) -> None:
218
243
  """Register a NAT Function as an MCP tool.
219
244
 
220
245
  Args:
221
246
  mcp: The FastMCP instance
222
247
  function_name: The name to register the function under
223
248
  function: The NAT Function to register
249
+ workflow: The parent workflow for observability context (if available)
224
250
  """
225
251
  logger.info("Registering function %s with MCP", function_name)
226
252
 
@@ -229,6 +255,7 @@ def register_function_with_mcp(mcp: FastMCP, function_name: str, function: Funct
229
255
  logger.info("Function %s has input schema: %s", function_name, input_schema)
230
256
 
231
257
  # Check if we're dealing with a Workflow
258
+ from nat.builder.workflow import Workflow
232
259
  is_workflow = isinstance(function, Workflow)
233
260
  if is_workflow:
234
261
  logger.info("Function %s is a Workflow", function_name)
@@ -237,5 +264,5 @@ def register_function_with_mcp(mcp: FastMCP, function_name: str, function: Funct
237
264
  function_description = get_function_description(function)
238
265
 
239
266
  # Create and register the wrapper function with MCP
240
- wrapper_func = create_function_wrapper(function_name, function, input_schema, is_workflow)
267
+ wrapper_func = create_function_wrapper(function_name, function, input_schema, is_workflow, workflow)
241
268
  mcp.tool(name=function_name, description=function_description)(wrapper_func)
@@ -375,7 +375,7 @@ class BaseExporter(Exporter):
375
375
  except asyncio.TimeoutError:
376
376
  logger.warning("%s: Some tasks did not complete within %s seconds", self.name, timeout)
377
377
  except Exception as e:
378
- logger.error("%s: Error while waiting for tasks: %s", self.name, e, exc_info=True)
378
+ logger.exception("%s: Error while waiting for tasks: %s", self.name, e)
379
379
 
380
380
  @override
381
381
  async def stop(self):
@@ -175,7 +175,7 @@ class ProcessingExporter(Generic[PipelineInputT, PipelineOutputT], BaseExporter,
175
175
  try:
176
176
  processed_item = await processor.process(processed_item)
177
177
  except Exception as e:
178
- logger.error("Error in processor %s: %s", processor.__class__.__name__, e, exc_info=True)
178
+ logger.exception("Error in processor %s: %s", processor.__class__.__name__, e)
179
179
  # Continue with unprocessed item rather than failing
180
180
  return processed_item
181
181
 
@@ -214,7 +214,7 @@ class ProcessingExporter(Generic[PipelineInputT, PipelineOutputT], BaseExporter,
214
214
  try:
215
215
  source_index = self._processors.index(source_processor)
216
216
  except ValueError:
217
- logger.error("Source processor %s not found in pipeline", source_processor.__class__.__name__)
217
+ logger.exception("Source processor %s not found in pipeline", source_processor.__class__.__name__)
218
218
  return
219
219
 
220
220
  # Process through remaining processors (skip the source processor)
@@ -225,10 +225,9 @@ class ProcessingExporter(Generic[PipelineInputT, PipelineOutputT], BaseExporter,
225
225
  await self._export_final_item(processed_item)
226
226
 
227
227
  except Exception as e:
228
- logger.error("Failed to continue pipeline processing after %s: %s",
229
- source_processor.__class__.__name__,
230
- e,
231
- exc_info=True)
228
+ logger.exception("Failed to continue pipeline processing after %s: %s",
229
+ source_processor.__class__.__name__,
230
+ e)
232
231
 
233
232
  async def _export_with_processing(self, item: PipelineInputT) -> None:
234
233
  """Export an item after processing it through the pipeline.
@@ -248,7 +247,7 @@ class ProcessingExporter(Generic[PipelineInputT, PipelineOutputT], BaseExporter,
248
247
  await self._export_final_item(final_item, raise_on_invalid=True)
249
248
 
250
249
  except Exception as e:
251
- logger.error("Failed to export item '%s': %s", item, e, exc_info=True)
250
+ logger.error("Failed to export item '%s': %s", item, e)
252
251
  raise
253
252
 
254
253
  @override
@@ -293,7 +292,7 @@ class ProcessingExporter(Generic[PipelineInputT, PipelineOutputT], BaseExporter,
293
292
  task.add_done_callback(self._tasks.discard)
294
293
 
295
294
  except Exception as e:
296
- logger.error("%s: Failed to create task: %s", self.name, e, exc_info=True)
295
+ logger.error("%s: Failed to create task: %s", self.name, e)
297
296
  raise
298
297
 
299
298
  @override
@@ -316,7 +315,7 @@ class ProcessingExporter(Generic[PipelineInputT, PipelineOutputT], BaseExporter,
316
315
  await asyncio.gather(*shutdown_tasks, return_exceptions=True)
317
316
  logger.debug("Successfully shut down %d processors", len(shutdown_tasks))
318
317
  except Exception as e:
319
- logger.error("Error shutting down processors: %s", e, exc_info=True)
318
+ logger.exception("Error shutting down processors: %s", e)
320
319
 
321
320
  # Call parent cleanup
322
321
  await super()._cleanup()
@@ -177,7 +177,7 @@ class ExporterManager:
177
177
  else:
178
178
  logger.debug("Skipping cleanup for non-isolated exporter '%s'", name)
179
179
  except Exception as e:
180
- logger.error("Error preparing cleanup for isolated exporter '%s': %s", name, e)
180
+ logger.exception("Error preparing cleanup for isolated exporter '%s': %s", name, e)
181
181
 
182
182
  if cleanup_tasks:
183
183
  # Run cleanup tasks concurrently with timeout
@@ -195,7 +195,7 @@ class ExporterManager:
195
195
  logger.debug("Stopping isolated exporter '%s'", name)
196
196
  await exporter.stop()
197
197
  except Exception as e:
198
- logger.error("Error stopping isolated exporter '%s': %s", name, e)
198
+ logger.exception("Error stopping isolated exporter '%s': %s", name, e)
199
199
 
200
200
  @asynccontextmanager
201
201
  async def start(self, context_state: ContextState | None = None):
@@ -251,7 +251,7 @@ class ExporterManager:
251
251
  try:
252
252
  await self._cleanup_isolated_exporters()
253
253
  except Exception as e:
254
- logger.error("Error during isolated exporter cleanup: %s", e)
254
+ logger.exception("Error during isolated exporter cleanup: %s", e)
255
255
 
256
256
  # Then stop the manager tasks
257
257
  await self.stop()
@@ -275,7 +275,7 @@ class ExporterManager:
275
275
  logger.info("Stopped exporter '%s'", name)
276
276
  raise
277
277
  except Exception as e:
278
- logger.error("Failed to run exporter '%s': %s", name, str(e), exc_info=True)
278
+ logger.error("Failed to run exporter '%s': %s", name, str(e))
279
279
  # Re-raise the exception to ensure it's properly handled
280
280
  raise
281
281
 
@@ -307,7 +307,7 @@ class ExporterManager:
307
307
  except asyncio.CancelledError:
308
308
  logger.debug("Exporter '%s' task cancelled", name)
309
309
  except Exception as e:
310
- logger.error("Failed to stop exporter '%s': %s", name, str(e))
310
+ logger.exception("Failed to stop exporter '%s': %s", name, str(e))
311
311
 
312
312
  if stuck_tasks:
313
313
  logger.warning("Exporters did not shut down in time: %s", ", ".join(stuck_tasks))