caption-flow 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/cli.py +2 -1
- caption_flow/models.py +108 -1
- caption_flow/monitor.py +1 -1
- caption_flow/orchestrator.py +423 -1595
- caption_flow/processors/__init__.py +11 -0
- caption_flow/processors/base.py +219 -0
- caption_flow/processors/huggingface.py +832 -0
- caption_flow/processors/local_filesystem.py +683 -0
- caption_flow/processors/webdataset.py +782 -0
- caption_flow/storage.py +415 -406
- caption_flow/utils/checkpoint_tracker.py +2 -2
- caption_flow/utils/chunk_tracker.py +94 -35
- caption_flow/utils/dataset_loader.py +64 -522
- caption_flow/utils/dataset_metadata_cache.py +67 -0
- caption_flow/utils/image_processor.py +1 -4
- caption_flow/utils/shard_processor.py +4 -200
- caption_flow/utils/shard_tracker.py +1 -5
- caption_flow/workers/base.py +3 -3
- caption_flow/workers/caption.py +416 -792
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/METADATA +29 -27
- caption_flow-0.2.3.dist-info/RECORD +35 -0
- caption_flow-0.2.1.dist-info/RECORD +0 -29
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/WHEEL +0 -0
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/top_level.txt +0 -0
caption_flow/cli.py
CHANGED
@@ -124,7 +124,7 @@ def setup_logging(verbose: bool = False):
|
|
124
124
|
level = logging.DEBUG if verbose else logging.INFO
|
125
125
|
logging.basicConfig(
|
126
126
|
level=level,
|
127
|
-
format="%(
|
127
|
+
format="%(message)s",
|
128
128
|
datefmt="[%Y-%m-%d %H:%M:%S]",
|
129
129
|
handlers=[
|
130
130
|
RichHandler(
|
@@ -161,6 +161,7 @@ def main(ctx, verbose: bool):
|
|
161
161
|
@click.option("--key", help="SSL key path")
|
162
162
|
@click.option("--no-ssl", is_flag=True, help="Disable SSL (development only)")
|
163
163
|
@click.option("--vllm", is_flag=True, help="Use vLLM orchestrator for WebDataset/HF datasets")
|
164
|
+
@click.option("--verbose", is_flag=True, help="Enable verbose logging")
|
164
165
|
@click.pass_context
|
165
166
|
def orchestrator(ctx, config: Optional[str], **kwargs):
|
166
167
|
"""Start the orchestrator server."""
|
caption_flow/models.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1
1
|
"""Data models for CaptionFlow."""
|
2
2
|
|
3
|
+
import PIL
|
3
4
|
from dataclasses import dataclass, field
|
4
5
|
from datetime import datetime
|
5
6
|
from enum import Enum
|
6
|
-
from typing import Any, Dict, List, Optional
|
7
|
+
from typing import Any, Dict, List, Optional, Tuple
|
8
|
+
from PIL import Image
|
7
9
|
|
8
10
|
|
9
11
|
class JobStatus(Enum):
|
@@ -38,6 +40,38 @@ class Job:
|
|
38
40
|
self.created_at = datetime.utcnow()
|
39
41
|
|
40
42
|
|
43
|
+
@dataclass
|
44
|
+
class JobId:
|
45
|
+
shard_id: str
|
46
|
+
chunk_id: str
|
47
|
+
sample_id: str
|
48
|
+
|
49
|
+
def get_shard_str(self):
|
50
|
+
return f"{self.shard_id}"
|
51
|
+
|
52
|
+
def get_chunk_str(self):
|
53
|
+
return f"{self.shard_id}:chunk:{self.chunk_id}"
|
54
|
+
|
55
|
+
def get_sample_str(self):
|
56
|
+
return f"{self.shard_id}:chunk:{self.chunk_id}:idx:{self.sample_id}"
|
57
|
+
|
58
|
+
@staticmethod
|
59
|
+
def from_dict(job: dict) -> "JobId":
|
60
|
+
return JobId(shard_id=job["shard_id"], chunk_id=job["chunk_id"], sample_id=job["sample_id"])
|
61
|
+
|
62
|
+
@staticmethod
|
63
|
+
def from_values(shard_id: str, chunk_id: str, sample_id: str) -> "JobId":
|
64
|
+
return JobId(shard_id=shard_id, chunk_id=chunk_id, sample_id=sample_id)
|
65
|
+
|
66
|
+
@staticmethod
|
67
|
+
def from_str(job_id: str):
|
68
|
+
# from data-0000:chunk:0:idx:0
|
69
|
+
parts = job_id.split(":")
|
70
|
+
if len(parts) != 5:
|
71
|
+
raise ValueError(f"Invalid job_id format: {job_id}")
|
72
|
+
return JobId(shard_id=parts[0], chunk_id=parts[2], sample_id=parts[4])
|
73
|
+
|
74
|
+
|
41
75
|
@dataclass
|
42
76
|
class Caption:
|
43
77
|
"""Generated caption with attribution and image metadata."""
|
@@ -61,6 +95,8 @@ class Caption:
|
|
61
95
|
image_height: Optional[int] = None
|
62
96
|
image_format: Optional[str] = None
|
63
97
|
file_size: Optional[int] = None
|
98
|
+
filename: Optional[str] = None
|
99
|
+
url: Optional[str] = None
|
64
100
|
|
65
101
|
# Processing metadata
|
66
102
|
caption_index: Optional[int] = None # Which caption this is (0, 1, 2...)
|
@@ -82,3 +118,74 @@ class Contributor:
|
|
82
118
|
name: str
|
83
119
|
total_captions: int = 0
|
84
120
|
trust_level: int = 1
|
121
|
+
|
122
|
+
|
123
|
+
@dataclass
|
124
|
+
class ProcessingStage:
|
125
|
+
"""Configuration for a single processing stage."""
|
126
|
+
|
127
|
+
name: str
|
128
|
+
model: str
|
129
|
+
prompts: List[str]
|
130
|
+
output_field: str
|
131
|
+
requires: List[str] = field(default_factory=list)
|
132
|
+
sampling: Optional[Dict[str, Any]] = None
|
133
|
+
|
134
|
+
# Model-specific overrides
|
135
|
+
tensor_parallel_size: Optional[int] = None
|
136
|
+
max_model_len: Optional[int] = None
|
137
|
+
dtype: Optional[str] = None
|
138
|
+
gpu_memory_utilization: Optional[float] = None
|
139
|
+
|
140
|
+
|
141
|
+
@dataclass
|
142
|
+
class StageResult:
|
143
|
+
"""Results from a single stage."""
|
144
|
+
|
145
|
+
stage_name: str
|
146
|
+
output_field: str
|
147
|
+
outputs: List[str] # Multiple outputs from multiple prompts
|
148
|
+
error: Optional[str] = None
|
149
|
+
|
150
|
+
def is_success(self) -> bool:
|
151
|
+
return self.error is None and bool(self.outputs)
|
152
|
+
|
153
|
+
|
154
|
+
@dataclass
|
155
|
+
class ShardChunk:
|
156
|
+
"""Shard chunk assignment with unprocessed ranges."""
|
157
|
+
|
158
|
+
chunk_id: str
|
159
|
+
shard_url: str
|
160
|
+
shard_name: str
|
161
|
+
start_index: int
|
162
|
+
chunk_size: int
|
163
|
+
unprocessed_ranges: List[Tuple[int, int]] = field(default_factory=list)
|
164
|
+
|
165
|
+
|
166
|
+
@dataclass
|
167
|
+
class ProcessingItem:
|
168
|
+
"""Item being processed."""
|
169
|
+
|
170
|
+
chunk_id: str
|
171
|
+
item_key: str
|
172
|
+
image: Image.Image
|
173
|
+
image_data: bytes
|
174
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
175
|
+
stage_results: Dict[str, StageResult] = field(default_factory=dict) # Accumulated results
|
176
|
+
|
177
|
+
|
178
|
+
@dataclass
|
179
|
+
class ProcessedResult:
|
180
|
+
"""Result with multi-stage outputs."""
|
181
|
+
|
182
|
+
chunk_id: str
|
183
|
+
shard_name: str
|
184
|
+
item_key: str
|
185
|
+
outputs: Dict[str, List[str]] # field_name -> list of outputs
|
186
|
+
image_width: int
|
187
|
+
image_height: int
|
188
|
+
image_format: str
|
189
|
+
file_size: int
|
190
|
+
processing_time_ms: float
|
191
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
caption_flow/monitor.py
CHANGED