caption-flow 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,516 +1,106 @@
1
- """Enhanced orchestrator with shard chunk assignment for vLLM workers.
2
-
3
- This orchestrator:
4
- 1. Divides dataset shards into chunks for parallel processing
5
- 2. Assigns chunks to workers on request
6
- 3. Collects captions from workers centrally
7
- 4. Manages checkpoints and fault tolerance
8
- """
9
-
10
1
  import time
11
2
  import asyncio
12
3
  import json
13
4
  import logging
14
5
  import ssl
15
6
  import uuid
16
- from dataclasses import dataclass, asdict
17
7
  from datetime import datetime
18
8
  from pathlib import Path
19
- from typing import Dict, Set, Optional, Any, List, Deque, Tuple
20
- from collections import deque, defaultdict
9
+ from typing import Dict, Set, Optional, Any, List
10
+ from collections import defaultdict
21
11
  import threading
22
- from queue import Queue, Empty
23
12
 
24
- from .workers import data
25
13
  import websockets
26
14
  from websockets.server import WebSocketServerProtocol
27
15
 
28
16
  from .storage import StorageManager
29
- from .models import Caption, Contributor
17
+ from .models import Caption, Contributor, JobId
30
18
  from .utils.auth import AuthManager
31
- from .utils import DatasetLoader, ShardTracker, ChunkTracker
32
- from .utils.json_utils import safe_dict, safe_json_dumps, to_json_dict
19
+ from .utils.json_utils import safe_json_dumps
20
+ from .processors import (
21
+ ProcessorConfig,
22
+ WorkAssignment,
23
+ WorkResult,
24
+ WorkUnit,
25
+ WebDatasetOrchestratorProcessor,
26
+ HuggingFaceDatasetOrchestratorProcessor,
27
+ LocalFilesystemOrchestratorProcessor,
28
+ )
33
29
 
34
30
  logger = logging.getLogger(__name__)
35
-
36
-
37
- @dataclass
38
- class ShardChunk:
39
- """Represents a chunk of a shard for processing."""
40
-
41
- chunk_id: str
42
- shard_url: str
43
- shard_name: str
44
- start_index: int
45
- chunk_size: int
46
- assigned_to: Optional[str] = None
47
- status: str = "pending" # pending, assigned, completed, failed
48
- assigned_at: Optional[datetime] = None
49
- completed_at: Optional[datetime] = None
50
-
51
- @classmethod
52
- def create(
53
- cls, shard_url: str, shard_name: str, start_index: int, chunk_size: int
54
- ) -> "ShardChunk":
55
- """Factory method to create a chunk with consistent ID."""
56
- # Always use consistent format: dataset_chunk_startindex
57
- if shard_url.startswith("hf_dataset:"):
58
- # Extract dataset path
59
- parts = shard_url.split(":")
60
- dataset_path = parts[1] if len(parts) > 1 else "unknown"
61
- chunk_id = f"{dataset_path.replace('/', '_')}_chunk_{start_index}"
62
- else:
63
- # WebDataset format
64
- chunk_id = f"{shard_name}_chunk_{start_index}"
65
-
66
- return cls(
67
- chunk_id=chunk_id,
68
- shard_url=shard_url,
69
- shard_name=shard_name,
70
- start_index=start_index,
71
- chunk_size=chunk_size,
72
- )
73
-
74
- def belongs_to_shard(self, shard_identifier: str) -> bool:
75
- """Check if this chunk belongs to a given shard."""
76
- return self.shard_name == shard_identifier
77
-
78
- def to_dict(self) -> Dict[str, Any]:
79
- """Convert to dict for JSON serialization (for workers)."""
80
- return {
81
- "chunk_id": self.chunk_id,
82
- "shard_url": self.shard_url,
83
- "shard_name": self.shard_name,
84
- "start_index": self.start_index,
85
- "chunk_size": self.chunk_size,
86
- }
87
-
88
-
89
- class ChunkManager:
90
- """Manages shard chunk creation and assignment."""
91
-
92
- def __init__(self, chunk_size: int = 1000, tracker: Optional[ChunkTracker] = None):
93
- self.chunk_size = chunk_size
94
- self.chunks: Dict[str, ShardChunk] = {}
95
- self.pending_chunks: Deque[str] = deque()
96
- self.assigned_chunks: Dict[str, Set[str]] = defaultdict(set) # worker_id -> chunk_ids
97
- self.lock = threading.Lock()
98
- self.tracker = tracker # Reference to chunk tracker
99
-
100
- # NEW: Track assigned ranges to prevent double allocation
101
- # Format: {chunk_id: {(start, end): worker_id}}
102
- self.assigned_ranges: Dict[str, Dict[Tuple[int, int], str]] = defaultdict(dict)
103
-
104
- def get_chunks_for_worker(
105
- self, worker_id: str, count: int = 1, tracker: Optional["ChunkTracker"] = None
106
- ) -> List[Dict[str, Any]]:
107
- """Get available chunks with unprocessed items for a worker."""
108
- assigned = []
109
-
110
- with self.lock:
111
- # FIRST PRIORITY: Check if this worker already has assigned chunks
112
- if worker_id in self.assigned_chunks:
113
- existing_chunk_ids = list(self.assigned_chunks[worker_id])
114
- for chunk_id in existing_chunk_ids:
115
- if len(assigned) >= count:
116
- break
117
-
118
- chunk = self.chunks.get(chunk_id)
119
- if not chunk:
120
- continue
121
-
122
- # Check if chunk still has unprocessed items
123
- if tracker:
124
- chunk_info = tracker.get_chunk_with_unprocessed_items(chunk_id)
125
- if chunk_info and chunk_info["unprocessed_ranges"]:
126
- # Filter out ranges that are assigned to other workers
127
- clean_ranges = []
128
- for start, end in chunk_info["unprocessed_ranges"]:
129
- range_key = (start, end)
130
- if range_key in self.assigned_ranges[chunk_id]:
131
- assigned_worker = self.assigned_ranges[chunk_id][range_key]
132
- if assigned_worker != worker_id:
133
- # Skip this range - it's assigned to another worker
134
- logger.warning(
135
- f"Skipping range {start}-{end} in chunk {chunk_id} "
136
- f"(assigned to {assigned_worker}, not {worker_id})"
137
- )
138
- continue
139
- # else: this worker already owns this range, include it
140
- clean_ranges.append((start, end))
141
-
142
- if clean_ranges:
143
- assigned.append(
144
- {
145
- "chunk": chunk,
146
- "unprocessed_ranges": clean_ranges,
147
- }
148
- )
149
- else:
150
- # No tracker, assume chunk needs processing
151
- assigned.append(
152
- {
153
- "chunk": chunk,
154
- "unprocessed_ranges": [(0, chunk.chunk_size - 1)],
155
- }
156
- )
157
-
158
- # SECOND PRIORITY: Get new pending chunks
159
- while len(assigned) < count and self.pending_chunks:
160
- chunk_id = self.pending_chunks.popleft()
161
- chunk = self.chunks.get(chunk_id)
162
-
163
- if not chunk:
164
- continue
165
-
166
- # Verify chunk is truly pending
167
- if chunk.status != "pending" or chunk.assigned_to is not None:
168
- logger.warning(
169
- f"Chunk {chunk_id} in pending queue but status={chunk.status}, assigned_to={chunk.assigned_to}"
170
- )
171
- continue
172
-
173
- # Assign to this worker
174
- chunk.assigned_to = worker_id
175
- chunk.status = "assigned"
176
- chunk.assigned_at = datetime.utcnow()
177
- self.assigned_chunks[worker_id].add(chunk_id)
178
-
179
- # Get unprocessed ranges and filter out any that are somehow already assigned
180
- unprocessed_ranges = [(0, chunk.chunk_size - 1)] # Default
181
- if tracker:
182
- chunk_info = tracker.get_chunk_with_unprocessed_items(chunk_id)
183
- if chunk_info:
184
- # Filter out any ranges that are already assigned (shouldn't happen for new chunks)
185
- clean_ranges = []
186
- for start, end in chunk_info["unprocessed_ranges"]:
187
- range_key = (start, end)
188
- if range_key not in self.assigned_ranges[chunk_id]:
189
- clean_ranges.append((start, end))
190
- else:
191
- logger.error(
192
- f"Range {start}-{end} in newly assigned chunk {chunk_id} "
193
- f"is already assigned to {self.assigned_ranges[chunk_id][range_key]}!"
194
- )
195
- unprocessed_ranges = clean_ranges if clean_ranges else []
196
-
197
- tracker.mark_assigned(chunk_id, worker_id)
198
-
199
- if unprocessed_ranges:
200
- assigned.append({"chunk": chunk, "unprocessed_ranges": unprocessed_ranges})
201
-
202
- # Track assigned ranges and verify no double allocation
203
- for info in assigned:
204
- chunk_id = info["chunk"].chunk_id
205
- for start, end in info["unprocessed_ranges"]:
206
- range_key = (start, end)
207
-
208
- # Check if this range is already assigned
209
- if range_key in self.assigned_ranges[chunk_id]:
210
- existing_worker = self.assigned_ranges[chunk_id][range_key]
211
- if existing_worker != worker_id:
212
- # This should never happen - raise assertion
213
- raise AssertionError(
214
- f"CRITICAL: Attempting to assign range {start}-{end} in chunk {chunk_id} "
215
- f"to worker {worker_id}, but it's already assigned to {existing_worker}! "
216
- f"This would cause duplicate processing."
217
- )
218
-
219
- # Track this assignment
220
- self.assigned_ranges[chunk_id][range_key] = worker_id
221
-
222
- # Log what we're assigning
223
- if assigned:
224
- chunk_summary = ", ".join(
225
- [
226
- f"{info['chunk'].chunk_id}[{len(info['unprocessed_ranges'])} ranges]"
227
- for info in assigned
228
- ]
229
- )
230
- logger.info(f"Assigning to worker {worker_id}: {chunk_summary}")
231
-
232
- # Detailed range logging for debugging
233
- for info in assigned:
234
- chunk_id = info["chunk"].chunk_id
235
- ranges_str = ", ".join([f"{s}-{e}" for s, e in info["unprocessed_ranges"]])
236
- logger.debug(f" Chunk {chunk_id} ranges: {ranges_str}")
237
-
238
- return assigned
239
-
240
- def complete_chunk(self, chunk_id: str, worker_id: str) -> bool:
241
- """Mark a chunk as completed."""
242
- with self.lock:
243
- if chunk_id in self.chunks:
244
- chunk = self.chunks[chunk_id]
245
- if chunk.assigned_to == worker_id and chunk.status == "assigned":
246
- chunk.status = "completed"
247
- chunk.completed_at = datetime.utcnow()
248
- self.assigned_chunks[worker_id].discard(chunk_id)
249
-
250
- # Clear assigned ranges for this chunk
251
- if chunk_id in self.assigned_ranges:
252
- # Log what ranges we're clearing
253
- ranges_to_clear = list(self.assigned_ranges[chunk_id].keys())
254
- logger.debug(
255
- f"Clearing {len(ranges_to_clear)} assigned ranges for completed chunk {chunk_id}"
256
- )
257
- del self.assigned_ranges[chunk_id]
258
-
259
- return True
260
- return False
261
-
262
- def fail_chunk(self, chunk_id: str, worker_id: str) -> bool:
263
- """Mark a chunk as failed and requeue it."""
264
- with self.lock:
265
- if chunk_id in self.chunks:
266
- chunk = self.chunks[chunk_id]
267
- if chunk.assigned_to == worker_id:
268
- chunk.status = "pending"
269
- chunk.assigned_to = None
270
- chunk.assigned_at = None
271
- self.assigned_chunks[worker_id].discard(chunk_id)
272
- self.pending_chunks.append(chunk_id)
273
-
274
- # Clear assigned ranges for this chunk/worker
275
- if chunk_id in self.assigned_ranges:
276
- ranges_to_clear = [
277
- range_key
278
- for range_key, assigned_worker in self.assigned_ranges[chunk_id].items()
279
- if assigned_worker == worker_id
280
- ]
281
- for range_key in ranges_to_clear:
282
- del self.assigned_ranges[chunk_id][range_key]
283
- logger.debug(
284
- f"Cleared {len(ranges_to_clear)} assigned ranges for failed chunk {chunk_id}"
285
- )
286
-
287
- return True
288
- return False
289
-
290
- def release_worker_chunks(self, worker_id: str):
291
- """Release all chunks assigned to a worker."""
292
- with self.lock:
293
- chunk_ids = list(self.assigned_chunks.get(worker_id, []))
294
- for chunk_id in chunk_ids:
295
- if chunk_id in self.chunks:
296
- chunk = self.chunks[chunk_id]
297
- if chunk.status == "assigned":
298
- chunk.status = "pending"
299
- chunk.assigned_to = None
300
- chunk.assigned_at = None
301
- self.pending_chunks.append(chunk_id)
302
-
303
- # Clear assigned ranges for this worker
304
- if chunk_id in self.assigned_ranges:
305
- ranges_to_clear = [
306
- range_key
307
- for range_key, assigned_worker in self.assigned_ranges[
308
- chunk_id
309
- ].items()
310
- if assigned_worker == worker_id
311
- ]
312
- for range_key in ranges_to_clear:
313
- del self.assigned_ranges[chunk_id][range_key]
314
-
315
- if ranges_to_clear:
316
- logger.info(
317
- f"Released {len(ranges_to_clear)} ranges from chunk {chunk_id} "
318
- f"previously assigned to disconnected worker {worker_id}"
319
- )
320
-
321
- if worker_id in self.assigned_chunks:
322
- del self.assigned_chunks[worker_id]
323
-
324
- def mark_ranges_processed(
325
- self, chunk_id: str, processed_ranges: List[Tuple[int, int]], worker_id: str
326
- ):
327
- """Remove ranges from assignment tracking once they're processed."""
328
- with self.lock:
329
- if chunk_id in self.assigned_ranges:
330
- for start, end in processed_ranges:
331
- range_key = (start, end)
332
- if range_key in self.assigned_ranges[chunk_id]:
333
- assigned_worker = self.assigned_ranges[chunk_id][range_key]
334
- if assigned_worker == worker_id:
335
- del self.assigned_ranges[chunk_id][range_key]
336
- logger.debug(
337
- f"Cleared assignment of range {start}-{end} in chunk {chunk_id} "
338
- f"after processing by {worker_id}"
339
- )
340
- else:
341
- logger.warning(
342
- f"Worker {worker_id} claims to have processed range {start}-{end} "
343
- f"in chunk {chunk_id}, but it was assigned to {assigned_worker}"
344
- )
345
-
346
- def get_stats(self) -> Dict[str, int]:
347
- """Get chunk statistics."""
348
- with self.lock:
349
- # Count total assigned ranges
350
- total_assigned_ranges = sum(len(ranges) for ranges in self.assigned_ranges.values())
351
-
352
- stats = {
353
- "total": len(self.chunks),
354
- "pending": len(self.pending_chunks),
355
- "assigned": sum(len(chunks) for chunks in self.assigned_chunks.values()),
356
- "completed": sum(1 for c in self.chunks.values() if c.status == "completed"),
357
- "failed": sum(1 for c in self.chunks.values() if c.status == "failed"),
358
- "assigned_ranges": total_assigned_ranges,
359
- }
360
- return stats
31
+ logger.setLevel(logging.INFO)
361
32
 
