caption-flow 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -73,12 +73,12 @@ class CheckpointTracker(ABC):
73
73
  logger.debug(f"Saved checkpoint to {self.checkpoint_path}")
74
74
 
75
75
  except Exception as e:
76
- logger.error(f"Error saving checkpoint: {e}", exc_info=True)
76
+ # logger.error(f"Error saving checkpoint: {e}", exc_info=True)
77
77
  # Try direct write as fallback
78
78
  try:
79
79
  with open(self.checkpoint_path, "w") as f:
80
80
  json.dump(data, f, indent=2)
81
- logger.info("Saved checkpoint using fallback direct write")
81
+ # logger.info("Saved checkpoint using fallback direct write")
82
82
  except Exception as fallback_error:
83
83
  logger.error(f"Fallback save also failed: {fallback_error}")
84
84
 
@@ -10,6 +10,7 @@ from dataclasses import dataclass, asdict, field
10
10
  from .checkpoint_tracker import CheckpointTracker
11
11
 
12
12
  logger = logging.getLogger(__name__)
13
+ # logger.setLevel(logging.DEBUG)
13
14
 
14
15
 
15
16
  @dataclass
@@ -58,11 +59,15 @@ class ChunkState:
58
59
  def get_unprocessed_ranges(self) -> List[Tuple[int, int]]:
59
60
  """Get ranges that haven't been processed yet."""
60
61
  if not self.processed_ranges:
62
+ logger.info(f"Chunk {self.chunk_id} has no processed ranges, returning full range")
61
63
  return [(0, self.chunk_size - 1)]
62
64
 
63
65
  unprocessed = []
64
66
  current = 0
65
67
 
68
+ logger.info(
69
+ f"Processing {len(self.processed_ranges)} processed ranges for chunk {self.chunk_id}"
70
+ )
66
71
  for start, end in self.processed_ranges:
67
72
  if current < start:
68
73
  unprocessed.append((current, start - 1))
@@ -132,6 +137,11 @@ class ChunkTracker(CheckpointTracker):
132
137
  self, chunk_id: str, shard_name: str, shard_url: str, start_index: int, chunk_size: int
133
138
  ) -> bool:
134
139
  """Add a new chunk. Returns False if chunk already exists and is completed."""
140
+ if chunk_id in self.chunks:
141
+ logger.debug(
142
+ f"Chunk {chunk_id} already exists with status: {self.chunks[chunk_id].status}, not creating"
143
+ )
144
+ return False
135
145
  if chunk_id in self.completed_chunks:
136
146
  logger.debug(f"Chunk {chunk_id} already completed, skipping")
137
147
  return False
@@ -166,7 +176,7 @@ class ChunkTracker(CheckpointTracker):
166
176
  chunk.completed_at = datetime.utcnow()
167
177
  self.completed_chunks.add(chunk_id)
168
178
  self.save()
169
- logger.info(f"Chunk {chunk_id} marked as completed")
179
+ logger.debug(f"Chunk {chunk_id} marked as completed")
170
180
 
171
181
  def mark_failed(self, chunk_id: str):
172
182
  """Mark chunk as failed."""
@@ -207,6 +217,49 @@ class ChunkTracker(CheckpointTracker):
207
217
  pending.append(chunk_id)
208
218
  return pending
209
219
 
220
+ def get_processed_indices_for_chunk(
221
+ self, chunk_id: str, processed_job_ids: Set[str]
222
+ ) -> List[Tuple[int, int]]:
223
+ """Convert processed job_ids back to ranges for a chunk."""
224
+ # Extract indices from job_ids like "data-0000:chunk:0:idx:42"
225
+ processed_indices = []
226
+ # this will be slow as shit, but it's simple for now, Proof of Concept.
227
+ for job_id in processed_job_ids:
228
+ test_chunk_id = chunk_id.replace("_", ":")
229
+ if test_chunk_id in job_id:
230
+ parts = job_id.split(":")
231
+ logger.debug(
232
+ f"Found matching job_id {job_id} for chunk {chunk_id} with {len(parts)=} and {parts[3]=}"
233
+ )
234
+ if len(parts) >= 5 and parts[3] == "idx":
235
+ idx = int(parts[4])
236
+ processed_indices.append(idx)
237
+
238
+ # Convert to ranges
239
+ if not processed_indices:
240
+ # logger.warning(
241
+ # f"Chunk {chunk_id} had no pre-processed ranges discovered, will process all elements"
242
+ # )
243
+ return []
244
+ else:
245
+ logger.debug(f"Chunk {chunk_id} has {len(processed_indices)} pre-processed indices")
246
+
247
+ processed_indices.sort()
248
+ ranges = []
249
+ start = processed_indices[0]
250
+ end = processed_indices[0]
251
+
252
+ for idx in processed_indices[1:]:
253
+ if idx == end + 1:
254
+ end = idx
255
+ else:
256
+ ranges.append((start, end))
257
+ start = idx
258
+ end = idx
259
+
260
+ ranges.append((start, end))
261
+ return ranges
262
+
210
263
  def is_shard_complete(self, shard_name: str) -> bool:
211
264
  """Check if all chunks for a shard are complete."""
212
265
  shard_chunks = [chunk for chunk in self.chunks.values() if chunk.shard_name == shard_name]
@@ -236,20 +289,8 @@ class ChunkTracker(CheckpointTracker):
236
289
 
