caption-flow 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,412 +1,106 @@
1
- """Enhanced orchestrator with shard chunk assignment for vLLM workers.
2
-
3
- This orchestrator:
4
- 1. Divides dataset shards into chunks for parallel processing
5
- 2. Assigns chunks to workers on request
6
- 3. Collects captions from workers centrally
7
- 4. Manages checkpoints and fault tolerance
8
- """
9
-
10
1
  import time
11
2
  import asyncio
12
3
  import json
13
4
  import logging
14
5
  import ssl
15
6
  import uuid
16
- from dataclasses import dataclass, asdict
17
7
  from datetime import datetime
18
8
  from pathlib import Path
19
- from typing import Dict, Set, Optional, Any, List, Deque
20
- from collections import deque, defaultdict
9
+ from typing import Dict, Set, Optional, Any, List
10
+ from collections import defaultdict
21
11
  import threading
22
- from queue import Queue, Empty
23
12
 
24
- from .workers import data
25
13
  import websockets
26
14
  from websockets.server import WebSocketServerProtocol
27
15
 
28
16
  from .storage import StorageManager
29
- from .models import Caption, Contributor
17
+ from .models import Caption, Contributor, JobId
30
18
  from .utils.auth import AuthManager
31
- from .utils import DatasetLoader, ShardTracker, ChunkTracker
32
- from .utils.json_utils import safe_dict, safe_json_dumps, to_json_dict
19
+ from .utils.json_utils import safe_json_dumps
20
+ from .processors import (
21
+ ProcessorConfig,
22
+ WorkAssignment,
23
+ WorkResult,
24
+ WorkUnit,
25
+ WebDatasetOrchestratorProcessor,
26
+ HuggingFaceDatasetOrchestratorProcessor,
27
+ LocalFilesystemOrchestratorProcessor,
28
+ )
33
29
 
34
30
  logger = logging.getLogger(__name__)
35
-
36
-
37
- @dataclass
38
- class ShardChunk:
39
- """Represents a chunk of a shard for processing."""
40
-
41
- chunk_id: str
42
- shard_url: str
43
- shard_name: str
44
- start_index: int
45
- chunk_size: int
46
- assigned_to: Optional[str] = None
47
- status: str = "pending" # pending, assigned, completed, failed
48
- assigned_at: Optional[datetime] = None
49
- completed_at: Optional[datetime] = None
50
-
51
- @classmethod
52
- def create(
53
- cls, shard_url: str, shard_name: str, start_index: int, chunk_size: int
54
- ) -> "ShardChunk":
55
- """Factory method to create a chunk with consistent ID."""
56
- # Always use consistent format: dataset_chunk_startindex
57
- if shard_url.startswith("hf_dataset:"):
58
- # Extract dataset path
59
- parts = shard_url.split(":")
60
- dataset_path = parts[1] if len(parts) > 1 else "unknown"
61
- chunk_id = f"{dataset_path.replace('/', '_')}_chunk_{start_index}"
62
- else:
63
- # WebDataset format
64
- chunk_id = f"{shard_name}_chunk_{start_index}"
65
-
66
- return cls(
67
- chunk_id=chunk_id,
68
- shard_url=shard_url,
69
- shard_name=shard_name,
70
- start_index=start_index,
71
- chunk_size=chunk_size,
72
- )
73
-
74
- def belongs_to_shard(self, shard_identifier: str) -> bool:
75
- """Check if this chunk belongs to a given shard."""
76
- return self.shard_name == shard_identifier
77
-
78
- def to_dict(self) -> Dict[str, Any]:
79
- """Convert to dict for JSON serialization (for workers)."""
80
- return {
81
- "chunk_id": self.chunk_id,
82
- "shard_url": self.shard_url,
83
- "shard_name": self.shard_name,
84
- "start_index": self.start_index,
85
- "chunk_size": self.chunk_size,
86
- }
87
-
88
-
89
- class ChunkManager:
90
- """Manages shard chunk creation and assignment."""
91
-
92
- def __init__(self, chunk_size: int = 1000, tracker: Optional[ChunkTracker] = None):
93
- self.chunk_size = chunk_size
94
- self.chunks: Dict[str, ShardChunk] = {}
95
- self.pending_chunks: Deque[str] = deque()
96
- self.assigned_chunks: Dict[str, Set[str]] = defaultdict(set) # worker_id -> chunk_ids
97
- self.lock = threading.Lock()
98
- self.tracker = tracker # Reference to chunk tracker
99
-
100
- def create_chunks_from_shard(
101
- self, shard_url: str, shard_name: str, total_items: int
102
- ) -> List[ShardChunk]:
103
- """Create chunks from a shard."""
104
- chunks = []
105
-
106
- for start_idx in range(0, total_items, self.chunk_size):
107
- chunk = ShardChunk.create(
108
- shard_url=shard_url,
109
- shard_name=shard_name,
110
- start_index=start_idx,
111
- chunk_size=min(self.chunk_size, total_items - start_idx),
112
- )
113
-
114
- with self.lock:
115
- self.chunks[chunk.chunk_id] = chunk
116
- self.pending_chunks.append(chunk.chunk_id)
117
-
118
- chunks.append(chunk)
119
-
120
- return chunks
121
-
122
- def get_chunks_for_worker(
123
- self, worker_id: str, count: int = 1, tracker: Optional["ChunkTracker"] = None
124
- ) -> List[Dict[str, Any]]:
125
- """Get available chunks with unprocessed items for a worker."""
126
- assigned = []
127
-
128
- with self.lock:
129
- # FIRST PRIORITY: Check if this worker already has assigned chunks
130
- # Workers should complete their current chunks before getting new ones
131
- if worker_id in self.assigned_chunks:
132
- existing_chunk_ids = list(self.assigned_chunks[worker_id])
133
- for chunk_id in existing_chunk_ids:
134
- if len(assigned) >= count:
135
- break
136
-
137
- chunk = self.chunks.get(chunk_id)
138
- if not chunk:
139
- continue
140
-
141
- # Check if chunk still has unprocessed items
142
- if tracker:
143
- chunk_info = tracker.get_chunk_with_unprocessed_items(chunk_id)
144
- if chunk_info and chunk_info["unprocessed_ranges"]:
145
- assigned.append(
146
- {
147
- "chunk": chunk,
148
- "unprocessed_ranges": chunk_info["unprocessed_ranges"],
149
- }
150
- )
151
- else:
152
- # No tracker, assume chunk needs processing
153
- assigned.append(
154
- {
155
- "chunk": chunk,
156
- "unprocessed_ranges": [(0, chunk.chunk_size - 1)],
157
- }
158
- )
159
-
160
- # SECOND PRIORITY: Get new pending chunks
161
- # Only if worker doesn't have enough chunks already
162
- while len(assigned) < count and self.pending_chunks:
163
- chunk_id = self.pending_chunks.popleft()
164
- chunk = self.chunks.get(chunk_id)
165
-
166
- if not chunk:
167
- continue
168
-
169
- # Verify chunk is truly pending (defensive check)
170
- if chunk.status != "pending" or chunk.assigned_to is not None:
171
- logger.warning(
172
- f"Chunk {chunk_id} in pending queue but status={chunk.status}, assigned_to={chunk.assigned_to}"
173
- )
174
- continue
175
-
176
- # Assign to this worker
177
- chunk.assigned_to = worker_id
178
- chunk.status = "assigned"
179
- chunk.assigned_at = datetime.utcnow()
180
- self.assigned_chunks[worker_id].add(chunk_id)
181
-
182
- # Get unprocessed ranges
183
- unprocessed_ranges = [(0, chunk.chunk_size - 1)] # Default
184
- if tracker:
185
- chunk_info = tracker.get_chunk_with_unprocessed_items(chunk_id)
186
- if chunk_info:
187
- unprocessed_ranges = chunk_info["unprocessed_ranges"]
188
- tracker.mark_assigned(chunk_id, worker_id)
189
-
190
- assigned.append({"chunk": chunk, "unprocessed_ranges": unprocessed_ranges})
191
-
192
- # Log what we're assigning
193
- if assigned:
194
- chunk_summary = ", ".join(
195
- [
196
- f"{info['chunk'].chunk_id}[{len(info['unprocessed_ranges'])} ranges]"
197
- for info in assigned
198
- ]
199
- )
200
- logger.info(f"Assigning to worker {worker_id}: {chunk_summary}")
201
-
202
- return assigned
203
-
204
- def complete_chunk(self, chunk_id: str, worker_id: str) -> bool:
205
- """Mark a chunk as completed."""
206
- with self.lock:
207
- if chunk_id in self.chunks:
208
- chunk = self.chunks[chunk_id]
209
- if chunk.assigned_to == worker_id and chunk.status == "assigned":
210
- chunk.status = "completed"
211
- chunk.completed_at = datetime.utcnow()
212
- self.assigned_chunks[worker_id].discard(chunk_id)
213
- return True
214
- return False
215
-
216
- def fail_chunk(self, chunk_id: str, worker_id: str) -> bool:
217
- """Mark a chunk as failed and requeue it."""
218
- with self.lock:
219
- if chunk_id in self.chunks:
220
- chunk = self.chunks[chunk_id]
221
- if chunk.assigned_to == worker_id:
222
- chunk.status = "pending"
223
- chunk.assigned_to = None
224
- chunk.assigned_at = None
225
- self.assigned_chunks[worker_id].discard(chunk_id)
226
- self.pending_chunks.append(chunk_id)
227
- return True
228
- return False
229
-
230
- def release_worker_chunks(self, worker_id: str):
231
- """Release all chunks assigned to a worker."""
232
- with self.lock:
233
- chunk_ids = list(self.assigned_chunks.get(worker_id, []))
234
- for chunk_id in chunk_ids:
235
- if chunk_id in self.chunks:
236
- chunk = self.chunks[chunk_id]
237
- if chunk.status == "assigned":
238
- chunk.status = "pending"
239
- chunk.assigned_to = None
240
- chunk.assigned_at = None
241
- self.pending_chunks.append(chunk_id)
242
-
243
- if worker_id in self.assigned_chunks:
244
- del self.assigned_chunks[worker_id]
245
-
246
- def get_stats(self) -> Dict[str, int]:
247
- """Get chunk statistics."""
248
- with self.lock:
249
- stats = {
250
- "total": len(self.chunks),
251
- "pending": len(self.pending_chunks),
252
- "assigned": sum(len(chunks) for chunks in self.assigned_chunks.values()),
253
- "completed": sum(1 for c in self.chunks.values() if c.status == "completed"),
254
- "failed": sum(1 for c in self.chunks.values() if c.status == "failed"),
255
- }
256
- return stats
31
+ logger.setLevel(logging.INFO)
257
32
 
258
33
 
259
34
  class Orchestrator:
260
- """Enhanced orchestrator for vLLM-based distributed captioning with chunk assignment."""
35
+ """Generic orchestrator for distributed work processing."""
261
36
 
262
37
  def __init__(self, config: Dict[str, Any]):
263
38
  self.config = config
264
39
  self.host = config.get("host", "0.0.0.0")
265
40
  self.port = config.get("port", 8765)
266
41
 
