nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. aiq/__init__.py +2 -2
  2. nat/agent/base.py +24 -15
  3. nat/agent/dual_node.py +9 -4
  4. nat/agent/prompt_optimizer/prompt.py +68 -0
  5. nat/agent/prompt_optimizer/register.py +149 -0
  6. nat/agent/react_agent/agent.py +79 -47
  7. nat/agent/react_agent/register.py +41 -21
  8. nat/agent/reasoning_agent/reasoning_agent.py +11 -9
  9. nat/agent/register.py +1 -1
  10. nat/agent/rewoo_agent/agent.py +326 -148
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +46 -26
  13. nat/agent/tool_calling_agent/agent.py +84 -28
  14. nat/agent/tool_calling_agent/register.py +51 -28
  15. nat/authentication/api_key/api_key_auth_provider.py +2 -2
  16. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  17. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  18. nat/authentication/interfaces.py +5 -2
  19. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +40 -20
  20. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  21. nat/authentication/register.py +0 -1
  22. nat/builder/builder.py +56 -24
  23. nat/builder/component_utils.py +9 -5
  24. nat/builder/context.py +46 -11
  25. nat/builder/eval_builder.py +16 -11
  26. nat/builder/framework_enum.py +1 -0
  27. nat/builder/front_end.py +1 -1
  28. nat/builder/function.py +378 -8
  29. nat/builder/function_base.py +3 -3
  30. nat/builder/function_info.py +6 -8
  31. nat/builder/user_interaction_manager.py +2 -2
  32. nat/builder/workflow.py +13 -1
  33. nat/builder/workflow_builder.py +281 -76
  34. nat/cli/cli_utils/config_override.py +2 -2
  35. nat/cli/commands/evaluate.py +1 -1
  36. nat/cli/commands/info/info.py +16 -6
  37. nat/cli/commands/info/list_channels.py +1 -1
  38. nat/cli/commands/info/list_components.py +7 -8
  39. nat/cli/commands/mcp/__init__.py +14 -0
  40. nat/cli/commands/mcp/mcp.py +986 -0
  41. nat/cli/commands/object_store/__init__.py +14 -0
  42. nat/cli/commands/object_store/object_store.py +227 -0
  43. nat/cli/commands/optimize.py +90 -0
  44. nat/cli/commands/registry/publish.py +2 -2
  45. nat/cli/commands/registry/pull.py +2 -2
  46. nat/cli/commands/registry/remove.py +2 -2
  47. nat/cli/commands/registry/search.py +15 -17
  48. nat/cli/commands/start.py +16 -5
  49. nat/cli/commands/uninstall.py +1 -1
  50. nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
  51. nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
  52. nat/cli/commands/workflow/templates/register.py.j2 +0 -1
  53. nat/cli/commands/workflow/workflow_commands.py +9 -13
  54. nat/cli/entrypoint.py +8 -10
  55. nat/cli/register_workflow.py +38 -4
  56. nat/cli/type_registry.py +75 -6
  57. nat/control_flow/__init__.py +0 -0
  58. nat/control_flow/register.py +20 -0
  59. nat/control_flow/router_agent/__init__.py +0 -0
  60. nat/control_flow/router_agent/agent.py +329 -0
  61. nat/control_flow/router_agent/prompt.py +48 -0
  62. nat/control_flow/router_agent/register.py +91 -0
  63. nat/control_flow/sequential_executor.py +166 -0
  64. nat/data_models/agent.py +34 -0
  65. nat/data_models/api_server.py +10 -10
  66. nat/data_models/authentication.py +23 -9
  67. nat/data_models/common.py +1 -1
  68. nat/data_models/component.py +2 -0
  69. nat/data_models/component_ref.py +11 -0
  70. nat/data_models/config.py +41 -17
  71. nat/data_models/dataset_handler.py +1 -1
  72. nat/data_models/discovery_metadata.py +4 -4
  73. nat/data_models/evaluate.py +4 -1
  74. nat/data_models/function.py +34 -0
  75. nat/data_models/function_dependencies.py +14 -6
  76. nat/data_models/gated_field_mixin.py +242 -0
  77. nat/data_models/intermediate_step.py +3 -3
  78. nat/data_models/optimizable.py +119 -0
  79. nat/data_models/optimizer.py +149 -0
  80. nat/data_models/swe_bench_model.py +1 -1
  81. nat/data_models/temperature_mixin.py +44 -0
  82. nat/data_models/thinking_mixin.py +86 -0
  83. nat/data_models/top_p_mixin.py +44 -0
  84. nat/embedder/nim_embedder.py +1 -1
  85. nat/embedder/openai_embedder.py +1 -1
  86. nat/embedder/register.py +0 -1
  87. nat/eval/config.py +3 -1
  88. nat/eval/dataset_handler/dataset_handler.py +71 -7
  89. nat/eval/evaluate.py +86 -31
  90. nat/eval/evaluator/base_evaluator.py +1 -1
  91. nat/eval/evaluator/evaluator_model.py +13 -0
  92. nat/eval/intermediate_step_adapter.py +1 -1
  93. nat/eval/rag_evaluator/evaluate.py +2 -2
  94. nat/eval/rag_evaluator/register.py +3 -3
  95. nat/eval/register.py +4 -1
  96. nat/eval/remote_workflow.py +3 -3
  97. nat/eval/runtime_evaluator/__init__.py +14 -0
  98. nat/eval/runtime_evaluator/evaluate.py +123 -0
  99. nat/eval/runtime_evaluator/register.py +100 -0
  100. nat/eval/swe_bench_evaluator/evaluate.py +6 -6
  101. nat/eval/trajectory_evaluator/evaluate.py +1 -1
  102. nat/eval/trajectory_evaluator/register.py +1 -1
  103. nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
  104. nat/eval/utils/eval_trace_ctx.py +89 -0
  105. nat/eval/utils/weave_eval.py +18 -9
  106. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  107. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  108. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
  109. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
  110. nat/experimental/test_time_compute/models/strategy_base.py +5 -4
  111. nat/experimental/test_time_compute/register.py +0 -1
  112. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
  113. nat/front_ends/console/authentication_flow_handler.py +82 -30
  114. nat/front_ends/console/console_front_end_plugin.py +8 -5
  115. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  116. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  117. nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
  118. nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
  119. nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
  120. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +481 -281
  121. nat/front_ends/fastapi/job_store.py +518 -99
  122. nat/front_ends/fastapi/main.py +11 -19
  123. nat/front_ends/fastapi/message_handler.py +13 -14
  124. nat/front_ends/fastapi/message_validator.py +17 -19
  125. nat/front_ends/fastapi/response_helpers.py +4 -4
  126. nat/front_ends/fastapi/step_adaptor.py +2 -2
  127. nat/front_ends/fastapi/utils.py +57 -0
  128. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  129. nat/front_ends/mcp/mcp_front_end_config.py +10 -1
  130. nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
  131. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
  132. nat/front_ends/mcp/tool_converter.py +44 -14
  133. nat/front_ends/register.py +0 -1
  134. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  135. nat/llm/aws_bedrock_llm.py +24 -12
  136. nat/llm/azure_openai_llm.py +13 -6
  137. nat/llm/litellm_llm.py +69 -0
  138. nat/llm/nim_llm.py +20 -8
  139. nat/llm/openai_llm.py +14 -6
  140. nat/llm/register.py +4 -1
  141. nat/llm/utils/env_config_value.py +2 -3
  142. nat/llm/utils/thinking.py +215 -0
  143. nat/meta/pypi.md +9 -9
  144. nat/object_store/register.py +0 -1
  145. nat/observability/exporter/base_exporter.py +3 -3
  146. nat/observability/exporter/file_exporter.py +1 -1
  147. nat/observability/exporter/processing_exporter.py +309 -81
  148. nat/observability/exporter/span_exporter.py +1 -1
  149. nat/observability/exporter_manager.py +7 -7
  150. nat/observability/mixin/file_mixin.py +7 -7
  151. nat/observability/mixin/redaction_config_mixin.py +42 -0
  152. nat/observability/mixin/tagging_config_mixin.py +62 -0
  153. nat/observability/mixin/type_introspection_mixin.py +420 -107
  154. nat/observability/processor/batching_processor.py +5 -7
  155. nat/observability/processor/falsy_batch_filter_processor.py +55 -0
  156. nat/observability/processor/processor.py +3 -0
  157. nat/observability/processor/processor_factory.py +70 -0
  158. nat/observability/processor/redaction/__init__.py +24 -0
  159. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  160. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  161. nat/observability/processor/redaction/redaction_processor.py +177 -0
  162. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  163. nat/observability/processor/span_tagging_processor.py +68 -0
  164. nat/observability/register.py +6 -4
  165. nat/profiler/calc/calc_runner.py +3 -4
  166. nat/profiler/callbacks/agno_callback_handler.py +1 -1
  167. nat/profiler/callbacks/langchain_callback_handler.py +6 -6
  168. nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
  169. nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
  170. nat/profiler/data_frame_row.py +1 -1
  171. nat/profiler/decorators/framework_wrapper.py +62 -13
  172. nat/profiler/decorators/function_tracking.py +160 -3
  173. nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
  174. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
  175. nat/profiler/inference_optimization/data_models.py +3 -3
  176. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
  177. nat/profiler/inference_optimization/token_uniqueness.py +1 -1
  178. nat/profiler/parameter_optimization/__init__.py +0 -0
  179. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  180. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  181. nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
  182. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  183. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  184. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  185. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  186. nat/profiler/profile_runner.py +14 -9
  187. nat/profiler/utils.py +4 -2
  188. nat/registry_handlers/local/local_handler.py +2 -2
  189. nat/registry_handlers/package_utils.py +1 -2
  190. nat/registry_handlers/pypi/pypi_handler.py +23 -26
  191. nat/registry_handlers/register.py +3 -4
  192. nat/registry_handlers/rest/rest_handler.py +12 -13
  193. nat/retriever/milvus/retriever.py +2 -2
  194. nat/retriever/nemo_retriever/retriever.py +1 -1
  195. nat/retriever/register.py +0 -1
  196. nat/runtime/loader.py +2 -2
  197. nat/runtime/runner.py +3 -2
  198. nat/runtime/session.py +43 -8
  199. nat/settings/global_settings.py +16 -5
  200. nat/tool/chat_completion.py +5 -2
  201. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
  202. nat/tool/datetime_tools.py +49 -9
  203. nat/tool/document_search.py +2 -2
  204. nat/tool/github_tools.py +450 -0
  205. nat/tool/nvidia_rag.py +1 -1
  206. nat/tool/register.py +2 -9
  207. nat/tool/retriever.py +3 -2
  208. nat/utils/callable_utils.py +70 -0
  209. nat/utils/data_models/schema_validator.py +3 -3
  210. nat/utils/exception_handlers/automatic_retries.py +104 -51
  211. nat/utils/exception_handlers/schemas.py +1 -1
  212. nat/utils/io/yaml_tools.py +2 -2
  213. nat/utils/log_levels.py +25 -0
  214. nat/utils/reactive/base/observable_base.py +2 -2
  215. nat/utils/reactive/base/observer_base.py +1 -1
  216. nat/utils/reactive/observable.py +2 -2
  217. nat/utils/reactive/observer.py +4 -4
  218. nat/utils/reactive/subscription.py +1 -1
  219. nat/utils/settings/global_settings.py +6 -8
  220. nat/utils/type_converter.py +4 -3
  221. nat/utils/type_utils.py +9 -5
  222. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/METADATA +42 -16
  223. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/RECORD +230 -189
  224. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/entry_points.txt +1 -0
  225. nat/cli/commands/info/list_mcp.py +0 -304
  226. nat/tool/github_tools/create_github_commit.py +0 -133
  227. nat/tool/github_tools/create_github_issue.py +0 -87
  228. nat/tool/github_tools/create_github_pr.py +0 -106
  229. nat/tool/github_tools/get_github_file.py +0 -106
  230. nat/tool/github_tools/get_github_issue.py +0 -166
  231. nat/tool/github_tools/get_github_pr.py +0 -256
  232. nat/tool/github_tools/update_github_issue.py +0 -100
  233. nat/tool/mcp/exceptions.py +0 -142
  234. nat/tool/mcp/mcp_client.py +0 -255
  235. nat/tool/mcp/mcp_tool.py +0 -96
  236. nat/utils/exception_handlers/mcp.py +0 -211
  237. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  238. /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
  239. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/WHEEL +0 -0
  240. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  241. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE.md +0 -0
  242. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/top_level.txt +0 -0
