nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.4.0a20251112__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. nat/agent/base.py +13 -8
  2. nat/agent/prompt_optimizer/prompt.py +68 -0
  3. nat/agent/prompt_optimizer/register.py +149 -0
  4. nat/agent/react_agent/agent.py +6 -5
  5. nat/agent/react_agent/register.py +49 -39
  6. nat/agent/reasoning_agent/reasoning_agent.py +17 -15
  7. nat/agent/register.py +2 -0
  8. nat/agent/responses_api_agent/__init__.py +14 -0
  9. nat/agent/responses_api_agent/register.py +126 -0
  10. nat/agent/rewoo_agent/agent.py +304 -117
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +51 -38
  13. nat/agent/tool_calling_agent/agent.py +75 -17
  14. nat/agent/tool_calling_agent/register.py +46 -23
  15. nat/authentication/api_key/api_key_auth_provider.py +6 -11
  16. nat/authentication/api_key/api_key_auth_provider_config.py +8 -5
  17. nat/authentication/credential_validator/__init__.py +14 -0
  18. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  19. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  20. nat/authentication/interfaces.py +5 -2
  21. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
  22. nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +2 -1
  23. nat/authentication/oauth2/oauth2_resource_server_config.py +125 -0
  24. nat/builder/builder.py +55 -23
  25. nat/builder/component_utils.py +9 -5
  26. nat/builder/context.py +54 -15
  27. nat/builder/eval_builder.py +14 -9
  28. nat/builder/framework_enum.py +1 -0
  29. nat/builder/front_end.py +1 -1
  30. nat/builder/function.py +370 -0
  31. nat/builder/function_info.py +1 -1
  32. nat/builder/intermediate_step_manager.py +38 -2
  33. nat/builder/workflow.py +5 -0
  34. nat/builder/workflow_builder.py +306 -54
  35. nat/cli/cli_utils/config_override.py +1 -1
  36. nat/cli/commands/info/info.py +16 -6
  37. nat/cli/commands/mcp/__init__.py +14 -0
  38. nat/cli/commands/mcp/mcp.py +986 -0
  39. nat/cli/commands/optimize.py +90 -0
  40. nat/cli/commands/start.py +1 -1
  41. nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
  42. nat/cli/commands/workflow/templates/register.py.j2 +2 -2
  43. nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
  44. nat/cli/commands/workflow/workflow_commands.py +60 -18
  45. nat/cli/entrypoint.py +15 -11
  46. nat/cli/main.py +3 -0
  47. nat/cli/register_workflow.py +38 -4
  48. nat/cli/type_registry.py +72 -1
  49. nat/control_flow/__init__.py +0 -0
  50. nat/control_flow/register.py +20 -0
  51. nat/control_flow/router_agent/__init__.py +0 -0
  52. nat/control_flow/router_agent/agent.py +329 -0
  53. nat/control_flow/router_agent/prompt.py +48 -0
  54. nat/control_flow/router_agent/register.py +91 -0
  55. nat/control_flow/sequential_executor.py +166 -0
  56. nat/data_models/agent.py +34 -0
  57. nat/data_models/api_server.py +199 -69
  58. nat/data_models/authentication.py +23 -9
  59. nat/data_models/common.py +47 -0
  60. nat/data_models/component.py +2 -0
  61. nat/data_models/component_ref.py +11 -0
  62. nat/data_models/config.py +41 -17
  63. nat/data_models/dataset_handler.py +4 -3
  64. nat/data_models/function.py +34 -0
  65. nat/data_models/function_dependencies.py +8 -0
  66. nat/data_models/intermediate_step.py +9 -1
  67. nat/data_models/llm.py +15 -1
  68. nat/data_models/openai_mcp.py +46 -0
  69. nat/data_models/optimizable.py +208 -0
  70. nat/data_models/optimizer.py +161 -0
  71. nat/data_models/span.py +41 -3
  72. nat/data_models/thinking_mixin.py +2 -2
  73. nat/embedder/azure_openai_embedder.py +2 -1
  74. nat/embedder/nim_embedder.py +3 -2
  75. nat/embedder/openai_embedder.py +3 -2
  76. nat/eval/config.py +1 -1
  77. nat/eval/dataset_handler/dataset_downloader.py +3 -2
  78. nat/eval/dataset_handler/dataset_filter.py +34 -2
  79. nat/eval/evaluate.py +10 -3
  80. nat/eval/evaluator/base_evaluator.py +1 -1
  81. nat/eval/rag_evaluator/evaluate.py +7 -4
  82. nat/eval/register.py +4 -0
  83. nat/eval/runtime_evaluator/__init__.py +14 -0
  84. nat/eval/runtime_evaluator/evaluate.py +123 -0
  85. nat/eval/runtime_evaluator/register.py +100 -0
  86. nat/eval/swe_bench_evaluator/evaluate.py +1 -1
  87. nat/eval/trajectory_evaluator/register.py +1 -1
  88. nat/eval/tunable_rag_evaluator/evaluate.py +1 -1
  89. nat/eval/usage_stats.py +2 -0
  90. nat/eval/utils/output_uploader.py +3 -2
  91. nat/eval/utils/weave_eval.py +17 -3
  92. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  93. nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
  94. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  95. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
  96. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +3 -3
  97. nat/experimental/test_time_compute/models/strategy_base.py +2 -2
  98. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
  99. nat/front_ends/console/authentication_flow_handler.py +82 -30
  100. nat/front_ends/console/console_front_end_plugin.py +19 -7
  101. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
  102. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  103. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  104. nat/front_ends/fastapi/fastapi_front_end_config.py +25 -3
  105. nat/front_ends/fastapi/fastapi_front_end_plugin.py +140 -3
  106. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +445 -265
  107. nat/front_ends/fastapi/job_store.py +518 -99
  108. nat/front_ends/fastapi/main.py +11 -19
  109. nat/front_ends/fastapi/message_handler.py +69 -44
  110. nat/front_ends/fastapi/message_validator.py +8 -7
  111. nat/front_ends/fastapi/utils.py +57 -0
  112. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  113. nat/front_ends/mcp/mcp_front_end_config.py +71 -3
  114. nat/front_ends/mcp/mcp_front_end_plugin.py +85 -21
  115. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +248 -29
  116. nat/front_ends/mcp/memory_profiler.py +320 -0
  117. nat/front_ends/mcp/tool_converter.py +78 -25
  118. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  119. nat/llm/aws_bedrock_llm.py +21 -8
  120. nat/llm/azure_openai_llm.py +14 -5
  121. nat/llm/litellm_llm.py +80 -0
  122. nat/llm/nim_llm.py +23 -9
  123. nat/llm/openai_llm.py +19 -7
  124. nat/llm/register.py +4 -0
  125. nat/llm/utils/thinking.py +1 -1
  126. nat/observability/exporter/base_exporter.py +1 -1
  127. nat/observability/exporter/processing_exporter.py +29 -55
  128. nat/observability/exporter/span_exporter.py +43 -15
  129. nat/observability/exporter_manager.py +2 -2
  130. nat/observability/mixin/redaction_config_mixin.py +5 -4
  131. nat/observability/mixin/tagging_config_mixin.py +26 -14
  132. nat/observability/mixin/type_introspection_mixin.py +420 -107
  133. nat/observability/processor/batching_processor.py +1 -1
  134. nat/observability/processor/processor.py +3 -0
  135. nat/observability/processor/redaction/__init__.py +24 -0
  136. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  137. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  138. nat/observability/processor/redaction/redaction_processor.py +177 -0
  139. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  140. nat/observability/processor/span_tagging_processor.py +21 -14
  141. nat/observability/register.py +16 -0
  142. nat/profiler/callbacks/langchain_callback_handler.py +32 -7
  143. nat/profiler/callbacks/llama_index_callback_handler.py +36 -2
  144. nat/profiler/callbacks/token_usage_base_model.py +2 -0
  145. nat/profiler/decorators/framework_wrapper.py +61 -9
  146. nat/profiler/decorators/function_tracking.py +35 -3
  147. nat/profiler/forecasting/models/linear_model.py +1 -1
  148. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  149. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  150. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
  151. nat/profiler/parameter_optimization/__init__.py +0 -0
  152. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  153. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  154. nat/profiler/parameter_optimization/parameter_optimizer.py +189 -0
  155. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  156. nat/profiler/parameter_optimization/pareto_visualizer.py +460 -0
  157. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  158. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  159. nat/profiler/utils.py +3 -1
  160. nat/registry_handlers/pypi/register_pypi.py +5 -3
  161. nat/registry_handlers/rest/register_rest.py +5 -3
  162. nat/retriever/milvus/retriever.py +1 -1
  163. nat/retriever/nemo_retriever/register.py +2 -1
  164. nat/runtime/loader.py +1 -1
  165. nat/runtime/runner.py +111 -6
  166. nat/runtime/session.py +49 -3
  167. nat/settings/global_settings.py +2 -2
  168. nat/tool/chat_completion.py +4 -1
  169. nat/tool/code_execution/code_sandbox.py +3 -6
  170. nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +19 -32
  171. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +6 -1
  172. nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +2 -0
  173. nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +10 -4
  174. nat/tool/datetime_tools.py +1 -1
  175. nat/tool/github_tools.py +450 -0
  176. nat/tool/memory_tools/add_memory_tool.py +3 -3
  177. nat/tool/memory_tools/delete_memory_tool.py +3 -4
  178. nat/tool/memory_tools/get_memory_tool.py +4 -4
  179. nat/tool/register.py +2 -7
  180. nat/tool/server_tools.py +15 -2
  181. nat/utils/__init__.py +76 -0
  182. nat/utils/callable_utils.py +70 -0
  183. nat/utils/data_models/schema_validator.py +1 -1
  184. nat/utils/decorators.py +210 -0
  185. nat/utils/exception_handlers/automatic_retries.py +278 -72
  186. nat/utils/io/yaml_tools.py +73 -3
  187. nat/utils/log_levels.py +25 -0
  188. nat/utils/responses_api.py +26 -0
  189. nat/utils/string_utils.py +16 -0
  190. nat/utils/type_converter.py +12 -3
  191. nat/utils/type_utils.py +6 -2
  192. nvidia_nat-1.4.0a20251112.dist-info/METADATA +197 -0
  193. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/RECORD +199 -165
  194. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/entry_points.txt +1 -0
  195. nat/cli/commands/info/list_mcp.py +0 -461
  196. nat/data_models/temperature_mixin.py +0 -43
  197. nat/data_models/top_p_mixin.py +0 -43
  198. nat/observability/processor/header_redaction_processor.py +0 -123
  199. nat/observability/processor/redaction_processor.py +0 -77
  200. nat/tool/code_execution/test_code_execution_sandbox.py +0 -414
  201. nat/tool/github_tools/create_github_commit.py +0 -133
  202. nat/tool/github_tools/create_github_issue.py +0 -87
  203. nat/tool/github_tools/create_github_pr.py +0 -106
  204. nat/tool/github_tools/get_github_file.py +0 -106
  205. nat/tool/github_tools/get_github_issue.py +0 -166
  206. nat/tool/github_tools/get_github_pr.py +0 -256
  207. nat/tool/github_tools/update_github_issue.py +0 -100
  208. nvidia_nat-1.3.0a20250910.dist-info/METADATA +0 -373
  209. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  210. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/WHEEL +0 -0
  211. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  212. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE.md +0 -0
  213. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/top_level.txt +0 -0
