caption-flow 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -21,15 +21,15 @@ from collections import deque, defaultdict
21
21
  import threading
22
22
  from queue import Queue, Empty
23
23
 
24
+ from .workers import data
24
25
  import websockets
25
26
  from websockets.server import WebSocketServerProtocol
26
27
 
27
28
  from .storage import StorageManager
28
29
  from .models import Caption, Contributor
29
30
  from .utils.auth import AuthManager
30
- from .utils.dataset_loader import DatasetLoader, ShardTracker
31
+ from .utils import DatasetLoader, ShardTracker, ChunkTracker
31
32
  from .utils.json_utils import safe_dict, safe_json_dumps, to_json_dict
32
- from .utils.chunk_tracker import ChunkTracker
33
33
 
34
34
  logger = logging.getLogger(__name__)
35
35
 
@@ -48,6 +48,43 @@ class ShardChunk:
48
48
  assigned_at: Optional[datetime] = None
49
49
  completed_at: Optional[datetime] = None
50
50
 
51
+ @classmethod
52
+ def create(
53
+ cls, shard_url: str, shard_name: str, start_index: int, chunk_size: int
54
+ ) -> "ShardChunk":
55
+ """Factory method to create a chunk with consistent ID."""
56
+ # Always use consistent format: dataset_chunk_startindex
57
+ if shard_url.startswith("hf_dataset:"):
58
+ # Extract dataset path
59
+ parts = shard_url.split(":")
60
+ dataset_path = parts[1] if len(parts) > 1 else "unknown"
61
+ chunk_id = f"{dataset_path.replace('/', '_')}_chunk_{start_index}"
62
+ else:
63
+ # WebDataset format
64
+ chunk_id = f"{shard_name}_chunk_{start_index}"
65
+
66
+ return cls(
67
+ chunk_id=chunk_id,
68
+ shard_url=shard_url,
69
+ shard_name=shard_name,
70
+ start_index=start_index,
71
+ chunk_size=chunk_size,
72
+ )
73
+
74
+ def belongs_to_shard(self, shard_identifier: str) -> bool:
75
+ """Check if this chunk belongs to a given shard."""
76
+ return self.shard_name == shard_identifier
77
+
78
+ def to_dict(self) -> Dict[str, Any]:
79
+ """Convert to dict for JSON serialization (for workers)."""
80
+ return {
81
+ "chunk_id": self.chunk_id,
82
+ "shard_url": self.shard_url,
83
+ "shard_name": self.shard_name,
84
+ "start_index": self.start_index,
85
+ "chunk_size": self.chunk_size,
86
+ }
87
+
51
88
 
52
89
  class ChunkManager:
53
90
  """Manages shard chunk creation and assignment."""
@@ -67,9 +104,7 @@ class ChunkManager:
67
104
  chunks = []
68
105
 
69
106
  for start_idx in range(0, total_items, self.chunk_size):
70
- chunk_id = f"{shard_name}_chunk_{start_idx}"
71
- chunk = ShardChunk(
72
- chunk_id=chunk_id,
107
+ chunk = ShardChunk.create(
73
108
  shard_url=shard_url,
74
109
  shard_name=shard_name,
75
110
  start_index=start_idx,
@@ -77,8 +112,8 @@ class ChunkManager:
77
112
  )
78
113
 
79
114
  with self.lock:
80
- self.chunks[chunk_id] = chunk
81
- self.pending_chunks.append(chunk_id)
115
+ self.chunks[chunk.chunk_id] = chunk
116
+ self.pending_chunks.append(chunk.chunk_id)
82
117
 
83
118
  chunks.append(chunk)
84
119
 
@@ -86,24 +121,84 @@ class ChunkManager:
86
121
 
87
122
  def get_chunks_for_worker(
88
123
  self, worker_id: str, count: int = 1, tracker: Optional["ChunkTracker"] = None
89
- ) -> List[ShardChunk]:
90
- """Get available chunks for a worker."""
124
+ ) -> List[Dict[str, Any]]:
125
+ """Get available chunks with unprocessed items for a worker."""
91
126
  assigned = []
92
127
 
93
128
  with self.lock:
129
+ # FIRST PRIORITY: Check if this worker already has assigned chunks
130
+ # Workers should complete their current chunks before getting new ones
131
+ if worker_id in self.assigned_chunks:
132
+ existing_chunk_ids = list(self.assigned_chunks[worker_id])
133
+ for chunk_id in existing_chunk_ids:
134
+ if len(assigned) >= count:
135
+ break
136
+
137
+ chunk = self.chunks.get(chunk_id)
138
+ if not chunk:
139
+ continue
140
+
141
+ # Check if chunk still has unprocessed items
142
+ if tracker:
143
+ chunk_info = tracker.get_chunk_with_unprocessed_items(chunk_id)
144
+ if chunk_info and chunk_info["unprocessed_ranges"]:
145
+ assigned.append(
146
+ {
147
+ "chunk": chunk,
148
+ "unprocessed_ranges": chunk_info["unprocessed_ranges"],
149
+ }
150
+ )
151
+ else:
152
+ # No tracker, assume chunk needs processing
153
+ assigned.append(
154
+ {
155
+ "chunk": chunk,
156
+ "unprocessed_ranges": [(0, chunk.chunk_size - 1)],
157
+ }
158
+ )
159
+
160
+ # SECOND PRIORITY: Get new pending chunks
161
+ # Only if worker doesn't have enough chunks already
94
162
  while len(assigned) < count and self.pending_chunks:
95
163
  chunk_id = self.pending_chunks.popleft()
96
- chunk = self.chunks[chunk_id]
164
+ chunk = self.chunks.get(chunk_id)
165
+
166
+ if not chunk:
167
+ continue
97
168
 
169
+ # Verify chunk is truly pending (defensive check)
170
+ if chunk.status != "pending" or chunk.assigned_to is not None:
171
+ logger.warning(
172
+ f"Chunk {chunk_id} in pending queue but status={chunk.status}, assigned_to={chunk.assigned_to}"
173
+ )
174
+ continue
175
+
176
+ # Assign to this worker
98
177
  chunk.assigned_to = worker_id
99
178
  chunk.status = "assigned"
100
179
  chunk.assigned_at = datetime.utcnow()
101
-
102
180
  self.assigned_chunks[worker_id].add(chunk_id)
103
- assigned.append(chunk)
181
+
182
+ # Get unprocessed ranges
183
+ unprocessed_ranges = [(0, chunk.chunk_size - 1)] # Default
104
184
  if tracker:
185
+ chunk_info = tracker.get_chunk_with_unprocessed_items(chunk_id)
186
+ if chunk_info:
187
+ unprocessed_ranges = chunk_info["unprocessed_ranges"]
105
188
  tracker.mark_assigned(chunk_id, worker_id)
106
189
 
190
+ assigned.append({"chunk": chunk, "unprocessed_ranges": unprocessed_ranges})
191
+
192
+ # Log what we're assigning
193
+ if assigned:
194
+ chunk_summary = ", ".join(
195
+ [
196
+ f"{info['chunk'].chunk_id}[{len(info['unprocessed_ranges'])} ranges]"
197
+ for info in assigned
198
+ ]
199
+ )
200
+ logger.info(f"Assigning to worker {worker_id}: {chunk_summary}")
201
+
107
202
  return assigned
108
203
 
109
204
  def complete_chunk(self, chunk_id: str, worker_id: str) -> bool:
@@ -173,6 +268,27 @@ class Orchestrator:
173
268
  self.dataset_config = config.get("dataset", {})
174
269
  self.dataset_path = self.dataset_config.get("path")
175
270
  self.dataset_type = self.dataset_config.get("type", "huggingface")
