caption-flow 0.2.3__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -34,7 +34,7 @@ from ..utils.prompt_template import PromptTemplateManager
34
34
  from ..models import ProcessingStage, StageResult
35
35
 
36
36
  logger = logging.getLogger(__name__)
37
- # logger.setLevel(logging.DEBUG)
37
+ logger.setLevel(logging.INFO)
38
38
 
39
39
 
40
40
  @dataclass
@@ -163,14 +163,14 @@ class CaptionWorker(BaseWorker):
163
163
  def __init__(self, config: Dict[str, Any]):
164
164
  super().__init__(config)
165
165
 
166
- # Processor configuration - will be set from orchestrator
166
+ # Processor configuration
167
167
  self.processor_type = None
168
168
  self.processor: Optional[
169
169
  Union[
170
170
  WebDatasetWorkerProcessor,
171
171
  HuggingFaceDatasetWorkerProcessor,
172
172
  LocalFilesystemWorkerProcessor,
173
- ],
173
+ ]
174
174
  ] = None
175
175
  self.dataset_path: Optional[str] = None
176
176
 
@@ -181,12 +181,16 @@ class CaptionWorker(BaseWorker):
181
181
  self.vllm_config_manager = VLLMConfigManager()
182
182
  self.model_manager = None
183
183
 
184
+ # Mock mode flag
185
+ self.mock_mode = False
186
+
184
187
  # GPU selection
185
188
  self.gpu_id = config.get("gpu_id", 0)
186
189
  self.hf_token = get_token()
187
190
 
188
191
  # Image processor
189
- batch_image_processing = config.get("batch_image_processing", False)
192
+ batch_image_processing = config.get("batch_image_processing", True)
193
+ logger.info(f"Using batch processing: {batch_image_processing}")
190
194
  self.image_processor = ImageProcessor() if batch_image_processing else None
191
195
 
192
196
  # Work processing
@@ -194,9 +198,7 @@ class CaptionWorker(BaseWorker):
194
198
  self.assigned_units = deque()
195
199
  self.current_unit: Optional[WorkUnit] = None
196
200
 
197
- # Processing queues
198
- self.readahead_queue = Queue(maxsize=256)
199
- self.inference_queue = Queue(maxsize=128)
201
+ # Single result queue for sending back to orchestrator
200
202
  self.result_queue = Queue()
201
203
 
202
204
  # Processing control
@@ -230,13 +232,18 @@ class CaptionWorker(BaseWorker):
230
232
  logger.error(f"Failed to get config: {e}")
231
233
  await asyncio.sleep(5)
232
234
 
233
- # Initialize vLLM once we have config
234
- if self.vllm_config:
235
- self._setup_vllm()
235
+ # Check for mock mode
236
+ self.mock_mode = self.vllm_config.get("mock_results", False) if self.vllm_config else False
236
237
 
237
- # Start background threads
238
- Thread(target=self._unit_processor_thread, daemon=True).start()
239
- Thread(target=self._inference_thread, daemon=True).start()
238
+ if self.mock_mode:
239
+ logger.info("🎭 MOCK MODE ENABLED - No vLLM models will be loaded")
240
+ else:
241
+ # Initialize vLLM once we have config
242
+ if self.vllm_config:
243
+ self._setup_vllm()
244
+
245
+ # Start processing thread
246
+ Thread(target=self._processing_thread, daemon=True).start()
240
247
 
241
248
  async def _initial_connect_for_config(self):
242
249
  """Connect initially just to get configuration."""
@@ -269,8 +276,6 @@ class CaptionWorker(BaseWorker):
269
276
  self.assigned_units.clear()
270
277
  self.current_unit = None
271
278
 
272
- self._clear_queue(self.readahead_queue)
273
- self._clear_queue(self.inference_queue)
274
279
  self._clear_queue(self.result_queue)
275
280
 
276
281
  # Reset counters
@@ -297,6 +302,7 @@ class CaptionWorker(BaseWorker):
297
302
 
298
303
  self.processor.initialize(processor_config)
299
304
  self.dataset_path = self.processor.dataset_path
305
+ self.units_per_request = processor_config.config.get("chunks_per_request", 1)
300
306
 
301
307
  # Update vLLM config if provided
302
308
  new_vllm_config = welcome_data.get("processor_config", {}).get("vllm")
@@ -309,7 +315,7 @@ class CaptionWorker(BaseWorker):
309
315
 
