sdg-hub 0.1.0a3__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. sdg_hub/_version.py +2 -2
  2. sdg_hub/blocks/__init__.py +35 -5
  3. sdg_hub/blocks/block.py +58 -16
  4. sdg_hub/blocks/llmblock.py +149 -204
  5. sdg_hub/blocks/utilblocks.py +500 -43
  6. sdg_hub/checkpointer.py +139 -0
  7. sdg_hub/configs/annotations/detailed_annotations.yaml +28 -0
  8. sdg_hub/configs/annotations/simple_annotations.yaml +9 -0
  9. sdg_hub/configs/knowledge/atomic_facts.yaml +1 -0
  10. sdg_hub/configs/knowledge/detailed_summary.yaml +1 -0
  11. sdg_hub/configs/knowledge/extractive_summary.yaml +1 -0
  12. sdg_hub/configs/knowledge/generate_questions.yaml +82 -0
  13. sdg_hub/configs/knowledge/generate_responses.yaml +86 -0
  14. sdg_hub/configs/skills/contexts.yaml +18 -11
  15. sdg_hub/configs/skills/evaluate_freeform_pair.yaml +79 -12
  16. sdg_hub/configs/skills/evaluate_freeform_questions.yaml +60 -28
  17. sdg_hub/configs/skills/evaluate_grounded_pair.yaml +95 -30
  18. sdg_hub/configs/skills/freeform_questions.yaml +21 -16
  19. sdg_hub/configs/skills/freeform_responses.yaml +19 -25
  20. sdg_hub/configs/skills/router.yaml +53 -6
  21. sdg_hub/flow.py +351 -21
  22. sdg_hub/flow_runner.py +216 -0
  23. sdg_hub/flows/generation/knowledge/synth_knowledge1.5.yaml +26 -9
  24. sdg_hub/flows/generation/skills/{agentic_improve_skill.yaml → improve_responses.yaml} +26 -31
  25. sdg_hub/flows/generation/skills/synth_skills.yaml +4 -4
  26. sdg_hub/pipeline.py +67 -12
  27. sdg_hub/prompts.py +26 -0
  28. sdg_hub/sdg.py +128 -86
  29. sdg_hub/utils/config_validation.py +91 -0
  30. sdg_hub/utils/validation_result.py +10 -0
  31. sdg_hub-0.1.1.dist-info/METADATA +190 -0
  32. sdg_hub-0.1.1.dist-info/RECORD +86 -0
  33. {sdg_hub-0.1.0a3.dist-info → sdg_hub-0.1.1.dist-info}/WHEEL +1 -1
  34. sdg_hub/blocks/filterblock.py +0 -76
  35. sdg_hub/blocks/iterblock.py +0 -31
  36. sdg_hub/blocks/rmblocks.py +0 -194
  37. sdg_hub/configs/annotations/simple.yaml +0 -10
  38. sdg_hub/configs/knowledge/data_recipe/default_recipe.yaml +0 -3
  39. sdg_hub/configs/skills/data_recipe/default_recipe.yaml +0 -6
  40. sdg_hub/flows/annotation/emotion/detailed_description.yaml +0 -19
  41. sdg_hub/flows/annotation/emotion/detailed_description_icl.yaml +0 -19
  42. sdg_hub/flows/annotation/emotion/simple.yaml +0 -19
  43. sdg_hub/utils/chunking.py +0 -73
  44. sdg_hub/utils/docprocessor.py +0 -357
  45. sdg_hub/utils/parse_and_convert.py +0 -392
  46. sdg_hub-0.1.0a3.dist-info/METADATA +0 -154
  47. sdg_hub-0.1.0a3.dist-info/RECORD +0 -90
  48. /sdg_hub/configs/{knowledge/data_recipe → reasoning}/__init__.py +0 -0
  49. /sdg_hub/configs/skills/{_G_.yaml → icl_examples/STEM.yaml} +0 -0
  50. /sdg_hub/configs/skills/{data_recipe → icl_examples}/__init__.py +0 -0
  51. /sdg_hub/configs/skills/{_A_.yaml → icl_examples/coding.yaml} +0 -0
  52. /sdg_hub/configs/skills/{_B_.yaml → icl_examples/extraction.yaml} +0 -0
  53. /sdg_hub/configs/skills/{_C_.yaml → icl_examples/humanities.yaml} +0 -0
  54. /sdg_hub/configs/skills/{_D_.yaml → icl_examples/math.yaml} +0 -0
  55. /sdg_hub/configs/skills/{_E_.yaml → icl_examples/reasoning.yaml} +0 -0
  56. /sdg_hub/configs/skills/{_F_.yaml → icl_examples/roleplay.yaml} +0 -0
  57. /sdg_hub/configs/skills/{_H_.yaml → icl_examples/writing.yaml} +0 -0
  58. {sdg_hub-0.1.0a3.dist-info → sdg_hub-0.1.1.dist-info}/licenses/LICENSE +0 -0
  59. {sdg_hub-0.1.0a3.dist-info → sdg_hub-0.1.1.dist-info}/top_level.txt +0 -0
