sdg-hub 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sdg_hub/_version.py +16 -3
- sdg_hub/core/blocks/deprecated_blocks/selector.py +1 -1
- sdg_hub/core/blocks/evaluation/evaluate_faithfulness_block.py +175 -416
- sdg_hub/core/blocks/evaluation/evaluate_relevancy_block.py +174 -415
- sdg_hub/core/blocks/evaluation/verify_question_block.py +180 -415
- sdg_hub/core/blocks/llm/__init__.py +2 -0
- sdg_hub/core/blocks/llm/client_manager.py +61 -24
- sdg_hub/core/blocks/llm/config.py +1 -0
- sdg_hub/core/blocks/llm/llm_chat_block.py +62 -7
- sdg_hub/core/blocks/llm/llm_chat_with_parsing_retry_block.py +653 -0
- sdg_hub/core/blocks/llm/text_parser_block.py +75 -30
- sdg_hub/core/blocks/registry.py +49 -35
- sdg_hub/core/blocks/transform/index_based_mapper.py +1 -1
- sdg_hub/core/flow/base.py +370 -20
- sdg_hub/core/flow/checkpointer.py +333 -0
- sdg_hub/core/flow/metadata.py +45 -0
- sdg_hub/core/flow/migration.py +12 -1
- sdg_hub/core/flow/registry.py +121 -58
- sdg_hub/core/flow/validation.py +12 -0
- sdg_hub/core/utils/__init__.py +2 -1
- sdg_hub/core/utils/datautils.py +81 -1
- sdg_hub/core/utils/flow_id_words.yaml +231 -0
- sdg_hub/core/utils/flow_identifier.py +94 -0
- sdg_hub/core/utils/yaml_utils.py +59 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/flow.yaml +1 -7
- {sdg_hub-0.2.0.dist-info → sdg_hub-0.2.2.dist-info}/METADATA +59 -31
- {sdg_hub-0.2.0.dist-info → sdg_hub-0.2.2.dist-info}/RECORD +30 -25
- {sdg_hub-0.2.0.dist-info → sdg_hub-0.2.2.dist-info}/WHEEL +0 -0
- {sdg_hub-0.2.0.dist-info → sdg_hub-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {sdg_hub-0.2.0.dist-info → sdg_hub-0.2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,333 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
"""Flow-level checkpointing with sample-level tracking for data generation pipelines."""
|
3
|
+
|
4
|
+
# Standard
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Any, Dict, List, Optional, Tuple
|
7
|
+
import json
|
8
|
+
import os
|
9
|
+
import uuid
|
10
|
+
|
11
|
+
# Third Party
|
12
|
+
from datasets import Dataset
|
13
|
+
|
14
|
+
# Local
|
15
|
+
from ..utils.datautils import safe_concatenate_with_validation
|
16
|
+
from ..utils.logger_config import setup_logger
|
17
|
+
|
18
|
+
logger = setup_logger(__name__)
|
19
|
+
|
20
|
+
|
21
|
+
class FlowCheckpointer:
|
22
|
+
"""Enhanced checkpointer for Flow execution with sample-level tracking.
|
23
|
+
|
24
|
+
Provides data-level checkpointing where progress is saved after processing
|
25
|
+
a specified number of samples through the entire flow pipeline.
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
checkpoint_dir: Optional[str] = None,
|
31
|
+
save_freq: Optional[int] = None,
|
32
|
+
flow_id: Optional[str] = None,
|
33
|
+
):
|
34
|
+
"""Initialize the FlowCheckpointer.
|
35
|
+
|
36
|
+
Parameters
|
37
|
+
----------
|
38
|
+
checkpoint_dir : Optional[str]
|
39
|
+
Directory to save/load checkpoints. If None, checkpointing is disabled.
|
40
|
+
save_freq : Optional[int]
|
41
|
+
Number of completed samples after which to save a checkpoint.
|
42
|
+
If None, only final results are saved.
|
43
|
+
flow_id : Optional[str]
|
44
|
+
Unique ID of the flow for checkpoint identification.
|
45
|
+
"""
|
46
|
+
self.checkpoint_dir = checkpoint_dir
|
47
|
+
self.save_freq = save_freq
|
48
|
+
self.flow_id = flow_id or "unknown_flow"
|
49
|
+
|
50
|
+
# Internal state
|
51
|
+
self._samples_processed = 0
|
52
|
+
self._checkpoint_counter = 0
|
53
|
+
self._pending_samples: List[Dict[str, Any]] = []
|
54
|
+
|
55
|
+
# Ensure checkpoint directory exists
|
56
|
+
if self.checkpoint_dir:
|
57
|
+
Path(self.checkpoint_dir).mkdir(parents=True, exist_ok=True)
|
58
|
+
|
59
|
+
@property
|
60
|
+
def is_enabled(self) -> bool:
|
61
|
+
"""Check if checkpointing is enabled."""
|
62
|
+
return self.checkpoint_dir is not None
|
63
|
+
|
64
|
+
@property
|
65
|
+
def metadata_path(self) -> str:
|
66
|
+
"""Path to the flow metadata file."""
|
67
|
+
return os.path.join(self.checkpoint_dir, "flow_metadata.json")
|
68
|
+
|
69
|
+
def load_existing_progress(
|
70
|
+
self, input_dataset: Dataset
|
71
|
+
) -> Tuple[Dataset, Optional[Dataset]]:
|
72
|
+
"""Load existing checkpoint data and determine remaining work.
|
73
|
+
|
74
|
+
Parameters
|
75
|
+
----------
|
76
|
+
input_dataset : Dataset
|
77
|
+
Original input dataset for the flow.
|
78
|
+
|
79
|
+
Returns
|
80
|
+
-------
|
81
|
+
Tuple[Dataset, Optional[Dataset]]
|
82
|
+
(remaining_samples_to_process, completed_samples_dataset)
|
83
|
+
If no checkpoints exist, returns (input_dataset, None)
|
84
|
+
"""
|
85
|
+
if not self.is_enabled:
|
86
|
+
return input_dataset, None
|
87
|
+
|
88
|
+
try:
|
89
|
+
# Load flow metadata
|
90
|
+
metadata = self._load_metadata()
|
91
|
+
if not metadata:
|
92
|
+
logger.info(f"No existing checkpoints found in {self.checkpoint_dir}")
|
93
|
+
return input_dataset, None
|
94
|
+
|
95
|
+
# Validate flow identity to prevent mixing checkpoints from different flows
|
96
|
+
saved_flow_id = metadata.get("flow_id")
|
97
|
+
if saved_flow_id and saved_flow_id != self.flow_id:
|
98
|
+
logger.warning(
|
99
|
+
f"Flow ID mismatch: saved checkpoints are for flow ID '{saved_flow_id}' "
|
100
|
+
f"but current flow ID is '{self.flow_id}'. Starting fresh to avoid "
|
101
|
+
f"mixing incompatible checkpoint data."
|
102
|
+
)
|
103
|
+
return input_dataset, None
|
104
|
+
|
105
|
+
# Load all completed samples from checkpoints
|
106
|
+
completed_dataset = self._load_completed_samples()
|
107
|
+
if completed_dataset is None or len(completed_dataset) == 0:
|
108
|
+
logger.info("No completed samples found in checkpoints")
|
109
|
+
return input_dataset, None
|
110
|
+
|
111
|
+
# Find samples that still need processing
|
112
|
+
remaining_dataset = self._find_remaining_samples(
|
113
|
+
input_dataset, completed_dataset
|
114
|
+
)
|
115
|
+
|
116
|
+
self._samples_processed = len(completed_dataset)
|
117
|
+
self._checkpoint_counter = metadata.get("checkpoint_counter", 0)
|
118
|
+
|
119
|
+
logger.info(
|
120
|
+
f"Loaded {len(completed_dataset)} completed samples, "
|
121
|
+
f"{len(remaining_dataset)} samples remaining"
|
122
|
+
)
|
123
|
+
|
124
|
+
return remaining_dataset, completed_dataset
|
125
|
+
|
126
|
+
except Exception as exc:
|
127
|
+
logger.warning(f"Failed to load checkpoints: {exc}. Starting from scratch.")
|
128
|
+
return input_dataset, None
|
129
|
+
|
130
|
+
def add_completed_samples(self, samples: Dataset) -> None:
|
131
|
+
"""Add samples that have completed the entire flow.
|
132
|
+
|
133
|
+
Parameters
|
134
|
+
----------
|
135
|
+
samples : Dataset
|
136
|
+
Samples that have completed processing through all blocks.
|
137
|
+
"""
|
138
|
+
if not self.is_enabled:
|
139
|
+
return
|
140
|
+
|
141
|
+
# Add to pending samples
|
142
|
+
for sample in samples:
|
143
|
+
self._pending_samples.append(sample)
|
144
|
+
self._samples_processed += 1
|
145
|
+
|
146
|
+
# Check if we should save a checkpoint
|
147
|
+
if self.save_freq and len(self._pending_samples) >= self.save_freq:
|
148
|
+
self._save_checkpoint()
|
149
|
+
|
150
|
+
def save_final_checkpoint(self) -> None:
|
151
|
+
"""Save any remaining pending samples as final checkpoint."""
|
152
|
+
if not self.is_enabled:
|
153
|
+
return
|
154
|
+
|
155
|
+
if self._pending_samples:
|
156
|
+
sample_count = len(self._pending_samples)
|
157
|
+
self._save_checkpoint()
|
158
|
+
logger.info(f"Saved final checkpoint with {sample_count} samples")
|
159
|
+
|
160
|
+
def _save_checkpoint(self) -> None:
|
161
|
+
"""Save current pending samples to a checkpoint file."""
|
162
|
+
if not self._pending_samples:
|
163
|
+
return
|
164
|
+
|
165
|
+
self._checkpoint_counter += 1
|
166
|
+
checkpoint_file = os.path.join(
|
167
|
+
self.checkpoint_dir, f"checkpoint_{self._checkpoint_counter:04d}.jsonl"
|
168
|
+
)
|
169
|
+
|
170
|
+
# Convert pending samples to dataset and save
|
171
|
+
checkpoint_dataset = Dataset.from_list(self._pending_samples)
|
172
|
+
checkpoint_dataset.to_json(checkpoint_file, orient="records", lines=True)
|
173
|
+
|
174
|
+
# Update metadata
|
175
|
+
self._save_metadata()
|
176
|
+
|
177
|
+
logger.info(
|
178
|
+
f"Saved checkpoint {self._checkpoint_counter} with "
|
179
|
+
f"{len(self._pending_samples)} samples to {checkpoint_file}"
|
180
|
+
)
|
181
|
+
|
182
|
+
# Clear pending samples
|
183
|
+
self._pending_samples.clear()
|
184
|
+
|
185
|
+
def _save_metadata(self) -> None:
|
186
|
+
"""Save flow execution metadata."""
|
187
|
+
metadata = {
|
188
|
+
"flow_id": self.flow_id,
|
189
|
+
"save_freq": self.save_freq,
|
190
|
+
"samples_processed": self._samples_processed,
|
191
|
+
"checkpoint_counter": self._checkpoint_counter,
|
192
|
+
"last_updated": str(uuid.uuid4()), # Simple versioning
|
193
|
+
}
|
194
|
+
|
195
|
+
with open(self.metadata_path, "w", encoding="utf-8") as f:
|
196
|
+
json.dump(metadata, f, indent=2)
|
197
|
+
|
198
|
+
def _load_metadata(self) -> Optional[Dict[str, Any]]:
|
199
|
+
"""Load flow execution metadata."""
|
200
|
+
if not os.path.exists(self.metadata_path):
|
201
|
+
return None
|
202
|
+
|
203
|
+
try:
|
204
|
+
with open(self.metadata_path, "r", encoding="utf-8") as f:
|
205
|
+
return json.load(f)
|
206
|
+
except Exception as exc:
|
207
|
+
logger.warning(f"Failed to load metadata: {exc}")
|
208
|
+
return None
|
209
|
+
|
210
|
+
def _load_completed_samples(self) -> Optional[Dataset]:
|
211
|
+
"""Load all completed samples from checkpoint files."""
|
212
|
+
checkpoint_files = []
|
213
|
+
checkpoint_dir = Path(self.checkpoint_dir)
|
214
|
+
|
215
|
+
# Find all checkpoint files
|
216
|
+
for file_path in checkpoint_dir.glob("checkpoint_*.jsonl"):
|
217
|
+
checkpoint_files.append(str(file_path))
|
218
|
+
|
219
|
+
if not checkpoint_files:
|
220
|
+
return None
|
221
|
+
|
222
|
+
# Sort checkpoint files by number
|
223
|
+
checkpoint_files.sort()
|
224
|
+
|
225
|
+
# Load and concatenate all checkpoint datasets
|
226
|
+
datasets = []
|
227
|
+
for file_path in checkpoint_files:
|
228
|
+
try:
|
229
|
+
dataset = Dataset.from_json(file_path)
|
230
|
+
if len(dataset) > 0:
|
231
|
+
datasets.append(dataset)
|
232
|
+
logger.debug(
|
233
|
+
f"Loaded checkpoint: {file_path} ({len(dataset)} samples)"
|
234
|
+
)
|
235
|
+
except Exception as exc:
|
236
|
+
logger.warning(f"Failed to load checkpoint {file_path}: {exc}")
|
237
|
+
|
238
|
+
if not datasets:
|
239
|
+
return None
|
240
|
+
|
241
|
+
return safe_concatenate_with_validation(datasets, "checkpoint files")
|
242
|
+
|
243
|
+
def _find_remaining_samples(
|
244
|
+
self, input_dataset: Dataset, completed_dataset: Dataset
|
245
|
+
) -> Dataset:
|
246
|
+
"""Find samples from input_dataset that are not in completed_dataset.
|
247
|
+
|
248
|
+
Note: Assumes input_dataset contains unique samples. For datasets with
|
249
|
+
duplicates, multiset semantics with collections.Counter would be needed.
|
250
|
+
|
251
|
+
Parameters
|
252
|
+
----------
|
253
|
+
input_dataset : Dataset
|
254
|
+
Original input dataset (assumed to contain unique samples).
|
255
|
+
completed_dataset : Dataset
|
256
|
+
Dataset of completed samples.
|
257
|
+
|
258
|
+
Returns
|
259
|
+
-------
|
260
|
+
Dataset
|
261
|
+
Samples that still need processing.
|
262
|
+
"""
|
263
|
+
# Get common columns for comparison
|
264
|
+
input_columns = set(input_dataset.column_names)
|
265
|
+
completed_columns = set(completed_dataset.column_names)
|
266
|
+
common_columns = list(input_columns & completed_columns)
|
267
|
+
|
268
|
+
if not common_columns:
|
269
|
+
logger.warning(
|
270
|
+
"No common columns found between input and completed datasets. "
|
271
|
+
"Processing all input samples."
|
272
|
+
)
|
273
|
+
return input_dataset
|
274
|
+
|
275
|
+
# Convert to pandas for easier comparison
|
276
|
+
input_df = input_dataset.select_columns(common_columns).to_pandas()
|
277
|
+
completed_df = completed_dataset.select_columns(common_columns).to_pandas()
|
278
|
+
|
279
|
+
# Find rows that haven't been completed
|
280
|
+
# Use tuple representation for comparison
|
281
|
+
input_tuples = set(input_df.apply(tuple, axis=1))
|
282
|
+
completed_tuples = set(completed_df.apply(tuple, axis=1))
|
283
|
+
remaining_tuples = input_tuples - completed_tuples
|
284
|
+
|
285
|
+
# Filter input dataset to only remaining samples
|
286
|
+
remaining_mask = input_df.apply(tuple, axis=1).isin(remaining_tuples)
|
287
|
+
remaining_indices = input_df[remaining_mask].index.tolist()
|
288
|
+
|
289
|
+
if not remaining_indices:
|
290
|
+
# Return empty dataset with same structure
|
291
|
+
return input_dataset.select([])
|
292
|
+
|
293
|
+
return input_dataset.select(remaining_indices)
|
294
|
+
|
295
|
+
def get_progress_info(self) -> Dict[str, Any]:
|
296
|
+
"""Get information about current progress.
|
297
|
+
|
298
|
+
Returns
|
299
|
+
-------
|
300
|
+
Dict[str, Any]
|
301
|
+
Progress information including samples processed, checkpoints saved, etc.
|
302
|
+
"""
|
303
|
+
return {
|
304
|
+
"checkpoint_dir": self.checkpoint_dir,
|
305
|
+
"save_freq": self.save_freq,
|
306
|
+
"flow_id": self.flow_id,
|
307
|
+
"samples_processed": self._samples_processed,
|
308
|
+
"checkpoint_counter": self._checkpoint_counter,
|
309
|
+
"pending_samples": len(self._pending_samples),
|
310
|
+
"is_enabled": self.is_enabled,
|
311
|
+
}
|
312
|
+
|
313
|
+
def cleanup_checkpoints(self) -> None:
|
314
|
+
"""Remove all checkpoint files and metadata."""
|
315
|
+
if not self.is_enabled:
|
316
|
+
return
|
317
|
+
|
318
|
+
checkpoint_dir = Path(self.checkpoint_dir)
|
319
|
+
if not checkpoint_dir.exists():
|
320
|
+
return
|
321
|
+
|
322
|
+
# Remove all checkpoint files
|
323
|
+
for file_path in checkpoint_dir.glob("checkpoint_*.jsonl"):
|
324
|
+
file_path.unlink()
|
325
|
+
logger.debug(f"Removed checkpoint file: {file_path}")
|
326
|
+
|
327
|
+
# Remove metadata file
|
328
|
+
metadata_path = Path(self.metadata_path)
|
329
|
+
if metadata_path.exists():
|
330
|
+
metadata_path.unlink()
|
331
|
+
logger.debug(f"Removed metadata file: {metadata_path}")
|
332
|
+
|
333
|
+
logger.info(f"Cleaned up all checkpoints in {self.checkpoint_dir}")
|
sdg_hub/core/flow/metadata.py
CHANGED
@@ -9,6 +9,9 @@ from typing import Any, Optional
|
|
9
9
|
# Third Party
|
10
10
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
11
11
|
|
12
|
+
# Local
|
13
|
+
from ..utils.flow_identifier import get_flow_identifier
|
14
|
+
|
12
15
|
|
13
16
|
class ModelCompatibility(str, Enum):
|
14
17
|
"""Model compatibility levels."""
|
@@ -238,6 +241,8 @@ class FlowMetadata(BaseModel):
|
|
238
241
|
|
239
242
|
Attributes
|
240
243
|
----------
|
244
|
+
id : str
|
245
|
+
Unique identifier for the flow.
|
241
246
|
name : str
|
242
247
|
Human-readable name of the flow.
|
243
248
|
description : str
|
@@ -267,6 +272,9 @@ class FlowMetadata(BaseModel):
|
|
267
272
|
"""
|
268
273
|
|
269
274
|
name: str = Field(..., min_length=1, description="Human-readable name")
|
275
|
+
id: str = Field(
|
276
|
+
default="", description="Unique identifier for the flow, generated from name"
|
277
|
+
)
|
270
278
|
description: str = Field(default="", description="Detailed description")
|
271
279
|
version: str = Field(
|
272
280
|
default="1.0.0",
|
@@ -304,6 +312,31 @@ class FlowMetadata(BaseModel):
|
|
304
312
|
default="", description="Estimated duration for flow execution"
|
305
313
|
)
|
306
314
|
|
315
|
+
@field_validator("id")
|
316
|
+
@classmethod
|
317
|
+
def validate_id(cls, v: str) -> str:
|
318
|
+
"""Validate flow id."""
|
319
|
+
# Note: Auto-generation is handled in the model_validator since field_validator
|
320
|
+
# doesn't have access to other field values in Pydantic v2
|
321
|
+
|
322
|
+
# Validate id format if provided
|
323
|
+
if v:
|
324
|
+
# Must be lowercase
|
325
|
+
if not v.islower():
|
326
|
+
raise ValueError("id must be lowercase")
|
327
|
+
|
328
|
+
# Must contain only alphanumeric characters and hyphens
|
329
|
+
if not v.replace("-", "").isalnum():
|
330
|
+
raise ValueError(
|
331
|
+
"id must contain only alphanumeric characters and hyphens"
|
332
|
+
)
|
333
|
+
|
334
|
+
# Must not start or end with a hyphen
|
335
|
+
if v.startswith("-") or v.endswith("-"):
|
336
|
+
raise ValueError("id must not start or end with a hyphen")
|
337
|
+
|
338
|
+
return v
|
339
|
+
|
307
340
|
@field_validator("tags")
|
308
341
|
@classmethod
|
309
342
|
def validate_tags(cls, v: list[str]) -> list[str]:
|
@@ -323,6 +356,18 @@ class FlowMetadata(BaseModel):
|
|
323
356
|
"""Update the updated_at timestamp."""
|
324
357
|
self.updated_at = datetime.now().isoformat()
|
325
358
|
|
359
|
+
@model_validator(mode="after")
|
360
|
+
def ensure_id(self) -> "FlowMetadata":
|
361
|
+
"""Ensure id is set.
|
362
|
+
|
363
|
+
Note: YAML persistence is handled by Flow.from_yaml() and FlowRegistry
|
364
|
+
to maintain proper separation of concerns.
|
365
|
+
"""
|
366
|
+
if not self.id and self.name:
|
367
|
+
self.id = get_flow_identifier(self.name)
|
368
|
+
|
369
|
+
return self
|
370
|
+
|
326
371
|
def get_best_model(
|
327
372
|
self, available_models: Optional[list[str]] = None
|
328
373
|
) -> Optional[str]:
|
sdg_hub/core/flow/migration.py
CHANGED
@@ -114,7 +114,10 @@ class FlowMigration:
|
|
114
114
|
@staticmethod
|
115
115
|
def _generate_default_metadata(flow_name: str) -> dict[str, Any]:
|
116
116
|
"""Generate default metadata for migrated flows."""
|
117
|
-
|
117
|
+
# Import here to avoid circular import
|
118
|
+
from ..utils.flow_identifier import get_flow_identifier
|
119
|
+
|
120
|
+
metadata = {
|
118
121
|
"name": flow_name,
|
119
122
|
"description": f"Migrated flow: {flow_name}",
|
120
123
|
"version": "1.0.0",
|
@@ -127,6 +130,14 @@ class FlowMigration:
|
|
127
130
|
},
|
128
131
|
}
|
129
132
|
|
133
|
+
# Generate id for migrated flows
|
134
|
+
flow_id = get_flow_identifier(flow_name)
|
135
|
+
if flow_id:
|
136
|
+
metadata["id"] = flow_id
|
137
|
+
logger.debug(f"Generated id for migrated flow: {flow_id}")
|
138
|
+
|
139
|
+
return metadata
|
140
|
+
|
130
141
|
@staticmethod
|
131
142
|
def _migrate_block_config(
|
132
143
|
block_config: dict[str, Any],
|