nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (250) hide show
  1. aiq/__init__.py +2 -2
  2. nat/agent/base.py +24 -15
  3. nat/agent/dual_node.py +9 -4
  4. nat/agent/prompt_optimizer/prompt.py +68 -0
  5. nat/agent/prompt_optimizer/register.py +149 -0
  6. nat/agent/react_agent/agent.py +79 -47
  7. nat/agent/react_agent/register.py +50 -22
  8. nat/agent/reasoning_agent/reasoning_agent.py +11 -9
  9. nat/agent/register.py +1 -1
  10. nat/agent/rewoo_agent/agent.py +326 -148
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +54 -27
  13. nat/agent/tool_calling_agent/agent.py +84 -28
  14. nat/agent/tool_calling_agent/register.py +51 -28
  15. nat/authentication/api_key/api_key_auth_provider.py +2 -2
  16. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  17. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  18. nat/authentication/interfaces.py +5 -2
  19. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
  20. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  21. nat/authentication/register.py +0 -1
  22. nat/builder/builder.py +56 -24
  23. nat/builder/component_utils.py +9 -5
  24. nat/builder/context.py +68 -17
  25. nat/builder/eval_builder.py +16 -11
  26. nat/builder/framework_enum.py +1 -0
  27. nat/builder/front_end.py +1 -1
  28. nat/builder/function.py +378 -8
  29. nat/builder/function_base.py +3 -3
  30. nat/builder/function_info.py +6 -8
  31. nat/builder/user_interaction_manager.py +2 -2
  32. nat/builder/workflow.py +13 -1
  33. nat/builder/workflow_builder.py +281 -76
  34. nat/cli/cli_utils/config_override.py +2 -2
  35. nat/cli/commands/evaluate.py +1 -1
  36. nat/cli/commands/info/info.py +16 -6
  37. nat/cli/commands/info/list_channels.py +1 -1
  38. nat/cli/commands/info/list_components.py +7 -8
  39. nat/cli/commands/mcp/__init__.py +14 -0
  40. nat/cli/commands/mcp/mcp.py +986 -0
  41. nat/cli/commands/object_store/__init__.py +14 -0
  42. nat/cli/commands/object_store/object_store.py +227 -0
  43. nat/cli/commands/optimize.py +90 -0
  44. nat/cli/commands/registry/publish.py +2 -2
  45. nat/cli/commands/registry/pull.py +2 -2
  46. nat/cli/commands/registry/remove.py +2 -2
  47. nat/cli/commands/registry/search.py +15 -17
  48. nat/cli/commands/start.py +16 -5
  49. nat/cli/commands/uninstall.py +1 -1
  50. nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
  51. nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
  52. nat/cli/commands/workflow/templates/register.py.j2 +2 -3
  53. nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
  54. nat/cli/commands/workflow/workflow_commands.py +62 -22
  55. nat/cli/entrypoint.py +8 -10
  56. nat/cli/main.py +3 -0
  57. nat/cli/register_workflow.py +38 -4
  58. nat/cli/type_registry.py +75 -6
  59. nat/control_flow/__init__.py +0 -0
  60. nat/control_flow/register.py +20 -0
  61. nat/control_flow/router_agent/__init__.py +0 -0
  62. nat/control_flow/router_agent/agent.py +329 -0
  63. nat/control_flow/router_agent/prompt.py +48 -0
  64. nat/control_flow/router_agent/register.py +91 -0
  65. nat/control_flow/sequential_executor.py +166 -0
  66. nat/data_models/agent.py +34 -0
  67. nat/data_models/api_server.py +74 -66
  68. nat/data_models/authentication.py +23 -9
  69. nat/data_models/common.py +1 -1
  70. nat/data_models/component.py +2 -0
  71. nat/data_models/component_ref.py +11 -0
  72. nat/data_models/config.py +41 -17
  73. nat/data_models/dataset_handler.py +1 -1
  74. nat/data_models/discovery_metadata.py +4 -4
  75. nat/data_models/evaluate.py +4 -1
  76. nat/data_models/function.py +34 -0
  77. nat/data_models/function_dependencies.py +14 -6
  78. nat/data_models/gated_field_mixin.py +242 -0
  79. nat/data_models/intermediate_step.py +3 -3
  80. nat/data_models/optimizable.py +119 -0
  81. nat/data_models/optimizer.py +149 -0
  82. nat/data_models/span.py +41 -3
  83. nat/data_models/swe_bench_model.py +1 -1
  84. nat/data_models/temperature_mixin.py +44 -0
  85. nat/data_models/thinking_mixin.py +86 -0
  86. nat/data_models/top_p_mixin.py +44 -0
  87. nat/embedder/nim_embedder.py +1 -1
  88. nat/embedder/openai_embedder.py +1 -1
  89. nat/embedder/register.py +0 -1
  90. nat/eval/config.py +3 -1
  91. nat/eval/dataset_handler/dataset_handler.py +71 -7
  92. nat/eval/evaluate.py +86 -31
  93. nat/eval/evaluator/base_evaluator.py +1 -1
  94. nat/eval/evaluator/evaluator_model.py +13 -0
  95. nat/eval/intermediate_step_adapter.py +1 -1
  96. nat/eval/rag_evaluator/evaluate.py +2 -2
  97. nat/eval/rag_evaluator/register.py +3 -3
  98. nat/eval/register.py +4 -1
  99. nat/eval/remote_workflow.py +3 -3
  100. nat/eval/runtime_evaluator/__init__.py +14 -0
  101. nat/eval/runtime_evaluator/evaluate.py +123 -0
  102. nat/eval/runtime_evaluator/register.py +100 -0
  103. nat/eval/swe_bench_evaluator/evaluate.py +6 -6
  104. nat/eval/trajectory_evaluator/evaluate.py +1 -1
  105. nat/eval/trajectory_evaluator/register.py +1 -1
  106. nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
  107. nat/eval/utils/eval_trace_ctx.py +89 -0
  108. nat/eval/utils/weave_eval.py +18 -9
  109. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  110. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  111. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
  112. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
  113. nat/experimental/test_time_compute/models/strategy_base.py +5 -4
  114. nat/experimental/test_time_compute/register.py +0 -1
  115. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
  116. nat/front_ends/console/authentication_flow_handler.py +82 -30
  117. nat/front_ends/console/console_front_end_plugin.py +8 -5
  118. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  119. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  120. nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
  121. nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
  122. nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
  123. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +452 -282
  124. nat/front_ends/fastapi/job_store.py +518 -99
  125. nat/front_ends/fastapi/main.py +11 -19
  126. nat/front_ends/fastapi/message_handler.py +13 -14
  127. nat/front_ends/fastapi/message_validator.py +19 -19
  128. nat/front_ends/fastapi/response_helpers.py +4 -4
  129. nat/front_ends/fastapi/step_adaptor.py +2 -2
  130. nat/front_ends/fastapi/utils.py +57 -0
  131. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  132. nat/front_ends/mcp/mcp_front_end_config.py +10 -1
  133. nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
  134. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
  135. nat/front_ends/mcp/tool_converter.py +44 -14
  136. nat/front_ends/register.py +0 -1
  137. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  138. nat/llm/aws_bedrock_llm.py +24 -12
  139. nat/llm/azure_openai_llm.py +13 -6
  140. nat/llm/litellm_llm.py +69 -0
  141. nat/llm/nim_llm.py +20 -8
  142. nat/llm/openai_llm.py +14 -6
  143. nat/llm/register.py +4 -1
  144. nat/llm/utils/env_config_value.py +2 -3
  145. nat/llm/utils/thinking.py +215 -0
  146. nat/meta/pypi.md +9 -9
  147. nat/object_store/register.py +0 -1
  148. nat/observability/exporter/base_exporter.py +3 -3
  149. nat/observability/exporter/file_exporter.py +1 -1
  150. nat/observability/exporter/processing_exporter.py +309 -81
  151. nat/observability/exporter/span_exporter.py +35 -15
  152. nat/observability/exporter_manager.py +7 -7
  153. nat/observability/mixin/file_mixin.py +7 -7
  154. nat/observability/mixin/redaction_config_mixin.py +42 -0
  155. nat/observability/mixin/tagging_config_mixin.py +62 -0
  156. nat/observability/mixin/type_introspection_mixin.py +420 -107
  157. nat/observability/processor/batching_processor.py +5 -7
  158. nat/observability/processor/falsy_batch_filter_processor.py +55 -0
  159. nat/observability/processor/processor.py +3 -0
  160. nat/observability/processor/processor_factory.py +70 -0
  161. nat/observability/processor/redaction/__init__.py +24 -0
  162. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  163. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  164. nat/observability/processor/redaction/redaction_processor.py +177 -0
  165. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  166. nat/observability/processor/span_tagging_processor.py +68 -0
  167. nat/observability/register.py +6 -4
  168. nat/profiler/calc/calc_runner.py +3 -4
  169. nat/profiler/callbacks/agno_callback_handler.py +1 -1
  170. nat/profiler/callbacks/langchain_callback_handler.py +6 -6
  171. nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
  172. nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
  173. nat/profiler/data_frame_row.py +1 -1
  174. nat/profiler/decorators/framework_wrapper.py +62 -13
  175. nat/profiler/decorators/function_tracking.py +160 -3
  176. nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
  177. nat/profiler/forecasting/models/linear_model.py +1 -1
  178. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  179. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  180. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
  181. nat/profiler/inference_optimization/data_models.py +3 -3
  182. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +8 -9
  183. nat/profiler/inference_optimization/token_uniqueness.py +1 -1
  184. nat/profiler/parameter_optimization/__init__.py +0 -0
  185. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  186. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  187. nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
  188. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  189. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  190. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  191. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  192. nat/profiler/profile_runner.py +14 -9
  193. nat/profiler/utils.py +4 -2
  194. nat/registry_handlers/local/local_handler.py +2 -2
  195. nat/registry_handlers/package_utils.py +1 -2
  196. nat/registry_handlers/pypi/pypi_handler.py +23 -26
  197. nat/registry_handlers/register.py +3 -4
  198. nat/registry_handlers/rest/rest_handler.py +12 -13
  199. nat/retriever/milvus/retriever.py +2 -2
  200. nat/retriever/nemo_retriever/retriever.py +1 -1
  201. nat/retriever/register.py +0 -1
  202. nat/runtime/loader.py +2 -2
  203. nat/runtime/runner.py +106 -8
  204. nat/runtime/session.py +69 -8
  205. nat/settings/global_settings.py +16 -5
  206. nat/tool/chat_completion.py +5 -2
  207. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
  208. nat/tool/datetime_tools.py +49 -9
  209. nat/tool/document_search.py +2 -2
  210. nat/tool/github_tools.py +450 -0
  211. nat/tool/memory_tools/get_memory_tool.py +1 -1
  212. nat/tool/nvidia_rag.py +1 -1
  213. nat/tool/register.py +2 -9
  214. nat/tool/retriever.py +3 -2
  215. nat/utils/callable_utils.py +70 -0
  216. nat/utils/data_models/schema_validator.py +3 -3
  217. nat/utils/decorators.py +210 -0
  218. nat/utils/exception_handlers/automatic_retries.py +104 -51
  219. nat/utils/exception_handlers/schemas.py +1 -1
  220. nat/utils/io/yaml_tools.py +2 -2
  221. nat/utils/log_levels.py +25 -0
  222. nat/utils/reactive/base/observable_base.py +2 -2
  223. nat/utils/reactive/base/observer_base.py +1 -1
  224. nat/utils/reactive/observable.py +2 -2
  225. nat/utils/reactive/observer.py +4 -4
  226. nat/utils/reactive/subscription.py +1 -1
  227. nat/utils/settings/global_settings.py +6 -8
  228. nat/utils/type_converter.py +4 -3
  229. nat/utils/type_utils.py +9 -5
  230. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/METADATA +42 -18
  231. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/RECORD +238 -196
  232. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/entry_points.txt +1 -0
  233. nat/cli/commands/info/list_mcp.py +0 -304
  234. nat/tool/github_tools/create_github_commit.py +0 -133
  235. nat/tool/github_tools/create_github_issue.py +0 -87
  236. nat/tool/github_tools/create_github_pr.py +0 -106
  237. nat/tool/github_tools/get_github_file.py +0 -106
  238. nat/tool/github_tools/get_github_issue.py +0 -166
  239. nat/tool/github_tools/get_github_pr.py +0 -256
  240. nat/tool/github_tools/update_github_issue.py +0 -100
  241. nat/tool/mcp/exceptions.py +0 -142
  242. nat/tool/mcp/mcp_client.py +0 -255
  243. nat/tool/mcp/mcp_tool.py +0 -96
  244. nat/utils/exception_handlers/mcp.py +0 -211
  245. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  246. /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
  247. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/WHEEL +0 -0
  248. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  249. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  250. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/top_level.txt +0 -0
