sdg-hub 0.5.1__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. sdg_hub/_version.py +2 -2
  2. sdg_hub/core/blocks/base.py +60 -58
  3. sdg_hub/core/blocks/filtering/column_value_filter.py +29 -16
  4. sdg_hub/core/blocks/llm/__init__.py +0 -2
  5. sdg_hub/core/blocks/llm/llm_chat_block.py +42 -36
  6. sdg_hub/core/blocks/llm/llm_parser_block.py +13 -59
  7. sdg_hub/core/blocks/llm/prompt_builder_block.py +15 -10
  8. sdg_hub/core/blocks/llm/text_parser_block.py +14 -61
  9. sdg_hub/core/blocks/transform/duplicate_columns.py +9 -8
  10. sdg_hub/core/blocks/transform/index_based_mapper.py +29 -15
  11. sdg_hub/core/blocks/transform/json_structure_block.py +16 -13
  12. sdg_hub/core/blocks/transform/melt_columns.py +13 -12
  13. sdg_hub/core/blocks/transform/rename_columns.py +20 -9
  14. sdg_hub/core/blocks/transform/text_concat.py +20 -21
  15. sdg_hub/core/blocks/transform/uniform_col_val_setter.py +6 -5
  16. sdg_hub/core/flow/base.py +139 -106
  17. sdg_hub/core/flow/checkpointer.py +34 -36
  18. sdg_hub/core/flow/validation.py +4 -4
  19. sdg_hub/core/utils/datautils.py +52 -54
  20. sdg_hub/core/utils/flow_metrics.py +9 -6
  21. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/multilingual/japanese/flow.yaml +1 -0
  22. {sdg_hub-0.5.1.dist-info → sdg_hub-0.6.1.dist-info}/METADATA +5 -9
  23. {sdg_hub-0.5.1.dist-info → sdg_hub-0.6.1.dist-info}/RECORD +26 -28
  24. sdg_hub/core/blocks/llm/llm_chat_with_parsing_retry_block.py +0 -771
  25. sdg_hub/core/utils/temp_manager.py +0 -57
  26. {sdg_hub-0.5.1.dist-info → sdg_hub-0.6.1.dist-info}/WHEEL +0 -0
  27. {sdg_hub-0.5.1.dist-info → sdg_hub-0.6.1.dist-info}/licenses/LICENSE +0 -0
  28. {sdg_hub-0.5.1.dist-info → sdg_hub-0.6.1.dist-info}/top_level.txt +0 -0
@@ -8,10 +8,11 @@ including conversion to OpenAI Messages format and template rendering.
8
8
  # Standard
9
9
  from typing import Any, Literal, Optional
10
10
 
11
- # Third Party
12
- from datasets import Dataset
13
11
  from jinja2 import Template, meta
14
12
  from pydantic import BaseModel, Field, field_validator
13
+
14
+ # Third Party
15
+ import pandas as pd
15
16
  import yaml
16
17
 
17
18
  # Local
@@ -279,12 +280,14 @@ class PromptBuilderBlock(BaseBlock):
279
280
  message_templates = self.prompt_template_config.get_message_templates()
280
281
  self.prompt_renderer = PromptRenderer(message_templates)
281
282
 
282
- def _validate_custom(self, dataset: Dataset) -> None:
283
+ def _validate_custom(self, dataset: pd.DataFrame) -> None:
283
284
  if len(dataset) > 0:
284
285
  # Get required variables from all message templates
285
286
  required_vars = self.prompt_renderer.get_required_variables()
286
287
 
287
- sample = dataset[0]
288
+ # Get first row as dict
289
+ sample = dataset.iloc[0].to_dict()
290
+
288
291
  template_vars = self.prompt_renderer.resolve_template_vars(
289
292
  sample, self.input_cols
290
293
  )
@@ -344,25 +347,27 @@ class PromptBuilderBlock(BaseBlock):
344
347
 
345
348
  return sample
346
349
 
