caption-flow 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +3 -3
- caption_flow/cli.py +937 -416
- caption_flow/models.py +45 -3
- caption_flow/monitor.py +5 -3
- caption_flow/orchestrator.py +186 -116
- caption_flow/processors/__init__.py +3 -3
- caption_flow/processors/base.py +8 -7
- caption_flow/processors/huggingface.py +440 -68
- caption_flow/processors/local_filesystem.py +24 -28
- caption_flow/processors/webdataset.py +66 -25
- caption_flow/storage/exporter.py +420 -339
- caption_flow/storage/manager.py +636 -756
- caption_flow/utils/__init__.py +1 -1
- caption_flow/utils/auth.py +1 -1
- caption_flow/utils/caption_utils.py +1 -1
- caption_flow/utils/certificates.py +15 -8
- caption_flow/utils/checkpoint_tracker.py +41 -19
- caption_flow/utils/chunk_tracker.py +200 -65
- caption_flow/utils/image_processor.py +9 -9
- caption_flow/utils/json_utils.py +37 -20
- caption_flow/utils/prompt_template.py +24 -16
- caption_flow/utils/vllm_config.py +5 -4
- caption_flow/viewer.py +4 -12
- caption_flow/workers/base.py +12 -6
- caption_flow/workers/caption.py +272 -91
- caption_flow/workers/data.py +6 -8
- {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/METADATA +5 -4
- caption_flow-0.4.0.dist-info/RECORD +33 -0
- caption_flow-0.3.3.dist-info/RECORD +0 -33
- {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/WHEEL +0 -0
- {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/top_level.txt +0 -0
caption_flow/models.py
CHANGED
@@ -1,12 +1,18 @@
|
|
1
1
|
"""Data models for CaptionFlow."""
|
2
2
|
|
3
|
-
import
|
3
|
+
import datetime as _datetime
|
4
|
+
import logging
|
5
|
+
import os
|
4
6
|
from dataclasses import dataclass, field
|
5
7
|
from datetime import datetime
|
6
8
|
from enum import Enum
|
7
9
|
from typing import Any, Dict, List, Optional, Tuple
|
10
|
+
|
8
11
|
from PIL import Image
|
9
12
|
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
logger.setLevel(os.environ.get("CAPTIONFLOW_LOG_LEVEL", "INFO").upper())
|
15
|
+
|
10
16
|
|
11
17
|
class JobStatus(Enum):
|
12
18
|
"""Job processing status."""
|
@@ -37,7 +43,7 @@ class Job:
|
|
37
43
|
|
38
44
|
def __post_init__(self):
|
39
45
|
if self.created_at is None:
|
40
|
-
self.created_at = datetime.
|
46
|
+
self.created_at = datetime.now(_datetime.UTC)
|
41
47
|
|
42
48
|
|
43
49
|
@dataclass
|
@@ -69,7 +75,43 @@ class JobId:
|
|
69
75
|
parts = job_id.split(":")
|
70
76
|
if len(parts) != 5:
|
71
77
|
raise ValueError(f"Invalid job_id format: {job_id}")
|
72
|
-
|
78
|
+
|
79
|
+
shard_id = parts[0]
|
80
|
+
chunk_keyword = parts[1]
|
81
|
+
chunk_id = parts[2]
|
82
|
+
idx_keyword = parts[3]
|
83
|
+
sample_id = parts[4]
|
84
|
+
|
85
|
+
# Validate format
|
86
|
+
if not shard_id:
|
87
|
+
raise ValueError(f"Invalid job_id format: empty shard_id in {job_id}")
|
88
|
+
if chunk_keyword != "chunk":
|
89
|
+
raise ValueError(
|
90
|
+
f"Invalid job_id format: expected 'chunk' keyword, got '{chunk_keyword}' in {job_id}"
|
91
|
+
)
|
92
|
+
if idx_keyword != "idx":
|
93
|
+
raise ValueError(
|
94
|
+
f"Invalid job_id format: expected 'idx' keyword, got '{idx_keyword}' in {job_id}"
|
95
|
+
)
|
96
|
+
|
97
|
+
# Validate numeric fields
|
98
|
+
try:
|
99
|
+
int(chunk_id)
|
100
|
+
except ValueError:
|
101
|
+
raise ValueError(
|
102
|
+
f"Invalid job_id format: chunk_id must be numeric, got '{chunk_id}' in {job_id}"
|
103
|
+
)
|
104
|
+
|
105
|
+
# sample_id can be empty/None for some use cases, but if provided must be numeric
|
106
|
+
if sample_id:
|
107
|
+
try:
|
108
|
+
int(sample_id)
|
109
|
+
except ValueError:
|
110
|
+
raise ValueError(
|
111
|
+
f"Invalid job_id format: sample_id must be numeric if provided, got '{sample_id}' in {job_id}"
|
112
|
+
)
|
113
|
+
|
114
|
+
return JobId(shard_id=shard_id, chunk_id=chunk_id, sample_id=sample_id)
|
73
115
|
|
74
116
|
|
75
117
|
@dataclass
|
caption_flow/monitor.py
CHANGED
@@ -4,9 +4,8 @@ import asyncio
|
|
4
4
|
import json
|
5
5
|
import logging
|
6
6
|
import ssl
|
7
|
-
import time
|
8
7
|
from datetime import datetime
|
9
|
-
from typing import
|
8
|
+
from typing import Any, Dict, Optional
|
10
9
|
|
11
10
|
import websockets
|
12
11
|
from rich.console import Console
|
@@ -73,6 +72,9 @@ class Monitor:
|
|
73
72
|
async with websockets.connect(
|
74
73
|
self.server_url,
|
75
74
|
ssl=self.ssl_context if self.server_url.startswith("wss://") else None,
|
75
|
+
ping_interval=20,
|
76
|
+
ping_timeout=60,
|
77
|
+
close_timeout=10,
|
76
78
|
) as websocket:
|
77
79
|
# Authenticate
|
78
80
|
await websocket.send(json.dumps({"token": self.token}))
|
@@ -107,7 +109,7 @@ class Monitor:
|
|
107
109
|
"""Main display update loop."""
|
108
110
|
layout = self._create_layout()
|
109
111
|
|
110
|
-
with Live(layout, console=self.console, refresh_per_second=1, screen=True)
|
112
|
+
with Live(layout, console=self.console, refresh_per_second=1, screen=True):
|
111
113
|
while self.running:
|
112
114
|
self._update_layout(layout)
|
113
115
|
await asyncio.sleep(0.25)
|
caption_flow/orchestrator.py
CHANGED
@@ -1,34 +1,34 @@
|
|
1
|
-
import time
|
2
1
|
import asyncio
|
2
|
+
import datetime as _datetime
|
3
3
|
import json
|
4
4
|
import logging
|
5
|
+
import os
|
5
6
|
import ssl
|
7
|
+
import time
|
6
8
|
import uuid
|
9
|
+
from collections import defaultdict
|
7
10
|
from datetime import datetime
|
8
11
|
from pathlib import Path
|
9
|
-
from typing import
|
10
|
-
from collections import defaultdict
|
11
|
-
import threading
|
12
|
+
from typing import Any, Dict, Optional, Set
|
12
13
|
|
13
14
|
import websockets
|
14
|
-
from websockets.server import
|
15
|
+
from websockets.asyncio.server import ServerConnection
|
15
16
|
|
16
|
-
from .storage import StorageManager
|
17
17
|
from .models import Caption, Contributor, JobId
|
18
|
-
from .utils.auth import AuthManager
|
19
|
-
from .utils.json_utils import safe_json_dumps
|
20
18
|
from .processors import (
|
19
|
+
HuggingFaceDatasetOrchestratorProcessor,
|
20
|
+
LocalFilesystemOrchestratorProcessor,
|
21
21
|
ProcessorConfig,
|
22
|
+
WebDatasetOrchestratorProcessor,
|
22
23
|
WorkAssignment,
|
23
24
|
WorkResult,
|
24
|
-
WorkUnit,
|
25
|
-
WebDatasetOrchestratorProcessor,
|
26
|
-
HuggingFaceDatasetOrchestratorProcessor,
|
27
|
-
LocalFilesystemOrchestratorProcessor,
|
28
25
|
)
|
26
|
+
from .storage import StorageManager
|
27
|
+
from .utils.auth import AuthManager
|
28
|
+
from .utils.json_utils import safe_json_dumps
|
29
29
|
|
30
30
|
logger = logging.getLogger(__name__)
|
31
|
-
logger.setLevel(
|
31
|
+
logger.setLevel(os.environ.get("CAPTIONFLOW_LOG_LEVEL", "INFO").upper())
|
32
32
|
|
33
33
|
|
34
34
|
class Orchestrator:
|
@@ -69,8 +69,8 @@ class Orchestrator:
|
|
69
69
|
self.chunks_per_request = config.get("chunks_per_request", 2)
|
70
70
|
|
71
71
|
# Track connections
|
72
|
-
self.workers: Dict[str,
|
73
|
-
self.monitors: Set[
|
72
|
+
self.workers: Dict[str, ServerConnection] = {}
|
73
|
+
self.monitors: Set[ServerConnection] = set()
|
74
74
|
self.workers_by_user = defaultdict(set)
|
75
75
|
|
76
76
|
# SSL configuration
|
@@ -160,11 +160,11 @@ class Orchestrator:
|
|
160
160
|
processed_job_ids = self.storage.get_all_processed_job_ids()
|
161
161
|
self.processor.update_from_storage(processed_job_ids)
|
162
162
|
|
163
|
-
async def _send_leaderboard_to_monitor(self, websocket:
|
163
|
+
async def _send_leaderboard_to_monitor(self, websocket: ServerConnection):
|
164
164
|
"""Alias for _send_monitor_leaderboard for backward compatibility."""
|
165
165
|
await self._send_monitor_leaderboard(websocket)
|
166
166
|
|
167
|
-
async def handle_connection(self, websocket:
|
167
|
+
async def handle_connection(self, websocket: ServerConnection):
|
168
168
|
"""Handle new WebSocket connection."""
|
169
169
|
try:
|
170
170
|
# Authenticate
|
@@ -193,7 +193,7 @@ class Orchestrator:
|
|
193
193
|
logger.error(f"Connection error: {e}", exc_info=True)
|
194
194
|
await websocket.close()
|
195
195
|
|
196
|
-
async def _handle_worker(self, websocket:
|
196
|
+
async def _handle_worker(self, websocket: ServerConnection, auth_ticket):
|
197
197
|
"""Handle worker connection lifecycle."""
|
198
198
|
# Generate unique worker ID
|
199
199
|
base_name = getattr(auth_ticket, "name", "worker")
|
@@ -250,9 +250,38 @@ class Orchestrator:
|
|
250
250
|
self.processor.release_assignments(worker_id)
|
251
251
|
logger.info(f"Worker {worker_id} has safely disconnected")
|
252
252
|
|
253
|
-
|
254
|
-
self,
|
255
|
-
):
|
253
|
+
def _auth_configs_equal(
|
254
|
+
self, current_config: Dict[str, Any], new_config: Dict[str, Any]
|
255
|
+
) -> bool:
|
256
|
+
"""Compare two auth configurations for equality."""
|
257
|
+
|
258
|
+
# Helper function to normalize token lists for comparison
|
259
|
+
def normalize_tokens(tokens):
|
260
|
+
if not tokens:
|
261
|
+
return []
|
262
|
+
# Sort by token for consistent comparison
|
263
|
+
return sorted(
|
264
|
+
[{"name": t.get("name"), "token": t.get("token")} for t in tokens],
|
265
|
+
key=lambda x: x.get("token", ""),
|
266
|
+
)
|
267
|
+
|
268
|
+
# Compare each token type
|
269
|
+
current_workers = normalize_tokens(current_config.get("worker_tokens", []))
|
270
|
+
new_workers = normalize_tokens(new_config.get("worker_tokens", []))
|
271
|
+
|
272
|
+
current_admins = normalize_tokens(current_config.get("admin_tokens", []))
|
273
|
+
new_admins = normalize_tokens(new_config.get("admin_tokens", []))
|
274
|
+
|
275
|
+
current_monitors = normalize_tokens(current_config.get("monitor_tokens", []))
|
276
|
+
new_monitors = normalize_tokens(new_config.get("monitor_tokens", []))
|
277
|
+
|
278
|
+
return (
|
279
|
+
current_workers == new_workers
|
280
|
+
and current_admins == new_admins
|
281
|
+
and current_monitors == new_monitors
|
282
|
+
)
|
283
|
+
|
284
|
+
async def _handle_config_reload(self, websocket: ServerConnection, new_config: Dict[str, Any]):
|
256
285
|
"""Handle configuration reload request."""
|
257
286
|
logger.info("Processing configuration reload request")
|
258
287
|
|
@@ -293,8 +322,16 @@ class Orchestrator:
|
|
293
322
|
# Update auth configuration
|
294
323
|
if "auth" in orchestrator_config:
|
295
324
|
try:
|
296
|
-
|
297
|
-
|
325
|
+
current_auth_config = self.config.get("auth", {})
|
326
|
+
new_auth_config = orchestrator_config["auth"]
|
327
|
+
|
328
|
+
# Only recreate AuthManager if auth config has actually changed
|
329
|
+
if not self._auth_configs_equal(current_auth_config, new_auth_config):
|
330
|
+
self.auth = AuthManager(new_auth_config)
|
331
|
+
updated_sections.append("auth")
|
332
|
+
logger.info("Auth configuration updated due to changes")
|
333
|
+
else:
|
334
|
+
logger.info("Auth configuration unchanged, preserving existing AuthManager")
|
298
335
|
except Exception as e:
|
299
336
|
logger.error(f"Failed to update AuthManager: {e}")
|
300
337
|
warnings.append(f"Auth update failed: {e}")
|
@@ -344,7 +381,7 @@ class Orchestrator:
|
|
344
381
|
assignment_id=str(uuid.uuid4()),
|
345
382
|
worker_id=worker_id,
|
346
383
|
units=units,
|
347
|
-
assigned_at=datetime.
|
384
|
+
assigned_at=datetime.now(_datetime.UTC),
|
348
385
|
)
|
349
386
|
|
350
387
|
await self.workers[worker_id].send(
|
@@ -374,75 +411,93 @@ class Orchestrator:
|
|
374
411
|
logger.debug(f"Heartbeat from {worker_id}: {data}")
|
375
412
|
|
376
413
|
async def _handle_results_submission(self, worker_id: str, data: Dict):
|
377
|
-
"""Process results submission from worker."""
|
378
|
-
#
|
379
|
-
|
380
|
-
# Create work result
|
381
|
-
_job_id = data.get("job_id")
|
382
|
-
job_id = JobId.from_str(_job_id)
|
383
|
-
shard_name = job_id.shard_id # >data-0000<
|
384
|
-
chunk_name = job_id.chunk_id # data-0000:chunk:>0<
|
385
|
-
# logger.debug(f"({job_id}) Worker result: {data}")
|
386
|
-
result = WorkResult(
|
387
|
-
unit_id=data["unit_id"],
|
388
|
-
source_id=shard_name,
|
389
|
-
chunk_id=job_id.get_chunk_str(), # we want the full string here
|
390
|
-
sample_id=data["sample_id"],
|
391
|
-
dataset=data["dataset"],
|
392
|
-
outputs=data["outputs"],
|
393
|
-
metadata=data.get("metadata", {}),
|
394
|
-
processing_time_ms=data.get("processing_time_ms", 0),
|
395
|
-
)
|
414
|
+
"""Process results submission from worker - fires off async task and returns immediately."""
|
415
|
+
# Fire and forget - process in background
|
416
|
+
asyncio.create_task(self._process_result_async(worker_id, data))
|
396
417
|
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
chunk_id=chunk_name,
|
420
|
-
item_key=item_key,
|
421
|
-
captions=result.outputs.get("captions", []),
|
422
|
-
outputs=result.outputs,
|
423
|
-
contributor_id=worker_user,
|
424
|
-
timestamp=datetime.utcnow(),
|
425
|
-
caption_count=total_outputs,
|
426
|
-
processing_time_ms=result.processing_time_ms,
|
427
|
-
metadata=result.metadata,
|
428
|
-
image_height=image_height,
|
429
|
-
image_width=image_width,
|
430
|
-
filename=filename,
|
431
|
-
url=url,
|
432
|
-
file_size=file_size,
|
433
|
-
image_format=image_format,
|
434
|
-
)
|
418
|
+
async def _process_result_async(self, worker_id: str, data: Dict):
|
419
|
+
"""Actually process the result in background."""
|
420
|
+
try:
|
421
|
+
# Extract user from worker_id
|
422
|
+
worker_user = worker_id.rsplit("_", 1)[0] if "_" in worker_id else worker_id
|
423
|
+
|
424
|
+
# Create work result
|
425
|
+
_job_id = data.get("job_id")
|
426
|
+
job_id = JobId.from_str(_job_id)
|
427
|
+
shard_name = job_id.shard_id
|
428
|
+
chunk_name = job_id.chunk_id
|
429
|
+
|
430
|
+
result = WorkResult(
|
431
|
+
unit_id=data["unit_id"],
|
432
|
+
source_id=shard_name,
|
433
|
+
chunk_id=job_id.get_chunk_str(),
|
434
|
+
sample_id=data["sample_id"],
|
435
|
+
dataset=data["dataset"],
|
436
|
+
outputs=data["outputs"],
|
437
|
+
metadata=data.get("metadata", {}),
|
438
|
+
processing_time_ms=data.get("processing_time_ms", 0),
|
439
|
+
)
|
435
440
|
|
436
|
-
|
437
|
-
|
441
|
+
# Let processor handle any custom processing - this updates chunk tracker
|
442
|
+
# IMPORTANT: Call this BEFORE saving to storage so chunk tracker is updated
|
443
|
+
# regardless of whether the item is a duplicate
|
444
|
+
processed = self.processor.handle_result(result)
|
445
|
+
|
446
|
+
# Create caption record for storage
|
447
|
+
total_outputs = sum(len(v) for v in result.outputs.values())
|
448
|
+
|
449
|
+
filename = result.metadata.pop("_filename", None)
|
450
|
+
url = result.metadata.pop("_url", None)
|
451
|
+
image_height = result.metadata.pop("image_height", None)
|
452
|
+
image_width = result.metadata.pop("image_width", None)
|
453
|
+
file_size = result.metadata.pop("file_size", None)
|
454
|
+
image_format = result.metadata.pop("image_format", None)
|
455
|
+
result.metadata.pop("item_index", None)
|
456
|
+
item_key = result.metadata.pop("item_key", None)
|
457
|
+
|
458
|
+
to_delete_metadata_keys = ["_image_format", "_job_id"]
|
459
|
+
for key in to_delete_metadata_keys:
|
460
|
+
if key in result.metadata:
|
461
|
+
del result.metadata[key]
|
462
|
+
|
463
|
+
caption = Caption(
|
464
|
+
job_id=job_id,
|
465
|
+
dataset=result.dataset,
|
466
|
+
shard=processed["source_id"],
|
467
|
+
chunk_id=chunk_name,
|
468
|
+
item_key=item_key,
|
469
|
+
captions=result.outputs.get("captions", []),
|
470
|
+
outputs=result.outputs,
|
471
|
+
contributor_id=worker_user,
|
472
|
+
timestamp=datetime.now(_datetime.UTC),
|
473
|
+
caption_count=total_outputs,
|
474
|
+
processing_time_ms=result.processing_time_ms,
|
475
|
+
metadata=result.metadata,
|
476
|
+
image_height=image_height,
|
477
|
+
image_width=image_width,
|
478
|
+
filename=filename,
|
479
|
+
url=url,
|
480
|
+
file_size=file_size,
|
481
|
+
image_format=image_format,
|
482
|
+
)
|
438
483
|
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
contributor
|
443
|
-
|
484
|
+
# Save to storage (might skip if duplicate)
|
485
|
+
saved = await self.storage.save_caption(caption)
|
486
|
+
|
487
|
+
# Update contributor stats only if actually saved
|
488
|
+
if saved:
|
489
|
+
contributor = await self.storage.get_contributor(worker_user)
|
490
|
+
if contributor:
|
491
|
+
contributor.total_captions += total_outputs
|
492
|
+
await self.storage.save_contributor(contributor)
|
493
|
+
|
494
|
+
except Exception as e:
|
495
|
+
logger.error(
|
496
|
+
f"Error processing result from {worker_id} for unit {data.get('unit_id', 'unknown')}: {e}",
|
497
|
+
exc_info=True,
|
498
|
+
)
|
444
499
|
|
445
|
-
async def _handle_monitor(self, websocket:
|
500
|
+
async def _handle_monitor(self, websocket: ServerConnection):
|
446
501
|
"""Handle monitor connection."""
|
447
502
|
self.monitors.add(websocket)
|
448
503
|
logger.info(f"Monitor connected (total: {len(self.monitors)})")
|
@@ -455,7 +510,7 @@ class Orchestrator:
|
|
455
510
|
await self._send_monitor_stats(websocket)
|
456
511
|
|
457
512
|
# Keep connection alive
|
458
|
-
async for
|
513
|
+
async for _message in websocket:
|
459
514
|
pass
|
460
515
|
|
461
516
|
except websockets.exceptions.ConnectionClosed:
|
@@ -463,7 +518,7 @@ class Orchestrator:
|
|
463
518
|
finally:
|
464
519
|
self.monitors.discard(websocket)
|
465
520
|
|
466
|
-
async def _handle_admin(self, websocket:
|
521
|
+
async def _handle_admin(self, websocket: ServerConnection, auth_ticket):
|
467
522
|
"""Handle admin connection."""
|
468
523
|
admin_id = getattr(auth_ticket, "name", "admin")
|
469
524
|
logger.info(f"Admin {admin_id} connected")
|
@@ -485,7 +540,7 @@ class Orchestrator:
|
|
485
540
|
except websockets.exceptions.ConnectionClosed:
|
486
541
|
logger.info(f"Admin {admin_id} disconnected")
|
487
542
|
|
488
|
-
async def _handle_data_worker(self, websocket:
|
543
|
+
async def _handle_data_worker(self, websocket: ServerConnection, auth_ticket):
|
489
544
|
"""Handle data worker connection."""
|
490
545
|
worker_id = getattr(auth_ticket, "name", str(uuid.uuid4()))
|
491
546
|
self.data_workers[worker_id] = websocket
|
@@ -554,7 +609,7 @@ class Orchestrator:
|
|
554
609
|
finally:
|
555
610
|
del self.data_workers[worker_id]
|
556
611
|
|
557
|
-
async def _send_monitor_initial_data(self, websocket:
|
612
|
+
async def _send_monitor_initial_data(self, websocket: ServerConnection):
|
558
613
|
"""Send initial data to monitor in a separate task to avoid blocking."""
|
559
614
|
total_start = time.time()
|
560
615
|
try:
|
@@ -611,7 +666,7 @@ class Orchestrator:
|
|
611
666
|
except Exception as e:
|
612
667
|
logger.error(f"Error sending initial monitor data: {e}")
|
613
668
|
|
614
|
-
async def _send_monitor_leaderboard(self, websocket:
|
669
|
+
async def _send_monitor_leaderboard(self, websocket: ServerConnection):
|
615
670
|
"""Send leaderboard data to a specific monitor."""
|
616
671
|
total_start = time.time()
|
617
672
|
try:
|
@@ -676,7 +731,7 @@ class Orchestrator:
|
|
676
731
|
except Exception as e:
|
677
732
|
logger.error(f"Error sending leaderboard to monitor: {e}")
|
678
733
|
|
679
|
-
async def _send_monitor_stats(self, websocket:
|
734
|
+
async def _send_monitor_stats(self, websocket: ServerConnection):
|
680
735
|
"""Send current stats to a monitor."""
|
681
736
|
# Get processor stats
|
682
737
|
processor_stats = self.processor.get_stats()
|
@@ -788,7 +843,7 @@ class Orchestrator:
|
|
788
843
|
# Remove disconnected
|
789
844
|
disconnected = {
|
790
845
|
m
|
791
|
-
for m, r in zip(monitors_copy, results)
|
846
|
+
for m, r in zip(monitors_copy, results, strict=False)
|
792
847
|
if r is not None and not isinstance(r, Exception)
|
793
848
|
}
|
794
849
|
self.monitors -= disconnected
|
@@ -840,39 +895,54 @@ class Orchestrator:
|
|
840
895
|
self.monitors -= disconnected
|
841
896
|
|
842
897
|
async def _heartbeat_loop(self):
|
843
|
-
"""
|
898
|
+
"""Collect and log worker status periodically."""
|
844
899
|
while True:
|
845
900
|
await asyncio.sleep(30)
|
846
901
|
|
847
|
-
|
902
|
+
# Just collect status - no ping/pong
|
903
|
+
active_workers = []
|
848
904
|
for worker_id, ws in list(self.workers.items()):
|
849
|
-
|
850
|
-
|
851
|
-
|
852
|
-
|
853
|
-
|
854
|
-
|
855
|
-
# Clean up disconnected workers
|
856
|
-
for worker_id in disconnected:
|
857
|
-
logger.warning(f"Worker {worker_id} did not respond to ping, disconnecting")
|
858
|
-
if worker_id in self.workers:
|
905
|
+
# Check if WebSocket is still open (don't ping)
|
906
|
+
if ws.state == websockets.protocol.State.OPEN:
|
907
|
+
active_workers.append(worker_id)
|
908
|
+
else:
|
909
|
+
# Clean up closed connections
|
910
|
+
logger.info(f"Worker {worker_id} connection closed")
|
859
911
|
del self.workers[worker_id]
|
860
|
-
logger.warning(
|
861
|
-
f"Releasing assignments for worker {worker_id} because it did not respond to ping"
|
862
|
-
)
|
863
912
|
self.processor.release_assignments(worker_id)
|
864
|
-
|
913
|
+
|
914
|
+
# Log status
|
915
|
+
if active_workers:
|
916
|
+
logger.debug(
|
917
|
+
f"Inactive workers: {len(self.workers) - len(active_workers)}/{len(active_workers)} - {', '.join(active_workers[:5])}"
|
918
|
+
)
|
919
|
+
# add to self.stats
|
920
|
+
self.stats["active_workers"] = len(active_workers)
|
921
|
+
self.stats["inactive_workers"] = len(self.workers) - len(active_workers)
|
865
922
|
|
866
923
|
async def _checkpoint_loop(self):
|
867
|
-
"""Periodically checkpoint storage."""
|
924
|
+
"""Periodically checkpoint storage and chunk tracker."""
|
868
925
|
interval = self.config.get("storage", {}).get("checkpoint_interval", 60)
|
869
926
|
|
870
927
|
while True:
|
871
928
|
await asyncio.sleep(interval)
|
872
929
|
|
873
|
-
|
874
|
-
|
875
|
-
|
930
|
+
try:
|
931
|
+
# Checkpoint storage
|
932
|
+
await self.storage.checkpoint()
|
933
|
+
|
934
|
+
# Also checkpoint the chunk tracker if using webdataset processor
|
935
|
+
if hasattr(self.processor, "chunk_tracker") and self.processor.chunk_tracker:
|
936
|
+
# Save checkpoint in thread pool to avoid blocking
|
937
|
+
await asyncio.get_event_loop().run_in_executor(
|
938
|
+
None, self.processor.chunk_tracker.save
|
939
|
+
)
|
940
|
+
logger.debug("Saved chunk tracker checkpoint")
|
941
|
+
|
942
|
+
self.stats["last_checkpoint"] = datetime.now(_datetime.UTC).isoformat()
|
943
|
+
logger.info("Storage and chunk tracker checkpoint complete")
|
944
|
+
except Exception as e:
|
945
|
+
logger.error(f"Error during checkpoint: {e}", exc_info=True)
|
876
946
|
|
877
947
|
async def _stats_update_loop(self):
|
878
948
|
"""Periodically update and broadcast stats."""
|
@@ -1,11 +1,11 @@
|
|
1
1
|
from .base import (
|
2
2
|
OrchestratorProcessor,
|
3
|
-
WorkerProcessor,
|
4
3
|
ProcessorConfig,
|
5
|
-
WorkUnit,
|
6
4
|
WorkAssignment,
|
5
|
+
WorkerProcessor,
|
7
6
|
WorkResult,
|
7
|
+
WorkUnit,
|
8
8
|
)
|
9
9
|
from .huggingface import HuggingFaceDatasetOrchestratorProcessor, HuggingFaceDatasetWorkerProcessor
|
10
|
-
from .webdataset import WebDatasetOrchestratorProcessor, WebDatasetWorkerProcessor
|
11
10
|
from .local_filesystem import LocalFilesystemOrchestratorProcessor, LocalFilesystemWorkerProcessor
|
11
|
+
from .webdataset import WebDatasetOrchestratorProcessor, WebDatasetWorkerProcessor
|
caption_flow/processors/base.py
CHANGED
@@ -2,9 +2,8 @@
|
|
2
2
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
4
|
from dataclasses import dataclass, field
|
5
|
-
from typing import Dict, Any, List, Optional, Iterator, Tuple
|
6
5
|
from datetime import datetime
|
7
|
-
from
|
6
|
+
from typing import Any, Dict, Iterator, List, Optional
|
8
7
|
|
9
8
|
|
10
9
|
@dataclass
|
@@ -98,9 +97,7 @@ class WorkResult:
|
|
98
97
|
return self.error is None and bool(self.outputs)
|
99
98
|
|
100
99
|
def to_repr(self, filter_outputs: bool = True):
|
101
|
-
"""
|
102
|
-
Print the WorkResult, optionally without captions to save on screen wall-of-text dumpage.
|
103
|
-
"""
|
100
|
+
"""Print the WorkResult, optionally without captions to save on screen wall-of-text dumpage."""
|
104
101
|
if filter_outputs:
|
105
102
|
outputs = "...filtered from logs..."
|
106
103
|
else:
|
@@ -172,6 +169,8 @@ class OrchestratorProcessor(ABC):
|
|
172
169
|
class WorkerProcessor(ABC):
|
173
170
|
"""Base processor for worker side - processes work units."""
|
174
171
|
|
172
|
+
gpu_id: Optional[int] = None
|
173
|
+
|
175
174
|
@abstractmethod
|
176
175
|
def initialize(self, config: ProcessorConfig) -> None:
|
177
176
|
"""Initialize the processor with configuration."""
|
@@ -179,18 +178,20 @@ class WorkerProcessor(ABC):
|
|
179
178
|
|
180
179
|
@abstractmethod
|
181
180
|
def process_unit(self, unit: WorkUnit, context: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
|
182
|
-
"""
|
183
|
-
Process a single work unit, yielding items to be captioned.
|
181
|
+
"""Process a single work unit, yielding items to be captioned.
|
184
182
|
|
185
183
|
Args:
|
184
|
+
----
|
186
185
|
unit: The work unit to process
|
187
186
|
context: Runtime context (e.g., models, sampling params)
|
188
187
|
|
189
188
|
Yields:
|
189
|
+
------
|
190
190
|
Dict containing:
|
191
191
|
- image: PIL Image
|
192
192
|
- metadata: Dict of metadata
|
193
193
|
- item_key: Unique identifier for this item
|
194
|
+
|
194
195
|
"""
|
195
196
|
pass
|
196
197
|
|