@@ -13,19 +13,24 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import importlib
17
16
  import logging
18
17
  import os
18
+ import typing
19
19
 
20
20
  from nat.front_ends.fastapi.fastapi_front_end_plugin_worker import FastApiFrontEndPluginWorkerBase
21
+ from nat.front_ends.fastapi.utils import get_config_file_path
22
+ from nat.front_ends.fastapi.utils import import_class_from_string
21
23
  from nat.runtime.loader import load_config
22
24
 
25
+ if typing.TYPE_CHECKING:
26
+ from fastapi import FastAPI
27
+
23
28
  logger = logging.getLogger(__name__)
24
29
 
25
30
 
26
- def get_app():
31
+ def get_app() -> "FastAPI":
27
32
 
28
- config_file_path = os.getenv("NAT_CONFIG_FILE")
33
+ config_file_path = get_config_file_path()
29
34
  front_end_worker_full_name = os.getenv("NAT_FRONT_END_WORKER")
30
35
 
31
36
  if (not config_file_path):
@@ -36,28 +41,15 @@ def get_app():
36
41
 
37
42
  # Try to import the front end worker class
38
43
  try:
39
- # Split the package from the class
40
- front_end_worker_parts = front_end_worker_full_name.split(".")
41
-
42
- front_end_worker_module_name = ".".join(front_end_worker_parts[:-1])
43
- front_end_worker_class_name = front_end_worker_parts[-1]
44
-
45
- front_end_worker_module = importlib.import_module(front_end_worker_module_name)
46
-
47
- if not hasattr(front_end_worker_module, front_end_worker_class_name):
48
- raise ValueError(f"Front end worker {front_end_worker_full_name} not found.")
49
-
50
- front_end_worker_class: type[FastApiFrontEndPluginWorkerBase] = getattr(front_end_worker_module,
51
- front_end_worker_class_name)
44
+ front_end_worker_class: type[FastApiFrontEndPluginWorkerBase] = import_class_from_string(
45
+ front_end_worker_full_name)
52
46
 
