caption-flow 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
caption_flow/cli.py CHANGED
@@ -124,7 +124,7 @@ def setup_logging(verbose: bool = False):
124
124
  level = logging.DEBUG if verbose else logging.INFO
125
125
  logging.basicConfig(
126
126
  level=level,
127
- format="%(asctime)s %(message)s",
127
+ format="%(message)s",
128
128
  datefmt="[%Y-%m-%d %H:%M:%S]",
129
129
  handlers=[
130
130
  RichHandler(
@@ -16,7 +16,7 @@ import uuid
16
16
  from dataclasses import dataclass, asdict
17
17
  from datetime import datetime
18
18
  from pathlib import Path
19
- from typing import Dict, Set, Optional, Any, List, Deque
19
+ from typing import Dict, Set, Optional, Any, List, Deque, Tuple
20
20
  from collections import deque, defaultdict
21
21
  import threading
22
22
  from queue import Queue, Empty
@@ -97,27 +97,9 @@ class ChunkManager:
97
97
  self.lock = threading.Lock()
98
98
  self.tracker = tracker # Reference to chunk tracker
99
99
 
100
- def create_chunks_from_shard(
101
- self, shard_url: str, shard_name: str, total_items: int
102
- ) -> List[ShardChunk]:
103
- """Create chunks from a shard."""
104
- chunks = []
105
-
106
- for start_idx in range(0, total_items, self.chunk_size):
107
- chunk = ShardChunk.create(
108
- shard_url=shard_url,
109
- shard_name=shard_name,
110
- start_index=start_idx,
111
- chunk_size=min(self.chunk_size, total_items - start_idx),
112
- )
113
-
114
- with self.lock:
115
- self.chunks[chunk.chunk_id] = chunk
116
- self.pending_chunks.append(chunk.chunk_id)
117
-
118
- chunks.append(chunk)
119
-
120
- return chunks
100
+ # NEW: Track assigned ranges to prevent double allocation
101
+ # Format: {chunk_id: {(start, end): worker_id}}
102
+ self.assigned_ranges: Dict[str, Dict[Tuple[int, int], str]] = defaultdict(dict)
121
103
 
122
104
  def get_chunks_for_worker(
123
105
  self, worker_id: str, count: int = 1, tracker: Optional["ChunkTracker"] = None
@@ -127,7 +109,6 @@ class ChunkManager:
127
109
 
128
110
  with self.lock:
129
111
  # FIRST PRIORITY: Check if this worker already has assigned chunks
130
- # Workers should complete their current chunks before getting new ones
131
112
  if worker_id in self.assigned_chunks:
132
113
  existing_chunk_ids = list(self.assigned_chunks[worker_id])
133
114
  for chunk_id in existing_chunk_ids:
@@ -142,12 +123,29 @@ class ChunkManager:
142
123
  if tracker:
143
124
  chunk_info = tracker.get_chunk_with_unprocessed_items(chunk_id)
144
125
  if chunk_info and chunk_info["unprocessed_ranges"]:
145
- assigned.append(
146
- {
147
- "chunk": chunk,
148
- "unprocessed_ranges": chunk_info["unprocessed_ranges"],
149
- }
150
- )
126
+ # Filter out ranges that are assigned to other workers
127
+ clean_ranges = []
128
+ for start, end in chunk_info["unprocessed_ranges"]:
129
+ range_key = (start, end)
130
+ if range_key in self.assigned_ranges[chunk_id]:
131
+ assigned_worker = self.assigned_ranges[chunk_id][range_key]
132
+ if assigned_worker != worker_id:
133
+ # Skip this range - it's assigned to another worker
134
+ logger.warning(
135
+ f"Skipping range {start}-{end} in chunk {chunk_id} "
136
+ f"(assigned to {assigned_worker}, not {worker_id})"
137
+ )
138
+ continue
139
+ # else: this worker already owns this range, include it
140
+ clean_ranges.append((start, end))
141
+
142
+ if clean_ranges:
143
+ assigned.append(
144
+ {
145
+ "chunk": chunk,
146
+ "unprocessed_ranges": clean_ranges,
147
+ }
148
+ )
151
149
  else:
152
150
  # No tracker, assume chunk needs processing
153
151
  assigned.append(
@@ -158,7 +156,6 @@ class ChunkManager:
158
156
  )
159
157
 
160
158
  # SECOND PRIORITY: Get new pending chunks
161
- # Only if worker doesn't have enough chunks already
162
159
  while len(assigned) < count and self.pending_chunks:
163
160
  chunk_id = self.pending_chunks.popleft()
164
161
  chunk = self.chunks.get(chunk_id)
@@ -166,7 +163,7 @@ class ChunkManager:
166
163
  if not chunk:
167
164
  continue
168
165
 
169
- # Verify chunk is truly pending (defensive check)
166
+ # Verify chunk is truly pending
170
167
  if chunk.status != "pending" or chunk.assigned_to is not None:
171
168
  logger.warning(
172
169
  f"Chunk {chunk_id} in pending queue but status={chunk.status}, assigned_to={chunk.assigned_to}"
@@ -179,15 +176,48 @@ class ChunkManager:
179
176
  chunk.assigned_at = datetime.utcnow()
180
177
  self.assigned_chunks[worker_id].add(chunk_id)
181
178
 
182
- # Get unprocessed ranges
179
+ # Get unprocessed ranges and filter out any that are somehow already assigned
183
180
  unprocessed_ranges = [(0, chunk.chunk_size - 1)] # Default
184
181
  if tracker:
185
182
  chunk_info = tracker.get_chunk_with_unprocessed_items(chunk_id)
186
183
  if chunk_info:
187
- unprocessed_ranges = chunk_info["unprocessed_ranges"]
184
+ # Filter out any ranges that are already assigned (shouldn't happen for new chunks)
185
+ clean_ranges = []
186
+ for start, end in chunk_info["unprocessed_ranges"]:
187
+ range_key = (start, end)
188
+ if range_key not in self.assigned_ranges[chunk_id]:
189
+ clean_ranges.append((start, end))
190
+ else:
191
+ logger.error(
192
+ f"Range {start}-{end} in newly assigned chunk {chunk_id} "
193
+ f"is already assigned to {self.assigned_ranges[chunk_id][range_key]}!"
194
+ )
195
+ unprocessed_ranges = clean_ranges if clean_ranges else []
196
+
188
197
  tracker.mark_assigned(chunk_id, worker_id)
189
198
 
190
- assigned.append({"chunk": chunk, "unprocessed_ranges": unprocessed_ranges})
199
+ if unprocessed_ranges:
200
+ assigned.append({"chunk": chunk, "unprocessed_ranges": unprocessed_ranges})
201
+
202
+ # Track assigned ranges and verify no double allocation
203
+ for info in assigned:
204
+ chunk_id = info["chunk"].chunk_id
205
+ for start, end in info["unprocessed_ranges"]:
206
+ range_key = (start, end)
207
+
208
+ # Check if this range is already assigned
209
+ if range_key in self.assigned_ranges[chunk_id]:
210
+ existing_worker = self.assigned_ranges[chunk_id][range_key]
211
+ if existing_worker != worker_id:
212
+ # This should never happen - raise assertion
213
+ raise AssertionError(
214
+ f"CRITICAL: Attempting to assign range {start}-{end} in chunk {chunk_id} "
215
+ f"to worker {worker_id}, but it's already assigned to {existing_worker}! "
216
+ f"This would cause duplicate processing."
217
+ )
218
+
219
+ # Track this assignment
220
+ self.assigned_ranges[chunk_id][range_key] = worker_id
191
221
 
192
222
  # Log what we're assigning
193
223
  if assigned:
@@ -199,6 +229,12 @@ class ChunkManager:
199
229
  )
200
230
  logger.info(f"Assigning to worker {worker_id}: {chunk_summary}")
201
231
 
232
+ # Detailed range logging for debugging
233
+ for info in assigned:
234
+ chunk_id = info["chunk"].chunk_id
235
+ ranges_str = ", ".join([f"{s}-{e}" for s, e in info["unprocessed_ranges"]])
236
+ logger.debug(f" Chunk {chunk_id} ranges: {ranges_str}")
237
+
202
238
  return assigned
203
239
 
204
240
  def complete_chunk(self, chunk_id: str, worker_id: str) -> bool:
@@ -210,6 +246,16 @@ class ChunkManager:
210
246
  chunk.status = "completed"
211
247
  chunk.completed_at = datetime.utcnow()
212
248
  self.assigned_chunks[worker_id].discard(chunk_id)
249
+
250
+ # Clear assigned ranges for this chunk
251
+ if chunk_id in self.assigned_ranges:
252
+ # Log what ranges we're clearing
253
+ ranges_to_clear = list(self.assigned_ranges[chunk_id].keys())
254
+ logger.debug(
255
+ f"Clearing {len(ranges_to_clear)} assigned ranges for completed chunk {chunk_id}"
256
+ )
257
+ del self.assigned_ranges[chunk_id]
258
+
213
259
  return True
214
260
  return False
215
261
 
@@ -224,6 +270,20 @@ class ChunkManager:
224
270
  chunk.assigned_at = None
225
271
  self.assigned_chunks[worker_id].discard(chunk_id)
226
272
  self.pending_chunks.append(chunk_id)
273
+
274
+ # Clear assigned ranges for this chunk/worker
275
+ if chunk_id in self.assigned_ranges:
276
+ ranges_to_clear = [
277
+ range_key
278
+ for range_key, assigned_worker in self.assigned_ranges[chunk_id].items()
279
+ if assigned_worker == worker_id
280
+ ]
281
+ for range_key in ranges_to_clear:
282
+ del self.assigned_ranges[chunk_id][range_key]
283
+ logger.debug(
284
+ f"Cleared {len(ranges_to_clear)} assigned ranges for failed chunk {chunk_id}"
285
+ )
286
+
227
287
  return True
228
288
  return False
229
289
 
@@ -240,18 +300,62 @@ class ChunkManager:
240
300
  chunk.assigned_at = None
241
301
  self.pending_chunks.append(chunk_id)
242
302
 
303
+ # Clear assigned ranges for this worker
304
+ if chunk_id in self.assigned_ranges:
305
+ ranges_to_clear = [
306
+ range_key
307
+ for range_key, assigned_worker in self.assigned_ranges[
308
+ chunk_id
309
+ ].items()
310
+ if assigned_worker == worker_id
311
+ ]
312
+ for range_key in ranges_to_clear:
313
+ del self.assigned_ranges[chunk_id][range_key]
314
+
315
+ if ranges_to_clear:
316
+ logger.info(
317
+ f"Released {len(ranges_to_clear)} ranges from chunk {chunk_id} "
318
+ f"previously assigned to disconnected worker {worker_id}"
319
+ )
320
+
243
321
  if worker_id in self.assigned_chunks:
244
322
  del self.assigned_chunks[worker_id]
245
323
 
324
+ def mark_ranges_processed(
325
+ self, chunk_id: str, processed_ranges: List[Tuple[int, int]], worker_id: str
326
+ ):
327
+ """Remove ranges from assignment tracking once they're processed."""
328
+ with self.lock:
329
+ if chunk_id in self.assigned_ranges:
330
+ for start, end in processed_ranges:
331
+ range_key = (start, end)
332
+ if range_key in self.assigned_ranges[chunk_id]:
333
+ assigned_worker = self.assigned_ranges[chunk_id][range_key]
334
+ if assigned_worker == worker_id:
335
+ del self.assigned_ranges[chunk_id][range_key]
336
+ logger.debug(
337
+ f"Cleared assignment of range {start}-{end} in chunk {chunk_id} "
338
+ f"after processing by {worker_id}"
339
+ )
340
+ else:
341
+ logger.warning(
342
+ f"Worker {worker_id} claims to have processed range {start}-{end} "
343
+ f"in chunk {chunk_id}, but it was assigned to {assigned_worker}"
344
+ )
345
+
246
346
  def get_stats(self) -> Dict[str, int]:
247
347
  """Get chunk statistics."""
248
348
  with self.lock:
349
+ # Count total assigned ranges
350
+ total_assigned_ranges = sum(len(ranges) for ranges in self.assigned_ranges.values())
351
+
249
352
  stats = {
250
353
  "total": len(self.chunks),
251
354
  "pending": len(self.pending_chunks),
252
355
  "assigned": sum(len(chunks) for chunks in self.assigned_chunks.values()),
253
356
  "completed": sum(1 for c in self.chunks.values() if c.status == "completed"),
254
357
  "failed": sum(1 for c in self.chunks.values() if c.status == "failed"),
358
+ "assigned_ranges": total_assigned_ranges,
255
359
  }
256
360
  return stats
257
361
 
@@ -491,13 +595,15 @@ class Orchestrator:
491
595
  with self.chunk_manager.lock:
492
596
  for chunk_state in shard_info["chunks"]:
493
597
  if chunk_state.status in ["pending", "failed", "assigned"]:
494
- # ChunkState already has shard_url stored
598
+ # For assigned chunks, reset them to pending since workers don't exist
495
599
  chunk = ShardChunk(
496
600
  chunk_id=chunk_state.chunk_id,
497
601
  shard_url=chunk_state.shard_url,
498
602
  shard_name=chunk_state.shard_name,
499
603
  start_index=chunk_state.start_index,
500
604
  chunk_size=chunk_state.chunk_size,
605
+ status="pending", # Reset to pending
606
+ assigned_to=None, # Clear assignment
501
607
  )
502
608
  self.chunk_manager.chunks[chunk_state.chunk_id] = chunk
503
609
  self.chunk_manager.pending_chunks.append(chunk_state.chunk_id)
@@ -1811,10 +1917,24 @@ class Orchestrator:
1811
1917
  # Don't forget the last range
1812
1918
  ranges.append((start, end))
1813
1919
 
1814
- # Mark ranges as processed (mark_items_processed expects absolute indices)
1920
+ # Mark ranges as processed
1815
1921
  for start_idx, end_idx in ranges:
1816
1922
  self.chunk_tracker.mark_items_processed(chunk_id, start_idx, end_idx)
1817
1923
 
1924
+ with self.chunk_manager.lock:
1925
+ if chunk_id in self.chunk_manager.assigned_ranges:
1926
+ for start_idx, end_idx in ranges:
1927
+ # Clear any assignments in this range
1928
+ to_remove = []
1929
+ for range_start, range_end in self.chunk_manager.assigned_ranges[
1930
+ chunk_id
1931
+ ]:
1932
+ if range_start >= start_idx and range_end <= end_idx:
1933
+ to_remove.append((range_start, range_end))
1934
+
1935
+ for range_key in to_remove:
1936
+ del self.chunk_manager.assigned_ranges[chunk_id][range_key]
1937
+
1818
1938
  # Clear pending items
1819
1939
  self.pending_processed_items.clear()
1820
1940
  self.last_item_batch_flush = time.time()
@@ -2027,15 +2147,15 @@ class Orchestrator:
2027
2147
  last_known_total = current_total_outputs
2028
2148
 
2029
2149
  # Log rate information when workers are connected
2030
- if (
2031
- worker_count > 0 and self.rate_tracker["current_rate"] >= 0
2032
- ): # Only log non-negative rates
2033
- logger.info(
2034
- f"Rate: {self.rate_tracker['current_rate']:.1f} outputs/min "
2035
- f"(avg: {self.rate_tracker['average_rate']:.1f}, "
2036
- f"expected: {self.rate_tracker['expected_rate']:.1f}) | "
2037
- f"Workers: {worker_count}, Chunks: {active_chunks}/{target_buffer}"
2038
- )
2150
+ # if (
2151
+ # worker_count > 0 and self.rate_tracker["current_rate"] >= 0
2152
+ # ): # Only log non-negative rates
2153
+ # logger.info(
2154
+ # f"Rate: {self.rate_tracker['current_rate']:.1f} outputs/min "
2155
+ # f"(avg: {self.rate_tracker['average_rate']:.1f}, "
2156
+ # f"expected: {self.rate_tracker['expected_rate']:.1f}) | "
2157
+ # f"Workers: {worker_count}, Chunks: {active_chunks}/{target_buffer}"
2158
+ # )
2039
2159
 
2040
2160
  await self._broadcast_stats()
2041
2161
 
caption_flow/storage.py CHANGED
@@ -386,10 +386,15 @@ class StorageManager:
386
386
 
387
387
  # Filter new data to exclude duplicates
388
388
  new_rows = []
389
+ duplicate_rows = []
389
390
  for row in prepared_buffer:
390
391
  if row["job_id"] not in existing_job_ids:
391
392
  new_rows.append(row)
393
+ elif row not in duplicate_rows:
394
+ duplicate_rows.append(row)
392
395
 
396
+ if duplicate_rows:
397
+ logger.info(f"Example duplicate row: {duplicate_rows[0]}")
393
398
  if new_rows:
394
399
  # Create table from new rows only
395
400
  new_table = pa.Table.from_pylist(new_rows, schema=self.caption_schema)
@@ -441,9 +441,27 @@ class ChunkTracker(CheckpointTracker):
441
441
  )
442
442
 
443
443
  def get_chunk_with_unprocessed_items(self, chunk_id: str) -> Optional[Dict[str, Any]]:
444
- """Get chunk info including unprocessed ranges."""
445
- if chunk_id not in self.chunks:
444
+ """Get chunk info with unprocessed item ranges."""
445
+ chunk_state = self.chunks.get(chunk_id)
446
+ if not chunk_state:
446
447
  return None
447
448
 
448
- chunk = self.chunks[chunk_id]
449
- return {"chunk": chunk.to_dict(), "unprocessed_ranges": chunk.get_unprocessed_ranges()}
449
+ # During startup or if no worker is assigned, treat all unprocessed as available
450
+ if not hasattr(self, "_startup_complete"):
451
+ self._startup_complete = False
452
+
453
+ if not self._startup_complete or not chunk_state.assigned_to:
454
+ # Return all unprocessed ranges
455
+ return {
456
+ "chunk_id": chunk_id,
457
+ "unprocessed_ranges": chunk_state.get_unprocessed_ranges(),
458
+ "status": chunk_state.status,
459
+ }
460
+
461
+ # Normal operation - only return ranges not being worked on
462
+ # This would need more complex tracking of which ranges each worker is processing
463
+ return {
464
+ "chunk_id": chunk_id,
465
+ "unprocessed_ranges": chunk_state.get_unprocessed_ranges(),
466
+ "status": chunk_state.status,
467
+ }
@@ -217,17 +217,26 @@ class DatasetLoader:
217
217
  return dataset_path, start_idx, chunk_size
218
218
 
219
219
  def iterate_shard(
220
- self, shard_url: str, processed_keys: Optional[set] = None
220
+ self,
221
+ shard_url: str,
222
+ processed_keys: Optional[set] = None,
223
+ unprocessed_ranges: Optional[List[Tuple[int, int]]] = None,
221
224
  ) -> Generator[Tuple[str, str, bytes], None, None]:
222
225
  """
223
226
  Iterate over items in a shard.
224
227
 
228
+ Args:
229
+ shard_url: URL or identifier of the shard
230
+ processed_keys: Set of already processed keys to skip
231
+ unprocessed_ranges: Specific ranges to process (for HF datasets)
232
+
225
233
  Yields:
226
234
  Tuple of (key, url, image_bytes)
227
235
  """
228
- # Check if this is a virtual HuggingFace dataset shard
229
236
  if shard_url.startswith("hf_dataset:"):
230
- yield from self._iterate_hf_dataset_shard(shard_url, processed_keys)
237
+ raise ValueError(
238
+ "Virtual HuggingFace dataset shards should use iterate_shard_with_metadata()"
239
+ )
231
240
  else:
232
241
  # Regular WebDataset shard
233
242
  ds = self.load_shard(shard_url, processed_keys)
@@ -296,296 +305,69 @@ class DatasetLoader:
296
305
  )
297
306
 
298
307
  try:
299
- # Try optimized approach for large skips
300
- if start_idx > 100:
301
- dataset = self._create_dataset_at_position(dataset_path, self.split, start_idx)
302
- if dataset:
303
- items_processed = 0
304
-
305
- for item in dataset:
306
- # Stop after processing chunk_size items
307
- if items_processed >= chunk_size:
308
- break
309
-
310
- # Generate a unique key for this item
311
- key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
312
-
313
- if key in processed_keys:
314
- items_processed += 1
315
- continue
316
-
317
- try:
318
- # Extract image data
319
- if self.image_column in item:
320
- img_data = item[self.image_column]
321
-
322
- # Process image to bytes
323
- image_bytes = ImageProcessor.process_image_data(img_data)
324
-
325
- if image_bytes:
326
- # Extract all metadata (excluding the image column)
327
- metadata = {
328
- k: v for k, v in item.items() if k != self.image_column
329
- }
330
-
331
- # URL is virtual for HF datasets
332
- url = f"hf://{dataset_path}#{start_idx + items_processed}"
333
- items_processed += 1
334
- yield key, url, image_bytes, metadata
335
- else:
336
- logger.warning(
337
- f"Failed to process image for item at index {start_idx + items_processed}"
338
- )
339
- items_processed += 1
340
- continue
341
- else:
342
- logger.warning(
343
- f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
344
- f"Available columns: {list(item.keys())}"
345
- )
346
- items_processed += 1
347
-
348
- except Exception as e:
349
- logger.error(
350
- f"Error processing item at index {start_idx + items_processed}: {e}"
351
- )
352
- items_processed += 1
353
- continue
354
-
355
- return
356
-
357
- # Fall back to regular approach for small skips or if StatefulDataLoader not available
358
- dataset = load_dataset(
359
- dataset_path,
360
- split=self.split,
361
- streaming=True,
362
- token=self.token,
363
- )
364
-
365
- # Skip to start index if needed
366
- if start_idx > 0:
367
- dataset = dataset.skip(start_idx)
368
-
308
+ # For HF datasets, we iterate through the full chunk range
309
+ # The actual range filtering happens in the shard processor
369
310
  items_processed = 0
311
+ current_abs_idx = start_idx
312
+
313
+ while items_processed < chunk_size:
314
+ # Create a fresh dataset iterator for each batch
315
+ # This avoids issues with stateful iterators
316
+ batch_size = min(1000, chunk_size - items_processed) # Process in smaller batches
317
+
318
+ dataset = load_dataset(
319
+ dataset_path,
320
+ split=self.split,
321
+ streaming=True,
322
+ token=self.token,
323
+ )
370
324
 
371
- for item in dataset:
372
- # Stop after processing chunk_size items
373
- if items_processed >= chunk_size:
374
- break
375
-
376
- # Generate a unique key for this item
377
- key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
378
-
379
- if key in processed_keys:
380
- items_processed += 1
381
- continue
382
-
383
- try:
384
- # Extract image data
385
- if self.image_column in item:
386
- img_data = item[self.image_column]
325
+ # Skip to current position
326
+ if current_abs_idx > 0:
327
+ dataset = dataset.skip(current_abs_idx)
387
328
 
388
- # Process image to bytes
389
- image_bytes = ImageProcessor.process_image_data(img_data)
329
+ batch_processed = 0
330
+ for item in dataset:
331
+ if batch_processed >= batch_size or items_processed >= chunk_size:
332
+ break
390
333
 
391
- if image_bytes:
392
- # Extract all metadata (excluding the image column)
393
- metadata = {k: v for k, v in item.items() if k != self.image_column}
334
+ # Generate key
335
+ key = f"{dataset_path.replace('/', '_')}_{current_abs_idx:08d}"
394
336
 
395
- # URL is virtual for HF datasets
396
- url = f"hf://{dataset_path}#{start_idx + items_processed}"
397
- items_processed += 1
398
- yield key, url, image_bytes, metadata
399
- else:
400
- logger.warning(
401
- f"Failed to process image for item at index {start_idx + items_processed}"
402
- )
403
- items_processed += 1
404
- continue
405
- else:
406
- logger.warning(
407
- f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
408
- f"Available columns: {list(item.keys())}"
409
- )
337
+ if key in processed_keys:
338
+ current_abs_idx += 1
339
+ batch_processed += 1
410
340
  items_processed += 1
341
+ continue
411
342
 
412
- except Exception as e:
413
- logger.error(
414
- f"Error processing item at index {start_idx + items_processed}: {e}"
415
- )
416
- items_processed += 1
417
- continue
343
+ try:
344
+ if self.image_column in item:
345
+ img_data = item[self.image_column]
346
+ image_bytes = ImageProcessor.process_image_data(img_data)
418
347
 
419
- except Exception as e:
420
- logger.error(f"Error loading HuggingFace dataset: {e}")
421
- return
348
+ if image_bytes:
349
+ metadata = {k: v for k, v in item.items() if k != self.image_column}
350
+ url = f"hf://{dataset_path}#{current_abs_idx}"
422
351
 
423
- def _iterate_hf_dataset_shard(
424
- self, shard_url: str, processed_keys: Optional[set] = None
425
- ) -> Generator[Tuple[str, str, bytes], None, None]:
426
- """Iterate over a virtual HuggingFace dataset shard."""
427
- if processed_keys is None:
428
- processed_keys = set()
352
+ yield key, url, image_bytes, metadata
429
353
 
430
- dataset_path, start_idx, chunk_size = self._parse_virtual_shard(shard_url)
431
-
432
- # IMPORTANT: Check if start_idx is beyond dataset bounds
433
- if self._hf_total_items is not None and start_idx >= self._hf_total_items:
434
- logger.warning(
435
- f"Virtual shard starts at index {start_idx} but dataset only has "
436
- f"{self._hf_total_items} items. Skipping this shard."
437
- )
438
- return
439
-
440
- logger.info(
441
- f"Loading HuggingFace dataset in streaming mode: {dataset_path} "
442
- f"(split: {self.split}, start: {start_idx}, chunk_size: {chunk_size})"
443
- )
444
-
445
- try:
446
- # Try optimized approach for large skips
447
- if start_idx > 100:
448
- dataset = self._create_dataset_at_position(dataset_path, self.split, start_idx)
449
- if dataset:
450
- items_processed = 0
451
-
452
- for item in dataset:
453
- # Stop after processing chunk_size items
454
- if items_processed >= chunk_size:
455
- logger.info(f"Completed chunk: processed {items_processed} items")
456
- break
457
-
458
- # Also stop if we've reached the dataset end
459
- if (
460
- self._hf_total_items
461
- and (start_idx + items_processed) >= self._hf_total_items
462
- ):
463
- logger.info(
464
- f"Reached dataset end at item {start_idx + items_processed} "
465
- f"(total: {self._hf_total_items})"
466
- )
467
- break
468
-
469
- # Generate a unique key for this item
470
- key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
471
-
472
- if key in processed_keys:
473
- items_processed += 1
474
- continue
475
-
476
- try:
477
- # Extract image data
478
- if self.image_column in item:
479
- img_data = item[self.image_column]
480
-
481
- # Delegate image processing to ImageProcessor
482
- image_bytes = ImageProcessor.process_image_data(img_data)
483
-
484
- if image_bytes:
485
- # URL is virtual for HF datasets
486
- url = f"hf://{dataset_path}#{start_idx + items_processed}"
487
- items_processed += 1
488
- yield key, url, image_bytes
489
- else:
490
- logger.warning(
491
- f"Failed to process image for item at index {start_idx + items_processed}"
492
- )
493
- items_processed += 1
494
- continue
495
- else:
496
- logger.warning(
497
- f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
498
- f"Available columns: {list(item.keys())}"
499
- )
500
- items_processed += 1
501
-
502
- except Exception as e:
503
- logger.error(
504
- f"Error processing item at index {start_idx + items_processed}: {e}"
505
- )
354
+ current_abs_idx += 1
355
+ batch_processed += 1
506
356
  items_processed += 1
507
- continue
508
-
509
- logger.info(
510
- f"Virtual shard complete: processed {items_processed} items "
511
- f"(start_idx: {start_idx})"
512
- )
513
- return
514
-
515
- # Fall back to regular approach for small skips or if StatefulDataLoader not available
516
- dataset = load_dataset(
517
- dataset_path,
518
- split=self.split,
519
- streaming=True,
520
- token=self.token,
521
- )
522
-
523
- # Use dataset.skip() for efficient skipping
524
- if start_idx > 0:
525
- dataset = dataset.skip(start_idx)
526
- logger.info(f"Skipped to index {start_idx}")
527
-
528
- items_processed = 0
529
-
530
- # Now enumerate starts from 0 after skip
531
- for item in dataset:
532
- # Stop after processing chunk_size items
533
- if items_processed >= chunk_size:
534
- logger.info(f"Completed chunk: processed {items_processed} items")
535
- break
536
-
537
- # Also stop if we've reached the dataset end
538
- if self._hf_total_items and (start_idx + items_processed) >= self._hf_total_items:
539
- logger.info(
540
- f"Reached dataset end at item {start_idx + items_processed} "
541
- f"(total: {self._hf_total_items})"
542
- )
543
- break
544
-
545
- # Generate a unique key for this item - ensure proper formatting
546
- key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
547
-
548
- if key in processed_keys:
549
- items_processed += 1
550
- continue
551
-
552
- try:
553
- # Extract image data - check configured column name
554
- if self.image_column in item:
555
- img_data = item[self.image_column]
556
-
557
- # Delegate image processing to ImageProcessor
558
- image_bytes = ImageProcessor.process_image_data(img_data)
559
-
560
- if image_bytes:
561
- # URL is virtual for HF datasets
562
- url = f"hf://{dataset_path}#{start_idx + items_processed}"
563
- items_processed += 1
564
- yield key, url, image_bytes
565
357
  else:
566
358
  logger.warning(
567
- f"Failed to process image for item at index {start_idx + items_processed}"
359
+ f"No image column '{self.image_column}' at index {current_abs_idx}"
568
360
  )
361
+ current_abs_idx += 1
362
+ batch_processed += 1
569
363
  items_processed += 1
570
- continue
571
- else:
572
- logger.warning(
573
- f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
574
- f"Available columns: {list(item.keys())}"
575
- )
576
- items_processed += 1
577
-
578
- except Exception as e:
579
- logger.error(
580
- f"Error processing item at index {start_idx + items_processed}: {e}"
581
- )
582
- items_processed += 1
583
- continue
584
364
 
585
- logger.info(
586
- f"Virtual shard complete: processed {items_processed} items "
587
- f"(start_idx: {start_idx})"
588
- )
365
+ except Exception as e:
366
+ logger.error(f"Error processing item at index {current_abs_idx}: {e}")
367
+ current_abs_idx += 1
368
+ batch_processed += 1
369
+ items_processed += 1
370
+ continue
589
371
 
590
372
  except Exception as e:
591
373
  logger.error(f"Error loading HuggingFace dataset: {e}")
@@ -7,6 +7,8 @@ from abc import ABC, abstractmethod
7
7
  from pathlib import Path
8
8
  from typing import Generator, Tuple, Optional, Dict, Any
9
9
  from dataclasses import dataclass
10
+ from datasets import load_dataset
11
+ from .image_processor import ImageProcessor
10
12
  from threading import Event
11
13
  import shlex
12
14
 
@@ -108,10 +110,7 @@ class HFDatasetShardProcessor(ShardProcessor):
108
110
  connected: Event,
109
111
  ) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
110
112
  """
111
- Process HuggingFace virtual shard chunk with metadata.
112
-
113
- Yields:
114
- Tuple of (key, url, image_data, metadata)
113
+ Process HuggingFace virtual shard chunk with metadata, range by range.
115
114
  """
116
115
  if not dataset_loader:
117
116
  logger.error("No dataset loader configured for HuggingFace dataset shard")
@@ -121,49 +120,114 @@ class HFDatasetShardProcessor(ShardProcessor):
121
120
  unprocessed_ranges = getattr(chunk, "unprocessed_ranges", [(0, chunk.chunk_size - 1)])
122
121
 
123
122
  logger.info(
124
- f"Processing HF dataset chunk {chunk.chunk_id} with ranges: {unprocessed_ranges}"
123
+ f"Processing HF dataset chunk {chunk.chunk_id} with {len(unprocessed_ranges)} ranges"
125
124
  )
126
125
 
127
- items_processed = 0
128
- current_idx = 0
129
-
130
- # Construct proper virtual shard URL
131
- parts = chunk.shard_url.split("_chunk_")
132
- if len(parts) == 2:
133
- base_path = parts[0]
134
- virtual_shard_url = f"{base_path}:chunk:{chunk.start_index}"
135
- else:
136
- virtual_shard_url = chunk.shard_url
137
-
138
- logger.debug(f"Using virtual shard URL: {virtual_shard_url}")
126
+ items_yielded = 0
139
127
 
140
- # Use the new iterate method that includes metadata
141
- for key, url, image_data, metadata in dataset_loader.iterate_shard_with_metadata(
142
- virtual_shard_url
143
- ):
144
- # Check if we should stop
128
+ # Process each range independently with its own iterator
129
+ for range_start, range_end in unprocessed_ranges:
145
130
  if should_stop.is_set() or not connected.is_set():
146
131
  logger.info(f"Stopping chunk processing early due to disconnect")
147
132
  break
148
133
 
149
- # Check if current index is in any unprocessed range
150
- in_range = any(start <= current_idx <= end for start, end in unprocessed_ranges)
151
-
152
- if not in_range:
153
- current_idx += 1
154
- continue # Skip already processed items
134
+ # Calculate absolute indices for this range
135
+ abs_start = chunk.start_index + range_start
136
+ abs_end = chunk.start_index + range_end
137
+ range_size = range_end - range_start + 1
155
138
 
156
- # Check if we've processed enough for this chunk
157
- if current_idx >= chunk.chunk_size:
158
- break
139
+ logger.debug(
140
+ f"Processing range [{range_start}, {range_end}] "
141
+ f"(absolute: [{abs_start}, {abs_end}])"
142
+ )
159
143
 
160
- items_processed += 1
161
- current_idx += 1
162
- yield key, url, image_data, metadata
144
+ try:
145
+ # Create a fresh dataset iterator for this range
146
+ dataset = load_dataset(
147
+ dataset_loader.dataset_path,
148
+ split=dataset_loader.split,
149
+ streaming=True,
150
+ token=dataset_loader.token,
151
+ )
152
+
153
+ # Use state_dict if available for efficient positioning
154
+ if hasattr(dataset, "load_state_dict") and hasattr(dataset, "state_dict"):
155
+ try:
156
+ state = dataset.state_dict()
157
+ # Modify state to jump to abs_start
158
+ if "num_examples_since_previous_state" in state:
159
+ state["num_examples_since_previous_state"] = abs_start
160
+ if "examples_iterable" in state and isinstance(
161
+ state["examples_iterable"], dict
162
+ ):
163
+ if "shard_example_idx" in state["examples_iterable"]:
164
+ state["examples_iterable"]["shard_example_idx"] = abs_start
165
+ dataset.load_state_dict(state)
166
+ logger.debug(f"Positioned dataset at index {abs_start} using state_dict")
167
+ except Exception as e:
168
+ logger.debug(f"Could not use state_dict, falling back to skip: {e}")
169
+ dataset = dataset.skip(abs_start)
170
+ else:
171
+ # Fall back to skip
172
+ dataset = dataset.skip(abs_start)
173
+
174
+ # Process items in this range
175
+ range_items = 0
176
+ for item in dataset:
177
+ if range_items >= range_size:
178
+ break
179
+
180
+ if should_stop.is_set() or not connected.is_set():
181
+ break
182
+
183
+ # Generate key for this item
184
+ current_abs_idx = abs_start + range_items
185
+ key = f"{dataset_loader.dataset_path.replace('/', '_')}_{current_abs_idx:08d}"
186
+
187
+ try:
188
+ if dataset_loader.image_column in item:
189
+ img_data = item[dataset_loader.image_column]
190
+ image_bytes = ImageProcessor.process_image_data(img_data)
191
+
192
+ if image_bytes:
193
+ # Extract metadata
194
+ metadata = {
195
+ k: v
196
+ for k, v in item.items()
197
+ if k != dataset_loader.image_column
198
+ }
199
+ # Add chunk-relative index to metadata
200
+ metadata["_chunk_relative_index"] = range_start + range_items
201
+
202
+ url = f"hf://{dataset_loader.dataset_path}#{current_abs_idx}"
203
+
204
+ items_yielded += 1
205
+ range_items += 1
206
+
207
+ yield key, url, image_bytes, metadata
208
+ else:
209
+ logger.warning(
210
+ f"Failed to process image at index {current_abs_idx}"
211
+ )
212
+ range_items += 1
213
+ else:
214
+ logger.warning(
215
+ f"No image column '{dataset_loader.image_column}' at index {current_abs_idx}"
216
+ )
217
+ range_items += 1
218
+
219
+ except Exception as e:
220
+ logger.error(f"Error processing item at index {current_abs_idx}: {e}")
221
+ range_items += 1
222
+ continue
223
+
224
+ except Exception as e:
225
+ logger.error(f"Error processing range [{range_start}, {range_end}]: {e}")
226
+ continue
163
227
 
164
228
  logger.info(
165
- f"HF dataset chunk {chunk.chunk_id}: yielded {items_processed} items "
166
- f"from ranges {unprocessed_ranges}"
229
+ f"HF dataset chunk {chunk.chunk_id}: yielded {items_yielded} items "
230
+ f"from {len(unprocessed_ranges)} ranges"
167
231
  )
168
232
 
169
233
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: caption-flow
3
- Version: 0.2.1
3
+ Version: 0.2.2
4
4
  Summary: Self-contained distributed community captioning system
5
5
  Author-email: bghira <bghira@users.github.com>
6
6
  License: MIT
@@ -1,29 +1,29 @@
1
1
  caption_flow/__init__.py,sha256=NLPJ25lRN7xHqncXweINDNwbt0q8lgjZ30G21zlPdRs,303
2
- caption_flow/cli.py,sha256=bHxx66CPsCmSieaH3pw8NZBojIIbniRTdU9mEBHMmWA,28832
2
+ caption_flow/cli.py,sha256=fkyQHzs5kei6-9ftkbJjko-K67TARxd7yNf7x9e7KSs,28820
3
3
  caption_flow/models.py,sha256=qo6lQiO10UISbaBVr6Cs-fSW_pmjwE6kmiTmmU_l3Wk,2140
4
4
  caption_flow/monitor.py,sha256=ZZCSasYLKJ-UzA3-RoAtytv-tbNA-m3h5YjlZg_vukg,7870
5
- caption_flow/orchestrator.py,sha256=bZ8NnGdqoXSmu7Nq-_7cOSH1DLHkBT88cne0uDyPeNY,89112
6
- caption_flow/storage.py,sha256=hC6ZHT_PHFoUVjqD5JUwy3_79oAD1e1H30neA_xsz7s,40748
5
+ caption_flow/orchestrator.py,sha256=9yWKVcaR-S6naNQSd7Np8AemwV5lNDmB_lCufpvVrS0,96282
6
+ caption_flow/storage.py,sha256=kGv9iQAgxwLLlAIPU6TBrlagdfxA339eBz1xG0yYRsc,40981
7
7
  caption_flow/utils/__init__.py,sha256=F1BChVoCsj9zn1GJRBOLHET1kLW6xrAmsbzcR7hHy6Y,202
8
8
  caption_flow/utils/auth.py,sha256=UrxX2n8OEEcfMD1Ey27TxGfrJFmUCpC59x-SCrQJoVE,2253
9
9
  caption_flow/utils/caption_utils.py,sha256=esUMAdcCkNjRroZ0Bhxv0_yKlLtMf0XeDCTt-5k6bik,5309
10
10
  caption_flow/utils/certificates.py,sha256=eu4blQZEkL9NRaY1ynQWg1asvDorRYhGRZea7STonJE,4635
11
11
  caption_flow/utils/checkpoint_tracker.py,sha256=8tsTFF-HcygitK92YcS-QWzeg-qRm9AuCpQoQRfC8M0,3335
12
- caption_flow/utils/chunk_tracker.py,sha256=hKn8CN6ubErc9kuCWZMj12ZCZKxVlqXqAEocbzjfa-k,17296
13
- caption_flow/utils/dataset_loader.py,sha256=ZplJv655ZMyUbaZC4BBiL5II18sBy4JSJhxGZtK_VmA,29107
12
+ caption_flow/utils/chunk_tracker.py,sha256=SO6ERvEwGXuikGDVaXFota_3Ix8BnePMU7CiZJKBAnQ,18025
13
+ caption_flow/utils/dataset_loader.py,sha256=Bvo-aa5jWtjzqXW0rEisdiWaN7Q-aH02rXXUu9uXqGo,19194
14
14
  caption_flow/utils/image_processor.py,sha256=Zl8TAv9gYPdAYat3UiTuuNdIb2fXNfZ35AxsxuovJTs,5650
15
15
  caption_flow/utils/job_queue.py,sha256=itdfXcrkvGjmXn4qtpgMF63k1ufRBaejDe4V6WcxzgU,1104
16
16
  caption_flow/utils/json_utils.py,sha256=IiZYn8uCM-3pYmyIbX2fmaOIyutArn67SqAyp0ggNpU,5396
17
17
  caption_flow/utils/prompt_template.py,sha256=AKp0diSZqNBMwZkpiTNjw8-bbQwHStr7QZTOJ7o1dC4,4345
18
- caption_flow/utils/shard_processor.py,sha256=CRda6M4xh4U0vwvYlzq9nJEzz4d_4yzUBosYAeBcPEA,10854
18
+ caption_flow/utils/shard_processor.py,sha256=c6COBKhFzZyUeJqot5uGVR3ANeOReBfs8-DR27mrdcA,14242
19
19
  caption_flow/utils/shard_tracker.py,sha256=Wt2oE-O85F2FxSnqIocJiaYeFn00OVVjIiklZIZRGL8,3233
20
20
  caption_flow/utils/vllm_config.py,sha256=TC7Rmjk0zRKbBXbWUXrFL4Z58hzax_-4L0pXZn09hdM,6019
21
21
  caption_flow/workers/base.py,sha256=jPm_Xw4Lxd0cnrPs-biBqKRQKkTOJLvHLolmp0Gb1CI,7530
22
22
  caption_flow/workers/caption.py,sha256=NZ9kTjk2uOoNwyyNSkB_arYk213vLr5mowHN-OjiFkk,54631
23
23
  caption_flow/workers/data.py,sha256=0Tg8NE0wdONeMlivYQ4nvbcfWdLuU51O7vR8_YSnJgo,14813
24
- caption_flow-0.2.1.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
25
- caption_flow-0.2.1.dist-info/METADATA,sha256=fxNfSOqkCklb96aq3ZFU7SvRuXEBUQ11xbjkQn7Yzuo,11941
26
- caption_flow-0.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
27
- caption_flow-0.2.1.dist-info/entry_points.txt,sha256=KnVlyrGKZj6p2zNyuEnCx4Y6jvJ4V-mcfN0lddPKTlQ,55
28
- caption_flow-0.2.1.dist-info/top_level.txt,sha256=_bXpKRutqded0FQ80dCChIz26ETV7tL4d4e2E_Y1FXs,13
29
- caption_flow-0.2.1.dist-info/RECORD,,
24
+ caption_flow-0.2.2.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
25
+ caption_flow-0.2.2.dist-info/METADATA,sha256=h9VN2ZWXVDH935Eavb-1kfsBpuW7m4Oph3tjh9ucc3w,11941
26
+ caption_flow-0.2.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
27
+ caption_flow-0.2.2.dist-info/entry_points.txt,sha256=KnVlyrGKZj6p2zNyuEnCx4Y6jvJ4V-mcfN0lddPKTlQ,55
28
+ caption_flow-0.2.2.dist-info/top_level.txt,sha256=_bXpKRutqded0FQ80dCChIz26ETV7tL4d4e2E_Y1FXs,13
29
+ caption_flow-0.2.2.dist-info/RECORD,,