@@ -2,34 +2,34 @@
2
2
  block_config:
3
3
  block_name: router
4
4
  config_path: configs/skills/router.yaml
5
- model_id: skill-classifier-v3-clm
5
+ model_id: meta-llama/Llama-3.3-70B-Instruct
6
6
  output_cols:
7
7
  - route
8
8
  gen_kwargs:
9
9
  temperature: 0
10
- max_tokens: 1
10
+ max_tokens: 5
11
11
  extra_body:
12
- allowed_token_ids:
13
- - 32001
14
- - 32002
15
- - 32003
16
- - 32004
17
- - 32005
18
- - 32006
19
- - 32007
20
- - 32008
12
+ guided_choice:
13
+ - "coding"
14
+ - "extraction"
15
+ - "humanities"
16
+ - "math"
17
+ - "reasoning"
18
+ - "roleplay"
19
+ - "STEM"
20
+ - "writing"
21
21
  - block_type: SamplePopulatorBlock
22
22
  block_config:
23
23
  block_name: icl_populator
24
24
  config_paths:
25
- - configs/skills/_A_.yaml
26
- - configs/skills/_B_.yaml
27
- - configs/skills/_C_.yaml
28
- - configs/skills/_D_.yaml
29
- - configs/skills/_E_.yaml
30
- - configs/skills/_F_.yaml
31
- - configs/skills/_G_.yaml
32
- - configs/skills/_H_.yaml
25
+ - configs/skills/icl_examples/coding.yaml
26
+ - configs/skills/icl_examples/extraction.yaml
27
+ - configs/skills/icl_examples/humanities.yaml
28
+ - configs/skills/icl_examples/math.yaml
29
+ - configs/skills/icl_examples/reasoning.yaml
30
+ - configs/skills/icl_examples/roleplay.yaml
31
+ - configs/skills/icl_examples/STEM.yaml
32
+ - configs/skills/icl_examples/writing.yaml
33
33
  column_name: route
34
34
  batch_kwargs:
35
35
  num_procs: 8
@@ -37,8 +37,7 @@
37
37
  block_config:
38
38
  block_name: analyzer
39
39
  config_path: configs/skills/analyzer.yaml
40
- model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
41
- model_prompt: <s> [INST] {prompt} [/INST]
40
+ model_id: meta-llama/Llama-3.3-70B-Instruct
42
41
  output_cols:
43
42
  - analysis
44
43
  - rubric
@@ -46,24 +45,21 @@
46
45
  block_config:
47
46
  block_name: critic
48
47
  config_path: configs/skills/critic.yaml
49
- model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
50
- model_prompt: <s> [INST] {prompt} [/INST]
48
+ model_id: meta-llama/Llama-3.3-70B-Instruct
51
49
  output_cols:
52
50
  - critique
53
51
  - block_type: LLMBlock
54
52
  block_config:
