caption-flow 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +3 -2
- caption_flow/cli.py +65 -42
- caption_flow/models.py +6 -4
- caption_flow/monitor.py +13 -3
- caption_flow/orchestrator.py +1049 -264
- caption_flow/storage.py +579 -222
- caption_flow/utils/__init__.py +3 -1
- caption_flow/utils/auth.py +24 -25
- caption_flow/utils/checkpoint_tracker.py +92 -0
- caption_flow/utils/chunk_tracker.py +278 -194
- caption_flow/utils/dataset_loader.py +567 -73
- caption_flow/utils/image_processor.py +121 -1
- caption_flow/utils/prompt_template.py +137 -0
- caption_flow/utils/shard_processor.py +315 -0
- caption_flow/utils/shard_tracker.py +87 -0
- caption_flow/workers/base.py +228 -0
- caption_flow/workers/caption.py +1321 -0
- caption_flow/{worker_data.py → workers/data.py} +162 -234
- caption_flow-0.2.1.dist-info/METADATA +370 -0
- caption_flow-0.2.1.dist-info/RECORD +29 -0
- caption_flow/worker.py +0 -300
- caption_flow/worker_vllm.py +0 -1028
- caption_flow-0.1.0.dist-info/METADATA +0 -427
- caption_flow-0.1.0.dist-info/RECORD +0 -25
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.1.dist-info}/WHEEL +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.1.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1321 @@
|
|
1
|
+
"""Caption worker for vLLM-based distributed image captioning with multi-stage processing."""
|
2
|
+
|
3
|
+
import os
|
4
|
+
|
5
|
+
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
|
6
|
+
|
7
|
+
import asyncio
|
8
|
+
import io
|
9
|
+
import json
|
10
|
+
import logging
|
11
|
+
import websockets
|
12
|
+
import time
|
13
|
+
from dataclasses import dataclass, field
|
14
|
+
from pathlib import Path
|
15
|
+
from typing import Dict, Any, Optional, List, Tuple
|
16
|
+
from queue import Queue, Empty
|
17
|
+
from threading import Thread, Lock, Event
|
18
|
+
from collections import deque, defaultdict
|
19
|
+
|
20
|
+
from PIL import Image
|
21
|
+
import numpy as np
|
22
|
+
from huggingface_hub import get_token
|
23
|
+
|
24
|
+
from .base import BaseWorker
|
25
|
+
from ..models import JobStatus, Job
|
26
|
+
from ..utils import CaptionUtils
|
27
|
+
from ..utils.dataset_loader import DatasetLoader
|
28
|
+
from ..utils.vllm_config import VLLMConfigManager
|
29
|
+
from ..utils.image_processor import ImageProcessor
|
30
|
+
from ..utils.shard_processor import HFDatasetShardProcessor, WebDatasetShardProcessor
|
31
|
+
from ..utils.prompt_template import PromptTemplateManager
|
32
|
+
|
33
|
+
logger = logging.getLogger(__name__)
|
34
|
+
|
35
|
+
|
36
|
+
@dataclass
|
37
|
+
class ProcessingStage:
|
38
|
+
"""Configuration for a single processing stage."""
|
39
|
+
|
40
|
+
name: str
|
41
|
+
model: str
|
42
|
+
prompts: List[str]
|
43
|
+
output_field: str
|
44
|
+
requires: List[str] = field(default_factory=list)
|
45
|
+
sampling: Optional[Dict[str, Any]] = None
|
46
|
+
|
47
|
+
# Model-specific overrides
|
48
|
+
tensor_parallel_size: Optional[int] = None
|
49
|
+
max_model_len: Optional[int] = None
|
50
|
+
dtype: Optional[str] = None
|
51
|
+
gpu_memory_utilization: Optional[float] = None
|
52
|
+
|
53
|
+
|
54
|
+
@dataclass
|
55
|
+
class StageResult:
|
56
|
+
"""Results from a single stage."""
|
57
|
+
|
58
|
+
stage_name: str
|
59
|
+
output_field: str
|
60
|
+
outputs: List[str] # Multiple outputs from multiple prompts
|
61
|
+
|
62
|
+
|
63
|
+
@dataclass
|
64
|
+
class ShardChunk:
|
65
|
+
"""Shard chunk assignment with unprocessed ranges."""
|
66
|
+
|
67
|
+
chunk_id: str
|
68
|
+
shard_url: str
|
69
|
+
shard_name: str
|
70
|
+
start_index: int
|
71
|
+
chunk_size: int
|
72
|
+
unprocessed_ranges: List[Tuple[int, int]] = field(default_factory=list)
|
73
|
+
|
74
|
+
|
75
|
+
@dataclass
|
76
|
+
class ProcessingItem:
|
77
|
+
"""Item being processed."""
|
78
|
+
|
79
|
+
chunk_id: str
|
80
|
+
item_key: str
|
81
|
+
image: Image.Image
|
82
|
+
image_data: bytes
|
83
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
84
|
+
stage_results: Dict[str, StageResult] = field(default_factory=dict) # Accumulated results
|
85
|
+
|
86
|
+
|
87
|
+
@dataclass
|
88
|
+
class ProcessedResult:
|
89
|
+
"""Result with multi-stage outputs."""
|
90
|
+
|
91
|
+
chunk_id: str
|
92
|
+
shard_name: str
|
93
|
+
item_key: str
|
94
|
+
outputs: Dict[str, List[str]] # field_name -> list of outputs
|
95
|
+
image_width: int
|
96
|
+
image_height: int
|
97
|
+
image_format: str
|
98
|
+
file_size: int
|
99
|
+
processing_time_ms: float
|
100
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
101
|
+
|
102
|
+
|
103
|
+
class MultiStageVLLMManager:
|
104
|
+
"""Manages multiple vLLM instances for different models."""
|
105
|
+
|
106
|
+
def __init__(self, gpu_id: int = 0):
|
107
|
+
self.gpu_id = gpu_id
|
108
|
+
self.models: Dict[str, Any] = {} # model_name -> LLM instance
|
109
|
+
self.processors: Dict[str, Any] = {} # model_name -> processor
|
110
|
+
self.tokenizers: Dict[str, Any] = {} # model_name -> tokenizer
|
111
|
+
self.sampling_params: Dict[str, Any] = {} # stage_name -> SamplingParams
|
112
|
+
|
113
|
+
def load_model(self, model_name: str, stage: ProcessingStage, base_config: Dict[str, Any]):
|
114
|
+
"""Load a model if not already loaded."""
|
115
|
+
if model_name in self.models:
|
116
|
+
logger.info(f"Model {model_name} already loaded, reusing instance")
|
117
|
+
return
|
118
|
+
|
119
|
+
from vllm import LLM, SamplingParams
|
120
|
+
from transformers import AutoTokenizer, AutoProcessor
|
121
|
+
|
122
|
+
logger.info(f"Loading model {model_name} for stage {stage.name}")
|
123
|
+
|
124
|
+
# Build model-specific config by merging base config with stage overrides
|
125
|
+
model_config = base_config.copy()
|
126
|
+
model_config["model"] = model_name
|
127
|
+
|
128
|
+
# Apply stage-specific overrides
|
129
|
+
if stage.tensor_parallel_size is not None:
|
130
|
+
model_config["tensor_parallel_size"] = stage.tensor_parallel_size
|
131
|
+
if stage.max_model_len is not None:
|
132
|
+
model_config["max_model_len"] = stage.max_model_len
|
133
|
+
if stage.dtype is not None:
|
134
|
+
model_config["dtype"] = stage.dtype
|
135
|
+
if stage.gpu_memory_utilization is not None:
|
136
|
+
model_config["gpu_memory_utilization"] = stage.gpu_memory_utilization
|
137
|
+
|
138
|
+
# Load tokenizer and processor
|
139
|
+
self.tokenizers[model_name] = AutoTokenizer.from_pretrained(
|
140
|
+
model_name, trust_remote_code=True, use_fast=True
|
141
|
+
)
|
142
|
+
self.processors[model_name] = AutoProcessor.from_pretrained(model_name)
|
143
|
+
|
144
|
+
# Initialize LLM
|
145
|
+
vllm_params = {
|
146
|
+
"model": model_name,
|
147
|
+
"trust_remote_code": True,
|
148
|
+
"tensor_parallel_size": model_config.get("tensor_parallel_size", 1),
|
149
|
+
"max_model_len": model_config.get("max_model_len", 16384),
|
150
|
+
"enforce_eager": model_config.get("enforce_eager", True),
|
151
|
+
"gpu_memory_utilization": model_config.get("gpu_memory_utilization", 0.92),
|
152
|
+
"dtype": model_config.get("dtype", "float16"),
|
153
|
+
"limit_mm_per_prompt": model_config.get("limit_mm_per_prompt", {"image": 1}),
|
154
|
+
"disable_mm_preprocessor_cache": model_config.get(
|
155
|
+
"disable_mm_preprocessor_cache", True
|
156
|
+
),
|
157
|
+
}
|
158
|
+
|
159
|
+
self.models[model_name] = LLM(**vllm_params)
|
160
|
+
logger.info(f"Model {model_name} loaded successfully")
|
161
|
+
|
162
|
+
def create_sampling_params(self, stage: ProcessingStage, base_sampling: Dict[str, Any]):
|
163
|
+
"""Create sampling params for a stage."""
|
164
|
+
from vllm import SamplingParams
|
165
|
+
|
166
|
+
# Start with base sampling config
|
167
|
+
sampling_config = base_sampling.copy()
|
168
|
+
|
169
|
+
# Override with stage-specific sampling if provided
|
170
|
+
if stage.sampling:
|
171
|
+
sampling_config.update(stage.sampling)
|
172
|
+
|
173
|
+
params = SamplingParams(
|
174
|
+
temperature=sampling_config.get("temperature", 0.7),
|
175
|
+
top_p=sampling_config.get("top_p", 0.95),
|
176
|
+
max_tokens=sampling_config.get("max_tokens", 256),
|
177
|
+
stop=sampling_config.get("stop", ["<|end|>", "<|endoftext|>", "<|im_end|>"]),
|
178
|
+
repetition_penalty=sampling_config.get("repetition_penalty", 1.05),
|
179
|
+
skip_special_tokens=sampling_config.get("skip_special_tokens", True),
|
180
|
+
)
|
181
|
+
|
182
|
+
self.sampling_params[stage.name] = params
|
183
|
+
return params
|
184
|
+
|
185
|
+
def get_model_for_stage(self, stage_name: str, model_name: str) -> Tuple[Any, Any, Any, Any]:
|
186
|
+
"""
|
187
|
+
Get model components for a stage.
|
188
|
+
|
189
|
+
Returns:
|
190
|
+
tuple: A tuple containing:
|
191
|
+
- llm: The language model instance for the given model name.
|
192
|
+
- processor: The processor associated with the model.
|
193
|
+
- tokenizer: The tokenizer for the model.
|
194
|
+
- sampling_params: The sampling parameters for the given stage.
|
195
|
+
"""
|
196
|
+
return (
|
197
|
+
self.models[model_name],
|
198
|
+
self.processors[model_name],
|
199
|
+
self.tokenizers[model_name],
|
200
|
+
self.sampling_params[stage_name],
|
201
|
+
)
|
202
|
+
|
203
|
+
def cleanup(self):
|
204
|
+
"""Clean up all loaded models."""
|
205
|
+
for model_name in list(self.models.keys()):
|
206
|
+
del self.models[model_name]
|
207
|
+
del self.processors[model_name]
|
208
|
+
del self.tokenizers[model_name]
|
209
|
+
self.sampling_params.clear()
|
210
|
+
|
211
|
+
import gc
|
212
|
+
|
213
|
+
gc.collect()
|
214
|
+
|
215
|
+
|
216
|
+
class CaptionWorker(BaseWorker):
|
217
|
+
"""Worker that processes shard chunks for image captioning using multi-stage vLLM."""
|
218
|
+
|
219
|
+
def __init__(self, config: Dict[str, Any]):
|
220
|
+
super().__init__(config)
|
221
|
+
|
222
|
+
batch_image_processing = config.get("batch_image_processing", False)
|
223
|
+
|
224
|
+
# Dataset configuration will be received from orchestrator
|
225
|
+
self.dataset_config = None
|
226
|
+
self.dataset_loader = None
|
227
|
+
self.dataset_type = None
|
228
|
+
self.dataset_split = None
|
229
|
+
self.dataset_image_column = None
|
230
|
+
self.hf_token = get_token()
|
231
|
+
|
232
|
+
# vLLM configuration will be received from orchestrator
|
233
|
+
self.vllm_config = None
|
234
|
+
self.stages: List[ProcessingStage] = []
|
235
|
+
self.stage_order: List[str] = [] # Topologically sorted stage names
|
236
|
+
self.vllm_config_manager = VLLMConfigManager()
|
237
|
+
self.model_manager = None
|
238
|
+
|
239
|
+
# Backward compatibility: local config for GPU selection
|
240
|
+
self.gpu_id = config.get("gpu_id", 0)
|
241
|
+
|
242
|
+
# Connection state events
|
243
|
+
self.should_stop_processing = Event()
|
244
|
+
|
245
|
+
# Image processor
|
246
|
+
self.image_processor = None
|
247
|
+
if batch_image_processing:
|
248
|
+
self.image_processor = ImageProcessor()
|
249
|
+
|
250
|
+
# Shard chunk processing
|
251
|
+
self.hf_processor = HFDatasetShardProcessor()
|
252
|
+
self.webdataset_processor = WebDatasetShardProcessor(
|
253
|
+
hf_token=self.hf_token, dataset_type=self.dataset_type
|
254
|
+
)
|
255
|
+
self.chunk_lock = Lock()
|
256
|
+
self.assigned_chunks = deque()
|
257
|
+
self.current_chunk = None
|
258
|
+
self.current_chunk_progress = 0
|
259
|
+
|
260
|
+
# Batching queues - will be cleared on disconnect
|
261
|
+
self.readahead_queue = Queue(maxsize=256)
|
262
|
+
self.inference_queue = Queue(maxsize=128)
|
263
|
+
self.result_queue = Queue()
|
264
|
+
|
265
|
+
# Job mode for shards vs jobs and job queue
|
266
|
+
self.job_mode = config.get("job_mode", False)
|
267
|
+
self.job_queue = Queue(maxsize=32)
|
268
|
+
|
269
|
+
def _init_metrics(self):
|
270
|
+
"""Initialize worker metrics."""
|
271
|
+
self.items_processed = 0
|
272
|
+
self.items_failed = 0
|
273
|
+
self.chunks_completed = 0
|
274
|
+
|
275
|
+
def _get_auth_data(self) -> Dict[str, Any]:
|
276
|
+
"""Get authentication data."""
|
277
|
+
return {"token": self.token, "name": self.name}
|
278
|
+
|
279
|
+
async def _pre_start(self):
|
280
|
+
"""Initialize before starting connection loop."""
|
281
|
+
# Wait for initial connection to get vLLM config
|
282
|
+
logger.info("Connecting to orchestrator for configuration...")
|
283
|
+
|
284
|
+
# Try initial connection to get config
|
285
|
+
config_received = False
|
286
|
+
while not config_received and self.running:
|
287
|
+
try:
|
288
|
+
await self._initial_connect_for_config()
|
289
|
+
config_received = True
|
290
|
+
except Exception as e:
|
291
|
+
logger.error(f"Failed to get config: {e}")
|
292
|
+
await asyncio.sleep(5)
|
293
|
+
|
294
|
+
# Initialize vLLM once we have config
|
295
|
+
self._setup_vllm()
|
296
|
+
|
297
|
+
# Start background threads
|
298
|
+
reader_thread = Thread(target=self._shard_reader_thread, daemon=True)
|
299
|
+
reader_thread.start()
|
300
|
+
|
301
|
+
inference_thread = Thread(target=self._inference_thread, daemon=True)
|
302
|
+
inference_thread.start()
|
303
|
+
|
304
|
+
def _parse_stages_config(self, vllm_config: Dict[str, Any]) -> List[ProcessingStage]:
|
305
|
+
"""Parse stages configuration from vLLM config."""
|
306
|
+
stages_config = vllm_config.get("stages", [])
|
307
|
+
|
308
|
+
if not stages_config:
|
309
|
+
# Backward compatibility: create single stage from old config
|
310
|
+
return [
|
311
|
+
ProcessingStage(
|
312
|
+
name="default",
|
313
|
+
model=vllm_config.get("model", "Qwen/Qwen2.5-VL-3B-Instruct"),
|
314
|
+
prompts=vllm_config.get("inference_prompts", ["describe this image"]),
|
315
|
+
output_field="captions",
|
316
|
+
requires=[],
|
317
|
+
)
|
318
|
+
]
|
319
|
+
|
320
|
+
# Parse stages
|
321
|
+
stages = []
|
322
|
+
for stage_cfg in stages_config:
|
323
|
+
stage = ProcessingStage(
|
324
|
+
name=stage_cfg["name"],
|
325
|
+
model=stage_cfg.get("model", vllm_config.get("model")),
|
326
|
+
prompts=stage_cfg.get("prompts", []),
|
327
|
+
output_field=stage_cfg.get("output_field", "captions"),
|
328
|
+
requires=stage_cfg.get("requires", []),
|
329
|
+
sampling=stage_cfg.get("sampling"),
|
330
|
+
tensor_parallel_size=stage_cfg.get("tensor_parallel_size"),
|
331
|
+
max_model_len=stage_cfg.get("max_model_len"),
|
332
|
+
dtype=stage_cfg.get("dtype"),
|
333
|
+
gpu_memory_utilization=stage_cfg.get("gpu_memory_utilization"),
|
334
|
+
)
|
335
|
+
stages.append(stage)
|
336
|
+
|
337
|
+
return stages
|
338
|
+
|
339
|
+
def _topological_sort_stages(self, stages: List[ProcessingStage]) -> List[str]:
|
340
|
+
"""Sort stages by dependencies."""
|
341
|
+
# Build dependency graph
|
342
|
+
graph = defaultdict(list)
|
343
|
+
in_degree = defaultdict(int)
|
344
|
+
|
345
|
+
stage_map = {s.name: s for s in stages}
|
346
|
+
|
347
|
+
for stage in stages:
|
348
|
+
in_degree[stage.name] = len(stage.requires)
|
349
|
+
for dep in stage.requires:
|
350
|
+
if dep not in stage_map:
|
351
|
+
raise ValueError(f"Stage '{stage.name}' requires missing dependency '{dep}'")
|
352
|
+
graph[dep].append(stage.name)
|
353
|
+
|
354
|
+
# Topological sort using Kahn's algorithm
|
355
|
+
queue = deque([name for name, degree in in_degree.items() if degree == 0])
|
356
|
+
result = []
|
357
|
+
|
358
|
+
while queue:
|
359
|
+
current = queue.popleft()
|
360
|
+
result.append(current)
|
361
|
+
|
362
|
+
for neighbor in graph[current]:
|
363
|
+
in_degree[neighbor] -= 1
|
364
|
+
if in_degree[neighbor] == 0:
|
365
|
+
queue.append(neighbor)
|
366
|
+
|
367
|
+
if len(result) != len(stages):
|
368
|
+
raise ValueError("Circular dependency detected in stages")
|
369
|
+
|
370
|
+
return result
|
371
|
+
|
372
|
+
async def _handle_welcome(self, welcome_data: Dict[str, Any]):
|
373
|
+
"""Handle welcome message from orchestrator."""
|
374
|
+
# Extract and setup dataset configuration
|
375
|
+
dataset_config = welcome_data.get("dataset_config", {})
|
376
|
+
if dataset_config:
|
377
|
+
self._setup_dataset_loader(dataset_config)
|
378
|
+
logger.info(f"Received dataset config: {dataset_config}")
|
379
|
+
else:
|
380
|
+
logger.warning("No dataset configuration received from orchestrator")
|
381
|
+
|
382
|
+
# Update vLLM config if provided (in case it changed)
|
383
|
+
new_vllm_config = welcome_data.get("vllm_config")
|
384
|
+
if new_vllm_config and new_vllm_config != self.vllm_config:
|
385
|
+
logger.info("Received updated vLLM configuration")
|
386
|
+
if not self._handle_vllm_config_update(new_vllm_config):
|
387
|
+
logger.error("Failed to update vLLM configuration")
|
388
|
+
|
389
|
+
# Clear stop signal now that we're connected
|
390
|
+
self.should_stop_processing.clear()
|
391
|
+
|
392
|
+
# Request initial chunks if not in job mode
|
393
|
+
if not self.job_mode and self.websocket:
|
394
|
+
await self.websocket.send(json.dumps({"type": "request_chunks", "count": 2}))
|
395
|
+
|
396
|
+
async def _handle_message(self, data: Dict[str, Any]):
|
397
|
+
"""Handle message from orchestrator."""
|
398
|
+
msg_type = data.get("type")
|
399
|
+
|
400
|
+
if msg_type == "shard_assignment":
|
401
|
+
chunks = data["chunks"]
|
402
|
+
for chunk_data in chunks:
|
403
|
+
chunk = ShardChunk(**chunk_data)
|
404
|
+
with self.chunk_lock:
|
405
|
+
self.assigned_chunks.append(chunk)
|
406
|
+
logger.info(f"Received chunk assignment: {chunk.chunk_id}")
|
407
|
+
|
408
|
+
elif msg_type == "no_chunks":
|
409
|
+
reason = data.get("reason", "unknown")
|
410
|
+
logger.info(f"No chunks available from orchestrator (reason: {reason})")
|
411
|
+
|
412
|
+
wait_time = 2 if reason == "state_restoring" else 10
|
413
|
+
await asyncio.sleep(wait_time)
|
414
|
+
|
415
|
+
if self.websocket and self.connected.is_set():
|
416
|
+
await self.websocket.send(json.dumps({"type": "request_chunks", "count": 2}))
|
417
|
+
|
418
|
+
elif msg_type == "reload_vllm":
|
419
|
+
logger.info("Orchestrator requested vLLM reload")
|
420
|
+
new_config = data.get("vllm_config")
|
421
|
+
if new_config:
|
422
|
+
self._handle_vllm_config_update(new_config)
|
423
|
+
|
424
|
+
elif msg_type == "config_update":
|
425
|
+
# Soft config update without reload
|
426
|
+
if data.get("vllm_config"):
|
427
|
+
self._handle_vllm_config_update(data["vllm_config"])
|
428
|
+
|
429
|
+
elif msg_type == "job_assignment":
|
430
|
+
await self._handle_job_assignment(data["job"])
|
431
|
+
|
432
|
+
elif msg_type == "no_jobs":
|
433
|
+
logger.debug("No jobs available")
|
434
|
+
await asyncio.sleep(2)
|
435
|
+
|
436
|
+
def _get_heartbeat_data(self) -> Dict[str, Any]:
|
437
|
+
"""Get heartbeat data."""
|
438
|
+
return {
|
439
|
+
"type": "heartbeat",
|
440
|
+
"processed": self.items_processed,
|
441
|
+
"failed": self.items_failed,
|
442
|
+
"chunks_completed": self.chunks_completed,
|
443
|
+
"current_chunk": self.current_chunk.chunk_id if self.current_chunk else None,
|
444
|
+
"chunk_progress": self.current_chunk_progress,
|
445
|
+
"queue_sizes": {
|
446
|
+
"readahead": self.readahead_queue.qsize(),
|
447
|
+
"inference": self.inference_queue.qsize(),
|
448
|
+
"results": self.result_queue.qsize(),
|
449
|
+
},
|
450
|
+
"stages": len(self.stages),
|
451
|
+
"models_loaded": len(self.model_manager.models) if self.model_manager else 0,
|
452
|
+
}
|
453
|
+
|
454
|
+
async def _create_tasks(self) -> list:
|
455
|
+
"""Create async tasks to run."""
|
456
|
+
tasks = [
|
457
|
+
asyncio.create_task(self._heartbeat_loop()),
|
458
|
+
asyncio.create_task(self._base_message_handler()),
|
459
|
+
asyncio.create_task(self._result_sender()),
|
460
|
+
]
|
461
|
+
|
462
|
+
if self.job_mode:
|
463
|
+
tasks.append(asyncio.create_task(self._job_request_loop()))
|
464
|
+
|
465
|
+
return tasks
|
466
|
+
|
467
|
+
async def _on_disconnect(self):
|
468
|
+
"""Handle disconnection."""
|
469
|
+
self._clear_state_on_disconnect()
|
470
|
+
|
471
|
+
async def _pre_shutdown(self):
|
472
|
+
"""Cleanup before shutdown."""
|
473
|
+
# Stop processing threads by adding stop signals
|
474
|
+
self.readahead_queue.put(None)
|
475
|
+
self.inference_queue.put(None)
|
476
|
+
|
477
|
+
# Shutdown image processor
|
478
|
+
if self.image_processor is not None:
|
479
|
+
self.image_processor.shutdown()
|
480
|
+
|
481
|
+
# Cleanup model manager
|
482
|
+
if self.model_manager:
|
483
|
+
self.model_manager.cleanup()
|
484
|
+
|
485
|
+
async def _initial_connect_for_config(self):
|
486
|
+
"""Connect initially just to get configuration."""
|
487
|
+
logger.info(f"Connecting to {self.server_url}")
|
488
|
+
async with websockets.connect(self.server_url, ssl=self.ssl_context) as websocket:
|
489
|
+
await websocket.send(json.dumps(self._get_auth_data()))
|
490
|
+
|
491
|
+
welcome = await websocket.recv()
|
492
|
+
welcome_data = json.loads(welcome)
|
493
|
+
|
494
|
+
if "error" in welcome_data:
|
495
|
+
raise RuntimeError(f"Authentication failed: {welcome_data['error']}")
|
496
|
+
|
497
|
+
self.vllm_config = welcome_data.get("vllm_config")
|
498
|
+
if not self.vllm_config:
|
499
|
+
raise RuntimeError("No vLLM configuration received from orchestrator")
|
500
|
+
|
501
|
+
# Parse stages configuration
|
502
|
+
self.stages = self._parse_stages_config(self.vllm_config)
|
503
|
+
self.stage_order = self._topological_sort_stages(self.stages)
|
504
|
+
|
505
|
+
logger.info(f"Configured {len(self.stages)} processing stages: {self.stage_order}")
|
506
|
+
|
507
|
+
self.vllm_config_manager.current_config = self.vllm_config
|
508
|
+
|
509
|
+
dataset_config = welcome_data.get("dataset_config", {})
|
510
|
+
if dataset_config:
|
511
|
+
self._setup_dataset_loader(dataset_config)
|
512
|
+
|
513
|
+
logger.info("Received configuration from orchestrator")
|
514
|
+
|
515
|
+
def _clear_state_on_disconnect(self):
|
516
|
+
"""Clear all processing state when disconnected."""
|
517
|
+
logger.info("Clearing state due to disconnection")
|
518
|
+
|
519
|
+
self.should_stop_processing.set()
|
520
|
+
|
521
|
+
with self.chunk_lock:
|
522
|
+
self.assigned_chunks.clear()
|
523
|
+
self.current_chunk = None
|
524
|
+
self.current_chunk_progress = 0
|
525
|
+
|
526
|
+
self._clear_queue(self.readahead_queue)
|
527
|
+
self._clear_queue(self.inference_queue)
|
528
|
+
self._clear_queue(self.result_queue)
|
529
|
+
|
530
|
+
logger.info("State cleared, ready for reconnection")
|
531
|
+
|
532
|
+
def _clear_queue(self, queue: Queue):
|
533
|
+
"""Clear all items from a queue."""
|
534
|
+
try:
|
535
|
+
while True:
|
536
|
+
queue.get_nowait()
|
537
|
+
except Empty:
|
538
|
+
pass
|
539
|
+
|
540
|
+
def _setup_dataset_loader(self, dataset_config: Dict[str, Any]):
|
541
|
+
"""Initialize dataset loader with config from orchestrator."""
|
542
|
+
dataset_path = dataset_config.get("dataset_path") or dataset_config.get("path")
|
543
|
+
dataset_type = dataset_config.get("dataset_type") or dataset_config.get(
|
544
|
+
"type", "huggingface"
|
545
|
+
)
|
546
|
+
dataset_split = dataset_config.get("dataset_split") or dataset_config.get("split", "train")
|
547
|
+
dataset_image_column = dataset_config.get("dataset_image_column") or dataset_config.get(
|
548
|
+
"image_column", "image"
|
549
|
+
)
|
550
|
+
|
551
|
+
if dataset_path:
|
552
|
+
logger.info(
|
553
|
+
f"Initializing dataset loader for {dataset_type}: {dataset_path} "
|
554
|
+
f"(split: {dataset_split}, image_column: {dataset_image_column})"
|
555
|
+
)
|
556
|
+
self.dataset_loader = DatasetLoader(
|
557
|
+
dataset_path, dataset_type, dataset_split, dataset_image_column
|
558
|
+
)
|
559
|
+
self.dataset_config = dataset_config
|
560
|
+
self.dataset_type = dataset_type
|
561
|
+
self.dataset_split = dataset_split
|
562
|
+
self.dataset_image_column = dataset_image_column
|
563
|
+
else:
|
564
|
+
logger.warning("No dataset path provided by orchestrator")
|
565
|
+
|
566
|
+
def _setup_vllm(self):
|
567
|
+
"""Initialize multi-stage vLLM components."""
|
568
|
+
if not self.vllm_config:
|
569
|
+
raise RuntimeError("vLLM config not received from orchestrator")
|
570
|
+
|
571
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(self.gpu_id)
|
572
|
+
|
573
|
+
# Initialize model manager
|
574
|
+
self.model_manager = MultiStageVLLMManager(self.gpu_id)
|
575
|
+
|
576
|
+
# Get base config for models
|
577
|
+
base_config = {
|
578
|
+
"tensor_parallel_size": self.vllm_config.get("tensor_parallel_size", 1),
|
579
|
+
"max_model_len": self.vllm_config.get("max_model_len", 16384),
|
580
|
+
"dtype": self.vllm_config.get("dtype", "float16"),
|
581
|
+
"gpu_memory_utilization": self.vllm_config.get("gpu_memory_utilization", 0.92),
|
582
|
+
"enforce_eager": self.vllm_config.get("enforce_eager", True),
|
583
|
+
"disable_mm_preprocessor_cache": self.vllm_config.get(
|
584
|
+
"disable_mm_preprocessor_cache", True
|
585
|
+
),
|
586
|
+
"limit_mm_per_prompt": self.vllm_config.get("limit_mm_per_prompt", {"image": 1}),
|
587
|
+
}
|
588
|
+
|
589
|
+
base_sampling = self.vllm_config.get("sampling", {})
|
590
|
+
|
591
|
+
# Load models for all stages
|
592
|
+
unique_models = set()
|
593
|
+
for stage in self.stages:
|
594
|
+
unique_models.add(stage.model)
|
595
|
+
|
596
|
+
logger.info(f"Loading {len(unique_models)} unique models for {len(self.stages)} stages")
|
597
|
+
|
598
|
+
for stage in self.stages:
|
599
|
+
self.model_manager.load_model(stage.model, stage, base_config)
|
600
|
+
self.model_manager.create_sampling_params(stage, base_sampling)
|
601
|
+
|
602
|
+
logger.info("Multi-stage vLLM initialization complete")
|
603
|
+
|
604
|
+
# Update config manager's tracking
|
605
|
+
self.vllm_config_manager.current_config = self.vllm_config
|
606
|
+
|
607
|
+
def _handle_vllm_config_update(self, new_config: Dict[str, Any]) -> bool:
|
608
|
+
"""Handle vLLM configuration updates for multi-stage."""
|
609
|
+
if not new_config:
|
610
|
+
return True
|
611
|
+
|
612
|
+
# Parse new stages
|
613
|
+
new_stages = self._parse_stages_config(new_config)
|
614
|
+
|
615
|
+
# Check if stages changed significantly
|
616
|
+
stages_changed = len(new_stages) != len(self.stages)
|
617
|
+
if not stages_changed:
|
618
|
+
for old, new in zip(self.stages, new_stages):
|
619
|
+
if (
|
620
|
+
old.name != new.name
|
621
|
+
or old.model != new.model
|
622
|
+
or old.prompts != new.prompts
|
623
|
+
or old.output_field != new.output_field
|
624
|
+
):
|
625
|
+
stages_changed = True
|
626
|
+
break
|
627
|
+
|
628
|
+
if stages_changed:
|
629
|
+
logger.info("Stage configuration changed, reloading all models")
|
630
|
+
|
631
|
+
# Save old config
|
632
|
+
old_config = self.vllm_config
|
633
|
+
self.vllm_config = new_config
|
634
|
+
self.stages = new_stages
|
635
|
+
self.stage_order = self._topological_sort_stages(self.stages)
|
636
|
+
|
637
|
+
try:
|
638
|
+
# Cleanup old models
|
639
|
+
if self.model_manager:
|
640
|
+
self.model_manager.cleanup()
|
641
|
+
|
642
|
+
# Reload with new config
|
643
|
+
self._setup_vllm()
|
644
|
+
|
645
|
+
logger.info("Multi-stage vLLM reload complete")
|
646
|
+
return True
|
647
|
+
|
648
|
+
except Exception as e:
|
649
|
+
logger.error(f"Failed to reload vLLM: {e}")
|
650
|
+
# Restore old config
|
651
|
+
self.vllm_config = old_config
|
652
|
+
return False
|
653
|
+
else:
|
654
|
+
# Just update sampling params for existing stages
|
655
|
+
logger.info("Updating sampling parameters without model reload")
|
656
|
+
|
657
|
+
base_sampling = new_config.get("sampling", {})
|
658
|
+
for stage in self.stages:
|
659
|
+
self.model_manager.create_sampling_params(stage, base_sampling)
|
660
|
+
|
661
|
+
self.vllm_config = new_config
|
662
|
+
return True
|
663
|
+
|
664
|
+
async def _handle_job_assignment(self, job_data: Dict):
|
665
|
+
"""Handle job assignment from orchestrator."""
|
666
|
+
try:
|
667
|
+
# Convert to processing item
|
668
|
+
image = Image.open(io.BytesIO(job_data["image_data"]))
|
669
|
+
|
670
|
+
item = ProcessingItem(
|
671
|
+
chunk_id=job_data["job_id"],
|
672
|
+
item_key=job_data["sample_id"],
|
673
|
+
image=image,
|
674
|
+
image_data=job_data["image_data"],
|
675
|
+
)
|
676
|
+
|
677
|
+
# Add to inference queue
|
678
|
+
self.readahead_queue.put(item)
|
679
|
+
logger.debug(f"Queued job {job_data['job_id']} for processing")
|
680
|
+
|
681
|
+
except Exception as e:
|
682
|
+
logger.error(f"Error handling job assignment: {e}")
|
683
|
+
|
684
|
+
async def _job_request_loop(self):
|
685
|
+
"""Request jobs from orchestrator in job mode."""
|
686
|
+
while self.running and self.connected.is_set():
|
687
|
+
try:
|
688
|
+
# Check if we need more work
|
689
|
+
if self.readahead_queue.qsize() < self.vllm_config.get("batch_size", 8):
|
690
|
+
await self.websocket.send(json.dumps({"type": "request_job"}))
|
691
|
+
|
692
|
+
await asyncio.sleep(1)
|
693
|
+
|
694
|
+
except Exception as e:
|
695
|
+
logger.error(f"Job request error: {e}")
|
696
|
+
await asyncio.sleep(5)
|
697
|
+
|
698
|
+
def _process_shard_chunk(self, chunk: ShardChunk):
|
699
|
+
"""Process a single shard chunk with item-level tracking."""
|
700
|
+
logger.info(
|
701
|
+
f"Processing shard {chunk.shard_name} with unprocessed ranges: {chunk.unprocessed_ranges}"
|
702
|
+
)
|
703
|
+
|
704
|
+
# Select appropriate processor
|
705
|
+
if chunk.shard_url.startswith("hf_dataset:"):
|
706
|
+
processor = self.hf_processor
|
707
|
+
else:
|
708
|
+
processor = self.webdataset_processor
|
709
|
+
|
710
|
+
items_processed = 0
|
711
|
+
|
712
|
+
# Let the processor handle the range filtering
|
713
|
+
for key, url, image_data, metadata in processor.iterate_chunk_with_metadata(
|
714
|
+
chunk, self.dataset_loader, self.should_stop_processing, self.connected
|
715
|
+
):
|
716
|
+
try:
|
717
|
+
# Load image
|
718
|
+
img = Image.open(io.BytesIO(image_data))
|
719
|
+
|
720
|
+
# Create processing item
|
721
|
+
item = ProcessingItem(
|
722
|
+
chunk_id=chunk.chunk_id,
|
723
|
+
item_key=key,
|
724
|
+
image=img,
|
725
|
+
image_data=image_data,
|
726
|
+
metadata=metadata,
|
727
|
+
)
|
728
|
+
|
729
|
+
# Store absolute item index for tracking
|
730
|
+
# The processor should provide the correct index in metadata
|
731
|
+
if "_chunk_relative_index" in metadata:
|
732
|
+
item.metadata["_item_index"] = (
|
733
|
+
chunk.start_index + metadata["_chunk_relative_index"]
|
734
|
+
)
|
735
|
+
|
736
|
+
# Add to readahead queue with timeout handling
|
737
|
+
timeout_end = time.time() + 30
|
738
|
+
while (
|
739
|
+
self.running
|
740
|
+
and not self.should_stop_processing.is_set()
|
741
|
+
and self.connected.is_set()
|
742
|
+
):
|
743
|
+
try:
|
744
|
+
self.readahead_queue.put(item, timeout=1)
|
745
|
+
break
|
746
|
+
except:
|
747
|
+
if time.time() > timeout_end:
|
748
|
+
raise TimeoutError("Queue put timeout")
|
749
|
+
continue
|
750
|
+
|
751
|
+
# If we couldn't queue due to disconnection, stop processing
|
752
|
+
if not self.connected.is_set() or self.should_stop_processing.is_set():
|
753
|
+
logger.debug(f"Skipping remaining items due to disconnection")
|
754
|
+
break
|
755
|
+
|
756
|
+
items_processed += 1
|
757
|
+
|
758
|
+
# Batch items for inference
|
759
|
+
batch_size = self.vllm_config.get("batch_size", 8)
|
760
|
+
if self.readahead_queue.qsize() >= batch_size:
|
761
|
+
self._batch_for_inference()
|
762
|
+
|
763
|
+
except Exception as e:
|
764
|
+
if self.should_stop_processing.is_set():
|
765
|
+
break
|
766
|
+
logger.error(f"Error processing item {key}: {e}")
|
767
|
+
self.items_failed += 1
|
768
|
+
|
769
|
+
# Process any remaining items in queue
|
770
|
+
if not self.should_stop_processing.is_set():
|
771
|
+
self._batch_for_inference()
|
772
|
+
|
773
|
+
logger.info(
|
774
|
+
f"Chunk {chunk.chunk_id} processed {items_processed} items from unprocessed ranges"
|
775
|
+
)
|
776
|
+
|
777
|
+
def _shard_reader_thread(self):
|
778
|
+
"""Background thread that reads from WebDataset shards."""
|
779
|
+
logger.info("Starting shard reader thread")
|
780
|
+
|
781
|
+
while self.running:
|
782
|
+
# Check if we should stop processing
|
783
|
+
if self.should_stop_processing.is_set():
|
784
|
+
logger.info("Shard reader waiting for reconnection")
|
785
|
+
time.sleep(1)
|
786
|
+
continue
|
787
|
+
|
788
|
+
# Only process if connected
|
789
|
+
if not self.connected.is_set():
|
790
|
+
time.sleep(1)
|
791
|
+
continue
|
792
|
+
|
793
|
+
# Get next chunk to process
|
794
|
+
with self.chunk_lock:
|
795
|
+
if not self.current_chunk and self.assigned_chunks:
|
796
|
+
self.current_chunk = self.assigned_chunks.popleft()
|
797
|
+
self.current_chunk_progress = 0
|
798
|
+
logger.info(f"Starting chunk {self.current_chunk.chunk_id}")
|
799
|
+
|
800
|
+
if not self.current_chunk:
|
801
|
+
time.sleep(1)
|
802
|
+
continue
|
803
|
+
|
804
|
+
try:
|
805
|
+
# Process the chunk
|
806
|
+
self._process_shard_chunk(self.current_chunk)
|
807
|
+
|
808
|
+
# Only mark complete if still connected
|
809
|
+
if self.connected.is_set() and not self.should_stop_processing.is_set():
|
810
|
+
logger.info(f"Completed chunk {self.current_chunk.chunk_id}")
|
811
|
+
self.chunks_completed += 1
|
812
|
+
|
813
|
+
# Notify orchestrator if connected
|
814
|
+
if self.websocket and self.main_loop:
|
815
|
+
try:
|
816
|
+
# Notify completion
|
817
|
+
asyncio.run_coroutine_threadsafe(
|
818
|
+
self.websocket.send(
|
819
|
+
json.dumps(
|
820
|
+
{
|
821
|
+
"type": "chunk_complete",
|
822
|
+
"chunk_id": self.current_chunk.chunk_id,
|
823
|
+
}
|
824
|
+
)
|
825
|
+
),
|
826
|
+
self.main_loop,
|
827
|
+
).result(timeout=5)
|
828
|
+
|
829
|
+
# Request more chunks if queue is low
|
830
|
+
with self.chunk_lock:
|
831
|
+
queue_size = len(self.assigned_chunks)
|
832
|
+
|
833
|
+
if queue_size < 2:
|
834
|
+
logger.info(f"Requesting more chunks (queue size: {queue_size})")
|
835
|
+
asyncio.run_coroutine_threadsafe(
|
836
|
+
self.websocket.send(
|
837
|
+
json.dumps({"type": "request_chunks", "count": 2})
|
838
|
+
),
|
839
|
+
self.main_loop,
|
840
|
+
).result(timeout=5)
|
841
|
+
|
842
|
+
except Exception as e:
|
843
|
+
logger.warning(f"Could not notify orchestrator: {e}")
|
844
|
+
|
845
|
+
with self.chunk_lock:
|
846
|
+
self.current_chunk = None
|
847
|
+
|
848
|
+
except Exception as e:
|
849
|
+
logger.error(f"Error processing chunk: {e}")
|
850
|
+
|
851
|
+
# Only notify of failure if still connected
|
852
|
+
if self.connected.is_set() and self.websocket and self.main_loop:
|
853
|
+
try:
|
854
|
+
asyncio.run_coroutine_threadsafe(
|
855
|
+
self.websocket.send(
|
856
|
+
json.dumps(
|
857
|
+
{
|
858
|
+
"type": "chunk_failed",
|
859
|
+
"chunk_id": (
|
860
|
+
self.current_chunk.chunk_id
|
861
|
+
if self.current_chunk
|
862
|
+
else "unknown"
|
863
|
+
),
|
864
|
+
"error": str(e),
|
865
|
+
}
|
866
|
+
)
|
867
|
+
),
|
868
|
+
self.main_loop,
|
869
|
+
).result(timeout=5)
|
870
|
+
except Exception as send_error:
|
871
|
+
logger.warning(
|
872
|
+
f"Could not notify orchestrator of chunk failure: {send_error}"
|
873
|
+
)
|
874
|
+
|
875
|
+
with self.chunk_lock:
|
876
|
+
self.current_chunk = None
|
877
|
+
|
878
|
+
async def _result_sender(self):
|
879
|
+
"""Send results back to orchestrator with item index."""
|
880
|
+
pending_results = []
|
881
|
+
|
882
|
+
try:
|
883
|
+
while self.running and self.connected.is_set():
|
884
|
+
try:
|
885
|
+
# Get result with timeout
|
886
|
+
try:
|
887
|
+
result = await asyncio.get_event_loop().run_in_executor(
|
888
|
+
None, self.result_queue.get, True, 1
|
889
|
+
)
|
890
|
+
pending_results.append(result)
|
891
|
+
except Empty:
|
892
|
+
pass
|
893
|
+
|
894
|
+
# Only try to send if connected
|
895
|
+
if pending_results and self.websocket and self.connected.is_set():
|
896
|
+
sent_results = []
|
897
|
+
for result in pending_results:
|
898
|
+
try:
|
899
|
+
# Build message with item index
|
900
|
+
message_data = {
|
901
|
+
"type": "submit_captions",
|
902
|
+
"chunk_id": result.chunk_id,
|
903
|
+
"dataset": self.dataset_config.get("dataset_path", "unknown"),
|
904
|
+
"shard": result.shard_name,
|
905
|
+
"item_key": result.item_key,
|
906
|
+
"item_index": result.item_index, # NEW: Include index
|
907
|
+
"outputs": result.outputs,
|
908
|
+
"captions": result.outputs.get("captions", []), # Compatibility
|
909
|
+
"caption_count": sum(len(v) for v in result.outputs.values()),
|
910
|
+
"image_width": result.image_width,
|
911
|
+
"image_height": result.image_height,
|
912
|
+
"image_format": result.image_format,
|
913
|
+
"file_size": result.file_size,
|
914
|
+
"processing_time_ms": result.processing_time_ms,
|
915
|
+
"metadata": result.metadata,
|
916
|
+
}
|
917
|
+
|
918
|
+
await self.websocket.send(json.dumps(message_data))
|
919
|
+
sent_results.append(result)
|
920
|
+
|
921
|
+
if self.items_processed % 100 == 0:
|
922
|
+
total_outputs = sum(
|
923
|
+
len(outputs) for outputs in result.outputs.values()
|
924
|
+
)
|
925
|
+
logger.info(
|
926
|
+
f"Processed {self.items_processed} items "
|
927
|
+
f"(~{total_outputs} outputs across {len(result.outputs)} fields)"
|
928
|
+
)
|
929
|
+
|
930
|
+
except websockets.exceptions.ConnectionClosed as e:
|
931
|
+
logger.warning(f"Connection lost while sending result: {e}")
|
932
|
+
raise
|
933
|
+
except Exception as e:
|
934
|
+
logger.error(f"Error sending result: {e}")
|
935
|
+
break
|
936
|
+
|
937
|
+
# Remove successfully sent results
|
938
|
+
for result in sent_results:
|
939
|
+
pending_results.remove(result)
|
940
|
+
|
941
|
+
# Clear pending results if disconnected and buffer is too large
|
942
|
+
if not self.connected.is_set() and len(pending_results) > 1000:
|
943
|
+
logger.warning(
|
944
|
+
f"Clearing {len(pending_results)} pending results due to prolonged disconnection"
|
945
|
+
)
|
946
|
+
pending_results.clear()
|
947
|
+
|
948
|
+
await asyncio.sleep(0.1)
|
949
|
+
|
950
|
+
except Exception as e:
|
951
|
+
if isinstance(e, websockets.exceptions.ConnectionClosed):
|
952
|
+
raise
|
953
|
+
logger.error(f"Unexpected error in result sender: {e}")
|
954
|
+
await asyncio.sleep(1)
|
955
|
+
|
956
|
+
except asyncio.CancelledError:
|
957
|
+
logger.debug("Result sender cancelled")
|
958
|
+
raise
|
959
|
+
|
960
|
+
def _batch_for_inference(self):
|
961
|
+
"""Batch items from readahead queue for inference."""
|
962
|
+
batch = []
|
963
|
+
batch_size = self.vllm_config.get("batch_size", 8)
|
964
|
+
|
965
|
+
try:
|
966
|
+
while len(batch) < batch_size:
|
967
|
+
item = self.readahead_queue.get_nowait()
|
968
|
+
batch.append(item)
|
969
|
+
except Empty:
|
970
|
+
pass
|
971
|
+
|
972
|
+
if batch:
|
973
|
+
self.inference_queue.put(batch)
|
974
|
+
|
975
|
+
def _process_batch_multi_stage(
|
976
|
+
self, batch: List[ProcessingItem], max_attempts: int = 3
|
977
|
+
) -> List[ProcessedResult]:
|
978
|
+
"""Process a batch through all stages sequentially."""
|
979
|
+
results = []
|
980
|
+
|
981
|
+
# Process each stage in order
|
982
|
+
for stage_name in self.stage_order:
|
983
|
+
stage = next(s for s in self.stages if s.name == stage_name)
|
984
|
+
logger.debug(f"Processing batch through stage: {stage_name}")
|
985
|
+
|
986
|
+
# Get model components for this stage
|
987
|
+
llm, processor, tokenizer, sampling_params = self.model_manager.get_model_for_stage(
|
988
|
+
stage_name, stage.model
|
989
|
+
)
|
990
|
+
|
991
|
+
# Track items for retry
|
992
|
+
items_to_process = [(i, item, 0) for i, item in enumerate(batch)]
|
993
|
+
|
994
|
+
while items_to_process:
|
995
|
+
# Build requests for current items
|
996
|
+
current_batch = []
|
997
|
+
current_indices = []
|
998
|
+
requests = []
|
999
|
+
|
1000
|
+
for idx, (original_idx, item, attempt_count) in enumerate(items_to_process):
|
1001
|
+
current_batch.append((original_idx, item, attempt_count))
|
1002
|
+
current_indices.append(idx)
|
1003
|
+
|
1004
|
+
# Prepare image
|
1005
|
+
converted_img = ImageProcessor.prepare_for_inference(item.image)
|
1006
|
+
|
1007
|
+
# Create template manager for this stage's prompts
|
1008
|
+
template_manager = PromptTemplateManager(stage.prompts)
|
1009
|
+
|
1010
|
+
# Build context including metadata and previous stage results
|
1011
|
+
context = item.metadata.copy()
|
1012
|
+
|
1013
|
+
# Add previous stage outputs to context
|
1014
|
+
for prev_stage_name, stage_result in item.stage_results.items():
|
1015
|
+
# Add outputs with stage name prefix
|
1016
|
+
for i, output in enumerate(stage_result.outputs):
|
1017
|
+
context[f"{prev_stage_name}_output_{i}"] = output
|
1018
|
+
# Also add under output field name
|
1019
|
+
if len(stage_result.outputs) == 1:
|
1020
|
+
context[stage_result.output_field] = stage_result.outputs[0]
|
1021
|
+
else:
|
1022
|
+
context[stage_result.output_field] = stage_result.outputs
|
1023
|
+
|
1024
|
+
# Format prompts with context
|
1025
|
+
formatted_prompts = template_manager.format_all(context)
|
1026
|
+
|
1027
|
+
# Build requests for all prompts
|
1028
|
+
for prompt in formatted_prompts:
|
1029
|
+
req = self._build_vllm_input(converted_img, prompt, processor, tokenizer)
|
1030
|
+
requests.append(req)
|
1031
|
+
|
1032
|
+
# Run inference
|
1033
|
+
outputs = llm.generate(requests, sampling_params)
|
1034
|
+
|
1035
|
+
# Process outputs
|
1036
|
+
successful_items = []
|
1037
|
+
failed_items = []
|
1038
|
+
|
1039
|
+
for idx, (original_idx, item, attempt_count) in enumerate(current_batch):
|
1040
|
+
# Check if we should stop
|
1041
|
+
if self.should_stop_processing.is_set():
|
1042
|
+
return results
|
1043
|
+
|
1044
|
+
# Extract outputs for this item
|
1045
|
+
base_idx = idx * len(stage.prompts)
|
1046
|
+
stage_outputs = []
|
1047
|
+
|
1048
|
+
for j in range(len(stage.prompts)):
|
1049
|
+
if base_idx + j < len(outputs) and outputs[base_idx + j].outputs:
|
1050
|
+
original_output = outputs[base_idx + j].outputs[0].text
|
1051
|
+
cleaned_output = self._clean_output(original_output)
|
1052
|
+
if cleaned_output:
|
1053
|
+
stage_outputs.append(cleaned_output)
|
1054
|
+
else:
|
1055
|
+
logger.warning(
|
1056
|
+
f"(stage {stage_name}, item {item.item_key}) output destroyed: {original_output}"
|
1057
|
+
)
|
1058
|
+
|
1059
|
+
if stage_outputs:
|
1060
|
+
# Success - add stage result to item
|
1061
|
+
stage_result = StageResult(
|
1062
|
+
stage_name=stage_name,
|
1063
|
+
output_field=stage.output_field,
|
1064
|
+
outputs=stage_outputs,
|
1065
|
+
)
|
1066
|
+
item.stage_results[stage_name] = stage_result
|
1067
|
+
successful_items.append((original_idx, item))
|
1068
|
+
else:
|
1069
|
+
# Failed - check if we should retry
|
1070
|
+
if attempt_count + 1 < max_attempts:
|
1071
|
+
failed_items.append((original_idx, item, attempt_count + 1))
|
1072
|
+
logger.warning(
|
1073
|
+
f"Stage {stage_name} failed for item {item.item_key} "
|
1074
|
+
f"(attempt {attempt_count + 1}/{max_attempts}), will retry"
|
1075
|
+
)
|
1076
|
+
else:
|
1077
|
+
logger.error(
|
1078
|
+
f"Stage {stage_name} failed for item {item.item_key} "
|
1079
|
+
f"after {max_attempts} attempts"
|
1080
|
+
)
|
1081
|
+
self.items_failed += 1
|
1082
|
+
|
1083
|
+
# Update items to process for next iteration
|
1084
|
+
items_to_process = failed_items
|
1085
|
+
|
1086
|
+
# Update batch with successful items for next stage
|
1087
|
+
batch = [item for _, item in successful_items]
|
1088
|
+
|
1089
|
+
# Log retry status if we have items to retry
|
1090
|
+
if items_to_process:
|
1091
|
+
logger.info(
|
1092
|
+
f"Retrying {len(items_to_process)} failed items for stage {stage_name}"
|
1093
|
+
)
|
1094
|
+
|
1095
|
+
# Convert batch items to results
|
1096
|
+
for item in batch:
|
1097
|
+
# Aggregate outputs by field name
|
1098
|
+
outputs_by_field = defaultdict(list)
|
1099
|
+
|
1100
|
+
for stage_result in item.stage_results.values():
|
1101
|
+
outputs_by_field[stage_result.output_field].extend(stage_result.outputs)
|
1102
|
+
|
1103
|
+
result = ProcessedResult(
|
1104
|
+
chunk_id=item.chunk_id,
|
1105
|
+
shard_name=Path(item.chunk_id).stem.rsplit("_chunk_", 1)[0],
|
1106
|
+
item_key=item.item_key,
|
1107
|
+
outputs=dict(outputs_by_field), # Convert defaultdict to dict
|
1108
|
+
image_width=item.image.width,
|
1109
|
+
image_height=item.image.height,
|
1110
|
+
image_format=item.image.format or "unknown",
|
1111
|
+
file_size=len(item.image_data),
|
1112
|
+
processing_time_ms=0, # Will be calculated by caller
|
1113
|
+
metadata=item.metadata,
|
1114
|
+
)
|
1115
|
+
results.append(result)
|
1116
|
+
self.items_processed += 1
|
1117
|
+
|
1118
|
+
return results
|
1119
|
+
|
1120
|
+
def _inference_thread(self):
|
1121
|
+
"""Background thread for multi-stage vLLM inference."""
|
1122
|
+
logger.info("Starting multi-stage inference thread")
|
1123
|
+
|
1124
|
+
while self.running:
|
1125
|
+
try:
|
1126
|
+
# Get batch from queue with timeout
|
1127
|
+
batch = self.inference_queue.get(timeout=1)
|
1128
|
+
|
1129
|
+
if not batch:
|
1130
|
+
continue
|
1131
|
+
|
1132
|
+
# Skip if disconnected
|
1133
|
+
if self.should_stop_processing.is_set():
|
1134
|
+
continue
|
1135
|
+
|
1136
|
+
logger.debug(
|
1137
|
+
f"Processing batch of {len(batch)} images through {len(self.stages)} stages"
|
1138
|
+
)
|
1139
|
+
start_time = time.time()
|
1140
|
+
|
1141
|
+
# Process batch through all stages
|
1142
|
+
results = self._process_batch_multi_stage(batch)
|
1143
|
+
|
1144
|
+
# Calculate processing time per item
|
1145
|
+
if results:
|
1146
|
+
processing_time_per_item = (time.time() - start_time) * 1000 / len(batch)
|
1147
|
+
|
1148
|
+
# Update processing time and queue results
|
1149
|
+
for result in results:
|
1150
|
+
result.processing_time_ms = processing_time_per_item
|
1151
|
+
self.result_queue.put(result)
|
1152
|
+
|
1153
|
+
logger.debug(
|
1154
|
+
f"Multi-stage batch processing complete: {len(results)} successful, "
|
1155
|
+
f"{len(batch) - len(results)} failed"
|
1156
|
+
)
|
1157
|
+
|
1158
|
+
except Empty:
|
1159
|
+
continue
|
1160
|
+
except Exception as e:
|
1161
|
+
if self.should_stop_processing.is_set():
|
1162
|
+
continue
|
1163
|
+
logger.error(f"Inference error: {e}", exc_info=True)
|
1164
|
+
|
1165
|
+
def _build_vllm_input(self, image: Image.Image, prompt: str, processor, tokenizer) -> Dict:
|
1166
|
+
"""Build vLLM input."""
|
1167
|
+
try:
|
1168
|
+
from qwen_vl_utils import process_vision_info
|
1169
|
+
|
1170
|
+
messages = [
|
1171
|
+
{
|
1172
|
+
"role": "user",
|
1173
|
+
"content": [
|
1174
|
+
{"type": "image", "image": image},
|
1175
|
+
{"type": "text", "text": prompt},
|
1176
|
+
],
|
1177
|
+
}
|
1178
|
+
]
|
1179
|
+
|
1180
|
+
prompt_text = processor.apply_chat_template(
|
1181
|
+
messages, tokenize=False, add_generation_prompt=True
|
1182
|
+
)
|
1183
|
+
image_inputs, _ = process_vision_info(messages)
|
1184
|
+
prompt_ids = tokenizer(prompt_text, add_special_tokens=False).input_ids
|
1185
|
+
|
1186
|
+
return {
|
1187
|
+
"prompt_token_ids": prompt_ids,
|
1188
|
+
"multi_modal_data": {"image": image_inputs},
|
1189
|
+
}
|
1190
|
+
except ImportError:
|
1191
|
+
return {
|
1192
|
+
"prompt": f"<|user|>\n<|image_pad|>\n{prompt}<|end|>\n<|assistant|>",
|
1193
|
+
"multi_modal_data": {"image": [image]},
|
1194
|
+
}
|
1195
|
+
|
1196
|
+
def _clean_output(self, text: str) -> str:
|
1197
|
+
"""Clean model output."""
|
1198
|
+
if not text:
|
1199
|
+
return ""
|
1200
|
+
|
1201
|
+
# Remove common artifacts
|
1202
|
+
for token in ["<|end|>", "<|endoftext|>", "<|im_end|>", "I'm sorry", "I cannot"]:
|
1203
|
+
if token in text:
|
1204
|
+
text = text.split(token)[0]
|
1205
|
+
|
1206
|
+
return text.strip()
|
1207
|
+
|
1208
|
+
async def _result_sender(self):
|
1209
|
+
"""Send results back to orchestrator with multi-stage outputs."""
|
1210
|
+
pending_results = [] # Buffer for results during disconnection
|
1211
|
+
|
1212
|
+
try:
|
1213
|
+
while self.running and self.connected.is_set():
|
1214
|
+
try:
|
1215
|
+
# Get result (with timeout to allow checking self.running)
|
1216
|
+
try:
|
1217
|
+
result = await asyncio.get_event_loop().run_in_executor(
|
1218
|
+
None, self.result_queue.get, True, 1
|
1219
|
+
)
|
1220
|
+
pending_results.append(result)
|
1221
|
+
except Empty:
|
1222
|
+
pass
|
1223
|
+
|
1224
|
+
# Only try to send if connected
|
1225
|
+
if pending_results and self.websocket and self.connected.is_set():
|
1226
|
+
sent_results = []
|
1227
|
+
for result in pending_results:
|
1228
|
+
try:
|
1229
|
+
# For backward compatibility, if there's only one output field "captions"
|
1230
|
+
# send it in the old format
|
1231
|
+
if len(result.outputs) == 1 and "captions" in result.outputs:
|
1232
|
+
# Old format for single-stage compatibility
|
1233
|
+
await self.websocket.send(
|
1234
|
+
json.dumps(
|
1235
|
+
{
|
1236
|
+
"type": "submit_captions",
|
1237
|
+
"chunk_id": result.chunk_id,
|
1238
|
+
"dataset": self.dataset_config.get(
|
1239
|
+
"dataset_path", "unknown"
|
1240
|
+
),
|
1241
|
+
"shard": result.shard_name,
|
1242
|
+
"item_key": result.item_key,
|
1243
|
+
"item_index": result.metadata.get("_item_index"),
|
1244
|
+
"captions": result.outputs["captions"],
|
1245
|
+
"caption_count": len(result.outputs["captions"]),
|
1246
|
+
"image_width": result.image_width,
|
1247
|
+
"image_height": result.image_height,
|
1248
|
+
"image_format": result.image_format,
|
1249
|
+
"file_size": result.file_size,
|
1250
|
+
"processing_time_ms": result.processing_time_ms,
|
1251
|
+
}
|
1252
|
+
)
|
1253
|
+
)
|
1254
|
+
else:
|
1255
|
+
# New format for multi-stage outputs
|
1256
|
+
await self.websocket.send(
|
1257
|
+
json.dumps(
|
1258
|
+
{
|
1259
|
+
"type": "submit_captions",
|
1260
|
+
"chunk_id": result.chunk_id,
|
1261
|
+
"dataset": self.dataset_config.get(
|
1262
|
+
"dataset_path", "unknown"
|
1263
|
+
),
|
1264
|
+
"shard": result.shard_name,
|
1265
|
+
"item_key": result.item_key,
|
1266
|
+
"outputs": result.outputs, # Dict of field -> list of outputs
|
1267
|
+
"captions": result.outputs.get(
|
1268
|
+
"captions", []
|
1269
|
+
), # For compatibility
|
1270
|
+
"caption_count": sum(
|
1271
|
+
len(v) for v in result.outputs.values()
|
1272
|
+
),
|
1273
|
+
"image_width": result.image_width,
|
1274
|
+
"image_height": result.image_height,
|
1275
|
+
"image_format": result.image_format,
|
1276
|
+
"file_size": result.file_size,
|
1277
|
+
"processing_time_ms": result.processing_time_ms,
|
1278
|
+
"metadata": result.metadata,
|
1279
|
+
}
|
1280
|
+
)
|
1281
|
+
)
|
1282
|
+
|
1283
|
+
sent_results.append(result)
|
1284
|
+
|
1285
|
+
if self.items_processed % 100 == 0:
|
1286
|
+
total_outputs = sum(
|
1287
|
+
len(outputs) for outputs in result.outputs.values()
|
1288
|
+
)
|
1289
|
+
logger.info(
|
1290
|
+
f"Processed {self.items_processed} items "
|
1291
|
+
f"(~{total_outputs} outputs across {len(result.outputs)} fields)"
|
1292
|
+
)
|
1293
|
+
except websockets.exceptions.ConnectionClosed as e:
|
1294
|
+
logger.warning(f"Connection lost while sending result: {e}")
|
1295
|
+
raise # Re-raise to trigger task completion
|
1296
|
+
except Exception as e:
|
1297
|
+
logger.error(f"Error sending result: {e}")
|
1298
|
+
break
|
1299
|
+
|
1300
|
+
# Remove successfully sent results
|
1301
|
+
for result in sent_results:
|
1302
|
+
pending_results.remove(result)
|
1303
|
+
|
1304
|
+
# Clear pending results if disconnected and buffer is too large
|
1305
|
+
if not self.connected.is_set() and len(pending_results) > 1000:
|
1306
|
+
logger.warning(
|
1307
|
+
f"Clearing {len(pending_results)} pending results due to prolonged disconnection"
|
1308
|
+
)
|
1309
|
+
pending_results.clear()
|
1310
|
+
|
1311
|
+
await asyncio.sleep(0.1)
|
1312
|
+
|
1313
|
+
except Exception as e:
|
1314
|
+
if isinstance(e, websockets.exceptions.ConnectionClosed):
|
1315
|
+
raise # Re-raise connection errors
|
1316
|
+
logger.error(f"Unexpected error in result sender: {e}")
|
1317
|
+
await asyncio.sleep(1)
|
1318
|
+
|
1319
|
+
except asyncio.CancelledError:
|
1320
|
+
logger.debug("Result sender cancelled")
|
1321
|
+
raise
|