nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. aiq/__init__.py +2 -2
  2. nat/agent/base.py +24 -15
  3. nat/agent/dual_node.py +9 -4
  4. nat/agent/prompt_optimizer/prompt.py +68 -0
  5. nat/agent/prompt_optimizer/register.py +149 -0
  6. nat/agent/react_agent/agent.py +79 -47
  7. nat/agent/react_agent/register.py +41 -21
  8. nat/agent/reasoning_agent/reasoning_agent.py +11 -9
  9. nat/agent/register.py +1 -1
  10. nat/agent/rewoo_agent/agent.py +326 -148
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +46 -26
  13. nat/agent/tool_calling_agent/agent.py +84 -28
  14. nat/agent/tool_calling_agent/register.py +51 -28
  15. nat/authentication/api_key/api_key_auth_provider.py +2 -2
  16. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  17. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  18. nat/authentication/interfaces.py +5 -2
  19. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +40 -20
  20. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  21. nat/authentication/register.py +0 -1
  22. nat/builder/builder.py +56 -24
  23. nat/builder/component_utils.py +9 -5
  24. nat/builder/context.py +46 -11
  25. nat/builder/eval_builder.py +16 -11
  26. nat/builder/framework_enum.py +1 -0
  27. nat/builder/front_end.py +1 -1
  28. nat/builder/function.py +378 -8
  29. nat/builder/function_base.py +3 -3
  30. nat/builder/function_info.py +6 -8
  31. nat/builder/user_interaction_manager.py +2 -2
  32. nat/builder/workflow.py +13 -1
  33. nat/builder/workflow_builder.py +281 -76
  34. nat/cli/cli_utils/config_override.py +2 -2
  35. nat/cli/commands/evaluate.py +1 -1
  36. nat/cli/commands/info/info.py +16 -6
  37. nat/cli/commands/info/list_channels.py +1 -1
  38. nat/cli/commands/info/list_components.py +7 -8
  39. nat/cli/commands/mcp/__init__.py +14 -0
  40. nat/cli/commands/mcp/mcp.py +986 -0
  41. nat/cli/commands/object_store/__init__.py +14 -0
  42. nat/cli/commands/object_store/object_store.py +227 -0
  43. nat/cli/commands/optimize.py +90 -0
  44. nat/cli/commands/registry/publish.py +2 -2
  45. nat/cli/commands/registry/pull.py +2 -2
  46. nat/cli/commands/registry/remove.py +2 -2
  47. nat/cli/commands/registry/search.py +15 -17
  48. nat/cli/commands/start.py +16 -5
  49. nat/cli/commands/uninstall.py +1 -1
  50. nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
  51. nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
  52. nat/cli/commands/workflow/templates/register.py.j2 +0 -1
  53. nat/cli/commands/workflow/workflow_commands.py +9 -13
  54. nat/cli/entrypoint.py +8 -10
  55. nat/cli/register_workflow.py +38 -4
  56. nat/cli/type_registry.py +75 -6
  57. nat/control_flow/__init__.py +0 -0
  58. nat/control_flow/register.py +20 -0
  59. nat/control_flow/router_agent/__init__.py +0 -0
  60. nat/control_flow/router_agent/agent.py +329 -0
  61. nat/control_flow/router_agent/prompt.py +48 -0
  62. nat/control_flow/router_agent/register.py +91 -0
  63. nat/control_flow/sequential_executor.py +166 -0
  64. nat/data_models/agent.py +34 -0
  65. nat/data_models/api_server.py +10 -10
  66. nat/data_models/authentication.py +23 -9
  67. nat/data_models/common.py +1 -1
  68. nat/data_models/component.py +2 -0
  69. nat/data_models/component_ref.py +11 -0
  70. nat/data_models/config.py +41 -17
  71. nat/data_models/dataset_handler.py +1 -1
  72. nat/data_models/discovery_metadata.py +4 -4
  73. nat/data_models/evaluate.py +4 -1
  74. nat/data_models/function.py +34 -0
  75. nat/data_models/function_dependencies.py +14 -6
  76. nat/data_models/gated_field_mixin.py +242 -0
  77. nat/data_models/intermediate_step.py +3 -3
  78. nat/data_models/optimizable.py +119 -0
  79. nat/data_models/optimizer.py +149 -0
  80. nat/data_models/swe_bench_model.py +1 -1
  81. nat/data_models/temperature_mixin.py +44 -0
  82. nat/data_models/thinking_mixin.py +86 -0
  83. nat/data_models/top_p_mixin.py +44 -0
  84. nat/embedder/nim_embedder.py +1 -1
  85. nat/embedder/openai_embedder.py +1 -1
  86. nat/embedder/register.py +0 -1
  87. nat/eval/config.py +3 -1
  88. nat/eval/dataset_handler/dataset_handler.py +71 -7
  89. nat/eval/evaluate.py +86 -31
  90. nat/eval/evaluator/base_evaluator.py +1 -1
  91. nat/eval/evaluator/evaluator_model.py +13 -0
  92. nat/eval/intermediate_step_adapter.py +1 -1
  93. nat/eval/rag_evaluator/evaluate.py +2 -2
  94. nat/eval/rag_evaluator/register.py +3 -3
  95. nat/eval/register.py +4 -1
  96. nat/eval/remote_workflow.py +3 -3
  97. nat/eval/runtime_evaluator/__init__.py +14 -0
  98. nat/eval/runtime_evaluator/evaluate.py +123 -0
  99. nat/eval/runtime_evaluator/register.py +100 -0
  100. nat/eval/swe_bench_evaluator/evaluate.py +6 -6
  101. nat/eval/trajectory_evaluator/evaluate.py +1 -1
  102. nat/eval/trajectory_evaluator/register.py +1 -1
  103. nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
  104. nat/eval/utils/eval_trace_ctx.py +89 -0
  105. nat/eval/utils/weave_eval.py +18 -9
  106. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  107. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  108. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
  109. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
  110. nat/experimental/test_time_compute/models/strategy_base.py +5 -4
  111. nat/experimental/test_time_compute/register.py +0 -1
  112. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
  113. nat/front_ends/console/authentication_flow_handler.py +82 -30
  114. nat/front_ends/console/console_front_end_plugin.py +8 -5
  115. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  116. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  117. nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
  118. nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
  119. nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
  120. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +481 -281
  121. nat/front_ends/fastapi/job_store.py +518 -99
  122. nat/front_ends/fastapi/main.py +11 -19
  123. nat/front_ends/fastapi/message_handler.py +13 -14
  124. nat/front_ends/fastapi/message_validator.py +17 -19
  125. nat/front_ends/fastapi/response_helpers.py +4 -4
  126. nat/front_ends/fastapi/step_adaptor.py +2 -2
  127. nat/front_ends/fastapi/utils.py +57 -0
  128. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  129. nat/front_ends/mcp/mcp_front_end_config.py +10 -1
  130. nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
  131. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
  132. nat/front_ends/mcp/tool_converter.py +44 -14
  133. nat/front_ends/register.py +0 -1
  134. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  135. nat/llm/aws_bedrock_llm.py +24 -12
  136. nat/llm/azure_openai_llm.py +13 -6
  137. nat/llm/litellm_llm.py +69 -0
  138. nat/llm/nim_llm.py +20 -8
  139. nat/llm/openai_llm.py +14 -6
  140. nat/llm/register.py +4 -1
  141. nat/llm/utils/env_config_value.py +2 -3
  142. nat/llm/utils/thinking.py +215 -0
  143. nat/meta/pypi.md +9 -9
  144. nat/object_store/register.py +0 -1
  145. nat/observability/exporter/base_exporter.py +3 -3
  146. nat/observability/exporter/file_exporter.py +1 -1
  147. nat/observability/exporter/processing_exporter.py +309 -81
  148. nat/observability/exporter/span_exporter.py +1 -1
  149. nat/observability/exporter_manager.py +7 -7
  150. nat/observability/mixin/file_mixin.py +7 -7
  151. nat/observability/mixin/redaction_config_mixin.py +42 -0
  152. nat/observability/mixin/tagging_config_mixin.py +62 -0
  153. nat/observability/mixin/type_introspection_mixin.py +420 -107
  154. nat/observability/processor/batching_processor.py +5 -7
  155. nat/observability/processor/falsy_batch_filter_processor.py +55 -0
  156. nat/observability/processor/processor.py +3 -0
  157. nat/observability/processor/processor_factory.py +70 -0
  158. nat/observability/processor/redaction/__init__.py +24 -0
  159. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  160. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  161. nat/observability/processor/redaction/redaction_processor.py +177 -0
  162. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  163. nat/observability/processor/span_tagging_processor.py +68 -0
  164. nat/observability/register.py +6 -4
  165. nat/profiler/calc/calc_runner.py +3 -4
  166. nat/profiler/callbacks/agno_callback_handler.py +1 -1
  167. nat/profiler/callbacks/langchain_callback_handler.py +6 -6
  168. nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
  169. nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
  170. nat/profiler/data_frame_row.py +1 -1
  171. nat/profiler/decorators/framework_wrapper.py +62 -13
  172. nat/profiler/decorators/function_tracking.py +160 -3
  173. nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
  174. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
  175. nat/profiler/inference_optimization/data_models.py +3 -3
  176. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
  177. nat/profiler/inference_optimization/token_uniqueness.py +1 -1
  178. nat/profiler/parameter_optimization/__init__.py +0 -0
  179. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  180. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  181. nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
  182. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  183. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  184. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  185. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  186. nat/profiler/profile_runner.py +14 -9
  187. nat/profiler/utils.py +4 -2
  188. nat/registry_handlers/local/local_handler.py +2 -2
  189. nat/registry_handlers/package_utils.py +1 -2
  190. nat/registry_handlers/pypi/pypi_handler.py +23 -26
  191. nat/registry_handlers/register.py +3 -4
  192. nat/registry_handlers/rest/rest_handler.py +12 -13
  193. nat/retriever/milvus/retriever.py +2 -2
  194. nat/retriever/nemo_retriever/retriever.py +1 -1
  195. nat/retriever/register.py +0 -1
  196. nat/runtime/loader.py +2 -2
  197. nat/runtime/runner.py +3 -2
  198. nat/runtime/session.py +43 -8
  199. nat/settings/global_settings.py +16 -5
  200. nat/tool/chat_completion.py +5 -2
  201. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
  202. nat/tool/datetime_tools.py +49 -9
  203. nat/tool/document_search.py +2 -2
  204. nat/tool/github_tools.py +450 -0
  205. nat/tool/nvidia_rag.py +1 -1
  206. nat/tool/register.py +2 -9
  207. nat/tool/retriever.py +3 -2
  208. nat/utils/callable_utils.py +70 -0
  209. nat/utils/data_models/schema_validator.py +3 -3
  210. nat/utils/exception_handlers/automatic_retries.py +104 -51
  211. nat/utils/exception_handlers/schemas.py +1 -1
  212. nat/utils/io/yaml_tools.py +2 -2
  213. nat/utils/log_levels.py +25 -0
  214. nat/utils/reactive/base/observable_base.py +2 -2
  215. nat/utils/reactive/base/observer_base.py +1 -1
  216. nat/utils/reactive/observable.py +2 -2
  217. nat/utils/reactive/observer.py +4 -4
  218. nat/utils/reactive/subscription.py +1 -1
  219. nat/utils/settings/global_settings.py +6 -8
  220. nat/utils/type_converter.py +4 -3
  221. nat/utils/type_utils.py +9 -5
  222. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/METADATA +42 -16
  223. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/RECORD +230 -189
  224. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/entry_points.txt +1 -0
  225. nat/cli/commands/info/list_mcp.py +0 -304
  226. nat/tool/github_tools/create_github_commit.py +0 -133
  227. nat/tool/github_tools/create_github_issue.py +0 -87
  228. nat/tool/github_tools/create_github_pr.py +0 -106
  229. nat/tool/github_tools/get_github_file.py +0 -106
  230. nat/tool/github_tools/get_github_issue.py +0 -166
  231. nat/tool/github_tools/get_github_pr.py +0 -256
  232. nat/tool/github_tools/update_github_issue.py +0 -100
  233. nat/tool/mcp/exceptions.py +0 -142
  234. nat/tool/mcp/mcp_client.py +0 -255
  235. nat/tool/mcp/mcp_tool.py +0 -96
  236. nat/utils/exception_handlers/mcp.py +0 -211
  237. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  238. /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
  239. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/WHEEL +0 -0
  240. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  241. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE.md +0 -0
  242. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/top_level.txt +0 -0
