caption-flow 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -21,15 +21,15 @@ from collections import deque, defaultdict
21
21
  import threading
22
22
  from queue import Queue, Empty
23
23
 
24
+ from .workers import data
24
25
  import websockets
25
26
  from websockets.server import WebSocketServerProtocol
26
27
 
27
28
  from .storage import StorageManager
28
29
  from .models import Caption, Contributor
29
30
  from .utils.auth import AuthManager
30
- from .utils.dataset_loader import DatasetLoader, ShardTracker
31
+ from .utils import DatasetLoader, ShardTracker, ChunkTracker
31
32
  from .utils.json_utils import safe_dict, safe_json_dumps, to_json_dict
32
- from .utils.chunk_tracker import ChunkTracker
33
33
 
34
34
  logger = logging.getLogger(__name__)
35
35
 
@@ -48,6 +48,43 @@ class ShardChunk:
48
48
  assigned_at: Optional[datetime] = None
49
49
  completed_at: Optional[datetime] = None
50
50
 
51
+ @classmethod
52
+ def create(
53
+ cls, shard_url: str, shard_name: str, start_index: int, chunk_size: int
54
+ ) -> "ShardChunk":
55
+ """Factory method to create a chunk with consistent ID."""
56
+ # Always use consistent format: dataset_chunk_startindex
57
+ if shard_url.startswith("hf_dataset:"):
58
+ # Extract dataset path
59
+ parts = shard_url.split(":")
60
+ dataset_path = parts[1] if len(parts) > 1 else "unknown"
61
+ chunk_id = f"{dataset_path.replace('/', '_')}_chunk_{start_index}"
62
+ else:
63
+ # WebDataset format
64
+ chunk_id = f"{shard_name}_chunk_{start_index}"
65
+
66
+ return cls(
67
+ chunk_id=chunk_id,
68
+ shard_url=shard_url,
69
+ shard_name=shard_name,
70
+ start_index=start_index,
71
+ chunk_size=chunk_size,
72
+ )
73
+
74
+ def belongs_to_shard(self, shard_identifier: str) -> bool:
75
+ """Check if this chunk belongs to a given shard."""
76
+ return self.shard_name == shard_identifier
77
+
78
+ def to_dict(self) -> Dict[str, Any]:
79
+ """Convert to dict for JSON serialization (for workers)."""
80
+ return {
81
+ "chunk_id": self.chunk_id,
82
+ "shard_url": self.shard_url,
83
+ "shard_name": self.shard_name,
84
+ "start_index": self.start_index,
85
+ "chunk_size": self.chunk_size,
86
+ }
87
+
51
88
 
52
89
  class ChunkManager:
53
90
  """Manages shard chunk creation and assignment."""
@@ -67,9 +104,7 @@ class ChunkManager:
67
104
  chunks = []
68
105
 
69
106
  for start_idx in range(0, total_items, self.chunk_size):
70
- chunk_id = f"{shard_name}_chunk_{start_idx}"
71
- chunk = ShardChunk(
72
- chunk_id=chunk_id,
107
+ chunk = ShardChunk.create(
73
108
  shard_url=shard_url,
74
109
  shard_name=shard_name,
75
110
  start_index=start_idx,
@@ -77,8 +112,8 @@ class ChunkManager:
77
112
  )
78
113
 
79
114
  with self.lock:
80
- self.chunks[chunk_id] = chunk
81
- self.pending_chunks.append(chunk_id)
115
+ self.chunks[chunk.chunk_id] = chunk
116
+ self.pending_chunks.append(chunk.chunk_id)
82
117
 
83
118
  chunks.append(chunk)
84
119
 
@@ -86,24 +121,84 @@ class ChunkManager:
86
121
 
87
122
  def get_chunks_for_worker(
88
123
  self, worker_id: str, count: int = 1, tracker: Optional["ChunkTracker"] = None
89
- ) -> List[ShardChunk]:
90
- """Get available chunks for a worker."""
124
+ ) -> List[Dict[str, Any]]:
125
+ """Get available chunks with unprocessed items for a worker."""
91
126
  assigned = []
92
127
 
93
128
  with self.lock:
129
+ # FIRST PRIORITY: Check if this worker already has assigned chunks
130
+ # Workers should complete their current chunks before getting new ones
131
+ if worker_id in self.assigned_chunks:
132
+ existing_chunk_ids = list(self.assigned_chunks[worker_id])
133
+ for chunk_id in existing_chunk_ids:
134
+ if len(assigned) >= count:
135
+ break
136
+
137
+ chunk = self.chunks.get(chunk_id)
138
+ if not chunk:
139
+ continue
140
+
141
+ # Check if chunk still has unprocessed items
142
+ if tracker:
143
+ chunk_info = tracker.get_chunk_with_unprocessed_items(chunk_id)
144
+ if chunk_info and chunk_info["unprocessed_ranges"]:
145
+ assigned.append(
146
+ {
147
+ "chunk": chunk,
148
+ "unprocessed_ranges": chunk_info["unprocessed_ranges"],
149
+ }
150
+ )
151
+ else:
152
+ # No tracker, assume chunk needs processing
153
+ assigned.append(
154
+ {
155
+ "chunk": chunk,
156
+ "unprocessed_ranges": [(0, chunk.chunk_size - 1)],
157
+ }
158
+ )
159
+
160
+ # SECOND PRIORITY: Get new pending chunks
161
+ # Only if worker doesn't have enough chunks already
94
162
  while len(assigned) < count and self.pending_chunks:
95
163
  chunk_id = self.pending_chunks.popleft()
96
- chunk = self.chunks[chunk_id]
164
+ chunk = self.chunks.get(chunk_id)
165
+
166
+ if not chunk:
167
+ continue
97
168
 
169
+ # Verify chunk is truly pending (defensive check)
170
+ if chunk.status != "pending" or chunk.assigned_to is not None:
171
+ logger.warning(
172
+ f"Chunk {chunk_id} in pending queue but status={chunk.status}, assigned_to={chunk.assigned_to}"
173
+ )
174
+ continue
175
+
176
+ # Assign to this worker
98
177
  chunk.assigned_to = worker_id
99
178
  chunk.status = "assigned"
100
179
  chunk.assigned_at = datetime.utcnow()
101
-
102
180
  self.assigned_chunks[worker_id].add(chunk_id)
103
- assigned.append(chunk)
181
+
182
+ # Get unprocessed ranges
183
+ unprocessed_ranges = [(0, chunk.chunk_size - 1)] # Default
104
184
  if tracker:
185
+ chunk_info = tracker.get_chunk_with_unprocessed_items(chunk_id)
186
+ if chunk_info:
187
+ unprocessed_ranges = chunk_info["unprocessed_ranges"]
105
188
  tracker.mark_assigned(chunk_id, worker_id)
106
189
 
190
+ assigned.append({"chunk": chunk, "unprocessed_ranges": unprocessed_ranges})
191
+
192
+ # Log what we're assigning
193
+ if assigned:
194
+ chunk_summary = ", ".join(
195
+ [
196
+ f"{info['chunk'].chunk_id}[{len(info['unprocessed_ranges'])} ranges]"
197
+ for info in assigned
198
+ ]
199
+ )
200
+ logger.info(f"Assigning to worker {worker_id}: {chunk_summary}")
201
+
107
202
  return assigned
108
203
 
109
204
  def complete_chunk(self, chunk_id: str, worker_id: str) -> bool:
@@ -173,6 +268,27 @@ class Orchestrator:
173
268
  self.dataset_config = config.get("dataset", {})
174
269
  self.dataset_path = self.dataset_config.get("path")
175
270
  self.dataset_type = self.dataset_config.get("type", "huggingface")
271
+ self.dataset_split = self.dataset_config.get("split", "train") # Add split configuration
272
+ self.dataset_image_column = self.dataset_config.get(
273
+ "image_column", "image"
274
+ ) # Add image column config
275
+
276
+ # Dataset components
277
+ self.dataset_loader = None
278
+ self.shard_tracker = None
279
+ self.chunk_tracker = None
280
+
281
+ if self.dataset_path:
282
+ self.dataset_loader = DatasetLoader(
283
+ self.dataset_path,
284
+ self.dataset_type,
285
+ self.dataset_split,
286
+ self.dataset_image_column,
287
+ )
288
+ checkpoint_dir = Path(config.get("storage", {}).get("checkpoint_dir", "./checkpoints"))
289
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
290
+ self.shard_tracker = ShardTracker(checkpoint_dir / "shards.json")
291
+ self.chunk_tracker = ChunkTracker(checkpoint_dir / "chunks.json")
176
292
 
177
293
  # vLLM configuration to distribute to workers
178
294
  self.vllm_config = config.get(
@@ -233,6 +349,11 @@ class Orchestrator:
233
349
 
234
350
  # Initialize chunk manager with reference to chunk tracker
235
351
  self.chunk_manager = ChunkManager(self.chunk_size, self.chunk_tracker)
352
+ self.pending_processed_items = defaultdict(list) # chunk_id -> list of indices
353
+ self.item_batch_lock = threading.Lock()
354
+ self.last_item_batch_flush = time.time()
355
+ self.item_batch_interval = 5 # Flush every 5 seconds
356
+ self.item_batch_size = 100 # Or every 100 items
236
357
 
237
358
  # Track connections
238
359
  self.workers: Dict[str, WebSocketServerProtocol] = {}
@@ -242,17 +363,15 @@ class Orchestrator:
242
363
  self.ssl_context = self._setup_ssl()
243
364
 
244
365
  # Statistics
366
+ self.is_generating_stats = False
245
367
  self.stats = {
246
368
  "total_chunks": 0,
247
369
  "completed_chunks": 0,
248
370
  "failed_chunks": 0,
249
- "total_captions": 0,
250
371
  "connected_workers": 0,
251
372
  "total_shards": 0,
252
373
  "completed_shards": 0,
253
374
  "current_shard": None,
254
- "buffer_size": 0,
255
- "total_written": 0,
256
375
  "last_checkpoint": None,
257
376
  }
258
377
 
@@ -266,7 +385,7 @@ class Orchestrator:
266
385
  "expected_rate": 0.0,
267
386
  }
268
387
 