267
- # Dataset configuration
268
- self.dataset_config = config.get("dataset", {})
269
- self.dataset_path = self.dataset_config.get("path")
270
- self.dataset_type = self.dataset_config.get("type", "huggingface")
271
- self.dataset_split = self.dataset_config.get("split", "train") # Add split configuration
272
- self.dataset_image_column = self.dataset_config.get(
273
- "image_column", "image"
274
- ) # Add image column config
275
-
276
- # Dataset components
277
- self.dataset_loader = None
278
- self.shard_tracker = None
279
- self.chunk_tracker = None
280
-
281
- if self.dataset_path:
282
- self.dataset_loader = DatasetLoader(
283
- self.dataset_path,
284
- self.dataset_type,
285
- self.dataset_split,
286
- self.dataset_image_column,
287
- )
288
- checkpoint_dir = Path(config.get("storage", {}).get("checkpoint_dir", "./checkpoints"))
289
- checkpoint_dir.mkdir(parents=True, exist_ok=True)
290
- self.shard_tracker = ShardTracker(checkpoint_dir / "shards.json")
291
- self.chunk_tracker = ChunkTracker(checkpoint_dir / "chunks.json")
292
-
293
- # vLLM configuration to distribute to workers
294
- self.vllm_config = config.get(
295
- "vllm",
296
- {
297
- "model": "Qwen/Qwen2.5-VL-3B-Instruct",
298
- "gpu_memory_utilization": 0.92,
299
- "max_model_len": 16384,
300
- "tensor_parallel_size": 1,
301
- "dtype": "float16",
302
- "enforce_eager": True,
303
- "limit_mm_per_prompt": {"image": 1},
304
- "disable_mm_preprocessor_cache": True,
305
- "sampling": {
306
- "temperature": 0.7,
307
- "top_p": 0.95,
308
- "max_tokens": 256,
309
- "repetition_penalty": 1.05,
310
- "stop": ["<|end|>", "<|endoftext|>", "<|im_end|>"],
311
- },
312
- "inference_prompts": [
313
- "describe this image in detail",
314
- "provide a comprehensive description of the visual content",
315
- "what are the key elements in this image?",
316
- ],
317
- },
318
- )
319
-
320
- # Chunk configuration
321
- self.chunk_size = config.get("chunk_size", 1000)
322
- self.chunks_per_request = config.get("chunks_per_request", 2)
323
-
324
- # Demand-driven chunk creation settings
325
- self.chunk_buffer_multiplier = config.get("chunk_buffer_multiplier", 3)
326
- self.min_chunk_buffer = config.get("min_chunk_buffer", 10)
42
+ # Processor configuration
43
+ processor_type = config.get("dataset", {}).get("processor_type", None)
44
+ assert (
45
+ processor_type is not None
46
+ ), "You must supply processor_type in your orchestrator dataset configuration."
47
+ processor_config = ProcessorConfig(processor_type=processor_type, config=config)
48
+
49
+ # Initialize processor
50
+ if processor_type == "webdataset":
51
+ self.processor = WebDatasetOrchestratorProcessor()
52
+ elif processor_type == "huggingface_datasets":
53
+ self.processor = HuggingFaceDatasetOrchestratorProcessor()
54
+ elif processor_type == "local_filesystem":
55
+ self.processor = LocalFilesystemOrchestratorProcessor()
56
+ else:
57
+ raise ValueError(f"Unknown processor type: {processor_type}")
327
58
 
328
59
  # Initialize components
329
60
  storage_config = config.get("storage", {})
330
61
  self.storage = StorageManager(
331
62
  Path(storage_config.get("data_dir", "./caption_data")),
332
63
  caption_buffer_size=storage_config.get("caption_buffer_size", 1000),
333
- job_buffer_size=storage_config.get("job_buffer_size", 100),
334
- contributor_buffer_size=storage_config.get("contributor_buffer_size", 10),
335
64
  )
336
65
  self.auth = AuthManager(config.get("auth", {}))
66
+ self.processor.initialize(processor_config, self.storage)
337
67
 
338
- # Dataset components
339
- self.dataset_loader = None
340
- self.shard_tracker = None
341
- self.chunk_tracker = None
342
-
343
- if self.dataset_path:
344
- self.dataset_loader = DatasetLoader(self.dataset_path, self.dataset_type)
345
- checkpoint_dir = Path(config.get("storage", {}).get("checkpoint_dir", "./checkpoints"))
346
- checkpoint_dir.mkdir(parents=True, exist_ok=True)
347
- self.shard_tracker = ShardTracker(checkpoint_dir / "shards.json")
348
- self.chunk_tracker = ChunkTracker(checkpoint_dir / "chunks.json")
349
-
350
- # Initialize chunk manager with reference to chunk tracker
351
- self.chunk_manager = ChunkManager(self.chunk_size, self.chunk_tracker)
352
- self.pending_processed_items = defaultdict(list) # chunk_id -> list of indices
353
- self.item_batch_lock = threading.Lock()
354
- self.last_item_batch_flush = time.time()
355
- self.item_batch_interval = 5 # Flush every 5 seconds
356
- self.item_batch_size = 100 # Or every 100 items
68
+ # Processing configuration
69
+ self.units_per_request = config.get("units_per_request", 2)
357
70
 
358
71
  # Track connections
359
72
  self.workers: Dict[str, WebSocketServerProtocol] = {}
360
73
  self.monitors: Set[WebSocketServerProtocol] = set()
74
+ self.workers_by_user = defaultdict(set)
361
75
 
362
76
  # SSL configuration
363
77
  self.ssl_context = self._setup_ssl()
364
78
 
365
79
  # Statistics
366
- self.is_generating_stats = False
367
80
  self.stats = {
368
- "total_chunks": 0,
369
- "completed_chunks": 0,
370
- "failed_chunks": 0,
371
81
  "connected_workers": 0,
372
- "total_shards": 0,
373
- "completed_shards": 0,
374
- "current_shard": None,
82
+ "total_outputs": 0,
375
83
  "last_checkpoint": None,
84
+ "processor_stats": {},
376
85
  }
377
86
 
87
+ # Cache for leaderboard
88
+ self._cached_leaderboard = None
89
+
90
+ # Data worker stuff
91
+ self.data_workers = {}
92
+ self.data_sample_queue = asyncio.Queue()
93
+ self.backpressure_threshold = config.get("backpressure_threshold", 1000)
94
+
378
95
  # Rate tracking
379
96
  self.rate_tracker = {
380
97
  "start_time": time.time(),
381
98
  "last_update_time": time.time(),
382
- "last_caption_count": 0,
99
+ "last_output_count": 0,
383
100
  "current_rate": 0.0,
384
101
  "average_rate": 0.0,
385
- "expected_rate": 0.0,
386
102
  }
387
103
 
388
- # Data sample queue for CaptionWorker
389
- self.data_sample_queue = asyncio.Queue(maxsize=1000)
390
- self.data_workers: Dict[str, WebSocketServerProtocol] = {}
391
-
392
- # Backpressure threshold
393
- self.backpressure_threshold = config.get("backpressure_threshold", 800)
394
-
395
- # Shard processing state
396
- self.all_shards = []
397
- self.current_shard_index = 0
398
- self.shard_lock = threading.Lock()
399
-
400
- # Background chunk creation
401
- self.chunk_creation_thread = None
402
- self.stop_chunk_creation = threading.Event()
403
-
404
- # State restoration flag
405
- self.state_restored = threading.Event()
406
- # If no dataset, state is already "restored"
407
- if not self.dataset_loader:
408
- self.state_restored.set()
409
-
410
104
  def _setup_ssl(self) -> Optional[ssl.SSLContext]:
411
105
  """Configure SSL if certificates are provided."""
412
106
  ssl_config = self.config.get("ssl", {})
@@ -417,324 +111,20 @@ class Orchestrator:
417
111
  context.load_cert_chain(ssl_config["cert"], ssl_config["key"])
418
112
  return context
419
113
 
