nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (250) hide show
  1. aiq/__init__.py +2 -2
  2. nat/agent/base.py +24 -15
  3. nat/agent/dual_node.py +9 -4
  4. nat/agent/prompt_optimizer/prompt.py +68 -0
  5. nat/agent/prompt_optimizer/register.py +149 -0
  6. nat/agent/react_agent/agent.py +79 -47
  7. nat/agent/react_agent/register.py +50 -22
  8. nat/agent/reasoning_agent/reasoning_agent.py +11 -9
  9. nat/agent/register.py +1 -1
  10. nat/agent/rewoo_agent/agent.py +326 -148
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +54 -27
  13. nat/agent/tool_calling_agent/agent.py +84 -28
  14. nat/agent/tool_calling_agent/register.py +51 -28
  15. nat/authentication/api_key/api_key_auth_provider.py +2 -2
  16. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  17. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  18. nat/authentication/interfaces.py +5 -2
  19. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
  20. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  21. nat/authentication/register.py +0 -1
  22. nat/builder/builder.py +56 -24
  23. nat/builder/component_utils.py +9 -5
  24. nat/builder/context.py +68 -17
  25. nat/builder/eval_builder.py +16 -11
  26. nat/builder/framework_enum.py +1 -0
  27. nat/builder/front_end.py +1 -1
  28. nat/builder/function.py +378 -8
  29. nat/builder/function_base.py +3 -3
  30. nat/builder/function_info.py +6 -8
  31. nat/builder/user_interaction_manager.py +2 -2
  32. nat/builder/workflow.py +13 -1
  33. nat/builder/workflow_builder.py +281 -76
  34. nat/cli/cli_utils/config_override.py +2 -2
  35. nat/cli/commands/evaluate.py +1 -1
  36. nat/cli/commands/info/info.py +16 -6
  37. nat/cli/commands/info/list_channels.py +1 -1
  38. nat/cli/commands/info/list_components.py +7 -8
  39. nat/cli/commands/mcp/__init__.py +14 -0
  40. nat/cli/commands/mcp/mcp.py +986 -0
  41. nat/cli/commands/object_store/__init__.py +14 -0
  42. nat/cli/commands/object_store/object_store.py +227 -0
  43. nat/cli/commands/optimize.py +90 -0
  44. nat/cli/commands/registry/publish.py +2 -2
  45. nat/cli/commands/registry/pull.py +2 -2
  46. nat/cli/commands/registry/remove.py +2 -2
  47. nat/cli/commands/registry/search.py +15 -17
  48. nat/cli/commands/start.py +16 -5
  49. nat/cli/commands/uninstall.py +1 -1
  50. nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
  51. nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
  52. nat/cli/commands/workflow/templates/register.py.j2 +2 -3
  53. nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
  54. nat/cli/commands/workflow/workflow_commands.py +62 -22
  55. nat/cli/entrypoint.py +8 -10
  56. nat/cli/main.py +3 -0
  57. nat/cli/register_workflow.py +38 -4
  58. nat/cli/type_registry.py +75 -6
  59. nat/control_flow/__init__.py +0 -0
  60. nat/control_flow/register.py +20 -0
  61. nat/control_flow/router_agent/__init__.py +0 -0
  62. nat/control_flow/router_agent/agent.py +329 -0
  63. nat/control_flow/router_agent/prompt.py +48 -0
  64. nat/control_flow/router_agent/register.py +91 -0
  65. nat/control_flow/sequential_executor.py +166 -0
  66. nat/data_models/agent.py +34 -0
  67. nat/data_models/api_server.py +74 -66
  68. nat/data_models/authentication.py +23 -9
  69. nat/data_models/common.py +1 -1
  70. nat/data_models/component.py +2 -0
  71. nat/data_models/component_ref.py +11 -0
  72. nat/data_models/config.py +41 -17
  73. nat/data_models/dataset_handler.py +1 -1
  74. nat/data_models/discovery_metadata.py +4 -4
  75. nat/data_models/evaluate.py +4 -1
  76. nat/data_models/function.py +34 -0
  77. nat/data_models/function_dependencies.py +14 -6
  78. nat/data_models/gated_field_mixin.py +242 -0
  79. nat/data_models/intermediate_step.py +3 -3
  80. nat/data_models/optimizable.py +119 -0
  81. nat/data_models/optimizer.py +149 -0
  82. nat/data_models/span.py +41 -3
  83. nat/data_models/swe_bench_model.py +1 -1
  84. nat/data_models/temperature_mixin.py +44 -0
  85. nat/data_models/thinking_mixin.py +86 -0
  86. nat/data_models/top_p_mixin.py +44 -0
  87. nat/embedder/nim_embedder.py +1 -1
  88. nat/embedder/openai_embedder.py +1 -1
  89. nat/embedder/register.py +0 -1
  90. nat/eval/config.py +3 -1
  91. nat/eval/dataset_handler/dataset_handler.py +71 -7
  92. nat/eval/evaluate.py +86 -31
  93. nat/eval/evaluator/base_evaluator.py +1 -1
  94. nat/eval/evaluator/evaluator_model.py +13 -0
  95. nat/eval/intermediate_step_adapter.py +1 -1
  96. nat/eval/rag_evaluator/evaluate.py +2 -2
  97. nat/eval/rag_evaluator/register.py +3 -3
  98. nat/eval/register.py +4 -1
  99. nat/eval/remote_workflow.py +3 -3
  100. nat/eval/runtime_evaluator/__init__.py +14 -0
  101. nat/eval/runtime_evaluator/evaluate.py +123 -0
  102. nat/eval/runtime_evaluator/register.py +100 -0
  103. nat/eval/swe_bench_evaluator/evaluate.py +6 -6
  104. nat/eval/trajectory_evaluator/evaluate.py +1 -1
  105. nat/eval/trajectory_evaluator/register.py +1 -1
  106. nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
  107. nat/eval/utils/eval_trace_ctx.py +89 -0
  108. nat/eval/utils/weave_eval.py +18 -9
  109. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  110. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  111. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
  112. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
  113. nat/experimental/test_time_compute/models/strategy_base.py +5 -4
  114. nat/experimental/test_time_compute/register.py +0 -1
  115. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
  116. nat/front_ends/console/authentication_flow_handler.py +82 -30
  117. nat/front_ends/console/console_front_end_plugin.py +8 -5
  118. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  119. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  120. nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
  121. nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
  122. nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
  123. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +452 -282
  124. nat/front_ends/fastapi/job_store.py +518 -99
  125. nat/front_ends/fastapi/main.py +11 -19
  126. nat/front_ends/fastapi/message_handler.py +13 -14
  127. nat/front_ends/fastapi/message_validator.py +19 -19
  128. nat/front_ends/fastapi/response_helpers.py +4 -4
  129. nat/front_ends/fastapi/step_adaptor.py +2 -2
  130. nat/front_ends/fastapi/utils.py +57 -0
  131. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  132. nat/front_ends/mcp/mcp_front_end_config.py +10 -1
  133. nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
  134. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
  135. nat/front_ends/mcp/tool_converter.py +44 -14
  136. nat/front_ends/register.py +0 -1
  137. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  138. nat/llm/aws_bedrock_llm.py +24 -12
  139. nat/llm/azure_openai_llm.py +13 -6
  140. nat/llm/litellm_llm.py +69 -0
  141. nat/llm/nim_llm.py +20 -8
  142. nat/llm/openai_llm.py +14 -6
  143. nat/llm/register.py +4 -1
  144. nat/llm/utils/env_config_value.py +2 -3
  145. nat/llm/utils/thinking.py +215 -0
  146. nat/meta/pypi.md +9 -9
  147. nat/object_store/register.py +0 -1
  148. nat/observability/exporter/base_exporter.py +3 -3
  149. nat/observability/exporter/file_exporter.py +1 -1
  150. nat/observability/exporter/processing_exporter.py +309 -81
  151. nat/observability/exporter/span_exporter.py +35 -15
  152. nat/observability/exporter_manager.py +7 -7
  153. nat/observability/mixin/file_mixin.py +7 -7
  154. nat/observability/mixin/redaction_config_mixin.py +42 -0
  155. nat/observability/mixin/tagging_config_mixin.py +62 -0
  156. nat/observability/mixin/type_introspection_mixin.py +420 -107
  157. nat/observability/processor/batching_processor.py +5 -7
  158. nat/observability/processor/falsy_batch_filter_processor.py +55 -0
  159. nat/observability/processor/processor.py +3 -0
  160. nat/observability/processor/processor_factory.py +70 -0
  161. nat/observability/processor/redaction/__init__.py +24 -0
  162. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  163. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  164. nat/observability/processor/redaction/redaction_processor.py +177 -0
  165. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  166. nat/observability/processor/span_tagging_processor.py +68 -0
  167. nat/observability/register.py +6 -4
  168. nat/profiler/calc/calc_runner.py +3 -4
  169. nat/profiler/callbacks/agno_callback_handler.py +1 -1
  170. nat/profiler/callbacks/langchain_callback_handler.py +6 -6
  171. nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
  172. nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
  173. nat/profiler/data_frame_row.py +1 -1
  174. nat/profiler/decorators/framework_wrapper.py +62 -13
  175. nat/profiler/decorators/function_tracking.py +160 -3
  176. nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
  177. nat/profiler/forecasting/models/linear_model.py +1 -1
  178. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  179. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  180. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
  181. nat/profiler/inference_optimization/data_models.py +3 -3
  182. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +8 -9
  183. nat/profiler/inference_optimization/token_uniqueness.py +1 -1
  184. nat/profiler/parameter_optimization/__init__.py +0 -0
  185. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  186. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  187. nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
  188. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  189. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  190. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  191. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  192. nat/profiler/profile_runner.py +14 -9
  193. nat/profiler/utils.py +4 -2
  194. nat/registry_handlers/local/local_handler.py +2 -2
  195. nat/registry_handlers/package_utils.py +1 -2
  196. nat/registry_handlers/pypi/pypi_handler.py +23 -26
  197. nat/registry_handlers/register.py +3 -4
  198. nat/registry_handlers/rest/rest_handler.py +12 -13
  199. nat/retriever/milvus/retriever.py +2 -2
  200. nat/retriever/nemo_retriever/retriever.py +1 -1
  201. nat/retriever/register.py +0 -1
  202. nat/runtime/loader.py +2 -2
  203. nat/runtime/runner.py +106 -8
  204. nat/runtime/session.py +69 -8
  205. nat/settings/global_settings.py +16 -5
  206. nat/tool/chat_completion.py +5 -2
  207. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
  208. nat/tool/datetime_tools.py +49 -9
  209. nat/tool/document_search.py +2 -2
  210. nat/tool/github_tools.py +450 -0
  211. nat/tool/memory_tools/get_memory_tool.py +1 -1
  212. nat/tool/nvidia_rag.py +1 -1
  213. nat/tool/register.py +2 -9
  214. nat/tool/retriever.py +3 -2
  215. nat/utils/callable_utils.py +70 -0
  216. nat/utils/data_models/schema_validator.py +3 -3
  217. nat/utils/decorators.py +210 -0
  218. nat/utils/exception_handlers/automatic_retries.py +104 -51
  219. nat/utils/exception_handlers/schemas.py +1 -1
  220. nat/utils/io/yaml_tools.py +2 -2
  221. nat/utils/log_levels.py +25 -0
  222. nat/utils/reactive/base/observable_base.py +2 -2
  223. nat/utils/reactive/base/observer_base.py +1 -1
  224. nat/utils/reactive/observable.py +2 -2
  225. nat/utils/reactive/observer.py +4 -4
  226. nat/utils/reactive/subscription.py +1 -1
  227. nat/utils/settings/global_settings.py +6 -8
  228. nat/utils/type_converter.py +4 -3
  229. nat/utils/type_utils.py +9 -5
  230. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/METADATA +42 -18
  231. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/RECORD +238 -196
  232. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/entry_points.txt +1 -0
  233. nat/cli/commands/info/list_mcp.py +0 -304
  234. nat/tool/github_tools/create_github_commit.py +0 -133
  235. nat/tool/github_tools/create_github_issue.py +0 -87
  236. nat/tool/github_tools/create_github_pr.py +0 -106
  237. nat/tool/github_tools/get_github_file.py +0 -106
  238. nat/tool/github_tools/get_github_issue.py +0 -166
  239. nat/tool/github_tools/get_github_pr.py +0 -256
  240. nat/tool/github_tools/update_github_issue.py +0 -100
  241. nat/tool/mcp/exceptions.py +0 -142
  242. nat/tool/mcp/mcp_client.py +0 -255
  243. nat/tool/mcp/mcp_tool.py +0 -96
  244. nat/utils/exception_handlers/mcp.py +0 -211
  245. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  246. /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
  247. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/WHEEL +0 -0
  248. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  249. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  250. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/top_level.txt +0 -0