@@ -39,7 +39,7 @@ class SWEBenchInput(BaseModel):
39
39
 
40
40
  # Handle improperly formatted JSON strings for list fields
41
41
  @field_validator("FAIL_TO_PASS", "PASS_TO_PASS", mode="before")
42
- def parse_list_fields(cls, value): # pylint: disable=no-self-argument
42
+ def parse_list_fields(cls, value):
43
43
  if isinstance(value, str):
44
44
  # Attempt to parse the string as a list
45
45
  return json.loads(value)
@@ -0,0 +1,44 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+
18
+ from pydantic import BaseModel
19
+
20
+ from nat.data_models.gated_field_mixin import GatedFieldMixin
21
+ from nat.data_models.optimizable import OptimizableField
22
+ from nat.data_models.optimizable import SearchSpace
23
+
24
+
25
+ class TemperatureMixin(
26
+ BaseModel,
27
+ GatedFieldMixin,
28
+ field_name="temperature",
29
+ default_if_supported=0.0,
30
+ keys=("model_name", "model", "azure_deployment"),
31
+ unsupported=(re.compile(r"gpt-?5", re.IGNORECASE), ),
32
+ ):
33
+ """
34
+ Mixin class for temperature configuration. Unsupported on models like gpt-5.
35
+
36
+ Attributes:
37
+ temperature: Sampling temperature in [0, 1]. Defaults to 0.0 when supported on the model.
38
+ """
39
+ temperature: float | None = OptimizableField(
40
+ default=None,
41
+ ge=0.0,
42
+ le=1.0,
43
+ description="Sampling temperature in [0, 1]. Defaults to 0.0 when supported on the model.",
44
+ space=SearchSpace(high=0.9, low=0.1, step=0.2))
@@ -0,0 +1,86 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+
18
+ from pydantic import BaseModel
19
+ from pydantic import Field
20
+
21
+ from nat.data_models.gated_field_mixin import GatedFieldMixin
22
+
23
+ # Currently the control logic for thinking is only implemented for Nemotron models
24
+ _NEMOTRON_REGEX = re.compile(r"^nvidia/(llama|nvidia).*nemotron", re.IGNORECASE)
25
+ # The keys are the fields that are used to determine if the model supports thinking
26
+ _MODEL_KEYS = ("model_name", "model", "azure_deployment")
27
+
28
+
29
+ class ThinkingMixin(
30
+ BaseModel,
31
+ GatedFieldMixin,
32
+ field_name="thinking",
33
+ default_if_supported=None,
34
+ keys=_MODEL_KEYS,
35
+ supported=(_NEMOTRON_REGEX, ),
36
+ ):
37
+ """
38
+ Mixin class for thinking configuration. Only supported on Nemotron models.
39
+
40
+ Attributes:
41
+ thinking: Whether to enable thinking. Defaults to None when supported on the model.
42
+ """
43
+ thinking: bool | None = Field(
44
+ default=None,
45
+ description="Whether to enable thinking. Defaults to None when supported on the model.",
46
+ )
47
+
48
+ @property
49
+ def thinking_system_prompt(self) -> str | None:
50
+ """
51
+ Returns the system prompt to use for thinking.
52
+ For NVIDIA Nemotron, returns "/think" if enabled, else "/no_think".
53
+ For Llama Nemotron v1.5, returns "/think" if enabled, else "/no_think".
54
+ For Llama Nemotron v1.0, returns "detailed thinking on" if enabled, else "detailed thinking off".
55
+ If thinking is not supported on the model, returns None.
56
+
57
+ Returns:
58
+ str | None: The system prompt to use for thinking.
59
+ """
60
+ if self.thinking is None:
61
+ return None
62
+
63
+ for key in _MODEL_KEYS:
64
+ model = getattr(self, key, None)
65
+ if not isinstance(model, str) or model is None:
66
+ continue
67
+
68
+ # Normalize name to reduce checks
69
+ model = model.lower().translate(str.maketrans("_.", "--"))
70
+
71
+ if model.startswith("nvidia/nvidia"):
72
+ return "/think" if self.thinking else "/no_think"
73
+
74
+ if model.startswith("nvidia/llama"):
75
+ if "v1-0" in model or "v1-1" in model:
76
+ return f"detailed thinking {'on' if self.thinking else 'off'}"
77
+
78
+ if "v1-5" in model:
79
+ # v1.5 models are updated to use the /think and /no_think system prompts
80
+ return "/think" if self.thinking else "/no_think"
81
+
82
+ # Assume any other model is a newer model that uses the /think and /no_think system prompts
83
+ return "/think" if self.thinking else "/no_think"
84
+
85
+ # Unknown model
86
+ return None
@@ -0,0 +1,44 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+
18
+ from pydantic import BaseModel
19
+
20
+ from nat.data_models.gated_field_mixin import GatedFieldMixin
21
+ from nat.data_models.optimizable import OptimizableField
22
+ from nat.data_models.optimizable import SearchSpace
23
+
24
+
25
+ class TopPMixin(
26
+ BaseModel,
27
+ GatedFieldMixin,
28
+ field_name="top_p",
29
+ default_if_supported=1.0,
30
+ keys=("model_name", "model", "azure_deployment"),
31
+ unsupported=(re.compile(r"gpt-?5", re.IGNORECASE), ),
32
+ ):
33
+ """
34
+ Mixin class for top-p configuration. Unsupported on models like gpt-5.
35
+
36
+ Attributes:
37
+ top_p: Top-p for distribution sampling. Defaults to 1.0 when supported on the model.
38
+ """
39
+ top_p: float | None = OptimizableField(
40
+ default=None,
41
+ ge=0.0,
42
+ le=1.0,
43
+ description="Top-p for distribution sampling. Defaults to 1.0 when supported on the model.",
44
+ space=SearchSpace(high=1.0, low=0.5, step=0.1))
@@ -50,7 +50,7 @@ class NIMEmbedderModelConfig(EmbedderBaseConfig, RetryMixin, name="nim"):
50
50
  description=("The truncation strategy if the input on the "
51
51
  "server side if it's too large."))
