caption-flow 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/cli.py +308 -0
- caption_flow/models.py +134 -1
- caption_flow/monitor.py +1 -1
- caption_flow/orchestrator.py +423 -1715
- caption_flow/processors/__init__.py +11 -0
- caption_flow/processors/base.py +219 -0
- caption_flow/processors/huggingface.py +832 -0
- caption_flow/processors/local_filesystem.py +683 -0
- caption_flow/processors/webdataset.py +782 -0
- caption_flow/storage/__init__.py +1 -0
- caption_flow/storage/exporter.py +550 -0
- caption_flow/{storage.py → storage/manager.py} +489 -401
- caption_flow/utils/checkpoint_tracker.py +2 -2
- caption_flow/utils/chunk_tracker.py +73 -32
- caption_flow/utils/dataset_loader.py +58 -298
- caption_flow/utils/dataset_metadata_cache.py +67 -0
- caption_flow/utils/image_processor.py +1 -4
- caption_flow/utils/shard_processor.py +5 -265
- caption_flow/utils/shard_tracker.py +1 -5
- caption_flow/viewer.py +594 -0
- caption_flow/workers/base.py +3 -3
- caption_flow/workers/caption.py +416 -792
- {caption_flow-0.2.2.dist-info → caption_flow-0.2.4.dist-info}/METADATA +49 -180
- caption_flow-0.2.4.dist-info/RECORD +38 -0
- caption_flow-0.2.2.dist-info/RECORD +0 -29
- {caption_flow-0.2.2.dist-info → caption_flow-0.2.4.dist-info}/WHEEL +0 -0
- {caption_flow-0.2.2.dist-info → caption_flow-0.2.4.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.2.2.dist-info → caption_flow-0.2.4.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.2.2.dist-info → caption_flow-0.2.4.dist-info}/top_level.txt +0 -0
caption_flow/orchestrator.py
CHANGED
@@ -1,516 +1,106 @@
|
|
1
|
-
"""Enhanced orchestrator with shard chunk assignment for vLLM workers.
|
2
|
-
|
3
|
-
This orchestrator:
|
4
|
-
1. Divides dataset shards into chunks for parallel processing
|
5
|
-
2. Assigns chunks to workers on request
|
6
|
-
3. Collects captions from workers centrally
|
7
|
-
4. Manages checkpoints and fault tolerance
|
8
|
-
"""
|
9
|
-
|
10
1
|
import time
|
11
2
|
import asyncio
|
12
3
|
import json
|
13
4
|
import logging
|
14
5
|
import ssl
|
15
6
|
import uuid
|
16
|
-
from dataclasses import dataclass, asdict
|
17
7
|
from datetime import datetime
|
18
8
|
from pathlib import Path
|
19
|
-
from typing import Dict, Set, Optional, Any, List
|
20
|
-
from collections import
|
9
|
+
from typing import Dict, Set, Optional, Any, List
|
10
|
+
from collections import defaultdict
|
21
11
|
import threading
|
22
|
-
from queue import Queue, Empty
|
23
12
|
|
24
|
-
from .workers import data
|
25
13
|
import websockets
|
26
14
|
from websockets.server import WebSocketServerProtocol
|
27
15
|
|
28
16
|
from .storage import StorageManager
|
29
|
-
from .models import Caption, Contributor
|
17
|
+
from .models import Caption, Contributor, JobId
|
30
18
|
from .utils.auth import AuthManager
|
31
|
-
from .utils import
|
32
|
-
from .
|
19
|
+
from .utils.json_utils import safe_json_dumps
|
20
|
+
from .processors import (
|
21
|
+
ProcessorConfig,
|
22
|
+
WorkAssignment,
|
23
|
+
WorkResult,
|
24
|
+
WorkUnit,
|
25
|
+
WebDatasetOrchestratorProcessor,
|
26
|
+
HuggingFaceDatasetOrchestratorProcessor,
|
27
|
+
LocalFilesystemOrchestratorProcessor,
|
28
|
+
)
|
33
29
|
|
34
30
|
logger = logging.getLogger(__name__)
|
35
|
-
|
36
|
-
|
37
|
-
@dataclass
|
38
|
-
class ShardChunk:
|
39
|
-
"""Represents a chunk of a shard for processing."""
|
40
|
-
|
41
|
-
chunk_id: str
|
42
|
-
shard_url: str
|
43
|
-
shard_name: str
|
44
|
-
start_index: int
|
45
|
-
chunk_size: int
|
46
|
-
assigned_to: Optional[str] = None
|
47
|
-
status: str = "pending" # pending, assigned, completed, failed
|
48
|
-
assigned_at: Optional[datetime] = None
|
49
|
-
completed_at: Optional[datetime] = None
|
50
|
-
|
51
|
-
@classmethod
|
52
|
-
def create(
|
53
|
-
cls, shard_url: str, shard_name: str, start_index: int, chunk_size: int
|
54
|
-
) -> "ShardChunk":
|
55
|
-
"""Factory method to create a chunk with consistent ID."""
|
56
|
-
# Always use consistent format: dataset_chunk_startindex
|
57
|
-
if shard_url.startswith("hf_dataset:"):
|
58
|
-
# Extract dataset path
|
59
|
-
parts = shard_url.split(":")
|
60
|
-
dataset_path = parts[1] if len(parts) > 1 else "unknown"
|
61
|
-
chunk_id = f"{dataset_path.replace('/', '_')}_chunk_{start_index}"
|
62
|
-
else:
|
63
|
-
# WebDataset format
|
64
|
-
chunk_id = f"{shard_name}_chunk_{start_index}"
|
65
|
-
|
66
|
-
return cls(
|
67
|
-
chunk_id=chunk_id,
|
68
|
-
shard_url=shard_url,
|
69
|
-
shard_name=shard_name,
|
70
|
-
start_index=start_index,
|
71
|
-
chunk_size=chunk_size,
|
72
|
-
)
|
73
|
-
|
74
|
-
def belongs_to_shard(self, shard_identifier: str) -> bool:
|
75
|
-
"""Check if this chunk belongs to a given shard."""
|
76
|
-
return self.shard_name == shard_identifier
|
77
|
-
|
78
|
-
def to_dict(self) -> Dict[str, Any]:
|
79
|
-
"""Convert to dict for JSON serialization (for workers)."""
|
80
|
-
return {
|
81
|
-
"chunk_id": self.chunk_id,
|
82
|
-
"shard_url": self.shard_url,
|
83
|
-
"shard_name": self.shard_name,
|
84
|
-
"start_index": self.start_index,
|
85
|
-
"chunk_size": self.chunk_size,
|
86
|
-
}
|
87
|
-
|
88
|
-
|
89
|
-
class ChunkManager:
|
90
|
-
"""Manages shard chunk creation and assignment."""
|
91
|
-
|
92
|
-
def __init__(self, chunk_size: int = 1000, tracker: Optional[ChunkTracker] = None):
|
93
|
-
self.chunk_size = chunk_size
|
94
|
-
self.chunks: Dict[str, ShardChunk] = {}
|
95
|
-
self.pending_chunks: Deque[str] = deque()
|
96
|
-
self.assigned_chunks: Dict[str, Set[str]] = defaultdict(set) # worker_id -> chunk_ids
|
97
|
-
self.lock = threading.Lock()
|
98
|
-
self.tracker = tracker # Reference to chunk tracker
|
99
|
-
|
100
|
-
# NEW: Track assigned ranges to prevent double allocation
|
101
|
-
# Format: {chunk_id: {(start, end): worker_id}}
|
102
|
-
self.assigned_ranges: Dict[str, Dict[Tuple[int, int], str]] = defaultdict(dict)
|
103
|
-
|
104
|
-
def get_chunks_for_worker(
|
105
|
-
self, worker_id: str, count: int = 1, tracker: Optional["ChunkTracker"] = None
|
106
|
-
) -> List[Dict[str, Any]]:
|
107
|
-
"""Get available chunks with unprocessed items for a worker."""
|
108
|
-
assigned = []
|
109
|
-
|
110
|
-
with self.lock:
|
111
|
-
# FIRST PRIORITY: Check if this worker already has assigned chunks
|
112
|
-
if worker_id in self.assigned_chunks:
|
113
|
-
existing_chunk_ids = list(self.assigned_chunks[worker_id])
|
114
|
-
for chunk_id in existing_chunk_ids:
|
115
|
-
if len(assigned) >= count:
|
116
|
-
break
|
117
|
-
|
118
|
-
chunk = self.chunks.get(chunk_id)
|
119
|
-
if not chunk:
|
120
|
-
continue
|
121
|
-
|
122
|
-
# Check if chunk still has unprocessed items
|
123
|
-
if tracker:
|
124
|
-
chunk_info = tracker.get_chunk_with_unprocessed_items(chunk_id)
|
125
|
-
if chunk_info and chunk_info["unprocessed_ranges"]:
|
126
|
-
# Filter out ranges that are assigned to other workers
|
127
|
-
clean_ranges = []
|
128
|
-
for start, end in chunk_info["unprocessed_ranges"]:
|
129
|
-
range_key = (start, end)
|
130
|
-
if range_key in self.assigned_ranges[chunk_id]:
|
131
|
-
assigned_worker = self.assigned_ranges[chunk_id][range_key]
|
132
|
-
if assigned_worker != worker_id:
|
133
|
-
# Skip this range - it's assigned to another worker
|
134
|
-
logger.warning(
|
135
|
-
f"Skipping range {start}-{end} in chunk {chunk_id} "
|
136
|
-
f"(assigned to {assigned_worker}, not {worker_id})"
|
137
|
-
)
|
138
|
-
continue
|
139
|
-
# else: this worker already owns this range, include it
|
140
|
-
clean_ranges.append((start, end))
|
141
|
-
|
142
|
-
if clean_ranges:
|
143
|
-
assigned.append(
|
144
|
-
{
|
145
|
-
"chunk": chunk,
|
146
|
-
"unprocessed_ranges": clean_ranges,
|
147
|
-
}
|
148
|
-
)
|
149
|
-
else:
|
150
|
-
# No tracker, assume chunk needs processing
|
151
|
-
assigned.append(
|
152
|
-
{
|
153
|
-
"chunk": chunk,
|
154
|
-
"unprocessed_ranges": [(0, chunk.chunk_size - 1)],
|
155
|
-
}
|
156
|
-
)
|
157
|
-
|
158
|
-
# SECOND PRIORITY: Get new pending chunks
|
159
|
-
while len(assigned) < count and self.pending_chunks:
|
160
|
-
chunk_id = self.pending_chunks.popleft()
|
161
|
-
chunk = self.chunks.get(chunk_id)
|
162
|
-
|
163
|
-
if not chunk:
|
164
|
-
continue
|
165
|
-
|
166
|
-
# Verify chunk is truly pending
|
167
|
-
if chunk.status != "pending" or chunk.assigned_to is not None:
|
168
|
-
logger.warning(
|
169
|
-
f"Chunk {chunk_id} in pending queue but status={chunk.status}, assigned_to={chunk.assigned_to}"
|
170
|
-
)
|
171
|
-
continue
|
172
|
-
|
173
|
-
# Assign to this worker
|
174
|
-
chunk.assigned_to = worker_id
|
175
|
-
chunk.status = "assigned"
|
176
|
-
chunk.assigned_at = datetime.utcnow()
|
177
|
-
self.assigned_chunks[worker_id].add(chunk_id)
|
178
|
-
|
179
|
-
# Get unprocessed ranges and filter out any that are somehow already assigned
|
180
|
-
unprocessed_ranges = [(0, chunk.chunk_size - 1)] # Default
|
181
|
-
if tracker:
|
182
|
-
chunk_info = tracker.get_chunk_with_unprocessed_items(chunk_id)
|
183
|
-
if chunk_info:
|
184
|
-
# Filter out any ranges that are already assigned (shouldn't happen for new chunks)
|
185
|
-
clean_ranges = []
|
186
|
-
for start, end in chunk_info["unprocessed_ranges"]:
|
187
|
-
range_key = (start, end)
|
188
|
-
if range_key not in self.assigned_ranges[chunk_id]:
|
189
|
-
clean_ranges.append((start, end))
|
190
|
-
else:
|
191
|
-
logger.error(
|
192
|
-
f"Range {start}-{end} in newly assigned chunk {chunk_id} "
|
193
|
-
f"is already assigned to {self.assigned_ranges[chunk_id][range_key]}!"
|
194
|
-
)
|
195
|
-
unprocessed_ranges = clean_ranges if clean_ranges else []
|
196
|
-
|
197
|
-
tracker.mark_assigned(chunk_id, worker_id)
|
198
|
-
|
199
|
-
if unprocessed_ranges:
|
200
|
-
assigned.append({"chunk": chunk, "unprocessed_ranges": unprocessed_ranges})
|
201
|
-
|
202
|
-
# Track assigned ranges and verify no double allocation
|
203
|
-
for info in assigned:
|
204
|
-
chunk_id = info["chunk"].chunk_id
|
205
|
-
for start, end in info["unprocessed_ranges"]:
|
206
|
-
range_key = (start, end)
|
207
|
-
|
208
|
-
# Check if this range is already assigned
|
209
|
-
if range_key in self.assigned_ranges[chunk_id]:
|
210
|
-
existing_worker = self.assigned_ranges[chunk_id][range_key]
|
211
|
-
if existing_worker != worker_id:
|
212
|
-
# This should never happen - raise assertion
|
213
|
-
raise AssertionError(
|
214
|
-
f"CRITICAL: Attempting to assign range {start}-{end} in chunk {chunk_id} "
|
215
|
-
f"to worker {worker_id}, but it's already assigned to {existing_worker}! "
|
216
|
-
f"This would cause duplicate processing."
|
217
|
-
)
|
218
|
-
|
219
|
-
# Track this assignment
|
220
|
-
self.assigned_ranges[chunk_id][range_key] = worker_id
|
221
|
-
|
222
|
-
# Log what we're assigning
|
223
|
-
if assigned:
|
224
|
-
chunk_summary = ", ".join(
|
225
|
-
[
|
226
|
-
f"{info['chunk'].chunk_id}[{len(info['unprocessed_ranges'])} ranges]"
|
227
|
-
for info in assigned
|
228
|
-
]
|
229
|
-
)
|
230
|
-
logger.info(f"Assigning to worker {worker_id}: {chunk_summary}")
|
231
|
-
|
232
|
-
# Detailed range logging for debugging
|
233
|
-
for info in assigned:
|
234
|
-
chunk_id = info["chunk"].chunk_id
|
235
|
-
ranges_str = ", ".join([f"{s}-{e}" for s, e in info["unprocessed_ranges"]])
|
236
|
-
logger.debug(f" Chunk {chunk_id} ranges: {ranges_str}")
|
237
|
-
|
238
|
-
return assigned
|
239
|
-
|
240
|
-
def complete_chunk(self, chunk_id: str, worker_id: str) -> bool:
|
241
|
-
"""Mark a chunk as completed."""
|
242
|
-
with self.lock:
|
243
|
-
if chunk_id in self.chunks:
|
244
|
-
chunk = self.chunks[chunk_id]
|
245
|
-
if chunk.assigned_to == worker_id and chunk.status == "assigned":
|
246
|
-
chunk.status = "completed"
|
247
|
-
chunk.completed_at = datetime.utcnow()
|
248
|
-
self.assigned_chunks[worker_id].discard(chunk_id)
|
249
|
-
|
250
|
-
# Clear assigned ranges for this chunk
|
251
|
-
if chunk_id in self.assigned_ranges:
|
252
|
-
# Log what ranges we're clearing
|
253
|
-
ranges_to_clear = list(self.assigned_ranges[chunk_id].keys())
|
254
|
-
logger.debug(
|
255
|
-
f"Clearing {len(ranges_to_clear)} assigned ranges for completed chunk {chunk_id}"
|
256
|
-
)
|
257
|
-
del self.assigned_ranges[chunk_id]
|
258
|
-
|
259
|
-
return True
|
260
|
-
return False
|
261
|
-
|
262
|
-
def fail_chunk(self, chunk_id: str, worker_id: str) -> bool:
|
263
|
-
"""Mark a chunk as failed and requeue it."""
|
264
|
-
with self.lock:
|
265
|
-
if chunk_id in self.chunks:
|
266
|
-
chunk = self.chunks[chunk_id]
|
267
|
-
if chunk.assigned_to == worker_id:
|
268
|
-
chunk.status = "pending"
|
269
|
-
chunk.assigned_to = None
|
270
|
-
chunk.assigned_at = None
|
271
|
-
self.assigned_chunks[worker_id].discard(chunk_id)
|
272
|
-
self.pending_chunks.append(chunk_id)
|
273
|
-
|
274
|
-
# Clear assigned ranges for this chunk/worker
|
275
|
-
if chunk_id in self.assigned_ranges:
|
276
|
-
ranges_to_clear = [
|
277
|
-
range_key
|
278
|
-
for range_key, assigned_worker in self.assigned_ranges[chunk_id].items()
|
279
|
-
if assigned_worker == worker_id
|
280
|
-
]
|
281
|
-
for range_key in ranges_to_clear:
|
282
|
-
del self.assigned_ranges[chunk_id][range_key]
|
283
|
-
logger.debug(
|
284
|
-
f"Cleared {len(ranges_to_clear)} assigned ranges for failed chunk {chunk_id}"
|
285
|
-
)
|
286
|
-
|
287
|
-
return True
|
288
|
-
return False
|
289
|
-
|
290
|
-
def release_worker_chunks(self, worker_id: str):
|
291
|
-
"""Release all chunks assigned to a worker."""
|
292
|
-
with self.lock:
|
293
|
-
chunk_ids = list(self.assigned_chunks.get(worker_id, []))
|
294
|
-
for chunk_id in chunk_ids:
|
295
|
-
if chunk_id in self.chunks:
|
296
|
-
chunk = self.chunks[chunk_id]
|
297
|
-
if chunk.status == "assigned":
|
298
|
-
chunk.status = "pending"
|
299
|
-
chunk.assigned_to = None
|
300
|
-
chunk.assigned_at = None
|
301
|
-
self.pending_chunks.append(chunk_id)
|
302
|
-
|
303
|
-
# Clear assigned ranges for this worker
|
304
|
-
if chunk_id in self.assigned_ranges:
|
305
|
-
ranges_to_clear = [
|
306
|
-
range_key
|
307
|
-
for range_key, assigned_worker in self.assigned_ranges[
|
308
|
-
chunk_id
|
309
|
-
].items()
|
310
|
-
if assigned_worker == worker_id
|
311
|
-
]
|
312
|
-
for range_key in ranges_to_clear:
|
313
|
-
del self.assigned_ranges[chunk_id][range_key]
|
314
|
-
|
315
|
-
if ranges_to_clear:
|
316
|
-
logger.info(
|
317
|
-
f"Released {len(ranges_to_clear)} ranges from chunk {chunk_id} "
|
318
|
-
f"previously assigned to disconnected worker {worker_id}"
|
319
|
-
)
|
320
|
-
|
321
|
-
if worker_id in self.assigned_chunks:
|
322
|
-
del self.assigned_chunks[worker_id]
|
323
|
-
|
324
|
-
def mark_ranges_processed(
|
325
|
-
self, chunk_id: str, processed_ranges: List[Tuple[int, int]], worker_id: str
|
326
|
-
):
|
327
|
-
"""Remove ranges from assignment tracking once they're processed."""
|
328
|
-
with self.lock:
|
329
|
-
if chunk_id in self.assigned_ranges:
|
330
|
-
for start, end in processed_ranges:
|
331
|
-
range_key = (start, end)
|
332
|
-
if range_key in self.assigned_ranges[chunk_id]:
|
333
|
-
assigned_worker = self.assigned_ranges[chunk_id][range_key]
|
334
|
-
if assigned_worker == worker_id:
|
335
|
-
del self.assigned_ranges[chunk_id][range_key]
|
336
|
-
logger.debug(
|
337
|
-
f"Cleared assignment of range {start}-{end} in chunk {chunk_id} "
|
338
|
-
f"after processing by {worker_id}"
|
339
|
-
)
|
340
|
-
else:
|
341
|
-
logger.warning(
|
342
|
-
f"Worker {worker_id} claims to have processed range {start}-{end} "
|
343
|
-
f"in chunk {chunk_id}, but it was assigned to {assigned_worker}"
|
344
|
-
)
|
345
|
-
|
346
|
-
def get_stats(self) -> Dict[str, int]:
|
347
|
-
"""Get chunk statistics."""
|
348
|
-
with self.lock:
|
349
|
-
# Count total assigned ranges
|
350
|
-
total_assigned_ranges = sum(len(ranges) for ranges in self.assigned_ranges.values())
|
351
|
-
|
352
|
-
stats = {
|
353
|
-
"total": len(self.chunks),
|
354
|
-
"pending": len(self.pending_chunks),
|
355
|
-
"assigned": sum(len(chunks) for chunks in self.assigned_chunks.values()),
|
356
|
-
"completed": sum(1 for c in self.chunks.values() if c.status == "completed"),
|
357
|
-
"failed": sum(1 for c in self.chunks.values() if c.status == "failed"),
|
358
|
-
"assigned_ranges": total_assigned_ranges,
|
359
|
-
}
|
360
|
-
return stats
|
31
|
+
logger.setLevel(logging.INFO)
|
361
32
|
|
362
33
|
|
363
34
|
class Orchestrator:
|
364
|
-
"""
|
35
|
+
"""Generic orchestrator for distributed work processing."""
|
365
36
|
|
366
37
|
def __init__(self, config: Dict[str, Any]):
|
367
38
|
self.config = config
|
368
39
|
self.host = config.get("host", "0.0.0.0")
|
369
40
|
self.port = config.get("port", 8765)
|
370
41
|
|
371
|
-
#
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
self.dataset_path,
|
388
|
-
self.dataset_type,
|
389
|
-
self.dataset_split,
|
390
|
-
self.dataset_image_column,
|
391
|
-
)
|
392
|
-
checkpoint_dir = Path(config.get("storage", {}).get("checkpoint_dir", "./checkpoints"))
|
393
|
-
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
394
|
-
self.shard_tracker = ShardTracker(checkpoint_dir / "shards.json")
|
395
|
-
self.chunk_tracker = ChunkTracker(checkpoint_dir / "chunks.json")
|
396
|
-
|
397
|
-
# vLLM configuration to distribute to workers
|
398
|
-
self.vllm_config = config.get(
|
399
|
-
"vllm",
|
400
|
-
{
|
401
|
-
"model": "Qwen/Qwen2.5-VL-3B-Instruct",
|
402
|
-
"gpu_memory_utilization": 0.92,
|
403
|
-
"max_model_len": 16384,
|
404
|
-
"tensor_parallel_size": 1,
|
405
|
-
"dtype": "float16",
|
406
|
-
"enforce_eager": True,
|
407
|
-
"limit_mm_per_prompt": {"image": 1},
|
408
|
-
"disable_mm_preprocessor_cache": True,
|
409
|
-
"sampling": {
|
410
|
-
"temperature": 0.7,
|
411
|
-
"top_p": 0.95,
|
412
|
-
"max_tokens": 256,
|
413
|
-
"repetition_penalty": 1.05,
|
414
|
-
"stop": ["<|end|>", "<|endoftext|>", "<|im_end|>"],
|
415
|
-
},
|
416
|
-
"inference_prompts": [
|
417
|
-
"describe this image in detail",
|
418
|
-
"provide a comprehensive description of the visual content",
|
419
|
-
"what are the key elements in this image?",
|
420
|
-
],
|
421
|
-
},
|
422
|
-
)
|
423
|
-
|
424
|
-
# Chunk configuration
|
425
|
-
self.chunk_size = config.get("chunk_size", 1000)
|
426
|
-
self.chunks_per_request = config.get("chunks_per_request", 2)
|
427
|
-
|
428
|
-
# Demand-driven chunk creation settings
|
429
|
-
self.chunk_buffer_multiplier = config.get("chunk_buffer_multiplier", 3)
|
430
|
-
self.min_chunk_buffer = config.get("min_chunk_buffer", 10)
|
42
|
+
# Processor configuration
|
43
|
+
processor_type = config.get("dataset", {}).get("processor_type", None)
|
44
|
+
assert (
|
45
|
+
processor_type is not None
|
46
|
+
), "You must supply processor_type in your orchestrator dataset configuration."
|
47
|
+
processor_config = ProcessorConfig(processor_type=processor_type, config=config)
|
48
|
+
|
49
|
+
# Initialize processor
|
50
|
+
if processor_type == "webdataset":
|
51
|
+
self.processor = WebDatasetOrchestratorProcessor()
|
52
|
+
elif processor_type == "huggingface_datasets":
|
53
|
+
self.processor = HuggingFaceDatasetOrchestratorProcessor()
|
54
|
+
elif processor_type == "local_filesystem":
|
55
|
+
self.processor = LocalFilesystemOrchestratorProcessor()
|
56
|
+
else:
|
57
|
+
raise ValueError(f"Unknown processor type: {processor_type}")
|
431
58
|
|
432
59
|
# Initialize components
|
433
60
|
storage_config = config.get("storage", {})
|
434
61
|
self.storage = StorageManager(
|
435
62
|
Path(storage_config.get("data_dir", "./caption_data")),
|
436
63
|
caption_buffer_size=storage_config.get("caption_buffer_size", 1000),
|
437
|
-
job_buffer_size=storage_config.get("job_buffer_size", 100),
|
438
|
-
contributor_buffer_size=storage_config.get("contributor_buffer_size", 10),
|
439
64
|
)
|
440
65
|
self.auth = AuthManager(config.get("auth", {}))
|
66
|
+
self.processor.initialize(processor_config, self.storage)
|
441
67
|
|
442
|
-
#
|
443
|
-
self.
|
444
|
-
self.shard_tracker = None
|
445
|
-
self.chunk_tracker = None
|
446
|
-
|
447
|
-
if self.dataset_path:
|
448
|
-
self.dataset_loader = DatasetLoader(self.dataset_path, self.dataset_type)
|
449
|
-
checkpoint_dir = Path(config.get("storage", {}).get("checkpoint_dir", "./checkpoints"))
|
450
|
-
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
451
|
-
self.shard_tracker = ShardTracker(checkpoint_dir / "shards.json")
|
452
|
-
self.chunk_tracker = ChunkTracker(checkpoint_dir / "chunks.json")
|
453
|
-
|
454
|
-
# Initialize chunk manager with reference to chunk tracker
|
455
|
-
self.chunk_manager = ChunkManager(self.chunk_size, self.chunk_tracker)
|
456
|
-
self.pending_processed_items = defaultdict(list) # chunk_id -> list of indices
|
457
|
-
self.item_batch_lock = threading.Lock()
|
458
|
-
self.last_item_batch_flush = time.time()
|
459
|
-
self.item_batch_interval = 5 # Flush every 5 seconds
|
460
|
-
self.item_batch_size = 100 # Or every 100 items
|
68
|
+
# Processing configuration
|
69
|
+
self.units_per_request = config.get("units_per_request", 2)
|
461
70
|
|
462
71
|
# Track connections
|
463
72
|
self.workers: Dict[str, WebSocketServerProtocol] = {}
|
464
73
|
self.monitors: Set[WebSocketServerProtocol] = set()
|
74
|
+
self.workers_by_user = defaultdict(set)
|
465
75
|
|
466
76
|
# SSL configuration
|
467
77
|
self.ssl_context = self._setup_ssl()
|
468
78
|
|
469
79
|
# Statistics
|
470
|
-
self.is_generating_stats = False
|
471
80
|
self.stats = {
|
472
|
-
"total_chunks": 0,
|
473
|
-
"completed_chunks": 0,
|
474
|
-
"failed_chunks": 0,
|
475
81
|
"connected_workers": 0,
|
476
|
-
"
|
477
|
-
"completed_shards": 0,
|
478
|
-
"current_shard": None,
|
82
|
+
"total_outputs": 0,
|
479
83
|
"last_checkpoint": None,
|
84
|
+
"processor_stats": {},
|
480
85
|
}
|
481
86
|
|
87
|
+
# Cache for leaderboard
|
88
|
+
self._cached_leaderboard = None
|
89
|
+
|
90
|
+
# Data worker stuff
|
91
|
+
self.data_workers = {}
|
92
|
+
self.data_sample_queue = asyncio.Queue()
|
93
|
+
self.backpressure_threshold = config.get("backpressure_threshold", 1000)
|
94
|
+
|
482
95
|
# Rate tracking
|
483
96
|
self.rate_tracker = {
|
484
97
|
"start_time": time.time(),
|
485
98
|
"last_update_time": time.time(),
|
486
|
-
"
|
99
|
+
"last_output_count": 0,
|
487
100
|
"current_rate": 0.0,
|
488
101
|
"average_rate": 0.0,
|
489
|
-
"expected_rate": 0.0,
|
490
102
|
}
|
491
103
|
|
492
|
-
# Data sample queue for CaptionWorker
|
493
|
-
self.data_sample_queue = asyncio.Queue(maxsize=1000)
|
494
|
-
self.data_workers: Dict[str, WebSocketServerProtocol] = {}
|
495
|
-
|
496
|
-
# Backpressure threshold
|
497
|
-
self.backpressure_threshold = config.get("backpressure_threshold", 800)
|
498
|
-
|
499
|
-
# Shard processing state
|
500
|
-
self.all_shards = []
|
501
|
-
self.current_shard_index = 0
|
502
|
-
self.shard_lock = threading.Lock()
|
503
|
-
|
504
|
-
# Background chunk creation
|
505
|
-
self.chunk_creation_thread = None
|
506
|
-
self.stop_chunk_creation = threading.Event()
|
507
|
-
|
508
|
-
# State restoration flag
|
509
|
-
self.state_restored = threading.Event()
|
510
|
-
# If no dataset, state is already "restored"
|
511
|
-
if not self.dataset_loader:
|
512
|
-
self.state_restored.set()
|
513
|
-
|
514
104
|
def _setup_ssl(self) -> Optional[ssl.SSLContext]:
|
515
105
|
"""Configure SSL if certificates are provided."""
|
516
106
|
ssl_config = self.config.get("ssl", {})
|
@@ -521,326 +111,20 @@ class Orchestrator:
|
|
521
111
|
context.load_cert_chain(ssl_config["cert"], ssl_config["key"])
|
522
112
|
return context
|
523
113
|
|
524
|
-
def _create_chunks_from_dataset(self):
|
525
|
-
"""Background thread to create chunks from dataset shards on demand."""
|
526
|
-
if not self.dataset_loader:
|
527
|
-
logger.warning("No dataset configured, skipping chunk creation")
|
528
|
-
self.state_restored.set() # No state to restore
|
529
|
-
return
|
530
|
-
|
531
|
-
logger.info("Starting chunk creation thread")
|
532
|
-
|
533
|
-
# Mark state as not restored until we process checkpoints
|
534
|
-
self.state_restored.clear()
|
535
|
-
|
536
|
-
# Get dataset info to check format
|
537
|
-
dataset_info = self.dataset_loader.get_dataset_info()
|
538
|
-
dataset_format = dataset_info.get("dataset_format", "unknown")
|
539
|
-
logger.info(f"Dataset format: {dataset_format}")
|
540
|
-
|
541
|
-
# Get all shards
|
542
|
-
self.all_shards = self.dataset_loader.get_shard_list()
|
543
|
-
self.stats["total_shards"] = len(self.all_shards)
|
544
|
-
|
545
|
-
# For HuggingFace datasets, we might need to dynamically create more shards
|
546
|
-
if dataset_format == "huggingface_datasets":
|
547
|
-
self._is_hf_dataset = True
|
548
|
-
self._hf_chunk_size = 10000 # Items per virtual shard
|
549
|
-
self._next_hf_shard_index = len(self.all_shards) # For creating new virtual shards
|
550
|
-
else:
|
551
|
-
self._is_hf_dataset = False
|
552
|
-
|
553
|
-
# Get shard status from ChunkTracker
|
554
|
-
shards_summary = self.chunk_tracker.get_shards_summary() if self.chunk_tracker else {}
|
555
|
-
completed_shards = {
|
556
|
-
shard_name for shard_name, info in shards_summary.items() if info["is_complete"]
|
557
|
-
}
|
558
|
-
|
559
|
-
# Update ShardTracker for completed shards
|
560
|
-
for shard_name in completed_shards:
|
561
|
-
if not self.shard_tracker.is_complete(shard_name):
|
562
|
-
logger.info(f"Marking shard {shard_name} as complete in ShardTracker")
|
563
|
-
self.shard_tracker.mark_complete(shard_name)
|
564
|
-
|
565
|
-
# Get shards that need processing
|
566
|
-
remaining_shards = self.shard_tracker.get_remaining_shards(self.all_shards)
|
567
|
-
|
568
|
-
# Also check which shards already have chunks (partial or complete)
|
569
|
-
shards_with_chunks = set()
|
570
|
-
for shard_name in shards_summary.keys():
|
571
|
-
shards_with_chunks.add(shard_name)
|
572
|
-
|
573
|
-
# Filter out shards that already have chunks created
|
574
|
-
remaining_shards = [
|
575
|
-
shard
|
576
|
-
for shard in remaining_shards
|
577
|
-
if (shard if shard.startswith("hf_dataset:") else Path(shard).stem)
|
578
|
-
not in shards_with_chunks
|
579
|
-
]
|
580
|
-
|
581
|
-
self.stats["completed_shards"] = len(completed_shards)
|
582
|
-
|
583
|
-
logger.info(
|
584
|
-
f"Total shards: {len(self.all_shards)}, "
|
585
|
-
f"Completed: {self.stats['completed_shards']}, "
|
586
|
-
f"Shards with chunks: {len(shards_with_chunks)}, "
|
587
|
-
f"Remaining to process: {len(remaining_shards)}"
|
588
|
-
)
|
589
|
-
|
590
|
-
# First, re-queue any existing pending chunks
|
591
|
-
initial_pending = 0
|
592
|
-
requeued_chunks_by_shard = defaultdict(list)
|
593
|
-
|
594
|
-
for shard_name, shard_info in shards_summary.items():
|
595
|
-
with self.chunk_manager.lock:
|
596
|
-
for chunk_state in shard_info["chunks"]:
|
597
|
-
if chunk_state.status in ["pending", "failed", "assigned"]:
|
598
|
-
# For assigned chunks, reset them to pending since workers don't exist
|
599
|
-
chunk = ShardChunk(
|
600
|
-
chunk_id=chunk_state.chunk_id,
|
601
|
-
shard_url=chunk_state.shard_url,
|
602
|
-
shard_name=chunk_state.shard_name,
|
603
|
-
start_index=chunk_state.start_index,
|
604
|
-
chunk_size=chunk_state.chunk_size,
|
605
|
-
status="pending", # Reset to pending
|
606
|
-
assigned_to=None, # Clear assignment
|
607
|
-
)
|
608
|
-
self.chunk_manager.chunks[chunk_state.chunk_id] = chunk
|
609
|
-
self.chunk_manager.pending_chunks.append(chunk_state.chunk_id)
|
610
|
-
requeued_chunks_by_shard[shard_name].append(chunk_state.chunk_id)
|
611
|
-
initial_pending += 1
|
612
|
-
|
613
|
-
logger.info(f"Re-queued {initial_pending} existing pending chunks")
|
614
|
-
for shard_name, chunk_ids in requeued_chunks_by_shard.items():
|
615
|
-
logger.info(f" Shard {shard_name}: {len(chunk_ids)} chunks - {chunk_ids}")
|
616
|
-
|
617
|
-
# Mark state as restored
|
618
|
-
self.state_restored.set()
|
619
|
-
logger.info("State restoration complete, accepting chunk requests")
|
620
|
-
|
621
|
-
# Process shards on-demand
|
622
|
-
shard_iter = iter(remaining_shards)
|
623
|
-
current_shard_url = None
|
624
|
-
current_shard_name = None
|
625
|
-
current_shard_items = None
|
626
|
-
current_shard_index = 0
|
627
|
-
|
628
|
-
while not self.stop_chunk_creation.is_set():
|
629
|
-
# Check how many chunks we need
|
630
|
-
with self.chunk_manager.lock:
|
631
|
-
pending_count = len(self.chunk_manager.pending_chunks)
|
632
|
-
assigned_count = sum(
|
633
|
-
len(chunks) for chunks in self.chunk_manager.assigned_chunks.values()
|
634
|
-
)
|
635
|
-
total_active = pending_count + assigned_count
|
636
|
-
|
637
|
-
# Target buffer: configurable multiplier × number of workers
|
638
|
-
worker_count = max(1, self.stats.get("connected_workers", 0))
|
639
|
-
target_buffer = max(
|
640
|
-
self.min_chunk_buffer, worker_count * self.chunk_buffer_multiplier
|
641
|
-
)
|
642
|
-
|
643
|
-
chunks_needed = max(0, target_buffer - total_active)
|
644
|
-
|
645
|
-
if chunks_needed == 0:
|
646
|
-
# We have enough chunks, wait a bit
|
647
|
-
time.sleep(5)
|
648
|
-
continue
|
649
|
-
|
650
|
-
logger.debug(
|
651
|
-
f"Need {chunks_needed} more chunks (pending: {pending_count}, "
|
652
|
-
f"assigned: {assigned_count}, workers: {worker_count})"
|
653
|
-
)
|
654
|
-
|
655
|
-
# Create chunks as needed
|
656
|
-
chunks_created = 0
|
657
|
-
|
658
|
-
while chunks_created < chunks_needed and not self.stop_chunk_creation.is_set():
|
659
|
-
# Need to load next shard?
|
660
|
-
if current_shard_url is None or current_shard_index >= current_shard_items:
|
661
|
-
try:
|
662
|
-
current_shard_url = next(shard_iter)
|
663
|
-
|
664
|
-
# Extract shard name based on type
|
665
|
-
if current_shard_url.startswith("hf_dataset:"):
|
666
|
-
current_shard_name = current_shard_url # Use full ID for virtual shards
|
667
|
-
else:
|
668
|
-
current_shard_name = Path(current_shard_url).stem
|
669
|
-
|
670
|
-
self.stats["current_shard"] = current_shard_name
|
671
|
-
|
672
|
-
# Skip if we already have chunks from this shard
|
673
|
-
if current_shard_name in shards_summary:
|
674
|
-
logger.debug(
|
675
|
-
f"Skipping shard {current_shard_name} - already has chunks"
|
676
|
-
)
|
677
|
-
current_shard_url = None
|
678
|
-
continue
|
679
|
-
|
680
|
-
# Count items in new shard
|
681
|
-
logger.info(f"Loading new shard {current_shard_name}")
|
682
|
-
|
683
|
-
# For virtual HF dataset shards, use the chunk size directly
|
684
|
-
if current_shard_url.startswith("hf_dataset:"):
|
685
|
-
current_shard_items = self.dataset_loader.count_shard_items(
|
686
|
-
current_shard_url
|
687
|
-
)
|
688
|
-
logger.info(
|
689
|
-
f"Virtual shard {current_shard_name} has {current_shard_items} items"
|
690
|
-
)
|
691
|
-
else:
|
692
|
-
# For WebDataset, actually count items
|
693
|
-
current_shard_items = sum(
|
694
|
-
1 for _ in self.dataset_loader.iterate_shard(current_shard_url)
|
695
|
-
)
|
696
|
-
logger.info(
|
697
|
-
f"Shard {current_shard_name} has {current_shard_items} items"
|
698
|
-
)
|
699
|
-
|
700
|
-
current_shard_index = 0
|
701
|
-
|
702
|
-
except StopIteration:
|
703
|
-
# No more shards in the iterator
|
704
|
-
if self._is_hf_dataset:
|
705
|
-
# Before creating new virtual shards, check if we have pending chunks
|
706
|
-
with self.chunk_manager.lock:
|
707
|
-
pending_count = len(self.chunk_manager.pending_chunks)
|
708
|
-
|
709
|
-
if pending_count > 0:
|
710
|
-
# Don't create new shards if we have pending chunks
|
711
|
-
logger.debug(
|
712
|
-
f"Have {pending_count} pending chunks, not creating new virtual shards yet"
|
713
|
-
)
|
714
|
-
current_shard_url = None
|
715
|
-
time.sleep(2)
|
716
|
-
continue
|
717
|
-
|
718
|
-
# For HF datasets, we can create more virtual shards on demand
|
719
|
-
logger.info(
|
720
|
-
"Creating additional virtual shards for HuggingFace dataset"
|
721
|
-
)
|
722
|
-
|
723
|
-
# Create 10 more virtual shards
|
724
|
-
new_shards = []
|
725
|
-
for i in range(10):
|
726
|
-
shard_id = f"hf_dataset:{self.dataset_path}:chunk:{self._next_hf_shard_index * self._hf_chunk_size}"
|
727
|
-
new_shards.append(shard_id)
|
728
|
-
self._next_hf_shard_index += 1
|
729
|
-
|
730
|
-
# Add to all_shards and create new iterator
|
731
|
-
self.all_shards.extend(new_shards)
|
732
|
-
self.stats["total_shards"] = len(self.all_shards)
|
733
|
-
|
734
|
-
# Filter for unprocessed shards
|
735
|
-
remaining_new_shards = [
|
736
|
-
s
|
737
|
-
for s in new_shards
|
738
|
-
if s not in shards_summary and s not in completed_shards
|
739
|
-
]
|
740
|
-
|
741
|
-
if remaining_new_shards:
|
742
|
-
shard_iter = iter(remaining_new_shards)
|
743
|
-
logger.info(f"Added {len(remaining_new_shards)} new virtual shards")
|
744
|
-
continue
|
745
|
-
|
746
|
-
# No more shards to process
|
747
|
-
logger.info("No more shards to process")
|
748
|
-
break
|
749
|
-
|
750
|
-
except Exception as e:
|
751
|
-
logger.error(f"Error loading shard {current_shard_name}: {e}")
|
752
|
-
current_shard_url = None
|
753
|
-
continue
|
754
|
-
|
755
|
-
# Create a chunk from current shard
|
756
|
-
if current_shard_url and current_shard_index < current_shard_items:
|
757
|
-
# Calculate the absolute dataset index for this chunk
|
758
|
-
if current_shard_url.startswith("hf_dataset:"):
|
759
|
-
# Parse the virtual shard URL to get the base start index
|
760
|
-
parts = current_shard_url.split(":")
|
761
|
-
if len(parts) >= 4 and parts[2] == "chunk":
|
762
|
-
shard_base_index = int(parts[3])
|
763
|
-
else:
|
764
|
-
shard_base_index = 0
|
765
|
-
|
766
|
-
# The absolute start index for this chunk in the dataset
|
767
|
-
absolute_start_index = shard_base_index + current_shard_index
|
768
|
-
else:
|
769
|
-
# For WebDataset, current_shard_index is already absolute
|
770
|
-
absolute_start_index = current_shard_index
|
771
|
-
|
772
|
-
# Create chunk with absolute index
|
773
|
-
chunk = ShardChunk.create(
|
774
|
-
shard_url=current_shard_url,
|
775
|
-
shard_name=current_shard_name,
|
776
|
-
start_index=absolute_start_index,
|
777
|
-
chunk_size=min(self.chunk_size, current_shard_items - current_shard_index),
|
778
|
-
)
|
779
|
-
|
780
|
-
# Add to ChunkTracker with all required fields
|
781
|
-
if self.chunk_tracker and self.chunk_tracker.add_chunk(
|
782
|
-
chunk.chunk_id,
|
783
|
-
chunk.shard_name,
|
784
|
-
chunk.shard_url,
|
785
|
-
chunk.start_index,
|
786
|
-
chunk.chunk_size,
|
787
|
-
):
|
788
|
-
with self.chunk_manager.lock:
|
789
|
-
self.chunk_manager.chunks[chunk.chunk_id] = chunk
|
790
|
-
self.chunk_manager.pending_chunks.append(chunk.chunk_id)
|
791
|
-
|
792
|
-
chunks_created += 1
|
793
|
-
self.stats["total_chunks"] += 1
|
794
|
-
|
795
|
-
current_shard_index += self.chunk_size
|
796
|
-
|
797
|
-
if chunks_created > 0:
|
798
|
-
logger.info(f"Created {chunks_created} chunks on demand")
|
799
|
-
|
800
|
-
# If we couldn't create any chunks and there are no more shards, check if it's HF dataset
|
801
|
-
if chunks_created == 0 and current_shard_url is None:
|
802
|
-
if self._is_hf_dataset:
|
803
|
-
# We can always create more virtual shards for HF datasets
|
804
|
-
logger.debug("Will create more virtual shards on next iteration")
|
805
|
-
else:
|
806
|
-
logger.info("All shards processed, chunk creation complete")
|
807
|
-
break
|
808
|
-
|
809
|
-
# Brief pause to avoid spinning
|
810
|
-
time.sleep(1)
|
811
|
-
|
812
|
-
# Final stats
|
813
|
-
if self.chunk_tracker:
|
814
|
-
final_stats = self.chunk_tracker.get_stats()
|
815
|
-
logger.info(
|
816
|
-
f"Chunk creation thread ending. Total: {final_stats['total']}, "
|
817
|
-
f"Pending: {final_stats['pending']}, Completed: {final_stats['completed']}"
|
818
|
-
)
|
819
|
-
|
820
|
-
logger.info("Chunk creation thread finished")
|
821
|
-
|
822
114
|
async def start(self):
|
823
115
|
"""Start the orchestrator server."""
|
824
|
-
logger.info(f"Starting
|
825
|
-
|
826
|
-
|
827
|
-
|
828
|
-
|
829
|
-
|
830
|
-
await self.storage.initialize()
|
831
|
-
if self.chunk_tracker:
|
832
|
-
await self.chunk_tracker.sync_with_storage(self.storage)
|
833
|
-
await self._restore_state()
|
834
|
-
|
835
|
-
# Start chunk creation thread if dataset is configured
|
836
|
-
if self.dataset_loader:
|
837
|
-
self.chunk_creation_thread = threading.Thread(
|
838
|
-
target=self._create_chunks_from_dataset, daemon=True
|
116
|
+
logger.info(f"Starting orchestrator on {self.host}:{self.port}")
|
117
|
+
processor_type = self.config.get("dataset", {}).get("processor_type", None)
|
118
|
+
if not processor_type:
|
119
|
+
logger.info(f"Config: {self.config}")
|
120
|
+
raise ValueError(
|
121
|
+
"You must supply processor_type in your orchestrator dataset configuration."
|
839
122
|
)
|
840
|
-
|
123
|
+
logger.info(f"Processor type: {processor_type}")
|
841
124
|
|
842
|
-
|
843
|
-
|
125
|
+
# Initialize storage
|
126
|
+
await self.storage.initialize()
|
127
|
+
await self.update_unprocessed_ranges()
|
844
128
|
|
845
129
|
# Start background tasks
|
846
130
|
asyncio.create_task(self._heartbeat_loop())
|
@@ -848,12 +132,37 @@ class Orchestrator:
|
|
848
132
|
asyncio.create_task(self._stats_update_loop())
|
849
133
|
|
850
134
|
# Start WebSocket server
|
135
|
+
websocket_logger = logging.getLogger("websockets")
|
136
|
+
websocket_logger.setLevel(logging.WARNING)
|
851
137
|
async with websockets.serve(
|
852
|
-
self.handle_connection,
|
138
|
+
self.handle_connection,
|
139
|
+
self.host,
|
140
|
+
self.port,
|
141
|
+
ssl=self.ssl_context,
|
142
|
+
logger=websocket_logger,
|
853
143
|
):
|
854
|
-
logger.info("
|
144
|
+
logger.info("Orchestrator ready for connections")
|
855
145
|
await asyncio.Future() # Run forever
|
856
146
|
|
147
|
+
def get_workers_by_user_stats(self) -> Dict[str, Dict]:
|
148
|
+
"""Get worker statistics grouped by user."""
|
149
|
+
stats = {}
|
150
|
+
for user, worker_ids in self.workers_by_user.items():
|
151
|
+
stats[user] = {"worker_ids": list(worker_ids), "count": len(worker_ids)}
|
152
|
+
return stats
|
153
|
+
|
154
|
+
async def update_unprocessed_ranges(self):
|
155
|
+
"""Update unprocessed ranges based on what's actually in storage."""
|
156
|
+
if not self.processor or not self.storage:
|
157
|
+
return
|
158
|
+
|
159
|
+
processed_job_ids = self.storage.get_all_processed_job_ids()
|
160
|
+
self.processor.update_from_storage(processed_job_ids)
|
161
|
+
|
162
|
+
async def _send_leaderboard_to_monitor(self, websocket: WebSocketServerProtocol):
|
163
|
+
"""Alias for _send_monitor_leaderboard for backward compatibility."""
|
164
|
+
await self._send_monitor_leaderboard(websocket)
|
165
|
+
|
857
166
|
async def handle_connection(self, websocket: WebSocketServerProtocol):
|
858
167
|
"""Handle new WebSocket connection."""
|
859
168
|
try:
|
@@ -868,49 +177,77 @@ class Orchestrator:
|
|
868
177
|
|
869
178
|
if auth_ticket.role == "worker":
|
870
179
|
await self._handle_worker(websocket, auth_ticket)
|
871
|
-
elif auth_ticket.role == "data_worker":
|
872
|
-
await self._handle_data_worker(websocket, auth_ticket)
|
873
180
|
elif auth_ticket.role == "monitor":
|
874
181
|
await self._handle_monitor(websocket)
|
875
182
|
elif auth_ticket.role == "admin":
|
876
183
|
await self._handle_admin(websocket, auth_ticket)
|
184
|
+
elif auth_ticket.role == "data_worker":
|
185
|
+
await self._handle_data_worker(websocket, auth_ticket)
|
877
186
|
else:
|
878
187
|
await websocket.send(
|
879
188
|
safe_json_dumps({"error": f"Unknown role: {auth_ticket.role}"})
|
880
189
|
)
|
881
190
|
|
882
191
|
except Exception as e:
|
883
|
-
logger.error(f"Connection error: {e}")
|
884
|
-
import traceback
|
885
|
-
|
886
|
-
logger.error(traceback.format_exc())
|
192
|
+
logger.error(f"Connection error: {e}", exc_info=True)
|
887
193
|
await websocket.close()
|
888
194
|
|
889
|
-
async def
|
890
|
-
"""Handle
|
891
|
-
|
892
|
-
|
195
|
+
async def _handle_worker(self, websocket: WebSocketServerProtocol, auth_ticket):
|
196
|
+
"""Handle worker connection lifecycle."""
|
197
|
+
# Generate unique worker ID
|
198
|
+
base_name = getattr(auth_ticket, "name", "worker")
|
199
|
+
worker_id = f"{base_name}_{str(uuid.uuid4())[:8]}"
|
200
|
+
worker_user = base_name
|
201
|
+
|
202
|
+
self.workers[worker_id] = websocket
|
203
|
+
self.workers_by_user[worker_user].add(worker_id)
|
204
|
+
self.stats["connected_workers"] = len(self.workers)
|
893
205
|
|
206
|
+
# Register contributor
|
207
|
+
contributor = await self.storage.get_contributor(worker_user)
|
208
|
+
if not contributor:
|
209
|
+
contributor = Contributor(
|
210
|
+
contributor_id=worker_user,
|
211
|
+
name=worker_user,
|
212
|
+
total_captions=0,
|
213
|
+
trust_level=1,
|
214
|
+
)
|
215
|
+
await self.storage.save_contributor(contributor)
|
216
|
+
|
217
|
+
logger.info(f"Worker {worker_id} (user: {worker_user}) is retrieving configuration")
|
894
218
|
try:
|
895
|
-
# Send welcome
|
896
|
-
|
219
|
+
# Send welcome message with processor config
|
220
|
+
filtered_config = self.config.copy()
|
221
|
+
for unwanted_key in ["auth", "orchestrator", "storage"]:
|
222
|
+
filtered_config.pop(unwanted_key, None)
|
223
|
+
welcome_message = {
|
224
|
+
"type": "welcome",
|
225
|
+
"worker_id": worker_id,
|
226
|
+
"user_id": worker_user,
|
227
|
+
"processor_type": self.config.get("dataset", {}).get("processor_type", None),
|
228
|
+
"processor_config": filtered_config,
|
229
|
+
}
|
230
|
+
await websocket.send(safe_json_dumps(welcome_message))
|
897
231
|
|
898
232
|
async for message in websocket:
|
899
|
-
|
900
|
-
|
901
|
-
msg_type = data.get("type")
|
233
|
+
data = json.loads(message)
|
234
|
+
await self._process_worker_message(worker_id, data)
|
902
235
|
|
903
|
-
|
904
|
-
|
236
|
+
except websockets.exceptions.ConnectionClosed:
|
237
|
+
logger.info(f"Worker {worker_id} has disconnected due to websocket connection closure")
|
238
|
+
finally:
|
239
|
+
if worker_id in self.workers:
|
240
|
+
del self.workers[worker_id]
|
905
241
|
|
906
|
-
|
907
|
-
|
908
|
-
|
909
|
-
safe_json_dumps({"type": "error", "error": "Invalid message format"})
|
910
|
-
)
|
242
|
+
self.workers_by_user[worker_user].discard(worker_id)
|
243
|
+
if not self.workers_by_user[worker_user]:
|
244
|
+
del self.workers_by_user[worker_user]
|
911
245
|
|
912
|
-
|
913
|
-
|
246
|
+
self.stats["connected_workers"] = len(self.workers)
|
247
|
+
|
248
|
+
# Release assignments
|
249
|
+
self.processor.release_assignments(worker_id)
|
250
|
+
logger.info(f"Worker {worker_id} has safely disconnected")
|
914
251
|
|
915
252
|
async def _handle_config_reload(
|
916
253
|
self, websocket: WebSocketServerProtocol, new_config: Dict[str, Any]
|
@@ -920,224 +257,61 @@ class Orchestrator:
|
|
920
257
|
|
921
258
|
updated_sections = []
|
922
259
|
warnings = []
|
923
|
-
requires_worker_restart = False
|
924
260
|
|
925
261
|
try:
|
926
262
|
# Extract orchestrator section if present
|
927
263
|
if "orchestrator" in new_config:
|
928
|
-
# Config has orchestrator wrapper, extract it
|
929
264
|
orchestrator_config = new_config["orchestrator"]
|
930
265
|
else:
|
931
|
-
# Config is already at orchestrator level
|
932
266
|
orchestrator_config = new_config
|
933
267
|
|
934
|
-
#
|
935
|
-
|
936
|
-
|
937
|
-
|
938
|
-
|
939
|
-
if
|
940
|
-
|
941
|
-
|
942
|
-
return all(deep_equal(a[k], b[k]) for k in a.keys())
|
943
|
-
elif isinstance(a, (list, tuple)):
|
944
|
-
if len(a) != len(b):
|
945
|
-
return False
|
946
|
-
return all(deep_equal(x, y) for x, y in zip(a, b))
|
268
|
+
# Update processor configuration if present
|
269
|
+
if "processor_type" in orchestrator_config:
|
270
|
+
old_type = self.config.get("processor_type")
|
271
|
+
new_type = orchestrator_config["processor_type"]
|
272
|
+
|
273
|
+
if old_type != new_type:
|
274
|
+
warnings.append("Processor type changes require orchestrator restart")
|
275
|
+
updated_sections.append("processor_type")
|
947
276
|
else:
|
948
|
-
|
949
|
-
|
950
|
-
|
951
|
-
|
952
|
-
|
953
|
-
|
954
|
-
|
955
|
-
|
956
|
-
|
957
|
-
|
958
|
-
|
959
|
-
|
960
|
-
|
961
|
-
|
962
|
-
|
963
|
-
# Check if critical changes require worker restart
|
964
|
-
if (
|
965
|
-
old_vllm.get("model") != new_vllm.get("model")
|
966
|
-
or old_vllm.get("gpu_memory_utilization")
|
967
|
-
!= new_vllm.get("gpu_memory_utilization")
|
968
|
-
or old_vllm.get("tensor_parallel_size")
|
969
|
-
!= new_vllm.get("tensor_parallel_size")
|
970
|
-
or old_vllm.get("dtype") != new_vllm.get("dtype")
|
971
|
-
or old_vllm.get("max_model_len") != new_vllm.get("max_model_len")
|
972
|
-
):
|
973
|
-
requires_worker_restart = True
|
974
|
-
warnings.append(
|
975
|
-
"Critical vLLM changes detected - workers will be disconnected to reload"
|
976
|
-
)
|
977
|
-
logger.info(
|
978
|
-
f"Model change: {old_vllm.get('model')} -> {new_vllm.get('model')}"
|
979
|
-
)
|
980
|
-
|
981
|
-
# Update dataset configuration
|
982
|
-
if "dataset" in orchestrator_config:
|
983
|
-
old_dataset = self.dataset_config.copy()
|
984
|
-
new_dataset = orchestrator_config["dataset"]
|
985
|
-
|
986
|
-
dataset_changed = not deep_equal(old_dataset, new_dataset)
|
987
|
-
|
988
|
-
if dataset_changed:
|
989
|
-
self.dataset_config = new_dataset.copy()
|
990
|
-
self.dataset_path = self.dataset_config.get("path")
|
991
|
-
self.dataset_type = self.dataset_config.get("type", "huggingface")
|
992
|
-
updated_sections.append("dataset")
|
993
|
-
warnings.append("Dataset changes will apply to new chunks only")
|
994
|
-
|
995
|
-
# Update chunk settings
|
996
|
-
if (
|
997
|
-
"chunk_size" in orchestrator_config
|
998
|
-
and self.chunk_size != orchestrator_config["chunk_size"]
|
999
|
-
):
|
1000
|
-
self.chunk_size = orchestrator_config["chunk_size"]
|
1001
|
-
self.chunk_manager.chunk_size = self.chunk_size
|
1002
|
-
updated_sections.append("chunk_size")
|
1003
|
-
|
1004
|
-
if (
|
1005
|
-
"chunks_per_request" in orchestrator_config
|
1006
|
-
and self.chunks_per_request != orchestrator_config["chunks_per_request"]
|
1007
|
-
):
|
1008
|
-
self.chunks_per_request = orchestrator_config["chunks_per_request"]
|
1009
|
-
updated_sections.append("chunks_per_request")
|
277
|
+
# Update processor config
|
278
|
+
self.config.update(orchestrator_config)
|
279
|
+
|
280
|
+
# Reinitialize processor with new config
|
281
|
+
processor_config = ProcessorConfig(
|
282
|
+
processor_type=new_type, config=orchestrator_config
|
283
|
+
)
|
284
|
+
self.processor.initialize(processor_config)
|
285
|
+
updated_sections.append("processor_config")
|
286
|
+
|
287
|
+
# Update units per request
|
288
|
+
if "units_per_request" in orchestrator_config:
|
289
|
+
self.units_per_request = orchestrator_config["units_per_request"]
|
290
|
+
updated_sections.append("units_per_request")
|
1010
291
|
|
1011
292
|
# Update auth configuration
|
1012
293
|
if "auth" in orchestrator_config:
|
1013
294
|
try:
|
1014
|
-
self.auth = AuthManager(
|
295
|
+
self.auth = AuthManager(orchestrator_config["auth"])
|
1015
296
|
updated_sections.append("auth")
|
1016
297
|
except Exception as e:
|
1017
298
|
logger.error(f"Failed to update AuthManager: {e}")
|
1018
299
|
warnings.append(f"Auth update failed: {e}")
|
1019
300
|
|
1020
|
-
# Update buffer settings
|
1021
|
-
if (
|
1022
|
-
"chunk_buffer_multiplier" in orchestrator_config
|
1023
|
-
and self.chunk_buffer_multiplier != orchestrator_config["chunk_buffer_multiplier"]
|
1024
|
-
):
|
1025
|
-
self.chunk_buffer_multiplier = orchestrator_config["chunk_buffer_multiplier"]
|
1026
|
-
updated_sections.append("chunk_buffer_multiplier")
|
1027
|
-
|
1028
|
-
if (
|
1029
|
-
"min_chunk_buffer" in orchestrator_config
|
1030
|
-
and self.min_chunk_buffer != orchestrator_config["min_chunk_buffer"]
|
1031
|
-
):
|
1032
|
-
self.min_chunk_buffer = orchestrator_config["min_chunk_buffer"]
|
1033
|
-
updated_sections.append("min_chunk_buffer")
|
1034
|
-
|
1035
301
|
# Update storage settings
|
1036
302
|
if "storage" in orchestrator_config:
|
1037
303
|
storage_config = orchestrator_config["storage"]
|
1038
|
-
storage_changed = False
|
1039
304
|
|
1040
|
-
if
|
1041
|
-
"caption_buffer_size" in storage_config
|
1042
|
-
and self.storage.caption_buffer_size != storage_config["caption_buffer_size"]
|
1043
|
-
):
|
305
|
+
if "caption_buffer_size" in storage_config:
|
1044
306
|
self.storage.caption_buffer_size = storage_config["caption_buffer_size"]
|
1045
|
-
|
307
|
+
updated_sections.append("storage.caption_buffer_size")
|
1046
308
|
|
1047
|
-
|
1048
|
-
current_interval = self.config.get("storage", {}).get(
|
1049
|
-
"checkpoint_interval", 1000
|
1050
|
-
)
|
1051
|
-
if current_interval != storage_config["checkpoint_interval"]:
|
1052
|
-
self.config.setdefault("storage", {})["checkpoint_interval"] = (
|
1053
|
-
storage_config["checkpoint_interval"]
|
1054
|
-
)
|
1055
|
-
storage_changed = True
|
1056
|
-
|
1057
|
-
if storage_changed:
|
1058
|
-
updated_sections.append("storage")
|
1059
|
-
|
1060
|
-
# Check if any changes were made
|
1061
|
-
if not updated_sections:
|
1062
|
-
await websocket.send(
|
1063
|
-
safe_json_dumps(
|
1064
|
-
{
|
1065
|
-
"type": "reload_complete",
|
1066
|
-
"message": "No changes applied - configuration is identical",
|
1067
|
-
}
|
1068
|
-
)
|
1069
|
-
)
|
1070
|
-
logger.info("Configuration reload requested but no changes detected")
|
1071
|
-
return
|
1072
|
-
|
1073
|
-
# Update the main config
|
309
|
+
# Update main config
|
1074
310
|
if "orchestrator" in new_config:
|
1075
|
-
self.config["orchestrator"]
|
311
|
+
self.config = new_config["orchestrator"]
|
1076
312
|
else:
|
1077
313
|
self.config.update(orchestrator_config)
|
1078
314
|
|
1079
|
-
# Handle worker restart if needed
|
1080
|
-
if requires_worker_restart:
|
1081
|
-
logger.info("Disconnecting all workers for configuration reload...")
|
1082
|
-
|
1083
|
-
# Send reload message to workers first
|
1084
|
-
reload_msg = safe_json_dumps(
|
1085
|
-
{
|
1086
|
-
"type": "reload_vllm",
|
1087
|
-
"vllm_config": self.vllm_config,
|
1088
|
-
}
|
1089
|
-
)
|
1090
|
-
|
1091
|
-
# Create a list of worker items to avoid modifying dict during iteration
|
1092
|
-
worker_items = list(self.workers.items())
|
1093
|
-
disconnected = []
|
1094
|
-
|
1095
|
-
for worker_id, ws in worker_items:
|
1096
|
-
try:
|
1097
|
-
await ws.send(reload_msg)
|
1098
|
-
# Give worker time to process before disconnect
|
1099
|
-
await asyncio.sleep(0.5)
|
1100
|
-
await ws.close(code=1012, reason="Configuration reload")
|
1101
|
-
disconnected.append(worker_id)
|
1102
|
-
except:
|
1103
|
-
disconnected.append(worker_id) # Still mark as disconnected if error
|
1104
|
-
|
1105
|
-
# Now safely clear workers dict
|
1106
|
-
for worker_id in disconnected:
|
1107
|
-
if worker_id in self.workers:
|
1108
|
-
del self.workers[worker_id]
|
1109
|
-
|
1110
|
-
warnings.append(
|
1111
|
-
f"Sent reload message to {len(disconnected)} workers - they will reconnect with new config"
|
1112
|
-
)
|
1113
|
-
else:
|
1114
|
-
# Just notify workers about config changes without disconnecting
|
1115
|
-
config_update_msg = safe_json_dumps(
|
1116
|
-
{
|
1117
|
-
"type": "config_update",
|
1118
|
-
"vllm_config": self.vllm_config if "vllm" in updated_sections else None,
|
1119
|
-
"dataset_config": (
|
1120
|
-
self.dataset_config if "dataset" in updated_sections else None
|
1121
|
-
),
|
1122
|
-
}
|
1123
|
-
)
|
1124
|
-
|
1125
|
-
# Create a list of worker items to avoid modifying dict during iteration
|
1126
|
-
worker_items = list(self.workers.items())
|
1127
|
-
disconnected = []
|
1128
|
-
|
1129
|
-
for worker_id, ws in worker_items:
|
1130
|
-
try:
|
1131
|
-
await ws.send(config_update_msg)
|
1132
|
-
logger.info(f"Sent config update to worker {worker_id}")
|
1133
|
-
except:
|
1134
|
-
disconnected.append(worker_id)
|
1135
|
-
|
1136
|
-
# Now safely remove disconnected workers
|
1137
|
-
for worker_id in disconnected:
|
1138
|
-
if worker_id in self.workers:
|
1139
|
-
del self.workers[worker_id]
|
1140
|
-
|
1141
315
|
# Send success response
|
1142
316
|
await websocket.send(
|
1143
317
|
safe_json_dumps(
|
@@ -1146,306 +320,169 @@ class Orchestrator:
|
|
1146
320
|
)
|
1147
321
|
|
1148
322
|
logger.info(f"Configuration reloaded. Updated sections: {', '.join(updated_sections)}")
|
1149
|
-
|
1150
|
-
# Broadcast stats update to monitors
|
1151
|
-
await self._broadcast_stats()
|
1152
323
|
await self._send_activity(
|
1153
324
|
f"Configuration reloaded by admin: {', '.join(updated_sections)}"
|
1154
325
|
)
|
1155
326
|
|
1156
327
|
except Exception as e:
|
1157
328
|
logger.error(f"Configuration reload failed: {e}")
|
1158
|
-
import traceback
|
1159
|
-
|
1160
|
-
logger.error(traceback.format_exc())
|
1161
329
|
await websocket.send(safe_json_dumps({"type": "reload_failed", "error": str(e)}))
|
1162
330
|
|
1163
|
-
async def _handle_worker(self, websocket: WebSocketServerProtocol, auth_ticket):
|
1164
|
-
"""Handle worker connection lifecycle."""
|
1165
|
-
# Generate unique worker ID even if using same token
|
1166
|
-
base_name = getattr(auth_ticket, "name", "worker")
|
1167
|
-
worker_id = f"{base_name}_{str(uuid.uuid4())[:8]}" # Add unique suffix
|
1168
|
-
|
1169
|
-
# Track the original token/user for accounting
|
1170
|
-
worker_user = base_name # Keep track of which user/token this worker belongs to
|
1171
|
-
|
1172
|
-
self.workers[worker_id] = websocket
|
1173
|
-
self.stats["connected_workers"] = len(self.workers)
|
1174
|
-
|
1175
|
-
# Optionally track workers by user/token
|
1176
|
-
if not hasattr(self, "workers_by_user"):
|
1177
|
-
self.workers_by_user = defaultdict(set)
|
1178
|
-
self.workers_by_user[worker_user].add(worker_id)
|
1179
|
-
|
1180
|
-
# Register contributor with the base name (for aggregating stats per user)
|
1181
|
-
contributor = await self.storage.get_contributor(worker_user)
|
1182
|
-
if not contributor:
|
1183
|
-
contributor = Contributor(
|
1184
|
-
contributor_id=worker_user,
|
1185
|
-
name=worker_user,
|
1186
|
-
total_captions=0,
|
1187
|
-
trust_level=1,
|
1188
|
-
)
|
1189
|
-
await self.storage.save_contributor(contributor)
|
1190
|
-
|
1191
|
-
logger.info(f"Worker {worker_id} (user: {worker_user}) connected")
|
1192
|
-
await self._broadcast_stats()
|
1193
|
-
await self._send_activity(f"Worker {worker_id} (user: {worker_user}) connected")
|
1194
|
-
|
1195
|
-
try:
|
1196
|
-
# Send welcome message with dataset configuration
|
1197
|
-
welcome_message = {
|
1198
|
-
"type": "welcome",
|
1199
|
-
"worker_id": worker_id,
|
1200
|
-
"user_id": worker_user,
|
1201
|
-
"dataset_config": {
|
1202
|
-
"dataset_path": self.dataset_path,
|
1203
|
-
"dataset_type": self.dataset_type,
|
1204
|
-
"dataset_split": self.dataset_split,
|
1205
|
-
"dataset_image_column": self.dataset_image_column,
|
1206
|
-
"path": self.dataset_path,
|
1207
|
-
"type": self.dataset_type,
|
1208
|
-
"split": self.dataset_split,
|
1209
|
-
"image_column": self.dataset_image_column,
|
1210
|
-
},
|
1211
|
-
"vllm_config": self.vllm_config,
|
1212
|
-
}
|
1213
|
-
await websocket.send(safe_json_dumps(welcome_message))
|
1214
|
-
|
1215
|
-
async for message in websocket:
|
1216
|
-
data = json.loads(message)
|
1217
|
-
await self._process_worker_message(worker_id, data)
|
1218
|
-
|
1219
|
-
except websockets.exceptions.ConnectionClosed:
|
1220
|
-
logger.info(f"Worker {worker_id} (user: {worker_user}) disconnected")
|
1221
|
-
finally:
|
1222
|
-
if worker_id in self.workers:
|
1223
|
-
del self.workers[worker_id]
|
1224
|
-
|
1225
|
-
# Clean up user tracking
|
1226
|
-
if hasattr(self, "workers_by_user") and worker_user in self.workers_by_user:
|
1227
|
-
self.workers_by_user[worker_user].discard(worker_id)
|
1228
|
-
if not self.workers_by_user[worker_user]:
|
1229
|
-
del self.workers_by_user[worker_user]
|
1230
|
-
|
1231
|
-
self.stats["connected_workers"] = len(self.workers)
|
1232
|
-
|
1233
|
-
# Release chunks
|
1234
|
-
self.chunk_manager.release_worker_chunks(worker_id)
|
1235
|
-
if self.chunk_tracker:
|
1236
|
-
released_chunks = self.chunk_tracker.release_worker_chunks(worker_id)
|
1237
|
-
logger.info(
|
1238
|
-
f"Released {len(released_chunks) if released_chunks is not None else 0} chunks from worker {worker_id}"
|
1239
|
-
)
|
1240
|
-
|
1241
|
-
await self._broadcast_stats()
|
1242
|
-
await self._send_activity(f"Worker {worker_id} (user: {worker_user}) disconnected")
|
1243
|
-
|
1244
331
|
async def _process_worker_message(self, worker_id: str, data: Dict):
|
1245
332
|
"""Process message from worker."""
|
1246
333
|
msg_type = data.get("type")
|
1247
334
|
|
1248
|
-
if msg_type == "
|
1249
|
-
|
1250
|
-
|
1251
|
-
|
1252
|
-
|
1253
|
-
|
335
|
+
if msg_type == "request_work":
|
336
|
+
count = data.get("count", self.units_per_request)
|
337
|
+
units = self.processor.get_work_units(count, worker_id)
|
338
|
+
logger.debug(f"Assigning units: {[unit.chunk_id for unit in units]}")
|
339
|
+
|
340
|
+
if units:
|
341
|
+
# Create assignment
|
342
|
+
assignment = WorkAssignment(
|
343
|
+
assignment_id=str(uuid.uuid4()),
|
344
|
+
worker_id=worker_id,
|
345
|
+
units=units,
|
346
|
+
assigned_at=datetime.utcnow(),
|
1254
347
|
)
|
1255
|
-
return
|
1256
|
-
|
1257
|
-
count = data.get("count", self.chunks_per_request)
|
1258
|
-
chunk_infos = self.chunk_manager.get_chunks_for_worker(
|
1259
|
-
worker_id, count, self.chunk_tracker
|
1260
|
-
)
|
1261
|
-
|
1262
|
-
if chunk_infos:
|
1263
|
-
# Send chunks with unprocessed ranges
|
1264
|
-
chunks_data = []
|
1265
|
-
for info in chunk_infos:
|
1266
|
-
chunk_dict = info["chunk"].to_dict()
|
1267
|
-
chunk_dict["unprocessed_ranges"] = info["unprocessed_ranges"]
|
1268
|
-
chunks_data.append(chunk_dict)
|
1269
348
|
|
1270
349
|
await self.workers[worker_id].send(
|
1271
|
-
safe_json_dumps({"type": "
|
350
|
+
safe_json_dumps({"type": "work_assignment", "assignment": assignment.to_dict()})
|
1272
351
|
)
|
1273
352
|
|
1274
|
-
|
1275
|
-
logger.info(
|
1276
|
-
f"Assigned {len(chunks_data)} chunks to worker {worker_id}: {chunk_ids}"
|
1277
|
-
)
|
353
|
+
logger.debug(f"Assigned {len(units)} work units to worker {worker_id}")
|
1278
354
|
else:
|
1279
|
-
await self.workers[worker_id].send(safe_json_dumps({"type": "
|
1280
|
-
|
1281
|
-
elif msg_type == "chunk_complete":
|
1282
|
-
chunk_id = data["chunk_id"]
|
1283
|
-
if self.chunk_manager.complete_chunk(chunk_id, worker_id):
|
1284
|
-
self.stats["completed_chunks"] += 1
|
355
|
+
await self.workers[worker_id].send(safe_json_dumps({"type": "no_work"}))
|
1285
356
|
|
1286
|
-
|
1287
|
-
|
357
|
+
elif msg_type == "work_complete":
|
358
|
+
unit_id = data["unit_id"]
|
359
|
+
self.processor.mark_completed(unit_id, worker_id)
|
360
|
+
logger.debug(f"Work unit {unit_id} completed by worker {worker_id}")
|
1288
361
|
|
1289
|
-
|
1290
|
-
|
1291
|
-
await self._send_activity(f"Chunk {chunk_id} completed by {worker_id}")
|
1292
|
-
elif msg_type == "chunk_failed":
|
1293
|
-
chunk_id = data["chunk_id"]
|
362
|
+
elif msg_type == "work_failed":
|
363
|
+
unit_id = data["unit_id"]
|
1294
364
|
error = data.get("error", "Unknown error")
|
1295
|
-
|
1296
|
-
|
365
|
+
self.processor.mark_failed(unit_id, worker_id, error)
|
366
|
+
logger.warning(f"Work unit {unit_id} failed on worker {worker_id}: {error}")
|
1297
367
|
|
1298
|
-
|
1299
|
-
|
368
|
+
elif msg_type == "submit_results":
|
369
|
+
await self._handle_results_submission(worker_id, data)
|
1300
370
|
|
1301
|
-
logger.warning(f"Chunk {chunk_id} failed on worker {worker_id}: {error}")
|
1302
|
-
await self._send_activity(f"Chunk {chunk_id} failed on {worker_id}: {error}")
|
1303
|
-
|
1304
|
-
elif msg_type == "submit_captions":
|
1305
|
-
await self._handle_captions_submission(worker_id, data)
|
1306
|
-
elif msg_type == "request_job":
|
1307
|
-
# CaptionWorker requesting a job from data samples
|
1308
|
-
try:
|
1309
|
-
job = await asyncio.wait_for(self.data_sample_queue.get(), timeout=5)
|
1310
|
-
await self.workers[worker_id].send(
|
1311
|
-
json.dumps({"type": "job_assignment", "job": job})
|
1312
|
-
)
|
1313
|
-
logger.debug(f"Assigned job {job['job_id']} to worker {worker_id}")
|
1314
|
-
except asyncio.TimeoutError:
|
1315
|
-
await self.workers[worker_id].send(json.dumps({"type": "no_jobs"}))
|
1316
371
|
elif msg_type == "heartbeat":
|
1317
|
-
# Update worker stats
|
1318
372
|
logger.debug(f"Heartbeat from {worker_id}: {data}")
|
1319
373
|
|
1320
|
-
async def
|
1321
|
-
"""Process
|
1322
|
-
|
1323
|
-
item_key = data["item_key"]
|
1324
|
-
|
1325
|
-
item_index = data.get("item_index") # Worker should send this
|
1326
|
-
if item_index is None:
|
1327
|
-
# Try to extract from item_key (format: dataset_XXXXXXXX)
|
1328
|
-
try:
|
1329
|
-
item_index = int(item_key.split("_")[-1])
|
1330
|
-
except:
|
1331
|
-
logger.warning(f"Could not extract item index from key: {item_key}")
|
1332
|
-
|
1333
|
-
# Extract user from worker_id (format: "username_uuid")
|
374
|
+
async def _handle_results_submission(self, worker_id: str, data: Dict):
|
375
|
+
"""Process results submission from worker."""
|
376
|
+
# Extract user from worker_id
|
1334
377
|
worker_user = worker_id.rsplit("_", 1)[0] if "_" in worker_id else worker_id
|
1335
378
|
|
1336
|
-
#
|
1337
|
-
|
1338
|
-
|
1339
|
-
|
1340
|
-
|
1341
|
-
|
1342
|
-
|
1343
|
-
|
1344
|
-
|
1345
|
-
|
1346
|
-
|
1347
|
-
|
1348
|
-
|
1349
|
-
|
1350
|
-
|
1351
|
-
|
1352
|
-
|
1353
|
-
logger.debug(
|
1354
|
-
f"Received {len(captions_list)} captions for item {item_key} from worker {worker_id}"
|
1355
|
-
)
|
379
|
+
# Create work result
|
380
|
+
_job_id = data.get("job_id")
|
381
|
+
job_id = JobId.from_str(_job_id)
|
382
|
+
shard_name = job_id.shard_id # >data-0000<
|
383
|
+
chunk_name = job_id.chunk_id # data-0000:chunk:>0<
|
384
|
+
# logger.debug(f"({job_id}) Worker result: {data}")
|
385
|
+
result = WorkResult(
|
386
|
+
unit_id=data["unit_id"],
|
387
|
+
source_id=shard_name,
|
388
|
+
chunk_id=job_id.get_chunk_str(), # we want the full string here
|
389
|
+
sample_id=data["sample_id"],
|
390
|
+
dataset=data["dataset"],
|
391
|
+
outputs=data["outputs"],
|
392
|
+
metadata=data.get("metadata", {}),
|
393
|
+
processing_time_ms=data.get("processing_time_ms", 0),
|
394
|
+
)
|
1356
395
|
|
1357
|
-
#
|
396
|
+
# Let processor handle any custom processing
|
397
|
+
processed = self.processor.handle_result(result)
|
398
|
+
|
399
|
+
# Create caption record for storage
|
400
|
+
total_outputs = sum(len(v) for v in result.outputs.values())
|
401
|
+
|
402
|
+
filename = result.metadata.pop("_filename", None)
|
403
|
+
url = result.metadata.pop("_url", None)
|
404
|
+
image_height = result.metadata.pop("image_height", None)
|
405
|
+
image_width = result.metadata.pop("image_width", None)
|
406
|
+
file_size = result.metadata.pop("file_size", None)
|
407
|
+
image_format = result.metadata.pop("image_format", None)
|
408
|
+
item_index = result.metadata.pop("item_index", None)
|
409
|
+
item_key = result.metadata.pop("item_key", None)
|
410
|
+
to_delete_metadata_keys = ["_image_format", "_job_id"]
|
411
|
+
for key in to_delete_metadata_keys:
|
412
|
+
if key in result.metadata:
|
413
|
+
del result.metadata[key]
|
1358
414
|
caption = Caption(
|
1359
|
-
job_id=
|
1360
|
-
dataset=
|
1361
|
-
shard=
|
415
|
+
job_id=job_id,
|
416
|
+
dataset=result.dataset,
|
417
|
+
shard=processed["source_id"],
|
418
|
+
chunk_id=chunk_name,
|
1362
419
|
item_key=item_key,
|
1363
|
-
captions=
|
1364
|
-
outputs=outputs,
|
420
|
+
captions=result.outputs.get("captions", []),
|
421
|
+
outputs=result.outputs,
|
1365
422
|
contributor_id=worker_user,
|
1366
423
|
timestamp=datetime.utcnow(),
|
1367
|
-
quality_scores=None,
|
1368
|
-
# Image metadata
|
1369
|
-
image_width=data.get("image_width"),
|
1370
|
-
image_height=data.get("image_height"),
|
1371
|
-
image_format=data.get("image_format"),
|
1372
|
-
file_size=data.get("file_size"),
|
1373
|
-
# Processing metadata
|
1374
424
|
caption_count=total_outputs,
|
1375
|
-
processing_time_ms=
|
1376
|
-
|
1377
|
-
|
425
|
+
processing_time_ms=result.processing_time_ms,
|
426
|
+
metadata=result.metadata,
|
427
|
+
image_height=image_height,
|
428
|
+
image_width=image_width,
|
429
|
+
filename=filename,
|
430
|
+
url=url,
|
431
|
+
file_size=file_size,
|
432
|
+
image_format=image_format,
|
1378
433
|
)
|
1379
434
|
|
1380
|
-
#
|
435
|
+
# Save to storage
|
1381
436
|
await self.storage.save_caption(caption)
|
1382
437
|
|
1383
|
-
#
|
1384
|
-
should_flush = False
|
1385
|
-
if chunk_id and item_index is not None and self.chunk_tracker:
|
1386
|
-
with self.item_batch_lock:
|
1387
|
-
self.pending_processed_items[chunk_id].append(item_index)
|
1388
|
-
|
1389
|
-
# Check if we should flush
|
1390
|
-
total_pending = sum(
|
1391
|
-
len(indices) for indices in self.pending_processed_items.values()
|
1392
|
-
)
|
1393
|
-
time_since_flush = time.time() - self.last_item_batch_flush
|
1394
|
-
|
1395
|
-
if (
|
1396
|
-
total_pending >= self.item_batch_size
|
1397
|
-
or time_since_flush >= self.item_batch_interval
|
1398
|
-
):
|
1399
|
-
should_flush = True
|
1400
|
-
|
1401
|
-
if should_flush:
|
1402
|
-
await self._flush_processed_items()
|
1403
|
-
|
1404
|
-
# Update contributor stats (use user, not worker)
|
438
|
+
# Update contributor stats
|
1405
439
|
contributor = await self.storage.get_contributor(worker_user)
|
1406
440
|
if contributor:
|
1407
441
|
contributor.total_captions += total_outputs
|
1408
442
|
await self.storage.save_contributor(contributor)
|
1409
443
|
|
1410
|
-
|
1411
|
-
|
1412
|
-
|
1413
|
-
|
1414
|
-
|
1415
|
-
|
1416
|
-
|
1417
|
-
|
1418
|
-
or self._last_logged_outputs != total_outputs
|
1419
|
-
):
|
1420
|
-
logger.info(f"Collected {total_outputs} outputs centrally")
|
1421
|
-
self._last_logged_outputs = total_outputs
|
1422
|
-
|
1423
|
-
async def _check_shard_completion(self, chunk_id: str):
|
1424
|
-
"""Check if a shard is complete after chunk completion."""
|
1425
|
-
# Get the chunk
|
1426
|
-
chunk = self.chunk_manager.chunks.get(chunk_id)
|
1427
|
-
if not chunk:
|
1428
|
-
return
|
444
|
+
async def _handle_monitor(self, websocket: WebSocketServerProtocol):
|
445
|
+
"""Handle monitor connection."""
|
446
|
+
self.monitors.add(websocket)
|
447
|
+
logger.info(f"Monitor connected (total: {len(self.monitors)})")
|
448
|
+
|
449
|
+
try:
|
450
|
+
# Send welcome
|
451
|
+
await websocket.send(safe_json_dumps({"type": "welcome", "role": "monitor"}))
|
1429
452
|
|
1430
|
-
|
453
|
+
# Send initial stats
|
454
|
+
await self._send_monitor_stats(websocket)
|
1431
455
|
|
1432
|
-
|
1433
|
-
|
1434
|
-
|
1435
|
-
]
|
456
|
+
# Keep connection alive
|
457
|
+
async for message in websocket:
|
458
|
+
pass
|
1436
459
|
|
1437
|
-
|
1438
|
-
|
1439
|
-
|
1440
|
-
|
460
|
+
except websockets.exceptions.ConnectionClosed:
|
461
|
+
logger.info("Monitor disconnected")
|
462
|
+
finally:
|
463
|
+
self.monitors.discard(websocket)
|
1441
464
|
|
1442
|
-
|
1443
|
-
|
1444
|
-
|
1445
|
-
|
1446
|
-
|
1447
|
-
|
1448
|
-
await
|
465
|
+
async def _handle_admin(self, websocket: WebSocketServerProtocol, auth_ticket):
|
466
|
+
"""Handle admin connection."""
|
467
|
+
admin_id = getattr(auth_ticket, "name", "admin")
|
468
|
+
logger.info(f"Admin {admin_id} connected")
|
469
|
+
|
470
|
+
try:
|
471
|
+
await websocket.send(safe_json_dumps({"type": "welcome", "role": "admin"}))
|
472
|
+
|
473
|
+
async for message in websocket:
|
474
|
+
try:
|
475
|
+
data = json.loads(message)
|
476
|
+
if data.get("type") == "reload_config":
|
477
|
+
await self._handle_config_reload(websocket, data.get("config", {}))
|
478
|
+
elif data.get("type") == "get_stats":
|
479
|
+
await self._send_monitor_stats(websocket)
|
480
|
+
|
481
|
+
except json.JSONDecodeError as e:
|
482
|
+
logger.error(f"Invalid admin message: {e}")
|
483
|
+
|
484
|
+
except websockets.exceptions.ConnectionClosed:
|
485
|
+
logger.info(f"Admin {admin_id} disconnected")
|
1449
486
|
|
1450
487
|
async def _handle_data_worker(self, websocket: WebSocketServerProtocol, auth_ticket):
|
1451
488
|
"""Handle data worker connection."""
|
@@ -1516,7 +553,64 @@ class Orchestrator:
|
|
1516
553
|
finally:
|
1517
554
|
del self.data_workers[worker_id]
|
1518
555
|
|
1519
|
-
async def
|
556
|
+
async def _send_monitor_initial_data(self, websocket: WebSocketServerProtocol):
|
557
|
+
"""Send initial data to monitor in a separate task to avoid blocking."""
|
558
|
+
total_start = time.time()
|
559
|
+
try:
|
560
|
+
# Check if websocket is still in monitors set
|
561
|
+
if websocket not in self.monitors:
|
562
|
+
logger.debug("Monitor disconnected before initial data send")
|
563
|
+
return
|
564
|
+
|
565
|
+
# Send current stats (already in memory)
|
566
|
+
stats_start = time.time()
|
567
|
+
await websocket.send(safe_json_dumps({"type": "stats", "data": self.stats}))
|
568
|
+
logger.debug(f"Monitor stats sent in {(time.time() - stats_start)*1000:.1f}ms")
|
569
|
+
|
570
|
+
# Get processor stats instead of chunk stats
|
571
|
+
processor_stats_start = time.time()
|
572
|
+
processor_stats = self.processor.get_stats()
|
573
|
+
logger.debug(
|
574
|
+
f"Processor stats retrieved in {(time.time() - processor_stats_start)*1000:.1f}ms"
|
575
|
+
)
|
576
|
+
|
577
|
+
stats_send_start = time.time()
|
578
|
+
await websocket.send(
|
579
|
+
safe_json_dumps({"type": "processor_stats", "data": processor_stats})
|
580
|
+
)
|
581
|
+
logger.debug(f"Processor stats sent in {(time.time() - stats_send_start)*1000:.1f}ms")
|
582
|
+
|
583
|
+
if websocket not in self.monitors:
|
584
|
+
return
|
585
|
+
|
586
|
+
# For leaderboard, check if we have a cached version first
|
587
|
+
if hasattr(self, "_cached_leaderboard") and self._cached_leaderboard:
|
588
|
+
# Use cached leaderboard if available
|
589
|
+
cache_send_start = time.time()
|
590
|
+
await websocket.send(
|
591
|
+
safe_json_dumps({"type": "leaderboard", "data": self._cached_leaderboard})
|
592
|
+
)
|
593
|
+
logger.debug(
|
594
|
+
f"Cached leaderboard sent in {(time.time() - cache_send_start)*1000:.1f}ms"
|
595
|
+
)
|
596
|
+
else:
|
597
|
+
# Schedule leaderboard update separately
|
598
|
+
leaderboard_task_start = time.time()
|
599
|
+
asyncio.create_task(self._send_leaderboard_to_monitor(websocket))
|
600
|
+
logger.debug(
|
601
|
+
f"Leaderboard task created in {(time.time() - leaderboard_task_start)*1000:.1f}ms"
|
602
|
+
)
|
603
|
+
|
604
|
+
logger.debug(
|
605
|
+
f"Monitor initial data send completed in {(time.time() - total_start)*1000:.1f}ms"
|
606
|
+
)
|
607
|
+
|
608
|
+
except websockets.exceptions.ConnectionClosed:
|
609
|
+
logger.debug("Monitor disconnected during initial data send")
|
610
|
+
except Exception as e:
|
611
|
+
logger.error(f"Error sending initial monitor data: {e}")
|
612
|
+
|
613
|
+
async def _send_monitor_leaderboard(self, websocket: WebSocketServerProtocol):
|
1520
614
|
"""Send leaderboard data to a specific monitor."""
|
1521
615
|
total_start = time.time()
|
1522
616
|
try:
|
@@ -1581,210 +675,43 @@ class Orchestrator:
|
|
1581
675
|
except Exception as e:
|
1582
676
|
logger.error(f"Error sending leaderboard to monitor: {e}")
|
1583
677
|
|
1584
|
-
async def
|
1585
|
-
"""Send
|
1586
|
-
|
1587
|
-
|
1588
|
-
# Check if websocket is still in monitors set
|
1589
|
-
if websocket not in self.monitors:
|
1590
|
-
logger.debug("Monitor disconnected before initial data send")
|
1591
|
-
return
|
1592
|
-
|
1593
|
-
# Send current stats (already in memory)
|
1594
|
-
stats_start = time.time()
|
1595
|
-
await websocket.send(safe_json_dumps({"type": "stats", "data": self.stats}))
|
1596
|
-
logger.debug(f"Monitor stats sent in {(time.time() - stats_start)*1000:.1f}ms")
|
1597
|
-
|
1598
|
-
# Get chunk stats asynchronously
|
1599
|
-
chunk_stats_start = time.time()
|
1600
|
-
loop = asyncio.get_event_loop()
|
1601
|
-
chunk_stats = await loop.run_in_executor(None, self.chunk_manager.get_stats)
|
1602
|
-
logger.debug(f"Chunk stats retrieved in {(time.time() - chunk_stats_start)*1000:.1f}ms")
|
1603
|
-
|
1604
|
-
if websocket not in self.monitors:
|
1605
|
-
return
|
1606
|
-
|
1607
|
-
chunk_send_start = time.time()
|
1608
|
-
await websocket.send(safe_json_dumps({"type": "chunk_stats", "data": chunk_stats}))
|
1609
|
-
logger.debug(f"Chunk stats sent in {(time.time() - chunk_send_start)*1000:.1f}ms")
|
1610
|
-
|
1611
|
-
# For leaderboard, check if we have a cached version first
|
1612
|
-
if hasattr(self, "_cached_leaderboard") and self._cached_leaderboard:
|
1613
|
-
# Use cached leaderboard if available
|
1614
|
-
cache_send_start = time.time()
|
1615
|
-
await websocket.send(
|
1616
|
-
safe_json_dumps({"type": "leaderboard", "data": self._cached_leaderboard})
|
1617
|
-
)
|
1618
|
-
logger.debug(
|
1619
|
-
f"Cached leaderboard sent in {(time.time() - cache_send_start)*1000:.1f}ms"
|
1620
|
-
)
|
1621
|
-
else:
|
1622
|
-
# Schedule leaderboard update separately
|
1623
|
-
leaderboard_task_start = time.time()
|
1624
|
-
asyncio.create_task(self._send_leaderboard_to_monitor(websocket))
|
1625
|
-
logger.debug(
|
1626
|
-
f"Leaderboard task created in {(time.time() - leaderboard_task_start)*1000:.1f}ms"
|
1627
|
-
)
|
1628
|
-
|
1629
|
-
logger.debug(
|
1630
|
-
f"Monitor initial data send completed in {(time.time() - total_start)*1000:.1f}ms"
|
1631
|
-
)
|
1632
|
-
|
1633
|
-
except websockets.exceptions.ConnectionClosed:
|
1634
|
-
logger.debug("Monitor disconnected during initial data send")
|
1635
|
-
except Exception as e:
|
1636
|
-
logger.error(f"Error sending initial monitor data: {e}")
|
1637
|
-
|
1638
|
-
async def _handle_monitor(self, websocket: WebSocketServerProtocol):
|
1639
|
-
"""Handle monitor connection - truly non-blocking version."""
|
1640
|
-
monitor_start = time.time()
|
1641
|
-
self.monitors.add(websocket)
|
1642
|
-
logger.info(f"Monitor connected (total monitors: {len(self.monitors)})")
|
678
|
+
async def _send_monitor_stats(self, websocket: WebSocketServerProtocol):
|
679
|
+
"""Send current stats to a monitor."""
|
680
|
+
# Get processor stats
|
681
|
+
processor_stats = self.processor.get_stats()
|
1643
682
|
|
1644
|
-
|
1645
|
-
|
1646
|
-
welcome_start = time.time()
|
1647
|
-
await websocket.send(safe_json_dumps({"type": "welcome", "role": "monitor"}))
|
1648
|
-
logger.debug(f"Monitor welcome sent in {(time.time() - welcome_start)*1000:.1f}ms")
|
1649
|
-
|
1650
|
-
# Schedule initial data send as a separate task to avoid blocking
|
1651
|
-
task_create_start = time.time()
|
1652
|
-
asyncio.create_task(self._send_initial_monitor_data(websocket))
|
1653
|
-
logger.debug(
|
1654
|
-
f"Monitor initial data task created in {(time.time() - task_create_start)*1000:.1f}ms"
|
1655
|
-
)
|
683
|
+
# Get storage stats
|
684
|
+
storage_stats = await self.storage.get_storage_stats()
|
1656
685
|
|
1657
|
-
|
1658
|
-
|
1659
|
-
|
1660
|
-
|
1661
|
-
|
1662
|
-
|
1663
|
-
|
1664
|
-
|
686
|
+
# Combine all stats
|
687
|
+
all_stats = {
|
688
|
+
**self.stats,
|
689
|
+
**storage_stats,
|
690
|
+
"processor_stats": processor_stats,
|
691
|
+
"current_rate": self.rate_tracker["current_rate"],
|
692
|
+
"average_rate": self.rate_tracker["average_rate"],
|
693
|
+
}
|
1665
694
|
|
1666
|
-
|
1667
|
-
logger.info("Monitor disconnected")
|
1668
|
-
except Exception as e:
|
1669
|
-
logger.error(f"Error in monitor handler: {e}")
|
1670
|
-
finally:
|
1671
|
-
self.monitors.discard(websocket)
|
1672
|
-
logger.debug(f"Monitor handler completed in {(time.time() - monitor_start)*1000:.1f}ms")
|
695
|
+
await websocket.send(safe_json_dumps({"type": "stats", "data": all_stats}))
|
1673
696
|
|
1674
|
-
async def
|
1675
|
-
"""
|
697
|
+
async def _send_activity(self, activity: str):
|
698
|
+
"""Send activity update to monitors."""
|
1676
699
|
if not self.monitors:
|
1677
700
|
return
|
1678
|
-
if self.is_generating_stats:
|
1679
|
-
return # Already generating stats, skip this call
|
1680
|
-
self.is_generating_stats = True
|
1681
|
-
total_start = time.time()
|
1682
701
|
|
1683
|
-
|
1684
|
-
|
1685
|
-
loop = asyncio.get_event_loop()
|
1686
|
-
|
1687
|
-
# Get storage stats (already async)
|
1688
|
-
storage_stats_start = time.time()
|
1689
|
-
storage_stats = await self.storage.get_storage_stats()
|
1690
|
-
logger.debug(f"Storage stats retrieved in {(time.time() - storage_stats_start)*1000:.1f}ms")
|
1691
|
-
|
1692
|
-
caption_stats_start = time.time()
|
1693
|
-
caption_stats = await self.storage.get_caption_stats()
|
1694
|
-
logger.debug(f"Caption stats retrieved in {(time.time() - caption_stats_start)*1000:.1f}ms")
|
1695
|
-
|
1696
|
-
# Get chunk stats in thread pool
|
1697
|
-
chunk_stats_start = time.time()
|
1698
|
-
chunk_stats = await loop.run_in_executor(None, self.chunk_manager.get_stats)
|
1699
|
-
logger.debug(f"Chunk stats retrieved in {(time.time() - chunk_stats_start)*1000:.1f}ms")
|
1700
|
-
|
1701
|
-
# Build stats dict
|
1702
|
-
build_stats_start = time.time()
|
1703
|
-
stats_update = self.stats.copy()
|
1704
|
-
stats_update.update({f"chunks_{k}": v for k, v in chunk_stats.items()})
|
1705
|
-
stats_update.update(storage_stats)
|
1706
|
-
stats_update["field_breakdown"] = caption_stats.get("field_stats", {})
|
1707
|
-
stats_update["output_fields_list"] = caption_stats.get("output_fields", [])
|
1708
|
-
|
1709
|
-
# Add rate information
|
1710
|
-
stats_update.update(
|
1711
|
-
{
|
1712
|
-
"current_rate": self.rate_tracker["current_rate"],
|
1713
|
-
"average_rate": self.rate_tracker["average_rate"],
|
1714
|
-
"expected_rate": self.rate_tracker["expected_rate"],
|
1715
|
-
}
|
702
|
+
message = safe_json_dumps(
|
703
|
+
{"type": "activity", "data": f"[{datetime.now().strftime('%H:%M:%S')}] {activity}"}
|
1716
704
|
)
|
1717
705
|
|
1718
|
-
|
1719
|
-
|
1720
|
-
stats_update["vllm_batch_size"] = self.vllm_config.get("batch_size", 0)
|
1721
|
-
|
1722
|
-
# Add stage information
|
1723
|
-
stages = self.vllm_config.get("stages", [])
|
1724
|
-
if stages:
|
1725
|
-
stats_update["stage_count"] = len(stages)
|
1726
|
-
stats_update["stage_names"] = [s.get("name", "unnamed") for s in stages]
|
1727
|
-
else:
|
1728
|
-
stats_update["stage_count"] = 1
|
1729
|
-
stats_update["stage_names"] = ["default"]
|
1730
|
-
|
1731
|
-
# Get field stats
|
1732
|
-
field_stats_start = time.time()
|
1733
|
-
field_stats = await self.storage.get_output_field_stats()
|
1734
|
-
stats_update["output_fields"] = field_stats
|
1735
|
-
logger.debug(f"Field stats retrieved in {(time.time() - field_stats_start)*1000:.1f}ms")
|
1736
|
-
|
1737
|
-
# Update our internal stats
|
1738
|
-
self.stats = stats_update
|
1739
|
-
logger.debug(f"Stats prepared in {(time.time() - build_stats_start)*1000:.1f}ms")
|
1740
|
-
|
1741
|
-
logger.debug(f"Total data preparation took {(time.time() - data_prep_start)*1000:.1f}ms")
|
1742
|
-
|
1743
|
-
# Create message once
|
1744
|
-
message_create_start = time.time()
|
1745
|
-
stats_message = safe_json_dumps({"type": "stats", "data": self.stats})
|
1746
|
-
logger.debug(f"Stats message created in {(time.time() - message_create_start)*1000:.1f}ms")
|
1747
|
-
|
1748
|
-
# Send to all monitors asynchronously in parallel
|
1749
|
-
send_start = time.time()
|
1750
|
-
|
1751
|
-
async def send_to_monitor(monitor):
|
706
|
+
disconnected = set()
|
707
|
+
for monitor in self.monitors:
|
1752
708
|
try:
|
1753
|
-
await monitor.send(
|
709
|
+
await monitor.send(message)
|
1754
710
|
except websockets.exceptions.ConnectionClosed:
|
1755
|
-
|
1756
|
-
except Exception as e:
|
1757
|
-
logger.debug(f"Error sending stats to monitor: {e}")
|
1758
|
-
return monitor # Return for removal
|
1759
|
-
return None
|
1760
|
-
|
1761
|
-
# Send to all monitors in parallel
|
1762
|
-
monitors_copy = self.monitors.copy()
|
1763
|
-
results = await asyncio.gather(
|
1764
|
-
*[send_to_monitor(m) for m in monitors_copy], return_exceptions=True
|
1765
|
-
)
|
711
|
+
disconnected.add(monitor)
|
1766
712
|
|
1767
|
-
# Remove disconnected monitors
|
1768
|
-
disconnected = {
|
1769
|
-
m
|
1770
|
-
for m, r in zip(monitors_copy, results)
|
1771
|
-
if r is not None and not isinstance(r, Exception)
|
1772
|
-
}
|
1773
713
|
self.monitors -= disconnected
|
1774
714
|
|
1775
|
-
logger.debug(
|
1776
|
-
f"Stats sent to {len(monitors_copy)} monitors in {(time.time() - send_start)*1000:.1f}ms"
|
1777
|
-
)
|
1778
|
-
|
1779
|
-
# Send leaderboard update in a separate task to avoid blocking
|
1780
|
-
leaderboard_task_start = time.time()
|
1781
|
-
asyncio.create_task(self._broadcast_leaderboard())
|
1782
|
-
self.is_generating_stats = False
|
1783
|
-
logger.debug(
|
1784
|
-
f"Leaderboard broadcast task created in {(time.time() - leaderboard_task_start)*1000:.1f}ms"
|
1785
|
-
)
|
1786
|
-
logger.debug(f"Stats broadcast completed in {(time.time() - total_start)*1000:.1f}ms")
|
1787
|
-
|
1788
715
|
async def _broadcast_leaderboard(self):
|
1789
716
|
"""Send leaderboard updates to monitors - separate from stats to avoid blocking."""
|
1790
717
|
if not self.monitors:
|
@@ -1875,96 +802,38 @@ class Orchestrator:
|
|
1875
802
|
except Exception as e:
|
1876
803
|
logger.error(f"Error broadcasting leaderboard: {e}")
|
1877
804
|
|
1878
|
-
def
|
1879
|
-
"""
|
1880
|
-
with self.chunk_manager.lock:
|
1881
|
-
return {
|
1882
|
-
"pending_chunks": len(self.chunk_manager.pending_chunks),
|
1883
|
-
"assigned_chunks": sum(
|
1884
|
-
len(chunks) for chunks in self.chunk_manager.assigned_chunks.values()
|
1885
|
-
),
|
1886
|
-
}
|
1887
|
-
|
1888
|
-
async def _flush_processed_items(self):
|
1889
|
-
"""Flush batched processed items to chunk tracker."""
|
1890
|
-
with self.item_batch_lock:
|
1891
|
-
if not self.pending_processed_items:
|
1892
|
-
return
|
1893
|
-
|
1894
|
-
for chunk_id, indices in self.pending_processed_items.items():
|
1895
|
-
if not indices:
|
1896
|
-
continue
|
1897
|
-
|
1898
|
-
# Indices here are ABSOLUTE dataset indices
|
1899
|
-
# Sort indices
|
1900
|
-
indices.sort()
|
1901
|
-
|
1902
|
-
# Group consecutive indices into ranges
|
1903
|
-
ranges = []
|
1904
|
-
start = indices[0]
|
1905
|
-
end = indices[0]
|
1906
|
-
|
1907
|
-
for i in range(1, len(indices)):
|
1908
|
-
if indices[i] == end + 1:
|
1909
|
-
# Consecutive, extend range
|
1910
|
-
end = indices[i]
|
1911
|
-
else:
|
1912
|
-
# Gap found, save current range and start new one
|
1913
|
-
ranges.append((start, end))
|
1914
|
-
start = indices[i]
|
1915
|
-
end = indices[i]
|
1916
|
-
|
1917
|
-
# Don't forget the last range
|
1918
|
-
ranges.append((start, end))
|
1919
|
-
|
1920
|
-
# Mark ranges as processed
|
1921
|
-
for start_idx, end_idx in ranges:
|
1922
|
-
self.chunk_tracker.mark_items_processed(chunk_id, start_idx, end_idx)
|
1923
|
-
|
1924
|
-
with self.chunk_manager.lock:
|
1925
|
-
if chunk_id in self.chunk_manager.assigned_ranges:
|
1926
|
-
for start_idx, end_idx in ranges:
|
1927
|
-
# Clear any assignments in this range
|
1928
|
-
to_remove = []
|
1929
|
-
for range_start, range_end in self.chunk_manager.assigned_ranges[
|
1930
|
-
chunk_id
|
1931
|
-
]:
|
1932
|
-
if range_start >= start_idx and range_end <= end_idx:
|
1933
|
-
to_remove.append((range_start, range_end))
|
1934
|
-
|
1935
|
-
for range_key in to_remove:
|
1936
|
-
del self.chunk_manager.assigned_ranges[chunk_id][range_key]
|
1937
|
-
|
1938
|
-
# Clear pending items
|
1939
|
-
self.pending_processed_items.clear()
|
1940
|
-
self.last_item_batch_flush = time.time()
|
1941
|
-
|
1942
|
-
def get_workers_by_user_stats(self) -> Dict[str, Any]:
|
1943
|
-
"""Get statistics about workers grouped by user/token - thread-safe version."""
|
1944
|
-
if not hasattr(self, "workers_by_user"):
|
1945
|
-
return {}
|
1946
|
-
|
1947
|
-
# Create a copy to avoid issues with concurrent modification
|
1948
|
-
stats = {}
|
1949
|
-
workers_snapshot = dict(self.workers_by_user)
|
1950
|
-
for user, worker_ids in workers_snapshot.items():
|
1951
|
-
stats[user] = {"worker_count": len(worker_ids), "worker_ids": list(worker_ids)}
|
1952
|
-
return stats
|
1953
|
-
|
1954
|
-
async def _send_activity(self, activity: str):
|
1955
|
-
"""Send activity update to monitors."""
|
805
|
+
async def _broadcast_stats(self):
|
806
|
+
"""Broadcast statistics to all monitors."""
|
1956
807
|
if not self.monitors:
|
1957
808
|
return
|
1958
809
|
|
1959
|
-
|
1960
|
-
|
810
|
+
# Get current stats
|
811
|
+
processor_stats = self.processor.get_stats()
|
812
|
+
storage_stats = await self.storage.get_storage_stats()
|
813
|
+
|
814
|
+
# Update main stats
|
815
|
+
self.stats["processor_stats"] = processor_stats
|
816
|
+
self.stats["total_outputs"] = storage_stats["total_captions"]
|
817
|
+
|
818
|
+
# Create message
|
819
|
+
stats_message = safe_json_dumps(
|
820
|
+
{
|
821
|
+
"type": "stats",
|
822
|
+
"data": {
|
823
|
+
**self.stats,
|
824
|
+
**storage_stats,
|
825
|
+
"current_rate": self.rate_tracker["current_rate"],
|
826
|
+
"average_rate": self.rate_tracker["average_rate"],
|
827
|
+
},
|
828
|
+
}
|
1961
829
|
)
|
1962
830
|
|
831
|
+
# Send to all monitors
|
1963
832
|
disconnected = set()
|
1964
833
|
for monitor in self.monitors:
|
1965
834
|
try:
|
1966
|
-
await monitor.send(
|
1967
|
-
except
|
835
|
+
await monitor.send(stats_message)
|
836
|
+
except:
|
1968
837
|
disconnected.add(monitor)
|
1969
838
|
|
1970
839
|
self.monitors -= disconnected
|
@@ -1972,235 +841,74 @@ class Orchestrator:
|
|
1972
841
|
async def _heartbeat_loop(self):
|
1973
842
|
"""Send periodic heartbeats to maintain connections."""
|
1974
843
|
while True:
|
1975
|
-
|
1976
|
-
|
1977
|
-
|
1978
|
-
|
1979
|
-
|
1980
|
-
|
1981
|
-
|
1982
|
-
|
1983
|
-
|
1984
|
-
|
1985
|
-
|
1986
|
-
|
1987
|
-
|
1988
|
-
|
1989
|
-
|
1990
|
-
|
1991
|
-
|
1992
|
-
|
1993
|
-
|
1994
|
-
|
1995
|
-
except websockets.exceptions.ConnectionClosed:
|
1996
|
-
logger.info(f"Worker {worker_id} connection already closed")
|
1997
|
-
disconnected.append(worker_id)
|
1998
|
-
except Exception as e:
|
1999
|
-
logger.error(f"Error pinging worker {worker_id}: {e}")
|
2000
|
-
disconnected.append(worker_id)
|
2001
|
-
|
2002
|
-
# Clean up disconnected workers
|
2003
|
-
for worker_id in disconnected:
|
2004
|
-
if worker_id in self.workers:
|
2005
|
-
logger.info(f"Removing unresponsive worker {worker_id}")
|
2006
|
-
del self.workers[worker_id]
|
2007
|
-
self.chunk_manager.release_worker_chunks(worker_id)
|
2008
|
-
|
2009
|
-
# Update stats
|
2010
|
-
self.stats["connected_workers"] = len(self.workers)
|
2011
|
-
|
2012
|
-
# Also clean up from workers_by_user if it exists
|
2013
|
-
if hasattr(self, "workers_by_user"):
|
2014
|
-
worker_user = (
|
2015
|
-
worker_id.rsplit("_", 1)[0] if "_" in worker_id else worker_id
|
2016
|
-
)
|
2017
|
-
if worker_user in self.workers_by_user:
|
2018
|
-
self.workers_by_user[worker_user].discard(worker_id)
|
2019
|
-
if not self.workers_by_user[worker_user]:
|
2020
|
-
del self.workers_by_user[worker_user]
|
2021
|
-
|
2022
|
-
# Notify monitors
|
2023
|
-
await self._broadcast_stats()
|
2024
|
-
await self._send_activity(
|
2025
|
-
f"Worker {worker_id} removed due to heartbeat timeout"
|
2026
|
-
)
|
2027
|
-
|
2028
|
-
except Exception as e:
|
2029
|
-
logger.error(f"Error in heartbeat loop: {e}", exc_info=True)
|
2030
|
-
# Continue the loop even if there's an error
|
2031
|
-
await asyncio.sleep(5)
|
844
|
+
await asyncio.sleep(30)
|
845
|
+
|
846
|
+
disconnected = []
|
847
|
+
for worker_id, ws in list(self.workers.items()):
|
848
|
+
try:
|
849
|
+
pong_waiter = await ws.ping()
|
850
|
+
await asyncio.wait_for(pong_waiter, timeout=10)
|
851
|
+
except:
|
852
|
+
disconnected.append(worker_id)
|
853
|
+
|
854
|
+
# Clean up disconnected workers
|
855
|
+
for worker_id in disconnected:
|
856
|
+
logger.warning(f"Worker {worker_id} did not respond to ping, disconnecting")
|
857
|
+
if worker_id in self.workers:
|
858
|
+
del self.workers[worker_id]
|
859
|
+
logger.warning(
|
860
|
+
f"Releasing assignments for worker {worker_id} because it did not respond to ping"
|
861
|
+
)
|
862
|
+
self.processor.release_assignments(worker_id)
|
863
|
+
self.stats["connected_workers"] = len(self.workers)
|
2032
864
|
|
2033
865
|
async def _checkpoint_loop(self):
|
2034
866
|
"""Periodically checkpoint storage."""
|
2035
|
-
interval = self.config.get("storage", {}).get("checkpoint_interval",
|
867
|
+
interval = self.config.get("storage", {}).get("checkpoint_interval", 60)
|
2036
868
|
|
2037
869
|
while True:
|
2038
|
-
await asyncio.sleep(
|
2039
|
-
|
2040
|
-
# Get current caption count from storage
|
2041
|
-
storage_stats = await self.storage.get_storage_stats()
|
2042
|
-
total_captions = storage_stats["total_captions"]
|
2043
|
-
|
2044
|
-
# Force checkpoint at regular intervals
|
2045
|
-
if total_captions > 0 and total_captions % interval == 0:
|
2046
|
-
logger.info(f"Triggering checkpoint at {total_captions} captions")
|
2047
|
-
await self.storage.checkpoint()
|
870
|
+
await asyncio.sleep(interval)
|
2048
871
|
|
2049
|
-
|
2050
|
-
|
2051
|
-
|
2052
|
-
|
2053
|
-
await self._broadcast_stats()
|
2054
|
-
logger.info(
|
2055
|
-
f"Checkpoint complete. Total written to disk: {storage_stats['total_written']}"
|
2056
|
-
)
|
872
|
+
await self.storage.checkpoint()
|
873
|
+
self.stats["last_checkpoint"] = datetime.utcnow().isoformat()
|
874
|
+
logger.info("Storage checkpoint complete")
|
2057
875
|
|
2058
876
|
async def _stats_update_loop(self):
|
2059
|
-
"""Periodically update and broadcast stats
|
2060
|
-
# Get the event loop for running blocking operations
|
2061
|
-
loop = asyncio.get_event_loop()
|
2062
|
-
|
2063
|
-
# Track session start values
|
2064
|
-
storage_stats = await self.storage.get_storage_stats()
|
2065
|
-
session_start_outputs = storage_stats["total_captions"] # This now counts ALL outputs
|
2066
|
-
session_start_time = time.time()
|
2067
|
-
|
2068
|
-
# Track the last known total to detect flushes
|
2069
|
-
last_known_total = session_start_outputs
|
2070
|
-
|
877
|
+
"""Periodically update and broadcast stats."""
|
2071
878
|
while True:
|
2072
879
|
await asyncio.sleep(10)
|
2073
880
|
|
2074
|
-
# Update
|
2075
|
-
chunk_stats = await loop.run_in_executor(None, self.chunk_manager.get_stats)
|
881
|
+
# Update rate tracking
|
2076
882
|
storage_stats = await self.storage.get_storage_stats()
|
2077
|
-
|
2078
|
-
if self.chunk_tracker:
|
2079
|
-
await self._flush_processed_items()
|
2080
|
-
|
2081
|
-
self.stats["total_chunks"] = chunk_stats["total"]
|
2082
|
-
self.stats["completed_chunks"] = chunk_stats["completed"]
|
2083
|
-
self.stats["failed_chunks"] = chunk_stats["failed"]
|
2084
|
-
|
2085
|
-
# Update total outputs stat (rename from total_captions for clarity)
|
2086
|
-
self.stats["total_outputs"] = current_total_outputs
|
2087
|
-
self.stats["total_captions"] = current_total_outputs # Keep for backward compatibility
|
2088
|
-
|
2089
|
-
# Get queue stats in thread pool to avoid blocking
|
2090
|
-
queue_stats = await loop.run_in_executor(None, self._get_queue_stats)
|
2091
|
-
self.stats.update(queue_stats)
|
2092
|
-
|
2093
|
-
# Calculate if we need more chunks
|
2094
|
-
worker_count = self.stats.get("connected_workers", 0)
|
2095
|
-
target_buffer = max(self.min_chunk_buffer, worker_count * self.chunk_buffer_multiplier)
|
2096
|
-
active_chunks = self.stats["pending_chunks"] + self.stats["assigned_chunks"]
|
2097
|
-
self.stats["chunk_buffer_status"] = f"{active_chunks}/{target_buffer}"
|
2098
|
-
|
2099
|
-
# Update rate information
|
883
|
+
current_total = storage_stats["total_captions"]
|
2100
884
|
current_time = time.time()
|
2101
|
-
elapsed_since_update = current_time - self.rate_tracker["last_update_time"]
|
2102
|
-
|
2103
|
-
if elapsed_since_update > 0:
|
2104
|
-
# FIX: Handle the case where duplicates were skipped during save
|
2105
|
-
# If current total is less than last known, it means duplicates were skipped
|
2106
|
-
# We should not count this as negative progress
|
2107
|
-
if current_total_outputs < last_known_total:
|
2108
|
-
logger.debug(
|
2109
|
-
f"Detected duplicate skip during save: {last_known_total} -> {current_total_outputs}"
|
2110
|
-
)
|
2111
|
-
# Don't calculate negative rate, just update the baseline
|
2112
|
-
self.rate_tracker["last_caption_count"] = current_total_outputs
|
2113
|
-
self.rate_tracker["current_rate"] = 0.0 # Set to 0 during flush
|
2114
|
-
else:
|
2115
|
-
# Normal rate calculation
|
2116
|
-
output_diff = current_total_outputs - self.rate_tracker["last_caption_count"]
|
2117
|
-
self.rate_tracker["current_rate"] = (output_diff / elapsed_since_update) * 60
|
2118
|
-
self.rate_tracker["last_caption_count"] = current_total_outputs
|
2119
|
-
|
2120
|
-
# Calculate average rate since THIS SESSION started
|
2121
|
-
session_elapsed = current_time - session_start_time
|
2122
|
-
if session_elapsed > 0:
|
2123
|
-
# Always use the difference from session start for average
|
2124
|
-
session_outputs = current_total_outputs - session_start_outputs
|
2125
|
-
self.rate_tracker["average_rate"] = (session_outputs / session_elapsed) * 60
|
2126
|
-
|
2127
|
-
# Calculate expected rate based on workers and stages
|
2128
|
-
batch_size = self.vllm_config.get("batch_size", 8)
|
2129
|
-
|
2130
|
-
# Count total prompts across all stages
|
2131
|
-
total_prompts = 0
|
2132
|
-
stages = self.vllm_config.get("stages", [])
|
2133
|
-
if stages:
|
2134
|
-
for stage in stages:
|
2135
|
-
total_prompts += len(stage.get("prompts", []))
|
2136
|
-
else:
|
2137
|
-
# Backward compatibility
|
2138
|
-
total_prompts = len(self.vllm_config.get("inference_prompts", ["", "", ""]))
|
2139
885
|
|
2140
|
-
|
2141
|
-
|
2142
|
-
|
2143
|
-
)
|
2144
|
-
|
2145
|
-
# Update trackers
|
886
|
+
elapsed = current_time - self.rate_tracker["last_update_time"]
|
887
|
+
if elapsed > 0:
|
888
|
+
output_diff = current_total - self.rate_tracker["last_output_count"]
|
889
|
+
self.rate_tracker["current_rate"] = (output_diff / elapsed) * 60
|
890
|
+
self.rate_tracker["last_output_count"] = current_total
|
2146
891
|
self.rate_tracker["last_update_time"] = current_time
|
2147
|
-
last_known_total = current_total_outputs
|
2148
|
-
|
2149
|
-
# Log rate information when workers are connected
|
2150
|
-
# if (
|
2151
|
-
# worker_count > 0 and self.rate_tracker["current_rate"] >= 0
|
2152
|
-
# ): # Only log non-negative rates
|
2153
|
-
# logger.info(
|
2154
|
-
# f"Rate: {self.rate_tracker['current_rate']:.1f} outputs/min "
|
2155
|
-
# f"(avg: {self.rate_tracker['average_rate']:.1f}, "
|
2156
|
-
# f"expected: {self.rate_tracker['expected_rate']:.1f}) | "
|
2157
|
-
# f"Workers: {worker_count}, Chunks: {active_chunks}/{target_buffer}"
|
2158
|
-
# )
|
2159
892
|
|
2160
|
-
|
893
|
+
# Average rate since start
|
894
|
+
total_elapsed = current_time - self.rate_tracker["start_time"]
|
895
|
+
if total_elapsed > 0:
|
896
|
+
self.rate_tracker["average_rate"] = (current_total / total_elapsed) * 60
|
2161
897
|
|
2162
|
-
|
2163
|
-
"""Restore state from storage on startup."""
|
2164
|
-
total_captions = await self.storage.count_captions()
|
2165
|
-
logger.info(f"Restored state: {total_captions} captions")
|
898
|
+
await self._broadcast_stats()
|
2166
899
|
|
2167
900
|
async def shutdown(self):
|
2168
901
|
"""Graceful shutdown."""
|
2169
902
|
logger.info("Shutting down orchestrator...")
|
2170
903
|
|
2171
|
-
# Stop chunk creation
|
2172
|
-
if self.chunk_tracker:
|
2173
|
-
await self._flush_processed_items()
|
2174
|
-
self.stop_chunk_creation.set()
|
2175
|
-
if self.chunk_creation_thread:
|
2176
|
-
self.chunk_creation_thread.join(timeout=5)
|
2177
|
-
|
2178
|
-
# Release all assigned chunks before closing connections
|
2179
|
-
for worker_id in list(self.workers.keys()):
|
2180
|
-
self.chunk_manager.release_worker_chunks(worker_id)
|
2181
|
-
if self.chunk_tracker:
|
2182
|
-
# Update chunk tracker to mark assigned chunks as pending
|
2183
|
-
with self.chunk_manager.lock:
|
2184
|
-
for chunk_id in list(self.chunk_manager.assigned_chunks.get(worker_id, [])):
|
2185
|
-
self.chunk_tracker.mark_pending(chunk_id)
|
2186
|
-
|
2187
904
|
# Close all connections
|
2188
905
|
for ws in list(self.workers.values()):
|
2189
906
|
await ws.close()
|
2190
907
|
for ws in list(self.monitors):
|
2191
908
|
await ws.close()
|
2192
909
|
|
2193
|
-
# Save chunk state
|
2194
|
-
if self.chunk_tracker:
|
2195
|
-
self.chunk_tracker.save()
|
2196
|
-
|
2197
910
|
# Final checkpoint
|
2198
|
-
logger.info(f"Final flush: {len(self.storage.caption_buffer)} captions in buffer")
|
2199
911
|
await self.storage.checkpoint()
|
2200
|
-
|
2201
|
-
# Log final statistics
|
2202
|
-
logger.info(
|
2203
|
-
f"Shutdown complete. Total captions collected: {self.storage.total_captions_written}"
|
2204
|
-
)
|
2205
|
-
|
2206
912
|
await self.storage.close()
|
913
|
+
|
914
|
+
logger.info("Shutdown complete")
|