caption-flow 0.2.4__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +1 -1
- caption_flow/orchestrator.py +9 -9
- caption_flow/processors/huggingface.py +636 -464
- caption_flow/processors/webdataset.py +379 -534
- caption_flow/storage/manager.py +328 -305
- caption_flow/utils/__init__.py +0 -2
- caption_flow/utils/chunk_tracker.py +196 -164
- caption_flow/utils/image_processor.py +19 -132
- caption_flow/workers/caption.py +164 -129
- {caption_flow-0.2.4.dist-info → caption_flow-0.3.1.dist-info}/METADATA +2 -1
- {caption_flow-0.2.4.dist-info → caption_flow-0.3.1.dist-info}/RECORD +15 -20
- caption_flow/utils/dataset_loader.py +0 -222
- caption_flow/utils/dataset_metadata_cache.py +0 -67
- caption_flow/utils/job_queue.py +0 -41
- caption_flow/utils/shard_processor.py +0 -119
- caption_flow/utils/shard_tracker.py +0 -83
- {caption_flow-0.2.4.dist-info → caption_flow-0.3.1.dist-info}/WHEEL +0 -0
- {caption_flow-0.2.4.dist-info → caption_flow-0.3.1.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.2.4.dist-info → caption_flow-0.3.1.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.2.4.dist-info → caption_flow-0.3.1.dist-info}/top_level.txt +0 -0
@@ -1,222 +0,0 @@
|
|
1
|
-
"""Dataset loading utilities for WebDataset and HuggingFace."""
|
2
|
-
|
3
|
-
import asyncio
|
4
|
-
import shlex
|
5
|
-
import logging
|
6
|
-
from pathlib import Path
|
7
|
-
from typing import List, Dict, Any, Generator, Optional, Tuple
|
8
|
-
import json
|
9
|
-
|
10
|
-
import webdataset as wds
|
11
|
-
from huggingface_hub import HfFileSystem, get_token, hf_hub_url
|
12
|
-
|
13
|
-
logger = logging.getLogger(__name__)
|
14
|
-
|
15
|
-
|
16
|
-
class DatasetLoader:
|
17
|
-
"""Handles loading datasets from various sources."""
|
18
|
-
|
19
|
-
def __init__(
|
20
|
-
self,
|
21
|
-
dataset_path: str,
|
22
|
-
dataset_type: str = "huggingface",
|
23
|
-
split: str = "train",
|
24
|
-
image_column: str = "image",
|
25
|
-
cache_dir: Optional[Path] = None,
|
26
|
-
):
|
27
|
-
"""
|
28
|
-
Initialize dataset loader.
|
29
|
-
|
30
|
-
Args:
|
31
|
-
dataset_path: Path to dataset (HF repo, local dir, etc.)
|
32
|
-
dataset_type: Type of dataset ("huggingface", "webdataset", "local")
|
33
|
-
split: Split to use for HuggingFace datasets (default: "train")
|
34
|
-
image_column: Column name containing image data or URLs (default: "image")
|
35
|
-
"""
|
36
|
-
self.dataset_path = dataset_path
|
37
|
-
self.dataset_type = dataset_type
|
38
|
-
self.split = split
|
39
|
-
self.image_column = image_column
|
40
|
-
self.token = get_token()
|
41
|
-
self.dataset_format = None # Will be detected: "webdataset" or "huggingface_datasets"
|
42
|
-
|
43
|
-
if not self.token and dataset_type == "huggingface":
|
44
|
-
logger.warning("No HuggingFace token found; run `huggingface-cli login`")
|
45
|
-
|
46
|
-
# Detect the actual format if it's a HuggingFace dataset
|
47
|
-
if dataset_type == "huggingface":
|
48
|
-
self.dataset_format = self._detect_dataset_format()
|
49
|
-
logger.info(f"Detected dataset format: {self.dataset_format}")
|
50
|
-
|
51
|
-
def _detect_dataset_format(self) -> str:
|
52
|
-
"""Detect whether it's WebDataset or HuggingFace datasets format."""
|
53
|
-
fs = HfFileSystem(token=self.token)
|
54
|
-
|
55
|
-
# Check for .tar files (WebDataset)
|
56
|
-
tar_files = list(fs.glob(f"hf://datasets/{self.dataset_path}/**/*.tar"))
|
57
|
-
if tar_files:
|
58
|
-
return "webdataset"
|
59
|
-
|
60
|
-
# Check for .parquet files (Huggingface Arrow DB)
|
61
|
-
parquet_files = list(fs.glob(f"hf://datasets/{self.dataset_path}/**/*.parquet"))
|
62
|
-
if parquet_files:
|
63
|
-
return "huggingface_datasets"
|
64
|
-
|
65
|
-
raise AssertionError(f"Could not detect dataset format for {self.dataset_path}")
|
66
|
-
|
67
|
-
def get_shard_list(self) -> List[str]:
|
68
|
-
"""Get list of all shards in the dataset."""
|
69
|
-
if self.dataset_type == "huggingface":
|
70
|
-
if self.dataset_format == "webdataset":
|
71
|
-
return self._get_hf_webdataset_shards()
|
72
|
-
else:
|
73
|
-
logger.error(f"Unknown dataset format: {self.dataset_format}")
|
74
|
-
return []
|
75
|
-
elif self.dataset_type == "local":
|
76
|
-
return self._get_local_shards()
|
77
|
-
else:
|
78
|
-
raise ValueError(f"Unknown dataset type: {self.dataset_type}")
|
79
|
-
|
80
|
-
def _get_hf_webdataset_shards(self) -> List[str]:
|
81
|
-
"""Get shard URLs from HuggingFace WebDataset."""
|
82
|
-
logger.info(f"Getting WebDataset shard list from HuggingFace: {self.dataset_path}")
|
83
|
-
|
84
|
-
fs = HfFileSystem(token=self.token)
|
85
|
-
files = [fs.resolve_path(p) for p in fs.glob(f"hf://datasets/{self.dataset_path}/**/*.tar")]
|
86
|
-
|
87
|
-
urls = [hf_hub_url(f.repo_id, f.path_in_repo, repo_type="dataset") for f in files]
|
88
|
-
|
89
|
-
logger.info(f"Found {len(urls)} WebDataset shards")
|
90
|
-
return sorted(urls)
|
91
|
-
|
92
|
-
def _get_local_shards(self) -> List[str]:
|
93
|
-
"""Get shard files from local directory."""
|
94
|
-
path = Path(self.dataset_path)
|
95
|
-
if not path.exists():
|
96
|
-
raise ValueError(f"Local dataset path does not exist: {path}")
|
97
|
-
|
98
|
-
shards = list(path.glob("*.tar"))
|
99
|
-
logger.info(f"Found {len(shards)} local shards")
|
100
|
-
return [str(s) for s in sorted(shards)]
|
101
|
-
|
102
|
-
def load_shard(self, shard_url: str, processed_keys: Optional[set] = None) -> wds.DataPipeline:
|
103
|
-
"""
|
104
|
-
Load a single shard as a WebDataset pipeline.
|
105
|
-
|
106
|
-
Args:
|
107
|
-
shard_url: URL or path to the shard
|
108
|
-
processed_keys: Set of already processed keys to skip
|
109
|
-
"""
|
110
|
-
if processed_keys is None:
|
111
|
-
processed_keys = set()
|
112
|
-
|
113
|
-
if self.dataset_type == "huggingface" and self.dataset_format == "webdataset":
|
114
|
-
# Use curl with auth token for HuggingFace
|
115
|
-
url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.token)}' {shlex.quote(shard_url)} || true"
|
116
|
-
ds = wds.DataPipeline(
|
117
|
-
wds.SimpleShardList(url_cmd),
|
118
|
-
wds.tarfile_to_samples(),
|
119
|
-
wds.to_tuple("__key__", "__url__", "jpg;png;jpeg;webp;jxl"),
|
120
|
-
wds.select(lambda x: x[0] not in processed_keys),
|
121
|
-
)
|
122
|
-
else:
|
123
|
-
# Local file access
|
124
|
-
ds = wds.DataPipeline(
|
125
|
-
wds.SimpleShardList(shard_url),
|
126
|
-
wds.tarfile_to_samples(),
|
127
|
-
wds.to_tuple("__key__", "__url__", "jpg;png;jpeg;webp;jxl"),
|
128
|
-
wds.select(lambda x: x[0] not in processed_keys),
|
129
|
-
)
|
130
|
-
|
131
|
-
return ds
|
132
|
-
|
133
|
-
def iterate_shard(
|
134
|
-
self,
|
135
|
-
shard_url: str,
|
136
|
-
processed_keys: Optional[set] = None,
|
137
|
-
unprocessed_ranges: Optional[List[Tuple[int, int]]] = None,
|
138
|
-
) -> Generator[Dict[str, Any], None, None]:
|
139
|
-
"""
|
140
|
-
Iterate over items in a shard, returning full sample dictionaries.
|
141
|
-
|
142
|
-
Args:
|
143
|
-
shard_url: URL or identifier of the shard
|
144
|
-
processed_keys: Set of already processed keys to skip
|
145
|
-
unprocessed_ranges: Specific ranges to process (for range-based processing)
|
146
|
-
|
147
|
-
Yields:
|
148
|
-
Dictionary containing the full WebDataset sample
|
149
|
-
"""
|
150
|
-
if processed_keys is None:
|
151
|
-
processed_keys = set()
|
152
|
-
|
153
|
-
if self.dataset_type == "huggingface" and self.dataset_format == "webdataset":
|
154
|
-
# Use curl with auth token for HuggingFace
|
155
|
-
url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.token)}' {shlex.quote(shard_url)} || true"
|
156
|
-
ds = wds.DataPipeline(
|
157
|
-
wds.SimpleShardList(url_cmd),
|
158
|
-
wds.tarfile_to_samples(),
|
159
|
-
wds.select(lambda x: x.get("__key__", "") not in processed_keys),
|
160
|
-
)
|
161
|
-
else:
|
162
|
-
# Local file access
|
163
|
-
ds = wds.DataPipeline(
|
164
|
-
wds.SimpleShardList(shard_url),
|
165
|
-
wds.tarfile_to_samples(),
|
166
|
-
wds.select(lambda x: x.get("__key__", "") not in processed_keys),
|
167
|
-
)
|
168
|
-
|
169
|
-
# Return full samples as dictionaries
|
170
|
-
for sample in ds:
|
171
|
-
# Ensure it's a dict and has required fields
|
172
|
-
if isinstance(sample, dict) and "__key__" in sample:
|
173
|
-
yield sample
|
174
|
-
|
175
|
-
def count_shard_items(self, shard_url: str, processed_keys: Optional[set] = None) -> int:
|
176
|
-
"""Count items in a shard (can be slow for large shards)."""
|
177
|
-
count = 0
|
178
|
-
try:
|
179
|
-
for _ in self.iterate_shard(shard_url, processed_keys):
|
180
|
-
count += 1
|
181
|
-
except Exception as e:
|
182
|
-
logger.error(f"Error counting shard {shard_url}: {e}")
|
183
|
-
return count
|
184
|
-
|
185
|
-
def get_dataset_info(self) -> Dict[str, Any]:
|
186
|
-
"""Get information about the dataset."""
|
187
|
-
info = {
|
188
|
-
"dataset_path": self.dataset_path,
|
189
|
-
"dataset_type": self.dataset_type,
|
190
|
-
"dataset_format": self.dataset_format,
|
191
|
-
}
|
192
|
-
|
193
|
-
if self.dataset_format == "huggingface_datasets":
|
194
|
-
# Include cached metadata if available
|
195
|
-
if hasattr(self, "_hf_metadata"):
|
196
|
-
info.update(self._hf_metadata)
|
197
|
-
else:
|
198
|
-
|
199
|
-
try:
|
200
|
-
# Try to get more info about the dataset
|
201
|
-
dataset_info = load_dataset(
|
202
|
-
self.dataset_path, split=self.split, streaming=True, token=self.token
|
203
|
-
)
|
204
|
-
# Get features info
|
205
|
-
if hasattr(dataset_info, "features"):
|
206
|
-
info["features"] = str(dataset_info.features)
|
207
|
-
|
208
|
-
# Try to get total size (might not work for all datasets)
|
209
|
-
try:
|
210
|
-
# This might be expensive for large datasets
|
211
|
-
total_examples = len(
|
212
|
-
load_dataset(self.dataset_path, split=self.split, token=self.token)
|
213
|
-
)
|
214
|
-
info["total_examples"] = total_examples
|
215
|
-
self._hf_total_items = total_examples
|
216
|
-
except:
|
217
|
-
info["total_examples"] = "unknown"
|
218
|
-
|
219
|
-
except Exception as e:
|
220
|
-
logger.error(f"Error getting dataset info: {e}")
|
221
|
-
|
222
|
-
return info
|
@@ -1,67 +0,0 @@
|
|
1
|
-
"""Dataset metadata caching for efficient HuggingFace dataset handling."""
|
2
|
-
|
3
|
-
import json
|
4
|
-
import logging
|
5
|
-
from pathlib import Path
|
6
|
-
from typing import Dict, Any, Optional, List
|
7
|
-
from datetime import datetime
|
8
|
-
|
9
|
-
logger = logging.getLogger(__name__)
|
10
|
-
|
11
|
-
|
12
|
-
class DatasetMetadataCache:
|
13
|
-
"""Caches dataset metadata to avoid repeated full iterations."""
|
14
|
-
|
15
|
-
def __init__(self, cache_dir: Path):
|
16
|
-
self.cache_dir = Path(cache_dir)
|
17
|
-
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
18
|
-
self.cache_file = self.cache_dir / "dataset_metadata.json"
|
19
|
-
self.metadata: Dict[str, Any] = {}
|
20
|
-
self._load_cache()
|
21
|
-
|
22
|
-
def _load_cache(self):
|
23
|
-
"""Load cached metadata from disk."""
|
24
|
-
if self.cache_file.exists():
|
25
|
-
try:
|
26
|
-
with open(self.cache_file, "r") as f:
|
27
|
-
self.metadata = json.load(f)
|
28
|
-
logger.info(f"Loaded dataset metadata cache with {len(self.metadata)} datasets")
|
29
|
-
except Exception as e:
|
30
|
-
logger.error(f"Failed to load metadata cache: {e}")
|
31
|
-
self.metadata = {}
|
32
|
-
|
33
|
-
def _save_cache(self):
|
34
|
-
"""Save metadata cache to disk."""
|
35
|
-
try:
|
36
|
-
with open(self.cache_file, "w") as f:
|
37
|
-
json.dump(self.metadata, f, indent=2)
|
38
|
-
logger.debug("Saved dataset metadata cache")
|
39
|
-
except Exception as e:
|
40
|
-
logger.error(f"Failed to save metadata cache: {e}")
|
41
|
-
|
42
|
-
def get_dataset_key(self, dataset_path: str, split: str) -> str:
|
43
|
-
"""Generate a unique key for a dataset+split combination."""
|
44
|
-
return f"{dataset_path}:{split}"
|
45
|
-
|
46
|
-
def get_metadata(self, dataset_path: str, split: str) -> Optional[Dict[str, Any]]:
|
47
|
-
"""Get cached metadata for a dataset."""
|
48
|
-
key = self.get_dataset_key(dataset_path, split)
|
49
|
-
return self.metadata.get(key)
|
50
|
-
|
51
|
-
def set_metadata(self, dataset_path: str, split: str, metadata: Dict[str, Any]):
|
52
|
-
"""Cache metadata for a dataset."""
|
53
|
-
key = self.get_dataset_key(dataset_path, split)
|
54
|
-
metadata["cached_at"] = datetime.utcnow().isoformat()
|
55
|
-
metadata["dataset_path"] = dataset_path
|
56
|
-
metadata["split"] = split
|
57
|
-
self.metadata[key] = metadata
|
58
|
-
self._save_cache()
|
59
|
-
logger.info(f"Cached metadata for {key}: {metadata.get('total_items', 0)} items")
|
60
|
-
|
61
|
-
def invalidate(self, dataset_path: str, split: str):
|
62
|
-
"""Remove cached metadata for a dataset."""
|
63
|
-
key = self.get_dataset_key(dataset_path, split)
|
64
|
-
if key in self.metadata:
|
65
|
-
del self.metadata[key]
|
66
|
-
self._save_cache()
|
67
|
-
logger.info(f"Invalidated metadata cache for {key}")
|
caption_flow/utils/job_queue.py
DELETED
@@ -1,41 +0,0 @@
|
|
1
|
-
"""Job queue management."""
|
2
|
-
|
3
|
-
import asyncio
|
4
|
-
from typing import Optional
|
5
|
-
from collections import deque
|
6
|
-
|
7
|
-
from ..models import Job
|
8
|
-
|
9
|
-
|
10
|
-
class JobQueue:
|
11
|
-
"""Priority job queue with backpressure."""
|
12
|
-
|
13
|
-
def __init__(self):
|
14
|
-
self.queue = deque()
|
15
|
-
self.processing = set()
|
16
|
-
self.lock = asyncio.Lock()
|
17
|
-
|
18
|
-
async def add(self, job: Job):
|
19
|
-
"""Add job to queue."""
|
20
|
-
async with self.lock:
|
21
|
-
self.queue.append(job)
|
22
|
-
|
23
|
-
async def get_next(self) -> Optional[Job]:
|
24
|
-
"""Get next available job."""
|
25
|
-
async with self.lock:
|
26
|
-
if self.queue:
|
27
|
-
job = self.queue.popleft()
|
28
|
-
self.processing.add(job.job_id)
|
29
|
-
return job
|
30
|
-
return None
|
31
|
-
|
32
|
-
async def complete(self, job_id: str):
|
33
|
-
"""Mark job as complete."""
|
34
|
-
async with self.lock:
|
35
|
-
self.processing.discard(job_id)
|
36
|
-
|
37
|
-
async def requeue(self, job: Job):
|
38
|
-
"""Requeue a job (for failures)."""
|
39
|
-
async with self.lock:
|
40
|
-
self.processing.discard(job.job_id)
|
41
|
-
self.queue.appendleft(job) # Priority requeue
|
@@ -1,119 +0,0 @@
|
|
1
|
-
"""Shard processing abstraction for different dataset types."""
|
2
|
-
|
3
|
-
import io
|
4
|
-
import logging
|
5
|
-
import time
|
6
|
-
from abc import ABC, abstractmethod
|
7
|
-
from pathlib import Path
|
8
|
-
from typing import Generator, Tuple, Optional, Dict, Any
|
9
|
-
from dataclasses import dataclass
|
10
|
-
from .image_processor import ImageProcessor
|
11
|
-
from threading import Event
|
12
|
-
import shlex
|
13
|
-
|
14
|
-
import webdataset as wds
|
15
|
-
from PIL import Image
|
16
|
-
|
17
|
-
from .dataset_loader import DatasetLoader
|
18
|
-
|
19
|
-
logger = logging.getLogger(__name__)
|
20
|
-
|
21
|
-
|
22
|
-
class ShardProcessor(ABC):
|
23
|
-
"""Abstract base for processing dataset shards."""
|
24
|
-
|
25
|
-
@abstractmethod
|
26
|
-
def iterate_chunk_with_metadata(
|
27
|
-
self,
|
28
|
-
chunk,
|
29
|
-
dataset_loader: Optional[DatasetLoader],
|
30
|
-
should_stop: Event,
|
31
|
-
connected: Event,
|
32
|
-
) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
|
33
|
-
"""
|
34
|
-
Iterate through items in a chunk with metadata.
|
35
|
-
|
36
|
-
Yields:
|
37
|
-
Tuple of (key, url, image_data, metadata)
|
38
|
-
"""
|
39
|
-
pass
|
40
|
-
|
41
|
-
|
42
|
-
class WebDatasetShardProcessor(ShardProcessor):
|
43
|
-
"""Processor for WebDataset tar shards with range support."""
|
44
|
-
|
45
|
-
def __init__(self, hf_token: Optional[str] = None, dataset_type: str = "local"):
|
46
|
-
self.hf_token = hf_token
|
47
|
-
self.dataset_type = dataset_type
|
48
|
-
|
49
|
-
def iterate_chunk_with_metadata(
|
50
|
-
self,
|
51
|
-
chunk,
|
52
|
-
dataset_loader: Optional[DatasetLoader],
|
53
|
-
should_stop: Event,
|
54
|
-
connected: Event,
|
55
|
-
) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
|
56
|
-
"""Process WebDataset shard chunk with metadata and range support."""
|
57
|
-
# Get unprocessed ranges
|
58
|
-
unprocessed_ranges = getattr(chunk, "unprocessed_ranges", [(0, chunk.chunk_size - 1)])
|
59
|
-
|
60
|
-
logger.info(
|
61
|
-
f"Processing WebDataset chunk {chunk.chunk_id} with ranges: {unprocessed_ranges}"
|
62
|
-
)
|
63
|
-
|
64
|
-
# Create WebDataset pipeline
|
65
|
-
if self.dataset_type == "huggingface":
|
66
|
-
# Use curl with auth for HuggingFace WebDataset
|
67
|
-
url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.hf_token)}' {shlex.quote(chunk.shard_url)} || true"
|
68
|
-
ds = wds.DataPipeline(
|
69
|
-
wds.SimpleShardList(url_cmd),
|
70
|
-
wds.tarfile_to_samples(),
|
71
|
-
wds.to_tuple("__key__", "jpg;png;jpeg;webp;jxl"),
|
72
|
-
)
|
73
|
-
else:
|
74
|
-
# Local file
|
75
|
-
ds = wds.DataPipeline(
|
76
|
-
wds.SimpleShardList(chunk.shard_url),
|
77
|
-
wds.tarfile_to_samples(),
|
78
|
-
wds.to_tuple("__key__", "jpg;png;jpeg;webp;jxl"),
|
79
|
-
)
|
80
|
-
|
81
|
-
# Process items
|
82
|
-
absolute_idx = 0 # Absolute index in the shard
|
83
|
-
items_yielded = 0
|
84
|
-
|
85
|
-
for key, image_data in ds:
|
86
|
-
# Check if we should stop
|
87
|
-
if should_stop.is_set() or not connected.is_set():
|
88
|
-
logger.info(f"Stopping WebDataset chunk processing early due to disconnect")
|
89
|
-
break
|
90
|
-
|
91
|
-
# Skip items before chunk start
|
92
|
-
if absolute_idx < chunk.start_index:
|
93
|
-
absolute_idx += 1
|
94
|
-
continue
|
95
|
-
|
96
|
-
# Calculate relative index within chunk
|
97
|
-
relative_idx = absolute_idx - chunk.start_index
|
98
|
-
|
99
|
-
# Stop if beyond chunk
|
100
|
-
if relative_idx >= chunk.chunk_size:
|
101
|
-
break
|
102
|
-
|
103
|
-
# Check if current index is in any unprocessed range
|
104
|
-
in_range = any(start <= relative_idx <= end for start, end in unprocessed_ranges)
|
105
|
-
|
106
|
-
if in_range:
|
107
|
-
# Create metadata with the relative index
|
108
|
-
metadata = {
|
109
|
-
"_chunk_relative_index": relative_idx,
|
110
|
-
}
|
111
|
-
items_yielded += 1
|
112
|
-
yield key, chunk.shard_url, image_data, metadata
|
113
|
-
|
114
|
-
absolute_idx += 1
|
115
|
-
|
116
|
-
logger.info(
|
117
|
-
f"WebDataset chunk {chunk.chunk_id}: yielded {items_yielded} items "
|
118
|
-
f"from ranges {unprocessed_ranges}"
|
119
|
-
)
|
@@ -1,83 +0,0 @@
|
|
1
|
-
"""Shard tracking using CheckpointTracker base class."""
|
2
|
-
|
3
|
-
from pathlib import Path
|
4
|
-
from typing import Dict, Any, List, Set
|
5
|
-
|
6
|
-
from .checkpoint_tracker import CheckpointTracker
|
7
|
-
|
8
|
-
|
9
|
-
class ShardTracker(CheckpointTracker):
|
10
|
-
"""Tracks shard processing progress."""
|
11
|
-
|
12
|
-
def __init__(self, checkpoint_path: Path):
|
13
|
-
"""Initialize shard tracker with checkpoint file."""
|
14
|
-
self.completed_shards: Set[str] = set()
|
15
|
-
self.partial_shards: Dict[str, Dict[str, Any]] = {}
|
16
|
-
super().__init__(checkpoint_path)
|
17
|
-
|
18
|
-
def _get_default_state(self) -> Dict[str, Any]:
|
19
|
-
"""Return default state structure for new checkpoints."""
|
20
|
-
return {"completed_shards": [], "partial_shards": {}}
|
21
|
-
|
22
|
-
def _deserialize_state(self, data: Dict[str, Any]) -> None:
|
23
|
-
"""Deserialize loaded data into instance state."""
|
24
|
-
self.completed_shards = set(data.get("completed_shards", []))
|
25
|
-
self.partial_shards = data.get("partial_shards", {})
|
26
|
-
|
27
|
-
def _serialize_state(self) -> Dict[str, Any]:
|
28
|
-
"""Serialize instance state for saving."""
|
29
|
-
return {
|
30
|
-
"completed_shards": list(self.completed_shards),
|
31
|
-
"partial_shards": self.partial_shards,
|
32
|
-
}
|
33
|
-
|
34
|
-
def mark_complete(self, shard_name: str) -> None:
|
35
|
-
"""Mark a shard as complete."""
|
36
|
-
self.completed_shards.add(shard_name)
|
37
|
-
if shard_name in self.partial_shards:
|
38
|
-
del self.partial_shards[shard_name]
|
39
|
-
self.save()
|
40
|
-
|
41
|
-
def update_partial(self, shard_name: str, processed_keys: List[str]) -> None:
|
42
|
-
"""Update partial progress for a shard."""
|
43
|
-
self.partial_shards[shard_name] = {"keys": processed_keys, "count": len(processed_keys)}
|
44
|
-
self.save()
|
45
|
-
|
46
|
-
def get_processed_keys(self, shard_name: str) -> Set[str]:
|
47
|
-
"""Get set of processed keys for a shard."""
|
48
|
-
if shard_name in self.completed_shards:
|
49
|
-
return set() # All done
|
50
|
-
|
51
|
-
if shard_name in self.partial_shards:
|
52
|
-
return set(self.partial_shards[shard_name].get("keys", []))
|
53
|
-
|
54
|
-
return set()
|
55
|
-
|
56
|
-
def is_complete(self, shard_name: str) -> bool:
|
57
|
-
"""Check if a shard is complete."""
|
58
|
-
return shard_name in self.completed_shards
|
59
|
-
|
60
|
-
def get_remaining_shards(self, all_shards: List[str]) -> List[str]:
|
61
|
-
"""Get list of shards that still need processing."""
|
62
|
-
remaining = []
|
63
|
-
for s in all_shards:
|
64
|
-
shard_name = Path(s).stem
|
65
|
-
|
66
|
-
if shard_name not in self.completed_shards:
|
67
|
-
remaining.append(s)
|
68
|
-
|
69
|
-
return remaining
|
70
|
-
|
71
|
-
def get_stats(self) -> Dict[str, Any]:
|
72
|
-
"""Get shard tracking statistics."""
|
73
|
-
base_stats = super().get_stats()
|
74
|
-
base_stats.update(
|
75
|
-
{
|
76
|
-
"completed_shards": len(self.completed_shards),
|
77
|
-
"partial_shards": len(self.partial_shards),
|
78
|
-
"total_partial_keys": sum(
|
79
|
-
len(data.get("keys", [])) for data in self.partial_shards.values()
|
80
|
-
),
|
81
|
-
}
|
82
|
-
)
|
83
|
-
return base_stats
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|