sdg-hub 0.1.0a2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. sdg_hub/__init__.py +4 -0
  2. sdg_hub/_version.py +21 -0
  3. sdg_hub/blocks/__init__.py +6 -0
  4. sdg_hub/blocks/block.py +54 -0
  5. sdg_hub/blocks/filterblock.py +76 -0
  6. sdg_hub/blocks/iterblock.py +31 -0
  7. sdg_hub/blocks/llmblock.py +430 -0
  8. sdg_hub/blocks/rmblocks.py +194 -0
  9. sdg_hub/blocks/utilblocks.py +140 -0
  10. sdg_hub/configs/__init__.py +0 -0
  11. sdg_hub/configs/annotations/__init__.py +0 -0
  12. sdg_hub/configs/annotations/cot_reflection.yaml +34 -0
  13. sdg_hub/configs/annotations/detailed_description.yaml +10 -0
  14. sdg_hub/configs/annotations/detailed_description_icl.yaml +32 -0
  15. sdg_hub/configs/annotations/simple.yaml +10 -0
  16. sdg_hub/configs/knowledge/__init__.py +0 -0
  17. sdg_hub/configs/knowledge/atomic_facts.yaml +45 -0
  18. sdg_hub/configs/knowledge/auxilary_instructions.yaml +35 -0
  19. sdg_hub/configs/knowledge/data_recipe/__init__.py +0 -0
  20. sdg_hub/configs/knowledge/data_recipe/default_recipe.yaml +3 -0
  21. sdg_hub/configs/knowledge/detailed_summary.yaml +17 -0
  22. sdg_hub/configs/knowledge/evaluate_faithfulness.yaml +68 -0
  23. sdg_hub/configs/knowledge/evaluate_question.yaml +38 -0
  24. sdg_hub/configs/knowledge/evaluate_relevancy.yaml +85 -0
  25. sdg_hub/configs/knowledge/extractive_summary.yaml +17 -0
  26. sdg_hub/configs/knowledge/generate_code_questions_responses.yaml +39 -0
  27. sdg_hub/configs/knowledge/generate_questions_responses.yaml +56 -0
  28. sdg_hub/configs/knowledge/mcq_generation.yaml +83 -0
  29. sdg_hub/configs/knowledge/router.yaml +12 -0
  30. sdg_hub/configs/knowledge/simple_generate_qa.yaml +34 -0
  31. sdg_hub/configs/reasoning/dynamic_cot.yaml +40 -0
  32. sdg_hub/configs/skills/_A_.yaml +97 -0
  33. sdg_hub/configs/skills/_B_.yaml +36 -0
  34. sdg_hub/configs/skills/_C_.yaml +71 -0
  35. sdg_hub/configs/skills/_D_.yaml +85 -0
  36. sdg_hub/configs/skills/_E_.yaml +30 -0
  37. sdg_hub/configs/skills/_F_.yaml +45 -0
  38. sdg_hub/configs/skills/_G_.yaml +56 -0
  39. sdg_hub/configs/skills/_H_.yaml +80 -0
  40. sdg_hub/configs/skills/__init__.py +0 -0
  41. sdg_hub/configs/skills/analyzer.yaml +48 -0
  42. sdg_hub/configs/skills/annotation.yaml +36 -0
  43. sdg_hub/configs/skills/contexts.yaml +21 -0
  44. sdg_hub/configs/skills/critic.yaml +60 -0
  45. sdg_hub/configs/skills/data_recipe/__init__.py +0 -0
  46. sdg_hub/configs/skills/data_recipe/default_recipe.yaml +6 -0
  47. sdg_hub/configs/skills/evaluate_freeform_pair.yaml +44 -0
  48. sdg_hub/configs/skills/evaluate_freeform_questions.yaml +46 -0
  49. sdg_hub/configs/skills/evaluate_grounded_pair.yaml +54 -0
  50. sdg_hub/configs/skills/evaluate_grounded_questions.yaml +51 -0
  51. sdg_hub/configs/skills/freeform_questions.yaml +29 -0
  52. sdg_hub/configs/skills/freeform_responses.yaml +45 -0
  53. sdg_hub/configs/skills/grounded_questions.yaml +38 -0
  54. sdg_hub/configs/skills/grounded_responses.yaml +59 -0
  55. sdg_hub/configs/skills/judge.yaml +53 -0
  56. sdg_hub/configs/skills/planner.yaml +67 -0
  57. sdg_hub/configs/skills/respond.yaml +8 -0
  58. sdg_hub/configs/skills/revised_responder.yaml +78 -0
  59. sdg_hub/configs/skills/router.yaml +12 -0
  60. sdg_hub/configs/skills/simple_generate_qa_freeform.yaml +27 -0
  61. sdg_hub/configs/skills/simple_generate_qa_grounded.yaml +31 -0
  62. sdg_hub/flow.py +127 -0
  63. sdg_hub/flows/annotation/emotion/detailed_description.yaml +19 -0
  64. sdg_hub/flows/annotation/emotion/detailed_description_icl.yaml +19 -0
  65. sdg_hub/flows/annotation/emotion/simple.yaml +19 -0
  66. sdg_hub/flows/generation/knowledge/mmlu_bench.yaml +13 -0
  67. sdg_hub/flows/generation/knowledge/simple_knowledge.yaml +12 -0
  68. sdg_hub/flows/generation/knowledge/synth_knowledge.yaml +89 -0
  69. sdg_hub/flows/generation/knowledge/synth_knowledge1.5.yaml +136 -0
  70. sdg_hub/flows/generation/skills/agentic_improve_skill.yaml +108 -0
  71. sdg_hub/flows/generation/skills/simple_freeform_skill.yaml +12 -0
  72. sdg_hub/flows/generation/skills/simple_grounded_skill.yaml +12 -0
  73. sdg_hub/flows/generation/skills/synth_grounded_skills.yaml +80 -0
  74. sdg_hub/flows/generation/skills/synth_skills.yaml +59 -0
  75. sdg_hub/logger_config.py +20 -0
  76. sdg_hub/pipeline.py +66 -0
  77. sdg_hub/prompts.py +17 -0
  78. sdg_hub/py.typed +0 -0
  79. sdg_hub/registry.py +122 -0
  80. sdg_hub/sdg.py +164 -0
  81. sdg_hub/utils/__init__.py +5 -0
  82. sdg_hub/utils/chunking.py +73 -0
  83. sdg_hub/utils/datamixing.py +123 -0
  84. sdg_hub/utils/datautils.py +14 -0
  85. sdg_hub/utils/docprocessor.py +357 -0
  86. sdg_hub/utils/json.py +48 -0
  87. sdg_hub/utils/models.py +31 -0
  88. sdg_hub/utils/parse_and_convert.py +392 -0
  89. sdg_hub/utils/taxonomy.py +489 -0
  90. sdg_hub-0.1.0a2.dev0.dist-info/METADATA +154 -0
  91. sdg_hub-0.1.0a2.dev0.dist-info/RECORD +94 -0
  92. sdg_hub-0.1.0a2.dev0.dist-info/WHEEL +5 -0
  93. sdg_hub-0.1.0a2.dev0.dist-info/licenses/LICENSE +201 -0
  94. sdg_hub-0.1.0a2.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,59 @@