@@ -13,19 +13,24 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import importlib
17
16
  import logging
18
17
  import os
18
+ import typing
19
19
 
20
20
  from nat.front_ends.fastapi.fastapi_front_end_plugin_worker import FastApiFrontEndPluginWorkerBase
21
+ from nat.front_ends.fastapi.utils import get_config_file_path
22
+ from nat.front_ends.fastapi.utils import import_class_from_string
21
23
  from nat.runtime.loader import load_config
22
24
 
25
+ if typing.TYPE_CHECKING:
26
+ from fastapi import FastAPI
27
+
23
28
  logger = logging.getLogger(__name__)
24
29
 
25
30
 
26
- def get_app():
31
+ def get_app() -> "FastAPI":
27
32
 
28
- config_file_path = os.getenv("NAT_CONFIG_FILE")
33
+ config_file_path = get_config_file_path()
29
34
  front_end_worker_full_name = os.getenv("NAT_FRONT_END_WORKER")
30
35
 
31
36
  if (not config_file_path):
@@ -36,28 +41,15 @@ def get_app():
36
41
 
37
42
  # Try to import the front end worker class
38
43
  try:
39
- # Split the package from the class
40
- front_end_worker_parts = front_end_worker_full_name.split(".")
41
-
42
- front_end_worker_module_name = ".".join(front_end_worker_parts[:-1])
43
- front_end_worker_class_name = front_end_worker_parts[-1]
44
-
45
- front_end_worker_module = importlib.import_module(front_end_worker_module_name)
46
-
47
- if not hasattr(front_end_worker_module, front_end_worker_class_name):
48
- raise ValueError(f"Front end worker {front_end_worker_full_name} not found.")
49
-
50
- front_end_worker_class: type[FastApiFrontEndPluginWorkerBase] = getattr(front_end_worker_module,
51
- front_end_worker_class_name)
44
+ front_end_worker_class: type[FastApiFrontEndPluginWorkerBase] = import_class_from_string(
45
+ front_end_worker_full_name)
52
46
 
