nvidia-nat 1.2.1rc1__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (257) hide show
  1. aiq/__init__.py +2 -2
  2. nat/agent/base.py +27 -18
  3. nat/agent/dual_node.py +9 -4
  4. nat/agent/prompt_optimizer/prompt.py +68 -0
  5. nat/agent/prompt_optimizer/register.py +149 -0
  6. nat/agent/react_agent/agent.py +81 -50
  7. nat/agent/react_agent/register.py +59 -40
  8. nat/agent/reasoning_agent/reasoning_agent.py +17 -15
  9. nat/agent/register.py +1 -1
  10. nat/agent/rewoo_agent/agent.py +327 -149
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +64 -46
  13. nat/agent/tool_calling_agent/agent.py +152 -29
  14. nat/agent/tool_calling_agent/register.py +61 -38
  15. nat/authentication/api_key/api_key_auth_provider.py +2 -2
  16. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  17. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  18. nat/authentication/interfaces.py +5 -2
  19. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
  20. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  21. nat/authentication/register.py +0 -1
  22. nat/builder/builder.py +56 -24
  23. nat/builder/component_utils.py +10 -6
  24. nat/builder/context.py +70 -18
  25. nat/builder/eval_builder.py +16 -11
  26. nat/builder/framework_enum.py +1 -0
  27. nat/builder/front_end.py +1 -1
  28. nat/builder/function.py +378 -8
  29. nat/builder/function_base.py +3 -3
  30. nat/builder/function_info.py +6 -8
  31. nat/builder/intermediate_step_manager.py +6 -2
  32. nat/builder/user_interaction_manager.py +2 -2
  33. nat/builder/workflow.py +13 -1
  34. nat/builder/workflow_builder.py +327 -79
  35. nat/cli/cli_utils/config_override.py +2 -2
  36. nat/cli/commands/evaluate.py +1 -1
  37. nat/cli/commands/info/info.py +16 -6
  38. nat/cli/commands/info/list_channels.py +1 -1
  39. nat/cli/commands/info/list_components.py +7 -8
  40. nat/cli/commands/mcp/__init__.py +14 -0
  41. nat/cli/commands/mcp/mcp.py +986 -0
  42. nat/cli/commands/object_store/__init__.py +14 -0
  43. nat/cli/commands/object_store/object_store.py +227 -0
  44. nat/cli/commands/optimize.py +90 -0
  45. nat/cli/commands/registry/publish.py +2 -2
  46. nat/cli/commands/registry/pull.py +2 -2
  47. nat/cli/commands/registry/remove.py +2 -2
  48. nat/cli/commands/registry/search.py +15 -17
  49. nat/cli/commands/start.py +16 -5
  50. nat/cli/commands/uninstall.py +1 -1
  51. nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
  52. nat/cli/commands/workflow/templates/pyproject.toml.j2 +5 -2
  53. nat/cli/commands/workflow/templates/register.py.j2 +2 -3
  54. nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
  55. nat/cli/commands/workflow/workflow_commands.py +105 -19
  56. nat/cli/entrypoint.py +17 -11
  57. nat/cli/main.py +3 -0
  58. nat/cli/register_workflow.py +38 -4
  59. nat/cli/type_registry.py +79 -10
  60. nat/control_flow/__init__.py +0 -0
  61. nat/control_flow/register.py +20 -0
  62. nat/control_flow/router_agent/__init__.py +0 -0
  63. nat/control_flow/router_agent/agent.py +329 -0
  64. nat/control_flow/router_agent/prompt.py +48 -0
  65. nat/control_flow/router_agent/register.py +91 -0
  66. nat/control_flow/sequential_executor.py +166 -0
  67. nat/data_models/agent.py +34 -0
  68. nat/data_models/api_server.py +196 -67
  69. nat/data_models/authentication.py +23 -9
  70. nat/data_models/common.py +1 -1
  71. nat/data_models/component.py +2 -0
  72. nat/data_models/component_ref.py +11 -0
  73. nat/data_models/config.py +42 -18
  74. nat/data_models/dataset_handler.py +1 -1
  75. nat/data_models/discovery_metadata.py +4 -4
  76. nat/data_models/evaluate.py +4 -1
  77. nat/data_models/function.py +34 -0
  78. nat/data_models/function_dependencies.py +14 -6
  79. nat/data_models/gated_field_mixin.py +242 -0
  80. nat/data_models/intermediate_step.py +3 -3
  81. nat/data_models/optimizable.py +119 -0
  82. nat/data_models/optimizer.py +149 -0
  83. nat/data_models/span.py +41 -3
  84. nat/data_models/swe_bench_model.py +1 -1
  85. nat/data_models/temperature_mixin.py +44 -0
  86. nat/data_models/thinking_mixin.py +86 -0
  87. nat/data_models/top_p_mixin.py +44 -0
  88. nat/embedder/azure_openai_embedder.py +46 -0
  89. nat/embedder/nim_embedder.py +1 -1
  90. nat/embedder/openai_embedder.py +2 -3
  91. nat/embedder/register.py +1 -1
  92. nat/eval/config.py +3 -1
  93. nat/eval/dataset_handler/dataset_handler.py +71 -7
  94. nat/eval/evaluate.py +86 -31
  95. nat/eval/evaluator/base_evaluator.py +1 -1
  96. nat/eval/evaluator/evaluator_model.py +13 -0
  97. nat/eval/intermediate_step_adapter.py +1 -1
  98. nat/eval/rag_evaluator/evaluate.py +9 -6
  99. nat/eval/rag_evaluator/register.py +3 -3
  100. nat/eval/register.py +4 -1
  101. nat/eval/remote_workflow.py +3 -3
  102. nat/eval/runtime_evaluator/__init__.py +14 -0
  103. nat/eval/runtime_evaluator/evaluate.py +123 -0
  104. nat/eval/runtime_evaluator/register.py +100 -0
  105. nat/eval/swe_bench_evaluator/evaluate.py +6 -6
  106. nat/eval/trajectory_evaluator/evaluate.py +1 -1
  107. nat/eval/trajectory_evaluator/register.py +1 -1
  108. nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
  109. nat/eval/utils/eval_trace_ctx.py +89 -0
  110. nat/eval/utils/weave_eval.py +18 -9
  111. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  112. nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
  113. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  114. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
  115. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +3 -3
  116. nat/experimental/test_time_compute/models/strategy_base.py +5 -4
  117. nat/experimental/test_time_compute/register.py +0 -1
  118. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
  119. nat/front_ends/console/authentication_flow_handler.py +82 -30
  120. nat/front_ends/console/console_front_end_plugin.py +19 -7
  121. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
  122. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  123. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  124. nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
  125. nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
  126. nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
  127. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +455 -282
  128. nat/front_ends/fastapi/job_store.py +518 -99
  129. nat/front_ends/fastapi/main.py +11 -19
  130. nat/front_ends/fastapi/message_handler.py +74 -50
  131. nat/front_ends/fastapi/message_validator.py +20 -21
  132. nat/front_ends/fastapi/response_helpers.py +4 -4
  133. nat/front_ends/fastapi/step_adaptor.py +2 -2
  134. nat/front_ends/fastapi/utils.py +57 -0
  135. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  136. nat/front_ends/mcp/mcp_front_end_config.py +47 -3
  137. nat/front_ends/mcp/mcp_front_end_plugin.py +48 -13
  138. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +120 -8
  139. nat/front_ends/mcp/tool_converter.py +44 -14
  140. nat/front_ends/register.py +0 -1
  141. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  142. nat/llm/aws_bedrock_llm.py +24 -12
  143. nat/llm/azure_openai_llm.py +57 -0
  144. nat/llm/litellm_llm.py +69 -0
  145. nat/llm/nim_llm.py +20 -8
  146. nat/llm/openai_llm.py +14 -6
  147. nat/llm/register.py +5 -1
  148. nat/llm/utils/env_config_value.py +2 -3
  149. nat/llm/utils/thinking.py +215 -0
  150. nat/meta/pypi.md +9 -9
  151. nat/object_store/register.py +0 -1
  152. nat/observability/exporter/base_exporter.py +3 -3
  153. nat/observability/exporter/file_exporter.py +1 -1
  154. nat/observability/exporter/processing_exporter.py +309 -81
  155. nat/observability/exporter/span_exporter.py +35 -15
  156. nat/observability/exporter_manager.py +7 -7
  157. nat/observability/mixin/file_mixin.py +7 -7
  158. nat/observability/mixin/redaction_config_mixin.py +42 -0
  159. nat/observability/mixin/tagging_config_mixin.py +62 -0
  160. nat/observability/mixin/type_introspection_mixin.py +420 -107
  161. nat/observability/processor/batching_processor.py +5 -7
  162. nat/observability/processor/falsy_batch_filter_processor.py +55 -0
  163. nat/observability/processor/processor.py +3 -0
  164. nat/observability/processor/processor_factory.py +70 -0
  165. nat/observability/processor/redaction/__init__.py +24 -0
  166. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  167. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  168. nat/observability/processor/redaction/redaction_processor.py +177 -0
  169. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  170. nat/observability/processor/span_tagging_processor.py +68 -0
  171. nat/observability/register.py +22 -4
  172. nat/profiler/calc/calc_runner.py +3 -4
  173. nat/profiler/callbacks/agno_callback_handler.py +1 -1
  174. nat/profiler/callbacks/langchain_callback_handler.py +14 -7
  175. nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
  176. nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
  177. nat/profiler/data_frame_row.py +1 -1
  178. nat/profiler/decorators/framework_wrapper.py +62 -13
  179. nat/profiler/decorators/function_tracking.py +160 -3
  180. nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
  181. nat/profiler/forecasting/models/linear_model.py +1 -1
  182. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  183. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  184. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
  185. nat/profiler/inference_optimization/data_models.py +3 -3
  186. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +8 -9
  187. nat/profiler/inference_optimization/token_uniqueness.py +1 -1
  188. nat/profiler/parameter_optimization/__init__.py +0 -0
  189. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  190. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  191. nat/profiler/parameter_optimization/parameter_optimizer.py +164 -0
  192. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  193. nat/profiler/parameter_optimization/pareto_visualizer.py +395 -0
  194. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  195. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  196. nat/profiler/profile_runner.py +14 -9
  197. nat/profiler/utils.py +4 -2
  198. nat/registry_handlers/local/local_handler.py +2 -2
  199. nat/registry_handlers/package_utils.py +1 -2
  200. nat/registry_handlers/pypi/pypi_handler.py +23 -26
  201. nat/registry_handlers/register.py +3 -4
  202. nat/registry_handlers/rest/rest_handler.py +12 -13
  203. nat/retriever/milvus/retriever.py +2 -2
  204. nat/retriever/nemo_retriever/retriever.py +1 -1
  205. nat/retriever/register.py +0 -1
  206. nat/runtime/loader.py +2 -2
  207. nat/runtime/runner.py +105 -8
  208. nat/runtime/session.py +69 -8
  209. nat/settings/global_settings.py +16 -5
  210. nat/tool/chat_completion.py +5 -2
  211. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
  212. nat/tool/datetime_tools.py +49 -9
  213. nat/tool/document_search.py +2 -2
  214. nat/tool/github_tools.py +450 -0
  215. nat/tool/memory_tools/add_memory_tool.py +3 -3
  216. nat/tool/memory_tools/delete_memory_tool.py +3 -4
  217. nat/tool/memory_tools/get_memory_tool.py +4 -4
  218. nat/tool/nvidia_rag.py +1 -1
  219. nat/tool/register.py +2 -9
  220. nat/tool/retriever.py +3 -2
  221. nat/utils/callable_utils.py +70 -0
  222. nat/utils/data_models/schema_validator.py +3 -3
  223. nat/utils/decorators.py +210 -0
  224. nat/utils/exception_handlers/automatic_retries.py +104 -51
  225. nat/utils/exception_handlers/schemas.py +1 -1
  226. nat/utils/io/yaml_tools.py +2 -2
  227. nat/utils/log_levels.py +25 -0
  228. nat/utils/reactive/base/observable_base.py +2 -2
  229. nat/utils/reactive/base/observer_base.py +1 -1
  230. nat/utils/reactive/observable.py +2 -2
  231. nat/utils/reactive/observer.py +4 -4
  232. nat/utils/reactive/subscription.py +1 -1
  233. nat/utils/settings/global_settings.py +6 -8
  234. nat/utils/type_converter.py +12 -3
  235. nat/utils/type_utils.py +9 -5
  236. nvidia_nat-1.3.0.dist-info/METADATA +195 -0
  237. {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/RECORD +244 -200
  238. {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/entry_points.txt +1 -0
  239. nat/cli/commands/info/list_mcp.py +0 -304
  240. nat/tool/github_tools/create_github_commit.py +0 -133
  241. nat/tool/github_tools/create_github_issue.py +0 -87
  242. nat/tool/github_tools/create_github_pr.py +0 -106
  243. nat/tool/github_tools/get_github_file.py +0 -106
  244. nat/tool/github_tools/get_github_issue.py +0 -166
  245. nat/tool/github_tools/get_github_pr.py +0 -256
  246. nat/tool/github_tools/update_github_issue.py +0 -100
  247. nat/tool/mcp/exceptions.py +0 -142
  248. nat/tool/mcp/mcp_client.py +0 -255
  249. nat/tool/mcp/mcp_tool.py +0 -96
  250. nat/utils/exception_handlers/mcp.py +0 -211
  251. nvidia_nat-1.2.1rc1.dist-info/METADATA +0 -365
  252. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  253. /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
  254. {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/WHEEL +0 -0
  255. {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  256. {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/licenses/LICENSE.md +0 -0
  257. {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import importlib
16
17
  import json
17
18
  import math
18
19
  from pathlib import Path
@@ -41,7 +42,8 @@ class DatasetHandler:
41
42
  reps: int,
42
43
  concurrency: int,
43
44
  num_passes: int = 1,
44
- adjust_dataset_size: bool = False):
45
+ adjust_dataset_size: bool = False,
46
+ custom_pre_eval_process_function: str | None = None):
45
47
  from nat.eval.intermediate_step_adapter import IntermediateStepAdapter
46
48
 
47
49
  self.dataset_config = dataset_config
@@ -53,6 +55,9 @@ class DatasetHandler:
53
55
  self.num_passes = num_passes
54
56
  self.adjust_dataset_size = adjust_dataset_size
55
57
 
58
+ # Custom pre-evaluation process function
59
+ self.custom_pre_eval_process_function = custom_pre_eval_process_function
60
+
56
61
  # Helpers
57
62
  self.intermediate_step_adapter = IntermediateStepAdapter()
58
63
 
@@ -146,13 +151,12 @@ class DatasetHandler:
146
151
  # When num_passes is specified, always use concurrency * num_passes
147
152
  # This respects the user's intent for exact number of passes
148
153
  target_size = self.concurrency * self.num_passes
154
+ # When num_passes = 0, use the largest multiple of concurrency <= original_size
155
+ # If original_size < concurrency, we need at least concurrency rows
156
+ elif original_size >= self.concurrency:
157
+ target_size = (original_size // self.concurrency) * self.concurrency
149
158
  else:
150
- # When num_passes = 0, use the largest multiple of concurrency <= original_size
151
- # If original_size < concurrency, we need at least concurrency rows
152
- if original_size >= self.concurrency:
153
- target_size = (original_size // self.concurrency) * self.concurrency
154
- else:
155
- target_size = self.concurrency
159
+ target_size = self.concurrency
156
160
 
157
161
  if target_size == 0:
158
162
  raise ValueError("Input dataset too small for even one batch at given concurrency.")
@@ -331,6 +335,66 @@ class DatasetHandler:
331
335
  filtered_steps = self.intermediate_step_adapter.filter_intermediate_steps(intermediate_steps, event_filter)
332
336
  return self.intermediate_step_adapter.serialize_intermediate_steps(filtered_steps)
333
337
 
338
+ def pre_eval_process_eval_input(self, eval_input: EvalInput) -> EvalInput:
339
+ """
340
+ Pre-evaluation process the eval input using custom function if provided.
341
+
342
+ The custom pre-evaluation process function should have the signature:
343
+ def custom_pre_eval_process(item: EvalInputItem) -> EvalInputItem
344
+
345
+ The framework will iterate through all items and call this function on each one.
346
+
347
+ Args:
348
+ eval_input: The EvalInput object to pre-evaluation process
349
+
350
+ Returns:
351
+ The pre-evaluation processed EvalInput object
352
+ """
353
+ if self.custom_pre_eval_process_function:
354
+ try:
355
+ custom_function = self._load_custom_pre_eval_process_function()
356
+ processed_items = []
357
+
358
+ for item in eval_input.eval_input_items:
359
+ processed_item = custom_function(item)
360
+ if not isinstance(processed_item, EvalInputItem):
361
+ raise TypeError(f"Custom pre-evaluation '{self.custom_pre_eval_process_function}' must return "
362
+ f"EvalInputItem, got {type(processed_item)}")
363
+ processed_items.append(processed_item)
364
+
365
+ return EvalInput(eval_input_items=processed_items)
366
+ except Exception as e:
367
+ raise RuntimeError(f"Error calling custom pre-evaluation process function "
368
+ f"'{self.custom_pre_eval_process_function}': {e}") from e
369
+
370
+ return eval_input
371
+
372
+ def _load_custom_pre_eval_process_function(self):
373
+ """
374
+ Import and return the custom pre-evaluation process function using standard Python import path.
375
+
376
+ The function should process individual EvalInputItem objects.
377
+ """
378
+ # Split the function path to get module and function name
379
+ if "." not in self.custom_pre_eval_process_function:
380
+ raise ValueError(f"Invalid custom_pre_eval_process_function '{self.custom_pre_eval_process_function}'. "
381
+ "Expected format: '<module_path>.<function_name>'")
382
+ module_path, function_name = self.custom_pre_eval_process_function.rsplit(".", 1)
383
+
384
+ # Import the module
385
+ module = importlib.import_module(module_path)
386
+
387
+ # Get the function from the module
388
+ if not hasattr(module, function_name):
389
+ raise AttributeError(f"Function '{function_name}' not found in module '{module_path}'")
390
+
391
+ custom_function = getattr(module, function_name)
392
+
393
+ if not callable(custom_function):
394
+ raise ValueError(f"'{self.custom_pre_eval_process_function}' is not callable")
395
+
396
+ return custom_function
397
+
334
398
  def publish_eval_input(self,
335
399
  eval_input,
336
400
  workflow_output_step_filter: list[IntermediateStepType] | None = None) -> str:
nat/eval/evaluate.py CHANGED
@@ -42,7 +42,7 @@ from nat.runtime.session import SessionManager
42
42
  logger = logging.getLogger(__name__)
43
43
 
44
44
 
45
- class EvaluationRun: # pylint: disable=too-many-public-methods
45
+ class EvaluationRun:
46
46
  """
47
47
  Instantiated for each evaluation run and used to store data for that single run.
48
48
 
@@ -63,7 +63,16 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
63
63
 
64
64
  # Helpers
65
65
  self.intermediate_step_adapter: IntermediateStepAdapter = IntermediateStepAdapter()
66
- self.weave_eval: WeaveEvaluationIntegration = WeaveEvaluationIntegration()
66
+
67
+ # Create evaluation trace context
68
+ try:
69
+ from nat.eval.utils.eval_trace_ctx import WeaveEvalTraceContext
70
+ self.eval_trace_context = WeaveEvalTraceContext()
71
+ except Exception:
72
+ from nat.eval.utils.eval_trace_ctx import EvalTraceContext
73
+ self.eval_trace_context = EvalTraceContext()
74
+
75
+ self.weave_eval: WeaveEvaluationIntegration = WeaveEvaluationIntegration(self.eval_trace_context)
67
76
  # Metadata
68
77
  self.eval_input: EvalInput | None = None
69
78
  self.workflow_interrupted: bool = False
@@ -159,17 +168,17 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
159
168
  intermediate_future = None
160
169
 
161
170
  try:
162
-
163
171
  # Start usage stats and intermediate steps collection in parallel
164
172
  intermediate_future = pull_intermediate()
165
173
  runner_result = runner.result()
166
174
  base_output = await runner_result
167
175
  intermediate_steps = await intermediate_future
168
176
  except NotImplementedError as e:
177
+ logger.error("Failed to run the workflow: %s", e)
169
178
  # raise original error
170
- raise e
179
+ raise
171
180
  except Exception as e:
172
- logger.exception("Failed to run the workflow: %s", e, exc_info=True)
181
+ logger.exception("Failed to run the workflow: %s", e)
173
182
  # stop processing if a workflow error occurs
174
183
  self.workflow_interrupted = True
175
184
 
@@ -308,9 +317,9 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
308
317
  logger.info("Deleting old job directory: %s", dir_to_delete)
309
318
  shutil.rmtree(dir_to_delete)
310
319
  except Exception as e:
311
- logger.exception("Failed to delete old job directory: %s: %s", dir_to_delete, e, exc_info=True)
320
+ logger.exception("Failed to delete old job directory: %s: %s", dir_to_delete, e)
312
321
 
313
- def write_output(self, dataset_handler: DatasetHandler, profiler_results: ProfilerResults): # pylint: disable=unused-argument # noqa: E501
322
+ def write_output(self, dataset_handler: DatasetHandler, profiler_results: ProfilerResults):
314
323
  workflow_output_file = self.eval_config.general.output_dir / "workflow_output.json"
315
324
  workflow_output_file.parent.mkdir(parents=True, exist_ok=True)
316
325
 
@@ -358,7 +367,7 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
358
367
 
359
368
  await self.weave_eval.alog_score(eval_output, evaluator_name)
360
369
  except Exception as e:
361
- logger.exception("An error occurred while running evaluator %s: %s", evaluator_name, e, exc_info=True)
370
+ logger.exception("An error occurred while running evaluator %s: %s", evaluator_name, e)
362
371
 
363
372
  async def run_evaluators(self, evaluators: dict[str, Any]):
364
373
  """Run all configured evaluators asynchronously."""
@@ -371,7 +380,7 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
371
380
  try:
372
381
  await asyncio.gather(*tasks)
373
382
  except Exception as e:
374
- logger.exception("An error occurred while running evaluators: %s", e, exc_info=True)
383
+ logger.error("An error occurred while running evaluators: %s", e)
375
384
  raise
376
385
  finally:
377
386
  # Finish prediction loggers in Weave
@@ -401,6 +410,33 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
401
410
 
402
411
  return workflow_type
403
412
 
413
+ async def wait_for_all_export_tasks_local(self, session_manager: SessionManager, timeout: float) -> None:
414
+ """Wait for all trace export tasks to complete for local workflows.
415
+
416
+ This only works for local workflows where we have direct access to the
417
+ SessionManager and its underlying workflow with exporter manager.
418
+ """
419
+ try:
420
+ workflow = session_manager.workflow
421
+ all_exporters = await workflow.get_all_exporters()
422
+ if not all_exporters:
423
+ logger.debug("No exporters to wait for")
424
+ return
425
+
426
+ logger.info("Waiting for export tasks from %d local exporters (timeout: %ds)", len(all_exporters), timeout)
427
+
428
+ for name, exporter in all_exporters.items():
429
+ try:
430
+ await exporter.wait_for_tasks(timeout=timeout)
431
+ logger.info("Export tasks completed for exporter: %s", name)
432
+ except Exception as e:
433
+ logger.warning("Error waiting for export tasks from %s: %s", name, e)
434
+
435
+ logger.info("All local export task waiting completed")
436
+
437
+ except Exception as e:
438
+ logger.warning("Failed to wait for local export tasks: %s", e)
439
+
404
440
  async def run_and_evaluate(self,
405
441
  session_manager: SessionManager | None = None,
406
442
  job_id: str | None = None) -> EvaluationRunOutput:
@@ -413,10 +449,14 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
413
449
  from nat.runtime.loader import load_config
414
450
 
415
451
  # Load and override the config
416
- if self.config.override:
452
+ config = None
453
+ if isinstance(self.config.config_file, BaseModel):
454
+ config = self.config.config_file
455
+ elif self.config.override:
417
456
  config = self.apply_overrides()
418
457
  else:
419
458
  config = load_config(self.config.config_file)
459
+
420
460
  self.eval_config = config.eval
421
461
  workflow_alias = self._get_workflow_alias(config.workflow.type)
422
462
  logger.debug("Loaded %s evaluation configuration: %s", workflow_alias, self.eval_config)
@@ -442,44 +482,59 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
442
482
  dataset_config = self.eval_config.general.dataset # Currently only one dataset is supported
443
483
  if not dataset_config:
444
484
  logger.info("No dataset found, nothing to evaluate")
445
- return EvaluationRunOutput(
446
- workflow_output_file=self.workflow_output_file,
447
- evaluator_output_files=self.evaluator_output_files,
448
- workflow_interrupted=self.workflow_interrupted,
449
- )
450
-
485
+ return EvaluationRunOutput(workflow_output_file=self.workflow_output_file,
486
+ evaluator_output_files=self.evaluator_output_files,
487
+ workflow_interrupted=self.workflow_interrupted,
488
+ eval_input=EvalInput(eval_input_items=[]),
489
+ evaluation_results=[],
490
+ usage_stats=UsageStats(),
491
+ profiler_results=ProfilerResults())
492
+
493
+ custom_pre_eval_process_function = self.eval_config.general.output.custom_pre_eval_process_function \
494
+ if self.eval_config.general.output else None
451
495
  dataset_handler = DatasetHandler(dataset_config=dataset_config,
452
496
  reps=self.config.reps,
453
497
  concurrency=self.eval_config.general.max_concurrency,
454
498
  num_passes=self.config.num_passes,
455
- adjust_dataset_size=self.config.adjust_dataset_size)
499
+ adjust_dataset_size=self.config.adjust_dataset_size,
500
+ custom_pre_eval_process_function=custom_pre_eval_process_function)
456
501
  self.eval_input = dataset_handler.get_eval_input_from_dataset(self.config.dataset)
457
502
  if not self.eval_input.eval_input_items:
458
503
  logger.info("Dataset is empty. Nothing to evaluate.")
459
- return EvaluationRunOutput(
460
- workflow_output_file=self.workflow_output_file,
461
- evaluator_output_files=self.evaluator_output_files,
462
- workflow_interrupted=self.workflow_interrupted,
463
- )
504
+ return EvaluationRunOutput(workflow_output_file=self.workflow_output_file,
505
+ evaluator_output_files=self.evaluator_output_files,
506
+ workflow_interrupted=self.workflow_interrupted,
507
+ eval_input=self.eval_input,
508
+ evaluation_results=self.evaluation_results,
509
+ usage_stats=self.usage_stats,
510
+ profiler_results=ProfilerResults())
464
511
 
465
512
  # Run workflow and evaluate
466
513
  async with WorkflowEvalBuilder.from_config(config=config) as eval_workflow:
467
514
  # Initialize Weave integration
468
515
  self.weave_eval.initialize_logger(workflow_alias, self.eval_input, config)
469
516
 
470
- # Run workflow
471
- if self.config.endpoint:
472
- await self.run_workflow_remote()
473
- else:
474
- if not self.config.skip_workflow:
517
+ with self.eval_trace_context.evaluation_context():
518
+ # Run workflow
519
+ if self.config.endpoint:
520
+ await self.run_workflow_remote()
521
+ elif not self.config.skip_workflow:
475
522
  if session_manager is None:
476
- session_manager = SessionManager(eval_workflow.build(),
523
+ workflow = await eval_workflow.build()
524
+ session_manager = SessionManager(workflow,
477
525
  max_concurrency=self.eval_config.general.max_concurrency)
478
526
  await self.run_workflow_local(session_manager)
479
527
 
480
- # Evaluate
481
- evaluators = {name: eval_workflow.get_evaluator(name) for name in self.eval_config.evaluators}
482
- await self.run_evaluators(evaluators)
528
+ # Pre-evaluation process the workflow output
529
+ self.eval_input = dataset_handler.pre_eval_process_eval_input(self.eval_input)
530
+
531
+ # Evaluate
532
+ evaluators = {name: eval_workflow.get_evaluator(name) for name in self.eval_config.evaluators}
533
+ await self.run_evaluators(evaluators)
534
+
535
+ # Wait for all trace export tasks to complete (local workflows only)
536
+ if session_manager and not self.config.endpoint:
537
+ await self.wait_for_all_export_tasks_local(session_manager, timeout=self.config.export_timeout)
483
538
 
484
539
  # Profile the workflow
485
540
  profiler_results = await self.profile_workflow()
@@ -71,7 +71,7 @@ class BaseEvaluator(ABC):
71
71
  TqdmPositionRegistry.release(tqdm_position)
72
72
 
73
73
  # Compute average if possible
74
- numeric_scores = [item.score for item in output_items if isinstance(item.score, (int, float))]
74
+ numeric_scores = [item.score for item in output_items if isinstance(item.score, int | float)]
75
75
  avg_score = round(sum(numeric_scores) / len(numeric_scores), 2) if numeric_scores else None
76
76
 
77
77
  return EvalOutput(average_score=avg_score, eval_output_items=output_items)
@@ -29,6 +29,19 @@ class EvalInputItem(BaseModel):
29
29
  trajectory: list[IntermediateStep] = [] # populated by the workflow
30
30
  full_dataset_entry: typing.Any
31
31
 
32
+ def copy_with_updates(self, **updates) -> "EvalInputItem":
33
+ """
34
+ Copy EvalInputItem with optional field updates.
35
+ """
36
+ # Get all current fields
37
+ item_data = self.model_dump()
38
+
39
+ # Apply any updates
40
+ item_data.update(updates)
41
+
42
+ # Create new item with all fields
43
+ return EvalInputItem(**item_data)
44
+
32
45
 
33
46
  class EvalInput(BaseModel):
34
47
  eval_input_items: list[EvalInputItem]
@@ -40,7 +40,7 @@ class IntermediateStepAdapter:
40
40
  try:
41
41
  validated_steps.append(IntermediateStep.model_validate(step_data))
42
42
  except Exception as e:
43
- logger.exception("Validation failed for step: %r, Error: %s", step_data, e, exc_info=True)
43
+ logger.exception("Validation failed for step: %r, Error: %s", step_data, e)
44
44
  return validated_steps
45
45
 
46
46
  def serialize_intermediate_steps(self, intermediate_steps: list[IntermediateStep]) -> list[dict]:
@@ -102,7 +102,7 @@ class RAGEvaluator:
102
102
  """Converts the ragas EvaluationResult to nat EvalOutput"""
103
103
 
104
104
  if not results_dataset:
105
- logger.error("Ragas evaluation failed with no results")
105
+ logger.error("Ragas evaluation failed with no results", exc_info=True)
106
106
  return EvalOutput(average_score=0.0, eval_output_items=[])
107
107
 
108
108
  scores: list[dict[str, float]] = results_dataset.scores
@@ -116,11 +116,14 @@ class RAGEvaluator:
116
116
  """Convert NaN or None to 0.0 for safe arithmetic/serialization."""
117
117
  return 0.0 if v is None or (isinstance(v, float) and math.isnan(v)) else v
118
118
 
119
- # Convert from list of dicts to dict of lists, coercing NaN/None to 0.0
119
+ # Keep original scores (preserving NaN/None) for output
120
+ original_scores_dict = {metric: [score.get(metric) for score in scores] for metric in scores[0]}
121
+
122
+ # Convert from list of dicts to dict of lists, coercing NaN/None to 0.0 for average calculation
120
123
  scores_dict = {metric: [_nan_to_zero(score.get(metric)) for score in scores] for metric in scores[0]}
121
124
  first_metric_name = list(scores_dict.keys())[0] if scores_dict else None
122
125
 
123
- # Compute the average of each metric, guarding against empty lists
126
+ # Compute the average of each metric using cleaned scores (NaN/None -> 0.0)
124
127
  average_scores = {
125
128
  metric: (sum(values) / len(values) if values else 0.0)
126
129
  for metric, values in scores_dict.items()
@@ -137,11 +140,11 @@ class RAGEvaluator:
137
140
  else:
138
141
  ids = df["user_input"].tolist() # Use "user_input" as ID fallback
139
142
 
140
- # Construct EvalOutputItem list
143
+ # Construct EvalOutputItem list using original scores (preserving NaN/None)
141
144
  eval_output_items = [
142
145
  EvalOutputItem(
143
146
  id=ids[i],
144
- score=_nan_to_zero(getattr(row, first_metric_name, 0.0) if first_metric_name else 0.0),
147
+ score=original_scores_dict[first_metric_name][i] if first_metric_name else None,
145
148
  reasoning={
146
149
  key:
147
150
  getattr(row, key, None) # Use getattr to safely access attributes
@@ -169,7 +172,7 @@ class RAGEvaluator:
169
172
  _pbar=pbar)
170
173
  except Exception as e:
171
174
  # On exception we still continue with other evaluators. Log and return an avg_score of 0.0
172
- logger.exception("Error evaluating ragas metric, Error: %s", e, exc_info=True)
175
+ logger.exception("Error evaluating ragas metric, Error: %s", e)
173
176
  results_dataset = None
174
177
  finally:
175
178
  pbar.close()
@@ -73,7 +73,7 @@ class RagasEvaluatorConfig(EvaluatorBaseConfig, name="ragas"):
73
73
  if isinstance(self.metric, str):
74
74
  return self.metric
75
75
  if isinstance(self.metric, dict) and self.metric:
76
- return next(iter(self.metric.keys())) # pylint: disable=no-member
76
+ return next(iter(self.metric.keys()))
77
77
  return ""
78
78
 
79
79
  @property
@@ -82,7 +82,7 @@ class RagasEvaluatorConfig(EvaluatorBaseConfig, name="ragas"):
82
82
  if isinstance(self.metric, str):
83
83
  return RagasMetricConfig() # Default config when only a metric name is given
84
84
  if isinstance(self.metric, dict) and self.metric:
85
- return next(iter(self.metric.values())) # pylint: disable=no-member
85
+ return next(iter(self.metric.values()))
86
86
  return RagasMetricConfig() # Default config when an invalid type is provided
87
87
 
88
88
 
@@ -104,7 +104,7 @@ async def register_ragas_evaluator(config: RagasEvaluatorConfig, builder: EvalBu
104
104
  raise ValueError(message) from e
105
105
  except AttributeError as e:
106
106
  message = f"Ragas metric {metric_name} not found {e}."
107
- logger.error(message)
107
+ logger.exception(message)
108
108
  return None
109
109
 
110
110
  async def evaluate_fn(eval_input: EvalInput) -> EvalOutput:
nat/eval/register.py CHANGED
@@ -14,10 +14,13 @@
14
14
  # limitations under the License.
15
15
 
16
16
  # flake8: noqa
17
- # pylint: disable=unused-import
18
17
 
19
18
  # Import evaluators which need to be automatically registered here
20
19
  from .rag_evaluator.register import register_ragas_evaluator
20
+ from .runtime_evaluator.register import register_avg_llm_latency_evaluator
21
+ from .runtime_evaluator.register import register_avg_num_llm_calls_evaluator
22
+ from .runtime_evaluator.register import register_avg_tokens_per_llm_end_evaluator
23
+ from .runtime_evaluator.register import register_avg_workflow_runtime_evaluator
21
24
  from .swe_bench_evaluator.register import register_swe_bench_evaluator
22
25
  from .trajectory_evaluator.register import register_trajectory_evaluator
23
26
  from .tunable_rag_evaluator.register import register_tunable_rag_evaluator
@@ -74,7 +74,7 @@ class EvaluationRemoteWorkflowHandler:
74
74
  if chunk_data.get("value"):
75
75
  final_response = chunk_data.get("value")
76
76
  except json.JSONDecodeError as e:
77
- logger.error("Failed to parse generate response chunk: %s", e)
77
+ logger.exception("Failed to parse generate response chunk: %s", e)
78
78
  continue
79
79
  elif line.startswith(INTERMEDIATE_DATA_PREFIX):
80
80
  # This is an intermediate step
@@ -90,12 +90,12 @@ class EvaluationRemoteWorkflowHandler:
90
90
  payload=payload)
91
91
  intermediate_steps.append(intermediate_step)
92
92
  except (json.JSONDecodeError, ValidationError) as e:
93
- logger.error("Failed to parse intermediate step: %s", e)
93
+ logger.exception("Failed to parse intermediate step: %s", e)
94
94
  continue
95
95
 
96
96
  except aiohttp.ClientError as e:
97
97
  # Handle connection or HTTP-related errors
98
- logger.error("Request failed for question %s: %s", question, e)
98
+ logger.exception("Request failed for question %s: %s", question, e)
99
99
  item.output_obj = None
100
100
  item.trajectory = []
101
101
  return
@@ -0,0 +1,14 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
@@ -0,0 +1,123 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ from collections import defaultdict
19
+ from dataclasses import dataclass
20
+
21
+ from nat.data_models.intermediate_step import IntermediateStepType
22
+ from nat.eval.evaluator.base_evaluator import BaseEvaluator
23
+ from nat.eval.evaluator.evaluator_model import EvalInputItem
24
+ from nat.eval.evaluator.evaluator_model import EvalOutputItem
25
+ from nat.profiler.intermediate_property_adapter import IntermediatePropertyAdaptor
26
+
27
+
28
+ @dataclass
29
+ class _CallTiming:
30
+ start_ts: float | None = None
31
+ end_ts: float | None = None
32
+
33
+ @property
34
+ def latency(self) -> float | None:
35
+ if self.start_ts is None or self.end_ts is None:
36
+ return None
37
+ return max(0.0, self.end_ts - self.start_ts)
38
+
39
+
40
+ class AverageLLMLatencyEvaluator(BaseEvaluator):
41
+ """
42
+ Mean difference between connected LLM_START and LLM_END events (same UUID).
43
+ The score is the average latency in seconds for the item. Reasoning contains per-call latencies.
44
+ """
45
+
46
+ def __init__(self, max_concurrency: int = 8):
47
+ super().__init__(max_concurrency=max_concurrency, tqdm_desc="Evaluating Avg LLM Latency")
48
+
49
+ async def evaluate_item(self, item: EvalInputItem) -> EvalOutputItem: # noqa: D401
50
+ calls: dict[str, _CallTiming] = defaultdict(_CallTiming)
51
+
52
+ for step in (IntermediatePropertyAdaptor.from_intermediate_step(s) for s in item.trajectory):
53
+ if step.event_type == IntermediateStepType.LLM_START:
54
+ calls[step.UUID].start_ts = step.event_timestamp
55
+ elif step.event_type == IntermediateStepType.LLM_END:
56
+ calls[step.UUID].end_ts = step.event_timestamp
57
+
58
+ latencies = [ct.latency for ct in calls.values() if ct.latency is not None]
59
+ avg_latency = sum(latencies) / len(latencies) if latencies else 0.0
60
+
61
+ reasoning = {
62
+ "num_llm_calls": len(latencies),
63
+ "latencies": latencies,
64
+ }
65
+ return EvalOutputItem(id=item.id, score=round(avg_latency, 4), reasoning=reasoning)
66
+
67
+
68
+ class AverageWorkflowRuntimeEvaluator(BaseEvaluator):
69
+ """
70
+ Average workflow runtime per item: max(event_timestamp) - min(event_timestamp) across the trajectory.
71
+ The score is the runtime in seconds for the item.
72
+ """
73
+
74
+ def __init__(self, max_concurrency: int = 8):
75
+ super().__init__(max_concurrency=max_concurrency, tqdm_desc="Evaluating Avg Workflow Runtime")
76
+
77
+ async def evaluate_item(self, item: EvalInputItem) -> EvalOutputItem: # noqa: D401
78
+ if not item.trajectory:
79
+ return EvalOutputItem(id=item.id, score=0.0, reasoning={"note": "no steps"})
80
+
81
+ timestamps = [s.event_timestamp for s in item.trajectory]
82
+ runtime = max(timestamps) - min(timestamps)
83
+ return EvalOutputItem(id=item.id, score=round(max(0.0, runtime), 4), reasoning={"steps": len(timestamps)})
84
+
85
+
86
+ class AverageNumberOfLLMCallsEvaluator(BaseEvaluator):
87
+ """
88
+ Average number of LLM calls per item. The score is the count for the item.
89
+ """
90
+
91
+ def __init__(self, max_concurrency: int = 8):
92
+ super().__init__(max_concurrency=max_concurrency, tqdm_desc="Evaluating Avg # LLM Calls")
93
+
94
+ async def evaluate_item(self, item: EvalInputItem) -> EvalOutputItem: # noqa: D401
95
+ num_calls = sum(1 for s in item.trajectory if s.event_type == IntermediateStepType.LLM_END)
96
+ return EvalOutputItem(id=item.id, score=float(num_calls), reasoning={"num_llm_end": num_calls})
97
+
98
+
99
+ class AverageTokensPerLLMEndEvaluator(BaseEvaluator):
100
+ """
101
+ Average total tokens per LLM_END event: sum of prompt and completion tokens if available.
102
+ The score is the average tokens per LLM_END for the item (0 if none).
103
+ """
104
+
105
+ def __init__(self, max_concurrency: int = 8):
106
+ super().__init__(max_concurrency=max_concurrency, tqdm_desc="Evaluating Avg Tokens/LLM_END")
107
+
108
+ async def evaluate_item(self, item: EvalInputItem) -> EvalOutputItem: # noqa: D401
109
+ totals: list[int] = []
110
+ for step in (IntermediatePropertyAdaptor.from_intermediate_step(s) for s in item.trajectory):
111
+ if step.event_type == IntermediateStepType.LLM_END:
112
+ total_tokens = step.token_usage.total_tokens
113
+ # If framework doesn't set total, compute from prompt+completion
114
+ if total_tokens == 0:
115
+ total_tokens = step.token_usage.prompt_tokens + step.token_usage.completion_tokens
116
+ totals.append(total_tokens)
117
+
118
+ avg_tokens = (sum(totals) / len(totals)) if totals else 0.0
119
+ reasoning = {
120
+ "num_llm_end": len(totals),
121
+ "totals": totals,
122
+ }
123
+ return EvalOutputItem(id=item.id, score=round(avg_tokens, 2), reasoning=reasoning)