52
52
 
53
- model_config = ConfigDict(protected_namespaces=())
53
+ model_config = ConfigDict(protected_namespaces=(), extra="allow")
54
54
 
55
55
 
56
56
  @register_embedder_provider(config_type=NIMEmbedderModelConfig)
@@ -27,7 +27,7 @@ from nat.data_models.retry_mixin import RetryMixin
27
27
  class OpenAIEmbedderModelConfig(EmbedderBaseConfig, RetryMixin, name="openai"):
28
28
  """An OpenAI LLM provider to be used with an LLM client."""
29
29
 
30
- model_config = ConfigDict(protected_namespaces=())
30
+ model_config = ConfigDict(protected_namespaces=(), extra="allow")
31
31
 
32
32
  api_key: str | None = Field(default=None, description="OpenAI API key to interact with hosted model.")
33
33
  base_url: str | None = Field(default=None, description="Base url to the hosted model.")
nat/embedder/register.py CHANGED
@@ -13,7 +13,6 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- # pylint: disable=unused-import
17
16
  # flake8: noqa
18
17
  # isort:skip_file
19
18
 
nat/eval/config.py CHANGED
@@ -27,7 +27,7 @@ class EvaluationRunConfig(BaseModel):
27
27
  """
28
28
  Parameters used for a single evaluation run.
29
29
  """