53
47
  if (not issubclass(front_end_worker_class, FastApiFrontEndPluginWorkerBase)):
54
48
  raise ValueError(
55
49
  f"Front end worker {front_end_worker_full_name} is not a subclass of FastApiFrontEndPluginWorker.")
56
50
 
57
51
  # Load the config
58
- abs_config_file_path = os.path.abspath(config_file_path)
59
-
60
- config = load_config(abs_config_file_path)
52
+ config = load_config(config_file_path)
61
53
 
62
54
  # Create an instance of the front end worker class
63
55
  front_end_worker = front_end_worker_class(config)
@@ -86,7 +86,7 @@ class WebSocketMessageHandler:
86
86
 
87
87
  async def __aexit__(self, exc_type, exc_value, traceback) -> None:
88
88
 
89
- # TODO: Handle the exit # pylint: disable=fixme
89
+ # TODO: Handle the exit
90
90
  pass
91
91
 
92
92
  async def run(self) -> None:
@@ -107,10 +107,8 @@ class WebSocketMessageHandler:
107
107
 
108
108
  elif isinstance(
109
109
  validated_message,
110
- ( # noqa: E131
111
- WebSocketSystemResponseTokenMessage,
112
- WebSocketSystemIntermediateStepMessage,
113
- WebSocketSystemInteractionMessage)):
110
+ WebSocketSystemResponseTokenMessage | WebSocketSystemIntermediateStepMessage
111
+ | WebSocketSystemInteractionMessage):
114
112
  # These messages are already handled by self.create_websocket_message(data_model=value, …)
