caption-flow 0.2.1__tar.gz → 0.2.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {caption_flow-0.2.1/src/caption_flow.egg-info → caption_flow-0.2.2}/PKG-INFO +1 -1
- {caption_flow-0.2.1 → caption_flow-0.2.2}/pyproject.toml +1 -1
- {caption_flow-0.2.1 → caption_flow-0.2.2}/src/caption_flow/cli.py +1 -1
- {caption_flow-0.2.1 → caption_flow-0.2.2}/src/caption_flow/orchestrator.py +165 -45
- {caption_flow-0.2.1 → caption_flow-0.2.2}/src/caption_flow/storage.py +5 -0
- {caption_flow-0.2.1 → caption_flow-0.2.2}/src/caption_flow/utils/chunk_tracker.py +22 -4
- {caption_flow-0.2.1 → caption_flow-0.2.2}/src/caption_flow/utils/dataset_loader.py +59 -277
- {caption_flow-0.2.1 → caption_flow-0.2.2}/src/caption_flow/utils/shard_processor.py +100 -36
- {caption_flow-0.2.1 → caption_flow-0.2.2/src/caption_flow.egg-info}/PKG-INFO +1 -1
- {caption_flow-0.2.1 → caption_flow-0.2.2}/LICENSE +0 -0
- {caption_flow-0.2.1 → caption_flow-0.2.2}/README.md +0 -0
- {caption_flow-0.2.1 → caption_flow-0.2.2}/setup.cfg +0 -0
- {caption_flow-0.2.1 → caption_flow-0.2.2}/src/caption_flow/__init__.py +0 -0
- {caption_flow-0.2.1 → caption_flow-0.2.2}/src/caption_flow/models.py +0 -0
- {caption_flow-0.2.1 → caption_flow-0.2.2}/src/caption_flow/monitor.py +0 -0
- {caption_flow-0.2.1 → caption_flow-0.2.2}/src/caption_flow/utils/__init__.py +0 -0
- {caption_flow-0.2.1 → caption_flow-0.2.2}/src/caption_flow/utils/auth.py +0 -0
- {caption_flow-0.2.1 → caption_flow-0.2.2}/src/caption_flow/utils/caption_utils.py +0 -0
- {caption_flow-0.2.1 → caption_flow-0.2.2}/src/caption_flow/utils/certificates.py +0 -0
- {caption_flow-0.2.1 → caption_flow-0.2.2}/src/caption_flow/utils/checkpoint_tracker.py +0 -0
- {caption_flow-0.2.1 → caption_flow-0.2.2}/src/caption_flow/utils/image_processor.py +0 -0
- {caption_flow-0.2.1 → caption_flow-0.2.2}/src/caption_flow/utils/job_queue.py +0 -0
- {caption_flow-0.2.1 → caption_flow-0.2.2}/src/caption_flow/utils/json_utils.py +0 -0
- {caption_flow-0.2.1 → caption_flow-0.2.2}/src/caption_flow/utils/prompt_template.py +0 -0
- {caption_flow-0.2.1 → caption_flow-0.2.2}/src/caption_flow/utils/shard_tracker.py +0 -0
- {caption_flow-0.2.1 → caption_flow-0.2.2}/src/caption_flow/utils/vllm_config.py +0 -0
- {caption_flow-0.2.1 → caption_flow-0.2.2}/src/caption_flow/workers/base.py +0 -0
- {caption_flow-0.2.1 → caption_flow-0.2.2}/src/caption_flow/workers/caption.py +0 -0
- {caption_flow-0.2.1 → caption_flow-0.2.2}/src/caption_flow/workers/data.py +0 -0
- {caption_flow-0.2.1 → caption_flow-0.2.2}/src/caption_flow.egg-info/SOURCES.txt +0 -0
- {caption_flow-0.2.1 → caption_flow-0.2.2}/src/caption_flow.egg-info/dependency_links.txt +0 -0
- {caption_flow-0.2.1 → caption_flow-0.2.2}/src/caption_flow.egg-info/entry_points.txt +0 -0
- {caption_flow-0.2.1 → caption_flow-0.2.2}/src/caption_flow.egg-info/requires.txt +0 -0
- {caption_flow-0.2.1 → caption_flow-0.2.2}/src/caption_flow.egg-info/top_level.txt +0 -0
@@ -124,7 +124,7 @@ def setup_logging(verbose: bool = False):
|
|
124
124
|
level = logging.DEBUG if verbose else logging.INFO
|
125
125
|
logging.basicConfig(
|
126
126
|
level=level,
|
127
|
-
format="%(
|
127
|
+
format="%(message)s",
|
128
128
|
datefmt="[%Y-%m-%d %H:%M:%S]",
|
129
129
|
handlers=[
|
130
130
|
RichHandler(
|
@@ -16,7 +16,7 @@ import uuid
|
|
16
16
|
from dataclasses import dataclass, asdict
|
17
17
|
from datetime import datetime
|
18
18
|
from pathlib import Path
|
19
|
-
from typing import Dict, Set, Optional, Any, List, Deque
|
19
|
+
from typing import Dict, Set, Optional, Any, List, Deque, Tuple
|
20
20
|
from collections import deque, defaultdict
|
21
21
|
import threading
|
22
22
|
from queue import Queue, Empty
|
@@ -97,27 +97,9 @@ class ChunkManager:
|
|
97
97
|
self.lock = threading.Lock()
|
98
98
|
self.tracker = tracker # Reference to chunk tracker
|
99
99
|
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
"""Create chunks from a shard."""
|
104
|
-
chunks = []
|
105
|
-
|
106
|
-
for start_idx in range(0, total_items, self.chunk_size):
|
107
|
-
chunk = ShardChunk.create(
|
108
|
-
shard_url=shard_url,
|
109
|
-
shard_name=shard_name,
|
110
|
-
start_index=start_idx,
|
111
|
-
chunk_size=min(self.chunk_size, total_items - start_idx),
|
112
|
-
)
|
113
|
-
|
114
|
-
with self.lock:
|
115
|
-
self.chunks[chunk.chunk_id] = chunk
|
116
|
-
self.pending_chunks.append(chunk.chunk_id)
|
117
|
-
|
118
|
-
chunks.append(chunk)
|
119
|
-
|
120
|
-
return chunks
|
100
|
+
# NEW: Track assigned ranges to prevent double allocation
|
101
|
+
# Format: {chunk_id: {(start, end): worker_id}}
|
102
|
+
self.assigned_ranges: Dict[str, Dict[Tuple[int, int], str]] = defaultdict(dict)
|
121
103
|
|
122
104
|
def get_chunks_for_worker(
|
123
105
|
self, worker_id: str, count: int = 1, tracker: Optional["ChunkTracker"] = None
|
@@ -127,7 +109,6 @@ class ChunkManager:
|
|
127
109
|
|
128
110
|
with self.lock:
|
129
111
|
# FIRST PRIORITY: Check if this worker already has assigned chunks
|
130
|
-
# Workers should complete their current chunks before getting new ones
|
131
112
|
if worker_id in self.assigned_chunks:
|
132
113
|
existing_chunk_ids = list(self.assigned_chunks[worker_id])
|
133
114
|
for chunk_id in existing_chunk_ids:
|
@@ -142,12 +123,29 @@ class ChunkManager:
|
|
142
123
|
if tracker:
|
143
124
|
chunk_info = tracker.get_chunk_with_unprocessed_items(chunk_id)
|
144
125
|
if chunk_info and chunk_info["unprocessed_ranges"]:
|
145
|
-
assigned
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
126
|
+
# Filter out ranges that are assigned to other workers
|
127
|
+
clean_ranges = []
|
128
|
+
for start, end in chunk_info["unprocessed_ranges"]:
|
129
|
+
range_key = (start, end)
|
130
|
+
if range_key in self.assigned_ranges[chunk_id]:
|
131
|
+
assigned_worker = self.assigned_ranges[chunk_id][range_key]
|
132
|
+
if assigned_worker != worker_id:
|
133
|
+
# Skip this range - it's assigned to another worker
|
134
|
+
logger.warning(
|
135
|
+
f"Skipping range {start}-{end} in chunk {chunk_id} "
|
136
|
+
f"(assigned to {assigned_worker}, not {worker_id})"
|
137
|
+
)
|
138
|
+
continue
|
139
|
+
# else: this worker already owns this range, include it
|
140
|
+
clean_ranges.append((start, end))
|
141
|
+
|
142
|
+
if clean_ranges:
|
143
|
+
assigned.append(
|
144
|
+
{
|
145
|
+
"chunk": chunk,
|
146
|
+
"unprocessed_ranges": clean_ranges,
|
147
|
+
}
|
148
|
+
)
|
151
149
|
else:
|
152
150
|
# No tracker, assume chunk needs processing
|
153
151
|
assigned.append(
|
@@ -158,7 +156,6 @@ class ChunkManager:
|
|
158
156
|
)
|
159
157
|
|
160
158
|
# SECOND PRIORITY: Get new pending chunks
|
161
|
-
# Only if worker doesn't have enough chunks already
|
162
159
|
while len(assigned) < count and self.pending_chunks:
|
163
160
|
chunk_id = self.pending_chunks.popleft()
|
164
161
|
chunk = self.chunks.get(chunk_id)
|
@@ -166,7 +163,7 @@ class ChunkManager:
|
|
166
163
|
if not chunk:
|
167
164
|
continue
|
168
165
|
|
169
|
-
# Verify chunk is truly pending
|
166
|
+
# Verify chunk is truly pending
|
170
167
|
if chunk.status != "pending" or chunk.assigned_to is not None:
|
171
168
|
logger.warning(
|
172
169
|
f"Chunk {chunk_id} in pending queue but status={chunk.status}, assigned_to={chunk.assigned_to}"
|
@@ -179,15 +176,48 @@ class ChunkManager:
|
|
179
176
|
chunk.assigned_at = datetime.utcnow()
|
180
177
|
self.assigned_chunks[worker_id].add(chunk_id)
|
181
178
|
|
182
|
-
# Get unprocessed ranges
|
179
|
+
# Get unprocessed ranges and filter out any that are somehow already assigned
|
183
180
|
unprocessed_ranges = [(0, chunk.chunk_size - 1)] # Default
|
184
181
|
if tracker:
|
185
182
|
chunk_info = tracker.get_chunk_with_unprocessed_items(chunk_id)
|
186
183
|
if chunk_info:
|
187
|
-
|
184
|
+
# Filter out any ranges that are already assigned (shouldn't happen for new chunks)
|
185
|
+
clean_ranges = []
|
186
|
+
for start, end in chunk_info["unprocessed_ranges"]:
|
187
|
+
range_key = (start, end)
|
188
|
+
if range_key not in self.assigned_ranges[chunk_id]:
|
189
|
+
clean_ranges.append((start, end))
|
190
|
+
else:
|
191
|
+
logger.error(
|
192
|
+
f"Range {start}-{end} in newly assigned chunk {chunk_id} "
|
193
|
+
f"is already assigned to {self.assigned_ranges[chunk_id][range_key]}!"
|
194
|
+
)
|
195
|
+
unprocessed_ranges = clean_ranges if clean_ranges else []
|
196
|
+
|
188
197
|
tracker.mark_assigned(chunk_id, worker_id)
|
189
198
|
|
190
|
-
|
199
|
+
if unprocessed_ranges:
|
200
|
+
assigned.append({"chunk": chunk, "unprocessed_ranges": unprocessed_ranges})
|
201
|
+
|
202
|
+
# Track assigned ranges and verify no double allocation
|
203
|
+
for info in assigned:
|
204
|
+
chunk_id = info["chunk"].chunk_id
|
205
|
+
for start, end in info["unprocessed_ranges"]:
|
206
|
+
range_key = (start, end)
|
207
|
+
|
208
|
+
# Check if this range is already assigned
|
209
|
+
if range_key in self.assigned_ranges[chunk_id]:
|
210
|
+
existing_worker = self.assigned_ranges[chunk_id][range_key]
|
211
|
+
if existing_worker != worker_id:
|
212
|
+
# This should never happen - raise assertion
|
213
|
+
raise AssertionError(
|
214
|
+
f"CRITICAL: Attempting to assign range {start}-{end} in chunk {chunk_id} "
|
215
|
+
f"to worker {worker_id}, but it's already assigned to {existing_worker}! "
|
216
|
+
f"This would cause duplicate processing."
|
217
|
+
)
|
218
|
+
|
219
|
+
# Track this assignment
|
220
|
+
self.assigned_ranges[chunk_id][range_key] = worker_id
|
191
221
|
|
192
222
|
# Log what we're assigning
|
193
223
|
if assigned:
|
@@ -199,6 +229,12 @@ class ChunkManager:
|
|
199
229
|
)
|
200
230
|
logger.info(f"Assigning to worker {worker_id}: {chunk_summary}")
|
201
231
|
|
232
|
+
# Detailed range logging for debugging
|
233
|
+
for info in assigned:
|
234
|
+
chunk_id = info["chunk"].chunk_id
|
235
|
+
ranges_str = ", ".join([f"{s}-{e}" for s, e in info["unprocessed_ranges"]])
|
236
|
+
logger.debug(f" Chunk {chunk_id} ranges: {ranges_str}")
|
237
|
+
|
202
238
|
return assigned
|
203
239
|
|
204
240
|
def complete_chunk(self, chunk_id: str, worker_id: str) -> bool:
|
@@ -210,6 +246,16 @@ class ChunkManager:
|
|
210
246
|
chunk.status = "completed"
|
211
247
|
chunk.completed_at = datetime.utcnow()
|
212
248
|
self.assigned_chunks[worker_id].discard(chunk_id)
|
249
|
+
|
250
|
+
# Clear assigned ranges for this chunk
|
251
|
+
if chunk_id in self.assigned_ranges:
|
252
|
+
# Log what ranges we're clearing
|
253
|
+
ranges_to_clear = list(self.assigned_ranges[chunk_id].keys())
|
254
|
+
logger.debug(
|
255
|
+
f"Clearing {len(ranges_to_clear)} assigned ranges for completed chunk {chunk_id}"
|
256
|
+
)
|
257
|
+
del self.assigned_ranges[chunk_id]
|
258
|
+
|
213
259
|
return True
|
214
260
|
return False
|
215
261
|
|
@@ -224,6 +270,20 @@ class ChunkManager:
|
|
224
270
|
chunk.assigned_at = None
|
225
271
|
self.assigned_chunks[worker_id].discard(chunk_id)
|
226
272
|
self.pending_chunks.append(chunk_id)
|
273
|
+
|
274
|
+
# Clear assigned ranges for this chunk/worker
|
275
|
+
if chunk_id in self.assigned_ranges:
|
276
|
+
ranges_to_clear = [
|
277
|
+
range_key
|
278
|
+
for range_key, assigned_worker in self.assigned_ranges[chunk_id].items()
|
279
|
+
if assigned_worker == worker_id
|
280
|
+
]
|
281
|
+
for range_key in ranges_to_clear:
|
282
|
+
del self.assigned_ranges[chunk_id][range_key]
|
283
|
+
logger.debug(
|
284
|
+
f"Cleared {len(ranges_to_clear)} assigned ranges for failed chunk {chunk_id}"
|
285
|
+
)
|
286
|
+
|
227
287
|
return True
|
228
288
|
return False
|
229
289
|
|
@@ -240,18 +300,62 @@ class ChunkManager:
|
|
240
300
|
chunk.assigned_at = None
|
241
301
|
self.pending_chunks.append(chunk_id)
|
242
302
|
|
303
|
+
# Clear assigned ranges for this worker
|
304
|
+
if chunk_id in self.assigned_ranges:
|
305
|
+
ranges_to_clear = [
|
306
|
+
range_key
|
307
|
+
for range_key, assigned_worker in self.assigned_ranges[
|
308
|
+
chunk_id
|
309
|
+
].items()
|
310
|
+
if assigned_worker == worker_id
|
311
|
+
]
|
312
|
+
for range_key in ranges_to_clear:
|
313
|
+
del self.assigned_ranges[chunk_id][range_key]
|
314
|
+
|
315
|
+
if ranges_to_clear:
|
316
|
+
logger.info(
|
317
|
+
f"Released {len(ranges_to_clear)} ranges from chunk {chunk_id} "
|
318
|
+
f"previously assigned to disconnected worker {worker_id}"
|
319
|
+
)
|
320
|
+
|
243
321
|
if worker_id in self.assigned_chunks:
|
244
322
|
del self.assigned_chunks[worker_id]
|
245
323
|
|
324
|
+
def mark_ranges_processed(
|
325
|
+
self, chunk_id: str, processed_ranges: List[Tuple[int, int]], worker_id: str
|
326
|
+
):
|
327
|
+
"""Remove ranges from assignment tracking once they're processed."""
|
328
|
+
with self.lock:
|
329
|
+
if chunk_id in self.assigned_ranges:
|
330
|
+
for start, end in processed_ranges:
|
331
|
+
range_key = (start, end)
|
332
|
+
if range_key in self.assigned_ranges[chunk_id]:
|
333
|
+
assigned_worker = self.assigned_ranges[chunk_id][range_key]
|
334
|
+
if assigned_worker == worker_id:
|
335
|
+
del self.assigned_ranges[chunk_id][range_key]
|
336
|
+
logger.debug(
|
337
|
+
f"Cleared assignment of range {start}-{end} in chunk {chunk_id} "
|
338
|
+
f"after processing by {worker_id}"
|
339
|
+
)
|
340
|
+
else:
|
341
|
+
logger.warning(
|
342
|
+
f"Worker {worker_id} claims to have processed range {start}-{end} "
|
343
|
+
f"in chunk {chunk_id}, but it was assigned to {assigned_worker}"
|
344
|
+
)
|
345
|
+
|
246
346
|
def get_stats(self) -> Dict[str, int]:
|
247
347
|
"""Get chunk statistics."""
|
248
348
|
with self.lock:
|
349
|
+
# Count total assigned ranges
|
350
|
+
total_assigned_ranges = sum(len(ranges) for ranges in self.assigned_ranges.values())
|
351
|
+
|
249
352
|
stats = {
|
250
353
|
"total": len(self.chunks),
|
251
354
|
"pending": len(self.pending_chunks),
|
252
355
|
"assigned": sum(len(chunks) for chunks in self.assigned_chunks.values()),
|
253
356
|
"completed": sum(1 for c in self.chunks.values() if c.status == "completed"),
|
254
357
|
"failed": sum(1 for c in self.chunks.values() if c.status == "failed"),
|
358
|
+
"assigned_ranges": total_assigned_ranges,
|
255
359
|
}
|
256
360
|
return stats
|
257
361
|
|
@@ -491,13 +595,15 @@ class Orchestrator:
|
|
491
595
|
with self.chunk_manager.lock:
|
492
596
|
for chunk_state in shard_info["chunks"]:
|
493
597
|
if chunk_state.status in ["pending", "failed", "assigned"]:
|
494
|
-
#
|
598
|
+
# For assigned chunks, reset them to pending since workers don't exist
|
495
599
|
chunk = ShardChunk(
|
496
600
|
chunk_id=chunk_state.chunk_id,
|
497
601
|
shard_url=chunk_state.shard_url,
|
498
602
|
shard_name=chunk_state.shard_name,
|
499
603
|
start_index=chunk_state.start_index,
|
500
604
|
chunk_size=chunk_state.chunk_size,
|
605
|
+
status="pending", # Reset to pending
|
606
|
+
assigned_to=None, # Clear assignment
|
501
607
|
)
|
502
608
|
self.chunk_manager.chunks[chunk_state.chunk_id] = chunk
|
503
609
|
self.chunk_manager.pending_chunks.append(chunk_state.chunk_id)
|
@@ -1811,10 +1917,24 @@ class Orchestrator:
|
|
1811
1917
|
# Don't forget the last range
|
1812
1918
|
ranges.append((start, end))
|
1813
1919
|
|
1814
|
-
# Mark ranges as processed
|
1920
|
+
# Mark ranges as processed
|
1815
1921
|
for start_idx, end_idx in ranges:
|
1816
1922
|
self.chunk_tracker.mark_items_processed(chunk_id, start_idx, end_idx)
|
1817
1923
|
|
1924
|
+
with self.chunk_manager.lock:
|
1925
|
+
if chunk_id in self.chunk_manager.assigned_ranges:
|
1926
|
+
for start_idx, end_idx in ranges:
|
1927
|
+
# Clear any assignments in this range
|
1928
|
+
to_remove = []
|
1929
|
+
for range_start, range_end in self.chunk_manager.assigned_ranges[
|
1930
|
+
chunk_id
|
1931
|
+
]:
|
1932
|
+
if range_start >= start_idx and range_end <= end_idx:
|
1933
|
+
to_remove.append((range_start, range_end))
|
1934
|
+
|
1935
|
+
for range_key in to_remove:
|
1936
|
+
del self.chunk_manager.assigned_ranges[chunk_id][range_key]
|
1937
|
+
|
1818
1938
|
# Clear pending items
|
1819
1939
|
self.pending_processed_items.clear()
|
1820
1940
|
self.last_item_batch_flush = time.time()
|
@@ -2027,15 +2147,15 @@ class Orchestrator:
|
|
2027
2147
|
last_known_total = current_total_outputs
|
2028
2148
|
|
2029
2149
|
# Log rate information when workers are connected
|
2030
|
-
if (
|
2031
|
-
|
2032
|
-
): # Only log non-negative rates
|
2033
|
-
|
2034
|
-
|
2035
|
-
|
2036
|
-
|
2037
|
-
|
2038
|
-
|
2150
|
+
# if (
|
2151
|
+
# worker_count > 0 and self.rate_tracker["current_rate"] >= 0
|
2152
|
+
# ): # Only log non-negative rates
|
2153
|
+
# logger.info(
|
2154
|
+
# f"Rate: {self.rate_tracker['current_rate']:.1f} outputs/min "
|
2155
|
+
# f"(avg: {self.rate_tracker['average_rate']:.1f}, "
|
2156
|
+
# f"expected: {self.rate_tracker['expected_rate']:.1f}) | "
|
2157
|
+
# f"Workers: {worker_count}, Chunks: {active_chunks}/{target_buffer}"
|
2158
|
+
# )
|
2039
2159
|
|
2040
2160
|
await self._broadcast_stats()
|
2041
2161
|
|
@@ -386,10 +386,15 @@ class StorageManager:
|
|
386
386
|
|
387
387
|
# Filter new data to exclude duplicates
|
388
388
|
new_rows = []
|
389
|
+
duplicate_rows = []
|
389
390
|
for row in prepared_buffer:
|
390
391
|
if row["job_id"] not in existing_job_ids:
|
391
392
|
new_rows.append(row)
|
393
|
+
elif row not in duplicate_rows:
|
394
|
+
duplicate_rows.append(row)
|
392
395
|
|
396
|
+
if duplicate_rows:
|
397
|
+
logger.info(f"Example duplicate row: {duplicate_rows[0]}")
|
393
398
|
if new_rows:
|
394
399
|
# Create table from new rows only
|
395
400
|
new_table = pa.Table.from_pylist(new_rows, schema=self.caption_schema)
|
@@ -441,9 +441,27 @@ class ChunkTracker(CheckpointTracker):
|
|
441
441
|
)
|
442
442
|
|
443
443
|
def get_chunk_with_unprocessed_items(self, chunk_id: str) -> Optional[Dict[str, Any]]:
|
444
|
-
"""Get chunk info
|
445
|
-
|
444
|
+
"""Get chunk info with unprocessed item ranges."""
|
445
|
+
chunk_state = self.chunks.get(chunk_id)
|
446
|
+
if not chunk_state:
|
446
447
|
return None
|
447
448
|
|
448
|
-
|
449
|
-
|
449
|
+
# During startup or if no worker is assigned, treat all unprocessed as available
|
450
|
+
if not hasattr(self, "_startup_complete"):
|
451
|
+
self._startup_complete = False
|
452
|
+
|
453
|
+
if not self._startup_complete or not chunk_state.assigned_to:
|
454
|
+
# Return all unprocessed ranges
|
455
|
+
return {
|
456
|
+
"chunk_id": chunk_id,
|
457
|
+
"unprocessed_ranges": chunk_state.get_unprocessed_ranges(),
|
458
|
+
"status": chunk_state.status,
|
459
|
+
}
|
460
|
+
|
461
|
+
# Normal operation - only return ranges not being worked on
|
462
|
+
# This would need more complex tracking of which ranges each worker is processing
|
463
|
+
return {
|
464
|
+
"chunk_id": chunk_id,
|
465
|
+
"unprocessed_ranges": chunk_state.get_unprocessed_ranges(),
|
466
|
+
"status": chunk_state.status,
|
467
|
+
}
|
@@ -217,17 +217,26 @@ class DatasetLoader:
|
|
217
217
|
return dataset_path, start_idx, chunk_size
|
218
218
|
|
219
219
|
def iterate_shard(
|
220
|
-
self,
|
220
|
+
self,
|
221
|
+
shard_url: str,
|
222
|
+
processed_keys: Optional[set] = None,
|
223
|
+
unprocessed_ranges: Optional[List[Tuple[int, int]]] = None,
|
221
224
|
) -> Generator[Tuple[str, str, bytes], None, None]:
|
222
225
|
"""
|
223
226
|
Iterate over items in a shard.
|
224
227
|
|
228
|
+
Args:
|
229
|
+
shard_url: URL or identifier of the shard
|
230
|
+
processed_keys: Set of already processed keys to skip
|
231
|
+
unprocessed_ranges: Specific ranges to process (for HF datasets)
|
232
|
+
|
225
233
|
Yields:
|
226
234
|
Tuple of (key, url, image_bytes)
|
227
235
|
"""
|
228
|
-
# Check if this is a virtual HuggingFace dataset shard
|
229
236
|
if shard_url.startswith("hf_dataset:"):
|
230
|
-
|
237
|
+
raise ValueError(
|
238
|
+
"Virtual HuggingFace dataset shards should use iterate_shard_with_metadata()"
|
239
|
+
)
|
231
240
|
else:
|
232
241
|
# Regular WebDataset shard
|
233
242
|
ds = self.load_shard(shard_url, processed_keys)
|
@@ -296,296 +305,69 @@ class DatasetLoader:
|
|
296
305
|
)
|
297
306
|
|
298
307
|
try:
|
299
|
-
#
|
300
|
-
|
301
|
-
dataset = self._create_dataset_at_position(dataset_path, self.split, start_idx)
|
302
|
-
if dataset:
|
303
|
-
items_processed = 0
|
304
|
-
|
305
|
-
for item in dataset:
|
306
|
-
# Stop after processing chunk_size items
|
307
|
-
if items_processed >= chunk_size:
|
308
|
-
break
|
309
|
-
|
310
|
-
# Generate a unique key for this item
|
311
|
-
key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
|
312
|
-
|
313
|
-
if key in processed_keys:
|
314
|
-
items_processed += 1
|
315
|
-
continue
|
316
|
-
|
317
|
-
try:
|
318
|
-
# Extract image data
|
319
|
-
if self.image_column in item:
|
320
|
-
img_data = item[self.image_column]
|
321
|
-
|
322
|
-
# Process image to bytes
|
323
|
-
image_bytes = ImageProcessor.process_image_data(img_data)
|
324
|
-
|
325
|
-
if image_bytes:
|
326
|
-
# Extract all metadata (excluding the image column)
|
327
|
-
metadata = {
|
328
|
-
k: v for k, v in item.items() if k != self.image_column
|
329
|
-
}
|
330
|
-
|
331
|
-
# URL is virtual for HF datasets
|
332
|
-
url = f"hf://{dataset_path}#{start_idx + items_processed}"
|
333
|
-
items_processed += 1
|
334
|
-
yield key, url, image_bytes, metadata
|
335
|
-
else:
|
336
|
-
logger.warning(
|
337
|
-
f"Failed to process image for item at index {start_idx + items_processed}"
|
338
|
-
)
|
339
|
-
items_processed += 1
|
340
|
-
continue
|
341
|
-
else:
|
342
|
-
logger.warning(
|
343
|
-
f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
|
344
|
-
f"Available columns: {list(item.keys())}"
|
345
|
-
)
|
346
|
-
items_processed += 1
|
347
|
-
|
348
|
-
except Exception as e:
|
349
|
-
logger.error(
|
350
|
-
f"Error processing item at index {start_idx + items_processed}: {e}"
|
351
|
-
)
|
352
|
-
items_processed += 1
|
353
|
-
continue
|
354
|
-
|
355
|
-
return
|
356
|
-
|
357
|
-
# Fall back to regular approach for small skips or if StatefulDataLoader not available
|
358
|
-
dataset = load_dataset(
|
359
|
-
dataset_path,
|
360
|
-
split=self.split,
|
361
|
-
streaming=True,
|
362
|
-
token=self.token,
|
363
|
-
)
|
364
|
-
|
365
|
-
# Skip to start index if needed
|
366
|
-
if start_idx > 0:
|
367
|
-
dataset = dataset.skip(start_idx)
|
368
|
-
|
308
|
+
# For HF datasets, we iterate through the full chunk range
|
309
|
+
# The actual range filtering happens in the shard processor
|
369
310
|
items_processed = 0
|
311
|
+
current_abs_idx = start_idx
|
312
|
+
|
313
|
+
while items_processed < chunk_size:
|
314
|
+
# Create a fresh dataset iterator for each batch
|
315
|
+
# This avoids issues with stateful iterators
|
316
|
+
batch_size = min(1000, chunk_size - items_processed) # Process in smaller batches
|
317
|
+
|
318
|
+
dataset = load_dataset(
|
319
|
+
dataset_path,
|
320
|
+
split=self.split,
|
321
|
+
streaming=True,
|
322
|
+
token=self.token,
|
323
|
+
)
|
370
324
|
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
break
|
375
|
-
|
376
|
-
# Generate a unique key for this item
|
377
|
-
key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
|
378
|
-
|
379
|
-
if key in processed_keys:
|
380
|
-
items_processed += 1
|
381
|
-
continue
|
382
|
-
|
383
|
-
try:
|
384
|
-
# Extract image data
|
385
|
-
if self.image_column in item:
|
386
|
-
img_data = item[self.image_column]
|
325
|
+
# Skip to current position
|
326
|
+
if current_abs_idx > 0:
|
327
|
+
dataset = dataset.skip(current_abs_idx)
|
387
328
|
|
388
|
-
|
389
|
-
|
329
|
+
batch_processed = 0
|
330
|
+
for item in dataset:
|
331
|
+
if batch_processed >= batch_size or items_processed >= chunk_size:
|
332
|
+
break
|
390
333
|
|
391
|
-
|
392
|
-
|
393
|
-
metadata = {k: v for k, v in item.items() if k != self.image_column}
|
334
|
+
# Generate key
|
335
|
+
key = f"{dataset_path.replace('/', '_')}_{current_abs_idx:08d}"
|
394
336
|
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
yield key, url, image_bytes, metadata
|
399
|
-
else:
|
400
|
-
logger.warning(
|
401
|
-
f"Failed to process image for item at index {start_idx + items_processed}"
|
402
|
-
)
|
403
|
-
items_processed += 1
|
404
|
-
continue
|
405
|
-
else:
|
406
|
-
logger.warning(
|
407
|
-
f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
|
408
|
-
f"Available columns: {list(item.keys())}"
|
409
|
-
)
|
337
|
+
if key in processed_keys:
|
338
|
+
current_abs_idx += 1
|
339
|
+
batch_processed += 1
|
410
340
|
items_processed += 1
|
341
|
+
continue
|
411
342
|
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
items_processed += 1
|
417
|
-
continue
|
343
|
+
try:
|
344
|
+
if self.image_column in item:
|
345
|
+
img_data = item[self.image_column]
|
346
|
+
image_bytes = ImageProcessor.process_image_data(img_data)
|
418
347
|
|
419
|
-
|
420
|
-
|
421
|
-
|
348
|
+
if image_bytes:
|
349
|
+
metadata = {k: v for k, v in item.items() if k != self.image_column}
|
350
|
+
url = f"hf://{dataset_path}#{current_abs_idx}"
|
422
351
|
|
423
|
-
|
424
|
-
self, shard_url: str, processed_keys: Optional[set] = None
|
425
|
-
) -> Generator[Tuple[str, str, bytes], None, None]:
|
426
|
-
"""Iterate over a virtual HuggingFace dataset shard."""
|
427
|
-
if processed_keys is None:
|
428
|
-
processed_keys = set()
|
352
|
+
yield key, url, image_bytes, metadata
|
429
353
|
|
430
|
-
|
431
|
-
|
432
|
-
# IMPORTANT: Check if start_idx is beyond dataset bounds
|
433
|
-
if self._hf_total_items is not None and start_idx >= self._hf_total_items:
|
434
|
-
logger.warning(
|
435
|
-
f"Virtual shard starts at index {start_idx} but dataset only has "
|
436
|
-
f"{self._hf_total_items} items. Skipping this shard."
|
437
|
-
)
|
438
|
-
return
|
439
|
-
|
440
|
-
logger.info(
|
441
|
-
f"Loading HuggingFace dataset in streaming mode: {dataset_path} "
|
442
|
-
f"(split: {self.split}, start: {start_idx}, chunk_size: {chunk_size})"
|
443
|
-
)
|
444
|
-
|
445
|
-
try:
|
446
|
-
# Try optimized approach for large skips
|
447
|
-
if start_idx > 100:
|
448
|
-
dataset = self._create_dataset_at_position(dataset_path, self.split, start_idx)
|
449
|
-
if dataset:
|
450
|
-
items_processed = 0
|
451
|
-
|
452
|
-
for item in dataset:
|
453
|
-
# Stop after processing chunk_size items
|
454
|
-
if items_processed >= chunk_size:
|
455
|
-
logger.info(f"Completed chunk: processed {items_processed} items")
|
456
|
-
break
|
457
|
-
|
458
|
-
# Also stop if we've reached the dataset end
|
459
|
-
if (
|
460
|
-
self._hf_total_items
|
461
|
-
and (start_idx + items_processed) >= self._hf_total_items
|
462
|
-
):
|
463
|
-
logger.info(
|
464
|
-
f"Reached dataset end at item {start_idx + items_processed} "
|
465
|
-
f"(total: {self._hf_total_items})"
|
466
|
-
)
|
467
|
-
break
|
468
|
-
|
469
|
-
# Generate a unique key for this item
|
470
|
-
key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
|
471
|
-
|
472
|
-
if key in processed_keys:
|
473
|
-
items_processed += 1
|
474
|
-
continue
|
475
|
-
|
476
|
-
try:
|
477
|
-
# Extract image data
|
478
|
-
if self.image_column in item:
|
479
|
-
img_data = item[self.image_column]
|
480
|
-
|
481
|
-
# Delegate image processing to ImageProcessor
|
482
|
-
image_bytes = ImageProcessor.process_image_data(img_data)
|
483
|
-
|
484
|
-
if image_bytes:
|
485
|
-
# URL is virtual for HF datasets
|
486
|
-
url = f"hf://{dataset_path}#{start_idx + items_processed}"
|
487
|
-
items_processed += 1
|
488
|
-
yield key, url, image_bytes
|
489
|
-
else:
|
490
|
-
logger.warning(
|
491
|
-
f"Failed to process image for item at index {start_idx + items_processed}"
|
492
|
-
)
|
493
|
-
items_processed += 1
|
494
|
-
continue
|
495
|
-
else:
|
496
|
-
logger.warning(
|
497
|
-
f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
|
498
|
-
f"Available columns: {list(item.keys())}"
|
499
|
-
)
|
500
|
-
items_processed += 1
|
501
|
-
|
502
|
-
except Exception as e:
|
503
|
-
logger.error(
|
504
|
-
f"Error processing item at index {start_idx + items_processed}: {e}"
|
505
|
-
)
|
354
|
+
current_abs_idx += 1
|
355
|
+
batch_processed += 1
|
506
356
|
items_processed += 1
|
507
|
-
continue
|
508
|
-
|
509
|
-
logger.info(
|
510
|
-
f"Virtual shard complete: processed {items_processed} items "
|
511
|
-
f"(start_idx: {start_idx})"
|
512
|
-
)
|
513
|
-
return
|
514
|
-
|
515
|
-
# Fall back to regular approach for small skips or if StatefulDataLoader not available
|
516
|
-
dataset = load_dataset(
|
517
|
-
dataset_path,
|
518
|
-
split=self.split,
|
519
|
-
streaming=True,
|
520
|
-
token=self.token,
|
521
|
-
)
|
522
|
-
|
523
|
-
# Use dataset.skip() for efficient skipping
|
524
|
-
if start_idx > 0:
|
525
|
-
dataset = dataset.skip(start_idx)
|
526
|
-
logger.info(f"Skipped to index {start_idx}")
|
527
|
-
|
528
|
-
items_processed = 0
|
529
|
-
|
530
|
-
# Now enumerate starts from 0 after skip
|
531
|
-
for item in dataset:
|
532
|
-
# Stop after processing chunk_size items
|
533
|
-
if items_processed >= chunk_size:
|
534
|
-
logger.info(f"Completed chunk: processed {items_processed} items")
|
535
|
-
break
|
536
|
-
|
537
|
-
# Also stop if we've reached the dataset end
|
538
|
-
if self._hf_total_items and (start_idx + items_processed) >= self._hf_total_items:
|
539
|
-
logger.info(
|
540
|
-
f"Reached dataset end at item {start_idx + items_processed} "
|
541
|
-
f"(total: {self._hf_total_items})"
|
542
|
-
)
|
543
|
-
break
|
544
|
-
|
545
|
-
# Generate a unique key for this item - ensure proper formatting
|
546
|
-
key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
|
547
|
-
|
548
|
-
if key in processed_keys:
|
549
|
-
items_processed += 1
|
550
|
-
continue
|
551
|
-
|
552
|
-
try:
|
553
|
-
# Extract image data - check configured column name
|
554
|
-
if self.image_column in item:
|
555
|
-
img_data = item[self.image_column]
|
556
|
-
|
557
|
-
# Delegate image processing to ImageProcessor
|
558
|
-
image_bytes = ImageProcessor.process_image_data(img_data)
|
559
|
-
|
560
|
-
if image_bytes:
|
561
|
-
# URL is virtual for HF datasets
|
562
|
-
url = f"hf://{dataset_path}#{start_idx + items_processed}"
|
563
|
-
items_processed += 1
|
564
|
-
yield key, url, image_bytes
|
565
357
|
else:
|
566
358
|
logger.warning(
|
567
|
-
f"
|
359
|
+
f"No image column '{self.image_column}' at index {current_abs_idx}"
|
568
360
|
)
|
361
|
+
current_abs_idx += 1
|
362
|
+
batch_processed += 1
|
569
363
|
items_processed += 1
|
570
|
-
continue
|
571
|
-
else:
|
572
|
-
logger.warning(
|
573
|
-
f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
|
574
|
-
f"Available columns: {list(item.keys())}"
|
575
|
-
)
|
576
|
-
items_processed += 1
|
577
|
-
|
578
|
-
except Exception as e:
|
579
|
-
logger.error(
|
580
|
-
f"Error processing item at index {start_idx + items_processed}: {e}"
|
581
|
-
)
|
582
|
-
items_processed += 1
|
583
|
-
continue
|
584
364
|
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
365
|
+
except Exception as e:
|
366
|
+
logger.error(f"Error processing item at index {current_abs_idx}: {e}")
|
367
|
+
current_abs_idx += 1
|
368
|
+
batch_processed += 1
|
369
|
+
items_processed += 1
|
370
|
+
continue
|
589
371
|
|
590
372
|
except Exception as e:
|
591
373
|
logger.error(f"Error loading HuggingFace dataset: {e}")
|
@@ -7,6 +7,8 @@ from abc import ABC, abstractmethod
|
|
7
7
|
from pathlib import Path
|
8
8
|
from typing import Generator, Tuple, Optional, Dict, Any
|
9
9
|
from dataclasses import dataclass
|
10
|
+
from datasets import load_dataset
|
11
|
+
from .image_processor import ImageProcessor
|
10
12
|
from threading import Event
|
11
13
|
import shlex
|
12
14
|
|
@@ -108,10 +110,7 @@ class HFDatasetShardProcessor(ShardProcessor):
|
|
108
110
|
connected: Event,
|
109
111
|
) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
|
110
112
|
"""
|
111
|
-
Process HuggingFace virtual shard chunk with metadata.
|
112
|
-
|
113
|
-
Yields:
|
114
|
-
Tuple of (key, url, image_data, metadata)
|
113
|
+
Process HuggingFace virtual shard chunk with metadata, range by range.
|
115
114
|
"""
|
116
115
|
if not dataset_loader:
|
117
116
|
logger.error("No dataset loader configured for HuggingFace dataset shard")
|
@@ -121,49 +120,114 @@ class HFDatasetShardProcessor(ShardProcessor):
|
|
121
120
|
unprocessed_ranges = getattr(chunk, "unprocessed_ranges", [(0, chunk.chunk_size - 1)])
|
122
121
|
|
123
122
|
logger.info(
|
124
|
-
f"Processing HF dataset chunk {chunk.chunk_id} with
|
123
|
+
f"Processing HF dataset chunk {chunk.chunk_id} with {len(unprocessed_ranges)} ranges"
|
125
124
|
)
|
126
125
|
|
127
|
-
|
128
|
-
current_idx = 0
|
129
|
-
|
130
|
-
# Construct proper virtual shard URL
|
131
|
-
parts = chunk.shard_url.split("_chunk_")
|
132
|
-
if len(parts) == 2:
|
133
|
-
base_path = parts[0]
|
134
|
-
virtual_shard_url = f"{base_path}:chunk:{chunk.start_index}"
|
135
|
-
else:
|
136
|
-
virtual_shard_url = chunk.shard_url
|
137
|
-
|
138
|
-
logger.debug(f"Using virtual shard URL: {virtual_shard_url}")
|
126
|
+
items_yielded = 0
|
139
127
|
|
140
|
-
#
|
141
|
-
for
|
142
|
-
virtual_shard_url
|
143
|
-
):
|
144
|
-
# Check if we should stop
|
128
|
+
# Process each range independently with its own iterator
|
129
|
+
for range_start, range_end in unprocessed_ranges:
|
145
130
|
if should_stop.is_set() or not connected.is_set():
|
146
131
|
logger.info(f"Stopping chunk processing early due to disconnect")
|
147
132
|
break
|
148
133
|
|
149
|
-
#
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
current_idx += 1
|
154
|
-
continue # Skip already processed items
|
134
|
+
# Calculate absolute indices for this range
|
135
|
+
abs_start = chunk.start_index + range_start
|
136
|
+
abs_end = chunk.start_index + range_end
|
137
|
+
range_size = range_end - range_start + 1
|
155
138
|
|
156
|
-
|
157
|
-
|
158
|
-
|
139
|
+
logger.debug(
|
140
|
+
f"Processing range [{range_start}, {range_end}] "
|
141
|
+
f"(absolute: [{abs_start}, {abs_end}])"
|
142
|
+
)
|
159
143
|
|
160
|
-
|
161
|
-
|
162
|
-
|
144
|
+
try:
|
145
|
+
# Create a fresh dataset iterator for this range
|
146
|
+
dataset = load_dataset(
|
147
|
+
dataset_loader.dataset_path,
|
148
|
+
split=dataset_loader.split,
|
149
|
+
streaming=True,
|
150
|
+
token=dataset_loader.token,
|
151
|
+
)
|
152
|
+
|
153
|
+
# Use state_dict if available for efficient positioning
|
154
|
+
if hasattr(dataset, "load_state_dict") and hasattr(dataset, "state_dict"):
|
155
|
+
try:
|
156
|
+
state = dataset.state_dict()
|
157
|
+
# Modify state to jump to abs_start
|
158
|
+
if "num_examples_since_previous_state" in state:
|
159
|
+
state["num_examples_since_previous_state"] = abs_start
|
160
|
+
if "examples_iterable" in state and isinstance(
|
161
|
+
state["examples_iterable"], dict
|
162
|
+
):
|
163
|
+
if "shard_example_idx" in state["examples_iterable"]:
|
164
|
+
state["examples_iterable"]["shard_example_idx"] = abs_start
|
165
|
+
dataset.load_state_dict(state)
|
166
|
+
logger.debug(f"Positioned dataset at index {abs_start} using state_dict")
|
167
|
+
except Exception as e:
|
168
|
+
logger.debug(f"Could not use state_dict, falling back to skip: {e}")
|
169
|
+
dataset = dataset.skip(abs_start)
|
170
|
+
else:
|
171
|
+
# Fall back to skip
|
172
|
+
dataset = dataset.skip(abs_start)
|
173
|
+
|
174
|
+
# Process items in this range
|
175
|
+
range_items = 0
|
176
|
+
for item in dataset:
|
177
|
+
if range_items >= range_size:
|
178
|
+
break
|
179
|
+
|
180
|
+
if should_stop.is_set() or not connected.is_set():
|
181
|
+
break
|
182
|
+
|
183
|
+
# Generate key for this item
|
184
|
+
current_abs_idx = abs_start + range_items
|
185
|
+
key = f"{dataset_loader.dataset_path.replace('/', '_')}_{current_abs_idx:08d}"
|
186
|
+
|
187
|
+
try:
|
188
|
+
if dataset_loader.image_column in item:
|
189
|
+
img_data = item[dataset_loader.image_column]
|
190
|
+
image_bytes = ImageProcessor.process_image_data(img_data)
|
191
|
+
|
192
|
+
if image_bytes:
|
193
|
+
# Extract metadata
|
194
|
+
metadata = {
|
195
|
+
k: v
|
196
|
+
for k, v in item.items()
|
197
|
+
if k != dataset_loader.image_column
|
198
|
+
}
|
199
|
+
# Add chunk-relative index to metadata
|
200
|
+
metadata["_chunk_relative_index"] = range_start + range_items
|
201
|
+
|
202
|
+
url = f"hf://{dataset_loader.dataset_path}#{current_abs_idx}"
|
203
|
+
|
204
|
+
items_yielded += 1
|
205
|
+
range_items += 1
|
206
|
+
|
207
|
+
yield key, url, image_bytes, metadata
|
208
|
+
else:
|
209
|
+
logger.warning(
|
210
|
+
f"Failed to process image at index {current_abs_idx}"
|
211
|
+
)
|
212
|
+
range_items += 1
|
213
|
+
else:
|
214
|
+
logger.warning(
|
215
|
+
f"No image column '{dataset_loader.image_column}' at index {current_abs_idx}"
|
216
|
+
)
|
217
|
+
range_items += 1
|
218
|
+
|
219
|
+
except Exception as e:
|
220
|
+
logger.error(f"Error processing item at index {current_abs_idx}: {e}")
|
221
|
+
range_items += 1
|
222
|
+
continue
|
223
|
+
|
224
|
+
except Exception as e:
|
225
|
+
logger.error(f"Error processing range [{range_start}, {range_end}]: {e}")
|
226
|
+
continue
|
163
227
|
|
164
228
|
logger.info(
|
165
|
-
f"HF dataset chunk {chunk.chunk_id}: yielded {
|
166
|
-
f"from
|
229
|
+
f"HF dataset chunk {chunk.chunk_id}: yielded {items_yielded} items "
|
230
|
+
f"from {len(unprocessed_ranges)} ranges"
|
167
231
|
)
|
168
232
|
|
169
233
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|