nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.4.0a20251112__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. nat/agent/base.py +13 -8
  2. nat/agent/prompt_optimizer/prompt.py +68 -0
  3. nat/agent/prompt_optimizer/register.py +149 -0
  4. nat/agent/react_agent/agent.py +6 -5
  5. nat/agent/react_agent/register.py +49 -39
  6. nat/agent/reasoning_agent/reasoning_agent.py +17 -15
  7. nat/agent/register.py +2 -0
  8. nat/agent/responses_api_agent/__init__.py +14 -0
  9. nat/agent/responses_api_agent/register.py +126 -0
  10. nat/agent/rewoo_agent/agent.py +304 -117
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +51 -38
  13. nat/agent/tool_calling_agent/agent.py +75 -17
  14. nat/agent/tool_calling_agent/register.py +46 -23
  15. nat/authentication/api_key/api_key_auth_provider.py +6 -11
  16. nat/authentication/api_key/api_key_auth_provider_config.py +8 -5
  17. nat/authentication/credential_validator/__init__.py +14 -0
  18. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  19. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  20. nat/authentication/interfaces.py +5 -2
  21. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
  22. nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +2 -1
  23. nat/authentication/oauth2/oauth2_resource_server_config.py +125 -0
  24. nat/builder/builder.py +55 -23
  25. nat/builder/component_utils.py +9 -5
  26. nat/builder/context.py +54 -15
  27. nat/builder/eval_builder.py +14 -9
  28. nat/builder/framework_enum.py +1 -0
  29. nat/builder/front_end.py +1 -1
  30. nat/builder/function.py +370 -0
  31. nat/builder/function_info.py +1 -1
  32. nat/builder/intermediate_step_manager.py +38 -2
  33. nat/builder/workflow.py +5 -0
  34. nat/builder/workflow_builder.py +306 -54
  35. nat/cli/cli_utils/config_override.py +1 -1
  36. nat/cli/commands/info/info.py +16 -6
  37. nat/cli/commands/mcp/__init__.py +14 -0
  38. nat/cli/commands/mcp/mcp.py +986 -0
  39. nat/cli/commands/optimize.py +90 -0
  40. nat/cli/commands/start.py +1 -1
  41. nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
  42. nat/cli/commands/workflow/templates/register.py.j2 +2 -2
  43. nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
  44. nat/cli/commands/workflow/workflow_commands.py +60 -18
  45. nat/cli/entrypoint.py +15 -11
  46. nat/cli/main.py +3 -0
  47. nat/cli/register_workflow.py +38 -4
  48. nat/cli/type_registry.py +72 -1
  49. nat/control_flow/__init__.py +0 -0
  50. nat/control_flow/register.py +20 -0
  51. nat/control_flow/router_agent/__init__.py +0 -0
  52. nat/control_flow/router_agent/agent.py +329 -0
  53. nat/control_flow/router_agent/prompt.py +48 -0
  54. nat/control_flow/router_agent/register.py +91 -0
  55. nat/control_flow/sequential_executor.py +166 -0
  56. nat/data_models/agent.py +34 -0
  57. nat/data_models/api_server.py +199 -69
  58. nat/data_models/authentication.py +23 -9
  59. nat/data_models/common.py +47 -0
  60. nat/data_models/component.py +2 -0
  61. nat/data_models/component_ref.py +11 -0
  62. nat/data_models/config.py +41 -17
  63. nat/data_models/dataset_handler.py +4 -3
  64. nat/data_models/function.py +34 -0
  65. nat/data_models/function_dependencies.py +8 -0
  66. nat/data_models/intermediate_step.py +9 -1
  67. nat/data_models/llm.py +15 -1
  68. nat/data_models/openai_mcp.py +46 -0
  69. nat/data_models/optimizable.py +208 -0
  70. nat/data_models/optimizer.py +161 -0
  71. nat/data_models/span.py +41 -3
  72. nat/data_models/thinking_mixin.py +2 -2
  73. nat/embedder/azure_openai_embedder.py +2 -1
  74. nat/embedder/nim_embedder.py +3 -2
  75. nat/embedder/openai_embedder.py +3 -2
  76. nat/eval/config.py +1 -1
  77. nat/eval/dataset_handler/dataset_downloader.py +3 -2
  78. nat/eval/dataset_handler/dataset_filter.py +34 -2
  79. nat/eval/evaluate.py +10 -3
  80. nat/eval/evaluator/base_evaluator.py +1 -1
  81. nat/eval/rag_evaluator/evaluate.py +7 -4
  82. nat/eval/register.py +4 -0
  83. nat/eval/runtime_evaluator/__init__.py +14 -0
  84. nat/eval/runtime_evaluator/evaluate.py +123 -0
  85. nat/eval/runtime_evaluator/register.py +100 -0
  86. nat/eval/swe_bench_evaluator/evaluate.py +1 -1
  87. nat/eval/trajectory_evaluator/register.py +1 -1
  88. nat/eval/tunable_rag_evaluator/evaluate.py +1 -1
  89. nat/eval/usage_stats.py +2 -0
  90. nat/eval/utils/output_uploader.py +3 -2
  91. nat/eval/utils/weave_eval.py +17 -3
  92. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  93. nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
  94. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  95. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
  96. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +3 -3
  97. nat/experimental/test_time_compute/models/strategy_base.py +2 -2
  98. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
  99. nat/front_ends/console/authentication_flow_handler.py +82 -30
  100. nat/front_ends/console/console_front_end_plugin.py +19 -7
  101. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
  102. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  103. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  104. nat/front_ends/fastapi/fastapi_front_end_config.py +25 -3
  105. nat/front_ends/fastapi/fastapi_front_end_plugin.py +140 -3
  106. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +445 -265
  107. nat/front_ends/fastapi/job_store.py +518 -99
  108. nat/front_ends/fastapi/main.py +11 -19
  109. nat/front_ends/fastapi/message_handler.py +69 -44
  110. nat/front_ends/fastapi/message_validator.py +8 -7
  111. nat/front_ends/fastapi/utils.py +57 -0
  112. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  113. nat/front_ends/mcp/mcp_front_end_config.py +71 -3
  114. nat/front_ends/mcp/mcp_front_end_plugin.py +85 -21
  115. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +248 -29
  116. nat/front_ends/mcp/memory_profiler.py +320 -0
  117. nat/front_ends/mcp/tool_converter.py +78 -25
  118. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  119. nat/llm/aws_bedrock_llm.py +21 -8
  120. nat/llm/azure_openai_llm.py +14 -5
  121. nat/llm/litellm_llm.py +80 -0
  122. nat/llm/nim_llm.py +23 -9
  123. nat/llm/openai_llm.py +19 -7
  124. nat/llm/register.py +4 -0
  125. nat/llm/utils/thinking.py +1 -1
  126. nat/observability/exporter/base_exporter.py +1 -1
  127. nat/observability/exporter/processing_exporter.py +29 -55
  128. nat/observability/exporter/span_exporter.py +43 -15
  129. nat/observability/exporter_manager.py +2 -2
  130. nat/observability/mixin/redaction_config_mixin.py +5 -4
  131. nat/observability/mixin/tagging_config_mixin.py +26 -14
  132. nat/observability/mixin/type_introspection_mixin.py +420 -107
  133. nat/observability/processor/batching_processor.py +1 -1
  134. nat/observability/processor/processor.py +3 -0
  135. nat/observability/processor/redaction/__init__.py +24 -0
  136. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  137. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  138. nat/observability/processor/redaction/redaction_processor.py +177 -0
  139. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  140. nat/observability/processor/span_tagging_processor.py +21 -14
  141. nat/observability/register.py +16 -0
  142. nat/profiler/callbacks/langchain_callback_handler.py +32 -7
  143. nat/profiler/callbacks/llama_index_callback_handler.py +36 -2
  144. nat/profiler/callbacks/token_usage_base_model.py +2 -0
  145. nat/profiler/decorators/framework_wrapper.py +61 -9
  146. nat/profiler/decorators/function_tracking.py +35 -3
  147. nat/profiler/forecasting/models/linear_model.py +1 -1
  148. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  149. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  150. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
  151. nat/profiler/parameter_optimization/__init__.py +0 -0
  152. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  153. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  154. nat/profiler/parameter_optimization/parameter_optimizer.py +189 -0
  155. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  156. nat/profiler/parameter_optimization/pareto_visualizer.py +460 -0
  157. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  158. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  159. nat/profiler/utils.py +3 -1
  160. nat/registry_handlers/pypi/register_pypi.py +5 -3
  161. nat/registry_handlers/rest/register_rest.py +5 -3
  162. nat/retriever/milvus/retriever.py +1 -1
  163. nat/retriever/nemo_retriever/register.py +2 -1
  164. nat/runtime/loader.py +1 -1
  165. nat/runtime/runner.py +111 -6
  166. nat/runtime/session.py +49 -3
  167. nat/settings/global_settings.py +2 -2
  168. nat/tool/chat_completion.py +4 -1
  169. nat/tool/code_execution/code_sandbox.py +3 -6
  170. nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +19 -32
  171. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +6 -1
  172. nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +2 -0
  173. nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +10 -4
  174. nat/tool/datetime_tools.py +1 -1
  175. nat/tool/github_tools.py +450 -0
  176. nat/tool/memory_tools/add_memory_tool.py +3 -3
  177. nat/tool/memory_tools/delete_memory_tool.py +3 -4
  178. nat/tool/memory_tools/get_memory_tool.py +4 -4
  179. nat/tool/register.py +2 -7
  180. nat/tool/server_tools.py +15 -2
  181. nat/utils/__init__.py +76 -0
  182. nat/utils/callable_utils.py +70 -0
  183. nat/utils/data_models/schema_validator.py +1 -1
  184. nat/utils/decorators.py +210 -0
  185. nat/utils/exception_handlers/automatic_retries.py +278 -72
  186. nat/utils/io/yaml_tools.py +73 -3
  187. nat/utils/log_levels.py +25 -0
  188. nat/utils/responses_api.py +26 -0
  189. nat/utils/string_utils.py +16 -0
  190. nat/utils/type_converter.py +12 -3
  191. nat/utils/type_utils.py +6 -2
  192. nvidia_nat-1.4.0a20251112.dist-info/METADATA +197 -0
  193. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/RECORD +199 -165
  194. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/entry_points.txt +1 -0
  195. nat/cli/commands/info/list_mcp.py +0 -461
  196. nat/data_models/temperature_mixin.py +0 -43
  197. nat/data_models/top_p_mixin.py +0 -43
  198. nat/observability/processor/header_redaction_processor.py +0 -123
  199. nat/observability/processor/redaction_processor.py +0 -77
  200. nat/tool/code_execution/test_code_execution_sandbox.py +0 -414
  201. nat/tool/github_tools/create_github_commit.py +0 -133
  202. nat/tool/github_tools/create_github_issue.py +0 -87
  203. nat/tool/github_tools/create_github_pr.py +0 -106
  204. nat/tool/github_tools/get_github_file.py +0 -106
  205. nat/tool/github_tools/get_github_issue.py +0 -166
  206. nat/tool/github_tools/get_github_pr.py +0 -256
  207. nat/tool/github_tools/update_github_issue.py +0 -100
  208. nvidia_nat-1.3.0a20250910.dist-info/METADATA +0 -373
  209. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  210. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/WHEEL +0 -0
  211. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  212. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE.md +0 -0
  213. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,100 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from pydantic import Field
