eval-framework 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eval_framework/__init__.py +7 -0
- eval_framework/base_config.py +36 -0
- eval_framework/context/__init__.py +0 -0
- eval_framework/context/determined.py +177 -0
- eval_framework/context/eval.py +121 -0
- eval_framework/context/local.py +78 -0
- eval_framework/evaluation_generator.py +234 -0
- eval_framework/exceptions.py +2 -0
- eval_framework/external/ifeval_impl/README.md +5 -0
- eval_framework/external/ifeval_impl/instructions.py +1523 -0
- eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
- eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
- eval_framework/external/ifeval_impl/utils.py +135 -0
- eval_framework/llm/__init__.py +0 -0
- eval_framework/llm/aleph_alpha.py +432 -0
- eval_framework/llm/base.py +180 -0
- eval_framework/llm/huggingface.py +418 -0
- eval_framework/llm/mistral.py +88 -0
- eval_framework/llm/models.py +28 -0
- eval_framework/llm/openai.py +400 -0
- eval_framework/llm/vllm.py +554 -0
- eval_framework/logger.py +3 -0
- eval_framework/main.py +166 -0
- eval_framework/metrics/__init__.py +0 -0
- eval_framework/metrics/base.py +40 -0
- eval_framework/metrics/completion/__init__.py +1 -0
- eval_framework/metrics/completion/accuracy_completion.py +16 -0
- eval_framework/metrics/completion/aidanbench.py +28 -0
- eval_framework/metrics/completion/bleu.py +76 -0
- eval_framework/metrics/completion/chrf.py +62 -0
- eval_framework/metrics/completion/code_assertion.py +44 -0
- eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
- eval_framework/metrics/completion/comet.py +56 -0
- eval_framework/metrics/completion/concordance_index.py +38 -0
- eval_framework/metrics/completion/csv_format.py +102 -0
- eval_framework/metrics/completion/cwe_accuracy.py +49 -0
- eval_framework/metrics/completion/exponential_similarity.py +65 -0
- eval_framework/metrics/completion/f1.py +42 -0
- eval_framework/metrics/completion/format_checker.py +56 -0
- eval_framework/metrics/completion/grid_difference.py +77 -0
- eval_framework/metrics/completion/ifeval.py +73 -0
- eval_framework/metrics/completion/json_format.py +179 -0
- eval_framework/metrics/completion/language_checker.py +74 -0
- eval_framework/metrics/completion/length_control.py +83 -0
- eval_framework/metrics/completion/math_reasoning_completion.py +307 -0
- eval_framework/metrics/completion/niah_accuracy.py +163 -0
- eval_framework/metrics/completion/placeholder_checker.py +27 -0
- eval_framework/metrics/completion/repetition.py +88 -0
- eval_framework/metrics/completion/rouge_1.py +35 -0
- eval_framework/metrics/completion/rouge_2.py +45 -0
- eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
- eval_framework/metrics/completion/rouge_l.py +52 -0
- eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
- eval_framework/metrics/completion/ter.py +67 -0
- eval_framework/metrics/completion/text_counter.py +182 -0
- eval_framework/metrics/efficiency/__init__.py +0 -0
- eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
- eval_framework/metrics/llm/__init__.py +0 -0
- eval_framework/metrics/llm/base.py +34 -0
- eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
- eval_framework/metrics/llm/graders/coherence_grader.py +115 -0
- eval_framework/metrics/llm/graders/comparison_grader.py +198 -0
- eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
- eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
- eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
- eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
- eval_framework/metrics/llm/graders/language.py +56 -0
- eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
- eval_framework/metrics/llm/graders/models.py +74 -0
- eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
- eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
- eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
- eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
- eval_framework/metrics/llm/llm_judge_coherence.py +44 -0
- eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
- eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
- eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
- eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
- eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
- eval_framework/metrics/llm/llm_judge_mtbench_pair.py +306 -0
- eval_framework/metrics/llm/llm_judge_mtbench_single.py +210 -0
- eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
- eval_framework/metrics/llm/llm_judge_sql.py +394 -0
- eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
- eval_framework/metrics/llm/utils.py +20 -0
- eval_framework/metrics/loglikelihood/__init__.py +0 -0
- eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
- eval_framework/metrics/loglikelihood/base.py +50 -0
- eval_framework/metrics/loglikelihood/confidence_weighted_accuracy.py +25 -0
- eval_framework/metrics/loglikelihood/dcs.py +43 -0
- eval_framework/metrics/loglikelihood/probability_mass.py +53 -0
- eval_framework/metrics/loglikelihood/ternary.py +42 -0
- eval_framework/py.typed +0 -0
- eval_framework/response_generator.py +351 -0
- eval_framework/result_processors/__init__.py +0 -0
- eval_framework/result_processors/base.py +88 -0
- eval_framework/result_processors/hf_uploader.py +75 -0
- eval_framework/result_processors/result_processor.py +129 -0
- eval_framework/result_processors/wandb_uploader.py +137 -0
- eval_framework/run.py +369 -0
- eval_framework/run_direct.py +42 -0
- eval_framework/shared/types.py +227 -0
- eval_framework/tasks/__init__.py +6 -0
- eval_framework/tasks/base.py +392 -0
- eval_framework/tasks/benchmarks/__init__.py +0 -0
- eval_framework/tasks/benchmarks/aidanbench.py +211 -0
- eval_framework/tasks/benchmarks/arc.py +70 -0
- eval_framework/tasks/benchmarks/arc_de.py +46 -0
- eval_framework/tasks/benchmarks/arc_fi.py +46 -0
- eval_framework/tasks/benchmarks/belebele.py +60 -0
- eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
- eval_framework/tasks/benchmarks/casehold.py +47 -0
- eval_framework/tasks/benchmarks/chembench.py +85 -0
- eval_framework/tasks/benchmarks/copa.py +64 -0
- eval_framework/tasks/benchmarks/duc.py +91 -0
- eval_framework/tasks/benchmarks/flores200.py +133 -0
- eval_framework/tasks/benchmarks/flores_plus.py +84 -0
- eval_framework/tasks/benchmarks/gpqa.py +201 -0
- eval_framework/tasks/benchmarks/gsm8k.py +150 -0
- eval_framework/tasks/benchmarks/hellaswag.py +69 -0
- eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
- eval_framework/tasks/benchmarks/humaneval.py +97 -0
- eval_framework/tasks/benchmarks/ifeval.py +78 -0
- eval_framework/tasks/benchmarks/include.py +119 -0
- eval_framework/tasks/benchmarks/infinitebench.py +302 -0
- eval_framework/tasks/benchmarks/math_reasoning.py +580 -0
- eval_framework/tasks/benchmarks/mbpp.py +192 -0
- eval_framework/tasks/benchmarks/mmlu.py +215 -0
- eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
- eval_framework/tasks/benchmarks/mmlu_pro.py +164 -0
- eval_framework/tasks/benchmarks/mmmlu.py +529 -0
- eval_framework/tasks/benchmarks/openbookqa.py +85 -0
- eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
- eval_framework/tasks/benchmarks/pawsx.py +65 -0
- eval_framework/tasks/benchmarks/piqa.py +64 -0
- eval_framework/tasks/benchmarks/quality.py +56 -0
- eval_framework/tasks/benchmarks/sciq.py +110 -0
- eval_framework/tasks/benchmarks/sphyr.py +79 -0
- eval_framework/tasks/benchmarks/squad.py +211 -0
- eval_framework/tasks/benchmarks/struct_eval.py +116 -0
- eval_framework/tasks/benchmarks/tablebench.py +117 -0
- eval_framework/tasks/benchmarks/triviaqa.py +42 -0
- eval_framework/tasks/benchmarks/truthfulqa.py +119 -0
- eval_framework/tasks/benchmarks/winogender.py +64 -0
- eval_framework/tasks/benchmarks/winogrande.py +69 -0
- eval_framework/tasks/benchmarks/winox.py +57 -0
- eval_framework/tasks/benchmarks/wmt.py +160 -0
- eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
- eval_framework/tasks/eval_config.py +136 -0
- eval_framework/tasks/perturbation.py +83 -0
- eval_framework/tasks/registry.py +186 -0
- eval_framework/tasks/task_loader.py +81 -0
- eval_framework/tasks/task_names.py +324 -0
- eval_framework/tasks/utils.py +584 -0
- eval_framework/utils/constants.py +9 -0
- eval_framework/utils/file_ops.py +245 -0
- eval_framework/utils/generate_task_docs.py +244 -0
- eval_framework/utils/helpers.py +32 -0
- eval_framework/utils/logging.py +62 -0
- eval_framework/utils/packaging.py +52 -0
- eval_framework/utils/tqdm_handler.py +14 -0
- eval_framework-0.2.7.dist-info/METADATA +548 -0
- eval_framework-0.2.7.dist-info/RECORD +170 -0
- eval_framework-0.2.7.dist-info/WHEEL +4 -0
- eval_framework-0.2.7.dist-info/entry_points.txt +3 -0
- template_formatting/README.md +83 -0
- template_formatting/__init__.py +0 -0
- template_formatting/formatter.py +537 -0
- template_formatting/mistral_formatter.py +159 -0
- template_formatting/py.typed +0 -0
eval_framework/main.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
import gc
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import shutil
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Literal
|
|
9
|
+
|
|
10
|
+
import wandb
|
|
11
|
+
|
|
12
|
+
from eval_framework.evaluation_generator import EvaluationGenerator, Result
|
|
13
|
+
from eval_framework.llm.base import BaseLLM
|
|
14
|
+
from eval_framework.response_generator import ResponseGenerator
|
|
15
|
+
from eval_framework.result_processors.hf_uploader import HFUploader
|
|
16
|
+
from eval_framework.result_processors.result_processor import ResultsFileProcessor, generate_output_dir
|
|
17
|
+
from eval_framework.result_processors.wandb_uploader import WandbUploader
|
|
18
|
+
from eval_framework.tasks.eval_config import EvalConfig
|
|
19
|
+
from eval_framework.utils.constants import RED, RESET
|
|
20
|
+
from eval_framework.utils.logging import setup_logging
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def main(
|
|
26
|
+
llm: BaseLLM,
|
|
27
|
+
config: EvalConfig,
|
|
28
|
+
should_preempt_callable: Callable[[], bool] | None = None,
|
|
29
|
+
trial_id: int | None = None,
|
|
30
|
+
*args: Any,
|
|
31
|
+
resource_cleanup: bool = False,
|
|
32
|
+
verbosity: int = 1,
|
|
33
|
+
) -> list[Result]:
|
|
34
|
+
"""Runs the entire evaluation process: responses generation and evaluation."""
|
|
35
|
+
# Set up centralized logging early
|
|
36
|
+
output_dir = generate_output_dir(llm.name, config)
|
|
37
|
+
setup_logging(output_dir=output_dir, log_level=verbosity, log_filename="evaluation.log")
|
|
38
|
+
logger.info(f"Output directory for evaluation: {output_dir}")
|
|
39
|
+
|
|
40
|
+
logger.info(f"{RED}[ Running full evaluation process ------- ]{RESET}")
|
|
41
|
+
logger.info(f"Evaluating {llm.name} on {config.task_name}")
|
|
42
|
+
logger.info(f"Configuration: num_fewshot={config.num_fewshot}, num_samples={config.num_samples}")
|
|
43
|
+
logger.info(f"Output directory: {output_dir}")
|
|
44
|
+
|
|
45
|
+
if not should_preempt_callable:
|
|
46
|
+
should_preempt_callable = lambda: False # noqa: E731
|
|
47
|
+
preemption_data = None
|
|
48
|
+
|
|
49
|
+
if trial_id:
|
|
50
|
+
preemption_data = _read_preemption_data(config, trial_id)
|
|
51
|
+
|
|
52
|
+
if preemption_data is None:
|
|
53
|
+
output_dir = generate_output_dir(llm.name, config)
|
|
54
|
+
wandb_run_id = config.wandb_run_id # defaults to none, if no run_id is provided then it starts a new one
|
|
55
|
+
else:
|
|
56
|
+
logger.info("Found preempted run restarting ...")
|
|
57
|
+
output_dir = preemption_data["output_dir"]
|
|
58
|
+
wandb_run_id = preemption_data.get("wandb_run_id", None)
|
|
59
|
+
|
|
60
|
+
logger.info(f"Output directory: {output_dir}")
|
|
61
|
+
assert output_dir is not None
|
|
62
|
+
|
|
63
|
+
file_processor = ResultsFileProcessor(output_dir)
|
|
64
|
+
response_generator = ResponseGenerator(llm, config, file_processor)
|
|
65
|
+
|
|
66
|
+
with wandb.init(
|
|
67
|
+
entity=config.wandb_entity,
|
|
68
|
+
project=config.wandb_project,
|
|
69
|
+
group=llm.name[:127],
|
|
70
|
+
job_type=config.task_name[:63],
|
|
71
|
+
id=wandb_run_id, # (potentially resuming run after preemption)
|
|
72
|
+
config=response_generator._get_metadata(),
|
|
73
|
+
resume="allow",
|
|
74
|
+
mode=_wandb_mode(config.wandb_project),
|
|
75
|
+
settings=wandb.Settings(disable_code=True), # ("wandb-history" artifacts not needed)
|
|
76
|
+
) as run:
|
|
77
|
+
artifact = getattr(llm, "artifact", None)
|
|
78
|
+
if artifact is not None:
|
|
79
|
+
wandb.use_artifact(artifact)
|
|
80
|
+
for additional_artifact in os.getenv("WANDB_ADDITIONAL_ARTIFACT_REFERENCES", "").split(","):
|
|
81
|
+
if additional_artifact.strip():
|
|
82
|
+
wandb.use_artifact(additional_artifact.strip())
|
|
83
|
+
|
|
84
|
+
_, preempted = response_generator.generate(should_preempt_callable)
|
|
85
|
+
|
|
86
|
+
if preempted:
|
|
87
|
+
logger.info("Response generation was preempted")
|
|
88
|
+
assert trial_id is not None
|
|
89
|
+
run.mark_preempting()
|
|
90
|
+
_save_preemption_data(config, trial_id, output_dir, wandb_run_id=run.id)
|
|
91
|
+
wandb.finish(exit_code=1)
|
|
92
|
+
return []
|
|
93
|
+
# update config from response generator with get metadata
|
|
94
|
+
if trial_id is not None:
|
|
95
|
+
_delete_preemption_file(config, trial_id)
|
|
96
|
+
|
|
97
|
+
if resource_cleanup:
|
|
98
|
+
del response_generator
|
|
99
|
+
gc.collect()
|
|
100
|
+
|
|
101
|
+
evaluator = EvaluationGenerator(config, file_processor)
|
|
102
|
+
results = evaluator.run_eval()
|
|
103
|
+
|
|
104
|
+
upload_success = False
|
|
105
|
+
for uploader in [HFUploader(config), WandbUploader(config)]:
|
|
106
|
+
upload_success |= uploader.upload(llm.name, config, output_dir)
|
|
107
|
+
|
|
108
|
+
if config.delete_output_dir_after_upload and upload_success:
|
|
109
|
+
logger.warning(f"Deleting output directory '{output_dir}' after successful upload(s)!")
|
|
110
|
+
shutil.rmtree(output_dir, ignore_errors=True)
|
|
111
|
+
if output_dir.exists() and any(output_dir.iterdir()):
|
|
112
|
+
logger.warning("Could not delete output directory, some files remain.")
|
|
113
|
+
|
|
114
|
+
return results
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _read_preemption_data(config: EvalConfig, trial_id: int) -> dict[str, Any] | None:
|
|
118
|
+
preemption_file = config.output_dir / f"preemption_trial_{trial_id}.json"
|
|
119
|
+
if not preemption_file.is_file():
|
|
120
|
+
return None
|
|
121
|
+
with open(preemption_file, "rb") as f:
|
|
122
|
+
preemption_data = json.load(f)
|
|
123
|
+
preemption_data["output_dir"] = Path(preemption_data["output_dir"])
|
|
124
|
+
preemption_data["wandb_run_id"] = preemption_data.get("wandb_run_id", "")
|
|
125
|
+
logger.info(f"Loaded preemption data from {preemption_file}")
|
|
126
|
+
return preemption_data
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _save_preemption_data(config: EvalConfig, trial_id: int, output_dir: Path, wandb_run_id: str = "") -> None:
|
|
130
|
+
preemption_file = config.output_dir / f"preemption_trial_{trial_id}.json"
|
|
131
|
+
with open(preemption_file, "w") as f:
|
|
132
|
+
json.dump({"output_dir": str(output_dir), "wandb_run_id": wandb_run_id}, f)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _delete_preemption_file(config: EvalConfig, trial_id: int) -> None:
|
|
136
|
+
preemption_file = config.output_dir / f"preemption_trial_{trial_id}.json"
|
|
137
|
+
if preemption_file.is_file():
|
|
138
|
+
preemption_file.unlink()
|
|
139
|
+
logger.info(f"Deleted preemption file: {preemption_file}")
|
|
140
|
+
else:
|
|
141
|
+
logger.info(f"No preemption file found to delete: {preemption_file}")
|
|
142
|
+
logger.info(f"Saved preemption data to {preemption_file}")
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _wandb_mode(project: str | None) -> Literal["online", "disabled"] | None:
|
|
146
|
+
"""
|
|
147
|
+
Checks to see if a WandB API key is found. If not, wandb starts in offline mode.
|
|
148
|
+
"""
|
|
149
|
+
if project is None:
|
|
150
|
+
logger.warning("No WandB project specified, disabling logging.")
|
|
151
|
+
return "disabled"
|
|
152
|
+
else:
|
|
153
|
+
try:
|
|
154
|
+
api_key = wandb.api.api_key
|
|
155
|
+
if api_key is None:
|
|
156
|
+
logger.warning(
|
|
157
|
+
"""No wandb API key found. Disabling Wandb logging.
|
|
158
|
+
If you have a WandB account set the environment variable 'WANDB_API_KEY'"""
|
|
159
|
+
)
|
|
160
|
+
return "disabled"
|
|
161
|
+
else:
|
|
162
|
+
logger.info("Wandb login detected. Using online mode.")
|
|
163
|
+
except Exception as e:
|
|
164
|
+
logger.warning(f"Wandb login check failed: {e}. Disabling Wandb logging.")
|
|
165
|
+
return "disabled"
|
|
166
|
+
return "online"
|
|
File without changes
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, ConfigDict
|
|
5
|
+
|
|
6
|
+
from eval_framework.shared.types import Error
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MetricResult(BaseModel):
|
|
10
|
+
model_config = ConfigDict(extra="forbid")
|
|
11
|
+
metric_name: str
|
|
12
|
+
value: float | None
|
|
13
|
+
higher_is_better: bool
|
|
14
|
+
llm_judge_prompt: str | None = None
|
|
15
|
+
llm_judge_response: str | None = None
|
|
16
|
+
code_execution_trace: str | None = None
|
|
17
|
+
error: Error | None = None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class classproperty:
|
|
21
|
+
def __init__(self, method: Any) -> None:
|
|
22
|
+
self.method = method
|
|
23
|
+
|
|
24
|
+
def __get__(self, instance: Any, cls: Any) -> Any:
|
|
25
|
+
return self.method(cls)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class BaseMetric[Response](ABC):
|
|
29
|
+
NAME: str
|
|
30
|
+
KEYS: list[str] | None = None
|
|
31
|
+
|
|
32
|
+
@classproperty
|
|
33
|
+
def NAMES(cls) -> list[str]:
|
|
34
|
+
if cls.KEYS is None:
|
|
35
|
+
return [cls.NAME]
|
|
36
|
+
return [f"{cls.NAME}/{k}" for k in cls.KEYS]
|
|
37
|
+
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def calculate(self, response: Response) -> list[MetricResult]:
|
|
40
|
+
raise NotImplementedError
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .accuracy_completion import AccuracyCompletion
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from eval_framework.metrics.base import BaseMetric, MetricResult
|
|
2
|
+
from eval_framework.shared.types import Completion
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class AccuracyCompletion(BaseMetric[Completion]):
|
|
6
|
+
NAME = "Accuracy Completion"
|
|
7
|
+
|
|
8
|
+
def calculate(self, response: Completion) -> list[MetricResult]:
|
|
9
|
+
if response.error is not None:
|
|
10
|
+
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
|
|
11
|
+
|
|
12
|
+
ground_truths = response.ground_truth_list
|
|
13
|
+
is_correct = any(response.completion == gt for gt in ground_truths)
|
|
14
|
+
return [
|
|
15
|
+
MetricResult(metric_name=self.NAME, value=float(is_correct), higher_is_better=True, error=response.error)
|
|
16
|
+
]
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from eval_framework.metrics.base import BaseMetric, MetricResult
|
|
4
|
+
from eval_framework.shared.types import Completion
|
|
5
|
+
|
|
6
|
+
logger = logging.getLogger(__name__)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class AidanBenchMetric(BaseMetric[Completion]):
|
|
10
|
+
NAME = "AidanBench"
|
|
11
|
+
|
|
12
|
+
def calculate(self, response: Completion) -> list[MetricResult]:
|
|
13
|
+
# subtract 2 to not count 1) initial instruction and 2) the latest model response, which caused the stop
|
|
14
|
+
# i.e. was not (unique && coherent)
|
|
15
|
+
num_unique_responses = len(response.messages) - 2 if response.messages is not None else 0
|
|
16
|
+
if num_unique_responses < 0:
|
|
17
|
+
logger.warning(
|
|
18
|
+
"Number of unique responses calculated as negative, setting to 0."
|
|
19
|
+
"Likely something went wrong during answer generation."
|
|
20
|
+
)
|
|
21
|
+
num_unique_responses = 0
|
|
22
|
+
return [
|
|
23
|
+
MetricResult(
|
|
24
|
+
metric_name=f"{self.NAME}/num_responses",
|
|
25
|
+
value=num_unique_responses,
|
|
26
|
+
higher_is_better=True,
|
|
27
|
+
)
|
|
28
|
+
]
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import sacrebleu
|
|
2
|
+
|
|
3
|
+
from eval_framework.exceptions import LogicError
|
|
4
|
+
from eval_framework.metrics.base import BaseMetric, MetricResult
|
|
5
|
+
from eval_framework.shared.types import Completion
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BLEU(BaseMetric[Completion]):
|
|
9
|
+
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
|
|
10
|
+
for evaluating a generated sentence to a reference sentence. It counts matching
|
|
11
|
+
n-grams in the candidate translation to n-grams in the reference text, where
|
|
12
|
+
1-gram or unigram would be each token and a bigram comparison would be each
|
|
13
|
+
word pair. The comparison is made regardless of word order
|
|
14
|
+
Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
|
|
15
|
+
Paper: https://www.aclweb.org/anthology/P02-1040/
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
NAME = "BLEU"
|
|
19
|
+
|
|
20
|
+
def calculate(self, response: Completion) -> list[MetricResult]:
|
|
21
|
+
if response.error is not None:
|
|
22
|
+
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
|
|
23
|
+
|
|
24
|
+
scores = []
|
|
25
|
+
for ground_truth in response.ground_truth_list:
|
|
26
|
+
if ground_truth == "" or ground_truth is None:
|
|
27
|
+
raise LogicError("When calculating BLEU we need a ground truth.")
|
|
28
|
+
|
|
29
|
+
sacre_formatted_completion = [response.completion]
|
|
30
|
+
sacre_formatted_ground_truth = [[ground_truth]]
|
|
31
|
+
scores.append(sacrebleu.corpus_bleu(sacre_formatted_completion, sacre_formatted_ground_truth).score)
|
|
32
|
+
|
|
33
|
+
return [
|
|
34
|
+
MetricResult(metric_name=self.NAME, value=float(max(scores)), higher_is_better=True, error=response.error)
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class LINEWISE_BLEU(BaseMetric[Completion]):
|
|
39
|
+
"""Maximum Line-level BLEU score."""
|
|
40
|
+
|
|
41
|
+
NAME = "Linewise BLEU"
|
|
42
|
+
|
|
43
|
+
def calculate(self, response: Completion) -> list[MetricResult]:
|
|
44
|
+
if response.error is not None:
|
|
45
|
+
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
|
|
46
|
+
|
|
47
|
+
scores = []
|
|
48
|
+
for ground_truth in response.ground_truth_list:
|
|
49
|
+
for sentence in response.completion.split("\n"):
|
|
50
|
+
if sentence == "":
|
|
51
|
+
continue
|
|
52
|
+
|
|
53
|
+
if ground_truth == "" or ground_truth is None:
|
|
54
|
+
raise LogicError("When calculating BLEU we need a ground truth.")
|
|
55
|
+
|
|
56
|
+
sacre_formatted_completion = [sentence]
|
|
57
|
+
sacre_formatted_ground_truth = [[ground_truth]]
|
|
58
|
+
scores.append(sacrebleu.corpus_bleu(sacre_formatted_completion, sacre_formatted_ground_truth).score)
|
|
59
|
+
|
|
60
|
+
return [
|
|
61
|
+
MetricResult(
|
|
62
|
+
metric_name=self.NAME, value=float(max(scores, default=0)), higher_is_better=True, error=response.error
|
|
63
|
+
)
|
|
64
|
+
]
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class ResponseToOriginalBLEU(BaseMetric[Completion]):
|
|
68
|
+
NAME = "Response to Original BLEU"
|
|
69
|
+
|
|
70
|
+
def calculate(self, response: Completion) -> list[MetricResult]:
|
|
71
|
+
if response.error is not None:
|
|
72
|
+
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
|
|
73
|
+
|
|
74
|
+
score = sacrebleu.corpus_bleu([response.completion], [[response.last_user_instruction]]).score
|
|
75
|
+
# scaled to [0, 1] to make aggregation easier
|
|
76
|
+
return [MetricResult(metric_name=self.NAME, value=score / 100, higher_is_better=True, error=response.error)]
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import sacrebleu
|
|
2
|
+
|
|
3
|
+
from eval_framework.exceptions import LogicError
|
|
4
|
+
from eval_framework.metrics.base import BaseMetric, MetricResult
|
|
5
|
+
from eval_framework.shared.types import Completion
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class CHRF(BaseMetric[Completion]):
|
|
9
|
+
"""chrF++ is a tool for automatic evaluation of machine translation output
|
|
10
|
+
based on character n-gram precision and recall enhanced with word n-grams.
|
|
11
|
+
Source: https://github.com/m-popovic/chrF
|
|
12
|
+
Paper: https://www.aclweb.org/anthology/W15-3049.pdf
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
NAME = "chrF"
|
|
16
|
+
|
|
17
|
+
def calculate(self, response: Completion) -> list[MetricResult]:
|
|
18
|
+
if response.error is not None:
|
|
19
|
+
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
|
|
20
|
+
|
|
21
|
+
scores = []
|
|
22
|
+
for ground_truth in response.ground_truth_list:
|
|
23
|
+
if ground_truth == "" or ground_truth is None:
|
|
24
|
+
raise LogicError("When calculating chrF we need a ground truth.")
|
|
25
|
+
|
|
26
|
+
sacre_formatted_completion = [response.completion]
|
|
27
|
+
sacre_formatted_ground_truth = [[ground_truth]]
|
|
28
|
+
scores.append(sacrebleu.corpus_chrf(sacre_formatted_completion, sacre_formatted_ground_truth).score)
|
|
29
|
+
|
|
30
|
+
return [
|
|
31
|
+
MetricResult(metric_name=self.NAME, value=float(max(scores)), higher_is_better=True, error=response.error)
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class LINEWISE_CHRF(BaseMetric[Completion]):
|
|
36
|
+
"""Maximum Line-level chrF++ (Character n-gram F-score) score.
|
|
37
|
+
Paper: https://aclanthology.org/W15-3049/"""
|
|
38
|
+
|
|
39
|
+
NAME = "Linewise chrF"
|
|
40
|
+
|
|
41
|
+
def calculate(self, response: Completion) -> list[MetricResult]:
|
|
42
|
+
if response.error is not None:
|
|
43
|
+
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
|
|
44
|
+
|
|
45
|
+
scores = []
|
|
46
|
+
for ground_truth in response.ground_truth_list:
|
|
47
|
+
for sentence in response.completion.split("\n"):
|
|
48
|
+
if sentence == "":
|
|
49
|
+
continue
|
|
50
|
+
|
|
51
|
+
if ground_truth == "" or ground_truth is None:
|
|
52
|
+
raise LogicError("When calculating chrF we need a ground truth.")
|
|
53
|
+
|
|
54
|
+
sacre_formatted_completion = [sentence]
|
|
55
|
+
sacre_formatted_ground_truth = [[ground_truth]]
|
|
56
|
+
scores.append(sacrebleu.corpus_chrf(sacre_formatted_completion, sacre_formatted_ground_truth).score)
|
|
57
|
+
|
|
58
|
+
return [
|
|
59
|
+
MetricResult(
|
|
60
|
+
metric_name=self.NAME, value=float(max(scores, default=0)), higher_is_better=True, error=response.error
|
|
61
|
+
)
|
|
62
|
+
]
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from eval_framework.metrics.base import BaseMetric, MetricResult
|
|
2
|
+
from eval_framework.shared.types import Completion, Error
|
|
3
|
+
from eval_framework.tasks.utils import run_python_code
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CodeCompletionAssertion(BaseMetric[Completion]):
|
|
7
|
+
NAME = "Code Completion Accuracy"
|
|
8
|
+
|
|
9
|
+
def calculate(self, response: Completion) -> list[MetricResult]:
|
|
10
|
+
if response.error is not None:
|
|
11
|
+
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
|
|
12
|
+
|
|
13
|
+
# this will always be a list, if return is "" this will be an empty list
|
|
14
|
+
code = response.completion
|
|
15
|
+
output = run_python_code(code, image="python:3.12-slim")
|
|
16
|
+
|
|
17
|
+
# Split and filter out empty strings
|
|
18
|
+
output_parts = [part for part in output.split() if part.strip()]
|
|
19
|
+
|
|
20
|
+
if not output_parts:
|
|
21
|
+
last_output = ""
|
|
22
|
+
else:
|
|
23
|
+
last_output = output_parts[-1]
|
|
24
|
+
|
|
25
|
+
success = last_output == "True"
|
|
26
|
+
error = (
|
|
27
|
+
None
|
|
28
|
+
if success
|
|
29
|
+
else Error(
|
|
30
|
+
error_class="CodeCompletionAssertionError",
|
|
31
|
+
message=f"Expected 'True' but got '{last_output}'",
|
|
32
|
+
traceback=output,
|
|
33
|
+
)
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
return [
|
|
37
|
+
MetricResult(
|
|
38
|
+
metric_name=self.NAME,
|
|
39
|
+
value=1.0 if success else 0.0,
|
|
40
|
+
higher_is_better=True,
|
|
41
|
+
error=error,
|
|
42
|
+
code_execution_trace=output,
|
|
43
|
+
)
|
|
44
|
+
]
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
import traceback
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from typing import Self
|
|
4
|
+
|
|
5
|
+
from pydantic import Field
|
|
6
|
+
|
|
7
|
+
from eval_framework.metrics.base import BaseMetric, MetricResult
|
|
8
|
+
from eval_framework.shared.types import BaseMetricContext, Completion, Error, extract_context_metric
|
|
9
|
+
from eval_framework.tasks.utils import CallableSerializer, ExecutionResult, execute_python_code_with_tests
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class CodeExecutionBaseContext(BaseMetricContext):
|
|
13
|
+
run_env: str = Field(description="Name of docker image to run unit-tests inside")
|
|
14
|
+
code_prompt: str = Field(description="Prompt to LLM for code generation")
|
|
15
|
+
test_code: str = Field(description="Python code that contains logic for unit test execution")
|
|
16
|
+
benchmark_timeout: int = Field(default=60, description="Time in seconds for the full test execution run")
|
|
17
|
+
package_downloads: dict[str, str | None] = Field(
|
|
18
|
+
description="a dictionary listing the packages and their respective names in PyPiinto the LLM sandbox"
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class CodeExecutionPassAtOneContext(CodeExecutionBaseContext):
|
|
23
|
+
snippet_merge_fn: str = Field(
|
|
24
|
+
description="logic for merging LLM generated code with test execution code;"
|
|
25
|
+
"this code will be passed into the sandbox to run the testing process"
|
|
26
|
+
"This code is serialized"
|
|
27
|
+
)
|
|
28
|
+
output_parse_fn: str = Field(
|
|
29
|
+
description="logic for parsing the output of test code execution run within the LLM sandbox"
|
|
30
|
+
"This code is serialized"
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class RealtimeCodeExectionContext(CodeExecutionBaseContext):
|
|
35
|
+
snippet_merge_fn: Callable[[str, str], str] = Field(
|
|
36
|
+
description="logic for merging LLM generated code with test execution code;"
|
|
37
|
+
"this code will be passed into the sandbox to run the testing process"
|
|
38
|
+
"This code is deserialized"
|
|
39
|
+
)
|
|
40
|
+
output_parse_fn: Callable[[str], ExecutionResult] = Field(
|
|
41
|
+
description="logic for parsing the output of test code execution run within the LLM sandbox"
|
|
42
|
+
"This code is deserialized"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
@classmethod
|
|
46
|
+
def from_context(cls, context: CodeExecutionPassAtOneContext) -> Self:
|
|
47
|
+
return cls(
|
|
48
|
+
run_env=context.run_env,
|
|
49
|
+
code_prompt=context.code_prompt,
|
|
50
|
+
test_code=context.test_code,
|
|
51
|
+
benchmark_timeout=context.benchmark_timeout,
|
|
52
|
+
snippet_merge_fn=CallableSerializer.decode(context.snippet_merge_fn),
|
|
53
|
+
output_parse_fn=CallableSerializer.decode(context.output_parse_fn),
|
|
54
|
+
package_downloads=context.package_downloads,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class CodeExecutionPassAtOne(BaseMetric[Completion]):
|
|
59
|
+
NAME = "code-execution-pass@1"
|
|
60
|
+
|
|
61
|
+
def __init__(self) -> None:
|
|
62
|
+
self.k = 1
|
|
63
|
+
# NOTE : this serializer should be the same class as initialized in the benchmark
|
|
64
|
+
self.serializer = CallableSerializer()
|
|
65
|
+
|
|
66
|
+
def calculate(self, response: Completion) -> list[MetricResult]:
|
|
67
|
+
if response.error is not None:
|
|
68
|
+
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
|
|
69
|
+
try:
|
|
70
|
+
context = extract_context_metric(response, CodeExecutionPassAtOneContext)
|
|
71
|
+
parsed_context = RealtimeCodeExectionContext.from_context(context)
|
|
72
|
+
except Exception as e:
|
|
73
|
+
raise Exception(f"Failed to rebuild parsing functions => {e}")
|
|
74
|
+
|
|
75
|
+
n = 1 # we only support N=1 at the moment
|
|
76
|
+
try:
|
|
77
|
+
c, output = self._count_correct_samples(response.completion, parsed_context)
|
|
78
|
+
except Exception as e:
|
|
79
|
+
error = Error(error_class=e.__class__.__name__, message=str(e), traceback=traceback.format_exc())
|
|
80
|
+
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=error)]
|
|
81
|
+
|
|
82
|
+
pass_at_k_value = estimate_pass_at_k(n, c, self.k)
|
|
83
|
+
return [
|
|
84
|
+
MetricResult(
|
|
85
|
+
metric_name=self.NAME,
|
|
86
|
+
value=pass_at_k_value,
|
|
87
|
+
higher_is_better=True,
|
|
88
|
+
error=response.error,
|
|
89
|
+
code_execution_trace=output,
|
|
90
|
+
)
|
|
91
|
+
]
|
|
92
|
+
|
|
93
|
+
def _count_correct_samples(self, completion: str, context: RealtimeCodeExectionContext) -> tuple[int, str]:
|
|
94
|
+
result = execute_python_code_with_tests(
|
|
95
|
+
code=completion,
|
|
96
|
+
test_code=context.test_code,
|
|
97
|
+
package_mapping=context.package_downloads,
|
|
98
|
+
merge_code_fn=context.snippet_merge_fn,
|
|
99
|
+
image=context.run_env,
|
|
100
|
+
timeout=context.benchmark_timeout,
|
|
101
|
+
parse_output_fn=context.output_parse_fn,
|
|
102
|
+
)
|
|
103
|
+
return (1 if result.success else 0), result.output
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def estimate_pass_at_k(n: int, c: int, k: int) -> float:
|
|
107
|
+
"""
|
|
108
|
+
Estimates pass@k for a single problem.
|
|
109
|
+
|
|
110
|
+
Parameters:
|
|
111
|
+
n (int): Total number of generated samples.
|
|
112
|
+
c (int): Number of correct samples.
|
|
113
|
+
k (int): Number of attempts or samples considered.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
float: The pass@k value.
|
|
117
|
+
"""
|
|
118
|
+
if n - c < k:
|
|
119
|
+
return 1.0
|
|
120
|
+
|
|
121
|
+
# Calculate the probability that at least one of the k samples is correct
|
|
122
|
+
probability = 1.0
|
|
123
|
+
for i in range(k):
|
|
124
|
+
probability *= (n - c - i) / (n - i)
|
|
125
|
+
|
|
126
|
+
return 1.0 - probability
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from comet import download_model, load_from_checkpoint
|
|
3
|
+
|
|
4
|
+
from eval_framework.exceptions import LogicError
|
|
5
|
+
from eval_framework.metrics.base import BaseMetric, MetricResult
|
|
6
|
+
from eval_framework.shared.types import Completion, UntemplatedPrompt
|
|
7
|
+
from eval_framework.utils.constants import ROOT_DIR
|
|
8
|
+
|
|
9
|
+
SAVING_DIR = ROOT_DIR / "comet_model"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class COMET(BaseMetric[Completion]):
|
|
13
|
+
"""COMET is a neural, multilingual framework for evaluating machine translation quality by leveraging cross-lingual
|
|
14
|
+
pretrained language models to achieve state-of-the-art correlation with human judgments
|
|
15
|
+
Note: this requires a Hugging Face token with access to the model: https://huggingface.co/Unbabel/XCOMET-XL
|
|
16
|
+
Source: https://github.com/Unbabel/COMET
|
|
17
|
+
Paper: https://arxiv.org/abs/2009.09025
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
NAME = "COMET"
|
|
21
|
+
|
|
22
|
+
def __init__(self) -> None:
|
|
23
|
+
checkpoint_path = download_model("Unbabel/XCOMET-XL", saving_directory=SAVING_DIR)
|
|
24
|
+
self.model = load_from_checkpoint(checkpoint_path)
|
|
25
|
+
assert torch.cuda.is_available(), "COMET requires a GPU"
|
|
26
|
+
|
|
27
|
+
def calculate(self, response: Completion) -> list[MetricResult]:
|
|
28
|
+
if response.error is not None:
|
|
29
|
+
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
|
|
30
|
+
|
|
31
|
+
if (
|
|
32
|
+
response.context is None
|
|
33
|
+
or not isinstance(response.context, UntemplatedPrompt)
|
|
34
|
+
or response.context.untemplated_prompt == ""
|
|
35
|
+
):
|
|
36
|
+
raise LogicError("When calculating COMET we need an untemplated prompt.")
|
|
37
|
+
|
|
38
|
+
scores = []
|
|
39
|
+
for ground_truth in response.ground_truth_list:
|
|
40
|
+
if ground_truth == "" or ground_truth is None:
|
|
41
|
+
raise LogicError("When calculating COMET we need a ground truth.")
|
|
42
|
+
|
|
43
|
+
data = [
|
|
44
|
+
{
|
|
45
|
+
"src": response.context.untemplated_prompt.strip(),
|
|
46
|
+
"mt": response.completion.strip(),
|
|
47
|
+
"ref": ground_truth.strip(),
|
|
48
|
+
},
|
|
49
|
+
]
|
|
50
|
+
with torch.no_grad():
|
|
51
|
+
model_output = self.model.predict(data, gpus=1)
|
|
52
|
+
scores.append(model_output.system_score)
|
|
53
|
+
|
|
54
|
+
return [
|
|
55
|
+
MetricResult(metric_name=self.NAME, value=float(max(scores)), higher_is_better=True, error=response.error)
|
|
56
|
+
]
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
|
|
3
|
+
from eval_framework.metrics.base import BaseMetric, MetricResult
|
|
4
|
+
from eval_framework.shared.types import Completion
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ConcordanceIndex(BaseMetric[Completion]):
|
|
8
|
+
NAME = "ConcordanceIndex"
|
|
9
|
+
|
|
10
|
+
def calculate(self, response: Completion) -> list[MetricResult]:
|
|
11
|
+
if response.error is not None:
|
|
12
|
+
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
|
|
13
|
+
|
|
14
|
+
ground_truths = [gt for gt in response.ground_truth_list if gt is not None]
|
|
15
|
+
if not ground_truths:
|
|
16
|
+
return [MetricResult(metric_name=self.NAME, value=0.0, higher_is_better=True, error=response.error)]
|
|
17
|
+
|
|
18
|
+
concordance_count = max([calculate_concordance_index(gt, response.completion) for gt in ground_truths])
|
|
19
|
+
return [
|
|
20
|
+
MetricResult(metric_name=self.NAME, value=concordance_count, higher_is_better=True, error=response.error)
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def calculate_concordance_index(
|
|
25
|
+
ground_truth: str,
|
|
26
|
+
completion: str,
|
|
27
|
+
) -> float:
|
|
28
|
+
ground_truth_arr = ast.literal_eval(ground_truth)
|
|
29
|
+
completion_arr = ast.literal_eval(completion)
|
|
30
|
+
|
|
31
|
+
if len(ground_truth_arr) != len(completion_arr):
|
|
32
|
+
return 0
|
|
33
|
+
|
|
34
|
+
concordance_count = 0
|
|
35
|
+
for gt, c in zip(ground_truth_arr, completion_arr):
|
|
36
|
+
concordance_count += 1 if gt == c else 0
|
|
37
|
+
|
|
38
|
+
return concordance_count / len(ground_truth_arr)
|