@@ -16,11 +16,15 @@
16
16
  import logging
17
17
  from abc import ABC
18
18
  from abc import abstractmethod
19
+ from collections.abc import Mapping
20
+ from typing import Any
19
21
 
20
22
  from mcp.server.fastmcp import FastMCP
23
+ from starlette.exceptions import HTTPException
21
24
  from starlette.requests import Request
22
25
 
23
26
  from nat.builder.function import Function
27
+ from nat.builder.function_base import FunctionBase
24
28
  from nat.builder.workflow import Workflow
25
29
  from nat.builder.workflow_builder import WorkflowBuilder
26
30
  from nat.data_models.config import Config
@@ -82,7 +86,7 @@ class MCPFrontEndPluginWorkerBase(ABC):
82
86
  """
83
87
  pass
84
88
 
85
- def _get_all_functions(self, workflow: Workflow) -> dict[str, Function]:
89
+ async def _get_all_functions(self, workflow: Workflow) -> dict[str, Function]:
86
90
  """Get all functions from the workflow.
87
91
 
88
92
  Args:
@@ -94,13 +98,114 @@ class MCPFrontEndPluginWorkerBase(ABC):
94
98
  functions: dict[str, Function] = {}
95
99
 
96
100
  # Extract all functions from the workflow
97
- for function_name, function in workflow.functions.items():
98
- functions[function_name] = function
101
+ functions.update(workflow.functions)
102
+ for function_group in workflow.function_groups.values():
103
+ functions.update(await function_group.get_accessible_functions())
99
104
 
