sdg-hub 0.1.0a4__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. sdg_hub/_version.py +2 -2
  2. sdg_hub/blocks/__init__.py +35 -5
  3. sdg_hub/blocks/block.py +58 -16
  4. sdg_hub/blocks/llmblock.py +121 -193
  5. sdg_hub/blocks/utilblocks.py +500 -43
  6. sdg_hub/checkpointer.py +139 -0
  7. sdg_hub/configs/annotations/detailed_annotations.yaml +28 -0
  8. sdg_hub/configs/annotations/simple_annotations.yaml +9 -0
  9. sdg_hub/configs/knowledge/atomic_facts.yaml +1 -0
  10. sdg_hub/configs/knowledge/detailed_summary.yaml +1 -0
  11. sdg_hub/configs/knowledge/extractive_summary.yaml +1 -0
  12. sdg_hub/configs/knowledge/generate_questions.yaml +82 -0
  13. sdg_hub/configs/knowledge/generate_responses.yaml +86 -0
  14. sdg_hub/configs/skills/contexts.yaml +18 -11
  15. sdg_hub/configs/skills/evaluate_freeform_pair.yaml +79 -12
  16. sdg_hub/configs/skills/evaluate_freeform_questions.yaml +60 -28
  17. sdg_hub/configs/skills/evaluate_grounded_pair.yaml +95 -30
  18. sdg_hub/configs/skills/freeform_questions.yaml +21 -16
  19. sdg_hub/configs/skills/freeform_responses.yaml +19 -25
  20. sdg_hub/configs/skills/router.yaml +53 -6
  21. sdg_hub/flow.py +351 -21
  22. sdg_hub/flow_runner.py +216 -0
  23. sdg_hub/flows/generation/knowledge/synth_knowledge1.5.yaml +26 -9
  24. sdg_hub/flows/generation/skills/{agentic_improve_skill.yaml → improve_responses.yaml} +26 -31
  25. sdg_hub/flows/generation/skills/synth_skills.yaml +4 -4
  26. sdg_hub/pipeline.py +67 -12
  27. sdg_hub/prompts.py +21 -0
  28. sdg_hub/sdg.py +128 -86
  29. sdg_hub/utils/config_validation.py +91 -0
  30. sdg_hub/utils/validation_result.py +10 -0
  31. sdg_hub-0.1.1.dist-info/METADATA +190 -0
  32. sdg_hub-0.1.1.dist-info/RECORD +86 -0
  33. {sdg_hub-0.1.0a4.dist-info → sdg_hub-0.1.1.dist-info}/WHEEL +1 -1
  34. sdg_hub/blocks/filterblock.py +0 -76
  35. sdg_hub/blocks/iterblock.py +0 -31
  36. sdg_hub/blocks/rmblocks.py +0 -194
  37. sdg_hub/configs/annotations/simple.yaml +0 -10
  38. sdg_hub/configs/knowledge/data_recipe/default_recipe.yaml +0 -3
  39. sdg_hub/configs/skills/data_recipe/default_recipe.yaml +0 -6
  40. sdg_hub/flows/annotation/emotion/detailed_description.yaml +0 -19
  41. sdg_hub/flows/annotation/emotion/detailed_description_icl.yaml +0 -19
  42. sdg_hub/flows/annotation/emotion/simple.yaml +0 -19
  43. sdg_hub/utils/chunking.py +0 -73
  44. sdg_hub/utils/docprocessor.py +0 -357
  45. sdg_hub/utils/parse_and_convert.py +0 -392
  46. sdg_hub-0.1.0a4.dist-info/METADATA +0 -309
  47. sdg_hub-0.1.0a4.dist-info/RECORD +0 -90
  48. /sdg_hub/configs/{knowledge/data_recipe → reasoning}/__init__.py +0 -0
  49. /sdg_hub/configs/skills/{_G_.yaml → icl_examples/STEM.yaml} +0 -0
  50. /sdg_hub/configs/skills/{data_recipe → icl_examples}/__init__.py +0 -0
  51. /sdg_hub/configs/skills/{_A_.yaml → icl_examples/coding.yaml} +0 -0
  52. /sdg_hub/configs/skills/{_B_.yaml → icl_examples/extraction.yaml} +0 -0
  53. /sdg_hub/configs/skills/{_C_.yaml → icl_examples/humanities.yaml} +0 -0
  54. /sdg_hub/configs/skills/{_D_.yaml → icl_examples/math.yaml} +0 -0
  55. /sdg_hub/configs/skills/{_E_.yaml → icl_examples/reasoning.yaml} +0 -0
  56. /sdg_hub/configs/skills/{_F_.yaml → icl_examples/roleplay.yaml} +0 -0
  57. /sdg_hub/configs/skills/{_H_.yaml → icl_examples/writing.yaml} +0 -0
  58. {sdg_hub-0.1.0a4.dist-info → sdg_hub-0.1.1.dist-info}/licenses/LICENSE +0 -0
  59. {sdg_hub-0.1.0a4.dist-info → sdg_hub-0.1.1.dist-info}/top_level.txt +0 -0
