caption-flow 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +3 -2
- caption_flow/cli.py +65 -42
- caption_flow/models.py +6 -4
- caption_flow/monitor.py +13 -3
- caption_flow/orchestrator.py +1049 -264
- caption_flow/storage.py +579 -222
- caption_flow/utils/__init__.py +3 -1
- caption_flow/utils/auth.py +24 -25
- caption_flow/utils/checkpoint_tracker.py +92 -0
- caption_flow/utils/chunk_tracker.py +278 -194
- caption_flow/utils/dataset_loader.py +567 -73
- caption_flow/utils/image_processor.py +121 -1
- caption_flow/utils/prompt_template.py +137 -0
- caption_flow/utils/shard_processor.py +315 -0
- caption_flow/utils/shard_tracker.py +87 -0
- caption_flow/workers/base.py +228 -0
- caption_flow/workers/caption.py +1321 -0
- caption_flow/{worker_data.py → workers/data.py} +162 -234
- caption_flow-0.2.1.dist-info/METADATA +370 -0
- caption_flow-0.2.1.dist-info/RECORD +29 -0
- caption_flow/worker.py +0 -300
- caption_flow/worker_vllm.py +0 -1028
- caption_flow-0.1.0.dist-info/METADATA +0 -427
- caption_flow-0.1.0.dist-info/RECORD +0 -25
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.1.dist-info}/WHEEL +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.1.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.1.dist-info}/top_level.txt +0 -0
caption_flow/orchestrator.py
CHANGED
@@ -21,15 +21,15 @@ from collections import deque, defaultdict
|
|
21
21
|
import threading
|
22
22
|
from queue import Queue, Empty
|
23
23
|
|
24
|
+
from .workers import data
|
24
25
|
import websockets
|
25
26
|
from websockets.server import WebSocketServerProtocol
|
26
27
|
|
27
28
|
from .storage import StorageManager
|
28
29
|
from .models import Caption, Contributor
|
29
30
|
from .utils.auth import AuthManager
|
30
|
-
from .utils
|
31
|
+
from .utils import DatasetLoader, ShardTracker, ChunkTracker
|
31
32
|
from .utils.json_utils import safe_dict, safe_json_dumps, to_json_dict
|
32
|
-
from .utils.chunk_tracker import ChunkTracker
|
33
33
|
|
34
34
|
logger = logging.getLogger(__name__)
|
35
35
|
|
@@ -48,6 +48,43 @@ class ShardChunk:
|
|
48
48
|
assigned_at: Optional[datetime] = None
|
49
49
|
completed_at: Optional[datetime] = None
|
50
50
|
|
51
|
+
@classmethod
|
52
|
+
def create(
|
53
|
+
cls, shard_url: str, shard_name: str, start_index: int, chunk_size: int
|
54
|
+
) -> "ShardChunk":
|
55
|
+
"""Factory method to create a chunk with consistent ID."""
|
56
|
+
# Always use consistent format: dataset_chunk_startindex
|
57
|
+
if shard_url.startswith("hf_dataset:"):
|
58
|
+
# Extract dataset path
|
59
|
+
parts = shard_url.split(":")
|
60
|
+
dataset_path = parts[1] if len(parts) > 1 else "unknown"
|
61
|
+
chunk_id = f"{dataset_path.replace('/', '_')}_chunk_{start_index}"
|
62
|
+
else:
|
63
|
+
# WebDataset format
|
64
|
+
chunk_id = f"{shard_name}_chunk_{start_index}"
|
65
|
+
|
66
|
+
return cls(
|
67
|
+
chunk_id=chunk_id,
|
68
|
+
shard_url=shard_url,
|
69
|
+
shard_name=shard_name,
|
70
|
+
start_index=start_index,
|
71
|
+
chunk_size=chunk_size,
|
72
|
+
)
|
73
|
+
|
74
|
+
def belongs_to_shard(self, shard_identifier: str) -> bool:
|
75
|
+
"""Check if this chunk belongs to a given shard."""
|
76
|
+
return self.shard_name == shard_identifier
|
77
|
+
|
78
|
+
def to_dict(self) -> Dict[str, Any]:
|
79
|
+
"""Convert to dict for JSON serialization (for workers)."""
|
80
|
+
return {
|
81
|
+
"chunk_id": self.chunk_id,
|
82
|
+
"shard_url": self.shard_url,
|
83
|
+
"shard_name": self.shard_name,
|
84
|
+
"start_index": self.start_index,
|
85
|
+
"chunk_size": self.chunk_size,
|
86
|
+
}
|
87
|
+
|
51
88
|
|
52
89
|
class ChunkManager:
|
53
90
|
"""Manages shard chunk creation and assignment."""
|
@@ -67,9 +104,7 @@ class ChunkManager:
|
|
67
104
|
chunks = []
|
68
105
|
|
69
106
|
for start_idx in range(0, total_items, self.chunk_size):
|
70
|
-
|
71
|
-
chunk = ShardChunk(
|
72
|
-
chunk_id=chunk_id,
|
107
|
+
chunk = ShardChunk.create(
|
73
108
|
shard_url=shard_url,
|
74
109
|
shard_name=shard_name,
|
75
110
|
start_index=start_idx,
|
@@ -77,8 +112,8 @@ class ChunkManager:
|
|
77
112
|
)
|
78
113
|
|
79
114
|
with self.lock:
|
80
|
-
self.chunks[chunk_id] = chunk
|
81
|
-
self.pending_chunks.append(chunk_id)
|
115
|
+
self.chunks[chunk.chunk_id] = chunk
|
116
|
+
self.pending_chunks.append(chunk.chunk_id)
|
82
117
|
|
83
118
|
chunks.append(chunk)
|
84
119
|
|
@@ -86,24 +121,84 @@ class ChunkManager:
|
|
86
121
|
|
87
122
|
def get_chunks_for_worker(
|
88
123
|
self, worker_id: str, count: int = 1, tracker: Optional["ChunkTracker"] = None
|
89
|
-
) -> List[
|
90
|
-
"""Get available chunks for a worker."""
|
124
|
+
) -> List[Dict[str, Any]]:
|
125
|
+
"""Get available chunks with unprocessed items for a worker."""
|
91
126
|
assigned = []
|
92
127
|
|
93
128
|
with self.lock:
|
129
|
+
# FIRST PRIORITY: Check if this worker already has assigned chunks
|
130
|
+
# Workers should complete their current chunks before getting new ones
|
131
|
+
if worker_id in self.assigned_chunks:
|
132
|
+
existing_chunk_ids = list(self.assigned_chunks[worker_id])
|
133
|
+
for chunk_id in existing_chunk_ids:
|
134
|
+
if len(assigned) >= count:
|
135
|
+
break
|
136
|
+
|
137
|
+
chunk = self.chunks.get(chunk_id)
|
138
|
+
if not chunk:
|
139
|
+
continue
|
140
|
+
|
141
|
+
# Check if chunk still has unprocessed items
|
142
|
+
if tracker:
|
143
|
+
chunk_info = tracker.get_chunk_with_unprocessed_items(chunk_id)
|
144
|
+
if chunk_info and chunk_info["unprocessed_ranges"]:
|
145
|
+
assigned.append(
|
146
|
+
{
|
147
|
+
"chunk": chunk,
|
148
|
+
"unprocessed_ranges": chunk_info["unprocessed_ranges"],
|
149
|
+
}
|
150
|
+
)
|
151
|
+
else:
|
152
|
+
# No tracker, assume chunk needs processing
|
153
|
+
assigned.append(
|
154
|
+
{
|
155
|
+
"chunk": chunk,
|
156
|
+
"unprocessed_ranges": [(0, chunk.chunk_size - 1)],
|
157
|
+
}
|
158
|
+
)
|
159
|
+
|
160
|
+
# SECOND PRIORITY: Get new pending chunks
|
161
|
+
# Only if worker doesn't have enough chunks already
|
94
162
|
while len(assigned) < count and self.pending_chunks:
|
95
163
|
chunk_id = self.pending_chunks.popleft()
|
96
|
-
chunk = self.chunks
|
164
|
+
chunk = self.chunks.get(chunk_id)
|
165
|
+
|
166
|
+
if not chunk:
|
167
|
+
continue
|
97
168
|
|
169
|
+
# Verify chunk is truly pending (defensive check)
|
170
|
+
if chunk.status != "pending" or chunk.assigned_to is not None:
|
171
|
+
logger.warning(
|
172
|
+
f"Chunk {chunk_id} in pending queue but status={chunk.status}, assigned_to={chunk.assigned_to}"
|
173
|
+
)
|
174
|
+
continue
|
175
|
+
|
176
|
+
# Assign to this worker
|
98
177
|
chunk.assigned_to = worker_id
|
99
178
|
chunk.status = "assigned"
|
100
179
|
chunk.assigned_at = datetime.utcnow()
|
101
|
-
|
102
180
|
self.assigned_chunks[worker_id].add(chunk_id)
|
103
|
-
|
181
|
+
|
182
|
+
# Get unprocessed ranges
|
183
|
+
unprocessed_ranges = [(0, chunk.chunk_size - 1)] # Default
|
104
184
|
if tracker:
|
185
|
+
chunk_info = tracker.get_chunk_with_unprocessed_items(chunk_id)
|
186
|
+
if chunk_info:
|
187
|
+
unprocessed_ranges = chunk_info["unprocessed_ranges"]
|
105
188
|
tracker.mark_assigned(chunk_id, worker_id)
|
106
189
|
|
190
|
+
assigned.append({"chunk": chunk, "unprocessed_ranges": unprocessed_ranges})
|
191
|
+
|
192
|
+
# Log what we're assigning
|
193
|
+
if assigned:
|
194
|
+
chunk_summary = ", ".join(
|
195
|
+
[
|
196
|
+
f"{info['chunk'].chunk_id}[{len(info['unprocessed_ranges'])} ranges]"
|
197
|
+
for info in assigned
|
198
|
+
]
|
199
|
+
)
|
200
|
+
logger.info(f"Assigning to worker {worker_id}: {chunk_summary}")
|
201
|
+
|
107
202
|
return assigned
|
108
203
|
|
109
204
|
def complete_chunk(self, chunk_id: str, worker_id: str) -> bool:
|
@@ -173,6 +268,27 @@ class Orchestrator:
|
|
173
268
|
self.dataset_config = config.get("dataset", {})
|
174
269
|
self.dataset_path = self.dataset_config.get("path")
|
175
270
|
self.dataset_type = self.dataset_config.get("type", "huggingface")
|
271
|
+
self.dataset_split = self.dataset_config.get("split", "train") # Add split configuration
|
272
|
+
self.dataset_image_column = self.dataset_config.get(
|
273
|
+
"image_column", "image"
|
274
|
+
) # Add image column config
|
275
|
+
|
276
|
+
# Dataset components
|
277
|
+
self.dataset_loader = None
|
278
|
+
self.shard_tracker = None
|
279
|
+
self.chunk_tracker = None
|
280
|
+
|
281
|
+
if self.dataset_path:
|
282
|
+
self.dataset_loader = DatasetLoader(
|
283
|
+
self.dataset_path,
|
284
|
+
self.dataset_type,
|
285
|
+
self.dataset_split,
|
286
|
+
self.dataset_image_column,
|
287
|
+
)
|
288
|
+
checkpoint_dir = Path(config.get("storage", {}).get("checkpoint_dir", "./checkpoints"))
|
289
|
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
290
|
+
self.shard_tracker = ShardTracker(checkpoint_dir / "shards.json")
|
291
|
+
self.chunk_tracker = ChunkTracker(checkpoint_dir / "chunks.json")
|
176
292
|
|
177
293
|
# vLLM configuration to distribute to workers
|
178
294
|
self.vllm_config = config.get(
|
@@ -233,6 +349,11 @@ class Orchestrator:
|
|
233
349
|
|
234
350
|
# Initialize chunk manager with reference to chunk tracker
|
235
351
|
self.chunk_manager = ChunkManager(self.chunk_size, self.chunk_tracker)
|
352
|
+
self.pending_processed_items = defaultdict(list) # chunk_id -> list of indices
|
353
|
+
self.item_batch_lock = threading.Lock()
|
354
|
+
self.last_item_batch_flush = time.time()
|
355
|
+
self.item_batch_interval = 5 # Flush every 5 seconds
|
356
|
+
self.item_batch_size = 100 # Or every 100 items
|
236
357
|
|
237
358
|
# Track connections
|
238
359
|
self.workers: Dict[str, WebSocketServerProtocol] = {}
|
@@ -242,17 +363,15 @@ class Orchestrator:
|
|
242
363
|
self.ssl_context = self._setup_ssl()
|
243
364
|
|
244
365
|
# Statistics
|
366
|
+
self.is_generating_stats = False
|
245
367
|
self.stats = {
|
246
368
|
"total_chunks": 0,
|
247
369
|
"completed_chunks": 0,
|
248
370
|
"failed_chunks": 0,
|
249
|
-
"total_captions": 0,
|
250
371
|
"connected_workers": 0,
|
251
372
|
"total_shards": 0,
|
252
373
|
"completed_shards": 0,
|
253
374
|
"current_shard": None,
|
254
|
-
"buffer_size": 0,
|
255
|
-
"total_written": 0,
|
256
375
|
"last_checkpoint": None,
|
257
376
|
}
|
258
377
|
|
@@ -266,7 +385,7 @@ class Orchestrator:
|
|
266
385
|
"expected_rate": 0.0,
|
267
386
|
}
|
268
387
|
|
269
|
-
# Data sample queue for
|
388
|
+
# Data sample queue for CaptionWorker
|
270
389
|
self.data_sample_queue = asyncio.Queue(maxsize=1000)
|
271
390
|
self.data_workers: Dict[str, WebSocketServerProtocol] = {}
|
272
391
|
|
@@ -310,10 +429,23 @@ class Orchestrator:
|
|
310
429
|
# Mark state as not restored until we process checkpoints
|
311
430
|
self.state_restored.clear()
|
312
431
|
|
432
|
+
# Get dataset info to check format
|
433
|
+
dataset_info = self.dataset_loader.get_dataset_info()
|
434
|
+
dataset_format = dataset_info.get("dataset_format", "unknown")
|
435
|
+
logger.info(f"Dataset format: {dataset_format}")
|
436
|
+
|
313
437
|
# Get all shards
|
314
438
|
self.all_shards = self.dataset_loader.get_shard_list()
|
315
439
|
self.stats["total_shards"] = len(self.all_shards)
|
316
440
|
|
441
|
+
# For HuggingFace datasets, we might need to dynamically create more shards
|
442
|
+
if dataset_format == "huggingface_datasets":
|
443
|
+
self._is_hf_dataset = True
|
444
|
+
self._hf_chunk_size = 10000 # Items per virtual shard
|
445
|
+
self._next_hf_shard_index = len(self.all_shards) # For creating new virtual shards
|
446
|
+
else:
|
447
|
+
self._is_hf_dataset = False
|
448
|
+
|
317
449
|
# Get shard status from ChunkTracker
|
318
450
|
shards_summary = self.chunk_tracker.get_shards_summary() if self.chunk_tracker else {}
|
319
451
|
completed_shards = {
|
@@ -336,7 +468,10 @@ class Orchestrator:
|
|
336
468
|
|
337
469
|
# Filter out shards that already have chunks created
|
338
470
|
remaining_shards = [
|
339
|
-
shard
|
471
|
+
shard
|
472
|
+
for shard in remaining_shards
|
473
|
+
if (shard if shard.startswith("hf_dataset:") else Path(shard).stem)
|
474
|
+
not in shards_with_chunks
|
340
475
|
]
|
341
476
|
|
342
477
|
self.stats["completed_shards"] = len(completed_shards)
|
@@ -356,25 +491,18 @@ class Orchestrator:
|
|
356
491
|
with self.chunk_manager.lock:
|
357
492
|
for chunk_state in shard_info["chunks"]:
|
358
493
|
if chunk_state.status in ["pending", "failed", "assigned"]:
|
359
|
-
#
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
start_index=chunk_state.start_index,
|
372
|
-
chunk_size=chunk_state.chunk_size,
|
373
|
-
)
|
374
|
-
self.chunk_manager.chunks[chunk_state.chunk_id] = chunk
|
375
|
-
self.chunk_manager.pending_chunks.append(chunk_state.chunk_id)
|
376
|
-
requeued_chunks_by_shard[shard_name].append(chunk_state.chunk_id)
|
377
|
-
initial_pending += 1
|
494
|
+
# ChunkState already has shard_url stored
|
495
|
+
chunk = ShardChunk(
|
496
|
+
chunk_id=chunk_state.chunk_id,
|
497
|
+
shard_url=chunk_state.shard_url,
|
498
|
+
shard_name=chunk_state.shard_name,
|
499
|
+
start_index=chunk_state.start_index,
|
500
|
+
chunk_size=chunk_state.chunk_size,
|
501
|
+
)
|
502
|
+
self.chunk_manager.chunks[chunk_state.chunk_id] = chunk
|
503
|
+
self.chunk_manager.pending_chunks.append(chunk_state.chunk_id)
|
504
|
+
requeued_chunks_by_shard[shard_name].append(chunk_state.chunk_id)
|
505
|
+
initial_pending += 1
|
378
506
|
|
379
507
|
logger.info(f"Re-queued {initial_pending} existing pending chunks")
|
380
508
|
for shard_name, chunk_ids in requeued_chunks_by_shard.items():
|
@@ -426,7 +554,13 @@ class Orchestrator:
|
|
426
554
|
if current_shard_url is None or current_shard_index >= current_shard_items:
|
427
555
|
try:
|
428
556
|
current_shard_url = next(shard_iter)
|
429
|
-
|
557
|
+
|
558
|
+
# Extract shard name based on type
|
559
|
+
if current_shard_url.startswith("hf_dataset:"):
|
560
|
+
current_shard_name = current_shard_url # Use full ID for virtual shards
|
561
|
+
else:
|
562
|
+
current_shard_name = Path(current_shard_url).stem
|
563
|
+
|
430
564
|
self.stats["current_shard"] = current_shard_name
|
431
565
|
|
432
566
|
# Skip if we already have chunks from this shard
|
@@ -439,16 +573,74 @@ class Orchestrator:
|
|
439
573
|
|
440
574
|
# Count items in new shard
|
441
575
|
logger.info(f"Loading new shard {current_shard_name}")
|
442
|
-
|
443
|
-
|
444
|
-
)
|
576
|
+
|
577
|
+
# For virtual HF dataset shards, use the chunk size directly
|
578
|
+
if current_shard_url.startswith("hf_dataset:"):
|
579
|
+
current_shard_items = self.dataset_loader.count_shard_items(
|
580
|
+
current_shard_url
|
581
|
+
)
|
582
|
+
logger.info(
|
583
|
+
f"Virtual shard {current_shard_name} has {current_shard_items} items"
|
584
|
+
)
|
585
|
+
else:
|
586
|
+
# For WebDataset, actually count items
|
587
|
+
current_shard_items = sum(
|
588
|
+
1 for _ in self.dataset_loader.iterate_shard(current_shard_url)
|
589
|
+
)
|
590
|
+
logger.info(
|
591
|
+
f"Shard {current_shard_name} has {current_shard_items} items"
|
592
|
+
)
|
593
|
+
|
445
594
|
current_shard_index = 0
|
446
|
-
logger.info(f"Shard {current_shard_name} has {current_shard_items} items")
|
447
595
|
|
448
596
|
except StopIteration:
|
449
|
-
# No more shards
|
597
|
+
# No more shards in the iterator
|
598
|
+
if self._is_hf_dataset:
|
599
|
+
# Before creating new virtual shards, check if we have pending chunks
|
600
|
+
with self.chunk_manager.lock:
|
601
|
+
pending_count = len(self.chunk_manager.pending_chunks)
|
602
|
+
|
603
|
+
if pending_count > 0:
|
604
|
+
# Don't create new shards if we have pending chunks
|
605
|
+
logger.debug(
|
606
|
+
f"Have {pending_count} pending chunks, not creating new virtual shards yet"
|
607
|
+
)
|
608
|
+
current_shard_url = None
|
609
|
+
time.sleep(2)
|
610
|
+
continue
|
611
|
+
|
612
|
+
# For HF datasets, we can create more virtual shards on demand
|
613
|
+
logger.info(
|
614
|
+
"Creating additional virtual shards for HuggingFace dataset"
|
615
|
+
)
|
616
|
+
|
617
|
+
# Create 10 more virtual shards
|
618
|
+
new_shards = []
|
619
|
+
for i in range(10):
|
620
|
+
shard_id = f"hf_dataset:{self.dataset_path}:chunk:{self._next_hf_shard_index * self._hf_chunk_size}"
|
621
|
+
new_shards.append(shard_id)
|
622
|
+
self._next_hf_shard_index += 1
|
623
|
+
|
624
|
+
# Add to all_shards and create new iterator
|
625
|
+
self.all_shards.extend(new_shards)
|
626
|
+
self.stats["total_shards"] = len(self.all_shards)
|
627
|
+
|
628
|
+
# Filter for unprocessed shards
|
629
|
+
remaining_new_shards = [
|
630
|
+
s
|
631
|
+
for s in new_shards
|
632
|
+
if s not in shards_summary and s not in completed_shards
|
633
|
+
]
|
634
|
+
|
635
|
+
if remaining_new_shards:
|
636
|
+
shard_iter = iter(remaining_new_shards)
|
637
|
+
logger.info(f"Added {len(remaining_new_shards)} new virtual shards")
|
638
|
+
continue
|
639
|
+
|
640
|
+
# No more shards to process
|
450
641
|
logger.info("No more shards to process")
|
451
642
|
break
|
643
|
+
|
452
644
|
except Exception as e:
|
453
645
|
logger.error(f"Error loading shard {current_shard_name}: {e}")
|
454
646
|
current_shard_url = None
|
@@ -456,25 +648,40 @@ class Orchestrator:
|
|
456
648
|
|
457
649
|
# Create a chunk from current shard
|
458
650
|
if current_shard_url and current_shard_index < current_shard_items:
|
459
|
-
|
460
|
-
|
651
|
+
# Calculate the absolute dataset index for this chunk
|
652
|
+
if current_shard_url.startswith("hf_dataset:"):
|
653
|
+
# Parse the virtual shard URL to get the base start index
|
654
|
+
parts = current_shard_url.split(":")
|
655
|
+
if len(parts) >= 4 and parts[2] == "chunk":
|
656
|
+
shard_base_index = int(parts[3])
|
657
|
+
else:
|
658
|
+
shard_base_index = 0
|
659
|
+
|
660
|
+
# The absolute start index for this chunk in the dataset
|
661
|
+
absolute_start_index = shard_base_index + current_shard_index
|
662
|
+
else:
|
663
|
+
# For WebDataset, current_shard_index is already absolute
|
664
|
+
absolute_start_index = current_shard_index
|
665
|
+
|
666
|
+
# Create chunk with absolute index
|
667
|
+
chunk = ShardChunk.create(
|
668
|
+
shard_url=current_shard_url,
|
669
|
+
shard_name=current_shard_name,
|
670
|
+
start_index=absolute_start_index,
|
671
|
+
chunk_size=min(self.chunk_size, current_shard_items - current_shard_index),
|
672
|
+
)
|
461
673
|
|
462
|
-
# Add to ChunkTracker
|
674
|
+
# Add to ChunkTracker with all required fields
|
463
675
|
if self.chunk_tracker and self.chunk_tracker.add_chunk(
|
464
|
-
chunk_id,
|
676
|
+
chunk.chunk_id,
|
677
|
+
chunk.shard_name,
|
678
|
+
chunk.shard_url,
|
679
|
+
chunk.start_index,
|
680
|
+
chunk.chunk_size,
|
465
681
|
):
|
466
|
-
# Create chunk
|
467
|
-
chunk = ShardChunk(
|
468
|
-
chunk_id=chunk_id,
|
469
|
-
shard_url=current_shard_url,
|
470
|
-
shard_name=current_shard_name,
|
471
|
-
start_index=current_shard_index,
|
472
|
-
chunk_size=chunk_size,
|
473
|
-
)
|
474
|
-
|
475
682
|
with self.chunk_manager.lock:
|
476
|
-
self.chunk_manager.chunks[chunk_id] = chunk
|
477
|
-
self.chunk_manager.pending_chunks.append(chunk_id)
|
683
|
+
self.chunk_manager.chunks[chunk.chunk_id] = chunk
|
684
|
+
self.chunk_manager.pending_chunks.append(chunk.chunk_id)
|
478
685
|
|
479
686
|
chunks_created += 1
|
480
687
|
self.stats["total_chunks"] += 1
|
@@ -484,10 +691,14 @@ class Orchestrator:
|
|
484
691
|
if chunks_created > 0:
|
485
692
|
logger.info(f"Created {chunks_created} chunks on demand")
|
486
693
|
|
487
|
-
# If we couldn't create any chunks and there are no more shards,
|
694
|
+
# If we couldn't create any chunks and there are no more shards, check if it's HF dataset
|
488
695
|
if chunks_created == 0 and current_shard_url is None:
|
489
|
-
|
490
|
-
|
696
|
+
if self._is_hf_dataset:
|
697
|
+
# We can always create more virtual shards for HF datasets
|
698
|
+
logger.debug("Will create more virtual shards on next iteration")
|
699
|
+
else:
|
700
|
+
logger.info("All shards processed, chunk creation complete")
|
701
|
+
break
|
491
702
|
|
492
703
|
# Brief pause to avoid spinning
|
493
704
|
time.sleep(1)
|
@@ -558,7 +769,9 @@ class Orchestrator:
|
|
558
769
|
elif auth_ticket.role == "admin":
|
559
770
|
await self._handle_admin(websocket, auth_ticket)
|
560
771
|
else:
|
561
|
-
await websocket.send(
|
772
|
+
await websocket.send(
|
773
|
+
safe_json_dumps({"error": f"Unknown role: {auth_ticket.role}"})
|
774
|
+
)
|
562
775
|
|
563
776
|
except Exception as e:
|
564
777
|
logger.error(f"Connection error: {e}")
|
@@ -604,81 +817,118 @@ class Orchestrator:
|
|
604
817
|
requires_worker_restart = False
|
605
818
|
|
606
819
|
try:
|
820
|
+
# Extract orchestrator section if present
|
821
|
+
if "orchestrator" in new_config:
|
822
|
+
# Config has orchestrator wrapper, extract it
|
823
|
+
orchestrator_config = new_config["orchestrator"]
|
824
|
+
else:
|
825
|
+
# Config is already at orchestrator level
|
826
|
+
orchestrator_config = new_config
|
827
|
+
|
828
|
+
# Helper function for deep comparison
|
829
|
+
def deep_equal(a, b):
|
830
|
+
"""Deep comparison of two values including nested dicts and lists."""
|
831
|
+
if type(a) != type(b):
|
832
|
+
return False
|
833
|
+
if isinstance(a, dict):
|
834
|
+
if set(a.keys()) != set(b.keys()):
|
835
|
+
return False
|
836
|
+
return all(deep_equal(a[k], b[k]) for k in a.keys())
|
837
|
+
elif isinstance(a, (list, tuple)):
|
838
|
+
if len(a) != len(b):
|
839
|
+
return False
|
840
|
+
return all(deep_equal(x, y) for x, y in zip(a, b))
|
841
|
+
else:
|
842
|
+
return a == b
|
843
|
+
|
607
844
|
# Update vLLM configuration
|
608
|
-
if "vllm" in
|
845
|
+
if "vllm" in orchestrator_config:
|
609
846
|
old_vllm = self.vllm_config.copy()
|
847
|
+
new_vllm = orchestrator_config["vllm"]
|
610
848
|
|
611
|
-
# Check
|
612
|
-
vllm_changed =
|
613
|
-
for key, value in new_config["vllm"].items():
|
614
|
-
if self.vllm_config.get(key) != value:
|
615
|
-
self.vllm_config[key] = value
|
616
|
-
vllm_changed = True
|
849
|
+
# Check if vLLM config actually changed using deep comparison
|
850
|
+
vllm_changed = not deep_equal(old_vllm, new_vllm)
|
617
851
|
|
618
852
|
if vllm_changed:
|
853
|
+
# Update the vLLM config
|
854
|
+
self.vllm_config = new_vllm.copy()
|
619
855
|
updated_sections.append("vllm")
|
620
856
|
|
621
857
|
# Check if critical changes require worker restart
|
622
858
|
if (
|
623
|
-
old_vllm.get("model") !=
|
859
|
+
old_vllm.get("model") != new_vllm.get("model")
|
624
860
|
or old_vllm.get("gpu_memory_utilization")
|
625
|
-
!=
|
861
|
+
!= new_vllm.get("gpu_memory_utilization")
|
626
862
|
or old_vllm.get("tensor_parallel_size")
|
627
|
-
!=
|
863
|
+
!= new_vllm.get("tensor_parallel_size")
|
864
|
+
or old_vllm.get("dtype") != new_vllm.get("dtype")
|
865
|
+
or old_vllm.get("max_model_len") != new_vllm.get("max_model_len")
|
628
866
|
):
|
629
867
|
requires_worker_restart = True
|
630
868
|
warnings.append(
|
631
869
|
"Critical vLLM changes detected - workers will be disconnected to reload"
|
632
870
|
)
|
871
|
+
logger.info(
|
872
|
+
f"Model change: {old_vllm.get('model')} -> {new_vllm.get('model')}"
|
873
|
+
)
|
633
874
|
|
634
875
|
# Update dataset configuration
|
635
|
-
if "dataset" in
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
dataset_changed = True
|
876
|
+
if "dataset" in orchestrator_config:
|
877
|
+
old_dataset = self.dataset_config.copy()
|
878
|
+
new_dataset = orchestrator_config["dataset"]
|
879
|
+
|
880
|
+
dataset_changed = not deep_equal(old_dataset, new_dataset)
|
641
881
|
|
642
882
|
if dataset_changed:
|
883
|
+
self.dataset_config = new_dataset.copy()
|
643
884
|
self.dataset_path = self.dataset_config.get("path")
|
644
885
|
self.dataset_type = self.dataset_config.get("type", "huggingface")
|
645
886
|
updated_sections.append("dataset")
|
646
887
|
warnings.append("Dataset changes will apply to new chunks only")
|
647
888
|
|
648
889
|
# Update chunk settings
|
649
|
-
if
|
650
|
-
|
890
|
+
if (
|
891
|
+
"chunk_size" in orchestrator_config
|
892
|
+
and self.chunk_size != orchestrator_config["chunk_size"]
|
893
|
+
):
|
894
|
+
self.chunk_size = orchestrator_config["chunk_size"]
|
651
895
|
self.chunk_manager.chunk_size = self.chunk_size
|
652
896
|
updated_sections.append("chunk_size")
|
653
897
|
|
654
898
|
if (
|
655
|
-
"chunks_per_request" in
|
656
|
-
and self.chunks_per_request !=
|
899
|
+
"chunks_per_request" in orchestrator_config
|
900
|
+
and self.chunks_per_request != orchestrator_config["chunks_per_request"]
|
657
901
|
):
|
658
|
-
self.chunks_per_request =
|
902
|
+
self.chunks_per_request = orchestrator_config["chunks_per_request"]
|
659
903
|
updated_sections.append("chunks_per_request")
|
660
904
|
|
661
|
-
#
|
662
|
-
|
905
|
+
# Update auth configuration
|
906
|
+
if "auth" in orchestrator_config:
|
907
|
+
try:
|
908
|
+
self.auth = AuthManager({"auth": orchestrator_config["auth"]})
|
909
|
+
updated_sections.append("auth")
|
910
|
+
except Exception as e:
|
911
|
+
logger.error(f"Failed to update AuthManager: {e}")
|
912
|
+
warnings.append(f"Auth update failed: {e}")
|
663
913
|
|
664
914
|
# Update buffer settings
|
665
915
|
if (
|
666
|
-
"chunk_buffer_multiplier" in
|
667
|
-
and self.chunk_buffer_multiplier !=
|
916
|
+
"chunk_buffer_multiplier" in orchestrator_config
|
917
|
+
and self.chunk_buffer_multiplier != orchestrator_config["chunk_buffer_multiplier"]
|
668
918
|
):
|
669
|
-
self.chunk_buffer_multiplier =
|
919
|
+
self.chunk_buffer_multiplier = orchestrator_config["chunk_buffer_multiplier"]
|
670
920
|
updated_sections.append("chunk_buffer_multiplier")
|
671
921
|
|
672
922
|
if (
|
673
|
-
"min_chunk_buffer" in
|
674
|
-
and self.min_chunk_buffer !=
|
923
|
+
"min_chunk_buffer" in orchestrator_config
|
924
|
+
and self.min_chunk_buffer != orchestrator_config["min_chunk_buffer"]
|
675
925
|
):
|
676
|
-
self.min_chunk_buffer =
|
926
|
+
self.min_chunk_buffer = orchestrator_config["min_chunk_buffer"]
|
677
927
|
updated_sections.append("min_chunk_buffer")
|
678
928
|
|
679
929
|
# Update storage settings
|
680
|
-
if "storage" in
|
681
|
-
storage_config =
|
930
|
+
if "storage" in orchestrator_config:
|
931
|
+
storage_config = orchestrator_config["storage"]
|
682
932
|
storage_changed = False
|
683
933
|
|
684
934
|
if (
|
@@ -701,21 +951,6 @@ class Orchestrator:
|
|
701
951
|
if storage_changed:
|
702
952
|
updated_sections.append("storage")
|
703
953
|
|
704
|
-
# Update data worker storage config
|
705
|
-
if "data_worker_storage" in new_config:
|
706
|
-
current_dw_storage = self.config.get("data_worker_storage", {})
|
707
|
-
if current_dw_storage != new_config["data_worker_storage"]:
|
708
|
-
self.config["data_worker_storage"] = new_config["data_worker_storage"]
|
709
|
-
updated_sections.append("data_worker_storage")
|
710
|
-
warnings.append("Data worker storage config will apply to new connections only")
|
711
|
-
|
712
|
-
# Update backpressure threshold
|
713
|
-
if "backpressure_threshold" in new_config:
|
714
|
-
current_threshold = getattr(self, "backpressure_threshold", 800)
|
715
|
-
if current_threshold != new_config["backpressure_threshold"]:
|
716
|
-
self.backpressure_threshold = new_config["backpressure_threshold"]
|
717
|
-
updated_sections.append("backpressure_threshold")
|
718
|
-
|
719
954
|
# Check if any changes were made
|
720
955
|
if not updated_sections:
|
721
956
|
await websocket.send(
|
@@ -729,29 +964,49 @@ class Orchestrator:
|
|
729
964
|
logger.info("Configuration reload requested but no changes detected")
|
730
965
|
return
|
731
966
|
|
732
|
-
# Update the main config
|
733
|
-
|
967
|
+
# Update the main config
|
968
|
+
if "orchestrator" in new_config:
|
969
|
+
self.config["orchestrator"] = orchestrator_config
|
970
|
+
else:
|
971
|
+
self.config.update(orchestrator_config)
|
734
972
|
|
735
973
|
# Handle worker restart if needed
|
736
974
|
if requires_worker_restart:
|
737
975
|
logger.info("Disconnecting all workers for configuration reload...")
|
738
976
|
|
739
|
-
#
|
740
|
-
|
741
|
-
|
977
|
+
# Send reload message to workers first
|
978
|
+
reload_msg = safe_json_dumps(
|
979
|
+
{
|
980
|
+
"type": "reload_vllm",
|
981
|
+
"vllm_config": self.vllm_config,
|
982
|
+
}
|
983
|
+
)
|
984
|
+
|
985
|
+
# Create a list of worker items to avoid modifying dict during iteration
|
986
|
+
worker_items = list(self.workers.items())
|
987
|
+
disconnected = []
|
988
|
+
|
989
|
+
for worker_id, ws in worker_items:
|
742
990
|
try:
|
743
|
-
await
|
744
|
-
|
745
|
-
)
|
991
|
+
await ws.send(reload_msg)
|
992
|
+
# Give worker time to process before disconnect
|
993
|
+
await asyncio.sleep(0.5)
|
994
|
+
await ws.close(code=1012, reason="Configuration reload")
|
995
|
+
disconnected.append(worker_id)
|
746
996
|
except:
|
747
|
-
|
997
|
+
disconnected.append(worker_id) # Still mark as disconnected if error
|
998
|
+
|
999
|
+
# Now safely clear workers dict
|
1000
|
+
for worker_id in disconnected:
|
1001
|
+
if worker_id in self.workers:
|
1002
|
+
del self.workers[worker_id]
|
748
1003
|
|
749
1004
|
warnings.append(
|
750
|
-
f"
|
1005
|
+
f"Sent reload message to {len(disconnected)} workers - they will reconnect with new config"
|
751
1006
|
)
|
752
1007
|
else:
|
753
|
-
# Just notify workers about config changes
|
754
|
-
|
1008
|
+
# Just notify workers about config changes without disconnecting
|
1009
|
+
config_update_msg = safe_json_dumps(
|
755
1010
|
{
|
756
1011
|
"type": "config_update",
|
757
1012
|
"vllm_config": self.vllm_config if "vllm" in updated_sections else None,
|
@@ -761,15 +1016,21 @@ class Orchestrator:
|
|
761
1016
|
}
|
762
1017
|
)
|
763
1018
|
|
1019
|
+
# Create a list of worker items to avoid modifying dict during iteration
|
1020
|
+
worker_items = list(self.workers.items())
|
764
1021
|
disconnected = []
|
765
|
-
|
1022
|
+
|
1023
|
+
for worker_id, ws in worker_items:
|
766
1024
|
try:
|
767
|
-
await ws.send(
|
1025
|
+
await ws.send(config_update_msg)
|
1026
|
+
logger.info(f"Sent config update to worker {worker_id}")
|
768
1027
|
except:
|
769
1028
|
disconnected.append(worker_id)
|
770
1029
|
|
1030
|
+
# Now safely remove disconnected workers
|
771
1031
|
for worker_id in disconnected:
|
772
|
-
|
1032
|
+
if worker_id in self.workers:
|
1033
|
+
del self.workers[worker_id]
|
773
1034
|
|
774
1035
|
# Send success response
|
775
1036
|
await websocket.send(
|
@@ -788,34 +1049,58 @@ class Orchestrator:
|
|
788
1049
|
|
789
1050
|
except Exception as e:
|
790
1051
|
logger.error(f"Configuration reload failed: {e}")
|
1052
|
+
import traceback
|
1053
|
+
|
1054
|
+
logger.error(traceback.format_exc())
|
791
1055
|
await websocket.send(safe_json_dumps({"type": "reload_failed", "error": str(e)}))
|
792
1056
|
|
793
1057
|
async def _handle_worker(self, websocket: WebSocketServerProtocol, auth_ticket):
|
794
1058
|
"""Handle worker connection lifecycle."""
|
795
|
-
|
1059
|
+
# Generate unique worker ID even if using same token
|
1060
|
+
base_name = getattr(auth_ticket, "name", "worker")
|
1061
|
+
worker_id = f"{base_name}_{str(uuid.uuid4())[:8]}" # Add unique suffix
|
1062
|
+
|
1063
|
+
# Track the original token/user for accounting
|
1064
|
+
worker_user = base_name # Keep track of which user/token this worker belongs to
|
1065
|
+
|
796
1066
|
self.workers[worker_id] = websocket
|
797
1067
|
self.stats["connected_workers"] = len(self.workers)
|
798
1068
|
|
799
|
-
#
|
800
|
-
|
801
|
-
|
802
|
-
)
|
803
|
-
|
1069
|
+
# Optionally track workers by user/token
|
1070
|
+
if not hasattr(self, "workers_by_user"):
|
1071
|
+
self.workers_by_user = defaultdict(set)
|
1072
|
+
self.workers_by_user[worker_user].add(worker_id)
|
1073
|
+
|
1074
|
+
# Register contributor with the base name (for aggregating stats per user)
|
1075
|
+
contributor = await self.storage.get_contributor(worker_user)
|
1076
|
+
if not contributor:
|
1077
|
+
contributor = Contributor(
|
1078
|
+
contributor_id=worker_user,
|
1079
|
+
name=worker_user,
|
1080
|
+
total_captions=0,
|
1081
|
+
trust_level=1,
|
1082
|
+
)
|
1083
|
+
await self.storage.save_contributor(contributor)
|
804
1084
|
|
805
|
-
logger.info(f"Worker {worker_id} connected")
|
1085
|
+
logger.info(f"Worker {worker_id} (user: {worker_user}) connected")
|
806
1086
|
await self._broadcast_stats()
|
807
|
-
await self._send_activity(f"Worker {worker_id} connected")
|
1087
|
+
await self._send_activity(f"Worker {worker_id} (user: {worker_user}) connected")
|
808
1088
|
|
809
1089
|
try:
|
810
1090
|
# Send welcome message with dataset configuration
|
811
1091
|
welcome_message = {
|
812
1092
|
"type": "welcome",
|
813
1093
|
"worker_id": worker_id,
|
1094
|
+
"user_id": worker_user,
|
814
1095
|
"dataset_config": {
|
815
1096
|
"dataset_path": self.dataset_path,
|
816
1097
|
"dataset_type": self.dataset_type,
|
817
|
-
"
|
818
|
-
"
|
1098
|
+
"dataset_split": self.dataset_split,
|
1099
|
+
"dataset_image_column": self.dataset_image_column,
|
1100
|
+
"path": self.dataset_path,
|
1101
|
+
"type": self.dataset_type,
|
1102
|
+
"split": self.dataset_split,
|
1103
|
+
"image_column": self.dataset_image_column,
|
819
1104
|
},
|
820
1105
|
"vllm_config": self.vllm_config,
|
821
1106
|
}
|
@@ -826,21 +1111,29 @@ class Orchestrator:
|
|
826
1111
|
await self._process_worker_message(worker_id, data)
|
827
1112
|
|
828
1113
|
except websockets.exceptions.ConnectionClosed:
|
829
|
-
logger.info(f"Worker {worker_id} disconnected")
|
1114
|
+
logger.info(f"Worker {worker_id} (user: {worker_user}) disconnected")
|
830
1115
|
finally:
|
831
|
-
|
1116
|
+
if worker_id in self.workers:
|
1117
|
+
del self.workers[worker_id]
|
1118
|
+
|
1119
|
+
# Clean up user tracking
|
1120
|
+
if hasattr(self, "workers_by_user") and worker_user in self.workers_by_user:
|
1121
|
+
self.workers_by_user[worker_user].discard(worker_id)
|
1122
|
+
if not self.workers_by_user[worker_user]:
|
1123
|
+
del self.workers_by_user[worker_user]
|
1124
|
+
|
832
1125
|
self.stats["connected_workers"] = len(self.workers)
|
833
|
-
|
1126
|
+
|
1127
|
+
# Release chunks
|
834
1128
|
self.chunk_manager.release_worker_chunks(worker_id)
|
835
1129
|
if self.chunk_tracker:
|
836
|
-
# Mark released chunks as pending in tracker
|
837
1130
|
released_chunks = self.chunk_tracker.release_worker_chunks(worker_id)
|
838
1131
|
logger.info(
|
839
1132
|
f"Released {len(released_chunks) if released_chunks is not None else 0} chunks from worker {worker_id}"
|
840
1133
|
)
|
841
1134
|
|
842
1135
|
await self._broadcast_stats()
|
843
|
-
await self._send_activity(f"Worker {worker_id} disconnected")
|
1136
|
+
await self._send_activity(f"Worker {worker_id} (user: {worker_user}) disconnected")
|
844
1137
|
|
845
1138
|
async def _process_worker_message(self, worker_id: str, data: Dict):
|
846
1139
|
"""Process message from worker."""
|
@@ -856,28 +1149,26 @@ class Orchestrator:
|
|
856
1149
|
return
|
857
1150
|
|
858
1151
|
count = data.get("count", self.chunks_per_request)
|
859
|
-
|
1152
|
+
chunk_infos = self.chunk_manager.get_chunks_for_worker(
|
1153
|
+
worker_id, count, self.chunk_tracker
|
1154
|
+
)
|
860
1155
|
|
861
|
-
if
|
862
|
-
#
|
863
|
-
|
864
|
-
for
|
865
|
-
|
866
|
-
|
867
|
-
|
868
|
-
"shard_url": chunk.shard_url,
|
869
|
-
"shard_name": chunk.shard_name,
|
870
|
-
"start_index": chunk.start_index,
|
871
|
-
"chunk_size": chunk.chunk_size,
|
872
|
-
}
|
873
|
-
)
|
1156
|
+
if chunk_infos:
|
1157
|
+
# Send chunks with unprocessed ranges
|
1158
|
+
chunks_data = []
|
1159
|
+
for info in chunk_infos:
|
1160
|
+
chunk_dict = info["chunk"].to_dict()
|
1161
|
+
chunk_dict["unprocessed_ranges"] = info["unprocessed_ranges"]
|
1162
|
+
chunks_data.append(chunk_dict)
|
874
1163
|
|
875
1164
|
await self.workers[worker_id].send(
|
876
|
-
safe_json_dumps({"type": "shard_assignment", "chunks":
|
1165
|
+
safe_json_dumps({"type": "shard_assignment", "chunks": chunks_data})
|
1166
|
+
)
|
1167
|
+
|
1168
|
+
chunk_ids = [c["chunk_id"] for c in chunks_data]
|
1169
|
+
logger.info(
|
1170
|
+
f"Assigned {len(chunks_data)} chunks to worker {worker_id}: {chunk_ids}"
|
877
1171
|
)
|
878
|
-
chunk_ids = [c["chunk_id"] for c in chunk_data]
|
879
|
-
logger.info(f"Assigned {len(chunks)} chunks to worker {worker_id}: {chunk_ids}")
|
880
|
-
await self._send_activity(f"Assigned {len(chunks)} chunks to {worker_id}")
|
881
1172
|
else:
|
882
1173
|
await self.workers[worker_id].send(safe_json_dumps({"type": "no_chunks"}))
|
883
1174
|
|
@@ -907,7 +1198,7 @@ class Orchestrator:
|
|
907
1198
|
elif msg_type == "submit_captions":
|
908
1199
|
await self._handle_captions_submission(worker_id, data)
|
909
1200
|
elif msg_type == "request_job":
|
910
|
-
#
|
1201
|
+
# CaptionWorker requesting a job from data samples
|
911
1202
|
try:
|
912
1203
|
job = await asyncio.wait_for(self.data_sample_queue.get(), timeout=5)
|
913
1204
|
await self.workers[worker_id].send(
|
@@ -921,76 +1212,132 @@ class Orchestrator:
|
|
921
1212
|
logger.debug(f"Heartbeat from {worker_id}: {data}")
|
922
1213
|
|
923
1214
|
async def _handle_captions_submission(self, worker_id: str, data: Dict):
|
924
|
-
"""Process
|
1215
|
+
"""Process caption submission from worker - now handles multi-stage outputs."""
|
925
1216
|
chunk_id = data.get("chunk_id")
|
926
1217
|
item_key = data["item_key"]
|
927
|
-
captions_list = data["captions"]
|
928
1218
|
|
929
|
-
|
930
|
-
|
931
|
-
|
1219
|
+
item_index = data.get("item_index") # Worker should send this
|
1220
|
+
if item_index is None:
|
1221
|
+
# Try to extract from item_key (format: dataset_XXXXXXXX)
|
1222
|
+
try:
|
1223
|
+
item_index = int(item_key.split("_")[-1])
|
1224
|
+
except:
|
1225
|
+
logger.warning(f"Could not extract item index from key: {item_key}")
|
932
1226
|
|
933
|
-
#
|
1227
|
+
# Extract user from worker_id (format: "username_uuid")
|
1228
|
+
worker_user = worker_id.rsplit("_", 1)[0] if "_" in worker_id else worker_id
|
1229
|
+
|
1230
|
+
# Handle both old format (captions list) and new format (outputs dict)
|
1231
|
+
if "outputs" in data:
|
1232
|
+
# New multi-stage format
|
1233
|
+
outputs = data["outputs"]
|
1234
|
+
captions_list = outputs.get("captions", [])
|
1235
|
+
total_outputs = sum(len(v) for v in outputs.values())
|
1236
|
+
|
1237
|
+
logger.debug(
|
1238
|
+
f"Received multi-stage outputs for item {item_key} from worker {worker_id}: "
|
1239
|
+
f"{total_outputs} outputs across {len(outputs)} fields"
|
1240
|
+
)
|
1241
|
+
else:
|
1242
|
+
# Old format - single captions list
|
1243
|
+
captions_list = data["captions"]
|
1244
|
+
outputs = {"captions": captions_list}
|
1245
|
+
total_outputs = len(captions_list)
|
1246
|
+
|
1247
|
+
logger.debug(
|
1248
|
+
f"Received {len(captions_list)} captions for item {item_key} from worker {worker_id}"
|
1249
|
+
)
|
1250
|
+
|
1251
|
+
# Create caption record with multi-stage outputs
|
934
1252
|
caption = Caption(
|
935
|
-
job_id=f"{chunk_id}_{item_key}",
|
1253
|
+
job_id=f"{chunk_id}_{item_key}",
|
936
1254
|
dataset=data.get("dataset"),
|
937
1255
|
shard=data.get("shard"),
|
938
1256
|
item_key=item_key,
|
939
|
-
captions=captions_list,
|
940
|
-
|
1257
|
+
captions=captions_list,
|
1258
|
+
outputs=outputs,
|
1259
|
+
contributor_id=worker_user,
|
941
1260
|
timestamp=datetime.utcnow(),
|
942
|
-
quality_scores=None,
|
1261
|
+
quality_scores=None,
|
943
1262
|
# Image metadata
|
944
1263
|
image_width=data.get("image_width"),
|
945
1264
|
image_height=data.get("image_height"),
|
946
1265
|
image_format=data.get("image_format"),
|
947
1266
|
file_size=data.get("file_size"),
|
948
1267
|
# Processing metadata
|
949
|
-
caption_count=
|
1268
|
+
caption_count=total_outputs,
|
950
1269
|
processing_time_ms=data.get("processing_time_ms"),
|
951
1270
|
chunk_id=chunk_id,
|
1271
|
+
metadata=data.get("metadata", {}),
|
952
1272
|
)
|
953
1273
|
|
954
|
-
# Add to central storage buffer
|
1274
|
+
# Add to central storage buffer
|
955
1275
|
await self.storage.save_caption(caption)
|
956
1276
|
|
957
|
-
#
|
958
|
-
|
959
|
-
|
1277
|
+
# Handle item tracking with fixed deadlock
|
1278
|
+
should_flush = False
|
1279
|
+
if chunk_id and item_index is not None and self.chunk_tracker:
|
1280
|
+
with self.item_batch_lock:
|
1281
|
+
self.pending_processed_items[chunk_id].append(item_index)
|
960
1282
|
|
961
|
-
|
962
|
-
|
1283
|
+
# Check if we should flush
|
1284
|
+
total_pending = sum(
|
1285
|
+
len(indices) for indices in self.pending_processed_items.values()
|
1286
|
+
)
|
1287
|
+
time_since_flush = time.time() - self.last_item_batch_flush
|
1288
|
+
|
1289
|
+
if (
|
1290
|
+
total_pending >= self.item_batch_size
|
1291
|
+
or time_since_flush >= self.item_batch_interval
|
1292
|
+
):
|
1293
|
+
should_flush = True
|
1294
|
+
|
1295
|
+
if should_flush:
|
1296
|
+
await self._flush_processed_items()
|
1297
|
+
|
1298
|
+
# Update contributor stats (use user, not worker)
|
1299
|
+
contributor = await self.storage.get_contributor(worker_user)
|
963
1300
|
if contributor:
|
964
|
-
contributor.total_captions +=
|
1301
|
+
contributor.total_captions += total_outputs
|
965
1302
|
await self.storage.save_contributor(contributor)
|
966
1303
|
|
967
1304
|
# Broadcast updated stats
|
968
1305
|
await self._broadcast_stats()
|
969
1306
|
|
970
1307
|
# Log progress periodically
|
971
|
-
|
972
|
-
|
1308
|
+
total_outputs = self.stats.get("total_outputs", 0)
|
1309
|
+
if total_outputs > 0 and total_outputs % 100 == 0:
|
1310
|
+
if (
|
1311
|
+
not hasattr(self, "_last_logged_outputs")
|
1312
|
+
or self._last_logged_outputs != total_outputs
|
1313
|
+
):
|
1314
|
+
logger.info(f"Collected {total_outputs} outputs centrally")
|
1315
|
+
self._last_logged_outputs = total_outputs
|
973
1316
|
|
974
1317
|
async def _check_shard_completion(self, chunk_id: str):
|
975
1318
|
"""Check if a shard is complete after chunk completion."""
|
976
|
-
#
|
977
|
-
|
1319
|
+
# Get the chunk
|
1320
|
+
chunk = self.chunk_manager.chunks.get(chunk_id)
|
1321
|
+
if not chunk:
|
1322
|
+
return
|
1323
|
+
|
1324
|
+
shard_name = chunk.shard_name
|
978
1325
|
|
979
|
-
#
|
980
|
-
chunk_stats = self.chunk_manager.get_stats()
|
1326
|
+
# Find all chunks for this shard
|
981
1327
|
shard_chunks = [
|
982
|
-
cid
|
983
|
-
for cid, chunk in self.chunk_manager.chunks.items()
|
984
|
-
if chunk.shard_name == shard_name
|
1328
|
+
cid for cid, c in self.chunk_manager.chunks.items() if c.belongs_to_shard(shard_name)
|
985
1329
|
]
|
986
1330
|
|
1331
|
+
# Check if all are completed
|
987
1332
|
completed_chunks = [
|
988
1333
|
cid for cid in shard_chunks if self.chunk_manager.chunks[cid].status == "completed"
|
989
1334
|
]
|
990
1335
|
|
991
|
-
if len(completed_chunks) == len(shard_chunks):
|
1336
|
+
if len(completed_chunks) == len(shard_chunks) and len(shard_chunks) > 0:
|
992
1337
|
logger.info(f"Shard {shard_name} complete!")
|
993
|
-
|
1338
|
+
# Don't mark virtual shards as complete in ShardTracker
|
1339
|
+
if not shard_name.startswith("hf_dataset:"):
|
1340
|
+
self.shard_tracker.mark_complete(shard_name)
|
994
1341
|
self.stats["completed_shards"] += 1
|
995
1342
|
await self._send_activity(f"Shard {shard_name} completed!")
|
996
1343
|
|
@@ -1063,47 +1410,198 @@ class Orchestrator:
|
|
1063
1410
|
finally:
|
1064
1411
|
del self.data_workers[worker_id]
|
1065
1412
|
|
1066
|
-
async def
|
1067
|
-
"""
|
1068
|
-
|
1069
|
-
|
1413
|
+
async def _send_leaderboard_to_monitor(self, websocket: WebSocketServerProtocol):
|
1414
|
+
"""Send leaderboard data to a specific monitor."""
|
1415
|
+
total_start = time.time()
|
1416
|
+
try:
|
1417
|
+
if websocket not in self.monitors:
|
1418
|
+
return
|
1419
|
+
|
1420
|
+
# Get contributors asynchronously
|
1421
|
+
contributors_start = time.time()
|
1422
|
+
contributors = await self.storage.get_top_contributors(10)
|
1423
|
+
logger.debug(
|
1424
|
+
f"Contributors retrieved in {(time.time() - contributors_start)*1000:.1f}ms"
|
1425
|
+
)
|
1426
|
+
|
1427
|
+
# Get worker counts in thread pool
|
1428
|
+
worker_counts_start = time.time()
|
1429
|
+
loop = asyncio.get_event_loop()
|
1430
|
+
worker_counts = await loop.run_in_executor(
|
1431
|
+
None,
|
1432
|
+
lambda: (
|
1433
|
+
self.get_workers_by_user_stats() if hasattr(self, "workers_by_user") else {}
|
1434
|
+
),
|
1435
|
+
)
|
1436
|
+
logger.debug(
|
1437
|
+
f"Worker counts retrieved in {(time.time() - worker_counts_start)*1000:.1f}ms"
|
1438
|
+
)
|
1070
1439
|
|
1440
|
+
# Build enhanced contributors list
|
1441
|
+
build_start = time.time()
|
1442
|
+
enhanced_contributors = []
|
1443
|
+
for contributor in contributors:
|
1444
|
+
contrib_dict = {
|
1445
|
+
"contributor_id": contributor.contributor_id,
|
1446
|
+
"name": contributor.name,
|
1447
|
+
"total_captions": contributor.total_captions,
|
1448
|
+
"trust_level": contributor.trust_level,
|
1449
|
+
"active_workers": len(
|
1450
|
+
worker_counts.get(contributor.contributor_id, {}).get("worker_ids", [])
|
1451
|
+
),
|
1452
|
+
}
|
1453
|
+
enhanced_contributors.append(contrib_dict)
|
1454
|
+
logger.debug(f"Enhanced contributors built in {(time.time() - build_start)*1000:.1f}ms")
|
1455
|
+
|
1456
|
+
# Cache for future monitors
|
1457
|
+
self._cached_leaderboard = enhanced_contributors
|
1458
|
+
|
1459
|
+
# Send if still connected
|
1460
|
+
if websocket in self.monitors:
|
1461
|
+
send_start = time.time()
|
1462
|
+
await websocket.send(
|
1463
|
+
safe_json_dumps({"type": "leaderboard", "data": enhanced_contributors})
|
1464
|
+
)
|
1465
|
+
logger.debug(
|
1466
|
+
f"Leaderboard sent to monitor in {(time.time() - send_start)*1000:.1f}ms"
|
1467
|
+
)
|
1468
|
+
|
1469
|
+
logger.debug(
|
1470
|
+
f"Leaderboard send to monitor completed in {(time.time() - total_start)*1000:.1f}ms"
|
1471
|
+
)
|
1472
|
+
|
1473
|
+
except websockets.exceptions.ConnectionClosed:
|
1474
|
+
logger.debug("Monitor disconnected during leaderboard send")
|
1475
|
+
except Exception as e:
|
1476
|
+
logger.error(f"Error sending leaderboard to monitor: {e}")
|
1477
|
+
|
1478
|
+
async def _send_initial_monitor_data(self, websocket: WebSocketServerProtocol):
|
1479
|
+
"""Send initial data to monitor in a separate task to avoid blocking."""
|
1480
|
+
total_start = time.time()
|
1071
1481
|
try:
|
1072
|
-
#
|
1482
|
+
# Check if websocket is still in monitors set
|
1483
|
+
if websocket not in self.monitors:
|
1484
|
+
logger.debug("Monitor disconnected before initial data send")
|
1485
|
+
return
|
1486
|
+
|
1487
|
+
# Send current stats (already in memory)
|
1488
|
+
stats_start = time.time()
|
1073
1489
|
await websocket.send(safe_json_dumps({"type": "stats", "data": self.stats}))
|
1490
|
+
logger.debug(f"Monitor stats sent in {(time.time() - stats_start)*1000:.1f}ms")
|
1491
|
+
|
1492
|
+
# Get chunk stats asynchronously
|
1493
|
+
chunk_stats_start = time.time()
|
1494
|
+
loop = asyncio.get_event_loop()
|
1495
|
+
chunk_stats = await loop.run_in_executor(None, self.chunk_manager.get_stats)
|
1496
|
+
logger.debug(f"Chunk stats retrieved in {(time.time() - chunk_stats_start)*1000:.1f}ms")
|
1497
|
+
|
1498
|
+
if websocket not in self.monitors:
|
1499
|
+
return
|
1074
1500
|
|
1075
|
-
|
1076
|
-
chunk_stats = self.chunk_manager.get_stats()
|
1501
|
+
chunk_send_start = time.time()
|
1077
1502
|
await websocket.send(safe_json_dumps({"type": "chunk_stats", "data": chunk_stats}))
|
1503
|
+
logger.debug(f"Chunk stats sent in {(time.time() - chunk_send_start)*1000:.1f}ms")
|
1078
1504
|
|
1079
|
-
#
|
1080
|
-
|
1081
|
-
|
1082
|
-
|
1083
|
-
|
1505
|
+
# For leaderboard, check if we have a cached version first
|
1506
|
+
if hasattr(self, "_cached_leaderboard") and self._cached_leaderboard:
|
1507
|
+
# Use cached leaderboard if available
|
1508
|
+
cache_send_start = time.time()
|
1509
|
+
await websocket.send(
|
1510
|
+
safe_json_dumps({"type": "leaderboard", "data": self._cached_leaderboard})
|
1084
1511
|
)
|
1512
|
+
logger.debug(
|
1513
|
+
f"Cached leaderboard sent in {(time.time() - cache_send_start)*1000:.1f}ms"
|
1514
|
+
)
|
1515
|
+
else:
|
1516
|
+
# Schedule leaderboard update separately
|
1517
|
+
leaderboard_task_start = time.time()
|
1518
|
+
asyncio.create_task(self._send_leaderboard_to_monitor(websocket))
|
1519
|
+
logger.debug(
|
1520
|
+
f"Leaderboard task created in {(time.time() - leaderboard_task_start)*1000:.1f}ms"
|
1521
|
+
)
|
1522
|
+
|
1523
|
+
logger.debug(
|
1524
|
+
f"Monitor initial data send completed in {(time.time() - total_start)*1000:.1f}ms"
|
1085
1525
|
)
|
1086
1526
|
|
1087
|
-
|
1088
|
-
|
1089
|
-
|
1527
|
+
except websockets.exceptions.ConnectionClosed:
|
1528
|
+
logger.debug("Monitor disconnected during initial data send")
|
1529
|
+
except Exception as e:
|
1530
|
+
logger.error(f"Error sending initial monitor data: {e}")
|
1531
|
+
|
1532
|
+
async def _handle_monitor(self, websocket: WebSocketServerProtocol):
|
1533
|
+
"""Handle monitor connection - truly non-blocking version."""
|
1534
|
+
monitor_start = time.time()
|
1535
|
+
self.monitors.add(websocket)
|
1536
|
+
logger.info(f"Monitor connected (total monitors: {len(self.monitors)})")
|
1537
|
+
|
1538
|
+
try:
|
1539
|
+
# Send welcome message immediately
|
1540
|
+
welcome_start = time.time()
|
1541
|
+
await websocket.send(safe_json_dumps({"type": "welcome", "role": "monitor"}))
|
1542
|
+
logger.debug(f"Monitor welcome sent in {(time.time() - welcome_start)*1000:.1f}ms")
|
1543
|
+
|
1544
|
+
# Schedule initial data send as a separate task to avoid blocking
|
1545
|
+
task_create_start = time.time()
|
1546
|
+
asyncio.create_task(self._send_initial_monitor_data(websocket))
|
1547
|
+
logger.debug(
|
1548
|
+
f"Monitor initial data task created in {(time.time() - task_create_start)*1000:.1f}ms"
|
1549
|
+
)
|
1550
|
+
|
1551
|
+
# Just keep the connection alive - no blocking work here
|
1552
|
+
try:
|
1553
|
+
async for message in websocket:
|
1554
|
+
# Handle any incoming messages from monitor if needed
|
1555
|
+
# For now, just ignore them
|
1556
|
+
pass
|
1557
|
+
except websockets.exceptions.ConnectionClosed:
|
1558
|
+
pass # Normal disconnection
|
1090
1559
|
|
1091
1560
|
except websockets.exceptions.ConnectionClosed:
|
1092
1561
|
logger.info("Monitor disconnected")
|
1562
|
+
except Exception as e:
|
1563
|
+
logger.error(f"Error in monitor handler: {e}")
|
1093
1564
|
finally:
|
1094
1565
|
self.monitors.discard(websocket)
|
1566
|
+
logger.debug(f"Monitor handler completed in {(time.time() - monitor_start)*1000:.1f}ms")
|
1095
1567
|
|
1096
1568
|
async def _broadcast_stats(self):
|
1097
|
-
"""Broadcast statistics to all monitors."""
|
1569
|
+
"""Broadcast statistics to all monitors - truly non-blocking version."""
|
1098
1570
|
if not self.monitors:
|
1099
1571
|
return
|
1100
|
-
|
1101
|
-
|
1102
|
-
|
1103
|
-
|
1572
|
+
if self.is_generating_stats:
|
1573
|
+
return # Already generating stats, skip this call
|
1574
|
+
self.is_generating_stats = True
|
1575
|
+
total_start = time.time()
|
1576
|
+
|
1577
|
+
# Prepare all the data first
|
1578
|
+
data_prep_start = time.time()
|
1579
|
+
loop = asyncio.get_event_loop()
|
1580
|
+
|
1581
|
+
# Get storage stats (already async)
|
1582
|
+
storage_stats_start = time.time()
|
1583
|
+
storage_stats = await self.storage.get_storage_stats()
|
1584
|
+
logger.debug(f"Storage stats retrieved in {(time.time() - storage_stats_start)*1000:.1f}ms")
|
1585
|
+
|
1586
|
+
caption_stats_start = time.time()
|
1587
|
+
caption_stats = await self.storage.get_caption_stats()
|
1588
|
+
logger.debug(f"Caption stats retrieved in {(time.time() - caption_stats_start)*1000:.1f}ms")
|
1589
|
+
|
1590
|
+
# Get chunk stats in thread pool
|
1591
|
+
chunk_stats_start = time.time()
|
1592
|
+
chunk_stats = await loop.run_in_executor(None, self.chunk_manager.get_stats)
|
1593
|
+
logger.debug(f"Chunk stats retrieved in {(time.time() - chunk_stats_start)*1000:.1f}ms")
|
1594
|
+
|
1595
|
+
# Build stats dict
|
1596
|
+
build_stats_start = time.time()
|
1597
|
+
stats_update = self.stats.copy()
|
1598
|
+
stats_update.update({f"chunks_{k}": v for k, v in chunk_stats.items()})
|
1599
|
+
stats_update.update(storage_stats)
|
1600
|
+
stats_update["field_breakdown"] = caption_stats.get("field_stats", {})
|
1601
|
+
stats_update["output_fields_list"] = caption_stats.get("output_fields", [])
|
1104
1602
|
|
1105
1603
|
# Add rate information
|
1106
|
-
|
1604
|
+
stats_update.update(
|
1107
1605
|
{
|
1108
1606
|
"current_rate": self.rate_tracker["current_rate"],
|
1109
1607
|
"average_rate": self.rate_tracker["average_rate"],
|
@@ -1112,22 +1610,227 @@ class Orchestrator:
|
|
1112
1610
|
)
|
1113
1611
|
|
1114
1612
|
# Add vLLM info
|
1115
|
-
|
1116
|
-
|
1117
|
-
|
1118
|
-
|
1119
|
-
|
1120
|
-
|
1121
|
-
|
1122
|
-
|
1613
|
+
stats_update["vllm_model"] = self.vllm_config.get("model", "unknown")
|
1614
|
+
stats_update["vllm_batch_size"] = self.vllm_config.get("batch_size", 0)
|
1615
|
+
|
1616
|
+
# Add stage information
|
1617
|
+
stages = self.vllm_config.get("stages", [])
|
1618
|
+
if stages:
|
1619
|
+
stats_update["stage_count"] = len(stages)
|
1620
|
+
stats_update["stage_names"] = [s.get("name", "unnamed") for s in stages]
|
1621
|
+
else:
|
1622
|
+
stats_update["stage_count"] = 1
|
1623
|
+
stats_update["stage_names"] = ["default"]
|
1624
|
+
|
1625
|
+
# Get field stats
|
1626
|
+
field_stats_start = time.time()
|
1627
|
+
field_stats = await self.storage.get_output_field_stats()
|
1628
|
+
stats_update["output_fields"] = field_stats
|
1629
|
+
logger.debug(f"Field stats retrieved in {(time.time() - field_stats_start)*1000:.1f}ms")
|
1630
|
+
|
1631
|
+
# Update our internal stats
|
1632
|
+
self.stats = stats_update
|
1633
|
+
logger.debug(f"Stats prepared in {(time.time() - build_stats_start)*1000:.1f}ms")
|
1634
|
+
|
1635
|
+
logger.debug(f"Total data preparation took {(time.time() - data_prep_start)*1000:.1f}ms")
|
1636
|
+
|
1637
|
+
# Create message once
|
1638
|
+
message_create_start = time.time()
|
1639
|
+
stats_message = safe_json_dumps({"type": "stats", "data": self.stats})
|
1640
|
+
logger.debug(f"Stats message created in {(time.time() - message_create_start)*1000:.1f}ms")
|
1641
|
+
|
1642
|
+
# Send to all monitors asynchronously in parallel
|
1643
|
+
send_start = time.time()
|
1644
|
+
|
1645
|
+
async def send_to_monitor(monitor):
|
1123
1646
|
try:
|
1124
|
-
await monitor.send(
|
1647
|
+
await monitor.send(stats_message)
|
1125
1648
|
except websockets.exceptions.ConnectionClosed:
|
1126
|
-
|
1649
|
+
return monitor # Return for removal
|
1650
|
+
except Exception as e:
|
1651
|
+
logger.debug(f"Error sending stats to monitor: {e}")
|
1652
|
+
return monitor # Return for removal
|
1653
|
+
return None
|
1654
|
+
|
1655
|
+
# Send to all monitors in parallel
|
1656
|
+
monitors_copy = self.monitors.copy()
|
1657
|
+
results = await asyncio.gather(
|
1658
|
+
*[send_to_monitor(m) for m in monitors_copy], return_exceptions=True
|
1659
|
+
)
|
1127
1660
|
|
1128
|
-
#
|
1661
|
+
# Remove disconnected monitors
|
1662
|
+
disconnected = {
|
1663
|
+
m
|
1664
|
+
for m, r in zip(monitors_copy, results)
|
1665
|
+
if r is not None and not isinstance(r, Exception)
|
1666
|
+
}
|
1129
1667
|
self.monitors -= disconnected
|
1130
1668
|
|
1669
|
+
logger.debug(
|
1670
|
+
f"Stats sent to {len(monitors_copy)} monitors in {(time.time() - send_start)*1000:.1f}ms"
|
1671
|
+
)
|
1672
|
+
|
1673
|
+
# Send leaderboard update in a separate task to avoid blocking
|
1674
|
+
leaderboard_task_start = time.time()
|
1675
|
+
asyncio.create_task(self._broadcast_leaderboard())
|
1676
|
+
self.is_generating_stats = False
|
1677
|
+
logger.debug(
|
1678
|
+
f"Leaderboard broadcast task created in {(time.time() - leaderboard_task_start)*1000:.1f}ms"
|
1679
|
+
)
|
1680
|
+
logger.debug(f"Stats broadcast completed in {(time.time() - total_start)*1000:.1f}ms")
|
1681
|
+
|
1682
|
+
async def _broadcast_leaderboard(self):
|
1683
|
+
"""Send leaderboard updates to monitors - separate from stats to avoid blocking."""
|
1684
|
+
if not self.monitors:
|
1685
|
+
return
|
1686
|
+
|
1687
|
+
total_start = time.time()
|
1688
|
+
try:
|
1689
|
+
# Get contributors
|
1690
|
+
contributors_start = time.time()
|
1691
|
+
contributors = await self.storage.get_top_contributors(10)
|
1692
|
+
logger.debug(
|
1693
|
+
f"Contributors retrieved for broadcast in {(time.time() - contributors_start)*1000:.1f}ms"
|
1694
|
+
)
|
1695
|
+
|
1696
|
+
# Get worker counts
|
1697
|
+
worker_counts_start = time.time()
|
1698
|
+
loop = asyncio.get_event_loop()
|
1699
|
+
worker_counts = await loop.run_in_executor(
|
1700
|
+
None,
|
1701
|
+
lambda: (
|
1702
|
+
self.get_workers_by_user_stats() if hasattr(self, "workers_by_user") else {}
|
1703
|
+
),
|
1704
|
+
)
|
1705
|
+
logger.debug(
|
1706
|
+
f"Worker counts retrieved for broadcast in {(time.time() - worker_counts_start)*1000:.1f}ms"
|
1707
|
+
)
|
1708
|
+
|
1709
|
+
# Build enhanced contributors list
|
1710
|
+
build_start = time.time()
|
1711
|
+
enhanced_contributors = []
|
1712
|
+
for contributor in contributors:
|
1713
|
+
contrib_dict = {
|
1714
|
+
"contributor_id": contributor.contributor_id,
|
1715
|
+
"name": contributor.name,
|
1716
|
+
"total_captions": contributor.total_captions,
|
1717
|
+
"trust_level": contributor.trust_level,
|
1718
|
+
"active_workers": len(
|
1719
|
+
worker_counts.get(contributor.contributor_id, {}).get("worker_ids", [])
|
1720
|
+
),
|
1721
|
+
}
|
1722
|
+
enhanced_contributors.append(contrib_dict)
|
1723
|
+
logger.debug(
|
1724
|
+
f"Enhanced contributors built for broadcast in {(time.time() - build_start)*1000:.1f}ms"
|
1725
|
+
)
|
1726
|
+
|
1727
|
+
# Cache it
|
1728
|
+
self._cached_leaderboard = enhanced_contributors
|
1729
|
+
|
1730
|
+
# Create message once
|
1731
|
+
message_create_start = time.time()
|
1732
|
+
leaderboard_message = safe_json_dumps(
|
1733
|
+
{"type": "leaderboard", "data": enhanced_contributors}
|
1734
|
+
)
|
1735
|
+
logger.debug(
|
1736
|
+
f"Leaderboard message created in {(time.time() - message_create_start)*1000:.1f}ms"
|
1737
|
+
)
|
1738
|
+
|
1739
|
+
# Send to all monitors in parallel
|
1740
|
+
send_start = time.time()
|
1741
|
+
|
1742
|
+
async def send_leaderboard(monitor):
|
1743
|
+
try:
|
1744
|
+
await monitor.send(leaderboard_message)
|
1745
|
+
except:
|
1746
|
+
return monitor # Mark for removal
|
1747
|
+
return None
|
1748
|
+
|
1749
|
+
monitors_copy = self.monitors.copy()
|
1750
|
+
results = await asyncio.gather(
|
1751
|
+
*[send_leaderboard(m) for m in monitors_copy], return_exceptions=True
|
1752
|
+
)
|
1753
|
+
|
1754
|
+
# Remove disconnected
|
1755
|
+
disconnected = {
|
1756
|
+
m
|
1757
|
+
for m, r in zip(monitors_copy, results)
|
1758
|
+
if r is not None and not isinstance(r, Exception)
|
1759
|
+
}
|
1760
|
+
self.monitors -= disconnected
|
1761
|
+
|
1762
|
+
logger.debug(
|
1763
|
+
f"Leaderboard sent to {len(monitors_copy)} monitors in {(time.time() - send_start)*1000:.1f}ms"
|
1764
|
+
)
|
1765
|
+
logger.debug(
|
1766
|
+
f"Leaderboard broadcast completed in {(time.time() - total_start)*1000:.1f}ms"
|
1767
|
+
)
|
1768
|
+
|
1769
|
+
except Exception as e:
|
1770
|
+
logger.error(f"Error broadcasting leaderboard: {e}")
|
1771
|
+
|
1772
|
+
def _get_queue_stats(self) -> Dict[str, int]:
|
1773
|
+
"""Get queue statistics - synchronous helper for thread pool."""
|
1774
|
+
with self.chunk_manager.lock:
|
1775
|
+
return {
|
1776
|
+
"pending_chunks": len(self.chunk_manager.pending_chunks),
|
1777
|
+
"assigned_chunks": sum(
|
1778
|
+
len(chunks) for chunks in self.chunk_manager.assigned_chunks.values()
|
1779
|
+
),
|
1780
|
+
}
|
1781
|
+
|
1782
|
+
async def _flush_processed_items(self):
|
1783
|
+
"""Flush batched processed items to chunk tracker."""
|
1784
|
+
with self.item_batch_lock:
|
1785
|
+
if not self.pending_processed_items:
|
1786
|
+
return
|
1787
|
+
|
1788
|
+
for chunk_id, indices in self.pending_processed_items.items():
|
1789
|
+
if not indices:
|
1790
|
+
continue
|
1791
|
+
|
1792
|
+
# Indices here are ABSOLUTE dataset indices
|
1793
|
+
# Sort indices
|
1794
|
+
indices.sort()
|
1795
|
+
|
1796
|
+
# Group consecutive indices into ranges
|
1797
|
+
ranges = []
|
1798
|
+
start = indices[0]
|
1799
|
+
end = indices[0]
|
1800
|
+
|
1801
|
+
for i in range(1, len(indices)):
|
1802
|
+
if indices[i] == end + 1:
|
1803
|
+
# Consecutive, extend range
|
1804
|
+
end = indices[i]
|
1805
|
+
else:
|
1806
|
+
# Gap found, save current range and start new one
|
1807
|
+
ranges.append((start, end))
|
1808
|
+
start = indices[i]
|
1809
|
+
end = indices[i]
|
1810
|
+
|
1811
|
+
# Don't forget the last range
|
1812
|
+
ranges.append((start, end))
|
1813
|
+
|
1814
|
+
# Mark ranges as processed (mark_items_processed expects absolute indices)
|
1815
|
+
for start_idx, end_idx in ranges:
|
1816
|
+
self.chunk_tracker.mark_items_processed(chunk_id, start_idx, end_idx)
|
1817
|
+
|
1818
|
+
# Clear pending items
|
1819
|
+
self.pending_processed_items.clear()
|
1820
|
+
self.last_item_batch_flush = time.time()
|
1821
|
+
|
1822
|
+
def get_workers_by_user_stats(self) -> Dict[str, Any]:
|
1823
|
+
"""Get statistics about workers grouped by user/token - thread-safe version."""
|
1824
|
+
if not hasattr(self, "workers_by_user"):
|
1825
|
+
return {}
|
1826
|
+
|
1827
|
+
# Create a copy to avoid issues with concurrent modification
|
1828
|
+
stats = {}
|
1829
|
+
workers_snapshot = dict(self.workers_by_user)
|
1830
|
+
for user, worker_ids in workers_snapshot.items():
|
1831
|
+
stats[user] = {"worker_count": len(worker_ids), "worker_ids": list(worker_ids)}
|
1832
|
+
return stats
|
1833
|
+
|
1131
1834
|
async def _send_activity(self, activity: str):
|
1132
1835
|
"""Send activity update to monitors."""
|
1133
1836
|
if not self.monitors:
|
@@ -1149,21 +1852,63 @@ class Orchestrator:
|
|
1149
1852
|
async def _heartbeat_loop(self):
|
1150
1853
|
"""Send periodic heartbeats to maintain connections."""
|
1151
1854
|
while True:
|
1152
|
-
|
1855
|
+
try:
|
1856
|
+
await asyncio.sleep(30)
|
1153
1857
|
|
1154
|
-
|
1155
|
-
|
1156
|
-
|
1157
|
-
try:
|
1158
|
-
await ws.ping()
|
1159
|
-
except:
|
1160
|
-
disconnected.append(worker_id)
|
1858
|
+
# Create a copy of worker items to avoid modification during iteration
|
1859
|
+
worker_items = list(self.workers.items())
|
1860
|
+
disconnected = []
|
1161
1861
|
|
1162
|
-
|
1163
|
-
|
1164
|
-
|
1165
|
-
|
1166
|
-
|
1862
|
+
for worker_id, ws in worker_items:
|
1863
|
+
try:
|
1864
|
+
# Check if worker still exists before pinging
|
1865
|
+
if worker_id not in self.workers:
|
1866
|
+
continue
|
1867
|
+
|
1868
|
+
# Send ping with timeout
|
1869
|
+
pong_waiter = await ws.ping()
|
1870
|
+
try:
|
1871
|
+
await asyncio.wait_for(pong_waiter, timeout=10)
|
1872
|
+
except asyncio.TimeoutError:
|
1873
|
+
logger.warning(f"Worker {worker_id} failed to respond to ping")
|
1874
|
+
disconnected.append(worker_id)
|
1875
|
+
except websockets.exceptions.ConnectionClosed:
|
1876
|
+
logger.info(f"Worker {worker_id} connection already closed")
|
1877
|
+
disconnected.append(worker_id)
|
1878
|
+
except Exception as e:
|
1879
|
+
logger.error(f"Error pinging worker {worker_id}: {e}")
|
1880
|
+
disconnected.append(worker_id)
|
1881
|
+
|
1882
|
+
# Clean up disconnected workers
|
1883
|
+
for worker_id in disconnected:
|
1884
|
+
if worker_id in self.workers:
|
1885
|
+
logger.info(f"Removing unresponsive worker {worker_id}")
|
1886
|
+
del self.workers[worker_id]
|
1887
|
+
self.chunk_manager.release_worker_chunks(worker_id)
|
1888
|
+
|
1889
|
+
# Update stats
|
1890
|
+
self.stats["connected_workers"] = len(self.workers)
|
1891
|
+
|
1892
|
+
# Also clean up from workers_by_user if it exists
|
1893
|
+
if hasattr(self, "workers_by_user"):
|
1894
|
+
worker_user = (
|
1895
|
+
worker_id.rsplit("_", 1)[0] if "_" in worker_id else worker_id
|
1896
|
+
)
|
1897
|
+
if worker_user in self.workers_by_user:
|
1898
|
+
self.workers_by_user[worker_user].discard(worker_id)
|
1899
|
+
if not self.workers_by_user[worker_user]:
|
1900
|
+
del self.workers_by_user[worker_user]
|
1901
|
+
|
1902
|
+
# Notify monitors
|
1903
|
+
await self._broadcast_stats()
|
1904
|
+
await self._send_activity(
|
1905
|
+
f"Worker {worker_id} removed due to heartbeat timeout"
|
1906
|
+
)
|
1907
|
+
|
1908
|
+
except Exception as e:
|
1909
|
+
logger.error(f"Error in heartbeat loop: {e}", exc_info=True)
|
1910
|
+
# Continue the loop even if there's an error
|
1911
|
+
await asyncio.sleep(5)
|
1167
1912
|
|
1168
1913
|
async def _checkpoint_loop(self):
|
1169
1914
|
"""Periodically checkpoint storage."""
|
@@ -1172,42 +1917,58 @@ class Orchestrator:
|
|
1172
1917
|
while True:
|
1173
1918
|
await asyncio.sleep(60)
|
1174
1919
|
|
1920
|
+
# Get current caption count from storage
|
1921
|
+
storage_stats = await self.storage.get_storage_stats()
|
1922
|
+
total_captions = storage_stats["total_captions"]
|
1923
|
+
|
1175
1924
|
# Force checkpoint at regular intervals
|
1176
|
-
if
|
1177
|
-
logger.info(f"Triggering checkpoint at {
|
1925
|
+
if total_captions > 0 and total_captions % interval == 0:
|
1926
|
+
logger.info(f"Triggering checkpoint at {total_captions} captions")
|
1178
1927
|
await self.storage.checkpoint()
|
1179
1928
|
|
1180
1929
|
# Update stats
|
1181
1930
|
self.stats["last_checkpoint"] = datetime.utcnow().isoformat()
|
1182
|
-
|
1183
|
-
self.stats["buffer_size"] = len(self.storage.caption_buffer)
|
1931
|
+
# No need to update total_written or buffer_size - they come from storage
|
1184
1932
|
|
1185
1933
|
await self._broadcast_stats()
|
1186
1934
|
logger.info(
|
1187
|
-
f"Checkpoint complete. Total written to disk: {
|
1935
|
+
f"Checkpoint complete. Total written to disk: {storage_stats['total_written']}"
|
1188
1936
|
)
|
1189
1937
|
|
1190
1938
|
async def _stats_update_loop(self):
|
1191
|
-
"""Periodically update and broadcast stats."""
|
1939
|
+
"""Periodically update and broadcast stats - non-blocking version."""
|
1940
|
+
# Get the event loop for running blocking operations
|
1941
|
+
loop = asyncio.get_event_loop()
|
1942
|
+
|
1192
1943
|
# Track session start values
|
1193
|
-
|
1944
|
+
storage_stats = await self.storage.get_storage_stats()
|
1945
|
+
session_start_outputs = storage_stats["total_captions"] # This now counts ALL outputs
|
1194
1946
|
session_start_time = time.time()
|
1195
1947
|
|
1948
|
+
# Track the last known total to detect flushes
|
1949
|
+
last_known_total = session_start_outputs
|
1950
|
+
|
1196
1951
|
while True:
|
1197
1952
|
await asyncio.sleep(10)
|
1198
1953
|
|
1199
|
-
# Update chunk stats
|
1200
|
-
chunk_stats = self.chunk_manager.get_stats
|
1954
|
+
# Update chunk stats in thread pool to avoid blocking
|
1955
|
+
chunk_stats = await loop.run_in_executor(None, self.chunk_manager.get_stats)
|
1956
|
+
storage_stats = await self.storage.get_storage_stats()
|
1957
|
+
current_total_outputs = storage_stats["total_captions"] # ALL outputs
|
1958
|
+
if self.chunk_tracker:
|
1959
|
+
await self._flush_processed_items()
|
1960
|
+
|
1201
1961
|
self.stats["total_chunks"] = chunk_stats["total"]
|
1202
1962
|
self.stats["completed_chunks"] = chunk_stats["completed"]
|
1203
1963
|
self.stats["failed_chunks"] = chunk_stats["failed"]
|
1204
1964
|
|
1205
|
-
#
|
1206
|
-
|
1207
|
-
|
1208
|
-
|
1209
|
-
|
1210
|
-
|
1965
|
+
# Update total outputs stat (rename from total_captions for clarity)
|
1966
|
+
self.stats["total_outputs"] = current_total_outputs
|
1967
|
+
self.stats["total_captions"] = current_total_outputs # Keep for backward compatibility
|
1968
|
+
|
1969
|
+
# Get queue stats in thread pool to avoid blocking
|
1970
|
+
queue_stats = await loop.run_in_executor(None, self._get_queue_stats)
|
1971
|
+
self.stats.update(queue_stats)
|
1211
1972
|
|
1212
1973
|
# Calculate if we need more chunks
|
1213
1974
|
worker_count = self.stats.get("connected_workers", 0)
|
@@ -1220,33 +1981,57 @@ class Orchestrator:
|
|
1220
1981
|
elapsed_since_update = current_time - self.rate_tracker["last_update_time"]
|
1221
1982
|
|
1222
1983
|
if elapsed_since_update > 0:
|
1223
|
-
#
|
1224
|
-
|
1225
|
-
|
1226
|
-
|
1227
|
-
|
1984
|
+
# FIX: Handle the case where duplicates were skipped during save
|
1985
|
+
# If current total is less than last known, it means duplicates were skipped
|
1986
|
+
# We should not count this as negative progress
|
1987
|
+
if current_total_outputs < last_known_total:
|
1988
|
+
logger.debug(
|
1989
|
+
f"Detected duplicate skip during save: {last_known_total} -> {current_total_outputs}"
|
1990
|
+
)
|
1991
|
+
# Don't calculate negative rate, just update the baseline
|
1992
|
+
self.rate_tracker["last_caption_count"] = current_total_outputs
|
1993
|
+
self.rate_tracker["current_rate"] = 0.0 # Set to 0 during flush
|
1994
|
+
else:
|
1995
|
+
# Normal rate calculation
|
1996
|
+
output_diff = current_total_outputs - self.rate_tracker["last_caption_count"]
|
1997
|
+
self.rate_tracker["current_rate"] = (output_diff / elapsed_since_update) * 60
|
1998
|
+
self.rate_tracker["last_caption_count"] = current_total_outputs
|
1228
1999
|
|
1229
2000
|
# Calculate average rate since THIS SESSION started
|
1230
2001
|
session_elapsed = current_time - session_start_time
|
1231
2002
|
if session_elapsed > 0:
|
1232
|
-
|
1233
|
-
|
2003
|
+
# Always use the difference from session start for average
|
2004
|
+
session_outputs = current_total_outputs - session_start_outputs
|
2005
|
+
self.rate_tracker["average_rate"] = (session_outputs / session_elapsed) * 60
|
1234
2006
|
|
1235
|
-
# Calculate expected rate based on workers
|
1236
|
-
# Assume each worker processes batch_size images every ~2 seconds with 3 captions each
|
2007
|
+
# Calculate expected rate based on workers and stages
|
1237
2008
|
batch_size = self.vllm_config.get("batch_size", 8)
|
1238
|
-
|
2009
|
+
|
2010
|
+
# Count total prompts across all stages
|
2011
|
+
total_prompts = 0
|
2012
|
+
stages = self.vllm_config.get("stages", [])
|
2013
|
+
if stages:
|
2014
|
+
for stage in stages:
|
2015
|
+
total_prompts += len(stage.get("prompts", []))
|
2016
|
+
else:
|
2017
|
+
# Backward compatibility
|
2018
|
+
total_prompts = len(self.vllm_config.get("inference_prompts", ["", "", ""]))
|
2019
|
+
|
1239
2020
|
images_per_minute = 30 # Rough estimate: 30 images/min per worker
|
1240
|
-
self.rate_tracker["expected_rate"] =
|
2021
|
+
self.rate_tracker["expected_rate"] = (
|
2022
|
+
worker_count * images_per_minute * total_prompts
|
2023
|
+
)
|
1241
2024
|
|
1242
2025
|
# Update trackers
|
1243
2026
|
self.rate_tracker["last_update_time"] = current_time
|
1244
|
-
|
2027
|
+
last_known_total = current_total_outputs
|
1245
2028
|
|
1246
2029
|
# Log rate information when workers are connected
|
1247
|
-
if
|
2030
|
+
if (
|
2031
|
+
worker_count > 0 and self.rate_tracker["current_rate"] >= 0
|
2032
|
+
): # Only log non-negative rates
|
1248
2033
|
logger.info(
|
1249
|
-
f"Rate: {self.rate_tracker['current_rate']:.1f}
|
2034
|
+
f"Rate: {self.rate_tracker['current_rate']:.1f} outputs/min "
|
1250
2035
|
f"(avg: {self.rate_tracker['average_rate']:.1f}, "
|
1251
2036
|
f"expected: {self.rate_tracker['expected_rate']:.1f}) | "
|
1252
2037
|
f"Workers: {worker_count}, Chunks: {active_chunks}/{target_buffer}"
|
@@ -1256,16 +2041,16 @@ class Orchestrator:
|
|
1256
2041
|
|
1257
2042
|
async def _restore_state(self):
|
1258
2043
|
"""Restore state from storage on startup."""
|
1259
|
-
|
1260
|
-
|
1261
|
-
|
1262
|
-
logger.info(f"Restored state: {self.stats['total_captions']} captions")
|
2044
|
+
total_captions = await self.storage.count_captions()
|
2045
|
+
logger.info(f"Restored state: {total_captions} captions")
|
1263
2046
|
|
1264
2047
|
async def shutdown(self):
|
1265
2048
|
"""Graceful shutdown."""
|
1266
2049
|
logger.info("Shutting down orchestrator...")
|
1267
2050
|
|
1268
2051
|
# Stop chunk creation
|
2052
|
+
if self.chunk_tracker:
|
2053
|
+
await self._flush_processed_items()
|
1269
2054
|
self.stop_chunk_creation.set()
|
1270
2055
|
if self.chunk_creation_thread:
|
1271
2056
|
self.chunk_creation_thread.join(timeout=5)
|
@@ -1287,7 +2072,7 @@ class Orchestrator:
|
|
1287
2072
|
|
1288
2073
|
# Save chunk state
|
1289
2074
|
if self.chunk_tracker:
|
1290
|
-
self.chunk_tracker.
|
2075
|
+
self.chunk_tracker.save()
|
1291
2076
|
|
1292
2077
|
# Final checkpoint
|
1293
2078
|
logger.info(f"Final flush: {len(self.storage.caption_buffer)} captions in buffer")
|