30
- config_file: Path
30
+ config_file: Path | BaseModel
31
31
  dataset: str | None = None # dataset file path can be specified in the config file
32
32
  result_json_path: str = "$"
33
33
  skip_workflow: bool = False
@@ -44,6 +44,8 @@ class EvaluationRunConfig(BaseModel):
44
44
  # number of passes at each concurrency, if 0 the dataset is adjusted to a multiple of the
45
45
  # concurrency. The is only used if adjust_dataset_size is true
46
46
  num_passes: int = 0
47
+ # timeout for waiting for trace export tasks to complete
48
+ export_timeout: float = 60.0
47
49
 
48
50
 
49
51
  class EvaluationRunOutput(BaseModel):
@@ -13,6 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import importlib
16
17
  import json
17
18
  import math
18
19
  from pathlib import Path
@@ -41,7 +42,8 @@ class DatasetHandler:
41
42
  reps: int,
42
43
  concurrency: int,
43
44
  num_passes: int = 1,
44
- adjust_dataset_size: bool = False):
45
+ adjust_dataset_size: bool = False,
46
+ custom_pre_eval_process_function: str | None = None):
45
47
  from nat.eval.intermediate_step_adapter import IntermediateStepAdapter
46
48
 
47
49
  self.dataset_config = dataset_config
