sdg-hub 0.5.1__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. sdg_hub/_version.py +2 -2
  2. sdg_hub/core/blocks/base.py +60 -58
  3. sdg_hub/core/blocks/filtering/column_value_filter.py +29 -16
  4. sdg_hub/core/blocks/llm/__init__.py +0 -2
  5. sdg_hub/core/blocks/llm/llm_chat_block.py +42 -36
  6. sdg_hub/core/blocks/llm/llm_parser_block.py +13 -59
  7. sdg_hub/core/blocks/llm/prompt_builder_block.py +15 -10
  8. sdg_hub/core/blocks/llm/text_parser_block.py +14 -61
  9. sdg_hub/core/blocks/transform/duplicate_columns.py +9 -8
  10. sdg_hub/core/blocks/transform/index_based_mapper.py +29 -15
  11. sdg_hub/core/blocks/transform/json_structure_block.py +16 -13
  12. sdg_hub/core/blocks/transform/melt_columns.py +13 -12
  13. sdg_hub/core/blocks/transform/rename_columns.py +20 -9
  14. sdg_hub/core/blocks/transform/text_concat.py +20 -21
  15. sdg_hub/core/blocks/transform/uniform_col_val_setter.py +6 -5
  16. sdg_hub/core/flow/base.py +139 -106
  17. sdg_hub/core/flow/checkpointer.py +34 -36
  18. sdg_hub/core/flow/validation.py +4 -4
  19. sdg_hub/core/utils/datautils.py +52 -54
  20. sdg_hub/core/utils/flow_metrics.py +9 -6
  21. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/multilingual/japanese/flow.yaml +1 -0
  22. {sdg_hub-0.5.1.dist-info → sdg_hub-0.6.1.dist-info}/METADATA +5 -9
  23. {sdg_hub-0.5.1.dist-info → sdg_hub-0.6.1.dist-info}/RECORD +26 -28
  24. sdg_hub/core/blocks/llm/llm_chat_with_parsing_retry_block.py +0 -771
  25. sdg_hub/core/utils/temp_manager.py +0 -57
  26. {sdg_hub-0.5.1.dist-info → sdg_hub-0.6.1.dist-info}/WHEEL +0 -0
  27. {sdg_hub-0.5.1.dist-info → sdg_hub-0.6.1.dist-info}/licenses/LICENSE +0 -0
  28. {sdg_hub-0.5.1.dist-info → sdg_hub-0.6.1.dist-info}/top_level.txt +0 -0
sdg_hub/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.5.1'
32
- __version_tuple__ = version_tuple = (0, 5, 1)
31
+ __version__ = version = '0.6.1'
32
+ __version_tuple__ = version_tuple = (0, 6, 1)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -9,13 +9,14 @@ with unified constructor patterns, column handling, and common functionality.
9
9
  from abc import ABC, abstractmethod
10
10
  from typing import Any, Optional, Union
11
11
 
12
- # Third Party
13
- from datasets import Dataset
14
12
  from pydantic import BaseModel, ConfigDict, Field, field_validator
15
13
  from rich.console import Console
16
14
  from rich.panel import Panel
17
15
  from rich.text import Text
18
16
 
17
+ # Third Party
18
+ import pandas as pd
19
+
19
20
  # Local
