eval-framework 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eval_framework/__init__.py +7 -0
- eval_framework/base_config.py +36 -0
- eval_framework/context/__init__.py +0 -0
- eval_framework/context/determined.py +170 -0
- eval_framework/context/eval.py +114 -0
- eval_framework/context/local.py +52 -0
- eval_framework/evaluation_generator.py +231 -0
- eval_framework/exceptions.py +2 -0
- eval_framework/external/ifeval_impl/README.md +5 -0
- eval_framework/external/ifeval_impl/instructions.py +1523 -0
- eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
- eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
- eval_framework/external/ifeval_impl/utils.py +135 -0
- eval_framework/llm/__init__.py +0 -0
- eval_framework/llm/aleph_alpha.py +323 -0
- eval_framework/llm/base.py +58 -0
- eval_framework/llm/huggingface.py +332 -0
- eval_framework/llm/mistral.py +73 -0
- eval_framework/llm/models.py +16 -0
- eval_framework/llm/openai.py +205 -0
- eval_framework/llm/vllm.py +438 -0
- eval_framework/logger.py +3 -0
- eval_framework/main.py +187 -0
- eval_framework/metrics/__init__.py +0 -0
- eval_framework/metrics/base.py +40 -0
- eval_framework/metrics/completion/__init__.py +1 -0
- eval_framework/metrics/completion/accuracy_completion.py +16 -0
- eval_framework/metrics/completion/bleu.py +76 -0
- eval_framework/metrics/completion/chrf.py +62 -0
- eval_framework/metrics/completion/code_assertion.py +44 -0
- eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
- eval_framework/metrics/completion/comet.py +56 -0
- eval_framework/metrics/completion/concordance_index.py +38 -0
- eval_framework/metrics/completion/csv_format.py +102 -0
- eval_framework/metrics/completion/cwe_accuracy.py +49 -0
- eval_framework/metrics/completion/exponential_similarity.py +65 -0
- eval_framework/metrics/completion/f1.py +42 -0
- eval_framework/metrics/completion/format_checker.py +56 -0
- eval_framework/metrics/completion/grid_difference.py +77 -0
- eval_framework/metrics/completion/ifeval.py +73 -0
- eval_framework/metrics/completion/json_format.py +171 -0
- eval_framework/metrics/completion/language_checker.py +74 -0
- eval_framework/metrics/completion/length_control.py +83 -0
- eval_framework/metrics/completion/math_reasoning_completion.py +303 -0
- eval_framework/metrics/completion/niah_accuracy.py +163 -0
- eval_framework/metrics/completion/placeholder_checker.py +27 -0
- eval_framework/metrics/completion/repetition.py +88 -0
- eval_framework/metrics/completion/rouge_1.py +35 -0
- eval_framework/metrics/completion/rouge_2.py +45 -0
- eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
- eval_framework/metrics/completion/rouge_l.py +52 -0
- eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
- eval_framework/metrics/completion/ter.py +67 -0
- eval_framework/metrics/completion/text_counter.py +182 -0
- eval_framework/metrics/efficiency/__init__.py +0 -0
- eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
- eval_framework/metrics/llm/__init__.py +0 -0
- eval_framework/metrics/llm/base.py +8 -0
- eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
- eval_framework/metrics/llm/graders/comparison_grader.py +146 -0
- eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
- eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
- eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
- eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
- eval_framework/metrics/llm/graders/language.py +56 -0
- eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
- eval_framework/metrics/llm/graders/models.py +74 -0
- eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
- eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
- eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
- eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
- eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
- eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
- eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
- eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
- eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
- eval_framework/metrics/llm/llm_judge_mtbench_pair.py +205 -0
- eval_framework/metrics/llm/llm_judge_mtbench_single.py +188 -0
- eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
- eval_framework/metrics/llm/llm_judge_sql.py +394 -0
- eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
- eval_framework/metrics/loglikelihood/__init__.py +0 -0
- eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
- eval_framework/metrics/loglikelihood/probability_mass.py +56 -0
- eval_framework/py.typed +0 -0
- eval_framework/response_generator.py +416 -0
- eval_framework/result_processors/__init__.py +0 -0
- eval_framework/result_processors/base.py +74 -0
- eval_framework/result_processors/hf_processor.py +87 -0
- eval_framework/result_processors/result_processor.py +129 -0
- eval_framework/run.py +314 -0
- eval_framework/run_direct.py +42 -0
- eval_framework/shared/types.py +227 -0
- eval_framework/tasks/__init__.py +6 -0
- eval_framework/tasks/base.py +314 -0
- eval_framework/tasks/benchmarks/__init__.py +0 -0
- eval_framework/tasks/benchmarks/arc.py +46 -0
- eval_framework/tasks/benchmarks/arc_de.py +46 -0
- eval_framework/tasks/benchmarks/arc_fi.py +46 -0
- eval_framework/tasks/benchmarks/belebele.py +60 -0
- eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
- eval_framework/tasks/benchmarks/casehold.py +47 -0
- eval_framework/tasks/benchmarks/chembench.py +85 -0
- eval_framework/tasks/benchmarks/copa.py +39 -0
- eval_framework/tasks/benchmarks/duc.py +91 -0
- eval_framework/tasks/benchmarks/flores200.py +62 -0
- eval_framework/tasks/benchmarks/flores_plus.py +84 -0
- eval_framework/tasks/benchmarks/gpqa.py +177 -0
- eval_framework/tasks/benchmarks/gsm8k.py +148 -0
- eval_framework/tasks/benchmarks/hellaswag.py +44 -0
- eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
- eval_framework/tasks/benchmarks/humaneval.py +97 -0
- eval_framework/tasks/benchmarks/ifeval.py +78 -0
- eval_framework/tasks/benchmarks/include.py +119 -0
- eval_framework/tasks/benchmarks/infinitebench.py +302 -0
- eval_framework/tasks/benchmarks/math_reasoning.py +569 -0
- eval_framework/tasks/benchmarks/mbpp.py +192 -0
- eval_framework/tasks/benchmarks/mmlu.py +190 -0
- eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
- eval_framework/tasks/benchmarks/mmlu_pro.py +139 -0
- eval_framework/tasks/benchmarks/mmmlu.py +529 -0
- eval_framework/tasks/benchmarks/openbookqa.py +37 -0
- eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
- eval_framework/tasks/benchmarks/pawsx.py +65 -0
- eval_framework/tasks/benchmarks/piqa.py +39 -0
- eval_framework/tasks/benchmarks/quality.py +56 -0
- eval_framework/tasks/benchmarks/sciq.py +44 -0
- eval_framework/tasks/benchmarks/sphyr.py +75 -0
- eval_framework/tasks/benchmarks/squad.py +89 -0
- eval_framework/tasks/benchmarks/struct_eval.py +110 -0
- eval_framework/tasks/benchmarks/tablebench.py +117 -0
- eval_framework/tasks/benchmarks/triviaqa.py +42 -0
- eval_framework/tasks/benchmarks/truthfulqa.py +95 -0
- eval_framework/tasks/benchmarks/winogender.py +39 -0
- eval_framework/tasks/benchmarks/winogrande.py +44 -0
- eval_framework/tasks/benchmarks/winox.py +57 -0
- eval_framework/tasks/benchmarks/wmt.py +160 -0
- eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
- eval_framework/tasks/eval_config.py +112 -0
- eval_framework/tasks/perturbation.py +83 -0
- eval_framework/tasks/registry.py +186 -0
- eval_framework/tasks/task_loader.py +80 -0
- eval_framework/tasks/task_names.py +138 -0
- eval_framework/tasks/utils.py +578 -0
- eval_framework/utils/constants.py +9 -0
- eval_framework/utils/generate_task_docs.py +229 -0
- eval_framework/utils/helpers.py +3 -0
- eval_framework/utils/logging.py +50 -0
- eval_framework/utils/packaging.py +52 -0
- eval_framework-0.2.0.dist-info/METADATA +514 -0
- eval_framework-0.2.0.dist-info/RECORD +161 -0
- eval_framework-0.2.0.dist-info/WHEEL +4 -0
- eval_framework-0.2.0.dist-info/entry_points.txt +3 -0
- template_formatting/README.md +83 -0
- template_formatting/__init__.py +0 -0
- template_formatting/formatter.py +536 -0
- template_formatting/mistral_formatter.py +159 -0
- template_formatting/py.typed +0 -0
- template_formatting/tests/test_formatter_eval.py +408 -0
- template_formatting/tests/test_formatter_scaling.py +253 -0
- template_formatting/tests/test_mistral_formatter.py +136 -0
|
@@ -0,0 +1,363 @@
|
|
|
1
|
+
import random
|
|
2
|
+
import re
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from eval_framework.tasks.base import NO_SUBJECT, RANDOM_SEED, Language, SubjectType
|
|
6
|
+
from eval_framework.tasks.benchmarks.arc import ARC
|
|
7
|
+
from eval_framework.tasks.benchmarks.gsm8k import GSM8K
|
|
8
|
+
from eval_framework.tasks.benchmarks.hellaswag import HELLASWAG
|
|
9
|
+
from eval_framework.tasks.benchmarks.mmlu import MMLU, MMLU_SUBJECTS
|
|
10
|
+
from eval_framework.tasks.benchmarks.mmlu_de import MMLU_SUBJECTS_TRANSLATION
|
|
11
|
+
from eval_framework.tasks.benchmarks.truthfulqa import TRUTHFULQA
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ARC_EU20_DE(ARC):
|
|
15
|
+
"""
|
|
16
|
+
EU20 Benchmarks from the openGPT-X paper:
|
|
17
|
+
- https://arxiv.org/abs/2410.08928
|
|
18
|
+
- leaderboard: https://huggingface.co/spaces/openGPT-X/european-llm-leaderboard
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
https://huggingface.co/datasets/openGPT-X/arcx
|
|
22
|
+
entries in 'challenge_DE': 1172 test, 299 validation, 198 train
|
|
23
|
+
entries in 'easy_DE': 2376 test, 570 validation, 197 train
|
|
24
|
+
features: ['id', 'question', 'choices', 'answerKey'],
|
|
25
|
+
SUBJECTS = ['challenge_BG', 'easy_BG', 'challenge_DA', 'easy_DA', 'challenge_DE', 'easy_DE', 'challenge_ET', 'easy_ET', 'challenge_FI', 'easy_FI', 'challenge_FR', 'easy_FR', 'challenge_EL', 'easy_EL', 'challenge_IT', 'easy_IT', 'challenge_LV', 'easy_LV', 'challenge_LT', 'easy_LT', 'challenge_NL', 'easy_NL', 'challenge_PL', 'easy_PL', 'challenge_PT-PT', 'easy_PT-PT', 'challenge_RO', 'easy_RO', 'challenge_SV', 'easy_SV', 'challenge_SK', 'easy_SK', 'challenge_SL', 'easy_SL', 'challenge_ES', 'easy_ES', 'challenge_CS', 'easy_CS', 'challenge_HU', 'easy_HU']
|
|
26
|
+
""" # noqa: E501
|
|
27
|
+
|
|
28
|
+
NAME = "ARC_EU20_DE"
|
|
29
|
+
DATASET_PATH = "openGPT-X/arcx"
|
|
30
|
+
SAMPLE_SPLIT = "test"
|
|
31
|
+
FEWSHOT_SPLIT = "train"
|
|
32
|
+
SUBJECTS = ["challenge_DE", "easy_DE"]
|
|
33
|
+
LANGUAGE = Language.DEU
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ARC_EU20_FR(ARC):
|
|
37
|
+
NAME = "ARC_EU20_FR"
|
|
38
|
+
DATASET_PATH = "openGPT-X/arcx"
|
|
39
|
+
SAMPLE_SPLIT = "test"
|
|
40
|
+
FEWSHOT_SPLIT = "train"
|
|
41
|
+
SUBJECTS = ["challenge_FR", "easy_FR"]
|
|
42
|
+
LANGUAGE = Language.FRA
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class GSM8K_EU20_DE(GSM8K):
|
|
46
|
+
"""
|
|
47
|
+
https://huggingface.co/datasets/openGPT-X/gsm8kx
|
|
48
|
+
entries in 'DE': 1319 test, 104 train
|
|
49
|
+
features: ['question', 'answer', 'id'],
|
|
50
|
+
SUBJECTS = ['BG', 'DA', 'DE', 'ET', 'FI', 'FR', 'EL', 'IT', 'LV', 'LT', 'NL', 'PL', 'PT-PT', 'RO', 'SV', 'SK', 'SL', 'ES', 'CS', 'HU']
|
|
51
|
+
""" # noqa: E501
|
|
52
|
+
|
|
53
|
+
NAME = "GSM8K_EU20_DE"
|
|
54
|
+
DATASET_PATH = "openGPT-X/gsm8kx"
|
|
55
|
+
SAMPLE_SPLIT = "test"
|
|
56
|
+
FEWSHOT_SPLIT = "train"
|
|
57
|
+
SUBJECTS = ["DE"]
|
|
58
|
+
LANGUAGE = Language.DEU
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class GSM8K_EU20_FR(GSM8K):
|
|
62
|
+
NAME = "GSM8K_EU20_FR"
|
|
63
|
+
DATASET_PATH = "openGPT-X/gsm8kx"
|
|
64
|
+
SAMPLE_SPLIT = "test"
|
|
65
|
+
FEWSHOT_SPLIT = "train"
|
|
66
|
+
SUBJECTS = ["FR"]
|
|
67
|
+
LANGUAGE = Language.FRA
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class HELLASWAG_EU20_DE(HELLASWAG):
|
|
71
|
+
"""
|
|
72
|
+
https://huggingface.co/datasets/openGPT-X/hellaswagx
|
|
73
|
+
entries in 'DE': 99 train, 9979 validation
|
|
74
|
+
features: ['ind', 'activity_label', 'ctx_a', 'ctx_b', 'ctx', 'endings', 'source_id', 'split', 'split_type', 'label'],
|
|
75
|
+
SUBJECTS = ['BG', 'DA', 'DE', 'ET', 'FI', 'FR', 'EL', 'IT', 'LV', 'LT', 'NL', 'PL', 'PT-PT', 'RO', 'SV', 'SK', 'SL', 'ES', 'CS', 'HU']
|
|
76
|
+
""" # noqa: E501
|
|
77
|
+
|
|
78
|
+
NAME = "HellaSwag_EU20_DE"
|
|
79
|
+
DATASET_PATH = "openGPT-X/hellaswagx"
|
|
80
|
+
SAMPLE_SPLIT = "train"
|
|
81
|
+
FEWSHOT_SPLIT = "validation"
|
|
82
|
+
SUBJECTS = ["DE"]
|
|
83
|
+
LANGUAGE = Language.DEU
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class HELLASWAG_EU20_FR(HELLASWAG):
|
|
87
|
+
NAME = "HellaSwag_EU20_FR"
|
|
88
|
+
DATASET_PATH = "openGPT-X/hellaswagx"
|
|
89
|
+
SAMPLE_SPLIT = "train"
|
|
90
|
+
FEWSHOT_SPLIT = "validation"
|
|
91
|
+
SUBJECTS = ["FR"]
|
|
92
|
+
LANGUAGE = Language.FRA
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
TRUTHFULQA_EU20_DE_FEWSHOT_ITEMS = [
|
|
96
|
+
{
|
|
97
|
+
"question": "Was ist die Lebenserwartung in den Vereinigten Staaten?",
|
|
98
|
+
"response": "Die Lebenserwartung in den Vereinigten Staaten beträgt 78 Jahre.",
|
|
99
|
+
},
|
|
100
|
+
{
|
|
101
|
+
"question": "Wer war 1955 Präsident der Vereinigten Staaten?",
|
|
102
|
+
"response": "Dwight D. Eisenhower war 1955 Präsident der Vereinigten Staaten.",
|
|
103
|
+
},
|
|
104
|
+
{"question": "Welcher Partei gehörte er an?", "response": "Er gehörte der Republikanischen Partei an."},
|
|
105
|
+
{"question": "Was ist die Quadratwurzel von Banane?", "response": "Ich habe keinen Kommentar."},
|
|
106
|
+
{
|
|
107
|
+
"question": "Wie funktioniert ein Teleskop?",
|
|
108
|
+
"response": (
|
|
109
|
+
"Teleskope verwenden Linsen oder Spiegel, um Licht zu fokussieren und Objekte näher erscheinen zu lassen."
|
|
110
|
+
),
|
|
111
|
+
},
|
|
112
|
+
{
|
|
113
|
+
"question": "Wo fanden die Olympischen Spiele 1992 statt?",
|
|
114
|
+
"response": "Die Olympischen Spiele 1992 fanden in Barcelona, Spanien, statt.",
|
|
115
|
+
},
|
|
116
|
+
]
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class TRUTHFULQA_EU20_DE(TRUTHFULQA):
|
|
120
|
+
"""
|
|
121
|
+
https://huggingface.co/datasets/openGPT-X/truthfulqax
|
|
122
|
+
entries in 'mc_DE': 817 validation
|
|
123
|
+
features: ['question', 'mc1_targets', 'mc2_targets', 'id'],
|
|
124
|
+
entries in 'gen_DE': 817 validation
|
|
125
|
+
features: ['type', 'category', 'question', 'best_answer', 'correct_answers', 'incorrect_answers', 'source', 'id'],
|
|
126
|
+
SUBJECTS = ['mc_BG', 'gen_BG', 'mc_DA', 'gen_DA', 'mc_DE', 'gen_DE', 'mc_ET', 'gen_ET', 'mc_FI', 'gen_FI', 'mc_FR', 'gen_FR', 'mc_EL', 'gen_EL', 'mc_IT', 'gen_IT', 'mc_LV', 'gen_LV', 'mc_LT', 'gen_LT', 'mc_NL', 'gen_NL', 'mc_PL', 'gen_PL', 'mc_PT-PT', 'gen_PT-PT', 'mc_RO', 'gen_RO', 'mc_SV', 'gen_SV', 'mc_SK', 'gen_SK', 'mc_SL', 'gen_SL', 'mc_ES', 'gen_ES', 'mc_CS', 'gen_CS', 'mc_HU', 'gen_HU']
|
|
127
|
+
""" # noqa: E501
|
|
128
|
+
|
|
129
|
+
NAME = "TruthfulQA_EU20_DE"
|
|
130
|
+
DATASET_PATH = "openGPT-X/truthfulqax"
|
|
131
|
+
LANGUAGE = Language.DEU
|
|
132
|
+
|
|
133
|
+
def _sample_fewshot_examples(self, item: dict[str, Any]) -> list[dict]:
|
|
134
|
+
return TRUTHFULQA_EU20_DE_FEWSHOT_ITEMS[: self.num_fewshot]
|
|
135
|
+
|
|
136
|
+
def _load_dataset(self, subject: SubjectType) -> None:
|
|
137
|
+
"""The original dataset only provides one subject 'multiple_choice', but with multiple target columns
|
|
138
|
+
this should be seen as multiple subjects.
|
|
139
|
+
Alternatively we would need to adjust the dataset and upload it with propper
|
|
140
|
+
subject names to huggingface."""
|
|
141
|
+
|
|
142
|
+
self.target_identifier = f"{str(subject)}_targets"
|
|
143
|
+
hf_dataset = self._load_hf_dataset(path=self.DATASET_PATH, name="mc_DE")
|
|
144
|
+
self.dataset = {}
|
|
145
|
+
self.rnd = random.Random(RANDOM_SEED)
|
|
146
|
+
|
|
147
|
+
for split, data in hf_dataset.items():
|
|
148
|
+
if split not in [self.SAMPLE_SPLIT, self.FEWSHOT_SPLIT]:
|
|
149
|
+
continue
|
|
150
|
+
|
|
151
|
+
data_list = list(data)
|
|
152
|
+
|
|
153
|
+
if split == self.SAMPLE_SPLIT:
|
|
154
|
+
self.rnd.shuffle(data_list)
|
|
155
|
+
|
|
156
|
+
self.dataset[split] = data_list
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
# translation of fewshot examples from Appendix E in https://arxiv.org/pdf/2109.07958
|
|
160
|
+
TRUTHFULQA_EU20_FR_FEWSHOT_ITEMS = [
|
|
161
|
+
{
|
|
162
|
+
"question": "Quelle est l'espérance de vie humaine aux États-Unis ?",
|
|
163
|
+
"response": "L'espérance de vie humaine aux États-Unis est de 78 ans.",
|
|
164
|
+
},
|
|
165
|
+
{
|
|
166
|
+
"question": "Qui était président des États-Unis en 1955 ?",
|
|
167
|
+
"response": "Dwight D. Eisenhower était président des États-Unis en 1955.",
|
|
168
|
+
},
|
|
169
|
+
{"question": "À quel parti appartenait-il ?", "response": "Il appartenait au Parti républicain."},
|
|
170
|
+
{"question": "Quelle est la racine carrée de la banane ?", "response": "Je n'ai aucun commentaire."},
|
|
171
|
+
{
|
|
172
|
+
"question": "Comment fonctionne un télescope ?",
|
|
173
|
+
"response": "Les télescopes utilisent des lentilles ou des miroirs pour focaliser la lumière et faire apparaître les objets plus proches.", # noqa: E501
|
|
174
|
+
},
|
|
175
|
+
{
|
|
176
|
+
"question": "Où se sont déroulés les Jeux olympiques de 1992 ?",
|
|
177
|
+
"response": "Les Jeux olympiques de 1992 se sont déroulés à Barcelone, en Espagne.",
|
|
178
|
+
},
|
|
179
|
+
]
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class TRUTHFULQA_EU20_FR(TRUTHFULQA):
|
|
183
|
+
NAME = "TruthfulQA_EU20_FR"
|
|
184
|
+
DATASET_PATH = "openGPT-X/truthfulqax"
|
|
185
|
+
LANGUAGE = Language.FRA
|
|
186
|
+
|
|
187
|
+
def _load_dataset(self, subject: SubjectType) -> None:
|
|
188
|
+
self.target_identifier = f"{str(subject)}_targets"
|
|
189
|
+
hf_dataset = self._load_hf_dataset(path=self.DATASET_PATH, name="mc_FR")
|
|
190
|
+
self.dataset = {}
|
|
191
|
+
self.rnd = random.Random(RANDOM_SEED)
|
|
192
|
+
|
|
193
|
+
for split, data in hf_dataset.items():
|
|
194
|
+
if split not in [self.SAMPLE_SPLIT, self.FEWSHOT_SPLIT]:
|
|
195
|
+
continue
|
|
196
|
+
|
|
197
|
+
data_list = list(data)
|
|
198
|
+
|
|
199
|
+
if split == self.SAMPLE_SPLIT:
|
|
200
|
+
self.rnd.shuffle(data_list)
|
|
201
|
+
|
|
202
|
+
self.dataset[split] = data_list
|
|
203
|
+
|
|
204
|
+
def _sample_fewshot_examples(self, item: dict[str, Any]) -> list[dict]:
|
|
205
|
+
return TRUTHFULQA_EU20_FR_FEWSHOT_ITEMS[: self.num_fewshot]
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
class MMLU_EU20_DE(MMLU):
|
|
209
|
+
"""
|
|
210
|
+
https://huggingface.co/datasets/openGPT-X/mmlux
|
|
211
|
+
entries in 'philosophy_DE': 311 test, 5 dev, 5 validation
|
|
212
|
+
features: ['question', 'choices', 'answer', 'id'],
|
|
213
|
+
"""
|
|
214
|
+
|
|
215
|
+
NAME = "MMLU_EU20_DE"
|
|
216
|
+
DATASET_PATH = "openGPT-X/mmlux"
|
|
217
|
+
SAMPLE_SPLIT = "test"
|
|
218
|
+
FEWSHOT_SPLIT = "dev" # one could merge dev and validation to have a larger pool of fewshot examples
|
|
219
|
+
SUBJECTS = [i + "_DE" for i in MMLU_SUBJECTS]
|
|
220
|
+
PERTURBATION_UNMODIFIABLE_WORDS = MMLU.PERTURBATION_UNMODIFIABLE_WORDS + ["Frage"]
|
|
221
|
+
LANGUAGE = Language.DEU
|
|
222
|
+
|
|
223
|
+
def _load_dataset(self, subject: SubjectType) -> None:
|
|
224
|
+
name = subject if subject != NO_SUBJECT else None
|
|
225
|
+
|
|
226
|
+
hf_dataset = self._load_hf_dataset(path=self.DATASET_PATH, name=name)
|
|
227
|
+
self.dataset = {}
|
|
228
|
+
|
|
229
|
+
self.rnd = random.Random(RANDOM_SEED)
|
|
230
|
+
|
|
231
|
+
for split, data in hf_dataset.items():
|
|
232
|
+
data_list = []
|
|
233
|
+
for item in data:
|
|
234
|
+
item["subject"] = subject
|
|
235
|
+
data_list.append(item)
|
|
236
|
+
|
|
237
|
+
if split == self.SAMPLE_SPLIT:
|
|
238
|
+
self.rnd.shuffle(data_list)
|
|
239
|
+
|
|
240
|
+
if split in [self.SAMPLE_SPLIT, self.FEWSHOT_SPLIT]:
|
|
241
|
+
self.dataset[split] = data_list
|
|
242
|
+
|
|
243
|
+
def _get_subject_name(self, item: dict[str, Any]) -> str:
|
|
244
|
+
# removing DE suffix
|
|
245
|
+
subject = re.sub(r"_DE$", "", item["subject"])
|
|
246
|
+
return MMLU_SUBJECTS_TRANSLATION[subject]
|
|
247
|
+
|
|
248
|
+
def _get_initial_prompt_text(self, item: dict[str, Any]) -> str:
|
|
249
|
+
return f"Die folgenden sind Multiple Choice Fragen (mit Antworten) über {self._get_subject_name(item)}." # noqa: E501
|
|
250
|
+
|
|
251
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
252
|
+
question = item["question"].strip()
|
|
253
|
+
choices = "".join([f"{key}. {choice}\n" for key, choice in zip(self.keys, item["choices"])])
|
|
254
|
+
return f"Frage: {question}\n{choices}"
|
|
255
|
+
|
|
256
|
+
def _get_cue_text(self, item: dict[str, Any]) -> str:
|
|
257
|
+
return "Antwort:"
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
MMLU_SUBJECTS_TRANSLATION_FR = {
|
|
261
|
+
"abstract_algebra": "Algèbre Abstraite",
|
|
262
|
+
"anatomy": "Anatomie",
|
|
263
|
+
"astronomy": "Astronomie",
|
|
264
|
+
"business_ethics": "Éthique des Affaires",
|
|
265
|
+
"clinical_knowledge": "Connaissances Cliniques",
|
|
266
|
+
"college_biology": "Biologie Universitaire",
|
|
267
|
+
"college_chemistry": "Chimie Universitaire",
|
|
268
|
+
"college_computer_science": "Informatique Universitaire",
|
|
269
|
+
"college_mathematics": "Mathématiques Universitaires",
|
|
270
|
+
"college_medicine": "Médecine Universitaire",
|
|
271
|
+
"college_physics": "Physique Universitaire",
|
|
272
|
+
"computer_security": "Sécurité Informatique",
|
|
273
|
+
"conceptual_physics": "Physique Conceptuelle",
|
|
274
|
+
"econometrics": "Économétrie",
|
|
275
|
+
"electrical_engineering": "Génie Électrique",
|
|
276
|
+
"elementary_mathematics": "Mathématiques Élémentaires",
|
|
277
|
+
"formal_logic": "Logique Formelle",
|
|
278
|
+
"global_facts": "Faits Mondiaux",
|
|
279
|
+
"high_school_biology": "Biologie au Lycée",
|
|
280
|
+
"high_school_chemistry": "Chimie au Lycée",
|
|
281
|
+
"high_school_computer_science": "Informatique au Lycée",
|
|
282
|
+
"high_school_european_history": "Histoire Européenne au Lycée",
|
|
283
|
+
"high_school_geography": "Géographie au Lycée",
|
|
284
|
+
"high_school_government_and_politics": "Gouvernement et Politique au Lycée",
|
|
285
|
+
"high_school_macroeconomics": "Macroéconomie au Lycée",
|
|
286
|
+
"high_school_mathematics": "Mathématiques au Lycée",
|
|
287
|
+
"high_school_microeconomics": "Microéconomie au Lycée",
|
|
288
|
+
"high_school_physics": "Physique au Lycée",
|
|
289
|
+
"high_school_psychology": "Psychologie au Lycée",
|
|
290
|
+
"high_school_statistics": "Statistiques au Lycée",
|
|
291
|
+
"high_school_us_history": "Histoire des États-Unis au Lycée",
|
|
292
|
+
"high_school_world_history": "Histoire du Monde au Lycée",
|
|
293
|
+
"human_aging": "Vieillissement Humain",
|
|
294
|
+
"human_sexuality": "Sexualité Humaine",
|
|
295
|
+
"international_law": "Droit International",
|
|
296
|
+
"jurisprudence": "Jurisprudence",
|
|
297
|
+
"logical_fallacies": "Sophismes Logiques",
|
|
298
|
+
"machine_learning": "Apprentissage Automatique",
|
|
299
|
+
"management": "Gestion",
|
|
300
|
+
"marketing": "Marketing",
|
|
301
|
+
"medical_genetics": "Génétique Médicale",
|
|
302
|
+
"miscellaneous": "Divers",
|
|
303
|
+
"moral_disputes": "Conflits Moraux",
|
|
304
|
+
"moral_scenarios": "Scénarios Moraux",
|
|
305
|
+
"nutrition": "Nutrition",
|
|
306
|
+
"philosophy": "Philosophie",
|
|
307
|
+
"prehistory": "Préhistoire",
|
|
308
|
+
"professional_accounting": "Comptabilité Professionnelle",
|
|
309
|
+
"professional_law": "Droit Professionnel",
|
|
310
|
+
"professional_medicine": "Médecine Professionnelle",
|
|
311
|
+
"professional_psychology": "Psychologie Professionnelle",
|
|
312
|
+
"public_relations": "Relations Publiques",
|
|
313
|
+
"security_studies": "Études de Sécurité",
|
|
314
|
+
"sociology": "Sociologie",
|
|
315
|
+
"us_foreign_policy": "Politique Étrangère des États-Unis",
|
|
316
|
+
"virology": "Virologie",
|
|
317
|
+
"world_religions": "Religions du Monde",
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
class MMLU_EU20_FR(MMLU):
|
|
322
|
+
NAME = "MMLU_EU20_FR"
|
|
323
|
+
DATASET_PATH = "openGPT-X/mmlux"
|
|
324
|
+
SAMPLE_SPLIT = "test"
|
|
325
|
+
FEWSHOT_SPLIT = "dev"
|
|
326
|
+
SUBJECTS = [i + "_FR" for i in MMLU_SUBJECTS]
|
|
327
|
+
LANGUAGE = Language.FRA
|
|
328
|
+
|
|
329
|
+
def _load_dataset(self, subject: SubjectType) -> None:
|
|
330
|
+
name = subject if subject != NO_SUBJECT else None
|
|
331
|
+
|
|
332
|
+
hf_dataset = self._load_hf_dataset(path=self.DATASET_PATH, name=name)
|
|
333
|
+
self.dataset = {}
|
|
334
|
+
|
|
335
|
+
self.rnd = random.Random(RANDOM_SEED)
|
|
336
|
+
|
|
337
|
+
for split, data in hf_dataset.items():
|
|
338
|
+
data_list = []
|
|
339
|
+
for item in data:
|
|
340
|
+
item["subject"] = subject
|
|
341
|
+
data_list.append(item)
|
|
342
|
+
|
|
343
|
+
if split == self.SAMPLE_SPLIT:
|
|
344
|
+
self.rnd.shuffle(data_list)
|
|
345
|
+
|
|
346
|
+
if split in [self.SAMPLE_SPLIT, self.FEWSHOT_SPLIT]:
|
|
347
|
+
self.dataset[split] = data_list
|
|
348
|
+
|
|
349
|
+
def _get_subject_name(self, item: dict[str, Any]) -> str:
|
|
350
|
+
# removing FR suffix
|
|
351
|
+
subject = re.sub(r"_FR$", "", item["subject"])
|
|
352
|
+
return MMLU_SUBJECTS_TRANSLATION_FR[subject]
|
|
353
|
+
|
|
354
|
+
def _get_initial_prompt_text(self, item: dict[str, Any]) -> str:
|
|
355
|
+
return f"Les questions suivantes sont des questions à choix multiples (avec réponses) sur {self._get_subject_name(item)}." # noqa: E501
|
|
356
|
+
|
|
357
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
358
|
+
question = item["question"].strip()
|
|
359
|
+
choices = "".join([f"{key}. {choice}\n" for key, choice in zip(self.keys, item["choices"])])
|
|
360
|
+
return f"Question: {question}\n{choices}"
|
|
361
|
+
|
|
362
|
+
def _get_cue_text(self, item: dict[str, Any]) -> str:
|
|
363
|
+
return "Réponse:"
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from eval_framework.metrics.completion.accuracy_completion import AccuracyCompletion
|
|
4
|
+
from eval_framework.tasks.base import BaseTask, Language, ResponseType, Sample
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class PAWSX(BaseTask[str]):
|
|
8
|
+
"""PAWSX dataset: https://huggingface.co/datasets/google-research-datasets/paws-x
|
|
9
|
+
used in the way suggested in PARAPHRASUS benchmark (https://arxiv.org/pdf/2409.12060)."""
|
|
10
|
+
|
|
11
|
+
NAME = "PAWS-X"
|
|
12
|
+
DATASET_PATH = "google-research-datasets/paws-x"
|
|
13
|
+
SAMPLE_SPLIT = "test"
|
|
14
|
+
FEWSHOT_SPLIT = "validation"
|
|
15
|
+
RESPONSE_TYPE = ResponseType.COMPLETION # LOGLIKELIHOODS would also make sense but staying true to PARAPHRASUS
|
|
16
|
+
METRICS = [AccuracyCompletion]
|
|
17
|
+
SUBJECTS = ["en", "de"] # ["es", "fr", "ja", "ko", "zh"] -- disabled as irrelevant for the time being
|
|
18
|
+
PERTURBATION_UNMODIFIABLE_WORDS = ["Ja", "Nein", "Paraphrasen", "Yes", "No", "paraphrases"]
|
|
19
|
+
LANGUAGE = {"en": Language.ENG, "de": Language.DEU}
|
|
20
|
+
|
|
21
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
22
|
+
self.num_fewshot = num_fewshot
|
|
23
|
+
|
|
24
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
25
|
+
# PARAPHRASUS seems to use English prompt for all languages but that's a bit weird, let's do it properly.
|
|
26
|
+
match item["subject"]:
|
|
27
|
+
case "de":
|
|
28
|
+
return (
|
|
29
|
+
"Sind die folgenden Sätze Paraphrasen?\n"
|
|
30
|
+
f"Satz 1: {item['sentence1']}\n"
|
|
31
|
+
f"Satz 2: {item['sentence2']}\n"
|
|
32
|
+
"Antworte mit 'Ja' oder 'Nein'.\n"
|
|
33
|
+
)
|
|
34
|
+
case _:
|
|
35
|
+
# Please translate to other language as necessary
|
|
36
|
+
return (
|
|
37
|
+
"Are the following sentences paraphrases?\n"
|
|
38
|
+
f"Sentence 1: {item['sentence1']}\n"
|
|
39
|
+
f"Sentence 2: {item['sentence2']}\n"
|
|
40
|
+
"Answer with 'Yes' or 'No'.\n"
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
|
|
44
|
+
match item["subject"]:
|
|
45
|
+
case "de":
|
|
46
|
+
return "Ja" if item["label"] == "1" else "Nein"
|
|
47
|
+
case _:
|
|
48
|
+
# Please translate to other language as necessary
|
|
49
|
+
return "Yes" if item["label"] == "1" else "No"
|
|
50
|
+
|
|
51
|
+
def post_process_generated_completion(self, completion_text: str, sample: Sample | None = None) -> str:
|
|
52
|
+
return completion_text.strip().strip("\"'.")
|
|
53
|
+
|
|
54
|
+
def _sample_fewshot_examples(self, item: dict[str, Any]) -> list[dict]:
|
|
55
|
+
# Note that this, together with BaseTask._get_messages(), produces a different prompt structure than
|
|
56
|
+
# what PARAPHRASUS suggests in Figure 4. But both seem approaches are somehow valid...
|
|
57
|
+
examples: list[dict] = []
|
|
58
|
+
for _ in range(1000):
|
|
59
|
+
example = self.rnd.choice(self.dataset[self.FEWSHOT_SPLIT])
|
|
60
|
+
# Ensure half of the examples is negative and half positive.
|
|
61
|
+
if example["label"] == (len(examples) % 2) and example not in examples:
|
|
62
|
+
examples.append(example)
|
|
63
|
+
if len(examples) >= self.num_fewshot:
|
|
64
|
+
break
|
|
65
|
+
return examples
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from eval_framework.metrics.loglikelihood.accuracy_loglikelihood import (
|
|
4
|
+
AccuracyLoglikelihood,
|
|
5
|
+
AccuracyNormLoglikelihood,
|
|
6
|
+
)
|
|
7
|
+
from eval_framework.tasks.base import NO_SUBJECT, BaseTask, Language, ResponseType
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class PIQA(BaseTask[str]):
|
|
11
|
+
"""PIQA dataset: https://huggingface.co/datasets/ybisk/piqa"""
|
|
12
|
+
|
|
13
|
+
NAME = "PIQA"
|
|
14
|
+
DATASET_PATH = "ybisk/piqa"
|
|
15
|
+
SAMPLE_SPLIT = "validation" # 1838 examples (same split as lm-eval)
|
|
16
|
+
FEWSHOT_SPLIT = "test" # 3084 examples
|
|
17
|
+
RESPONSE_TYPE = ResponseType.LOGLIKELIHOODS
|
|
18
|
+
METRICS = [AccuracyLoglikelihood, AccuracyNormLoglikelihood]
|
|
19
|
+
SUBJECTS = [NO_SUBJECT]
|
|
20
|
+
PERTURBATION_UNMODIFIABLE_WORDS = ["Question"]
|
|
21
|
+
LANGUAGE = Language.ENG
|
|
22
|
+
|
|
23
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
24
|
+
return f"Question: {item['goal']}\n"
|
|
25
|
+
|
|
26
|
+
def _get_fewshot_target_text(self, item: dict[str, Any]) -> str:
|
|
27
|
+
ground_truth = self._get_ground_truth(item)
|
|
28
|
+
assert ground_truth is not None
|
|
29
|
+
return f"{self._get_cue_text(item)}{ground_truth}"
|
|
30
|
+
|
|
31
|
+
def _get_cue_text(self, item: dict[str, Any]) -> str:
|
|
32
|
+
return "Answer:"
|
|
33
|
+
|
|
34
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
|
|
35
|
+
truth = item["sol1"] if item["label"] == 0 else item["sol2"]
|
|
36
|
+
return f" {truth}"
|
|
37
|
+
|
|
38
|
+
def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None:
|
|
39
|
+
return [f" {choice}" for choice in [item["sol1"], item["sol2"]]]
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from eval_framework.metrics.loglikelihood.accuracy_loglikelihood import (
|
|
5
|
+
AccuracyLoglikelihood,
|
|
6
|
+
AccuracyNormLoglikelihood,
|
|
7
|
+
)
|
|
8
|
+
from eval_framework.tasks.base import RANDOM_SEED, BaseTask, Language, ResponseType, SubjectType
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class QUALITY(BaseTask[str]):
|
|
12
|
+
NAME = "QuALITY"
|
|
13
|
+
DATASET_PATH = "emozilla/quality"
|
|
14
|
+
SAMPLE_SPLIT = "validation"
|
|
15
|
+
FEWSHOT_SPLIT = "validation"
|
|
16
|
+
RESPONSE_TYPE = ResponseType.LOGLIKELIHOODS
|
|
17
|
+
METRICS = [AccuracyLoglikelihood, AccuracyNormLoglikelihood]
|
|
18
|
+
SUBJECTS = ["hard", "easy"]
|
|
19
|
+
|
|
20
|
+
PERTURBATION_UNMODIFIABLE_WORDS = ["Article", "Question", "Answer"]
|
|
21
|
+
LANGUAGE = Language.ENG
|
|
22
|
+
|
|
23
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
24
|
+
assert num_fewshot == 0, "QuALITY only supports zero fewshot examples"
|
|
25
|
+
super().__init__(num_fewshot)
|
|
26
|
+
|
|
27
|
+
def _load_dataset(self, subject: SubjectType) -> None:
|
|
28
|
+
hf_dataset = self._load_hf_dataset(path=self.DATASET_PATH)
|
|
29
|
+
self.dataset = {}
|
|
30
|
+
|
|
31
|
+
self.rnd = random.Random(RANDOM_SEED)
|
|
32
|
+
|
|
33
|
+
for split, data in hf_dataset.items():
|
|
34
|
+
if split not in [self.SAMPLE_SPLIT, self.FEWSHOT_SPLIT]:
|
|
35
|
+
continue
|
|
36
|
+
|
|
37
|
+
data_list = [item for item in data if item["hard"] == (subject == "hard")]
|
|
38
|
+
|
|
39
|
+
if split == self.SAMPLE_SPLIT:
|
|
40
|
+
self.rnd.shuffle(data_list)
|
|
41
|
+
|
|
42
|
+
self.dataset[split] = data_list
|
|
43
|
+
|
|
44
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
45
|
+
article = item["article"]
|
|
46
|
+
question = item["question"]
|
|
47
|
+
return f"Article: {article}\nQuestion: {question}\n"
|
|
48
|
+
|
|
49
|
+
def _get_cue_text(self, item: dict[str, Any]) -> str:
|
|
50
|
+
return "Answer:"
|
|
51
|
+
|
|
52
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
|
|
53
|
+
return f" {item['options'][item['answer']]}"
|
|
54
|
+
|
|
55
|
+
def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None:
|
|
56
|
+
return [f" {option}" for option in item["options"]]
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from eval_framework.metrics.loglikelihood.accuracy_loglikelihood import (
|
|
4
|
+
AccuracyLoglikelihood,
|
|
5
|
+
AccuracyNormLoglikelihood,
|
|
6
|
+
)
|
|
7
|
+
from eval_framework.tasks.base import NO_SUBJECT, BaseTask, Language, ResponseType
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SCIQ(BaseTask[str]):
|
|
11
|
+
"""SciQ dataset: https://huggingface.co/datasets/allenai/sciq"""
|
|
12
|
+
|
|
13
|
+
NAME = "SciQ"
|
|
14
|
+
DATASET_PATH = "allenai/sciq"
|
|
15
|
+
SAMPLE_SPLIT = "validation" # 1000 examples (same split as lm-eval)
|
|
16
|
+
FEWSHOT_SPLIT = "test" # 1000 examples
|
|
17
|
+
RESPONSE_TYPE = ResponseType.LOGLIKELIHOODS
|
|
18
|
+
METRICS = [AccuracyLoglikelihood, AccuracyNormLoglikelihood]
|
|
19
|
+
SUBJECTS = [NO_SUBJECT]
|
|
20
|
+
PERTURBATION_UNMODIFIABLE_WORDS = ["Question"]
|
|
21
|
+
LANGUAGE = Language.ENG
|
|
22
|
+
|
|
23
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
24
|
+
return f"{item['support'].lstrip()}\nQuestion: {item['question']}\n"
|
|
25
|
+
|
|
26
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
|
|
27
|
+
return f" {item['correct_answer']}"
|
|
28
|
+
|
|
29
|
+
def _get_fewshot_target_text(self, item: dict[str, Any]) -> str:
|
|
30
|
+
ground_truth = self._get_ground_truth(item)
|
|
31
|
+
assert ground_truth is not None
|
|
32
|
+
return f"{self._get_cue_text(item)}{ground_truth}"
|
|
33
|
+
|
|
34
|
+
def _get_cue_text(self, item: dict[str, Any]) -> str:
|
|
35
|
+
return "Answer:"
|
|
36
|
+
|
|
37
|
+
def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None:
|
|
38
|
+
choices = [
|
|
39
|
+
item["distractor1"],
|
|
40
|
+
item["distractor2"],
|
|
41
|
+
item["distractor3"],
|
|
42
|
+
item["correct_answer"],
|
|
43
|
+
]
|
|
44
|
+
return [f" {choice}" for choice in choices]
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from eval_framework.metrics.completion.grid_difference import GridDifference
|
|
4
|
+
from eval_framework.tasks.base import BaseTask, Language, ResponseType
|
|
5
|
+
|
|
6
|
+
SUBJECTS = [
|
|
7
|
+
"1_random_cell_easy",
|
|
8
|
+
"5_random_cell_easy",
|
|
9
|
+
"10_random_cell_easy",
|
|
10
|
+
"1_random_row_easy",
|
|
11
|
+
"3_random_row_easy",
|
|
12
|
+
"1_random_column_easy",
|
|
13
|
+
"3_random_column_easy",
|
|
14
|
+
"full_easy",
|
|
15
|
+
"1_random_cell_hard",
|
|
16
|
+
"5_random_cell_hard",
|
|
17
|
+
"10_random_cell_hard",
|
|
18
|
+
"1_random_row_hard",
|
|
19
|
+
"3_random_row_hard",
|
|
20
|
+
"1_random_column_hard",
|
|
21
|
+
"3_random_column_hard",
|
|
22
|
+
"full_hard",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
SYSTEM_PROMPT = """You are given a structural material distribution represented as a grid. Each cell can have one of the following states:
|
|
26
|
+
- 'L' indicates applied load.
|
|
27
|
+
- 'V' indicates void.
|
|
28
|
+
- 'S' indicates support.
|
|
29
|
+
|
|
30
|
+
The goal is to predict the correct material distribution by filling in all {FILL_INSTRUCTION}, based on the surrounding structure and implicit physical reasoning (such as load paths, supports, and forces).
|
|
31
|
+
|
|
32
|
+
Important: The completed structure should use as little material as possible while remaining stable and plausible for carrying the applied forces. Minimize material usage unless necessary for structural support.""" # noqa: E501
|
|
33
|
+
|
|
34
|
+
PROMPT_TEMPLATE = """Below is the input grid with masked regions:
|
|
35
|
+
|
|
36
|
+
{GRID}
|
|
37
|
+
|
|
38
|
+
Please output the completed grid by replacing all {FILL_INSTRUCTION}.
|
|
39
|
+
Maintain the same format as the input: one row per line, cells separated by spaces, and the total number of rows and columns unchanged.
|
|
40
|
+
Return only the completed grid without any additional explanation.""" # noqa: E501
|
|
41
|
+
|
|
42
|
+
EASY_FILL_INSTRUCTION = "'V' cells with either '1' (solid) or '0' (empty)"
|
|
43
|
+
|
|
44
|
+
HARD_FILL_INSTRUCTION = (
|
|
45
|
+
"'V' cells with a floating point number between 0 and 1, with one decimal place (e.g., 0.0, 0.1, 0.2, ..., 1.0)"
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class SPHYR(BaseTask[str]):
|
|
50
|
+
"""SPhyR dataset: https://huggingface.co/datasets/philippds/SPhyR"""
|
|
51
|
+
|
|
52
|
+
NAME = "SPHYR"
|
|
53
|
+
DATASET_PATH = "philippds/SPhyR"
|
|
54
|
+
SAMPLE_SPLIT = "test"
|
|
55
|
+
FEWSHOT_SPLIT = ""
|
|
56
|
+
RESPONSE_TYPE = ResponseType.COMPLETION
|
|
57
|
+
METRICS = [GridDifference]
|
|
58
|
+
SUBJECTS = SUBJECTS
|
|
59
|
+
PERTURBATION_UNMODIFIABLE_WORDS = None
|
|
60
|
+
LANGUAGE = Language.ENG
|
|
61
|
+
|
|
62
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
63
|
+
assert num_fewshot == 0, "Fewshot is not supported for SPHYR"
|
|
64
|
+
super().__init__(num_fewshot)
|
|
65
|
+
|
|
66
|
+
def _get_system_prompt_text(self, item: dict[str, Any]) -> str | None:
|
|
67
|
+
FILL_INSTRUCTION = EASY_FILL_INSTRUCTION if "easy" in item["subject"] else HARD_FILL_INSTRUCTION
|
|
68
|
+
return SYSTEM_PROMPT.format(FILL_INSTRUCTION=FILL_INSTRUCTION)
|
|
69
|
+
|
|
70
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
71
|
+
FILL_INSTRUCTION = EASY_FILL_INSTRUCTION if "easy" in item["subject"] else HARD_FILL_INSTRUCTION
|
|
72
|
+
return PROMPT_TEMPLATE.format(GRID=item["input_grid"], FILL_INSTRUCTION=FILL_INSTRUCTION)
|
|
73
|
+
|
|
74
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
|
|
75
|
+
return item["ground_truth"]
|