@@ -53,6 +55,9 @@ class DatasetHandler:
53
55
  self.num_passes = num_passes
54
56
  self.adjust_dataset_size = adjust_dataset_size
55
57
 
58
+ # Custom pre-evaluation process function
59
+ self.custom_pre_eval_process_function = custom_pre_eval_process_function
60
+
56
61
  # Helpers
57
62
  self.intermediate_step_adapter = IntermediateStepAdapter()
58
63
 
@@ -146,13 +151,12 @@ class DatasetHandler:
146
151
  # When num_passes is specified, always use concurrency * num_passes
147
152
  # This respects the user's intent for exact number of passes
148
153
  target_size = self.concurrency * self.num_passes
154
+ # When num_passes = 0, use the largest multiple of concurrency <= original_size
155
+ # If original_size < concurrency, we need at least concurrency rows
156
+ elif original_size >= self.concurrency:
157
+ target_size = (original_size // self.concurrency) * self.concurrency
149
158
  else:
150
- # When num_passes = 0, use the largest multiple of concurrency <= original_size
151
- # If original_size < concurrency, we need at least concurrency rows
152
- if original_size >= self.concurrency:
153
- target_size = (original_size // self.concurrency) * self.concurrency
154
- else:
155
- target_size = self.concurrency
159
+ target_size = self.concurrency
156
160
 
157
161
  if target_size == 0:
158
162
  raise ValueError("Input dataset too small for even one batch at given concurrency.")
@@ -331,6 +335,66 @@ class DatasetHandler:
331
335
  filtered_steps = self.intermediate_step_adapter.filter_intermediate_steps(intermediate_steps, event_filter)
332
336
  return self.intermediate_step_adapter.serialize_intermediate_steps(filtered_steps)
333
337
 
338
+ def pre_eval_process_eval_input(self, eval_input: EvalInput) -> EvalInput:
339
+ """
340
+ Pre-evaluation process the eval input using custom function if provided.
341
+
342
+ The custom pre-evaluation process function should have the signature:
343
+ def custom_pre_eval_process(item: EvalInputItem) -> EvalInputItem
344
+
345
+ The framework will iterate through all items and call this function on each one.
346
+
347
+ Args:
348
+ eval_input: The EvalInput object to pre-evaluation process
349
+
350
+ Returns:
351
+ The pre-evaluation processed EvalInput object
352
+ """
353
+ if self.custom_pre_eval_process_function:
354
+ try:
355
+ custom_function = self._load_custom_pre_eval_process_function()
356
+ processed_items = []
357
+
358
+ for item in eval_input.eval_input_items:
359
+ processed_item = custom_function(item)
360
+ if not isinstance(processed_item, EvalInputItem):
361
+ raise TypeError(f"Custom pre-evaluation '{self.custom_pre_eval_process_function}' must return "
362
+ f"EvalInputItem, got {type(processed_item)}")
363
+ processed_items.append(processed_item)
364
+
365
+ return EvalInput(eval_input_items=processed_items)
366
+ except Exception as e:
367
+ raise RuntimeError(f"Error calling custom pre-evaluation process function "
368
+ f"'{self.custom_pre_eval_process_function}': {e}") from e
369
+
370
+ return eval_input
371
+
372
+ def _load_custom_pre_eval_process_function(self):
373
+ """
374
+ Import and return the custom pre-evaluation process function using standard Python import path.
375
+
376
+ The function should process individual EvalInputItem objects.
377
+ """
378
+ # Split the function path to get module and function name
379
+ if "." not in self.custom_pre_eval_process_function:
380
+ raise ValueError(f"Invalid custom_pre_eval_process_function '{self.custom_pre_eval_process_function}'. "
381
+ "Expected format: '<module_path>.<function_name>'")
382
+ module_path, function_name = self.custom_pre_eval_process_function.rsplit(".", 1)
383
+
384
+ # Import the module
385
+ module = importlib.import_module(module_path)
386
+
387
+ # Get the function from the module
388
+ if not hasattr(module, function_name):
389
+ raise AttributeError(f"Function '{function_name}' not found in module '{module_path}'")
390
+
391
+ custom_function = getattr(module, function_name)
392
+
393
+ if not callable(custom_function):
394
+ raise ValueError(f"'{self.custom_pre_eval_process_function}' is not callable")
395
+
396
+ return custom_function
397
+
334
398
  def publish_eval_input(self,
335
399
  eval_input,
336
400
  workflow_output_step_filter: list[IntermediateStepType] | None = None) -> str:
nat/eval/evaluate.py CHANGED
@@ -42,7 +42,7 @@ from nat.runtime.session import SessionManager
42
42
  logger = logging.getLogger(__name__)
43
43
 
44
44
 
45
- class EvaluationRun: # pylint: disable=too-many-public-methods
45
+ class EvaluationRun:
46
46
  """
47
47
  Instantiated for each evaluation run and used to store data for that single run.
48
48
 
@@ -63,7 +63,16 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
63
63
 
64
64
  # Helpers
65
65
  self.intermediate_step_adapter: IntermediateStepAdapter = IntermediateStepAdapter()
66
- self.weave_eval: WeaveEvaluationIntegration = WeaveEvaluationIntegration()
66
+
67
+ # Create evaluation trace context
68
+ try:
69
+ from nat.eval.utils.eval_trace_ctx import WeaveEvalTraceContext
70
+ self.eval_trace_context = WeaveEvalTraceContext()
71
+ except Exception:
72
+ from nat.eval.utils.eval_trace_ctx import EvalTraceContext
73
+ self.eval_trace_context = EvalTraceContext()
74
+
75
+ self.weave_eval: WeaveEvaluationIntegration = WeaveEvaluationIntegration(self.eval_trace_context)
67
76
  # Metadata
68
77
  self.eval_input: EvalInput | None = None
69
78
  self.workflow_interrupted: bool = False
@@ -159,17 +168,17 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
159
168
  intermediate_future = None
160
169
 
161
170
  try:
162
-
163
171
  # Start usage stats and intermediate steps collection in parallel
164
172
  intermediate_future = pull_intermediate()
165
173
  runner_result = runner.result()
166
174
  base_output = await runner_result
167
175
  intermediate_steps = await intermediate_future
168
176
  except NotImplementedError as e:
177
+ logger.error("Failed to run the workflow: %s", e)
169
178
  # raise original error
170
- raise e
179
+ raise
171
180
  except Exception as e:
172
- logger.exception("Failed to run the workflow: %s", e, exc_info=True)
181
+ logger.exception("Failed to run the workflow: %s", e)
173
182
  # stop processing if a workflow error occurs
174
183
  self.workflow_interrupted = True
175
184
 
@@ -308,9 +317,9 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
308
317
  logger.info("Deleting old job directory: %s", dir_to_delete)
309
318
  shutil.rmtree(dir_to_delete)
310
319
  except Exception as e:
311
- logger.exception("Failed to delete old job directory: %s: %s", dir_to_delete, e, exc_info=True)
320
+ logger.exception("Failed to delete old job directory: %s: %s", dir_to_delete, e)
312
321
 
313
- def write_output(self, dataset_handler: DatasetHandler, profiler_results: ProfilerResults): # pylint: disable=unused-argument # noqa: E501
322
+ def write_output(self, dataset_handler: DatasetHandler, profiler_results: ProfilerResults):
314
323
  workflow_output_file = self.eval_config.general.output_dir / "workflow_output.json"
315
324
  workflow_output_file.parent.mkdir(parents=True, exist_ok=True)
316
325
 
@@ -358,7 +367,7 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
358
367
 
359
368
  await self.weave_eval.alog_score(eval_output, evaluator_name)
360
369
  except Exception as e:
361
- logger.exception("An error occurred while running evaluator %s: %s", evaluator_name, e, exc_info=True)
370
+ logger.exception("An error occurred while running evaluator %s: %s", evaluator_name, e)
362
371
 
363
372
  async def run_evaluators(self, evaluators: dict[str, Any]):
364
373
  """Run all configured evaluators asynchronously."""
@@ -371,7 +380,7 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
371
380
  try:
372
381
  await asyncio.gather(*tasks)
373
382
  except Exception as e:
374
- logger.exception("An error occurred while running evaluators: %s", e, exc_info=True)
383
+ logger.error("An error occurred while running evaluators: %s", e)
375
384
  raise
376
385
  finally:
377
386
  # Finish prediction loggers in Weave
@@ -401,6 +410,33 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
401
410
 
402
411
  return workflow_type
403
412
 
413
+ async def wait_for_all_export_tasks_local(self, session_manager: SessionManager, timeout: float) -> None:
414
+ """Wait for all trace export tasks to complete for local workflows.
415
+
416
+ This only works for local workflows where we have direct access to the
417
+ SessionManager and its underlying workflow with exporter manager.
418
+ """
419
+ try:
420
+ workflow = session_manager.workflow
421
+ all_exporters = await workflow.get_all_exporters()
422
+ if not all_exporters:
423
+ logger.debug("No exporters to wait for")
424
+ return
425
+
426
+ logger.info("Waiting for export tasks from %d local exporters (timeout: %ds)", len(all_exporters), timeout)
427
+
428
+ for name, exporter in all_exporters.items():
429
+ try:
430
+ await exporter.wait_for_tasks(timeout=timeout)
431
+ logger.info("Export tasks completed for exporter: %s", name)
432
+ except Exception as e:
433
+ logger.warning("Error waiting for export tasks from %s: %s", name, e)
434
+
435
+ logger.info("All local export task waiting completed")
436
+
437
+ except Exception as e:
438
+ logger.warning("Failed to wait for local export tasks: %s", e)
439
+
404
440
  async def run_and_evaluate(self,
405
441
  session_manager: SessionManager | None = None,
406
442
  job_id: str | None = None) -> EvaluationRunOutput:
@@ -413,10 +449,14 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
413
449
  from nat.runtime.loader import load_config
414
450
 
415
451
  # Load and override the config
416
- if self.config.override:
452
+ config = None
453
+ if isinstance(self.config.config_file, BaseModel):
454
+ config = self.config.config_file
455
+ elif self.config.override:
417
456
  config = self.apply_overrides()
418
457
  else:
419
458
  config = load_config(self.config.config_file)
459
+
420
460
  self.eval_config = config.eval
421
461
  workflow_alias = self._get_workflow_alias(config.workflow.type)
422
462
  logger.debug("Loaded %s evaluation configuration: %s", workflow_alias, self.eval_config)
@@ -442,44 +482,59 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
442
482
  dataset_config = self.eval_config.general.dataset # Currently only one dataset is supported
443
483
  if not dataset_config:
444
484
  logger.info("No dataset found, nothing to evaluate")
445
- return EvaluationRunOutput(
446
- workflow_output_file=self.workflow_output_file,
447
- evaluator_output_files=self.evaluator_output_files,
448
- workflow_interrupted=self.workflow_interrupted,
449
- )
450
-
485
+ return EvaluationRunOutput(workflow_output_file=self.workflow_output_file,
486
+ evaluator_output_files=self.evaluator_output_files,
487
+ workflow_interrupted=self.workflow_interrupted,
488
+ eval_input=EvalInput(eval_input_items=[]),
489
+ evaluation_results=[],
490
+ usage_stats=UsageStats(),
491
+ profiler_results=ProfilerResults())
492
+
493
+ custom_pre_eval_process_function = self.eval_config.general.output.custom_pre_eval_process_function \
494
+ if self.eval_config.general.output else None
451
495
  dataset_handler = DatasetHandler(dataset_config=dataset_config,
452
496
  reps=self.config.reps,
453
497
  concurrency=self.eval_config.general.max_concurrency,
454
498
  num_passes=self.config.num_passes,
455
- adjust_dataset_size=self.config.adjust_dataset_size)
499
+ adjust_dataset_size=self.config.adjust_dataset_size,
500
+ custom_pre_eval_process_function=custom_pre_eval_process_function)
456
501
  self.eval_input = dataset_handler.get_eval_input_from_dataset(self.config.dataset)
457
502
  if not self.eval_input.eval_input_items:
458
503
  logger.info("Dataset is empty. Nothing to evaluate.")
459
- return EvaluationRunOutput(
460
- workflow_output_file=self.workflow_output_file,
461
- evaluator_output_files=self.evaluator_output_files,
462
- workflow_interrupted=self.workflow_interrupted,
463
- )
504
+ return EvaluationRunOutput(workflow_output_file=self.workflow_output_file,
505
+ evaluator_output_files=self.evaluator_output_files,
506
+ workflow_interrupted=self.workflow_interrupted,
507
+ eval_input=self.eval_input,
508
+ evaluation_results=self.evaluation_results,
509
+ usage_stats=self.usage_stats,
510
+ profiler_results=ProfilerResults())
464
511
 
465
512
  # Run workflow and evaluate
466
513
  async with WorkflowEvalBuilder.from_config(config=config) as eval_workflow:
467
514
  # Initialize Weave integration
468
515
  self.weave_eval.initialize_logger(workflow_alias, self.eval_input, config)
469
516
 
470
- # Run workflow
471
- if self.config.endpoint:
472
- await self.run_workflow_remote()
473
- else:
474
- if not self.config.skip_workflow:
517
+ with self.eval_trace_context.evaluation_context():
518
+ # Run workflow
519
+ if self.config.endpoint:
520
+ await self.run_workflow_remote()
521
+ elif not self.config.skip_workflow:
475
522
  if session_manager is None:
476
- session_manager = SessionManager(eval_workflow.build(),
523
+ workflow = await eval_workflow.build()
524
+ session_manager = SessionManager(workflow,
477
525
  max_concurrency=self.eval_config.general.max_concurrency)
478
526
  await self.run_workflow_local(session_manager)
479
527
 
480
- # Evaluate
481
- evaluators = {name: eval_workflow.get_evaluator(name) for name in self.eval_config.evaluators}
482
- await self.run_evaluators(evaluators)
528
+ # Pre-evaluation process the workflow output
529
+ self.eval_input = dataset_handler.pre_eval_process_eval_input(self.eval_input)
530
+
531
+ # Evaluate
532
+ evaluators = {name: eval_workflow.get_evaluator(name) for name in self.eval_config.evaluators}
533
+ await self.run_evaluators(evaluators)
534
+
535
+ # Wait for all trace export tasks to complete (local workflows only)
536
+ if session_manager and not self.config.endpoint:
537
+ await self.wait_for_all_export_tasks_local(session_manager, timeout=self.config.export_timeout)
483
538
 
484
539
  # Profile the workflow
485
540
  profiler_results = await self.profile_workflow()
@@ -71,7 +71,7 @@ class BaseEvaluator(ABC):
71
71
  TqdmPositionRegistry.release(tqdm_position)
72
72
 
73
73
  # Compute average if possible
74
- numeric_scores = [item.score for item in output_items if isinstance(item.score, (int, float))]
74
+ numeric_scores = [item.score for item in output_items if isinstance(item.score, int | float)]
75
75
  avg_score = round(sum(numeric_scores) / len(numeric_scores), 2) if numeric_scores else None
76
76
 
77
77
  return EvalOutput(average_score=avg_score, eval_output_items=output_items)
@@ -29,6 +29,19 @@ class EvalInputItem(BaseModel):
29
29
  trajectory: list[IntermediateStep] = [] # populated by the workflow
30
30
  full_dataset_entry: typing.Any
31
31
 
32
+ def copy_with_updates(self, **updates) -> "EvalInputItem":
33
+ """
34
+ Copy EvalInputItem with optional field updates.
35
+ """
36
+ # Get all current fields
37
+ item_data = self.model_dump()
38
+
39
+ # Apply any updates
40
+ item_data.update(updates)
41
+
42
+ # Create new item with all fields
43
+ return EvalInputItem(**item_data)
44
+
32
45
 
33
46
  class EvalInput(BaseModel):
34
47
  eval_input_items: list[EvalInputItem]
@@ -40,7 +40,7 @@ class IntermediateStepAdapter:
40
40
  try:
41
41
  validated_steps.append(IntermediateStep.model_validate(step_data))
42
42
  except Exception as e:
43
- logger.exception("Validation failed for step: %r, Error: %s", step_data, e, exc_info=True)
43
+ logger.exception("Validation failed for step: %r, Error: %s", step_data, e)
44
44
  return validated_steps
45
45
 
46
46
  def serialize_intermediate_steps(self, intermediate_steps: list[IntermediateStep]) -> list[dict]:
@@ -102,7 +102,7 @@ class RAGEvaluator:
102
102
  """Converts the ragas EvaluationResult to nat EvalOutput"""
103
103
 
104
104
  if not results_dataset:
105
- logger.error("Ragas evaluation failed with no results")
105
+ logger.error("Ragas evaluation failed with no results", exc_info=True)
106
106
  return EvalOutput(average_score=0.0, eval_output_items=[])
107
107
 
108
108
  scores: list[dict[str, float]] = results_dataset.scores
@@ -169,7 +169,7 @@ class RAGEvaluator:
169
169
  _pbar=pbar)
170
170
  except Exception as e:
171
171
  # On exception we still continue with other evaluators. Log and return an avg_score of 0.0
172
- logger.exception("Error evaluating ragas metric, Error: %s", e, exc_info=True)
172
+ logger.exception("Error evaluating ragas metric, Error: %s", e)
173
173
  results_dataset = None
174
174
  finally:
175
175
  pbar.close()
@@ -73,7 +73,7 @@ class RagasEvaluatorConfig(EvaluatorBaseConfig, name="ragas"):
73
73
  if isinstance(self.metric, str):
74
74
  return self.metric
75
75
  if isinstance(self.metric, dict) and self.metric:
76
- return next(iter(self.metric.keys())) # pylint: disable=no-member
76
+ return next(iter(self.metric.keys()))
77
77
  return ""
78
78
 
79
79
  @property
@@ -82,7 +82,7 @@ class RagasEvaluatorConfig(EvaluatorBaseConfig, name="ragas"):
82
82
  if isinstance(self.metric, str):
83
83
  return RagasMetricConfig() # Default config when only a metric name is given
84
84
  if isinstance(self.metric, dict) and self.metric:
85
- return next(iter(self.metric.values())) # pylint: disable=no-member
85
+ return next(iter(self.metric.values()))
86
86
  return RagasMetricConfig() # Default config when an invalid type is provided
87
87
 
88
88
 
@@ -104,7 +104,7 @@ async def register_ragas_evaluator(config: RagasEvaluatorConfig, builder: EvalBu
104
104
  raise ValueError(message) from e
105
105
  except AttributeError as e:
106
106
  message = f"Ragas metric {metric_name} not found {e}."
107
- logger.error(message)
107
+ logger.exception(message)
108
108
  return None
109
109
 
110
110
  async def evaluate_fn(eval_input: EvalInput) -> EvalOutput:
nat/eval/register.py CHANGED
@@ -14,10 +14,13 @@
14
14
  # limitations under the License.
15
15
 
16
16
  # flake8: noqa
17
- # pylint: disable=unused-import
18
17
 
19
18
  # Import evaluators which need to be automatically registered here
20
19
  from .rag_evaluator.register import register_ragas_evaluator
20
+ from .runtime_evaluator.register import register_avg_llm_latency_evaluator
21
+ from .runtime_evaluator.register import register_avg_num_llm_calls_evaluator
22
+ from .runtime_evaluator.register import register_avg_tokens_per_llm_end_evaluator
23
+ from .runtime_evaluator.register import register_avg_workflow_runtime_evaluator
21
24
  from .swe_bench_evaluator.register import register_swe_bench_evaluator
22
25
  from .trajectory_evaluator.register import register_trajectory_evaluator
23
26
  from .tunable_rag_evaluator.register import register_tunable_rag_evaluator