20
21
  from ..utils.error_handling import (
21
22
  EmptyDatasetError,
@@ -32,7 +33,7 @@ class BaseBlock(BaseModel, ABC):
32
33
  """Base class for all blocks, with standardized patterns and full Pydantic compatibility.
33
34
 
34
35
  This class defines a unified, configurable base for building composable data processing blocks
35
- that operate over HuggingFace Datasets. It supports field-based initialization, validation,
36
+ that operate over pandas DataFrames. It supports field-based initialization, validation,
36
37
  and rich logging for inputs and outputs.
37
38
 
38
39
  Attributes
@@ -40,9 +41,9 @@ class BaseBlock(BaseModel, ABC):
40
41
  block_name : str
41
42
  Unique identifier for this block instance.
42
43
  input_cols : Union[List[str], Dict[str, Any]]
43
- Input columns from the dataset (string, list of strings, or mapping).
44
+ Input columns from the DataFrame (string, list of strings, or mapping).
44
45
  output_cols : Union[List[str], Dict[str, Any]]
45
- Output columns to write to the dataset (string, list of strings, or mapping).
46
+ Output columns to write to the DataFrame (string, list of strings, or mapping).
46
47
  """
47
48
 
48
49
  block_name: str = Field(
@@ -55,7 +56,7 @@ class BaseBlock(BaseModel, ABC):
55
56
  None, description="Output columns: str, list, or dict"
56
57
  )
57
58
 
58
- # Allow extra config fields and complex types like Dataset
59
+ # Allow extra config fields and complex types like DataFrame
59
60
  model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
60
61
 
61
62
  # Normalize input columns before model construction
@@ -101,13 +102,13 @@ class BaseBlock(BaseModel, ABC):
101
102
  return dict(cols)
102
103
  raise ValueError(f"Invalid column specification: {cols} (type: {type(cols)})")
103
104
 
104
- def _validate_columns(self, dataset: Dataset) -> None:
105
- """Check that all required input columns are present in the dataset.
105
+ def _validate_columns(self, df: pd.DataFrame) -> None:
106
+ """Check that all required input columns are present in the DataFrame.
106
107
 
107
108
  Parameters
108
109
  ----------
109
- dataset : Dataset
110
- HuggingFace dataset to validate against.
110
+ df : pd.DataFrame
111
+ DataFrame to validate against.
111
112
 
112
113
  Raises
113
114
  ------
@@ -121,28 +122,29 @@ class BaseBlock(BaseModel, ABC):
121
122
  if isinstance(self.input_cols, dict)
122
123
  else self.input_cols
123
124
  )
125
+ available_columns = df.columns.tolist()
124
126
  missing_columns = [
125
- col for col in columns_to_check if col not in dataset.column_names
127
+ col for col in columns_to_check if col not in available_columns
126
128
  ]
127
129
  if missing_columns:
128
130
  raise MissingColumnError(
129
131
  block_name=self.block_name,
130
132
  missing_columns=missing_columns,
131
- available_columns=dataset.column_names,
133
+ available_columns=available_columns,
132
134
  )
133
135
 
134
- def _validate_output_columns(self, dataset: Dataset) -> None:
136
+ def _validate_output_columns(self, df: pd.DataFrame) -> None:
135
137
  """Check that the output columns will not overwrite existing ones.
136
138
 
137
139
  Parameters
138
140
  ----------
139
- dataset : Dataset
140
- HuggingFace dataset to validate.
141
+ df : pd.DataFrame
142
+ DataFrame to validate.
141
143
 
142
144
  Raises
143
145
  ------
144
146
  OutputColumnCollisionError
145
- If output columns already exist in the dataset.
147
+ If output columns already exist in the DataFrame.
146
148
  """
147
149
  if not self.output_cols:
148
150
  return
@@ -151,42 +153,43 @@ class BaseBlock(BaseModel, ABC):
151
153
  if isinstance(self.output_cols, dict)
152
154
  else self.output_cols
153
155
  )
154
- collisions = [col for col in columns_to_check if col in dataset.column_names]
156
+ available_columns = df.columns.tolist()
157
+ collisions = [col for col in columns_to_check if col in available_columns]
155
158
  if collisions:
156
159
  raise OutputColumnCollisionError(
157
160
  block_name=self.block_name,
158
161
  collision_columns=collisions,
159
- existing_columns=dataset.column_names,
162
+ existing_columns=available_columns,
160
163
  )
161
164
 
162
- def _validate_dataset_not_empty(self, dataset: Dataset) -> None:
163
- """Raise an error if the dataset is empty.
165
+ def _validate_dataframe_not_empty(self, df: pd.DataFrame) -> None:
166
+ """Raise an error if the DataFrame is empty.
164
167
 