53
47
  if (not issubclass(front_end_worker_class, FastApiFrontEndPluginWorkerBase)):
54
48
  raise ValueError(
55
49
  f"Front end worker {front_end_worker_full_name} is not a subclass of FastApiFrontEndPluginWorker.")
56
50
 
57
51
  # Load the config
58
- abs_config_file_path = os.path.abspath(config_file_path)
59
-
60
- config = load_config(abs_config_file_path)
52
+ config = load_config(config_file_path)
61
53
 
62
54
  # Create an instance of the front end worker class
63
55
  front_end_worker = front_end_worker_class(config)
@@ -25,6 +25,7 @@ from pydantic import ValidationError
25
25
  from starlette.websockets import WebSocketDisconnect
26
26
 
27
27
  from nat.authentication.interfaces import FlowHandlerBase
28
+ from nat.data_models.api_server import ChatRequest
28
29
  from nat.data_models.api_server import ChatResponse
29
30
  from nat.data_models.api_server import ChatResponseChunk
30
31
  from nat.data_models.api_server import Error
@@ -33,6 +34,8 @@ from nat.data_models.api_server import ResponsePayloadOutput
33
34
  from nat.data_models.api_server import ResponseSerializable
34
35
  from nat.data_models.api_server import SystemResponseContent
35
36
  from nat.data_models.api_server import TextContent
