caption-flow 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/cli.py +308 -0
- caption_flow/models.py +134 -1
- caption_flow/monitor.py +1 -1
- caption_flow/orchestrator.py +423 -1715
- caption_flow/processors/__init__.py +11 -0
- caption_flow/processors/base.py +219 -0
- caption_flow/processors/huggingface.py +832 -0
- caption_flow/processors/local_filesystem.py +683 -0
- caption_flow/processors/webdataset.py +782 -0
- caption_flow/storage/__init__.py +1 -0
- caption_flow/storage/exporter.py +550 -0
- caption_flow/{storage.py → storage/manager.py} +489 -401
- caption_flow/utils/checkpoint_tracker.py +2 -2
- caption_flow/utils/chunk_tracker.py +73 -32
- caption_flow/utils/dataset_loader.py +58 -298
- caption_flow/utils/dataset_metadata_cache.py +67 -0
- caption_flow/utils/image_processor.py +1 -4
- caption_flow/utils/shard_processor.py +5 -265
- caption_flow/utils/shard_tracker.py +1 -5
- caption_flow/viewer.py +594 -0
- caption_flow/workers/base.py +3 -3
- caption_flow/workers/caption.py +416 -792
- {caption_flow-0.2.2.dist-info → caption_flow-0.2.4.dist-info}/METADATA +49 -180
- caption_flow-0.2.4.dist-info/RECORD +38 -0
- caption_flow-0.2.2.dist-info/RECORD +0 -29
- {caption_flow-0.2.2.dist-info → caption_flow-0.2.4.dist-info}/WHEEL +0 -0
- {caption_flow-0.2.2.dist-info → caption_flow-0.2.4.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.2.2.dist-info → caption_flow-0.2.4.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.2.2.dist-info → caption_flow-0.2.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,11 @@
|
|
1
|
+
from .base import (
|
2
|
+
OrchestratorProcessor,
|
3
|
+
WorkerProcessor,
|
4
|
+
ProcessorConfig,
|
5
|
+
WorkUnit,
|
6
|
+
WorkAssignment,
|
7
|
+
WorkResult,
|
8
|
+
)
|
9
|
+
from .huggingface import HuggingFaceDatasetOrchestratorProcessor, HuggingFaceDatasetWorkerProcessor
|
10
|
+
from .webdataset import WebDatasetOrchestratorProcessor, WebDatasetWorkerProcessor
|
11
|
+
from .local_filesystem import LocalFilesystemOrchestratorProcessor, LocalFilesystemWorkerProcessor
|
@@ -0,0 +1,219 @@
|
|
1
|
+
"""Base processor abstractions for data source handling."""
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
from dataclasses import dataclass, field
|
5
|
+
from typing import Dict, Any, List, Optional, Iterator, Tuple
|
6
|
+
from datetime import datetime
|
7
|
+
from pathlib import Path
|
8
|
+
|
9
|
+
|
10
|
+
@dataclass
|
11
|
+
class WorkUnit:
|
12
|
+
"""Generic unit of work that can be processed."""
|
13
|
+
|
14
|
+
unit_id: str # usually, but not always, the chunk id
|
15
|
+
chunk_id: str # always the chunk id
|
16
|
+
source_id: str # the shard name
|
17
|
+
data: Dict[str, Any]
|
18
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
19
|
+
priority: int = 0
|
20
|
+
sample_id: str = ""
|
21
|
+
|
22
|
+
def get_size_hint(self) -> int:
|
23
|
+
"""Get estimated size/complexity of this work unit."""
|
24
|
+
return self.metadata.get("size_hint", 1)
|
25
|
+
|
26
|
+
|
27
|
+
@dataclass
|
28
|
+
class WorkAssignment:
|
29
|
+
"""Assignment of work units to a worker."""
|
30
|
+
|
31
|
+
assignment_id: str
|
32
|
+
worker_id: str
|
33
|
+
units: List[WorkUnit]
|
34
|
+
assigned_at: datetime
|
35
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
36
|
+
|
37
|
+
def to_dict(self) -> Dict[str, Any]:
|
38
|
+
"""Convert to dict for network transmission."""
|
39
|
+
return {
|
40
|
+
"assignment_id": self.assignment_id,
|
41
|
+
"worker_id": self.worker_id,
|
42
|
+
"units": [
|
43
|
+
{
|
44
|
+
"unit_id": u.unit_id,
|
45
|
+
"source_id": u.source_id,
|
46
|
+
"chunk_id": u.chunk_id,
|
47
|
+
"data": u.data,
|
48
|
+
"metadata": u.metadata,
|
49
|
+
"priority": u.priority,
|
50
|
+
}
|
51
|
+
for u in self.units
|
52
|
+
],
|
53
|
+
"assigned_at": self.assigned_at.isoformat(),
|
54
|
+
"metadata": self.metadata,
|
55
|
+
}
|
56
|
+
|
57
|
+
@classmethod
|
58
|
+
def from_dict(cls, data: Dict[str, Any]) -> "WorkAssignment":
|
59
|
+
"""Create from dict received over network."""
|
60
|
+
units = [
|
61
|
+
WorkUnit(
|
62
|
+
unit_id=u["unit_id"],
|
63
|
+
chunk_id=u["chunk_id"],
|
64
|
+
source_id=u["source_id"],
|
65
|
+
data=u["data"],
|
66
|
+
metadata=u.get("metadata", {}),
|
67
|
+
priority=u.get("priority", 0),
|
68
|
+
)
|
69
|
+
for u in data["units"]
|
70
|
+
]
|
71
|
+
return cls(
|
72
|
+
assignment_id=data["assignment_id"],
|
73
|
+
worker_id=data["worker_id"],
|
74
|
+
units=units,
|
75
|
+
assigned_at=datetime.fromisoformat(data["assigned_at"]),
|
76
|
+
metadata=data.get("metadata", {}),
|
77
|
+
)
|
78
|
+
|
79
|
+
|
80
|
+
@dataclass
|
81
|
+
class WorkResult:
|
82
|
+
"""Result from processing a work unit."""
|
83
|
+
|
84
|
+
unit_id: str
|
85
|
+
source_id: str
|
86
|
+
chunk_id: str
|
87
|
+
sample_id: str
|
88
|
+
outputs: Dict[str, List[Any]] # field_name -> list of outputs
|
89
|
+
dataset: Optional[str] = None
|
90
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
91
|
+
processing_time_ms: float = 0
|
92
|
+
error: Optional[str] = None
|
93
|
+
|
94
|
+
def is_success(self) -> bool:
|
95
|
+
return self.error is None and bool(self.outputs)
|
96
|
+
|
97
|
+
def to_repr(self, filter_outputs: bool = True):
|
98
|
+
"""
|
99
|
+
Print the WorkResult, optionally without captions to save on screen wall-of-text dumpage.
|
100
|
+
"""
|
101
|
+
if filter_outputs:
|
102
|
+
outputs = "...filtered from logs..."
|
103
|
+
else:
|
104
|
+
outputs = self.outputs
|
105
|
+
|
106
|
+
return {
|
107
|
+
"unit_id": self.unit_id,
|
108
|
+
"source_id": self.source_id,
|
109
|
+
"chunk_id": self.chunk_id,
|
110
|
+
"sample_id": self.sample_id,
|
111
|
+
"outputs": outputs,
|
112
|
+
"metadata": self.metadata,
|
113
|
+
"processing_time_ms": self.processing_time_ms,
|
114
|
+
"error": self.error,
|
115
|
+
}
|
116
|
+
|
117
|
+
|
118
|
+
@dataclass
|
119
|
+
class ProcessorConfig:
|
120
|
+
"""Configuration for a processor."""
|
121
|
+
|
122
|
+
processor_type: str
|
123
|
+
config: Dict[str, Any]
|
124
|
+
|
125
|
+
|
126
|
+
class OrchestratorProcessor(ABC):
|
127
|
+
"""Base processor for orchestrator side - manages work distribution."""
|
128
|
+
|
129
|
+
@abstractmethod
|
130
|
+
def initialize(self, config: ProcessorConfig) -> None:
|
131
|
+
"""Initialize the processor with configuration."""
|
132
|
+
pass
|
133
|
+
|
134
|
+
@abstractmethod
|
135
|
+
def get_work_units(self, count: int, worker_id: str) -> List[WorkUnit]:
|
136
|
+
"""Get available work units for a worker."""
|
137
|
+
pass
|
138
|
+
|
139
|
+
@abstractmethod
|
140
|
+
def mark_completed(self, unit_id: str, worker_id: str) -> None:
|
141
|
+
"""Mark a work unit as completed."""
|
142
|
+
pass
|
143
|
+
|
144
|
+
@abstractmethod
|
145
|
+
def mark_failed(self, unit_id: str, worker_id: str, error: str) -> None:
|
146
|
+
"""Mark a work unit as failed."""
|
147
|
+
pass
|
148
|
+
|
149
|
+
@abstractmethod
|
150
|
+
def release_assignments(self, worker_id: str) -> None:
|
151
|
+
"""Release all assignments for a disconnected worker."""
|
152
|
+
pass
|
153
|
+
|
154
|
+
@abstractmethod
|
155
|
+
def get_stats(self) -> Dict[str, Any]:
|
156
|
+
"""Get processor statistics."""
|
157
|
+
pass
|
158
|
+
|
159
|
+
def handle_result(self, result: WorkResult) -> Dict[str, Any]:
|
160
|
+
"""Handle a work result - can be overridden for custom processing."""
|
161
|
+
return {
|
162
|
+
"unit_id": result.unit_id,
|
163
|
+
"source_id": result.source_id,
|
164
|
+
"outputs": result.outputs,
|
165
|
+
"metadata": result.metadata,
|
166
|
+
}
|
167
|
+
|
168
|
+
|
169
|
+
class WorkerProcessor(ABC):
|
170
|
+
"""Base processor for worker side - processes work units."""
|
171
|
+
|
172
|
+
@abstractmethod
|
173
|
+
def initialize(self, config: ProcessorConfig) -> None:
|
174
|
+
"""Initialize the processor with configuration."""
|
175
|
+
pass
|
176
|
+
|
177
|
+
@abstractmethod
|
178
|
+
def process_unit(self, unit: WorkUnit, context: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
|
179
|
+
"""
|
180
|
+
Process a single work unit, yielding items to be captioned.
|
181
|
+
|
182
|
+
Args:
|
183
|
+
unit: The work unit to process
|
184
|
+
context: Runtime context (e.g., models, sampling params)
|
185
|
+
|
186
|
+
Yields:
|
187
|
+
Dict containing:
|
188
|
+
- image: PIL Image
|
189
|
+
- metadata: Dict of metadata
|
190
|
+
- item_key: Unique identifier for this item
|
191
|
+
"""
|
192
|
+
pass
|
193
|
+
|
194
|
+
def prepare_result(
|
195
|
+
self, unit: WorkUnit, outputs: List[Dict[str, Any]], processing_time_ms: float
|
196
|
+
) -> WorkResult:
|
197
|
+
"""Prepare a work result from processed outputs."""
|
198
|
+
# Aggregate outputs by field
|
199
|
+
aggregated = {}
|
200
|
+
for output in outputs:
|
201
|
+
for field, values in output.items():
|
202
|
+
if field not in aggregated:
|
203
|
+
aggregated[field] = []
|
204
|
+
aggregated[field].extend(values if isinstance(values, list) else [values])
|
205
|
+
|
206
|
+
return WorkResult(
|
207
|
+
unit_id=unit.unit_id,
|
208
|
+
source_id=unit.source_id,
|
209
|
+
chunk_id=unit.chunk_id,
|
210
|
+
sample_id=unit.sample_id,
|
211
|
+
outputs=aggregated,
|
212
|
+
metadata={"item_count": len(outputs), **unit.metadata},
|
213
|
+
processing_time_ms=processing_time_ms,
|
214
|
+
)
|
215
|
+
|
216
|
+
@abstractmethod
|
217
|
+
def get_dataset_info(self) -> Dict[str, Any]:
|
218
|
+
"""Get information about the dataset/source being processed."""
|
219
|
+
pass
|