100
- functions[workflow.config.workflow.type] = workflow
105
+ if workflow.config.workflow.workflow_alias:
106
+ functions[workflow.config.workflow.workflow_alias] = workflow
107
+ else:
108
+ functions[workflow.config.workflow.type] = workflow
101
109
 
102
110
  return functions
103
111
 
112
+ def _setup_debug_endpoints(self, mcp: FastMCP, functions: Mapping[str, FunctionBase]) -> None:
113
+ """Set up HTTP debug endpoints for introspecting tools and schemas.
114
+
115
+ Exposes:
116
+ - GET /debug/tools/list: List tools. Optional query param `name` (one or more, repeatable or comma separated)
117
+ selects a subset and returns details for those tools.
118
+ """
119
+
120
+ @mcp.custom_route("/debug/tools/list", methods=["GET"])
121
+ async def list_tools(request: Request):
122
+ """HTTP list tools endpoint."""
123
+
124
+ from starlette.responses import JSONResponse
125
+
126
+ from nat.front_ends.mcp.tool_converter import get_function_description
127
+
128
+ # Query params
129
+ # Support repeated names and comma-separated lists
130
+ names_param_list = set(request.query_params.getlist("name"))
131
+ names: list[str] = []
132
+ for raw in names_param_list:
133
+ # if p.strip() is empty, it won't be included in the list!
134
+ parts = [p.strip() for p in raw.split(",") if p.strip()]
135
+ names.extend(parts)
136
+ detail_raw = request.query_params.get("detail")
137
+
138
+ def _parse_detail_param(detail_param: str | None, has_names: bool) -> bool:
139
+ if detail_param is None:
140
+ if has_names:
141
+ return True
142
+ return False
143
+ v = detail_param.strip().lower()
144
+ if v in ("0", "false", "no", "off"):
145
+ return False
146
+ if v in ("1", "true", "yes", "on"):
147
+ return True
148
+ # For invalid values, default based on whether names are present
149
+ return has_names
150
+
151
+ # Helper function to build the input schema info
152
+ def _build_schema_info(fn: FunctionBase) -> dict[str, Any] | None:
153
+ schema = getattr(fn, "input_schema", None)
154
+ if schema is None:
155
+ return None
156
+
157
+ # check if schema is a ChatRequest
158
+ schema_name = getattr(schema, "__name__", "")
159
+ schema_qualname = getattr(schema, "__qualname__", "")
160
+ if "ChatRequest" in schema_name or "ChatRequest" in schema_qualname:
161
+ # Simplified interface used by MCP wrapper for ChatRequest
162
+ return {
163
+ "type": "object",
164
+ "properties": {
165
+ "query": {
166
+ "type": "string", "description": "User query string"
167
+ }
168
+ },
169
+ "required": ["query"],
170
+ "title": "ChatRequestQuery",
171
+ }
172
+
173
+ # Pydantic models provide model_json_schema
174
+ if schema is not None and hasattr(schema, "model_json_schema"):
175
+ return schema.model_json_schema()
176
+
177
+ return None
178
+
179
+ def _build_final_json(functions_to_include: Mapping[str, FunctionBase],
180
+ include_schemas: bool = False) -> dict[str, Any]:
181
+ tools = []
182
+ for name, fn in functions_to_include.items():
183
+ list_entry: dict[str, Any] = {
184
+ "name": name, "description": get_function_description(fn), "is_workflow": hasattr(fn, "run")
185
+ }
186
+ if include_schemas:
187
+ list_entry["schema"] = _build_schema_info(fn)
188
+ tools.append(list_entry)
189
+
190
+ return {
191
+ "count": len(tools),
192
+ "tools": tools,
193
+ "server_name": mcp.name,
194
+ }
195
+
196
+ if names:
197
+ # Return selected tools
198
+ try:
199
+ functions_to_include = {n: functions[n] for n in names}
200
+ except KeyError as e:
201
+ raise HTTPException(status_code=404, detail=f"Tool \"{e.args[0]}\" not found.") from e
202
+ else:
203
+ functions_to_include = functions
204
+
205
+ # Default for listing all: detail defaults to False unless explicitly set true
206
+ return JSONResponse(
207
+ _build_final_json(functions_to_include, _parse_detail_param(detail_raw, has_names=bool(names))))
208
+
104
209
 