37
+ from nat.data_models.api_server import UserMessageContentRoleType
38
+ from nat.data_models.api_server import UserMessages
36
39
  from nat.data_models.api_server import WebSocketMessageStatus
37
40
  from nat.data_models.api_server import WebSocketMessageType
38
41
  from nat.data_models.api_server import WebSocketSystemInteractionMessage
@@ -64,12 +67,12 @@ class WebSocketMessageHandler:
64
67
  self._running_workflow_task: asyncio.Task | None = None
65
68
  self._message_parent_id: str = "default_id"
66
69
  self._conversation_id: str | None = None
67
- self._workflow_schema_type: str = None
68
- self._user_interaction_response: asyncio.Future[HumanResponse] | None = None
70
+ self._workflow_schema_type: str | None = None
71
+ self._user_interaction_response: asyncio.Future[TextContent] | None = None
69
72
 
70
73
  self._flow_handler: FlowHandlerBase | None = None
71
74
 
72
- self._schema_output_mapping: dict[str, type[BaseModel] | None] = {
75
+ self._schema_output_mapping: dict[str, type[BaseModel] | type[None]] = {
73
76
  WorkflowSchemaType.GENERATE: self._session_manager.workflow.single_output_schema,
74
77
  WorkflowSchemaType.CHAT: ChatResponse,
75
78
  WorkflowSchemaType.CHAT_STREAM: ChatResponseChunk,
@@ -105,45 +108,67 @@ class WebSocketMessageHandler:
105
108
  if (isinstance(validated_message, WebSocketUserMessage)):
106
109
  await self.process_workflow_request(validated_message)
107
110
 
108
- elif isinstance(validated_message,
109
- (WebSocketSystemResponseTokenMessage,
110
- WebSocketSystemIntermediateStepMessage,
111
- WebSocketSystemInteractionMessage)):
111
+ elif isinstance(
112
+ validated_message,
113
+ WebSocketSystemResponseTokenMessage | WebSocketSystemIntermediateStepMessage
114
+ | WebSocketSystemInteractionMessage):
112
115
  # These messages are already handled by self.create_websocket_message(data_model=value, …)
113
116
  # No further processing is needed here.
114
117
  pass
115
118
 
116
119
  elif (isinstance(validated_message, WebSocketUserInteractionResponseMessage)):
117
- user_content = await self.process_user_message_content(validated_message)
120
+ user_content = await self._process_websocket_user_interaction_response_message(validated_message)
121
+ assert self._user_interaction_response is not None
118
122
  self._user_interaction_response.set_result(user_content)
119
123
  except (asyncio.CancelledError, WebSocketDisconnect):
120
124
  # TODO: Handle the disconnect
121
125
  break
122
126
 
123
- async def process_user_message_content(
124
- self, user_content: WebSocketUserMessage | WebSocketUserInteractionResponseMessage) -> BaseModel | None:
127
+ def _extract_last_user_message_content(self, messages: list[UserMessages]) -> TextContent:
125
128
  """
126
- Processes the contents of a user message.
129
+ Extracts the last user's TextContent from a list of messages.
127
130
 
128
- :param user_content: Incoming content data model.
129
- :return: A validated Pydantic user content model or None if not found.
130
- """
131
+ Args:
132
+ messages: List of UserMessages.
131
133
 
132
- for user_message in user_content.content.messages[::-1]:
133
- if (user_message.role == "user"):
134
+ Returns:
135
+ TextContent object from the last user message.
134
136
 
137
+ Raises:
138
+ ValueError: If no user text content is found.
139
+ """
140
+ for user_message in messages[::-1]:
141
+ if user_message.role == UserMessageContentRoleType.USER:
135
142
  for attachment in user_message.content:
136
-
137
143
  if isinstance(attachment, TextContent):
138
144
  return attachment
145
+ raise ValueError("No user text content found in messages.")
146
+
147
+ async def _process_websocket_user_interaction_response_message(
148
+ self, user_content: WebSocketUserInteractionResponseMessage) -> TextContent:
149
+ """
150
+ Processes a WebSocketUserInteractionResponseMessage.
151
+ """
152
+ return self._extract_last_user_message_content(user_content.content.messages)
139
153
 
140
- return None
154
+ async def _process_websocket_user_message(self, user_content: WebSocketUserMessage) -> ChatRequest | str:
155
+ """
156
+ Processes a WebSocketUserMessage based on schema type.
157
+ """
158
+ if self._workflow_schema_type in [WorkflowSchemaType.CHAT, WorkflowSchemaType.CHAT_STREAM]:
159
+ return ChatRequest(**user_content.content.model_dump(include={"messages"}))
160
+
161
+ elif self._workflow_schema_type in [WorkflowSchemaType.GENERATE, WorkflowSchemaType.GENERATE_STREAM]:
162
+ return self._extract_last_user_message_content(user_content.content.messages).text
163
+
164
+ raise ValueError("Unsupported workflow schema type for WebSocketUserMessage")
141
165
 
142
166
  async def process_workflow_request(self, user_message_as_validated_type: WebSocketUserMessage) -> None:
143
167
  """
144
168
  Process user messages and routes them appropriately.
145
169
 
146
- :param user_message_as_validated_type: A WebSocketUserMessage Data Model instance.
170
+ Args:
171
+ user_message_as_validated_type (WebSocketUserMessage): The validated user message to process.
147
172
  """
148
173
 
149
174
  try:
@@ -151,18 +176,15 @@ class WebSocketMessageHandler:
151
176
  self._workflow_schema_type = user_message_as_validated_type.schema_type
152
177
  self._conversation_id = user_message_as_validated_type.conversation_id
153
178
 
154
- content: BaseModel | None = await self.process_user_message_content(user_message_as_validated_type)
155
-
156
- if content is None:
157
- raise ValueError(f"User message content could not be found: {user_message_as_validated_type}")
179
+ message_content: typing.Any = await self._process_websocket_user_message(user_message_as_validated_type)
158
180
 
159
- if isinstance(content, TextContent) and (self._running_workflow_task is None):
181
+ if (self._running_workflow_task is None):
160
182
 
161
- def _done_callback(task: asyncio.Task):
183
+ def _done_callback(_task: asyncio.Task):
162
184
  self._running_workflow_task = None
163
185
 
164
186
  self._running_workflow_task = asyncio.create_task(
165
- self._run_workflow(payload=content.text,
187
+ self._run_workflow(payload=message_content,
166
188
  user_message_id=self._message_parent_id,
167
189
  conversation_id=self._conversation_id,
168
190
  result_type=self._schema_output_mapping[self._workflow_schema_type],
@@ -180,13 +202,14 @@ class WebSocketMessageHandler:
180
202
  async def create_websocket_message(self,
181
203
  data_model: BaseModel,
182
204
  message_type: str | None = None,
183
- status: str = WebSocketMessageStatus.IN_PROGRESS) -> None:
205
+ status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS) -> None:
184
206
  """
185
207
  Creates a websocket message that will be ready for routing based on message type or data model.
186
208
 
187
- :param data_model: Message content model.
188
- :param message_type: Message content model.
189
- :param status: Message content model.
209
+ Args:
210
+ data_model (BaseModel): Message content model.
211
+ message_type (str | None): Message content model.
212
+ status (WebSocketMessageStatus): Message content model.
190
213
  """
191
214
  try:
192
215
  message: BaseModel | None = None
@@ -196,8 +219,8 @@ class WebSocketMessageHandler:
196
219
 
197
220
  message_schema: type[BaseModel] = await self._message_validator.get_message_schema_by_type(message_type)
198
221
 
199
- if 'id' in data_model.model_fields:
200
- message_id: str = data_model.id
222
+ if hasattr(data_model, 'id'):
223
+ message_id: str = str(getattr(data_model, 'id'))
201
224
  else:
202
225
  message_id = str(uuid.uuid4())
203
226
 
@@ -253,12 +276,15 @@ class WebSocketMessageHandler:
253
276
  Registered human interaction callback that processes human interactions and returns
254
277
  responses from websocket connection.
255
278
 
256
- :param prompt: Incoming interaction content data model.
257
- :return: A Text Content Base Pydantic model.
279
+ Args:
280
+ prompt: Incoming interaction content data model.
281
+
282
+ Returns:
283
+ A Text Content Base Pydantic model.
258
284
  """
259
285
 
260
286
  # First create a future from the loop for the human response
261
- human_response_future: asyncio.Future[HumanResponse] = asyncio.get_running_loop().create_future()
287
+ human_response_future: asyncio.Future[TextContent] = asyncio.get_running_loop().create_future()
262
288
 
263
289
  # Then add the future to the outstanding human prompts dictionary
264
290
  self._user_interaction_response = human_response_future
@@ -274,10 +300,10 @@ class WebSocketMessageHandler:
274
300
  return HumanResponseNotification()
275
301
 
276
302
  # Wait for the human response future to complete
277
- interaction_response: HumanResponse = await human_response_future
303
+ text_content: TextContent = await human_response_future
278
304
 
279
305
  interaction_response: HumanResponse = await self._message_validator.convert_text_content_to_human_response(
280
- interaction_response, prompt.content)
306
+ text_content, prompt.content)
281
307
 
282
308
  return interaction_response
283
309
 
@@ -293,13 +319,12 @@ class WebSocketMessageHandler:
293
319
  output_type: type | None = None) -> None:
