caption-flow 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,103 +1,59 @@
1
- """Caption worker for vLLM-based distributed image captioning with multi-stage processing."""
1
+ """Caption worker with processor abstraction for distributed captioning."""
2
2
 
3
3
  import os
4
4
 
5
5
  os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
6
6
 
7
7
  import asyncio
8
- import io
9
8
  import json
10
9
  import logging
11
10
  import websockets
12
11
  import time
13
- from dataclasses import dataclass, field
14
- from pathlib import Path
15
- from typing import Dict, Any, Optional, List, Tuple
12
+ from dataclasses import dataclass
13
+ from typing import Dict, Any, Optional, List, Tuple, Union
16
14
  from queue import Queue, Empty
17
- from threading import Thread, Lock, Event
18
- from collections import deque, defaultdict
15
+ from threading import Thread, Event, Lock
16
+ from collections import defaultdict, deque
19
17
 
20
18
  from PIL import Image
21
- import numpy as np
22
19
  from huggingface_hub import get_token
23
20
 
24
21
  from .base import BaseWorker
25
- from ..models import JobStatus, Job
26
- from ..utils import CaptionUtils
27
- from ..utils.dataset_loader import DatasetLoader
22
+ from ..processors import (
23
+ ProcessorConfig,
24
+ WorkAssignment,
25
+ WorkUnit,
26
+ WorkResult,
27
+ WebDatasetWorkerProcessor,
28
+ HuggingFaceDatasetWorkerProcessor,
29
+ LocalFilesystemWorkerProcessor,
30
+ )
28
31
  from ..utils.vllm_config import VLLMConfigManager
29
32
  from ..utils.image_processor import ImageProcessor
30
- from ..utils.shard_processor import HFDatasetShardProcessor, WebDatasetShardProcessor
31
33
  from ..utils.prompt_template import PromptTemplateManager
34
+ from ..models import ProcessingStage, StageResult
32
35
 
33
36
  logger = logging.getLogger(__name__)
34
-
35
-
36
- @dataclass
37
- class ProcessingStage:
38
- """Configuration for a single processing stage."""
39
-
40
- name: str
41
- model: str
42
- prompts: List[str]
43
- output_field: str
44
- requires: List[str] = field(default_factory=list)
45
- sampling: Optional[Dict[str, Any]] = None
46
-
47
- # Model-specific overrides
48
- tensor_parallel_size: Optional[int] = None
49
- max_model_len: Optional[int] = None
50
- dtype: Optional[str] = None
51
- gpu_memory_utilization: Optional[float] = None
52
-
53
-
54
- @dataclass
55
- class StageResult:
56
- """Results from a single stage."""
57
-
58
- stage_name: str
59
- output_field: str
60
- outputs: List[str] # Multiple outputs from multiple prompts
61
-
62
-
63
- @dataclass
64
- class ShardChunk:
65
- """Shard chunk assignment with unprocessed ranges."""
66
-
67
- chunk_id: str
68
- shard_url: str
69
- shard_name: str
70
- start_index: int
71
- chunk_size: int
72
- unprocessed_ranges: List[Tuple[int, int]] = field(default_factory=list)
37
+ # logger.setLevel(logging.DEBUG)
73
38
 
74
39
 
75
40
  @dataclass
76
41
  class ProcessingItem:
77
- """Item being processed."""
42
+ """Item being processed through stages."""
78
43
 
44
+ unit_id: str
45
+ job_id: str
79
46
  chunk_id: str
80
47
  item_key: str
48
+ item_index: int
81
49
  image: Image.Image
82
50
  image_data: bytes
83
- metadata: Dict[str, Any] = field(default_factory=dict)
84
- stage_results: Dict[str, StageResult] = field(default_factory=dict) # Accumulated results
51
+ metadata: Dict[str, Any]
52
+ stage_results: Dict[str, StageResult] = None
85
53
 
86
-
87
- @dataclass
88
- class ProcessedResult:
89
- """Result with multi-stage outputs."""
90
-
91
- chunk_id: str
92
- shard_name: str
93
- item_key: str
94
- outputs: Dict[str, List[str]] # field_name -> list of outputs
95
- image_width: int
96
- image_height: int
97
- image_format: str
98
- file_size: int
99
- processing_time_ms: float
100
- metadata: Dict[str, Any] = field(default_factory=dict)
54
+ def __post_init__(self):
55
+ if self.stage_results is None:
56
+ self.stage_results = {}
101
57
 
102
58
 
103
59
  class MultiStageVLLMManager:
@@ -121,7 +77,7 @@ class MultiStageVLLMManager:
121
77
 
122
78
  logger.info(f"Loading model {model_name} for stage {stage.name}")
123
79
 
124
- # Build model-specific config by merging base config with stage overrides
80
+ # Build model-specific config
125
81
  model_config = base_config.copy()
126
82
  model_config["model"] = model_name
127
83
 
@@ -163,10 +119,7 @@ class MultiStageVLLMManager:
163
119
  """Create sampling params for a stage."""
164
120
  from vllm import SamplingParams
165
121
 
166
- # Start with base sampling config
167
122
  sampling_config = base_sampling.copy()
168
-
169
- # Override with stage-specific sampling if provided
170
123
  if stage.sampling:
171
124
  sampling_config.update(stage.sampling)
172
125
 
@@ -183,16 +136,7 @@ class MultiStageVLLMManager:
183
136
  return params
184
137
 
185
138
  def get_model_for_stage(self, stage_name: str, model_name: str) -> Tuple[Any, Any, Any, Any]:
