aiqtoolkit 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aiqtoolkit might be problematic. Click here for more details.
- aiq/agent/__init__.py +0 -0
- aiq/agent/base.py +76 -0
- aiq/agent/dual_node.py +67 -0
- aiq/agent/react_agent/__init__.py +0 -0
- aiq/agent/react_agent/agent.py +322 -0
- aiq/agent/react_agent/output_parser.py +104 -0
- aiq/agent/react_agent/prompt.py +46 -0
- aiq/agent/react_agent/register.py +148 -0
- aiq/agent/reasoning_agent/__init__.py +0 -0
- aiq/agent/reasoning_agent/reasoning_agent.py +224 -0
- aiq/agent/register.py +23 -0
- aiq/agent/rewoo_agent/__init__.py +0 -0
- aiq/agent/rewoo_agent/agent.py +410 -0
- aiq/agent/rewoo_agent/prompt.py +108 -0
- aiq/agent/rewoo_agent/register.py +158 -0
- aiq/agent/tool_calling_agent/__init__.py +0 -0
- aiq/agent/tool_calling_agent/agent.py +123 -0
- aiq/agent/tool_calling_agent/register.py +105 -0
- aiq/builder/__init__.py +0 -0
- aiq/builder/builder.py +223 -0
- aiq/builder/component_utils.py +303 -0
- aiq/builder/context.py +227 -0
- aiq/builder/embedder.py +24 -0
- aiq/builder/eval_builder.py +120 -0
- aiq/builder/evaluator.py +29 -0
- aiq/builder/framework_enum.py +24 -0
- aiq/builder/front_end.py +73 -0
- aiq/builder/function.py +297 -0
- aiq/builder/function_base.py +376 -0
- aiq/builder/function_info.py +627 -0
- aiq/builder/intermediate_step_manager.py +176 -0
- aiq/builder/llm.py +25 -0
- aiq/builder/retriever.py +25 -0
- aiq/builder/user_interaction_manager.py +71 -0
- aiq/builder/workflow.py +143 -0
- aiq/builder/workflow_builder.py +757 -0
- aiq/cli/__init__.py +14 -0
- aiq/cli/cli_utils/__init__.py +0 -0
- aiq/cli/cli_utils/config_override.py +231 -0
- aiq/cli/cli_utils/validation.py +37 -0
- aiq/cli/commands/__init__.py +0 -0
- aiq/cli/commands/configure/__init__.py +0 -0
- aiq/cli/commands/configure/channel/__init__.py +0 -0
- aiq/cli/commands/configure/channel/add.py +28 -0
- aiq/cli/commands/configure/channel/channel.py +36 -0
- aiq/cli/commands/configure/channel/remove.py +30 -0
- aiq/cli/commands/configure/channel/update.py +30 -0
- aiq/cli/commands/configure/configure.py +33 -0
- aiq/cli/commands/evaluate.py +139 -0
- aiq/cli/commands/info/__init__.py +14 -0
- aiq/cli/commands/info/info.py +39 -0
- aiq/cli/commands/info/list_channels.py +32 -0
- aiq/cli/commands/info/list_components.py +129 -0
- aiq/cli/commands/info/list_mcp.py +126 -0
- aiq/cli/commands/registry/__init__.py +14 -0
- aiq/cli/commands/registry/publish.py +88 -0
- aiq/cli/commands/registry/pull.py +118 -0
- aiq/cli/commands/registry/registry.py +38 -0
- aiq/cli/commands/registry/remove.py +108 -0
- aiq/cli/commands/registry/search.py +155 -0
- aiq/cli/commands/start.py +250 -0
- aiq/cli/commands/uninstall.py +83 -0
- aiq/cli/commands/validate.py +47 -0
- aiq/cli/commands/workflow/__init__.py +14 -0
- aiq/cli/commands/workflow/templates/__init__.py.j2 +0 -0
- aiq/cli/commands/workflow/templates/config.yml.j2 +16 -0
- aiq/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
- aiq/cli/commands/workflow/templates/register.py.j2 +5 -0
- aiq/cli/commands/workflow/templates/workflow.py.j2 +36 -0
- aiq/cli/commands/workflow/workflow.py +37 -0
- aiq/cli/commands/workflow/workflow_commands.py +313 -0
- aiq/cli/entrypoint.py +133 -0
- aiq/cli/main.py +44 -0
- aiq/cli/register_workflow.py +408 -0
- aiq/cli/type_registry.py +879 -0
- aiq/data_models/__init__.py +14 -0
- aiq/data_models/api_server.py +588 -0
- aiq/data_models/common.py +143 -0
- aiq/data_models/component.py +46 -0
- aiq/data_models/component_ref.py +135 -0
- aiq/data_models/config.py +349 -0
- aiq/data_models/dataset_handler.py +122 -0
- aiq/data_models/discovery_metadata.py +286 -0
- aiq/data_models/embedder.py +26 -0
- aiq/data_models/evaluate.py +104 -0
- aiq/data_models/evaluator.py +26 -0
- aiq/data_models/front_end.py +26 -0
- aiq/data_models/function.py +30 -0
- aiq/data_models/function_dependencies.py +64 -0
- aiq/data_models/interactive.py +237 -0
- aiq/data_models/intermediate_step.py +269 -0
- aiq/data_models/invocation_node.py +38 -0
- aiq/data_models/llm.py +26 -0
- aiq/data_models/logging.py +26 -0
- aiq/data_models/memory.py +26 -0
- aiq/data_models/profiler.py +53 -0
- aiq/data_models/registry_handler.py +26 -0
- aiq/data_models/retriever.py +30 -0
- aiq/data_models/step_adaptor.py +64 -0
- aiq/data_models/streaming.py +33 -0
- aiq/data_models/swe_bench_model.py +54 -0
- aiq/data_models/telemetry_exporter.py +26 -0
- aiq/embedder/__init__.py +0 -0
- aiq/embedder/langchain_client.py +41 -0
- aiq/embedder/nim_embedder.py +58 -0
- aiq/embedder/openai_embedder.py +42 -0
- aiq/embedder/register.py +24 -0
- aiq/eval/__init__.py +14 -0
- aiq/eval/config.py +42 -0
- aiq/eval/dataset_handler/__init__.py +0 -0
- aiq/eval/dataset_handler/dataset_downloader.py +106 -0
- aiq/eval/dataset_handler/dataset_filter.py +52 -0
- aiq/eval/dataset_handler/dataset_handler.py +169 -0
- aiq/eval/evaluate.py +325 -0
- aiq/eval/evaluator/__init__.py +14 -0
- aiq/eval/evaluator/evaluator_model.py +44 -0
- aiq/eval/intermediate_step_adapter.py +93 -0
- aiq/eval/rag_evaluator/__init__.py +0 -0
- aiq/eval/rag_evaluator/evaluate.py +138 -0
- aiq/eval/rag_evaluator/register.py +138 -0
- aiq/eval/register.py +23 -0
- aiq/eval/remote_workflow.py +128 -0
- aiq/eval/runtime_event_subscriber.py +52 -0
- aiq/eval/swe_bench_evaluator/__init__.py +0 -0
- aiq/eval/swe_bench_evaluator/evaluate.py +215 -0
- aiq/eval/swe_bench_evaluator/register.py +36 -0
- aiq/eval/trajectory_evaluator/__init__.py +0 -0
- aiq/eval/trajectory_evaluator/evaluate.py +118 -0
- aiq/eval/trajectory_evaluator/register.py +40 -0
- aiq/eval/tunable_rag_evaluator/__init__.py +0 -0
- aiq/eval/tunable_rag_evaluator/evaluate.py +263 -0
- aiq/eval/tunable_rag_evaluator/register.py +50 -0
- aiq/eval/utils/__init__.py +0 -0
- aiq/eval/utils/output_uploader.py +131 -0
- aiq/eval/utils/tqdm_position_registry.py +40 -0
- aiq/front_ends/__init__.py +14 -0
- aiq/front_ends/console/__init__.py +14 -0
- aiq/front_ends/console/console_front_end_config.py +32 -0
- aiq/front_ends/console/console_front_end_plugin.py +107 -0
- aiq/front_ends/console/register.py +25 -0
- aiq/front_ends/cron/__init__.py +14 -0
- aiq/front_ends/fastapi/__init__.py +14 -0
- aiq/front_ends/fastapi/fastapi_front_end_config.py +150 -0
- aiq/front_ends/fastapi/fastapi_front_end_plugin.py +103 -0
- aiq/front_ends/fastapi/fastapi_front_end_plugin_worker.py +607 -0
- aiq/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
- aiq/front_ends/fastapi/job_store.py +161 -0
- aiq/front_ends/fastapi/main.py +70 -0
- aiq/front_ends/fastapi/message_handler.py +279 -0
- aiq/front_ends/fastapi/message_validator.py +345 -0
- aiq/front_ends/fastapi/register.py +25 -0
- aiq/front_ends/fastapi/response_helpers.py +195 -0
- aiq/front_ends/fastapi/step_adaptor.py +320 -0
- aiq/front_ends/fastapi/websocket.py +148 -0
- aiq/front_ends/mcp/__init__.py +14 -0
- aiq/front_ends/mcp/mcp_front_end_config.py +32 -0
- aiq/front_ends/mcp/mcp_front_end_plugin.py +93 -0
- aiq/front_ends/mcp/register.py +27 -0
- aiq/front_ends/mcp/tool_converter.py +242 -0
- aiq/front_ends/register.py +22 -0
- aiq/front_ends/simple_base/__init__.py +14 -0
- aiq/front_ends/simple_base/simple_front_end_plugin_base.py +52 -0
- aiq/llm/__init__.py +0 -0
- aiq/llm/nim_llm.py +45 -0
- aiq/llm/openai_llm.py +45 -0
- aiq/llm/register.py +22 -0
- aiq/llm/utils/__init__.py +14 -0
- aiq/llm/utils/env_config_value.py +94 -0
- aiq/llm/utils/error.py +17 -0
- aiq/memory/__init__.py +20 -0
- aiq/memory/interfaces.py +183 -0
- aiq/memory/models.py +112 -0
- aiq/meta/module_to_distro.json +3 -0
- aiq/meta/pypi.md +58 -0
- aiq/observability/__init__.py +0 -0
- aiq/observability/async_otel_listener.py +429 -0
- aiq/observability/register.py +99 -0
- aiq/plugins/.namespace +1 -0
- aiq/profiler/__init__.py +0 -0
- aiq/profiler/callbacks/__init__.py +0 -0
- aiq/profiler/callbacks/agno_callback_handler.py +295 -0
- aiq/profiler/callbacks/base_callback_class.py +20 -0
- aiq/profiler/callbacks/langchain_callback_handler.py +278 -0
- aiq/profiler/callbacks/llama_index_callback_handler.py +205 -0
- aiq/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
- aiq/profiler/callbacks/token_usage_base_model.py +27 -0
- aiq/profiler/data_frame_row.py +51 -0
- aiq/profiler/decorators/__init__.py +0 -0
- aiq/profiler/decorators/framework_wrapper.py +131 -0
- aiq/profiler/decorators/function_tracking.py +254 -0
- aiq/profiler/forecasting/__init__.py +0 -0
- aiq/profiler/forecasting/config.py +18 -0
- aiq/profiler/forecasting/model_trainer.py +75 -0
- aiq/profiler/forecasting/models/__init__.py +22 -0
- aiq/profiler/forecasting/models/forecasting_base_model.py +40 -0
- aiq/profiler/forecasting/models/linear_model.py +196 -0
- aiq/profiler/forecasting/models/random_forest_regressor.py +268 -0
- aiq/profiler/inference_metrics_model.py +25 -0
- aiq/profiler/inference_optimization/__init__.py +0 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +452 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
- aiq/profiler/inference_optimization/data_models.py +386 -0
- aiq/profiler/inference_optimization/experimental/__init__.py +0 -0
- aiq/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
- aiq/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
- aiq/profiler/inference_optimization/llm_metrics.py +212 -0
- aiq/profiler/inference_optimization/prompt_caching.py +163 -0
- aiq/profiler/inference_optimization/token_uniqueness.py +107 -0
- aiq/profiler/inference_optimization/workflow_runtimes.py +72 -0
- aiq/profiler/intermediate_property_adapter.py +102 -0
- aiq/profiler/profile_runner.py +433 -0
- aiq/profiler/utils.py +184 -0
- aiq/registry_handlers/__init__.py +0 -0
- aiq/registry_handlers/local/__init__.py +0 -0
- aiq/registry_handlers/local/local_handler.py +176 -0
- aiq/registry_handlers/local/register_local.py +37 -0
- aiq/registry_handlers/metadata_factory.py +60 -0
- aiq/registry_handlers/package_utils.py +198 -0
- aiq/registry_handlers/pypi/__init__.py +0 -0
- aiq/registry_handlers/pypi/pypi_handler.py +251 -0
- aiq/registry_handlers/pypi/register_pypi.py +40 -0
- aiq/registry_handlers/register.py +21 -0
- aiq/registry_handlers/registry_handler_base.py +157 -0
- aiq/registry_handlers/rest/__init__.py +0 -0
- aiq/registry_handlers/rest/register_rest.py +56 -0
- aiq/registry_handlers/rest/rest_handler.py +237 -0
- aiq/registry_handlers/schemas/__init__.py +0 -0
- aiq/registry_handlers/schemas/headers.py +42 -0
- aiq/registry_handlers/schemas/package.py +68 -0
- aiq/registry_handlers/schemas/publish.py +63 -0
- aiq/registry_handlers/schemas/pull.py +82 -0
- aiq/registry_handlers/schemas/remove.py +36 -0
- aiq/registry_handlers/schemas/search.py +91 -0
- aiq/registry_handlers/schemas/status.py +47 -0
- aiq/retriever/__init__.py +0 -0
- aiq/retriever/interface.py +37 -0
- aiq/retriever/milvus/__init__.py +14 -0
- aiq/retriever/milvus/register.py +81 -0
- aiq/retriever/milvus/retriever.py +228 -0
- aiq/retriever/models.py +74 -0
- aiq/retriever/nemo_retriever/__init__.py +14 -0
- aiq/retriever/nemo_retriever/register.py +60 -0
- aiq/retriever/nemo_retriever/retriever.py +190 -0
- aiq/retriever/register.py +22 -0
- aiq/runtime/__init__.py +14 -0
- aiq/runtime/loader.py +188 -0
- aiq/runtime/runner.py +176 -0
- aiq/runtime/session.py +140 -0
- aiq/runtime/user_metadata.py +131 -0
- aiq/settings/__init__.py +0 -0
- aiq/settings/global_settings.py +318 -0
- aiq/test/.namespace +1 -0
- aiq/tool/__init__.py +0 -0
- aiq/tool/code_execution/__init__.py +0 -0
- aiq/tool/code_execution/code_sandbox.py +188 -0
- aiq/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
- aiq/tool/code_execution/local_sandbox/__init__.py +13 -0
- aiq/tool/code_execution/local_sandbox/local_sandbox_server.py +83 -0
- aiq/tool/code_execution/local_sandbox/sandbox.requirements.txt +4 -0
- aiq/tool/code_execution/local_sandbox/start_local_sandbox.sh +25 -0
- aiq/tool/code_execution/register.py +70 -0
- aiq/tool/code_execution/utils.py +100 -0
- aiq/tool/datetime_tools.py +42 -0
- aiq/tool/document_search.py +141 -0
- aiq/tool/github_tools/__init__.py +0 -0
- aiq/tool/github_tools/create_github_commit.py +133 -0
- aiq/tool/github_tools/create_github_issue.py +87 -0
- aiq/tool/github_tools/create_github_pr.py +106 -0
- aiq/tool/github_tools/get_github_file.py +106 -0
- aiq/tool/github_tools/get_github_issue.py +166 -0
- aiq/tool/github_tools/get_github_pr.py +256 -0
- aiq/tool/github_tools/update_github_issue.py +100 -0
- aiq/tool/mcp/__init__.py +14 -0
- aiq/tool/mcp/mcp_client.py +220 -0
- aiq/tool/mcp/mcp_tool.py +95 -0
- aiq/tool/memory_tools/__init__.py +0 -0
- aiq/tool/memory_tools/add_memory_tool.py +79 -0
- aiq/tool/memory_tools/delete_memory_tool.py +67 -0
- aiq/tool/memory_tools/get_memory_tool.py +72 -0
- aiq/tool/nvidia_rag.py +95 -0
- aiq/tool/register.py +37 -0
- aiq/tool/retriever.py +89 -0
- aiq/tool/server_tools.py +63 -0
- aiq/utils/__init__.py +0 -0
- aiq/utils/data_models/__init__.py +0 -0
- aiq/utils/data_models/schema_validator.py +58 -0
- aiq/utils/debugging_utils.py +43 -0
- aiq/utils/exception_handlers/__init__.py +0 -0
- aiq/utils/exception_handlers/schemas.py +114 -0
- aiq/utils/io/__init__.py +0 -0
- aiq/utils/io/yaml_tools.py +119 -0
- aiq/utils/metadata_utils.py +74 -0
- aiq/utils/optional_imports.py +142 -0
- aiq/utils/producer_consumer_queue.py +178 -0
- aiq/utils/reactive/__init__.py +0 -0
- aiq/utils/reactive/base/__init__.py +0 -0
- aiq/utils/reactive/base/observable_base.py +65 -0
- aiq/utils/reactive/base/observer_base.py +55 -0
- aiq/utils/reactive/base/subject_base.py +79 -0
- aiq/utils/reactive/observable.py +59 -0
- aiq/utils/reactive/observer.py +76 -0
- aiq/utils/reactive/subject.py +131 -0
- aiq/utils/reactive/subscription.py +49 -0
- aiq/utils/settings/__init__.py +0 -0
- aiq/utils/settings/global_settings.py +197 -0
- aiq/utils/type_converter.py +232 -0
- aiq/utils/type_utils.py +397 -0
- aiq/utils/url_utils.py +27 -0
- aiqtoolkit-1.1.0.dist-info/METADATA +331 -0
- aiqtoolkit-1.1.0.dist-info/RECORD +316 -0
- aiqtoolkit-1.1.0.dist-info/WHEEL +5 -0
- aiqtoolkit-1.1.0.dist-info/entry_points.txt +17 -0
- aiqtoolkit-1.1.0.dist-info/licenses/LICENSE-3rd-party.txt +3686 -0
- aiqtoolkit-1.1.0.dist-info/licenses/LICENSE.md +201 -0
- aiqtoolkit-1.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
import logging
|
|
18
|
+
import os
|
|
19
|
+
import shutil
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
|
|
22
|
+
from aiq.data_models.swe_bench_model import SWEBenchInput
|
|
23
|
+
from aiq.data_models.swe_bench_model import SWEBenchOutput
|
|
24
|
+
from aiq.eval.evaluator.evaluator_model import EvalInput
|
|
25
|
+
from aiq.eval.evaluator.evaluator_model import EvalOutput
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
import swebench.harness.run_evaluation as swebench_eval
|
|
29
|
+
from swebench.harness.constants import MAP_REPO_VERSION_TO_SPECS
|
|
30
|
+
except ImportError as exc:
|
|
31
|
+
raise ImportError("Please install swebench to use this evaluator") from exc
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class SweBenchEvaluator:
|
|
37
|
+
|
|
38
|
+
def __init__(self, run_id: str, max_workers: int, output_dir: Path):
|
|
39
|
+
|
|
40
|
+
self.run_id = run_id
|
|
41
|
+
self.max_workers = max_workers
|
|
42
|
+
self.output_dir = output_dir
|
|
43
|
+
|
|
44
|
+
# metadata
|
|
45
|
+
self._unsupported_repos = []
|
|
46
|
+
self._swe_bench_inputs = []
|
|
47
|
+
self._swe_bench_outputs = []
|
|
48
|
+
self._model_name_or_path = "no_llm"
|
|
49
|
+
|
|
50
|
+
def get_model_name_from_output(self, workflow_output: list[dict]) -> str | None:
|
|
51
|
+
"""Fetch the `model_name_or_path` from the first entry in the list."""
|
|
52
|
+
return workflow_output[0].get("model_name_or_path") if workflow_output else None
|
|
53
|
+
|
|
54
|
+
@staticmethod
|
|
55
|
+
def empty_report_dir(report_dir: Path):
|
|
56
|
+
"""Remove the current contents of the report directory."""
|
|
57
|
+
os.makedirs(report_dir, exist_ok=True)
|
|
58
|
+
|
|
59
|
+
# Iterate through all files in the directory and remove them
|
|
60
|
+
for item in report_dir.iterdir():
|
|
61
|
+
if item.is_file(): # Remove files only
|
|
62
|
+
item.unlink()
|
|
63
|
+
elif item.is_dir(): # Remove subdirectories and their contents
|
|
64
|
+
shutil.rmtree(item)
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def move_report_and_logs(swe_bench_report_file: str, logs_dir: str, report_dir: Path):
|
|
68
|
+
""" Temorary function to move the report and logs to the output directory"""
|
|
69
|
+
try:
|
|
70
|
+
shutil.move(swe_bench_report_file, report_dir)
|
|
71
|
+
except Exception as e:
|
|
72
|
+
logger.exception("Error moving report file: %s", e, exc_info=True)
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
dest_logs_dir = os.path.join(report_dir, 'logs')
|
|
76
|
+
shutil.move(logs_dir, dest_logs_dir)
|
|
77
|
+
except Exception as e:
|
|
78
|
+
logger.exception("Error moving logs directory: %s", e, exc_info=True)
|
|
79
|
+
|
|
80
|
+
def is_repo_supported(self, repo: str, version: str) -> bool:
|
|
81
|
+
"""Check if the repo is supported by swebench"""
|
|
82
|
+
|
|
83
|
+
try:
|
|
84
|
+
_ = MAP_REPO_VERSION_TO_SPECS[repo][str(version)]
|
|
85
|
+
except KeyError:
|
|
86
|
+
self._unsupported_repos.append({repo, version})
|
|
87
|
+
return False
|
|
88
|
+
return True
|
|
89
|
+
|
|
90
|
+
def process_eval_input(self, eval_input: EvalInput) -> tuple[Path, Path]:
|
|
91
|
+
"""Converts EvalInput into lists of SWEBenchInput and SWEBenchOutput models and applies filtering."""
|
|
92
|
+
# Convert input_obj and output_obj JSON strings to SWEBenchInput and SWEBenchOutput models
|
|
93
|
+
swebench_inputs = []
|
|
94
|
+
swebench_outputs = []
|
|
95
|
+
|
|
96
|
+
for item in eval_input.eval_input_items:
|
|
97
|
+
try:
|
|
98
|
+
swebench_input = SWEBenchInput.model_validate_json(item.input_obj) # Convert input JSON to model
|
|
99
|
+
swebench_input.version = str(swebench_input.version) # Convert version to string
|
|
100
|
+
swebench_inputs.append(swebench_input)
|
|
101
|
+
|
|
102
|
+
if item.output_obj: # Convert output JSON to model if available
|
|
103
|
+
swebench_output = SWEBenchOutput.model_validate_json(item.output_obj)
|
|
104
|
+
swebench_outputs.append(swebench_output)
|
|
105
|
+
# this is bit of a hack to match the swe_bench harness
|
|
106
|
+
self._model_name_or_path = swebench_output.model_name_or_path
|
|
107
|
+
|
|
108
|
+
except Exception as e:
|
|
109
|
+
logger.exception("Failed to parse EvalInputItem %s: %s", item.id, e, exc_info=True)
|
|
110
|
+
|
|
111
|
+
# Filter out repos/version not supported by SWEBench
|
|
112
|
+
supported_inputs = [
|
|
113
|
+
swebench for swebench in swebench_inputs if self.is_repo_supported(swebench.repo, swebench.version)
|
|
114
|
+
]
|
|
115
|
+
|
|
116
|
+
if not supported_inputs:
|
|
117
|
+
logger.error("No supported instances; nothing to evaluate")
|
|
118
|
+
return None, None
|
|
119
|
+
|
|
120
|
+
if len(supported_inputs) < len(swebench_inputs):
|
|
121
|
+
logger.warning("The following repos are not supported by SWEBench and were skipped:\n %s",
|
|
122
|
+
{s.repo
|
|
123
|
+
for s in swebench_inputs if s not in supported_inputs})
|
|
124
|
+
|
|
125
|
+
# Write SWEBenchInput to file
|
|
126
|
+
workflow_input_file = self.output_dir / "aiq_workflow_input.json"
|
|
127
|
+
workflow_input_file.parent.mkdir(parents=True, exist_ok=True)
|
|
128
|
+
Path(workflow_input_file).write_text(json.dumps([swebench.model_dump() for swebench in supported_inputs],
|
|
129
|
+
indent=2),
|
|
130
|
+
encoding="utf-8")
|
|
131
|
+
logger.info("Workflow input written to %s", workflow_input_file)
|
|
132
|
+
|
|
133
|
+
# Filter SWEBenchOutput to include only instance_ids present in SWEBenchInput
|
|
134
|
+
valid_instance_ids = {swebench.instance_id for swebench in supported_inputs}
|
|
135
|
+
filtered_outputs = [output for output in swebench_outputs if output.instance_id in valid_instance_ids]
|
|
136
|
+
|
|
137
|
+
if not filtered_outputs:
|
|
138
|
+
logger.error("No supported outputs; nothing to evaluate")
|
|
139
|
+
return None, None
|
|
140
|
+
|
|
141
|
+
# Write SWEBenchOutput to file
|
|
142
|
+
workflow_output_file = self.output_dir / "aiq_workflow_output.json"
|
|
143
|
+
Path(workflow_output_file).write_text(json.dumps([output.model_dump() for output in filtered_outputs],
|
|
144
|
+
indent=2),
|
|
145
|
+
encoding="utf-8")
|
|
146
|
+
logger.info("Workflow output written to %s", workflow_output_file)
|
|
147
|
+
|
|
148
|
+
self._swe_bench_inputs = supported_inputs
|
|
149
|
+
self._swe_bench_outputs = filtered_outputs
|
|
150
|
+
return workflow_input_file, workflow_output_file
|
|
151
|
+
|
|
152
|
+
def build_eval_output(self):
|
|
153
|
+
"""Builds the EvalOutput object from the SWEBenchOutput models and the average score."""
|
|
154
|
+
# WIP: Build a score based on eval run logs
|
|
155
|
+
for swebench_output in self._swe_bench_outputs:
|
|
156
|
+
yield {"id": swebench_output.instance_id, "score": "-", "reasoning": "-"}
|
|
157
|
+
|
|
158
|
+
@staticmethod
|
|
159
|
+
def compute_score(success_cnt: int, total_cnt: int) -> float:
|
|
160
|
+
if total_cnt == 0:
|
|
161
|
+
return 0.0
|
|
162
|
+
score = success_cnt / total_cnt
|
|
163
|
+
return min(max(score, 0.0), 1.0)
|
|
164
|
+
|
|
165
|
+
async def evaluate(self, eval_input: EvalInput) -> EvalOutput:
|
|
166
|
+
'''Run the swebench evaluation and store the report in the output directory'''
|
|
167
|
+
|
|
168
|
+
# Process the EvalInput
|
|
169
|
+
workflow_input_file, workflow_output_file = self.process_eval_input(eval_input)
|
|
170
|
+
if not workflow_input_file or not workflow_output_file:
|
|
171
|
+
# nothing to evaluate
|
|
172
|
+
return EvalOutput(average_score=0.0, eval_output_items=[])
|
|
173
|
+
|
|
174
|
+
report_dir = self.output_dir / "swe_bench_reports"
|
|
175
|
+
self.empty_report_dir(report_dir)
|
|
176
|
+
|
|
177
|
+
logger.info("Starting swe_bench run %s", self.run_id)
|
|
178
|
+
swebench_eval.main(dataset_name=str(workflow_input_file),
|
|
179
|
+
split="dev",
|
|
180
|
+
instance_ids=[],
|
|
181
|
+
predictions_path=str(workflow_output_file),
|
|
182
|
+
max_workers=self.max_workers,
|
|
183
|
+
force_rebuild=False,
|
|
184
|
+
cache_level="env",
|
|
185
|
+
clean=False,
|
|
186
|
+
open_file_limit=4096,
|
|
187
|
+
run_id=self.run_id,
|
|
188
|
+
timeout=1800,
|
|
189
|
+
namespace=None,
|
|
190
|
+
rewrite_reports=False,
|
|
191
|
+
modal=False,
|
|
192
|
+
instance_image_tag='latest',
|
|
193
|
+
report_dir=str(report_dir))
|
|
194
|
+
logger.info("Completed swe_bench run %s", self.run_id)
|
|
195
|
+
|
|
196
|
+
swe_bench_report_file = f"{self._model_name_or_path}.{self.run_id}.json"
|
|
197
|
+
|
|
198
|
+
# There is a bug in swebench because of which report_dir is being ignored. Copy the report to the output dir
|
|
199
|
+
self.move_report_and_logs(swe_bench_report_file=swe_bench_report_file, logs_dir="logs", report_dir=report_dir)
|
|
200
|
+
logger.info("SWE_bench report and logs written to %s directory", report_dir)
|
|
201
|
+
|
|
202
|
+
# read the swe_bench report file
|
|
203
|
+
report_file = report_dir / swe_bench_report_file
|
|
204
|
+
# if report file is not present, return empty EvalOutput
|
|
205
|
+
avg_score = 0.0
|
|
206
|
+
if report_file.exists():
|
|
207
|
+
with open(report_file, "r", encoding="utf-8") as f:
|
|
208
|
+
report = json.load(f)
|
|
209
|
+
resolved_instances = report.get("resolved_instances", 0)
|
|
210
|
+
total_instances = report.get("total_instances", 0)
|
|
211
|
+
avg_score = self.compute_score(resolved_instances, total_instances)
|
|
212
|
+
|
|
213
|
+
# Build the EvalOutput from self._swe_bench_outputs and avg_score
|
|
214
|
+
eval_output_items = list(self.build_eval_output())
|
|
215
|
+
return EvalOutput(average_score=avg_score, eval_output_items=eval_output_items)
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from pydantic import Field
|
|
17
|
+
|
|
18
|
+
from aiq.builder.builder import EvalBuilder
|
|
19
|
+
from aiq.builder.evaluator import EvaluatorInfo
|
|
20
|
+
from aiq.cli.register_workflow import register_evaluator
|
|
21
|
+
from aiq.data_models.evaluator import EvaluatorBaseConfig
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SweBenchEvaluatorConfig(EvaluatorBaseConfig, name="swe_bench"):
|
|
25
|
+
"""Code patch evaluation for SWE Bench problems."""
|
|
26
|
+
|
|
27
|
+
run_id: str = Field(description="swe-bench test harness run identifier.")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@register_evaluator(config_type=SweBenchEvaluatorConfig)
|
|
31
|
+
async def register_swe_bench_evaluator(config: SweBenchEvaluatorConfig, builder: EvalBuilder):
|
|
32
|
+
|
|
33
|
+
from .evaluate import SweBenchEvaluator
|
|
34
|
+
_evaluator = SweBenchEvaluator(config.run_id, builder.get_max_concurrency(), builder.get_output_dir())
|
|
35
|
+
|
|
36
|
+
yield EvaluatorInfo(config=config, evaluate_fn=_evaluator.evaluate, description="SWE Bench Evaluator")
|
|
File without changes
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import logging
|
|
18
|
+
|
|
19
|
+
from langchain.evaluation import TrajectoryEvalChain
|
|
20
|
+
from langchain_core.language_models import BaseChatModel
|
|
21
|
+
from langchain_core.tools import BaseTool
|
|
22
|
+
from tqdm import tqdm
|
|
23
|
+
|
|
24
|
+
from aiq.eval.evaluator.evaluator_model import EvalInput
|
|
25
|
+
from aiq.eval.evaluator.evaluator_model import EvalInputItem
|
|
26
|
+
from aiq.eval.evaluator.evaluator_model import EvalOutput
|
|
27
|
+
from aiq.eval.evaluator.evaluator_model import EvalOutputItem
|
|
28
|
+
from aiq.eval.utils.tqdm_position_registry import TqdmPositionRegistry
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TrajectoryEvaluator:
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
llm: BaseChatModel,
|
|
38
|
+
tools: list[BaseTool] | None = None,
|
|
39
|
+
max_concurrency: int = 8,
|
|
40
|
+
):
|
|
41
|
+
|
|
42
|
+
self.llm = llm
|
|
43
|
+
self.tools = tools
|
|
44
|
+
self.max_concurrency = max_concurrency
|
|
45
|
+
self.semaphore = asyncio.Semaphore(self.max_concurrency)
|
|
46
|
+
# Initialize trajectory evaluation chain
|
|
47
|
+
self.traj_eval_chain = TrajectoryEvalChain.from_llm(llm=self.llm,
|
|
48
|
+
tools=self.tools,
|
|
49
|
+
return_reasoning=True,
|
|
50
|
+
requires_reference=True)
|
|
51
|
+
logger.debug("Trajectory evaluation chain initialized.")
|
|
52
|
+
|
|
53
|
+
async def evaluate(self, eval_input: EvalInput) -> EvalOutput:
|
|
54
|
+
"""
|
|
55
|
+
Evaluates the agent trajectories using trajectory evaluation chain.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
num_records = len(eval_input.eval_input_items)
|
|
59
|
+
logger.info("Running trajectory evaluation with %d records", num_records)
|
|
60
|
+
from aiq.data_models.intermediate_step import IntermediateStepType
|
|
61
|
+
from aiq.eval.intermediate_step_adapter import IntermediateStepAdapter
|
|
62
|
+
|
|
63
|
+
intermediate_step_adapter = IntermediateStepAdapter()
|
|
64
|
+
event_filter = [IntermediateStepType.LLM_END, IntermediateStepType.TOOL_END]
|
|
65
|
+
|
|
66
|
+
async def process_item(item: EvalInputItem) -> tuple[float, dict]:
|
|
67
|
+
"""
|
|
68
|
+
Evaluate a single EvalInputItem asynchronously and return a tuple of-
|
|
69
|
+
1. score
|
|
70
|
+
2. reasoning for the score
|
|
71
|
+
"""
|
|
72
|
+
question = item.input_obj
|
|
73
|
+
generated_answer = item.output_obj
|
|
74
|
+
agent_trajectory = intermediate_step_adapter.get_agent_actions(item.trajectory, event_filter)
|
|
75
|
+
try:
|
|
76
|
+
eval_result = await self.traj_eval_chain.aevaluate_agent_trajectory(
|
|
77
|
+
input=question,
|
|
78
|
+
agent_trajectory=agent_trajectory,
|
|
79
|
+
prediction=generated_answer,
|
|
80
|
+
)
|
|
81
|
+
except Exception as e:
|
|
82
|
+
logger.exception("Error evaluating trajectory for question: %s, Error: %s", question, e, exc_info=True)
|
|
83
|
+
return 0.0, f"Error evaluating trajectory: {e}"
|
|
84
|
+
|
|
85
|
+
reasoning = {
|
|
86
|
+
"reasoning": eval_result["reasoning"],
|
|
87
|
+
"trajectory": [(action.model_dump(), output) for (action, output) in agent_trajectory]
|
|
88
|
+
}
|
|
89
|
+
return eval_result["score"], reasoning
|
|
90
|
+
|
|
91
|
+
async def wrapped_process(item: EvalInputItem) -> tuple[float, dict]:
|
|
92
|
+
async with self.semaphore:
|
|
93
|
+
result = await process_item(item)
|
|
94
|
+
pbar.update(1)
|
|
95
|
+
return result
|
|
96
|
+
|
|
97
|
+
# Execute all evaluations asynchronously
|
|
98
|
+
try:
|
|
99
|
+
tqdm_position = TqdmPositionRegistry.claim()
|
|
100
|
+
pbar = tqdm(total=len(eval_input.eval_input_items), desc="Evaluating Trajectory", position=tqdm_position)
|
|
101
|
+
results = await asyncio.gather(*[wrapped_process(item) for item in eval_input.eval_input_items])
|
|
102
|
+
finally:
|
|
103
|
+
pbar.close()
|
|
104
|
+
TqdmPositionRegistry.release(tqdm_position)
|
|
105
|
+
|
|
106
|
+
# Extract scores and reasonings
|
|
107
|
+
sample_scores, sample_reasonings = zip(*results) if results else ([], [])
|
|
108
|
+
|
|
109
|
+
# Compute average score
|
|
110
|
+
avg_score = round(sum(sample_scores) / len(sample_scores), 2) if sample_scores else 0.0
|
|
111
|
+
|
|
112
|
+
# Construct EvalOutputItems
|
|
113
|
+
eval_output_items = [
|
|
114
|
+
EvalOutputItem(id=item.id, score=score, reasoning=reasoning)
|
|
115
|
+
for item, score, reasoning in zip(eval_input.eval_input_items, sample_scores, sample_reasonings)
|
|
116
|
+
]
|
|
117
|
+
|
|
118
|
+
return EvalOutput(average_score=avg_score, eval_output_items=eval_output_items)
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from pydantic import Field
|
|
17
|
+
|
|
18
|
+
from aiq.builder.builder import EvalBuilder
|
|
19
|
+
from aiq.builder.evaluator import EvaluatorInfo
|
|
20
|
+
from aiq.cli.register_workflow import register_evaluator
|
|
21
|
+
from aiq.data_models.evaluator import EvaluatorBaseConfig
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class TrajectoryEvaluatorConfig(EvaluatorBaseConfig, name="trajectory"):
|
|
25
|
+
"""Agent Trajectory Evaluation."""
|
|
26
|
+
|
|
27
|
+
llm_name: str = Field(description="LLM as a judge.")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@register_evaluator(config_type=TrajectoryEvaluatorConfig)
|
|
31
|
+
async def register_trajectory_evaluator(config: TrajectoryEvaluatorConfig, builder: EvalBuilder):
|
|
32
|
+
from aiq.builder.framework_enum import LLMFrameworkEnum
|
|
33
|
+
|
|
34
|
+
from .evaluate import TrajectoryEvaluator
|
|
35
|
+
llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
36
|
+
tools = builder.get_all_tools(wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
37
|
+
|
|
38
|
+
_evaluator = TrajectoryEvaluator(llm, tools, builder.get_max_concurrency())
|
|
39
|
+
|
|
40
|
+
yield EvaluatorInfo(config=config, evaluate_fn=_evaluator.evaluate, description="Trajectory Evaluator")
|
|
File without changes
|
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import logging
|
|
18
|
+
|
|
19
|
+
from langchain.output_parsers import ResponseSchema
|
|
20
|
+
from langchain.output_parsers import StructuredOutputParser
|
|
21
|
+
from langchain.schema import HumanMessage
|
|
22
|
+
from langchain.schema import SystemMessage
|
|
23
|
+
from langchain_core.language_models import BaseChatModel
|
|
24
|
+
from tqdm import tqdm
|
|
25
|
+
|
|
26
|
+
from aiq.eval.evaluator.evaluator_model import EvalInput
|
|
27
|
+
from aiq.eval.evaluator.evaluator_model import EvalInputItem
|
|
28
|
+
from aiq.eval.evaluator.evaluator_model import EvalOutput
|
|
29
|
+
from aiq.eval.evaluator.evaluator_model import EvalOutputItem
|
|
30
|
+
from aiq.eval.utils.tqdm_position_registry import TqdmPositionRegistry
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
# pylint: disable=line-too-long
|
|
35
|
+
# flake8: noqa: E501
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def evaluation_prompt(judge_llm_prompt: str,
|
|
39
|
+
question: str,
|
|
40
|
+
answer_description: str,
|
|
41
|
+
generated_answer: str,
|
|
42
|
+
format_instructions: str,
|
|
43
|
+
default_scoring: bool):
|
|
44
|
+
"""
|
|
45
|
+
This function generates a prompt for the judge LLM to evaluate the generated answer.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
DEFAULT_SCORING_INSTRUCTIONS = """
|
|
49
|
+
The coverage score is a measure of how well the generated answer covers the critical aspects mentioned in the expected answer. A low coverage score indicates that the generated answer misses critical aspects of the expected answer. A middle coverage score indicates that the generated answer covers some of the must-haves of the expected answer but lacks other details. A high coverage score indicates that all of the expected aspects are present in the generated answer.
|
|
50
|
+
The correctness score is a measure of how well the generated answer matches the expected answer. A low correctness score indicates that the generated answer is incorrect or does not match the expected answer. A middle correctness score indicates that the generated answer is correct but lacks some details. A high correctness score indicates that the generated answer is exactly the same as the expected answer.
|
|
51
|
+
The relevance score is a measure of how well the generated answer is relevant to the question. A low relevance score indicates that the generated answer is not relevant to the question. A middle relevance score indicates that the generated answer is somewhat relevant to the question. A high relevance score indicates that the generated answer is exactly relevant to the question.
|
|
52
|
+
The reasoning is a 1-2 sentence explanation for the scoring.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
DEFAULT_EVAL_PROMPT = (f"You are an intelligent assistant that responds strictly in JSON format."
|
|
56
|
+
f"Judge based on the following scoring rubric: {DEFAULT_SCORING_INSTRUCTIONS}"
|
|
57
|
+
f"{judge_llm_prompt}\n"
|
|
58
|
+
f"{format_instructions}\n"
|
|
59
|
+
f"Here is the user's query: {question}"
|
|
60
|
+
f"Here is the description of the expected answer: {answer_description}"
|
|
61
|
+
f"Here is the generated answer: {generated_answer}")
|
|
62
|
+
|
|
63
|
+
EVAL_PROMPT = (f"You are an intelligent assistant that responds strictly in JSON format. {judge_llm_prompt}\n"
|
|
64
|
+
f"{format_instructions}\n"
|
|
65
|
+
f"Here is the user's query: {question}"
|
|
66
|
+
f"Here is the description of the expected answer: {answer_description}"
|
|
67
|
+
f"Here is the generated answer: {generated_answer}")
|
|
68
|
+
|
|
69
|
+
return EVAL_PROMPT if not default_scoring else DEFAULT_EVAL_PROMPT
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class TunableRagEvaluator:
|
|
73
|
+
'''Tunable RAG evaluator class with customizable LLM prompt for scoring.'''
|
|
74
|
+
|
|
75
|
+
def __init__(self,
|
|
76
|
+
llm: BaseChatModel,
|
|
77
|
+
judge_llm_prompt: str,
|
|
78
|
+
max_concurrency: int,
|
|
79
|
+
default_scoring: bool,
|
|
80
|
+
default_score_weights: dict):
|
|
81
|
+
self.llm = llm
|
|
82
|
+
self.max_concurrency = max_concurrency
|
|
83
|
+
self.judge_llm_prompt = judge_llm_prompt
|
|
84
|
+
self.semaphore = asyncio.Semaphore(self.max_concurrency)
|
|
85
|
+
self.default_scoring = default_scoring
|
|
86
|
+
# Use user-provided weights if available; otherwise, set equal weights for each score
|
|
87
|
+
self.default_score_weights = default_score_weights if default_score_weights else {
|
|
88
|
+
"coverage": 1 / 3, "correctness": 1 / 3, "relevance": 1 / 3
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
async def evaluate(self, eval_input: EvalInput) -> EvalOutput:
|
|
92
|
+
'''Evaluate function'''
|
|
93
|
+
|
|
94
|
+
async def process_item(item):
|
|
95
|
+
"""Compute RAG evaluation for an individual item"""
|
|
96
|
+
question = item.input_obj
|
|
97
|
+
answer_description = item.expected_output_obj
|
|
98
|
+
generated_answer = item.output_obj
|
|
99
|
+
|
|
100
|
+
# Call judge LLM to generate score
|
|
101
|
+
score = 0.0
|
|
102
|
+
|
|
103
|
+
default_evaluation_schema = [
|
|
104
|
+
ResponseSchema(
|
|
105
|
+
name="coverage_score",
|
|
106
|
+
description=
|
|
107
|
+
"Score for the coverage of all critical aspects mentioned in the expected answer. Ex. 0.5",
|
|
108
|
+
type="float"),
|
|
109
|
+
ResponseSchema(
|
|
110
|
+
name="correctness_score",
|
|
111
|
+
description=
|
|
112
|
+
"Score for the accuracy of the generated answer compared to the expected answer. Ex. 0.5",
|
|
113
|
+
type="float"),
|
|
114
|
+
ResponseSchema(name="relevance_score",
|
|
115
|
+
description="Score for the relevance of the generated answer to the question. Ex. 0.5",
|
|
116
|
+
type="float"),
|
|
117
|
+
ResponseSchema(
|
|
118
|
+
name="reasoning",
|
|
119
|
+
description=
|
|
120
|
+
"1-2 summarized sentences of reasoning for the scores. Ex. 'The generated answer covers all critical aspects mentioned in the expected answer, is correct, and is relevant to the question.'",
|
|
121
|
+
type="string"),
|
|
122
|
+
]
|
|
123
|
+
|
|
124
|
+
custom_evaluation_schema = [
|
|
125
|
+
ResponseSchema(name="score", description="Score for the generated answer. Ex. 0.5", type="float"),
|
|
126
|
+
ResponseSchema(
|
|
127
|
+
name="reasoning",
|
|
128
|
+
description=
|
|
129
|
+
"1-2 sentence reasoning for the score. Ex. 'The generated answer is exactly the same as the description of the expected answer.'",
|
|
130
|
+
type="string"),
|
|
131
|
+
]
|
|
132
|
+
|
|
133
|
+
if self.default_scoring:
|
|
134
|
+
evaluation_schema = default_evaluation_schema
|
|
135
|
+
else:
|
|
136
|
+
evaluation_schema = custom_evaluation_schema
|
|
137
|
+
|
|
138
|
+
llm_input_response_parser = StructuredOutputParser.from_response_schemas(evaluation_schema)
|
|
139
|
+
format_instructions = llm_input_response_parser.get_format_instructions()
|
|
140
|
+
|
|
141
|
+
eval_prompt = evaluation_prompt(judge_llm_prompt=self.judge_llm_prompt,
|
|
142
|
+
question=question,
|
|
143
|
+
answer_description=answer_description,
|
|
144
|
+
generated_answer=generated_answer,
|
|
145
|
+
format_instructions=format_instructions,
|
|
146
|
+
default_scoring=self.default_scoring)
|
|
147
|
+
|
|
148
|
+
messages = [
|
|
149
|
+
SystemMessage(content="You must respond only in JSON format."), HumanMessage(content=eval_prompt)
|
|
150
|
+
]
|
|
151
|
+
|
|
152
|
+
response = await self.llm.ainvoke(messages)
|
|
153
|
+
|
|
154
|
+
# Initialize default values to handle service errors
|
|
155
|
+
coverage_score = 0.0
|
|
156
|
+
correctness_score = 0.0
|
|
157
|
+
relevance_score = 0.0
|
|
158
|
+
reasoning = "Error in evaluator from parsing judge LLM response."
|
|
159
|
+
|
|
160
|
+
try:
|
|
161
|
+
parsed_response = llm_input_response_parser.parse(response.content)
|
|
162
|
+
if self.default_scoring:
|
|
163
|
+
try:
|
|
164
|
+
coverage_score = parsed_response["coverage_score"]
|
|
165
|
+
correctness_score = parsed_response["correctness_score"]
|
|
166
|
+
relevance_score = parsed_response["relevance_score"]
|
|
167
|
+
reasoning = parsed_response["reasoning"]
|
|
168
|
+
except KeyError as e:
|
|
169
|
+
logger.error("Missing required keys in default scoring response: %s",
|
|
170
|
+
", ".join(str(arg) for arg in e.args))
|
|
171
|
+
reasoning = f"Error in evaluator from parsing judge LLM response. Missing required key(s): {', '.join(str(arg) for arg in e.args)}"
|
|
172
|
+
|
|
173
|
+
coverage_weight = self.default_score_weights.get("coverage", 1 / 3)
|
|
174
|
+
correctness_weight = self.default_score_weights.get("correctness", 1 / 3)
|
|
175
|
+
relevance_weight = self.default_score_weights.get("relevance", 1 / 3)
|
|
176
|
+
|
|
177
|
+
# Calculate score
|
|
178
|
+
total_weight = coverage_weight + correctness_weight + relevance_weight
|
|
179
|
+
coverage_weight = coverage_weight / total_weight
|
|
180
|
+
correctness_weight = correctness_weight / total_weight
|
|
181
|
+
relevance_weight = relevance_weight / total_weight
|
|
182
|
+
|
|
183
|
+
if round(coverage_weight + correctness_weight + relevance_weight, 2) != 1:
|
|
184
|
+
logger.warning("The sum of the default score weights is not 1. The weights will be normalized.")
|
|
185
|
+
coverage_weight = coverage_weight / (coverage_weight + correctness_weight + relevance_weight)
|
|
186
|
+
correctness_weight = correctness_weight / (coverage_weight + correctness_weight +
|
|
187
|
+
relevance_weight)
|
|
188
|
+
relevance_weight = relevance_weight / (coverage_weight + correctness_weight + relevance_weight)
|
|
189
|
+
|
|
190
|
+
score = (coverage_weight * coverage_score + correctness_weight * correctness_score +
|
|
191
|
+
relevance_weight * relevance_score)
|
|
192
|
+
|
|
193
|
+
else:
|
|
194
|
+
try:
|
|
195
|
+
score = parsed_response["score"]
|
|
196
|
+
reasoning = parsed_response["reasoning"]
|
|
197
|
+
except KeyError as e:
|
|
198
|
+
logger.error("Missing required keys in custom scoring response: %s",
|
|
199
|
+
", ".join(str(arg) for arg in e.args))
|
|
200
|
+
reasoning = f"Error in evaluator from parsing judge LLM response. Missing required key(s): {', '.join(str(arg) for arg in e.args)}"
|
|
201
|
+
raise
|
|
202
|
+
except (KeyError, ValueError) as e:
|
|
203
|
+
logger.error("Error parsing judge LLM response: %s", e)
|
|
204
|
+
score = 0.0
|
|
205
|
+
reasoning = "Error in evaluator from parsing judge LLM response."
|
|
206
|
+
|
|
207
|
+
if self.default_scoring:
|
|
208
|
+
reasoning = {
|
|
209
|
+
"question": question,
|
|
210
|
+
"answer_description": answer_description,
|
|
211
|
+
"generated_answer": generated_answer,
|
|
212
|
+
"score_breakdown": {
|
|
213
|
+
"coverage_score": coverage_score,
|
|
214
|
+
"correctness_score": correctness_score,
|
|
215
|
+
"relevance_score": relevance_score,
|
|
216
|
+
},
|
|
217
|
+
"reasoning": reasoning,
|
|
218
|
+
}
|
|
219
|
+
else:
|
|
220
|
+
reasoning = {
|
|
221
|
+
"question": question,
|
|
222
|
+
"answer_description": answer_description,
|
|
223
|
+
"generated_answer": generated_answer,
|
|
224
|
+
"reasoning": reasoning
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
return score, reasoning
|
|
228
|
+
|
|
229
|
+
async def wrapped_process(item: EvalInputItem) -> tuple[float, dict]:
|
|
230
|
+
"""
|
|
231
|
+
Process an item asynchronously and update the progress bar.
|
|
232
|
+
Use the semaphore to limit the number of concurrent items.
|
|
233
|
+
"""
|
|
234
|
+
async with self.semaphore:
|
|
235
|
+
result = await process_item(item)
|
|
236
|
+
# Update the progress bar
|
|
237
|
+
pbar.update(1)
|
|
238
|
+
return result
|
|
239
|
+
|
|
240
|
+
try:
|
|
241
|
+
# Claim a tqdm position to display the progress bar
|
|
242
|
+
tqdm_position = TqdmPositionRegistry.claim()
|
|
243
|
+
# Create a progress bar
|
|
244
|
+
pbar = tqdm(total=len(eval_input.eval_input_items), desc="Evaluating RAG", position=tqdm_position)
|
|
245
|
+
# Process items concurrently with a limit on concurrency
|
|
246
|
+
results = await asyncio.gather(*[wrapped_process(item) for item in eval_input.eval_input_items])
|
|
247
|
+
finally:
|
|
248
|
+
pbar.close()
|
|
249
|
+
TqdmPositionRegistry.release(tqdm_position)
|
|
250
|
+
|
|
251
|
+
# Extract scores and reasonings
|
|
252
|
+
sample_scores, sample_reasonings = zip(*results) if results else ([], [])
|
|
253
|
+
|
|
254
|
+
# Compute average score
|
|
255
|
+
avg_score = round(sum(sample_scores) / len(sample_scores), 2) if sample_scores else 0.0
|
|
256
|
+
|
|
257
|
+
# Construct EvalOutputItems
|
|
258
|
+
eval_output_items = [
|
|
259
|
+
EvalOutputItem(id=item.id, score=score, reasoning=reasoning)
|
|
260
|
+
for item, score, reasoning in zip(eval_input.eval_input_items, sample_scores, sample_reasonings)
|
|
261
|
+
]
|
|
262
|
+
|
|
263
|
+
return EvalOutput(average_score=avg_score, eval_output_items=eval_output_items)
|