sdg-hub 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sdg_hub/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.5.0'
32
- __version_tuple__ = version_tuple = (0, 5, 0)
31
+ __version__ = version = '0.6.0'
32
+ __version_tuple__ = version_tuple = (0, 6, 0)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -9,13 +9,14 @@ with unified constructor patterns, column handling, and common functionality.
9
9
  from abc import ABC, abstractmethod
10
10
  from typing import Any, Optional, Union
11
11
 
12
- # Third Party
13
- from datasets import Dataset
14
12
  from pydantic import BaseModel, ConfigDict, Field, field_validator
15
13
  from rich.console import Console
16
14
  from rich.panel import Panel
17
15
  from rich.text import Text
18
16
 
17
+ # Third Party
18
+ import pandas as pd
19
+
19
20
  # Local
20
21
  from ..utils.error_handling import (
21
22
  EmptyDatasetError,
@@ -32,7 +33,7 @@ class BaseBlock(BaseModel, ABC):
32
33
  """Base class for all blocks, with standardized patterns and full Pydantic compatibility.
33
34
 
34
35
  This class defines a unified, configurable base for building composable data processing blocks
35
- that operate over HuggingFace Datasets. It supports field-based initialization, validation,
36
+ that operate over pandas DataFrames. It supports field-based initialization, validation,
36
37
  and rich logging for inputs and outputs.
37
38
 
38
39
  Attributes
@@ -40,9 +41,9 @@ class BaseBlock(BaseModel, ABC):
40
41
  block_name : str
41
42
  Unique identifier for this block instance.
42
43
  input_cols : Union[List[str], Dict[str, Any]]
43
- Input columns from the dataset (string, list of strings, or mapping).
44
+ Input columns from the DataFrame (string, list of strings, or mapping).
44
45
  output_cols : Union[List[str], Dict[str, Any]]
45
- Output columns to write to the dataset (string, list of strings, or mapping).
46
+ Output columns to write to the DataFrame (string, list of strings, or mapping).
46
47
  """
47
48
 
48
49
  block_name: str = Field(
@@ -55,7 +56,7 @@ class BaseBlock(BaseModel, ABC):
55
56
  None, description="Output columns: str, list, or dict"
56
57
  )
57
58
 
58
- # Allow extra config fields and complex types like Dataset
59
+ # Allow extra config fields and complex types like DataFrame
59
60
  model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
60
61
 
61
62
  # Normalize input columns before model construction
@@ -101,13 +102,13 @@ class BaseBlock(BaseModel, ABC):
101
102
  return dict(cols)
102
103
  raise ValueError(f"Invalid column specification: {cols} (type: {type(cols)})")
103
104
 
104
- def _validate_columns(self, dataset: Dataset) -> None:
105
- """Check that all required input columns are present in the dataset.
105
+ def _validate_columns(self, df: pd.DataFrame) -> None:
106
+ """Check that all required input columns are present in the DataFrame.
106
107
 
107
108
  Parameters
108
109
  ----------
109
- dataset : Dataset
110
- HuggingFace dataset to validate against.
110
+ df : pd.DataFrame
111
+ DataFrame to validate against.
111
112
 
112
113
  Raises
113
114
  ------
@@ -121,28 +122,29 @@ class BaseBlock(BaseModel, ABC):
121
122
  if isinstance(self.input_cols, dict)
122
123
  else self.input_cols
123
124
  )
125
+ available_columns = df.columns.tolist()
124
126
  missing_columns = [
125
- col for col in columns_to_check if col not in dataset.column_names
127
+ col for col in columns_to_check if col not in available_columns
126
128
  ]
127
129
  if missing_columns:
128
130
  raise MissingColumnError(
129
131
  block_name=self.block_name,
130
132
  missing_columns=missing_columns,
131
- available_columns=dataset.column_names,
133
+ available_columns=available_columns,
132
134
  )
133
135
 
134
- def _validate_output_columns(self, dataset: Dataset) -> None:
136
+ def _validate_output_columns(self, df: pd.DataFrame) -> None:
135
137
  """Check that the output columns will not overwrite existing ones.
136
138
 
137
139
  Parameters
138
140
  ----------
139
- dataset : Dataset
140
- HuggingFace dataset to validate.
141
+ df : pd.DataFrame
142
+ DataFrame to validate.
141
143
 
142
144
  Raises
143
145
  ------
144
146
  OutputColumnCollisionError
145
- If output columns already exist in the dataset.
147
+ If output columns already exist in the DataFrame.
146
148
  """
