caption-flow 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. caption_flow/__init__.py +3 -3
  2. caption_flow/cli.py +937 -416
  3. caption_flow/models.py +45 -3
  4. caption_flow/monitor.py +5 -3
  5. caption_flow/orchestrator.py +186 -116
  6. caption_flow/processors/__init__.py +3 -3
  7. caption_flow/processors/base.py +8 -7
  8. caption_flow/processors/huggingface.py +440 -68
  9. caption_flow/processors/local_filesystem.py +24 -28
  10. caption_flow/processors/webdataset.py +66 -25
  11. caption_flow/storage/exporter.py +420 -339
  12. caption_flow/storage/manager.py +636 -756
  13. caption_flow/utils/__init__.py +1 -1
  14. caption_flow/utils/auth.py +1 -1
  15. caption_flow/utils/caption_utils.py +1 -1
  16. caption_flow/utils/certificates.py +15 -8
  17. caption_flow/utils/checkpoint_tracker.py +41 -19
  18. caption_flow/utils/chunk_tracker.py +200 -65
  19. caption_flow/utils/image_processor.py +9 -9
  20. caption_flow/utils/json_utils.py +37 -20
  21. caption_flow/utils/prompt_template.py +24 -16
  22. caption_flow/utils/vllm_config.py +5 -4
  23. caption_flow/viewer.py +4 -12
  24. caption_flow/workers/base.py +12 -6
  25. caption_flow/workers/caption.py +272 -91
  26. caption_flow/workers/data.py +6 -8
  27. {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/METADATA +5 -4
  28. caption_flow-0.4.0.dist-info/RECORD +33 -0
  29. caption_flow-0.3.3.dist-info/RECORD +0 -33
  30. {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/WHEEL +0 -0
  31. {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/entry_points.txt +0 -0
  32. {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/licenses/LICENSE +0 -0
  33. {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/top_level.txt +0 -0
caption_flow/models.py CHANGED
@@ -1,12 +1,18 @@
1
1
  """Data models for CaptionFlow."""
2
2
 
3
- import PIL
3
+ import datetime as _datetime
4
+ import logging
5
+ import os
4
6
  from dataclasses import dataclass, field
5
7
  from datetime import datetime
6
8
  from enum import Enum
7
9
  from typing import Any, Dict, List, Optional, Tuple
10
+
8
11
  from PIL import Image
9
12
 
13
+ logger = logging.getLogger(__name__)
14
+ logger.setLevel(os.environ.get("CAPTIONFLOW_LOG_LEVEL", "INFO").upper())
15
+
10
16
 
11
17
  class JobStatus(Enum):
12
18
  """Job processing status."""
@@ -37,7 +43,7 @@ class Job:
37
43
 
38
44
  def __post_init__(self):
39
45
  if self.created_at is None:
40
- self.created_at = datetime.utcnow()
46
+ self.created_at = datetime.now(_datetime.UTC)
41
47
 
42
48
 
43
49
  @dataclass
@@ -69,7 +75,43 @@ class JobId:
69
75
  parts = job_id.split(":")
70
76
  if len(parts) != 5:
71
77
  raise ValueError(f"Invalid job_id format: {job_id}")
72
- return JobId(shard_id=parts[0], chunk_id=parts[2], sample_id=parts[4])
78
+
79
+ shard_id = parts[0]
80
+ chunk_keyword = parts[1]
81
+ chunk_id = parts[2]
82
+ idx_keyword = parts[3]
83
+ sample_id = parts[4]
84
+
85
+ # Validate format
86
+ if not shard_id:
87
+ raise ValueError(f"Invalid job_id format: empty shard_id in {job_id}")
88
+ if chunk_keyword != "chunk":
89
+ raise ValueError(
90
+ f"Invalid job_id format: expected 'chunk' keyword, got '{chunk_keyword}' in {job_id}"
91
+ )
92
+ if idx_keyword != "idx":
93
+ raise ValueError(
94
+ f"Invalid job_id format: expected 'idx' keyword, got '{idx_keyword}' in {job_id}"
95
+ )
96
+
97
+ # Validate numeric fields
98
+ try:
99
+ int(chunk_id)
100
+ except ValueError:
101
+ raise ValueError(
102
+ f"Invalid job_id format: chunk_id must be numeric, got '{chunk_id}' in {job_id}"
103
+ )
104
+
105
+ # sample_id can be empty/None for some use cases, but if provided must be numeric
106
+ if sample_id:
107
+ try:
108
+ int(sample_id)
109
+ except ValueError:
110
+ raise ValueError(
111
+ f"Invalid job_id format: sample_id must be numeric if provided, got '{sample_id}' in {job_id}"
112
+ )
113
+
114
+ return JobId(shard_id=shard_id, chunk_id=chunk_id, sample_id=sample_id)
73
115
 
74
116
 
75
117
  @dataclass
caption_flow/monitor.py CHANGED
@@ -4,9 +4,8 @@ import asyncio
4
4
  import json
5
5
  import logging
6
6
  import ssl
7
- import time
8
7
  from datetime import datetime
9
- from typing import Dict, Any, List, Optional
8
+ from typing import Any, Dict, Optional
10
9
 
11
10
  import websockets
12
11
  from rich.console import Console
@@ -73,6 +72,9 @@ class Monitor:
73
72
  async with websockets.connect(
74
73
  self.server_url,
75
74
  ssl=self.ssl_context if self.server_url.startswith("wss://") else None,
75
+ ping_interval=20,
76
+ ping_timeout=60,
77
+ close_timeout=10,
76
78
  ) as websocket:
77
79
  # Authenticate
78
80
  await websocket.send(json.dumps({"token": self.token}))
@@ -107,7 +109,7 @@ class Monitor:
107
109
  """Main display update loop."""
108
110
  layout = self._create_layout()
109
111
 
110
- with Live(layout, console=self.console, refresh_per_second=1, screen=True) as live:
112
+ with Live(layout, console=self.console, refresh_per_second=1, screen=True):
111
113
  while self.running:
112
114
  self._update_layout(layout)
113
115
  await asyncio.sleep(0.25)
@@ -1,34 +1,34 @@
1
- import time
2
1
  import asyncio
2
+ import datetime as _datetime
3
3
  import json
4
4
  import logging
5
+ import os
5
6
  import ssl
7
+ import time
6
8
  import uuid
9
+ from collections import defaultdict
7
10
  from datetime import datetime
8
11
  from pathlib import Path
9
- from typing import Dict, Set, Optional, Any, List
10
- from collections import defaultdict
11
- import threading
12
+ from typing import Any, Dict, Optional, Set
12
13
 
13
14
  import websockets
14
- from websockets.server import WebSocketServerProtocol
15
+ from websockets.asyncio.server import ServerConnection
15
16
 
16
- from .storage import StorageManager
17
17
  from .models import Caption, Contributor, JobId
18
- from .utils.auth import AuthManager
19
- from .utils.json_utils import safe_json_dumps
20
18
  from .processors import (
19
+ HuggingFaceDatasetOrchestratorProcessor,
20
+ LocalFilesystemOrchestratorProcessor,
21
21
  ProcessorConfig,
22
+ WebDatasetOrchestratorProcessor,
22
23
  WorkAssignment,
23
24
  WorkResult,
24
- WorkUnit,
25
- WebDatasetOrchestratorProcessor,
26
- HuggingFaceDatasetOrchestratorProcessor,
27
- LocalFilesystemOrchestratorProcessor,
28
25
  )
26
+ from .storage import StorageManager
27
+ from .utils.auth import AuthManager
28
+ from .utils.json_utils import safe_json_dumps
29
29
 
30
30
  logger = logging.getLogger(__name__)
31
- logger.setLevel(logging.INFO)
31
+ logger.setLevel(os.environ.get("CAPTIONFLOW_LOG_LEVEL", "INFO").upper())
32
32
 
33
33
 
34
34
  class Orchestrator:
@@ -69,8 +69,8 @@ class Orchestrator:
69
69
  self.chunks_per_request = config.get("chunks_per_request", 2)
70
70
 
71
71
  # Track connections
72
- self.workers: Dict[str, WebSocketServerProtocol] = {}
73
- self.monitors: Set[WebSocketServerProtocol] = set()
72
+ self.workers: Dict[str, ServerConnection] = {}
73
+ self.monitors: Set[ServerConnection] = set()
74
74
  self.workers_by_user = defaultdict(set)
75
75
 
76
76
  # SSL configuration
@@ -160,11 +160,11 @@ class Orchestrator:
160
160
  processed_job_ids = self.storage.get_all_processed_job_ids()
161
161
  self.processor.update_from_storage(processed_job_ids)
162
162
 
163
- async def _send_leaderboard_to_monitor(self, websocket: WebSocketServerProtocol):
163
+ async def _send_leaderboard_to_monitor(self, websocket: ServerConnection):
164
164
  """Alias for _send_monitor_leaderboard for backward compatibility."""
165
165
  await self._send_monitor_leaderboard(websocket)
166
166
 
167
- async def handle_connection(self, websocket: WebSocketServerProtocol):
167
+ async def handle_connection(self, websocket: ServerConnection):
168
168
  """Handle new WebSocket connection."""
169
169
  try:
170
170
  # Authenticate
@@ -193,7 +193,7 @@ class Orchestrator:
193
193
  logger.error(f"Connection error: {e}", exc_info=True)
194
194
  await websocket.close()
195
195
 
196
- async def _handle_worker(self, websocket: WebSocketServerProtocol, auth_ticket):
196
+ async def _handle_worker(self, websocket: ServerConnection, auth_ticket):
197
197
  """Handle worker connection lifecycle."""
198
198
  # Generate unique worker ID
199
199
  base_name = getattr(auth_ticket, "name", "worker")
@@ -250,9 +250,38 @@ class Orchestrator:
250
250
  self.processor.release_assignments(worker_id)
251
251
  logger.info(f"Worker {worker_id} has safely disconnected")
252
252
 
253
- async def _handle_config_reload(
254
- self, websocket: WebSocketServerProtocol, new_config: Dict[str, Any]
255
- ):
253
+ def _auth_configs_equal(
254
+ self, current_config: Dict[str, Any], new_config: Dict[str, Any]
255
+ ) -> bool:
256
+ """Compare two auth configurations for equality."""
257
+
258
+ # Helper function to normalize token lists for comparison
259
+ def normalize_tokens(tokens):
260
+ if not tokens:
261
+ return []
262
+ # Sort by token for consistent comparison
263
+ return sorted(
264
+ [{"name": t.get("name"), "token": t.get("token")} for t in tokens],
265
+ key=lambda x: x.get("token", ""),
266
+ )
267
+
268
+ # Compare each token type
269
+ current_workers = normalize_tokens(current_config.get("worker_tokens", []))
270
+ new_workers = normalize_tokens(new_config.get("worker_tokens", []))
271
+
272
+ current_admins = normalize_tokens(current_config.get("admin_tokens", []))
273
+ new_admins = normalize_tokens(new_config.get("admin_tokens", []))
274
+
275
+ current_monitors = normalize_tokens(current_config.get("monitor_tokens", []))
276
+ new_monitors = normalize_tokens(new_config.get("monitor_tokens", []))
277
+
278
+ return (
279
+ current_workers == new_workers
280
+ and current_admins == new_admins
281
+ and current_monitors == new_monitors
282
+ )
283
+
284
+ async def _handle_config_reload(self, websocket: ServerConnection, new_config: Dict[str, Any]):
256
285
  """Handle configuration reload request."""
257
286
  logger.info("Processing configuration reload request")
258
287
 
@@ -293,8 +322,16 @@ class Orchestrator:
293
322
  # Update auth configuration
294
323
  if "auth" in orchestrator_config:
295
324
  try:
296
- self.auth = AuthManager(orchestrator_config["auth"])
297
- updated_sections.append("auth")
325
+ current_auth_config = self.config.get("auth", {})
326
+ new_auth_config = orchestrator_config["auth"]
327
+
328
+ # Only recreate AuthManager if auth config has actually changed
329
+ if not self._auth_configs_equal(current_auth_config, new_auth_config):
330
+ self.auth = AuthManager(new_auth_config)
331
+ updated_sections.append("auth")
332
+ logger.info("Auth configuration updated due to changes")
333
+ else:
334
+ logger.info("Auth configuration unchanged, preserving existing AuthManager")
298
335
  except Exception as e:
299
336
  logger.error(f"Failed to update AuthManager: {e}")
300
337
  warnings.append(f"Auth update failed: {e}")
@@ -344,7 +381,7 @@ class Orchestrator:
344
381
  assignment_id=str(uuid.uuid4()),
345
382
  worker_id=worker_id,
346
383
  units=units,
347
- assigned_at=datetime.utcnow(),
384
+ assigned_at=datetime.now(_datetime.UTC),
348
385
  )
349
386
 
350
387
  await self.workers[worker_id].send(
@@ -374,75 +411,93 @@ class Orchestrator:
374
411
  logger.debug(f"Heartbeat from {worker_id}: {data}")
375
412
 
376
413
  async def _handle_results_submission(self, worker_id: str, data: Dict):
377
- """Process results submission from worker."""
378
- # Extract user from worker_id
379
- worker_user = worker_id.rsplit("_", 1)[0] if "_" in worker_id else worker_id
380
- # Create work result
381
- _job_id = data.get("job_id")
382
- job_id = JobId.from_str(_job_id)
383
- shard_name = job_id.shard_id # >data-0000<
384
- chunk_name = job_id.chunk_id # data-0000:chunk:>0<
385
- # logger.debug(f"({job_id}) Worker result: {data}")
386
- result = WorkResult(
387
- unit_id=data["unit_id"],
388
- source_id=shard_name,
389
- chunk_id=job_id.get_chunk_str(), # we want the full string here
390
- sample_id=data["sample_id"],
391
- dataset=data["dataset"],
392
- outputs=data["outputs"],
393
- metadata=data.get("metadata", {}),
394
- processing_time_ms=data.get("processing_time_ms", 0),
395
- )
414
+ """Process results submission from worker - fires off async task and returns immediately."""
415
+ # Fire and forget - process in background
416
+ asyncio.create_task(self._process_result_async(worker_id, data))
396
417
 
397
- # Let processor handle any custom processing
398
- processed = self.processor.handle_result(result)
399
-
400
- # Create caption record for storage
401
- total_outputs = sum(len(v) for v in result.outputs.values())
402
-
403
- filename = result.metadata.pop("_filename", None)
404
- url = result.metadata.pop("_url", None)
405
- image_height = result.metadata.pop("image_height", None)
406
- image_width = result.metadata.pop("image_width", None)
407
- file_size = result.metadata.pop("file_size", None)
408
- image_format = result.metadata.pop("image_format", None)
409
- item_index = result.metadata.pop("item_index", None)
410
- item_key = result.metadata.pop("item_key", None)
411
- to_delete_metadata_keys = ["_image_format", "_job_id"]
412
- for key in to_delete_metadata_keys:
413
- if key in result.metadata:
414
- del result.metadata[key]
415
- caption = Caption(
416
- job_id=job_id,
417
- dataset=result.dataset,
418
- shard=processed["source_id"],
419
- chunk_id=chunk_name,
420
- item_key=item_key,
421
- captions=result.outputs.get("captions", []),
422
- outputs=result.outputs,
423
- contributor_id=worker_user,
424
- timestamp=datetime.utcnow(),
425
- caption_count=total_outputs,
426
- processing_time_ms=result.processing_time_ms,
427
- metadata=result.metadata,
428
- image_height=image_height,
429
- image_width=image_width,
430
- filename=filename,
431
- url=url,
432
- file_size=file_size,
433
- image_format=image_format,
434
- )
418
+ async def _process_result_async(self, worker_id: str, data: Dict):
419
+ """Actually process the result in background."""
420
+ try:
421
+ # Extract user from worker_id
422
+ worker_user = worker_id.rsplit("_", 1)[0] if "_" in worker_id else worker_id
423
+
424
+ # Create work result
425
+ _job_id = data.get("job_id")
426
+ job_id = JobId.from_str(_job_id)
427
+ shard_name = job_id.shard_id
428
+ chunk_name = job_id.chunk_id
429
+
430
+ result = WorkResult(
431
+ unit_id=data["unit_id"],
432
+ source_id=shard_name,
433
+ chunk_id=job_id.get_chunk_str(),
434
+ sample_id=data["sample_id"],
435
+ dataset=data["dataset"],
436
+ outputs=data["outputs"],
437
+ metadata=data.get("metadata", {}),
438
+ processing_time_ms=data.get("processing_time_ms", 0),
439
+ )
435
440
 
436
- # Save to storage
437
- await self.storage.save_caption(caption)
441
+ # Let processor handle any custom processing - this updates chunk tracker
442
+ # IMPORTANT: Call this BEFORE saving to storage so chunk tracker is updated
443
+ # regardless of whether the item is a duplicate
444
+ processed = self.processor.handle_result(result)
445
+
446
+ # Create caption record for storage
447
+ total_outputs = sum(len(v) for v in result.outputs.values())
448
+
449
+ filename = result.metadata.pop("_filename", None)
450
+ url = result.metadata.pop("_url", None)
451
+ image_height = result.metadata.pop("image_height", None)
452
+ image_width = result.metadata.pop("image_width", None)
453
+ file_size = result.metadata.pop("file_size", None)
454
+ image_format = result.metadata.pop("image_format", None)
455
+ result.metadata.pop("item_index", None)
456
+ item_key = result.metadata.pop("item_key", None)
457
+
458
+ to_delete_metadata_keys = ["_image_format", "_job_id"]
459
+ for key in to_delete_metadata_keys:
460
+ if key in result.metadata:
461
+ del result.metadata[key]
462
+
463
+ caption = Caption(
464
+ job_id=job_id,
465
+ dataset=result.dataset,
466
+ shard=processed["source_id"],
467
+ chunk_id=chunk_name,
468
+ item_key=item_key,
469
+ captions=result.outputs.get("captions", []),
470
+ outputs=result.outputs,
471
+ contributor_id=worker_user,
472
+ timestamp=datetime.now(_datetime.UTC),
473
+ caption_count=total_outputs,
474
+ processing_time_ms=result.processing_time_ms,
475
+ metadata=result.metadata,
476
+ image_height=image_height,
477
+ image_width=image_width,
478
+ filename=filename,
479
+ url=url,
480
+ file_size=file_size,
481
+ image_format=image_format,
482
+ )
438
483
 
439
- # Update contributor stats
440
- contributor = await self.storage.get_contributor(worker_user)
441
- if contributor:
442
- contributor.total_captions += total_outputs
443
- await self.storage.save_contributor(contributor)
484
+ # Save to storage (might skip if duplicate)
485
+ saved = await self.storage.save_caption(caption)
486
+
487
+ # Update contributor stats only if actually saved
488
+ if saved:
489
+ contributor = await self.storage.get_contributor(worker_user)
490
+ if contributor:
491
+ contributor.total_captions += total_outputs
492
+ await self.storage.save_contributor(contributor)
493
+
494
+ except Exception as e:
495
+ logger.error(
496
+ f"Error processing result from {worker_id} for unit {data.get('unit_id', 'unknown')}: {e}",
497
+ exc_info=True,
498
+ )
444
499
 
445
- async def _handle_monitor(self, websocket: WebSocketServerProtocol):
500
+ async def _handle_monitor(self, websocket: ServerConnection):
446
501
  """Handle monitor connection."""
447
502
  self.monitors.add(websocket)
448
503
  logger.info(f"Monitor connected (total: {len(self.monitors)})")
@@ -455,7 +510,7 @@ class Orchestrator:
455
510
  await self._send_monitor_stats(websocket)
456
511
 
457
512
  # Keep connection alive
458
- async for message in websocket:
513
+ async for _message in websocket:
459
514
  pass
460
515
 
461
516
  except websockets.exceptions.ConnectionClosed:
@@ -463,7 +518,7 @@ class Orchestrator:
463
518
  finally:
464
519
  self.monitors.discard(websocket)
465
520
 
466
- async def _handle_admin(self, websocket: WebSocketServerProtocol, auth_ticket):
521
+ async def _handle_admin(self, websocket: ServerConnection, auth_ticket):
467
522
  """Handle admin connection."""
468
523
  admin_id = getattr(auth_ticket, "name", "admin")
469
524
  logger.info(f"Admin {admin_id} connected")
@@ -485,7 +540,7 @@ class Orchestrator:
485
540
  except websockets.exceptions.ConnectionClosed:
486
541
  logger.info(f"Admin {admin_id} disconnected")
487
542
 
488
- async def _handle_data_worker(self, websocket: WebSocketServerProtocol, auth_ticket):
543
+ async def _handle_data_worker(self, websocket: ServerConnection, auth_ticket):
489
544
  """Handle data worker connection."""
490
545
  worker_id = getattr(auth_ticket, "name", str(uuid.uuid4()))
491
546
  self.data_workers[worker_id] = websocket
@@ -554,7 +609,7 @@ class Orchestrator:
554
609
  finally:
555
610
  del self.data_workers[worker_id]
556
611
 
557
- async def _send_monitor_initial_data(self, websocket: WebSocketServerProtocol):
612
+ async def _send_monitor_initial_data(self, websocket: ServerConnection):
558
613
  """Send initial data to monitor in a separate task to avoid blocking."""
559
614
  total_start = time.time()
560
615
  try:
@@ -611,7 +666,7 @@ class Orchestrator:
611
666
  except Exception as e:
612
667
  logger.error(f"Error sending initial monitor data: {e}")
613
668
 
614
- async def _send_monitor_leaderboard(self, websocket: WebSocketServerProtocol):
669
+ async def _send_monitor_leaderboard(self, websocket: ServerConnection):
615
670
  """Send leaderboard data to a specific monitor."""
616
671
  total_start = time.time()
617
672
  try:
@@ -676,7 +731,7 @@ class Orchestrator:
676
731
  except Exception as e:
677
732
  logger.error(f"Error sending leaderboard to monitor: {e}")
678
733
 
679
- async def _send_monitor_stats(self, websocket: WebSocketServerProtocol):
734
+ async def _send_monitor_stats(self, websocket: ServerConnection):
680
735
  """Send current stats to a monitor."""
681
736
  # Get processor stats
682
737
  processor_stats = self.processor.get_stats()
@@ -788,7 +843,7 @@ class Orchestrator:
788
843
  # Remove disconnected
789
844
  disconnected = {
790
845
  m
791
- for m, r in zip(monitors_copy, results)
846
+ for m, r in zip(monitors_copy, results, strict=False)
792
847
  if r is not None and not isinstance(r, Exception)
793
848
  }
794
849
  self.monitors -= disconnected
@@ -840,39 +895,54 @@ class Orchestrator:
840
895
  self.monitors -= disconnected
841
896
 
842
897
  async def _heartbeat_loop(self):
843
- """Send periodic heartbeats to maintain connections."""
898
+ """Collect and log worker status periodically."""
844
899
  while True:
845
900
  await asyncio.sleep(30)
846
901
 
847
- disconnected = []
902
+ # Just collect status - no ping/pong
903
+ active_workers = []
848
904
  for worker_id, ws in list(self.workers.items()):
849
- try:
850
- pong_waiter = await ws.ping()
851
- await asyncio.wait_for(pong_waiter, timeout=10)
852
- except:
853
- disconnected.append(worker_id)
854
-
855
- # Clean up disconnected workers
856
- for worker_id in disconnected:
857
- logger.warning(f"Worker {worker_id} did not respond to ping, disconnecting")
858
- if worker_id in self.workers:
905
+ # Check if WebSocket is still open (don't ping)
906
+ if ws.state == websockets.protocol.State.OPEN:
907
+ active_workers.append(worker_id)
908
+ else:
909
+ # Clean up closed connections
910
+ logger.info(f"Worker {worker_id} connection closed")
859
911
  del self.workers[worker_id]
860
- logger.warning(
861
- f"Releasing assignments for worker {worker_id} because it did not respond to ping"
862
- )
863
912
  self.processor.release_assignments(worker_id)
864
- self.stats["connected_workers"] = len(self.workers)
913
+
914
+ # Log status
915
+ if active_workers:
916
+ logger.debug(
917
+ f"Inactive workers: {len(self.workers) - len(active_workers)}/{len(active_workers)} - {', '.join(active_workers[:5])}"
918
+ )
919
+ # add to self.stats
920
+ self.stats["active_workers"] = len(active_workers)
921
+ self.stats["inactive_workers"] = len(self.workers) - len(active_workers)
865
922
 
866
923
  async def _checkpoint_loop(self):
867
- """Periodically checkpoint storage."""
924
+ """Periodically checkpoint storage and chunk tracker."""
868
925
  interval = self.config.get("storage", {}).get("checkpoint_interval", 60)
869
926
 
870
927
  while True:
871
928
  await asyncio.sleep(interval)
872
929
 
873
- await self.storage.checkpoint()
874
- self.stats["last_checkpoint"] = datetime.utcnow().isoformat()
875
- logger.info("Storage checkpoint complete")
930
+ try:
931
+ # Checkpoint storage
932
+ await self.storage.checkpoint()
933
+
934
+ # Also checkpoint the chunk tracker if using webdataset processor
935
+ if hasattr(self.processor, "chunk_tracker") and self.processor.chunk_tracker:
936
+ # Save checkpoint in thread pool to avoid blocking
937
+ await asyncio.get_event_loop().run_in_executor(
938
+ None, self.processor.chunk_tracker.save
939
+ )
940
+ logger.debug("Saved chunk tracker checkpoint")
941
+
942
+ self.stats["last_checkpoint"] = datetime.now(_datetime.UTC).isoformat()
943
+ logger.info("Storage and chunk tracker checkpoint complete")
944
+ except Exception as e:
945
+ logger.error(f"Error during checkpoint: {e}", exc_info=True)
876
946
 
877
947
  async def _stats_update_loop(self):
878
948
  """Periodically update and broadcast stats."""
@@ -1,11 +1,11 @@
1
1
  from .base import (
2
2
  OrchestratorProcessor,
3
- WorkerProcessor,
4
3
  ProcessorConfig,
5
- WorkUnit,
6
4
  WorkAssignment,
5
+ WorkerProcessor,
7
6
  WorkResult,
7
+ WorkUnit,
8
8
  )
9
9
  from .huggingface import HuggingFaceDatasetOrchestratorProcessor, HuggingFaceDatasetWorkerProcessor
10
- from .webdataset import WebDatasetOrchestratorProcessor, WebDatasetWorkerProcessor
11
10
  from .local_filesystem import LocalFilesystemOrchestratorProcessor, LocalFilesystemWorkerProcessor
11
+ from .webdataset import WebDatasetOrchestratorProcessor, WebDatasetWorkerProcessor
@@ -2,9 +2,8 @@
2
2
 
3
3
  from abc import ABC, abstractmethod
4
4
  from dataclasses import dataclass, field
5
- from typing import Dict, Any, List, Optional, Iterator, Tuple
6
5
  from datetime import datetime
7
- from pathlib import Path
6
+ from typing import Any, Dict, Iterator, List, Optional
8
7
 
9
8
 
10
9
  @dataclass
@@ -98,9 +97,7 @@ class WorkResult:
98
97
  return self.error is None and bool(self.outputs)
99
98
 
100
99
  def to_repr(self, filter_outputs: bool = True):
101
- """
102
- Print the WorkResult, optionally without captions to save on screen wall-of-text dumpage.
103
- """
100
+ """Print the WorkResult, optionally without captions to save on screen wall-of-text dumpage."""
104
101
  if filter_outputs:
105
102
  outputs = "...filtered from logs..."
106
103
  else:
@@ -172,6 +169,8 @@ class OrchestratorProcessor(ABC):
172
169
  class WorkerProcessor(ABC):
173
170
  """Base processor for worker side - processes work units."""
174
171
 
172
+ gpu_id: Optional[int] = None
173
+
175
174
  @abstractmethod
176
175
  def initialize(self, config: ProcessorConfig) -> None:
177
176
  """Initialize the processor with configuration."""
@@ -179,18 +178,20 @@ class WorkerProcessor(ABC):
179
178
 
180
179
  @abstractmethod
181
180
  def process_unit(self, unit: WorkUnit, context: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
182
- """
183
- Process a single work unit, yielding items to be captioned.
181
+ """Process a single work unit, yielding items to be captioned.
184
182
 
185
183
  Args:
184
+ ----
186
185
  unit: The work unit to process
187
186
  context: Runtime context (e.g., models, sampling params)
188
187
 
189
188
  Yields:
189
+ ------
190
190
  Dict containing:
191
191
  - image: PIL Image
192
192
  - metadata: Dict of metadata
193
193
  - item_key: Unique identifier for this item
194
+
194
195
  """
195
196
  pass
196
197