186
- """
187
- Get model components for a stage.
188
-
189
- Returns:
190
- tuple: A tuple containing:
191
- - llm: The language model instance for the given model name.
192
- - processor: The processor associated with the model.
193
- - tokenizer: The tokenizer for the model.
194
- - sampling_params: The sampling parameters for the given stage.
195
- """
139
+ """Get model components for a stage."""
196
140
  return (
197
141
  self.models[model_name],
198
142
  self.processors[model_name],
@@ -214,74 +158,69 @@ class MultiStageVLLMManager:
214
158
 
215
159
 
216
160
  class CaptionWorker(BaseWorker):
217
- """Worker that processes shard chunks for image captioning using multi-stage vLLM."""
161
+ """Worker that processes work units for image captioning using multi-stage vLLM."""
218
162
 
219
163
  def __init__(self, config: Dict[str, Any]):
220
164
  super().__init__(config)
221
165
 
222
- batch_image_processing = config.get("batch_image_processing", False)
223
-
224
- # Dataset configuration will be received from orchestrator
225
- self.dataset_config = None
226
- self.dataset_loader = None
227
- self.dataset_type = None
228
- self.dataset_split = None
229
- self.dataset_image_column = None
230
- self.hf_token = get_token()
231
-
232
- # vLLM configuration will be received from orchestrator
166
+ # Processor configuration - will be set from orchestrator
167
+ self.processor_type = None
168
+ self.processor: Optional[
169
+ Union[
170
+ WebDatasetWorkerProcessor,
171
+ HuggingFaceDatasetWorkerProcessor,
172
+ LocalFilesystemWorkerProcessor,
173
+ ],
174
+ ] = None
175
+ self.dataset_path: Optional[str] = None
176
+
177
+ # vLLM configuration
233
178
  self.vllm_config = None
234
179
  self.stages: List[ProcessingStage] = []
235
- self.stage_order: List[str] = [] # Topologically sorted stage names
180
+ self.stage_order: List[str] = []
236
181
  self.vllm_config_manager = VLLMConfigManager()
237
182
  self.model_manager = None
238
183
 
239
- # Backward compatibility: local config for GPU selection
184
+ # GPU selection
240
185
  self.gpu_id = config.get("gpu_id", 0)
241
-
242
- # Connection state events
243
- self.should_stop_processing = Event()
186
+ self.hf_token = get_token()
244
187
 
245
188
  # Image processor
246
- self.image_processor = None
247
- if batch_image_processing:
248
- self.image_processor = ImageProcessor()
249
-
250
- # Shard chunk processing
251
- self.hf_processor = HFDatasetShardProcessor()
252
- self.webdataset_processor = WebDatasetShardProcessor(
253
- hf_token=self.hf_token, dataset_type=self.dataset_type
254
- )
255
- self.chunk_lock = Lock()
256
- self.assigned_chunks = deque()
257
- self.current_chunk = None
258
- self.current_chunk_progress = 0
189
+ batch_image_processing = config.get("batch_image_processing", False)
190
+ self.image_processor = ImageProcessor() if batch_image_processing else None
191
+
192
+ # Work processing
193
+ self.work_lock = Lock()
194
+ self.assigned_units = deque()
195
+ self.current_unit: Optional[WorkUnit] = None
259
196
 
260
- # Batching queues - will be cleared on disconnect
197
+ # Processing queues
261
198
  self.readahead_queue = Queue(maxsize=256)
262
199
  self.inference_queue = Queue(maxsize=128)
263
200
  self.result_queue = Queue()
264
201
 
265
- # Job mode for shards vs jobs and job queue
266
- self.job_mode = config.get("job_mode", False)
267
- self.job_queue = Queue(maxsize=32)
202
+ # Processing control
203
+ self.should_stop_processing = Event()
268
204
 
269
205
  def _init_metrics(self):
270
206
  """Initialize worker metrics."""
271
207
  self.items_processed = 0
272
208
  self.items_failed = 0
273
- self.chunks_completed = 0
209
+ self.units_completed = 0
274
210
 
275
211
  def _get_auth_data(self) -> Dict[str, Any]:
276
212
  """Get authentication data."""
277
213
  return {"token": self.token, "name": self.name}
278
214
 
215
+ def _get_current_unit_id(self) -> Optional[str]:
216
+ """Get the current unit ID."""
217
+ return self.current_unit.unit_id if self.current_unit else None
218
+
279
219
  async def _pre_start(self):
280
220
  """Initialize before starting connection loop."""
281
- # Wait for initial connection to get vLLM config
221
+ # Wait for initial connection to get config
282
222
  logger.info("Connecting to orchestrator for configuration...")
283
223
 
284
- # Try initial connection to get config
285
224
  config_received = False
286
225
  while not config_received and self.running:
287
226
  try:
@@ -292,21 +231,110 @@ class CaptionWorker(BaseWorker):
292
231
  await asyncio.sleep(5)
293
232
 
294
233
  # Initialize vLLM once we have config
295
- self._setup_vllm()
234
+ if self.vllm_config:
235
+ self._setup_vllm()
296
236
 
297
237
  # Start background threads
298
- reader_thread = Thread(target=self._shard_reader_thread, daemon=True)
299
- reader_thread.start()
238
+ Thread(target=self._unit_processor_thread, daemon=True).start()
239
+ Thread(target=self._inference_thread, daemon=True).start()
240
+
241
+ async def _initial_connect_for_config(self):
242
+ """Connect initially just to get configuration."""
243
+ logger.info(f"Connecting to {self.server_url}")
244
+ async with websockets.connect(self.server_url, ssl=self.ssl_context) as websocket:
245
+ await websocket.send(json.dumps(self._get_auth_data()))
246
+
247
+ welcome = await websocket.recv()
248
+ welcome_data = json.loads(welcome)
249
+
250
+ if "error" in welcome_data:
251
+ raise RuntimeError(f"Authentication failed: {welcome_data['error']}")
252
+
253
+ # Extract vLLM config from processor config
254
+ processor_config = welcome_data.get("processor_config", {})
255
+ self.vllm_config = processor_config.get("vllm", {})
256
+
257
+ if not self.vllm_config:
258
+ raise RuntimeError("No vLLM configuration received from orchestrator")
259
+
260
+ # Parse stages
261
+ self.stages = self._parse_stages_config(self.vllm_config)
262
+ self.stage_order = self._topological_sort_stages(self.stages)
263
+
264
+ logger.info(f"Configured {len(self.stages)} processing stages: {self.stage_order}")
265
+
266
+ async def _handle_welcome(self, welcome_data: Dict[str, Any]):
267
+ """Handle welcome message from orchestrator."""
268
+ with self.work_lock:
269
+ self.assigned_units.clear()
270
+ self.current_unit = None
271
+
272
+ self._clear_queue(self.readahead_queue)
273
+ self._clear_queue(self.inference_queue)
274
+ self._clear_queue(self.result_queue)
275
+
276
+ # Reset counters
277
+ self.items_processed = 0
278
+ self.items_failed = 0
279
+ self.units_completed = 0
280
+
281
+ # Setup processor
282
+ self.processor_type = welcome_data.get("processor_type", None)
283
+ assert self.processor_type is not None, "Processor type not found in welcome data"
284
+ logger.info(f"Creating {self.processor_type} processor")
285
+ processor_config = ProcessorConfig(
286
+ processor_type=self.processor_type, config=welcome_data.get("processor_config", {})
287
+ )
288
+
289
+ if self.processor_type == "webdataset":
290
+ self.processor = WebDatasetWorkerProcessor()
291
+ elif self.processor_type == "huggingface_datasets":
292
+ self.processor = HuggingFaceDatasetWorkerProcessor()
293
+ elif self.processor_type == "local_filesystem":
294
+ self.processor = LocalFilesystemWorkerProcessor()
295
+ else:
296
+ raise ValueError(f"Unknown processor type: {self.processor_type}")
297
+
298
+ self.processor.initialize(processor_config)
299
+ self.dataset_path = self.processor.dataset_path
300
+
301
+ # Update vLLM config if provided
302
+ new_vllm_config = welcome_data.get("processor_config", {}).get("vllm")
303
+ if new_vllm_config and new_vllm_config != self.vllm_config:
304
+ logger.info("Received updated vLLM configuration")
305
+ self._handle_vllm_config_update(new_vllm_config)
306
+
307
+ # Clear stop signal
308
+ self.should_stop_processing.clear()
309
+
310
+ # Request initial work
311
+ if self.websocket:
312
+ await self.websocket.send(json.dumps({"type": "request_work", "count": 2}))
313
+
314
+ async def _handle_message(self, data: Dict[str, Any]):
315
+ """Handle message from orchestrator."""
316
+ msg_type = data.get("type")
317
+
318
+ if msg_type == "work_assignment":
319
+ assignment = WorkAssignment.from_dict(data["assignment"])
320
+ with self.work_lock:
321
+ for unit in assignment.units:
322
+ self.assigned_units.append(unit)
323
+ logger.info(f"Received {len(assignment.units)} work units")
300
324
 
301
- inference_thread = Thread(target=self._inference_thread, daemon=True)
302
- inference_thread.start()
325
+ elif msg_type == "no_work":
326
+ logger.info("No work available")
327
+ await asyncio.sleep(10)
328
+
329
+ if self.websocket and self.connected.is_set():
330
+ await self.websocket.send(json.dumps({"type": "request_work", "count": 2}))
303
331
 
304
332
  def _parse_stages_config(self, vllm_config: Dict[str, Any]) -> List[ProcessingStage]:
305
333
  """Parse stages configuration from vLLM config."""
306
334
  stages_config = vllm_config.get("stages", [])
307
335
 
308
336
  if not stages_config:
309
- # Backward compatibility: create single stage from old config
337
+ # Backward compatibility
310
338
  return [
311
339
  ProcessingStage(
312
340
  name="default",
@@ -338,10 +366,8 @@ class CaptionWorker(BaseWorker):
338
366
 
339
367
  def _topological_sort_stages(self, stages: List[ProcessingStage]) -> List[str]:
340
368
  """Sort stages by dependencies."""
341
- # Build dependency graph
342
369
  graph = defaultdict(list)
343
370
  in_degree = defaultdict(int)
344
-
345
371
  stage_map = {s.name: s for s in stages}
346
372
 
347
373
  for stage in stages:
@@ -351,7 +377,6 @@ class CaptionWorker(BaseWorker):
351
377
  raise ValueError(f"Stage '{stage.name}' requires missing dependency '{dep}'")
352
378
  graph[dep].append(stage.name)
353
379
 
354
- # Topological sort using Kahn's algorithm
355
380
  queue = deque([name for name, degree in in_degree.items() if degree == 0])
356
381
  result = []
357
382
 
@@ -369,204 +394,10 @@ class CaptionWorker(BaseWorker):
369
394
 
370
395
  return result
371
396
 
372
- async def _handle_welcome(self, welcome_data: Dict[str, Any]):
373
- """Handle welcome message from orchestrator."""
374
- # Extract and setup dataset configuration
375
- dataset_config = welcome_data.get("dataset_config", {})
376
- if dataset_config:
377
- self._setup_dataset_loader(dataset_config)
378
- logger.info(f"Received dataset config: {dataset_config}")
379
- else:
380
- logger.warning("No dataset configuration received from orchestrator")
381
-
382
- # Update vLLM config if provided (in case it changed)
383
- new_vllm_config = welcome_data.get("vllm_config")
384
- if new_vllm_config and new_vllm_config != self.vllm_config:
385
- logger.info("Received updated vLLM configuration")
386
- if not self._handle_vllm_config_update(new_vllm_config):
387
- logger.error("Failed to update vLLM configuration")
388
-
389
- # Clear stop signal now that we're connected
390
- self.should_stop_processing.clear()
391
-
392
- # Request initial chunks if not in job mode
393
- if not self.job_mode and self.websocket:
394
- await self.websocket.send(json.dumps({"type": "request_chunks", "count": 2}))
395
-
396
- async def _handle_message(self, data: Dict[str, Any]):
397
- """Handle message from orchestrator."""
398
- msg_type = data.get("type")
399
-
400
- if msg_type == "shard_assignment":
401
- chunks = data["chunks"]
402
- for chunk_data in chunks:
403
- chunk = ShardChunk(**chunk_data)
404
- with self.chunk_lock:
405
- self.assigned_chunks.append(chunk)
406
- logger.info(f"Received chunk assignment: {chunk.chunk_id}")
407
-
408
- elif msg_type == "no_chunks":
409
- reason = data.get("reason", "unknown")
410
- logger.info(f"No chunks available from orchestrator (reason: {reason})")
411
-
412
- wait_time = 2 if reason == "state_restoring" else 10
413
- await asyncio.sleep(wait_time)
414
-
415
- if self.websocket and self.connected.is_set():
416
- await self.websocket.send(json.dumps({"type": "request_chunks", "count": 2}))
417
-
418
- elif msg_type == "reload_vllm":
419
- logger.info("Orchestrator requested vLLM reload")
420
- new_config = data.get("vllm_config")
421
- if new_config:
422
- self._handle_vllm_config_update(new_config)
423
-
424
- elif msg_type == "config_update":
425
- # Soft config update without reload
426
- if data.get("vllm_config"):
427
- self._handle_vllm_config_update(data["vllm_config"])
428
-
429
- elif msg_type == "job_assignment":
430
- await self._handle_job_assignment(data["job"])
431
-
432
- elif msg_type == "no_jobs":
433
- logger.debug("No jobs available")
434
- await asyncio.sleep(2)
435
-
436
- def _get_heartbeat_data(self) -> Dict[str, Any]:
437
- """Get heartbeat data."""
438
- return {
439
- "type": "heartbeat",
440
- "processed": self.items_processed,
441
- "failed": self.items_failed,
442
- "chunks_completed": self.chunks_completed,
443
- "current_chunk": self.current_chunk.chunk_id if self.current_chunk else None,
444
- "chunk_progress": self.current_chunk_progress,
445
- "queue_sizes": {
446
- "readahead": self.readahead_queue.qsize(),
447
- "inference": self.inference_queue.qsize(),
448
- "results": self.result_queue.qsize(),
449
- },
450
- "stages": len(self.stages),
451
- "models_loaded": len(self.model_manager.models) if self.model_manager else 0,
452
- }
453
-
454
- async def _create_tasks(self) -> list:
455
- """Create async tasks to run."""
456
- tasks = [
457
- asyncio.create_task(self._heartbeat_loop()),
458
- asyncio.create_task(self._base_message_handler()),
459
- asyncio.create_task(self._result_sender()),
460
- ]
461
-
462
- if self.job_mode:
463
- tasks.append(asyncio.create_task(self._job_request_loop()))
464
-
465
- return tasks
466
-
467
- async def _on_disconnect(self):
468
- """Handle disconnection."""
469
- self._clear_state_on_disconnect()
470
-
471
- async def _pre_shutdown(self):
472
- """Cleanup before shutdown."""
473
- # Stop processing threads by adding stop signals
474
- self.readahead_queue.put(None)
475
- self.inference_queue.put(None)
476
-
477
- # Shutdown image processor
478
- if self.image_processor is not None:
479
- self.image_processor.shutdown()
480
-
481
- # Cleanup model manager
482
- if self.model_manager:
483
- self.model_manager.cleanup()
484
-
485
- async def _initial_connect_for_config(self):
486
- """Connect initially just to get configuration."""
487
- logger.info(f"Connecting to {self.server_url}")
488
- async with websockets.connect(self.server_url, ssl=self.ssl_context) as websocket:
489
- await websocket.send(json.dumps(self._get_auth_data()))
490
-
491
- welcome = await websocket.recv()
492
- welcome_data = json.loads(welcome)
493
-
494
- if "error" in welcome_data:
495
- raise RuntimeError(f"Authentication failed: {welcome_data['error']}")
496
-
497
- self.vllm_config = welcome_data.get("vllm_config")
498
- if not self.vllm_config:
499
- raise RuntimeError("No vLLM configuration received from orchestrator")
500
-
501
- # Parse stages configuration
502
- self.stages = self._parse_stages_config(self.vllm_config)
503
- self.stage_order = self._topological_sort_stages(self.stages)
504
-
505
- logger.info(f"Configured {len(self.stages)} processing stages: {self.stage_order}")
506
-
507
- self.vllm_config_manager.current_config = self.vllm_config
508
-
509
- dataset_config = welcome_data.get("dataset_config", {})
510
- if dataset_config:
511
- self._setup_dataset_loader(dataset_config)
512
-
513
- logger.info("Received configuration from orchestrator")
514
-
515
- def _clear_state_on_disconnect(self):
516
- """Clear all processing state when disconnected."""
517
- logger.info("Clearing state due to disconnection")
518
-
519
- self.should_stop_processing.set()
520
-
521
- with self.chunk_lock:
522
- self.assigned_chunks.clear()
523
- self.current_chunk = None
524
- self.current_chunk_progress = 0
525
-
526
- self._clear_queue(self.readahead_queue)
527
- self._clear_queue(self.inference_queue)
528
- self._clear_queue(self.result_queue)
529
-
530
- logger.info("State cleared, ready for reconnection")
531
-
532
- def _clear_queue(self, queue: Queue):
533
- """Clear all items from a queue."""
534
- try:
535
- while True:
536
- queue.get_nowait()
537
- except Empty:
538
- pass
539
-
540
- def _setup_dataset_loader(self, dataset_config: Dict[str, Any]):
541
- """Initialize dataset loader with config from orchestrator."""
542
- dataset_path = dataset_config.get("dataset_path") or dataset_config.get("path")
543
- dataset_type = dataset_config.get("dataset_type") or dataset_config.get(
544
- "type", "huggingface"
545
- )
546
- dataset_split = dataset_config.get("dataset_split") or dataset_config.get("split", "train")
547
- dataset_image_column = dataset_config.get("dataset_image_column") or dataset_config.get(
548
- "image_column", "image"
549
- )
550
-
551
- if dataset_path:
552
- logger.info(
553
- f"Initializing dataset loader for {dataset_type}: {dataset_path} "
554
- f"(split: {dataset_split}, image_column: {dataset_image_column})"
555
- )
556
- self.dataset_loader = DatasetLoader(
557
- dataset_path, dataset_type, dataset_split, dataset_image_column
558
- )
559
- self.dataset_config = dataset_config
560
- self.dataset_type = dataset_type
561
- self.dataset_split = dataset_split
562
- self.dataset_image_column = dataset_image_column
563
- else:
564
- logger.warning("No dataset path provided by orchestrator")
565
-
566
397
  def _setup_vllm(self):
567
398
  """Initialize multi-stage vLLM components."""
568
399
  if not self.vllm_config:
569
- raise RuntimeError("vLLM config not received from orchestrator")
400
+ raise RuntimeError("vLLM config not received")
570
401
 
571
402
  os.environ["CUDA_VISIBLE_DEVICES"] = str(self.gpu_id)
572
403
 
@@ -601,11 +432,8 @@ class CaptionWorker(BaseWorker):
601
432
 
602
433
  logger.info("Multi-stage vLLM initialization complete")
603
434
 
604
- # Update config manager's tracking
605
- self.vllm_config_manager.current_config = self.vllm_config
606
-
607
435
  def _handle_vllm_config_update(self, new_config: Dict[str, Any]) -> bool:
608
- """Handle vLLM configuration updates for multi-stage."""
436
+ """Handle vLLM configuration updates."""
609
437
  if not new_config:
610
438
  return True
611
439
 
@@ -628,112 +456,107 @@ class CaptionWorker(BaseWorker):
628
456
  if stages_changed:
629
457
  logger.info("Stage configuration changed, reloading all models")
630
458
 
631
- # Save old config
632
459
  old_config = self.vllm_config
633
460
  self.vllm_config = new_config
634
461
  self.stages = new_stages
635
462
  self.stage_order = self._topological_sort_stages(self.stages)
636
463
 
637
464
  try:
638
- # Cleanup old models
639
465
  if self.model_manager:
640
466
  self.model_manager.cleanup()
641
-
642
- # Reload with new config
643
467
  self._setup_vllm()
644
-
645
- logger.info("Multi-stage vLLM reload complete")
646
468
  return True
647
-
648
469
  except Exception as e:
649
470
  logger.error(f"Failed to reload vLLM: {e}")
650
- # Restore old config
651
471
  self.vllm_config = old_config
652
472
  return False
653
473
  else:
654
- # Just update sampling params for existing stages
474
+ # Just update sampling params
655
475
  logger.info("Updating sampling parameters without model reload")
656
-
657
476
  base_sampling = new_config.get("sampling", {})
658
477
  for stage in self.stages:
659
478
  self.model_manager.create_sampling_params(stage, base_sampling)
660
-
661
479
  self.vllm_config = new_config
662
480
  return True
663
481
 
664
- async def _handle_job_assignment(self, job_data: Dict):
665
- """Handle job assignment from orchestrator."""
666
- try:
667
- # Convert to processing item
668
- image = Image.open(io.BytesIO(job_data["image_data"]))
669
-
670
- item = ProcessingItem(
671
- chunk_id=job_data["job_id"],
672
- item_key=job_data["sample_id"],
673
- image=image,
674
- image_data=job_data["image_data"],
675
- )
482
+ def _unit_processor_thread(self):
483
+ """Background thread that processes work units."""
484
+ logger.info("Starting unit processor thread")
676
485
 
677
- # Add to inference queue
678
- self.readahead_queue.put(item)
679
- logger.debug(f"Queued job {job_data['job_id']} for processing")
486
+ while self.running:
487
+ if self.should_stop_processing.is_set():
488
+ time.sleep(1)
489
+ continue
680
490
 
681
- except Exception as e:
682
- logger.error(f"Error handling job assignment: {e}")
491
+ if not self.connected.is_set():
492
+ time.sleep(1)
493
+ continue
494
+
495
+ # Get next unit
496
+ with self.work_lock:
497
+ if not self.current_unit and self.assigned_units:
498
+ self.current_unit = self.assigned_units.popleft()
499
+ logger.info(f"Starting unit {self._get_current_unit_id()}")
500
+
501
+ if not self.current_unit:
502
+ time.sleep(1)
503
+ continue
683
504
 
684
- async def _job_request_loop(self):
685
- """Request jobs from orchestrator in job mode."""
686
- while self.running and self.connected.is_set():
687
505
  try:
688
- # Check if we need more work
689
- if self.readahead_queue.qsize() < self.vllm_config.get("batch_size", 8):
690
- await self.websocket.send(json.dumps({"type": "request_job"}))
506
+ self._process_work_unit(self.current_unit)
691
507
 
692
- await asyncio.sleep(1)
508
+ if self.connected.is_set() and not self.should_stop_processing.is_set():
509
+ logger.info(f"Completed unit {self._get_current_unit_id()}")
510
+ self.units_completed += 1
693
511
 
694
- except Exception as e:
695
- logger.error(f"Job request error: {e}")
696
- await asyncio.sleep(5)
512
+ # Request more work if needed
513
+ with self.work_lock:
514
+ queue_size = len(self.assigned_units)
697
515
 
698
- def _process_shard_chunk(self, chunk: ShardChunk):
699
- """Process a single shard chunk with item-level tracking."""
700
- logger.info(
701
- f"Processing shard {chunk.shard_name} with unprocessed ranges: {chunk.unprocessed_ranges}"
702
- )
516
+ if queue_size < 2 and self.websocket and self.main_loop:
517
+ try:
518
+ asyncio.run_coroutine_threadsafe(
519
+ self.websocket.send(
520
+ json.dumps({"type": "request_work", "count": 2})
521
+ ),
522
+ self.main_loop,
523
+ ).result(timeout=5)
524
+ except Exception as e:
525
+ logger.warning(f"Could not request more work: {e}")
703
526
 
704
- # Select appropriate processor
705
- if chunk.shard_url.startswith("hf_dataset:"):
706
- processor = self.hf_processor
707
- else:
708
- processor = self.webdataset_processor
527
+ with self.work_lock:
528
+ self.current_unit = None
529
+
530
+ except Exception as e:
531
+ logger.error(f"Error processing unit: {e}", exc_info=True)
532
+ with self.work_lock:
533
+ self.current_unit = None
534
+
535
+ def _process_work_unit(self, unit: WorkUnit):
536
+ """Process a single work unit."""
537
+ if not self.processor:
538
+ logger.error("Processor not initialized")
539
+ return
709
540
 
710
541
  items_processed = 0
542
+ context = {} # Will store processed indices
711
543
 
712
- # Let the processor handle the range filtering
713
- for key, url, image_data, metadata in processor.iterate_chunk_with_metadata(
714
- chunk, self.dataset_loader, self.should_stop_processing, self.connected
715
- ):
544
+ # Get items from processor
545
+ for item_data in self.processor.process_unit(unit, context):
716
546
  try:
717
- # Load image
718
- img = Image.open(io.BytesIO(image_data))
719
-
720
547
  # Create processing item
721
548
  item = ProcessingItem(
722
- chunk_id=chunk.chunk_id,
723
- item_key=key,
724
- image=img,
725
- image_data=image_data,
726
- metadata=metadata,
549
+ unit_id=unit.unit_id,
550
+ chunk_id=unit.chunk_id,
551
+ job_id=item_data["job_id"],
552
+ item_key=item_data["item_key"],
553
+ item_index=item_data["item_index"],
554
+ image=item_data["image"],
555
+ image_data=item_data.get("image_data", b""),
556
+ metadata=item_data.get("metadata", {}),
727
557
  )
728
558
 
729
- # Store absolute item index for tracking
730
- # The processor should provide the correct index in metadata
731
- if "_chunk_relative_index" in metadata:
732
- item.metadata["_item_index"] = (
733
- chunk.start_index + metadata["_chunk_relative_index"]
734
- )
735
-
736
- # Add to readahead queue with timeout handling
559
+ # Add to readahead queue
737
560
  timeout_end = time.time() + 30
738
561
  while (
739
562
  self.running
@@ -748,9 +571,7 @@ class CaptionWorker(BaseWorker):
748
571
  raise TimeoutError("Queue put timeout")
749
572
  continue
750
573
 
751
- # If we couldn't queue due to disconnection, stop processing
752
574
  if not self.connected.is_set() or self.should_stop_processing.is_set():
753
- logger.debug(f"Skipping remaining items due to disconnection")
754
575
  break
755
576
 
756
577
  items_processed += 1
@@ -763,199 +584,22 @@ class CaptionWorker(BaseWorker):
763
584
  except Exception as e:
764
585
  if self.should_stop_processing.is_set():
765
586
  break
766
- logger.error(f"Error processing item {key}: {e}")
587
+ logger.error(f"Error processing item {item_data.get('item_key')}: {e}")
767
588
  self.items_failed += 1
768
589
 
769
- # Process any remaining items in queue
590
+ # Process any remaining items
770
591
  if not self.should_stop_processing.is_set():
771
592
  self._batch_for_inference()
593
+ if self.connected.is_set():
594
+ # Notify orchestrator that unit is complete
595
+ asyncio.run_coroutine_threadsafe(
596
+ self.websocket.send(
597
+ json.dumps({"type": "work_complete", "unit_id": unit.unit_id})
598
+ ),
599
+ self.main_loop,
600
+ ).result(timeout=5)
772
601
 
773
- logger.info(
774
- f"Chunk {chunk.chunk_id} processed {items_processed} items from unprocessed ranges"
775
- )
776
-
777
- def _shard_reader_thread(self):
778
- """Background thread that reads from WebDataset shards."""
779
- logger.info("Starting shard reader thread")
780
-
781
- while self.running:
782
- # Check if we should stop processing
783
- if self.should_stop_processing.is_set():
784
- logger.info("Shard reader waiting for reconnection")
785
- time.sleep(1)
786
- continue
787
-
788
- # Only process if connected
789
- if not self.connected.is_set():
790
- time.sleep(1)
791
- continue
792
-
793
- # Get next chunk to process
794
- with self.chunk_lock:
795
- if not self.current_chunk and self.assigned_chunks:
796
- self.current_chunk = self.assigned_chunks.popleft()
797
- self.current_chunk_progress = 0
798
- logger.info(f"Starting chunk {self.current_chunk.chunk_id}")
799
-
800
- if not self.current_chunk:
801
- time.sleep(1)
802
- continue
803
-
804
- try:
805
- # Process the chunk
806
- self._process_shard_chunk(self.current_chunk)
807
-
808
- # Only mark complete if still connected
809
- if self.connected.is_set() and not self.should_stop_processing.is_set():
810
- logger.info(f"Completed chunk {self.current_chunk.chunk_id}")
811
- self.chunks_completed += 1
812
-
813
- # Notify orchestrator if connected
814
- if self.websocket and self.main_loop:
815
- try:
816
- # Notify completion
817
- asyncio.run_coroutine_threadsafe(
818
- self.websocket.send(
819
- json.dumps(
820
- {
821
- "type": "chunk_complete",
822
- "chunk_id": self.current_chunk.chunk_id,
823
- }
824
- )
825
- ),
826
- self.main_loop,
827
- ).result(timeout=5)
828
-
829
- # Request more chunks if queue is low
830
- with self.chunk_lock:
831
- queue_size = len(self.assigned_chunks)
832
-
833
- if queue_size < 2:
834
- logger.info(f"Requesting more chunks (queue size: {queue_size})")
835
- asyncio.run_coroutine_threadsafe(
836
- self.websocket.send(
837
- json.dumps({"type": "request_chunks", "count": 2})
838
- ),
839
- self.main_loop,
840
- ).result(timeout=5)
841
-
842
- except Exception as e:
843
- logger.warning(f"Could not notify orchestrator: {e}")
844
-
845
- with self.chunk_lock:
846
- self.current_chunk = None
847
-
848
- except Exception as e:
849
- logger.error(f"Error processing chunk: {e}")
850
-
851
- # Only notify of failure if still connected
852
- if self.connected.is_set() and self.websocket and self.main_loop:
853
- try:
854
- asyncio.run_coroutine_threadsafe(
855
- self.websocket.send(
856
- json.dumps(
857
- {
858
- "type": "chunk_failed",
859
- "chunk_id": (
860
- self.current_chunk.chunk_id
861
- if self.current_chunk
862
- else "unknown"
863
- ),
864
- "error": str(e),
865
- }
866
- )
867
- ),
868
- self.main_loop,
869
- ).result(timeout=5)
870
- except Exception as send_error:
871
- logger.warning(
872
- f"Could not notify orchestrator of chunk failure: {send_error}"
873
- )
874
-
875
- with self.chunk_lock:
876
- self.current_chunk = None
877
-
878
- async def _result_sender(self):
879
- """Send results back to orchestrator with item index."""
880
- pending_results = []
881
-
882
- try:
883
- while self.running and self.connected.is_set():
884
- try:
885
- # Get result with timeout
886
- try:
887
- result = await asyncio.get_event_loop().run_in_executor(
888
- None, self.result_queue.get, True, 1
889
- )
890
- pending_results.append(result)
891
- except Empty:
892
- pass
893
-
894
- # Only try to send if connected
895
- if pending_results and self.websocket and self.connected.is_set():
896
- sent_results = []
897
- for result in pending_results:
898
- try:
899
- # Build message with item index
900
- message_data = {
901
- "type": "submit_captions",
902
- "chunk_id": result.chunk_id,
903
- "dataset": self.dataset_config.get("dataset_path", "unknown"),
904
- "shard": result.shard_name,
905
- "item_key": result.item_key,
906
- "item_index": result.item_index, # NEW: Include index
907
- "outputs": result.outputs,
908
- "captions": result.outputs.get("captions", []), # Compatibility
909
- "caption_count": sum(len(v) for v in result.outputs.values()),
910
- "image_width": result.image_width,
911
- "image_height": result.image_height,
912
- "image_format": result.image_format,
913
- "file_size": result.file_size,
914
- "processing_time_ms": result.processing_time_ms,
915
- "metadata": result.metadata,
916
- }
917
-
918
- await self.websocket.send(json.dumps(message_data))
919
- sent_results.append(result)
920
-
921
- if self.items_processed % 100 == 0:
922
- total_outputs = sum(
923
- len(outputs) for outputs in result.outputs.values()
924
- )
925
- logger.info(
926
- f"Processed {self.items_processed} items "
927
- f"(~{total_outputs} outputs across {len(result.outputs)} fields)"
928
- )
929
-
930
- except websockets.exceptions.ConnectionClosed as e:
931
- logger.warning(f"Connection lost while sending result: {e}")
932
- raise
933
- except Exception as e:
934
- logger.error(f"Error sending result: {e}")
935
- break
936
-
937
- # Remove successfully sent results
938
- for result in sent_results:
939
- pending_results.remove(result)
940
-
941
- # Clear pending results if disconnected and buffer is too large
942
- if not self.connected.is_set() and len(pending_results) > 1000:
943
- logger.warning(
944
- f"Clearing {len(pending_results)} pending results due to prolonged disconnection"
945
- )
946
- pending_results.clear()
947
-
948
- await asyncio.sleep(0.1)
949
-
950
- except Exception as e:
951
- if isinstance(e, websockets.exceptions.ConnectionClosed):
952
- raise
953
- logger.error(f"Unexpected error in result sender: {e}")
954
- await asyncio.sleep(1)
955
-
956
- except asyncio.CancelledError:
957
- logger.debug("Result sender cancelled")
958
- raise
602
+ logger.info(f"Unit {unit.unit_id} processed {items_processed} items")
959
603
 
960
604
  def _batch_for_inference(self):
961
605
  """Batch items from readahead queue for inference."""
@@ -972,9 +616,52 @@ class CaptionWorker(BaseWorker):
972
616
  if batch:
973
617
  self.inference_queue.put(batch)
974
618
 
619
+ def _inference_thread(self):
620
+ """Background thread for multi-stage vLLM inference."""
621
+ logger.info("Starting multi-stage inference thread")
622
+
623
+ while self.running:
624
+ try:
625
+ batch = self.inference_queue.get(timeout=1)
626
+ if not batch:
627
+ continue
628
+
629
+ if self.should_stop_processing.is_set():
630
+ continue
631
+
632
+ logger.debug(
633
+ f"Processing batch of {len(batch)} images through {len(self.stages)} stages"
634
+ )
635
+ start_time = time.time()
636
+
637
+ # Process batch through all stages
638
+ results = self._process_batch_multi_stage(batch)
639
+
640
+ # Calculate processing time
641
+ if results:
642
+ processing_time_per_item = (time.time() - start_time) * 1000 / len(batch)
643
+
644
+ for item, result_outputs in results:
645
+ self.result_queue.put(
646
+ {
647
+ "item": item,
648
+ "outputs": result_outputs,
649
+ "processing_time_ms": processing_time_per_item,
650
+ }
651
+ )
652
+
653
+ logger.debug(f"Batch processing complete: {len(results)} successful")
654
+
655
+ except Empty:
656
+ continue
657
+ except Exception as e:
658
+ if self.should_stop_processing.is_set():
659
+ continue
660
+ logger.error(f"Inference error: {e}", exc_info=True)
661
+
975
662
  def _process_batch_multi_stage(
976
663
  self, batch: List[ProcessingItem], max_attempts: int = 3
977
- ) -> List[ProcessedResult]:
664
+ ) -> List[Tuple[ProcessingItem, Dict]]:
978
665
  """Process a batch through all stages sequentially."""
979
666
  results = []
980
667
 
@@ -983,7 +670,7 @@ class CaptionWorker(BaseWorker):
983
670
  stage = next(s for s in self.stages if s.name == stage_name)
984
671
  logger.debug(f"Processing batch through stage: {stage_name}")
985
672
 
986
- # Get model components for this stage
673
+ # Get model components
987
674
  llm, processor, tokenizer, sampling_params = self.model_manager.get_model_for_stage(
988
675
  stage_name, stage.model
989
676
  )
@@ -992,39 +679,34 @@ class CaptionWorker(BaseWorker):
992
679
  items_to_process = [(i, item, 0) for i, item in enumerate(batch)]
993
680
 
994
681
  while items_to_process:
995
- # Build requests for current items
996
682
  current_batch = []
997
- current_indices = []
998
683
  requests = []
999
684
 
1000
685
  for idx, (original_idx, item, attempt_count) in enumerate(items_to_process):
1001
686
  current_batch.append((original_idx, item, attempt_count))
1002
- current_indices.append(idx)
1003
687
 
1004
688
  # Prepare image
1005
689
  converted_img = ImageProcessor.prepare_for_inference(item.image)
1006
690
 
1007
- # Create template manager for this stage's prompts
691
+ # Create template manager
1008
692
  template_manager = PromptTemplateManager(stage.prompts)
1009
693
 
1010
- # Build context including metadata and previous stage results
694
+ # Build context
1011
695
  context = item.metadata.copy()
1012
696
 
1013
- # Add previous stage outputs to context
697
+ # Add previous stage results
1014
698
  for prev_stage_name, stage_result in item.stage_results.items():
1015
- # Add outputs with stage name prefix
1016
699
  for i, output in enumerate(stage_result.outputs):
1017
700
  context[f"{prev_stage_name}_output_{i}"] = output
1018
- # Also add under output field name
1019
701
  if len(stage_result.outputs) == 1:
1020
702
  context[stage_result.output_field] = stage_result.outputs[0]
1021
703
  else:
1022
704
  context[stage_result.output_field] = stage_result.outputs
1023
705
 
1024
- # Format prompts with context
706
+ # Format prompts
1025
707
  formatted_prompts = template_manager.format_all(context)
1026
708
 
1027
- # Build requests for all prompts
709
+ # Build requests
1028
710
  for prompt in formatted_prompts:
1029
711
  req = self._build_vllm_input(converted_img, prompt, processor, tokenizer)
1030
712
  requests.append(req)
@@ -1037,11 +719,10 @@ class CaptionWorker(BaseWorker):
1037
719
  failed_items = []
1038
720
 
1039
721
  for idx, (original_idx, item, attempt_count) in enumerate(current_batch):
1040
- # Check if we should stop
1041
722
  if self.should_stop_processing.is_set():
1042
723
  return results
1043
724
 
1044
- # Extract outputs for this item
725
+ # Extract outputs
1045
726
  base_idx = idx * len(stage.prompts)
1046
727
  stage_outputs = []
1047
728
 
@@ -1051,13 +732,9 @@ class CaptionWorker(BaseWorker):
1051
732
  cleaned_output = self._clean_output(original_output)
1052
733
  if cleaned_output:
1053
734
  stage_outputs.append(cleaned_output)
1054
- else:
1055
- logger.warning(
1056
- f"(stage {stage_name}, item {item.item_key}) output destroyed: {original_output}"
1057
- )
1058
735
 
1059
736
  if stage_outputs:
1060
- # Success - add stage result to item
737
+ # Success
1061
738
  stage_result = StageResult(
1062
739
  stage_name=stage_name,
1063
740
  output_field=stage.output_field,
@@ -1066,102 +743,44 @@ class CaptionWorker(BaseWorker):
1066
743
  item.stage_results[stage_name] = stage_result
1067
744
  successful_items.append((original_idx, item))
1068
745
  else:
1069
- # Failed - check if we should retry
746
+ # Failed - check retry
1070
747
  if attempt_count + 1 < max_attempts:
1071
748
  failed_items.append((original_idx, item, attempt_count + 1))
1072
- logger.warning(
1073
- f"Stage {stage_name} failed for item {item.item_key} "
1074
- f"(attempt {attempt_count + 1}/{max_attempts}), will retry"
1075
- )
1076
749
  else:
1077
- logger.error(
1078
- f"Stage {stage_name} failed for item {item.item_key} "
1079
- f"after {max_attempts} attempts"
1080
- )
750
+ logger.error(f"Stage {stage_name} failed for item {item.item_key}")
1081
751
  self.items_failed += 1
752
+ stage_result = StageResult(
753
+ stage_name=stage_name,
754
+ output_field=stage.output_field,
755
+ outputs=[],
756
+ error=f"Failed after {max_attempts} attempts",
757
+ )
758
+ item.stage_results[stage_name] = stage_result
759
+ self.result_queue.put(
760
+ {
761
+ "item": item,
762
+ "outputs": {},
763
+ "processing_time_ms": 0.0,
764
+ "error": f"Failed stage {stage_name} after {max_attempts} attempts",
765
+ }
766
+ )
1082
767
 
1083
- # Update items to process for next iteration
768
+ # Update for next iteration
1084
769
  items_to_process = failed_items
1085
-
1086
- # Update batch with successful items for next stage
1087
770
  batch = [item for _, item in successful_items]
1088
771
 
1089
- # Log retry status if we have items to retry
1090
- if items_to_process:
1091
- logger.info(
1092
- f"Retrying {len(items_to_process)} failed items for stage {stage_name}"
1093
- )
1094
-
1095
- # Convert batch items to results
772
+ # Convert to results
1096
773
  for item in batch:
1097
- # Aggregate outputs by field name
774
+ # Aggregate outputs by field
1098
775
  outputs_by_field = defaultdict(list)
1099
-
1100
776
  for stage_result in item.stage_results.values():
1101
777
  outputs_by_field[stage_result.output_field].extend(stage_result.outputs)
1102
778
 
1103
- result = ProcessedResult(
1104
- chunk_id=item.chunk_id,
1105
- shard_name=Path(item.chunk_id).stem.rsplit("_chunk_", 1)[0],
1106
- item_key=item.item_key,
1107
- outputs=dict(outputs_by_field), # Convert defaultdict to dict
1108
- image_width=item.image.width,
1109
- image_height=item.image.height,
1110
- image_format=item.image.format or "unknown",
1111
- file_size=len(item.image_data),
1112
- processing_time_ms=0, # Will be calculated by caller
1113
- metadata=item.metadata,
1114
- )
1115
- results.append(result)
779
+ results.append((item, dict(outputs_by_field)))
1116
780
  self.items_processed += 1
1117
781
 
1118
782
  return results
1119
783
 
1120
- def _inference_thread(self):
1121
- """Background thread for multi-stage vLLM inference."""
1122
- logger.info("Starting multi-stage inference thread")
1123
-
1124
- while self.running:
1125
- try:
1126
- # Get batch from queue with timeout
1127
- batch = self.inference_queue.get(timeout=1)
1128
-
1129
- if not batch:
1130
- continue
1131
-
1132
- # Skip if disconnected
1133
- if self.should_stop_processing.is_set():
1134
- continue
1135
-
1136
- logger.debug(
1137
- f"Processing batch of {len(batch)} images through {len(self.stages)} stages"
1138
- )
1139
- start_time = time.time()
1140
-
1141
- # Process batch through all stages
1142
- results = self._process_batch_multi_stage(batch)
1143
-
1144
- # Calculate processing time per item
1145
- if results:
1146
- processing_time_per_item = (time.time() - start_time) * 1000 / len(batch)
1147
-
1148
- # Update processing time and queue results
1149
- for result in results:
1150
- result.processing_time_ms = processing_time_per_item
1151
- self.result_queue.put(result)
1152
-
1153
- logger.debug(
1154
- f"Multi-stage batch processing complete: {len(results)} successful, "
1155
- f"{len(batch) - len(results)} failed"
1156
- )
1157
-
1158
- except Empty:
1159
- continue
1160
- except Exception as e:
1161
- if self.should_stop_processing.is_set():
1162
- continue
1163
- logger.error(f"Inference error: {e}", exc_info=True)
1164
-
1165
784
  def _build_vllm_input(self, image: Image.Image, prompt: str, processor, tokenizer) -> Dict:
1166
785
  """Build vLLM input."""
1167
786
  try:
@@ -1198,124 +817,129 @@ class CaptionWorker(BaseWorker):
1198
817
  if not text:
1199
818
  return ""
1200
819
 
1201
- # Remove common artifacts
1202
820
  for token in ["<|end|>", "<|endoftext|>", "<|im_end|>", "I'm sorry", "I cannot"]:
1203
821
  if token in text:
1204
822
  text = text.split(token)[0]
1205
823
 
1206
824
  return text.strip()
1207
825
 
826
+ def _get_heartbeat_data(self) -> Dict[str, Any]:
827
+ """Get heartbeat data."""
828
+ return {
829
+ "type": "heartbeat",
830
+ "processed": self.items_processed,
831
+ "failed": self.items_failed,
832
+ "units_completed": self.units_completed,
833
+ "current_unit": self._get_current_unit_id() if self.current_unit else None,
834
+ "queue_sizes": {
835
+ "readahead": self.readahead_queue.qsize(),
836
+ "inference": self.inference_queue.qsize(),
837
+ "results": self.result_queue.qsize(),
838
+ },
839
+ "stages": len(self.stages),
840
+ "models_loaded": len(self.model_manager.models) if self.model_manager else 0,
841
+ }
842
+
843
+ async def _create_tasks(self) -> list:
844
+ """Create async tasks to run."""
845
+ return [
846
+ asyncio.create_task(self._heartbeat_loop()),
847
+ asyncio.create_task(self._base_message_handler()),
848
+ asyncio.create_task(self._result_sender()),
849
+ ]
850
+
1208
851
  async def _result_sender(self):
1209
- """Send results back to orchestrator with multi-stage outputs."""
1210
- pending_results = [] # Buffer for results during disconnection
852
+ """Send results back to orchestrator."""
853
+ while self.running and self.connected.is_set():
854
+ try:
855
+ # Get result
856
+ result_data = await asyncio.get_event_loop().run_in_executor(
857
+ None, self.result_queue.get, True, 1
858
+ )
1211
859
 
1212
- try:
1213
- while self.running and self.connected.is_set():
1214
- try:
1215
- # Get result (with timeout to allow checking self.running)
1216
- try:
1217
- result = await asyncio.get_event_loop().run_in_executor(
1218
- None, self.result_queue.get, True, 1
860
+ if self.websocket and self.connected.is_set():
861
+ item = result_data["item"]
862
+ logger.debug(f"Handling results for item: {item}")
863
+ outputs = result_data["outputs"]
864
+
865
+ # Create work result
866
+ # logger.info(f"Processed item: {item}")
867
+ work_result = WorkResult(
868
+ unit_id=item.unit_id,
869
+ source_id=item.metadata.get("shard_name", "unknown"),
870
+ chunk_id=item.chunk_id,
871
+ sample_id=f"{item.item_key}",
872
+ outputs=outputs,
873
+ metadata={
874
+ "item_key": item.item_key,
875
+ "item_index": item.metadata.get("_item_index"),
876
+ "image_width": item.image.width,
877
+ "image_height": item.image.height,
878
+ "image_format": item.image.format or "unknown",
879
+ "file_size": len(item.image_data) if item.image_data else 0,
880
+ **item.metadata,
881
+ },
882
+ processing_time_ms=result_data["processing_time_ms"],
883
+ error=result_data.get("error", None),
884
+ )
885
+
886
+ # Send result in format that orchestrator expects
887
+ await self.websocket.send(
888
+ json.dumps(
889
+ {
890
+ "type": "submit_results",
891
+ "unit_id": work_result.unit_id,
892
+ "job_id": item.job_id,
893
+ "dataset": self.dataset_path,
894
+ "sample_id": work_result.sample_id,
895
+ "source_id": work_result.source_id,
896
+ "outputs": work_result.outputs,
897
+ "metadata": work_result.metadata,
898
+ "processing_time_ms": work_result.processing_time_ms,
899
+ }
1219
900
  )
1220
- pending_results.append(result)
1221
- except Empty:
1222
- pass
1223
-
1224
- # Only try to send if connected
1225
- if pending_results and self.websocket and self.connected.is_set():
1226
- sent_results = []
1227
- for result in pending_results:
1228
- try:
1229
- # For backward compatibility, if there's only one output field "captions"
1230
- # send it in the old format
1231
- if len(result.outputs) == 1 and "captions" in result.outputs:
1232
- # Old format for single-stage compatibility
1233
- await self.websocket.send(
1234
- json.dumps(
1235
- {
1236
- "type": "submit_captions",
1237
- "chunk_id": result.chunk_id,
1238
- "dataset": self.dataset_config.get(
1239
- "dataset_path", "unknown"
1240
- ),
1241
- "shard": result.shard_name,
1242
- "item_key": result.item_key,
1243
- "item_index": result.metadata.get("_item_index"),
1244
- "captions": result.outputs["captions"],
1245
- "caption_count": len(result.outputs["captions"]),
1246
- "image_width": result.image_width,
1247
- "image_height": result.image_height,
1248
- "image_format": result.image_format,
1249
- "file_size": result.file_size,
1250
- "processing_time_ms": result.processing_time_ms,
1251
- }
1252
- )
1253
- )
1254
- else:
1255
- # New format for multi-stage outputs
1256
- await self.websocket.send(
1257
- json.dumps(
1258
- {
1259
- "type": "submit_captions",
1260
- "chunk_id": result.chunk_id,
1261
- "dataset": self.dataset_config.get(
1262
- "dataset_path", "unknown"
1263
- ),
1264
- "shard": result.shard_name,
1265
- "item_key": result.item_key,
1266
- "outputs": result.outputs, # Dict of field -> list of outputs
1267
- "captions": result.outputs.get(
1268
- "captions", []
1269
- ), # For compatibility
1270
- "caption_count": sum(
1271
- len(v) for v in result.outputs.values()
1272
- ),
1273
- "image_width": result.image_width,
1274
- "image_height": result.image_height,
1275
- "image_format": result.image_format,
1276
- "file_size": result.file_size,
1277
- "processing_time_ms": result.processing_time_ms,
1278
- "metadata": result.metadata,
1279
- }
1280
- )
1281
- )
1282
-
1283
- sent_results.append(result)
1284
-
1285
- if self.items_processed % 100 == 0:
1286
- total_outputs = sum(
1287
- len(outputs) for outputs in result.outputs.values()
1288
- )
1289
- logger.info(
1290
- f"Processed {self.items_processed} items "
1291
- f"(~{total_outputs} outputs across {len(result.outputs)} fields)"
1292
- )
1293
- except websockets.exceptions.ConnectionClosed as e:
1294
- logger.warning(f"Connection lost while sending result: {e}")
1295
- raise # Re-raise to trigger task completion
1296
- except Exception as e:
1297
- logger.error(f"Error sending result: {e}")
1298
- break
1299
-
1300
- # Remove successfully sent results
1301
- for result in sent_results:
1302
- pending_results.remove(result)
1303
-
1304
- # Clear pending results if disconnected and buffer is too large
1305
- if not self.connected.is_set() and len(pending_results) > 1000:
1306
- logger.warning(
1307
- f"Clearing {len(pending_results)} pending results due to prolonged disconnection"
901
+ )
902
+
903
+ if self.items_processed % 100 == 0:
904
+ total_outputs = sum(len(v) for v in outputs.values())
905
+ logger.info(
906
+ f"Processed {self.items_processed} items (~{total_outputs} outputs)"
1308
907
  )
1309
- pending_results.clear()
1310
908
 
1311
- await asyncio.sleep(0.1)
909
+ except Empty:
910
+ continue
911
+ except Exception as e:
912
+ logger.error(f"Error sending result: {e}", exc_info=True)
913
+ await asyncio.sleep(1)
914
+
915
+ async def _on_disconnect(self):
916
+ """Handle disconnection."""
917
+ self.should_stop_processing.set()
918
+
919
+ with self.work_lock:
920
+ self.assigned_units.clear()
921
+ self.current_unit = None
922
+
923
+ # Clear queues
924
+ self._clear_queue(self.readahead_queue)
925
+ self._clear_queue(self.inference_queue)
926
+ self._clear_queue(self.result_queue)
927
+
928
+ def _clear_queue(self, queue: Queue):
929
+ """Clear all items from a queue."""
930
+ try:
931
+ while True:
932
+ queue.get_nowait()
933
+ except Empty:
934
+ pass
935
+
936
+ async def _pre_shutdown(self):
937
+ """Cleanup before shutdown."""
938
+ self.readahead_queue.put(None)
939
+ self.inference_queue.put(None)
1312
940
 
1313
- except Exception as e:
1314
- if isinstance(e, websockets.exceptions.ConnectionClosed):
1315
- raise # Re-raise connection errors
1316
- logger.error(f"Unexpected error in result sender: {e}")
1317
- await asyncio.sleep(1)
941
+ if self.image_processor:
942
+ self.image_processor.shutdown()
1318
943
 
1319
- except asyncio.CancelledError:
1320
- logger.debug("Result sender cancelled")
1321
- raise
944
+ if self.model_manager:
945
+ self.model_manager.cleanup()