147
149
  if not self.output_cols:
148
150
  return
@@ -151,42 +153,43 @@ class BaseBlock(BaseModel, ABC):
151
153
  if isinstance(self.output_cols, dict)
152
154
  else self.output_cols
153
155
  )
154
- collisions = [col for col in columns_to_check if col in dataset.column_names]
156
+ available_columns = df.columns.tolist()
157
+ collisions = [col for col in columns_to_check if col in available_columns]
155
158
  if collisions:
156
159
  raise OutputColumnCollisionError(
157
160
  block_name=self.block_name,
158
161
  collision_columns=collisions,
159
- existing_columns=dataset.column_names,
162
+ existing_columns=available_columns,
160
163
  )
161
164
 
162
- def _validate_dataset_not_empty(self, dataset: Dataset) -> None:
163
- """Raise an error if the dataset is empty.
165
+ def _validate_dataframe_not_empty(self, df: pd.DataFrame) -> None:
166
+ """Raise an error if the DataFrame is empty.
164
167
 
165
168
  Parameters
166
169
  ----------
167
- dataset : Dataset
170
+ df : pd.DataFrame
168
171
 
169
172
  Raises
170
173
  ------
171
174
  EmptyDatasetError
172
175
  """
173
- if len(dataset) == 0:
176
+ if len(df) == 0:
174
177
  raise EmptyDatasetError(block_name=self.block_name)
175
178
 
176
- def _validate_dataset(self, dataset: Dataset) -> None:
177
- """Perform all default dataset validations."""
178
- self._validate_dataset_not_empty(dataset)
179
- self._validate_columns(dataset)
180
- self._validate_output_columns(dataset)
179
+ def _validate_dataframe(self, df: pd.DataFrame) -> None:
180
+ """Perform all default DataFrame validations."""
181
+ self._validate_dataframe_not_empty(df)
182
+ self._validate_columns(df)
183
+ self._validate_output_columns(df)
181
184
 
182
- def _validate_custom(self, dataset: Dataset) -> None:
185
+ def _validate_custom(self, df: pd.DataFrame) -> None:
183
186
  """Hook for subclasses to add extra validation logic."""
184
187
  pass
185
188
 
186
- def _log_input_data(self, dataset: Dataset) -> None:
187
- """Print a summary of the input dataset with Rich formatting."""
188
- row_count = len(dataset)
189
- columns = dataset.column_names
189
+ def _log_input_data(self, df: pd.DataFrame) -> None:
190
+ """Print a summary of the input DataFrame with Rich formatting."""
191
+ row_count = len(df)
192
+ columns = df.columns.tolist()
190
193
  content = Text()
191
194
  content.append("\U0001f4ca Processing Input Data\n", style="bold blue")
192
195
  content.append(f"Block Type: {self.__class__.__name__}\n", style="cyan")
@@ -207,13 +210,12 @@ class BaseBlock(BaseModel, ABC):
207
210
  Panel(content, title=f"[bold]{self.block_name}[/bold]", border_style="blue")
208
211
  )
209
212
 
210
- def _log_output_data(self, input_dataset: Dataset, output_dataset: Dataset) -> None:
211
- """Print a Rich panel summarizing output dataset differences."""
212
- in_rows, out_rows = len(input_dataset), len(output_dataset)
213
- in_cols, out_cols = (
214
- set(input_dataset.column_names),
215
- set(output_dataset.column_names),
216
- )
213
+ def _log_output_data(self, input_df: pd.DataFrame, output_df: pd.DataFrame) -> None:
214
+ """Print a Rich panel summarizing output DataFrame differences."""
215
+ in_rows, out_rows = len(input_df), len(output_df)
216
+ in_cols = set(input_df.columns.tolist())
217
+ out_cols = set(output_df.columns.tolist())
218
+
217
219
  added_cols, removed_cols = out_cols - in_cols, in_cols - out_cols
218
220
  content = Text()
219
221
  content.append("\u2705 Processing Complete\n", style="bold green")
@@ -239,35 +241,35 @@ class BaseBlock(BaseModel, ABC):
239
241
  )
240
242
 
241
243
  @abstractmethod
242
- def generate(self, samples: Dataset, **kwargs: Any) -> Dataset:
244
+ def generate(self, samples: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
243
245
  """Subclass method to implement data generation logic.
244
246
 
245
247
  Parameters
246
248
  ----------
