eval-framework 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eval_framework/__init__.py +7 -0
- eval_framework/base_config.py +36 -0
- eval_framework/context/__init__.py +0 -0
- eval_framework/context/determined.py +170 -0
- eval_framework/context/eval.py +114 -0
- eval_framework/context/local.py +52 -0
- eval_framework/evaluation_generator.py +231 -0
- eval_framework/exceptions.py +2 -0
- eval_framework/external/ifeval_impl/README.md +5 -0
- eval_framework/external/ifeval_impl/instructions.py +1523 -0
- eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
- eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
- eval_framework/external/ifeval_impl/utils.py +135 -0
- eval_framework/llm/__init__.py +0 -0
- eval_framework/llm/aleph_alpha.py +323 -0
- eval_framework/llm/base.py +58 -0
- eval_framework/llm/huggingface.py +332 -0
- eval_framework/llm/mistral.py +73 -0
- eval_framework/llm/models.py +16 -0
- eval_framework/llm/openai.py +205 -0
- eval_framework/llm/vllm.py +438 -0
- eval_framework/logger.py +3 -0
- eval_framework/main.py +187 -0
- eval_framework/metrics/__init__.py +0 -0
- eval_framework/metrics/base.py +40 -0
- eval_framework/metrics/completion/__init__.py +1 -0
- eval_framework/metrics/completion/accuracy_completion.py +16 -0
- eval_framework/metrics/completion/bleu.py +76 -0
- eval_framework/metrics/completion/chrf.py +62 -0
- eval_framework/metrics/completion/code_assertion.py +44 -0
- eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
- eval_framework/metrics/completion/comet.py +56 -0
- eval_framework/metrics/completion/concordance_index.py +38 -0
- eval_framework/metrics/completion/csv_format.py +102 -0
- eval_framework/metrics/completion/cwe_accuracy.py +49 -0
- eval_framework/metrics/completion/exponential_similarity.py +65 -0
- eval_framework/metrics/completion/f1.py +42 -0
- eval_framework/metrics/completion/format_checker.py +56 -0
- eval_framework/metrics/completion/grid_difference.py +77 -0
- eval_framework/metrics/completion/ifeval.py +73 -0
- eval_framework/metrics/completion/json_format.py +171 -0
- eval_framework/metrics/completion/language_checker.py +74 -0
- eval_framework/metrics/completion/length_control.py +83 -0
- eval_framework/metrics/completion/math_reasoning_completion.py +303 -0
- eval_framework/metrics/completion/niah_accuracy.py +163 -0
- eval_framework/metrics/completion/placeholder_checker.py +27 -0
- eval_framework/metrics/completion/repetition.py +88 -0
- eval_framework/metrics/completion/rouge_1.py +35 -0
- eval_framework/metrics/completion/rouge_2.py +45 -0
- eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
- eval_framework/metrics/completion/rouge_l.py +52 -0
- eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
- eval_framework/metrics/completion/ter.py +67 -0
- eval_framework/metrics/completion/text_counter.py +182 -0
- eval_framework/metrics/efficiency/__init__.py +0 -0
- eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
- eval_framework/metrics/llm/__init__.py +0 -0
- eval_framework/metrics/llm/base.py +8 -0
- eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
- eval_framework/metrics/llm/graders/comparison_grader.py +146 -0
- eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
- eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
- eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
- eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
- eval_framework/metrics/llm/graders/language.py +56 -0
- eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
- eval_framework/metrics/llm/graders/models.py +74 -0
- eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
- eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
- eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
- eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
- eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
- eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
- eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
- eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
- eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
- eval_framework/metrics/llm/llm_judge_mtbench_pair.py +205 -0
- eval_framework/metrics/llm/llm_judge_mtbench_single.py +188 -0
- eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
- eval_framework/metrics/llm/llm_judge_sql.py +394 -0
- eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
- eval_framework/metrics/loglikelihood/__init__.py +0 -0
- eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
- eval_framework/metrics/loglikelihood/probability_mass.py +56 -0
- eval_framework/py.typed +0 -0
- eval_framework/response_generator.py +416 -0
- eval_framework/result_processors/__init__.py +0 -0
- eval_framework/result_processors/base.py +74 -0
- eval_framework/result_processors/hf_processor.py +87 -0
- eval_framework/result_processors/result_processor.py +129 -0
- eval_framework/run.py +314 -0
- eval_framework/run_direct.py +42 -0
- eval_framework/shared/types.py +227 -0
- eval_framework/tasks/__init__.py +6 -0
- eval_framework/tasks/base.py +314 -0
- eval_framework/tasks/benchmarks/__init__.py +0 -0
- eval_framework/tasks/benchmarks/arc.py +46 -0
- eval_framework/tasks/benchmarks/arc_de.py +46 -0
- eval_framework/tasks/benchmarks/arc_fi.py +46 -0
- eval_framework/tasks/benchmarks/belebele.py +60 -0
- eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
- eval_framework/tasks/benchmarks/casehold.py +47 -0
- eval_framework/tasks/benchmarks/chembench.py +85 -0
- eval_framework/tasks/benchmarks/copa.py +39 -0
- eval_framework/tasks/benchmarks/duc.py +91 -0
- eval_framework/tasks/benchmarks/flores200.py +62 -0
- eval_framework/tasks/benchmarks/flores_plus.py +84 -0
- eval_framework/tasks/benchmarks/gpqa.py +177 -0
- eval_framework/tasks/benchmarks/gsm8k.py +148 -0
- eval_framework/tasks/benchmarks/hellaswag.py +44 -0
- eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
- eval_framework/tasks/benchmarks/humaneval.py +97 -0
- eval_framework/tasks/benchmarks/ifeval.py +78 -0
- eval_framework/tasks/benchmarks/include.py +119 -0
- eval_framework/tasks/benchmarks/infinitebench.py +302 -0
- eval_framework/tasks/benchmarks/math_reasoning.py +569 -0
- eval_framework/tasks/benchmarks/mbpp.py +192 -0
- eval_framework/tasks/benchmarks/mmlu.py +190 -0
- eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
- eval_framework/tasks/benchmarks/mmlu_pro.py +139 -0
- eval_framework/tasks/benchmarks/mmmlu.py +529 -0
- eval_framework/tasks/benchmarks/openbookqa.py +37 -0
- eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
- eval_framework/tasks/benchmarks/pawsx.py +65 -0
- eval_framework/tasks/benchmarks/piqa.py +39 -0
- eval_framework/tasks/benchmarks/quality.py +56 -0
- eval_framework/tasks/benchmarks/sciq.py +44 -0
- eval_framework/tasks/benchmarks/sphyr.py +75 -0
- eval_framework/tasks/benchmarks/squad.py +89 -0
- eval_framework/tasks/benchmarks/struct_eval.py +110 -0
- eval_framework/tasks/benchmarks/tablebench.py +117 -0
- eval_framework/tasks/benchmarks/triviaqa.py +42 -0
- eval_framework/tasks/benchmarks/truthfulqa.py +95 -0
- eval_framework/tasks/benchmarks/winogender.py +39 -0
- eval_framework/tasks/benchmarks/winogrande.py +44 -0
- eval_framework/tasks/benchmarks/winox.py +57 -0
- eval_framework/tasks/benchmarks/wmt.py +160 -0
- eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
- eval_framework/tasks/eval_config.py +112 -0
- eval_framework/tasks/perturbation.py +83 -0
- eval_framework/tasks/registry.py +186 -0
- eval_framework/tasks/task_loader.py +80 -0
- eval_framework/tasks/task_names.py +138 -0
- eval_framework/tasks/utils.py +578 -0
- eval_framework/utils/constants.py +9 -0
- eval_framework/utils/generate_task_docs.py +229 -0
- eval_framework/utils/helpers.py +3 -0
- eval_framework/utils/logging.py +50 -0
- eval_framework/utils/packaging.py +52 -0
- eval_framework-0.2.0.dist-info/METADATA +514 -0
- eval_framework-0.2.0.dist-info/RECORD +161 -0
- eval_framework-0.2.0.dist-info/WHEEL +4 -0
- eval_framework-0.2.0.dist-info/entry_points.txt +3 -0
- template_formatting/README.md +83 -0
- template_formatting/__init__.py +0 -0
- template_formatting/formatter.py +536 -0
- template_formatting/mistral_formatter.py +159 -0
- template_formatting/py.typed +0 -0
- template_formatting/tests/test_formatter_eval.py +408 -0
- template_formatting/tests/test_formatter_scaling.py +253 -0
- template_formatting/tests/test_mistral_formatter.py +136 -0
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from abc import ABC
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import pycountry
|
|
6
|
+
import sacrebleu
|
|
7
|
+
|
|
8
|
+
from eval_framework.metrics.completion.bleu import LINEWISE_BLEU
|
|
9
|
+
from eval_framework.metrics.completion.chrf import LINEWISE_CHRF
|
|
10
|
+
from eval_framework.metrics.completion.ter import LINEWISE_TER
|
|
11
|
+
from eval_framework.tasks.base import RANDOM_SEED, BaseTask, Language, ResponseType, Sample
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class WMT(BaseTask[str], ABC):
|
|
15
|
+
"""WMT dataset:"""
|
|
16
|
+
|
|
17
|
+
NAME = "WMT"
|
|
18
|
+
DATASET_PATH = ""
|
|
19
|
+
SAMPLE_SPLIT = "test"
|
|
20
|
+
FEWSHOT_SPLIT = "test"
|
|
21
|
+
RESPONSE_TYPE = ResponseType.COMPLETION
|
|
22
|
+
METRICS = [LINEWISE_BLEU, LINEWISE_CHRF, LINEWISE_TER]
|
|
23
|
+
PERTURBATION_UNMODIFIABLE_WORDS = ["phrase"]
|
|
24
|
+
|
|
25
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
26
|
+
super().__init__(num_fewshot)
|
|
27
|
+
self.stop_sequences: list[str] = [".\n", " phrase: ", "phrase:", "phrase: ", " phrase:", "\n\n"]
|
|
28
|
+
|
|
29
|
+
def _load_dataset(self, subject: str | None) -> None:
|
|
30
|
+
src_file, ref_file, _, _, _ = sacrebleu.download_test_set(test_set=self.DATASET_PATH, langpair=subject)
|
|
31
|
+
src_data, ref_data = [[line.rstrip() for line in sacrebleu.smart_open(file)] for file in (src_file, ref_file)]
|
|
32
|
+
|
|
33
|
+
data_list = [{"source": src, "target": ref, "subject": subject} for src, ref in zip(src_data, ref_data)]
|
|
34
|
+
self.rnd = random.Random(RANDOM_SEED)
|
|
35
|
+
self.rnd.shuffle(data_list)
|
|
36
|
+
self.dataset = {"test": data_list}
|
|
37
|
+
|
|
38
|
+
def _code_to_language(self, code: str) -> str:
|
|
39
|
+
# key is alpha_2 or alpha_3 depending on the code length
|
|
40
|
+
key = f"alpha_{len(code)}"
|
|
41
|
+
language_tuple = pycountry.languages.get(**{key: code})
|
|
42
|
+
return language_tuple.name
|
|
43
|
+
|
|
44
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
45
|
+
language_codes = item["subject"].split("-")
|
|
46
|
+
src_lang = self._code_to_language(language_codes[0])
|
|
47
|
+
|
|
48
|
+
language_codes = item["subject"].split("-")
|
|
49
|
+
tar_lang = self._code_to_language(language_codes[1])
|
|
50
|
+
cue = f"{tar_lang} phrase:"
|
|
51
|
+
|
|
52
|
+
return f"{src_lang} phrase: {item['source']}\n{cue}"
|
|
53
|
+
|
|
54
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
|
|
55
|
+
return item["target"] if isinstance(item["target"], str) else item["target"][0]
|
|
56
|
+
|
|
57
|
+
def _get_fewshot_target_text(self, item: dict[str, Any]) -> str:
|
|
58
|
+
target = self._get_ground_truth(item)
|
|
59
|
+
assert target is not None
|
|
60
|
+
assert isinstance(target, str)
|
|
61
|
+
return f" {target}"
|
|
62
|
+
|
|
63
|
+
def post_process_generated_completion(self, completion_text: str, sample: Sample | None = None) -> str:
|
|
64
|
+
for stop_sequence in self.stop_sequences:
|
|
65
|
+
if stop_sequence in completion_text:
|
|
66
|
+
completion_text = completion_text.split(stop_sequence)[0]
|
|
67
|
+
return completion_text.strip()
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class WMT14(WMT):
|
|
71
|
+
NAME = "WMT14"
|
|
72
|
+
DATASET_PATH = "wmt14"
|
|
73
|
+
SUBJECTS = ["en-fr", "fr-en"]
|
|
74
|
+
LANGUAGE = {
|
|
75
|
+
"en-fr": (Language["ENG"], Language["FRA"]),
|
|
76
|
+
"fr-en": (Language["FRA"], Language["ENG"]),
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class WMT16(WMT):
|
|
81
|
+
NAME = "WMT16"
|
|
82
|
+
DATASET_PATH = "wmt16"
|
|
83
|
+
SUBJECTS = ["de-en", "en-de"]
|
|
84
|
+
LANGUAGE = {
|
|
85
|
+
"de-en": (Language["DEU"], Language["ENG"]),
|
|
86
|
+
"en-de": (Language["ENG"], Language["DEU"]),
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class WMT20(WMT):
|
|
91
|
+
NAME = "WMT20"
|
|
92
|
+
DATASET_PATH = "wmt20"
|
|
93
|
+
SUBJECTS = ["de-en", "de-fr", "en-de", "fr-de"]
|
|
94
|
+
LANGUAGE = {
|
|
95
|
+
"de-en": (Language["DEU"], Language["ENG"]),
|
|
96
|
+
"de-fr": (Language["DEU"], Language["FRA"]),
|
|
97
|
+
"en-de": (Language["ENG"], Language["DEU"]),
|
|
98
|
+
"fr-de": (Language["FRA"], Language["DEU"]),
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class WMT_INSTRUCT(WMT):
|
|
103
|
+
PERTURBATION_UNMODIFIABLE_WORDS = ["Please", "translate"]
|
|
104
|
+
COMPLETION_PREFIX = "This is the translation:"
|
|
105
|
+
|
|
106
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
107
|
+
super().__init__(num_fewshot)
|
|
108
|
+
self.stop_sequences: list[str] = ["Please translate"]
|
|
109
|
+
|
|
110
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
111
|
+
src_lang, tar_lang = map(self._code_to_language, item["subject"].split("-"))
|
|
112
|
+
return f"Please translate from {src_lang} to {tar_lang}: {item['source']}"
|
|
113
|
+
|
|
114
|
+
def _get_cue(self, item: dict[str, Any]) -> str:
|
|
115
|
+
return self.COMPLETION_PREFIX
|
|
116
|
+
|
|
117
|
+
def _get_fewshot_target_text(self, item: dict[str, Any]) -> str:
|
|
118
|
+
target = self._get_ground_truth(item)
|
|
119
|
+
assert target is not None
|
|
120
|
+
return f" {target}"
|
|
121
|
+
|
|
122
|
+
def post_process_generated_completion(self, completion_text: str, sample: Sample | None = None) -> str:
|
|
123
|
+
completion_text = completion_text.removeprefix(self.COMPLETION_PREFIX)
|
|
124
|
+
completion_text = completion_text.strip()
|
|
125
|
+
for stop_sequence in self.stop_sequences:
|
|
126
|
+
if stop_sequence in completion_text:
|
|
127
|
+
completion_text = completion_text.split(stop_sequence)[0]
|
|
128
|
+
return completion_text
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class WMT14_INSTRUCT(WMT_INSTRUCT):
|
|
132
|
+
NAME = "WMT14 Instruct"
|
|
133
|
+
DATASET_PATH = "wmt14"
|
|
134
|
+
SUBJECTS = ["en-fr", "fr-en"]
|
|
135
|
+
LANGUAGE = {
|
|
136
|
+
"en-fr": (Language["ENG"], Language["FRA"]),
|
|
137
|
+
"fr-en": (Language["FRA"], Language["ENG"]),
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class WMT16_INSTRUCT(WMT_INSTRUCT):
|
|
142
|
+
NAME = "WMT16 Instruct"
|
|
143
|
+
DATASET_PATH = "wmt16"
|
|
144
|
+
SUBJECTS = ["de-en", "en-de"]
|
|
145
|
+
LANGUAGE = {
|
|
146
|
+
"de-en": (Language["DEU"], Language["ENG"]),
|
|
147
|
+
"en-de": (Language["ENG"], Language["DEU"]),
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class WMT20_INSTRUCT(WMT_INSTRUCT):
|
|
152
|
+
NAME = "WMT20 Instruct"
|
|
153
|
+
DATASET_PATH = "wmt20"
|
|
154
|
+
SUBJECTS = ["de-en", "de-fr", "en-de", "fr-de"]
|
|
155
|
+
LANGUAGE = {
|
|
156
|
+
"de-en": (Language["DEU"], Language["ENG"]),
|
|
157
|
+
"de-fr": (Language["DEU"], Language["FRA"]),
|
|
158
|
+
"en-de": (Language["ENG"], Language["DEU"]),
|
|
159
|
+
"fr-de": (Language["FRA"], Language["DEU"]),
|
|
160
|
+
}
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from eval_framework.metrics.completion.exponential_similarity import ExponentialSimilarity
|
|
5
|
+
from eval_framework.metrics.completion.f1 import F1
|
|
6
|
+
from eval_framework.metrics.completion.rouge_geometric_mean import ROUGE_GEOMETRIC_MEAN
|
|
7
|
+
from eval_framework.metrics.loglikelihood.accuracy_loglikelihood import (
|
|
8
|
+
AccuracyLoglikelihood,
|
|
9
|
+
)
|
|
10
|
+
from eval_framework.tasks.base import BaseTask, Language, ResponseType, Sample
|
|
11
|
+
from eval_framework.tasks.utils import get_n_letters
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ZERO_SCROLLS_QUALITY(BaseTask[str]):
|
|
15
|
+
"""ZeroSCROLLS dataset: https://huggingface.co/datasets/tau/zero_scrolls"""
|
|
16
|
+
|
|
17
|
+
NAME = "ZeroSCROLLS QuALITY"
|
|
18
|
+
DATASET_PATH = "tau/zero_scrolls"
|
|
19
|
+
SAMPLE_SPLIT = "validation"
|
|
20
|
+
FEWSHOT_SPLIT = "validation"
|
|
21
|
+
RESPONSE_TYPE = ResponseType.LOGLIKELIHOODS
|
|
22
|
+
METRICS = [AccuracyLoglikelihood]
|
|
23
|
+
SUBJECTS = ["quality"]
|
|
24
|
+
|
|
25
|
+
PERTURBATION_UNMODIFIABLE_WORDS = ["Answer"]
|
|
26
|
+
LANGUAGE = Language.ENG
|
|
27
|
+
|
|
28
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
29
|
+
assert num_fewshot == 0, "ZeroSCROLLS QuALITY only supports zero fewshot examples"
|
|
30
|
+
super().__init__(num_fewshot)
|
|
31
|
+
self.keys = get_n_letters(4)
|
|
32
|
+
|
|
33
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
34
|
+
query_end_index = item["query_end_index"]
|
|
35
|
+
return f"{item['input'][:query_end_index]}\n\n"
|
|
36
|
+
|
|
37
|
+
def _get_cue_text(self, item: dict[str, Any]) -> str:
|
|
38
|
+
return "Answer:"
|
|
39
|
+
|
|
40
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
|
|
41
|
+
return f" {item['output']}"
|
|
42
|
+
|
|
43
|
+
def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None:
|
|
44
|
+
return [f" {key}" for key in self.keys]
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class ZERO_SCROLLS_COMPLETION(BaseTask[str]):
|
|
48
|
+
"""ZeroSCROLLS dataset: https://huggingface.co/datasets/tau/zero_scrolls"""
|
|
49
|
+
|
|
50
|
+
DATASET_PATH = "tau/zero_scrolls"
|
|
51
|
+
SAMPLE_SPLIT = "validation"
|
|
52
|
+
FEWSHOT_SPLIT = "validation"
|
|
53
|
+
RESPONSE_TYPE = ResponseType.COMPLETION
|
|
54
|
+
|
|
55
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
|
|
56
|
+
return item["output"]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class ZERO_SCROLLS_GOV_REPORT(ZERO_SCROLLS_COMPLETION):
|
|
60
|
+
NAME = "ZeroSCROLLS GovReport"
|
|
61
|
+
METRICS = [ROUGE_GEOMETRIC_MEAN]
|
|
62
|
+
SUBJECTS = ["gov_report"]
|
|
63
|
+
PERTURBATION_UNMODIFIABLE_WORDS = ["Summary"]
|
|
64
|
+
|
|
65
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
66
|
+
assert num_fewshot == 0, "ZeroSCROLLS GovReport only supports zero fewshot examples"
|
|
67
|
+
super().__init__(num_fewshot)
|
|
68
|
+
|
|
69
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
70
|
+
query_end_index = item["query_end_index"]
|
|
71
|
+
return f"{item['input'][:query_end_index]}Summary:"
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class ZERO_SCROLLS_QMSUM(ZERO_SCROLLS_COMPLETION):
|
|
75
|
+
NAME = "ZeroSCROLLS QMSum"
|
|
76
|
+
METRICS = [ROUGE_GEOMETRIC_MEAN]
|
|
77
|
+
SUBJECTS = ["qmsum"]
|
|
78
|
+
PERTURBATION_UNMODIFIABLE_WORDS = ["Answer"]
|
|
79
|
+
|
|
80
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
81
|
+
assert num_fewshot == 0, "ZeroSCROLLS QMSum only supports zero fewshot examples"
|
|
82
|
+
super().__init__(num_fewshot)
|
|
83
|
+
|
|
84
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
85
|
+
query_end_index = item["query_end_index"]
|
|
86
|
+
return f"{item['input'][:query_end_index]}\n\nAnswer:"
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class ZERO_SCROLLS_SQUALITY(ZERO_SCROLLS_COMPLETION):
|
|
90
|
+
NAME = "ZeroSCROLLS SQuALITY"
|
|
91
|
+
METRICS = [ROUGE_GEOMETRIC_MEAN]
|
|
92
|
+
SUBJECTS = ["squality"]
|
|
93
|
+
PERTURBATION_UNMODIFIABLE_WORDS = ["Answer"]
|
|
94
|
+
|
|
95
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
96
|
+
assert num_fewshot == 0, "ZeroSCROLLS SQuALITY only supports zero fewshot examples"
|
|
97
|
+
super().__init__(num_fewshot)
|
|
98
|
+
|
|
99
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
100
|
+
query_end_index = item["query_end_index"]
|
|
101
|
+
return f"{item['input'][:query_end_index]}\n\nAnswer:"
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class ZERO_SCROLLS_QASPER(ZERO_SCROLLS_COMPLETION):
|
|
105
|
+
NAME = "ZeroSCROLLS Qasper"
|
|
106
|
+
METRICS = [F1]
|
|
107
|
+
SUBJECTS = ["qasper"]
|
|
108
|
+
PERTURBATION_UNMODIFIABLE_WORDS = ["Answer"]
|
|
109
|
+
|
|
110
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
111
|
+
assert num_fewshot == 0, "ZeroSCROLLS Qasper only supports zero fewshot examples"
|
|
112
|
+
super().__init__(num_fewshot)
|
|
113
|
+
|
|
114
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
115
|
+
query_end_index = item["query_end_index"]
|
|
116
|
+
return f"{item['input'][:query_end_index]}\n\nAnswer:"
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class ZERO_SCROLLS_NARRATIVEQA(ZERO_SCROLLS_COMPLETION):
|
|
120
|
+
NAME = "ZeroSCROLLS NarrativeQA"
|
|
121
|
+
METRICS = [F1]
|
|
122
|
+
SUBJECTS = ["narrative_qa"]
|
|
123
|
+
PERTURBATION_UNMODIFIABLE_WORDS = ["Answer"]
|
|
124
|
+
|
|
125
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
126
|
+
assert num_fewshot == 0, "ZeroSCROLLS NarrativeQA only supports zero fewshot examples"
|
|
127
|
+
super().__init__(num_fewshot)
|
|
128
|
+
|
|
129
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
130
|
+
query_end_index = item["query_end_index"]
|
|
131
|
+
return f"{item['input'][:query_end_index]}\n\nAnswer:"
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class ZERO_SCROLLS_MUSIQUE(ZERO_SCROLLS_COMPLETION):
|
|
135
|
+
NAME = "ZeroSCROLLS MuSiQue"
|
|
136
|
+
METRICS = [F1]
|
|
137
|
+
SUBJECTS = ["musique"]
|
|
138
|
+
PERTURBATION_UNMODIFIABLE_WORDS = ["Answer"]
|
|
139
|
+
|
|
140
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
141
|
+
assert num_fewshot == 0, "ZeroSCROLLS MuSiQue only supports zero fewshot examples"
|
|
142
|
+
super().__init__(num_fewshot)
|
|
143
|
+
|
|
144
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
145
|
+
query_end_index = item["query_end_index"]
|
|
146
|
+
return f"{item['input'][:query_end_index]}\n\nAnswer:"
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class ZERO_SCROLLS_SPACE_DIGEST(ZERO_SCROLLS_COMPLETION):
|
|
150
|
+
NAME = "ZeroSCROLLS SpaceDigest"
|
|
151
|
+
METRICS = [ExponentialSimilarity]
|
|
152
|
+
SUBJECTS = ["space_digest"]
|
|
153
|
+
PERTURBATION_UNMODIFIABLE_WORDS = ["Answer"]
|
|
154
|
+
|
|
155
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
156
|
+
assert num_fewshot == 0, "ZeroSCROLLS SpaceDigest only supports zero fewshot examples"
|
|
157
|
+
super().__init__(num_fewshot)
|
|
158
|
+
|
|
159
|
+
def post_process_generated_completion(self, completion_text: str, sample: Sample | None = None) -> str:
|
|
160
|
+
# First, try to find patterns like "X%" or "X percent" or "X percentage"
|
|
161
|
+
percentage_patterns = [
|
|
162
|
+
r"(\d+(?:\.\d+)?)%", # Matches: 30%, 30.5%
|
|
163
|
+
r"(\d+(?:\.\d+)?)\s*percent", # Matches: 30 percent, 30.5 percent
|
|
164
|
+
r"(\d+(?:\.\d+)?)\s*percentage", # Matches: 30 percentage, 30.5 percentage
|
|
165
|
+
r"percentage\s*(?:is|of|:)?\s*(\d+(?:\.\d+)?)", # Matches: percentage is 30, percentage: 30.5
|
|
166
|
+
r"(?:is|equals|equal to|about|approximately|around|roughly)\s*(\d+(?:\.\d+)?)\s*%",
|
|
167
|
+
# Matches: is 30%, equals 30.5%
|
|
168
|
+
r"(?:is|equals|equal to|about|approximately|around|roughly)\s*(\d+(?:\.\d+)?)\s*percent",
|
|
169
|
+
# Matches: is 30 percent
|
|
170
|
+
r"it'?s\s*(\d+(?:\.\d+)?)", # Matches: it's 60, its 60
|
|
171
|
+
r"that'?s\s*(\d+(?:\.\d+)?)", # Matches: that's 60, thats 60
|
|
172
|
+
]
|
|
173
|
+
|
|
174
|
+
for pattern in percentage_patterns:
|
|
175
|
+
match = re.search(pattern, completion_text, re.IGNORECASE)
|
|
176
|
+
if match:
|
|
177
|
+
return match.group(1).strip()
|
|
178
|
+
|
|
179
|
+
# If no percentage pattern is found, check if the entire text is just a number
|
|
180
|
+
if re.fullmatch(r"\s*(\d+(?:\.\d+)?)\s*", completion_text):
|
|
181
|
+
return completion_text.strip()
|
|
182
|
+
|
|
183
|
+
# If not a standalone number, look for any number in the text
|
|
184
|
+
# This is a fallback and might be less accurate
|
|
185
|
+
number_match = re.search(r"(\d+(?:\.\d+)?)", completion_text)
|
|
186
|
+
if number_match:
|
|
187
|
+
return number_match.group(1).strip()
|
|
188
|
+
|
|
189
|
+
# If no number is found, return the original text stripped
|
|
190
|
+
return completion_text.strip()
|
|
191
|
+
|
|
192
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
193
|
+
query_end_index = item["query_end_index"]
|
|
194
|
+
return f"{item['input'][:query_end_index]}Answer:"
|
|
195
|
+
|
|
196
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
|
|
197
|
+
return self.post_process_generated_completion(item["output"])
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import json
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Annotated, Any
|
|
5
|
+
|
|
6
|
+
from pydantic import AfterValidator, Field, field_serializer, field_validator, model_validator
|
|
7
|
+
|
|
8
|
+
from eval_framework.base_config import BaseConfig
|
|
9
|
+
from eval_framework.llm.base import BaseLLM
|
|
10
|
+
from eval_framework.metrics.llm.base import BaseLLMJudgeMetric
|
|
11
|
+
from eval_framework.tasks.base import BaseTask
|
|
12
|
+
from eval_framework.tasks.perturbation import PerturbationConfig
|
|
13
|
+
from eval_framework.tasks.registry import get_task, validate_task_name
|
|
14
|
+
from eval_framework.utils.constants import ROOT_DIR
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class EvalConfig(BaseConfig):
|
|
18
|
+
output_dir: Path = ROOT_DIR
|
|
19
|
+
wandb_project: str | None = None
|
|
20
|
+
wandb_entity: str | None = None
|
|
21
|
+
wandb_run_id: str | None = None
|
|
22
|
+
hf_upload_dir: str | None = None
|
|
23
|
+
hf_upload_repo: str | None = None
|
|
24
|
+
num_fewshot: Annotated[int, Field(ge=0)] = 0
|
|
25
|
+
num_samples: Annotated[int | None, Field(ge=1)] = 10 # Allows None or int
|
|
26
|
+
max_tokens: int | None = None
|
|
27
|
+
perturbation_config: PerturbationConfig | None = None
|
|
28
|
+
task_name: Annotated[str, AfterValidator(validate_task_name)]
|
|
29
|
+
task_subjects: list[str] | None = None
|
|
30
|
+
hf_revision: str | None = None
|
|
31
|
+
llm_class: type[BaseLLM]
|
|
32
|
+
llm_args: dict[str, Any] = Field(default_factory=dict)
|
|
33
|
+
llm_judge_class: type[BaseLLM] | None = None
|
|
34
|
+
judge_model_args: dict[str, Any] = Field(default_factory=dict)
|
|
35
|
+
batch_size: Annotated[int, Field(ge=1)] = 1
|
|
36
|
+
description: str | None = None
|
|
37
|
+
save_intermediate_results: bool = True
|
|
38
|
+
save_logs: bool = True
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def task_class(self) -> type[BaseTask]:
|
|
42
|
+
return get_task(self.task_name)
|
|
43
|
+
|
|
44
|
+
@field_serializer("output_dir")
|
|
45
|
+
def serialize_output_dir(self, value: Path) -> str:
|
|
46
|
+
return str(value)
|
|
47
|
+
|
|
48
|
+
@field_validator("output_dir", mode="before")
|
|
49
|
+
@classmethod
|
|
50
|
+
def validate_output_dir(cls, value: str | Path) -> Path:
|
|
51
|
+
if isinstance(value, str):
|
|
52
|
+
return Path(value)
|
|
53
|
+
return value
|
|
54
|
+
|
|
55
|
+
@field_validator("llm_args", mode="before")
|
|
56
|
+
@classmethod
|
|
57
|
+
def validate_llm_args(cls, value: dict[str, Any]) -> dict[str, Any]:
|
|
58
|
+
def convert_value(v: Any) -> Any:
|
|
59
|
+
if isinstance(v, dict):
|
|
60
|
+
# Recursively process nested dictionaries (like sampling_params)
|
|
61
|
+
return {k: convert_value(nested_v) for k, nested_v in v.items()}
|
|
62
|
+
elif isinstance(v, str):
|
|
63
|
+
try:
|
|
64
|
+
# Try to evaluate as a Python literal (int, float, bool, None, list, dict, etc.)
|
|
65
|
+
return ast.literal_eval(v)
|
|
66
|
+
except (ValueError, SyntaxError):
|
|
67
|
+
return v # keep as string if not a valid literal
|
|
68
|
+
else:
|
|
69
|
+
return v # already proper type
|
|
70
|
+
|
|
71
|
+
return convert_value(value)
|
|
72
|
+
|
|
73
|
+
@field_validator("judge_model_args", mode="before")
|
|
74
|
+
@classmethod
|
|
75
|
+
def validate_judge_model_args(cls, value: dict[str, Any]) -> dict[str, Any]:
|
|
76
|
+
typed_value = {}
|
|
77
|
+
for k, v in value.items():
|
|
78
|
+
try: # maybe this llm argument is actually a number?
|
|
79
|
+
if "." in str(v):
|
|
80
|
+
v = float(v)
|
|
81
|
+
else:
|
|
82
|
+
v = int(v)
|
|
83
|
+
except ValueError:
|
|
84
|
+
pass
|
|
85
|
+
typed_value[k] = v
|
|
86
|
+
return typed_value
|
|
87
|
+
|
|
88
|
+
@model_validator(mode="after")
|
|
89
|
+
def validate_llm_judge_defined(self) -> "EvalConfig":
|
|
90
|
+
task = get_task(self.task_name)
|
|
91
|
+
for metric_class in task.METRICS:
|
|
92
|
+
if issubclass(metric_class, BaseLLMJudgeMetric):
|
|
93
|
+
assert self.llm_judge_class is not None, "The LLM Judge must be defined for this evaluation task."
|
|
94
|
+
return self
|
|
95
|
+
|
|
96
|
+
@field_serializer("llm_class")
|
|
97
|
+
def serialize_llm_class(self, value: type[BaseLLM] | None) -> str | None:
|
|
98
|
+
"""Serialize the class into its fully qualified name."""
|
|
99
|
+
if value:
|
|
100
|
+
return value.__name__
|
|
101
|
+
return None
|
|
102
|
+
|
|
103
|
+
@field_serializer("llm_judge_class")
|
|
104
|
+
def serialize_llm_judge_class(self, value: type[BaseLLM] | None) -> str | None:
|
|
105
|
+
"""Serialize the class into its fully qualified name."""
|
|
106
|
+
if value:
|
|
107
|
+
return value.__name__
|
|
108
|
+
return None
|
|
109
|
+
|
|
110
|
+
def model_json_dump(self) -> str:
|
|
111
|
+
model_dump = self.model_dump()
|
|
112
|
+
return json.dumps(model_dump, sort_keys=True)
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
import threading
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import Annotated, Any, TypeVar
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
6
|
+
|
|
7
|
+
from eval_framework.logger import logger
|
|
8
|
+
from eval_framework.tasks.base import RANDOM_SEED, BaseTask, Sample
|
|
9
|
+
from eval_framework.tasks.utils import Editor, HatPaperEditor
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PerturbationType(str, Enum):
|
|
13
|
+
# Editor methods
|
|
14
|
+
EDITOR = "editor"
|
|
15
|
+
# Hat paper methods
|
|
16
|
+
PERMUTE = "permute"
|
|
17
|
+
REPLACE = "replace"
|
|
18
|
+
DELETE = "delete"
|
|
19
|
+
UPPERCASE = "uppercase"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class PerturbationConfig(BaseModel):
|
|
23
|
+
model_config = ConfigDict(extra="forbid")
|
|
24
|
+
type: PerturbationType = PerturbationType.EDITOR
|
|
25
|
+
probability: Annotated[float, Field(ge=0.0, le=1.0)] = 0.1
|
|
26
|
+
seed: int = RANDOM_SEED
|
|
27
|
+
verbose: bool = False
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
_DOCKER_LAUNCH_LOCK = threading.Lock()
|
|
31
|
+
_AUGMENTER_PORT = 0
|
|
32
|
+
|
|
33
|
+
SomeBaseTask = TypeVar("SomeBaseTask", bound=BaseTask[Any])
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def create_perturbation_class[T: BaseTask](base_class: type[T], perturbation_config: PerturbationConfig) -> type[T]:
|
|
37
|
+
# mypy seems to have trouble inferring the type
|
|
38
|
+
class EditorPerturbation(base_class): # type: ignore
|
|
39
|
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
40
|
+
super().__init__(*args, **kwargs)
|
|
41
|
+
self.perturbation_config = perturbation_config
|
|
42
|
+
self.editor = Editor(
|
|
43
|
+
language="de" if base_class.LANGUAGE == "German" else "en", seed=perturbation_config.seed
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
def _get_instruction_text(self, sample: Sample) -> str:
|
|
47
|
+
text = super()._get_instruction_text(sample)
|
|
48
|
+
if self.perturbation_config.verbose:
|
|
49
|
+
logger.info(f"Perturbating text: {text}")
|
|
50
|
+
result = self.editor(
|
|
51
|
+
text, self.perturbation_config.probability, getattr(self, "PERTURBATION_UNMODIFIABLE_WORDS", [])
|
|
52
|
+
)
|
|
53
|
+
if self.perturbation_config.verbose:
|
|
54
|
+
logger.info(f"Perturbed text: {result}")
|
|
55
|
+
return result
|
|
56
|
+
|
|
57
|
+
class HatPaperPerturbation(base_class): # type: ignore
|
|
58
|
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
59
|
+
super().__init__(*args, **kwargs)
|
|
60
|
+
self.perturbation_config = perturbation_config
|
|
61
|
+
self.editor = HatPaperEditor(seed=perturbation_config.seed)
|
|
62
|
+
|
|
63
|
+
def _get_instruction_text(self, sample: Sample) -> str:
|
|
64
|
+
text = super()._get_instruction_text(sample)
|
|
65
|
+
if self.perturbation_config.verbose:
|
|
66
|
+
logger.info(f"Perturbating text: {text}")
|
|
67
|
+
words = getattr(self, "PERTURBATION_UNMODIFIABLE_WORDS", [])
|
|
68
|
+
if self.perturbation_config.type == PerturbationType.PERMUTE:
|
|
69
|
+
result = self.editor.permute_chars_in_string(text, self.perturbation_config.probability, words)
|
|
70
|
+
elif self.perturbation_config.type == PerturbationType.REPLACE:
|
|
71
|
+
result = self.editor.replace_chars_in_string(text, self.perturbation_config.probability, words)
|
|
72
|
+
elif self.perturbation_config.type == PerturbationType.DELETE:
|
|
73
|
+
result = self.editor.delete_chars_in_string(text, self.perturbation_config.probability, words)
|
|
74
|
+
elif self.perturbation_config.type == PerturbationType.UPPERCASE:
|
|
75
|
+
result = self.editor.upper_case_string(text)
|
|
76
|
+
if self.perturbation_config.verbose:
|
|
77
|
+
logger.info(f"Perturbed text: {result}")
|
|
78
|
+
return result
|
|
79
|
+
|
|
80
|
+
if perturbation_config.type == PerturbationType.EDITOR:
|
|
81
|
+
return EditorPerturbation
|
|
82
|
+
else:
|
|
83
|
+
return HatPaperPerturbation
|