caption-flow 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,1028 +0,0 @@
1
- """Improved vLLM worker with proper connection recovery and chunk abandonment.
2
-
3
- Key improvements:
4
- 1. Detects disconnection and stops current chunk processing
5
- 2. Clears all queues and abandons current chunk on disconnect
6
- 3. Maintains vLLM instance across reconnections
7
- 4. Properly handles connection state in all threads
8
- """
9
-
10
- import os
11
-
12
- os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
13
-
14
- import asyncio
15
- import io
16
- import json
17
- import logging
18
- import ssl
19
- import shlex
20
- import time
21
- from dataclasses import dataclass
22
- from pathlib import Path
23
- from typing import Dict, Any, Optional, List
24
- from queue import Queue, Empty
25
- from threading import Thread, Lock, Event
26
- from collections import deque
27
-
28
- import websockets
29
- from websockets.client import WebSocketClientProtocol
30
- from PIL import Image
31
- import numpy as np
32
- import webdataset as wds
33
- from huggingface_hub import get_token
34
-
35
- from .models import JobStatus, Job
36
- from .utils import CaptionUtils
37
- from .utils.dataset_loader import DatasetLoader
38
- from .utils.vllm_config import VLLMConfigManager
39
-
40
- logger = logging.getLogger(__name__)
41
-
42
-
43
- @dataclass
44
- class ShardChunk:
45
- """Shard chunk assignment from orchestrator."""
46
-
47
- chunk_id: str
48
- shard_url: str
49
- shard_name: str
50
- start_index: int
51
- chunk_size: int
52
-
53
-
54
- @dataclass
55
- class ProcessingItem:
56
- """Item being processed."""
57
-
58
- chunk_id: str
59
- item_key: str
60
- image: Image.Image
61
- image_data: bytes
62
-
63
-
64
- @dataclass
65
- class ProcessedResult:
66
- """Result with multiple captions and metadata."""
67
-
68
- chunk_id: str
69
- shard_name: str
70
- item_key: str
71
- captions: List[str]
72
- image_width: int
73
- image_height: int
74
- image_format: str
75
- file_size: int
76
- processing_time_ms: float
77
-
78
-
79
- class VLLMWorker:
80
- """Worker that processes shard chunks directly with proper reconnection."""
81
-
82
- def __init__(self, config: Dict[str, Any]):
83
- self.config = config
84
- self.server_url = config["server"]
85
- self.token = config["token"]
86
- self.name = config.get("name", "worker")
87
-
88
- # Dataset configuration will be received from orchestrator
89
- self.dataset_config = None
90
- self.dataset_loader = None
91
- self.dataset_type = None
92
- self.hf_token = get_token()
93
-
94
- # vLLM configuration will be received from orchestrator
95
- self.vllm_config = None
96
- self.inference_prompts = None
97
- self.vllm_config_manager = VLLMConfigManager()
98
-
99
- # Backward compatibility: local config for GPU selection
100
- self.gpu_id = config.get("gpu_id", 0)
101
-
102
- # SSL configuration
103
- self.ssl_context = self._setup_ssl()
104
-
105
- # State
106
- self.worker_id: Optional[str] = None
107
- self.websocket: Optional[WebSocketClientProtocol] = None
108
- self.running = False
109
- self.main_loop: Optional[asyncio.AbstractEventLoop] = None # Store main event loop
110
-
111
- # Connection state events
112
- self.connected = Event()
113
- self.should_stop_processing = Event()
114
-
115
- # Inference components (initialized in setup)
116
- self.llm = None
117
- self.processor = None
118
- self.tokenizer = None
119
- self.sampling_params = None
120
-
121
- # Shard chunk processing
122
- self.chunk_lock = Lock()
123
- self.assigned_chunks = deque()
124
- self.current_chunk = None
125
- self.current_chunk_progress = 0
126
- # Batching queues - will be cleared on disconnect
127
- self.readahead_queue = Queue(maxsize=256)
128
- self.inference_queue = Queue(maxsize=128)
129
- self.result_queue = Queue()
130
-
131
- # Metrics
132
- self.items_processed = 0
133
- self.items_failed = 0
134
- self.chunks_completed = 0
135
-
136
- # Job mode for shards vs jobs and job queue.
137
- self.job_mode = config.get("job_mode", False)
138
- self.job_queue = Queue(maxsize=32)
139
-
140
- def _setup_ssl(self) -> Optional[ssl.SSLContext]:
141
- """Configure SSL context."""
142
- if self.server_url.startswith("ws://"):
143
- logger.warning("Using insecure WebSocket connection")
144
- return None
145
-
146
- if not self.config.get("verify_ssl", True):
147
- context = ssl.create_default_context()
148
- context.check_hostname = False
149
- context.verify_mode = ssl.CERT_NONE
150
- return context
151
-
152
- return ssl.create_default_context()
153
-
154
- def _setup_dataset_loader(self, dataset_config: Dict[str, Any]):
155
- """Initialize dataset loader with config from orchestrator."""
156
- dataset_path = dataset_config.get("dataset_path") or dataset_config.get("path")
157
- dataset_type = dataset_config.get("dataset_type") or dataset_config.get(
158
- "type", "huggingface"
159
- )
160
-
161
- if dataset_path:
162
- logger.info(f"Initializing dataset loader for {dataset_type}: {dataset_path}")
163
- self.dataset_loader = DatasetLoader(dataset_path, dataset_type)
164
- self.dataset_config = dataset_config
165
- self.dataset_type = dataset_type
166
- else:
167
- logger.warning("No dataset path provided by orchestrator")
168
-
169
- def _setup_vllm(self):
170
- """Initialize vLLM components."""
171
- if not self.vllm_config:
172
- raise RuntimeError("vLLM config not received from orchestrator")
173
-
174
- os.environ["CUDA_VISIBLE_DEVICES"] = str(self.gpu_id)
175
-
176
- from vllm import LLM, SamplingParams
177
- from transformers import AutoTokenizer, AutoProcessor
178
-
179
- model_name = self.vllm_config["model"]
180
- logger.info(f"Loading {model_name} on GPU {self.gpu_id}")
181
-
182
- # Always reload tokenizer/processor (they're model-specific)
183
- self.tokenizer = AutoTokenizer.from_pretrained(
184
- model_name, trust_remote_code=True, use_fast=True
185
- )
186
- self.processor = AutoProcessor.from_pretrained(model_name)
187
-
188
- # Initialize LLM with settings from orchestrator using config manager
189
- vllm_params = self.vllm_config_manager.get_vllm_init_params(self.vllm_config)
190
- self.llm = LLM(**vllm_params)
191
-
192
- # Create sampling params from orchestrator config
193
- self.sampling_params = self.vllm_config_manager.create_sampling_params(self.vllm_config)
194
-
195
- logger.info("vLLM initialization complete")
196
-
197
- # Update config manager's tracking
198
- self.vllm_config_manager.current_config = self.vllm_config
199
-
200
- async def _handle_job_assignment(self, job_data: Dict):
201
- """Handle job assignment from orchestrator."""
202
- try:
203
- # Convert to processing item
204
- image = Image.open(io.BytesIO(job_data["image_data"]))
205
-
206
- item = ProcessingItem(
207
- chunk_id=job_data["job_id"],
208
- item_key=job_data["sample_id"],
209
- image=image,
210
- image_data=job_data["image_data"],
211
- )
212
-
213
- # Add to inference queue
214
- self.readahead_queue.put(item)
215
- logger.debug(f"Queued job {job_data['job_id']} for processing")
216
-
217
- except Exception as e:
218
- logger.error(f"Error handling job assignment: {e}")
219
-
220
- async def _job_request_loop(self):
221
- """Request jobs from orchestrator in job mode."""
222
- while self.running and self.connected.is_set():
223
- try:
224
- # Check if we need more work
225
- if self.readahead_queue.qsize() < self.vllm_config.get("batch_size", 8):
226
- await self.websocket.send(json.dumps({"type": "request_job"}))
227
-
228
- await asyncio.sleep(1)
229
-
230
- except Exception as e:
231
- logger.error(f"Job request error: {e}")
232
- await asyncio.sleep(5)
233
-
234
- def _handle_vllm_config_update(self, new_config: Dict[str, Any]) -> bool:
235
- """
236
- Handle vLLM configuration updates.
237
-
238
- Returns:
239
- True if config was updated successfully, False if reload is needed
240
- """
241
- if not new_config:
242
- return True
243
-
244
- # Check what changed
245
- change = self.vllm_config_manager.analyze_config_change(self.vllm_config, new_config)
246
-
247
- if not change.changed_fields:
248
- # No changes
249
- return True
250
-
251
- if change.requires_reload:
252
- # Need to reload vLLM
253
- logger.info(f"vLLM config changes require reload: {change.changed_fields}")
254
-
255
- # Save old config
256
- old_config = self.vllm_config
257
- self.vllm_config = new_config
258
-
259
- try:
260
- # Reload vLLM with new config
261
- logger.info("Reloading vLLM with new configuration...")
262
-
263
- # Clean up old instance
264
- if hasattr(self, "llm") and self.llm:
265
- del self.llm
266
-
267
- # Also clean up tokenizer/processor if model changed
268
- if change.model_changed:
269
- if hasattr(self, "tokenizer"):
270
- del self.tokenizer
271
- if hasattr(self, "processor"):
272
- del self.processor
273
-
274
- import gc
275
-
276
- gc.collect()
277
-
278
- # Reload with new config
279
- self._setup_vllm()
280
-
281
- # Update prompts
282
- self.inference_prompts = new_config.get("inference_prompts", self.inference_prompts)
283
-
284
- logger.info("vLLM reload complete")
285
- return True
286
-
287
- except Exception as e:
288
- logger.error(f"Failed to reload vLLM: {e}")
289
- # Restore old config
290
- self.vllm_config = old_config
291
- return False
292
-
293
- else:
294
- # Can update without reload
295
- logger.info(f"Updating vLLM config without reload: {change.changed_fields}")
296
-
297
- # Update sampling params if changed
298
- if change.sampling_changed:
299
- self.sampling_params = self.vllm_config_manager.create_sampling_params(new_config)
300
-
301
- # Update prompts if changed
302
- if change.prompts_changed:
303
- self.inference_prompts = new_config.get("inference_prompts", self.inference_prompts)
304
- logger.info(f"Updated inference prompts: {len(self.inference_prompts)} prompts")
305
-
306
- # Update config
307
- self.vllm_config = new_config
308
- logger.info("vLLM configuration updated successfully without reload")
309
- return True
310
-
311
- def _clear_state_on_disconnect(self):
312
- """Clear all processing state when disconnected."""
313
- logger.info("Clearing state due to disconnection")
314
-
315
- # Signal threads to stop current processing
316
- self.should_stop_processing.set()
317
-
318
- with self.chunk_lock:
319
- # Clear assigned chunks
320
- self.assigned_chunks.clear()
321
- self.current_chunk = None
322
- self.current_chunk_progress = 0
323
-
324
- # Clear all queues
325
- self._clear_queue(self.readahead_queue)
326
- self._clear_queue(self.inference_queue)
327
- self._clear_queue(self.result_queue)
328
-
329
- logger.info("State cleared, ready for reconnection")
330
-
331
- def _clear_queue(self, queue: Queue):
332
- """Clear all items from a queue."""
333
- try:
334
- while True:
335
- queue.get_nowait()
336
- except Empty:
337
- pass
338
-
339
- async def start(self):
340
- """Start the worker with automatic reconnection."""
341
- self.running = True
342
-
343
- # Wait for initial connection to get vLLM config
344
- logger.info("Connecting to orchestrator for configuration...")
345
-
346
- # Try initial connection to get config
347
- config_received = False
348
- while not config_received and self.running:
349
- try:
350
- await self._initial_connect_for_config()
351
- config_received = True
352
- except Exception as e:
353
- logger.error(f"Failed to get config: {e}")
354
- await asyncio.sleep(5)
355
-
356
- # Initialize vLLM once we have config
357
- self._setup_vllm()
358
-
359
- # Capture the main event loop for use in background threads
360
- self.main_loop = asyncio.get_running_loop()
361
-
362
- # Start shard reader thread
363
- reader_thread = Thread(target=self._shard_reader_thread, daemon=True)
364
- reader_thread.start()
365
-
366
- # Start inference thread
367
- inference_thread = Thread(target=self._inference_thread, daemon=True)
368
- inference_thread.start()
369
-
370
- # Reconnection with exponential backoff
371
- reconnect_delay = 5
372
- max_delay = 60
373
-
374
- # Connect to orchestrator with retries
375
- while self.running:
376
- try:
377
- await self._connect_and_run()
378
-
379
- # Reset delay on successful connection
380
- reconnect_delay = 5
381
-
382
- except Exception as e:
383
- logger.error(f"Connection error: {e}")
384
-
385
- # Mark as disconnected
386
- self.connected.clear()
387
- self.websocket = None
388
-
389
- # Clear all state on disconnect
390
- self._clear_state_on_disconnect()
391
-
392
- if self.running:
393
- logger.info(f"Reconnecting in {reconnect_delay} seconds...")
394
- await asyncio.sleep(reconnect_delay)
395
-
396
- # Exponential backoff
397
- reconnect_delay = min(reconnect_delay * 2, max_delay)
398
-
399
- async def _initial_connect_for_config(self):
400
- """Connect initially just to get configuration."""
401
- async with websockets.connect(self.server_url, ssl=self.ssl_context) as websocket:
402
- # Authenticate
403
- await websocket.send(json.dumps({"token": self.token, "name": self.name}))
404
-
405
- # Wait for welcome message with config
406
- welcome = await websocket.recv()
407
- welcome_data = json.loads(welcome)
408
-
409
- if "error" in welcome_data:
410
- raise RuntimeError(f"Authentication failed: {welcome_data['error']}")
411
-
412
- # Extract vLLM configuration
413
- self.vllm_config = welcome_data.get("vllm_config")
414
- if not self.vllm_config:
415
- raise RuntimeError("No vLLM configuration received from orchestrator")
416
-
417
- self.inference_prompts = self.vllm_config.get(
418
- "inference_prompts",
419
- [
420
- "describe this image in detail",
421
- "provide a comprehensive description of the visual content",
422
- "what are the key elements in this image?",
423
- ],
424
- )
425
-
426
- # Store config in manager
427
- self.vllm_config_manager.current_config = self.vllm_config
428
-
429
- # Extract dataset configuration
430
- dataset_config = welcome_data.get("dataset_config", {})
431
- if dataset_config:
432
- self._setup_dataset_loader(dataset_config)
433
-
434
- logger.info("Received configuration from orchestrator")
435
- # Disconnect after getting config
436
-
437
- async def _connect_and_run(self):
438
- """Connect to orchestrator and process chunks."""
439
- logger.info(f"Connecting to {self.server_url}")
440
-
441
- async with websockets.connect(self.server_url, ssl=self.ssl_context) as websocket:
442
- self.websocket = websocket
443
- self.connected.set()
444
-
445
- # Clear stop signal now that we're connected
446
- self.should_stop_processing.clear()
447
-
448
- # Authenticate
449
- await websocket.send(json.dumps({"token": self.token, "name": self.name}))
450
-
451
- # Wait for welcome message with dataset config
452
- welcome = await websocket.recv()
453
- welcome_data = json.loads(welcome)
454
-
455
- if "error" in welcome_data:
456
- logger.error(f"Authentication failed: {welcome_data['error']}")
457
- self.running = False
458
- return
459
-
460
- self.worker_id = welcome_data.get("worker_id")
461
- logger.info(f"Connected as {self.worker_id}")
462
-
463
- # Extract and setup dataset configuration from orchestrator
464
- dataset_config = welcome_data.get("dataset_config", {})
465
- if dataset_config:
466
- self._setup_dataset_loader(dataset_config)
467
- logger.info(f"Received dataset config: {dataset_config}")
468
- else:
469
- logger.warning("No dataset configuration received from orchestrator")
470
-
471
- # Update vLLM config if provided (in case it changed)
472
- new_vllm_config = welcome_data.get("vllm_config")
473
- if new_vllm_config and new_vllm_config != self.vllm_config:
474
- logger.info("Received updated vLLM configuration")
475
-
476
- # Handle config update (may trigger reload)
477
- if not self._handle_vllm_config_update(new_vllm_config):
478
- logger.error("Failed to update vLLM configuration")
479
- # Continue with existing config
480
-
481
- if self.job_mode:
482
- # In job mode, request individual jobs instead of chunks
483
- tasks.append(asyncio.create_task(self._job_request_loop()))
484
- else:
485
- # Request initial chunks
486
- await websocket.send(json.dumps({"type": "request_chunks", "count": 2}))
487
-
488
- # Start processing
489
- try:
490
- # Create tasks
491
- tasks = [
492
- asyncio.create_task(self._heartbeat_loop()),
493
- asyncio.create_task(self._message_handler()),
494
- asyncio.create_task(self._result_sender()),
495
- ]
496
-
497
- # Wait for any task to complete (likely due to disconnection)
498
- done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
499
-
500
- # Cancel remaining tasks
501
- for task in pending:
502
- task.cancel()
503
- try:
504
- await task
505
- except asyncio.CancelledError:
506
- pass
507
-
508
- finally:
509
- # Ensure we mark as disconnected
510
- self.connected.clear()
511
- self.websocket = None
512
-
513
- async def _message_handler(self):
514
- """Handle messages from orchestrator."""
515
- try:
516
- async for message in self.websocket:
517
- try:
518
- data = json.loads(message)
519
- msg_type = data.get("type")
520
-
521
- if msg_type == "shard_assignment":
522
- chunks = data["chunks"]
523
- for chunk_data in chunks:
524
- chunk = ShardChunk(**chunk_data)
525
- with self.chunk_lock:
526
- self.assigned_chunks.append(chunk)
527
- logger.info(f"Received chunk assignment: {chunk.chunk_id}")
528
-
529
- elif msg_type == "no_chunks":
530
- reason = data.get("reason", "unknown")
531
- logger.info(f"No chunks available from orchestrator (reason: {reason})")
532
-
533
- # Different wait times based on reason
534
- wait_time = 2 if reason == "state_restoring" else 10
535
- await asyncio.sleep(wait_time)
536
-
537
- # Request again after waiting
538
- if self.websocket and self.connected.is_set():
539
- await self.websocket.send(
540
- json.dumps({"type": "request_chunks", "count": 2})
541
- )
542
-
543
- elif msg_type == "reload_vllm":
544
- # Orchestrator requested vLLM reload
545
- logger.info("Orchestrator requested vLLM reload")
546
- new_config = data.get("vllm_config")
547
- if new_config:
548
- self._handle_vllm_config_update(new_config)
549
-
550
- elif msg_type == "job_assignment":
551
- await self._handle_job_assignment(data["job"])
552
-
553
- elif msg_type == "no_jobs":
554
- logger.debug("No jobs available")
555
- await asyncio.sleep(2)
556
-
557
- except json.JSONDecodeError as e:
558
- logger.error(f"Invalid message format: {e}")
559
- except Exception as e:
560
- logger.error(f"Error handling message: {e}")
561
-
562
- except websockets.exceptions.ConnectionClosed as e:
563
- logger.info(f"Connection closed by orchestrator: {e}")
564
- raise # Re-raise to trigger cleanup
565
- except Exception as e:
566
- logger.error(f"Message handler error: {e}")
567
- raise
568
-
569
- def _shard_reader_thread(self):
570
- """Background thread that reads from WebDataset shards."""
571
- logger.info("Starting shard reader thread")
572
-
573
- while self.running:
574
- # Check if we should stop processing
575
- if self.should_stop_processing.is_set():
576
- logger.info("Shard reader waiting for reconnection")
577
- time.sleep(1)
578
- continue
579
-
580
- # Only process if connected
581
- if not self.connected.is_set():
582
- time.sleep(1)
583
- continue
584
-
585
- # Get next chunk to process
586
- with self.chunk_lock:
587
- if not self.current_chunk and self.assigned_chunks:
588
- self.current_chunk = self.assigned_chunks.popleft()
589
- self.current_chunk_progress = 0
590
- logger.info(f"Starting chunk {self.current_chunk.chunk_id}")
591
-
592
- if not self.current_chunk:
593
- time.sleep(1)
594
- continue
595
-
596
- try:
597
- # Process the chunk
598
- self._process_shard_chunk(self.current_chunk)
599
-
600
- # Only mark complete if still connected
601
- if self.connected.is_set() and not self.should_stop_processing.is_set():
602
- logger.info(f"Completed chunk {self.current_chunk.chunk_id}")
603
- self.chunks_completed += 1
604
-
605
- # Notify orchestrator if connected
606
- if self.websocket and self.main_loop:
607
- try:
608
- # Notify completion
609
- asyncio.run_coroutine_threadsafe(
610
- self.websocket.send(
611
- json.dumps(
612
- {
613
- "type": "chunk_complete",
614
- "chunk_id": self.current_chunk.chunk_id,
615
- }
616
- )
617
- ),
618
- self.main_loop,
619
- ).result(timeout=5)
620
-
621
- # Request more chunks if queue is low
622
- with self.chunk_lock:
623
- queue_size = len(self.assigned_chunks)
624
-
625
- if queue_size < 2:
626
- logger.info(f"Requesting more chunks (queue size: {queue_size})")
627
- asyncio.run_coroutine_threadsafe(
628
- self.websocket.send(
629
- json.dumps({"type": "request_chunks", "count": 2})
630
- ),
631
- self.main_loop,
632
- ).result(timeout=5)
633
-
634
- except Exception as e:
635
- logger.warning(f"Could not notify orchestrator: {e}")
636
-
637
- with self.chunk_lock:
638
- self.current_chunk = None
639
-
640
- except Exception as e:
641
- logger.error(f"Error processing chunk: {e}")
642
-
643
- # Only notify of failure if still connected
644
- if self.connected.is_set() and self.websocket and self.main_loop:
645
- try:
646
- asyncio.run_coroutine_threadsafe(
647
- self.websocket.send(
648
- json.dumps(
649
- {
650
- "type": "chunk_failed",
651
- "chunk_id": (
652
- self.current_chunk.chunk_id
653
- if self.current_chunk
654
- else "unknown"
655
- ),
656
- "error": str(e),
657
- }
658
- )
659
- ),
660
- self.main_loop,
661
- ).result(timeout=5)
662
- except Exception as send_error:
663
- logger.warning(
664
- f"Could not notify orchestrator of chunk failure: {send_error}"
665
- )
666
-
667
- with self.chunk_lock:
668
- self.current_chunk = None
669
-
670
- def _process_shard_chunk(self, chunk: ShardChunk):
671
- """Process a single shard chunk."""
672
- logger.info(f"Processing shard {chunk.shard_name} from index {chunk.start_index}")
673
-
674
- # Create WebDataset pipeline
675
- if self.dataset_type == "huggingface":
676
- # Use curl with auth for HuggingFace
677
- url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.hf_token)}' {shlex.quote(chunk.shard_url)} || true"
678
- ds = wds.DataPipeline(
679
- wds.SimpleShardList(url_cmd),
680
- wds.tarfile_to_samples(),
681
- wds.to_tuple("__key__", "jpg;png;jpeg;webp"),
682
- )
683
- else:
684
- # Local file
685
- ds = wds.DataPipeline(
686
- wds.SimpleShardList(chunk.shard_url),
687
- wds.tarfile_to_samples(),
688
- wds.to_tuple("__key__", "jpg;png;jpeg;webp"),
689
- )
690
-
691
- # Process items with readahead
692
- items_processed = 0
693
- items_to_skip = chunk.start_index
694
-
695
- for key, image_data in ds:
696
- # Check if we should stop
697
- if (
698
- not self.running
699
- or self.should_stop_processing.is_set()
700
- or not self.connected.is_set()
701
- ):
702
- logger.info(f"Stopping chunk processing early due to disconnect")
703
- break
704
-
705
- # Skip to start index
706
- if items_to_skip > 0:
707
- items_to_skip -= 1
708
- continue
709
-
710
- # Check if we've processed enough
711
- if items_processed >= chunk.chunk_size:
712
- break
713
-
714
- try:
715
- # Load image
716
- img = Image.open(io.BytesIO(image_data))
717
-
718
- # Create processing item
719
- item = ProcessingItem(
720
- chunk_id=chunk.chunk_id, item_key=key, image=img, image_data=image_data
721
- )
722
-
723
- # Add to readahead queue (blocks if full - provides backpressure)
724
- # Use timeout to allow checking for disconnection
725
- timeout_end = time.time() + 30
726
- while (
727
- self.running
728
- and not self.should_stop_processing.is_set()
729
- and self.connected.is_set()
730
- ):
731
- try:
732
- self.readahead_queue.put(item, timeout=1)
733
- break
734
- except:
735
- if time.time() > timeout_end:
736
- raise TimeoutError("Queue put timeout")
737
- continue
738
-
739
- # If we couldn't queue due to disconnection, skip this item
740
- if not self.connected.is_set() or self.should_stop_processing.is_set():
741
- logger.debug(f"Skipping item {key} due to disconnection")
742
- break
743
-
744
- items_processed += 1
745
- self.current_chunk_progress = items_processed
746
-
747
- # Batch items for inference
748
- batch_size = self.vllm_config.get("batch_size", 8)
749
- if self.readahead_queue.qsize() >= batch_size:
750
- self._batch_for_inference()
751
-
752
- except Exception as e:
753
- if self.should_stop_processing.is_set():
754
- break
755
- logger.error(f"Error processing item {key}: {e}")
756
- self.items_failed += 1
757
-
758
- # Process remaining items only if still connected
759
- if not self.should_stop_processing.is_set():
760
- self._batch_for_inference()
761
-
762
- logger.info(f"Chunk {chunk.chunk_id} processed {items_processed} items")
763
-
764
- def _batch_for_inference(self):
765
- """Batch items from readahead queue for inference."""
766
- batch = []
767
- batch_size = self.vllm_config.get("batch_size", 8)
768
-
769
- try:
770
- while len(batch) < batch_size:
771
- item = self.readahead_queue.get_nowait()
772
- batch.append(item)
773
- except Empty:
774
- pass
775
-
776
- if batch:
777
- self.inference_queue.put(batch)
778
-
779
- def _inference_thread(self):
780
- """Background thread for vLLM inference."""
781
- logger.info("Starting inference thread")
782
-
783
- while self.running:
784
- try:
785
- # Get batch from queue with timeout
786
- batch = self.inference_queue.get(timeout=1)
787
-
788
- if not batch:
789
- continue
790
-
791
- # Skip if disconnected
792
- if self.should_stop_processing.is_set():
793
- continue
794
-
795
- logger.debug(f"Processing batch of {len(batch)} images")
796
- start_time = time.time()
797
-
798
- # Prepare vLLM inputs
799
- requests = []
800
- for item in batch:
801
- # Resize for consistency
802
- item.image.thumbnail((512, 512), Image.BILINEAR)
803
-
804
- for prompt in self.inference_prompts:
805
- req = self._build_vllm_input(item.image, prompt)
806
- requests.append(req)
807
-
808
- # Run inference
809
- outputs = self.llm.generate(requests, self.sampling_params)
810
-
811
- # Process outputs only if still connected
812
- if not self.should_stop_processing.is_set():
813
- for i, item in enumerate(batch):
814
- # Get all prompt outputs as a list
815
- idx = i * len(self.inference_prompts)
816
- captions = []
817
-
818
- for j in range(len(self.inference_prompts)):
819
- if idx + j < len(outputs) and outputs[idx + j].outputs:
820
- caption_text = self._clean_output(outputs[idx + j].outputs[0].text)
821
- if caption_text: # Only add non-empty captions
822
- captions.append(caption_text)
823
-
824
- # Only create result if we have at least one caption
825
- if captions:
826
- result = ProcessedResult(
827
- chunk_id=item.chunk_id,
828
- shard_name=Path(item.chunk_id).stem.rsplit("_chunk_", 1)[0],
829
- item_key=item.item_key,
830
- captions=captions,
831
- image_width=item.image.width,
832
- image_height=item.image.height,
833
- image_format=item.image.format or "unknown",
834
- file_size=len(item.image_data),
835
- processing_time_ms=(time.time() - start_time) * 1000 / len(batch),
836
- )
837
-
838
- self.result_queue.put(result)
839
- self.items_processed += 1
840
- else:
841
- logger.warning(f"No valid captions generated for item {item.item_key}")
842
- self.items_failed += 1
843
-
844
- except Empty:
845
- continue
846
- except Exception as e:
847
- if self.should_stop_processing.is_set():
848
- continue
849
- logger.error(f"Inference error: {e}")
850
-
851
- def _build_vllm_input(self, image: Image.Image, prompt: str) -> Dict:
852
- """Build vLLM input."""
853
- try:
854
- from qwen_vl_utils import process_vision_info
855
-
856
- messages = [
857
- {
858
- "role": "user",
859
- "content": [
860
- {"type": "image", "image": image},
861
- {"type": "text", "text": prompt},
862
- ],
863
- }
864
- ]
865
-
866
- prompt_text = self.processor.apply_chat_template(
867
- messages, tokenize=False, add_generation_prompt=True
868
- )
869
- image_inputs, _ = process_vision_info(messages)
870
- prompt_ids = self.tokenizer(prompt_text, add_special_tokens=False).input_ids
871
-
872
- return {
873
- "prompt_token_ids": prompt_ids,
874
- "multi_modal_data": {"image": image_inputs},
875
- }
876
- except ImportError:
877
- return {
878
- "prompt": f"<|user|>\n<|image_pad|>\n{prompt}<|end|>\n<|assistant|>",
879
- "multi_modal_data": {"image": [image]},
880
- }
881
-
882
- def _clean_output(self, text: str) -> str:
883
- """Clean model output."""
884
- if not text:
885
- return ""
886
-
887
- # Remove common artifacts
888
- for token in ["<|end|>", "<|endoftext|>", "<|im_end|>", "I'm sorry", "I cannot"]:
889
- if token in text:
890
- text = text.split(token)[0]
891
-
892
- return text.strip()
893
-
894
- async def _result_sender(self):
895
- """Send results back to orchestrator."""
896
- pending_results = [] # Buffer for results during disconnection
897
-
898
- try:
899
- while self.running and self.connected.is_set():
900
- try:
901
- # Get result (with timeout to allow checking self.running)
902
- try:
903
- result = await asyncio.get_event_loop().run_in_executor(
904
- None, self.result_queue.get, True, 1
905
- )
906
- pending_results.append(result)
907
- except Empty:
908
- pass
909
-
910
- # Only try to send if connected
911
- if pending_results and self.websocket and self.connected.is_set():
912
- sent_results = []
913
- for result in pending_results:
914
- try:
915
- # Send result with all captions
916
- await self.websocket.send(
917
- json.dumps(
918
- {
919
- "type": "submit_captions",
920
- "chunk_id": result.chunk_id,
921
- "dataset": self.dataset_config.get(
922
- "dataset_path", "unknown"
923
- ),
924
- "shard": result.shard_name,
925
- "item_key": result.item_key,
926
- "captions": result.captions,
927
- "caption_count": len(result.captions),
928
- "image_width": result.image_width,
929
- "image_height": result.image_height,
930
- "image_format": result.image_format,
931
- "file_size": result.file_size,
932
- "processing_time_ms": result.processing_time_ms,
933
- }
934
- )
935
- )
936
- sent_results.append(result)
937
-
938
- if self.items_processed % 100 == 0:
939
- logger.info(
940
- f"Processed {self.items_processed} items "
941
- f"(~{self.items_processed * 3} captions)"
942
- )
943
- except websockets.exceptions.ConnectionClosed as e:
944
- logger.warning(f"Connection lost while sending result: {e}")
945
- raise # Re-raise to trigger task completion
946
- except Exception as e:
947
- logger.error(f"Error sending result: {e}")
948
- break
949
-
950
- # Remove successfully sent results
951
- for result in sent_results:
952
- pending_results.remove(result)
953
-
954
- # Clear pending results if disconnected and buffer is too large
955
- if not self.connected.is_set() and len(pending_results) > 1000:
956
- logger.warning(
957
- f"Clearing {len(pending_results)} pending results due to prolonged disconnection"
958
- )
959
- pending_results.clear()
960
-
961
- await asyncio.sleep(0.1)
962
-
963
- except Exception as e:
964
- if isinstance(e, websockets.exceptions.ConnectionClosed):
965
- raise # Re-raise connection errors
966
- logger.error(f"Unexpected error in result sender: {e}")
967
- await asyncio.sleep(1)
968
-
969
- except asyncio.CancelledError:
970
- logger.debug("Result sender cancelled")
971
- raise
972
-
973
- async def _heartbeat_loop(self):
974
- """Send periodic heartbeats with connection checking."""
975
- try:
976
- while self.running and self.connected.is_set():
977
- try:
978
- if self.websocket:
979
- await self.websocket.send(
980
- json.dumps(
981
- {
982
- "type": "heartbeat",
983
- "processed": self.items_processed,
984
- "failed": self.items_failed,
985
- "chunks_completed": self.chunks_completed,
986
- "current_chunk": (
987
- self.current_chunk.chunk_id if self.current_chunk else None
988
- ),
989
- "chunk_progress": self.current_chunk_progress,
990
- "queue_sizes": {
991
- "readahead": self.readahead_queue.qsize(),
992
- "inference": self.inference_queue.qsize(),
993
- "results": self.result_queue.qsize(),
994
- },
995
- }
996
- )
997
- )
998
- await asyncio.sleep(30)
999
- except websockets.exceptions.ConnectionClosed as e:
1000
- logger.info(f"Connection lost during heartbeat: {e}")
1001
- raise # Re-raise to trigger task completion
1002
- except Exception as e:
1003
- logger.error(f"Heartbeat error: {e}")
1004
- raise # Re-raise to trigger task completion
1005
- except asyncio.CancelledError:
1006
- logger.debug("Heartbeat loop cancelled")
1007
- raise
1008
-
1009
- async def shutdown(self):
1010
- """Graceful shutdown."""
1011
- logger.info("Shutting down worker...")
1012
- self.running = False
1013
- self.connected.clear()
1014
- self.should_stop_processing.set()
1015
-
1016
- # Stop processing threads by adding stop signals
1017
- self.readahead_queue.put(None)
1018
- self.inference_queue.put(None)
1019
-
1020
- # Close websocket if connected
1021
- if self.websocket:
1022
- try:
1023
- await self.websocket.close()
1024
- except:
1025
- pass
1026
- self.websocket = None
1027
-
1028
- logger.info("Worker shutdown complete")