115
113
  # No further processing is needed here.
116
114
  pass
@@ -119,11 +117,9 @@ class WebSocketMessageHandler:
119
117
  user_content = await self.process_user_message_content(validated_message)
120
118
  self._user_interaction_response.set_result(user_content)
121
119
  except (asyncio.CancelledError, WebSocketDisconnect):
122
- # TODO: Handle the disconnect # pylint: disable=fixme
120
+ # TODO: Handle the disconnect
123
121
  break
124
122
 
125
- return None
126
-
127
123
  async def process_user_message_content(
128
124
  self, user_content: WebSocketUserMessage | WebSocketUserInteractionResponseMessage) -> BaseModel | None:
129
125
  """
@@ -162,18 +158,19 @@ class WebSocketMessageHandler:
162
158
 
163
159
  if isinstance(content, TextContent) and (self._running_workflow_task is None):
164
160
 
165
- def _done_callback(task: asyncio.Task): # pylint: disable=unused-argument
161
+ def _done_callback(task: asyncio.Task):
166
162
  self._running_workflow_task = None
167
163
 
168
164
  self._running_workflow_task = asyncio.create_task(
169
- self._run_workflow(content.text,
170
- self._conversation_id,
165
+ self._run_workflow(payload=content.text,
166
+ user_message_id=self._message_parent_id,
167
+ conversation_id=self._conversation_id,
171
168
  result_type=self._schema_output_mapping[self._workflow_schema_type],
172
169
  output_type=self._schema_output_mapping[
173
170
  self._workflow_schema_type])).add_done_callback(_done_callback)
174
171
 
175
172
  except ValueError as e:
176
- logger.error("User message content not found: %s", str(e), exc_info=True)
173
+ logger.exception("User message content not found: %s", str(e))
177
174
  await self.create_websocket_message(data_model=Error(code=ErrorTypes.INVALID_USER_MESSAGE_CONTENT,
178
175
  message="User message content could not be found",
179
176
  details=str(e)),
@@ -241,7 +238,7 @@ class WebSocketMessageHandler:
241
238
  f"Message type could not be resolved by input data model: {data_model.model_dump_json()}")
242
239
 
243
240
  except (ValidationError, TypeError, ValueError) as e:
244
- logger.error("A data vaidation error ocurred creating websocket message: %s", str(e), exc_info=True)
241
+ logger.exception("A data vaidation error ocurred creating websocket message: %s", str(e))
245
242
  message = await self._message_validator.create_system_response_token_message(
246
243
  message_type=WebSocketMessageType.ERROR_MESSAGE,
247
244
  conversation_id=self._conversation_id,
@@ -290,14 +287,16 @@ class WebSocketMessageHandler:
290
287
 
291
288
  async def _run_workflow(self,
292
289
  payload: typing.Any,
290
+ user_message_id: str | None = None,
293
291
  conversation_id: str | None = None,
294
292
  result_type: type | None = None,
295
293
  output_type: type | None = None) -> None:
296
294
 
297
295
  try:
298
296
  async with self._session_manager.session(
297
+ user_message_id=user_message_id,
299
298
  conversation_id=conversation_id,
300
- request=self._socket,
299
+ http_connection=self._socket,
301
300
  user_input_callback=self.human_interaction_callback,
302
301
  user_authentication_callback=(self._flow_handler.authenticate
303
302
  if self._flow_handler else None)) as session:
@@ -97,7 +97,7 @@ class MessageValidator:
97
97
  return validated_message
98
98
 
99
99
  except (ValidationError, TypeError, ValueError) as e:
100
- logger.error("A data validation error %s occurred for message: %s", str(e), str(message), exc_info=True)
100
+ logger.exception("A data validation error %s occurred for message: %s", str(e), str(message))
101
101
  return await self.create_system_response_token_message(message_type=WebSocketMessageType.ERROR_MESSAGE,
102
102
  content=Error(code=ErrorTypes.INVALID_MESSAGE,
103
103
  message="Error validating message.",
@@ -119,7 +119,7 @@ class MessageValidator:
119
119
  return schema
120
120
 
121
121
  except (TypeError, ValueError) as e:
122
- logger.error("Error retrieving schema for message type '%s': %s", message_type, str(e), exc_info=True)
122
+ logger.exception("Error retrieving schema for message type '%s': %s", message_type, str(e))
123
123
  return Error
124
124
 
125
125
  async def convert_data_to_message_content(self, data_model: BaseModel) -> BaseModel:
@@ -139,7 +139,7 @@ class MessageValidator:
139
139
  text_content: str = str(data_model.payload)
140
140
  validated_message_content = SystemResponseContent(text=text_content)
141
141
 
142
- elif (isinstance(data_model, (ChatResponse, ChatResponseChunk))):
142
+ elif (isinstance(data_model, ChatResponse | ChatResponseChunk)):
143
143
  validated_message_content = SystemResponseContent(text=data_model.choices[0].message.content)
144
144
 
145
145
  elif (isinstance(data_model, ResponseIntermediateStep)):
@@ -156,7 +156,7 @@ class MessageValidator:
156
156
  return validated_message_content
157
157
 
158
158
  except ValueError as e:
159
- logger.error("Input data could not be converted to validated message content: %s", str(e), exc_info=True)
159
+ logger.exception("Input data could not be converted to validated message content: %s", str(e))
160
160
  return Error(code=ErrorTypes.INVALID_DATA_CONTENT, message="Input data not supported.", details=str(e))
161
161
 
162
162
  async def convert_text_content_to_human_response(self, text_content: TextContent,
@@ -191,7 +191,7 @@ class MessageValidator:
191
191
  return human_response
192
192
 
193
193
  except ValueError as e:
194
- logger.error("Error human response content not found: %s", str(e), exc_info=True)
194
+ logger.exception("Error human response content not found: %s", str(e))
195
195
  return HumanResponseText(text=str(e))
196
196
 
197
197
  async def resolve_message_type_by_data(self, data_model: BaseModel) -> str:
@@ -204,7 +204,7 @@ class MessageValidator:
204
204
 
205
205
  validated_message_type: str = ""
206
206
  try:
207
- if (isinstance(data_model, (ResponsePayloadOutput, ChatResponse, ChatResponseChunk))):
207
+ if (isinstance(data_model, ResponsePayloadOutput | ChatResponse | ChatResponseChunk)):
208
208
  validated_message_type = WebSocketMessageType.RESPONSE_MESSAGE
209
209
 
210
210
  elif (isinstance(data_model, ResponseIntermediateStep)):
@@ -218,9 +218,7 @@ class MessageValidator:
218
218
  return validated_message_type
219
219
 
220
220
  except ValueError as e:
221
- logger.error("Error type not found converting data to validated websocket message content: %s",
222
- str(e),
223
- exc_info=True)
221
+ logger.exception("Error type not found converting data to validated websocket message content: %s", str(e))
224
222
  return WebSocketMessageType.ERROR_MESSAGE
225
223
 
226
224
  async def get_intermediate_step_parent_id(self, data_model: ResponseIntermediateStep) -> str:
@@ -232,7 +230,7 @@ class MessageValidator:
232
230
  """
