sdg-hub 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sdg_hub/__init__.py +28 -1
- sdg_hub/_version.py +2 -2
- sdg_hub/core/__init__.py +22 -0
- sdg_hub/core/blocks/__init__.py +58 -0
- sdg_hub/core/blocks/base.py +313 -0
- sdg_hub/core/blocks/deprecated_blocks/__init__.py +29 -0
- sdg_hub/core/blocks/deprecated_blocks/combine_columns.py +93 -0
- sdg_hub/core/blocks/deprecated_blocks/duplicate_columns.py +88 -0
- sdg_hub/core/blocks/deprecated_blocks/filter_by_value.py +103 -0
- sdg_hub/core/blocks/deprecated_blocks/flatten_columns.py +94 -0
- sdg_hub/core/blocks/deprecated_blocks/llmblock.py +479 -0
- sdg_hub/core/blocks/deprecated_blocks/rename_columns.py +88 -0
- sdg_hub/core/blocks/deprecated_blocks/sample_populator.py +58 -0
- sdg_hub/core/blocks/deprecated_blocks/selector.py +97 -0
- sdg_hub/core/blocks/deprecated_blocks/set_to_majority_value.py +88 -0
- sdg_hub/core/blocks/evaluation/__init__.py +9 -0
- sdg_hub/core/blocks/evaluation/evaluate_faithfulness_block.py +564 -0
- sdg_hub/core/blocks/evaluation/evaluate_relevancy_block.py +564 -0
- sdg_hub/core/blocks/evaluation/verify_question_block.py +564 -0
- sdg_hub/core/blocks/filtering/__init__.py +12 -0
- sdg_hub/core/blocks/filtering/column_value_filter.py +188 -0
- sdg_hub/core/blocks/llm/__init__.py +27 -0
- sdg_hub/core/blocks/llm/client_manager.py +398 -0
- sdg_hub/core/blocks/llm/config.py +336 -0
- sdg_hub/core/blocks/llm/error_handler.py +368 -0
- sdg_hub/core/blocks/llm/llm_chat_block.py +542 -0
- sdg_hub/core/blocks/llm/llm_chat_with_parsing_retry_block.py +491 -0
- sdg_hub/core/blocks/llm/prompt_builder_block.py +368 -0
- sdg_hub/core/blocks/llm/text_parser_block.py +357 -0
- sdg_hub/core/blocks/registry.py +331 -0
- sdg_hub/core/blocks/transform/__init__.py +23 -0
- sdg_hub/core/blocks/transform/duplicate_columns.py +88 -0
- sdg_hub/core/blocks/transform/index_based_mapper.py +225 -0
- sdg_hub/core/blocks/transform/melt_columns.py +126 -0
- sdg_hub/core/blocks/transform/rename_columns.py +69 -0
- sdg_hub/core/blocks/transform/text_concat.py +102 -0
- sdg_hub/core/blocks/transform/uniform_col_val_setter.py +101 -0
- sdg_hub/core/flow/__init__.py +20 -0
- sdg_hub/core/flow/base.py +1209 -0
- sdg_hub/core/flow/checkpointer.py +333 -0
- sdg_hub/core/flow/metadata.py +389 -0
- sdg_hub/core/flow/migration.py +198 -0
- sdg_hub/core/flow/registry.py +393 -0
- sdg_hub/core/flow/validation.py +277 -0
- sdg_hub/{utils → core/utils}/__init__.py +7 -4
- sdg_hub/core/utils/datautils.py +63 -0
- sdg_hub/core/utils/error_handling.py +208 -0
- sdg_hub/core/utils/flow_id_words.yaml +231 -0
- sdg_hub/core/utils/flow_identifier.py +94 -0
- sdg_hub/{utils → core/utils}/path_resolution.py +2 -2
- sdg_hub/core/utils/yaml_utils.py +59 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/atomic_facts.yaml +40 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/detailed_summary.yaml +13 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_faithfulness.yaml +64 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_question.yaml +29 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_relevancy.yaml +81 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/extractive_summary.yaml +13 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/flow.yaml +192 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/generate_questions_responses.yaml +54 -0
- sdg_hub-0.2.1.dist-info/METADATA +221 -0
- sdg_hub-0.2.1.dist-info/RECORD +68 -0
- sdg_hub/blocks/__init__.py +0 -42
- sdg_hub/blocks/block.py +0 -96
- sdg_hub/blocks/llmblock.py +0 -375
- sdg_hub/blocks/openaichatblock.py +0 -556
- sdg_hub/blocks/utilblocks.py +0 -597
- sdg_hub/checkpointer.py +0 -139
- sdg_hub/configs/annotations/cot_reflection.yaml +0 -34
- sdg_hub/configs/annotations/detailed_annotations.yaml +0 -28
- sdg_hub/configs/annotations/detailed_description.yaml +0 -10
- sdg_hub/configs/annotations/detailed_description_icl.yaml +0 -32
- sdg_hub/configs/annotations/simple_annotations.yaml +0 -9
- sdg_hub/configs/knowledge/__init__.py +0 -0
- sdg_hub/configs/knowledge/atomic_facts.yaml +0 -46
- sdg_hub/configs/knowledge/auxilary_instructions.yaml +0 -35
- sdg_hub/configs/knowledge/detailed_summary.yaml +0 -18
- sdg_hub/configs/knowledge/evaluate_faithfulness.yaml +0 -68
- sdg_hub/configs/knowledge/evaluate_question.yaml +0 -38
- sdg_hub/configs/knowledge/evaluate_relevancy.yaml +0 -84
- sdg_hub/configs/knowledge/extractive_summary.yaml +0 -18
- sdg_hub/configs/knowledge/generate_code_questions_responses.yaml +0 -39
- sdg_hub/configs/knowledge/generate_questions.yaml +0 -82
- sdg_hub/configs/knowledge/generate_questions_responses.yaml +0 -56
- sdg_hub/configs/knowledge/generate_responses.yaml +0 -86
- sdg_hub/configs/knowledge/mcq_generation.yaml +0 -83
- sdg_hub/configs/knowledge/router.yaml +0 -12
- sdg_hub/configs/knowledge/simple_generate_qa.yaml +0 -34
- sdg_hub/configs/reasoning/__init__.py +0 -0
- sdg_hub/configs/reasoning/dynamic_cot.yaml +0 -40
- sdg_hub/configs/skills/__init__.py +0 -0
- sdg_hub/configs/skills/analyzer.yaml +0 -48
- sdg_hub/configs/skills/annotation.yaml +0 -36
- sdg_hub/configs/skills/contexts.yaml +0 -28
- sdg_hub/configs/skills/critic.yaml +0 -60
- sdg_hub/configs/skills/evaluate_freeform_pair.yaml +0 -111
- sdg_hub/configs/skills/evaluate_freeform_questions.yaml +0 -78
- sdg_hub/configs/skills/evaluate_grounded_pair.yaml +0 -119
- sdg_hub/configs/skills/evaluate_grounded_questions.yaml +0 -51
- sdg_hub/configs/skills/freeform_questions.yaml +0 -34
- sdg_hub/configs/skills/freeform_responses.yaml +0 -39
- sdg_hub/configs/skills/grounded_questions.yaml +0 -38
- sdg_hub/configs/skills/grounded_responses.yaml +0 -59
- sdg_hub/configs/skills/icl_examples/STEM.yaml +0 -56
- sdg_hub/configs/skills/icl_examples/__init__.py +0 -0
- sdg_hub/configs/skills/icl_examples/coding.yaml +0 -97
- sdg_hub/configs/skills/icl_examples/extraction.yaml +0 -36
- sdg_hub/configs/skills/icl_examples/humanities.yaml +0 -71
- sdg_hub/configs/skills/icl_examples/math.yaml +0 -85
- sdg_hub/configs/skills/icl_examples/reasoning.yaml +0 -30
- sdg_hub/configs/skills/icl_examples/roleplay.yaml +0 -45
- sdg_hub/configs/skills/icl_examples/writing.yaml +0 -80
- sdg_hub/configs/skills/judge.yaml +0 -53
- sdg_hub/configs/skills/planner.yaml +0 -67
- sdg_hub/configs/skills/respond.yaml +0 -8
- sdg_hub/configs/skills/revised_responder.yaml +0 -78
- sdg_hub/configs/skills/router.yaml +0 -59
- sdg_hub/configs/skills/simple_generate_qa_freeform.yaml +0 -27
- sdg_hub/configs/skills/simple_generate_qa_grounded.yaml +0 -31
- sdg_hub/flow.py +0 -477
- sdg_hub/flow_runner.py +0 -450
- sdg_hub/flows/generation/knowledge/mmlu_bench.yaml +0 -13
- sdg_hub/flows/generation/knowledge/simple_knowledge.yaml +0 -12
- sdg_hub/flows/generation/knowledge/synth_knowledge.yaml +0 -89
- sdg_hub/flows/generation/knowledge/synth_knowledge1.5.yaml +0 -136
- sdg_hub/flows/generation/skills/improve_responses.yaml +0 -103
- sdg_hub/flows/generation/skills/simple_freeform_skill.yaml +0 -12
- sdg_hub/flows/generation/skills/simple_grounded_skill.yaml +0 -12
- sdg_hub/flows/generation/skills/synth_grounded_skills.yaml +0 -80
- sdg_hub/flows/generation/skills/synth_skills.yaml +0 -59
- sdg_hub/pipeline.py +0 -121
- sdg_hub/prompts.py +0 -80
- sdg_hub/registry.py +0 -122
- sdg_hub/sdg.py +0 -206
- sdg_hub/utils/config_validation.py +0 -91
- sdg_hub/utils/datautils.py +0 -14
- sdg_hub/utils/error_handling.py +0 -94
- sdg_hub/utils/validation_result.py +0 -10
- sdg_hub-0.1.4.dist-info/METADATA +0 -190
- sdg_hub-0.1.4.dist-info/RECORD +0 -89
- sdg_hub/{logger_config.py → core/utils/logger_config.py} +1 -1
- /sdg_hub/{configs/__init__.py → flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/README.md} +0 -0
- /sdg_hub/{configs/annotations → flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab}/__init__.py +0 -0
- {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.1.dist-info}/WHEEL +0 -0
- {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,333 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
"""Flow-level checkpointing with sample-level tracking for data generation pipelines."""
|
3
|
+
|
4
|
+
# Standard
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Any, Dict, List, Optional, Tuple
|
7
|
+
import json
|
8
|
+
import os
|
9
|
+
import uuid
|
10
|
+
|
11
|
+
# Third Party
|
12
|
+
from datasets import Dataset
|
13
|
+
|
14
|
+
# Local
|
15
|
+
from ..utils.datautils import safe_concatenate_with_validation
|
16
|
+
from ..utils.logger_config import setup_logger
|
17
|
+
|
18
|
+
logger = setup_logger(__name__)
|
19
|
+
|
20
|
+
|
21
|
+
class FlowCheckpointer:
|
22
|
+
"""Enhanced checkpointer for Flow execution with sample-level tracking.
|
23
|
+
|
24
|
+
Provides data-level checkpointing where progress is saved after processing
|
25
|
+
a specified number of samples through the entire flow pipeline.
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
checkpoint_dir: Optional[str] = None,
|
31
|
+
save_freq: Optional[int] = None,
|
32
|
+
flow_id: Optional[str] = None,
|
33
|
+
):
|
34
|
+
"""Initialize the FlowCheckpointer.
|
35
|
+
|
36
|
+
Parameters
|
37
|
+
----------
|
38
|
+
checkpoint_dir : Optional[str]
|
39
|
+
Directory to save/load checkpoints. If None, checkpointing is disabled.
|
40
|
+
save_freq : Optional[int]
|
41
|
+
Number of completed samples after which to save a checkpoint.
|
42
|
+
If None, only final results are saved.
|
43
|
+
flow_id : Optional[str]
|
44
|
+
Unique ID of the flow for checkpoint identification.
|
45
|
+
"""
|
46
|
+
self.checkpoint_dir = checkpoint_dir
|
47
|
+
self.save_freq = save_freq
|
48
|
+
self.flow_id = flow_id or "unknown_flow"
|
49
|
+
|
50
|
+
# Internal state
|
51
|
+
self._samples_processed = 0
|
52
|
+
self._checkpoint_counter = 0
|
53
|
+
self._pending_samples: List[Dict[str, Any]] = []
|
54
|
+
|
55
|
+
# Ensure checkpoint directory exists
|
56
|
+
if self.checkpoint_dir:
|
57
|
+
Path(self.checkpoint_dir).mkdir(parents=True, exist_ok=True)
|
58
|
+
|
59
|
+
@property
|
60
|
+
def is_enabled(self) -> bool:
|
61
|
+
"""Check if checkpointing is enabled."""
|
62
|
+
return self.checkpoint_dir is not None
|
63
|
+
|
64
|
+
@property
|
65
|
+
def metadata_path(self) -> str:
|
66
|
+
"""Path to the flow metadata file."""
|
67
|
+
return os.path.join(self.checkpoint_dir, "flow_metadata.json")
|
68
|
+
|
69
|
+
def load_existing_progress(
|
70
|
+
self, input_dataset: Dataset
|
71
|
+
) -> Tuple[Dataset, Optional[Dataset]]:
|
72
|
+
"""Load existing checkpoint data and determine remaining work.
|
73
|
+
|
74
|
+
Parameters
|
75
|
+
----------
|
76
|
+
input_dataset : Dataset
|
77
|
+
Original input dataset for the flow.
|
78
|
+
|
79
|
+
Returns
|
80
|
+
-------
|
81
|
+
Tuple[Dataset, Optional[Dataset]]
|
82
|
+
(remaining_samples_to_process, completed_samples_dataset)
|
83
|
+
If no checkpoints exist, returns (input_dataset, None)
|
84
|
+
"""
|
85
|
+
if not self.is_enabled:
|
86
|
+
return input_dataset, None
|
87
|
+
|
88
|
+
try:
|
89
|
+
# Load flow metadata
|
90
|
+
metadata = self._load_metadata()
|
91
|
+
if not metadata:
|
92
|
+
logger.info(f"No existing checkpoints found in {self.checkpoint_dir}")
|
93
|
+
return input_dataset, None
|
94
|
+
|
95
|
+
# Validate flow identity to prevent mixing checkpoints from different flows
|
96
|
+
saved_flow_id = metadata.get("flow_id")
|
97
|
+
if saved_flow_id and saved_flow_id != self.flow_id:
|
98
|
+
logger.warning(
|
99
|
+
f"Flow ID mismatch: saved checkpoints are for flow ID '{saved_flow_id}' "
|
100
|
+
f"but current flow ID is '{self.flow_id}'. Starting fresh to avoid "
|
101
|
+
f"mixing incompatible checkpoint data."
|
102
|
+
)
|
103
|
+
return input_dataset, None
|
104
|
+
|
105
|
+
# Load all completed samples from checkpoints
|
106
|
+
completed_dataset = self._load_completed_samples()
|
107
|
+
if completed_dataset is None or len(completed_dataset) == 0:
|
108
|
+
logger.info("No completed samples found in checkpoints")
|
109
|
+
return input_dataset, None
|
110
|
+
|
111
|
+
# Find samples that still need processing
|
112
|
+
remaining_dataset = self._find_remaining_samples(
|
113
|
+
input_dataset, completed_dataset
|
114
|
+
)
|
115
|
+
|
116
|
+
self._samples_processed = len(completed_dataset)
|
117
|
+
self._checkpoint_counter = metadata.get("checkpoint_counter", 0)
|
118
|
+
|
119
|
+
logger.info(
|
120
|
+
f"Loaded {len(completed_dataset)} completed samples, "
|
121
|
+
f"{len(remaining_dataset)} samples remaining"
|
122
|
+
)
|
123
|
+
|
124
|
+
return remaining_dataset, completed_dataset
|
125
|
+
|
126
|
+
except Exception as exc:
|
127
|
+
logger.warning(f"Failed to load checkpoints: {exc}. Starting from scratch.")
|
128
|
+
return input_dataset, None
|
129
|
+
|
130
|
+
def add_completed_samples(self, samples: Dataset) -> None:
|
131
|
+
"""Add samples that have completed the entire flow.
|
132
|
+
|
133
|
+
Parameters
|
134
|
+
----------
|
135
|
+
samples : Dataset
|
136
|
+
Samples that have completed processing through all blocks.
|
137
|
+
"""
|
138
|
+
if not self.is_enabled:
|
139
|
+
return
|
140
|
+
|
141
|
+
# Add to pending samples
|
142
|
+
for sample in samples:
|
143
|
+
self._pending_samples.append(sample)
|
144
|
+
self._samples_processed += 1
|
145
|
+
|
146
|
+
# Check if we should save a checkpoint
|
147
|
+
if self.save_freq and len(self._pending_samples) >= self.save_freq:
|
148
|
+
self._save_checkpoint()
|
149
|
+
|
150
|
+
def save_final_checkpoint(self) -> None:
|
151
|
+
"""Save any remaining pending samples as final checkpoint."""
|
152
|
+
if not self.is_enabled:
|
153
|
+
return
|
154
|
+
|
155
|
+
if self._pending_samples:
|
156
|
+
sample_count = len(self._pending_samples)
|
157
|
+
self._save_checkpoint()
|
158
|
+
logger.info(f"Saved final checkpoint with {sample_count} samples")
|
159
|
+
|
160
|
+
def _save_checkpoint(self) -> None:
|
161
|
+
"""Save current pending samples to a checkpoint file."""
|
162
|
+
if not self._pending_samples:
|
163
|
+
return
|
164
|
+
|
165
|
+
self._checkpoint_counter += 1
|
166
|
+
checkpoint_file = os.path.join(
|
167
|
+
self.checkpoint_dir, f"checkpoint_{self._checkpoint_counter:04d}.jsonl"
|
168
|
+
)
|
169
|
+
|
170
|
+
# Convert pending samples to dataset and save
|
171
|
+
checkpoint_dataset = Dataset.from_list(self._pending_samples)
|
172
|
+
checkpoint_dataset.to_json(checkpoint_file, orient="records", lines=True)
|
173
|
+
|
174
|
+
# Update metadata
|
175
|
+
self._save_metadata()
|
176
|
+
|
177
|
+
logger.info(
|
178
|
+
f"Saved checkpoint {self._checkpoint_counter} with "
|
179
|
+
f"{len(self._pending_samples)} samples to {checkpoint_file}"
|
180
|
+
)
|
181
|
+
|
182
|
+
# Clear pending samples
|
183
|
+
self._pending_samples.clear()
|
184
|
+
|
185
|
+
def _save_metadata(self) -> None:
|
186
|
+
"""Save flow execution metadata."""
|
187
|
+
metadata = {
|
188
|
+
"flow_id": self.flow_id,
|
189
|
+
"save_freq": self.save_freq,
|
190
|
+
"samples_processed": self._samples_processed,
|
191
|
+
"checkpoint_counter": self._checkpoint_counter,
|
192
|
+
"last_updated": str(uuid.uuid4()), # Simple versioning
|
193
|
+
}
|
194
|
+
|
195
|
+
with open(self.metadata_path, "w", encoding="utf-8") as f:
|
196
|
+
json.dump(metadata, f, indent=2)
|
197
|
+
|
198
|
+
def _load_metadata(self) -> Optional[Dict[str, Any]]:
|
199
|
+
"""Load flow execution metadata."""
|
200
|
+
if not os.path.exists(self.metadata_path):
|
201
|
+
return None
|
202
|
+
|
203
|
+
try:
|
204
|
+
with open(self.metadata_path, "r", encoding="utf-8") as f:
|
205
|
+
return json.load(f)
|
206
|
+
except Exception as exc:
|
207
|
+
logger.warning(f"Failed to load metadata: {exc}")
|
208
|
+
return None
|
209
|
+
|
210
|
+
def _load_completed_samples(self) -> Optional[Dataset]:
|
211
|
+
"""Load all completed samples from checkpoint files."""
|
212
|
+
checkpoint_files = []
|
213
|
+
checkpoint_dir = Path(self.checkpoint_dir)
|
214
|
+
|
215
|
+
# Find all checkpoint files
|
216
|
+
for file_path in checkpoint_dir.glob("checkpoint_*.jsonl"):
|
217
|
+
checkpoint_files.append(str(file_path))
|
218
|
+
|
219
|
+
if not checkpoint_files:
|
220
|
+
return None
|
221
|
+
|
222
|
+
# Sort checkpoint files by number
|
223
|
+
checkpoint_files.sort()
|
224
|
+
|
225
|
+
# Load and concatenate all checkpoint datasets
|
226
|
+
datasets = []
|
227
|
+
for file_path in checkpoint_files:
|
228
|
+
try:
|
229
|
+
dataset = Dataset.from_json(file_path)
|
230
|
+
if len(dataset) > 0:
|
231
|
+
datasets.append(dataset)
|
232
|
+
logger.debug(
|
233
|
+
f"Loaded checkpoint: {file_path} ({len(dataset)} samples)"
|
234
|
+
)
|
235
|
+
except Exception as exc:
|
236
|
+
logger.warning(f"Failed to load checkpoint {file_path}: {exc}")
|
237
|
+
|
238
|
+
if not datasets:
|
239
|
+
return None
|
240
|
+
|
241
|
+
return safe_concatenate_with_validation(datasets, "checkpoint files")
|
242
|
+
|
243
|
+
def _find_remaining_samples(
|
244
|
+
self, input_dataset: Dataset, completed_dataset: Dataset
|
245
|
+
) -> Dataset:
|
246
|
+
"""Find samples from input_dataset that are not in completed_dataset.
|
247
|
+
|
248
|
+
Note: Assumes input_dataset contains unique samples. For datasets with
|
249
|
+
duplicates, multiset semantics with collections.Counter would be needed.
|
250
|
+
|
251
|
+
Parameters
|
252
|
+
----------
|
253
|
+
input_dataset : Dataset
|
254
|
+
Original input dataset (assumed to contain unique samples).
|
255
|
+
completed_dataset : Dataset
|
256
|
+
Dataset of completed samples.
|
257
|
+
|
258
|
+
Returns
|
259
|
+
-------
|
260
|
+
Dataset
|
261
|
+
Samples that still need processing.
|
262
|
+
"""
|
263
|
+
# Get common columns for comparison
|
264
|
+
input_columns = set(input_dataset.column_names)
|
265
|
+
completed_columns = set(completed_dataset.column_names)
|
266
|
+
common_columns = list(input_columns & completed_columns)
|
267
|
+
|
268
|
+
if not common_columns:
|
269
|
+
logger.warning(
|
270
|
+
"No common columns found between input and completed datasets. "
|
271
|
+
"Processing all input samples."
|
272
|
+
)
|
273
|
+
return input_dataset
|
274
|
+
|
275
|
+
# Convert to pandas for easier comparison
|
276
|
+
input_df = input_dataset.select_columns(common_columns).to_pandas()
|
277
|
+
completed_df = completed_dataset.select_columns(common_columns).to_pandas()
|
278
|
+
|
279
|
+
# Find rows that haven't been completed
|
280
|
+
# Use tuple representation for comparison
|
281
|
+
input_tuples = set(input_df.apply(tuple, axis=1))
|
282
|
+
completed_tuples = set(completed_df.apply(tuple, axis=1))
|
283
|
+
remaining_tuples = input_tuples - completed_tuples
|
284
|
+
|
285
|
+
# Filter input dataset to only remaining samples
|
286
|
+
remaining_mask = input_df.apply(tuple, axis=1).isin(remaining_tuples)
|
287
|
+
remaining_indices = input_df[remaining_mask].index.tolist()
|
288
|
+
|
289
|
+
if not remaining_indices:
|
290
|
+
# Return empty dataset with same structure
|
291
|
+
return input_dataset.select([])
|
292
|
+
|
293
|
+
return input_dataset.select(remaining_indices)
|
294
|
+
|
295
|
+
def get_progress_info(self) -> Dict[str, Any]:
|
296
|
+
"""Get information about current progress.
|
297
|
+
|
298
|
+
Returns
|
299
|
+
-------
|
300
|
+
Dict[str, Any]
|
301
|
+
Progress information including samples processed, checkpoints saved, etc.
|
302
|
+
"""
|
303
|
+
return {
|
304
|
+
"checkpoint_dir": self.checkpoint_dir,
|
305
|
+
"save_freq": self.save_freq,
|
306
|
+
"flow_id": self.flow_id,
|
307
|
+
"samples_processed": self._samples_processed,
|
308
|
+
"checkpoint_counter": self._checkpoint_counter,
|
309
|
+
"pending_samples": len(self._pending_samples),
|
310
|
+
"is_enabled": self.is_enabled,
|
311
|
+
}
|
312
|
+
|
313
|
+
def cleanup_checkpoints(self) -> None:
|
314
|
+
"""Remove all checkpoint files and metadata."""
|
315
|
+
if not self.is_enabled:
|
316
|
+
return
|
317
|
+
|
318
|
+
checkpoint_dir = Path(self.checkpoint_dir)
|
319
|
+
if not checkpoint_dir.exists():
|
320
|
+
return
|
321
|
+
|
322
|
+
# Remove all checkpoint files
|
323
|
+
for file_path in checkpoint_dir.glob("checkpoint_*.jsonl"):
|
324
|
+
file_path.unlink()
|
325
|
+
logger.debug(f"Removed checkpoint file: {file_path}")
|
326
|
+
|
327
|
+
# Remove metadata file
|
328
|
+
metadata_path = Path(self.metadata_path)
|
329
|
+
if metadata_path.exists():
|
330
|
+
metadata_path.unlink()
|
331
|
+
logger.debug(f"Removed metadata file: {metadata_path}")
|
332
|
+
|
333
|
+
logger.info(f"Cleaned up all checkpoints in {self.checkpoint_dir}")
|