55
53
  block_name: planner
56
54
  config_path: configs/skills/planner.yaml
57
- model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
58
- model_prompt: <s> [INST] {prompt} [/INST]
55
+ model_id: meta-llama/Llama-3.3-70B-Instruct
59
56
  output_cols:
60
57
  - plan
61
58
  - block_type: LLMBlock
62
59
  block_config:
63
60
  block_name: revised_responder
64
61
  config_path: configs/skills/revised_responder.yaml
65
- model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
66
- model_prompt: <s> [INST] {prompt} [/INST]
62
+ model_id: meta-llama/Llama-3.3-70B-Instruct
67
63
  output_cols:
68
64
  - revised_response
69
65
  drop_columns:
@@ -78,8 +74,7 @@
78
74
  block_config:
79
75
  block_name: judge
80
76
  config_path: configs/skills/judge.yaml
81
- model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
82
- model_prompt: <s> [INST] {prompt} [/INST]
77
+ model_id: meta-llama/Llama-3.3-70B-Instruct
83
78
  output_cols:
84
79
  - judgement
85
80
  - verdict
@@ -100,9 +95,9 @@
100
95
  Assistant A: "response"
101
96
  Assistant B: "revised_response"
102
97
  choice_col: verdict
103
- output_col: chosen_reponse
98
+ output_col: chosen_response
104
99
  batch_kwargs:
105
100
  num_procs: 8
106
101
  drop_columns:
107
102
  - judgemnent
108
- - verdict
103
+ - verdict
@@ -2,7 +2,7 @@
2
2
  block_config:
3
3
  block_name: gen_questions
4
4
  config_path: configs/skills/freeform_questions.yaml
5
- model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
5
+ model_id: meta-llama/Llama-3.3-70B-Instruct
6
6
  output_cols:
7
7
  - question
8
8
  batch_kwargs:
@@ -13,7 +13,7 @@
13
13
  block_config:
14
14
  block_name: eval_questions
15
15
  config_path: configs/skills/evaluate_freeform_questions.yaml
16
- model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
16
+ model_id: meta-llama/Llama-3.3-70B-Instruct
17
17
  output_cols:
18
18
  - evaluation
19
19
  - score
@@ -34,14 +34,14 @@
34
34
  block_config:
35
35
  block_name: gen_responses
36
36
  config_path: configs/skills/freeform_responses.yaml
37
- model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
37
+ model_id: meta-llama/Llama-3.3-70B-Instruct
38
38
  output_cols:
39
39
  - response
40
40
  - block_type: LLMBlock
41
41
  block_config:
42
42
  block_name: evaluate_qa_pair
43
43
  config_path: configs/skills/evaluate_freeform_pair.yaml
44
- model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
44
+ model_id: meta-llama/Llama-3.3-70B-Instruct
45
45
  output_cols:
46
46
  - evaluation
47
47
  - score
sdg_hub/pipeline.py CHANGED
@@ -1,6 +1,17 @@
1
+ """
2
+ Deprecated Pipeline class for data generation pipelines.
3
+
4
+ Use the Flow class directly for new code.
5
+ """
6
+
1
7
  # SPDX-License-Identifier: Apache-2.0
8
+ # Standard
9
+ import warnings
10
+ from typing import List, Dict, Any
11
+
2
12
  # Third Party
3
13
  from datasets import Dataset
14
+ from datasets.data_files import EmptyDatasetError
4
15
 
5
16
  # Local
6
17
  from .logger_config import setup_logger
@@ -8,31 +19,75 @@ from .logger_config import setup_logger
8
19
  logger = setup_logger(__name__)
9
20
 
10
21
 