233
231
  return data_model.parent_id or "root"
234
232
 
235
- async def create_system_response_token_message( # pylint: disable=R0917:too-many-positional-arguments
233
+ async def create_system_response_token_message(
236
234
  self,
237
235
  message_type: Literal[WebSocketMessageType.RESPONSE_MESSAGE,
238
236
  WebSocketMessageType.ERROR_MESSAGE] = WebSocketMessageType.RESPONSE_MESSAGE,
@@ -243,7 +241,7 @@ class MessageValidator:
243
241
  content: SystemResponseContent
244
242
  | Error = SystemResponseContent(),
245
243
  status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
246
- timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
244
+ timestamp: str = str(datetime.datetime.now(datetime.UTC))
247
245
  ) -> WebSocketSystemResponseTokenMessage | None:
248
246
  """
249
247
  Creates a system response token message with default values.
@@ -269,10 +267,10 @@ class MessageValidator:
269
267
  timestamp=timestamp)
270
268
 
271
269
  except Exception as e:
272
- logger.error("Error creating system response token message: %s", str(e), exc_info=True)
270
+ logger.exception("Error creating system response token message: %s", str(e))
273
271
  return None
274
272
 
275
- async def create_system_intermediate_step_message( # pylint: disable=R0917:too-many-positional-arguments
273
+ async def create_system_intermediate_step_message(
276
274
  self,
277
275
  message_type: Literal[WebSocketMessageType.INTERMEDIATE_STEP_MESSAGE] = (
278
276
  WebSocketMessageType.INTERMEDIATE_STEP_MESSAGE),
@@ -282,7 +280,7 @@ class MessageValidator:
282
280
  conversation_id: str | None = None,
283
281
  content: SystemIntermediateStepContent = SystemIntermediateStepContent(name="default", payload="default"),
284
282
  status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
285
- timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
283
+ timestamp: str = str(datetime.datetime.now(datetime.UTC))
286
284
  ) -> WebSocketSystemIntermediateStepMessage | None:
287
285
  """
288
286
  Creates a system intermediate step message with default values.
@@ -308,10 +306,10 @@ class MessageValidator:
308
306
  timestamp=timestamp)