347
- def generate(self, samples: Dataset, **_kwargs: Any) -> Dataset:
348
- """Generate formatted output for all samples using dataset map.
350
+ def generate(self, samples: pd.DataFrame, **_kwargs: Any) -> pd.DataFrame:
351
+ """Generate formatted output for all samples.
349
352
 
350
353
  Parameters
351
354
  ----------
352
- samples : Dataset
355
+ samples : pd.DataFrame
353
356
  Input dataset containing samples to be formatted.
354
357
  **kwargs : Dict[str, Any]
355
358
  Additional keyword arguments (unused in this block).
356
359
 
357
360
  Returns
358
361
  -------
359
- Dataset
362
+ pd.DataFrame
360
363
  Dataset with the formatted output added to the specified column.
361
364
  """
362
365
  logger.debug(f"Formatting prompts for {len(samples)} samples")
363
366
 
364
- # Use dataset map for efficient processing
365
- formatted_dataset = samples.map(self._generate)
367
+ # Convert DataFrame to list of dicts, process each, and convert back
368
+ samples_list = samples.to_dict("records")
369
+ formatted_samples = [self._generate(sample) for sample in samples_list]
370
+ formatted_dataset = pd.DataFrame(formatted_samples)
366
371
 
367
372
  logger.debug(f"Successfully formatted {len(formatted_dataset)} samples")
368
373
  return formatted_dataset
@@ -7,17 +7,15 @@ start/end tags, custom regex patterns, and cleanup operations.
7
7
 
8
8
  # Standard
9
9
  from typing import Any, Optional
10
- from weakref import finalize
11
- import json
12
10
  import re
13
11
 
14
- # Third Party
15
- from datasets import Dataset, load_dataset
16
12
  from pydantic import Field, field_validator, model_validator
17
13
 
14
+ # Third Party
15
+ import pandas as pd
16
+
18
17
  # Local
19
18
  from ...utils.logger_config import setup_logger
20
- from ...utils.temp_manager import cleanup_path, create_temp_dir, create_temp_file
21
19
  from ..base import BaseBlock
22
20
  from ..registry import BlockRegistry
23
21
 
@@ -122,12 +120,12 @@ class TextParserBlock(BaseBlock):
122
120
 
123
121
  return self
124
122
 
125
- def _validate_custom(self, dataset: Dataset) -> None:
123
+ def _validate_custom(self, dataset: pd.DataFrame) -> None:
126
124
  """Validate TextParserBlock specific requirements.
127
125
 
128
126
  Parameters
129
127
  ----------
130
- dataset : Dataset
128
+ dataset : pd.DataFrame
131
129
  The dataset to validate.
132
130
 
133
131
  Raises
@@ -316,60 +314,15 @@ class TextParserBlock(BaseBlock):
316
314
  )
317
315
  return []
318
316
 
