caption-flow 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +9 -0
- caption_flow/cli.py +709 -0
- caption_flow/models.py +82 -0
- caption_flow/monitor.py +211 -0
- caption_flow/orchestrator.py +1301 -0
- caption_flow/storage.py +694 -0
- caption_flow/utils/__init__.py +4 -0
- caption_flow/utils/auth.py +67 -0
- caption_flow/utils/caption_utils.py +172 -0
- caption_flow/utils/certificates.py +140 -0
- caption_flow/utils/chunk_tracker.py +365 -0
- caption_flow/utils/dataset_loader.py +186 -0
- caption_flow/utils/image_processor.py +51 -0
- caption_flow/utils/job_queue.py +41 -0
- caption_flow/utils/json_utils.py +201 -0
- caption_flow/utils/vllm_config.py +164 -0
- caption_flow/worker.py +300 -0
- caption_flow/worker_data.py +482 -0
- caption_flow/worker_vllm.py +1028 -0
- caption_flow-0.1.0.dist-info/METADATA +427 -0
- caption_flow-0.1.0.dist-info/RECORD +25 -0
- caption_flow-0.1.0.dist-info/WHEEL +5 -0
- caption_flow-0.1.0.dist-info/entry_points.txt +2 -0
- caption_flow-0.1.0.dist-info/licenses/LICENSE +661 -0
- caption_flow-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,365 @@
|
|
1
|
+
"""Chunk tracking for persistent state across restarts."""
|
2
|
+
|
3
|
+
import json
|
4
|
+
import logging
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Set, Dict, List, Optional, Any
|
7
|
+
from datetime import datetime
|
8
|
+
from dataclasses import dataclass, asdict
|
9
|
+
|
10
|
+
logger = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
|
13
|
+
@dataclass
|
14
|
+
class ChunkState:
|
15
|
+
"""State of a chunk."""
|
16
|
+
|
17
|
+
chunk_id: str
|
18
|
+
shard_name: str
|
19
|
+
start_index: int
|
20
|
+
chunk_size: int
|
21
|
+
status: str # pending, assigned, completed, failed
|
22
|
+
completed_at: Optional[datetime] = None
|
23
|
+
assigned_to: Optional[str] = None
|
24
|
+
assigned_at: Optional[datetime] = None
|
25
|
+
|
26
|
+
def to_dict(self):
|
27
|
+
"""Convert to dictionary for JSON serialization."""
|
28
|
+
d = asdict(self)
|
29
|
+
# Convert datetime objects to ISO format strings
|
30
|
+
if d["completed_at"]:
|
31
|
+
d["completed_at"] = d["completed_at"].isoformat()
|
32
|
+
if d["assigned_at"]:
|
33
|
+
d["assigned_at"] = d["assigned_at"].isoformat()
|
34
|
+
return d
|
35
|
+
|
36
|
+
@classmethod
|
37
|
+
def from_dict(cls, d: Dict):
|
38
|
+
"""Create from dictionary."""
|
39
|
+
# Convert ISO format strings back to datetime objects
|
40
|
+
if d.get("completed_at"):
|
41
|
+
d["completed_at"] = datetime.fromisoformat(d["completed_at"])
|
42
|
+
if d.get("assigned_at"):
|
43
|
+
d["assigned_at"] = datetime.fromisoformat(d["assigned_at"])
|
44
|
+
return cls(**d)
|
45
|
+
|
46
|
+
|
47
|
+
class ChunkTracker:
|
48
|
+
"""Tracks chunk processing state persistently."""
|
49
|
+
|
50
|
+
def __init__(self, checkpoint_file: Path):
|
51
|
+
self.checkpoint_file = checkpoint_file
|
52
|
+
self.chunks: Dict[str, ChunkState] = {}
|
53
|
+
self.completed_chunks: Set[str] = set()
|
54
|
+
self._load_checkpoint()
|
55
|
+
|
56
|
+
def _load_checkpoint(self):
|
57
|
+
"""Load checkpoint from disk."""
|
58
|
+
if self.checkpoint_file.exists():
|
59
|
+
try:
|
60
|
+
with open(self.checkpoint_file, "r") as f:
|
61
|
+
data = json.load(f)
|
62
|
+
|
63
|
+
# Load chunk states
|
64
|
+
for chunk_id, chunk_data in data.get("chunks", {}).items():
|
65
|
+
chunk_state = ChunkState.from_dict(chunk_data)
|
66
|
+
self.chunks[chunk_id] = chunk_state
|
67
|
+
if chunk_state.status == "completed":
|
68
|
+
self.completed_chunks.add(chunk_id)
|
69
|
+
|
70
|
+
logger.info(
|
71
|
+
f"Loaded {len(self.chunks)} chunks from checkpoint, "
|
72
|
+
f"{len(self.completed_chunks)} completed"
|
73
|
+
)
|
74
|
+
|
75
|
+
except Exception as e:
|
76
|
+
logger.error(f"Error loading chunk checkpoint: {e}")
|
77
|
+
self.chunks = {}
|
78
|
+
self.completed_chunks = set()
|
79
|
+
|
80
|
+
def save_checkpoint(self):
|
81
|
+
"""Save checkpoint to disk."""
|
82
|
+
try:
|
83
|
+
# Ensure parent directory exists
|
84
|
+
self.checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
|
85
|
+
|
86
|
+
# Convert chunks to serializable format
|
87
|
+
data = {
|
88
|
+
"chunks": {chunk_id: chunk.to_dict() for chunk_id, chunk in self.chunks.items()},
|
89
|
+
"updated_at": datetime.utcnow().isoformat(),
|
90
|
+
}
|
91
|
+
|
92
|
+
# Write atomically with absolute paths
|
93
|
+
tmp_file = self.checkpoint_file.parent / f"{self.checkpoint_file.name}.tmp"
|
94
|
+
|
95
|
+
# Write to temp file
|
96
|
+
with open(tmp_file, "w") as f:
|
97
|
+
json.dump(data, f, indent=2)
|
98
|
+
|
99
|
+
# Ensure temp file was created
|
100
|
+
if not tmp_file.exists():
|
101
|
+
raise IOError(f"Failed to create temporary file: {tmp_file}")
|
102
|
+
|
103
|
+
# Move atomically (use rename for same filesystem)
|
104
|
+
import shutil
|
105
|
+
|
106
|
+
shutil.move(str(tmp_file), str(self.checkpoint_file))
|
107
|
+
|
108
|
+
logger.debug(
|
109
|
+
f"Saved chunk checkpoint with {len(self.chunks)} chunks to {self.checkpoint_file}"
|
110
|
+
)
|
111
|
+
|
112
|
+
except Exception as e:
|
113
|
+
logger.error(f"Error saving chunk checkpoint: {e}", exc_info=True)
|
114
|
+
# Try direct write as fallback
|
115
|
+
try:
|
116
|
+
with open(self.checkpoint_file, "w") as f:
|
117
|
+
json.dump(data, f, indent=2)
|
118
|
+
logger.info("Saved checkpoint using fallback direct write")
|
119
|
+
except Exception as fallback_error:
|
120
|
+
logger.error(f"Fallback save also failed: {fallback_error}")
|
121
|
+
|
122
|
+
def add_chunk(self, chunk_id: str, shard_name: str, start_index: int, chunk_size: int) -> bool:
|
123
|
+
"""Add a new chunk. Returns False if chunk already exists and is completed."""
|
124
|
+
if chunk_id in self.completed_chunks:
|
125
|
+
logger.debug(f"Chunk {chunk_id} already completed, skipping")
|
126
|
+
return False
|
127
|
+
|
128
|
+
if chunk_id not in self.chunks:
|
129
|
+
self.chunks[chunk_id] = ChunkState(
|
130
|
+
chunk_id=chunk_id,
|
131
|
+
shard_name=shard_name,
|
132
|
+
start_index=start_index,
|
133
|
+
chunk_size=chunk_size,
|
134
|
+
status="pending",
|
135
|
+
)
|
136
|
+
self.save_checkpoint()
|
137
|
+
|
138
|
+
return True
|
139
|
+
|
140
|
+
def mark_assigned(self, chunk_id: str, worker_id: str):
|
141
|
+
"""Mark chunk as assigned."""
|
142
|
+
if chunk_id in self.chunks:
|
143
|
+
chunk = self.chunks[chunk_id]
|
144
|
+
chunk.status = "assigned"
|
145
|
+
chunk.assigned_to = worker_id
|
146
|
+
chunk.assigned_at = datetime.utcnow()
|
147
|
+
self.save_checkpoint()
|
148
|
+
|
149
|
+
def mark_completed(self, chunk_id: str):
|
150
|
+
"""Mark chunk as completed."""
|
151
|
+
if chunk_id in self.chunks:
|
152
|
+
chunk = self.chunks[chunk_id]
|
153
|
+
chunk.status = "completed"
|
154
|
+
chunk.completed_at = datetime.utcnow()
|
155
|
+
self.completed_chunks.add(chunk_id)
|
156
|
+
self.save_checkpoint()
|
157
|
+
logger.info(f"Chunk {chunk_id} marked as completed")
|
158
|
+
|
159
|
+
def mark_failed(self, chunk_id: str):
|
160
|
+
"""Mark chunk as failed."""
|
161
|
+
if chunk_id in self.chunks:
|
162
|
+
chunk = self.chunks[chunk_id]
|
163
|
+
chunk.status = "pending" # Reset to pending for retry
|
164
|
+
chunk.assigned_to = None
|
165
|
+
chunk.assigned_at = None
|
166
|
+
self.save_checkpoint()
|
167
|
+
|
168
|
+
def mark_pending(self, chunk_id: str):
|
169
|
+
"""Mark chunk as pending (for manual reset)."""
|
170
|
+
if chunk_id in self.chunks:
|
171
|
+
chunk = self.chunks[chunk_id]
|
172
|
+
chunk.status = "pending"
|
173
|
+
chunk.assigned_to = None
|
174
|
+
chunk.assigned_at = None
|
175
|
+
self.save_checkpoint()
|
176
|
+
|
177
|
+
def release_worker_chunks(self, worker_id: str):
|
178
|
+
"""Release all chunks assigned to a worker."""
|
179
|
+
released_chunks = []
|
180
|
+
for chunk_id, chunk in self.chunks.items():
|
181
|
+
if chunk.assigned_to == worker_id and chunk.status == "assigned":
|
182
|
+
chunk.status = "pending"
|
183
|
+
chunk.assigned_to = None
|
184
|
+
chunk.assigned_at = None
|
185
|
+
released_chunks.append(chunk_id)
|
186
|
+
self.save_checkpoint()
|
187
|
+
return released_chunks
|
188
|
+
|
189
|
+
def get_pending_chunks(self, shard_name: Optional[str] = None) -> List[str]:
|
190
|
+
"""Get list of pending chunk IDs, optionally filtered by shard."""
|
191
|
+
pending = []
|
192
|
+
for chunk_id, chunk in self.chunks.items():
|
193
|
+
if chunk.status == "pending":
|
194
|
+
if shard_name is None or chunk.shard_name == shard_name:
|
195
|
+
pending.append(chunk_id)
|
196
|
+
return pending
|
197
|
+
|
198
|
+
def is_shard_complete(self, shard_name: str) -> bool:
|
199
|
+
"""Check if all chunks for a shard are complete."""
|
200
|
+
shard_chunks = [chunk for chunk in self.chunks.values() if chunk.shard_name == shard_name]
|
201
|
+
|
202
|
+
if not shard_chunks:
|
203
|
+
return False
|
204
|
+
|
205
|
+
return all(chunk.status == "completed" for chunk in shard_chunks)
|
206
|
+
|
207
|
+
def get_stats(self) -> Dict[str, int]:
|
208
|
+
"""Get chunk statistics."""
|
209
|
+
stats = {
|
210
|
+
"total": len(self.chunks),
|
211
|
+
"pending": sum(1 for c in self.chunks.values() if c.status == "pending"),
|
212
|
+
"assigned": sum(1 for c in self.chunks.values() if c.status == "assigned"),
|
213
|
+
"completed": len(self.completed_chunks),
|
214
|
+
"failed": sum(1 for c in self.chunks.values() if c.status == "failed"),
|
215
|
+
}
|
216
|
+
return stats
|
217
|
+
|
218
|
+
def get_shards_summary(self) -> Dict[str, Dict[str, Any]]:
|
219
|
+
"""Get summary of all shards and their chunk status."""
|
220
|
+
shards = {}
|
221
|
+
|
222
|
+
for chunk_id, chunk_state in self.chunks.items():
|
223
|
+
shard_name = chunk_state.shard_name
|
224
|
+
if shard_name not in shards:
|
225
|
+
shards[shard_name] = {
|
226
|
+
"total_chunks": 0,
|
227
|
+
"completed_chunks": 0,
|
228
|
+
"pending_chunks": 0,
|
229
|
+
"assigned_chunks": 0,
|
230
|
+
"failed_chunks": 0,
|
231
|
+
"is_complete": True,
|
232
|
+
"chunks": [],
|
233
|
+
}
|
234
|
+
|
235
|
+
shards[shard_name]["chunks"].append(chunk_state)
|
236
|
+
shards[shard_name]["total_chunks"] += 1
|
237
|
+
|
238
|
+
if chunk_state.status == "completed":
|
239
|
+
shards[shard_name]["completed_chunks"] += 1
|
240
|
+
elif chunk_state.status == "pending":
|
241
|
+
shards[shard_name]["pending_chunks"] += 1
|
242
|
+
shards[shard_name]["is_complete"] = False
|
243
|
+
elif chunk_state.status == "assigned":
|
244
|
+
shards[shard_name]["assigned_chunks"] += 1
|
245
|
+
shards[shard_name]["is_complete"] = False
|
246
|
+
elif chunk_state.status == "failed":
|
247
|
+
shards[shard_name]["failed_chunks"] += 1
|
248
|
+
shards[shard_name]["is_complete"] = False
|
249
|
+
|
250
|
+
return shards
|
251
|
+
|
252
|
+
def get_incomplete_shards(self) -> Set[str]:
|
253
|
+
"""Get set of shard names that have incomplete chunks."""
|
254
|
+
incomplete = set()
|
255
|
+
for chunk_id, chunk_state in self.chunks.items():
|
256
|
+
if chunk_state.status != "completed":
|
257
|
+
incomplete.add(chunk_state.shard_name)
|
258
|
+
return incomplete
|
259
|
+
|
260
|
+
def get_shards_summary(self) -> Dict[str, Dict[str, Any]]:
|
261
|
+
"""Get summary of all shards and their chunk status."""
|
262
|
+
shards = {}
|
263
|
+
|
264
|
+
for chunk_id, chunk_state in self.chunks.items():
|
265
|
+
shard_name = chunk_state.shard_name
|
266
|
+
if shard_name not in shards:
|
267
|
+
shards[shard_name] = {
|
268
|
+
"total_chunks": 0,
|
269
|
+
"completed_chunks": 0,
|
270
|
+
"pending_chunks": 0,
|
271
|
+
"assigned_chunks": 0,
|
272
|
+
"failed_chunks": 0,
|
273
|
+
"is_complete": True,
|
274
|
+
"chunks": [],
|
275
|
+
}
|
276
|
+
|
277
|
+
shards[shard_name]["chunks"].append(chunk_state)
|
278
|
+
shards[shard_name]["total_chunks"] += 1
|
279
|
+
|
280
|
+
if chunk_state.status == "completed":
|
281
|
+
shards[shard_name]["completed_chunks"] += 1
|
282
|
+
elif chunk_state.status == "pending":
|
283
|
+
shards[shard_name]["pending_chunks"] += 1
|
284
|
+
shards[shard_name]["is_complete"] = False
|
285
|
+
elif chunk_state.status == "assigned":
|
286
|
+
shards[shard_name]["assigned_chunks"] += 1
|
287
|
+
shards[shard_name]["is_complete"] = False
|
288
|
+
elif chunk_state.status == "failed":
|
289
|
+
shards[shard_name]["failed_chunks"] += 1
|
290
|
+
shards[shard_name]["is_complete"] = False
|
291
|
+
|
292
|
+
return shards
|
293
|
+
|
294
|
+
async def sync_with_storage(self, storage_manager):
|
295
|
+
"""Sync chunk state with storage to detect already-processed chunks."""
|
296
|
+
logger.info("Syncing chunk state with storage...")
|
297
|
+
|
298
|
+
# Get all existing captions from storage
|
299
|
+
if storage_manager.captions_path.exists():
|
300
|
+
import pyarrow.parquet as pq
|
301
|
+
|
302
|
+
# Read just the job_id column
|
303
|
+
table = pq.read_table(storage_manager.captions_path, columns=["job_id", "chunk_id"])
|
304
|
+
existing_job_ids = set(table["job_id"].to_pylist())
|
305
|
+
|
306
|
+
# Also get chunk_ids if available
|
307
|
+
if "chunk_id" in table.column_names:
|
308
|
+
existing_chunk_ids = set(
|
309
|
+
cid for cid in table["chunk_id"].to_pylist() if cid is not None
|
310
|
+
)
|
311
|
+
|
312
|
+
# Mark existing chunks as completed
|
313
|
+
for chunk_id in existing_chunk_ids:
|
314
|
+
if chunk_id in self.chunks:
|
315
|
+
self.mark_completed(chunk_id)
|
316
|
+
else:
|
317
|
+
# Create chunk entry for already-processed chunks
|
318
|
+
# Extract shard name from chunk_id (format: shard_chunk_index)
|
319
|
+
parts = chunk_id.rsplit("_chunk_", 1)
|
320
|
+
if len(parts) == 2:
|
321
|
+
shard_name = parts[0]
|
322
|
+
try:
|
323
|
+
start_idx = int(parts[1])
|
324
|
+
# We don't know exact chunk size, but mark it as completed
|
325
|
+
self.chunks[chunk_id] = ChunkState(
|
326
|
+
chunk_id=chunk_id,
|
327
|
+
shard_name=shard_name,
|
328
|
+
start_index=start_idx,
|
329
|
+
chunk_size=1000, # Default chunk size
|
330
|
+
status="completed",
|
331
|
+
completed_at=datetime.utcnow(),
|
332
|
+
)
|
333
|
+
self.completed_chunks.add(chunk_id)
|
334
|
+
except ValueError:
|
335
|
+
logger.warning(f"Could not parse chunk_id: {chunk_id}")
|
336
|
+
|
337
|
+
logger.info(f"Found {len(existing_chunk_ids)} completed chunks in storage")
|
338
|
+
|
339
|
+
# Also check by job_id pattern if chunk_id column doesn't exist
|
340
|
+
else:
|
341
|
+
for job_id in existing_job_ids:
|
342
|
+
# Extract chunk_id from job_id (format: chunk_id_item_key)
|
343
|
+
if "_chunk_" in job_id:
|
344
|
+
parts = job_id.split("_")
|
345
|
+
# Find the chunk part
|
346
|
+
for i, part in enumerate(parts):
|
347
|
+
if part == "chunk" and i + 1 < len(parts):
|
348
|
+
try:
|
349
|
+
# Reconstruct chunk_id
|
350
|
+
chunk_idx = int(parts[i + 1])
|
351
|
+
shard_parts = parts[:i]
|
352
|
+
chunk_id = f"{'_'.join(shard_parts)}_chunk_{chunk_idx}"
|
353
|
+
|
354
|
+
if chunk_id not in self.completed_chunks:
|
355
|
+
self.completed_chunks.add(chunk_id)
|
356
|
+
logger.debug(
|
357
|
+
f"Marked chunk {chunk_id} as completed from job_id"
|
358
|
+
)
|
359
|
+
break
|
360
|
+
except ValueError:
|
361
|
+
continue
|
362
|
+
|
363
|
+
logger.info(f"Inferred {len(self.completed_chunks)} completed chunks from job_ids")
|
364
|
+
|
365
|
+
self.save_checkpoint()
|
@@ -0,0 +1,186 @@
|
|
1
|
+
"""Dataset loading utilities for WebDataset and HuggingFace."""
|
2
|
+
|
3
|
+
import asyncio
|
4
|
+
import shlex
|
5
|
+
import logging
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import List, Dict, Any, Generator, Optional, Tuple
|
8
|
+
import json
|
9
|
+
|
10
|
+
import webdataset as wds
|
11
|
+
from huggingface_hub import HfFileSystem, get_token, hf_hub_url
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
class DatasetLoader:
|
17
|
+
"""Handles loading datasets from various sources."""
|
18
|
+
|
19
|
+
def __init__(self, dataset_path: str, dataset_type: str = "huggingface"):
|
20
|
+
"""
|
21
|
+
Initialize dataset loader.
|
22
|
+
|
23
|
+
Args:
|
24
|
+
dataset_path: Path to dataset (HF repo, local dir, etc.)
|
25
|
+
dataset_type: Type of dataset ("huggingface", "webdataset", "local")
|
26
|
+
"""
|
27
|
+
self.dataset_path = dataset_path
|
28
|
+
self.dataset_type = dataset_type
|
29
|
+
self.token = get_token()
|
30
|
+
|
31
|
+
if not self.token and dataset_type == "huggingface":
|
32
|
+
logger.warning("No HuggingFace token found; run `huggingface-cli login`")
|
33
|
+
|
34
|
+
def get_shard_list(self) -> List[str]:
|
35
|
+
"""Get list of all shards in the dataset."""
|
36
|
+
if self.dataset_type == "huggingface":
|
37
|
+
return self._get_hf_shards()
|
38
|
+
elif self.dataset_type == "local":
|
39
|
+
return self._get_local_shards()
|
40
|
+
else:
|
41
|
+
raise ValueError(f"Unknown dataset type: {self.dataset_type}")
|
42
|
+
|
43
|
+
def _get_hf_shards(self) -> List[str]:
|
44
|
+
"""Get shard URLs from HuggingFace dataset."""
|
45
|
+
logger.info(f"Getting shard list from HuggingFace: {self.dataset_path}")
|
46
|
+
|
47
|
+
fs = HfFileSystem()
|
48
|
+
files = [fs.resolve_path(p) for p in fs.glob(f"hf://datasets/{self.dataset_path}/**/*.tar")]
|
49
|
+
|
50
|
+
urls = [hf_hub_url(f.repo_id, f.path_in_repo, repo_type="dataset") for f in files]
|
51
|
+
|
52
|
+
logger.info(f"Found {len(urls)} shards")
|
53
|
+
return sorted(urls)
|
54
|
+
|
55
|
+
def _get_local_shards(self) -> List[str]:
|
56
|
+
"""Get shard files from local directory."""
|
57
|
+
path = Path(self.dataset_path)
|
58
|
+
if not path.exists():
|
59
|
+
raise ValueError(f"Local dataset path does not exist: {path}")
|
60
|
+
|
61
|
+
shards = list(path.glob("*.tar"))
|
62
|
+
logger.info(f"Found {len(shards)} local shards")
|
63
|
+
return [str(s) for s in sorted(shards)]
|
64
|
+
|
65
|
+
def load_shard(self, shard_url: str, processed_keys: Optional[set] = None) -> wds.DataPipeline:
|
66
|
+
"""
|
67
|
+
Load a single shard as a WebDataset pipeline.
|
68
|
+
|
69
|
+
Args:
|
70
|
+
shard_url: URL or path to the shard
|
71
|
+
processed_keys: Set of already processed keys to skip
|
72
|
+
"""
|
73
|
+
if processed_keys is None:
|
74
|
+
processed_keys = set()
|
75
|
+
|
76
|
+
if self.dataset_type == "huggingface":
|
77
|
+
# Use curl with auth token for HuggingFace
|
78
|
+
url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.token)}' {shlex.quote(shard_url)} || true"
|
79
|
+
ds = wds.DataPipeline(
|
80
|
+
wds.SimpleShardList(url_cmd),
|
81
|
+
wds.tarfile_to_samples(),
|
82
|
+
wds.to_tuple("__key__", "__url__", "jpg;png;jpeg;webp;jxl"),
|
83
|
+
wds.select(lambda x: x[0] not in processed_keys),
|
84
|
+
)
|
85
|
+
else:
|
86
|
+
# Local file access
|
87
|
+
ds = wds.DataPipeline(
|
88
|
+
wds.SimpleShardList(shard_url),
|
89
|
+
wds.tarfile_to_samples(),
|
90
|
+
wds.to_tuple("__key__", "__url__", "jpg;png;jpeg;webp;jxl"),
|
91
|
+
wds.select(lambda x: x[0] not in processed_keys),
|
92
|
+
)
|
93
|
+
|
94
|
+
return ds
|
95
|
+
|
96
|
+
def iterate_shard(
|
97
|
+
self, shard_url: str, processed_keys: Optional[set] = None
|
98
|
+
) -> Generator[Tuple[str, str, bytes], None, None]:
|
99
|
+
"""
|
100
|
+
Iterate over items in a shard.
|
101
|
+
|
102
|
+
Yields:
|
103
|
+
Tuple of (key, url, image_bytes)
|
104
|
+
"""
|
105
|
+
ds = self.load_shard(shard_url, processed_keys)
|
106
|
+
|
107
|
+
for key, url, image_data in ds:
|
108
|
+
yield key, url, image_data
|
109
|
+
|
110
|
+
def count_shard_items(self, shard_url: str, processed_keys: Optional[set] = None) -> int:
|
111
|
+
"""Count items in a shard (can be slow for large shards)."""
|
112
|
+
count = 0
|
113
|
+
try:
|
114
|
+
for _ in self.iterate_shard(shard_url, processed_keys):
|
115
|
+
count += 1
|
116
|
+
except Exception as e:
|
117
|
+
logger.error(f"Error counting shard {shard_url}: {e}")
|
118
|
+
return count
|
119
|
+
|
120
|
+
|
121
|
+
class ShardTracker:
|
122
|
+
"""Tracks shard processing progress."""
|
123
|
+
|
124
|
+
def __init__(self, checkpoint_path: Path):
|
125
|
+
"""Initialize shard tracker with checkpoint file."""
|
126
|
+
self.checkpoint_path = checkpoint_path
|
127
|
+
self.checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
|
128
|
+
|
129
|
+
self.completed_shards: set = set()
|
130
|
+
self.partial_shards: Dict[str, Dict[str, Any]] = {}
|
131
|
+
self.load()
|
132
|
+
|
133
|
+
def load(self):
|
134
|
+
"""Load checkpoint from disk."""
|
135
|
+
if self.checkpoint_path.exists():
|
136
|
+
try:
|
137
|
+
data = json.loads(self.checkpoint_path.read_text())
|
138
|
+
self.completed_shards = set(data.get("completed_shards", []))
|
139
|
+
self.partial_shards = data.get("partial_shards", {})
|
140
|
+
logger.info(
|
141
|
+
f"Loaded checkpoint: {len(self.completed_shards)} completed, "
|
142
|
+
f"{len(self.partial_shards)} partial shards"
|
143
|
+
)
|
144
|
+
except Exception as e:
|
145
|
+
logger.error(f"Failed to load checkpoint: {e}")
|
146
|
+
|
147
|
+
def save(self):
|
148
|
+
"""Save checkpoint to disk."""
|
149
|
+
data = {
|
150
|
+
"completed_shards": list(self.completed_shards),
|
151
|
+
"partial_shards": self.partial_shards,
|
152
|
+
}
|
153
|
+
|
154
|
+
tmp = self.checkpoint_path.with_suffix(".tmp")
|
155
|
+
tmp.write_text(json.dumps(data, indent=2))
|
156
|
+
tmp.replace(self.checkpoint_path)
|
157
|
+
|
158
|
+
def mark_complete(self, shard_name: str):
|
159
|
+
"""Mark a shard as complete."""
|
160
|
+
self.completed_shards.add(shard_name)
|
161
|
+
if shard_name in self.partial_shards:
|
162
|
+
del self.partial_shards[shard_name]
|
163
|
+
self.save()
|
164
|
+
|
165
|
+
def update_partial(self, shard_name: str, processed_keys: List[str]):
|
166
|
+
"""Update partial progress for a shard."""
|
167
|
+
self.partial_shards[shard_name] = {"keys": processed_keys, "count": len(processed_keys)}
|
168
|
+
self.save()
|
169
|
+
|
170
|
+
def get_processed_keys(self, shard_name: str) -> set:
|
171
|
+
"""Get set of processed keys for a shard."""
|
172
|
+
if shard_name in self.completed_shards:
|
173
|
+
return set() # All done
|
174
|
+
|
175
|
+
if shard_name in self.partial_shards:
|
176
|
+
return set(self.partial_shards[shard_name].get("keys", []))
|
177
|
+
|
178
|
+
return set()
|
179
|
+
|
180
|
+
def is_complete(self, shard_name: str) -> bool:
|
181
|
+
"""Check if a shard is complete."""
|
182
|
+
return shard_name in self.completed_shards
|
183
|
+
|
184
|
+
def get_remaining_shards(self, all_shards: List[str]) -> List[str]:
|
185
|
+
"""Get list of shards that still need processing."""
|
186
|
+
return [s for s in all_shards if Path(s).stem not in self.completed_shards]
|
@@ -0,0 +1,51 @@
|
|
1
|
+
"""Image preprocessing utilities."""
|
2
|
+
|
3
|
+
import asyncio
|
4
|
+
from concurrent.futures import ProcessPoolExecutor
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import List, Any
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
from PIL import Image
|
10
|
+
|
11
|
+
|
12
|
+
class ImageProcessor:
|
13
|
+
"""Handles image loading and preprocessing."""
|
14
|
+
|
15
|
+
def __init__(self, num_workers: int = 4):
|
16
|
+
self.executor = ProcessPoolExecutor(max_workers=num_workers)
|
17
|
+
|
18
|
+
async def process_batch(self, image_paths: List[Path]) -> List[np.ndarray]:
|
19
|
+
"""Process a batch of images in parallel."""
|
20
|
+
loop = asyncio.get_event_loop()
|
21
|
+
|
22
|
+
tasks = []
|
23
|
+
for path in image_paths:
|
24
|
+
task = loop.run_in_executor(self.executor, self._process_image, path)
|
25
|
+
tasks.append(task)
|
26
|
+
|
27
|
+
return await asyncio.gather(*tasks)
|
28
|
+
|
29
|
+
@staticmethod
|
30
|
+
def _process_image(path: Path) -> np.ndarray:
|
31
|
+
"""Process a single image."""
|
32
|
+
img = Image.open(path)
|
33
|
+
|
34
|
+
# Resize to standard size
|
35
|
+
img = img.resize((224, 224), Image.Resampling.LANCZOS)
|
36
|
+
|
37
|
+
# Convert to RGB if needed
|
38
|
+
if img.mode != "RGB":
|
39
|
+
img = img.convert("RGB")
|
40
|
+
|
41
|
+
# Convert to numpy array
|
42
|
+
arr = np.array(img, dtype=np.float32)
|
43
|
+
|
44
|
+
# Normalize
|
45
|
+
arr = arr / 255.0
|
46
|
+
|
47
|
+
return arr
|
48
|
+
|
49
|
+
def shutdown(self):
|
50
|
+
"""Shutdown the executor."""
|
51
|
+
self.executor.shutdown(wait=True)
|
@@ -0,0 +1,41 @@
|
|
1
|
+
"""Job queue management."""
|
2
|
+
|
3
|
+
import asyncio
|
4
|
+
from typing import Optional
|
5
|
+
from collections import deque
|
6
|
+
|
7
|
+
from ..models import Job
|
8
|
+
|
9
|
+
|
10
|
+
class JobQueue:
|
11
|
+
"""Priority job queue with backpressure."""
|
12
|
+
|
13
|
+
def __init__(self):
|
14
|
+
self.queue = deque()
|
15
|
+
self.processing = set()
|
16
|
+
self.lock = asyncio.Lock()
|
17
|
+
|
18
|
+
async def add(self, job: Job):
|
19
|
+
"""Add job to queue."""
|
20
|
+
async with self.lock:
|
21
|
+
self.queue.append(job)
|
22
|
+
|
23
|
+
async def get_next(self) -> Optional[Job]:
|
24
|
+
"""Get next available job."""
|
25
|
+
async with self.lock:
|
26
|
+
if self.queue:
|
27
|
+
job = self.queue.popleft()
|
28
|
+
self.processing.add(job.job_id)
|
29
|
+
return job
|
30
|
+
return None
|
31
|
+
|
32
|
+
async def complete(self, job_id: str):
|
33
|
+
"""Mark job as complete."""
|
34
|
+
async with self.lock:
|
35
|
+
self.processing.discard(job_id)
|
36
|
+
|
37
|
+
async def requeue(self, job: Job):
|
38
|
+
"""Requeue a job (for failures)."""
|
39
|
+
async with self.lock:
|
40
|
+
self.processing.discard(job.job_id)
|
41
|
+
self.queue.appendleft(job) # Priority requeue
|