sdg-hub 0.1.0a2.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sdg_hub/__init__.py +4 -0
- sdg_hub/_version.py +21 -0
- sdg_hub/blocks/__init__.py +6 -0
- sdg_hub/blocks/block.py +54 -0
- sdg_hub/blocks/filterblock.py +76 -0
- sdg_hub/blocks/iterblock.py +31 -0
- sdg_hub/blocks/llmblock.py +430 -0
- sdg_hub/blocks/rmblocks.py +194 -0
- sdg_hub/blocks/utilblocks.py +140 -0
- sdg_hub/configs/__init__.py +0 -0
- sdg_hub/configs/annotations/__init__.py +0 -0
- sdg_hub/configs/annotations/cot_reflection.yaml +34 -0
- sdg_hub/configs/annotations/detailed_description.yaml +10 -0
- sdg_hub/configs/annotations/detailed_description_icl.yaml +32 -0
- sdg_hub/configs/annotations/simple.yaml +10 -0
- sdg_hub/configs/knowledge/__init__.py +0 -0
- sdg_hub/configs/knowledge/atomic_facts.yaml +45 -0
- sdg_hub/configs/knowledge/auxilary_instructions.yaml +35 -0
- sdg_hub/configs/knowledge/data_recipe/__init__.py +0 -0
- sdg_hub/configs/knowledge/data_recipe/default_recipe.yaml +3 -0
- sdg_hub/configs/knowledge/detailed_summary.yaml +17 -0
- sdg_hub/configs/knowledge/evaluate_faithfulness.yaml +68 -0
- sdg_hub/configs/knowledge/evaluate_question.yaml +38 -0
- sdg_hub/configs/knowledge/evaluate_relevancy.yaml +85 -0
- sdg_hub/configs/knowledge/extractive_summary.yaml +17 -0
- sdg_hub/configs/knowledge/generate_code_questions_responses.yaml +39 -0
- sdg_hub/configs/knowledge/generate_questions_responses.yaml +56 -0
- sdg_hub/configs/knowledge/mcq_generation.yaml +83 -0
- sdg_hub/configs/knowledge/router.yaml +12 -0
- sdg_hub/configs/knowledge/simple_generate_qa.yaml +34 -0
- sdg_hub/configs/reasoning/dynamic_cot.yaml +40 -0
- sdg_hub/configs/skills/_A_.yaml +97 -0
- sdg_hub/configs/skills/_B_.yaml +36 -0
- sdg_hub/configs/skills/_C_.yaml +71 -0
- sdg_hub/configs/skills/_D_.yaml +85 -0
- sdg_hub/configs/skills/_E_.yaml +30 -0
- sdg_hub/configs/skills/_F_.yaml +45 -0
- sdg_hub/configs/skills/_G_.yaml +56 -0
- sdg_hub/configs/skills/_H_.yaml +80 -0
- sdg_hub/configs/skills/__init__.py +0 -0
- sdg_hub/configs/skills/analyzer.yaml +48 -0
- sdg_hub/configs/skills/annotation.yaml +36 -0
- sdg_hub/configs/skills/contexts.yaml +21 -0
- sdg_hub/configs/skills/critic.yaml +60 -0
- sdg_hub/configs/skills/data_recipe/__init__.py +0 -0
- sdg_hub/configs/skills/data_recipe/default_recipe.yaml +6 -0
- sdg_hub/configs/skills/evaluate_freeform_pair.yaml +44 -0
- sdg_hub/configs/skills/evaluate_freeform_questions.yaml +46 -0
- sdg_hub/configs/skills/evaluate_grounded_pair.yaml +54 -0
- sdg_hub/configs/skills/evaluate_grounded_questions.yaml +51 -0
- sdg_hub/configs/skills/freeform_questions.yaml +29 -0
- sdg_hub/configs/skills/freeform_responses.yaml +45 -0
- sdg_hub/configs/skills/grounded_questions.yaml +38 -0
- sdg_hub/configs/skills/grounded_responses.yaml +59 -0
- sdg_hub/configs/skills/judge.yaml +53 -0
- sdg_hub/configs/skills/planner.yaml +67 -0
- sdg_hub/configs/skills/respond.yaml +8 -0
- sdg_hub/configs/skills/revised_responder.yaml +78 -0
- sdg_hub/configs/skills/router.yaml +12 -0
- sdg_hub/configs/skills/simple_generate_qa_freeform.yaml +27 -0
- sdg_hub/configs/skills/simple_generate_qa_grounded.yaml +31 -0
- sdg_hub/flow.py +127 -0
- sdg_hub/flows/annotation/emotion/detailed_description.yaml +19 -0
- sdg_hub/flows/annotation/emotion/detailed_description_icl.yaml +19 -0
- sdg_hub/flows/annotation/emotion/simple.yaml +19 -0
- sdg_hub/flows/generation/knowledge/mmlu_bench.yaml +13 -0
- sdg_hub/flows/generation/knowledge/simple_knowledge.yaml +12 -0
- sdg_hub/flows/generation/knowledge/synth_knowledge.yaml +89 -0
- sdg_hub/flows/generation/knowledge/synth_knowledge1.5.yaml +136 -0
- sdg_hub/flows/generation/skills/agentic_improve_skill.yaml +108 -0
- sdg_hub/flows/generation/skills/simple_freeform_skill.yaml +12 -0
- sdg_hub/flows/generation/skills/simple_grounded_skill.yaml +12 -0
- sdg_hub/flows/generation/skills/synth_grounded_skills.yaml +80 -0
- sdg_hub/flows/generation/skills/synth_skills.yaml +59 -0
- sdg_hub/logger_config.py +20 -0
- sdg_hub/pipeline.py +66 -0
- sdg_hub/prompts.py +17 -0
- sdg_hub/py.typed +0 -0
- sdg_hub/registry.py +122 -0
- sdg_hub/sdg.py +164 -0
- sdg_hub/utils/__init__.py +5 -0
- sdg_hub/utils/chunking.py +73 -0
- sdg_hub/utils/datamixing.py +123 -0
- sdg_hub/utils/datautils.py +14 -0
- sdg_hub/utils/docprocessor.py +357 -0
- sdg_hub/utils/json.py +48 -0
- sdg_hub/utils/models.py +31 -0
- sdg_hub/utils/parse_and_convert.py +392 -0
- sdg_hub/utils/taxonomy.py +489 -0
- sdg_hub-0.1.0a2.dev0.dist-info/METADATA +154 -0
- sdg_hub-0.1.0a2.dev0.dist-info/RECORD +94 -0
- sdg_hub-0.1.0a2.dev0.dist-info/WHEEL +5 -0
- sdg_hub-0.1.0a2.dev0.dist-info/licenses/LICENSE +201 -0
- sdg_hub-0.1.0a2.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,59 @@
|
|
1
|
+
- block_type: LLMBlock
|
2
|
+
block_config:
|
3
|
+
block_name: gen_questions
|
4
|
+
config_path: configs/skills/freeform_questions.yaml
|
5
|
+
model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
|
6
|
+
output_cols:
|
7
|
+
- question
|
8
|
+
batch_kwargs:
|
9
|
+
num_samples: 30
|
10
|
+
drop_duplicates:
|
11
|
+
- question
|
12
|
+
- block_type: LLMBlock
|
13
|
+
block_config:
|
14
|
+
block_name: eval_questions
|
15
|
+
config_path: configs/skills/evaluate_freeform_questions.yaml
|
16
|
+
model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
|
17
|
+
output_cols:
|
18
|
+
- evaluation
|
19
|
+
- score
|
20
|
+
- block_type: FilterByValueBlock
|
21
|
+
block_config:
|
22
|
+
block_name: filter_questions
|
23
|
+
filter_column: score
|
24
|
+
filter_value: 1.0
|
25
|
+
operation: operator.eq
|
26
|
+
convert_dtype: float
|
27
|
+
batch_kwargs:
|
28
|
+
num_procs: 8
|
29
|
+
drop_columns:
|
30
|
+
- evaluation
|
31
|
+
- score
|
32
|
+
- num_samples
|
33
|
+
- block_type: LLMBlock
|
34
|
+
block_config:
|
35
|
+
block_name: gen_responses
|
36
|
+
config_path: configs/skills/freeform_responses.yaml
|
37
|
+
model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
|
38
|
+
output_cols:
|
39
|
+
- response
|
40
|
+
- block_type: LLMBlock
|
41
|
+
block_config:
|
42
|
+
block_name: evaluate_qa_pair
|
43
|
+
config_path: configs/skills/evaluate_freeform_pair.yaml
|
44
|
+
model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
|
45
|
+
output_cols:
|
46
|
+
- evaluation
|
47
|
+
- score
|
48
|
+
- block_type: FilterByValueBlock
|
49
|
+
block_config:
|
50
|
+
block_name: filter_qa_pair
|
51
|
+
filter_column: score
|
52
|
+
filter_value: 2.0
|
53
|
+
operation: operator.ge
|
54
|
+
convert_dtype: float
|
55
|
+
batch_kwargs:
|
56
|
+
num_procs: 8
|
57
|
+
drop_columns:
|
58
|
+
- evaluation
|
59
|
+
- score
|
sdg_hub/logger_config.py
ADDED
@@ -0,0 +1,20 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
# Standard
|
3
|
+
import os
|
4
|
+
import logging
|
5
|
+
|
6
|
+
# Third Party
|
7
|
+
from rich.logging import RichHandler
|
8
|
+
|
9
|
+
|
10
|
+
def setup_logger(name):
|
11
|
+
# Set up the logger
|
12
|
+
log_level = os.getenv("LOG_LEVEL", "INFO")
|
13
|
+
logging.basicConfig(
|
14
|
+
level=log_level,
|
15
|
+
format="%(message)s",
|
16
|
+
datefmt="[%X]",
|
17
|
+
handlers=[RichHandler()],
|
18
|
+
)
|
19
|
+
logger = logging.getLogger(name)
|
20
|
+
return logger
|
sdg_hub/pipeline.py
ADDED
@@ -0,0 +1,66 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
# Third Party
|
3
|
+
from datasets import Dataset
|
4
|
+
|
5
|
+
# Local
|
6
|
+
from .logger_config import setup_logger
|
7
|
+
|
8
|
+
logger = setup_logger(__name__)
|
9
|
+
|
10
|
+
|
11
|
+
class EmptyDatasetError(Exception):
|
12
|
+
pass
|
13
|
+
|
14
|
+
|
15
|
+
class Pipeline:
|
16
|
+
def __init__(self, chained_blocks: list) -> None:
|
17
|
+
"""
|
18
|
+
Initialize the Pipeline class with a configuration dictionary.
|
19
|
+
config_dict: the run config py or yaml loaded into a dictionary
|
20
|
+
"""
|
21
|
+
# pipeline config is the run configuration that consists of the pipeline steps
|
22
|
+
self.chained_blocks = chained_blocks
|
23
|
+
|
24
|
+
def _drop_duplicates(self, dataset, cols):
|
25
|
+
"""
|
26
|
+
Drop duplicates from the dataset based on the columns provided.
|
27
|
+
"""
|
28
|
+
df = dataset.to_pandas()
|
29
|
+
df = df.drop_duplicates(subset=cols).reset_index(drop=True)
|
30
|
+
return Dataset.from_pandas(df)
|
31
|
+
|
32
|
+
def generate(self, dataset) -> Dataset:
|
33
|
+
"""
|
34
|
+
Generate the dataset by running the pipeline steps.
|
35
|
+
dataset: the input dataset
|
36
|
+
"""
|
37
|
+
for block_prop in self.chained_blocks:
|
38
|
+
block_type = block_prop["block_type"]
|
39
|
+
block_config = block_prop["block_config"]
|
40
|
+
drop_columns = block_prop.get("drop_columns", [])
|
41
|
+
gen_kwargs = block_prop.get("gen_kwargs", {})
|
42
|
+
drop_duplicates_cols = block_prop.get("drop_duplicates", False)
|
43
|
+
block = block_type(**block_config)
|
44
|
+
|
45
|
+
logger.debug("------------------------------------\n")
|
46
|
+
logger.debug("Running block: %s", block_config["block_name"])
|
47
|
+
logger.debug("Input dataset: %s", dataset)
|
48
|
+
|
49
|
+
dataset = block.generate(dataset, **gen_kwargs)
|
50
|
+
|
51
|
+
if len(dataset) == 0:
|
52
|
+
raise EmptyDatasetError(
|
53
|
+
f"Pipeline stopped: Empty dataset after running block: {block_config['block_name']}"
|
54
|
+
)
|
55
|
+
|
56
|
+
drop_columns_in_ds = [e for e in drop_columns if e in dataset.column_names]
|
57
|
+
if drop_columns:
|
58
|
+
dataset = dataset.remove_columns(drop_columns_in_ds)
|
59
|
+
|
60
|
+
if drop_duplicates_cols:
|
61
|
+
dataset = self._drop_duplicates(dataset, cols=drop_duplicates_cols)
|
62
|
+
|
63
|
+
logger.debug("Output dataset: %s", dataset)
|
64
|
+
logger.debug("------------------------------------\n\n")
|
65
|
+
|
66
|
+
return dataset
|
sdg_hub/prompts.py
ADDED
@@ -0,0 +1,17 @@
|
|
1
|
+
# Local
|
2
|
+
from .registry import PromptRegistry
|
3
|
+
|
4
|
+
|
5
|
+
@PromptRegistry.register("blank")
|
6
|
+
def blank_chat_template():
|
7
|
+
return """{{ messages }}"""
|
8
|
+
|
9
|
+
|
10
|
+
@PromptRegistry.register("instructlab")
|
11
|
+
def instructlab_chat_template():
|
12
|
+
return """{% for message in messages %}{% if message['role'] == 'pretraining' %}{{ '<|pretrain|>' + message['content'] + '<|endoftext|>' + '<|/pretrain|>' }}{% elif message['role'] == 'system' %}{{ '<|system|>' + '\n' + message['content'] + '\n' }}{% elif message['role'] == 'user' %}{{ '<|user|>' + '\n' + message['content'] + '\n' }}{% elif message['role'] == 'assistant' %}{{ '<|assistant|>' + '\n' + message['content'] + '<|endoftext|>' + ('' if loop.last else '\n') }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '<|assistant|>' + '\n' }}{% endif %}{% endfor %}"""
|
13
|
+
|
14
|
+
|
15
|
+
@PromptRegistry.register("mistralai")
|
16
|
+
def mistral_chat_template():
|
17
|
+
return """{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n<s>\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.first and system_message is defined %}\n {{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}\n {%- else %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- ' ' + message['content'] + '</s>'}}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n{%- endfor %}\n"""
|
sdg_hub/py.typed
ADDED
File without changes
|
sdg_hub/registry.py
ADDED
@@ -0,0 +1,122 @@
|
|
1
|
+
# Standard
|
2
|
+
from typing import Union, List, Dict
|
3
|
+
|
4
|
+
# Third Party
|
5
|
+
from jinja2 import Template
|
6
|
+
|
7
|
+
# Local
|
8
|
+
from .logger_config import setup_logger
|
9
|
+
|
10
|
+
logger = setup_logger(__name__)
|
11
|
+
|
12
|
+
|
13
|
+
class BlockRegistry:
|
14
|
+
"""Registry for block classes to avoid manual additions to block type map."""
|
15
|
+
|
16
|
+
_registry: Dict[str, type] = {}
|
17
|
+
|
18
|
+
@classmethod
|
19
|
+
def register(cls, block_name: str):
|
20
|
+
"""
|
21
|
+
Decorator to register a block class under a specified name.
|
22
|
+
|
23
|
+
:param block_name: Name under which to register the block.
|
24
|
+
"""
|
25
|
+
|
26
|
+
def decorator(block_class):
|
27
|
+
cls._registry[block_name] = block_class
|
28
|
+
logger.debug(
|
29
|
+
f"Registered block '{block_name}' with class '{block_class.__name__}'"
|
30
|
+
)
|
31
|
+
return block_class
|
32
|
+
|
33
|
+
return decorator
|
34
|
+
|
35
|
+
@classmethod
|
36
|
+
def get_registry(cls):
|
37
|
+
"""
|
38
|
+
Retrieve the current registry map of block types.
|
39
|
+
|
40
|
+
:return: Dictionary of registered block names and classes.
|
41
|
+
"""
|
42
|
+
logger.debug("Fetching the block registry map.")
|
43
|
+
return cls._registry
|
44
|
+
|
45
|
+
|
46
|
+
class PromptRegistry:
|
47
|
+
"""Registry for managing Jinja2 prompt templates."""
|
48
|
+
|
49
|
+
_registry: Dict[str, Template] = {}
|
50
|
+
|
51
|
+
@classmethod
|
52
|
+
def register(cls, name: str):
|
53
|
+
"""Decorator to register a Jinja2 template function by name.
|
54
|
+
|
55
|
+
:param name: Name of the template to register.
|
56
|
+
:return: A decorator that registers the Jinja2 template function.
|
57
|
+
"""
|
58
|
+
|
59
|
+
def decorator(func):
|
60
|
+
template_str = func()
|
61
|
+
cls._registry[name] = Template(template_str)
|
62
|
+
logger.debug(f"Registered prompt template '{name}'")
|
63
|
+
return func
|
64
|
+
|
65
|
+
return decorator
|
66
|
+
|
67
|
+
@classmethod
|
68
|
+
def get_template(cls, name: str) -> Template:
|
69
|
+
"""Retrieve a Jinja2 template by name.
|
70
|
+
|
71
|
+
:param name: Name of the template to retrieve.
|
72
|
+
:return: The Jinja2 template instance.
|
73
|
+
"""
|
74
|
+
if name not in cls._registry:
|
75
|
+
raise KeyError(f"Template '{name}' not found.")
|
76
|
+
logger.debug(f"Retrieving prompt template '{name}'")
|
77
|
+
return cls._registry[name]
|
78
|
+
|
79
|
+
@classmethod
|
80
|
+
def get_registry(cls):
|
81
|
+
"""
|
82
|
+
Retrieve the current registry map of block types.
|
83
|
+
|
84
|
+
:return: Dictionary of registered block names and classes.
|
85
|
+
"""
|
86
|
+
logger.debug("Fetching the block registry map.")
|
87
|
+
return cls._registry
|
88
|
+
|
89
|
+
@classmethod
|
90
|
+
def render_template(
|
91
|
+
cls,
|
92
|
+
name: str,
|
93
|
+
messages: Union[str, List[Dict[str, str]]],
|
94
|
+
add_generation_prompt: bool = True,
|
95
|
+
) -> str:
|
96
|
+
"""Render the template with the provided messages or query.
|
97
|
+
|
98
|
+
:param name: Name of the template to render.
|
99
|
+
:param messages: Either a single query string or a list of messages (each as a dict with 'role' and 'content').
|
100
|
+
:param add_generation_prompt: Whether to add a generation prompt at the end.
|
101
|
+
:return: The rendered prompt as a string.
|
102
|
+
"""
|
103
|
+
|
104
|
+
# Special handling for "blank" template
|
105
|
+
if name == "blank":
|
106
|
+
if not isinstance(messages, str):
|
107
|
+
raise ValueError(
|
108
|
+
"The 'blank' template can only be used with a single query string, not a list of messages."
|
109
|
+
)
|
110
|
+
return messages # Return the query as-is without templating
|
111
|
+
|
112
|
+
# Get the template
|
113
|
+
template = cls.get_template(name)
|
114
|
+
|
115
|
+
# If `messages` is a string, wrap it in a list with a default user role
|
116
|
+
if isinstance(messages, str):
|
117
|
+
messages = [{"role": "user", "content": messages}]
|
118
|
+
|
119
|
+
# Render the template with the `messages` list
|
120
|
+
return template.render(
|
121
|
+
messages=messages, add_generation_prompt=add_generation_prompt
|
122
|
+
)
|
sdg_hub/sdg.py
ADDED
@@ -0,0 +1,164 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
# Standard
|
3
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
4
|
+
from typing import List
|
5
|
+
import traceback
|
6
|
+
import uuid
|
7
|
+
|
8
|
+
# Third Party
|
9
|
+
from datasets import Dataset, load_dataset
|
10
|
+
from datasets.data_files import EmptyDatasetError
|
11
|
+
from tqdm import tqdm
|
12
|
+
|
13
|
+
# Local
|
14
|
+
from .logger_config import setup_logger
|
15
|
+
from .pipeline import Pipeline
|
16
|
+
from .utils.datautils import safe_concatenate_datasets
|
17
|
+
|
18
|
+
|
19
|
+
logger = setup_logger(__name__)
|
20
|
+
|
21
|
+
|
22
|
+
class SDG:
|
23
|
+
def __init__(
|
24
|
+
self, pipelines: List[Pipeline], num_workers=1, batch_size=None, save_freq=None
|
25
|
+
) -> None:
|
26
|
+
self.pipelines = pipelines
|
27
|
+
self.num_workers = num_workers
|
28
|
+
self.batch_size = batch_size
|
29
|
+
self.save_freq = save_freq
|
30
|
+
|
31
|
+
def _split_dataset(self, dataset: Dataset, batch_size: int) -> List[Dataset]:
|
32
|
+
"""Split the dataset into smaller batches."""
|
33
|
+
total_size = len(dataset)
|
34
|
+
num_batches = (total_size + batch_size - 1) // batch_size
|
35
|
+
|
36
|
+
batches = [
|
37
|
+
(i * batch_size, min((i + 1) * batch_size, total_size))
|
38
|
+
for i in tqdm(range(num_batches))
|
39
|
+
]
|
40
|
+
|
41
|
+
return batches
|
42
|
+
|
43
|
+
def _get_missing_data(self, seed_data, generated_data):
|
44
|
+
# Get the common columns between the two datasets
|
45
|
+
common_columns = list(
|
46
|
+
set(seed_data.column_names) & set(generated_data.column_names)
|
47
|
+
)
|
48
|
+
|
49
|
+
# Extract the relevant data based on common columns
|
50
|
+
seed_data_common = seed_data.select_columns(common_columns)
|
51
|
+
generated_data_common = generated_data.select_columns(common_columns)
|
52
|
+
|
53
|
+
# Convert to Pandas DataFrames for easier comparison
|
54
|
+
seed_df = seed_data_common.to_pandas()
|
55
|
+
generated_df = generated_data_common.to_pandas()
|
56
|
+
|
57
|
+
# Identify missing rows
|
58
|
+
missing_df = seed_df[
|
59
|
+
~seed_df.apply(tuple, 1).isin(generated_df.apply(tuple, 1))
|
60
|
+
]
|
61
|
+
|
62
|
+
# Convert back to Dataset
|
63
|
+
missing_data = Dataset.from_pandas(missing_df, preserve_index=False)
|
64
|
+
|
65
|
+
return missing_data
|
66
|
+
|
67
|
+
def _save_intermediate_checkpoint(self, dataset, checkpoint_dir):
|
68
|
+
checkpoint_id = uuid.uuid4().hex
|
69
|
+
checkpoint_file = f"{checkpoint_dir}/data_checkpoint_{checkpoint_id}.jsonl"
|
70
|
+
logger.info(f"Saving checkpoint to {checkpoint_file}")
|
71
|
+
dataset.to_json(checkpoint_file, orient="records", lines=True)
|
72
|
+
|
73
|
+
@staticmethod
|
74
|
+
def _generate_data(pipelines, input_split, ds, i=None):
|
75
|
+
logger.info(f"Processing split {i}")
|
76
|
+
input_split = ds.select(range(input_split[0], input_split[1]))
|
77
|
+
try:
|
78
|
+
for pipeline in pipelines:
|
79
|
+
input_split = pipeline.generate(input_split)
|
80
|
+
return input_split
|
81
|
+
except Exception as e:
|
82
|
+
logger.error(f"Error processing split {i}: {e}")
|
83
|
+
traceback.print_exc()
|
84
|
+
return None
|
85
|
+
|
86
|
+
def generate(self, dataset: Dataset, checkpoint_dir=None) -> Dataset:
|
87
|
+
# check if checkpoint_dir exists
|
88
|
+
pre_generated_data = []
|
89
|
+
if checkpoint_dir is not None:
|
90
|
+
try:
|
91
|
+
# check if there are any existing checkpoints
|
92
|
+
pre_generated_data = load_dataset(
|
93
|
+
"json", data_dir=checkpoint_dir, split="train"
|
94
|
+
)
|
95
|
+
logger.info(
|
96
|
+
f"Loading existing checkpoints from {checkpoint_dir}, with {pre_generated_data.num_rows} rows"
|
97
|
+
)
|
98
|
+
seed_data = self._get_missing_data(dataset, pre_generated_data)
|
99
|
+
if seed_data.num_rows == 0:
|
100
|
+
logger.info(
|
101
|
+
f"All seed data has been generated, no missing rows found, returning data from {checkpoint_dir}"
|
102
|
+
)
|
103
|
+
return pre_generated_data
|
104
|
+
logger.info(f"Found {seed_data.num_rows} missing rows in the dataset")
|
105
|
+
|
106
|
+
except EmptyDatasetError:
|
107
|
+
logger.info(
|
108
|
+
f"No existing checkpoints found in {checkpoint_dir}, generating from scratch"
|
109
|
+
)
|
110
|
+
seed_data = dataset
|
111
|
+
|
112
|
+
else:
|
113
|
+
seed_data = dataset
|
114
|
+
|
115
|
+
if not self.batch_size:
|
116
|
+
# If batch size is not provided, generate the dataset in a single pass
|
117
|
+
generated_dataset = seed_data
|
118
|
+
# generated_data is initialized with seed_data, and it gets updated with each pipeline
|
119
|
+
for pipeline in self.pipelines:
|
120
|
+
generated_dataset = pipeline.generate(seed_data)
|
121
|
+
return generated_dataset
|
122
|
+
|
123
|
+
logger.info("Splitting the dataset into smaller batches")
|
124
|
+
input_splits = (
|
125
|
+
self._split_dataset(seed_data, self.batch_size)
|
126
|
+
if self.batch_size
|
127
|
+
else [seed_data]
|
128
|
+
)
|
129
|
+
logger.info(
|
130
|
+
f"Generating dataset with {len(input_splits)} splits, batch size {self.batch_size}, and {self.num_workers} workers"
|
131
|
+
)
|
132
|
+
|
133
|
+
generated_data = [pre_generated_data] if pre_generated_data else []
|
134
|
+
last_saved_split_index = 0 # To track the last saved split
|
135
|
+
|
136
|
+
with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
|
137
|
+
futures = [
|
138
|
+
executor.submit(
|
139
|
+
self._generate_data, self.pipelines, input_split, seed_data, i
|
140
|
+
)
|
141
|
+
for i, input_split in enumerate(input_splits)
|
142
|
+
]
|
143
|
+
|
144
|
+
for i, future in enumerate(tqdm(as_completed(futures), total=len(futures))):
|
145
|
+
generated_data_split = future.result() # Ensure each future completes
|
146
|
+
|
147
|
+
if generated_data_split:
|
148
|
+
generated_data.append(generated_data_split)
|
149
|
+
logger.info(f"Finished future processing split {i} \n\n")
|
150
|
+
if self.save_freq and (i + 1) % self.save_freq == 0:
|
151
|
+
# Save only the new splits since the last checkpoint
|
152
|
+
new_splits = generated_data[last_saved_split_index : i + 1]
|
153
|
+
checkpoint_dataset = safe_concatenate_datasets(new_splits)
|
154
|
+
# check if checkpoint_dataset is not None
|
155
|
+
if checkpoint_dataset:
|
156
|
+
self._save_intermediate_checkpoint(
|
157
|
+
checkpoint_dataset, checkpoint_dir
|
158
|
+
)
|
159
|
+
|
160
|
+
last_saved_split_index = i + 1
|
161
|
+
|
162
|
+
generated_dataset = safe_concatenate_datasets(generated_data)
|
163
|
+
|
164
|
+
return generated_dataset
|
@@ -0,0 +1,73 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
# Standard
|
4
|
+
from typing import List
|
5
|
+
import logging
|
6
|
+
import re
|
7
|
+
|
8
|
+
# Third Party
|
9
|
+
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter
|
10
|
+
|
11
|
+
_DEFAULT_CHUNK_OVERLAP = 100
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
def _num_tokens_from_words(num_words) -> int:
|
17
|
+
return int(num_words * 1.3) # 1 word ~ 1.3 token
|
18
|
+
|
19
|
+
|
20
|
+
def _num_chars_from_tokens(num_tokens) -> int:
|
21
|
+
return int(num_tokens * 4) # 1 token ~ 4 English character
|
22
|
+
|
23
|
+
|
24
|
+
def chunk_document(documents: List, server_ctx_size, chunk_word_count) -> List[str]:
|
25
|
+
"""
|
26
|
+
Iterates over the documents and splits them into chunks based on the word count provided by the user.
|
27
|
+
Args:
|
28
|
+
documents (list): List of documents retrieved from git (can also consist of a single document).
|
29
|
+
server_ctx_size (int): Context window size of server.
|
30
|
+
chunk_word_count (int): Maximum number of words to chunk a document.
|
31
|
+
Returns:
|
32
|
+
List[str]: List of chunked documents.
|
33
|
+
"""
|
34
|
+
|
35
|
+
# Checks for input type error
|
36
|
+
if isinstance(documents, str):
|
37
|
+
documents = [documents]
|
38
|
+
|
39
|
+
elif not isinstance(documents, list):
|
40
|
+
raise TypeError(
|
41
|
+
"Expected: documents to be a list, but got {}".format(type(documents))
|
42
|
+
)
|
43
|
+
|
44
|
+
no_tokens_per_doc = _num_tokens_from_words(chunk_word_count)
|
45
|
+
if no_tokens_per_doc > int(server_ctx_size - 1024):
|
46
|
+
raise ValueError(
|
47
|
+
"Error: {}".format(
|
48
|
+
str(
|
49
|
+
f"Given word count ({chunk_word_count}) per doc will exceed the server context window size ({server_ctx_size})"
|
50
|
+
)
|
51
|
+
)
|
52
|
+
)
|
53
|
+
# Placeholder for params
|
54
|
+
content = []
|
55
|
+
chunk_size = _num_chars_from_tokens(no_tokens_per_doc)
|
56
|
+
chunk_overlap = _DEFAULT_CHUNK_OVERLAP
|
57
|
+
|
58
|
+
# Using Markdown as default, document-specific chunking will be implemented in seperate pr.
|
59
|
+
text_splitter = RecursiveCharacterTextSplitter.from_language(
|
60
|
+
language=Language.MARKDOWN,
|
61
|
+
chunk_size=chunk_size,
|
62
|
+
chunk_overlap=chunk_overlap,
|
63
|
+
)
|
64
|
+
|
65
|
+
# Determine file type for heuristics, default with markdown
|
66
|
+
for docs in documents:
|
67
|
+
# Use regex to remove unnecessary dashes in front of pipe characters in a markdown table.
|
68
|
+
docs = re.sub(r"-{2,}\|", "-|", docs)
|
69
|
+
# Remove unnecessary spaces in front of pipe characters in a markdown table.
|
70
|
+
docs = re.sub(r"\ +\|", " |", docs)
|
71
|
+
temp = text_splitter.create_documents([docs])
|
72
|
+
content.extend([item.page_content for item in temp])
|
73
|
+
return content
|
@@ -0,0 +1,123 @@
|
|
1
|
+
# Standard
|
2
|
+
import json
|
3
|
+
import os
|
4
|
+
|
5
|
+
# Third Party
|
6
|
+
from datasets import Dataset, load_dataset
|
7
|
+
import yaml
|
8
|
+
|
9
|
+
# First Party
|
10
|
+
from sdg_hub.logger_config import setup_logger
|
11
|
+
from .datautils import safe_concatenate_datasets
|
12
|
+
|
13
|
+
|
14
|
+
LOGGER = setup_logger(__name__)
|
15
|
+
ALLOWED_COLS = ["id", "messages", "metadata"]
|
16
|
+
|
17
|
+
|
18
|
+
def adjust_train_sample_size(ds: Dataset, num_samples: int):
|
19
|
+
LOGGER.info(f"Rebalancing dataset to have {num_samples} samples ...")
|
20
|
+
df = ds.to_pandas()
|
21
|
+
df = df.sample(n=num_samples, random_state=42, replace=True).reset_index(drop=True)
|
22
|
+
return Dataset.from_pandas(df)
|
23
|
+
|
24
|
+
|
25
|
+
def load_ds(path, sampling_size):
|
26
|
+
if path.endswith(".jsonl"):
|
27
|
+
LOGGER.info(f"Loading dataset from {path} ...")
|
28
|
+
dataset = load_dataset("json", data_files=path, split="train")
|
29
|
+
else:
|
30
|
+
LOGGER.info(f"Loading dataset from HF {path} ...")
|
31
|
+
dataset = load_dataset(path, split="train")
|
32
|
+
LOGGER.info(f"Dataset columns: {dataset.column_names}")
|
33
|
+
LOGGER.info(f"Dataset loaded with {len(dataset)} samples")
|
34
|
+
|
35
|
+
if sampling_size != 1.0:
|
36
|
+
if isinstance(sampling_size, int):
|
37
|
+
num_samples = sampling_size
|
38
|
+
else:
|
39
|
+
num_samples = int(len(dataset) * sampling_size)
|
40
|
+
dataset = adjust_train_sample_size(dataset, num_samples)
|
41
|
+
|
42
|
+
# move any column that is not in ALLOWED_COLS to metadata
|
43
|
+
def move_unallowed_cols_to_metadata(example):
|
44
|
+
metadata = example.get("metadata", {})
|
45
|
+
if isinstance(metadata, str):
|
46
|
+
metadata = json.loads(metadata)
|
47
|
+
for col in dataset.column_names:
|
48
|
+
if col not in ALLOWED_COLS:
|
49
|
+
metadata[col] = example[col]
|
50
|
+
example.pop(col)
|
51
|
+
example["metadata"] = json.dumps(metadata)
|
52
|
+
return example
|
53
|
+
|
54
|
+
dataset = dataset.map(move_unallowed_cols_to_metadata, num_proc=8)
|
55
|
+
|
56
|
+
# check if metadata column is string if not convert it using json.dumps
|
57
|
+
if not isinstance(dataset["metadata"][0], str):
|
58
|
+
dataset = dataset.map(
|
59
|
+
lambda x: {"metadata": json.dumps(x["metadata"])}, num_proc=8
|
60
|
+
)
|
61
|
+
|
62
|
+
return dataset
|
63
|
+
|
64
|
+
|
65
|
+
def add_system_message(sample: dict, sys_prompt: str) -> dict:
|
66
|
+
# check if the messages have role system
|
67
|
+
has_system = False
|
68
|
+
for msg in sample["messages"]:
|
69
|
+
if msg["role"] == "system":
|
70
|
+
has_system = True
|
71
|
+
msg["content"] = sys_prompt
|
72
|
+
|
73
|
+
if not has_system:
|
74
|
+
sample["messages"].insert(0, {"role": "system", "content": sys_prompt})
|
75
|
+
|
76
|
+
return sample
|
77
|
+
|
78
|
+
|
79
|
+
class Recipe:
|
80
|
+
def __init__(self, recipe_path):
|
81
|
+
self.recipe_path = recipe_path
|
82
|
+
self.recipe = self._load_recipe()
|
83
|
+
self.sys_prompt = self.recipe.get("sys_prompt", "")
|
84
|
+
self.dataset_added = False
|
85
|
+
|
86
|
+
def _load_recipe(self):
|
87
|
+
with open(self.recipe_path, encoding="utf-8") as fp:
|
88
|
+
return yaml.safe_load(fp)
|
89
|
+
|
90
|
+
def add_dataset(self, path, sampling_size=1.0):
|
91
|
+
self.dataset_added = True
|
92
|
+
self.recipe["datasets"].append({"path": path, "sampling_size": sampling_size})
|
93
|
+
|
94
|
+
def save_recipe(self, output_path):
|
95
|
+
# check if directory exists
|
96
|
+
output_dir = os.path.dirname(output_path)
|
97
|
+
if not os.path.exists(output_dir):
|
98
|
+
os.makedirs(output_dir)
|
99
|
+
|
100
|
+
with open(output_path, "w", encoding="utf-8") as fp:
|
101
|
+
yaml.dump(self.recipe, fp)
|
102
|
+
|
103
|
+
def save_mixed_dataset(self, output_path):
|
104
|
+
if not self.dataset_added:
|
105
|
+
LOGGER.error("No dataset added to the recipe")
|
106
|
+
|
107
|
+
mixed_ds = [
|
108
|
+
load_ds(dataset["path"], dataset["sampling_size"])
|
109
|
+
for dataset in self.recipe["datasets"]
|
110
|
+
]
|
111
|
+
|
112
|
+
mixed_ds = safe_concatenate_datasets(mixed_ds)
|
113
|
+
mixed_ds = mixed_ds.map(
|
114
|
+
add_system_message, fn_kwargs={"sys_prompt": self.sys_prompt}, num_proc=8
|
115
|
+
)
|
116
|
+
|
117
|
+
# assert that the dataset only has the allowed columns
|
118
|
+
assert set(mixed_ds.column_names) == set(
|
119
|
+
ALLOWED_COLS
|
120
|
+
), "Dataset has invalid columns"
|
121
|
+
|
122
|
+
mixed_ds.to_json(output_path, orient="records", lines=True)
|
123
|
+
LOGGER.info(f"Mixed Dataset saved to {output_path}")
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Third Party
|
2
|
+
from datasets import concatenate_datasets
|
3
|
+
|
4
|
+
|
5
|
+
def safe_concatenate_datasets(datasets: list):
|
6
|
+
"""
|
7
|
+
Concatenate datasets safely, ignoring any datasets that are None or empty.
|
8
|
+
"""
|
9
|
+
filtered_datasets = [ds for ds in datasets if ds is not None and ds.num_rows > 0]
|
10
|
+
|
11
|
+
if not filtered_datasets:
|
12
|
+
return None
|
13
|
+
|
14
|
+
return concatenate_datasets(filtered_datasets)
|