sdg-hub 0.1.0a2.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sdg_hub/__init__.py +4 -0
- sdg_hub/_version.py +21 -0
- sdg_hub/blocks/__init__.py +6 -0
- sdg_hub/blocks/block.py +54 -0
- sdg_hub/blocks/filterblock.py +76 -0
- sdg_hub/blocks/iterblock.py +31 -0
- sdg_hub/blocks/llmblock.py +430 -0
- sdg_hub/blocks/rmblocks.py +194 -0
- sdg_hub/blocks/utilblocks.py +140 -0
- sdg_hub/configs/__init__.py +0 -0
- sdg_hub/configs/annotations/__init__.py +0 -0
- sdg_hub/configs/annotations/cot_reflection.yaml +34 -0
- sdg_hub/configs/annotations/detailed_description.yaml +10 -0
- sdg_hub/configs/annotations/detailed_description_icl.yaml +32 -0
- sdg_hub/configs/annotations/simple.yaml +10 -0
- sdg_hub/configs/knowledge/__init__.py +0 -0
- sdg_hub/configs/knowledge/atomic_facts.yaml +45 -0
- sdg_hub/configs/knowledge/auxilary_instructions.yaml +35 -0
- sdg_hub/configs/knowledge/data_recipe/__init__.py +0 -0
- sdg_hub/configs/knowledge/data_recipe/default_recipe.yaml +3 -0
- sdg_hub/configs/knowledge/detailed_summary.yaml +17 -0
- sdg_hub/configs/knowledge/evaluate_faithfulness.yaml +68 -0
- sdg_hub/configs/knowledge/evaluate_question.yaml +38 -0
- sdg_hub/configs/knowledge/evaluate_relevancy.yaml +85 -0
- sdg_hub/configs/knowledge/extractive_summary.yaml +17 -0
- sdg_hub/configs/knowledge/generate_code_questions_responses.yaml +39 -0
- sdg_hub/configs/knowledge/generate_questions_responses.yaml +56 -0
- sdg_hub/configs/knowledge/mcq_generation.yaml +83 -0
- sdg_hub/configs/knowledge/router.yaml +12 -0
- sdg_hub/configs/knowledge/simple_generate_qa.yaml +34 -0
- sdg_hub/configs/reasoning/dynamic_cot.yaml +40 -0
- sdg_hub/configs/skills/_A_.yaml +97 -0
- sdg_hub/configs/skills/_B_.yaml +36 -0
- sdg_hub/configs/skills/_C_.yaml +71 -0
- sdg_hub/configs/skills/_D_.yaml +85 -0
- sdg_hub/configs/skills/_E_.yaml +30 -0
- sdg_hub/configs/skills/_F_.yaml +45 -0
- sdg_hub/configs/skills/_G_.yaml +56 -0
- sdg_hub/configs/skills/_H_.yaml +80 -0
- sdg_hub/configs/skills/__init__.py +0 -0
- sdg_hub/configs/skills/analyzer.yaml +48 -0
- sdg_hub/configs/skills/annotation.yaml +36 -0
- sdg_hub/configs/skills/contexts.yaml +21 -0
- sdg_hub/configs/skills/critic.yaml +60 -0
- sdg_hub/configs/skills/data_recipe/__init__.py +0 -0
- sdg_hub/configs/skills/data_recipe/default_recipe.yaml +6 -0
- sdg_hub/configs/skills/evaluate_freeform_pair.yaml +44 -0
- sdg_hub/configs/skills/evaluate_freeform_questions.yaml +46 -0
- sdg_hub/configs/skills/evaluate_grounded_pair.yaml +54 -0
- sdg_hub/configs/skills/evaluate_grounded_questions.yaml +51 -0
- sdg_hub/configs/skills/freeform_questions.yaml +29 -0
- sdg_hub/configs/skills/freeform_responses.yaml +45 -0
- sdg_hub/configs/skills/grounded_questions.yaml +38 -0
- sdg_hub/configs/skills/grounded_responses.yaml +59 -0
- sdg_hub/configs/skills/judge.yaml +53 -0
- sdg_hub/configs/skills/planner.yaml +67 -0
- sdg_hub/configs/skills/respond.yaml +8 -0
- sdg_hub/configs/skills/revised_responder.yaml +78 -0
- sdg_hub/configs/skills/router.yaml +12 -0
- sdg_hub/configs/skills/simple_generate_qa_freeform.yaml +27 -0
- sdg_hub/configs/skills/simple_generate_qa_grounded.yaml +31 -0
- sdg_hub/flow.py +127 -0
- sdg_hub/flows/annotation/emotion/detailed_description.yaml +19 -0
- sdg_hub/flows/annotation/emotion/detailed_description_icl.yaml +19 -0
- sdg_hub/flows/annotation/emotion/simple.yaml +19 -0
- sdg_hub/flows/generation/knowledge/mmlu_bench.yaml +13 -0
- sdg_hub/flows/generation/knowledge/simple_knowledge.yaml +12 -0
- sdg_hub/flows/generation/knowledge/synth_knowledge.yaml +89 -0
- sdg_hub/flows/generation/knowledge/synth_knowledge1.5.yaml +136 -0
- sdg_hub/flows/generation/skills/agentic_improve_skill.yaml +108 -0
- sdg_hub/flows/generation/skills/simple_freeform_skill.yaml +12 -0
- sdg_hub/flows/generation/skills/simple_grounded_skill.yaml +12 -0
- sdg_hub/flows/generation/skills/synth_grounded_skills.yaml +80 -0
- sdg_hub/flows/generation/skills/synth_skills.yaml +59 -0
- sdg_hub/logger_config.py +20 -0
- sdg_hub/pipeline.py +66 -0
- sdg_hub/prompts.py +17 -0
- sdg_hub/py.typed +0 -0
- sdg_hub/registry.py +122 -0
- sdg_hub/sdg.py +164 -0
- sdg_hub/utils/__init__.py +5 -0
- sdg_hub/utils/chunking.py +73 -0
- sdg_hub/utils/datamixing.py +123 -0
- sdg_hub/utils/datautils.py +14 -0
- sdg_hub/utils/docprocessor.py +357 -0
- sdg_hub/utils/json.py +48 -0
- sdg_hub/utils/models.py +31 -0
- sdg_hub/utils/parse_and_convert.py +392 -0
- sdg_hub/utils/taxonomy.py +489 -0
- sdg_hub-0.1.0a2.dev0.dist-info/METADATA +154 -0
- sdg_hub-0.1.0a2.dev0.dist-info/RECORD +94 -0
- sdg_hub-0.1.0a2.dev0.dist-info/WHEEL +5 -0
- sdg_hub-0.1.0a2.dev0.dist-info/licenses/LICENSE +201 -0
- sdg_hub-0.1.0a2.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,392 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
# Standard
|
4
|
+
from enum import Enum
|
5
|
+
from typing import Any
|
6
|
+
import json
|
7
|
+
import os
|
8
|
+
import random
|
9
|
+
import re
|
10
|
+
import uuid
|
11
|
+
|
12
|
+
# Third Party
|
13
|
+
from datasets import Dataset
|
14
|
+
import yaml
|
15
|
+
|
16
|
+
# First Party
|
17
|
+
# pylint: disable=ungrouped-imports
|
18
|
+
from sdg_hub import utils
|
19
|
+
from sdg_hub.logger_config import setup_logger
|
20
|
+
from .datautils import safe_concatenate_datasets
|
21
|
+
|
22
|
+
logger = setup_logger(__name__)
|
23
|
+
|
24
|
+
|
25
|
+
class TaxonomyType(Enum):
|
26
|
+
KNOWLEDGE = "knowledge"
|
27
|
+
SKILL = "skill"
|
28
|
+
|
29
|
+
|
30
|
+
def _unescape(s):
|
31
|
+
return bytes(s, "utf-8").decode("utf-8").strip()
|
32
|
+
|
33
|
+
|
34
|
+
# This is a hack because the simple workflow returns a q/a pair as a single output.
|
35
|
+
# We could possibly try to ask for them separately, but it would cost twice the inference
|
36
|
+
# API calls. All of this is because the smallest models we use on small environments
|
37
|
+
# for testing and demos weren't good enough to follow the strict formatting instructions used
|
38
|
+
# in the full pipeline.
|
39
|
+
def _get_question(synth_example: dict):
|
40
|
+
if "question" in synth_example:
|
41
|
+
return synth_example["question"]
|
42
|
+
|
43
|
+
if not synth_example.get("output"):
|
44
|
+
raise utils.GenerateException(
|
45
|
+
f"Error: output not found in synth_example: {synth_example}"
|
46
|
+
)
|
47
|
+
|
48
|
+
parts = synth_example["output"].split("?", 1)
|
49
|
+
if len(parts) != 2:
|
50
|
+
logger.warning(f"Failed to split generated q&a: {synth_example['output']}")
|
51
|
+
return parts[0].strip() + "?" if len(parts) == 2 else ""
|
52
|
+
|
53
|
+
|
54
|
+
# This is also a hack. See the comment above _get_question.
|
55
|
+
def _get_response(synth_example: dict):
|
56
|
+
if "response" in synth_example:
|
57
|
+
return synth_example["response"]
|
58
|
+
|
59
|
+
if "output" not in synth_example:
|
60
|
+
raise utils.GenerateException(
|
61
|
+
f"Error: output not found in synth_example: {synth_example}"
|
62
|
+
)
|
63
|
+
|
64
|
+
parts = synth_example["output"].split("?", 1)
|
65
|
+
if len(parts) != 2:
|
66
|
+
logger.warning(f"Failed to split generated q&a: {synth_example['output']}")
|
67
|
+
return parts[1].strip() if len(parts) == 2 else parts[0].strip()
|
68
|
+
|
69
|
+
|
70
|
+
def _convert_to_hack_fmt(sample: dict, sys_prompt: str):
|
71
|
+
"""
|
72
|
+
Convert a sample dictionary to contain 'system', 'user', and 'assistant' columns.
|
73
|
+
|
74
|
+
Note: We should remove this function in the future when we resolve this issue and
|
75
|
+
standardize the format to messages.
|
76
|
+
"""
|
77
|
+
# Create user query message
|
78
|
+
user_query = _unescape(_get_question(sample))
|
79
|
+
response = _unescape(_get_response(sample))
|
80
|
+
if "context" in sample:
|
81
|
+
user_query = f"{sample['context']}\n\n{user_query}"
|
82
|
+
|
83
|
+
sample["id"] = str(uuid.uuid4())
|
84
|
+
sample["system"] = sys_prompt
|
85
|
+
sample["user"] = user_query
|
86
|
+
sample["assistant"] = response
|
87
|
+
|
88
|
+
return sample
|
89
|
+
|
90
|
+
|
91
|
+
def _convert_to_messages(sample: dict, sys_prompt: str):
|
92
|
+
"""
|
93
|
+
Convert a sample dictionary to contain 'messages'
|
94
|
+
and 'metadata' columns required for training.
|
95
|
+
"""
|
96
|
+
# Create user query message
|
97
|
+
user_query = _unescape(_get_question(sample))
|
98
|
+
response = _unescape(_get_response(sample))
|
99
|
+
|
100
|
+
sample["id"] = str(uuid.uuid4())
|
101
|
+
sample["messages"] = [
|
102
|
+
{"content": sys_prompt, "role": "system"},
|
103
|
+
{"content": user_query, "role": "user"},
|
104
|
+
{"content": response, "role": "assistant"},
|
105
|
+
]
|
106
|
+
|
107
|
+
return sample
|
108
|
+
|
109
|
+
|
110
|
+
def create_auxiliary_dataset(generated_dataset: Dataset):
|
111
|
+
if "dataset_type" not in generated_dataset.column_names:
|
112
|
+
return None
|
113
|
+
|
114
|
+
# get module path of the current file
|
115
|
+
module_dir = os.path.dirname(os.path.abspath(__file__))
|
116
|
+
aux_inst_path = os.path.join(module_dir, "../configs/knowledge/auxilary_instructions.yaml")
|
117
|
+
if os.path.isfile(
|
118
|
+
aux_inst_path
|
119
|
+
):
|
120
|
+
with open(aux_inst_path, "r", encoding="utf-8") as fp:
|
121
|
+
auxiliary_inst = yaml.safe_load(fp)
|
122
|
+
else:
|
123
|
+
logger.error(f"auxiliary instructions file not found at {aux_inst_path}")
|
124
|
+
return None
|
125
|
+
auxiliary_ds = generated_dataset.filter(
|
126
|
+
lambda x: x["dataset_type"] != "base_document"
|
127
|
+
)
|
128
|
+
unique_document_auxiliary = auxiliary_ds.to_pandas().drop_duplicates(
|
129
|
+
subset=["document"]
|
130
|
+
)
|
131
|
+
unique_document_auxiliary = Dataset.from_pandas(unique_document_auxiliary)
|
132
|
+
unique_document_auxiliary = unique_document_auxiliary.remove_columns(
|
133
|
+
[
|
134
|
+
col
|
135
|
+
for col in unique_document_auxiliary.column_names
|
136
|
+
if col
|
137
|
+
not in [
|
138
|
+
"raw_document",
|
139
|
+
"document_outline",
|
140
|
+
"domain",
|
141
|
+
"dataset_type",
|
142
|
+
"document",
|
143
|
+
]
|
144
|
+
]
|
145
|
+
)
|
146
|
+
unique_document_auxiliary = unique_document_auxiliary.rename_columns(
|
147
|
+
{"raw_document": "context", "document": "response"}
|
148
|
+
)
|
149
|
+
|
150
|
+
def __create_auxiliary_ds(rec):
|
151
|
+
instruction = random.choice(auxiliary_inst[rec["dataset_type"]])
|
152
|
+
messages = [
|
153
|
+
{"role": "user", "content": f"{rec['context']}\n\n{instruction}"},
|
154
|
+
{"role": "assistant", "content": rec["response"]},
|
155
|
+
]
|
156
|
+
metadata = json.dumps(
|
157
|
+
{
|
158
|
+
"dataset_type": rec["dataset_type"],
|
159
|
+
"raw_document": rec["context"],
|
160
|
+
"dataset": f"document_{rec['dataset_type']}",
|
161
|
+
"domain": rec["domain"],
|
162
|
+
}
|
163
|
+
)
|
164
|
+
return {"messages": messages, "metadata": metadata, "id": str(uuid.uuid4())}
|
165
|
+
|
166
|
+
unique_document_auxiliary = unique_document_auxiliary.map(
|
167
|
+
__create_auxiliary_ds, remove_columns=unique_document_auxiliary.column_names
|
168
|
+
)
|
169
|
+
return unique_document_auxiliary
|
170
|
+
|
171
|
+
|
172
|
+
def generate_knowledge_qa_dataset(
|
173
|
+
generated_dataset: Dataset, keep_context_separate=False
|
174
|
+
):
|
175
|
+
def __create_qa_row(rec):
|
176
|
+
context = rec["document"]
|
177
|
+
instruction = rec["question"]
|
178
|
+
response = rec["response"]
|
179
|
+
metadata = {
|
180
|
+
"sdg_document": rec["document"],
|
181
|
+
"domain": rec["domain"],
|
182
|
+
"dataset": "document_knowledge_qa",
|
183
|
+
}
|
184
|
+
if "raw_document" in rec and "dataset_type" in rec:
|
185
|
+
metadata.update(
|
186
|
+
{
|
187
|
+
"raw_document": rec["raw_document"],
|
188
|
+
"dataset_type": rec["dataset_type"],
|
189
|
+
}
|
190
|
+
)
|
191
|
+
metadata = json.dumps(metadata)
|
192
|
+
if keep_context_separate:
|
193
|
+
messages = [
|
194
|
+
{"role": "user", "content": f"{instruction}"},
|
195
|
+
{"role": "assistant", "content": response},
|
196
|
+
]
|
197
|
+
return {
|
198
|
+
"messages": messages,
|
199
|
+
"metadata": metadata,
|
200
|
+
"id": str(uuid.uuid4()),
|
201
|
+
"context": context,
|
202
|
+
}
|
203
|
+
else:
|
204
|
+
messages = [
|
205
|
+
{"role": "user", "content": f"{context}\n\n{instruction}"},
|
206
|
+
{"role": "assistant", "content": response},
|
207
|
+
]
|
208
|
+
|
209
|
+
return {"messages": messages, "metadata": metadata, "id": str(uuid.uuid4())}
|
210
|
+
|
211
|
+
knowledge_ds = generated_dataset.map(
|
212
|
+
__create_qa_row, remove_columns=generated_dataset.column_names
|
213
|
+
)
|
214
|
+
return knowledge_ds
|
215
|
+
|
216
|
+
|
217
|
+
def build_raft_dataset(ds: Dataset, p, num_doc_in_context=4):
|
218
|
+
all_context = list(set(ds["context"]))
|
219
|
+
|
220
|
+
def _pick_documents(rec, p):
|
221
|
+
answer_document = rec["context"]
|
222
|
+
selected_docs = [e for e in all_context if e != answer_document]
|
223
|
+
if len(selected_docs) > 0:
|
224
|
+
if len(selected_docs) < num_doc_in_context:
|
225
|
+
logger.info(
|
226
|
+
f"Number of unique document is {len(selected_docs)} which is less than {num_doc_in_context}. Using all the documents in the RAFT context"
|
227
|
+
)
|
228
|
+
if random.uniform(0, 1) < p:
|
229
|
+
# golden/answer + distractor documents
|
230
|
+
docs = (
|
231
|
+
random.sample(selected_docs, k=num_doc_in_context-1) + [answer_document]
|
232
|
+
if len(selected_docs) >= (num_doc_in_context-1)
|
233
|
+
else selected_docs + [answer_document]
|
234
|
+
)
|
235
|
+
else:
|
236
|
+
# distractor documents
|
237
|
+
docs = (
|
238
|
+
random.sample(selected_docs, k=num_doc_in_context)
|
239
|
+
if len(selected_docs) >= num_doc_in_context
|
240
|
+
else selected_docs
|
241
|
+
)
|
242
|
+
else:
|
243
|
+
logger.info("Only 1 unique document found. Turning off RAFT styling")
|
244
|
+
docs = [answer_document]
|
245
|
+
|
246
|
+
random.shuffle(docs)
|
247
|
+
|
248
|
+
docs = "\n".join(([f"Document:\n{e}\n\n" for idx, e in enumerate(docs)]))
|
249
|
+
user_idx, user_msg = [
|
250
|
+
(idx, rec_msg)
|
251
|
+
for idx, rec_msg in enumerate(rec["messages"])
|
252
|
+
if rec_msg["role"] == "user"
|
253
|
+
][0]
|
254
|
+
user_inst = user_msg["content"]
|
255
|
+
rec["messages"][user_idx]["content"] = f"{docs}\n\n{user_inst}"
|
256
|
+
rec["messages"] = rec["messages"]
|
257
|
+
metadata = json.loads(rec["metadata"])
|
258
|
+
metadata["dataset"] += f"_raft_p{p}"
|
259
|
+
rec["metadata"] = json.dumps(metadata)
|
260
|
+
return rec
|
261
|
+
|
262
|
+
ds = ds.map(_pick_documents, fn_kwargs={"p": p} , remove_columns=["context"])
|
263
|
+
return ds
|
264
|
+
|
265
|
+
|
266
|
+
|
267
|
+
def _conv_pretrain(rec):
|
268
|
+
rec["messages"] = [
|
269
|
+
{
|
270
|
+
"role": "pretraining",
|
271
|
+
"content": f"<|user|>\n{rec['messages'][0]['content']}\n<|assistant|>\n{rec['messages'][1]['content']}",
|
272
|
+
}
|
273
|
+
]
|
274
|
+
return rec
|
275
|
+
|
276
|
+
|
277
|
+
def create_knowledge_regular_ds(generated_dataset: Dataset):
|
278
|
+
# Phase 1.0
|
279
|
+
knowledge_ds = generate_knowledge_qa_dataset(
|
280
|
+
generated_dataset, keep_context_separate=True
|
281
|
+
)
|
282
|
+
knowledge_ds = build_raft_dataset(knowledge_ds, p=0.4)
|
283
|
+
|
284
|
+
auxiliary_dataset = create_auxiliary_dataset(generated_dataset)
|
285
|
+
if auxiliary_dataset is not None:
|
286
|
+
transformed_data = safe_concatenate_datasets([knowledge_ds, auxiliary_dataset])
|
287
|
+
else:
|
288
|
+
transformed_data = knowledge_ds
|
289
|
+
return transformed_data
|
290
|
+
|
291
|
+
|
292
|
+
def create_knowledge_pretraining_ds(generated_dataset: Dataset):
|
293
|
+
# Phase 0.7
|
294
|
+
knowledge_ds = generate_knowledge_qa_dataset(
|
295
|
+
generated_dataset, keep_context_separate=False
|
296
|
+
)
|
297
|
+
knowledge_ds = knowledge_ds.map(_conv_pretrain)
|
298
|
+
|
299
|
+
auxiliary_dataset = create_auxiliary_dataset(generated_dataset)
|
300
|
+
if auxiliary_dataset is not None:
|
301
|
+
auxiliary_dataset = auxiliary_dataset.map(_conv_pretrain)
|
302
|
+
transformed_data = safe_concatenate_datasets([knowledge_ds, auxiliary_dataset])
|
303
|
+
else:
|
304
|
+
transformed_data = knowledge_ds
|
305
|
+
return transformed_data
|
306
|
+
|
307
|
+
|
308
|
+
def post_process_mcq(ds: Dataset, is_mmlu_eval: bool = False) -> Dataset:
|
309
|
+
"""Filters out badly generated data, adds dataset type column
|
310
|
+
|
311
|
+
Args:
|
312
|
+
ds (Dataset): mcq generated dataset from mmmlu pipeline
|
313
|
+
is_mmlu_eval (bool, optional): _description_. Defaults to False.
|
314
|
+
|
315
|
+
Returns:
|
316
|
+
Dataset: Hf Dataset with new column, filtered dataset
|
317
|
+
"""
|
318
|
+
ds = ds.filter(lambda x: ")" in x["mmlubench_answer"])
|
319
|
+
ds = ds.filter(lambda x: "A)" in x["mmlubench_question"])
|
320
|
+
ds = ds.add_column("dataset_type", ["mcq_qa"] * ds.num_rows)
|
321
|
+
if is_mmlu_eval:
|
322
|
+
return format_mmlu_style(ds)
|
323
|
+
return ds
|
324
|
+
|
325
|
+
|
326
|
+
def extract_options(text: str) -> list[Any]:
|
327
|
+
"""regex to extract options from mcq
|
328
|
+
|
329
|
+
Args:
|
330
|
+
text (str): question with options/mcq choices
|
331
|
+
|
332
|
+
Returns:
|
333
|
+
list[Any]: options under question that match the pattern.
|
334
|
+
"""
|
335
|
+
# Use a regular expression to find patterns and capture the text after the letter and parenthesis
|
336
|
+
pattern = r"\b[A-Z]\) (.+)"
|
337
|
+
matches = re.findall(pattern, text)
|
338
|
+
return matches
|
339
|
+
|
340
|
+
|
341
|
+
def format_mmlu_style(ds: Dataset) -> Dataset:
|
342
|
+
"""Format the dataset according to lm-harness mmlu requirement.
|
343
|
+
|
344
|
+
Args:
|
345
|
+
ds (Dataset): input dataset
|
346
|
+
|
347
|
+
Returns:
|
348
|
+
Dataset: formated hf dataset
|
349
|
+
"""
|
350
|
+
ds = ds.map(
|
351
|
+
lambda x: {"answer": x["mmlubench_answer"][: x["mmlubench_answer"].index(")")]}
|
352
|
+
)
|
353
|
+
ds = ds.map(lambda x: {"choices": extract_options(x["mmlubench_question"])})
|
354
|
+
ds = ds.map(
|
355
|
+
lambda x: {
|
356
|
+
"question": x["mmlubench_question"][
|
357
|
+
: x["mmlubench_question"].index("A)")
|
358
|
+
].strip()
|
359
|
+
}
|
360
|
+
)
|
361
|
+
ds = ds.rename_columns({"domain": "subject"})
|
362
|
+
ds = ds.filter(lambda x: x["choices"])
|
363
|
+
ds = ds.filter(lambda x: len(x["choices"]) == 4)
|
364
|
+
ds = ds.filter(lambda x: x["answer"] in ["A", "B", "C", "D"])
|
365
|
+
ds = ds.class_encode_column("answer")
|
366
|
+
return ds
|
367
|
+
|
368
|
+
|
369
|
+
def create_mmlu_evaluation_dataset(generate_mcq_dataset: Dataset) -> Dataset:
|
370
|
+
"""Filter, format and return mcq dataset that is compatible with lm-harness for doing mmlu-style evaluation
|
371
|
+
|
372
|
+
Args:
|
373
|
+
generate_mcq_dataset (Dataset): sdg generated mcq dataset
|
374
|
+
Returns:
|
375
|
+
Dataset: MMLU MCQ datast
|
376
|
+
"""
|
377
|
+
mmlu_dataset = post_process_mcq(generate_mcq_dataset, is_mmlu_eval=True)
|
378
|
+
return mmlu_dataset
|
379
|
+
|
380
|
+
|
381
|
+
def create_mmlu_evaluation_yaml(task_name, eval_data_file_path, yaml_file_path):
|
382
|
+
"""
|
383
|
+
Prepare Task Yaml that will be used in lm_eval_harness to evaluate knowledge using mmlu style metric
|
384
|
+
"""
|
385
|
+
task_yaml = {
|
386
|
+
"task": task_name,
|
387
|
+
"dataset_kwargs": {"data_files": {"test": eval_data_file_path}},
|
388
|
+
"include": "_default_mmlu_pr_template_yaml",
|
389
|
+
"group": "mmlu_pr",
|
390
|
+
}
|
391
|
+
with open(yaml_file_path, "w", encoding="utf-8") as yaml_file:
|
392
|
+
yaml.dump(task_yaml, yaml_file, default_flow_style=False)
|