caption-flow 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/cli.py +308 -0
- caption_flow/models.py +134 -1
- caption_flow/monitor.py +1 -1
- caption_flow/orchestrator.py +423 -1715
- caption_flow/processors/__init__.py +11 -0
- caption_flow/processors/base.py +219 -0
- caption_flow/processors/huggingface.py +832 -0
- caption_flow/processors/local_filesystem.py +683 -0
- caption_flow/processors/webdataset.py +782 -0
- caption_flow/storage/__init__.py +1 -0
- caption_flow/storage/exporter.py +550 -0
- caption_flow/{storage.py → storage/manager.py} +489 -401
- caption_flow/utils/checkpoint_tracker.py +2 -2
- caption_flow/utils/chunk_tracker.py +73 -32
- caption_flow/utils/dataset_loader.py +58 -298
- caption_flow/utils/dataset_metadata_cache.py +67 -0
- caption_flow/utils/image_processor.py +1 -4
- caption_flow/utils/shard_processor.py +5 -265
- caption_flow/utils/shard_tracker.py +1 -5
- caption_flow/viewer.py +594 -0
- caption_flow/workers/base.py +3 -3
- caption_flow/workers/caption.py +416 -792
- {caption_flow-0.2.2.dist-info → caption_flow-0.2.4.dist-info}/METADATA +49 -180
- caption_flow-0.2.4.dist-info/RECORD +38 -0
- caption_flow-0.2.2.dist-info/RECORD +0 -29
- {caption_flow-0.2.2.dist-info → caption_flow-0.2.4.dist-info}/WHEEL +0 -0
- {caption_flow-0.2.2.dist-info → caption_flow-0.2.4.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.2.2.dist-info → caption_flow-0.2.4.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.2.2.dist-info → caption_flow-0.2.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,67 @@
|
|
1
|
+
"""Dataset metadata caching for efficient HuggingFace dataset handling."""
|
2
|
+
|
3
|
+
import json
|
4
|
+
import logging
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Dict, Any, Optional, List
|
7
|
+
from datetime import datetime
|
8
|
+
|
9
|
+
logger = logging.getLogger(__name__)
|
10
|
+
|
11
|
+
|
12
|
+
class DatasetMetadataCache:
|
13
|
+
"""Caches dataset metadata to avoid repeated full iterations."""
|
14
|
+
|
15
|
+
def __init__(self, cache_dir: Path):
|
16
|
+
self.cache_dir = Path(cache_dir)
|
17
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
18
|
+
self.cache_file = self.cache_dir / "dataset_metadata.json"
|
19
|
+
self.metadata: Dict[str, Any] = {}
|
20
|
+
self._load_cache()
|
21
|
+
|
22
|
+
def _load_cache(self):
|
23
|
+
"""Load cached metadata from disk."""
|
24
|
+
if self.cache_file.exists():
|
25
|
+
try:
|
26
|
+
with open(self.cache_file, "r") as f:
|
27
|
+
self.metadata = json.load(f)
|
28
|
+
logger.info(f"Loaded dataset metadata cache with {len(self.metadata)} datasets")
|
29
|
+
except Exception as e:
|
30
|
+
logger.error(f"Failed to load metadata cache: {e}")
|
31
|
+
self.metadata = {}
|
32
|
+
|
33
|
+
def _save_cache(self):
|
34
|
+
"""Save metadata cache to disk."""
|
35
|
+
try:
|
36
|
+
with open(self.cache_file, "w") as f:
|
37
|
+
json.dump(self.metadata, f, indent=2)
|
38
|
+
logger.debug("Saved dataset metadata cache")
|
39
|
+
except Exception as e:
|
40
|
+
logger.error(f"Failed to save metadata cache: {e}")
|
41
|
+
|
42
|
+
def get_dataset_key(self, dataset_path: str, split: str) -> str:
|
43
|
+
"""Generate a unique key for a dataset+split combination."""
|
44
|
+
return f"{dataset_path}:{split}"
|
45
|
+
|
46
|
+
def get_metadata(self, dataset_path: str, split: str) -> Optional[Dict[str, Any]]:
|
47
|
+
"""Get cached metadata for a dataset."""
|
48
|
+
key = self.get_dataset_key(dataset_path, split)
|
49
|
+
return self.metadata.get(key)
|
50
|
+
|
51
|
+
def set_metadata(self, dataset_path: str, split: str, metadata: Dict[str, Any]):
|
52
|
+
"""Cache metadata for a dataset."""
|
53
|
+
key = self.get_dataset_key(dataset_path, split)
|
54
|
+
metadata["cached_at"] = datetime.utcnow().isoformat()
|
55
|
+
metadata["dataset_path"] = dataset_path
|
56
|
+
metadata["split"] = split
|
57
|
+
self.metadata[key] = metadata
|
58
|
+
self._save_cache()
|
59
|
+
logger.info(f"Cached metadata for {key}: {metadata.get('total_items', 0)} items")
|
60
|
+
|
61
|
+
def invalidate(self, dataset_path: str, split: str):
|
62
|
+
"""Remove cached metadata for a dataset."""
|
63
|
+
key = self.get_dataset_key(dataset_path, split)
|
64
|
+
if key in self.metadata:
|
65
|
+
del self.metadata[key]
|
66
|
+
self._save_cache()
|
67
|
+
logger.info(f"Invalidated metadata cache for {key}")
|
@@ -112,10 +112,7 @@ class ImageProcessor:
|
|
112
112
|
return None
|
113
113
|
|
114
114
|
except Exception as e:
|
115
|
-
logger.error(f"Error processing image data: {e}")
|
116
|
-
import traceback
|
117
|
-
|
118
|
-
logger.error(traceback.format_exc())
|
115
|
+
logger.error(f"Error processing image data: {e}", exc_info=True)
|
119
116
|
return None
|
120
117
|
|
121
118
|
@staticmethod
|
@@ -7,7 +7,6 @@ from abc import ABC, abstractmethod
|
|
7
7
|
from pathlib import Path
|
8
8
|
from typing import Generator, Tuple, Optional, Dict, Any
|
9
9
|
from dataclasses import dataclass
|
10
|
-
from datasets import load_dataset
|
11
10
|
from .image_processor import ImageProcessor
|
12
11
|
from threading import Event
|
13
12
|
import shlex
|
@@ -24,213 +23,22 @@ class ShardProcessor(ABC):
|
|
24
23
|
"""Abstract base for processing dataset shards."""
|
25
24
|
|
26
25
|
@abstractmethod
|
27
|
-
def
|
26
|
+
def iterate_chunk_with_metadata(
|
28
27
|
self,
|
29
28
|
chunk,
|
30
29
|
dataset_loader: Optional[DatasetLoader],
|
31
30
|
should_stop: Event,
|
32
31
|
connected: Event,
|
33
|
-
) -> Generator[Tuple[str, str, bytes], None, None]:
|
32
|
+
) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
|
34
33
|
"""
|
35
|
-
Iterate through items in a chunk.
|
34
|
+
Iterate through items in a chunk with metadata.
|
36
35
|
|
37
36
|
Yields:
|
38
|
-
Tuple of (key, url, image_data)
|
37
|
+
Tuple of (key, url, image_data, metadata)
|
39
38
|
"""
|
40
39
|
pass
|
41
40
|
|
42
41
|
|
43
|
-
class HFDatasetShardProcessor(ShardProcessor):
|
44
|
-
"""Processor for HuggingFace virtual dataset shards."""
|
45
|
-
|
46
|
-
def iterate_chunk(
|
47
|
-
self,
|
48
|
-
chunk,
|
49
|
-
dataset_loader: Optional[DatasetLoader],
|
50
|
-
should_stop: Event,
|
51
|
-
connected: Event,
|
52
|
-
) -> Generator[Tuple[str, str, bytes], None, None]:
|
53
|
-
"""Process HuggingFace virtual shard chunk."""
|
54
|
-
if not dataset_loader:
|
55
|
-
logger.error("No dataset loader configured for HuggingFace dataset shard")
|
56
|
-
return
|
57
|
-
|
58
|
-
# Get unprocessed ranges
|
59
|
-
unprocessed_ranges = getattr(chunk, "unprocessed_ranges", [(0, chunk.chunk_size - 1)])
|
60
|
-
|
61
|
-
logger.info(
|
62
|
-
f"Processing HF dataset chunk {chunk.chunk_id} with ranges: {unprocessed_ranges}"
|
63
|
-
)
|
64
|
-
|
65
|
-
items_processed = 0
|
66
|
-
current_idx = 0
|
67
|
-
|
68
|
-
# Construct proper virtual shard URL
|
69
|
-
parts = chunk.shard_url.split("_chunk_")
|
70
|
-
if len(parts) == 2:
|
71
|
-
base_path = parts[0]
|
72
|
-
virtual_shard_url = f"{base_path}:chunk:{chunk.start_index}"
|
73
|
-
else:
|
74
|
-
virtual_shard_url = chunk.shard_url
|
75
|
-
|
76
|
-
logger.debug(f"Using virtual shard URL: {virtual_shard_url}")
|
77
|
-
|
78
|
-
# Iterate through the virtual shard
|
79
|
-
for key, url, image_data in dataset_loader.iterate_shard(virtual_shard_url):
|
80
|
-
# Check if we should stop
|
81
|
-
if should_stop.is_set() or not connected.is_set():
|
82
|
-
logger.info(f"Stopping chunk processing early due to disconnect")
|
83
|
-
break
|
84
|
-
|
85
|
-
# Check if current index is in any unprocessed range
|
86
|
-
in_range = any(start <= current_idx <= end for start, end in unprocessed_ranges)
|
87
|
-
|
88
|
-
if not in_range:
|
89
|
-
current_idx += 1
|
90
|
-
continue # Skip already processed items
|
91
|
-
|
92
|
-
# Check if we've processed enough for this chunk
|
93
|
-
if current_idx >= chunk.chunk_size:
|
94
|
-
break
|
95
|
-
|
96
|
-
items_processed += 1
|
97
|
-
current_idx += 1
|
98
|
-
yield key, url, image_data
|
99
|
-
|
100
|
-
logger.info(
|
101
|
-
f"HF dataset chunk {chunk.chunk_id}: yielded {items_processed} items "
|
102
|
-
f"from ranges {unprocessed_ranges}"
|
103
|
-
)
|
104
|
-
|
105
|
-
def iterate_chunk_with_metadata(
|
106
|
-
self,
|
107
|
-
chunk,
|
108
|
-
dataset_loader: Optional[DatasetLoader],
|
109
|
-
should_stop: Event,
|
110
|
-
connected: Event,
|
111
|
-
) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
|
112
|
-
"""
|
113
|
-
Process HuggingFace virtual shard chunk with metadata, range by range.
|
114
|
-
"""
|
115
|
-
if not dataset_loader:
|
116
|
-
logger.error("No dataset loader configured for HuggingFace dataset shard")
|
117
|
-
return
|
118
|
-
|
119
|
-
# Get unprocessed ranges
|
120
|
-
unprocessed_ranges = getattr(chunk, "unprocessed_ranges", [(0, chunk.chunk_size - 1)])
|
121
|
-
|
122
|
-
logger.info(
|
123
|
-
f"Processing HF dataset chunk {chunk.chunk_id} with {len(unprocessed_ranges)} ranges"
|
124
|
-
)
|
125
|
-
|
126
|
-
items_yielded = 0
|
127
|
-
|
128
|
-
# Process each range independently with its own iterator
|
129
|
-
for range_start, range_end in unprocessed_ranges:
|
130
|
-
if should_stop.is_set() or not connected.is_set():
|
131
|
-
logger.info(f"Stopping chunk processing early due to disconnect")
|
132
|
-
break
|
133
|
-
|
134
|
-
# Calculate absolute indices for this range
|
135
|
-
abs_start = chunk.start_index + range_start
|
136
|
-
abs_end = chunk.start_index + range_end
|
137
|
-
range_size = range_end - range_start + 1
|
138
|
-
|
139
|
-
logger.debug(
|
140
|
-
f"Processing range [{range_start}, {range_end}] "
|
141
|
-
f"(absolute: [{abs_start}, {abs_end}])"
|
142
|
-
)
|
143
|
-
|
144
|
-
try:
|
145
|
-
# Create a fresh dataset iterator for this range
|
146
|
-
dataset = load_dataset(
|
147
|
-
dataset_loader.dataset_path,
|
148
|
-
split=dataset_loader.split,
|
149
|
-
streaming=True,
|
150
|
-
token=dataset_loader.token,
|
151
|
-
)
|
152
|
-
|
153
|
-
# Use state_dict if available for efficient positioning
|
154
|
-
if hasattr(dataset, "load_state_dict") and hasattr(dataset, "state_dict"):
|
155
|
-
try:
|
156
|
-
state = dataset.state_dict()
|
157
|
-
# Modify state to jump to abs_start
|
158
|
-
if "num_examples_since_previous_state" in state:
|
159
|
-
state["num_examples_since_previous_state"] = abs_start
|
160
|
-
if "examples_iterable" in state and isinstance(
|
161
|
-
state["examples_iterable"], dict
|
162
|
-
):
|
163
|
-
if "shard_example_idx" in state["examples_iterable"]:
|
164
|
-
state["examples_iterable"]["shard_example_idx"] = abs_start
|
165
|
-
dataset.load_state_dict(state)
|
166
|
-
logger.debug(f"Positioned dataset at index {abs_start} using state_dict")
|
167
|
-
except Exception as e:
|
168
|
-
logger.debug(f"Could not use state_dict, falling back to skip: {e}")
|
169
|
-
dataset = dataset.skip(abs_start)
|
170
|
-
else:
|
171
|
-
# Fall back to skip
|
172
|
-
dataset = dataset.skip(abs_start)
|
173
|
-
|
174
|
-
# Process items in this range
|
175
|
-
range_items = 0
|
176
|
-
for item in dataset:
|
177
|
-
if range_items >= range_size:
|
178
|
-
break
|
179
|
-
|
180
|
-
if should_stop.is_set() or not connected.is_set():
|
181
|
-
break
|
182
|
-
|
183
|
-
# Generate key for this item
|
184
|
-
current_abs_idx = abs_start + range_items
|
185
|
-
key = f"{dataset_loader.dataset_path.replace('/', '_')}_{current_abs_idx:08d}"
|
186
|
-
|
187
|
-
try:
|
188
|
-
if dataset_loader.image_column in item:
|
189
|
-
img_data = item[dataset_loader.image_column]
|
190
|
-
image_bytes = ImageProcessor.process_image_data(img_data)
|
191
|
-
|
192
|
-
if image_bytes:
|
193
|
-
# Extract metadata
|
194
|
-
metadata = {
|
195
|
-
k: v
|
196
|
-
for k, v in item.items()
|
197
|
-
if k != dataset_loader.image_column
|
198
|
-
}
|
199
|
-
# Add chunk-relative index to metadata
|
200
|
-
metadata["_chunk_relative_index"] = range_start + range_items
|
201
|
-
|
202
|
-
url = f"hf://{dataset_loader.dataset_path}#{current_abs_idx}"
|
203
|
-
|
204
|
-
items_yielded += 1
|
205
|
-
range_items += 1
|
206
|
-
|
207
|
-
yield key, url, image_bytes, metadata
|
208
|
-
else:
|
209
|
-
logger.warning(
|
210
|
-
f"Failed to process image at index {current_abs_idx}"
|
211
|
-
)
|
212
|
-
range_items += 1
|
213
|
-
else:
|
214
|
-
logger.warning(
|
215
|
-
f"No image column '{dataset_loader.image_column}' at index {current_abs_idx}"
|
216
|
-
)
|
217
|
-
range_items += 1
|
218
|
-
|
219
|
-
except Exception as e:
|
220
|
-
logger.error(f"Error processing item at index {current_abs_idx}: {e}")
|
221
|
-
range_items += 1
|
222
|
-
continue
|
223
|
-
|
224
|
-
except Exception as e:
|
225
|
-
logger.error(f"Error processing range [{range_start}, {range_end}]: {e}")
|
226
|
-
continue
|
227
|
-
|
228
|
-
logger.info(
|
229
|
-
f"HF dataset chunk {chunk.chunk_id}: yielded {items_yielded} items "
|
230
|
-
f"from {len(unprocessed_ranges)} ranges"
|
231
|
-
)
|
232
|
-
|
233
|
-
|
234
42
|
class WebDatasetShardProcessor(ShardProcessor):
|
235
43
|
"""Processor for WebDataset tar shards with range support."""
|
236
44
|
|
@@ -238,74 +46,6 @@ class WebDatasetShardProcessor(ShardProcessor):
|
|
238
46
|
self.hf_token = hf_token
|
239
47
|
self.dataset_type = dataset_type
|
240
48
|
|
241
|
-
def iterate_chunk(
|
242
|
-
self,
|
243
|
-
chunk,
|
244
|
-
dataset_loader: Optional[DatasetLoader],
|
245
|
-
should_stop: Event,
|
246
|
-
connected: Event,
|
247
|
-
) -> Generator[Tuple[str, str, bytes], None, None]:
|
248
|
-
"""Process WebDataset shard chunk with unprocessed ranges."""
|
249
|
-
# Get unprocessed ranges
|
250
|
-
unprocessed_ranges = getattr(chunk, "unprocessed_ranges", [(0, chunk.chunk_size - 1)])
|
251
|
-
|
252
|
-
logger.info(
|
253
|
-
f"Processing WebDataset chunk {chunk.chunk_id} with ranges: {unprocessed_ranges}"
|
254
|
-
)
|
255
|
-
|
256
|
-
# Create WebDataset pipeline
|
257
|
-
if self.dataset_type == "huggingface" and not chunk.shard_url.startswith("hf_dataset:"):
|
258
|
-
# Use curl with auth for HuggingFace WebDataset
|
259
|
-
url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.hf_token)}' {shlex.quote(chunk.shard_url)} || true"
|
260
|
-
ds = wds.DataPipeline(
|
261
|
-
wds.SimpleShardList(url_cmd),
|
262
|
-
wds.tarfile_to_samples(),
|
263
|
-
wds.to_tuple("__key__", "jpg;png;jpeg;webp;jxl"),
|
264
|
-
)
|
265
|
-
else:
|
266
|
-
# Local file
|
267
|
-
ds = wds.DataPipeline(
|
268
|
-
wds.SimpleShardList(chunk.shard_url),
|
269
|
-
wds.tarfile_to_samples(),
|
270
|
-
wds.to_tuple("__key__", "jpg;png;jpeg;webp;jxl"),
|
271
|
-
)
|
272
|
-
|
273
|
-
# Process items
|
274
|
-
current_idx = 0
|
275
|
-
items_yielded = 0
|
276
|
-
|
277
|
-
for key, image_data in ds:
|
278
|
-
# Check if we should stop
|
279
|
-
if should_stop.is_set() or not connected.is_set():
|
280
|
-
logger.info(f"Stopping WebDataset chunk processing early due to disconnect")
|
281
|
-
break
|
282
|
-
|
283
|
-
# Calculate relative index within chunk
|
284
|
-
relative_idx = current_idx - chunk.start_index
|
285
|
-
|
286
|
-
# Skip items before chunk start
|
287
|
-
if current_idx < chunk.start_index:
|
288
|
-
current_idx += 1
|
289
|
-
continue
|
290
|
-
|
291
|
-
# Stop if beyond chunk
|
292
|
-
if relative_idx >= chunk.chunk_size:
|
293
|
-
break
|
294
|
-
|
295
|
-
# Check if current index is in any unprocessed range
|
296
|
-
in_range = any(start <= relative_idx <= end for start, end in unprocessed_ranges)
|
297
|
-
|
298
|
-
if in_range:
|
299
|
-
items_yielded += 1
|
300
|
-
yield key, chunk.shard_url, image_data
|
301
|
-
|
302
|
-
current_idx += 1
|
303
|
-
|
304
|
-
logger.info(
|
305
|
-
f"WebDataset chunk {chunk.chunk_id}: yielded {items_yielded} items "
|
306
|
-
f"from ranges {unprocessed_ranges}"
|
307
|
-
)
|
308
|
-
|
309
49
|
def iterate_chunk_with_metadata(
|
310
50
|
self,
|
311
51
|
chunk,
|
@@ -322,7 +62,7 @@ class WebDatasetShardProcessor(ShardProcessor):
|
|
322
62
|
)
|
323
63
|
|
324
64
|
# Create WebDataset pipeline
|
325
|
-
if self.dataset_type == "huggingface"
|
65
|
+
if self.dataset_type == "huggingface":
|
326
66
|
# Use curl with auth for HuggingFace WebDataset
|
327
67
|
url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.hf_token)}' {shlex.quote(chunk.shard_url)} || true"
|
328
68
|
ds = wds.DataPipeline(
|
@@ -61,11 +61,7 @@ class ShardTracker(CheckpointTracker):
|
|
61
61
|
"""Get list of shards that still need processing."""
|
62
62
|
remaining = []
|
63
63
|
for s in all_shards:
|
64
|
-
|
65
|
-
if s.startswith("hf_dataset:"):
|
66
|
-
shard_name = s # Use full virtual shard ID
|
67
|
-
else:
|
68
|
-
shard_name = Path(s).stem
|
64
|
+
shard_name = Path(s).stem
|
69
65
|
|
70
66
|
if shard_name not in self.completed_shards:
|
71
67
|
remaining.append(s)
|