17
+
18
+ from nat.builder.builder import EvalBuilder
19
+ from nat.builder.evaluator import EvaluatorInfo
20
+ from nat.cli.register_workflow import register_evaluator
21
+ from nat.data_models.evaluator import EvaluatorBaseConfig
22
+ from nat.eval.evaluator.evaluator_model import EvalInput
23
+ from nat.eval.evaluator.evaluator_model import EvalOutput
24
+
25
+
26
+ class AverageLLMLatencyConfig(EvaluatorBaseConfig, name="avg_llm_latency"):
27
+ """Mean difference between connected LLM_START and LLM_END events (same UUID)."""
28
+
29
+ max_concurrency: int = Field(default=8, description="Max concurrency for evaluation.")
30
+
31
+
32
+ class AverageWorkflowRuntimeConfig(EvaluatorBaseConfig, name="avg_workflow_runtime"):
33
+ """Average workflow runtime per item (max timestamp - min timestamp)."""
34
+
35
+ max_concurrency: int = Field(default=8, description="Max concurrency for evaluation.")
36
+
37
+
38
+ class AverageNumberOfLLMCallsConfig(EvaluatorBaseConfig, name="avg_num_llm_calls"):
39
+ """Average number of LLM calls per item (count of LLM_END)."""
40
+
41
+ max_concurrency: int = Field(default=8, description="Max concurrency for evaluation.")
42
+
43
+
44
+ class AverageTokensPerLLMEndConfig(EvaluatorBaseConfig, name="avg_tokens_per_llm_end"):
45
+ """Average total tokens per LLM_END event (prompt + completion if available)."""
46
+
47
+ max_concurrency: int = Field(default=8, description="Max concurrency for evaluation.")
48
+
49
+
50
+ @register_evaluator(config_type=AverageLLMLatencyConfig)
51
+ async def register_avg_llm_latency_evaluator(config: AverageLLMLatencyConfig, builder: EvalBuilder):
52
+ from .evaluate import AverageLLMLatencyEvaluator
53
+
54
+ evaluator = AverageLLMLatencyEvaluator(max_concurrency=config.max_concurrency or builder.get_max_concurrency())
55
+
56
+ async def evaluate_fn(eval_input: EvalInput) -> EvalOutput:
57
+ return await evaluator.evaluate(eval_input)
58
+
59
+ yield EvaluatorInfo(config=config,
60
+ evaluate_fn=evaluate_fn,
61
+ description="Average LLM latency (s) from LLM_START to LLM_END")
62
+
63
+
64
+ @register_evaluator(config_type=AverageWorkflowRuntimeConfig)
65
+ async def register_avg_workflow_runtime_evaluator(config: AverageWorkflowRuntimeConfig, builder: EvalBuilder):
66
+ from .evaluate import AverageWorkflowRuntimeEvaluator
67
+
68
+ evaluator = AverageWorkflowRuntimeEvaluator(max_concurrency=config.max_concurrency or builder.get_max_concurrency())
69
+
70
+ async def evaluate_fn(eval_input: EvalInput) -> EvalOutput:
71
+ return await evaluator.evaluate(eval_input)
72
+
73
+ yield EvaluatorInfo(config=config, evaluate_fn=evaluate_fn, description="Average workflow runtime (s)")
74
+
75
+
76
+ @register_evaluator(config_type=AverageNumberOfLLMCallsConfig)
77
+ async def register_avg_num_llm_calls_evaluator(config: AverageNumberOfLLMCallsConfig, builder: EvalBuilder):
78
+ from .evaluate import AverageNumberOfLLMCallsEvaluator
79
+
80
+ evaluator = AverageNumberOfLLMCallsEvaluator(
81
+ max_concurrency=config.max_concurrency or builder.get_max_concurrency())
82
+
83
+ async def evaluate_fn(eval_input: EvalInput) -> EvalOutput:
84
+ return await evaluator.evaluate(eval_input)
85
+
86
+ yield EvaluatorInfo(config=config, evaluate_fn=evaluate_fn, description="Average number of LLM calls")
87
+
88
+
89
+ @register_evaluator(config_type=AverageTokensPerLLMEndConfig)
90
+ async def register_avg_tokens_per_llm_end_evaluator(config: AverageTokensPerLLMEndConfig, builder: EvalBuilder):
91
+ from .evaluate import AverageTokensPerLLMEndEvaluator
92
+
93
+ evaluator = AverageTokensPerLLMEndEvaluator(max_concurrency=config.max_concurrency or builder.get_max_concurrency())
94
+
95
+ async def evaluate_fn(eval_input: EvalInput) -> EvalOutput:
96
+ return await evaluator.evaluate(eval_input)
97
+
98
+ yield EvaluatorInfo(config=config,
99
+ evaluate_fn=evaluate_fn,
100
+ description="Average total tokens per LLM_END (prompt + completion)")
@@ -204,7 +204,7 @@ class SweBenchEvaluator:
204
204
  # if report file is not present, return empty EvalOutput
