sdg-hub 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. sdg_hub/__init__.py +4 -0
  2. sdg_hub/_version.py +21 -0
  3. sdg_hub/blocks/__init__.py +6 -0
  4. sdg_hub/blocks/block.py +54 -0
  5. sdg_hub/blocks/filterblock.py +76 -0
  6. sdg_hub/blocks/iterblock.py +31 -0
  7. sdg_hub/blocks/llmblock.py +430 -0
  8. sdg_hub/blocks/rmblocks.py +194 -0
  9. sdg_hub/blocks/utilblocks.py +140 -0
  10. sdg_hub/configs/__init__.py +0 -0
  11. sdg_hub/configs/annotations/__init__.py +0 -0
  12. sdg_hub/configs/annotations/cot_reflection.yaml +34 -0
  13. sdg_hub/configs/annotations/detailed_description.yaml +10 -0
  14. sdg_hub/configs/annotations/detailed_description_icl.yaml +32 -0
  15. sdg_hub/configs/annotations/simple.yaml +10 -0
  16. sdg_hub/configs/knowledge/__init__.py +0 -0
  17. sdg_hub/configs/knowledge/atomic_facts.yaml +45 -0
  18. sdg_hub/configs/knowledge/auxilary_instructions.yaml +35 -0
  19. sdg_hub/configs/knowledge/data_recipe/__init__.py +0 -0
  20. sdg_hub/configs/knowledge/data_recipe/default_recipe.yaml +3 -0
  21. sdg_hub/configs/knowledge/detailed_summary.yaml +17 -0
  22. sdg_hub/configs/knowledge/evaluate_faithfulness.yaml +68 -0
  23. sdg_hub/configs/knowledge/evaluate_question.yaml +38 -0
  24. sdg_hub/configs/knowledge/evaluate_relevancy.yaml +85 -0
  25. sdg_hub/configs/knowledge/extractive_summary.yaml +17 -0
  26. sdg_hub/configs/knowledge/generate_code_questions_responses.yaml +39 -0
  27. sdg_hub/configs/knowledge/generate_questions_responses.yaml +56 -0
  28. sdg_hub/configs/knowledge/mcq_generation.yaml +83 -0
  29. sdg_hub/configs/knowledge/router.yaml +12 -0
  30. sdg_hub/configs/knowledge/simple_generate_qa.yaml +34 -0
  31. sdg_hub/configs/reasoning/dynamic_cot.yaml +40 -0
  32. sdg_hub/configs/skills/_A_.yaml +97 -0
  33. sdg_hub/configs/skills/_B_.yaml +36 -0
  34. sdg_hub/configs/skills/_C_.yaml +71 -0
  35. sdg_hub/configs/skills/_D_.yaml +85 -0
  36. sdg_hub/configs/skills/_E_.yaml +30 -0
  37. sdg_hub/configs/skills/_F_.yaml +45 -0
  38. sdg_hub/configs/skills/_G_.yaml +56 -0
  39. sdg_hub/configs/skills/_H_.yaml +80 -0
  40. sdg_hub/configs/skills/__init__.py +0 -0
  41. sdg_hub/configs/skills/analyzer.yaml +48 -0
  42. sdg_hub/configs/skills/annotation.yaml +36 -0
  43. sdg_hub/configs/skills/contexts.yaml +21 -0
  44. sdg_hub/configs/skills/critic.yaml +60 -0
  45. sdg_hub/configs/skills/data_recipe/__init__.py +0 -0
  46. sdg_hub/configs/skills/data_recipe/default_recipe.yaml +6 -0
  47. sdg_hub/configs/skills/evaluate_freeform_pair.yaml +44 -0
  48. sdg_hub/configs/skills/evaluate_freeform_questions.yaml +46 -0
  49. sdg_hub/configs/skills/evaluate_grounded_pair.yaml +54 -0
  50. sdg_hub/configs/skills/evaluate_grounded_questions.yaml +51 -0
  51. sdg_hub/configs/skills/freeform_questions.yaml +29 -0
  52. sdg_hub/configs/skills/freeform_responses.yaml +45 -0
  53. sdg_hub/configs/skills/grounded_questions.yaml +38 -0
  54. sdg_hub/configs/skills/grounded_responses.yaml +59 -0
  55. sdg_hub/configs/skills/judge.yaml +53 -0
  56. sdg_hub/configs/skills/planner.yaml +67 -0
  57. sdg_hub/configs/skills/respond.yaml +8 -0
  58. sdg_hub/configs/skills/revised_responder.yaml +78 -0
  59. sdg_hub/configs/skills/router.yaml +12 -0
  60. sdg_hub/configs/skills/simple_generate_qa_freeform.yaml +27 -0
  61. sdg_hub/configs/skills/simple_generate_qa_grounded.yaml +31 -0
  62. sdg_hub/flow.py +127 -0
  63. sdg_hub/flows/annotation/emotion/detailed_description.yaml +19 -0
  64. sdg_hub/flows/annotation/emotion/detailed_description_icl.yaml +19 -0
  65. sdg_hub/flows/annotation/emotion/simple.yaml +19 -0
  66. sdg_hub/flows/generation/knowledge/mmlu_bench.yaml +13 -0
  67. sdg_hub/flows/generation/knowledge/simple_knowledge.yaml +12 -0
  68. sdg_hub/flows/generation/knowledge/synth_knowledge.yaml +89 -0
  69. sdg_hub/flows/generation/knowledge/synth_knowledge1.5.yaml +136 -0
  70. sdg_hub/flows/generation/skills/agentic_improve_skill.yaml +108 -0
  71. sdg_hub/flows/generation/skills/simple_freeform_skill.yaml +12 -0
  72. sdg_hub/flows/generation/skills/simple_grounded_skill.yaml +12 -0
  73. sdg_hub/flows/generation/skills/synth_grounded_skills.yaml +80 -0
  74. sdg_hub/flows/generation/skills/synth_skills.yaml +59 -0
  75. sdg_hub/logger_config.py +20 -0
  76. sdg_hub/pipeline.py +66 -0
  77. sdg_hub/prompts.py +17 -0
  78. sdg_hub/py.typed +0 -0
  79. sdg_hub/registry.py +122 -0
  80. sdg_hub/sdg.py +164 -0
  81. sdg_hub/utils/__init__.py +5 -0
  82. sdg_hub/utils/chunking.py +73 -0
  83. sdg_hub/utils/datamixing.py +123 -0
  84. sdg_hub/utils/datautils.py +14 -0
  85. sdg_hub/utils/docprocessor.py +357 -0
  86. sdg_hub/utils/json.py +48 -0
  87. sdg_hub/utils/models.py +31 -0
  88. sdg_hub/utils/parse_and_convert.py +392 -0
  89. sdg_hub/utils/taxonomy.py +489 -0
  90. sdg_hub-0.1.0a1.dist-info/METADATA +154 -0
  91. sdg_hub-0.1.0a1.dist-info/RECORD +94 -0
  92. sdg_hub-0.1.0a1.dist-info/WHEEL +5 -0
  93. sdg_hub-0.1.0a1.dist-info/licenses/LICENSE +201 -0
  94. sdg_hub-0.1.0a1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,392 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # Standard