294
320
 
295
321
  try:
296
- async with self._session_manager.session(
297
- user_message_id=user_message_id,
298
- conversation_id=conversation_id,
299
- http_connection=self._socket,
300
- user_input_callback=self.human_interaction_callback,
301
- user_authentication_callback=(self._flow_handler.authenticate
302
- if self._flow_handler else None)) as session:
322
+ auth_callback = self._flow_handler.authenticate if self._flow_handler else None
323
+ async with self._session_manager.session(user_message_id=user_message_id,
324
+ conversation_id=conversation_id,
325
+ http_connection=self._socket,
326
+ user_input_callback=self.human_interaction_callback,
327
+ user_authentication_callback=auth_callback) as session:
303
328
 
304
329
  async for value in generate_streaming_response(payload,
305
330
  session_manager=session,
@@ -139,8 +139,10 @@ class MessageValidator:
139
139
  text_content: str = str(data_model.payload)
140
140
  validated_message_content = SystemResponseContent(text=text_content)
141
141
 
142
- elif (isinstance(data_model, (ChatResponse, ChatResponseChunk))):
142
+ elif isinstance(data_model, ChatResponse):
143
143
  validated_message_content = SystemResponseContent(text=data_model.choices[0].message.content)
144
+ elif isinstance(data_model, ChatResponseChunk):
145
+ validated_message_content = SystemResponseContent(text=data_model.choices[0].delta.content)
144
146
 
145
147
  elif (isinstance(data_model, ResponseIntermediateStep)):
146
148
  validated_message_content = SystemIntermediateStepContent(name=data_model.name,
@@ -204,7 +206,7 @@ class MessageValidator:
204
206
 
205
207
  validated_message_type: str = ""
206
208
  try:
207
- if (isinstance(data_model, (ResponsePayloadOutput, ChatResponse, ChatResponseChunk))):
209
+ if (isinstance(data_model, ResponsePayloadOutput | ChatResponse | ChatResponseChunk)):
208
210
  validated_message_type = WebSocketMessageType.RESPONSE_MESSAGE
209
211
 
210
212
  elif (isinstance(data_model, ResponseIntermediateStep)):
@@ -238,10 +240,9 @@ class MessageValidator:
238
240
  thread_id: str = "default",
239
241
  parent_id: str = "default",
240
242
  conversation_id: str | None = None,
241
- content: SystemResponseContent
242
- | Error = SystemResponseContent(),
243
+ content: SystemResponseContent | Error = SystemResponseContent(),
243
244
  status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
244
- timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
245
+ timestamp: str = str(datetime.datetime.now(datetime.UTC))
245
246
  ) -> WebSocketSystemResponseTokenMessage | None:
246
247
  """
247
248
  Creates a system response token message with default values.
@@ -280,7 +281,7 @@ class MessageValidator:
280
281
  conversation_id: str | None = None,
281
282
  content: SystemIntermediateStepContent = SystemIntermediateStepContent(name="default", payload="default"),
282
283
  status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
283
- timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
284
+ timestamp: str = str(datetime.datetime.now(datetime.UTC))
284
285
  ) -> WebSocketSystemIntermediateStepMessage | None:
285
286
  """
286
287
  Creates a system intermediate step message with default values.
@@ -320,7 +321,7 @@ class MessageValidator:
320
321
  conversation_id: str | None = None,
321
322
  content: HumanPrompt,
322
323
  status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
323
- timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
324
+ timestamp: str = str(datetime.datetime.now(datetime.UTC))
324
325
  ) -> WebSocketSystemInteractionMessage | None:
