nvidia-nat 1.2.1rc1__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (257) hide show
  1. aiq/__init__.py +2 -2
  2. nat/agent/base.py +27 -18
  3. nat/agent/dual_node.py +9 -4
  4. nat/agent/prompt_optimizer/prompt.py +68 -0
  5. nat/agent/prompt_optimizer/register.py +149 -0
  6. nat/agent/react_agent/agent.py +81 -50
  7. nat/agent/react_agent/register.py +59 -40
  8. nat/agent/reasoning_agent/reasoning_agent.py +17 -15
  9. nat/agent/register.py +1 -1
  10. nat/agent/rewoo_agent/agent.py +327 -149
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +64 -46
  13. nat/agent/tool_calling_agent/agent.py +152 -29
  14. nat/agent/tool_calling_agent/register.py +61 -38
  15. nat/authentication/api_key/api_key_auth_provider.py +2 -2
  16. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  17. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  18. nat/authentication/interfaces.py +5 -2
  19. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
  20. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  21. nat/authentication/register.py +0 -1
  22. nat/builder/builder.py +56 -24
  23. nat/builder/component_utils.py +10 -6
  24. nat/builder/context.py +70 -18
  25. nat/builder/eval_builder.py +16 -11
  26. nat/builder/framework_enum.py +1 -0
  27. nat/builder/front_end.py +1 -1
  28. nat/builder/function.py +378 -8
  29. nat/builder/function_base.py +3 -3
  30. nat/builder/function_info.py +6 -8
  31. nat/builder/intermediate_step_manager.py +6 -2
  32. nat/builder/user_interaction_manager.py +2 -2
  33. nat/builder/workflow.py +13 -1
  34. nat/builder/workflow_builder.py +327 -79
  35. nat/cli/cli_utils/config_override.py +2 -2
  36. nat/cli/commands/evaluate.py +1 -1
  37. nat/cli/commands/info/info.py +16 -6
  38. nat/cli/commands/info/list_channels.py +1 -1
  39. nat/cli/commands/info/list_components.py +7 -8
  40. nat/cli/commands/mcp/__init__.py +14 -0
  41. nat/cli/commands/mcp/mcp.py +986 -0
  42. nat/cli/commands/object_store/__init__.py +14 -0
  43. nat/cli/commands/object_store/object_store.py +227 -0
  44. nat/cli/commands/optimize.py +90 -0
  45. nat/cli/commands/registry/publish.py +2 -2
  46. nat/cli/commands/registry/pull.py +2 -2
  47. nat/cli/commands/registry/remove.py +2 -2
  48. nat/cli/commands/registry/search.py +15 -17
  49. nat/cli/commands/start.py +16 -5
  50. nat/cli/commands/uninstall.py +1 -1
  51. nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
  52. nat/cli/commands/workflow/templates/pyproject.toml.j2 +5 -2
  53. nat/cli/commands/workflow/templates/register.py.j2 +2 -3
  54. nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
  55. nat/cli/commands/workflow/workflow_commands.py +105 -19
  56. nat/cli/entrypoint.py +17 -11
  57. nat/cli/main.py +3 -0
  58. nat/cli/register_workflow.py +38 -4
  59. nat/cli/type_registry.py +79 -10
  60. nat/control_flow/__init__.py +0 -0
  61. nat/control_flow/register.py +20 -0
  62. nat/control_flow/router_agent/__init__.py +0 -0
  63. nat/control_flow/router_agent/agent.py +329 -0
  64. nat/control_flow/router_agent/prompt.py +48 -0
  65. nat/control_flow/router_agent/register.py +91 -0
  66. nat/control_flow/sequential_executor.py +166 -0
  67. nat/data_models/agent.py +34 -0
  68. nat/data_models/api_server.py +196 -67
  69. nat/data_models/authentication.py +23 -9
  70. nat/data_models/common.py +1 -1
  71. nat/data_models/component.py +2 -0
  72. nat/data_models/component_ref.py +11 -0
  73. nat/data_models/config.py +42 -18
  74. nat/data_models/dataset_handler.py +1 -1
  75. nat/data_models/discovery_metadata.py +4 -4
  76. nat/data_models/evaluate.py +4 -1
  77. nat/data_models/function.py +34 -0
  78. nat/data_models/function_dependencies.py +14 -6
  79. nat/data_models/gated_field_mixin.py +242 -0
  80. nat/data_models/intermediate_step.py +3 -3
  81. nat/data_models/optimizable.py +119 -0
  82. nat/data_models/optimizer.py +149 -0
  83. nat/data_models/span.py +41 -3
  84. nat/data_models/swe_bench_model.py +1 -1
  85. nat/data_models/temperature_mixin.py +44 -0
  86. nat/data_models/thinking_mixin.py +86 -0
  87. nat/data_models/top_p_mixin.py +44 -0
  88. nat/embedder/azure_openai_embedder.py +46 -0
  89. nat/embedder/nim_embedder.py +1 -1
  90. nat/embedder/openai_embedder.py +2 -3
  91. nat/embedder/register.py +1 -1
  92. nat/eval/config.py +3 -1
  93. nat/eval/dataset_handler/dataset_handler.py +71 -7
  94. nat/eval/evaluate.py +86 -31
  95. nat/eval/evaluator/base_evaluator.py +1 -1
  96. nat/eval/evaluator/evaluator_model.py +13 -0
  97. nat/eval/intermediate_step_adapter.py +1 -1
  98. nat/eval/rag_evaluator/evaluate.py +9 -6
  99. nat/eval/rag_evaluator/register.py +3 -3
  100. nat/eval/register.py +4 -1
  101. nat/eval/remote_workflow.py +3 -3
  102. nat/eval/runtime_evaluator/__init__.py +14 -0
  103. nat/eval/runtime_evaluator/evaluate.py +123 -0
  104. nat/eval/runtime_evaluator/register.py +100 -0
  105. nat/eval/swe_bench_evaluator/evaluate.py +6 -6
  106. nat/eval/trajectory_evaluator/evaluate.py +1 -1
  107. nat/eval/trajectory_evaluator/register.py +1 -1
  108. nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
  109. nat/eval/utils/eval_trace_ctx.py +89 -0
  110. nat/eval/utils/weave_eval.py +18 -9
  111. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  112. nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
  113. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  114. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
  115. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +3 -3
  116. nat/experimental/test_time_compute/models/strategy_base.py +5 -4
  117. nat/experimental/test_time_compute/register.py +0 -1
  118. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
  119. nat/front_ends/console/authentication_flow_handler.py +82 -30
  120. nat/front_ends/console/console_front_end_plugin.py +19 -7
  121. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
  122. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  123. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  124. nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
  125. nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
  126. nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
  127. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +455 -282
  128. nat/front_ends/fastapi/job_store.py +518 -99
  129. nat/front_ends/fastapi/main.py +11 -19
  130. nat/front_ends/fastapi/message_handler.py +74 -50
  131. nat/front_ends/fastapi/message_validator.py +20 -21
  132. nat/front_ends/fastapi/response_helpers.py +4 -4
  133. nat/front_ends/fastapi/step_adaptor.py +2 -2
  134. nat/front_ends/fastapi/utils.py +57 -0
  135. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  136. nat/front_ends/mcp/mcp_front_end_config.py +47 -3
  137. nat/front_ends/mcp/mcp_front_end_plugin.py +48 -13
  138. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +120 -8
  139. nat/front_ends/mcp/tool_converter.py +44 -14
  140. nat/front_ends/register.py +0 -1
  141. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  142. nat/llm/aws_bedrock_llm.py +24 -12
  143. nat/llm/azure_openai_llm.py +57 -0
  144. nat/llm/litellm_llm.py +69 -0
  145. nat/llm/nim_llm.py +20 -8
  146. nat/llm/openai_llm.py +14 -6
  147. nat/llm/register.py +5 -1
  148. nat/llm/utils/env_config_value.py +2 -3
  149. nat/llm/utils/thinking.py +215 -0
  150. nat/meta/pypi.md +9 -9
  151. nat/object_store/register.py +0 -1
  152. nat/observability/exporter/base_exporter.py +3 -3
  153. nat/observability/exporter/file_exporter.py +1 -1
  154. nat/observability/exporter/processing_exporter.py +309 -81
  155. nat/observability/exporter/span_exporter.py +35 -15
  156. nat/observability/exporter_manager.py +7 -7
  157. nat/observability/mixin/file_mixin.py +7 -7
  158. nat/observability/mixin/redaction_config_mixin.py +42 -0
  159. nat/observability/mixin/tagging_config_mixin.py +62 -0
  160. nat/observability/mixin/type_introspection_mixin.py +420 -107
  161. nat/observability/processor/batching_processor.py +5 -7
  162. nat/observability/processor/falsy_batch_filter_processor.py +55 -0
  163. nat/observability/processor/processor.py +3 -0
  164. nat/observability/processor/processor_factory.py +70 -0
  165. nat/observability/processor/redaction/__init__.py +24 -0
  166. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  167. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  168. nat/observability/processor/redaction/redaction_processor.py +177 -0
  169. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  170. nat/observability/processor/span_tagging_processor.py +68 -0
  171. nat/observability/register.py +22 -4
  172. nat/profiler/calc/calc_runner.py +3 -4
  173. nat/profiler/callbacks/agno_callback_handler.py +1 -1
  174. nat/profiler/callbacks/langchain_callback_handler.py +14 -7
  175. nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
  176. nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
  177. nat/profiler/data_frame_row.py +1 -1
  178. nat/profiler/decorators/framework_wrapper.py +62 -13
  179. nat/profiler/decorators/function_tracking.py +160 -3
  180. nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
  181. nat/profiler/forecasting/models/linear_model.py +1 -1
  182. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  183. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  184. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
  185. nat/profiler/inference_optimization/data_models.py +3 -3
  186. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +8 -9
  187. nat/profiler/inference_optimization/token_uniqueness.py +1 -1
  188. nat/profiler/parameter_optimization/__init__.py +0 -0
  189. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  190. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  191. nat/profiler/parameter_optimization/parameter_optimizer.py +164 -0
  192. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  193. nat/profiler/parameter_optimization/pareto_visualizer.py +395 -0
  194. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  195. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  196. nat/profiler/profile_runner.py +14 -9
  197. nat/profiler/utils.py +4 -2
  198. nat/registry_handlers/local/local_handler.py +2 -2
  199. nat/registry_handlers/package_utils.py +1 -2
  200. nat/registry_handlers/pypi/pypi_handler.py +23 -26
  201. nat/registry_handlers/register.py +3 -4
  202. nat/registry_handlers/rest/rest_handler.py +12 -13
  203. nat/retriever/milvus/retriever.py +2 -2
  204. nat/retriever/nemo_retriever/retriever.py +1 -1
  205. nat/retriever/register.py +0 -1
  206. nat/runtime/loader.py +2 -2
  207. nat/runtime/runner.py +105 -8
  208. nat/runtime/session.py +69 -8
  209. nat/settings/global_settings.py +16 -5
  210. nat/tool/chat_completion.py +5 -2
  211. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
  212. nat/tool/datetime_tools.py +49 -9
  213. nat/tool/document_search.py +2 -2
  214. nat/tool/github_tools.py +450 -0
  215. nat/tool/memory_tools/add_memory_tool.py +3 -3
  216. nat/tool/memory_tools/delete_memory_tool.py +3 -4
  217. nat/tool/memory_tools/get_memory_tool.py +4 -4
  218. nat/tool/nvidia_rag.py +1 -1
  219. nat/tool/register.py +2 -9
  220. nat/tool/retriever.py +3 -2
  221. nat/utils/callable_utils.py +70 -0
  222. nat/utils/data_models/schema_validator.py +3 -3
  223. nat/utils/decorators.py +210 -0
  224. nat/utils/exception_handlers/automatic_retries.py +104 -51
  225. nat/utils/exception_handlers/schemas.py +1 -1
  226. nat/utils/io/yaml_tools.py +2 -2
  227. nat/utils/log_levels.py +25 -0
  228. nat/utils/reactive/base/observable_base.py +2 -2
  229. nat/utils/reactive/base/observer_base.py +1 -1
  230. nat/utils/reactive/observable.py +2 -2
  231. nat/utils/reactive/observer.py +4 -4
  232. nat/utils/reactive/subscription.py +1 -1
  233. nat/utils/settings/global_settings.py +6 -8
  234. nat/utils/type_converter.py +12 -3
  235. nat/utils/type_utils.py +9 -5
  236. nvidia_nat-1.3.0.dist-info/METADATA +195 -0
  237. {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/RECORD +244 -200
  238. {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/entry_points.txt +1 -0
  239. nat/cli/commands/info/list_mcp.py +0 -304
  240. nat/tool/github_tools/create_github_commit.py +0 -133
  241. nat/tool/github_tools/create_github_issue.py +0 -87
  242. nat/tool/github_tools/create_github_pr.py +0 -106
  243. nat/tool/github_tools/get_github_file.py +0 -106
  244. nat/tool/github_tools/get_github_issue.py +0 -166
  245. nat/tool/github_tools/get_github_pr.py +0 -256
  246. nat/tool/github_tools/update_github_issue.py +0 -100
  247. nat/tool/mcp/exceptions.py +0 -142
  248. nat/tool/mcp/mcp_client.py +0 -255
  249. nat/tool/mcp/mcp_tool.py +0 -96
  250. nat/utils/exception_handlers/mcp.py +0 -211
  251. nvidia_nat-1.2.1rc1.dist-info/METADATA +0 -365
  252. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  253. /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
  254. {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/WHEEL +0 -0
  255. {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  256. {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/licenses/LICENSE.md +0 -0
  257. {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/top_level.txt +0 -0
@@ -13,15 +13,22 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import logging
17
+ from typing import Literal
18
+
16
19
  from pydantic import Field
20
+ from pydantic import model_validator
17
21
 
22
+ from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
18
23
  from nat.data_models.front_end import FrontEndBaseConfig
19
24
 
25
+ logger = logging.getLogger(__name__)
26
+
20
27
 
21
28
  class MCPFrontEndConfig(FrontEndBaseConfig, name="mcp"):
22
29
  """MCP front end configuration.
23
30
 
24
- A simple MCP (Modular Communication Protocol) front end for NeMo Agent toolkit.
31
+ A simple MCP (Model Context Protocol) front end for NeMo Agent toolkit.
25
32
  """
26
33
 
27
34
  name: str = Field(default="NeMo Agent Toolkit MCP",
@@ -30,7 +37,44 @@ class MCPFrontEndConfig(FrontEndBaseConfig, name="mcp"):
30
37
  port: int = Field(default=9901, description="Port to bind the server to (default: 9901)", ge=0, le=65535)
31
38
  debug: bool = Field(default=False, description="Enable debug mode (default: False)")
32
39
  log_level: str = Field(default="INFO", description="Log level for the MCP server (default: INFO)")
33
- tool_names: list[str] = Field(default_factory=list,
34
- description="The list of tools MCP server will expose (default: all tools)")
40
+ tool_names: list[str] = Field(
41
+ default_factory=list,
42
+ description="The list of tools MCP server will expose (default: all tools)."
43
+ "Tool names can be functions or function groups",
44
+ )
45
+ transport: Literal["sse", "streamable-http"] = Field(
46
+ default="streamable-http",
47
+ description="Transport type for the MCP server (default: streamable-http, backwards compatible with sse)")
35
48
  runner_class: str | None = Field(
36
49
  default=None, description="Custom worker class for handling MCP routes (default: built-in worker)")
50
+
51
+ server_auth: OAuth2ResourceServerConfig | None = Field(
52
+ default=None, description=("OAuth 2.0 Resource Server configuration for token verification."))
53
+
54
+ @model_validator(mode="after")
55
+ def validate_security_configuration(self):
56
+ """Validate security configuration to prevent accidental misconfigurations."""
57
+ # Check if server is bound to a non-localhost interface without authentication
58
+ localhost_hosts = {"localhost", "127.0.0.1", "::1"}
59
+ if self.host not in localhost_hosts and self.server_auth is None:
60
+ logger.warning(
61
+ "MCP server is configured to bind to '%s' without authentication. "
62
+ "This may expose your server to unauthorized access. "
63
+ "Consider either: (1) binding to localhost for local-only access, "
64
+ "or (2) configuring server_auth for production deployments on public interfaces.",
65
+ self.host)
66
+
67
+ # Check if SSE transport is used (which doesn't support authentication)
68
+ if self.transport == "sse":
69
+ if self.server_auth is not None:
70
+ logger.warning("SSE transport does not support authentication. "
71
+ "The configured server_auth will be ignored. "
72
+ "For production use with authentication, use 'streamable-http' transport instead.")
73
+ elif self.host not in localhost_hosts:
74
+ logger.warning(
75
+ "SSE transport does not support authentication and is bound to '%s'. "
76
+ "This configuration is not recommended for production use. "
77
+ "For production deployments, use 'streamable-http' transport with server_auth configured.",
78
+ self.host)
79
+
80
+ return self
@@ -16,6 +16,7 @@
16
16
  import logging
17
17
  import typing
18
18
 
19
+ from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
19
20
  from nat.builder.front_end import FrontEndBase
20
21
  from nat.builder.workflow_builder import WorkflowBuilder
21
22
  from nat.front_ends.mcp.mcp_front_end_config import MCPFrontEndConfig
@@ -55,27 +56,61 @@ class MCPFrontEndPlugin(FrontEndBase[MCPFrontEndConfig]):
55
56
 
56
57
  return worker_class(self.full_config)
57
58
 
59
+ async def _create_token_verifier(self, token_verifier_config: OAuth2ResourceServerConfig):
60
+ """Create a token verifier based on configuration."""
61
+ from nat.front_ends.mcp.introspection_token_verifier import IntrospectionTokenVerifier
62
+
63
+ if not self.front_end_config.server_auth:
64
+ return None
65
+
66
+ return IntrospectionTokenVerifier(token_verifier_config)
67
+
58
68
  async def run(self) -> None:
59
69
  """Run the MCP server."""
60
70
  # Import FastMCP
61
71
  from mcp.server.fastmcp import FastMCP
62
72
 
63
- # Create an MCP server with the configured parameters
64
- mcp = FastMCP(
65
- self.front_end_config.name,
66
- host=self.front_end_config.host,
67
- port=self.front_end_config.port,
68
- debug=self.front_end_config.debug,
69
- log_level=self.front_end_config.log_level,
70
- )
71
-
72
- # Get the worker instance and set up routes
73
- worker = self._get_worker_instance()
73
+ # Create auth settings and token verifier if auth is required
74
+ auth_settings = None
75
+ token_verifier = None
74
76
 
75
77
  # Build the workflow and add routes using the worker
76
78
  async with WorkflowBuilder.from_config(config=self.full_config) as builder:
79
+
80
+ if self.front_end_config.server_auth:
81
+ from mcp.server.auth.settings import AuthSettings
82
+ from pydantic import AnyHttpUrl
83
+
84
+ server_url = f"http://{self.front_end_config.host}:{self.front_end_config.port}"
85
+
86
+ auth_settings = AuthSettings(issuer_url=AnyHttpUrl(self.front_end_config.server_auth.issuer_url),
87
+ required_scopes=self.front_end_config.server_auth.scopes,
88
+ resource_server_url=AnyHttpUrl(server_url))
89
+
90
+ token_verifier = await self._create_token_verifier(self.front_end_config.server_auth)
91
+
92
+ # Create an MCP server with the configured parameters
93
+ mcp = FastMCP(name=self.front_end_config.name,
94
+ host=self.front_end_config.host,
95
+ port=self.front_end_config.port,
96
+ debug=self.front_end_config.debug,
97
+ auth=auth_settings,
98
+ token_verifier=token_verifier)
99
+
100
+ # Get the worker instance and set up routes
101
+ worker = self._get_worker_instance()
102
+
77
103
  # Add routes through the worker (includes health endpoint and function registration)
78
104
  await worker.add_routes(mcp, builder)
79
105
 
80
- # Start the MCP server
81
- await mcp.run_sse_async()
106
+ # Start the MCP server with configurable transport
107
+ # streamable-http is the default, but users can choose sse if preferred
108
+ try:
109
+ if self.front_end_config.transport == "sse":
110
+ logger.info("Starting MCP server with SSE endpoint at /sse")
111
+ await mcp.run_sse_async()
112
+ else: # streamable-http
113
+ logger.info("Starting MCP server with streamable-http endpoint at /mcp/")
114
+ await mcp.run_streamable_http_async()
115
+ except KeyboardInterrupt:
116
+ logger.info("MCP server shutdown requested (Ctrl+C). Shutting down gracefully.")
@@ -16,11 +16,15 @@
16
16
  import logging
17
17
  from abc import ABC
18
18
  from abc import abstractmethod
19
+ from collections.abc import Mapping
20
+ from typing import Any
19
21
 
20
22
  from mcp.server.fastmcp import FastMCP
23
+ from starlette.exceptions import HTTPException
21
24
  from starlette.requests import Request
22
25
 
23
26
  from nat.builder.function import Function
27
+ from nat.builder.function_base import FunctionBase
24
28
  from nat.builder.workflow import Workflow
25
29
  from nat.builder.workflow_builder import WorkflowBuilder
26
30
  from nat.data_models.config import Config
@@ -82,7 +86,7 @@ class MCPFrontEndPluginWorkerBase(ABC):
82
86
  """
83
87
  pass
84
88
 
85
- def _get_all_functions(self, workflow: Workflow) -> dict[str, Function]:
89
+ async def _get_all_functions(self, workflow: Workflow) -> dict[str, Function]:
86
90
  """Get all functions from the workflow.
87
91
 
88
92
  Args:
@@ -94,13 +98,114 @@ class MCPFrontEndPluginWorkerBase(ABC):
94
98
  functions: dict[str, Function] = {}
95
99
 
96
100
  # Extract all functions from the workflow
97
- for function_name, function in workflow.functions.items():
98
- functions[function_name] = function
101
+ functions.update(workflow.functions)
102
+ for function_group in workflow.function_groups.values():
103
+ functions.update(await function_group.get_accessible_functions())
99
104
 
100
- functions[workflow.config.workflow.type] = workflow
105
+ if workflow.config.workflow.workflow_alias:
106
+ functions[workflow.config.workflow.workflow_alias] = workflow
107
+ else:
108
+ functions[workflow.config.workflow.type] = workflow
101
109
 
102
110
  return functions
103
111
 
112
+ def _setup_debug_endpoints(self, mcp: FastMCP, functions: Mapping[str, FunctionBase]) -> None:
113
+ """Set up HTTP debug endpoints for introspecting tools and schemas.
114
+
115
+ Exposes:
116
+ - GET /debug/tools/list: List tools. Optional query param `name` (one or more, repeatable or comma separated)
117
+ selects a subset and returns details for those tools.
118
+ """
119
+
120
+ @mcp.custom_route("/debug/tools/list", methods=["GET"])
121
+ async def list_tools(request: Request):
122
+ """HTTP list tools endpoint."""
123
+
124
+ from starlette.responses import JSONResponse
125
+
126
+ from nat.front_ends.mcp.tool_converter import get_function_description
127
+
128
+ # Query params
129
+ # Support repeated names and comma-separated lists
130
+ names_param_list = set(request.query_params.getlist("name"))
131
+ names: list[str] = []
132
+ for raw in names_param_list:
133
+ # if p.strip() is empty, it won't be included in the list!
134
+ parts = [p.strip() for p in raw.split(",") if p.strip()]
135
+ names.extend(parts)
136
+ detail_raw = request.query_params.get("detail")
137
+
138
+ def _parse_detail_param(detail_param: str | None, has_names: bool) -> bool:
139
+ if detail_param is None:
140
+ if has_names:
141
+ return True
142
+ return False
143
+ v = detail_param.strip().lower()
144
+ if v in ("0", "false", "no", "off"):
145
+ return False
146
+ if v in ("1", "true", "yes", "on"):
147
+ return True
148
+ # For invalid values, default based on whether names are present
149
+ return has_names
150
+
151
+ # Helper function to build the input schema info
152
+ def _build_schema_info(fn: FunctionBase) -> dict[str, Any] | None:
153
+ schema = getattr(fn, "input_schema", None)
154
+ if schema is None:
155
+ return None
156
+
157
+ # check if schema is a ChatRequest
158
+ schema_name = getattr(schema, "__name__", "")
159
+ schema_qualname = getattr(schema, "__qualname__", "")
160
+ if "ChatRequest" in schema_name or "ChatRequest" in schema_qualname:
161
+ # Simplified interface used by MCP wrapper for ChatRequest
162
+ return {
163
+ "type": "object",
164
+ "properties": {
165
+ "query": {
166
+ "type": "string", "description": "User query string"
167
+ }
168
+ },
169
+ "required": ["query"],
170
+ "title": "ChatRequestQuery",
171
+ }
172
+
173
+ # Pydantic models provide model_json_schema
174
+ if schema is not None and hasattr(schema, "model_json_schema"):
175
+ return schema.model_json_schema()
176
+
177
+ return None
178
+
179
+ def _build_final_json(functions_to_include: Mapping[str, FunctionBase],
180
+ include_schemas: bool = False) -> dict[str, Any]:
181
+ tools = []
182
+ for name, fn in functions_to_include.items():
183
+ list_entry: dict[str, Any] = {
184
+ "name": name, "description": get_function_description(fn), "is_workflow": hasattr(fn, "run")
185
+ }
186
+ if include_schemas:
187
+ list_entry["schema"] = _build_schema_info(fn)
188
+ tools.append(list_entry)
189
+
190
+ return {
191
+ "count": len(tools),
192
+ "tools": tools,
193
+ "server_name": mcp.name,
194
+ }
195
+
196
+ if names:
197
+ # Return selected tools
198
+ try:
199
+ functions_to_include = {n: functions[n] for n in names}
200
+ except KeyError as e:
201
+ raise HTTPException(status_code=404, detail=f"Tool \"{e.args[0]}\" not found.") from e
202
+ else:
203
+ functions_to_include = functions
204
+
205
+ # Default for listing all: detail defaults to False unless explicitly set true
206
+ return JSONResponse(
207
+ _build_final_json(functions_to_include, _parse_detail_param(detail_raw, has_names=bool(names))))
208
+
104
209
 
105
210
  class MCPFrontEndPluginWorker(MCPFrontEndPluginWorkerBase):
106
211
  """Default MCP front end plugin worker implementation."""
@@ -118,10 +223,10 @@ class MCPFrontEndPluginWorker(MCPFrontEndPluginWorkerBase):
118
223
  self._setup_health_endpoint(mcp)
119
224
 
120
225
  # Build the workflow and register all functions with MCP
121
- workflow = builder.build()
226
+ workflow = await builder.build()
122
227
 
123
228
  # Get all functions from the workflow
124
- functions = self._get_all_functions(workflow)
229
+ functions = await self._get_all_functions(workflow)
125
230
 
126
231
  # Filter functions based on tool_names if provided
127
232
  if self.front_end_config.tool_names:
@@ -129,15 +234,22 @@ class MCPFrontEndPluginWorker(MCPFrontEndPluginWorkerBase):
129
234
  filtered_functions: dict[str, Function] = {}
130
235
  for function_name, function in functions.items():
131
236
  if function_name in self.front_end_config.tool_names:
237
+ # Treat current tool_names as function names, so check if the function name is in the list
238
+ filtered_functions[function_name] = function
239
+ elif any(function_name.startswith(f"{group_name}.") for group_name in self.front_end_config.tool_names):
240
+ # Treat tool_names as function group names, so check if the function name starts with the group name
132
241
  filtered_functions[function_name] = function
133
242
  else:
134
243
  logger.debug("Skipping function %s as it's not in tool_names", function_name)
135
244
  functions = filtered_functions
136
245
 
137
- # Register each function with MCP
246
+ # Register each function with MCP, passing workflow context for observability
138
247
  for function_name, function in functions.items():
139
- register_function_with_mcp(mcp, function_name, function)
248
+ register_function_with_mcp(mcp, function_name, function, workflow)
140
249
 
141
250
  # Add a simple fallback function if no functions were found
142
251
  if not functions:
143
252
  raise RuntimeError("No functions found in workflow. Please check your configuration.")
253
+
254
+ # After registration, expose debug endpoints for tool/schema inspection
255
+ self._setup_debug_endpoints(mcp, functions)
@@ -17,13 +17,17 @@ import json
17
17
  import logging
18
18
  from inspect import Parameter
19
19
  from inspect import Signature
20
+ from typing import TYPE_CHECKING
20
21
 
21
22
  from mcp.server.fastmcp import FastMCP
22
23
  from pydantic import BaseModel
23
24
 
25
+ from nat.builder.context import ContextState
24
26
  from nat.builder.function import Function
25
27
  from nat.builder.function_base import FunctionBase
26
- from nat.builder.workflow import Workflow
28
+
29
+ if TYPE_CHECKING:
30
+ from nat.builder.workflow import Workflow
27
31
 
28
32
  logger = logging.getLogger(__name__)
29
33
 
@@ -33,14 +37,16 @@ def create_function_wrapper(
33
37
  function: FunctionBase,
34
38
  schema: type[BaseModel],
35
39
  is_workflow: bool = False,
40
+ workflow: 'Workflow | None' = None,
36
41
  ):
37
42
  """Create a wrapper function that exposes the actual parameters of a NAT Function as an MCP tool.
38
43
 
39
44
  Args:
40
- function_name: The name of the function/tool
41
- function: The NAT Function object
42
- schema: The input schema of the function
43
- is_workflow: Whether the function is a Workflow
45
+ function_name (str): The name of the function/tool
46
+ function (FunctionBase): The NAT Function object
47
+ schema (type[BaseModel]): The input schema of the function
48
+ is_workflow (bool): Whether the function is a Workflow
49
+ workflow (Workflow | None): The parent workflow for observability context
44
50
 
45
51
  Returns:
46
52
  A wrapper function suitable for registration with MCP
@@ -101,6 +107,19 @@ def create_function_wrapper(
101
107
  await ctx.report_progress(0, 100)
102
108
 
103
109
  try:
110
+ # Helper function to wrap function calls with observability
111
+ async def call_with_observability(func_call):
112
+ # Use workflow's observability context (workflow should always be available)
113
+ if not workflow:
114
+ logger.error("Missing workflow context for function %s - observability will not be available",
115
+ function_name)
116
+ raise RuntimeError("Workflow context is required for observability")
117
+
118
+ logger.debug("Starting observability context for function %s", function_name)
119
+ context_state = ContextState.get()
120
+ async with workflow.exporter_manager.start(context_state=context_state):
121
+ return await func_call()
122
+
104
123
  # Special handling for ChatRequest
105
124
  if is_chat_request:
106
125
  from nat.data_models.api_server import ChatRequest
@@ -118,7 +137,7 @@ def create_function_wrapper(
118
137
  result = await runner.result(to_type=str)
119
138
  else:
120
139
  # Regular functions use ainvoke
121
- result = await function.ainvoke(chat_request, to_type=str)
140
+ result = await call_with_observability(lambda: function.ainvoke(chat_request, to_type=str))
122
141
  else:
123
142
  # Regular handling
124
143
  # Handle complex input schema - if we extracted fields from a nested schema,
@@ -129,7 +148,7 @@ def create_function_wrapper(
129
148
  field_type = schema.model_fields[field_name].annotation
130
149
 
131
150
  # If it's a pydantic model, we need to create an instance
132
- if hasattr(field_type, "model_validate"):
151
+ if field_type and hasattr(field_type, "model_validate"):
133
152
  # Create the nested object
134
153
  nested_obj = field_type.model_validate(kwargs)
135
154
  # Call with the nested object
@@ -147,7 +166,7 @@ def create_function_wrapper(
147
166
  result = await runner.result(to_type=str)
148
167
  else:
149
168
  # Regular function call
150
- result = await function.acall_invoke(**kwargs)
169
+ result = await call_with_observability(lambda: function.acall_invoke(**kwargs))
151
170
 
152
171
  # Report completion
153
172
  if ctx:
@@ -156,7 +175,7 @@ def create_function_wrapper(
156
175
  # Handle different result types for proper formatting
157
176
  if isinstance(result, str):
158
177
  return result
159
- if isinstance(result, (dict, list)):
178
+ if isinstance(result, dict | list):
160
179
  return json.dumps(result, default=str)
161
180
  return str(result)
162
181
  except Exception as e:
@@ -170,7 +189,7 @@ def create_function_wrapper(
170
189
  wrapper = create_wrapper()
171
190
 
172
191
  # Set the signature on the wrapper function (WITHOUT ctx)
173
- wrapper.__signature__ = sig
192
+ wrapper.__signature__ = sig # type: ignore
174
193
  wrapper.__name__ = function_name
175
194
 
176
195
  # Return the wrapper with proper signature
@@ -183,8 +202,8 @@ def get_function_description(function: FunctionBase) -> str:
183
202
 
184
203
  The description is determined using the following precedence:
185
204
  1. If the function is a Workflow and has a 'description' attribute, use it.
186
- 2. If the Workflow's config has a 'topic', use it.
187
- 3. If the Workflow's config has a 'description', use it.
205
+ 2. If the Workflow's config has a 'description', use it.
206
+ 3. If the Workflow's config has a 'topic', use it.
188
207
  4. If the function is a regular Function, use its 'description' attribute.
189
208
 
190
209
  Args:
@@ -195,6 +214,9 @@ def get_function_description(function: FunctionBase) -> str:
195
214
  """
196
215
  function_description = ""
197
216
 
217
+ # Import here to avoid circular imports
218
+ from nat.builder.workflow import Workflow
219
+
198
220
  if isinstance(function, Workflow):
199
221
  config = function.config
200
222
 
@@ -207,6 +229,9 @@ def get_function_description(function: FunctionBase) -> str:
207
229
  # Try to get anything that might be a description
208
230
  elif hasattr(config, "topic") and config.topic:
209
231
  function_description = config.topic
232
+ # Try to get description from the workflow config
233
+ elif hasattr(config, "workflow") and hasattr(config.workflow, "description") and config.workflow.description:
234
+ function_description = config.workflow.description
210
235
 
211
236
  elif isinstance(function, Function):
212
237
  function_description = function.description
@@ -214,13 +239,17 @@ def get_function_description(function: FunctionBase) -> str:
214
239
  return function_description
215
240
 
216
241
 
217
- def register_function_with_mcp(mcp: FastMCP, function_name: str, function: FunctionBase) -> None:
242
+ def register_function_with_mcp(mcp: FastMCP,
243
+ function_name: str,
244
+ function: FunctionBase,
245
+ workflow: 'Workflow | None' = None) -> None:
218
246
  """Register a NAT Function as an MCP tool.
219
247
 
220
248
  Args:
221
249
  mcp: The FastMCP instance
222
250
  function_name: The name to register the function under
223
251
  function: The NAT Function to register
252
+ workflow: The parent workflow for observability context (if available)
224
253
  """
225
254
  logger.info("Registering function %s with MCP", function_name)
226
255
 
@@ -229,6 +258,7 @@ def register_function_with_mcp(mcp: FastMCP, function_name: str, function: Funct
229
258
  logger.info("Function %s has input schema: %s", function_name, input_schema)
230
259
 
231
260
  # Check if we're dealing with a Workflow
261
+ from nat.builder.workflow import Workflow
232
262
  is_workflow = isinstance(function, Workflow)
233
263
  if is_workflow:
234
264
  logger.info("Function %s is a Workflow", function_name)
@@ -237,5 +267,5 @@ def register_function_with_mcp(mcp: FastMCP, function_name: str, function: Funct
237
267
  function_description = get_function_description(function)
238
268
 
239
269
  # Create and register the wrapper function with MCP
240
- wrapper_func = create_function_wrapper(function_name, function, input_schema, is_workflow)
270
+ wrapper_func = create_function_wrapper(function_name, function, input_schema, is_workflow, workflow)
241
271
  mcp.tool(name=function_name, description=function_description)(wrapper_func)
@@ -13,7 +13,6 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- # pylint: disable=unused-import
17
16
  # flake8: noqa
18
17
  # isort:skip_file
19
18
 
@@ -35,6 +35,8 @@ class SimpleFrontEndPluginBase(FrontEndBase[FrontEndConfigT], ABC):
35
35
 
36
36
  async def run(self):
37
37
 
38
+ await self.pre_run()
39
+
38
40
  # Must yield the workflow function otherwise it cleans up
39
41
  async with WorkflowBuilder.from_config(config=self.full_config) as builder:
40
42
 
@@ -45,7 +47,7 @@ class SimpleFrontEndPluginBase(FrontEndBase[FrontEndConfigT], ABC):
45
47
 
46
48
  click.echo(stream.getvalue())
47
49
 
48
- workflow = builder.build()
50
+ workflow = await builder.build()
49
51
  session_manager = SessionManager(workflow)
50
52
  await self.run_workflow(session_manager)
51
53
 
@@ -21,27 +21,39 @@ from nat.builder.builder import Builder
21
21
  from nat.builder.llm import LLMProviderInfo
22
22
  from nat.cli.register_workflow import register_llm_provider
23
23
  from nat.data_models.llm import LLMBaseConfig
24
+ from nat.data_models.optimizable import OptimizableField
25
+ from nat.data_models.optimizable import OptimizableMixin
26
+ from nat.data_models.optimizable import SearchSpace
24
27
  from nat.data_models.retry_mixin import RetryMixin
28
+ from nat.data_models.temperature_mixin import TemperatureMixin
29
+ from nat.data_models.thinking_mixin import ThinkingMixin
30
+ from nat.data_models.top_p_mixin import TopPMixin
25
31
 
26
32
 
27
- class AWSBedrockModelConfig(LLMBaseConfig, RetryMixin, name="aws_bedrock"):
33
+ class AWSBedrockModelConfig(LLMBaseConfig,
34
+ RetryMixin,
35
+ OptimizableMixin,
36
+ TemperatureMixin,
37
+ TopPMixin,
38
+ ThinkingMixin,
39
+ name="aws_bedrock"):
28
40
  """An AWS Bedrock llm provider to be used with an LLM client."""
29
41
 
30
- model_config = ConfigDict(protected_namespaces=())
42
+ model_config = ConfigDict(protected_namespaces=(), extra="allow")
31
43
 
32
44
  # Completion parameters
33
45
  model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
34
46
  serialization_alias="model",
35
47
  description="The model name for the hosted AWS Bedrock.")
36
- temperature: float = Field(default=0.0, ge=0.0, le=1.0, description="Sampling temperature in [0, 1].")
37
- max_tokens: int | None = Field(default=1024,
38
- gt=0,
39
- description="Maximum number of tokens to generate."
40
- "This field is ONLY required when using AWS Bedrock with Langchain.")
41
- context_size: int | None = Field(default=1024,
42
- gt=0,
43
- description="Maximum number of tokens to generate."
44
- "This field is ONLY required when using AWS Bedrock with LlamaIndex.")
48
+ max_tokens: int = OptimizableField(default=300,
49
+ description="Maximum number of tokens to generate.",
50
+ space=SearchSpace(high=2176, low=128, step=512))
51
+ context_size: int | None = Field(
52
+ default=1024,
53
+ gt=0,
54
+ description="The maximum number of tokens available for input. This is only required for LlamaIndex. "
55
+ "This field is ignored for LangChain/LangGraph.",
56
+ )
45
57
 
46
58
  # Client parameters
47
59
  region_name: str | None = Field(default="None", description="AWS region to use.")
@@ -52,6 +64,6 @@ class AWSBedrockModelConfig(LLMBaseConfig, RetryMixin, name="aws_bedrock"):
52
64
 
53
65
 
54
66
  @register_llm_provider(config_type=AWSBedrockModelConfig)
55
- async def aws_bedrock_model(llm_config: AWSBedrockModelConfig, builder: Builder):
67
+ async def aws_bedrock_model(llm_config: AWSBedrockModelConfig, _builder: Builder):
56
68
 
57
69
  yield LLMProviderInfo(config=llm_config, description="A AWS Bedrock model for use with an LLM client.")
@@ -0,0 +1,57 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from pydantic import AliasChoices
17
+ from pydantic import ConfigDict
18
+ from pydantic import Field
19
+
20
+ from nat.builder.builder import Builder
21
+ from nat.builder.llm import LLMProviderInfo
22
+ from nat.cli.register_workflow import register_llm_provider
23
+ from nat.data_models.llm import LLMBaseConfig
24
+ from nat.data_models.retry_mixin import RetryMixin
25
+ from nat.data_models.temperature_mixin import TemperatureMixin
26
+ from nat.data_models.thinking_mixin import ThinkingMixin
27
+ from nat.data_models.top_p_mixin import TopPMixin
28
+
29
+
30
+ class AzureOpenAIModelConfig(
31
+ LLMBaseConfig,
32
+ RetryMixin,
33
+ TemperatureMixin,
34
+ TopPMixin,
35
+ ThinkingMixin,
36
+ name="azure_openai",
37
+ ):
38
+ """An Azure OpenAI LLM provider to be used with an LLM client."""
39
+
40
+ model_config = ConfigDict(protected_namespaces=(), extra="allow")
41
+
42
+ api_key: str | None = Field(default=None, description="Azure OpenAI API key to interact with hosted model.")
43
+ api_version: str = Field(default="2025-04-01-preview", description="Azure OpenAI API version.")
44
+ azure_endpoint: str | None = Field(validation_alias=AliasChoices("azure_endpoint", "base_url"),
45
+ serialization_alias="azure_endpoint",
46
+ default=None,
47
+ description="Base URL for the hosted model.")
48
+ azure_deployment: str = Field(validation_alias=AliasChoices("azure_deployment", "model_name", "model"),
49
+ serialization_alias="azure_deployment",
50
+ description="The Azure OpenAI hosted model/deployment name.")
51
+ seed: int | None = Field(default=None, description="Random seed to set for generation.")
52
+
53
+
54
+ @register_llm_provider(config_type=AzureOpenAIModelConfig)
55
+ async def azure_openai_llm(config: AzureOpenAIModelConfig, _builder: Builder):
56
+
57
+ yield LLMProviderInfo(config=config, description="An Azure OpenAI model for use with an LLM client.")