362
33
 
363
34
  class Orchestrator:
364
- """Enhanced orchestrator for vLLM-based distributed captioning with chunk assignment."""
35
+ """Generic orchestrator for distributed work processing."""
365
36
 
366
37
  def __init__(self, config: Dict[str, Any]):
367
38
  self.config = config
368
39
  self.host = config.get("host", "0.0.0.0")
369
40
  self.port = config.get("port", 8765)
370
41
 
371
- # Dataset configuration
372
- self.dataset_config = config.get("dataset", {})
373
- self.dataset_path = self.dataset_config.get("path")
374
- self.dataset_type = self.dataset_config.get("type", "huggingface")
375
- self.dataset_split = self.dataset_config.get("split", "train") # Add split configuration
376
- self.dataset_image_column = self.dataset_config.get(
377
- "image_column", "image"
378
- ) # Add image column config
379
-
380
- # Dataset components
381
- self.dataset_loader = None
382
- self.shard_tracker = None
383
- self.chunk_tracker = None
384
-
385
- if self.dataset_path:
386
- self.dataset_loader = DatasetLoader(
387
- self.dataset_path,
388
- self.dataset_type,
389
- self.dataset_split,
390
- self.dataset_image_column,
391
- )
392
- checkpoint_dir = Path(config.get("storage", {}).get("checkpoint_dir", "./checkpoints"))
393
- checkpoint_dir.mkdir(parents=True, exist_ok=True)
394
- self.shard_tracker = ShardTracker(checkpoint_dir / "shards.json")
395
- self.chunk_tracker = ChunkTracker(checkpoint_dir / "chunks.json")
396
-
397
- # vLLM configuration to distribute to workers
398
- self.vllm_config = config.get(
399
- "vllm",
400
- {
401
- "model": "Qwen/Qwen2.5-VL-3B-Instruct",
402
- "gpu_memory_utilization": 0.92,
403
- "max_model_len": 16384,
404
- "tensor_parallel_size": 1,
405
- "dtype": "float16",
406
- "enforce_eager": True,
407
- "limit_mm_per_prompt": {"image": 1},
408
- "disable_mm_preprocessor_cache": True,
409
- "sampling": {
410
- "temperature": 0.7,
411
- "top_p": 0.95,
412
- "max_tokens": 256,
413
- "repetition_penalty": 1.05,
414
- "stop": ["<|end|>", "<|endoftext|>", "<|im_end|>"],
415
- },
416
- "inference_prompts": [
417
- "describe this image in detail",
418
- "provide a comprehensive description of the visual content",
419
- "what are the key elements in this image?",
420
- ],
421
- },
422
- )
423
-
424
- # Chunk configuration
425
- self.chunk_size = config.get("chunk_size", 1000)
426
- self.chunks_per_request = config.get("chunks_per_request", 2)
427
-
428
- # Demand-driven chunk creation settings
429
- self.chunk_buffer_multiplier = config.get("chunk_buffer_multiplier", 3)
430
- self.min_chunk_buffer = config.get("min_chunk_buffer", 10)
42
+ # Processor configuration
43
+ processor_type = config.get("dataset", {}).get("processor_type", None)
44
+ assert (
45
+ processor_type is not None
46
+ ), "You must supply processor_type in your orchestrator dataset configuration."
47
+ processor_config = ProcessorConfig(processor_type=processor_type, config=config)
48
+
49
+ # Initialize processor
50
+ if processor_type == "webdataset":
51
+ self.processor = WebDatasetOrchestratorProcessor()
52
+ elif processor_type == "huggingface_datasets":
53
+ self.processor = HuggingFaceDatasetOrchestratorProcessor()
54
+ elif processor_type == "local_filesystem":
55
+ self.processor = LocalFilesystemOrchestratorProcessor()
56
+ else:
57
+ raise ValueError(f"Unknown processor type: {processor_type}")
431
58
 
432
59
  # Initialize components
433
60
  storage_config = config.get("storage", {})
434
61
  self.storage = StorageManager(
435
62
  Path(storage_config.get("data_dir", "./caption_data")),
436
63
  caption_buffer_size=storage_config.get("caption_buffer_size", 1000),
437
- job_buffer_size=storage_config.get("job_buffer_size", 100),
438
- contributor_buffer_size=storage_config.get("contributor_buffer_size", 10),
439
64
  )
440
65
  self.auth = AuthManager(config.get("auth", {}))
66
+ self.processor.initialize(processor_config, self.storage)
441
67
 
442
- # Dataset components
443
- self.dataset_loader = None
444
- self.shard_tracker = None
445
- self.chunk_tracker = None
446
-
447
- if self.dataset_path:
448
- self.dataset_loader = DatasetLoader(self.dataset_path, self.dataset_type)
449
- checkpoint_dir = Path(config.get("storage", {}).get("checkpoint_dir", "./checkpoints"))
450
- checkpoint_dir.mkdir(parents=True, exist_ok=True)
451
- self.shard_tracker = ShardTracker(checkpoint_dir / "shards.json")
452
- self.chunk_tracker = ChunkTracker(checkpoint_dir / "chunks.json")
453
-
454
- # Initialize chunk manager with reference to chunk tracker
455
- self.chunk_manager = ChunkManager(self.chunk_size, self.chunk_tracker)
456
- self.pending_processed_items = defaultdict(list) # chunk_id -> list of indices
457
- self.item_batch_lock = threading.Lock()
458
- self.last_item_batch_flush = time.time()
459
- self.item_batch_interval = 5 # Flush every 5 seconds
460
- self.item_batch_size = 100 # Or every 100 items
68
+ # Processing configuration
69
+ self.units_per_request = config.get("units_per_request", 2)
461
70
 
462
71
  # Track connections
463
72
  self.workers: Dict[str, WebSocketServerProtocol] = {}
464
73
  self.monitors: Set[WebSocketServerProtocol] = set()
74
+ self.workers_by_user = defaultdict(set)
465
75
 
466
76
  # SSL configuration
467
77
  self.ssl_context = self._setup_ssl()
468
78
 
469
79
  # Statistics
470
- self.is_generating_stats = False
471
80
  self.stats = {
472
- "total_chunks": 0,
473
- "completed_chunks": 0,
474
- "failed_chunks": 0,
475
81
  "connected_workers": 0,
476
- "total_shards": 0,
477
- "completed_shards": 0,
478
- "current_shard": None,
82
+ "total_outputs": 0,
479
83
  "last_checkpoint": None,
84
+ "processor_stats": {},
480
85
  }
481
86
 
87
+ # Cache for leaderboard
88
+ self._cached_leaderboard = None
89
+
90
+ # Data worker stuff
91
+ self.data_workers = {}
92
+ self.data_sample_queue = asyncio.Queue()
93
+ self.backpressure_threshold = config.get("backpressure_threshold", 1000)
94
+
482
95
  # Rate tracking
483
96
  self.rate_tracker = {
484
97
  "start_time": time.time(),
485
98
  "last_update_time": time.time(),
486
- "last_caption_count": 0,
99
+ "last_output_count": 0,
487
100
  "current_rate": 0.0,
488
101
  "average_rate": 0.0,
489
- "expected_rate": 0.0,
490
102
  }
491
103
 
492
- # Data sample queue for CaptionWorker
493
- self.data_sample_queue = asyncio.Queue(maxsize=1000)
494
- self.data_workers: Dict[str, WebSocketServerProtocol] = {}
495
-
496
- # Backpressure threshold
497
- self.backpressure_threshold = config.get("backpressure_threshold", 800)
498
-
499
- # Shard processing state
500
- self.all_shards = []
501
- self.current_shard_index = 0
502
- self.shard_lock = threading.Lock()
503
-
504
- # Background chunk creation
505
- self.chunk_creation_thread = None
506
- self.stop_chunk_creation = threading.Event()
507
-
508
- # State restoration flag
509
- self.state_restored = threading.Event()
510
- # If no dataset, state is already "restored"
511
- if not self.dataset_loader:
512
- self.state_restored.set()
513
-
514
104
  def _setup_ssl(self) -> Optional[ssl.SSLContext]:
515
105
  """Configure SSL if certificates are provided."""
516
106
  ssl_config = self.config.get("ssl", {})
@@ -521,326 +111,20 @@ class Orchestrator:
521
111
  context.load_cert_chain(ssl_config["cert"], ssl_config["key"])
522
112
  return context
523
113
 