420
- def _create_chunks_from_dataset(self):
421
- """Background thread to create chunks from dataset shards on demand."""
422
- if not self.dataset_loader:
423
- logger.warning("No dataset configured, skipping chunk creation")
424
- self.state_restored.set() # No state to restore
425
- return
426
-
427
- logger.info("Starting chunk creation thread")
428
-
429
- # Mark state as not restored until we process checkpoints
430
- self.state_restored.clear()
431
-
432
- # Get dataset info to check format
433
- dataset_info = self.dataset_loader.get_dataset_info()
434
- dataset_format = dataset_info.get("dataset_format", "unknown")
435
- logger.info(f"Dataset format: {dataset_format}")
436
-
437
- # Get all shards
438
- self.all_shards = self.dataset_loader.get_shard_list()
439
- self.stats["total_shards"] = len(self.all_shards)
440
-
441
- # For HuggingFace datasets, we might need to dynamically create more shards
442
- if dataset_format == "huggingface_datasets":
443
- self._is_hf_dataset = True
444
- self._hf_chunk_size = 10000 # Items per virtual shard
445
- self._next_hf_shard_index = len(self.all_shards) # For creating new virtual shards
446
- else:
447
- self._is_hf_dataset = False
448
-
449
- # Get shard status from ChunkTracker
450
- shards_summary = self.chunk_tracker.get_shards_summary() if self.chunk_tracker else {}
451
- completed_shards = {
452
- shard_name for shard_name, info in shards_summary.items() if info["is_complete"]
453
- }
454
-
455
- # Update ShardTracker for completed shards
456
- for shard_name in completed_shards:
457
- if not self.shard_tracker.is_complete(shard_name):
458
- logger.info(f"Marking shard {shard_name} as complete in ShardTracker")
459
- self.shard_tracker.mark_complete(shard_name)
460
-
461
- # Get shards that need processing
462
- remaining_shards = self.shard_tracker.get_remaining_shards(self.all_shards)
463
-
464
- # Also check which shards already have chunks (partial or complete)
465
- shards_with_chunks = set()
466
- for shard_name in shards_summary.keys():
467
- shards_with_chunks.add(shard_name)
468
-
469
- # Filter out shards that already have chunks created
470
- remaining_shards = [
471
- shard
472
- for shard in remaining_shards
473
- if (shard if shard.startswith("hf_dataset:") else Path(shard).stem)
474
- not in shards_with_chunks
475
- ]
476
-
477
- self.stats["completed_shards"] = len(completed_shards)
478
-
479
- logger.info(
480
- f"Total shards: {len(self.all_shards)}, "
481
- f"Completed: {self.stats['completed_shards']}, "
482
- f"Shards with chunks: {len(shards_with_chunks)}, "
483
- f"Remaining to process: {len(remaining_shards)}"
484
- )
485
-
486
- # First, re-queue any existing pending chunks
487
- initial_pending = 0
488
- requeued_chunks_by_shard = defaultdict(list)
489
-
490
- for shard_name, shard_info in shards_summary.items():
491
- with self.chunk_manager.lock:
492
- for chunk_state in shard_info["chunks"]:
493
- if chunk_state.status in ["pending", "failed", "assigned"]:
494
- # ChunkState already has shard_url stored
495
- chunk = ShardChunk(
496
- chunk_id=chunk_state.chunk_id,
497
- shard_url=chunk_state.shard_url,
498
- shard_name=chunk_state.shard_name,
499
- start_index=chunk_state.start_index,
500
- chunk_size=chunk_state.chunk_size,
501
- )
502
- self.chunk_manager.chunks[chunk_state.chunk_id] = chunk
503
- self.chunk_manager.pending_chunks.append(chunk_state.chunk_id)
504
- requeued_chunks_by_shard[shard_name].append(chunk_state.chunk_id)
505
- initial_pending += 1
506
-
507
- logger.info(f"Re-queued {initial_pending} existing pending chunks")
508
- for shard_name, chunk_ids in requeued_chunks_by_shard.items():
509
- logger.info(f" Shard {shard_name}: {len(chunk_ids)} chunks - {chunk_ids}")
510
-
511
- # Mark state as restored
512
- self.state_restored.set()
513
- logger.info("State restoration complete, accepting chunk requests")
514
-
515
- # Process shards on-demand
516
- shard_iter = iter(remaining_shards)
517
- current_shard_url = None
518
- current_shard_name = None
519
- current_shard_items = None
520
- current_shard_index = 0
521
-
522
- while not self.stop_chunk_creation.is_set():
523
- # Check how many chunks we need
524
- with self.chunk_manager.lock:
525
- pending_count = len(self.chunk_manager.pending_chunks)
526
- assigned_count = sum(
527
- len(chunks) for chunks in self.chunk_manager.assigned_chunks.values()
528
- )
529
- total_active = pending_count + assigned_count
530
-
531
- # Target buffer: configurable multiplier × number of workers
532
- worker_count = max(1, self.stats.get("connected_workers", 0))
533
- target_buffer = max(
534
- self.min_chunk_buffer, worker_count * self.chunk_buffer_multiplier
535
- )
536
-
537
- chunks_needed = max(0, target_buffer - total_active)
538
-
539
- if chunks_needed == 0:
540
- # We have enough chunks, wait a bit
541
- time.sleep(5)
542
- continue
543
-
544
- logger.debug(
545
- f"Need {chunks_needed} more chunks (pending: {pending_count}, "
546
- f"assigned: {assigned_count}, workers: {worker_count})"
547
- )
548
-
549
- # Create chunks as needed
550
- chunks_created = 0
551
-
552
- while chunks_created < chunks_needed and not self.stop_chunk_creation.is_set():
553
- # Need to load next shard?
554
- if current_shard_url is None or current_shard_index >= current_shard_items:
555
- try:
556
- current_shard_url = next(shard_iter)
557
-
558
- # Extract shard name based on type
559
- if current_shard_url.startswith("hf_dataset:"):
560
- current_shard_name = current_shard_url # Use full ID for virtual shards
561
- else:
562
- current_shard_name = Path(current_shard_url).stem
563
-
564
- self.stats["current_shard"] = current_shard_name
565
-
566
- # Skip if we already have chunks from this shard
567
- if current_shard_name in shards_summary:
568
- logger.debug(
569
- f"Skipping shard {current_shard_name} - already has chunks"
570
- )
571
- current_shard_url = None
572
- continue
573
-
574
- # Count items in new shard
575
- logger.info(f"Loading new shard {current_shard_name}")
576
-
577
- # For virtual HF dataset shards, use the chunk size directly
578
- if current_shard_url.startswith("hf_dataset:"):
579
- current_shard_items = self.dataset_loader.count_shard_items(
580
- current_shard_url
581
- )
582
- logger.info(
583
- f"Virtual shard {current_shard_name} has {current_shard_items} items"
584
- )
585
- else:
586
- # For WebDataset, actually count items
587
- current_shard_items = sum(
588
- 1 for _ in self.dataset_loader.iterate_shard(current_shard_url)
589
- )
590
- logger.info(
591
- f"Shard {current_shard_name} has {current_shard_items} items"
592
- )
593
-
594
- current_shard_index = 0
595
-
596
- except StopIteration:
597
- # No more shards in the iterator
598
- if self._is_hf_dataset:
599
- # Before creating new virtual shards, check if we have pending chunks
600
- with self.chunk_manager.lock:
601
- pending_count = len(self.chunk_manager.pending_chunks)
602
-
603
- if pending_count > 0:
604
- # Don't create new shards if we have pending chunks
605
- logger.debug(
606
- f"Have {pending_count} pending chunks, not creating new virtual shards yet"
607
- )
608
- current_shard_url = None
609
- time.sleep(2)
610
- continue
611
-
612
- # For HF datasets, we can create more virtual shards on demand
613
- logger.info(
614
- "Creating additional virtual shards for HuggingFace dataset"
615
- )
616
-
617
- # Create 10 more virtual shards
618
- new_shards = []
619
- for i in range(10):
620
- shard_id = f"hf_dataset:{self.dataset_path}:chunk:{self._next_hf_shard_index * self._hf_chunk_size}"
621
- new_shards.append(shard_id)
622
- self._next_hf_shard_index += 1
623
-
624
- # Add to all_shards and create new iterator
625
- self.all_shards.extend(new_shards)
626
- self.stats["total_shards"] = len(self.all_shards)
627
-
628
- # Filter for unprocessed shards
629
- remaining_new_shards = [
630
- s
631
- for s in new_shards
632
- if s not in shards_summary and s not in completed_shards
633
- ]
634
-
635
- if remaining_new_shards:
636
- shard_iter = iter(remaining_new_shards)
637
- logger.info(f"Added {len(remaining_new_shards)} new virtual shards")
638
- continue
639
-
640
- # No more shards to process
641
- logger.info("No more shards to process")
642
- break
643
-
644
- except Exception as e:
645
- logger.error(f"Error loading shard {current_shard_name}: {e}")
646
- current_shard_url = None
647
- continue
648
-
649
- # Create a chunk from current shard
650
- if current_shard_url and current_shard_index < current_shard_items:
651
- # Calculate the absolute dataset index for this chunk
652
- if current_shard_url.startswith("hf_dataset:"):
653
- # Parse the virtual shard URL to get the base start index
654
- parts = current_shard_url.split(":")
655
- if len(parts) >= 4 and parts[2] == "chunk":
656
- shard_base_index = int(parts[3])
657
- else:
658
- shard_base_index = 0
659
-
660
- # The absolute start index for this chunk in the dataset
661
- absolute_start_index = shard_base_index + current_shard_index
662
- else:
663
- # For WebDataset, current_shard_index is already absolute
664
- absolute_start_index = current_shard_index
665
-
666
- # Create chunk with absolute index
667
- chunk = ShardChunk.create(
668
- shard_url=current_shard_url,
669
- shard_name=current_shard_name,
670
- start_index=absolute_start_index,
671
- chunk_size=min(self.chunk_size, current_shard_items - current_shard_index),
672
- )
673
-
674
- # Add to ChunkTracker with all required fields
675
- if self.chunk_tracker and self.chunk_tracker.add_chunk(
676
- chunk.chunk_id,
677
- chunk.shard_name,
678
- chunk.shard_url,
679
- chunk.start_index,
680
- chunk.chunk_size,
681
- ):
682
- with self.chunk_manager.lock:
683
- self.chunk_manager.chunks[chunk.chunk_id] = chunk
684
- self.chunk_manager.pending_chunks.append(chunk.chunk_id)
685
-
686
- chunks_created += 1
687
- self.stats["total_chunks"] += 1
688
-
689
- current_shard_index += self.chunk_size
690
-
691
- if chunks_created > 0:
692
- logger.info(f"Created {chunks_created} chunks on demand")
693
-
694
- # If we couldn't create any chunks and there are no more shards, check if it's HF dataset
695
- if chunks_created == 0 and current_shard_url is None:
696
- if self._is_hf_dataset:
697
- # We can always create more virtual shards for HF datasets
698
- logger.debug("Will create more virtual shards on next iteration")
699
- else:
700
- logger.info("All shards processed, chunk creation complete")
701
- break
702
-
703
- # Brief pause to avoid spinning
704
- time.sleep(1)
705
-
706
- # Final stats
707
- if self.chunk_tracker:
708
- final_stats = self.chunk_tracker.get_stats()
709
- logger.info(
710
- f"Chunk creation thread ending. Total: {final_stats['total']}, "
711
- f"Pending: {final_stats['pending']}, Completed: {final_stats['completed']}"
712
- )
713
-
714
- logger.info("Chunk creation thread finished")
715
-
716
114
  async def start(self):
717
115
  """Start the orchestrator server."""