309
307
 
310
308
  except Exception as e:
311
- logger.error("Error creating system intermediate step message: %s", str(e), exc_info=True)
309
+ logger.exception("Error creating system intermediate step message: %s", str(e))
312
310
  return None
313
311
 
314
- async def create_system_interaction_message( # pylint: disable=R0917:too-many-positional-arguments
312
+ async def create_system_interaction_message(
315
313
  self,
316
314
  *,
317
315
  message_type: Literal[WebSocketMessageType.SYSTEM_INTERACTION_MESSAGE] = (
@@ -322,8 +320,8 @@ class MessageValidator:
322
320
  conversation_id: str | None = None,
323
321
  content: HumanPrompt,
324
322
  status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
325
- timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
326
- ) -> WebSocketSystemInteractionMessage | None: # noqa: E125 continuation line with same indent as next logical line
323
+ timestamp: str = str(datetime.datetime.now(datetime.UTC))
324
+ ) -> WebSocketSystemInteractionMessage | None:
327
325
  """
328
326
  Creates a system interaction message with default values.
329
327
 
@@ -348,5 +346,5 @@ class MessageValidator:
348
346
  timestamp=timestamp)
349
347
 
350
348
  except Exception as e:
351
- logger.error("Error creating system interaction message: %s", str(e), exc_info=True)
349
+ logger.exception("Error creating system interaction message: %s", str(e))
352
350
  return None
@@ -98,9 +98,9 @@ async def generate_streaming_response(payload: typing.Any,
98
98
  yield item
99
99
  else:
100
100
  yield ResponsePayloadOutput(payload=item)
101
- except Exception as e:
101
+ except Exception:
102
102
  # Handle exceptions here
103
- raise e
103
+ raise
104
104
  finally:
105
105
  await q.close()
106
106
 
@@ -165,9 +165,9 @@ async def generate_streaming_response_full(payload: typing.Any,
165
165
  yield item
166
166
  else:
167
167
  yield ResponsePayloadOutput(payload=item)
168
- except Exception as e:
168
+ except Exception:
169
169
  # Handle exceptions here
170
- raise e
170
+ raise
171
171
  finally:
172
172
  await q.close()
173
173
 
@@ -289,7 +289,7 @@ class StepAdaptor:
289
289
 
290
290
  return event
291
291
 
292
- def process(self, step: IntermediateStep) -> ResponseSerializable | None: # pylint: disable=R1710
292
+ def process(self, step: IntermediateStep) -> ResponseSerializable | None:
293
293
 
294
294
  # Track the chunk
295
295
  self._history.append(step)
@@ -314,6 +314,6 @@ class StepAdaptor:
314
314
  return self._handle_custom(payload, ancestry)
315
315
 
316
316
  except Exception as e:
317
- logger.error("Error processing intermediate step: %s", e, exc_info=True)
317
+ logger.exception("Error processing intermediate step: %s", e)
318
318
 
319
319
  return None
@@ -0,0 +1,57 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import importlib
17
+ import os
18
+
19
+
20
+ def get_config_file_path() -> str:
21
+ """
22
+ Get the path to the NAT configuration file from the environment variable NAT_CONFIG_FILE.
23
+ Raises ValueError if the environment variable is not set.
24
+ """
25
+ config_file_path = os.getenv("NAT_CONFIG_FILE")
26
+ if (not config_file_path):
27
+ raise ValueError("Config file not found in environment variable NAT_CONFIG_FILE.")
28
+
29
+ return os.path.abspath(config_file_path)
30
+
31
+
32
+ def import_class_from_string(class_full_name: str) -> type:
33
+ """
34
+ Import a class from a string in the format 'module.submodule.ClassName'.
35
+ Raises ImportError if the class cannot be imported.
36
+ """
37
+ try:
38
+ class_name_parts = class_full_name.split(".")
39
+
40
+ module_name = ".".join(class_name_parts[:-1])
41
+ class_name = class_name_parts[-1]
42
+
43
+ module = importlib.import_module(module_name)
44
+
45
+ if not hasattr(module, class_name):
46
+ raise ValueError(f"Class '{class_full_name}' not found.")
47
+
48
+ return getattr(module, class_name)
49
+ except (ImportError, AttributeError) as e:
50
+ raise ImportError(f"Could not import {class_full_name}.") from e
51
+
52
+
53
+ def get_class_name(cls: type) -> str:
54
+ """
55
+ Get the full class name including the module.
56
+ """
57
+ return f"{cls.__module__}.{cls.__qualname__}"
@@ -0,0 +1,73 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """OAuth 2.0 Token Introspection verifier implementation for MCP servers."""
16
+
17
+ import logging
18
+
19
+ from mcp.server.auth.provider import AccessToken
20
+ from mcp.server.auth.provider import TokenVerifier
21
+
22
+ from nat.authentication.credential_validator.bearer_token_validator import BearerTokenValidator
23
+ from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class IntrospectionTokenVerifier(TokenVerifier):
29
+ """Token verifier that delegates token verification to BearerTokenValidator."""
30
+
31
+ def __init__(self, config: OAuth2ResourceServerConfig):
32
+ """Create IntrospectionTokenVerifier from OAuth2ResourceServerConfig.
33
+
34
+ Args:
35
+ config: OAuth2ResourceServerConfig
36
+ """
37
+ issuer = config.issuer_url
38
+ scopes = config.scopes or []
39
+ audience = config.audience
40
+ jwks_uri = config.jwks_uri
41
+ introspection_endpoint = config.introspection_endpoint
42
+ discovery_url = config.discovery_url
43
+ client_id = config.client_id
44
+ client_secret = config.client_secret
45
+
46
+ self._bearer_token_validator = BearerTokenValidator(
47
+ issuer=issuer,
48
+ audience=audience,
49
+ scopes=scopes,
50
+ jwks_uri=jwks_uri,
51
+ introspection_endpoint=introspection_endpoint,
52
+ discovery_url=discovery_url,
53
+ client_id=client_id,
54
+ client_secret=client_secret,
55
+ )
56
+
57
+ async def verify_token(self, token: str) -> AccessToken | None:
58
+ """Verify token by delegating to BearerTokenValidator.
59
+
60
+ Args:
61
+ token: The Bearer token to verify
62
+
63
+ Returns:
64
+ AccessToken | None: AccessToken if valid, None if invalid
65
+ """
66
+ validation_result = await self._bearer_token_validator.verify(token)
67
+
68
+ if validation_result.active:
69
+ return AccessToken(token=token,
70
+ expires_at=validation_result.expires_at,
71
+ scopes=validation_result.scopes or [],
72
+ client_id=validation_result.client_id or "")
73
+ return None
@@ -13,15 +13,18 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ from typing import Literal
17
+
16
18
  from pydantic import Field
17
19
 
20
+ from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
18
21
  from nat.data_models.front_end import FrontEndBaseConfig
19
22
 
20
23
 
21
24
  class MCPFrontEndConfig(FrontEndBaseConfig, name="mcp"):
22
25
  """MCP front end configuration.
23
26
 
24
- A simple MCP (Modular Communication Protocol) front end for NeMo Agent toolkit.
27
+ A simple MCP (Model Context Protocol) front end for NeMo Agent toolkit.
25
28
  """
26
29
 
27
30
  name: str = Field(default="NeMo Agent Toolkit MCP",
@@ -32,5 +35,11 @@ class MCPFrontEndConfig(FrontEndBaseConfig, name="mcp"):
32
35
  log_level: str = Field(default="INFO", description="Log level for the MCP server (default: INFO)")
33
36
  tool_names: list[str] = Field(default_factory=list,
34
37
  description="The list of tools MCP server will expose (default: all tools)")
38
+ transport: Literal["sse", "streamable-http"] = Field(
39
+ default="streamable-http",
40
+ description="Transport type for the MCP server (default: streamable-http, backwards compatible with sse)")
35
41
  runner_class: str | None = Field(
36
42
  default=None, description="Custom worker class for handling MCP routes (default: built-in worker)")
43
+
44
+ server_auth: OAuth2ResourceServerConfig | None = Field(
45
+ default=None, description=("OAuth 2.0 Resource Server configuration for token verification."))
@@ -16,6 +16,7 @@
16
16
  import logging
17
17
  import typing
18
18
 
19
+ from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
19
20
  from nat.builder.front_end import FrontEndBase
20
21
  from nat.builder.workflow_builder import WorkflowBuilder
21
22
  from nat.front_ends.mcp.mcp_front_end_config import MCPFrontEndConfig
@@ -55,27 +56,58 @@ class MCPFrontEndPlugin(FrontEndBase[MCPFrontEndConfig]):
55
56
 
56
57
  return worker_class(self.full_config)
57
58
 
59
+ async def _create_token_verifier(self, token_verifier_config: OAuth2ResourceServerConfig):
60
+ """Create a token verifier based on configuration."""
61
+ from nat.front_ends.mcp.introspection_token_verifier import IntrospectionTokenVerifier
62
+
63
+ if not self.front_end_config.server_auth:
64
+ return None
65
+
66
+ return IntrospectionTokenVerifier(token_verifier_config)
67
+
58
68
  async def run(self) -> None:
59
69
  """Run the MCP server."""
60
70
  # Import FastMCP
61
71
  from mcp.server.fastmcp import FastMCP
62
72
 
63
- # Create an MCP server with the configured parameters
64
- mcp = FastMCP(
65
- self.front_end_config.name,
66
- host=self.front_end_config.host,
67
- port=self.front_end_config.port,
68
- debug=self.front_end_config.debug,
69
- log_level=self.front_end_config.log_level,
70
- )
71
-
72
- # Get the worker instance and set up routes
73
- worker = self._get_worker_instance()
73
+ # Create auth settings and token verifier if auth is required
74
+ auth_settings = None
75
+ token_verifier = None
74
76
 
75
77
  # Build the workflow and add routes using the worker
76
78
  async with WorkflowBuilder.from_config(config=self.full_config) as builder:
79
+
80
+ if self.front_end_config.server_auth:
81
+ from mcp.server.auth.settings import AuthSettings
82
+ from pydantic import AnyHttpUrl
83
+
84
+ server_url = f"http://{self.front_end_config.host}:{self.front_end_config.port}"
85
+
86
+ auth_settings = AuthSettings(issuer_url=AnyHttpUrl(self.front_end_config.server_auth.issuer_url),
87
+ required_scopes=self.front_end_config.server_auth.scopes,
88
+ resource_server_url=AnyHttpUrl(server_url))
89
+
90
+ token_verifier = await self._create_token_verifier(self.front_end_config.server_auth)
91
+
92
+ # Create an MCP server with the configured parameters
93
+ mcp = FastMCP(name=self.front_end_config.name,
94
+ host=self.front_end_config.host,
95
+ port=self.front_end_config.port,
96
+ debug=self.front_end_config.debug,
97
+ auth=auth_settings,
98
+ token_verifier=token_verifier)
99
+
100
+ # Get the worker instance and set up routes
101
+ worker = self._get_worker_instance()
102
+
77
103
  # Add routes through the worker (includes health endpoint and function registration)
78
104
  await worker.add_routes(mcp, builder)
79
105
 
80
- # Start the MCP server
81
- await mcp.run_sse_async()
106
+ # Start the MCP server with configurable transport
107
+ # streamable-http is the default, but users can choose sse if preferred
108
+ if self.front_end_config.transport == "sse":
109
+ logger.info("Starting MCP server with SSE endpoint at /sse")
110
+ await mcp.run_sse_async()
111
+ else: # streamable-http
112
+ logger.info("Starting MCP server with streamable-http endpoint at /mcp/")
113
+ await mcp.run_streamable_http_async()