nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +2 -2
- nat/agent/base.py +24 -15
- nat/agent/dual_node.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +79 -47
- nat/agent/react_agent/register.py +50 -22
- nat/agent/reasoning_agent/reasoning_agent.py +11 -9
- nat/agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +326 -148
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +54 -27
- nat/agent/tool_calling_agent/agent.py +84 -28
- nat/agent/tool_calling_agent/register.py +51 -28
- nat/authentication/api_key/api_key_auth_provider.py +2 -2
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +56 -24
- nat/builder/component_utils.py +9 -5
- nat/builder/context.py +68 -17
- nat/builder/eval_builder.py +16 -11
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +378 -8
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +6 -8
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +13 -1
- nat/builder/workflow_builder.py +281 -76
- nat/cli/cli_utils/config_override.py +2 -2
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +15 -17
- nat/cli/commands/start.py +16 -5
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
- nat/cli/commands/workflow/templates/register.py.j2 +2 -3
- nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
- nat/cli/commands/workflow/workflow_commands.py +62 -22
- nat/cli/entrypoint.py +8 -10
- nat/cli/main.py +3 -0
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +75 -6
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +74 -66
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +41 -17
- nat/data_models/dataset_handler.py +1 -1
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/evaluate.py +4 -1
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +14 -6
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/span.py +41 -3
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/embedder/register.py +0 -1
- nat/eval/config.py +3 -1
- nat/eval/dataset_handler/dataset_handler.py +71 -7
- nat/eval/evaluate.py +86 -31
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/evaluator/evaluator_model.py +13 -0
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +2 -2
- nat/eval/rag_evaluator/register.py +3 -3
- nat/eval/register.py +4 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +6 -6
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +18 -9
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +5 -4
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +8 -5
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +452 -282
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +13 -14
- nat/front_ends/fastapi/message_validator.py +19 -19
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +2 -2
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +10 -1
- nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
- nat/front_ends/mcp/tool_converter.py +44 -14
- nat/front_ends/register.py +0 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +24 -12
- nat/llm/azure_openai_llm.py +13 -6
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +20 -8
- nat/llm/openai_llm.py +14 -6
- nat/llm/register.py +4 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/llm/utils/thinking.py +215 -0
- nat/meta/pypi.md +9 -9
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +3 -3
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +309 -81
- nat/observability/exporter/span_exporter.py +35 -15
- nat/observability/exporter_manager.py +7 -7
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +5 -7
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +6 -4
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +6 -6
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +62 -13
- nat/profiler/decorators/function_tracking.py +160 -3
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/forecasting/models/linear_model.py +1 -1
- nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +8 -9
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +14 -9
- nat/profiler/utils.py +4 -2
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -2
- nat/registry_handlers/pypi/pypi_handler.py +23 -26
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +12 -13
- nat/retriever/milvus/retriever.py +2 -2
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/retriever/register.py +0 -1
- nat/runtime/loader.py +2 -2
- nat/runtime/runner.py +106 -8
- nat/runtime/session.py +69 -8
- nat/settings/global_settings.py +16 -5
- nat/tool/chat_completion.py +5 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +2 -2
- nat/tool/github_tools.py +450 -0
- nat/tool/memory_tools/get_memory_tool.py +1 -1
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/register.py +2 -9
- nat/tool/retriever.py +3 -2
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +3 -3
- nat/utils/decorators.py +210 -0
- nat/utils/exception_handlers/automatic_retries.py +104 -51
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/yaml_tools.py +2 -2
- nat/utils/log_levels.py +25 -0
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +4 -4
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +6 -8
- nat/utils/type_converter.py +4 -3
- nat/utils/type_utils.py +9 -5
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/METADATA +42 -18
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/RECORD +238 -196
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -304
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nat/tool/mcp/exceptions.py +0 -142
- nat/tool/mcp/mcp_client.py +0 -255
- nat/tool/mcp/mcp_tool.py +0 -96
- nat/utils/exception_handlers/mcp.py +0 -211
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -39,7 +39,7 @@ class SWEBenchInput(BaseModel):
|
|
|
39
39
|
|
|
40
40
|
# Handle improperly formatted JSON strings for list fields
|
|
41
41
|
@field_validator("FAIL_TO_PASS", "PASS_TO_PASS", mode="before")
|
|
42
|
-
def parse_list_fields(cls, value):
|
|
42
|
+
def parse_list_fields(cls, value):
|
|
43
43
|
if isinstance(value, str):
|
|
44
44
|
# Attempt to parse the string as a list
|
|
45
45
|
return json.loads(value)
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import re
|
|
17
|
+
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
|
|
20
|
+
from nat.data_models.gated_field_mixin import GatedFieldMixin
|
|
21
|
+
from nat.data_models.optimizable import OptimizableField
|
|
22
|
+
from nat.data_models.optimizable import SearchSpace
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class TemperatureMixin(
|
|
26
|
+
BaseModel,
|
|
27
|
+
GatedFieldMixin,
|
|
28
|
+
field_name="temperature",
|
|
29
|
+
default_if_supported=0.0,
|
|
30
|
+
keys=("model_name", "model", "azure_deployment"),
|
|
31
|
+
unsupported=(re.compile(r"gpt-?5", re.IGNORECASE), ),
|
|
32
|
+
):
|
|
33
|
+
"""
|
|
34
|
+
Mixin class for temperature configuration. Unsupported on models like gpt-5.
|
|
35
|
+
|
|
36
|
+
Attributes:
|
|
37
|
+
temperature: Sampling temperature in [0, 1]. Defaults to 0.0 when supported on the model.
|
|
38
|
+
"""
|
|
39
|
+
temperature: float | None = OptimizableField(
|
|
40
|
+
default=None,
|
|
41
|
+
ge=0.0,
|
|
42
|
+
le=1.0,
|
|
43
|
+
description="Sampling temperature in [0, 1]. Defaults to 0.0 when supported on the model.",
|
|
44
|
+
space=SearchSpace(high=0.9, low=0.1, step=0.2))
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import re
|
|
17
|
+
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
from pydantic import Field
|
|
20
|
+
|
|
21
|
+
from nat.data_models.gated_field_mixin import GatedFieldMixin
|
|
22
|
+
|
|
23
|
+
# Currently the control logic for thinking is only implemented for Nemotron models
|
|
24
|
+
_NEMOTRON_REGEX = re.compile(r"^nvidia/(llama|nvidia).*nemotron", re.IGNORECASE)
|
|
25
|
+
# The keys are the fields that are used to determine if the model supports thinking
|
|
26
|
+
_MODEL_KEYS = ("model_name", "model", "azure_deployment")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ThinkingMixin(
|
|
30
|
+
BaseModel,
|
|
31
|
+
GatedFieldMixin,
|
|
32
|
+
field_name="thinking",
|
|
33
|
+
default_if_supported=None,
|
|
34
|
+
keys=_MODEL_KEYS,
|
|
35
|
+
supported=(_NEMOTRON_REGEX, ),
|
|
36
|
+
):
|
|
37
|
+
"""
|
|
38
|
+
Mixin class for thinking configuration. Only supported on Nemotron models.
|
|
39
|
+
|
|
40
|
+
Attributes:
|
|
41
|
+
thinking: Whether to enable thinking. Defaults to None when supported on the model.
|
|
42
|
+
"""
|
|
43
|
+
thinking: bool | None = Field(
|
|
44
|
+
default=None,
|
|
45
|
+
description="Whether to enable thinking. Defaults to None when supported on the model.",
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def thinking_system_prompt(self) -> str | None:
|
|
50
|
+
"""
|
|
51
|
+
Returns the system prompt to use for thinking.
|
|
52
|
+
For NVIDIA Nemotron, returns "/think" if enabled, else "/no_think".
|
|
53
|
+
For Llama Nemotron v1.5, returns "/think" if enabled, else "/no_think".
|
|
54
|
+
For Llama Nemotron v1.0, returns "detailed thinking on" if enabled, else "detailed thinking off".
|
|
55
|
+
If thinking is not supported on the model, returns None.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
str | None: The system prompt to use for thinking.
|
|
59
|
+
"""
|
|
60
|
+
if self.thinking is None:
|
|
61
|
+
return None
|
|
62
|
+
|
|
63
|
+
for key in _MODEL_KEYS:
|
|
64
|
+
model = getattr(self, key, None)
|
|
65
|
+
if not isinstance(model, str) or model is None:
|
|
66
|
+
continue
|
|
67
|
+
|
|
68
|
+
# Normalize name to reduce checks
|
|
69
|
+
model = model.lower().translate(str.maketrans("_.", "--"))
|
|
70
|
+
|
|
71
|
+
if model.startswith("nvidia/nvidia"):
|
|
72
|
+
return "/think" if self.thinking else "/no_think"
|
|
73
|
+
|
|
74
|
+
if model.startswith("nvidia/llama"):
|
|
75
|
+
if "v1-0" in model or "v1-1" in model:
|
|
76
|
+
return f"detailed thinking {'on' if self.thinking else 'off'}"
|
|
77
|
+
|
|
78
|
+
if "v1-5" in model:
|
|
79
|
+
# v1.5 models are updated to use the /think and /no_think system prompts
|
|
80
|
+
return "/think" if self.thinking else "/no_think"
|
|
81
|
+
|
|
82
|
+
# Assume any other model is a newer model that uses the /think and /no_think system prompts
|
|
83
|
+
return "/think" if self.thinking else "/no_think"
|
|
84
|
+
|
|
85
|
+
# Unknown model
|
|
86
|
+
return None
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import re
|
|
17
|
+
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
|
|
20
|
+
from nat.data_models.gated_field_mixin import GatedFieldMixin
|
|
21
|
+
from nat.data_models.optimizable import OptimizableField
|
|
22
|
+
from nat.data_models.optimizable import SearchSpace
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class TopPMixin(
|
|
26
|
+
BaseModel,
|
|
27
|
+
GatedFieldMixin,
|
|
28
|
+
field_name="top_p",
|
|
29
|
+
default_if_supported=1.0,
|
|
30
|
+
keys=("model_name", "model", "azure_deployment"),
|
|
31
|
+
unsupported=(re.compile(r"gpt-?5", re.IGNORECASE), ),
|
|
32
|
+
):
|
|
33
|
+
"""
|
|
34
|
+
Mixin class for top-p configuration. Unsupported on models like gpt-5.
|
|
35
|
+
|
|
36
|
+
Attributes:
|
|
37
|
+
top_p: Top-p for distribution sampling. Defaults to 1.0 when supported on the model.
|
|
38
|
+
"""
|
|
39
|
+
top_p: float | None = OptimizableField(
|
|
40
|
+
default=None,
|
|
41
|
+
ge=0.0,
|
|
42
|
+
le=1.0,
|
|
43
|
+
description="Top-p for distribution sampling. Defaults to 1.0 when supported on the model.",
|
|
44
|
+
space=SearchSpace(high=1.0, low=0.5, step=0.1))
|
nat/embedder/nim_embedder.py
CHANGED
|
@@ -50,7 +50,7 @@ class NIMEmbedderModelConfig(EmbedderBaseConfig, RetryMixin, name="nim"):
|
|
|
50
50
|
description=("The truncation strategy if the input on the "
|
|
51
51
|
"server side if it's too large."))
|
|
52
52
|
|
|
53
|
-
model_config = ConfigDict(protected_namespaces=())
|
|
53
|
+
model_config = ConfigDict(protected_namespaces=(), extra="allow")
|
|
54
54
|
|
|
55
55
|
|
|
56
56
|
@register_embedder_provider(config_type=NIMEmbedderModelConfig)
|
nat/embedder/openai_embedder.py
CHANGED
|
@@ -27,7 +27,7 @@ from nat.data_models.retry_mixin import RetryMixin
|
|
|
27
27
|
class OpenAIEmbedderModelConfig(EmbedderBaseConfig, RetryMixin, name="openai"):
|
|
28
28
|
"""An OpenAI LLM provider to be used with an LLM client."""
|
|
29
29
|
|
|
30
|
-
model_config = ConfigDict(protected_namespaces=())
|
|
30
|
+
model_config = ConfigDict(protected_namespaces=(), extra="allow")
|
|
31
31
|
|
|
32
32
|
api_key: str | None = Field(default=None, description="OpenAI API key to interact with hosted model.")
|
|
33
33
|
base_url: str | None = Field(default=None, description="Base url to the hosted model.")
|
nat/embedder/register.py
CHANGED
nat/eval/config.py
CHANGED
|
@@ -27,7 +27,7 @@ class EvaluationRunConfig(BaseModel):
|
|
|
27
27
|
"""
|
|
28
28
|
Parameters used for a single evaluation run.
|
|
29
29
|
"""
|
|
30
|
-
config_file: Path
|
|
30
|
+
config_file: Path | BaseModel
|
|
31
31
|
dataset: str | None = None # dataset file path can be specified in the config file
|
|
32
32
|
result_json_path: str = "$"
|
|
33
33
|
skip_workflow: bool = False
|
|
@@ -44,6 +44,8 @@ class EvaluationRunConfig(BaseModel):
|
|
|
44
44
|
# number of passes at each concurrency, if 0 the dataset is adjusted to a multiple of the
|
|
45
45
|
# concurrency. The is only used if adjust_dataset_size is true
|
|
46
46
|
num_passes: int = 0
|
|
47
|
+
# timeout for waiting for trace export tasks to complete
|
|
48
|
+
export_timeout: float = 60.0
|
|
47
49
|
|
|
48
50
|
|
|
49
51
|
class EvaluationRunOutput(BaseModel):
|
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import importlib
|
|
16
17
|
import json
|
|
17
18
|
import math
|
|
18
19
|
from pathlib import Path
|
|
@@ -41,7 +42,8 @@ class DatasetHandler:
|
|
|
41
42
|
reps: int,
|
|
42
43
|
concurrency: int,
|
|
43
44
|
num_passes: int = 1,
|
|
44
|
-
adjust_dataset_size: bool = False
|
|
45
|
+
adjust_dataset_size: bool = False,
|
|
46
|
+
custom_pre_eval_process_function: str | None = None):
|
|
45
47
|
from nat.eval.intermediate_step_adapter import IntermediateStepAdapter
|
|
46
48
|
|
|
47
49
|
self.dataset_config = dataset_config
|
|
@@ -53,6 +55,9 @@ class DatasetHandler:
|
|
|
53
55
|
self.num_passes = num_passes
|
|
54
56
|
self.adjust_dataset_size = adjust_dataset_size
|
|
55
57
|
|
|
58
|
+
# Custom pre-evaluation process function
|
|
59
|
+
self.custom_pre_eval_process_function = custom_pre_eval_process_function
|
|
60
|
+
|
|
56
61
|
# Helpers
|
|
57
62
|
self.intermediate_step_adapter = IntermediateStepAdapter()
|
|
58
63
|
|
|
@@ -146,13 +151,12 @@ class DatasetHandler:
|
|
|
146
151
|
# When num_passes is specified, always use concurrency * num_passes
|
|
147
152
|
# This respects the user's intent for exact number of passes
|
|
148
153
|
target_size = self.concurrency * self.num_passes
|
|
154
|
+
# When num_passes = 0, use the largest multiple of concurrency <= original_size
|
|
155
|
+
# If original_size < concurrency, we need at least concurrency rows
|
|
156
|
+
elif original_size >= self.concurrency:
|
|
157
|
+
target_size = (original_size // self.concurrency) * self.concurrency
|
|
149
158
|
else:
|
|
150
|
-
|
|
151
|
-
# If original_size < concurrency, we need at least concurrency rows
|
|
152
|
-
if original_size >= self.concurrency:
|
|
153
|
-
target_size = (original_size // self.concurrency) * self.concurrency
|
|
154
|
-
else:
|
|
155
|
-
target_size = self.concurrency
|
|
159
|
+
target_size = self.concurrency
|
|
156
160
|
|
|
157
161
|
if target_size == 0:
|
|
158
162
|
raise ValueError("Input dataset too small for even one batch at given concurrency.")
|
|
@@ -331,6 +335,66 @@ class DatasetHandler:
|
|
|
331
335
|
filtered_steps = self.intermediate_step_adapter.filter_intermediate_steps(intermediate_steps, event_filter)
|
|
332
336
|
return self.intermediate_step_adapter.serialize_intermediate_steps(filtered_steps)
|
|
333
337
|
|
|
338
|
+
def pre_eval_process_eval_input(self, eval_input: EvalInput) -> EvalInput:
|
|
339
|
+
"""
|
|
340
|
+
Pre-evaluation process the eval input using custom function if provided.
|
|
341
|
+
|
|
342
|
+
The custom pre-evaluation process function should have the signature:
|
|
343
|
+
def custom_pre_eval_process(item: EvalInputItem) -> EvalInputItem
|
|
344
|
+
|
|
345
|
+
The framework will iterate through all items and call this function on each one.
|
|
346
|
+
|
|
347
|
+
Args:
|
|
348
|
+
eval_input: The EvalInput object to pre-evaluation process
|
|
349
|
+
|
|
350
|
+
Returns:
|
|
351
|
+
The pre-evaluation processed EvalInput object
|
|
352
|
+
"""
|
|
353
|
+
if self.custom_pre_eval_process_function:
|
|
354
|
+
try:
|
|
355
|
+
custom_function = self._load_custom_pre_eval_process_function()
|
|
356
|
+
processed_items = []
|
|
357
|
+
|
|
358
|
+
for item in eval_input.eval_input_items:
|
|
359
|
+
processed_item = custom_function(item)
|
|
360
|
+
if not isinstance(processed_item, EvalInputItem):
|
|
361
|
+
raise TypeError(f"Custom pre-evaluation '{self.custom_pre_eval_process_function}' must return "
|
|
362
|
+
f"EvalInputItem, got {type(processed_item)}")
|
|
363
|
+
processed_items.append(processed_item)
|
|
364
|
+
|
|
365
|
+
return EvalInput(eval_input_items=processed_items)
|
|
366
|
+
except Exception as e:
|
|
367
|
+
raise RuntimeError(f"Error calling custom pre-evaluation process function "
|
|
368
|
+
f"'{self.custom_pre_eval_process_function}': {e}") from e
|
|
369
|
+
|
|
370
|
+
return eval_input
|
|
371
|
+
|
|
372
|
+
def _load_custom_pre_eval_process_function(self):
|
|
373
|
+
"""
|
|
374
|
+
Import and return the custom pre-evaluation process function using standard Python import path.
|
|
375
|
+
|
|
376
|
+
The function should process individual EvalInputItem objects.
|
|
377
|
+
"""
|
|
378
|
+
# Split the function path to get module and function name
|
|
379
|
+
if "." not in self.custom_pre_eval_process_function:
|
|
380
|
+
raise ValueError(f"Invalid custom_pre_eval_process_function '{self.custom_pre_eval_process_function}'. "
|
|
381
|
+
"Expected format: '<module_path>.<function_name>'")
|
|
382
|
+
module_path, function_name = self.custom_pre_eval_process_function.rsplit(".", 1)
|
|
383
|
+
|
|
384
|
+
# Import the module
|
|
385
|
+
module = importlib.import_module(module_path)
|
|
386
|
+
|
|
387
|
+
# Get the function from the module
|
|
388
|
+
if not hasattr(module, function_name):
|
|
389
|
+
raise AttributeError(f"Function '{function_name}' not found in module '{module_path}'")
|
|
390
|
+
|
|
391
|
+
custom_function = getattr(module, function_name)
|
|
392
|
+
|
|
393
|
+
if not callable(custom_function):
|
|
394
|
+
raise ValueError(f"'{self.custom_pre_eval_process_function}' is not callable")
|
|
395
|
+
|
|
396
|
+
return custom_function
|
|
397
|
+
|
|
334
398
|
def publish_eval_input(self,
|
|
335
399
|
eval_input,
|
|
336
400
|
workflow_output_step_filter: list[IntermediateStepType] | None = None) -> str:
|
nat/eval/evaluate.py
CHANGED
|
@@ -42,7 +42,7 @@ from nat.runtime.session import SessionManager
|
|
|
42
42
|
logger = logging.getLogger(__name__)
|
|
43
43
|
|
|
44
44
|
|
|
45
|
-
class EvaluationRun:
|
|
45
|
+
class EvaluationRun:
|
|
46
46
|
"""
|
|
47
47
|
Instantiated for each evaluation run and used to store data for that single run.
|
|
48
48
|
|
|
@@ -63,7 +63,16 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
|
|
|
63
63
|
|
|
64
64
|
# Helpers
|
|
65
65
|
self.intermediate_step_adapter: IntermediateStepAdapter = IntermediateStepAdapter()
|
|
66
|
-
|
|
66
|
+
|
|
67
|
+
# Create evaluation trace context
|
|
68
|
+
try:
|
|
69
|
+
from nat.eval.utils.eval_trace_ctx import WeaveEvalTraceContext
|
|
70
|
+
self.eval_trace_context = WeaveEvalTraceContext()
|
|
71
|
+
except Exception:
|
|
72
|
+
from nat.eval.utils.eval_trace_ctx import EvalTraceContext
|
|
73
|
+
self.eval_trace_context = EvalTraceContext()
|
|
74
|
+
|
|
75
|
+
self.weave_eval: WeaveEvaluationIntegration = WeaveEvaluationIntegration(self.eval_trace_context)
|
|
67
76
|
# Metadata
|
|
68
77
|
self.eval_input: EvalInput | None = None
|
|
69
78
|
self.workflow_interrupted: bool = False
|
|
@@ -159,17 +168,17 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
|
|
|
159
168
|
intermediate_future = None
|
|
160
169
|
|
|
161
170
|
try:
|
|
162
|
-
|
|
163
171
|
# Start usage stats and intermediate steps collection in parallel
|
|
164
172
|
intermediate_future = pull_intermediate()
|
|
165
173
|
runner_result = runner.result()
|
|
166
174
|
base_output = await runner_result
|
|
167
175
|
intermediate_steps = await intermediate_future
|
|
168
176
|
except NotImplementedError as e:
|
|
177
|
+
logger.error("Failed to run the workflow: %s", e)
|
|
169
178
|
# raise original error
|
|
170
|
-
raise
|
|
179
|
+
raise
|
|
171
180
|
except Exception as e:
|
|
172
|
-
logger.exception("Failed to run the workflow: %s", e
|
|
181
|
+
logger.exception("Failed to run the workflow: %s", e)
|
|
173
182
|
# stop processing if a workflow error occurs
|
|
174
183
|
self.workflow_interrupted = True
|
|
175
184
|
|
|
@@ -308,9 +317,9 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
|
|
|
308
317
|
logger.info("Deleting old job directory: %s", dir_to_delete)
|
|
309
318
|
shutil.rmtree(dir_to_delete)
|
|
310
319
|
except Exception as e:
|
|
311
|
-
logger.exception("Failed to delete old job directory: %s: %s", dir_to_delete, e
|
|
320
|
+
logger.exception("Failed to delete old job directory: %s: %s", dir_to_delete, e)
|
|
312
321
|
|
|
313
|
-
def write_output(self, dataset_handler: DatasetHandler, profiler_results: ProfilerResults):
|
|
322
|
+
def write_output(self, dataset_handler: DatasetHandler, profiler_results: ProfilerResults):
|
|
314
323
|
workflow_output_file = self.eval_config.general.output_dir / "workflow_output.json"
|
|
315
324
|
workflow_output_file.parent.mkdir(parents=True, exist_ok=True)
|
|
316
325
|
|
|
@@ -358,7 +367,7 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
|
|
|
358
367
|
|
|
359
368
|
await self.weave_eval.alog_score(eval_output, evaluator_name)
|
|
360
369
|
except Exception as e:
|
|
361
|
-
logger.exception("An error occurred while running evaluator %s: %s", evaluator_name, e
|
|
370
|
+
logger.exception("An error occurred while running evaluator %s: %s", evaluator_name, e)
|
|
362
371
|
|
|
363
372
|
async def run_evaluators(self, evaluators: dict[str, Any]):
|
|
364
373
|
"""Run all configured evaluators asynchronously."""
|
|
@@ -371,7 +380,7 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
|
|
|
371
380
|
try:
|
|
372
381
|
await asyncio.gather(*tasks)
|
|
373
382
|
except Exception as e:
|
|
374
|
-
logger.
|
|
383
|
+
logger.error("An error occurred while running evaluators: %s", e)
|
|
375
384
|
raise
|
|
376
385
|
finally:
|
|
377
386
|
# Finish prediction loggers in Weave
|
|
@@ -401,6 +410,33 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
|
|
|
401
410
|
|
|
402
411
|
return workflow_type
|
|
403
412
|
|
|
413
|
+
async def wait_for_all_export_tasks_local(self, session_manager: SessionManager, timeout: float) -> None:
|
|
414
|
+
"""Wait for all trace export tasks to complete for local workflows.
|
|
415
|
+
|
|
416
|
+
This only works for local workflows where we have direct access to the
|
|
417
|
+
SessionManager and its underlying workflow with exporter manager.
|
|
418
|
+
"""
|
|
419
|
+
try:
|
|
420
|
+
workflow = session_manager.workflow
|
|
421
|
+
all_exporters = await workflow.get_all_exporters()
|
|
422
|
+
if not all_exporters:
|
|
423
|
+
logger.debug("No exporters to wait for")
|
|
424
|
+
return
|
|
425
|
+
|
|
426
|
+
logger.info("Waiting for export tasks from %d local exporters (timeout: %ds)", len(all_exporters), timeout)
|
|
427
|
+
|
|
428
|
+
for name, exporter in all_exporters.items():
|
|
429
|
+
try:
|
|
430
|
+
await exporter.wait_for_tasks(timeout=timeout)
|
|
431
|
+
logger.info("Export tasks completed for exporter: %s", name)
|
|
432
|
+
except Exception as e:
|
|
433
|
+
logger.warning("Error waiting for export tasks from %s: %s", name, e)
|
|
434
|
+
|
|
435
|
+
logger.info("All local export task waiting completed")
|
|
436
|
+
|
|
437
|
+
except Exception as e:
|
|
438
|
+
logger.warning("Failed to wait for local export tasks: %s", e)
|
|
439
|
+
|
|
404
440
|
async def run_and_evaluate(self,
|
|
405
441
|
session_manager: SessionManager | None = None,
|
|
406
442
|
job_id: str | None = None) -> EvaluationRunOutput:
|
|
@@ -413,10 +449,14 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
|
|
|
413
449
|
from nat.runtime.loader import load_config
|
|
414
450
|
|
|
415
451
|
# Load and override the config
|
|
416
|
-
|
|
452
|
+
config = None
|
|
453
|
+
if isinstance(self.config.config_file, BaseModel):
|
|
454
|
+
config = self.config.config_file
|
|
455
|
+
elif self.config.override:
|
|
417
456
|
config = self.apply_overrides()
|
|
418
457
|
else:
|
|
419
458
|
config = load_config(self.config.config_file)
|
|
459
|
+
|
|
420
460
|
self.eval_config = config.eval
|
|
421
461
|
workflow_alias = self._get_workflow_alias(config.workflow.type)
|
|
422
462
|
logger.debug("Loaded %s evaluation configuration: %s", workflow_alias, self.eval_config)
|
|
@@ -442,44 +482,59 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
|
|
|
442
482
|
dataset_config = self.eval_config.general.dataset # Currently only one dataset is supported
|
|
443
483
|
if not dataset_config:
|
|
444
484
|
logger.info("No dataset found, nothing to evaluate")
|
|
445
|
-
return EvaluationRunOutput(
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
485
|
+
return EvaluationRunOutput(workflow_output_file=self.workflow_output_file,
|
|
486
|
+
evaluator_output_files=self.evaluator_output_files,
|
|
487
|
+
workflow_interrupted=self.workflow_interrupted,
|
|
488
|
+
eval_input=EvalInput(eval_input_items=[]),
|
|
489
|
+
evaluation_results=[],
|
|
490
|
+
usage_stats=UsageStats(),
|
|
491
|
+
profiler_results=ProfilerResults())
|
|
492
|
+
|
|
493
|
+
custom_pre_eval_process_function = self.eval_config.general.output.custom_pre_eval_process_function \
|
|
494
|
+
if self.eval_config.general.output else None
|
|
451
495
|
dataset_handler = DatasetHandler(dataset_config=dataset_config,
|
|
452
496
|
reps=self.config.reps,
|
|
453
497
|
concurrency=self.eval_config.general.max_concurrency,
|
|
454
498
|
num_passes=self.config.num_passes,
|
|
455
|
-
adjust_dataset_size=self.config.adjust_dataset_size
|
|
499
|
+
adjust_dataset_size=self.config.adjust_dataset_size,
|
|
500
|
+
custom_pre_eval_process_function=custom_pre_eval_process_function)
|
|
456
501
|
self.eval_input = dataset_handler.get_eval_input_from_dataset(self.config.dataset)
|
|
457
502
|
if not self.eval_input.eval_input_items:
|
|
458
503
|
logger.info("Dataset is empty. Nothing to evaluate.")
|
|
459
|
-
return EvaluationRunOutput(
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
504
|
+
return EvaluationRunOutput(workflow_output_file=self.workflow_output_file,
|
|
505
|
+
evaluator_output_files=self.evaluator_output_files,
|
|
506
|
+
workflow_interrupted=self.workflow_interrupted,
|
|
507
|
+
eval_input=self.eval_input,
|
|
508
|
+
evaluation_results=self.evaluation_results,
|
|
509
|
+
usage_stats=self.usage_stats,
|
|
510
|
+
profiler_results=ProfilerResults())
|
|
464
511
|
|
|
465
512
|
# Run workflow and evaluate
|
|
466
513
|
async with WorkflowEvalBuilder.from_config(config=config) as eval_workflow:
|
|
467
514
|
# Initialize Weave integration
|
|
468
515
|
self.weave_eval.initialize_logger(workflow_alias, self.eval_input, config)
|
|
469
516
|
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
517
|
+
with self.eval_trace_context.evaluation_context():
|
|
518
|
+
# Run workflow
|
|
519
|
+
if self.config.endpoint:
|
|
520
|
+
await self.run_workflow_remote()
|
|
521
|
+
elif not self.config.skip_workflow:
|
|
475
522
|
if session_manager is None:
|
|
476
|
-
|
|
523
|
+
workflow = await eval_workflow.build()
|
|
524
|
+
session_manager = SessionManager(workflow,
|
|
477
525
|
max_concurrency=self.eval_config.general.max_concurrency)
|
|
478
526
|
await self.run_workflow_local(session_manager)
|
|
479
527
|
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
528
|
+
# Pre-evaluation process the workflow output
|
|
529
|
+
self.eval_input = dataset_handler.pre_eval_process_eval_input(self.eval_input)
|
|
530
|
+
|
|
531
|
+
# Evaluate
|
|
532
|
+
evaluators = {name: eval_workflow.get_evaluator(name) for name in self.eval_config.evaluators}
|
|
533
|
+
await self.run_evaluators(evaluators)
|
|
534
|
+
|
|
535
|
+
# Wait for all trace export tasks to complete (local workflows only)
|
|
536
|
+
if session_manager and not self.config.endpoint:
|
|
537
|
+
await self.wait_for_all_export_tasks_local(session_manager, timeout=self.config.export_timeout)
|
|
483
538
|
|
|
484
539
|
# Profile the workflow
|
|
485
540
|
profiler_results = await self.profile_workflow()
|
|
@@ -71,7 +71,7 @@ class BaseEvaluator(ABC):
|
|
|
71
71
|
TqdmPositionRegistry.release(tqdm_position)
|
|
72
72
|
|
|
73
73
|
# Compute average if possible
|
|
74
|
-
numeric_scores = [item.score for item in output_items if isinstance(item.score,
|
|
74
|
+
numeric_scores = [item.score for item in output_items if isinstance(item.score, int | float)]
|
|
75
75
|
avg_score = round(sum(numeric_scores) / len(numeric_scores), 2) if numeric_scores else None
|
|
76
76
|
|
|
77
77
|
return EvalOutput(average_score=avg_score, eval_output_items=output_items)
|
|
@@ -29,6 +29,19 @@ class EvalInputItem(BaseModel):
|
|
|
29
29
|
trajectory: list[IntermediateStep] = [] # populated by the workflow
|
|
30
30
|
full_dataset_entry: typing.Any
|
|
31
31
|
|
|
32
|
+
def copy_with_updates(self, **updates) -> "EvalInputItem":
|
|
33
|
+
"""
|
|
34
|
+
Copy EvalInputItem with optional field updates.
|
|
35
|
+
"""
|
|
36
|
+
# Get all current fields
|
|
37
|
+
item_data = self.model_dump()
|
|
38
|
+
|
|
39
|
+
# Apply any updates
|
|
40
|
+
item_data.update(updates)
|
|
41
|
+
|
|
42
|
+
# Create new item with all fields
|
|
43
|
+
return EvalInputItem(**item_data)
|
|
44
|
+
|
|
32
45
|
|
|
33
46
|
class EvalInput(BaseModel):
|
|
34
47
|
eval_input_items: list[EvalInputItem]
|
|
@@ -40,7 +40,7 @@ class IntermediateStepAdapter:
|
|
|
40
40
|
try:
|
|
41
41
|
validated_steps.append(IntermediateStep.model_validate(step_data))
|
|
42
42
|
except Exception as e:
|
|
43
|
-
logger.exception("Validation failed for step: %r, Error: %s", step_data, e
|
|
43
|
+
logger.exception("Validation failed for step: %r, Error: %s", step_data, e)
|
|
44
44
|
return validated_steps
|
|
45
45
|
|
|
46
46
|
def serialize_intermediate_steps(self, intermediate_steps: list[IntermediateStep]) -> list[dict]:
|
|
@@ -102,7 +102,7 @@ class RAGEvaluator:
|
|
|
102
102
|
"""Converts the ragas EvaluationResult to nat EvalOutput"""
|
|
103
103
|
|
|
104
104
|
if not results_dataset:
|
|
105
|
-
logger.error("Ragas evaluation failed with no results")
|
|
105
|
+
logger.error("Ragas evaluation failed with no results", exc_info=True)
|
|
106
106
|
return EvalOutput(average_score=0.0, eval_output_items=[])
|
|
107
107
|
|
|
108
108
|
scores: list[dict[str, float]] = results_dataset.scores
|
|
@@ -169,7 +169,7 @@ class RAGEvaluator:
|
|
|
169
169
|
_pbar=pbar)
|
|
170
170
|
except Exception as e:
|
|
171
171
|
# On exception we still continue with other evaluators. Log and return an avg_score of 0.0
|
|
172
|
-
logger.exception("Error evaluating ragas metric, Error: %s", e
|
|
172
|
+
logger.exception("Error evaluating ragas metric, Error: %s", e)
|
|
173
173
|
results_dataset = None
|
|
174
174
|
finally:
|
|
175
175
|
pbar.close()
|
|
@@ -73,7 +73,7 @@ class RagasEvaluatorConfig(EvaluatorBaseConfig, name="ragas"):
|
|
|
73
73
|
if isinstance(self.metric, str):
|
|
74
74
|
return self.metric
|
|
75
75
|
if isinstance(self.metric, dict) and self.metric:
|
|
76
|
-
return next(iter(self.metric.keys()))
|
|
76
|
+
return next(iter(self.metric.keys()))
|
|
77
77
|
return ""
|
|
78
78
|
|
|
79
79
|
@property
|
|
@@ -82,7 +82,7 @@ class RagasEvaluatorConfig(EvaluatorBaseConfig, name="ragas"):
|
|
|
82
82
|
if isinstance(self.metric, str):
|
|
83
83
|
return RagasMetricConfig() # Default config when only a metric name is given
|
|
84
84
|
if isinstance(self.metric, dict) and self.metric:
|
|
85
|
-
return next(iter(self.metric.values()))
|
|
85
|
+
return next(iter(self.metric.values()))
|
|
86
86
|
return RagasMetricConfig() # Default config when an invalid type is provided
|
|
87
87
|
|
|
88
88
|
|
|
@@ -104,7 +104,7 @@ async def register_ragas_evaluator(config: RagasEvaluatorConfig, builder: EvalBu
|
|
|
104
104
|
raise ValueError(message) from e
|
|
105
105
|
except AttributeError as e:
|
|
106
106
|
message = f"Ragas metric {metric_name} not found {e}."
|
|
107
|
-
logger.
|
|
107
|
+
logger.exception(message)
|
|
108
108
|
return None
|
|
109
109
|
|
|
110
110
|
async def evaluate_fn(eval_input: EvalInput) -> EvalOutput:
|
nat/eval/register.py
CHANGED
|
@@ -14,10 +14,13 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
# flake8: noqa
|
|
17
|
-
# pylint: disable=unused-import
|
|
18
17
|
|
|
19
18
|
# Import evaluators which need to be automatically registered here
|
|
20
19
|
from .rag_evaluator.register import register_ragas_evaluator
|
|
20
|
+
from .runtime_evaluator.register import register_avg_llm_latency_evaluator
|
|
21
|
+
from .runtime_evaluator.register import register_avg_num_llm_calls_evaluator
|
|
22
|
+
from .runtime_evaluator.register import register_avg_tokens_per_llm_end_evaluator
|
|
23
|
+
from .runtime_evaluator.register import register_avg_workflow_runtime_evaluator
|
|
21
24
|
from .swe_bench_evaluator.register import register_swe_bench_evaluator
|
|
22
25
|
from .trajectory_evaluator.register import register_trajectory_evaluator
|
|
23
26
|
from .tunable_rag_evaluator.register import register_tunable_rag_evaluator
|