nvidia-nat 1.3.0a20250828__py3-none-any.whl → 1.3.0a20250830__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/base.py +6 -1
- nat/agent/react_agent/agent.py +46 -38
- nat/agent/react_agent/register.py +7 -2
- nat/agent/rewoo_agent/agent.py +16 -30
- nat/agent/rewoo_agent/register.py +3 -3
- nat/agent/tool_calling_agent/agent.py +9 -19
- nat/agent/tool_calling_agent/register.py +2 -2
- nat/builder/eval_builder.py +2 -2
- nat/builder/function.py +8 -8
- nat/builder/workflow.py +6 -2
- nat/builder/workflow_builder.py +21 -24
- nat/cli/cli_utils/config_override.py +1 -1
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_mcp.py +183 -47
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +1 -1
- nat/cli/commands/start.py +15 -3
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/workflow_commands.py +4 -4
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/thinking_mixin.py +27 -8
- nat/eval/evaluate.py +6 -6
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +2 -2
- nat/eval/rag_evaluator/register.py +1 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/swe_bench_evaluator/evaluate.py +5 -5
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +2 -2
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +1 -1
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +3 -3
- nat/front_ends/fastapi/message_handler.py +2 -2
- nat/front_ends/fastapi/message_validator.py +8 -10
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +1 -1
- nat/front_ends/mcp/mcp_front_end_config.py +5 -0
- nat/front_ends/mcp/mcp_front_end_plugin.py +8 -2
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +2 -2
- nat/front_ends/mcp/tool_converter.py +40 -13
- nat/observability/exporter/base_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +8 -9
- nat/observability/exporter_manager.py +5 -5
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/processor/batching_processor.py +4 -6
- nat/observability/register.py +3 -1
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +5 -5
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +2 -2
- nat/profiler/profile_runner.py +1 -1
- nat/profiler/utils.py +1 -1
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -1
- nat/registry_handlers/pypi/pypi_handler.py +3 -3
- nat/registry_handlers/rest/rest_handler.py +4 -4
- nat/retriever/milvus/retriever.py +1 -1
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/runtime/loader.py +1 -1
- nat/runtime/runner.py +2 -2
- nat/settings/global_settings.py +1 -1
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
- nat/tool/mcp/{mcp_client.py → mcp_client_base.py} +197 -46
- nat/tool/mcp/mcp_client_impl.py +229 -0
- nat/tool/mcp/mcp_tool.py +79 -42
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/register.py +1 -0
- nat/tool/retriever.py +3 -2
- nat/utils/io/yaml_tools.py +1 -1
- nat/utils/reactive/observer.py +2 -2
- nat/utils/settings/global_settings.py +2 -2
- {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/METADATA +3 -3
- {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/RECORD +82 -81
- {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/top_level.txt +0 -0
|
@@ -40,7 +40,7 @@ class IntermediateStepAdapter:
|
|
|
40
40
|
try:
|
|
41
41
|
validated_steps.append(IntermediateStep.model_validate(step_data))
|
|
42
42
|
except Exception as e:
|
|
43
|
-
logger.exception("Validation failed for step: %r, Error: %s", step_data, e
|
|
43
|
+
logger.exception("Validation failed for step: %r, Error: %s", step_data, e)
|
|
44
44
|
return validated_steps
|
|
45
45
|
|
|
46
46
|
def serialize_intermediate_steps(self, intermediate_steps: list[IntermediateStep]) -> list[dict]:
|
|
@@ -102,7 +102,7 @@ class RAGEvaluator:
|
|
|
102
102
|
"""Converts the ragas EvaluationResult to nat EvalOutput"""
|
|
103
103
|
|
|
104
104
|
if not results_dataset:
|
|
105
|
-
logger.error("Ragas evaluation failed with no results")
|
|
105
|
+
logger.error("Ragas evaluation failed with no results", exc_info=True)
|
|
106
106
|
return EvalOutput(average_score=0.0, eval_output_items=[])
|
|
107
107
|
|
|
108
108
|
scores: list[dict[str, float]] = results_dataset.scores
|
|
@@ -169,7 +169,7 @@ class RAGEvaluator:
|
|
|
169
169
|
_pbar=pbar)
|
|
170
170
|
except Exception as e:
|
|
171
171
|
# On exception we still continue with other evaluators. Log and return an avg_score of 0.0
|
|
172
|
-
logger.exception("Error evaluating ragas metric, Error: %s", e
|
|
172
|
+
logger.exception("Error evaluating ragas metric, Error: %s", e)
|
|
173
173
|
results_dataset = None
|
|
174
174
|
finally:
|
|
175
175
|
pbar.close()
|
|
@@ -104,7 +104,7 @@ async def register_ragas_evaluator(config: RagasEvaluatorConfig, builder: EvalBu
|
|
|
104
104
|
raise ValueError(message) from e
|
|
105
105
|
except AttributeError as e:
|
|
106
106
|
message = f"Ragas metric {metric_name} not found {e}."
|
|
107
|
-
logger.
|
|
107
|
+
logger.exception(message)
|
|
108
108
|
return None
|
|
109
109
|
|
|
110
110
|
async def evaluate_fn(eval_input: EvalInput) -> EvalOutput:
|
nat/eval/remote_workflow.py
CHANGED
|
@@ -74,7 +74,7 @@ class EvaluationRemoteWorkflowHandler:
|
|
|
74
74
|
if chunk_data.get("value"):
|
|
75
75
|
final_response = chunk_data.get("value")
|
|
76
76
|
except json.JSONDecodeError as e:
|
|
77
|
-
logger.
|
|
77
|
+
logger.exception("Failed to parse generate response chunk: %s", e)
|
|
78
78
|
continue
|
|
79
79
|
elif line.startswith(INTERMEDIATE_DATA_PREFIX):
|
|
80
80
|
# This is an intermediate step
|
|
@@ -90,12 +90,12 @@ class EvaluationRemoteWorkflowHandler:
|
|
|
90
90
|
payload=payload)
|
|
91
91
|
intermediate_steps.append(intermediate_step)
|
|
92
92
|
except (json.JSONDecodeError, ValidationError) as e:
|
|
93
|
-
logger.
|
|
93
|
+
logger.exception("Failed to parse intermediate step: %s", e)
|
|
94
94
|
continue
|
|
95
95
|
|
|
96
96
|
except aiohttp.ClientError as e:
|
|
97
97
|
# Handle connection or HTTP-related errors
|
|
98
|
-
logger.
|
|
98
|
+
logger.exception("Request failed for question %s: %s", question, e)
|
|
99
99
|
item.output_obj = None
|
|
100
100
|
item.trajectory = []
|
|
101
101
|
return
|
|
@@ -69,13 +69,13 @@ class SweBenchEvaluator:
|
|
|
69
69
|
try:
|
|
70
70
|
shutil.move(swe_bench_report_file, report_dir)
|
|
71
71
|
except Exception as e:
|
|
72
|
-
logger.exception("Error moving report file: %s", e
|
|
72
|
+
logger.exception("Error moving report file: %s", e)
|
|
73
73
|
|
|
74
74
|
try:
|
|
75
75
|
dest_logs_dir = os.path.join(report_dir, 'logs')
|
|
76
76
|
shutil.move(logs_dir, dest_logs_dir)
|
|
77
77
|
except Exception as e:
|
|
78
|
-
logger.exception("Error moving logs directory: %s", e
|
|
78
|
+
logger.exception("Error moving logs directory: %s", e)
|
|
79
79
|
|
|
80
80
|
def is_repo_supported(self, repo: str, version: str) -> bool:
|
|
81
81
|
"""Check if the repo is supported by swebench"""
|
|
@@ -106,7 +106,7 @@ class SweBenchEvaluator:
|
|
|
106
106
|
self._model_name_or_path = swebench_output.model_name_or_path
|
|
107
107
|
|
|
108
108
|
except Exception as e:
|
|
109
|
-
logger.exception("Failed to parse EvalInputItem %s: %s", item.id, e
|
|
109
|
+
logger.exception("Failed to parse EvalInputItem %s: %s", item.id, e)
|
|
110
110
|
|
|
111
111
|
# Filter out repos/version not supported by SWEBench
|
|
112
112
|
supported_inputs = [
|
|
@@ -114,7 +114,7 @@ class SweBenchEvaluator:
|
|
|
114
114
|
]
|
|
115
115
|
|
|
116
116
|
if not supported_inputs:
|
|
117
|
-
logger.
|
|
117
|
+
logger.exception("No supported instances; nothing to evaluate")
|
|
118
118
|
return None, None
|
|
119
119
|
|
|
120
120
|
if len(supported_inputs) < len(swebench_inputs):
|
|
@@ -135,7 +135,7 @@ class SweBenchEvaluator:
|
|
|
135
135
|
filtered_outputs = [output for output in swebench_outputs if output.instance_id in valid_instance_ids]
|
|
136
136
|
|
|
137
137
|
if not filtered_outputs:
|
|
138
|
-
logger.error("No supported outputs; nothing to evaluate")
|
|
138
|
+
logger.error("No supported outputs; nothing to evaluate", exc_info=True)
|
|
139
139
|
return None, None
|
|
140
140
|
|
|
141
141
|
# Write SWEBenchOutput to file
|
|
@@ -65,7 +65,7 @@ class TrajectoryEvaluator(BaseEvaluator):
|
|
|
65
65
|
prediction=generated_answer,
|
|
66
66
|
)
|
|
67
67
|
except Exception as e:
|
|
68
|
-
logger.exception("Error evaluating trajectory for question: %s, Error: %s", question, e
|
|
68
|
+
logger.exception("Error evaluating trajectory for question: %s, Error: %s", question, e)
|
|
69
69
|
return EvalOutputItem(id=item.id, score=0.0, reasoning=f"Error evaluating trajectory: {e}")
|
|
70
70
|
|
|
71
71
|
reasoning = {
|
|
@@ -182,8 +182,8 @@ class TunableRagEvaluator(BaseEvaluator):
|
|
|
182
182
|
relevance_score = parsed_response["relevance_score"]
|
|
183
183
|
reasoning = parsed_response["reasoning"]
|
|
184
184
|
except KeyError as e:
|
|
185
|
-
logger.
|
|
186
|
-
|
|
185
|
+
logger.exception("Missing required keys in default scoring response: %s",
|
|
186
|
+
", ".join(str(arg) for arg in e.args))
|
|
187
187
|
reasoning = f"Error in evaluator from parsing judge LLM response. Missing required key(s): {', '.join(str(arg) for arg in e.args)}"
|
|
188
188
|
|
|
189
189
|
coverage_weight = self.default_score_weights.get("coverage", 1 / 3)
|
|
@@ -215,7 +215,7 @@ class TunableRagEvaluator(BaseEvaluator):
|
|
|
215
215
|
reasoning = f"Error in evaluator from parsing judge LLM response. Missing required key(s): {', '.join(str(arg) for arg in e.args)}"
|
|
216
216
|
raise
|
|
217
217
|
except (KeyError, ValueError) as e:
|
|
218
|
-
logger.
|
|
218
|
+
logger.exception("Error parsing judge LLM response: %s", e)
|
|
219
219
|
score = 0.0
|
|
220
220
|
reasoning = "Error in evaluator from parsing judge LLM response."
|
|
221
221
|
|
|
@@ -148,13 +148,13 @@ async def register_ttc_tool_orchestration_function(
|
|
|
148
148
|
result = await fn.acall_invoke(item.output)
|
|
149
149
|
return item, result, None
|
|
150
150
|
except Exception as e:
|
|
151
|
-
logger.
|
|
151
|
+
logger.exception(f"Error invoking function '{item.name}': {e}")
|
|
152
152
|
return item, None, str(e)
|
|
153
153
|
|
|
154
154
|
tasks = []
|
|
155
155
|
for item in ttc_items:
|
|
156
156
|
if item.name not in function_map:
|
|
157
|
-
logger.error(f"Function '{item.name}' not found in function map.")
|
|
157
|
+
logger.error(f"Function '{item.name}' not found in function map.", exc_info=True)
|
|
158
158
|
item.output = f"Error: Function '{item.name}' not found in function map. Check your input"
|
|
159
159
|
else:
|
|
160
160
|
fn = function_map[item.name]
|
|
@@ -47,11 +47,11 @@ class _FastApiFrontEndController:
|
|
|
47
47
|
self._server_background_task = asyncio.create_task(self._server.serve())
|
|
48
48
|
except asyncio.CancelledError as e:
|
|
49
49
|
error_message = f"Task error occurred while starting API server: {str(e)}"
|
|
50
|
-
logger.error(error_message
|
|
50
|
+
logger.error(error_message)
|
|
51
51
|
raise RuntimeError(error_message) from e
|
|
52
52
|
except Exception as e:
|
|
53
53
|
error_message = f"Unexpected error occurred while starting API server: {str(e)}"
|
|
54
|
-
logger.
|
|
54
|
+
logger.exception(error_message)
|
|
55
55
|
raise RuntimeError(error_message) from e
|
|
56
56
|
|
|
57
57
|
async def stop_server(self) -> None:
|
|
@@ -63,6 +63,6 @@ class _FastApiFrontEndController:
|
|
|
63
63
|
self._server.should_exit = True
|
|
64
64
|
await self._server_background_task
|
|
65
65
|
except asyncio.CancelledError as e:
|
|
66
|
-
logger.
|
|
66
|
+
logger.exception("Server shutdown failed: %s", str(e))
|
|
67
67
|
except Exception as e:
|
|
68
|
-
logger.
|
|
68
|
+
logger.exception("Unexpected error occurred: %s", str(e))
|
|
@@ -113,4 +113,4 @@ class FastApiFrontEndPlugin(FrontEndBase[FastApiFrontEndConfig]):
|
|
|
113
113
|
try:
|
|
114
114
|
os.remove(config_file_name)
|
|
115
115
|
except OSError as e:
|
|
116
|
-
logger.
|
|
116
|
+
logger.exception(f"Warning: Failed to delete temp file {config_file_name}: {e}")
|
|
@@ -215,7 +215,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
215
215
|
job_store.cleanup_expired_jobs()
|
|
216
216
|
logger.debug("Expired %s jobs cleaned up", name)
|
|
217
217
|
except Exception as e:
|
|
218
|
-
logger.
|
|
218
|
+
logger.exception("Error during %s job cleanup: %s", name, e)
|
|
219
219
|
await asyncio.sleep(sleep_time_sec)
|
|
220
220
|
|
|
221
221
|
async def create_cleanup_task(self, app: FastAPI, name: str, job_store: JobStore, sleep_time_sec: int = 300):
|
|
@@ -301,7 +301,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
301
301
|
|
|
302
302
|
job_store.update_status(job_id, "success", output_path=str(parent_dir))
|
|
303
303
|
except Exception as e:
|
|
304
|
-
logger.
|
|
304
|
+
logger.exception("Error in evaluation job %s: %s", job_id, str(e))
|
|
305
305
|
job_store.update_status(job_id, "failure", error=str(e))
|
|
306
306
|
|
|
307
307
|
async def start_evaluation(request: EvaluateRequest, background_tasks: BackgroundTasks, http_request: Request):
|
|
@@ -735,7 +735,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
735
735
|
result_type=result_type)
|
|
736
736
|
job_store.update_status(job_id, "success", output=result)
|
|
737
737
|
except Exception as e:
|
|
738
|
-
logger.
|
|
738
|
+
logger.exception("Error in evaluation job %s: %s", job_id, e)
|
|
739
739
|
job_store.update_status(job_id, "failure", error=str(e))
|
|
740
740
|
|
|
741
741
|
def _job_status_to_response(job: JobInfo) -> AsyncGenerationStatusResponse:
|
|
@@ -170,7 +170,7 @@ class WebSocketMessageHandler:
|
|
|
170
170
|
self._workflow_schema_type])).add_done_callback(_done_callback)
|
|
171
171
|
|
|
172
172
|
except ValueError as e:
|
|
173
|
-
logger.
|
|
173
|
+
logger.exception("User message content not found: %s", str(e))
|
|
174
174
|
await self.create_websocket_message(data_model=Error(code=ErrorTypes.INVALID_USER_MESSAGE_CONTENT,
|
|
175
175
|
message="User message content could not be found",
|
|
176
176
|
details=str(e)),
|
|
@@ -238,7 +238,7 @@ class WebSocketMessageHandler:
|
|
|
238
238
|
f"Message type could not be resolved by input data model: {data_model.model_dump_json()}")
|
|
239
239
|
|
|
240
240
|
except (ValidationError, TypeError, ValueError) as e:
|
|
241
|
-
logger.
|
|
241
|
+
logger.exception("A data vaidation error ocurred creating websocket message: %s", str(e))
|
|
242
242
|
message = await self._message_validator.create_system_response_token_message(
|
|
243
243
|
message_type=WebSocketMessageType.ERROR_MESSAGE,
|
|
244
244
|
conversation_id=self._conversation_id,
|
|
@@ -97,7 +97,7 @@ class MessageValidator:
|
|
|
97
97
|
return validated_message
|
|
98
98
|
|
|
99
99
|
except (ValidationError, TypeError, ValueError) as e:
|
|
100
|
-
logger.
|
|
100
|
+
logger.exception("A data validation error %s occurred for message: %s", str(e), str(message))
|
|
101
101
|
return await self.create_system_response_token_message(message_type=WebSocketMessageType.ERROR_MESSAGE,
|
|
102
102
|
content=Error(code=ErrorTypes.INVALID_MESSAGE,
|
|
103
103
|
message="Error validating message.",
|
|
@@ -119,7 +119,7 @@ class MessageValidator:
|
|
|
119
119
|
return schema
|
|
120
120
|
|
|
121
121
|
except (TypeError, ValueError) as e:
|
|
122
|
-
logger.
|
|
122
|
+
logger.exception("Error retrieving schema for message type '%s': %s", message_type, str(e))
|
|
123
123
|
return Error
|
|
124
124
|
|
|
125
125
|
async def convert_data_to_message_content(self, data_model: BaseModel) -> BaseModel:
|
|
@@ -156,7 +156,7 @@ class MessageValidator:
|
|
|
156
156
|
return validated_message_content
|
|
157
157
|
|
|
158
158
|
except ValueError as e:
|
|
159
|
-
logger.
|
|
159
|
+
logger.exception("Input data could not be converted to validated message content: %s", str(e))
|
|
160
160
|
return Error(code=ErrorTypes.INVALID_DATA_CONTENT, message="Input data not supported.", details=str(e))
|
|
161
161
|
|
|
162
162
|
async def convert_text_content_to_human_response(self, text_content: TextContent,
|
|
@@ -191,7 +191,7 @@ class MessageValidator:
|
|
|
191
191
|
return human_response
|
|
192
192
|
|
|
193
193
|
except ValueError as e:
|
|
194
|
-
logger.
|
|
194
|
+
logger.exception("Error human response content not found: %s", str(e))
|
|
195
195
|
return HumanResponseText(text=str(e))
|
|
196
196
|
|
|
197
197
|
async def resolve_message_type_by_data(self, data_model: BaseModel) -> str:
|
|
@@ -218,9 +218,7 @@ class MessageValidator:
|
|
|
218
218
|
return validated_message_type
|
|
219
219
|
|
|
220
220
|
except ValueError as e:
|
|
221
|
-
logger.
|
|
222
|
-
str(e),
|
|
223
|
-
exc_info=True)
|
|
221
|
+
logger.exception("Error type not found converting data to validated websocket message content: %s", str(e))
|
|
224
222
|
return WebSocketMessageType.ERROR_MESSAGE
|
|
225
223
|
|
|
226
224
|
async def get_intermediate_step_parent_id(self, data_model: ResponseIntermediateStep) -> str:
|
|
@@ -269,7 +267,7 @@ class MessageValidator:
|
|
|
269
267
|
timestamp=timestamp)
|
|
270
268
|
|
|
271
269
|
except Exception as e:
|
|
272
|
-
logger.
|
|
270
|
+
logger.exception("Error creating system response token message: %s", str(e))
|
|
273
271
|
return None
|
|
274
272
|
|
|
275
273
|
async def create_system_intermediate_step_message(
|
|
@@ -308,7 +306,7 @@ class MessageValidator:
|
|
|
308
306
|
timestamp=timestamp)
|
|
309
307
|
|
|
310
308
|
except Exception as e:
|
|
311
|
-
logger.
|
|
309
|
+
logger.exception("Error creating system intermediate step message: %s", str(e))
|
|
312
310
|
return None
|
|
313
311
|
|
|
314
312
|
async def create_system_interaction_message(
|
|
@@ -348,5 +346,5 @@ class MessageValidator:
|
|
|
348
346
|
timestamp=timestamp)
|
|
349
347
|
|
|
350
348
|
except Exception as e:
|
|
351
|
-
logger.
|
|
349
|
+
logger.exception("Error creating system interaction message: %s", str(e))
|
|
352
350
|
return None
|
|
@@ -98,9 +98,9 @@ async def generate_streaming_response(payload: typing.Any,
|
|
|
98
98
|
yield item
|
|
99
99
|
else:
|
|
100
100
|
yield ResponsePayloadOutput(payload=item)
|
|
101
|
-
except Exception
|
|
101
|
+
except Exception:
|
|
102
102
|
# Handle exceptions here
|
|
103
|
-
raise
|
|
103
|
+
raise
|
|
104
104
|
finally:
|
|
105
105
|
await q.close()
|
|
106
106
|
|
|
@@ -165,9 +165,9 @@ async def generate_streaming_response_full(payload: typing.Any,
|
|
|
165
165
|
yield item
|
|
166
166
|
else:
|
|
167
167
|
yield ResponsePayloadOutput(payload=item)
|
|
168
|
-
except Exception
|
|
168
|
+
except Exception:
|
|
169
169
|
# Handle exceptions here
|
|
170
|
-
raise
|
|
170
|
+
raise
|
|
171
171
|
finally:
|
|
172
172
|
await q.close()
|
|
173
173
|
|
|
@@ -314,6 +314,6 @@ class StepAdaptor:
|
|
|
314
314
|
return self._handle_custom(payload, ancestry)
|
|
315
315
|
|
|
316
316
|
except Exception as e:
|
|
317
|
-
logger.
|
|
317
|
+
logger.exception("Error processing intermediate step: %s", e)
|
|
318
318
|
|
|
319
319
|
return None
|
|
@@ -13,6 +13,8 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
from typing import Literal
|
|
17
|
+
|
|
16
18
|
from pydantic import Field
|
|
17
19
|
|
|
18
20
|
from nat.data_models.front_end import FrontEndBaseConfig
|
|
@@ -32,5 +34,8 @@ class MCPFrontEndConfig(FrontEndBaseConfig, name="mcp"):
|
|
|
32
34
|
log_level: str = Field(default="INFO", description="Log level for the MCP server (default: INFO)")
|
|
33
35
|
tool_names: list[str] = Field(default_factory=list,
|
|
34
36
|
description="The list of tools MCP server will expose (default: all tools)")
|
|
37
|
+
transport: Literal["sse", "streamable-http"] = Field(
|
|
38
|
+
default="streamable-http",
|
|
39
|
+
description="Transport type for the MCP server (default: streamable-http, backwards compatible with sse)")
|
|
35
40
|
runner_class: str | None = Field(
|
|
36
41
|
default=None, description="Custom worker class for handling MCP routes (default: built-in worker)")
|
|
@@ -77,5 +77,11 @@ class MCPFrontEndPlugin(FrontEndBase[MCPFrontEndConfig]):
|
|
|
77
77
|
# Add routes through the worker (includes health endpoint and function registration)
|
|
78
78
|
await worker.add_routes(mcp, builder)
|
|
79
79
|
|
|
80
|
-
# Start the MCP server
|
|
81
|
-
|
|
80
|
+
# Start the MCP server with configurable transport
|
|
81
|
+
# streamable-http is the default, but users can choose sse if preferred
|
|
82
|
+
if self.front_end_config.transport == "sse":
|
|
83
|
+
logger.info("Starting MCP server with SSE endpoint at /sse")
|
|
84
|
+
await mcp.run_sse_async()
|
|
85
|
+
else: # streamable-http
|
|
86
|
+
logger.info("Starting MCP server with streamable-http endpoint at /mcp/")
|
|
87
|
+
await mcp.run_streamable_http_async()
|
|
@@ -134,9 +134,9 @@ class MCPFrontEndPluginWorker(MCPFrontEndPluginWorkerBase):
|
|
|
134
134
|
logger.debug("Skipping function %s as it's not in tool_names", function_name)
|
|
135
135
|
functions = filtered_functions
|
|
136
136
|
|
|
137
|
-
# Register each function with MCP
|
|
137
|
+
# Register each function with MCP, passing workflow context for observability
|
|
138
138
|
for function_name, function in functions.items():
|
|
139
|
-
register_function_with_mcp(mcp, function_name, function)
|
|
139
|
+
register_function_with_mcp(mcp, function_name, function, workflow)
|
|
140
140
|
|
|
141
141
|
# Add a simple fallback function if no functions were found
|
|
142
142
|
if not functions:
|
|
@@ -17,13 +17,17 @@ import json
|
|
|
17
17
|
import logging
|
|
18
18
|
from inspect import Parameter
|
|
19
19
|
from inspect import Signature
|
|
20
|
+
from typing import TYPE_CHECKING
|
|
20
21
|
|
|
21
22
|
from mcp.server.fastmcp import FastMCP
|
|
22
23
|
from pydantic import BaseModel
|
|
23
24
|
|
|
25
|
+
from nat.builder.context import ContextState
|
|
24
26
|
from nat.builder.function import Function
|
|
25
27
|
from nat.builder.function_base import FunctionBase
|
|
26
|
-
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from nat.builder.workflow import Workflow
|
|
27
31
|
|
|
28
32
|
logger = logging.getLogger(__name__)
|
|
29
33
|
|
|
@@ -33,14 +37,16 @@ def create_function_wrapper(
|
|
|
33
37
|
function: FunctionBase,
|
|
34
38
|
schema: type[BaseModel],
|
|
35
39
|
is_workflow: bool = False,
|
|
40
|
+
workflow: 'Workflow | None' = None,
|
|
36
41
|
):
|
|
37
42
|
"""Create a wrapper function that exposes the actual parameters of a NAT Function as an MCP tool.
|
|
38
43
|
|
|
39
44
|
Args:
|
|
40
|
-
function_name: The name of the function/tool
|
|
41
|
-
function: The NAT Function object
|
|
42
|
-
schema: The input schema of the function
|
|
43
|
-
is_workflow: Whether the function is a Workflow
|
|
45
|
+
function_name (str): The name of the function/tool
|
|
46
|
+
function (FunctionBase): The NAT Function object
|
|
47
|
+
schema (type[BaseModel]): The input schema of the function
|
|
48
|
+
is_workflow (bool): Whether the function is a Workflow
|
|
49
|
+
workflow (Workflow | None): The parent workflow for observability context
|
|
44
50
|
|
|
45
51
|
Returns:
|
|
46
52
|
A wrapper function suitable for registration with MCP
|
|
@@ -101,6 +107,19 @@ def create_function_wrapper(
|
|
|
101
107
|
await ctx.report_progress(0, 100)
|
|
102
108
|
|
|
103
109
|
try:
|
|
110
|
+
# Helper function to wrap function calls with observability
|
|
111
|
+
async def call_with_observability(func_call):
|
|
112
|
+
# Use workflow's observability context (workflow should always be available)
|
|
113
|
+
if not workflow:
|
|
114
|
+
logger.error("Missing workflow context for function %s - observability will not be available",
|
|
115
|
+
function_name)
|
|
116
|
+
raise RuntimeError("Workflow context is required for observability")
|
|
117
|
+
|
|
118
|
+
logger.debug("Starting observability context for function %s", function_name)
|
|
119
|
+
context_state = ContextState.get()
|
|
120
|
+
async with workflow.exporter_manager.start(context_state=context_state):
|
|
121
|
+
return await func_call()
|
|
122
|
+
|
|
104
123
|
# Special handling for ChatRequest
|
|
105
124
|
if is_chat_request:
|
|
106
125
|
from nat.data_models.api_server import ChatRequest
|
|
@@ -118,7 +137,7 @@ def create_function_wrapper(
|
|
|
118
137
|
result = await runner.result(to_type=str)
|
|
119
138
|
else:
|
|
120
139
|
# Regular functions use ainvoke
|
|
121
|
-
result = await function.ainvoke(chat_request, to_type=str)
|
|
140
|
+
result = await call_with_observability(lambda: function.ainvoke(chat_request, to_type=str))
|
|
122
141
|
else:
|
|
123
142
|
# Regular handling
|
|
124
143
|
# Handle complex input schema - if we extracted fields from a nested schema,
|
|
@@ -129,7 +148,7 @@ def create_function_wrapper(
|
|
|
129
148
|
field_type = schema.model_fields[field_name].annotation
|
|
130
149
|
|
|
131
150
|
# If it's a pydantic model, we need to create an instance
|
|
132
|
-
if hasattr(field_type, "model_validate"):
|
|
151
|
+
if field_type and hasattr(field_type, "model_validate"):
|
|
133
152
|
# Create the nested object
|
|
134
153
|
nested_obj = field_type.model_validate(kwargs)
|
|
135
154
|
# Call with the nested object
|
|
@@ -147,7 +166,7 @@ def create_function_wrapper(
|
|
|
147
166
|
result = await runner.result(to_type=str)
|
|
148
167
|
else:
|
|
149
168
|
# Regular function call
|
|
150
|
-
result = await function.acall_invoke(**kwargs)
|
|
169
|
+
result = await call_with_observability(lambda: function.acall_invoke(**kwargs))
|
|
151
170
|
|
|
152
171
|
# Report completion
|
|
153
172
|
if ctx:
|
|
@@ -170,7 +189,7 @@ def create_function_wrapper(
|
|
|
170
189
|
wrapper = create_wrapper()
|
|
171
190
|
|
|
172
191
|
# Set the signature on the wrapper function (WITHOUT ctx)
|
|
173
|
-
wrapper.__signature__ = sig
|
|
192
|
+
wrapper.__signature__ = sig # type: ignore
|
|
174
193
|
wrapper.__name__ = function_name
|
|
175
194
|
|
|
176
195
|
# Return the wrapper with proper signature
|
|
@@ -183,8 +202,8 @@ def get_function_description(function: FunctionBase) -> str:
|
|
|
183
202
|
|
|
184
203
|
The description is determined using the following precedence:
|
|
185
204
|
1. If the function is a Workflow and has a 'description' attribute, use it.
|
|
186
|
-
2. If the Workflow's config has a '
|
|
187
|
-
3. If the Workflow's config has a '
|
|
205
|
+
2. If the Workflow's config has a 'description', use it.
|
|
206
|
+
3. If the Workflow's config has a 'topic', use it.
|
|
188
207
|
4. If the function is a regular Function, use its 'description' attribute.
|
|
189
208
|
|
|
190
209
|
Args:
|
|
@@ -195,6 +214,9 @@ def get_function_description(function: FunctionBase) -> str:
|
|
|
195
214
|
"""
|
|
196
215
|
function_description = ""
|
|
197
216
|
|
|
217
|
+
# Import here to avoid circular imports
|
|
218
|
+
from nat.builder.workflow import Workflow
|
|
219
|
+
|
|
198
220
|
if isinstance(function, Workflow):
|
|
199
221
|
config = function.config
|
|
200
222
|
|
|
@@ -214,13 +236,17 @@ def get_function_description(function: FunctionBase) -> str:
|
|
|
214
236
|
return function_description
|
|
215
237
|
|
|
216
238
|
|
|
217
|
-
def register_function_with_mcp(mcp: FastMCP,
|
|
239
|
+
def register_function_with_mcp(mcp: FastMCP,
|
|
240
|
+
function_name: str,
|
|
241
|
+
function: FunctionBase,
|
|
242
|
+
workflow: 'Workflow | None' = None) -> None:
|
|
218
243
|
"""Register a NAT Function as an MCP tool.
|
|
219
244
|
|
|
220
245
|
Args:
|
|
221
246
|
mcp: The FastMCP instance
|
|
222
247
|
function_name: The name to register the function under
|
|
223
248
|
function: The NAT Function to register
|
|
249
|
+
workflow: The parent workflow for observability context (if available)
|
|
224
250
|
"""
|
|
225
251
|
logger.info("Registering function %s with MCP", function_name)
|
|
226
252
|
|
|
@@ -229,6 +255,7 @@ def register_function_with_mcp(mcp: FastMCP, function_name: str, function: Funct
|
|
|
229
255
|
logger.info("Function %s has input schema: %s", function_name, input_schema)
|
|
230
256
|
|
|
231
257
|
# Check if we're dealing with a Workflow
|
|
258
|
+
from nat.builder.workflow import Workflow
|
|
232
259
|
is_workflow = isinstance(function, Workflow)
|
|
233
260
|
if is_workflow:
|
|
234
261
|
logger.info("Function %s is a Workflow", function_name)
|
|
@@ -237,5 +264,5 @@ def register_function_with_mcp(mcp: FastMCP, function_name: str, function: Funct
|
|
|
237
264
|
function_description = get_function_description(function)
|
|
238
265
|
|
|
239
266
|
# Create and register the wrapper function with MCP
|
|
240
|
-
wrapper_func = create_function_wrapper(function_name, function, input_schema, is_workflow)
|
|
267
|
+
wrapper_func = create_function_wrapper(function_name, function, input_schema, is_workflow, workflow)
|
|
241
268
|
mcp.tool(name=function_name, description=function_description)(wrapper_func)
|
|
@@ -375,7 +375,7 @@ class BaseExporter(Exporter):
|
|
|
375
375
|
except asyncio.TimeoutError:
|
|
376
376
|
logger.warning("%s: Some tasks did not complete within %s seconds", self.name, timeout)
|
|
377
377
|
except Exception as e:
|
|
378
|
-
logger.
|
|
378
|
+
logger.exception("%s: Error while waiting for tasks: %s", self.name, e)
|
|
379
379
|
|
|
380
380
|
@override
|
|
381
381
|
async def stop(self):
|
|
@@ -175,7 +175,7 @@ class ProcessingExporter(Generic[PipelineInputT, PipelineOutputT], BaseExporter,
|
|
|
175
175
|
try:
|
|
176
176
|
processed_item = await processor.process(processed_item)
|
|
177
177
|
except Exception as e:
|
|
178
|
-
logger.
|
|
178
|
+
logger.exception("Error in processor %s: %s", processor.__class__.__name__, e)
|
|
179
179
|
# Continue with unprocessed item rather than failing
|
|
180
180
|
return processed_item
|
|
181
181
|
|
|
@@ -214,7 +214,7 @@ class ProcessingExporter(Generic[PipelineInputT, PipelineOutputT], BaseExporter,
|
|
|
214
214
|
try:
|
|
215
215
|
source_index = self._processors.index(source_processor)
|
|
216
216
|
except ValueError:
|
|
217
|
-
logger.
|
|
217
|
+
logger.exception("Source processor %s not found in pipeline", source_processor.__class__.__name__)
|
|
218
218
|
return
|
|
219
219
|
|
|
220
220
|
# Process through remaining processors (skip the source processor)
|
|
@@ -225,10 +225,9 @@ class ProcessingExporter(Generic[PipelineInputT, PipelineOutputT], BaseExporter,
|
|
|
225
225
|
await self._export_final_item(processed_item)
|
|
226
226
|
|
|
227
227
|
except Exception as e:
|
|
228
|
-
logger.
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
exc_info=True)
|
|
228
|
+
logger.exception("Failed to continue pipeline processing after %s: %s",
|
|
229
|
+
source_processor.__class__.__name__,
|
|
230
|
+
e)
|
|
232
231
|
|
|
233
232
|
async def _export_with_processing(self, item: PipelineInputT) -> None:
|
|
234
233
|
"""Export an item after processing it through the pipeline.
|
|
@@ -248,7 +247,7 @@ class ProcessingExporter(Generic[PipelineInputT, PipelineOutputT], BaseExporter,
|
|
|
248
247
|
await self._export_final_item(final_item, raise_on_invalid=True)
|
|
249
248
|
|
|
250
249
|
except Exception as e:
|
|
251
|
-
logger.error("Failed to export item '%s': %s", item, e
|
|
250
|
+
logger.error("Failed to export item '%s': %s", item, e)
|
|
252
251
|
raise
|
|
253
252
|
|
|
254
253
|
@override
|
|
@@ -293,7 +292,7 @@ class ProcessingExporter(Generic[PipelineInputT, PipelineOutputT], BaseExporter,
|
|
|
293
292
|
task.add_done_callback(self._tasks.discard)
|
|
294
293
|
|
|
295
294
|
except Exception as e:
|
|
296
|
-
logger.error("%s: Failed to create task: %s", self.name, e
|
|
295
|
+
logger.error("%s: Failed to create task: %s", self.name, e)
|
|
297
296
|
raise
|
|
298
297
|
|
|
299
298
|
@override
|
|
@@ -316,7 +315,7 @@ class ProcessingExporter(Generic[PipelineInputT, PipelineOutputT], BaseExporter,
|
|
|
316
315
|
await asyncio.gather(*shutdown_tasks, return_exceptions=True)
|
|
317
316
|
logger.debug("Successfully shut down %d processors", len(shutdown_tasks))
|
|
318
317
|
except Exception as e:
|
|
319
|
-
logger.
|
|
318
|
+
logger.exception("Error shutting down processors: %s", e)
|
|
320
319
|
|
|
321
320
|
# Call parent cleanup
|
|
322
321
|
await super()._cleanup()
|
|
@@ -177,7 +177,7 @@ class ExporterManager:
|
|
|
177
177
|
else:
|
|
178
178
|
logger.debug("Skipping cleanup for non-isolated exporter '%s'", name)
|
|
179
179
|
except Exception as e:
|
|
180
|
-
logger.
|
|
180
|
+
logger.exception("Error preparing cleanup for isolated exporter '%s': %s", name, e)
|
|
181
181
|
|
|
182
182
|
if cleanup_tasks:
|
|
183
183
|
# Run cleanup tasks concurrently with timeout
|
|
@@ -195,7 +195,7 @@ class ExporterManager:
|
|
|
195
195
|
logger.debug("Stopping isolated exporter '%s'", name)
|
|
196
196
|
await exporter.stop()
|
|
197
197
|
except Exception as e:
|
|
198
|
-
logger.
|
|
198
|
+
logger.exception("Error stopping isolated exporter '%s': %s", name, e)
|
|
199
199
|
|
|
200
200
|
@asynccontextmanager
|
|
201
201
|
async def start(self, context_state: ContextState | None = None):
|
|
@@ -251,7 +251,7 @@ class ExporterManager:
|
|
|
251
251
|
try:
|
|
252
252
|
await self._cleanup_isolated_exporters()
|
|
253
253
|
except Exception as e:
|
|
254
|
-
logger.
|
|
254
|
+
logger.exception("Error during isolated exporter cleanup: %s", e)
|
|
255
255
|
|
|
256
256
|
# Then stop the manager tasks
|
|
257
257
|
await self.stop()
|
|
@@ -275,7 +275,7 @@ class ExporterManager:
|
|
|
275
275
|
logger.info("Stopped exporter '%s'", name)
|
|
276
276
|
raise
|
|
277
277
|
except Exception as e:
|
|
278
|
-
logger.error("Failed to run exporter '%s': %s", name, str(e)
|
|
278
|
+
logger.error("Failed to run exporter '%s': %s", name, str(e))
|
|
279
279
|
# Re-raise the exception to ensure it's properly handled
|
|
280
280
|
raise
|
|
281
281
|
|
|
@@ -307,7 +307,7 @@ class ExporterManager:
|
|
|
307
307
|
except asyncio.CancelledError:
|
|
308
308
|
logger.debug("Exporter '%s' task cancelled", name)
|
|
309
309
|
except Exception as e:
|
|
310
|
-
logger.
|
|
310
|
+
logger.exception("Failed to stop exporter '%s': %s", name, str(e))
|
|
311
311
|
|
|
312
312
|
if stuck_tasks:
|
|
313
313
|
logger.warning("Exporters did not shut down in time: %s", ", ".join(stuck_tasks))
|