271
+ self.dataset_split = self.dataset_config.get("split", "train") # Add split configuration
272
+ self.dataset_image_column = self.dataset_config.get(
273
+ "image_column", "image"
274
+ ) # Add image column config
275
+
276
+ # Dataset components
277
+ self.dataset_loader = None
278
+ self.shard_tracker = None
279
+ self.chunk_tracker = None
280
+
281
+ if self.dataset_path:
282
+ self.dataset_loader = DatasetLoader(
283
+ self.dataset_path,
284
+ self.dataset_type,
285
+ self.dataset_split,
286
+ self.dataset_image_column,
287
+ )
288
+ checkpoint_dir = Path(config.get("storage", {}).get("checkpoint_dir", "./checkpoints"))
289
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
290
+ self.shard_tracker = ShardTracker(checkpoint_dir / "shards.json")
291
+ self.chunk_tracker = ChunkTracker(checkpoint_dir / "chunks.json")
176
292
 
177
293
  # vLLM configuration to distribute to workers
178
294
  self.vllm_config = config.get(
@@ -233,6 +349,11 @@ class Orchestrator:
233
349
 
234
350
  # Initialize chunk manager with reference to chunk tracker
235
351
  self.chunk_manager = ChunkManager(self.chunk_size, self.chunk_tracker)
352
+ self.pending_processed_items = defaultdict(list) # chunk_id -> list of indices
353
+ self.item_batch_lock = threading.Lock()
354
+ self.last_item_batch_flush = time.time()
355
+ self.item_batch_interval = 5 # Flush every 5 seconds
356
+ self.item_batch_size = 100 # Or every 100 items
236
357
 
237
358
  # Track connections
238
359
  self.workers: Dict[str, WebSocketServerProtocol] = {}
@@ -246,13 +367,10 @@ class Orchestrator:
246
367
  "total_chunks": 0,
247
368
  "completed_chunks": 0,
248
369
  "failed_chunks": 0,
249
- "total_captions": 0,
250
370
  "connected_workers": 0,
251
371
  "total_shards": 0,
252
372
  "completed_shards": 0,
253
373
  "current_shard": None,
254
- "buffer_size": 0,
255
- "total_written": 0,
256
374
  "last_checkpoint": None,
257
375
  }
258
376
 
@@ -266,7 +384,7 @@ class Orchestrator:
266
384
  "expected_rate": 0.0,
267
385
  }
268
386
 
269
- # Data sample queue for VLLMWorkers
387
+ # Data sample queue for CaptionWorker
270
388
  self.data_sample_queue = asyncio.Queue(maxsize=1000)
271
389
  self.data_workers: Dict[str, WebSocketServerProtocol] = {}
272
390
 
@@ -310,10 +428,23 @@ class Orchestrator:
310
428
  # Mark state as not restored until we process checkpoints
311
429
  self.state_restored.clear()
312
430
 
431
+ # Get dataset info to check format
432
+ dataset_info = self.dataset_loader.get_dataset_info()
433
+ dataset_format = dataset_info.get("dataset_format", "unknown")
434
+ logger.info(f"Dataset format: {dataset_format}")
435
+
313
436
  # Get all shards
314
437
  self.all_shards = self.dataset_loader.get_shard_list()
315
438
  self.stats["total_shards"] = len(self.all_shards)
316
439
 
440
+ # For HuggingFace datasets, we might need to dynamically create more shards
441
+ if dataset_format == "huggingface_datasets":
442
+ self._is_hf_dataset = True
443
+ self._hf_chunk_size = 10000 # Items per virtual shard
444
+ self._next_hf_shard_index = len(self.all_shards) # For creating new virtual shards
445
+ else:
446
+ self._is_hf_dataset = False
447
+
317
448
  # Get shard status from ChunkTracker
318
449
  shards_summary = self.chunk_tracker.get_shards_summary() if self.chunk_tracker else {}
319
450
  completed_shards = {
@@ -336,7 +467,10 @@ class Orchestrator:
336
467
 
337
468
  # Filter out shards that already have chunks created
338
469
  remaining_shards = [
339
- shard for shard in remaining_shards if Path(shard).stem not in shards_with_chunks
470
+ shard
471
+ for shard in remaining_shards
472
+ if (shard if shard.startswith("hf_dataset:") else Path(shard).stem)
473
+ not in shards_with_chunks
340
474
  ]
341
475
 
342
476
  self.stats["completed_shards"] = len(completed_shards)
@@ -356,25 +490,18 @@ class Orchestrator:
356
490
  with self.chunk_manager.lock:
357
491
  for chunk_state in shard_info["chunks"]:
358
492
  if chunk_state.status in ["pending", "failed", "assigned"]:
359
- # Find shard URL
360
- shard_url = None
361
- for url in self.all_shards:
362
- if Path(url).stem == shard_name:
363
- shard_url = url
364
- break
365
-
366
- if shard_url:
367
- chunk = ShardChunk(
368
- chunk_id=chunk_state.chunk_id,
369
- shard_url=shard_url,
370
- shard_name=chunk_state.shard_name,
371
- start_index=chunk_state.start_index,
372
- chunk_size=chunk_state.chunk_size,
373
- )
374
- self.chunk_manager.chunks[chunk_state.chunk_id] = chunk
375
- self.chunk_manager.pending_chunks.append(chunk_state.chunk_id)
376
- requeued_chunks_by_shard[shard_name].append(chunk_state.chunk_id)
377
- initial_pending += 1
493
+ # ChunkState already has shard_url stored
494
+ chunk = ShardChunk(
495
+ chunk_id=chunk_state.chunk_id,
496
+ shard_url=chunk_state.shard_url,
497
+ shard_name=chunk_state.shard_name,
498
+ start_index=chunk_state.start_index,
499
+ chunk_size=chunk_state.chunk_size,
500
+ )
501
+ self.chunk_manager.chunks[chunk_state.chunk_id] = chunk
502
+ self.chunk_manager.pending_chunks.append(chunk_state.chunk_id)
503
+ requeued_chunks_by_shard[shard_name].append(chunk_state.chunk_id)
504
+ initial_pending += 1
378
505
 
379
506
  logger.info(f"Re-queued {initial_pending} existing pending chunks")
380
507
  for shard_name, chunk_ids in requeued_chunks_by_shard.items():
@@ -426,7 +553,13 @@ class Orchestrator:
426
553
  if current_shard_url is None or current_shard_index >= current_shard_items:
427
554
  try:
428
555
  current_shard_url = next(shard_iter)
429
- current_shard_name = Path(current_shard_url).stem
556
+
557
+ # Extract shard name based on type
558
+ if current_shard_url.startswith("hf_dataset:"):
559
+ current_shard_name = current_shard_url # Use full ID for virtual shards
560
+ else:
561
+ current_shard_name = Path(current_shard_url).stem
562
+
430
563
  self.stats["current_shard"] = current_shard_name
431
564
 
432
565
  # Skip if we already have chunks from this shard
@@ -439,16 +572,74 @@ class Orchestrator:
439
572
 
440
573
  # Count items in new shard
441
574
  logger.info(f"Loading new shard {current_shard_name}")
442
- current_shard_items = sum(
443
- 1 for _ in self.dataset_loader.iterate_shard(current_shard_url)
444
- )
575
+
576
+ # For virtual HF dataset shards, use the chunk size directly
577
+ if current_shard_url.startswith("hf_dataset:"):
578
+ current_shard_items = self.dataset_loader.count_shard_items(
579
+ current_shard_url
580
+ )
581
+ logger.info(
582
+ f"Virtual shard {current_shard_name} has {current_shard_items} items"
583
+ )
584
+ else:
585
+ # For WebDataset, actually count items
586
+ current_shard_items = sum(
587
+ 1 for _ in self.dataset_loader.iterate_shard(current_shard_url)
588
+ )
589
+ logger.info(
590
+ f"Shard {current_shard_name} has {current_shard_items} items"
591
+ )
592
+
445
593
  current_shard_index = 0
446
- logger.info(f"Shard {current_shard_name} has {current_shard_items} items")
447
594
 
448
595
  except StopIteration:
449
- # No more shards
596
+ # No more shards in the iterator
597
+ if self._is_hf_dataset:
598
+ # Before creating new virtual shards, check if we have pending chunks
599
+ with self.chunk_manager.lock:
600
+ pending_count = len(self.chunk_manager.pending_chunks)
601
+
602
+ if pending_count > 0:
603
+ # Don't create new shards if we have pending chunks
604
+ logger.debug(
605
+ f"Have {pending_count} pending chunks, not creating new virtual shards yet"
606
+ )
607
+ current_shard_url = None
608
+ time.sleep(2)
609
+ continue
610
+
611
+ # For HF datasets, we can create more virtual shards on demand
612
+ logger.info(
613
+ "Creating additional virtual shards for HuggingFace dataset"
614
+ )
615
+
616
+ # Create 10 more virtual shards
617
+ new_shards = []
618
+ for i in range(10):
619
+ shard_id = f"hf_dataset:{self.dataset_path}:chunk:{self._next_hf_shard_index * self._hf_chunk_size}"
620
+ new_shards.append(shard_id)
621
+ self._next_hf_shard_index += 1
622
+
623
+ # Add to all_shards and create new iterator
624
+ self.all_shards.extend(new_shards)
625
+ self.stats["total_shards"] = len(self.all_shards)
626
+
627
+ # Filter for unprocessed shards
628
+ remaining_new_shards = [
629
+ s
630
+ for s in new_shards
631
+ if s not in shards_summary and s not in completed_shards
632
+ ]
633
+
634
+ if remaining_new_shards:
635
+ shard_iter = iter(remaining_new_shards)
636
+ logger.info(f"Added {len(remaining_new_shards)} new virtual shards")
637
+ continue
638
+
639
+ # No more shards to process
450
640
  logger.info("No more shards to process")
451
641
  break
642
+
452
643
  except Exception as e:
453
644
  logger.error(f"Error loading shard {current_shard_name}: {e}")
454
645
  current_shard_url = None
@@ -456,25 +647,40 @@ class Orchestrator:
456
647
 
457
648
  # Create a chunk from current shard
458
649
  if current_shard_url and current_shard_index < current_shard_items:
459
- chunk_id = f"{current_shard_name}_chunk_{current_shard_index}"
460
- chunk_size = min(self.chunk_size, current_shard_items - current_shard_index)
650
+ # Calculate the absolute dataset index for this chunk
651
+ if current_shard_url.startswith("hf_dataset:"):
652
+ # Parse the virtual shard URL to get the base start index
653
+ parts = current_shard_url.split(":")
654
+ if len(parts) >= 4 and parts[2] == "chunk":
655
+ shard_base_index = int(parts[3])
656
+ else:
657
+ shard_base_index = 0
658
+
659
+ # The absolute start index for this chunk in the dataset
660
+ absolute_start_index = shard_base_index + current_shard_index
661
+ else:
662
+ # For WebDataset, current_shard_index is already absolute
663
+ absolute_start_index = current_shard_index
664
+
665
+ # Create chunk with absolute index
666
+ chunk = ShardChunk.create(
667
+ shard_url=current_shard_url,
668
+ shard_name=current_shard_name,
669
+ start_index=absolute_start_index,
670
+ chunk_size=min(self.chunk_size, current_shard_items - current_shard_index),
671
+ )
461
672
 
462
- # Add to ChunkTracker
673
+ # Add to ChunkTracker with all required fields
463
674
  if self.chunk_tracker and self.chunk_tracker.add_chunk(
464
- chunk_id, current_shard_name, current_shard_index, chunk_size
675
+ chunk.chunk_id,
676
+ chunk.shard_name,
677
+ chunk.shard_url,
678
+ chunk.start_index,
679
+ chunk.chunk_size,
465
680
  ):
466
- # Create chunk
467
- chunk = ShardChunk(
468
- chunk_id=chunk_id,
469
- shard_url=current_shard_url,
470
- shard_name=current_shard_name,
471
- start_index=current_shard_index,
472
- chunk_size=chunk_size,
473
- )
474
-
475
681
  with self.chunk_manager.lock:
476
- self.chunk_manager.chunks[chunk_id] = chunk
477
- self.chunk_manager.pending_chunks.append(chunk_id)
682
+ self.chunk_manager.chunks[chunk.chunk_id] = chunk
683
+ self.chunk_manager.pending_chunks.append(chunk.chunk_id)
478
684
 
479
685
  chunks_created += 1
480
686
  self.stats["total_chunks"] += 1
@@ -484,10 +690,14 @@ class Orchestrator:
484
690
  if chunks_created > 0:
485
691
  logger.info(f"Created {chunks_created} chunks on demand")
486
692
 
487
- # If we couldn't create any chunks and there are no more shards, we're done
693
+ # If we couldn't create any chunks and there are no more shards, check if it's HF dataset
488
694
  if chunks_created == 0 and current_shard_url is None:
489
- logger.info("All shards processed, chunk creation complete")
490
- break
695
+ if self._is_hf_dataset:
696
+ # We can always create more virtual shards for HF datasets
697
+ logger.debug("Will create more virtual shards on next iteration")
698
+ else:
699
+ logger.info("All shards processed, chunk creation complete")
700
+ break
491
701
 
492
702
  # Brief pause to avoid spinning
493
703
  time.sleep(1)
@@ -558,7 +768,9 @@ class Orchestrator:
558
768
  elif auth_ticket.role == "admin":
559
769
  await self._handle_admin(websocket, auth_ticket)
560
770
  else:
561
- await websocket.send(safe_json_dumps({"error": f"Unknown role: {auth_ticket.role}"}))
771
+ await websocket.send(
772
+ safe_json_dumps({"error": f"Unknown role: {auth_ticket.role}"})
773
+ )
562
774
 
563
775
  except Exception as e:
564
776
  logger.error(f"Connection error: {e}")
@@ -604,81 +816,118 @@ class Orchestrator:
604
816
  requires_worker_restart = False
605
817
 
606
818
  try:
819
+ # Extract orchestrator section if present
820
+ if "orchestrator" in new_config:
821
+ # Config has orchestrator wrapper, extract it
822
+ orchestrator_config = new_config["orchestrator"]
823
+ else:
824
+ # Config is already at orchestrator level
825
+ orchestrator_config = new_config
826
+
827
+ # Helper function for deep comparison
828
+ def deep_equal(a, b):
829
+ """Deep comparison of two values including nested dicts and lists."""
830
+ if type(a) != type(b):
831
+ return False
832
+ if isinstance(a, dict):
833
+ if set(a.keys()) != set(b.keys()):
834
+ return False
835
+ return all(deep_equal(a[k], b[k]) for k in a.keys())
836
+ elif isinstance(a, (list, tuple)):
837
+ if len(a) != len(b):
838
+ return False
839
+ return all(deep_equal(x, y) for x, y in zip(a, b))
840
+ else:
841
+ return a == b
842
+
607
843
  # Update vLLM configuration
608
- if "vllm" in new_config:
844
+ if "vllm" in orchestrator_config:
609
845
  old_vllm = self.vllm_config.copy()
846
+ new_vllm = orchestrator_config["vllm"]
610
847
 
611
- # Check each field for actual changes
612
- vllm_changed = False
613
- for key, value in new_config["vllm"].items():
614
- if self.vllm_config.get(key) != value:
615
- self.vllm_config[key] = value
616
- vllm_changed = True
848
+ # Check if vLLM config actually changed using deep comparison
849
+ vllm_changed = not deep_equal(old_vllm, new_vllm)
617
850
 
618
851
  if vllm_changed:
852
+ # Update the vLLM config
853
+ self.vllm_config = new_vllm.copy()
619
854
  updated_sections.append("vllm")
620
855
 
621
856
  # Check if critical changes require worker restart
622
857
  if (
623
- old_vllm.get("model") != self.vllm_config.get("model")
858
+ old_vllm.get("model") != new_vllm.get("model")
624
859
  or old_vllm.get("gpu_memory_utilization")
625
- != self.vllm_config.get("gpu_memory_utilization")
860
+ != new_vllm.get("gpu_memory_utilization")
626
861
  or old_vllm.get("tensor_parallel_size")
627
- != self.vllm_config.get("tensor_parallel_size")
862
+ != new_vllm.get("tensor_parallel_size")
863
+ or old_vllm.get("dtype") != new_vllm.get("dtype")
864
+ or old_vllm.get("max_model_len") != new_vllm.get("max_model_len")
628
865
  ):
629
866
  requires_worker_restart = True
630
867
  warnings.append(
631
868
  "Critical vLLM changes detected - workers will be disconnected to reload"
632
869
  )
870
+ logger.info(
871
+ f"Model change: {old_vllm.get('model')} -> {new_vllm.get('model')}"
872
+ )
633
873
 
634
874
  # Update dataset configuration
635
- if "dataset" in new_config:
636
- dataset_changed = False
637
- for key, value in new_config["dataset"].items():
638
- if self.dataset_config.get(key) != value:
639
- self.dataset_config[key] = value
640
- dataset_changed = True
875
+ if "dataset" in orchestrator_config:
876
+ old_dataset = self.dataset_config.copy()
877
+ new_dataset = orchestrator_config["dataset"]
878
+
879
+ dataset_changed = not deep_equal(old_dataset, new_dataset)
641
880
 
642
881
  if dataset_changed:
882
+ self.dataset_config = new_dataset.copy()
643
883
  self.dataset_path = self.dataset_config.get("path")
644
884
  self.dataset_type = self.dataset_config.get("type", "huggingface")
645
885
  updated_sections.append("dataset")
646
886
  warnings.append("Dataset changes will apply to new chunks only")
647
887
 
648
888
  # Update chunk settings
649
- if "chunk_size" in new_config and self.chunk_size != new_config["chunk_size"]:
650
- self.chunk_size = new_config["chunk_size"]
889
+ if (
890
+ "chunk_size" in orchestrator_config
891
+ and self.chunk_size != orchestrator_config["chunk_size"]
892
+ ):
893
+ self.chunk_size = orchestrator_config["chunk_size"]
651
894
  self.chunk_manager.chunk_size = self.chunk_size
652
895
  updated_sections.append("chunk_size")
653
896
 
654
897
  if (
655
- "chunks_per_request" in new_config
656
- and self.chunks_per_request != new_config["chunks_per_request"]
898
+ "chunks_per_request" in orchestrator_config
899
+ and self.chunks_per_request != orchestrator_config["chunks_per_request"]
657
900
  ):
658
- self.chunks_per_request = new_config["chunks_per_request"]
901
+ self.chunks_per_request = orchestrator_config["chunks_per_request"]
659
902
  updated_sections.append("chunks_per_request")
660
903
 
661
- # Recreate auth manager
662
- self.auth = AuthManager(config=new_config)
904
+ # Update auth configuration
905
+ if "auth" in orchestrator_config:
906
+ try:
907
+ self.auth = AuthManager({"auth": orchestrator_config["auth"]})
908
+ updated_sections.append("auth")
909
+ except Exception as e:
910
+ logger.error(f"Failed to update AuthManager: {e}")
911
+ warnings.append(f"Auth update failed: {e}")
663
912
 
664
913
  # Update buffer settings
665
914
  if (
666
- "chunk_buffer_multiplier" in new_config
667
- and self.chunk_buffer_multiplier != new_config["chunk_buffer_multiplier"]
915
+ "chunk_buffer_multiplier" in orchestrator_config
916
+ and self.chunk_buffer_multiplier != orchestrator_config["chunk_buffer_multiplier"]
668
917
  ):
669
- self.chunk_buffer_multiplier = new_config["chunk_buffer_multiplier"]
918
+ self.chunk_buffer_multiplier = orchestrator_config["chunk_buffer_multiplier"]
670
919
  updated_sections.append("chunk_buffer_multiplier")
671
920
 
672
921
  if (
673
- "min_chunk_buffer" in new_config
674
- and self.min_chunk_buffer != new_config["min_chunk_buffer"]
922
+ "min_chunk_buffer" in orchestrator_config
923
+ and self.min_chunk_buffer != orchestrator_config["min_chunk_buffer"]
675
924
  ):
676
- self.min_chunk_buffer = new_config["min_chunk_buffer"]
925
+ self.min_chunk_buffer = orchestrator_config["min_chunk_buffer"]
677
926
  updated_sections.append("min_chunk_buffer")
678
927
 
679
928
  # Update storage settings
680
- if "storage" in new_config:
681
- storage_config = new_config["storage"]
929
+ if "storage" in orchestrator_config:
930
+ storage_config = orchestrator_config["storage"]
682
931
  storage_changed = False
683
932
 
684
933
  if (
@@ -701,21 +950,6 @@ class Orchestrator:
701
950
  if storage_changed:
702
951
  updated_sections.append("storage")
703
952
 
704
- # Update data worker storage config
705
- if "data_worker_storage" in new_config:
706
- current_dw_storage = self.config.get("data_worker_storage", {})
707
- if current_dw_storage != new_config["data_worker_storage"]:
708
- self.config["data_worker_storage"] = new_config["data_worker_storage"]
709
- updated_sections.append("data_worker_storage")
710
- warnings.append("Data worker storage config will apply to new connections only")
711
-
712
- # Update backpressure threshold
713
- if "backpressure_threshold" in new_config:
714
- current_threshold = getattr(self, "backpressure_threshold", 800)
715
- if current_threshold != new_config["backpressure_threshold"]:
716
- self.backpressure_threshold = new_config["backpressure_threshold"]
717
- updated_sections.append("backpressure_threshold")
718
-
719
953
  # Check if any changes were made
720
954
  if not updated_sections:
721
955
  await websocket.send(
@@ -729,29 +963,49 @@ class Orchestrator:
729
963
  logger.info("Configuration reload requested but no changes detected")
730
964
  return
731
965
 
732
- # Update the main config for any other fields
733
- self.config.update(new_config)
966
+ # Update the main config
967
+ if "orchestrator" in new_config:
968
+ self.config["orchestrator"] = orchestrator_config
969
+ else:
970
+ self.config.update(orchestrator_config)
734
971
 
735
972
  # Handle worker restart if needed
736
973
  if requires_worker_restart:
737
974
  logger.info("Disconnecting all workers for configuration reload...")
738
975
 
739
- # Disconnect all workers
740
- worker_ids = list(self.workers.keys())
741
- for worker_id in worker_ids:
976
+ # Send reload message to workers first
977
+ reload_msg = safe_json_dumps(
978
+ {
979
+ "type": "reload_vllm",
980
+ "vllm_config": self.vllm_config,
981
+ }
982
+ )
983
+
984
+ # Create a list of worker items to avoid modifying dict during iteration
985
+ worker_items = list(self.workers.items())
986
+ disconnected = []
987
+
988
+ for worker_id, ws in worker_items:
742
989
  try:
743
- await self.workers[worker_id].close(
744
- code=1012, reason="Configuration reload"
745
- )
990
+ await ws.send(reload_msg)
991
+ # Give worker time to process before disconnect
992
+ await asyncio.sleep(0.5)
993
+ await ws.close(code=1012, reason="Configuration reload")
994
+ disconnected.append(worker_id)
746
995
  except:
747
- pass
996
+ disconnected.append(worker_id) # Still mark as disconnected if error
997
+
998
+ # Now safely clear workers dict
999
+ for worker_id in disconnected:
1000
+ if worker_id in self.workers:
1001
+ del self.workers[worker_id]
748
1002
 
749
1003
  warnings.append(
750
- f"Disconnected {len(worker_ids)} workers - they will reconnect with new config"
1004
+ f"Sent reload message to {len(disconnected)} workers - they will reconnect with new config"
751
1005
  )
752
1006
  else:
753
- # Just notify workers about config changes
754
- reload_msg = safe_json_dumps(
1007
+ # Just notify workers about config changes without disconnecting
1008
+ config_update_msg = safe_json_dumps(
755
1009
  {
756
1010
  "type": "config_update",
757
1011
  "vllm_config": self.vllm_config if "vllm" in updated_sections else None,
@@ -761,15 +1015,21 @@ class Orchestrator:
761
1015
  }
762
1016
  )
763
1017
 
1018
+ # Create a list of worker items to avoid modifying dict during iteration
1019
+ worker_items = list(self.workers.items())
764
1020
  disconnected = []
765
- for worker_id, ws in self.workers.items():
1021
+
1022
+ for worker_id, ws in worker_items:
766
1023
  try:
767
- await ws.send(reload_msg)
1024
+ await ws.send(config_update_msg)
1025
+ logger.info(f"Sent config update to worker {worker_id}")
768
1026
  except:
769
1027
  disconnected.append(worker_id)
770
1028
 
1029
+ # Now safely remove disconnected workers
771
1030
  for worker_id in disconnected:
772
- del self.workers[worker_id]
1031
+ if worker_id in self.workers:
1032
+ del self.workers[worker_id]
773
1033
 
774
1034
  # Send success response
775
1035
  await websocket.send(
@@ -788,34 +1048,58 @@ class Orchestrator:
788
1048
 
789
1049
  except Exception as e:
790
1050
  logger.error(f"Configuration reload failed: {e}")
1051
+ import traceback
1052
+
1053
+ logger.error(traceback.format_exc())
791
1054
  await websocket.send(safe_json_dumps({"type": "reload_failed", "error": str(e)}))
792
1055
 
793
1056
  async def _handle_worker(self, websocket: WebSocketServerProtocol, auth_ticket):
794
1057
  """Handle worker connection lifecycle."""
795
- worker_id = getattr(auth_ticket, "name", str(uuid.uuid4()))
1058
+ # Generate unique worker ID even if using same token
1059
+ base_name = getattr(auth_ticket, "name", "worker")
1060
+ worker_id = f"{base_name}_{str(uuid.uuid4())[:8]}" # Add unique suffix
1061
+
1062
+ # Track the original token/user for accounting
1063
+ worker_user = base_name # Keep track of which user/token this worker belongs to
1064
+
796
1065
  self.workers[worker_id] = websocket
797
1066
  self.stats["connected_workers"] = len(self.workers)
798
1067
 
799
- # Register contributor
800
- contributor = Contributor(
801
- contributor_id=worker_id, name=worker_id, total_captions=0, trust_level=1
802
- )
803
- await self.storage.save_contributor(contributor)
1068
+ # Optionally track workers by user/token
1069
+ if not hasattr(self, "workers_by_user"):
1070
+ self.workers_by_user = defaultdict(set)
1071
+ self.workers_by_user[worker_user].add(worker_id)
1072
+
1073
+ # Register contributor with the base name (for aggregating stats per user)
1074
+ contributor = await self.storage.get_contributor(worker_user)
1075
+ if not contributor:
1076
+ contributor = Contributor(
1077
+ contributor_id=worker_user,
1078
+ name=worker_user,
1079
+ total_captions=0,
1080
+ trust_level=1,
1081
+ )
1082
+ await self.storage.save_contributor(contributor)
804
1083
 
805
- logger.info(f"Worker {worker_id} connected")
1084
+ logger.info(f"Worker {worker_id} (user: {worker_user}) connected")
806
1085
  await self._broadcast_stats()
807
- await self._send_activity(f"Worker {worker_id} connected")
1086
+ await self._send_activity(f"Worker {worker_id} (user: {worker_user}) connected")
808
1087
 
809
1088
  try:
810
1089
  # Send welcome message with dataset configuration
811
1090
  welcome_message = {
812
1091
  "type": "welcome",
813
1092
  "worker_id": worker_id,
1093
+ "user_id": worker_user,
814
1094
  "dataset_config": {
815
1095
  "dataset_path": self.dataset_path,
816
1096
  "dataset_type": self.dataset_type,
817
- "path": self.dataset_path, # For compatibility
818
- "type": self.dataset_type, # For compatibility
1097
+ "dataset_split": self.dataset_split,
1098
+ "dataset_image_column": self.dataset_image_column,
1099
+ "path": self.dataset_path,
1100
+ "type": self.dataset_type,
1101
+ "split": self.dataset_split,
1102
+ "image_column": self.dataset_image_column,
819
1103
  },
820
1104
  "vllm_config": self.vllm_config,
821
1105
  }
@@ -826,21 +1110,29 @@ class Orchestrator:
826
1110
  await self._process_worker_message(worker_id, data)
827
1111
 
828
1112
  except websockets.exceptions.ConnectionClosed:
829
- logger.info(f"Worker {worker_id} disconnected")
1113
+ logger.info(f"Worker {worker_id} (user: {worker_user}) disconnected")
830
1114
  finally:
831
- del self.workers[worker_id]
1115
+ if worker_id in self.workers:
1116
+ del self.workers[worker_id]
1117
+
1118
+ # Clean up user tracking
1119
+ if hasattr(self, "workers_by_user") and worker_user in self.workers_by_user:
1120
+ self.workers_by_user[worker_user].discard(worker_id)
1121
+ if not self.workers_by_user[worker_user]:
1122
+ del self.workers_by_user[worker_user]
1123
+
832
1124
  self.stats["connected_workers"] = len(self.workers)
833
- # Release chunks in both managers
1125
+
1126
+ # Release chunks
834
1127
  self.chunk_manager.release_worker_chunks(worker_id)
835
1128
  if self.chunk_tracker:
836
- # Mark released chunks as pending in tracker
837
1129
  released_chunks = self.chunk_tracker.release_worker_chunks(worker_id)
838
1130
  logger.info(
839
1131
  f"Released {len(released_chunks) if released_chunks is not None else 0} chunks from worker {worker_id}"
840
1132
  )
841
1133
 
842
1134
  await self._broadcast_stats()
843
- await self._send_activity(f"Worker {worker_id} disconnected")
1135
+ await self._send_activity(f"Worker {worker_id} (user: {worker_user}) disconnected")
844
1136
 
845
1137
  async def _process_worker_message(self, worker_id: str, data: Dict):
846
1138
  """Process message from worker."""
@@ -856,28 +1148,26 @@ class Orchestrator:
856
1148
  return
857
1149
 
858
1150
  count = data.get("count", self.chunks_per_request)
859
- chunks = self.chunk_manager.get_chunks_for_worker(worker_id, count, self.chunk_tracker)
1151
+ chunk_infos = self.chunk_manager.get_chunks_for_worker(
1152
+ worker_id, count, self.chunk_tracker
1153
+ )
860
1154
 
861
- if chunks:
862
- # Only send the fields that worker expects
863
- chunk_data = []
864
- for chunk in chunks:
865
- chunk_data.append(
866
- {
867
- "chunk_id": chunk.chunk_id,
868
- "shard_url": chunk.shard_url,
869
- "shard_name": chunk.shard_name,
870
- "start_index": chunk.start_index,
871
- "chunk_size": chunk.chunk_size,
872
- }
873
- )
1155
+ if chunk_infos:
1156
+ # Send chunks with unprocessed ranges
1157
+ chunks_data = []
1158
+ for info in chunk_infos:
1159
+ chunk_dict = info["chunk"].to_dict()
1160
+ chunk_dict["unprocessed_ranges"] = info["unprocessed_ranges"]
1161
+ chunks_data.append(chunk_dict)
874
1162
 
875
1163
  await self.workers[worker_id].send(
876
- safe_json_dumps({"type": "shard_assignment", "chunks": chunk_data})
1164
+ safe_json_dumps({"type": "shard_assignment", "chunks": chunks_data})
1165
+ )
1166
+
1167
+ chunk_ids = [c["chunk_id"] for c in chunks_data]
1168
+ logger.info(
1169
+ f"Assigned {len(chunks_data)} chunks to worker {worker_id}: {chunk_ids}"
877
1170
  )
878
- chunk_ids = [c["chunk_id"] for c in chunk_data]
879
- logger.info(f"Assigned {len(chunks)} chunks to worker {worker_id}: {chunk_ids}")
880
- await self._send_activity(f"Assigned {len(chunks)} chunks to {worker_id}")
881
1171
  else:
882
1172
  await self.workers[worker_id].send(safe_json_dumps({"type": "no_chunks"}))
883
1173
 
@@ -907,7 +1197,7 @@ class Orchestrator:
907
1197
  elif msg_type == "submit_captions":
908
1198
  await self._handle_captions_submission(worker_id, data)
909
1199
  elif msg_type == "request_job":
910
- # VLLMWorker requesting a job from data samples
1200
+ # CaptionWorker requesting a job from data samples
911
1201
  try:
912
1202
  job = await asyncio.wait_for(self.data_sample_queue.get(), timeout=5)
913
1203
  await self.workers[worker_id].send(
@@ -921,76 +1211,132 @@ class Orchestrator:
921
1211
  logger.debug(f"Heartbeat from {worker_id}: {data}")
922
1212
 
923
1213
  async def _handle_captions_submission(self, worker_id: str, data: Dict):
924
- """Process multiple captions submission from worker."""
1214
+ """Process caption submission from worker - now handles multi-stage outputs."""
925
1215
  chunk_id = data.get("chunk_id")
926
1216
  item_key = data["item_key"]
927
- captions_list = data["captions"]
928
1217
 
929
- logger.debug(
930
- f"Received {len(captions_list)} captions for item {item_key} from worker {worker_id}"
931
- )
1218
+ item_index = data.get("item_index") # Worker should send this
1219
+ if item_index is None:
1220
+ # Try to extract from item_key (format: dataset_XXXXXXXX)
1221
+ try:
1222
+ item_index = int(item_key.split("_")[-1])
1223
+ except:
1224
+ logger.warning(f"Could not extract item index from key: {item_key}")
932
1225
 
933
- # Create a SINGLE caption record with ALL captions as a list
1226
+ # Extract user from worker_id (format: "username_uuid")
1227
+ worker_user = worker_id.rsplit("_", 1)[0] if "_" in worker_id else worker_id
1228
+
1229
+ # Handle both old format (captions list) and new format (outputs dict)
1230
+ if "outputs" in data:
1231
+ # New multi-stage format
1232
+ outputs = data["outputs"]
1233
+ captions_list = outputs.get("captions", [])
1234
+ total_outputs = sum(len(v) for v in outputs.values())
1235
+
1236
+ logger.debug(
1237
+ f"Received multi-stage outputs for item {item_key} from worker {worker_id}: "
1238
+ f"{total_outputs} outputs across {len(outputs)} fields"
1239
+ )
1240
+ else:
1241
+ # Old format - single captions list
1242
+ captions_list = data["captions"]
1243
+ outputs = {"captions": captions_list}
1244
+ total_outputs = len(captions_list)
1245
+
1246
+ logger.debug(
1247
+ f"Received {len(captions_list)} captions for item {item_key} from worker {worker_id}"
1248
+ )
1249
+
1250
+ # Create caption record with multi-stage outputs
934
1251
  caption = Caption(
935
- job_id=f"{chunk_id}_{item_key}", # Single ID for the item
1252
+ job_id=f"{chunk_id}_{item_key}",
936
1253
  dataset=data.get("dataset"),
937
1254
  shard=data.get("shard"),
938
1255
  item_key=item_key,
939
- captions=captions_list, # Store ALL captions as a list
940
- contributor_id=worker_id,
1256
+ captions=captions_list,
1257
+ outputs=outputs,
1258
+ contributor_id=worker_user,
941
1259
  timestamp=datetime.utcnow(),
942
- quality_scores=None, # Could be a list of scores matching captions
1260
+ quality_scores=None,
943
1261
  # Image metadata
944
1262
  image_width=data.get("image_width"),
945
1263
  image_height=data.get("image_height"),
946
1264
  image_format=data.get("image_format"),
947
1265
  file_size=data.get("file_size"),
948
1266
  # Processing metadata
949
- caption_count=len(captions_list),
1267
+ caption_count=total_outputs,
950
1268
  processing_time_ms=data.get("processing_time_ms"),
951
1269
  chunk_id=chunk_id,
1270
+ metadata=data.get("metadata", {}),
952
1271
  )
953
1272
 
954
- # Add to central storage buffer as a single entry
1273
+ # Add to central storage buffer
955
1274
  await self.storage.save_caption(caption)
956
1275
 
957
- # Update statistics
958
- self.stats["total_captions"] += len(captions_list)
959
- self.stats["buffer_size"] = len(self.storage.caption_buffer)
1276
+ # Handle item tracking with fixed deadlock
1277
+ should_flush = False
1278
+ if chunk_id and item_index is not None and self.chunk_tracker:
1279
+ with self.item_batch_lock:
1280
+ self.pending_processed_items[chunk_id].append(item_index)
960
1281
 
961
- # Update contributor stats
962
- contributor = await self.storage.get_contributor(worker_id)
1282
+ # Check if we should flush
1283
+ total_pending = sum(
1284
+ len(indices) for indices in self.pending_processed_items.values()
1285
+ )
1286
+ time_since_flush = time.time() - self.last_item_batch_flush
1287
+
1288
+ if (
1289
+ total_pending >= self.item_batch_size
1290
+ or time_since_flush >= self.item_batch_interval
1291
+ ):
1292
+ should_flush = True
1293
+
1294
+ if should_flush:
1295
+ await self._flush_processed_items()
1296
+
1297
+ # Update contributor stats (use user, not worker)
1298
+ contributor = await self.storage.get_contributor(worker_user)
963
1299
  if contributor:
964
- contributor.total_captions += len(captions_list)
1300
+ contributor.total_captions += total_outputs
965
1301
  await self.storage.save_contributor(contributor)
966
1302
 
967
1303
  # Broadcast updated stats
968
1304
  await self._broadcast_stats()
969
1305
 
970
1306
  # Log progress periodically
971
- if self.stats["total_captions"] % 100 == 0:
972
- logger.info(f"Collected {self.stats['total_captions']} captions centrally")
1307
+ total_outputs = self.stats.get("total_outputs", 0)
1308
+ if total_outputs > 0 and total_outputs % 100 == 0:
1309
+ if (
1310
+ not hasattr(self, "_last_logged_outputs")
1311
+ or self._last_logged_outputs != total_outputs
1312
+ ):
1313
+ logger.info(f"Collected {total_outputs} outputs centrally")
1314
+ self._last_logged_outputs = total_outputs
973
1315
 
974
1316
  async def _check_shard_completion(self, chunk_id: str):
975
1317
  """Check if a shard is complete after chunk completion."""
976
- # Extract shard name from chunk_id
977
- shard_name = chunk_id.rsplit("_chunk_", 1)[0]
1318
+ # Get the chunk
1319
+ chunk = self.chunk_manager.chunks.get(chunk_id)
1320
+ if not chunk:
1321
+ return
978
1322
 
979
- # Check if all chunks for this shard are complete
980
- chunk_stats = self.chunk_manager.get_stats()
1323
+ shard_name = chunk.shard_name
1324
+
1325
+ # Find all chunks for this shard
981
1326
  shard_chunks = [
982
- cid
983
- for cid, chunk in self.chunk_manager.chunks.items()
984
- if chunk.shard_name == shard_name
1327
+ cid for cid, c in self.chunk_manager.chunks.items() if c.belongs_to_shard(shard_name)
985
1328
  ]
986
1329
 
1330
+ # Check if all are completed
987
1331
  completed_chunks = [
988
1332
  cid for cid in shard_chunks if self.chunk_manager.chunks[cid].status == "completed"
989
1333
  ]
990
1334
 
991
- if len(completed_chunks) == len(shard_chunks):
1335
+ if len(completed_chunks) == len(shard_chunks) and len(shard_chunks) > 0:
992
1336
  logger.info(f"Shard {shard_name} complete!")
993
- self.shard_tracker.mark_complete(shard_name)
1337
+ # Don't mark virtual shards as complete in ShardTracker
1338
+ if not shard_name.startswith("hf_dataset:"):
1339
+ self.shard_tracker.mark_complete(shard_name)
994
1340
  self.stats["completed_shards"] += 1
995
1341
  await self._send_activity(f"Shard {shard_name} completed!")
996
1342
 
@@ -1076,12 +1422,29 @@ class Orchestrator:
1076
1422
  chunk_stats = self.chunk_manager.get_stats()
1077
1423
  await websocket.send(safe_json_dumps({"type": "chunk_stats", "data": chunk_stats}))
1078
1424
 
1079
- # Send contributor leaderboard
1425
+ # Send contributor leaderboard with active worker counts
1080
1426
  contributors = await self.storage.get_top_contributors(10)
1427
+
1428
+ # Enhance contributor data with active worker counts
1429
+ enhanced_contributors = []
1430
+ worker_counts = (
1431
+ self.get_workers_by_user_stats() if hasattr(self, "workers_by_user") else {}
1432
+ )
1433
+
1434
+ for contributor in contributors:
1435
+ contrib_dict = {
1436
+ "contributor_id": contributor.contributor_id,
1437
+ "name": contributor.name,
1438
+ "total_captions": contributor.total_captions,
1439
+ "trust_level": contributor.trust_level,
1440
+ "active_workers": len(
1441
+ worker_counts.get(contributor.contributor_id, {}).get("worker_ids", [])
1442
+ ),
1443
+ }
1444
+ enhanced_contributors.append(contrib_dict)
1445
+
1081
1446
  await websocket.send(
1082
- safe_json_dumps(
1083
- {"type": "leaderboard", "data": [safe_dict(c) for c in contributors]}
1084
- )
1447
+ safe_json_dumps({"type": "leaderboard", "data": enhanced_contributors})
1085
1448
  )
1086
1449
 
1087
1450
  # Keep connection alive
@@ -1094,14 +1457,23 @@ class Orchestrator:
1094
1457
  self.monitors.discard(websocket)
1095
1458
 
1096
1459
  async def _broadcast_stats(self):
1097
- """Broadcast statistics to all monitors."""
1460
+ """Broadcast statistics to all monitors - enhanced for multi-stage."""
1098
1461
  if not self.monitors:
1099
1462
  return
1100
1463
 
1464
+ # Get storage stats
1465
+ storage_stats = await self.storage.get_storage_stats()
1466
+ caption_stats = await self.storage.get_caption_stats()
1467
+
1101
1468
  # Include chunk stats
1102
1469
  chunk_stats = self.chunk_manager.get_stats()
1103
1470
  self.stats.update({f"chunks_{k}": v for k, v in chunk_stats.items()})
1104
1471
 
1472
+ # Merge storage stats
1473
+ self.stats.update(storage_stats)
1474
+ self.stats["field_breakdown"] = caption_stats.get("field_stats", {})
1475
+ self.stats["output_fields_list"] = caption_stats.get("output_fields", [])
1476
+
1105
1477
  # Add rate information
1106
1478
  self.stats.update(
1107
1479
  {
@@ -1111,23 +1483,123 @@ class Orchestrator:
1111
1483
  }
1112
1484
  )
1113
1485
 
1114
- # Add vLLM info
1486
+ # Add vLLM info - now includes stage count
1115
1487
  self.stats["vllm_model"] = self.vllm_config.get("model", "unknown")
1116
1488
  self.stats["vllm_batch_size"] = self.vllm_config.get("batch_size", 0)
1117
1489
 
1490
+ # NEW: Add stage information
1491
+ stages = self.vllm_config.get("stages", [])
1492
+ if stages:
1493
+ self.stats["stage_count"] = len(stages)
1494
+ self.stats["stage_names"] = [s.get("name", "unnamed") for s in stages]
1495
+ else:
1496
+ self.stats["stage_count"] = 1 # Backward compatibility
1497
+ self.stats["stage_names"] = ["default"]
1498
+
1499
+ field_stats = await self.storage.get_output_field_stats()
1500
+ self.stats["output_fields"] = field_stats
1501
+
1118
1502
  message = safe_json_dumps({"type": "stats", "data": self.stats})
1119
1503
 
1120
1504
  # Send to all monitors
1121
1505
  disconnected = set()
1122
- for monitor in self.monitors:
1506
+ _monitors = self.monitors.copy()
1507
+ for monitor in _monitors:
1123
1508
  try:
1124
1509
  await monitor.send(message)
1125
1510
  except websockets.exceptions.ConnectionClosed:
1126
1511
  disconnected.add(monitor)
1127
1512
 
1513
+ # send updated leaderboard
1514
+ try:
1515
+ contributors = await self.storage.get_top_contributors(10)
1516
+ enhanced_contributors = []
1517
+ worker_counts = (
1518
+ self.get_workers_by_user_stats() if hasattr(self, "workers_by_user") else {}
1519
+ )
1520
+
1521
+ for contributor in contributors:
1522
+ contrib_dict = {
1523
+ "contributor_id": contributor.contributor_id,
1524
+ "name": contributor.name,
1525
+ "total_captions": contributor.total_captions,
1526
+ "trust_level": contributor.trust_level,
1527
+ "active_workers": len(
1528
+ worker_counts.get(contributor.contributor_id, {}).get("worker_ids", [])
1529
+ ),
1530
+ }
1531
+ enhanced_contributors.append(contrib_dict)
1532
+
1533
+ leaderboard_message = safe_json_dumps(
1534
+ {"type": "leaderboard", "data": enhanced_contributors}
1535
+ )
1536
+
1537
+ # Send to all monitors
1538
+ disconnected = set()
1539
+ for monitor in self.monitors.copy():
1540
+ try:
1541
+ await monitor.send(leaderboard_message)
1542
+ except websockets.exceptions.ConnectionClosed:
1543
+ disconnected.add(monitor)
1544
+
1545
+ self.monitors -= disconnected
1546
+
1547
+ except Exception as e:
1548
+ logger.error(f"Error sending leaderboard update: {e}")
1549
+
1128
1550
  # Clean up disconnected monitors
1129
1551
  self.monitors -= disconnected
1130
1552
 
1553
+ async def _flush_processed_items(self):
1554
+ """Flush batched processed items to chunk tracker."""
1555
+ with self.item_batch_lock:
1556
+ if not self.pending_processed_items:
1557
+ return
1558
+
1559
+ for chunk_id, indices in self.pending_processed_items.items():
1560
+ if not indices:
1561
+ continue
1562
+
1563
+ # Indices here are ABSOLUTE dataset indices
1564
+ # Sort indices
1565
+ indices.sort()
1566
+
1567
+ # Group consecutive indices into ranges
1568
+ ranges = []
1569
+ start = indices[0]
1570
+ end = indices[0]
1571
+
1572
+ for i in range(1, len(indices)):
1573
+ if indices[i] == end + 1:
1574
+ # Consecutive, extend range
1575
+ end = indices[i]
1576
+ else:
1577
+ # Gap found, save current range and start new one
1578
+ ranges.append((start, end))
1579
+ start = indices[i]
1580
+ end = indices[i]
1581
+
1582
+ # Don't forget the last range
1583
+ ranges.append((start, end))
1584
+
1585
+ # Mark ranges as processed (mark_items_processed expects absolute indices)
1586
+ for start_idx, end_idx in ranges:
1587
+ self.chunk_tracker.mark_items_processed(chunk_id, start_idx, end_idx)
1588
+
1589
+ # Clear pending items
1590
+ self.pending_processed_items.clear()
1591
+ self.last_item_batch_flush = time.time()
1592
+
1593
+ def get_workers_by_user_stats(self) -> Dict[str, Any]:
1594
+ """Get statistics about workers grouped by user/token."""
1595
+ if not hasattr(self, "workers_by_user"):
1596
+ return {}
1597
+
1598
+ stats = {}
1599
+ for user, worker_ids in self.workers_by_user.items():
1600
+ stats[user] = {"worker_count": len(worker_ids), "worker_ids": list(worker_ids)}
1601
+ return stats
1602
+
1131
1603
  async def _send_activity(self, activity: str):
1132
1604
  """Send activity update to monitors."""
1133
1605
  if not self.monitors:
@@ -1172,36 +1644,52 @@ class Orchestrator:
1172
1644
  while True:
1173
1645
  await asyncio.sleep(60)
1174
1646
 
1647
+ # Get current caption count from storage
1648
+ storage_stats = await self.storage.get_storage_stats()
1649
+ total_captions = storage_stats["total_captions"]
1650
+
1175
1651
  # Force checkpoint at regular intervals
1176
- if self.stats["total_captions"] > 0 and self.stats["total_captions"] % interval == 0:
1177
- logger.info(f"Triggering checkpoint at {self.stats['total_captions']} captions")
1652
+ if total_captions > 0 and total_captions % interval == 0:
1653
+ logger.info(f"Triggering checkpoint at {total_captions} captions")
1178
1654
  await self.storage.checkpoint()
1179
1655
 
1180
1656
  # Update stats
1181
1657
  self.stats["last_checkpoint"] = datetime.utcnow().isoformat()
1182
- self.stats["total_written"] = self.storage.total_captions_written
1183
- self.stats["buffer_size"] = len(self.storage.caption_buffer)
1658
+ # No need to update total_written or buffer_size - they come from storage
1184
1659
 
1185
1660
  await self._broadcast_stats()
1186
1661
  logger.info(
1187
- f"Checkpoint complete. Total written to disk: {self.stats['total_written']}"
1662
+ f"Checkpoint complete. Total written to disk: {storage_stats['total_written']}"
1188
1663
  )
1189
1664
 
1190
1665
  async def _stats_update_loop(self):
1191
1666
  """Periodically update and broadcast stats."""
1192
1667
  # Track session start values
1193
- session_start_captions = self.stats["total_captions"]
1668
+ storage_stats = await self.storage.get_storage_stats()
1669
+ session_start_outputs = storage_stats["total_captions"] # This now counts ALL outputs
1194
1670
  session_start_time = time.time()
1195
1671
 
1672
+ # Track the last known total to detect flushes
1673
+ last_known_total = session_start_outputs
1674
+
1196
1675
  while True:
1197
1676
  await asyncio.sleep(10)
1198
1677
 
1199
1678
  # Update chunk stats
1200
1679
  chunk_stats = self.chunk_manager.get_stats()
1680
+ storage_stats = await self.storage.get_storage_stats()
1681
+ current_total_outputs = storage_stats["total_captions"] # ALL outputs
1682
+ if self.chunk_tracker:
1683
+ await self._flush_processed_items()
1684
+
1201
1685
  self.stats["total_chunks"] = chunk_stats["total"]
1202
1686
  self.stats["completed_chunks"] = chunk_stats["completed"]
1203
1687
  self.stats["failed_chunks"] = chunk_stats["failed"]
1204
1688
 
1689
+ # Update total outputs stat (rename from total_captions for clarity)
1690
+ self.stats["total_outputs"] = current_total_outputs
1691
+ self.stats["total_captions"] = current_total_outputs # Keep for backward compatibility
1692
+
1205
1693
  # Add queue information
1206
1694
  with self.chunk_manager.lock:
1207
1695
  self.stats["pending_chunks"] = len(self.chunk_manager.pending_chunks)
@@ -1220,33 +1708,57 @@ class Orchestrator:
1220
1708
  elapsed_since_update = current_time - self.rate_tracker["last_update_time"]
1221
1709
 
1222
1710
  if elapsed_since_update > 0:
1223
- # Calculate current rate (captions per minute)
1224
- caption_diff = (
1225
- self.stats["total_captions"] - self.rate_tracker["last_caption_count"]
1226
- )
1227
- self.rate_tracker["current_rate"] = (caption_diff / elapsed_since_update) * 60
1711
+ # FIX: Handle the case where duplicates were skipped during save
1712
+ # If current total is less than last known, it means duplicates were skipped
1713
+ # We should not count this as negative progress
1714
+ if current_total_outputs < last_known_total:
1715
+ logger.debug(
1716
+ f"Detected duplicate skip during save: {last_known_total} -> {current_total_outputs}"
1717
+ )
1718
+ # Don't calculate negative rate, just update the baseline
1719
+ self.rate_tracker["last_caption_count"] = current_total_outputs
1720
+ self.rate_tracker["current_rate"] = 0.0 # Set to 0 during flush
1721
+ else:
1722
+ # Normal rate calculation
1723
+ output_diff = current_total_outputs - self.rate_tracker["last_caption_count"]
1724
+ self.rate_tracker["current_rate"] = (output_diff / elapsed_since_update) * 60
1725
+ self.rate_tracker["last_caption_count"] = current_total_outputs
1228
1726
 
1229
1727
  # Calculate average rate since THIS SESSION started
1230
1728
  session_elapsed = current_time - session_start_time
1231
1729
  if session_elapsed > 0:
1232
- session_captions = self.stats["total_captions"] - session_start_captions
1233
- self.rate_tracker["average_rate"] = (session_captions / session_elapsed) * 60
1730
+ # Always use the difference from session start for average
1731
+ session_outputs = current_total_outputs - session_start_outputs
1732
+ self.rate_tracker["average_rate"] = (session_outputs / session_elapsed) * 60
1234
1733
 
1235
- # Calculate expected rate based on workers
1236
- # Assume each worker processes batch_size images every ~2 seconds with 3 captions each
1734
+ # Calculate expected rate based on workers and stages
1237
1735
  batch_size = self.vllm_config.get("batch_size", 8)
1238
- num_prompts = len(self.vllm_config.get("inference_prompts", ["", "", ""]))
1736
+
1737
+ # Count total prompts across all stages
1738
+ total_prompts = 0
1739
+ stages = self.vllm_config.get("stages", [])
1740
+ if stages:
1741
+ for stage in stages:
1742
+ total_prompts += len(stage.get("prompts", []))
1743
+ else:
1744
+ # Backward compatibility
1745
+ total_prompts = len(self.vllm_config.get("inference_prompts", ["", "", ""]))
1746
+
1239
1747
  images_per_minute = 30 # Rough estimate: 30 images/min per worker
1240
- self.rate_tracker["expected_rate"] = worker_count * images_per_minute * num_prompts
1748
+ self.rate_tracker["expected_rate"] = (
1749
+ worker_count * images_per_minute * total_prompts
1750
+ )
1241
1751
 
1242
1752
  # Update trackers
1243
1753
  self.rate_tracker["last_update_time"] = current_time
1244
- self.rate_tracker["last_caption_count"] = self.stats["total_captions"]
1754
+ last_known_total = current_total_outputs
1245
1755
 
1246
1756
  # Log rate information when workers are connected
1247
- if worker_count > 0:
1757
+ if (
1758
+ worker_count > 0 and self.rate_tracker["current_rate"] >= 0
1759
+ ): # Only log non-negative rates
1248
1760
  logger.info(
1249
- f"Rate: {self.rate_tracker['current_rate']:.1f} captions/min "
1761
+ f"Rate: {self.rate_tracker['current_rate']:.1f} outputs/min "
1250
1762
  f"(avg: {self.rate_tracker['average_rate']:.1f}, "
1251
1763
  f"expected: {self.rate_tracker['expected_rate']:.1f}) | "
1252
1764
  f"Workers: {worker_count}, Chunks: {active_chunks}/{target_buffer}"
@@ -1256,16 +1768,16 @@ class Orchestrator:
1256
1768
 
1257
1769
  async def _restore_state(self):
1258
1770
  """Restore state from storage on startup."""
1259
- # Update statistics
1260
- self.stats["total_captions"] = await self.storage.count_captions()
1261
-
1262
- logger.info(f"Restored state: {self.stats['total_captions']} captions")
1771
+ total_captions = await self.storage.count_captions()
1772
+ logger.info(f"Restored state: {total_captions} captions")
1263
1773
 
1264
1774
  async def shutdown(self):
1265
1775
  """Graceful shutdown."""
1266
1776
  logger.info("Shutting down orchestrator...")
1267
1777
 
1268
1778
  # Stop chunk creation
1779
+ if self.chunk_tracker:
1780
+ await self._flush_processed_items()
1269
1781
  self.stop_chunk_creation.set()
1270
1782
  if self.chunk_creation_thread:
1271
1783
  self.chunk_creation_thread.join(timeout=5)
@@ -1287,7 +1799,7 @@ class Orchestrator:
1287
1799
 
1288
1800
  # Save chunk state
1289
1801
  if self.chunk_tracker:
1290
- self.chunk_tracker.save_checkpoint()
1802
+ self.chunk_tracker.save()
1291
1803
 
1292
1804
  # Final checkpoint
1293
1805
  logger.info(f"Final flush: {len(self.storage.caption_buffer)} captions in buffer")