11
- class EmptyDatasetError(Exception):
12
- pass
22
+ class Pipeline:
23
+ """A class representing a data generation pipeline.
13
24
 
25
+ This class is deprecated and will be removed in a future version.
26
+ Use the Flow class directly instead.
14
27
 
15
- class Pipeline:
16
- def __init__(self, chained_blocks: list) -> None:
28
+ Parameters
29
+ ----------
30
+ chained_blocks : List[Dict[str, Any]]
31
+ List of block configurations to execute in sequence.
32
+
33
+ Attributes
34
+ ----------
35
+ chained_blocks : List[Dict[str, Any]]
36
+ List of block configurations to execute in sequence.
37
+ """
38
+
39
+ def __init__(self, chained_blocks: List[Dict[str, Any]]) -> None:
17
40
  """
18
41
  Initialize the Pipeline class with a configuration dictionary.
19
- config_dict: the run config py or yaml loaded into a dictionary
42
+
43
+ DEPRECATED: This class is deprecated and will be removed in a future version.
44
+ Use the Flow class directly instead.
20
45
  """
46
+ warnings.warn(
47
+ "Pipeline class is deprecated and will be removed in a future version. "
48
+ "Use Flow class directly instead of wrapping it with Pipeline.",
49
+ DeprecationWarning,
50
+ stacklevel=2
51
+ )
21
52
  # pipeline config is the run configuration that consists of the pipeline steps
22
53
  self.chained_blocks = chained_blocks
23
54
 
24
- def _drop_duplicates(self, dataset, cols):
25
- """
26
- Drop duplicates from the dataset based on the columns provided.
55
+ def _drop_duplicates(self, dataset: Dataset, cols: List[str]) -> Dataset:
56
+ """Drop duplicates from the dataset based on the columns provided.
57
+
58
+ Parameters
59
+ ----------
60
+ dataset : Dataset
61
+ The input dataset.
62
+ cols : List[str]
63
+ Columns to consider for duplicate detection.
64
+
65
+ Returns
66
+ -------
67
+ Dataset
68
+ Dataset with duplicates removed.
27
69
  """
28
70
  df = dataset.to_pandas()
29
71
  df = df.drop_duplicates(subset=cols).reset_index(drop=True)
30
72
  return Dataset.from_pandas(df)
31
73
 
32
- def generate(self, dataset) -> Dataset:
33
- """
34
- Generate the dataset by running the pipeline steps.
35
- dataset: the input dataset
74
+ def generate(self, dataset: Dataset) -> Dataset:
75
+ """Generate the dataset by running the pipeline steps.
76
+
77
+ Parameters
78
+ ----------
79
+ dataset : Dataset
80
+ The input dataset to process.
81
+
82
+ Returns
83
+ -------
84
+ Dataset
85
+ The processed dataset.
86
+
87
+ Raises
88
+ ------
89
+ EmptyDatasetError
90
+ If a block produces an empty dataset.
36
91
  """
37
92
  for block_prop in self.chained_blocks:
38
93
  block_type = block_prop["block_type"]
sdg_hub/prompts.py CHANGED
@@ -15,3 +15,29 @@ def instructlab_chat_template():
15
15
  @PromptRegistry.register("mistralai")
16
16
  def mistral_chat_template():
17
17
  return """{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n<s>\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.first and system_message is defined %}\n {{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}\n {%- else %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- ' ' + message['content'] + '</s>'}}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n{%- endfor %}\n"""
