sdg-hub 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sdg_hub/core/flow/base.py CHANGED
@@ -4,18 +4,26 @@
4
4
  # Standard
5
5
  from pathlib import Path
6
6
  from typing import Any, Optional, Union
7
+ import time
7
8
 
8
9
  # Third Party
9
10
  from datasets import Dataset
10
11
  from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
12
+ from rich.console import Console
13
+ from rich.panel import Panel
14
+ from rich.table import Table
15
+ from rich.tree import Tree
11
16
  import yaml
12
17
 
13
18
  # Local
14
19
  from ..blocks.base import BaseBlock
15
20
  from ..blocks.registry import BlockRegistry
21
+ from ..utils.datautils import safe_concatenate_with_validation
16
22
  from ..utils.error_handling import EmptyDatasetError, FlowValidationError
17
23
  from ..utils.logger_config import setup_logger
18
24
  from ..utils.path_resolution import resolve_path
25
+ from ..utils.yaml_utils import save_flow_yaml
26
+ from .checkpointer import FlowCheckpointer
19
27
  from .metadata import FlowMetadata, FlowParameter
20
28
  from .migration import FlowMigration
21
29
  from .validation import FlowValidator
@@ -133,7 +141,17 @@ class Flow(BaseModel):
133
141
  -------
134
142
  Flow
135
143
  Validated Flow instance.
144
+
145
+ Raises
146
+ ------
147
+ FlowValidationError
148
+ If yaml_path is None or the file doesn't exist.
136
149
  """
150
+ if yaml_path is None:
151
+ raise FlowValidationError(
152
+ "Flow path cannot be None. Please provide a valid YAML file path or check that the flow exists in the registry."
153
+ )
154
+
137
155
  yaml_path = resolve_path(yaml_path, [])
138
156
  yaml_dir = Path(yaml_path).parent
139
157
 
@@ -160,6 +178,8 @@ class Flow(BaseModel):
160
178
  flow_config, migrated_runtime_params = FlowMigration.migrate_to_new_format(
161
179
  flow_config, yaml_path
162
180
  )
181
+ # Save migrated config back to YAML to persist id
182
+ save_flow_yaml(yaml_path, flow_config, "migrated to new format")
163
183
 
164
184
  # Validate YAML structure
165
185
  validator = FlowValidator()
@@ -221,6 +241,17 @@ class Flow(BaseModel):
221
241
  # Create and validate the flow
222
242
  try:
223
243
  flow = cls(blocks=blocks, metadata=metadata, parameters=parameters)
244
+ # Persist generated id back to the YAML file (only on initial load)
245
+ # If the file had no metadata.id originally, update and rewrite
246
+ if not flow_config.get("metadata", {}).get("id"):
247
+ flow_config.setdefault("metadata", {})["id"] = flow.metadata.id
248
+ save_flow_yaml(
249
+ yaml_path,
250
+ flow_config,
251
+ f"added generated id: {flow.metadata.id}",
252
+ )
253
+ else:
254
+ logger.debug(f"Flow already had id: {flow.metadata.id}")
224
255
  # Store migrated runtime params and client for backward compatibility
225
256
  if migrated_runtime_params:
226
257
  flow._migrated_runtime_params = migrated_runtime_params
@@ -324,6 +355,8 @@ class Flow(BaseModel):
324
355
  self,
325
356
  dataset: Dataset,
326
357
  runtime_params: Optional[dict[str, dict[str, Any]]] = None,
358
+ checkpoint_dir: Optional[str] = None,
359
+ save_freq: Optional[int] = None,
327
360
  ) -> Dataset:
328
361
  """Execute the flow blocks in sequence to generate data.
329
362
 
@@ -340,6 +373,11 @@ class Flow(BaseModel):
340
373
  "block_name": {"param1": value1, "param2": value2},
341
374
  "other_block": {"param3": value3}
342
375
  }
