caption-flow 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +9 -0
- caption_flow/cli.py +709 -0
- caption_flow/models.py +82 -0
- caption_flow/monitor.py +211 -0
- caption_flow/orchestrator.py +1301 -0
- caption_flow/storage.py +694 -0
- caption_flow/utils/__init__.py +4 -0
- caption_flow/utils/auth.py +67 -0
- caption_flow/utils/caption_utils.py +172 -0
- caption_flow/utils/certificates.py +140 -0
- caption_flow/utils/chunk_tracker.py +365 -0
- caption_flow/utils/dataset_loader.py +186 -0
- caption_flow/utils/image_processor.py +51 -0
- caption_flow/utils/job_queue.py +41 -0
- caption_flow/utils/json_utils.py +201 -0
- caption_flow/utils/vllm_config.py +164 -0
- caption_flow/worker.py +300 -0
- caption_flow/worker_data.py +482 -0
- caption_flow/worker_vllm.py +1028 -0
- caption_flow-0.1.0.dist-info/METADATA +427 -0
- caption_flow-0.1.0.dist-info/RECORD +25 -0
- caption_flow-0.1.0.dist-info/WHEEL +5 -0
- caption_flow-0.1.0.dist-info/entry_points.txt +2 -0
- caption_flow-0.1.0.dist-info/licenses/LICENSE +661 -0
- caption_flow-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1301 @@
|
|
1
|
+
"""Enhanced orchestrator with shard chunk assignment for vLLM workers.
|
2
|
+
|
3
|
+
This orchestrator:
|
4
|
+
1. Divides dataset shards into chunks for parallel processing
|
5
|
+
2. Assigns chunks to workers on request
|
6
|
+
3. Collects captions from workers centrally
|
7
|
+
4. Manages checkpoints and fault tolerance
|
8
|
+
"""
|
9
|
+
|
10
|
+
import time
|
11
|
+
import asyncio
|
12
|
+
import json
|
13
|
+
import logging
|
14
|
+
import ssl
|
15
|
+
import uuid
|
16
|
+
from dataclasses import dataclass, asdict
|
17
|
+
from datetime import datetime
|
18
|
+
from pathlib import Path
|
19
|
+
from typing import Dict, Set, Optional, Any, List, Deque
|
20
|
+
from collections import deque, defaultdict
|
21
|
+
import threading
|
22
|
+
from queue import Queue, Empty
|
23
|
+
|
24
|
+
import websockets
|
25
|
+
from websockets.server import WebSocketServerProtocol
|
26
|
+
|
27
|
+
from .storage import StorageManager
|
28
|
+
from .models import Caption, Contributor
|
29
|
+
from .utils.auth import AuthManager
|
30
|
+
from .utils.dataset_loader import DatasetLoader, ShardTracker
|
31
|
+
from .utils.json_utils import safe_dict, safe_json_dumps, to_json_dict
|
32
|
+
from .utils.chunk_tracker import ChunkTracker
|
33
|
+
|
34
|
+
logger = logging.getLogger(__name__)
|
35
|
+
|
36
|
+
|
37
|
+
@dataclass
|
38
|
+
class ShardChunk:
|
39
|
+
"""Represents a chunk of a shard for processing."""
|
40
|
+
|
41
|
+
chunk_id: str
|
42
|
+
shard_url: str
|
43
|
+
shard_name: str
|
44
|
+
start_index: int
|
45
|
+
chunk_size: int
|
46
|
+
assigned_to: Optional[str] = None
|
47
|
+
status: str = "pending" # pending, assigned, completed, failed
|
48
|
+
assigned_at: Optional[datetime] = None
|
49
|
+
completed_at: Optional[datetime] = None
|
50
|
+
|
51
|
+
|
52
|
+
class ChunkManager:
|
53
|
+
"""Manages shard chunk creation and assignment."""
|
54
|
+
|
55
|
+
def __init__(self, chunk_size: int = 1000, tracker: Optional[ChunkTracker] = None):
|
56
|
+
self.chunk_size = chunk_size
|
57
|
+
self.chunks: Dict[str, ShardChunk] = {}
|
58
|
+
self.pending_chunks: Deque[str] = deque()
|
59
|
+
self.assigned_chunks: Dict[str, Set[str]] = defaultdict(set) # worker_id -> chunk_ids
|
60
|
+
self.lock = threading.Lock()
|
61
|
+
self.tracker = tracker # Reference to chunk tracker
|
62
|
+
|
63
|
+
def create_chunks_from_shard(
|
64
|
+
self, shard_url: str, shard_name: str, total_items: int
|
65
|
+
) -> List[ShardChunk]:
|
66
|
+
"""Create chunks from a shard."""
|
67
|
+
chunks = []
|
68
|
+
|
69
|
+
for start_idx in range(0, total_items, self.chunk_size):
|
70
|
+
chunk_id = f"{shard_name}_chunk_{start_idx}"
|
71
|
+
chunk = ShardChunk(
|
72
|
+
chunk_id=chunk_id,
|
73
|
+
shard_url=shard_url,
|
74
|
+
shard_name=shard_name,
|
75
|
+
start_index=start_idx,
|
76
|
+
chunk_size=min(self.chunk_size, total_items - start_idx),
|
77
|
+
)
|
78
|
+
|
79
|
+
with self.lock:
|
80
|
+
self.chunks[chunk_id] = chunk
|
81
|
+
self.pending_chunks.append(chunk_id)
|
82
|
+
|
83
|
+
chunks.append(chunk)
|
84
|
+
|
85
|
+
return chunks
|
86
|
+
|
87
|
+
def get_chunks_for_worker(
|
88
|
+
self, worker_id: str, count: int = 1, tracker: Optional["ChunkTracker"] = None
|
89
|
+
) -> List[ShardChunk]:
|
90
|
+
"""Get available chunks for a worker."""
|
91
|
+
assigned = []
|
92
|
+
|
93
|
+
with self.lock:
|
94
|
+
while len(assigned) < count and self.pending_chunks:
|
95
|
+
chunk_id = self.pending_chunks.popleft()
|
96
|
+
chunk = self.chunks[chunk_id]
|
97
|
+
|
98
|
+
chunk.assigned_to = worker_id
|
99
|
+
chunk.status = "assigned"
|
100
|
+
chunk.assigned_at = datetime.utcnow()
|
101
|
+
|
102
|
+
self.assigned_chunks[worker_id].add(chunk_id)
|
103
|
+
assigned.append(chunk)
|
104
|
+
if tracker:
|
105
|
+
tracker.mark_assigned(chunk_id, worker_id)
|
106
|
+
|
107
|
+
return assigned
|
108
|
+
|
109
|
+
def complete_chunk(self, chunk_id: str, worker_id: str) -> bool:
|
110
|
+
"""Mark a chunk as completed."""
|
111
|
+
with self.lock:
|
112
|
+
if chunk_id in self.chunks:
|
113
|
+
chunk = self.chunks[chunk_id]
|
114
|
+
if chunk.assigned_to == worker_id and chunk.status == "assigned":
|
115
|
+
chunk.status = "completed"
|
116
|
+
chunk.completed_at = datetime.utcnow()
|
117
|
+
self.assigned_chunks[worker_id].discard(chunk_id)
|
118
|
+
return True
|
119
|
+
return False
|
120
|
+
|
121
|
+
def fail_chunk(self, chunk_id: str, worker_id: str) -> bool:
|
122
|
+
"""Mark a chunk as failed and requeue it."""
|
123
|
+
with self.lock:
|
124
|
+
if chunk_id in self.chunks:
|
125
|
+
chunk = self.chunks[chunk_id]
|
126
|
+
if chunk.assigned_to == worker_id:
|
127
|
+
chunk.status = "pending"
|
128
|
+
chunk.assigned_to = None
|
129
|
+
chunk.assigned_at = None
|
130
|
+
self.assigned_chunks[worker_id].discard(chunk_id)
|
131
|
+
self.pending_chunks.append(chunk_id)
|
132
|
+
return True
|
133
|
+
return False
|
134
|
+
|
135
|
+
def release_worker_chunks(self, worker_id: str):
|
136
|
+
"""Release all chunks assigned to a worker."""
|
137
|
+
with self.lock:
|
138
|
+
chunk_ids = list(self.assigned_chunks.get(worker_id, []))
|
139
|
+
for chunk_id in chunk_ids:
|
140
|
+
if chunk_id in self.chunks:
|
141
|
+
chunk = self.chunks[chunk_id]
|
142
|
+
if chunk.status == "assigned":
|
143
|
+
chunk.status = "pending"
|
144
|
+
chunk.assigned_to = None
|
145
|
+
chunk.assigned_at = None
|
146
|
+
self.pending_chunks.append(chunk_id)
|
147
|
+
|
148
|
+
if worker_id in self.assigned_chunks:
|
149
|
+
del self.assigned_chunks[worker_id]
|
150
|
+
|
151
|
+
def get_stats(self) -> Dict[str, int]:
|
152
|
+
"""Get chunk statistics."""
|
153
|
+
with self.lock:
|
154
|
+
stats = {
|
155
|
+
"total": len(self.chunks),
|
156
|
+
"pending": len(self.pending_chunks),
|
157
|
+
"assigned": sum(len(chunks) for chunks in self.assigned_chunks.values()),
|
158
|
+
"completed": sum(1 for c in self.chunks.values() if c.status == "completed"),
|
159
|
+
"failed": sum(1 for c in self.chunks.values() if c.status == "failed"),
|
160
|
+
}
|
161
|
+
return stats
|
162
|
+
|
163
|
+
|
164
|
+
class Orchestrator:
|
165
|
+
"""Enhanced orchestrator for vLLM-based distributed captioning with chunk assignment."""
|
166
|
+
|
167
|
+
def __init__(self, config: Dict[str, Any]):
|
168
|
+
self.config = config
|
169
|
+
self.host = config.get("host", "0.0.0.0")
|
170
|
+
self.port = config.get("port", 8765)
|
171
|
+
|
172
|
+
# Dataset configuration
|
173
|
+
self.dataset_config = config.get("dataset", {})
|
174
|
+
self.dataset_path = self.dataset_config.get("path")
|
175
|
+
self.dataset_type = self.dataset_config.get("type", "huggingface")
|
176
|
+
|
177
|
+
# vLLM configuration to distribute to workers
|
178
|
+
self.vllm_config = config.get(
|
179
|
+
"vllm",
|
180
|
+
{
|
181
|
+
"model": "Qwen/Qwen2.5-VL-3B-Instruct",
|
182
|
+
"gpu_memory_utilization": 0.92,
|
183
|
+
"max_model_len": 16384,
|
184
|
+
"tensor_parallel_size": 1,
|
185
|
+
"dtype": "float16",
|
186
|
+
"enforce_eager": True,
|
187
|
+
"limit_mm_per_prompt": {"image": 1},
|
188
|
+
"disable_mm_preprocessor_cache": True,
|
189
|
+
"sampling": {
|
190
|
+
"temperature": 0.7,
|
191
|
+
"top_p": 0.95,
|
192
|
+
"max_tokens": 256,
|
193
|
+
"repetition_penalty": 1.05,
|
194
|
+
"stop": ["<|end|>", "<|endoftext|>", "<|im_end|>"],
|
195
|
+
},
|
196
|
+
"inference_prompts": [
|
197
|
+
"describe this image in detail",
|
198
|
+
"provide a comprehensive description of the visual content",
|
199
|
+
"what are the key elements in this image?",
|
200
|
+
],
|
201
|
+
},
|
202
|
+
)
|
203
|
+
|
204
|
+
# Chunk configuration
|
205
|
+
self.chunk_size = config.get("chunk_size", 1000)
|
206
|
+
self.chunks_per_request = config.get("chunks_per_request", 2)
|
207
|
+
|
208
|
+
# Demand-driven chunk creation settings
|
209
|
+
self.chunk_buffer_multiplier = config.get("chunk_buffer_multiplier", 3)
|
210
|
+
self.min_chunk_buffer = config.get("min_chunk_buffer", 10)
|
211
|
+
|
212
|
+
# Initialize components
|
213
|
+
storage_config = config.get("storage", {})
|
214
|
+
self.storage = StorageManager(
|
215
|
+
Path(storage_config.get("data_dir", "./caption_data")),
|
216
|
+
caption_buffer_size=storage_config.get("caption_buffer_size", 1000),
|
217
|
+
job_buffer_size=storage_config.get("job_buffer_size", 100),
|
218
|
+
contributor_buffer_size=storage_config.get("contributor_buffer_size", 10),
|
219
|
+
)
|
220
|
+
self.auth = AuthManager(config.get("auth", {}))
|
221
|
+
|
222
|
+
# Dataset components
|
223
|
+
self.dataset_loader = None
|
224
|
+
self.shard_tracker = None
|
225
|
+
self.chunk_tracker = None
|
226
|
+
|
227
|
+
if self.dataset_path:
|
228
|
+
self.dataset_loader = DatasetLoader(self.dataset_path, self.dataset_type)
|
229
|
+
checkpoint_dir = Path(config.get("storage", {}).get("checkpoint_dir", "./checkpoints"))
|
230
|
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
231
|
+
self.shard_tracker = ShardTracker(checkpoint_dir / "shards.json")
|
232
|
+
self.chunk_tracker = ChunkTracker(checkpoint_dir / "chunks.json")
|
233
|
+
|
234
|
+
# Initialize chunk manager with reference to chunk tracker
|
235
|
+
self.chunk_manager = ChunkManager(self.chunk_size, self.chunk_tracker)
|
236
|
+
|
237
|
+
# Track connections
|
238
|
+
self.workers: Dict[str, WebSocketServerProtocol] = {}
|
239
|
+
self.monitors: Set[WebSocketServerProtocol] = set()
|
240
|
+
|
241
|
+
# SSL configuration
|
242
|
+
self.ssl_context = self._setup_ssl()
|
243
|
+
|
244
|
+
# Statistics
|
245
|
+
self.stats = {
|
246
|
+
"total_chunks": 0,
|
247
|
+
"completed_chunks": 0,
|
248
|
+
"failed_chunks": 0,
|
249
|
+
"total_captions": 0,
|
250
|
+
"connected_workers": 0,
|
251
|
+
"total_shards": 0,
|
252
|
+
"completed_shards": 0,
|
253
|
+
"current_shard": None,
|
254
|
+
"buffer_size": 0,
|
255
|
+
"total_written": 0,
|
256
|
+
"last_checkpoint": None,
|
257
|
+
}
|
258
|
+
|
259
|
+
# Rate tracking
|
260
|
+
self.rate_tracker = {
|
261
|
+
"start_time": time.time(),
|
262
|
+
"last_update_time": time.time(),
|
263
|
+
"last_caption_count": 0,
|
264
|
+
"current_rate": 0.0,
|
265
|
+
"average_rate": 0.0,
|
266
|
+
"expected_rate": 0.0,
|
267
|
+
}
|
268
|
+
|
269
|
+
# Data sample queue for VLLMWorkers
|
270
|
+
self.data_sample_queue = asyncio.Queue(maxsize=1000)
|
271
|
+
self.data_workers: Dict[str, WebSocketServerProtocol] = {}
|
272
|
+
|
273
|
+
# Backpressure threshold
|
274
|
+
self.backpressure_threshold = config.get("backpressure_threshold", 800)
|
275
|
+
|
276
|
+
# Shard processing state
|
277
|
+
self.all_shards = []
|
278
|
+
self.current_shard_index = 0
|
279
|
+
self.shard_lock = threading.Lock()
|
280
|
+
|
281
|
+
# Background chunk creation
|
282
|
+
self.chunk_creation_thread = None
|
283
|
+
self.stop_chunk_creation = threading.Event()
|
284
|
+
|
285
|
+
# State restoration flag
|
286
|
+
self.state_restored = threading.Event()
|
287
|
+
# If no dataset, state is already "restored"
|
288
|
+
if not self.dataset_loader:
|
289
|
+
self.state_restored.set()
|
290
|
+
|
291
|
+
def _setup_ssl(self) -> Optional[ssl.SSLContext]:
|
292
|
+
"""Configure SSL if certificates are provided."""
|
293
|
+
ssl_config = self.config.get("ssl", {})
|
294
|
+
if not ssl_config.get("cert") or not ssl_config.get("key"):
|
295
|
+
return None
|
296
|
+
|
297
|
+
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
298
|
+
context.load_cert_chain(ssl_config["cert"], ssl_config["key"])
|
299
|
+
return context
|
300
|
+
|
301
|
+
def _create_chunks_from_dataset(self):
|
302
|
+
"""Background thread to create chunks from dataset shards on demand."""
|
303
|
+
if not self.dataset_loader:
|
304
|
+
logger.warning("No dataset configured, skipping chunk creation")
|
305
|
+
self.state_restored.set() # No state to restore
|
306
|
+
return
|
307
|
+
|
308
|
+
logger.info("Starting chunk creation thread")
|
309
|
+
|
310
|
+
# Mark state as not restored until we process checkpoints
|
311
|
+
self.state_restored.clear()
|
312
|
+
|
313
|
+
# Get all shards
|
314
|
+
self.all_shards = self.dataset_loader.get_shard_list()
|
315
|
+
self.stats["total_shards"] = len(self.all_shards)
|
316
|
+
|
317
|
+
# Get shard status from ChunkTracker
|
318
|
+
shards_summary = self.chunk_tracker.get_shards_summary() if self.chunk_tracker else {}
|
319
|
+
completed_shards = {
|
320
|
+
shard_name for shard_name, info in shards_summary.items() if info["is_complete"]
|
321
|
+
}
|
322
|
+
|
323
|
+
# Update ShardTracker for completed shards
|
324
|
+
for shard_name in completed_shards:
|
325
|
+
if not self.shard_tracker.is_complete(shard_name):
|
326
|
+
logger.info(f"Marking shard {shard_name} as complete in ShardTracker")
|
327
|
+
self.shard_tracker.mark_complete(shard_name)
|
328
|
+
|
329
|
+
# Get shards that need processing
|
330
|
+
remaining_shards = self.shard_tracker.get_remaining_shards(self.all_shards)
|
331
|
+
|
332
|
+
# Also check which shards already have chunks (partial or complete)
|
333
|
+
shards_with_chunks = set()
|
334
|
+
for shard_name in shards_summary.keys():
|
335
|
+
shards_with_chunks.add(shard_name)
|
336
|
+
|
337
|
+
# Filter out shards that already have chunks created
|
338
|
+
remaining_shards = [
|
339
|
+
shard for shard in remaining_shards if Path(shard).stem not in shards_with_chunks
|
340
|
+
]
|
341
|
+
|
342
|
+
self.stats["completed_shards"] = len(completed_shards)
|
343
|
+
|
344
|
+
logger.info(
|
345
|
+
f"Total shards: {len(self.all_shards)}, "
|
346
|
+
f"Completed: {self.stats['completed_shards']}, "
|
347
|
+
f"Shards with chunks: {len(shards_with_chunks)}, "
|
348
|
+
f"Remaining to process: {len(remaining_shards)}"
|
349
|
+
)
|
350
|
+
|
351
|
+
# First, re-queue any existing pending chunks
|
352
|
+
initial_pending = 0
|
353
|
+
requeued_chunks_by_shard = defaultdict(list)
|
354
|
+
|
355
|
+
for shard_name, shard_info in shards_summary.items():
|
356
|
+
with self.chunk_manager.lock:
|
357
|
+
for chunk_state in shard_info["chunks"]:
|
358
|
+
if chunk_state.status in ["pending", "failed", "assigned"]:
|
359
|
+
# Find shard URL
|
360
|
+
shard_url = None
|
361
|
+
for url in self.all_shards:
|
362
|
+
if Path(url).stem == shard_name:
|
363
|
+
shard_url = url
|
364
|
+
break
|
365
|
+
|
366
|
+
if shard_url:
|
367
|
+
chunk = ShardChunk(
|
368
|
+
chunk_id=chunk_state.chunk_id,
|
369
|
+
shard_url=shard_url,
|
370
|
+
shard_name=chunk_state.shard_name,
|
371
|
+
start_index=chunk_state.start_index,
|
372
|
+
chunk_size=chunk_state.chunk_size,
|
373
|
+
)
|
374
|
+
self.chunk_manager.chunks[chunk_state.chunk_id] = chunk
|
375
|
+
self.chunk_manager.pending_chunks.append(chunk_state.chunk_id)
|
376
|
+
requeued_chunks_by_shard[shard_name].append(chunk_state.chunk_id)
|
377
|
+
initial_pending += 1
|
378
|
+
|
379
|
+
logger.info(f"Re-queued {initial_pending} existing pending chunks")
|
380
|
+
for shard_name, chunk_ids in requeued_chunks_by_shard.items():
|
381
|
+
logger.info(f" Shard {shard_name}: {len(chunk_ids)} chunks - {chunk_ids}")
|
382
|
+
|
383
|
+
# Mark state as restored
|
384
|
+
self.state_restored.set()
|
385
|
+
logger.info("State restoration complete, accepting chunk requests")
|
386
|
+
|
387
|
+
# Process shards on-demand
|
388
|
+
shard_iter = iter(remaining_shards)
|
389
|
+
current_shard_url = None
|
390
|
+
current_shard_name = None
|
391
|
+
current_shard_items = None
|
392
|
+
current_shard_index = 0
|
393
|
+
|
394
|
+
while not self.stop_chunk_creation.is_set():
|
395
|
+
# Check how many chunks we need
|
396
|
+
with self.chunk_manager.lock:
|
397
|
+
pending_count = len(self.chunk_manager.pending_chunks)
|
398
|
+
assigned_count = sum(
|
399
|
+
len(chunks) for chunks in self.chunk_manager.assigned_chunks.values()
|
400
|
+
)
|
401
|
+
total_active = pending_count + assigned_count
|
402
|
+
|
403
|
+
# Target buffer: configurable multiplier × number of workers
|
404
|
+
worker_count = max(1, self.stats.get("connected_workers", 0))
|
405
|
+
target_buffer = max(
|
406
|
+
self.min_chunk_buffer, worker_count * self.chunk_buffer_multiplier
|
407
|
+
)
|
408
|
+
|
409
|
+
chunks_needed = max(0, target_buffer - total_active)
|
410
|
+
|
411
|
+
if chunks_needed == 0:
|
412
|
+
# We have enough chunks, wait a bit
|
413
|
+
time.sleep(5)
|
414
|
+
continue
|
415
|
+
|
416
|
+
logger.debug(
|
417
|
+
f"Need {chunks_needed} more chunks (pending: {pending_count}, "
|
418
|
+
f"assigned: {assigned_count}, workers: {worker_count})"
|
419
|
+
)
|
420
|
+
|
421
|
+
# Create chunks as needed
|
422
|
+
chunks_created = 0
|
423
|
+
|
424
|
+
while chunks_created < chunks_needed and not self.stop_chunk_creation.is_set():
|
425
|
+
# Need to load next shard?
|
426
|
+
if current_shard_url is None or current_shard_index >= current_shard_items:
|
427
|
+
try:
|
428
|
+
current_shard_url = next(shard_iter)
|
429
|
+
current_shard_name = Path(current_shard_url).stem
|
430
|
+
self.stats["current_shard"] = current_shard_name
|
431
|
+
|
432
|
+
# Skip if we already have chunks from this shard
|
433
|
+
if current_shard_name in shards_summary:
|
434
|
+
logger.debug(
|
435
|
+
f"Skipping shard {current_shard_name} - already has chunks"
|
436
|
+
)
|
437
|
+
current_shard_url = None
|
438
|
+
continue
|
439
|
+
|
440
|
+
# Count items in new shard
|
441
|
+
logger.info(f"Loading new shard {current_shard_name}")
|
442
|
+
current_shard_items = sum(
|
443
|
+
1 for _ in self.dataset_loader.iterate_shard(current_shard_url)
|
444
|
+
)
|
445
|
+
current_shard_index = 0
|
446
|
+
logger.info(f"Shard {current_shard_name} has {current_shard_items} items")
|
447
|
+
|
448
|
+
except StopIteration:
|
449
|
+
# No more shards
|
450
|
+
logger.info("No more shards to process")
|
451
|
+
break
|
452
|
+
except Exception as e:
|
453
|
+
logger.error(f"Error loading shard {current_shard_name}: {e}")
|
454
|
+
current_shard_url = None
|
455
|
+
continue
|
456
|
+
|
457
|
+
# Create a chunk from current shard
|
458
|
+
if current_shard_url and current_shard_index < current_shard_items:
|
459
|
+
chunk_id = f"{current_shard_name}_chunk_{current_shard_index}"
|
460
|
+
chunk_size = min(self.chunk_size, current_shard_items - current_shard_index)
|
461
|
+
|
462
|
+
# Add to ChunkTracker
|
463
|
+
if self.chunk_tracker and self.chunk_tracker.add_chunk(
|
464
|
+
chunk_id, current_shard_name, current_shard_index, chunk_size
|
465
|
+
):
|
466
|
+
# Create chunk
|
467
|
+
chunk = ShardChunk(
|
468
|
+
chunk_id=chunk_id,
|
469
|
+
shard_url=current_shard_url,
|
470
|
+
shard_name=current_shard_name,
|
471
|
+
start_index=current_shard_index,
|
472
|
+
chunk_size=chunk_size,
|
473
|
+
)
|
474
|
+
|
475
|
+
with self.chunk_manager.lock:
|
476
|
+
self.chunk_manager.chunks[chunk_id] = chunk
|
477
|
+
self.chunk_manager.pending_chunks.append(chunk_id)
|
478
|
+
|
479
|
+
chunks_created += 1
|
480
|
+
self.stats["total_chunks"] += 1
|
481
|
+
|
482
|
+
current_shard_index += self.chunk_size
|
483
|
+
|
484
|
+
if chunks_created > 0:
|
485
|
+
logger.info(f"Created {chunks_created} chunks on demand")
|
486
|
+
|
487
|
+
# If we couldn't create any chunks and there are no more shards, we're done
|
488
|
+
if chunks_created == 0 and current_shard_url is None:
|
489
|
+
logger.info("All shards processed, chunk creation complete")
|
490
|
+
break
|
491
|
+
|
492
|
+
# Brief pause to avoid spinning
|
493
|
+
time.sleep(1)
|
494
|
+
|
495
|
+
# Final stats
|
496
|
+
if self.chunk_tracker:
|
497
|
+
final_stats = self.chunk_tracker.get_stats()
|
498
|
+
logger.info(
|
499
|
+
f"Chunk creation thread ending. Total: {final_stats['total']}, "
|
500
|
+
f"Pending: {final_stats['pending']}, Completed: {final_stats['completed']}"
|
501
|
+
)
|
502
|
+
|
503
|
+
logger.info("Chunk creation thread finished")
|
504
|
+
|
505
|
+
async def start(self):
|
506
|
+
"""Start the orchestrator server."""
|
507
|
+
logger.info(f"Starting vLLM orchestrator on {self.host}:{self.port}")
|
508
|
+
logger.info(
|
509
|
+
f"vLLM config: model={self.vllm_config.get('model')}, batch_size={self.vllm_config.get('batch_size')}"
|
510
|
+
)
|
511
|
+
|
512
|
+
# Load existing state BEFORE accepting connections
|
513
|
+
await self.storage.initialize()
|
514
|
+
if self.chunk_tracker:
|
515
|
+
await self.chunk_tracker.sync_with_storage(self.storage)
|
516
|
+
await self._restore_state()
|
517
|
+
|
518
|
+
# Start chunk creation thread if dataset is configured
|
519
|
+
if self.dataset_loader:
|
520
|
+
self.chunk_creation_thread = threading.Thread(
|
521
|
+
target=self._create_chunks_from_dataset, daemon=True
|
522
|
+
)
|
523
|
+
self.chunk_creation_thread.start()
|
524
|
+
|
525
|
+
# Give chunk creation thread time to restore existing chunks
|
526
|
+
await asyncio.sleep(2)
|
527
|
+
|
528
|
+
# Start background tasks
|
529
|
+
asyncio.create_task(self._heartbeat_loop())
|
530
|
+
asyncio.create_task(self._checkpoint_loop())
|
531
|
+
asyncio.create_task(self._stats_update_loop())
|
532
|
+
|
533
|
+
# Start WebSocket server
|
534
|
+
async with websockets.serve(
|
535
|
+
self.handle_connection, self.host, self.port, ssl=self.ssl_context
|
536
|
+
):
|
537
|
+
logger.info("vLLM Orchestrator ready for connections")
|
538
|
+
await asyncio.Future() # Run forever
|
539
|
+
|
540
|
+
async def handle_connection(self, websocket: WebSocketServerProtocol):
|
541
|
+
"""Handle new WebSocket connection."""
|
542
|
+
try:
|
543
|
+
# Authenticate
|
544
|
+
auth_msg = await websocket.recv()
|
545
|
+
auth_data = json.loads(auth_msg)
|
546
|
+
|
547
|
+
auth_ticket = self.auth.authenticate(auth_data.get("token"))
|
548
|
+
if not auth_ticket.role:
|
549
|
+
await websocket.send(safe_json_dumps({"error": "Invalid token"}))
|
550
|
+
return
|
551
|
+
|
552
|
+
if auth_ticket.role == "worker":
|
553
|
+
await self._handle_worker(websocket, auth_ticket)
|
554
|
+
elif auth_ticket.role == "data_worker":
|
555
|
+
await self._handle_data_worker(websocket, auth_ticket)
|
556
|
+
elif auth_ticket.role == "monitor":
|
557
|
+
await self._handle_monitor(websocket)
|
558
|
+
elif auth_ticket.role == "admin":
|
559
|
+
await self._handle_admin(websocket, auth_ticket)
|
560
|
+
else:
|
561
|
+
await websocket.send(safe_json_dumps({"error": f"Unknown role: {auth_ticket.role}"}))
|
562
|
+
|
563
|
+
except Exception as e:
|
564
|
+
logger.error(f"Connection error: {e}")
|
565
|
+
import traceback
|
566
|
+
|
567
|
+
logger.error(traceback.format_exc())
|
568
|
+
await websocket.close()
|
569
|
+
|
570
|
+
async def _handle_admin(self, websocket: WebSocketServerProtocol, auth_ticket):
|
571
|
+
"""Handle admin connection for configuration updates."""
|
572
|
+
admin_id = getattr(auth_ticket, "name", "admin")
|
573
|
+
logger.info(f"Admin {admin_id} connected")
|
574
|
+
|
575
|
+
try:
|
576
|
+
# Send welcome
|
577
|
+
await websocket.send(safe_json_dumps({"type": "welcome", "role": "admin"}))
|
578
|
+
|
579
|
+
async for message in websocket:
|
580
|
+
try:
|
581
|
+
data = json.loads(message)
|
582
|
+
msg_type = data.get("type")
|
583
|
+
|
584
|
+
if msg_type == "reload_config":
|
585
|
+
await self._handle_config_reload(websocket, data.get("config", {}))
|
586
|
+
|
587
|
+
except json.JSONDecodeError as e:
|
588
|
+
logger.error(f"Invalid admin message: {e}")
|
589
|
+
await websocket.send(
|
590
|
+
safe_json_dumps({"type": "error", "error": "Invalid message format"})
|
591
|
+
)
|
592
|
+
|
593
|
+
except websockets.exceptions.ConnectionClosed:
|
594
|
+
logger.info(f"Admin {admin_id} disconnected")
|
595
|
+
|
596
|
+
async def _handle_config_reload(
|
597
|
+
self, websocket: WebSocketServerProtocol, new_config: Dict[str, Any]
|
598
|
+
):
|
599
|
+
"""Handle configuration reload request."""
|
600
|
+
logger.info("Processing configuration reload request")
|
601
|
+
|
602
|
+
updated_sections = []
|
603
|
+
warnings = []
|
604
|
+
requires_worker_restart = False
|
605
|
+
|
606
|
+
try:
|
607
|
+
# Update vLLM configuration
|
608
|
+
if "vllm" in new_config:
|
609
|
+
old_vllm = self.vllm_config.copy()
|
610
|
+
|
611
|
+
# Check each field for actual changes
|
612
|
+
vllm_changed = False
|
613
|
+
for key, value in new_config["vllm"].items():
|
614
|
+
if self.vllm_config.get(key) != value:
|
615
|
+
self.vllm_config[key] = value
|
616
|
+
vllm_changed = True
|
617
|
+
|
618
|
+
if vllm_changed:
|
619
|
+
updated_sections.append("vllm")
|
620
|
+
|
621
|
+
# Check if critical changes require worker restart
|
622
|
+
if (
|
623
|
+
old_vllm.get("model") != self.vllm_config.get("model")
|
624
|
+
or old_vllm.get("gpu_memory_utilization")
|
625
|
+
!= self.vllm_config.get("gpu_memory_utilization")
|
626
|
+
or old_vllm.get("tensor_parallel_size")
|
627
|
+
!= self.vllm_config.get("tensor_parallel_size")
|
628
|
+
):
|
629
|
+
requires_worker_restart = True
|
630
|
+
warnings.append(
|
631
|
+
"Critical vLLM changes detected - workers will be disconnected to reload"
|
632
|
+
)
|
633
|
+
|
634
|
+
# Update dataset configuration
|
635
|
+
if "dataset" in new_config:
|
636
|
+
dataset_changed = False
|
637
|
+
for key, value in new_config["dataset"].items():
|
638
|
+
if self.dataset_config.get(key) != value:
|
639
|
+
self.dataset_config[key] = value
|
640
|
+
dataset_changed = True
|
641
|
+
|
642
|
+
if dataset_changed:
|
643
|
+
self.dataset_path = self.dataset_config.get("path")
|
644
|
+
self.dataset_type = self.dataset_config.get("type", "huggingface")
|
645
|
+
updated_sections.append("dataset")
|
646
|
+
warnings.append("Dataset changes will apply to new chunks only")
|
647
|
+
|
648
|
+
# Update chunk settings
|
649
|
+
if "chunk_size" in new_config and self.chunk_size != new_config["chunk_size"]:
|
650
|
+
self.chunk_size = new_config["chunk_size"]
|
651
|
+
self.chunk_manager.chunk_size = self.chunk_size
|
652
|
+
updated_sections.append("chunk_size")
|
653
|
+
|
654
|
+
if (
|
655
|
+
"chunks_per_request" in new_config
|
656
|
+
and self.chunks_per_request != new_config["chunks_per_request"]
|
657
|
+
):
|
658
|
+
self.chunks_per_request = new_config["chunks_per_request"]
|
659
|
+
updated_sections.append("chunks_per_request")
|
660
|
+
|
661
|
+
# Recreate auth manager
|
662
|
+
self.auth = AuthManager(config=new_config)
|
663
|
+
|
664
|
+
# Update buffer settings
|
665
|
+
if (
|
666
|
+
"chunk_buffer_multiplier" in new_config
|
667
|
+
and self.chunk_buffer_multiplier != new_config["chunk_buffer_multiplier"]
|
668
|
+
):
|
669
|
+
self.chunk_buffer_multiplier = new_config["chunk_buffer_multiplier"]
|
670
|
+
updated_sections.append("chunk_buffer_multiplier")
|
671
|
+
|
672
|
+
if (
|
673
|
+
"min_chunk_buffer" in new_config
|
674
|
+
and self.min_chunk_buffer != new_config["min_chunk_buffer"]
|
675
|
+
):
|
676
|
+
self.min_chunk_buffer = new_config["min_chunk_buffer"]
|
677
|
+
updated_sections.append("min_chunk_buffer")
|
678
|
+
|
679
|
+
# Update storage settings
|
680
|
+
if "storage" in new_config:
|
681
|
+
storage_config = new_config["storage"]
|
682
|
+
storage_changed = False
|
683
|
+
|
684
|
+
if (
|
685
|
+
"caption_buffer_size" in storage_config
|
686
|
+
and self.storage.caption_buffer_size != storage_config["caption_buffer_size"]
|
687
|
+
):
|
688
|
+
self.storage.caption_buffer_size = storage_config["caption_buffer_size"]
|
689
|
+
storage_changed = True
|
690
|
+
|
691
|
+
if "checkpoint_interval" in storage_config:
|
692
|
+
current_interval = self.config.get("storage", {}).get(
|
693
|
+
"checkpoint_interval", 1000
|
694
|
+
)
|
695
|
+
if current_interval != storage_config["checkpoint_interval"]:
|
696
|
+
self.config.setdefault("storage", {})["checkpoint_interval"] = (
|
697
|
+
storage_config["checkpoint_interval"]
|
698
|
+
)
|
699
|
+
storage_changed = True
|
700
|
+
|
701
|
+
if storage_changed:
|
702
|
+
updated_sections.append("storage")
|
703
|
+
|
704
|
+
# Update data worker storage config
|
705
|
+
if "data_worker_storage" in new_config:
|
706
|
+
current_dw_storage = self.config.get("data_worker_storage", {})
|
707
|
+
if current_dw_storage != new_config["data_worker_storage"]:
|
708
|
+
self.config["data_worker_storage"] = new_config["data_worker_storage"]
|
709
|
+
updated_sections.append("data_worker_storage")
|
710
|
+
warnings.append("Data worker storage config will apply to new connections only")
|
711
|
+
|
712
|
+
# Update backpressure threshold
|
713
|
+
if "backpressure_threshold" in new_config:
|
714
|
+
current_threshold = getattr(self, "backpressure_threshold", 800)
|
715
|
+
if current_threshold != new_config["backpressure_threshold"]:
|
716
|
+
self.backpressure_threshold = new_config["backpressure_threshold"]
|
717
|
+
updated_sections.append("backpressure_threshold")
|
718
|
+
|
719
|
+
# Check if any changes were made
|
720
|
+
if not updated_sections:
|
721
|
+
await websocket.send(
|
722
|
+
safe_json_dumps(
|
723
|
+
{
|
724
|
+
"type": "reload_complete",
|
725
|
+
"message": "No changes applied - configuration is identical",
|
726
|
+
}
|
727
|
+
)
|
728
|
+
)
|
729
|
+
logger.info("Configuration reload requested but no changes detected")
|
730
|
+
return
|
731
|
+
|
732
|
+
# Update the main config for any other fields
|
733
|
+
self.config.update(new_config)
|
734
|
+
|
735
|
+
# Handle worker restart if needed
|
736
|
+
if requires_worker_restart:
|
737
|
+
logger.info("Disconnecting all workers for configuration reload...")
|
738
|
+
|
739
|
+
# Disconnect all workers
|
740
|
+
worker_ids = list(self.workers.keys())
|
741
|
+
for worker_id in worker_ids:
|
742
|
+
try:
|
743
|
+
await self.workers[worker_id].close(
|
744
|
+
code=1012, reason="Configuration reload"
|
745
|
+
)
|
746
|
+
except:
|
747
|
+
pass
|
748
|
+
|
749
|
+
warnings.append(
|
750
|
+
f"Disconnected {len(worker_ids)} workers - they will reconnect with new config"
|
751
|
+
)
|
752
|
+
else:
|
753
|
+
# Just notify workers about config changes
|
754
|
+
reload_msg = safe_json_dumps(
|
755
|
+
{
|
756
|
+
"type": "config_update",
|
757
|
+
"vllm_config": self.vllm_config if "vllm" in updated_sections else None,
|
758
|
+
"dataset_config": (
|
759
|
+
self.dataset_config if "dataset" in updated_sections else None
|
760
|
+
),
|
761
|
+
}
|
762
|
+
)
|
763
|
+
|
764
|
+
disconnected = []
|
765
|
+
for worker_id, ws in self.workers.items():
|
766
|
+
try:
|
767
|
+
await ws.send(reload_msg)
|
768
|
+
except:
|
769
|
+
disconnected.append(worker_id)
|
770
|
+
|
771
|
+
for worker_id in disconnected:
|
772
|
+
del self.workers[worker_id]
|
773
|
+
|
774
|
+
# Send success response
|
775
|
+
await websocket.send(
|
776
|
+
safe_json_dumps(
|
777
|
+
{"type": "reload_complete", "updated": updated_sections, "warnings": warnings}
|
778
|
+
)
|
779
|
+
)
|
780
|
+
|
781
|
+
logger.info(f"Configuration reloaded. Updated sections: {', '.join(updated_sections)}")
|
782
|
+
|
783
|
+
# Broadcast stats update to monitors
|
784
|
+
await self._broadcast_stats()
|
785
|
+
await self._send_activity(
|
786
|
+
f"Configuration reloaded by admin: {', '.join(updated_sections)}"
|
787
|
+
)
|
788
|
+
|
789
|
+
except Exception as e:
|
790
|
+
logger.error(f"Configuration reload failed: {e}")
|
791
|
+
await websocket.send(safe_json_dumps({"type": "reload_failed", "error": str(e)}))
|
792
|
+
|
793
|
+
async def _handle_worker(self, websocket: WebSocketServerProtocol, auth_ticket):
|
794
|
+
"""Handle worker connection lifecycle."""
|
795
|
+
worker_id = getattr(auth_ticket, "name", str(uuid.uuid4()))
|
796
|
+
self.workers[worker_id] = websocket
|
797
|
+
self.stats["connected_workers"] = len(self.workers)
|
798
|
+
|
799
|
+
# Register contributor
|
800
|
+
contributor = Contributor(
|
801
|
+
contributor_id=worker_id, name=worker_id, total_captions=0, trust_level=1
|
802
|
+
)
|
803
|
+
await self.storage.save_contributor(contributor)
|
804
|
+
|
805
|
+
logger.info(f"Worker {worker_id} connected")
|
806
|
+
await self._broadcast_stats()
|
807
|
+
await self._send_activity(f"Worker {worker_id} connected")
|
808
|
+
|
809
|
+
try:
|
810
|
+
# Send welcome message with dataset configuration
|
811
|
+
welcome_message = {
|
812
|
+
"type": "welcome",
|
813
|
+
"worker_id": worker_id,
|
814
|
+
"dataset_config": {
|
815
|
+
"dataset_path": self.dataset_path,
|
816
|
+
"dataset_type": self.dataset_type,
|
817
|
+
"path": self.dataset_path, # For compatibility
|
818
|
+
"type": self.dataset_type, # For compatibility
|
819
|
+
},
|
820
|
+
"vllm_config": self.vllm_config,
|
821
|
+
}
|
822
|
+
await websocket.send(safe_json_dumps(welcome_message))
|
823
|
+
|
824
|
+
async for message in websocket:
|
825
|
+
data = json.loads(message)
|
826
|
+
await self._process_worker_message(worker_id, data)
|
827
|
+
|
828
|
+
except websockets.exceptions.ConnectionClosed:
|
829
|
+
logger.info(f"Worker {worker_id} disconnected")
|
830
|
+
finally:
|
831
|
+
del self.workers[worker_id]
|
832
|
+
self.stats["connected_workers"] = len(self.workers)
|
833
|
+
# Release chunks in both managers
|
834
|
+
self.chunk_manager.release_worker_chunks(worker_id)
|
835
|
+
if self.chunk_tracker:
|
836
|
+
# Mark released chunks as pending in tracker
|
837
|
+
released_chunks = self.chunk_tracker.release_worker_chunks(worker_id)
|
838
|
+
logger.info(
|
839
|
+
f"Released {len(released_chunks) if released_chunks is not None else 0} chunks from worker {worker_id}"
|
840
|
+
)
|
841
|
+
|
842
|
+
await self._broadcast_stats()
|
843
|
+
await self._send_activity(f"Worker {worker_id} disconnected")
|
844
|
+
|
845
|
+
async def _process_worker_message(self, worker_id: str, data: Dict):
|
846
|
+
"""Process message from worker."""
|
847
|
+
msg_type = data.get("type")
|
848
|
+
|
849
|
+
if msg_type == "request_chunks":
|
850
|
+
# Wait for state restoration to complete
|
851
|
+
if not self.state_restored.is_set():
|
852
|
+
logger.info(f"Worker {worker_id} requesting chunks, but state not yet restored")
|
853
|
+
await self.workers[worker_id].send(
|
854
|
+
safe_json_dumps({"type": "no_chunks", "reason": "state_restoring"})
|
855
|
+
)
|
856
|
+
return
|
857
|
+
|
858
|
+
count = data.get("count", self.chunks_per_request)
|
859
|
+
chunks = self.chunk_manager.get_chunks_for_worker(worker_id, count, self.chunk_tracker)
|
860
|
+
|
861
|
+
if chunks:
|
862
|
+
# Only send the fields that worker expects
|
863
|
+
chunk_data = []
|
864
|
+
for chunk in chunks:
|
865
|
+
chunk_data.append(
|
866
|
+
{
|
867
|
+
"chunk_id": chunk.chunk_id,
|
868
|
+
"shard_url": chunk.shard_url,
|
869
|
+
"shard_name": chunk.shard_name,
|
870
|
+
"start_index": chunk.start_index,
|
871
|
+
"chunk_size": chunk.chunk_size,
|
872
|
+
}
|
873
|
+
)
|
874
|
+
|
875
|
+
await self.workers[worker_id].send(
|
876
|
+
safe_json_dumps({"type": "shard_assignment", "chunks": chunk_data})
|
877
|
+
)
|
878
|
+
chunk_ids = [c["chunk_id"] for c in chunk_data]
|
879
|
+
logger.info(f"Assigned {len(chunks)} chunks to worker {worker_id}: {chunk_ids}")
|
880
|
+
await self._send_activity(f"Assigned {len(chunks)} chunks to {worker_id}")
|
881
|
+
else:
|
882
|
+
await self.workers[worker_id].send(safe_json_dumps({"type": "no_chunks"}))
|
883
|
+
|
884
|
+
elif msg_type == "chunk_complete":
|
885
|
+
chunk_id = data["chunk_id"]
|
886
|
+
if self.chunk_manager.complete_chunk(chunk_id, worker_id):
|
887
|
+
self.stats["completed_chunks"] += 1
|
888
|
+
|
889
|
+
if self.chunk_tracker:
|
890
|
+
self.chunk_tracker.mark_completed(chunk_id)
|
891
|
+
|
892
|
+
logger.info(f"Chunk {chunk_id} completed by worker {worker_id}")
|
893
|
+
await self._check_shard_completion(chunk_id)
|
894
|
+
await self._send_activity(f"Chunk {chunk_id} completed by {worker_id}")
|
895
|
+
elif msg_type == "chunk_failed":
|
896
|
+
chunk_id = data["chunk_id"]
|
897
|
+
error = data.get("error", "Unknown error")
|
898
|
+
if self.chunk_manager.fail_chunk(chunk_id, worker_id):
|
899
|
+
self.stats["failed_chunks"] += 1
|
900
|
+
|
901
|
+
if self.chunk_tracker:
|
902
|
+
self.chunk_tracker.mark_failed(chunk_id)
|
903
|
+
|
904
|
+
logger.warning(f"Chunk {chunk_id} failed on worker {worker_id}: {error}")
|
905
|
+
await self._send_activity(f"Chunk {chunk_id} failed on {worker_id}: {error}")
|
906
|
+
|
907
|
+
elif msg_type == "submit_captions":
|
908
|
+
await self._handle_captions_submission(worker_id, data)
|
909
|
+
elif msg_type == "request_job":
|
910
|
+
# VLLMWorker requesting a job from data samples
|
911
|
+
try:
|
912
|
+
job = await asyncio.wait_for(self.data_sample_queue.get(), timeout=5)
|
913
|
+
await self.workers[worker_id].send(
|
914
|
+
json.dumps({"type": "job_assignment", "job": job})
|
915
|
+
)
|
916
|
+
logger.debug(f"Assigned job {job['job_id']} to worker {worker_id}")
|
917
|
+
except asyncio.TimeoutError:
|
918
|
+
await self.workers[worker_id].send(json.dumps({"type": "no_jobs"}))
|
919
|
+
elif msg_type == "heartbeat":
|
920
|
+
# Update worker stats
|
921
|
+
logger.debug(f"Heartbeat from {worker_id}: {data}")
|
922
|
+
|
923
|
+
async def _handle_captions_submission(self, worker_id: str, data: Dict):
|
924
|
+
"""Process multiple captions submission from worker."""
|
925
|
+
chunk_id = data.get("chunk_id")
|
926
|
+
item_key = data["item_key"]
|
927
|
+
captions_list = data["captions"]
|
928
|
+
|
929
|
+
logger.debug(
|
930
|
+
f"Received {len(captions_list)} captions for item {item_key} from worker {worker_id}"
|
931
|
+
)
|
932
|
+
|
933
|
+
# Create a SINGLE caption record with ALL captions as a list
|
934
|
+
caption = Caption(
|
935
|
+
job_id=f"{chunk_id}_{item_key}", # Single ID for the item
|
936
|
+
dataset=data.get("dataset"),
|
937
|
+
shard=data.get("shard"),
|
938
|
+
item_key=item_key,
|
939
|
+
captions=captions_list, # Store ALL captions as a list
|
940
|
+
contributor_id=worker_id,
|
941
|
+
timestamp=datetime.utcnow(),
|
942
|
+
quality_scores=None, # Could be a list of scores matching captions
|
943
|
+
# Image metadata
|
944
|
+
image_width=data.get("image_width"),
|
945
|
+
image_height=data.get("image_height"),
|
946
|
+
image_format=data.get("image_format"),
|
947
|
+
file_size=data.get("file_size"),
|
948
|
+
# Processing metadata
|
949
|
+
caption_count=len(captions_list),
|
950
|
+
processing_time_ms=data.get("processing_time_ms"),
|
951
|
+
chunk_id=chunk_id,
|
952
|
+
)
|
953
|
+
|
954
|
+
# Add to central storage buffer as a single entry
|
955
|
+
await self.storage.save_caption(caption)
|
956
|
+
|
957
|
+
# Update statistics
|
958
|
+
self.stats["total_captions"] += len(captions_list)
|
959
|
+
self.stats["buffer_size"] = len(self.storage.caption_buffer)
|
960
|
+
|
961
|
+
# Update contributor stats
|
962
|
+
contributor = await self.storage.get_contributor(worker_id)
|
963
|
+
if contributor:
|
964
|
+
contributor.total_captions += len(captions_list)
|
965
|
+
await self.storage.save_contributor(contributor)
|
966
|
+
|
967
|
+
# Broadcast updated stats
|
968
|
+
await self._broadcast_stats()
|
969
|
+
|
970
|
+
# Log progress periodically
|
971
|
+
if self.stats["total_captions"] % 100 == 0:
|
972
|
+
logger.info(f"Collected {self.stats['total_captions']} captions centrally")
|
973
|
+
|
974
|
+
async def _check_shard_completion(self, chunk_id: str):
|
975
|
+
"""Check if a shard is complete after chunk completion."""
|
976
|
+
# Extract shard name from chunk_id
|
977
|
+
shard_name = chunk_id.rsplit("_chunk_", 1)[0]
|
978
|
+
|
979
|
+
# Check if all chunks for this shard are complete
|
980
|
+
chunk_stats = self.chunk_manager.get_stats()
|
981
|
+
shard_chunks = [
|
982
|
+
cid
|
983
|
+
for cid, chunk in self.chunk_manager.chunks.items()
|
984
|
+
if chunk.shard_name == shard_name
|
985
|
+
]
|
986
|
+
|
987
|
+
completed_chunks = [
|
988
|
+
cid for cid in shard_chunks if self.chunk_manager.chunks[cid].status == "completed"
|
989
|
+
]
|
990
|
+
|
991
|
+
if len(completed_chunks) == len(shard_chunks):
|
992
|
+
logger.info(f"Shard {shard_name} complete!")
|
993
|
+
self.shard_tracker.mark_complete(shard_name)
|
994
|
+
self.stats["completed_shards"] += 1
|
995
|
+
await self._send_activity(f"Shard {shard_name} completed!")
|
996
|
+
|
997
|
+
async def _handle_data_worker(self, websocket: WebSocketServerProtocol, auth_ticket):
|
998
|
+
"""Handle data worker connection."""
|
999
|
+
worker_id = getattr(auth_ticket, "name", str(uuid.uuid4()))
|
1000
|
+
self.data_workers[worker_id] = websocket
|
1001
|
+
|
1002
|
+
logger.info(f"Data worker {worker_id} connected")
|
1003
|
+
|
1004
|
+
try:
|
1005
|
+
# Send welcome with storage config
|
1006
|
+
storage_config = self.config.get(
|
1007
|
+
"data_worker_storage",
|
1008
|
+
{
|
1009
|
+
"forward_to_orchestrator": True,
|
1010
|
+
"local": {"enabled": False},
|
1011
|
+
"s3": {"enabled": False},
|
1012
|
+
},
|
1013
|
+
)
|
1014
|
+
|
1015
|
+
await websocket.send(
|
1016
|
+
json.dumps(
|
1017
|
+
{"type": "welcome", "worker_id": worker_id, "storage_config": storage_config}
|
1018
|
+
)
|
1019
|
+
)
|
1020
|
+
|
1021
|
+
# Track if we've sent backpressure
|
1022
|
+
backpressure_sent = False
|
1023
|
+
|
1024
|
+
async for message in websocket:
|
1025
|
+
data = json.loads(message)
|
1026
|
+
msg_type = data.get("type")
|
1027
|
+
|
1028
|
+
if msg_type == "submit_samples":
|
1029
|
+
# Check queue size for backpressure
|
1030
|
+
if self.data_sample_queue.qsize() > self.backpressure_threshold:
|
1031
|
+
if not backpressure_sent:
|
1032
|
+
await websocket.send(json.dumps({"type": "backpressure"}))
|
1033
|
+
backpressure_sent = True
|
1034
|
+
logger.warning(f"Backpressure applied to data worker {worker_id}")
|
1035
|
+
else:
|
1036
|
+
if backpressure_sent:
|
1037
|
+
await websocket.send(json.dumps({"type": "resume"}))
|
1038
|
+
backpressure_sent = False
|
1039
|
+
|
1040
|
+
# Receive image data for each sample
|
1041
|
+
samples = data["samples"]
|
1042
|
+
for sample in samples:
|
1043
|
+
# Receive binary image data
|
1044
|
+
image_data = await websocket.recv()
|
1045
|
+
|
1046
|
+
# Create job and add to queue
|
1047
|
+
job = {
|
1048
|
+
"job_id": f"data_{worker_id}_{sample['sample_id']}",
|
1049
|
+
"sample_id": sample["sample_id"],
|
1050
|
+
"image_data": image_data,
|
1051
|
+
"metadata": sample.get("metadata", {}),
|
1052
|
+
"source": "data_worker",
|
1053
|
+
"worker_id": worker_id,
|
1054
|
+
}
|
1055
|
+
|
1056
|
+
await self.data_sample_queue.put(job)
|
1057
|
+
|
1058
|
+
elif msg_type == "heartbeat":
|
1059
|
+
logger.debug(f"Data worker {worker_id} heartbeat: {data}")
|
1060
|
+
|
1061
|
+
except websockets.exceptions.ConnectionClosed:
|
1062
|
+
logger.info(f"Data worker {worker_id} disconnected")
|
1063
|
+
finally:
|
1064
|
+
del self.data_workers[worker_id]
|
1065
|
+
|
1066
|
+
async def _handle_monitor(self, websocket: WebSocketServerProtocol):
|
1067
|
+
"""Handle monitor connection."""
|
1068
|
+
self.monitors.add(websocket)
|
1069
|
+
logger.info("Monitor connected")
|
1070
|
+
|
1071
|
+
try:
|
1072
|
+
# Send initial stats
|
1073
|
+
await websocket.send(safe_json_dumps({"type": "stats", "data": self.stats}))
|
1074
|
+
|
1075
|
+
# Send chunk stats
|
1076
|
+
chunk_stats = self.chunk_manager.get_stats()
|
1077
|
+
await websocket.send(safe_json_dumps({"type": "chunk_stats", "data": chunk_stats}))
|
1078
|
+
|
1079
|
+
# Send contributor leaderboard
|
1080
|
+
contributors = await self.storage.get_top_contributors(10)
|
1081
|
+
await websocket.send(
|
1082
|
+
safe_json_dumps(
|
1083
|
+
{"type": "leaderboard", "data": [safe_dict(c) for c in contributors]}
|
1084
|
+
)
|
1085
|
+
)
|
1086
|
+
|
1087
|
+
# Keep connection alive
|
1088
|
+
async for _ in websocket:
|
1089
|
+
pass
|
1090
|
+
|
1091
|
+
except websockets.exceptions.ConnectionClosed:
|
1092
|
+
logger.info("Monitor disconnected")
|
1093
|
+
finally:
|
1094
|
+
self.monitors.discard(websocket)
|
1095
|
+
|
1096
|
+
async def _broadcast_stats(self):
|
1097
|
+
"""Broadcast statistics to all monitors."""
|
1098
|
+
if not self.monitors:
|
1099
|
+
return
|
1100
|
+
|
1101
|
+
# Include chunk stats
|
1102
|
+
chunk_stats = self.chunk_manager.get_stats()
|
1103
|
+
self.stats.update({f"chunks_{k}": v for k, v in chunk_stats.items()})
|
1104
|
+
|
1105
|
+
# Add rate information
|
1106
|
+
self.stats.update(
|
1107
|
+
{
|
1108
|
+
"current_rate": self.rate_tracker["current_rate"],
|
1109
|
+
"average_rate": self.rate_tracker["average_rate"],
|
1110
|
+
"expected_rate": self.rate_tracker["expected_rate"],
|
1111
|
+
}
|
1112
|
+
)
|
1113
|
+
|
1114
|
+
# Add vLLM info
|
1115
|
+
self.stats["vllm_model"] = self.vllm_config.get("model", "unknown")
|
1116
|
+
self.stats["vllm_batch_size"] = self.vllm_config.get("batch_size", 0)
|
1117
|
+
|
1118
|
+
message = safe_json_dumps({"type": "stats", "data": self.stats})
|
1119
|
+
|
1120
|
+
# Send to all monitors
|
1121
|
+
disconnected = set()
|
1122
|
+
for monitor in self.monitors:
|
1123
|
+
try:
|
1124
|
+
await monitor.send(message)
|
1125
|
+
except websockets.exceptions.ConnectionClosed:
|
1126
|
+
disconnected.add(monitor)
|
1127
|
+
|
1128
|
+
# Clean up disconnected monitors
|
1129
|
+
self.monitors -= disconnected
|
1130
|
+
|
1131
|
+
async def _send_activity(self, activity: str):
|
1132
|
+
"""Send activity update to monitors."""
|
1133
|
+
if not self.monitors:
|
1134
|
+
return
|
1135
|
+
|
1136
|
+
message = safe_json_dumps(
|
1137
|
+
{"type": "activity", "data": f"[{datetime.now().strftime('%H:%M:%S')}] {activity}"}
|
1138
|
+
)
|
1139
|
+
|
1140
|
+
disconnected = set()
|
1141
|
+
for monitor in self.monitors:
|
1142
|
+
try:
|
1143
|
+
await monitor.send(message)
|
1144
|
+
except websockets.exceptions.ConnectionClosed:
|
1145
|
+
disconnected.add(monitor)
|
1146
|
+
|
1147
|
+
self.monitors -= disconnected
|
1148
|
+
|
1149
|
+
async def _heartbeat_loop(self):
|
1150
|
+
"""Send periodic heartbeats to maintain connections."""
|
1151
|
+
while True:
|
1152
|
+
await asyncio.sleep(30)
|
1153
|
+
|
1154
|
+
# Ping workers
|
1155
|
+
disconnected = []
|
1156
|
+
for worker_id, ws in self.workers.items():
|
1157
|
+
try:
|
1158
|
+
await ws.ping()
|
1159
|
+
except:
|
1160
|
+
disconnected.append(worker_id)
|
1161
|
+
|
1162
|
+
# Clean up disconnected workers
|
1163
|
+
for worker_id in disconnected:
|
1164
|
+
if worker_id in self.workers:
|
1165
|
+
del self.workers[worker_id]
|
1166
|
+
self.chunk_manager.release_worker_chunks(worker_id)
|
1167
|
+
|
1168
|
+
async def _checkpoint_loop(self):
|
1169
|
+
"""Periodically checkpoint storage."""
|
1170
|
+
interval = self.config.get("storage", {}).get("checkpoint_interval", 1000)
|
1171
|
+
|
1172
|
+
while True:
|
1173
|
+
await asyncio.sleep(60)
|
1174
|
+
|
1175
|
+
# Force checkpoint at regular intervals
|
1176
|
+
if self.stats["total_captions"] > 0 and self.stats["total_captions"] % interval == 0:
|
1177
|
+
logger.info(f"Triggering checkpoint at {self.stats['total_captions']} captions")
|
1178
|
+
await self.storage.checkpoint()
|
1179
|
+
|
1180
|
+
# Update stats
|
1181
|
+
self.stats["last_checkpoint"] = datetime.utcnow().isoformat()
|
1182
|
+
self.stats["total_written"] = self.storage.total_captions_written
|
1183
|
+
self.stats["buffer_size"] = len(self.storage.caption_buffer)
|
1184
|
+
|
1185
|
+
await self._broadcast_stats()
|
1186
|
+
logger.info(
|
1187
|
+
f"Checkpoint complete. Total written to disk: {self.stats['total_written']}"
|
1188
|
+
)
|
1189
|
+
|
1190
|
+
async def _stats_update_loop(self):
|
1191
|
+
"""Periodically update and broadcast stats."""
|
1192
|
+
# Track session start values
|
1193
|
+
session_start_captions = self.stats["total_captions"]
|
1194
|
+
session_start_time = time.time()
|
1195
|
+
|
1196
|
+
while True:
|
1197
|
+
await asyncio.sleep(10)
|
1198
|
+
|
1199
|
+
# Update chunk stats
|
1200
|
+
chunk_stats = self.chunk_manager.get_stats()
|
1201
|
+
self.stats["total_chunks"] = chunk_stats["total"]
|
1202
|
+
self.stats["completed_chunks"] = chunk_stats["completed"]
|
1203
|
+
self.stats["failed_chunks"] = chunk_stats["failed"]
|
1204
|
+
|
1205
|
+
# Add queue information
|
1206
|
+
with self.chunk_manager.lock:
|
1207
|
+
self.stats["pending_chunks"] = len(self.chunk_manager.pending_chunks)
|
1208
|
+
self.stats["assigned_chunks"] = sum(
|
1209
|
+
len(chunks) for chunks in self.chunk_manager.assigned_chunks.values()
|
1210
|
+
)
|
1211
|
+
|
1212
|
+
# Calculate if we need more chunks
|
1213
|
+
worker_count = self.stats.get("connected_workers", 0)
|
1214
|
+
target_buffer = max(self.min_chunk_buffer, worker_count * self.chunk_buffer_multiplier)
|
1215
|
+
active_chunks = self.stats["pending_chunks"] + self.stats["assigned_chunks"]
|
1216
|
+
self.stats["chunk_buffer_status"] = f"{active_chunks}/{target_buffer}"
|
1217
|
+
|
1218
|
+
# Update rate information
|
1219
|
+
current_time = time.time()
|
1220
|
+
elapsed_since_update = current_time - self.rate_tracker["last_update_time"]
|
1221
|
+
|
1222
|
+
if elapsed_since_update > 0:
|
1223
|
+
# Calculate current rate (captions per minute)
|
1224
|
+
caption_diff = (
|
1225
|
+
self.stats["total_captions"] - self.rate_tracker["last_caption_count"]
|
1226
|
+
)
|
1227
|
+
self.rate_tracker["current_rate"] = (caption_diff / elapsed_since_update) * 60
|
1228
|
+
|
1229
|
+
# Calculate average rate since THIS SESSION started
|
1230
|
+
session_elapsed = current_time - session_start_time
|
1231
|
+
if session_elapsed > 0:
|
1232
|
+
session_captions = self.stats["total_captions"] - session_start_captions
|
1233
|
+
self.rate_tracker["average_rate"] = (session_captions / session_elapsed) * 60
|
1234
|
+
|
1235
|
+
# Calculate expected rate based on workers
|
1236
|
+
# Assume each worker processes batch_size images every ~2 seconds with 3 captions each
|
1237
|
+
batch_size = self.vllm_config.get("batch_size", 8)
|
1238
|
+
num_prompts = len(self.vllm_config.get("inference_prompts", ["", "", ""]))
|
1239
|
+
images_per_minute = 30 # Rough estimate: 30 images/min per worker
|
1240
|
+
self.rate_tracker["expected_rate"] = worker_count * images_per_minute * num_prompts
|
1241
|
+
|
1242
|
+
# Update trackers
|
1243
|
+
self.rate_tracker["last_update_time"] = current_time
|
1244
|
+
self.rate_tracker["last_caption_count"] = self.stats["total_captions"]
|
1245
|
+
|
1246
|
+
# Log rate information when workers are connected
|
1247
|
+
if worker_count > 0:
|
1248
|
+
logger.info(
|
1249
|
+
f"Rate: {self.rate_tracker['current_rate']:.1f} captions/min "
|
1250
|
+
f"(avg: {self.rate_tracker['average_rate']:.1f}, "
|
1251
|
+
f"expected: {self.rate_tracker['expected_rate']:.1f}) | "
|
1252
|
+
f"Workers: {worker_count}, Chunks: {active_chunks}/{target_buffer}"
|
1253
|
+
)
|
1254
|
+
|
1255
|
+
await self._broadcast_stats()
|
1256
|
+
|
1257
|
+
async def _restore_state(self):
|
1258
|
+
"""Restore state from storage on startup."""
|
1259
|
+
# Update statistics
|
1260
|
+
self.stats["total_captions"] = await self.storage.count_captions()
|
1261
|
+
|
1262
|
+
logger.info(f"Restored state: {self.stats['total_captions']} captions")
|
1263
|
+
|
1264
|
+
async def shutdown(self):
|
1265
|
+
"""Graceful shutdown."""
|
1266
|
+
logger.info("Shutting down orchestrator...")
|
1267
|
+
|
1268
|
+
# Stop chunk creation
|
1269
|
+
self.stop_chunk_creation.set()
|
1270
|
+
if self.chunk_creation_thread:
|
1271
|
+
self.chunk_creation_thread.join(timeout=5)
|
1272
|
+
|
1273
|
+
# Release all assigned chunks before closing connections
|
1274
|
+
for worker_id in list(self.workers.keys()):
|
1275
|
+
self.chunk_manager.release_worker_chunks(worker_id)
|
1276
|
+
if self.chunk_tracker:
|
1277
|
+
# Update chunk tracker to mark assigned chunks as pending
|
1278
|
+
with self.chunk_manager.lock:
|
1279
|
+
for chunk_id in list(self.chunk_manager.assigned_chunks.get(worker_id, [])):
|
1280
|
+
self.chunk_tracker.mark_pending(chunk_id)
|
1281
|
+
|
1282
|
+
# Close all connections
|
1283
|
+
for ws in list(self.workers.values()):
|
1284
|
+
await ws.close()
|
1285
|
+
for ws in list(self.monitors):
|
1286
|
+
await ws.close()
|
1287
|
+
|
1288
|
+
# Save chunk state
|
1289
|
+
if self.chunk_tracker:
|
1290
|
+
self.chunk_tracker.save_checkpoint()
|
1291
|
+
|
1292
|
+
# Final checkpoint
|
1293
|
+
logger.info(f"Final flush: {len(self.storage.caption_buffer)} captions in buffer")
|
1294
|
+
await self.storage.checkpoint()
|
1295
|
+
|
1296
|
+
# Log final statistics
|
1297
|
+
logger.info(
|
1298
|
+
f"Shutdown complete. Total captions collected: {self.storage.total_captions_written}"
|
1299
|
+
)
|
1300
|
+
|
1301
|
+
await self.storage.close()
|