205
205
  avg_score = 0.0
206
206
  if report_file.exists():
207
- with open(report_file, "r", encoding="utf-8") as f:
207
+ with open(report_file, encoding="utf-8") as f:
208
208
  report = json.load(f)
209
209
  resolved_instances = report.get("resolved_instances", 0)
210
210
  total_instances = report.get("total_instances", 0)
@@ -33,7 +33,7 @@ async def register_trajectory_evaluator(config: TrajectoryEvaluatorConfig, build
33
33
 
34
34
  from .evaluate import TrajectoryEvaluator
35
35
  llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
36
- tools = builder.get_all_tools(wrapper_type=LLMFrameworkEnum.LANGCHAIN)
36
+ tools = await builder.get_all_tools(wrapper_type=LLMFrameworkEnum.LANGCHAIN)
37
37
 
38
38
  _evaluator = TrajectoryEvaluator(llm, tools, builder.get_max_concurrency())
39
39
 
@@ -14,7 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import logging
17
- from typing import Callable
17
+ from collections.abc import Callable
18
18
 
19
19
  from langchain.output_parsers import ResponseSchema
20
20
  from langchain.output_parsers import StructuredOutputParser
nat/eval/usage_stats.py CHANGED
@@ -21,6 +21,8 @@ from pydantic import BaseModel
21
21
  class UsageStatsLLM(BaseModel):
22
22
  prompt_tokens: int = 0
23
23
  completion_tokens: int = 0
24
+ cached_tokens: int = 0
25
+ reasoning_tokens: int = 0
24
26
  total_tokens: int = 0
25
27
 
26
28
 
@@ -24,6 +24,7 @@ import aioboto3
24
24
  from botocore.exceptions import NoCredentialsError
25
25
  from tqdm import tqdm
26
26
 
27
+ from nat.data_models.common import get_secret_value
27
28
  from nat.data_models.evaluate import EvalOutputConfig
28
29
 
29
30
  logger = logging.getLogger(__name__)
@@ -90,8 +91,8 @@ class OutputUploader:
90
91
  "s3",
91
92
  endpoint_url=endpoint_url,
92
93
  region_name=region_name,
93
- aws_access_key_id=self.s3_config.access_key,
94
- aws_secret_access_key=self.s3_config.secret_key,
94
+ aws_access_key_id=get_secret_value(self.s3_config.access_key),
95
+ aws_secret_access_key=get_secret_value(self.s3_config.secret_key),
95
96
  ) as s3_client:
