caption-flow 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/cli.py +2 -1
- caption_flow/models.py +108 -1
- caption_flow/monitor.py +1 -1
- caption_flow/orchestrator.py +423 -1595
- caption_flow/processors/__init__.py +11 -0
- caption_flow/processors/base.py +219 -0
- caption_flow/processors/huggingface.py +832 -0
- caption_flow/processors/local_filesystem.py +683 -0
- caption_flow/processors/webdataset.py +782 -0
- caption_flow/storage.py +415 -406
- caption_flow/utils/checkpoint_tracker.py +2 -2
- caption_flow/utils/chunk_tracker.py +94 -35
- caption_flow/utils/dataset_loader.py +64 -522
- caption_flow/utils/dataset_metadata_cache.py +67 -0
- caption_flow/utils/image_processor.py +1 -4
- caption_flow/utils/shard_processor.py +4 -200
- caption_flow/utils/shard_tracker.py +1 -5
- caption_flow/workers/base.py +3 -3
- caption_flow/workers/caption.py +416 -792
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/METADATA +29 -27
- caption_flow-0.2.3.dist-info/RECORD +35 -0
- caption_flow-0.2.1.dist-info/RECORD +0 -29
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/WHEEL +0 -0
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,67 @@
|
|
1
|
+
"""Dataset metadata caching for efficient HuggingFace dataset handling."""
|
2
|
+
|
3
|
+
import json
|
4
|
+
import logging
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Dict, Any, Optional, List
|
7
|
+
from datetime import datetime
|
8
|
+
|
9
|
+
logger = logging.getLogger(__name__)
|
10
|
+
|
11
|
+
|
12
|
+
class DatasetMetadataCache:
|
13
|
+
"""Caches dataset metadata to avoid repeated full iterations."""
|
14
|
+
|
15
|
+
def __init__(self, cache_dir: Path):
|
16
|
+
self.cache_dir = Path(cache_dir)
|
17
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
18
|
+
self.cache_file = self.cache_dir / "dataset_metadata.json"
|
19
|
+
self.metadata: Dict[str, Any] = {}
|
20
|
+
self._load_cache()
|
21
|
+
|
22
|
+
def _load_cache(self):
|
23
|
+
"""Load cached metadata from disk."""
|
24
|
+
if self.cache_file.exists():
|
25
|
+
try:
|
26
|
+
with open(self.cache_file, "r") as f:
|
27
|
+
self.metadata = json.load(f)
|
28
|
+
logger.info(f"Loaded dataset metadata cache with {len(self.metadata)} datasets")
|
29
|
+
except Exception as e:
|
30
|
+
logger.error(f"Failed to load metadata cache: {e}")
|
31
|
+
self.metadata = {}
|
32
|
+
|
33
|
+
def _save_cache(self):
|
34
|
+
"""Save metadata cache to disk."""
|
35
|
+
try:
|
36
|
+
with open(self.cache_file, "w") as f:
|
37
|
+
json.dump(self.metadata, f, indent=2)
|
38
|
+
logger.debug("Saved dataset metadata cache")
|
39
|
+
except Exception as e:
|
40
|
+
logger.error(f"Failed to save metadata cache: {e}")
|
41
|
+
|
42
|
+
def get_dataset_key(self, dataset_path: str, split: str) -> str:
|
43
|
+
"""Generate a unique key for a dataset+split combination."""
|
44
|
+
return f"{dataset_path}:{split}"
|
45
|
+
|
46
|
+
def get_metadata(self, dataset_path: str, split: str) -> Optional[Dict[str, Any]]:
|
47
|
+
"""Get cached metadata for a dataset."""
|
48
|
+
key = self.get_dataset_key(dataset_path, split)
|
49
|
+
return self.metadata.get(key)
|
50
|
+
|
51
|
+
def set_metadata(self, dataset_path: str, split: str, metadata: Dict[str, Any]):
|
52
|
+
"""Cache metadata for a dataset."""
|
53
|
+
key = self.get_dataset_key(dataset_path, split)
|
54
|
+
metadata["cached_at"] = datetime.utcnow().isoformat()
|
55
|
+
metadata["dataset_path"] = dataset_path
|
56
|
+
metadata["split"] = split
|
57
|
+
self.metadata[key] = metadata
|
58
|
+
self._save_cache()
|
59
|
+
logger.info(f"Cached metadata for {key}: {metadata.get('total_items', 0)} items")
|
60
|
+
|
61
|
+
def invalidate(self, dataset_path: str, split: str):
|
62
|
+
"""Remove cached metadata for a dataset."""
|
63
|
+
key = self.get_dataset_key(dataset_path, split)
|
64
|
+
if key in self.metadata:
|
65
|
+
del self.metadata[key]
|
66
|
+
self._save_cache()
|
67
|
+
logger.info(f"Invalidated metadata cache for {key}")
|
@@ -112,10 +112,7 @@ class ImageProcessor:
|
|
112
112
|
return None
|
113
113
|
|
114
114
|
except Exception as e:
|
115
|
-
logger.error(f"Error processing image data: {e}")
|
116
|
-
import traceback
|
117
|
-
|
118
|
-
logger.error(traceback.format_exc())
|
115
|
+
logger.error(f"Error processing image data: {e}", exc_info=True)
|
119
116
|
return None
|
120
117
|
|
121
118
|
@staticmethod
|
@@ -7,6 +7,7 @@ from abc import ABC, abstractmethod
|
|
7
7
|
from pathlib import Path
|
8
8
|
from typing import Generator, Tuple, Optional, Dict, Any
|
9
9
|
from dataclasses import dataclass
|
10
|
+
from .image_processor import ImageProcessor
|
10
11
|
from threading import Event
|
11
12
|
import shlex
|
12
13
|
|
@@ -22,84 +23,6 @@ class ShardProcessor(ABC):
|
|
22
23
|
"""Abstract base for processing dataset shards."""
|
23
24
|
|
24
25
|
@abstractmethod
|
25
|
-
def iterate_chunk(
|
26
|
-
self,
|
27
|
-
chunk,
|
28
|
-
dataset_loader: Optional[DatasetLoader],
|
29
|
-
should_stop: Event,
|
30
|
-
connected: Event,
|
31
|
-
) -> Generator[Tuple[str, str, bytes], None, None]:
|
32
|
-
"""
|
33
|
-
Iterate through items in a chunk.
|
34
|
-
|
35
|
-
Yields:
|
36
|
-
Tuple of (key, url, image_data)
|
37
|
-
"""
|
38
|
-
pass
|
39
|
-
|
40
|
-
|
41
|
-
class HFDatasetShardProcessor(ShardProcessor):
|
42
|
-
"""Processor for HuggingFace virtual dataset shards."""
|
43
|
-
|
44
|
-
def iterate_chunk(
|
45
|
-
self,
|
46
|
-
chunk,
|
47
|
-
dataset_loader: Optional[DatasetLoader],
|
48
|
-
should_stop: Event,
|
49
|
-
connected: Event,
|
50
|
-
) -> Generator[Tuple[str, str, bytes], None, None]:
|
51
|
-
"""Process HuggingFace virtual shard chunk."""
|
52
|
-
if not dataset_loader:
|
53
|
-
logger.error("No dataset loader configured for HuggingFace dataset shard")
|
54
|
-
return
|
55
|
-
|
56
|
-
# Get unprocessed ranges
|
57
|
-
unprocessed_ranges = getattr(chunk, "unprocessed_ranges", [(0, chunk.chunk_size - 1)])
|
58
|
-
|
59
|
-
logger.info(
|
60
|
-
f"Processing HF dataset chunk {chunk.chunk_id} with ranges: {unprocessed_ranges}"
|
61
|
-
)
|
62
|
-
|
63
|
-
items_processed = 0
|
64
|
-
current_idx = 0
|
65
|
-
|
66
|
-
# Construct proper virtual shard URL
|
67
|
-
parts = chunk.shard_url.split("_chunk_")
|
68
|
-
if len(parts) == 2:
|
69
|
-
base_path = parts[0]
|
70
|
-
virtual_shard_url = f"{base_path}:chunk:{chunk.start_index}"
|
71
|
-
else:
|
72
|
-
virtual_shard_url = chunk.shard_url
|
73
|
-
|
74
|
-
logger.debug(f"Using virtual shard URL: {virtual_shard_url}")
|
75
|
-
|
76
|
-
# Iterate through the virtual shard
|
77
|
-
for key, url, image_data in dataset_loader.iterate_shard(virtual_shard_url):
|
78
|
-
# Check if we should stop
|
79
|
-
if should_stop.is_set() or not connected.is_set():
|
80
|
-
logger.info(f"Stopping chunk processing early due to disconnect")
|
81
|
-
break
|
82
|
-
|
83
|
-
# Check if current index is in any unprocessed range
|
84
|
-
in_range = any(start <= current_idx <= end for start, end in unprocessed_ranges)
|
85
|
-
|
86
|
-
if not in_range:
|
87
|
-
current_idx += 1
|
88
|
-
continue # Skip already processed items
|
89
|
-
|
90
|
-
# Check if we've processed enough for this chunk
|
91
|
-
if current_idx >= chunk.chunk_size:
|
92
|
-
break
|
93
|
-
|
94
|
-
items_processed += 1
|
95
|
-
current_idx += 1
|
96
|
-
yield key, url, image_data
|
97
|
-
|
98
|
-
logger.info(
|
99
|
-
f"HF dataset chunk {chunk.chunk_id}: yielded {items_processed} items "
|
100
|
-
f"from ranges {unprocessed_ranges}"
|
101
|
-
)
|
102
|
-
|
103
26
|
def iterate_chunk_with_metadata(
|
104
27
|
self,
|
105
28
|
chunk,
|
@@ -108,63 +31,12 @@ class HFDatasetShardProcessor(ShardProcessor):
|
|
108
31
|
connected: Event,
|
109
32
|
) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
|
110
33
|
"""
|
111
|
-
|
34
|
+
Iterate through items in a chunk with metadata.
|
112
35
|
|
113
36
|
Yields:
|
114
37
|
Tuple of (key, url, image_data, metadata)
|
115
38
|
"""
|
116
|
-
|
117
|
-
logger.error("No dataset loader configured for HuggingFace dataset shard")
|
118
|
-
return
|
119
|
-
|
120
|
-
# Get unprocessed ranges
|
121
|
-
unprocessed_ranges = getattr(chunk, "unprocessed_ranges", [(0, chunk.chunk_size - 1)])
|
122
|
-
|
123
|
-
logger.info(
|
124
|
-
f"Processing HF dataset chunk {chunk.chunk_id} with ranges: {unprocessed_ranges}"
|
125
|
-
)
|
126
|
-
|
127
|
-
items_processed = 0
|
128
|
-
current_idx = 0
|
129
|
-
|
130
|
-
# Construct proper virtual shard URL
|
131
|
-
parts = chunk.shard_url.split("_chunk_")
|
132
|
-
if len(parts) == 2:
|
133
|
-
base_path = parts[0]
|
134
|
-
virtual_shard_url = f"{base_path}:chunk:{chunk.start_index}"
|
135
|
-
else:
|
136
|
-
virtual_shard_url = chunk.shard_url
|
137
|
-
|
138
|
-
logger.debug(f"Using virtual shard URL: {virtual_shard_url}")
|
139
|
-
|
140
|
-
# Use the new iterate method that includes metadata
|
141
|
-
for key, url, image_data, metadata in dataset_loader.iterate_shard_with_metadata(
|
142
|
-
virtual_shard_url
|
143
|
-
):
|
144
|
-
# Check if we should stop
|
145
|
-
if should_stop.is_set() or not connected.is_set():
|
146
|
-
logger.info(f"Stopping chunk processing early due to disconnect")
|
147
|
-
break
|
148
|
-
|
149
|
-
# Check if current index is in any unprocessed range
|
150
|
-
in_range = any(start <= current_idx <= end for start, end in unprocessed_ranges)
|
151
|
-
|
152
|
-
if not in_range:
|
153
|
-
current_idx += 1
|
154
|
-
continue # Skip already processed items
|
155
|
-
|
156
|
-
# Check if we've processed enough for this chunk
|
157
|
-
if current_idx >= chunk.chunk_size:
|
158
|
-
break
|
159
|
-
|
160
|
-
items_processed += 1
|
161
|
-
current_idx += 1
|
162
|
-
yield key, url, image_data, metadata
|
163
|
-
|
164
|
-
logger.info(
|
165
|
-
f"HF dataset chunk {chunk.chunk_id}: yielded {items_processed} items "
|
166
|
-
f"from ranges {unprocessed_ranges}"
|
167
|
-
)
|
39
|
+
pass
|
168
40
|
|
169
41
|
|
170
42
|
class WebDatasetShardProcessor(ShardProcessor):
|
@@ -174,74 +46,6 @@ class WebDatasetShardProcessor(ShardProcessor):
|
|
174
46
|
self.hf_token = hf_token
|
175
47
|
self.dataset_type = dataset_type
|
176
48
|
|
177
|
-
def iterate_chunk(
|
178
|
-
self,
|
179
|
-
chunk,
|
180
|
-
dataset_loader: Optional[DatasetLoader],
|
181
|
-
should_stop: Event,
|
182
|
-
connected: Event,
|
183
|
-
) -> Generator[Tuple[str, str, bytes], None, None]:
|
184
|
-
"""Process WebDataset shard chunk with unprocessed ranges."""
|
185
|
-
# Get unprocessed ranges
|
186
|
-
unprocessed_ranges = getattr(chunk, "unprocessed_ranges", [(0, chunk.chunk_size - 1)])
|
187
|
-
|
188
|
-
logger.info(
|
189
|
-
f"Processing WebDataset chunk {chunk.chunk_id} with ranges: {unprocessed_ranges}"
|
190
|
-
)
|
191
|
-
|
192
|
-
# Create WebDataset pipeline
|
193
|
-
if self.dataset_type == "huggingface" and not chunk.shard_url.startswith("hf_dataset:"):
|
194
|
-
# Use curl with auth for HuggingFace WebDataset
|
195
|
-
url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.hf_token)}' {shlex.quote(chunk.shard_url)} || true"
|
196
|
-
ds = wds.DataPipeline(
|
197
|
-
wds.SimpleShardList(url_cmd),
|
198
|
-
wds.tarfile_to_samples(),
|
199
|
-
wds.to_tuple("__key__", "jpg;png;jpeg;webp;jxl"),
|
200
|
-
)
|
201
|
-
else:
|
202
|
-
# Local file
|
203
|
-
ds = wds.DataPipeline(
|
204
|
-
wds.SimpleShardList(chunk.shard_url),
|
205
|
-
wds.tarfile_to_samples(),
|
206
|
-
wds.to_tuple("__key__", "jpg;png;jpeg;webp;jxl"),
|
207
|
-
)
|
208
|
-
|
209
|
-
# Process items
|
210
|
-
current_idx = 0
|
211
|
-
items_yielded = 0
|
212
|
-
|
213
|
-
for key, image_data in ds:
|
214
|
-
# Check if we should stop
|
215
|
-
if should_stop.is_set() or not connected.is_set():
|
216
|
-
logger.info(f"Stopping WebDataset chunk processing early due to disconnect")
|
217
|
-
break
|
218
|
-
|
219
|
-
# Calculate relative index within chunk
|
220
|
-
relative_idx = current_idx - chunk.start_index
|
221
|
-
|
222
|
-
# Skip items before chunk start
|
223
|
-
if current_idx < chunk.start_index:
|
224
|
-
current_idx += 1
|
225
|
-
continue
|
226
|
-
|
227
|
-
# Stop if beyond chunk
|
228
|
-
if relative_idx >= chunk.chunk_size:
|
229
|
-
break
|
230
|
-
|
231
|
-
# Check if current index is in any unprocessed range
|
232
|
-
in_range = any(start <= relative_idx <= end for start, end in unprocessed_ranges)
|
233
|
-
|
234
|
-
if in_range:
|
235
|
-
items_yielded += 1
|
236
|
-
yield key, chunk.shard_url, image_data
|
237
|
-
|
238
|
-
current_idx += 1
|
239
|
-
|
240
|
-
logger.info(
|
241
|
-
f"WebDataset chunk {chunk.chunk_id}: yielded {items_yielded} items "
|
242
|
-
f"from ranges {unprocessed_ranges}"
|
243
|
-
)
|
244
|
-
|
245
49
|
def iterate_chunk_with_metadata(
|
246
50
|
self,
|
247
51
|
chunk,
|
@@ -258,7 +62,7 @@ class WebDatasetShardProcessor(ShardProcessor):
|
|
258
62
|
)
|
259
63
|
|
260
64
|
# Create WebDataset pipeline
|
261
|
-
if self.dataset_type == "huggingface"
|
65
|
+
if self.dataset_type == "huggingface":
|
262
66
|
# Use curl with auth for HuggingFace WebDataset
|
263
67
|
url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.hf_token)}' {shlex.quote(chunk.shard_url)} || true"
|
264
68
|
ds = wds.DataPipeline(
|
@@ -61,11 +61,7 @@ class ShardTracker(CheckpointTracker):
|
|
61
61
|
"""Get list of shards that still need processing."""
|
62
62
|
remaining = []
|
63
63
|
for s in all_shards:
|
64
|
-
|
65
|
-
if s.startswith("hf_dataset:"):
|
66
|
-
shard_name = s # Use full virtual shard ID
|
67
|
-
else:
|
68
|
-
shard_name = Path(s).stem
|
64
|
+
shard_name = Path(s).stem
|
69
65
|
|
70
66
|
if shard_name not in self.completed_shards:
|
71
67
|
remaining.append(s)
|
caption_flow/workers/base.py
CHANGED
@@ -74,7 +74,7 @@ class BaseWorker(ABC):
|
|
74
74
|
await self._connect_and_run()
|
75
75
|
reconnect_delay = 5 # Reset delay on successful connection
|
76
76
|
except Exception as e:
|
77
|
-
logger.error(f"Connection error: {e}")
|
77
|
+
logger.error(f"Connection error: {e}", exc_info=True)
|
78
78
|
self.connected.clear()
|
79
79
|
self.websocket = None
|
80
80
|
|
@@ -159,13 +159,13 @@ class BaseWorker(ABC):
|
|
159
159
|
except json.JSONDecodeError as e:
|
160
160
|
logger.error(f"Invalid message format: {e}")
|
161
161
|
except Exception as e:
|
162
|
-
logger.error(f"Error handling message: {e}")
|
162
|
+
logger.error(f"Error handling message: {e}", exc_info=True)
|
163
163
|
|
164
164
|
except websockets.exceptions.ConnectionClosed as e:
|
165
165
|
logger.info(f"Connection closed by orchestrator: {e}")
|
166
166
|
raise
|
167
167
|
except Exception as e:
|
168
|
-
logger.error(f"Message handler error: {e}")
|
168
|
+
logger.error(f"Message handler error: {e}", exc_info=True)
|
169
169
|
raise
|
170
170
|
|
171
171
|
async def shutdown(self):
|