376
+ checkpoint_dir : Optional[str], optional
377
+ Directory to save/load checkpoints. If provided, enables checkpointing.
378
+ save_freq : Optional[int], optional
379
+ Number of completed samples after which to save a checkpoint.
380
+ If None, only saves final results when checkpointing is enabled.
343
381
 
344
382
  Returns
345
383
  -------
@@ -353,6 +391,12 @@ class Flow(BaseModel):
353
391
  FlowValidationError
354
392
  If flow validation fails or if model configuration is required but not set.
355
393
  """
394
+ # Validate save_freq parameter early to prevent range() errors
395
+ if save_freq is not None and save_freq <= 0:
396
+ raise FlowValidationError(
397
+ f"save_freq must be greater than 0, got {save_freq}"
398
+ )
399
+
356
400
  # Validate preconditions
357
401
  if not self.blocks:
358
402
  raise FlowValidationError("Cannot generate with empty flow")
@@ -376,18 +420,119 @@ class Flow(BaseModel):
376
420
  "Dataset validation failed:\n" + "\n".join(dataset_errors)
377
421
  )
378
422
 
423
+ # Initialize checkpointer if enabled
424
+ checkpointer = None
425
+ completed_dataset = None
426
+ if checkpoint_dir:
427
+ checkpointer = FlowCheckpointer(
428
+ checkpoint_dir=checkpoint_dir,
429
+ save_freq=save_freq,
430
+ flow_id=self.metadata.id,
431
+ )
432
+
433
+ # Load existing progress
434
+ remaining_dataset, completed_dataset = checkpointer.load_existing_progress(
435
+ dataset
436
+ )
437
+
438
+ if len(remaining_dataset) == 0:
439
+ logger.info("All samples already completed, returning existing results")
440
+ return completed_dataset
441
+
442
+ dataset = remaining_dataset
443
+ logger.info(f"Resuming with {len(dataset)} remaining samples")
444
+
379
445
  logger.info(
380
446
  f"Starting flow '{self.metadata.name}' v{self.metadata.version} "
381
447
  f"with {len(dataset)} samples across {len(self.blocks)} blocks"
382
448
  )
383
449
 
384
- current_dataset = dataset
385
450
  # Merge migrated runtime params with provided ones (provided ones take precedence)
386
451
  merged_runtime_params = self._migrated_runtime_params.copy()
387
452
  if runtime_params:
388
453
  merged_runtime_params.update(runtime_params)
389
454
  runtime_params = merged_runtime_params
390
455
 
456
+ # Process dataset in chunks if checkpointing with save_freq
457
+ if checkpointer and save_freq:
458
+ all_processed = []
459
+
460
+ # Process in chunks of save_freq
461
+ for i in range(0, len(dataset), save_freq):
462
+ chunk_end = min(i + save_freq, len(dataset))
463
+ chunk_dataset = dataset.select(range(i, chunk_end))
464
+
465
+ logger.info(
466
+ f"Processing chunk {i // save_freq + 1}: samples {i} to {chunk_end - 1}"
467
+ )
468
+
469
+ # Execute all blocks on this chunk
470
+ processed_chunk = self._execute_blocks_on_dataset(
471
+ chunk_dataset, runtime_params
472
+ )
473
+ all_processed.append(processed_chunk)
474
+
475
+ # Save checkpoint after chunk completion
476
+ checkpointer.add_completed_samples(processed_chunk)
477
+
478
+ # Save final checkpoint for any remaining samples
479
+ checkpointer.save_final_checkpoint()
480
+
481
+ # Combine all processed chunks
482
+ final_dataset = safe_concatenate_with_validation(
483
+ all_processed, "processed chunks from flow execution"
484
+ )
485
+
486
+ # Combine with previously completed samples if any
487
+ if checkpointer and completed_dataset:
488
+ final_dataset = safe_concatenate_with_validation(
489
+ [completed_dataset, final_dataset],
490
+ "completed checkpoint data with newly processed data",
491
+ )
492
+
493
+ else:
494
+ # Process entire dataset at once
495
+ final_dataset = self._execute_blocks_on_dataset(dataset, runtime_params)
496
+
497
+ # Save final checkpoint if checkpointing enabled
498
+ if checkpointer:
499
+ checkpointer.add_completed_samples(final_dataset)
500
+ checkpointer.save_final_checkpoint()
501
+
502
+ # Combine with previously completed samples if any
503
+ if completed_dataset:
504
+ final_dataset = safe_concatenate_with_validation(
505
+ [completed_dataset, final_dataset],
506
+ "completed checkpoint data with newly processed data",
507
+ )
508
+
509
+ logger.info(
510
+ f"Flow '{self.metadata.name}' completed successfully: "
511
+ f"{len(final_dataset)} final samples, "
512
+ f"{len(final_dataset.column_names)} final columns"
513
+ )
514
+
515
+ return final_dataset
516
+
517
+ def _execute_blocks_on_dataset(
518
+ self, dataset: Dataset, runtime_params: dict[str, dict[str, Any]]
519
+ ) -> Dataset:
520
+ """Execute all blocks in sequence on the given dataset.
521
+
522
+ Parameters
523
+ ----------
524
+ dataset : Dataset
525
+ Dataset to process through all blocks.
526
+ runtime_params : Dict[str, Dict[str, Any]]
527
+ Runtime parameters for block execution.
528
+
529
+ Returns
530
+ -------
531
+ Dataset
532
+ Dataset after processing through all blocks.
533
+ """
534
+ current_dataset = dataset
535
+
391
536
  # Execute blocks in sequence
392
537
  for i, block in enumerate(self.blocks):
393
538
  logger.info(
@@ -436,12 +581,6 @@ class Flow(BaseModel):
436
581
  f"Block '{block.block_name}' execution failed: {exc}"
437
582
  ) from exc
438
583
 
439
- logger.info(
440
- f"Flow '{self.metadata.name}' completed successfully: "
441
- f"{len(current_dataset)} final samples, "
442
- f"{len(current_dataset.column_names)} final columns"
443
- )
444
-
445
584
  return current_dataset
446
585
 
447
586
  def _prepare_block_kwargs(
@@ -784,9 +923,6 @@ class Flow(BaseModel):
784
923
  "execution_time_seconds": 0,
785
924
  }
786
925
 
787
- # Standard
788
- import time
789
-
790
926
  start_time = time.time()
791
927
 
792
928
  try:
@@ -930,6 +1066,102 @@ class Flow(BaseModel):
930
1066
  "block_names": [block.block_name for block in self.blocks],
931
1067
  }
932
1068
 
1069
+ def print_info(self) -> None:
1070
+ """
1071
+ Print an interactive summary of the Flow in the console.
1072
+
1073
+ The summary contains:
1074
+ 1. Flow metadata (name, version, author, description)
1075
+ 2. Defined runtime parameters with type hints and defaults
1076
+ 3. A table of all blocks with their input and output columns
1077
+
1078
+ Notes
1079
+ -----
1080
+ Uses the `rich` library for colourised output; install with
1081
+ `pip install rich` if not already present.
1082
+
1083
+ Returns
1084
+ -------
1085
+ None
1086
+ """
1087
+
1088
+ console = Console()
1089
+
1090
+ # Create main tree structure
1091
+ flow_tree = Tree(
1092
+ f"[bold bright_blue]{self.metadata.name}[/bold bright_blue] Flow"
1093
+ )
1094
+
1095
+ # Metadata section
1096
+ metadata_branch = flow_tree.add(
1097
+ "[bold bright_green]Metadata[/bold bright_green]"
1098
+ )
1099
+ metadata_branch.add(
1100
+ f"Version: [bright_cyan]{self.metadata.version}[/bright_cyan]"
1101
+ )
1102
+ metadata_branch.add(
1103
+ f"Author: [bright_cyan]{self.metadata.author}[/bright_cyan]"
1104
+ )
1105
+ if self.metadata.description:
1106
+ metadata_branch.add(
1107
+ f"Description: [white]{self.metadata.description}[/white]"
1108
+ )
1109
+
1110
+ # Parameters section
1111
+ if self.parameters:
1112
+ params_branch = flow_tree.add(
1113
+ "[bold bright_yellow]Parameters[/bold bright_yellow]"
1114
+ )
1115
+ for name, param in self.parameters.items():
1116
+ param_info = f"[bright_cyan]{name}[/bright_cyan]: [white]{param.type_hint}[/white]"
1117
+ if param.default is not None:
1118
+ param_info += f" = [bright_white]{param.default}[/bright_white]"
1119
+ params_branch.add(param_info)
1120
+
1121
+ # Blocks overview
1122
+ flow_tree.add(
1123
+ f"[bold bright_magenta]Blocks[/bold bright_magenta] ({len(self.blocks)} total)"
1124
+ )
1125
+
1126
+ # Create blocks table
1127
+ blocks_table = Table(show_header=True, header_style="bold bright_white")
1128
+ blocks_table.add_column("Block Name", style="bright_cyan")
1129
+ blocks_table.add_column("Type", style="bright_green")
1130
+ blocks_table.add_column("Input Cols", style="bright_yellow")
1131
+ blocks_table.add_column("Output Cols", style="bright_red")
1132
+
1133
+ for block in self.blocks:
1134
+ input_cols = getattr(block, "input_cols", None)
1135
+ output_cols = getattr(block, "output_cols", None)
1136
+
1137
+ blocks_table.add_row(
1138
+ block.block_name,
1139
+ block.__class__.__name__,
1140
+ str(input_cols) if input_cols else "[bright_black]None[/bright_black]",
1141
+ str(output_cols)
1142
+ if output_cols
1143
+ else "[bright_black]None[/bright_black]",
1144
+ )
1145
+
1146
+ # Print everything
1147
+ console.print()
1148
+ console.print(
1149
+ Panel(
1150
+ flow_tree,
1151
+ title="[bold bright_white]Flow Information[/bold bright_white]",
1152
+ border_style="bright_blue",
1153
+ )
1154
+ )
1155
+ console.print()
1156
+ console.print(
1157
+ Panel(
1158
+ blocks_table,
1159
+ title="[bold bright_white]Block Details[/bold bright_white]",
1160
+ border_style="bright_magenta",
1161
+ )
1162
+ )
1163
+ console.print()
1164
+
933
1165
  def to_yaml(self, output_path: str) -> None:
934
1166
  """Save flow configuration to YAML file.