325
326
  """
326
327
  Creates a system interaction message with default values.
@@ -0,0 +1,57 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import importlib
17
+ import os
18
+
19
+
20
+ def get_config_file_path() -> str:
21
+ """
22
+ Get the path to the NAT configuration file from the environment variable NAT_CONFIG_FILE.
23
+ Raises ValueError if the environment variable is not set.
24
+ """
25
+ config_file_path = os.getenv("NAT_CONFIG_FILE")
26
+ if (not config_file_path):
27
+ raise ValueError("Config file not found in environment variable NAT_CONFIG_FILE.")
28
+
29
+ return os.path.abspath(config_file_path)
30
+
31
+
32
+ def import_class_from_string(class_full_name: str) -> type:
33
+ """
34
+ Import a class from a string in the format 'module.submodule.ClassName'.
35
+ Raises ImportError if the class cannot be imported.
36
+ """
37
+ try:
38
+ class_name_parts = class_full_name.split(".")
39
+
40
+ module_name = ".".join(class_name_parts[:-1])
41
+ class_name = class_name_parts[-1]
42
+
43
+ module = importlib.import_module(module_name)
44
+
45
+ if not hasattr(module, class_name):
46
+ raise ValueError(f"Class '{class_full_name}' not found.")
47
+
48
+ return getattr(module, class_name)
49
+ except (ImportError, AttributeError) as e:
50
+ raise ImportError(f"Could not import {class_full_name}.") from e
51
+
52
+
53
+ def get_class_name(cls: type) -> str:
54
+ """
55
+ Get the full class name including the module.
56
+ """
57
+ return f"{cls.__module__}.{cls.__qualname__}"
@@ -0,0 +1,73 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """OAuth 2.0 Token Introspection verifier implementation for MCP servers."""
16
+
17
+ import logging
18
+
19
+ from mcp.server.auth.provider import AccessToken
20
+ from mcp.server.auth.provider import TokenVerifier
21
+
22
+ from nat.authentication.credential_validator.bearer_token_validator import BearerTokenValidator
23
+ from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class IntrospectionTokenVerifier(TokenVerifier):
29
+ """Token verifier that delegates token verification to BearerTokenValidator."""
30
+
31
+ def __init__(self, config: OAuth2ResourceServerConfig):
32
+ """Create IntrospectionTokenVerifier from OAuth2ResourceServerConfig.
33
+
34
+ Args:
35
+ config: OAuth2ResourceServerConfig
36
+ """
37
+ issuer = config.issuer_url
38
+ scopes = config.scopes or []
39
+ audience = config.audience
40
+ jwks_uri = config.jwks_uri
41
+ introspection_endpoint = config.introspection_endpoint
42
+ discovery_url = config.discovery_url
43
+ client_id = config.client_id
44
+ client_secret = config.client_secret
45
+
46
+ self._bearer_token_validator = BearerTokenValidator(
47
+ issuer=issuer,
48
+ audience=audience,
49
+ scopes=scopes,
50
+ jwks_uri=jwks_uri,
51
+ introspection_endpoint=introspection_endpoint,
52
+ discovery_url=discovery_url,
53
+ client_id=client_id,
54
+ client_secret=client_secret,
55
+ )
56
+
57
+ async def verify_token(self, token: str) -> AccessToken | None:
58
+ """Verify token by delegating to BearerTokenValidator.
59
+
60
+ Args:
61
+ token: The Bearer token to verify
62
+
63
+ Returns:
64
+ AccessToken | None: AccessToken if valid, None if invalid
65
+ """
66
+ validation_result = await self._bearer_token_validator.verify(token)
67
+
68
+ if validation_result.active:
69
+ return AccessToken(token=token,
70
+ expires_at=validation_result.expires_at,
71
+ scopes=validation_result.scopes or [],
72
+ client_id=validation_result.client_id or "")
73
+ return None
@@ -13,17 +13,23 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import logging
16
17
  from typing import Literal
17
18
 
18
19
  from pydantic import Field
20
+ from pydantic import field_validator
21
+ from pydantic import model_validator
19
22
 
23
+ from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
20
24
  from nat.data_models.front_end import FrontEndBaseConfig
21
25
 
26
+ logger = logging.getLogger(__name__)
27
+
22
28
 
23
29
  class MCPFrontEndConfig(FrontEndBaseConfig, name="mcp"):
24
30
  """MCP front end configuration.
