sdg-hub 0.1.0a4__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. sdg_hub/_version.py +2 -2
  2. sdg_hub/blocks/__init__.py +35 -5
  3. sdg_hub/blocks/block.py +58 -16
  4. sdg_hub/blocks/llmblock.py +121 -193
  5. sdg_hub/blocks/utilblocks.py +500 -43
  6. sdg_hub/checkpointer.py +139 -0
  7. sdg_hub/configs/annotations/detailed_annotations.yaml +28 -0
  8. sdg_hub/configs/annotations/simple_annotations.yaml +9 -0
  9. sdg_hub/configs/knowledge/atomic_facts.yaml +1 -0
  10. sdg_hub/configs/knowledge/detailed_summary.yaml +1 -0
  11. sdg_hub/configs/knowledge/extractive_summary.yaml +1 -0
  12. sdg_hub/configs/knowledge/generate_questions.yaml +82 -0
  13. sdg_hub/configs/knowledge/generate_responses.yaml +86 -0
  14. sdg_hub/configs/skills/contexts.yaml +18 -11
  15. sdg_hub/configs/skills/evaluate_freeform_pair.yaml +79 -12
  16. sdg_hub/configs/skills/evaluate_freeform_questions.yaml +60 -28
  17. sdg_hub/configs/skills/evaluate_grounded_pair.yaml +95 -30
  18. sdg_hub/configs/skills/freeform_questions.yaml +21 -16
  19. sdg_hub/configs/skills/freeform_responses.yaml +19 -25
  20. sdg_hub/configs/skills/router.yaml +53 -6
  21. sdg_hub/flow.py +351 -21
  22. sdg_hub/flow_runner.py +216 -0
  23. sdg_hub/flows/generation/knowledge/synth_knowledge1.5.yaml +26 -9
  24. sdg_hub/flows/generation/skills/{agentic_improve_skill.yaml → improve_responses.yaml} +26 -31
  25. sdg_hub/flows/generation/skills/synth_skills.yaml +4 -4
  26. sdg_hub/pipeline.py +67 -12
  27. sdg_hub/prompts.py +21 -0
  28. sdg_hub/sdg.py +128 -86
  29. sdg_hub/utils/config_validation.py +91 -0
  30. sdg_hub/utils/validation_result.py +10 -0
  31. sdg_hub-0.1.1.dist-info/METADATA +190 -0
  32. sdg_hub-0.1.1.dist-info/RECORD +86 -0
  33. {sdg_hub-0.1.0a4.dist-info → sdg_hub-0.1.1.dist-info}/WHEEL +1 -1
  34. sdg_hub/blocks/filterblock.py +0 -76
  35. sdg_hub/blocks/iterblock.py +0 -31
  36. sdg_hub/blocks/rmblocks.py +0 -194
  37. sdg_hub/configs/annotations/simple.yaml +0 -10
  38. sdg_hub/configs/knowledge/data_recipe/default_recipe.yaml +0 -3
  39. sdg_hub/configs/skills/data_recipe/default_recipe.yaml +0 -6
  40. sdg_hub/flows/annotation/emotion/detailed_description.yaml +0 -19
  41. sdg_hub/flows/annotation/emotion/detailed_description_icl.yaml +0 -19
  42. sdg_hub/flows/annotation/emotion/simple.yaml +0 -19
  43. sdg_hub/utils/chunking.py +0 -73
  44. sdg_hub/utils/docprocessor.py +0 -357
  45. sdg_hub/utils/parse_and_convert.py +0 -392
  46. sdg_hub-0.1.0a4.dist-info/METADATA +0 -309
  47. sdg_hub-0.1.0a4.dist-info/RECORD +0 -90
  48. /sdg_hub/configs/{knowledge/data_recipe → reasoning}/__init__.py +0 -0
  49. /sdg_hub/configs/skills/{_G_.yaml → icl_examples/STEM.yaml} +0 -0
  50. /sdg_hub/configs/skills/{data_recipe → icl_examples}/__init__.py +0 -0
  51. /sdg_hub/configs/skills/{_A_.yaml → icl_examples/coding.yaml} +0 -0
  52. /sdg_hub/configs/skills/{_B_.yaml → icl_examples/extraction.yaml} +0 -0
  53. /sdg_hub/configs/skills/{_C_.yaml → icl_examples/humanities.yaml} +0 -0
  54. /sdg_hub/configs/skills/{_D_.yaml → icl_examples/math.yaml} +0 -0
  55. /sdg_hub/configs/skills/{_E_.yaml → icl_examples/reasoning.yaml} +0 -0
  56. /sdg_hub/configs/skills/{_F_.yaml → icl_examples/roleplay.yaml} +0 -0
  57. /sdg_hub/configs/skills/{_H_.yaml → icl_examples/writing.yaml} +0 -0
  58. {sdg_hub-0.1.0a4.dist-info → sdg_hub-0.1.1.dist-info}/licenses/LICENSE +0 -0
  59. {sdg_hub-0.1.0a4.dist-info → sdg_hub-0.1.1.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,14 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
+ """Utility blocks for dataset manipulation and transformation.
3
+
4
+ This module provides various utility blocks for operations like column manipulation,
5
+ data population, selection, and transformation of datasets.
6
+ """
7
+
8
+ # Standard
9
+ import operator
10
+ from typing import Any, Callable, Dict, List, Optional, Type, Union
11
+
2
12
  # Third Party
3
13
  from datasets import Dataset
4
14
 
@@ -10,12 +20,156 @@ from ..logger_config import setup_logger
10
20
  logger = setup_logger(__name__)
11
21
 
12
22
 
23
+ @BlockRegistry.register("FilterByValueBlock")
24
+ class FilterByValueBlock(Block):
25
+ """A block for filtering datasets based on column values.
26
+
27
+ This block allows filtering of datasets using various operations (e.g., equals, contains)
28
+ on specified column values, with optional data type conversion
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ block_name: str,
34
+ filter_column: str,
35
+ filter_value: Union[Any, List[Any]],
36
+ operation: Callable[[Any, Any], bool],
37
+ convert_dtype: Optional[Union[Type[float], Type[int]]] = None,
38
+ **batch_kwargs: Dict[str, Any],
39
+ ) -> None:
40
+ """Initialize a new FilterByValueBlock instance.
41
+
42
+ Parameters
43
+ ----------
44
+ block_name : str
45
+ Name of the block.
46
+ filter_column : str
47
+ The name of the column in the dataset to apply the filter on.
48
+ filter_value : Union[Any, List[Any]]
49
+ The value(s) to filter by.
50
+ operation : Callable[[Any, Any], bool]
51
+ A binary operator from the operator module (e.g., operator.eq, operator.contains)
52
+ that takes two arguments and returns a boolean.
53
+ convert_dtype : Optional[Union[Type[float], Type[int]]], optional
54
+ Type to convert the filter column to. Can be either float or int.
55
+ If None, no conversion is performed.
56
+ **batch_kwargs : Dict[str, Any]
57
+ Additional kwargs for batch processing.
58
+
59
+ Returns
60
+ -------
61
+ None
62
+
63
+ Raises
64
+ ------
65
+ ValueError
66
+ If the operation is not from the operator module.
67
+ """
68
+ super().__init__(block_name=block_name)
69
+ # Validate that operation is from operator module
70
+ if operation.__module__ != "_operator":
71
+ logger.error("Invalid operation: %s", operation)
72
+ raise ValueError("Operation must be from operator module")
73
+
74
+ self.value = filter_value if isinstance(filter_value, list) else [filter_value]
75
+ self.column_name = filter_column
76
+ self.operation = operation
77
+ self.convert_dtype = convert_dtype
78
+ self.num_procs = batch_kwargs.get("num_procs", 1)
79
+
80
+ def _convert_dtype(self, sample: Dict[str, Any]) -> Dict[str, Any]:
81
+ """Convert the data type of the filter column.
82
+
83
+ Parameters
84
+ ----------
85
+ sample : Dict[str, Any]
86
+ The sample dictionary containing the column to convert.
87
+
88
+ Returns
89
+ -------
90
+ Dict[str, Any]
91
+ The sample with converted column value.
92
+ """
93
+ try:
94
+ sample[self.column_name] = self.convert_dtype(sample[self.column_name])
95
+ except ValueError as e:
96
+ logger.error(
97
+ "Error converting dtype: %s, filling with None to be filtered later", e
98
+ )
99
+ sample[self.column_name] = None
100
+ return sample
101
+
102
+ def generate(self, samples: Dataset) -> Dataset:
103
+ """Generate filtered dataset based on specified conditions.
104
+
105
+ Parameters
106
+ ----------
107
+ samples : Dataset
108
+ The input dataset to filter.
109
+
110
+ Returns
111
+ -------
112
+ Dataset
113
+ The filtered dataset.
114
+ """
115
+ if self.convert_dtype:
116
+ samples = samples.map(
117
+ self._convert_dtype,
118
+ num_proc=self.num_procs,
119
+ )
120
+
121
+ if self.operation == operator.contains:
122
+ samples = samples.filter(
123
+ lambda x: self.operation(self.value, x[self.column_name]),
124
+ num_proc=self.num_procs,
125
+ )
126
+
127
+ samples = samples.filter(
128
+ lambda x: x[self.column_name] is not None,
129
+ num_proc=self.num_procs,
130
+ )
131
+
132
+ samples = samples.filter(
133
+ lambda x: any(
134
+ self.operation(x[self.column_name], value) for value in self.value
135
+ ),
136
+ num_proc=self.num_procs,
137
+ )
138
+
139
+ return samples
140
+
141
+
13
142
  @BlockRegistry.register("SamplePopulatorBlock")
14
143
  class SamplePopulatorBlock(Block):
15
- def __init__(self, config_paths, column_name, post_fix="", **batch_kwargs) -> None:
16
- super().__init__(
17
- block_name=self.__class__.__name__
18
- ) # Call the base class's __init__
144
+ """Block for populating dataset with data from configuration files.
145
+
146
+ This block reads data from one or more configuration files and populates a
147
+ dataset with the data. The data is stored in a dictionary, with the keys
148
+ being the names of the configuration files.
149
+
150
+ Parameters
151
+ ----------
152
+ block_name : str
153
+ Name of the block.
154
+ config_paths : List[str]
155
+ List of paths to configuration files to load.
156
+ column_name : str
157
+ Name of the column to use as key for populating data.
158
+ post_fix : str, optional
159
+ Suffix to append to configuration filenames, by default "".
160
+ **batch_kwargs : Dict[str, Any]
161
+ Additional keyword arguments for batch processing.
162
+ """
163
+
164
+ def __init__(
165
+ self,
166
+ block_name: str,
167
+ config_paths: List[str],
168
+ column_name: str,
169
+ post_fix: str = "",
170
+ **batch_kwargs: Dict[str, Any],
171
+ ) -> None:
172
+ super().__init__(block_name=block_name)
19
173
  self.configs = {}
20
174
  for config in config_paths:
21
175
  if post_fix:
@@ -27,114 +181,417 @@ class SamplePopulatorBlock(Block):
27
181
  self.column_name = column_name
28
182
  self.num_procs = batch_kwargs.get("num_procs", 8)
29
183
 
30
- def _generate(self, sample) -> dict:
184
+ def _generate(self, sample: Dict[str, Any]) -> Dict[str, Any]:
185
+ """Generate a new sample by populating it with configuration data.
186
+
187
+ Parameters
188
+ ----------
189
+ sample : Dict[str, Any]
190
+ Input sample to populate with configuration data.
191
+
192
+ Returns
193
+ -------
194
+ Dict[str, Any]
195
+ Sample populated with configuration data.
196
+ """
31
197
  sample = {**sample, **self.configs[sample[self.column_name]]}
32
198
  return sample
33
199
 
34
- def generate(self, samples) -> Dataset:
200
+ def generate(self, samples: Dataset) -> Dataset:
201
+ """Generate a new dataset with populated configuration data.
202
+
203
+ Parameters
204
+ ----------
205
+ samples : Dataset
206
+ Input dataset to populate with configuration data.
207
+
208
+ Returns
209
+ -------
210
+ Dataset
211
+ Dataset populated with configuration data.
212
+ """
35
213
  samples = samples.map(self._generate, num_proc=self.num_procs)
36
214
  return samples
37
215
 
38
216
 
39
217
  @BlockRegistry.register("SelectorBlock")
40
218
  class SelectorBlock(Block):
41
- def __init__(self, choice_map, choice_col, output_col, **batch_kwargs) -> None:
42
- super().__init__(block_name=self.__class__.__name__)
219
+ """Block for selecting and mapping values from one column to another.
220
+
221
+ This block uses a mapping dictionary to select values from one column and
222
+ store them in a new output column based on a choice column's value.
223
+
224
+ Parameters
225
+ ----------
226
+ block_name : str
227
+ Name of the block.
228
+ choice_map : Dict[str, str]
229
+ Dictionary mapping choice values to column names.
230
+ choice_col : str
231
+ Name of the column containing choice values.
232
+ output_col : str
233
+ Name of the column to store selected values.
234
+ **batch_kwargs : Dict[str, Any]
235
+ Additional keyword arguments for batch processing.
236
+ """
237
+
238
+ def __init__(
239
+ self,
240
+ block_name: str,
241
+ choice_map: Dict[str, str],
242
+ choice_col: str,
243
+ output_col: str,
244
+ **batch_kwargs: Dict[str, Any],
245
+ ) -> None:
246
+ super().__init__(block_name=block_name)
43
247
  self.choice_map = choice_map
44
248
  self.choice_col = choice_col
45
249
  self.output_col = output_col
46
250
  self.num_procs = batch_kwargs.get("num_procs", 8)
47
251
 
48
- def _generate(self, sample) -> dict:
252
+ def _generate(self, sample: Dict[str, Any]) -> Dict[str, Any]:
253
+ """Generate a new sample by selecting values based on choice mapping.
254
+
255
+ Parameters
256
+ ----------
257
+ sample : Dict[str, Any]
258
+ Input sample to process.
259
+
260
+ Returns
261
+ -------
262
+ Dict[str, Any]
263
+ Sample with selected values stored in output column.
264
+ """
49
265
  sample[self.output_col] = sample[self.choice_map[sample[self.choice_col]]]
50
266
  return sample
51
267
 
52
268
  def generate(self, samples: Dataset) -> Dataset:
269
+ """Generate a new dataset with selected values.
270
+
271
+ Parameters
272
+ ----------
273
+ samples : Dataset
274
+ Input dataset to process.
275
+
276
+ Returns
277
+ -------
278
+ Dataset
279
+ Dataset with selected values stored in output column.
280
+ """
53
281
  samples = samples.map(self._generate, num_proc=self.num_procs)
54
282
  return samples
55
283
 
56
284
 
57
285
  @BlockRegistry.register("CombineColumnsBlock")
58
286
  class CombineColumnsBlock(Block):
59
- def __init__(self, columns, output_col, separator="\n\n", **batch_kwargs) -> None:
60
- super().__init__(block_name=self.__class__.__name__)
287
+ r"""Block for combining multiple columns into a single column.
288
+
289
+ This block concatenates values from multiple columns into a single output column,
290
+ using a specified separator between values.
291
+
292
+ Parameters
293
+ ----------
294
+ block_name : str
295
+ Name of the block.
296
+ columns : List[str]
297
+ List of column names to combine.
298
+ output_col : str
299
+ Name of the column to store combined values.
300
+ separator : str, optional
301
+ String to use as separator between combined values, by default "\n\n".
302
+ **batch_kwargs : Dict[str, Any]
303
+ Additional keyword arguments for batch processing.
304
+ """
305
+
306
+ def __init__(
307
+ self,
308
+ block_name: str,
309
+ columns: List[str],
310
+ output_col: str,
311
+ separator: str = "\n\n",
312
+ **batch_kwargs: Dict[str, Any],
313
+ ) -> None:
314
+ super().__init__(block_name=block_name)
61
315
  self.columns = columns
62
316
  self.output_col = output_col
63
317
  self.separator = separator
64
318
  self.num_procs = batch_kwargs.get("num_procs", 8)
65
319
 
66
- def _generate(self, sample) -> dict:
320
+ def _generate(self, sample: Dict[str, Any]) -> Dict[str, Any]:
321
+ """Generate a new sample by combining multiple columns.
322
+
323
+ Parameters
324
+ ----------
325
+ sample : Dict[str, Any]
326
+ Input sample to process.
327
+
328
+ Returns
329
+ -------
330
+ Dict[str, Any]
331
+ Sample with combined values stored in output column.
332
+ """
67
333
  sample[self.output_col] = self.separator.join(
68
- [sample[col] for col in self.columns]
334
+ [str(sample[col]) for col in self.columns]
69
335
  )
70
336
  return sample
71
337
 
72
338
  def generate(self, samples: Dataset) -> Dataset:
339
+ """Generate a new dataset with combined columns.
340
+
341
+ Parameters
342
+ ----------
343
+ samples : Dataset
344
+ Input dataset to process.
345
+
346
+ Returns
347
+ -------
348
+ Dataset
349
+ Dataset with combined values stored in output column.
350
+ """
73
351
  samples = samples.map(self._generate, num_proc=self.num_procs)
74
352
  return samples
75
353
 
76
354
 
77
355
  @BlockRegistry.register("FlattenColumnsBlock")
78
356
  class FlattenColumnsBlock(Block):
79
- def __init__(self, block_name: str, var_cols: list, value_name: str, var_name: str) -> None:
357
+ """Block for flattening multiple columns into a long format.
358
+
359
+ This block transforms a wide dataset format into a long format by melting
360
+ specified columns into rows, creating new variable and value columns.
361
+
362
+ Parameters
363
+ ----------
364
+ block_name : str
365
+ Name of the block.
366
+ var_cols : List[str]
367
+ List of column names to be melted into rows.
368
+ value_name : str
369
+ Name of the new column that will contain the values.
370
+ var_name : str
371
+ Name of the new column that will contain the variable names.
372
+ """
373
+
374
+ def __init__(
375
+ self,
376
+ block_name: str,
377
+ var_cols: List[str],
378
+ value_name: str,
379
+ var_name: str,
380
+ ) -> None:
80
381
  super().__init__(block_name=block_name)
81
382
  self.var_cols = var_cols
82
383
  self.value_name = value_name
83
384
  self.var_name = var_name
84
385
 
85
386
  def generate(self, samples: Dataset) -> Dataset:
387
+ """Generate a flattened dataset in long format.
388
+
389
+ Parameters
390
+ ----------
391
+ samples : Dataset
392
+ Input dataset to flatten.
393
+
394
+ Returns
395
+ -------
396
+ Dataset
397
+ Flattened dataset in long format with new variable and value columns.
398
+ """
86
399
  df = samples.to_pandas()
87
400
  id_cols = [col for col in samples.column_names if col not in self.var_cols]
88
- flatten_df = df.melt(id_vars=id_cols,
89
- value_vars=self.var_cols,
90
- value_name=self.value_name,
91
- var_name=self.var_name)
92
-
401
+ flatten_df = df.melt(
402
+ id_vars=id_cols,
403
+ value_vars=self.var_cols,
404
+ value_name=self.value_name,
405
+ var_name=self.var_name,
406
+ )
93
407
  return Dataset.from_pandas(flatten_df)
94
408
 
95
409
 
96
410
  @BlockRegistry.register("DuplicateColumns")
97
411
  class DuplicateColumns(Block):
98
- def __init__(self, block_name: str, columns_map: dict) -> None:
99
- """Create duplicate of columns specified in column map.
412
+ """Block for duplicating existing columns with new names.
100
413
 
101
- Args:
102
- columns_map (dict): mapping of existing column to new column names
103
- """
414
+ This block creates copies of existing columns with new names as specified
415
+ in the columns mapping dictionary.
416
+
417
+ Parameters
418
+ ----------
419
+ block_name : str
420
+ Name of the block.
421
+ columns_map : Dict[str, str]
422
+ Dictionary mapping existing column names to new column names.
423
+ Keys are existing column names, values are new column names.
424
+ """
425
+
426
+ def __init__(
427
+ self,
428
+ block_name: str,
429
+ columns_map: Dict[str, str],
430
+ ) -> None:
104
431
  super().__init__(block_name=block_name)
105
- self.columns_map = columns_map
106
-
107
-
108
- def generate(self, samples: Dataset):
432
+ self.columns_map = columns_map
433
+
434
+ def generate(self, samples: Dataset) -> Dataset:
435
+ """Generate a dataset with duplicated columns.
436
+
437
+ Parameters
438
+ ----------
439
+ samples : Dataset
440
+ Input dataset to duplicate columns from.
441
+
442
+ Returns
443
+ -------
444
+ Dataset
445
+ Dataset with additional duplicated columns.
446
+ """
109
447
  for col_to_dup in self.columns_map:
110
- samples = samples.add_column(self.columns_map[col_to_dup], samples[col_to_dup])
448
+ samples = samples.add_column(
449
+ self.columns_map[col_to_dup], samples[col_to_dup]
450
+ )
111
451
  return samples
112
452
 
113
453
 
114
454
  @BlockRegistry.register("RenameColumns")
115
455
  class RenameColumns(Block):
116
- def __init__(self, block_name: str, columns_map: dict) -> None:
117
- """Rename dataset columns.
456
+ """Block for renaming columns in a dataset.
118
457
 
119
- Args:
120
- columns_map (dict): mapping of existing column to new column names
121
- """
122
- self.columns_map = columns_map
458
+ This block renames columns in a dataset according to a mapping dictionary,
459
+ where keys are existing column names and values are new column names.
460
+
461
+ Parameters
462
+ ----------
463
+ block_name : str
464
+ Name of the block.
465
+ columns_map : Dict[str, str]
466
+ Dictionary mapping existing column names to new column names.
467
+ Keys are existing column names, values are new column names.
468
+ """
469
+
470
+ def __init__(
471
+ self,
472
+ block_name: str,
473
+ columns_map: Dict[str, str],
474
+ ) -> None:
123
475
  super().__init__(block_name=block_name)
124
-
125
-
126
- def generate(self, samples: Dataset):
476
+ self.columns_map = columns_map
477
+
478
+ def generate(self, samples: Dataset) -> Dataset:
479
+ """Generate a dataset with renamed columns.
480
+
481
+ Parameters
482
+ ----------
483
+ samples : Dataset
484
+ Input dataset to rename columns in.
485
+
486
+ Returns
487
+ -------
488
+ Dataset
489
+ Dataset with renamed columns.
490
+ """
127
491
  samples = samples.rename_columns(self.columns_map)
128
492
  return samples
129
493
 
130
494
 
131
495
  @BlockRegistry.register("SetToMajorityValue")
132
496
  class SetToMajorityValue(Block):
133
- def __init__(self, block_name: str, col_name) -> None:
497
+ """Block for setting all values in a column to the most frequent value.
498
+
499
+ This block finds the most common value (mode) in a specified column and
500
+ replaces all values in that column with this majority value.
501
+
502
+ Parameters
503
+ ----------
504
+ block_name : str
505
+ Name of the block.
506
+ col_name : str
507
+ Name of the column to set to majority value.
508
+ """
509
+
510
+ def __init__(
511
+ self,
512
+ block_name: str,
513
+ col_name: str,
514
+ ) -> None:
515
+ super().__init__(block_name=block_name)
134
516
  self.col_name = col_name
135
- super().__init__(block_name)
136
-
137
- def generate(self, samples: Dataset):
517
+
518
+ def generate(self, samples: Dataset) -> Dataset:
519
+ """Generate a dataset with column set to majority value.
520
+
521
+ Parameters
522
+ ----------
523
+ samples : Dataset
524
+ Input dataset to process.
525
+
526
+ Returns
527
+ -------
528
+ Dataset
529
+ Dataset with specified column set to its majority value.
530
+ """
138
531
  samples = samples.to_pandas()
139
532
  samples[self.col_name] = samples[self.col_name].mode()[0]
140
- return Dataset.from_pandas(samples)
533
+ return Dataset.from_pandas(samples)
534
+
535
+
536
+ @BlockRegistry.register("IterBlock")
537
+ class IterBlock(Block):
538
+ """Block for iteratively applying another block multiple times.
539
+
540
+ This block takes another block type and applies it repeatedly to generate
541
+ multiple samples from the input dataset.
542
+
543
+ Parameters
544
+ ----------
545
+ block_name : str
546
+ Name of the block.
547
+ num_iters : int
548
+ Number of times to apply the block.
549
+ block_type : Type[Block]
550
+ The block class to instantiate and apply.
551
+ block_kwargs : Dict[str, Any]
552
+ Keyword arguments to pass to the block constructor.
553
+ **kwargs : Dict[str, Any]
554
+ Additional keyword arguments. Supports:
555
+ - gen_kwargs: Dict[str, Any]
556
+ Arguments to pass to the block's generate method.
557
+ """
558
+
559
+ def __init__(
560
+ self,
561
+ block_name: str,
562
+ num_iters: int,
563
+ block_type: Type[Block],
564
+ block_kwargs: Dict[str, Any],
565
+ **kwargs: Dict[str, Any],
566
+ ) -> None:
567
+ super().__init__(block_name)
568
+ self.num_iters = num_iters
569
+ self.block = block_type(**block_kwargs)
570
+ self.gen_kwargs = kwargs.get("gen_kwargs", {})
571
+
572
+ def generate(self, samples: Dataset, **gen_kwargs: Dict[str, Any]) -> Dataset:
573
+ """Generate multiple samples by iteratively applying the block.
574
+
575
+ Parameters
576
+ ----------
577
+ samples : Dataset
578
+ Input dataset to process.
579
+ **gen_kwargs : Dict[str, Any]
580
+ Additional keyword arguments to pass to the block's generate method.
581
+ These are merged with the gen_kwargs provided at initialization.
582
+
583
+ Returns
584
+ -------
585
+ Dataset
586
+ Dataset containing all generated samples from all iterations.
587
+ """
588
+ generated_samples = []
589
+ num_iters = self.num_iters
590
+
591
+ for _ in range(num_iters):
592
+ batch_generated = self.block.generate(
593
+ samples, **{**self.gen_kwargs, **gen_kwargs}
594
+ )
595
+ generated_samples.extend(batch_generated)
596
+
597
+ return Dataset.from_list(generated_samples)