nvidia-nat 1.2.1rc1__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (257) hide show
  1. aiq/__init__.py +2 -2
  2. nat/agent/base.py +27 -18
  3. nat/agent/dual_node.py +9 -4
  4. nat/agent/prompt_optimizer/prompt.py +68 -0
  5. nat/agent/prompt_optimizer/register.py +149 -0
  6. nat/agent/react_agent/agent.py +81 -50
  7. nat/agent/react_agent/register.py +59 -40
  8. nat/agent/reasoning_agent/reasoning_agent.py +17 -15
  9. nat/agent/register.py +1 -1
  10. nat/agent/rewoo_agent/agent.py +327 -149
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +64 -46
  13. nat/agent/tool_calling_agent/agent.py +152 -29
  14. nat/agent/tool_calling_agent/register.py +61 -38
  15. nat/authentication/api_key/api_key_auth_provider.py +2 -2
  16. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  17. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  18. nat/authentication/interfaces.py +5 -2
  19. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
  20. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  21. nat/authentication/register.py +0 -1
  22. nat/builder/builder.py +56 -24
  23. nat/builder/component_utils.py +10 -6
  24. nat/builder/context.py +70 -18
  25. nat/builder/eval_builder.py +16 -11
  26. nat/builder/framework_enum.py +1 -0
  27. nat/builder/front_end.py +1 -1
  28. nat/builder/function.py +378 -8
  29. nat/builder/function_base.py +3 -3
  30. nat/builder/function_info.py +6 -8
  31. nat/builder/intermediate_step_manager.py +6 -2
  32. nat/builder/user_interaction_manager.py +2 -2
  33. nat/builder/workflow.py +13 -1
  34. nat/builder/workflow_builder.py +327 -79
  35. nat/cli/cli_utils/config_override.py +2 -2
  36. nat/cli/commands/evaluate.py +1 -1
  37. nat/cli/commands/info/info.py +16 -6
  38. nat/cli/commands/info/list_channels.py +1 -1
  39. nat/cli/commands/info/list_components.py +7 -8
  40. nat/cli/commands/mcp/__init__.py +14 -0
  41. nat/cli/commands/mcp/mcp.py +986 -0
  42. nat/cli/commands/object_store/__init__.py +14 -0
  43. nat/cli/commands/object_store/object_store.py +227 -0
  44. nat/cli/commands/optimize.py +90 -0
  45. nat/cli/commands/registry/publish.py +2 -2
  46. nat/cli/commands/registry/pull.py +2 -2
  47. nat/cli/commands/registry/remove.py +2 -2
  48. nat/cli/commands/registry/search.py +15 -17
  49. nat/cli/commands/start.py +16 -5
  50. nat/cli/commands/uninstall.py +1 -1
  51. nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
  52. nat/cli/commands/workflow/templates/pyproject.toml.j2 +5 -2
  53. nat/cli/commands/workflow/templates/register.py.j2 +2 -3
  54. nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
  55. nat/cli/commands/workflow/workflow_commands.py +105 -19
  56. nat/cli/entrypoint.py +17 -11
  57. nat/cli/main.py +3 -0
  58. nat/cli/register_workflow.py +38 -4
  59. nat/cli/type_registry.py +79 -10
  60. nat/control_flow/__init__.py +0 -0
  61. nat/control_flow/register.py +20 -0
  62. nat/control_flow/router_agent/__init__.py +0 -0
  63. nat/control_flow/router_agent/agent.py +329 -0
  64. nat/control_flow/router_agent/prompt.py +48 -0
  65. nat/control_flow/router_agent/register.py +91 -0
  66. nat/control_flow/sequential_executor.py +166 -0
  67. nat/data_models/agent.py +34 -0
  68. nat/data_models/api_server.py +196 -67
  69. nat/data_models/authentication.py +23 -9
  70. nat/data_models/common.py +1 -1
  71. nat/data_models/component.py +2 -0
  72. nat/data_models/component_ref.py +11 -0
  73. nat/data_models/config.py +42 -18
  74. nat/data_models/dataset_handler.py +1 -1
  75. nat/data_models/discovery_metadata.py +4 -4
  76. nat/data_models/evaluate.py +4 -1
  77. nat/data_models/function.py +34 -0
  78. nat/data_models/function_dependencies.py +14 -6
  79. nat/data_models/gated_field_mixin.py +242 -0
  80. nat/data_models/intermediate_step.py +3 -3
  81. nat/data_models/optimizable.py +119 -0
  82. nat/data_models/optimizer.py +149 -0
  83. nat/data_models/span.py +41 -3
  84. nat/data_models/swe_bench_model.py +1 -1
  85. nat/data_models/temperature_mixin.py +44 -0
  86. nat/data_models/thinking_mixin.py +86 -0
  87. nat/data_models/top_p_mixin.py +44 -0
  88. nat/embedder/azure_openai_embedder.py +46 -0
  89. nat/embedder/nim_embedder.py +1 -1
  90. nat/embedder/openai_embedder.py +2 -3
  91. nat/embedder/register.py +1 -1
  92. nat/eval/config.py +3 -1
  93. nat/eval/dataset_handler/dataset_handler.py +71 -7
  94. nat/eval/evaluate.py +86 -31
  95. nat/eval/evaluator/base_evaluator.py +1 -1
  96. nat/eval/evaluator/evaluator_model.py +13 -0
  97. nat/eval/intermediate_step_adapter.py +1 -1
  98. nat/eval/rag_evaluator/evaluate.py +9 -6
  99. nat/eval/rag_evaluator/register.py +3 -3
  100. nat/eval/register.py +4 -1
  101. nat/eval/remote_workflow.py +3 -3
  102. nat/eval/runtime_evaluator/__init__.py +14 -0
  103. nat/eval/runtime_evaluator/evaluate.py +123 -0
  104. nat/eval/runtime_evaluator/register.py +100 -0
  105. nat/eval/swe_bench_evaluator/evaluate.py +6 -6
  106. nat/eval/trajectory_evaluator/evaluate.py +1 -1
  107. nat/eval/trajectory_evaluator/register.py +1 -1
  108. nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
  109. nat/eval/utils/eval_trace_ctx.py +89 -0
  110. nat/eval/utils/weave_eval.py +18 -9
  111. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  112. nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
  113. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  114. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
  115. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +3 -3
  116. nat/experimental/test_time_compute/models/strategy_base.py +5 -4
  117. nat/experimental/test_time_compute/register.py +0 -1
  118. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
  119. nat/front_ends/console/authentication_flow_handler.py +82 -30
  120. nat/front_ends/console/console_front_end_plugin.py +19 -7
  121. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
  122. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  123. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  124. nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
  125. nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
  126. nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
  127. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +455 -282
  128. nat/front_ends/fastapi/job_store.py +518 -99
  129. nat/front_ends/fastapi/main.py +11 -19
  130. nat/front_ends/fastapi/message_handler.py +74 -50
  131. nat/front_ends/fastapi/message_validator.py +20 -21
  132. nat/front_ends/fastapi/response_helpers.py +4 -4
  133. nat/front_ends/fastapi/step_adaptor.py +2 -2
  134. nat/front_ends/fastapi/utils.py +57 -0
  135. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  136. nat/front_ends/mcp/mcp_front_end_config.py +47 -3
  137. nat/front_ends/mcp/mcp_front_end_plugin.py +48 -13
  138. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +120 -8
  139. nat/front_ends/mcp/tool_converter.py +44 -14
  140. nat/front_ends/register.py +0 -1
  141. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  142. nat/llm/aws_bedrock_llm.py +24 -12
  143. nat/llm/azure_openai_llm.py +57 -0
  144. nat/llm/litellm_llm.py +69 -0
  145. nat/llm/nim_llm.py +20 -8
  146. nat/llm/openai_llm.py +14 -6
  147. nat/llm/register.py +5 -1
  148. nat/llm/utils/env_config_value.py +2 -3
  149. nat/llm/utils/thinking.py +215 -0
  150. nat/meta/pypi.md +9 -9
  151. nat/object_store/register.py +0 -1
  152. nat/observability/exporter/base_exporter.py +3 -3
  153. nat/observability/exporter/file_exporter.py +1 -1
  154. nat/observability/exporter/processing_exporter.py +309 -81
  155. nat/observability/exporter/span_exporter.py +35 -15
  156. nat/observability/exporter_manager.py +7 -7
  157. nat/observability/mixin/file_mixin.py +7 -7
  158. nat/observability/mixin/redaction_config_mixin.py +42 -0
  159. nat/observability/mixin/tagging_config_mixin.py +62 -0
  160. nat/observability/mixin/type_introspection_mixin.py +420 -107
  161. nat/observability/processor/batching_processor.py +5 -7
  162. nat/observability/processor/falsy_batch_filter_processor.py +55 -0
  163. nat/observability/processor/processor.py +3 -0
  164. nat/observability/processor/processor_factory.py +70 -0
  165. nat/observability/processor/redaction/__init__.py +24 -0
  166. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  167. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  168. nat/observability/processor/redaction/redaction_processor.py +177 -0
  169. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  170. nat/observability/processor/span_tagging_processor.py +68 -0
  171. nat/observability/register.py +22 -4
  172. nat/profiler/calc/calc_runner.py +3 -4
  173. nat/profiler/callbacks/agno_callback_handler.py +1 -1
  174. nat/profiler/callbacks/langchain_callback_handler.py +14 -7
  175. nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
  176. nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
  177. nat/profiler/data_frame_row.py +1 -1
  178. nat/profiler/decorators/framework_wrapper.py +62 -13
  179. nat/profiler/decorators/function_tracking.py +160 -3
  180. nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
  181. nat/profiler/forecasting/models/linear_model.py +1 -1
  182. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  183. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  184. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
  185. nat/profiler/inference_optimization/data_models.py +3 -3
  186. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +8 -9
  187. nat/profiler/inference_optimization/token_uniqueness.py +1 -1
  188. nat/profiler/parameter_optimization/__init__.py +0 -0
  189. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  190. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  191. nat/profiler/parameter_optimization/parameter_optimizer.py +164 -0
  192. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  193. nat/profiler/parameter_optimization/pareto_visualizer.py +395 -0
  194. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  195. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  196. nat/profiler/profile_runner.py +14 -9
  197. nat/profiler/utils.py +4 -2
  198. nat/registry_handlers/local/local_handler.py +2 -2
  199. nat/registry_handlers/package_utils.py +1 -2
  200. nat/registry_handlers/pypi/pypi_handler.py +23 -26
  201. nat/registry_handlers/register.py +3 -4
  202. nat/registry_handlers/rest/rest_handler.py +12 -13
  203. nat/retriever/milvus/retriever.py +2 -2
  204. nat/retriever/nemo_retriever/retriever.py +1 -1
  205. nat/retriever/register.py +0 -1
  206. nat/runtime/loader.py +2 -2
  207. nat/runtime/runner.py +105 -8
  208. nat/runtime/session.py +69 -8
  209. nat/settings/global_settings.py +16 -5
  210. nat/tool/chat_completion.py +5 -2
  211. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
  212. nat/tool/datetime_tools.py +49 -9
  213. nat/tool/document_search.py +2 -2
  214. nat/tool/github_tools.py +450 -0
  215. nat/tool/memory_tools/add_memory_tool.py +3 -3
  216. nat/tool/memory_tools/delete_memory_tool.py +3 -4
  217. nat/tool/memory_tools/get_memory_tool.py +4 -4
  218. nat/tool/nvidia_rag.py +1 -1
  219. nat/tool/register.py +2 -9
  220. nat/tool/retriever.py +3 -2
  221. nat/utils/callable_utils.py +70 -0
  222. nat/utils/data_models/schema_validator.py +3 -3
  223. nat/utils/decorators.py +210 -0
  224. nat/utils/exception_handlers/automatic_retries.py +104 -51
  225. nat/utils/exception_handlers/schemas.py +1 -1
  226. nat/utils/io/yaml_tools.py +2 -2
  227. nat/utils/log_levels.py +25 -0
  228. nat/utils/reactive/base/observable_base.py +2 -2
  229. nat/utils/reactive/base/observer_base.py +1 -1
  230. nat/utils/reactive/observable.py +2 -2
  231. nat/utils/reactive/observer.py +4 -4
  232. nat/utils/reactive/subscription.py +1 -1
  233. nat/utils/settings/global_settings.py +6 -8
  234. nat/utils/type_converter.py +12 -3
  235. nat/utils/type_utils.py +9 -5
  236. nvidia_nat-1.3.0.dist-info/METADATA +195 -0
  237. {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/RECORD +244 -200
  238. {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/entry_points.txt +1 -0
  239. nat/cli/commands/info/list_mcp.py +0 -304
  240. nat/tool/github_tools/create_github_commit.py +0 -133
  241. nat/tool/github_tools/create_github_issue.py +0 -87
  242. nat/tool/github_tools/create_github_pr.py +0 -106
  243. nat/tool/github_tools/get_github_file.py +0 -106
  244. nat/tool/github_tools/get_github_issue.py +0 -166
  245. nat/tool/github_tools/get_github_pr.py +0 -256
  246. nat/tool/github_tools/update_github_issue.py +0 -100
  247. nat/tool/mcp/exceptions.py +0 -142
  248. nat/tool/mcp/mcp_client.py +0 -255
  249. nat/tool/mcp/mcp_tool.py +0 -96
  250. nat/utils/exception_handlers/mcp.py +0 -211
  251. nvidia_nat-1.2.1rc1.dist-info/METADATA +0 -365
  252. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  253. /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
  254. {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/WHEEL +0 -0
  255. {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  256. {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/licenses/LICENSE.md +0 -0
  257. {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/top_level.txt +0 -0
@@ -13,19 +13,24 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import importlib
17
16
  import logging
18
17
  import os
18
+ import typing
19
19
 
20
20
  from nat.front_ends.fastapi.fastapi_front_end_plugin_worker import FastApiFrontEndPluginWorkerBase
21
+ from nat.front_ends.fastapi.utils import get_config_file_path
22
+ from nat.front_ends.fastapi.utils import import_class_from_string
21
23
  from nat.runtime.loader import load_config
22
24
 
25
+ if typing.TYPE_CHECKING:
26
+ from fastapi import FastAPI
27
+
23
28
  logger = logging.getLogger(__name__)
24
29
 
25
30
 
26
- def get_app():
31
+ def get_app() -> "FastAPI":
27
32
 
28
- config_file_path = os.getenv("NAT_CONFIG_FILE")
33
+ config_file_path = get_config_file_path()
29
34
  front_end_worker_full_name = os.getenv("NAT_FRONT_END_WORKER")
30
35
 
31
36
  if (not config_file_path):
@@ -36,28 +41,15 @@ def get_app():
36
41
 
37
42
  # Try to import the front end worker class
38
43
  try:
39
- # Split the package from the class
40
- front_end_worker_parts = front_end_worker_full_name.split(".")
41
-
42
- front_end_worker_module_name = ".".join(front_end_worker_parts[:-1])
43
- front_end_worker_class_name = front_end_worker_parts[-1]
44
-
45
- front_end_worker_module = importlib.import_module(front_end_worker_module_name)
46
-
47
- if not hasattr(front_end_worker_module, front_end_worker_class_name):
48
- raise ValueError(f"Front end worker {front_end_worker_full_name} not found.")
49
-
50
- front_end_worker_class: type[FastApiFrontEndPluginWorkerBase] = getattr(front_end_worker_module,
51
- front_end_worker_class_name)
44
+ front_end_worker_class: type[FastApiFrontEndPluginWorkerBase] = import_class_from_string(
45
+ front_end_worker_full_name)
52
46
 
53
47
  if (not issubclass(front_end_worker_class, FastApiFrontEndPluginWorkerBase)):
54
48
  raise ValueError(
55
49
  f"Front end worker {front_end_worker_full_name} is not a subclass of FastApiFrontEndPluginWorker.")
56
50
 
57
51
  # Load the config
58
- abs_config_file_path = os.path.abspath(config_file_path)
59
-
60
- config = load_config(abs_config_file_path)
52
+ config = load_config(config_file_path)
61
53
 
62
54
  # Create an instance of the front end worker class
63
55
  front_end_worker = front_end_worker_class(config)
@@ -25,6 +25,7 @@ from pydantic import ValidationError
25
25
  from starlette.websockets import WebSocketDisconnect
26
26
 
27
27
  from nat.authentication.interfaces import FlowHandlerBase
28
+ from nat.data_models.api_server import ChatRequest
28
29
  from nat.data_models.api_server import ChatResponse
29
30
  from nat.data_models.api_server import ChatResponseChunk
30
31
  from nat.data_models.api_server import Error
@@ -33,6 +34,8 @@ from nat.data_models.api_server import ResponsePayloadOutput
33
34
  from nat.data_models.api_server import ResponseSerializable
34
35
  from nat.data_models.api_server import SystemResponseContent
35
36
  from nat.data_models.api_server import TextContent
37
+ from nat.data_models.api_server import UserMessageContentRoleType
38
+ from nat.data_models.api_server import UserMessages
36
39
  from nat.data_models.api_server import WebSocketMessageStatus
37
40
  from nat.data_models.api_server import WebSocketMessageType
38
41
  from nat.data_models.api_server import WebSocketSystemInteractionMessage
@@ -64,12 +67,12 @@ class WebSocketMessageHandler:
64
67
  self._running_workflow_task: asyncio.Task | None = None
65
68
  self._message_parent_id: str = "default_id"
66
69
  self._conversation_id: str | None = None
67
- self._workflow_schema_type: str = None
68
- self._user_interaction_response: asyncio.Future[HumanResponse] | None = None
70
+ self._workflow_schema_type: str | None = None
71
+ self._user_interaction_response: asyncio.Future[TextContent] | None = None
69
72
 
70
73
  self._flow_handler: FlowHandlerBase | None = None
71
74
 
72
- self._schema_output_mapping: dict[str, type[BaseModel] | None] = {
75
+ self._schema_output_mapping: dict[str, type[BaseModel] | type[None]] = {
73
76
  WorkflowSchemaType.GENERATE: self._session_manager.workflow.single_output_schema,
74
77
  WorkflowSchemaType.CHAT: ChatResponse,
75
78
  WorkflowSchemaType.CHAT_STREAM: ChatResponseChunk,
@@ -86,7 +89,7 @@ class WebSocketMessageHandler:
86
89
 
87
90
  async def __aexit__(self, exc_type, exc_value, traceback) -> None:
88
91
 
89
- # TODO: Handle the exit # pylint: disable=fixme
92
+ # TODO: Handle the exit
90
93
  pass
91
94
 
92
95
  async def run(self) -> None:
@@ -107,47 +110,65 @@ class WebSocketMessageHandler:
107
110
 
108
111
  elif isinstance(
109
112
  validated_message,
110
- ( # noqa: E131
111
- WebSocketSystemResponseTokenMessage,
112
- WebSocketSystemIntermediateStepMessage,
113
- WebSocketSystemInteractionMessage)):
113
+ WebSocketSystemResponseTokenMessage | WebSocketSystemIntermediateStepMessage
114
+ | WebSocketSystemInteractionMessage):
114
115
  # These messages are already handled by self.create_websocket_message(data_model=value, …)
115
116
  # No further processing is needed here.
116
117
  pass
117
118
 
118
119
  elif (isinstance(validated_message, WebSocketUserInteractionResponseMessage)):
119
- user_content = await self.process_user_message_content(validated_message)
120
+ user_content = await self._process_websocket_user_interaction_response_message(validated_message)
121
+ assert self._user_interaction_response is not None
120
122
  self._user_interaction_response.set_result(user_content)
121
123
  except (asyncio.CancelledError, WebSocketDisconnect):
122
- # TODO: Handle the disconnect # pylint: disable=fixme
124
+ # TODO: Handle the disconnect
123
125
  break
124
126
 
125
- return None
126
-
127
- async def process_user_message_content(
128
- self, user_content: WebSocketUserMessage | WebSocketUserInteractionResponseMessage) -> BaseModel | None:
127
+ def _extract_last_user_message_content(self, messages: list[UserMessages]) -> TextContent:
129
128
  """
130
- Processes the contents of a user message.
129
+ Extracts the last user's TextContent from a list of messages.
131
130
 
132
- :param user_content: Incoming content data model.
133
- :return: A validated Pydantic user content model or None if not found.
134
- """
131
+ Args:
132
+ messages: List of UserMessages.
135
133
 
136
- for user_message in user_content.content.messages[::-1]:
137
- if (user_message.role == "user"):
134
+ Returns:
135
+ TextContent object from the last user message.
138
136
 
137
+ Raises:
138
+ ValueError: If no user text content is found.
139
+ """
140
+ for user_message in messages[::-1]:
141
+ if user_message.role == UserMessageContentRoleType.USER:
139
142
  for attachment in user_message.content:
140
-
141
143
  if isinstance(attachment, TextContent):
142
144
  return attachment
145
+ raise ValueError("No user text content found in messages.")
143
146
 
144
- return None
147
+ async def _process_websocket_user_interaction_response_message(
148
+ self, user_content: WebSocketUserInteractionResponseMessage) -> TextContent:
149
+ """
150
+ Processes a WebSocketUserInteractionResponseMessage.
151
+ """
152
+ return self._extract_last_user_message_content(user_content.content.messages)
153
+
154
+ async def _process_websocket_user_message(self, user_content: WebSocketUserMessage) -> ChatRequest | str:
155
+ """
156
+ Processes a WebSocketUserMessage based on schema type.
157
+ """
158
+ if self._workflow_schema_type in [WorkflowSchemaType.CHAT, WorkflowSchemaType.CHAT_STREAM]:
159
+ return ChatRequest(**user_content.content.model_dump(include={"messages"}))
160
+
161
+ elif self._workflow_schema_type in [WorkflowSchemaType.GENERATE, WorkflowSchemaType.GENERATE_STREAM]:
162
+ return self._extract_last_user_message_content(user_content.content.messages).text
163
+
164
+ raise ValueError("Unsupported workflow schema type for WebSocketUserMessage")
145
165
 
146
166
  async def process_workflow_request(self, user_message_as_validated_type: WebSocketUserMessage) -> None:
147
167
  """
148
168
  Process user messages and routes them appropriately.
149
169
 
150
- :param user_message_as_validated_type: A WebSocketUserMessage Data Model instance.
170
+ Args:
171
+ user_message_as_validated_type (WebSocketUserMessage): The validated user message to process.
151
172
  """
152
173
 
153
174
  try:
@@ -155,25 +176,23 @@ class WebSocketMessageHandler:
155
176
  self._workflow_schema_type = user_message_as_validated_type.schema_type
156
177
  self._conversation_id = user_message_as_validated_type.conversation_id
157
178
 
158
- content: BaseModel | None = await self.process_user_message_content(user_message_as_validated_type)
159
-
160
- if content is None:
161
- raise ValueError(f"User message content could not be found: {user_message_as_validated_type}")
179
+ message_content: typing.Any = await self._process_websocket_user_message(user_message_as_validated_type)
162
180
 
163
- if isinstance(content, TextContent) and (self._running_workflow_task is None):
181
+ if (self._running_workflow_task is None):
164
182
 
165
- def _done_callback(task: asyncio.Task): # pylint: disable=unused-argument
183
+ def _done_callback(_task: asyncio.Task):
166
184
  self._running_workflow_task = None
167
185
 
168
186
  self._running_workflow_task = asyncio.create_task(
169
- self._run_workflow(content.text,
170
- self._conversation_id,
187
+ self._run_workflow(payload=message_content,
188
+ user_message_id=self._message_parent_id,
189
+ conversation_id=self._conversation_id,
171
190
  result_type=self._schema_output_mapping[self._workflow_schema_type],
172
191
  output_type=self._schema_output_mapping[
173
192
  self._workflow_schema_type])).add_done_callback(_done_callback)
174
193
 
175
194
  except ValueError as e:
176
- logger.error("User message content not found: %s", str(e), exc_info=True)
195
+ logger.exception("User message content not found: %s", str(e))
177
196
  await self.create_websocket_message(data_model=Error(code=ErrorTypes.INVALID_USER_MESSAGE_CONTENT,
178
197
  message="User message content could not be found",
179
198
  details=str(e)),
@@ -183,13 +202,14 @@ class WebSocketMessageHandler:
183
202
  async def create_websocket_message(self,
184
203
  data_model: BaseModel,
185
204
  message_type: str | None = None,
186
- status: str = WebSocketMessageStatus.IN_PROGRESS) -> None:
205
+ status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS) -> None:
187
206
  """
188
207
  Creates a websocket message that will be ready for routing based on message type or data model.
189
208
 
190
- :param data_model: Message content model.
191
- :param message_type: Message content model.
192
- :param status: Message content model.
209
+ Args:
210
+ data_model (BaseModel): Message content model.
211
+ message_type (str | None): Message content model.
212
+ status (WebSocketMessageStatus): Message content model.
193
213
  """
194
214
  try:
195
215
  message: BaseModel | None = None
@@ -199,8 +219,8 @@ class WebSocketMessageHandler:
199
219
 
200
220
  message_schema: type[BaseModel] = await self._message_validator.get_message_schema_by_type(message_type)
201
221
 
202
- if 'id' in data_model.model_fields:
203
- message_id: str = data_model.id
222
+ if hasattr(data_model, 'id'):
223
+ message_id: str = str(getattr(data_model, 'id'))
204
224
  else:
205
225
  message_id = str(uuid.uuid4())
206
226
 
@@ -241,7 +261,7 @@ class WebSocketMessageHandler:
241
261
  f"Message type could not be resolved by input data model: {data_model.model_dump_json()}")
242
262
 
243
263
  except (ValidationError, TypeError, ValueError) as e:
244
- logger.error("A data vaidation error ocurred creating websocket message: %s", str(e), exc_info=True)
264
+ logger.exception("A data vaidation error ocurred creating websocket message: %s", str(e))
245
265
  message = await self._message_validator.create_system_response_token_message(
246
266
  message_type=WebSocketMessageType.ERROR_MESSAGE,
247
267
  conversation_id=self._conversation_id,
@@ -256,12 +276,15 @@ class WebSocketMessageHandler:
256
276
  Registered human interaction callback that processes human interactions and returns
257
277
  responses from websocket connection.
258
278
 
259
- :param prompt: Incoming interaction content data model.
260
- :return: A Text Content Base Pydantic model.
279
+ Args:
280
+ prompt: Incoming interaction content data model.
281
+
282
+ Returns:
283
+ A Text Content Base Pydantic model.
261
284
  """
262
285
 
263
286
  # First create a future from the loop for the human response
264
- human_response_future: asyncio.Future[HumanResponse] = asyncio.get_running_loop().create_future()
287
+ human_response_future: asyncio.Future[TextContent] = asyncio.get_running_loop().create_future()
265
288
 
266
289
  # Then add the future to the outstanding human prompts dictionary
267
290
  self._user_interaction_response = human_response_future
@@ -277,10 +300,10 @@ class WebSocketMessageHandler:
277
300
  return HumanResponseNotification()
278
301
 
279
302
  # Wait for the human response future to complete
280
- interaction_response: HumanResponse = await human_response_future
303
+ text_content: TextContent = await human_response_future
281
304
 
282
305
  interaction_response: HumanResponse = await self._message_validator.convert_text_content_to_human_response(
283
- interaction_response, prompt.content)
306
+ text_content, prompt.content)
284
307
 
285
308
  return interaction_response
286
309
 
@@ -290,17 +313,18 @@ class WebSocketMessageHandler:
290
313
 
291
314
  async def _run_workflow(self,
292
315
  payload: typing.Any,
316
+ user_message_id: str | None = None,
293
317
  conversation_id: str | None = None,
294
318
  result_type: type | None = None,
295
319
  output_type: type | None = None) -> None:
296
320
 
297
321
  try:
298
- async with self._session_manager.session(
299
- conversation_id=conversation_id,
300
- request=self._socket,
301
- user_input_callback=self.human_interaction_callback,
302
- user_authentication_callback=(self._flow_handler.authenticate
303
- if self._flow_handler else None)) as session:
322
+ auth_callback = self._flow_handler.authenticate if self._flow_handler else None
323
+ async with self._session_manager.session(user_message_id=user_message_id,
324
+ conversation_id=conversation_id,
325
+ http_connection=self._socket,
326
+ user_input_callback=self.human_interaction_callback,
327
+ user_authentication_callback=auth_callback) as session:
304
328
 
305
329
  async for value in generate_streaming_response(payload,
306
330
  session_manager=session,
@@ -97,7 +97,7 @@ class MessageValidator:
97
97
  return validated_message
98
98
 
99
99
  except (ValidationError, TypeError, ValueError) as e:
100
- logger.error("A data validation error %s occurred for message: %s", str(e), str(message), exc_info=True)
100
+ logger.exception("A data validation error %s occurred for message: %s", str(e), str(message))
101
101
  return await self.create_system_response_token_message(message_type=WebSocketMessageType.ERROR_MESSAGE,
102
102
  content=Error(code=ErrorTypes.INVALID_MESSAGE,
103
103
  message="Error validating message.",
@@ -119,7 +119,7 @@ class MessageValidator:
119
119
  return schema
120
120
 
121
121
  except (TypeError, ValueError) as e:
122
- logger.error("Error retrieving schema for message type '%s': %s", message_type, str(e), exc_info=True)
122
+ logger.exception("Error retrieving schema for message type '%s': %s", message_type, str(e))
123
123
  return Error
124
124
 
125
125
  async def convert_data_to_message_content(self, data_model: BaseModel) -> BaseModel:
@@ -139,8 +139,10 @@ class MessageValidator:
139
139
  text_content: str = str(data_model.payload)
140
140
  validated_message_content = SystemResponseContent(text=text_content)
141
141
 
142
- elif (isinstance(data_model, (ChatResponse, ChatResponseChunk))):
142
+ elif isinstance(data_model, ChatResponse):
143
143
  validated_message_content = SystemResponseContent(text=data_model.choices[0].message.content)
144
+ elif isinstance(data_model, ChatResponseChunk):
145
+ validated_message_content = SystemResponseContent(text=data_model.choices[0].delta.content)
144
146
 
145
147
  elif (isinstance(data_model, ResponseIntermediateStep)):
146
148
  validated_message_content = SystemIntermediateStepContent(name=data_model.name,
@@ -156,7 +158,7 @@ class MessageValidator:
156
158
  return validated_message_content
157
159
 
158
160
  except ValueError as e:
159
- logger.error("Input data could not be converted to validated message content: %s", str(e), exc_info=True)
161
+ logger.exception("Input data could not be converted to validated message content: %s", str(e))
160
162
  return Error(code=ErrorTypes.INVALID_DATA_CONTENT, message="Input data not supported.", details=str(e))
161
163
 
162
164
  async def convert_text_content_to_human_response(self, text_content: TextContent,
@@ -191,7 +193,7 @@ class MessageValidator:
191
193
  return human_response
192
194
 
193
195
  except ValueError as e:
194
- logger.error("Error human response content not found: %s", str(e), exc_info=True)
196
+ logger.exception("Error human response content not found: %s", str(e))
195
197
  return HumanResponseText(text=str(e))
196
198
 
197
199
  async def resolve_message_type_by_data(self, data_model: BaseModel) -> str:
@@ -204,7 +206,7 @@ class MessageValidator:
204
206
 
205
207
  validated_message_type: str = ""
206
208
  try:
207
- if (isinstance(data_model, (ResponsePayloadOutput, ChatResponse, ChatResponseChunk))):
209
+ if (isinstance(data_model, ResponsePayloadOutput | ChatResponse | ChatResponseChunk)):
208
210
  validated_message_type = WebSocketMessageType.RESPONSE_MESSAGE
209
211
 
210
212
  elif (isinstance(data_model, ResponseIntermediateStep)):
@@ -218,9 +220,7 @@ class MessageValidator:
218
220
  return validated_message_type
219
221
 
220
222
  except ValueError as e:
221
- logger.error("Error type not found converting data to validated websocket message content: %s",
222
- str(e),
223
- exc_info=True)
223
+ logger.exception("Error type not found converting data to validated websocket message content: %s", str(e))
224
224
  return WebSocketMessageType.ERROR_MESSAGE
225
225
 
226
226
  async def get_intermediate_step_parent_id(self, data_model: ResponseIntermediateStep) -> str:
@@ -232,7 +232,7 @@ class MessageValidator:
232
232
  """
233
233
  return data_model.parent_id or "root"
234
234
 
235
- async def create_system_response_token_message( # pylint: disable=R0917:too-many-positional-arguments
235
+ async def create_system_response_token_message(
236
236
  self,
237
237
  message_type: Literal[WebSocketMessageType.RESPONSE_MESSAGE,
238
238
  WebSocketMessageType.ERROR_MESSAGE] = WebSocketMessageType.RESPONSE_MESSAGE,
@@ -240,10 +240,9 @@ class MessageValidator:
240
240
  thread_id: str = "default",
241
241
  parent_id: str = "default",
242
242
  conversation_id: str | None = None,
243
- content: SystemResponseContent
244
- | Error = SystemResponseContent(),
243
+ content: SystemResponseContent | Error = SystemResponseContent(),
245
244
  status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
246
- timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
245
+ timestamp: str = str(datetime.datetime.now(datetime.UTC))
247
246
  ) -> WebSocketSystemResponseTokenMessage | None:
248
247
  """
249
248
  Creates a system response token message with default values.
@@ -269,10 +268,10 @@ class MessageValidator:
269
268
  timestamp=timestamp)
270
269
 
271
270
  except Exception as e:
272
- logger.error("Error creating system response token message: %s", str(e), exc_info=True)
271
+ logger.exception("Error creating system response token message: %s", str(e))
273
272
  return None
274
273
 
275
- async def create_system_intermediate_step_message( # pylint: disable=R0917:too-many-positional-arguments
274
+ async def create_system_intermediate_step_message(
276
275
  self,
277
276
  message_type: Literal[WebSocketMessageType.INTERMEDIATE_STEP_MESSAGE] = (
278
277
  WebSocketMessageType.INTERMEDIATE_STEP_MESSAGE),
@@ -282,7 +281,7 @@ class MessageValidator:
282
281
  conversation_id: str | None = None,
283
282
  content: SystemIntermediateStepContent = SystemIntermediateStepContent(name="default", payload="default"),
284
283
  status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
285
- timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
284
+ timestamp: str = str(datetime.datetime.now(datetime.UTC))
286
285
  ) -> WebSocketSystemIntermediateStepMessage | None:
287
286
  """
288
287
  Creates a system intermediate step message with default values.
@@ -308,10 +307,10 @@ class MessageValidator:
308
307
  timestamp=timestamp)
309
308
 
310
309
  except Exception as e:
311
- logger.error("Error creating system intermediate step message: %s", str(e), exc_info=True)
310
+ logger.exception("Error creating system intermediate step message: %s", str(e))
312
311
  return None
313
312
 
314
- async def create_system_interaction_message( # pylint: disable=R0917:too-many-positional-arguments
313
+ async def create_system_interaction_message(
315
314
  self,
316
315
  *,
317
316
  message_type: Literal[WebSocketMessageType.SYSTEM_INTERACTION_MESSAGE] = (
@@ -322,8 +321,8 @@ class MessageValidator:
322
321
  conversation_id: str | None = None,
323
322
  content: HumanPrompt,
324
323
  status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
325
- timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
326
- ) -> WebSocketSystemInteractionMessage | None: # noqa: E125 continuation line with same indent as next logical line
324
+ timestamp: str = str(datetime.datetime.now(datetime.UTC))
325
+ ) -> WebSocketSystemInteractionMessage | None:
327
326
  """
328
327
  Creates a system interaction message with default values.
329
328
 
@@ -348,5 +347,5 @@ class MessageValidator:
348
347
  timestamp=timestamp)
349
348
 
350
349
  except Exception as e:
351
- logger.error("Error creating system interaction message: %s", str(e), exc_info=True)
350
+ logger.exception("Error creating system interaction message: %s", str(e))
352
351
  return None
@@ -98,9 +98,9 @@ async def generate_streaming_response(payload: typing.Any,
98
98
  yield item
99
99
  else:
100
100
  yield ResponsePayloadOutput(payload=item)
101
- except Exception as e:
101
+ except Exception:
102
102
  # Handle exceptions here
103
- raise e
103
+ raise
104
104
  finally:
105
105
  await q.close()
106
106
 
@@ -165,9 +165,9 @@ async def generate_streaming_response_full(payload: typing.Any,
165
165
  yield item
166
166
  else:
167
167
  yield ResponsePayloadOutput(payload=item)
168
- except Exception as e:
168
+ except Exception:
169
169
  # Handle exceptions here
170
- raise e
170
+ raise
171
171
  finally:
172
172
  await q.close()
173
173
 
@@ -289,7 +289,7 @@ class StepAdaptor:
289
289
 
290
290
  return event
291
291
 
292
- def process(self, step: IntermediateStep) -> ResponseSerializable | None: # pylint: disable=R1710
292
+ def process(self, step: IntermediateStep) -> ResponseSerializable | None:
293
293
 
294
294
  # Track the chunk
295
295
  self._history.append(step)
@@ -314,6 +314,6 @@ class StepAdaptor:
314
314
  return self._handle_custom(payload, ancestry)
315
315
 
316
316
  except Exception as e:
317
- logger.error("Error processing intermediate step: %s", e, exc_info=True)
317
+ logger.exception("Error processing intermediate step: %s", e)
318
318
 
319
319
  return None
@@ -0,0 +1,57 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import importlib
17
+ import os
18
+
19
+
20
+ def get_config_file_path() -> str:
21
+ """
22
+ Get the path to the NAT configuration file from the environment variable NAT_CONFIG_FILE.
23
+ Raises ValueError if the environment variable is not set.
24
+ """
25
+ config_file_path = os.getenv("NAT_CONFIG_FILE")
26
+ if (not config_file_path):
27
+ raise ValueError("Config file not found in environment variable NAT_CONFIG_FILE.")
28
+
29
+ return os.path.abspath(config_file_path)
30
+
31
+
32
+ def import_class_from_string(class_full_name: str) -> type:
33
+ """
34
+ Import a class from a string in the format 'module.submodule.ClassName'.
35
+ Raises ImportError if the class cannot be imported.
36
+ """
37
+ try:
38
+ class_name_parts = class_full_name.split(".")
39
+
40
+ module_name = ".".join(class_name_parts[:-1])
41
+ class_name = class_name_parts[-1]
42
+
43
+ module = importlib.import_module(module_name)
44
+
45
+ if not hasattr(module, class_name):
46
+ raise ValueError(f"Class '{class_full_name}' not found.")
47
+
48
+ return getattr(module, class_name)
49
+ except (ImportError, AttributeError) as e:
50
+ raise ImportError(f"Could not import {class_full_name}.") from e
51
+
52
+
53
+ def get_class_name(cls: type) -> str:
54
+ """
55
+ Get the full class name including the module.
56
+ """
57
+ return f"{cls.__module__}.{cls.__qualname__}"
@@ -0,0 +1,73 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """OAuth 2.0 Token Introspection verifier implementation for MCP servers."""
16
+
17
+ import logging
18
+
19
+ from mcp.server.auth.provider import AccessToken
20
+ from mcp.server.auth.provider import TokenVerifier
21
+
22
+ from nat.authentication.credential_validator.bearer_token_validator import BearerTokenValidator
23
+ from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class IntrospectionTokenVerifier(TokenVerifier):
29
+ """Token verifier that delegates token verification to BearerTokenValidator."""
30
+
31
+ def __init__(self, config: OAuth2ResourceServerConfig):
32
+ """Create IntrospectionTokenVerifier from OAuth2ResourceServerConfig.
33
+
34
+ Args:
35
+ config: OAuth2ResourceServerConfig
36
+ """
37
+ issuer = config.issuer_url
38
+ scopes = config.scopes or []
39
+ audience = config.audience
40
+ jwks_uri = config.jwks_uri
41
+ introspection_endpoint = config.introspection_endpoint
42
+ discovery_url = config.discovery_url
43
+ client_id = config.client_id
44
+ client_secret = config.client_secret
45
+
46
+ self._bearer_token_validator = BearerTokenValidator(
47
+ issuer=issuer,
48
+ audience=audience,
49
+ scopes=scopes,
50
+ jwks_uri=jwks_uri,
51
+ introspection_endpoint=introspection_endpoint,
52
+ discovery_url=discovery_url,
53
+ client_id=client_id,
54
+ client_secret=client_secret,
55
+ )
56
+
57
+ async def verify_token(self, token: str) -> AccessToken | None:
58
+ """Verify token by delegating to BearerTokenValidator.
59
+
60
+ Args:
61
+ token: The Bearer token to verify
62
+
63
+ Returns:
64
+ AccessToken | None: AccessToken if valid, None if invalid
65
+ """
66
+ validation_result = await self._bearer_token_validator.verify(token)
67
+
68
+ if validation_result.active:
69
+ return AccessToken(token=token,
70
+ expires_at=validation_result.expires_at,
71
+ scopes=validation_result.scopes or [],
72
+ client_id=validation_result.client_id or "")
73
+ return None