caption-flow 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/cli.py +308 -0
- caption_flow/models.py +134 -1
- caption_flow/monitor.py +1 -1
- caption_flow/orchestrator.py +423 -1715
- caption_flow/processors/__init__.py +11 -0
- caption_flow/processors/base.py +219 -0
- caption_flow/processors/huggingface.py +832 -0
- caption_flow/processors/local_filesystem.py +683 -0
- caption_flow/processors/webdataset.py +782 -0
- caption_flow/storage/__init__.py +1 -0
- caption_flow/storage/exporter.py +550 -0
- caption_flow/{storage.py → storage/manager.py} +489 -401
- caption_flow/utils/checkpoint_tracker.py +2 -2
- caption_flow/utils/chunk_tracker.py +73 -32
- caption_flow/utils/dataset_loader.py +58 -298
- caption_flow/utils/dataset_metadata_cache.py +67 -0
- caption_flow/utils/image_processor.py +1 -4
- caption_flow/utils/shard_processor.py +5 -265
- caption_flow/utils/shard_tracker.py +1 -5
- caption_flow/viewer.py +594 -0
- caption_flow/workers/base.py +3 -3
- caption_flow/workers/caption.py +416 -792
- {caption_flow-0.2.2.dist-info → caption_flow-0.2.4.dist-info}/METADATA +49 -180
- caption_flow-0.2.4.dist-info/RECORD +38 -0
- caption_flow-0.2.2.dist-info/RECORD +0 -29
- {caption_flow-0.2.2.dist-info → caption_flow-0.2.4.dist-info}/WHEEL +0 -0
- {caption_flow-0.2.2.dist-info → caption_flow-0.2.4.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.2.2.dist-info → caption_flow-0.2.4.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.2.2.dist-info → caption_flow-0.2.4.dist-info}/top_level.txt +0 -0
caption_flow/workers/caption.py
CHANGED
@@ -1,103 +1,59 @@
|
|
1
|
-
"""Caption worker for
|
1
|
+
"""Caption worker with processor abstraction for distributed captioning."""
|
2
2
|
|
3
3
|
import os
|
4
4
|
|
5
5
|
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
|
6
6
|
|
7
7
|
import asyncio
|
8
|
-
import io
|
9
8
|
import json
|
10
9
|
import logging
|
11
10
|
import websockets
|
12
11
|
import time
|
13
|
-
from dataclasses import dataclass
|
14
|
-
from
|
15
|
-
from typing import Dict, Any, Optional, List, Tuple
|
12
|
+
from dataclasses import dataclass
|
13
|
+
from typing import Dict, Any, Optional, List, Tuple, Union
|
16
14
|
from queue import Queue, Empty
|
17
|
-
from threading import Thread,
|
18
|
-
from collections import
|
15
|
+
from threading import Thread, Event, Lock
|
16
|
+
from collections import defaultdict, deque
|
19
17
|
|
20
18
|
from PIL import Image
|
21
|
-
import numpy as np
|
22
19
|
from huggingface_hub import get_token
|
23
20
|
|
24
21
|
from .base import BaseWorker
|
25
|
-
from ..
|
26
|
-
|
27
|
-
|
22
|
+
from ..processors import (
|
23
|
+
ProcessorConfig,
|
24
|
+
WorkAssignment,
|
25
|
+
WorkUnit,
|
26
|
+
WorkResult,
|
27
|
+
WebDatasetWorkerProcessor,
|
28
|
+
HuggingFaceDatasetWorkerProcessor,
|
29
|
+
LocalFilesystemWorkerProcessor,
|
30
|
+
)
|
28
31
|
from ..utils.vllm_config import VLLMConfigManager
|
29
32
|
from ..utils.image_processor import ImageProcessor
|
30
|
-
from ..utils.shard_processor import HFDatasetShardProcessor, WebDatasetShardProcessor
|
31
33
|
from ..utils.prompt_template import PromptTemplateManager
|
34
|
+
from ..models import ProcessingStage, StageResult
|
32
35
|
|
33
36
|
logger = logging.getLogger(__name__)
|
34
|
-
|
35
|
-
|
36
|
-
@dataclass
|
37
|
-
class ProcessingStage:
|
38
|
-
"""Configuration for a single processing stage."""
|
39
|
-
|
40
|
-
name: str
|
41
|
-
model: str
|
42
|
-
prompts: List[str]
|
43
|
-
output_field: str
|
44
|
-
requires: List[str] = field(default_factory=list)
|
45
|
-
sampling: Optional[Dict[str, Any]] = None
|
46
|
-
|
47
|
-
# Model-specific overrides
|
48
|
-
tensor_parallel_size: Optional[int] = None
|
49
|
-
max_model_len: Optional[int] = None
|
50
|
-
dtype: Optional[str] = None
|
51
|
-
gpu_memory_utilization: Optional[float] = None
|
52
|
-
|
53
|
-
|
54
|
-
@dataclass
|
55
|
-
class StageResult:
|
56
|
-
"""Results from a single stage."""
|
57
|
-
|
58
|
-
stage_name: str
|
59
|
-
output_field: str
|
60
|
-
outputs: List[str] # Multiple outputs from multiple prompts
|
61
|
-
|
62
|
-
|
63
|
-
@dataclass
|
64
|
-
class ShardChunk:
|
65
|
-
"""Shard chunk assignment with unprocessed ranges."""
|
66
|
-
|
67
|
-
chunk_id: str
|
68
|
-
shard_url: str
|
69
|
-
shard_name: str
|
70
|
-
start_index: int
|
71
|
-
chunk_size: int
|
72
|
-
unprocessed_ranges: List[Tuple[int, int]] = field(default_factory=list)
|
37
|
+
# logger.setLevel(logging.DEBUG)
|
73
38
|
|
74
39
|
|
75
40
|
@dataclass
|
76
41
|
class ProcessingItem:
|
77
|
-
"""Item being processed."""
|
42
|
+
"""Item being processed through stages."""
|
78
43
|
|
44
|
+
unit_id: str
|
45
|
+
job_id: str
|
79
46
|
chunk_id: str
|
80
47
|
item_key: str
|
48
|
+
item_index: int
|
81
49
|
image: Image.Image
|
82
50
|
image_data: bytes
|
83
|
-
metadata: Dict[str, Any]
|
84
|
-
stage_results: Dict[str, StageResult] =
|
51
|
+
metadata: Dict[str, Any]
|
52
|
+
stage_results: Dict[str, StageResult] = None
|
85
53
|
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
"""Result with multi-stage outputs."""
|
90
|
-
|
91
|
-
chunk_id: str
|
92
|
-
shard_name: str
|
93
|
-
item_key: str
|
94
|
-
outputs: Dict[str, List[str]] # field_name -> list of outputs
|
95
|
-
image_width: int
|
96
|
-
image_height: int
|
97
|
-
image_format: str
|
98
|
-
file_size: int
|
99
|
-
processing_time_ms: float
|
100
|
-
metadata: Dict[str, Any] = field(default_factory=dict)
|
54
|
+
def __post_init__(self):
|
55
|
+
if self.stage_results is None:
|
56
|
+
self.stage_results = {}
|
101
57
|
|
102
58
|
|
103
59
|
class MultiStageVLLMManager:
|
@@ -121,7 +77,7 @@ class MultiStageVLLMManager:
|
|
121
77
|
|
122
78
|
logger.info(f"Loading model {model_name} for stage {stage.name}")
|
123
79
|
|
124
|
-
# Build model-specific config
|
80
|
+
# Build model-specific config
|
125
81
|
model_config = base_config.copy()
|
126
82
|
model_config["model"] = model_name
|
127
83
|
|
@@ -163,10 +119,7 @@ class MultiStageVLLMManager:
|
|
163
119
|
"""Create sampling params for a stage."""
|
164
120
|
from vllm import SamplingParams
|
165
121
|
|
166
|
-
# Start with base sampling config
|
167
122
|
sampling_config = base_sampling.copy()
|
168
|
-
|
169
|
-
# Override with stage-specific sampling if provided
|
170
123
|
if stage.sampling:
|
171
124
|
sampling_config.update(stage.sampling)
|
172
125
|
|
@@ -183,16 +136,7 @@ class MultiStageVLLMManager:
|
|
183
136
|
return params
|
184
137
|
|
185
138
|
def get_model_for_stage(self, stage_name: str, model_name: str) -> Tuple[Any, Any, Any, Any]:
|
186
|
-
"""
|
187
|
-
Get model components for a stage.
|
188
|
-
|
189
|
-
Returns:
|
190
|
-
tuple: A tuple containing:
|
191
|
-
- llm: The language model instance for the given model name.
|
192
|
-
- processor: The processor associated with the model.
|
193
|
-
- tokenizer: The tokenizer for the model.
|
194
|
-
- sampling_params: The sampling parameters for the given stage.
|
195
|
-
"""
|
139
|
+
"""Get model components for a stage."""
|
196
140
|
return (
|
197
141
|
self.models[model_name],
|
198
142
|
self.processors[model_name],
|
@@ -214,74 +158,69 @@ class MultiStageVLLMManager:
|
|
214
158
|
|
215
159
|
|
216
160
|
class CaptionWorker(BaseWorker):
|
217
|
-
"""Worker that processes
|
161
|
+
"""Worker that processes work units for image captioning using multi-stage vLLM."""
|
218
162
|
|
219
163
|
def __init__(self, config: Dict[str, Any]):
|
220
164
|
super().__init__(config)
|
221
165
|
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
166
|
+
# Processor configuration - will be set from orchestrator
|
167
|
+
self.processor_type = None
|
168
|
+
self.processor: Optional[
|
169
|
+
Union[
|
170
|
+
WebDatasetWorkerProcessor,
|
171
|
+
HuggingFaceDatasetWorkerProcessor,
|
172
|
+
LocalFilesystemWorkerProcessor,
|
173
|
+
],
|
174
|
+
] = None
|
175
|
+
self.dataset_path: Optional[str] = None
|
176
|
+
|
177
|
+
# vLLM configuration
|
233
178
|
self.vllm_config = None
|
234
179
|
self.stages: List[ProcessingStage] = []
|
235
|
-
self.stage_order: List[str] = []
|
180
|
+
self.stage_order: List[str] = []
|
236
181
|
self.vllm_config_manager = VLLMConfigManager()
|
237
182
|
self.model_manager = None
|
238
183
|
|
239
|
-
#
|
184
|
+
# GPU selection
|
240
185
|
self.gpu_id = config.get("gpu_id", 0)
|
241
|
-
|
242
|
-
# Connection state events
|
243
|
-
self.should_stop_processing = Event()
|
186
|
+
self.hf_token = get_token()
|
244
187
|
|
245
188
|
# Image processor
|
246
|
-
|
247
|
-
if batch_image_processing
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
self.
|
252
|
-
self.
|
253
|
-
hf_token=self.hf_token, dataset_type=self.dataset_type
|
254
|
-
)
|
255
|
-
self.chunk_lock = Lock()
|
256
|
-
self.assigned_chunks = deque()
|
257
|
-
self.current_chunk = None
|
258
|
-
self.current_chunk_progress = 0
|
189
|
+
batch_image_processing = config.get("batch_image_processing", False)
|
190
|
+
self.image_processor = ImageProcessor() if batch_image_processing else None
|
191
|
+
|
192
|
+
# Work processing
|
193
|
+
self.work_lock = Lock()
|
194
|
+
self.assigned_units = deque()
|
195
|
+
self.current_unit: Optional[WorkUnit] = None
|
259
196
|
|
260
|
-
#
|
197
|
+
# Processing queues
|
261
198
|
self.readahead_queue = Queue(maxsize=256)
|
262
199
|
self.inference_queue = Queue(maxsize=128)
|
263
200
|
self.result_queue = Queue()
|
264
201
|
|
265
|
-
#
|
266
|
-
self.
|
267
|
-
self.job_queue = Queue(maxsize=32)
|
202
|
+
# Processing control
|
203
|
+
self.should_stop_processing = Event()
|
268
204
|
|
269
205
|
def _init_metrics(self):
|
270
206
|
"""Initialize worker metrics."""
|
271
207
|
self.items_processed = 0
|
272
208
|
self.items_failed = 0
|
273
|
-
self.
|
209
|
+
self.units_completed = 0
|
274
210
|
|
275
211
|
def _get_auth_data(self) -> Dict[str, Any]:
|
276
212
|
"""Get authentication data."""
|
277
213
|
return {"token": self.token, "name": self.name}
|
278
214
|
|
215
|
+
def _get_current_unit_id(self) -> Optional[str]:
|
216
|
+
"""Get the current unit ID."""
|
217
|
+
return self.current_unit.unit_id if self.current_unit else None
|
218
|
+
|
279
219
|
async def _pre_start(self):
|
280
220
|
"""Initialize before starting connection loop."""
|
281
|
-
# Wait for initial connection to get
|
221
|
+
# Wait for initial connection to get config
|
282
222
|
logger.info("Connecting to orchestrator for configuration...")
|
283
223
|
|
284
|
-
# Try initial connection to get config
|
285
224
|
config_received = False
|
286
225
|
while not config_received and self.running:
|
287
226
|
try:
|
@@ -292,21 +231,110 @@ class CaptionWorker(BaseWorker):
|
|
292
231
|
await asyncio.sleep(5)
|
293
232
|
|
294
233
|
# Initialize vLLM once we have config
|
295
|
-
self.
|
234
|
+
if self.vllm_config:
|
235
|
+
self._setup_vllm()
|
296
236
|
|
297
237
|
# Start background threads
|
298
|
-
|
299
|
-
|
238
|
+
Thread(target=self._unit_processor_thread, daemon=True).start()
|
239
|
+
Thread(target=self._inference_thread, daemon=True).start()
|
240
|
+
|
241
|
+
async def _initial_connect_for_config(self):
|
242
|
+
"""Connect initially just to get configuration."""
|
243
|
+
logger.info(f"Connecting to {self.server_url}")
|
244
|
+
async with websockets.connect(self.server_url, ssl=self.ssl_context) as websocket:
|
245
|
+
await websocket.send(json.dumps(self._get_auth_data()))
|
246
|
+
|
247
|
+
welcome = await websocket.recv()
|
248
|
+
welcome_data = json.loads(welcome)
|
249
|
+
|
250
|
+
if "error" in welcome_data:
|
251
|
+
raise RuntimeError(f"Authentication failed: {welcome_data['error']}")
|
252
|
+
|
253
|
+
# Extract vLLM config from processor config
|
254
|
+
processor_config = welcome_data.get("processor_config", {})
|
255
|
+
self.vllm_config = processor_config.get("vllm", {})
|
256
|
+
|
257
|
+
if not self.vllm_config:
|
258
|
+
raise RuntimeError("No vLLM configuration received from orchestrator")
|
259
|
+
|
260
|
+
# Parse stages
|
261
|
+
self.stages = self._parse_stages_config(self.vllm_config)
|
262
|
+
self.stage_order = self._topological_sort_stages(self.stages)
|
263
|
+
|
264
|
+
logger.info(f"Configured {len(self.stages)} processing stages: {self.stage_order}")
|
265
|
+
|
266
|
+
async def _handle_welcome(self, welcome_data: Dict[str, Any]):
|
267
|
+
"""Handle welcome message from orchestrator."""
|
268
|
+
with self.work_lock:
|
269
|
+
self.assigned_units.clear()
|
270
|
+
self.current_unit = None
|
271
|
+
|
272
|
+
self._clear_queue(self.readahead_queue)
|
273
|
+
self._clear_queue(self.inference_queue)
|
274
|
+
self._clear_queue(self.result_queue)
|
275
|
+
|
276
|
+
# Reset counters
|
277
|
+
self.items_processed = 0
|
278
|
+
self.items_failed = 0
|
279
|
+
self.units_completed = 0
|
280
|
+
|
281
|
+
# Setup processor
|
282
|
+
self.processor_type = welcome_data.get("processor_type", None)
|
283
|
+
assert self.processor_type is not None, "Processor type not found in welcome data"
|
284
|
+
logger.info(f"Creating {self.processor_type} processor")
|
285
|
+
processor_config = ProcessorConfig(
|
286
|
+
processor_type=self.processor_type, config=welcome_data.get("processor_config", {})
|
287
|
+
)
|
288
|
+
|
289
|
+
if self.processor_type == "webdataset":
|
290
|
+
self.processor = WebDatasetWorkerProcessor()
|
291
|
+
elif self.processor_type == "huggingface_datasets":
|
292
|
+
self.processor = HuggingFaceDatasetWorkerProcessor()
|
293
|
+
elif self.processor_type == "local_filesystem":
|
294
|
+
self.processor = LocalFilesystemWorkerProcessor()
|
295
|
+
else:
|
296
|
+
raise ValueError(f"Unknown processor type: {self.processor_type}")
|
297
|
+
|
298
|
+
self.processor.initialize(processor_config)
|
299
|
+
self.dataset_path = self.processor.dataset_path
|
300
|
+
|
301
|
+
# Update vLLM config if provided
|
302
|
+
new_vllm_config = welcome_data.get("processor_config", {}).get("vllm")
|
303
|
+
if new_vllm_config and new_vllm_config != self.vllm_config:
|
304
|
+
logger.info("Received updated vLLM configuration")
|
305
|
+
self._handle_vllm_config_update(new_vllm_config)
|
306
|
+
|
307
|
+
# Clear stop signal
|
308
|
+
self.should_stop_processing.clear()
|
309
|
+
|
310
|
+
# Request initial work
|
311
|
+
if self.websocket:
|
312
|
+
await self.websocket.send(json.dumps({"type": "request_work", "count": 2}))
|
313
|
+
|
314
|
+
async def _handle_message(self, data: Dict[str, Any]):
|
315
|
+
"""Handle message from orchestrator."""
|
316
|
+
msg_type = data.get("type")
|
317
|
+
|
318
|
+
if msg_type == "work_assignment":
|
319
|
+
assignment = WorkAssignment.from_dict(data["assignment"])
|
320
|
+
with self.work_lock:
|
321
|
+
for unit in assignment.units:
|
322
|
+
self.assigned_units.append(unit)
|
323
|
+
logger.info(f"Received {len(assignment.units)} work units")
|
300
324
|
|
301
|
-
|
302
|
-
|
325
|
+
elif msg_type == "no_work":
|
326
|
+
logger.info("No work available")
|
327
|
+
await asyncio.sleep(10)
|
328
|
+
|
329
|
+
if self.websocket and self.connected.is_set():
|
330
|
+
await self.websocket.send(json.dumps({"type": "request_work", "count": 2}))
|
303
331
|
|
304
332
|
def _parse_stages_config(self, vllm_config: Dict[str, Any]) -> List[ProcessingStage]:
|
305
333
|
"""Parse stages configuration from vLLM config."""
|
306
334
|
stages_config = vllm_config.get("stages", [])
|
307
335
|
|
308
336
|
if not stages_config:
|
309
|
-
# Backward compatibility
|
337
|
+
# Backward compatibility
|
310
338
|
return [
|
311
339
|
ProcessingStage(
|
312
340
|
name="default",
|
@@ -338,10 +366,8 @@ class CaptionWorker(BaseWorker):
|
|
338
366
|
|
339
367
|
def _topological_sort_stages(self, stages: List[ProcessingStage]) -> List[str]:
|
340
368
|
"""Sort stages by dependencies."""
|
341
|
-
# Build dependency graph
|
342
369
|
graph = defaultdict(list)
|
343
370
|
in_degree = defaultdict(int)
|
344
|
-
|
345
371
|
stage_map = {s.name: s for s in stages}
|
346
372
|
|
347
373
|
for stage in stages:
|
@@ -351,7 +377,6 @@ class CaptionWorker(BaseWorker):
|
|
351
377
|
raise ValueError(f"Stage '{stage.name}' requires missing dependency '{dep}'")
|
352
378
|
graph[dep].append(stage.name)
|
353
379
|
|
354
|
-
# Topological sort using Kahn's algorithm
|
355
380
|
queue = deque([name for name, degree in in_degree.items() if degree == 0])
|
356
381
|
result = []
|
357
382
|
|
@@ -369,204 +394,10 @@ class CaptionWorker(BaseWorker):
|
|
369
394
|
|
370
395
|
return result
|
371
396
|
|
372
|
-
async def _handle_welcome(self, welcome_data: Dict[str, Any]):
|
373
|
-
"""Handle welcome message from orchestrator."""
|
374
|
-
# Extract and setup dataset configuration
|
375
|
-
dataset_config = welcome_data.get("dataset_config", {})
|
376
|
-
if dataset_config:
|
377
|
-
self._setup_dataset_loader(dataset_config)
|
378
|
-
logger.info(f"Received dataset config: {dataset_config}")
|
379
|
-
else:
|
380
|
-
logger.warning("No dataset configuration received from orchestrator")
|
381
|
-
|
382
|
-
# Update vLLM config if provided (in case it changed)
|
383
|
-
new_vllm_config = welcome_data.get("vllm_config")
|
384
|
-
if new_vllm_config and new_vllm_config != self.vllm_config:
|
385
|
-
logger.info("Received updated vLLM configuration")
|
386
|
-
if not self._handle_vllm_config_update(new_vllm_config):
|
387
|
-
logger.error("Failed to update vLLM configuration")
|
388
|
-
|
389
|
-
# Clear stop signal now that we're connected
|
390
|
-
self.should_stop_processing.clear()
|
391
|
-
|
392
|
-
# Request initial chunks if not in job mode
|
393
|
-
if not self.job_mode and self.websocket:
|
394
|
-
await self.websocket.send(json.dumps({"type": "request_chunks", "count": 2}))
|
395
|
-
|
396
|
-
async def _handle_message(self, data: Dict[str, Any]):
|
397
|
-
"""Handle message from orchestrator."""
|
398
|
-
msg_type = data.get("type")
|
399
|
-
|
400
|
-
if msg_type == "shard_assignment":
|
401
|
-
chunks = data["chunks"]
|
402
|
-
for chunk_data in chunks:
|
403
|
-
chunk = ShardChunk(**chunk_data)
|
404
|
-
with self.chunk_lock:
|
405
|
-
self.assigned_chunks.append(chunk)
|
406
|
-
logger.info(f"Received chunk assignment: {chunk.chunk_id}")
|
407
|
-
|
408
|
-
elif msg_type == "no_chunks":
|
409
|
-
reason = data.get("reason", "unknown")
|
410
|
-
logger.info(f"No chunks available from orchestrator (reason: {reason})")
|
411
|
-
|
412
|
-
wait_time = 2 if reason == "state_restoring" else 10
|
413
|
-
await asyncio.sleep(wait_time)
|
414
|
-
|
415
|
-
if self.websocket and self.connected.is_set():
|
416
|
-
await self.websocket.send(json.dumps({"type": "request_chunks", "count": 2}))
|
417
|
-
|
418
|
-
elif msg_type == "reload_vllm":
|
419
|
-
logger.info("Orchestrator requested vLLM reload")
|
420
|
-
new_config = data.get("vllm_config")
|
421
|
-
if new_config:
|
422
|
-
self._handle_vllm_config_update(new_config)
|
423
|
-
|
424
|
-
elif msg_type == "config_update":
|
425
|
-
# Soft config update without reload
|
426
|
-
if data.get("vllm_config"):
|
427
|
-
self._handle_vllm_config_update(data["vllm_config"])
|
428
|
-
|
429
|
-
elif msg_type == "job_assignment":
|
430
|
-
await self._handle_job_assignment(data["job"])
|
431
|
-
|
432
|
-
elif msg_type == "no_jobs":
|
433
|
-
logger.debug("No jobs available")
|
434
|
-
await asyncio.sleep(2)
|
435
|
-
|
436
|
-
def _get_heartbeat_data(self) -> Dict[str, Any]:
|
437
|
-
"""Get heartbeat data."""
|
438
|
-
return {
|
439
|
-
"type": "heartbeat",
|
440
|
-
"processed": self.items_processed,
|
441
|
-
"failed": self.items_failed,
|
442
|
-
"chunks_completed": self.chunks_completed,
|
443
|
-
"current_chunk": self.current_chunk.chunk_id if self.current_chunk else None,
|
444
|
-
"chunk_progress": self.current_chunk_progress,
|
445
|
-
"queue_sizes": {
|
446
|
-
"readahead": self.readahead_queue.qsize(),
|
447
|
-
"inference": self.inference_queue.qsize(),
|
448
|
-
"results": self.result_queue.qsize(),
|
449
|
-
},
|
450
|
-
"stages": len(self.stages),
|
451
|
-
"models_loaded": len(self.model_manager.models) if self.model_manager else 0,
|
452
|
-
}
|
453
|
-
|
454
|
-
async def _create_tasks(self) -> list:
|
455
|
-
"""Create async tasks to run."""
|
456
|
-
tasks = [
|
457
|
-
asyncio.create_task(self._heartbeat_loop()),
|
458
|
-
asyncio.create_task(self._base_message_handler()),
|
459
|
-
asyncio.create_task(self._result_sender()),
|
460
|
-
]
|
461
|
-
|
462
|
-
if self.job_mode:
|
463
|
-
tasks.append(asyncio.create_task(self._job_request_loop()))
|
464
|
-
|
465
|
-
return tasks
|
466
|
-
|
467
|
-
async def _on_disconnect(self):
|
468
|
-
"""Handle disconnection."""
|
469
|
-
self._clear_state_on_disconnect()
|
470
|
-
|
471
|
-
async def _pre_shutdown(self):
|
472
|
-
"""Cleanup before shutdown."""
|
473
|
-
# Stop processing threads by adding stop signals
|
474
|
-
self.readahead_queue.put(None)
|
475
|
-
self.inference_queue.put(None)
|
476
|
-
|
477
|
-
# Shutdown image processor
|
478
|
-
if self.image_processor is not None:
|
479
|
-
self.image_processor.shutdown()
|
480
|
-
|
481
|
-
# Cleanup model manager
|
482
|
-
if self.model_manager:
|
483
|
-
self.model_manager.cleanup()
|
484
|
-
|
485
|
-
async def _initial_connect_for_config(self):
|
486
|
-
"""Connect initially just to get configuration."""
|
487
|
-
logger.info(f"Connecting to {self.server_url}")
|
488
|
-
async with websockets.connect(self.server_url, ssl=self.ssl_context) as websocket:
|
489
|
-
await websocket.send(json.dumps(self._get_auth_data()))
|
490
|
-
|
491
|
-
welcome = await websocket.recv()
|
492
|
-
welcome_data = json.loads(welcome)
|
493
|
-
|
494
|
-
if "error" in welcome_data:
|
495
|
-
raise RuntimeError(f"Authentication failed: {welcome_data['error']}")
|
496
|
-
|
497
|
-
self.vllm_config = welcome_data.get("vllm_config")
|
498
|
-
if not self.vllm_config:
|
499
|
-
raise RuntimeError("No vLLM configuration received from orchestrator")
|
500
|
-
|
501
|
-
# Parse stages configuration
|
502
|
-
self.stages = self._parse_stages_config(self.vllm_config)
|
503
|
-
self.stage_order = self._topological_sort_stages(self.stages)
|
504
|
-
|
505
|
-
logger.info(f"Configured {len(self.stages)} processing stages: {self.stage_order}")
|
506
|
-
|
507
|
-
self.vllm_config_manager.current_config = self.vllm_config
|
508
|
-
|
509
|
-
dataset_config = welcome_data.get("dataset_config", {})
|
510
|
-
if dataset_config:
|
511
|
-
self._setup_dataset_loader(dataset_config)
|
512
|
-
|
513
|
-
logger.info("Received configuration from orchestrator")
|
514
|
-
|
515
|
-
def _clear_state_on_disconnect(self):
|
516
|
-
"""Clear all processing state when disconnected."""
|
517
|
-
logger.info("Clearing state due to disconnection")
|
518
|
-
|
519
|
-
self.should_stop_processing.set()
|
520
|
-
|
521
|
-
with self.chunk_lock:
|
522
|
-
self.assigned_chunks.clear()
|
523
|
-
self.current_chunk = None
|
524
|
-
self.current_chunk_progress = 0
|
525
|
-
|
526
|
-
self._clear_queue(self.readahead_queue)
|
527
|
-
self._clear_queue(self.inference_queue)
|
528
|
-
self._clear_queue(self.result_queue)
|
529
|
-
|
530
|
-
logger.info("State cleared, ready for reconnection")
|
531
|
-
|
532
|
-
def _clear_queue(self, queue: Queue):
|
533
|
-
"""Clear all items from a queue."""
|
534
|
-
try:
|
535
|
-
while True:
|
536
|
-
queue.get_nowait()
|
537
|
-
except Empty:
|
538
|
-
pass
|
539
|
-
|
540
|
-
def _setup_dataset_loader(self, dataset_config: Dict[str, Any]):
|
541
|
-
"""Initialize dataset loader with config from orchestrator."""
|
542
|
-
dataset_path = dataset_config.get("dataset_path") or dataset_config.get("path")
|
543
|
-
dataset_type = dataset_config.get("dataset_type") or dataset_config.get(
|
544
|
-
"type", "huggingface"
|
545
|
-
)
|
546
|
-
dataset_split = dataset_config.get("dataset_split") or dataset_config.get("split", "train")
|
547
|
-
dataset_image_column = dataset_config.get("dataset_image_column") or dataset_config.get(
|
548
|
-
"image_column", "image"
|
549
|
-
)
|
550
|
-
|
551
|
-
if dataset_path:
|
552
|
-
logger.info(
|
553
|
-
f"Initializing dataset loader for {dataset_type}: {dataset_path} "
|
554
|
-
f"(split: {dataset_split}, image_column: {dataset_image_column})"
|
555
|
-
)
|
556
|
-
self.dataset_loader = DatasetLoader(
|
557
|
-
dataset_path, dataset_type, dataset_split, dataset_image_column
|
558
|
-
)
|
559
|
-
self.dataset_config = dataset_config
|
560
|
-
self.dataset_type = dataset_type
|
561
|
-
self.dataset_split = dataset_split
|
562
|
-
self.dataset_image_column = dataset_image_column
|
563
|
-
else:
|
564
|
-
logger.warning("No dataset path provided by orchestrator")
|
565
|
-
|
566
397
|
def _setup_vllm(self):
|
567
398
|
"""Initialize multi-stage vLLM components."""
|
568
399
|
if not self.vllm_config:
|
569
|
-
raise RuntimeError("vLLM config not received
|
400
|
+
raise RuntimeError("vLLM config not received")
|
570
401
|
|
571
402
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(self.gpu_id)
|
572
403
|
|
@@ -601,11 +432,8 @@ class CaptionWorker(BaseWorker):
|
|
601
432
|
|
602
433
|
logger.info("Multi-stage vLLM initialization complete")
|
603
434
|
|
604
|
-
# Update config manager's tracking
|
605
|
-
self.vllm_config_manager.current_config = self.vllm_config
|
606
|
-
|
607
435
|
def _handle_vllm_config_update(self, new_config: Dict[str, Any]) -> bool:
|
608
|
-
"""Handle vLLM configuration updates
|
436
|
+
"""Handle vLLM configuration updates."""
|
609
437
|
if not new_config:
|
610
438
|
return True
|
611
439
|
|
@@ -628,112 +456,107 @@ class CaptionWorker(BaseWorker):
|
|
628
456
|
if stages_changed:
|
629
457
|
logger.info("Stage configuration changed, reloading all models")
|
630
458
|
|
631
|
-
# Save old config
|
632
459
|
old_config = self.vllm_config
|
633
460
|
self.vllm_config = new_config
|
634
461
|
self.stages = new_stages
|
635
462
|
self.stage_order = self._topological_sort_stages(self.stages)
|
636
463
|
|
637
464
|
try:
|
638
|
-
# Cleanup old models
|
639
465
|
if self.model_manager:
|
640
466
|
self.model_manager.cleanup()
|
641
|
-
|
642
|
-
# Reload with new config
|
643
467
|
self._setup_vllm()
|
644
|
-
|
645
|
-
logger.info("Multi-stage vLLM reload complete")
|
646
468
|
return True
|
647
|
-
|
648
469
|
except Exception as e:
|
649
470
|
logger.error(f"Failed to reload vLLM: {e}")
|
650
|
-
# Restore old config
|
651
471
|
self.vllm_config = old_config
|
652
472
|
return False
|
653
473
|
else:
|
654
|
-
# Just update sampling params
|
474
|
+
# Just update sampling params
|
655
475
|
logger.info("Updating sampling parameters without model reload")
|
656
|
-
|
657
476
|
base_sampling = new_config.get("sampling", {})
|
658
477
|
for stage in self.stages:
|
659
478
|
self.model_manager.create_sampling_params(stage, base_sampling)
|
660
|
-
|
661
479
|
self.vllm_config = new_config
|
662
480
|
return True
|
663
481
|
|
664
|
-
|
665
|
-
"""
|
666
|
-
|
667
|
-
# Convert to processing item
|
668
|
-
image = Image.open(io.BytesIO(job_data["image_data"]))
|
669
|
-
|
670
|
-
item = ProcessingItem(
|
671
|
-
chunk_id=job_data["job_id"],
|
672
|
-
item_key=job_data["sample_id"],
|
673
|
-
image=image,
|
674
|
-
image_data=job_data["image_data"],
|
675
|
-
)
|
482
|
+
def _unit_processor_thread(self):
|
483
|
+
"""Background thread that processes work units."""
|
484
|
+
logger.info("Starting unit processor thread")
|
676
485
|
|
677
|
-
|
678
|
-
self.
|
679
|
-
|
486
|
+
while self.running:
|
487
|
+
if self.should_stop_processing.is_set():
|
488
|
+
time.sleep(1)
|
489
|
+
continue
|
680
490
|
|
681
|
-
|
682
|
-
|
491
|
+
if not self.connected.is_set():
|
492
|
+
time.sleep(1)
|
493
|
+
continue
|
494
|
+
|
495
|
+
# Get next unit
|
496
|
+
with self.work_lock:
|
497
|
+
if not self.current_unit and self.assigned_units:
|
498
|
+
self.current_unit = self.assigned_units.popleft()
|
499
|
+
logger.info(f"Starting unit {self._get_current_unit_id()}")
|
500
|
+
|
501
|
+
if not self.current_unit:
|
502
|
+
time.sleep(1)
|
503
|
+
continue
|
683
504
|
|
684
|
-
async def _job_request_loop(self):
|
685
|
-
"""Request jobs from orchestrator in job mode."""
|
686
|
-
while self.running and self.connected.is_set():
|
687
505
|
try:
|
688
|
-
|
689
|
-
if self.readahead_queue.qsize() < self.vllm_config.get("batch_size", 8):
|
690
|
-
await self.websocket.send(json.dumps({"type": "request_job"}))
|
506
|
+
self._process_work_unit(self.current_unit)
|
691
507
|
|
692
|
-
|
508
|
+
if self.connected.is_set() and not self.should_stop_processing.is_set():
|
509
|
+
logger.info(f"Completed unit {self._get_current_unit_id()}")
|
510
|
+
self.units_completed += 1
|
693
511
|
|
694
|
-
|
695
|
-
|
696
|
-
|
512
|
+
# Request more work if needed
|
513
|
+
with self.work_lock:
|
514
|
+
queue_size = len(self.assigned_units)
|
697
515
|
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
516
|
+
if queue_size < 2 and self.websocket and self.main_loop:
|
517
|
+
try:
|
518
|
+
asyncio.run_coroutine_threadsafe(
|
519
|
+
self.websocket.send(
|
520
|
+
json.dumps({"type": "request_work", "count": 2})
|
521
|
+
),
|
522
|
+
self.main_loop,
|
523
|
+
).result(timeout=5)
|
524
|
+
except Exception as e:
|
525
|
+
logger.warning(f"Could not request more work: {e}")
|
703
526
|
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
527
|
+
with self.work_lock:
|
528
|
+
self.current_unit = None
|
529
|
+
|
530
|
+
except Exception as e:
|
531
|
+
logger.error(f"Error processing unit: {e}", exc_info=True)
|
532
|
+
with self.work_lock:
|
533
|
+
self.current_unit = None
|
534
|
+
|
535
|
+
def _process_work_unit(self, unit: WorkUnit):
|
536
|
+
"""Process a single work unit."""
|
537
|
+
if not self.processor:
|
538
|
+
logger.error("Processor not initialized")
|
539
|
+
return
|
709
540
|
|
710
541
|
items_processed = 0
|
542
|
+
context = {} # Will store processed indices
|
711
543
|
|
712
|
-
#
|
713
|
-
for
|
714
|
-
chunk, self.dataset_loader, self.should_stop_processing, self.connected
|
715
|
-
):
|
544
|
+
# Get items from processor
|
545
|
+
for item_data in self.processor.process_unit(unit, context):
|
716
546
|
try:
|
717
|
-
# Load image
|
718
|
-
img = Image.open(io.BytesIO(image_data))
|
719
|
-
|
720
547
|
# Create processing item
|
721
548
|
item = ProcessingItem(
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
549
|
+
unit_id=unit.unit_id,
|
550
|
+
chunk_id=unit.chunk_id,
|
551
|
+
job_id=item_data["job_id"],
|
552
|
+
item_key=item_data["item_key"],
|
553
|
+
item_index=item_data["item_index"],
|
554
|
+
image=item_data["image"],
|
555
|
+
image_data=item_data.get("image_data", b""),
|
556
|
+
metadata=item_data.get("metadata", {}),
|
727
557
|
)
|
728
558
|
|
729
|
-
#
|
730
|
-
# The processor should provide the correct index in metadata
|
731
|
-
if "_chunk_relative_index" in metadata:
|
732
|
-
item.metadata["_item_index"] = (
|
733
|
-
chunk.start_index + metadata["_chunk_relative_index"]
|
734
|
-
)
|
735
|
-
|
736
|
-
# Add to readahead queue with timeout handling
|
559
|
+
# Add to readahead queue
|
737
560
|
timeout_end = time.time() + 30
|
738
561
|
while (
|
739
562
|
self.running
|
@@ -748,9 +571,7 @@ class CaptionWorker(BaseWorker):
|
|
748
571
|
raise TimeoutError("Queue put timeout")
|
749
572
|
continue
|
750
573
|
|
751
|
-
# If we couldn't queue due to disconnection, stop processing
|
752
574
|
if not self.connected.is_set() or self.should_stop_processing.is_set():
|
753
|
-
logger.debug(f"Skipping remaining items due to disconnection")
|
754
575
|
break
|
755
576
|
|
756
577
|
items_processed += 1
|
@@ -763,199 +584,22 @@ class CaptionWorker(BaseWorker):
|
|
763
584
|
except Exception as e:
|
764
585
|
if self.should_stop_processing.is_set():
|
765
586
|
break
|
766
|
-
logger.error(f"Error processing item {
|
587
|
+
logger.error(f"Error processing item {item_data.get('item_key')}: {e}")
|
767
588
|
self.items_failed += 1
|
768
589
|
|
769
|
-
# Process any remaining items
|
590
|
+
# Process any remaining items
|
770
591
|
if not self.should_stop_processing.is_set():
|
771
592
|
self._batch_for_inference()
|
593
|
+
if self.connected.is_set():
|
594
|
+
# Notify orchestrator that unit is complete
|
595
|
+
asyncio.run_coroutine_threadsafe(
|
596
|
+
self.websocket.send(
|
597
|
+
json.dumps({"type": "work_complete", "unit_id": unit.unit_id})
|
598
|
+
),
|
599
|
+
self.main_loop,
|
600
|
+
).result(timeout=5)
|
772
601
|
|
773
|
-
logger.info(
|
774
|
-
f"Chunk {chunk.chunk_id} processed {items_processed} items from unprocessed ranges"
|
775
|
-
)
|
776
|
-
|
777
|
-
def _shard_reader_thread(self):
|
778
|
-
"""Background thread that reads from WebDataset shards."""
|
779
|
-
logger.info("Starting shard reader thread")
|
780
|
-
|
781
|
-
while self.running:
|
782
|
-
# Check if we should stop processing
|
783
|
-
if self.should_stop_processing.is_set():
|
784
|
-
logger.info("Shard reader waiting for reconnection")
|
785
|
-
time.sleep(1)
|
786
|
-
continue
|
787
|
-
|
788
|
-
# Only process if connected
|
789
|
-
if not self.connected.is_set():
|
790
|
-
time.sleep(1)
|
791
|
-
continue
|
792
|
-
|
793
|
-
# Get next chunk to process
|
794
|
-
with self.chunk_lock:
|
795
|
-
if not self.current_chunk and self.assigned_chunks:
|
796
|
-
self.current_chunk = self.assigned_chunks.popleft()
|
797
|
-
self.current_chunk_progress = 0
|
798
|
-
logger.info(f"Starting chunk {self.current_chunk.chunk_id}")
|
799
|
-
|
800
|
-
if not self.current_chunk:
|
801
|
-
time.sleep(1)
|
802
|
-
continue
|
803
|
-
|
804
|
-
try:
|
805
|
-
# Process the chunk
|
806
|
-
self._process_shard_chunk(self.current_chunk)
|
807
|
-
|
808
|
-
# Only mark complete if still connected
|
809
|
-
if self.connected.is_set() and not self.should_stop_processing.is_set():
|
810
|
-
logger.info(f"Completed chunk {self.current_chunk.chunk_id}")
|
811
|
-
self.chunks_completed += 1
|
812
|
-
|
813
|
-
# Notify orchestrator if connected
|
814
|
-
if self.websocket and self.main_loop:
|
815
|
-
try:
|
816
|
-
# Notify completion
|
817
|
-
asyncio.run_coroutine_threadsafe(
|
818
|
-
self.websocket.send(
|
819
|
-
json.dumps(
|
820
|
-
{
|
821
|
-
"type": "chunk_complete",
|
822
|
-
"chunk_id": self.current_chunk.chunk_id,
|
823
|
-
}
|
824
|
-
)
|
825
|
-
),
|
826
|
-
self.main_loop,
|
827
|
-
).result(timeout=5)
|
828
|
-
|
829
|
-
# Request more chunks if queue is low
|
830
|
-
with self.chunk_lock:
|
831
|
-
queue_size = len(self.assigned_chunks)
|
832
|
-
|
833
|
-
if queue_size < 2:
|
834
|
-
logger.info(f"Requesting more chunks (queue size: {queue_size})")
|
835
|
-
asyncio.run_coroutine_threadsafe(
|
836
|
-
self.websocket.send(
|
837
|
-
json.dumps({"type": "request_chunks", "count": 2})
|
838
|
-
),
|
839
|
-
self.main_loop,
|
840
|
-
).result(timeout=5)
|
841
|
-
|
842
|
-
except Exception as e:
|
843
|
-
logger.warning(f"Could not notify orchestrator: {e}")
|
844
|
-
|
845
|
-
with self.chunk_lock:
|
846
|
-
self.current_chunk = None
|
847
|
-
|
848
|
-
except Exception as e:
|
849
|
-
logger.error(f"Error processing chunk: {e}")
|
850
|
-
|
851
|
-
# Only notify of failure if still connected
|
852
|
-
if self.connected.is_set() and self.websocket and self.main_loop:
|
853
|
-
try:
|
854
|
-
asyncio.run_coroutine_threadsafe(
|
855
|
-
self.websocket.send(
|
856
|
-
json.dumps(
|
857
|
-
{
|
858
|
-
"type": "chunk_failed",
|
859
|
-
"chunk_id": (
|
860
|
-
self.current_chunk.chunk_id
|
861
|
-
if self.current_chunk
|
862
|
-
else "unknown"
|
863
|
-
),
|
864
|
-
"error": str(e),
|
865
|
-
}
|
866
|
-
)
|
867
|
-
),
|
868
|
-
self.main_loop,
|
869
|
-
).result(timeout=5)
|
870
|
-
except Exception as send_error:
|
871
|
-
logger.warning(
|
872
|
-
f"Could not notify orchestrator of chunk failure: {send_error}"
|
873
|
-
)
|
874
|
-
|
875
|
-
with self.chunk_lock:
|
876
|
-
self.current_chunk = None
|
877
|
-
|
878
|
-
async def _result_sender(self):
|
879
|
-
"""Send results back to orchestrator with item index."""
|
880
|
-
pending_results = []
|
881
|
-
|
882
|
-
try:
|
883
|
-
while self.running and self.connected.is_set():
|
884
|
-
try:
|
885
|
-
# Get result with timeout
|
886
|
-
try:
|
887
|
-
result = await asyncio.get_event_loop().run_in_executor(
|
888
|
-
None, self.result_queue.get, True, 1
|
889
|
-
)
|
890
|
-
pending_results.append(result)
|
891
|
-
except Empty:
|
892
|
-
pass
|
893
|
-
|
894
|
-
# Only try to send if connected
|
895
|
-
if pending_results and self.websocket and self.connected.is_set():
|
896
|
-
sent_results = []
|
897
|
-
for result in pending_results:
|
898
|
-
try:
|
899
|
-
# Build message with item index
|
900
|
-
message_data = {
|
901
|
-
"type": "submit_captions",
|
902
|
-
"chunk_id": result.chunk_id,
|
903
|
-
"dataset": self.dataset_config.get("dataset_path", "unknown"),
|
904
|
-
"shard": result.shard_name,
|
905
|
-
"item_key": result.item_key,
|
906
|
-
"item_index": result.item_index, # NEW: Include index
|
907
|
-
"outputs": result.outputs,
|
908
|
-
"captions": result.outputs.get("captions", []), # Compatibility
|
909
|
-
"caption_count": sum(len(v) for v in result.outputs.values()),
|
910
|
-
"image_width": result.image_width,
|
911
|
-
"image_height": result.image_height,
|
912
|
-
"image_format": result.image_format,
|
913
|
-
"file_size": result.file_size,
|
914
|
-
"processing_time_ms": result.processing_time_ms,
|
915
|
-
"metadata": result.metadata,
|
916
|
-
}
|
917
|
-
|
918
|
-
await self.websocket.send(json.dumps(message_data))
|
919
|
-
sent_results.append(result)
|
920
|
-
|
921
|
-
if self.items_processed % 100 == 0:
|
922
|
-
total_outputs = sum(
|
923
|
-
len(outputs) for outputs in result.outputs.values()
|
924
|
-
)
|
925
|
-
logger.info(
|
926
|
-
f"Processed {self.items_processed} items "
|
927
|
-
f"(~{total_outputs} outputs across {len(result.outputs)} fields)"
|
928
|
-
)
|
929
|
-
|
930
|
-
except websockets.exceptions.ConnectionClosed as e:
|
931
|
-
logger.warning(f"Connection lost while sending result: {e}")
|
932
|
-
raise
|
933
|
-
except Exception as e:
|
934
|
-
logger.error(f"Error sending result: {e}")
|
935
|
-
break
|
936
|
-
|
937
|
-
# Remove successfully sent results
|
938
|
-
for result in sent_results:
|
939
|
-
pending_results.remove(result)
|
940
|
-
|
941
|
-
# Clear pending results if disconnected and buffer is too large
|
942
|
-
if not self.connected.is_set() and len(pending_results) > 1000:
|
943
|
-
logger.warning(
|
944
|
-
f"Clearing {len(pending_results)} pending results due to prolonged disconnection"
|
945
|
-
)
|
946
|
-
pending_results.clear()
|
947
|
-
|
948
|
-
await asyncio.sleep(0.1)
|
949
|
-
|
950
|
-
except Exception as e:
|
951
|
-
if isinstance(e, websockets.exceptions.ConnectionClosed):
|
952
|
-
raise
|
953
|
-
logger.error(f"Unexpected error in result sender: {e}")
|
954
|
-
await asyncio.sleep(1)
|
955
|
-
|
956
|
-
except asyncio.CancelledError:
|
957
|
-
logger.debug("Result sender cancelled")
|
958
|
-
raise
|
602
|
+
logger.info(f"Unit {unit.unit_id} processed {items_processed} items")
|
959
603
|
|
960
604
|
def _batch_for_inference(self):
|
961
605
|
"""Batch items from readahead queue for inference."""
|
@@ -972,9 +616,52 @@ class CaptionWorker(BaseWorker):
|
|
972
616
|
if batch:
|
973
617
|
self.inference_queue.put(batch)
|
974
618
|
|
619
|
+
def _inference_thread(self):
|
620
|
+
"""Background thread for multi-stage vLLM inference."""
|
621
|
+
logger.info("Starting multi-stage inference thread")
|
622
|
+
|
623
|
+
while self.running:
|
624
|
+
try:
|
625
|
+
batch = self.inference_queue.get(timeout=1)
|
626
|
+
if not batch:
|
627
|
+
continue
|
628
|
+
|
629
|
+
if self.should_stop_processing.is_set():
|
630
|
+
continue
|
631
|
+
|
632
|
+
logger.debug(
|
633
|
+
f"Processing batch of {len(batch)} images through {len(self.stages)} stages"
|
634
|
+
)
|
635
|
+
start_time = time.time()
|
636
|
+
|
637
|
+
# Process batch through all stages
|
638
|
+
results = self._process_batch_multi_stage(batch)
|
639
|
+
|
640
|
+
# Calculate processing time
|
641
|
+
if results:
|
642
|
+
processing_time_per_item = (time.time() - start_time) * 1000 / len(batch)
|
643
|
+
|
644
|
+
for item, result_outputs in results:
|
645
|
+
self.result_queue.put(
|
646
|
+
{
|
647
|
+
"item": item,
|
648
|
+
"outputs": result_outputs,
|
649
|
+
"processing_time_ms": processing_time_per_item,
|
650
|
+
}
|
651
|
+
)
|
652
|
+
|
653
|
+
logger.debug(f"Batch processing complete: {len(results)} successful")
|
654
|
+
|
655
|
+
except Empty:
|
656
|
+
continue
|
657
|
+
except Exception as e:
|
658
|
+
if self.should_stop_processing.is_set():
|
659
|
+
continue
|
660
|
+
logger.error(f"Inference error: {e}", exc_info=True)
|
661
|
+
|
975
662
|
def _process_batch_multi_stage(
|
976
663
|
self, batch: List[ProcessingItem], max_attempts: int = 3
|
977
|
-
) -> List[
|
664
|
+
) -> List[Tuple[ProcessingItem, Dict]]:
|
978
665
|
"""Process a batch through all stages sequentially."""
|
979
666
|
results = []
|
980
667
|
|
@@ -983,7 +670,7 @@ class CaptionWorker(BaseWorker):
|
|
983
670
|
stage = next(s for s in self.stages if s.name == stage_name)
|
984
671
|
logger.debug(f"Processing batch through stage: {stage_name}")
|
985
672
|
|
986
|
-
# Get model components
|
673
|
+
# Get model components
|
987
674
|
llm, processor, tokenizer, sampling_params = self.model_manager.get_model_for_stage(
|
988
675
|
stage_name, stage.model
|
989
676
|
)
|
@@ -992,39 +679,34 @@ class CaptionWorker(BaseWorker):
|
|
992
679
|
items_to_process = [(i, item, 0) for i, item in enumerate(batch)]
|
993
680
|
|
994
681
|
while items_to_process:
|
995
|
-
# Build requests for current items
|
996
682
|
current_batch = []
|
997
|
-
current_indices = []
|
998
683
|
requests = []
|
999
684
|
|
1000
685
|
for idx, (original_idx, item, attempt_count) in enumerate(items_to_process):
|
1001
686
|
current_batch.append((original_idx, item, attempt_count))
|
1002
|
-
current_indices.append(idx)
|
1003
687
|
|
1004
688
|
# Prepare image
|
1005
689
|
converted_img = ImageProcessor.prepare_for_inference(item.image)
|
1006
690
|
|
1007
|
-
# Create template manager
|
691
|
+
# Create template manager
|
1008
692
|
template_manager = PromptTemplateManager(stage.prompts)
|
1009
693
|
|
1010
|
-
# Build context
|
694
|
+
# Build context
|
1011
695
|
context = item.metadata.copy()
|
1012
696
|
|
1013
|
-
# Add previous stage
|
697
|
+
# Add previous stage results
|
1014
698
|
for prev_stage_name, stage_result in item.stage_results.items():
|
1015
|
-
# Add outputs with stage name prefix
|
1016
699
|
for i, output in enumerate(stage_result.outputs):
|
1017
700
|
context[f"{prev_stage_name}_output_{i}"] = output
|
1018
|
-
# Also add under output field name
|
1019
701
|
if len(stage_result.outputs) == 1:
|
1020
702
|
context[stage_result.output_field] = stage_result.outputs[0]
|
1021
703
|
else:
|
1022
704
|
context[stage_result.output_field] = stage_result.outputs
|
1023
705
|
|
1024
|
-
# Format prompts
|
706
|
+
# Format prompts
|
1025
707
|
formatted_prompts = template_manager.format_all(context)
|
1026
708
|
|
1027
|
-
# Build requests
|
709
|
+
# Build requests
|
1028
710
|
for prompt in formatted_prompts:
|
1029
711
|
req = self._build_vllm_input(converted_img, prompt, processor, tokenizer)
|
1030
712
|
requests.append(req)
|
@@ -1037,11 +719,10 @@ class CaptionWorker(BaseWorker):
|
|
1037
719
|
failed_items = []
|
1038
720
|
|
1039
721
|
for idx, (original_idx, item, attempt_count) in enumerate(current_batch):
|
1040
|
-
# Check if we should stop
|
1041
722
|
if self.should_stop_processing.is_set():
|
1042
723
|
return results
|
1043
724
|
|
1044
|
-
# Extract outputs
|
725
|
+
# Extract outputs
|
1045
726
|
base_idx = idx * len(stage.prompts)
|
1046
727
|
stage_outputs = []
|
1047
728
|
|
@@ -1051,13 +732,9 @@ class CaptionWorker(BaseWorker):
|
|
1051
732
|
cleaned_output = self._clean_output(original_output)
|
1052
733
|
if cleaned_output:
|
1053
734
|
stage_outputs.append(cleaned_output)
|
1054
|
-
else:
|
1055
|
-
logger.warning(
|
1056
|
-
f"(stage {stage_name}, item {item.item_key}) output destroyed: {original_output}"
|
1057
|
-
)
|
1058
735
|
|
1059
736
|
if stage_outputs:
|
1060
|
-
# Success
|
737
|
+
# Success
|
1061
738
|
stage_result = StageResult(
|
1062
739
|
stage_name=stage_name,
|
1063
740
|
output_field=stage.output_field,
|
@@ -1066,102 +743,44 @@ class CaptionWorker(BaseWorker):
|
|
1066
743
|
item.stage_results[stage_name] = stage_result
|
1067
744
|
successful_items.append((original_idx, item))
|
1068
745
|
else:
|
1069
|
-
# Failed - check
|
746
|
+
# Failed - check retry
|
1070
747
|
if attempt_count + 1 < max_attempts:
|
1071
748
|
failed_items.append((original_idx, item, attempt_count + 1))
|
1072
|
-
logger.warning(
|
1073
|
-
f"Stage {stage_name} failed for item {item.item_key} "
|
1074
|
-
f"(attempt {attempt_count + 1}/{max_attempts}), will retry"
|
1075
|
-
)
|
1076
749
|
else:
|
1077
|
-
logger.error(
|
1078
|
-
f"Stage {stage_name} failed for item {item.item_key} "
|
1079
|
-
f"after {max_attempts} attempts"
|
1080
|
-
)
|
750
|
+
logger.error(f"Stage {stage_name} failed for item {item.item_key}")
|
1081
751
|
self.items_failed += 1
|
752
|
+
stage_result = StageResult(
|
753
|
+
stage_name=stage_name,
|
754
|
+
output_field=stage.output_field,
|
755
|
+
outputs=[],
|
756
|
+
error=f"Failed after {max_attempts} attempts",
|
757
|
+
)
|
758
|
+
item.stage_results[stage_name] = stage_result
|
759
|
+
self.result_queue.put(
|
760
|
+
{
|
761
|
+
"item": item,
|
762
|
+
"outputs": {},
|
763
|
+
"processing_time_ms": 0.0,
|
764
|
+
"error": f"Failed stage {stage_name} after {max_attempts} attempts",
|
765
|
+
}
|
766
|
+
)
|
1082
767
|
|
1083
|
-
# Update
|
768
|
+
# Update for next iteration
|
1084
769
|
items_to_process = failed_items
|
1085
|
-
|
1086
|
-
# Update batch with successful items for next stage
|
1087
770
|
batch = [item for _, item in successful_items]
|
1088
771
|
|
1089
|
-
|
1090
|
-
if items_to_process:
|
1091
|
-
logger.info(
|
1092
|
-
f"Retrying {len(items_to_process)} failed items for stage {stage_name}"
|
1093
|
-
)
|
1094
|
-
|
1095
|
-
# Convert batch items to results
|
772
|
+
# Convert to results
|
1096
773
|
for item in batch:
|
1097
|
-
# Aggregate outputs by field
|
774
|
+
# Aggregate outputs by field
|
1098
775
|
outputs_by_field = defaultdict(list)
|
1099
|
-
|
1100
776
|
for stage_result in item.stage_results.values():
|
1101
777
|
outputs_by_field[stage_result.output_field].extend(stage_result.outputs)
|
1102
778
|
|
1103
|
-
|
1104
|
-
chunk_id=item.chunk_id,
|
1105
|
-
shard_name=Path(item.chunk_id).stem.rsplit("_chunk_", 1)[0],
|
1106
|
-
item_key=item.item_key,
|
1107
|
-
outputs=dict(outputs_by_field), # Convert defaultdict to dict
|
1108
|
-
image_width=item.image.width,
|
1109
|
-
image_height=item.image.height,
|
1110
|
-
image_format=item.image.format or "unknown",
|
1111
|
-
file_size=len(item.image_data),
|
1112
|
-
processing_time_ms=0, # Will be calculated by caller
|
1113
|
-
metadata=item.metadata,
|
1114
|
-
)
|
1115
|
-
results.append(result)
|
779
|
+
results.append((item, dict(outputs_by_field)))
|
1116
780
|
self.items_processed += 1
|
1117
781
|
|
1118
782
|
return results
|
1119
783
|
|
1120
|
-
def _inference_thread(self):
|
1121
|
-
"""Background thread for multi-stage vLLM inference."""
|
1122
|
-
logger.info("Starting multi-stage inference thread")
|
1123
|
-
|
1124
|
-
while self.running:
|
1125
|
-
try:
|
1126
|
-
# Get batch from queue with timeout
|
1127
|
-
batch = self.inference_queue.get(timeout=1)
|
1128
|
-
|
1129
|
-
if not batch:
|
1130
|
-
continue
|
1131
|
-
|
1132
|
-
# Skip if disconnected
|
1133
|
-
if self.should_stop_processing.is_set():
|
1134
|
-
continue
|
1135
|
-
|
1136
|
-
logger.debug(
|
1137
|
-
f"Processing batch of {len(batch)} images through {len(self.stages)} stages"
|
1138
|
-
)
|
1139
|
-
start_time = time.time()
|
1140
|
-
|
1141
|
-
# Process batch through all stages
|
1142
|
-
results = self._process_batch_multi_stage(batch)
|
1143
|
-
|
1144
|
-
# Calculate processing time per item
|
1145
|
-
if results:
|
1146
|
-
processing_time_per_item = (time.time() - start_time) * 1000 / len(batch)
|
1147
|
-
|
1148
|
-
# Update processing time and queue results
|
1149
|
-
for result in results:
|
1150
|
-
result.processing_time_ms = processing_time_per_item
|
1151
|
-
self.result_queue.put(result)
|
1152
|
-
|
1153
|
-
logger.debug(
|
1154
|
-
f"Multi-stage batch processing complete: {len(results)} successful, "
|
1155
|
-
f"{len(batch) - len(results)} failed"
|
1156
|
-
)
|
1157
|
-
|
1158
|
-
except Empty:
|
1159
|
-
continue
|
1160
|
-
except Exception as e:
|
1161
|
-
if self.should_stop_processing.is_set():
|
1162
|
-
continue
|
1163
|
-
logger.error(f"Inference error: {e}", exc_info=True)
|
1164
|
-
|
1165
784
|
def _build_vllm_input(self, image: Image.Image, prompt: str, processor, tokenizer) -> Dict:
|
1166
785
|
"""Build vLLM input."""
|
1167
786
|
try:
|
@@ -1198,124 +817,129 @@ class CaptionWorker(BaseWorker):
|
|
1198
817
|
if not text:
|
1199
818
|
return ""
|
1200
819
|
|
1201
|
-
# Remove common artifacts
|
1202
820
|
for token in ["<|end|>", "<|endoftext|>", "<|im_end|>", "I'm sorry", "I cannot"]:
|
1203
821
|
if token in text:
|
1204
822
|
text = text.split(token)[0]
|
1205
823
|
|
1206
824
|
return text.strip()
|
1207
825
|
|
826
|
+
def _get_heartbeat_data(self) -> Dict[str, Any]:
|
827
|
+
"""Get heartbeat data."""
|
828
|
+
return {
|
829
|
+
"type": "heartbeat",
|
830
|
+
"processed": self.items_processed,
|
831
|
+
"failed": self.items_failed,
|
832
|
+
"units_completed": self.units_completed,
|
833
|
+
"current_unit": self._get_current_unit_id() if self.current_unit else None,
|
834
|
+
"queue_sizes": {
|
835
|
+
"readahead": self.readahead_queue.qsize(),
|
836
|
+
"inference": self.inference_queue.qsize(),
|
837
|
+
"results": self.result_queue.qsize(),
|
838
|
+
},
|
839
|
+
"stages": len(self.stages),
|
840
|
+
"models_loaded": len(self.model_manager.models) if self.model_manager else 0,
|
841
|
+
}
|
842
|
+
|
843
|
+
async def _create_tasks(self) -> list:
|
844
|
+
"""Create async tasks to run."""
|
845
|
+
return [
|
846
|
+
asyncio.create_task(self._heartbeat_loop()),
|
847
|
+
asyncio.create_task(self._base_message_handler()),
|
848
|
+
asyncio.create_task(self._result_sender()),
|
849
|
+
]
|
850
|
+
|
1208
851
|
async def _result_sender(self):
|
1209
|
-
"""Send results back to orchestrator
|
1210
|
-
|
852
|
+
"""Send results back to orchestrator."""
|
853
|
+
while self.running and self.connected.is_set():
|
854
|
+
try:
|
855
|
+
# Get result
|
856
|
+
result_data = await asyncio.get_event_loop().run_in_executor(
|
857
|
+
None, self.result_queue.get, True, 1
|
858
|
+
)
|
1211
859
|
|
1212
|
-
|
1213
|
-
|
1214
|
-
|
1215
|
-
|
1216
|
-
|
1217
|
-
|
1218
|
-
|
860
|
+
if self.websocket and self.connected.is_set():
|
861
|
+
item = result_data["item"]
|
862
|
+
logger.debug(f"Handling results for item: {item}")
|
863
|
+
outputs = result_data["outputs"]
|
864
|
+
|
865
|
+
# Create work result
|
866
|
+
# logger.info(f"Processed item: {item}")
|
867
|
+
work_result = WorkResult(
|
868
|
+
unit_id=item.unit_id,
|
869
|
+
source_id=item.metadata.get("shard_name", "unknown"),
|
870
|
+
chunk_id=item.chunk_id,
|
871
|
+
sample_id=f"{item.item_key}",
|
872
|
+
outputs=outputs,
|
873
|
+
metadata={
|
874
|
+
"item_key": item.item_key,
|
875
|
+
"item_index": item.metadata.get("_item_index"),
|
876
|
+
"image_width": item.image.width,
|
877
|
+
"image_height": item.image.height,
|
878
|
+
"image_format": item.image.format or "unknown",
|
879
|
+
"file_size": len(item.image_data) if item.image_data else 0,
|
880
|
+
**item.metadata,
|
881
|
+
},
|
882
|
+
processing_time_ms=result_data["processing_time_ms"],
|
883
|
+
error=result_data.get("error", None),
|
884
|
+
)
|
885
|
+
|
886
|
+
# Send result in format that orchestrator expects
|
887
|
+
await self.websocket.send(
|
888
|
+
json.dumps(
|
889
|
+
{
|
890
|
+
"type": "submit_results",
|
891
|
+
"unit_id": work_result.unit_id,
|
892
|
+
"job_id": item.job_id,
|
893
|
+
"dataset": self.dataset_path,
|
894
|
+
"sample_id": work_result.sample_id,
|
895
|
+
"source_id": work_result.source_id,
|
896
|
+
"outputs": work_result.outputs,
|
897
|
+
"metadata": work_result.metadata,
|
898
|
+
"processing_time_ms": work_result.processing_time_ms,
|
899
|
+
}
|
1219
900
|
)
|
1220
|
-
|
1221
|
-
|
1222
|
-
|
1223
|
-
|
1224
|
-
|
1225
|
-
|
1226
|
-
sent_results = []
|
1227
|
-
for result in pending_results:
|
1228
|
-
try:
|
1229
|
-
# For backward compatibility, if there's only one output field "captions"
|
1230
|
-
# send it in the old format
|
1231
|
-
if len(result.outputs) == 1 and "captions" in result.outputs:
|
1232
|
-
# Old format for single-stage compatibility
|
1233
|
-
await self.websocket.send(
|
1234
|
-
json.dumps(
|
1235
|
-
{
|
1236
|
-
"type": "submit_captions",
|
1237
|
-
"chunk_id": result.chunk_id,
|
1238
|
-
"dataset": self.dataset_config.get(
|
1239
|
-
"dataset_path", "unknown"
|
1240
|
-
),
|
1241
|
-
"shard": result.shard_name,
|
1242
|
-
"item_key": result.item_key,
|
1243
|
-
"item_index": result.metadata.get("_item_index"),
|
1244
|
-
"captions": result.outputs["captions"],
|
1245
|
-
"caption_count": len(result.outputs["captions"]),
|
1246
|
-
"image_width": result.image_width,
|
1247
|
-
"image_height": result.image_height,
|
1248
|
-
"image_format": result.image_format,
|
1249
|
-
"file_size": result.file_size,
|
1250
|
-
"processing_time_ms": result.processing_time_ms,
|
1251
|
-
}
|
1252
|
-
)
|
1253
|
-
)
|
1254
|
-
else:
|
1255
|
-
# New format for multi-stage outputs
|
1256
|
-
await self.websocket.send(
|
1257
|
-
json.dumps(
|
1258
|
-
{
|
1259
|
-
"type": "submit_captions",
|
1260
|
-
"chunk_id": result.chunk_id,
|
1261
|
-
"dataset": self.dataset_config.get(
|
1262
|
-
"dataset_path", "unknown"
|
1263
|
-
),
|
1264
|
-
"shard": result.shard_name,
|
1265
|
-
"item_key": result.item_key,
|
1266
|
-
"outputs": result.outputs, # Dict of field -> list of outputs
|
1267
|
-
"captions": result.outputs.get(
|
1268
|
-
"captions", []
|
1269
|
-
), # For compatibility
|
1270
|
-
"caption_count": sum(
|
1271
|
-
len(v) for v in result.outputs.values()
|
1272
|
-
),
|
1273
|
-
"image_width": result.image_width,
|
1274
|
-
"image_height": result.image_height,
|
1275
|
-
"image_format": result.image_format,
|
1276
|
-
"file_size": result.file_size,
|
1277
|
-
"processing_time_ms": result.processing_time_ms,
|
1278
|
-
"metadata": result.metadata,
|
1279
|
-
}
|
1280
|
-
)
|
1281
|
-
)
|
1282
|
-
|
1283
|
-
sent_results.append(result)
|
1284
|
-
|
1285
|
-
if self.items_processed % 100 == 0:
|
1286
|
-
total_outputs = sum(
|
1287
|
-
len(outputs) for outputs in result.outputs.values()
|
1288
|
-
)
|
1289
|
-
logger.info(
|
1290
|
-
f"Processed {self.items_processed} items "
|
1291
|
-
f"(~{total_outputs} outputs across {len(result.outputs)} fields)"
|
1292
|
-
)
|
1293
|
-
except websockets.exceptions.ConnectionClosed as e:
|
1294
|
-
logger.warning(f"Connection lost while sending result: {e}")
|
1295
|
-
raise # Re-raise to trigger task completion
|
1296
|
-
except Exception as e:
|
1297
|
-
logger.error(f"Error sending result: {e}")
|
1298
|
-
break
|
1299
|
-
|
1300
|
-
# Remove successfully sent results
|
1301
|
-
for result in sent_results:
|
1302
|
-
pending_results.remove(result)
|
1303
|
-
|
1304
|
-
# Clear pending results if disconnected and buffer is too large
|
1305
|
-
if not self.connected.is_set() and len(pending_results) > 1000:
|
1306
|
-
logger.warning(
|
1307
|
-
f"Clearing {len(pending_results)} pending results due to prolonged disconnection"
|
901
|
+
)
|
902
|
+
|
903
|
+
if self.items_processed % 100 == 0:
|
904
|
+
total_outputs = sum(len(v) for v in outputs.values())
|
905
|
+
logger.info(
|
906
|
+
f"Processed {self.items_processed} items (~{total_outputs} outputs)"
|
1308
907
|
)
|
1309
|
-
pending_results.clear()
|
1310
908
|
|
1311
|
-
|
909
|
+
except Empty:
|
910
|
+
continue
|
911
|
+
except Exception as e:
|
912
|
+
logger.error(f"Error sending result: {e}", exc_info=True)
|
913
|
+
await asyncio.sleep(1)
|
914
|
+
|
915
|
+
async def _on_disconnect(self):
|
916
|
+
"""Handle disconnection."""
|
917
|
+
self.should_stop_processing.set()
|
918
|
+
|
919
|
+
with self.work_lock:
|
920
|
+
self.assigned_units.clear()
|
921
|
+
self.current_unit = None
|
922
|
+
|
923
|
+
# Clear queues
|
924
|
+
self._clear_queue(self.readahead_queue)
|
925
|
+
self._clear_queue(self.inference_queue)
|
926
|
+
self._clear_queue(self.result_queue)
|
927
|
+
|
928
|
+
def _clear_queue(self, queue: Queue):
|
929
|
+
"""Clear all items from a queue."""
|
930
|
+
try:
|
931
|
+
while True:
|
932
|
+
queue.get_nowait()
|
933
|
+
except Empty:
|
934
|
+
pass
|
935
|
+
|
936
|
+
async def _pre_shutdown(self):
|
937
|
+
"""Cleanup before shutdown."""
|
938
|
+
self.readahead_queue.put(None)
|
939
|
+
self.inference_queue.put(None)
|
1312
940
|
|
1313
|
-
|
1314
|
-
|
1315
|
-
raise # Re-raise connection errors
|
1316
|
-
logger.error(f"Unexpected error in result sender: {e}")
|
1317
|
-
await asyncio.sleep(1)
|
941
|
+
if self.image_processor:
|
942
|
+
self.image_processor.shutdown()
|
1318
943
|
|
1319
|
-
|
1320
|
-
|
1321
|
-
raise
|
944
|
+
if self.model_manager:
|
945
|
+
self.model_manager.cleanup()
|