ai2-olmo-eval 0.7.2__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ai2_olmo_eval-0.7.2.dist-info → ai2_olmo_eval-0.8.0.dist-info}/METADATA +1 -1
- {ai2_olmo_eval-0.7.2.dist-info → ai2_olmo_eval-0.8.0.dist-info}/RECORD +8 -8
- olmo_eval/metrics.py +112 -87
- olmo_eval/tasks.py +430 -2
- olmo_eval/version.py +2 -2
- {ai2_olmo_eval-0.7.2.dist-info → ai2_olmo_eval-0.8.0.dist-info}/WHEEL +0 -0
- {ai2_olmo_eval-0.7.2.dist-info → ai2_olmo_eval-0.8.0.dist-info}/licenses/LICENSE +0 -0
- {ai2_olmo_eval-0.7.2.dist-info → ai2_olmo_eval-0.8.0.dist-info}/top_level.txt +0 -0
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
ai2_olmo_eval-0.
|
|
1
|
+
ai2_olmo_eval-0.8.0.dist-info/licenses/LICENSE,sha256=YvuKOpYh3COIF0yqq-nCMXtpS7mh1GyYvPVlW2j1G-M,11359
|
|
2
2
|
olmo_eval/__init__.py,sha256=49RxnAaJNk8U9XP3SF5MjyFIxLSkxH0vXQuZgnEOi44,283
|
|
3
|
-
olmo_eval/metrics.py,sha256=
|
|
4
|
-
olmo_eval/tasks.py,sha256=
|
|
3
|
+
olmo_eval/metrics.py,sha256=zc4IOZ8rUhxPyXVk6fOYzVKjJ4Lzq4tYeoyurxYWqY0,20034
|
|
4
|
+
olmo_eval/tasks.py,sha256=DF4-2MS5dkGgZSjNrRkjEoWShrAsGO7tiB6mqcTQnE8,93219
|
|
5
5
|
olmo_eval/tokenizer.py,sha256=PnkidE0nAtEA1QZjuQpE_bIwgAsHxodnaJRALAPqrJQ,5127
|
|
6
6
|
olmo_eval/util.py,sha256=ARmZmRQl8VOvnKQoUprb3cOunzcApeNhRdV4BMXZuvo,3856
|
|
7
|
-
olmo_eval/version.py,sha256=
|
|
7
|
+
olmo_eval/version.py,sha256=ucNFr1ahYQCmPHuM8Qq6kPbT7lmTnsZQuSxG1jpqphI,308
|
|
8
8
|
olmo_eval/hf_datasets/ai2_arc/ARC-Challenge/validation/data-00000-of-00001.arrow,sha256=TPWbMhBmticWjYp7TA3etcKbXbaoCDBWhxuqlD1bDJA,98080
|
|
9
9
|
olmo_eval/hf_datasets/ai2_arc/ARC-Challenge/validation/dataset_info.json,sha256=iZumP5Udu8LD7cbew3o7nNpnGu-o9jPaMxUrNDDNIVY,1795
|
|
10
10
|
olmo_eval/hf_datasets/ai2_arc/ARC-Challenge/validation/state.json,sha256=6Q1XhM-HMZcymuGAKBC_8RjMBKgJSaR_6lLUO9Z8XwE,255
|
|
@@ -716,7 +716,7 @@ olmo_eval/oe_eval_tasks/winogrande/val_rc_5shot/config.json,sha256=ySjEVqTOj5GwC
|
|
|
716
716
|
olmo_eval/oe_eval_tasks/winogrande/val_rc_5shot/requests.jsonl.gz,sha256=knTzcqigWCfdYLN1Pl0TfCm0Fi1lRASWAo_SC6KtXsc,115262
|
|
717
717
|
olmo_eval/tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json,sha256=yjXYcnpTO7Zjm_R4Gucrn9oA5paadiYM-ZZER5q_EXc,2114319
|
|
718
718
|
olmo_eval/tokenizers/allenai_gpt-neox-olmo-dolma-v1_5.json,sha256=mtM7Szmp-Dlzw_jEKgGUjdW4d6KKyaU1aVbE_07QtxQ,2115113
|
|
719
|
-
ai2_olmo_eval-0.
|
|
720
|
-
ai2_olmo_eval-0.
|
|
721
|
-
ai2_olmo_eval-0.
|
|
722
|
-
ai2_olmo_eval-0.
|
|
719
|
+
ai2_olmo_eval-0.8.0.dist-info/METADATA,sha256=TZmOipbol7scpsNfiFVximYmOONNlOg-J_bhbn0a-FE,14398
|
|
720
|
+
ai2_olmo_eval-0.8.0.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
721
|
+
ai2_olmo_eval-0.8.0.dist-info/top_level.txt,sha256=Pryk28JTb89-j624Uy1gRZiE0YXI3czgbNIfJCl9-x0,10
|
|
722
|
+
ai2_olmo_eval-0.8.0.dist-info/RECORD,,
|
olmo_eval/metrics.py
CHANGED
|
@@ -98,96 +98,121 @@ class ICLMetric(Metric):
|
|
|
98
98
|
batch["ctx_len"][idx] - 1 : batch["ctx_len"][idx] + batch["cont_len"][idx] - 1
|
|
99
99
|
]
|
|
100
100
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
)
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
/ batch["cont_byte_len"][idx]
|
|
135
|
-
* LOG_2_OF_E
|
|
136
|
-
)
|
|
137
|
-
|
|
138
|
-
log_likelihood_no_leading_space = torch.gather(
|
|
139
|
-
lm_cont_logits, 1, cont_tokens.unsqueeze(-1)
|
|
140
|
-
).sum()
|
|
141
|
-
celoss_no_leading_space = (
|
|
142
|
-
-torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum()
|
|
143
|
-
/ batch["cont_str_len_no_leading_space"][idx]
|
|
144
|
-
)
|
|
145
|
-
bpb_no_leading_space = (
|
|
146
|
-
-torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum()
|
|
147
|
-
/ batch["cont_byte_len_no_leading_space"][idx]
|
|
148
|
-
* LOG_2_OF_E
|
|
149
|
-
)
|
|
150
|
-
elif self.metric_type in ["len_norm", "ce_loss", "bpb"]:
|
|
151
|
-
log_likelihood = (
|
|
152
|
-
torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum()
|
|
153
|
-
/ batch["cont_str_len"][idx]
|
|
154
|
-
)
|
|
155
|
-
celoss = (
|
|
156
|
-
-torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum()
|
|
157
|
-
/ batch["cont_str_len"][idx]
|
|
158
|
-
)
|
|
159
|
-
bpb = (
|
|
160
|
-
-torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum()
|
|
161
|
-
/ batch["cont_byte_len"][idx]
|
|
162
|
-
* LOG_2_OF_E
|
|
163
|
-
)
|
|
101
|
+
if "choice_ids" in batch:
|
|
102
|
+
fast_mc = True
|
|
103
|
+
choice_ids = batch["choice_ids"][idx]
|
|
104
|
+
else:
|
|
105
|
+
fast_mc = False
|
|
106
|
+
choice_ids = cont_tokens
|
|
107
|
+
|
|
108
|
+
# For each choice token, calculate metrics and append as separate entries
|
|
109
|
+
for choice_idx, choice_token in enumerate(choice_ids):
|
|
110
|
+
if fast_mc:
|
|
111
|
+
_cont_id = choice_idx
|
|
112
|
+
_cont_tokens = choice_token.unsqueeze(-1)
|
|
113
|
+
else:
|
|
114
|
+
_cont_id = cont_id
|
|
115
|
+
_cont_tokens = cont_tokens
|
|
116
|
+
|
|
117
|
+
# Skip choices for Qs with less than the max choices (for questions w/ different nubmers of choices)
|
|
118
|
+
is_empty_choice = (choice_token.unsqueeze(-1).unsqueeze(-1) == -1).all().item()
|
|
119
|
+
if is_empty_choice:
|
|
120
|
+
continue
|
|
121
|
+
|
|
122
|
+
log_likelihood: torch.Tensor
|
|
123
|
+
celoss: torch.Tensor
|
|
124
|
+
bpb: torch.Tensor
|
|
125
|
+
log_likelihood_no_leading_space: torch.Tensor
|
|
126
|
+
celoss_no_leading_space: torch.Tensor
|
|
127
|
+
bpb_no_leading_space: torch.Tensor
|
|
128
|
+
if self.metric_type == "pmi_dc":
|
|
129
|
+
assert dc_lm_logits is not None
|
|
130
|
+
# get domain conditional continuation logits: [cont_len, vocab]
|
|
131
|
+
dc_lm_cont_logits = dc_lm_logits[idx][
|
|
132
|
+
batch["dc_len"][idx] - 1 : batch["dc_len"][idx] + batch["cont_len"][idx] - 1
|
|
133
|
+
]
|
|
164
134
|
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
135
|
+
# gather log-probs at continuation token indices but divide by domain conditional prob
|
|
136
|
+
log_likelihood = (
|
|
137
|
+
torch.gather(lm_cont_logits, 1, _cont_tokens.unsqueeze(-1)).sum()
|
|
138
|
+
/ torch.gather(dc_lm_cont_logits, 1, _cont_tokens.unsqueeze(-1)).sum()
|
|
139
|
+
)
|
|
140
|
+
celoss = -log_likelihood
|
|
141
|
+
bpb = -log_likelihood # the normalization factors cancel out
|
|
142
|
+
|
|
143
|
+
log_likelihood_no_leading_space = log_likelihood
|
|
144
|
+
celoss_no_leading_space = celoss
|
|
145
|
+
bpb_no_leading_space = bpb
|
|
146
|
+
elif self.metric_type == "acc" or self.metric_type == "f1":
|
|
147
|
+
# gather log-probs at continuation token indices
|
|
148
|
+
log_likelihood = torch.gather(
|
|
149
|
+
lm_cont_logits, 1, _cont_tokens.unsqueeze(-1)
|
|
150
|
+
).sum()
|
|
151
|
+
celoss = (
|
|
152
|
+
-torch.gather(lm_cont_logits, 1, _cont_tokens.unsqueeze(-1)).sum()
|
|
153
|
+
/ batch["cont_str_len"][idx]
|
|
154
|
+
)
|
|
155
|
+
bpb = (
|
|
156
|
+
-torch.gather(lm_cont_logits, 1, _cont_tokens.unsqueeze(-1)).sum()
|
|
157
|
+
/ batch["cont_byte_len"][idx]
|
|
158
|
+
* LOG_2_OF_E
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
log_likelihood_no_leading_space = torch.gather(
|
|
162
|
+
lm_cont_logits, 1, _cont_tokens.unsqueeze(-1)
|
|
163
|
+
).sum()
|
|
164
|
+
celoss_no_leading_space = (
|
|
165
|
+
-torch.gather(lm_cont_logits, 1, _cont_tokens.unsqueeze(-1)).sum()
|
|
166
|
+
/ batch["cont_str_len_no_leading_space"][idx]
|
|
167
|
+
)
|
|
168
|
+
bpb_no_leading_space = (
|
|
169
|
+
-torch.gather(lm_cont_logits, 1, _cont_tokens.unsqueeze(-1)).sum()
|
|
170
|
+
/ batch["cont_byte_len_no_leading_space"][idx]
|
|
171
|
+
* LOG_2_OF_E
|
|
172
|
+
)
|
|
173
|
+
elif self.metric_type in ["len_norm", "ce_loss", "bpb"]:
|
|
174
|
+
log_likelihood = (
|
|
175
|
+
torch.gather(lm_cont_logits, 1, _cont_tokens.unsqueeze(-1)).sum()
|
|
176
|
+
/ batch["cont_str_len"][idx]
|
|
177
|
+
)
|
|
178
|
+
celoss = (
|
|
179
|
+
-torch.gather(lm_cont_logits, 1, _cont_tokens.unsqueeze(-1)).sum()
|
|
180
|
+
/ batch["cont_str_len"][idx]
|
|
181
|
+
)
|
|
182
|
+
bpb = (
|
|
183
|
+
-torch.gather(lm_cont_logits, 1, _cont_tokens.unsqueeze(-1)).sum()
|
|
184
|
+
/ batch["cont_byte_len"][idx]
|
|
185
|
+
* LOG_2_OF_E
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
log_likelihood_no_leading_space = (
|
|
189
|
+
torch.gather(lm_cont_logits, 1, _cont_tokens.unsqueeze(-1)).sum()
|
|
190
|
+
/ batch["cont_str_len_no_leading_space"][idx]
|
|
191
|
+
)
|
|
192
|
+
celoss_no_leading_space = (
|
|
193
|
+
-torch.gather(lm_cont_logits, 1, _cont_tokens.unsqueeze(-1)).sum()
|
|
194
|
+
/ batch["cont_str_len_no_leading_space"][idx]
|
|
195
|
+
)
|
|
196
|
+
bpb_no_leading_space = (
|
|
197
|
+
-torch.gather(lm_cont_logits, 1, _cont_tokens.unsqueeze(-1)).sum()
|
|
198
|
+
/ batch["cont_byte_len_no_leading_space"][idx]
|
|
199
|
+
* LOG_2_OF_E
|
|
200
|
+
)
|
|
201
|
+
else:
|
|
202
|
+
raise ValueError(self.metric_type)
|
|
203
|
+
|
|
204
|
+
self.labels.append((doc_id, _cont_id, int(batch["label_id"][idx])))
|
|
205
|
+
self.loglikelihoods.append((doc_id, _cont_id, float(log_likelihood)))
|
|
206
|
+
self.celosses.append((doc_id, _cont_id, float(celoss)))
|
|
207
|
+
self.bpbs.append((doc_id, _cont_id, float(bpb)))
|
|
208
|
+
|
|
209
|
+
self.loglikelihoods_no_leading_space.append(
|
|
210
|
+
(doc_id, _cont_id, float(log_likelihood_no_leading_space))
|
|
172
211
|
)
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
/ batch["cont_byte_len_no_leading_space"][idx]
|
|
176
|
-
* LOG_2_OF_E
|
|
212
|
+
self.celosses_no_leading_space.append(
|
|
213
|
+
(doc_id, _cont_id, float(celoss_no_leading_space))
|
|
177
214
|
)
|
|
178
|
-
|
|
179
|
-
raise ValueError(self.metric_type)
|
|
180
|
-
|
|
181
|
-
self.labels.append((doc_id, cont_id, int(batch["label_id"][idx])))
|
|
182
|
-
self.loglikelihoods.append((doc_id, cont_id, float(log_likelihood)))
|
|
183
|
-
self.celosses.append((doc_id, cont_id, float(celoss)))
|
|
184
|
-
self.bpbs.append((doc_id, cont_id, float(bpb)))
|
|
185
|
-
|
|
186
|
-
self.loglikelihoods_no_leading_space.append(
|
|
187
|
-
(doc_id, cont_id, float(log_likelihood_no_leading_space))
|
|
188
|
-
)
|
|
189
|
-
self.celosses_no_leading_space.append((doc_id, cont_id, float(celoss_no_leading_space)))
|
|
190
|
-
self.bpbs_no_leading_space.append((doc_id, cont_id, float(bpb_no_leading_space)))
|
|
215
|
+
self.bpbs_no_leading_space.append((doc_id, _cont_id, float(bpb_no_leading_space)))
|
|
191
216
|
|
|
192
217
|
def compute(self) -> Dict[str, torch.Tensor]:
|
|
193
218
|
# Task "suffix" -> tensor
|
olmo_eval/tasks.py
CHANGED
|
@@ -33,6 +33,7 @@ class ICLMultiChoiceTaskDataset(metaclass=abc.ABCMeta):
|
|
|
33
33
|
dataset_name: Union[str, Sequence[str], None] = None,
|
|
34
34
|
model_ctx_len: int = 2048,
|
|
35
35
|
fixed_ctx_len: bool = False,
|
|
36
|
+
fast_mc: bool = False,
|
|
36
37
|
split="validation",
|
|
37
38
|
metric_type=None, # Override default metric type
|
|
38
39
|
prompts: Optional[List[Optional[str]]] = None, # List of prompt variants to use
|
|
@@ -44,6 +45,7 @@ class ICLMultiChoiceTaskDataset(metaclass=abc.ABCMeta):
|
|
|
44
45
|
self.dataset_name = dataset_name
|
|
45
46
|
self.model_ctx_len = model_ctx_len
|
|
46
47
|
self.fixed_ctx_len = fixed_ctx_len
|
|
48
|
+
self.fast_mc = fast_mc
|
|
47
49
|
self.prompts = prompts or [None]
|
|
48
50
|
self.current_prompt: Optional[str] = None
|
|
49
51
|
if metric_type is not None:
|
|
@@ -76,6 +78,7 @@ class ICLMultiChoiceTaskDataset(metaclass=abc.ABCMeta):
|
|
|
76
78
|
def prep_examples(self):
|
|
77
79
|
"""Append doc_ids to each example so that they are processed together in the metric"""
|
|
78
80
|
doc_id = 0
|
|
81
|
+
new_samples = []
|
|
79
82
|
for doc in self.dataset:
|
|
80
83
|
for prompt in self.prompts:
|
|
81
84
|
self.current_prompt = prompt
|
|
@@ -125,7 +128,7 @@ class ICLMultiChoiceTaskDataset(metaclass=abc.ABCMeta):
|
|
|
125
128
|
dc_query = dc + continuation[:-1]
|
|
126
129
|
|
|
127
130
|
# form a sample
|
|
128
|
-
|
|
131
|
+
new_samples.append(
|
|
129
132
|
{
|
|
130
133
|
"doc_id": doc_id,
|
|
131
134
|
"cont_id": cont_id,
|
|
@@ -148,6 +151,56 @@ class ICLMultiChoiceTaskDataset(metaclass=abc.ABCMeta):
|
|
|
148
151
|
|
|
149
152
|
doc_id += 1
|
|
150
153
|
|
|
154
|
+
# Fast MCQA:
|
|
155
|
+
# Only pass a single request, and group together all continuations as tokens
|
|
156
|
+
if self.fast_mc:
|
|
157
|
+
# Get unique doc IDs
|
|
158
|
+
unique_doc_ids = {
|
|
159
|
+
sample["doc_id"] for sample in new_samples if isinstance(sample["doc_id"], int)
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
# Create new samples list for fast MC
|
|
163
|
+
fast_mc_samples = []
|
|
164
|
+
|
|
165
|
+
# Process each unique document
|
|
166
|
+
for doc_id in unique_doc_ids:
|
|
167
|
+
# Get all samples for this doc_id
|
|
168
|
+
doc_samples = [s for s in new_samples if s["doc_id"] == doc_id]
|
|
169
|
+
|
|
170
|
+
# Sort by continuation ID
|
|
171
|
+
doc_samples.sort(
|
|
172
|
+
key=lambda x: float(x["cont_id"])
|
|
173
|
+
if isinstance(x["cont_id"], (int, float))
|
|
174
|
+
else 0.0
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Create new sample with distractor continuations
|
|
178
|
+
base_sample = doc_samples[0].copy()
|
|
179
|
+
choices = [s["continuation"] for s in doc_samples]
|
|
180
|
+
|
|
181
|
+
# Assert all continuations are length 1
|
|
182
|
+
for choice in choices:
|
|
183
|
+
if not isinstance(choice, (list, tuple)):
|
|
184
|
+
raise TypeError(
|
|
185
|
+
f"Expected continuation to be a list or tuple, got {type(choice)}"
|
|
186
|
+
)
|
|
187
|
+
assert len(choice) == 1, f"Expected continuation length 1, got {len(choice)}"
|
|
188
|
+
|
|
189
|
+
# Take first token of each continuation
|
|
190
|
+
choices = [
|
|
191
|
+
choice[0] if isinstance(choice, (list, tuple)) else choice for choice in choices
|
|
192
|
+
]
|
|
193
|
+
|
|
194
|
+
base_sample["choices"] = choices
|
|
195
|
+
base_sample["fast_mc"] = True
|
|
196
|
+
|
|
197
|
+
fast_mc_samples.append(base_sample)
|
|
198
|
+
|
|
199
|
+
# Add fast MC samples to main samples list
|
|
200
|
+
new_samples = fast_mc_samples
|
|
201
|
+
|
|
202
|
+
self.samples = new_samples
|
|
203
|
+
|
|
151
204
|
def pad_tokens_until_max(self, tokens, max_len=2048):
|
|
152
205
|
"""truncate from left if len(tokens) > model_ctx_len, max_len is not considered then
|
|
153
206
|
queries are already truncated at max length of model_ctx_len
|
|
@@ -214,6 +267,7 @@ class ICLMultiChoiceTaskDataset(metaclass=abc.ABCMeta):
|
|
|
214
267
|
ctxs = []
|
|
215
268
|
continuations = []
|
|
216
269
|
ctx_lens = []
|
|
270
|
+
choice_ids = []
|
|
217
271
|
dc_lens = []
|
|
218
272
|
cont_lens = []
|
|
219
273
|
cont_str_lens = []
|
|
@@ -245,6 +299,8 @@ class ICLMultiChoiceTaskDataset(metaclass=abc.ABCMeta):
|
|
|
245
299
|
cont_byte_lens.append(sample["cont_byte_len"])
|
|
246
300
|
cont_str_len_no_leading_space.append(sample["cont_str_len_no_leading_space"])
|
|
247
301
|
cont_byte_len_no_leading_space.append(sample["cont_byte_len_no_leading_space"])
|
|
302
|
+
if self.fast_mc:
|
|
303
|
+
choice_ids.append(sample["choices"])
|
|
248
304
|
|
|
249
305
|
queries.append(
|
|
250
306
|
torch.LongTensor(
|
|
@@ -281,6 +337,16 @@ class ICLMultiChoiceTaskDataset(metaclass=abc.ABCMeta):
|
|
|
281
337
|
"label_id": torch.LongTensor(label_ids),
|
|
282
338
|
}
|
|
283
339
|
|
|
340
|
+
if self.fast_mc:
|
|
341
|
+
# Pad choice_ids with -1 (for Qs with different numbers of choices)
|
|
342
|
+
max_choices_len = max(len(choices) for choices in choice_ids)
|
|
343
|
+
padded_choice_ids = []
|
|
344
|
+
for choices in choice_ids:
|
|
345
|
+
padding = [-1] * (max_choices_len - len(choices))
|
|
346
|
+
padded_choice_ids.append(choices + padding)
|
|
347
|
+
choice_ids = padded_choice_ids
|
|
348
|
+
batch["choice_ids"] = torch.LongTensor(choice_ids)
|
|
349
|
+
|
|
284
350
|
return batch
|
|
285
351
|
|
|
286
352
|
def token_encode(self, string: str) -> List[int]:
|
|
@@ -1446,6 +1512,7 @@ class OEEvalTask(ICLMultiChoiceTaskDataset):
|
|
|
1446
1512
|
dataset_name: Union[str, Sequence[str], None] = None,
|
|
1447
1513
|
model_ctx_len: int = 2048,
|
|
1448
1514
|
fixed_ctx_len: bool = False,
|
|
1515
|
+
fast_mc: bool = False,
|
|
1449
1516
|
split=None,
|
|
1450
1517
|
metric_type=None,
|
|
1451
1518
|
prompts: Optional[List[Optional[str]]] = None, # List of prompt variants to use
|
|
@@ -1457,6 +1524,7 @@ class OEEvalTask(ICLMultiChoiceTaskDataset):
|
|
|
1457
1524
|
self.dataset_name = dataset_name
|
|
1458
1525
|
self.model_ctx_len = model_ctx_len
|
|
1459
1526
|
self.fixed_ctx_len = fixed_ctx_len
|
|
1527
|
+
self.fast_mc = fast_mc
|
|
1460
1528
|
self.log_instances = 0 # Set to > 0 to log the first few instances as a sanity check
|
|
1461
1529
|
|
|
1462
1530
|
self.samples: List[Dict[str, Any]] = []
|
|
@@ -1500,6 +1568,8 @@ class OEEvalTask(ICLMultiChoiceTaskDataset):
|
|
|
1500
1568
|
for requests in self.dataset:
|
|
1501
1569
|
current_doc_id_offset += max_doc_id
|
|
1502
1570
|
max_doc_id = 0 # Max doc id seen in this dataset
|
|
1571
|
+
|
|
1572
|
+
new_samples = []
|
|
1503
1573
|
for request in requests:
|
|
1504
1574
|
doc = request["doc"]
|
|
1505
1575
|
doc_id = request["doc_id"]
|
|
@@ -1571,7 +1641,7 @@ class OEEvalTask(ICLMultiChoiceTaskDataset):
|
|
|
1571
1641
|
dc_query = dc + continuation[:-1]
|
|
1572
1642
|
|
|
1573
1643
|
# form a sample
|
|
1574
|
-
|
|
1644
|
+
new_samples.append(
|
|
1575
1645
|
{
|
|
1576
1646
|
"doc_id": doc_id + current_doc_id_offset,
|
|
1577
1647
|
"cont_id": cont_id,
|
|
@@ -1592,6 +1662,46 @@ class OEEvalTask(ICLMultiChoiceTaskDataset):
|
|
|
1592
1662
|
}
|
|
1593
1663
|
)
|
|
1594
1664
|
|
|
1665
|
+
# Fast MCQA:
|
|
1666
|
+
# Only pass a single request, and group together all continuations as tokens
|
|
1667
|
+
if self.fast_mc:
|
|
1668
|
+
# Get unique doc IDs
|
|
1669
|
+
unique_doc_ids = set(sample["doc_id"] for sample in new_samples)
|
|
1670
|
+
|
|
1671
|
+
# Create new samples list for fast MC
|
|
1672
|
+
fast_mc_samples = []
|
|
1673
|
+
|
|
1674
|
+
# Process each unique document
|
|
1675
|
+
for doc_id in unique_doc_ids:
|
|
1676
|
+
# Get all samples for this doc_id
|
|
1677
|
+
doc_samples = [s for s in new_samples if s["doc_id"] == doc_id]
|
|
1678
|
+
|
|
1679
|
+
# Sort by continuation ID
|
|
1680
|
+
doc_samples.sort(key=lambda x: x["cont_id"])
|
|
1681
|
+
|
|
1682
|
+
# Create new sample with distractor continuations
|
|
1683
|
+
base_sample = doc_samples[0].copy()
|
|
1684
|
+
choices = [s["continuation"] for s in doc_samples]
|
|
1685
|
+
|
|
1686
|
+
# Assert all continuations are length 1
|
|
1687
|
+
for choice in choices:
|
|
1688
|
+
assert (
|
|
1689
|
+
len(choice) == 1
|
|
1690
|
+
), f"Expected continuation length 1, got {len(choice)}"
|
|
1691
|
+
|
|
1692
|
+
# Take first token of each continuation
|
|
1693
|
+
choices = [choice[0] for choice in choices]
|
|
1694
|
+
|
|
1695
|
+
base_sample["choices"] = choices
|
|
1696
|
+
base_sample["fast_mc"] = True
|
|
1697
|
+
|
|
1698
|
+
fast_mc_samples.append(base_sample)
|
|
1699
|
+
|
|
1700
|
+
# Add fast MC samples to main samples list
|
|
1701
|
+
new_samples = fast_mc_samples
|
|
1702
|
+
|
|
1703
|
+
self.samples = new_samples
|
|
1704
|
+
|
|
1595
1705
|
def doc_to_text(self, doc) -> str:
|
|
1596
1706
|
del doc
|
|
1597
1707
|
raise NotImplementedError
|
|
@@ -1768,6 +1878,24 @@ LABEL_TO_TASK_MAP_ORIG = {
|
|
|
1768
1878
|
OEEvalTask,
|
|
1769
1879
|
{"dataset_path": "copycolors", "dataset_name": "xl_10way", "metric_type": "acc"},
|
|
1770
1880
|
),
|
|
1881
|
+
"copycolors_10way_fast": (
|
|
1882
|
+
OEEvalTask,
|
|
1883
|
+
{
|
|
1884
|
+
"dataset_path": "copycolors",
|
|
1885
|
+
"dataset_name": "10way",
|
|
1886
|
+
"metric_type": "acc",
|
|
1887
|
+
"fast_mc": True,
|
|
1888
|
+
},
|
|
1889
|
+
),
|
|
1890
|
+
"copycolors_xl_10way_fast": (
|
|
1891
|
+
OEEvalTask,
|
|
1892
|
+
{
|
|
1893
|
+
"dataset_path": "copycolors",
|
|
1894
|
+
"dataset_name": "xl_10way",
|
|
1895
|
+
"metric_type": "acc",
|
|
1896
|
+
"fast_mc": True,
|
|
1897
|
+
},
|
|
1898
|
+
),
|
|
1771
1899
|
"csqa_mc_5shot": (
|
|
1772
1900
|
OEEvalTask,
|
|
1773
1901
|
{"dataset_path": "csqa", "dataset_name": "mc_5shot", "metric_type": "acc"},
|
|
@@ -1792,6 +1920,10 @@ LABEL_TO_TASK_MAP_ORIG = {
|
|
|
1792
1920
|
OEEvalTask,
|
|
1793
1921
|
{"dataset_path": "hellaswag", "dataset_name": "rc_5shot", "metric_type": "len_norm"},
|
|
1794
1922
|
),
|
|
1923
|
+
"hellaswag_bpb_5shot": (
|
|
1924
|
+
OEEvalTask,
|
|
1925
|
+
{"dataset_path": "hellaswag", "dataset_name": "rc_5shot", "metric_type": "bpb"},
|
|
1926
|
+
),
|
|
1795
1927
|
"openbookqa_mc_5shot": (
|
|
1796
1928
|
OEEvalTask,
|
|
1797
1929
|
{"dataset_path": "openbookqa", "dataset_name": "mc_5shot", "metric_type": "acc"},
|
|
@@ -2001,6 +2133,14 @@ LABEL_TO_TASK_MAP_LADDER = {
|
|
|
2001
2133
|
"metric_type": "len_norm",
|
|
2002
2134
|
},
|
|
2003
2135
|
),
|
|
2136
|
+
"arc_challenge_val_bpb_5shot": (
|
|
2137
|
+
OEEvalTask,
|
|
2138
|
+
{
|
|
2139
|
+
"dataset_path": "arc_challenge",
|
|
2140
|
+
"dataset_name": "val_rc_5shot",
|
|
2141
|
+
"metric_type": "bpb",
|
|
2142
|
+
},
|
|
2143
|
+
),
|
|
2004
2144
|
"arc_challenge_val_mc_5shot": (
|
|
2005
2145
|
OEEvalTask,
|
|
2006
2146
|
{"dataset_path": "arc_challenge", "dataset_name": "val_mc_5shot", "metric_type": "acc"},
|
|
@@ -2013,114 +2153,299 @@ LABEL_TO_TASK_MAP_LADDER = {
|
|
|
2013
2153
|
"metric_type": "len_norm",
|
|
2014
2154
|
},
|
|
2015
2155
|
),
|
|
2156
|
+
"arc_challenge_test_bpb_5shot": (
|
|
2157
|
+
OEEvalTask,
|
|
2158
|
+
{
|
|
2159
|
+
"dataset_path": "arc_challenge",
|
|
2160
|
+
"dataset_name": "test_rc_5shot",
|
|
2161
|
+
"metric_type": "bpb",
|
|
2162
|
+
},
|
|
2163
|
+
),
|
|
2016
2164
|
"arc_challenge_test_mc_5shot": (
|
|
2017
2165
|
OEEvalTask,
|
|
2018
2166
|
{"dataset_path": "arc_challenge", "dataset_name": "test_mc_5shot", "metric_type": "acc"},
|
|
2019
2167
|
),
|
|
2168
|
+
"arc_challenge_test_mc_5shot_fast": (
|
|
2169
|
+
OEEvalTask,
|
|
2170
|
+
{
|
|
2171
|
+
"dataset_path": "arc_challenge",
|
|
2172
|
+
"dataset_name": "test_mc_5shot",
|
|
2173
|
+
"metric_type": "acc",
|
|
2174
|
+
"fast_mc": True,
|
|
2175
|
+
},
|
|
2176
|
+
),
|
|
2020
2177
|
"arc_easy_val_rc_5shot": (
|
|
2021
2178
|
OEEvalTask,
|
|
2022
2179
|
{"dataset_path": "arc_easy", "dataset_name": "val_rc_5shot", "metric_type": "len_norm"},
|
|
2023
2180
|
),
|
|
2181
|
+
"arc_easy_val_bpb_5shot": (
|
|
2182
|
+
OEEvalTask,
|
|
2183
|
+
{"dataset_path": "arc_easy", "dataset_name": "val_rc_5shot", "metric_type": "bpb"},
|
|
2184
|
+
),
|
|
2024
2185
|
"arc_easy_val_mc_5shot": (
|
|
2025
2186
|
OEEvalTask,
|
|
2026
2187
|
{"dataset_path": "arc_easy", "dataset_name": "val_mc_5shot", "metric_type": "acc"},
|
|
2027
2188
|
),
|
|
2189
|
+
"arc_easy_val_mc_5shot_fast": (
|
|
2190
|
+
OEEvalTask,
|
|
2191
|
+
{
|
|
2192
|
+
"dataset_path": "arc_easy",
|
|
2193
|
+
"dataset_name": "val_mc_5shot",
|
|
2194
|
+
"metric_type": "acc",
|
|
2195
|
+
"fast_mc": True,
|
|
2196
|
+
},
|
|
2197
|
+
),
|
|
2028
2198
|
"arc_easy_test_rc_5shot": (
|
|
2029
2199
|
OEEvalTask,
|
|
2030
2200
|
{"dataset_path": "arc_easy", "dataset_name": "test_rc_5shot", "metric_type": "len_norm"},
|
|
2031
2201
|
),
|
|
2202
|
+
"arc_easy_test_bpb_5shot": (
|
|
2203
|
+
OEEvalTask,
|
|
2204
|
+
{"dataset_path": "arc_easy", "dataset_name": "test_rc_5shot", "metric_type": "bpb"},
|
|
2205
|
+
),
|
|
2032
2206
|
"arc_easy_test_mc_5shot": (
|
|
2033
2207
|
OEEvalTask,
|
|
2034
2208
|
{"dataset_path": "arc_easy", "dataset_name": "test_mc_5shot", "metric_type": "acc"},
|
|
2035
2209
|
),
|
|
2210
|
+
"arc_easy_test_mc_5shot_fast": (
|
|
2211
|
+
OEEvalTask,
|
|
2212
|
+
{
|
|
2213
|
+
"dataset_path": "arc_easy",
|
|
2214
|
+
"dataset_name": "test_mc_5shot",
|
|
2215
|
+
"metric_type": "acc",
|
|
2216
|
+
"fast_mc": True,
|
|
2217
|
+
},
|
|
2218
|
+
),
|
|
2036
2219
|
"boolq_val_rc_5shot": (
|
|
2037
2220
|
OEEvalTask,
|
|
2038
2221
|
{"dataset_path": "boolq", "dataset_name": "val_rc_5shot", "metric_type": "acc"},
|
|
2039
2222
|
),
|
|
2223
|
+
"boolq_val_bpb_5shot": (
|
|
2224
|
+
OEEvalTask,
|
|
2225
|
+
{"dataset_path": "boolq", "dataset_name": "val_rc_5shot", "metric_type": "bpb"},
|
|
2226
|
+
),
|
|
2040
2227
|
"boolq_val_mc_5shot": (
|
|
2041
2228
|
OEEvalTask,
|
|
2042
2229
|
{"dataset_path": "boolq", "dataset_name": "val_mc_5shot", "metric_type": "acc"},
|
|
2043
2230
|
),
|
|
2231
|
+
"boolq_val_mc_5shot_fast": (
|
|
2232
|
+
OEEvalTask,
|
|
2233
|
+
{
|
|
2234
|
+
"dataset_path": "boolq",
|
|
2235
|
+
"dataset_name": "val_mc_5shot",
|
|
2236
|
+
"metric_type": "acc",
|
|
2237
|
+
"fast_mc": True,
|
|
2238
|
+
},
|
|
2239
|
+
),
|
|
2044
2240
|
"csqa_val_rc_5shot": (
|
|
2045
2241
|
OEEvalTask,
|
|
2046
2242
|
{"dataset_path": "csqa", "dataset_name": "val_rc_5shot", "metric_type": "len_norm"},
|
|
2047
2243
|
),
|
|
2244
|
+
"csqa_val_bpb_5shot": (
|
|
2245
|
+
OEEvalTask,
|
|
2246
|
+
{"dataset_path": "csqa", "dataset_name": "val_rc_5shot", "metric_type": "bpb"},
|
|
2247
|
+
),
|
|
2048
2248
|
"csqa_val_mc_5shot": (
|
|
2049
2249
|
OEEvalTask,
|
|
2050
2250
|
{"dataset_path": "csqa", "dataset_name": "val_mc_5shot", "metric_type": "acc"},
|
|
2051
2251
|
),
|
|
2252
|
+
"csqa_val_mc_5shot_fast": (
|
|
2253
|
+
OEEvalTask,
|
|
2254
|
+
{
|
|
2255
|
+
"dataset_path": "csqa",
|
|
2256
|
+
"dataset_name": "val_mc_5shot",
|
|
2257
|
+
"metric_type": "acc",
|
|
2258
|
+
"fast_mc": True,
|
|
2259
|
+
},
|
|
2260
|
+
),
|
|
2052
2261
|
"hellaswag_val_rc_5shot": (
|
|
2053
2262
|
OEEvalTask,
|
|
2054
2263
|
{"dataset_path": "hellaswag", "dataset_name": "val_rc_5shot", "metric_type": "len_norm"},
|
|
2055
2264
|
),
|
|
2265
|
+
"hellaswag_val_bpb_5shot": (
|
|
2266
|
+
OEEvalTask,
|
|
2267
|
+
{"dataset_path": "hellaswag", "dataset_name": "val_rc_5shot", "metric_type": "bpb"},
|
|
2268
|
+
),
|
|
2056
2269
|
"hellaswag_val_mc_5shot": (
|
|
2057
2270
|
OEEvalTask,
|
|
2058
2271
|
{"dataset_path": "hellaswag", "dataset_name": "val_mc_5shot", "metric_type": "acc"},
|
|
2059
2272
|
),
|
|
2273
|
+
"hellaswag_val_mc_5shot_fast": (
|
|
2274
|
+
OEEvalTask,
|
|
2275
|
+
{
|
|
2276
|
+
"dataset_path": "hellaswag",
|
|
2277
|
+
"dataset_name": "val_mc_5shot",
|
|
2278
|
+
"metric_type": "acc",
|
|
2279
|
+
"fast_mc": True,
|
|
2280
|
+
},
|
|
2281
|
+
),
|
|
2060
2282
|
"openbookqa_val_rc_5shot": (
|
|
2061
2283
|
OEEvalTask,
|
|
2062
2284
|
{"dataset_path": "openbookqa", "dataset_name": "val_rc_5shot", "metric_type": "len_norm"},
|
|
2063
2285
|
),
|
|
2286
|
+
"openbookqa_val_bpb_5shot": (
|
|
2287
|
+
OEEvalTask,
|
|
2288
|
+
{"dataset_path": "openbookqa", "dataset_name": "val_rc_5shot", "metric_type": "bpb"},
|
|
2289
|
+
),
|
|
2064
2290
|
"openbookqa_val_mc_5shot": (
|
|
2065
2291
|
OEEvalTask,
|
|
2066
2292
|
{"dataset_path": "openbookqa", "dataset_name": "val_mc_5shot", "metric_type": "acc"},
|
|
2067
2293
|
),
|
|
2294
|
+
"openbookqa_val_mc_5shot_fast": (
|
|
2295
|
+
OEEvalTask,
|
|
2296
|
+
{
|
|
2297
|
+
"dataset_path": "openbookqa",
|
|
2298
|
+
"dataset_name": "val_mc_5shot",
|
|
2299
|
+
"metric_type": "acc",
|
|
2300
|
+
"fast_mc": True,
|
|
2301
|
+
},
|
|
2302
|
+
),
|
|
2068
2303
|
"openbookqa_test_rc_5shot": (
|
|
2069
2304
|
OEEvalTask,
|
|
2070
2305
|
{"dataset_path": "openbookqa", "dataset_name": "test_rc_5shot", "metric_type": "len_norm"},
|
|
2071
2306
|
),
|
|
2307
|
+
"openbookqa_test_bpb_5shot": (
|
|
2308
|
+
OEEvalTask,
|
|
2309
|
+
{"dataset_path": "openbookqa", "dataset_name": "test_rc_5shot", "metric_type": "bpb"},
|
|
2310
|
+
),
|
|
2072
2311
|
"openbookqa_test_mc_5shot": (
|
|
2073
2312
|
OEEvalTask,
|
|
2074
2313
|
{"dataset_path": "openbookqa", "dataset_name": "test_mc_5shot", "metric_type": "acc"},
|
|
2075
2314
|
),
|
|
2315
|
+
"openbookqa_test_mc_5shot_fast": (
|
|
2316
|
+
OEEvalTask,
|
|
2317
|
+
{
|
|
2318
|
+
"dataset_path": "openbookqa",
|
|
2319
|
+
"dataset_name": "test_mc_5shot",
|
|
2320
|
+
"metric_type": "acc",
|
|
2321
|
+
"fast_mc": True,
|
|
2322
|
+
},
|
|
2323
|
+
),
|
|
2076
2324
|
"piqa_val_rc_5shot": (
|
|
2077
2325
|
OEEvalTask,
|
|
2078
2326
|
{"dataset_path": "piqa", "dataset_name": "val_rc_5shot", "metric_type": "len_norm"},
|
|
2079
2327
|
),
|
|
2328
|
+
"piqa_val_bpb_5shot": (
|
|
2329
|
+
OEEvalTask,
|
|
2330
|
+
{"dataset_path": "piqa", "dataset_name": "val_rc_5shot", "metric_type": "bpb"},
|
|
2331
|
+
),
|
|
2080
2332
|
"piqa_val_mc_5shot": (
|
|
2081
2333
|
OEEvalTask,
|
|
2082
2334
|
{"dataset_path": "piqa", "dataset_name": "val_mc_5shot", "metric_type": "acc"},
|
|
2083
2335
|
),
|
|
2336
|
+
"piqa_val_mc_5shot_fast": (
|
|
2337
|
+
OEEvalTask,
|
|
2338
|
+
{
|
|
2339
|
+
"dataset_path": "piqa",
|
|
2340
|
+
"dataset_name": "val_mc_5shot",
|
|
2341
|
+
"metric_type": "acc",
|
|
2342
|
+
"fast_mc": True,
|
|
2343
|
+
},
|
|
2344
|
+
),
|
|
2084
2345
|
"socialiqa_val_rc_5shot": (
|
|
2085
2346
|
OEEvalTask,
|
|
2086
2347
|
{"dataset_path": "socialiqa", "dataset_name": "val_rc_5shot", "metric_type": "len_norm"},
|
|
2087
2348
|
),
|
|
2349
|
+
"socialiqa_val_bpb_5shot": (
|
|
2350
|
+
OEEvalTask,
|
|
2351
|
+
{"dataset_path": "socialiqa", "dataset_name": "val_rc_5shot", "metric_type": "bpb"},
|
|
2352
|
+
),
|
|
2088
2353
|
"socialiqa_val_mc_5shot": (
|
|
2089
2354
|
OEEvalTask,
|
|
2090
2355
|
{"dataset_path": "socialiqa", "dataset_name": "val_mc_5shot", "metric_type": "acc"},
|
|
2091
2356
|
),
|
|
2357
|
+
"socialiqa_val_mc_5shot_fast": (
|
|
2358
|
+
OEEvalTask,
|
|
2359
|
+
{
|
|
2360
|
+
"dataset_path": "socialiqa",
|
|
2361
|
+
"dataset_name": "val_mc_5shot",
|
|
2362
|
+
"metric_type": "acc",
|
|
2363
|
+
"fast_mc": True,
|
|
2364
|
+
},
|
|
2365
|
+
),
|
|
2092
2366
|
"winogrande_val_rc_5shot": (
|
|
2093
2367
|
OEEvalTask,
|
|
2094
2368
|
{"dataset_path": "winogrande", "dataset_name": "val_rc_5shot", "metric_type": "len_norm"},
|
|
2095
2369
|
),
|
|
2370
|
+
"winogrande_val_bpb_5shot": (
|
|
2371
|
+
OEEvalTask,
|
|
2372
|
+
{"dataset_path": "winogrande", "dataset_name": "val_rc_5shot", "metric_type": "bpb"},
|
|
2373
|
+
),
|
|
2096
2374
|
"winogrande_val_mc_5shot": (
|
|
2097
2375
|
OEEvalTask,
|
|
2098
2376
|
{"dataset_path": "winogrande", "dataset_name": "val_mc_5shot", "metric_type": "acc"},
|
|
2099
2377
|
),
|
|
2378
|
+
"winogrande_val_mc_5shot_fast": (
|
|
2379
|
+
OEEvalTask,
|
|
2380
|
+
{
|
|
2381
|
+
"dataset_path": "winogrande",
|
|
2382
|
+
"dataset_name": "val_mc_5shot",
|
|
2383
|
+
"metric_type": "acc",
|
|
2384
|
+
"fast_mc": True,
|
|
2385
|
+
},
|
|
2386
|
+
),
|
|
2100
2387
|
"mmlu_stem_val_rc_var": (MMLU, {"dataset_name": "stem", "prompt_variations": 1}),
|
|
2101
2388
|
"mmlu_stem_val_rc_5shot": (MMLU, {"dataset_name": "stem", "prompt_variations": 2}),
|
|
2389
|
+
"mmlu_stem_val_bpb_5shot": (
|
|
2390
|
+
MMLU,
|
|
2391
|
+
{"dataset_name": "stem", "prompt_variations": 2, "metric_type": "bpb"},
|
|
2392
|
+
),
|
|
2102
2393
|
"mmlu_stem_val_mc_5shot": (
|
|
2103
2394
|
MMLU,
|
|
2104
2395
|
{"dataset_name": "stem", "prompt_variations": 2, "mc_labels": True},
|
|
2105
2396
|
),
|
|
2397
|
+
"mmlu_stem_val_mc_5shot_fast": (
|
|
2398
|
+
MMLU,
|
|
2399
|
+
{"dataset_name": "stem", "prompt_variations": 2, "mc_labels": True, "fast_mc": True},
|
|
2400
|
+
),
|
|
2106
2401
|
"mmlu_stem_test_rc_var": (
|
|
2107
2402
|
MMLU,
|
|
2108
2403
|
{"dataset_name": "stem", "split": "test", "prompt_variations": 1},
|
|
2109
2404
|
),
|
|
2405
|
+
"mmlu_stem_test_bpb_var": (
|
|
2406
|
+
MMLU,
|
|
2407
|
+
{"dataset_name": "stem", "split": "test", "prompt_variations": 2, "metric_type": "bpb"},
|
|
2408
|
+
),
|
|
2110
2409
|
"mmlu_stem_test_rc_5shot": (
|
|
2111
2410
|
MMLU,
|
|
2112
2411
|
{"dataset_name": "stem", "split": "test", "prompt_variations": 2},
|
|
2113
2412
|
),
|
|
2413
|
+
"mmlu_stem_test_bpb_5shot": (
|
|
2414
|
+
MMLU,
|
|
2415
|
+
{"dataset_name": "stem", "split": "test", "prompt_variations": 2, "metric_type": "bpb"},
|
|
2416
|
+
),
|
|
2114
2417
|
"mmlu_stem_test_mc_5shot": (
|
|
2115
2418
|
MMLU,
|
|
2116
2419
|
{"dataset_name": "stem", "split": "test", "prompt_variations": 2, "mc_labels": True},
|
|
2117
2420
|
),
|
|
2421
|
+
"mmlu_stem_test_mc_5shot_fast": (
|
|
2422
|
+
MMLU,
|
|
2423
|
+
{
|
|
2424
|
+
"dataset_name": "stem",
|
|
2425
|
+
"split": "test",
|
|
2426
|
+
"prompt_variations": 2,
|
|
2427
|
+
"mc_labels": True,
|
|
2428
|
+
"fast_mc": True,
|
|
2429
|
+
},
|
|
2430
|
+
),
|
|
2118
2431
|
"mmlu_humanities_val_rc_var": (MMLU, {"dataset_name": "humanities", "prompt_variations": 1}),
|
|
2119
2432
|
"mmlu_humanities_val_rc_5shot": (MMLU, {"dataset_name": "humanities", "prompt_variations": 2}),
|
|
2433
|
+
"mmlu_humanities_val_bpb_var": (
|
|
2434
|
+
MMLU,
|
|
2435
|
+
{"dataset_name": "humanities", "prompt_variations": 2, "metric_type": "bpb"},
|
|
2436
|
+
),
|
|
2437
|
+
"mmlu_humanities_val_bpb_5shot": (
|
|
2438
|
+
MMLU,
|
|
2439
|
+
{"dataset_name": "humanities", "prompt_variations": 2, "metric_type": "bpb"},
|
|
2440
|
+
),
|
|
2120
2441
|
"mmlu_humanities_val_mc_5shot": (
|
|
2121
2442
|
MMLU,
|
|
2122
2443
|
{"dataset_name": "humanities", "prompt_variations": 2, "mc_labels": True},
|
|
2123
2444
|
),
|
|
2445
|
+
"mmlu_humanities_val_mc_5shot_fast": (
|
|
2446
|
+
MMLU,
|
|
2447
|
+
{"dataset_name": "humanities", "prompt_variations": 2, "mc_labels": True, "fast_mc": True},
|
|
2448
|
+
),
|
|
2124
2449
|
"mmlu_humanities_test_rc_var": (
|
|
2125
2450
|
MMLU,
|
|
2126
2451
|
{"dataset_name": "humanities", "split": "test", "prompt_variations": 1},
|
|
@@ -2129,10 +2454,38 @@ LABEL_TO_TASK_MAP_LADDER = {
|
|
|
2129
2454
|
MMLU,
|
|
2130
2455
|
{"dataset_name": "humanities", "split": "test", "prompt_variations": 2},
|
|
2131
2456
|
),
|
|
2457
|
+
"mmlu_humanities_test_bpb_var": (
|
|
2458
|
+
MMLU,
|
|
2459
|
+
{
|
|
2460
|
+
"dataset_name": "humanities",
|
|
2461
|
+
"split": "test",
|
|
2462
|
+
"prompt_variations": 2,
|
|
2463
|
+
"metric_type": "bpb",
|
|
2464
|
+
},
|
|
2465
|
+
),
|
|
2466
|
+
"mmlu_humanities_test_bpb_5shot": (
|
|
2467
|
+
MMLU,
|
|
2468
|
+
{
|
|
2469
|
+
"dataset_name": "humanities",
|
|
2470
|
+
"split": "test",
|
|
2471
|
+
"prompt_variations": 2,
|
|
2472
|
+
"metric_type": "bpb",
|
|
2473
|
+
},
|
|
2474
|
+
),
|
|
2132
2475
|
"mmlu_humanities_test_mc_5shot": (
|
|
2133
2476
|
MMLU,
|
|
2134
2477
|
{"dataset_name": "humanities", "split": "test", "prompt_variations": 2, "mc_labels": True},
|
|
2135
2478
|
),
|
|
2479
|
+
"mmlu_humanities_test_mc_5shot_fast": (
|
|
2480
|
+
MMLU,
|
|
2481
|
+
{
|
|
2482
|
+
"dataset_name": "humanities",
|
|
2483
|
+
"split": "test",
|
|
2484
|
+
"prompt_variations": 2,
|
|
2485
|
+
"mc_labels": True,
|
|
2486
|
+
"fast_mc": True,
|
|
2487
|
+
},
|
|
2488
|
+
),
|
|
2136
2489
|
"mmlu_social_sciences_val_rc_var": (
|
|
2137
2490
|
MMLU,
|
|
2138
2491
|
{"dataset_name": "social_sciences", "prompt_variations": 1},
|
|
@@ -2141,10 +2494,27 @@ LABEL_TO_TASK_MAP_LADDER = {
|
|
|
2141
2494
|
MMLU,
|
|
2142
2495
|
{"dataset_name": "social_sciences", "prompt_variations": 2},
|
|
2143
2496
|
),
|
|
2497
|
+
"mmlu_social_sciences_val_bpb_var": (
|
|
2498
|
+
MMLU,
|
|
2499
|
+
{"dataset_name": "social_sciences", "prompt_variations": 2, "metric_type": "bpb"},
|
|
2500
|
+
),
|
|
2501
|
+
"mmlu_social_sciences_val_bpb_5shot": (
|
|
2502
|
+
MMLU,
|
|
2503
|
+
{"dataset_name": "social_sciences", "prompt_variations": 2, "metric_type": "bpb"},
|
|
2504
|
+
),
|
|
2144
2505
|
"mmlu_social_sciences_val_mc_5shot": (
|
|
2145
2506
|
MMLU,
|
|
2146
2507
|
{"dataset_name": "social_sciences", "prompt_variations": 2, "mc_labels": True},
|
|
2147
2508
|
),
|
|
2509
|
+
"mmlu_social_sciences_val_mc_5shot_fast": (
|
|
2510
|
+
MMLU,
|
|
2511
|
+
{
|
|
2512
|
+
"dataset_name": "social_sciences",
|
|
2513
|
+
"prompt_variations": 2,
|
|
2514
|
+
"mc_labels": True,
|
|
2515
|
+
"fast_mc": True,
|
|
2516
|
+
},
|
|
2517
|
+
),
|
|
2148
2518
|
"mmlu_social_sciences_test_rc_var": (
|
|
2149
2519
|
MMLU,
|
|
2150
2520
|
{"dataset_name": "social_sciences", "split": "test", "prompt_variations": 1},
|
|
@@ -2153,6 +2523,24 @@ LABEL_TO_TASK_MAP_LADDER = {
|
|
|
2153
2523
|
MMLU,
|
|
2154
2524
|
{"dataset_name": "social_sciences", "split": "test", "prompt_variations": 2},
|
|
2155
2525
|
),
|
|
2526
|
+
"mmlu_social_sciences_test_bpb_var": (
|
|
2527
|
+
MMLU,
|
|
2528
|
+
{
|
|
2529
|
+
"dataset_name": "social_sciences",
|
|
2530
|
+
"split": "test",
|
|
2531
|
+
"prompt_variations": 2,
|
|
2532
|
+
"metric_type": "bpb",
|
|
2533
|
+
},
|
|
2534
|
+
),
|
|
2535
|
+
"mmlu_social_sciences_test_bpb_5shot": (
|
|
2536
|
+
MMLU,
|
|
2537
|
+
{
|
|
2538
|
+
"dataset_name": "social_sciences",
|
|
2539
|
+
"split": "test",
|
|
2540
|
+
"prompt_variations": 2,
|
|
2541
|
+
"metric_type": "bpb",
|
|
2542
|
+
},
|
|
2543
|
+
),
|
|
2156
2544
|
"mmlu_social_sciences_test_mc_5shot": (
|
|
2157
2545
|
MMLU,
|
|
2158
2546
|
{
|
|
@@ -2162,12 +2550,34 @@ LABEL_TO_TASK_MAP_LADDER = {
|
|
|
2162
2550
|
"mc_labels": True,
|
|
2163
2551
|
},
|
|
2164
2552
|
),
|
|
2553
|
+
"mmlu_social_sciences_test_mc_5shot_fast": (
|
|
2554
|
+
MMLU,
|
|
2555
|
+
{
|
|
2556
|
+
"dataset_name": "social_sciences",
|
|
2557
|
+
"split": "test",
|
|
2558
|
+
"prompt_variations": 2,
|
|
2559
|
+
"mc_labels": True,
|
|
2560
|
+
"fast_mc": True,
|
|
2561
|
+
},
|
|
2562
|
+
),
|
|
2165
2563
|
"mmlu_other_val_rc_var": (MMLU, {"dataset_name": "other", "prompt_variations": 1}),
|
|
2166
2564
|
"mmlu_other_val_rc_5shot": (MMLU, {"dataset_name": "other", "prompt_variations": 2}),
|
|
2565
|
+
"mmlu_other_val_bpb_var": (
|
|
2566
|
+
MMLU,
|
|
2567
|
+
{"dataset_name": "other", "prompt_variations": 2, "metric_type": "bpb"},
|
|
2568
|
+
),
|
|
2569
|
+
"mmlu_other_val_bpb_5shot": (
|
|
2570
|
+
MMLU,
|
|
2571
|
+
{"dataset_name": "other", "prompt_variations": 2, "metric_type": "bpb"},
|
|
2572
|
+
),
|
|
2167
2573
|
"mmlu_other_val_mc_5shot": (
|
|
2168
2574
|
MMLU,
|
|
2169
2575
|
{"dataset_name": "other", "prompt_variations": 2, "mc_labels": True},
|
|
2170
2576
|
),
|
|
2577
|
+
"mmlu_other_val_mc_5shot_fast": (
|
|
2578
|
+
MMLU,
|
|
2579
|
+
{"dataset_name": "other", "prompt_variations": 2, "mc_labels": True, "fast_mc": True},
|
|
2580
|
+
),
|
|
2171
2581
|
"mmlu_other_test_rc_var": (
|
|
2172
2582
|
MMLU,
|
|
2173
2583
|
{"dataset_name": "other", "split": "test", "prompt_variations": 1},
|
|
@@ -2176,10 +2586,28 @@ LABEL_TO_TASK_MAP_LADDER = {
|
|
|
2176
2586
|
MMLU,
|
|
2177
2587
|
{"dataset_name": "other", "split": "test", "prompt_variations": 2},
|
|
2178
2588
|
),
|
|
2589
|
+
"mmlu_other_test_bpb_var": (
|
|
2590
|
+
MMLU,
|
|
2591
|
+
{"dataset_name": "other", "split": "test", "prompt_variations": 2, "metric_type": "bpb"},
|
|
2592
|
+
),
|
|
2593
|
+
"mmlu_other_test_bpb_5shot": (
|
|
2594
|
+
MMLU,
|
|
2595
|
+
{"dataset_name": "other", "split": "test", "prompt_variations": 2, "metric_type": "bpb"},
|
|
2596
|
+
),
|
|
2179
2597
|
"mmlu_other_test_mc_5shot": (
|
|
2180
2598
|
MMLU,
|
|
2181
2599
|
{"dataset_name": "other", "split": "test", "prompt_variations": 2, "mc_labels": True},
|
|
2182
2600
|
),
|
|
2601
|
+
"mmlu_other_test_mc_5shot_fast": (
|
|
2602
|
+
MMLU,
|
|
2603
|
+
{
|
|
2604
|
+
"dataset_name": "other",
|
|
2605
|
+
"split": "test",
|
|
2606
|
+
"prompt_variations": 2,
|
|
2607
|
+
"mc_labels": True,
|
|
2608
|
+
"fast_mc": True,
|
|
2609
|
+
},
|
|
2610
|
+
),
|
|
2183
2611
|
}
|
|
2184
2612
|
|
|
2185
2613
|
# Expanded tasks for BPB on some generative tasks
|
olmo_eval/version.py
CHANGED
|
File without changes
|
|
File without changes
|
|
File without changes
|