96
97
  with tqdm(total=len(file_entries), desc="Uploading files to S3") as pbar:
97
98
  upload_tasks = [
@@ -82,7 +82,7 @@ class WeaveEvaluationIntegration:
82
82
  """Get the full dataset for Weave."""
83
83
  return [item.full_dataset_entry for item in eval_input.eval_input_items]
84
84
 
85
- def initialize_logger(self, workflow_alias: str, eval_input: EvalInput, config: Any):
85
+ def initialize_logger(self, workflow_alias: str, eval_input: EvalInput, config: Any, job_id: str | None = None):
86
86
  """Initialize the Weave evaluation logger."""
87
87
  if not self.client and not self.initialize_client():
88
88
  # lazy init the client
@@ -92,10 +92,16 @@ class WeaveEvaluationIntegration:
92
92
  weave_dataset = self._get_weave_dataset(eval_input)
93
93
  config_dict = config.model_dump(mode="json")
94
94
  config_dict["name"] = workflow_alias
95
+
96
+ # Include job_id in eval_attributes if provided
97
+ eval_attributes = {}
98
+ if job_id:
99
+ eval_attributes["job_id"] = job_id
100
+
95
101
  self.eval_logger = self.evaluation_logger_cls(model=config_dict,
96
102
  dataset=weave_dataset,
97
103
  name=workflow_alias,
98
- eval_attributes={})
104
+ eval_attributes=eval_attributes)
99
105
  self.pred_loggers = {}
100
106
 
101
107
  # Capture the current evaluation call for context propagation
@@ -136,9 +142,17 @@ class WeaveEvaluationIntegration:
136
142
  coros = []
137
143
  for eval_output_item in eval_output.eval_output_items:
138
144
  if eval_output_item.id in self.pred_loggers:
145
+ # Structure the score as a dict and include reasoning if available
146
+ score_value = {
147
+ "score": eval_output_item.score,
148
+ }
149
+
150
+ if eval_output_item.reasoning is not None:
151
+ score_value["reasoning"] = eval_output_item.reasoning
152
+
139
153
  coros.append(self.pred_loggers[eval_output_item.id].alog_score(
140
154
  scorer=evaluator_name,
141
- score=eval_output_item.score,
155
+ score=score_value,
142
156
  ))
143
157
 
144
158
  # Execute all coroutines concurrently
@@ -16,7 +16,12 @@
16
16
  import functools
17
17
  import inspect
18
18
  import logging
19
+ from collections.abc import AsyncGenerator
20
+ from collections.abc import Callable
21
+ from collections.abc import Generator
19
22
  from typing import Any
23
+ from typing import TypeVar
24
+ from typing import overload
20
25
 
21
26
  logger = logging.getLogger(__name__)
22
27
 
@@ -25,6 +30,9 @@ BASE_WARNING_MESSAGE = ("is experimental and the API may change in future releas
25
30
 
26
31
  _warning_issued = set()
27
32
 
33
+ # Type variables for overloads
34
+ F = TypeVar('F', bound=Callable[..., Any])
35
+
28
36
 
29
37
  def issue_experimental_warning(function_name: str,
30
38
  feature_name: str | None = None,
@@ -53,7 +61,20 @@ def issue_experimental_warning(function_name: str,
53
61
  _warning_issued.add(function_name)
54
62
 
55
63
 
56
- def experimental(func: Any = None, *, feature_name: str | None = None, metadata: dict[str, Any] | None = None):
64
+ # Overloads for different function types
65
+ @overload
66
+ def experimental(func: F, *, feature_name: str | None = None, metadata: dict[str, Any] | None = None) -> F:
67
+ """Overload for when a function is passed directly."""
68
+ ...
69
+
70
+
71
+ @overload
72
+ def experimental(*, feature_name: str | None = None, metadata: dict[str, Any] | None = None) -> Callable[[F], F]:
73
+ """Overload for decorator factory usage (when called with parentheses)."""
74
+ ...
75
+
76
+
77
+ def experimental(func: Any = None, *, feature_name: str | None = None, metadata: dict[str, Any] | None = None) -> Any:
57
78
  """
58
79
  Decorator that can wrap any type of function (sync, async, generator,
59
80
  async generator) and logs a warning that the function is experimental.
@@ -90,7 +111,7 @@ def experimental(func: Any = None, *, feature_name: str | None = None, metadata:
90
111
  # ---------------------
91
112
 
92
113
  @functools.wraps(func)
93
- async def async_gen_wrapper(*args, **kwargs):
114
+ async def async_gen_wrapper(*args, **kwargs) -> AsyncGenerator[Any, Any]:
94
115
  issue_experimental_warning(function_name, feature_name, metadata)
95
116
  async for item in func(*args, **kwargs):
96
117
  yield item # yield the original item
@@ -102,7 +123,7 @@ def experimental(func: Any = None, *, feature_name: str | None = None, metadata:
102
123
  # ASYNC FUNCTION
103
124
  # ---------------------
104
125
  @functools.wraps(func)
105
- async def async_wrapper(*args, **kwargs):
126
+ async def async_wrapper(*args, **kwargs) -> Any:
106
127
  issue_experimental_warning(function_name, feature_name, metadata)
107
128
  result = await func(*args, **kwargs)
108
129
  return result
@@ -114,15 +135,14 @@ def experimental(func: Any = None, *, feature_name: str | None = None, metadata:
114
135
  # SYNC GENERATOR
115
136
  # ---------------------
116
137
  @functools.wraps(func)
117
- def sync_gen_wrapper(*args, **kwargs):
138
+ def sync_gen_wrapper(*args, **kwargs) -> Generator[Any, Any, Any]:
118
139
  issue_experimental_warning(function_name, feature_name, metadata)
119
- for item in func(*args, **kwargs):
120
- yield item # yield the original item
140
+ yield from func(*args, **kwargs) # yield the original item
121
141
 
122
142
  return sync_gen_wrapper
123
143
 
124
144
  @functools.wraps(func)
125
- def sync_wrapper(*args, **kwargs):
145
+ def sync_wrapper(*args, **kwargs) -> Any:
126
146
  issue_experimental_warning(function_name, feature_name, metadata)
127
147
  result = func(*args, **kwargs)
128
148
  return result
@@ -46,7 +46,7 @@ async def execute_score_select_function(config: ExecuteScoreSelectFunctionConfig
46
46
 
47
47
  from pydantic import BaseModel
48
48
 
49
- executable_fn: Function = builder.get_function(name=config.augmented_fn)
49
+ executable_fn: Function = await builder.get_function(name=config.augmented_fn)
50
50
 
51
51
  if config.scorer:
52
52
  scorer = await builder.get_ttc_strategy(strategy_name=config.scorer,
@@ -86,7 +86,7 @@ async def plan_select_execute_function(config: PlanSelectExecuteFunctionConfig,
86
86
  "This error can be resolved by installing nvidia-nat-langchain.")
87
87
 
88
88
  # Get the augmented function's description
89
- augmented_function = builder.get_function(config.augmented_fn)
89
+ augmented_function = await builder.get_function(config.augmented_fn)
90
90
 
91
91
  # For now, we rely on runtime checking for type conversion
92
92
 
@@ -97,11 +97,15 @@ async def plan_select_execute_function(config: PlanSelectExecuteFunctionConfig,
97
97
  f"function without a description.")
98
98
 
99
99
  # Get the function dependencies of the augmented function
100
- function_used_tools = builder.get_function_dependencies(config.augmented_fn).functions
100
+ function_dependencies = builder.get_function_dependencies(config.augmented_fn)
101
+ function_used_tools = set(function_dependencies.functions)
102
+ for function_group in function_dependencies.function_groups:
103
+ function_used_tools.update(builder.get_function_group_dependencies(function_group).functions)
104
+
101
105
  tool_list = "Tool: Description\n"
102
106
 
103
107
  for tool in function_used_tools:
104
- tool_impl = builder.get_function(tool)
108
+ tool_impl = await builder.get_function(tool)
105
109
  tool_list += f"- {tool}: {tool_impl.description if hasattr(tool_impl, 'description') else ''}\n"
106
110
 
107
111
  # Draft the reasoning prompt for the augmented function
@@ -82,7 +82,7 @@ async def register_ttc_tool_orchestration_function(
82
82
  function_map = {}
83
83
  for fn_ref in config.augmented_fns:
84
84
  # Retrieve the actual function from the builder
85
- fn_obj = builder.get_function(fn_ref)
85
+ fn_obj = await builder.get_function(fn_ref)
86
86
  function_map[fn_ref] = fn_obj
87
87
 
88
88
  # 2) Instantiate search, editing, scoring, selection strategies (if any)
@@ -80,7 +80,7 @@ async def register_ttc_tool_wrapper_function(
80
80
  raise ImportError("langchain-core is not installed. Please install it to use SingleShotMultiPlanPlanner.\n"
81
81
  "This error can be resolved by installing nvidia-nat-langchain.")
82
82
 
83
- augmented_function: Function = builder.get_function(config.augmented_fn)
83
+ augmented_function: Function = await builder.get_function(config.augmented_fn)
84
84
  input_llm: BaseChatModel = await builder.get_llm(config.input_llm, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
85
85
 
86
86
  if not augmented_function.has_single_output:
@@ -98,8 +98,8 @@ async def register_ttc_tool_wrapper_function(
98
98
 
99
99
  augmented_function_desc = config.tool_description
100
100
 
101
- fn_input_schema: BaseModel = augmented_function.input_schema
102
- fn_output_schema: BaseModel = augmented_function.single_output_schema
101
+ fn_input_schema: type[BaseModel] = augmented_function.input_schema
102
+ fn_output_schema: type[BaseModel] | type[None] = augmented_function.single_output_schema
103
103
 
104
104
  runnable_llm = input_llm.with_structured_output(schema=fn_input_schema)
105
105
 
@@ -46,11 +46,11 @@ class StrategyBase(ABC):
46
46
  items: list[TTCItem],
47
47
  original_prompt: str | None = None,
48
48
  agent_context: str | None = None,
49
- **kwargs) -> [TTCItem]:
49
+ **kwargs) -> list[TTCItem]:
50
50
  pass
51
51
 
52
52
  @abstractmethod
53
- def supported_pipeline_types(self) -> [PipelineTypeEnum]:
53
+ def supported_pipeline_types(self) -> list[PipelineTypeEnum]:
54
54
  """Return the stage types supported by this selector."""
55
55
  pass
56
56
 
@@ -71,7 +71,7 @@ class LLMBasedOutputMergingSelector(StrategyBase):
71
71
  raise ImportError("langchain-core is not installed. Please install it to use SingleShotMultiPlanPlanner.\n"
72
72
  "This error can be resolved by installing nvidia-nat-langchain.")
73
73
 
74
- from typing import Callable
74
+ from collections.abc import Callable
75
75
 
76
76
  from pydantic import BaseModel
77
77
 
@@ -14,13 +14,16 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import asyncio
17
+ import logging
17
18
  import secrets
18
19
  import webbrowser
19
20
  from dataclasses import dataclass
20
21
  from dataclasses import field
21
22
 
22
23
  import click
24
+ import httpx
23
25
  import pkce
26
+ from authlib.common.errors import AuthlibBaseError as OAuthError
24
27
  from authlib.integrations.httpx_client import AsyncOAuth2Client
25
28
  from fastapi import FastAPI
26
29
  from fastapi import Request
@@ -32,6 +35,8 @@ from nat.data_models.authentication import AuthFlowType
32
35
  from nat.data_models.authentication import AuthProviderBaseConfig
33
36
  from nat.front_ends.fastapi.fastapi_front_end_controller import _FastApiFrontEndController
34
37
 
38
+ logger = logging.getLogger(__name__)
39
+
35
40
 
36
41
  # --------------------------------------------------------------------------- #
37
42
  # Helpers #
@@ -87,17 +92,53 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
87
92
  """
88
93
  Separated for easy overriding in tests (to inject ASGITransport).
89
94
  """
90
- client = AsyncOAuth2Client(
91
- client_id=cfg.client_id,
92
- client_secret=cfg.client_secret,
93
- redirect_uri=cfg.redirect_uri,
94
- scope=" ".join(cfg.scopes) if cfg.scopes else None,
95
- token_endpoint=cfg.token_url,
96
- token_endpoint_auth_method=cfg.token_endpoint_auth_method,
97
- code_challenge_method="S256" if cfg.use_pkce else None,
98
- )
99
- self._oauth_client = client
100
- return client
95
+ try:
96
+ client = AsyncOAuth2Client(
97
+ client_id=cfg.client_id,
98
+ client_secret=cfg.client_secret.get_secret_value(),
99
+ redirect_uri=cfg.redirect_uri,
100
+ scope=" ".join(cfg.scopes) if cfg.scopes else None,
101
+ token_endpoint=cfg.token_url,
102
+ token_endpoint_auth_method=cfg.token_endpoint_auth_method,
103
+ code_challenge_method="S256" if cfg.use_pkce else None,
104
+ )
105
+ self._oauth_client = client
106
+ return client
107
+ except (OAuthError, ValueError, TypeError) as e:
108
+ raise RuntimeError(f"Invalid OAuth2 configuration: {e}") from e
109
+ except Exception as e:
110
+ raise RuntimeError(f"Failed to create OAuth2 client: {e}") from e
111
+
112
+ def _create_authorization_url(self,
113
+ client: AsyncOAuth2Client,
114
+ config: OAuth2AuthCodeFlowProviderConfig,
115
+ state: str,
116
+ verifier: str | None = None,
117
+ challenge: str | None = None) -> str:
118
+ """
119
+ Create OAuth authorization URL with proper error handling.
120
+
121
+ Args:
122
+ client: The OAuth2 client instance
123
+ config: OAuth2 configuration
124
+ state: OAuth state parameter
125
+ verifier: PKCE verifier (if using PKCE)
126
+ challenge: PKCE challenge (if using PKCE)
127
+
128
+ Returns:
129
+ The authorization URL
130
+ """
131
+ try:
132
+ auth_url, _ = client.create_authorization_url(
133
+ config.authorization_url,
134
+ state=state,
135
+ code_verifier=verifier if config.use_pkce else None,
136
+ code_challenge=challenge if config.use_pkce else None,
137
+ **(config.authorization_kwargs or {})
138
+ )
139
+ return auth_url
140
+ except (OAuthError, ValueError, TypeError) as e:
141
+ raise RuntimeError(f"Error creating OAuth authorization URL: {e}") from e
101
142
 
102
143
  # --------------------------- HTTP Basic ------------------------------ #
103
144
  @staticmethod
@@ -131,13 +172,12 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
131
172
  flow_state.verifier = verifier
132
173
  flow_state.challenge = challenge
133
174
 
134
- auth_url, _ = client.create_authorization_url(
135
- cfg.authorization_url,
136
- state=state,
137
- code_verifier=flow_state.verifier if cfg.use_pkce else None,
138
- code_challenge=flow_state.challenge if cfg.use_pkce else None,
139
- **(cfg.authorization_kwargs or {})
140
- )
175
+ # Create authorization URL using helper function
176
+ auth_url = self._create_authorization_url(client=client,
177
+ config=cfg,
178
+ state=state,
179
+ verifier=flow_state.verifier,
180
+ challenge=flow_state.challenge)
141
181
 
142
182
  # Register flow + maybe spin up redirect handler
143
183
  async with self._server_lock:
@@ -149,14 +189,18 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
149
189
  self._flows[state] = flow_state
150
190
  self._active_flows += 1
151
191
 
152
- click.echo("Your browser has been opened for authentication.")
153
- webbrowser.open(auth_url)
192
+ try:
193
+ webbrowser.open(auth_url)
194
+ click.echo("Your browser has been opened for authentication.")
195
+ except Exception as e:
196
+ logger.error("Browser open failed: %s", e)
197
+ raise RuntimeError(f"Browser open failed: {e}") from e
154
198
 
155
199
  # Wait for the redirect to land
156
200
  try:
157
201
  token = await asyncio.wait_for(flow_state.future, timeout=300)
158
- except asyncio.TimeoutError:
159
- raise RuntimeError("Authentication timed out (5 min).")
202
+ except TimeoutError as exc:
203
+ raise RuntimeError("Authentication timed out (5 min).") from exc
160
204
  finally:
161
205
  async with self._server_lock:
162
206
  self._flows.pop(state, None)
@@ -175,9 +219,9 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
175
219
  # --------------- redirect server / in‑process app -------------------- #
176
220
  async def _build_redirect_app(self) -> FastAPI:
177
221
  """
178
- * If cfg.run_redirect_local_server == True → start a uvicorn server (old behaviour).
179
- * Else → only build the FastAPI app and save it to `self._redirect_app`
180
- for in‑process testing with ASGITransport.
222
+ * If cfg.run_redirect_local_server == True → start a local server.
223
+ * Else → only build the redirect app and save it to `self._redirect_app`
224
+ for in‑process testing.
181
225
  """
182
226
  app = FastAPI()
183
227
 
@@ -195,8 +239,16 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
195
239
  state=state,
196
240
  )
197
241
  flow_state.future.set_result(token)
198
- except Exception as exc: # noqa: BLE001
199
- flow_state.future.set_exception(exc)
242
+ except OAuthError as e:
243
+ flow_state.future.set_exception(
244
+ RuntimeError(f"Authorization server rejected request: {e.error} ({e.description})"))
245
+ return "Authentication failed: Authorization server rejected the request. You may close this tab."
246
+ except httpx.HTTPError as e:
247
+ flow_state.future.set_exception(RuntimeError(f"Network error during token fetch: {e}"))
248
+ return "Authentication failed: Network error occurred. You may close this tab."
249
+ except Exception as e:
250
+ flow_state.future.set_exception(RuntimeError(f"Authentication failed: {e}"))
251
+ return "Authentication failed: An unexpected error occurred. You may close this tab."
200
252
  return "Authentication successful – you may close this tab."
201
253
 
202
254
  return app
@@ -213,7 +265,7 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
213
265
 
214
266
  asyncio.create_task(self._server_controller.start_server(host="localhost", port=8000))
215
267
 
216
- # Give uvicorn a moment to bind sockets before we return
268
+ # Give the server a moment to bind sockets before we return
217
269
  await asyncio.sleep(0.3)
218
270
  except Exception as exc: # noqa: BLE001
219
271
  raise RuntimeError(f"Failed to start redirect server: {exc}") from exc
@@ -227,7 +279,7 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
227
279
  @property
228
280
  def redirect_app(self) -> FastAPI | None:
229
281
  """
230
- In testmode (run_redirect_local_server=False) the in‑memory FastAPI
231
- app is exposed so you can mount it on `httpx.ASGITransport`.
282
+ In test mode (run_redirect_local_server=False) the in‑memory redirect
283
+ app is exposed for testing purposes.
232
284
  """
233
285
  return self._redirect_app
@@ -55,9 +55,10 @@ class ConsoleFrontEndPlugin(SimpleFrontEndPluginBase[ConsoleFrontEndConfig]):
55
55
  self.auth_flow_handler = ConsoleAuthenticationFlowHandler()
56
56
 
57
57
  async def pre_run(self):
58
-
59
- if (not self.front_end_config.input_query and not self.front_end_config.input_file):
60
- raise click.UsageError("Must specify either --input_query or --input_file")
58
+ if (self.front_end_config.input_query is not None and self.front_end_config.input_file is not None):
59
+ raise click.UsageError("Must specify either --input or --input_file, not both")
60
+ if (self.front_end_config.input_query is None and self.front_end_config.input_file is None):
61
+ raise click.UsageError("Must specify either --input or --input_file")
61
62
 
62
63
  async def run_workflow(self, session_manager: SessionManager):
63
64
 
@@ -80,17 +81,28 @@ class ConsoleFrontEndPlugin(SimpleFrontEndPluginBase[ConsoleFrontEndConfig]):
80
81
  input_list = list(self.front_end_config.input_query)
81
82
  logger.debug("Processing input: %s", self.front_end_config.input_query)
82
83
 
83
- runner_outputs = await asyncio.gather(*[run_single_query(query) for query in input_list])
84
+ # Make `return_exceptions=False` explicit; all exceptions are raised instead of being silenced
85
+ runner_outputs = await asyncio.gather(*[run_single_query(query) for query in input_list],
86
+ return_exceptions=False)
84
87
 
85
88
  elif (self.front_end_config.input_file):
86
89
 
87
90
  # Run the workflow
88
- with open(self.front_end_config.input_file, "r", encoding="utf-8") as f:
91
+ with open(self.front_end_config.input_file, encoding="utf-8") as f:
89
92
 
90
93
  async with session_manager.workflow.run(f) as runner:
91
94
  runner_outputs = await runner.result(to_type=str)
92
95
  else:
93
96
  assert False, "Should not reach here. Should have been caught by pre_run"
94
97
 
95
- # Print result
96
- logger.info(f"\n{'-' * 50}\n{Fore.GREEN}Workflow Result:\n%s{Fore.RESET}\n{'-' * 50}", runner_outputs)
98
+ line = f"{'-' * 50}"
99
+ prefix = f"{line}\n{Fore.GREEN}Workflow Result:\n"
100
+ suffix = f"{Fore.RESET}\n{line}"
101
+
102
+ logger.info(f"{prefix}%s{suffix}", runner_outputs)
103
+
104
+ # (handler is a stream handler) => (level > INFO)
105
+ effective_level_too_high = all(
106
+ type(h) is not logging.StreamHandler or h.level > logging.INFO for h in logging.getLogger().handlers)
107
+ if effective_level_too_high:
108
+ print(f"{prefix}{runner_outputs}{suffix}")
@@ -24,4 +24,4 @@ class HTTPAuthenticationFlowHandler(FlowHandlerBase):
24
24
  async def authenticate(self, config: AuthProviderBaseConfig, method: AuthFlowType) -> AuthenticatedContext:
25
25
 
26
26
  raise NotImplementedError(f"Authentication method '{method}' is not supported by the HTTP frontend."
27
- f" Do you have Websockets enabled?")
27
+ f" Do you have WebSockets enabled?")