165
168
  Parameters
166
169
  ----------
167
- dataset : Dataset
170
+ df : pd.DataFrame
168
171
 
169
172
  Raises
170
173
  ------
171
174
  EmptyDatasetError
172
175
  """
173
- if len(dataset) == 0:
176
+ if len(df) == 0:
174
177
  raise EmptyDatasetError(block_name=self.block_name)
175
178
 
176
- def _validate_dataset(self, dataset: Dataset) -> None:
177
- """Perform all default dataset validations."""
178
- self._validate_dataset_not_empty(dataset)
179
- self._validate_columns(dataset)
180
- self._validate_output_columns(dataset)
179
+ def _validate_dataframe(self, df: pd.DataFrame) -> None:
180
+ """Perform all default DataFrame validations."""
181
+ self._validate_dataframe_not_empty(df)
182
+ self._validate_columns(df)
183
+ self._validate_output_columns(df)
181
184
 
182
- def _validate_custom(self, dataset: Dataset) -> None:
185
+ def _validate_custom(self, df: pd.DataFrame) -> None:
183
186
  """Hook for subclasses to add extra validation logic."""
184
187
  pass
185
188
 
186
- def _log_input_data(self, dataset: Dataset) -> None:
187
- """Print a summary of the input dataset with Rich formatting."""
188
- row_count = len(dataset)
189
- columns = dataset.column_names
189
+ def _log_input_data(self, df: pd.DataFrame) -> None:
190
+ """Print a summary of the input DataFrame with Rich formatting."""
191
+ row_count = len(df)
192
+ columns = df.columns.tolist()
190
193
  content = Text()
191
194
  content.append("\U0001f4ca Processing Input Data\n", style="bold blue")
192
195
  content.append(f"Block Type: {self.__class__.__name__}\n", style="cyan")
@@ -207,13 +210,12 @@ class BaseBlock(BaseModel, ABC):
207
210
  Panel(content, title=f"[bold]{self.block_name}[/bold]", border_style="blue")
208
211
  )
209
212
 
210
- def _log_output_data(self, input_dataset: Dataset, output_dataset: Dataset) -> None:
211
- """Print a Rich panel summarizing output dataset differences."""
212
- in_rows, out_rows = len(input_dataset), len(output_dataset)
213
- in_cols, out_cols = (
214
- set(input_dataset.column_names),
215
- set(output_dataset.column_names),
216
- )
213
+ def _log_output_data(self, input_df: pd.DataFrame, output_df: pd.DataFrame) -> None:
214
+ """Print a Rich panel summarizing output DataFrame differences."""
215
+ in_rows, out_rows = len(input_df), len(output_df)
216
+ in_cols = set(input_df.columns.tolist())
217
+ out_cols = set(output_df.columns.tolist())
218
+
217
219
  added_cols, removed_cols = out_cols - in_cols, in_cols - out_cols
218
220
  content = Text()
219
221
  content.append("\u2705 Processing Complete\n", style="bold green")
@@ -239,35 +241,35 @@ class BaseBlock(BaseModel, ABC):
239
241
  )
240
242
 
241
243
  @abstractmethod
242
- def generate(self, samples: Dataset, **kwargs: Any) -> Dataset:
244
+ def generate(self, samples: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
243
245
  """Subclass method to implement data generation logic.
244
246
 
245
247
  Parameters
246
248
  ----------
247
- samples : Dataset
248
- Input dataset to process.
249
+ samples : pd.DataFrame
250
+ Input DataFrame to process.
249
251
 
250
252
  Returns
251
253
  -------