105
210
  class MCPFrontEndPluginWorker(MCPFrontEndPluginWorkerBase):
106
211
  """Default MCP front end plugin worker implementation."""
@@ -118,10 +223,10 @@ class MCPFrontEndPluginWorker(MCPFrontEndPluginWorkerBase):
118
223
  self._setup_health_endpoint(mcp)
119
224
 
120
225
  # Build the workflow and register all functions with MCP
121
- workflow = builder.build()
226
+ workflow = await builder.build()
122
227
 
123
228
  # Get all functions from the workflow
124
- functions = self._get_all_functions(workflow)
229
+ functions = await self._get_all_functions(workflow)
125
230
 
126
231
  # Filter functions based on tool_names if provided
127
232
  if self.front_end_config.tool_names:
@@ -134,10 +239,13 @@ class MCPFrontEndPluginWorker(MCPFrontEndPluginWorkerBase):
134
239
  logger.debug("Skipping function %s as it's not in tool_names", function_name)
135
240
  functions = filtered_functions
136
241
 
137
- # Register each function with MCP
242
+ # Register each function with MCP, passing workflow context for observability
138
243
  for function_name, function in functions.items():
139
- register_function_with_mcp(mcp, function_name, function)
244
+ register_function_with_mcp(mcp, function_name, function, workflow)
140
245
 
141
246
  # Add a simple fallback function if no functions were found
142
247
  if not functions:
143
248
  raise RuntimeError("No functions found in workflow. Please check your configuration.")
249
+
250
+ # After registration, expose debug endpoints for tool/schema inspection
251
+ self._setup_debug_endpoints(mcp, functions)
@@ -17,13 +17,17 @@ import json
17
17
  import logging
18
18
  from inspect import Parameter
19
19
  from inspect import Signature
20
+ from typing import TYPE_CHECKING
20
21
 
21
22
  from mcp.server.fastmcp import FastMCP
22
23
  from pydantic import BaseModel
23
24
 
25
+ from nat.builder.context import ContextState
24
26
  from nat.builder.function import Function
25
27
  from nat.builder.function_base import FunctionBase
26
- from nat.builder.workflow import Workflow
28
+
29
+ if TYPE_CHECKING:
30
+ from nat.builder.workflow import Workflow
27
31
 
28
32
  logger = logging.getLogger(__name__)
29
33
 
@@ -33,14 +37,16 @@ def create_function_wrapper(
33
37
  function: FunctionBase,
34
38
  schema: type[BaseModel],
35
39
  is_workflow: bool = False,
40
+ workflow: 'Workflow | None' = None,
36
41
  ):
