sdg-hub 0.1.0a4__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sdg_hub/_version.py +2 -2
- sdg_hub/blocks/__init__.py +35 -5
- sdg_hub/blocks/block.py +58 -16
- sdg_hub/blocks/llmblock.py +121 -193
- sdg_hub/blocks/utilblocks.py +500 -43
- sdg_hub/checkpointer.py +139 -0
- sdg_hub/configs/annotations/detailed_annotations.yaml +28 -0
- sdg_hub/configs/annotations/simple_annotations.yaml +9 -0
- sdg_hub/configs/knowledge/atomic_facts.yaml +1 -0
- sdg_hub/configs/knowledge/detailed_summary.yaml +1 -0
- sdg_hub/configs/knowledge/extractive_summary.yaml +1 -0
- sdg_hub/configs/knowledge/generate_questions.yaml +82 -0
- sdg_hub/configs/knowledge/generate_responses.yaml +86 -0
- sdg_hub/configs/skills/contexts.yaml +18 -11
- sdg_hub/configs/skills/evaluate_freeform_pair.yaml +79 -12
- sdg_hub/configs/skills/evaluate_freeform_questions.yaml +60 -28
- sdg_hub/configs/skills/evaluate_grounded_pair.yaml +95 -30
- sdg_hub/configs/skills/freeform_questions.yaml +21 -16
- sdg_hub/configs/skills/freeform_responses.yaml +19 -25
- sdg_hub/configs/skills/router.yaml +53 -6
- sdg_hub/flow.py +351 -21
- sdg_hub/flow_runner.py +216 -0
- sdg_hub/flows/generation/knowledge/synth_knowledge1.5.yaml +26 -9
- sdg_hub/flows/generation/skills/{agentic_improve_skill.yaml → improve_responses.yaml} +26 -31
- sdg_hub/flows/generation/skills/synth_skills.yaml +4 -4
- sdg_hub/pipeline.py +67 -12
- sdg_hub/prompts.py +21 -0
- sdg_hub/sdg.py +128 -86
- sdg_hub/utils/config_validation.py +91 -0
- sdg_hub/utils/validation_result.py +10 -0
- sdg_hub-0.1.1.dist-info/METADATA +190 -0
- sdg_hub-0.1.1.dist-info/RECORD +86 -0
- {sdg_hub-0.1.0a4.dist-info → sdg_hub-0.1.1.dist-info}/WHEEL +1 -1
- sdg_hub/blocks/filterblock.py +0 -76
- sdg_hub/blocks/iterblock.py +0 -31
- sdg_hub/blocks/rmblocks.py +0 -194
- sdg_hub/configs/annotations/simple.yaml +0 -10
- sdg_hub/configs/knowledge/data_recipe/default_recipe.yaml +0 -3
- sdg_hub/configs/skills/data_recipe/default_recipe.yaml +0 -6
- sdg_hub/flows/annotation/emotion/detailed_description.yaml +0 -19
- sdg_hub/flows/annotation/emotion/detailed_description_icl.yaml +0 -19
- sdg_hub/flows/annotation/emotion/simple.yaml +0 -19
- sdg_hub/utils/chunking.py +0 -73
- sdg_hub/utils/docprocessor.py +0 -357
- sdg_hub/utils/parse_and_convert.py +0 -392
- sdg_hub-0.1.0a4.dist-info/METADATA +0 -309
- sdg_hub-0.1.0a4.dist-info/RECORD +0 -90
- /sdg_hub/configs/{knowledge/data_recipe → reasoning}/__init__.py +0 -0
- /sdg_hub/configs/skills/{_G_.yaml → icl_examples/STEM.yaml} +0 -0
- /sdg_hub/configs/skills/{data_recipe → icl_examples}/__init__.py +0 -0
- /sdg_hub/configs/skills/{_A_.yaml → icl_examples/coding.yaml} +0 -0
- /sdg_hub/configs/skills/{_B_.yaml → icl_examples/extraction.yaml} +0 -0
- /sdg_hub/configs/skills/{_C_.yaml → icl_examples/humanities.yaml} +0 -0
- /sdg_hub/configs/skills/{_D_.yaml → icl_examples/math.yaml} +0 -0
- /sdg_hub/configs/skills/{_E_.yaml → icl_examples/reasoning.yaml} +0 -0
- /sdg_hub/configs/skills/{_F_.yaml → icl_examples/roleplay.yaml} +0 -0
- /sdg_hub/configs/skills/{_H_.yaml → icl_examples/writing.yaml} +0 -0
- {sdg_hub-0.1.0a4.dist-info → sdg_hub-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {sdg_hub-0.1.0a4.dist-info → sdg_hub-0.1.1.dist-info}/top_level.txt +0 -0
sdg_hub/flow.py
CHANGED
@@ -1,58 +1,145 @@
|
|
1
|
+
"""
|
2
|
+
Flow module for managing data generation pipelines.
|
3
|
+
|
4
|
+
This module provides the core Flow class that handles both configuration loading and execution
|
5
|
+
of data generation blocks. The Flow class serves as the main interface for defining and running
|
6
|
+
data generation pipelines, supporting both direct usage with SDG and backward compatibility
|
7
|
+
through the deprecated Pipeline class.
|
8
|
+
|
9
|
+
Example:
|
10
|
+
>>> flow = Flow(llm_client)
|
11
|
+
>>> flow = flow.get_flow_from_file("path/to/flow.yaml")
|
12
|
+
>>> dataset = flow.generate(input_dataset)
|
13
|
+
|
14
|
+
Note:
|
15
|
+
This module is part of the SDG Hub package and is designed to work in conjunction
|
16
|
+
with the SDG class for distributed data generation.
|
17
|
+
"""
|
18
|
+
|
1
19
|
# SPDX-License-Identifier: Apache-2.0
|
2
20
|
# Standard
|
3
21
|
from abc import ABC
|
4
22
|
from importlib import resources
|
5
|
-
from typing import Optional
|
23
|
+
from typing import Any, Callable, Dict, List, Optional
|
6
24
|
import operator
|
7
25
|
import os
|
8
26
|
|
9
27
|
# Third Party
|
28
|
+
from datasets import Dataset
|
29
|
+
from datasets.data_files import EmptyDatasetError
|
30
|
+
from jinja2 import Environment, meta
|
31
|
+
from rich.console import Console
|
32
|
+
from rich.table import Table
|
10
33
|
import yaml
|
11
34
|
|
12
35
|
# Local
|
36
|
+
from .blocks import * # needed to register blocks
|
37
|
+
from .logger_config import setup_logger
|
38
|
+
from .prompts import * # needed to register prompts
|
13
39
|
from .registry import BlockRegistry, PromptRegistry
|
14
|
-
from . import
|
15
|
-
from . import
|
40
|
+
from .utils.config_validation import validate_prompt_config_schema
|
41
|
+
from .utils.validation_result import ValidationResult
|
42
|
+
|
43
|
+
logger = setup_logger(__name__)
|
16
44
|
|
17
45
|
|
18
|
-
OPERATOR_MAP = {
|
46
|
+
OPERATOR_MAP: Dict[str, Callable] = {
|
19
47
|
"operator.eq": operator.eq,
|
20
48
|
"operator.ge": operator.ge,
|
49
|
+
"operator.le": operator.le,
|
50
|
+
"operator.gt": operator.gt,
|
51
|
+
"operator.lt": operator.lt,
|
52
|
+
"operator.ne": operator.ne,
|
21
53
|
"operator.contains": operator.contains,
|
22
54
|
}
|
23
55
|
|
24
|
-
CONVERT_DTYPE_MAP = {
|
56
|
+
CONVERT_DTYPE_MAP: Dict[str, Callable] = {
|
25
57
|
"float": float,
|
26
58
|
"int": int,
|
27
59
|
}
|
28
60
|
|
29
61
|
|
30
62
|
class Flow(ABC):
|
63
|
+
"""A class representing a data generation flow.
|
64
|
+
|
65
|
+
This class handles both configuration loading and execution of data generation
|
66
|
+
blocks. It can be used directly with SDG or through the deprecated Pipeline class.
|
67
|
+
"""
|
68
|
+
|
31
69
|
def __init__(
|
32
70
|
self,
|
33
|
-
llm_client,
|
71
|
+
llm_client: Any,
|
34
72
|
num_samples_to_generate: Optional[int] = None,
|
73
|
+
log_level: Optional[str] = None,
|
35
74
|
) -> None:
|
75
|
+
"""
|
76
|
+
Initialize the Flow class.
|
77
|
+
|
78
|
+
Parameters
|
79
|
+
----------
|
80
|
+
llm_client : Any
|
81
|
+
The LLM client to use for generation.
|
82
|
+
num_samples_to_generate : Optional[int], optional
|
83
|
+
Number of samples to generate, by default None
|
84
|
+
log_level : Optional[str], optional
|
85
|
+
Logging verbosity level, by default None
|
86
|
+
|
87
|
+
Attributes
|
88
|
+
----------
|
89
|
+
llm_client : Any
|
90
|
+
The LLM client instance.
|
91
|
+
base_path : str
|
92
|
+
Base path for resource files.
|
93
|
+
registered_blocks : Dict[str, Any]
|
94
|
+
Registry of available blocks.
|
95
|
+
chained_blocks : Optional[List[Dict[str, Any]]]
|
96
|
+
List of block configurations.
|
97
|
+
num_samples_to_generate : Optional[int]
|
98
|
+
Number of samples to generate.
|
99
|
+
|
100
|
+
"""
|
36
101
|
self.llm_client = llm_client
|
37
|
-
self.num_samples_to_generate = num_samples_to_generate
|
38
102
|
self.base_path = str(resources.files(__package__))
|
39
103
|
self.registered_blocks = BlockRegistry.get_registry()
|
104
|
+
self.chained_blocks = None # Will be set by get_flow_from_file
|
105
|
+
self.num_samples_to_generate = num_samples_to_generate
|
40
106
|
|
41
|
-
|
42
|
-
"""
|
43
|
-
|
107
|
+
# Logging verbosity level
|
108
|
+
self.log_level = log_level or os.getenv("SDG_HUB_LOG_LEVEL", "normal").lower()
|
109
|
+
self.console = Console() if self.log_level in ["verbose", "debug"] else None
|
44
110
|
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
111
|
+
def _log_block_info(
|
112
|
+
self, index: int, total: int, name: str, ds: Dataset, stage: str
|
113
|
+
) -> None:
|
114
|
+
if self.log_level in ["verbose", "debug"] and self.console:
|
115
|
+
table = Table(
|
116
|
+
title=f"{stage} Block {index + 1}/{total}: {name}", show_header=True
|
117
|
+
)
|
118
|
+
table.add_column("Metric", style="cyan", no_wrap=True)
|
119
|
+
table.add_column("Value", style="magenta")
|
120
|
+
table.add_row("Rows", str(len(ds)))
|
121
|
+
table.add_row("Columns", ", ".join(ds.column_names))
|
122
|
+
self.console.print(table)
|
123
|
+
|
124
|
+
def _getFilePath(self, dirs: List[str], filename: str) -> str:
|
125
|
+
"""Find a named configuration file.
|
49
126
|
|
50
|
-
|
51
|
-
|
52
|
-
|
127
|
+
Files are checked in the following order:
|
128
|
+
1. Absolute path is always used
|
129
|
+
2. Checked relative to the directories in "dirs"
|
130
|
+
3. Relative to the current directory
|
53
131
|
|
54
|
-
|
55
|
-
|
132
|
+
Parameters
|
133
|
+
----------
|
134
|
+
dirs : List[str]
|
135
|
+
Directories in which to search for the file.
|
136
|
+
filename : str
|
137
|
+
The path to the configuration file.
|
138
|
+
|
139
|
+
Returns
|
140
|
+
-------
|
141
|
+
str
|
142
|
+
Selected file path.
|
56
143
|
"""
|
57
144
|
if os.path.isabs(filename):
|
58
145
|
return filename
|
@@ -64,7 +151,175 @@ class Flow(ABC):
|
|
64
151
|
# assume the path is relative to the current directory
|
65
152
|
return filename
|
66
153
|
|
67
|
-
def
|
154
|
+
def _drop_duplicates(self, dataset: Dataset, cols: List[str]) -> Dataset:
|
155
|
+
"""Drop duplicates from the dataset based on the columns provided.
|
156
|
+
|
157
|
+
Parameters
|
158
|
+
----------
|
159
|
+
dataset : Dataset
|
160
|
+
The input dataset.
|
161
|
+
cols : List[str]
|
162
|
+
Columns to consider for duplicate detection.
|
163
|
+
|
164
|
+
Returns
|
165
|
+
-------
|
166
|
+
Dataset
|
167
|
+
Dataset with duplicates removed.
|
168
|
+
"""
|
169
|
+
df = dataset.to_pandas()
|
170
|
+
df = df.drop_duplicates(subset=cols).reset_index(drop=True)
|
171
|
+
return Dataset.from_pandas(df)
|
172
|
+
|
173
|
+
def generate(self, dataset: Dataset) -> Dataset:
|
174
|
+
"""Generate the dataset by running the pipeline steps.
|
175
|
+
|
176
|
+
Parameters
|
177
|
+
----------
|
178
|
+
dataset : Dataset
|
179
|
+
The input dataset to process.
|
180
|
+
|
181
|
+
Returns
|
182
|
+
-------
|
183
|
+
Dataset
|
184
|
+
The processed dataset.
|
185
|
+
|
186
|
+
Raises
|
187
|
+
------
|
188
|
+
ValueError
|
189
|
+
If Flow has not been initialized with blocks.
|
190
|
+
EmptyDatasetError
|
191
|
+
If a block produces an empty dataset.
|
192
|
+
"""
|
193
|
+
if self.chained_blocks is None:
|
194
|
+
raise ValueError(
|
195
|
+
"Flow has not been initialized with blocks. "
|
196
|
+
"Call get_flow_from_file() first. "
|
197
|
+
"Or pass a list of blocks to the Flow constructor."
|
198
|
+
)
|
199
|
+
|
200
|
+
for i, block_prop in enumerate(self.chained_blocks):
|
201
|
+
block_type = block_prop["block_type"]
|
202
|
+
block_config = block_prop["block_config"]
|
203
|
+
drop_columns = block_prop.get("drop_columns", [])
|
204
|
+
gen_kwargs = block_prop.get("gen_kwargs", {})
|
205
|
+
drop_duplicates_cols = block_prop.get("drop_duplicates", False)
|
206
|
+
block = block_type(**block_config)
|
207
|
+
|
208
|
+
name = block_config.get("block_name", f"block_{i}")
|
209
|
+
|
210
|
+
# Logging: always show basic progress unless in quiet mode
|
211
|
+
if self.log_level in ["normal", "verbose", "debug"]:
|
212
|
+
logger.info(
|
213
|
+
f"🔄 Running block {i + 1}/{len(self.chained_blocks)}: {name}"
|
214
|
+
)
|
215
|
+
|
216
|
+
# Log dataset shape before block (verbose/debug)
|
217
|
+
self._log_block_info(i, len(self.chained_blocks), name, dataset, "Input")
|
218
|
+
|
219
|
+
if self.log_level == "debug":
|
220
|
+
logger.debug(f"Input dataset (truncated): {dataset}")
|
221
|
+
|
222
|
+
dataset = block.generate(dataset, **gen_kwargs)
|
223
|
+
|
224
|
+
if len(dataset) == 0:
|
225
|
+
raise EmptyDatasetError(
|
226
|
+
f"Pipeline stopped: "
|
227
|
+
f"Empty dataset after running block: "
|
228
|
+
f"{block_config['block_name']}"
|
229
|
+
)
|
230
|
+
|
231
|
+
drop_columns_in_ds = [e for e in drop_columns if e in dataset.column_names]
|
232
|
+
if drop_columns:
|
233
|
+
dataset = dataset.remove_columns(drop_columns_in_ds)
|
234
|
+
|
235
|
+
if drop_duplicates_cols:
|
236
|
+
dataset = self._drop_duplicates(dataset, cols=drop_duplicates_cols)
|
237
|
+
|
238
|
+
# Log dataset shape after block (verbose/debug)
|
239
|
+
self._log_block_info(i, len(self.chained_blocks), name, dataset, "Output")
|
240
|
+
|
241
|
+
if self.log_level == "debug":
|
242
|
+
logger.debug(f"Output dataset (truncated): {dataset}")
|
243
|
+
|
244
|
+
return dataset
|
245
|
+
|
246
|
+
def validate_config_files(self) -> "ValidationResult":
|
247
|
+
"""
|
248
|
+
Validate all configuration file paths referenced in the flow blocks.
|
249
|
+
|
250
|
+
This method checks that all config files specified via `config_path` or `config_paths`
|
251
|
+
in each block:
|
252
|
+
- Exist on the filesystem
|
253
|
+
- Are readable by the current process
|
254
|
+
- Are valid YAML files (optional format check)
|
255
|
+
|
256
|
+
Returns
|
257
|
+
-------
|
258
|
+
ValidationResult
|
259
|
+
An object indicating whether all config files passed validation, along with a list
|
260
|
+
of error messages for any missing, unreadable, or invalid YAML files.
|
261
|
+
|
262
|
+
Notes
|
263
|
+
-----
|
264
|
+
This method is automatically called at the end of `get_flow_from_file()` to ensure
|
265
|
+
early detection of misconfigured blocks.
|
266
|
+
"""
|
267
|
+
errors = []
|
268
|
+
|
269
|
+
def check_file(path: str, context: str):
|
270
|
+
if not os.path.isfile(path):
|
271
|
+
errors.append(f"[{context}] File does not exist: {path}")
|
272
|
+
else:
|
273
|
+
try:
|
274
|
+
with open(path, "r", encoding="utf-8") as f:
|
275
|
+
config_data = yaml.safe_load(f)
|
276
|
+
_, validation_errors = validate_prompt_config_schema(config_data, path)
|
277
|
+
|
278
|
+
if validation_errors:
|
279
|
+
errors.extend(validation_errors)
|
280
|
+
|
281
|
+
except PermissionError:
|
282
|
+
errors.append(f"[{context}] File is not readable: {path}")
|
283
|
+
except yaml.YAMLError as e:
|
284
|
+
errors.append(f"[{context}] YAML load failed: {path} ({e})")
|
285
|
+
|
286
|
+
for i, block in enumerate(self.chained_blocks or []):
|
287
|
+
block_name = block["block_config"].get("block_name", f"block_{i}")
|
288
|
+
|
289
|
+
config_path = block["block_config"].get("config_path")
|
290
|
+
if config_path:
|
291
|
+
check_file(config_path, f"{block_name}.config_path")
|
292
|
+
|
293
|
+
config_paths = block["block_config"].get("config_paths")
|
294
|
+
if isinstance(config_paths, list):
|
295
|
+
for idx, path in enumerate(config_paths):
|
296
|
+
check_file(path, f"{block_name}.config_paths[{idx}]")
|
297
|
+
elif isinstance(config_paths, dict):
|
298
|
+
for key, path in config_paths.items():
|
299
|
+
check_file(path, f"{block_name}.config_paths['{key}']")
|
300
|
+
|
301
|
+
return ValidationResult(valid=(len(errors) == 0), errors=errors)
|
302
|
+
|
303
|
+
def get_flow_from_file(self, yaml_path: str) -> "Flow":
|
304
|
+
"""Load and initialize flow configuration from a YAML file.
|
305
|
+
|
306
|
+
Parameters
|
307
|
+
----------
|
308
|
+
yaml_path : str
|
309
|
+
Path to the YAML configuration file.
|
310
|
+
|
311
|
+
Returns
|
312
|
+
-------
|
313
|
+
Flow
|
314
|
+
Self with initialized chained_blocks.
|
315
|
+
|
316
|
+
Raises
|
317
|
+
------
|
318
|
+
FileNotFoundError
|
319
|
+
If the YAML file cannot be found.
|
320
|
+
KeyError
|
321
|
+
If a required block or prompt is not found in the registry.
|
322
|
+
"""
|
68
323
|
yaml_path_relative_to_base = os.path.join(self.base_path, yaml_path)
|
69
324
|
if os.path.isfile(yaml_path_relative_to_base):
|
70
325
|
yaml_path = yaml_path_relative_to_base
|
@@ -141,4 +396,79 @@ class Flow(ABC):
|
|
141
396
|
block["block_config"]["convert_dtype"]
|
142
397
|
]
|
143
398
|
|
144
|
-
return
|
399
|
+
# Store the chained blocks and return self
|
400
|
+
self.chained_blocks = flow
|
401
|
+
|
402
|
+
# Validate config files
|
403
|
+
result = self.validate_config_files()
|
404
|
+
if not result.valid:
|
405
|
+
raise ValueError("Invalid config files:\n\n".join(result.errors))
|
406
|
+
|
407
|
+
return self
|
408
|
+
|
409
|
+
def validate_flow(self, dataset: Dataset) -> "ValidationResult":
|
410
|
+
"""
|
411
|
+
Validate that all required dataset columns are present before executing the flow.
|
412
|
+
|
413
|
+
This includes:
|
414
|
+
- Columns referenced in Jinja templates for LLM blocks
|
415
|
+
- Columns required by specific utility blocks (e.g. filter_column, choice_col, etc.)
|
416
|
+
|
417
|
+
Parameters
|
418
|
+
----------
|
419
|
+
dataset : Dataset
|
420
|
+
The input dataset to validate against.
|
421
|
+
|
422
|
+
Returns
|
423
|
+
-------
|
424
|
+
ValidationResult
|
425
|
+
Whether the dataset has all required columns, and which ones are missing.
|
426
|
+
"""
|
427
|
+
errors = []
|
428
|
+
all_columns = set(dataset.column_names)
|
429
|
+
|
430
|
+
for i, block in enumerate(self.chained_blocks or []):
|
431
|
+
name = block["block_config"].get("block_name", f"block_{i}")
|
432
|
+
block_type = block["block_type"]
|
433
|
+
config = block["block_config"]
|
434
|
+
|
435
|
+
# LLM Block: parse Jinja vars
|
436
|
+
cls_name = block_type.__name__ if isinstance(block_type, type) else block_type.__class__.__name__
|
437
|
+
logger.info(f"Validating block: {name} ({cls_name})")
|
438
|
+
if "LLM" in cls_name:
|
439
|
+
config_path = config.get("config_path")
|
440
|
+
if config_path and os.path.isfile(config_path):
|
441
|
+
with open(config_path, "r", encoding="utf-8") as f:
|
442
|
+
content = f.read()
|
443
|
+
env = Environment()
|
444
|
+
ast = env.parse(content)
|
445
|
+
vars_found = meta.find_undeclared_variables(ast)
|
446
|
+
for var in vars_found:
|
447
|
+
if var not in all_columns:
|
448
|
+
errors.append(f"[{name}] Missing column for prompt var: '{var}'")
|
449
|
+
|
450
|
+
# FilterByValueBlock
|
451
|
+
if "FilterByValueBlock" in str(block_type):
|
452
|
+
col = config.get("filter_column")
|
453
|
+
if col and col not in all_columns:
|
454
|
+
errors.append(f"[{name}] Missing filter_column: '{col}'")
|
455
|
+
|
456
|
+
# SelectorBlock
|
457
|
+
if "SelectorBlock" in str(block_type):
|
458
|
+
col = config.get("choice_col")
|
459
|
+
if col and col not in all_columns:
|
460
|
+
errors.append(f"[{name}] Missing choice_col: '{col}'")
|
461
|
+
|
462
|
+
choice_map = config.get("choice_map", {})
|
463
|
+
for col in choice_map.values():
|
464
|
+
if col not in all_columns:
|
465
|
+
errors.append(f"[{name}] choice_map references missing column: '{col}'")
|
466
|
+
|
467
|
+
# CombineColumnsBlock
|
468
|
+
if "CombineColumnsBlock" in str(block_type):
|
469
|
+
cols = config.get("columns", [])
|
470
|
+
for col in cols:
|
471
|
+
if col not in all_columns:
|
472
|
+
errors.append(f"[{name}] CombineColumnsBlock requires column: '{col}'")
|
473
|
+
|
474
|
+
return ValidationResult(valid=(len(errors) == 0), errors=errors)
|
sdg_hub/flow_runner.py
ADDED
@@ -0,0 +1,216 @@
|
|
1
|
+
"""Script for running data generation flows with configurable parameters."""
|
2
|
+
|
3
|
+
# Standard
|
4
|
+
import os
|
5
|
+
|
6
|
+
# Third Party
|
7
|
+
from datasets import load_dataset
|
8
|
+
from openai import OpenAI
|
9
|
+
import click
|
10
|
+
|
11
|
+
# First Party
|
12
|
+
from sdg_hub.flow import Flow
|
13
|
+
from sdg_hub.logger_config import setup_logger
|
14
|
+
from sdg_hub.sdg import SDG
|
15
|
+
|
16
|
+
|
17
|
+
logger = setup_logger(__name__)
|
18
|
+
|
19
|
+
|
20
|
+
def run_flow(
|
21
|
+
ds_path: str,
|
22
|
+
save_path: str,
|
23
|
+
endpoint: str,
|
24
|
+
flow_path: str,
|
25
|
+
checkpoint_dir: str,
|
26
|
+
batch_size: int = 8,
|
27
|
+
num_workers: int = 32,
|
28
|
+
save_freq: int = 2,
|
29
|
+
debug: bool = False,
|
30
|
+
dataset_start_index: int = 0,
|
31
|
+
dataset_end_index: int = None,
|
32
|
+
) -> None:
|
33
|
+
"""Process the dataset using the specified configuration.
|
34
|
+
|
35
|
+
Parameters
|
36
|
+
----------
|
37
|
+
ds_path : str
|
38
|
+
Path to the dataset file.
|
39
|
+
save_path : str
|
40
|
+
Path where the output will be saved.
|
41
|
+
endpoint : str
|
42
|
+
API endpoint for data processing.
|
43
|
+
flow_path : str
|
44
|
+
Path to the flow configuration file.
|
45
|
+
checkpoint_dir : str
|
46
|
+
Directory path for saving checkpoints.
|
47
|
+
batch_size : int, optional
|
48
|
+
Batch size for processing, by default 8.
|
49
|
+
num_workers : int, optional
|
50
|
+
Number of worker processes to use, by default 32.
|
51
|
+
save_freq : int, optional
|
52
|
+
Frequency (in batches) at which to save checkpoints, by default 2.
|
53
|
+
debug : bool, optional
|
54
|
+
If True, enables debug mode with a smaller dataset subset, by default False.
|
55
|
+
|
56
|
+
Returns
|
57
|
+
-------
|
58
|
+
None
|
59
|
+
|
60
|
+
Raises
|
61
|
+
------
|
62
|
+
FileNotFoundError
|
63
|
+
If the flow configuration file is not found.
|
64
|
+
"""
|
65
|
+
logger.info(f"Generation configuration: {locals()}\n\n")
|
66
|
+
ds = load_dataset("json", data_files=ds_path, split="train")
|
67
|
+
if dataset_start_index is not None and dataset_end_index is not None:
|
68
|
+
ds = ds.select(range(dataset_start_index, dataset_end_index))
|
69
|
+
logger.info(f"Dataset sliced from {dataset_start_index} to {dataset_end_index}")
|
70
|
+
if debug:
|
71
|
+
ds = ds.shuffle(seed=42).select(range(30))
|
72
|
+
logger.info("Debug mode enabled. Using a subset of the dataset.")
|
73
|
+
|
74
|
+
openai_api_key = os.environ.get("OPENAI_API_KEY", "EMPTY")
|
75
|
+
openai_api_base = endpoint
|
76
|
+
|
77
|
+
client = OpenAI(
|
78
|
+
api_key=openai_api_key,
|
79
|
+
base_url=openai_api_base,
|
80
|
+
)
|
81
|
+
|
82
|
+
if not os.path.exists(flow_path):
|
83
|
+
raise FileNotFoundError(f"Flow file not found: {flow_path}")
|
84
|
+
|
85
|
+
flow = Flow(client).get_flow_from_file(flow_path)
|
86
|
+
sdg = SDG(
|
87
|
+
flows=[flow],
|
88
|
+
num_workers=num_workers,
|
89
|
+
batch_size=batch_size,
|
90
|
+
save_freq=save_freq,
|
91
|
+
)
|
92
|
+
|
93
|
+
generated_data = sdg.generate(ds, checkpoint_dir=checkpoint_dir)
|
94
|
+
if dataset_end_index is not None and dataset_start_index is not None:
|
95
|
+
save_path = save_path.replace(".jsonl", f"_{dataset_start_index}_{dataset_end_index}.jsonl")
|
96
|
+
generated_data.to_json(save_path, orient="records", lines=True)
|
97
|
+
logger.info(f"Data saved to {save_path}")
|
98
|
+
|
99
|
+
|
100
|
+
@click.command()
|
101
|
+
@click.option(
|
102
|
+
"--ds_path",
|
103
|
+
type=click.Path(exists=True),
|
104
|
+
required=True,
|
105
|
+
help="Path to the dataset.",
|
106
|
+
)
|
107
|
+
@click.option(
|
108
|
+
"--bs",
|
109
|
+
type=int,
|
110
|
+
default=8,
|
111
|
+
show_default=True,
|
112
|
+
help="Batch size for processing.",
|
113
|
+
)
|
114
|
+
@click.option(
|
115
|
+
"--num_workers",
|
116
|
+
type=int,
|
117
|
+
default=32,
|
118
|
+
show_default=True,
|
119
|
+
help="Number of worker processes to use.",
|
120
|
+
)
|
121
|
+
@click.option(
|
122
|
+
"--save_path",
|
123
|
+
type=click.Path(),
|
124
|
+
required=True,
|
125
|
+
help="Path to save the output.",
|
126
|
+
)
|
127
|
+
@click.option(
|
128
|
+
"--endpoint",
|
129
|
+
type=str,
|
130
|
+
required=True,
|
131
|
+
help="API endpoint for data processing.",
|
132
|
+
)
|
133
|
+
@click.option(
|
134
|
+
"--flow",
|
135
|
+
type=click.Path(exists=True),
|
136
|
+
required=True,
|
137
|
+
help="Flow configuration for the process.",
|
138
|
+
)
|
139
|
+
@click.option(
|
140
|
+
"--checkpoint_dir",
|
141
|
+
type=click.Path(),
|
142
|
+
required=True,
|
143
|
+
help="Path to save checkpoints.",
|
144
|
+
)
|
145
|
+
@click.option(
|
146
|
+
"--save_freq",
|
147
|
+
type=int,
|
148
|
+
default=2,
|
149
|
+
show_default=True,
|
150
|
+
help="Frequency to save checkpoints.",
|
151
|
+
)
|
152
|
+
@click.option(
|
153
|
+
"--debug",
|
154
|
+
is_flag=True,
|
155
|
+
help="Enable debug mode with a smaller dataset subset.",
|
156
|
+
)
|
157
|
+
@click.option("--dataset_start_index", type=int, default=0, help="Start index of the dataset.")
|
158
|
+
@click.option("--dataset_end_index", type=int, default=None, help="End index of the dataset.")
|
159
|
+
def main(
|
160
|
+
ds_path: str,
|
161
|
+
bs: int,
|
162
|
+
num_workers: int,
|
163
|
+
save_path: str,
|
164
|
+
endpoint: str,
|
165
|
+
flow: str,
|
166
|
+
checkpoint_dir: str,
|
167
|
+
save_freq: int,
|
168
|
+
debug: bool,
|
169
|
+
dataset_start_index: int,
|
170
|
+
dataset_end_index: int,
|
171
|
+
) -> None:
|
172
|
+
"""CLI entry point for running data generation flows.
|
173
|
+
|
174
|
+
Parameters
|
175
|
+
----------
|
176
|
+
ds_path : str
|
177
|
+
Path to the dataset file.
|
178
|
+
bs : int
|
179
|
+
Batch size for processing.
|
180
|
+
num_workers : int
|
181
|
+
Number of worker processes to use.
|
182
|
+
save_path : str
|
183
|
+
Path where the output will be saved.
|
184
|
+
endpoint : str
|
185
|
+
API endpoint for data processing.
|
186
|
+
flow : str
|
187
|
+
Path to the flow configuration file.
|
188
|
+
checkpoint_dir : str
|
189
|
+
Directory path for saving checkpoints.
|
190
|
+
save_freq : int
|
191
|
+
Frequency (in batches) at which to save checkpoints.
|
192
|
+
debug : bool
|
193
|
+
If True, enables debug mode with a smaller dataset subset.
|
194
|
+
|
195
|
+
Returns
|
196
|
+
-------
|
197
|
+
None
|
198
|
+
"""
|
199
|
+
run_flow(
|
200
|
+
ds_path=ds_path,
|
201
|
+
batch_size=bs,
|
202
|
+
num_workers=num_workers,
|
203
|
+
save_path=save_path,
|
204
|
+
endpoint=endpoint,
|
205
|
+
flow_path=flow,
|
206
|
+
checkpoint_dir=checkpoint_dir,
|
207
|
+
save_freq=save_freq,
|
208
|
+
debug=debug,
|
209
|
+
dataset_start_index=dataset_start_index,
|
210
|
+
dataset_end_index=dataset_end_index,
|
211
|
+
)
|
212
|
+
|
213
|
+
|
214
|
+
if __name__ == "__main__":
|
215
|
+
# pylint: disable=no-value-for-parameter
|
216
|
+
main()
|
@@ -12,7 +12,10 @@
|
|
12
12
|
output_cols:
|
13
13
|
- summary_detailed
|
14
14
|
gen_kwargs:
|
15
|
-
max_tokens:
|
15
|
+
max_tokens: 4096
|
16
|
+
temperature: 0.7
|
17
|
+
seed: 7452
|
18
|
+
n: 50
|
16
19
|
|
17
20
|
- block_type: LLMBlock
|
18
21
|
block_config:
|
@@ -22,7 +25,9 @@
|
|
22
25
|
output_cols:
|
23
26
|
- summary_atomic_facts
|
24
27
|
gen_kwargs:
|
25
|
-
max_tokens:
|
28
|
+
max_tokens: 4096
|
29
|
+
temperature: 0.7
|
30
|
+
seed: 7452
|
26
31
|
|
27
32
|
- block_type: LLMBlock
|
28
33
|
block_config:
|
@@ -32,7 +37,9 @@
|
|
32
37
|
output_cols:
|
33
38
|
- summary_extractive
|
34
39
|
gen_kwargs:
|
35
|
-
max_tokens:
|
40
|
+
max_tokens: 4096
|
41
|
+
temperature: 0.7
|
42
|
+
seed: 7452
|
36
43
|
|
37
44
|
- block_type: FlattenColumnsBlock
|
38
45
|
block_config:
|
@@ -55,19 +62,29 @@
|
|
55
62
|
- block_type: LLMBlock
|
56
63
|
block_config:
|
57
64
|
block_name: knowledge generation
|
58
|
-
config_path: configs/knowledge/
|
65
|
+
config_path: configs/knowledge/generate_questions.yaml
|
59
66
|
model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
|
60
67
|
output_cols:
|
61
68
|
- question
|
62
|
-
- response
|
63
69
|
parser_kwargs:
|
64
70
|
parser_name: custom
|
65
|
-
parsing_pattern: "\\[(?:Question|QUESTION)\\]\\s*(.*?)\\s
|
66
|
-
|
67
|
-
|
71
|
+
parsing_pattern: "\\[(?:Question|QUESTION)\\]\\s*(.*?)\\s*(?=\\[(?:Question|QUESTION)\\]|$)"
|
72
|
+
gen_kwargs:
|
73
|
+
temperature: 0.7
|
74
|
+
max_tokens: 100
|
75
|
+
seed: 7452
|
76
|
+
|
77
|
+
- block_type: LLMBlock
|
78
|
+
block_config:
|
79
|
+
block_name: knowledge generation
|
80
|
+
config_path: configs/knowledge/generate_responses.yaml
|
81
|
+
model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
|
82
|
+
output_cols:
|
83
|
+
- response
|
68
84
|
gen_kwargs:
|
69
|
-
temperature: 0.
|
85
|
+
temperature: 0.7
|
70
86
|
max_tokens: 2048
|
87
|
+
seed: 7452
|
71
88
|
|
72
89
|
- block_type: LLMBlock
|
73
90
|
block_config:
|