247
- samples : Dataset
248
- Input dataset to process.
249
+ samples : pd.DataFrame
250
+ Input DataFrame to process.
249
251
 
250
252
  Returns
251
253
  -------
252
- Dataset
253
- Transformed dataset with new columns or values.
254
+ pd.DataFrame
255
+ Transformed DataFrame with new columns or values.
254
256
  """
255
257
  pass
256
258
 
257
- def __call__(self, samples: Dataset, **kwargs: Any) -> Dataset:
258
- """Run the block on a dataset with full validation and logging.
259
+ def __call__(self, samples: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
260
+ """Run the block on a DataFrame with full validation and logging.
259
261
 
260
262
  Parameters
261
263
  ----------
262
- samples : Dataset
263
- Input dataset.
264
+ samples : pd.DataFrame
265
+ Input DataFrame.
264
266
  **kwargs : Any
265
267
  Runtime parameters to override block configuration
266
268
 
267
269
  Returns
268
270
  -------
269
- Dataset
270
- Output dataset after block processing.
271
+ pd.DataFrame
272
+ Output DataFrame after block processing.
271
273
  """
272
274
  # Handle runtime kwargs overrides
273
275
  if kwargs:
@@ -310,12 +312,12 @@ class BaseBlock(BaseModel, ABC):
310
312
 
311
313
  try:
312
314
  self._log_input_data(samples)
313
- self._validate_dataset(samples)
315
+ self._validate_dataframe(samples)
314
316
  self._validate_custom(samples)
315
317
  # Pass ALL kwargs to generate (including flow params)
316
- output_dataset = self.generate(samples, **kwargs)
317
- self._log_output_data(samples, output_dataset)
318
- return output_dataset
318
+ output_df = self.generate(samples, **kwargs)
319
+ self._log_output_data(samples, output_df)
320
+ return output_df
319
321
  finally:
320
322
  # Always restore original values for block fields
321
323
  for key, value in original_values.items():
@@ -323,11 +325,11 @@ class BaseBlock(BaseModel, ABC):
323
325
  else:
324
326
  # Normal execution without overrides
325
327
  self._log_input_data(samples)
326
- self._validate_dataset(samples)
328
+ self._validate_dataframe(samples)
327
329
  self._validate_custom(samples)
328
- output_dataset = self.generate(samples)
329
- self._log_output_data(samples, output_dataset)
330
- return output_dataset
330
+ output_df = self.generate(samples)
331
+ self._log_output_data(samples, output_df)
332
+ return output_df
331
333
 
332
334
  def __repr__(self) -> str:
333
335
  """Compact string representation."""
@@ -9,10 +9,11 @@ using various operations with optional data type conversion.
9
9
  from typing import Any, Optional, Union
10
10
  import operator
11
11
 
12
- # Third Party
13
- from datasets import Dataset
14
12
  from pydantic import Field, field_validator
15
13
 
14
+ # Third Party
15
+ import pandas as pd
16
+
16
17
  # Local
17
18
  from ...utils.logger_config import setup_logger
18
19
  from ..base import BaseBlock
@@ -158,32 +159,44 @@ class ColumnValueFilterBlock(BaseBlock):
158
159
  sample[self.column_name] = None
159
160
  return sample
160
161
 
161
- def generate(self, samples: Dataset, **_kwargs: Any) -> Dataset:
162
+ def generate(self, samples: pd.DataFrame, **_kwargs: Any) -> pd.DataFrame:
162
163
  """Generate filtered dataset based on specified conditions.
163
164
 
164
165
  Parameters
165
166
  ----------
166
- samples : Dataset
167
+ samples : pd.DataFrame
167
168
  The input dataset to filter.
168
169
 
169
170
  Returns
170
171
  -------
171
- Dataset
172
+ pd.DataFrame
172
173
  The filtered dataset.
173
174
  """
175
+ result = samples.copy()
176
+
177
+ # Convert dtype if specified
174
178
  if self._convert_dtype_func:
175
- samples = samples.map(self._convert_dtype)
176
179
 
177
- samples = samples.filter(
178
- lambda x: x[self.column_name] is not None,
179
- )
180
+ def safe_convert(x):
181
+ """Safely convert value, returning None on error."""
182
+ if pd.isna(x):
183
+ return None
184
+ try:
185
+ return self._convert_dtype_func(x)
186
+ except (ValueError, TypeError):
187
+ return None
180
188
 
181
- # Apply filter operation
182
- samples = samples.filter(
183
- lambda x: any(
184
- self._operation_func(x[self.column_name], value)
185
- for value in self.filter_value
186
- )
189
+ result[self.column_name] = result[self.column_name].apply(safe_convert)
190
+
191
+ # Filter out None values
192
+ result = result[result[self.column_name].notna()]
193
+
194
+ # Apply filter operation using boolean indexing
195
+ # Create a mask that checks if any filter value matches
196
+ mask = result[self.column_name].apply(
197
+ lambda x: any(self._operation_func(x, value) for value in self.filter_value)
187
198
  )
188
199
 
189
- return samples
200
+ result = result[mask]
201
+
202
+ return result
@@ -9,7 +9,6 @@ local models (vLLM, Ollama), and more.
9
9
  # Local
10
10
  from .error_handler import ErrorCategory, LLMErrorHandler
11
11
  from .llm_chat_block import LLMChatBlock
12
- from .llm_chat_with_parsing_retry_block import LLMChatWithParsingRetryBlock
13
12
  from .llm_parser_block import LLMParserBlock
14
13
  from .prompt_builder_block import PromptBuilderBlock
15
14
  from .text_parser_block import TextParserBlock
@@ -18,7 +17,6 @@ __all__ = [
18
17
  "LLMErrorHandler",
19
18
  "ErrorCategory",
20
19
  "LLMChatBlock",
21
- "LLMChatWithParsingRetryBlock",
22
20
  "LLMParserBlock",
23
21
  "PromptBuilderBlock",
24
22
  "TextParserBlock",
@@ -5,12 +5,13 @@
5
5
  from typing import Any, Optional
6
6
  import asyncio
7
7
 
8
- # Third Party
9
- from datasets import Dataset
10
8
  from litellm import acompletion, completion
11
9
  from pydantic import ConfigDict, Field, field_validator
12
10
  import litellm
13
11
 
12
+ # Third Party
13
+ import pandas as pd
14
+
14
15
  from ...utils.error_handling import BlockValidationError
15
16
  from ...utils.logger_config import setup_logger
16
17
 
@@ -167,12 +168,12 @@ class LLMChatBlock(BaseBlock):
167
168
  },
168
169
  )
169
170
 
170
- def generate(self, samples: Dataset, **kwargs: Any) -> Dataset:
171
+ def generate(self, samples: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
171
172
  """Generate responses from the LLM.
172
173
 
173
174
  Parameters
174
175
  ----------
175
- samples : Dataset
176
+ samples : pd.DataFrame
176
177
  Input dataset containing the messages column.
177
178
  **kwargs : Any
178
179
  Runtime parameters that override initialization defaults.
@@ -180,7 +181,7 @@ class LLMChatBlock(BaseBlock):
180
181
 
181
182
  Returns
182
183
  -------
183
- Dataset
184
+ pd.DataFrame
184
185
  Dataset with responses added to the output column.
185
186
 
186
187
  Raises
@@ -201,8 +202,8 @@ class LLMChatBlock(BaseBlock):
201
202
  # Build completion kwargs from ALL fields + runtime overrides
202
203
  completion_kwargs = self._build_completion_kwargs(**kwargs)
203
204
 
204
- # Extract messages
205
- messages_list = samples[self.input_cols[0]]
205
+ # Extract messages from pandas DataFrame
206
+ messages_list = samples[self.input_cols[0]].tolist()
206
207
 
207
208
  # Log generation start
208
209
  logger.info(
@@ -269,7 +270,9 @@ class LLMChatBlock(BaseBlock):
269
270
  )
270
271
 
271
272
  # Add responses as new column
272
- return samples.add_column(self.output_cols[0], responses)
273
+ result = samples.copy()
274
+ result[self.output_cols[0]] = responses
275
+ return result
273
276
 
274
277
  def _build_completion_kwargs(self, **overrides) -> dict[str, Any]:
275
278
  """Build kwargs for LiteLLM completion call.
@@ -513,12 +516,14 @@ class LLMChatBlock(BaseBlock):
513
516
  )
514
517
  raise
515
518
 
516
- def _validate_custom(self, dataset: Dataset) -> None:
519
+ def _validate_custom(self, dataset: pd.DataFrame) -> None:
517
520
  """Custom validation for LLMChatBlock message format.
518
521
 
522
+ Uses vectorized operations where possible for better performance.
523
+
519
524
  Parameters
520
525
  ----------
521
- dataset : Dataset
526
+ dataset : pd.DataFrame
522
527
  The dataset to validate.
523
528
 
524
529
  Raises
@@ -526,28 +531,32 @@ class LLMChatBlock(BaseBlock):
526
531
  BlockValidationError
527
532
  If message format validation fails.
528
533
  """
534
+ messages_col = dataset[self.input_cols[0]]
535
+
536
+ # avoid using pd iterrows() when possible, it is notoriously slow: https://github.com/pandas-dev/pandas/issues/7683
537
+ # Vectorized check: all values must be lists
538
+ is_list = messages_col.apply(lambda x: isinstance(x, list))
539
+ if not is_list.all():
540
+ invalid_idx = is_list[~is_list].index[0]
541
+ invalid_value = messages_col.loc[invalid_idx]
542
+ raise BlockValidationError(
543
+ f"Messages column '{self.input_cols[0]}' must contain a list, "
544
+ f"got {type(invalid_value)} in row {invalid_idx}",
545
+ details=f"Block: {self.block_name}, Row: {invalid_idx}, Value: {invalid_value}",
546
+ )
529
547
 
530
- def validate_sample(sample_with_index):
531
- """Validate a single sample's message format."""
532
- idx, sample = sample_with_index
533
- messages = sample[self.input_cols[0]]
534
-
535
- # Validate messages is a list
536
- if not isinstance(messages, list):
537
- raise BlockValidationError(
538
- f"Messages column '{self.input_cols[0]}' must contain a list, "
539
- f"got {type(messages)} in row {idx}",
540
- details=f"Block: {self.block_name}, Row: {idx}, Value: {messages}",
541
- )
542
-
543
- # Validate messages is not empty
544
- if not messages:
545
- raise BlockValidationError(
546
- f"Messages list is empty in row {idx}",
547
- details=f"Block: {self.block_name}, Row: {idx}",
548
- )
548
+ # Vectorized check: no empty lists
549
+ is_empty = messages_col.apply(lambda x: len(x) == 0)
550
+ if is_empty.any():
551
+ invalid_idx = is_empty[is_empty].index[0]
552
+ raise BlockValidationError(
553
+ f"Messages list is empty in row {invalid_idx}",
554
+ details=f"Block: {self.block_name}, Row: {invalid_idx}",
555
+ )
549
556
 
550
- # Validate each message format
557
+ # Validate nested message structure (requires iteration over messages column only)
558
+ def validate_message_structure(messages, idx):
559
+ """Validate structure of messages list."""
551
560
  for msg_idx, message in enumerate(messages):
552
561
  if not isinstance(message, dict):
553
562
  raise BlockValidationError(
@@ -555,7 +564,6 @@ class LLMChatBlock(BaseBlock):
555
564
  details=f"Block: {self.block_name}, Row: {idx}, Message: {msg_idx}, Value: {message}",
556
565
  )
557
566
 
558
- # Validate required fields
559
567
  if "role" not in message or message["role"] is None:
560
568
  raise BlockValidationError(
561
569
  f"Message {msg_idx} in row {idx} missing required 'role' field",
@@ -568,11 +576,9 @@ class LLMChatBlock(BaseBlock):
568
576
  details=f"Block: {self.block_name}, Row: {idx}, Message: {msg_idx}, Available fields: {list(message.keys())}",
569
577
  )
570
578
 
571
- return True
572
-
573
- # Validate all samples
574
- indexed_samples = [(i, sample) for i, sample in enumerate(dataset)]
575
- list(map(validate_sample, indexed_samples))
579
+ # Iterate only over the messages column (not the entire DataFrame)
580
+ for idx, messages in messages_col.items():
581
+ validate_message_structure(messages, idx)
576
582
 
577
583
  def __repr__(self) -> str:
578
584
  """String representation of the block."""
@@ -8,10 +8,11 @@ This module provides the LLMParserBlock for extracting specific fields
8
8
  # Standard
9
9
  from typing import Any
10
10
 
11
- # Third Party
12
- from datasets import Dataset
13
11
  from pydantic import Field, model_validator
14
12
 
13
+ # Third Party
14
+ import pandas as pd
15
+
15
16
  # Local
16
17
  from ...utils.logger_config import setup_logger
17
18
  from ..base import BaseBlock
@@ -26,6 +27,8 @@ logger = setup_logger(__name__)
26
27
  "Extracts specified fields from LLM response objects",
27
28
  )