310
316
  # Request initial work
311
317
  if self.websocket:
312
- await self.websocket.send(json.dumps({"type": "request_work", "count": 2}))
318
+ await self.websocket.send(json.dumps({"type": "get_work_units", "count": 2}))
313
319
 
314
320
  async def _handle_message(self, data: Dict[str, Any]):
315
321
  """Handle message from orchestrator."""
@@ -327,7 +333,7 @@ class CaptionWorker(BaseWorker):
327
333
  await asyncio.sleep(10)
328
334
 
329
335
  if self.websocket and self.connected.is_set():
330
- await self.websocket.send(json.dumps({"type": "request_work", "count": 2}))
336
+ await self.websocket.send(json.dumps({"type": "get_work_units", "count": 2}))
331
337
 
332
338
  def _parse_stages_config(self, vllm_config: Dict[str, Any]) -> List[ProcessingStage]:
333
339
  """Parse stages configuration from vLLM config."""
@@ -405,6 +411,7 @@ class CaptionWorker(BaseWorker):
405
411
  self.model_manager = MultiStageVLLMManager(self.gpu_id)
406
412
 
407
413
  # Get base config for models
414
+ logger.info(f"vLLM config: {self.vllm_config}")
408
415
  base_config = {
409
416
  "tensor_parallel_size": self.vllm_config.get("tensor_parallel_size", 1),
410
417
  "max_model_len": self.vllm_config.get("max_model_len", 16384),
@@ -437,6 +444,13 @@ class CaptionWorker(BaseWorker):
437
444
  if not new_config:
438
445
  return True
439
446
 
447
+ # Check if mock mode changed
448
+ old_mock_mode = self.mock_mode
449
+ self.mock_mode = new_config.get("mock_results", False)
450
+
451
+ if old_mock_mode != self.mock_mode:
452
+ logger.info(f"Mock mode changed from {old_mock_mode} to {self.mock_mode}")
453
+
440
454
  # Parse new stages
441
455
  new_stages = self._parse_stages_config(new_config)
442
456
 
@@ -453,35 +467,43 @@ class CaptionWorker(BaseWorker):
453
467
  stages_changed = True
454
468
  break
455
469
 
456
- if stages_changed:
457
- logger.info("Stage configuration changed, reloading all models")
470
+ if stages_changed or old_mock_mode != self.mock_mode:
471
+ logger.info("Configuration changed significantly")
458
472
 
459
473
  old_config = self.vllm_config
460
474
  self.vllm_config = new_config
461
475
  self.stages = new_stages
462
476
  self.stage_order = self._topological_sort_stages(self.stages)
463
477
 
464
- try:
478
+ if not self.mock_mode:
479
+ try:
480
+ if self.model_manager:
481
+ self.model_manager.cleanup()
482
+ self._setup_vllm()
483
+ return True
484
+ except Exception as e:
485
+ logger.error(f"Failed to reload vLLM: {e}")
486
+ self.vllm_config = old_config
487
+ return False
488
+ else:
489
+ # Clean up models if switching to mock mode
465
490
  if self.model_manager:
466
491
  self.model_manager.cleanup()
467
- self._setup_vllm()
492
+ self.model_manager = None
468
493
  return True
469
- except Exception as e:
470
- logger.error(f"Failed to reload vLLM: {e}")
471
- self.vllm_config = old_config
472
- return False
473
494
  else:
474
495
  # Just update sampling params
475
- logger.info("Updating sampling parameters without model reload")
476
- base_sampling = new_config.get("sampling", {})
477
- for stage in self.stages:
478
- self.model_manager.create_sampling_params(stage, base_sampling)
496
+ if not self.mock_mode:
497
+ logger.info("Updating sampling parameters without model reload")
498
+ base_sampling = new_config.get("sampling", {})
499
+ for stage in self.stages:
500
+ self.model_manager.create_sampling_params(stage, base_sampling)
479
501
  self.vllm_config = new_config
480
502
  return True
481
503
 
482
- def _unit_processor_thread(self):
483
- """Background thread that processes work units."""
484
- logger.info("Starting unit processor thread")
504
+ def _processing_thread(self):
505
+ """Main processing thread that handles work units."""
506
+ logger.info("Starting processing thread")
485
507
 
486
508
  while self.running:
487
509
  if self.should_stop_processing.is_set():
@@ -513,11 +535,13 @@ class CaptionWorker(BaseWorker):
513
535
  with self.work_lock:
514
536
  queue_size = len(self.assigned_units)
515
537
 
516
- if queue_size < 2 and self.websocket and self.main_loop:
538
+ if queue_size < self.units_per_request and self.websocket and self.main_loop:
517
539
  try:
518
540
  asyncio.run_coroutine_threadsafe(
519
541
  self.websocket.send(
520
- json.dumps({"type": "request_work", "count": 2})
542
+ json.dumps(
543
+ {"type": "get_work_units", "count": self.units_per_request}
544
+ )
521
545
  ),
522
546
  self.main_loop,
523
547
  ).result(timeout=5)
@@ -533,16 +557,20 @@ class CaptionWorker(BaseWorker):
533
557
  self.current_unit = None
534
558
 
535
559
  def _process_work_unit(self, unit: WorkUnit):
536
- """Process a single work unit."""
560
+ """Process a single work unit with batching."""
537
561
  if not self.processor:
538
562
  logger.error("Processor not initialized")
539
563
  return
540
564
 
541
- items_processed = 0
542
- context = {} # Will store processed indices
565
+ batch = []
566
+ batch_size = self.vllm_config.get("batch_size", 8)
567
+ context = {}
543
568
 
544
- # Get items from processor
569
+ # Collect items for batching
545
570
  for item_data in self.processor.process_unit(unit, context):
571
+ if self.should_stop_processing.is_set() or not self.connected.is_set():
572
+ break
573
+
546
574
  try:
547
575
  # Create processing item
548
576
  item = ProcessingItem(
@@ -551,35 +579,19 @@ class CaptionWorker(BaseWorker):
551
579
  job_id=item_data["job_id"],
552
580
  item_key=item_data["item_key"],
553
581
  item_index=item_data["item_index"],
554
- image=item_data["image"],
582
+ image=item_data.get("image", None),
555
583
  image_data=item_data.get("image_data", b""),
556
584
  metadata=item_data.get("metadata", {}),
557
585
  )
586
+ if "_processed_indices" in item_data:
587
+ context["_processed_indices"] = item_data.pop("_processed_indices", [])
558
588
 
559
- # Add to readahead queue
560
- timeout_end = time.time() + 30
561
- while (
562
- self.running
563
- and not self.should_stop_processing.is_set()
564
- and self.connected.is_set()
565
- ):
566
- try:
567
- self.readahead_queue.put(item, timeout=1)
568
- break
569
- except:
570
- if time.time() > timeout_end:
571
- raise TimeoutError("Queue put timeout")
572
- continue
573
-
574
- if not self.connected.is_set() or self.should_stop_processing.is_set():
575
- break
576
-
577
- items_processed += 1
589
+ batch.append(item)
578
590
 
579
- # Batch items for inference
580
- batch_size = self.vllm_config.get("batch_size", 8)
581
- if self.readahead_queue.qsize() >= batch_size:
582
- self._batch_for_inference()
591
+ # Process batch when it reaches size
592
+ if len(batch) >= batch_size:
593
+ self._process_batch(batch)
594
+ batch = []
583
595
 
584
596
  except Exception as e:
585
597
  if self.should_stop_processing.is_set():
@@ -587,77 +599,95 @@ class CaptionWorker(BaseWorker):
587
599
  logger.error(f"Error processing item {item_data.get('item_key')}: {e}")
588
600
  self.items_failed += 1
589
601
 
590
- # Process any remaining items
591
- if not self.should_stop_processing.is_set():
592
- self._batch_for_inference()
593
- if self.connected.is_set():
594
- # Notify orchestrator that unit is complete
602
+ # Process remaining items in batch
603
+ if batch and not self.should_stop_processing.is_set():
604
+ self._process_batch(batch)
605
+
606
+ # Notify orchestrator that unit is complete
607
+ if self.connected.is_set() and self.websocket:
608
+ try:
595
609
  asyncio.run_coroutine_threadsafe(
596
610
  self.websocket.send(
597
611
  json.dumps({"type": "work_complete", "unit_id": unit.unit_id})
598
612
  ),
599
613
  self.main_loop,
600
614
  ).result(timeout=5)
615
+ except Exception as e:
616
+ logger.warning(f"Could not notify work complete: {e}")
601
617
 
602
- logger.info(f"Unit {unit.unit_id} processed {items_processed} items")
618
+ def _process_batch(self, batch: List[ProcessingItem]):
619
+ """Process a batch of items through all stages."""
620
+ if not batch:
621
+ return
603
622
 
604
- def _batch_for_inference(self):
605
- """Batch items from readahead queue for inference."""
606
- batch = []
607
- batch_size = self.vllm_config.get("batch_size", 8)
623
+ logger.debug(f"Processing batch of {len(batch)} images")
624
+ start_time = time.time()
608
625
 
609
626
  try:
610
- while len(batch) < batch_size:
611
- item = self.readahead_queue.get_nowait()
612
- batch.append(item)
613
- except Empty:
614
- pass
627
+ # Process batch through all stages
628
+ if self.mock_mode:
629
+ results = self._process_batch_mock(batch)
630
+ else:
631
+ results = self._process_batch_multi_stage(batch)
615
632
 
616
- if batch:
617
- self.inference_queue.put(batch)
633
+ # Calculate processing time
634
+ if results:
635
+ processing_time_per_item = (time.time() - start_time) * 1000 / len(batch)
636
+
637
+ for item, result_outputs in results:
638
+ self.result_queue.put(
639
+ {
640
+ "item": item,
641
+ "outputs": result_outputs,
642
+ "processing_time_ms": processing_time_per_item,
643
+ }
644
+ )
618
645
 
619
- def _inference_thread(self):
620
- """Background thread for multi-stage vLLM inference."""
621
- logger.info("Starting multi-stage inference thread")
646
+ logger.debug(f"Batch processing complete: {len(results)} successful")
622
647
 
623
- while self.running:
624
- try:
625
- batch = self.inference_queue.get(timeout=1)
626
- if not batch:
627
- continue
648
+ except Exception as e:
649
+ logger.error(f"Batch processing error: {e}", exc_info=True)
628
650
 
629
- if self.should_stop_processing.is_set():
630
- continue
651
+ def _process_batch_mock(self, batch: List[ProcessingItem]) -> List[Tuple[ProcessingItem, Dict]]:
652
+ """Process a batch in mock mode - return dummy captions."""
653
+ results = []
631
654
 
632
- logger.debug(
633
- f"Processing batch of {len(batch)} images through {len(self.stages)} stages"
634
- )
635
- start_time = time.time()
655
+ # Simulate some processing time
656
+ time.sleep(0.1)
636
657
 
637
- # Process batch through all stages
638
- results = self._process_batch_multi_stage(batch)
658
+ for item in batch:
659
+ # Generate mock outputs for each stage
660
+ for stage_name in self.stage_order:
661
+ stage = next(s for s in self.stages if s.name == stage_name)
662
+
663
+ # Create mock outputs based on stage prompts
664
+ stage_outputs = []
665
+ for i, prompt in enumerate(stage.prompts):
666
+ mock_output = (
667
+ f"Mock {stage_name} output {i+1} for job {item.job_id} - {item.item_key}"
668
+ )
669
+ stage_outputs.append(mock_output)
639
670
 
640
- # Calculate processing time
641
- if results:
642
- processing_time_per_item = (time.time() - start_time) * 1000 / len(batch)
671
+ # Store stage result
672
+ stage_result = StageResult(
673
+ stage_name=stage_name,
674
+ output_field=stage.output_field,
675
+ outputs=stage_outputs,
676
+ )
677
+ item.stage_results[stage_name] = stage_result
643
678
 
644
- for item, result_outputs in results:
645
- self.result_queue.put(
646
- {
647
- "item": item,
648
- "outputs": result_outputs,
649
- "processing_time_ms": processing_time_per_item,
650
- }
651
- )
679
+ # Aggregate outputs by field
680
+ outputs_by_field = defaultdict(list)
681
+ for stage_result in item.stage_results.values():
682
+ outputs_by_field[stage_result.output_field].extend(stage_result.outputs)
652
683
 
653
- logger.debug(f"Batch processing complete: {len(results)} successful")
684
+ results.append((item, dict(outputs_by_field)))
685
+ self.items_processed += 1
654
686
 
655
- except Empty:
656
- continue
657
- except Exception as e:
658
- if self.should_stop_processing.is_set():
659
- continue
660
- logger.error(f"Inference error: {e}", exc_info=True)
687
+ if self.items_processed % 10 == 0:
688
+ logger.info(f"🎭 Mock mode: Processed {self.items_processed} items")
689
+
690
+ return results
661
691
 
662
692
  def _process_batch_multi_stage(
663
693
  self, batch: List[ProcessingItem], max_attempts: int = 3
@@ -685,8 +715,8 @@ class CaptionWorker(BaseWorker):
685
715
  for idx, (original_idx, item, attempt_count) in enumerate(items_to_process):
686
716
  current_batch.append((original_idx, item, attempt_count))
687
717
 
688
- # Prepare image
689
- converted_img = ImageProcessor.prepare_for_inference(item.image)
718
+ # Prepare image from PIL frame or bytes
719
+ converted_img = ImageProcessor.prepare_for_inference(item)
690
720
 
691
721
  # Create template manager
692
722
  template_manager = PromptTemplateManager(stage.prompts)
@@ -832,12 +862,11 @@ class CaptionWorker(BaseWorker):
832
862
  "units_completed": self.units_completed,
833
863
  "current_unit": self._get_current_unit_id() if self.current_unit else None,
834
864
  "queue_sizes": {
835
- "readahead": self.readahead_queue.qsize(),
836
- "inference": self.inference_queue.qsize(),
837
865
  "results": self.result_queue.qsize(),
838
866
  },
839
867
  "stages": len(self.stages),
840
868
  "models_loaded": len(self.model_manager.models) if self.model_manager else 0,
869
+ "mock_mode": self.mock_mode,
841
870
  }
842
871
 
843
872
  async def _create_tasks(self) -> list:
@@ -852,7 +881,7 @@ class CaptionWorker(BaseWorker):
852
881
  """Send results back to orchestrator."""
853
882
  while self.running and self.connected.is_set():
854
883
  try:
855
- # Get result
884
+ # Get result with timeout
856
885
  result_data = await asyncio.get_event_loop().run_in_executor(
857
886
  None, self.result_queue.get, True, 1
858
887
  )
@@ -863,7 +892,6 @@ class CaptionWorker(BaseWorker):
863
892
  outputs = result_data["outputs"]
864
893
 
865
894
  # Create work result
866
- # logger.info(f"Processed item: {item}")
867
895
  work_result = WorkResult(
868
896
  unit_id=item.unit_id,
869
897
  source_id=item.metadata.get("shard_name", "unknown"),
@@ -873,9 +901,21 @@ class CaptionWorker(BaseWorker):
873
901
  metadata={
874
902
  "item_key": item.item_key,
875
903
  "item_index": item.metadata.get("_item_index"),
876
- "image_width": item.image.width,
877
- "image_height": item.image.height,
878
- "image_format": item.image.format or "unknown",
904
+ "image_width": (
905
+ item.image.width
906
+ if item.image is not None
907
+ else item.metadata.get("image_width")
908
+ ),
909
+ "image_height": (
910
+ item.image.height
911
+ if item.image is not None
912
+ else item.metadata.get("image_height")
913
+ ),
914
+ "image_format": (
915
+ item.image.format
916
+ if item.image is not None
917
+ else item.metadata.get("image_format", "unknown")
918
+ ),
879
919
  "file_size": len(item.image_data) if item.image_data else 0,
880
920
  **item.metadata,
881
921
  },
@@ -883,7 +923,7 @@ class CaptionWorker(BaseWorker):
883
923
  error=result_data.get("error", None),
884
924
  )
885
925
 
886
- # Send result in format that orchestrator expects
926
+ # Send result
887
927
  await self.websocket.send(
888
928
  json.dumps(
889
929
  {
@@ -920,9 +960,7 @@ class CaptionWorker(BaseWorker):
920
960
  self.assigned_units.clear()
921
961
  self.current_unit = None
922
962
 
923
- # Clear queues
924
- self._clear_queue(self.readahead_queue)
925
- self._clear_queue(self.inference_queue)
963
+ # Clear result queue
926
964
  self._clear_queue(self.result_queue)
927
965
 
928
966
  def _clear_queue(self, queue: Queue):
@@ -935,9 +973,6 @@ class CaptionWorker(BaseWorker):
935
973
 
936
974
  async def _pre_shutdown(self):
937
975
  """Cleanup before shutdown."""
938
- self.readahead_queue.put(None)
939
- self.inference_queue.put(None)
940
-
941
976
  if self.image_processor:
942
977
  self.image_processor.shutdown()
943
978