718
- logger.info(f"Starting vLLM orchestrator on {self.host}:{self.port}")
719
- logger.info(
720
- f"vLLM config: model={self.vllm_config.get('model')}, batch_size={self.vllm_config.get('batch_size')}"
721
- )
722
-
723
- # Load existing state BEFORE accepting connections
724
- await self.storage.initialize()
725
- if self.chunk_tracker:
726
- await self.chunk_tracker.sync_with_storage(self.storage)
727
- await self._restore_state()
728
-
729
- # Start chunk creation thread if dataset is configured
730
- if self.dataset_loader:
731
- self.chunk_creation_thread = threading.Thread(
732
- target=self._create_chunks_from_dataset, daemon=True
116
+ logger.info(f"Starting orchestrator on {self.host}:{self.port}")
117
+ processor_type = self.config.get("dataset", {}).get("processor_type", None)
118
+ if not processor_type:
119
+ logger.info(f"Config: {self.config}")
120
+ raise ValueError(
121
+ "You must supply processor_type in your orchestrator dataset configuration."
733
122
  )
734
- self.chunk_creation_thread.start()
123
+ logger.info(f"Processor type: {processor_type}")
735
124
 
736
- # Give chunk creation thread time to restore existing chunks
737
- await asyncio.sleep(2)
125
+ # Initialize storage
126
+ await self.storage.initialize()
127
+ await self.update_unprocessed_ranges()
738
128
 
739
129
  # Start background tasks
740
130
  asyncio.create_task(self._heartbeat_loop())
@@ -742,12 +132,37 @@ class Orchestrator:
742
132
  asyncio.create_task(self._stats_update_loop())
743
133
 
744
134
  # Start WebSocket server
135
+ websocket_logger = logging.getLogger("websockets")
136
+ websocket_logger.setLevel(logging.WARNING)
745
137
  async with websockets.serve(
746
- self.handle_connection, self.host, self.port, ssl=self.ssl_context
138
+ self.handle_connection,
139
+ self.host,
140
+ self.port,
141
+ ssl=self.ssl_context,
142
+ logger=websocket_logger,
747
143
  ):
748
- logger.info("vLLM Orchestrator ready for connections")
144
+ logger.info("Orchestrator ready for connections")
749
145
  await asyncio.Future() # Run forever
750
146
 
147
+ def get_workers_by_user_stats(self) -> Dict[str, Dict]:
148
+ """Get worker statistics grouped by user."""
149
+ stats = {}
150
+ for user, worker_ids in self.workers_by_user.items():
151
+ stats[user] = {"worker_ids": list(worker_ids), "count": len(worker_ids)}
152
+ return stats
153
+
154
+ async def update_unprocessed_ranges(self):
155
+ """Update unprocessed ranges based on what's actually in storage."""
156
+ if not self.processor or not self.storage:
157
+ return
158
+
159
+ processed_job_ids = self.storage.get_all_processed_job_ids()
160
+ self.processor.update_from_storage(processed_job_ids)
161
+
162
+ async def _send_leaderboard_to_monitor(self, websocket: WebSocketServerProtocol):
163
+ """Alias for _send_monitor_leaderboard for backward compatibility."""
164
+ await self._send_monitor_leaderboard(websocket)
165
+
751
166
  async def handle_connection(self, websocket: WebSocketServerProtocol):
752
167
  """Handle new WebSocket connection."""
753
168
  try:
@@ -762,49 +177,77 @@ class Orchestrator:
762
177
 
763
178
  if auth_ticket.role == "worker":
764
179
  await self._handle_worker(websocket, auth_ticket)
765
- elif auth_ticket.role == "data_worker":
766
- await self._handle_data_worker(websocket, auth_ticket)
767
180
  elif auth_ticket.role == "monitor":
768
181
  await self._handle_monitor(websocket)
769
182
  elif auth_ticket.role == "admin":
770
183
  await self._handle_admin(websocket, auth_ticket)
184
+ elif auth_ticket.role == "data_worker":
185
+ await self._handle_data_worker(websocket, auth_ticket)
771
186
  else:
772
187
  await websocket.send(
773
188
  safe_json_dumps({"error": f"Unknown role: {auth_ticket.role}"})
774
189
  )
775
190
 
776
191
  except Exception as e:
777
- logger.error(f"Connection error: {e}")
778
- import traceback
779
-
780
- logger.error(traceback.format_exc())
192
+ logger.error(f"Connection error: {e}", exc_info=True)
781
193
  await websocket.close()
782
194
 
783
- async def _handle_admin(self, websocket: WebSocketServerProtocol, auth_ticket):
784
- """Handle admin connection for configuration updates."""
785
- admin_id = getattr(auth_ticket, "name", "admin")
786
- logger.info(f"Admin {admin_id} connected")
195
+ async def _handle_worker(self, websocket: WebSocketServerProtocol, auth_ticket):
196
+ """Handle worker connection lifecycle."""
197
+ # Generate unique worker ID
198
+ base_name = getattr(auth_ticket, "name", "worker")
199
+ worker_id = f"{base_name}_{str(uuid.uuid4())[:8]}"
200
+ worker_user = base_name
201
+
202
+ self.workers[worker_id] = websocket
203
+ self.workers_by_user[worker_user].add(worker_id)
204
+ self.stats["connected_workers"] = len(self.workers)
787
205
 
206
+ # Register contributor
207
+ contributor = await self.storage.get_contributor(worker_user)
208
+ if not contributor:
209
+ contributor = Contributor(
210
+ contributor_id=worker_user,
211
+ name=worker_user,
212
+ total_captions=0,
213
+ trust_level=1,
214
+ )
215
+ await self.storage.save_contributor(contributor)
216
+
217
+ logger.info(f"Worker {worker_id} (user: {worker_user}) is retrieving configuration")
788
218
  try:
789
- # Send welcome
790
- await websocket.send(safe_json_dumps({"type": "welcome", "role": "admin"}))
219
+ # Send welcome message with processor config
220
+ filtered_config = self.config.copy()
221
+ for unwanted_key in ["auth", "orchestrator", "storage"]:
222
+ filtered_config.pop(unwanted_key, None)
223
+ welcome_message = {
224
+ "type": "welcome",
225
+ "worker_id": worker_id,
226
+ "user_id": worker_user,
227
+ "processor_type": self.config.get("dataset", {}).get("processor_type", None),
228
+ "processor_config": filtered_config,
229
+ }
230
+ await websocket.send(safe_json_dumps(welcome_message))
791
231
 
792
232
  async for message in websocket:
793
- try:
794
- data = json.loads(message)
795
- msg_type = data.get("type")
233
+ data = json.loads(message)
234
+ await self._process_worker_message(worker_id, data)
796
235
 
797
- if msg_type == "reload_config":
798
- await self._handle_config_reload(websocket, data.get("config", {}))
236
+ except websockets.exceptions.ConnectionClosed:
237
+ logger.info(f"Worker {worker_id} has disconnected due to websocket connection closure")
238
+ finally:
239
+ if worker_id in self.workers:
240
+ del self.workers[worker_id]
799
241
 
800
- except json.JSONDecodeError as e:
801
- logger.error(f"Invalid admin message: {e}")
802
- await websocket.send(
803
- safe_json_dumps({"type": "error", "error": "Invalid message format"})
804
- )
242
+ self.workers_by_user[worker_user].discard(worker_id)
243
+ if not self.workers_by_user[worker_user]:
244
+ del self.workers_by_user[worker_user]
805
245
 
806
- except websockets.exceptions.ConnectionClosed:
807
- logger.info(f"Admin {admin_id} disconnected")
246
+ self.stats["connected_workers"] = len(self.workers)
247
+
248
+ # Release assignments
249
+ self.processor.release_assignments(worker_id)
250
+ logger.info(f"Worker {worker_id} has safely disconnected")
808
251
 
809
252
  async def _handle_config_reload(
810
253
  self, websocket: WebSocketServerProtocol, new_config: Dict[str, Any]
@@ -814,224 +257,61 @@ class Orchestrator:
814
257
 
815
258
  updated_sections = []
816
259
  warnings = []
817
- requires_worker_restart = False
818
260
 
819
261
  try:
820
262
  # Extract orchestrator section if present
821
263
  if "orchestrator" in new_config:
822
- # Config has orchestrator wrapper, extract it
823
264
  orchestrator_config = new_config["orchestrator"]
824
265
  else:
825
- # Config is already at orchestrator level
826
266
  orchestrator_config = new_config
827
267
 
828
- # Helper function for deep comparison
829
- def deep_equal(a, b):
830
- """Deep comparison of two values including nested dicts and lists."""
831
- if type(a) != type(b):
832
- return False
833
- if isinstance(a, dict):
834
- if set(a.keys()) != set(b.keys()):
835
- return False
836
- return all(deep_equal(a[k], b[k]) for k in a.keys())
837
- elif isinstance(a, (list, tuple)):
838
- if len(a) != len(b):
839
- return False
840
- return all(deep_equal(x, y) for x, y in zip(a, b))
268
+ # Update processor configuration if present
269
+ if "processor_type" in orchestrator_config:
270
+ old_type = self.config.get("processor_type")
271
+ new_type = orchestrator_config["processor_type"]
272
+
273
+ if old_type != new_type:
274
+ warnings.append("Processor type changes require orchestrator restart")
275
+ updated_sections.append("processor_type")
841
276
  else:
842
- return a == b
843
-
844
- # Update vLLM configuration
845
- if "vllm" in orchestrator_config:
846
- old_vllm = self.vllm_config.copy()
847
- new_vllm = orchestrator_config["vllm"]
848
-
849
- # Check if vLLM config actually changed using deep comparison
850
- vllm_changed = not deep_equal(old_vllm, new_vllm)
851
-
852
- if vllm_changed:
853
- # Update the vLLM config
854
- self.vllm_config = new_vllm.copy()
855
- updated_sections.append("vllm")
856
-
857
- # Check if critical changes require worker restart
858
- if (
859
- old_vllm.get("model") != new_vllm.get("model")
860
- or old_vllm.get("gpu_memory_utilization")
861
- != new_vllm.get("gpu_memory_utilization")
862
- or old_vllm.get("tensor_parallel_size")
863
- != new_vllm.get("tensor_parallel_size")
864
- or old_vllm.get("dtype") != new_vllm.get("dtype")
865
- or old_vllm.get("max_model_len") != new_vllm.get("max_model_len")
866
- ):
867
- requires_worker_restart = True
868
- warnings.append(
869
- "Critical vLLM changes detected - workers will be disconnected to reload"
870
- )
871
- logger.info(
872
- f"Model change: {old_vllm.get('model')} -> {new_vllm.get('model')}"
873
- )
874
-
875
- # Update dataset configuration
876
- if "dataset" in orchestrator_config:
877
- old_dataset = self.dataset_config.copy()
878
- new_dataset = orchestrator_config["dataset"]
879
-
880
- dataset_changed = not deep_equal(old_dataset, new_dataset)
881
-
882
- if dataset_changed:
883
- self.dataset_config = new_dataset.copy()
884
- self.dataset_path = self.dataset_config.get("path")
885
- self.dataset_type = self.dataset_config.get("type", "huggingface")
886
- updated_sections.append("dataset")
887
- warnings.append("Dataset changes will apply to new chunks only")
888
-
889
- # Update chunk settings
890
- if (
891
- "chunk_size" in orchestrator_config
892
- and self.chunk_size != orchestrator_config["chunk_size"]
893
- ):
894
- self.chunk_size = orchestrator_config["chunk_size"]
895
- self.chunk_manager.chunk_size = self.chunk_size
896
- updated_sections.append("chunk_size")
897
-
898
- if (
899
- "chunks_per_request" in orchestrator_config
900
- and self.chunks_per_request != orchestrator_config["chunks_per_request"]
901
- ):
902
- self.chunks_per_request = orchestrator_config["chunks_per_request"]
903
- updated_sections.append("chunks_per_request")
277
+ # Update processor config
278
+ self.config.update(orchestrator_config)
279
+
280
+ # Reinitialize processor with new config
281
+ processor_config = ProcessorConfig(
282
+ processor_type=new_type, config=orchestrator_config
283
+ )
284
+ self.processor.initialize(processor_config)
285
+ updated_sections.append("processor_config")
286
+
287
+ # Update units per request
288
+ if "units_per_request" in orchestrator_config:
289
+ self.units_per_request = orchestrator_config["units_per_request"]
290
+ updated_sections.append("units_per_request")
904
291
 
905
292
  # Update auth configuration
906
293
  if "auth" in orchestrator_config:
907
294
  try:
908
- self.auth = AuthManager({"auth": orchestrator_config["auth"]})
295
+ self.auth = AuthManager(orchestrator_config["auth"])
909
296
  updated_sections.append("auth")
910
297
  except Exception as e:
911
298
  logger.error(f"Failed to update AuthManager: {e}")
912
299
  warnings.append(f"Auth update failed: {e}")
913
300
 
914
- # Update buffer settings
915
- if (
916
- "chunk_buffer_multiplier" in orchestrator_config
917
- and self.chunk_buffer_multiplier != orchestrator_config["chunk_buffer_multiplier"]
918
- ):
919
- self.chunk_buffer_multiplier = orchestrator_config["chunk_buffer_multiplier"]
920
- updated_sections.append("chunk_buffer_multiplier")
921
-
922
- if (
923
- "min_chunk_buffer" in orchestrator_config
924
- and self.min_chunk_buffer != orchestrator_config["min_chunk_buffer"]
925
- ):
926
- self.min_chunk_buffer = orchestrator_config["min_chunk_buffer"]
927
- updated_sections.append("min_chunk_buffer")
928
-
929
301
  # Update storage settings
930
302
  if "storage" in orchestrator_config:
931
303
  storage_config = orchestrator_config["storage"]
932
- storage_changed = False
933
304
 
934
- if (
935
- "caption_buffer_size" in storage_config
936
- and self.storage.caption_buffer_size != storage_config["caption_buffer_size"]
937
- ):
305
+ if "caption_buffer_size" in storage_config:
938
306
  self.storage.caption_buffer_size = storage_config["caption_buffer_size"]
939
- storage_changed = True
940
-
941
- if "checkpoint_interval" in storage_config:
942
- current_interval = self.config.get("storage", {}).get(
943
- "checkpoint_interval", 1000
944
- )
945
- if current_interval != storage_config["checkpoint_interval"]:
946
- self.config.setdefault("storage", {})["checkpoint_interval"] = (
947
- storage_config["checkpoint_interval"]
948
- )
949
- storage_changed = True
307
+ updated_sections.append("storage.caption_buffer_size")
950
308
 
951
- if storage_changed:
952
- updated_sections.append("storage")
953
-
954
- # Check if any changes were made
955
- if not updated_sections:
956
- await websocket.send(
957
- safe_json_dumps(
958
- {
959
- "type": "reload_complete",
960
- "message": "No changes applied - configuration is identical",
961
- }
962
- )
963
- )
964
- logger.info("Configuration reload requested but no changes detected")
965
- return
966
-
967
- # Update the main config
309
+ # Update main config
968
310
  if "orchestrator" in new_config:
969
- self.config["orchestrator"] = orchestrator_config
311
+ self.config = new_config["orchestrator"]
970
312
  else:
971
313
  self.config.update(orchestrator_config)
972
314
 
973
- # Handle worker restart if needed
974
- if requires_worker_restart:
975
- logger.info("Disconnecting all workers for configuration reload...")
976
-
977
- # Send reload message to workers first
978
- reload_msg = safe_json_dumps(
979
- {
980
- "type": "reload_vllm",
981
- "vllm_config": self.vllm_config,
982
- }
983
- )
984
-
985
- # Create a list of worker items to avoid modifying dict during iteration
986
- worker_items = list(self.workers.items())
987
- disconnected = []
988
-
989
- for worker_id, ws in worker_items:
990
- try:
991
- await ws.send(reload_msg)
992
- # Give worker time to process before disconnect
993
- await asyncio.sleep(0.5)
994
- await ws.close(code=1012, reason="Configuration reload")
995
- disconnected.append(worker_id)
996
- except:
997
- disconnected.append(worker_id) # Still mark as disconnected if error
998
-
999
- # Now safely clear workers dict
1000
- for worker_id in disconnected:
1001
- if worker_id in self.workers:
1002
- del self.workers[worker_id]
1003
-
1004
- warnings.append(
1005
- f"Sent reload message to {len(disconnected)} workers - they will reconnect with new config"
1006
- )
1007
- else:
1008
- # Just notify workers about config changes without disconnecting
1009
- config_update_msg = safe_json_dumps(
1010
- {
1011
- "type": "config_update",
1012
- "vllm_config": self.vllm_config if "vllm" in updated_sections else None,
1013
- "dataset_config": (
1014
- self.dataset_config if "dataset" in updated_sections else None
1015
- ),
1016
- }
1017
- )
1018
-
1019
- # Create a list of worker items to avoid modifying dict during iteration
1020
- worker_items = list(self.workers.items())
1021
- disconnected = []
1022
-
1023
- for worker_id, ws in worker_items:
1024
- try:
1025
- await ws.send(config_update_msg)
1026
- logger.info(f"Sent config update to worker {worker_id}")
1027
- except:
1028
- disconnected.append(worker_id)
1029
-
1030
- # Now safely remove disconnected workers
1031
- for worker_id in disconnected:
1032
- if worker_id in self.workers:
1033
- del self.workers[worker_id]
1034
-
1035
315
  # Send success response
1036
316
  await websocket.send(
1037
317
  safe_json_dumps(
@@ -1040,306 +320,169 @@ class Orchestrator:
1040
320
  )
1041
321
 
1042
322
  logger.info(f"Configuration reloaded. Updated sections: {', '.join(updated_sections)}")
1043
-
1044
- # Broadcast stats update to monitors
1045
- await self._broadcast_stats()
1046
323
  await self._send_activity(
1047
324
  f"Configuration reloaded by admin: {', '.join(updated_sections)}"
1048
325
  )
1049
326
 
1050
327
  except Exception as e:
1051
328
  logger.error(f"Configuration reload failed: {e}")
1052
- import traceback
1053
-
1054
- logger.error(traceback.format_exc())
1055
329
  await websocket.send(safe_json_dumps({"type": "reload_failed", "error": str(e)}))
1056
330
 
1057
- async def _handle_worker(self, websocket: WebSocketServerProtocol, auth_ticket):
1058
- """Handle worker connection lifecycle."""
1059
- # Generate unique worker ID even if using same token
1060
- base_name = getattr(auth_ticket, "name", "worker")
1061
- worker_id = f"{base_name}_{str(uuid.uuid4())[:8]}" # Add unique suffix
1062
-
1063
- # Track the original token/user for accounting
1064
- worker_user = base_name # Keep track of which user/token this worker belongs to
1065
-
1066
- self.workers[worker_id] = websocket
1067
- self.stats["connected_workers"] = len(self.workers)
1068
-
1069
- # Optionally track workers by user/token
1070
- if not hasattr(self, "workers_by_user"):
1071
- self.workers_by_user = defaultdict(set)
1072
- self.workers_by_user[worker_user].add(worker_id)
1073
-
1074
- # Register contributor with the base name (for aggregating stats per user)
1075
- contributor = await self.storage.get_contributor(worker_user)
1076
- if not contributor:
1077
- contributor = Contributor(
1078
- contributor_id=worker_user,
1079
- name=worker_user,
1080
- total_captions=0,
1081
- trust_level=1,
1082
- )
1083
- await self.storage.save_contributor(contributor)
1084
-
1085
- logger.info(f"Worker {worker_id} (user: {worker_user}) connected")
1086
- await self._broadcast_stats()
1087
- await self._send_activity(f"Worker {worker_id} (user: {worker_user}) connected")
1088
-
1089
- try:
1090
- # Send welcome message with dataset configuration
1091
- welcome_message = {
1092
- "type": "welcome",
1093
- "worker_id": worker_id,
1094
- "user_id": worker_user,
1095
- "dataset_config": {
1096
- "dataset_path": self.dataset_path,
1097
- "dataset_type": self.dataset_type,
1098
- "dataset_split": self.dataset_split,
1099
- "dataset_image_column": self.dataset_image_column,
1100
- "path": self.dataset_path,
1101
- "type": self.dataset_type,
1102
- "split": self.dataset_split,
1103
- "image_column": self.dataset_image_column,
1104
- },
1105
- "vllm_config": self.vllm_config,
1106
- }
1107
- await websocket.send(safe_json_dumps(welcome_message))
1108
-
1109
- async for message in websocket:
1110
- data = json.loads(message)
1111
- await self._process_worker_message(worker_id, data)
1112
-
1113
- except websockets.exceptions.ConnectionClosed:
1114
- logger.info(f"Worker {worker_id} (user: {worker_user}) disconnected")
1115
- finally:
1116
- if worker_id in self.workers:
1117
- del self.workers[worker_id]
1118
-
1119
- # Clean up user tracking
1120
- if hasattr(self, "workers_by_user") and worker_user in self.workers_by_user:
1121
- self.workers_by_user[worker_user].discard(worker_id)
1122
- if not self.workers_by_user[worker_user]:
1123
- del self.workers_by_user[worker_user]
1124
-
1125
- self.stats["connected_workers"] = len(self.workers)
1126
-
1127
- # Release chunks
1128
- self.chunk_manager.release_worker_chunks(worker_id)
1129
- if self.chunk_tracker:
1130
- released_chunks = self.chunk_tracker.release_worker_chunks(worker_id)
1131
- logger.info(
1132
- f"Released {len(released_chunks) if released_chunks is not None else 0} chunks from worker {worker_id}"
1133
- )
1134
-
1135
- await self._broadcast_stats()
1136
- await self._send_activity(f"Worker {worker_id} (user: {worker_user}) disconnected")
1137
-
1138
331
  async def _process_worker_message(self, worker_id: str, data: Dict):
1139
332
  """Process message from worker."""
1140
333
  msg_type = data.get("type")
1141
334
 
1142
- if msg_type == "request_chunks":
1143
- # Wait for state restoration to complete
1144
- if not self.state_restored.is_set():
1145
- logger.info(f"Worker {worker_id} requesting chunks, but state not yet restored")
1146
- await self.workers[worker_id].send(
1147
- safe_json_dumps({"type": "no_chunks", "reason": "state_restoring"})
335
+ if msg_type == "request_work":
336
+ count = data.get("count", self.units_per_request)
337
+ units = self.processor.get_work_units(count, worker_id)
338
+ logger.debug(f"Assigning units: {[unit.chunk_id for unit in units]}")
339
+
340
+ if units:
341
+ # Create assignment
342
+ assignment = WorkAssignment(
343
+ assignment_id=str(uuid.uuid4()),
344
+ worker_id=worker_id,
345
+ units=units,
346
+ assigned_at=datetime.utcnow(),
1148
347
  )
1149
- return
1150
-
1151
- count = data.get("count", self.chunks_per_request)
1152
- chunk_infos = self.chunk_manager.get_chunks_for_worker(
1153
- worker_id, count, self.chunk_tracker
1154
- )
1155
-
1156
- if chunk_infos:
1157
- # Send chunks with unprocessed ranges
1158
- chunks_data = []
1159
- for info in chunk_infos:
1160
- chunk_dict = info["chunk"].to_dict()
1161
- chunk_dict["unprocessed_ranges"] = info["unprocessed_ranges"]
1162
- chunks_data.append(chunk_dict)
1163
348
 
1164
349
  await self.workers[worker_id].send(
1165
- safe_json_dumps({"type": "shard_assignment", "chunks": chunks_data})
350
+ safe_json_dumps({"type": "work_assignment", "assignment": assignment.to_dict()})
1166
351
  )
1167
352
 
1168
- chunk_ids = [c["chunk_id"] for c in chunks_data]
1169
- logger.info(
1170
- f"Assigned {len(chunks_data)} chunks to worker {worker_id}: {chunk_ids}"
1171
- )
353
+ logger.debug(f"Assigned {len(units)} work units to worker {worker_id}")
1172
354
  else:
1173
- await self.workers[worker_id].send(safe_json_dumps({"type": "no_chunks"}))
355
+ await self.workers[worker_id].send(safe_json_dumps({"type": "no_work"}))
1174
356
 
1175
- elif msg_type == "chunk_complete":
1176
- chunk_id = data["chunk_id"]
1177
- if self.chunk_manager.complete_chunk(chunk_id, worker_id):
1178
- self.stats["completed_chunks"] += 1
357
+ elif msg_type == "work_complete":
358
+ unit_id = data["unit_id"]
359
+ self.processor.mark_completed(unit_id, worker_id)
360
+ logger.debug(f"Work unit {unit_id} completed by worker {worker_id}")
1179
361
 
1180
- if self.chunk_tracker:
1181
- self.chunk_tracker.mark_completed(chunk_id)
1182
-
1183
- logger.info(f"Chunk {chunk_id} completed by worker {worker_id}")
1184
- await self._check_shard_completion(chunk_id)
1185
- await self._send_activity(f"Chunk {chunk_id} completed by {worker_id}")
1186
- elif msg_type == "chunk_failed":
1187
- chunk_id = data["chunk_id"]
362
+ elif msg_type == "work_failed":
363
+ unit_id = data["unit_id"]
1188
364
  error = data.get("error", "Unknown error")
1189
- if self.chunk_manager.fail_chunk(chunk_id, worker_id):
1190
- self.stats["failed_chunks"] += 1
1191
-
1192
- if self.chunk_tracker:
1193
- self.chunk_tracker.mark_failed(chunk_id)
365
+ self.processor.mark_failed(unit_id, worker_id, error)
366
+ logger.warning(f"Work unit {unit_id} failed on worker {worker_id}: {error}")
1194
367
 
1195
- logger.warning(f"Chunk {chunk_id} failed on worker {worker_id}: {error}")
1196
- await self._send_activity(f"Chunk {chunk_id} failed on {worker_id}: {error}")
368
+ elif msg_type == "submit_results":
369
+ await self._handle_results_submission(worker_id, data)
1197
370
 
1198
- elif msg_type == "submit_captions":
1199
- await self._handle_captions_submission(worker_id, data)
1200
- elif msg_type == "request_job":
1201
- # CaptionWorker requesting a job from data samples
1202
- try:
1203
- job = await asyncio.wait_for(self.data_sample_queue.get(), timeout=5)
1204
- await self.workers[worker_id].send(
1205
- json.dumps({"type": "job_assignment", "job": job})
1206
- )
1207
- logger.debug(f"Assigned job {job['job_id']} to worker {worker_id}")
1208
- except asyncio.TimeoutError:
1209
- await self.workers[worker_id].send(json.dumps({"type": "no_jobs"}))
1210
371
  elif msg_type == "heartbeat":
1211
- # Update worker stats
1212
372
  logger.debug(f"Heartbeat from {worker_id}: {data}")
1213
373
 
1214
- async def _handle_captions_submission(self, worker_id: str, data: Dict):
1215
- """Process caption submission from worker - now handles multi-stage outputs."""
1216
- chunk_id = data.get("chunk_id")
1217
- item_key = data["item_key"]
1218
-
1219
- item_index = data.get("item_index") # Worker should send this
1220
- if item_index is None:
1221
- # Try to extract from item_key (format: dataset_XXXXXXXX)
1222
- try:
1223
- item_index = int(item_key.split("_")[-1])
1224
- except:
1225
- logger.warning(f"Could not extract item index from key: {item_key}")
1226
-
1227
- # Extract user from worker_id (format: "username_uuid")
374
+ async def _handle_results_submission(self, worker_id: str, data: Dict):
375
+ """Process results submission from worker."""
376
+ # Extract user from worker_id
1228
377
  worker_user = worker_id.rsplit("_", 1)[0] if "_" in worker_id else worker_id
1229
378
 
1230
- # Handle both old format (captions list) and new format (outputs dict)
1231
- if "outputs" in data:
1232
- # New multi-stage format
1233
- outputs = data["outputs"]
1234
- captions_list = outputs.get("captions", [])
1235
- total_outputs = sum(len(v) for v in outputs.values())
1236
-
1237
- logger.debug(
1238
- f"Received multi-stage outputs for item {item_key} from worker {worker_id}: "
1239
- f"{total_outputs} outputs across {len(outputs)} fields"
1240
- )
1241
- else:
1242
- # Old format - single captions list
1243
- captions_list = data["captions"]
1244
- outputs = {"captions": captions_list}
1245
- total_outputs = len(captions_list)
1246
-
1247
- logger.debug(
1248
- f"Received {len(captions_list)} captions for item {item_key} from worker {worker_id}"
1249
- )
379
+ # Create work result
380
+ _job_id = data.get("job_id")
381
+ job_id = JobId.from_str(_job_id)
382
+ shard_name = job_id.shard_id # >data-0000<
383
+ chunk_name = job_id.chunk_id # data-0000:chunk:>0<
384
+ # logger.debug(f"({job_id}) Worker result: {data}")
385
+ result = WorkResult(
386
+ unit_id=data["unit_id"],
387
+ source_id=shard_name,
388
+ chunk_id=job_id.get_chunk_str(), # we want the full string here
389
+ sample_id=data["sample_id"],
390
+ dataset=data["dataset"],
391
+ outputs=data["outputs"],
392
+ metadata=data.get("metadata", {}),
393
+ processing_time_ms=data.get("processing_time_ms", 0),
394
+ )
1250
395
 
1251
- # Create caption record with multi-stage outputs
396
+ # Let processor handle any custom processing
397
+ processed = self.processor.handle_result(result)
398
+
399
+ # Create caption record for storage
400
+ total_outputs = sum(len(v) for v in result.outputs.values())
401
+
402
+ filename = result.metadata.pop("_filename", None)
403
+ url = result.metadata.pop("_url", None)
404
+ image_height = result.metadata.pop("image_height", None)
405
+ image_width = result.metadata.pop("image_width", None)
406
+ file_size = result.metadata.pop("file_size", None)
407
+ image_format = result.metadata.pop("image_format", None)
408
+ item_index = result.metadata.pop("item_index", None)
409
+ item_key = result.metadata.pop("item_key", None)
410
+ to_delete_metadata_keys = ["_image_format", "_job_id"]
411
+ for key in to_delete_metadata_keys:
412
+ if key in result.metadata:
413
+ del result.metadata[key]
1252
414
  caption = Caption(
1253
- job_id=f"{chunk_id}_{item_key}",
1254
- dataset=data.get("dataset"),
1255
- shard=data.get("shard"),
415
+ job_id=job_id,
416
+ dataset=result.dataset,
417
+ shard=processed["source_id"],
418
+ chunk_id=chunk_name,
1256
419
  item_key=item_key,
1257
- captions=captions_list,
1258
- outputs=outputs,
420
+ captions=result.outputs.get("captions", []),
421
+ outputs=result.outputs,
1259
422
  contributor_id=worker_user,
1260
423
  timestamp=datetime.utcnow(),
1261
- quality_scores=None,
1262
- # Image metadata
1263
- image_width=data.get("image_width"),
1264
- image_height=data.get("image_height"),
1265
- image_format=data.get("image_format"),
1266
- file_size=data.get("file_size"),
1267
- # Processing metadata
1268
424
  caption_count=total_outputs,
1269
- processing_time_ms=data.get("processing_time_ms"),
1270
- chunk_id=chunk_id,
1271
- metadata=data.get("metadata", {}),
425
+ processing_time_ms=result.processing_time_ms,
426
+ metadata=result.metadata,
427
+ image_height=image_height,
428
+ image_width=image_width,
429
+ filename=filename,
430
+ url=url,
431
+ file_size=file_size,
432
+ image_format=image_format,
1272
433
  )
1273
434
 
1274
- # Add to central storage buffer
435
+ # Save to storage
1275
436
  await self.storage.save_caption(caption)
1276
437
 
1277
- # Handle item tracking with fixed deadlock
1278
- should_flush = False
1279
- if chunk_id and item_index is not None and self.chunk_tracker:
1280
- with self.item_batch_lock:
1281
- self.pending_processed_items[chunk_id].append(item_index)
1282
-
1283
- # Check if we should flush
1284
- total_pending = sum(
1285
- len(indices) for indices in self.pending_processed_items.values()
1286
- )
1287
- time_since_flush = time.time() - self.last_item_batch_flush
1288
-
1289
- if (
1290
- total_pending >= self.item_batch_size
1291
- or time_since_flush >= self.item_batch_interval
1292
- ):
1293
- should_flush = True
1294
-
1295
- if should_flush:
1296
- await self._flush_processed_items()
1297
-
1298
- # Update contributor stats (use user, not worker)
438
+ # Update contributor stats
1299
439
  contributor = await self.storage.get_contributor(worker_user)
1300
440
  if contributor:
1301
441
  contributor.total_captions += total_outputs
1302
442
  await self.storage.save_contributor(contributor)
1303
443
 
1304
- # Broadcast updated stats
1305
- await self._broadcast_stats()
1306
-
1307
- # Log progress periodically
1308
- total_outputs = self.stats.get("total_outputs", 0)
1309
- if total_outputs > 0 and total_outputs % 100 == 0:
1310
- if (
1311
- not hasattr(self, "_last_logged_outputs")
1312
- or self._last_logged_outputs != total_outputs
1313
- ):
1314
- logger.info(f"Collected {total_outputs} outputs centrally")
1315
- self._last_logged_outputs = total_outputs
1316
-
1317
- async def _check_shard_completion(self, chunk_id: str):
1318
- """Check if a shard is complete after chunk completion."""
1319
- # Get the chunk
1320
- chunk = self.chunk_manager.chunks.get(chunk_id)
1321
- if not chunk:
1322
- return
444
+ async def _handle_monitor(self, websocket: WebSocketServerProtocol):
445
+ """Handle monitor connection."""
446
+ self.monitors.add(websocket)
447
+ logger.info(f"Monitor connected (total: {len(self.monitors)})")
1323
448
 
1324
- shard_name = chunk.shard_name
449
+ try:
450
+ # Send welcome
451
+ await websocket.send(safe_json_dumps({"type": "welcome", "role": "monitor"}))
1325
452
 
1326
- # Find all chunks for this shard
1327
- shard_chunks = [
1328
- cid for cid, c in self.chunk_manager.chunks.items() if c.belongs_to_shard(shard_name)
1329
- ]
453
+ # Send initial stats
454
+ await self._send_monitor_stats(websocket)
1330
455
 
1331
- # Check if all are completed
1332
- completed_chunks = [
1333
- cid for cid in shard_chunks if self.chunk_manager.chunks[cid].status == "completed"
1334
- ]
456
+ # Keep connection alive
457
+ async for message in websocket:
458
+ pass
1335
459
 
1336
- if len(completed_chunks) == len(shard_chunks) and len(shard_chunks) > 0:
1337
- logger.info(f"Shard {shard_name} complete!")
1338
- # Don't mark virtual shards as complete in ShardTracker
1339
- if not shard_name.startswith("hf_dataset:"):
1340
- self.shard_tracker.mark_complete(shard_name)
1341
- self.stats["completed_shards"] += 1
1342
- await self._send_activity(f"Shard {shard_name} completed!")
460
+ except websockets.exceptions.ConnectionClosed:
461
+ logger.info("Monitor disconnected")
462
+ finally:
463
+ self.monitors.discard(websocket)
464
+
465
+ async def _handle_admin(self, websocket: WebSocketServerProtocol, auth_ticket):
466
+ """Handle admin connection."""
467
+ admin_id = getattr(auth_ticket, "name", "admin")
468
+ logger.info(f"Admin {admin_id} connected")
469
+
470
+ try:
471
+ await websocket.send(safe_json_dumps({"type": "welcome", "role": "admin"}))
472
+
473
+ async for message in websocket:
474
+ try:
475
+ data = json.loads(message)
476
+ if data.get("type") == "reload_config":
477
+ await self._handle_config_reload(websocket, data.get("config", {}))
478
+ elif data.get("type") == "get_stats":
479
+ await self._send_monitor_stats(websocket)
480
+
481
+ except json.JSONDecodeError as e:
482
+ logger.error(f"Invalid admin message: {e}")
483
+
484
+ except websockets.exceptions.ConnectionClosed:
485
+ logger.info(f"Admin {admin_id} disconnected")
1343
486
 
1344
487
  async def _handle_data_worker(self, websocket: WebSocketServerProtocol, auth_ticket):
1345
488
  """Handle data worker connection."""
@@ -1410,7 +553,64 @@ class Orchestrator:
1410
553
  finally:
1411
554
  del self.data_workers[worker_id]
1412
555
 
1413
- async def _send_leaderboard_to_monitor(self, websocket: WebSocketServerProtocol):
556
+ async def _send_monitor_initial_data(self, websocket: WebSocketServerProtocol):
557
+ """Send initial data to monitor in a separate task to avoid blocking."""
558
+ total_start = time.time()
559
+ try:
560
+ # Check if websocket is still in monitors set
561
+ if websocket not in self.monitors:
562
+ logger.debug("Monitor disconnected before initial data send")
563
+ return
564
+
565
+ # Send current stats (already in memory)
566
+ stats_start = time.time()
567
+ await websocket.send(safe_json_dumps({"type": "stats", "data": self.stats}))
568
+ logger.debug(f"Monitor stats sent in {(time.time() - stats_start)*1000:.1f}ms")
569
+
570
+ # Get processor stats instead of chunk stats
571
+ processor_stats_start = time.time()
572
+ processor_stats = self.processor.get_stats()
573
+ logger.debug(
574
+ f"Processor stats retrieved in {(time.time() - processor_stats_start)*1000:.1f}ms"
575
+ )
576
+
577
+ stats_send_start = time.time()
578
+ await websocket.send(
579
+ safe_json_dumps({"type": "processor_stats", "data": processor_stats})
580
+ )
581
+ logger.debug(f"Processor stats sent in {(time.time() - stats_send_start)*1000:.1f}ms")
582
+
583
+ if websocket not in self.monitors:
584
+ return
585
+
586
+ # For leaderboard, check if we have a cached version first
587
+ if hasattr(self, "_cached_leaderboard") and self._cached_leaderboard:
588
+ # Use cached leaderboard if available
589
+ cache_send_start = time.time()
590
+ await websocket.send(
591
+ safe_json_dumps({"type": "leaderboard", "data": self._cached_leaderboard})
592
+ )
593
+ logger.debug(
594
+ f"Cached leaderboard sent in {(time.time() - cache_send_start)*1000:.1f}ms"
595
+ )
596
+ else:
597
+ # Schedule leaderboard update separately
598
+ leaderboard_task_start = time.time()
599
+ asyncio.create_task(self._send_leaderboard_to_monitor(websocket))
600
+ logger.debug(
601
+ f"Leaderboard task created in {(time.time() - leaderboard_task_start)*1000:.1f}ms"
602
+ )
603
+
604
+ logger.debug(
605
+ f"Monitor initial data send completed in {(time.time() - total_start)*1000:.1f}ms"
606
+ )
607
+
608
+ except websockets.exceptions.ConnectionClosed:
609
+ logger.debug("Monitor disconnected during initial data send")
610
+ except Exception as e:
611
+ logger.error(f"Error sending initial monitor data: {e}")
612
+
613
+ async def _send_monitor_leaderboard(self, websocket: WebSocketServerProtocol):
1414
614
  """Send leaderboard data to a specific monitor."""
1415
615
  total_start = time.time()
1416
616
  try:
@@ -1475,210 +675,43 @@ class Orchestrator:
1475
675
  except Exception as e:
1476
676
  logger.error(f"Error sending leaderboard to monitor: {e}")
1477
677
 
1478
- async def _send_initial_monitor_data(self, websocket: WebSocketServerProtocol):
1479
- """Send initial data to monitor in a separate task to avoid blocking."""
1480
- total_start = time.time()
1481
- try:
1482
- # Check if websocket is still in monitors set
1483
- if websocket not in self.monitors:
1484
- logger.debug("Monitor disconnected before initial data send")
1485
- return
1486
-
1487
- # Send current stats (already in memory)
1488
- stats_start = time.time()
1489
- await websocket.send(safe_json_dumps({"type": "stats", "data": self.stats}))
1490
- logger.debug(f"Monitor stats sent in {(time.time() - stats_start)*1000:.1f}ms")
1491
-
1492
- # Get chunk stats asynchronously
1493
- chunk_stats_start = time.time()
1494
- loop = asyncio.get_event_loop()
1495
- chunk_stats = await loop.run_in_executor(None, self.chunk_manager.get_stats)
1496
- logger.debug(f"Chunk stats retrieved in {(time.time() - chunk_stats_start)*1000:.1f}ms")
1497
-
1498
- if websocket not in self.monitors:
1499
- return
1500
-
1501
- chunk_send_start = time.time()
1502
- await websocket.send(safe_json_dumps({"type": "chunk_stats", "data": chunk_stats}))
1503
- logger.debug(f"Chunk stats sent in {(time.time() - chunk_send_start)*1000:.1f}ms")
1504
-
1505
- # For leaderboard, check if we have a cached version first
1506
- if hasattr(self, "_cached_leaderboard") and self._cached_leaderboard:
1507
- # Use cached leaderboard if available
1508
- cache_send_start = time.time()
1509
- await websocket.send(
1510
- safe_json_dumps({"type": "leaderboard", "data": self._cached_leaderboard})
1511
- )
1512
- logger.debug(
1513
- f"Cached leaderboard sent in {(time.time() - cache_send_start)*1000:.1f}ms"
1514
- )
1515
- else:
1516
- # Schedule leaderboard update separately
1517
- leaderboard_task_start = time.time()
1518
- asyncio.create_task(self._send_leaderboard_to_monitor(websocket))
1519
- logger.debug(
1520
- f"Leaderboard task created in {(time.time() - leaderboard_task_start)*1000:.1f}ms"
1521
- )
1522
-
1523
- logger.debug(
1524
- f"Monitor initial data send completed in {(time.time() - total_start)*1000:.1f}ms"
1525
- )
1526
-
1527
- except websockets.exceptions.ConnectionClosed:
1528
- logger.debug("Monitor disconnected during initial data send")
1529
- except Exception as e:
1530
- logger.error(f"Error sending initial monitor data: {e}")
678
+ async def _send_monitor_stats(self, websocket: WebSocketServerProtocol):
679
+ """Send current stats to a monitor."""
680
+ # Get processor stats
681
+ processor_stats = self.processor.get_stats()
1531
682
 
1532
- async def _handle_monitor(self, websocket: WebSocketServerProtocol):
1533
- """Handle monitor connection - truly non-blocking version."""
1534
- monitor_start = time.time()
1535
- self.monitors.add(websocket)
1536
- logger.info(f"Monitor connected (total monitors: {len(self.monitors)})")
1537
-
1538
- try:
1539
- # Send welcome message immediately
1540
- welcome_start = time.time()
1541
- await websocket.send(safe_json_dumps({"type": "welcome", "role": "monitor"}))
1542
- logger.debug(f"Monitor welcome sent in {(time.time() - welcome_start)*1000:.1f}ms")
683
+ # Get storage stats
684
+ storage_stats = await self.storage.get_storage_stats()
1543
685
 
1544
- # Schedule initial data send as a separate task to avoid blocking
1545
- task_create_start = time.time()
1546
- asyncio.create_task(self._send_initial_monitor_data(websocket))
1547
- logger.debug(
1548
- f"Monitor initial data task created in {(time.time() - task_create_start)*1000:.1f}ms"
1549
- )
686
+ # Combine all stats
687
+ all_stats = {
688
+ **self.stats,
689
+ **storage_stats,
690
+ "processor_stats": processor_stats,
691
+ "current_rate": self.rate_tracker["current_rate"],
692
+ "average_rate": self.rate_tracker["average_rate"],
693
+ }
1550
694
 
1551
- # Just keep the connection alive - no blocking work here
1552
- try:
1553
- async for message in websocket:
1554
- # Handle any incoming messages from monitor if needed
1555
- # For now, just ignore them
1556
- pass
1557
- except websockets.exceptions.ConnectionClosed:
1558
- pass # Normal disconnection
695
+ await websocket.send(safe_json_dumps({"type": "stats", "data": all_stats}))
1559
696
 
1560
- except websockets.exceptions.ConnectionClosed:
1561
- logger.info("Monitor disconnected")
1562
- except Exception as e:
1563
- logger.error(f"Error in monitor handler: {e}")
1564
- finally:
1565
- self.monitors.discard(websocket)
1566
- logger.debug(f"Monitor handler completed in {(time.time() - monitor_start)*1000:.1f}ms")
1567
-
1568
- async def _broadcast_stats(self):
1569
- """Broadcast statistics to all monitors - truly non-blocking version."""
697
+ async def _send_activity(self, activity: str):
698
+ """Send activity update to monitors."""
1570
699
  if not self.monitors:
1571
700
  return
1572
- if self.is_generating_stats:
1573
- return # Already generating stats, skip this call
1574
- self.is_generating_stats = True
1575
- total_start = time.time()
1576
-
1577
- # Prepare all the data first
1578
- data_prep_start = time.time()
1579
- loop = asyncio.get_event_loop()
1580
701
 
1581
- # Get storage stats (already async)
1582
- storage_stats_start = time.time()
1583
- storage_stats = await self.storage.get_storage_stats()
1584
- logger.debug(f"Storage stats retrieved in {(time.time() - storage_stats_start)*1000:.1f}ms")
1585
-
1586
- caption_stats_start = time.time()
1587
- caption_stats = await self.storage.get_caption_stats()
1588
- logger.debug(f"Caption stats retrieved in {(time.time() - caption_stats_start)*1000:.1f}ms")
1589
-
1590
- # Get chunk stats in thread pool
1591
- chunk_stats_start = time.time()
1592
- chunk_stats = await loop.run_in_executor(None, self.chunk_manager.get_stats)
1593
- logger.debug(f"Chunk stats retrieved in {(time.time() - chunk_stats_start)*1000:.1f}ms")
1594
-
1595
- # Build stats dict
1596
- build_stats_start = time.time()
1597
- stats_update = self.stats.copy()
1598
- stats_update.update({f"chunks_{k}": v for k, v in chunk_stats.items()})
1599
- stats_update.update(storage_stats)
1600
- stats_update["field_breakdown"] = caption_stats.get("field_stats", {})
1601
- stats_update["output_fields_list"] = caption_stats.get("output_fields", [])
1602
-
1603
- # Add rate information
1604
- stats_update.update(
1605
- {
1606
- "current_rate": self.rate_tracker["current_rate"],
1607
- "average_rate": self.rate_tracker["average_rate"],
1608
- "expected_rate": self.rate_tracker["expected_rate"],
1609
- }
702
+ message = safe_json_dumps(
703
+ {"type": "activity", "data": f"[{datetime.now().strftime('%H:%M:%S')}] {activity}"}
1610
704
  )
1611
705
 
1612
- # Add vLLM info
1613
- stats_update["vllm_model"] = self.vllm_config.get("model", "unknown")
1614
- stats_update["vllm_batch_size"] = self.vllm_config.get("batch_size", 0)
1615
-
1616
- # Add stage information
1617
- stages = self.vllm_config.get("stages", [])
1618
- if stages:
1619
- stats_update["stage_count"] = len(stages)
1620
- stats_update["stage_names"] = [s.get("name", "unnamed") for s in stages]
1621
- else:
1622
- stats_update["stage_count"] = 1
1623
- stats_update["stage_names"] = ["default"]
1624
-
1625
- # Get field stats
1626
- field_stats_start = time.time()
1627
- field_stats = await self.storage.get_output_field_stats()
1628
- stats_update["output_fields"] = field_stats
1629
- logger.debug(f"Field stats retrieved in {(time.time() - field_stats_start)*1000:.1f}ms")
1630
-
1631
- # Update our internal stats
1632
- self.stats = stats_update
1633
- logger.debug(f"Stats prepared in {(time.time() - build_stats_start)*1000:.1f}ms")
1634
-
1635
- logger.debug(f"Total data preparation took {(time.time() - data_prep_start)*1000:.1f}ms")
1636
-
1637
- # Create message once
1638
- message_create_start = time.time()
1639
- stats_message = safe_json_dumps({"type": "stats", "data": self.stats})
1640
- logger.debug(f"Stats message created in {(time.time() - message_create_start)*1000:.1f}ms")
1641
-
1642
- # Send to all monitors asynchronously in parallel
1643
- send_start = time.time()
1644
-
1645
- async def send_to_monitor(monitor):
706
+ disconnected = set()
707
+ for monitor in self.monitors:
1646
708
  try:
1647
- await monitor.send(stats_message)
709
+ await monitor.send(message)
1648
710
  except websockets.exceptions.ConnectionClosed:
1649
- return monitor # Return for removal
1650
- except Exception as e:
1651
- logger.debug(f"Error sending stats to monitor: {e}")
1652
- return monitor # Return for removal
1653
- return None
1654
-
1655
- # Send to all monitors in parallel
1656
- monitors_copy = self.monitors.copy()
1657
- results = await asyncio.gather(
1658
- *[send_to_monitor(m) for m in monitors_copy], return_exceptions=True
1659
- )
711
+ disconnected.add(monitor)
1660
712
 
1661
- # Remove disconnected monitors
1662
- disconnected = {
1663
- m
1664
- for m, r in zip(monitors_copy, results)
1665
- if r is not None and not isinstance(r, Exception)
1666
- }
1667
713
  self.monitors -= disconnected
1668
714
 
1669
- logger.debug(
1670
- f"Stats sent to {len(monitors_copy)} monitors in {(time.time() - send_start)*1000:.1f}ms"
1671
- )
1672
-
1673
- # Send leaderboard update in a separate task to avoid blocking
1674
- leaderboard_task_start = time.time()
1675
- asyncio.create_task(self._broadcast_leaderboard())
1676
- self.is_generating_stats = False
1677
- logger.debug(
1678
- f"Leaderboard broadcast task created in {(time.time() - leaderboard_task_start)*1000:.1f}ms"
1679
- )
1680
- logger.debug(f"Stats broadcast completed in {(time.time() - total_start)*1000:.1f}ms")
1681
-
1682
715
  async def _broadcast_leaderboard(self):
1683
716
  """Send leaderboard updates to monitors - separate from stats to avoid blocking."""
1684
717
  if not self.monitors:
@@ -1769,82 +802,38 @@ class Orchestrator:
1769
802
  except Exception as e:
1770
803
  logger.error(f"Error broadcasting leaderboard: {e}")
1771
804
 
1772
- def _get_queue_stats(self) -> Dict[str, int]:
1773
- """Get queue statistics - synchronous helper for thread pool."""
1774
- with self.chunk_manager.lock:
1775
- return {
1776
- "pending_chunks": len(self.chunk_manager.pending_chunks),
1777
- "assigned_chunks": sum(
1778
- len(chunks) for chunks in self.chunk_manager.assigned_chunks.values()
1779
- ),
1780
- }
1781
-
1782
- async def _flush_processed_items(self):
1783
- """Flush batched processed items to chunk tracker."""
1784
- with self.item_batch_lock:
1785
- if not self.pending_processed_items:
1786
- return
1787
-
1788
- for chunk_id, indices in self.pending_processed_items.items():
1789
- if not indices:
1790
- continue
1791
-
1792
- # Indices here are ABSOLUTE dataset indices
1793
- # Sort indices
1794
- indices.sort()
1795
-
1796
- # Group consecutive indices into ranges
1797
- ranges = []
1798
- start = indices[0]
1799
- end = indices[0]
1800
-
1801
- for i in range(1, len(indices)):
1802
- if indices[i] == end + 1:
1803
- # Consecutive, extend range
1804
- end = indices[i]
1805
- else:
1806
- # Gap found, save current range and start new one
1807
- ranges.append((start, end))
1808
- start = indices[i]
1809
- end = indices[i]
1810
-
1811
- # Don't forget the last range
1812
- ranges.append((start, end))
1813
-
1814
- # Mark ranges as processed (mark_items_processed expects absolute indices)
1815
- for start_idx, end_idx in ranges:
1816
- self.chunk_tracker.mark_items_processed(chunk_id, start_idx, end_idx)
1817
-
1818
- # Clear pending items
1819
- self.pending_processed_items.clear()
1820
- self.last_item_batch_flush = time.time()
1821
-
1822
- def get_workers_by_user_stats(self) -> Dict[str, Any]:
1823
- """Get statistics about workers grouped by user/token - thread-safe version."""
1824
- if not hasattr(self, "workers_by_user"):
1825
- return {}
1826
-
1827
- # Create a copy to avoid issues with concurrent modification
1828
- stats = {}
1829
- workers_snapshot = dict(self.workers_by_user)
1830
- for user, worker_ids in workers_snapshot.items():
1831
- stats[user] = {"worker_count": len(worker_ids), "worker_ids": list(worker_ids)}
1832
- return stats
1833
-
1834
- async def _send_activity(self, activity: str):
1835
- """Send activity update to monitors."""
805
+ async def _broadcast_stats(self):
806
+ """Broadcast statistics to all monitors."""
1836
807
  if not self.monitors:
1837
808
  return
1838
809
 
1839
- message = safe_json_dumps(
1840
- {"type": "activity", "data": f"[{datetime.now().strftime('%H:%M:%S')}] {activity}"}
810
+ # Get current stats
811
+ processor_stats = self.processor.get_stats()
812
+ storage_stats = await self.storage.get_storage_stats()
813
+
814
+ # Update main stats
815
+ self.stats["processor_stats"] = processor_stats
816
+ self.stats["total_outputs"] = storage_stats["total_captions"]
817
+
818
+ # Create message
819
+ stats_message = safe_json_dumps(
820
+ {
821
+ "type": "stats",
822
+ "data": {
823
+ **self.stats,
824
+ **storage_stats,
825
+ "current_rate": self.rate_tracker["current_rate"],
826
+ "average_rate": self.rate_tracker["average_rate"],
827
+ },
828
+ }
1841
829
  )
1842
830
 
831
+ # Send to all monitors
1843
832
  disconnected = set()
1844
833
  for monitor in self.monitors:
1845
834
  try:
1846
- await monitor.send(message)
1847
- except websockets.exceptions.ConnectionClosed:
835
+ await monitor.send(stats_message)
836
+ except:
1848
837
  disconnected.add(monitor)
1849
838
 
1850
839
  self.monitors -= disconnected
@@ -1852,235 +841,74 @@ class Orchestrator:
1852
841
  async def _heartbeat_loop(self):
1853
842
  """Send periodic heartbeats to maintain connections."""
1854
843
  while True:
1855
- try:
1856
- await asyncio.sleep(30)
1857
-
1858
- # Create a copy of worker items to avoid modification during iteration
1859
- worker_items = list(self.workers.items())
1860
- disconnected = []
1861
-
1862
- for worker_id, ws in worker_items:
1863
- try:
1864
- # Check if worker still exists before pinging
1865
- if worker_id not in self.workers:
1866
- continue
1867
-
1868
- # Send ping with timeout
1869
- pong_waiter = await ws.ping()
1870
- try:
1871
- await asyncio.wait_for(pong_waiter, timeout=10)
1872
- except asyncio.TimeoutError:
1873
- logger.warning(f"Worker {worker_id} failed to respond to ping")
1874
- disconnected.append(worker_id)
1875
- except websockets.exceptions.ConnectionClosed:
1876
- logger.info(f"Worker {worker_id} connection already closed")
1877
- disconnected.append(worker_id)
1878
- except Exception as e:
1879
- logger.error(f"Error pinging worker {worker_id}: {e}")
1880
- disconnected.append(worker_id)
1881
-
1882
- # Clean up disconnected workers
1883
- for worker_id in disconnected:
1884
- if worker_id in self.workers:
1885
- logger.info(f"Removing unresponsive worker {worker_id}")
1886
- del self.workers[worker_id]
1887
- self.chunk_manager.release_worker_chunks(worker_id)
1888
-
1889
- # Update stats
1890
- self.stats["connected_workers"] = len(self.workers)
1891
-
1892
- # Also clean up from workers_by_user if it exists
1893
- if hasattr(self, "workers_by_user"):
1894
- worker_user = (
1895
- worker_id.rsplit("_", 1)[0] if "_" in worker_id else worker_id
1896
- )
1897
- if worker_user in self.workers_by_user:
1898
- self.workers_by_user[worker_user].discard(worker_id)
1899
- if not self.workers_by_user[worker_user]:
1900
- del self.workers_by_user[worker_user]
1901
-
1902
- # Notify monitors
1903
- await self._broadcast_stats()
1904
- await self._send_activity(
1905
- f"Worker {worker_id} removed due to heartbeat timeout"
1906
- )
1907
-
1908
- except Exception as e:
1909
- logger.error(f"Error in heartbeat loop: {e}", exc_info=True)
1910
- # Continue the loop even if there's an error
1911
- await asyncio.sleep(5)
844
+ await asyncio.sleep(30)
845
+
846
+ disconnected = []
847
+ for worker_id, ws in list(self.workers.items()):
848
+ try:
849
+ pong_waiter = await ws.ping()
850
+ await asyncio.wait_for(pong_waiter, timeout=10)
851
+ except:
852
+ disconnected.append(worker_id)
853
+
854
+ # Clean up disconnected workers
855
+ for worker_id in disconnected:
856
+ logger.warning(f"Worker {worker_id} did not respond to ping, disconnecting")
857
+ if worker_id in self.workers:
858
+ del self.workers[worker_id]
859
+ logger.warning(
860
+ f"Releasing assignments for worker {worker_id} because it did not respond to ping"
861
+ )
862
+ self.processor.release_assignments(worker_id)
863
+ self.stats["connected_workers"] = len(self.workers)
1912
864
 
1913
865
  async def _checkpoint_loop(self):
1914
866
  """Periodically checkpoint storage."""
1915
- interval = self.config.get("storage", {}).get("checkpoint_interval", 1000)
867
+ interval = self.config.get("storage", {}).get("checkpoint_interval", 60)
1916
868
 
1917
869
  while True:
1918
- await asyncio.sleep(60)
870
+ await asyncio.sleep(interval)
1919
871
 
1920
- # Get current caption count from storage
1921
- storage_stats = await self.storage.get_storage_stats()
1922
- total_captions = storage_stats["total_captions"]
1923
-
1924
- # Force checkpoint at regular intervals
1925
- if total_captions > 0 and total_captions % interval == 0:
1926
- logger.info(f"Triggering checkpoint at {total_captions} captions")
1927
- await self.storage.checkpoint()
1928
-
1929
- # Update stats
1930
- self.stats["last_checkpoint"] = datetime.utcnow().isoformat()
1931
- # No need to update total_written or buffer_size - they come from storage
1932
-
1933
- await self._broadcast_stats()
1934
- logger.info(
1935
- f"Checkpoint complete. Total written to disk: {storage_stats['total_written']}"
1936
- )
872
+ await self.storage.checkpoint()
873
+ self.stats["last_checkpoint"] = datetime.utcnow().isoformat()
874
+ logger.info("Storage checkpoint complete")
1937
875
 
1938
876
  async def _stats_update_loop(self):
1939
- """Periodically update and broadcast stats - non-blocking version."""
1940
- # Get the event loop for running blocking operations
1941
- loop = asyncio.get_event_loop()
1942
-
1943
- # Track session start values
1944
- storage_stats = await self.storage.get_storage_stats()
1945
- session_start_outputs = storage_stats["total_captions"] # This now counts ALL outputs
1946
- session_start_time = time.time()
1947
-
1948
- # Track the last known total to detect flushes
1949
- last_known_total = session_start_outputs
1950
-
877
+ """Periodically update and broadcast stats."""
1951
878
  while True:
1952
879
  await asyncio.sleep(10)
1953
880
 
1954
- # Update chunk stats in thread pool to avoid blocking
1955
- chunk_stats = await loop.run_in_executor(None, self.chunk_manager.get_stats)
881
+ # Update rate tracking
1956
882
  storage_stats = await self.storage.get_storage_stats()
1957
- current_total_outputs = storage_stats["total_captions"] # ALL outputs
1958
- if self.chunk_tracker:
1959
- await self._flush_processed_items()
1960
-
1961
- self.stats["total_chunks"] = chunk_stats["total"]
1962
- self.stats["completed_chunks"] = chunk_stats["completed"]
1963
- self.stats["failed_chunks"] = chunk_stats["failed"]
1964
-
1965
- # Update total outputs stat (rename from total_captions for clarity)
1966
- self.stats["total_outputs"] = current_total_outputs
1967
- self.stats["total_captions"] = current_total_outputs # Keep for backward compatibility
1968
-
1969
- # Get queue stats in thread pool to avoid blocking
1970
- queue_stats = await loop.run_in_executor(None, self._get_queue_stats)
1971
- self.stats.update(queue_stats)
1972
-
1973
- # Calculate if we need more chunks
1974
- worker_count = self.stats.get("connected_workers", 0)
1975
- target_buffer = max(self.min_chunk_buffer, worker_count * self.chunk_buffer_multiplier)
1976
- active_chunks = self.stats["pending_chunks"] + self.stats["assigned_chunks"]
1977
- self.stats["chunk_buffer_status"] = f"{active_chunks}/{target_buffer}"
1978
-
1979
- # Update rate information
883
+ current_total = storage_stats["total_captions"]
1980
884
  current_time = time.time()
1981
- elapsed_since_update = current_time - self.rate_tracker["last_update_time"]
1982
-
1983
- if elapsed_since_update > 0:
1984
- # FIX: Handle the case where duplicates were skipped during save
1985
- # If current total is less than last known, it means duplicates were skipped
1986
- # We should not count this as negative progress
1987
- if current_total_outputs < last_known_total:
1988
- logger.debug(
1989
- f"Detected duplicate skip during save: {last_known_total} -> {current_total_outputs}"
1990
- )
1991
- # Don't calculate negative rate, just update the baseline
1992
- self.rate_tracker["last_caption_count"] = current_total_outputs
1993
- self.rate_tracker["current_rate"] = 0.0 # Set to 0 during flush
1994
- else:
1995
- # Normal rate calculation
1996
- output_diff = current_total_outputs - self.rate_tracker["last_caption_count"]
1997
- self.rate_tracker["current_rate"] = (output_diff / elapsed_since_update) * 60
1998
- self.rate_tracker["last_caption_count"] = current_total_outputs
1999
-
2000
- # Calculate average rate since THIS SESSION started
2001
- session_elapsed = current_time - session_start_time
2002
- if session_elapsed > 0:
2003
- # Always use the difference from session start for average
2004
- session_outputs = current_total_outputs - session_start_outputs
2005
- self.rate_tracker["average_rate"] = (session_outputs / session_elapsed) * 60
2006
-
2007
- # Calculate expected rate based on workers and stages
2008
- batch_size = self.vllm_config.get("batch_size", 8)
2009
-
2010
- # Count total prompts across all stages
2011
- total_prompts = 0
2012
- stages = self.vllm_config.get("stages", [])
2013
- if stages:
2014
- for stage in stages:
2015
- total_prompts += len(stage.get("prompts", []))
2016
- else:
2017
- # Backward compatibility
2018
- total_prompts = len(self.vllm_config.get("inference_prompts", ["", "", ""]))
2019
885
 
2020
- images_per_minute = 30 # Rough estimate: 30 images/min per worker
2021
- self.rate_tracker["expected_rate"] = (
2022
- worker_count * images_per_minute * total_prompts
2023
- )
2024
-
2025
- # Update trackers
886
+ elapsed = current_time - self.rate_tracker["last_update_time"]
887
+ if elapsed > 0:
888
+ output_diff = current_total - self.rate_tracker["last_output_count"]
889
+ self.rate_tracker["current_rate"] = (output_diff / elapsed) * 60
890
+ self.rate_tracker["last_output_count"] = current_total
2026
891
  self.rate_tracker["last_update_time"] = current_time
2027
- last_known_total = current_total_outputs
2028
-
2029
- # Log rate information when workers are connected
2030
- if (
2031
- worker_count > 0 and self.rate_tracker["current_rate"] >= 0
2032
- ): # Only log non-negative rates
2033
- logger.info(
2034
- f"Rate: {self.rate_tracker['current_rate']:.1f} outputs/min "
2035
- f"(avg: {self.rate_tracker['average_rate']:.1f}, "
2036
- f"expected: {self.rate_tracker['expected_rate']:.1f}) | "
2037
- f"Workers: {worker_count}, Chunks: {active_chunks}/{target_buffer}"
2038
- )
2039
892
 
2040
- await self._broadcast_stats()
893
+ # Average rate since start
894
+ total_elapsed = current_time - self.rate_tracker["start_time"]
895
+ if total_elapsed > 0:
896
+ self.rate_tracker["average_rate"] = (current_total / total_elapsed) * 60
2041
897
 
2042
- async def _restore_state(self):
2043
- """Restore state from storage on startup."""
2044
- total_captions = await self.storage.count_captions()
2045
- logger.info(f"Restored state: {total_captions} captions")
898
+ await self._broadcast_stats()
2046
899
 
2047
900
  async def shutdown(self):
2048
901
  """Graceful shutdown."""
2049
902
  logger.info("Shutting down orchestrator...")
2050
903
 
2051
- # Stop chunk creation
2052
- if self.chunk_tracker:
2053
- await self._flush_processed_items()
2054
- self.stop_chunk_creation.set()
2055
- if self.chunk_creation_thread:
2056
- self.chunk_creation_thread.join(timeout=5)
2057
-
2058
- # Release all assigned chunks before closing connections
2059
- for worker_id in list(self.workers.keys()):
2060
- self.chunk_manager.release_worker_chunks(worker_id)
2061
- if self.chunk_tracker:
2062
- # Update chunk tracker to mark assigned chunks as pending
2063
- with self.chunk_manager.lock:
2064
- for chunk_id in list(self.chunk_manager.assigned_chunks.get(worker_id, [])):
2065
- self.chunk_tracker.mark_pending(chunk_id)
2066
-
2067
904
  # Close all connections
2068
905
  for ws in list(self.workers.values()):
2069
906
  await ws.close()
2070
907
  for ws in list(self.monitors):
2071
908
  await ws.close()
2072
909
 
2073
- # Save chunk state
2074
- if self.chunk_tracker:
2075
- self.chunk_tracker.save()
2076
-
2077
910
  # Final checkpoint
2078
- logger.info(f"Final flush: {len(self.storage.caption_buffer)} captions in buffer")
2079
911
  await self.storage.checkpoint()
2080
-
2081
- # Log final statistics
2082
- logger.info(
2083
- f"Shutdown complete. Total captions collected: {self.storage.total_captions_written}"
2084
- )
2085
-
2086
912
  await self.storage.close()
913
+
914
+ logger.info("Shutdown complete")