eval-framework 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eval_framework/__init__.py +7 -0
- eval_framework/base_config.py +36 -0
- eval_framework/context/__init__.py +0 -0
- eval_framework/context/determined.py +177 -0
- eval_framework/context/eval.py +121 -0
- eval_framework/context/local.py +78 -0
- eval_framework/evaluation_generator.py +234 -0
- eval_framework/exceptions.py +2 -0
- eval_framework/external/ifeval_impl/README.md +5 -0
- eval_framework/external/ifeval_impl/instructions.py +1523 -0
- eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
- eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
- eval_framework/external/ifeval_impl/utils.py +135 -0
- eval_framework/llm/__init__.py +0 -0
- eval_framework/llm/aleph_alpha.py +432 -0
- eval_framework/llm/base.py +180 -0
- eval_framework/llm/huggingface.py +418 -0
- eval_framework/llm/mistral.py +88 -0
- eval_framework/llm/models.py +28 -0
- eval_framework/llm/openai.py +400 -0
- eval_framework/llm/vllm.py +554 -0
- eval_framework/logger.py +3 -0
- eval_framework/main.py +166 -0
- eval_framework/metrics/__init__.py +0 -0
- eval_framework/metrics/base.py +40 -0
- eval_framework/metrics/completion/__init__.py +1 -0
- eval_framework/metrics/completion/accuracy_completion.py +16 -0
- eval_framework/metrics/completion/aidanbench.py +28 -0
- eval_framework/metrics/completion/bleu.py +76 -0
- eval_framework/metrics/completion/chrf.py +62 -0
- eval_framework/metrics/completion/code_assertion.py +44 -0
- eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
- eval_framework/metrics/completion/comet.py +56 -0
- eval_framework/metrics/completion/concordance_index.py +38 -0
- eval_framework/metrics/completion/csv_format.py +102 -0
- eval_framework/metrics/completion/cwe_accuracy.py +49 -0
- eval_framework/metrics/completion/exponential_similarity.py +65 -0
- eval_framework/metrics/completion/f1.py +42 -0
- eval_framework/metrics/completion/format_checker.py +56 -0
- eval_framework/metrics/completion/grid_difference.py +77 -0
- eval_framework/metrics/completion/ifeval.py +73 -0
- eval_framework/metrics/completion/json_format.py +179 -0
- eval_framework/metrics/completion/language_checker.py +74 -0
- eval_framework/metrics/completion/length_control.py +83 -0
- eval_framework/metrics/completion/math_reasoning_completion.py +307 -0
- eval_framework/metrics/completion/niah_accuracy.py +163 -0
- eval_framework/metrics/completion/placeholder_checker.py +27 -0
- eval_framework/metrics/completion/repetition.py +88 -0
- eval_framework/metrics/completion/rouge_1.py +35 -0
- eval_framework/metrics/completion/rouge_2.py +45 -0
- eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
- eval_framework/metrics/completion/rouge_l.py +52 -0
- eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
- eval_framework/metrics/completion/ter.py +67 -0
- eval_framework/metrics/completion/text_counter.py +182 -0
- eval_framework/metrics/efficiency/__init__.py +0 -0
- eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
- eval_framework/metrics/llm/__init__.py +0 -0
- eval_framework/metrics/llm/base.py +34 -0
- eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
- eval_framework/metrics/llm/graders/coherence_grader.py +115 -0
- eval_framework/metrics/llm/graders/comparison_grader.py +198 -0
- eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
- eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
- eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
- eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
- eval_framework/metrics/llm/graders/language.py +56 -0
- eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
- eval_framework/metrics/llm/graders/models.py +74 -0
- eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
- eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
- eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
- eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
- eval_framework/metrics/llm/llm_judge_coherence.py +44 -0
- eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
- eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
- eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
- eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
- eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
- eval_framework/metrics/llm/llm_judge_mtbench_pair.py +306 -0
- eval_framework/metrics/llm/llm_judge_mtbench_single.py +210 -0
- eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
- eval_framework/metrics/llm/llm_judge_sql.py +394 -0
- eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
- eval_framework/metrics/llm/utils.py +20 -0
- eval_framework/metrics/loglikelihood/__init__.py +0 -0
- eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
- eval_framework/metrics/loglikelihood/base.py +50 -0
- eval_framework/metrics/loglikelihood/confidence_weighted_accuracy.py +25 -0
- eval_framework/metrics/loglikelihood/dcs.py +43 -0
- eval_framework/metrics/loglikelihood/probability_mass.py +53 -0
- eval_framework/metrics/loglikelihood/ternary.py +42 -0
- eval_framework/py.typed +0 -0
- eval_framework/response_generator.py +351 -0
- eval_framework/result_processors/__init__.py +0 -0
- eval_framework/result_processors/base.py +88 -0
- eval_framework/result_processors/hf_uploader.py +75 -0
- eval_framework/result_processors/result_processor.py +129 -0
- eval_framework/result_processors/wandb_uploader.py +137 -0
- eval_framework/run.py +369 -0
- eval_framework/run_direct.py +42 -0
- eval_framework/shared/types.py +227 -0
- eval_framework/tasks/__init__.py +6 -0
- eval_framework/tasks/base.py +392 -0
- eval_framework/tasks/benchmarks/__init__.py +0 -0
- eval_framework/tasks/benchmarks/aidanbench.py +211 -0
- eval_framework/tasks/benchmarks/arc.py +70 -0
- eval_framework/tasks/benchmarks/arc_de.py +46 -0
- eval_framework/tasks/benchmarks/arc_fi.py +46 -0
- eval_framework/tasks/benchmarks/belebele.py +60 -0
- eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
- eval_framework/tasks/benchmarks/casehold.py +47 -0
- eval_framework/tasks/benchmarks/chembench.py +85 -0
- eval_framework/tasks/benchmarks/copa.py +64 -0
- eval_framework/tasks/benchmarks/duc.py +91 -0
- eval_framework/tasks/benchmarks/flores200.py +133 -0
- eval_framework/tasks/benchmarks/flores_plus.py +84 -0
- eval_framework/tasks/benchmarks/gpqa.py +201 -0
- eval_framework/tasks/benchmarks/gsm8k.py +150 -0
- eval_framework/tasks/benchmarks/hellaswag.py +69 -0
- eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
- eval_framework/tasks/benchmarks/humaneval.py +97 -0
- eval_framework/tasks/benchmarks/ifeval.py +78 -0
- eval_framework/tasks/benchmarks/include.py +119 -0
- eval_framework/tasks/benchmarks/infinitebench.py +302 -0
- eval_framework/tasks/benchmarks/math_reasoning.py +580 -0
- eval_framework/tasks/benchmarks/mbpp.py +192 -0
- eval_framework/tasks/benchmarks/mmlu.py +215 -0
- eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
- eval_framework/tasks/benchmarks/mmlu_pro.py +164 -0
- eval_framework/tasks/benchmarks/mmmlu.py +529 -0
- eval_framework/tasks/benchmarks/openbookqa.py +85 -0
- eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
- eval_framework/tasks/benchmarks/pawsx.py +65 -0
- eval_framework/tasks/benchmarks/piqa.py +64 -0
- eval_framework/tasks/benchmarks/quality.py +56 -0
- eval_framework/tasks/benchmarks/sciq.py +110 -0
- eval_framework/tasks/benchmarks/sphyr.py +79 -0
- eval_framework/tasks/benchmarks/squad.py +211 -0
- eval_framework/tasks/benchmarks/struct_eval.py +116 -0
- eval_framework/tasks/benchmarks/tablebench.py +117 -0
- eval_framework/tasks/benchmarks/triviaqa.py +42 -0
- eval_framework/tasks/benchmarks/truthfulqa.py +119 -0
- eval_framework/tasks/benchmarks/winogender.py +64 -0
- eval_framework/tasks/benchmarks/winogrande.py +69 -0
- eval_framework/tasks/benchmarks/winox.py +57 -0
- eval_framework/tasks/benchmarks/wmt.py +160 -0
- eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
- eval_framework/tasks/eval_config.py +136 -0
- eval_framework/tasks/perturbation.py +83 -0
- eval_framework/tasks/registry.py +186 -0
- eval_framework/tasks/task_loader.py +81 -0
- eval_framework/tasks/task_names.py +324 -0
- eval_framework/tasks/utils.py +584 -0
- eval_framework/utils/constants.py +9 -0
- eval_framework/utils/file_ops.py +245 -0
- eval_framework/utils/generate_task_docs.py +244 -0
- eval_framework/utils/helpers.py +32 -0
- eval_framework/utils/logging.py +62 -0
- eval_framework/utils/packaging.py +52 -0
- eval_framework/utils/tqdm_handler.py +14 -0
- eval_framework-0.2.7.dist-info/METADATA +548 -0
- eval_framework-0.2.7.dist-info/RECORD +170 -0
- eval_framework-0.2.7.dist-info/WHEEL +4 -0
- eval_framework-0.2.7.dist-info/entry_points.txt +3 -0
- template_formatting/README.md +83 -0
- template_formatting/__init__.py +0 -0
- template_formatting/formatter.py +537 -0
- template_formatting/mistral_formatter.py +159 -0
- template_formatting/py.typed +0 -0
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from eval_framework.metrics.base import MetricResult
|
|
2
|
+
from eval_framework.metrics.loglikelihood.base import BaseLoglikelihoodMetric
|
|
3
|
+
from eval_framework.shared.types import Loglikelihood
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TernaryScore(BaseLoglikelihoodMetric):
|
|
7
|
+
"""Based on Kalai et al. (2025) Why language models hallucinate. arXiv:2509.04664"""
|
|
8
|
+
|
|
9
|
+
NAME = "Ternary Score"
|
|
10
|
+
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
*,
|
|
14
|
+
lc: float = 1.0, # Default reward for correct answers
|
|
15
|
+
lw: float = 1.0, # Default penalty for wrong answers (note: this will be negated in the score)
|
|
16
|
+
len_normalised: bool = True,
|
|
17
|
+
) -> None:
|
|
18
|
+
super().__init__(len_normalised=len_normalised)
|
|
19
|
+
self._lc = float(lc)
|
|
20
|
+
self._lw = float(lw)
|
|
21
|
+
if not (self._lc >= 0 and self._lw >= 0):
|
|
22
|
+
raise ValueError(f"Invalid reward and penalty values: lc={self._lc}, lw={self._lw}. Require lc>=0, lw>=0.")
|
|
23
|
+
|
|
24
|
+
def calculate(self, response: Loglikelihood) -> list[MetricResult]:
|
|
25
|
+
if response.error is not None:
|
|
26
|
+
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
|
|
27
|
+
|
|
28
|
+
loglikelihoods, probs = self._compute_probabilities(response.loglikelihoods)
|
|
29
|
+
ground_truths = self._gather_ground_truths(response)
|
|
30
|
+
|
|
31
|
+
completion_text = max(loglikelihoods, key=loglikelihoods.get) # type: ignore[arg-type]
|
|
32
|
+
norm_text = self._normalise_text(completion_text)
|
|
33
|
+
idk_key = self._normalise_text(list(response.loglikelihoods.keys())[-1]) # assumes last key is "IDK" option
|
|
34
|
+
|
|
35
|
+
if norm_text in ground_truths:
|
|
36
|
+
score = self._lc
|
|
37
|
+
elif norm_text == idk_key:
|
|
38
|
+
score = 0.0
|
|
39
|
+
else:
|
|
40
|
+
score = -self._lw
|
|
41
|
+
|
|
42
|
+
return [MetricResult(metric_name=self.NAME, value=score, higher_is_better=True, error=response.error)]
|
eval_framework/py.typed
ADDED
|
File without changes
|
|
@@ -0,0 +1,351 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import time
|
|
3
|
+
import traceback
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from datetime import UTC, datetime
|
|
6
|
+
from functools import partial
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from eval_framework.tasks.registry import get_task
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
from determined._info import get_cluster_info
|
|
13
|
+
except ImportError:
|
|
14
|
+
get_cluster_info = None # type: ignore[assignment]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
from tqdm import tqdm
|
|
18
|
+
|
|
19
|
+
from eval_framework import __version__ as eval_framework_version
|
|
20
|
+
from eval_framework.llm.base import BaseLLM
|
|
21
|
+
from eval_framework.result_processors.result_processor import ResultsFileProcessor
|
|
22
|
+
from eval_framework.shared.types import (
|
|
23
|
+
Completion,
|
|
24
|
+
Error,
|
|
25
|
+
Loglikelihood,
|
|
26
|
+
RawLoglikelihood,
|
|
27
|
+
)
|
|
28
|
+
from eval_framework.tasks.base import Language, ResponseType, Sample
|
|
29
|
+
from eval_framework.tasks.eval_config import EvalConfig
|
|
30
|
+
from eval_framework.tasks.perturbation import create_perturbation_class
|
|
31
|
+
from eval_framework.tasks.utils import raise_errors
|
|
32
|
+
from eval_framework.utils.constants import RED, RESET
|
|
33
|
+
from eval_framework.utils.tqdm_handler import get_disable_bar_flag, safe_tqdm_write
|
|
34
|
+
|
|
35
|
+
logger = logging.getLogger(__name__)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def map_language_to_value(
|
|
39
|
+
language: Language | dict[str, Language] | dict[str, tuple[Language, Language]] | None,
|
|
40
|
+
) -> str | dict[str, str] | dict[str, tuple[str, str]] | None:
|
|
41
|
+
if language is None:
|
|
42
|
+
return None
|
|
43
|
+
elif isinstance(language, Language):
|
|
44
|
+
return language.value
|
|
45
|
+
elif isinstance(language, dict):
|
|
46
|
+
if isinstance(list(language.values())[0], Language):
|
|
47
|
+
return {k: v.value for k, v in language.items()} # type: ignore[union-attr]
|
|
48
|
+
else:
|
|
49
|
+
return {k: (v[0].value, v[1].value) for k, v in language.items()} # type: ignore[index]
|
|
50
|
+
else:
|
|
51
|
+
raise ValueError(f"Invalid language: {language}")
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class ResponseGenerator:
|
|
55
|
+
def __init__(self, llm: BaseLLM, config: EvalConfig, result_processor: ResultsFileProcessor) -> None:
|
|
56
|
+
self.few_shot = config.num_fewshot
|
|
57
|
+
self.task_name = config.task_name
|
|
58
|
+
self.llm = llm
|
|
59
|
+
self.config = config
|
|
60
|
+
self.result_processor = result_processor
|
|
61
|
+
self.num_samples = config.num_samples
|
|
62
|
+
self.save_intermediate_results = config.save_intermediate_results
|
|
63
|
+
|
|
64
|
+
task_class = get_task(config.task_name)
|
|
65
|
+
|
|
66
|
+
if config.perturbation_config is not None:
|
|
67
|
+
perturbation_task_class = create_perturbation_class(task_class, config.perturbation_config)
|
|
68
|
+
self.task = perturbation_task_class.with_overwrite(
|
|
69
|
+
self.few_shot, custom_subjects=self.config.task_subjects, custom_hf_revision=self.config.hf_revision
|
|
70
|
+
)
|
|
71
|
+
else:
|
|
72
|
+
self.task = task_class.with_overwrite(
|
|
73
|
+
self.few_shot, custom_subjects=self.config.task_subjects, custom_hf_revision=self.config.hf_revision
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
self.response_type = task_class.RESPONSE_TYPE
|
|
77
|
+
|
|
78
|
+
def _llm_task_param_precedence(self) -> tuple[list[str] | None, int | None]:
|
|
79
|
+
"""
|
|
80
|
+
sets the stop_sequences and max_tokens values to be used in the completion generation.
|
|
81
|
+
Max token and stop sequence values have an order of precedence:
|
|
82
|
+
|
|
83
|
+
LLM attributes take precedence over task attributes, and therefore overload them.
|
|
84
|
+
:return: stop_sequences, max_tokens
|
|
85
|
+
"""
|
|
86
|
+
llm_stop_sequences = getattr(self.llm, "stop_sequences", None)
|
|
87
|
+
llm_max_tokens = getattr(self.llm, "max_tokens", None)
|
|
88
|
+
task_stop_sequences = getattr(self.task, "stop_sequences", None)
|
|
89
|
+
task_max_tokens = self.config.max_tokens or getattr(self.task, "max_tokens", None)
|
|
90
|
+
# if both task and model define a max_token, the smaller value is used
|
|
91
|
+
max_tokens = min([x for x in [llm_max_tokens, task_max_tokens] if x is not None], default=None)
|
|
92
|
+
logger.info(f"Set max_tokens to {max_tokens}")
|
|
93
|
+
# if both task and model define stop sequences, those are merged into one list
|
|
94
|
+
stop_sequences_merged = (llm_stop_sequences or []) + (task_stop_sequences or [])
|
|
95
|
+
stop_sequences = sorted(list(set(stop_sequences_merged))) if stop_sequences_merged else None
|
|
96
|
+
logger.info(f"Set stop_sequences to {stop_sequences}")
|
|
97
|
+
return stop_sequences, max_tokens
|
|
98
|
+
|
|
99
|
+
def _generate_loglikelihoods(self, samples: list[Sample]) -> list[Loglikelihood]:
|
|
100
|
+
"""
|
|
101
|
+
Generate log likelihoods when a sample is run against the model.
|
|
102
|
+
:param sample: sample to run the task against
|
|
103
|
+
:return: loglikelihoods
|
|
104
|
+
"""
|
|
105
|
+
raw_loglikelihoods: list[RawLoglikelihood]
|
|
106
|
+
try:
|
|
107
|
+
raw_loglikelihoods = self.llm.logprobs(samples)
|
|
108
|
+
except Exception as e:
|
|
109
|
+
if raise_errors():
|
|
110
|
+
raise e
|
|
111
|
+
logger.info(f"Error: {e.__class__.__name__} {e}")
|
|
112
|
+
assert len(samples) == 1, "LLMs not handling errors are not supported in batch mode"
|
|
113
|
+
raw_loglikelihoods = [
|
|
114
|
+
RawLoglikelihood(
|
|
115
|
+
prompt="",
|
|
116
|
+
prompt_sequence_positions=0,
|
|
117
|
+
loglikelihoods={},
|
|
118
|
+
loglikelihoods_sequence_positions={},
|
|
119
|
+
raw_loglikelihood_error=Error(
|
|
120
|
+
error_class=e.__class__.__name__, message=str(e), traceback=traceback.format_exc()
|
|
121
|
+
),
|
|
122
|
+
)
|
|
123
|
+
for _ in range(len(samples))
|
|
124
|
+
]
|
|
125
|
+
|
|
126
|
+
loglikelihood_list = []
|
|
127
|
+
for idx, sample in enumerate(samples):
|
|
128
|
+
raw_loglikelihood = raw_loglikelihoods[idx]
|
|
129
|
+
assert sample.ground_truth is not None
|
|
130
|
+
loglikelihood_list.append(
|
|
131
|
+
Loglikelihood(
|
|
132
|
+
id=sample.id,
|
|
133
|
+
subject=sample.subject,
|
|
134
|
+
ground_truth=sample.ground_truth,
|
|
135
|
+
prompt=raw_loglikelihood.prompt,
|
|
136
|
+
prompt_sequence_positions=raw_loglikelihood.prompt_sequence_positions,
|
|
137
|
+
concat_compression=raw_loglikelihood.concat_compression,
|
|
138
|
+
loglikelihoods=raw_loglikelihood.loglikelihoods,
|
|
139
|
+
loglikelihoods_sequence_positions=raw_loglikelihood.loglikelihoods_sequence_positions,
|
|
140
|
+
error=raw_loglikelihood.raw_loglikelihood_error,
|
|
141
|
+
)
|
|
142
|
+
)
|
|
143
|
+
return loglikelihood_list
|
|
144
|
+
|
|
145
|
+
def _generative_output_type_selector(self) -> Callable[[list[Sample]], list[Completion] | list[Loglikelihood]]:
|
|
146
|
+
"""
|
|
147
|
+
Selects the generative output type based on the response type.
|
|
148
|
+
:return: function to generate responses
|
|
149
|
+
"""
|
|
150
|
+
match self.response_type:
|
|
151
|
+
case ResponseType.COMPLETION:
|
|
152
|
+
stop_sequences, max_tokens = self._llm_task_param_precedence()
|
|
153
|
+
return partial(
|
|
154
|
+
self.task.generate_completions, self.llm, stop_sequences=stop_sequences, max_tokens=max_tokens
|
|
155
|
+
) # type: ignore[call-arg]
|
|
156
|
+
case ResponseType.LOGLIKELIHOODS:
|
|
157
|
+
return self._generate_loglikelihoods
|
|
158
|
+
case _:
|
|
159
|
+
raise KeyError(f"Task type {self.task} not supported")
|
|
160
|
+
|
|
161
|
+
def _run_task_against_model(
|
|
162
|
+
self, should_preempt_callable: Callable[[], bool]
|
|
163
|
+
) -> tuple[list[Completion | Loglikelihood], bool]:
|
|
164
|
+
"""
|
|
165
|
+
Runs the task against the model and generates responses.
|
|
166
|
+
:param should_preempt_callable: function to check if preempt is called
|
|
167
|
+
:return: list of responses, preempted
|
|
168
|
+
"""
|
|
169
|
+
logger.info(f"{RED}[ Running task {self.task.NAME} against model ------------ ]{RESET}")
|
|
170
|
+
self.start_time, monotonic_start = time.time(), time.monotonic()
|
|
171
|
+
run_fn = self._generative_output_type_selector()
|
|
172
|
+
self._verify_loaded_metadata_compatibility()
|
|
173
|
+
responses = self.result_processor.load_responses() # load responses if present
|
|
174
|
+
subject_response_id_mapping = self._map_subject_response_ids(responses)
|
|
175
|
+
self.result_processor.save_metadata(self._get_metadata())
|
|
176
|
+
responses, preempted = self._curate_responses(
|
|
177
|
+
responses, subject_response_id_mapping, run_fn, should_preempt_callable
|
|
178
|
+
)
|
|
179
|
+
self.end_time, monotonic_end = time.time(), time.monotonic()
|
|
180
|
+
self.total_time = monotonic_end - monotonic_start
|
|
181
|
+
self.result_processor.save_metadata(self._get_metadata()) # overwrite with updated timing
|
|
182
|
+
|
|
183
|
+
return responses, preempted
|
|
184
|
+
|
|
185
|
+
def _map_subject_response_ids(self, responses: list[Completion | Loglikelihood]) -> dict[str, set[int]]:
|
|
186
|
+
"""
|
|
187
|
+
Maps subject to response id
|
|
188
|
+
:param responses: list of responses
|
|
189
|
+
:return: mapping of subject to response id
|
|
190
|
+
"""
|
|
191
|
+
subject_response_id_mapping = {}
|
|
192
|
+
if responses:
|
|
193
|
+
response_subjects = {resp.subject for resp in responses}
|
|
194
|
+
subject_response_id_mapping = {
|
|
195
|
+
response_subject: set([resp.id for resp in responses if resp.subject == response_subject])
|
|
196
|
+
for response_subject in response_subjects
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
return subject_response_id_mapping
|
|
200
|
+
|
|
201
|
+
def _curate_responses(
|
|
202
|
+
self,
|
|
203
|
+
responses: list[Completion | Loglikelihood],
|
|
204
|
+
subject_response_id_mapping: dict[str, set[int]],
|
|
205
|
+
generative_output_function: Callable[[list[Sample]], list[Completion] | list[Loglikelihood]],
|
|
206
|
+
should_preempt_callable: Callable[[], bool],
|
|
207
|
+
) -> tuple[list[Completion | Loglikelihood], bool]:
|
|
208
|
+
"""
|
|
209
|
+
Generates responses for the task and saves them along with metadata.
|
|
210
|
+
:param responses: list of responses
|
|
211
|
+
:param subject_response_id_mapping: mapping of subject to response id
|
|
212
|
+
:param generative_output_function: function to generate responses
|
|
213
|
+
:param metadata: metadata dictionary
|
|
214
|
+
:param should_preempt_callable: function to check if preempt is called
|
|
215
|
+
:return: None
|
|
216
|
+
"""
|
|
217
|
+
|
|
218
|
+
def _process_batch(samples_batch: list[Sample]) -> None:
|
|
219
|
+
if not samples_batch:
|
|
220
|
+
return
|
|
221
|
+
if len(samples_batch) > 1:
|
|
222
|
+
log_msg = "Processing batch..."
|
|
223
|
+
logger.info(log_msg) # For log files
|
|
224
|
+
safe_tqdm_write(log_msg) # For console display with tqdm
|
|
225
|
+
|
|
226
|
+
responses_batch = generative_output_function(samples_batch)
|
|
227
|
+
responses.extend(responses_batch)
|
|
228
|
+
if self.save_intermediate_results:
|
|
229
|
+
for response in responses_batch:
|
|
230
|
+
self.result_processor.save_response(response)
|
|
231
|
+
|
|
232
|
+
# In order to enable parallelism we group samples in batches and send them in parallel to the `run_fn`.
|
|
233
|
+
# The BaseLLM class is then in charge of managing the parallelism (eg, using AsyncClient in API models).
|
|
234
|
+
# If samples_batch_size = 1, samples are run sequentially; in any case, we return here after finishing each
|
|
235
|
+
# individual batch to honor preemption requests and save cached results.
|
|
236
|
+
samples_batch_size = self.config.batch_size
|
|
237
|
+
|
|
238
|
+
# Calculate total samples for progress bar - use num_samples or iterate to count
|
|
239
|
+
total_num_samples = self.num_samples
|
|
240
|
+
if total_num_samples is None:
|
|
241
|
+
# Count samples by iterating (this might be expensive for large datasets)
|
|
242
|
+
total_num_samples = sum(1 for _ in self.task.iterate_samples(None))
|
|
243
|
+
|
|
244
|
+
samples_batch: list[Sample] = []
|
|
245
|
+
with tqdm(
|
|
246
|
+
total=total_num_samples, desc=f"Processing {self.response_type.value}", disable=get_disable_bar_flag()
|
|
247
|
+
) as pbar:
|
|
248
|
+
for i, sample in enumerate(self.task.iterate_samples(self.num_samples)):
|
|
249
|
+
subject = f" - Subject: {sample.subject}"
|
|
250
|
+
sample_index = i + 1
|
|
251
|
+
|
|
252
|
+
if sample.id in subject_response_id_mapping.get(sample.subject, []):
|
|
253
|
+
log_msg = (
|
|
254
|
+
f"Task: {self.response_type.value}{subject} - Sample: {sample_index} - skipping, already done."
|
|
255
|
+
)
|
|
256
|
+
logger.info(log_msg) # For log files
|
|
257
|
+
safe_tqdm_write(log_msg) # For console display with tqdm
|
|
258
|
+
pbar.update(1)
|
|
259
|
+
continue
|
|
260
|
+
|
|
261
|
+
log_msg = f"Task: {self.response_type.value}{subject} - Sample: {sample_index}/{total_num_samples}"
|
|
262
|
+
logger.info(log_msg) # For log files
|
|
263
|
+
safe_tqdm_write(log_msg) # For console display with tqdm
|
|
264
|
+
pbar.set_postfix_str(f"Sample {sample_index}/{total_num_samples}")
|
|
265
|
+
pbar.update(1)
|
|
266
|
+
|
|
267
|
+
samples_batch.append(sample)
|
|
268
|
+
|
|
269
|
+
if len(samples_batch) >= samples_batch_size:
|
|
270
|
+
_process_batch(samples_batch)
|
|
271
|
+
samples_batch = []
|
|
272
|
+
|
|
273
|
+
if should_preempt_callable():
|
|
274
|
+
log_msg = "Preempt"
|
|
275
|
+
logger.info(log_msg) # For log files
|
|
276
|
+
safe_tqdm_write(log_msg) # For console display with tqdm
|
|
277
|
+
if not self.save_intermediate_results:
|
|
278
|
+
self.result_processor.save_responses(responses)
|
|
279
|
+
return responses, True
|
|
280
|
+
|
|
281
|
+
_process_batch(samples_batch)
|
|
282
|
+
|
|
283
|
+
if not self.save_intermediate_results:
|
|
284
|
+
self.result_processor.save_responses(responses)
|
|
285
|
+
return responses, False
|
|
286
|
+
|
|
287
|
+
def _get_metadata(self) -> dict[str, Any]:
|
|
288
|
+
"""Prepares metadata dictionary from the configuration."""
|
|
289
|
+
all_metrics = getattr(self.task, "METRICS", None)
|
|
290
|
+
metadata = self.config.model_dump(mode="json")
|
|
291
|
+
metadata["llm_name"] = self.llm.name
|
|
292
|
+
metadata["task_name"] = self.task_name
|
|
293
|
+
language = getattr(self.task, "LANGUAGE", None)
|
|
294
|
+
metadata["language"] = map_language_to_value(language)
|
|
295
|
+
metadata["metrics"] = [m.NAME for m in all_metrics] if all_metrics is not None else []
|
|
296
|
+
metadata["primary_metrics"] = getattr(self.task, "PRIMARY_METRICS", None)
|
|
297
|
+
metadata["eval_framework_version"] = eval_framework_version
|
|
298
|
+
metadata["task_output_dir"] = str(self.result_processor.output_dir)
|
|
299
|
+
if hasattr(self, "total_time"):
|
|
300
|
+
metadata["start_time"] = str(datetime.fromtimestamp(self.start_time, UTC))
|
|
301
|
+
metadata["end_time"] = str(datetime.fromtimestamp(self.end_time, UTC))
|
|
302
|
+
metadata["total_time"] = self.total_time
|
|
303
|
+
|
|
304
|
+
# add task specific metadata
|
|
305
|
+
metadata["task_metadata"] = self.task.get_metadata()
|
|
306
|
+
|
|
307
|
+
try:
|
|
308
|
+
assert get_cluster_info is not None, "Determined cluster info not available"
|
|
309
|
+
info = get_cluster_info()
|
|
310
|
+
if info is not None:
|
|
311
|
+
metadata["determined_agent_id"] = info.agent_id
|
|
312
|
+
if info.task_type == "TRIAL":
|
|
313
|
+
metadata["determined_experiment_id"] = info.trial.experiment_id
|
|
314
|
+
metadata["determined_trial_id"] = info.trial.trial_id
|
|
315
|
+
except Exception as e:
|
|
316
|
+
logger.info(f"{e}; cluster info not available in local context")
|
|
317
|
+
|
|
318
|
+
return metadata
|
|
319
|
+
|
|
320
|
+
def _verify_loaded_metadata_compatibility(self) -> None:
|
|
321
|
+
if not (loaded_metadata := self.result_processor.load_metadata()):
|
|
322
|
+
return
|
|
323
|
+
current_metadata = self._get_metadata()
|
|
324
|
+
# check if crucial keys in metadata are the same as in the previous run
|
|
325
|
+
keys = [
|
|
326
|
+
"task_name",
|
|
327
|
+
"task_subjects",
|
|
328
|
+
"num_fewshot",
|
|
329
|
+
"num_samples",
|
|
330
|
+
"llm_name",
|
|
331
|
+
"llm_args",
|
|
332
|
+
"perturbation_config",
|
|
333
|
+
]
|
|
334
|
+
for key in keys:
|
|
335
|
+
if loaded_metadata[key] != current_metadata[key]:
|
|
336
|
+
raise ValueError(f"Existing metadata does not match current metadata for {key}.")
|
|
337
|
+
|
|
338
|
+
def __del__(self) -> None:
|
|
339
|
+
self.llm.__del__()
|
|
340
|
+
|
|
341
|
+
def generate(self, should_preempt_callable: Callable[[], bool]) -> tuple[list[Completion | Loglikelihood], bool]:
|
|
342
|
+
"""Generates responses and saves them along with metadata.
|
|
343
|
+
:param should_preempt_callable: function to check if preempt is called
|
|
344
|
+
:return: list of responses, preempted: whether the process was preempted or not
|
|
345
|
+
"""
|
|
346
|
+
logger.info(f"{RED}[ Running responses generation ---------- ]{RESET}")
|
|
347
|
+
logger.info(f"{RED}[ Will save into {self.result_processor.output_dir} ---------- ]{RESET}")
|
|
348
|
+
responses, preempted = self._run_task_against_model(should_preempt_callable)
|
|
349
|
+
logger.info("Completions generated and saved.")
|
|
350
|
+
|
|
351
|
+
return responses, preempted
|
|
File without changes
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
from dotenv import load_dotenv
|
|
5
|
+
from pydantic import BaseModel, ConfigDict
|
|
6
|
+
|
|
7
|
+
from eval_framework.shared.types import Completion, Error, Loglikelihood
|
|
8
|
+
from eval_framework.tasks.eval_config import EvalConfig
|
|
9
|
+
|
|
10
|
+
MAIN = "eval_framework_results"
|
|
11
|
+
|
|
12
|
+
load_dotenv()
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Result(BaseModel):
|
|
16
|
+
model_config = ConfigDict(extra="forbid")
|
|
17
|
+
id: int
|
|
18
|
+
subject: str
|
|
19
|
+
num_fewshot: int
|
|
20
|
+
llm_name: str
|
|
21
|
+
task_name: str
|
|
22
|
+
metric_class_name: str
|
|
23
|
+
metric_name: str
|
|
24
|
+
key: str | None
|
|
25
|
+
value: float | None
|
|
26
|
+
higher_is_better: bool
|
|
27
|
+
prompt: str
|
|
28
|
+
response: str
|
|
29
|
+
llm_judge_prompt: str | None = None
|
|
30
|
+
llm_judge_response: str | None = None
|
|
31
|
+
code_execution_trace: str | None = None
|
|
32
|
+
error: Error | None = None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ResultProcessor(ABC):
|
|
36
|
+
@abstractmethod
|
|
37
|
+
def save_metadata(self, metadata: dict) -> None:
|
|
38
|
+
"""Save metadata."""
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
@abstractmethod
|
|
42
|
+
def load_metadata(self) -> dict:
|
|
43
|
+
"""Load metadata."""
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
def save_responses(self, responses: list[Completion | Loglikelihood]) -> None:
|
|
48
|
+
"""Save a list of response objects (overwrite a file)."""
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
@abstractmethod
|
|
52
|
+
def save_response(self, response: Completion | Loglikelihood) -> None:
|
|
53
|
+
"""Save a single response object (append into a file)."""
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
@abstractmethod
|
|
57
|
+
def load_responses(self) -> list[Completion | Loglikelihood]:
|
|
58
|
+
"""Load a list of response objects."""
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
@abstractmethod
|
|
62
|
+
def save_metrics_results(self, results: list[Result]) -> None:
|
|
63
|
+
"""Save the results of the metrics (overwrite a file)."""
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
@abstractmethod
|
|
67
|
+
def save_metrics_result(self, result: Result) -> None:
|
|
68
|
+
"""Save a single metric result (append into a file)."""
|
|
69
|
+
pass
|
|
70
|
+
|
|
71
|
+
@abstractmethod
|
|
72
|
+
def save_aggregated_results(self, result: dict[str, float | None]) -> None:
|
|
73
|
+
"""Save the aggregated results."""
|
|
74
|
+
pass
|
|
75
|
+
|
|
76
|
+
@abstractmethod
|
|
77
|
+
def load_metrics_results(self) -> list[Result]:
|
|
78
|
+
"""Load the aggregated results."""
|
|
79
|
+
pass
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class ResultsUploader(ABC):
|
|
83
|
+
@abstractmethod
|
|
84
|
+
def upload(self, llm_name: str, config: EvalConfig, output_dir: Path) -> bool:
|
|
85
|
+
"""Upload relevant parts from `output_dir` to the desired destination.
|
|
86
|
+
Returns True if upload was successful, False otherwise.
|
|
87
|
+
"""
|
|
88
|
+
pass
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Module for writing result folder and its contents to HuggingFace
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
import wandb
|
|
10
|
+
from huggingface_hub import HfApi, login
|
|
11
|
+
|
|
12
|
+
from eval_framework.result_processors.base import ResultsUploader
|
|
13
|
+
from eval_framework.tasks.eval_config import EvalConfig
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class HFUploader(ResultsUploader):
|
|
19
|
+
def __init__(self, config: EvalConfig):
|
|
20
|
+
if not config.hf_upload_dir:
|
|
21
|
+
logger.warning("Results will not be persisted in HuggingFace (`hf_upload_dir` not configured).")
|
|
22
|
+
return
|
|
23
|
+
if config.output_dir is None:
|
|
24
|
+
raise ValueError("Output directory is not set in the configuration.")
|
|
25
|
+
if not config.hf_upload_repo:
|
|
26
|
+
raise ValueError("HuggingFace upload repository is not set in the configuration.")
|
|
27
|
+
|
|
28
|
+
self.hf_api = HFUploader._login_into_hf()
|
|
29
|
+
if self.hf_api is None:
|
|
30
|
+
logger.error("Could not login into HuggingFace (check HF_TOKEN). Results not persisted in HuggingFace.")
|
|
31
|
+
|
|
32
|
+
def upload(self, llm_name: str, config: EvalConfig, output_dir: Path) -> bool:
|
|
33
|
+
if not hasattr(self, "hf_api") or self.hf_api is None:
|
|
34
|
+
return False
|
|
35
|
+
assert config.hf_upload_repo and config.hf_upload_dir
|
|
36
|
+
|
|
37
|
+
rel_upload_dir = output_dir.relative_to(config.output_dir)
|
|
38
|
+
upload_dir = Path(config.hf_upload_dir.replace("/", "")) / rel_upload_dir
|
|
39
|
+
logger.info(f"HuggingFace upload starting to: {upload_dir}")
|
|
40
|
+
|
|
41
|
+
upload_counter = 0
|
|
42
|
+
for fp in output_dir.iterdir():
|
|
43
|
+
if fp.name not in ["aggregated_results.json", "metadata.json"]:
|
|
44
|
+
logger.info(f"Skipping {fp}.")
|
|
45
|
+
else:
|
|
46
|
+
try:
|
|
47
|
+
self.hf_api.upload_file(
|
|
48
|
+
path_or_fileobj=str(fp),
|
|
49
|
+
path_in_repo=str(upload_dir / fp.name),
|
|
50
|
+
repo_id=config.hf_upload_repo,
|
|
51
|
+
repo_type="dataset",
|
|
52
|
+
)
|
|
53
|
+
upload_counter += 1
|
|
54
|
+
except Exception as e:
|
|
55
|
+
logger.error("Problem during HF file upload: " + str(e))
|
|
56
|
+
return False
|
|
57
|
+
|
|
58
|
+
hf_url = f"https://huggingface.co/datasets/{config.hf_upload_repo}/tree/main/{upload_dir}"
|
|
59
|
+
logger.info(f"Uploaded {upload_counter} result files to {hf_url}.")
|
|
60
|
+
|
|
61
|
+
if wandb.run is not None:
|
|
62
|
+
try:
|
|
63
|
+
wandb.run.notes = f"Results uploaded to HuggingFace: [{hf_url}]({hf_url})"
|
|
64
|
+
except Exception as e:
|
|
65
|
+
logger.warning(f"Failed to update wandb notes with HF URL: {e}")
|
|
66
|
+
return True
|
|
67
|
+
|
|
68
|
+
@classmethod
|
|
69
|
+
def _login_into_hf(cls) -> HfApi | None:
|
|
70
|
+
try:
|
|
71
|
+
login(token=os.environ.get("HF_TOKEN", ""))
|
|
72
|
+
logger.debug("logged into HF")
|
|
73
|
+
return HfApi()
|
|
74
|
+
except Exception:
|
|
75
|
+
return None
|