sdg_hub/flow.py CHANGED
@@ -1,58 +1,145 @@
1
+ """
2
+ Flow module for managing data generation pipelines.
3
+
4
+ This module provides the core Flow class that handles both configuration loading and execution
5
+ of data generation blocks. The Flow class serves as the main interface for defining and running
6
+ data generation pipelines, supporting both direct usage with SDG and backward compatibility
7
+ through the deprecated Pipeline class.
8
+
9
+ Example:
10
+ >>> flow = Flow(llm_client)
11
+ >>> flow = flow.get_flow_from_file("path/to/flow.yaml")
12
+ >>> dataset = flow.generate(input_dataset)
13
+
14
+ Note:
15
+ This module is part of the SDG Hub package and is designed to work in conjunction
16
+ with the SDG class for distributed data generation.
17
+ """
18
+
1
19
  # SPDX-License-Identifier: Apache-2.0
2
20
  # Standard
3
21
  from abc import ABC
4
22
  from importlib import resources
5
- from typing import Optional
23
+ from typing import Any, Callable, Dict, List, Optional
6
24
  import operator
7
25
  import os
8
26
 
9
27
  # Third Party
28
+ from datasets import Dataset
29
+ from datasets.data_files import EmptyDatasetError
30
+ from jinja2 import Environment, meta
31
+ from rich.console import Console
32
+ from rich.table import Table
10
33
  import yaml
11
34
 
12
35
  # Local
36
+ from .blocks import * # needed to register blocks
37
+ from .logger_config import setup_logger
38
+ from .prompts import * # needed to register prompts
13
39
  from .registry import BlockRegistry, PromptRegistry
14
- from . import prompts
15
- from . import blocks
40
+ from .utils.config_validation import validate_prompt_config_schema
41
+ from .utils.validation_result import ValidationResult
42
+
43
+ logger = setup_logger(__name__)
16
44
 
17
45
 