524
- def _create_chunks_from_dataset(self):
525
- """Background thread to create chunks from dataset shards on demand."""
526
- if not self.dataset_loader:
527
- logger.warning("No dataset configured, skipping chunk creation")
528
- self.state_restored.set() # No state to restore
529
- return
530
-
531
- logger.info("Starting chunk creation thread")
532
-
533
- # Mark state as not restored until we process checkpoints
534
- self.state_restored.clear()
535
-
536
- # Get dataset info to check format
537
- dataset_info = self.dataset_loader.get_dataset_info()
538
- dataset_format = dataset_info.get("dataset_format", "unknown")
539
- logger.info(f"Dataset format: {dataset_format}")
540
-
541
- # Get all shards
542
- self.all_shards = self.dataset_loader.get_shard_list()
543
- self.stats["total_shards"] = len(self.all_shards)
544
-
545
- # For HuggingFace datasets, we might need to dynamically create more shards
546
- if dataset_format == "huggingface_datasets":
547
- self._is_hf_dataset = True
548
- self._hf_chunk_size = 10000 # Items per virtual shard
549
- self._next_hf_shard_index = len(self.all_shards) # For creating new virtual shards
550
- else:
551
- self._is_hf_dataset = False
552
-
553
- # Get shard status from ChunkTracker
554
- shards_summary = self.chunk_tracker.get_shards_summary() if self.chunk_tracker else {}
555
- completed_shards = {
556
- shard_name for shard_name, info in shards_summary.items() if info["is_complete"]
557
- }
558
-
559
- # Update ShardTracker for completed shards
560
- for shard_name in completed_shards:
561
- if not self.shard_tracker.is_complete(shard_name):
562
- logger.info(f"Marking shard {shard_name} as complete in ShardTracker")
563
- self.shard_tracker.mark_complete(shard_name)
564
-
565
- # Get shards that need processing
566
- remaining_shards = self.shard_tracker.get_remaining_shards(self.all_shards)
567
-
568
- # Also check which shards already have chunks (partial or complete)
569
- shards_with_chunks = set()
570
- for shard_name in shards_summary.keys():
571
- shards_with_chunks.add(shard_name)
572
-
573
- # Filter out shards that already have chunks created
574
- remaining_shards = [
575
- shard
576
- for shard in remaining_shards
577
- if (shard if shard.startswith("hf_dataset:") else Path(shard).stem)
578
- not in shards_with_chunks
579
- ]
580
-
581
- self.stats["completed_shards"] = len(completed_shards)
582
-
583
- logger.info(
584
- f"Total shards: {len(self.all_shards)}, "
585
- f"Completed: {self.stats['completed_shards']}, "
586
- f"Shards with chunks: {len(shards_with_chunks)}, "
587
- f"Remaining to process: {len(remaining_shards)}"
588
- )
589
-
590
- # First, re-queue any existing pending chunks
591
- initial_pending = 0
592
- requeued_chunks_by_shard = defaultdict(list)
593
-
594
- for shard_name, shard_info in shards_summary.items():
595
- with self.chunk_manager.lock:
596
- for chunk_state in shard_info["chunks"]:
597
- if chunk_state.status in ["pending", "failed", "assigned"]:
598
- # For assigned chunks, reset them to pending since workers don't exist
599
- chunk = ShardChunk(
600
- chunk_id=chunk_state.chunk_id,
601
- shard_url=chunk_state.shard_url,
602
- shard_name=chunk_state.shard_name,
603
- start_index=chunk_state.start_index,
604
- chunk_size=chunk_state.chunk_size,
605
- status="pending", # Reset to pending
606
- assigned_to=None, # Clear assignment
607
- )
608
- self.chunk_manager.chunks[chunk_state.chunk_id] = chunk
609
- self.chunk_manager.pending_chunks.append(chunk_state.chunk_id)
610
- requeued_chunks_by_shard[shard_name].append(chunk_state.chunk_id)
611
- initial_pending += 1
612
-
613
- logger.info(f"Re-queued {initial_pending} existing pending chunks")
614
- for shard_name, chunk_ids in requeued_chunks_by_shard.items():
615
- logger.info(f" Shard {shard_name}: {len(chunk_ids)} chunks - {chunk_ids}")
616
-
617
- # Mark state as restored
618
- self.state_restored.set()
619
- logger.info("State restoration complete, accepting chunk requests")
620
-
621
- # Process shards on-demand
622
- shard_iter = iter(remaining_shards)
623
- current_shard_url = None
624
- current_shard_name = None
625
- current_shard_items = None
626
- current_shard_index = 0
627
-
628
- while not self.stop_chunk_creation.is_set():
629
- # Check how many chunks we need
630
- with self.chunk_manager.lock:
631
- pending_count = len(self.chunk_manager.pending_chunks)
632
- assigned_count = sum(
633
- len(chunks) for chunks in self.chunk_manager.assigned_chunks.values()
634
- )
635
- total_active = pending_count + assigned_count
636
-
637
- # Target buffer: configurable multiplier × number of workers
638
- worker_count = max(1, self.stats.get("connected_workers", 0))
639
- target_buffer = max(
640
- self.min_chunk_buffer, worker_count * self.chunk_buffer_multiplier
641
- )
642
-
643
- chunks_needed = max(0, target_buffer - total_active)
644
-
645
- if chunks_needed == 0:
646
- # We have enough chunks, wait a bit
647
- time.sleep(5)
648
- continue
649
-
650
- logger.debug(
651
- f"Need {chunks_needed} more chunks (pending: {pending_count}, "
652
- f"assigned: {assigned_count}, workers: {worker_count})"
653
- )
654
-
655
- # Create chunks as needed
656
- chunks_created = 0
657
-
658
- while chunks_created < chunks_needed and not self.stop_chunk_creation.is_set():
659
- # Need to load next shard?
660
- if current_shard_url is None or current_shard_index >= current_shard_items:
661
- try:
662
- current_shard_url = next(shard_iter)
663
-
664
- # Extract shard name based on type
665
- if current_shard_url.startswith("hf_dataset:"):
666
- current_shard_name = current_shard_url # Use full ID for virtual shards
667
- else:
668
- current_shard_name = Path(current_shard_url).stem
669
-
670
- self.stats["current_shard"] = current_shard_name
671
-
672
- # Skip if we already have chunks from this shard
673
- if current_shard_name in shards_summary:
674
- logger.debug(
675
- f"Skipping shard {current_shard_name} - already has chunks"
676
- )
677
- current_shard_url = None
678
- continue
679
-
680
- # Count items in new shard
681
- logger.info(f"Loading new shard {current_shard_name}")
682
-
683
- # For virtual HF dataset shards, use the chunk size directly
684
- if current_shard_url.startswith("hf_dataset:"):
685
- current_shard_items = self.dataset_loader.count_shard_items(
686
- current_shard_url
687
- )
688
- logger.info(
689
- f"Virtual shard {current_shard_name} has {current_shard_items} items"
690
- )
691
- else:
692
- # For WebDataset, actually count items
693
- current_shard_items = sum(
694
- 1 for _ in self.dataset_loader.iterate_shard(current_shard_url)
695
- )
696
- logger.info(
697
- f"Shard {current_shard_name} has {current_shard_items} items"
698
- )
699
-
700
- current_shard_index = 0
701
-
702
- except StopIteration:
703
- # No more shards in the iterator
704
- if self._is_hf_dataset:
705
- # Before creating new virtual shards, check if we have pending chunks
706
- with self.chunk_manager.lock:
707
- pending_count = len(self.chunk_manager.pending_chunks)
708
-
709
- if pending_count > 0:
710
- # Don't create new shards if we have pending chunks
711
- logger.debug(
712
- f"Have {pending_count} pending chunks, not creating new virtual shards yet"
713
- )
714
- current_shard_url = None
715
- time.sleep(2)
716
- continue
717
-
718
- # For HF datasets, we can create more virtual shards on demand
719
- logger.info(
720
- "Creating additional virtual shards for HuggingFace dataset"
721
- )
722
-
723
- # Create 10 more virtual shards
724
- new_shards = []
725
- for i in range(10):
726
- shard_id = f"hf_dataset:{self.dataset_path}:chunk:{self._next_hf_shard_index * self._hf_chunk_size}"
727
- new_shards.append(shard_id)
728
- self._next_hf_shard_index += 1
729
-
730
- # Add to all_shards and create new iterator
731
- self.all_shards.extend(new_shards)
732
- self.stats["total_shards"] = len(self.all_shards)
733
-
734
- # Filter for unprocessed shards
735
- remaining_new_shards = [
736
- s
737
- for s in new_shards
738
- if s not in shards_summary and s not in completed_shards
739
- ]
740
-
741
- if remaining_new_shards:
742
- shard_iter = iter(remaining_new_shards)
743
- logger.info(f"Added {len(remaining_new_shards)} new virtual shards")
744
- continue
745
-
746
- # No more shards to process
747
- logger.info("No more shards to process")
748
- break
749
-
750
- except Exception as e:
751
- logger.error(f"Error loading shard {current_shard_name}: {e}")
752
- current_shard_url = None
753
- continue
754
-
755
- # Create a chunk from current shard
756
- if current_shard_url and current_shard_index < current_shard_items:
757
- # Calculate the absolute dataset index for this chunk
758
- if current_shard_url.startswith("hf_dataset:"):
759
- # Parse the virtual shard URL to get the base start index
760
- parts = current_shard_url.split(":")
761
- if len(parts) >= 4 and parts[2] == "chunk":
762
- shard_base_index = int(parts[3])
763
- else:
764
- shard_base_index = 0
765
-
766
- # The absolute start index for this chunk in the dataset
767
- absolute_start_index = shard_base_index + current_shard_index
768
- else:
769
- # For WebDataset, current_shard_index is already absolute
770
- absolute_start_index = current_shard_index
771
-
772
- # Create chunk with absolute index
773
- chunk = ShardChunk.create(
774
- shard_url=current_shard_url,
775
- shard_name=current_shard_name,
776
- start_index=absolute_start_index,
777
- chunk_size=min(self.chunk_size, current_shard_items - current_shard_index),
778
- )
779
-
780
- # Add to ChunkTracker with all required fields
781
- if self.chunk_tracker and self.chunk_tracker.add_chunk(
782
- chunk.chunk_id,
783
- chunk.shard_name,
784
- chunk.shard_url,
785
- chunk.start_index,
786
- chunk.chunk_size,
787
- ):
788
- with self.chunk_manager.lock:
789
- self.chunk_manager.chunks[chunk.chunk_id] = chunk
790
- self.chunk_manager.pending_chunks.append(chunk.chunk_id)
791
-
792
- chunks_created += 1
793
- self.stats["total_chunks"] += 1
794
-
795
- current_shard_index += self.chunk_size
796
-
797
- if chunks_created > 0:
798
- logger.info(f"Created {chunks_created} chunks on demand")
799
-
800
- # If we couldn't create any chunks and there are no more shards, check if it's HF dataset
801
- if chunks_created == 0 and current_shard_url is None:
802
- if self._is_hf_dataset:
803
- # We can always create more virtual shards for HF datasets
804
- logger.debug("Will create more virtual shards on next iteration")
805
- else:
806
- logger.info("All shards processed, chunk creation complete")
807
- break
808
-
809
- # Brief pause to avoid spinning
810
- time.sleep(1)
811
-
812
- # Final stats
813
- if self.chunk_tracker:
814
- final_stats = self.chunk_tracker.get_stats()
815
- logger.info(
816
- f"Chunk creation thread ending. Total: {final_stats['total']}, "
817
- f"Pending: {final_stats['pending']}, Completed: {final_stats['completed']}"
818
- )
819
-
820
- logger.info("Chunk creation thread finished")
821
-
822
114
  async def start(self):
823
115
  """Start the orchestrator server."""
