caption-flow 0.2.3__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +1 -1
- caption_flow/cli.py +307 -0
- caption_flow/models.py +26 -0
- caption_flow/orchestrator.py +9 -9
- caption_flow/processors/huggingface.py +636 -464
- caption_flow/processors/webdataset.py +379 -534
- caption_flow/storage/__init__.py +1 -0
- caption_flow/storage/exporter.py +550 -0
- caption_flow/{storage.py → storage/manager.py} +410 -303
- caption_flow/utils/__init__.py +0 -2
- caption_flow/utils/chunk_tracker.py +196 -164
- caption_flow/utils/image_processor.py +19 -132
- caption_flow/viewer.py +594 -0
- caption_flow/workers/caption.py +164 -129
- {caption_flow-0.2.3.dist-info → caption_flow-0.3.1.dist-info}/METADATA +45 -177
- caption_flow-0.3.1.dist-info/RECORD +33 -0
- caption_flow/utils/dataset_loader.py +0 -222
- caption_flow/utils/dataset_metadata_cache.py +0 -67
- caption_flow/utils/job_queue.py +0 -41
- caption_flow/utils/shard_processor.py +0 -119
- caption_flow/utils/shard_tracker.py +0 -83
- caption_flow-0.2.3.dist-info/RECORD +0 -35
- {caption_flow-0.2.3.dist-info → caption_flow-0.3.1.dist-info}/WHEEL +0 -0
- {caption_flow-0.2.3.dist-info → caption_flow-0.3.1.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.2.3.dist-info → caption_flow-0.3.1.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.2.3.dist-info → caption_flow-0.3.1.dist-info}/top_level.txt +0 -0
caption_flow/workers/caption.py
CHANGED
@@ -34,7 +34,7 @@ from ..utils.prompt_template import PromptTemplateManager
|
|
34
34
|
from ..models import ProcessingStage, StageResult
|
35
35
|
|
36
36
|
logger = logging.getLogger(__name__)
|
37
|
-
|
37
|
+
logger.setLevel(logging.INFO)
|
38
38
|
|
39
39
|
|
40
40
|
@dataclass
|
@@ -163,14 +163,14 @@ class CaptionWorker(BaseWorker):
|
|
163
163
|
def __init__(self, config: Dict[str, Any]):
|
164
164
|
super().__init__(config)
|
165
165
|
|
166
|
-
# Processor configuration
|
166
|
+
# Processor configuration
|
167
167
|
self.processor_type = None
|
168
168
|
self.processor: Optional[
|
169
169
|
Union[
|
170
170
|
WebDatasetWorkerProcessor,
|
171
171
|
HuggingFaceDatasetWorkerProcessor,
|
172
172
|
LocalFilesystemWorkerProcessor,
|
173
|
-
]
|
173
|
+
]
|
174
174
|
] = None
|
175
175
|
self.dataset_path: Optional[str] = None
|
176
176
|
|
@@ -181,12 +181,16 @@ class CaptionWorker(BaseWorker):
|
|
181
181
|
self.vllm_config_manager = VLLMConfigManager()
|
182
182
|
self.model_manager = None
|
183
183
|
|
184
|
+
# Mock mode flag
|
185
|
+
self.mock_mode = False
|
186
|
+
|
184
187
|
# GPU selection
|
185
188
|
self.gpu_id = config.get("gpu_id", 0)
|
186
189
|
self.hf_token = get_token()
|
187
190
|
|
188
191
|
# Image processor
|
189
|
-
batch_image_processing = config.get("batch_image_processing",
|
192
|
+
batch_image_processing = config.get("batch_image_processing", True)
|
193
|
+
logger.info(f"Using batch processing: {batch_image_processing}")
|
190
194
|
self.image_processor = ImageProcessor() if batch_image_processing else None
|
191
195
|
|
192
196
|
# Work processing
|
@@ -194,9 +198,7 @@ class CaptionWorker(BaseWorker):
|
|
194
198
|
self.assigned_units = deque()
|
195
199
|
self.current_unit: Optional[WorkUnit] = None
|
196
200
|
|
197
|
-
#
|
198
|
-
self.readahead_queue = Queue(maxsize=256)
|
199
|
-
self.inference_queue = Queue(maxsize=128)
|
201
|
+
# Single result queue for sending back to orchestrator
|
200
202
|
self.result_queue = Queue()
|
201
203
|
|
202
204
|
# Processing control
|
@@ -230,13 +232,18 @@ class CaptionWorker(BaseWorker):
|
|
230
232
|
logger.error(f"Failed to get config: {e}")
|
231
233
|
await asyncio.sleep(5)
|
232
234
|
|
233
|
-
#
|
234
|
-
if self.vllm_config
|
235
|
-
self._setup_vllm()
|
235
|
+
# Check for mock mode
|
236
|
+
self.mock_mode = self.vllm_config.get("mock_results", False) if self.vllm_config else False
|
236
237
|
|
237
|
-
|
238
|
-
|
239
|
-
|
238
|
+
if self.mock_mode:
|
239
|
+
logger.info("🎭 MOCK MODE ENABLED - No vLLM models will be loaded")
|
240
|
+
else:
|
241
|
+
# Initialize vLLM once we have config
|
242
|
+
if self.vllm_config:
|
243
|
+
self._setup_vllm()
|
244
|
+
|
245
|
+
# Start processing thread
|
246
|
+
Thread(target=self._processing_thread, daemon=True).start()
|
240
247
|
|
241
248
|
async def _initial_connect_for_config(self):
|
242
249
|
"""Connect initially just to get configuration."""
|
@@ -269,8 +276,6 @@ class CaptionWorker(BaseWorker):
|
|
269
276
|
self.assigned_units.clear()
|
270
277
|
self.current_unit = None
|
271
278
|
|
272
|
-
self._clear_queue(self.readahead_queue)
|
273
|
-
self._clear_queue(self.inference_queue)
|
274
279
|
self._clear_queue(self.result_queue)
|
275
280
|
|
276
281
|
# Reset counters
|
@@ -297,6 +302,7 @@ class CaptionWorker(BaseWorker):
|
|
297
302
|
|
298
303
|
self.processor.initialize(processor_config)
|
299
304
|
self.dataset_path = self.processor.dataset_path
|
305
|
+
self.units_per_request = processor_config.config.get("chunks_per_request", 1)
|
300
306
|
|
301
307
|
# Update vLLM config if provided
|
302
308
|
new_vllm_config = welcome_data.get("processor_config", {}).get("vllm")
|
@@ -309,7 +315,7 @@ class CaptionWorker(BaseWorker):
|
|
309
315
|
|
310
316
|
# Request initial work
|
311
317
|
if self.websocket:
|
312
|
-
await self.websocket.send(json.dumps({"type": "
|
318
|
+
await self.websocket.send(json.dumps({"type": "get_work_units", "count": 2}))
|
313
319
|
|
314
320
|
async def _handle_message(self, data: Dict[str, Any]):
|
315
321
|
"""Handle message from orchestrator."""
|
@@ -327,7 +333,7 @@ class CaptionWorker(BaseWorker):
|
|
327
333
|
await asyncio.sleep(10)
|
328
334
|
|
329
335
|
if self.websocket and self.connected.is_set():
|
330
|
-
await self.websocket.send(json.dumps({"type": "
|
336
|
+
await self.websocket.send(json.dumps({"type": "get_work_units", "count": 2}))
|
331
337
|
|
332
338
|
def _parse_stages_config(self, vllm_config: Dict[str, Any]) -> List[ProcessingStage]:
|
333
339
|
"""Parse stages configuration from vLLM config."""
|
@@ -405,6 +411,7 @@ class CaptionWorker(BaseWorker):
|
|
405
411
|
self.model_manager = MultiStageVLLMManager(self.gpu_id)
|
406
412
|
|
407
413
|
# Get base config for models
|
414
|
+
logger.info(f"vLLM config: {self.vllm_config}")
|
408
415
|
base_config = {
|
409
416
|
"tensor_parallel_size": self.vllm_config.get("tensor_parallel_size", 1),
|
410
417
|
"max_model_len": self.vllm_config.get("max_model_len", 16384),
|
@@ -437,6 +444,13 @@ class CaptionWorker(BaseWorker):
|
|
437
444
|
if not new_config:
|
438
445
|
return True
|
439
446
|
|
447
|
+
# Check if mock mode changed
|
448
|
+
old_mock_mode = self.mock_mode
|
449
|
+
self.mock_mode = new_config.get("mock_results", False)
|
450
|
+
|
451
|
+
if old_mock_mode != self.mock_mode:
|
452
|
+
logger.info(f"Mock mode changed from {old_mock_mode} to {self.mock_mode}")
|
453
|
+
|
440
454
|
# Parse new stages
|
441
455
|
new_stages = self._parse_stages_config(new_config)
|
442
456
|
|
@@ -453,35 +467,43 @@ class CaptionWorker(BaseWorker):
|
|
453
467
|
stages_changed = True
|
454
468
|
break
|
455
469
|
|
456
|
-
if stages_changed:
|
457
|
-
logger.info("
|
470
|
+
if stages_changed or old_mock_mode != self.mock_mode:
|
471
|
+
logger.info("Configuration changed significantly")
|
458
472
|
|
459
473
|
old_config = self.vllm_config
|
460
474
|
self.vllm_config = new_config
|
461
475
|
self.stages = new_stages
|
462
476
|
self.stage_order = self._topological_sort_stages(self.stages)
|
463
477
|
|
464
|
-
|
478
|
+
if not self.mock_mode:
|
479
|
+
try:
|
480
|
+
if self.model_manager:
|
481
|
+
self.model_manager.cleanup()
|
482
|
+
self._setup_vllm()
|
483
|
+
return True
|
484
|
+
except Exception as e:
|
485
|
+
logger.error(f"Failed to reload vLLM: {e}")
|
486
|
+
self.vllm_config = old_config
|
487
|
+
return False
|
488
|
+
else:
|
489
|
+
# Clean up models if switching to mock mode
|
465
490
|
if self.model_manager:
|
466
491
|
self.model_manager.cleanup()
|
467
|
-
|
492
|
+
self.model_manager = None
|
468
493
|
return True
|
469
|
-
except Exception as e:
|
470
|
-
logger.error(f"Failed to reload vLLM: {e}")
|
471
|
-
self.vllm_config = old_config
|
472
|
-
return False
|
473
494
|
else:
|
474
495
|
# Just update sampling params
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
self.
|
496
|
+
if not self.mock_mode:
|
497
|
+
logger.info("Updating sampling parameters without model reload")
|
498
|
+
base_sampling = new_config.get("sampling", {})
|
499
|
+
for stage in self.stages:
|
500
|
+
self.model_manager.create_sampling_params(stage, base_sampling)
|
479
501
|
self.vllm_config = new_config
|
480
502
|
return True
|
481
503
|
|
482
|
-
def
|
483
|
-
"""
|
484
|
-
logger.info("Starting
|
504
|
+
def _processing_thread(self):
|
505
|
+
"""Main processing thread that handles work units."""
|
506
|
+
logger.info("Starting processing thread")
|
485
507
|
|
486
508
|
while self.running:
|
487
509
|
if self.should_stop_processing.is_set():
|
@@ -513,11 +535,13 @@ class CaptionWorker(BaseWorker):
|
|
513
535
|
with self.work_lock:
|
514
536
|
queue_size = len(self.assigned_units)
|
515
537
|
|
516
|
-
if queue_size <
|
538
|
+
if queue_size < self.units_per_request and self.websocket and self.main_loop:
|
517
539
|
try:
|
518
540
|
asyncio.run_coroutine_threadsafe(
|
519
541
|
self.websocket.send(
|
520
|
-
json.dumps(
|
542
|
+
json.dumps(
|
543
|
+
{"type": "get_work_units", "count": self.units_per_request}
|
544
|
+
)
|
521
545
|
),
|
522
546
|
self.main_loop,
|
523
547
|
).result(timeout=5)
|
@@ -533,16 +557,20 @@ class CaptionWorker(BaseWorker):
|
|
533
557
|
self.current_unit = None
|
534
558
|
|
535
559
|
def _process_work_unit(self, unit: WorkUnit):
|
536
|
-
"""Process a single work unit."""
|
560
|
+
"""Process a single work unit with batching."""
|
537
561
|
if not self.processor:
|
538
562
|
logger.error("Processor not initialized")
|
539
563
|
return
|
540
564
|
|
541
|
-
|
542
|
-
|
565
|
+
batch = []
|
566
|
+
batch_size = self.vllm_config.get("batch_size", 8)
|
567
|
+
context = {}
|
543
568
|
|
544
|
-
#
|
569
|
+
# Collect items for batching
|
545
570
|
for item_data in self.processor.process_unit(unit, context):
|
571
|
+
if self.should_stop_processing.is_set() or not self.connected.is_set():
|
572
|
+
break
|
573
|
+
|
546
574
|
try:
|
547
575
|
# Create processing item
|
548
576
|
item = ProcessingItem(
|
@@ -551,35 +579,19 @@ class CaptionWorker(BaseWorker):
|
|
551
579
|
job_id=item_data["job_id"],
|
552
580
|
item_key=item_data["item_key"],
|
553
581
|
item_index=item_data["item_index"],
|
554
|
-
image=item_data
|
582
|
+
image=item_data.get("image", None),
|
555
583
|
image_data=item_data.get("image_data", b""),
|
556
584
|
metadata=item_data.get("metadata", {}),
|
557
585
|
)
|
586
|
+
if "_processed_indices" in item_data:
|
587
|
+
context["_processed_indices"] = item_data.pop("_processed_indices", [])
|
558
588
|
|
559
|
-
|
560
|
-
timeout_end = time.time() + 30
|
561
|
-
while (
|
562
|
-
self.running
|
563
|
-
and not self.should_stop_processing.is_set()
|
564
|
-
and self.connected.is_set()
|
565
|
-
):
|
566
|
-
try:
|
567
|
-
self.readahead_queue.put(item, timeout=1)
|
568
|
-
break
|
569
|
-
except:
|
570
|
-
if time.time() > timeout_end:
|
571
|
-
raise TimeoutError("Queue put timeout")
|
572
|
-
continue
|
573
|
-
|
574
|
-
if not self.connected.is_set() or self.should_stop_processing.is_set():
|
575
|
-
break
|
576
|
-
|
577
|
-
items_processed += 1
|
589
|
+
batch.append(item)
|
578
590
|
|
579
|
-
#
|
580
|
-
|
581
|
-
|
582
|
-
|
591
|
+
# Process batch when it reaches size
|
592
|
+
if len(batch) >= batch_size:
|
593
|
+
self._process_batch(batch)
|
594
|
+
batch = []
|
583
595
|
|
584
596
|
except Exception as e:
|
585
597
|
if self.should_stop_processing.is_set():
|
@@ -587,77 +599,95 @@ class CaptionWorker(BaseWorker):
|
|
587
599
|
logger.error(f"Error processing item {item_data.get('item_key')}: {e}")
|
588
600
|
self.items_failed += 1
|
589
601
|
|
590
|
-
# Process
|
591
|
-
if not self.should_stop_processing.is_set():
|
592
|
-
self.
|
593
|
-
|
594
|
-
|
602
|
+
# Process remaining items in batch
|
603
|
+
if batch and not self.should_stop_processing.is_set():
|
604
|
+
self._process_batch(batch)
|
605
|
+
|
606
|
+
# Notify orchestrator that unit is complete
|
607
|
+
if self.connected.is_set() and self.websocket:
|
608
|
+
try:
|
595
609
|
asyncio.run_coroutine_threadsafe(
|
596
610
|
self.websocket.send(
|
597
611
|
json.dumps({"type": "work_complete", "unit_id": unit.unit_id})
|
598
612
|
),
|
599
613
|
self.main_loop,
|
600
614
|
).result(timeout=5)
|
615
|
+
except Exception as e:
|
616
|
+
logger.warning(f"Could not notify work complete: {e}")
|
601
617
|
|
602
|
-
|
618
|
+
def _process_batch(self, batch: List[ProcessingItem]):
|
619
|
+
"""Process a batch of items through all stages."""
|
620
|
+
if not batch:
|
621
|
+
return
|
603
622
|
|
604
|
-
|
605
|
-
|
606
|
-
batch = []
|
607
|
-
batch_size = self.vllm_config.get("batch_size", 8)
|
623
|
+
logger.debug(f"Processing batch of {len(batch)} images")
|
624
|
+
start_time = time.time()
|
608
625
|
|
609
626
|
try:
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
627
|
+
# Process batch through all stages
|
628
|
+
if self.mock_mode:
|
629
|
+
results = self._process_batch_mock(batch)
|
630
|
+
else:
|
631
|
+
results = self._process_batch_multi_stage(batch)
|
615
632
|
|
616
|
-
|
617
|
-
|
633
|
+
# Calculate processing time
|
634
|
+
if results:
|
635
|
+
processing_time_per_item = (time.time() - start_time) * 1000 / len(batch)
|
636
|
+
|
637
|
+
for item, result_outputs in results:
|
638
|
+
self.result_queue.put(
|
639
|
+
{
|
640
|
+
"item": item,
|
641
|
+
"outputs": result_outputs,
|
642
|
+
"processing_time_ms": processing_time_per_item,
|
643
|
+
}
|
644
|
+
)
|
618
645
|
|
619
|
-
|
620
|
-
"""Background thread for multi-stage vLLM inference."""
|
621
|
-
logger.info("Starting multi-stage inference thread")
|
646
|
+
logger.debug(f"Batch processing complete: {len(results)} successful")
|
622
647
|
|
623
|
-
|
624
|
-
|
625
|
-
batch = self.inference_queue.get(timeout=1)
|
626
|
-
if not batch:
|
627
|
-
continue
|
648
|
+
except Exception as e:
|
649
|
+
logger.error(f"Batch processing error: {e}", exc_info=True)
|
628
650
|
|
629
|
-
|
630
|
-
|
651
|
+
def _process_batch_mock(self, batch: List[ProcessingItem]) -> List[Tuple[ProcessingItem, Dict]]:
|
652
|
+
"""Process a batch in mock mode - return dummy captions."""
|
653
|
+
results = []
|
631
654
|
|
632
|
-
|
633
|
-
|
634
|
-
)
|
635
|
-
start_time = time.time()
|
655
|
+
# Simulate some processing time
|
656
|
+
time.sleep(0.1)
|
636
657
|
|
637
|
-
|
638
|
-
|
658
|
+
for item in batch:
|
659
|
+
# Generate mock outputs for each stage
|
660
|
+
for stage_name in self.stage_order:
|
661
|
+
stage = next(s for s in self.stages if s.name == stage_name)
|
662
|
+
|
663
|
+
# Create mock outputs based on stage prompts
|
664
|
+
stage_outputs = []
|
665
|
+
for i, prompt in enumerate(stage.prompts):
|
666
|
+
mock_output = (
|
667
|
+
f"Mock {stage_name} output {i+1} for job {item.job_id} - {item.item_key}"
|
668
|
+
)
|
669
|
+
stage_outputs.append(mock_output)
|
639
670
|
|
640
|
-
#
|
641
|
-
|
642
|
-
|
671
|
+
# Store stage result
|
672
|
+
stage_result = StageResult(
|
673
|
+
stage_name=stage_name,
|
674
|
+
output_field=stage.output_field,
|
675
|
+
outputs=stage_outputs,
|
676
|
+
)
|
677
|
+
item.stage_results[stage_name] = stage_result
|
643
678
|
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
"outputs": result_outputs,
|
649
|
-
"processing_time_ms": processing_time_per_item,
|
650
|
-
}
|
651
|
-
)
|
679
|
+
# Aggregate outputs by field
|
680
|
+
outputs_by_field = defaultdict(list)
|
681
|
+
for stage_result in item.stage_results.values():
|
682
|
+
outputs_by_field[stage_result.output_field].extend(stage_result.outputs)
|
652
683
|
|
653
|
-
|
684
|
+
results.append((item, dict(outputs_by_field)))
|
685
|
+
self.items_processed += 1
|
654
686
|
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
continue
|
660
|
-
logger.error(f"Inference error: {e}", exc_info=True)
|
687
|
+
if self.items_processed % 10 == 0:
|
688
|
+
logger.info(f"🎭 Mock mode: Processed {self.items_processed} items")
|
689
|
+
|
690
|
+
return results
|
661
691
|
|
662
692
|
def _process_batch_multi_stage(
|
663
693
|
self, batch: List[ProcessingItem], max_attempts: int = 3
|
@@ -685,8 +715,8 @@ class CaptionWorker(BaseWorker):
|
|
685
715
|
for idx, (original_idx, item, attempt_count) in enumerate(items_to_process):
|
686
716
|
current_batch.append((original_idx, item, attempt_count))
|
687
717
|
|
688
|
-
# Prepare image
|
689
|
-
converted_img = ImageProcessor.prepare_for_inference(item
|
718
|
+
# Prepare image from PIL frame or bytes
|
719
|
+
converted_img = ImageProcessor.prepare_for_inference(item)
|
690
720
|
|
691
721
|
# Create template manager
|
692
722
|
template_manager = PromptTemplateManager(stage.prompts)
|
@@ -832,12 +862,11 @@ class CaptionWorker(BaseWorker):
|
|
832
862
|
"units_completed": self.units_completed,
|
833
863
|
"current_unit": self._get_current_unit_id() if self.current_unit else None,
|
834
864
|
"queue_sizes": {
|
835
|
-
"readahead": self.readahead_queue.qsize(),
|
836
|
-
"inference": self.inference_queue.qsize(),
|
837
865
|
"results": self.result_queue.qsize(),
|
838
866
|
},
|
839
867
|
"stages": len(self.stages),
|
840
868
|
"models_loaded": len(self.model_manager.models) if self.model_manager else 0,
|
869
|
+
"mock_mode": self.mock_mode,
|
841
870
|
}
|
842
871
|
|
843
872
|
async def _create_tasks(self) -> list:
|
@@ -852,7 +881,7 @@ class CaptionWorker(BaseWorker):
|
|
852
881
|
"""Send results back to orchestrator."""
|
853
882
|
while self.running and self.connected.is_set():
|
854
883
|
try:
|
855
|
-
# Get result
|
884
|
+
# Get result with timeout
|
856
885
|
result_data = await asyncio.get_event_loop().run_in_executor(
|
857
886
|
None, self.result_queue.get, True, 1
|
858
887
|
)
|
@@ -863,7 +892,6 @@ class CaptionWorker(BaseWorker):
|
|
863
892
|
outputs = result_data["outputs"]
|
864
893
|
|
865
894
|
# Create work result
|
866
|
-
# logger.info(f"Processed item: {item}")
|
867
895
|
work_result = WorkResult(
|
868
896
|
unit_id=item.unit_id,
|
869
897
|
source_id=item.metadata.get("shard_name", "unknown"),
|
@@ -873,9 +901,21 @@ class CaptionWorker(BaseWorker):
|
|
873
901
|
metadata={
|
874
902
|
"item_key": item.item_key,
|
875
903
|
"item_index": item.metadata.get("_item_index"),
|
876
|
-
"image_width":
|
877
|
-
|
878
|
-
|
904
|
+
"image_width": (
|
905
|
+
item.image.width
|
906
|
+
if item.image is not None
|
907
|
+
else item.metadata.get("image_width")
|
908
|
+
),
|
909
|
+
"image_height": (
|
910
|
+
item.image.height
|
911
|
+
if item.image is not None
|
912
|
+
else item.metadata.get("image_height")
|
913
|
+
),
|
914
|
+
"image_format": (
|
915
|
+
item.image.format
|
916
|
+
if item.image is not None
|
917
|
+
else item.metadata.get("image_format", "unknown")
|
918
|
+
),
|
879
919
|
"file_size": len(item.image_data) if item.image_data else 0,
|
880
920
|
**item.metadata,
|
881
921
|
},
|
@@ -883,7 +923,7 @@ class CaptionWorker(BaseWorker):
|
|
883
923
|
error=result_data.get("error", None),
|
884
924
|
)
|
885
925
|
|
886
|
-
# Send result
|
926
|
+
# Send result
|
887
927
|
await self.websocket.send(
|
888
928
|
json.dumps(
|
889
929
|
{
|
@@ -920,9 +960,7 @@ class CaptionWorker(BaseWorker):
|
|
920
960
|
self.assigned_units.clear()
|
921
961
|
self.current_unit = None
|
922
962
|
|
923
|
-
# Clear
|
924
|
-
self._clear_queue(self.readahead_queue)
|
925
|
-
self._clear_queue(self.inference_queue)
|
963
|
+
# Clear result queue
|
926
964
|
self._clear_queue(self.result_queue)
|
927
965
|
|
928
966
|
def _clear_queue(self, queue: Queue):
|
@@ -935,9 +973,6 @@ class CaptionWorker(BaseWorker):
|
|
935
973
|
|
936
974
|
async def _pre_shutdown(self):
|
937
975
|
"""Cleanup before shutdown."""
|
938
|
-
self.readahead_queue.put(None)
|
939
|
-
self.inference_queue.put(None)
|
940
|
-
|
941
976
|
if self.image_processor:
|
942
977
|
self.image_processor.shutdown()
|
943
978
|
|