@@ -82,7 +82,7 @@ async def register_ttc_tool_orchestration_function(
82
82
  function_map = {}
83
83
  for fn_ref in config.augmented_fns:
84
84
  # Retrieve the actual function from the builder
85
- fn_obj = builder.get_function(fn_ref)
85
+ fn_obj = await builder.get_function(fn_ref)
86
86
  function_map[fn_ref] = fn_obj
87
87
 
88
88
  # 2) Instantiate search, editing, scoring, selection strategies (if any)
@@ -148,13 +148,13 @@ async def register_ttc_tool_orchestration_function(
148
148
  result = await fn.acall_invoke(item.output)
149
149
  return item, result, None
150
150
  except Exception as e:
151
- logger.error(f"Error invoking function '{item.name}': {e}")
151
+ logger.exception(f"Error invoking function '{item.name}': {e}")
152
152
  return item, None, str(e)
153
153
 
154
154
  tasks = []
155
155
  for item in ttc_items:
156
156
  if item.name not in function_map:
157
- logger.error(f"Function '{item.name}' not found in function map.")
157
+ logger.error(f"Function '{item.name}' not found in function map.", exc_info=True)
158
158
  item.output = f"Error: Function '{item.name}' not found in function map. Check your input"
159
159
  else:
160
160
  fn = function_map[item.name]
@@ -80,7 +80,7 @@ async def register_ttc_tool_wrapper_function(
80
80
  raise ImportError("langchain-core is not installed. Please install it to use SingleShotMultiPlanPlanner.\n"
81
81
  "This error can be resolved by installing nvidia-nat-langchain.")
82
82
 
83
- augmented_function: Function = builder.get_function(config.augmented_fn)
83
+ augmented_function: Function = await builder.get_function(config.augmented_fn)
84
84
  input_llm: BaseChatModel = await builder.get_llm(config.input_llm, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
85
85
 
86
86
  if not augmented_function.has_single_output:
@@ -17,9 +17,10 @@ from abc import ABC
17
17
  from abc import abstractmethod
18
18
 
19
19
  from nat.builder.builder import Builder
20
- from nat.experimental.test_time_compute.models.ttc_item import TTCItem
21
- from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum, PipelineTypeEnum
22
20
  from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
21
+ from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
22
+ from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
23
+ from nat.experimental.test_time_compute.models.ttc_item import TTCItem
23
24
 
24
25
 
25
26
  class StrategyBase(ABC):
@@ -45,11 +46,11 @@ class StrategyBase(ABC):
45
46
  items: list[TTCItem],
46
47
  original_prompt: str | None = None,
47
48
  agent_context: str | None = None,
48
- **kwargs) -> [TTCItem]:
49
+ **kwargs) -> list[TTCItem]:
49
50
  pass
50
51
 
51
52
  @abstractmethod
52
- def supported_pipeline_types(self) -> [PipelineTypeEnum]:
53
+ def supported_pipeline_types(self) -> list[PipelineTypeEnum]:
53
54
  """Return the stage types supported by this selector."""
54
55
  pass
55
56
 
@@ -13,7 +13,6 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- # pylint: disable=unused-import
17
16
  # flake8: noqa
18
17
 
19
18
  from .editing import iterative_plan_refinement_editor
@@ -71,7 +71,7 @@ class LLMBasedOutputMergingSelector(StrategyBase):
71
71
  raise ImportError("langchain-core is not installed. Please install it to use SingleShotMultiPlanPlanner.\n"
72
72
  "This error can be resolved by installing nvidia-nat-langchain.")
73
73
 
74
- from typing import Callable
74
+ from collections.abc import Callable
75
75
 
76
76
  from pydantic import BaseModel
77
77
 
@@ -135,8 +135,6 @@ class LLMBasedOutputMergingSelector(StrategyBase):
135
135
  except Exception as e:
136
136
  logger.error(f"Error parsing merged output: {e}")
137
137
  raise ValueError("Failed to parse merged output.")
138
- else:
139
- merged_output = merged_output
140
138
 
141
139
  logger.info("Merged output: %s", str(merged_output))
142
140
 
@@ -14,13 +14,16 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import asyncio
17
+ import logging
17
18
  import secrets
18
19
  import webbrowser
19
20
  from dataclasses import dataclass
20
21
  from dataclasses import field
21
22
 
22
23
  import click
24
+ import httpx
23
25
  import pkce
26
+ from authlib.common.errors import AuthlibBaseError as OAuthError
24
27
  from authlib.integrations.httpx_client import AsyncOAuth2Client
25
28
  from fastapi import FastAPI
26
29
  from fastapi import Request
@@ -32,6 +35,8 @@ from nat.data_models.authentication import AuthFlowType
32
35
  from nat.data_models.authentication import AuthProviderBaseConfig
33
36
  from nat.front_ends.fastapi.fastapi_front_end_controller import _FastApiFrontEndController
34
37
 
38
+ logger = logging.getLogger(__name__)
39
+
35
40
 
36
41
  # --------------------------------------------------------------------------- #
37
42
  # Helpers #
@@ -87,17 +92,53 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
87
92
  """
88
93
  Separated for easy overriding in tests (to inject ASGITransport).
89
94
  """
90
- client = AsyncOAuth2Client(
91
- client_id=cfg.client_id,
92
- client_secret=cfg.client_secret,
93
- redirect_uri=cfg.redirect_uri,
94
- scope=" ".join(cfg.scopes) if cfg.scopes else None,
95
- token_endpoint=cfg.token_url,
96
- token_endpoint_auth_method=cfg.token_endpoint_auth_method,
97
- code_challenge_method="S256" if cfg.use_pkce else None,
98
- )
99
- self._oauth_client = client
100
- return client
95
+ try:
96
+ client = AsyncOAuth2Client(
97
+ client_id=cfg.client_id,
98
+ client_secret=cfg.client_secret,
99
+ redirect_uri=cfg.redirect_uri,
100
+ scope=" ".join(cfg.scopes) if cfg.scopes else None,
101
+ token_endpoint=cfg.token_url,
102
+ token_endpoint_auth_method=cfg.token_endpoint_auth_method,
103
+ code_challenge_method="S256" if cfg.use_pkce else None,
104
+ )
105
+ self._oauth_client = client
106
+ return client
107
+ except (OAuthError, ValueError, TypeError) as e:
108
+ raise RuntimeError(f"Invalid OAuth2 configuration: {e}") from e
109
+ except Exception as e:
110
+ raise RuntimeError(f"Failed to create OAuth2 client: {e}") from e
111
+
112
+ def _create_authorization_url(self,
113
+ client: AsyncOAuth2Client,
114
+ config: OAuth2AuthCodeFlowProviderConfig,
115
+ state: str,
116
+ verifier: str | None = None,
117
+ challenge: str | None = None) -> str:
118
+ """
119
+ Create OAuth authorization URL with proper error handling.
120
+
121
+ Args:
122
+ client: The OAuth2 client instance
123
+ config: OAuth2 configuration
124
+ state: OAuth state parameter
125
+ verifier: PKCE verifier (if using PKCE)
126
+ challenge: PKCE challenge (if using PKCE)
127
+
128
+ Returns:
129
+ The authorization URL
130
+ """
131
+ try:
132
+ auth_url, _ = client.create_authorization_url(
133
+ config.authorization_url,
134
+ state=state,
135
+ code_verifier=verifier if config.use_pkce else None,
136
+ code_challenge=challenge if config.use_pkce else None,
137
+ **(config.authorization_kwargs or {})
138
+ )
139
+ return auth_url
140
+ except (OAuthError, ValueError, TypeError) as e:
141
+ raise RuntimeError(f"Error creating OAuth authorization URL: {e}") from e
101
142
 
102
143
  # --------------------------- HTTP Basic ------------------------------ #
103
144
  @staticmethod
@@ -131,13 +172,12 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
131
172
  flow_state.verifier = verifier
132
173
  flow_state.challenge = challenge
133
174
 
134
- auth_url, _ = client.create_authorization_url(
135
- cfg.authorization_url,
136
- state=state,
137
- code_verifier=flow_state.verifier if cfg.use_pkce else None,
138
- code_challenge=flow_state.challenge if cfg.use_pkce else None,
139
- **(cfg.authorization_kwargs or {})
140
- )
175
+ # Create authorization URL using helper function
176
+ auth_url = self._create_authorization_url(client=client,
177
+ config=cfg,
178
+ state=state,
179
+ verifier=flow_state.verifier,
180
+ challenge=flow_state.challenge)
141
181
 
142
182
  # Register flow + maybe spin up redirect handler
143
183
  async with self._server_lock:
@@ -149,14 +189,18 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
149
189
  self._flows[state] = flow_state
150
190
  self._active_flows += 1
151
191
 
152
- click.echo("Your browser has been opened for authentication.")
153
- webbrowser.open(auth_url)
192
+ try:
193
+ webbrowser.open(auth_url)
194
+ click.echo("Your browser has been opened for authentication.")
195
+ except Exception as e:
196
+ logger.error("Browser open failed: %s", e)
197
+ raise RuntimeError(f"Browser open failed: {e}") from e
154
198
 
155
199
  # Wait for the redirect to land
156
200
  try:
157
201
  token = await asyncio.wait_for(flow_state.future, timeout=300)
158
- except asyncio.TimeoutError:
159
- raise RuntimeError("Authentication timed out (5 min).")
202
+ except TimeoutError as exc:
203
+ raise RuntimeError("Authentication timed out (5 min).") from exc
160
204
  finally:
161
205
  async with self._server_lock:
162
206
  self._flows.pop(state, None)
@@ -175,9 +219,9 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
175
219
  # --------------- redirect server / in‑process app -------------------- #
176
220
  async def _build_redirect_app(self) -> FastAPI:
177
221
  """
178
- * If cfg.run_redirect_local_server == True → start a uvicorn server (old behaviour).
179
- * Else → only build the FastAPI app and save it to `self._redirect_app`
180
- for in‑process testing with ASGITransport.
222
+ * If cfg.run_redirect_local_server == True → start a local server.
223
+ * Else → only build the redirect app and save it to `self._redirect_app`
224
+ for in‑process testing.
181
225
  """
182
226
  app = FastAPI()
183
227
 
@@ -195,8 +239,16 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
195
239
  state=state,
196
240
  )
197
241
  flow_state.future.set_result(token)
198
- except Exception as exc: # noqa: BLE001
199
- flow_state.future.set_exception(exc)
242
+ except OAuthError as e:
243
+ flow_state.future.set_exception(
244
+ RuntimeError(f"Authorization server rejected request: {e.error} ({e.description})"))
245
+ return "Authentication failed: Authorization server rejected the request. You may close this tab."
246
+ except httpx.HTTPError as e:
247
+ flow_state.future.set_exception(RuntimeError(f"Network error during token fetch: {e}"))
248
+ return "Authentication failed: Network error occurred. You may close this tab."
249
+ except Exception as e:
250
+ flow_state.future.set_exception(RuntimeError(f"Authentication failed: {e}"))
251
+ return "Authentication failed: An unexpected error occurred. You may close this tab."
200
252
  return "Authentication successful – you may close this tab."
201
253
 
202
254
  return app
@@ -213,7 +265,7 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
213
265
 
214
266
  asyncio.create_task(self._server_controller.start_server(host="localhost", port=8000))
215
267
 
216
- # Give uvicorn a moment to bind sockets before we return
268
+ # Give the server a moment to bind sockets before we return
217
269
  await asyncio.sleep(0.3)
218
270
  except Exception as exc: # noqa: BLE001
219
271
  raise RuntimeError(f"Failed to start redirect server: {exc}") from exc
@@ -227,7 +279,7 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
227
279
  @property
228
280
  def redirect_app(self) -> FastAPI | None:
229
281
  """
230
- In testmode (run_redirect_local_server=False) the in‑memory FastAPI
231
- app is exposed so you can mount it on `httpx.ASGITransport`.
282
+ In test mode (run_redirect_local_server=False) the in‑memory redirect
283
+ app is exposed for testing purposes.
232
284
  """
233
285
  return self._redirect_app
@@ -55,9 +55,10 @@ class ConsoleFrontEndPlugin(SimpleFrontEndPluginBase[ConsoleFrontEndConfig]):
55
55
  self.auth_flow_handler = ConsoleAuthenticationFlowHandler()
56
56
 
57
57
  async def pre_run(self):
58
-
59
- if (not self.front_end_config.input_query and not self.front_end_config.input_file):
60
- raise click.UsageError("Must specify either --input_query or --input_file")
58
+ if (self.front_end_config.input_query is not None and self.front_end_config.input_file is not None):
59
+ raise click.UsageError("Must specify either --input or --input_file, not both")
60
+ if (self.front_end_config.input_query is None and self.front_end_config.input_file is None):
61
+ raise click.UsageError("Must specify either --input or --input_file")
61
62
 
62
63
  async def run_workflow(self, session_manager: SessionManager):
63
64
 
@@ -80,12 +81,14 @@ class ConsoleFrontEndPlugin(SimpleFrontEndPluginBase[ConsoleFrontEndConfig]):
80
81
  input_list = list(self.front_end_config.input_query)
81
82
  logger.debug("Processing input: %s", self.front_end_config.input_query)
82
83
 
83
- runner_outputs = await asyncio.gather(*[run_single_query(query) for query in input_list])
84
+ # Make `return_exceptions=False` explicit; all exceptions are raised instead of being silenced
85
+ runner_outputs = await asyncio.gather(*[run_single_query(query) for query in input_list],
86
+ return_exceptions=False)
84
87
 
85
88
  elif (self.front_end_config.input_file):
86
89
 
87
90
  # Run the workflow
88
- with open(self.front_end_config.input_file, "r", encoding="utf-8") as f:
91
+ with open(self.front_end_config.input_file, encoding="utf-8") as f:
89
92
 
90
93
  async with session_manager.workflow.run(f) as runner:
91
94
  runner_outputs = await runner.result(to_type=str)
@@ -22,6 +22,7 @@ from dataclasses import dataclass
22
22
  from dataclasses import field
23
23
 
24
24
  import pkce
25
+ from authlib.common.errors import AuthlibBaseError as OAuthError
25
26
  from authlib.integrations.httpx_client import AsyncOAuth2Client
26
27
 
27
28
  from nat.authentication.interfaces import FlowHandlerBase
@@ -61,14 +62,50 @@ class WebSocketAuthenticationFlowHandler(FlowHandlerBase):
61
62
 
62
63
  raise NotImplementedError(f"Authentication method '{method}' is not supported by the websocket frontend.")
63
64
 
64
- def create_oauth_client(self, config: OAuth2AuthCodeFlowProviderConfig):
65
- return AsyncOAuth2Client(client_id=config.client_id,
66
- client_secret=config.client_secret,
67
- redirect_uri=config.redirect_uri,
68
- scope=" ".join(config.scopes) if config.scopes else None,
69
- token_endpoint=config.token_url,
70
- code_challenge_method='S256' if config.use_pkce else None,
71
- token_endpoint_auth_method=config.token_endpoint_auth_method)
65
+ def create_oauth_client(self, config: OAuth2AuthCodeFlowProviderConfig) -> AsyncOAuth2Client:
66
+ try:
67
+ return AsyncOAuth2Client(client_id=config.client_id,
68
+ client_secret=config.client_secret,
69
+ redirect_uri=config.redirect_uri,
70
+ scope=" ".join(config.scopes) if config.scopes else None,
71
+ token_endpoint=config.token_url,
72
+ code_challenge_method='S256' if config.use_pkce else None,
73
+ token_endpoint_auth_method=config.token_endpoint_auth_method)
74
+ except (OAuthError, ValueError, TypeError) as e:
75
+ raise RuntimeError(f"Invalid OAuth2 configuration: {e}") from e
76
+ except Exception as e:
77
+ raise RuntimeError(f"Failed to create OAuth2 client: {e}") from e
78
+
79
+ def _create_authorization_url(self,
80
+ client: AsyncOAuth2Client,
81
+ config: OAuth2AuthCodeFlowProviderConfig,
82
+ state: str,
83
+ verifier: str = None,
84
+ challenge: str = None) -> str:
85
+ """
86
+ Create OAuth authorization URL with proper error handling.
87
+
88
+ Args:
89
+ client: The OAuth2 client instance
90
+ config: OAuth2 configuration
91
+ state: OAuth state parameter
92
+ verifier: PKCE verifier (if using PKCE)
93
+ challenge: PKCE challenge (if using PKCE)
94
+
95
+ Returns:
96
+ The authorization URL
97
+ """
98
+ try:
99
+ authorization_url, _ = client.create_authorization_url(
100
+ config.authorization_url,
101
+ state=state,
102
+ code_verifier=verifier if config.use_pkce else None,
103
+ code_challenge=challenge if config.use_pkce else None,
104
+ **(config.authorization_kwargs or {})
105
+ )
106
+ return authorization_url
107
+ except (OAuthError, ValueError, TypeError) as e:
108
+ raise RuntimeError(f"Error creating OAuth authorization URL: {e}") from e
72
109
 
73
110
  async def _handle_oauth2_auth_code_flow(self, config: OAuth2AuthCodeFlowProviderConfig) -> AuthenticatedContext:
74
111
 
@@ -82,21 +119,19 @@ class WebSocketAuthenticationFlowHandler(FlowHandlerBase):
82
119
  flow_state.verifier = verifier
83
120
  flow_state.challenge = challenge
84
121
 
85
- authorization_url, _ = flow_state.client.create_authorization_url(
86
- config.authorization_url,
87
- state=state,
88
- code_verifier=flow_state.verifier if config.use_pkce else None,
89
- code_challenge=flow_state.challenge if config.use_pkce else None,
90
- **(config.authorization_kwargs or {})
91
- )
122
+ authorization_url = self._create_authorization_url(client=flow_state.client,
123
+ config=config,
124
+ state=state,
125
+ verifier=flow_state.verifier,
126
+ challenge=flow_state.challenge)
92
127
 
93
128
  await self._add_flow_cb(state, flow_state)
94
129
  await self._web_socket_message_handler.create_websocket_message(_HumanPromptOAuthConsent(text=authorization_url)
95
130
  )
96
131
  try:
97
132
  token = await asyncio.wait_for(flow_state.future, timeout=300)
98
- except asyncio.TimeoutError:
99
- raise RuntimeError("Authentication flow timed out after 5 minutes.")
133
+ except TimeoutError as exc:
134
+ raise RuntimeError("Authentication flow timed out after 5 minutes.") from exc
100
135
  finally:
101
136
 
102
137
  await self._remove_flow_cb(state)
@@ -0,0 +1,65 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import typing
17
+ from abc import ABC
18
+ from collections.abc import AsyncGenerator
19
+ from collections.abc import Generator
20
+ from contextlib import asynccontextmanager
21
+ from contextlib import contextmanager
22
+
23
+ if typing.TYPE_CHECKING:
24
+ from dask.distributed import Client
25
+
26
+
27
+ class DaskClientMixin(ABC):
28
+
29
+ @asynccontextmanager
30
+ async def client(self, address: str) -> AsyncGenerator["Client"]:
31
+ """
32
+ Async context manager for obtaining a Dask client.
33
+
34
+ Yields
35
+ ------
36
+ Client
37
+ An async Dask client connected to the scheduler. The client is automatically closed when exiting the
38
+ context manager.
39
+ """
40
+ from dask.distributed import Client
41
+ client = await Client(address=address, asynchronous=True)
42
+
43
+ try:
44
+ yield client
45
+ finally:
46
+ await client.close()
47
+
48
+ @contextmanager
49
+ def blocking_client(self, address: str) -> Generator["Client"]:
50
+ """
51
+ context manager for obtaining a blocking Dask client.
52
+
53
+ Yields
54
+ ------
55
+ Client
56
+ A blocking Dask client connected to the scheduler. The client is automatically closed when exiting the
57
+ context manager.
58
+ """
59
+ from dask.distributed import Client
60
+ client = Client(address=address)
61
+
62
+ try:
63
+ yield client
64
+ finally:
65
+ client.close()
@@ -14,6 +14,8 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import logging
17
+ import os
18
+ import sys
17
19
  import typing
18
20
  from datetime import datetime
19
21
  from pathlib import Path
@@ -31,6 +33,20 @@ logger = logging.getLogger(__name__)
31
33
  YAML_EXTENSIONS = (".yaml", ".yml")
32
34
 
33
35
 
36
+ def _is_reserved(path: Path) -> bool:
37
+ """
38
+ Check if a path is reserved in the current Python version and platform.
39
+
40
+ On Windows, this function checks if the path is reserved in the current Python version.
41
+ On other platforms, returns False
42
+ """
43
+ if sys.platform != "win32":
44
+ return False
45
+ if sys.version_info >= (3, 13):
46
+ return os.path.isreserved(path)
47
+ return path.is_reserved()
48
+
49
+
34
50
  class EvaluateRequest(BaseModel):
35
51
  """Request model for the evaluate endpoint."""
36
52
  config_file: str = Field(description="Path to the configuration file for evaluation")
@@ -51,7 +67,7 @@ class EvaluateRequest(BaseModel):
51
67
  f"Job ID '{job_id}' contains invalid characters. Only alphanumeric characters and underscores are"
52
68
  " allowed.")
