caption-flow 0.3.4__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. caption_flow/__init__.py +3 -3
  2. caption_flow/cli.py +921 -427
  3. caption_flow/models.py +45 -3
  4. caption_flow/monitor.py +2 -3
  5. caption_flow/orchestrator.py +153 -104
  6. caption_flow/processors/__init__.py +3 -3
  7. caption_flow/processors/base.py +8 -7
  8. caption_flow/processors/huggingface.py +463 -68
  9. caption_flow/processors/local_filesystem.py +24 -28
  10. caption_flow/processors/webdataset.py +28 -22
  11. caption_flow/storage/exporter.py +420 -339
  12. caption_flow/storage/manager.py +636 -756
  13. caption_flow/utils/__init__.py +1 -1
  14. caption_flow/utils/auth.py +1 -1
  15. caption_flow/utils/caption_utils.py +1 -1
  16. caption_flow/utils/certificates.py +15 -8
  17. caption_flow/utils/checkpoint_tracker.py +30 -28
  18. caption_flow/utils/chunk_tracker.py +153 -56
  19. caption_flow/utils/image_processor.py +9 -9
  20. caption_flow/utils/json_utils.py +37 -20
  21. caption_flow/utils/prompt_template.py +24 -16
  22. caption_flow/utils/vllm_config.py +5 -4
  23. caption_flow/viewer.py +4 -12
  24. caption_flow/workers/base.py +5 -4
  25. caption_flow/workers/caption.py +303 -92
  26. caption_flow/workers/data.py +6 -8
  27. {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/METADATA +9 -4
  28. caption_flow-0.4.1.dist-info/RECORD +33 -0
  29. caption_flow-0.3.4.dist-info/RECORD +0 -33
  30. {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/WHEEL +0 -0
  31. {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/entry_points.txt +0 -0
  32. {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/licenses/LICENSE +0 -0
  33. {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/top_level.txt +0 -0
caption_flow/models.py CHANGED
@@ -1,12 +1,18 @@
1
1
  """Data models for CaptionFlow."""
2
2
 
3
- import PIL
3
+ import datetime as _datetime
4
+ import logging
5
+ import os
4
6
  from dataclasses import dataclass, field
5
7
  from datetime import datetime
6
8
  from enum import Enum
7
9
  from typing import Any, Dict, List, Optional, Tuple
10
+
8
11
  from PIL import Image
9
12
 
13
+ logger = logging.getLogger(__name__)
14
+ logger.setLevel(os.environ.get("CAPTIONFLOW_LOG_LEVEL", "INFO").upper())
15
+
10
16
 
11
17
  class JobStatus(Enum):
12
18
  """Job processing status."""
@@ -37,7 +43,7 @@ class Job:
37
43
 
38
44
  def __post_init__(self):
39
45
  if self.created_at is None:
40
- self.created_at = datetime.utcnow()
46
+ self.created_at = datetime.now(_datetime.UTC)
41
47
 
42
48
 
43
49
  @dataclass
@@ -69,7 +75,43 @@ class JobId:
69
75
  parts = job_id.split(":")
70
76
  if len(parts) != 5:
71
77
  raise ValueError(f"Invalid job_id format: {job_id}")
72
- return JobId(shard_id=parts[0], chunk_id=parts[2], sample_id=parts[4])
78
+
79
+ shard_id = parts[0]
80
+ chunk_keyword = parts[1]
81
+ chunk_id = parts[2]
82
+ idx_keyword = parts[3]
83
+ sample_id = parts[4]
84
+
85
+ # Validate format
86
+ if not shard_id:
87
+ raise ValueError(f"Invalid job_id format: empty shard_id in {job_id}")
88
+ if chunk_keyword != "chunk":
89
+ raise ValueError(
90
+ f"Invalid job_id format: expected 'chunk' keyword, got '{chunk_keyword}' in {job_id}"
91
+ )
92
+ if idx_keyword != "idx":
93
+ raise ValueError(
94
+ f"Invalid job_id format: expected 'idx' keyword, got '{idx_keyword}' in {job_id}"
95
+ )
96
+
97
+ # Validate numeric fields
98
+ try:
99
+ int(chunk_id)
100
+ except ValueError:
101
+ raise ValueError(
102
+ f"Invalid job_id format: chunk_id must be numeric, got '{chunk_id}' in {job_id}"
103
+ )
104
+
105
+ # sample_id can be empty/None for some use cases, but if provided must be numeric
106
+ if sample_id:
107
+ try:
108
+ int(sample_id)
109
+ except ValueError:
110
+ raise ValueError(
111
+ f"Invalid job_id format: sample_id must be numeric if provided, got '{sample_id}' in {job_id}"
112
+ )
113
+
114
+ return JobId(shard_id=shard_id, chunk_id=chunk_id, sample_id=sample_id)
73
115
 
74
116
 
75
117
  @dataclass
caption_flow/monitor.py CHANGED
@@ -4,9 +4,8 @@ import asyncio
4
4
  import json
5
5
  import logging
6
6
  import ssl
7
- import time
8
7
  from datetime import datetime
9
- from typing import Dict, Any, List, Optional
8
+ from typing import Any, Dict, Optional
10
9
 
11
10
  import websockets
12
11
  from rich.console import Console
@@ -110,7 +109,7 @@ class Monitor:
110
109
  """Main display update loop."""
111
110
  layout = self._create_layout()
112
111
 
113
- with Live(layout, console=self.console, refresh_per_second=1, screen=True) as live:
112
+ with Live(layout, console=self.console, refresh_per_second=1, screen=True):
114
113
  while self.running:
115
114
  self._update_layout(layout)
116
115
  await asyncio.sleep(0.25)
@@ -1,34 +1,34 @@
1
- import time
2
1
  import asyncio
2
+ import datetime as _datetime
3
3
  import json
4
4
  import logging
5
+ import os
5
6
  import ssl
7
+ import time
6
8
  import uuid
9
+ from collections import defaultdict
7
10
  from datetime import datetime
8
11
  from pathlib import Path
9
- from typing import Dict, Set, Optional, Any, List
10
- from collections import defaultdict
11
- import threading
12
+ from typing import Any, Dict, Optional, Set
12
13
 
13
14
  import websockets
14
- from websockets.server import WebSocketServerProtocol
15
+ from websockets.asyncio.server import ServerConnection
15
16
 
16
- from .storage import StorageManager
17
17
  from .models import Caption, Contributor, JobId
18
- from .utils.auth import AuthManager
19
- from .utils.json_utils import safe_json_dumps
20
18
  from .processors import (
19
+ HuggingFaceDatasetOrchestratorProcessor,
20
+ LocalFilesystemOrchestratorProcessor,
21
21
  ProcessorConfig,
22
+ WebDatasetOrchestratorProcessor,
22
23
  WorkAssignment,
23
24
  WorkResult,
24
- WorkUnit,
25
- WebDatasetOrchestratorProcessor,
26
- HuggingFaceDatasetOrchestratorProcessor,
27
- LocalFilesystemOrchestratorProcessor,
28
25
  )
26
+ from .storage import StorageManager
27
+ from .utils.auth import AuthManager
28
+ from .utils.json_utils import safe_json_dumps
29
29
 
30
30
  logger = logging.getLogger(__name__)
31
- logger.setLevel(logging.INFO)
31
+ logger.setLevel(os.environ.get("CAPTIONFLOW_LOG_LEVEL", "INFO").upper())
32
32
 
33
33
 
34
34
  class Orchestrator:
@@ -69,8 +69,8 @@ class Orchestrator:
69
69
  self.chunks_per_request = config.get("chunks_per_request", 2)
70
70
 
71
71
  # Track connections
72
- self.workers: Dict[str, WebSocketServerProtocol] = {}
73
- self.monitors: Set[WebSocketServerProtocol] = set()
72
+ self.workers: Dict[str, ServerConnection] = {}
73
+ self.monitors: Set[ServerConnection] = set()
74
74
  self.workers_by_user = defaultdict(set)
75
75
 
76
76
  # SSL configuration
@@ -160,11 +160,11 @@ class Orchestrator:
160
160
  processed_job_ids = self.storage.get_all_processed_job_ids()
161
161
  self.processor.update_from_storage(processed_job_ids)
162
162
 
163
- async def _send_leaderboard_to_monitor(self, websocket: WebSocketServerProtocol):
163
+ async def _send_leaderboard_to_monitor(self, websocket: ServerConnection):
164
164
  """Alias for _send_monitor_leaderboard for backward compatibility."""
165
165
  await self._send_monitor_leaderboard(websocket)
166
166
 
167
- async def handle_connection(self, websocket: WebSocketServerProtocol):
167
+ async def handle_connection(self, websocket: ServerConnection):
168
168
  """Handle new WebSocket connection."""
169
169
  try:
170
170
  # Authenticate
@@ -193,7 +193,7 @@ class Orchestrator:
193
193
  logger.error(f"Connection error: {e}", exc_info=True)
194
194
  await websocket.close()
195
195
 
196
- async def _handle_worker(self, websocket: WebSocketServerProtocol, auth_ticket):
196
+ async def _handle_worker(self, websocket: ServerConnection, auth_ticket):
197
197
  """Handle worker connection lifecycle."""
198
198
  # Generate unique worker ID
199
199
  base_name = getattr(auth_ticket, "name", "worker")
@@ -250,9 +250,38 @@ class Orchestrator:
250
250
  self.processor.release_assignments(worker_id)
251
251
  logger.info(f"Worker {worker_id} has safely disconnected")
252
252
 
253
- async def _handle_config_reload(
254
- self, websocket: WebSocketServerProtocol, new_config: Dict[str, Any]
255
- ):
253
+ def _auth_configs_equal(
254
+ self, current_config: Dict[str, Any], new_config: Dict[str, Any]
255
+ ) -> bool:
256
+ """Compare two auth configurations for equality."""
257
+
258
+ # Helper function to normalize token lists for comparison
259
+ def normalize_tokens(tokens):
260
+ if not tokens:
261
+ return []
262
+ # Sort by token for consistent comparison
263
+ return sorted(
264
+ [{"name": t.get("name"), "token": t.get("token")} for t in tokens],
265
+ key=lambda x: x.get("token", ""),
266
+ )
267
+
268
+ # Compare each token type
269
+ current_workers = normalize_tokens(current_config.get("worker_tokens", []))
270
+ new_workers = normalize_tokens(new_config.get("worker_tokens", []))
271
+
272
+ current_admins = normalize_tokens(current_config.get("admin_tokens", []))
273
+ new_admins = normalize_tokens(new_config.get("admin_tokens", []))
274
+
275
+ current_monitors = normalize_tokens(current_config.get("monitor_tokens", []))
276
+ new_monitors = normalize_tokens(new_config.get("monitor_tokens", []))
277
+
278
+ return (
279
+ current_workers == new_workers
280
+ and current_admins == new_admins
281
+ and current_monitors == new_monitors
282
+ )
283
+
284
+ async def _handle_config_reload(self, websocket: ServerConnection, new_config: Dict[str, Any]):
256
285
  """Handle configuration reload request."""
257
286
  logger.info("Processing configuration reload request")
258
287
 
@@ -293,8 +322,16 @@ class Orchestrator:
293
322
  # Update auth configuration
294
323
  if "auth" in orchestrator_config:
295
324
  try:
296
- self.auth = AuthManager(orchestrator_config["auth"])
297
- updated_sections.append("auth")
325
+ current_auth_config = self.config.get("auth", {})
326
+ new_auth_config = orchestrator_config["auth"]
327
+
328
+ # Only recreate AuthManager if auth config has actually changed
329
+ if not self._auth_configs_equal(current_auth_config, new_auth_config):
330
+ self.auth = AuthManager(new_auth_config)
331
+ updated_sections.append("auth")
332
+ logger.info("Auth configuration updated due to changes")
333
+ else:
334
+ logger.info("Auth configuration unchanged, preserving existing AuthManager")
298
335
  except Exception as e:
299
336
  logger.error(f"Failed to update AuthManager: {e}")
300
337
  warnings.append(f"Auth update failed: {e}")
@@ -344,7 +381,7 @@ class Orchestrator:
344
381
  assignment_id=str(uuid.uuid4()),
345
382
  worker_id=worker_id,
346
383
  units=units,
347
- assigned_at=datetime.utcnow(),
384
+ assigned_at=datetime.now(_datetime.UTC),
348
385
  )
349
386
 
350
387
  await self.workers[worker_id].send(
@@ -374,80 +411,93 @@ class Orchestrator:
374
411
  logger.debug(f"Heartbeat from {worker_id}: {data}")
375
412
 
376
413
  async def _handle_results_submission(self, worker_id: str, data: Dict):
377
- """Process results submission from worker."""
378
- # Extract user from worker_id
379
- worker_user = worker_id.rsplit("_", 1)[0] if "_" in worker_id else worker_id
380
-
381
- # Create work result
382
- _job_id = data.get("job_id")
383
- job_id = JobId.from_str(_job_id)
384
- shard_name = job_id.shard_id
385
- chunk_name = job_id.chunk_id
386
-
387
- result = WorkResult(
388
- unit_id=data["unit_id"],
389
- source_id=shard_name,
390
- chunk_id=job_id.get_chunk_str(),
391
- sample_id=data["sample_id"],
392
- dataset=data["dataset"],
393
- outputs=data["outputs"],
394
- metadata=data.get("metadata", {}),
395
- processing_time_ms=data.get("processing_time_ms", 0),
396
- )
414
+ """Process results submission from worker - fires off async task and returns immediately."""
415
+ # Fire and forget - process in background
416
+ asyncio.create_task(self._process_result_async(worker_id, data))
397
417
 
398
- # Let processor handle any custom processing - this updates chunk tracker
399
- # IMPORTANT: Call this BEFORE saving to storage so chunk tracker is updated
400
- # regardless of whether the item is a duplicate
401
- processed = self.processor.handle_result(result)
402
-
403
- # Create caption record for storage
404
- total_outputs = sum(len(v) for v in result.outputs.values())
405
-
406
- filename = result.metadata.pop("_filename", None)
407
- url = result.metadata.pop("_url", None)
408
- image_height = result.metadata.pop("image_height", None)
409
- image_width = result.metadata.pop("image_width", None)
410
- file_size = result.metadata.pop("file_size", None)
411
- image_format = result.metadata.pop("image_format", None)
412
- item_index = result.metadata.pop("item_index", None)
413
- item_key = result.metadata.pop("item_key", None)
414
- to_delete_metadata_keys = ["_image_format", "_job_id"]
415
- for key in to_delete_metadata_keys:
416
- if key in result.metadata:
417
- del result.metadata[key]
418
-
419
- caption = Caption(
420
- job_id=job_id,
421
- dataset=result.dataset,
422
- shard=processed["source_id"],
423
- chunk_id=chunk_name,
424
- item_key=item_key,
425
- captions=result.outputs.get("captions", []),
426
- outputs=result.outputs,
427
- contributor_id=worker_user,
428
- timestamp=datetime.utcnow(),
429
- caption_count=total_outputs,
430
- processing_time_ms=result.processing_time_ms,
431
- metadata=result.metadata,
432
- image_height=image_height,
433
- image_width=image_width,
434
- filename=filename,
435
- url=url,
436
- file_size=file_size,
437
- image_format=image_format,
438
- )
418
+ async def _process_result_async(self, worker_id: str, data: Dict):
419
+ """Actually process the result in background."""
420
+ try:
421
+ # Extract user from worker_id
422
+ worker_user = worker_id.rsplit("_", 1)[0] if "_" in worker_id else worker_id
423
+
424
+ # Create work result
425
+ _job_id = data.get("job_id")
426
+ job_id = JobId.from_str(_job_id)
427
+ shard_name = job_id.shard_id
428
+ chunk_name = job_id.chunk_id
429
+
430
+ result = WorkResult(
431
+ unit_id=data["unit_id"],
432
+ source_id=shard_name,
433
+ chunk_id=job_id.get_chunk_str(),
434
+ sample_id=data["sample_id"],
435
+ dataset=data["dataset"],
436
+ outputs=data["outputs"],
437
+ metadata=data.get("metadata", {}),
438
+ processing_time_ms=data.get("processing_time_ms", 0),
439
+ )
440
+
441
+ # Let processor handle any custom processing - this updates chunk tracker
442
+ # IMPORTANT: Call this BEFORE saving to storage so chunk tracker is updated
443
+ # regardless of whether the item is a duplicate
444
+ processed = self.processor.handle_result(result)
445
+
446
+ # Create caption record for storage
447
+ total_outputs = sum(len(v) for v in result.outputs.values())
448
+
449
+ filename = result.metadata.pop("_filename", None)
450
+ url = result.metadata.pop("_url", None)
451
+ image_height = result.metadata.pop("image_height", None)
452
+ image_width = result.metadata.pop("image_width", None)
453
+ file_size = result.metadata.pop("file_size", None)
454
+ image_format = result.metadata.pop("image_format", None)
455
+ result.metadata.pop("item_index", None)
456
+ item_key = result.metadata.pop("item_key", None)
457
+
458
+ to_delete_metadata_keys = ["_image_format", "_job_id"]
459
+ for key in to_delete_metadata_keys:
460
+ if key in result.metadata:
461
+ del result.metadata[key]
462
+
463
+ caption = Caption(
464
+ job_id=job_id,
465
+ dataset=result.dataset,
466
+ shard=processed["source_id"],
467
+ chunk_id=chunk_name,
468
+ item_key=item_key,
469
+ captions=result.outputs.get("captions", []),
470
+ outputs=result.outputs,
471
+ contributor_id=worker_user,
472
+ timestamp=datetime.now(_datetime.UTC),
473
+ caption_count=total_outputs,
474
+ processing_time_ms=result.processing_time_ms,
475
+ metadata=result.metadata,
476
+ image_height=image_height,
477
+ image_width=image_width,
478
+ filename=filename,
479
+ url=url,
480
+ file_size=file_size,
481
+ image_format=image_format,
482
+ )
439
483
 
440
- # Save to storage (might skip if duplicate)
441
- saved = await self.storage.save_caption(caption)
484
+ # Save to storage (might skip if duplicate)
485
+ saved = await self.storage.save_caption(caption)
442
486
 
443
- # Update contributor stats only if actually saved
444
- if saved:
445
- contributor = await self.storage.get_contributor(worker_user)
446
- if contributor:
447
- contributor.total_captions += total_outputs
448
- await self.storage.save_contributor(contributor)
487
+ # Update contributor stats only if actually saved
488
+ if saved:
489
+ contributor = await self.storage.get_contributor(worker_user)
490
+ if contributor:
491
+ contributor.total_captions += total_outputs
492
+ await self.storage.save_contributor(contributor)
449
493
 
450
- async def _handle_monitor(self, websocket: WebSocketServerProtocol):
494
+ except Exception as e:
495
+ logger.error(
496
+ f"Error processing result from {worker_id} for unit {data.get('unit_id', 'unknown')}: {e}",
497
+ exc_info=True,
498
+ )
499
+
500
+ async def _handle_monitor(self, websocket: ServerConnection):
451
501
  """Handle monitor connection."""
452
502
  self.monitors.add(websocket)
453
503
  logger.info(f"Monitor connected (total: {len(self.monitors)})")
@@ -460,7 +510,7 @@ class Orchestrator:
460
510
  await self._send_monitor_stats(websocket)
461
511
 
462
512
  # Keep connection alive
463
- async for message in websocket:
513
+ async for _message in websocket:
464
514
  pass
465
515
 
466
516
  except websockets.exceptions.ConnectionClosed:
@@ -468,7 +518,7 @@ class Orchestrator:
468
518
  finally:
469
519
  self.monitors.discard(websocket)
470
520
 
471
- async def _handle_admin(self, websocket: WebSocketServerProtocol, auth_ticket):
521
+ async def _handle_admin(self, websocket: ServerConnection, auth_ticket):
472
522
  """Handle admin connection."""
473
523
  admin_id = getattr(auth_ticket, "name", "admin")
474
524
  logger.info(f"Admin {admin_id} connected")
@@ -490,7 +540,7 @@ class Orchestrator:
490
540
  except websockets.exceptions.ConnectionClosed:
491
541
  logger.info(f"Admin {admin_id} disconnected")
492
542
 
493
- async def _handle_data_worker(self, websocket: WebSocketServerProtocol, auth_ticket):
543
+ async def _handle_data_worker(self, websocket: ServerConnection, auth_ticket):
494
544
  """Handle data worker connection."""
495
545
  worker_id = getattr(auth_ticket, "name", str(uuid.uuid4()))
496
546
  self.data_workers[worker_id] = websocket
@@ -559,7 +609,7 @@ class Orchestrator:
559
609
  finally:
560
610
  del self.data_workers[worker_id]
561
611
 
562
- async def _send_monitor_initial_data(self, websocket: WebSocketServerProtocol):
612
+ async def _send_monitor_initial_data(self, websocket: ServerConnection):
563
613
  """Send initial data to monitor in a separate task to avoid blocking."""
564
614
  total_start = time.time()
565
615
  try:
@@ -616,7 +666,7 @@ class Orchestrator:
616
666
  except Exception as e:
617
667
  logger.error(f"Error sending initial monitor data: {e}")
618
668
 
619
- async def _send_monitor_leaderboard(self, websocket: WebSocketServerProtocol):
669
+ async def _send_monitor_leaderboard(self, websocket: ServerConnection):
620
670
  """Send leaderboard data to a specific monitor."""
621
671
  total_start = time.time()
622
672
  try:
@@ -681,7 +731,7 @@ class Orchestrator:
681
731
  except Exception as e:
682
732
  logger.error(f"Error sending leaderboard to monitor: {e}")
683
733
 
684
- async def _send_monitor_stats(self, websocket: WebSocketServerProtocol):
734
+ async def _send_monitor_stats(self, websocket: ServerConnection):
685
735
  """Send current stats to a monitor."""
686
736
  # Get processor stats
687
737
  processor_stats = self.processor.get_stats()
@@ -793,7 +843,7 @@ class Orchestrator:
793
843
  # Remove disconnected
794
844
  disconnected = {
795
845
  m
796
- for m, r in zip(monitors_copy, results)
846
+ for m, r in zip(monitors_copy, results, strict=False)
797
847
  if r is not None and not isinstance(r, Exception)
798
848
  }
799
849
  self.monitors -= disconnected
@@ -864,9 +914,8 @@ class Orchestrator:
864
914
  # Log status
865
915
  if active_workers:
866
916
  logger.debug(
867
- f"Active workers: {len(active_workers)} - {', '.join(active_workers[:5])}"
917
+ f"Inactive workers: {len(self.workers) - len(active_workers)}/{len(active_workers)} - {', '.join(active_workers[:5])}"
868
918
  )
869
- logger.debug(f"Inactive workers: {len(self.workers) - len(active_workers)}")
870
919
  # add to self.stats
871
920
  self.stats["active_workers"] = len(active_workers)
872
921
  self.stats["inactive_workers"] = len(self.workers) - len(active_workers)
@@ -890,7 +939,7 @@ class Orchestrator:
890
939
  )
891
940
  logger.debug("Saved chunk tracker checkpoint")
892
941
 
893
- self.stats["last_checkpoint"] = datetime.utcnow().isoformat()
942
+ self.stats["last_checkpoint"] = datetime.now(_datetime.UTC).isoformat()
894
943
  logger.info("Storage and chunk tracker checkpoint complete")
895
944
  except Exception as e:
896
945
  logger.error(f"Error during checkpoint: {e}", exc_info=True)
@@ -1,11 +1,11 @@
1
1
  from .base import (
2
2
  OrchestratorProcessor,
3
- WorkerProcessor,
4
3
  ProcessorConfig,
5
- WorkUnit,
6
4
  WorkAssignment,
5
+ WorkerProcessor,
7
6
  WorkResult,
7
+ WorkUnit,
8
8
  )
9
9
  from .huggingface import HuggingFaceDatasetOrchestratorProcessor, HuggingFaceDatasetWorkerProcessor
10
- from .webdataset import WebDatasetOrchestratorProcessor, WebDatasetWorkerProcessor
11
10
  from .local_filesystem import LocalFilesystemOrchestratorProcessor, LocalFilesystemWorkerProcessor
11
+ from .webdataset import WebDatasetOrchestratorProcessor, WebDatasetWorkerProcessor
@@ -2,9 +2,8 @@
2
2
 
3
3
  from abc import ABC, abstractmethod
4
4
  from dataclasses import dataclass, field
5
- from typing import Dict, Any, List, Optional, Iterator, Tuple
6
5
  from datetime import datetime
7
- from pathlib import Path
6
+ from typing import Any, Dict, Iterator, List, Optional
8
7
 
9
8
 
10
9
  @dataclass
@@ -98,9 +97,7 @@ class WorkResult:
98
97
  return self.error is None and bool(self.outputs)
99
98
 
100
99
  def to_repr(self, filter_outputs: bool = True):
101
- """
102
- Print the WorkResult, optionally without captions to save on screen wall-of-text dumpage.
103
- """
100
+ """Print the WorkResult, optionally without captions to save on screen wall-of-text dumpage."""
104
101
  if filter_outputs:
105
102
  outputs = "...filtered from logs..."
106
103
  else:
@@ -172,6 +169,8 @@ class OrchestratorProcessor(ABC):
172
169
  class WorkerProcessor(ABC):
173
170
  """Base processor for worker side - processes work units."""
174
171
 
172
+ gpu_id: Optional[int] = None
173
+
175
174
  @abstractmethod
176
175
  def initialize(self, config: ProcessorConfig) -> None:
177
176
  """Initialize the processor with configuration."""
@@ -179,18 +178,20 @@ class WorkerProcessor(ABC):
179
178
 
180
179
  @abstractmethod
181
180
  def process_unit(self, unit: WorkUnit, context: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
182
- """
183
- Process a single work unit, yielding items to be captioned.
181
+ """Process a single work unit, yielding items to be captioned.
184
182
 
185
183
  Args:
184
+ ----
186
185
  unit: The work unit to process
187
186
  context: Runtime context (e.g., models, sampling params)
188
187
 
189
188
  Yields:
189
+ ------
190
190
  Dict containing:
191
191
  - image: PIL Image
192
192
  - metadata: Dict of metadata
193
193
  - item_key: Unique identifier for this item
194
+
194
195
  """
195
196
  pass
196
197