252
- Dataset
253
- Transformed dataset with new columns or values.
254
+ pd.DataFrame
255
+ Transformed DataFrame with new columns or values.
254
256
  """
255
257
  pass
256
258
 
257
- def __call__(self, samples: Dataset, **kwargs: Any) -> Dataset:
258
- """Run the block on a dataset with full validation and logging.
259
+ def __call__(self, samples: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
260
+ """Run the block on a DataFrame with full validation and logging.
259
261
 
260
262
  Parameters
261
263
  ----------
262
- samples : Dataset
263
- Input dataset.
264
+ samples : pd.DataFrame
265
+ Input DataFrame.
264
266
  **kwargs : Any
265
267
  Runtime parameters to override block configuration
266
268
 
267
269
  Returns
268
270
  -------
269
- Dataset
270
- Output dataset after block processing.
271
+ pd.DataFrame
272
+ Output DataFrame after block processing.
271
273
  """
272
274
  # Handle runtime kwargs overrides
273
275
  if kwargs:
@@ -310,12 +312,12 @@ class BaseBlock(BaseModel, ABC):
310
312
 
311
313
  try:
312
314
  self._log_input_data(samples)
313
- self._validate_dataset(samples)
315
+ self._validate_dataframe(samples)
314
316
  self._validate_custom(samples)
315
317
  # Pass ALL kwargs to generate (including flow params)
316
- output_dataset = self.generate(samples, **kwargs)
317
- self._log_output_data(samples, output_dataset)
318
- return output_dataset
318
+ output_df = self.generate(samples, **kwargs)
319
+ self._log_output_data(samples, output_df)
320
+ return output_df
319
321
  finally:
320
322
  # Always restore original values for block fields
321
323
  for key, value in original_values.items():
@@ -323,11 +325,11 @@ class BaseBlock(BaseModel, ABC):
323
325
  else:
324
326
  # Normal execution without overrides
325
327
  self._log_input_data(samples)
326
- self._validate_dataset(samples)
328
+ self._validate_dataframe(samples)
327
329
  self._validate_custom(samples)
328
- output_dataset = self.generate(samples)
329
- self._log_output_data(samples, output_dataset)
330
- return output_dataset
330
+ output_df = self.generate(samples)
331
+ self._log_output_data(samples, output_df)
332
+ return output_df
331
333
 
332
334
  def __repr__(self) -> str:
333
335
  """Compact string representation."""
@@ -9,10 +9,11 @@ using various operations with optional data type conversion.
9
9
  from typing import Any, Optional, Union
10
10
  import operator
11
11
 
12
- # Third Party
13
- from datasets import Dataset
14
12
  from pydantic import Field, field_validator
15
13
 
14
+ # Third Party
15
+ import pandas as pd
16
+
16
17
  # Local
17
18
  from ...utils.logger_config import setup_logger
18
19
  from ..base import BaseBlock
@@ -158,32 +159,44 @@ class ColumnValueFilterBlock(BaseBlock):
158
159
  sample[self.column_name] = None
159
160
  return sample
160
161
 
161
- def generate(self, samples: Dataset, **_kwargs: Any) -> Dataset:
162
+ def generate(self, samples: pd.DataFrame, **_kwargs: Any) -> pd.DataFrame:
162
163
  """Generate filtered dataset based on specified conditions.
163
164
 
164
165
  Parameters
165
166
  ----------
166
- samples : Dataset
167
+ samples : pd.DataFrame
167
168
  The input dataset to filter.
168
169
 
169
170
  Returns
170
171
  -------
171
- Dataset
172
+ pd.DataFrame
172
173
  The filtered dataset.
173
174
  """
175
+ result = samples.copy()
176
+
177
+ # Convert dtype if specified
174
178
  if self._convert_dtype_func:
175
- samples = samples.map(self._convert_dtype)
176
179
 
177
- samples = samples.filter(
178
- lambda x: x[self.column_name] is not None,
179
- )
180
+ def safe_convert(x):
181
+ """Safely convert value, returning None on error."""
182
+ if pd.isna(x):
183
+ return None
184
+ try:
185
+ return self._convert_dtype_func(x)
186
+ except (ValueError, TypeError):
187
+ return None
180
188
 
181
- # Apply filter operation
182
- samples = samples.filter(
183
- lambda x: any(
184
- self._operation_func(x[self.column_name], value)
185
- for value in self.filter_value
186
- )
189
+ result[self.column_name] = result[self.column_name].apply(safe_convert)
190
+
191
+ # Filter out None values
192
+ result = result[result[self.column_name].notna()]
193
+
194
+ # Apply filter operation using boolean indexing
195
+ # Create a mask that checks if any filter value matches
196
+ mask = result[self.column_name].apply(
197
+ lambda x: any(self._operation_func(x, value) for value in self.filter_value)
187
198
  )
188
199
 
189
- return samples
200
+ result = result[mask]
201
+
202
+ return result
@@ -9,7 +9,6 @@ local models (vLLM, Ollama), and more.
9
9
  # Local
10
10
  from .error_handler import ErrorCategory, LLMErrorHandler
11
11
  from .llm_chat_block import LLMChatBlock
12
- from .llm_chat_with_parsing_retry_block import LLMChatWithParsingRetryBlock
13
12
  from .llm_parser_block import LLMParserBlock
14
13
  from .prompt_builder_block import PromptBuilderBlock
15
14
  from .text_parser_block import TextParserBlock
@@ -18,7 +17,6 @@ __all__ = [
18
17
  "LLMErrorHandler",
19
18
  "ErrorCategory",
20
19
  "LLMChatBlock",
21
- "LLMChatWithParsingRetryBlock",
22
20
  "LLMParserBlock",
23
21
  "PromptBuilderBlock",
24
22
  "TextParserBlock",
@@ -5,12 +5,13 @@
5
5
  from typing import Any, Optional
6
6
  import asyncio
7
7
 
8
- # Third Party
9
- from datasets import Dataset
10
8
  from litellm import acompletion, completion
11
9
  from pydantic import ConfigDict, Field, field_validator
12
10
  import litellm
13
11
 
12
+ # Third Party
13
+ import pandas as pd
14
+
14
15
  from ...utils.error_handling import BlockValidationError
15
16
  from ...utils.logger_config import setup_logger
16
17
 
@@ -167,12 +168,12 @@ class LLMChatBlock(BaseBlock):
167
168
  },
168
169
  )
169
170
 
170
- def generate(self, samples: Dataset, **kwargs: Any) -> Dataset:
171
+ def generate(self, samples: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
171
172
  """Generate responses from the LLM.
172
173
 
173
174
  Parameters
174
175
  ----------
175
- samples : Dataset
176
+ samples : pd.DataFrame
176
177
  Input dataset containing the messages column.
177
178
  **kwargs : Any
178
179
  Runtime parameters that override initialization defaults.
@@ -180,7 +181,7 @@ class LLMChatBlock(BaseBlock):
180
181
 
181
182
  Returns
182
183
  -------
183
- Dataset
184
+ pd.DataFrame
184
185
  Dataset with responses added to the output column.
185
186
 
186
187
  Raises
@@ -201,8 +202,8 @@ class LLMChatBlock(BaseBlock):
201
202
  # Build completion kwargs from ALL fields + runtime overrides
202
203
  completion_kwargs = self._build_completion_kwargs(**kwargs)
203
204
 
204
- # Extract messages
205
- messages_list = samples[self.input_cols[0]]
205
+ # Extract messages from pandas DataFrame
206
+ messages_list = samples[self.input_cols[0]].tolist()
206
207
 
207
208
  # Log generation start
208
209
  logger.info(
@@ -269,7 +270,9 @@ class LLMChatBlock(BaseBlock):
269
270
  )
270
271
 
271
272
  # Add responses as new column
272
- return samples.add_column(self.output_cols[0], responses)
273
+ result = samples.copy()
274
+ result[self.output_cols[0]] = responses
275
+ return result
273
276
 
274
277
  def _build_completion_kwargs(self, **overrides) -> dict[str, Any]:
275
278
  """Build kwargs for LiteLLM completion call.
@@ -513,12 +516,14 @@ class LLMChatBlock(BaseBlock):
513
516
  )
514
517
  raise
515
518
 
516
- def _validate_custom(self, dataset: Dataset) -> None:
519
+ def _validate_custom(self, dataset: pd.DataFrame) -> None:
517
520
  """Custom validation for LLMChatBlock message format.
518
521
 
522
+ Uses vectorized operations where possible for better performance.
523
+
519
524
  Parameters
520
525
  ----------
521
- dataset : Dataset
526
+ dataset : pd.DataFrame
522
527
  The dataset to validate.
523
528
 
524
529
  Raises
@@ -526,28 +531,32 @@ class LLMChatBlock(BaseBlock):
526
531
  BlockValidationError
527
532
  If message format validation fails.
528
533
  """
534
+ messages_col = dataset[self.input_cols[0]]
535
+
536
+ # avoid using pd iterrows() when possible, it is notoriously slow: https://github.com/pandas-dev/pandas/issues/7683
537
+ # Vectorized check: all values must be lists
538
+ is_list = messages_col.apply(lambda x: isinstance(x, list))
539
+ if not is_list.all():
540
+ invalid_idx = is_list[~is_list].index[0]
541
+ invalid_value = messages_col.loc[invalid_idx]
542
+ raise BlockValidationError(
543
+ f"Messages column '{self.input_cols[0]}' must contain a list, "
544
+ f"got {type(invalid_value)} in row {invalid_idx}",
545
+ details=f"Block: {self.block_name}, Row: {invalid_idx}, Value: {invalid_value}",
546
+ )
529
547
 
530
- def validate_sample(sample_with_index):
531
- """Validate a single sample's message format."""
532
- idx, sample = sample_with_index
533
- messages = sample[self.input_cols[0]]
534
-
535
- # Validate messages is a list
536
- if not isinstance(messages, list):
537
- raise BlockValidationError(
538
- f"Messages column '{self.input_cols[0]}' must contain a list, "
539
- f"got {type(messages)} in row {idx}",
540
- details=f"Block: {self.block_name}, Row: {idx}, Value: {messages}",
541
- )
542
-
543
- # Validate messages is not empty
544
- if not messages:
545
- raise BlockValidationError(
546
- f"Messages list is empty in row {idx}",
547
- details=f"Block: {self.block_name}, Row: {idx}",
548
- )
548
+ # Vectorized check: no empty lists
549
+ is_empty = messages_col.apply(lambda x: len(x) == 0)
550
+ if is_empty.any():
551
+ invalid_idx = is_empty[is_empty].index[0]
552
+ raise BlockValidationError(
553
+ f"Messages list is empty in row {invalid_idx}",
554
+ details=f"Block: {self.block_name}, Row: {invalid_idx}",
555
+ )
549
556
 
550
- # Validate each message format
557
+ # Validate nested message structure (requires iteration over messages column only)
558
+ def validate_message_structure(messages, idx):
559
+ """Validate structure of messages list."""
551
560
  for msg_idx, message in enumerate(messages):
552
561
  if not isinstance(message, dict):
553
562
  raise BlockValidationError(
@@ -555,7 +564,6 @@ class LLMChatBlock(BaseBlock):
555
564
  details=f"Block: {self.block_name}, Row: {idx}, Message: {msg_idx}, Value: {message}",
556
565
  )
557
566
 
558
- # Validate required fields
559
567
  if "role" not in message or message["role"] is None:
560
568
  raise BlockValidationError(
561
569
  f"Message {msg_idx} in row {idx} missing required 'role' field",
@@ -568,11 +576,9 @@ class LLMChatBlock(BaseBlock):
568
576
  details=f"Block: {self.block_name}, Row: {idx}, Message: {msg_idx}, Available fields: {list(message.keys())}",
569
577
  )
570
578
 
571
- return True
572
-
573
- # Validate all samples
574
- indexed_samples = [(i, sample) for i, sample in enumerate(dataset)]
575
- list(map(validate_sample, indexed_samples))
579
+ # Iterate only over the messages column (not the entire DataFrame)
580
+ for idx, messages in messages_col.items():
581
+ validate_message_structure(messages, idx)
576
582
 
577
583
  def __repr__(self) -> str:
578
584
  """String representation of the block."""
@@ -7,16 +7,14 @@ This module provides the LLMParserBlock for extracting specific fields
7
7
 
8
8
  # Standard
9
9
  from typing import Any
10
- from weakref import finalize
11
- import json
12
10
 
13
- # Third Party
14
- from datasets import Dataset, load_dataset
15
11
  from pydantic import Field, model_validator
16
12
 
13
+ # Third Party
14
+ import pandas as pd
15
+
17
16
  # Local
18
17
  from ...utils.logger_config import setup_logger
19
- from ...utils.temp_manager import cleanup_path, create_temp_dir, create_temp_file
20
18
  from ..base import BaseBlock
21
19
  from ..registry import BlockRegistry
22
20
 
@@ -107,12 +105,12 @@ class LLMParserBlock(BaseBlock):
107
105
 
108
106
  return self
109
107
 
110
- def _validate_custom(self, dataset: Dataset) -> None:
108
+ def _validate_custom(self, dataset: pd.DataFrame) -> None:
111
109
  """Validate LLMParserBlock specific requirements.
112
110
 
113
111
  Parameters
114
112
  ----------
115
- dataset : Dataset
113
+ dataset : pd.DataFrame
116
114
  The dataset to validate.
117
115
 
118
116
  Raises
@@ -313,60 +311,16 @@ class LLMParserBlock(BaseBlock):
313
311
  extracted = self._extract_fields_from_response(raw_output)
314
312
  return [{**sample, **extracted}]
315
313
 
316
- def generate(self, samples: Dataset, **kwargs: Any) -> Dataset:
314
+ def generate(self, samples: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
317
315
  logger.debug(f"Extracting fields from {len(samples)} samples")
318
316
  if len(samples) == 0:
319
317
  logger.warning("No samples to process, returning empty dataset")
320
- return Dataset.from_list([])
318
+ return pd.DataFrame()
321
319
 
322
- tmp_jsonl_path = kwargs.get("_flow_tmp_jsonl_path")
323
- cleanup_locally = False
320
+ new_data = []
321
+ samples = samples.to_dict("records") # Avoid Iterrows() when possible
324
322
 
325
- if tmp_jsonl_path is None:
326
- tmp_jsonl_path = str(
327
- create_temp_file(
328
- prefix=f"{self.block_name}_llm_parser", suffix=".jsonl"
329
- )
330
- )
331
- cleanup_locally = True
332
-
333
- rows_written = 0
334
- batch = []
335
- with open(tmp_jsonl_path, "w") as f:
336
- for sample in samples:
337
- out = self._generate(sample)
338
- for row in out:
339
- batch.append(json.dumps(row) + "\n")
340
- rows_written += 1
341
- if len(batch) >= 5:
342
- f.writelines(batch)
343
- batch.clear()
344
- if batch:
345
- f.writelines(batch)
346
-
347
- if rows_written == 0:
348
- if cleanup_locally:
349
- cleanup_path(tmp_jsonl_path)
350
- return Dataset.from_list([])
351
-
352
- hf_cache_dir = None
353
- try:
354
- hf_cache_dir = create_temp_dir(
355
- prefix=f"{self.block_name}_llm_parser_hf_cache"
356
- )
357
- ret = load_dataset(
358
- "json",
359
- data_files=tmp_jsonl_path,
360
- split="train",
361
- keep_in_memory=False,
362
- cache_dir=str(hf_cache_dir),
363
- )
364
- finalize(ret, cleanup_path, hf_cache_dir)
365
- return ret
366
- except Exception:
367
- if hf_cache_dir is not None:
368
- cleanup_path(hf_cache_dir)
369
- raise
370
- finally:
371
- if cleanup_locally:
372
- cleanup_path(tmp_jsonl_path)
323
+ for sample in samples:
324
+ new_data.extend(self._generate(sample))
325
+
326
+ return pd.DataFrame(new_data)