sdg-hub 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (139) hide show
  1. sdg_hub/__init__.py +28 -1
  2. sdg_hub/_version.py +2 -2
  3. sdg_hub/core/__init__.py +22 -0
  4. sdg_hub/core/blocks/__init__.py +58 -0
  5. sdg_hub/core/blocks/base.py +313 -0
  6. sdg_hub/core/blocks/deprecated_blocks/__init__.py +29 -0
  7. sdg_hub/core/blocks/deprecated_blocks/combine_columns.py +93 -0
  8. sdg_hub/core/blocks/deprecated_blocks/duplicate_columns.py +88 -0
  9. sdg_hub/core/blocks/deprecated_blocks/filter_by_value.py +103 -0
  10. sdg_hub/core/blocks/deprecated_blocks/flatten_columns.py +94 -0
  11. sdg_hub/core/blocks/deprecated_blocks/llmblock.py +479 -0
  12. sdg_hub/core/blocks/deprecated_blocks/rename_columns.py +88 -0
  13. sdg_hub/core/blocks/deprecated_blocks/sample_populator.py +58 -0
  14. sdg_hub/core/blocks/deprecated_blocks/selector.py +97 -0
  15. sdg_hub/core/blocks/deprecated_blocks/set_to_majority_value.py +88 -0
  16. sdg_hub/core/blocks/evaluation/__init__.py +9 -0
  17. sdg_hub/core/blocks/evaluation/evaluate_faithfulness_block.py +564 -0
  18. sdg_hub/core/blocks/evaluation/evaluate_relevancy_block.py +564 -0
  19. sdg_hub/core/blocks/evaluation/verify_question_block.py +564 -0
  20. sdg_hub/core/blocks/filtering/__init__.py +12 -0
  21. sdg_hub/core/blocks/filtering/column_value_filter.py +188 -0
  22. sdg_hub/core/blocks/llm/__init__.py +25 -0
  23. sdg_hub/core/blocks/llm/client_manager.py +398 -0
  24. sdg_hub/core/blocks/llm/config.py +336 -0
  25. sdg_hub/core/blocks/llm/error_handler.py +368 -0
  26. sdg_hub/core/blocks/llm/llm_chat_block.py +542 -0
  27. sdg_hub/core/blocks/llm/prompt_builder_block.py +368 -0
  28. sdg_hub/core/blocks/llm/text_parser_block.py +310 -0
  29. sdg_hub/core/blocks/registry.py +331 -0
  30. sdg_hub/core/blocks/transform/__init__.py +23 -0
  31. sdg_hub/core/blocks/transform/duplicate_columns.py +88 -0
  32. sdg_hub/core/blocks/transform/index_based_mapper.py +225 -0
  33. sdg_hub/core/blocks/transform/melt_columns.py +126 -0
  34. sdg_hub/core/blocks/transform/rename_columns.py +69 -0
  35. sdg_hub/core/blocks/transform/text_concat.py +102 -0
  36. sdg_hub/core/blocks/transform/uniform_col_val_setter.py +101 -0
  37. sdg_hub/core/flow/__init__.py +20 -0
  38. sdg_hub/core/flow/base.py +980 -0
  39. sdg_hub/core/flow/metadata.py +344 -0
  40. sdg_hub/core/flow/migration.py +187 -0
  41. sdg_hub/core/flow/registry.py +330 -0
  42. sdg_hub/core/flow/validation.py +265 -0
  43. sdg_hub/{utils → core/utils}/__init__.py +6 -4
  44. sdg_hub/{utils → core/utils}/datautils.py +1 -3
  45. sdg_hub/core/utils/error_handling.py +208 -0
  46. sdg_hub/{utils → core/utils}/path_resolution.py +2 -2
  47. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/atomic_facts.yaml +40 -0
  48. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/detailed_summary.yaml +13 -0
  49. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_faithfulness.yaml +64 -0
  50. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_question.yaml +29 -0
  51. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_relevancy.yaml +81 -0
  52. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/extractive_summary.yaml +13 -0
  53. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/flow.yaml +191 -0
  54. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/generate_questions_responses.yaml +54 -0
  55. sdg_hub-0.2.0.dist-info/METADATA +218 -0
  56. sdg_hub-0.2.0.dist-info/RECORD +63 -0
  57. sdg_hub/blocks/__init__.py +0 -42
  58. sdg_hub/blocks/block.py +0 -96
  59. sdg_hub/blocks/llmblock.py +0 -375
  60. sdg_hub/blocks/openaichatblock.py +0 -556
  61. sdg_hub/blocks/utilblocks.py +0 -597
  62. sdg_hub/checkpointer.py +0 -139
  63. sdg_hub/configs/annotations/cot_reflection.yaml +0 -34
  64. sdg_hub/configs/annotations/detailed_annotations.yaml +0 -28
  65. sdg_hub/configs/annotations/detailed_description.yaml +0 -10
  66. sdg_hub/configs/annotations/detailed_description_icl.yaml +0 -32
  67. sdg_hub/configs/annotations/simple_annotations.yaml +0 -9
  68. sdg_hub/configs/knowledge/__init__.py +0 -0
  69. sdg_hub/configs/knowledge/atomic_facts.yaml +0 -46
  70. sdg_hub/configs/knowledge/auxilary_instructions.yaml +0 -35
  71. sdg_hub/configs/knowledge/detailed_summary.yaml +0 -18
  72. sdg_hub/configs/knowledge/evaluate_faithfulness.yaml +0 -68
  73. sdg_hub/configs/knowledge/evaluate_question.yaml +0 -38
  74. sdg_hub/configs/knowledge/evaluate_relevancy.yaml +0 -84
  75. sdg_hub/configs/knowledge/extractive_summary.yaml +0 -18
  76. sdg_hub/configs/knowledge/generate_code_questions_responses.yaml +0 -39
  77. sdg_hub/configs/knowledge/generate_questions.yaml +0 -82
  78. sdg_hub/configs/knowledge/generate_questions_responses.yaml +0 -56
  79. sdg_hub/configs/knowledge/generate_responses.yaml +0 -86
  80. sdg_hub/configs/knowledge/mcq_generation.yaml +0 -83
  81. sdg_hub/configs/knowledge/router.yaml +0 -12
  82. sdg_hub/configs/knowledge/simple_generate_qa.yaml +0 -34
  83. sdg_hub/configs/reasoning/__init__.py +0 -0
  84. sdg_hub/configs/reasoning/dynamic_cot.yaml +0 -40
  85. sdg_hub/configs/skills/__init__.py +0 -0
  86. sdg_hub/configs/skills/analyzer.yaml +0 -48
  87. sdg_hub/configs/skills/annotation.yaml +0 -36
  88. sdg_hub/configs/skills/contexts.yaml +0 -28
  89. sdg_hub/configs/skills/critic.yaml +0 -60
  90. sdg_hub/configs/skills/evaluate_freeform_pair.yaml +0 -111
  91. sdg_hub/configs/skills/evaluate_freeform_questions.yaml +0 -78
  92. sdg_hub/configs/skills/evaluate_grounded_pair.yaml +0 -119
  93. sdg_hub/configs/skills/evaluate_grounded_questions.yaml +0 -51
  94. sdg_hub/configs/skills/freeform_questions.yaml +0 -34
  95. sdg_hub/configs/skills/freeform_responses.yaml +0 -39
  96. sdg_hub/configs/skills/grounded_questions.yaml +0 -38
  97. sdg_hub/configs/skills/grounded_responses.yaml +0 -59
  98. sdg_hub/configs/skills/icl_examples/STEM.yaml +0 -56
  99. sdg_hub/configs/skills/icl_examples/__init__.py +0 -0
  100. sdg_hub/configs/skills/icl_examples/coding.yaml +0 -97
  101. sdg_hub/configs/skills/icl_examples/extraction.yaml +0 -36
  102. sdg_hub/configs/skills/icl_examples/humanities.yaml +0 -71
  103. sdg_hub/configs/skills/icl_examples/math.yaml +0 -85
  104. sdg_hub/configs/skills/icl_examples/reasoning.yaml +0 -30
  105. sdg_hub/configs/skills/icl_examples/roleplay.yaml +0 -45
  106. sdg_hub/configs/skills/icl_examples/writing.yaml +0 -80
  107. sdg_hub/configs/skills/judge.yaml +0 -53
  108. sdg_hub/configs/skills/planner.yaml +0 -67
  109. sdg_hub/configs/skills/respond.yaml +0 -8
  110. sdg_hub/configs/skills/revised_responder.yaml +0 -78
  111. sdg_hub/configs/skills/router.yaml +0 -59
  112. sdg_hub/configs/skills/simple_generate_qa_freeform.yaml +0 -27
  113. sdg_hub/configs/skills/simple_generate_qa_grounded.yaml +0 -31
  114. sdg_hub/flow.py +0 -477
  115. sdg_hub/flow_runner.py +0 -450
  116. sdg_hub/flows/generation/knowledge/mmlu_bench.yaml +0 -13
  117. sdg_hub/flows/generation/knowledge/simple_knowledge.yaml +0 -12
  118. sdg_hub/flows/generation/knowledge/synth_knowledge.yaml +0 -89
  119. sdg_hub/flows/generation/knowledge/synth_knowledge1.5.yaml +0 -136
  120. sdg_hub/flows/generation/skills/improve_responses.yaml +0 -103
  121. sdg_hub/flows/generation/skills/simple_freeform_skill.yaml +0 -12
  122. sdg_hub/flows/generation/skills/simple_grounded_skill.yaml +0 -12
  123. sdg_hub/flows/generation/skills/synth_grounded_skills.yaml +0 -80
  124. sdg_hub/flows/generation/skills/synth_skills.yaml +0 -59
  125. sdg_hub/pipeline.py +0 -121
  126. sdg_hub/prompts.py +0 -80
  127. sdg_hub/registry.py +0 -122
  128. sdg_hub/sdg.py +0 -206
  129. sdg_hub/utils/config_validation.py +0 -91
  130. sdg_hub/utils/error_handling.py +0 -94
  131. sdg_hub/utils/validation_result.py +0 -10
  132. sdg_hub-0.1.4.dist-info/METADATA +0 -190
  133. sdg_hub-0.1.4.dist-info/RECORD +0 -89
  134. sdg_hub/{logger_config.py → core/utils/logger_config.py} +1 -1
  135. /sdg_hub/{configs/__init__.py → flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/README.md} +0 -0
  136. /sdg_hub/{configs/annotations → flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab}/__init__.py +0 -0
  137. {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.0.dist-info}/WHEEL +0 -0
  138. {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.0.dist-info}/licenses/LICENSE +0 -0
  139. {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,331 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Enhanced BlockRegistry with metadata and better error handling.
3
+
4
+ This module provides a clean registry system for blocks with metadata,
5
+ categorization, and improved error handling.
6
+ """
7
+
8
+ # Standard
9
+ from dataclasses import dataclass
10
+ from difflib import get_close_matches
11
+ from typing import Optional
12
+ import inspect
13
+
14
+ # Third Party
15
+ from rich.console import Console
16
+ from rich.table import Table
17
+
18
+ # Local
19
+ from ..utils.logger_config import setup_logger
20
+
21
+ logger = setup_logger(__name__)
22
+ console = Console()
23
+
24
+
25
+ @dataclass
26
+ class BlockMetadata:
27
+ """Metadata for registered blocks.
28
+
29
+ Parameters
30
+ ----------
31
+ name : str
32
+ The registered name of the block.
33
+ block_class : Type
34
+ The actual block class.
35
+ category : str
36
+ Category for organization (e.g., 'llm', 'utility', 'filtering').
37
+ description : str, optional
38
+ Human-readable description of what the block does.
39
+ deprecated : bool, optional
40
+ Whether this block is deprecated.
41
+ replacement : str, optional
42
+ Suggested replacement if deprecated.
43
+ """
44
+
45
+ name: str
46
+ block_class: type
47
+ category: str
48
+ description: str = ""
49
+ deprecated: bool = False
50
+ replacement: Optional[str] = None
51
+
52
+ def __post_init__(self) -> None:
53
+ """Validate metadata after initialization."""
54
+ if not self.name:
55
+ raise ValueError("Block name cannot be empty")
56
+ if not inspect.isclass(self.block_class):
57
+ raise ValueError("block_class must be a class")
58
+
59
+
60
+ class BlockRegistry:
61
+ """Registry for block classes with metadata and enhanced error handling."""
62
+
63
+ _metadata: dict[str, BlockMetadata] = {}
64
+ _categories: dict[str, set[str]] = {}
65
+
66
+ @classmethod
67
+ def register(
68
+ cls,
69
+ block_name: str,
70
+ category: str,
71
+ description: str = "",
72
+ deprecated: bool = False,
73
+ replacement: Optional[str] = None,
74
+ ):
75
+ """Register a block class with metadata.
76
+
77
+ Parameters
78
+ ----------
79
+ block_name : str
80
+ Name under which to register the block.
81
+ category : str
82
+ Category for organization.
83
+ description : str, optional
84
+ Human-readable description of the block.
85
+ deprecated : bool, optional
86
+ Whether this block is deprecated.
87
+ replacement : str, optional
88
+ Suggested replacement if deprecated.
89
+
90
+ Returns
91
+ -------
92
+ callable
93
+ Decorator function.
94
+ """
95
+
96
+ def decorator(block_class: type) -> type:
97
+ # Validate the class
98
+ cls._validate_block_class(block_class)
99
+
100
+ # Create metadata
101
+ metadata = BlockMetadata(
102
+ name=block_name,
103
+ block_class=block_class,
104
+ category=category,
105
+ description=description,
106
+ deprecated=deprecated,
107
+ replacement=replacement,
108
+ )
109
+
110
+ # Register the metadata
111
+ cls._metadata[block_name] = metadata
112
+
113
+ # Update category index
114
+ if category not in cls._categories:
115
+ cls._categories[category] = set()
116
+ cls._categories[category].add(block_name)
117
+
118
+ logger.debug(
119
+ f"Registered block '{block_name}' "
120
+ f"({block_class.__name__}) in category '{category}'"
121
+ )
122
+
123
+ if deprecated:
124
+ warning_msg = f"Block '{block_name}' is deprecated."
125
+ if replacement:
126
+ warning_msg += f" Use '{replacement}' instead."
127
+ logger.warning(warning_msg)
128
+
129
+ return block_class
130
+
131
+ return decorator
132
+
133
+ @classmethod
134
+ def _validate_block_class(cls, block_class: type) -> None:
135
+ """Validate that a class is a proper block class.
136
+
137
+ Parameters
138
+ ----------
139
+ block_class : Type
140
+ The class to validate.
141
+
142
+ Raises
143
+ ------
144
+ ValueError
145
+ If the class is not a valid block class.
146
+ """
147
+ if not inspect.isclass(block_class):
148
+ raise ValueError(f"Expected a class, got {type(block_class)}")
149
+
150
+ # Validate BaseBlock inheritance
151
+ try:
152
+ # Local
153
+ from .base import BaseBlock
154
+
155
+ if not issubclass(block_class, BaseBlock):
156
+ raise ValueError(
157
+ f"Block class '{block_class.__name__}' must inherit from BaseBlock"
158
+ )
159
+ except ImportError as exc:
160
+ # BaseBlock not available, check for generate method
161
+ if not hasattr(block_class, "generate"):
162
+ raise ValueError(
163
+ f"Block class '{block_class.__name__}' must implement 'generate' method"
164
+ ) from exc
165
+
166
+ @classmethod
167
+ def get(cls, block_name: str) -> type:
168
+ """Get a block class with enhanced error handling.
169
+
170
+ Parameters
171
+ ----------
172
+ block_name : str
173
+ Name of the block to retrieve.
174
+
175
+ Returns
176
+ -------
177
+ Type
178
+ The block class.
179
+
180
+ Raises
181
+ ------
182
+ KeyError
183
+ If the block is not found, with helpful suggestions.
184
+ """
185
+ if block_name not in cls._metadata:
186
+ available_blocks = list(cls._metadata.keys())
187
+ suggestions = get_close_matches(
188
+ block_name, available_blocks, n=3, cutoff=0.6
189
+ )
190
+
191
+ error_msg = f"Block '{block_name}' not found in registry."
192
+
193
+ if suggestions:
194
+ error_msg += f" Did you mean: {', '.join(suggestions)}?"
195
+
196
+ if available_blocks:
197
+ error_msg += (
198
+ f"\nAvailable blocks: {', '.join(sorted(available_blocks))}"
199
+ )
200
+
201
+ if cls._categories:
202
+ error_msg += (
203
+ f"\nCategories: {', '.join(sorted(cls._categories.keys()))}"
204
+ )
205
+
206
+ logger.error(error_msg)
207
+ raise KeyError(error_msg)
208
+
209
+ metadata = cls._metadata[block_name]
210
+
211
+ if metadata.deprecated:
212
+ warning_msg = f"Block '{block_name}' is deprecated."
213
+ if metadata.replacement:
214
+ warning_msg += f" Use '{metadata.replacement}' instead."
215
+ logger.warning(warning_msg)
216
+
217
+ return metadata.block_class
218
+
219
+ @classmethod
220
+ def info(cls, block_name: str) -> BlockMetadata:
221
+ """Get metadata for a specific block.
222
+
223
+ Parameters
224
+ ----------
225
+ block_name : str
226
+ Name of the block.
227
+
228
+ Returns
229
+ -------
230
+ BlockMetadata
231
+ The block's metadata.
232
+
233
+ Raises
234
+ ------
235
+ KeyError
236
+ If the block is not found.
237
+ """
238
+ if block_name not in cls._metadata:
239
+ raise KeyError(f"Block '{block_name}' not found in registry.")
240
+ return cls._metadata[block_name]
241
+
242
+ @classmethod
243
+ def categories(cls) -> list[str]:
244
+ """Get all available categories.
245
+
246
+ Returns
247
+ -------
248
+ List[str]
249
+ Sorted list of categories.
250
+ """
251
+ return sorted(cls._categories.keys())
252
+
253
+ @classmethod
254
+ def category(cls, category: str) -> list[str]:
255
+ """Get all blocks in a specific category.
256
+
257
+ Parameters
258
+ ----------
259
+ category : str
260
+ The category to filter by.
261
+
262
+ Returns
263
+ -------
264
+ List[str]
265
+ List of block names in the category.
266
+
267
+ Raises
268
+ ------
269
+ KeyError
270
+ If the category doesn't exist.
271
+ """
272
+ if category not in cls._categories:
273
+ available_categories = sorted(cls._categories.keys())
274
+ raise KeyError(
275
+ f"Category '{category}' not found. "
276
+ f"Available categories: {', '.join(available_categories)}"
277
+ )
278
+ return sorted(cls._categories[category])
279
+
280
+ @classmethod
281
+ def all(cls) -> dict[str, list[str]]:
282
+ """List all blocks organized by category.
283
+
284
+ Returns
285
+ -------
286
+ Dict[str, List[str]]
287
+ Dictionary mapping categories to lists of block names.
288
+ """
289
+ return {
290
+ category: sorted(blocks) for category, blocks in cls._categories.items()
291
+ }
292
+
293
+ @classmethod
294
+ def show(cls) -> None:
295
+ """Print a Rich-formatted table of all available blocks."""
296
+ if not cls._metadata:
297
+ console.print("[yellow]No blocks registered yet.[/yellow]")
298
+ return
299
+
300
+ table = Table(
301
+ title="Available Blocks", show_header=True, header_style="bold magenta"
302
+ )
303
+ table.add_column("Block Name", style="cyan", no_wrap=True)
304
+ table.add_column("Category", style="green")
305
+ table.add_column("Description", style="white")
306
+
307
+ # Sort blocks by category, then by name
308
+ sorted_blocks = sorted(
309
+ cls._metadata.items(), key=lambda x: (x[1].category, x[0])
310
+ )
311
+
312
+ for name, metadata in sorted_blocks:
313
+ description = metadata.description or "No description"
314
+
315
+ # Show deprecated blocks with a warning indicator in the name
316
+ block_name = f"⚠️ {name}" if metadata.deprecated else name
317
+
318
+ table.add_row(block_name, metadata.category, description)
319
+
320
+ console.print(table)
321
+
322
+ # Show summary
323
+ total_blocks = len(cls._metadata)
324
+ total_categories = len(cls._categories)
325
+ deprecated_count = sum(1 for m in cls._metadata.values() if m.deprecated)
326
+
327
+ console.print(
328
+ f"\n[bold]Summary:[/bold] {total_blocks} blocks across {total_categories} categories"
329
+ )
330
+ if deprecated_count > 0:
331
+ console.print(f"[yellow]⚠️ {deprecated_count} deprecated blocks[/yellow]")
@@ -0,0 +1,23 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Data transformation blocks for dataset manipulation.
3
+
4
+ This module provides blocks for transforming datasets including column operations,
5
+ wide-to-long transformations, value selection, and majority value assignment.
6
+ """
7
+
8
+ # Local
9
+ from .duplicate_columns import DuplicateColumnsBlock
10
+ from .index_based_mapper import IndexBasedMapperBlock
11
+ from .melt_columns import MeltColumnsBlock
12
+ from .rename_columns import RenameColumnsBlock
13
+ from .text_concat import TextConcatBlock
14
+ from .uniform_col_val_setter import UniformColumnValueSetter
15
+
16
+ __all__ = [
17
+ "TextConcatBlock",
18
+ "DuplicateColumnsBlock",
19
+ "MeltColumnsBlock",
20
+ "IndexBasedMapperBlock",
21
+ "RenameColumnsBlock",
22
+ "UniformColumnValueSetter",
23
+ ]
@@ -0,0 +1,88 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Duplicate columns block for dataset column duplication operations.
3
+
4
+ This module provides a block for duplicating existing columns with new names
5
+ according to a mapping specification.
6
+ """
7
+
8
+ # Standard
9
+ from typing import Any
10
+
11
+ # Third Party
12
+ from datasets import Dataset
13
+ from pydantic import field_validator
14
+
15
+ # Local
16
+ from ...utils.logger_config import setup_logger
17
+ from ..base import BaseBlock
18
+ from ..registry import BlockRegistry
19
+
20
+ logger = setup_logger(__name__)
21
+
22
+
23
+ @BlockRegistry.register(
24
+ "DuplicateColumnsBlock",
25
+ "transform",
26
+ "Duplicates existing columns with new names according to a mapping specification",
27
+ )
28
+ class DuplicateColumnsBlock(BaseBlock):
29
+ """Block for duplicating existing columns with new names.
30
+
31
+ This block creates copies of existing columns with new names according to a mapping specification.
32
+ The mapping is provided through input_cols as a dictionary.
33
+
34
+ Attributes
35
+ ----------
36
+ block_name : str
37
+ Name of the block.
38
+ input_cols : Dict[str, str]
39
+ Dictionary mapping existing column names to new column names.
40
+ Keys are existing column names, values are new column names.
41
+ """
42
+
43
+ @field_validator("input_cols", mode="after")
44
+ @classmethod
45
+ def validate_input_cols(cls, v):
46
+ """Validate that input_cols is a non-empty dict."""
47
+ if not v:
48
+ raise ValueError("input_cols cannot be empty")
49
+ if not isinstance(v, dict):
50
+ raise ValueError(
51
+ "input_cols must be a dictionary mapping existing column names to new column names"
52
+ )
53
+ return v
54
+
55
+ def model_post_init(self, __context: Any) -> None:
56
+ """Initialize derived attributes after Pydantic validation."""
57
+ super().model_post_init(__context) if hasattr(
58
+ super(), "model_post_init"
59
+ ) else None
60
+
61
+ # Set output_cols to the new column names being created
62
+ if self.output_cols is None:
63
+ self.output_cols = list(self.input_cols.values())
64
+
65
+ def generate(self, samples: Dataset, **kwargs: Any) -> Dataset:
66
+ """Generate a dataset with duplicated columns.
67
+
68
+ Parameters
69
+ ----------
70
+ samples : Dataset
71
+ Input dataset to duplicate columns from.
72
+
73
+ Returns
74
+ -------
75
+ Dataset
76
+ Dataset with additional duplicated columns.
77
+ """
78
+ # Create a copy to avoid modifying the original
79
+ result = samples
80
+
81
+ # Duplicate each column as specified in the mapping
82
+ for source_col, target_col in self.input_cols.items():
83
+ if source_col not in result.column_names:
84
+ raise ValueError(f"Source column '{source_col}' not found in dataset")
85
+
86
+ result = result.add_column(target_col, result[source_col])
87
+
88
+ return result
@@ -0,0 +1,225 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Selector block for column value selection and mapping.
3
+
4
+ This module provides a block for selecting and mapping values from one column
5
+ to another based on a choice column's value.
6
+ """
7
+
8
+ # Standard
9
+ from typing import Any
10
+
11
+ # Third Party
12
+ from datasets import Dataset
13
+ from pydantic import Field, field_validator, model_validator
14
+
15
+ # Local
16
+ from ...utils.error_handling import MissingColumnError
17
+ from ...utils.logger_config import setup_logger
18
+ from ..base import BaseBlock
19
+ from ..registry import BlockRegistry
20
+
21
+ logger = setup_logger(__name__)
22
+
23
+
24
+ @BlockRegistry.register(
25
+ "IndexBasedMapperBlock",
26
+ "transform",
27
+ "Maps values from source columns to output columns based on choice columns using shared mapping",
28
+ )
29
+ class IndexBasedMapperBlock(BaseBlock):
30
+ """Block for mapping values from source columns to output columns based on choice columns.
31
+
32
+ This block uses a shared mapping dictionary to select values from source columns and
33
+ store them in output columns based on corresponding choice columns' values.
34
+ The choice_cols and output_cols must have the same length - choice_cols[i] determines
35
+ the value for output_cols[i].
36
+
37
+ Attributes
38
+ ----------
39
+ block_name : str
40
+ Name of the block.
41
+ input_cols : Union[str, List[str], Dict[str, Any], None]
42
+ Input column specification. Should include choice columns and mapped columns.
43
+ output_cols : Union[str, List[str], Dict[str, Any], None]
44
+ Output column specification. Must have same length as choice_cols.
45
+ choice_map : Dict[str, str]
46
+ Dictionary mapping choice values to source column names.
47
+ choice_cols : List[str]
48
+ List of column names containing choice values. Must have same length as output_cols.
49
+ """
50
+
51
+ choice_map: dict[str, str] = Field(
52
+ ..., description="Dictionary mapping choice values to column names"
53
+ )
54
+ choice_cols: list[str] = Field(
55
+ ..., description="List of column names containing choice values"
56
+ )
57
+
58
+ @field_validator("choice_map")
59
+ @classmethod
60
+ def validate_choice_map(cls, v):
61
+ """Validate that choice_map is not empty."""
62
+ if not v:
63
+ raise ValueError("choice_map cannot be empty")
64
+ return v
65
+
66
+ @field_validator("choice_cols")
67
+ @classmethod
68
+ def validate_choice_cols_not_empty(cls, v):
69
+ """Validate that choice_cols is not empty."""
70
+ if not v:
71
+ raise ValueError("choice_cols cannot be empty")
72
+ return v
73
+
74
+ @model_validator(mode="after")
75
+ def validate_input_output_consistency(self):
76
+ """Validate that choice_cols and output_cols have same length and consistency."""
77
+ # Validate equal lengths
78
+ if len(self.choice_cols) != len(self.output_cols):
79
+ raise ValueError(
80
+ f"choice_cols and output_cols must have same length. "
81
+ f"Got choice_cols: {len(self.choice_cols)}, output_cols: {len(self.output_cols)}"
82
+ )
83
+
84
+ if isinstance(self.input_cols, list):
85
+ # Check that all choice_cols are in input_cols
86
+ missing_choice_cols = set(self.choice_cols) - set(self.input_cols)
87
+ if missing_choice_cols:
88
+ logger.warning(
89
+ f"Choice columns {missing_choice_cols} not found in input_cols {self.input_cols}"
90
+ )
91
+
92
+ # Check that all mapped columns are in input_cols
93
+ missing_mapped_cols = set(self.choice_map.values()) - set(self.input_cols)
94
+ if missing_mapped_cols:
95
+ logger.warning(
96
+ f"Mapped columns {missing_mapped_cols} not found in input_cols {self.input_cols}"
97
+ )
98
+
99
+ return self
100
+
101
+ def model_post_init(self, __context: Any) -> None:
102
+ """Initialize derived attributes after Pydantic validation."""
103
+ # Create mapping from choice_col to output_col for easy access
104
+ self.choice_to_output_map = dict(zip(self.choice_cols, self.output_cols))
105
+
106
+ def _validate_custom(self, samples: Dataset) -> None:
107
+ """Validate that required columns exist in the dataset.
108
+
109
+ Parameters
110
+ ----------
111
+ samples : Dataset
112
+ Input dataset to validate.
113
+
114
+ Raises
115
+ ------
116
+ MissingColumnError
117
+ If required columns are missing from the dataset.
118
+ ValueError
119
+ If choice values in data are not found in choice_map.
120
+ """
121
+ # Check that all choice_cols exist
122
+ missing_choice_cols = [
123
+ col for col in self.choice_cols if col not in samples.column_names
124
+ ]
125
+ if missing_choice_cols:
126
+ raise MissingColumnError(
127
+ block_name=self.block_name,
128
+ missing_columns=missing_choice_cols,
129
+ available_columns=samples.column_names,
130
+ )
131
+
132
+ # Check that all mapped columns exist
133
+ mapped_cols = list(self.choice_map.values())
134
+ missing_cols = list(set(mapped_cols) - set(samples.column_names))
135
+ if missing_cols:
136
+ raise MissingColumnError(
137
+ block_name=self.block_name,
138
+ missing_columns=missing_cols,
139
+ available_columns=samples.column_names,
140
+ )
141
+
142
+ # Check that all choice values in all choice columns have corresponding mappings
143
+ all_unique_choices = set()
144
+ for choice_col in self.choice_cols:
145
+ all_unique_choices.update(samples[choice_col])
146
+
147
+ mapped_choices = set(self.choice_map.keys())
148
+ unmapped_choices = all_unique_choices - mapped_choices
149
+
150
+ if unmapped_choices:
151
+ raise ValueError(
152
+ f"Choice values {sorted(unmapped_choices)} not found in choice_map for block '{self.block_name}'. "
153
+ f"Available choices in mapping: {sorted(mapped_choices)}"
154
+ )
155
+
156
+ def _generate(self, sample: dict[str, Any]) -> dict[str, Any]:
157
+ """Generate a new sample by selecting values based on choice mapping.
158
+
159
+ Parameters
160
+ ----------
161
+ sample : Dict[str, Any]
162
+ Input sample to process.
163
+
164
+ Returns
165
+ -------
166
+ Dict[str, Any]
167
+ Sample with selected values stored in corresponding output columns.
168
+ """
169
+ for choice_col, output_col in self.choice_to_output_map.items():
170
+ choice_value = sample[choice_col]
171
+ source_col = self.choice_map[
172
+ choice_value
173
+ ] # Safe since validated in _validate_custom
174
+ sample[output_col] = sample[source_col]
175
+ return sample
176
+
177
+ def generate(self, samples: Dataset) -> Dataset:
178
+ """Generate a new dataset with selected values.
179
+
180
+ Parameters
181
+ ----------
182
+ samples : Dataset
183
+ Input dataset to process.
184
+
185
+ Returns
186
+ -------
187
+ Dataset
188
+ Dataset with selected values stored in output column.
189
+ """
190
+ # Log the operation
191
+ all_unique_choices = set()
192
+ for choice_col in self.choice_cols:
193
+ all_unique_choices.update(samples[choice_col])
194
+ mapped_choices = set(self.choice_map.keys())
195
+
196
+ logger.info(
197
+ f"Mapping values based on choice columns for block '{self.block_name}'",
198
+ extra={
199
+ "block_name": self.block_name,
200
+ "choice_columns": self.choice_cols,
201
+ "output_columns": self.output_cols,
202
+ "choice_mappings": len(self.choice_map),
203
+ "unique_choices_in_data": len(all_unique_choices),
204
+ "unmapped_choices": len(all_unique_choices - mapped_choices),
205
+ },
206
+ )
207
+
208
+ # Apply the mapping
209
+ result = samples.map(self._generate)
210
+
211
+ # Log completion
212
+ logger.info(
213
+ f"Successfully applied choice mapping for block '{self.block_name}'",
214
+ extra={
215
+ "block_name": self.block_name,
216
+ "rows_processed": len(result),
217
+ "output_columns": self.output_cols,
218
+ "mapping_coverage": len(mapped_choices & all_unique_choices)
219
+ / len(all_unique_choices)
220
+ if all_unique_choices
221
+ else 0,
222
+ },
223
+ )
224
+
225
+ return result