caption-flow 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1321 @@
1
+ """Caption worker for vLLM-based distributed image captioning with multi-stage processing."""
2
+
3
+ import os
4
+
5
+ os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
6
+
7
+ import asyncio
8
+ import io
9
+ import json
10
+ import logging
11
+ import websockets
12
+ import time
13
+ from dataclasses import dataclass, field
14
+ from pathlib import Path
15
+ from typing import Dict, Any, Optional, List, Tuple
16
+ from queue import Queue, Empty
17
+ from threading import Thread, Lock, Event
18
+ from collections import deque, defaultdict
19
+
20
+ from PIL import Image
21
+ import numpy as np
22
+ from huggingface_hub import get_token
23
+
24
+ from .base import BaseWorker
25
+ from ..models import JobStatus, Job
26
+ from ..utils import CaptionUtils
27
+ from ..utils.dataset_loader import DatasetLoader
28
+ from ..utils.vllm_config import VLLMConfigManager
29
+ from ..utils.image_processor import ImageProcessor
30
+ from ..utils.shard_processor import HFDatasetShardProcessor, WebDatasetShardProcessor
31
+ from ..utils.prompt_template import PromptTemplateManager
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ @dataclass
37
+ class ProcessingStage:
38
+ """Configuration for a single processing stage."""
39
+
40
+ name: str
41
+ model: str
42
+ prompts: List[str]
43
+ output_field: str
44
+ requires: List[str] = field(default_factory=list)
45
+ sampling: Optional[Dict[str, Any]] = None
46
+
47
+ # Model-specific overrides
48
+ tensor_parallel_size: Optional[int] = None
49
+ max_model_len: Optional[int] = None
50
+ dtype: Optional[str] = None
51
+ gpu_memory_utilization: Optional[float] = None
52
+
53
+
54
+ @dataclass
55
+ class StageResult:
56
+ """Results from a single stage."""
57
+
58
+ stage_name: str
59
+ output_field: str
60
+ outputs: List[str] # Multiple outputs from multiple prompts
61
+
62
+
63
+ @dataclass
64
+ class ShardChunk:
65
+ """Shard chunk assignment with unprocessed ranges."""
66
+
67
+ chunk_id: str
68
+ shard_url: str
69
+ shard_name: str
70
+ start_index: int
71
+ chunk_size: int
72
+ unprocessed_ranges: List[Tuple[int, int]] = field(default_factory=list)
73
+
74
+
75
+ @dataclass
76
+ class ProcessingItem:
77
+ """Item being processed."""
78
+
79
+ chunk_id: str
80
+ item_key: str
81
+ image: Image.Image
82
+ image_data: bytes
83
+ metadata: Dict[str, Any] = field(default_factory=dict)
84
+ stage_results: Dict[str, StageResult] = field(default_factory=dict) # Accumulated results
85
+
86
+
87
+ @dataclass
88
+ class ProcessedResult:
89
+ """Result with multi-stage outputs."""
90
+
91
+ chunk_id: str
92
+ shard_name: str
93
+ item_key: str
94
+ outputs: Dict[str, List[str]] # field_name -> list of outputs
95
+ image_width: int
96
+ image_height: int
97
+ image_format: str
98
+ file_size: int
99
+ processing_time_ms: float
100
+ metadata: Dict[str, Any] = field(default_factory=dict)
101
+
102
+
103
+ class MultiStageVLLMManager:
104
+ """Manages multiple vLLM instances for different models."""
105
+
106
+ def __init__(self, gpu_id: int = 0):
107
+ self.gpu_id = gpu_id
108
+ self.models: Dict[str, Any] = {} # model_name -> LLM instance
109
+ self.processors: Dict[str, Any] = {} # model_name -> processor
110
+ self.tokenizers: Dict[str, Any] = {} # model_name -> tokenizer
111
+ self.sampling_params: Dict[str, Any] = {} # stage_name -> SamplingParams
112
+
113
+ def load_model(self, model_name: str, stage: ProcessingStage, base_config: Dict[str, Any]):
114
+ """Load a model if not already loaded."""
115
+ if model_name in self.models:
116
+ logger.info(f"Model {model_name} already loaded, reusing instance")
117
+ return
118
+
119
+ from vllm import LLM, SamplingParams
120
+ from transformers import AutoTokenizer, AutoProcessor
121
+
122
+ logger.info(f"Loading model {model_name} for stage {stage.name}")
123
+
124
+ # Build model-specific config by merging base config with stage overrides
125
+ model_config = base_config.copy()
126
+ model_config["model"] = model_name
127
+
128
+ # Apply stage-specific overrides
129
+ if stage.tensor_parallel_size is not None:
130
+ model_config["tensor_parallel_size"] = stage.tensor_parallel_size
131
+ if stage.max_model_len is not None:
132
+ model_config["max_model_len"] = stage.max_model_len
133
+ if stage.dtype is not None:
134
+ model_config["dtype"] = stage.dtype
135
+ if stage.gpu_memory_utilization is not None:
136
+ model_config["gpu_memory_utilization"] = stage.gpu_memory_utilization
137
+
138
+ # Load tokenizer and processor
139
+ self.tokenizers[model_name] = AutoTokenizer.from_pretrained(
140
+ model_name, trust_remote_code=True, use_fast=True
141
+ )
142
+ self.processors[model_name] = AutoProcessor.from_pretrained(model_name)
143
+
144
+ # Initialize LLM
145
+ vllm_params = {
146
+ "model": model_name,
147
+ "trust_remote_code": True,
148
+ "tensor_parallel_size": model_config.get("tensor_parallel_size", 1),
149
+ "max_model_len": model_config.get("max_model_len", 16384),
150
+ "enforce_eager": model_config.get("enforce_eager", True),
151
+ "gpu_memory_utilization": model_config.get("gpu_memory_utilization", 0.92),
152
+ "dtype": model_config.get("dtype", "float16"),
153
+ "limit_mm_per_prompt": model_config.get("limit_mm_per_prompt", {"image": 1}),
154
+ "disable_mm_preprocessor_cache": model_config.get(
155
+ "disable_mm_preprocessor_cache", True
156
+ ),
157
+ }
158
+
159
+ self.models[model_name] = LLM(**vllm_params)
160
+ logger.info(f"Model {model_name} loaded successfully")
161
+
162
+ def create_sampling_params(self, stage: ProcessingStage, base_sampling: Dict[str, Any]):
163
+ """Create sampling params for a stage."""
164
+ from vllm import SamplingParams
165
+
166
+ # Start with base sampling config
167
+ sampling_config = base_sampling.copy()
168
+
169
+ # Override with stage-specific sampling if provided
170
+ if stage.sampling:
171
+ sampling_config.update(stage.sampling)
172
+
173
+ params = SamplingParams(
174
+ temperature=sampling_config.get("temperature", 0.7),
175
+ top_p=sampling_config.get("top_p", 0.95),
176
+ max_tokens=sampling_config.get("max_tokens", 256),
177
+ stop=sampling_config.get("stop", ["<|end|>", "<|endoftext|>", "<|im_end|>"]),
178
+ repetition_penalty=sampling_config.get("repetition_penalty", 1.05),
179
+ skip_special_tokens=sampling_config.get("skip_special_tokens", True),
180
+ )
181
+
182
+ self.sampling_params[stage.name] = params
183
+ return params
184
+
185
+ def get_model_for_stage(self, stage_name: str, model_name: str) -> Tuple[Any, Any, Any, Any]:
186
+ """
187
+ Get model components for a stage.
188
+
189
+ Returns:
190
+ tuple: A tuple containing:
191
+ - llm: The language model instance for the given model name.
192
+ - processor: The processor associated with the model.
193
+ - tokenizer: The tokenizer for the model.
194
+ - sampling_params: The sampling parameters for the given stage.
195
+ """
196
+ return (
197
+ self.models[model_name],
198
+ self.processors[model_name],
199
+ self.tokenizers[model_name],
200
+ self.sampling_params[stage_name],
201
+ )
202
+
203
+ def cleanup(self):
204
+ """Clean up all loaded models."""
205
+ for model_name in list(self.models.keys()):
206
+ del self.models[model_name]
207
+ del self.processors[model_name]
208
+ del self.tokenizers[model_name]
209
+ self.sampling_params.clear()
210
+
211
+ import gc
212
+
213
+ gc.collect()
214
+
215
+
216
+ class CaptionWorker(BaseWorker):
217
+ """Worker that processes shard chunks for image captioning using multi-stage vLLM."""
218
+
219
+ def __init__(self, config: Dict[str, Any]):
220
+ super().__init__(config)
221
+
222
+ batch_image_processing = config.get("batch_image_processing", False)
223
+
224
+ # Dataset configuration will be received from orchestrator
225
+ self.dataset_config = None
226
+ self.dataset_loader = None
227
+ self.dataset_type = None
228
+ self.dataset_split = None
229
+ self.dataset_image_column = None
230
+ self.hf_token = get_token()
231
+
232
+ # vLLM configuration will be received from orchestrator
233
+ self.vllm_config = None
234
+ self.stages: List[ProcessingStage] = []
235
+ self.stage_order: List[str] = [] # Topologically sorted stage names
236
+ self.vllm_config_manager = VLLMConfigManager()
237
+ self.model_manager = None
238
+
239
+ # Backward compatibility: local config for GPU selection
240
+ self.gpu_id = config.get("gpu_id", 0)
241
+
242
+ # Connection state events
243
+ self.should_stop_processing = Event()
244
+
245
+ # Image processor
246
+ self.image_processor = None
247
+ if batch_image_processing:
248
+ self.image_processor = ImageProcessor()
249
+
250
+ # Shard chunk processing
251
+ self.hf_processor = HFDatasetShardProcessor()
252
+ self.webdataset_processor = WebDatasetShardProcessor(
253
+ hf_token=self.hf_token, dataset_type=self.dataset_type
254
+ )
255
+ self.chunk_lock = Lock()
256
+ self.assigned_chunks = deque()
257
+ self.current_chunk = None
258
+ self.current_chunk_progress = 0
259
+
260
+ # Batching queues - will be cleared on disconnect
261
+ self.readahead_queue = Queue(maxsize=256)
262
+ self.inference_queue = Queue(maxsize=128)
263
+ self.result_queue = Queue()
264
+
265
+ # Job mode for shards vs jobs and job queue
266
+ self.job_mode = config.get("job_mode", False)
267
+ self.job_queue = Queue(maxsize=32)
268
+
269
+ def _init_metrics(self):
270
+ """Initialize worker metrics."""
271
+ self.items_processed = 0
272
+ self.items_failed = 0
273
+ self.chunks_completed = 0
274
+
275
+ def _get_auth_data(self) -> Dict[str, Any]:
276
+ """Get authentication data."""
277
+ return {"token": self.token, "name": self.name}
278
+
279
+ async def _pre_start(self):
280
+ """Initialize before starting connection loop."""
281
+ # Wait for initial connection to get vLLM config
282
+ logger.info("Connecting to orchestrator for configuration...")
283
+
284
+ # Try initial connection to get config
285
+ config_received = False
286
+ while not config_received and self.running:
287
+ try:
288
+ await self._initial_connect_for_config()
289
+ config_received = True
290
+ except Exception as e:
291
+ logger.error(f"Failed to get config: {e}")
292
+ await asyncio.sleep(5)
293
+
294
+ # Initialize vLLM once we have config
295
+ self._setup_vllm()
296
+
297
+ # Start background threads
298
+ reader_thread = Thread(target=self._shard_reader_thread, daemon=True)
299
+ reader_thread.start()
300
+
301
+ inference_thread = Thread(target=self._inference_thread, daemon=True)
302
+ inference_thread.start()
303
+
304
+ def _parse_stages_config(self, vllm_config: Dict[str, Any]) -> List[ProcessingStage]:
305
+ """Parse stages configuration from vLLM config."""
306
+ stages_config = vllm_config.get("stages", [])
307
+
308
+ if not stages_config:
309
+ # Backward compatibility: create single stage from old config
310
+ return [
311
+ ProcessingStage(
312
+ name="default",
313
+ model=vllm_config.get("model", "Qwen/Qwen2.5-VL-3B-Instruct"),
314
+ prompts=vllm_config.get("inference_prompts", ["describe this image"]),
315
+ output_field="captions",
316
+ requires=[],
317
+ )
318
+ ]
319
+
320
+ # Parse stages
321
+ stages = []
322
+ for stage_cfg in stages_config:
323
+ stage = ProcessingStage(
324
+ name=stage_cfg["name"],
325
+ model=stage_cfg.get("model", vllm_config.get("model")),
326
+ prompts=stage_cfg.get("prompts", []),
327
+ output_field=stage_cfg.get("output_field", "captions"),
328
+ requires=stage_cfg.get("requires", []),
329
+ sampling=stage_cfg.get("sampling"),
330
+ tensor_parallel_size=stage_cfg.get("tensor_parallel_size"),
331
+ max_model_len=stage_cfg.get("max_model_len"),
332
+ dtype=stage_cfg.get("dtype"),
333
+ gpu_memory_utilization=stage_cfg.get("gpu_memory_utilization"),
334
+ )
335
+ stages.append(stage)
336
+
337
+ return stages
338
+
339
+ def _topological_sort_stages(self, stages: List[ProcessingStage]) -> List[str]:
340
+ """Sort stages by dependencies."""
341
+ # Build dependency graph
342
+ graph = defaultdict(list)
343
+ in_degree = defaultdict(int)
344
+
345
+ stage_map = {s.name: s for s in stages}
346
+
347
+ for stage in stages:
348
+ in_degree[stage.name] = len(stage.requires)
349
+ for dep in stage.requires:
350
+ if dep not in stage_map:
351
+ raise ValueError(f"Stage '{stage.name}' requires missing dependency '{dep}'")
352
+ graph[dep].append(stage.name)
353
+
354
+ # Topological sort using Kahn's algorithm
355
+ queue = deque([name for name, degree in in_degree.items() if degree == 0])
356
+ result = []
357
+
358
+ while queue:
359
+ current = queue.popleft()
360
+ result.append(current)
361
+
362
+ for neighbor in graph[current]:
363
+ in_degree[neighbor] -= 1
364
+ if in_degree[neighbor] == 0:
365
+ queue.append(neighbor)
366
+
367
+ if len(result) != len(stages):
368
+ raise ValueError("Circular dependency detected in stages")
369
+
370
+ return result
371
+
372
+ async def _handle_welcome(self, welcome_data: Dict[str, Any]):
373
+ """Handle welcome message from orchestrator."""
374
+ # Extract and setup dataset configuration
375
+ dataset_config = welcome_data.get("dataset_config", {})
376
+ if dataset_config:
377
+ self._setup_dataset_loader(dataset_config)
378
+ logger.info(f"Received dataset config: {dataset_config}")
379
+ else:
380
+ logger.warning("No dataset configuration received from orchestrator")
381
+
382
+ # Update vLLM config if provided (in case it changed)
383
+ new_vllm_config = welcome_data.get("vllm_config")
384
+ if new_vllm_config and new_vllm_config != self.vllm_config:
385
+ logger.info("Received updated vLLM configuration")
386
+ if not self._handle_vllm_config_update(new_vllm_config):
387
+ logger.error("Failed to update vLLM configuration")
388
+
389
+ # Clear stop signal now that we're connected
390
+ self.should_stop_processing.clear()
391
+
392
+ # Request initial chunks if not in job mode
393
+ if not self.job_mode and self.websocket:
394
+ await self.websocket.send(json.dumps({"type": "request_chunks", "count": 2}))
395
+
396
+ async def _handle_message(self, data: Dict[str, Any]):
397
+ """Handle message from orchestrator."""
398
+ msg_type = data.get("type")
399
+
400
+ if msg_type == "shard_assignment":
401
+ chunks = data["chunks"]
402
+ for chunk_data in chunks:
403
+ chunk = ShardChunk(**chunk_data)
404
+ with self.chunk_lock:
405
+ self.assigned_chunks.append(chunk)
406
+ logger.info(f"Received chunk assignment: {chunk.chunk_id}")
407
+
408
+ elif msg_type == "no_chunks":
409
+ reason = data.get("reason", "unknown")
410
+ logger.info(f"No chunks available from orchestrator (reason: {reason})")
411
+
412
+ wait_time = 2 if reason == "state_restoring" else 10
413
+ await asyncio.sleep(wait_time)
414
+
415
+ if self.websocket and self.connected.is_set():
416
+ await self.websocket.send(json.dumps({"type": "request_chunks", "count": 2}))
417
+
418
+ elif msg_type == "reload_vllm":
419
+ logger.info("Orchestrator requested vLLM reload")
420
+ new_config = data.get("vllm_config")
421
+ if new_config:
422
+ self._handle_vllm_config_update(new_config)
423
+
424
+ elif msg_type == "config_update":
425
+ # Soft config update without reload
426
+ if data.get("vllm_config"):
427
+ self._handle_vllm_config_update(data["vllm_config"])
428
+
429
+ elif msg_type == "job_assignment":
430
+ await self._handle_job_assignment(data["job"])
431
+
432
+ elif msg_type == "no_jobs":
433
+ logger.debug("No jobs available")
434
+ await asyncio.sleep(2)
435
+
436
+ def _get_heartbeat_data(self) -> Dict[str, Any]:
437
+ """Get heartbeat data."""
438
+ return {
439
+ "type": "heartbeat",
440
+ "processed": self.items_processed,
441
+ "failed": self.items_failed,
442
+ "chunks_completed": self.chunks_completed,
443
+ "current_chunk": self.current_chunk.chunk_id if self.current_chunk else None,
444
+ "chunk_progress": self.current_chunk_progress,
445
+ "queue_sizes": {
446
+ "readahead": self.readahead_queue.qsize(),
447
+ "inference": self.inference_queue.qsize(),
448
+ "results": self.result_queue.qsize(),
449
+ },
450
+ "stages": len(self.stages),
451
+ "models_loaded": len(self.model_manager.models) if self.model_manager else 0,
452
+ }
453
+
454
+ async def _create_tasks(self) -> list:
455
+ """Create async tasks to run."""
456
+ tasks = [
457
+ asyncio.create_task(self._heartbeat_loop()),
458
+ asyncio.create_task(self._base_message_handler()),
459
+ asyncio.create_task(self._result_sender()),
460
+ ]
461
+
462
+ if self.job_mode:
463
+ tasks.append(asyncio.create_task(self._job_request_loop()))
464
+
465
+ return tasks
466
+
467
+ async def _on_disconnect(self):
468
+ """Handle disconnection."""
469
+ self._clear_state_on_disconnect()
470
+
471
+ async def _pre_shutdown(self):
472
+ """Cleanup before shutdown."""
473
+ # Stop processing threads by adding stop signals
474
+ self.readahead_queue.put(None)
475
+ self.inference_queue.put(None)
476
+
477
+ # Shutdown image processor
478
+ if self.image_processor is not None:
479
+ self.image_processor.shutdown()
480
+
481
+ # Cleanup model manager
482
+ if self.model_manager:
483
+ self.model_manager.cleanup()
484
+
485
+ async def _initial_connect_for_config(self):
486
+ """Connect initially just to get configuration."""
487
+ logger.info(f"Connecting to {self.server_url}")
488
+ async with websockets.connect(self.server_url, ssl=self.ssl_context) as websocket:
489
+ await websocket.send(json.dumps(self._get_auth_data()))
490
+
491
+ welcome = await websocket.recv()
492
+ welcome_data = json.loads(welcome)
493
+
494
+ if "error" in welcome_data:
495
+ raise RuntimeError(f"Authentication failed: {welcome_data['error']}")
496
+
497
+ self.vllm_config = welcome_data.get("vllm_config")
498
+ if not self.vllm_config:
499
+ raise RuntimeError("No vLLM configuration received from orchestrator")
500
+
501
+ # Parse stages configuration
502
+ self.stages = self._parse_stages_config(self.vllm_config)
503
+ self.stage_order = self._topological_sort_stages(self.stages)
504
+
505
+ logger.info(f"Configured {len(self.stages)} processing stages: {self.stage_order}")
506
+
507
+ self.vllm_config_manager.current_config = self.vllm_config
508
+
509
+ dataset_config = welcome_data.get("dataset_config", {})
510
+ if dataset_config:
511
+ self._setup_dataset_loader(dataset_config)
512
+
513
+ logger.info("Received configuration from orchestrator")
514
+
515
+ def _clear_state_on_disconnect(self):
516
+ """Clear all processing state when disconnected."""
517
+ logger.info("Clearing state due to disconnection")
518
+
519
+ self.should_stop_processing.set()
520
+
521
+ with self.chunk_lock:
522
+ self.assigned_chunks.clear()
523
+ self.current_chunk = None
524
+ self.current_chunk_progress = 0
525
+
526
+ self._clear_queue(self.readahead_queue)
527
+ self._clear_queue(self.inference_queue)
528
+ self._clear_queue(self.result_queue)
529
+
530
+ logger.info("State cleared, ready for reconnection")
531
+
532
+ def _clear_queue(self, queue: Queue):
533
+ """Clear all items from a queue."""
534
+ try:
535
+ while True:
536
+ queue.get_nowait()
537
+ except Empty:
538
+ pass
539
+
540
+ def _setup_dataset_loader(self, dataset_config: Dict[str, Any]):
541
+ """Initialize dataset loader with config from orchestrator."""
542
+ dataset_path = dataset_config.get("dataset_path") or dataset_config.get("path")
543
+ dataset_type = dataset_config.get("dataset_type") or dataset_config.get(
544
+ "type", "huggingface"
545
+ )
546
+ dataset_split = dataset_config.get("dataset_split") or dataset_config.get("split", "train")
547
+ dataset_image_column = dataset_config.get("dataset_image_column") or dataset_config.get(
548
+ "image_column", "image"
549
+ )
550
+
551
+ if dataset_path:
552
+ logger.info(
553
+ f"Initializing dataset loader for {dataset_type}: {dataset_path} "
554
+ f"(split: {dataset_split}, image_column: {dataset_image_column})"
555
+ )
556
+ self.dataset_loader = DatasetLoader(
557
+ dataset_path, dataset_type, dataset_split, dataset_image_column
558
+ )
559
+ self.dataset_config = dataset_config
560
+ self.dataset_type = dataset_type
561
+ self.dataset_split = dataset_split
562
+ self.dataset_image_column = dataset_image_column
563
+ else:
564
+ logger.warning("No dataset path provided by orchestrator")
565
+
566
+ def _setup_vllm(self):
567
+ """Initialize multi-stage vLLM components."""
568
+ if not self.vllm_config:
569
+ raise RuntimeError("vLLM config not received from orchestrator")
570
+
571
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(self.gpu_id)
572
+
573
+ # Initialize model manager
574
+ self.model_manager = MultiStageVLLMManager(self.gpu_id)
575
+
576
+ # Get base config for models
577
+ base_config = {
578
+ "tensor_parallel_size": self.vllm_config.get("tensor_parallel_size", 1),
579
+ "max_model_len": self.vllm_config.get("max_model_len", 16384),
580
+ "dtype": self.vllm_config.get("dtype", "float16"),
581
+ "gpu_memory_utilization": self.vllm_config.get("gpu_memory_utilization", 0.92),
582
+ "enforce_eager": self.vllm_config.get("enforce_eager", True),
583
+ "disable_mm_preprocessor_cache": self.vllm_config.get(
584
+ "disable_mm_preprocessor_cache", True
585
+ ),
586
+ "limit_mm_per_prompt": self.vllm_config.get("limit_mm_per_prompt", {"image": 1}),
587
+ }
588
+
589
+ base_sampling = self.vllm_config.get("sampling", {})
590
+
591
+ # Load models for all stages
592
+ unique_models = set()
593
+ for stage in self.stages:
594
+ unique_models.add(stage.model)
595
+
596
+ logger.info(f"Loading {len(unique_models)} unique models for {len(self.stages)} stages")
597
+
598
+ for stage in self.stages:
599
+ self.model_manager.load_model(stage.model, stage, base_config)
600
+ self.model_manager.create_sampling_params(stage, base_sampling)
601
+
602
+ logger.info("Multi-stage vLLM initialization complete")
603
+
604
+ # Update config manager's tracking
605
+ self.vllm_config_manager.current_config = self.vllm_config
606
+
607
+ def _handle_vllm_config_update(self, new_config: Dict[str, Any]) -> bool:
608
+ """Handle vLLM configuration updates for multi-stage."""
609
+ if not new_config:
610
+ return True
611
+
612
+ # Parse new stages
613
+ new_stages = self._parse_stages_config(new_config)
614
+
615
+ # Check if stages changed significantly
616
+ stages_changed = len(new_stages) != len(self.stages)
617
+ if not stages_changed:
618
+ for old, new in zip(self.stages, new_stages):
619
+ if (
620
+ old.name != new.name
621
+ or old.model != new.model
622
+ or old.prompts != new.prompts
623
+ or old.output_field != new.output_field
624
+ ):
625
+ stages_changed = True
626
+ break
627
+
628
+ if stages_changed:
629
+ logger.info("Stage configuration changed, reloading all models")
630
+
631
+ # Save old config
632
+ old_config = self.vllm_config
633
+ self.vllm_config = new_config
634
+ self.stages = new_stages
635
+ self.stage_order = self._topological_sort_stages(self.stages)
636
+
637
+ try:
638
+ # Cleanup old models
639
+ if self.model_manager:
640
+ self.model_manager.cleanup()
641
+
642
+ # Reload with new config
643
+ self._setup_vllm()
644
+
645
+ logger.info("Multi-stage vLLM reload complete")
646
+ return True
647
+
648
+ except Exception as e:
649
+ logger.error(f"Failed to reload vLLM: {e}")
650
+ # Restore old config
651
+ self.vllm_config = old_config
652
+ return False
653
+ else:
654
+ # Just update sampling params for existing stages
655
+ logger.info("Updating sampling parameters without model reload")
656
+
657
+ base_sampling = new_config.get("sampling", {})
658
+ for stage in self.stages:
659
+ self.model_manager.create_sampling_params(stage, base_sampling)
660
+
661
+ self.vllm_config = new_config
662
+ return True
663
+
664
+ async def _handle_job_assignment(self, job_data: Dict):
665
+ """Handle job assignment from orchestrator."""
666
+ try:
667
+ # Convert to processing item
668
+ image = Image.open(io.BytesIO(job_data["image_data"]))
669
+
670
+ item = ProcessingItem(
671
+ chunk_id=job_data["job_id"],
672
+ item_key=job_data["sample_id"],
673
+ image=image,
674
+ image_data=job_data["image_data"],
675
+ )
676
+
677
+ # Add to inference queue
678
+ self.readahead_queue.put(item)
679
+ logger.debug(f"Queued job {job_data['job_id']} for processing")
680
+
681
+ except Exception as e:
682
+ logger.error(f"Error handling job assignment: {e}")
683
+
684
+ async def _job_request_loop(self):
685
+ """Request jobs from orchestrator in job mode."""
686
+ while self.running and self.connected.is_set():
687
+ try:
688
+ # Check if we need more work
689
+ if self.readahead_queue.qsize() < self.vllm_config.get("batch_size", 8):
690
+ await self.websocket.send(json.dumps({"type": "request_job"}))
691
+
692
+ await asyncio.sleep(1)
693
+
694
+ except Exception as e:
695
+ logger.error(f"Job request error: {e}")
696
+ await asyncio.sleep(5)
697
+
698
+ def _process_shard_chunk(self, chunk: ShardChunk):
699
+ """Process a single shard chunk with item-level tracking."""
700
+ logger.info(
701
+ f"Processing shard {chunk.shard_name} with unprocessed ranges: {chunk.unprocessed_ranges}"
702
+ )
703
+
704
+ # Select appropriate processor
705
+ if chunk.shard_url.startswith("hf_dataset:"):
706
+ processor = self.hf_processor
707
+ else:
708
+ processor = self.webdataset_processor
709
+
710
+ items_processed = 0
711
+
712
+ # Let the processor handle the range filtering
713
+ for key, url, image_data, metadata in processor.iterate_chunk_with_metadata(
714
+ chunk, self.dataset_loader, self.should_stop_processing, self.connected
715
+ ):
716
+ try:
717
+ # Load image
718
+ img = Image.open(io.BytesIO(image_data))
719
+
720
+ # Create processing item
721
+ item = ProcessingItem(
722
+ chunk_id=chunk.chunk_id,
723
+ item_key=key,
724
+ image=img,
725
+ image_data=image_data,
726
+ metadata=metadata,
727
+ )
728
+
729
+ # Store absolute item index for tracking
730
+ # The processor should provide the correct index in metadata
731
+ if "_chunk_relative_index" in metadata:
732
+ item.metadata["_item_index"] = (
733
+ chunk.start_index + metadata["_chunk_relative_index"]
734
+ )
735
+
736
+ # Add to readahead queue with timeout handling
737
+ timeout_end = time.time() + 30
738
+ while (
739
+ self.running
740
+ and not self.should_stop_processing.is_set()
741
+ and self.connected.is_set()
742
+ ):
743
+ try:
744
+ self.readahead_queue.put(item, timeout=1)
745
+ break
746
+ except:
747
+ if time.time() > timeout_end:
748
+ raise TimeoutError("Queue put timeout")
749
+ continue
750
+
751
+ # If we couldn't queue due to disconnection, stop processing
752
+ if not self.connected.is_set() or self.should_stop_processing.is_set():
753
+ logger.debug(f"Skipping remaining items due to disconnection")
754
+ break
755
+
756
+ items_processed += 1
757
+
758
+ # Batch items for inference
759
+ batch_size = self.vllm_config.get("batch_size", 8)
760
+ if self.readahead_queue.qsize() >= batch_size:
761
+ self._batch_for_inference()
762
+
763
+ except Exception as e:
764
+ if self.should_stop_processing.is_set():
765
+ break
766
+ logger.error(f"Error processing item {key}: {e}")
767
+ self.items_failed += 1
768
+
769
+ # Process any remaining items in queue
770
+ if not self.should_stop_processing.is_set():
771
+ self._batch_for_inference()
772
+
773
+ logger.info(
774
+ f"Chunk {chunk.chunk_id} processed {items_processed} items from unprocessed ranges"
775
+ )
776
+
777
+ def _shard_reader_thread(self):
778
+ """Background thread that reads from WebDataset shards."""
779
+ logger.info("Starting shard reader thread")
780
+
781
+ while self.running:
782
+ # Check if we should stop processing
783
+ if self.should_stop_processing.is_set():
784
+ logger.info("Shard reader waiting for reconnection")
785
+ time.sleep(1)
786
+ continue
787
+
788
+ # Only process if connected
789
+ if not self.connected.is_set():
790
+ time.sleep(1)
791
+ continue
792
+
793
+ # Get next chunk to process
794
+ with self.chunk_lock:
795
+ if not self.current_chunk and self.assigned_chunks:
796
+ self.current_chunk = self.assigned_chunks.popleft()
797
+ self.current_chunk_progress = 0
798
+ logger.info(f"Starting chunk {self.current_chunk.chunk_id}")
799
+
800
+ if not self.current_chunk:
801
+ time.sleep(1)
802
+ continue
803
+
804
+ try:
805
+ # Process the chunk
806
+ self._process_shard_chunk(self.current_chunk)
807
+
808
+ # Only mark complete if still connected
809
+ if self.connected.is_set() and not self.should_stop_processing.is_set():
810
+ logger.info(f"Completed chunk {self.current_chunk.chunk_id}")
811
+ self.chunks_completed += 1
812
+
813
+ # Notify orchestrator if connected
814
+ if self.websocket and self.main_loop:
815
+ try:
816
+ # Notify completion
817
+ asyncio.run_coroutine_threadsafe(
818
+ self.websocket.send(
819
+ json.dumps(
820
+ {
821
+ "type": "chunk_complete",
822
+ "chunk_id": self.current_chunk.chunk_id,
823
+ }
824
+ )
825
+ ),
826
+ self.main_loop,
827
+ ).result(timeout=5)
828
+
829
+ # Request more chunks if queue is low
830
+ with self.chunk_lock:
831
+ queue_size = len(self.assigned_chunks)
832
+
833
+ if queue_size < 2:
834
+ logger.info(f"Requesting more chunks (queue size: {queue_size})")
835
+ asyncio.run_coroutine_threadsafe(
836
+ self.websocket.send(
837
+ json.dumps({"type": "request_chunks", "count": 2})
838
+ ),
839
+ self.main_loop,
840
+ ).result(timeout=5)
841
+
842
+ except Exception as e:
843
+ logger.warning(f"Could not notify orchestrator: {e}")
844
+
845
+ with self.chunk_lock:
846
+ self.current_chunk = None
847
+
848
+ except Exception as e:
849
+ logger.error(f"Error processing chunk: {e}")
850
+
851
+ # Only notify of failure if still connected
852
+ if self.connected.is_set() and self.websocket and self.main_loop:
853
+ try:
854
+ asyncio.run_coroutine_threadsafe(
855
+ self.websocket.send(
856
+ json.dumps(
857
+ {
858
+ "type": "chunk_failed",
859
+ "chunk_id": (
860
+ self.current_chunk.chunk_id
861
+ if self.current_chunk
862
+ else "unknown"
863
+ ),
864
+ "error": str(e),
865
+ }
866
+ )
867
+ ),
868
+ self.main_loop,
869
+ ).result(timeout=5)
870
+ except Exception as send_error:
871
+ logger.warning(
872
+ f"Could not notify orchestrator of chunk failure: {send_error}"
873
+ )
874
+
875
+ with self.chunk_lock:
876
+ self.current_chunk = None
877
+
878
+ async def _result_sender(self):
879
+ """Send results back to orchestrator with item index."""
880
+ pending_results = []
881
+
882
+ try:
883
+ while self.running and self.connected.is_set():
884
+ try:
885
+ # Get result with timeout
886
+ try:
887
+ result = await asyncio.get_event_loop().run_in_executor(
888
+ None, self.result_queue.get, True, 1
889
+ )
890
+ pending_results.append(result)
891
+ except Empty:
892
+ pass
893
+
894
+ # Only try to send if connected
895
+ if pending_results and self.websocket and self.connected.is_set():
896
+ sent_results = []
897
+ for result in pending_results:
898
+ try:
899
+ # Build message with item index
900
+ message_data = {
901
+ "type": "submit_captions",
902
+ "chunk_id": result.chunk_id,
903
+ "dataset": self.dataset_config.get("dataset_path", "unknown"),
904
+ "shard": result.shard_name,
905
+ "item_key": result.item_key,
906
+ "item_index": result.item_index, # NEW: Include index
907
+ "outputs": result.outputs,
908
+ "captions": result.outputs.get("captions", []), # Compatibility
909
+ "caption_count": sum(len(v) for v in result.outputs.values()),
910
+ "image_width": result.image_width,
911
+ "image_height": result.image_height,
912
+ "image_format": result.image_format,
913
+ "file_size": result.file_size,
914
+ "processing_time_ms": result.processing_time_ms,
915
+ "metadata": result.metadata,
916
+ }
917
+
918
+ await self.websocket.send(json.dumps(message_data))
919
+ sent_results.append(result)
920
+
921
+ if self.items_processed % 100 == 0:
922
+ total_outputs = sum(
923
+ len(outputs) for outputs in result.outputs.values()
924
+ )
925
+ logger.info(
926
+ f"Processed {self.items_processed} items "
927
+ f"(~{total_outputs} outputs across {len(result.outputs)} fields)"
928
+ )
929
+
930
+ except websockets.exceptions.ConnectionClosed as e:
931
+ logger.warning(f"Connection lost while sending result: {e}")
932
+ raise
933
+ except Exception as e:
934
+ logger.error(f"Error sending result: {e}")
935
+ break
936
+
937
+ # Remove successfully sent results
938
+ for result in sent_results:
939
+ pending_results.remove(result)
940
+
941
+ # Clear pending results if disconnected and buffer is too large
942
+ if not self.connected.is_set() and len(pending_results) > 1000:
943
+ logger.warning(
944
+ f"Clearing {len(pending_results)} pending results due to prolonged disconnection"
945
+ )
946
+ pending_results.clear()
947
+
948
+ await asyncio.sleep(0.1)
949
+
950
+ except Exception as e:
951
+ if isinstance(e, websockets.exceptions.ConnectionClosed):
952
+ raise
953
+ logger.error(f"Unexpected error in result sender: {e}")
954
+ await asyncio.sleep(1)
955
+
956
+ except asyncio.CancelledError:
957
+ logger.debug("Result sender cancelled")
958
+ raise
959
+
960
+ def _batch_for_inference(self):
961
+ """Batch items from readahead queue for inference."""
962
+ batch = []
963
+ batch_size = self.vllm_config.get("batch_size", 8)
964
+
965
+ try:
966
+ while len(batch) < batch_size:
967
+ item = self.readahead_queue.get_nowait()
968
+ batch.append(item)
969
+ except Empty:
970
+ pass
971
+
972
+ if batch:
973
+ self.inference_queue.put(batch)
974
+
975
+ def _process_batch_multi_stage(
976
+ self, batch: List[ProcessingItem], max_attempts: int = 3
977
+ ) -> List[ProcessedResult]:
978
+ """Process a batch through all stages sequentially."""
979
+ results = []
980
+
981
+ # Process each stage in order
982
+ for stage_name in self.stage_order:
983
+ stage = next(s for s in self.stages if s.name == stage_name)
984
+ logger.debug(f"Processing batch through stage: {stage_name}")
985
+
986
+ # Get model components for this stage
987
+ llm, processor, tokenizer, sampling_params = self.model_manager.get_model_for_stage(
988
+ stage_name, stage.model
989
+ )
990
+
991
+ # Track items for retry
992
+ items_to_process = [(i, item, 0) for i, item in enumerate(batch)]
993
+
994
+ while items_to_process:
995
+ # Build requests for current items
996
+ current_batch = []
997
+ current_indices = []
998
+ requests = []
999
+
1000
+ for idx, (original_idx, item, attempt_count) in enumerate(items_to_process):
1001
+ current_batch.append((original_idx, item, attempt_count))
1002
+ current_indices.append(idx)
1003
+
1004
+ # Prepare image
1005
+ converted_img = ImageProcessor.prepare_for_inference(item.image)
1006
+
1007
+ # Create template manager for this stage's prompts
1008
+ template_manager = PromptTemplateManager(stage.prompts)
1009
+
1010
+ # Build context including metadata and previous stage results
1011
+ context = item.metadata.copy()
1012
+
1013
+ # Add previous stage outputs to context
1014
+ for prev_stage_name, stage_result in item.stage_results.items():
1015
+ # Add outputs with stage name prefix
1016
+ for i, output in enumerate(stage_result.outputs):
1017
+ context[f"{prev_stage_name}_output_{i}"] = output
1018
+ # Also add under output field name
1019
+ if len(stage_result.outputs) == 1:
1020
+ context[stage_result.output_field] = stage_result.outputs[0]
1021
+ else:
1022
+ context[stage_result.output_field] = stage_result.outputs
1023
+
1024
+ # Format prompts with context
1025
+ formatted_prompts = template_manager.format_all(context)
1026
+
1027
+ # Build requests for all prompts
1028
+ for prompt in formatted_prompts:
1029
+ req = self._build_vllm_input(converted_img, prompt, processor, tokenizer)
1030
+ requests.append(req)
1031
+
1032
+ # Run inference
1033
+ outputs = llm.generate(requests, sampling_params)
1034
+
1035
+ # Process outputs
1036
+ successful_items = []
1037
+ failed_items = []
1038
+
1039
+ for idx, (original_idx, item, attempt_count) in enumerate(current_batch):
1040
+ # Check if we should stop
1041
+ if self.should_stop_processing.is_set():
1042
+ return results
1043
+
1044
+ # Extract outputs for this item
1045
+ base_idx = idx * len(stage.prompts)
1046
+ stage_outputs = []
1047
+
1048
+ for j in range(len(stage.prompts)):
1049
+ if base_idx + j < len(outputs) and outputs[base_idx + j].outputs:
1050
+ original_output = outputs[base_idx + j].outputs[0].text
1051
+ cleaned_output = self._clean_output(original_output)
1052
+ if cleaned_output:
1053
+ stage_outputs.append(cleaned_output)
1054
+ else:
1055
+ logger.warning(
1056
+ f"(stage {stage_name}, item {item.item_key}) output destroyed: {original_output}"
1057
+ )
1058
+
1059
+ if stage_outputs:
1060
+ # Success - add stage result to item
1061
+ stage_result = StageResult(
1062
+ stage_name=stage_name,
1063
+ output_field=stage.output_field,
1064
+ outputs=stage_outputs,
1065
+ )
1066
+ item.stage_results[stage_name] = stage_result
1067
+ successful_items.append((original_idx, item))
1068
+ else:
1069
+ # Failed - check if we should retry
1070
+ if attempt_count + 1 < max_attempts:
1071
+ failed_items.append((original_idx, item, attempt_count + 1))
1072
+ logger.warning(
1073
+ f"Stage {stage_name} failed for item {item.item_key} "
1074
+ f"(attempt {attempt_count + 1}/{max_attempts}), will retry"
1075
+ )
1076
+ else:
1077
+ logger.error(
1078
+ f"Stage {stage_name} failed for item {item.item_key} "
1079
+ f"after {max_attempts} attempts"
1080
+ )
1081
+ self.items_failed += 1
1082
+
1083
+ # Update items to process for next iteration
1084
+ items_to_process = failed_items
1085
+
1086
+ # Update batch with successful items for next stage
1087
+ batch = [item for _, item in successful_items]
1088
+
1089
+ # Log retry status if we have items to retry
1090
+ if items_to_process:
1091
+ logger.info(
1092
+ f"Retrying {len(items_to_process)} failed items for stage {stage_name}"
1093
+ )
1094
+
1095
+ # Convert batch items to results
1096
+ for item in batch:
1097
+ # Aggregate outputs by field name
1098
+ outputs_by_field = defaultdict(list)
1099
+
1100
+ for stage_result in item.stage_results.values():
1101
+ outputs_by_field[stage_result.output_field].extend(stage_result.outputs)
1102
+
1103
+ result = ProcessedResult(
1104
+ chunk_id=item.chunk_id,
1105
+ shard_name=Path(item.chunk_id).stem.rsplit("_chunk_", 1)[0],
1106
+ item_key=item.item_key,
1107
+ outputs=dict(outputs_by_field), # Convert defaultdict to dict
1108
+ image_width=item.image.width,
1109
+ image_height=item.image.height,
1110
+ image_format=item.image.format or "unknown",
1111
+ file_size=len(item.image_data),
1112
+ processing_time_ms=0, # Will be calculated by caller
1113
+ metadata=item.metadata,
1114
+ )
1115
+ results.append(result)
1116
+ self.items_processed += 1
1117
+
1118
+ return results
1119
+
1120
+ def _inference_thread(self):
1121
+ """Background thread for multi-stage vLLM inference."""
1122
+ logger.info("Starting multi-stage inference thread")
1123
+
1124
+ while self.running:
1125
+ try:
1126
+ # Get batch from queue with timeout
1127
+ batch = self.inference_queue.get(timeout=1)
1128
+
1129
+ if not batch:
1130
+ continue
1131
+
1132
+ # Skip if disconnected
1133
+ if self.should_stop_processing.is_set():
1134
+ continue
1135
+
1136
+ logger.debug(
1137
+ f"Processing batch of {len(batch)} images through {len(self.stages)} stages"
1138
+ )
1139
+ start_time = time.time()
1140
+
1141
+ # Process batch through all stages
1142
+ results = self._process_batch_multi_stage(batch)
1143
+
1144
+ # Calculate processing time per item
1145
+ if results:
1146
+ processing_time_per_item = (time.time() - start_time) * 1000 / len(batch)
1147
+
1148
+ # Update processing time and queue results
1149
+ for result in results:
1150
+ result.processing_time_ms = processing_time_per_item
1151
+ self.result_queue.put(result)
1152
+
1153
+ logger.debug(
1154
+ f"Multi-stage batch processing complete: {len(results)} successful, "
1155
+ f"{len(batch) - len(results)} failed"
1156
+ )
1157
+
1158
+ except Empty:
1159
+ continue
1160
+ except Exception as e:
1161
+ if self.should_stop_processing.is_set():
1162
+ continue
1163
+ logger.error(f"Inference error: {e}", exc_info=True)
1164
+
1165
+ def _build_vllm_input(self, image: Image.Image, prompt: str, processor, tokenizer) -> Dict:
1166
+ """Build vLLM input."""
1167
+ try:
1168
+ from qwen_vl_utils import process_vision_info
1169
+
1170
+ messages = [
1171
+ {
1172
+ "role": "user",
1173
+ "content": [
1174
+ {"type": "image", "image": image},
1175
+ {"type": "text", "text": prompt},
1176
+ ],
1177
+ }
1178
+ ]
1179
+
1180
+ prompt_text = processor.apply_chat_template(
1181
+ messages, tokenize=False, add_generation_prompt=True
1182
+ )
1183
+ image_inputs, _ = process_vision_info(messages)
1184
+ prompt_ids = tokenizer(prompt_text, add_special_tokens=False).input_ids
1185
+
1186
+ return {
1187
+ "prompt_token_ids": prompt_ids,
1188
+ "multi_modal_data": {"image": image_inputs},
1189
+ }
1190
+ except ImportError:
1191
+ return {
1192
+ "prompt": f"<|user|>\n<|image_pad|>\n{prompt}<|end|>\n<|assistant|>",
1193
+ "multi_modal_data": {"image": [image]},
1194
+ }
1195
+
1196
+ def _clean_output(self, text: str) -> str:
1197
+ """Clean model output."""
1198
+ if not text:
1199
+ return ""
1200
+
1201
+ # Remove common artifacts
1202
+ for token in ["<|end|>", "<|endoftext|>", "<|im_end|>", "I'm sorry", "I cannot"]:
1203
+ if token in text:
1204
+ text = text.split(token)[0]
1205
+
1206
+ return text.strip()
1207
+
1208
+ async def _result_sender(self):
1209
+ """Send results back to orchestrator with multi-stage outputs."""
1210
+ pending_results = [] # Buffer for results during disconnection
1211
+
1212
+ try:
1213
+ while self.running and self.connected.is_set():
1214
+ try:
1215
+ # Get result (with timeout to allow checking self.running)
1216
+ try:
1217
+ result = await asyncio.get_event_loop().run_in_executor(
1218
+ None, self.result_queue.get, True, 1
1219
+ )
1220
+ pending_results.append(result)
1221
+ except Empty:
1222
+ pass
1223
+
1224
+ # Only try to send if connected
1225
+ if pending_results and self.websocket and self.connected.is_set():
1226
+ sent_results = []
1227
+ for result in pending_results:
1228
+ try:
1229
+ # For backward compatibility, if there's only one output field "captions"
1230
+ # send it in the old format
1231
+ if len(result.outputs) == 1 and "captions" in result.outputs:
1232
+ # Old format for single-stage compatibility
1233
+ await self.websocket.send(
1234
+ json.dumps(
1235
+ {
1236
+ "type": "submit_captions",
1237
+ "chunk_id": result.chunk_id,
1238
+ "dataset": self.dataset_config.get(
1239
+ "dataset_path", "unknown"
1240
+ ),
1241
+ "shard": result.shard_name,
1242
+ "item_key": result.item_key,
1243
+ "item_index": result.metadata.get("_item_index"),
1244
+ "captions": result.outputs["captions"],
1245
+ "caption_count": len(result.outputs["captions"]),
1246
+ "image_width": result.image_width,
1247
+ "image_height": result.image_height,
1248
+ "image_format": result.image_format,
1249
+ "file_size": result.file_size,
1250
+ "processing_time_ms": result.processing_time_ms,
1251
+ }
1252
+ )
1253
+ )
1254
+ else:
1255
+ # New format for multi-stage outputs
1256
+ await self.websocket.send(
1257
+ json.dumps(
1258
+ {
1259
+ "type": "submit_captions",
1260
+ "chunk_id": result.chunk_id,
1261
+ "dataset": self.dataset_config.get(
1262
+ "dataset_path", "unknown"
1263
+ ),
1264
+ "shard": result.shard_name,
1265
+ "item_key": result.item_key,
1266
+ "outputs": result.outputs, # Dict of field -> list of outputs
1267
+ "captions": result.outputs.get(
1268
+ "captions", []
1269
+ ), # For compatibility
1270
+ "caption_count": sum(
1271
+ len(v) for v in result.outputs.values()
1272
+ ),
1273
+ "image_width": result.image_width,
1274
+ "image_height": result.image_height,
1275
+ "image_format": result.image_format,
1276
+ "file_size": result.file_size,
1277
+ "processing_time_ms": result.processing_time_ms,
1278
+ "metadata": result.metadata,
1279
+ }
1280
+ )
1281
+ )
1282
+
1283
+ sent_results.append(result)
1284
+
1285
+ if self.items_processed % 100 == 0:
1286
+ total_outputs = sum(
1287
+ len(outputs) for outputs in result.outputs.values()
1288
+ )
1289
+ logger.info(
1290
+ f"Processed {self.items_processed} items "
1291
+ f"(~{total_outputs} outputs across {len(result.outputs)} fields)"
1292
+ )
1293
+ except websockets.exceptions.ConnectionClosed as e:
1294
+ logger.warning(f"Connection lost while sending result: {e}")
1295
+ raise # Re-raise to trigger task completion
1296
+ except Exception as e:
1297
+ logger.error(f"Error sending result: {e}")
1298
+ break
1299
+
1300
+ # Remove successfully sent results
1301
+ for result in sent_results:
1302
+ pending_results.remove(result)
1303
+
1304
+ # Clear pending results if disconnected and buffer is too large
1305
+ if not self.connected.is_set() and len(pending_results) > 1000:
1306
+ logger.warning(
1307
+ f"Clearing {len(pending_results)} pending results due to prolonged disconnection"
1308
+ )
1309
+ pending_results.clear()
1310
+
1311
+ await asyncio.sleep(0.1)
1312
+
1313
+ except Exception as e:
1314
+ if isinstance(e, websockets.exceptions.ConnectionClosed):
1315
+ raise # Re-raise connection errors
1316
+ logger.error(f"Unexpected error in result sender: {e}")
1317
+ await asyncio.sleep(1)
1318
+
1319
+ except asyncio.CancelledError:
1320
+ logger.debug("Result sender cancelled")
1321
+ raise