caption-flow 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/cli.py +2 -1
- caption_flow/models.py +108 -1
- caption_flow/monitor.py +1 -1
- caption_flow/orchestrator.py +423 -1595
- caption_flow/processors/__init__.py +11 -0
- caption_flow/processors/base.py +219 -0
- caption_flow/processors/huggingface.py +832 -0
- caption_flow/processors/local_filesystem.py +683 -0
- caption_flow/processors/webdataset.py +782 -0
- caption_flow/storage.py +415 -406
- caption_flow/utils/checkpoint_tracker.py +2 -2
- caption_flow/utils/chunk_tracker.py +94 -35
- caption_flow/utils/dataset_loader.py +64 -522
- caption_flow/utils/dataset_metadata_cache.py +67 -0
- caption_flow/utils/image_processor.py +1 -4
- caption_flow/utils/shard_processor.py +4 -200
- caption_flow/utils/shard_tracker.py +1 -5
- caption_flow/workers/base.py +3 -3
- caption_flow/workers/caption.py +416 -792
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/METADATA +29 -27
- caption_flow-0.2.3.dist-info/RECORD +35 -0
- caption_flow-0.2.1.dist-info/RECORD +0 -29
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/WHEEL +0 -0
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/top_level.txt +0 -0
caption_flow/orchestrator.py
CHANGED
@@ -1,412 +1,106 @@
|
|
1
|
-
"""Enhanced orchestrator with shard chunk assignment for vLLM workers.
|
2
|
-
|
3
|
-
This orchestrator:
|
4
|
-
1. Divides dataset shards into chunks for parallel processing
|
5
|
-
2. Assigns chunks to workers on request
|
6
|
-
3. Collects captions from workers centrally
|
7
|
-
4. Manages checkpoints and fault tolerance
|
8
|
-
"""
|
9
|
-
|
10
1
|
import time
|
11
2
|
import asyncio
|
12
3
|
import json
|
13
4
|
import logging
|
14
5
|
import ssl
|
15
6
|
import uuid
|
16
|
-
from dataclasses import dataclass, asdict
|
17
7
|
from datetime import datetime
|
18
8
|
from pathlib import Path
|
19
|
-
from typing import Dict, Set, Optional, Any, List
|
20
|
-
from collections import
|
9
|
+
from typing import Dict, Set, Optional, Any, List
|
10
|
+
from collections import defaultdict
|
21
11
|
import threading
|
22
|
-
from queue import Queue, Empty
|
23
12
|
|
24
|
-
from .workers import data
|
25
13
|
import websockets
|
26
14
|
from websockets.server import WebSocketServerProtocol
|
27
15
|
|
28
16
|
from .storage import StorageManager
|
29
|
-
from .models import Caption, Contributor
|
17
|
+
from .models import Caption, Contributor, JobId
|
30
18
|
from .utils.auth import AuthManager
|
31
|
-
from .utils import
|
32
|
-
from .
|
19
|
+
from .utils.json_utils import safe_json_dumps
|
20
|
+
from .processors import (
|
21
|
+
ProcessorConfig,
|
22
|
+
WorkAssignment,
|
23
|
+
WorkResult,
|
24
|
+
WorkUnit,
|
25
|
+
WebDatasetOrchestratorProcessor,
|
26
|
+
HuggingFaceDatasetOrchestratorProcessor,
|
27
|
+
LocalFilesystemOrchestratorProcessor,
|
28
|
+
)
|
33
29
|
|
34
30
|
logger = logging.getLogger(__name__)
|
35
|
-
|
36
|
-
|
37
|
-
@dataclass
|
38
|
-
class ShardChunk:
|
39
|
-
"""Represents a chunk of a shard for processing."""
|
40
|
-
|
41
|
-
chunk_id: str
|
42
|
-
shard_url: str
|
43
|
-
shard_name: str
|
44
|
-
start_index: int
|
45
|
-
chunk_size: int
|
46
|
-
assigned_to: Optional[str] = None
|
47
|
-
status: str = "pending" # pending, assigned, completed, failed
|
48
|
-
assigned_at: Optional[datetime] = None
|
49
|
-
completed_at: Optional[datetime] = None
|
50
|
-
|
51
|
-
@classmethod
|
52
|
-
def create(
|
53
|
-
cls, shard_url: str, shard_name: str, start_index: int, chunk_size: int
|
54
|
-
) -> "ShardChunk":
|
55
|
-
"""Factory method to create a chunk with consistent ID."""
|
56
|
-
# Always use consistent format: dataset_chunk_startindex
|
57
|
-
if shard_url.startswith("hf_dataset:"):
|
58
|
-
# Extract dataset path
|
59
|
-
parts = shard_url.split(":")
|
60
|
-
dataset_path = parts[1] if len(parts) > 1 else "unknown"
|
61
|
-
chunk_id = f"{dataset_path.replace('/', '_')}_chunk_{start_index}"
|
62
|
-
else:
|
63
|
-
# WebDataset format
|
64
|
-
chunk_id = f"{shard_name}_chunk_{start_index}"
|
65
|
-
|
66
|
-
return cls(
|
67
|
-
chunk_id=chunk_id,
|
68
|
-
shard_url=shard_url,
|
69
|
-
shard_name=shard_name,
|
70
|
-
start_index=start_index,
|
71
|
-
chunk_size=chunk_size,
|
72
|
-
)
|
73
|
-
|
74
|
-
def belongs_to_shard(self, shard_identifier: str) -> bool:
|
75
|
-
"""Check if this chunk belongs to a given shard."""
|
76
|
-
return self.shard_name == shard_identifier
|
77
|
-
|
78
|
-
def to_dict(self) -> Dict[str, Any]:
|
79
|
-
"""Convert to dict for JSON serialization (for workers)."""
|
80
|
-
return {
|
81
|
-
"chunk_id": self.chunk_id,
|
82
|
-
"shard_url": self.shard_url,
|
83
|
-
"shard_name": self.shard_name,
|
84
|
-
"start_index": self.start_index,
|
85
|
-
"chunk_size": self.chunk_size,
|
86
|
-
}
|
87
|
-
|
88
|
-
|
89
|
-
class ChunkManager:
|
90
|
-
"""Manages shard chunk creation and assignment."""
|
91
|
-
|
92
|
-
def __init__(self, chunk_size: int = 1000, tracker: Optional[ChunkTracker] = None):
|
93
|
-
self.chunk_size = chunk_size
|
94
|
-
self.chunks: Dict[str, ShardChunk] = {}
|
95
|
-
self.pending_chunks: Deque[str] = deque()
|
96
|
-
self.assigned_chunks: Dict[str, Set[str]] = defaultdict(set) # worker_id -> chunk_ids
|
97
|
-
self.lock = threading.Lock()
|
98
|
-
self.tracker = tracker # Reference to chunk tracker
|
99
|
-
|
100
|
-
def create_chunks_from_shard(
|
101
|
-
self, shard_url: str, shard_name: str, total_items: int
|
102
|
-
) -> List[ShardChunk]:
|
103
|
-
"""Create chunks from a shard."""
|
104
|
-
chunks = []
|
105
|
-
|
106
|
-
for start_idx in range(0, total_items, self.chunk_size):
|
107
|
-
chunk = ShardChunk.create(
|
108
|
-
shard_url=shard_url,
|
109
|
-
shard_name=shard_name,
|
110
|
-
start_index=start_idx,
|
111
|
-
chunk_size=min(self.chunk_size, total_items - start_idx),
|
112
|
-
)
|
113
|
-
|
114
|
-
with self.lock:
|
115
|
-
self.chunks[chunk.chunk_id] = chunk
|
116
|
-
self.pending_chunks.append(chunk.chunk_id)
|
117
|
-
|
118
|
-
chunks.append(chunk)
|
119
|
-
|
120
|
-
return chunks
|
121
|
-
|
122
|
-
def get_chunks_for_worker(
|
123
|
-
self, worker_id: str, count: int = 1, tracker: Optional["ChunkTracker"] = None
|
124
|
-
) -> List[Dict[str, Any]]:
|
125
|
-
"""Get available chunks with unprocessed items for a worker."""
|
126
|
-
assigned = []
|
127
|
-
|
128
|
-
with self.lock:
|
129
|
-
# FIRST PRIORITY: Check if this worker already has assigned chunks
|
130
|
-
# Workers should complete their current chunks before getting new ones
|
131
|
-
if worker_id in self.assigned_chunks:
|
132
|
-
existing_chunk_ids = list(self.assigned_chunks[worker_id])
|
133
|
-
for chunk_id in existing_chunk_ids:
|
134
|
-
if len(assigned) >= count:
|
135
|
-
break
|
136
|
-
|
137
|
-
chunk = self.chunks.get(chunk_id)
|
138
|
-
if not chunk:
|
139
|
-
continue
|
140
|
-
|
141
|
-
# Check if chunk still has unprocessed items
|
142
|
-
if tracker:
|
143
|
-
chunk_info = tracker.get_chunk_with_unprocessed_items(chunk_id)
|
144
|
-
if chunk_info and chunk_info["unprocessed_ranges"]:
|
145
|
-
assigned.append(
|
146
|
-
{
|
147
|
-
"chunk": chunk,
|
148
|
-
"unprocessed_ranges": chunk_info["unprocessed_ranges"],
|
149
|
-
}
|
150
|
-
)
|
151
|
-
else:
|
152
|
-
# No tracker, assume chunk needs processing
|
153
|
-
assigned.append(
|
154
|
-
{
|
155
|
-
"chunk": chunk,
|
156
|
-
"unprocessed_ranges": [(0, chunk.chunk_size - 1)],
|
157
|
-
}
|
158
|
-
)
|
159
|
-
|
160
|
-
# SECOND PRIORITY: Get new pending chunks
|
161
|
-
# Only if worker doesn't have enough chunks already
|
162
|
-
while len(assigned) < count and self.pending_chunks:
|
163
|
-
chunk_id = self.pending_chunks.popleft()
|
164
|
-
chunk = self.chunks.get(chunk_id)
|
165
|
-
|
166
|
-
if not chunk:
|
167
|
-
continue
|
168
|
-
|
169
|
-
# Verify chunk is truly pending (defensive check)
|
170
|
-
if chunk.status != "pending" or chunk.assigned_to is not None:
|
171
|
-
logger.warning(
|
172
|
-
f"Chunk {chunk_id} in pending queue but status={chunk.status}, assigned_to={chunk.assigned_to}"
|
173
|
-
)
|
174
|
-
continue
|
175
|
-
|
176
|
-
# Assign to this worker
|
177
|
-
chunk.assigned_to = worker_id
|
178
|
-
chunk.status = "assigned"
|
179
|
-
chunk.assigned_at = datetime.utcnow()
|
180
|
-
self.assigned_chunks[worker_id].add(chunk_id)
|
181
|
-
|
182
|
-
# Get unprocessed ranges
|
183
|
-
unprocessed_ranges = [(0, chunk.chunk_size - 1)] # Default
|
184
|
-
if tracker:
|
185
|
-
chunk_info = tracker.get_chunk_with_unprocessed_items(chunk_id)
|
186
|
-
if chunk_info:
|
187
|
-
unprocessed_ranges = chunk_info["unprocessed_ranges"]
|
188
|
-
tracker.mark_assigned(chunk_id, worker_id)
|
189
|
-
|
190
|
-
assigned.append({"chunk": chunk, "unprocessed_ranges": unprocessed_ranges})
|
191
|
-
|
192
|
-
# Log what we're assigning
|
193
|
-
if assigned:
|
194
|
-
chunk_summary = ", ".join(
|
195
|
-
[
|
196
|
-
f"{info['chunk'].chunk_id}[{len(info['unprocessed_ranges'])} ranges]"
|
197
|
-
for info in assigned
|
198
|
-
]
|
199
|
-
)
|
200
|
-
logger.info(f"Assigning to worker {worker_id}: {chunk_summary}")
|
201
|
-
|
202
|
-
return assigned
|
203
|
-
|
204
|
-
def complete_chunk(self, chunk_id: str, worker_id: str) -> bool:
|
205
|
-
"""Mark a chunk as completed."""
|
206
|
-
with self.lock:
|
207
|
-
if chunk_id in self.chunks:
|
208
|
-
chunk = self.chunks[chunk_id]
|
209
|
-
if chunk.assigned_to == worker_id and chunk.status == "assigned":
|
210
|
-
chunk.status = "completed"
|
211
|
-
chunk.completed_at = datetime.utcnow()
|
212
|
-
self.assigned_chunks[worker_id].discard(chunk_id)
|
213
|
-
return True
|
214
|
-
return False
|
215
|
-
|
216
|
-
def fail_chunk(self, chunk_id: str, worker_id: str) -> bool:
|
217
|
-
"""Mark a chunk as failed and requeue it."""
|
218
|
-
with self.lock:
|
219
|
-
if chunk_id in self.chunks:
|
220
|
-
chunk = self.chunks[chunk_id]
|
221
|
-
if chunk.assigned_to == worker_id:
|
222
|
-
chunk.status = "pending"
|
223
|
-
chunk.assigned_to = None
|
224
|
-
chunk.assigned_at = None
|
225
|
-
self.assigned_chunks[worker_id].discard(chunk_id)
|
226
|
-
self.pending_chunks.append(chunk_id)
|
227
|
-
return True
|
228
|
-
return False
|
229
|
-
|
230
|
-
def release_worker_chunks(self, worker_id: str):
|
231
|
-
"""Release all chunks assigned to a worker."""
|
232
|
-
with self.lock:
|
233
|
-
chunk_ids = list(self.assigned_chunks.get(worker_id, []))
|
234
|
-
for chunk_id in chunk_ids:
|
235
|
-
if chunk_id in self.chunks:
|
236
|
-
chunk = self.chunks[chunk_id]
|
237
|
-
if chunk.status == "assigned":
|
238
|
-
chunk.status = "pending"
|
239
|
-
chunk.assigned_to = None
|
240
|
-
chunk.assigned_at = None
|
241
|
-
self.pending_chunks.append(chunk_id)
|
242
|
-
|
243
|
-
if worker_id in self.assigned_chunks:
|
244
|
-
del self.assigned_chunks[worker_id]
|
245
|
-
|
246
|
-
def get_stats(self) -> Dict[str, int]:
|
247
|
-
"""Get chunk statistics."""
|
248
|
-
with self.lock:
|
249
|
-
stats = {
|
250
|
-
"total": len(self.chunks),
|
251
|
-
"pending": len(self.pending_chunks),
|
252
|
-
"assigned": sum(len(chunks) for chunks in self.assigned_chunks.values()),
|
253
|
-
"completed": sum(1 for c in self.chunks.values() if c.status == "completed"),
|
254
|
-
"failed": sum(1 for c in self.chunks.values() if c.status == "failed"),
|
255
|
-
}
|
256
|
-
return stats
|
31
|
+
logger.setLevel(logging.INFO)
|
257
32
|
|
258
33
|
|
259
34
|
class Orchestrator:
|
260
|
-
"""
|
35
|
+
"""Generic orchestrator for distributed work processing."""
|
261
36
|
|
262
37
|
def __init__(self, config: Dict[str, Any]):
|
263
38
|
self.config = config
|
264
39
|
self.host = config.get("host", "0.0.0.0")
|
265
40
|
self.port = config.get("port", 8765)
|
266
41
|
|
267
|
-
#
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
self.dataset_path,
|
284
|
-
self.dataset_type,
|
285
|
-
self.dataset_split,
|
286
|
-
self.dataset_image_column,
|
287
|
-
)
|
288
|
-
checkpoint_dir = Path(config.get("storage", {}).get("checkpoint_dir", "./checkpoints"))
|
289
|
-
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
290
|
-
self.shard_tracker = ShardTracker(checkpoint_dir / "shards.json")
|
291
|
-
self.chunk_tracker = ChunkTracker(checkpoint_dir / "chunks.json")
|
292
|
-
|
293
|
-
# vLLM configuration to distribute to workers
|
294
|
-
self.vllm_config = config.get(
|
295
|
-
"vllm",
|
296
|
-
{
|
297
|
-
"model": "Qwen/Qwen2.5-VL-3B-Instruct",
|
298
|
-
"gpu_memory_utilization": 0.92,
|
299
|
-
"max_model_len": 16384,
|
300
|
-
"tensor_parallel_size": 1,
|
301
|
-
"dtype": "float16",
|
302
|
-
"enforce_eager": True,
|
303
|
-
"limit_mm_per_prompt": {"image": 1},
|
304
|
-
"disable_mm_preprocessor_cache": True,
|
305
|
-
"sampling": {
|
306
|
-
"temperature": 0.7,
|
307
|
-
"top_p": 0.95,
|
308
|
-
"max_tokens": 256,
|
309
|
-
"repetition_penalty": 1.05,
|
310
|
-
"stop": ["<|end|>", "<|endoftext|>", "<|im_end|>"],
|
311
|
-
},
|
312
|
-
"inference_prompts": [
|
313
|
-
"describe this image in detail",
|
314
|
-
"provide a comprehensive description of the visual content",
|
315
|
-
"what are the key elements in this image?",
|
316
|
-
],
|
317
|
-
},
|
318
|
-
)
|
319
|
-
|
320
|
-
# Chunk configuration
|
321
|
-
self.chunk_size = config.get("chunk_size", 1000)
|
322
|
-
self.chunks_per_request = config.get("chunks_per_request", 2)
|
323
|
-
|
324
|
-
# Demand-driven chunk creation settings
|
325
|
-
self.chunk_buffer_multiplier = config.get("chunk_buffer_multiplier", 3)
|
326
|
-
self.min_chunk_buffer = config.get("min_chunk_buffer", 10)
|
42
|
+
# Processor configuration
|
43
|
+
processor_type = config.get("dataset", {}).get("processor_type", None)
|
44
|
+
assert (
|
45
|
+
processor_type is not None
|
46
|
+
), "You must supply processor_type in your orchestrator dataset configuration."
|
47
|
+
processor_config = ProcessorConfig(processor_type=processor_type, config=config)
|
48
|
+
|
49
|
+
# Initialize processor
|
50
|
+
if processor_type == "webdataset":
|
51
|
+
self.processor = WebDatasetOrchestratorProcessor()
|
52
|
+
elif processor_type == "huggingface_datasets":
|
53
|
+
self.processor = HuggingFaceDatasetOrchestratorProcessor()
|
54
|
+
elif processor_type == "local_filesystem":
|
55
|
+
self.processor = LocalFilesystemOrchestratorProcessor()
|
56
|
+
else:
|
57
|
+
raise ValueError(f"Unknown processor type: {processor_type}")
|
327
58
|
|
328
59
|
# Initialize components
|
329
60
|
storage_config = config.get("storage", {})
|
330
61
|
self.storage = StorageManager(
|
331
62
|
Path(storage_config.get("data_dir", "./caption_data")),
|
332
63
|
caption_buffer_size=storage_config.get("caption_buffer_size", 1000),
|
333
|
-
job_buffer_size=storage_config.get("job_buffer_size", 100),
|
334
|
-
contributor_buffer_size=storage_config.get("contributor_buffer_size", 10),
|
335
64
|
)
|
336
65
|
self.auth = AuthManager(config.get("auth", {}))
|
66
|
+
self.processor.initialize(processor_config, self.storage)
|
337
67
|
|
338
|
-
#
|
339
|
-
self.
|
340
|
-
self.shard_tracker = None
|
341
|
-
self.chunk_tracker = None
|
342
|
-
|
343
|
-
if self.dataset_path:
|
344
|
-
self.dataset_loader = DatasetLoader(self.dataset_path, self.dataset_type)
|
345
|
-
checkpoint_dir = Path(config.get("storage", {}).get("checkpoint_dir", "./checkpoints"))
|
346
|
-
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
347
|
-
self.shard_tracker = ShardTracker(checkpoint_dir / "shards.json")
|
348
|
-
self.chunk_tracker = ChunkTracker(checkpoint_dir / "chunks.json")
|
349
|
-
|
350
|
-
# Initialize chunk manager with reference to chunk tracker
|
351
|
-
self.chunk_manager = ChunkManager(self.chunk_size, self.chunk_tracker)
|
352
|
-
self.pending_processed_items = defaultdict(list) # chunk_id -> list of indices
|
353
|
-
self.item_batch_lock = threading.Lock()
|
354
|
-
self.last_item_batch_flush = time.time()
|
355
|
-
self.item_batch_interval = 5 # Flush every 5 seconds
|
356
|
-
self.item_batch_size = 100 # Or every 100 items
|
68
|
+
# Processing configuration
|
69
|
+
self.units_per_request = config.get("units_per_request", 2)
|
357
70
|
|
358
71
|
# Track connections
|
359
72
|
self.workers: Dict[str, WebSocketServerProtocol] = {}
|
360
73
|
self.monitors: Set[WebSocketServerProtocol] = set()
|
74
|
+
self.workers_by_user = defaultdict(set)
|
361
75
|
|
362
76
|
# SSL configuration
|
363
77
|
self.ssl_context = self._setup_ssl()
|
364
78
|
|
365
79
|
# Statistics
|
366
|
-
self.is_generating_stats = False
|
367
80
|
self.stats = {
|
368
|
-
"total_chunks": 0,
|
369
|
-
"completed_chunks": 0,
|
370
|
-
"failed_chunks": 0,
|
371
81
|
"connected_workers": 0,
|
372
|
-
"
|
373
|
-
"completed_shards": 0,
|
374
|
-
"current_shard": None,
|
82
|
+
"total_outputs": 0,
|
375
83
|
"last_checkpoint": None,
|
84
|
+
"processor_stats": {},
|
376
85
|
}
|
377
86
|
|
87
|
+
# Cache for leaderboard
|
88
|
+
self._cached_leaderboard = None
|
89
|
+
|
90
|
+
# Data worker stuff
|
91
|
+
self.data_workers = {}
|
92
|
+
self.data_sample_queue = asyncio.Queue()
|
93
|
+
self.backpressure_threshold = config.get("backpressure_threshold", 1000)
|
94
|
+
|
378
95
|
# Rate tracking
|
379
96
|
self.rate_tracker = {
|
380
97
|
"start_time": time.time(),
|
381
98
|
"last_update_time": time.time(),
|
382
|
-
"
|
99
|
+
"last_output_count": 0,
|
383
100
|
"current_rate": 0.0,
|
384
101
|
"average_rate": 0.0,
|
385
|
-
"expected_rate": 0.0,
|
386
102
|
}
|
387
103
|
|
388
|
-
# Data sample queue for CaptionWorker
|
389
|
-
self.data_sample_queue = asyncio.Queue(maxsize=1000)
|
390
|
-
self.data_workers: Dict[str, WebSocketServerProtocol] = {}
|
391
|
-
|
392
|
-
# Backpressure threshold
|
393
|
-
self.backpressure_threshold = config.get("backpressure_threshold", 800)
|
394
|
-
|
395
|
-
# Shard processing state
|
396
|
-
self.all_shards = []
|
397
|
-
self.current_shard_index = 0
|
398
|
-
self.shard_lock = threading.Lock()
|
399
|
-
|
400
|
-
# Background chunk creation
|
401
|
-
self.chunk_creation_thread = None
|
402
|
-
self.stop_chunk_creation = threading.Event()
|
403
|
-
|
404
|
-
# State restoration flag
|
405
|
-
self.state_restored = threading.Event()
|
406
|
-
# If no dataset, state is already "restored"
|
407
|
-
if not self.dataset_loader:
|
408
|
-
self.state_restored.set()
|
409
|
-
|
410
104
|
def _setup_ssl(self) -> Optional[ssl.SSLContext]:
|
411
105
|
"""Configure SSL if certificates are provided."""
|
412
106
|
ssl_config = self.config.get("ssl", {})
|
@@ -417,324 +111,20 @@ class Orchestrator:
|
|
417
111
|
context.load_cert_chain(ssl_config["cert"], ssl_config["key"])
|
418
112
|
return context
|
419
113
|
|
420
|
-
def _create_chunks_from_dataset(self):
|
421
|
-
"""Background thread to create chunks from dataset shards on demand."""
|
422
|
-
if not self.dataset_loader:
|
423
|
-
logger.warning("No dataset configured, skipping chunk creation")
|
424
|
-
self.state_restored.set() # No state to restore
|
425
|
-
return
|
426
|
-
|
427
|
-
logger.info("Starting chunk creation thread")
|
428
|
-
|
429
|
-
# Mark state as not restored until we process checkpoints
|
430
|
-
self.state_restored.clear()
|
431
|
-
|
432
|
-
# Get dataset info to check format
|
433
|
-
dataset_info = self.dataset_loader.get_dataset_info()
|
434
|
-
dataset_format = dataset_info.get("dataset_format", "unknown")
|
435
|
-
logger.info(f"Dataset format: {dataset_format}")
|
436
|
-
|
437
|
-
# Get all shards
|
438
|
-
self.all_shards = self.dataset_loader.get_shard_list()
|
439
|
-
self.stats["total_shards"] = len(self.all_shards)
|
440
|
-
|
441
|
-
# For HuggingFace datasets, we might need to dynamically create more shards
|
442
|
-
if dataset_format == "huggingface_datasets":
|
443
|
-
self._is_hf_dataset = True
|
444
|
-
self._hf_chunk_size = 10000 # Items per virtual shard
|
445
|
-
self._next_hf_shard_index = len(self.all_shards) # For creating new virtual shards
|
446
|
-
else:
|
447
|
-
self._is_hf_dataset = False
|
448
|
-
|
449
|
-
# Get shard status from ChunkTracker
|
450
|
-
shards_summary = self.chunk_tracker.get_shards_summary() if self.chunk_tracker else {}
|
451
|
-
completed_shards = {
|
452
|
-
shard_name for shard_name, info in shards_summary.items() if info["is_complete"]
|
453
|
-
}
|
454
|
-
|
455
|
-
# Update ShardTracker for completed shards
|
456
|
-
for shard_name in completed_shards:
|
457
|
-
if not self.shard_tracker.is_complete(shard_name):
|
458
|
-
logger.info(f"Marking shard {shard_name} as complete in ShardTracker")
|
459
|
-
self.shard_tracker.mark_complete(shard_name)
|
460
|
-
|
461
|
-
# Get shards that need processing
|
462
|
-
remaining_shards = self.shard_tracker.get_remaining_shards(self.all_shards)
|
463
|
-
|
464
|
-
# Also check which shards already have chunks (partial or complete)
|
465
|
-
shards_with_chunks = set()
|
466
|
-
for shard_name in shards_summary.keys():
|
467
|
-
shards_with_chunks.add(shard_name)
|
468
|
-
|
469
|
-
# Filter out shards that already have chunks created
|
470
|
-
remaining_shards = [
|
471
|
-
shard
|
472
|
-
for shard in remaining_shards
|
473
|
-
if (shard if shard.startswith("hf_dataset:") else Path(shard).stem)
|
474
|
-
not in shards_with_chunks
|
475
|
-
]
|
476
|
-
|
477
|
-
self.stats["completed_shards"] = len(completed_shards)
|
478
|
-
|
479
|
-
logger.info(
|
480
|
-
f"Total shards: {len(self.all_shards)}, "
|
481
|
-
f"Completed: {self.stats['completed_shards']}, "
|
482
|
-
f"Shards with chunks: {len(shards_with_chunks)}, "
|
483
|
-
f"Remaining to process: {len(remaining_shards)}"
|
484
|
-
)
|
485
|
-
|
486
|
-
# First, re-queue any existing pending chunks
|
487
|
-
initial_pending = 0
|
488
|
-
requeued_chunks_by_shard = defaultdict(list)
|
489
|
-
|
490
|
-
for shard_name, shard_info in shards_summary.items():
|
491
|
-
with self.chunk_manager.lock:
|
492
|
-
for chunk_state in shard_info["chunks"]:
|
493
|
-
if chunk_state.status in ["pending", "failed", "assigned"]:
|
494
|
-
# ChunkState already has shard_url stored
|
495
|
-
chunk = ShardChunk(
|
496
|
-
chunk_id=chunk_state.chunk_id,
|
497
|
-
shard_url=chunk_state.shard_url,
|
498
|
-
shard_name=chunk_state.shard_name,
|
499
|
-
start_index=chunk_state.start_index,
|
500
|
-
chunk_size=chunk_state.chunk_size,
|
501
|
-
)
|
502
|
-
self.chunk_manager.chunks[chunk_state.chunk_id] = chunk
|
503
|
-
self.chunk_manager.pending_chunks.append(chunk_state.chunk_id)
|
504
|
-
requeued_chunks_by_shard[shard_name].append(chunk_state.chunk_id)
|
505
|
-
initial_pending += 1
|
506
|
-
|
507
|
-
logger.info(f"Re-queued {initial_pending} existing pending chunks")
|
508
|
-
for shard_name, chunk_ids in requeued_chunks_by_shard.items():
|
509
|
-
logger.info(f" Shard {shard_name}: {len(chunk_ids)} chunks - {chunk_ids}")
|
510
|
-
|
511
|
-
# Mark state as restored
|
512
|
-
self.state_restored.set()
|
513
|
-
logger.info("State restoration complete, accepting chunk requests")
|
514
|
-
|
515
|
-
# Process shards on-demand
|
516
|
-
shard_iter = iter(remaining_shards)
|
517
|
-
current_shard_url = None
|
518
|
-
current_shard_name = None
|
519
|
-
current_shard_items = None
|
520
|
-
current_shard_index = 0
|
521
|
-
|
522
|
-
while not self.stop_chunk_creation.is_set():
|
523
|
-
# Check how many chunks we need
|
524
|
-
with self.chunk_manager.lock:
|
525
|
-
pending_count = len(self.chunk_manager.pending_chunks)
|
526
|
-
assigned_count = sum(
|
527
|
-
len(chunks) for chunks in self.chunk_manager.assigned_chunks.values()
|
528
|
-
)
|
529
|
-
total_active = pending_count + assigned_count
|
530
|
-
|
531
|
-
# Target buffer: configurable multiplier × number of workers
|
532
|
-
worker_count = max(1, self.stats.get("connected_workers", 0))
|
533
|
-
target_buffer = max(
|
534
|
-
self.min_chunk_buffer, worker_count * self.chunk_buffer_multiplier
|
535
|
-
)
|
536
|
-
|
537
|
-
chunks_needed = max(0, target_buffer - total_active)
|
538
|
-
|
539
|
-
if chunks_needed == 0:
|
540
|
-
# We have enough chunks, wait a bit
|
541
|
-
time.sleep(5)
|
542
|
-
continue
|
543
|
-
|
544
|
-
logger.debug(
|
545
|
-
f"Need {chunks_needed} more chunks (pending: {pending_count}, "
|
546
|
-
f"assigned: {assigned_count}, workers: {worker_count})"
|
547
|
-
)
|
548
|
-
|
549
|
-
# Create chunks as needed
|
550
|
-
chunks_created = 0
|
551
|
-
|
552
|
-
while chunks_created < chunks_needed and not self.stop_chunk_creation.is_set():
|
553
|
-
# Need to load next shard?
|
554
|
-
if current_shard_url is None or current_shard_index >= current_shard_items:
|
555
|
-
try:
|
556
|
-
current_shard_url = next(shard_iter)
|
557
|
-
|
558
|
-
# Extract shard name based on type
|
559
|
-
if current_shard_url.startswith("hf_dataset:"):
|
560
|
-
current_shard_name = current_shard_url # Use full ID for virtual shards
|
561
|
-
else:
|
562
|
-
current_shard_name = Path(current_shard_url).stem
|
563
|
-
|
564
|
-
self.stats["current_shard"] = current_shard_name
|
565
|
-
|
566
|
-
# Skip if we already have chunks from this shard
|
567
|
-
if current_shard_name in shards_summary:
|
568
|
-
logger.debug(
|
569
|
-
f"Skipping shard {current_shard_name} - already has chunks"
|
570
|
-
)
|
571
|
-
current_shard_url = None
|
572
|
-
continue
|
573
|
-
|
574
|
-
# Count items in new shard
|
575
|
-
logger.info(f"Loading new shard {current_shard_name}")
|
576
|
-
|
577
|
-
# For virtual HF dataset shards, use the chunk size directly
|
578
|
-
if current_shard_url.startswith("hf_dataset:"):
|
579
|
-
current_shard_items = self.dataset_loader.count_shard_items(
|
580
|
-
current_shard_url
|
581
|
-
)
|
582
|
-
logger.info(
|
583
|
-
f"Virtual shard {current_shard_name} has {current_shard_items} items"
|
584
|
-
)
|
585
|
-
else:
|
586
|
-
# For WebDataset, actually count items
|
587
|
-
current_shard_items = sum(
|
588
|
-
1 for _ in self.dataset_loader.iterate_shard(current_shard_url)
|
589
|
-
)
|
590
|
-
logger.info(
|
591
|
-
f"Shard {current_shard_name} has {current_shard_items} items"
|
592
|
-
)
|
593
|
-
|
594
|
-
current_shard_index = 0
|
595
|
-
|
596
|
-
except StopIteration:
|
597
|
-
# No more shards in the iterator
|
598
|
-
if self._is_hf_dataset:
|
599
|
-
# Before creating new virtual shards, check if we have pending chunks
|
600
|
-
with self.chunk_manager.lock:
|
601
|
-
pending_count = len(self.chunk_manager.pending_chunks)
|
602
|
-
|
603
|
-
if pending_count > 0:
|
604
|
-
# Don't create new shards if we have pending chunks
|
605
|
-
logger.debug(
|
606
|
-
f"Have {pending_count} pending chunks, not creating new virtual shards yet"
|
607
|
-
)
|
608
|
-
current_shard_url = None
|
609
|
-
time.sleep(2)
|
610
|
-
continue
|
611
|
-
|
612
|
-
# For HF datasets, we can create more virtual shards on demand
|
613
|
-
logger.info(
|
614
|
-
"Creating additional virtual shards for HuggingFace dataset"
|
615
|
-
)
|
616
|
-
|
617
|
-
# Create 10 more virtual shards
|
618
|
-
new_shards = []
|
619
|
-
for i in range(10):
|
620
|
-
shard_id = f"hf_dataset:{self.dataset_path}:chunk:{self._next_hf_shard_index * self._hf_chunk_size}"
|
621
|
-
new_shards.append(shard_id)
|
622
|
-
self._next_hf_shard_index += 1
|
623
|
-
|
624
|
-
# Add to all_shards and create new iterator
|
625
|
-
self.all_shards.extend(new_shards)
|
626
|
-
self.stats["total_shards"] = len(self.all_shards)
|
627
|
-
|
628
|
-
# Filter for unprocessed shards
|
629
|
-
remaining_new_shards = [
|
630
|
-
s
|
631
|
-
for s in new_shards
|
632
|
-
if s not in shards_summary and s not in completed_shards
|
633
|
-
]
|
634
|
-
|
635
|
-
if remaining_new_shards:
|
636
|
-
shard_iter = iter(remaining_new_shards)
|
637
|
-
logger.info(f"Added {len(remaining_new_shards)} new virtual shards")
|
638
|
-
continue
|
639
|
-
|
640
|
-
# No more shards to process
|
641
|
-
logger.info("No more shards to process")
|
642
|
-
break
|
643
|
-
|
644
|
-
except Exception as e:
|
645
|
-
logger.error(f"Error loading shard {current_shard_name}: {e}")
|
646
|
-
current_shard_url = None
|
647
|
-
continue
|
648
|
-
|
649
|
-
# Create a chunk from current shard
|
650
|
-
if current_shard_url and current_shard_index < current_shard_items:
|
651
|
-
# Calculate the absolute dataset index for this chunk
|
652
|
-
if current_shard_url.startswith("hf_dataset:"):
|
653
|
-
# Parse the virtual shard URL to get the base start index
|
654
|
-
parts = current_shard_url.split(":")
|
655
|
-
if len(parts) >= 4 and parts[2] == "chunk":
|
656
|
-
shard_base_index = int(parts[3])
|
657
|
-
else:
|
658
|
-
shard_base_index = 0
|
659
|
-
|
660
|
-
# The absolute start index for this chunk in the dataset
|
661
|
-
absolute_start_index = shard_base_index + current_shard_index
|
662
|
-
else:
|
663
|
-
# For WebDataset, current_shard_index is already absolute
|
664
|
-
absolute_start_index = current_shard_index
|
665
|
-
|
666
|
-
# Create chunk with absolute index
|
667
|
-
chunk = ShardChunk.create(
|
668
|
-
shard_url=current_shard_url,
|
669
|
-
shard_name=current_shard_name,
|
670
|
-
start_index=absolute_start_index,
|
671
|
-
chunk_size=min(self.chunk_size, current_shard_items - current_shard_index),
|
672
|
-
)
|
673
|
-
|
674
|
-
# Add to ChunkTracker with all required fields
|
675
|
-
if self.chunk_tracker and self.chunk_tracker.add_chunk(
|
676
|
-
chunk.chunk_id,
|
677
|
-
chunk.shard_name,
|
678
|
-
chunk.shard_url,
|
679
|
-
chunk.start_index,
|
680
|
-
chunk.chunk_size,
|
681
|
-
):
|
682
|
-
with self.chunk_manager.lock:
|
683
|
-
self.chunk_manager.chunks[chunk.chunk_id] = chunk
|
684
|
-
self.chunk_manager.pending_chunks.append(chunk.chunk_id)
|
685
|
-
|
686
|
-
chunks_created += 1
|
687
|
-
self.stats["total_chunks"] += 1
|
688
|
-
|
689
|
-
current_shard_index += self.chunk_size
|
690
|
-
|
691
|
-
if chunks_created > 0:
|
692
|
-
logger.info(f"Created {chunks_created} chunks on demand")
|
693
|
-
|
694
|
-
# If we couldn't create any chunks and there are no more shards, check if it's HF dataset
|
695
|
-
if chunks_created == 0 and current_shard_url is None:
|
696
|
-
if self._is_hf_dataset:
|
697
|
-
# We can always create more virtual shards for HF datasets
|
698
|
-
logger.debug("Will create more virtual shards on next iteration")
|
699
|
-
else:
|
700
|
-
logger.info("All shards processed, chunk creation complete")
|
701
|
-
break
|
702
|
-
|
703
|
-
# Brief pause to avoid spinning
|
704
|
-
time.sleep(1)
|
705
|
-
|
706
|
-
# Final stats
|
707
|
-
if self.chunk_tracker:
|
708
|
-
final_stats = self.chunk_tracker.get_stats()
|
709
|
-
logger.info(
|
710
|
-
f"Chunk creation thread ending. Total: {final_stats['total']}, "
|
711
|
-
f"Pending: {final_stats['pending']}, Completed: {final_stats['completed']}"
|
712
|
-
)
|
713
|
-
|
714
|
-
logger.info("Chunk creation thread finished")
|
715
|
-
|
716
114
|
async def start(self):
|
717
115
|
"""Start the orchestrator server."""
|
718
|
-
logger.info(f"Starting
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
await self.storage.initialize()
|
725
|
-
if self.chunk_tracker:
|
726
|
-
await self.chunk_tracker.sync_with_storage(self.storage)
|
727
|
-
await self._restore_state()
|
728
|
-
|
729
|
-
# Start chunk creation thread if dataset is configured
|
730
|
-
if self.dataset_loader:
|
731
|
-
self.chunk_creation_thread = threading.Thread(
|
732
|
-
target=self._create_chunks_from_dataset, daemon=True
|
116
|
+
logger.info(f"Starting orchestrator on {self.host}:{self.port}")
|
117
|
+
processor_type = self.config.get("dataset", {}).get("processor_type", None)
|
118
|
+
if not processor_type:
|
119
|
+
logger.info(f"Config: {self.config}")
|
120
|
+
raise ValueError(
|
121
|
+
"You must supply processor_type in your orchestrator dataset configuration."
|
733
122
|
)
|
734
|
-
|
123
|
+
logger.info(f"Processor type: {processor_type}")
|
735
124
|
|
736
|
-
|
737
|
-
|
125
|
+
# Initialize storage
|
126
|
+
await self.storage.initialize()
|
127
|
+
await self.update_unprocessed_ranges()
|
738
128
|
|
739
129
|
# Start background tasks
|
740
130
|
asyncio.create_task(self._heartbeat_loop())
|
@@ -742,12 +132,37 @@ class Orchestrator:
|
|
742
132
|
asyncio.create_task(self._stats_update_loop())
|
743
133
|
|
744
134
|
# Start WebSocket server
|
135
|
+
websocket_logger = logging.getLogger("websockets")
|
136
|
+
websocket_logger.setLevel(logging.WARNING)
|
745
137
|
async with websockets.serve(
|
746
|
-
self.handle_connection,
|
138
|
+
self.handle_connection,
|
139
|
+
self.host,
|
140
|
+
self.port,
|
141
|
+
ssl=self.ssl_context,
|
142
|
+
logger=websocket_logger,
|
747
143
|
):
|
748
|
-
logger.info("
|
144
|
+
logger.info("Orchestrator ready for connections")
|
749
145
|
await asyncio.Future() # Run forever
|
750
146
|
|
147
|
+
def get_workers_by_user_stats(self) -> Dict[str, Dict]:
|
148
|
+
"""Get worker statistics grouped by user."""
|
149
|
+
stats = {}
|
150
|
+
for user, worker_ids in self.workers_by_user.items():
|
151
|
+
stats[user] = {"worker_ids": list(worker_ids), "count": len(worker_ids)}
|
152
|
+
return stats
|
153
|
+
|
154
|
+
async def update_unprocessed_ranges(self):
|
155
|
+
"""Update unprocessed ranges based on what's actually in storage."""
|
156
|
+
if not self.processor or not self.storage:
|
157
|
+
return
|
158
|
+
|
159
|
+
processed_job_ids = self.storage.get_all_processed_job_ids()
|
160
|
+
self.processor.update_from_storage(processed_job_ids)
|
161
|
+
|
162
|
+
async def _send_leaderboard_to_monitor(self, websocket: WebSocketServerProtocol):
|
163
|
+
"""Alias for _send_monitor_leaderboard for backward compatibility."""
|
164
|
+
await self._send_monitor_leaderboard(websocket)
|
165
|
+
|
751
166
|
async def handle_connection(self, websocket: WebSocketServerProtocol):
|
752
167
|
"""Handle new WebSocket connection."""
|
753
168
|
try:
|
@@ -762,49 +177,77 @@ class Orchestrator:
|
|
762
177
|
|
763
178
|
if auth_ticket.role == "worker":
|
764
179
|
await self._handle_worker(websocket, auth_ticket)
|
765
|
-
elif auth_ticket.role == "data_worker":
|
766
|
-
await self._handle_data_worker(websocket, auth_ticket)
|
767
180
|
elif auth_ticket.role == "monitor":
|
768
181
|
await self._handle_monitor(websocket)
|
769
182
|
elif auth_ticket.role == "admin":
|
770
183
|
await self._handle_admin(websocket, auth_ticket)
|
184
|
+
elif auth_ticket.role == "data_worker":
|
185
|
+
await self._handle_data_worker(websocket, auth_ticket)
|
771
186
|
else:
|
772
187
|
await websocket.send(
|
773
188
|
safe_json_dumps({"error": f"Unknown role: {auth_ticket.role}"})
|
774
189
|
)
|
775
190
|
|
776
191
|
except Exception as e:
|
777
|
-
logger.error(f"Connection error: {e}")
|
778
|
-
import traceback
|
779
|
-
|
780
|
-
logger.error(traceback.format_exc())
|
192
|
+
logger.error(f"Connection error: {e}", exc_info=True)
|
781
193
|
await websocket.close()
|
782
194
|
|
783
|
-
async def
|
784
|
-
"""Handle
|
785
|
-
|
786
|
-
|
195
|
+
async def _handle_worker(self, websocket: WebSocketServerProtocol, auth_ticket):
|
196
|
+
"""Handle worker connection lifecycle."""
|
197
|
+
# Generate unique worker ID
|
198
|
+
base_name = getattr(auth_ticket, "name", "worker")
|
199
|
+
worker_id = f"{base_name}_{str(uuid.uuid4())[:8]}"
|
200
|
+
worker_user = base_name
|
201
|
+
|
202
|
+
self.workers[worker_id] = websocket
|
203
|
+
self.workers_by_user[worker_user].add(worker_id)
|
204
|
+
self.stats["connected_workers"] = len(self.workers)
|
787
205
|
|
206
|
+
# Register contributor
|
207
|
+
contributor = await self.storage.get_contributor(worker_user)
|
208
|
+
if not contributor:
|
209
|
+
contributor = Contributor(
|
210
|
+
contributor_id=worker_user,
|
211
|
+
name=worker_user,
|
212
|
+
total_captions=0,
|
213
|
+
trust_level=1,
|
214
|
+
)
|
215
|
+
await self.storage.save_contributor(contributor)
|
216
|
+
|
217
|
+
logger.info(f"Worker {worker_id} (user: {worker_user}) is retrieving configuration")
|
788
218
|
try:
|
789
|
-
# Send welcome
|
790
|
-
|
219
|
+
# Send welcome message with processor config
|
220
|
+
filtered_config = self.config.copy()
|
221
|
+
for unwanted_key in ["auth", "orchestrator", "storage"]:
|
222
|
+
filtered_config.pop(unwanted_key, None)
|
223
|
+
welcome_message = {
|
224
|
+
"type": "welcome",
|
225
|
+
"worker_id": worker_id,
|
226
|
+
"user_id": worker_user,
|
227
|
+
"processor_type": self.config.get("dataset", {}).get("processor_type", None),
|
228
|
+
"processor_config": filtered_config,
|
229
|
+
}
|
230
|
+
await websocket.send(safe_json_dumps(welcome_message))
|
791
231
|
|
792
232
|
async for message in websocket:
|
793
|
-
|
794
|
-
|
795
|
-
msg_type = data.get("type")
|
233
|
+
data = json.loads(message)
|
234
|
+
await self._process_worker_message(worker_id, data)
|
796
235
|
|
797
|
-
|
798
|
-
|
236
|
+
except websockets.exceptions.ConnectionClosed:
|
237
|
+
logger.info(f"Worker {worker_id} has disconnected due to websocket connection closure")
|
238
|
+
finally:
|
239
|
+
if worker_id in self.workers:
|
240
|
+
del self.workers[worker_id]
|
799
241
|
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
safe_json_dumps({"type": "error", "error": "Invalid message format"})
|
804
|
-
)
|
242
|
+
self.workers_by_user[worker_user].discard(worker_id)
|
243
|
+
if not self.workers_by_user[worker_user]:
|
244
|
+
del self.workers_by_user[worker_user]
|
805
245
|
|
806
|
-
|
807
|
-
|
246
|
+
self.stats["connected_workers"] = len(self.workers)
|
247
|
+
|
248
|
+
# Release assignments
|
249
|
+
self.processor.release_assignments(worker_id)
|
250
|
+
logger.info(f"Worker {worker_id} has safely disconnected")
|
808
251
|
|
809
252
|
async def _handle_config_reload(
|
810
253
|
self, websocket: WebSocketServerProtocol, new_config: Dict[str, Any]
|
@@ -814,224 +257,61 @@ class Orchestrator:
|
|
814
257
|
|
815
258
|
updated_sections = []
|
816
259
|
warnings = []
|
817
|
-
requires_worker_restart = False
|
818
260
|
|
819
261
|
try:
|
820
262
|
# Extract orchestrator section if present
|
821
263
|
if "orchestrator" in new_config:
|
822
|
-
# Config has orchestrator wrapper, extract it
|
823
264
|
orchestrator_config = new_config["orchestrator"]
|
824
265
|
else:
|
825
|
-
# Config is already at orchestrator level
|
826
266
|
orchestrator_config = new_config
|
827
267
|
|
828
|
-
#
|
829
|
-
|
830
|
-
|
831
|
-
|
832
|
-
|
833
|
-
if
|
834
|
-
|
835
|
-
|
836
|
-
return all(deep_equal(a[k], b[k]) for k in a.keys())
|
837
|
-
elif isinstance(a, (list, tuple)):
|
838
|
-
if len(a) != len(b):
|
839
|
-
return False
|
840
|
-
return all(deep_equal(x, y) for x, y in zip(a, b))
|
268
|
+
# Update processor configuration if present
|
269
|
+
if "processor_type" in orchestrator_config:
|
270
|
+
old_type = self.config.get("processor_type")
|
271
|
+
new_type = orchestrator_config["processor_type"]
|
272
|
+
|
273
|
+
if old_type != new_type:
|
274
|
+
warnings.append("Processor type changes require orchestrator restart")
|
275
|
+
updated_sections.append("processor_type")
|
841
276
|
else:
|
842
|
-
|
843
|
-
|
844
|
-
|
845
|
-
|
846
|
-
|
847
|
-
|
848
|
-
|
849
|
-
|
850
|
-
|
851
|
-
|
852
|
-
|
853
|
-
|
854
|
-
|
855
|
-
|
856
|
-
|
857
|
-
# Check if critical changes require worker restart
|
858
|
-
if (
|
859
|
-
old_vllm.get("model") != new_vllm.get("model")
|
860
|
-
or old_vllm.get("gpu_memory_utilization")
|
861
|
-
!= new_vllm.get("gpu_memory_utilization")
|
862
|
-
or old_vllm.get("tensor_parallel_size")
|
863
|
-
!= new_vllm.get("tensor_parallel_size")
|
864
|
-
or old_vllm.get("dtype") != new_vllm.get("dtype")
|
865
|
-
or old_vllm.get("max_model_len") != new_vllm.get("max_model_len")
|
866
|
-
):
|
867
|
-
requires_worker_restart = True
|
868
|
-
warnings.append(
|
869
|
-
"Critical vLLM changes detected - workers will be disconnected to reload"
|
870
|
-
)
|
871
|
-
logger.info(
|
872
|
-
f"Model change: {old_vllm.get('model')} -> {new_vllm.get('model')}"
|
873
|
-
)
|
874
|
-
|
875
|
-
# Update dataset configuration
|
876
|
-
if "dataset" in orchestrator_config:
|
877
|
-
old_dataset = self.dataset_config.copy()
|
878
|
-
new_dataset = orchestrator_config["dataset"]
|
879
|
-
|
880
|
-
dataset_changed = not deep_equal(old_dataset, new_dataset)
|
881
|
-
|
882
|
-
if dataset_changed:
|
883
|
-
self.dataset_config = new_dataset.copy()
|
884
|
-
self.dataset_path = self.dataset_config.get("path")
|
885
|
-
self.dataset_type = self.dataset_config.get("type", "huggingface")
|
886
|
-
updated_sections.append("dataset")
|
887
|
-
warnings.append("Dataset changes will apply to new chunks only")
|
888
|
-
|
889
|
-
# Update chunk settings
|
890
|
-
if (
|
891
|
-
"chunk_size" in orchestrator_config
|
892
|
-
and self.chunk_size != orchestrator_config["chunk_size"]
|
893
|
-
):
|
894
|
-
self.chunk_size = orchestrator_config["chunk_size"]
|
895
|
-
self.chunk_manager.chunk_size = self.chunk_size
|
896
|
-
updated_sections.append("chunk_size")
|
897
|
-
|
898
|
-
if (
|
899
|
-
"chunks_per_request" in orchestrator_config
|
900
|
-
and self.chunks_per_request != orchestrator_config["chunks_per_request"]
|
901
|
-
):
|
902
|
-
self.chunks_per_request = orchestrator_config["chunks_per_request"]
|
903
|
-
updated_sections.append("chunks_per_request")
|
277
|
+
# Update processor config
|
278
|
+
self.config.update(orchestrator_config)
|
279
|
+
|
280
|
+
# Reinitialize processor with new config
|
281
|
+
processor_config = ProcessorConfig(
|
282
|
+
processor_type=new_type, config=orchestrator_config
|
283
|
+
)
|
284
|
+
self.processor.initialize(processor_config)
|
285
|
+
updated_sections.append("processor_config")
|
286
|
+
|
287
|
+
# Update units per request
|
288
|
+
if "units_per_request" in orchestrator_config:
|
289
|
+
self.units_per_request = orchestrator_config["units_per_request"]
|
290
|
+
updated_sections.append("units_per_request")
|
904
291
|
|
905
292
|
# Update auth configuration
|
906
293
|
if "auth" in orchestrator_config:
|
907
294
|
try:
|
908
|
-
self.auth = AuthManager(
|
295
|
+
self.auth = AuthManager(orchestrator_config["auth"])
|
909
296
|
updated_sections.append("auth")
|
910
297
|
except Exception as e:
|
911
298
|
logger.error(f"Failed to update AuthManager: {e}")
|
912
299
|
warnings.append(f"Auth update failed: {e}")
|
913
300
|
|
914
|
-
# Update buffer settings
|
915
|
-
if (
|
916
|
-
"chunk_buffer_multiplier" in orchestrator_config
|
917
|
-
and self.chunk_buffer_multiplier != orchestrator_config["chunk_buffer_multiplier"]
|
918
|
-
):
|
919
|
-
self.chunk_buffer_multiplier = orchestrator_config["chunk_buffer_multiplier"]
|
920
|
-
updated_sections.append("chunk_buffer_multiplier")
|
921
|
-
|
922
|
-
if (
|
923
|
-
"min_chunk_buffer" in orchestrator_config
|
924
|
-
and self.min_chunk_buffer != orchestrator_config["min_chunk_buffer"]
|
925
|
-
):
|
926
|
-
self.min_chunk_buffer = orchestrator_config["min_chunk_buffer"]
|
927
|
-
updated_sections.append("min_chunk_buffer")
|
928
|
-
|
929
301
|
# Update storage settings
|
930
302
|
if "storage" in orchestrator_config:
|
931
303
|
storage_config = orchestrator_config["storage"]
|
932
|
-
storage_changed = False
|
933
304
|
|
934
|
-
if
|
935
|
-
"caption_buffer_size" in storage_config
|
936
|
-
and self.storage.caption_buffer_size != storage_config["caption_buffer_size"]
|
937
|
-
):
|
305
|
+
if "caption_buffer_size" in storage_config:
|
938
306
|
self.storage.caption_buffer_size = storage_config["caption_buffer_size"]
|
939
|
-
|
940
|
-
|
941
|
-
if "checkpoint_interval" in storage_config:
|
942
|
-
current_interval = self.config.get("storage", {}).get(
|
943
|
-
"checkpoint_interval", 1000
|
944
|
-
)
|
945
|
-
if current_interval != storage_config["checkpoint_interval"]:
|
946
|
-
self.config.setdefault("storage", {})["checkpoint_interval"] = (
|
947
|
-
storage_config["checkpoint_interval"]
|
948
|
-
)
|
949
|
-
storage_changed = True
|
307
|
+
updated_sections.append("storage.caption_buffer_size")
|
950
308
|
|
951
|
-
|
952
|
-
updated_sections.append("storage")
|
953
|
-
|
954
|
-
# Check if any changes were made
|
955
|
-
if not updated_sections:
|
956
|
-
await websocket.send(
|
957
|
-
safe_json_dumps(
|
958
|
-
{
|
959
|
-
"type": "reload_complete",
|
960
|
-
"message": "No changes applied - configuration is identical",
|
961
|
-
}
|
962
|
-
)
|
963
|
-
)
|
964
|
-
logger.info("Configuration reload requested but no changes detected")
|
965
|
-
return
|
966
|
-
|
967
|
-
# Update the main config
|
309
|
+
# Update main config
|
968
310
|
if "orchestrator" in new_config:
|
969
|
-
self.config["orchestrator"]
|
311
|
+
self.config = new_config["orchestrator"]
|
970
312
|
else:
|
971
313
|
self.config.update(orchestrator_config)
|
972
314
|
|
973
|
-
# Handle worker restart if needed
|
974
|
-
if requires_worker_restart:
|
975
|
-
logger.info("Disconnecting all workers for configuration reload...")
|
976
|
-
|
977
|
-
# Send reload message to workers first
|
978
|
-
reload_msg = safe_json_dumps(
|
979
|
-
{
|
980
|
-
"type": "reload_vllm",
|
981
|
-
"vllm_config": self.vllm_config,
|
982
|
-
}
|
983
|
-
)
|
984
|
-
|
985
|
-
# Create a list of worker items to avoid modifying dict during iteration
|
986
|
-
worker_items = list(self.workers.items())
|
987
|
-
disconnected = []
|
988
|
-
|
989
|
-
for worker_id, ws in worker_items:
|
990
|
-
try:
|
991
|
-
await ws.send(reload_msg)
|
992
|
-
# Give worker time to process before disconnect
|
993
|
-
await asyncio.sleep(0.5)
|
994
|
-
await ws.close(code=1012, reason="Configuration reload")
|
995
|
-
disconnected.append(worker_id)
|
996
|
-
except:
|
997
|
-
disconnected.append(worker_id) # Still mark as disconnected if error
|
998
|
-
|
999
|
-
# Now safely clear workers dict
|
1000
|
-
for worker_id in disconnected:
|
1001
|
-
if worker_id in self.workers:
|
1002
|
-
del self.workers[worker_id]
|
1003
|
-
|
1004
|
-
warnings.append(
|
1005
|
-
f"Sent reload message to {len(disconnected)} workers - they will reconnect with new config"
|
1006
|
-
)
|
1007
|
-
else:
|
1008
|
-
# Just notify workers about config changes without disconnecting
|
1009
|
-
config_update_msg = safe_json_dumps(
|
1010
|
-
{
|
1011
|
-
"type": "config_update",
|
1012
|
-
"vllm_config": self.vllm_config if "vllm" in updated_sections else None,
|
1013
|
-
"dataset_config": (
|
1014
|
-
self.dataset_config if "dataset" in updated_sections else None
|
1015
|
-
),
|
1016
|
-
}
|
1017
|
-
)
|
1018
|
-
|
1019
|
-
# Create a list of worker items to avoid modifying dict during iteration
|
1020
|
-
worker_items = list(self.workers.items())
|
1021
|
-
disconnected = []
|
1022
|
-
|
1023
|
-
for worker_id, ws in worker_items:
|
1024
|
-
try:
|
1025
|
-
await ws.send(config_update_msg)
|
1026
|
-
logger.info(f"Sent config update to worker {worker_id}")
|
1027
|
-
except:
|
1028
|
-
disconnected.append(worker_id)
|
1029
|
-
|
1030
|
-
# Now safely remove disconnected workers
|
1031
|
-
for worker_id in disconnected:
|
1032
|
-
if worker_id in self.workers:
|
1033
|
-
del self.workers[worker_id]
|
1034
|
-
|
1035
315
|
# Send success response
|
1036
316
|
await websocket.send(
|
1037
317
|
safe_json_dumps(
|
@@ -1040,306 +320,169 @@ class Orchestrator:
|
|
1040
320
|
)
|
1041
321
|
|
1042
322
|
logger.info(f"Configuration reloaded. Updated sections: {', '.join(updated_sections)}")
|
1043
|
-
|
1044
|
-
# Broadcast stats update to monitors
|
1045
|
-
await self._broadcast_stats()
|
1046
323
|
await self._send_activity(
|
1047
324
|
f"Configuration reloaded by admin: {', '.join(updated_sections)}"
|
1048
325
|
)
|
1049
326
|
|
1050
327
|
except Exception as e:
|
1051
328
|
logger.error(f"Configuration reload failed: {e}")
|
1052
|
-
import traceback
|
1053
|
-
|
1054
|
-
logger.error(traceback.format_exc())
|
1055
329
|
await websocket.send(safe_json_dumps({"type": "reload_failed", "error": str(e)}))
|
1056
330
|
|
1057
|
-
async def _handle_worker(self, websocket: WebSocketServerProtocol, auth_ticket):
|
1058
|
-
"""Handle worker connection lifecycle."""
|
1059
|
-
# Generate unique worker ID even if using same token
|
1060
|
-
base_name = getattr(auth_ticket, "name", "worker")
|
1061
|
-
worker_id = f"{base_name}_{str(uuid.uuid4())[:8]}" # Add unique suffix
|
1062
|
-
|
1063
|
-
# Track the original token/user for accounting
|
1064
|
-
worker_user = base_name # Keep track of which user/token this worker belongs to
|
1065
|
-
|
1066
|
-
self.workers[worker_id] = websocket
|
1067
|
-
self.stats["connected_workers"] = len(self.workers)
|
1068
|
-
|
1069
|
-
# Optionally track workers by user/token
|
1070
|
-
if not hasattr(self, "workers_by_user"):
|
1071
|
-
self.workers_by_user = defaultdict(set)
|
1072
|
-
self.workers_by_user[worker_user].add(worker_id)
|
1073
|
-
|
1074
|
-
# Register contributor with the base name (for aggregating stats per user)
|
1075
|
-
contributor = await self.storage.get_contributor(worker_user)
|
1076
|
-
if not contributor:
|
1077
|
-
contributor = Contributor(
|
1078
|
-
contributor_id=worker_user,
|
1079
|
-
name=worker_user,
|
1080
|
-
total_captions=0,
|
1081
|
-
trust_level=1,
|
1082
|
-
)
|
1083
|
-
await self.storage.save_contributor(contributor)
|
1084
|
-
|
1085
|
-
logger.info(f"Worker {worker_id} (user: {worker_user}) connected")
|
1086
|
-
await self._broadcast_stats()
|
1087
|
-
await self._send_activity(f"Worker {worker_id} (user: {worker_user}) connected")
|
1088
|
-
|
1089
|
-
try:
|
1090
|
-
# Send welcome message with dataset configuration
|
1091
|
-
welcome_message = {
|
1092
|
-
"type": "welcome",
|
1093
|
-
"worker_id": worker_id,
|
1094
|
-
"user_id": worker_user,
|
1095
|
-
"dataset_config": {
|
1096
|
-
"dataset_path": self.dataset_path,
|
1097
|
-
"dataset_type": self.dataset_type,
|
1098
|
-
"dataset_split": self.dataset_split,
|
1099
|
-
"dataset_image_column": self.dataset_image_column,
|
1100
|
-
"path": self.dataset_path,
|
1101
|
-
"type": self.dataset_type,
|
1102
|
-
"split": self.dataset_split,
|
1103
|
-
"image_column": self.dataset_image_column,
|
1104
|
-
},
|
1105
|
-
"vllm_config": self.vllm_config,
|
1106
|
-
}
|
1107
|
-
await websocket.send(safe_json_dumps(welcome_message))
|
1108
|
-
|
1109
|
-
async for message in websocket:
|
1110
|
-
data = json.loads(message)
|
1111
|
-
await self._process_worker_message(worker_id, data)
|
1112
|
-
|
1113
|
-
except websockets.exceptions.ConnectionClosed:
|
1114
|
-
logger.info(f"Worker {worker_id} (user: {worker_user}) disconnected")
|
1115
|
-
finally:
|
1116
|
-
if worker_id in self.workers:
|
1117
|
-
del self.workers[worker_id]
|
1118
|
-
|
1119
|
-
# Clean up user tracking
|
1120
|
-
if hasattr(self, "workers_by_user") and worker_user in self.workers_by_user:
|
1121
|
-
self.workers_by_user[worker_user].discard(worker_id)
|
1122
|
-
if not self.workers_by_user[worker_user]:
|
1123
|
-
del self.workers_by_user[worker_user]
|
1124
|
-
|
1125
|
-
self.stats["connected_workers"] = len(self.workers)
|
1126
|
-
|
1127
|
-
# Release chunks
|
1128
|
-
self.chunk_manager.release_worker_chunks(worker_id)
|
1129
|
-
if self.chunk_tracker:
|
1130
|
-
released_chunks = self.chunk_tracker.release_worker_chunks(worker_id)
|
1131
|
-
logger.info(
|
1132
|
-
f"Released {len(released_chunks) if released_chunks is not None else 0} chunks from worker {worker_id}"
|
1133
|
-
)
|
1134
|
-
|
1135
|
-
await self._broadcast_stats()
|
1136
|
-
await self._send_activity(f"Worker {worker_id} (user: {worker_user}) disconnected")
|
1137
|
-
|
1138
331
|
async def _process_worker_message(self, worker_id: str, data: Dict):
|
1139
332
|
"""Process message from worker."""
|
1140
333
|
msg_type = data.get("type")
|
1141
334
|
|
1142
|
-
if msg_type == "
|
1143
|
-
|
1144
|
-
|
1145
|
-
|
1146
|
-
|
1147
|
-
|
335
|
+
if msg_type == "request_work":
|
336
|
+
count = data.get("count", self.units_per_request)
|
337
|
+
units = self.processor.get_work_units(count, worker_id)
|
338
|
+
logger.debug(f"Assigning units: {[unit.chunk_id for unit in units]}")
|
339
|
+
|
340
|
+
if units:
|
341
|
+
# Create assignment
|
342
|
+
assignment = WorkAssignment(
|
343
|
+
assignment_id=str(uuid.uuid4()),
|
344
|
+
worker_id=worker_id,
|
345
|
+
units=units,
|
346
|
+
assigned_at=datetime.utcnow(),
|
1148
347
|
)
|
1149
|
-
return
|
1150
|
-
|
1151
|
-
count = data.get("count", self.chunks_per_request)
|
1152
|
-
chunk_infos = self.chunk_manager.get_chunks_for_worker(
|
1153
|
-
worker_id, count, self.chunk_tracker
|
1154
|
-
)
|
1155
|
-
|
1156
|
-
if chunk_infos:
|
1157
|
-
# Send chunks with unprocessed ranges
|
1158
|
-
chunks_data = []
|
1159
|
-
for info in chunk_infos:
|
1160
|
-
chunk_dict = info["chunk"].to_dict()
|
1161
|
-
chunk_dict["unprocessed_ranges"] = info["unprocessed_ranges"]
|
1162
|
-
chunks_data.append(chunk_dict)
|
1163
348
|
|
1164
349
|
await self.workers[worker_id].send(
|
1165
|
-
safe_json_dumps({"type": "
|
350
|
+
safe_json_dumps({"type": "work_assignment", "assignment": assignment.to_dict()})
|
1166
351
|
)
|
1167
352
|
|
1168
|
-
|
1169
|
-
logger.info(
|
1170
|
-
f"Assigned {len(chunks_data)} chunks to worker {worker_id}: {chunk_ids}"
|
1171
|
-
)
|
353
|
+
logger.debug(f"Assigned {len(units)} work units to worker {worker_id}")
|
1172
354
|
else:
|
1173
|
-
await self.workers[worker_id].send(safe_json_dumps({"type": "
|
355
|
+
await self.workers[worker_id].send(safe_json_dumps({"type": "no_work"}))
|
1174
356
|
|
1175
|
-
elif msg_type == "
|
1176
|
-
|
1177
|
-
|
1178
|
-
|
357
|
+
elif msg_type == "work_complete":
|
358
|
+
unit_id = data["unit_id"]
|
359
|
+
self.processor.mark_completed(unit_id, worker_id)
|
360
|
+
logger.debug(f"Work unit {unit_id} completed by worker {worker_id}")
|
1179
361
|
|
1180
|
-
|
1181
|
-
|
1182
|
-
|
1183
|
-
logger.info(f"Chunk {chunk_id} completed by worker {worker_id}")
|
1184
|
-
await self._check_shard_completion(chunk_id)
|
1185
|
-
await self._send_activity(f"Chunk {chunk_id} completed by {worker_id}")
|
1186
|
-
elif msg_type == "chunk_failed":
|
1187
|
-
chunk_id = data["chunk_id"]
|
362
|
+
elif msg_type == "work_failed":
|
363
|
+
unit_id = data["unit_id"]
|
1188
364
|
error = data.get("error", "Unknown error")
|
1189
|
-
|
1190
|
-
|
1191
|
-
|
1192
|
-
if self.chunk_tracker:
|
1193
|
-
self.chunk_tracker.mark_failed(chunk_id)
|
365
|
+
self.processor.mark_failed(unit_id, worker_id, error)
|
366
|
+
logger.warning(f"Work unit {unit_id} failed on worker {worker_id}: {error}")
|
1194
367
|
|
1195
|
-
|
1196
|
-
|
368
|
+
elif msg_type == "submit_results":
|
369
|
+
await self._handle_results_submission(worker_id, data)
|
1197
370
|
|
1198
|
-
elif msg_type == "submit_captions":
|
1199
|
-
await self._handle_captions_submission(worker_id, data)
|
1200
|
-
elif msg_type == "request_job":
|
1201
|
-
# CaptionWorker requesting a job from data samples
|
1202
|
-
try:
|
1203
|
-
job = await asyncio.wait_for(self.data_sample_queue.get(), timeout=5)
|
1204
|
-
await self.workers[worker_id].send(
|
1205
|
-
json.dumps({"type": "job_assignment", "job": job})
|
1206
|
-
)
|
1207
|
-
logger.debug(f"Assigned job {job['job_id']} to worker {worker_id}")
|
1208
|
-
except asyncio.TimeoutError:
|
1209
|
-
await self.workers[worker_id].send(json.dumps({"type": "no_jobs"}))
|
1210
371
|
elif msg_type == "heartbeat":
|
1211
|
-
# Update worker stats
|
1212
372
|
logger.debug(f"Heartbeat from {worker_id}: {data}")
|
1213
373
|
|
1214
|
-
async def
|
1215
|
-
"""Process
|
1216
|
-
|
1217
|
-
item_key = data["item_key"]
|
1218
|
-
|
1219
|
-
item_index = data.get("item_index") # Worker should send this
|
1220
|
-
if item_index is None:
|
1221
|
-
# Try to extract from item_key (format: dataset_XXXXXXXX)
|
1222
|
-
try:
|
1223
|
-
item_index = int(item_key.split("_")[-1])
|
1224
|
-
except:
|
1225
|
-
logger.warning(f"Could not extract item index from key: {item_key}")
|
1226
|
-
|
1227
|
-
# Extract user from worker_id (format: "username_uuid")
|
374
|
+
async def _handle_results_submission(self, worker_id: str, data: Dict):
|
375
|
+
"""Process results submission from worker."""
|
376
|
+
# Extract user from worker_id
|
1228
377
|
worker_user = worker_id.rsplit("_", 1)[0] if "_" in worker_id else worker_id
|
1229
378
|
|
1230
|
-
#
|
1231
|
-
|
1232
|
-
|
1233
|
-
|
1234
|
-
|
1235
|
-
|
1236
|
-
|
1237
|
-
|
1238
|
-
|
1239
|
-
|
1240
|
-
|
1241
|
-
|
1242
|
-
|
1243
|
-
|
1244
|
-
|
1245
|
-
|
1246
|
-
|
1247
|
-
logger.debug(
|
1248
|
-
f"Received {len(captions_list)} captions for item {item_key} from worker {worker_id}"
|
1249
|
-
)
|
379
|
+
# Create work result
|
380
|
+
_job_id = data.get("job_id")
|
381
|
+
job_id = JobId.from_str(_job_id)
|
382
|
+
shard_name = job_id.shard_id # >data-0000<
|
383
|
+
chunk_name = job_id.chunk_id # data-0000:chunk:>0<
|
384
|
+
# logger.debug(f"({job_id}) Worker result: {data}")
|
385
|
+
result = WorkResult(
|
386
|
+
unit_id=data["unit_id"],
|
387
|
+
source_id=shard_name,
|
388
|
+
chunk_id=job_id.get_chunk_str(), # we want the full string here
|
389
|
+
sample_id=data["sample_id"],
|
390
|
+
dataset=data["dataset"],
|
391
|
+
outputs=data["outputs"],
|
392
|
+
metadata=data.get("metadata", {}),
|
393
|
+
processing_time_ms=data.get("processing_time_ms", 0),
|
394
|
+
)
|
1250
395
|
|
1251
|
-
#
|
396
|
+
# Let processor handle any custom processing
|
397
|
+
processed = self.processor.handle_result(result)
|
398
|
+
|
399
|
+
# Create caption record for storage
|
400
|
+
total_outputs = sum(len(v) for v in result.outputs.values())
|
401
|
+
|
402
|
+
filename = result.metadata.pop("_filename", None)
|
403
|
+
url = result.metadata.pop("_url", None)
|
404
|
+
image_height = result.metadata.pop("image_height", None)
|
405
|
+
image_width = result.metadata.pop("image_width", None)
|
406
|
+
file_size = result.metadata.pop("file_size", None)
|
407
|
+
image_format = result.metadata.pop("image_format", None)
|
408
|
+
item_index = result.metadata.pop("item_index", None)
|
409
|
+
item_key = result.metadata.pop("item_key", None)
|
410
|
+
to_delete_metadata_keys = ["_image_format", "_job_id"]
|
411
|
+
for key in to_delete_metadata_keys:
|
412
|
+
if key in result.metadata:
|
413
|
+
del result.metadata[key]
|
1252
414
|
caption = Caption(
|
1253
|
-
job_id=
|
1254
|
-
dataset=
|
1255
|
-
shard=
|
415
|
+
job_id=job_id,
|
416
|
+
dataset=result.dataset,
|
417
|
+
shard=processed["source_id"],
|
418
|
+
chunk_id=chunk_name,
|
1256
419
|
item_key=item_key,
|
1257
|
-
captions=
|
1258
|
-
outputs=outputs,
|
420
|
+
captions=result.outputs.get("captions", []),
|
421
|
+
outputs=result.outputs,
|
1259
422
|
contributor_id=worker_user,
|
1260
423
|
timestamp=datetime.utcnow(),
|
1261
|
-
quality_scores=None,
|
1262
|
-
# Image metadata
|
1263
|
-
image_width=data.get("image_width"),
|
1264
|
-
image_height=data.get("image_height"),
|
1265
|
-
image_format=data.get("image_format"),
|
1266
|
-
file_size=data.get("file_size"),
|
1267
|
-
# Processing metadata
|
1268
424
|
caption_count=total_outputs,
|
1269
|
-
processing_time_ms=
|
1270
|
-
|
1271
|
-
|
425
|
+
processing_time_ms=result.processing_time_ms,
|
426
|
+
metadata=result.metadata,
|
427
|
+
image_height=image_height,
|
428
|
+
image_width=image_width,
|
429
|
+
filename=filename,
|
430
|
+
url=url,
|
431
|
+
file_size=file_size,
|
432
|
+
image_format=image_format,
|
1272
433
|
)
|
1273
434
|
|
1274
|
-
#
|
435
|
+
# Save to storage
|
1275
436
|
await self.storage.save_caption(caption)
|
1276
437
|
|
1277
|
-
#
|
1278
|
-
should_flush = False
|
1279
|
-
if chunk_id and item_index is not None and self.chunk_tracker:
|
1280
|
-
with self.item_batch_lock:
|
1281
|
-
self.pending_processed_items[chunk_id].append(item_index)
|
1282
|
-
|
1283
|
-
# Check if we should flush
|
1284
|
-
total_pending = sum(
|
1285
|
-
len(indices) for indices in self.pending_processed_items.values()
|
1286
|
-
)
|
1287
|
-
time_since_flush = time.time() - self.last_item_batch_flush
|
1288
|
-
|
1289
|
-
if (
|
1290
|
-
total_pending >= self.item_batch_size
|
1291
|
-
or time_since_flush >= self.item_batch_interval
|
1292
|
-
):
|
1293
|
-
should_flush = True
|
1294
|
-
|
1295
|
-
if should_flush:
|
1296
|
-
await self._flush_processed_items()
|
1297
|
-
|
1298
|
-
# Update contributor stats (use user, not worker)
|
438
|
+
# Update contributor stats
|
1299
439
|
contributor = await self.storage.get_contributor(worker_user)
|
1300
440
|
if contributor:
|
1301
441
|
contributor.total_captions += total_outputs
|
1302
442
|
await self.storage.save_contributor(contributor)
|
1303
443
|
|
1304
|
-
|
1305
|
-
|
1306
|
-
|
1307
|
-
|
1308
|
-
total_outputs = self.stats.get("total_outputs", 0)
|
1309
|
-
if total_outputs > 0 and total_outputs % 100 == 0:
|
1310
|
-
if (
|
1311
|
-
not hasattr(self, "_last_logged_outputs")
|
1312
|
-
or self._last_logged_outputs != total_outputs
|
1313
|
-
):
|
1314
|
-
logger.info(f"Collected {total_outputs} outputs centrally")
|
1315
|
-
self._last_logged_outputs = total_outputs
|
1316
|
-
|
1317
|
-
async def _check_shard_completion(self, chunk_id: str):
|
1318
|
-
"""Check if a shard is complete after chunk completion."""
|
1319
|
-
# Get the chunk
|
1320
|
-
chunk = self.chunk_manager.chunks.get(chunk_id)
|
1321
|
-
if not chunk:
|
1322
|
-
return
|
444
|
+
async def _handle_monitor(self, websocket: WebSocketServerProtocol):
|
445
|
+
"""Handle monitor connection."""
|
446
|
+
self.monitors.add(websocket)
|
447
|
+
logger.info(f"Monitor connected (total: {len(self.monitors)})")
|
1323
448
|
|
1324
|
-
|
449
|
+
try:
|
450
|
+
# Send welcome
|
451
|
+
await websocket.send(safe_json_dumps({"type": "welcome", "role": "monitor"}))
|
1325
452
|
|
1326
|
-
|
1327
|
-
|
1328
|
-
cid for cid, c in self.chunk_manager.chunks.items() if c.belongs_to_shard(shard_name)
|
1329
|
-
]
|
453
|
+
# Send initial stats
|
454
|
+
await self._send_monitor_stats(websocket)
|
1330
455
|
|
1331
|
-
|
1332
|
-
|
1333
|
-
|
1334
|
-
]
|
456
|
+
# Keep connection alive
|
457
|
+
async for message in websocket:
|
458
|
+
pass
|
1335
459
|
|
1336
|
-
|
1337
|
-
logger.info(
|
1338
|
-
|
1339
|
-
|
1340
|
-
|
1341
|
-
|
1342
|
-
|
460
|
+
except websockets.exceptions.ConnectionClosed:
|
461
|
+
logger.info("Monitor disconnected")
|
462
|
+
finally:
|
463
|
+
self.monitors.discard(websocket)
|
464
|
+
|
465
|
+
async def _handle_admin(self, websocket: WebSocketServerProtocol, auth_ticket):
|
466
|
+
"""Handle admin connection."""
|
467
|
+
admin_id = getattr(auth_ticket, "name", "admin")
|
468
|
+
logger.info(f"Admin {admin_id} connected")
|
469
|
+
|
470
|
+
try:
|
471
|
+
await websocket.send(safe_json_dumps({"type": "welcome", "role": "admin"}))
|
472
|
+
|
473
|
+
async for message in websocket:
|
474
|
+
try:
|
475
|
+
data = json.loads(message)
|
476
|
+
if data.get("type") == "reload_config":
|
477
|
+
await self._handle_config_reload(websocket, data.get("config", {}))
|
478
|
+
elif data.get("type") == "get_stats":
|
479
|
+
await self._send_monitor_stats(websocket)
|
480
|
+
|
481
|
+
except json.JSONDecodeError as e:
|
482
|
+
logger.error(f"Invalid admin message: {e}")
|
483
|
+
|
484
|
+
except websockets.exceptions.ConnectionClosed:
|
485
|
+
logger.info(f"Admin {admin_id} disconnected")
|
1343
486
|
|
1344
487
|
async def _handle_data_worker(self, websocket: WebSocketServerProtocol, auth_ticket):
|
1345
488
|
"""Handle data worker connection."""
|
@@ -1410,7 +553,64 @@ class Orchestrator:
|
|
1410
553
|
finally:
|
1411
554
|
del self.data_workers[worker_id]
|
1412
555
|
|
1413
|
-
async def
|
556
|
+
async def _send_monitor_initial_data(self, websocket: WebSocketServerProtocol):
|
557
|
+
"""Send initial data to monitor in a separate task to avoid blocking."""
|
558
|
+
total_start = time.time()
|
559
|
+
try:
|
560
|
+
# Check if websocket is still in monitors set
|
561
|
+
if websocket not in self.monitors:
|
562
|
+
logger.debug("Monitor disconnected before initial data send")
|
563
|
+
return
|
564
|
+
|
565
|
+
# Send current stats (already in memory)
|
566
|
+
stats_start = time.time()
|
567
|
+
await websocket.send(safe_json_dumps({"type": "stats", "data": self.stats}))
|
568
|
+
logger.debug(f"Monitor stats sent in {(time.time() - stats_start)*1000:.1f}ms")
|
569
|
+
|
570
|
+
# Get processor stats instead of chunk stats
|
571
|
+
processor_stats_start = time.time()
|
572
|
+
processor_stats = self.processor.get_stats()
|
573
|
+
logger.debug(
|
574
|
+
f"Processor stats retrieved in {(time.time() - processor_stats_start)*1000:.1f}ms"
|
575
|
+
)
|
576
|
+
|
577
|
+
stats_send_start = time.time()
|
578
|
+
await websocket.send(
|
579
|
+
safe_json_dumps({"type": "processor_stats", "data": processor_stats})
|
580
|
+
)
|
581
|
+
logger.debug(f"Processor stats sent in {(time.time() - stats_send_start)*1000:.1f}ms")
|
582
|
+
|
583
|
+
if websocket not in self.monitors:
|
584
|
+
return
|
585
|
+
|
586
|
+
# For leaderboard, check if we have a cached version first
|
587
|
+
if hasattr(self, "_cached_leaderboard") and self._cached_leaderboard:
|
588
|
+
# Use cached leaderboard if available
|
589
|
+
cache_send_start = time.time()
|
590
|
+
await websocket.send(
|
591
|
+
safe_json_dumps({"type": "leaderboard", "data": self._cached_leaderboard})
|
592
|
+
)
|
593
|
+
logger.debug(
|
594
|
+
f"Cached leaderboard sent in {(time.time() - cache_send_start)*1000:.1f}ms"
|
595
|
+
)
|
596
|
+
else:
|
597
|
+
# Schedule leaderboard update separately
|
598
|
+
leaderboard_task_start = time.time()
|
599
|
+
asyncio.create_task(self._send_leaderboard_to_monitor(websocket))
|
600
|
+
logger.debug(
|
601
|
+
f"Leaderboard task created in {(time.time() - leaderboard_task_start)*1000:.1f}ms"
|
602
|
+
)
|
603
|
+
|
604
|
+
logger.debug(
|
605
|
+
f"Monitor initial data send completed in {(time.time() - total_start)*1000:.1f}ms"
|
606
|
+
)
|
607
|
+
|
608
|
+
except websockets.exceptions.ConnectionClosed:
|
609
|
+
logger.debug("Monitor disconnected during initial data send")
|
610
|
+
except Exception as e:
|
611
|
+
logger.error(f"Error sending initial monitor data: {e}")
|
612
|
+
|
613
|
+
async def _send_monitor_leaderboard(self, websocket: WebSocketServerProtocol):
|
1414
614
|
"""Send leaderboard data to a specific monitor."""
|
1415
615
|
total_start = time.time()
|
1416
616
|
try:
|
@@ -1475,210 +675,43 @@ class Orchestrator:
|
|
1475
675
|
except Exception as e:
|
1476
676
|
logger.error(f"Error sending leaderboard to monitor: {e}")
|
1477
677
|
|
1478
|
-
async def
|
1479
|
-
"""Send
|
1480
|
-
|
1481
|
-
|
1482
|
-
# Check if websocket is still in monitors set
|
1483
|
-
if websocket not in self.monitors:
|
1484
|
-
logger.debug("Monitor disconnected before initial data send")
|
1485
|
-
return
|
1486
|
-
|
1487
|
-
# Send current stats (already in memory)
|
1488
|
-
stats_start = time.time()
|
1489
|
-
await websocket.send(safe_json_dumps({"type": "stats", "data": self.stats}))
|
1490
|
-
logger.debug(f"Monitor stats sent in {(time.time() - stats_start)*1000:.1f}ms")
|
1491
|
-
|
1492
|
-
# Get chunk stats asynchronously
|
1493
|
-
chunk_stats_start = time.time()
|
1494
|
-
loop = asyncio.get_event_loop()
|
1495
|
-
chunk_stats = await loop.run_in_executor(None, self.chunk_manager.get_stats)
|
1496
|
-
logger.debug(f"Chunk stats retrieved in {(time.time() - chunk_stats_start)*1000:.1f}ms")
|
1497
|
-
|
1498
|
-
if websocket not in self.monitors:
|
1499
|
-
return
|
1500
|
-
|
1501
|
-
chunk_send_start = time.time()
|
1502
|
-
await websocket.send(safe_json_dumps({"type": "chunk_stats", "data": chunk_stats}))
|
1503
|
-
logger.debug(f"Chunk stats sent in {(time.time() - chunk_send_start)*1000:.1f}ms")
|
1504
|
-
|
1505
|
-
# For leaderboard, check if we have a cached version first
|
1506
|
-
if hasattr(self, "_cached_leaderboard") and self._cached_leaderboard:
|
1507
|
-
# Use cached leaderboard if available
|
1508
|
-
cache_send_start = time.time()
|
1509
|
-
await websocket.send(
|
1510
|
-
safe_json_dumps({"type": "leaderboard", "data": self._cached_leaderboard})
|
1511
|
-
)
|
1512
|
-
logger.debug(
|
1513
|
-
f"Cached leaderboard sent in {(time.time() - cache_send_start)*1000:.1f}ms"
|
1514
|
-
)
|
1515
|
-
else:
|
1516
|
-
# Schedule leaderboard update separately
|
1517
|
-
leaderboard_task_start = time.time()
|
1518
|
-
asyncio.create_task(self._send_leaderboard_to_monitor(websocket))
|
1519
|
-
logger.debug(
|
1520
|
-
f"Leaderboard task created in {(time.time() - leaderboard_task_start)*1000:.1f}ms"
|
1521
|
-
)
|
1522
|
-
|
1523
|
-
logger.debug(
|
1524
|
-
f"Monitor initial data send completed in {(time.time() - total_start)*1000:.1f}ms"
|
1525
|
-
)
|
1526
|
-
|
1527
|
-
except websockets.exceptions.ConnectionClosed:
|
1528
|
-
logger.debug("Monitor disconnected during initial data send")
|
1529
|
-
except Exception as e:
|
1530
|
-
logger.error(f"Error sending initial monitor data: {e}")
|
678
|
+
async def _send_monitor_stats(self, websocket: WebSocketServerProtocol):
|
679
|
+
"""Send current stats to a monitor."""
|
680
|
+
# Get processor stats
|
681
|
+
processor_stats = self.processor.get_stats()
|
1531
682
|
|
1532
|
-
|
1533
|
-
|
1534
|
-
monitor_start = time.time()
|
1535
|
-
self.monitors.add(websocket)
|
1536
|
-
logger.info(f"Monitor connected (total monitors: {len(self.monitors)})")
|
1537
|
-
|
1538
|
-
try:
|
1539
|
-
# Send welcome message immediately
|
1540
|
-
welcome_start = time.time()
|
1541
|
-
await websocket.send(safe_json_dumps({"type": "welcome", "role": "monitor"}))
|
1542
|
-
logger.debug(f"Monitor welcome sent in {(time.time() - welcome_start)*1000:.1f}ms")
|
683
|
+
# Get storage stats
|
684
|
+
storage_stats = await self.storage.get_storage_stats()
|
1543
685
|
|
1544
|
-
|
1545
|
-
|
1546
|
-
|
1547
|
-
|
1548
|
-
|
1549
|
-
|
686
|
+
# Combine all stats
|
687
|
+
all_stats = {
|
688
|
+
**self.stats,
|
689
|
+
**storage_stats,
|
690
|
+
"processor_stats": processor_stats,
|
691
|
+
"current_rate": self.rate_tracker["current_rate"],
|
692
|
+
"average_rate": self.rate_tracker["average_rate"],
|
693
|
+
}
|
1550
694
|
|
1551
|
-
|
1552
|
-
try:
|
1553
|
-
async for message in websocket:
|
1554
|
-
# Handle any incoming messages from monitor if needed
|
1555
|
-
# For now, just ignore them
|
1556
|
-
pass
|
1557
|
-
except websockets.exceptions.ConnectionClosed:
|
1558
|
-
pass # Normal disconnection
|
695
|
+
await websocket.send(safe_json_dumps({"type": "stats", "data": all_stats}))
|
1559
696
|
|
1560
|
-
|
1561
|
-
|
1562
|
-
except Exception as e:
|
1563
|
-
logger.error(f"Error in monitor handler: {e}")
|
1564
|
-
finally:
|
1565
|
-
self.monitors.discard(websocket)
|
1566
|
-
logger.debug(f"Monitor handler completed in {(time.time() - monitor_start)*1000:.1f}ms")
|
1567
|
-
|
1568
|
-
async def _broadcast_stats(self):
|
1569
|
-
"""Broadcast statistics to all monitors - truly non-blocking version."""
|
697
|
+
async def _send_activity(self, activity: str):
|
698
|
+
"""Send activity update to monitors."""
|
1570
699
|
if not self.monitors:
|
1571
700
|
return
|
1572
|
-
if self.is_generating_stats:
|
1573
|
-
return # Already generating stats, skip this call
|
1574
|
-
self.is_generating_stats = True
|
1575
|
-
total_start = time.time()
|
1576
|
-
|
1577
|
-
# Prepare all the data first
|
1578
|
-
data_prep_start = time.time()
|
1579
|
-
loop = asyncio.get_event_loop()
|
1580
701
|
|
1581
|
-
|
1582
|
-
|
1583
|
-
storage_stats = await self.storage.get_storage_stats()
|
1584
|
-
logger.debug(f"Storage stats retrieved in {(time.time() - storage_stats_start)*1000:.1f}ms")
|
1585
|
-
|
1586
|
-
caption_stats_start = time.time()
|
1587
|
-
caption_stats = await self.storage.get_caption_stats()
|
1588
|
-
logger.debug(f"Caption stats retrieved in {(time.time() - caption_stats_start)*1000:.1f}ms")
|
1589
|
-
|
1590
|
-
# Get chunk stats in thread pool
|
1591
|
-
chunk_stats_start = time.time()
|
1592
|
-
chunk_stats = await loop.run_in_executor(None, self.chunk_manager.get_stats)
|
1593
|
-
logger.debug(f"Chunk stats retrieved in {(time.time() - chunk_stats_start)*1000:.1f}ms")
|
1594
|
-
|
1595
|
-
# Build stats dict
|
1596
|
-
build_stats_start = time.time()
|
1597
|
-
stats_update = self.stats.copy()
|
1598
|
-
stats_update.update({f"chunks_{k}": v for k, v in chunk_stats.items()})
|
1599
|
-
stats_update.update(storage_stats)
|
1600
|
-
stats_update["field_breakdown"] = caption_stats.get("field_stats", {})
|
1601
|
-
stats_update["output_fields_list"] = caption_stats.get("output_fields", [])
|
1602
|
-
|
1603
|
-
# Add rate information
|
1604
|
-
stats_update.update(
|
1605
|
-
{
|
1606
|
-
"current_rate": self.rate_tracker["current_rate"],
|
1607
|
-
"average_rate": self.rate_tracker["average_rate"],
|
1608
|
-
"expected_rate": self.rate_tracker["expected_rate"],
|
1609
|
-
}
|
702
|
+
message = safe_json_dumps(
|
703
|
+
{"type": "activity", "data": f"[{datetime.now().strftime('%H:%M:%S')}] {activity}"}
|
1610
704
|
)
|
1611
705
|
|
1612
|
-
|
1613
|
-
|
1614
|
-
stats_update["vllm_batch_size"] = self.vllm_config.get("batch_size", 0)
|
1615
|
-
|
1616
|
-
# Add stage information
|
1617
|
-
stages = self.vllm_config.get("stages", [])
|
1618
|
-
if stages:
|
1619
|
-
stats_update["stage_count"] = len(stages)
|
1620
|
-
stats_update["stage_names"] = [s.get("name", "unnamed") for s in stages]
|
1621
|
-
else:
|
1622
|
-
stats_update["stage_count"] = 1
|
1623
|
-
stats_update["stage_names"] = ["default"]
|
1624
|
-
|
1625
|
-
# Get field stats
|
1626
|
-
field_stats_start = time.time()
|
1627
|
-
field_stats = await self.storage.get_output_field_stats()
|
1628
|
-
stats_update["output_fields"] = field_stats
|
1629
|
-
logger.debug(f"Field stats retrieved in {(time.time() - field_stats_start)*1000:.1f}ms")
|
1630
|
-
|
1631
|
-
# Update our internal stats
|
1632
|
-
self.stats = stats_update
|
1633
|
-
logger.debug(f"Stats prepared in {(time.time() - build_stats_start)*1000:.1f}ms")
|
1634
|
-
|
1635
|
-
logger.debug(f"Total data preparation took {(time.time() - data_prep_start)*1000:.1f}ms")
|
1636
|
-
|
1637
|
-
# Create message once
|
1638
|
-
message_create_start = time.time()
|
1639
|
-
stats_message = safe_json_dumps({"type": "stats", "data": self.stats})
|
1640
|
-
logger.debug(f"Stats message created in {(time.time() - message_create_start)*1000:.1f}ms")
|
1641
|
-
|
1642
|
-
# Send to all monitors asynchronously in parallel
|
1643
|
-
send_start = time.time()
|
1644
|
-
|
1645
|
-
async def send_to_monitor(monitor):
|
706
|
+
disconnected = set()
|
707
|
+
for monitor in self.monitors:
|
1646
708
|
try:
|
1647
|
-
await monitor.send(
|
709
|
+
await monitor.send(message)
|
1648
710
|
except websockets.exceptions.ConnectionClosed:
|
1649
|
-
|
1650
|
-
except Exception as e:
|
1651
|
-
logger.debug(f"Error sending stats to monitor: {e}")
|
1652
|
-
return monitor # Return for removal
|
1653
|
-
return None
|
1654
|
-
|
1655
|
-
# Send to all monitors in parallel
|
1656
|
-
monitors_copy = self.monitors.copy()
|
1657
|
-
results = await asyncio.gather(
|
1658
|
-
*[send_to_monitor(m) for m in monitors_copy], return_exceptions=True
|
1659
|
-
)
|
711
|
+
disconnected.add(monitor)
|
1660
712
|
|
1661
|
-
# Remove disconnected monitors
|
1662
|
-
disconnected = {
|
1663
|
-
m
|
1664
|
-
for m, r in zip(monitors_copy, results)
|
1665
|
-
if r is not None and not isinstance(r, Exception)
|
1666
|
-
}
|
1667
713
|
self.monitors -= disconnected
|
1668
714
|
|
1669
|
-
logger.debug(
|
1670
|
-
f"Stats sent to {len(monitors_copy)} monitors in {(time.time() - send_start)*1000:.1f}ms"
|
1671
|
-
)
|
1672
|
-
|
1673
|
-
# Send leaderboard update in a separate task to avoid blocking
|
1674
|
-
leaderboard_task_start = time.time()
|
1675
|
-
asyncio.create_task(self._broadcast_leaderboard())
|
1676
|
-
self.is_generating_stats = False
|
1677
|
-
logger.debug(
|
1678
|
-
f"Leaderboard broadcast task created in {(time.time() - leaderboard_task_start)*1000:.1f}ms"
|
1679
|
-
)
|
1680
|
-
logger.debug(f"Stats broadcast completed in {(time.time() - total_start)*1000:.1f}ms")
|
1681
|
-
|
1682
715
|
async def _broadcast_leaderboard(self):
|
1683
716
|
"""Send leaderboard updates to monitors - separate from stats to avoid blocking."""
|
1684
717
|
if not self.monitors:
|
@@ -1769,82 +802,38 @@ class Orchestrator:
|
|
1769
802
|
except Exception as e:
|
1770
803
|
logger.error(f"Error broadcasting leaderboard: {e}")
|
1771
804
|
|
1772
|
-
def
|
1773
|
-
"""
|
1774
|
-
with self.chunk_manager.lock:
|
1775
|
-
return {
|
1776
|
-
"pending_chunks": len(self.chunk_manager.pending_chunks),
|
1777
|
-
"assigned_chunks": sum(
|
1778
|
-
len(chunks) for chunks in self.chunk_manager.assigned_chunks.values()
|
1779
|
-
),
|
1780
|
-
}
|
1781
|
-
|
1782
|
-
async def _flush_processed_items(self):
|
1783
|
-
"""Flush batched processed items to chunk tracker."""
|
1784
|
-
with self.item_batch_lock:
|
1785
|
-
if not self.pending_processed_items:
|
1786
|
-
return
|
1787
|
-
|
1788
|
-
for chunk_id, indices in self.pending_processed_items.items():
|
1789
|
-
if not indices:
|
1790
|
-
continue
|
1791
|
-
|
1792
|
-
# Indices here are ABSOLUTE dataset indices
|
1793
|
-
# Sort indices
|
1794
|
-
indices.sort()
|
1795
|
-
|
1796
|
-
# Group consecutive indices into ranges
|
1797
|
-
ranges = []
|
1798
|
-
start = indices[0]
|
1799
|
-
end = indices[0]
|
1800
|
-
|
1801
|
-
for i in range(1, len(indices)):
|
1802
|
-
if indices[i] == end + 1:
|
1803
|
-
# Consecutive, extend range
|
1804
|
-
end = indices[i]
|
1805
|
-
else:
|
1806
|
-
# Gap found, save current range and start new one
|
1807
|
-
ranges.append((start, end))
|
1808
|
-
start = indices[i]
|
1809
|
-
end = indices[i]
|
1810
|
-
|
1811
|
-
# Don't forget the last range
|
1812
|
-
ranges.append((start, end))
|
1813
|
-
|
1814
|
-
# Mark ranges as processed (mark_items_processed expects absolute indices)
|
1815
|
-
for start_idx, end_idx in ranges:
|
1816
|
-
self.chunk_tracker.mark_items_processed(chunk_id, start_idx, end_idx)
|
1817
|
-
|
1818
|
-
# Clear pending items
|
1819
|
-
self.pending_processed_items.clear()
|
1820
|
-
self.last_item_batch_flush = time.time()
|
1821
|
-
|
1822
|
-
def get_workers_by_user_stats(self) -> Dict[str, Any]:
|
1823
|
-
"""Get statistics about workers grouped by user/token - thread-safe version."""
|
1824
|
-
if not hasattr(self, "workers_by_user"):
|
1825
|
-
return {}
|
1826
|
-
|
1827
|
-
# Create a copy to avoid issues with concurrent modification
|
1828
|
-
stats = {}
|
1829
|
-
workers_snapshot = dict(self.workers_by_user)
|
1830
|
-
for user, worker_ids in workers_snapshot.items():
|
1831
|
-
stats[user] = {"worker_count": len(worker_ids), "worker_ids": list(worker_ids)}
|
1832
|
-
return stats
|
1833
|
-
|
1834
|
-
async def _send_activity(self, activity: str):
|
1835
|
-
"""Send activity update to monitors."""
|
805
|
+
async def _broadcast_stats(self):
|
806
|
+
"""Broadcast statistics to all monitors."""
|
1836
807
|
if not self.monitors:
|
1837
808
|
return
|
1838
809
|
|
1839
|
-
|
1840
|
-
|
810
|
+
# Get current stats
|
811
|
+
processor_stats = self.processor.get_stats()
|
812
|
+
storage_stats = await self.storage.get_storage_stats()
|
813
|
+
|
814
|
+
# Update main stats
|
815
|
+
self.stats["processor_stats"] = processor_stats
|
816
|
+
self.stats["total_outputs"] = storage_stats["total_captions"]
|
817
|
+
|
818
|
+
# Create message
|
819
|
+
stats_message = safe_json_dumps(
|
820
|
+
{
|
821
|
+
"type": "stats",
|
822
|
+
"data": {
|
823
|
+
**self.stats,
|
824
|
+
**storage_stats,
|
825
|
+
"current_rate": self.rate_tracker["current_rate"],
|
826
|
+
"average_rate": self.rate_tracker["average_rate"],
|
827
|
+
},
|
828
|
+
}
|
1841
829
|
)
|
1842
830
|
|
831
|
+
# Send to all monitors
|
1843
832
|
disconnected = set()
|
1844
833
|
for monitor in self.monitors:
|
1845
834
|
try:
|
1846
|
-
await monitor.send(
|
1847
|
-
except
|
835
|
+
await monitor.send(stats_message)
|
836
|
+
except:
|
1848
837
|
disconnected.add(monitor)
|
1849
838
|
|
1850
839
|
self.monitors -= disconnected
|
@@ -1852,235 +841,74 @@ class Orchestrator:
|
|
1852
841
|
async def _heartbeat_loop(self):
|
1853
842
|
"""Send periodic heartbeats to maintain connections."""
|
1854
843
|
while True:
|
1855
|
-
|
1856
|
-
|
1857
|
-
|
1858
|
-
|
1859
|
-
|
1860
|
-
|
1861
|
-
|
1862
|
-
|
1863
|
-
|
1864
|
-
|
1865
|
-
|
1866
|
-
|
1867
|
-
|
1868
|
-
|
1869
|
-
|
1870
|
-
|
1871
|
-
|
1872
|
-
|
1873
|
-
|
1874
|
-
|
1875
|
-
except websockets.exceptions.ConnectionClosed:
|
1876
|
-
logger.info(f"Worker {worker_id} connection already closed")
|
1877
|
-
disconnected.append(worker_id)
|
1878
|
-
except Exception as e:
|
1879
|
-
logger.error(f"Error pinging worker {worker_id}: {e}")
|
1880
|
-
disconnected.append(worker_id)
|
1881
|
-
|
1882
|
-
# Clean up disconnected workers
|
1883
|
-
for worker_id in disconnected:
|
1884
|
-
if worker_id in self.workers:
|
1885
|
-
logger.info(f"Removing unresponsive worker {worker_id}")
|
1886
|
-
del self.workers[worker_id]
|
1887
|
-
self.chunk_manager.release_worker_chunks(worker_id)
|
1888
|
-
|
1889
|
-
# Update stats
|
1890
|
-
self.stats["connected_workers"] = len(self.workers)
|
1891
|
-
|
1892
|
-
# Also clean up from workers_by_user if it exists
|
1893
|
-
if hasattr(self, "workers_by_user"):
|
1894
|
-
worker_user = (
|
1895
|
-
worker_id.rsplit("_", 1)[0] if "_" in worker_id else worker_id
|
1896
|
-
)
|
1897
|
-
if worker_user in self.workers_by_user:
|
1898
|
-
self.workers_by_user[worker_user].discard(worker_id)
|
1899
|
-
if not self.workers_by_user[worker_user]:
|
1900
|
-
del self.workers_by_user[worker_user]
|
1901
|
-
|
1902
|
-
# Notify monitors
|
1903
|
-
await self._broadcast_stats()
|
1904
|
-
await self._send_activity(
|
1905
|
-
f"Worker {worker_id} removed due to heartbeat timeout"
|
1906
|
-
)
|
1907
|
-
|
1908
|
-
except Exception as e:
|
1909
|
-
logger.error(f"Error in heartbeat loop: {e}", exc_info=True)
|
1910
|
-
# Continue the loop even if there's an error
|
1911
|
-
await asyncio.sleep(5)
|
844
|
+
await asyncio.sleep(30)
|
845
|
+
|
846
|
+
disconnected = []
|
847
|
+
for worker_id, ws in list(self.workers.items()):
|
848
|
+
try:
|
849
|
+
pong_waiter = await ws.ping()
|
850
|
+
await asyncio.wait_for(pong_waiter, timeout=10)
|
851
|
+
except:
|
852
|
+
disconnected.append(worker_id)
|
853
|
+
|
854
|
+
# Clean up disconnected workers
|
855
|
+
for worker_id in disconnected:
|
856
|
+
logger.warning(f"Worker {worker_id} did not respond to ping, disconnecting")
|
857
|
+
if worker_id in self.workers:
|
858
|
+
del self.workers[worker_id]
|
859
|
+
logger.warning(
|
860
|
+
f"Releasing assignments for worker {worker_id} because it did not respond to ping"
|
861
|
+
)
|
862
|
+
self.processor.release_assignments(worker_id)
|
863
|
+
self.stats["connected_workers"] = len(self.workers)
|
1912
864
|
|
1913
865
|
async def _checkpoint_loop(self):
|
1914
866
|
"""Periodically checkpoint storage."""
|
1915
|
-
interval = self.config.get("storage", {}).get("checkpoint_interval",
|
867
|
+
interval = self.config.get("storage", {}).get("checkpoint_interval", 60)
|
1916
868
|
|
1917
869
|
while True:
|
1918
|
-
await asyncio.sleep(
|
870
|
+
await asyncio.sleep(interval)
|
1919
871
|
|
1920
|
-
|
1921
|
-
|
1922
|
-
|
1923
|
-
|
1924
|
-
# Force checkpoint at regular intervals
|
1925
|
-
if total_captions > 0 and total_captions % interval == 0:
|
1926
|
-
logger.info(f"Triggering checkpoint at {total_captions} captions")
|
1927
|
-
await self.storage.checkpoint()
|
1928
|
-
|
1929
|
-
# Update stats
|
1930
|
-
self.stats["last_checkpoint"] = datetime.utcnow().isoformat()
|
1931
|
-
# No need to update total_written or buffer_size - they come from storage
|
1932
|
-
|
1933
|
-
await self._broadcast_stats()
|
1934
|
-
logger.info(
|
1935
|
-
f"Checkpoint complete. Total written to disk: {storage_stats['total_written']}"
|
1936
|
-
)
|
872
|
+
await self.storage.checkpoint()
|
873
|
+
self.stats["last_checkpoint"] = datetime.utcnow().isoformat()
|
874
|
+
logger.info("Storage checkpoint complete")
|
1937
875
|
|
1938
876
|
async def _stats_update_loop(self):
|
1939
|
-
"""Periodically update and broadcast stats
|
1940
|
-
# Get the event loop for running blocking operations
|
1941
|
-
loop = asyncio.get_event_loop()
|
1942
|
-
|
1943
|
-
# Track session start values
|
1944
|
-
storage_stats = await self.storage.get_storage_stats()
|
1945
|
-
session_start_outputs = storage_stats["total_captions"] # This now counts ALL outputs
|
1946
|
-
session_start_time = time.time()
|
1947
|
-
|
1948
|
-
# Track the last known total to detect flushes
|
1949
|
-
last_known_total = session_start_outputs
|
1950
|
-
|
877
|
+
"""Periodically update and broadcast stats."""
|
1951
878
|
while True:
|
1952
879
|
await asyncio.sleep(10)
|
1953
880
|
|
1954
|
-
# Update
|
1955
|
-
chunk_stats = await loop.run_in_executor(None, self.chunk_manager.get_stats)
|
881
|
+
# Update rate tracking
|
1956
882
|
storage_stats = await self.storage.get_storage_stats()
|
1957
|
-
|
1958
|
-
if self.chunk_tracker:
|
1959
|
-
await self._flush_processed_items()
|
1960
|
-
|
1961
|
-
self.stats["total_chunks"] = chunk_stats["total"]
|
1962
|
-
self.stats["completed_chunks"] = chunk_stats["completed"]
|
1963
|
-
self.stats["failed_chunks"] = chunk_stats["failed"]
|
1964
|
-
|
1965
|
-
# Update total outputs stat (rename from total_captions for clarity)
|
1966
|
-
self.stats["total_outputs"] = current_total_outputs
|
1967
|
-
self.stats["total_captions"] = current_total_outputs # Keep for backward compatibility
|
1968
|
-
|
1969
|
-
# Get queue stats in thread pool to avoid blocking
|
1970
|
-
queue_stats = await loop.run_in_executor(None, self._get_queue_stats)
|
1971
|
-
self.stats.update(queue_stats)
|
1972
|
-
|
1973
|
-
# Calculate if we need more chunks
|
1974
|
-
worker_count = self.stats.get("connected_workers", 0)
|
1975
|
-
target_buffer = max(self.min_chunk_buffer, worker_count * self.chunk_buffer_multiplier)
|
1976
|
-
active_chunks = self.stats["pending_chunks"] + self.stats["assigned_chunks"]
|
1977
|
-
self.stats["chunk_buffer_status"] = f"{active_chunks}/{target_buffer}"
|
1978
|
-
|
1979
|
-
# Update rate information
|
883
|
+
current_total = storage_stats["total_captions"]
|
1980
884
|
current_time = time.time()
|
1981
|
-
elapsed_since_update = current_time - self.rate_tracker["last_update_time"]
|
1982
|
-
|
1983
|
-
if elapsed_since_update > 0:
|
1984
|
-
# FIX: Handle the case where duplicates were skipped during save
|
1985
|
-
# If current total is less than last known, it means duplicates were skipped
|
1986
|
-
# We should not count this as negative progress
|
1987
|
-
if current_total_outputs < last_known_total:
|
1988
|
-
logger.debug(
|
1989
|
-
f"Detected duplicate skip during save: {last_known_total} -> {current_total_outputs}"
|
1990
|
-
)
|
1991
|
-
# Don't calculate negative rate, just update the baseline
|
1992
|
-
self.rate_tracker["last_caption_count"] = current_total_outputs
|
1993
|
-
self.rate_tracker["current_rate"] = 0.0 # Set to 0 during flush
|
1994
|
-
else:
|
1995
|
-
# Normal rate calculation
|
1996
|
-
output_diff = current_total_outputs - self.rate_tracker["last_caption_count"]
|
1997
|
-
self.rate_tracker["current_rate"] = (output_diff / elapsed_since_update) * 60
|
1998
|
-
self.rate_tracker["last_caption_count"] = current_total_outputs
|
1999
|
-
|
2000
|
-
# Calculate average rate since THIS SESSION started
|
2001
|
-
session_elapsed = current_time - session_start_time
|
2002
|
-
if session_elapsed > 0:
|
2003
|
-
# Always use the difference from session start for average
|
2004
|
-
session_outputs = current_total_outputs - session_start_outputs
|
2005
|
-
self.rate_tracker["average_rate"] = (session_outputs / session_elapsed) * 60
|
2006
|
-
|
2007
|
-
# Calculate expected rate based on workers and stages
|
2008
|
-
batch_size = self.vllm_config.get("batch_size", 8)
|
2009
|
-
|
2010
|
-
# Count total prompts across all stages
|
2011
|
-
total_prompts = 0
|
2012
|
-
stages = self.vllm_config.get("stages", [])
|
2013
|
-
if stages:
|
2014
|
-
for stage in stages:
|
2015
|
-
total_prompts += len(stage.get("prompts", []))
|
2016
|
-
else:
|
2017
|
-
# Backward compatibility
|
2018
|
-
total_prompts = len(self.vllm_config.get("inference_prompts", ["", "", ""]))
|
2019
885
|
|
2020
|
-
|
2021
|
-
|
2022
|
-
|
2023
|
-
)
|
2024
|
-
|
2025
|
-
# Update trackers
|
886
|
+
elapsed = current_time - self.rate_tracker["last_update_time"]
|
887
|
+
if elapsed > 0:
|
888
|
+
output_diff = current_total - self.rate_tracker["last_output_count"]
|
889
|
+
self.rate_tracker["current_rate"] = (output_diff / elapsed) * 60
|
890
|
+
self.rate_tracker["last_output_count"] = current_total
|
2026
891
|
self.rate_tracker["last_update_time"] = current_time
|
2027
|
-
last_known_total = current_total_outputs
|
2028
|
-
|
2029
|
-
# Log rate information when workers are connected
|
2030
|
-
if (
|
2031
|
-
worker_count > 0 and self.rate_tracker["current_rate"] >= 0
|
2032
|
-
): # Only log non-negative rates
|
2033
|
-
logger.info(
|
2034
|
-
f"Rate: {self.rate_tracker['current_rate']:.1f} outputs/min "
|
2035
|
-
f"(avg: {self.rate_tracker['average_rate']:.1f}, "
|
2036
|
-
f"expected: {self.rate_tracker['expected_rate']:.1f}) | "
|
2037
|
-
f"Workers: {worker_count}, Chunks: {active_chunks}/{target_buffer}"
|
2038
|
-
)
|
2039
892
|
|
2040
|
-
|
893
|
+
# Average rate since start
|
894
|
+
total_elapsed = current_time - self.rate_tracker["start_time"]
|
895
|
+
if total_elapsed > 0:
|
896
|
+
self.rate_tracker["average_rate"] = (current_total / total_elapsed) * 60
|
2041
897
|
|
2042
|
-
|
2043
|
-
"""Restore state from storage on startup."""
|
2044
|
-
total_captions = await self.storage.count_captions()
|
2045
|
-
logger.info(f"Restored state: {total_captions} captions")
|
898
|
+
await self._broadcast_stats()
|
2046
899
|
|
2047
900
|
async def shutdown(self):
|
2048
901
|
"""Graceful shutdown."""
|
2049
902
|
logger.info("Shutting down orchestrator...")
|
2050
903
|
|
2051
|
-
# Stop chunk creation
|
2052
|
-
if self.chunk_tracker:
|
2053
|
-
await self._flush_processed_items()
|
2054
|
-
self.stop_chunk_creation.set()
|
2055
|
-
if self.chunk_creation_thread:
|
2056
|
-
self.chunk_creation_thread.join(timeout=5)
|
2057
|
-
|
2058
|
-
# Release all assigned chunks before closing connections
|
2059
|
-
for worker_id in list(self.workers.keys()):
|
2060
|
-
self.chunk_manager.release_worker_chunks(worker_id)
|
2061
|
-
if self.chunk_tracker:
|
2062
|
-
# Update chunk tracker to mark assigned chunks as pending
|
2063
|
-
with self.chunk_manager.lock:
|
2064
|
-
for chunk_id in list(self.chunk_manager.assigned_chunks.get(worker_id, [])):
|
2065
|
-
self.chunk_tracker.mark_pending(chunk_id)
|
2066
|
-
|
2067
904
|
# Close all connections
|
2068
905
|
for ws in list(self.workers.values()):
|
2069
906
|
await ws.close()
|
2070
907
|
for ws in list(self.monitors):
|
2071
908
|
await ws.close()
|
2072
909
|
|
2073
|
-
# Save chunk state
|
2074
|
-
if self.chunk_tracker:
|
2075
|
-
self.chunk_tracker.save()
|
2076
|
-
|
2077
910
|
# Final checkpoint
|
2078
|
-
logger.info(f"Final flush: {len(self.storage.caption_buffer)} captions in buffer")
|
2079
911
|
await self.storage.checkpoint()
|
2080
|
-
|
2081
|
-
# Log final statistics
|
2082
|
-
logger.info(
|
2083
|
-
f"Shutdown complete. Total captions collected: {self.storage.total_captions_written}"
|
2084
|
-
)
|
2085
|
-
|
2086
912
|
await self.storage.close()
|
913
|
+
|
914
|
+
logger.info("Shutdown complete")
|