824
- logger.info(f"Starting vLLM orchestrator on {self.host}:{self.port}")
825
- logger.info(
826
- f"vLLM config: model={self.vllm_config.get('model')}, batch_size={self.vllm_config.get('batch_size')}"
827
- )
828
-
829
- # Load existing state BEFORE accepting connections
830
- await self.storage.initialize()
831
- if self.chunk_tracker:
832
- await self.chunk_tracker.sync_with_storage(self.storage)
833
- await self._restore_state()
834
-
835
- # Start chunk creation thread if dataset is configured
836
- if self.dataset_loader:
837
- self.chunk_creation_thread = threading.Thread(
838
- target=self._create_chunks_from_dataset, daemon=True
116
+ logger.info(f"Starting orchestrator on {self.host}:{self.port}")
117
+ processor_type = self.config.get("dataset", {}).get("processor_type", None)
118
+ if not processor_type:
119
+ logger.info(f"Config: {self.config}")
120
+ raise ValueError(
121
+ "You must supply processor_type in your orchestrator dataset configuration."
839
122
  )
840
- self.chunk_creation_thread.start()
123
+ logger.info(f"Processor type: {processor_type}")
841
124
 
842
- # Give chunk creation thread time to restore existing chunks
843
- await asyncio.sleep(2)
125
+ # Initialize storage
126
+ await self.storage.initialize()
127
+ await self.update_unprocessed_ranges()
844
128
 
845
129
  # Start background tasks
846
130
  asyncio.create_task(self._heartbeat_loop())
@@ -848,12 +132,37 @@ class Orchestrator:
848
132
  asyncio.create_task(self._stats_update_loop())
849
133
 
850
134
  # Start WebSocket server
135
+ websocket_logger = logging.getLogger("websockets")
136
+ websocket_logger.setLevel(logging.WARNING)
851
137
  async with websockets.serve(
852
- self.handle_connection, self.host, self.port, ssl=self.ssl_context
138
+ self.handle_connection,
139
+ self.host,
140
+ self.port,
141
+ ssl=self.ssl_context,
142
+ logger=websocket_logger,
853
143
  ):
854
- logger.info("vLLM Orchestrator ready for connections")
144
+ logger.info("Orchestrator ready for connections")
855
145
  await asyncio.Future() # Run forever
856
146
 
147
+ def get_workers_by_user_stats(self) -> Dict[str, Dict]:
148
+ """Get worker statistics grouped by user."""
149
+ stats = {}
150
+ for user, worker_ids in self.workers_by_user.items():
151
+ stats[user] = {"worker_ids": list(worker_ids), "count": len(worker_ids)}
152
+ return stats
153
+
154
+ async def update_unprocessed_ranges(self):
155
+ """Update unprocessed ranges based on what's actually in storage."""
156
+ if not self.processor or not self.storage:
157
+ return
158
+
159
+ processed_job_ids = self.storage.get_all_processed_job_ids()
160
+ self.processor.update_from_storage(processed_job_ids)
161
+
162
+ async def _send_leaderboard_to_monitor(self, websocket: WebSocketServerProtocol):
163
+ """Alias for _send_monitor_leaderboard for backward compatibility."""
164
+ await self._send_monitor_leaderboard(websocket)
165
+
857
166
  async def handle_connection(self, websocket: WebSocketServerProtocol):
858
167
  """Handle new WebSocket connection."""
859
168
  try:
@@ -868,49 +177,77 @@ class Orchestrator:
868
177
 
869
178
  if auth_ticket.role == "worker":
870
179
  await self._handle_worker(websocket, auth_ticket)
871
- elif auth_ticket.role == "data_worker":
872
- await self._handle_data_worker(websocket, auth_ticket)
873
180
  elif auth_ticket.role == "monitor":
874
181
  await self._handle_monitor(websocket)
875
182
  elif auth_ticket.role == "admin":
876
183
  await self._handle_admin(websocket, auth_ticket)
184
+ elif auth_ticket.role == "data_worker":
185
+ await self._handle_data_worker(websocket, auth_ticket)
877
186
  else:
878
187
  await websocket.send(
879
188
  safe_json_dumps({"error": f"Unknown role: {auth_ticket.role}"})
880
189
  )
881
190
 
882
191
  except Exception as e:
883
- logger.error(f"Connection error: {e}")
884
- import traceback
885
-
886
- logger.error(traceback.format_exc())
192
+ logger.error(f"Connection error: {e}", exc_info=True)
887
193
  await websocket.close()
888
194
 
889
- async def _handle_admin(self, websocket: WebSocketServerProtocol, auth_ticket):
890
- """Handle admin connection for configuration updates."""
891
- admin_id = getattr(auth_ticket, "name", "admin")
892
- logger.info(f"Admin {admin_id} connected")
195
+ async def _handle_worker(self, websocket: WebSocketServerProtocol, auth_ticket):
196
+ """Handle worker connection lifecycle."""
197
+ # Generate unique worker ID
198
+ base_name = getattr(auth_ticket, "name", "worker")
199
+ worker_id = f"{base_name}_{str(uuid.uuid4())[:8]}"
200
+ worker_user = base_name
201
+
202
+ self.workers[worker_id] = websocket
203
+ self.workers_by_user[worker_user].add(worker_id)
204
+ self.stats["connected_workers"] = len(self.workers)
893
205
 
206
+ # Register contributor
207
+ contributor = await self.storage.get_contributor(worker_user)
208
+ if not contributor:
209
+ contributor = Contributor(
210
+ contributor_id=worker_user,
211
+ name=worker_user,
212
+ total_captions=0,
213
+ trust_level=1,
214
+ )
215
+ await self.storage.save_contributor(contributor)
216
+
217
+ logger.info(f"Worker {worker_id} (user: {worker_user}) is retrieving configuration")
894
218
  try:
895
- # Send welcome
896
- await websocket.send(safe_json_dumps({"type": "welcome", "role": "admin"}))
219
+ # Send welcome message with processor config
220
+ filtered_config = self.config.copy()
221
+ for unwanted_key in ["auth", "orchestrator", "storage"]:
222
+ filtered_config.pop(unwanted_key, None)
223
+ welcome_message = {
224
+ "type": "welcome",
225
+ "worker_id": worker_id,
226
+ "user_id": worker_user,
227
+ "processor_type": self.config.get("dataset", {}).get("processor_type", None),
228
+ "processor_config": filtered_config,
229
+ }
230
+ await websocket.send(safe_json_dumps(welcome_message))
897
231
 
898
232
  async for message in websocket:
899
- try:
900
- data = json.loads(message)
901
- msg_type = data.get("type")
233
+ data = json.loads(message)
234
+ await self._process_worker_message(worker_id, data)
902
235
 
903
- if msg_type == "reload_config":
904
- await self._handle_config_reload(websocket, data.get("config", {}))
236
+ except websockets.exceptions.ConnectionClosed:
237
+ logger.info(f"Worker {worker_id} has disconnected due to websocket connection closure")
238
+ finally:
239
+ if worker_id in self.workers:
240
+ del self.workers[worker_id]
905
241
 
906
- except json.JSONDecodeError as e:
907
- logger.error(f"Invalid admin message: {e}")
908
- await websocket.send(
909
- safe_json_dumps({"type": "error", "error": "Invalid message format"})
910
- )
242
+ self.workers_by_user[worker_user].discard(worker_id)
243
+ if not self.workers_by_user[worker_user]:
244
+ del self.workers_by_user[worker_user]
911
245
 
912
- except websockets.exceptions.ConnectionClosed:
913
- logger.info(f"Admin {admin_id} disconnected")
246
+ self.stats["connected_workers"] = len(self.workers)
247
+
248
+ # Release assignments
249
+ self.processor.release_assignments(worker_id)
250
+ logger.info(f"Worker {worker_id} has safely disconnected")
914
251
 
915
252
  async def _handle_config_reload(
916
253
  self, websocket: WebSocketServerProtocol, new_config: Dict[str, Any]
@@ -920,224 +257,61 @@ class Orchestrator:
920
257
 
921
258
  updated_sections = []
922
259
  warnings = []
923
- requires_worker_restart = False
924
260
 
925
261
  try:
926
262
  # Extract orchestrator section if present
927
263
  if "orchestrator" in new_config:
928
- # Config has orchestrator wrapper, extract it
929
264
  orchestrator_config = new_config["orchestrator"]
930
265
  else:
931
- # Config is already at orchestrator level
932
266
  orchestrator_config = new_config
933
267
 
934
- # Helper function for deep comparison
935
- def deep_equal(a, b):
936
- """Deep comparison of two values including nested dicts and lists."""
937
- if type(a) != type(b):
938
- return False
939
- if isinstance(a, dict):
940
- if set(a.keys()) != set(b.keys()):
941
- return False
942
- return all(deep_equal(a[k], b[k]) for k in a.keys())
943
- elif isinstance(a, (list, tuple)):
944
- if len(a) != len(b):
945
- return False
946
- return all(deep_equal(x, y) for x, y in zip(a, b))
268
+ # Update processor configuration if present
269
+ if "processor_type" in orchestrator_config:
270
+ old_type = self.config.get("processor_type")
271
+ new_type = orchestrator_config["processor_type"]
272
+
273
+ if old_type != new_type:
274
+ warnings.append("Processor type changes require orchestrator restart")
275
+ updated_sections.append("processor_type")
947
276
  else:
948
- return a == b
949
-
950
- # Update vLLM configuration
951
- if "vllm" in orchestrator_config:
952
- old_vllm = self.vllm_config.copy()
953
- new_vllm = orchestrator_config["vllm"]
954
-
955
- # Check if vLLM config actually changed using deep comparison
956
- vllm_changed = not deep_equal(old_vllm, new_vllm)
957
-
958
- if vllm_changed:
959
- # Update the vLLM config
960
- self.vllm_config = new_vllm.copy()
961
- updated_sections.append("vllm")
962
-
963
- # Check if critical changes require worker restart
964
- if (
965
- old_vllm.get("model") != new_vllm.get("model")
966
- or old_vllm.get("gpu_memory_utilization")
967
- != new_vllm.get("gpu_memory_utilization")
968
- or old_vllm.get("tensor_parallel_size")
969
- != new_vllm.get("tensor_parallel_size")
970
- or old_vllm.get("dtype") != new_vllm.get("dtype")
971
- or old_vllm.get("max_model_len") != new_vllm.get("max_model_len")
972
- ):
973
- requires_worker_restart = True
974
- warnings.append(
975
- "Critical vLLM changes detected - workers will be disconnected to reload"
976
- )
977
- logger.info(
978
- f"Model change: {old_vllm.get('model')} -> {new_vllm.get('model')}"
979
- )
980
-
981
- # Update dataset configuration
982
- if "dataset" in orchestrator_config:
983
- old_dataset = self.dataset_config.copy()
984
- new_dataset = orchestrator_config["dataset"]
985
-
986
- dataset_changed = not deep_equal(old_dataset, new_dataset)
987
-
988
- if dataset_changed:
989
- self.dataset_config = new_dataset.copy()
990
- self.dataset_path = self.dataset_config.get("path")
991
- self.dataset_type = self.dataset_config.get("type", "huggingface")
992
- updated_sections.append("dataset")
993
- warnings.append("Dataset changes will apply to new chunks only")
994
-
995
- # Update chunk settings
996
- if (
997
- "chunk_size" in orchestrator_config
998
- and self.chunk_size != orchestrator_config["chunk_size"]
999
- ):
1000
- self.chunk_size = orchestrator_config["chunk_size"]
1001
- self.chunk_manager.chunk_size = self.chunk_size
1002
- updated_sections.append("chunk_size")
1003
-
1004
- if (
1005
- "chunks_per_request" in orchestrator_config
1006
- and self.chunks_per_request != orchestrator_config["chunks_per_request"]
1007
- ):
1008
- self.chunks_per_request = orchestrator_config["chunks_per_request"]
1009
- updated_sections.append("chunks_per_request")
277
+ # Update processor config
278
+ self.config.update(orchestrator_config)
279
+
280
+ # Reinitialize processor with new config
281
+ processor_config = ProcessorConfig(
282
+ processor_type=new_type, config=orchestrator_config
283
+ )
284
+ self.processor.initialize(processor_config)
285
+ updated_sections.append("processor_config")
286
+
287
+ # Update units per request
288
+ if "units_per_request" in orchestrator_config:
289
+ self.units_per_request = orchestrator_config["units_per_request"]
290
+ updated_sections.append("units_per_request")
1010
291
 
1011
292
  # Update auth configuration
1012
293
  if "auth" in orchestrator_config:
1013
294
  try:
1014
- self.auth = AuthManager({"auth": orchestrator_config["auth"]})
295
+ self.auth = AuthManager(orchestrator_config["auth"])
1015
296
  updated_sections.append("auth")
1016
297
  except Exception as e:
1017
298
  logger.error(f"Failed to update AuthManager: {e}")
1018
299
  warnings.append(f"Auth update failed: {e}")
1019
300
 
1020
- # Update buffer settings
1021
- if (
1022
- "chunk_buffer_multiplier" in orchestrator_config
1023
- and self.chunk_buffer_multiplier != orchestrator_config["chunk_buffer_multiplier"]
1024
- ):
1025
- self.chunk_buffer_multiplier = orchestrator_config["chunk_buffer_multiplier"]
1026
- updated_sections.append("chunk_buffer_multiplier")
1027
-
1028
- if (
1029
- "min_chunk_buffer" in orchestrator_config
1030
- and self.min_chunk_buffer != orchestrator_config["min_chunk_buffer"]
1031
- ):
1032
- self.min_chunk_buffer = orchestrator_config["min_chunk_buffer"]
1033
- updated_sections.append("min_chunk_buffer")
1034
-
1035
301
  # Update storage settings
1036
302
  if "storage" in orchestrator_config:
1037
303
  storage_config = orchestrator_config["storage"]
1038
- storage_changed = False
1039
304
 
1040
- if (
1041
- "caption_buffer_size" in storage_config
1042
- and self.storage.caption_buffer_size != storage_config["caption_buffer_size"]
1043
- ):
305
+ if "caption_buffer_size" in storage_config:
1044
306
  self.storage.caption_buffer_size = storage_config["caption_buffer_size"]
1045
- storage_changed = True
307
+ updated_sections.append("storage.caption_buffer_size")
1046
308
 
1047
- if "checkpoint_interval" in storage_config:
1048
- current_interval = self.config.get("storage", {}).get(
1049
- "checkpoint_interval", 1000
1050
- )
1051
- if current_interval != storage_config["checkpoint_interval"]:
1052
- self.config.setdefault("storage", {})["checkpoint_interval"] = (
1053
- storage_config["checkpoint_interval"]
1054
- )
1055
- storage_changed = True
1056
-
1057
- if storage_changed:
1058
- updated_sections.append("storage")
1059
-
1060
- # Check if any changes were made
1061
- if not updated_sections:
1062
- await websocket.send(
1063
- safe_json_dumps(
1064
- {
1065
- "type": "reload_complete",
1066
- "message": "No changes applied - configuration is identical",
1067
- }
1068
- )
1069
- )
1070
- logger.info("Configuration reload requested but no changes detected")
1071
- return
1072
-
1073
- # Update the main config
309
+ # Update main config
1074
310
  if "orchestrator" in new_config:
1075
- self.config["orchestrator"] = orchestrator_config
311
+ self.config = new_config["orchestrator"]
1076
312
  else:
1077
313
  self.config.update(orchestrator_config)
1078
314
 
1079
- # Handle worker restart if needed
1080
- if requires_worker_restart:
1081
- logger.info("Disconnecting all workers for configuration reload...")
1082
-
1083
- # Send reload message to workers first
1084
- reload_msg = safe_json_dumps(
1085
- {
1086
- "type": "reload_vllm",
1087
- "vllm_config": self.vllm_config,
1088
- }
1089
- )
1090
-
1091
- # Create a list of worker items to avoid modifying dict during iteration
1092
- worker_items = list(self.workers.items())
1093
- disconnected = []
1094
-
1095
- for worker_id, ws in worker_items:
1096
- try:
1097
- await ws.send(reload_msg)
1098
- # Give worker time to process before disconnect
1099
- await asyncio.sleep(0.5)
1100
- await ws.close(code=1012, reason="Configuration reload")
1101
- disconnected.append(worker_id)
1102
- except:
1103
- disconnected.append(worker_id) # Still mark as disconnected if error
1104
-
1105
- # Now safely clear workers dict
1106
- for worker_id in disconnected:
1107
- if worker_id in self.workers:
1108
- del self.workers[worker_id]
1109
-
1110
- warnings.append(
1111
- f"Sent reload message to {len(disconnected)} workers - they will reconnect with new config"
1112
- )
1113
- else:
1114
- # Just notify workers about config changes without disconnecting
1115
- config_update_msg = safe_json_dumps(
1116
- {
1117
- "type": "config_update",
1118
- "vllm_config": self.vllm_config if "vllm" in updated_sections else None,
1119
- "dataset_config": (
1120
- self.dataset_config if "dataset" in updated_sections else None
1121
- ),
1122
- }
1123
- )
1124
-
1125
- # Create a list of worker items to avoid modifying dict during iteration
1126
- worker_items = list(self.workers.items())
1127
- disconnected = []
1128
-
1129
- for worker_id, ws in worker_items:
1130
- try:
1131
- await ws.send(config_update_msg)
1132
- logger.info(f"Sent config update to worker {worker_id}")
1133
- except:
1134
- disconnected.append(worker_id)
1135
-
1136
- # Now safely remove disconnected workers
1137
- for worker_id in disconnected:
1138
- if worker_id in self.workers:
1139
- del self.workers[worker_id]
1140
-
1141
315
  # Send success response
1142
316
  await websocket.send(
1143
317
  safe_json_dumps(
@@ -1146,306 +320,169 @@ class Orchestrator:
1146
320
  )
1147
321
 
1148
322
  logger.info(f"Configuration reloaded. Updated sections: {', '.join(updated_sections)}")
1149
-
1150
- # Broadcast stats update to monitors
1151
- await self._broadcast_stats()
1152
323
  await self._send_activity(
1153
324
  f"Configuration reloaded by admin: {', '.join(updated_sections)}"
1154
325
  )
1155
326
 
1156
327
  except Exception as e:
1157
328
  logger.error(f"Configuration reload failed: {e}")
1158
- import traceback
1159
-
1160
- logger.error(traceback.format_exc())
1161
329
  await websocket.send(safe_json_dumps({"type": "reload_failed", "error": str(e)}))
1162
330
 
1163
- async def _handle_worker(self, websocket: WebSocketServerProtocol, auth_ticket):
1164
- """Handle worker connection lifecycle."""
1165
- # Generate unique worker ID even if using same token
1166
- base_name = getattr(auth_ticket, "name", "worker")
1167
- worker_id = f"{base_name}_{str(uuid.uuid4())[:8]}" # Add unique suffix
1168
-
1169
- # Track the original token/user for accounting
1170
- worker_user = base_name # Keep track of which user/token this worker belongs to
1171
-
1172
- self.workers[worker_id] = websocket
1173
- self.stats["connected_workers"] = len(self.workers)
1174
-
1175
- # Optionally track workers by user/token
1176
- if not hasattr(self, "workers_by_user"):
1177
- self.workers_by_user = defaultdict(set)
1178
- self.workers_by_user[worker_user].add(worker_id)
1179
-
1180
- # Register contributor with the base name (for aggregating stats per user)
1181
- contributor = await self.storage.get_contributor(worker_user)
1182
- if not contributor:
1183
- contributor = Contributor(
1184
- contributor_id=worker_user,
1185
- name=worker_user,
1186
- total_captions=0,
1187
- trust_level=1,
1188
- )
1189
- await self.storage.save_contributor(contributor)
1190
-
1191
- logger.info(f"Worker {worker_id} (user: {worker_user}) connected")
1192
- await self._broadcast_stats()
1193
- await self._send_activity(f"Worker {worker_id} (user: {worker_user}) connected")
1194
-
1195
- try:
1196
- # Send welcome message with dataset configuration
1197
- welcome_message = {
1198
- "type": "welcome",
1199
- "worker_id": worker_id,
1200
- "user_id": worker_user,
1201
- "dataset_config": {
1202
- "dataset_path": self.dataset_path,
1203
- "dataset_type": self.dataset_type,
1204
- "dataset_split": self.dataset_split,
1205
- "dataset_image_column": self.dataset_image_column,
1206
- "path": self.dataset_path,
1207
- "type": self.dataset_type,
1208
- "split": self.dataset_split,
1209
- "image_column": self.dataset_image_column,
1210
- },
1211
- "vllm_config": self.vllm_config,
1212
- }
1213
- await websocket.send(safe_json_dumps(welcome_message))
1214
-
1215
- async for message in websocket:
1216
- data = json.loads(message)
1217
- await self._process_worker_message(worker_id, data)
1218
-
1219
- except websockets.exceptions.ConnectionClosed:
1220
- logger.info(f"Worker {worker_id} (user: {worker_user}) disconnected")
1221
- finally:
1222
- if worker_id in self.workers:
1223
- del self.workers[worker_id]
1224
-
1225
- # Clean up user tracking
1226
- if hasattr(self, "workers_by_user") and worker_user in self.workers_by_user:
1227
- self.workers_by_user[worker_user].discard(worker_id)
1228
- if not self.workers_by_user[worker_user]:
1229
- del self.workers_by_user[worker_user]
1230
-
1231
- self.stats["connected_workers"] = len(self.workers)
1232
-
1233
- # Release chunks
1234
- self.chunk_manager.release_worker_chunks(worker_id)
1235
- if self.chunk_tracker:
1236
- released_chunks = self.chunk_tracker.release_worker_chunks(worker_id)
1237
- logger.info(
1238
- f"Released {len(released_chunks) if released_chunks is not None else 0} chunks from worker {worker_id}"
1239
- )
1240
-
1241
- await self._broadcast_stats()
1242
- await self._send_activity(f"Worker {worker_id} (user: {worker_user}) disconnected")
1243
-
1244
331
  async def _process_worker_message(self, worker_id: str, data: Dict):
1245
332
  """Process message from worker."""
1246
333
  msg_type = data.get("type")
1247
334
 
1248
- if msg_type == "request_chunks":
1249
- # Wait for state restoration to complete
1250
- if not self.state_restored.is_set():
1251
- logger.info(f"Worker {worker_id} requesting chunks, but state not yet restored")
1252
- await self.workers[worker_id].send(
1253
- safe_json_dumps({"type": "no_chunks", "reason": "state_restoring"})
335
+ if msg_type == "request_work":
336
+ count = data.get("count", self.units_per_request)
337
+ units = self.processor.get_work_units(count, worker_id)
338
+ logger.debug(f"Assigning units: {[unit.chunk_id for unit in units]}")
339
+
340
+ if units:
341
+ # Create assignment
342
+ assignment = WorkAssignment(
343
+ assignment_id=str(uuid.uuid4()),
344
+ worker_id=worker_id,
345
+ units=units,
346
+ assigned_at=datetime.utcnow(),
1254
347
  )
1255
- return
1256
-
1257
- count = data.get("count", self.chunks_per_request)
1258
- chunk_infos = self.chunk_manager.get_chunks_for_worker(
1259
- worker_id, count, self.chunk_tracker
1260
- )
1261
-
1262
- if chunk_infos:
1263
- # Send chunks with unprocessed ranges
1264
- chunks_data = []
1265
- for info in chunk_infos:
1266
- chunk_dict = info["chunk"].to_dict()
1267
- chunk_dict["unprocessed_ranges"] = info["unprocessed_ranges"]
1268
- chunks_data.append(chunk_dict)
1269
348
 
1270
349
  await self.workers[worker_id].send(
1271
- safe_json_dumps({"type": "shard_assignment", "chunks": chunks_data})
350
+ safe_json_dumps({"type": "work_assignment", "assignment": assignment.to_dict()})
1272
351
  )
1273
352
 
1274
- chunk_ids = [c["chunk_id"] for c in chunks_data]
1275
- logger.info(
1276
- f"Assigned {len(chunks_data)} chunks to worker {worker_id}: {chunk_ids}"
1277
- )
353
+ logger.debug(f"Assigned {len(units)} work units to worker {worker_id}")
1278
354
  else:
1279
- await self.workers[worker_id].send(safe_json_dumps({"type": "no_chunks"}))
1280
-
1281
- elif msg_type == "chunk_complete":
1282
- chunk_id = data["chunk_id"]
1283
- if self.chunk_manager.complete_chunk(chunk_id, worker_id):
1284
- self.stats["completed_chunks"] += 1
355
+ await self.workers[worker_id].send(safe_json_dumps({"type": "no_work"}))
1285
356
 
1286
- if self.chunk_tracker:
1287
- self.chunk_tracker.mark_completed(chunk_id)
357
+ elif msg_type == "work_complete":
358
+ unit_id = data["unit_id"]
359
+ self.processor.mark_completed(unit_id, worker_id)
360
+ logger.debug(f"Work unit {unit_id} completed by worker {worker_id}")
1288
361
 
1289
- logger.info(f"Chunk {chunk_id} completed by worker {worker_id}")
1290
- await self._check_shard_completion(chunk_id)
1291
- await self._send_activity(f"Chunk {chunk_id} completed by {worker_id}")
1292
- elif msg_type == "chunk_failed":
1293
- chunk_id = data["chunk_id"]
362
+ elif msg_type == "work_failed":
363
+ unit_id = data["unit_id"]
1294
364
  error = data.get("error", "Unknown error")
1295
- if self.chunk_manager.fail_chunk(chunk_id, worker_id):
1296
- self.stats["failed_chunks"] += 1
365
+ self.processor.mark_failed(unit_id, worker_id, error)
366
+ logger.warning(f"Work unit {unit_id} failed on worker {worker_id}: {error}")
1297
367
 
1298
- if self.chunk_tracker:
1299
- self.chunk_tracker.mark_failed(chunk_id)
368
+ elif msg_type == "submit_results":
369
+ await self._handle_results_submission(worker_id, data)
1300
370
 
1301
- logger.warning(f"Chunk {chunk_id} failed on worker {worker_id}: {error}")
1302
- await self._send_activity(f"Chunk {chunk_id} failed on {worker_id}: {error}")
1303
-
1304
- elif msg_type == "submit_captions":
1305
- await self._handle_captions_submission(worker_id, data)
1306
- elif msg_type == "request_job":
1307
- # CaptionWorker requesting a job from data samples
1308
- try:
1309
- job = await asyncio.wait_for(self.data_sample_queue.get(), timeout=5)
1310
- await self.workers[worker_id].send(
1311
- json.dumps({"type": "job_assignment", "job": job})
1312
- )
1313
- logger.debug(f"Assigned job {job['job_id']} to worker {worker_id}")
1314
- except asyncio.TimeoutError:
1315
- await self.workers[worker_id].send(json.dumps({"type": "no_jobs"}))
1316
371
  elif msg_type == "heartbeat":
1317
- # Update worker stats
1318
372
  logger.debug(f"Heartbeat from {worker_id}: {data}")
1319
373
 
1320
- async def _handle_captions_submission(self, worker_id: str, data: Dict):
1321
- """Process caption submission from worker - now handles multi-stage outputs."""
1322
- chunk_id = data.get("chunk_id")
1323
- item_key = data["item_key"]
1324
-
1325
- item_index = data.get("item_index") # Worker should send this
1326
- if item_index is None:
1327
- # Try to extract from item_key (format: dataset_XXXXXXXX)
1328
- try:
1329
- item_index = int(item_key.split("_")[-1])
1330
- except:
1331
- logger.warning(f"Could not extract item index from key: {item_key}")
1332
-
1333
- # Extract user from worker_id (format: "username_uuid")
374
+ async def _handle_results_submission(self, worker_id: str, data: Dict):
375
+ """Process results submission from worker."""
376
+ # Extract user from worker_id
1334
377
  worker_user = worker_id.rsplit("_", 1)[0] if "_" in worker_id else worker_id
1335
378
 
1336
- # Handle both old format (captions list) and new format (outputs dict)
1337
- if "outputs" in data:
1338
- # New multi-stage format
1339
- outputs = data["outputs"]
1340
- captions_list = outputs.get("captions", [])
1341
- total_outputs = sum(len(v) for v in outputs.values())
1342
-
1343
- logger.debug(
1344
- f"Received multi-stage outputs for item {item_key} from worker {worker_id}: "
1345
- f"{total_outputs} outputs across {len(outputs)} fields"
1346
- )
1347
- else:
1348
- # Old format - single captions list
1349
- captions_list = data["captions"]
1350
- outputs = {"captions": captions_list}
1351
- total_outputs = len(captions_list)
1352
-
1353
- logger.debug(
1354
- f"Received {len(captions_list)} captions for item {item_key} from worker {worker_id}"
1355
- )
379
+ # Create work result
380
+ _job_id = data.get("job_id")
381
+ job_id = JobId.from_str(_job_id)
382
+ shard_name = job_id.shard_id # >data-0000<
383
+ chunk_name = job_id.chunk_id # data-0000:chunk:>0<
384
+ # logger.debug(f"({job_id}) Worker result: {data}")
385
+ result = WorkResult(
386
+ unit_id=data["unit_id"],
387
+ source_id=shard_name,
388
+ chunk_id=job_id.get_chunk_str(), # we want the full string here
389
+ sample_id=data["sample_id"],
390
+ dataset=data["dataset"],
391
+ outputs=data["outputs"],
392
+ metadata=data.get("metadata", {}),
393
+ processing_time_ms=data.get("processing_time_ms", 0),
394
+ )
1356
395
 
1357
- # Create caption record with multi-stage outputs
396
+ # Let processor handle any custom processing
397
+ processed = self.processor.handle_result(result)
398
+
399
+ # Create caption record for storage
400
+ total_outputs = sum(len(v) for v in result.outputs.values())
401
+
402
+ filename = result.metadata.pop("_filename", None)
403
+ url = result.metadata.pop("_url", None)
404
+ image_height = result.metadata.pop("image_height", None)
405
+ image_width = result.metadata.pop("image_width", None)
406
+ file_size = result.metadata.pop("file_size", None)
407
+ image_format = result.metadata.pop("image_format", None)
408
+ item_index = result.metadata.pop("item_index", None)
409
+ item_key = result.metadata.pop("item_key", None)
410
+ to_delete_metadata_keys = ["_image_format", "_job_id"]
411
+ for key in to_delete_metadata_keys:
412
+ if key in result.metadata:
413
+ del result.metadata[key]
1358
414
  caption = Caption(
1359
- job_id=f"{chunk_id}_{item_key}",
1360
- dataset=data.get("dataset"),
1361
- shard=data.get("shard"),
415
+ job_id=job_id,
416
+ dataset=result.dataset,
417
+ shard=processed["source_id"],
418
+ chunk_id=chunk_name,
1362
419
  item_key=item_key,
1363
- captions=captions_list,
1364
- outputs=outputs,
420
+ captions=result.outputs.get("captions", []),
421
+ outputs=result.outputs,
1365
422
  contributor_id=worker_user,
1366
423
  timestamp=datetime.utcnow(),
1367
- quality_scores=None,
1368
- # Image metadata
1369
- image_width=data.get("image_width"),
1370
- image_height=data.get("image_height"),
1371
- image_format=data.get("image_format"),
1372
- file_size=data.get("file_size"),
1373
- # Processing metadata
1374
424
  caption_count=total_outputs,
1375
- processing_time_ms=data.get("processing_time_ms"),
1376
- chunk_id=chunk_id,
1377
- metadata=data.get("metadata", {}),
425
+ processing_time_ms=result.processing_time_ms,
426
+ metadata=result.metadata,
427
+ image_height=image_height,
428
+ image_width=image_width,
429
+ filename=filename,
430
+ url=url,
431
+ file_size=file_size,
432
+ image_format=image_format,
1378
433
  )
1379
434
 
1380
- # Add to central storage buffer
435
+ # Save to storage
1381
436
  await self.storage.save_caption(caption)
1382
437
 
1383
- # Handle item tracking with fixed deadlock
1384
- should_flush = False
1385
- if chunk_id and item_index is not None and self.chunk_tracker:
1386
- with self.item_batch_lock:
1387
- self.pending_processed_items[chunk_id].append(item_index)
1388
-
1389
- # Check if we should flush
1390
- total_pending = sum(
1391
- len(indices) for indices in self.pending_processed_items.values()
1392
- )
1393
- time_since_flush = time.time() - self.last_item_batch_flush
1394
-
1395
- if (
1396
- total_pending >= self.item_batch_size
1397
- or time_since_flush >= self.item_batch_interval
1398
- ):
1399
- should_flush = True
1400
-
1401
- if should_flush:
1402
- await self._flush_processed_items()
1403
-
1404
- # Update contributor stats (use user, not worker)
438
+ # Update contributor stats
1405
439
  contributor = await self.storage.get_contributor(worker_user)
1406
440
  if contributor:
1407
441
  contributor.total_captions += total_outputs
1408
442
  await self.storage.save_contributor(contributor)
1409
443
 
1410
- # Broadcast updated stats
1411
- await self._broadcast_stats()
1412
-
1413
- # Log progress periodically
1414
- total_outputs = self.stats.get("total_outputs", 0)
1415
- if total_outputs > 0 and total_outputs % 100 == 0:
1416
- if (
1417
- not hasattr(self, "_last_logged_outputs")
1418
- or self._last_logged_outputs != total_outputs
1419
- ):
1420
- logger.info(f"Collected {total_outputs} outputs centrally")
1421
- self._last_logged_outputs = total_outputs
1422
-
1423
- async def _check_shard_completion(self, chunk_id: str):
1424
- """Check if a shard is complete after chunk completion."""
1425
- # Get the chunk
1426
- chunk = self.chunk_manager.chunks.get(chunk_id)
1427
- if not chunk:
1428
- return
444
+ async def _handle_monitor(self, websocket: WebSocketServerProtocol):
445
+ """Handle monitor connection."""
446
+ self.monitors.add(websocket)
447
+ logger.info(f"Monitor connected (total: {len(self.monitors)})")
448
+
449
+ try:
450
+ # Send welcome
451
+ await websocket.send(safe_json_dumps({"type": "welcome", "role": "monitor"}))
1429
452
 
1430
- shard_name = chunk.shard_name
453
+ # Send initial stats
454
+ await self._send_monitor_stats(websocket)
1431
455
 
1432
- # Find all chunks for this shard
1433
- shard_chunks = [
1434
- cid for cid, c in self.chunk_manager.chunks.items() if c.belongs_to_shard(shard_name)
1435
- ]
456
+ # Keep connection alive
457
+ async for message in websocket:
458
+ pass
1436
459
 
1437
- # Check if all are completed
1438
- completed_chunks = [
1439
- cid for cid in shard_chunks if self.chunk_manager.chunks[cid].status == "completed"
1440
- ]
460
+ except websockets.exceptions.ConnectionClosed:
461
+ logger.info("Monitor disconnected")
462
+ finally:
463
+ self.monitors.discard(websocket)
1441
464
 
1442
- if len(completed_chunks) == len(shard_chunks) and len(shard_chunks) > 0:
1443
- logger.info(f"Shard {shard_name} complete!")
1444
- # Don't mark virtual shards as complete in ShardTracker
1445
- if not shard_name.startswith("hf_dataset:"):
1446
- self.shard_tracker.mark_complete(shard_name)
1447
- self.stats["completed_shards"] += 1
1448
- await self._send_activity(f"Shard {shard_name} completed!")
465
+ async def _handle_admin(self, websocket: WebSocketServerProtocol, auth_ticket):
466
+ """Handle admin connection."""
467
+ admin_id = getattr(auth_ticket, "name", "admin")
468
+ logger.info(f"Admin {admin_id} connected")
469
+
470
+ try:
471
+ await websocket.send(safe_json_dumps({"type": "welcome", "role": "admin"}))
472
+
473
+ async for message in websocket:
474
+ try:
475
+ data = json.loads(message)
476
+ if data.get("type") == "reload_config":
477
+ await self._handle_config_reload(websocket, data.get("config", {}))
478
+ elif data.get("type") == "get_stats":
479
+ await self._send_monitor_stats(websocket)
480
+
481
+ except json.JSONDecodeError as e:
482
+ logger.error(f"Invalid admin message: {e}")
483
+
484
+ except websockets.exceptions.ConnectionClosed:
485
+ logger.info(f"Admin {admin_id} disconnected")
1449
486
 
1450
487
  async def _handle_data_worker(self, websocket: WebSocketServerProtocol, auth_ticket):
1451
488
  """Handle data worker connection."""
@@ -1516,7 +553,64 @@ class Orchestrator:
1516
553
  finally:
1517
554
  del self.data_workers[worker_id]
1518
555
 
1519
- async def _send_leaderboard_to_monitor(self, websocket: WebSocketServerProtocol):
556
+ async def _send_monitor_initial_data(self, websocket: WebSocketServerProtocol):
557
+ """Send initial data to monitor in a separate task to avoid blocking."""
558
+ total_start = time.time()
559
+ try:
560
+ # Check if websocket is still in monitors set
561
+ if websocket not in self.monitors:
562
+ logger.debug("Monitor disconnected before initial data send")
563
+ return
564
+
565
+ # Send current stats (already in memory)
566
+ stats_start = time.time()
567
+ await websocket.send(safe_json_dumps({"type": "stats", "data": self.stats}))
568
+ logger.debug(f"Monitor stats sent in {(time.time() - stats_start)*1000:.1f}ms")
569
+
570
+ # Get processor stats instead of chunk stats
571
+ processor_stats_start = time.time()
572
+ processor_stats = self.processor.get_stats()
573
+ logger.debug(
574
+ f"Processor stats retrieved in {(time.time() - processor_stats_start)*1000:.1f}ms"
575
+ )
576
+
577
+ stats_send_start = time.time()
578
+ await websocket.send(
579
+ safe_json_dumps({"type": "processor_stats", "data": processor_stats})
580
+ )
581
+ logger.debug(f"Processor stats sent in {(time.time() - stats_send_start)*1000:.1f}ms")
582
+
583
+ if websocket not in self.monitors:
584
+ return
585
+
586
+ # For leaderboard, check if we have a cached version first
587
+ if hasattr(self, "_cached_leaderboard") and self._cached_leaderboard:
588
+ # Use cached leaderboard if available
589
+ cache_send_start = time.time()
590
+ await websocket.send(
591
+ safe_json_dumps({"type": "leaderboard", "data": self._cached_leaderboard})
592
+ )
593
+ logger.debug(
594
+ f"Cached leaderboard sent in {(time.time() - cache_send_start)*1000:.1f}ms"
595
+ )
596
+ else:
597
+ # Schedule leaderboard update separately
598
+ leaderboard_task_start = time.time()
599
+ asyncio.create_task(self._send_leaderboard_to_monitor(websocket))
600
+ logger.debug(
601
+ f"Leaderboard task created in {(time.time() - leaderboard_task_start)*1000:.1f}ms"
602
+ )
603
+
604
+ logger.debug(
605
+ f"Monitor initial data send completed in {(time.time() - total_start)*1000:.1f}ms"
606
+ )
607
+
608
+ except websockets.exceptions.ConnectionClosed:
609
+ logger.debug("Monitor disconnected during initial data send")
610
+ except Exception as e:
611
+ logger.error(f"Error sending initial monitor data: {e}")
612
+
613
+ async def _send_monitor_leaderboard(self, websocket: WebSocketServerProtocol):
1520
614
  """Send leaderboard data to a specific monitor."""
1521
615
  total_start = time.time()
1522
616
  try:
@@ -1581,210 +675,43 @@ class Orchestrator:
1581
675
  except Exception as e:
1582
676
  logger.error(f"Error sending leaderboard to monitor: {e}")
1583
677
 
1584
- async def _send_initial_monitor_data(self, websocket: WebSocketServerProtocol):
1585
- """Send initial data to monitor in a separate task to avoid blocking."""
1586
- total_start = time.time()
1587
- try:
1588
- # Check if websocket is still in monitors set
1589
- if websocket not in self.monitors:
1590
- logger.debug("Monitor disconnected before initial data send")
1591
- return
1592
-
1593
- # Send current stats (already in memory)
1594
- stats_start = time.time()
1595
- await websocket.send(safe_json_dumps({"type": "stats", "data": self.stats}))
1596
- logger.debug(f"Monitor stats sent in {(time.time() - stats_start)*1000:.1f}ms")
1597
-
1598
- # Get chunk stats asynchronously
1599
- chunk_stats_start = time.time()
1600
- loop = asyncio.get_event_loop()
1601
- chunk_stats = await loop.run_in_executor(None, self.chunk_manager.get_stats)
1602
- logger.debug(f"Chunk stats retrieved in {(time.time() - chunk_stats_start)*1000:.1f}ms")
1603
-
1604
- if websocket not in self.monitors:
1605
- return
1606
-
1607
- chunk_send_start = time.time()
1608
- await websocket.send(safe_json_dumps({"type": "chunk_stats", "data": chunk_stats}))
1609
- logger.debug(f"Chunk stats sent in {(time.time() - chunk_send_start)*1000:.1f}ms")
1610
-
1611
- # For leaderboard, check if we have a cached version first
1612
- if hasattr(self, "_cached_leaderboard") and self._cached_leaderboard:
1613
- # Use cached leaderboard if available
1614
- cache_send_start = time.time()
1615
- await websocket.send(
1616
- safe_json_dumps({"type": "leaderboard", "data": self._cached_leaderboard})
1617
- )
1618
- logger.debug(
1619
- f"Cached leaderboard sent in {(time.time() - cache_send_start)*1000:.1f}ms"
1620
- )
1621
- else:
1622
- # Schedule leaderboard update separately
1623
- leaderboard_task_start = time.time()
1624
- asyncio.create_task(self._send_leaderboard_to_monitor(websocket))
1625
- logger.debug(
1626
- f"Leaderboard task created in {(time.time() - leaderboard_task_start)*1000:.1f}ms"
1627
- )
1628
-
1629
- logger.debug(
1630
- f"Monitor initial data send completed in {(time.time() - total_start)*1000:.1f}ms"
1631
- )
1632
-
1633
- except websockets.exceptions.ConnectionClosed:
1634
- logger.debug("Monitor disconnected during initial data send")
1635
- except Exception as e:
1636
- logger.error(f"Error sending initial monitor data: {e}")
1637
-
1638
- async def _handle_monitor(self, websocket: WebSocketServerProtocol):
1639
- """Handle monitor connection - truly non-blocking version."""
1640
- monitor_start = time.time()
1641
- self.monitors.add(websocket)
1642
- logger.info(f"Monitor connected (total monitors: {len(self.monitors)})")
678
+ async def _send_monitor_stats(self, websocket: WebSocketServerProtocol):
679
+ """Send current stats to a monitor."""
680
+ # Get processor stats
681
+ processor_stats = self.processor.get_stats()
1643
682
 
1644
- try:
1645
- # Send welcome message immediately
1646
- welcome_start = time.time()
1647
- await websocket.send(safe_json_dumps({"type": "welcome", "role": "monitor"}))
1648
- logger.debug(f"Monitor welcome sent in {(time.time() - welcome_start)*1000:.1f}ms")
1649
-
1650
- # Schedule initial data send as a separate task to avoid blocking
1651
- task_create_start = time.time()
1652
- asyncio.create_task(self._send_initial_monitor_data(websocket))
1653
- logger.debug(
1654
- f"Monitor initial data task created in {(time.time() - task_create_start)*1000:.1f}ms"
1655
- )
683
+ # Get storage stats
684
+ storage_stats = await self.storage.get_storage_stats()
1656
685
 
1657
- # Just keep the connection alive - no blocking work here
1658
- try:
1659
- async for message in websocket:
1660
- # Handle any incoming messages from monitor if needed
1661
- # For now, just ignore them
1662
- pass
1663
- except websockets.exceptions.ConnectionClosed:
1664
- pass # Normal disconnection
686
+ # Combine all stats
687
+ all_stats = {
688
+ **self.stats,
689
+ **storage_stats,
690
+ "processor_stats": processor_stats,
691
+ "current_rate": self.rate_tracker["current_rate"],
692
+ "average_rate": self.rate_tracker["average_rate"],
693
+ }
1665
694
 
1666
- except websockets.exceptions.ConnectionClosed:
1667
- logger.info("Monitor disconnected")
1668
- except Exception as e:
1669
- logger.error(f"Error in monitor handler: {e}")
1670
- finally:
1671
- self.monitors.discard(websocket)
1672
- logger.debug(f"Monitor handler completed in {(time.time() - monitor_start)*1000:.1f}ms")
695
+ await websocket.send(safe_json_dumps({"type": "stats", "data": all_stats}))
1673
696
 
1674
- async def _broadcast_stats(self):
1675
- """Broadcast statistics to all monitors - truly non-blocking version."""
697
+ async def _send_activity(self, activity: str):
698
+ """Send activity update to monitors."""
1676
699
  if not self.monitors:
1677
700
  return
1678
- if self.is_generating_stats:
1679
- return # Already generating stats, skip this call
1680
- self.is_generating_stats = True
1681
- total_start = time.time()
1682
701
 
1683
- # Prepare all the data first
1684
- data_prep_start = time.time()
1685
- loop = asyncio.get_event_loop()
1686
-
1687
- # Get storage stats (already async)
1688
- storage_stats_start = time.time()
1689
- storage_stats = await self.storage.get_storage_stats()
1690
- logger.debug(f"Storage stats retrieved in {(time.time() - storage_stats_start)*1000:.1f}ms")
1691
-
1692
- caption_stats_start = time.time()
1693
- caption_stats = await self.storage.get_caption_stats()
1694
- logger.debug(f"Caption stats retrieved in {(time.time() - caption_stats_start)*1000:.1f}ms")
1695
-
1696
- # Get chunk stats in thread pool
1697
- chunk_stats_start = time.time()
1698
- chunk_stats = await loop.run_in_executor(None, self.chunk_manager.get_stats)
1699
- logger.debug(f"Chunk stats retrieved in {(time.time() - chunk_stats_start)*1000:.1f}ms")
1700
-
1701
- # Build stats dict
1702
- build_stats_start = time.time()
1703
- stats_update = self.stats.copy()
1704
- stats_update.update({f"chunks_{k}": v for k, v in chunk_stats.items()})
1705
- stats_update.update(storage_stats)
1706
- stats_update["field_breakdown"] = caption_stats.get("field_stats", {})
1707
- stats_update["output_fields_list"] = caption_stats.get("output_fields", [])
1708
-
1709
- # Add rate information
1710
- stats_update.update(
1711
- {
1712
- "current_rate": self.rate_tracker["current_rate"],
1713
- "average_rate": self.rate_tracker["average_rate"],
1714
- "expected_rate": self.rate_tracker["expected_rate"],
1715
- }
702
+ message = safe_json_dumps(
703
+ {"type": "activity", "data": f"[{datetime.now().strftime('%H:%M:%S')}] {activity}"}
1716
704
  )
1717
705
 
1718
- # Add vLLM info
1719
- stats_update["vllm_model"] = self.vllm_config.get("model", "unknown")
1720
- stats_update["vllm_batch_size"] = self.vllm_config.get("batch_size", 0)
1721
-
1722
- # Add stage information
1723
- stages = self.vllm_config.get("stages", [])
1724
- if stages:
1725
- stats_update["stage_count"] = len(stages)
1726
- stats_update["stage_names"] = [s.get("name", "unnamed") for s in stages]
1727
- else:
1728
- stats_update["stage_count"] = 1
1729
- stats_update["stage_names"] = ["default"]
1730
-
1731
- # Get field stats
1732
- field_stats_start = time.time()
1733
- field_stats = await self.storage.get_output_field_stats()
1734
- stats_update["output_fields"] = field_stats
1735
- logger.debug(f"Field stats retrieved in {(time.time() - field_stats_start)*1000:.1f}ms")
1736
-
1737
- # Update our internal stats
1738
- self.stats = stats_update
1739
- logger.debug(f"Stats prepared in {(time.time() - build_stats_start)*1000:.1f}ms")
1740
-
1741
- logger.debug(f"Total data preparation took {(time.time() - data_prep_start)*1000:.1f}ms")
1742
-
1743
- # Create message once
1744
- message_create_start = time.time()
1745
- stats_message = safe_json_dumps({"type": "stats", "data": self.stats})
1746
- logger.debug(f"Stats message created in {(time.time() - message_create_start)*1000:.1f}ms")
1747
-
1748
- # Send to all monitors asynchronously in parallel
1749
- send_start = time.time()
1750
-
1751
- async def send_to_monitor(monitor):
706
+ disconnected = set()
707
+ for monitor in self.monitors:
1752
708
  try:
1753
- await monitor.send(stats_message)
709
+ await monitor.send(message)
1754
710
  except websockets.exceptions.ConnectionClosed:
1755
- return monitor # Return for removal
1756
- except Exception as e:
1757
- logger.debug(f"Error sending stats to monitor: {e}")
1758
- return monitor # Return for removal
1759
- return None
1760
-
1761
- # Send to all monitors in parallel
1762
- monitors_copy = self.monitors.copy()
1763
- results = await asyncio.gather(
1764
- *[send_to_monitor(m) for m in monitors_copy], return_exceptions=True
1765
- )
711
+ disconnected.add(monitor)
1766
712
 
1767
- # Remove disconnected monitors
1768
- disconnected = {
1769
- m
1770
- for m, r in zip(monitors_copy, results)
1771
- if r is not None and not isinstance(r, Exception)
1772
- }
1773
713
  self.monitors -= disconnected
1774
714
 
1775
- logger.debug(
1776
- f"Stats sent to {len(monitors_copy)} monitors in {(time.time() - send_start)*1000:.1f}ms"
1777
- )
1778
-
1779
- # Send leaderboard update in a separate task to avoid blocking
1780
- leaderboard_task_start = time.time()
1781
- asyncio.create_task(self._broadcast_leaderboard())
1782
- self.is_generating_stats = False
1783
- logger.debug(
1784
- f"Leaderboard broadcast task created in {(time.time() - leaderboard_task_start)*1000:.1f}ms"
1785
- )
1786
- logger.debug(f"Stats broadcast completed in {(time.time() - total_start)*1000:.1f}ms")
1787
-
1788
715
  async def _broadcast_leaderboard(self):
1789
716
  """Send leaderboard updates to monitors - separate from stats to avoid blocking."""
1790
717
  if not self.monitors:
@@ -1875,96 +802,38 @@ class Orchestrator:
1875
802
  except Exception as e:
1876
803
  logger.error(f"Error broadcasting leaderboard: {e}")
1877
804
 
1878
- def _get_queue_stats(self) -> Dict[str, int]:
1879
- """Get queue statistics - synchronous helper for thread pool."""
1880
- with self.chunk_manager.lock:
1881
- return {
1882
- "pending_chunks": len(self.chunk_manager.pending_chunks),
1883
- "assigned_chunks": sum(
1884
- len(chunks) for chunks in self.chunk_manager.assigned_chunks.values()
1885
- ),
1886
- }
1887
-
1888
- async def _flush_processed_items(self):
1889
- """Flush batched processed items to chunk tracker."""
1890
- with self.item_batch_lock:
1891
- if not self.pending_processed_items:
1892
- return
1893
-
1894
- for chunk_id, indices in self.pending_processed_items.items():
1895
- if not indices:
1896
- continue
1897
-
1898
- # Indices here are ABSOLUTE dataset indices
1899
- # Sort indices
1900
- indices.sort()
1901
-
1902
- # Group consecutive indices into ranges
1903
- ranges = []
1904
- start = indices[0]
1905
- end = indices[0]
1906
-
1907
- for i in range(1, len(indices)):
1908
- if indices[i] == end + 1:
1909
- # Consecutive, extend range
1910
- end = indices[i]
1911
- else:
1912
- # Gap found, save current range and start new one
1913
- ranges.append((start, end))
1914
- start = indices[i]
1915
- end = indices[i]
1916
-
1917
- # Don't forget the last range
1918
- ranges.append((start, end))
1919
-
1920
- # Mark ranges as processed
1921
- for start_idx, end_idx in ranges:
1922
- self.chunk_tracker.mark_items_processed(chunk_id, start_idx, end_idx)
1923
-
1924
- with self.chunk_manager.lock:
1925
- if chunk_id in self.chunk_manager.assigned_ranges:
1926
- for start_idx, end_idx in ranges:
1927
- # Clear any assignments in this range
1928
- to_remove = []
1929
- for range_start, range_end in self.chunk_manager.assigned_ranges[
1930
- chunk_id
1931
- ]:
1932
- if range_start >= start_idx and range_end <= end_idx:
1933
- to_remove.append((range_start, range_end))
1934
-
1935
- for range_key in to_remove:
1936
- del self.chunk_manager.assigned_ranges[chunk_id][range_key]
1937
-
1938
- # Clear pending items
1939
- self.pending_processed_items.clear()
1940
- self.last_item_batch_flush = time.time()
1941
-
1942
- def get_workers_by_user_stats(self) -> Dict[str, Any]:
1943
- """Get statistics about workers grouped by user/token - thread-safe version."""
1944
- if not hasattr(self, "workers_by_user"):
1945
- return {}
1946
-
1947
- # Create a copy to avoid issues with concurrent modification
1948
- stats = {}
1949
- workers_snapshot = dict(self.workers_by_user)
1950
- for user, worker_ids in workers_snapshot.items():
1951
- stats[user] = {"worker_count": len(worker_ids), "worker_ids": list(worker_ids)}
1952
- return stats
1953
-
1954
- async def _send_activity(self, activity: str):
1955
- """Send activity update to monitors."""
805
+ async def _broadcast_stats(self):
806
+ """Broadcast statistics to all monitors."""
1956
807
  if not self.monitors:
1957
808
  return
1958
809
 
1959
- message = safe_json_dumps(
1960
- {"type": "activity", "data": f"[{datetime.now().strftime('%H:%M:%S')}] {activity}"}
810
+ # Get current stats
811
+ processor_stats = self.processor.get_stats()
812
+ storage_stats = await self.storage.get_storage_stats()
813
+
814
+ # Update main stats
815
+ self.stats["processor_stats"] = processor_stats
816
+ self.stats["total_outputs"] = storage_stats["total_captions"]
817
+
818
+ # Create message
819
+ stats_message = safe_json_dumps(
820
+ {
821
+ "type": "stats",
822
+ "data": {
823
+ **self.stats,
824
+ **storage_stats,
825
+ "current_rate": self.rate_tracker["current_rate"],
826
+ "average_rate": self.rate_tracker["average_rate"],
827
+ },
828
+ }
1961
829
  )
1962
830
 
831
+ # Send to all monitors
1963
832
  disconnected = set()
1964
833
  for monitor in self.monitors:
1965
834
  try:
1966
- await monitor.send(message)
1967
- except websockets.exceptions.ConnectionClosed:
835
+ await monitor.send(stats_message)
836
+ except:
1968
837
  disconnected.add(monitor)
1969
838
 
1970
839
  self.monitors -= disconnected
@@ -1972,235 +841,74 @@ class Orchestrator:
1972
841
  async def _heartbeat_loop(self):
1973
842
  """Send periodic heartbeats to maintain connections."""
1974
843
  while True:
1975
- try:
1976
- await asyncio.sleep(30)
1977
-
1978
- # Create a copy of worker items to avoid modification during iteration
1979
- worker_items = list(self.workers.items())
1980
- disconnected = []
1981
-
1982
- for worker_id, ws in worker_items:
1983
- try:
1984
- # Check if worker still exists before pinging
1985
- if worker_id not in self.workers:
1986
- continue
1987
-
1988
- # Send ping with timeout
1989
- pong_waiter = await ws.ping()
1990
- try:
1991
- await asyncio.wait_for(pong_waiter, timeout=10)
1992
- except asyncio.TimeoutError:
1993
- logger.warning(f"Worker {worker_id} failed to respond to ping")
1994
- disconnected.append(worker_id)
1995
- except websockets.exceptions.ConnectionClosed:
1996
- logger.info(f"Worker {worker_id} connection already closed")
1997
- disconnected.append(worker_id)
1998
- except Exception as e:
1999
- logger.error(f"Error pinging worker {worker_id}: {e}")
2000
- disconnected.append(worker_id)
2001
-
2002
- # Clean up disconnected workers
2003
- for worker_id in disconnected:
2004
- if worker_id in self.workers:
2005
- logger.info(f"Removing unresponsive worker {worker_id}")
2006
- del self.workers[worker_id]
2007
- self.chunk_manager.release_worker_chunks(worker_id)
2008
-
2009
- # Update stats
2010
- self.stats["connected_workers"] = len(self.workers)
2011
-
2012
- # Also clean up from workers_by_user if it exists
2013
- if hasattr(self, "workers_by_user"):
2014
- worker_user = (
2015
- worker_id.rsplit("_", 1)[0] if "_" in worker_id else worker_id
2016
- )
2017
- if worker_user in self.workers_by_user:
2018
- self.workers_by_user[worker_user].discard(worker_id)
2019
- if not self.workers_by_user[worker_user]:
2020
- del self.workers_by_user[worker_user]
2021
-
2022
- # Notify monitors
2023
- await self._broadcast_stats()
2024
- await self._send_activity(
2025
- f"Worker {worker_id} removed due to heartbeat timeout"
2026
- )
2027
-
2028
- except Exception as e:
2029
- logger.error(f"Error in heartbeat loop: {e}", exc_info=True)
2030
- # Continue the loop even if there's an error
2031
- await asyncio.sleep(5)
844
+ await asyncio.sleep(30)
845
+
846
+ disconnected = []
847
+ for worker_id, ws in list(self.workers.items()):
848
+ try:
849
+ pong_waiter = await ws.ping()
850
+ await asyncio.wait_for(pong_waiter, timeout=10)
851
+ except:
852
+ disconnected.append(worker_id)
853
+
854
+ # Clean up disconnected workers
855
+ for worker_id in disconnected:
856
+ logger.warning(f"Worker {worker_id} did not respond to ping, disconnecting")
857
+ if worker_id in self.workers:
858
+ del self.workers[worker_id]
859
+ logger.warning(
860
+ f"Releasing assignments for worker {worker_id} because it did not respond to ping"
861
+ )
862
+ self.processor.release_assignments(worker_id)
863
+ self.stats["connected_workers"] = len(self.workers)
2032
864
 
2033
865
  async def _checkpoint_loop(self):
2034
866
  """Periodically checkpoint storage."""
2035
- interval = self.config.get("storage", {}).get("checkpoint_interval", 1000)
867
+ interval = self.config.get("storage", {}).get("checkpoint_interval", 60)
2036
868
 
2037
869
  while True:
2038
- await asyncio.sleep(60)
2039
-
2040
- # Get current caption count from storage
2041
- storage_stats = await self.storage.get_storage_stats()
2042
- total_captions = storage_stats["total_captions"]
2043
-
2044
- # Force checkpoint at regular intervals
2045
- if total_captions > 0 and total_captions % interval == 0:
2046
- logger.info(f"Triggering checkpoint at {total_captions} captions")
2047
- await self.storage.checkpoint()
870
+ await asyncio.sleep(interval)
2048
871
 
2049
- # Update stats
2050
- self.stats["last_checkpoint"] = datetime.utcnow().isoformat()
2051
- # No need to update total_written or buffer_size - they come from storage
2052
-
2053
- await self._broadcast_stats()
2054
- logger.info(
2055
- f"Checkpoint complete. Total written to disk: {storage_stats['total_written']}"
2056
- )
872
+ await self.storage.checkpoint()
873
+ self.stats["last_checkpoint"] = datetime.utcnow().isoformat()
874
+ logger.info("Storage checkpoint complete")
2057
875
 
2058
876
  async def _stats_update_loop(self):
2059
- """Periodically update and broadcast stats - non-blocking version."""
2060
- # Get the event loop for running blocking operations
2061
- loop = asyncio.get_event_loop()
2062
-
2063
- # Track session start values
2064
- storage_stats = await self.storage.get_storage_stats()
2065
- session_start_outputs = storage_stats["total_captions"] # This now counts ALL outputs
2066
- session_start_time = time.time()
2067
-
2068
- # Track the last known total to detect flushes
2069
- last_known_total = session_start_outputs
2070
-
877
+ """Periodically update and broadcast stats."""
2071
878
  while True:
2072
879
  await asyncio.sleep(10)
2073
880
 
2074
- # Update chunk stats in thread pool to avoid blocking
2075
- chunk_stats = await loop.run_in_executor(None, self.chunk_manager.get_stats)
881
+ # Update rate tracking
2076
882
  storage_stats = await self.storage.get_storage_stats()
2077
- current_total_outputs = storage_stats["total_captions"] # ALL outputs
2078
- if self.chunk_tracker:
2079
- await self._flush_processed_items()
2080
-
2081
- self.stats["total_chunks"] = chunk_stats["total"]
2082
- self.stats["completed_chunks"] = chunk_stats["completed"]
2083
- self.stats["failed_chunks"] = chunk_stats["failed"]
2084
-
2085
- # Update total outputs stat (rename from total_captions for clarity)
2086
- self.stats["total_outputs"] = current_total_outputs
2087
- self.stats["total_captions"] = current_total_outputs # Keep for backward compatibility
2088
-
2089
- # Get queue stats in thread pool to avoid blocking
2090
- queue_stats = await loop.run_in_executor(None, self._get_queue_stats)
2091
- self.stats.update(queue_stats)
2092
-
2093
- # Calculate if we need more chunks
2094
- worker_count = self.stats.get("connected_workers", 0)
2095
- target_buffer = max(self.min_chunk_buffer, worker_count * self.chunk_buffer_multiplier)
2096
- active_chunks = self.stats["pending_chunks"] + self.stats["assigned_chunks"]
2097
- self.stats["chunk_buffer_status"] = f"{active_chunks}/{target_buffer}"
2098
-
2099
- # Update rate information
883
+ current_total = storage_stats["total_captions"]
2100
884
  current_time = time.time()
2101
- elapsed_since_update = current_time - self.rate_tracker["last_update_time"]
2102
-
2103
- if elapsed_since_update > 0:
2104
- # FIX: Handle the case where duplicates were skipped during save
2105
- # If current total is less than last known, it means duplicates were skipped
2106
- # We should not count this as negative progress
2107
- if current_total_outputs < last_known_total:
2108
- logger.debug(
2109
- f"Detected duplicate skip during save: {last_known_total} -> {current_total_outputs}"
2110
- )
2111
- # Don't calculate negative rate, just update the baseline
2112
- self.rate_tracker["last_caption_count"] = current_total_outputs
2113
- self.rate_tracker["current_rate"] = 0.0 # Set to 0 during flush
2114
- else:
2115
- # Normal rate calculation
2116
- output_diff = current_total_outputs - self.rate_tracker["last_caption_count"]
2117
- self.rate_tracker["current_rate"] = (output_diff / elapsed_since_update) * 60
2118
- self.rate_tracker["last_caption_count"] = current_total_outputs
2119
-
2120
- # Calculate average rate since THIS SESSION started
2121
- session_elapsed = current_time - session_start_time
2122
- if session_elapsed > 0:
2123
- # Always use the difference from session start for average
2124
- session_outputs = current_total_outputs - session_start_outputs
2125
- self.rate_tracker["average_rate"] = (session_outputs / session_elapsed) * 60
2126
-
2127
- # Calculate expected rate based on workers and stages
2128
- batch_size = self.vllm_config.get("batch_size", 8)
2129
-
2130
- # Count total prompts across all stages
2131
- total_prompts = 0
2132
- stages = self.vllm_config.get("stages", [])
2133
- if stages:
2134
- for stage in stages:
2135
- total_prompts += len(stage.get("prompts", []))
2136
- else:
2137
- # Backward compatibility
2138
- total_prompts = len(self.vllm_config.get("inference_prompts", ["", "", ""]))
2139
885
 
2140
- images_per_minute = 30 # Rough estimate: 30 images/min per worker
2141
- self.rate_tracker["expected_rate"] = (
2142
- worker_count * images_per_minute * total_prompts
2143
- )
2144
-
2145
- # Update trackers
886
+ elapsed = current_time - self.rate_tracker["last_update_time"]
887
+ if elapsed > 0:
888
+ output_diff = current_total - self.rate_tracker["last_output_count"]
889
+ self.rate_tracker["current_rate"] = (output_diff / elapsed) * 60
890
+ self.rate_tracker["last_output_count"] = current_total
2146
891
  self.rate_tracker["last_update_time"] = current_time
2147
- last_known_total = current_total_outputs
2148
-
2149
- # Log rate information when workers are connected
2150
- # if (
2151
- # worker_count > 0 and self.rate_tracker["current_rate"] >= 0
2152
- # ): # Only log non-negative rates
2153
- # logger.info(
2154
- # f"Rate: {self.rate_tracker['current_rate']:.1f} outputs/min "
2155
- # f"(avg: {self.rate_tracker['average_rate']:.1f}, "
2156
- # f"expected: {self.rate_tracker['expected_rate']:.1f}) | "
2157
- # f"Workers: {worker_count}, Chunks: {active_chunks}/{target_buffer}"
2158
- # )
2159
892
 
2160
- await self._broadcast_stats()
893
+ # Average rate since start
894
+ total_elapsed = current_time - self.rate_tracker["start_time"]
895
+ if total_elapsed > 0:
896
+ self.rate_tracker["average_rate"] = (current_total / total_elapsed) * 60
2161
897
 
2162
- async def _restore_state(self):
2163
- """Restore state from storage on startup."""
2164
- total_captions = await self.storage.count_captions()
2165
- logger.info(f"Restored state: {total_captions} captions")
898
+ await self._broadcast_stats()
2166
899
 
2167
900
  async def shutdown(self):
2168
901
  """Graceful shutdown."""
2169
902
  logger.info("Shutting down orchestrator...")
2170
903
 
2171
- # Stop chunk creation
2172
- if self.chunk_tracker:
2173
- await self._flush_processed_items()
2174
- self.stop_chunk_creation.set()
2175
- if self.chunk_creation_thread:
2176
- self.chunk_creation_thread.join(timeout=5)
2177
-
2178
- # Release all assigned chunks before closing connections
2179
- for worker_id in list(self.workers.keys()):
2180
- self.chunk_manager.release_worker_chunks(worker_id)
2181
- if self.chunk_tracker:
2182
- # Update chunk tracker to mark assigned chunks as pending
2183
- with self.chunk_manager.lock:
2184
- for chunk_id in list(self.chunk_manager.assigned_chunks.get(worker_id, [])):
2185
- self.chunk_tracker.mark_pending(chunk_id)
2186
-
2187
904
  # Close all connections
2188
905
  for ws in list(self.workers.values()):
2189
906
  await ws.close()
2190
907
  for ws in list(self.monitors):
2191
908
  await ws.close()
2192
909
 
2193
- # Save chunk state
2194
- if self.chunk_tracker:
2195
- self.chunk_tracker.save()
2196
-
2197
910
  # Final checkpoint
2198
- logger.info(f"Final flush: {len(self.storage.caption_buffer)} captions in buffer")
2199
911
  await self.storage.checkpoint()
2200
-
2201
- # Log final statistics
2202
- logger.info(
2203
- f"Shutdown complete. Total captions collected: {self.storage.total_captions_written}"
2204
- )
2205
-
2206
912
  await self.storage.close()
913
+
914
+ logger.info("Shutdown complete")