caption-flow 0.3.4__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +3 -3
- caption_flow/cli.py +934 -415
- caption_flow/models.py +45 -3
- caption_flow/monitor.py +2 -3
- caption_flow/orchestrator.py +153 -104
- caption_flow/processors/__init__.py +3 -3
- caption_flow/processors/base.py +8 -7
- caption_flow/processors/huggingface.py +439 -67
- caption_flow/processors/local_filesystem.py +24 -28
- caption_flow/processors/webdataset.py +28 -22
- caption_flow/storage/exporter.py +420 -339
- caption_flow/storage/manager.py +636 -756
- caption_flow/utils/__init__.py +1 -1
- caption_flow/utils/auth.py +1 -1
- caption_flow/utils/caption_utils.py +1 -1
- caption_flow/utils/certificates.py +15 -8
- caption_flow/utils/checkpoint_tracker.py +30 -28
- caption_flow/utils/chunk_tracker.py +153 -56
- caption_flow/utils/image_processor.py +9 -9
- caption_flow/utils/json_utils.py +37 -20
- caption_flow/utils/prompt_template.py +24 -16
- caption_flow/utils/vllm_config.py +5 -4
- caption_flow/viewer.py +4 -12
- caption_flow/workers/base.py +5 -4
- caption_flow/workers/caption.py +265 -90
- caption_flow/workers/data.py +6 -8
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/METADATA +5 -4
- caption_flow-0.4.0.dist-info/RECORD +33 -0
- caption_flow-0.3.4.dist-info/RECORD +0 -33
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/WHEEL +0 -0
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,33 +1,36 @@
|
|
1
1
|
"""HuggingFace Datasets processor implementation - Memory Optimized Version."""
|
2
2
|
|
3
|
-
import
|
4
|
-
import threading
|
5
|
-
import re
|
6
|
-
import queue
|
7
|
-
import requests
|
8
|
-
import json
|
3
|
+
import gc
|
9
4
|
import io
|
5
|
+
import json
|
6
|
+
import logging
|
10
7
|
import os
|
11
|
-
import
|
12
|
-
import
|
13
|
-
|
14
|
-
|
15
|
-
from collections import
|
16
|
-
from
|
8
|
+
import queue
|
9
|
+
import re
|
10
|
+
import threading
|
11
|
+
import time
|
12
|
+
from collections import defaultdict, deque
|
13
|
+
from concurrent.futures import Future, ThreadPoolExecutor
|
17
14
|
from datetime import datetime
|
18
|
-
from
|
19
|
-
import
|
15
|
+
from pathlib import Path
|
16
|
+
from typing import Any, Deque, Dict, Iterator, List, Optional, Set, Tuple
|
17
|
+
|
18
|
+
import psutil
|
20
19
|
import pyarrow.parquet as pq
|
20
|
+
import requests
|
21
21
|
from datasets import get_dataset_config_names, get_dataset_split_names
|
22
|
-
from huggingface_hub import
|
22
|
+
from huggingface_hub import get_token, hf_hub_download
|
23
|
+
from PIL import Image
|
24
|
+
from tqdm import tqdm
|
25
|
+
|
23
26
|
from caption_flow.storage import StorageManager
|
24
27
|
|
25
|
-
from .base import OrchestratorProcessor, WorkerProcessor, ProcessorConfig, WorkUnit, WorkResult
|
26
|
-
from ..utils import ChunkTracker
|
27
28
|
from ..models import JobId
|
29
|
+
from ..utils import ChunkTracker
|
30
|
+
from .base import OrchestratorProcessor, ProcessorConfig, WorkerProcessor, WorkResult, WorkUnit
|
28
31
|
|
29
32
|
logger = logging.getLogger(__name__)
|
30
|
-
logger.setLevel(
|
33
|
+
logger.setLevel(os.environ.get("CAPTIONFLOW_LOG_LEVEL", "INFO").upper())
|
31
34
|
|
32
35
|
|
33
36
|
def log_memory(location: str):
|
@@ -41,10 +44,6 @@ def log_memory(location: str):
|
|
41
44
|
gc.collect()
|
42
45
|
|
43
46
|
|
44
|
-
logger = logging.getLogger(__name__)
|
45
|
-
logger.setLevel(logging.DEBUG)
|
46
|
-
|
47
|
-
|
48
47
|
class NonBlockingQueueHandler:
|
49
48
|
"""Handles non-blocking retrieval from queues using concurrent futures."""
|
50
49
|
|
@@ -146,6 +145,9 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
146
145
|
|
147
146
|
cfg = config.config
|
148
147
|
|
148
|
+
# Store storage reference for chunk state synchronization
|
149
|
+
self.storage = storage
|
150
|
+
|
149
151
|
# Dataset configuration
|
150
152
|
dataset_cfg = cfg.get("dataset", {})
|
151
153
|
self.dataset_name = dataset_cfg.get("dataset_path")
|
@@ -340,6 +342,8 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
340
342
|
|
341
343
|
# Cache shard info
|
342
344
|
try:
|
345
|
+
# make dir if it doesn't exist already
|
346
|
+
shard_info_cache_path.parent.mkdir(parents=True, exist_ok=True)
|
343
347
|
cache_data = {
|
344
348
|
"dataset": self.dataset_name,
|
345
349
|
"config": self.config,
|
@@ -367,9 +371,18 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
367
371
|
raise ValueError(f"Global index {global_index} not found in any shard")
|
368
372
|
|
369
373
|
def _restore_state(self, storage: StorageManager) -> None:
|
370
|
-
"""Restore state from chunk tracker."""
|
371
|
-
logger.debug("Restoring state from chunk tracker")
|
374
|
+
"""Restore state from chunk tracker and synchronize with storage."""
|
375
|
+
logger.debug("Restoring state from chunk tracker and synchronizing with storage")
|
376
|
+
|
377
|
+
# FIRST: Update chunk tracker from storage (like WebDataset does)
|
378
|
+
if storage:
|
379
|
+
processed_job_ids = storage.get_all_processed_job_ids()
|
380
|
+
if processed_job_ids:
|
381
|
+
self.update_from_storage(processed_job_ids)
|
382
|
+
|
383
|
+
# THEN: Restore work units from chunk tracker
|
372
384
|
if not self.chunk_tracker:
|
385
|
+
logger.warning("No chunk tracker available for state restoration")
|
373
386
|
return
|
374
387
|
|
375
388
|
with self.lock:
|
@@ -382,14 +395,120 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
382
395
|
# Only add incomplete chunks to pending
|
383
396
|
if chunk_state.status != "completed":
|
384
397
|
self.pending_units.append(chunk_id)
|
385
|
-
elif chunk_state.status == "completed" and chunk_state.processed_ranges:
|
386
|
-
logger.warning(
|
387
|
-
f"Chunk {chunk_id} has processed_ranges stored in the checkpoint."
|
388
|
-
)
|
389
398
|
|
390
399
|
self.current_chunk_index = max_chunk_index + 1
|
391
400
|
logger.info(f"Resuming from chunk index {self.current_chunk_index}")
|
392
401
|
|
402
|
+
# Flush checkpoint after major update
|
403
|
+
self.chunk_tracker.flush()
|
404
|
+
|
405
|
+
def _create_work_units_from_chunk(self, chunk_index: int) -> List[WorkUnit]:
|
406
|
+
"""Create one or more work units from a chunk, splitting large gaps in unprocessed ranges."""
|
407
|
+
units = []
|
408
|
+
base_unit = self._create_work_unit(chunk_index)
|
409
|
+
|
410
|
+
if not base_unit:
|
411
|
+
return []
|
412
|
+
|
413
|
+
# Check if we should split this into multiple work units based on gaps
|
414
|
+
unprocessed_ranges = base_unit.data["unprocessed_ranges"]
|
415
|
+
|
416
|
+
if len(unprocessed_ranges) <= 1:
|
417
|
+
# Single range or no ranges, return as-is
|
418
|
+
return [base_unit]
|
419
|
+
|
420
|
+
# Check for large gaps between ranges that suggest we should split
|
421
|
+
total_span = unprocessed_ranges[-1][1] - unprocessed_ranges[0][0] + 1
|
422
|
+
total_work = sum(end - start + 1 for start, end in unprocessed_ranges)
|
423
|
+
gap_ratio = (total_span - total_work) / total_span if total_span > 0 else 0
|
424
|
+
|
425
|
+
# If gaps are more than 50% of the span, split into separate work units
|
426
|
+
if gap_ratio > 0.5 and len(unprocessed_ranges) > 1:
|
427
|
+
logger.debug(
|
428
|
+
f"Splitting chunk {chunk_index} with {len(unprocessed_ranges)} ranges (gap ratio: {gap_ratio:.1%})"
|
429
|
+
)
|
430
|
+
|
431
|
+
# Create separate work units for each contiguous range group
|
432
|
+
current_group = []
|
433
|
+
|
434
|
+
for i, (start, end) in enumerate(unprocessed_ranges):
|
435
|
+
if not current_group:
|
436
|
+
current_group = [(start, end)]
|
437
|
+
else:
|
438
|
+
# Check gap to previous range
|
439
|
+
prev_end = current_group[-1][1]
|
440
|
+
gap_size = start - prev_end - 1
|
441
|
+
|
442
|
+
# If gap is large (>100 items), start new group
|
443
|
+
if gap_size > 100:
|
444
|
+
# Create work unit for current group
|
445
|
+
units.append(self._create_range_work_unit(chunk_index, current_group))
|
446
|
+
current_group = [(start, end)]
|
447
|
+
else:
|
448
|
+
# Add to current group
|
449
|
+
current_group.append((start, end))
|
450
|
+
|
451
|
+
# Don't forget the last group
|
452
|
+
if current_group:
|
453
|
+
units.append(self._create_range_work_unit(chunk_index, current_group))
|
454
|
+
|
455
|
+
return [unit for unit in units if unit is not None]
|
456
|
+
else:
|
457
|
+
# Keep as single unit
|
458
|
+
return [base_unit]
|
459
|
+
|
460
|
+
def _create_range_work_unit(
|
461
|
+
self, chunk_index: int, ranges: List[Tuple[int, int]]
|
462
|
+
) -> Optional[WorkUnit]:
|
463
|
+
"""Create a work unit for specific ranges within a chunk."""
|
464
|
+
if not ranges:
|
465
|
+
return None
|
466
|
+
|
467
|
+
current_index = chunk_index * self.chunk_size
|
468
|
+
chunk_size = min(self.chunk_size, self.total_items - current_index)
|
469
|
+
|
470
|
+
# Find shard for this chunk
|
471
|
+
shard_id, local_idx = self._get_shard_for_index(current_index)
|
472
|
+
shard_name = Path(self.shard_info[shard_id]["filename"]).stem
|
473
|
+
|
474
|
+
# Create unique unit ID that includes range info
|
475
|
+
range_suffix = f"r{len(ranges)}" # r2 = 2 ranges, etc.
|
476
|
+
job_id_obj = JobId(
|
477
|
+
shard_id=shard_name, chunk_id=str(chunk_index), sample_id=str(current_index)
|
478
|
+
)
|
479
|
+
base_unit_id = job_id_obj.get_chunk_str()
|
480
|
+
unit_id = f"{base_unit_id}_{range_suffix}"
|
481
|
+
|
482
|
+
unprocessed_items = sum(end - start + 1 for start, end in ranges)
|
483
|
+
|
484
|
+
unit = WorkUnit(
|
485
|
+
unit_id=unit_id,
|
486
|
+
chunk_id=base_unit_id, # Keep original chunk_id for tracking
|
487
|
+
source_id=shard_name,
|
488
|
+
unit_size=unprocessed_items,
|
489
|
+
data={
|
490
|
+
"dataset_name": self.dataset_name,
|
491
|
+
"config": self.config,
|
492
|
+
"split": self.split,
|
493
|
+
"start_index": current_index,
|
494
|
+
"chunk_size": chunk_size,
|
495
|
+
"actual_work_size": unprocessed_items,
|
496
|
+
"unprocessed_ranges": ranges,
|
497
|
+
"range_based": True,
|
498
|
+
"is_split_unit": True, # Flag to indicate this is a split from larger chunk
|
499
|
+
"shard_ids": [shard_id],
|
500
|
+
"data_files": self.data_files,
|
501
|
+
},
|
502
|
+
metadata={
|
503
|
+
"dataset": self.dataset_name,
|
504
|
+
"shard_name": shard_name,
|
505
|
+
"chunk_index": chunk_index,
|
506
|
+
"range_count": len(ranges),
|
507
|
+
},
|
508
|
+
)
|
509
|
+
|
510
|
+
return unit
|
511
|
+
|
393
512
|
def _create_work_unit(self, chunk_index: int) -> Optional[WorkUnit]:
|
394
513
|
"""Create a single work unit for a chunk index."""
|
395
514
|
current_index = chunk_index * self.chunk_size
|
@@ -400,39 +519,65 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
400
519
|
chunk_size = min(self.chunk_size, self.total_items - current_index)
|
401
520
|
|
402
521
|
# Find shard for this chunk
|
403
|
-
shard_id,
|
522
|
+
shard_id, local_idx = self._get_shard_for_index(current_index)
|
404
523
|
shard_name = Path(self.shard_info[shard_id]["filename"]).stem
|
405
524
|
|
406
|
-
|
525
|
+
# Calculate RELATIVE chunk index within the shard
|
526
|
+
job_id_obj = JobId(
|
527
|
+
shard_id=shard_name, chunk_id=str(chunk_index), sample_id=str(current_index)
|
528
|
+
)
|
407
529
|
unit_id = job_id_obj.get_chunk_str()
|
408
530
|
|
409
531
|
# Calculate unprocessed ranges based on existing chunk state
|
410
532
|
unprocessed_ranges = [(current_index, current_index + chunk_size - 1)]
|
411
|
-
|
412
533
|
if self.chunk_tracker and unit_id in self.chunk_tracker.chunks:
|
413
534
|
chunk_state = self.chunk_tracker.chunks[unit_id]
|
414
535
|
if chunk_state.processed_ranges:
|
536
|
+
# Convert relative processed ranges to absolute ranges
|
537
|
+
absolute_processed_ranges = [
|
538
|
+
(start + current_index, end + current_index)
|
539
|
+
for start, end in chunk_state.processed_ranges
|
540
|
+
]
|
415
541
|
# Subtract processed ranges from total range
|
542
|
+
range_to_subtract = (current_index, current_index + chunk_size - 1)
|
543
|
+
logger.debug(
|
544
|
+
f"Chunk {unit_id} has processed ranges: {chunk_state.processed_ranges} (relative), {absolute_processed_ranges} (absolute)"
|
545
|
+
)
|
416
546
|
unprocessed_ranges = self._subtract_ranges(
|
417
|
-
[
|
547
|
+
[range_to_subtract], absolute_processed_ranges
|
418
548
|
)
|
419
549
|
|
420
550
|
# If all ranges are processed, return None (shouldn't happen if status tracking is correct)
|
421
551
|
if not unprocessed_ranges:
|
552
|
+
logger.debug(f"Chunk {unit_id} has no unprocessed ranges, skipping")
|
553
|
+
return None
|
554
|
+
|
555
|
+
# Calculate actual unprocessed items and total work to be assigned
|
556
|
+
unprocessed_items = sum(end - start + 1 for start, end in unprocessed_ranges)
|
557
|
+
|
558
|
+
# Skip assignment if there are very few unprocessed items (< 10 items)
|
559
|
+
if unprocessed_items < 10:
|
560
|
+
logger.debug(
|
561
|
+
f"Chunk {unit_id} has only {unprocessed_items} unprocessed items, skipping assignment"
|
562
|
+
)
|
422
563
|
return None
|
423
564
|
|
565
|
+
# Create work unit that represents ONLY the unprocessed ranges
|
566
|
+
# This is the key fix: don't assign the full chunk, assign only unprocessed parts
|
424
567
|
unit = WorkUnit(
|
425
568
|
unit_id=unit_id,
|
426
569
|
chunk_id=unit_id,
|
427
570
|
source_id=shard_name,
|
428
|
-
unit_size=
|
571
|
+
unit_size=unprocessed_items, # Only the unprocessed items
|
429
572
|
data={
|
430
573
|
"dataset_name": self.dataset_name,
|
431
574
|
"config": self.config,
|
432
575
|
"split": self.split,
|
433
|
-
"start_index": current_index,
|
434
|
-
"chunk_size": chunk_size,
|
435
|
-
"
|
576
|
+
"start_index": current_index, # Keep original chunk start for reference
|
577
|
+
"chunk_size": chunk_size, # Keep original chunk size for reference
|
578
|
+
"actual_work_size": unprocessed_items, # NEW: actual work to be done
|
579
|
+
"unprocessed_ranges": unprocessed_ranges, # The specific ranges to process
|
580
|
+
"range_based": True, # NEW: flag to indicate this is range-based
|
436
581
|
"shard_ids": [shard_id],
|
437
582
|
"data_files": self.data_files,
|
438
583
|
},
|
@@ -455,23 +600,35 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
455
600
|
assigned_count = sum(len(units) for units in self.assigned_units.values())
|
456
601
|
worker_count = max(1, len(self.assigned_units))
|
457
602
|
|
603
|
+
# Check if all data has been processed
|
604
|
+
if self.current_chunk_index * self.chunk_size >= self.total_items:
|
605
|
+
# All chunks processed - exit the background thread
|
606
|
+
logger.debug("All chunks processed, exiting background thread")
|
607
|
+
break
|
608
|
+
|
458
609
|
target_buffer = max(self.min_buffer, worker_count * self.buffer_multiplier)
|
459
610
|
units_needed = max(0, target_buffer - (pending_count + assigned_count))
|
460
611
|
|
461
612
|
if units_needed == 0:
|
462
|
-
|
613
|
+
self.stop_creation.wait(5)
|
463
614
|
continue
|
464
615
|
|
465
616
|
# Create units as needed
|
466
617
|
units_created = 0
|
467
618
|
|
619
|
+
# Progress bar
|
620
|
+
progress_bar = tqdm(total=units_needed, desc="Creating work units", unit="unit")
|
621
|
+
|
468
622
|
while units_created < units_needed:
|
469
|
-
logger.debug(f"Creating work unit for chunk {self.current_chunk_index}")
|
623
|
+
# logger.debug(f"Creating work unit for chunk {self.current_chunk_index}")
|
470
624
|
if self.current_chunk_index * self.chunk_size >= self.total_items:
|
471
|
-
|
625
|
+
# No more data available - exit immediately instead of waiting
|
626
|
+
logger.debug(
|
627
|
+
f"All chunks processed (chunk_index={self.current_chunk_index}, total_items={self.total_items})"
|
628
|
+
)
|
472
629
|
break
|
473
630
|
# Get shard info for proper unit_id
|
474
|
-
current_index = self.current_chunk_index
|
631
|
+
current_index = self.current_chunk_index
|
475
632
|
if current_index < self.total_items:
|
476
633
|
shard_id, _ = self._get_shard_for_index(current_index)
|
477
634
|
shard_name = Path(self.shard_info[shard_id]["filename"]).stem
|
@@ -509,14 +666,14 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
509
666
|
units_created += 1
|
510
667
|
self.current_chunk_index += 1
|
511
668
|
|
669
|
+
progress_bar.update(1)
|
512
670
|
if units_created > 0:
|
513
671
|
logger.debug(f"Created {units_created} work unit IDs")
|
514
672
|
|
515
673
|
logger.info("Thread for creating units has completed. Exiting thread.")
|
516
674
|
|
517
675
|
def process_responses_non_blocking(self, response_queue: queue.Queue) -> Optional[WorkResult]:
|
518
|
-
"""
|
519
|
-
Non-blocking method to process responses from workers.
|
676
|
+
"""Non-blocking method to process responses from workers.
|
520
677
|
Returns a WorkResult if one is available, None otherwise.
|
521
678
|
"""
|
522
679
|
# Check for response without blocking
|
@@ -551,33 +708,64 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
551
708
|
|
552
709
|
# Force checkpoint save if needed
|
553
710
|
if self.chunk_tracker:
|
554
|
-
|
711
|
+
# Flush statistics updates immediately
|
712
|
+
self.chunk_tracker.flush()
|
555
713
|
|
556
714
|
def get_work_units(self, count: int, worker_id: str) -> List[WorkUnit]:
|
557
715
|
"""Get available work units for a worker."""
|
558
|
-
|
559
716
|
logger.debug(
|
560
717
|
"get_work_units called: count=%d worker_id=%s, pending: %d",
|
561
718
|
count,
|
562
719
|
worker_id,
|
563
720
|
len(self.pending_units),
|
564
721
|
)
|
722
|
+
|
723
|
+
# Periodically sync with storage to ensure we don't assign already-completed work
|
724
|
+
# This is especially important when workers reconnect
|
725
|
+
if hasattr(self, "storage") and self.storage and len(self.pending_units) > 0:
|
726
|
+
try:
|
727
|
+
processed_job_ids = self.storage.get_all_processed_job_ids()
|
728
|
+
if processed_job_ids:
|
729
|
+
logger.debug(
|
730
|
+
f"Syncing chunk tracker with {len(processed_job_ids)} processed items before assignment"
|
731
|
+
)
|
732
|
+
self.update_from_storage(processed_job_ids)
|
733
|
+
# Flush after storage sync to ensure consistency
|
734
|
+
self.chunk_tracker.flush()
|
735
|
+
except Exception as e:
|
736
|
+
logger.warning(f"Failed to sync with storage during work assignment: {e}")
|
737
|
+
|
565
738
|
assigned = []
|
566
739
|
with self.lock:
|
567
740
|
while len(assigned) < count and self.pending_units:
|
568
741
|
unit_id = self.pending_units.popleft()
|
569
742
|
|
570
|
-
# Create work
|
743
|
+
# Create work units on demand (may create multiple units from one chunk)
|
571
744
|
chunk_index = int(unit_id.split(":")[-1])
|
572
|
-
|
745
|
+
units = self._create_work_units_from_chunk(chunk_index)
|
573
746
|
|
574
|
-
|
575
|
-
|
747
|
+
for unit in units:
|
748
|
+
if len(assigned) >= count:
|
749
|
+
# Put remaining units back in queue for next worker
|
750
|
+
self.pending_units.appendleft(
|
751
|
+
f"{unit.metadata['shard_name']}:chunk:{chunk_index}"
|
752
|
+
)
|
753
|
+
break
|
754
|
+
|
755
|
+
# Use the unit's actual unit_id for tracking
|
756
|
+
actual_unit_id = unit.unit_id
|
757
|
+
self.assigned_units[worker_id].add(actual_unit_id)
|
576
758
|
assigned.append(unit)
|
577
|
-
logger.debug(
|
759
|
+
logger.debug(
|
760
|
+
"Assigning unit %s (%d items) to worker %s",
|
761
|
+
actual_unit_id,
|
762
|
+
unit.data.get("actual_work_size", unit.unit_size),
|
763
|
+
worker_id,
|
764
|
+
)
|
578
765
|
|
579
766
|
if self.chunk_tracker:
|
580
|
-
|
767
|
+
# Track assignment using the base chunk_id for chunk tracker compatibility
|
768
|
+
self.chunk_tracker.mark_assigned(unit.chunk_id, worker_id)
|
581
769
|
|
582
770
|
logger.debug("Returning %d work units to worker %s", len(assigned), worker_id)
|
583
771
|
return assigned
|
@@ -610,12 +798,52 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
610
798
|
def release_assignments(self, worker_id: str) -> None:
|
611
799
|
"""Release all assignments for a disconnected worker."""
|
612
800
|
logger.debug("Releasing assignments for worker %s", worker_id)
|
801
|
+
|
802
|
+
# FIRST: Sync with storage to ensure chunk tracker is up-to-date
|
803
|
+
# This prevents reassigning work that was already completed by this or other workers
|
804
|
+
if hasattr(self, "storage") and self.storage:
|
805
|
+
try:
|
806
|
+
processed_job_ids = self.storage.get_all_processed_job_ids()
|
807
|
+
if processed_job_ids:
|
808
|
+
logger.info(
|
809
|
+
f"Syncing chunk tracker with {len(processed_job_ids)} processed items before releasing assignments"
|
810
|
+
)
|
811
|
+
self.update_from_storage(processed_job_ids)
|
812
|
+
# Flush after storage sync to ensure consistency
|
813
|
+
self.chunk_tracker.flush()
|
814
|
+
except Exception as e:
|
815
|
+
logger.warning(f"Failed to sync with storage before releasing assignments: {e}")
|
816
|
+
|
613
817
|
with self.lock:
|
614
818
|
unit_ids = list(self.assigned_units.get(worker_id, []))
|
615
819
|
|
616
820
|
for unit_id in unit_ids:
|
617
|
-
|
618
|
-
|
821
|
+
# Check if this chunk is already fully processed before re-queuing
|
822
|
+
should_requeue = True
|
823
|
+
if self.chunk_tracker and unit_id in self.chunk_tracker.chunks:
|
824
|
+
chunk_state = self.chunk_tracker.chunks[unit_id]
|
825
|
+
if chunk_state.status == "completed":
|
826
|
+
logger.info(f"Not re-queuing completed chunk {unit_id}")
|
827
|
+
should_requeue = False
|
828
|
+
elif chunk_state.processed_ranges:
|
829
|
+
# Check if chunk is mostly complete (>90% processed)
|
830
|
+
total_items = chunk_state.chunk_size
|
831
|
+
processed_items = sum(
|
832
|
+
end - start + 1 for start, end in chunk_state.processed_ranges
|
833
|
+
)
|
834
|
+
completion_rate = processed_items / total_items if total_items > 0 else 0
|
835
|
+
|
836
|
+
if completion_rate > 0.9:
|
837
|
+
logger.info(
|
838
|
+
f"Not re-queuing nearly complete chunk {unit_id} ({completion_rate:.1%} done)"
|
839
|
+
)
|
840
|
+
should_requeue = False
|
841
|
+
|
842
|
+
if should_requeue:
|
843
|
+
logger.debug(f"Adding {unit_id} to pending queue")
|
844
|
+
self.pending_units.append(unit_id)
|
845
|
+
else:
|
846
|
+
logger.debug(f"Skipping re-queue of {unit_id} (already processed)")
|
619
847
|
|
620
848
|
if worker_id in self.assigned_units:
|
621
849
|
del self.assigned_units[worker_id]
|
@@ -624,9 +852,87 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
624
852
|
self.chunk_tracker.release_worker_chunks(worker_id)
|
625
853
|
|
626
854
|
def update_from_storage(self, processed_job_ids: Set[str]) -> None:
|
627
|
-
"""Update
|
855
|
+
"""Update chunk tracker based on what's been processed in storage."""
|
628
856
|
logger.info(f"Updating from storage with {len(processed_job_ids)} processed jobs")
|
629
|
-
|
857
|
+
|
858
|
+
if not self.chunk_tracker:
|
859
|
+
return
|
860
|
+
|
861
|
+
# Group by chunk
|
862
|
+
processed_by_chunk = defaultdict(set)
|
863
|
+
|
864
|
+
for job_id_str in processed_job_ids:
|
865
|
+
try:
|
866
|
+
# Parse job ID to get chunk and sample index
|
867
|
+
job_id = JobId.from_str(job_id_str)
|
868
|
+
chunk_id = job_id.get_chunk_str()
|
869
|
+
sample_idx = int(job_id.sample_id)
|
870
|
+
processed_by_chunk[chunk_id].add(sample_idx)
|
871
|
+
except ValueError as e:
|
872
|
+
logger.warning(f"Invalid job ID format: {job_id_str} - {e}")
|
873
|
+
continue
|
874
|
+
|
875
|
+
# Update chunk tracker with processed items
|
876
|
+
for chunk_id, indices in processed_by_chunk.items():
|
877
|
+
if not indices:
|
878
|
+
continue
|
879
|
+
|
880
|
+
# Get or create chunk state
|
881
|
+
chunk_state = self.chunk_tracker.chunks.get(chunk_id)
|
882
|
+
|
883
|
+
if not chunk_state:
|
884
|
+
# Parse chunk_id using JobId to get info (reuse existing validation)
|
885
|
+
try:
|
886
|
+
# Reconstruct a valid job_id to parse chunk info
|
887
|
+
sample_job_id = f"{chunk_id}:idx:0" # Use dummy sample_id
|
888
|
+
job_id = JobId.from_str(sample_job_id)
|
889
|
+
|
890
|
+
shard_name = job_id.shard_id
|
891
|
+
chunk_idx = int(job_id.chunk_id)
|
892
|
+
start_index = chunk_idx * self.chunk_size
|
893
|
+
|
894
|
+
# Add chunk to tracker
|
895
|
+
self.chunk_tracker.add_chunk(
|
896
|
+
chunk_id,
|
897
|
+
shard_name,
|
898
|
+
"", # URL not needed for HuggingFace
|
899
|
+
start_index,
|
900
|
+
self.chunk_size,
|
901
|
+
)
|
902
|
+
chunk_state = self.chunk_tracker.chunks[chunk_id]
|
903
|
+
logger.info(f"Created chunk state for {chunk_id} from storage")
|
904
|
+
except ValueError as e:
|
905
|
+
logger.error(f"Failed to parse chunk_id {chunk_id}: {e}")
|
906
|
+
continue
|
907
|
+
|
908
|
+
# Get chunk start index for conversion (not used in this implementation but kept for clarity)
|
909
|
+
# chunk_start = chunk_state.start_index
|
910
|
+
|
911
|
+
# Sort absolute indices for range creation
|
912
|
+
sorted_indices = sorted(indices)
|
913
|
+
|
914
|
+
# Convert to contiguous ranges using absolute indices
|
915
|
+
ranges = []
|
916
|
+
start_range = sorted_indices[0]
|
917
|
+
end_range = sorted_indices[0]
|
918
|
+
|
919
|
+
for i in range(1, len(sorted_indices)):
|
920
|
+
if sorted_indices[i] == end_range + 1:
|
921
|
+
end_range = sorted_indices[i]
|
922
|
+
else:
|
923
|
+
ranges.append((start_range, end_range))
|
924
|
+
start_range = sorted_indices[i]
|
925
|
+
end_range = sorted_indices[i]
|
926
|
+
ranges.append((start_range, end_range))
|
927
|
+
|
928
|
+
# Mark ranges as processed (WITH ABSOLUTE INDICES)
|
929
|
+
logger.info(f"Marking {len(ranges)} ranges as processed in chunk {chunk_id}")
|
930
|
+
for start_idx, end_idx in ranges:
|
931
|
+
self.chunk_tracker.mark_items_processed(chunk_id, start_idx, end_idx)
|
932
|
+
|
933
|
+
# Save updated chunk tracker
|
934
|
+
self.chunk_tracker.save()
|
935
|
+
logger.info("Chunk tracker synchronized with storage")
|
630
936
|
|
631
937
|
def get_stats(self) -> Dict[str, Any]:
|
632
938
|
"""Get processor statistics."""
|
@@ -761,7 +1067,7 @@ class HuggingFaceDatasetWorkerProcessor(WorkerProcessor):
|
|
761
1067
|
return url.split("/")[-1]
|
762
1068
|
|
763
1069
|
def _create_dummy_image(self, index: int, metadata: Dict[str, Any]) -> Image.Image:
|
764
|
-
"""Create a dummy image"""
|
1070
|
+
"""Create a dummy image."""
|
765
1071
|
color = (0, 0, 0)
|
766
1072
|
width, height = 128, 128
|
767
1073
|
image = Image.new("RGB", (width, height), color=color)
|
@@ -901,17 +1207,83 @@ class HuggingFaceDatasetWorkerProcessor(WorkerProcessor):
|
|
901
1207
|
)
|
902
1208
|
else:
|
903
1209
|
# Normal processing - load real images
|
904
|
-
if self.url_column
|
905
|
-
|
906
|
-
|
907
|
-
|
908
|
-
|
909
|
-
|
910
|
-
|
911
|
-
|
912
|
-
|
1210
|
+
if self.url_column:
|
1211
|
+
if self.url_column in item:
|
1212
|
+
image_url = item[self.url_column]
|
1213
|
+
try:
|
1214
|
+
max_retries = 3
|
1215
|
+
backoff_factor = 2
|
1216
|
+
initial_delay = 1 # seconds
|
1217
|
+
response = None
|
1218
|
+
|
1219
|
+
for attempt in range(max_retries):
|
1220
|
+
try:
|
1221
|
+
response = requests.get(image_url, timeout=30)
|
1222
|
+
response.raise_for_status()
|
1223
|
+
break # Success
|
1224
|
+
except requests.exceptions.HTTPError as http_err:
|
1225
|
+
if (
|
1226
|
+
response is not None
|
1227
|
+
and response.status_code == 429
|
1228
|
+
):
|
1229
|
+
retry_after = response.headers.get(
|
1230
|
+
"Retry-After"
|
1231
|
+
)
|
1232
|
+
sleep_time = initial_delay * (
|
1233
|
+
backoff_factor**attempt
|
1234
|
+
)
|
1235
|
+
if retry_after:
|
1236
|
+
try:
|
1237
|
+
sleep_time = int(retry_after)
|
1238
|
+
except ValueError:
|
1239
|
+
pass # Keep exponential backoff
|
1240
|
+
logger.warning(
|
1241
|
+
f"Rate limited (429) for {image_url}. Retrying in {sleep_time}s..."
|
1242
|
+
)
|
1243
|
+
time.sleep(sleep_time)
|
1244
|
+
elif (
|
1245
|
+
response is not None
|
1246
|
+
and 500 <= response.status_code < 600
|
1247
|
+
):
|
1248
|
+
delay = initial_delay * (
|
1249
|
+
backoff_factor**attempt
|
1250
|
+
)
|
1251
|
+
logger.warning(
|
1252
|
+
f"Server error ({response.status_code}) for {image_url}. Retrying in {delay:.1f}s..."
|
1253
|
+
)
|
1254
|
+
time.sleep(delay)
|
1255
|
+
else:
|
1256
|
+
# Non-retriable HTTP error
|
1257
|
+
raise http_err
|
1258
|
+
except (
|
1259
|
+
requests.exceptions.RequestException
|
1260
|
+
) as req_err:
|
1261
|
+
if attempt == max_retries - 1:
|
1262
|
+
raise req_err # Re-raise on last attempt
|
1263
|
+
delay = initial_delay * (
|
1264
|
+
backoff_factor**attempt
|
1265
|
+
)
|
1266
|
+
logger.warning(
|
1267
|
+
f"Request failed for {image_url}. Retrying in {delay:.1f}s... Error: {req_err}"
|
1268
|
+
)
|
1269
|
+
time.sleep(delay)
|
1270
|
+
|
1271
|
+
if response is None or not response.ok:
|
1272
|
+
logger.error(
|
1273
|
+
f"Failed to download image from {image_url} after {max_retries} retries."
|
1274
|
+
)
|
1275
|
+
continue
|
1276
|
+
|
1277
|
+
image = Image.open(io.BytesIO(response.content))
|
1278
|
+
except Exception as e:
|
1279
|
+
logger.error(
|
1280
|
+
f"Error downloading image from {image_url}: {e}"
|
1281
|
+
)
|
1282
|
+
continue
|
1283
|
+
else:
|
1284
|
+
logger.warning(
|
1285
|
+
f"URL column '{self.url_column}' not found in item at index {global_idx}"
|
913
1286
|
)
|
914
|
-
continue
|
915
1287
|
|
916
1288
|
elif self.image_column and self.image_column in item:
|
917
1289
|
image_data = item[self.image_column]
|
@@ -930,7 +1302,7 @@ class HuggingFaceDatasetWorkerProcessor(WorkerProcessor):
|
|
930
1302
|
job_id_obj = JobId(
|
931
1303
|
shard_id=shard_name,
|
932
1304
|
chunk_id=str(chunk_index),
|
933
|
-
sample_id=str(
|
1305
|
+
sample_id=str(local_idx),
|
934
1306
|
)
|
935
1307
|
job_id = job_id_obj.get_sample_str()
|
936
1308
|
|
@@ -963,7 +1335,7 @@ class HuggingFaceDatasetWorkerProcessor(WorkerProcessor):
|
|
963
1335
|
"_processed_indices": processed_indices,
|
964
1336
|
}
|
965
1337
|
|
966
|
-
processed_indices.append(
|
1338
|
+
processed_indices.append(local_idx)
|
967
1339
|
|
968
1340
|
except Exception as e:
|
969
1341
|
logger.error(f"Error processing item at index {global_idx}: {e}")
|