18
+
19
+
20
+ @PromptRegistry.register("meta-llama/Llama-3.3")
21
+ def meta_llama_chat_template():
22
+ return """{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message + builtin tools #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if builtin_tools is defined or tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n {{- \"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- \"<|python_tag|>\" + tool_call.name + \".call(\" }}\n {%- for arg_name, arg_val in tool_call.arguments | items %}\n {{- arg_name + '=\"' + arg_val + '\"' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \")\" }}\n {%- else %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {%- endif %}\n {%- if builtin_tools is defined %}\n {#- This means we're in ipython mode #}\n {{- \"<|eom_id|>\" }}\n {%- else %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n"""
23
+
24
+
25
+ @PromptRegistry.register("microsoft/phi-4")
26
+ def microsoft_phi_chat_template():
27
+ return """{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|im_start|>system<|im_sep|>' + message['content'] + '<|im_end|>'}}{% elif (message['role'] == 'user') %}{{'<|im_start|>user<|im_sep|>' + message['content'] + '<|im_end|>'}}{% elif (message['role'] == 'assistant') %}{{'<|im_start|>assistant<|im_sep|>' + message['content'] + '<|im_end|>'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|im_sep|>' }}{% endif %}"""
28
+
29
+ @PromptRegistry.register("nvidia/Llama-3_3-Nemotron-Super-49B-v1")
30
+ def nemotron_chat_template():
31
+ return """{{- bos_token }}
32
+ {{- "<|start_header_id|>system<|end_header_id|>\n\n" }}detailed thinking on{{- "<|eot_id|>" }}
33
+ {%- for message in messages %}
34
+ {%- if message['role'] == 'assistant' and '</think>' in message['content'] %}
35
+ {%- set content = message['content'].split('</think>')[-1].lstrip() %}
36
+ {%- else %}
37
+ {%- set content = message['content'] %}
38
+ {%- endif %}
39
+ {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + content | trim + '<|eot_id|>' }}
40
+ {%- endfor %}
41
+ {%- if add_generation_prompt %}
42
+ {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
43
+ {%- endif %}"""
sdg_hub/sdg.py CHANGED
@@ -1,35 +1,83 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
+
3
+ """Synthetic Data Generator (SDG) module for managing data generation flows."""
4
+
2
5
  # Standard
3
6
  from concurrent.futures import ThreadPoolExecutor, as_completed
4
- from typing import List
7
+ from typing import List, Optional, Tuple
5
8
  import traceback
6
- import uuid
7
9
 
8
10
  # Third Party
9
- from datasets import Dataset, load_dataset
10
- from datasets.data_files import EmptyDatasetError
11
+ from datasets import Dataset
11
12
  from tqdm import tqdm
12
13
 
13
14
  # Local
15
+ from .checkpointer import Checkpointer
16
+ from .flow import Flow
14
17
  from .logger_config import setup_logger
15
- from .pipeline import Pipeline
16
18
  from .utils.datautils import safe_concatenate_datasets
17
19
 
18
-
19
20
  logger = setup_logger(__name__)
20
21
 
21
22
 
22
23
  class SDG:
24
+ """Synthetic Data Generator class.
25
+
26
+ This class manages the generation of synthetic data using one or more
27
+ data generation flows.
28
+
29
+ Parameters
30
+ ----------
31
+ flows : List[Flow]
32
+ List of flows to execute.
33
+ num_workers : int, optional
34
+ Number of worker threads to use, by default 1
35
+ batch_size : Optional[int], optional
36
+ Size of batches to process, by default None
37
+ save_freq : Optional[int], optional
38
+ Frequency of checkpoint saves, by default None
39
+
40
+ Attributes
41
+ ----------
42
+ flows : List[Flow]
43
+ List of flows to execute.
44
+ num_workers : int
45
+ Number of worker threads to use.
46
+ batch_size : Optional[int]
47
+ Size of batches to process.
48
+ save_freq : Optional[int]
49
+ Frequency of checkpoint saves.
50
+ """
51
+
23
52
  def __init__(
24
- self, pipelines: List[Pipeline], num_workers=1, batch_size=None, save_freq=None
53
+ self,
54
+ flows: List[Flow],
55
+ num_workers: int = 1,
56
+ batch_size: Optional[int] = None,
57
+ save_freq: Optional[int] = None,
25
58
  ) -> None:
26
- self.pipelines = pipelines
59
+ self.flows = flows
27
60
  self.num_workers = num_workers
28
61
  self.batch_size = batch_size
29
62
  self.save_freq = save_freq
30
63
 