53
69
 
54
- if job_id_path.is_reserved():
70
+ if _is_reserved(job_id_path):
55
71
  # reserved names is Windows specific
56
72
  raise ValueError(f"Job ID '{job_id}' is a reserved name. Please choose a different name.")
57
73
 
@@ -68,7 +84,7 @@ class EvaluateRequest(BaseModel):
68
84
  raise ValueError(f"Config file '{config_file}' must be a YAML file with one of the following extensions: "
69
85
  f"{', '.join(YAML_EXTENSIONS)}")
70
86
 
71
- if config_file_path.is_reserved():
87
+ if _is_reserved(config_file_path):
72
88
  # reserved names is Windows specific
73
89
  raise ValueError(f"Config file '{config_file}' is a reserved name. Please choose a different name.")
74
90
 
@@ -181,9 +197,24 @@ class FastApiFrontEndConfig(FrontEndBaseConfig, name="fastapi"):
181
197
  port: int = Field(default=8000, description="Port to bind the server to", ge=0, le=65535)
182
198
  reload: bool = Field(default=False, description="Enable auto-reload for development")
183
199
  workers: int = Field(default=1, description="Number of workers to run", ge=1)
184
- max_running_async_jobs: int = Field(default=10,
185
- description="Maximum number of async jobs to run concurrently",
186
- ge=1)
200
+ scheduler_address: str | None = Field(
201
+ default=None,
202
+ description=("Address of the Dask scheduler to use for async jobs. If None, a Dask local cluster is created. "
203
+ "Note: This requires the optional dask dependency to be installed."))
204
+ db_url: str | None = Field(
205
+ default=None,
206
+ description=
207
+ "SQLAlchemy database URL for storing async job metadata, if unset a temporary SQLite database is used.")
208
+ max_running_async_jobs: int = Field(
209
+ default=10,
210
+ description=(
211
+ "Maximum number of async jobs to run concurrently, this controls the number of dask workers created. "
212
+ "This parameter is only used when scheduler_address is `None` and a Dask local cluster is created."),
213
+ ge=1)
214
+ dask_log_level: str = Field(
215
+ default="WARNING",
216
+ description="Logging level for Dask.",
217
+ )
187
218
  step_adaptor: StepAdaptorConfig = StepAdaptorConfig()