37
42
  """Create a wrapper function that exposes the actual parameters of a NAT Function as an MCP tool.
38
43
 
39
44
  Args:
40
- function_name: The name of the function/tool
41
- function: The NAT Function object
42
- schema: The input schema of the function
43
- is_workflow: Whether the function is a Workflow
45
+ function_name (str): The name of the function/tool
46
+ function (FunctionBase): The NAT Function object
47
+ schema (type[BaseModel]): The input schema of the function
48
+ is_workflow (bool): Whether the function is a Workflow
49
+ workflow (Workflow | None): The parent workflow for observability context
44
50
 
45
51
  Returns:
46
52
  A wrapper function suitable for registration with MCP
@@ -101,6 +107,19 @@ def create_function_wrapper(
101
107
  await ctx.report_progress(0, 100)
102
108
 
103
109
  try:
110
+ # Helper function to wrap function calls with observability
111
+ async def call_with_observability(func_call):
112
+ # Use workflow's observability context (workflow should always be available)
113
+ if not workflow:
114
+ logger.error("Missing workflow context for function %s - observability will not be available",
115
+ function_name)
116
+ raise RuntimeError("Workflow context is required for observability")
117
+
118
+ logger.debug("Starting observability context for function %s", function_name)
119
+ context_state = ContextState.get()
120
+ async with workflow.exporter_manager.start(context_state=context_state):
121
+ return await func_call()
122
+
104
123
  # Special handling for ChatRequest
105
124
  if is_chat_request:
106
125
  from nat.data_models.api_server import ChatRequest
@@ -118,7 +137,7 @@ def create_function_wrapper(
118
137
  result = await runner.result(to_type=str)
119
138
  else:
120
139
  # Regular functions use ainvoke
121
- result = await function.ainvoke(chat_request, to_type=str)
140
+ result = await call_with_observability(lambda: function.ainvoke(chat_request, to_type=str))
122
141
  else:
123
142
  # Regular handling
124
143
  # Handle complex input schema - if we extracted fields from a nested schema,
@@ -129,7 +148,7 @@ def create_function_wrapper(
129
148
  field_type = schema.model_fields[field_name].annotation
130
149
 
131
150
  # If it's a pydantic model, we need to create an instance
132
- if hasattr(field_type, "model_validate"):
151
+ if field_type and hasattr(field_type, "model_validate"):
133
152
  # Create the nested object
134
153
  nested_obj = field_type.model_validate(kwargs)
135
154
  # Call with the nested object
@@ -147,7 +166,7 @@ def create_function_wrapper(
147
166
  result = await runner.result(to_type=str)
148
167
  else:
149
168
  # Regular function call
150
- result = await function.acall_invoke(**kwargs)
169
+ result = await call_with_observability(lambda: function.acall_invoke(**kwargs))
151
170
 
152
171
  # Report completion
153
172
  if ctx:
@@ -156,7 +175,7 @@ def create_function_wrapper(
156
175
  # Handle different result types for proper formatting
157
176
  if isinstance(result, str):
158
177
  return result
159
- if isinstance(result, (dict, list)):
178
+ if isinstance(result, dict | list):
160
179
  return json.dumps(result, default=str)
161
180
  return str(result)
162
181
  except Exception as e:
@@ -170,7 +189,7 @@ def create_function_wrapper(
170
189
  wrapper = create_wrapper()
171
190
 
172
191
  # Set the signature on the wrapper function (WITHOUT ctx)
173
- wrapper.__signature__ = sig
192
+ wrapper.__signature__ = sig # type: ignore
174
193
  wrapper.__name__ = function_name
175
194
 
176
195
  # Return the wrapper with proper signature
@@ -183,8 +202,8 @@ def get_function_description(function: FunctionBase) -> str:
183
202
 
184
203
  The description is determined using the following precedence:
185
204
  1. If the function is a Workflow and has a 'description' attribute, use it.
186
- 2. If the Workflow's config has a 'topic', use it.
187
- 3. If the Workflow's config has a 'description', use it.
205
+ 2. If the Workflow's config has a 'description', use it.
206
+ 3. If the Workflow's config has a 'topic', use it.
188
207
  4. If the function is a regular Function, use its 'description' attribute.
189
208
 
190
209
  Args:
@@ -195,6 +214,9 @@ def get_function_description(function: FunctionBase) -> str:
195
214
  """
196
215
  function_description = ""
197
216
 
217
+ # Import here to avoid circular imports
218
+ from nat.builder.workflow import Workflow
219
+
198
220
  if isinstance(function, Workflow):
199
221
  config = function.config
200
222
 
@@ -207,6 +229,9 @@ def get_function_description(function: FunctionBase) -> str:
207
229
  # Try to get anything that might be a description
208
230
  elif hasattr(config, "topic") and config.topic:
209
231
  function_description = config.topic
232
+ # Try to get description from the workflow config
233
+ elif hasattr(config, "workflow") and hasattr(config.workflow, "description") and config.workflow.description:
234
+ function_description = config.workflow.description
210
235
 
211
236
  elif isinstance(function, Function):
212
237
  function_description = function.description
@@ -214,13 +239,17 @@ def get_function_description(function: FunctionBase) -> str:
214
239
  return function_description
215
240
 
216
241
 