28
29
  class LLMParserBlock(BaseBlock):
30
+ _flow_requires_jsonl_tmp: bool = True
31
+
29
32
  """Block for extracting fields from LLM response objects.
30
33
 
31
34
  This block extracts specified fields from chat completion response objects.
@@ -102,12 +105,12 @@ class LLMParserBlock(BaseBlock):
102
105
 
103
106
  return self
104
107
 
105
- def _validate_custom(self, dataset: Dataset) -> None:
108
+ def _validate_custom(self, dataset: pd.DataFrame) -> None:
106
109
  """Validate LLMParserBlock specific requirements.
107
110
 
108
111
  Parameters
109
112
  ----------
110
- dataset : Dataset
113
+ dataset : pd.DataFrame
111
114
  The dataset to validate.
112
115
 
113
116
  Raises
@@ -308,13 +311,16 @@ class LLMParserBlock(BaseBlock):
308
311
  extracted = self._extract_fields_from_response(raw_output)
309
312
  return [{**sample, **extracted}]
310
313
 
311
- def generate(self, samples: Dataset, **kwargs: Any) -> Dataset:
314
+ def generate(self, samples: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
312
315
  logger.debug(f"Extracting fields from {len(samples)} samples")
313
316
  if len(samples) == 0:
314
317
  logger.warning("No samples to process, returning empty dataset")
315
- return Dataset.from_list([])
318
+ return pd.DataFrame()
316
319
 
317
320
  new_data = []
321
+ samples = samples.to_dict("records") # Avoid Iterrows() when possible
322
+
318
323
  for sample in samples:
319
324
  new_data.extend(self._generate(sample))
320
- return Dataset.from_list(new_data)
325
+
326
+ return pd.DataFrame(new_data)
@@ -8,10 +8,11 @@ including conversion to OpenAI Messages format and template rendering.
8
8
  # Standard
9
9
  from typing import Any, Literal, Optional
10
10
 
11
- # Third Party
12
- from datasets import Dataset
13
11
  from jinja2 import Template, meta
14
12
  from pydantic import BaseModel, Field, field_validator
13
+
14
+ # Third Party
15
+ import pandas as pd
15
16
  import yaml
16
17
 
17
18
  # Local
@@ -279,12 +280,14 @@ class PromptBuilderBlock(BaseBlock):
279
280
  message_templates = self.prompt_template_config.get_message_templates()
280
281
  self.prompt_renderer = PromptRenderer(message_templates)
281
282
 
282
- def _validate_custom(self, dataset: Dataset) -> None:
283
+ def _validate_custom(self, dataset: pd.DataFrame) -> None:
283
284
  if len(dataset) > 0:
284
285
  # Get required variables from all message templates
285
286
  required_vars = self.prompt_renderer.get_required_variables()
286
287
 
287
- sample = dataset[0]
288
+ # Get first row as dict
289
+ sample = dataset.iloc[0].to_dict()
290
+
288
291
  template_vars = self.prompt_renderer.resolve_template_vars(
289
292
  sample, self.input_cols
290
293
  )
@@ -344,25 +347,27 @@ class PromptBuilderBlock(BaseBlock):
344
347
 
345
348
  return sample
346
349
 
347
- def generate(self, samples: Dataset, **_kwargs: Any) -> Dataset:
348
- """Generate formatted output for all samples using dataset map.
350
+ def generate(self, samples: pd.DataFrame, **_kwargs: Any) -> pd.DataFrame:
351
+ """Generate formatted output for all samples.
349
352
 
350
353
  Parameters
351
354
  ----------
352
- samples : Dataset
355
+ samples : pd.DataFrame
353
356
  Input dataset containing samples to be formatted.
354
357
  **kwargs : Dict[str, Any]
355
358
  Additional keyword arguments (unused in this block).
356
359
 
357
360
  Returns
358
361
  -------
359
- Dataset
362
+ pd.DataFrame
360
363
  Dataset with the formatted output added to the specified column.
361
364
  """
362
365
  logger.debug(f"Formatting prompts for {len(samples)} samples")
363
366
 
364
- # Use dataset map for efficient processing
365
- formatted_dataset = samples.map(self._generate)
367
+ # Convert DataFrame to list of dicts, process each, and convert back
368
+ samples_list = samples.to_dict("records")
369
+ formatted_samples = [self._generate(sample) for sample in samples_list]
370
+ formatted_dataset = pd.DataFrame(formatted_samples)
366
371
 
367
372
  logger.debug(f"Successfully formatted {len(formatted_dataset)} samples")
368
373
  return formatted_dataset