31
- def _split_dataset(self, dataset: Dataset, batch_size: int) -> List[Dataset]:
32
- """Split the dataset into smaller batches."""
64
+ def _split_dataset(
65
+ self, dataset: Dataset, batch_size: int
66
+ ) -> List[Tuple[int, int]]:
67
+ """Split the dataset into smaller batches.
68
+
69
+ Parameters
70
+ ----------
71
+ dataset : Dataset
72
+ The dataset to split.
73
+ batch_size : int
74
+ Size of each batch.
75
+
76
+ Returns
77
+ -------
78
+ List[Tuple[int, int]]
79
+ List of (start, end) indices for each batch.
80
+ """
33
81
  total_size = len(dataset)
34
82
  num_batches = (total_size + batch_size - 1) // batch_size
35
83
 
@@ -40,94 +88,87 @@ class SDG:
40
88
 
41
89
  return batches
42
90
 
43
- def _get_missing_data(self, seed_data, generated_data):
44
- # Get the common columns between the two datasets
45
- common_columns = list(
46
- set(seed_data.column_names) & set(generated_data.column_names)
47
- )
48
-
49
- # Extract the relevant data based on common columns
50
- seed_data_common = seed_data.select_columns(common_columns)
51
- generated_data_common = generated_data.select_columns(common_columns)
52
-
53
- # Convert to Pandas DataFrames for easier comparison
54
- seed_df = seed_data_common.to_pandas()
55
- generated_df = generated_data_common.to_pandas()
56
-
57
- # Identify missing rows
58
- missing_df = seed_df[
59
- ~seed_df.apply(tuple, 1).isin(generated_df.apply(tuple, 1))
60
- ]
61
-
62
- # Convert back to Dataset
63
- missing_data = Dataset.from_pandas(missing_df, preserve_index=False)
64
-
65
- return missing_data
66
-
67
- def _save_intermediate_checkpoint(self, dataset, checkpoint_dir):
68
- checkpoint_id = uuid.uuid4().hex
69
- checkpoint_file = f"{checkpoint_dir}/data_checkpoint_{checkpoint_id}.jsonl"
70
- logger.info(f"Saving checkpoint to {checkpoint_file}")
71
- dataset.to_json(checkpoint_file, orient="records", lines=True)
72
-
73
91
  @staticmethod
74
- def _generate_data(pipelines, input_split, ds, i=None):
92
+ def _generate_data(
93
+ flows: List[Flow],
94
+ input_split: Tuple[int, int],
95
+ ds: Dataset,
96
+ i: Optional[int] = None,
97
+ ) -> Optional[Dataset]:
98
+ """Generate data for a single split using the provided flows.
99
+
100
+ Parameters
101
+ ----------
102
+ flows : List[Flow]
103
+ List of flows to execute.
104
+ input_split : Tuple[int, int]
105
+ (start, end) indices for the current split.
106
+ ds : Dataset
107
+ The full input dataset.
108
+ i : Optional[int], optional
109
+ Split index for logging, by default None
110
+
111
+ Returns
112
+ -------
113
+ Optional[Dataset]
114
+ Generated dataset for the split, or None if generation failed.
115
+ """
75
116
  logger.info(f"Processing split {i}")
76
117
  input_split = ds.select(range(input_split[0], input_split[1]))
77
118
  try:
78
- for pipeline in pipelines:
79
- input_split = pipeline.generate(input_split)
119
+ for flow in flows:
120
+ input_split = flow.generate(input_split)
80
121
  return input_split
81
122
  except Exception as e:
82
123
  logger.error(f"Error processing split {i}: {e}")
83
124
  traceback.print_exc()
84
125
  return None
85
126
 
