caption-flow 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/cli.py +2 -1
- caption_flow/models.py +108 -1
- caption_flow/monitor.py +1 -1
- caption_flow/orchestrator.py +423 -1595
- caption_flow/processors/__init__.py +11 -0
- caption_flow/processors/base.py +219 -0
- caption_flow/processors/huggingface.py +832 -0
- caption_flow/processors/local_filesystem.py +683 -0
- caption_flow/processors/webdataset.py +782 -0
- caption_flow/storage.py +415 -406
- caption_flow/utils/checkpoint_tracker.py +2 -2
- caption_flow/utils/chunk_tracker.py +94 -35
- caption_flow/utils/dataset_loader.py +64 -522
- caption_flow/utils/dataset_metadata_cache.py +67 -0
- caption_flow/utils/image_processor.py +1 -4
- caption_flow/utils/shard_processor.py +4 -200
- caption_flow/utils/shard_tracker.py +1 -5
- caption_flow/workers/base.py +3 -3
- caption_flow/workers/caption.py +416 -792
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/METADATA +29 -27
- caption_flow-0.2.3.dist-info/RECORD +35 -0
- caption_flow-0.2.1.dist-info/RECORD +0 -29
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/WHEEL +0 -0
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/top_level.txt +0 -0
@@ -73,12 +73,12 @@ class CheckpointTracker(ABC):
|
|
73
73
|
logger.debug(f"Saved checkpoint to {self.checkpoint_path}")
|
74
74
|
|
75
75
|
except Exception as e:
|
76
|
-
logger.error(f"Error saving checkpoint: {e}", exc_info=True)
|
76
|
+
# logger.error(f"Error saving checkpoint: {e}", exc_info=True)
|
77
77
|
# Try direct write as fallback
|
78
78
|
try:
|
79
79
|
with open(self.checkpoint_path, "w") as f:
|
80
80
|
json.dump(data, f, indent=2)
|
81
|
-
logger.info("Saved checkpoint using fallback direct write")
|
81
|
+
# logger.info("Saved checkpoint using fallback direct write")
|
82
82
|
except Exception as fallback_error:
|
83
83
|
logger.error(f"Fallback save also failed: {fallback_error}")
|
84
84
|
|
@@ -10,6 +10,7 @@ from dataclasses import dataclass, asdict, field
|
|
10
10
|
from .checkpoint_tracker import CheckpointTracker
|
11
11
|
|
12
12
|
logger = logging.getLogger(__name__)
|
13
|
+
# logger.setLevel(logging.DEBUG)
|
13
14
|
|
14
15
|
|
15
16
|
@dataclass
|
@@ -58,11 +59,15 @@ class ChunkState:
|
|
58
59
|
def get_unprocessed_ranges(self) -> List[Tuple[int, int]]:
|
59
60
|
"""Get ranges that haven't been processed yet."""
|
60
61
|
if not self.processed_ranges:
|
62
|
+
logger.info(f"Chunk {self.chunk_id} has no processed ranges, returning full range")
|
61
63
|
return [(0, self.chunk_size - 1)]
|
62
64
|
|
63
65
|
unprocessed = []
|
64
66
|
current = 0
|
65
67
|
|
68
|
+
logger.info(
|
69
|
+
f"Processing {len(self.processed_ranges)} processed ranges for chunk {self.chunk_id}"
|
70
|
+
)
|
66
71
|
for start, end in self.processed_ranges:
|
67
72
|
if current < start:
|
68
73
|
unprocessed.append((current, start - 1))
|
@@ -132,6 +137,11 @@ class ChunkTracker(CheckpointTracker):
|
|
132
137
|
self, chunk_id: str, shard_name: str, shard_url: str, start_index: int, chunk_size: int
|
133
138
|
) -> bool:
|
134
139
|
"""Add a new chunk. Returns False if chunk already exists and is completed."""
|
140
|
+
if chunk_id in self.chunks:
|
141
|
+
logger.debug(
|
142
|
+
f"Chunk {chunk_id} already exists with status: {self.chunks[chunk_id].status}, not creating"
|
143
|
+
)
|
144
|
+
return False
|
135
145
|
if chunk_id in self.completed_chunks:
|
136
146
|
logger.debug(f"Chunk {chunk_id} already completed, skipping")
|
137
147
|
return False
|
@@ -166,7 +176,7 @@ class ChunkTracker(CheckpointTracker):
|
|
166
176
|
chunk.completed_at = datetime.utcnow()
|
167
177
|
self.completed_chunks.add(chunk_id)
|
168
178
|
self.save()
|
169
|
-
logger.
|
179
|
+
logger.debug(f"Chunk {chunk_id} marked as completed")
|
170
180
|
|
171
181
|
def mark_failed(self, chunk_id: str):
|
172
182
|
"""Mark chunk as failed."""
|
@@ -207,6 +217,49 @@ class ChunkTracker(CheckpointTracker):
|
|
207
217
|
pending.append(chunk_id)
|
208
218
|
return pending
|
209
219
|
|
220
|
+
def get_processed_indices_for_chunk(
|
221
|
+
self, chunk_id: str, processed_job_ids: Set[str]
|
222
|
+
) -> List[Tuple[int, int]]:
|
223
|
+
"""Convert processed job_ids back to ranges for a chunk."""
|
224
|
+
# Extract indices from job_ids like "data-0000:chunk:0:idx:42"
|
225
|
+
processed_indices = []
|
226
|
+
# this will be slow as shit, but it's simple for now, Proof of Concept.
|
227
|
+
for job_id in processed_job_ids:
|
228
|
+
test_chunk_id = chunk_id.replace("_", ":")
|
229
|
+
if test_chunk_id in job_id:
|
230
|
+
parts = job_id.split(":")
|
231
|
+
logger.debug(
|
232
|
+
f"Found matching job_id {job_id} for chunk {chunk_id} with {len(parts)=} and {parts[3]=}"
|
233
|
+
)
|
234
|
+
if len(parts) >= 5 and parts[3] == "idx":
|
235
|
+
idx = int(parts[4])
|
236
|
+
processed_indices.append(idx)
|
237
|
+
|
238
|
+
# Convert to ranges
|
239
|
+
if not processed_indices:
|
240
|
+
# logger.warning(
|
241
|
+
# f"Chunk {chunk_id} had no pre-processed ranges discovered, will process all elements"
|
242
|
+
# )
|
243
|
+
return []
|
244
|
+
else:
|
245
|
+
logger.debug(f"Chunk {chunk_id} has {len(processed_indices)} pre-processed indices")
|
246
|
+
|
247
|
+
processed_indices.sort()
|
248
|
+
ranges = []
|
249
|
+
start = processed_indices[0]
|
250
|
+
end = processed_indices[0]
|
251
|
+
|
252
|
+
for idx in processed_indices[1:]:
|
253
|
+
if idx == end + 1:
|
254
|
+
end = idx
|
255
|
+
else:
|
256
|
+
ranges.append((start, end))
|
257
|
+
start = idx
|
258
|
+
end = idx
|
259
|
+
|
260
|
+
ranges.append((start, end))
|
261
|
+
return ranges
|
262
|
+
|
210
263
|
def is_shard_complete(self, shard_name: str) -> bool:
|
211
264
|
"""Check if all chunks for a shard are complete."""
|
212
265
|
shard_chunks = [chunk for chunk in self.chunks.values() if chunk.shard_name == shard_name]
|
@@ -236,20 +289,8 @@ class ChunkTracker(CheckpointTracker):
|
|
236
289
|
|
237
290
|
for chunk_id, chunk_state in self.chunks.items():
|
238
291
|
shard_name = chunk_state.shard_name
|
239
|
-
|
240
|
-
|
241
|
-
if shard_name.startswith("hf_dataset:"):
|
242
|
-
parts = shard_name.split(":")
|
243
|
-
if len(parts) >= 4 and parts[2] == "chunk":
|
244
|
-
# Use just the dataset identifier as the shard name
|
245
|
-
normalized_shard_name = ":".join(parts[:2])
|
246
|
-
else:
|
247
|
-
normalized_shard_name = shard_name
|
248
|
-
else:
|
249
|
-
normalized_shard_name = shard_name
|
250
|
-
|
251
|
-
if normalized_shard_name not in shards:
|
252
|
-
shards[normalized_shard_name] = {
|
292
|
+
if shard_name not in shards:
|
293
|
+
shards[shard_name] = {
|
253
294
|
"total_chunks": 0,
|
254
295
|
"completed_chunks": 0,
|
255
296
|
"pending_chunks": 0,
|
@@ -259,20 +300,20 @@ class ChunkTracker(CheckpointTracker):
|
|
259
300
|
"chunks": [],
|
260
301
|
}
|
261
302
|
|
262
|
-
shards[
|
263
|
-
shards[
|
303
|
+
shards[shard_name]["chunks"].append(chunk_state)
|
304
|
+
shards[shard_name]["total_chunks"] += 1
|
264
305
|
|
265
306
|
if chunk_state.status == "completed":
|
266
|
-
shards[
|
307
|
+
shards[shard_name]["completed_chunks"] += 1
|
267
308
|
elif chunk_state.status == "pending":
|
268
|
-
shards[
|
269
|
-
shards[
|
309
|
+
shards[shard_name]["pending_chunks"] += 1
|
310
|
+
shards[shard_name]["is_complete"] = False
|
270
311
|
elif chunk_state.status == "assigned":
|
271
|
-
shards[
|
272
|
-
shards[
|
312
|
+
shards[shard_name]["assigned_chunks"] += 1
|
313
|
+
shards[shard_name]["is_complete"] = False
|
273
314
|
elif chunk_state.status == "failed":
|
274
|
-
shards[
|
275
|
-
shards[
|
315
|
+
shards[shard_name]["failed_chunks"] += 1
|
316
|
+
shards[shard_name]["is_complete"] = False
|
276
317
|
|
277
318
|
return shards
|
278
319
|
|
@@ -322,13 +363,7 @@ class ChunkTracker(CheckpointTracker):
|
|
322
363
|
continue
|
323
364
|
|
324
365
|
# Infer shard URL and create chunk with default size
|
325
|
-
|
326
|
-
# HF dataset
|
327
|
-
dataset_path = shard_name.replace("_", "/")
|
328
|
-
shard_url = f"hf_dataset:{dataset_path}:chunk:{start_idx}"
|
329
|
-
else:
|
330
|
-
# WebDataset
|
331
|
-
shard_url = f"unknown://{shard_name}.tar"
|
366
|
+
shard_url = f"unknown://{shard_name}.tar"
|
332
367
|
|
333
368
|
self.chunks[chunk_id] = ChunkState(
|
334
369
|
chunk_id=chunk_id,
|
@@ -410,6 +445,7 @@ class ChunkTracker(CheckpointTracker):
|
|
410
445
|
"""Mark a range of items as processed within a chunk (expects ABSOLUTE indices)."""
|
411
446
|
if chunk_id not in self.chunks:
|
412
447
|
logger.error(f"Unknown chunk: {chunk_id}")
|
448
|
+
logger.debug(f"Known chunks: {list(self.chunks.keys())}")
|
413
449
|
return
|
414
450
|
|
415
451
|
chunk = self.chunks[chunk_id]
|
@@ -441,9 +477,32 @@ class ChunkTracker(CheckpointTracker):
|
|
441
477
|
)
|
442
478
|
|
443
479
|
def get_chunk_with_unprocessed_items(self, chunk_id: str) -> Optional[Dict[str, Any]]:
|
444
|
-
"""Get chunk info
|
445
|
-
|
480
|
+
"""Get chunk info with unprocessed item ranges."""
|
481
|
+
chunk_state = self.chunks.get(chunk_id)
|
482
|
+
if not chunk_state:
|
446
483
|
return None
|
447
484
|
|
448
|
-
|
449
|
-
|
485
|
+
# During startup or if no worker is assigned, treat all unprocessed as available
|
486
|
+
if not hasattr(self, "_startup_complete"):
|
487
|
+
self._startup_complete = False
|
488
|
+
|
489
|
+
if not self._startup_complete or (
|
490
|
+
not chunk_state.assigned_to or chunk_state.completed_at is None
|
491
|
+
):
|
492
|
+
# Return all unprocessed ranges
|
493
|
+
logger.debug(
|
494
|
+
f"Returning all unprocessed ranges. Status {self._startup_complete=} {chunk_state=}"
|
495
|
+
)
|
496
|
+
return {
|
497
|
+
"chunk_id": chunk_id,
|
498
|
+
"unprocessed_ranges": chunk_state.get_unprocessed_ranges(),
|
499
|
+
"status": chunk_state.status,
|
500
|
+
}
|
501
|
+
|
502
|
+
# Normal operation - only return ranges not being worked on
|
503
|
+
# This would need more complex tracking of which ranges each worker is processing
|
504
|
+
return {
|
505
|
+
"chunk_id": chunk_id,
|
506
|
+
"unprocessed_ranges": chunk_state.get_unprocessed_ranges(),
|
507
|
+
"status": chunk_state.status,
|
508
|
+
}
|