25
31
 
26
- A simple MCP (Modular Communication Protocol) front end for NeMo Agent toolkit.
32
+ A simple MCP (Model Context Protocol) front end for NeMo Agent toolkit.
27
33
  """
28
34
 
29
35
  name: str = Field(default="NeMo Agent Toolkit MCP",
@@ -32,10 +38,72 @@ class MCPFrontEndConfig(FrontEndBaseConfig, name="mcp"):
32
38
  port: int = Field(default=9901, description="Port to bind the server to (default: 9901)", ge=0, le=65535)
33
39
  debug: bool = Field(default=False, description="Enable debug mode (default: False)")
34
40
  log_level: str = Field(default="INFO", description="Log level for the MCP server (default: INFO)")
35
- tool_names: list[str] = Field(default_factory=list,
36
- description="The list of tools MCP server will expose (default: all tools)")
41
+ tool_names: list[str] = Field(
42
+ default_factory=list,
43
+ description="The list of tools MCP server will expose (default: all tools)."
44
+ "Tool names can be functions or function groups",
45
+ )
37
46
  transport: Literal["sse", "streamable-http"] = Field(
38
47
  default="streamable-http",
39
48
  description="Transport type for the MCP server (default: streamable-http, backwards compatible with sse)")
40
49
  runner_class: str | None = Field(
41
50
  default=None, description="Custom worker class for handling MCP routes (default: built-in worker)")
51
+ base_path: str | None = Field(default=None,
52
+ description="Base path to mount the MCP server at (e.g., '/api/v1'). "
53
+ "If specified, the server will be accessible at http://host:port{base_path}/mcp. "
54
+ "If None, server runs at root path /mcp.")
55
+
56
+ server_auth: OAuth2ResourceServerConfig | None = Field(
57
+ default=None, description=("OAuth 2.0 Resource Server configuration for token verification."))
58
+
59
+ @field_validator('base_path')
60
+ @classmethod
61
+ def validate_base_path(cls, v: str | None) -> str | None:
62
+ """Validate that base_path starts with '/' and doesn't end with '/'."""
63
+ if v is not None:
64
+ if not v.startswith('/'):
65
+ raise ValueError("base_path must start with '/'")
66
+ if v.endswith('/'):
67
+ raise ValueError("base_path must not end with '/'")
68
+ return v
69
+
70
+ # Memory profiling configuration
71
+ enable_memory_profiling: bool = Field(default=False,
72
+ description="Enable memory profiling and diagnostics (default: False)")
73
+ memory_profile_interval: int = Field(default=50,
74
+ description="Log memory stats every N requests (default: 50)",
75
+ ge=1)
76
+ memory_profile_top_n: int = Field(default=10,
77
+ description="Number of top memory allocations to log (default: 10)",
78
+ ge=1,
79
+ le=50)
80
+ memory_profile_log_level: str = Field(default="DEBUG",
81
+ description="Log level for memory profiling output (default: DEBUG)")
82
+
83
+ @model_validator(mode="after")
84
+ def validate_security_configuration(self):
85
+ """Validate security configuration to prevent accidental misconfigurations."""
86
+ # Check if server is bound to a non-localhost interface without authentication
87
+ localhost_hosts = {"localhost", "127.0.0.1", "::1"}
88
+ if self.host not in localhost_hosts and self.server_auth is None:
89
+ logger.warning(
90
+ "MCP server is configured to bind to '%s' without authentication. "
91
+ "This may expose your server to unauthorized access. "
92
+ "Consider either: (1) binding to localhost for local-only access, "
93
+ "or (2) configuring server_auth for production deployments on public interfaces.",
94
+ self.host)
95
+
96
+ # Check if SSE transport is used (which doesn't support authentication)
97
+ if self.transport == "sse":
98
+ if self.server_auth is not None:
99
+ logger.warning("SSE transport does not support authentication. "
100
+ "The configured server_auth will be ignored. "
101
+ "For production use with authentication, use 'streamable-http' transport instead.")
102
+ elif self.host not in localhost_hosts:
103
+ logger.warning(
104
+ "SSE transport does not support authentication and is bound to '%s'. "
105
+ "This configuration is not recommended for production use. "
106
+ "For production deployments, use 'streamable-http' transport with server_auth configured.",
107
+ self.host)
108
+
109
+ return self