caption-flow 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,11 @@
1
+ from .base import (
2
+ OrchestratorProcessor,
3
+ WorkerProcessor,
4
+ ProcessorConfig,
5
+ WorkUnit,
6
+ WorkAssignment,
7
+ WorkResult,
8
+ )
9
+ from .huggingface import HuggingFaceDatasetOrchestratorProcessor, HuggingFaceDatasetWorkerProcessor
10
+ from .webdataset import WebDatasetOrchestratorProcessor, WebDatasetWorkerProcessor
11
+ from .local_filesystem import LocalFilesystemOrchestratorProcessor, LocalFilesystemWorkerProcessor
@@ -0,0 +1,219 @@
1
+ """Base processor abstractions for data source handling."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from dataclasses import dataclass, field
5
+ from typing import Dict, Any, List, Optional, Iterator, Tuple
6
+ from datetime import datetime
7
+ from pathlib import Path
8
+
9
+
10
+ @dataclass
11
+ class WorkUnit:
12
+ """Generic unit of work that can be processed."""
13
+
14
+ unit_id: str # usually, but not always, the chunk id
15
+ chunk_id: str # always the chunk id
16
+ source_id: str # the shard name
17
+ data: Dict[str, Any]
18
+ metadata: Dict[str, Any] = field(default_factory=dict)
19
+ priority: int = 0
20
+ sample_id: str = ""
21
+
22
+ def get_size_hint(self) -> int:
23
+ """Get estimated size/complexity of this work unit."""
24
+ return self.metadata.get("size_hint", 1)
25
+
26
+
27
+ @dataclass
28
+ class WorkAssignment:
29
+ """Assignment of work units to a worker."""
30
+
31
+ assignment_id: str
32
+ worker_id: str
33
+ units: List[WorkUnit]
34
+ assigned_at: datetime
35
+ metadata: Dict[str, Any] = field(default_factory=dict)
36
+
37
+ def to_dict(self) -> Dict[str, Any]:
38
+ """Convert to dict for network transmission."""
39
+ return {
40
+ "assignment_id": self.assignment_id,
41
+ "worker_id": self.worker_id,
42
+ "units": [
43
+ {
44
+ "unit_id": u.unit_id,
45
+ "source_id": u.source_id,
46
+ "chunk_id": u.chunk_id,
47
+ "data": u.data,
48
+ "metadata": u.metadata,
49
+ "priority": u.priority,
50
+ }
51
+ for u in self.units
52
+ ],
53
+ "assigned_at": self.assigned_at.isoformat(),
54
+ "metadata": self.metadata,
55
+ }
56
+
57
+ @classmethod
58
+ def from_dict(cls, data: Dict[str, Any]) -> "WorkAssignment":
59
+ """Create from dict received over network."""
60
+ units = [
61
+ WorkUnit(
62
+ unit_id=u["unit_id"],
63
+ chunk_id=u["chunk_id"],
64
+ source_id=u["source_id"],
65
+ data=u["data"],
66
+ metadata=u.get("metadata", {}),
67
+ priority=u.get("priority", 0),
68
+ )
69
+ for u in data["units"]
70
+ ]
71
+ return cls(
72
+ assignment_id=data["assignment_id"],
73
+ worker_id=data["worker_id"],
74
+ units=units,
75
+ assigned_at=datetime.fromisoformat(data["assigned_at"]),
76
+ metadata=data.get("metadata", {}),
77
+ )
78
+
79
+
80
+ @dataclass
81
+ class WorkResult:
82
+ """Result from processing a work unit."""
83
+
84
+ unit_id: str
85
+ source_id: str
86
+ chunk_id: str
87
+ sample_id: str
88
+ outputs: Dict[str, List[Any]] # field_name -> list of outputs
89
+ dataset: Optional[str] = None
90
+ metadata: Dict[str, Any] = field(default_factory=dict)
91
+ processing_time_ms: float = 0
92
+ error: Optional[str] = None
93
+
94
+ def is_success(self) -> bool:
95
+ return self.error is None and bool(self.outputs)
96
+
97
+ def to_repr(self, filter_outputs: bool = True):
98
+ """
99
+ Print the WorkResult, optionally without captions to save on screen wall-of-text dumpage.
100
+ """
101
+ if filter_outputs:
102
+ outputs = "...filtered from logs..."
103
+ else:
104
+ outputs = self.outputs
105
+
106
+ return {
107
+ "unit_id": self.unit_id,
108
+ "source_id": self.source_id,
109
+ "chunk_id": self.chunk_id,
110
+ "sample_id": self.sample_id,
111
+ "outputs": outputs,
112
+ "metadata": self.metadata,
113
+ "processing_time_ms": self.processing_time_ms,
114
+ "error": self.error,
115
+ }
116
+
117
+
118
+ @dataclass
119
+ class ProcessorConfig:
120
+ """Configuration for a processor."""
121
+
122
+ processor_type: str
123
+ config: Dict[str, Any]
124
+
125
+
126
+ class OrchestratorProcessor(ABC):
127
+ """Base processor for orchestrator side - manages work distribution."""
128
+
129
+ @abstractmethod
130
+ def initialize(self, config: ProcessorConfig) -> None:
131
+ """Initialize the processor with configuration."""
132
+ pass
133
+
134
+ @abstractmethod
135
+ def get_work_units(self, count: int, worker_id: str) -> List[WorkUnit]:
136
+ """Get available work units for a worker."""
137
+ pass
138
+
139
+ @abstractmethod
140
+ def mark_completed(self, unit_id: str, worker_id: str) -> None:
141
+ """Mark a work unit as completed."""
142
+ pass
143
+
144
+ @abstractmethod
145
+ def mark_failed(self, unit_id: str, worker_id: str, error: str) -> None:
146
+ """Mark a work unit as failed."""
147
+ pass
148
+
149
+ @abstractmethod
150
+ def release_assignments(self, worker_id: str) -> None:
151
+ """Release all assignments for a disconnected worker."""
152
+ pass
153
+
154
+ @abstractmethod
155
+ def get_stats(self) -> Dict[str, Any]:
156
+ """Get processor statistics."""
157
+ pass
158
+
159
+ def handle_result(self, result: WorkResult) -> Dict[str, Any]:
160
+ """Handle a work result - can be overridden for custom processing."""
161
+ return {
162
+ "unit_id": result.unit_id,
163
+ "source_id": result.source_id,
164
+ "outputs": result.outputs,
165
+ "metadata": result.metadata,
166
+ }
167
+
168
+
169
+ class WorkerProcessor(ABC):
170
+ """Base processor for worker side - processes work units."""
171
+
172
+ @abstractmethod
173
+ def initialize(self, config: ProcessorConfig) -> None:
174
+ """Initialize the processor with configuration."""
175
+ pass
176
+
177
+ @abstractmethod
178
+ def process_unit(self, unit: WorkUnit, context: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
179
+ """
180
+ Process a single work unit, yielding items to be captioned.
181
+
182
+ Args:
183
+ unit: The work unit to process
184
+ context: Runtime context (e.g., models, sampling params)
185
+
186
+ Yields:
187
+ Dict containing:
188
+ - image: PIL Image
189
+ - metadata: Dict of metadata
190
+ - item_key: Unique identifier for this item
191
+ """
192
+ pass
193
+
194
+ def prepare_result(
195
+ self, unit: WorkUnit, outputs: List[Dict[str, Any]], processing_time_ms: float
196
+ ) -> WorkResult:
197
+ """Prepare a work result from processed outputs."""
198
+ # Aggregate outputs by field
199
+ aggregated = {}
200
+ for output in outputs:
201
+ for field, values in output.items():
202
+ if field not in aggregated:
203
+ aggregated[field] = []
204
+ aggregated[field].extend(values if isinstance(values, list) else [values])
205
+
206
+ return WorkResult(
207
+ unit_id=unit.unit_id,
208
+ source_id=unit.source_id,
209
+ chunk_id=unit.chunk_id,
210
+ sample_id=unit.sample_id,
211
+ outputs=aggregated,
212
+ metadata={"item_count": len(outputs), **unit.metadata},
213
+ processing_time_ms=processing_time_ms,
214
+ )
215
+
216
+ @abstractmethod
217
+ def get_dataset_info(self) -> Dict[str, Any]:
218
+ """Get information about the dataset/source being processed."""
219
+ pass