1
+ - block_type: LLMBlock
2
+ block_config:
3
+ block_name: gen_questions
4
+ config_path: configs/skills/freeform_questions.yaml
5
+ model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
6
+ output_cols:
7
+ - question
8
+ batch_kwargs:
9
+ num_samples: 30
10
+ drop_duplicates:
11
+ - question
12
+ - block_type: LLMBlock
13
+ block_config:
14
+ block_name: eval_questions
15
+ config_path: configs/skills/evaluate_freeform_questions.yaml
16
+ model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
17
+ output_cols:
18
+ - evaluation
19
+ - score
20
+ - block_type: FilterByValueBlock
21
+ block_config:
22
+ block_name: filter_questions
23
+ filter_column: score
24
+ filter_value: 1.0
25
+ operation: operator.eq
26
+ convert_dtype: float
27
+ batch_kwargs:
28
+ num_procs: 8
29
+ drop_columns:
30
+ - evaluation
31
+ - score
32
+ - num_samples
33
+ - block_type: LLMBlock
34
+ block_config:
35
+ block_name: gen_responses
36
+ config_path: configs/skills/freeform_responses.yaml
37
+ model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
38
+ output_cols:
39
+ - response
40
+ - block_type: LLMBlock
41
+ block_config:
42
+ block_name: evaluate_qa_pair
43
+ config_path: configs/skills/evaluate_freeform_pair.yaml
44
+ model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
45
+ output_cols:
46
+ - evaluation
47
+ - score
48
+ - block_type: FilterByValueBlock
49
+ block_config:
50
+ block_name: filter_qa_pair
51
+ filter_column: score
52
+ filter_value: 2.0
53
+ operation: operator.ge
54
+ convert_dtype: float
55
+ batch_kwargs:
56
+ num_procs: 8
57
+ drop_columns:
58
+ - evaluation
59
+ - score
@@ -0,0 +1,20 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Standard
3
+ import os
4
+ import logging
5
+
6
+ # Third Party
7
+ from rich.logging import RichHandler
8
+
9
+
10
+ def setup_logger(name):
11
+ # Set up the logger
12
+ log_level = os.getenv("LOG_LEVEL", "INFO")
13
+ logging.basicConfig(
14
+ level=log_level,
15
+ format="%(message)s",
16
+ datefmt="[%X]",
17
+ handlers=[RichHandler()],
18
+ )
19
+ logger = logging.getLogger(name)
20
+ return logger
sdg_hub/pipeline.py ADDED
@@ -0,0 +1,66 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Third Party
3
+ from datasets import Dataset
4
+
5
+ # Local
6
+ from .logger_config import setup_logger
7
+
8
+ logger = setup_logger(__name__)
9
+
10
+
11
+ class EmptyDatasetError(Exception):
12
+ pass
13
+
14
+
15
+ class Pipeline:
16
+ def __init__(self, chained_blocks: list) -> None:
17
+ """
18
+ Initialize the Pipeline class with a configuration dictionary.
19
+ config_dict: the run config py or yaml loaded into a dictionary
20
+ """
21
+ # pipeline config is the run configuration that consists of the pipeline steps
22
+ self.chained_blocks = chained_blocks
23
+
24
+ def _drop_duplicates(self, dataset, cols):
25
+ """
26
+ Drop duplicates from the dataset based on the columns provided.
27
+ """
28
+ df = dataset.to_pandas()
29
+ df = df.drop_duplicates(subset=cols).reset_index(drop=True)
30
+ return Dataset.from_pandas(df)
31
+
32
+ def generate(self, dataset) -> Dataset:
33
+ """
34
+ Generate the dataset by running the pipeline steps.
35
+ dataset: the input dataset
36
+ """
37
+ for block_prop in self.chained_blocks:
38
+ block_type = block_prop["block_type"]
39
+ block_config = block_prop["block_config"]
40
+ drop_columns = block_prop.get("drop_columns", [])
41
+ gen_kwargs = block_prop.get("gen_kwargs", {})
42
+ drop_duplicates_cols = block_prop.get("drop_duplicates", False)
43
+ block = block_type(**block_config)
44
+
45
+ logger.debug("------------------------------------\n")
46
+ logger.debug("Running block: %s", block_config["block_name"])
47
+ logger.debug("Input dataset: %s", dataset)
48
+
49
+ dataset = block.generate(dataset, **gen_kwargs)
50
+
51
+ if len(dataset) == 0:
52
+ raise EmptyDatasetError(
53
+ f"Pipeline stopped: Empty dataset after running block: {block_config['block_name']}"
54
+ )
55
+
56
+ drop_columns_in_ds = [e for e in drop_columns if e in dataset.column_names]
57
+ if drop_columns:
58
+ dataset = dataset.remove_columns(drop_columns_in_ds)
59
+
60
+ if drop_duplicates_cols:
61
+ dataset = self._drop_duplicates(dataset, cols=drop_duplicates_cols)
62
+
63
+ logger.debug("Output dataset: %s", dataset)
64
+ logger.debug("------------------------------------\n\n")
65
+
66
+ return dataset
sdg_hub/prompts.py ADDED
@@ -0,0 +1,17 @@
1
+ # Local
2
+ from .registry import PromptRegistry
3
+
4
+
5
+ @PromptRegistry.register("blank")
6
+ def blank_chat_template():
7
+ return """{{ messages }}"""
8
+
9
+
10
+ @PromptRegistry.register("instructlab")
11
+ def instructlab_chat_template():
12
+ return """{% for message in messages %}{% if message['role'] == 'pretraining' %}{{ '<|pretrain|>' + message['content'] + '<|endoftext|>' + '<|/pretrain|>' }}{% elif message['role'] == 'system' %}{{ '<|system|>' + '\n' + message['content'] + '\n' }}{% elif message['role'] == 'user' %}{{ '<|user|>' + '\n' + message['content'] + '\n' }}{% elif message['role'] == 'assistant' %}{{ '<|assistant|>' + '\n' + message['content'] + '<|endoftext|>' + ('' if loop.last else '\n') }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '<|assistant|>' + '\n' }}{% endif %}{% endfor %}"""
13
+
14
+
15
+ @PromptRegistry.register("mistralai")
16
+ def mistral_chat_template():
17
+ return """{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n<s>\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.first and system_message is defined %}\n {{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}\n {%- else %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- ' ' + message['content'] + '</s>'}}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n{%- endfor %}\n"""
sdg_hub/py.typed ADDED
File without changes
sdg_hub/registry.py ADDED
@@ -0,0 +1,122 @@
1
+ # Standard
2
+ from typing import Union, List, Dict
3
+
4
+ # Third Party
5
+ from jinja2 import Template
6
+
7
+ # Local
8
+ from .logger_config import setup_logger
9
+
10
+ logger = setup_logger(__name__)
11
+
12
+
13
+ class BlockRegistry:
14
+ """Registry for block classes to avoid manual additions to block type map."""
15
+
16
+ _registry: Dict[str, type] = {}
17
+
18
+ @classmethod
19
+ def register(cls, block_name: str):
20
+ """
21
+ Decorator to register a block class under a specified name.
22
+
23
+ :param block_name: Name under which to register the block.
24
+ """
25
+
26
+ def decorator(block_class):
27
+ cls._registry[block_name] = block_class
28
+ logger.debug(
29
+ f"Registered block '{block_name}' with class '{block_class.__name__}'"
30
+ )
31
+ return block_class
32
+
33
+ return decorator
34
+
35
+ @classmethod
36
+ def get_registry(cls):
37
+ """
38
+ Retrieve the current registry map of block types.
39
+
40
+ :return: Dictionary of registered block names and classes.
41
+ """
42
+ logger.debug("Fetching the block registry map.")
43
+ return cls._registry
44
+
45
+
46
+ class PromptRegistry:
47
+ """Registry for managing Jinja2 prompt templates."""
48
+
49
+ _registry: Dict[str, Template] = {}
50
+
51
+ @classmethod
52
+ def register(cls, name: str):
53
+ """Decorator to register a Jinja2 template function by name.
54
+
55
+ :param name: Name of the template to register.
56
+ :return: A decorator that registers the Jinja2 template function.
57
+ """
58
+
59
+ def decorator(func):
60
+ template_str = func()
61
+ cls._registry[name] = Template(template_str)
62
+ logger.debug(f"Registered prompt template '{name}'")
63
+ return func
64
+
65
+ return decorator
66
+
67
+ @classmethod
68
+ def get_template(cls, name: str) -> Template:
69
+ """Retrieve a Jinja2 template by name.
70
+
71
+ :param name: Name of the template to retrieve.
72
+ :return: The Jinja2 template instance.
73
+ """
74
+ if name not in cls._registry:
75
+ raise KeyError(f"Template '{name}' not found.")
76
+ logger.debug(f"Retrieving prompt template '{name}'")
77
+ return cls._registry[name]
78
+
79
+ @classmethod
80
+ def get_registry(cls):
81
+ """
82
+ Retrieve the current registry map of block types.
83
+
84
+ :return: Dictionary of registered block names and classes.
85
+ """
86
+ logger.debug("Fetching the block registry map.")
87
+ return cls._registry
88
+
89
+ @classmethod
90
+ def render_template(
91
+ cls,
92
+ name: str,
93
+ messages: Union[str, List[Dict[str, str]]],
94
+ add_generation_prompt: bool = True,
95
+ ) -> str:
96
+ """Render the template with the provided messages or query.
97
+
98
+ :param name: Name of the template to render.
99
+ :param messages: Either a single query string or a list of messages (each as a dict with 'role' and 'content').
100
+ :param add_generation_prompt: Whether to add a generation prompt at the end.
101
+ :return: The rendered prompt as a string.
102
+ """
103
+
104
+ # Special handling for "blank" template
105
+ if name == "blank":
106
+ if not isinstance(messages, str):
107
+ raise ValueError(
108
+ "The 'blank' template can only be used with a single query string, not a list of messages."
109
+ )
110
+ return messages # Return the query as-is without templating
111
+
112
+ # Get the template
113
+ template = cls.get_template(name)
114
+
115
+ # If `messages` is a string, wrap it in a list with a default user role
116
+ if isinstance(messages, str):
117
+ messages = [{"role": "user", "content": messages}]
118
+
119
+ # Render the template with the `messages` list
120
+ return template.render(
121
+ messages=messages, add_generation_prompt=add_generation_prompt
122
+ )
sdg_hub/sdg.py ADDED
@@ -0,0 +1,164 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Standard
3
+ from concurrent.futures import ThreadPoolExecutor, as_completed
4
+ from typing import List
5
+ import traceback
6
+ import uuid
7
+
8
+ # Third Party
9
+ from datasets import Dataset, load_dataset
10
+ from datasets.data_files import EmptyDatasetError
11
+ from tqdm import tqdm
12
+
13
+ # Local
14
+ from .logger_config import setup_logger
15
+ from .pipeline import Pipeline
16
+ from .utils.datautils import safe_concatenate_datasets
17
+
18
+
19
+ logger = setup_logger(__name__)
20
+
21
+
22
+ class SDG:
23
+ def __init__(
24
+ self, pipelines: List[Pipeline], num_workers=1, batch_size=None, save_freq=None
25
+ ) -> None:
26
+ self.pipelines = pipelines
27
+ self.num_workers = num_workers
28
+ self.batch_size = batch_size
29
+ self.save_freq = save_freq
30
+
31
+ def _split_dataset(self, dataset: Dataset, batch_size: int) -> List[Dataset]:
32
+ """Split the dataset into smaller batches."""
33
+ total_size = len(dataset)
34
+ num_batches = (total_size + batch_size - 1) // batch_size
35
+
36
+ batches = [
37
+ (i * batch_size, min((i + 1) * batch_size, total_size))
38
+ for i in tqdm(range(num_batches))
39
+ ]
40
+
41
+ return batches
42
+
43
+ def _get_missing_data(self, seed_data, generated_data):
44
+ # Get the common columns between the two datasets
45
+ common_columns = list(
46
+ set(seed_data.column_names) & set(generated_data.column_names)
47
+ )
48
+
49
+ # Extract the relevant data based on common columns
50
+ seed_data_common = seed_data.select_columns(common_columns)
51
+ generated_data_common = generated_data.select_columns(common_columns)
52
+
53
+ # Convert to Pandas DataFrames for easier comparison
54
+ seed_df = seed_data_common.to_pandas()
55
+ generated_df = generated_data_common.to_pandas()
56
+
57
+ # Identify missing rows
58
+ missing_df = seed_df[
59
+ ~seed_df.apply(tuple, 1).isin(generated_df.apply(tuple, 1))
60
+ ]
61
+
62
+ # Convert back to Dataset
63
+ missing_data = Dataset.from_pandas(missing_df, preserve_index=False)
64
+
65
+ return missing_data
66
+
67
+ def _save_intermediate_checkpoint(self, dataset, checkpoint_dir):
68
+ checkpoint_id = uuid.uuid4().hex
69
+ checkpoint_file = f"{checkpoint_dir}/data_checkpoint_{checkpoint_id}.jsonl"
70
+ logger.info(f"Saving checkpoint to {checkpoint_file}")
71
+ dataset.to_json(checkpoint_file, orient="records", lines=True)
72
+
73
+ @staticmethod
74
+ def _generate_data(pipelines, input_split, ds, i=None):
75
+ logger.info(f"Processing split {i}")
76
+ input_split = ds.select(range(input_split[0], input_split[1]))
77
+ try:
78
+ for pipeline in pipelines:
79
+ input_split = pipeline.generate(input_split)
80
+ return input_split
81
+ except Exception as e:
82
+ logger.error(f"Error processing split {i}: {e}")
83
+ traceback.print_exc()
84
+ return None
85
+
86
+ def generate(self, dataset: Dataset, checkpoint_dir=None) -> Dataset:
87
+ # check if checkpoint_dir exists
88
+ pre_generated_data = []
89
+ if checkpoint_dir is not None:
90
+ try:
91
+ # check if there are any existing checkpoints
92
+ pre_generated_data = load_dataset(
93
+ "json", data_dir=checkpoint_dir, split="train"
94
+ )
95
+ logger.info(
96
+ f"Loading existing checkpoints from {checkpoint_dir}, with {pre_generated_data.num_rows} rows"
97
+ )
98
+ seed_data = self._get_missing_data(dataset, pre_generated_data)
99
+ if seed_data.num_rows == 0:
100
+ logger.info(
101
+ f"All seed data has been generated, no missing rows found, returning data from {checkpoint_dir}"
102
+ )
103
+ return pre_generated_data
104
+ logger.info(f"Found {seed_data.num_rows} missing rows in the dataset")
105
+
106
+ except EmptyDatasetError:
107
+ logger.info(
108
+ f"No existing checkpoints found in {checkpoint_dir}, generating from scratch"
109
+ )
110
+ seed_data = dataset
111
+
112
+ else:
113
+ seed_data = dataset
114
+
115
+ if not self.batch_size:
116
+ # If batch size is not provided, generate the dataset in a single pass
117
+ generated_dataset = seed_data
118
+ # generated_data is initialized with seed_data, and it gets updated with each pipeline
119
+ for pipeline in self.pipelines:
120
+ generated_dataset = pipeline.generate(seed_data)
121
+ return generated_dataset
122
+
123
+ logger.info("Splitting the dataset into smaller batches")
124
+ input_splits = (
125
+ self._split_dataset(seed_data, self.batch_size)
126
+ if self.batch_size
127
+ else [seed_data]
128
+ )
129
+ logger.info(
130
+ f"Generating dataset with {len(input_splits)} splits, batch size {self.batch_size}, and {self.num_workers} workers"
131
+ )
132
+
133
+ generated_data = [pre_generated_data] if pre_generated_data else []
134
+ last_saved_split_index = 0 # To track the last saved split
135
+
136
+ with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
137
+ futures = [
138
+ executor.submit(
139
+ self._generate_data, self.pipelines, input_split, seed_data, i
140
+ )
141
+ for i, input_split in enumerate(input_splits)
142
+ ]
143
+
144
+ for i, future in enumerate(tqdm(as_completed(futures), total=len(futures))):
145
+ generated_data_split = future.result() # Ensure each future completes
146
+
147
+ if generated_data_split:
148
+ generated_data.append(generated_data_split)
149
+ logger.info(f"Finished future processing split {i} \n\n")
150
+ if self.save_freq and (i + 1) % self.save_freq == 0:
151
+ # Save only the new splits since the last checkpoint
152
+ new_splits = generated_data[last_saved_split_index : i + 1]
153
+ checkpoint_dataset = safe_concatenate_datasets(new_splits)
154
+ # check if checkpoint_dataset is not None
155
+ if checkpoint_dataset:
156
+ self._save_intermediate_checkpoint(
157
+ checkpoint_dataset, checkpoint_dir
158
+ )
159
+
160
+ last_saved_split_index = i + 1
161
+
162
+ generated_dataset = safe_concatenate_datasets(generated_data)
163
+
164
+ return generated_dataset
@@ -0,0 +1,5 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # This is part of the public API, and used by instructlab
4
+ class GenerateException(Exception):
5
+ """An exception raised during generate step."""
@@ -0,0 +1,73 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # Standard
4
+ from typing import List
5
+ import logging
6
+ import re
7
+
8
+ # Third Party
9
+ from langchain_text_splitters import Language, RecursiveCharacterTextSplitter
10
+
11
+ _DEFAULT_CHUNK_OVERLAP = 100
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def _num_tokens_from_words(num_words) -> int:
17
+ return int(num_words * 1.3) # 1 word ~ 1.3 token
18
+
19
+
20
+ def _num_chars_from_tokens(num_tokens) -> int:
21
+ return int(num_tokens * 4) # 1 token ~ 4 English character
22
+
23
+
24
+ def chunk_document(documents: List, server_ctx_size, chunk_word_count) -> List[str]:
25
+ """
26
+ Iterates over the documents and splits them into chunks based on the word count provided by the user.
27
+ Args:
28
+ documents (list): List of documents retrieved from git (can also consist of a single document).
29
+ server_ctx_size (int): Context window size of server.
30
+ chunk_word_count (int): Maximum number of words to chunk a document.
31
+ Returns:
32
+ List[str]: List of chunked documents.
33
+ """
34
+
35
+ # Checks for input type error
36
+ if isinstance(documents, str):
37
+ documents = [documents]
38
+
39
+ elif not isinstance(documents, list):
40
+ raise TypeError(
41
+ "Expected: documents to be a list, but got {}".format(type(documents))
42
+ )
43
+
44
+ no_tokens_per_doc = _num_tokens_from_words(chunk_word_count)
45
+ if no_tokens_per_doc > int(server_ctx_size - 1024):
46
+ raise ValueError(
47
+ "Error: {}".format(
48
+ str(
49
+ f"Given word count ({chunk_word_count}) per doc will exceed the server context window size ({server_ctx_size})"
50
+ )
51
+ )
52
+ )
53
+ # Placeholder for params
54
+ content = []
55
+ chunk_size = _num_chars_from_tokens(no_tokens_per_doc)
56
+ chunk_overlap = _DEFAULT_CHUNK_OVERLAP
57
+
58
+ # Using Markdown as default, document-specific chunking will be implemented in seperate pr.
59
+ text_splitter = RecursiveCharacterTextSplitter.from_language(
60
+ language=Language.MARKDOWN,
61
+ chunk_size=chunk_size,
62
+ chunk_overlap=chunk_overlap,
63
+ )
64
+
65
+ # Determine file type for heuristics, default with markdown
66
+ for docs in documents:
67
+ # Use regex to remove unnecessary dashes in front of pipe characters in a markdown table.
68
+ docs = re.sub(r"-{2,}\|", "-|", docs)
69
+ # Remove unnecessary spaces in front of pipe characters in a markdown table.
70
+ docs = re.sub(r"\ +\|", " |", docs)
71
+ temp = text_splitter.create_documents([docs])
72
+ content.extend([item.page_content for item in temp])
73
+ return content
@@ -0,0 +1,123 @@
1
+ # Standard
2
+ import json
3
+ import os
4
+
5
+ # Third Party
6
+ from datasets import Dataset, load_dataset
7
+ import yaml
8
+
9
+ # First Party
10
+ from sdg_hub.logger_config import setup_logger
11
+ from .datautils import safe_concatenate_datasets
12
+
13
+
14
+ LOGGER = setup_logger(__name__)
15
+ ALLOWED_COLS = ["id", "messages", "metadata"]
16
+
17
+
18
+ def adjust_train_sample_size(ds: Dataset, num_samples: int):
19
+ LOGGER.info(f"Rebalancing dataset to have {num_samples} samples ...")
20
+ df = ds.to_pandas()
21
+ df = df.sample(n=num_samples, random_state=42, replace=True).reset_index(drop=True)
22
+ return Dataset.from_pandas(df)
23
+
24
+
25
+ def load_ds(path, sampling_size):
26
+ if path.endswith(".jsonl"):
27
+ LOGGER.info(f"Loading dataset from {path} ...")
28
+ dataset = load_dataset("json", data_files=path, split="train")
29
+ else:
30
+ LOGGER.info(f"Loading dataset from HF {path} ...")
31
+ dataset = load_dataset(path, split="train")
32
+ LOGGER.info(f"Dataset columns: {dataset.column_names}")
33
+ LOGGER.info(f"Dataset loaded with {len(dataset)} samples")
34
+
35
+ if sampling_size != 1.0:
36
+ if isinstance(sampling_size, int):
37
+ num_samples = sampling_size
38
+ else:
39
+ num_samples = int(len(dataset) * sampling_size)
40
+ dataset = adjust_train_sample_size(dataset, num_samples)
41
+
42
+ # move any column that is not in ALLOWED_COLS to metadata
43
+ def move_unallowed_cols_to_metadata(example):
44
+ metadata = example.get("metadata", {})
45
+ if isinstance(metadata, str):
46
+ metadata = json.loads(metadata)
47
+ for col in dataset.column_names:
48
+ if col not in ALLOWED_COLS:
49
+ metadata[col] = example[col]
50
+ example.pop(col)
51
+ example["metadata"] = json.dumps(metadata)
52
+ return example
53
+
54
+ dataset = dataset.map(move_unallowed_cols_to_metadata, num_proc=8)
55
+
56
+ # check if metadata column is string if not convert it using json.dumps
57
+ if not isinstance(dataset["metadata"][0], str):
58
+ dataset = dataset.map(
59
+ lambda x: {"metadata": json.dumps(x["metadata"])}, num_proc=8
60
+ )
61
+
62
+ return dataset
63
+
64
+
65
+ def add_system_message(sample: dict, sys_prompt: str) -> dict:
66
+ # check if the messages have role system
67
+ has_system = False
68
+ for msg in sample["messages"]:
69
+ if msg["role"] == "system":
70
+ has_system = True
71
+ msg["content"] = sys_prompt
72
+
73
+ if not has_system:
74
+ sample["messages"].insert(0, {"role": "system", "content": sys_prompt})
75
+
76
+ return sample
77
+
78
+
79
+ class Recipe:
80
+ def __init__(self, recipe_path):
81
+ self.recipe_path = recipe_path
82
+ self.recipe = self._load_recipe()
83
+ self.sys_prompt = self.recipe.get("sys_prompt", "")
84
+ self.dataset_added = False
85
+
86
+ def _load_recipe(self):
87
+ with open(self.recipe_path, encoding="utf-8") as fp:
88
+ return yaml.safe_load(fp)
89
+
90
+ def add_dataset(self, path, sampling_size=1.0):
91
+ self.dataset_added = True
92
+ self.recipe["datasets"].append({"path": path, "sampling_size": sampling_size})
93
+
94
+ def save_recipe(self, output_path):
95
+ # check if directory exists
96
+ output_dir = os.path.dirname(output_path)
97
+ if not os.path.exists(output_dir):
98
+ os.makedirs(output_dir)
99
+
100
+ with open(output_path, "w", encoding="utf-8") as fp:
101
+ yaml.dump(self.recipe, fp)
102
+
103
+ def save_mixed_dataset(self, output_path):
104
+ if not self.dataset_added:
105
+ LOGGER.error("No dataset added to the recipe")
106
+
107
+ mixed_ds = [
108
+ load_ds(dataset["path"], dataset["sampling_size"])
109
+ for dataset in self.recipe["datasets"]
110
+ ]
111
+
112
+ mixed_ds = safe_concatenate_datasets(mixed_ds)
113
+ mixed_ds = mixed_ds.map(
114
+ add_system_message, fn_kwargs={"sys_prompt": self.sys_prompt}, num_proc=8
115
+ )
116
+
117
+ # assert that the dataset only has the allowed columns
118
+ assert set(mixed_ds.column_names) == set(
119
+ ALLOWED_COLS
120
+ ), "Dataset has invalid columns"
121
+
122
+ mixed_ds.to_json(output_path, orient="records", lines=True)
123
+ LOGGER.info(f"Mixed Dataset saved to {output_path}")
@@ -0,0 +1,14 @@
1
+ # Third Party
2
+ from datasets import concatenate_datasets
3
+
4
+
5
+ def safe_concatenate_datasets(datasets: list):
6
+ """
7
+ Concatenate datasets safely, ignoring any datasets that are None or empty.
8
+ """
9
+ filtered_datasets = [ds for ds in datasets if ds is not None and ds.num_rows > 0]
10
+
11
+ if not filtered_datasets:
12
+ return None
13
+
14
+ return concatenate_datasets(filtered_datasets)