319
- def generate(self, samples: Dataset, **kwargs: Any) -> Dataset:
317
+ def generate(self, samples: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
320
318
  logger.debug(f"Parsing outputs for {len(samples)} samples")
321
319
  if len(samples) == 0:
322
320
  logger.warning("No samples to parse, returning empty dataset")
323
- return Dataset.from_list([])
324
-
325
- tmp_jsonl_path = kwargs.get("_flow_tmp_jsonl_path")
326
- cleanup_locally = False
327
-
328
- if tmp_jsonl_path is None:
329
- tmp_jsonl_path = str(
330
- create_temp_file(
331
- prefix=f"{self.block_name}_text_parser", suffix=".jsonl"
332
- )
333
- )
334
- cleanup_locally = True
335
-
336
- rows_written = 0
337
- batch = []
338
- with open(tmp_jsonl_path, "w") as f:
339
- for sample in samples:
340
- out = self._generate(sample)
341
- for row in out:
342
- batch.append(json.dumps(row) + "\n")
343
- rows_written += 1
344
- if len(batch) >= 5:
345
- f.writelines(batch)
346
- batch.clear()
347
- if batch:
348
- f.writelines(batch)
349
-
350
- if rows_written == 0:
351
- if cleanup_locally:
352
- cleanup_path(tmp_jsonl_path)
353
- return Dataset.from_list([])
354
-
355
- hf_cache_dir = None
356
- try:
357
- hf_cache_dir = create_temp_dir(
358
- prefix=f"{self.block_name}_text_parser_hf_cache"
359
- )
360
- ret = load_dataset(
361
- "json",
362
- data_files=tmp_jsonl_path,
363
- split="train",
364
- keep_in_memory=False,
365
- cache_dir=str(hf_cache_dir),
366
- )
367
- finalize(ret, cleanup_path, hf_cache_dir)
368
- return ret
369
- except Exception:
370
- if hf_cache_dir is not None:
371
- cleanup_path(hf_cache_dir)
372
- raise
373
- finally:
374
- if cleanup_locally:
375
- cleanup_path(tmp_jsonl_path)
321
+ return pd.DataFrame()
322
+
323
+ # Convert DataFrame to list of dicts to avoid iterrows and improve performance
324
+ samples_list = samples.to_dict("records")
325
+ new_data: list[dict] = []
326
+ for sample in samples_list:
327
+ new_data.extend(self._generate(sample))
328
+ return pd.DataFrame(new_data)
@@ -8,10 +8,11 @@ according to a mapping specification.
8
8
  # Standard
9
9
  from typing import Any
10
10
 
11
- # Third Party
12
- from datasets import Dataset
13
11
  from pydantic import field_validator
14
12
 
13
+ # Third Party
14
+ import pandas as pd
15
+
15
16
  # Local
16
17
  from ...utils.logger_config import setup_logger
17
18
  from ..base import BaseBlock
@@ -62,27 +63,27 @@ class DuplicateColumnsBlock(BaseBlock):
62
63
  if self.output_cols is None:
63
64
  self.output_cols = list(self.input_cols.values())
64
65
 
65
- def generate(self, samples: Dataset, **kwargs: Any) -> Dataset:
66
+ def generate(self, samples: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
66
67
  """Generate a dataset with duplicated columns.
67
68
 
68
69
  Parameters
69
70
  ----------
70
- samples : Dataset
71
+ samples : pd.DataFrame
71
72
  Input dataset to duplicate columns from.
72
73
 
73
74
  Returns
74
75
  -------
75
- Dataset
76
+ pd.DataFrame
76
77
  Dataset with additional duplicated columns.
77
78
  """
78
79
  # Create a copy to avoid modifying the original
79
- result = samples
80
+ result = samples.copy()
80
81
 
81
82
  # Duplicate each column as specified in the mapping
82
83
  for source_col, target_col in self.input_cols.items():
83
- if source_col not in result.column_names:
84
+ if source_col not in result.columns.tolist():
84
85
  raise ValueError(f"Source column '{source_col}' not found in dataset")
85
86
 
86
- result = result.add_column(target_col, result[source_col])
87
+ result[target_col] = result[source_col]
87
88
 
88
89
  return result
@@ -8,10 +8,11 @@ to another based on a choice column's value.
8
8
  # Standard
9
9
  from typing import Any
10
10
 
11
- # Third Party
12
- from datasets import Dataset
13
11
  from pydantic import Field, field_validator, model_validator
14
12
 
13
+ # Third Party
14
+ import pandas as pd
15
+
15
16
  # Local
16
17
  from ...utils.error_handling import MissingColumnError
17
18
  from ...utils.logger_config import setup_logger
@@ -103,12 +104,12 @@ class IndexBasedMapperBlock(BaseBlock):
103
104
  # Create mapping from choice_col to output_col for easy access
104
105
  self.choice_to_output_map = dict(zip(self.choice_cols, self.output_cols))
105
106
 
106
- def _validate_custom(self, samples: Dataset) -> None:
107
+ def _validate_custom(self, samples: pd.DataFrame) -> None:
107
108
  """Validate that required columns exist in the dataset.
108
109
 
109
110
  Parameters
110
111
  ----------
111
- samples : Dataset
112
+ samples : pd.DataFrame
112
113
  Input dataset to validate.
113
114
 
114
115
  Raises
@@ -120,29 +121,29 @@ class IndexBasedMapperBlock(BaseBlock):
120
121
  """
121
122
  # Check that all choice_cols exist
122
123
  missing_choice_cols = [
123
- col for col in self.choice_cols if col not in samples.column_names
124
+ col for col in self.choice_cols if col not in samples.columns.tolist()
124
125
  ]
125
126
  if missing_choice_cols:
126
127
  raise MissingColumnError(
127
128
  block_name=self.block_name,
128
129
  missing_columns=missing_choice_cols,
129
- available_columns=samples.column_names,
130
+ available_columns=samples.columns.tolist(),
130
131
  )
131
132
 
132
133
  # Check that all mapped columns exist
133
134
  mapped_cols = list(self.choice_map.values())
134
- missing_cols = list(set(mapped_cols) - set(samples.column_names))
135
+ missing_cols = list(set(mapped_cols) - set(samples.columns.tolist()))
135
136
  if missing_cols:
136
137
  raise MissingColumnError(
137
138
  block_name=self.block_name,
138
139
  missing_columns=missing_cols,
139
- available_columns=samples.column_names,
140
+ available_columns=samples.columns.tolist(),
140
141
  )
141
142
 
142
143
  # Check that all choice values in all choice columns have corresponding mappings
143
144
  all_unique_choices = set()
144
145
  for choice_col in self.choice_cols:
145
- all_unique_choices.update(samples[choice_col])
146
+ all_unique_choices.update(samples[choice_col].unique())
146
147
 
147
148
  mapped_choices = set(self.choice_map.keys())
148
149
  unmapped_choices = all_unique_choices - mapped_choices
@@ -174,23 +175,23 @@ class IndexBasedMapperBlock(BaseBlock):
174
175
  sample[output_col] = sample[source_col]
175
176
  return sample
176
177
 
177
- def generate(self, samples: Dataset, **kwargs) -> Dataset:
178
+ def generate(self, samples: pd.DataFrame, **kwargs) -> pd.DataFrame:
178
179
  """Generate a new dataset with selected values.
179
180
 
180
181
  Parameters
181
182
  ----------
182
- samples : Dataset
183
+ samples : pd.DataFrame
183
184
  Input dataset to process.
184
185
 
185
186
  Returns
186
187
  -------
187
- Dataset
188
+ pd.DataFrame
188
189
  Dataset with selected values stored in output column.
189
190
  """
190
191
  # Log the operation
191
192
  all_unique_choices = set()
192
193
  for choice_col in self.choice_cols:
193
- all_unique_choices.update(samples[choice_col])
194
+ all_unique_choices.update(samples[choice_col].unique())
194
195
  mapped_choices = set(self.choice_map.keys())
195
196
 
196
197
  logger.info(
@@ -205,8 +206,21 @@ class IndexBasedMapperBlock(BaseBlock):
205
206
  },
206
207
  )
207
208
 
208
- # Apply the mapping
209
- result = samples.map(self._generate)
209
+ # Create a copy to avoid modifying the input
210
+ result = samples.copy()
211
+
212
+ # Handle empty DataFrame case
213
+ if len(result) == 0:
214
+ # Add empty output columns
215
+ for output_col in self.output_cols:
216
+ result[output_col] = []
217
+ else:
218
+ # Apply the mapping for each choice column and output column pair
219
+ for choice_col, output_col in self.choice_to_output_map.items():
220
+ # Map the choice values to source columns, then get values from those columns
221
+ result[output_col] = result.apply(
222
+ lambda row: row[self.choice_map[row[choice_col]]], axis=1
223
+ )
210
224
 
211
225
  # Log completion
212
226
  logger.info(
@@ -9,10 +9,11 @@ containing a structured JSON object with specified field names.
9
9
  from typing import Any, Dict
10
10
  import json
11
11
 
12
- # Third Party
13
- from datasets import Dataset
14
12
  from pydantic import Field, field_validator
15
13
 
14
+ # Third Party
15
+ import pandas as pd
16
+
16
17
  # Local
17
18
  from ...utils.logger_config import setup_logger
18
19
  from ..base import BaseBlock
@@ -90,17 +91,17 @@ class JSONStructureBlock(BaseBlock):
90
91
 
91
92
  raise ValueError("input_cols must be a list of column names")
92
93
 
93
- def generate(self, samples: Dataset, **kwargs: Any) -> Dataset:
94
+ def generate(self, samples: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
94
95
  """Generate a dataset with JSON structured output.
95
96
 
96
97
  Parameters
97
98
  ----------
98
- samples : Dataset
99
+ samples : pd.DataFrame
99
100
  Input dataset to process.
100
101
 
101
102
  Returns
102
103
  -------
103
- Dataset
104
+ pd.DataFrame
104
105
  Dataset with JSON structured output in the specified column.
105
106
  """
106
107
  if not self.output_cols:
@@ -109,17 +110,17 @@ class JSONStructureBlock(BaseBlock):
109
110
  output_col = self.output_cols[0]
110
111
  field_mapping = self._get_field_mapping()
111
112
 
112
- def _create_json_structure(sample):
113
+ def _create_json_structure(row):
113
114
  """Create JSON structure from input columns."""
114
115
  json_obj = {}
115
116
 
116
117
  # Build the JSON object using the field mapping
117
118
  for json_field, col_name in field_mapping.items():
118
- if col_name not in sample:
119
- logger.warning(f"Input column '{col_name}' not found in sample")
119
+ if col_name not in row.index:
120
+ logger.warning(f"Input column '{col_name}' not found in row")
120
121
  json_obj[json_field] = None
121
122
  else:
122
- value = sample[col_name]
123
+ value = row[col_name]
123
124
  if self.ensure_json_serializable:
124
125
  value = self._make_json_serializable(value)
125
126
  json_obj[json_field] = value
@@ -130,13 +131,15 @@ class JSONStructureBlock(BaseBlock):
130
131
  json_string = json.dumps(json_obj, indent=2, ensure_ascii=False)
131
132
  else:
132
133
  json_string = json.dumps(json_obj, ensure_ascii=False)
133
- sample[output_col] = json_string
134
+ return json_string
134
135
  except (TypeError, ValueError) as e:
135
136
  logger.error(f"Failed to serialize JSON object: {e}")
136
- sample[output_col] = "{}"
137
+ return "{}"
137
138
 
138
- return sample
139
+ # Create a copy to avoid modifying the input
140
+ result = samples.copy()
139
141
 
140
142
  # Apply the JSON structuring to all samples
141
- result = samples.map(_create_json_structure)
143
+ result[output_col] = result.apply(_create_json_structure, axis=1)
144
+
142
145
  return result
@@ -8,10 +8,11 @@ by melting specified columns into rows.
8
8
  # Standard
9
9
  from typing import Any
10
10
 
11
- # Third Party
12
- from datasets import Dataset
13
11
  from pydantic import field_validator
14
12
 
13
+ # Third Party
14
+ import pandas as pd
15
+
15
16
  # Local
16
17
  from ...utils.error_handling import MissingColumnError
17
18
  from ...utils.logger_config import setup_logger
@@ -79,12 +80,12 @@ class MeltColumnsBlock(BaseBlock):
79
80
  self.input_cols if isinstance(self.input_cols, list) else [self.input_cols]
80
81
  )
81
82
 
82
- def _validate_custom(self, samples: Dataset) -> None:
83
+ def _validate_custom(self, samples: pd.DataFrame) -> None:
83
84
  """Validate that required columns exist in the dataset.
84
85
 
85
86
  Parameters
86
87
  ----------
87
- samples : Dataset
88
+ samples : pd.DataFrame
88
89
  Input dataset to validate.
89
90
 
90
91
  Raises
@@ -93,34 +94,34 @@ class MeltColumnsBlock(BaseBlock):
93
94
  If required columns are missing from the dataset.
94
95
  """
95
96
  # Check that all var_cols exist in the dataset
96
- missing_cols = list(set(self.var_cols) - set(samples.column_names))
97
+ missing_cols = list(set(self.var_cols) - set(samples.columns.tolist()))
97
98
  if missing_cols:
98
99
  raise MissingColumnError(
99
100
  block_name=self.block_name,
100
101
  missing_columns=missing_cols,
101
- available_columns=samples.column_names,
102
+ available_columns=samples.columns.tolist(),
102
103
  )
103
104
 
104
- def generate(self, samples: Dataset, **kwargs: Any) -> Dataset:
105
+ def generate(self, samples: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
105
106
  """Generate a flattened dataset in long format.
106
107
 
107
108
  Parameters
108
109
  ----------
109
- samples : Dataset
110
+ samples : pd.DataFrame
110
111
  Input dataset to flatten.
111
112
 
112
113
  Returns
113
114
  -------
114
- Dataset
115
+ pd.DataFrame
115
116
  Flattened dataset in long format with new variable and value columns.
116
117
  """
117
118
  # Use the original simple logic - just adapted to use derived attributes
118
- df = samples.to_pandas()
119
- id_cols = [col for col in samples.column_names if col not in self.var_cols]
119
+ df = samples
120
+ id_cols = [col for col in samples.columns.tolist() if col not in self.var_cols]
120
121
  flatten_df = df.melt(
121
122
  id_vars=id_cols,
122
123
  value_vars=self.var_cols,
123
124
  value_name=self.value_name,
124
125
  var_name=self.var_name,
125
126
  )
126
- return Dataset.from_pandas(flatten_df)
127
+ return flatten_df
@@ -8,10 +8,11 @@ to a mapping specification.
8
8
  # Standard
9
9
  from typing import Any
10
10
 
11
- # Third Party
12
- from datasets import Dataset
13
11
  from pydantic import field_validator
14
12
 
13
+ # Third Party
14
+ import pandas as pd
15
+
15
16
  # Local
16
17
  from ...utils.logger_config import setup_logger
17
18
  from ..base import BaseBlock
@@ -52,28 +53,38 @@ class RenameColumnsBlock(BaseBlock):
52
53
  )
53
54
  return v
54
55
 
55
- def generate(self, samples: Dataset, **kwargs: Any) -> Dataset:
56
+ def generate(self, samples: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
56
57
  """Generate a dataset with renamed columns.
57
58
 
58
59
  Parameters
59
60
  ----------
60
- samples : Dataset
61
+ samples : pd.DataFrame
61
62
  Input dataset to rename columns in.
62
63
 
63
64
  Returns
64
65
  -------
65
- Dataset
66
+ pd.DataFrame
66
67
  Dataset with renamed columns.
67
68
 
68
69
  Raises
69
70
  ------
70
71
  ValueError
71
- If attempting to rename to a column name that already exists.
72
+ If attempting to rename to a column name that already exists,
73
+ or if the original column names don't exist in the dataset.
72
74
  """
75
+ # Check that all original column names exist in the dataset
76
+ existing_cols = set(samples.columns.tolist())
77
+ original_cols = set(self.input_cols.keys())
78
+
79
+ missing_cols = original_cols - existing_cols
80
+ if missing_cols:
81
+ raise ValueError(
82
+ f"Original column names {sorted(missing_cols)} not in the dataset"
83
+ )
84
+
73
85
  # Check for column name collisions
74
86
  # Strict validation: no target column name can be an existing column name
75
87
  # This prevents chained/circular renames which can be confusing
76
- existing_cols = set(samples.column_names)
77
88
  target_cols = set(self.input_cols.values())
78
89
 
79
90
  collision = target_cols & existing_cols
@@ -84,5 +95,5 @@ class RenameColumnsBlock(BaseBlock):
84
95
  "Chained renames are not supported."
85
96
  )
86
97
 
87
- # Rename columns using HuggingFace datasets method
88
- return samples.rename_columns(self.input_cols)
98
+ # Rename columns using pandas method
99
+ return samples.rename(columns=self.input_cols)
@@ -8,10 +8,11 @@ using a specified separator.
8
8
  # Standard
9
9
  from typing import Any
10
10
 
11
- # Third Party
12
- from datasets import Dataset
13
11
  from pydantic import Field, field_validator
14
12
 
13
+ # Third Party
14
+ import pandas as pd
15
+
15
16
  # Local
16
17
  from ...utils.logger_config import setup_logger
17
18
  from ..base import BaseBlock
@@ -65,17 +66,17 @@ class TextConcatBlock(BaseBlock):
65
66
  raise ValueError("TextConcatBlock requires exactly one output column")
66
67
  return v
67
68
 
68
- def generate(self, samples: Dataset, **kwargs: Any) -> Dataset:
69
+ def generate(self, samples: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
69
70
  """Generate a dataset with combined columns.
70
71
 
71
72
  Parameters
72
73
  ----------
73
- samples : Dataset
74
+ samples : pd.DataFrame
74
75
  Input dataset to process.
75
76
 
76
77
  Returns
77
78
  -------
78
- Dataset
79
+ pd.DataFrame
79
80
  Dataset with combined values stored in output column.
80
81
  """
81
82
  if not self.output_cols:
@@ -83,20 +84,18 @@ class TextConcatBlock(BaseBlock):
83
84
 
84
85
  output_col = self.output_cols[0]
85
86
 
86
- def _combine_columns(sample):
87
- """Combine values from input columns."""
88
- # Check that all input columns exist
89
- for col in self.input_cols:
90
- if col not in sample:
91
- raise ValueError(f"Input column '{col}' not found in sample")
92
-
93
- # Combine values using separator
94
- combined_value = self.separator.join(
95
- [str(sample[col]) for col in self.input_cols]
96
- )
97
- sample[output_col] = combined_value
98
- return sample
99
-
100
- # Apply the combination to all samples
101
- result = samples.map(_combine_columns)
87
+ # Validate that all input columns exist in the dataset
88
+ for col in self.input_cols:
89
+ if col not in samples.columns:
90
+ raise ValueError(f"Input column '{col}' not found in sample")
91
+
92
+ # Create a copy to avoid modifying the input
93
+ result = samples.copy()
94
+
95
+ # Combine columns using vectorized string operations
96
+ # Convert all input columns to strings and concatenate with separator
97
+ result[output_col] = (
98
+ result[self.input_cols].astype(str).agg(self.separator.join, axis=1)
99
+ )
100
+
102
101
  return result
@@ -8,11 +8,12 @@ mode, min, max, mean, or median.
8
8
  # Standard
9
9
  from typing import Any, Literal
10
10
 
11
- # Third Party
12
- from datasets import Dataset
13
11
  from pydantic import field_validator
14
12
  import numpy as np
15
13
 
14
+ # Third Party
15
+ import pandas as pd
16
+
16
17
  # Local
17
18
  from ...utils.logger_config import setup_logger
18
19
  from ..base import BaseBlock
@@ -66,8 +67,8 @@ class UniformColumnValueSetter(BaseBlock):
66
67
  self.output_cols = []
67
68
  self.col_name = self.input_cols[0]
68
69
 
69
- def generate(self, samples: Dataset, **kwargs: Any) -> Dataset:
70
- df = samples.to_pandas()
70
+ def generate(self, samples: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
71
+ df = samples.copy()
71
72
 
72
73
  if df.empty:
73
74
  raise ValueError("Cannot compute reduction for empty dataset")
@@ -98,4 +99,4 @@ class UniformColumnValueSetter(BaseBlock):
98
99
  )
99
100
 
100
101
  df[self.col_name] = value
101
- return Dataset.from_pandas(df)
102
+ return df