86
- def generate(self, dataset: Dataset, checkpoint_dir=None) -> Dataset:
87
- # check if checkpoint_dir exists
88
- pre_generated_data = []
89
- if checkpoint_dir is not None:
90
- try:
91
- # check if there are any existing checkpoints
92
- pre_generated_data = load_dataset(
93
- "json", data_dir=checkpoint_dir, split="train"
94
- )
95
- logger.info(
96
- f"Loading existing checkpoints from {checkpoint_dir}, with {pre_generated_data.num_rows} rows"
97
- )
98
- seed_data = self._get_missing_data(dataset, pre_generated_data)
99
- if seed_data.num_rows == 0:
100
- logger.info(
101
- f"All seed data has been generated, no missing rows found, returning data from {checkpoint_dir}"
102
- )
103
- return pre_generated_data
104
- logger.info(f"Found {seed_data.num_rows} missing rows in the dataset")
105
-
106
- except EmptyDatasetError:
107
- logger.info(
108
- f"No existing checkpoints found in {checkpoint_dir}, generating from scratch"
109
- )
110
- seed_data = dataset
111
-
112
- else:
113
- seed_data = dataset
127
+ def generate(
128
+ self, dataset: Dataset, checkpoint_dir: Optional[str] = None
129
+ ) -> Dataset:
130
+ """Generate synthetic data using the configured flows.
131
+
132
+ Parameters
133
+ ----------
134
+ dataset : Dataset
135
+ The input dataset to process.
136
+ checkpoint_dir : Optional[str], optional
137
+ Directory to save checkpoints, by default None
138
+
139
+ Returns
140
+ -------
141
+ Dataset
142
+ The generated dataset.
143
+
144
+ Notes
145
+ -----
146
+ If checkpoint_dir is provided, the generation process can be resumed
147
+ from the last checkpoint in case of interruption.
148
+ """
149
+ # Initialize checkpointer
150
+ checkpointer = Checkpointer(checkpoint_dir, self.save_freq)
151
+
152
+ # Load existing checkpoints and determine missing data
153
+ seed_data, pre_generated_data = checkpointer.load_existing_data(dataset)
154
+
155
+ # If all data has been generated, return the pre-generated data
156
+ if seed_data.num_rows == 0 and pre_generated_data is not None:
157
+ return pre_generated_data
114
158
 
115
159
  if not self.batch_size:
116
160
  # If batch size is not provided, generate the dataset in a single pass
117
161
  generated_dataset = seed_data
118
- # generated_data is initialized with seed_data, and it gets updated with each pipeline
119
- for pipeline in self.pipelines:
120
- generated_dataset = pipeline.generate(seed_data)
162
+ # generated_data is initialized with seed_data, and it gets updated with each flow
163
+ for flow in self.flows:
164
+ generated_dataset = flow.generate(generated_dataset)
121
165
  return generated_dataset
122
-
166
+
123
167
  logger.info("Splitting the dataset into smaller batches")
124
- input_splits = (
125
- self._split_dataset(seed_data, self.batch_size)
126
- if self.batch_size
127
- else [seed_data]
128
- )
168
+ input_splits = self._split_dataset(seed_data, self.batch_size)
129
169
  logger.info(
130
- f"Generating dataset with {len(input_splits)} splits, batch size {self.batch_size}, and {self.num_workers} workers"
170
+ f"Generating dataset with {len(input_splits)} splits, "
171
+ f"batch size {self.batch_size}, and {self.num_workers} workers"
131
172
  )
132
173
 
133
174
  generated_data = [pre_generated_data] if pre_generated_data else []
@@ -136,7 +177,7 @@ class SDG:
136
177
  with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
137
178
  futures = [
138
179
  executor.submit(
139
- self._generate_data, self.pipelines, input_split, seed_data, i
180
+ self._generate_data, self.flows, input_split, seed_data, i
140
181
  )
141
182
  for i, input_split in enumerate(input_splits)
142
183
  ]
@@ -147,16 +188,17 @@ class SDG:
147
188
  if generated_data_split:
148
189
  generated_data.append(generated_data_split)
149
190
  logger.info(f"Finished future processing split {i} \n\n")
150
- if self.save_freq and (i + 1) % self.save_freq == 0:
191
+
192
+ # Use checkpointer to handle intermediate saves
193
+ if checkpointer.should_save_checkpoint(i):
151
194
  # Save only the new splits since the last checkpoint
152
195
  new_splits = generated_data[last_saved_split_index : i + 1]
153
196
  checkpoint_dataset = safe_concatenate_datasets(new_splits)