4
+ from enum import Enum
5
+ from typing import Any
6
+ import json
7
+ import os
8
+ import random
9
+ import re
10
+ import uuid
11
+
12
+ # Third Party
13
+ from datasets import Dataset
14
+ import yaml
15
+
16
+ # First Party
17
+ # pylint: disable=ungrouped-imports
18
+ from sdg_hub import utils
19
+ from sdg_hub.logger_config import setup_logger
20
+ from .datautils import safe_concatenate_datasets
21
+
22
+ logger = setup_logger(__name__)
23
+
24
+
25
+ class TaxonomyType(Enum):
26
+ KNOWLEDGE = "knowledge"
27
+ SKILL = "skill"
28
+
29
+
30
+ def _unescape(s):
31
+ return bytes(s, "utf-8").decode("utf-8").strip()
32
+
33
+
34
+ # This is a hack because the simple workflow returns a q/a pair as a single output.
35
+ # We could possibly try to ask for them separately, but it would cost twice the inference
36
+ # API calls. All of this is because the smallest models we use on small environments
37
+ # for testing and demos weren't good enough to follow the strict formatting instructions used
38
+ # in the full pipeline.
39
+ def _get_question(synth_example: dict):
40
+ if "question" in synth_example:
41
+ return synth_example["question"]
42
+
43
+ if not synth_example.get("output"):
44
+ raise utils.GenerateException(
45
+ f"Error: output not found in synth_example: {synth_example}"
46
+ )
47
+
48
+ parts = synth_example["output"].split("?", 1)
49
+ if len(parts) != 2:
50
+ logger.warning(f"Failed to split generated q&a: {synth_example['output']}")
51
+ return parts[0].strip() + "?" if len(parts) == 2 else ""
52
+
53
+
54
+ # This is also a hack. See the comment above _get_question.
55
+ def _get_response(synth_example: dict):
56
+ if "response" in synth_example:
57
+ return synth_example["response"]
58
+
59
+ if "output" not in synth_example:
60
+ raise utils.GenerateException(
61
+ f"Error: output not found in synth_example: {synth_example}"
62
+ )
63
+
64
+ parts = synth_example["output"].split("?", 1)
65
+ if len(parts) != 2:
66
+ logger.warning(f"Failed to split generated q&a: {synth_example['output']}")
67
+ return parts[1].strip() if len(parts) == 2 else parts[0].strip()
68
+
69
+
70
+ def _convert_to_hack_fmt(sample: dict, sys_prompt: str):
71
+ """
72
+ Convert a sample dictionary to contain 'system', 'user', and 'assistant' columns.
73
+
74
+ Note: We should remove this function in the future when we resolve this issue and
75
+ standardize the format to messages.
76
+ """
77
+ # Create user query message
78
+ user_query = _unescape(_get_question(sample))
79
+ response = _unescape(_get_response(sample))
80
+ if "context" in sample:
81
+ user_query = f"{sample['context']}\n\n{user_query}"
82
+
83
+ sample["id"] = str(uuid.uuid4())
84
+ sample["system"] = sys_prompt
85
+ sample["user"] = user_query
86
+ sample["assistant"] = response
87
+
88
+ return sample
89
+
90
+
91
+ def _convert_to_messages(sample: dict, sys_prompt: str):
92
+ """
93
+ Convert a sample dictionary to contain 'messages'
94
+ and 'metadata' columns required for training.
95
+ """
96
+ # Create user query message
97
+ user_query = _unescape(_get_question(sample))
98
+ response = _unescape(_get_response(sample))
99
+
100
+ sample["id"] = str(uuid.uuid4())
101
+ sample["messages"] = [
102
+ {"content": sys_prompt, "role": "system"},
103
+ {"content": user_query, "role": "user"},
104
+ {"content": response, "role": "assistant"},
105
+ ]
106
+
107
+ return sample
108
+
109
+
110
+ def create_auxiliary_dataset(generated_dataset: Dataset):
111
+ if "dataset_type" not in generated_dataset.column_names:
112
+ return None
113
+
114
+ # get module path of the current file
115
+ module_dir = os.path.dirname(os.path.abspath(__file__))
116
+ aux_inst_path = os.path.join(module_dir, "../configs/knowledge/auxilary_instructions.yaml")
117
+ if os.path.isfile(
118
+ aux_inst_path
119
+ ):
120
+ with open(aux_inst_path, "r", encoding="utf-8") as fp:
121
+ auxiliary_inst = yaml.safe_load(fp)
122
+ else:
123
+ logger.error(f"auxiliary instructions file not found at {aux_inst_path}")
124
+ return None
125
+ auxiliary_ds = generated_dataset.filter(
126
+ lambda x: x["dataset_type"] != "base_document"
127
+ )
128
+ unique_document_auxiliary = auxiliary_ds.to_pandas().drop_duplicates(
129
+ subset=["document"]
130
+ )
131
+ unique_document_auxiliary = Dataset.from_pandas(unique_document_auxiliary)
132
+ unique_document_auxiliary = unique_document_auxiliary.remove_columns(
133
+ [
134
+ col
135
+ for col in unique_document_auxiliary.column_names
136
+ if col
137
+ not in [
138
+ "raw_document",
139
+ "document_outline",
140
+ "domain",
141
+ "dataset_type",
142
+ "document",
143
+ ]
144
+ ]
145
+ )
146
+ unique_document_auxiliary = unique_document_auxiliary.rename_columns(
147
+ {"raw_document": "context", "document": "response"}
148
+ )
149
+
150
+ def __create_auxiliary_ds(rec):
151
+ instruction = random.choice(auxiliary_inst[rec["dataset_type"]])
152
+ messages = [
153
+ {"role": "user", "content": f"{rec['context']}\n\n{instruction}"},
154
+ {"role": "assistant", "content": rec["response"]},
155
+ ]
156
+ metadata = json.dumps(
157
+ {
158
+ "dataset_type": rec["dataset_type"],
159
+ "raw_document": rec["context"],
160
+ "dataset": f"document_{rec['dataset_type']}",
161
+ "domain": rec["domain"],
162
+ }
163
+ )
164
+ return {"messages": messages, "metadata": metadata, "id": str(uuid.uuid4())}
165
+
166
+ unique_document_auxiliary = unique_document_auxiliary.map(
167
+ __create_auxiliary_ds, remove_columns=unique_document_auxiliary.column_names
168
+ )
169
+ return unique_document_auxiliary
170
+
171
+
172
+ def generate_knowledge_qa_dataset(
173
+ generated_dataset: Dataset, keep_context_separate=False
174
+ ):
175
+ def __create_qa_row(rec):
176
+ context = rec["document"]
177
+ instruction = rec["question"]
178
+ response = rec["response"]
179
+ metadata = {
180
+ "sdg_document": rec["document"],
181
+ "domain": rec["domain"],
182
+ "dataset": "document_knowledge_qa",
183
+ }
184
+ if "raw_document" in rec and "dataset_type" in rec:
185
+ metadata.update(
186
+ {
187
+ "raw_document": rec["raw_document"],
188
+ "dataset_type": rec["dataset_type"],
189
+ }
190
+ )
191
+ metadata = json.dumps(metadata)
192
+ if keep_context_separate:
193
+ messages = [
194
+ {"role": "user", "content": f"{instruction}"},
195
+ {"role": "assistant", "content": response},
196
+ ]
197
+ return {
198
+ "messages": messages,
199
+ "metadata": metadata,
200
+ "id": str(uuid.uuid4()),
201
+ "context": context,
202
+ }
203
+ else:
204
+ messages = [
205
+ {"role": "user", "content": f"{context}\n\n{instruction}"},
206
+ {"role": "assistant", "content": response},
207
+ ]
208
+
209
+ return {"messages": messages, "metadata": metadata, "id": str(uuid.uuid4())}
210
+
211
+ knowledge_ds = generated_dataset.map(
212
+ __create_qa_row, remove_columns=generated_dataset.column_names
213
+ )
214
+ return knowledge_ds
215
+
216
+
217
+ def build_raft_dataset(ds: Dataset, p, num_doc_in_context=4):
218
+ all_context = list(set(ds["context"]))
219
+
220
+ def _pick_documents(rec, p):
221
+ answer_document = rec["context"]
222
+ selected_docs = [e for e in all_context if e != answer_document]
223
+ if len(selected_docs) > 0:
224
+ if len(selected_docs) < num_doc_in_context:
225
+ logger.info(
226
+ f"Number of unique document is {len(selected_docs)} which is less than {num_doc_in_context}. Using all the documents in the RAFT context"
227
+ )
228
+ if random.uniform(0, 1) < p:
229
+ # golden/answer + distractor documents
230
+ docs = (
231
+ random.sample(selected_docs, k=num_doc_in_context-1) + [answer_document]
232
+ if len(selected_docs) >= (num_doc_in_context-1)
233
+ else selected_docs + [answer_document]
234
+ )
235
+ else:
236
+ # distractor documents
237
+ docs = (
238
+ random.sample(selected_docs, k=num_doc_in_context)
239
+ if len(selected_docs) >= num_doc_in_context
240
+ else selected_docs
241
+ )
242
+ else:
243
+ logger.info("Only 1 unique document found. Turning off RAFT styling")
244
+ docs = [answer_document]
245
+
246
+ random.shuffle(docs)
247
+
248
+ docs = "\n".join(([f"Document:\n{e}\n\n" for idx, e in enumerate(docs)]))
249
+ user_idx, user_msg = [
250
+ (idx, rec_msg)
251
+ for idx, rec_msg in enumerate(rec["messages"])
252
+ if rec_msg["role"] == "user"
253
+ ][0]
254
+ user_inst = user_msg["content"]
255
+ rec["messages"][user_idx]["content"] = f"{docs}\n\n{user_inst}"
256
+ rec["messages"] = rec["messages"]
257
+ metadata = json.loads(rec["metadata"])
258
+ metadata["dataset"] += f"_raft_p{p}"
259
+ rec["metadata"] = json.dumps(metadata)
260
+ return rec
261
+
262
+ ds = ds.map(_pick_documents, fn_kwargs={"p": p} , remove_columns=["context"])
263
+ return ds
264
+
265
+
266
+
267
+ def _conv_pretrain(rec):
268
+ rec["messages"] = [
269
+ {
270
+ "role": "pretraining",
271
+ "content": f"<|user|>\n{rec['messages'][0]['content']}\n<|assistant|>\n{rec['messages'][1]['content']}",
272
+ }
273
+ ]
274
+ return rec
275
+
276
+
277
+ def create_knowledge_regular_ds(generated_dataset: Dataset):
278
+ # Phase 1.0
279
+ knowledge_ds = generate_knowledge_qa_dataset(
280
+ generated_dataset, keep_context_separate=True
281
+ )
282
+ knowledge_ds = build_raft_dataset(knowledge_ds, p=0.4)
283
+
284
+ auxiliary_dataset = create_auxiliary_dataset(generated_dataset)
285
+ if auxiliary_dataset is not None:
286
+ transformed_data = safe_concatenate_datasets([knowledge_ds, auxiliary_dataset])
287
+ else:
288
+ transformed_data = knowledge_ds
289
+ return transformed_data
290
+
291
+
292
+ def create_knowledge_pretraining_ds(generated_dataset: Dataset):
293
+ # Phase 0.7
294
+ knowledge_ds = generate_knowledge_qa_dataset(
295
+ generated_dataset, keep_context_separate=False
296
+ )
297
+ knowledge_ds = knowledge_ds.map(_conv_pretrain)
298
+
299
+ auxiliary_dataset = create_auxiliary_dataset(generated_dataset)
300
+ if auxiliary_dataset is not None:
301
+ auxiliary_dataset = auxiliary_dataset.map(_conv_pretrain)
302
+ transformed_data = safe_concatenate_datasets([knowledge_ds, auxiliary_dataset])
303
+ else:
304
+ transformed_data = knowledge_ds
305
+ return transformed_data
306
+
307
+
308
+ def post_process_mcq(ds: Dataset, is_mmlu_eval: bool = False) -> Dataset:
309
+ """Filters out badly generated data, adds dataset type column
310
+
311
+ Args:
312
+ ds (Dataset): mcq generated dataset from mmmlu pipeline
313
+ is_mmlu_eval (bool, optional): _description_. Defaults to False.
314
+
315
+ Returns:
316
+ Dataset: Hf Dataset with new column, filtered dataset
317
+ """
318
+ ds = ds.filter(lambda x: ")" in x["mmlubench_answer"])
319
+ ds = ds.filter(lambda x: "A)" in x["mmlubench_question"])
320
+ ds = ds.add_column("dataset_type", ["mcq_qa"] * ds.num_rows)
321
+ if is_mmlu_eval:
322
+ return format_mmlu_style(ds)
323
+ return ds
324
+
325
+
326
+ def extract_options(text: str) -> list[Any]:
327
+ """regex to extract options from mcq
328
+
329
+ Args:
330
+ text (str): question with options/mcq choices
331
+
332
+ Returns:
333
+ list[Any]: options under question that match the pattern.
334
+ """
335
+ # Use a regular expression to find patterns and capture the text after the letter and parenthesis
336
+ pattern = r"\b[A-Z]\) (.+)"
337
+ matches = re.findall(pattern, text)
338
+ return matches
339
+
340
+
341
+ def format_mmlu_style(ds: Dataset) -> Dataset:
342
+ """Format the dataset according to lm-harness mmlu requirement.
343
+
344
+ Args:
345
+ ds (Dataset): input dataset
346
+
347
+ Returns:
348
+ Dataset: formated hf dataset
349
+ """
350
+ ds = ds.map(
351
+ lambda x: {"answer": x["mmlubench_answer"][: x["mmlubench_answer"].index(")")]}
352
+ )
353
+ ds = ds.map(lambda x: {"choices": extract_options(x["mmlubench_question"])})
354
+ ds = ds.map(
355
+ lambda x: {
356
+ "question": x["mmlubench_question"][
357
+ : x["mmlubench_question"].index("A)")
358
+ ].strip()
359
+ }
360
+ )
361
+ ds = ds.rename_columns({"domain": "subject"})
362
+ ds = ds.filter(lambda x: x["choices"])
363
+ ds = ds.filter(lambda x: len(x["choices"]) == 4)
364
+ ds = ds.filter(lambda x: x["answer"] in ["A", "B", "C", "D"])
365
+ ds = ds.class_encode_column("answer")
366
+ return ds
367
+
368
+
369
+ def create_mmlu_evaluation_dataset(generate_mcq_dataset: Dataset) -> Dataset:
370
+ """Filter, format and return mcq dataset that is compatible with lm-harness for doing mmlu-style evaluation
371
+
372
+ Args:
373
+ generate_mcq_dataset (Dataset): sdg generated mcq dataset
374
+ Returns:
375
+ Dataset: MMLU MCQ datast
376
+ """
377
+ mmlu_dataset = post_process_mcq(generate_mcq_dataset, is_mmlu_eval=True)
378
+ return mmlu_dataset
379
+
380
+
381
+ def create_mmlu_evaluation_yaml(task_name, eval_data_file_path, yaml_file_path):
382
+ """
383
+ Prepare Task Yaml that will be used in lm_eval_harness to evaluate knowledge using mmlu style metric
384
+ """
385
+ task_yaml = {
386
+ "task": task_name,
387
+ "dataset_kwargs": {"data_files": {"test": eval_data_file_path}},
388
+ "include": "_default_mmlu_pr_template_yaml",
389
+ "group": "mmlu_pr",
390
+ }
391
+ with open(yaml_file_path, "w", encoding="utf-8") as yaml_file:
392
+ yaml.dump(task_yaml, yaml_file, default_flow_style=False)