217
- def register_function_with_mcp(mcp: FastMCP, function_name: str, function: FunctionBase) -> None:
242
+ def register_function_with_mcp(mcp: FastMCP,
243
+ function_name: str,
244
+ function: FunctionBase,
245
+ workflow: 'Workflow | None' = None) -> None:
218
246
  """Register a NAT Function as an MCP tool.
219
247
 
220
248
  Args:
221
249
  mcp: The FastMCP instance
222
250
  function_name: The name to register the function under
223
251
  function: The NAT Function to register
252
+ workflow: The parent workflow for observability context (if available)
224
253
  """
225
254
  logger.info("Registering function %s with MCP", function_name)
226
255
 
@@ -229,6 +258,7 @@ def register_function_with_mcp(mcp: FastMCP, function_name: str, function: Funct
229
258
  logger.info("Function %s has input schema: %s", function_name, input_schema)
230
259
 
231
260
  # Check if we're dealing with a Workflow
261
+ from nat.builder.workflow import Workflow
232
262
  is_workflow = isinstance(function, Workflow)
233
263
  if is_workflow:
234
264
  logger.info("Function %s is a Workflow", function_name)
@@ -237,5 +267,5 @@ def register_function_with_mcp(mcp: FastMCP, function_name: str, function: Funct
237
267
  function_description = get_function_description(function)
238
268
 
239
269
  # Create and register the wrapper function with MCP
240
- wrapper_func = create_function_wrapper(function_name, function, input_schema, is_workflow)
270
+ wrapper_func = create_function_wrapper(function_name, function, input_schema, is_workflow, workflow)
241
271
  mcp.tool(name=function_name, description=function_description)(wrapper_func)
@@ -13,7 +13,6 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- # pylint: disable=unused-import
17
16
  # flake8: noqa
18
17
  # isort:skip_file
19
18
 
@@ -35,6 +35,8 @@ class SimpleFrontEndPluginBase(FrontEndBase[FrontEndConfigT], ABC):
35
35
 
36
36
  async def run(self):
37
37
 
38
+ await self.pre_run()
39
+
38
40
  # Must yield the workflow function otherwise it cleans up
39
41
  async with WorkflowBuilder.from_config(config=self.full_config) as builder:
40
42
 
@@ -45,7 +47,7 @@ class SimpleFrontEndPluginBase(FrontEndBase[FrontEndConfigT], ABC):
45
47
 
46
48
  click.echo(stream.getvalue())
47
49
 
48
- workflow = builder.build()
50
+ workflow = await builder.build()
49
51
  session_manager = SessionManager(workflow)
50
52
  await self.run_workflow(session_manager)
51
53
 
@@ -21,27 +21,39 @@ from nat.builder.builder import Builder
21
21
  from nat.builder.llm import LLMProviderInfo
22
22
  from nat.cli.register_workflow import register_llm_provider
23
23
  from nat.data_models.llm import LLMBaseConfig
24
+ from nat.data_models.optimizable import OptimizableField
25
+ from nat.data_models.optimizable import OptimizableMixin
26
+ from nat.data_models.optimizable import SearchSpace
24
27
  from nat.data_models.retry_mixin import RetryMixin
28
+ from nat.data_models.temperature_mixin import TemperatureMixin
29
+ from nat.data_models.thinking_mixin import ThinkingMixin
30
+ from nat.data_models.top_p_mixin import TopPMixin
25
31
 
26
32
 
27
- class AWSBedrockModelConfig(LLMBaseConfig, RetryMixin, name="aws_bedrock"):
33
+ class AWSBedrockModelConfig(LLMBaseConfig,
34
+ RetryMixin,
35
+ OptimizableMixin,
36
+ TemperatureMixin,
37
+ TopPMixin,
38
+ ThinkingMixin,
39
+ name="aws_bedrock"):
28
40
  """An AWS Bedrock llm provider to be used with an LLM client."""
29
41
 
30
- model_config = ConfigDict(protected_namespaces=())
42
+ model_config = ConfigDict(protected_namespaces=(), extra="allow")
31
43
 
32
44
  # Completion parameters
33
45
  model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
34
46
  serialization_alias="model",
35
47
  description="The model name for the hosted AWS Bedrock.")
36
- temperature: float = Field(default=0.0, ge=0.0, le=1.0, description="Sampling temperature in [0, 1].")
37
- max_tokens: int | None = Field(default=1024,
38
- gt=0,
39
- description="Maximum number of tokens to generate."
40
- "This field is ONLY required when using AWS Bedrock with Langchain.")
41
- context_size: int | None = Field(default=1024,
42
- gt=0,
43
- description="Maximum number of tokens to generate."
44
- "This field is ONLY required when using AWS Bedrock with LlamaIndex.")
48
+ max_tokens: int = OptimizableField(default=300,
49
+ description="Maximum number of tokens to generate.",
50
+ space=SearchSpace(high=2176, low=128, step=512))
51
+ context_size: int | None = Field(
52
+ default=1024,
53
+ gt=0,
54
+ description="The maximum number of tokens available for input. This is only required for LlamaIndex. "
55
+ "This field is ignored for LangChain/LangGraph.",
56
+ )
45
57
 
46
58
  # Client parameters
47
59
  region_name: str | None = Field(default="None", description="AWS region to use.")
@@ -52,6 +64,6 @@ class AWSBedrockModelConfig(LLMBaseConfig, RetryMixin, name="aws_bedrock"):
52
64
 
53
65
 
54
66
  @register_llm_provider(config_type=AWSBedrockModelConfig)
55
- async def aws_bedrock_model(llm_config: AWSBedrockModelConfig, builder: Builder):
67
+ async def aws_bedrock_model(llm_config: AWSBedrockModelConfig, _builder: Builder):
56
68
 
57
69
  yield LLMProviderInfo(config=llm_config, description="A AWS Bedrock model for use with an LLM client.")
@@ -22,9 +22,19 @@ from nat.builder.llm import LLMProviderInfo
22
22
  from nat.cli.register_workflow import register_llm_provider
23
23
  from nat.data_models.llm import LLMBaseConfig
24
24
  from nat.data_models.retry_mixin import RetryMixin
25
-
26
-
27
- class AzureOpenAIModelConfig(LLMBaseConfig, RetryMixin, name="azure_openai"):
25
+ from nat.data_models.temperature_mixin import TemperatureMixin
26
+ from nat.data_models.thinking_mixin import ThinkingMixin
27
+ from nat.data_models.top_p_mixin import TopPMixin
28
+
29
+
30
+ class AzureOpenAIModelConfig(
31
+ LLMBaseConfig,
32
+ RetryMixin,
33
+ TemperatureMixin,
34
+ TopPMixin,
35
+ ThinkingMixin,
36
+ name="azure_openai",
37
+ ):
28
38
  """An Azure OpenAI LLM provider to be used with an LLM client."""
29
39
 
30
40
  model_config = ConfigDict(protected_namespaces=(), extra="allow")
@@ -38,10 +48,7 @@ class AzureOpenAIModelConfig(LLMBaseConfig, RetryMixin, name="azure_openai"):
38
48
  azure_deployment: str = Field(validation_alias=AliasChoices("azure_deployment", "model_name", "model"),
39
49
  serialization_alias="azure_deployment",
40
50
  description="The Azure OpenAI hosted model/deployment name.")
41
- temperature: float = Field(default=0.0, description="Sampling temperature in [0, 1].")
42
- top_p: float = Field(default=1.0, description="Top-p for distribution sampling.")
43
51
  seed: int | None = Field(default=None, description="Random seed to set for generation.")
44
- max_retries: int = Field(default=10, description="The max number of retries for the request.")
45
52
 
46
53
 
47
54
  @register_llm_provider(config_type=AzureOpenAIModelConfig)
nat/llm/litellm_llm.py ADDED
@@ -0,0 +1,69 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections.abc import AsyncIterator
17
+
18
+ from pydantic import AliasChoices
19
+ from pydantic import ConfigDict
20
+ from pydantic import Field
21
+
22
+ from nat.builder.builder import Builder
23
+ from nat.builder.llm import LLMProviderInfo
24
+ from nat.cli.register_workflow import register_llm_provider
25
+ from nat.data_models.llm import LLMBaseConfig
26
+ from nat.data_models.retry_mixin import RetryMixin
27
+ from nat.data_models.temperature_mixin import TemperatureMixin
28
+ from nat.data_models.thinking_mixin import ThinkingMixin
29
+ from nat.data_models.top_p_mixin import TopPMixin
30
+
31
+
32
+ class LiteLlmModelConfig(
33
+ LLMBaseConfig,
34
+ RetryMixin,
35
+ TemperatureMixin,
36
+ TopPMixin,
37
+ ThinkingMixin,
38
+ name="litellm",
39
+ ):
40
+ """A LiteLlm provider to be used with an LLM client."""
41
+
42
+ model_config = ConfigDict(protected_namespaces=(), extra="allow")
43
+
44
+ api_key: str | None = Field(default=None, description="API key to interact with hosted model.")
45
+ base_url: str | None = Field(default=None,
46
+ description="Base url to the hosted model.",
47
+ validation_alias=AliasChoices("base_url", "api_base"),
48
+ serialization_alias="api_base")
49
+ model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
50
+ serialization_alias="model",
51
+ description="The LiteLlm hosted model name.")
52
+ seed: int | None = Field(default=None, description="Random seed to set for generation.")
53
+
54
+
55
+ @register_llm_provider(config_type=LiteLlmModelConfig)
56
+ async def litellm_model(
57
+ config: LiteLlmModelConfig,
58
+ _builder: Builder,
59
+ ) -> AsyncIterator[LLMProviderInfo]:
60
+ """Litellm model provider.
61
+
62
+ Args:
63
+ config (LiteLlmModelConfig): The LiteLlm model configuration.
64
+ _builder (Builder): The NAT builder instance.
65
+
66
+ Returns:
67
+ AsyncIterator[LLMProviderInfo]: An async iterator that yields an LLMProviderInfo object.
68
+ """
69
+ yield LLMProviderInfo(config=config, description="A LiteLlm model for use with an LLM client.")
nat/llm/nim_llm.py CHANGED
@@ -22,25 +22,37 @@ from nat.builder.builder import Builder
22
22
  from nat.builder.llm import LLMProviderInfo