154
197
  # check if checkpoint_dataset is not None
155
198
  if checkpoint_dataset:
156
- self._save_intermediate_checkpoint(
157
- checkpoint_dataset, checkpoint_dir
199
+ checkpointer.save_intermediate_checkpoint(
200
+ checkpoint_dataset
158
201
  )
159
-
160
202
  last_saved_split_index = i + 1
161
203
 
162
204
  generated_dataset = safe_concatenate_datasets(generated_data)
@@ -0,0 +1,91 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Configuration validation utilities for SDG Hub.
3
+
4
+ This module provides functions to validate configuration files used by blocks,
5
+ ensuring they meet the required schema and contain all necessary fields.
6
+ """
7
+
8
+ # Standard
9
+ from typing import Any, Dict, List
10
+
11
+ # Local
12
+ from ..logger_config import setup_logger
13
+
14
+ logger = setup_logger(__name__)
15
+
16
+
17
+ def validate_prompt_config_schema(
18
+ config: Dict[str, Any], config_path: str
19
+ ) -> tuple[bool, List[str]]:
20
+ """Validate that a prompt configuration file has the required schema fields.
21
+
22
+ For prompt template configs, 'system' and 'generation' are required fields.
23
+ Other fields like 'introduction', 'principles', 'examples', 'start_tags', 'end_tags' are optional.
24
+
25
+ Parameters
26
+ ----------
27
+ config : Dict[str, Any]
28
+ The loaded configuration dictionary.
29
+ config_path : str
30
+ The path to the configuration file (for error reporting).
31
+
32
+ Returns
33
+ -------
34
+ tuple[bool, List[str]]
35
+ A tuple containing:
36
+ - bool: True if schema is valid, False otherwise
37
+ - List[str]: List of validation error messages (empty if valid)
38
+ """
39
+ required_fields = ["system", "generation"]
40
+ errors = []
41
+
42
+ # Ensure config is a dictionary
43
+ if not isinstance(config, dict):
44
+ errors.append(f"Configuration must be a dictionary, got {type(config).__name__}")
45
+ return False, errors
46
+
47
+ # Check for missing required fields
48
+ missing_fields = [field for field in required_fields if field not in config]
49
+ if missing_fields:
50
+ errors.append(f"Missing required fields: {missing_fields}")
51
+
52
+ # Check for empty or null required fields and validate they are strings
53
+ for field in required_fields:
54
+ if field in config:
55
+ value = config[field]
56
+ if value is None:
57
+ errors.append(f"Required field '{field}' is null")
58
+ elif not isinstance(value, str):
59
+ errors.append(f"Required field '{field}' must be a string, got {type(value).__name__}")
60
+ elif not value.strip():
61
+ errors.append(f"Required field '{field}' is empty")
62
+
63
+ # Check optional string fields are strings when present
64
+ string_fields = ["introduction", "principles", "examples"]
65
+ for field in string_fields:
66
+ if field in config:
67
+ value = config[field]
68
+ if value is not None and not isinstance(value, str):
69
+ errors.append(f"Field '{field}' must be a string, got {type(value).__name__}")
70
+
71
+ # Check start_tags and end_tags are lists of strings when present
72
+ tag_fields = ["start_tags", "end_tags"]
73
+ for field in tag_fields:
74
+ if field in config:
75
+ value = config[field]
76
+ if value is not None:
77
+ if not isinstance(value, list):
78
+ errors.append(f"Field '{field}' must be a list, got {type(value).__name__}")
79
+ else:
80
+ for i, tag in enumerate(value):
81
+ if not isinstance(tag, str):
82
+ errors.append(f"Field '{field}[{i}]' must be a string, got {type(tag).__name__}")
83
+
84
+ # Log validation results
85
+ if errors:
86
+ for error in errors:
87
+ logger.error(f"Config validation failed for {config_path}: {error}")
88
+ return False, errors
89
+
90
+ logger.debug(f"Config validation passed for {config_path}")
91
+ return True, []