237
290
  for chunk_id, chunk_state in self.chunks.items():
238
291
  shard_name = chunk_state.shard_name
239
-
240
- # For virtual HF dataset shards, normalize the shard name
241
- if shard_name.startswith("hf_dataset:"):
242
- parts = shard_name.split(":")
243
- if len(parts) >= 4 and parts[2] == "chunk":
244
- # Use just the dataset identifier as the shard name
245
- normalized_shard_name = ":".join(parts[:2])
246
- else:
247
- normalized_shard_name = shard_name
248
- else:
249
- normalized_shard_name = shard_name
250
-
251
- if normalized_shard_name not in shards:
252
- shards[normalized_shard_name] = {
292
+ if shard_name not in shards:
293
+ shards[shard_name] = {
253
294
  "total_chunks": 0,
254
295
  "completed_chunks": 0,
255
296
  "pending_chunks": 0,
@@ -259,20 +300,20 @@ class ChunkTracker(CheckpointTracker):
259
300
  "chunks": [],
260
301
  }
261
302
 
262
- shards[normalized_shard_name]["chunks"].append(chunk_state)
263
- shards[normalized_shard_name]["total_chunks"] += 1
303
+ shards[shard_name]["chunks"].append(chunk_state)
304
+ shards[shard_name]["total_chunks"] += 1
264
305
 
265
306
  if chunk_state.status == "completed":
266
- shards[normalized_shard_name]["completed_chunks"] += 1
307
+ shards[shard_name]["completed_chunks"] += 1
267
308
  elif chunk_state.status == "pending":
268
- shards[normalized_shard_name]["pending_chunks"] += 1
269
- shards[normalized_shard_name]["is_complete"] = False
309
+ shards[shard_name]["pending_chunks"] += 1
310
+ shards[shard_name]["is_complete"] = False
270
311
  elif chunk_state.status == "assigned":
271
- shards[normalized_shard_name]["assigned_chunks"] += 1
272
- shards[normalized_shard_name]["is_complete"] = False
312
+ shards[shard_name]["assigned_chunks"] += 1
313
+ shards[shard_name]["is_complete"] = False
273
314
  elif chunk_state.status == "failed":
274
- shards[normalized_shard_name]["failed_chunks"] += 1
275
- shards[normalized_shard_name]["is_complete"] = False
315
+ shards[shard_name]["failed_chunks"] += 1
316
+ shards[shard_name]["is_complete"] = False
276
317
 
277
318
  return shards
278
319
 
@@ -322,13 +363,7 @@ class ChunkTracker(CheckpointTracker):
322
363
  continue
323
364
 
324
365
  # Infer shard URL and create chunk with default size
325
- if shard_name.replace("_", "/") in chunk_id or "_" in shard_name:
326
- # HF dataset
327
- dataset_path = shard_name.replace("_", "/")
328
- shard_url = f"hf_dataset:{dataset_path}:chunk:{start_idx}"
329
- else:
330
- # WebDataset
331
- shard_url = f"unknown://{shard_name}.tar"
366
+ shard_url = f"unknown://{shard_name}.tar"
332
367
 
333
368
  self.chunks[chunk_id] = ChunkState(
334
369
  chunk_id=chunk_id,
@@ -410,6 +445,7 @@ class ChunkTracker(CheckpointTracker):
410
445
  """Mark a range of items as processed within a chunk (expects ABSOLUTE indices)."""
411
446
  if chunk_id not in self.chunks:
412
447
  logger.error(f"Unknown chunk: {chunk_id}")
448
+ logger.debug(f"Known chunks: {list(self.chunks.keys())}")
413
449
  return
414
450
 
415
451
  chunk = self.chunks[chunk_id]
@@ -441,9 +477,32 @@ class ChunkTracker(CheckpointTracker):
441
477
  )
442
478
 
443
479
  def get_chunk_with_unprocessed_items(self, chunk_id: str) -> Optional[Dict[str, Any]]:
444
- """Get chunk info including unprocessed ranges."""
445
- if chunk_id not in self.chunks:
480
+ """Get chunk info with unprocessed item ranges."""
481
+ chunk_state = self.chunks.get(chunk_id)
482
+ if not chunk_state:
446
483
  return None
447
484
 
448
- chunk = self.chunks[chunk_id]
449
- return {"chunk": chunk.to_dict(), "unprocessed_ranges": chunk.get_unprocessed_ranges()}
485
+ # During startup or if no worker is assigned, treat all unprocessed as available
486
+ if not hasattr(self, "_startup_complete"):
487
+ self._startup_complete = False
488
+
489
+ if not self._startup_complete or (
490
+ not chunk_state.assigned_to or chunk_state.completed_at is None
491
+ ):
492
+ # Return all unprocessed ranges
493
+ logger.debug(
494
+ f"Returning all unprocessed ranges. Status {self._startup_complete=} {chunk_state=}"
495
+ )
496
+ return {
497
+ "chunk_id": chunk_id,
498
+ "unprocessed_ranges": chunk_state.get_unprocessed_ranges(),
499
+ "status": chunk_state.status,
500
+ }
501
+
502
+ # Normal operation - only return ranges not being worked on
503
+ # This would need more complex tracking of which ranges each worker is processing
504
+ return {
505
+ "chunk_id": chunk_id,
506
+ "unprocessed_ranges": chunk_state.get_unprocessed_ranges(),
507
+ "status": chunk_state.status,
508
+ }