23
23
  from nat.cli.register_workflow import register_llm_provider
24
24
  from nat.data_models.llm import LLMBaseConfig
25
+ from nat.data_models.optimizable import OptimizableField
26
+ from nat.data_models.optimizable import OptimizableMixin
27
+ from nat.data_models.optimizable import SearchSpace
25
28
  from nat.data_models.retry_mixin import RetryMixin
26
-
27
-
28
- class NIMModelConfig(LLMBaseConfig, RetryMixin, name="nim"):
29
+ from nat.data_models.temperature_mixin import TemperatureMixin
30
+ from nat.data_models.thinking_mixin import ThinkingMixin
31
+ from nat.data_models.top_p_mixin import TopPMixin
32
+
33
+
34
+ class NIMModelConfig(LLMBaseConfig,
35
+ RetryMixin,
36
+ OptimizableMixin,
37
+ TemperatureMixin,
38
+ TopPMixin,
39
+ ThinkingMixin,
40
+ name="nim"):
29
41
  """An NVIDIA Inference Microservice (NIM) llm provider to be used with an LLM client."""
30
42
 
31
- model_config = ConfigDict(protected_namespaces=())
43
+ model_config = ConfigDict(protected_namespaces=(), extra="allow")
32
44
 
33
45
  api_key: str | None = Field(default=None, description="NVIDIA API key to interact with hosted NIM.")
