caption-flow 0.3.4__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. caption_flow/__init__.py +3 -3
  2. caption_flow/cli.py +934 -415
  3. caption_flow/models.py +45 -3
  4. caption_flow/monitor.py +2 -3
  5. caption_flow/orchestrator.py +153 -104
  6. caption_flow/processors/__init__.py +3 -3
  7. caption_flow/processors/base.py +8 -7
  8. caption_flow/processors/huggingface.py +439 -67
  9. caption_flow/processors/local_filesystem.py +24 -28
  10. caption_flow/processors/webdataset.py +28 -22
  11. caption_flow/storage/exporter.py +420 -339
  12. caption_flow/storage/manager.py +636 -756
  13. caption_flow/utils/__init__.py +1 -1
  14. caption_flow/utils/auth.py +1 -1
  15. caption_flow/utils/caption_utils.py +1 -1
  16. caption_flow/utils/certificates.py +15 -8
  17. caption_flow/utils/checkpoint_tracker.py +30 -28
  18. caption_flow/utils/chunk_tracker.py +153 -56
  19. caption_flow/utils/image_processor.py +9 -9
  20. caption_flow/utils/json_utils.py +37 -20
  21. caption_flow/utils/prompt_template.py +24 -16
  22. caption_flow/utils/vllm_config.py +5 -4
  23. caption_flow/viewer.py +4 -12
  24. caption_flow/workers/base.py +5 -4
  25. caption_flow/workers/caption.py +265 -90
  26. caption_flow/workers/data.py +6 -8
  27. {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/METADATA +5 -4
  28. caption_flow-0.4.0.dist-info/RECORD +33 -0
  29. caption_flow-0.3.4.dist-info/RECORD +0 -33
  30. {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/WHEEL +0 -0
  31. {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/entry_points.txt +0 -0
  32. {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/licenses/LICENSE +0 -0
  33. {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,33 +1,36 @@
1
1
  """HuggingFace Datasets processor implementation - Memory Optimized Version."""
2
2
 
3
- import logging
4
- import threading
5
- import re
6
- import queue
7
- import requests
8
- import json
3
+ import gc
9
4
  import io
5
+ import json
6
+ import logging
10
7
  import os
11
- import gc
12
- import psutil
13
- from concurrent.futures import ThreadPoolExecutor, Future
14
- from typing import Dict, Any, List, Optional, Iterator, Set, Deque, Tuple
15
- from collections import deque, defaultdict
16
- from pathlib import Path
8
+ import queue
9
+ import re
10
+ import threading
11
+ import time
12
+ from collections import defaultdict, deque
13
+ from concurrent.futures import Future, ThreadPoolExecutor
17
14
  from datetime import datetime
18
- from PIL import Image
19
- import pyarrow as pa
15
+ from pathlib import Path
16
+ from typing import Any, Deque, Dict, Iterator, List, Optional, Set, Tuple
17
+
18
+ import psutil
20
19
  import pyarrow.parquet as pq
20
+ import requests
21
21
  from datasets import get_dataset_config_names, get_dataset_split_names
22
- from huggingface_hub import hf_hub_download, get_token
22
+ from huggingface_hub import get_token, hf_hub_download
23
+ from PIL import Image
24
+ from tqdm import tqdm
25
+
23
26
  from caption_flow.storage import StorageManager
24
27
 
25
- from .base import OrchestratorProcessor, WorkerProcessor, ProcessorConfig, WorkUnit, WorkResult
26
- from ..utils import ChunkTracker
27
28
  from ..models import JobId
29
+ from ..utils import ChunkTracker
30
+ from .base import OrchestratorProcessor, ProcessorConfig, WorkerProcessor, WorkResult, WorkUnit
28
31
 
29
32
  logger = logging.getLogger(__name__)
30
- logger.setLevel(logging.DEBUG)
33
+ logger.setLevel(os.environ.get("CAPTIONFLOW_LOG_LEVEL", "INFO").upper())
31
34
 
32
35
 
33
36
  def log_memory(location: str):
@@ -41,10 +44,6 @@ def log_memory(location: str):
41
44
  gc.collect()
42
45
 
43
46
 
44
- logger = logging.getLogger(__name__)
45
- logger.setLevel(logging.DEBUG)
46
-
47
-
48
47
  class NonBlockingQueueHandler:
49
48
  """Handles non-blocking retrieval from queues using concurrent futures."""
50
49
 
@@ -146,6 +145,9 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
146
145
 
147
146
  cfg = config.config
148
147
 
148
+ # Store storage reference for chunk state synchronization
149
+ self.storage = storage
150
+
149
151
  # Dataset configuration
150
152
  dataset_cfg = cfg.get("dataset", {})
151
153
  self.dataset_name = dataset_cfg.get("dataset_path")
@@ -340,6 +342,8 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
340
342
 
341
343
  # Cache shard info
342
344
  try:
345
+ # make dir if it doesn't exist already
346
+ shard_info_cache_path.parent.mkdir(parents=True, exist_ok=True)
343
347
  cache_data = {
344
348
  "dataset": self.dataset_name,
345
349
  "config": self.config,
@@ -367,9 +371,18 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
367
371
  raise ValueError(f"Global index {global_index} not found in any shard")
368
372
 
369
373
  def _restore_state(self, storage: StorageManager) -> None:
370
- """Restore state from chunk tracker."""
371
- logger.debug("Restoring state from chunk tracker")
374
+ """Restore state from chunk tracker and synchronize with storage."""
375
+ logger.debug("Restoring state from chunk tracker and synchronizing with storage")
376
+
377
+ # FIRST: Update chunk tracker from storage (like WebDataset does)
378
+ if storage:
379
+ processed_job_ids = storage.get_all_processed_job_ids()
380
+ if processed_job_ids:
381
+ self.update_from_storage(processed_job_ids)
382
+
383
+ # THEN: Restore work units from chunk tracker
372
384
  if not self.chunk_tracker:
385
+ logger.warning("No chunk tracker available for state restoration")
373
386
  return
374
387
 
375
388
  with self.lock:
@@ -382,14 +395,120 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
382
395
  # Only add incomplete chunks to pending
383
396
  if chunk_state.status != "completed":
384
397
  self.pending_units.append(chunk_id)
385
- elif chunk_state.status == "completed" and chunk_state.processed_ranges:
386
- logger.warning(
387
- f"Chunk {chunk_id} has processed_ranges stored in the checkpoint."
388
- )
389
398
 
390
399
  self.current_chunk_index = max_chunk_index + 1
391
400
  logger.info(f"Resuming from chunk index {self.current_chunk_index}")
392
401
 
402
+ # Flush checkpoint after major update
403
+ self.chunk_tracker.flush()
404
+
405
+ def _create_work_units_from_chunk(self, chunk_index: int) -> List[WorkUnit]:
406
+ """Create one or more work units from a chunk, splitting large gaps in unprocessed ranges."""
407
+ units = []
408
+ base_unit = self._create_work_unit(chunk_index)
409
+
410
+ if not base_unit:
411
+ return []
412
+
413
+ # Check if we should split this into multiple work units based on gaps
414
+ unprocessed_ranges = base_unit.data["unprocessed_ranges"]
415
+
416
+ if len(unprocessed_ranges) <= 1:
417
+ # Single range or no ranges, return as-is
418
+ return [base_unit]
419
+
420
+ # Check for large gaps between ranges that suggest we should split
421
+ total_span = unprocessed_ranges[-1][1] - unprocessed_ranges[0][0] + 1
422
+ total_work = sum(end - start + 1 for start, end in unprocessed_ranges)
423
+ gap_ratio = (total_span - total_work) / total_span if total_span > 0 else 0
424
+
425
+ # If gaps are more than 50% of the span, split into separate work units
426
+ if gap_ratio > 0.5 and len(unprocessed_ranges) > 1:
427
+ logger.debug(
428
+ f"Splitting chunk {chunk_index} with {len(unprocessed_ranges)} ranges (gap ratio: {gap_ratio:.1%})"
429
+ )
430
+
431
+ # Create separate work units for each contiguous range group
432
+ current_group = []
433
+
434
+ for i, (start, end) in enumerate(unprocessed_ranges):
435
+ if not current_group:
436
+ current_group = [(start, end)]
437
+ else:
438
+ # Check gap to previous range
439
+ prev_end = current_group[-1][1]
440
+ gap_size = start - prev_end - 1
441
+
442
+ # If gap is large (>100 items), start new group
443
+ if gap_size > 100:
444
+ # Create work unit for current group
445
+ units.append(self._create_range_work_unit(chunk_index, current_group))
446
+ current_group = [(start, end)]
447
+ else:
448
+ # Add to current group
449
+ current_group.append((start, end))
450
+
451
+ # Don't forget the last group
452
+ if current_group:
453
+ units.append(self._create_range_work_unit(chunk_index, current_group))
454
+
455
+ return [unit for unit in units if unit is not None]
456
+ else:
457
+ # Keep as single unit
458
+ return [base_unit]
459
+
460
+ def _create_range_work_unit(
461
+ self, chunk_index: int, ranges: List[Tuple[int, int]]
462
+ ) -> Optional[WorkUnit]:
463
+ """Create a work unit for specific ranges within a chunk."""
464
+ if not ranges:
465
+ return None
466
+
467
+ current_index = chunk_index * self.chunk_size
468
+ chunk_size = min(self.chunk_size, self.total_items - current_index)
469
+
470
+ # Find shard for this chunk
471
+ shard_id, local_idx = self._get_shard_for_index(current_index)
472
+ shard_name = Path(self.shard_info[shard_id]["filename"]).stem
473
+
474
+ # Create unique unit ID that includes range info
475
+ range_suffix = f"r{len(ranges)}" # r2 = 2 ranges, etc.
476
+ job_id_obj = JobId(
477
+ shard_id=shard_name, chunk_id=str(chunk_index), sample_id=str(current_index)
478
+ )
479
+ base_unit_id = job_id_obj.get_chunk_str()
480
+ unit_id = f"{base_unit_id}_{range_suffix}"
481
+
482
+ unprocessed_items = sum(end - start + 1 for start, end in ranges)
483
+
484
+ unit = WorkUnit(
485
+ unit_id=unit_id,
486
+ chunk_id=base_unit_id, # Keep original chunk_id for tracking
487
+ source_id=shard_name,
488
+ unit_size=unprocessed_items,
489
+ data={
490
+ "dataset_name": self.dataset_name,
491
+ "config": self.config,
492
+ "split": self.split,
493
+ "start_index": current_index,
494
+ "chunk_size": chunk_size,
495
+ "actual_work_size": unprocessed_items,
496
+ "unprocessed_ranges": ranges,
497
+ "range_based": True,
498
+ "is_split_unit": True, # Flag to indicate this is a split from larger chunk
499
+ "shard_ids": [shard_id],
500
+ "data_files": self.data_files,
501
+ },
502
+ metadata={
503
+ "dataset": self.dataset_name,
504
+ "shard_name": shard_name,
505
+ "chunk_index": chunk_index,
506
+ "range_count": len(ranges),
507
+ },
508
+ )
509
+
510
+ return unit
511
+
393
512
  def _create_work_unit(self, chunk_index: int) -> Optional[WorkUnit]:
394
513
  """Create a single work unit for a chunk index."""
395
514
  current_index = chunk_index * self.chunk_size
@@ -400,39 +519,65 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
400
519
  chunk_size = min(self.chunk_size, self.total_items - current_index)
401
520
 
402
521
  # Find shard for this chunk
403
- shard_id, _ = self._get_shard_for_index(current_index)
522
+ shard_id, local_idx = self._get_shard_for_index(current_index)
404
523
  shard_name = Path(self.shard_info[shard_id]["filename"]).stem
405
524
 
406
- job_id_obj = JobId(shard_id=shard_name, chunk_id=chunk_index, sample_id=current_index)
525
+ # Calculate RELATIVE chunk index within the shard
526
+ job_id_obj = JobId(
527
+ shard_id=shard_name, chunk_id=str(chunk_index), sample_id=str(current_index)
528
+ )
407
529
  unit_id = job_id_obj.get_chunk_str()
408
530
 
409
531
  # Calculate unprocessed ranges based on existing chunk state
410
532
  unprocessed_ranges = [(current_index, current_index + chunk_size - 1)]
411
-
412
533
  if self.chunk_tracker and unit_id in self.chunk_tracker.chunks:
413
534
  chunk_state = self.chunk_tracker.chunks[unit_id]
414
535
  if chunk_state.processed_ranges:
536
+ # Convert relative processed ranges to absolute ranges
537
+ absolute_processed_ranges = [
538
+ (start + current_index, end + current_index)
539
+ for start, end in chunk_state.processed_ranges
540
+ ]
415
541
  # Subtract processed ranges from total range
542
+ range_to_subtract = (current_index, current_index + chunk_size - 1)
543
+ logger.debug(
544
+ f"Chunk {unit_id} has processed ranges: {chunk_state.processed_ranges} (relative), {absolute_processed_ranges} (absolute)"
545
+ )
416
546
  unprocessed_ranges = self._subtract_ranges(
417
- [(current_index, current_index + chunk_size - 1)], chunk_state.processed_ranges
547
+ [range_to_subtract], absolute_processed_ranges
418
548
  )
419
549
 
420
550
  # If all ranges are processed, return None (shouldn't happen if status tracking is correct)
421
551
  if not unprocessed_ranges:
552
+ logger.debug(f"Chunk {unit_id} has no unprocessed ranges, skipping")
553
+ return None
554
+
555
+ # Calculate actual unprocessed items and total work to be assigned
556
+ unprocessed_items = sum(end - start + 1 for start, end in unprocessed_ranges)
557
+
558
+ # Skip assignment if there are very few unprocessed items (< 10 items)
559
+ if unprocessed_items < 10:
560
+ logger.debug(
561
+ f"Chunk {unit_id} has only {unprocessed_items} unprocessed items, skipping assignment"
562
+ )
422
563
  return None
423
564
 
565
+ # Create work unit that represents ONLY the unprocessed ranges
566
+ # This is the key fix: don't assign the full chunk, assign only unprocessed parts
424
567
  unit = WorkUnit(
425
568
  unit_id=unit_id,
426
569
  chunk_id=unit_id,
427
570
  source_id=shard_name,
428
- unit_size=chunk_size,
571
+ unit_size=unprocessed_items, # Only the unprocessed items
429
572
  data={
430
573
  "dataset_name": self.dataset_name,
431
574
  "config": self.config,
432
575
  "split": self.split,
433
- "start_index": current_index,
434
- "chunk_size": chunk_size,
435
- "unprocessed_ranges": unprocessed_ranges, # Use calculated ranges
576
+ "start_index": current_index, # Keep original chunk start for reference
577
+ "chunk_size": chunk_size, # Keep original chunk size for reference
578
+ "actual_work_size": unprocessed_items, # NEW: actual work to be done
579
+ "unprocessed_ranges": unprocessed_ranges, # The specific ranges to process
580
+ "range_based": True, # NEW: flag to indicate this is range-based
436
581
  "shard_ids": [shard_id],
437
582
  "data_files": self.data_files,
438
583
  },
@@ -455,23 +600,35 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
455
600
  assigned_count = sum(len(units) for units in self.assigned_units.values())
456
601
  worker_count = max(1, len(self.assigned_units))
457
602
 
603
+ # Check if all data has been processed
604
+ if self.current_chunk_index * self.chunk_size >= self.total_items:
605
+ # All chunks processed - exit the background thread
606
+ logger.debug("All chunks processed, exiting background thread")
607
+ break
608
+
458
609
  target_buffer = max(self.min_buffer, worker_count * self.buffer_multiplier)
459
610
  units_needed = max(0, target_buffer - (pending_count + assigned_count))
460
611
 
461
612
  if units_needed == 0:
462
- threading.Event().wait(5)
613
+ self.stop_creation.wait(5)
463
614
  continue
464
615
 
465
616
  # Create units as needed
466
617
  units_created = 0
467
618
 
619
+ # Progress bar
620
+ progress_bar = tqdm(total=units_needed, desc="Creating work units", unit="unit")
621
+
468
622
  while units_created < units_needed:
469
- logger.debug(f"Creating work unit for chunk {self.current_chunk_index}")
623
+ # logger.debug(f"Creating work unit for chunk {self.current_chunk_index}")
470
624
  if self.current_chunk_index * self.chunk_size >= self.total_items:
471
- threading.Event().wait(30)
625
+ # No more data available - exit immediately instead of waiting
626
+ logger.debug(
627
+ f"All chunks processed (chunk_index={self.current_chunk_index}, total_items={self.total_items})"
628
+ )
472
629
  break
473
630
  # Get shard info for proper unit_id
474
- current_index = self.current_chunk_index * self.chunk_size
631
+ current_index = self.current_chunk_index
475
632
  if current_index < self.total_items:
476
633
  shard_id, _ = self._get_shard_for_index(current_index)
477
634
  shard_name = Path(self.shard_info[shard_id]["filename"]).stem
@@ -509,14 +666,14 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
509
666
  units_created += 1
510
667
  self.current_chunk_index += 1
511
668
 
669
+ progress_bar.update(1)
512
670
  if units_created > 0:
513
671
  logger.debug(f"Created {units_created} work unit IDs")
514
672
 
515
673
  logger.info("Thread for creating units has completed. Exiting thread.")
516
674
 
517
675
  def process_responses_non_blocking(self, response_queue: queue.Queue) -> Optional[WorkResult]:
518
- """
519
- Non-blocking method to process responses from workers.
676
+ """Non-blocking method to process responses from workers.
520
677
  Returns a WorkResult if one is available, None otherwise.
521
678
  """
522
679
  # Check for response without blocking
@@ -551,33 +708,64 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
551
708
 
552
709
  # Force checkpoint save if needed
553
710
  if self.chunk_tracker:
554
- self.chunk_tracker.save()
711
+ # Flush statistics updates immediately
712
+ self.chunk_tracker.flush()
555
713
 
556
714
  def get_work_units(self, count: int, worker_id: str) -> List[WorkUnit]:
557
715
  """Get available work units for a worker."""
558
-
559
716
  logger.debug(
560
717
  "get_work_units called: count=%d worker_id=%s, pending: %d",
561
718
  count,
562
719
  worker_id,
563
720
  len(self.pending_units),
564
721
  )
722
+
723
+ # Periodically sync with storage to ensure we don't assign already-completed work
724
+ # This is especially important when workers reconnect
725
+ if hasattr(self, "storage") and self.storage and len(self.pending_units) > 0:
726
+ try:
727
+ processed_job_ids = self.storage.get_all_processed_job_ids()
728
+ if processed_job_ids:
729
+ logger.debug(
730
+ f"Syncing chunk tracker with {len(processed_job_ids)} processed items before assignment"
731
+ )
732
+ self.update_from_storage(processed_job_ids)
733
+ # Flush after storage sync to ensure consistency
734
+ self.chunk_tracker.flush()
735
+ except Exception as e:
736
+ logger.warning(f"Failed to sync with storage during work assignment: {e}")
737
+
565
738
  assigned = []
566
739
  with self.lock:
567
740
  while len(assigned) < count and self.pending_units:
568
741
  unit_id = self.pending_units.popleft()
569
742
 
570
- # Create work unit on demand
743
+ # Create work units on demand (may create multiple units from one chunk)
571
744
  chunk_index = int(unit_id.split(":")[-1])
572
- unit = self._create_work_unit(chunk_index)
745
+ units = self._create_work_units_from_chunk(chunk_index)
573
746
 
574
- if unit:
575
- self.assigned_units[worker_id].add(unit_id)
747
+ for unit in units:
748
+ if len(assigned) >= count:
749
+ # Put remaining units back in queue for next worker
750
+ self.pending_units.appendleft(
751
+ f"{unit.metadata['shard_name']}:chunk:{chunk_index}"
752
+ )
753
+ break
754
+
755
+ # Use the unit's actual unit_id for tracking
756
+ actual_unit_id = unit.unit_id
757
+ self.assigned_units[worker_id].add(actual_unit_id)
576
758
  assigned.append(unit)
577
- logger.debug("Assigning unit %s to worker %s", unit_id, worker_id)
759
+ logger.debug(
760
+ "Assigning unit %s (%d items) to worker %s",
761
+ actual_unit_id,
762
+ unit.data.get("actual_work_size", unit.unit_size),
763
+ worker_id,
764
+ )
578
765
 
579
766
  if self.chunk_tracker:
580
- self.chunk_tracker.mark_assigned(unit_id, worker_id)
767
+ # Track assignment using the base chunk_id for chunk tracker compatibility
768
+ self.chunk_tracker.mark_assigned(unit.chunk_id, worker_id)
581
769
 
582
770
  logger.debug("Returning %d work units to worker %s", len(assigned), worker_id)
583
771
  return assigned
@@ -610,12 +798,52 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
610
798
  def release_assignments(self, worker_id: str) -> None:
611
799
  """Release all assignments for a disconnected worker."""
612
800
  logger.debug("Releasing assignments for worker %s", worker_id)
801
+
802
+ # FIRST: Sync with storage to ensure chunk tracker is up-to-date
803
+ # This prevents reassigning work that was already completed by this or other workers
804
+ if hasattr(self, "storage") and self.storage:
805
+ try:
806
+ processed_job_ids = self.storage.get_all_processed_job_ids()
807
+ if processed_job_ids:
808
+ logger.info(
809
+ f"Syncing chunk tracker with {len(processed_job_ids)} processed items before releasing assignments"
810
+ )
811
+ self.update_from_storage(processed_job_ids)
812
+ # Flush after storage sync to ensure consistency
813
+ self.chunk_tracker.flush()
814
+ except Exception as e:
815
+ logger.warning(f"Failed to sync with storage before releasing assignments: {e}")
816
+
613
817
  with self.lock:
614
818
  unit_ids = list(self.assigned_units.get(worker_id, []))
615
819
 
616
820
  for unit_id in unit_ids:
617
- logger.debug(f"Adding {unit_id} to pending queue")
618
- self.pending_units.append(unit_id)
821
+ # Check if this chunk is already fully processed before re-queuing
822
+ should_requeue = True
823
+ if self.chunk_tracker and unit_id in self.chunk_tracker.chunks:
824
+ chunk_state = self.chunk_tracker.chunks[unit_id]
825
+ if chunk_state.status == "completed":
826
+ logger.info(f"Not re-queuing completed chunk {unit_id}")
827
+ should_requeue = False
828
+ elif chunk_state.processed_ranges:
829
+ # Check if chunk is mostly complete (>90% processed)
830
+ total_items = chunk_state.chunk_size
831
+ processed_items = sum(
832
+ end - start + 1 for start, end in chunk_state.processed_ranges
833
+ )
834
+ completion_rate = processed_items / total_items if total_items > 0 else 0
835
+
836
+ if completion_rate > 0.9:
837
+ logger.info(
838
+ f"Not re-queuing nearly complete chunk {unit_id} ({completion_rate:.1%} done)"
839
+ )
840
+ should_requeue = False
841
+
842
+ if should_requeue:
843
+ logger.debug(f"Adding {unit_id} to pending queue")
844
+ self.pending_units.append(unit_id)
845
+ else:
846
+ logger.debug(f"Skipping re-queue of {unit_id} (already processed)")
619
847
 
620
848
  if worker_id in self.assigned_units:
621
849
  del self.assigned_units[worker_id]
@@ -624,9 +852,87 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
624
852
  self.chunk_tracker.release_worker_chunks(worker_id)
625
853
 
626
854
  def update_from_storage(self, processed_job_ids: Set[str]) -> None:
627
- """Update work units based on what's been processed."""
855
+ """Update chunk tracker based on what's been processed in storage."""
628
856
  logger.info(f"Updating from storage with {len(processed_job_ids)} processed jobs")
629
- # No need to update in-memory work units since we create on demand
857
+
858
+ if not self.chunk_tracker:
859
+ return
860
+
861
+ # Group by chunk
862
+ processed_by_chunk = defaultdict(set)
863
+
864
+ for job_id_str in processed_job_ids:
865
+ try:
866
+ # Parse job ID to get chunk and sample index
867
+ job_id = JobId.from_str(job_id_str)
868
+ chunk_id = job_id.get_chunk_str()
869
+ sample_idx = int(job_id.sample_id)
870
+ processed_by_chunk[chunk_id].add(sample_idx)
871
+ except ValueError as e:
872
+ logger.warning(f"Invalid job ID format: {job_id_str} - {e}")
873
+ continue
874
+
875
+ # Update chunk tracker with processed items
876
+ for chunk_id, indices in processed_by_chunk.items():
877
+ if not indices:
878
+ continue
879
+
880
+ # Get or create chunk state
881
+ chunk_state = self.chunk_tracker.chunks.get(chunk_id)
882
+
883
+ if not chunk_state:
884
+ # Parse chunk_id using JobId to get info (reuse existing validation)
885
+ try:
886
+ # Reconstruct a valid job_id to parse chunk info
887
+ sample_job_id = f"{chunk_id}:idx:0" # Use dummy sample_id
888
+ job_id = JobId.from_str(sample_job_id)
889
+
890
+ shard_name = job_id.shard_id
891
+ chunk_idx = int(job_id.chunk_id)
892
+ start_index = chunk_idx * self.chunk_size
893
+
894
+ # Add chunk to tracker
895
+ self.chunk_tracker.add_chunk(
896
+ chunk_id,
897
+ shard_name,
898
+ "", # URL not needed for HuggingFace
899
+ start_index,
900
+ self.chunk_size,
901
+ )
902
+ chunk_state = self.chunk_tracker.chunks[chunk_id]
903
+ logger.info(f"Created chunk state for {chunk_id} from storage")
904
+ except ValueError as e:
905
+ logger.error(f"Failed to parse chunk_id {chunk_id}: {e}")
906
+ continue
907
+
908
+ # Get chunk start index for conversion (not used in this implementation but kept for clarity)
909
+ # chunk_start = chunk_state.start_index
910
+
911
+ # Sort absolute indices for range creation
912
+ sorted_indices = sorted(indices)
913
+
914
+ # Convert to contiguous ranges using absolute indices
915
+ ranges = []
916
+ start_range = sorted_indices[0]
917
+ end_range = sorted_indices[0]
918
+
919
+ for i in range(1, len(sorted_indices)):
920
+ if sorted_indices[i] == end_range + 1:
921
+ end_range = sorted_indices[i]
922
+ else:
923
+ ranges.append((start_range, end_range))
924
+ start_range = sorted_indices[i]
925
+ end_range = sorted_indices[i]
926
+ ranges.append((start_range, end_range))
927
+
928
+ # Mark ranges as processed (WITH ABSOLUTE INDICES)
929
+ logger.info(f"Marking {len(ranges)} ranges as processed in chunk {chunk_id}")
930
+ for start_idx, end_idx in ranges:
931
+ self.chunk_tracker.mark_items_processed(chunk_id, start_idx, end_idx)
932
+
933
+ # Save updated chunk tracker
934
+ self.chunk_tracker.save()
935
+ logger.info("Chunk tracker synchronized with storage")
630
936
 
631
937
  def get_stats(self) -> Dict[str, Any]:
632
938
  """Get processor statistics."""
@@ -761,7 +1067,7 @@ class HuggingFaceDatasetWorkerProcessor(WorkerProcessor):
761
1067
  return url.split("/")[-1]
762
1068
 
763
1069
  def _create_dummy_image(self, index: int, metadata: Dict[str, Any]) -> Image.Image:
764
- """Create a dummy image"""
1070
+ """Create a dummy image."""
765
1071
  color = (0, 0, 0)
766
1072
  width, height = 128, 128
767
1073
  image = Image.new("RGB", (width, height), color=color)
@@ -901,17 +1207,83 @@ class HuggingFaceDatasetWorkerProcessor(WorkerProcessor):
901
1207
  )
902
1208
  else:
903
1209
  # Normal processing - load real images
904
- if self.url_column and self.url_column in item:
905
- image_url = item[self.url_column]
906
- try:
907
- response = requests.get(image_url, timeout=30)
908
- response.raise_for_status()
909
- image = Image.open(io.BytesIO(response.content))
910
- except Exception as e:
911
- logger.error(
912
- f"Error downloading image from {image_url}: {e}"
1210
+ if self.url_column:
1211
+ if self.url_column in item:
1212
+ image_url = item[self.url_column]
1213
+ try:
1214
+ max_retries = 3
1215
+ backoff_factor = 2
1216
+ initial_delay = 1 # seconds
1217
+ response = None
1218
+
1219
+ for attempt in range(max_retries):
1220
+ try:
1221
+ response = requests.get(image_url, timeout=30)
1222
+ response.raise_for_status()
1223
+ break # Success
1224
+ except requests.exceptions.HTTPError as http_err:
1225
+ if (
1226
+ response is not None
1227
+ and response.status_code == 429
1228
+ ):
1229
+ retry_after = response.headers.get(
1230
+ "Retry-After"
1231
+ )
1232
+ sleep_time = initial_delay * (
1233
+ backoff_factor**attempt
1234
+ )
1235
+ if retry_after:
1236
+ try:
1237
+ sleep_time = int(retry_after)
1238
+ except ValueError:
1239
+ pass # Keep exponential backoff
1240
+ logger.warning(
1241
+ f"Rate limited (429) for {image_url}. Retrying in {sleep_time}s..."
1242
+ )
1243
+ time.sleep(sleep_time)
1244
+ elif (
1245
+ response is not None
1246
+ and 500 <= response.status_code < 600
1247
+ ):
1248
+ delay = initial_delay * (
1249
+ backoff_factor**attempt
1250
+ )
1251
+ logger.warning(
1252
+ f"Server error ({response.status_code}) for {image_url}. Retrying in {delay:.1f}s..."
1253
+ )
1254
+ time.sleep(delay)
1255
+ else:
1256
+ # Non-retriable HTTP error
1257
+ raise http_err
1258
+ except (
1259
+ requests.exceptions.RequestException
1260
+ ) as req_err:
1261
+ if attempt == max_retries - 1:
1262
+ raise req_err # Re-raise on last attempt
1263
+ delay = initial_delay * (
1264
+ backoff_factor**attempt
1265
+ )
1266
+ logger.warning(
1267
+ f"Request failed for {image_url}. Retrying in {delay:.1f}s... Error: {req_err}"
1268
+ )
1269
+ time.sleep(delay)
1270
+
1271
+ if response is None or not response.ok:
1272
+ logger.error(
1273
+ f"Failed to download image from {image_url} after {max_retries} retries."
1274
+ )
1275
+ continue
1276
+
1277
+ image = Image.open(io.BytesIO(response.content))
1278
+ except Exception as e:
1279
+ logger.error(
1280
+ f"Error downloading image from {image_url}: {e}"
1281
+ )
1282
+ continue
1283
+ else:
1284
+ logger.warning(
1285
+ f"URL column '{self.url_column}' not found in item at index {global_idx}"
913
1286
  )
914
- continue
915
1287
 
916
1288
  elif self.image_column and self.image_column in item:
917
1289
  image_data = item[self.image_column]
@@ -930,7 +1302,7 @@ class HuggingFaceDatasetWorkerProcessor(WorkerProcessor):
930
1302
  job_id_obj = JobId(
931
1303
  shard_id=shard_name,
932
1304
  chunk_id=str(chunk_index),
933
- sample_id=str(global_idx),
1305
+ sample_id=str(local_idx),
934
1306
  )
935
1307
  job_id = job_id_obj.get_sample_str()
936
1308
 
@@ -963,7 +1335,7 @@ class HuggingFaceDatasetWorkerProcessor(WorkerProcessor):
963
1335
  "_processed_indices": processed_indices,
964
1336
  }
965
1337
 
966
- processed_indices.append(global_idx)
1338
+ processed_indices.append(local_idx)
967
1339
 
968
1340
  except Exception as e:
969
1341
  logger.error(f"Error processing item at index {global_idx}: {e}")