sdg-hub 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sdg_hub/_version.py +2 -2
- sdg_hub/core/blocks/llm/__init__.py +2 -0
- sdg_hub/core/blocks/llm/llm_chat_with_parsing_retry_block.py +491 -0
- sdg_hub/core/blocks/llm/text_parser_block.py +77 -30
- sdg_hub/core/blocks/registry.py +1 -1
- sdg_hub/core/flow/base.py +243 -14
- sdg_hub/core/flow/checkpointer.py +333 -0
- sdg_hub/core/flow/metadata.py +45 -0
- sdg_hub/core/flow/migration.py +12 -1
- sdg_hub/core/flow/registry.py +121 -58
- sdg_hub/core/flow/validation.py +12 -0
- sdg_hub/core/utils/__init__.py +2 -1
- sdg_hub/core/utils/datautils.py +52 -1
- sdg_hub/core/utils/flow_id_words.yaml +231 -0
- sdg_hub/core/utils/flow_identifier.py +94 -0
- sdg_hub/core/utils/yaml_utils.py +59 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/flow.yaml +1 -0
- {sdg_hub-0.2.0.dist-info → sdg_hub-0.2.1.dist-info}/METADATA +21 -18
- {sdg_hub-0.2.0.dist-info → sdg_hub-0.2.1.dist-info}/RECORD +22 -17
- {sdg_hub-0.2.0.dist-info → sdg_hub-0.2.1.dist-info}/WHEEL +0 -0
- {sdg_hub-0.2.0.dist-info → sdg_hub-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {sdg_hub-0.2.0.dist-info → sdg_hub-0.2.1.dist-info}/top_level.txt +0 -0
sdg_hub/core/flow/base.py
CHANGED
@@ -4,18 +4,26 @@
|
|
4
4
|
# Standard
|
5
5
|
from pathlib import Path
|
6
6
|
from typing import Any, Optional, Union
|
7
|
+
import time
|
7
8
|
|
8
9
|
# Third Party
|
9
10
|
from datasets import Dataset
|
10
11
|
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
12
|
+
from rich.console import Console
|
13
|
+
from rich.panel import Panel
|
14
|
+
from rich.table import Table
|
15
|
+
from rich.tree import Tree
|
11
16
|
import yaml
|
12
17
|
|
13
18
|
# Local
|
14
19
|
from ..blocks.base import BaseBlock
|
15
20
|
from ..blocks.registry import BlockRegistry
|
21
|
+
from ..utils.datautils import safe_concatenate_with_validation
|
16
22
|
from ..utils.error_handling import EmptyDatasetError, FlowValidationError
|
17
23
|
from ..utils.logger_config import setup_logger
|
18
24
|
from ..utils.path_resolution import resolve_path
|
25
|
+
from ..utils.yaml_utils import save_flow_yaml
|
26
|
+
from .checkpointer import FlowCheckpointer
|
19
27
|
from .metadata import FlowMetadata, FlowParameter
|
20
28
|
from .migration import FlowMigration
|
21
29
|
from .validation import FlowValidator
|
@@ -133,7 +141,17 @@ class Flow(BaseModel):
|
|
133
141
|
-------
|
134
142
|
Flow
|
135
143
|
Validated Flow instance.
|
144
|
+
|
145
|
+
Raises
|
146
|
+
------
|
147
|
+
FlowValidationError
|
148
|
+
If yaml_path is None or the file doesn't exist.
|
136
149
|
"""
|
150
|
+
if yaml_path is None:
|
151
|
+
raise FlowValidationError(
|
152
|
+
"Flow path cannot be None. Please provide a valid YAML file path or check that the flow exists in the registry."
|
153
|
+
)
|
154
|
+
|
137
155
|
yaml_path = resolve_path(yaml_path, [])
|
138
156
|
yaml_dir = Path(yaml_path).parent
|
139
157
|
|
@@ -160,6 +178,8 @@ class Flow(BaseModel):
|
|
160
178
|
flow_config, migrated_runtime_params = FlowMigration.migrate_to_new_format(
|
161
179
|
flow_config, yaml_path
|
162
180
|
)
|
181
|
+
# Save migrated config back to YAML to persist id
|
182
|
+
save_flow_yaml(yaml_path, flow_config, "migrated to new format")
|
163
183
|
|
164
184
|
# Validate YAML structure
|
165
185
|
validator = FlowValidator()
|
@@ -221,6 +241,17 @@ class Flow(BaseModel):
|
|
221
241
|
# Create and validate the flow
|
222
242
|
try:
|
223
243
|
flow = cls(blocks=blocks, metadata=metadata, parameters=parameters)
|
244
|
+
# Persist generated id back to the YAML file (only on initial load)
|
245
|
+
# If the file had no metadata.id originally, update and rewrite
|
246
|
+
if not flow_config.get("metadata", {}).get("id"):
|
247
|
+
flow_config.setdefault("metadata", {})["id"] = flow.metadata.id
|
248
|
+
save_flow_yaml(
|
249
|
+
yaml_path,
|
250
|
+
flow_config,
|
251
|
+
f"added generated id: {flow.metadata.id}",
|
252
|
+
)
|
253
|
+
else:
|
254
|
+
logger.debug(f"Flow already had id: {flow.metadata.id}")
|
224
255
|
# Store migrated runtime params and client for backward compatibility
|
225
256
|
if migrated_runtime_params:
|
226
257
|
flow._migrated_runtime_params = migrated_runtime_params
|
@@ -324,6 +355,8 @@ class Flow(BaseModel):
|
|
324
355
|
self,
|
325
356
|
dataset: Dataset,
|
326
357
|
runtime_params: Optional[dict[str, dict[str, Any]]] = None,
|
358
|
+
checkpoint_dir: Optional[str] = None,
|
359
|
+
save_freq: Optional[int] = None,
|
327
360
|
) -> Dataset:
|
328
361
|
"""Execute the flow blocks in sequence to generate data.
|
329
362
|
|
@@ -340,6 +373,11 @@ class Flow(BaseModel):
|
|
340
373
|
"block_name": {"param1": value1, "param2": value2},
|
341
374
|
"other_block": {"param3": value3}
|
342
375
|
}
|
376
|
+
checkpoint_dir : Optional[str], optional
|
377
|
+
Directory to save/load checkpoints. If provided, enables checkpointing.
|
378
|
+
save_freq : Optional[int], optional
|
379
|
+
Number of completed samples after which to save a checkpoint.
|
380
|
+
If None, only saves final results when checkpointing is enabled.
|
343
381
|
|
344
382
|
Returns
|
345
383
|
-------
|
@@ -353,6 +391,12 @@ class Flow(BaseModel):
|
|
353
391
|
FlowValidationError
|
354
392
|
If flow validation fails or if model configuration is required but not set.
|
355
393
|
"""
|
394
|
+
# Validate save_freq parameter early to prevent range() errors
|
395
|
+
if save_freq is not None and save_freq <= 0:
|
396
|
+
raise FlowValidationError(
|
397
|
+
f"save_freq must be greater than 0, got {save_freq}"
|
398
|
+
)
|
399
|
+
|
356
400
|
# Validate preconditions
|
357
401
|
if not self.blocks:
|
358
402
|
raise FlowValidationError("Cannot generate with empty flow")
|
@@ -376,18 +420,119 @@ class Flow(BaseModel):
|
|
376
420
|
"Dataset validation failed:\n" + "\n".join(dataset_errors)
|
377
421
|
)
|
378
422
|
|
423
|
+
# Initialize checkpointer if enabled
|
424
|
+
checkpointer = None
|
425
|
+
completed_dataset = None
|
426
|
+
if checkpoint_dir:
|
427
|
+
checkpointer = FlowCheckpointer(
|
428
|
+
checkpoint_dir=checkpoint_dir,
|
429
|
+
save_freq=save_freq,
|
430
|
+
flow_id=self.metadata.id,
|
431
|
+
)
|
432
|
+
|
433
|
+
# Load existing progress
|
434
|
+
remaining_dataset, completed_dataset = checkpointer.load_existing_progress(
|
435
|
+
dataset
|
436
|
+
)
|
437
|
+
|
438
|
+
if len(remaining_dataset) == 0:
|
439
|
+
logger.info("All samples already completed, returning existing results")
|
440
|
+
return completed_dataset
|
441
|
+
|
442
|
+
dataset = remaining_dataset
|
443
|
+
logger.info(f"Resuming with {len(dataset)} remaining samples")
|
444
|
+
|
379
445
|
logger.info(
|
380
446
|
f"Starting flow '{self.metadata.name}' v{self.metadata.version} "
|
381
447
|
f"with {len(dataset)} samples across {len(self.blocks)} blocks"
|
382
448
|
)
|
383
449
|
|
384
|
-
current_dataset = dataset
|
385
450
|
# Merge migrated runtime params with provided ones (provided ones take precedence)
|
386
451
|
merged_runtime_params = self._migrated_runtime_params.copy()
|
387
452
|
if runtime_params:
|
388
453
|
merged_runtime_params.update(runtime_params)
|
389
454
|
runtime_params = merged_runtime_params
|
390
455
|
|
456
|
+
# Process dataset in chunks if checkpointing with save_freq
|
457
|
+
if checkpointer and save_freq:
|
458
|
+
all_processed = []
|
459
|
+
|
460
|
+
# Process in chunks of save_freq
|
461
|
+
for i in range(0, len(dataset), save_freq):
|
462
|
+
chunk_end = min(i + save_freq, len(dataset))
|
463
|
+
chunk_dataset = dataset.select(range(i, chunk_end))
|
464
|
+
|
465
|
+
logger.info(
|
466
|
+
f"Processing chunk {i // save_freq + 1}: samples {i} to {chunk_end - 1}"
|
467
|
+
)
|
468
|
+
|
469
|
+
# Execute all blocks on this chunk
|
470
|
+
processed_chunk = self._execute_blocks_on_dataset(
|
471
|
+
chunk_dataset, runtime_params
|
472
|
+
)
|
473
|
+
all_processed.append(processed_chunk)
|
474
|
+
|
475
|
+
# Save checkpoint after chunk completion
|
476
|
+
checkpointer.add_completed_samples(processed_chunk)
|
477
|
+
|
478
|
+
# Save final checkpoint for any remaining samples
|
479
|
+
checkpointer.save_final_checkpoint()
|
480
|
+
|
481
|
+
# Combine all processed chunks
|
482
|
+
final_dataset = safe_concatenate_with_validation(
|
483
|
+
all_processed, "processed chunks from flow execution"
|
484
|
+
)
|
485
|
+
|
486
|
+
# Combine with previously completed samples if any
|
487
|
+
if checkpointer and completed_dataset:
|
488
|
+
final_dataset = safe_concatenate_with_validation(
|
489
|
+
[completed_dataset, final_dataset],
|
490
|
+
"completed checkpoint data with newly processed data",
|
491
|
+
)
|
492
|
+
|
493
|
+
else:
|
494
|
+
# Process entire dataset at once
|
495
|
+
final_dataset = self._execute_blocks_on_dataset(dataset, runtime_params)
|
496
|
+
|
497
|
+
# Save final checkpoint if checkpointing enabled
|
498
|
+
if checkpointer:
|
499
|
+
checkpointer.add_completed_samples(final_dataset)
|
500
|
+
checkpointer.save_final_checkpoint()
|
501
|
+
|
502
|
+
# Combine with previously completed samples if any
|
503
|
+
if completed_dataset:
|
504
|
+
final_dataset = safe_concatenate_with_validation(
|
505
|
+
[completed_dataset, final_dataset],
|
506
|
+
"completed checkpoint data with newly processed data",
|
507
|
+
)
|
508
|
+
|
509
|
+
logger.info(
|
510
|
+
f"Flow '{self.metadata.name}' completed successfully: "
|
511
|
+
f"{len(final_dataset)} final samples, "
|
512
|
+
f"{len(final_dataset.column_names)} final columns"
|
513
|
+
)
|
514
|
+
|
515
|
+
return final_dataset
|
516
|
+
|
517
|
+
def _execute_blocks_on_dataset(
|
518
|
+
self, dataset: Dataset, runtime_params: dict[str, dict[str, Any]]
|
519
|
+
) -> Dataset:
|
520
|
+
"""Execute all blocks in sequence on the given dataset.
|
521
|
+
|
522
|
+
Parameters
|
523
|
+
----------
|
524
|
+
dataset : Dataset
|
525
|
+
Dataset to process through all blocks.
|
526
|
+
runtime_params : Dict[str, Dict[str, Any]]
|
527
|
+
Runtime parameters for block execution.
|
528
|
+
|
529
|
+
Returns
|
530
|
+
-------
|
531
|
+
Dataset
|
532
|
+
Dataset after processing through all blocks.
|
533
|
+
"""
|
534
|
+
current_dataset = dataset
|
535
|
+
|
391
536
|
# Execute blocks in sequence
|
392
537
|
for i, block in enumerate(self.blocks):
|
393
538
|
logger.info(
|
@@ -436,12 +581,6 @@ class Flow(BaseModel):
|
|
436
581
|
f"Block '{block.block_name}' execution failed: {exc}"
|
437
582
|
) from exc
|
438
583
|
|
439
|
-
logger.info(
|
440
|
-
f"Flow '{self.metadata.name}' completed successfully: "
|
441
|
-
f"{len(current_dataset)} final samples, "
|
442
|
-
f"{len(current_dataset.column_names)} final columns"
|
443
|
-
)
|
444
|
-
|
445
584
|
return current_dataset
|
446
585
|
|
447
586
|
def _prepare_block_kwargs(
|
@@ -784,9 +923,6 @@ class Flow(BaseModel):
|
|
784
923
|
"execution_time_seconds": 0,
|
785
924
|
}
|
786
925
|
|
787
|
-
# Standard
|
788
|
-
import time
|
789
|
-
|
790
926
|
start_time = time.time()
|
791
927
|
|
792
928
|
try:
|
@@ -930,6 +1066,102 @@ class Flow(BaseModel):
|
|
930
1066
|
"block_names": [block.block_name for block in self.blocks],
|
931
1067
|
}
|
932
1068
|
|
1069
|
+
def print_info(self) -> None:
|
1070
|
+
"""
|
1071
|
+
Print an interactive summary of the Flow in the console.
|
1072
|
+
|
1073
|
+
The summary contains:
|
1074
|
+
1. Flow metadata (name, version, author, description)
|
1075
|
+
2. Defined runtime parameters with type hints and defaults
|
1076
|
+
3. A table of all blocks with their input and output columns
|
1077
|
+
|
1078
|
+
Notes
|
1079
|
+
-----
|
1080
|
+
Uses the `rich` library for colourised output; install with
|
1081
|
+
`pip install rich` if not already present.
|
1082
|
+
|
1083
|
+
Returns
|
1084
|
+
-------
|
1085
|
+
None
|
1086
|
+
"""
|
1087
|
+
|
1088
|
+
console = Console()
|
1089
|
+
|
1090
|
+
# Create main tree structure
|
1091
|
+
flow_tree = Tree(
|
1092
|
+
f"[bold bright_blue]{self.metadata.name}[/bold bright_blue] Flow"
|
1093
|
+
)
|
1094
|
+
|
1095
|
+
# Metadata section
|
1096
|
+
metadata_branch = flow_tree.add(
|
1097
|
+
"[bold bright_green]Metadata[/bold bright_green]"
|
1098
|
+
)
|
1099
|
+
metadata_branch.add(
|
1100
|
+
f"Version: [bright_cyan]{self.metadata.version}[/bright_cyan]"
|
1101
|
+
)
|
1102
|
+
metadata_branch.add(
|
1103
|
+
f"Author: [bright_cyan]{self.metadata.author}[/bright_cyan]"
|
1104
|
+
)
|
1105
|
+
if self.metadata.description:
|
1106
|
+
metadata_branch.add(
|
1107
|
+
f"Description: [white]{self.metadata.description}[/white]"
|
1108
|
+
)
|
1109
|
+
|
1110
|
+
# Parameters section
|
1111
|
+
if self.parameters:
|
1112
|
+
params_branch = flow_tree.add(
|
1113
|
+
"[bold bright_yellow]Parameters[/bold bright_yellow]"
|
1114
|
+
)
|
1115
|
+
for name, param in self.parameters.items():
|
1116
|
+
param_info = f"[bright_cyan]{name}[/bright_cyan]: [white]{param.type_hint}[/white]"
|
1117
|
+
if param.default is not None:
|
1118
|
+
param_info += f" = [bright_white]{param.default}[/bright_white]"
|
1119
|
+
params_branch.add(param_info)
|
1120
|
+
|
1121
|
+
# Blocks overview
|
1122
|
+
flow_tree.add(
|
1123
|
+
f"[bold bright_magenta]Blocks[/bold bright_magenta] ({len(self.blocks)} total)"
|
1124
|
+
)
|
1125
|
+
|
1126
|
+
# Create blocks table
|
1127
|
+
blocks_table = Table(show_header=True, header_style="bold bright_white")
|
1128
|
+
blocks_table.add_column("Block Name", style="bright_cyan")
|
1129
|
+
blocks_table.add_column("Type", style="bright_green")
|
1130
|
+
blocks_table.add_column("Input Cols", style="bright_yellow")
|
1131
|
+
blocks_table.add_column("Output Cols", style="bright_red")
|
1132
|
+
|
1133
|
+
for block in self.blocks:
|
1134
|
+
input_cols = getattr(block, "input_cols", None)
|
1135
|
+
output_cols = getattr(block, "output_cols", None)
|
1136
|
+
|
1137
|
+
blocks_table.add_row(
|
1138
|
+
block.block_name,
|
1139
|
+
block.__class__.__name__,
|
1140
|
+
str(input_cols) if input_cols else "[bright_black]None[/bright_black]",
|
1141
|
+
str(output_cols)
|
1142
|
+
if output_cols
|
1143
|
+
else "[bright_black]None[/bright_black]",
|
1144
|
+
)
|
1145
|
+
|
1146
|
+
# Print everything
|
1147
|
+
console.print()
|
1148
|
+
console.print(
|
1149
|
+
Panel(
|
1150
|
+
flow_tree,
|
1151
|
+
title="[bold bright_white]Flow Information[/bold bright_white]",
|
1152
|
+
border_style="bright_blue",
|
1153
|
+
)
|
1154
|
+
)
|
1155
|
+
console.print()
|
1156
|
+
console.print(
|
1157
|
+
Panel(
|
1158
|
+
blocks_table,
|
1159
|
+
title="[bold bright_white]Block Details[/bold bright_white]",
|
1160
|
+
border_style="bright_magenta",
|
1161
|
+
)
|
1162
|
+
)
|
1163
|
+
console.print()
|
1164
|
+
|
933
1165
|
def to_yaml(self, output_path: str) -> None:
|
934
1166
|
"""Save flow configuration to YAML file.
|
935
1167
|
|
@@ -952,10 +1184,7 @@ class Flow(BaseModel):
|
|
952
1184
|
name: param.model_dump() for name, param in self.parameters.items()
|
953
1185
|
}
|
954
1186
|
|
955
|
-
|
956
|
-
yaml.dump(config, f, default_flow_style=False, sort_keys=False)
|
957
|
-
|
958
|
-
logger.info(f"Flow configuration saved to: {output_path}")
|
1187
|
+
save_flow_yaml(output_path, config)
|
959
1188
|
|
960
1189
|
def __len__(self) -> int:
|
961
1190
|
"""Number of blocks in the flow."""
|
@@ -0,0 +1,333 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
"""Flow-level checkpointing with sample-level tracking for data generation pipelines."""
|
3
|
+
|
4
|
+
# Standard
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Any, Dict, List, Optional, Tuple
|
7
|
+
import json
|
8
|
+
import os
|
9
|
+
import uuid
|
10
|
+
|
11
|
+
# Third Party
|
12
|
+
from datasets import Dataset
|
13
|
+
|
14
|
+
# Local
|
15
|
+
from ..utils.datautils import safe_concatenate_with_validation
|
16
|
+
from ..utils.logger_config import setup_logger
|
17
|
+
|
18
|
+
logger = setup_logger(__name__)
|
19
|
+
|
20
|
+
|
21
|
+
class FlowCheckpointer:
|
22
|
+
"""Enhanced checkpointer for Flow execution with sample-level tracking.
|
23
|
+
|
24
|
+
Provides data-level checkpointing where progress is saved after processing
|
25
|
+
a specified number of samples through the entire flow pipeline.
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
checkpoint_dir: Optional[str] = None,
|
31
|
+
save_freq: Optional[int] = None,
|
32
|
+
flow_id: Optional[str] = None,
|
33
|
+
):
|
34
|
+
"""Initialize the FlowCheckpointer.
|
35
|
+
|
36
|
+
Parameters
|
37
|
+
----------
|
38
|
+
checkpoint_dir : Optional[str]
|
39
|
+
Directory to save/load checkpoints. If None, checkpointing is disabled.
|
40
|
+
save_freq : Optional[int]
|
41
|
+
Number of completed samples after which to save a checkpoint.
|
42
|
+
If None, only final results are saved.
|
43
|
+
flow_id : Optional[str]
|
44
|
+
Unique ID of the flow for checkpoint identification.
|
45
|
+
"""
|
46
|
+
self.checkpoint_dir = checkpoint_dir
|
47
|
+
self.save_freq = save_freq
|
48
|
+
self.flow_id = flow_id or "unknown_flow"
|
49
|
+
|
50
|
+
# Internal state
|
51
|
+
self._samples_processed = 0
|
52
|
+
self._checkpoint_counter = 0
|
53
|
+
self._pending_samples: List[Dict[str, Any]] = []
|
54
|
+
|
55
|
+
# Ensure checkpoint directory exists
|
56
|
+
if self.checkpoint_dir:
|
57
|
+
Path(self.checkpoint_dir).mkdir(parents=True, exist_ok=True)
|
58
|
+
|
59
|
+
@property
|
60
|
+
def is_enabled(self) -> bool:
|
61
|
+
"""Check if checkpointing is enabled."""
|
62
|
+
return self.checkpoint_dir is not None
|
63
|
+
|
64
|
+
@property
|
65
|
+
def metadata_path(self) -> str:
|
66
|
+
"""Path to the flow metadata file."""
|
67
|
+
return os.path.join(self.checkpoint_dir, "flow_metadata.json")
|
68
|
+
|
69
|
+
def load_existing_progress(
|
70
|
+
self, input_dataset: Dataset
|
71
|
+
) -> Tuple[Dataset, Optional[Dataset]]:
|
72
|
+
"""Load existing checkpoint data and determine remaining work.
|
73
|
+
|
74
|
+
Parameters
|
75
|
+
----------
|
76
|
+
input_dataset : Dataset
|
77
|
+
Original input dataset for the flow.
|
78
|
+
|
79
|
+
Returns
|
80
|
+
-------
|
81
|
+
Tuple[Dataset, Optional[Dataset]]
|
82
|
+
(remaining_samples_to_process, completed_samples_dataset)
|
83
|
+
If no checkpoints exist, returns (input_dataset, None)
|
84
|
+
"""
|
85
|
+
if not self.is_enabled:
|
86
|
+
return input_dataset, None
|
87
|
+
|
88
|
+
try:
|
89
|
+
# Load flow metadata
|
90
|
+
metadata = self._load_metadata()
|
91
|
+
if not metadata:
|
92
|
+
logger.info(f"No existing checkpoints found in {self.checkpoint_dir}")
|
93
|
+
return input_dataset, None
|
94
|
+
|
95
|
+
# Validate flow identity to prevent mixing checkpoints from different flows
|
96
|
+
saved_flow_id = metadata.get("flow_id")
|
97
|
+
if saved_flow_id and saved_flow_id != self.flow_id:
|
98
|
+
logger.warning(
|
99
|
+
f"Flow ID mismatch: saved checkpoints are for flow ID '{saved_flow_id}' "
|
100
|
+
f"but current flow ID is '{self.flow_id}'. Starting fresh to avoid "
|
101
|
+
f"mixing incompatible checkpoint data."
|
102
|
+
)
|
103
|
+
return input_dataset, None
|
104
|
+
|
105
|
+
# Load all completed samples from checkpoints
|
106
|
+
completed_dataset = self._load_completed_samples()
|
107
|
+
if completed_dataset is None or len(completed_dataset) == 0:
|
108
|
+
logger.info("No completed samples found in checkpoints")
|
109
|
+
return input_dataset, None
|
110
|
+
|
111
|
+
# Find samples that still need processing
|
112
|
+
remaining_dataset = self._find_remaining_samples(
|
113
|
+
input_dataset, completed_dataset
|
114
|
+
)
|
115
|
+
|
116
|
+
self._samples_processed = len(completed_dataset)
|
117
|
+
self._checkpoint_counter = metadata.get("checkpoint_counter", 0)
|
118
|
+
|
119
|
+
logger.info(
|
120
|
+
f"Loaded {len(completed_dataset)} completed samples, "
|
121
|
+
f"{len(remaining_dataset)} samples remaining"
|
122
|
+
)
|
123
|
+
|
124
|
+
return remaining_dataset, completed_dataset
|
125
|
+
|
126
|
+
except Exception as exc:
|
127
|
+
logger.warning(f"Failed to load checkpoints: {exc}. Starting from scratch.")
|
128
|
+
return input_dataset, None
|
129
|
+
|
130
|
+
def add_completed_samples(self, samples: Dataset) -> None:
|
131
|
+
"""Add samples that have completed the entire flow.
|
132
|
+
|
133
|
+
Parameters
|
134
|
+
----------
|
135
|
+
samples : Dataset
|
136
|
+
Samples that have completed processing through all blocks.
|
137
|
+
"""
|
138
|
+
if not self.is_enabled:
|
139
|
+
return
|
140
|
+
|
141
|
+
# Add to pending samples
|
142
|
+
for sample in samples:
|
143
|
+
self._pending_samples.append(sample)
|
144
|
+
self._samples_processed += 1
|
145
|
+
|
146
|
+
# Check if we should save a checkpoint
|
147
|
+
if self.save_freq and len(self._pending_samples) >= self.save_freq:
|
148
|
+
self._save_checkpoint()
|
149
|
+
|
150
|
+
def save_final_checkpoint(self) -> None:
|
151
|
+
"""Save any remaining pending samples as final checkpoint."""
|
152
|
+
if not self.is_enabled:
|
153
|
+
return
|
154
|
+
|
155
|
+
if self._pending_samples:
|
156
|
+
sample_count = len(self._pending_samples)
|
157
|
+
self._save_checkpoint()
|
158
|
+
logger.info(f"Saved final checkpoint with {sample_count} samples")
|
159
|
+
|
160
|
+
def _save_checkpoint(self) -> None:
|
161
|
+
"""Save current pending samples to a checkpoint file."""
|
162
|
+
if not self._pending_samples:
|
163
|
+
return
|
164
|
+
|
165
|
+
self._checkpoint_counter += 1
|
166
|
+
checkpoint_file = os.path.join(
|
167
|
+
self.checkpoint_dir, f"checkpoint_{self._checkpoint_counter:04d}.jsonl"
|
168
|
+
)
|
169
|
+
|
170
|
+
# Convert pending samples to dataset and save
|
171
|
+
checkpoint_dataset = Dataset.from_list(self._pending_samples)
|
172
|
+
checkpoint_dataset.to_json(checkpoint_file, orient="records", lines=True)
|
173
|
+
|
174
|
+
# Update metadata
|
175
|
+
self._save_metadata()
|
176
|
+
|
177
|
+
logger.info(
|
178
|
+
f"Saved checkpoint {self._checkpoint_counter} with "
|
179
|
+
f"{len(self._pending_samples)} samples to {checkpoint_file}"
|
180
|
+
)
|
181
|
+
|
182
|
+
# Clear pending samples
|
183
|
+
self._pending_samples.clear()
|
184
|
+
|
185
|
+
def _save_metadata(self) -> None:
|
186
|
+
"""Save flow execution metadata."""
|
187
|
+
metadata = {
|
188
|
+
"flow_id": self.flow_id,
|
189
|
+
"save_freq": self.save_freq,
|
190
|
+
"samples_processed": self._samples_processed,
|
191
|
+
"checkpoint_counter": self._checkpoint_counter,
|
192
|
+
"last_updated": str(uuid.uuid4()), # Simple versioning
|
193
|
+
}
|
194
|
+
|
195
|
+
with open(self.metadata_path, "w", encoding="utf-8") as f:
|
196
|
+
json.dump(metadata, f, indent=2)
|
197
|
+
|
198
|
+
def _load_metadata(self) -> Optional[Dict[str, Any]]:
|
199
|
+
"""Load flow execution metadata."""
|
200
|
+
if not os.path.exists(self.metadata_path):
|
201
|
+
return None
|
202
|
+
|
203
|
+
try:
|
204
|
+
with open(self.metadata_path, "r", encoding="utf-8") as f:
|
205
|
+
return json.load(f)
|
206
|
+
except Exception as exc:
|
207
|
+
logger.warning(f"Failed to load metadata: {exc}")
|
208
|
+
return None
|
209
|
+
|
210
|
+
def _load_completed_samples(self) -> Optional[Dataset]:
|
211
|
+
"""Load all completed samples from checkpoint files."""
|
212
|
+
checkpoint_files = []
|
213
|
+
checkpoint_dir = Path(self.checkpoint_dir)
|
214
|
+
|
215
|
+
# Find all checkpoint files
|
216
|
+
for file_path in checkpoint_dir.glob("checkpoint_*.jsonl"):
|
217
|
+
checkpoint_files.append(str(file_path))
|
218
|
+
|
219
|
+
if not checkpoint_files:
|
220
|
+
return None
|
221
|
+
|
222
|
+
# Sort checkpoint files by number
|
223
|
+
checkpoint_files.sort()
|
224
|
+
|
225
|
+
# Load and concatenate all checkpoint datasets
|
226
|
+
datasets = []
|
227
|
+
for file_path in checkpoint_files:
|
228
|
+
try:
|
229
|
+
dataset = Dataset.from_json(file_path)
|
230
|
+
if len(dataset) > 0:
|
231
|
+
datasets.append(dataset)
|
232
|
+
logger.debug(
|
233
|
+
f"Loaded checkpoint: {file_path} ({len(dataset)} samples)"
|
234
|
+
)
|
235
|
+
except Exception as exc:
|
236
|
+
logger.warning(f"Failed to load checkpoint {file_path}: {exc}")
|
237
|
+
|
238
|
+
if not datasets:
|
239
|
+
return None
|
240
|
+
|
241
|
+
return safe_concatenate_with_validation(datasets, "checkpoint files")
|
242
|
+
|
243
|
+
def _find_remaining_samples(
|
244
|
+
self, input_dataset: Dataset, completed_dataset: Dataset
|
245
|
+
) -> Dataset:
|
246
|
+
"""Find samples from input_dataset that are not in completed_dataset.
|
247
|
+
|
248
|
+
Note: Assumes input_dataset contains unique samples. For datasets with
|
249
|
+
duplicates, multiset semantics with collections.Counter would be needed.
|
250
|
+
|
251
|
+
Parameters
|
252
|
+
----------
|
253
|
+
input_dataset : Dataset
|
254
|
+
Original input dataset (assumed to contain unique samples).
|
255
|
+
completed_dataset : Dataset
|
256
|
+
Dataset of completed samples.
|
257
|
+
|
258
|
+
Returns
|
259
|
+
-------
|
260
|
+
Dataset
|
261
|
+
Samples that still need processing.
|
262
|
+
"""
|
263
|
+
# Get common columns for comparison
|
264
|
+
input_columns = set(input_dataset.column_names)
|
265
|
+
completed_columns = set(completed_dataset.column_names)
|
266
|
+
common_columns = list(input_columns & completed_columns)
|
267
|
+
|
268
|
+
if not common_columns:
|
269
|
+
logger.warning(
|
270
|
+
"No common columns found between input and completed datasets. "
|
271
|
+
"Processing all input samples."
|
272
|
+
)
|
273
|
+
return input_dataset
|
274
|
+
|
275
|
+
# Convert to pandas for easier comparison
|
276
|
+
input_df = input_dataset.select_columns(common_columns).to_pandas()
|
277
|
+
completed_df = completed_dataset.select_columns(common_columns).to_pandas()
|
278
|
+
|
279
|
+
# Find rows that haven't been completed
|
280
|
+
# Use tuple representation for comparison
|
281
|
+
input_tuples = set(input_df.apply(tuple, axis=1))
|
282
|
+
completed_tuples = set(completed_df.apply(tuple, axis=1))
|
283
|
+
remaining_tuples = input_tuples - completed_tuples
|
284
|
+
|
285
|
+
# Filter input dataset to only remaining samples
|
286
|
+
remaining_mask = input_df.apply(tuple, axis=1).isin(remaining_tuples)
|
287
|
+
remaining_indices = input_df[remaining_mask].index.tolist()
|
288
|
+
|
289
|
+
if not remaining_indices:
|
290
|
+
# Return empty dataset with same structure
|
291
|
+
return input_dataset.select([])
|
292
|
+
|
293
|
+
return input_dataset.select(remaining_indices)
|
294
|
+
|
295
|
+
def get_progress_info(self) -> Dict[str, Any]:
|
296
|
+
"""Get information about current progress.
|
297
|
+
|
298
|
+
Returns
|
299
|
+
-------
|
300
|
+
Dict[str, Any]
|
301
|
+
Progress information including samples processed, checkpoints saved, etc.
|
302
|
+
"""
|
303
|
+
return {
|
304
|
+
"checkpoint_dir": self.checkpoint_dir,
|
305
|
+
"save_freq": self.save_freq,
|
306
|
+
"flow_id": self.flow_id,
|
307
|
+
"samples_processed": self._samples_processed,
|
308
|
+
"checkpoint_counter": self._checkpoint_counter,
|
309
|
+
"pending_samples": len(self._pending_samples),
|
310
|
+
"is_enabled": self.is_enabled,
|
311
|
+
}
|
312
|
+
|
313
|
+
def cleanup_checkpoints(self) -> None:
|
314
|
+
"""Remove all checkpoint files and metadata."""
|
315
|
+
if not self.is_enabled:
|
316
|
+
return
|
317
|
+
|
318
|
+
checkpoint_dir = Path(self.checkpoint_dir)
|
319
|
+
if not checkpoint_dir.exists():
|
320
|
+
return
|
321
|
+
|
322
|
+
# Remove all checkpoint files
|
323
|
+
for file_path in checkpoint_dir.glob("checkpoint_*.jsonl"):
|
324
|
+
file_path.unlink()
|
325
|
+
logger.debug(f"Removed checkpoint file: {file_path}")
|
326
|
+
|
327
|
+
# Remove metadata file
|
328
|
+
metadata_path = Path(self.metadata_path)
|
329
|
+
if metadata_path.exists():
|
330
|
+
metadata_path.unlink()
|
331
|
+
logger.debug(f"Removed metadata file: {metadata_path}")
|
332
|
+
|
333
|
+
logger.info(f"Cleaned up all checkpoints in {self.checkpoint_dir}")
|