34
46
  base_url: str | None = Field(default=None, description="Base url to the hosted NIM.")
35
47
  model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
36
48
  serialization_alias="model",
37
49
  description="The model name for the hosted NIM.")
38
- temperature: float = Field(default=0.0, description="Sampling temperature in [0, 1].")
39
- top_p: float = Field(default=1.0, description="Top-p for distribution sampling.")
40
- max_tokens: PositiveInt = Field(default=300, description="Maximum number of tokens to generate.")
50
+ max_tokens: PositiveInt = OptimizableField(default=300,
51
+ description="Maximum number of tokens to generate.",
52
+ space=SearchSpace(high=2176, low=128, step=512))
41
53
 
42
54
 
43
55
  @register_llm_provider(config_type=NIMModelConfig)
44
- async def nim_model(llm_config: NIMModelConfig, builder: Builder):
56
+ async def nim_model(llm_config: NIMModelConfig, _builder: Builder):
45
57
 
46
58
  yield LLMProviderInfo(config=llm_config, description="A NIM model for use with an LLM client.")
nat/llm/openai_llm.py CHANGED
@@ -21,10 +21,20 @@ from nat.builder.builder import Builder
21
21
  from nat.builder.llm import LLMProviderInfo
22
22
  from nat.cli.register_workflow import register_llm_provider
23
23
  from nat.data_models.llm import LLMBaseConfig
24
+ from nat.data_models.optimizable import OptimizableMixin
24
25
  from nat.data_models.retry_mixin import RetryMixin
25
-
26
-
27
- class OpenAIModelConfig(LLMBaseConfig, RetryMixin, name="openai"):
26
+ from nat.data_models.temperature_mixin import TemperatureMixin
27
+ from nat.data_models.thinking_mixin import ThinkingMixin
28
+ from nat.data_models.top_p_mixin import TopPMixin
29
+
30
+
31
+ class OpenAIModelConfig(LLMBaseConfig,
32
+ RetryMixin,
33
+ OptimizableMixin,
34
+ TemperatureMixin,
35
+ TopPMixin,
36
+ ThinkingMixin,
37
+ name="openai"):
28
38
  """An OpenAI LLM provider to be used with an LLM client."""
29
39
 
30
40
  model_config = ConfigDict(protected_namespaces=(), extra="allow")
@@ -34,13 +44,11 @@ class OpenAIModelConfig(LLMBaseConfig, RetryMixin, name="openai"):
34
44
  model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
35
45
  serialization_alias="model",
36
46
  description="The OpenAI hosted model name.")
37
- temperature: float = Field(default=0.0, description="Sampling temperature in [0, 1].")
38
- top_p: float = Field(default=1.0, description="Top-p for distribution sampling.")
39
47
  seed: int | None = Field(default=None, description="Random seed to set for generation.")
40
48
  max_retries: int = Field(default=10, description="The max number of retries for the request.")
41
49
 
42
50
 
43
51
  @register_llm_provider(config_type=OpenAIModelConfig)
44
- async def openai_llm(config: OpenAIModelConfig, builder: Builder):
52
+ async def openai_llm(config: OpenAIModelConfig, _builder: Builder):
45
53
 
46
54
  yield LLMProviderInfo(config=config, description="An OpenAI model for use with an LLM client.")
nat/llm/register.py CHANGED
@@ -13,12 +13,15 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- # pylint: disable=unused-import
17
16
  # flake8: noqa
18
17
  # isort:skip_file
18
+ """Register LLM providers via import side effects.
19
19
 
20
+ This module is imported by the NeMo Agent Toolkit runtime to ensure providers are registered and discoverable.
21
+ """
20
22
  # Import any providers which need to be automatically registered here
21
23
  from . import aws_bedrock_llm
22
24
  from . import azure_openai_llm
25
+ from . import litellm_llm
23
26
  from . import nim_llm
24
27
  from . import openai_llm
@@ -72,9 +72,8 @@ class EnvConfigValue(ABC):
72
72
  f"{message} Try passing a value to the constructor, or setting the `{self.__class__._ENV_KEY}` "
73
73
  "environment variable.")
74
74
 
75
- else:
76
- if not self.__class__._ALLOW_NONE and value is None:
77
- raise ValueError("value must not be none")
75
+ elif not self.__class__._ALLOW_NONE and value is None:
76
+ raise ValueError("value must not be none")
78
77
 
79
78
  assert isinstance(value, str) or value is None
80
79