caption-flow 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1301 @@
1
+ """Enhanced orchestrator with shard chunk assignment for vLLM workers.
2
+
3
+ This orchestrator:
4
+ 1. Divides dataset shards into chunks for parallel processing
5
+ 2. Assigns chunks to workers on request
6
+ 3. Collects captions from workers centrally
7
+ 4. Manages checkpoints and fault tolerance
8
+ """
9
+
10
+ import time
11
+ import asyncio
12
+ import json
13
+ import logging
14
+ import ssl
15
+ import uuid
16
+ from dataclasses import dataclass, asdict
17
+ from datetime import datetime
18
+ from pathlib import Path
19
+ from typing import Dict, Set, Optional, Any, List, Deque
20
+ from collections import deque, defaultdict
21
+ import threading
22
+ from queue import Queue, Empty
23
+
24
+ import websockets
25
+ from websockets.server import WebSocketServerProtocol
26
+
27
+ from .storage import StorageManager
28
+ from .models import Caption, Contributor
29
+ from .utils.auth import AuthManager
30
+ from .utils.dataset_loader import DatasetLoader, ShardTracker
31
+ from .utils.json_utils import safe_dict, safe_json_dumps, to_json_dict
32
+ from .utils.chunk_tracker import ChunkTracker
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ @dataclass
38
+ class ShardChunk:
39
+ """Represents a chunk of a shard for processing."""
40
+
41
+ chunk_id: str
42
+ shard_url: str
43
+ shard_name: str
44
+ start_index: int
45
+ chunk_size: int
46
+ assigned_to: Optional[str] = None
47
+ status: str = "pending" # pending, assigned, completed, failed
48
+ assigned_at: Optional[datetime] = None
49
+ completed_at: Optional[datetime] = None
50
+
51
+
52
+ class ChunkManager:
53
+ """Manages shard chunk creation and assignment."""
54
+
55
+ def __init__(self, chunk_size: int = 1000, tracker: Optional[ChunkTracker] = None):
56
+ self.chunk_size = chunk_size
57
+ self.chunks: Dict[str, ShardChunk] = {}
58
+ self.pending_chunks: Deque[str] = deque()
59
+ self.assigned_chunks: Dict[str, Set[str]] = defaultdict(set) # worker_id -> chunk_ids
60
+ self.lock = threading.Lock()
61
+ self.tracker = tracker # Reference to chunk tracker
62
+
63
+ def create_chunks_from_shard(
64
+ self, shard_url: str, shard_name: str, total_items: int
65
+ ) -> List[ShardChunk]:
66
+ """Create chunks from a shard."""
67
+ chunks = []
68
+
69
+ for start_idx in range(0, total_items, self.chunk_size):
70
+ chunk_id = f"{shard_name}_chunk_{start_idx}"
71
+ chunk = ShardChunk(
72
+ chunk_id=chunk_id,
73
+ shard_url=shard_url,
74
+ shard_name=shard_name,
75
+ start_index=start_idx,
76
+ chunk_size=min(self.chunk_size, total_items - start_idx),
77
+ )
78
+
79
+ with self.lock:
80
+ self.chunks[chunk_id] = chunk
81
+ self.pending_chunks.append(chunk_id)
82
+
83
+ chunks.append(chunk)
84
+
85
+ return chunks
86
+
87
+ def get_chunks_for_worker(
88
+ self, worker_id: str, count: int = 1, tracker: Optional["ChunkTracker"] = None
89
+ ) -> List[ShardChunk]:
90
+ """Get available chunks for a worker."""
91
+ assigned = []
92
+
93
+ with self.lock:
94
+ while len(assigned) < count and self.pending_chunks:
95
+ chunk_id = self.pending_chunks.popleft()
96
+ chunk = self.chunks[chunk_id]
97
+
98
+ chunk.assigned_to = worker_id
99
+ chunk.status = "assigned"
100
+ chunk.assigned_at = datetime.utcnow()
101
+
102
+ self.assigned_chunks[worker_id].add(chunk_id)
103
+ assigned.append(chunk)
104
+ if tracker:
105
+ tracker.mark_assigned(chunk_id, worker_id)
106
+
107
+ return assigned
108
+
109
+ def complete_chunk(self, chunk_id: str, worker_id: str) -> bool:
110
+ """Mark a chunk as completed."""
111
+ with self.lock:
112
+ if chunk_id in self.chunks:
113
+ chunk = self.chunks[chunk_id]
114
+ if chunk.assigned_to == worker_id and chunk.status == "assigned":
115
+ chunk.status = "completed"
116
+ chunk.completed_at = datetime.utcnow()
117
+ self.assigned_chunks[worker_id].discard(chunk_id)
118
+ return True
119
+ return False
120
+
121
+ def fail_chunk(self, chunk_id: str, worker_id: str) -> bool:
122
+ """Mark a chunk as failed and requeue it."""
123
+ with self.lock:
124
+ if chunk_id in self.chunks:
125
+ chunk = self.chunks[chunk_id]
126
+ if chunk.assigned_to == worker_id:
127
+ chunk.status = "pending"
128
+ chunk.assigned_to = None
129
+ chunk.assigned_at = None
130
+ self.assigned_chunks[worker_id].discard(chunk_id)
131
+ self.pending_chunks.append(chunk_id)
132
+ return True
133
+ return False
134
+
135
+ def release_worker_chunks(self, worker_id: str):
136
+ """Release all chunks assigned to a worker."""
137
+ with self.lock:
138
+ chunk_ids = list(self.assigned_chunks.get(worker_id, []))
139
+ for chunk_id in chunk_ids:
140
+ if chunk_id in self.chunks:
141
+ chunk = self.chunks[chunk_id]
142
+ if chunk.status == "assigned":
143
+ chunk.status = "pending"
144
+ chunk.assigned_to = None
145
+ chunk.assigned_at = None
146
+ self.pending_chunks.append(chunk_id)
147
+
148
+ if worker_id in self.assigned_chunks:
149
+ del self.assigned_chunks[worker_id]
150
+
151
+ def get_stats(self) -> Dict[str, int]:
152
+ """Get chunk statistics."""
153
+ with self.lock:
154
+ stats = {
155
+ "total": len(self.chunks),
156
+ "pending": len(self.pending_chunks),
157
+ "assigned": sum(len(chunks) for chunks in self.assigned_chunks.values()),
158
+ "completed": sum(1 for c in self.chunks.values() if c.status == "completed"),
159
+ "failed": sum(1 for c in self.chunks.values() if c.status == "failed"),
160
+ }
161
+ return stats
162
+
163
+
164
+ class Orchestrator:
165
+ """Enhanced orchestrator for vLLM-based distributed captioning with chunk assignment."""
166
+
167
+ def __init__(self, config: Dict[str, Any]):
168
+ self.config = config
169
+ self.host = config.get("host", "0.0.0.0")
170
+ self.port = config.get("port", 8765)
171
+
172
+ # Dataset configuration
173
+ self.dataset_config = config.get("dataset", {})
174
+ self.dataset_path = self.dataset_config.get("path")
175
+ self.dataset_type = self.dataset_config.get("type", "huggingface")
176
+
177
+ # vLLM configuration to distribute to workers
178
+ self.vllm_config = config.get(
179
+ "vllm",
180
+ {
181
+ "model": "Qwen/Qwen2.5-VL-3B-Instruct",
182
+ "gpu_memory_utilization": 0.92,
183
+ "max_model_len": 16384,
184
+ "tensor_parallel_size": 1,
185
+ "dtype": "float16",
186
+ "enforce_eager": True,
187
+ "limit_mm_per_prompt": {"image": 1},
188
+ "disable_mm_preprocessor_cache": True,
189
+ "sampling": {
190
+ "temperature": 0.7,
191
+ "top_p": 0.95,
192
+ "max_tokens": 256,
193
+ "repetition_penalty": 1.05,
194
+ "stop": ["<|end|>", "<|endoftext|>", "<|im_end|>"],
195
+ },
196
+ "inference_prompts": [
197
+ "describe this image in detail",
198
+ "provide a comprehensive description of the visual content",
199
+ "what are the key elements in this image?",
200
+ ],
201
+ },
202
+ )
203
+
204
+ # Chunk configuration
205
+ self.chunk_size = config.get("chunk_size", 1000)
206
+ self.chunks_per_request = config.get("chunks_per_request", 2)
207
+
208
+ # Demand-driven chunk creation settings
209
+ self.chunk_buffer_multiplier = config.get("chunk_buffer_multiplier", 3)
210
+ self.min_chunk_buffer = config.get("min_chunk_buffer", 10)
211
+
212
+ # Initialize components
213
+ storage_config = config.get("storage", {})
214
+ self.storage = StorageManager(
215
+ Path(storage_config.get("data_dir", "./caption_data")),
216
+ caption_buffer_size=storage_config.get("caption_buffer_size", 1000),
217
+ job_buffer_size=storage_config.get("job_buffer_size", 100),
218
+ contributor_buffer_size=storage_config.get("contributor_buffer_size", 10),
219
+ )
220
+ self.auth = AuthManager(config.get("auth", {}))
221
+
222
+ # Dataset components
223
+ self.dataset_loader = None
224
+ self.shard_tracker = None
225
+ self.chunk_tracker = None
226
+
227
+ if self.dataset_path:
228
+ self.dataset_loader = DatasetLoader(self.dataset_path, self.dataset_type)
229
+ checkpoint_dir = Path(config.get("storage", {}).get("checkpoint_dir", "./checkpoints"))
230
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
231
+ self.shard_tracker = ShardTracker(checkpoint_dir / "shards.json")
232
+ self.chunk_tracker = ChunkTracker(checkpoint_dir / "chunks.json")
233
+
234
+ # Initialize chunk manager with reference to chunk tracker
235
+ self.chunk_manager = ChunkManager(self.chunk_size, self.chunk_tracker)
236
+
237
+ # Track connections
238
+ self.workers: Dict[str, WebSocketServerProtocol] = {}
239
+ self.monitors: Set[WebSocketServerProtocol] = set()
240
+
241
+ # SSL configuration
242
+ self.ssl_context = self._setup_ssl()
243
+
244
+ # Statistics
245
+ self.stats = {
246
+ "total_chunks": 0,
247
+ "completed_chunks": 0,
248
+ "failed_chunks": 0,
249
+ "total_captions": 0,
250
+ "connected_workers": 0,
251
+ "total_shards": 0,
252
+ "completed_shards": 0,
253
+ "current_shard": None,
254
+ "buffer_size": 0,
255
+ "total_written": 0,
256
+ "last_checkpoint": None,
257
+ }
258
+
259
+ # Rate tracking
260
+ self.rate_tracker = {
261
+ "start_time": time.time(),
262
+ "last_update_time": time.time(),
263
+ "last_caption_count": 0,
264
+ "current_rate": 0.0,
265
+ "average_rate": 0.0,
266
+ "expected_rate": 0.0,
267
+ }
268
+
269
+ # Data sample queue for VLLMWorkers
270
+ self.data_sample_queue = asyncio.Queue(maxsize=1000)
271
+ self.data_workers: Dict[str, WebSocketServerProtocol] = {}
272
+
273
+ # Backpressure threshold
274
+ self.backpressure_threshold = config.get("backpressure_threshold", 800)
275
+
276
+ # Shard processing state
277
+ self.all_shards = []
278
+ self.current_shard_index = 0
279
+ self.shard_lock = threading.Lock()
280
+
281
+ # Background chunk creation
282
+ self.chunk_creation_thread = None
283
+ self.stop_chunk_creation = threading.Event()
284
+
285
+ # State restoration flag
286
+ self.state_restored = threading.Event()
287
+ # If no dataset, state is already "restored"
288
+ if not self.dataset_loader:
289
+ self.state_restored.set()
290
+
291
+ def _setup_ssl(self) -> Optional[ssl.SSLContext]:
292
+ """Configure SSL if certificates are provided."""
293
+ ssl_config = self.config.get("ssl", {})
294
+ if not ssl_config.get("cert") or not ssl_config.get("key"):
295
+ return None
296
+
297
+ context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
298
+ context.load_cert_chain(ssl_config["cert"], ssl_config["key"])
299
+ return context
300
+
301
+ def _create_chunks_from_dataset(self):
302
+ """Background thread to create chunks from dataset shards on demand."""
303
+ if not self.dataset_loader:
304
+ logger.warning("No dataset configured, skipping chunk creation")
305
+ self.state_restored.set() # No state to restore
306
+ return
307
+
308
+ logger.info("Starting chunk creation thread")
309
+
310
+ # Mark state as not restored until we process checkpoints
311
+ self.state_restored.clear()
312
+
313
+ # Get all shards
314
+ self.all_shards = self.dataset_loader.get_shard_list()
315
+ self.stats["total_shards"] = len(self.all_shards)
316
+
317
+ # Get shard status from ChunkTracker
318
+ shards_summary = self.chunk_tracker.get_shards_summary() if self.chunk_tracker else {}
319
+ completed_shards = {
320
+ shard_name for shard_name, info in shards_summary.items() if info["is_complete"]
321
+ }
322
+
323
+ # Update ShardTracker for completed shards
324
+ for shard_name in completed_shards:
325
+ if not self.shard_tracker.is_complete(shard_name):
326
+ logger.info(f"Marking shard {shard_name} as complete in ShardTracker")
327
+ self.shard_tracker.mark_complete(shard_name)
328
+
329
+ # Get shards that need processing
330
+ remaining_shards = self.shard_tracker.get_remaining_shards(self.all_shards)
331
+
332
+ # Also check which shards already have chunks (partial or complete)
333
+ shards_with_chunks = set()
334
+ for shard_name in shards_summary.keys():
335
+ shards_with_chunks.add(shard_name)
336
+
337
+ # Filter out shards that already have chunks created
338
+ remaining_shards = [
339
+ shard for shard in remaining_shards if Path(shard).stem not in shards_with_chunks
340
+ ]
341
+
342
+ self.stats["completed_shards"] = len(completed_shards)
343
+
344
+ logger.info(
345
+ f"Total shards: {len(self.all_shards)}, "
346
+ f"Completed: {self.stats['completed_shards']}, "
347
+ f"Shards with chunks: {len(shards_with_chunks)}, "
348
+ f"Remaining to process: {len(remaining_shards)}"
349
+ )
350
+
351
+ # First, re-queue any existing pending chunks
352
+ initial_pending = 0
353
+ requeued_chunks_by_shard = defaultdict(list)
354
+
355
+ for shard_name, shard_info in shards_summary.items():
356
+ with self.chunk_manager.lock:
357
+ for chunk_state in shard_info["chunks"]:
358
+ if chunk_state.status in ["pending", "failed", "assigned"]:
359
+ # Find shard URL
360
+ shard_url = None
361
+ for url in self.all_shards:
362
+ if Path(url).stem == shard_name:
363
+ shard_url = url
364
+ break
365
+
366
+ if shard_url:
367
+ chunk = ShardChunk(
368
+ chunk_id=chunk_state.chunk_id,
369
+ shard_url=shard_url,
370
+ shard_name=chunk_state.shard_name,
371
+ start_index=chunk_state.start_index,
372
+ chunk_size=chunk_state.chunk_size,
373
+ )
374
+ self.chunk_manager.chunks[chunk_state.chunk_id] = chunk
375
+ self.chunk_manager.pending_chunks.append(chunk_state.chunk_id)
376
+ requeued_chunks_by_shard[shard_name].append(chunk_state.chunk_id)
377
+ initial_pending += 1
378
+
379
+ logger.info(f"Re-queued {initial_pending} existing pending chunks")
380
+ for shard_name, chunk_ids in requeued_chunks_by_shard.items():
381
+ logger.info(f" Shard {shard_name}: {len(chunk_ids)} chunks - {chunk_ids}")
382
+
383
+ # Mark state as restored
384
+ self.state_restored.set()
385
+ logger.info("State restoration complete, accepting chunk requests")
386
+
387
+ # Process shards on-demand
388
+ shard_iter = iter(remaining_shards)
389
+ current_shard_url = None
390
+ current_shard_name = None
391
+ current_shard_items = None
392
+ current_shard_index = 0
393
+
394
+ while not self.stop_chunk_creation.is_set():
395
+ # Check how many chunks we need
396
+ with self.chunk_manager.lock:
397
+ pending_count = len(self.chunk_manager.pending_chunks)
398
+ assigned_count = sum(
399
+ len(chunks) for chunks in self.chunk_manager.assigned_chunks.values()
400
+ )
401
+ total_active = pending_count + assigned_count
402
+
403
+ # Target buffer: configurable multiplier × number of workers
404
+ worker_count = max(1, self.stats.get("connected_workers", 0))
405
+ target_buffer = max(
406
+ self.min_chunk_buffer, worker_count * self.chunk_buffer_multiplier
407
+ )
408
+
409
+ chunks_needed = max(0, target_buffer - total_active)
410
+
411
+ if chunks_needed == 0:
412
+ # We have enough chunks, wait a bit
413
+ time.sleep(5)
414
+ continue
415
+
416
+ logger.debug(
417
+ f"Need {chunks_needed} more chunks (pending: {pending_count}, "
418
+ f"assigned: {assigned_count}, workers: {worker_count})"
419
+ )
420
+
421
+ # Create chunks as needed
422
+ chunks_created = 0
423
+
424
+ while chunks_created < chunks_needed and not self.stop_chunk_creation.is_set():
425
+ # Need to load next shard?
426
+ if current_shard_url is None or current_shard_index >= current_shard_items:
427
+ try:
428
+ current_shard_url = next(shard_iter)
429
+ current_shard_name = Path(current_shard_url).stem
430
+ self.stats["current_shard"] = current_shard_name
431
+
432
+ # Skip if we already have chunks from this shard
433
+ if current_shard_name in shards_summary:
434
+ logger.debug(
435
+ f"Skipping shard {current_shard_name} - already has chunks"
436
+ )
437
+ current_shard_url = None
438
+ continue
439
+
440
+ # Count items in new shard
441
+ logger.info(f"Loading new shard {current_shard_name}")
442
+ current_shard_items = sum(
443
+ 1 for _ in self.dataset_loader.iterate_shard(current_shard_url)
444
+ )
445
+ current_shard_index = 0
446
+ logger.info(f"Shard {current_shard_name} has {current_shard_items} items")
447
+
448
+ except StopIteration:
449
+ # No more shards
450
+ logger.info("No more shards to process")
451
+ break
452
+ except Exception as e:
453
+ logger.error(f"Error loading shard {current_shard_name}: {e}")
454
+ current_shard_url = None
455
+ continue
456
+
457
+ # Create a chunk from current shard
458
+ if current_shard_url and current_shard_index < current_shard_items:
459
+ chunk_id = f"{current_shard_name}_chunk_{current_shard_index}"
460
+ chunk_size = min(self.chunk_size, current_shard_items - current_shard_index)
461
+
462
+ # Add to ChunkTracker
463
+ if self.chunk_tracker and self.chunk_tracker.add_chunk(
464
+ chunk_id, current_shard_name, current_shard_index, chunk_size
465
+ ):
466
+ # Create chunk
467
+ chunk = ShardChunk(
468
+ chunk_id=chunk_id,
469
+ shard_url=current_shard_url,
470
+ shard_name=current_shard_name,
471
+ start_index=current_shard_index,
472
+ chunk_size=chunk_size,
473
+ )
474
+
475
+ with self.chunk_manager.lock:
476
+ self.chunk_manager.chunks[chunk_id] = chunk
477
+ self.chunk_manager.pending_chunks.append(chunk_id)
478
+
479
+ chunks_created += 1
480
+ self.stats["total_chunks"] += 1
481
+
482
+ current_shard_index += self.chunk_size
483
+
484
+ if chunks_created > 0:
485
+ logger.info(f"Created {chunks_created} chunks on demand")
486
+
487
+ # If we couldn't create any chunks and there are no more shards, we're done
488
+ if chunks_created == 0 and current_shard_url is None:
489
+ logger.info("All shards processed, chunk creation complete")
490
+ break
491
+
492
+ # Brief pause to avoid spinning
493
+ time.sleep(1)
494
+
495
+ # Final stats
496
+ if self.chunk_tracker:
497
+ final_stats = self.chunk_tracker.get_stats()
498
+ logger.info(
499
+ f"Chunk creation thread ending. Total: {final_stats['total']}, "
500
+ f"Pending: {final_stats['pending']}, Completed: {final_stats['completed']}"
501
+ )
502
+
503
+ logger.info("Chunk creation thread finished")
504
+
505
+ async def start(self):
506
+ """Start the orchestrator server."""
507
+ logger.info(f"Starting vLLM orchestrator on {self.host}:{self.port}")
508
+ logger.info(
509
+ f"vLLM config: model={self.vllm_config.get('model')}, batch_size={self.vllm_config.get('batch_size')}"
510
+ )
511
+
512
+ # Load existing state BEFORE accepting connections
513
+ await self.storage.initialize()
514
+ if self.chunk_tracker:
515
+ await self.chunk_tracker.sync_with_storage(self.storage)
516
+ await self._restore_state()
517
+
518
+ # Start chunk creation thread if dataset is configured
519
+ if self.dataset_loader:
520
+ self.chunk_creation_thread = threading.Thread(
521
+ target=self._create_chunks_from_dataset, daemon=True
522
+ )
523
+ self.chunk_creation_thread.start()
524
+
525
+ # Give chunk creation thread time to restore existing chunks
526
+ await asyncio.sleep(2)
527
+
528
+ # Start background tasks
529
+ asyncio.create_task(self._heartbeat_loop())
530
+ asyncio.create_task(self._checkpoint_loop())
531
+ asyncio.create_task(self._stats_update_loop())
532
+
533
+ # Start WebSocket server
534
+ async with websockets.serve(
535
+ self.handle_connection, self.host, self.port, ssl=self.ssl_context
536
+ ):
537
+ logger.info("vLLM Orchestrator ready for connections")
538
+ await asyncio.Future() # Run forever
539
+
540
+ async def handle_connection(self, websocket: WebSocketServerProtocol):
541
+ """Handle new WebSocket connection."""
542
+ try:
543
+ # Authenticate
544
+ auth_msg = await websocket.recv()
545
+ auth_data = json.loads(auth_msg)
546
+
547
+ auth_ticket = self.auth.authenticate(auth_data.get("token"))
548
+ if not auth_ticket.role:
549
+ await websocket.send(safe_json_dumps({"error": "Invalid token"}))
550
+ return
551
+
552
+ if auth_ticket.role == "worker":
553
+ await self._handle_worker(websocket, auth_ticket)
554
+ elif auth_ticket.role == "data_worker":
555
+ await self._handle_data_worker(websocket, auth_ticket)
556
+ elif auth_ticket.role == "monitor":
557
+ await self._handle_monitor(websocket)
558
+ elif auth_ticket.role == "admin":
559
+ await self._handle_admin(websocket, auth_ticket)
560
+ else:
561
+ await websocket.send(safe_json_dumps({"error": f"Unknown role: {auth_ticket.role}"}))
562
+
563
+ except Exception as e:
564
+ logger.error(f"Connection error: {e}")
565
+ import traceback
566
+
567
+ logger.error(traceback.format_exc())
568
+ await websocket.close()
569
+
570
+ async def _handle_admin(self, websocket: WebSocketServerProtocol, auth_ticket):
571
+ """Handle admin connection for configuration updates."""
572
+ admin_id = getattr(auth_ticket, "name", "admin")
573
+ logger.info(f"Admin {admin_id} connected")
574
+
575
+ try:
576
+ # Send welcome
577
+ await websocket.send(safe_json_dumps({"type": "welcome", "role": "admin"}))
578
+
579
+ async for message in websocket:
580
+ try:
581
+ data = json.loads(message)
582
+ msg_type = data.get("type")
583
+
584
+ if msg_type == "reload_config":
585
+ await self._handle_config_reload(websocket, data.get("config", {}))
586
+
587
+ except json.JSONDecodeError as e:
588
+ logger.error(f"Invalid admin message: {e}")
589
+ await websocket.send(
590
+ safe_json_dumps({"type": "error", "error": "Invalid message format"})
591
+ )
592
+
593
+ except websockets.exceptions.ConnectionClosed:
594
+ logger.info(f"Admin {admin_id} disconnected")
595
+
596
+ async def _handle_config_reload(
597
+ self, websocket: WebSocketServerProtocol, new_config: Dict[str, Any]
598
+ ):
599
+ """Handle configuration reload request."""
600
+ logger.info("Processing configuration reload request")
601
+
602
+ updated_sections = []
603
+ warnings = []
604
+ requires_worker_restart = False
605
+
606
+ try:
607
+ # Update vLLM configuration
608
+ if "vllm" in new_config:
609
+ old_vllm = self.vllm_config.copy()
610
+
611
+ # Check each field for actual changes
612
+ vllm_changed = False
613
+ for key, value in new_config["vllm"].items():
614
+ if self.vllm_config.get(key) != value:
615
+ self.vllm_config[key] = value
616
+ vllm_changed = True
617
+
618
+ if vllm_changed:
619
+ updated_sections.append("vllm")
620
+
621
+ # Check if critical changes require worker restart
622
+ if (
623
+ old_vllm.get("model") != self.vllm_config.get("model")
624
+ or old_vllm.get("gpu_memory_utilization")
625
+ != self.vllm_config.get("gpu_memory_utilization")
626
+ or old_vllm.get("tensor_parallel_size")
627
+ != self.vllm_config.get("tensor_parallel_size")
628
+ ):
629
+ requires_worker_restart = True
630
+ warnings.append(
631
+ "Critical vLLM changes detected - workers will be disconnected to reload"
632
+ )
633
+
634
+ # Update dataset configuration
635
+ if "dataset" in new_config:
636
+ dataset_changed = False
637
+ for key, value in new_config["dataset"].items():
638
+ if self.dataset_config.get(key) != value:
639
+ self.dataset_config[key] = value
640
+ dataset_changed = True
641
+
642
+ if dataset_changed:
643
+ self.dataset_path = self.dataset_config.get("path")
644
+ self.dataset_type = self.dataset_config.get("type", "huggingface")
645
+ updated_sections.append("dataset")
646
+ warnings.append("Dataset changes will apply to new chunks only")
647
+
648
+ # Update chunk settings
649
+ if "chunk_size" in new_config and self.chunk_size != new_config["chunk_size"]:
650
+ self.chunk_size = new_config["chunk_size"]
651
+ self.chunk_manager.chunk_size = self.chunk_size
652
+ updated_sections.append("chunk_size")
653
+
654
+ if (
655
+ "chunks_per_request" in new_config
656
+ and self.chunks_per_request != new_config["chunks_per_request"]
657
+ ):
658
+ self.chunks_per_request = new_config["chunks_per_request"]
659
+ updated_sections.append("chunks_per_request")
660
+
661
+ # Recreate auth manager
662
+ self.auth = AuthManager(config=new_config)
663
+
664
+ # Update buffer settings
665
+ if (
666
+ "chunk_buffer_multiplier" in new_config
667
+ and self.chunk_buffer_multiplier != new_config["chunk_buffer_multiplier"]
668
+ ):
669
+ self.chunk_buffer_multiplier = new_config["chunk_buffer_multiplier"]
670
+ updated_sections.append("chunk_buffer_multiplier")
671
+
672
+ if (
673
+ "min_chunk_buffer" in new_config
674
+ and self.min_chunk_buffer != new_config["min_chunk_buffer"]
675
+ ):
676
+ self.min_chunk_buffer = new_config["min_chunk_buffer"]
677
+ updated_sections.append("min_chunk_buffer")
678
+
679
+ # Update storage settings
680
+ if "storage" in new_config:
681
+ storage_config = new_config["storage"]
682
+ storage_changed = False
683
+
684
+ if (
685
+ "caption_buffer_size" in storage_config
686
+ and self.storage.caption_buffer_size != storage_config["caption_buffer_size"]
687
+ ):
688
+ self.storage.caption_buffer_size = storage_config["caption_buffer_size"]
689
+ storage_changed = True
690
+
691
+ if "checkpoint_interval" in storage_config:
692
+ current_interval = self.config.get("storage", {}).get(
693
+ "checkpoint_interval", 1000
694
+ )
695
+ if current_interval != storage_config["checkpoint_interval"]:
696
+ self.config.setdefault("storage", {})["checkpoint_interval"] = (
697
+ storage_config["checkpoint_interval"]
698
+ )
699
+ storage_changed = True
700
+
701
+ if storage_changed:
702
+ updated_sections.append("storage")
703
+
704
+ # Update data worker storage config
705
+ if "data_worker_storage" in new_config:
706
+ current_dw_storage = self.config.get("data_worker_storage", {})
707
+ if current_dw_storage != new_config["data_worker_storage"]:
708
+ self.config["data_worker_storage"] = new_config["data_worker_storage"]
709
+ updated_sections.append("data_worker_storage")
710
+ warnings.append("Data worker storage config will apply to new connections only")
711
+
712
+ # Update backpressure threshold
713
+ if "backpressure_threshold" in new_config:
714
+ current_threshold = getattr(self, "backpressure_threshold", 800)
715
+ if current_threshold != new_config["backpressure_threshold"]:
716
+ self.backpressure_threshold = new_config["backpressure_threshold"]
717
+ updated_sections.append("backpressure_threshold")
718
+
719
+ # Check if any changes were made
720
+ if not updated_sections:
721
+ await websocket.send(
722
+ safe_json_dumps(
723
+ {
724
+ "type": "reload_complete",
725
+ "message": "No changes applied - configuration is identical",
726
+ }
727
+ )
728
+ )
729
+ logger.info("Configuration reload requested but no changes detected")
730
+ return
731
+
732
+ # Update the main config for any other fields
733
+ self.config.update(new_config)
734
+
735
+ # Handle worker restart if needed
736
+ if requires_worker_restart:
737
+ logger.info("Disconnecting all workers for configuration reload...")
738
+
739
+ # Disconnect all workers
740
+ worker_ids = list(self.workers.keys())
741
+ for worker_id in worker_ids:
742
+ try:
743
+ await self.workers[worker_id].close(
744
+ code=1012, reason="Configuration reload"
745
+ )
746
+ except:
747
+ pass
748
+
749
+ warnings.append(
750
+ f"Disconnected {len(worker_ids)} workers - they will reconnect with new config"
751
+ )
752
+ else:
753
+ # Just notify workers about config changes
754
+ reload_msg = safe_json_dumps(
755
+ {
756
+ "type": "config_update",
757
+ "vllm_config": self.vllm_config if "vllm" in updated_sections else None,
758
+ "dataset_config": (
759
+ self.dataset_config if "dataset" in updated_sections else None
760
+ ),
761
+ }
762
+ )
763
+
764
+ disconnected = []
765
+ for worker_id, ws in self.workers.items():
766
+ try:
767
+ await ws.send(reload_msg)
768
+ except:
769
+ disconnected.append(worker_id)
770
+
771
+ for worker_id in disconnected:
772
+ del self.workers[worker_id]
773
+
774
+ # Send success response
775
+ await websocket.send(
776
+ safe_json_dumps(
777
+ {"type": "reload_complete", "updated": updated_sections, "warnings": warnings}
778
+ )
779
+ )
780
+
781
+ logger.info(f"Configuration reloaded. Updated sections: {', '.join(updated_sections)}")
782
+
783
+ # Broadcast stats update to monitors
784
+ await self._broadcast_stats()
785
+ await self._send_activity(
786
+ f"Configuration reloaded by admin: {', '.join(updated_sections)}"
787
+ )
788
+
789
+ except Exception as e:
790
+ logger.error(f"Configuration reload failed: {e}")
791
+ await websocket.send(safe_json_dumps({"type": "reload_failed", "error": str(e)}))
792
+
793
+ async def _handle_worker(self, websocket: WebSocketServerProtocol, auth_ticket):
794
+ """Handle worker connection lifecycle."""
795
+ worker_id = getattr(auth_ticket, "name", str(uuid.uuid4()))
796
+ self.workers[worker_id] = websocket
797
+ self.stats["connected_workers"] = len(self.workers)
798
+
799
+ # Register contributor
800
+ contributor = Contributor(
801
+ contributor_id=worker_id, name=worker_id, total_captions=0, trust_level=1
802
+ )
803
+ await self.storage.save_contributor(contributor)
804
+
805
+ logger.info(f"Worker {worker_id} connected")
806
+ await self._broadcast_stats()
807
+ await self._send_activity(f"Worker {worker_id} connected")
808
+
809
+ try:
810
+ # Send welcome message with dataset configuration
811
+ welcome_message = {
812
+ "type": "welcome",
813
+ "worker_id": worker_id,
814
+ "dataset_config": {
815
+ "dataset_path": self.dataset_path,
816
+ "dataset_type": self.dataset_type,
817
+ "path": self.dataset_path, # For compatibility
818
+ "type": self.dataset_type, # For compatibility
819
+ },
820
+ "vllm_config": self.vllm_config,
821
+ }
822
+ await websocket.send(safe_json_dumps(welcome_message))
823
+
824
+ async for message in websocket:
825
+ data = json.loads(message)
826
+ await self._process_worker_message(worker_id, data)
827
+
828
+ except websockets.exceptions.ConnectionClosed:
829
+ logger.info(f"Worker {worker_id} disconnected")
830
+ finally:
831
+ del self.workers[worker_id]
832
+ self.stats["connected_workers"] = len(self.workers)
833
+ # Release chunks in both managers
834
+ self.chunk_manager.release_worker_chunks(worker_id)
835
+ if self.chunk_tracker:
836
+ # Mark released chunks as pending in tracker
837
+ released_chunks = self.chunk_tracker.release_worker_chunks(worker_id)
838
+ logger.info(
839
+ f"Released {len(released_chunks) if released_chunks is not None else 0} chunks from worker {worker_id}"
840
+ )
841
+
842
+ await self._broadcast_stats()
843
+ await self._send_activity(f"Worker {worker_id} disconnected")
844
+
845
+ async def _process_worker_message(self, worker_id: str, data: Dict):
846
+ """Process message from worker."""
847
+ msg_type = data.get("type")
848
+
849
+ if msg_type == "request_chunks":
850
+ # Wait for state restoration to complete
851
+ if not self.state_restored.is_set():
852
+ logger.info(f"Worker {worker_id} requesting chunks, but state not yet restored")
853
+ await self.workers[worker_id].send(
854
+ safe_json_dumps({"type": "no_chunks", "reason": "state_restoring"})
855
+ )
856
+ return
857
+
858
+ count = data.get("count", self.chunks_per_request)
859
+ chunks = self.chunk_manager.get_chunks_for_worker(worker_id, count, self.chunk_tracker)
860
+
861
+ if chunks:
862
+ # Only send the fields that worker expects
863
+ chunk_data = []
864
+ for chunk in chunks:
865
+ chunk_data.append(
866
+ {
867
+ "chunk_id": chunk.chunk_id,
868
+ "shard_url": chunk.shard_url,
869
+ "shard_name": chunk.shard_name,
870
+ "start_index": chunk.start_index,
871
+ "chunk_size": chunk.chunk_size,
872
+ }
873
+ )
874
+
875
+ await self.workers[worker_id].send(
876
+ safe_json_dumps({"type": "shard_assignment", "chunks": chunk_data})
877
+ )
878
+ chunk_ids = [c["chunk_id"] for c in chunk_data]
879
+ logger.info(f"Assigned {len(chunks)} chunks to worker {worker_id}: {chunk_ids}")
880
+ await self._send_activity(f"Assigned {len(chunks)} chunks to {worker_id}")
881
+ else:
882
+ await self.workers[worker_id].send(safe_json_dumps({"type": "no_chunks"}))
883
+
884
+ elif msg_type == "chunk_complete":
885
+ chunk_id = data["chunk_id"]
886
+ if self.chunk_manager.complete_chunk(chunk_id, worker_id):
887
+ self.stats["completed_chunks"] += 1
888
+
889
+ if self.chunk_tracker:
890
+ self.chunk_tracker.mark_completed(chunk_id)
891
+
892
+ logger.info(f"Chunk {chunk_id} completed by worker {worker_id}")
893
+ await self._check_shard_completion(chunk_id)
894
+ await self._send_activity(f"Chunk {chunk_id} completed by {worker_id}")
895
+ elif msg_type == "chunk_failed":
896
+ chunk_id = data["chunk_id"]
897
+ error = data.get("error", "Unknown error")
898
+ if self.chunk_manager.fail_chunk(chunk_id, worker_id):
899
+ self.stats["failed_chunks"] += 1
900
+
901
+ if self.chunk_tracker:
902
+ self.chunk_tracker.mark_failed(chunk_id)
903
+
904
+ logger.warning(f"Chunk {chunk_id} failed on worker {worker_id}: {error}")
905
+ await self._send_activity(f"Chunk {chunk_id} failed on {worker_id}: {error}")
906
+
907
+ elif msg_type == "submit_captions":
908
+ await self._handle_captions_submission(worker_id, data)
909
+ elif msg_type == "request_job":
910
+ # VLLMWorker requesting a job from data samples
911
+ try:
912
+ job = await asyncio.wait_for(self.data_sample_queue.get(), timeout=5)
913
+ await self.workers[worker_id].send(
914
+ json.dumps({"type": "job_assignment", "job": job})
915
+ )
916
+ logger.debug(f"Assigned job {job['job_id']} to worker {worker_id}")
917
+ except asyncio.TimeoutError:
918
+ await self.workers[worker_id].send(json.dumps({"type": "no_jobs"}))
919
+ elif msg_type == "heartbeat":
920
+ # Update worker stats
921
+ logger.debug(f"Heartbeat from {worker_id}: {data}")
922
+
923
+ async def _handle_captions_submission(self, worker_id: str, data: Dict):
924
+ """Process multiple captions submission from worker."""
925
+ chunk_id = data.get("chunk_id")
926
+ item_key = data["item_key"]
927
+ captions_list = data["captions"]
928
+
929
+ logger.debug(
930
+ f"Received {len(captions_list)} captions for item {item_key} from worker {worker_id}"
931
+ )
932
+
933
+ # Create a SINGLE caption record with ALL captions as a list
934
+ caption = Caption(
935
+ job_id=f"{chunk_id}_{item_key}", # Single ID for the item
936
+ dataset=data.get("dataset"),
937
+ shard=data.get("shard"),
938
+ item_key=item_key,
939
+ captions=captions_list, # Store ALL captions as a list
940
+ contributor_id=worker_id,
941
+ timestamp=datetime.utcnow(),
942
+ quality_scores=None, # Could be a list of scores matching captions
943
+ # Image metadata
944
+ image_width=data.get("image_width"),
945
+ image_height=data.get("image_height"),
946
+ image_format=data.get("image_format"),
947
+ file_size=data.get("file_size"),
948
+ # Processing metadata
949
+ caption_count=len(captions_list),
950
+ processing_time_ms=data.get("processing_time_ms"),
951
+ chunk_id=chunk_id,
952
+ )
953
+
954
+ # Add to central storage buffer as a single entry
955
+ await self.storage.save_caption(caption)
956
+
957
+ # Update statistics
958
+ self.stats["total_captions"] += len(captions_list)
959
+ self.stats["buffer_size"] = len(self.storage.caption_buffer)
960
+
961
+ # Update contributor stats
962
+ contributor = await self.storage.get_contributor(worker_id)
963
+ if contributor:
964
+ contributor.total_captions += len(captions_list)
965
+ await self.storage.save_contributor(contributor)
966
+
967
+ # Broadcast updated stats
968
+ await self._broadcast_stats()
969
+
970
+ # Log progress periodically
971
+ if self.stats["total_captions"] % 100 == 0:
972
+ logger.info(f"Collected {self.stats['total_captions']} captions centrally")
973
+
974
+ async def _check_shard_completion(self, chunk_id: str):
975
+ """Check if a shard is complete after chunk completion."""
976
+ # Extract shard name from chunk_id
977
+ shard_name = chunk_id.rsplit("_chunk_", 1)[0]
978
+
979
+ # Check if all chunks for this shard are complete
980
+ chunk_stats = self.chunk_manager.get_stats()
981
+ shard_chunks = [
982
+ cid
983
+ for cid, chunk in self.chunk_manager.chunks.items()
984
+ if chunk.shard_name == shard_name
985
+ ]
986
+
987
+ completed_chunks = [
988
+ cid for cid in shard_chunks if self.chunk_manager.chunks[cid].status == "completed"
989
+ ]
990
+
991
+ if len(completed_chunks) == len(shard_chunks):
992
+ logger.info(f"Shard {shard_name} complete!")
993
+ self.shard_tracker.mark_complete(shard_name)
994
+ self.stats["completed_shards"] += 1
995
+ await self._send_activity(f"Shard {shard_name} completed!")
996
+
997
+ async def _handle_data_worker(self, websocket: WebSocketServerProtocol, auth_ticket):
998
+ """Handle data worker connection."""
999
+ worker_id = getattr(auth_ticket, "name", str(uuid.uuid4()))
1000
+ self.data_workers[worker_id] = websocket
1001
+
1002
+ logger.info(f"Data worker {worker_id} connected")
1003
+
1004
+ try:
1005
+ # Send welcome with storage config
1006
+ storage_config = self.config.get(
1007
+ "data_worker_storage",
1008
+ {
1009
+ "forward_to_orchestrator": True,
1010
+ "local": {"enabled": False},
1011
+ "s3": {"enabled": False},
1012
+ },
1013
+ )
1014
+
1015
+ await websocket.send(
1016
+ json.dumps(
1017
+ {"type": "welcome", "worker_id": worker_id, "storage_config": storage_config}
1018
+ )
1019
+ )
1020
+
1021
+ # Track if we've sent backpressure
1022
+ backpressure_sent = False
1023
+
1024
+ async for message in websocket:
1025
+ data = json.loads(message)
1026
+ msg_type = data.get("type")
1027
+
1028
+ if msg_type == "submit_samples":
1029
+ # Check queue size for backpressure
1030
+ if self.data_sample_queue.qsize() > self.backpressure_threshold:
1031
+ if not backpressure_sent:
1032
+ await websocket.send(json.dumps({"type": "backpressure"}))
1033
+ backpressure_sent = True
1034
+ logger.warning(f"Backpressure applied to data worker {worker_id}")
1035
+ else:
1036
+ if backpressure_sent:
1037
+ await websocket.send(json.dumps({"type": "resume"}))
1038
+ backpressure_sent = False
1039
+
1040
+ # Receive image data for each sample
1041
+ samples = data["samples"]
1042
+ for sample in samples:
1043
+ # Receive binary image data
1044
+ image_data = await websocket.recv()
1045
+
1046
+ # Create job and add to queue
1047
+ job = {
1048
+ "job_id": f"data_{worker_id}_{sample['sample_id']}",
1049
+ "sample_id": sample["sample_id"],
1050
+ "image_data": image_data,
1051
+ "metadata": sample.get("metadata", {}),
1052
+ "source": "data_worker",
1053
+ "worker_id": worker_id,
1054
+ }
1055
+
1056
+ await self.data_sample_queue.put(job)
1057
+
1058
+ elif msg_type == "heartbeat":
1059
+ logger.debug(f"Data worker {worker_id} heartbeat: {data}")
1060
+
1061
+ except websockets.exceptions.ConnectionClosed:
1062
+ logger.info(f"Data worker {worker_id} disconnected")
1063
+ finally:
1064
+ del self.data_workers[worker_id]
1065
+
1066
+ async def _handle_monitor(self, websocket: WebSocketServerProtocol):
1067
+ """Handle monitor connection."""
1068
+ self.monitors.add(websocket)
1069
+ logger.info("Monitor connected")
1070
+
1071
+ try:
1072
+ # Send initial stats
1073
+ await websocket.send(safe_json_dumps({"type": "stats", "data": self.stats}))
1074
+
1075
+ # Send chunk stats
1076
+ chunk_stats = self.chunk_manager.get_stats()
1077
+ await websocket.send(safe_json_dumps({"type": "chunk_stats", "data": chunk_stats}))
1078
+
1079
+ # Send contributor leaderboard
1080
+ contributors = await self.storage.get_top_contributors(10)
1081
+ await websocket.send(
1082
+ safe_json_dumps(
1083
+ {"type": "leaderboard", "data": [safe_dict(c) for c in contributors]}
1084
+ )
1085
+ )
1086
+
1087
+ # Keep connection alive
1088
+ async for _ in websocket:
1089
+ pass
1090
+
1091
+ except websockets.exceptions.ConnectionClosed:
1092
+ logger.info("Monitor disconnected")
1093
+ finally:
1094
+ self.monitors.discard(websocket)
1095
+
1096
+ async def _broadcast_stats(self):
1097
+ """Broadcast statistics to all monitors."""
1098
+ if not self.monitors:
1099
+ return
1100
+
1101
+ # Include chunk stats
1102
+ chunk_stats = self.chunk_manager.get_stats()
1103
+ self.stats.update({f"chunks_{k}": v for k, v in chunk_stats.items()})
1104
+
1105
+ # Add rate information
1106
+ self.stats.update(
1107
+ {
1108
+ "current_rate": self.rate_tracker["current_rate"],
1109
+ "average_rate": self.rate_tracker["average_rate"],
1110
+ "expected_rate": self.rate_tracker["expected_rate"],
1111
+ }
1112
+ )
1113
+
1114
+ # Add vLLM info
1115
+ self.stats["vllm_model"] = self.vllm_config.get("model", "unknown")
1116
+ self.stats["vllm_batch_size"] = self.vllm_config.get("batch_size", 0)
1117
+
1118
+ message = safe_json_dumps({"type": "stats", "data": self.stats})
1119
+
1120
+ # Send to all monitors
1121
+ disconnected = set()
1122
+ for monitor in self.monitors:
1123
+ try:
1124
+ await monitor.send(message)
1125
+ except websockets.exceptions.ConnectionClosed:
1126
+ disconnected.add(monitor)
1127
+
1128
+ # Clean up disconnected monitors
1129
+ self.monitors -= disconnected
1130
+
1131
+ async def _send_activity(self, activity: str):
1132
+ """Send activity update to monitors."""
1133
+ if not self.monitors:
1134
+ return
1135
+
1136
+ message = safe_json_dumps(
1137
+ {"type": "activity", "data": f"[{datetime.now().strftime('%H:%M:%S')}] {activity}"}
1138
+ )
1139
+
1140
+ disconnected = set()
1141
+ for monitor in self.monitors:
1142
+ try:
1143
+ await monitor.send(message)
1144
+ except websockets.exceptions.ConnectionClosed:
1145
+ disconnected.add(monitor)
1146
+
1147
+ self.monitors -= disconnected
1148
+
1149
+ async def _heartbeat_loop(self):
1150
+ """Send periodic heartbeats to maintain connections."""
1151
+ while True:
1152
+ await asyncio.sleep(30)
1153
+
1154
+ # Ping workers
1155
+ disconnected = []
1156
+ for worker_id, ws in self.workers.items():
1157
+ try:
1158
+ await ws.ping()
1159
+ except:
1160
+ disconnected.append(worker_id)
1161
+
1162
+ # Clean up disconnected workers
1163
+ for worker_id in disconnected:
1164
+ if worker_id in self.workers:
1165
+ del self.workers[worker_id]
1166
+ self.chunk_manager.release_worker_chunks(worker_id)
1167
+
1168
+ async def _checkpoint_loop(self):
1169
+ """Periodically checkpoint storage."""
1170
+ interval = self.config.get("storage", {}).get("checkpoint_interval", 1000)
1171
+
1172
+ while True:
1173
+ await asyncio.sleep(60)
1174
+
1175
+ # Force checkpoint at regular intervals
1176
+ if self.stats["total_captions"] > 0 and self.stats["total_captions"] % interval == 0:
1177
+ logger.info(f"Triggering checkpoint at {self.stats['total_captions']} captions")
1178
+ await self.storage.checkpoint()
1179
+
1180
+ # Update stats
1181
+ self.stats["last_checkpoint"] = datetime.utcnow().isoformat()
1182
+ self.stats["total_written"] = self.storage.total_captions_written
1183
+ self.stats["buffer_size"] = len(self.storage.caption_buffer)
1184
+
1185
+ await self._broadcast_stats()
1186
+ logger.info(
1187
+ f"Checkpoint complete. Total written to disk: {self.stats['total_written']}"
1188
+ )
1189
+
1190
+ async def _stats_update_loop(self):
1191
+ """Periodically update and broadcast stats."""
1192
+ # Track session start values
1193
+ session_start_captions = self.stats["total_captions"]
1194
+ session_start_time = time.time()
1195
+
1196
+ while True:
1197
+ await asyncio.sleep(10)
1198
+
1199
+ # Update chunk stats
1200
+ chunk_stats = self.chunk_manager.get_stats()
1201
+ self.stats["total_chunks"] = chunk_stats["total"]
1202
+ self.stats["completed_chunks"] = chunk_stats["completed"]
1203
+ self.stats["failed_chunks"] = chunk_stats["failed"]
1204
+
1205
+ # Add queue information
1206
+ with self.chunk_manager.lock:
1207
+ self.stats["pending_chunks"] = len(self.chunk_manager.pending_chunks)
1208
+ self.stats["assigned_chunks"] = sum(
1209
+ len(chunks) for chunks in self.chunk_manager.assigned_chunks.values()
1210
+ )
1211
+
1212
+ # Calculate if we need more chunks
1213
+ worker_count = self.stats.get("connected_workers", 0)
1214
+ target_buffer = max(self.min_chunk_buffer, worker_count * self.chunk_buffer_multiplier)
1215
+ active_chunks = self.stats["pending_chunks"] + self.stats["assigned_chunks"]
1216
+ self.stats["chunk_buffer_status"] = f"{active_chunks}/{target_buffer}"
1217
+
1218
+ # Update rate information
1219
+ current_time = time.time()
1220
+ elapsed_since_update = current_time - self.rate_tracker["last_update_time"]
1221
+
1222
+ if elapsed_since_update > 0:
1223
+ # Calculate current rate (captions per minute)
1224
+ caption_diff = (
1225
+ self.stats["total_captions"] - self.rate_tracker["last_caption_count"]
1226
+ )
1227
+ self.rate_tracker["current_rate"] = (caption_diff / elapsed_since_update) * 60
1228
+
1229
+ # Calculate average rate since THIS SESSION started
1230
+ session_elapsed = current_time - session_start_time
1231
+ if session_elapsed > 0:
1232
+ session_captions = self.stats["total_captions"] - session_start_captions
1233
+ self.rate_tracker["average_rate"] = (session_captions / session_elapsed) * 60
1234
+
1235
+ # Calculate expected rate based on workers
1236
+ # Assume each worker processes batch_size images every ~2 seconds with 3 captions each
1237
+ batch_size = self.vllm_config.get("batch_size", 8)
1238
+ num_prompts = len(self.vllm_config.get("inference_prompts", ["", "", ""]))
1239
+ images_per_minute = 30 # Rough estimate: 30 images/min per worker
1240
+ self.rate_tracker["expected_rate"] = worker_count * images_per_minute * num_prompts
1241
+
1242
+ # Update trackers
1243
+ self.rate_tracker["last_update_time"] = current_time
1244
+ self.rate_tracker["last_caption_count"] = self.stats["total_captions"]
1245
+
1246
+ # Log rate information when workers are connected
1247
+ if worker_count > 0:
1248
+ logger.info(
1249
+ f"Rate: {self.rate_tracker['current_rate']:.1f} captions/min "
1250
+ f"(avg: {self.rate_tracker['average_rate']:.1f}, "
1251
+ f"expected: {self.rate_tracker['expected_rate']:.1f}) | "
1252
+ f"Workers: {worker_count}, Chunks: {active_chunks}/{target_buffer}"
1253
+ )
1254
+
1255
+ await self._broadcast_stats()
1256
+
1257
+ async def _restore_state(self):
1258
+ """Restore state from storage on startup."""
1259
+ # Update statistics
1260
+ self.stats["total_captions"] = await self.storage.count_captions()
1261
+
1262
+ logger.info(f"Restored state: {self.stats['total_captions']} captions")
1263
+
1264
+ async def shutdown(self):
1265
+ """Graceful shutdown."""
1266
+ logger.info("Shutting down orchestrator...")
1267
+
1268
+ # Stop chunk creation
1269
+ self.stop_chunk_creation.set()
1270
+ if self.chunk_creation_thread:
1271
+ self.chunk_creation_thread.join(timeout=5)
1272
+
1273
+ # Release all assigned chunks before closing connections
1274
+ for worker_id in list(self.workers.keys()):
1275
+ self.chunk_manager.release_worker_chunks(worker_id)
1276
+ if self.chunk_tracker:
1277
+ # Update chunk tracker to mark assigned chunks as pending
1278
+ with self.chunk_manager.lock:
1279
+ for chunk_id in list(self.chunk_manager.assigned_chunks.get(worker_id, [])):
1280
+ self.chunk_tracker.mark_pending(chunk_id)
1281
+
1282
+ # Close all connections
1283
+ for ws in list(self.workers.values()):
1284
+ await ws.close()
1285
+ for ws in list(self.monitors):
1286
+ await ws.close()
1287
+
1288
+ # Save chunk state
1289
+ if self.chunk_tracker:
1290
+ self.chunk_tracker.save_checkpoint()
1291
+
1292
+ # Final checkpoint
1293
+ logger.info(f"Final flush: {len(self.storage.caption_buffer)} captions in buffer")
1294
+ await self.storage.checkpoint()
1295
+
1296
+ # Log final statistics
1297
+ logger.info(
1298
+ f"Shutdown complete. Total captions collected: {self.storage.total_captions_written}"
1299
+ )
1300
+
1301
+ await self.storage.close()