caption-flow 0.2.4__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +1 -1
- caption_flow/orchestrator.py +9 -9
- caption_flow/processors/base.py +3 -0
- caption_flow/processors/huggingface.py +637 -464
- caption_flow/processors/local_filesystem.py +2 -0
- caption_flow/processors/webdataset.py +438 -538
- caption_flow/storage/manager.py +328 -305
- caption_flow/utils/__init__.py +0 -2
- caption_flow/utils/chunk_tracker.py +197 -164
- caption_flow/utils/image_processor.py +19 -132
- caption_flow/workers/caption.py +191 -138
- {caption_flow-0.2.4.dist-info → caption_flow-0.3.2.dist-info}/METADATA +2 -1
- caption_flow-0.3.2.dist-info/RECORD +33 -0
- caption_flow/utils/dataset_loader.py +0 -222
- caption_flow/utils/dataset_metadata_cache.py +0 -67
- caption_flow/utils/job_queue.py +0 -41
- caption_flow/utils/shard_processor.py +0 -119
- caption_flow/utils/shard_tracker.py +0 -83
- caption_flow-0.2.4.dist-info/RECORD +0 -38
- {caption_flow-0.2.4.dist-info → caption_flow-0.3.2.dist-info}/WHEEL +0 -0
- {caption_flow-0.2.4.dist-info → caption_flow-0.3.2.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.2.4.dist-info → caption_flow-0.3.2.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.2.4.dist-info → caption_flow-0.3.2.dist-info}/top_level.txt +0 -0
caption_flow/workers/caption.py
CHANGED
@@ -34,7 +34,7 @@ from ..utils.prompt_template import PromptTemplateManager
|
|
34
34
|
from ..models import ProcessingStage, StageResult
|
35
35
|
|
36
36
|
logger = logging.getLogger(__name__)
|
37
|
-
|
37
|
+
logger.setLevel(logging.INFO)
|
38
38
|
|
39
39
|
|
40
40
|
@dataclass
|
@@ -163,14 +163,14 @@ class CaptionWorker(BaseWorker):
|
|
163
163
|
def __init__(self, config: Dict[str, Any]):
|
164
164
|
super().__init__(config)
|
165
165
|
|
166
|
-
# Processor configuration
|
166
|
+
# Processor configuration
|
167
167
|
self.processor_type = None
|
168
168
|
self.processor: Optional[
|
169
169
|
Union[
|
170
170
|
WebDatasetWorkerProcessor,
|
171
171
|
HuggingFaceDatasetWorkerProcessor,
|
172
172
|
LocalFilesystemWorkerProcessor,
|
173
|
-
]
|
173
|
+
]
|
174
174
|
] = None
|
175
175
|
self.dataset_path: Optional[str] = None
|
176
176
|
|
@@ -181,12 +181,16 @@ class CaptionWorker(BaseWorker):
|
|
181
181
|
self.vllm_config_manager = VLLMConfigManager()
|
182
182
|
self.model_manager = None
|
183
183
|
|
184
|
+
# Mock mode flag
|
185
|
+
self.mock_mode = False
|
186
|
+
|
184
187
|
# GPU selection
|
185
188
|
self.gpu_id = config.get("gpu_id", 0)
|
186
189
|
self.hf_token = get_token()
|
187
190
|
|
188
191
|
# Image processor
|
189
|
-
batch_image_processing = config.get("batch_image_processing",
|
192
|
+
batch_image_processing = config.get("batch_image_processing", True)
|
193
|
+
logger.info(f"Using batch processing: {batch_image_processing}")
|
190
194
|
self.image_processor = ImageProcessor() if batch_image_processing else None
|
191
195
|
|
192
196
|
# Work processing
|
@@ -194,9 +198,7 @@ class CaptionWorker(BaseWorker):
|
|
194
198
|
self.assigned_units = deque()
|
195
199
|
self.current_unit: Optional[WorkUnit] = None
|
196
200
|
|
197
|
-
#
|
198
|
-
self.readahead_queue = Queue(maxsize=256)
|
199
|
-
self.inference_queue = Queue(maxsize=128)
|
201
|
+
# Single result queue for sending back to orchestrator
|
200
202
|
self.result_queue = Queue()
|
201
203
|
|
202
204
|
# Processing control
|
@@ -230,13 +232,18 @@ class CaptionWorker(BaseWorker):
|
|
230
232
|
logger.error(f"Failed to get config: {e}")
|
231
233
|
await asyncio.sleep(5)
|
232
234
|
|
233
|
-
#
|
234
|
-
if self.vllm_config
|
235
|
-
|
235
|
+
# Check for mock mode
|
236
|
+
self.mock_mode = self.vllm_config.get("mock_results", False) if self.vllm_config else False
|
237
|
+
|
238
|
+
if self.mock_mode:
|
239
|
+
logger.info("🎭 MOCK MODE ENABLED - No vLLM models will be loaded")
|
240
|
+
else:
|
241
|
+
# Initialize vLLM once we have config
|
242
|
+
if self.vllm_config:
|
243
|
+
self._setup_vllm()
|
236
244
|
|
237
|
-
# Start
|
238
|
-
Thread(target=self.
|
239
|
-
Thread(target=self._inference_thread, daemon=True).start()
|
245
|
+
# Start processing thread
|
246
|
+
Thread(target=self._processing_thread, daemon=True).start()
|
240
247
|
|
241
248
|
async def _initial_connect_for_config(self):
|
242
249
|
"""Connect initially just to get configuration."""
|
@@ -269,8 +276,6 @@ class CaptionWorker(BaseWorker):
|
|
269
276
|
self.assigned_units.clear()
|
270
277
|
self.current_unit = None
|
271
278
|
|
272
|
-
self._clear_queue(self.readahead_queue)
|
273
|
-
self._clear_queue(self.inference_queue)
|
274
279
|
self._clear_queue(self.result_queue)
|
275
280
|
|
276
281
|
# Reset counters
|
@@ -297,6 +302,7 @@ class CaptionWorker(BaseWorker):
|
|
297
302
|
|
298
303
|
self.processor.initialize(processor_config)
|
299
304
|
self.dataset_path = self.processor.dataset_path
|
305
|
+
self.units_per_request = processor_config.config.get("chunks_per_request", 1)
|
300
306
|
|
301
307
|
# Update vLLM config if provided
|
302
308
|
new_vllm_config = welcome_data.get("processor_config", {}).get("vllm")
|
@@ -309,7 +315,7 @@ class CaptionWorker(BaseWorker):
|
|
309
315
|
|
310
316
|
# Request initial work
|
311
317
|
if self.websocket:
|
312
|
-
await self.websocket.send(json.dumps({"type": "
|
318
|
+
await self.websocket.send(json.dumps({"type": "get_work_units", "count": 2}))
|
313
319
|
|
314
320
|
async def _handle_message(self, data: Dict[str, Any]):
|
315
321
|
"""Handle message from orchestrator."""
|
@@ -327,7 +333,7 @@ class CaptionWorker(BaseWorker):
|
|
327
333
|
await asyncio.sleep(10)
|
328
334
|
|
329
335
|
if self.websocket and self.connected.is_set():
|
330
|
-
await self.websocket.send(json.dumps({"type": "
|
336
|
+
await self.websocket.send(json.dumps({"type": "get_work_units", "count": 2}))
|
331
337
|
|
332
338
|
def _parse_stages_config(self, vllm_config: Dict[str, Any]) -> List[ProcessingStage]:
|
333
339
|
"""Parse stages configuration from vLLM config."""
|
@@ -405,6 +411,7 @@ class CaptionWorker(BaseWorker):
|
|
405
411
|
self.model_manager = MultiStageVLLMManager(self.gpu_id)
|
406
412
|
|
407
413
|
# Get base config for models
|
414
|
+
logger.info(f"vLLM config: {self.vllm_config}")
|
408
415
|
base_config = {
|
409
416
|
"tensor_parallel_size": self.vllm_config.get("tensor_parallel_size", 1),
|
410
417
|
"max_model_len": self.vllm_config.get("max_model_len", 16384),
|
@@ -437,6 +444,13 @@ class CaptionWorker(BaseWorker):
|
|
437
444
|
if not new_config:
|
438
445
|
return True
|
439
446
|
|
447
|
+
# Check if mock mode changed
|
448
|
+
old_mock_mode = self.mock_mode
|
449
|
+
self.mock_mode = new_config.get("mock_results", False)
|
450
|
+
|
451
|
+
if old_mock_mode != self.mock_mode:
|
452
|
+
logger.info(f"Mock mode changed from {old_mock_mode} to {self.mock_mode}")
|
453
|
+
|
440
454
|
# Parse new stages
|
441
455
|
new_stages = self._parse_stages_config(new_config)
|
442
456
|
|
@@ -453,35 +467,43 @@ class CaptionWorker(BaseWorker):
|
|
453
467
|
stages_changed = True
|
454
468
|
break
|
455
469
|
|
456
|
-
if stages_changed:
|
457
|
-
logger.info("
|
470
|
+
if stages_changed or old_mock_mode != self.mock_mode:
|
471
|
+
logger.info("Configuration changed significantly")
|
458
472
|
|
459
473
|
old_config = self.vllm_config
|
460
474
|
self.vllm_config = new_config
|
461
475
|
self.stages = new_stages
|
462
476
|
self.stage_order = self._topological_sort_stages(self.stages)
|
463
477
|
|
464
|
-
|
478
|
+
if not self.mock_mode:
|
479
|
+
try:
|
480
|
+
if self.model_manager:
|
481
|
+
self.model_manager.cleanup()
|
482
|
+
self._setup_vllm()
|
483
|
+
return True
|
484
|
+
except Exception as e:
|
485
|
+
logger.error(f"Failed to reload vLLM: {e}")
|
486
|
+
self.vllm_config = old_config
|
487
|
+
return False
|
488
|
+
else:
|
489
|
+
# Clean up models if switching to mock mode
|
465
490
|
if self.model_manager:
|
466
491
|
self.model_manager.cleanup()
|
467
|
-
|
492
|
+
self.model_manager = None
|
468
493
|
return True
|
469
|
-
except Exception as e:
|
470
|
-
logger.error(f"Failed to reload vLLM: {e}")
|
471
|
-
self.vllm_config = old_config
|
472
|
-
return False
|
473
494
|
else:
|
474
495
|
# Just update sampling params
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
self.
|
496
|
+
if not self.mock_mode:
|
497
|
+
logger.info("Updating sampling parameters without model reload")
|
498
|
+
base_sampling = new_config.get("sampling", {})
|
499
|
+
for stage in self.stages:
|
500
|
+
self.model_manager.create_sampling_params(stage, base_sampling)
|
479
501
|
self.vllm_config = new_config
|
480
502
|
return True
|
481
503
|
|
482
|
-
def
|
483
|
-
"""
|
484
|
-
logger.info("Starting
|
504
|
+
def _processing_thread(self):
|
505
|
+
"""Main processing thread that handles work units."""
|
506
|
+
logger.info("Starting processing thread")
|
485
507
|
|
486
508
|
while self.running:
|
487
509
|
if self.should_stop_processing.is_set():
|
@@ -513,11 +535,13 @@ class CaptionWorker(BaseWorker):
|
|
513
535
|
with self.work_lock:
|
514
536
|
queue_size = len(self.assigned_units)
|
515
537
|
|
516
|
-
if queue_size <
|
538
|
+
if queue_size < self.units_per_request and self.websocket and self.main_loop:
|
517
539
|
try:
|
518
540
|
asyncio.run_coroutine_threadsafe(
|
519
541
|
self.websocket.send(
|
520
|
-
json.dumps(
|
542
|
+
json.dumps(
|
543
|
+
{"type": "get_work_units", "count": self.units_per_request}
|
544
|
+
)
|
521
545
|
),
|
522
546
|
self.main_loop,
|
523
547
|
).result(timeout=5)
|
@@ -533,16 +557,21 @@ class CaptionWorker(BaseWorker):
|
|
533
557
|
self.current_unit = None
|
534
558
|
|
535
559
|
def _process_work_unit(self, unit: WorkUnit):
|
536
|
-
"""Process a single work unit."""
|
560
|
+
"""Process a single work unit with batching."""
|
537
561
|
if not self.processor:
|
538
562
|
logger.error("Processor not initialized")
|
539
563
|
return
|
540
564
|
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
565
|
+
batch = []
|
566
|
+
batch_size = self.vllm_config.get("batch_size", 8)
|
567
|
+
context = {}
|
568
|
+
self.items_processed = 0
|
569
|
+
self.items_failed = 0
|
570
|
+
# Collect items for batching
|
545
571
|
for item_data in self.processor.process_unit(unit, context):
|
572
|
+
if self.should_stop_processing.is_set() or not self.connected.is_set():
|
573
|
+
break
|
574
|
+
|
546
575
|
try:
|
547
576
|
# Create processing item
|
548
577
|
item = ProcessingItem(
|
@@ -551,35 +580,19 @@ class CaptionWorker(BaseWorker):
|
|
551
580
|
job_id=item_data["job_id"],
|
552
581
|
item_key=item_data["item_key"],
|
553
582
|
item_index=item_data["item_index"],
|
554
|
-
image=item_data
|
583
|
+
image=item_data.get("image", None),
|
555
584
|
image_data=item_data.get("image_data", b""),
|
556
585
|
metadata=item_data.get("metadata", {}),
|
557
586
|
)
|
587
|
+
if "_processed_indices" in item_data:
|
588
|
+
context["_processed_indices"] = item_data.pop("_processed_indices", [])
|
558
589
|
|
559
|
-
|
560
|
-
timeout_end = time.time() + 30
|
561
|
-
while (
|
562
|
-
self.running
|
563
|
-
and not self.should_stop_processing.is_set()
|
564
|
-
and self.connected.is_set()
|
565
|
-
):
|
566
|
-
try:
|
567
|
-
self.readahead_queue.put(item, timeout=1)
|
568
|
-
break
|
569
|
-
except:
|
570
|
-
if time.time() > timeout_end:
|
571
|
-
raise TimeoutError("Queue put timeout")
|
572
|
-
continue
|
573
|
-
|
574
|
-
if not self.connected.is_set() or self.should_stop_processing.is_set():
|
575
|
-
break
|
576
|
-
|
577
|
-
items_processed += 1
|
590
|
+
batch.append(item)
|
578
591
|
|
579
|
-
#
|
580
|
-
|
581
|
-
|
582
|
-
|
592
|
+
# Process batch when it reaches size
|
593
|
+
if len(batch) >= batch_size:
|
594
|
+
self._process_batch(batch)
|
595
|
+
batch = []
|
583
596
|
|
584
597
|
except Exception as e:
|
585
598
|
if self.should_stop_processing.is_set():
|
@@ -587,77 +600,112 @@ class CaptionWorker(BaseWorker):
|
|
587
600
|
logger.error(f"Error processing item {item_data.get('item_key')}: {e}")
|
588
601
|
self.items_failed += 1
|
589
602
|
|
590
|
-
# Process
|
591
|
-
if not self.should_stop_processing.is_set():
|
592
|
-
self.
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
603
|
+
# Process remaining items in batch
|
604
|
+
if batch and not self.should_stop_processing.is_set():
|
605
|
+
self._process_batch(batch)
|
606
|
+
|
607
|
+
# Notify orchestrator that unit is complete
|
608
|
+
# Check if the number of processed items matches the expected count for the unit.
|
609
|
+
# The context dictionary holds the count of items yielded by the processor.
|
610
|
+
total_items_in_unit = unit.unit_size
|
611
|
+
|
612
|
+
if (
|
613
|
+
not self.should_stop_processing.is_set()
|
614
|
+
and self.connected.is_set()
|
615
|
+
and self.items_failed == 0
|
616
|
+
and self.items_processed >= total_items_in_unit
|
617
|
+
):
|
618
|
+
if self.websocket:
|
619
|
+
try:
|
620
|
+
asyncio.run_coroutine_threadsafe(
|
621
|
+
self.websocket.send(
|
622
|
+
json.dumps({"type": "work_complete", "unit_id": unit.unit_id})
|
623
|
+
),
|
624
|
+
self.main_loop,
|
625
|
+
).result(timeout=5)
|
626
|
+
logger.info(
|
627
|
+
f"Unit {unit.unit_id} fully processed ({self.items_processed}/{total_items_in_unit}) and marked complete."
|
628
|
+
)
|
629
|
+
except Exception as e:
|
630
|
+
logger.warning(f"Could not notify work complete for unit {unit.unit_id}: {e}")
|
631
|
+
else:
|
632
|
+
logger.warning(
|
633
|
+
f"Processing of unit {unit.unit_id} was incomplete ({self.items_processed}/{total_items_in_unit}). Not marking as complete."
|
634
|
+
)
|
635
|
+
|
636
|
+
def _process_batch(self, batch: List[ProcessingItem]):
|
637
|
+
"""Process a batch of items through all stages."""
|
638
|
+
if not batch:
|
639
|
+
return
|
640
|
+
|
641
|
+
logger.debug(f"Processing batch of {len(batch)} images")
|
642
|
+
start_time = time.time()
|
608
643
|
|
609
644
|
try:
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
645
|
+
# Process batch through all stages
|
646
|
+
if self.mock_mode:
|
647
|
+
results = self._process_batch_mock(batch)
|
648
|
+
else:
|
649
|
+
results = self._process_batch_multi_stage(batch)
|
615
650
|
|
616
|
-
|
617
|
-
|
651
|
+
# Calculate processing time
|
652
|
+
if results:
|
653
|
+
processing_time_per_item = (time.time() - start_time) * 1000 / len(batch)
|
654
|
+
|
655
|
+
for item, result_outputs in results:
|
656
|
+
self.result_queue.put(
|
657
|
+
{
|
658
|
+
"item": item,
|
659
|
+
"outputs": result_outputs,
|
660
|
+
"processing_time_ms": processing_time_per_item,
|
661
|
+
}
|
662
|
+
)
|
618
663
|
|
619
|
-
|
620
|
-
"""Background thread for multi-stage vLLM inference."""
|
621
|
-
logger.info("Starting multi-stage inference thread")
|
664
|
+
logger.debug(f"Batch processing complete: {len(results)} successful")
|
622
665
|
|
623
|
-
|
624
|
-
|
625
|
-
batch = self.inference_queue.get(timeout=1)
|
626
|
-
if not batch:
|
627
|
-
continue
|
666
|
+
except Exception as e:
|
667
|
+
logger.error(f"Batch processing error: {e}", exc_info=True)
|
628
668
|
|
629
|
-
|
630
|
-
|
669
|
+
def _process_batch_mock(self, batch: List[ProcessingItem]) -> List[Tuple[ProcessingItem, Dict]]:
|
670
|
+
"""Process a batch in mock mode - return dummy captions."""
|
671
|
+
results = []
|
631
672
|
|
632
|
-
|
633
|
-
|
634
|
-
)
|
635
|
-
start_time = time.time()
|
673
|
+
# Simulate some processing time
|
674
|
+
time.sleep(0.1)
|
636
675
|
|
637
|
-
|
638
|
-
|
676
|
+
for item in batch:
|
677
|
+
# Generate mock outputs for each stage
|
678
|
+
for stage_name in self.stage_order:
|
679
|
+
stage = next(s for s in self.stages if s.name == stage_name)
|
680
|
+
|
681
|
+
# Create mock outputs based on stage prompts
|
682
|
+
stage_outputs = []
|
683
|
+
for i, prompt in enumerate(stage.prompts):
|
684
|
+
mock_output = (
|
685
|
+
f"Mock {stage_name} output {i+1} for job {item.job_id} - {item.item_key}"
|
686
|
+
)
|
687
|
+
stage_outputs.append(mock_output)
|
639
688
|
|
640
|
-
#
|
641
|
-
|
642
|
-
|
689
|
+
# Store stage result
|
690
|
+
stage_result = StageResult(
|
691
|
+
stage_name=stage_name,
|
692
|
+
output_field=stage.output_field,
|
693
|
+
outputs=stage_outputs,
|
694
|
+
)
|
695
|
+
item.stage_results[stage_name] = stage_result
|
643
696
|
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
)
|
697
|
+
# Aggregate outputs by field
|
698
|
+
outputs_by_field = defaultdict(list)
|
699
|
+
for stage_result in item.stage_results.values():
|
700
|
+
outputs_by_field[stage_result.output_field].extend(stage_result.outputs)
|
701
|
+
|
702
|
+
results.append((item, dict(outputs_by_field)))
|
703
|
+
self.items_processed += 1
|
652
704
|
|
653
|
-
|
705
|
+
if self.items_processed % 10 == 0:
|
706
|
+
logger.info(f"🎭 Mock mode: Processed {self.items_processed} items")
|
654
707
|
|
655
|
-
|
656
|
-
continue
|
657
|
-
except Exception as e:
|
658
|
-
if self.should_stop_processing.is_set():
|
659
|
-
continue
|
660
|
-
logger.error(f"Inference error: {e}", exc_info=True)
|
708
|
+
return results
|
661
709
|
|
662
710
|
def _process_batch_multi_stage(
|
663
711
|
self, batch: List[ProcessingItem], max_attempts: int = 3
|
@@ -685,8 +733,8 @@ class CaptionWorker(BaseWorker):
|
|
685
733
|
for idx, (original_idx, item, attempt_count) in enumerate(items_to_process):
|
686
734
|
current_batch.append((original_idx, item, attempt_count))
|
687
735
|
|
688
|
-
# Prepare image
|
689
|
-
converted_img = ImageProcessor.prepare_for_inference(item
|
736
|
+
# Prepare image from PIL frame or bytes
|
737
|
+
converted_img = ImageProcessor.prepare_for_inference(item)
|
690
738
|
|
691
739
|
# Create template manager
|
692
740
|
template_manager = PromptTemplateManager(stage.prompts)
|
@@ -832,12 +880,11 @@ class CaptionWorker(BaseWorker):
|
|
832
880
|
"units_completed": self.units_completed,
|
833
881
|
"current_unit": self._get_current_unit_id() if self.current_unit else None,
|
834
882
|
"queue_sizes": {
|
835
|
-
"readahead": self.readahead_queue.qsize(),
|
836
|
-
"inference": self.inference_queue.qsize(),
|
837
883
|
"results": self.result_queue.qsize(),
|
838
884
|
},
|
839
885
|
"stages": len(self.stages),
|
840
886
|
"models_loaded": len(self.model_manager.models) if self.model_manager else 0,
|
887
|
+
"mock_mode": self.mock_mode,
|
841
888
|
}
|
842
889
|
|
843
890
|
async def _create_tasks(self) -> list:
|
@@ -852,7 +899,7 @@ class CaptionWorker(BaseWorker):
|
|
852
899
|
"""Send results back to orchestrator."""
|
853
900
|
while self.running and self.connected.is_set():
|
854
901
|
try:
|
855
|
-
# Get result
|
902
|
+
# Get result with timeout
|
856
903
|
result_data = await asyncio.get_event_loop().run_in_executor(
|
857
904
|
None, self.result_queue.get, True, 1
|
858
905
|
)
|
@@ -863,7 +910,6 @@ class CaptionWorker(BaseWorker):
|
|
863
910
|
outputs = result_data["outputs"]
|
864
911
|
|
865
912
|
# Create work result
|
866
|
-
# logger.info(f"Processed item: {item}")
|
867
913
|
work_result = WorkResult(
|
868
914
|
unit_id=item.unit_id,
|
869
915
|
source_id=item.metadata.get("shard_name", "unknown"),
|
@@ -873,9 +919,21 @@ class CaptionWorker(BaseWorker):
|
|
873
919
|
metadata={
|
874
920
|
"item_key": item.item_key,
|
875
921
|
"item_index": item.metadata.get("_item_index"),
|
876
|
-
"image_width":
|
877
|
-
|
878
|
-
|
922
|
+
"image_width": (
|
923
|
+
item.image.width
|
924
|
+
if item.image is not None
|
925
|
+
else item.metadata.get("image_width")
|
926
|
+
),
|
927
|
+
"image_height": (
|
928
|
+
item.image.height
|
929
|
+
if item.image is not None
|
930
|
+
else item.metadata.get("image_height")
|
931
|
+
),
|
932
|
+
"image_format": (
|
933
|
+
item.image.format
|
934
|
+
if item.image is not None
|
935
|
+
else item.metadata.get("image_format", "unknown")
|
936
|
+
),
|
879
937
|
"file_size": len(item.image_data) if item.image_data else 0,
|
880
938
|
**item.metadata,
|
881
939
|
},
|
@@ -883,7 +941,7 @@ class CaptionWorker(BaseWorker):
|
|
883
941
|
error=result_data.get("error", None),
|
884
942
|
)
|
885
943
|
|
886
|
-
# Send result
|
944
|
+
# Send result
|
887
945
|
await self.websocket.send(
|
888
946
|
json.dumps(
|
889
947
|
{
|
@@ -920,9 +978,7 @@ class CaptionWorker(BaseWorker):
|
|
920
978
|
self.assigned_units.clear()
|
921
979
|
self.current_unit = None
|
922
980
|
|
923
|
-
# Clear
|
924
|
-
self._clear_queue(self.readahead_queue)
|
925
|
-
self._clear_queue(self.inference_queue)
|
981
|
+
# Clear result queue
|
926
982
|
self._clear_queue(self.result_queue)
|
927
983
|
|
928
984
|
def _clear_queue(self, queue: Queue):
|
@@ -935,9 +991,6 @@ class CaptionWorker(BaseWorker):
|
|
935
991
|
|
936
992
|
async def _pre_shutdown(self):
|
937
993
|
"""Cleanup before shutdown."""
|
938
|
-
self.readahead_queue.put(None)
|
939
|
-
self.inference_queue.put(None)
|
940
|
-
|
941
994
|
if self.image_processor:
|
942
995
|
self.image_processor.shutdown()
|
943
996
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: caption-flow
|
3
|
-
Version: 0.2
|
3
|
+
Version: 0.3.2
|
4
4
|
Summary: Self-contained distributed community captioning system
|
5
5
|
Author-email: bghira <bghira@users.github.com>
|
6
6
|
License: MIT
|
@@ -35,6 +35,7 @@ Requires-Dist: boto3<2.0.0,>=1.40.11
|
|
35
35
|
Requires-Dist: torchdata<0.12.0,>=0.11.0
|
36
36
|
Requires-Dist: textual<6.0.0,>=5.3.0
|
37
37
|
Requires-Dist: urwid<4.0.0,>=3.0.2
|
38
|
+
Requires-Dist: webshart<0.5.0,>=0.4.0
|
38
39
|
Provides-Extra: dev
|
39
40
|
Requires-Dist: pytest>=7.4.0; extra == "dev"
|
40
41
|
Requires-Dist: pytest-asyncio>=0.21.0; extra == "dev"
|
@@ -0,0 +1,33 @@
|
|
1
|
+
caption_flow/__init__.py,sha256=09Vyr0RqKrKe1caUhXq9beficJkmclryjT6BNiASUxQ,303
|
2
|
+
caption_flow/cli.py,sha256=t_cYCxJE7f5UtB3br2Es51JjO5KPsWM1JTdDXAxM_Lw,41371
|
3
|
+
caption_flow/models.py,sha256=2n6iphTEL62xK2FFcJM6axMsaE8KwsUv5Ak_cCF-TdQ,5652
|
4
|
+
caption_flow/monitor.py,sha256=bAt9EJqfPgT_KdbknGdCxwBRH002pRDgyUmYIj6Dyso,7885
|
5
|
+
caption_flow/orchestrator.py,sha256=34gZvaW14YZ7a7LagYOO3VKKwlbuS4aw0yoP1L8gwf0,36192
|
6
|
+
caption_flow/viewer.py,sha256=HxO98eHR1xtivG0dEdYC2U9T_RgeRfJqqTK-37u9bNM,20471
|
7
|
+
caption_flow/processors/__init__.py,sha256=hvq-OuAJWQe6hFglKe7QmkS8473k20FmxZDSxfXpCrg,423
|
8
|
+
caption_flow/processors/base.py,sha256=IAEr0pqHRuSkXunvDWk1vf2IKeYQ-2YERqej9iSQm94,6931
|
9
|
+
caption_flow/processors/huggingface.py,sha256=w0j7PRosXYyJXZ0A0Y-J6_n-aHCGVW8tbt8lcvguO_Y,41237
|
10
|
+
caption_flow/processors/local_filesystem.py,sha256=OuNNDemy0sdtpBBC_5GbI-c1vMqp8OIz983Cq85gdb8,27964
|
11
|
+
caption_flow/processors/webdataset.py,sha256=TkC6xZO6m2FcwiBQGJsSQcrshBKcLdr4edFVtnBOd3U,28999
|
12
|
+
caption_flow/storage/__init__.py,sha256=IVnzcSCPpPuyp-QLlgJirRZ9Sb3tR0F4sfuF5u2cNMk,36
|
13
|
+
caption_flow/storage/exporter.py,sha256=mFJqMDQ61cP-qcXe118_-oL1TUqULdQZ8LdjSTym44I,19697
|
14
|
+
caption_flow/storage/manager.py,sha256=KPExcKPuFVQSsBnfCBdne5PO4PwN4NTfd-EJQk13OY0,47459
|
15
|
+
caption_flow/utils/__init__.py,sha256=bDcO5uR455TKCQ2hX-_XcdTnRXDBaT8Yn4jWqWzfFsE,120
|
16
|
+
caption_flow/utils/auth.py,sha256=UrxX2n8OEEcfMD1Ey27TxGfrJFmUCpC59x-SCrQJoVE,2253
|
17
|
+
caption_flow/utils/caption_utils.py,sha256=esUMAdcCkNjRroZ0Bhxv0_yKlLtMf0XeDCTt-5k6bik,5309
|
18
|
+
caption_flow/utils/certificates.py,sha256=eu4blQZEkL9NRaY1ynQWg1asvDorRYhGRZea7STonJE,4635
|
19
|
+
caption_flow/utils/checkpoint_tracker.py,sha256=-nN5gLvXyMdKOCT2SNNL2Km6UYm2Hii9wuXeezWhwx4,3339
|
20
|
+
caption_flow/utils/chunk_tracker.py,sha256=HntWeINTbJmIERsW21p4q4FK8D9-4xKbZQUsj24DIqo,19975
|
21
|
+
caption_flow/utils/image_processor.py,sha256=wmOExkVfM7OeuLfX3AwMefsH-TxL8TNcn22gp0NmJKY,1541
|
22
|
+
caption_flow/utils/json_utils.py,sha256=IiZYn8uCM-3pYmyIbX2fmaOIyutArn67SqAyp0ggNpU,5396
|
23
|
+
caption_flow/utils/prompt_template.py,sha256=AKp0diSZqNBMwZkpiTNjw8-bbQwHStr7QZTOJ7o1dC4,4345
|
24
|
+
caption_flow/utils/vllm_config.py,sha256=TC7Rmjk0zRKbBXbWUXrFL4Z58hzax_-4L0pXZn09hdM,6019
|
25
|
+
caption_flow/workers/base.py,sha256=2AGWERC5hbmO-0V_A1MUbgRVvRNN3blqGPyDokvvzmM,7575
|
26
|
+
caption_flow/workers/caption.py,sha256=X4BEmb6C1c73hvgJDMsHtgCUlCuECtnloWSVolVpa4s,39353
|
27
|
+
caption_flow/workers/data.py,sha256=0Tg8NE0wdONeMlivYQ4nvbcfWdLuU51O7vR8_YSnJgo,14813
|
28
|
+
caption_flow-0.3.2.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
|
29
|
+
caption_flow-0.3.2.dist-info/METADATA,sha256=8bHECzNi4R6_FlbHWSHMx9TDo4uTVKWWgVbqAe5cCIs,9708
|
30
|
+
caption_flow-0.3.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
31
|
+
caption_flow-0.3.2.dist-info/entry_points.txt,sha256=KnVlyrGKZj6p2zNyuEnCx4Y6jvJ4V-mcfN0lddPKTlQ,55
|
32
|
+
caption_flow-0.3.2.dist-info/top_level.txt,sha256=_bXpKRutqded0FQ80dCChIz26ETV7tL4d4e2E_Y1FXs,13
|
33
|
+
caption_flow-0.3.2.dist-info/RECORD,,
|