188
219
 
189
220
  workflow: typing.Annotated[EndpointBase, Field(description="Endpoint for the default workflow.")] = EndpointBase(
@@ -47,11 +47,11 @@ class _FastApiFrontEndController:
47
47
  self._server_background_task = asyncio.create_task(self._server.serve())
48
48
  except asyncio.CancelledError as e:
49
49
  error_message = f"Task error occurred while starting API server: {str(e)}"
50
- logger.error(error_message, exc_info=True)
50
+ logger.error(error_message)
51
51
  raise RuntimeError(error_message) from e
52
52
  except Exception as e:
53
53
  error_message = f"Unexpected error occurred while starting API server: {str(e)}"
54
- logger.error(error_message, exc_info=True)
54
+ logger.exception(error_message)
55
55
  raise RuntimeError(error_message) from e
56
56
 
57
57
  async def stop_server(self) -> None:
@@ -63,6 +63,6 @@ class _FastApiFrontEndController:
63
63
  self._server.should_exit = True
64
64
  await self._server_background_task
65
65
  except asyncio.CancelledError as e:
66
- logger.error("Server shutdown failed: %s", str(e), exc_info=True)
66
+ logger.exception("Server shutdown failed: %s", str(e))
67
67
  except Exception as e:
68
- logger.error("Unexpected error occurred: %s", str(e), exc_info=True)
68
+ logger.exception("Unexpected error occurred: %s", str(e))