269
- # Data sample queue for VLLMWorkers
388
+ # Data sample queue for CaptionWorker
270
389
  self.data_sample_queue = asyncio.Queue(maxsize=1000)
271
390
  self.data_workers: Dict[str, WebSocketServerProtocol] = {}
272
391
 
@@ -310,10 +429,23 @@ class Orchestrator:
310
429
  # Mark state as not restored until we process checkpoints
311
430
  self.state_restored.clear()
312
431
 
432
+ # Get dataset info to check format
433
+ dataset_info = self.dataset_loader.get_dataset_info()
434
+ dataset_format = dataset_info.get("dataset_format", "unknown")
435
+ logger.info(f"Dataset format: {dataset_format}")
436
+
313
437
  # Get all shards
314
438
  self.all_shards = self.dataset_loader.get_shard_list()
315
439
  self.stats["total_shards"] = len(self.all_shards)
316
440
 
441
+ # For HuggingFace datasets, we might need to dynamically create more shards
442
+ if dataset_format == "huggingface_datasets":
443
+ self._is_hf_dataset = True
444
+ self._hf_chunk_size = 10000 # Items per virtual shard
445
+ self._next_hf_shard_index = len(self.all_shards) # For creating new virtual shards
446
+ else:
447
+ self._is_hf_dataset = False
448
+
317
449
  # Get shard status from ChunkTracker
318
450
  shards_summary = self.chunk_tracker.get_shards_summary() if self.chunk_tracker else {}
319
451
  completed_shards = {
@@ -336,7 +468,10 @@ class Orchestrator:
336
468
 
337
469
  # Filter out shards that already have chunks created
338
470
  remaining_shards = [
339
- shard for shard in remaining_shards if Path(shard).stem not in shards_with_chunks
471
+ shard
472
+ for shard in remaining_shards
473
+ if (shard if shard.startswith("hf_dataset:") else Path(shard).stem)
474
+ not in shards_with_chunks
340
475
  ]
341
476
 
342
477
  self.stats["completed_shards"] = len(completed_shards)
@@ -356,25 +491,18 @@ class Orchestrator:
356
491
  with self.chunk_manager.lock:
357
492
  for chunk_state in shard_info["chunks"]:
358
493
  if chunk_state.status in ["pending", "failed", "assigned"]:
359
- # Find shard URL
360
- shard_url = None
361
- for url in self.all_shards:
362
- if Path(url).stem == shard_name:
363
- shard_url = url
364
- break
365
-
366
- if shard_url:
367
- chunk = ShardChunk(
368
- chunk_id=chunk_state.chunk_id,
369
- shard_url=shard_url,
370
- shard_name=chunk_state.shard_name,
371
- start_index=chunk_state.start_index,
372
- chunk_size=chunk_state.chunk_size,
373
- )
374
- self.chunk_manager.chunks[chunk_state.chunk_id] = chunk
375
- self.chunk_manager.pending_chunks.append(chunk_state.chunk_id)
376
- requeued_chunks_by_shard[shard_name].append(chunk_state.chunk_id)
377
- initial_pending += 1
494
+ # ChunkState already has shard_url stored
495
+ chunk = ShardChunk(
496
+ chunk_id=chunk_state.chunk_id,
497
+ shard_url=chunk_state.shard_url,
498
+ shard_name=chunk_state.shard_name,
499
+ start_index=chunk_state.start_index,
500
+ chunk_size=chunk_state.chunk_size,
501
+ )
502
+ self.chunk_manager.chunks[chunk_state.chunk_id] = chunk
503
+ self.chunk_manager.pending_chunks.append(chunk_state.chunk_id)
504
+ requeued_chunks_by_shard[shard_name].append(chunk_state.chunk_id)
505
+ initial_pending += 1
378
506
 
379
507
  logger.info(f"Re-queued {initial_pending} existing pending chunks")
380
508
  for shard_name, chunk_ids in requeued_chunks_by_shard.items():
@@ -426,7 +554,13 @@ class Orchestrator:
426
554
  if current_shard_url is None or current_shard_index >= current_shard_items:
427
555
  try:
428
556
  current_shard_url = next(shard_iter)
429
- current_shard_name = Path(current_shard_url).stem
557
+
558
+ # Extract shard name based on type
559
+ if current_shard_url.startswith("hf_dataset:"):
560
+ current_shard_name = current_shard_url # Use full ID for virtual shards
561
+ else:
562
+ current_shard_name = Path(current_shard_url).stem
563
+
430
564
  self.stats["current_shard"] = current_shard_name
431
565
 
432
566
  # Skip if we already have chunks from this shard
@@ -439,16 +573,74 @@ class Orchestrator:
439
573
 
440
574
  # Count items in new shard
441
575
  logger.info(f"Loading new shard {current_shard_name}")
442
- current_shard_items = sum(
443
- 1 for _ in self.dataset_loader.iterate_shard(current_shard_url)
444
- )
576
+
577
+ # For virtual HF dataset shards, use the chunk size directly
578
+ if current_shard_url.startswith("hf_dataset:"):
579
+ current_shard_items = self.dataset_loader.count_shard_items(
580
+ current_shard_url
581
+ )
582
+ logger.info(
583
+ f"Virtual shard {current_shard_name} has {current_shard_items} items"
584
+ )
585
+ else:
586
+ # For WebDataset, actually count items
587
+ current_shard_items = sum(
588
+ 1 for _ in self.dataset_loader.iterate_shard(current_shard_url)
589
+ )
590
+ logger.info(
591
+ f"Shard {current_shard_name} has {current_shard_items} items"
592
+ )
593
+
445
594
  current_shard_index = 0
446
- logger.info(f"Shard {current_shard_name} has {current_shard_items} items")
447
595
 
448
596
  except StopIteration:
449
- # No more shards
597
+ # No more shards in the iterator
598
+ if self._is_hf_dataset:
599
+ # Before creating new virtual shards, check if we have pending chunks
600
+ with self.chunk_manager.lock:
601
+ pending_count = len(self.chunk_manager.pending_chunks)
602
+
603
+ if pending_count > 0:
604
+ # Don't create new shards if we have pending chunks
605
+ logger.debug(
606
+ f"Have {pending_count} pending chunks, not creating new virtual shards yet"
607
+ )
608
+ current_shard_url = None
609
+ time.sleep(2)
610
+ continue
611
+
612
+ # For HF datasets, we can create more virtual shards on demand
613
+ logger.info(
614
+ "Creating additional virtual shards for HuggingFace dataset"
615
+ )
616
+
617
+ # Create 10 more virtual shards
618
+ new_shards = []
619
+ for i in range(10):
620
+ shard_id = f"hf_dataset:{self.dataset_path}:chunk:{self._next_hf_shard_index * self._hf_chunk_size}"
621
+ new_shards.append(shard_id)
622
+ self._next_hf_shard_index += 1
623
+
624
+ # Add to all_shards and create new iterator
625
+ self.all_shards.extend(new_shards)
626
+ self.stats["total_shards"] = len(self.all_shards)
627
+
628
+ # Filter for unprocessed shards
629
+ remaining_new_shards = [
630
+ s
631
+ for s in new_shards
632
+ if s not in shards_summary and s not in completed_shards
633
+ ]
634
+
635
+ if remaining_new_shards:
636
+ shard_iter = iter(remaining_new_shards)
637
+ logger.info(f"Added {len(remaining_new_shards)} new virtual shards")
638
+ continue
639
+
640
+ # No more shards to process
450
641
  logger.info("No more shards to process")
451
642
  break
643
+
452
644
  except Exception as e:
453
645
  logger.error(f"Error loading shard {current_shard_name}: {e}")
454
646
  current_shard_url = None
@@ -456,25 +648,40 @@ class Orchestrator:
456
648
 
457
649
  # Create a chunk from current shard
458
650
  if current_shard_url and current_shard_index < current_shard_items:
459
- chunk_id = f"{current_shard_name}_chunk_{current_shard_index}"
460
- chunk_size = min(self.chunk_size, current_shard_items - current_shard_index)
651
+ # Calculate the absolute dataset index for this chunk
652
+ if current_shard_url.startswith("hf_dataset:"):
653
+ # Parse the virtual shard URL to get the base start index
654
+ parts = current_shard_url.split(":")
655
+ if len(parts) >= 4 and parts[2] == "chunk":
656
+ shard_base_index = int(parts[3])
657
+ else:
658
+ shard_base_index = 0
659
+
660
+ # The absolute start index for this chunk in the dataset
661
+ absolute_start_index = shard_base_index + current_shard_index
662
+ else:
663
+ # For WebDataset, current_shard_index is already absolute
664
+ absolute_start_index = current_shard_index
665
+
666
+ # Create chunk with absolute index
667
+ chunk = ShardChunk.create(
668
+ shard_url=current_shard_url,
669
+ shard_name=current_shard_name,
670
+ start_index=absolute_start_index,
671
+ chunk_size=min(self.chunk_size, current_shard_items - current_shard_index),
672
+ )
461
673
 
462
- # Add to ChunkTracker
674
+ # Add to ChunkTracker with all required fields
463
675
  if self.chunk_tracker and self.chunk_tracker.add_chunk(
464
- chunk_id, current_shard_name, current_shard_index, chunk_size
676
+ chunk.chunk_id,
677
+ chunk.shard_name,
678
+ chunk.shard_url,
679
+ chunk.start_index,
680
+ chunk.chunk_size,
465
681
  ):
466
- # Create chunk
467
- chunk = ShardChunk(
468
- chunk_id=chunk_id,
469
- shard_url=current_shard_url,
470
- shard_name=current_shard_name,
471
- start_index=current_shard_index,
472
- chunk_size=chunk_size,
473
- )
474
-
475
682
  with self.chunk_manager.lock:
476
- self.chunk_manager.chunks[chunk_id] = chunk
477
- self.chunk_manager.pending_chunks.append(chunk_id)
683
+ self.chunk_manager.chunks[chunk.chunk_id] = chunk
684
+ self.chunk_manager.pending_chunks.append(chunk.chunk_id)
478
685
 
479
686
  chunks_created += 1
480
687
  self.stats["total_chunks"] += 1
@@ -484,10 +691,14 @@ class Orchestrator:
484
691
  if chunks_created > 0:
485
692
  logger.info(f"Created {chunks_created} chunks on demand")
486
693
 
487
- # If we couldn't create any chunks and there are no more shards, we're done
694
+ # If we couldn't create any chunks and there are no more shards, check if it's HF dataset
488
695
  if chunks_created == 0 and current_shard_url is None:
489
- logger.info("All shards processed, chunk creation complete")
490
- break
696
+ if self._is_hf_dataset:
697
+ # We can always create more virtual shards for HF datasets
698
+ logger.debug("Will create more virtual shards on next iteration")
699
+ else:
700
+ logger.info("All shards processed, chunk creation complete")
701
+ break
491
702
 
492
703
  # Brief pause to avoid spinning
493
704
  time.sleep(1)
@@ -558,7 +769,9 @@ class Orchestrator:
558
769
  elif auth_ticket.role == "admin":
559
770
  await self._handle_admin(websocket, auth_ticket)
560
771
  else:
561
- await websocket.send(safe_json_dumps({"error": f"Unknown role: {auth_ticket.role}"}))
772
+ await websocket.send(
773
+ safe_json_dumps({"error": f"Unknown role: {auth_ticket.role}"})
774
+ )
562
775
 
563
776
  except Exception as e:
564
777
  logger.error(f"Connection error: {e}")
@@ -604,81 +817,118 @@ class Orchestrator:
604
817
  requires_worker_restart = False
605
818
 
606
819
  try:
820
+ # Extract orchestrator section if present
821
+ if "orchestrator" in new_config:
822
+ # Config has orchestrator wrapper, extract it
823
+ orchestrator_config = new_config["orchestrator"]
824
+ else:
825
+ # Config is already at orchestrator level
826
+ orchestrator_config = new_config
827
+
828
+ # Helper function for deep comparison
829
+ def deep_equal(a, b):
830
+ """Deep comparison of two values including nested dicts and lists."""
831
+ if type(a) != type(b):
832
+ return False
833
+ if isinstance(a, dict):
834
+ if set(a.keys()) != set(b.keys()):
835
+ return False
836
+ return all(deep_equal(a[k], b[k]) for k in a.keys())
837
+ elif isinstance(a, (list, tuple)):
838
+ if len(a) != len(b):
839
+ return False
840
+ return all(deep_equal(x, y) for x, y in zip(a, b))
841
+ else:
842
+ return a == b
843
+
607
844
  # Update vLLM configuration
608
- if "vllm" in new_config:
845
+ if "vllm" in orchestrator_config:
609
846
  old_vllm = self.vllm_config.copy()
847
+ new_vllm = orchestrator_config["vllm"]
610
848
 
611
- # Check each field for actual changes
612
- vllm_changed = False
613
- for key, value in new_config["vllm"].items():
614
- if self.vllm_config.get(key) != value:
615
- self.vllm_config[key] = value
616
- vllm_changed = True
849
+ # Check if vLLM config actually changed using deep comparison
850
+ vllm_changed = not deep_equal(old_vllm, new_vllm)
617
851
 
618
852
  if vllm_changed:
853
+ # Update the vLLM config
854
+ self.vllm_config = new_vllm.copy()
619
855
  updated_sections.append("vllm")
620
856
 
621
857
  # Check if critical changes require worker restart
622
858
  if (
623
- old_vllm.get("model") != self.vllm_config.get("model")
859
+ old_vllm.get("model") != new_vllm.get("model")
624
860
  or old_vllm.get("gpu_memory_utilization")
625
- != self.vllm_config.get("gpu_memory_utilization")
861
+ != new_vllm.get("gpu_memory_utilization")
626
862
  or old_vllm.get("tensor_parallel_size")
627
- != self.vllm_config.get("tensor_parallel_size")
863
+ != new_vllm.get("tensor_parallel_size")
864
+ or old_vllm.get("dtype") != new_vllm.get("dtype")
865
+ or old_vllm.get("max_model_len") != new_vllm.get("max_model_len")
628
866
  ):
629
867
  requires_worker_restart = True
630
868
  warnings.append(
631
869
  "Critical vLLM changes detected - workers will be disconnected to reload"
632
870
  )
871
+ logger.info(
872
+ f"Model change: {old_vllm.get('model')} -> {new_vllm.get('model')}"
873
+ )
633
874
 
634
875
  # Update dataset configuration
635
- if "dataset" in new_config:
636
- dataset_changed = False
637
- for key, value in new_config["dataset"].items():
638
- if self.dataset_config.get(key) != value:
639
- self.dataset_config[key] = value
640
- dataset_changed = True
876
+ if "dataset" in orchestrator_config:
877
+ old_dataset = self.dataset_config.copy()
878
+ new_dataset = orchestrator_config["dataset"]
879
+
880
+ dataset_changed = not deep_equal(old_dataset, new_dataset)
641
881
 
642
882
  if dataset_changed:
883
+ self.dataset_config = new_dataset.copy()
643
884
  self.dataset_path = self.dataset_config.get("path")
644
885
  self.dataset_type = self.dataset_config.get("type", "huggingface")
645
886
  updated_sections.append("dataset")
646
887
  warnings.append("Dataset changes will apply to new chunks only")
647
888
 
648
889
  # Update chunk settings
649
- if "chunk_size" in new_config and self.chunk_size != new_config["chunk_size"]:
650
- self.chunk_size = new_config["chunk_size"]
890
+ if (
891
+ "chunk_size" in orchestrator_config
892
+ and self.chunk_size != orchestrator_config["chunk_size"]
893
+ ):
894
+ self.chunk_size = orchestrator_config["chunk_size"]
651
895
  self.chunk_manager.chunk_size = self.chunk_size
652
896
  updated_sections.append("chunk_size")
653
897
 
654
898
  if (
655
- "chunks_per_request" in new_config
656
- and self.chunks_per_request != new_config["chunks_per_request"]
899
+ "chunks_per_request" in orchestrator_config
900
+ and self.chunks_per_request != orchestrator_config["chunks_per_request"]
657
901
  ):
658
- self.chunks_per_request = new_config["chunks_per_request"]
902
+ self.chunks_per_request = orchestrator_config["chunks_per_request"]
659
903
  updated_sections.append("chunks_per_request")
660
904
 
661
- # Recreate auth manager
662
- self.auth = AuthManager(config=new_config)
905
+ # Update auth configuration
906
+ if "auth" in orchestrator_config:
907
+ try:
908
+ self.auth = AuthManager({"auth": orchestrator_config["auth"]})
909
+ updated_sections.append("auth")
910
+ except Exception as e:
911
+ logger.error(f"Failed to update AuthManager: {e}")
912
+ warnings.append(f"Auth update failed: {e}")
663
913
 
664
914
  # Update buffer settings
665
915
  if (
666
- "chunk_buffer_multiplier" in new_config
667
- and self.chunk_buffer_multiplier != new_config["chunk_buffer_multiplier"]
916
+ "chunk_buffer_multiplier" in orchestrator_config
917
+ and self.chunk_buffer_multiplier != orchestrator_config["chunk_buffer_multiplier"]
668
918
  ):
669
- self.chunk_buffer_multiplier = new_config["chunk_buffer_multiplier"]
919
+ self.chunk_buffer_multiplier = orchestrator_config["chunk_buffer_multiplier"]
670
920
  updated_sections.append("chunk_buffer_multiplier")
671
921
 
672
922
  if (
673
- "min_chunk_buffer" in new_config
674
- and self.min_chunk_buffer != new_config["min_chunk_buffer"]
923
+ "min_chunk_buffer" in orchestrator_config
924
+ and self.min_chunk_buffer != orchestrator_config["min_chunk_buffer"]
675
925
  ):
676
- self.min_chunk_buffer = new_config["min_chunk_buffer"]
926
+ self.min_chunk_buffer = orchestrator_config["min_chunk_buffer"]
677
927
  updated_sections.append("min_chunk_buffer")
678
928
 
679
929
  # Update storage settings
680
- if "storage" in new_config:
681
- storage_config = new_config["storage"]
930
+ if "storage" in orchestrator_config:
931
+ storage_config = orchestrator_config["storage"]
682
932
  storage_changed = False
683
933
 
684
934
  if (
@@ -701,21 +951,6 @@ class Orchestrator:
701
951
  if storage_changed:
702
952
  updated_sections.append("storage")
703
953
 
704
- # Update data worker storage config
705
- if "data_worker_storage" in new_config:
706
- current_dw_storage = self.config.get("data_worker_storage", {})
707
- if current_dw_storage != new_config["data_worker_storage"]:
708
- self.config["data_worker_storage"] = new_config["data_worker_storage"]
709
- updated_sections.append("data_worker_storage")
710
- warnings.append("Data worker storage config will apply to new connections only")
711
-
712
- # Update backpressure threshold
713
- if "backpressure_threshold" in new_config:
714
- current_threshold = getattr(self, "backpressure_threshold", 800)
715
- if current_threshold != new_config["backpressure_threshold"]:
716
- self.backpressure_threshold = new_config["backpressure_threshold"]
717
- updated_sections.append("backpressure_threshold")
718
-
719
954
  # Check if any changes were made
720
955
  if not updated_sections:
721
956
  await websocket.send(
@@ -729,29 +964,49 @@ class Orchestrator:
729
964
  logger.info("Configuration reload requested but no changes detected")
730
965
  return
731
966
 
732
- # Update the main config for any other fields
733
- self.config.update(new_config)
967
+ # Update the main config
968
+ if "orchestrator" in new_config:
969
+ self.config["orchestrator"] = orchestrator_config
970
+ else:
971
+ self.config.update(orchestrator_config)
734
972
 
735
973
  # Handle worker restart if needed
736
974
  if requires_worker_restart:
737
975
  logger.info("Disconnecting all workers for configuration reload...")
738
976
 
739
- # Disconnect all workers
740
- worker_ids = list(self.workers.keys())
741
- for worker_id in worker_ids:
977
+ # Send reload message to workers first
978
+ reload_msg = safe_json_dumps(
979
+ {
980
+ "type": "reload_vllm",
981
+ "vllm_config": self.vllm_config,
982
+ }
983
+ )
984
+
985
+ # Create a list of worker items to avoid modifying dict during iteration
986
+ worker_items = list(self.workers.items())
987
+ disconnected = []
988
+
989
+ for worker_id, ws in worker_items:
742
990
  try:
743
- await self.workers[worker_id].close(
744
- code=1012, reason="Configuration reload"
745
- )
991
+ await ws.send(reload_msg)
992
+ # Give worker time to process before disconnect
993
+ await asyncio.sleep(0.5)
994
+ await ws.close(code=1012, reason="Configuration reload")
995
+ disconnected.append(worker_id)
746
996
  except:
747
- pass
997
+ disconnected.append(worker_id) # Still mark as disconnected if error
998
+
999
+ # Now safely clear workers dict
1000
+ for worker_id in disconnected:
1001
+ if worker_id in self.workers:
1002
+ del self.workers[worker_id]
748
1003
 
749
1004
  warnings.append(
750
- f"Disconnected {len(worker_ids)} workers - they will reconnect with new config"
1005
+ f"Sent reload message to {len(disconnected)} workers - they will reconnect with new config"
751
1006
  )
752
1007
  else:
753
- # Just notify workers about config changes
754
- reload_msg = safe_json_dumps(
1008
+ # Just notify workers about config changes without disconnecting
1009
+ config_update_msg = safe_json_dumps(
755
1010
  {
756
1011
  "type": "config_update",
757
1012
  "vllm_config": self.vllm_config if "vllm" in updated_sections else None,
@@ -761,15 +1016,21 @@ class Orchestrator:
761
1016
  }
762
1017
  )
763
1018
 
1019
+ # Create a list of worker items to avoid modifying dict during iteration
1020
+ worker_items = list(self.workers.items())
764
1021
  disconnected = []
765
- for worker_id, ws in self.workers.items():
1022
+
1023
+ for worker_id, ws in worker_items:
766
1024
  try:
767
- await ws.send(reload_msg)
1025
+ await ws.send(config_update_msg)
1026
+ logger.info(f"Sent config update to worker {worker_id}")
768
1027
  except:
769
1028
  disconnected.append(worker_id)
770
1029
 
1030
+ # Now safely remove disconnected workers
771
1031
  for worker_id in disconnected:
772
- del self.workers[worker_id]
1032
+ if worker_id in self.workers:
1033
+ del self.workers[worker_id]
773
1034
 
774
1035
  # Send success response
775
1036
  await websocket.send(
@@ -788,34 +1049,58 @@ class Orchestrator:
788
1049
 
789
1050
  except Exception as e:
790
1051
  logger.error(f"Configuration reload failed: {e}")
1052
+ import traceback
1053
+
1054
+ logger.error(traceback.format_exc())
791
1055
  await websocket.send(safe_json_dumps({"type": "reload_failed", "error": str(e)}))
792
1056
 
793
1057
  async def _handle_worker(self, websocket: WebSocketServerProtocol, auth_ticket):
794
1058
  """Handle worker connection lifecycle."""
795
- worker_id = getattr(auth_ticket, "name", str(uuid.uuid4()))
1059
+ # Generate unique worker ID even if using same token
1060
+ base_name = getattr(auth_ticket, "name", "worker")
1061
+ worker_id = f"{base_name}_{str(uuid.uuid4())[:8]}" # Add unique suffix
1062
+
1063
+ # Track the original token/user for accounting
1064
+ worker_user = base_name # Keep track of which user/token this worker belongs to
1065
+
796
1066
  self.workers[worker_id] = websocket
797
1067
  self.stats["connected_workers"] = len(self.workers)
798
1068
 
799
- # Register contributor
800
- contributor = Contributor(
801
- contributor_id=worker_id, name=worker_id, total_captions=0, trust_level=1
802
- )
803
- await self.storage.save_contributor(contributor)
1069
+ # Optionally track workers by user/token
1070
+ if not hasattr(self, "workers_by_user"):
1071
+ self.workers_by_user = defaultdict(set)
1072
+ self.workers_by_user[worker_user].add(worker_id)
1073
+
1074
+ # Register contributor with the base name (for aggregating stats per user)
1075
+ contributor = await self.storage.get_contributor(worker_user)
1076
+ if not contributor:
1077
+ contributor = Contributor(
1078
+ contributor_id=worker_user,
1079
+ name=worker_user,
1080
+ total_captions=0,
1081
+ trust_level=1,
1082
+ )
1083
+ await self.storage.save_contributor(contributor)
804
1084
 
805
- logger.info(f"Worker {worker_id} connected")
1085
+ logger.info(f"Worker {worker_id} (user: {worker_user}) connected")
806
1086
  await self._broadcast_stats()
807
- await self._send_activity(f"Worker {worker_id} connected")
1087
+ await self._send_activity(f"Worker {worker_id} (user: {worker_user}) connected")
808
1088
 
809
1089
  try:
810
1090
  # Send welcome message with dataset configuration
811
1091
  welcome_message = {
812
1092
  "type": "welcome",
813
1093
  "worker_id": worker_id,
1094
+ "user_id": worker_user,
814
1095
  "dataset_config": {
815
1096
  "dataset_path": self.dataset_path,
816
1097
  "dataset_type": self.dataset_type,
817
- "path": self.dataset_path, # For compatibility
818
- "type": self.dataset_type, # For compatibility
1098
+ "dataset_split": self.dataset_split,
1099
+ "dataset_image_column": self.dataset_image_column,
1100
+ "path": self.dataset_path,
1101
+ "type": self.dataset_type,
1102
+ "split": self.dataset_split,
1103
+ "image_column": self.dataset_image_column,
819
1104
  },
820
1105
  "vllm_config": self.vllm_config,
821
1106
  }
@@ -826,21 +1111,29 @@ class Orchestrator:
826
1111
  await self._process_worker_message(worker_id, data)
827
1112
 
828
1113
  except websockets.exceptions.ConnectionClosed:
829
- logger.info(f"Worker {worker_id} disconnected")
1114
+ logger.info(f"Worker {worker_id} (user: {worker_user}) disconnected")
830
1115
  finally:
831
- del self.workers[worker_id]
1116
+ if worker_id in self.workers:
1117
+ del self.workers[worker_id]
1118
+
1119
+ # Clean up user tracking
1120
+ if hasattr(self, "workers_by_user") and worker_user in self.workers_by_user:
1121
+ self.workers_by_user[worker_user].discard(worker_id)
1122
+ if not self.workers_by_user[worker_user]:
1123
+ del self.workers_by_user[worker_user]
1124
+
832
1125
  self.stats["connected_workers"] = len(self.workers)
833
- # Release chunks in both managers
1126
+
1127
+ # Release chunks
834
1128
  self.chunk_manager.release_worker_chunks(worker_id)
835
1129
  if self.chunk_tracker:
836
- # Mark released chunks as pending in tracker
837
1130
  released_chunks = self.chunk_tracker.release_worker_chunks(worker_id)
838
1131
  logger.info(
839
1132
  f"Released {len(released_chunks) if released_chunks is not None else 0} chunks from worker {worker_id}"
840
1133
  )
841
1134
 
842
1135
  await self._broadcast_stats()
843
- await self._send_activity(f"Worker {worker_id} disconnected")
1136
+ await self._send_activity(f"Worker {worker_id} (user: {worker_user}) disconnected")
844
1137
 
845
1138
  async def _process_worker_message(self, worker_id: str, data: Dict):
846
1139
  """Process message from worker."""
@@ -856,28 +1149,26 @@ class Orchestrator:
856
1149
  return
857
1150
 
858
1151
  count = data.get("count", self.chunks_per_request)
859
- chunks = self.chunk_manager.get_chunks_for_worker(worker_id, count, self.chunk_tracker)
1152
+ chunk_infos = self.chunk_manager.get_chunks_for_worker(
1153
+ worker_id, count, self.chunk_tracker
1154
+ )
860
1155
 
861
- if chunks:
862
- # Only send the fields that worker expects
863
- chunk_data = []
864
- for chunk in chunks:
865
- chunk_data.append(
866
- {
867
- "chunk_id": chunk.chunk_id,
868
- "shard_url": chunk.shard_url,
869
- "shard_name": chunk.shard_name,
870
- "start_index": chunk.start_index,
871
- "chunk_size": chunk.chunk_size,
872
- }
873
- )
1156
+ if chunk_infos:
1157
+ # Send chunks with unprocessed ranges
1158
+ chunks_data = []
1159
+ for info in chunk_infos:
1160
+ chunk_dict = info["chunk"].to_dict()
1161
+ chunk_dict["unprocessed_ranges"] = info["unprocessed_ranges"]
1162
+ chunks_data.append(chunk_dict)
874
1163
 
875
1164
  await self.workers[worker_id].send(
876
- safe_json_dumps({"type": "shard_assignment", "chunks": chunk_data})
1165
+ safe_json_dumps({"type": "shard_assignment", "chunks": chunks_data})
1166
+ )
1167
+
1168
+ chunk_ids = [c["chunk_id"] for c in chunks_data]
1169
+ logger.info(
1170
+ f"Assigned {len(chunks_data)} chunks to worker {worker_id}: {chunk_ids}"
877
1171
  )
878
- chunk_ids = [c["chunk_id"] for c in chunk_data]
879
- logger.info(f"Assigned {len(chunks)} chunks to worker {worker_id}: {chunk_ids}")
880
- await self._send_activity(f"Assigned {len(chunks)} chunks to {worker_id}")
881
1172
  else:
882
1173
  await self.workers[worker_id].send(safe_json_dumps({"type": "no_chunks"}))
883
1174
 
@@ -907,7 +1198,7 @@ class Orchestrator:
907
1198
  elif msg_type == "submit_captions":
908
1199
  await self._handle_captions_submission(worker_id, data)
909
1200
  elif msg_type == "request_job":
910
- # VLLMWorker requesting a job from data samples
1201
+ # CaptionWorker requesting a job from data samples
911
1202
  try:
912
1203
  job = await asyncio.wait_for(self.data_sample_queue.get(), timeout=5)
913
1204
  await self.workers[worker_id].send(
@@ -921,76 +1212,132 @@ class Orchestrator:
921
1212
  logger.debug(f"Heartbeat from {worker_id}: {data}")
922
1213
 
923
1214
  async def _handle_captions_submission(self, worker_id: str, data: Dict):
924
- """Process multiple captions submission from worker."""
1215
+ """Process caption submission from worker - now handles multi-stage outputs."""
925
1216
  chunk_id = data.get("chunk_id")
926
1217
  item_key = data["item_key"]
927
- captions_list = data["captions"]
928
1218
 
929
- logger.debug(
930
- f"Received {len(captions_list)} captions for item {item_key} from worker {worker_id}"
931
- )
1219
+ item_index = data.get("item_index") # Worker should send this
1220
+ if item_index is None:
1221
+ # Try to extract from item_key (format: dataset_XXXXXXXX)
1222
+ try:
1223
+ item_index = int(item_key.split("_")[-1])
1224
+ except:
1225
+ logger.warning(f"Could not extract item index from key: {item_key}")
932
1226
 
933
- # Create a SINGLE caption record with ALL captions as a list
1227
+ # Extract user from worker_id (format: "username_uuid")
1228
+ worker_user = worker_id.rsplit("_", 1)[0] if "_" in worker_id else worker_id
1229
+
1230
+ # Handle both old format (captions list) and new format (outputs dict)
1231
+ if "outputs" in data:
1232
+ # New multi-stage format
1233
+ outputs = data["outputs"]
1234
+ captions_list = outputs.get("captions", [])
1235
+ total_outputs = sum(len(v) for v in outputs.values())
1236
+
1237
+ logger.debug(
1238
+ f"Received multi-stage outputs for item {item_key} from worker {worker_id}: "
1239
+ f"{total_outputs} outputs across {len(outputs)} fields"
1240
+ )
1241
+ else:
1242
+ # Old format - single captions list
1243
+ captions_list = data["captions"]
1244
+ outputs = {"captions": captions_list}
1245
+ total_outputs = len(captions_list)
1246
+
1247
+ logger.debug(
1248
+ f"Received {len(captions_list)} captions for item {item_key} from worker {worker_id}"
1249
+ )
1250
+
1251
+ # Create caption record with multi-stage outputs
934
1252
  caption = Caption(
935
- job_id=f"{chunk_id}_{item_key}", # Single ID for the item
1253
+ job_id=f"{chunk_id}_{item_key}",
936
1254
  dataset=data.get("dataset"),
937
1255
  shard=data.get("shard"),
938
1256
  item_key=item_key,
939
- captions=captions_list, # Store ALL captions as a list
940
- contributor_id=worker_id,
1257
+ captions=captions_list,
1258
+ outputs=outputs,
1259
+ contributor_id=worker_user,
941
1260
  timestamp=datetime.utcnow(),
942
- quality_scores=None, # Could be a list of scores matching captions
1261
+ quality_scores=None,
943
1262
  # Image metadata
944
1263
  image_width=data.get("image_width"),
945
1264
  image_height=data.get("image_height"),
946
1265
  image_format=data.get("image_format"),
947
1266
  file_size=data.get("file_size"),
948
1267
  # Processing metadata
949
- caption_count=len(captions_list),
1268
+ caption_count=total_outputs,
950
1269
  processing_time_ms=data.get("processing_time_ms"),
951
1270
  chunk_id=chunk_id,
1271
+ metadata=data.get("metadata", {}),
952
1272
  )
953
1273
 
954
- # Add to central storage buffer as a single entry
1274
+ # Add to central storage buffer
955
1275
  await self.storage.save_caption(caption)
956
1276
 
957
- # Update statistics
958
- self.stats["total_captions"] += len(captions_list)
959
- self.stats["buffer_size"] = len(self.storage.caption_buffer)
1277
+ # Handle item tracking with fixed deadlock
1278
+ should_flush = False
1279
+ if chunk_id and item_index is not None and self.chunk_tracker:
1280
+ with self.item_batch_lock:
1281
+ self.pending_processed_items[chunk_id].append(item_index)
960
1282
 
961
- # Update contributor stats
962
- contributor = await self.storage.get_contributor(worker_id)
1283
+ # Check if we should flush
1284
+ total_pending = sum(
1285
+ len(indices) for indices in self.pending_processed_items.values()
1286
+ )
1287
+ time_since_flush = time.time() - self.last_item_batch_flush
1288
+
1289
+ if (
1290
+ total_pending >= self.item_batch_size
1291
+ or time_since_flush >= self.item_batch_interval
1292
+ ):
1293
+ should_flush = True
1294
+
1295
+ if should_flush:
1296
+ await self._flush_processed_items()
1297
+
1298
+ # Update contributor stats (use user, not worker)
1299
+ contributor = await self.storage.get_contributor(worker_user)
963
1300
  if contributor:
964
- contributor.total_captions += len(captions_list)
1301
+ contributor.total_captions += total_outputs
965
1302
  await self.storage.save_contributor(contributor)
966
1303
 
967
1304
  # Broadcast updated stats
968
1305
  await self._broadcast_stats()
969
1306
 
970
1307
  # Log progress periodically
971
- if self.stats["total_captions"] % 100 == 0:
972
- logger.info(f"Collected {self.stats['total_captions']} captions centrally")
1308
+ total_outputs = self.stats.get("total_outputs", 0)
1309
+ if total_outputs > 0 and total_outputs % 100 == 0:
1310
+ if (
1311
+ not hasattr(self, "_last_logged_outputs")
1312
+ or self._last_logged_outputs != total_outputs
1313
+ ):
1314
+ logger.info(f"Collected {total_outputs} outputs centrally")
1315
+ self._last_logged_outputs = total_outputs
973
1316
 
974
1317
  async def _check_shard_completion(self, chunk_id: str):
975
1318
  """Check if a shard is complete after chunk completion."""
976
- # Extract shard name from chunk_id
977
- shard_name = chunk_id.rsplit("_chunk_", 1)[0]
1319
+ # Get the chunk
1320
+ chunk = self.chunk_manager.chunks.get(chunk_id)
1321
+ if not chunk:
1322
+ return
1323
+
1324
+ shard_name = chunk.shard_name
978
1325
 
979
- # Check if all chunks for this shard are complete
980
- chunk_stats = self.chunk_manager.get_stats()
1326
+ # Find all chunks for this shard
981
1327
  shard_chunks = [
982
- cid
983
- for cid, chunk in self.chunk_manager.chunks.items()
984
- if chunk.shard_name == shard_name
1328
+ cid for cid, c in self.chunk_manager.chunks.items() if c.belongs_to_shard(shard_name)
985
1329
  ]
986
1330
 
1331
+ # Check if all are completed
987
1332
  completed_chunks = [
988
1333
  cid for cid in shard_chunks if self.chunk_manager.chunks[cid].status == "completed"
989
1334
  ]
990
1335
 
991
- if len(completed_chunks) == len(shard_chunks):
1336
+ if len(completed_chunks) == len(shard_chunks) and len(shard_chunks) > 0:
992
1337
  logger.info(f"Shard {shard_name} complete!")
993
- self.shard_tracker.mark_complete(shard_name)
1338
+ # Don't mark virtual shards as complete in ShardTracker
1339
+ if not shard_name.startswith("hf_dataset:"):
1340
+ self.shard_tracker.mark_complete(shard_name)
994
1341
  self.stats["completed_shards"] += 1
995
1342
  await self._send_activity(f"Shard {shard_name} completed!")
996
1343
 
@@ -1063,47 +1410,198 @@ class Orchestrator:
1063
1410
  finally:
1064
1411
  del self.data_workers[worker_id]
1065
1412
 
1066
- async def _handle_monitor(self, websocket: WebSocketServerProtocol):
1067
- """Handle monitor connection."""
1068
- self.monitors.add(websocket)
1069
- logger.info("Monitor connected")
1413
+ async def _send_leaderboard_to_monitor(self, websocket: WebSocketServerProtocol):
1414
+ """Send leaderboard data to a specific monitor."""
1415
+ total_start = time.time()
1416
+ try:
1417
+ if websocket not in self.monitors:
1418
+ return
1419
+
1420
+ # Get contributors asynchronously
1421
+ contributors_start = time.time()
1422
+ contributors = await self.storage.get_top_contributors(10)
1423
+ logger.debug(
1424
+ f"Contributors retrieved in {(time.time() - contributors_start)*1000:.1f}ms"
1425
+ )
1426
+
1427
+ # Get worker counts in thread pool
1428
+ worker_counts_start = time.time()
1429
+ loop = asyncio.get_event_loop()
1430
+ worker_counts = await loop.run_in_executor(
1431
+ None,
1432
+ lambda: (
1433
+ self.get_workers_by_user_stats() if hasattr(self, "workers_by_user") else {}
1434
+ ),
1435
+ )
1436
+ logger.debug(
1437
+ f"Worker counts retrieved in {(time.time() - worker_counts_start)*1000:.1f}ms"
1438
+ )
1070
1439
 
1440
+ # Build enhanced contributors list
1441
+ build_start = time.time()
1442
+ enhanced_contributors = []
1443
+ for contributor in contributors:
1444
+ contrib_dict = {
1445
+ "contributor_id": contributor.contributor_id,
1446
+ "name": contributor.name,
1447
+ "total_captions": contributor.total_captions,
1448
+ "trust_level": contributor.trust_level,
1449
+ "active_workers": len(
1450
+ worker_counts.get(contributor.contributor_id, {}).get("worker_ids", [])
1451
+ ),
1452
+ }
1453
+ enhanced_contributors.append(contrib_dict)
1454
+ logger.debug(f"Enhanced contributors built in {(time.time() - build_start)*1000:.1f}ms")
1455
+
1456
+ # Cache for future monitors
1457
+ self._cached_leaderboard = enhanced_contributors
1458
+
1459
+ # Send if still connected
1460
+ if websocket in self.monitors:
1461
+ send_start = time.time()
1462
+ await websocket.send(
1463
+ safe_json_dumps({"type": "leaderboard", "data": enhanced_contributors})
1464
+ )
1465
+ logger.debug(
1466
+ f"Leaderboard sent to monitor in {(time.time() - send_start)*1000:.1f}ms"
1467
+ )
1468
+
1469
+ logger.debug(
1470
+ f"Leaderboard send to monitor completed in {(time.time() - total_start)*1000:.1f}ms"
1471
+ )
1472
+
1473
+ except websockets.exceptions.ConnectionClosed:
1474
+ logger.debug("Monitor disconnected during leaderboard send")
1475
+ except Exception as e:
1476
+ logger.error(f"Error sending leaderboard to monitor: {e}")
1477
+
1478
+ async def _send_initial_monitor_data(self, websocket: WebSocketServerProtocol):
1479
+ """Send initial data to monitor in a separate task to avoid blocking."""
1480
+ total_start = time.time()
1071
1481
  try:
1072
- # Send initial stats
1482
+ # Check if websocket is still in monitors set
1483
+ if websocket not in self.monitors:
1484
+ logger.debug("Monitor disconnected before initial data send")
1485
+ return
1486
+
1487
+ # Send current stats (already in memory)
1488
+ stats_start = time.time()
1073
1489
  await websocket.send(safe_json_dumps({"type": "stats", "data": self.stats}))
1490
+ logger.debug(f"Monitor stats sent in {(time.time() - stats_start)*1000:.1f}ms")
1491
+
1492
+ # Get chunk stats asynchronously
1493
+ chunk_stats_start = time.time()
1494
+ loop = asyncio.get_event_loop()
1495
+ chunk_stats = await loop.run_in_executor(None, self.chunk_manager.get_stats)
1496
+ logger.debug(f"Chunk stats retrieved in {(time.time() - chunk_stats_start)*1000:.1f}ms")
1497
+
1498
+ if websocket not in self.monitors:
1499
+ return
1074
1500
 
1075
- # Send chunk stats
1076
- chunk_stats = self.chunk_manager.get_stats()
1501
+ chunk_send_start = time.time()
1077
1502
  await websocket.send(safe_json_dumps({"type": "chunk_stats", "data": chunk_stats}))
1503
+ logger.debug(f"Chunk stats sent in {(time.time() - chunk_send_start)*1000:.1f}ms")
1078
1504
 
1079
- # Send contributor leaderboard
1080
- contributors = await self.storage.get_top_contributors(10)
1081
- await websocket.send(
1082
- safe_json_dumps(
1083
- {"type": "leaderboard", "data": [safe_dict(c) for c in contributors]}
1505
+ # For leaderboard, check if we have a cached version first
1506
+ if hasattr(self, "_cached_leaderboard") and self._cached_leaderboard:
1507
+ # Use cached leaderboard if available
1508
+ cache_send_start = time.time()
1509
+ await websocket.send(
1510
+ safe_json_dumps({"type": "leaderboard", "data": self._cached_leaderboard})
1084
1511
  )
1512
+ logger.debug(
1513
+ f"Cached leaderboard sent in {(time.time() - cache_send_start)*1000:.1f}ms"
1514
+ )
1515
+ else:
1516
+ # Schedule leaderboard update separately
1517
+ leaderboard_task_start = time.time()
1518
+ asyncio.create_task(self._send_leaderboard_to_monitor(websocket))
1519
+ logger.debug(
1520
+ f"Leaderboard task created in {(time.time() - leaderboard_task_start)*1000:.1f}ms"
1521
+ )
1522
+
1523
+ logger.debug(
1524
+ f"Monitor initial data send completed in {(time.time() - total_start)*1000:.1f}ms"
1085
1525
  )
1086
1526
 
1087
- # Keep connection alive
1088
- async for _ in websocket:
1089
- pass
1527
+ except websockets.exceptions.ConnectionClosed:
1528
+ logger.debug("Monitor disconnected during initial data send")
1529
+ except Exception as e:
1530
+ logger.error(f"Error sending initial monitor data: {e}")
1531
+
1532
+ async def _handle_monitor(self, websocket: WebSocketServerProtocol):
1533
+ """Handle monitor connection - truly non-blocking version."""
1534
+ monitor_start = time.time()
1535
+ self.monitors.add(websocket)
1536
+ logger.info(f"Monitor connected (total monitors: {len(self.monitors)})")
1537
+
1538
+ try:
1539
+ # Send welcome message immediately
1540
+ welcome_start = time.time()
1541
+ await websocket.send(safe_json_dumps({"type": "welcome", "role": "monitor"}))
1542
+ logger.debug(f"Monitor welcome sent in {(time.time() - welcome_start)*1000:.1f}ms")
1543
+
1544
+ # Schedule initial data send as a separate task to avoid blocking
1545
+ task_create_start = time.time()
1546
+ asyncio.create_task(self._send_initial_monitor_data(websocket))
1547
+ logger.debug(
1548
+ f"Monitor initial data task created in {(time.time() - task_create_start)*1000:.1f}ms"
1549
+ )
1550
+
1551
+ # Just keep the connection alive - no blocking work here
1552
+ try:
1553
+ async for message in websocket:
1554
+ # Handle any incoming messages from monitor if needed
1555
+ # For now, just ignore them
1556
+ pass
1557
+ except websockets.exceptions.ConnectionClosed:
1558
+ pass # Normal disconnection
1090
1559
 
1091
1560
  except websockets.exceptions.ConnectionClosed:
1092
1561
  logger.info("Monitor disconnected")
1562
+ except Exception as e:
1563
+ logger.error(f"Error in monitor handler: {e}")
1093
1564
  finally:
1094
1565
  self.monitors.discard(websocket)
1566
+ logger.debug(f"Monitor handler completed in {(time.time() - monitor_start)*1000:.1f}ms")
1095
1567
 
1096
1568
  async def _broadcast_stats(self):
1097
- """Broadcast statistics to all monitors."""
1569
+ """Broadcast statistics to all monitors - truly non-blocking version."""
1098
1570
  if not self.monitors:
1099
1571
  return
1100
-
1101
- # Include chunk stats
1102
- chunk_stats = self.chunk_manager.get_stats()
1103
- self.stats.update({f"chunks_{k}": v for k, v in chunk_stats.items()})
1572
+ if self.is_generating_stats:
1573
+ return # Already generating stats, skip this call
1574
+ self.is_generating_stats = True
1575
+ total_start = time.time()
1576
+
1577
+ # Prepare all the data first
1578
+ data_prep_start = time.time()
1579
+ loop = asyncio.get_event_loop()
1580
+
1581
+ # Get storage stats (already async)
1582
+ storage_stats_start = time.time()
1583
+ storage_stats = await self.storage.get_storage_stats()
1584
+ logger.debug(f"Storage stats retrieved in {(time.time() - storage_stats_start)*1000:.1f}ms")
1585
+
1586
+ caption_stats_start = time.time()
1587
+ caption_stats = await self.storage.get_caption_stats()
1588
+ logger.debug(f"Caption stats retrieved in {(time.time() - caption_stats_start)*1000:.1f}ms")
1589
+
1590
+ # Get chunk stats in thread pool
1591
+ chunk_stats_start = time.time()
1592
+ chunk_stats = await loop.run_in_executor(None, self.chunk_manager.get_stats)
1593
+ logger.debug(f"Chunk stats retrieved in {(time.time() - chunk_stats_start)*1000:.1f}ms")
1594
+
1595
+ # Build stats dict
1596
+ build_stats_start = time.time()
1597
+ stats_update = self.stats.copy()
1598
+ stats_update.update({f"chunks_{k}": v for k, v in chunk_stats.items()})
1599
+ stats_update.update(storage_stats)
1600
+ stats_update["field_breakdown"] = caption_stats.get("field_stats", {})
1601
+ stats_update["output_fields_list"] = caption_stats.get("output_fields", [])
1104
1602
 
1105
1603
  # Add rate information
1106
- self.stats.update(
1604
+ stats_update.update(
1107
1605
  {
1108
1606
  "current_rate": self.rate_tracker["current_rate"],
1109
1607
  "average_rate": self.rate_tracker["average_rate"],
@@ -1112,22 +1610,227 @@ class Orchestrator:
1112
1610
  )
1113
1611
 
1114
1612
  # Add vLLM info
1115
- self.stats["vllm_model"] = self.vllm_config.get("model", "unknown")
1116
- self.stats["vllm_batch_size"] = self.vllm_config.get("batch_size", 0)
1117
-
1118
- message = safe_json_dumps({"type": "stats", "data": self.stats})
1119
-
1120
- # Send to all monitors
1121
- disconnected = set()
1122
- for monitor in self.monitors:
1613
+ stats_update["vllm_model"] = self.vllm_config.get("model", "unknown")
1614
+ stats_update["vllm_batch_size"] = self.vllm_config.get("batch_size", 0)
1615
+
1616
+ # Add stage information
1617
+ stages = self.vllm_config.get("stages", [])
1618
+ if stages:
1619
+ stats_update["stage_count"] = len(stages)
1620
+ stats_update["stage_names"] = [s.get("name", "unnamed") for s in stages]
1621
+ else:
1622
+ stats_update["stage_count"] = 1
1623
+ stats_update["stage_names"] = ["default"]
1624
+
1625
+ # Get field stats
1626
+ field_stats_start = time.time()
1627
+ field_stats = await self.storage.get_output_field_stats()
1628
+ stats_update["output_fields"] = field_stats
1629
+ logger.debug(f"Field stats retrieved in {(time.time() - field_stats_start)*1000:.1f}ms")
1630
+
1631
+ # Update our internal stats
1632
+ self.stats = stats_update
1633
+ logger.debug(f"Stats prepared in {(time.time() - build_stats_start)*1000:.1f}ms")
1634
+
1635
+ logger.debug(f"Total data preparation took {(time.time() - data_prep_start)*1000:.1f}ms")
1636
+
1637
+ # Create message once
1638
+ message_create_start = time.time()
1639
+ stats_message = safe_json_dumps({"type": "stats", "data": self.stats})
1640
+ logger.debug(f"Stats message created in {(time.time() - message_create_start)*1000:.1f}ms")
1641
+
1642
+ # Send to all monitors asynchronously in parallel
1643
+ send_start = time.time()
1644
+
1645
+ async def send_to_monitor(monitor):
1123
1646
  try:
1124
- await monitor.send(message)
1647
+ await monitor.send(stats_message)
1125
1648
  except websockets.exceptions.ConnectionClosed:
1126
- disconnected.add(monitor)
1649
+ return monitor # Return for removal
1650
+ except Exception as e:
1651
+ logger.debug(f"Error sending stats to monitor: {e}")
1652
+ return monitor # Return for removal
1653
+ return None
1654
+
1655
+ # Send to all monitors in parallel
1656
+ monitors_copy = self.monitors.copy()
1657
+ results = await asyncio.gather(
1658
+ *[send_to_monitor(m) for m in monitors_copy], return_exceptions=True
1659
+ )
1127
1660
 
1128
- # Clean up disconnected monitors
1661
+ # Remove disconnected monitors
1662
+ disconnected = {
1663
+ m
1664
+ for m, r in zip(monitors_copy, results)
1665
+ if r is not None and not isinstance(r, Exception)
1666
+ }
1129
1667
  self.monitors -= disconnected
1130
1668
 
1669
+ logger.debug(
1670
+ f"Stats sent to {len(monitors_copy)} monitors in {(time.time() - send_start)*1000:.1f}ms"
1671
+ )
1672
+
1673
+ # Send leaderboard update in a separate task to avoid blocking
1674
+ leaderboard_task_start = time.time()
1675
+ asyncio.create_task(self._broadcast_leaderboard())
1676
+ self.is_generating_stats = False
1677
+ logger.debug(
1678
+ f"Leaderboard broadcast task created in {(time.time() - leaderboard_task_start)*1000:.1f}ms"
1679
+ )
1680
+ logger.debug(f"Stats broadcast completed in {(time.time() - total_start)*1000:.1f}ms")
1681
+
1682
+ async def _broadcast_leaderboard(self):
1683
+ """Send leaderboard updates to monitors - separate from stats to avoid blocking."""
1684
+ if not self.monitors:
1685
+ return
1686
+
1687
+ total_start = time.time()
1688
+ try:
1689
+ # Get contributors
1690
+ contributors_start = time.time()
1691
+ contributors = await self.storage.get_top_contributors(10)
1692
+ logger.debug(
1693
+ f"Contributors retrieved for broadcast in {(time.time() - contributors_start)*1000:.1f}ms"
1694
+ )
1695
+
1696
+ # Get worker counts
1697
+ worker_counts_start = time.time()
1698
+ loop = asyncio.get_event_loop()
1699
+ worker_counts = await loop.run_in_executor(
1700
+ None,
1701
+ lambda: (
1702
+ self.get_workers_by_user_stats() if hasattr(self, "workers_by_user") else {}
1703
+ ),
1704
+ )
1705
+ logger.debug(
1706
+ f"Worker counts retrieved for broadcast in {(time.time() - worker_counts_start)*1000:.1f}ms"
1707
+ )
1708
+
1709
+ # Build enhanced contributors list
1710
+ build_start = time.time()
1711
+ enhanced_contributors = []
1712
+ for contributor in contributors:
1713
+ contrib_dict = {
1714
+ "contributor_id": contributor.contributor_id,
1715
+ "name": contributor.name,
1716
+ "total_captions": contributor.total_captions,
1717
+ "trust_level": contributor.trust_level,
1718
+ "active_workers": len(
1719
+ worker_counts.get(contributor.contributor_id, {}).get("worker_ids", [])
1720
+ ),
1721
+ }
1722
+ enhanced_contributors.append(contrib_dict)
1723
+ logger.debug(
1724
+ f"Enhanced contributors built for broadcast in {(time.time() - build_start)*1000:.1f}ms"
1725
+ )
1726
+
1727
+ # Cache it
1728
+ self._cached_leaderboard = enhanced_contributors
1729
+
1730
+ # Create message once
1731
+ message_create_start = time.time()
1732
+ leaderboard_message = safe_json_dumps(
1733
+ {"type": "leaderboard", "data": enhanced_contributors}
1734
+ )
1735
+ logger.debug(
1736
+ f"Leaderboard message created in {(time.time() - message_create_start)*1000:.1f}ms"
1737
+ )
1738
+
1739
+ # Send to all monitors in parallel
1740
+ send_start = time.time()
1741
+
1742
+ async def send_leaderboard(monitor):
1743
+ try:
1744
+ await monitor.send(leaderboard_message)
1745
+ except:
1746
+ return monitor # Mark for removal
1747
+ return None
1748
+
1749
+ monitors_copy = self.monitors.copy()
1750
+ results = await asyncio.gather(
1751
+ *[send_leaderboard(m) for m in monitors_copy], return_exceptions=True
1752
+ )
1753
+
1754
+ # Remove disconnected
1755
+ disconnected = {
1756
+ m
1757
+ for m, r in zip(monitors_copy, results)
1758
+ if r is not None and not isinstance(r, Exception)
1759
+ }
1760
+ self.monitors -= disconnected
1761
+
1762
+ logger.debug(
1763
+ f"Leaderboard sent to {len(monitors_copy)} monitors in {(time.time() - send_start)*1000:.1f}ms"
1764
+ )
1765
+ logger.debug(
1766
+ f"Leaderboard broadcast completed in {(time.time() - total_start)*1000:.1f}ms"
1767
+ )
1768
+
1769
+ except Exception as e:
1770
+ logger.error(f"Error broadcasting leaderboard: {e}")
1771
+
1772
+ def _get_queue_stats(self) -> Dict[str, int]:
1773
+ """Get queue statistics - synchronous helper for thread pool."""
1774
+ with self.chunk_manager.lock:
1775
+ return {
1776
+ "pending_chunks": len(self.chunk_manager.pending_chunks),
1777
+ "assigned_chunks": sum(
1778
+ len(chunks) for chunks in self.chunk_manager.assigned_chunks.values()
1779
+ ),
1780
+ }
1781
+
1782
+ async def _flush_processed_items(self):
1783
+ """Flush batched processed items to chunk tracker."""
1784
+ with self.item_batch_lock:
1785
+ if not self.pending_processed_items:
1786
+ return
1787
+
1788
+ for chunk_id, indices in self.pending_processed_items.items():
1789
+ if not indices:
1790
+ continue
1791
+
1792
+ # Indices here are ABSOLUTE dataset indices
1793
+ # Sort indices
1794
+ indices.sort()
1795
+
1796
+ # Group consecutive indices into ranges
1797
+ ranges = []
1798
+ start = indices[0]
1799
+ end = indices[0]
1800
+
1801
+ for i in range(1, len(indices)):
1802
+ if indices[i] == end + 1:
1803
+ # Consecutive, extend range
1804
+ end = indices[i]
1805
+ else:
1806
+ # Gap found, save current range and start new one
1807
+ ranges.append((start, end))
1808
+ start = indices[i]
1809
+ end = indices[i]
1810
+
1811
+ # Don't forget the last range
1812
+ ranges.append((start, end))
1813
+
1814
+ # Mark ranges as processed (mark_items_processed expects absolute indices)
1815
+ for start_idx, end_idx in ranges:
1816
+ self.chunk_tracker.mark_items_processed(chunk_id, start_idx, end_idx)
1817
+
1818
+ # Clear pending items
1819
+ self.pending_processed_items.clear()
1820
+ self.last_item_batch_flush = time.time()
1821
+
1822
+ def get_workers_by_user_stats(self) -> Dict[str, Any]:
1823
+ """Get statistics about workers grouped by user/token - thread-safe version."""
1824
+ if not hasattr(self, "workers_by_user"):
1825
+ return {}
1826
+
1827
+ # Create a copy to avoid issues with concurrent modification
1828
+ stats = {}
1829
+ workers_snapshot = dict(self.workers_by_user)
1830
+ for user, worker_ids in workers_snapshot.items():
1831
+ stats[user] = {"worker_count": len(worker_ids), "worker_ids": list(worker_ids)}
1832
+ return stats
1833
+
1131
1834
  async def _send_activity(self, activity: str):
1132
1835
  """Send activity update to monitors."""
1133
1836
  if not self.monitors:
@@ -1149,21 +1852,63 @@ class Orchestrator:
1149
1852
  async def _heartbeat_loop(self):
1150
1853
  """Send periodic heartbeats to maintain connections."""
1151
1854
  while True:
1152
- await asyncio.sleep(30)
1855
+ try:
1856
+ await asyncio.sleep(30)
1153
1857
 
1154
- # Ping workers
1155
- disconnected = []
1156
- for worker_id, ws in self.workers.items():
1157
- try:
1158
- await ws.ping()
1159
- except:
1160
- disconnected.append(worker_id)
1858
+ # Create a copy of worker items to avoid modification during iteration
1859
+ worker_items = list(self.workers.items())
1860
+ disconnected = []
1161
1861
 
1162
- # Clean up disconnected workers
1163
- for worker_id in disconnected:
1164
- if worker_id in self.workers:
1165
- del self.workers[worker_id]
1166
- self.chunk_manager.release_worker_chunks(worker_id)
1862
+ for worker_id, ws in worker_items:
1863
+ try:
1864
+ # Check if worker still exists before pinging
1865
+ if worker_id not in self.workers:
1866
+ continue
1867
+
1868
+ # Send ping with timeout
1869
+ pong_waiter = await ws.ping()
1870
+ try:
1871
+ await asyncio.wait_for(pong_waiter, timeout=10)
1872
+ except asyncio.TimeoutError:
1873
+ logger.warning(f"Worker {worker_id} failed to respond to ping")
1874
+ disconnected.append(worker_id)
1875
+ except websockets.exceptions.ConnectionClosed:
1876
+ logger.info(f"Worker {worker_id} connection already closed")
1877
+ disconnected.append(worker_id)
1878
+ except Exception as e:
1879
+ logger.error(f"Error pinging worker {worker_id}: {e}")
1880
+ disconnected.append(worker_id)
1881
+
1882
+ # Clean up disconnected workers
1883
+ for worker_id in disconnected:
1884
+ if worker_id in self.workers:
1885
+ logger.info(f"Removing unresponsive worker {worker_id}")
1886
+ del self.workers[worker_id]
1887
+ self.chunk_manager.release_worker_chunks(worker_id)
1888
+
1889
+ # Update stats
1890
+ self.stats["connected_workers"] = len(self.workers)
1891
+
1892
+ # Also clean up from workers_by_user if it exists
1893
+ if hasattr(self, "workers_by_user"):
1894
+ worker_user = (
1895
+ worker_id.rsplit("_", 1)[0] if "_" in worker_id else worker_id
1896
+ )
1897
+ if worker_user in self.workers_by_user:
1898
+ self.workers_by_user[worker_user].discard(worker_id)
1899
+ if not self.workers_by_user[worker_user]:
1900
+ del self.workers_by_user[worker_user]
1901
+
1902
+ # Notify monitors
1903
+ await self._broadcast_stats()
1904
+ await self._send_activity(
1905
+ f"Worker {worker_id} removed due to heartbeat timeout"
1906
+ )
1907
+
1908
+ except Exception as e:
1909
+ logger.error(f"Error in heartbeat loop: {e}", exc_info=True)
1910
+ # Continue the loop even if there's an error
1911
+ await asyncio.sleep(5)
1167
1912
 
1168
1913
  async def _checkpoint_loop(self):
1169
1914
  """Periodically checkpoint storage."""
@@ -1172,42 +1917,58 @@ class Orchestrator:
1172
1917
  while True:
1173
1918
  await asyncio.sleep(60)
1174
1919
 
1920
+ # Get current caption count from storage
1921
+ storage_stats = await self.storage.get_storage_stats()
1922
+ total_captions = storage_stats["total_captions"]
1923
+
1175
1924
  # Force checkpoint at regular intervals
1176
- if self.stats["total_captions"] > 0 and self.stats["total_captions"] % interval == 0:
1177
- logger.info(f"Triggering checkpoint at {self.stats['total_captions']} captions")
1925
+ if total_captions > 0 and total_captions % interval == 0:
1926
+ logger.info(f"Triggering checkpoint at {total_captions} captions")
1178
1927
  await self.storage.checkpoint()
1179
1928
 
1180
1929
  # Update stats
1181
1930
  self.stats["last_checkpoint"] = datetime.utcnow().isoformat()
1182
- self.stats["total_written"] = self.storage.total_captions_written
1183
- self.stats["buffer_size"] = len(self.storage.caption_buffer)
1931
+ # No need to update total_written or buffer_size - they come from storage
1184
1932
 
1185
1933
  await self._broadcast_stats()
1186
1934
  logger.info(
1187
- f"Checkpoint complete. Total written to disk: {self.stats['total_written']}"
1935
+ f"Checkpoint complete. Total written to disk: {storage_stats['total_written']}"
1188
1936
  )
1189
1937
 
1190
1938
  async def _stats_update_loop(self):
1191
- """Periodically update and broadcast stats."""
1939
+ """Periodically update and broadcast stats - non-blocking version."""
1940
+ # Get the event loop for running blocking operations
1941
+ loop = asyncio.get_event_loop()
1942
+
1192
1943
  # Track session start values
1193
- session_start_captions = self.stats["total_captions"]
1944
+ storage_stats = await self.storage.get_storage_stats()
1945
+ session_start_outputs = storage_stats["total_captions"] # This now counts ALL outputs
1194
1946
  session_start_time = time.time()
1195
1947
 
1948
+ # Track the last known total to detect flushes
1949
+ last_known_total = session_start_outputs
1950
+
1196
1951
  while True:
1197
1952
  await asyncio.sleep(10)
1198
1953
 
1199
- # Update chunk stats
1200
- chunk_stats = self.chunk_manager.get_stats()
1954
+ # Update chunk stats in thread pool to avoid blocking
1955
+ chunk_stats = await loop.run_in_executor(None, self.chunk_manager.get_stats)
1956
+ storage_stats = await self.storage.get_storage_stats()
1957
+ current_total_outputs = storage_stats["total_captions"] # ALL outputs
1958
+ if self.chunk_tracker:
1959
+ await self._flush_processed_items()
1960
+
1201
1961
  self.stats["total_chunks"] = chunk_stats["total"]
1202
1962
  self.stats["completed_chunks"] = chunk_stats["completed"]
1203
1963
  self.stats["failed_chunks"] = chunk_stats["failed"]
1204
1964
 
1205
- # Add queue information
1206
- with self.chunk_manager.lock:
1207
- self.stats["pending_chunks"] = len(self.chunk_manager.pending_chunks)
1208
- self.stats["assigned_chunks"] = sum(
1209
- len(chunks) for chunks in self.chunk_manager.assigned_chunks.values()
1210
- )
1965
+ # Update total outputs stat (rename from total_captions for clarity)
1966
+ self.stats["total_outputs"] = current_total_outputs
1967
+ self.stats["total_captions"] = current_total_outputs # Keep for backward compatibility
1968
+
1969
+ # Get queue stats in thread pool to avoid blocking
1970
+ queue_stats = await loop.run_in_executor(None, self._get_queue_stats)
1971
+ self.stats.update(queue_stats)
1211
1972
 
1212
1973
  # Calculate if we need more chunks
1213
1974
  worker_count = self.stats.get("connected_workers", 0)
@@ -1220,33 +1981,57 @@ class Orchestrator:
1220
1981
  elapsed_since_update = current_time - self.rate_tracker["last_update_time"]
1221
1982
 
1222
1983
  if elapsed_since_update > 0:
1223
- # Calculate current rate (captions per minute)
1224
- caption_diff = (
1225
- self.stats["total_captions"] - self.rate_tracker["last_caption_count"]
1226
- )
1227
- self.rate_tracker["current_rate"] = (caption_diff / elapsed_since_update) * 60
1984
+ # FIX: Handle the case where duplicates were skipped during save
1985
+ # If current total is less than last known, it means duplicates were skipped
1986
+ # We should not count this as negative progress
1987
+ if current_total_outputs < last_known_total:
1988
+ logger.debug(
1989
+ f"Detected duplicate skip during save: {last_known_total} -> {current_total_outputs}"
1990
+ )
1991
+ # Don't calculate negative rate, just update the baseline
1992
+ self.rate_tracker["last_caption_count"] = current_total_outputs
1993
+ self.rate_tracker["current_rate"] = 0.0 # Set to 0 during flush
1994
+ else:
1995
+ # Normal rate calculation
1996
+ output_diff = current_total_outputs - self.rate_tracker["last_caption_count"]
1997
+ self.rate_tracker["current_rate"] = (output_diff / elapsed_since_update) * 60
1998
+ self.rate_tracker["last_caption_count"] = current_total_outputs
1228
1999
 
1229
2000
  # Calculate average rate since THIS SESSION started
1230
2001
  session_elapsed = current_time - session_start_time
1231
2002
  if session_elapsed > 0:
1232
- session_captions = self.stats["total_captions"] - session_start_captions
1233
- self.rate_tracker["average_rate"] = (session_captions / session_elapsed) * 60
2003
+ # Always use the difference from session start for average
2004
+ session_outputs = current_total_outputs - session_start_outputs
2005
+ self.rate_tracker["average_rate"] = (session_outputs / session_elapsed) * 60
1234
2006
 
1235
- # Calculate expected rate based on workers
1236
- # Assume each worker processes batch_size images every ~2 seconds with 3 captions each
2007
+ # Calculate expected rate based on workers and stages
1237
2008
  batch_size = self.vllm_config.get("batch_size", 8)
1238
- num_prompts = len(self.vllm_config.get("inference_prompts", ["", "", ""]))
2009
+
2010
+ # Count total prompts across all stages
2011
+ total_prompts = 0
2012
+ stages = self.vllm_config.get("stages", [])
2013
+ if stages:
2014
+ for stage in stages:
2015
+ total_prompts += len(stage.get("prompts", []))
2016
+ else:
2017
+ # Backward compatibility
2018
+ total_prompts = len(self.vllm_config.get("inference_prompts", ["", "", ""]))
2019
+
1239
2020
  images_per_minute = 30 # Rough estimate: 30 images/min per worker
1240
- self.rate_tracker["expected_rate"] = worker_count * images_per_minute * num_prompts
2021
+ self.rate_tracker["expected_rate"] = (
2022
+ worker_count * images_per_minute * total_prompts
2023
+ )
1241
2024
 
1242
2025
  # Update trackers
1243
2026
  self.rate_tracker["last_update_time"] = current_time
1244
- self.rate_tracker["last_caption_count"] = self.stats["total_captions"]
2027
+ last_known_total = current_total_outputs
1245
2028
 
1246
2029
  # Log rate information when workers are connected
1247
- if worker_count > 0:
2030
+ if (
2031
+ worker_count > 0 and self.rate_tracker["current_rate"] >= 0
2032
+ ): # Only log non-negative rates
1248
2033
  logger.info(
1249
- f"Rate: {self.rate_tracker['current_rate']:.1f} captions/min "
2034
+ f"Rate: {self.rate_tracker['current_rate']:.1f} outputs/min "
1250
2035
  f"(avg: {self.rate_tracker['average_rate']:.1f}, "
1251
2036
  f"expected: {self.rate_tracker['expected_rate']:.1f}) | "
1252
2037
  f"Workers: {worker_count}, Chunks: {active_chunks}/{target_buffer}"
@@ -1256,16 +2041,16 @@ class Orchestrator:
1256
2041
 
1257
2042
  async def _restore_state(self):
1258
2043
  """Restore state from storage on startup."""
1259
- # Update statistics
1260
- self.stats["total_captions"] = await self.storage.count_captions()
1261
-
1262
- logger.info(f"Restored state: {self.stats['total_captions']} captions")
2044
+ total_captions = await self.storage.count_captions()
2045
+ logger.info(f"Restored state: {total_captions} captions")
1263
2046
 
1264
2047
  async def shutdown(self):
1265
2048
  """Graceful shutdown."""
1266
2049
  logger.info("Shutting down orchestrator...")
1267
2050
 
1268
2051
  # Stop chunk creation
2052
+ if self.chunk_tracker:
2053
+ await self._flush_processed_items()
1269
2054
  self.stop_chunk_creation.set()
1270
2055
  if self.chunk_creation_thread:
1271
2056
  self.chunk_creation_thread.join(timeout=5)
@@ -1287,7 +2072,7 @@ class Orchestrator:
1287
2072
 
1288
2073
  # Save chunk state
1289
2074
  if self.chunk_tracker:
1290
- self.chunk_tracker.save_checkpoint()
2075
+ self.chunk_tracker.save()
1291
2076
 
1292
2077
  # Final checkpoint
1293
2078
  logger.info(f"Final flush: {len(self.storage.caption_buffer)} captions in buffer")