18
- OPERATOR_MAP = {
46
+ OPERATOR_MAP: Dict[str, Callable] = {
19
47
  "operator.eq": operator.eq,
20
48
  "operator.ge": operator.ge,
49
+ "operator.le": operator.le,
50
+ "operator.gt": operator.gt,
51
+ "operator.lt": operator.lt,
52
+ "operator.ne": operator.ne,
21
53
  "operator.contains": operator.contains,
22
54
  }
23
55
 
24
- CONVERT_DTYPE_MAP = {
56
+ CONVERT_DTYPE_MAP: Dict[str, Callable] = {
25
57
  "float": float,
26
58
  "int": int,
27
59
  }
28
60
 
29
61
 
30
62
  class Flow(ABC):
63
+ """A class representing a data generation flow.
64
+
65
+ This class handles both configuration loading and execution of data generation
66
+ blocks. It can be used directly with SDG or through the deprecated Pipeline class.
67
+ """
68
+
31
69
  def __init__(
32
70
  self,
33
- llm_client,
71
+ llm_client: Any,
34
72
  num_samples_to_generate: Optional[int] = None,
73
+ log_level: Optional[str] = None,
35
74
  ) -> None:
75
+ """
76
+ Initialize the Flow class.
77
+
78
+ Parameters
79
+ ----------
80
+ llm_client : Any
81
+ The LLM client to use for generation.
82
+ num_samples_to_generate : Optional[int], optional
83
+ Number of samples to generate, by default None
84
+ log_level : Optional[str], optional
85
+ Logging verbosity level, by default None
86
+
87
+ Attributes
88
+ ----------
89
+ llm_client : Any
90
+ The LLM client instance.
91
+ base_path : str
92
+ Base path for resource files.
93
+ registered_blocks : Dict[str, Any]
94
+ Registry of available blocks.
95
+ chained_blocks : Optional[List[Dict[str, Any]]]
96
+ List of block configurations.
97
+ num_samples_to_generate : Optional[int]
98
+ Number of samples to generate.
99
+
100
+ """
36
101
  self.llm_client = llm_client
37
- self.num_samples_to_generate = num_samples_to_generate
38
102
  self.base_path = str(resources.files(__package__))
39
103
  self.registered_blocks = BlockRegistry.get_registry()
104
+ self.chained_blocks = None # Will be set by get_flow_from_file
105
+ self.num_samples_to_generate = num_samples_to_generate
40
106
 
41
- def _getFilePath(self, dirs, filename):
42
- """
43
- Find a named configuration file.
107
+ # Logging verbosity level
108
+ self.log_level = log_level or os.getenv("SDG_HUB_LOG_LEVEL", "normal").lower()
109
+ self.console = Console() if self.log_level in ["verbose", "debug"] else None
44
110
 
45
- Files are checked in the following order
46
- - absulute path is always used
47
- - checked relative to the directories in "dirs"
48
- - relative the the current directory
111
+ def _log_block_info(
112
+ self, index: int, total: int, name: str, ds: Dataset, stage: str
113
+ ) -> None:
114
+ if self.log_level in ["verbose", "debug"] and self.console:
115
+ table = Table(
116
+ title=f"{stage} Block {index + 1}/{total}: {name}", show_header=True
117
+ )
118
+ table.add_column("Metric", style="cyan", no_wrap=True)
119
+ table.add_column("Value", style="magenta")
120
+ table.add_row("Rows", str(len(ds)))
121
+ table.add_row("Columns", ", ".join(ds.column_names))
122
+ self.console.print(table)
123
+
124
+ def _getFilePath(self, dirs: List[str], filename: str) -> str:
125
+ """Find a named configuration file.
49
126
 
50
- Args:
51
- dirs (list): Directories in which to search for "config_path"
52
- config_path (str): The path to the configuration file.
127
+ Files are checked in the following order:
128
+ 1. Absolute path is always used
129
+ 2. Checked relative to the directories in "dirs"
130
+ 3. Relative to the current directory
53
131
 
54
- Returns:
55
- Selected file path
132
+ Parameters
133
+ ----------
134
+ dirs : List[str]
135
+ Directories in which to search for the file.
136
+ filename : str
137
+ The path to the configuration file.
138
+
139
+ Returns
140
+ -------
141
+ str
142
+ Selected file path.
56
143
  """
57
144
  if os.path.isabs(filename):
58
145
  return filename
@@ -64,7 +151,175 @@ class Flow(ABC):
64
151
  # assume the path is relative to the current directory
65
152
  return filename
66
153
 
67
- def get_flow_from_file(self, yaml_path: str) -> list:
154
+ def _drop_duplicates(self, dataset: Dataset, cols: List[str]) -> Dataset:
155
+ """Drop duplicates from the dataset based on the columns provided.
156
+
157
+ Parameters
158
+ ----------
159
+ dataset : Dataset
160
+ The input dataset.
161
+ cols : List[str]
162
+ Columns to consider for duplicate detection.
163
+
164
+ Returns
165
+ -------
166
+ Dataset
167
+ Dataset with duplicates removed.
168
+ """
169
+ df = dataset.to_pandas()
170
+ df = df.drop_duplicates(subset=cols).reset_index(drop=True)
171
+ return Dataset.from_pandas(df)
172
+
173
+ def generate(self, dataset: Dataset) -> Dataset:
174
+ """Generate the dataset by running the pipeline steps.
175
+
176
+ Parameters
177
+ ----------
178
+ dataset : Dataset
179
+ The input dataset to process.
180
+
181
+ Returns
182
+ -------
183
+ Dataset
184
+ The processed dataset.
185
+
186
+ Raises
187
+ ------
188
+ ValueError
189
+ If Flow has not been initialized with blocks.
190
+ EmptyDatasetError
191
+ If a block produces an empty dataset.
192
+ """
193
+ if self.chained_blocks is None:
194
+ raise ValueError(
195
+ "Flow has not been initialized with blocks. "
196
+ "Call get_flow_from_file() first. "
197
+ "Or pass a list of blocks to the Flow constructor."
198
+ )
199
+
200
+ for i, block_prop in enumerate(self.chained_blocks):
201
+ block_type = block_prop["block_type"]
202
+ block_config = block_prop["block_config"]
203
+ drop_columns = block_prop.get("drop_columns", [])
204
+ gen_kwargs = block_prop.get("gen_kwargs", {})
205
+ drop_duplicates_cols = block_prop.get("drop_duplicates", False)
206
+ block = block_type(**block_config)
207
+
208
+ name = block_config.get("block_name", f"block_{i}")
209
+
210
+ # Logging: always show basic progress unless in quiet mode
211
+ if self.log_level in ["normal", "verbose", "debug"]:
212
+ logger.info(
213
+ f"🔄 Running block {i + 1}/{len(self.chained_blocks)}: {name}"
214
+ )
215
+
216
+ # Log dataset shape before block (verbose/debug)
217
+ self._log_block_info(i, len(self.chained_blocks), name, dataset, "Input")
218
+
219
+ if self.log_level == "debug":
220
+ logger.debug(f"Input dataset (truncated): {dataset}")
221
+
222
+ dataset = block.generate(dataset, **gen_kwargs)
223
+
224
+ if len(dataset) == 0:
225
+ raise EmptyDatasetError(
226
+ f"Pipeline stopped: "
227
+ f"Empty dataset after running block: "
228
+ f"{block_config['block_name']}"
229
+ )
230
+
231
+ drop_columns_in_ds = [e for e in drop_columns if e in dataset.column_names]
232
+ if drop_columns:
233
+ dataset = dataset.remove_columns(drop_columns_in_ds)
234
+
235
+ if drop_duplicates_cols:
236
+ dataset = self._drop_duplicates(dataset, cols=drop_duplicates_cols)
237
+
238
+ # Log dataset shape after block (verbose/debug)
239
+ self._log_block_info(i, len(self.chained_blocks), name, dataset, "Output")
240
+
241
+ if self.log_level == "debug":
242
+ logger.debug(f"Output dataset (truncated): {dataset}")
243
+
244
+ return dataset
245
+
246
+ def validate_config_files(self) -> "ValidationResult":
247
+ """
248
+ Validate all configuration file paths referenced in the flow blocks.
249
+
250
+ This method checks that all config files specified via `config_path` or `config_paths`
251
+ in each block:
252
+ - Exist on the filesystem
253
+ - Are readable by the current process
254
+ - Are valid YAML files (optional format check)
255
+
256
+ Returns
257
+ -------
258
+ ValidationResult
259
+ An object indicating whether all config files passed validation, along with a list
260
+ of error messages for any missing, unreadable, or invalid YAML files.
261
+
262
+ Notes
263
+ -----
264
+ This method is automatically called at the end of `get_flow_from_file()` to ensure
265
+ early detection of misconfigured blocks.
266
+ """
267
+ errors = []
268
+
269
+ def check_file(path: str, context: str):
270
+ if not os.path.isfile(path):
271
+ errors.append(f"[{context}] File does not exist: {path}")
272
+ else:
273
+ try:
274
+ with open(path, "r", encoding="utf-8") as f:
275
+ config_data = yaml.safe_load(f)
276
+ _, validation_errors = validate_prompt_config_schema(config_data, path)
277
+
278
+ if validation_errors:
279
+ errors.extend(validation_errors)
280
+
281
+ except PermissionError:
282
+ errors.append(f"[{context}] File is not readable: {path}")
283
+ except yaml.YAMLError as e:
284
+ errors.append(f"[{context}] YAML load failed: {path} ({e})")
285
+
286
+ for i, block in enumerate(self.chained_blocks or []):
287
+ block_name = block["block_config"].get("block_name", f"block_{i}")
288
+
289
+ config_path = block["block_config"].get("config_path")
290
+ if config_path:
291
+ check_file(config_path, f"{block_name}.config_path")
292
+
293
+ config_paths = block["block_config"].get("config_paths")
294
+ if isinstance(config_paths, list):
295
+ for idx, path in enumerate(config_paths):
296
+ check_file(path, f"{block_name}.config_paths[{idx}]")
297
+ elif isinstance(config_paths, dict):
298
+ for key, path in config_paths.items():
299
+ check_file(path, f"{block_name}.config_paths['{key}']")
300
+
301
+ return ValidationResult(valid=(len(errors) == 0), errors=errors)
302
+
303
+ def get_flow_from_file(self, yaml_path: str) -> "Flow":
304
+ """Load and initialize flow configuration from a YAML file.
305
+
306
+ Parameters
307
+ ----------
308
+ yaml_path : str
309
+ Path to the YAML configuration file.
310
+
311
+ Returns
312
+ -------
313
+ Flow
314
+ Self with initialized chained_blocks.
315
+
316
+ Raises
317
+ ------
318
+ FileNotFoundError
319
+ If the YAML file cannot be found.
320
+ KeyError
321
+ If a required block or prompt is not found in the registry.
322
+ """
68
323
  yaml_path_relative_to_base = os.path.join(self.base_path, yaml_path)
69
324
  if os.path.isfile(yaml_path_relative_to_base):
70
325
  yaml_path = yaml_path_relative_to_base
@@ -141,4 +396,79 @@ class Flow(ABC):
141
396
  block["block_config"]["convert_dtype"]
142
397
  ]
143
398
 
144
- return flow
399
+ # Store the chained blocks and return self
400
+ self.chained_blocks = flow
401
+
402
+ # Validate config files
403
+ result = self.validate_config_files()
404
+ if not result.valid:
405
+ raise ValueError("Invalid config files:\n\n".join(result.errors))
406
+
407
+ return self
408
+
409
+ def validate_flow(self, dataset: Dataset) -> "ValidationResult":
410
+ """
411
+ Validate that all required dataset columns are present before executing the flow.
412
+
413
+ This includes:
414
+ - Columns referenced in Jinja templates for LLM blocks
415
+ - Columns required by specific utility blocks (e.g. filter_column, choice_col, etc.)
416
+
417
+ Parameters
418
+ ----------
419
+ dataset : Dataset
420
+ The input dataset to validate against.
421
+
422
+ Returns
423
+ -------
424
+ ValidationResult
425
+ Whether the dataset has all required columns, and which ones are missing.
426
+ """
427
+ errors = []
428
+ all_columns = set(dataset.column_names)
429
+
430
+ for i, block in enumerate(self.chained_blocks or []):
431
+ name = block["block_config"].get("block_name", f"block_{i}")
432
+ block_type = block["block_type"]
433
+ config = block["block_config"]
434
+
435
+ # LLM Block: parse Jinja vars
436
+ cls_name = block_type.__name__ if isinstance(block_type, type) else block_type.__class__.__name__
437
+ logger.info(f"Validating block: {name} ({cls_name})")
438
+ if "LLM" in cls_name:
439
+ config_path = config.get("config_path")
440
+ if config_path and os.path.isfile(config_path):
441
+ with open(config_path, "r", encoding="utf-8") as f:
442
+ content = f.read()
443
+ env = Environment()
444
+ ast = env.parse(content)
445
+ vars_found = meta.find_undeclared_variables(ast)
446
+ for var in vars_found:
447
+ if var not in all_columns:
448
+ errors.append(f"[{name}] Missing column for prompt var: '{var}'")
449
+
450
+ # FilterByValueBlock
451
+ if "FilterByValueBlock" in str(block_type):
452
+ col = config.get("filter_column")
453
+ if col and col not in all_columns:
454
+ errors.append(f"[{name}] Missing filter_column: '{col}'")
455
+
456
+ # SelectorBlock
457
+ if "SelectorBlock" in str(block_type):
458
+ col = config.get("choice_col")
459
+ if col and col not in all_columns:
460
+ errors.append(f"[{name}] Missing choice_col: '{col}'")
461
+
462
+ choice_map = config.get("choice_map", {})
463
+ for col in choice_map.values():
464
+ if col not in all_columns:
465
+ errors.append(f"[{name}] choice_map references missing column: '{col}'")
466
+
467
+ # CombineColumnsBlock
468
+ if "CombineColumnsBlock" in str(block_type):
469
+ cols = config.get("columns", [])
470
+ for col in cols:
471
+ if col not in all_columns:
472
+ errors.append(f"[{name}] CombineColumnsBlock requires column: '{col}'")
473
+
474
+ return ValidationResult(valid=(len(errors) == 0), errors=errors)
sdg_hub/flow_runner.py ADDED
@@ -0,0 +1,216 @@
1
+ """Script for running data generation flows with configurable parameters."""
2
+
3
+ # Standard
4
+ import os
5
+
6
+ # Third Party
7
+ from datasets import load_dataset
8
+ from openai import OpenAI
9
+ import click
10
+
11
+ # First Party
12
+ from sdg_hub.flow import Flow
13
+ from sdg_hub.logger_config import setup_logger
14
+ from sdg_hub.sdg import SDG
15
+
16
+
17
+ logger = setup_logger(__name__)
18
+
19
+
20
+ def run_flow(
21
+ ds_path: str,
22
+ save_path: str,
23
+ endpoint: str,
24
+ flow_path: str,
25
+ checkpoint_dir: str,
26
+ batch_size: int = 8,
27
+ num_workers: int = 32,
28
+ save_freq: int = 2,
29
+ debug: bool = False,
30
+ dataset_start_index: int = 0,
31
+ dataset_end_index: int = None,
32
+ ) -> None:
33
+ """Process the dataset using the specified configuration.
34
+
35
+ Parameters
36
+ ----------
37
+ ds_path : str
38
+ Path to the dataset file.
39
+ save_path : str
40
+ Path where the output will be saved.
41
+ endpoint : str
42
+ API endpoint for data processing.
43
+ flow_path : str
44
+ Path to the flow configuration file.
45
+ checkpoint_dir : str
46
+ Directory path for saving checkpoints.
47
+ batch_size : int, optional
48
+ Batch size for processing, by default 8.
49
+ num_workers : int, optional
50
+ Number of worker processes to use, by default 32.
51
+ save_freq : int, optional
52
+ Frequency (in batches) at which to save checkpoints, by default 2.
53
+ debug : bool, optional
54
+ If True, enables debug mode with a smaller dataset subset, by default False.
55
+
56
+ Returns
57
+ -------
58
+ None
59
+
60
+ Raises
61
+ ------
62
+ FileNotFoundError
63
+ If the flow configuration file is not found.
64
+ """
65
+ logger.info(f"Generation configuration: {locals()}\n\n")
66
+ ds = load_dataset("json", data_files=ds_path, split="train")
67
+ if dataset_start_index is not None and dataset_end_index is not None:
68
+ ds = ds.select(range(dataset_start_index, dataset_end_index))
69
+ logger.info(f"Dataset sliced from {dataset_start_index} to {dataset_end_index}")
70
+ if debug:
71
+ ds = ds.shuffle(seed=42).select(range(30))
72
+ logger.info("Debug mode enabled. Using a subset of the dataset.")
73
+
74
+ openai_api_key = os.environ.get("OPENAI_API_KEY", "EMPTY")
75
+ openai_api_base = endpoint
76
+
77
+ client = OpenAI(
78
+ api_key=openai_api_key,
79
+ base_url=openai_api_base,
80
+ )
81
+
82
+ if not os.path.exists(flow_path):
83
+ raise FileNotFoundError(f"Flow file not found: {flow_path}")
84
+
85
+ flow = Flow(client).get_flow_from_file(flow_path)
86
+ sdg = SDG(
87
+ flows=[flow],
88
+ num_workers=num_workers,
89
+ batch_size=batch_size,
90
+ save_freq=save_freq,
91
+ )
92
+
93
+ generated_data = sdg.generate(ds, checkpoint_dir=checkpoint_dir)
94
+ if dataset_end_index is not None and dataset_start_index is not None:
95
+ save_path = save_path.replace(".jsonl", f"_{dataset_start_index}_{dataset_end_index}.jsonl")
96
+ generated_data.to_json(save_path, orient="records", lines=True)
97
+ logger.info(f"Data saved to {save_path}")
98
+
99
+
100
+ @click.command()
101
+ @click.option(
102
+ "--ds_path",
103
+ type=click.Path(exists=True),
104
+ required=True,
105
+ help="Path to the dataset.",
106
+ )
107
+ @click.option(
108
+ "--bs",
109
+ type=int,
110
+ default=8,
111
+ show_default=True,
112
+ help="Batch size for processing.",
113
+ )
114
+ @click.option(
115
+ "--num_workers",
116
+ type=int,
117
+ default=32,
118
+ show_default=True,
119
+ help="Number of worker processes to use.",
120
+ )
121
+ @click.option(
122
+ "--save_path",
123
+ type=click.Path(),
124
+ required=True,
125
+ help="Path to save the output.",
126
+ )
127
+ @click.option(
128
+ "--endpoint",
129
+ type=str,
130
+ required=True,
131
+ help="API endpoint for data processing.",
132
+ )
133
+ @click.option(
134
+ "--flow",
135
+ type=click.Path(exists=True),
136
+ required=True,
137
+ help="Flow configuration for the process.",
138
+ )
139
+ @click.option(
140
+ "--checkpoint_dir",
141
+ type=click.Path(),
142
+ required=True,
143
+ help="Path to save checkpoints.",
144
+ )
145
+ @click.option(
146
+ "--save_freq",
147
+ type=int,
148
+ default=2,
149
+ show_default=True,
150
+ help="Frequency to save checkpoints.",
151
+ )
152
+ @click.option(
153
+ "--debug",
154
+ is_flag=True,
155
+ help="Enable debug mode with a smaller dataset subset.",
156
+ )
157
+ @click.option("--dataset_start_index", type=int, default=0, help="Start index of the dataset.")
158
+ @click.option("--dataset_end_index", type=int, default=None, help="End index of the dataset.")
159
+ def main(
160
+ ds_path: str,
161
+ bs: int,
162
+ num_workers: int,
163
+ save_path: str,
164
+ endpoint: str,
165
+ flow: str,
166
+ checkpoint_dir: str,
167
+ save_freq: int,
168
+ debug: bool,
169
+ dataset_start_index: int,
170
+ dataset_end_index: int,
171
+ ) -> None:
172
+ """CLI entry point for running data generation flows.
173
+
174
+ Parameters
175
+ ----------
176
+ ds_path : str
177
+ Path to the dataset file.
178
+ bs : int
179
+ Batch size for processing.
180
+ num_workers : int
181
+ Number of worker processes to use.
182
+ save_path : str
183
+ Path where the output will be saved.
184
+ endpoint : str
185
+ API endpoint for data processing.
186
+ flow : str
187
+ Path to the flow configuration file.
188
+ checkpoint_dir : str
189
+ Directory path for saving checkpoints.
190
+ save_freq : int
191
+ Frequency (in batches) at which to save checkpoints.
192
+ debug : bool
193
+ If True, enables debug mode with a smaller dataset subset.
194
+
195
+ Returns
196
+ -------
197
+ None
198
+ """
199
+ run_flow(
200
+ ds_path=ds_path,
201
+ batch_size=bs,
202
+ num_workers=num_workers,
203
+ save_path=save_path,
204
+ endpoint=endpoint,
205
+ flow_path=flow,
206
+ checkpoint_dir=checkpoint_dir,
207
+ save_freq=save_freq,
208
+ debug=debug,
209
+ dataset_start_index=dataset_start_index,
210
+ dataset_end_index=dataset_end_index,
211
+ )
212
+
213
+
214
+ if __name__ == "__main__":
215
+ # pylint: disable=no-value-for-parameter
216
+ main()
@@ -12,7 +12,10 @@
12
12
  output_cols:
13
13
  - summary_detailed
14
14
  gen_kwargs:
15
- max_tokens: 2048
15
+ max_tokens: 4096
16
+ temperature: 0.7
17
+ seed: 7452
18
+ n: 50
16
19
 
17
20
  - block_type: LLMBlock
18
21
  block_config:
@@ -22,7 +25,9 @@
22
25
  output_cols:
23
26
  - summary_atomic_facts
24
27
  gen_kwargs:
25
- max_tokens: 2048
28
+ max_tokens: 4096
29
+ temperature: 0.7
30
+ seed: 7452
26
31
 
27
32
  - block_type: LLMBlock
28
33
  block_config:
@@ -32,7 +37,9 @@
32
37
  output_cols:
33
38
  - summary_extractive
34
39
  gen_kwargs:
35
- max_tokens: 2048
40
+ max_tokens: 4096
41
+ temperature: 0.7
42
+ seed: 7452
36
43
 
37
44
  - block_type: FlattenColumnsBlock
38
45
  block_config:
@@ -55,19 +62,29 @@
55
62
  - block_type: LLMBlock
56
63
  block_config:
57
64
  block_name: knowledge generation
58
- config_path: configs/knowledge/generate_questions_responses.yaml
65
+ config_path: configs/knowledge/generate_questions.yaml
59
66
  model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
60
67
  output_cols:
61
68
  - question
62
- - response
63
69
  parser_kwargs:
64
70
  parser_name: custom
65
- parsing_pattern: "\\[(?:Question|QUESTION)\\]\\s*(.*?)\\s*\\[(?:Answer|ANSWER)\\]\\s*(.*?)\\s*(?=\\[(?:Question|QUESTION)\\]|$)"
66
- parser_cleanup_tags:
67
- - "[END]"
71
+ parsing_pattern: "\\[(?:Question|QUESTION)\\]\\s*(.*?)\\s*(?=\\[(?:Question|QUESTION)\\]|$)"
72
+ gen_kwargs:
73
+ temperature: 0.7
74
+ max_tokens: 100
75
+ seed: 7452
76
+
77
+ - block_type: LLMBlock
78
+ block_config:
79
+ block_name: knowledge generation
80
+ config_path: configs/knowledge/generate_responses.yaml
81
+ model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
82
+ output_cols:
83
+ - response
68
84
  gen_kwargs:
69
- temperature: 0.0
85
+ temperature: 0.7
70
86
  max_tokens: 2048
87
+ seed: 7452
71
88
 
72
89
  - block_type: LLMBlock
73
90
  block_config: