caption-flow 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +3 -2
- caption_flow/cli.py +56 -39
- caption_flow/models.py +6 -4
- caption_flow/monitor.py +12 -2
- caption_flow/orchestrator.py +729 -217
- caption_flow/storage.py +579 -222
- caption_flow/utils/__init__.py +3 -1
- caption_flow/utils/auth.py +24 -25
- caption_flow/utils/checkpoint_tracker.py +92 -0
- caption_flow/utils/chunk_tracker.py +278 -194
- caption_flow/utils/dataset_loader.py +392 -73
- caption_flow/utils/image_processor.py +121 -1
- caption_flow/utils/prompt_template.py +137 -0
- caption_flow/utils/shard_processor.py +315 -0
- caption_flow/utils/shard_tracker.py +87 -0
- caption_flow/workers/base.py +228 -0
- caption_flow/workers/caption.py +1321 -0
- caption_flow/{worker_data.py → workers/data.py} +162 -234
- caption_flow-0.2.0.dist-info/METADATA +369 -0
- caption_flow-0.2.0.dist-info/RECORD +29 -0
- caption_flow/worker.py +0 -300
- caption_flow/worker_vllm.py +0 -1028
- caption_flow-0.1.0.dist-info/METADATA +0 -427
- caption_flow-0.1.0.dist-info/RECORD +0 -25
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.0.dist-info}/WHEEL +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.0.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.0.dist-info}/top_level.txt +0 -0
caption_flow/orchestrator.py
CHANGED
@@ -21,15 +21,15 @@ from collections import deque, defaultdict
|
|
21
21
|
import threading
|
22
22
|
from queue import Queue, Empty
|
23
23
|
|
24
|
+
from .workers import data
|
24
25
|
import websockets
|
25
26
|
from websockets.server import WebSocketServerProtocol
|
26
27
|
|
27
28
|
from .storage import StorageManager
|
28
29
|
from .models import Caption, Contributor
|
29
30
|
from .utils.auth import AuthManager
|
30
|
-
from .utils
|
31
|
+
from .utils import DatasetLoader, ShardTracker, ChunkTracker
|
31
32
|
from .utils.json_utils import safe_dict, safe_json_dumps, to_json_dict
|
32
|
-
from .utils.chunk_tracker import ChunkTracker
|
33
33
|
|
34
34
|
logger = logging.getLogger(__name__)
|
35
35
|
|
@@ -48,6 +48,43 @@ class ShardChunk:
|
|
48
48
|
assigned_at: Optional[datetime] = None
|
49
49
|
completed_at: Optional[datetime] = None
|
50
50
|
|
51
|
+
@classmethod
|
52
|
+
def create(
|
53
|
+
cls, shard_url: str, shard_name: str, start_index: int, chunk_size: int
|
54
|
+
) -> "ShardChunk":
|
55
|
+
"""Factory method to create a chunk with consistent ID."""
|
56
|
+
# Always use consistent format: dataset_chunk_startindex
|
57
|
+
if shard_url.startswith("hf_dataset:"):
|
58
|
+
# Extract dataset path
|
59
|
+
parts = shard_url.split(":")
|
60
|
+
dataset_path = parts[1] if len(parts) > 1 else "unknown"
|
61
|
+
chunk_id = f"{dataset_path.replace('/', '_')}_chunk_{start_index}"
|
62
|
+
else:
|
63
|
+
# WebDataset format
|
64
|
+
chunk_id = f"{shard_name}_chunk_{start_index}"
|
65
|
+
|
66
|
+
return cls(
|
67
|
+
chunk_id=chunk_id,
|
68
|
+
shard_url=shard_url,
|
69
|
+
shard_name=shard_name,
|
70
|
+
start_index=start_index,
|
71
|
+
chunk_size=chunk_size,
|
72
|
+
)
|
73
|
+
|
74
|
+
def belongs_to_shard(self, shard_identifier: str) -> bool:
|
75
|
+
"""Check if this chunk belongs to a given shard."""
|
76
|
+
return self.shard_name == shard_identifier
|
77
|
+
|
78
|
+
def to_dict(self) -> Dict[str, Any]:
|
79
|
+
"""Convert to dict for JSON serialization (for workers)."""
|
80
|
+
return {
|
81
|
+
"chunk_id": self.chunk_id,
|
82
|
+
"shard_url": self.shard_url,
|
83
|
+
"shard_name": self.shard_name,
|
84
|
+
"start_index": self.start_index,
|
85
|
+
"chunk_size": self.chunk_size,
|
86
|
+
}
|
87
|
+
|
51
88
|
|
52
89
|
class ChunkManager:
|
53
90
|
"""Manages shard chunk creation and assignment."""
|
@@ -67,9 +104,7 @@ class ChunkManager:
|
|
67
104
|
chunks = []
|
68
105
|
|
69
106
|
for start_idx in range(0, total_items, self.chunk_size):
|
70
|
-
|
71
|
-
chunk = ShardChunk(
|
72
|
-
chunk_id=chunk_id,
|
107
|
+
chunk = ShardChunk.create(
|
73
108
|
shard_url=shard_url,
|
74
109
|
shard_name=shard_name,
|
75
110
|
start_index=start_idx,
|
@@ -77,8 +112,8 @@ class ChunkManager:
|
|
77
112
|
)
|
78
113
|
|
79
114
|
with self.lock:
|
80
|
-
self.chunks[chunk_id] = chunk
|
81
|
-
self.pending_chunks.append(chunk_id)
|
115
|
+
self.chunks[chunk.chunk_id] = chunk
|
116
|
+
self.pending_chunks.append(chunk.chunk_id)
|
82
117
|
|
83
118
|
chunks.append(chunk)
|
84
119
|
|
@@ -86,24 +121,84 @@ class ChunkManager:
|
|
86
121
|
|
87
122
|
def get_chunks_for_worker(
|
88
123
|
self, worker_id: str, count: int = 1, tracker: Optional["ChunkTracker"] = None
|
89
|
-
) -> List[
|
90
|
-
"""Get available chunks for a worker."""
|
124
|
+
) -> List[Dict[str, Any]]:
|
125
|
+
"""Get available chunks with unprocessed items for a worker."""
|
91
126
|
assigned = []
|
92
127
|
|
93
128
|
with self.lock:
|
129
|
+
# FIRST PRIORITY: Check if this worker already has assigned chunks
|
130
|
+
# Workers should complete their current chunks before getting new ones
|
131
|
+
if worker_id in self.assigned_chunks:
|
132
|
+
existing_chunk_ids = list(self.assigned_chunks[worker_id])
|
133
|
+
for chunk_id in existing_chunk_ids:
|
134
|
+
if len(assigned) >= count:
|
135
|
+
break
|
136
|
+
|
137
|
+
chunk = self.chunks.get(chunk_id)
|
138
|
+
if not chunk:
|
139
|
+
continue
|
140
|
+
|
141
|
+
# Check if chunk still has unprocessed items
|
142
|
+
if tracker:
|
143
|
+
chunk_info = tracker.get_chunk_with_unprocessed_items(chunk_id)
|
144
|
+
if chunk_info and chunk_info["unprocessed_ranges"]:
|
145
|
+
assigned.append(
|
146
|
+
{
|
147
|
+
"chunk": chunk,
|
148
|
+
"unprocessed_ranges": chunk_info["unprocessed_ranges"],
|
149
|
+
}
|
150
|
+
)
|
151
|
+
else:
|
152
|
+
# No tracker, assume chunk needs processing
|
153
|
+
assigned.append(
|
154
|
+
{
|
155
|
+
"chunk": chunk,
|
156
|
+
"unprocessed_ranges": [(0, chunk.chunk_size - 1)],
|
157
|
+
}
|
158
|
+
)
|
159
|
+
|
160
|
+
# SECOND PRIORITY: Get new pending chunks
|
161
|
+
# Only if worker doesn't have enough chunks already
|
94
162
|
while len(assigned) < count and self.pending_chunks:
|
95
163
|
chunk_id = self.pending_chunks.popleft()
|
96
|
-
chunk = self.chunks
|
164
|
+
chunk = self.chunks.get(chunk_id)
|
165
|
+
|
166
|
+
if not chunk:
|
167
|
+
continue
|
97
168
|
|
169
|
+
# Verify chunk is truly pending (defensive check)
|
170
|
+
if chunk.status != "pending" or chunk.assigned_to is not None:
|
171
|
+
logger.warning(
|
172
|
+
f"Chunk {chunk_id} in pending queue but status={chunk.status}, assigned_to={chunk.assigned_to}"
|
173
|
+
)
|
174
|
+
continue
|
175
|
+
|
176
|
+
# Assign to this worker
|
98
177
|
chunk.assigned_to = worker_id
|
99
178
|
chunk.status = "assigned"
|
100
179
|
chunk.assigned_at = datetime.utcnow()
|
101
|
-
|
102
180
|
self.assigned_chunks[worker_id].add(chunk_id)
|
103
|
-
|
181
|
+
|
182
|
+
# Get unprocessed ranges
|
183
|
+
unprocessed_ranges = [(0, chunk.chunk_size - 1)] # Default
|
104
184
|
if tracker:
|
185
|
+
chunk_info = tracker.get_chunk_with_unprocessed_items(chunk_id)
|
186
|
+
if chunk_info:
|
187
|
+
unprocessed_ranges = chunk_info["unprocessed_ranges"]
|
105
188
|
tracker.mark_assigned(chunk_id, worker_id)
|
106
189
|
|
190
|
+
assigned.append({"chunk": chunk, "unprocessed_ranges": unprocessed_ranges})
|
191
|
+
|
192
|
+
# Log what we're assigning
|
193
|
+
if assigned:
|
194
|
+
chunk_summary = ", ".join(
|
195
|
+
[
|
196
|
+
f"{info['chunk'].chunk_id}[{len(info['unprocessed_ranges'])} ranges]"
|
197
|
+
for info in assigned
|
198
|
+
]
|
199
|
+
)
|
200
|
+
logger.info(f"Assigning to worker {worker_id}: {chunk_summary}")
|
201
|
+
|
107
202
|
return assigned
|
108
203
|
|
109
204
|
def complete_chunk(self, chunk_id: str, worker_id: str) -> bool:
|
@@ -173,6 +268,27 @@ class Orchestrator:
|
|
173
268
|
self.dataset_config = config.get("dataset", {})
|
174
269
|
self.dataset_path = self.dataset_config.get("path")
|
175
270
|
self.dataset_type = self.dataset_config.get("type", "huggingface")
|
271
|
+
self.dataset_split = self.dataset_config.get("split", "train") # Add split configuration
|
272
|
+
self.dataset_image_column = self.dataset_config.get(
|
273
|
+
"image_column", "image"
|
274
|
+
) # Add image column config
|
275
|
+
|
276
|
+
# Dataset components
|
277
|
+
self.dataset_loader = None
|
278
|
+
self.shard_tracker = None
|
279
|
+
self.chunk_tracker = None
|
280
|
+
|
281
|
+
if self.dataset_path:
|
282
|
+
self.dataset_loader = DatasetLoader(
|
283
|
+
self.dataset_path,
|
284
|
+
self.dataset_type,
|
285
|
+
self.dataset_split,
|
286
|
+
self.dataset_image_column,
|
287
|
+
)
|
288
|
+
checkpoint_dir = Path(config.get("storage", {}).get("checkpoint_dir", "./checkpoints"))
|
289
|
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
290
|
+
self.shard_tracker = ShardTracker(checkpoint_dir / "shards.json")
|
291
|
+
self.chunk_tracker = ChunkTracker(checkpoint_dir / "chunks.json")
|
176
292
|
|
177
293
|
# vLLM configuration to distribute to workers
|
178
294
|
self.vllm_config = config.get(
|
@@ -233,6 +349,11 @@ class Orchestrator:
|
|
233
349
|
|
234
350
|
# Initialize chunk manager with reference to chunk tracker
|
235
351
|
self.chunk_manager = ChunkManager(self.chunk_size, self.chunk_tracker)
|
352
|
+
self.pending_processed_items = defaultdict(list) # chunk_id -> list of indices
|
353
|
+
self.item_batch_lock = threading.Lock()
|
354
|
+
self.last_item_batch_flush = time.time()
|
355
|
+
self.item_batch_interval = 5 # Flush every 5 seconds
|
356
|
+
self.item_batch_size = 100 # Or every 100 items
|
236
357
|
|
237
358
|
# Track connections
|
238
359
|
self.workers: Dict[str, WebSocketServerProtocol] = {}
|
@@ -246,13 +367,10 @@ class Orchestrator:
|
|
246
367
|
"total_chunks": 0,
|
247
368
|
"completed_chunks": 0,
|
248
369
|
"failed_chunks": 0,
|
249
|
-
"total_captions": 0,
|
250
370
|
"connected_workers": 0,
|
251
371
|
"total_shards": 0,
|
252
372
|
"completed_shards": 0,
|
253
373
|
"current_shard": None,
|
254
|
-
"buffer_size": 0,
|
255
|
-
"total_written": 0,
|
256
374
|
"last_checkpoint": None,
|
257
375
|
}
|
258
376
|
|
@@ -266,7 +384,7 @@ class Orchestrator:
|
|
266
384
|
"expected_rate": 0.0,
|
267
385
|
}
|
268
386
|
|
269
|
-
# Data sample queue for
|
387
|
+
# Data sample queue for CaptionWorker
|
270
388
|
self.data_sample_queue = asyncio.Queue(maxsize=1000)
|
271
389
|
self.data_workers: Dict[str, WebSocketServerProtocol] = {}
|
272
390
|
|
@@ -310,10 +428,23 @@ class Orchestrator:
|
|
310
428
|
# Mark state as not restored until we process checkpoints
|
311
429
|
self.state_restored.clear()
|
312
430
|
|
431
|
+
# Get dataset info to check format
|
432
|
+
dataset_info = self.dataset_loader.get_dataset_info()
|
433
|
+
dataset_format = dataset_info.get("dataset_format", "unknown")
|
434
|
+
logger.info(f"Dataset format: {dataset_format}")
|
435
|
+
|
313
436
|
# Get all shards
|
314
437
|
self.all_shards = self.dataset_loader.get_shard_list()
|
315
438
|
self.stats["total_shards"] = len(self.all_shards)
|
316
439
|
|
440
|
+
# For HuggingFace datasets, we might need to dynamically create more shards
|
441
|
+
if dataset_format == "huggingface_datasets":
|
442
|
+
self._is_hf_dataset = True
|
443
|
+
self._hf_chunk_size = 10000 # Items per virtual shard
|
444
|
+
self._next_hf_shard_index = len(self.all_shards) # For creating new virtual shards
|
445
|
+
else:
|
446
|
+
self._is_hf_dataset = False
|
447
|
+
|
317
448
|
# Get shard status from ChunkTracker
|
318
449
|
shards_summary = self.chunk_tracker.get_shards_summary() if self.chunk_tracker else {}
|
319
450
|
completed_shards = {
|
@@ -336,7 +467,10 @@ class Orchestrator:
|
|
336
467
|
|
337
468
|
# Filter out shards that already have chunks created
|
338
469
|
remaining_shards = [
|
339
|
-
shard
|
470
|
+
shard
|
471
|
+
for shard in remaining_shards
|
472
|
+
if (shard if shard.startswith("hf_dataset:") else Path(shard).stem)
|
473
|
+
not in shards_with_chunks
|
340
474
|
]
|
341
475
|
|
342
476
|
self.stats["completed_shards"] = len(completed_shards)
|
@@ -356,25 +490,18 @@ class Orchestrator:
|
|
356
490
|
with self.chunk_manager.lock:
|
357
491
|
for chunk_state in shard_info["chunks"]:
|
358
492
|
if chunk_state.status in ["pending", "failed", "assigned"]:
|
359
|
-
#
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
start_index=chunk_state.start_index,
|
372
|
-
chunk_size=chunk_state.chunk_size,
|
373
|
-
)
|
374
|
-
self.chunk_manager.chunks[chunk_state.chunk_id] = chunk
|
375
|
-
self.chunk_manager.pending_chunks.append(chunk_state.chunk_id)
|
376
|
-
requeued_chunks_by_shard[shard_name].append(chunk_state.chunk_id)
|
377
|
-
initial_pending += 1
|
493
|
+
# ChunkState already has shard_url stored
|
494
|
+
chunk = ShardChunk(
|
495
|
+
chunk_id=chunk_state.chunk_id,
|
496
|
+
shard_url=chunk_state.shard_url,
|
497
|
+
shard_name=chunk_state.shard_name,
|
498
|
+
start_index=chunk_state.start_index,
|
499
|
+
chunk_size=chunk_state.chunk_size,
|
500
|
+
)
|
501
|
+
self.chunk_manager.chunks[chunk_state.chunk_id] = chunk
|
502
|
+
self.chunk_manager.pending_chunks.append(chunk_state.chunk_id)
|
503
|
+
requeued_chunks_by_shard[shard_name].append(chunk_state.chunk_id)
|
504
|
+
initial_pending += 1
|
378
505
|
|
379
506
|
logger.info(f"Re-queued {initial_pending} existing pending chunks")
|
380
507
|
for shard_name, chunk_ids in requeued_chunks_by_shard.items():
|
@@ -426,7 +553,13 @@ class Orchestrator:
|
|
426
553
|
if current_shard_url is None or current_shard_index >= current_shard_items:
|
427
554
|
try:
|
428
555
|
current_shard_url = next(shard_iter)
|
429
|
-
|
556
|
+
|
557
|
+
# Extract shard name based on type
|
558
|
+
if current_shard_url.startswith("hf_dataset:"):
|
559
|
+
current_shard_name = current_shard_url # Use full ID for virtual shards
|
560
|
+
else:
|
561
|
+
current_shard_name = Path(current_shard_url).stem
|
562
|
+
|
430
563
|
self.stats["current_shard"] = current_shard_name
|
431
564
|
|
432
565
|
# Skip if we already have chunks from this shard
|
@@ -439,16 +572,74 @@ class Orchestrator:
|
|
439
572
|
|
440
573
|
# Count items in new shard
|
441
574
|
logger.info(f"Loading new shard {current_shard_name}")
|
442
|
-
|
443
|
-
|
444
|
-
)
|
575
|
+
|
576
|
+
# For virtual HF dataset shards, use the chunk size directly
|
577
|
+
if current_shard_url.startswith("hf_dataset:"):
|
578
|
+
current_shard_items = self.dataset_loader.count_shard_items(
|
579
|
+
current_shard_url
|
580
|
+
)
|
581
|
+
logger.info(
|
582
|
+
f"Virtual shard {current_shard_name} has {current_shard_items} items"
|
583
|
+
)
|
584
|
+
else:
|
585
|
+
# For WebDataset, actually count items
|
586
|
+
current_shard_items = sum(
|
587
|
+
1 for _ in self.dataset_loader.iterate_shard(current_shard_url)
|
588
|
+
)
|
589
|
+
logger.info(
|
590
|
+
f"Shard {current_shard_name} has {current_shard_items} items"
|
591
|
+
)
|
592
|
+
|
445
593
|
current_shard_index = 0
|
446
|
-
logger.info(f"Shard {current_shard_name} has {current_shard_items} items")
|
447
594
|
|
448
595
|
except StopIteration:
|
449
|
-
# No more shards
|
596
|
+
# No more shards in the iterator
|
597
|
+
if self._is_hf_dataset:
|
598
|
+
# Before creating new virtual shards, check if we have pending chunks
|
599
|
+
with self.chunk_manager.lock:
|
600
|
+
pending_count = len(self.chunk_manager.pending_chunks)
|
601
|
+
|
602
|
+
if pending_count > 0:
|
603
|
+
# Don't create new shards if we have pending chunks
|
604
|
+
logger.debug(
|
605
|
+
f"Have {pending_count} pending chunks, not creating new virtual shards yet"
|
606
|
+
)
|
607
|
+
current_shard_url = None
|
608
|
+
time.sleep(2)
|
609
|
+
continue
|
610
|
+
|
611
|
+
# For HF datasets, we can create more virtual shards on demand
|
612
|
+
logger.info(
|
613
|
+
"Creating additional virtual shards for HuggingFace dataset"
|
614
|
+
)
|
615
|
+
|
616
|
+
# Create 10 more virtual shards
|
617
|
+
new_shards = []
|
618
|
+
for i in range(10):
|
619
|
+
shard_id = f"hf_dataset:{self.dataset_path}:chunk:{self._next_hf_shard_index * self._hf_chunk_size}"
|
620
|
+
new_shards.append(shard_id)
|
621
|
+
self._next_hf_shard_index += 1
|
622
|
+
|
623
|
+
# Add to all_shards and create new iterator
|
624
|
+
self.all_shards.extend(new_shards)
|
625
|
+
self.stats["total_shards"] = len(self.all_shards)
|
626
|
+
|
627
|
+
# Filter for unprocessed shards
|
628
|
+
remaining_new_shards = [
|
629
|
+
s
|
630
|
+
for s in new_shards
|
631
|
+
if s not in shards_summary and s not in completed_shards
|
632
|
+
]
|
633
|
+
|
634
|
+
if remaining_new_shards:
|
635
|
+
shard_iter = iter(remaining_new_shards)
|
636
|
+
logger.info(f"Added {len(remaining_new_shards)} new virtual shards")
|
637
|
+
continue
|
638
|
+
|
639
|
+
# No more shards to process
|
450
640
|
logger.info("No more shards to process")
|
451
641
|
break
|
642
|
+
|
452
643
|
except Exception as e:
|
453
644
|
logger.error(f"Error loading shard {current_shard_name}: {e}")
|
454
645
|
current_shard_url = None
|
@@ -456,25 +647,40 @@ class Orchestrator:
|
|
456
647
|
|
457
648
|
# Create a chunk from current shard
|
458
649
|
if current_shard_url and current_shard_index < current_shard_items:
|
459
|
-
|
460
|
-
|
650
|
+
# Calculate the absolute dataset index for this chunk
|
651
|
+
if current_shard_url.startswith("hf_dataset:"):
|
652
|
+
# Parse the virtual shard URL to get the base start index
|
653
|
+
parts = current_shard_url.split(":")
|
654
|
+
if len(parts) >= 4 and parts[2] == "chunk":
|
655
|
+
shard_base_index = int(parts[3])
|
656
|
+
else:
|
657
|
+
shard_base_index = 0
|
658
|
+
|
659
|
+
# The absolute start index for this chunk in the dataset
|
660
|
+
absolute_start_index = shard_base_index + current_shard_index
|
661
|
+
else:
|
662
|
+
# For WebDataset, current_shard_index is already absolute
|
663
|
+
absolute_start_index = current_shard_index
|
664
|
+
|
665
|
+
# Create chunk with absolute index
|
666
|
+
chunk = ShardChunk.create(
|
667
|
+
shard_url=current_shard_url,
|
668
|
+
shard_name=current_shard_name,
|
669
|
+
start_index=absolute_start_index,
|
670
|
+
chunk_size=min(self.chunk_size, current_shard_items - current_shard_index),
|
671
|
+
)
|
461
672
|
|
462
|
-
# Add to ChunkTracker
|
673
|
+
# Add to ChunkTracker with all required fields
|
463
674
|
if self.chunk_tracker and self.chunk_tracker.add_chunk(
|
464
|
-
chunk_id,
|
675
|
+
chunk.chunk_id,
|
676
|
+
chunk.shard_name,
|
677
|
+
chunk.shard_url,
|
678
|
+
chunk.start_index,
|
679
|
+
chunk.chunk_size,
|
465
680
|
):
|
466
|
-
# Create chunk
|
467
|
-
chunk = ShardChunk(
|
468
|
-
chunk_id=chunk_id,
|
469
|
-
shard_url=current_shard_url,
|
470
|
-
shard_name=current_shard_name,
|
471
|
-
start_index=current_shard_index,
|
472
|
-
chunk_size=chunk_size,
|
473
|
-
)
|
474
|
-
|
475
681
|
with self.chunk_manager.lock:
|
476
|
-
self.chunk_manager.chunks[chunk_id] = chunk
|
477
|
-
self.chunk_manager.pending_chunks.append(chunk_id)
|
682
|
+
self.chunk_manager.chunks[chunk.chunk_id] = chunk
|
683
|
+
self.chunk_manager.pending_chunks.append(chunk.chunk_id)
|
478
684
|
|
479
685
|
chunks_created += 1
|
480
686
|
self.stats["total_chunks"] += 1
|
@@ -484,10 +690,14 @@ class Orchestrator:
|
|
484
690
|
if chunks_created > 0:
|
485
691
|
logger.info(f"Created {chunks_created} chunks on demand")
|
486
692
|
|
487
|
-
# If we couldn't create any chunks and there are no more shards,
|
693
|
+
# If we couldn't create any chunks and there are no more shards, check if it's HF dataset
|
488
694
|
if chunks_created == 0 and current_shard_url is None:
|
489
|
-
|
490
|
-
|
695
|
+
if self._is_hf_dataset:
|
696
|
+
# We can always create more virtual shards for HF datasets
|
697
|
+
logger.debug("Will create more virtual shards on next iteration")
|
698
|
+
else:
|
699
|
+
logger.info("All shards processed, chunk creation complete")
|
700
|
+
break
|
491
701
|
|
492
702
|
# Brief pause to avoid spinning
|
493
703
|
time.sleep(1)
|
@@ -558,7 +768,9 @@ class Orchestrator:
|
|
558
768
|
elif auth_ticket.role == "admin":
|
559
769
|
await self._handle_admin(websocket, auth_ticket)
|
560
770
|
else:
|
561
|
-
await websocket.send(
|
771
|
+
await websocket.send(
|
772
|
+
safe_json_dumps({"error": f"Unknown role: {auth_ticket.role}"})
|
773
|
+
)
|
562
774
|
|
563
775
|
except Exception as e:
|
564
776
|
logger.error(f"Connection error: {e}")
|
@@ -604,81 +816,118 @@ class Orchestrator:
|
|
604
816
|
requires_worker_restart = False
|
605
817
|
|
606
818
|
try:
|
819
|
+
# Extract orchestrator section if present
|
820
|
+
if "orchestrator" in new_config:
|
821
|
+
# Config has orchestrator wrapper, extract it
|
822
|
+
orchestrator_config = new_config["orchestrator"]
|
823
|
+
else:
|
824
|
+
# Config is already at orchestrator level
|
825
|
+
orchestrator_config = new_config
|
826
|
+
|
827
|
+
# Helper function for deep comparison
|
828
|
+
def deep_equal(a, b):
|
829
|
+
"""Deep comparison of two values including nested dicts and lists."""
|
830
|
+
if type(a) != type(b):
|
831
|
+
return False
|
832
|
+
if isinstance(a, dict):
|
833
|
+
if set(a.keys()) != set(b.keys()):
|
834
|
+
return False
|
835
|
+
return all(deep_equal(a[k], b[k]) for k in a.keys())
|
836
|
+
elif isinstance(a, (list, tuple)):
|
837
|
+
if len(a) != len(b):
|
838
|
+
return False
|
839
|
+
return all(deep_equal(x, y) for x, y in zip(a, b))
|
840
|
+
else:
|
841
|
+
return a == b
|
842
|
+
|
607
843
|
# Update vLLM configuration
|
608
|
-
if "vllm" in
|
844
|
+
if "vllm" in orchestrator_config:
|
609
845
|
old_vllm = self.vllm_config.copy()
|
846
|
+
new_vllm = orchestrator_config["vllm"]
|
610
847
|
|
611
|
-
# Check
|
612
|
-
vllm_changed =
|
613
|
-
for key, value in new_config["vllm"].items():
|
614
|
-
if self.vllm_config.get(key) != value:
|
615
|
-
self.vllm_config[key] = value
|
616
|
-
vllm_changed = True
|
848
|
+
# Check if vLLM config actually changed using deep comparison
|
849
|
+
vllm_changed = not deep_equal(old_vllm, new_vllm)
|
617
850
|
|
618
851
|
if vllm_changed:
|
852
|
+
# Update the vLLM config
|
853
|
+
self.vllm_config = new_vllm.copy()
|
619
854
|
updated_sections.append("vllm")
|
620
855
|
|
621
856
|
# Check if critical changes require worker restart
|
622
857
|
if (
|
623
|
-
old_vllm.get("model") !=
|
858
|
+
old_vllm.get("model") != new_vllm.get("model")
|
624
859
|
or old_vllm.get("gpu_memory_utilization")
|
625
|
-
!=
|
860
|
+
!= new_vllm.get("gpu_memory_utilization")
|
626
861
|
or old_vllm.get("tensor_parallel_size")
|
627
|
-
!=
|
862
|
+
!= new_vllm.get("tensor_parallel_size")
|
863
|
+
or old_vllm.get("dtype") != new_vllm.get("dtype")
|
864
|
+
or old_vllm.get("max_model_len") != new_vllm.get("max_model_len")
|
628
865
|
):
|
629
866
|
requires_worker_restart = True
|
630
867
|
warnings.append(
|
631
868
|
"Critical vLLM changes detected - workers will be disconnected to reload"
|
632
869
|
)
|
870
|
+
logger.info(
|
871
|
+
f"Model change: {old_vllm.get('model')} -> {new_vllm.get('model')}"
|
872
|
+
)
|
633
873
|
|
634
874
|
# Update dataset configuration
|
635
|
-
if "dataset" in
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
dataset_changed = True
|
875
|
+
if "dataset" in orchestrator_config:
|
876
|
+
old_dataset = self.dataset_config.copy()
|
877
|
+
new_dataset = orchestrator_config["dataset"]
|
878
|
+
|
879
|
+
dataset_changed = not deep_equal(old_dataset, new_dataset)
|
641
880
|
|
642
881
|
if dataset_changed:
|
882
|
+
self.dataset_config = new_dataset.copy()
|
643
883
|
self.dataset_path = self.dataset_config.get("path")
|
644
884
|
self.dataset_type = self.dataset_config.get("type", "huggingface")
|
645
885
|
updated_sections.append("dataset")
|
646
886
|
warnings.append("Dataset changes will apply to new chunks only")
|
647
887
|
|
648
888
|
# Update chunk settings
|
649
|
-
if
|
650
|
-
|
889
|
+
if (
|
890
|
+
"chunk_size" in orchestrator_config
|
891
|
+
and self.chunk_size != orchestrator_config["chunk_size"]
|
892
|
+
):
|
893
|
+
self.chunk_size = orchestrator_config["chunk_size"]
|
651
894
|
self.chunk_manager.chunk_size = self.chunk_size
|
652
895
|
updated_sections.append("chunk_size")
|
653
896
|
|
654
897
|
if (
|
655
|
-
"chunks_per_request" in
|
656
|
-
and self.chunks_per_request !=
|
898
|
+
"chunks_per_request" in orchestrator_config
|
899
|
+
and self.chunks_per_request != orchestrator_config["chunks_per_request"]
|
657
900
|
):
|
658
|
-
self.chunks_per_request =
|
901
|
+
self.chunks_per_request = orchestrator_config["chunks_per_request"]
|
659
902
|
updated_sections.append("chunks_per_request")
|
660
903
|
|
661
|
-
#
|
662
|
-
|
904
|
+
# Update auth configuration
|
905
|
+
if "auth" in orchestrator_config:
|
906
|
+
try:
|
907
|
+
self.auth = AuthManager({"auth": orchestrator_config["auth"]})
|
908
|
+
updated_sections.append("auth")
|
909
|
+
except Exception as e:
|
910
|
+
logger.error(f"Failed to update AuthManager: {e}")
|
911
|
+
warnings.append(f"Auth update failed: {e}")
|
663
912
|
|
664
913
|
# Update buffer settings
|
665
914
|
if (
|
666
|
-
"chunk_buffer_multiplier" in
|
667
|
-
and self.chunk_buffer_multiplier !=
|
915
|
+
"chunk_buffer_multiplier" in orchestrator_config
|
916
|
+
and self.chunk_buffer_multiplier != orchestrator_config["chunk_buffer_multiplier"]
|
668
917
|
):
|
669
|
-
self.chunk_buffer_multiplier =
|
918
|
+
self.chunk_buffer_multiplier = orchestrator_config["chunk_buffer_multiplier"]
|
670
919
|
updated_sections.append("chunk_buffer_multiplier")
|
671
920
|
|
672
921
|
if (
|
673
|
-
"min_chunk_buffer" in
|
674
|
-
and self.min_chunk_buffer !=
|
922
|
+
"min_chunk_buffer" in orchestrator_config
|
923
|
+
and self.min_chunk_buffer != orchestrator_config["min_chunk_buffer"]
|
675
924
|
):
|
676
|
-
self.min_chunk_buffer =
|
925
|
+
self.min_chunk_buffer = orchestrator_config["min_chunk_buffer"]
|
677
926
|
updated_sections.append("min_chunk_buffer")
|
678
927
|
|
679
928
|
# Update storage settings
|
680
|
-
if "storage" in
|
681
|
-
storage_config =
|
929
|
+
if "storage" in orchestrator_config:
|
930
|
+
storage_config = orchestrator_config["storage"]
|
682
931
|
storage_changed = False
|
683
932
|
|
684
933
|
if (
|
@@ -701,21 +950,6 @@ class Orchestrator:
|
|
701
950
|
if storage_changed:
|
702
951
|
updated_sections.append("storage")
|
703
952
|
|
704
|
-
# Update data worker storage config
|
705
|
-
if "data_worker_storage" in new_config:
|
706
|
-
current_dw_storage = self.config.get("data_worker_storage", {})
|
707
|
-
if current_dw_storage != new_config["data_worker_storage"]:
|
708
|
-
self.config["data_worker_storage"] = new_config["data_worker_storage"]
|
709
|
-
updated_sections.append("data_worker_storage")
|
710
|
-
warnings.append("Data worker storage config will apply to new connections only")
|
711
|
-
|
712
|
-
# Update backpressure threshold
|
713
|
-
if "backpressure_threshold" in new_config:
|
714
|
-
current_threshold = getattr(self, "backpressure_threshold", 800)
|
715
|
-
if current_threshold != new_config["backpressure_threshold"]:
|
716
|
-
self.backpressure_threshold = new_config["backpressure_threshold"]
|
717
|
-
updated_sections.append("backpressure_threshold")
|
718
|
-
|
719
953
|
# Check if any changes were made
|
720
954
|
if not updated_sections:
|
721
955
|
await websocket.send(
|
@@ -729,29 +963,49 @@ class Orchestrator:
|
|
729
963
|
logger.info("Configuration reload requested but no changes detected")
|
730
964
|
return
|
731
965
|
|
732
|
-
# Update the main config
|
733
|
-
|
966
|
+
# Update the main config
|
967
|
+
if "orchestrator" in new_config:
|
968
|
+
self.config["orchestrator"] = orchestrator_config
|
969
|
+
else:
|
970
|
+
self.config.update(orchestrator_config)
|
734
971
|
|
735
972
|
# Handle worker restart if needed
|
736
973
|
if requires_worker_restart:
|
737
974
|
logger.info("Disconnecting all workers for configuration reload...")
|
738
975
|
|
739
|
-
#
|
740
|
-
|
741
|
-
|
976
|
+
# Send reload message to workers first
|
977
|
+
reload_msg = safe_json_dumps(
|
978
|
+
{
|
979
|
+
"type": "reload_vllm",
|
980
|
+
"vllm_config": self.vllm_config,
|
981
|
+
}
|
982
|
+
)
|
983
|
+
|
984
|
+
# Create a list of worker items to avoid modifying dict during iteration
|
985
|
+
worker_items = list(self.workers.items())
|
986
|
+
disconnected = []
|
987
|
+
|
988
|
+
for worker_id, ws in worker_items:
|
742
989
|
try:
|
743
|
-
await
|
744
|
-
|
745
|
-
)
|
990
|
+
await ws.send(reload_msg)
|
991
|
+
# Give worker time to process before disconnect
|
992
|
+
await asyncio.sleep(0.5)
|
993
|
+
await ws.close(code=1012, reason="Configuration reload")
|
994
|
+
disconnected.append(worker_id)
|
746
995
|
except:
|
747
|
-
|
996
|
+
disconnected.append(worker_id) # Still mark as disconnected if error
|
997
|
+
|
998
|
+
# Now safely clear workers dict
|
999
|
+
for worker_id in disconnected:
|
1000
|
+
if worker_id in self.workers:
|
1001
|
+
del self.workers[worker_id]
|
748
1002
|
|
749
1003
|
warnings.append(
|
750
|
-
f"
|
1004
|
+
f"Sent reload message to {len(disconnected)} workers - they will reconnect with new config"
|
751
1005
|
)
|
752
1006
|
else:
|
753
|
-
# Just notify workers about config changes
|
754
|
-
|
1007
|
+
# Just notify workers about config changes without disconnecting
|
1008
|
+
config_update_msg = safe_json_dumps(
|
755
1009
|
{
|
756
1010
|
"type": "config_update",
|
757
1011
|
"vllm_config": self.vllm_config if "vllm" in updated_sections else None,
|
@@ -761,15 +1015,21 @@ class Orchestrator:
|
|
761
1015
|
}
|
762
1016
|
)
|
763
1017
|
|
1018
|
+
# Create a list of worker items to avoid modifying dict during iteration
|
1019
|
+
worker_items = list(self.workers.items())
|
764
1020
|
disconnected = []
|
765
|
-
|
1021
|
+
|
1022
|
+
for worker_id, ws in worker_items:
|
766
1023
|
try:
|
767
|
-
await ws.send(
|
1024
|
+
await ws.send(config_update_msg)
|
1025
|
+
logger.info(f"Sent config update to worker {worker_id}")
|
768
1026
|
except:
|
769
1027
|
disconnected.append(worker_id)
|
770
1028
|
|
1029
|
+
# Now safely remove disconnected workers
|
771
1030
|
for worker_id in disconnected:
|
772
|
-
|
1031
|
+
if worker_id in self.workers:
|
1032
|
+
del self.workers[worker_id]
|
773
1033
|
|
774
1034
|
# Send success response
|
775
1035
|
await websocket.send(
|
@@ -788,34 +1048,58 @@ class Orchestrator:
|
|
788
1048
|
|
789
1049
|
except Exception as e:
|
790
1050
|
logger.error(f"Configuration reload failed: {e}")
|
1051
|
+
import traceback
|
1052
|
+
|
1053
|
+
logger.error(traceback.format_exc())
|
791
1054
|
await websocket.send(safe_json_dumps({"type": "reload_failed", "error": str(e)}))
|
792
1055
|
|
793
1056
|
async def _handle_worker(self, websocket: WebSocketServerProtocol, auth_ticket):
|
794
1057
|
"""Handle worker connection lifecycle."""
|
795
|
-
|
1058
|
+
# Generate unique worker ID even if using same token
|
1059
|
+
base_name = getattr(auth_ticket, "name", "worker")
|
1060
|
+
worker_id = f"{base_name}_{str(uuid.uuid4())[:8]}" # Add unique suffix
|
1061
|
+
|
1062
|
+
# Track the original token/user for accounting
|
1063
|
+
worker_user = base_name # Keep track of which user/token this worker belongs to
|
1064
|
+
|
796
1065
|
self.workers[worker_id] = websocket
|
797
1066
|
self.stats["connected_workers"] = len(self.workers)
|
798
1067
|
|
799
|
-
#
|
800
|
-
|
801
|
-
|
802
|
-
)
|
803
|
-
|
1068
|
+
# Optionally track workers by user/token
|
1069
|
+
if not hasattr(self, "workers_by_user"):
|
1070
|
+
self.workers_by_user = defaultdict(set)
|
1071
|
+
self.workers_by_user[worker_user].add(worker_id)
|
1072
|
+
|
1073
|
+
# Register contributor with the base name (for aggregating stats per user)
|
1074
|
+
contributor = await self.storage.get_contributor(worker_user)
|
1075
|
+
if not contributor:
|
1076
|
+
contributor = Contributor(
|
1077
|
+
contributor_id=worker_user,
|
1078
|
+
name=worker_user,
|
1079
|
+
total_captions=0,
|
1080
|
+
trust_level=1,
|
1081
|
+
)
|
1082
|
+
await self.storage.save_contributor(contributor)
|
804
1083
|
|
805
|
-
logger.info(f"Worker {worker_id} connected")
|
1084
|
+
logger.info(f"Worker {worker_id} (user: {worker_user}) connected")
|
806
1085
|
await self._broadcast_stats()
|
807
|
-
await self._send_activity(f"Worker {worker_id} connected")
|
1086
|
+
await self._send_activity(f"Worker {worker_id} (user: {worker_user}) connected")
|
808
1087
|
|
809
1088
|
try:
|
810
1089
|
# Send welcome message with dataset configuration
|
811
1090
|
welcome_message = {
|
812
1091
|
"type": "welcome",
|
813
1092
|
"worker_id": worker_id,
|
1093
|
+
"user_id": worker_user,
|
814
1094
|
"dataset_config": {
|
815
1095
|
"dataset_path": self.dataset_path,
|
816
1096
|
"dataset_type": self.dataset_type,
|
817
|
-
"
|
818
|
-
"
|
1097
|
+
"dataset_split": self.dataset_split,
|
1098
|
+
"dataset_image_column": self.dataset_image_column,
|
1099
|
+
"path": self.dataset_path,
|
1100
|
+
"type": self.dataset_type,
|
1101
|
+
"split": self.dataset_split,
|
1102
|
+
"image_column": self.dataset_image_column,
|
819
1103
|
},
|
820
1104
|
"vllm_config": self.vllm_config,
|
821
1105
|
}
|
@@ -826,21 +1110,29 @@ class Orchestrator:
|
|
826
1110
|
await self._process_worker_message(worker_id, data)
|
827
1111
|
|
828
1112
|
except websockets.exceptions.ConnectionClosed:
|
829
|
-
logger.info(f"Worker {worker_id} disconnected")
|
1113
|
+
logger.info(f"Worker {worker_id} (user: {worker_user}) disconnected")
|
830
1114
|
finally:
|
831
|
-
|
1115
|
+
if worker_id in self.workers:
|
1116
|
+
del self.workers[worker_id]
|
1117
|
+
|
1118
|
+
# Clean up user tracking
|
1119
|
+
if hasattr(self, "workers_by_user") and worker_user in self.workers_by_user:
|
1120
|
+
self.workers_by_user[worker_user].discard(worker_id)
|
1121
|
+
if not self.workers_by_user[worker_user]:
|
1122
|
+
del self.workers_by_user[worker_user]
|
1123
|
+
|
832
1124
|
self.stats["connected_workers"] = len(self.workers)
|
833
|
-
|
1125
|
+
|
1126
|
+
# Release chunks
|
834
1127
|
self.chunk_manager.release_worker_chunks(worker_id)
|
835
1128
|
if self.chunk_tracker:
|
836
|
-
# Mark released chunks as pending in tracker
|
837
1129
|
released_chunks = self.chunk_tracker.release_worker_chunks(worker_id)
|
838
1130
|
logger.info(
|
839
1131
|
f"Released {len(released_chunks) if released_chunks is not None else 0} chunks from worker {worker_id}"
|
840
1132
|
)
|
841
1133
|
|
842
1134
|
await self._broadcast_stats()
|
843
|
-
await self._send_activity(f"Worker {worker_id} disconnected")
|
1135
|
+
await self._send_activity(f"Worker {worker_id} (user: {worker_user}) disconnected")
|
844
1136
|
|
845
1137
|
async def _process_worker_message(self, worker_id: str, data: Dict):
|
846
1138
|
"""Process message from worker."""
|
@@ -856,28 +1148,26 @@ class Orchestrator:
|
|
856
1148
|
return
|
857
1149
|
|
858
1150
|
count = data.get("count", self.chunks_per_request)
|
859
|
-
|
1151
|
+
chunk_infos = self.chunk_manager.get_chunks_for_worker(
|
1152
|
+
worker_id, count, self.chunk_tracker
|
1153
|
+
)
|
860
1154
|
|
861
|
-
if
|
862
|
-
#
|
863
|
-
|
864
|
-
for
|
865
|
-
|
866
|
-
|
867
|
-
|
868
|
-
"shard_url": chunk.shard_url,
|
869
|
-
"shard_name": chunk.shard_name,
|
870
|
-
"start_index": chunk.start_index,
|
871
|
-
"chunk_size": chunk.chunk_size,
|
872
|
-
}
|
873
|
-
)
|
1155
|
+
if chunk_infos:
|
1156
|
+
# Send chunks with unprocessed ranges
|
1157
|
+
chunks_data = []
|
1158
|
+
for info in chunk_infos:
|
1159
|
+
chunk_dict = info["chunk"].to_dict()
|
1160
|
+
chunk_dict["unprocessed_ranges"] = info["unprocessed_ranges"]
|
1161
|
+
chunks_data.append(chunk_dict)
|
874
1162
|
|
875
1163
|
await self.workers[worker_id].send(
|
876
|
-
safe_json_dumps({"type": "shard_assignment", "chunks":
|
1164
|
+
safe_json_dumps({"type": "shard_assignment", "chunks": chunks_data})
|
1165
|
+
)
|
1166
|
+
|
1167
|
+
chunk_ids = [c["chunk_id"] for c in chunks_data]
|
1168
|
+
logger.info(
|
1169
|
+
f"Assigned {len(chunks_data)} chunks to worker {worker_id}: {chunk_ids}"
|
877
1170
|
)
|
878
|
-
chunk_ids = [c["chunk_id"] for c in chunk_data]
|
879
|
-
logger.info(f"Assigned {len(chunks)} chunks to worker {worker_id}: {chunk_ids}")
|
880
|
-
await self._send_activity(f"Assigned {len(chunks)} chunks to {worker_id}")
|
881
1171
|
else:
|
882
1172
|
await self.workers[worker_id].send(safe_json_dumps({"type": "no_chunks"}))
|
883
1173
|
|
@@ -907,7 +1197,7 @@ class Orchestrator:
|
|
907
1197
|
elif msg_type == "submit_captions":
|
908
1198
|
await self._handle_captions_submission(worker_id, data)
|
909
1199
|
elif msg_type == "request_job":
|
910
|
-
#
|
1200
|
+
# CaptionWorker requesting a job from data samples
|
911
1201
|
try:
|
912
1202
|
job = await asyncio.wait_for(self.data_sample_queue.get(), timeout=5)
|
913
1203
|
await self.workers[worker_id].send(
|
@@ -921,76 +1211,132 @@ class Orchestrator:
|
|
921
1211
|
logger.debug(f"Heartbeat from {worker_id}: {data}")
|
922
1212
|
|
923
1213
|
async def _handle_captions_submission(self, worker_id: str, data: Dict):
|
924
|
-
"""Process
|
1214
|
+
"""Process caption submission from worker - now handles multi-stage outputs."""
|
925
1215
|
chunk_id = data.get("chunk_id")
|
926
1216
|
item_key = data["item_key"]
|
927
|
-
captions_list = data["captions"]
|
928
1217
|
|
929
|
-
|
930
|
-
|
931
|
-
|
1218
|
+
item_index = data.get("item_index") # Worker should send this
|
1219
|
+
if item_index is None:
|
1220
|
+
# Try to extract from item_key (format: dataset_XXXXXXXX)
|
1221
|
+
try:
|
1222
|
+
item_index = int(item_key.split("_")[-1])
|
1223
|
+
except:
|
1224
|
+
logger.warning(f"Could not extract item index from key: {item_key}")
|
932
1225
|
|
933
|
-
#
|
1226
|
+
# Extract user from worker_id (format: "username_uuid")
|
1227
|
+
worker_user = worker_id.rsplit("_", 1)[0] if "_" in worker_id else worker_id
|
1228
|
+
|
1229
|
+
# Handle both old format (captions list) and new format (outputs dict)
|
1230
|
+
if "outputs" in data:
|
1231
|
+
# New multi-stage format
|
1232
|
+
outputs = data["outputs"]
|
1233
|
+
captions_list = outputs.get("captions", [])
|
1234
|
+
total_outputs = sum(len(v) for v in outputs.values())
|
1235
|
+
|
1236
|
+
logger.debug(
|
1237
|
+
f"Received multi-stage outputs for item {item_key} from worker {worker_id}: "
|
1238
|
+
f"{total_outputs} outputs across {len(outputs)} fields"
|
1239
|
+
)
|
1240
|
+
else:
|
1241
|
+
# Old format - single captions list
|
1242
|
+
captions_list = data["captions"]
|
1243
|
+
outputs = {"captions": captions_list}
|
1244
|
+
total_outputs = len(captions_list)
|
1245
|
+
|
1246
|
+
logger.debug(
|
1247
|
+
f"Received {len(captions_list)} captions for item {item_key} from worker {worker_id}"
|
1248
|
+
)
|
1249
|
+
|
1250
|
+
# Create caption record with multi-stage outputs
|
934
1251
|
caption = Caption(
|
935
|
-
job_id=f"{chunk_id}_{item_key}",
|
1252
|
+
job_id=f"{chunk_id}_{item_key}",
|
936
1253
|
dataset=data.get("dataset"),
|
937
1254
|
shard=data.get("shard"),
|
938
1255
|
item_key=item_key,
|
939
|
-
captions=captions_list,
|
940
|
-
|
1256
|
+
captions=captions_list,
|
1257
|
+
outputs=outputs,
|
1258
|
+
contributor_id=worker_user,
|
941
1259
|
timestamp=datetime.utcnow(),
|
942
|
-
quality_scores=None,
|
1260
|
+
quality_scores=None,
|
943
1261
|
# Image metadata
|
944
1262
|
image_width=data.get("image_width"),
|
945
1263
|
image_height=data.get("image_height"),
|
946
1264
|
image_format=data.get("image_format"),
|
947
1265
|
file_size=data.get("file_size"),
|
948
1266
|
# Processing metadata
|
949
|
-
caption_count=
|
1267
|
+
caption_count=total_outputs,
|
950
1268
|
processing_time_ms=data.get("processing_time_ms"),
|
951
1269
|
chunk_id=chunk_id,
|
1270
|
+
metadata=data.get("metadata", {}),
|
952
1271
|
)
|
953
1272
|
|
954
|
-
# Add to central storage buffer
|
1273
|
+
# Add to central storage buffer
|
955
1274
|
await self.storage.save_caption(caption)
|
956
1275
|
|
957
|
-
#
|
958
|
-
|
959
|
-
|
1276
|
+
# Handle item tracking with fixed deadlock
|
1277
|
+
should_flush = False
|
1278
|
+
if chunk_id and item_index is not None and self.chunk_tracker:
|
1279
|
+
with self.item_batch_lock:
|
1280
|
+
self.pending_processed_items[chunk_id].append(item_index)
|
960
1281
|
|
961
|
-
|
962
|
-
|
1282
|
+
# Check if we should flush
|
1283
|
+
total_pending = sum(
|
1284
|
+
len(indices) for indices in self.pending_processed_items.values()
|
1285
|
+
)
|
1286
|
+
time_since_flush = time.time() - self.last_item_batch_flush
|
1287
|
+
|
1288
|
+
if (
|
1289
|
+
total_pending >= self.item_batch_size
|
1290
|
+
or time_since_flush >= self.item_batch_interval
|
1291
|
+
):
|
1292
|
+
should_flush = True
|
1293
|
+
|
1294
|
+
if should_flush:
|
1295
|
+
await self._flush_processed_items()
|
1296
|
+
|
1297
|
+
# Update contributor stats (use user, not worker)
|
1298
|
+
contributor = await self.storage.get_contributor(worker_user)
|
963
1299
|
if contributor:
|
964
|
-
contributor.total_captions +=
|
1300
|
+
contributor.total_captions += total_outputs
|
965
1301
|
await self.storage.save_contributor(contributor)
|
966
1302
|
|
967
1303
|
# Broadcast updated stats
|
968
1304
|
await self._broadcast_stats()
|
969
1305
|
|
970
1306
|
# Log progress periodically
|
971
|
-
|
972
|
-
|
1307
|
+
total_outputs = self.stats.get("total_outputs", 0)
|
1308
|
+
if total_outputs > 0 and total_outputs % 100 == 0:
|
1309
|
+
if (
|
1310
|
+
not hasattr(self, "_last_logged_outputs")
|
1311
|
+
or self._last_logged_outputs != total_outputs
|
1312
|
+
):
|
1313
|
+
logger.info(f"Collected {total_outputs} outputs centrally")
|
1314
|
+
self._last_logged_outputs = total_outputs
|
973
1315
|
|
974
1316
|
async def _check_shard_completion(self, chunk_id: str):
|
975
1317
|
"""Check if a shard is complete after chunk completion."""
|
976
|
-
#
|
977
|
-
|
1318
|
+
# Get the chunk
|
1319
|
+
chunk = self.chunk_manager.chunks.get(chunk_id)
|
1320
|
+
if not chunk:
|
1321
|
+
return
|
978
1322
|
|
979
|
-
|
980
|
-
|
1323
|
+
shard_name = chunk.shard_name
|
1324
|
+
|
1325
|
+
# Find all chunks for this shard
|
981
1326
|
shard_chunks = [
|
982
|
-
cid
|
983
|
-
for cid, chunk in self.chunk_manager.chunks.items()
|
984
|
-
if chunk.shard_name == shard_name
|
1327
|
+
cid for cid, c in self.chunk_manager.chunks.items() if c.belongs_to_shard(shard_name)
|
985
1328
|
]
|
986
1329
|
|
1330
|
+
# Check if all are completed
|
987
1331
|
completed_chunks = [
|
988
1332
|
cid for cid in shard_chunks if self.chunk_manager.chunks[cid].status == "completed"
|
989
1333
|
]
|
990
1334
|
|
991
|
-
if len(completed_chunks) == len(shard_chunks):
|
1335
|
+
if len(completed_chunks) == len(shard_chunks) and len(shard_chunks) > 0:
|
992
1336
|
logger.info(f"Shard {shard_name} complete!")
|
993
|
-
|
1337
|
+
# Don't mark virtual shards as complete in ShardTracker
|
1338
|
+
if not shard_name.startswith("hf_dataset:"):
|
1339
|
+
self.shard_tracker.mark_complete(shard_name)
|
994
1340
|
self.stats["completed_shards"] += 1
|
995
1341
|
await self._send_activity(f"Shard {shard_name} completed!")
|
996
1342
|
|
@@ -1076,12 +1422,29 @@ class Orchestrator:
|
|
1076
1422
|
chunk_stats = self.chunk_manager.get_stats()
|
1077
1423
|
await websocket.send(safe_json_dumps({"type": "chunk_stats", "data": chunk_stats}))
|
1078
1424
|
|
1079
|
-
# Send contributor leaderboard
|
1425
|
+
# Send contributor leaderboard with active worker counts
|
1080
1426
|
contributors = await self.storage.get_top_contributors(10)
|
1427
|
+
|
1428
|
+
# Enhance contributor data with active worker counts
|
1429
|
+
enhanced_contributors = []
|
1430
|
+
worker_counts = (
|
1431
|
+
self.get_workers_by_user_stats() if hasattr(self, "workers_by_user") else {}
|
1432
|
+
)
|
1433
|
+
|
1434
|
+
for contributor in contributors:
|
1435
|
+
contrib_dict = {
|
1436
|
+
"contributor_id": contributor.contributor_id,
|
1437
|
+
"name": contributor.name,
|
1438
|
+
"total_captions": contributor.total_captions,
|
1439
|
+
"trust_level": contributor.trust_level,
|
1440
|
+
"active_workers": len(
|
1441
|
+
worker_counts.get(contributor.contributor_id, {}).get("worker_ids", [])
|
1442
|
+
),
|
1443
|
+
}
|
1444
|
+
enhanced_contributors.append(contrib_dict)
|
1445
|
+
|
1081
1446
|
await websocket.send(
|
1082
|
-
safe_json_dumps(
|
1083
|
-
{"type": "leaderboard", "data": [safe_dict(c) for c in contributors]}
|
1084
|
-
)
|
1447
|
+
safe_json_dumps({"type": "leaderboard", "data": enhanced_contributors})
|
1085
1448
|
)
|
1086
1449
|
|
1087
1450
|
# Keep connection alive
|
@@ -1094,14 +1457,23 @@ class Orchestrator:
|
|
1094
1457
|
self.monitors.discard(websocket)
|
1095
1458
|
|
1096
1459
|
async def _broadcast_stats(self):
|
1097
|
-
"""Broadcast statistics to all monitors."""
|
1460
|
+
"""Broadcast statistics to all monitors - enhanced for multi-stage."""
|
1098
1461
|
if not self.monitors:
|
1099
1462
|
return
|
1100
1463
|
|
1464
|
+
# Get storage stats
|
1465
|
+
storage_stats = await self.storage.get_storage_stats()
|
1466
|
+
caption_stats = await self.storage.get_caption_stats()
|
1467
|
+
|
1101
1468
|
# Include chunk stats
|
1102
1469
|
chunk_stats = self.chunk_manager.get_stats()
|
1103
1470
|
self.stats.update({f"chunks_{k}": v for k, v in chunk_stats.items()})
|
1104
1471
|
|
1472
|
+
# Merge storage stats
|
1473
|
+
self.stats.update(storage_stats)
|
1474
|
+
self.stats["field_breakdown"] = caption_stats.get("field_stats", {})
|
1475
|
+
self.stats["output_fields_list"] = caption_stats.get("output_fields", [])
|
1476
|
+
|
1105
1477
|
# Add rate information
|
1106
1478
|
self.stats.update(
|
1107
1479
|
{
|
@@ -1111,23 +1483,123 @@ class Orchestrator:
|
|
1111
1483
|
}
|
1112
1484
|
)
|
1113
1485
|
|
1114
|
-
# Add vLLM info
|
1486
|
+
# Add vLLM info - now includes stage count
|
1115
1487
|
self.stats["vllm_model"] = self.vllm_config.get("model", "unknown")
|
1116
1488
|
self.stats["vllm_batch_size"] = self.vllm_config.get("batch_size", 0)
|
1117
1489
|
|
1490
|
+
# NEW: Add stage information
|
1491
|
+
stages = self.vllm_config.get("stages", [])
|
1492
|
+
if stages:
|
1493
|
+
self.stats["stage_count"] = len(stages)
|
1494
|
+
self.stats["stage_names"] = [s.get("name", "unnamed") for s in stages]
|
1495
|
+
else:
|
1496
|
+
self.stats["stage_count"] = 1 # Backward compatibility
|
1497
|
+
self.stats["stage_names"] = ["default"]
|
1498
|
+
|
1499
|
+
field_stats = await self.storage.get_output_field_stats()
|
1500
|
+
self.stats["output_fields"] = field_stats
|
1501
|
+
|
1118
1502
|
message = safe_json_dumps({"type": "stats", "data": self.stats})
|
1119
1503
|
|
1120
1504
|
# Send to all monitors
|
1121
1505
|
disconnected = set()
|
1122
|
-
|
1506
|
+
_monitors = self.monitors.copy()
|
1507
|
+
for monitor in _monitors:
|
1123
1508
|
try:
|
1124
1509
|
await monitor.send(message)
|
1125
1510
|
except websockets.exceptions.ConnectionClosed:
|
1126
1511
|
disconnected.add(monitor)
|
1127
1512
|
|
1513
|
+
# send updated leaderboard
|
1514
|
+
try:
|
1515
|
+
contributors = await self.storage.get_top_contributors(10)
|
1516
|
+
enhanced_contributors = []
|
1517
|
+
worker_counts = (
|
1518
|
+
self.get_workers_by_user_stats() if hasattr(self, "workers_by_user") else {}
|
1519
|
+
)
|
1520
|
+
|
1521
|
+
for contributor in contributors:
|
1522
|
+
contrib_dict = {
|
1523
|
+
"contributor_id": contributor.contributor_id,
|
1524
|
+
"name": contributor.name,
|
1525
|
+
"total_captions": contributor.total_captions,
|
1526
|
+
"trust_level": contributor.trust_level,
|
1527
|
+
"active_workers": len(
|
1528
|
+
worker_counts.get(contributor.contributor_id, {}).get("worker_ids", [])
|
1529
|
+
),
|
1530
|
+
}
|
1531
|
+
enhanced_contributors.append(contrib_dict)
|
1532
|
+
|
1533
|
+
leaderboard_message = safe_json_dumps(
|
1534
|
+
{"type": "leaderboard", "data": enhanced_contributors}
|
1535
|
+
)
|
1536
|
+
|
1537
|
+
# Send to all monitors
|
1538
|
+
disconnected = set()
|
1539
|
+
for monitor in self.monitors.copy():
|
1540
|
+
try:
|
1541
|
+
await monitor.send(leaderboard_message)
|
1542
|
+
except websockets.exceptions.ConnectionClosed:
|
1543
|
+
disconnected.add(monitor)
|
1544
|
+
|
1545
|
+
self.monitors -= disconnected
|
1546
|
+
|
1547
|
+
except Exception as e:
|
1548
|
+
logger.error(f"Error sending leaderboard update: {e}")
|
1549
|
+
|
1128
1550
|
# Clean up disconnected monitors
|
1129
1551
|
self.monitors -= disconnected
|
1130
1552
|
|
1553
|
+
async def _flush_processed_items(self):
|
1554
|
+
"""Flush batched processed items to chunk tracker."""
|
1555
|
+
with self.item_batch_lock:
|
1556
|
+
if not self.pending_processed_items:
|
1557
|
+
return
|
1558
|
+
|
1559
|
+
for chunk_id, indices in self.pending_processed_items.items():
|
1560
|
+
if not indices:
|
1561
|
+
continue
|
1562
|
+
|
1563
|
+
# Indices here are ABSOLUTE dataset indices
|
1564
|
+
# Sort indices
|
1565
|
+
indices.sort()
|
1566
|
+
|
1567
|
+
# Group consecutive indices into ranges
|
1568
|
+
ranges = []
|
1569
|
+
start = indices[0]
|
1570
|
+
end = indices[0]
|
1571
|
+
|
1572
|
+
for i in range(1, len(indices)):
|
1573
|
+
if indices[i] == end + 1:
|
1574
|
+
# Consecutive, extend range
|
1575
|
+
end = indices[i]
|
1576
|
+
else:
|
1577
|
+
# Gap found, save current range and start new one
|
1578
|
+
ranges.append((start, end))
|
1579
|
+
start = indices[i]
|
1580
|
+
end = indices[i]
|
1581
|
+
|
1582
|
+
# Don't forget the last range
|
1583
|
+
ranges.append((start, end))
|
1584
|
+
|
1585
|
+
# Mark ranges as processed (mark_items_processed expects absolute indices)
|
1586
|
+
for start_idx, end_idx in ranges:
|
1587
|
+
self.chunk_tracker.mark_items_processed(chunk_id, start_idx, end_idx)
|
1588
|
+
|
1589
|
+
# Clear pending items
|
1590
|
+
self.pending_processed_items.clear()
|
1591
|
+
self.last_item_batch_flush = time.time()
|
1592
|
+
|
1593
|
+
def get_workers_by_user_stats(self) -> Dict[str, Any]:
|
1594
|
+
"""Get statistics about workers grouped by user/token."""
|
1595
|
+
if not hasattr(self, "workers_by_user"):
|
1596
|
+
return {}
|
1597
|
+
|
1598
|
+
stats = {}
|
1599
|
+
for user, worker_ids in self.workers_by_user.items():
|
1600
|
+
stats[user] = {"worker_count": len(worker_ids), "worker_ids": list(worker_ids)}
|
1601
|
+
return stats
|
1602
|
+
|
1131
1603
|
async def _send_activity(self, activity: str):
|
1132
1604
|
"""Send activity update to monitors."""
|
1133
1605
|
if not self.monitors:
|
@@ -1172,36 +1644,52 @@ class Orchestrator:
|
|
1172
1644
|
while True:
|
1173
1645
|
await asyncio.sleep(60)
|
1174
1646
|
|
1647
|
+
# Get current caption count from storage
|
1648
|
+
storage_stats = await self.storage.get_storage_stats()
|
1649
|
+
total_captions = storage_stats["total_captions"]
|
1650
|
+
|
1175
1651
|
# Force checkpoint at regular intervals
|
1176
|
-
if
|
1177
|
-
logger.info(f"Triggering checkpoint at {
|
1652
|
+
if total_captions > 0 and total_captions % interval == 0:
|
1653
|
+
logger.info(f"Triggering checkpoint at {total_captions} captions")
|
1178
1654
|
await self.storage.checkpoint()
|
1179
1655
|
|
1180
1656
|
# Update stats
|
1181
1657
|
self.stats["last_checkpoint"] = datetime.utcnow().isoformat()
|
1182
|
-
|
1183
|
-
self.stats["buffer_size"] = len(self.storage.caption_buffer)
|
1658
|
+
# No need to update total_written or buffer_size - they come from storage
|
1184
1659
|
|
1185
1660
|
await self._broadcast_stats()
|
1186
1661
|
logger.info(
|
1187
|
-
f"Checkpoint complete. Total written to disk: {
|
1662
|
+
f"Checkpoint complete. Total written to disk: {storage_stats['total_written']}"
|
1188
1663
|
)
|
1189
1664
|
|
1190
1665
|
async def _stats_update_loop(self):
|
1191
1666
|
"""Periodically update and broadcast stats."""
|
1192
1667
|
# Track session start values
|
1193
|
-
|
1668
|
+
storage_stats = await self.storage.get_storage_stats()
|
1669
|
+
session_start_outputs = storage_stats["total_captions"] # This now counts ALL outputs
|
1194
1670
|
session_start_time = time.time()
|
1195
1671
|
|
1672
|
+
# Track the last known total to detect flushes
|
1673
|
+
last_known_total = session_start_outputs
|
1674
|
+
|
1196
1675
|
while True:
|
1197
1676
|
await asyncio.sleep(10)
|
1198
1677
|
|
1199
1678
|
# Update chunk stats
|
1200
1679
|
chunk_stats = self.chunk_manager.get_stats()
|
1680
|
+
storage_stats = await self.storage.get_storage_stats()
|
1681
|
+
current_total_outputs = storage_stats["total_captions"] # ALL outputs
|
1682
|
+
if self.chunk_tracker:
|
1683
|
+
await self._flush_processed_items()
|
1684
|
+
|
1201
1685
|
self.stats["total_chunks"] = chunk_stats["total"]
|
1202
1686
|
self.stats["completed_chunks"] = chunk_stats["completed"]
|
1203
1687
|
self.stats["failed_chunks"] = chunk_stats["failed"]
|
1204
1688
|
|
1689
|
+
# Update total outputs stat (rename from total_captions for clarity)
|
1690
|
+
self.stats["total_outputs"] = current_total_outputs
|
1691
|
+
self.stats["total_captions"] = current_total_outputs # Keep for backward compatibility
|
1692
|
+
|
1205
1693
|
# Add queue information
|
1206
1694
|
with self.chunk_manager.lock:
|
1207
1695
|
self.stats["pending_chunks"] = len(self.chunk_manager.pending_chunks)
|
@@ -1220,33 +1708,57 @@ class Orchestrator:
|
|
1220
1708
|
elapsed_since_update = current_time - self.rate_tracker["last_update_time"]
|
1221
1709
|
|
1222
1710
|
if elapsed_since_update > 0:
|
1223
|
-
#
|
1224
|
-
|
1225
|
-
|
1226
|
-
|
1227
|
-
|
1711
|
+
# FIX: Handle the case where duplicates were skipped during save
|
1712
|
+
# If current total is less than last known, it means duplicates were skipped
|
1713
|
+
# We should not count this as negative progress
|
1714
|
+
if current_total_outputs < last_known_total:
|
1715
|
+
logger.debug(
|
1716
|
+
f"Detected duplicate skip during save: {last_known_total} -> {current_total_outputs}"
|
1717
|
+
)
|
1718
|
+
# Don't calculate negative rate, just update the baseline
|
1719
|
+
self.rate_tracker["last_caption_count"] = current_total_outputs
|
1720
|
+
self.rate_tracker["current_rate"] = 0.0 # Set to 0 during flush
|
1721
|
+
else:
|
1722
|
+
# Normal rate calculation
|
1723
|
+
output_diff = current_total_outputs - self.rate_tracker["last_caption_count"]
|
1724
|
+
self.rate_tracker["current_rate"] = (output_diff / elapsed_since_update) * 60
|
1725
|
+
self.rate_tracker["last_caption_count"] = current_total_outputs
|
1228
1726
|
|
1229
1727
|
# Calculate average rate since THIS SESSION started
|
1230
1728
|
session_elapsed = current_time - session_start_time
|
1231
1729
|
if session_elapsed > 0:
|
1232
|
-
|
1233
|
-
|
1730
|
+
# Always use the difference from session start for average
|
1731
|
+
session_outputs = current_total_outputs - session_start_outputs
|
1732
|
+
self.rate_tracker["average_rate"] = (session_outputs / session_elapsed) * 60
|
1234
1733
|
|
1235
|
-
# Calculate expected rate based on workers
|
1236
|
-
# Assume each worker processes batch_size images every ~2 seconds with 3 captions each
|
1734
|
+
# Calculate expected rate based on workers and stages
|
1237
1735
|
batch_size = self.vllm_config.get("batch_size", 8)
|
1238
|
-
|
1736
|
+
|
1737
|
+
# Count total prompts across all stages
|
1738
|
+
total_prompts = 0
|
1739
|
+
stages = self.vllm_config.get("stages", [])
|
1740
|
+
if stages:
|
1741
|
+
for stage in stages:
|
1742
|
+
total_prompts += len(stage.get("prompts", []))
|
1743
|
+
else:
|
1744
|
+
# Backward compatibility
|
1745
|
+
total_prompts = len(self.vllm_config.get("inference_prompts", ["", "", ""]))
|
1746
|
+
|
1239
1747
|
images_per_minute = 30 # Rough estimate: 30 images/min per worker
|
1240
|
-
self.rate_tracker["expected_rate"] =
|
1748
|
+
self.rate_tracker["expected_rate"] = (
|
1749
|
+
worker_count * images_per_minute * total_prompts
|
1750
|
+
)
|
1241
1751
|
|
1242
1752
|
# Update trackers
|
1243
1753
|
self.rate_tracker["last_update_time"] = current_time
|
1244
|
-
|
1754
|
+
last_known_total = current_total_outputs
|
1245
1755
|
|
1246
1756
|
# Log rate information when workers are connected
|
1247
|
-
if
|
1757
|
+
if (
|
1758
|
+
worker_count > 0 and self.rate_tracker["current_rate"] >= 0
|
1759
|
+
): # Only log non-negative rates
|
1248
1760
|
logger.info(
|
1249
|
-
f"Rate: {self.rate_tracker['current_rate']:.1f}
|
1761
|
+
f"Rate: {self.rate_tracker['current_rate']:.1f} outputs/min "
|
1250
1762
|
f"(avg: {self.rate_tracker['average_rate']:.1f}, "
|
1251
1763
|
f"expected: {self.rate_tracker['expected_rate']:.1f}) | "
|
1252
1764
|
f"Workers: {worker_count}, Chunks: {active_chunks}/{target_buffer}"
|
@@ -1256,16 +1768,16 @@ class Orchestrator:
|
|
1256
1768
|
|
1257
1769
|
async def _restore_state(self):
|
1258
1770
|
"""Restore state from storage on startup."""
|
1259
|
-
|
1260
|
-
|
1261
|
-
|
1262
|
-
logger.info(f"Restored state: {self.stats['total_captions']} captions")
|
1771
|
+
total_captions = await self.storage.count_captions()
|
1772
|
+
logger.info(f"Restored state: {total_captions} captions")
|
1263
1773
|
|
1264
1774
|
async def shutdown(self):
|
1265
1775
|
"""Graceful shutdown."""
|
1266
1776
|
logger.info("Shutting down orchestrator...")
|
1267
1777
|
|
1268
1778
|
# Stop chunk creation
|
1779
|
+
if self.chunk_tracker:
|
1780
|
+
await self._flush_processed_items()
|
1269
1781
|
self.stop_chunk_creation.set()
|
1270
1782
|
if self.chunk_creation_thread:
|
1271
1783
|
self.chunk_creation_thread.join(timeout=5)
|
@@ -1287,7 +1799,7 @@ class Orchestrator:
|
|
1287
1799
|
|
1288
1800
|
# Save chunk state
|
1289
1801
|
if self.chunk_tracker:
|
1290
|
-
self.chunk_tracker.
|
1802
|
+
self.chunk_tracker.save()
|
1291
1803
|
|
1292
1804
|
# Final checkpoint
|
1293
1805
|
logger.info(f"Final flush: {len(self.storage.caption_buffer)} captions in buffer")
|