caption-flow 0.3.4__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +3 -3
- caption_flow/cli.py +921 -427
- caption_flow/models.py +45 -3
- caption_flow/monitor.py +2 -3
- caption_flow/orchestrator.py +153 -104
- caption_flow/processors/__init__.py +3 -3
- caption_flow/processors/base.py +8 -7
- caption_flow/processors/huggingface.py +463 -68
- caption_flow/processors/local_filesystem.py +24 -28
- caption_flow/processors/webdataset.py +28 -22
- caption_flow/storage/exporter.py +420 -339
- caption_flow/storage/manager.py +636 -756
- caption_flow/utils/__init__.py +1 -1
- caption_flow/utils/auth.py +1 -1
- caption_flow/utils/caption_utils.py +1 -1
- caption_flow/utils/certificates.py +15 -8
- caption_flow/utils/checkpoint_tracker.py +30 -28
- caption_flow/utils/chunk_tracker.py +153 -56
- caption_flow/utils/image_processor.py +9 -9
- caption_flow/utils/json_utils.py +37 -20
- caption_flow/utils/prompt_template.py +24 -16
- caption_flow/utils/vllm_config.py +5 -4
- caption_flow/viewer.py +4 -12
- caption_flow/workers/base.py +5 -4
- caption_flow/workers/caption.py +303 -92
- caption_flow/workers/data.py +6 -8
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/METADATA +9 -4
- caption_flow-0.4.1.dist-info/RECORD +33 -0
- caption_flow-0.3.4.dist-info/RECORD +0 -33
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/WHEEL +0 -0
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/top_level.txt +0 -0
caption_flow/workers/caption.py
CHANGED
@@ -7,34 +7,34 @@ os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
|
|
7
7
|
import asyncio
|
8
8
|
import json
|
9
9
|
import logging
|
10
|
-
import websockets
|
11
10
|
import time
|
12
|
-
from dataclasses import dataclass
|
13
|
-
from typing import Dict, Any, Optional, List, Tuple, Union
|
14
|
-
from queue import Queue, Empty
|
15
|
-
from threading import Thread, Event, Lock
|
16
11
|
from collections import defaultdict, deque
|
12
|
+
from dataclasses import dataclass
|
13
|
+
from queue import Empty, Queue
|
14
|
+
from threading import Event, Lock, Thread
|
15
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
17
16
|
|
18
|
-
|
17
|
+
import websockets
|
19
18
|
from huggingface_hub import get_token
|
19
|
+
from PIL import Image
|
20
20
|
|
21
|
-
from
|
21
|
+
from ..models import ProcessingStage, StageResult
|
22
22
|
from ..processors import (
|
23
|
+
HuggingFaceDatasetWorkerProcessor,
|
24
|
+
LocalFilesystemWorkerProcessor,
|
23
25
|
ProcessorConfig,
|
26
|
+
WebDatasetWorkerProcessor,
|
24
27
|
WorkAssignment,
|
25
|
-
WorkUnit,
|
26
28
|
WorkResult,
|
27
|
-
|
28
|
-
HuggingFaceDatasetWorkerProcessor,
|
29
|
-
LocalFilesystemWorkerProcessor,
|
29
|
+
WorkUnit,
|
30
30
|
)
|
31
|
-
from ..utils.vllm_config import VLLMConfigManager
|
32
31
|
from ..utils.image_processor import ImageProcessor
|
33
32
|
from ..utils.prompt_template import PromptTemplateManager
|
34
|
-
from ..
|
33
|
+
from ..utils.vllm_config import VLLMConfigManager
|
34
|
+
from .base import BaseWorker
|
35
35
|
|
36
36
|
logger = logging.getLogger(__name__)
|
37
|
-
logger.setLevel(
|
37
|
+
logger.setLevel(os.environ.get("CAPTIONFLOW_LOG_LEVEL", "INFO").upper())
|
38
38
|
|
39
39
|
|
40
40
|
@dataclass
|
@@ -72,8 +72,8 @@ class MultiStageVLLMManager:
|
|
72
72
|
logger.info(f"Model {model_name} already loaded, reusing instance")
|
73
73
|
return
|
74
74
|
|
75
|
-
from
|
76
|
-
from
|
75
|
+
from transformers import AutoProcessor, AutoTokenizer
|
76
|
+
from vllm import LLM
|
77
77
|
|
78
78
|
logger.info(f"Loading model {model_name} for stage {stage.name}")
|
79
79
|
|
@@ -137,6 +137,19 @@ class MultiStageVLLMManager:
|
|
137
137
|
|
138
138
|
def get_model_for_stage(self, stage_name: str, model_name: str) -> Tuple[Any, Any, Any, Any]:
|
139
139
|
"""Get model components for a stage."""
|
140
|
+
if model_name not in self.models:
|
141
|
+
raise KeyError(
|
142
|
+
f"Model '{model_name}' not found in loaded models. Available models: {list(self.models.keys())}"
|
143
|
+
)
|
144
|
+
if model_name not in self.processors:
|
145
|
+
raise KeyError(f"Processor for model '{model_name}' not found")
|
146
|
+
if model_name not in self.tokenizers:
|
147
|
+
raise KeyError(f"Tokenizer for model '{model_name}' not found")
|
148
|
+
if stage_name not in self.sampling_params:
|
149
|
+
raise KeyError(
|
150
|
+
f"Sampling params for stage '{stage_name}' not found. Available stages: {list(self.sampling_params.keys())}"
|
151
|
+
)
|
152
|
+
|
140
153
|
return (
|
141
154
|
self.models[model_name],
|
142
155
|
self.processors[model_name],
|
@@ -305,7 +318,7 @@ class CaptionWorker(BaseWorker):
|
|
305
318
|
self.processor = LocalFilesystemWorkerProcessor()
|
306
319
|
else:
|
307
320
|
raise ValueError(f"Unknown processor type: {self.processor_type}")
|
308
|
-
|
321
|
+
self.processor.gpu_id = self.gpu_id
|
309
322
|
self.processor.initialize(processor_config)
|
310
323
|
self.dataset_path = self.processor.dataset_path
|
311
324
|
self.units_per_request = processor_config.config.get("chunks_per_request", 1)
|
@@ -463,7 +476,7 @@ class CaptionWorker(BaseWorker):
|
|
463
476
|
# Check if stages changed significantly
|
464
477
|
stages_changed = len(new_stages) != len(self.stages)
|
465
478
|
if not stages_changed:
|
466
|
-
for old, new in zip(self.stages, new_stages):
|
479
|
+
for old, new in zip(self.stages, new_stages, strict=False):
|
467
480
|
if (
|
468
481
|
old.name != new.name
|
469
482
|
or old.model != new.model
|
@@ -489,7 +502,19 @@ class CaptionWorker(BaseWorker):
|
|
489
502
|
return True
|
490
503
|
except Exception as e:
|
491
504
|
logger.error(f"Failed to reload vLLM: {e}")
|
505
|
+
# Restore previous state
|
492
506
|
self.vllm_config = old_config
|
507
|
+
self.stages = self._parse_stages_config(old_config)
|
508
|
+
self.stage_order = self._topological_sort_stages(self.stages)
|
509
|
+
# Attempt to restore previous models
|
510
|
+
try:
|
511
|
+
self._setup_vllm()
|
512
|
+
except Exception as restore_error:
|
513
|
+
logger.error(f"Failed to restore previous vLLM state: {restore_error}")
|
514
|
+
# Clean up broken state
|
515
|
+
if self.model_manager:
|
516
|
+
self.model_manager.cleanup()
|
517
|
+
self.model_manager = None
|
493
518
|
return False
|
494
519
|
else:
|
495
520
|
# Clean up models if switching to mock mode
|
@@ -580,6 +605,7 @@ class CaptionWorker(BaseWorker):
|
|
580
605
|
|
581
606
|
try:
|
582
607
|
# Create processing item
|
608
|
+
logger.debug(f"Processing item data: {item_data}")
|
583
609
|
item = ProcessingItem(
|
584
610
|
unit_id=unit.unit_id,
|
585
611
|
chunk_id=unit.chunk_id,
|
@@ -610,34 +636,64 @@ class CaptionWorker(BaseWorker):
|
|
610
636
|
if batch and not self.should_stop_processing.is_set():
|
611
637
|
self._process_batch(batch)
|
612
638
|
|
613
|
-
# Notify orchestrator
|
639
|
+
# Notify orchestrator about unit completion or failure
|
614
640
|
# Check if the number of processed items matches the expected count for the unit.
|
615
641
|
# The context dictionary holds the count of items yielded by the processor.
|
616
642
|
total_items_in_unit = unit.unit_size
|
617
643
|
|
618
|
-
if (
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
644
|
+
if not self.should_stop_processing.is_set() and self.connected.is_set():
|
645
|
+
if self.items_failed == 0 and self.items_processed >= total_items_in_unit:
|
646
|
+
# Unit completed successfully
|
647
|
+
if self.websocket:
|
648
|
+
try:
|
649
|
+
asyncio.run_coroutine_threadsafe(
|
650
|
+
self.websocket.send(
|
651
|
+
json.dumps({"type": "work_complete", "unit_id": unit.unit_id})
|
652
|
+
),
|
653
|
+
self.main_loop,
|
654
|
+
).result(timeout=5)
|
655
|
+
logger.info(
|
656
|
+
f"Unit {unit.unit_id} fully processed "
|
657
|
+
f"({self.items_processed}/{total_items_in_unit}) and marked complete."
|
658
|
+
)
|
659
|
+
except Exception as e:
|
660
|
+
logger.warning(
|
661
|
+
f"Could not notify work complete for unit {unit.unit_id}: {e}"
|
662
|
+
)
|
663
|
+
else:
|
664
|
+
# Unit failed or was incomplete
|
665
|
+
if self.items_failed > 0:
|
666
|
+
error_msg = (
|
667
|
+
f"Processing failed for {self.items_failed} out of "
|
668
|
+
f"{total_items_in_unit} items"
|
634
669
|
)
|
635
|
-
|
636
|
-
|
670
|
+
logger.error(f"Unit {unit.unit_id} failed: {error_msg}")
|
671
|
+
else:
|
672
|
+
error_msg = (
|
673
|
+
f"Processing incomplete: {self.items_processed}/"
|
674
|
+
f"{total_items_in_unit} items processed"
|
675
|
+
)
|
676
|
+
logger.warning(f"Unit {unit.unit_id} incomplete: {error_msg}")
|
677
|
+
|
678
|
+
if self.websocket:
|
679
|
+
try:
|
680
|
+
asyncio.run_coroutine_threadsafe(
|
681
|
+
self.websocket.send(
|
682
|
+
json.dumps(
|
683
|
+
{
|
684
|
+
"type": "work_failed",
|
685
|
+
"unit_id": unit.unit_id,
|
686
|
+
"error": error_msg,
|
687
|
+
}
|
688
|
+
)
|
689
|
+
),
|
690
|
+
self.main_loop,
|
691
|
+
).result(timeout=5)
|
692
|
+
logger.info(f"Unit {unit.unit_id} failure reported to orchestrator")
|
693
|
+
except Exception as e:
|
694
|
+
logger.warning(f"Could not notify work failed for unit {unit.unit_id}: {e}")
|
637
695
|
else:
|
638
|
-
logger.
|
639
|
-
f"Processing of unit {unit.unit_id} was incomplete ({self.items_processed}/{total_items_in_unit}). Not marking as complete."
|
640
|
-
)
|
696
|
+
logger.info(f"Unit {unit.unit_id} processing stopped due to disconnect or shutdown")
|
641
697
|
|
642
698
|
def _process_batch(self, batch: List[ProcessingItem]):
|
643
699
|
"""Process a batch of items through all stages."""
|
@@ -672,6 +728,20 @@ class CaptionWorker(BaseWorker):
|
|
672
728
|
except Exception as e:
|
673
729
|
logger.error(f"Batch processing error: {e}", exc_info=True)
|
674
730
|
|
731
|
+
# Mark all items in batch as failed
|
732
|
+
self.items_failed += len(batch)
|
733
|
+
|
734
|
+
# Send error results for each item in the batch
|
735
|
+
for item in batch:
|
736
|
+
self.result_queue.put(
|
737
|
+
{
|
738
|
+
"item": item,
|
739
|
+
"outputs": {},
|
740
|
+
"processing_time_ms": 0.0,
|
741
|
+
"error": f"Batch processing failed: {str(e)}",
|
742
|
+
}
|
743
|
+
)
|
744
|
+
|
675
745
|
def _process_batch_mock(self, batch: List[ProcessingItem]) -> List[Tuple[ProcessingItem, Dict]]:
|
676
746
|
"""Process a batch in mock mode - return dummy captions."""
|
677
747
|
results = []
|
@@ -686,9 +756,9 @@ class CaptionWorker(BaseWorker):
|
|
686
756
|
|
687
757
|
# Create mock outputs based on stage prompts
|
688
758
|
stage_outputs = []
|
689
|
-
for i,
|
759
|
+
for i, _prompt in enumerate(stage.prompts):
|
690
760
|
mock_output = (
|
691
|
-
f"Mock {stage_name} output {i+1} for job {item.job_id} - {item.item_key}"
|
761
|
+
f"Mock {stage_name} output {i + 1} for job {item.job_id} - {item.item_key}"
|
692
762
|
)
|
693
763
|
stage_outputs.append(mock_output)
|
694
764
|
|
@@ -713,42 +783,212 @@ class CaptionWorker(BaseWorker):
|
|
713
783
|
|
714
784
|
return results
|
715
785
|
|
786
|
+
def _validate_and_split_batch(
|
787
|
+
self,
|
788
|
+
batch: List[ProcessingItem],
|
789
|
+
stage: ProcessingStage,
|
790
|
+
processor,
|
791
|
+
tokenizer,
|
792
|
+
sampling_params,
|
793
|
+
max_length: int = 16384,
|
794
|
+
) -> Tuple[List[ProcessingItem], List[ProcessingItem]]:
|
795
|
+
"""Validate batch items and split into processable and too-long items."""
|
796
|
+
logger.debug(
|
797
|
+
f"Validating batch of size {len(batch)} for stage '{stage.name}' "
|
798
|
+
f"with max_length {max_length}"
|
799
|
+
)
|
800
|
+
processable = []
|
801
|
+
too_long = []
|
802
|
+
|
803
|
+
for item in batch:
|
804
|
+
try:
|
805
|
+
# Create a test prompt for this item
|
806
|
+
converted_img = ImageProcessor.prepare_for_inference(item)
|
807
|
+
template_manager = PromptTemplateManager(
|
808
|
+
stage.prompts[:1]
|
809
|
+
) # Test with first prompt
|
810
|
+
|
811
|
+
# Build context
|
812
|
+
context = item.metadata.copy()
|
813
|
+
for prev_stage_name, stage_result in item.stage_results.items():
|
814
|
+
for i, output in enumerate(stage_result.outputs):
|
815
|
+
context[f"{prev_stage_name}_output_{i}"] = output
|
816
|
+
if len(stage_result.outputs) == 1:
|
817
|
+
context[stage_result.output_field] = stage_result.outputs[0]
|
818
|
+
else:
|
819
|
+
context[stage_result.output_field] = stage_result.outputs
|
820
|
+
logger.debug(f"Validation context for {item.item_key}: {context}")
|
821
|
+
|
822
|
+
# Format test prompt
|
823
|
+
formatted_prompts = template_manager.format_all(context)
|
824
|
+
if not formatted_prompts:
|
825
|
+
logger.warning(
|
826
|
+
f"Could not format prompt for {item.item_key}, marking as too long."
|
827
|
+
)
|
828
|
+
too_long.append(item)
|
829
|
+
continue
|
830
|
+
|
831
|
+
logger.debug(
|
832
|
+
f"Formatted validation prompt for {item.item_key}: {formatted_prompts[0]}"
|
833
|
+
)
|
834
|
+
|
835
|
+
# Build actual vLLM input to test
|
836
|
+
test_req = self._build_vllm_input(
|
837
|
+
converted_img, formatted_prompts[0], processor, tokenizer
|
838
|
+
)
|
839
|
+
|
840
|
+
# Use processor to get actual token count
|
841
|
+
if "prompt_token_ids" in test_req:
|
842
|
+
prompt_length = len(test_req["prompt_token_ids"])
|
843
|
+
else:
|
844
|
+
# Fallback to tokenizer
|
845
|
+
prompt_length = len(tokenizer.encode(test_req.get("prompt", "")))
|
846
|
+
|
847
|
+
# Check individual prompt length (prompts are processed one by one)
|
848
|
+
# Use a small safety buffer to account for token estimation variations
|
849
|
+
safety_buffer = 50
|
850
|
+
if prompt_length < max_length - safety_buffer:
|
851
|
+
processable.append(item)
|
852
|
+
logger.debug(
|
853
|
+
f"Item {item.item_key} validated: {prompt_length} tokens per prompt"
|
854
|
+
)
|
855
|
+
else:
|
856
|
+
too_long.append(item)
|
857
|
+
logger.warning(
|
858
|
+
f"Item {item.item_key} too long: {prompt_length} tokens "
|
859
|
+
f"vs max {max_length - safety_buffer} (with safety buffer)"
|
860
|
+
)
|
861
|
+
|
862
|
+
except Exception as e:
|
863
|
+
logger.error(f"Error validating item {item.item_key}: {e}", exc_info=True)
|
864
|
+
too_long.append(item)
|
865
|
+
|
866
|
+
logger.debug(
|
867
|
+
f"Validation complete: {len(processable)} processable, {len(too_long)} too long."
|
868
|
+
)
|
869
|
+
return processable, too_long
|
870
|
+
|
871
|
+
def _resize_image_for_tokens(
|
872
|
+
self, item: ProcessingItem, target_ratio: float = 0.7
|
873
|
+
) -> ProcessingItem:
|
874
|
+
"""Resize image to reduce token count."""
|
875
|
+
if not item.image:
|
876
|
+
return item
|
877
|
+
|
878
|
+
# Calculate new size
|
879
|
+
new_width = int(item.image.width * target_ratio)
|
880
|
+
new_height = int(item.image.height * target_ratio)
|
881
|
+
|
882
|
+
# Resize image
|
883
|
+
resized_image = item.image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
884
|
+
|
885
|
+
# Create new item with resized image
|
886
|
+
new_item = ProcessingItem(
|
887
|
+
unit_id=item.unit_id,
|
888
|
+
job_id=item.job_id,
|
889
|
+
chunk_id=item.chunk_id,
|
890
|
+
item_key=item.item_key,
|
891
|
+
item_index=item.item_index,
|
892
|
+
image=resized_image,
|
893
|
+
image_data=item.image_data, # Keep original data for metadata
|
894
|
+
metadata={**item.metadata, "_resized": True, "_resize_ratio": target_ratio},
|
895
|
+
stage_results=item.stage_results.copy(),
|
896
|
+
)
|
897
|
+
|
898
|
+
return new_item
|
899
|
+
|
716
900
|
def _process_batch_multi_stage(
|
717
901
|
self, batch: List[ProcessingItem], max_attempts: int = 3
|
718
902
|
) -> List[Tuple[ProcessingItem, Dict]]:
|
719
|
-
"""Process a batch through all stages
|
903
|
+
"""Process a batch through all stages with token validation."""
|
720
904
|
results = []
|
721
905
|
|
906
|
+
# Get max model length from config
|
907
|
+
max_model_len = self.vllm_config.get("max_model_len", 16384)
|
908
|
+
|
722
909
|
# Process each stage in order
|
723
910
|
for stage_name in self.stage_order:
|
724
911
|
stage = next(s for s in self.stages if s.name == stage_name)
|
725
912
|
logger.debug(f"Processing batch through stage: {stage_name}")
|
726
913
|
|
914
|
+
# Check if model manager is properly initialized
|
915
|
+
if not self.model_manager:
|
916
|
+
logger.error("Model manager not initialized")
|
917
|
+
self.items_failed += len(batch)
|
918
|
+
return []
|
919
|
+
|
727
920
|
# Get model components
|
728
|
-
|
729
|
-
|
921
|
+
try:
|
922
|
+
llm, processor, tokenizer, sampling_params = self.model_manager.get_model_for_stage(
|
923
|
+
stage_name, stage.model
|
924
|
+
)
|
925
|
+
except KeyError as e:
|
926
|
+
logger.error(f"Model not found during batch processing: {e}")
|
927
|
+
self.items_failed += len(batch)
|
928
|
+
return []
|
929
|
+
|
930
|
+
# Validate batch before processing
|
931
|
+
processable_batch, too_long_items = self._validate_and_split_batch(
|
932
|
+
batch, stage, processor, tokenizer, sampling_params, max_model_len
|
730
933
|
)
|
731
934
|
|
732
|
-
#
|
733
|
-
|
935
|
+
# Handle items that are too long
|
936
|
+
for item in too_long_items:
|
937
|
+
logger.warning(f"Item {item.item_key} exceeds token limit, attempting resize")
|
734
938
|
|
735
|
-
|
736
|
-
|
737
|
-
requests = []
|
939
|
+
# Try resizing the image
|
940
|
+
resized_item = self._resize_image_for_tokens(item, target_ratio=0.7)
|
738
941
|
|
739
|
-
|
740
|
-
|
942
|
+
# Re-validate
|
943
|
+
resized_processable, still_too_long = self._validate_and_split_batch(
|
944
|
+
[resized_item], stage, processor, tokenizer, sampling_params, max_model_len
|
945
|
+
)
|
741
946
|
|
742
|
-
|
743
|
-
|
947
|
+
if resized_processable:
|
948
|
+
processable_batch.extend(resized_processable)
|
949
|
+
logger.info(f"Successfully resized {item.item_key} for processing")
|
950
|
+
else:
|
951
|
+
# Try even smaller
|
952
|
+
resized_item = self._resize_image_for_tokens(item, target_ratio=0.5)
|
953
|
+
resized_processable, still_too_long = self._validate_and_split_batch(
|
954
|
+
[resized_item], stage, processor, tokenizer, sampling_params, max_model_len
|
955
|
+
)
|
956
|
+
|
957
|
+
if resized_processable:
|
958
|
+
processable_batch.extend(resized_processable)
|
959
|
+
logger.info(f"Successfully resized {item.item_key} to 50% for processing")
|
960
|
+
else:
|
961
|
+
logger.error(f"Item {item.item_key} still too long after resize, skipping")
|
962
|
+
self.items_failed += 1
|
963
|
+
|
964
|
+
# Send error result
|
965
|
+
stage_result = StageResult(
|
966
|
+
stage_name=stage_name,
|
967
|
+
output_field=stage.output_field,
|
968
|
+
outputs=[],
|
969
|
+
error="Image too large even after resizing",
|
970
|
+
)
|
971
|
+
item.stage_results[stage_name] = stage_result
|
972
|
+
|
973
|
+
self.result_queue.put(
|
974
|
+
{
|
975
|
+
"item": item,
|
976
|
+
"outputs": {},
|
977
|
+
"processing_time_ms": 0.0,
|
978
|
+
"error": f"Failed stage {stage_name}: token limit exceeded",
|
979
|
+
}
|
980
|
+
)
|
744
981
|
|
745
|
-
|
982
|
+
# Process the validated batch
|
983
|
+
if processable_batch:
|
984
|
+
# Build requests for processable items
|
985
|
+
requests = []
|
986
|
+
for item in processable_batch:
|
987
|
+
converted_img = ImageProcessor.prepare_for_inference(item)
|
746
988
|
template_manager = PromptTemplateManager(stage.prompts)
|
747
989
|
|
748
990
|
# Build context
|
749
991
|
context = item.metadata.copy()
|
750
|
-
|
751
|
-
# Add previous stage results
|
752
992
|
for prev_stage_name, stage_result in item.stage_results.items():
|
753
993
|
for i, output in enumerate(stage_result.outputs):
|
754
994
|
context[f"{prev_stage_name}_output_{i}"] = output
|
@@ -769,14 +1009,7 @@ class CaptionWorker(BaseWorker):
|
|
769
1009
|
outputs = llm.generate(requests, sampling_params)
|
770
1010
|
|
771
1011
|
# Process outputs
|
772
|
-
|
773
|
-
failed_items = []
|
774
|
-
|
775
|
-
for idx, (original_idx, item, attempt_count) in enumerate(current_batch):
|
776
|
-
if self.should_stop_processing.is_set():
|
777
|
-
return results
|
778
|
-
|
779
|
-
# Extract outputs
|
1012
|
+
for idx, item in enumerate(processable_batch):
|
780
1013
|
base_idx = idx * len(stage.prompts)
|
781
1014
|
stage_outputs = []
|
782
1015
|
|
@@ -788,40 +1021,18 @@ class CaptionWorker(BaseWorker):
|
|
788
1021
|
stage_outputs.append(cleaned_output)
|
789
1022
|
|
790
1023
|
if stage_outputs:
|
791
|
-
# Success
|
792
1024
|
stage_result = StageResult(
|
793
1025
|
stage_name=stage_name,
|
794
1026
|
output_field=stage.output_field,
|
795
1027
|
outputs=stage_outputs,
|
796
1028
|
)
|
797
1029
|
item.stage_results[stage_name] = stage_result
|
798
|
-
successful_items.append((original_idx, item))
|
799
1030
|
else:
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
self.items_failed += 1
|
806
|
-
stage_result = StageResult(
|
807
|
-
stage_name=stage_name,
|
808
|
-
output_field=stage.output_field,
|
809
|
-
outputs=[],
|
810
|
-
error=f"Failed after {max_attempts} attempts",
|
811
|
-
)
|
812
|
-
item.stage_results[stage_name] = stage_result
|
813
|
-
self.result_queue.put(
|
814
|
-
{
|
815
|
-
"item": item,
|
816
|
-
"outputs": {},
|
817
|
-
"processing_time_ms": 0.0,
|
818
|
-
"error": f"Failed stage {stage_name} after {max_attempts} attempts",
|
819
|
-
}
|
820
|
-
)
|
821
|
-
|
822
|
-
# Update for next iteration
|
823
|
-
items_to_process = failed_items
|
824
|
-
batch = [item for _, item in successful_items]
|
1031
|
+
logger.error(f"No outputs for {item.item_key} in stage {stage_name}")
|
1032
|
+
self.items_failed += 1
|
1033
|
+
|
1034
|
+
# Update batch for next stage
|
1035
|
+
batch = processable_batch
|
825
1036
|
|
826
1037
|
# Convert to results
|
827
1038
|
for item in batch:
|
caption_flow/workers/data.py
CHANGED
@@ -1,20 +1,18 @@
|
|
1
1
|
"""DataWorker for retrieving data from various sources and forwarding to orchestrator or storage."""
|
2
2
|
|
3
3
|
import asyncio
|
4
|
+
import io
|
4
5
|
import json
|
5
6
|
import logging
|
6
|
-
import io
|
7
|
-
import time
|
8
7
|
from dataclasses import dataclass
|
9
8
|
from pathlib import Path
|
10
|
-
from
|
11
|
-
from
|
12
|
-
from
|
9
|
+
from queue import Empty, Queue
|
10
|
+
from threading import Event
|
11
|
+
from typing import Any, AsyncIterator, Dict, Optional
|
13
12
|
|
13
|
+
import boto3
|
14
14
|
import pandas as pd
|
15
15
|
import pyarrow.parquet as pq
|
16
|
-
from PIL import Image
|
17
|
-
import boto3
|
18
16
|
from botocore.config import Config
|
19
17
|
|
20
18
|
from .base import BaseWorker
|
@@ -179,7 +177,7 @@ class DataWorker(BaseWorker):
|
|
179
177
|
try:
|
180
178
|
self.send_queue.put_nowait(batch)
|
181
179
|
batch = []
|
182
|
-
except:
|
180
|
+
except Exception:
|
183
181
|
# Queue full, wait
|
184
182
|
await asyncio.sleep(1)
|
185
183
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: caption-flow
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.4.1
|
4
4
|
Summary: Self-contained distributed community captioning system
|
5
5
|
Author-email: bghira <bghira@users.github.com>
|
6
6
|
License: MIT
|
@@ -9,10 +9,9 @@ Classifier: Development Status :: 4 - Beta
|
|
9
9
|
Classifier: Intended Audience :: Developers
|
10
10
|
Classifier: License :: OSI Approved :: MIT License
|
11
11
|
Classifier: Programming Language :: Python :: 3
|
12
|
-
Classifier: Programming Language :: Python :: 3.10
|
13
12
|
Classifier: Programming Language :: Python :: 3.11
|
14
13
|
Classifier: Programming Language :: Python :: 3.12
|
15
|
-
Requires-Python: <3.13,>=3.
|
14
|
+
Requires-Python: <3.13,>=3.11
|
16
15
|
Description-Content-Type: text/markdown
|
17
16
|
License-File: LICENSE
|
18
17
|
Requires-Dist: websockets>=12.0
|
@@ -35,7 +34,9 @@ Requires-Dist: boto3<2.0.0,>=1.40.11
|
|
35
34
|
Requires-Dist: torchdata<0.12.0,>=0.11.0
|
36
35
|
Requires-Dist: textual<6.0.0,>=5.3.0
|
37
36
|
Requires-Dist: urwid<4.0.0,>=3.0.2
|
38
|
-
Requires-Dist: webshart<0.5.0,>=0.4.
|
37
|
+
Requires-Dist: webshart<0.5.0,>=0.4.3
|
38
|
+
Requires-Dist: pylance<0.36.0,>=0.35.0
|
39
|
+
Requires-Dist: duckdb<2.0.0,>=1.3.2
|
39
40
|
Provides-Extra: dev
|
40
41
|
Requires-Dist: pytest>=7.4.0; extra == "dev"
|
41
42
|
Requires-Dist: pytest-asyncio>=0.21.0; extra == "dev"
|
@@ -47,6 +48,10 @@ Dynamic: license-file
|
|
47
48
|
|
48
49
|
# CaptionFlow
|
49
50
|
|
51
|
+
<!-- [](https://github.com/bghira/CaptionFlow/actions/workflows/tests.yml) -->
|
52
|
+
[](https://codecov.io/github/bghira/CaptionFlow)
|
53
|
+
[](https://badge.fury.io/py/caption-flow)
|
54
|
+
|
50
55
|
scalable, fault-tolerant **vLLM-powered image captioning**.
|
51
56
|
|
52
57
|
a fast websocket-based orchestrator paired with lightweight gpu workers achieves exceptional performance for batched requests through vLLM.
|
@@ -0,0 +1,33 @@
|
|
1
|
+
caption_flow/__init__.py,sha256=AanaoBXNzR2j3ow-uWQQXmYpv6sUXLfLrqACm55_BMY,303
|
2
|
+
caption_flow/cli.py,sha256=q3M6ekz70huVGD7NBqsO5xZUqMYBhLqe0ZGo85Vb69g,56072
|
3
|
+
caption_flow/models.py,sha256=6-IJj_B3HAarucoLo8_PncJRnxofHuLFCsyRnmUXgRk,7063
|
4
|
+
caption_flow/monitor.py,sha256=j5RExadSLOUujVZQMe7btMeKNlq-WbZ9bYqfikgYJ8Q,7972
|
5
|
+
caption_flow/orchestrator.py,sha256=MWQKaAclI9rMjn7mWdvoSzl9y4b7bU_24aVr8I1YGhE,39645
|
6
|
+
caption_flow/viewer.py,sha256=40w2Zj7GaXbK-dgqvYYdFrMzSDE_ZPWNZc6kS0OrymQ,20281
|
7
|
+
caption_flow/processors/__init__.py,sha256=l1udEZLxAmqwFYS4-3GsRVcPT6WxnDOIk0s0UqsZsJM,423
|
8
|
+
caption_flow/processors/base.py,sha256=Zx6kRZSqG969x8kYJ5VY2Mo5mLeWEgBCEpo8D4GjsBM,6935
|
9
|
+
caption_flow/processors/huggingface.py,sha256=i-DZRt5nTnPN8180Yf8FKBiYPUPmxfKMEZ68CUZECWk,61603
|
10
|
+
caption_flow/processors/local_filesystem.py,sha256=auAWxnqplEH4YJ1DWZCaFmAd03iyhNLudgt71N8O7NE,27827
|
11
|
+
caption_flow/processors/webdataset.py,sha256=66y_7KaJBBntJqBHYKLzCXkBi9ly-TfYYaTCp_7pqTo,34206
|
12
|
+
caption_flow/storage/__init__.py,sha256=IVnzcSCPpPuyp-QLlgJirRZ9Sb3tR0F4sfuF5u2cNMk,36
|
13
|
+
caption_flow/storage/exporter.py,sha256=6atbxWgxSu_5qg9l8amwgkXRL1SKTZQb2yryu62yPc8,22371
|
14
|
+
caption_flow/storage/manager.py,sha256=2jkyNl-2_B2Z7NfjCBua-Jgo7Km_JmJqMKrYsYj5uF4,41416
|
15
|
+
caption_flow/utils/__init__.py,sha256=ULJImkcFPc8QH2Zz6TW7AeVXMFdRpvfni2MgEo_PRyY,120
|
16
|
+
caption_flow/utils/auth.py,sha256=6HRNnWfX1j1Jh55M23crfSA1olkFGg-9kZ5Booy5wCM,2253
|
17
|
+
caption_flow/utils/caption_utils.py,sha256=7k6GnElIAqyyzDHQd3JC3Ffr7r57sFWqS3ET7itzdoM,5309
|
18
|
+
caption_flow/utils/certificates.py,sha256=NiHSeeZYKrf5BpAkwg5qOe-1C7-z42jZO3pjQo0N3I8,4889
|
19
|
+
caption_flow/utils/checkpoint_tracker.py,sha256=LoCGjb30QOcMESHLF5hKVCd8X8_gWACyyq9EKLTXIn4,4613
|
20
|
+
caption_flow/utils/chunk_tracker.py,sha256=And1krrTvpfiwG7xRxh9n6xy-_W8MSWSkcGmFSDFnB8,25460
|
21
|
+
caption_flow/utils/image_processor.py,sha256=_dmiKXcAKxjkQ6d9V5QgoZSf_dDOL52tFMOEXa3iA24,1581
|
22
|
+
caption_flow/utils/json_utils.py,sha256=AaGcNTToUcVYCQj2TXs2D_hxc_LeEqFquiK4CquS0U8,5537
|
23
|
+
caption_flow/utils/prompt_template.py,sha256=mq7FPnpjp8gVCMMh4NtRf0vL_B9LDMuBkbySvACRSZM,4401
|
24
|
+
caption_flow/utils/vllm_config.py,sha256=xFOnmniQGkUGwfTabfW6R0V01TF-_rN1UYJy0HwOvUI,6026
|
25
|
+
caption_flow/workers/base.py,sha256=Yh_PBsL3j1kXUuIOQHqIdR69Nepfq11je23i01iWSxw,7714
|
26
|
+
caption_flow/workers/caption.py,sha256=qph-TVMUqObRQBgriXOJtCgkWOo3qBdTg883D1TuXlw,48994
|
27
|
+
caption_flow/workers/data.py,sha256=iWnTM7UgpJeFzhSTly-gHzFu5sIYUGG-XO4yRNn_MQk,14775
|
28
|
+
caption_flow-0.4.1.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
|
29
|
+
caption_flow-0.4.1.dist-info/METADATA,sha256=2mg45AYJVVZrgBzD611qFaWfNFId_3Xhl8xpwlFNrjg,10123
|
30
|
+
caption_flow-0.4.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
31
|
+
caption_flow-0.4.1.dist-info/entry_points.txt,sha256=KnVlyrGKZj6p2zNyuEnCx4Y6jvJ4V-mcfN0lddPKTlQ,55
|
32
|
+
caption_flow-0.4.1.dist-info/top_level.txt,sha256=_bXpKRutqded0FQ80dCChIz26ETV7tL4d4e2E_Y1FXs,13
|
33
|
+
caption_flow-0.4.1.dist-info/RECORD,,
|