caption-flow 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +3 -2
- caption_flow/cli.py +56 -39
- caption_flow/models.py +6 -4
- caption_flow/monitor.py +12 -2
- caption_flow/orchestrator.py +729 -217
- caption_flow/storage.py +579 -222
- caption_flow/utils/__init__.py +3 -1
- caption_flow/utils/auth.py +24 -25
- caption_flow/utils/checkpoint_tracker.py +92 -0
- caption_flow/utils/chunk_tracker.py +278 -194
- caption_flow/utils/dataset_loader.py +392 -73
- caption_flow/utils/image_processor.py +121 -1
- caption_flow/utils/prompt_template.py +137 -0
- caption_flow/utils/shard_processor.py +315 -0
- caption_flow/utils/shard_tracker.py +87 -0
- caption_flow/workers/base.py +228 -0
- caption_flow/workers/caption.py +1321 -0
- caption_flow/{worker_data.py → workers/data.py} +162 -234
- caption_flow-0.2.0.dist-info/METADATA +369 -0
- caption_flow-0.2.0.dist-info/RECORD +29 -0
- caption_flow/worker.py +0 -300
- caption_flow/worker_vllm.py +0 -1028
- caption_flow-0.1.0.dist-info/METADATA +0 -427
- caption_flow-0.1.0.dist-info/RECORD +0 -25
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.0.dist-info}/WHEEL +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.0.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.0.dist-info}/top_level.txt +0 -0
caption_flow/worker_vllm.py
DELETED
@@ -1,1028 +0,0 @@
|
|
1
|
-
"""Improved vLLM worker with proper connection recovery and chunk abandonment.
|
2
|
-
|
3
|
-
Key improvements:
|
4
|
-
1. Detects disconnection and stops current chunk processing
|
5
|
-
2. Clears all queues and abandons current chunk on disconnect
|
6
|
-
3. Maintains vLLM instance across reconnections
|
7
|
-
4. Properly handles connection state in all threads
|
8
|
-
"""
|
9
|
-
|
10
|
-
import os
|
11
|
-
|
12
|
-
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
|
13
|
-
|
14
|
-
import asyncio
|
15
|
-
import io
|
16
|
-
import json
|
17
|
-
import logging
|
18
|
-
import ssl
|
19
|
-
import shlex
|
20
|
-
import time
|
21
|
-
from dataclasses import dataclass
|
22
|
-
from pathlib import Path
|
23
|
-
from typing import Dict, Any, Optional, List
|
24
|
-
from queue import Queue, Empty
|
25
|
-
from threading import Thread, Lock, Event
|
26
|
-
from collections import deque
|
27
|
-
|
28
|
-
import websockets
|
29
|
-
from websockets.client import WebSocketClientProtocol
|
30
|
-
from PIL import Image
|
31
|
-
import numpy as np
|
32
|
-
import webdataset as wds
|
33
|
-
from huggingface_hub import get_token
|
34
|
-
|
35
|
-
from .models import JobStatus, Job
|
36
|
-
from .utils import CaptionUtils
|
37
|
-
from .utils.dataset_loader import DatasetLoader
|
38
|
-
from .utils.vllm_config import VLLMConfigManager
|
39
|
-
|
40
|
-
logger = logging.getLogger(__name__)
|
41
|
-
|
42
|
-
|
43
|
-
@dataclass
|
44
|
-
class ShardChunk:
|
45
|
-
"""Shard chunk assignment from orchestrator."""
|
46
|
-
|
47
|
-
chunk_id: str
|
48
|
-
shard_url: str
|
49
|
-
shard_name: str
|
50
|
-
start_index: int
|
51
|
-
chunk_size: int
|
52
|
-
|
53
|
-
|
54
|
-
@dataclass
|
55
|
-
class ProcessingItem:
|
56
|
-
"""Item being processed."""
|
57
|
-
|
58
|
-
chunk_id: str
|
59
|
-
item_key: str
|
60
|
-
image: Image.Image
|
61
|
-
image_data: bytes
|
62
|
-
|
63
|
-
|
64
|
-
@dataclass
|
65
|
-
class ProcessedResult:
|
66
|
-
"""Result with multiple captions and metadata."""
|
67
|
-
|
68
|
-
chunk_id: str
|
69
|
-
shard_name: str
|
70
|
-
item_key: str
|
71
|
-
captions: List[str]
|
72
|
-
image_width: int
|
73
|
-
image_height: int
|
74
|
-
image_format: str
|
75
|
-
file_size: int
|
76
|
-
processing_time_ms: float
|
77
|
-
|
78
|
-
|
79
|
-
class VLLMWorker:
|
80
|
-
"""Worker that processes shard chunks directly with proper reconnection."""
|
81
|
-
|
82
|
-
def __init__(self, config: Dict[str, Any]):
|
83
|
-
self.config = config
|
84
|
-
self.server_url = config["server"]
|
85
|
-
self.token = config["token"]
|
86
|
-
self.name = config.get("name", "worker")
|
87
|
-
|
88
|
-
# Dataset configuration will be received from orchestrator
|
89
|
-
self.dataset_config = None
|
90
|
-
self.dataset_loader = None
|
91
|
-
self.dataset_type = None
|
92
|
-
self.hf_token = get_token()
|
93
|
-
|
94
|
-
# vLLM configuration will be received from orchestrator
|
95
|
-
self.vllm_config = None
|
96
|
-
self.inference_prompts = None
|
97
|
-
self.vllm_config_manager = VLLMConfigManager()
|
98
|
-
|
99
|
-
# Backward compatibility: local config for GPU selection
|
100
|
-
self.gpu_id = config.get("gpu_id", 0)
|
101
|
-
|
102
|
-
# SSL configuration
|
103
|
-
self.ssl_context = self._setup_ssl()
|
104
|
-
|
105
|
-
# State
|
106
|
-
self.worker_id: Optional[str] = None
|
107
|
-
self.websocket: Optional[WebSocketClientProtocol] = None
|
108
|
-
self.running = False
|
109
|
-
self.main_loop: Optional[asyncio.AbstractEventLoop] = None # Store main event loop
|
110
|
-
|
111
|
-
# Connection state events
|
112
|
-
self.connected = Event()
|
113
|
-
self.should_stop_processing = Event()
|
114
|
-
|
115
|
-
# Inference components (initialized in setup)
|
116
|
-
self.llm = None
|
117
|
-
self.processor = None
|
118
|
-
self.tokenizer = None
|
119
|
-
self.sampling_params = None
|
120
|
-
|
121
|
-
# Shard chunk processing
|
122
|
-
self.chunk_lock = Lock()
|
123
|
-
self.assigned_chunks = deque()
|
124
|
-
self.current_chunk = None
|
125
|
-
self.current_chunk_progress = 0
|
126
|
-
# Batching queues - will be cleared on disconnect
|
127
|
-
self.readahead_queue = Queue(maxsize=256)
|
128
|
-
self.inference_queue = Queue(maxsize=128)
|
129
|
-
self.result_queue = Queue()
|
130
|
-
|
131
|
-
# Metrics
|
132
|
-
self.items_processed = 0
|
133
|
-
self.items_failed = 0
|
134
|
-
self.chunks_completed = 0
|
135
|
-
|
136
|
-
# Job mode for shards vs jobs and job queue.
|
137
|
-
self.job_mode = config.get("job_mode", False)
|
138
|
-
self.job_queue = Queue(maxsize=32)
|
139
|
-
|
140
|
-
def _setup_ssl(self) -> Optional[ssl.SSLContext]:
|
141
|
-
"""Configure SSL context."""
|
142
|
-
if self.server_url.startswith("ws://"):
|
143
|
-
logger.warning("Using insecure WebSocket connection")
|
144
|
-
return None
|
145
|
-
|
146
|
-
if not self.config.get("verify_ssl", True):
|
147
|
-
context = ssl.create_default_context()
|
148
|
-
context.check_hostname = False
|
149
|
-
context.verify_mode = ssl.CERT_NONE
|
150
|
-
return context
|
151
|
-
|
152
|
-
return ssl.create_default_context()
|
153
|
-
|
154
|
-
def _setup_dataset_loader(self, dataset_config: Dict[str, Any]):
|
155
|
-
"""Initialize dataset loader with config from orchestrator."""
|
156
|
-
dataset_path = dataset_config.get("dataset_path") or dataset_config.get("path")
|
157
|
-
dataset_type = dataset_config.get("dataset_type") or dataset_config.get(
|
158
|
-
"type", "huggingface"
|
159
|
-
)
|
160
|
-
|
161
|
-
if dataset_path:
|
162
|
-
logger.info(f"Initializing dataset loader for {dataset_type}: {dataset_path}")
|
163
|
-
self.dataset_loader = DatasetLoader(dataset_path, dataset_type)
|
164
|
-
self.dataset_config = dataset_config
|
165
|
-
self.dataset_type = dataset_type
|
166
|
-
else:
|
167
|
-
logger.warning("No dataset path provided by orchestrator")
|
168
|
-
|
169
|
-
def _setup_vllm(self):
|
170
|
-
"""Initialize vLLM components."""
|
171
|
-
if not self.vllm_config:
|
172
|
-
raise RuntimeError("vLLM config not received from orchestrator")
|
173
|
-
|
174
|
-
os.environ["CUDA_VISIBLE_DEVICES"] = str(self.gpu_id)
|
175
|
-
|
176
|
-
from vllm import LLM, SamplingParams
|
177
|
-
from transformers import AutoTokenizer, AutoProcessor
|
178
|
-
|
179
|
-
model_name = self.vllm_config["model"]
|
180
|
-
logger.info(f"Loading {model_name} on GPU {self.gpu_id}")
|
181
|
-
|
182
|
-
# Always reload tokenizer/processor (they're model-specific)
|
183
|
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
184
|
-
model_name, trust_remote_code=True, use_fast=True
|
185
|
-
)
|
186
|
-
self.processor = AutoProcessor.from_pretrained(model_name)
|
187
|
-
|
188
|
-
# Initialize LLM with settings from orchestrator using config manager
|
189
|
-
vllm_params = self.vllm_config_manager.get_vllm_init_params(self.vllm_config)
|
190
|
-
self.llm = LLM(**vllm_params)
|
191
|
-
|
192
|
-
# Create sampling params from orchestrator config
|
193
|
-
self.sampling_params = self.vllm_config_manager.create_sampling_params(self.vllm_config)
|
194
|
-
|
195
|
-
logger.info("vLLM initialization complete")
|
196
|
-
|
197
|
-
# Update config manager's tracking
|
198
|
-
self.vllm_config_manager.current_config = self.vllm_config
|
199
|
-
|
200
|
-
async def _handle_job_assignment(self, job_data: Dict):
|
201
|
-
"""Handle job assignment from orchestrator."""
|
202
|
-
try:
|
203
|
-
# Convert to processing item
|
204
|
-
image = Image.open(io.BytesIO(job_data["image_data"]))
|
205
|
-
|
206
|
-
item = ProcessingItem(
|
207
|
-
chunk_id=job_data["job_id"],
|
208
|
-
item_key=job_data["sample_id"],
|
209
|
-
image=image,
|
210
|
-
image_data=job_data["image_data"],
|
211
|
-
)
|
212
|
-
|
213
|
-
# Add to inference queue
|
214
|
-
self.readahead_queue.put(item)
|
215
|
-
logger.debug(f"Queued job {job_data['job_id']} for processing")
|
216
|
-
|
217
|
-
except Exception as e:
|
218
|
-
logger.error(f"Error handling job assignment: {e}")
|
219
|
-
|
220
|
-
async def _job_request_loop(self):
|
221
|
-
"""Request jobs from orchestrator in job mode."""
|
222
|
-
while self.running and self.connected.is_set():
|
223
|
-
try:
|
224
|
-
# Check if we need more work
|
225
|
-
if self.readahead_queue.qsize() < self.vllm_config.get("batch_size", 8):
|
226
|
-
await self.websocket.send(json.dumps({"type": "request_job"}))
|
227
|
-
|
228
|
-
await asyncio.sleep(1)
|
229
|
-
|
230
|
-
except Exception as e:
|
231
|
-
logger.error(f"Job request error: {e}")
|
232
|
-
await asyncio.sleep(5)
|
233
|
-
|
234
|
-
def _handle_vllm_config_update(self, new_config: Dict[str, Any]) -> bool:
|
235
|
-
"""
|
236
|
-
Handle vLLM configuration updates.
|
237
|
-
|
238
|
-
Returns:
|
239
|
-
True if config was updated successfully, False if reload is needed
|
240
|
-
"""
|
241
|
-
if not new_config:
|
242
|
-
return True
|
243
|
-
|
244
|
-
# Check what changed
|
245
|
-
change = self.vllm_config_manager.analyze_config_change(self.vllm_config, new_config)
|
246
|
-
|
247
|
-
if not change.changed_fields:
|
248
|
-
# No changes
|
249
|
-
return True
|
250
|
-
|
251
|
-
if change.requires_reload:
|
252
|
-
# Need to reload vLLM
|
253
|
-
logger.info(f"vLLM config changes require reload: {change.changed_fields}")
|
254
|
-
|
255
|
-
# Save old config
|
256
|
-
old_config = self.vllm_config
|
257
|
-
self.vllm_config = new_config
|
258
|
-
|
259
|
-
try:
|
260
|
-
# Reload vLLM with new config
|
261
|
-
logger.info("Reloading vLLM with new configuration...")
|
262
|
-
|
263
|
-
# Clean up old instance
|
264
|
-
if hasattr(self, "llm") and self.llm:
|
265
|
-
del self.llm
|
266
|
-
|
267
|
-
# Also clean up tokenizer/processor if model changed
|
268
|
-
if change.model_changed:
|
269
|
-
if hasattr(self, "tokenizer"):
|
270
|
-
del self.tokenizer
|
271
|
-
if hasattr(self, "processor"):
|
272
|
-
del self.processor
|
273
|
-
|
274
|
-
import gc
|
275
|
-
|
276
|
-
gc.collect()
|
277
|
-
|
278
|
-
# Reload with new config
|
279
|
-
self._setup_vllm()
|
280
|
-
|
281
|
-
# Update prompts
|
282
|
-
self.inference_prompts = new_config.get("inference_prompts", self.inference_prompts)
|
283
|
-
|
284
|
-
logger.info("vLLM reload complete")
|
285
|
-
return True
|
286
|
-
|
287
|
-
except Exception as e:
|
288
|
-
logger.error(f"Failed to reload vLLM: {e}")
|
289
|
-
# Restore old config
|
290
|
-
self.vllm_config = old_config
|
291
|
-
return False
|
292
|
-
|
293
|
-
else:
|
294
|
-
# Can update without reload
|
295
|
-
logger.info(f"Updating vLLM config without reload: {change.changed_fields}")
|
296
|
-
|
297
|
-
# Update sampling params if changed
|
298
|
-
if change.sampling_changed:
|
299
|
-
self.sampling_params = self.vllm_config_manager.create_sampling_params(new_config)
|
300
|
-
|
301
|
-
# Update prompts if changed
|
302
|
-
if change.prompts_changed:
|
303
|
-
self.inference_prompts = new_config.get("inference_prompts", self.inference_prompts)
|
304
|
-
logger.info(f"Updated inference prompts: {len(self.inference_prompts)} prompts")
|
305
|
-
|
306
|
-
# Update config
|
307
|
-
self.vllm_config = new_config
|
308
|
-
logger.info("vLLM configuration updated successfully without reload")
|
309
|
-
return True
|
310
|
-
|
311
|
-
def _clear_state_on_disconnect(self):
|
312
|
-
"""Clear all processing state when disconnected."""
|
313
|
-
logger.info("Clearing state due to disconnection")
|
314
|
-
|
315
|
-
# Signal threads to stop current processing
|
316
|
-
self.should_stop_processing.set()
|
317
|
-
|
318
|
-
with self.chunk_lock:
|
319
|
-
# Clear assigned chunks
|
320
|
-
self.assigned_chunks.clear()
|
321
|
-
self.current_chunk = None
|
322
|
-
self.current_chunk_progress = 0
|
323
|
-
|
324
|
-
# Clear all queues
|
325
|
-
self._clear_queue(self.readahead_queue)
|
326
|
-
self._clear_queue(self.inference_queue)
|
327
|
-
self._clear_queue(self.result_queue)
|
328
|
-
|
329
|
-
logger.info("State cleared, ready for reconnection")
|
330
|
-
|
331
|
-
def _clear_queue(self, queue: Queue):
|
332
|
-
"""Clear all items from a queue."""
|
333
|
-
try:
|
334
|
-
while True:
|
335
|
-
queue.get_nowait()
|
336
|
-
except Empty:
|
337
|
-
pass
|
338
|
-
|
339
|
-
async def start(self):
|
340
|
-
"""Start the worker with automatic reconnection."""
|
341
|
-
self.running = True
|
342
|
-
|
343
|
-
# Wait for initial connection to get vLLM config
|
344
|
-
logger.info("Connecting to orchestrator for configuration...")
|
345
|
-
|
346
|
-
# Try initial connection to get config
|
347
|
-
config_received = False
|
348
|
-
while not config_received and self.running:
|
349
|
-
try:
|
350
|
-
await self._initial_connect_for_config()
|
351
|
-
config_received = True
|
352
|
-
except Exception as e:
|
353
|
-
logger.error(f"Failed to get config: {e}")
|
354
|
-
await asyncio.sleep(5)
|
355
|
-
|
356
|
-
# Initialize vLLM once we have config
|
357
|
-
self._setup_vllm()
|
358
|
-
|
359
|
-
# Capture the main event loop for use in background threads
|
360
|
-
self.main_loop = asyncio.get_running_loop()
|
361
|
-
|
362
|
-
# Start shard reader thread
|
363
|
-
reader_thread = Thread(target=self._shard_reader_thread, daemon=True)
|
364
|
-
reader_thread.start()
|
365
|
-
|
366
|
-
# Start inference thread
|
367
|
-
inference_thread = Thread(target=self._inference_thread, daemon=True)
|
368
|
-
inference_thread.start()
|
369
|
-
|
370
|
-
# Reconnection with exponential backoff
|
371
|
-
reconnect_delay = 5
|
372
|
-
max_delay = 60
|
373
|
-
|
374
|
-
# Connect to orchestrator with retries
|
375
|
-
while self.running:
|
376
|
-
try:
|
377
|
-
await self._connect_and_run()
|
378
|
-
|
379
|
-
# Reset delay on successful connection
|
380
|
-
reconnect_delay = 5
|
381
|
-
|
382
|
-
except Exception as e:
|
383
|
-
logger.error(f"Connection error: {e}")
|
384
|
-
|
385
|
-
# Mark as disconnected
|
386
|
-
self.connected.clear()
|
387
|
-
self.websocket = None
|
388
|
-
|
389
|
-
# Clear all state on disconnect
|
390
|
-
self._clear_state_on_disconnect()
|
391
|
-
|
392
|
-
if self.running:
|
393
|
-
logger.info(f"Reconnecting in {reconnect_delay} seconds...")
|
394
|
-
await asyncio.sleep(reconnect_delay)
|
395
|
-
|
396
|
-
# Exponential backoff
|
397
|
-
reconnect_delay = min(reconnect_delay * 2, max_delay)
|
398
|
-
|
399
|
-
async def _initial_connect_for_config(self):
|
400
|
-
"""Connect initially just to get configuration."""
|
401
|
-
async with websockets.connect(self.server_url, ssl=self.ssl_context) as websocket:
|
402
|
-
# Authenticate
|
403
|
-
await websocket.send(json.dumps({"token": self.token, "name": self.name}))
|
404
|
-
|
405
|
-
# Wait for welcome message with config
|
406
|
-
welcome = await websocket.recv()
|
407
|
-
welcome_data = json.loads(welcome)
|
408
|
-
|
409
|
-
if "error" in welcome_data:
|
410
|
-
raise RuntimeError(f"Authentication failed: {welcome_data['error']}")
|
411
|
-
|
412
|
-
# Extract vLLM configuration
|
413
|
-
self.vllm_config = welcome_data.get("vllm_config")
|
414
|
-
if not self.vllm_config:
|
415
|
-
raise RuntimeError("No vLLM configuration received from orchestrator")
|
416
|
-
|
417
|
-
self.inference_prompts = self.vllm_config.get(
|
418
|
-
"inference_prompts",
|
419
|
-
[
|
420
|
-
"describe this image in detail",
|
421
|
-
"provide a comprehensive description of the visual content",
|
422
|
-
"what are the key elements in this image?",
|
423
|
-
],
|
424
|
-
)
|
425
|
-
|
426
|
-
# Store config in manager
|
427
|
-
self.vllm_config_manager.current_config = self.vllm_config
|
428
|
-
|
429
|
-
# Extract dataset configuration
|
430
|
-
dataset_config = welcome_data.get("dataset_config", {})
|
431
|
-
if dataset_config:
|
432
|
-
self._setup_dataset_loader(dataset_config)
|
433
|
-
|
434
|
-
logger.info("Received configuration from orchestrator")
|
435
|
-
# Disconnect after getting config
|
436
|
-
|
437
|
-
async def _connect_and_run(self):
|
438
|
-
"""Connect to orchestrator and process chunks."""
|
439
|
-
logger.info(f"Connecting to {self.server_url}")
|
440
|
-
|
441
|
-
async with websockets.connect(self.server_url, ssl=self.ssl_context) as websocket:
|
442
|
-
self.websocket = websocket
|
443
|
-
self.connected.set()
|
444
|
-
|
445
|
-
# Clear stop signal now that we're connected
|
446
|
-
self.should_stop_processing.clear()
|
447
|
-
|
448
|
-
# Authenticate
|
449
|
-
await websocket.send(json.dumps({"token": self.token, "name": self.name}))
|
450
|
-
|
451
|
-
# Wait for welcome message with dataset config
|
452
|
-
welcome = await websocket.recv()
|
453
|
-
welcome_data = json.loads(welcome)
|
454
|
-
|
455
|
-
if "error" in welcome_data:
|
456
|
-
logger.error(f"Authentication failed: {welcome_data['error']}")
|
457
|
-
self.running = False
|
458
|
-
return
|
459
|
-
|
460
|
-
self.worker_id = welcome_data.get("worker_id")
|
461
|
-
logger.info(f"Connected as {self.worker_id}")
|
462
|
-
|
463
|
-
# Extract and setup dataset configuration from orchestrator
|
464
|
-
dataset_config = welcome_data.get("dataset_config", {})
|
465
|
-
if dataset_config:
|
466
|
-
self._setup_dataset_loader(dataset_config)
|
467
|
-
logger.info(f"Received dataset config: {dataset_config}")
|
468
|
-
else:
|
469
|
-
logger.warning("No dataset configuration received from orchestrator")
|
470
|
-
|
471
|
-
# Update vLLM config if provided (in case it changed)
|
472
|
-
new_vllm_config = welcome_data.get("vllm_config")
|
473
|
-
if new_vllm_config and new_vllm_config != self.vllm_config:
|
474
|
-
logger.info("Received updated vLLM configuration")
|
475
|
-
|
476
|
-
# Handle config update (may trigger reload)
|
477
|
-
if not self._handle_vllm_config_update(new_vllm_config):
|
478
|
-
logger.error("Failed to update vLLM configuration")
|
479
|
-
# Continue with existing config
|
480
|
-
|
481
|
-
if self.job_mode:
|
482
|
-
# In job mode, request individual jobs instead of chunks
|
483
|
-
tasks.append(asyncio.create_task(self._job_request_loop()))
|
484
|
-
else:
|
485
|
-
# Request initial chunks
|
486
|
-
await websocket.send(json.dumps({"type": "request_chunks", "count": 2}))
|
487
|
-
|
488
|
-
# Start processing
|
489
|
-
try:
|
490
|
-
# Create tasks
|
491
|
-
tasks = [
|
492
|
-
asyncio.create_task(self._heartbeat_loop()),
|
493
|
-
asyncio.create_task(self._message_handler()),
|
494
|
-
asyncio.create_task(self._result_sender()),
|
495
|
-
]
|
496
|
-
|
497
|
-
# Wait for any task to complete (likely due to disconnection)
|
498
|
-
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
499
|
-
|
500
|
-
# Cancel remaining tasks
|
501
|
-
for task in pending:
|
502
|
-
task.cancel()
|
503
|
-
try:
|
504
|
-
await task
|
505
|
-
except asyncio.CancelledError:
|
506
|
-
pass
|
507
|
-
|
508
|
-
finally:
|
509
|
-
# Ensure we mark as disconnected
|
510
|
-
self.connected.clear()
|
511
|
-
self.websocket = None
|
512
|
-
|
513
|
-
async def _message_handler(self):
|
514
|
-
"""Handle messages from orchestrator."""
|
515
|
-
try:
|
516
|
-
async for message in self.websocket:
|
517
|
-
try:
|
518
|
-
data = json.loads(message)
|
519
|
-
msg_type = data.get("type")
|
520
|
-
|
521
|
-
if msg_type == "shard_assignment":
|
522
|
-
chunks = data["chunks"]
|
523
|
-
for chunk_data in chunks:
|
524
|
-
chunk = ShardChunk(**chunk_data)
|
525
|
-
with self.chunk_lock:
|
526
|
-
self.assigned_chunks.append(chunk)
|
527
|
-
logger.info(f"Received chunk assignment: {chunk.chunk_id}")
|
528
|
-
|
529
|
-
elif msg_type == "no_chunks":
|
530
|
-
reason = data.get("reason", "unknown")
|
531
|
-
logger.info(f"No chunks available from orchestrator (reason: {reason})")
|
532
|
-
|
533
|
-
# Different wait times based on reason
|
534
|
-
wait_time = 2 if reason == "state_restoring" else 10
|
535
|
-
await asyncio.sleep(wait_time)
|
536
|
-
|
537
|
-
# Request again after waiting
|
538
|
-
if self.websocket and self.connected.is_set():
|
539
|
-
await self.websocket.send(
|
540
|
-
json.dumps({"type": "request_chunks", "count": 2})
|
541
|
-
)
|
542
|
-
|
543
|
-
elif msg_type == "reload_vllm":
|
544
|
-
# Orchestrator requested vLLM reload
|
545
|
-
logger.info("Orchestrator requested vLLM reload")
|
546
|
-
new_config = data.get("vllm_config")
|
547
|
-
if new_config:
|
548
|
-
self._handle_vllm_config_update(new_config)
|
549
|
-
|
550
|
-
elif msg_type == "job_assignment":
|
551
|
-
await self._handle_job_assignment(data["job"])
|
552
|
-
|
553
|
-
elif msg_type == "no_jobs":
|
554
|
-
logger.debug("No jobs available")
|
555
|
-
await asyncio.sleep(2)
|
556
|
-
|
557
|
-
except json.JSONDecodeError as e:
|
558
|
-
logger.error(f"Invalid message format: {e}")
|
559
|
-
except Exception as e:
|
560
|
-
logger.error(f"Error handling message: {e}")
|
561
|
-
|
562
|
-
except websockets.exceptions.ConnectionClosed as e:
|
563
|
-
logger.info(f"Connection closed by orchestrator: {e}")
|
564
|
-
raise # Re-raise to trigger cleanup
|
565
|
-
except Exception as e:
|
566
|
-
logger.error(f"Message handler error: {e}")
|
567
|
-
raise
|
568
|
-
|
569
|
-
def _shard_reader_thread(self):
|
570
|
-
"""Background thread that reads from WebDataset shards."""
|
571
|
-
logger.info("Starting shard reader thread")
|
572
|
-
|
573
|
-
while self.running:
|
574
|
-
# Check if we should stop processing
|
575
|
-
if self.should_stop_processing.is_set():
|
576
|
-
logger.info("Shard reader waiting for reconnection")
|
577
|
-
time.sleep(1)
|
578
|
-
continue
|
579
|
-
|
580
|
-
# Only process if connected
|
581
|
-
if not self.connected.is_set():
|
582
|
-
time.sleep(1)
|
583
|
-
continue
|
584
|
-
|
585
|
-
# Get next chunk to process
|
586
|
-
with self.chunk_lock:
|
587
|
-
if not self.current_chunk and self.assigned_chunks:
|
588
|
-
self.current_chunk = self.assigned_chunks.popleft()
|
589
|
-
self.current_chunk_progress = 0
|
590
|
-
logger.info(f"Starting chunk {self.current_chunk.chunk_id}")
|
591
|
-
|
592
|
-
if not self.current_chunk:
|
593
|
-
time.sleep(1)
|
594
|
-
continue
|
595
|
-
|
596
|
-
try:
|
597
|
-
# Process the chunk
|
598
|
-
self._process_shard_chunk(self.current_chunk)
|
599
|
-
|
600
|
-
# Only mark complete if still connected
|
601
|
-
if self.connected.is_set() and not self.should_stop_processing.is_set():
|
602
|
-
logger.info(f"Completed chunk {self.current_chunk.chunk_id}")
|
603
|
-
self.chunks_completed += 1
|
604
|
-
|
605
|
-
# Notify orchestrator if connected
|
606
|
-
if self.websocket and self.main_loop:
|
607
|
-
try:
|
608
|
-
# Notify completion
|
609
|
-
asyncio.run_coroutine_threadsafe(
|
610
|
-
self.websocket.send(
|
611
|
-
json.dumps(
|
612
|
-
{
|
613
|
-
"type": "chunk_complete",
|
614
|
-
"chunk_id": self.current_chunk.chunk_id,
|
615
|
-
}
|
616
|
-
)
|
617
|
-
),
|
618
|
-
self.main_loop,
|
619
|
-
).result(timeout=5)
|
620
|
-
|
621
|
-
# Request more chunks if queue is low
|
622
|
-
with self.chunk_lock:
|
623
|
-
queue_size = len(self.assigned_chunks)
|
624
|
-
|
625
|
-
if queue_size < 2:
|
626
|
-
logger.info(f"Requesting more chunks (queue size: {queue_size})")
|
627
|
-
asyncio.run_coroutine_threadsafe(
|
628
|
-
self.websocket.send(
|
629
|
-
json.dumps({"type": "request_chunks", "count": 2})
|
630
|
-
),
|
631
|
-
self.main_loop,
|
632
|
-
).result(timeout=5)
|
633
|
-
|
634
|
-
except Exception as e:
|
635
|
-
logger.warning(f"Could not notify orchestrator: {e}")
|
636
|
-
|
637
|
-
with self.chunk_lock:
|
638
|
-
self.current_chunk = None
|
639
|
-
|
640
|
-
except Exception as e:
|
641
|
-
logger.error(f"Error processing chunk: {e}")
|
642
|
-
|
643
|
-
# Only notify of failure if still connected
|
644
|
-
if self.connected.is_set() and self.websocket and self.main_loop:
|
645
|
-
try:
|
646
|
-
asyncio.run_coroutine_threadsafe(
|
647
|
-
self.websocket.send(
|
648
|
-
json.dumps(
|
649
|
-
{
|
650
|
-
"type": "chunk_failed",
|
651
|
-
"chunk_id": (
|
652
|
-
self.current_chunk.chunk_id
|
653
|
-
if self.current_chunk
|
654
|
-
else "unknown"
|
655
|
-
),
|
656
|
-
"error": str(e),
|
657
|
-
}
|
658
|
-
)
|
659
|
-
),
|
660
|
-
self.main_loop,
|
661
|
-
).result(timeout=5)
|
662
|
-
except Exception as send_error:
|
663
|
-
logger.warning(
|
664
|
-
f"Could not notify orchestrator of chunk failure: {send_error}"
|
665
|
-
)
|
666
|
-
|
667
|
-
with self.chunk_lock:
|
668
|
-
self.current_chunk = None
|
669
|
-
|
670
|
-
def _process_shard_chunk(self, chunk: ShardChunk):
|
671
|
-
"""Process a single shard chunk."""
|
672
|
-
logger.info(f"Processing shard {chunk.shard_name} from index {chunk.start_index}")
|
673
|
-
|
674
|
-
# Create WebDataset pipeline
|
675
|
-
if self.dataset_type == "huggingface":
|
676
|
-
# Use curl with auth for HuggingFace
|
677
|
-
url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.hf_token)}' {shlex.quote(chunk.shard_url)} || true"
|
678
|
-
ds = wds.DataPipeline(
|
679
|
-
wds.SimpleShardList(url_cmd),
|
680
|
-
wds.tarfile_to_samples(),
|
681
|
-
wds.to_tuple("__key__", "jpg;png;jpeg;webp"),
|
682
|
-
)
|
683
|
-
else:
|
684
|
-
# Local file
|
685
|
-
ds = wds.DataPipeline(
|
686
|
-
wds.SimpleShardList(chunk.shard_url),
|
687
|
-
wds.tarfile_to_samples(),
|
688
|
-
wds.to_tuple("__key__", "jpg;png;jpeg;webp"),
|
689
|
-
)
|
690
|
-
|
691
|
-
# Process items with readahead
|
692
|
-
items_processed = 0
|
693
|
-
items_to_skip = chunk.start_index
|
694
|
-
|
695
|
-
for key, image_data in ds:
|
696
|
-
# Check if we should stop
|
697
|
-
if (
|
698
|
-
not self.running
|
699
|
-
or self.should_stop_processing.is_set()
|
700
|
-
or not self.connected.is_set()
|
701
|
-
):
|
702
|
-
logger.info(f"Stopping chunk processing early due to disconnect")
|
703
|
-
break
|
704
|
-
|
705
|
-
# Skip to start index
|
706
|
-
if items_to_skip > 0:
|
707
|
-
items_to_skip -= 1
|
708
|
-
continue
|
709
|
-
|
710
|
-
# Check if we've processed enough
|
711
|
-
if items_processed >= chunk.chunk_size:
|
712
|
-
break
|
713
|
-
|
714
|
-
try:
|
715
|
-
# Load image
|
716
|
-
img = Image.open(io.BytesIO(image_data))
|
717
|
-
|
718
|
-
# Create processing item
|
719
|
-
item = ProcessingItem(
|
720
|
-
chunk_id=chunk.chunk_id, item_key=key, image=img, image_data=image_data
|
721
|
-
)
|
722
|
-
|
723
|
-
# Add to readahead queue (blocks if full - provides backpressure)
|
724
|
-
# Use timeout to allow checking for disconnection
|
725
|
-
timeout_end = time.time() + 30
|
726
|
-
while (
|
727
|
-
self.running
|
728
|
-
and not self.should_stop_processing.is_set()
|
729
|
-
and self.connected.is_set()
|
730
|
-
):
|
731
|
-
try:
|
732
|
-
self.readahead_queue.put(item, timeout=1)
|
733
|
-
break
|
734
|
-
except:
|
735
|
-
if time.time() > timeout_end:
|
736
|
-
raise TimeoutError("Queue put timeout")
|
737
|
-
continue
|
738
|
-
|
739
|
-
# If we couldn't queue due to disconnection, skip this item
|
740
|
-
if not self.connected.is_set() or self.should_stop_processing.is_set():
|
741
|
-
logger.debug(f"Skipping item {key} due to disconnection")
|
742
|
-
break
|
743
|
-
|
744
|
-
items_processed += 1
|
745
|
-
self.current_chunk_progress = items_processed
|
746
|
-
|
747
|
-
# Batch items for inference
|
748
|
-
batch_size = self.vllm_config.get("batch_size", 8)
|
749
|
-
if self.readahead_queue.qsize() >= batch_size:
|
750
|
-
self._batch_for_inference()
|
751
|
-
|
752
|
-
except Exception as e:
|
753
|
-
if self.should_stop_processing.is_set():
|
754
|
-
break
|
755
|
-
logger.error(f"Error processing item {key}: {e}")
|
756
|
-
self.items_failed += 1
|
757
|
-
|
758
|
-
# Process remaining items only if still connected
|
759
|
-
if not self.should_stop_processing.is_set():
|
760
|
-
self._batch_for_inference()
|
761
|
-
|
762
|
-
logger.info(f"Chunk {chunk.chunk_id} processed {items_processed} items")
|
763
|
-
|
764
|
-
def _batch_for_inference(self):
|
765
|
-
"""Batch items from readahead queue for inference."""
|
766
|
-
batch = []
|
767
|
-
batch_size = self.vllm_config.get("batch_size", 8)
|
768
|
-
|
769
|
-
try:
|
770
|
-
while len(batch) < batch_size:
|
771
|
-
item = self.readahead_queue.get_nowait()
|
772
|
-
batch.append(item)
|
773
|
-
except Empty:
|
774
|
-
pass
|
775
|
-
|
776
|
-
if batch:
|
777
|
-
self.inference_queue.put(batch)
|
778
|
-
|
779
|
-
def _inference_thread(self):
|
780
|
-
"""Background thread for vLLM inference."""
|
781
|
-
logger.info("Starting inference thread")
|
782
|
-
|
783
|
-
while self.running:
|
784
|
-
try:
|
785
|
-
# Get batch from queue with timeout
|
786
|
-
batch = self.inference_queue.get(timeout=1)
|
787
|
-
|
788
|
-
if not batch:
|
789
|
-
continue
|
790
|
-
|
791
|
-
# Skip if disconnected
|
792
|
-
if self.should_stop_processing.is_set():
|
793
|
-
continue
|
794
|
-
|
795
|
-
logger.debug(f"Processing batch of {len(batch)} images")
|
796
|
-
start_time = time.time()
|
797
|
-
|
798
|
-
# Prepare vLLM inputs
|
799
|
-
requests = []
|
800
|
-
for item in batch:
|
801
|
-
# Resize for consistency
|
802
|
-
item.image.thumbnail((512, 512), Image.BILINEAR)
|
803
|
-
|
804
|
-
for prompt in self.inference_prompts:
|
805
|
-
req = self._build_vllm_input(item.image, prompt)
|
806
|
-
requests.append(req)
|
807
|
-
|
808
|
-
# Run inference
|
809
|
-
outputs = self.llm.generate(requests, self.sampling_params)
|
810
|
-
|
811
|
-
# Process outputs only if still connected
|
812
|
-
if not self.should_stop_processing.is_set():
|
813
|
-
for i, item in enumerate(batch):
|
814
|
-
# Get all prompt outputs as a list
|
815
|
-
idx = i * len(self.inference_prompts)
|
816
|
-
captions = []
|
817
|
-
|
818
|
-
for j in range(len(self.inference_prompts)):
|
819
|
-
if idx + j < len(outputs) and outputs[idx + j].outputs:
|
820
|
-
caption_text = self._clean_output(outputs[idx + j].outputs[0].text)
|
821
|
-
if caption_text: # Only add non-empty captions
|
822
|
-
captions.append(caption_text)
|
823
|
-
|
824
|
-
# Only create result if we have at least one caption
|
825
|
-
if captions:
|
826
|
-
result = ProcessedResult(
|
827
|
-
chunk_id=item.chunk_id,
|
828
|
-
shard_name=Path(item.chunk_id).stem.rsplit("_chunk_", 1)[0],
|
829
|
-
item_key=item.item_key,
|
830
|
-
captions=captions,
|
831
|
-
image_width=item.image.width,
|
832
|
-
image_height=item.image.height,
|
833
|
-
image_format=item.image.format or "unknown",
|
834
|
-
file_size=len(item.image_data),
|
835
|
-
processing_time_ms=(time.time() - start_time) * 1000 / len(batch),
|
836
|
-
)
|
837
|
-
|
838
|
-
self.result_queue.put(result)
|
839
|
-
self.items_processed += 1
|
840
|
-
else:
|
841
|
-
logger.warning(f"No valid captions generated for item {item.item_key}")
|
842
|
-
self.items_failed += 1
|
843
|
-
|
844
|
-
except Empty:
|
845
|
-
continue
|
846
|
-
except Exception as e:
|
847
|
-
if self.should_stop_processing.is_set():
|
848
|
-
continue
|
849
|
-
logger.error(f"Inference error: {e}")
|
850
|
-
|
851
|
-
def _build_vllm_input(self, image: Image.Image, prompt: str) -> Dict:
|
852
|
-
"""Build vLLM input."""
|
853
|
-
try:
|
854
|
-
from qwen_vl_utils import process_vision_info
|
855
|
-
|
856
|
-
messages = [
|
857
|
-
{
|
858
|
-
"role": "user",
|
859
|
-
"content": [
|
860
|
-
{"type": "image", "image": image},
|
861
|
-
{"type": "text", "text": prompt},
|
862
|
-
],
|
863
|
-
}
|
864
|
-
]
|
865
|
-
|
866
|
-
prompt_text = self.processor.apply_chat_template(
|
867
|
-
messages, tokenize=False, add_generation_prompt=True
|
868
|
-
)
|
869
|
-
image_inputs, _ = process_vision_info(messages)
|
870
|
-
prompt_ids = self.tokenizer(prompt_text, add_special_tokens=False).input_ids
|
871
|
-
|
872
|
-
return {
|
873
|
-
"prompt_token_ids": prompt_ids,
|
874
|
-
"multi_modal_data": {"image": image_inputs},
|
875
|
-
}
|
876
|
-
except ImportError:
|
877
|
-
return {
|
878
|
-
"prompt": f"<|user|>\n<|image_pad|>\n{prompt}<|end|>\n<|assistant|>",
|
879
|
-
"multi_modal_data": {"image": [image]},
|
880
|
-
}
|
881
|
-
|
882
|
-
def _clean_output(self, text: str) -> str:
|
883
|
-
"""Clean model output."""
|
884
|
-
if not text:
|
885
|
-
return ""
|
886
|
-
|
887
|
-
# Remove common artifacts
|
888
|
-
for token in ["<|end|>", "<|endoftext|>", "<|im_end|>", "I'm sorry", "I cannot"]:
|
889
|
-
if token in text:
|
890
|
-
text = text.split(token)[0]
|
891
|
-
|
892
|
-
return text.strip()
|
893
|
-
|
894
|
-
async def _result_sender(self):
|
895
|
-
"""Send results back to orchestrator."""
|
896
|
-
pending_results = [] # Buffer for results during disconnection
|
897
|
-
|
898
|
-
try:
|
899
|
-
while self.running and self.connected.is_set():
|
900
|
-
try:
|
901
|
-
# Get result (with timeout to allow checking self.running)
|
902
|
-
try:
|
903
|
-
result = await asyncio.get_event_loop().run_in_executor(
|
904
|
-
None, self.result_queue.get, True, 1
|
905
|
-
)
|
906
|
-
pending_results.append(result)
|
907
|
-
except Empty:
|
908
|
-
pass
|
909
|
-
|
910
|
-
# Only try to send if connected
|
911
|
-
if pending_results and self.websocket and self.connected.is_set():
|
912
|
-
sent_results = []
|
913
|
-
for result in pending_results:
|
914
|
-
try:
|
915
|
-
# Send result with all captions
|
916
|
-
await self.websocket.send(
|
917
|
-
json.dumps(
|
918
|
-
{
|
919
|
-
"type": "submit_captions",
|
920
|
-
"chunk_id": result.chunk_id,
|
921
|
-
"dataset": self.dataset_config.get(
|
922
|
-
"dataset_path", "unknown"
|
923
|
-
),
|
924
|
-
"shard": result.shard_name,
|
925
|
-
"item_key": result.item_key,
|
926
|
-
"captions": result.captions,
|
927
|
-
"caption_count": len(result.captions),
|
928
|
-
"image_width": result.image_width,
|
929
|
-
"image_height": result.image_height,
|
930
|
-
"image_format": result.image_format,
|
931
|
-
"file_size": result.file_size,
|
932
|
-
"processing_time_ms": result.processing_time_ms,
|
933
|
-
}
|
934
|
-
)
|
935
|
-
)
|
936
|
-
sent_results.append(result)
|
937
|
-
|
938
|
-
if self.items_processed % 100 == 0:
|
939
|
-
logger.info(
|
940
|
-
f"Processed {self.items_processed} items "
|
941
|
-
f"(~{self.items_processed * 3} captions)"
|
942
|
-
)
|
943
|
-
except websockets.exceptions.ConnectionClosed as e:
|
944
|
-
logger.warning(f"Connection lost while sending result: {e}")
|
945
|
-
raise # Re-raise to trigger task completion
|
946
|
-
except Exception as e:
|
947
|
-
logger.error(f"Error sending result: {e}")
|
948
|
-
break
|
949
|
-
|
950
|
-
# Remove successfully sent results
|
951
|
-
for result in sent_results:
|
952
|
-
pending_results.remove(result)
|
953
|
-
|
954
|
-
# Clear pending results if disconnected and buffer is too large
|
955
|
-
if not self.connected.is_set() and len(pending_results) > 1000:
|
956
|
-
logger.warning(
|
957
|
-
f"Clearing {len(pending_results)} pending results due to prolonged disconnection"
|
958
|
-
)
|
959
|
-
pending_results.clear()
|
960
|
-
|
961
|
-
await asyncio.sleep(0.1)
|
962
|
-
|
963
|
-
except Exception as e:
|
964
|
-
if isinstance(e, websockets.exceptions.ConnectionClosed):
|
965
|
-
raise # Re-raise connection errors
|
966
|
-
logger.error(f"Unexpected error in result sender: {e}")
|
967
|
-
await asyncio.sleep(1)
|
968
|
-
|
969
|
-
except asyncio.CancelledError:
|
970
|
-
logger.debug("Result sender cancelled")
|
971
|
-
raise
|
972
|
-
|
973
|
-
async def _heartbeat_loop(self):
|
974
|
-
"""Send periodic heartbeats with connection checking."""
|
975
|
-
try:
|
976
|
-
while self.running and self.connected.is_set():
|
977
|
-
try:
|
978
|
-
if self.websocket:
|
979
|
-
await self.websocket.send(
|
980
|
-
json.dumps(
|
981
|
-
{
|
982
|
-
"type": "heartbeat",
|
983
|
-
"processed": self.items_processed,
|
984
|
-
"failed": self.items_failed,
|
985
|
-
"chunks_completed": self.chunks_completed,
|
986
|
-
"current_chunk": (
|
987
|
-
self.current_chunk.chunk_id if self.current_chunk else None
|
988
|
-
),
|
989
|
-
"chunk_progress": self.current_chunk_progress,
|
990
|
-
"queue_sizes": {
|
991
|
-
"readahead": self.readahead_queue.qsize(),
|
992
|
-
"inference": self.inference_queue.qsize(),
|
993
|
-
"results": self.result_queue.qsize(),
|
994
|
-
},
|
995
|
-
}
|
996
|
-
)
|
997
|
-
)
|
998
|
-
await asyncio.sleep(30)
|
999
|
-
except websockets.exceptions.ConnectionClosed as e:
|
1000
|
-
logger.info(f"Connection lost during heartbeat: {e}")
|
1001
|
-
raise # Re-raise to trigger task completion
|
1002
|
-
except Exception as e:
|
1003
|
-
logger.error(f"Heartbeat error: {e}")
|
1004
|
-
raise # Re-raise to trigger task completion
|
1005
|
-
except asyncio.CancelledError:
|
1006
|
-
logger.debug("Heartbeat loop cancelled")
|
1007
|
-
raise
|
1008
|
-
|
1009
|
-
async def shutdown(self):
|
1010
|
-
"""Graceful shutdown."""
|
1011
|
-
logger.info("Shutting down worker...")
|
1012
|
-
self.running = False
|
1013
|
-
self.connected.clear()
|
1014
|
-
self.should_stop_processing.set()
|
1015
|
-
|
1016
|
-
# Stop processing threads by adding stop signals
|
1017
|
-
self.readahead_queue.put(None)
|
1018
|
-
self.inference_queue.put(None)
|
1019
|
-
|
1020
|
-
# Close websocket if connected
|
1021
|
-
if self.websocket:
|
1022
|
-
try:
|
1023
|
-
await self.websocket.close()
|
1024
|
-
except:
|
1025
|
-
pass
|
1026
|
-
self.websocket = None
|
1027
|
-
|
1028
|
-
logger.info("Worker shutdown complete")
|