caption-flow 0.3.4__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +3 -3
- caption_flow/cli.py +934 -415
- caption_flow/models.py +45 -3
- caption_flow/monitor.py +2 -3
- caption_flow/orchestrator.py +153 -104
- caption_flow/processors/__init__.py +3 -3
- caption_flow/processors/base.py +8 -7
- caption_flow/processors/huggingface.py +439 -67
- caption_flow/processors/local_filesystem.py +24 -28
- caption_flow/processors/webdataset.py +28 -22
- caption_flow/storage/exporter.py +420 -339
- caption_flow/storage/manager.py +636 -756
- caption_flow/utils/__init__.py +1 -1
- caption_flow/utils/auth.py +1 -1
- caption_flow/utils/caption_utils.py +1 -1
- caption_flow/utils/certificates.py +15 -8
- caption_flow/utils/checkpoint_tracker.py +30 -28
- caption_flow/utils/chunk_tracker.py +153 -56
- caption_flow/utils/image_processor.py +9 -9
- caption_flow/utils/json_utils.py +37 -20
- caption_flow/utils/prompt_template.py +24 -16
- caption_flow/utils/vllm_config.py +5 -4
- caption_flow/viewer.py +4 -12
- caption_flow/workers/base.py +5 -4
- caption_flow/workers/caption.py +265 -90
- caption_flow/workers/data.py +6 -8
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/METADATA +5 -4
- caption_flow-0.4.0.dist-info/RECORD +33 -0
- caption_flow-0.3.4.dist-info/RECORD +0 -33
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/WHEEL +0 -0
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/top_level.txt +0 -0
caption_flow/models.py
CHANGED
@@ -1,12 +1,18 @@
|
|
1
1
|
"""Data models for CaptionFlow."""
|
2
2
|
|
3
|
-
import
|
3
|
+
import datetime as _datetime
|
4
|
+
import logging
|
5
|
+
import os
|
4
6
|
from dataclasses import dataclass, field
|
5
7
|
from datetime import datetime
|
6
8
|
from enum import Enum
|
7
9
|
from typing import Any, Dict, List, Optional, Tuple
|
10
|
+
|
8
11
|
from PIL import Image
|
9
12
|
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
logger.setLevel(os.environ.get("CAPTIONFLOW_LOG_LEVEL", "INFO").upper())
|
15
|
+
|
10
16
|
|
11
17
|
class JobStatus(Enum):
|
12
18
|
"""Job processing status."""
|
@@ -37,7 +43,7 @@ class Job:
|
|
37
43
|
|
38
44
|
def __post_init__(self):
|
39
45
|
if self.created_at is None:
|
40
|
-
self.created_at = datetime.
|
46
|
+
self.created_at = datetime.now(_datetime.UTC)
|
41
47
|
|
42
48
|
|
43
49
|
@dataclass
|
@@ -69,7 +75,43 @@ class JobId:
|
|
69
75
|
parts = job_id.split(":")
|
70
76
|
if len(parts) != 5:
|
71
77
|
raise ValueError(f"Invalid job_id format: {job_id}")
|
72
|
-
|
78
|
+
|
79
|
+
shard_id = parts[0]
|
80
|
+
chunk_keyword = parts[1]
|
81
|
+
chunk_id = parts[2]
|
82
|
+
idx_keyword = parts[3]
|
83
|
+
sample_id = parts[4]
|
84
|
+
|
85
|
+
# Validate format
|
86
|
+
if not shard_id:
|
87
|
+
raise ValueError(f"Invalid job_id format: empty shard_id in {job_id}")
|
88
|
+
if chunk_keyword != "chunk":
|
89
|
+
raise ValueError(
|
90
|
+
f"Invalid job_id format: expected 'chunk' keyword, got '{chunk_keyword}' in {job_id}"
|
91
|
+
)
|
92
|
+
if idx_keyword != "idx":
|
93
|
+
raise ValueError(
|
94
|
+
f"Invalid job_id format: expected 'idx' keyword, got '{idx_keyword}' in {job_id}"
|
95
|
+
)
|
96
|
+
|
97
|
+
# Validate numeric fields
|
98
|
+
try:
|
99
|
+
int(chunk_id)
|
100
|
+
except ValueError:
|
101
|
+
raise ValueError(
|
102
|
+
f"Invalid job_id format: chunk_id must be numeric, got '{chunk_id}' in {job_id}"
|
103
|
+
)
|
104
|
+
|
105
|
+
# sample_id can be empty/None for some use cases, but if provided must be numeric
|
106
|
+
if sample_id:
|
107
|
+
try:
|
108
|
+
int(sample_id)
|
109
|
+
except ValueError:
|
110
|
+
raise ValueError(
|
111
|
+
f"Invalid job_id format: sample_id must be numeric if provided, got '{sample_id}' in {job_id}"
|
112
|
+
)
|
113
|
+
|
114
|
+
return JobId(shard_id=shard_id, chunk_id=chunk_id, sample_id=sample_id)
|
73
115
|
|
74
116
|
|
75
117
|
@dataclass
|
caption_flow/monitor.py
CHANGED
@@ -4,9 +4,8 @@ import asyncio
|
|
4
4
|
import json
|
5
5
|
import logging
|
6
6
|
import ssl
|
7
|
-
import time
|
8
7
|
from datetime import datetime
|
9
|
-
from typing import
|
8
|
+
from typing import Any, Dict, Optional
|
10
9
|
|
11
10
|
import websockets
|
12
11
|
from rich.console import Console
|
@@ -110,7 +109,7 @@ class Monitor:
|
|
110
109
|
"""Main display update loop."""
|
111
110
|
layout = self._create_layout()
|
112
111
|
|
113
|
-
with Live(layout, console=self.console, refresh_per_second=1, screen=True)
|
112
|
+
with Live(layout, console=self.console, refresh_per_second=1, screen=True):
|
114
113
|
while self.running:
|
115
114
|
self._update_layout(layout)
|
116
115
|
await asyncio.sleep(0.25)
|
caption_flow/orchestrator.py
CHANGED
@@ -1,34 +1,34 @@
|
|
1
|
-
import time
|
2
1
|
import asyncio
|
2
|
+
import datetime as _datetime
|
3
3
|
import json
|
4
4
|
import logging
|
5
|
+
import os
|
5
6
|
import ssl
|
7
|
+
import time
|
6
8
|
import uuid
|
9
|
+
from collections import defaultdict
|
7
10
|
from datetime import datetime
|
8
11
|
from pathlib import Path
|
9
|
-
from typing import
|
10
|
-
from collections import defaultdict
|
11
|
-
import threading
|
12
|
+
from typing import Any, Dict, Optional, Set
|
12
13
|
|
13
14
|
import websockets
|
14
|
-
from websockets.server import
|
15
|
+
from websockets.asyncio.server import ServerConnection
|
15
16
|
|
16
|
-
from .storage import StorageManager
|
17
17
|
from .models import Caption, Contributor, JobId
|
18
|
-
from .utils.auth import AuthManager
|
19
|
-
from .utils.json_utils import safe_json_dumps
|
20
18
|
from .processors import (
|
19
|
+
HuggingFaceDatasetOrchestratorProcessor,
|
20
|
+
LocalFilesystemOrchestratorProcessor,
|
21
21
|
ProcessorConfig,
|
22
|
+
WebDatasetOrchestratorProcessor,
|
22
23
|
WorkAssignment,
|
23
24
|
WorkResult,
|
24
|
-
WorkUnit,
|
25
|
-
WebDatasetOrchestratorProcessor,
|
26
|
-
HuggingFaceDatasetOrchestratorProcessor,
|
27
|
-
LocalFilesystemOrchestratorProcessor,
|
28
25
|
)
|
26
|
+
from .storage import StorageManager
|
27
|
+
from .utils.auth import AuthManager
|
28
|
+
from .utils.json_utils import safe_json_dumps
|
29
29
|
|
30
30
|
logger = logging.getLogger(__name__)
|
31
|
-
logger.setLevel(
|
31
|
+
logger.setLevel(os.environ.get("CAPTIONFLOW_LOG_LEVEL", "INFO").upper())
|
32
32
|
|
33
33
|
|
34
34
|
class Orchestrator:
|
@@ -69,8 +69,8 @@ class Orchestrator:
|
|
69
69
|
self.chunks_per_request = config.get("chunks_per_request", 2)
|
70
70
|
|
71
71
|
# Track connections
|
72
|
-
self.workers: Dict[str,
|
73
|
-
self.monitors: Set[
|
72
|
+
self.workers: Dict[str, ServerConnection] = {}
|
73
|
+
self.monitors: Set[ServerConnection] = set()
|
74
74
|
self.workers_by_user = defaultdict(set)
|
75
75
|
|
76
76
|
# SSL configuration
|
@@ -160,11 +160,11 @@ class Orchestrator:
|
|
160
160
|
processed_job_ids = self.storage.get_all_processed_job_ids()
|
161
161
|
self.processor.update_from_storage(processed_job_ids)
|
162
162
|
|
163
|
-
async def _send_leaderboard_to_monitor(self, websocket:
|
163
|
+
async def _send_leaderboard_to_monitor(self, websocket: ServerConnection):
|
164
164
|
"""Alias for _send_monitor_leaderboard for backward compatibility."""
|
165
165
|
await self._send_monitor_leaderboard(websocket)
|
166
166
|
|
167
|
-
async def handle_connection(self, websocket:
|
167
|
+
async def handle_connection(self, websocket: ServerConnection):
|
168
168
|
"""Handle new WebSocket connection."""
|
169
169
|
try:
|
170
170
|
# Authenticate
|
@@ -193,7 +193,7 @@ class Orchestrator:
|
|
193
193
|
logger.error(f"Connection error: {e}", exc_info=True)
|
194
194
|
await websocket.close()
|
195
195
|
|
196
|
-
async def _handle_worker(self, websocket:
|
196
|
+
async def _handle_worker(self, websocket: ServerConnection, auth_ticket):
|
197
197
|
"""Handle worker connection lifecycle."""
|
198
198
|
# Generate unique worker ID
|
199
199
|
base_name = getattr(auth_ticket, "name", "worker")
|
@@ -250,9 +250,38 @@ class Orchestrator:
|
|
250
250
|
self.processor.release_assignments(worker_id)
|
251
251
|
logger.info(f"Worker {worker_id} has safely disconnected")
|
252
252
|
|
253
|
-
|
254
|
-
self,
|
255
|
-
):
|
253
|
+
def _auth_configs_equal(
|
254
|
+
self, current_config: Dict[str, Any], new_config: Dict[str, Any]
|
255
|
+
) -> bool:
|
256
|
+
"""Compare two auth configurations for equality."""
|
257
|
+
|
258
|
+
# Helper function to normalize token lists for comparison
|
259
|
+
def normalize_tokens(tokens):
|
260
|
+
if not tokens:
|
261
|
+
return []
|
262
|
+
# Sort by token for consistent comparison
|
263
|
+
return sorted(
|
264
|
+
[{"name": t.get("name"), "token": t.get("token")} for t in tokens],
|
265
|
+
key=lambda x: x.get("token", ""),
|
266
|
+
)
|
267
|
+
|
268
|
+
# Compare each token type
|
269
|
+
current_workers = normalize_tokens(current_config.get("worker_tokens", []))
|
270
|
+
new_workers = normalize_tokens(new_config.get("worker_tokens", []))
|
271
|
+
|
272
|
+
current_admins = normalize_tokens(current_config.get("admin_tokens", []))
|
273
|
+
new_admins = normalize_tokens(new_config.get("admin_tokens", []))
|
274
|
+
|
275
|
+
current_monitors = normalize_tokens(current_config.get("monitor_tokens", []))
|
276
|
+
new_monitors = normalize_tokens(new_config.get("monitor_tokens", []))
|
277
|
+
|
278
|
+
return (
|
279
|
+
current_workers == new_workers
|
280
|
+
and current_admins == new_admins
|
281
|
+
and current_monitors == new_monitors
|
282
|
+
)
|
283
|
+
|
284
|
+
async def _handle_config_reload(self, websocket: ServerConnection, new_config: Dict[str, Any]):
|
256
285
|
"""Handle configuration reload request."""
|
257
286
|
logger.info("Processing configuration reload request")
|
258
287
|
|
@@ -293,8 +322,16 @@ class Orchestrator:
|
|
293
322
|
# Update auth configuration
|
294
323
|
if "auth" in orchestrator_config:
|
295
324
|
try:
|
296
|
-
|
297
|
-
|
325
|
+
current_auth_config = self.config.get("auth", {})
|
326
|
+
new_auth_config = orchestrator_config["auth"]
|
327
|
+
|
328
|
+
# Only recreate AuthManager if auth config has actually changed
|
329
|
+
if not self._auth_configs_equal(current_auth_config, new_auth_config):
|
330
|
+
self.auth = AuthManager(new_auth_config)
|
331
|
+
updated_sections.append("auth")
|
332
|
+
logger.info("Auth configuration updated due to changes")
|
333
|
+
else:
|
334
|
+
logger.info("Auth configuration unchanged, preserving existing AuthManager")
|
298
335
|
except Exception as e:
|
299
336
|
logger.error(f"Failed to update AuthManager: {e}")
|
300
337
|
warnings.append(f"Auth update failed: {e}")
|
@@ -344,7 +381,7 @@ class Orchestrator:
|
|
344
381
|
assignment_id=str(uuid.uuid4()),
|
345
382
|
worker_id=worker_id,
|
346
383
|
units=units,
|
347
|
-
assigned_at=datetime.
|
384
|
+
assigned_at=datetime.now(_datetime.UTC),
|
348
385
|
)
|
349
386
|
|
350
387
|
await self.workers[worker_id].send(
|
@@ -374,80 +411,93 @@ class Orchestrator:
|
|
374
411
|
logger.debug(f"Heartbeat from {worker_id}: {data}")
|
375
412
|
|
376
413
|
async def _handle_results_submission(self, worker_id: str, data: Dict):
|
377
|
-
"""Process results submission from worker."""
|
378
|
-
#
|
379
|
-
|
380
|
-
|
381
|
-
# Create work result
|
382
|
-
_job_id = data.get("job_id")
|
383
|
-
job_id = JobId.from_str(_job_id)
|
384
|
-
shard_name = job_id.shard_id
|
385
|
-
chunk_name = job_id.chunk_id
|
386
|
-
|
387
|
-
result = WorkResult(
|
388
|
-
unit_id=data["unit_id"],
|
389
|
-
source_id=shard_name,
|
390
|
-
chunk_id=job_id.get_chunk_str(),
|
391
|
-
sample_id=data["sample_id"],
|
392
|
-
dataset=data["dataset"],
|
393
|
-
outputs=data["outputs"],
|
394
|
-
metadata=data.get("metadata", {}),
|
395
|
-
processing_time_ms=data.get("processing_time_ms", 0),
|
396
|
-
)
|
414
|
+
"""Process results submission from worker - fires off async task and returns immediately."""
|
415
|
+
# Fire and forget - process in background
|
416
|
+
asyncio.create_task(self._process_result_async(worker_id, data))
|
397
417
|
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
418
|
+
async def _process_result_async(self, worker_id: str, data: Dict):
|
419
|
+
"""Actually process the result in background."""
|
420
|
+
try:
|
421
|
+
# Extract user from worker_id
|
422
|
+
worker_user = worker_id.rsplit("_", 1)[0] if "_" in worker_id else worker_id
|
423
|
+
|
424
|
+
# Create work result
|
425
|
+
_job_id = data.get("job_id")
|
426
|
+
job_id = JobId.from_str(_job_id)
|
427
|
+
shard_name = job_id.shard_id
|
428
|
+
chunk_name = job_id.chunk_id
|
429
|
+
|
430
|
+
result = WorkResult(
|
431
|
+
unit_id=data["unit_id"],
|
432
|
+
source_id=shard_name,
|
433
|
+
chunk_id=job_id.get_chunk_str(),
|
434
|
+
sample_id=data["sample_id"],
|
435
|
+
dataset=data["dataset"],
|
436
|
+
outputs=data["outputs"],
|
437
|
+
metadata=data.get("metadata", {}),
|
438
|
+
processing_time_ms=data.get("processing_time_ms", 0),
|
439
|
+
)
|
440
|
+
|
441
|
+
# Let processor handle any custom processing - this updates chunk tracker
|
442
|
+
# IMPORTANT: Call this BEFORE saving to storage so chunk tracker is updated
|
443
|
+
# regardless of whether the item is a duplicate
|
444
|
+
processed = self.processor.handle_result(result)
|
445
|
+
|
446
|
+
# Create caption record for storage
|
447
|
+
total_outputs = sum(len(v) for v in result.outputs.values())
|
448
|
+
|
449
|
+
filename = result.metadata.pop("_filename", None)
|
450
|
+
url = result.metadata.pop("_url", None)
|
451
|
+
image_height = result.metadata.pop("image_height", None)
|
452
|
+
image_width = result.metadata.pop("image_width", None)
|
453
|
+
file_size = result.metadata.pop("file_size", None)
|
454
|
+
image_format = result.metadata.pop("image_format", None)
|
455
|
+
result.metadata.pop("item_index", None)
|
456
|
+
item_key = result.metadata.pop("item_key", None)
|
457
|
+
|
458
|
+
to_delete_metadata_keys = ["_image_format", "_job_id"]
|
459
|
+
for key in to_delete_metadata_keys:
|
460
|
+
if key in result.metadata:
|
461
|
+
del result.metadata[key]
|
462
|
+
|
463
|
+
caption = Caption(
|
464
|
+
job_id=job_id,
|
465
|
+
dataset=result.dataset,
|
466
|
+
shard=processed["source_id"],
|
467
|
+
chunk_id=chunk_name,
|
468
|
+
item_key=item_key,
|
469
|
+
captions=result.outputs.get("captions", []),
|
470
|
+
outputs=result.outputs,
|
471
|
+
contributor_id=worker_user,
|
472
|
+
timestamp=datetime.now(_datetime.UTC),
|
473
|
+
caption_count=total_outputs,
|
474
|
+
processing_time_ms=result.processing_time_ms,
|
475
|
+
metadata=result.metadata,
|
476
|
+
image_height=image_height,
|
477
|
+
image_width=image_width,
|
478
|
+
filename=filename,
|
479
|
+
url=url,
|
480
|
+
file_size=file_size,
|
481
|
+
image_format=image_format,
|
482
|
+
)
|
439
483
|
|
440
|
-
|
441
|
-
|
484
|
+
# Save to storage (might skip if duplicate)
|
485
|
+
saved = await self.storage.save_caption(caption)
|
442
486
|
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
487
|
+
# Update contributor stats only if actually saved
|
488
|
+
if saved:
|
489
|
+
contributor = await self.storage.get_contributor(worker_user)
|
490
|
+
if contributor:
|
491
|
+
contributor.total_captions += total_outputs
|
492
|
+
await self.storage.save_contributor(contributor)
|
449
493
|
|
450
|
-
|
494
|
+
except Exception as e:
|
495
|
+
logger.error(
|
496
|
+
f"Error processing result from {worker_id} for unit {data.get('unit_id', 'unknown')}: {e}",
|
497
|
+
exc_info=True,
|
498
|
+
)
|
499
|
+
|
500
|
+
async def _handle_monitor(self, websocket: ServerConnection):
|
451
501
|
"""Handle monitor connection."""
|
452
502
|
self.monitors.add(websocket)
|
453
503
|
logger.info(f"Monitor connected (total: {len(self.monitors)})")
|
@@ -460,7 +510,7 @@ class Orchestrator:
|
|
460
510
|
await self._send_monitor_stats(websocket)
|
461
511
|
|
462
512
|
# Keep connection alive
|
463
|
-
async for
|
513
|
+
async for _message in websocket:
|
464
514
|
pass
|
465
515
|
|
466
516
|
except websockets.exceptions.ConnectionClosed:
|
@@ -468,7 +518,7 @@ class Orchestrator:
|
|
468
518
|
finally:
|
469
519
|
self.monitors.discard(websocket)
|
470
520
|
|
471
|
-
async def _handle_admin(self, websocket:
|
521
|
+
async def _handle_admin(self, websocket: ServerConnection, auth_ticket):
|
472
522
|
"""Handle admin connection."""
|
473
523
|
admin_id = getattr(auth_ticket, "name", "admin")
|
474
524
|
logger.info(f"Admin {admin_id} connected")
|
@@ -490,7 +540,7 @@ class Orchestrator:
|
|
490
540
|
except websockets.exceptions.ConnectionClosed:
|
491
541
|
logger.info(f"Admin {admin_id} disconnected")
|
492
542
|
|
493
|
-
async def _handle_data_worker(self, websocket:
|
543
|
+
async def _handle_data_worker(self, websocket: ServerConnection, auth_ticket):
|
494
544
|
"""Handle data worker connection."""
|
495
545
|
worker_id = getattr(auth_ticket, "name", str(uuid.uuid4()))
|
496
546
|
self.data_workers[worker_id] = websocket
|
@@ -559,7 +609,7 @@ class Orchestrator:
|
|
559
609
|
finally:
|
560
610
|
del self.data_workers[worker_id]
|
561
611
|
|
562
|
-
async def _send_monitor_initial_data(self, websocket:
|
612
|
+
async def _send_monitor_initial_data(self, websocket: ServerConnection):
|
563
613
|
"""Send initial data to monitor in a separate task to avoid blocking."""
|
564
614
|
total_start = time.time()
|
565
615
|
try:
|
@@ -616,7 +666,7 @@ class Orchestrator:
|
|
616
666
|
except Exception as e:
|
617
667
|
logger.error(f"Error sending initial monitor data: {e}")
|
618
668
|
|
619
|
-
async def _send_monitor_leaderboard(self, websocket:
|
669
|
+
async def _send_monitor_leaderboard(self, websocket: ServerConnection):
|
620
670
|
"""Send leaderboard data to a specific monitor."""
|
621
671
|
total_start = time.time()
|
622
672
|
try:
|
@@ -681,7 +731,7 @@ class Orchestrator:
|
|
681
731
|
except Exception as e:
|
682
732
|
logger.error(f"Error sending leaderboard to monitor: {e}")
|
683
733
|
|
684
|
-
async def _send_monitor_stats(self, websocket:
|
734
|
+
async def _send_monitor_stats(self, websocket: ServerConnection):
|
685
735
|
"""Send current stats to a monitor."""
|
686
736
|
# Get processor stats
|
687
737
|
processor_stats = self.processor.get_stats()
|
@@ -793,7 +843,7 @@ class Orchestrator:
|
|
793
843
|
# Remove disconnected
|
794
844
|
disconnected = {
|
795
845
|
m
|
796
|
-
for m, r in zip(monitors_copy, results)
|
846
|
+
for m, r in zip(monitors_copy, results, strict=False)
|
797
847
|
if r is not None and not isinstance(r, Exception)
|
798
848
|
}
|
799
849
|
self.monitors -= disconnected
|
@@ -864,9 +914,8 @@ class Orchestrator:
|
|
864
914
|
# Log status
|
865
915
|
if active_workers:
|
866
916
|
logger.debug(
|
867
|
-
f"
|
917
|
+
f"Inactive workers: {len(self.workers) - len(active_workers)}/{len(active_workers)} - {', '.join(active_workers[:5])}"
|
868
918
|
)
|
869
|
-
logger.debug(f"Inactive workers: {len(self.workers) - len(active_workers)}")
|
870
919
|
# add to self.stats
|
871
920
|
self.stats["active_workers"] = len(active_workers)
|
872
921
|
self.stats["inactive_workers"] = len(self.workers) - len(active_workers)
|
@@ -890,7 +939,7 @@ class Orchestrator:
|
|
890
939
|
)
|
891
940
|
logger.debug("Saved chunk tracker checkpoint")
|
892
941
|
|
893
|
-
self.stats["last_checkpoint"] = datetime.
|
942
|
+
self.stats["last_checkpoint"] = datetime.now(_datetime.UTC).isoformat()
|
894
943
|
logger.info("Storage and chunk tracker checkpoint complete")
|
895
944
|
except Exception as e:
|
896
945
|
logger.error(f"Error during checkpoint: {e}", exc_info=True)
|
@@ -1,11 +1,11 @@
|
|
1
1
|
from .base import (
|
2
2
|
OrchestratorProcessor,
|
3
|
-
WorkerProcessor,
|
4
3
|
ProcessorConfig,
|
5
|
-
WorkUnit,
|
6
4
|
WorkAssignment,
|
5
|
+
WorkerProcessor,
|
7
6
|
WorkResult,
|
7
|
+
WorkUnit,
|
8
8
|
)
|
9
9
|
from .huggingface import HuggingFaceDatasetOrchestratorProcessor, HuggingFaceDatasetWorkerProcessor
|
10
|
-
from .webdataset import WebDatasetOrchestratorProcessor, WebDatasetWorkerProcessor
|
11
10
|
from .local_filesystem import LocalFilesystemOrchestratorProcessor, LocalFilesystemWorkerProcessor
|
11
|
+
from .webdataset import WebDatasetOrchestratorProcessor, WebDatasetWorkerProcessor
|
caption_flow/processors/base.py
CHANGED
@@ -2,9 +2,8 @@
|
|
2
2
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
4
|
from dataclasses import dataclass, field
|
5
|
-
from typing import Dict, Any, List, Optional, Iterator, Tuple
|
6
5
|
from datetime import datetime
|
7
|
-
from
|
6
|
+
from typing import Any, Dict, Iterator, List, Optional
|
8
7
|
|
9
8
|
|
10
9
|
@dataclass
|
@@ -98,9 +97,7 @@ class WorkResult:
|
|
98
97
|
return self.error is None and bool(self.outputs)
|
99
98
|
|
100
99
|
def to_repr(self, filter_outputs: bool = True):
|
101
|
-
"""
|
102
|
-
Print the WorkResult, optionally without captions to save on screen wall-of-text dumpage.
|
103
|
-
"""
|
100
|
+
"""Print the WorkResult, optionally without captions to save on screen wall-of-text dumpage."""
|
104
101
|
if filter_outputs:
|
105
102
|
outputs = "...filtered from logs..."
|
106
103
|
else:
|
@@ -172,6 +169,8 @@ class OrchestratorProcessor(ABC):
|
|
172
169
|
class WorkerProcessor(ABC):
|
173
170
|
"""Base processor for worker side - processes work units."""
|
174
171
|
|
172
|
+
gpu_id: Optional[int] = None
|
173
|
+
|
175
174
|
@abstractmethod
|
176
175
|
def initialize(self, config: ProcessorConfig) -> None:
|
177
176
|
"""Initialize the processor with configuration."""
|
@@ -179,18 +178,20 @@ class WorkerProcessor(ABC):
|
|
179
178
|
|
180
179
|
@abstractmethod
|
181
180
|
def process_unit(self, unit: WorkUnit, context: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
|
182
|
-
"""
|
183
|
-
Process a single work unit, yielding items to be captioned.
|
181
|
+
"""Process a single work unit, yielding items to be captioned.
|
184
182
|
|
185
183
|
Args:
|
184
|
+
----
|
186
185
|
unit: The work unit to process
|
187
186
|
context: Runtime context (e.g., models, sampling params)
|
188
187
|
|
189
188
|
Yields:
|
189
|
+
------
|
190
190
|
Dict containing:
|
191
191
|
- image: PIL Image
|
192
192
|
- metadata: Dict of metadata
|
193
193
|
- item_key: Unique identifier for this item
|
194
|
+
|
194
195
|
"""
|
195
196
|
pass
|
196
197
|
|