935
1167
 
@@ -952,10 +1184,7 @@ class Flow(BaseModel):
952
1184
  name: param.model_dump() for name, param in self.parameters.items()
953
1185
  }
954
1186
 
955
- with open(output_path, "w", encoding="utf-8") as f:
956
- yaml.dump(config, f, default_flow_style=False, sort_keys=False)
957
-
958
- logger.info(f"Flow configuration saved to: {output_path}")
1187
+ save_flow_yaml(output_path, config)
959
1188
 
960
1189
  def __len__(self) -> int:
961
1190
  """Number of blocks in the flow."""
@@ -0,0 +1,333 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Flow-level checkpointing with sample-level tracking for data generation pipelines."""
3
+
4
+ # Standard
5
+ from pathlib import Path
6
+ from typing import Any, Dict, List, Optional, Tuple
7
+ import json
8
+ import os
9
+ import uuid
10
+
11
+ # Third Party
12
+ from datasets import Dataset
13
+
14
+ # Local
15
+ from ..utils.datautils import safe_concatenate_with_validation
16
+ from ..utils.logger_config import setup_logger
17
+
18
+ logger = setup_logger(__name__)
19
+
20
+
21
+ class FlowCheckpointer:
22
+ """Enhanced checkpointer for Flow execution with sample-level tracking.
23
+
24
+ Provides data-level checkpointing where progress is saved after processing
25
+ a specified number of samples through the entire flow pipeline.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ checkpoint_dir: Optional[str] = None,
31
+ save_freq: Optional[int] = None,
32
+ flow_id: Optional[str] = None,
33
+ ):
34
+ """Initialize the FlowCheckpointer.
35
+
36
+ Parameters
37
+ ----------
38
+ checkpoint_dir : Optional[str]
39
+ Directory to save/load checkpoints. If None, checkpointing is disabled.
40
+ save_freq : Optional[int]
41
+ Number of completed samples after which to save a checkpoint.
42
+ If None, only final results are saved.
43
+ flow_id : Optional[str]
44
+ Unique ID of the flow for checkpoint identification.
45
+ """
46
+ self.checkpoint_dir = checkpoint_dir
47
+ self.save_freq = save_freq
48
+ self.flow_id = flow_id or "unknown_flow"
49
+
50
+ # Internal state
51
+ self._samples_processed = 0
52
+ self._checkpoint_counter = 0
53
+ self._pending_samples: List[Dict[str, Any]] = []
54
+
55
+ # Ensure checkpoint directory exists
56
+ if self.checkpoint_dir:
57
+ Path(self.checkpoint_dir).mkdir(parents=True, exist_ok=True)
58
+
59
+ @property
60
+ def is_enabled(self) -> bool:
61
+ """Check if checkpointing is enabled."""
62
+ return self.checkpoint_dir is not None
63
+
64
+ @property
65
+ def metadata_path(self) -> str:
66
+ """Path to the flow metadata file."""
67
+ return os.path.join(self.checkpoint_dir, "flow_metadata.json")
68
+
69
+ def load_existing_progress(
70
+ self, input_dataset: Dataset
71
+ ) -> Tuple[Dataset, Optional[Dataset]]:
72
+ """Load existing checkpoint data and determine remaining work.
73
+
74
+ Parameters
75
+ ----------
76
+ input_dataset : Dataset
77
+ Original input dataset for the flow.
78
+
79
+ Returns
80
+ -------
81
+ Tuple[Dataset, Optional[Dataset]]
82
+ (remaining_samples_to_process, completed_samples_dataset)
83
+ If no checkpoints exist, returns (input_dataset, None)
84
+ """
85
+ if not self.is_enabled:
86
+ return input_dataset, None
87
+
88
+ try:
89
+ # Load flow metadata
90
+ metadata = self._load_metadata()
91
+ if not metadata:
92
+ logger.info(f"No existing checkpoints found in {self.checkpoint_dir}")
93
+ return input_dataset, None
94
+
95
+ # Validate flow identity to prevent mixing checkpoints from different flows
96
+ saved_flow_id = metadata.get("flow_id")
97
+ if saved_flow_id and saved_flow_id != self.flow_id:
98
+ logger.warning(
99
+ f"Flow ID mismatch: saved checkpoints are for flow ID '{saved_flow_id}' "
100
+ f"but current flow ID is '{self.flow_id}'. Starting fresh to avoid "
101
+ f"mixing incompatible checkpoint data."
102
+ )
103
+ return input_dataset, None
104
+
105
+ # Load all completed samples from checkpoints
106
+ completed_dataset = self._load_completed_samples()
107
+ if completed_dataset is None or len(completed_dataset) == 0:
108
+ logger.info("No completed samples found in checkpoints")
109
+ return input_dataset, None
110
+
111
+ # Find samples that still need processing
112
+ remaining_dataset = self._find_remaining_samples(
113
+ input_dataset, completed_dataset
114
+ )
115
+
116
+ self._samples_processed = len(completed_dataset)
117
+ self._checkpoint_counter = metadata.get("checkpoint_counter", 0)
118
+
119
+ logger.info(
120
+ f"Loaded {len(completed_dataset)} completed samples, "
121
+ f"{len(remaining_dataset)} samples remaining"
122
+ )
123
+
124
+ return remaining_dataset, completed_dataset
125
+
126
+ except Exception as exc:
127
+ logger.warning(f"Failed to load checkpoints: {exc}. Starting from scratch.")
128
+ return input_dataset, None
129
+
130
+ def add_completed_samples(self, samples: Dataset) -> None:
131
+ """Add samples that have completed the entire flow.
132
+
133
+ Parameters
134
+ ----------
135
+ samples : Dataset
136
+ Samples that have completed processing through all blocks.
137
+ """
138
+ if not self.is_enabled:
139
+ return
140
+
141
+ # Add to pending samples
142
+ for sample in samples:
143
+ self._pending_samples.append(sample)
144
+ self._samples_processed += 1
145
+
146
+ # Check if we should save a checkpoint
147
+ if self.save_freq and len(self._pending_samples) >= self.save_freq:
148
+ self._save_checkpoint()
149
+
150
+ def save_final_checkpoint(self) -> None:
151
+ """Save any remaining pending samples as final checkpoint."""
152
+ if not self.is_enabled:
153
+ return
154
+
155
+ if self._pending_samples:
156
+ sample_count = len(self._pending_samples)
157
+ self._save_checkpoint()
158
+ logger.info(f"Saved final checkpoint with {sample_count} samples")
159
+
160
+ def _save_checkpoint(self) -> None:
161
+ """Save current pending samples to a checkpoint file."""
162
+ if not self._pending_samples:
163
+ return
164
+
165
+ self._checkpoint_counter += 1
166
+ checkpoint_file = os.path.join(
167
+ self.checkpoint_dir, f"checkpoint_{self._checkpoint_counter:04d}.jsonl"
168
+ )
169
+
170
+ # Convert pending samples to dataset and save
171
+ checkpoint_dataset = Dataset.from_list(self._pending_samples)
172
+ checkpoint_dataset.to_json(checkpoint_file, orient="records", lines=True)
173
+
174
+ # Update metadata
175
+ self._save_metadata()
176
+
177
+ logger.info(
178
+ f"Saved checkpoint {self._checkpoint_counter} with "
179
+ f"{len(self._pending_samples)} samples to {checkpoint_file}"
180
+ )
181
+
182
+ # Clear pending samples
183
+ self._pending_samples.clear()
184
+
185
+ def _save_metadata(self) -> None:
186
+ """Save flow execution metadata."""
187
+ metadata = {
188
+ "flow_id": self.flow_id,
189
+ "save_freq": self.save_freq,
190
+ "samples_processed": self._samples_processed,
191
+ "checkpoint_counter": self._checkpoint_counter,
192
+ "last_updated": str(uuid.uuid4()), # Simple versioning
193
+ }
194
+
195
+ with open(self.metadata_path, "w", encoding="utf-8") as f:
196
+ json.dump(metadata, f, indent=2)
197
+
198
+ def _load_metadata(self) -> Optional[Dict[str, Any]]:
199
+ """Load flow execution metadata."""
200
+ if not os.path.exists(self.metadata_path):
201
+ return None
202
+
203
+ try:
204
+ with open(self.metadata_path, "r", encoding="utf-8") as f:
205
+ return json.load(f)
206
+ except Exception as exc:
207
+ logger.warning(f"Failed to load metadata: {exc}")
208
+ return None
209
+
210
+ def _load_completed_samples(self) -> Optional[Dataset]:
211
+ """Load all completed samples from checkpoint files."""
212
+ checkpoint_files = []
213
+ checkpoint_dir = Path(self.checkpoint_dir)
214
+
215
+ # Find all checkpoint files
216
+ for file_path in checkpoint_dir.glob("checkpoint_*.jsonl"):
217
+ checkpoint_files.append(str(file_path))
218
+
219
+ if not checkpoint_files:
220
+ return None
221
+
222
+ # Sort checkpoint files by number
223
+ checkpoint_files.sort()
224
+
225
+ # Load and concatenate all checkpoint datasets
226
+ datasets = []
227
+ for file_path in checkpoint_files:
228
+ try:
229
+ dataset = Dataset.from_json(file_path)
230
+ if len(dataset) > 0:
231
+ datasets.append(dataset)
232
+ logger.debug(
233
+ f"Loaded checkpoint: {file_path} ({len(dataset)} samples)"
234
+ )
235
+ except Exception as exc:
236
+ logger.warning(f"Failed to load checkpoint {file_path}: {exc}")
237
+
238
+ if not datasets:
239
+ return None
240
+
241
+ return safe_concatenate_with_validation(datasets, "checkpoint files")
242
+
243
+ def _find_remaining_samples(
244
+ self, input_dataset: Dataset, completed_dataset: Dataset
245
+ ) -> Dataset:
246
+ """Find samples from input_dataset that are not in completed_dataset.
247
+
248
+ Note: Assumes input_dataset contains unique samples. For datasets with
249
+ duplicates, multiset semantics with collections.Counter would be needed.
250
+
251
+ Parameters
252
+ ----------
253
+ input_dataset : Dataset
254
+ Original input dataset (assumed to contain unique samples).
255
+ completed_dataset : Dataset
256
+ Dataset of completed samples.
257
+
258
+ Returns
259
+ -------
260
+ Dataset
261
+ Samples that still need processing.
262
+ """
263
+ # Get common columns for comparison
264
+ input_columns = set(input_dataset.column_names)
265
+ completed_columns = set(completed_dataset.column_names)
266
+ common_columns = list(input_columns & completed_columns)
267
+
268
+ if not common_columns:
269
+ logger.warning(
270
+ "No common columns found between input and completed datasets. "
271
+ "Processing all input samples."
272
+ )
273
+ return input_dataset
274
+
275
+ # Convert to pandas for easier comparison
276
+ input_df = input_dataset.select_columns(common_columns).to_pandas()
277
+ completed_df = completed_dataset.select_columns(common_columns).to_pandas()
278
+
279
+ # Find rows that haven't been completed
280
+ # Use tuple representation for comparison
281
+ input_tuples = set(input_df.apply(tuple, axis=1))
282
+ completed_tuples = set(completed_df.apply(tuple, axis=1))
283
+ remaining_tuples = input_tuples - completed_tuples
284
+
285
+ # Filter input dataset to only remaining samples
286
+ remaining_mask = input_df.apply(tuple, axis=1).isin(remaining_tuples)
287
+ remaining_indices = input_df[remaining_mask].index.tolist()
288
+
289
+ if not remaining_indices:
290
+ # Return empty dataset with same structure
291
+ return input_dataset.select([])
292
+
293
+ return input_dataset.select(remaining_indices)
294
+
295
+ def get_progress_info(self) -> Dict[str, Any]:
296
+ """Get information about current progress.
297
+
298
+ Returns
299
+ -------
300
+ Dict[str, Any]
301
+ Progress information including samples processed, checkpoints saved, etc.
302
+ """
303
+ return {
304
+ "checkpoint_dir": self.checkpoint_dir,
305
+ "save_freq": self.save_freq,
306
+ "flow_id": self.flow_id,
307
+ "samples_processed": self._samples_processed,
308
+ "checkpoint_counter": self._checkpoint_counter,
309
+ "pending_samples": len(self._pending_samples),
310
+ "is_enabled": self.is_enabled,
311
+ }
312
+
313
+ def cleanup_checkpoints(self) -> None:
314
+ """Remove all checkpoint files and metadata."""
315
+ if not self.is_enabled:
316
+ return
317
+
318
+ checkpoint_dir = Path(self.checkpoint_dir)
319
+ if not checkpoint_dir.exists():
320
+ return
321
+
322
+ # Remove all checkpoint files
323
+ for file_path in checkpoint_dir.glob("checkpoint_*.jsonl"):
324
+ file_path.unlink()
325
+ logger.debug(f"Removed checkpoint file: {file_path}")
326
+
327
+ # Remove metadata file
328
+ metadata_path = Path(self.metadata_path)
329
+ if metadata_path.exists():
330
+ metadata_path.unlink()
331
+ logger.debug(f"Removed metadata file: {metadata_path}")
332
+
333
+ logger.info(f"Cleaned up all checkpoints in {self.checkpoint_dir}")