caption-flow 0.2.3__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,4 @@
1
1
  """Utility modules for CaptionFlow."""
2
2
 
3
- from .dataset_loader import DatasetLoader
4
- from .shard_tracker import ShardTracker
5
3
  from .chunk_tracker import ChunkTracker
6
4
  from .caption_utils import CaptionUtils
@@ -1,16 +1,15 @@
1
- """Chunk tracking using CheckpointTracker base class."""
1
+ """Chunk tracking using CheckpointTracker base class with memory optimization."""
2
2
 
3
3
  from collections import defaultdict
4
4
  import logging
5
5
  from pathlib import Path
6
6
  from typing import Set, Dict, List, Optional, Any, Tuple
7
- from datetime import datetime
7
+ from datetime import datetime, timedelta
8
8
  from dataclasses import dataclass, asdict, field
9
9
 
10
10
  from .checkpoint_tracker import CheckpointTracker
11
11
 
12
12
  logger = logging.getLogger(__name__)
13
- # logger.setLevel(logging.DEBUG)
14
13
 
15
14
 
16
15
  @dataclass
@@ -53,11 +52,22 @@ class ChunkState:
53
52
 
54
53
  # Auto-complete if all items processed
55
54
  if self.processed_count >= self.chunk_size:
56
- self.status = "completed"
57
- self.completed_at = datetime.utcnow()
55
+ self.mark_completed()
56
+
57
+ def mark_completed(self):
58
+ """Mark chunk as completed and clear unnecessary data to save memory."""
59
+ self.status = "completed"
60
+ self.completed_at = datetime.utcnow()
61
+ # Clear processed_ranges since we don't need them after completion
62
+ self.processed_ranges = []
63
+ self.assigned_to = None
64
+ self.assigned_at = None
58
65
 
59
66
  def get_unprocessed_ranges(self) -> List[Tuple[int, int]]:
60
67
  """Get ranges that haven't been processed yet."""
68
+ if self.status == "completed":
69
+ return []
70
+
61
71
  if not self.processed_ranges:
62
72
  logger.info(f"Chunk {self.chunk_id} has no processed ranges, returning full range")
63
73
  return [(0, self.chunk_size - 1)]
@@ -101,37 +111,89 @@ class ChunkState:
101
111
 
102
112
 
103
113
  class ChunkTracker(CheckpointTracker):
104
- """Tracks chunk processing state persistently."""
105
-
106
- def __init__(self, checkpoint_file: Path):
114
+ """Tracks chunk processing state persistently with memory optimization."""
115
+
116
+ def __init__(
117
+ self,
118
+ checkpoint_file: Path,
119
+ max_completed_chunks_in_memory: int = 1000,
120
+ archive_after_hours: int = 24,
121
+ ):
107
122
  self.chunks: Dict[str, ChunkState] = {}
108
- self.completed_chunks: Set[str] = set()
123
+ self.max_completed_chunks_in_memory = max_completed_chunks_in_memory
124
+ self.archive_after_hours = archive_after_hours
125
+ self._completed_count = 0 # Track count without storing all IDs
109
126
  super().__init__(checkpoint_file)
110
127
 
111
128
  def _get_default_state(self) -> Dict[str, Any]:
112
129
  """Return default state structure for new checkpoints."""
113
- return {"chunks": {}}
130
+ return {"chunks": {}, "completed_count": 0}
114
131
 
115
132
  def _deserialize_state(self, data: Dict[str, Any]) -> None:
116
133
  """Deserialize loaded data into instance state."""
117
134
  self.chunks = {}
118
- self.completed_chunks = set()
135
+ self._completed_count = data.get("completed_count", 0)
119
136
 
120
137
  # Load chunk states
138
+ completed_chunks = 0
121
139
  for chunk_id, chunk_data in data.get("chunks", {}).items():
122
140
  chunk_state = ChunkState.from_dict(chunk_data)
123
141
  self.chunks[chunk_id] = chunk_state
124
142
  if chunk_state.status == "completed":
125
- self.completed_chunks.add(chunk_id)
143
+ completed_chunks += 1
126
144
 
127
145
  logger.info(
128
146
  f"Loaded {len(self.chunks)} chunks from checkpoint, "
129
- f"{len(self.completed_chunks)} completed"
147
+ f"{completed_chunks} completed in memory, "
148
+ f"{self._completed_count} total completed"
130
149
  )
131
150
 
132
151
  def _serialize_state(self) -> Dict[str, Any]:
133
152
  """Serialize instance state for saving."""
134
- return {"chunks": {chunk_id: chunk.to_dict() for chunk_id, chunk in self.chunks.items()}}
153
+ return {
154
+ "chunks": {chunk_id: chunk.to_dict() for chunk_id, chunk in self.chunks.items()},
155
+ "completed_count": self._completed_count,
156
+ }
157
+
158
+ def _archive_old_completed_chunks(self):
159
+ """Remove old completed chunks from memory to prevent unbounded growth."""
160
+ if not self.archive_after_hours:
161
+ return
162
+
163
+ cutoff_time = datetime.utcnow() - timedelta(hours=self.archive_after_hours)
164
+ chunks_to_remove = []
165
+
166
+ for chunk_id, chunk in self.chunks.items():
167
+ if (
168
+ chunk.status == "completed"
169
+ and chunk.completed_at
170
+ and chunk.completed_at < cutoff_time
171
+ ):
172
+ chunks_to_remove.append(chunk_id)
173
+
174
+ if chunks_to_remove:
175
+ for chunk_id in chunks_to_remove:
176
+ del self.chunks[chunk_id]
177
+ logger.info(f"Archived {len(chunks_to_remove)} old completed chunks from memory")
178
+ self.save()
179
+
180
+ def _limit_completed_chunks_in_memory(self):
181
+ """Keep only the most recent completed chunks in memory."""
182
+ completed_chunks = [
183
+ (cid, c) for cid, c in self.chunks.items() if c.status == "completed" and c.completed_at
184
+ ]
185
+
186
+ if len(completed_chunks) > self.max_completed_chunks_in_memory:
187
+ # Sort by completion time, oldest first
188
+ completed_chunks.sort(key=lambda x: x[1].completed_at)
189
+
190
+ # Remove oldest chunks
191
+ to_remove = len(completed_chunks) - self.max_completed_chunks_in_memory
192
+ for chunk_id, _ in completed_chunks[:to_remove]:
193
+ del self.chunks[chunk_id]
194
+
195
+ logger.info(f"Removed {to_remove} oldest completed chunks from memory")
196
+ self.save()
135
197
 
136
198
  def add_chunk(
137
199
  self, chunk_id: str, shard_name: str, shard_url: str, start_index: int, chunk_size: int
@@ -139,23 +201,24 @@ class ChunkTracker(CheckpointTracker):
139
201
  """Add a new chunk. Returns False if chunk already exists and is completed."""
140
202
  if chunk_id in self.chunks:
141
203
  logger.debug(
142
- f"Chunk {chunk_id} already exists with status: {self.chunks[chunk_id].status}, not creating"
204
+ f"Chunk {chunk_id} already exists with status: {self.chunks[chunk_id].status}"
143
205
  )
144
206
  return False
145
- if chunk_id in self.completed_chunks:
146
- logger.debug(f"Chunk {chunk_id} already completed, skipping")
147
- return False
148
207
 
149
- if chunk_id not in self.chunks:
150
- self.chunks[chunk_id] = ChunkState(
151
- chunk_id=chunk_id,
152
- shard_name=shard_name,
153
- shard_url=shard_url, # Now included
154
- start_index=start_index,
155
- chunk_size=chunk_size,
156
- status="pending",
157
- )
158
- self.save()
208
+ self.chunks[chunk_id] = ChunkState(
209
+ chunk_id=chunk_id,
210
+ shard_name=shard_name,
211
+ shard_url=shard_url,
212
+ start_index=start_index,
213
+ chunk_size=chunk_size,
214
+ status="pending",
215
+ )
216
+ self.save()
217
+
218
+ # Periodically clean up old chunks
219
+ if len(self.chunks) % 100 == 0:
220
+ self._archive_old_completed_chunks()
221
+ self._limit_completed_chunks_in_memory()
159
222
 
160
223
  return True
161
224
 
@@ -172,12 +235,17 @@ class ChunkTracker(CheckpointTracker):
172
235
  """Mark chunk as completed."""
173
236
  if chunk_id in self.chunks:
174
237
  chunk = self.chunks[chunk_id]
175
- chunk.status = "completed"
176
- chunk.completed_at = datetime.utcnow()
177
- self.completed_chunks.add(chunk_id)
238
+ was_completed = chunk.status == "completed"
239
+ chunk.mark_completed() # This clears processed_ranges
240
+ if not was_completed:
241
+ self._completed_count += 1
178
242
  self.save()
179
243
  logger.debug(f"Chunk {chunk_id} marked as completed")
180
244
 
245
+ # Check if we need to clean up
246
+ if self._completed_count % 50 == 0:
247
+ self._limit_completed_chunks_in_memory()
248
+
181
249
  def mark_failed(self, chunk_id: str):
182
250
  """Mark chunk as failed."""
183
251
  if chunk_id in self.chunks:
@@ -191,6 +259,8 @@ class ChunkTracker(CheckpointTracker):
191
259
  """Mark chunk as pending (for manual reset)."""
192
260
  if chunk_id in self.chunks:
193
261
  chunk = self.chunks[chunk_id]
262
+ if chunk.status == "completed":
263
+ self._completed_count -= 1
194
264
  chunk.status = "pending"
195
265
  chunk.assigned_to = None
196
266
  chunk.assigned_at = None
@@ -217,48 +287,13 @@ class ChunkTracker(CheckpointTracker):
217
287
  pending.append(chunk_id)
218
288
  return pending
219
289
 
220
- def get_processed_indices_for_chunk(
221
- self, chunk_id: str, processed_job_ids: Set[str]
222
- ) -> List[Tuple[int, int]]:
223
- """Convert processed job_ids back to ranges for a chunk."""
224
- # Extract indices from job_ids like "data-0000:chunk:0:idx:42"
225
- processed_indices = []
226
- # this will be slow as shit, but it's simple for now, Proof of Concept.
227
- for job_id in processed_job_ids:
228
- test_chunk_id = chunk_id.replace("_", ":")
229
- if test_chunk_id in job_id:
230
- parts = job_id.split(":")
231
- logger.debug(
232
- f"Found matching job_id {job_id} for chunk {chunk_id} with {len(parts)=} and {parts[3]=}"
233
- )
234
- if len(parts) >= 5 and parts[3] == "idx":
235
- idx = int(parts[4])
236
- processed_indices.append(idx)
237
-
238
- # Convert to ranges
239
- if not processed_indices:
240
- # logger.warning(
241
- # f"Chunk {chunk_id} had no pre-processed ranges discovered, will process all elements"
242
- # )
243
- return []
244
- else:
245
- logger.debug(f"Chunk {chunk_id} has {len(processed_indices)} pre-processed indices")
246
-
247
- processed_indices.sort()
248
- ranges = []
249
- start = processed_indices[0]
250
- end = processed_indices[0]
251
-
252
- for idx in processed_indices[1:]:
253
- if idx == end + 1:
254
- end = idx
255
- else:
256
- ranges.append((start, end))
257
- start = idx
258
- end = idx
259
-
260
- ranges.append((start, end))
261
- return ranges
290
+ def is_chunk_completed(self, chunk_id: str) -> bool:
291
+ """Check if a chunk is completed (works even if chunk is archived)."""
292
+ if chunk_id in self.chunks:
293
+ return self.chunks[chunk_id].status == "completed"
294
+ # If not in memory, we can't know for sure without loading from disk
295
+ # Could implement a separate completed chunks index if needed
296
+ return False
262
297
 
263
298
  def is_shard_complete(self, shard_name: str) -> bool:
264
299
  """Check if all chunks for a shard are complete."""
@@ -272,13 +307,20 @@ class ChunkTracker(CheckpointTracker):
272
307
  def get_stats(self) -> Dict[str, int]:
273
308
  """Get chunk statistics."""
274
309
  base_stats = super().get_stats()
310
+
311
+ # Count chunks by status in memory
312
+ status_counts = defaultdict(int)
313
+ for chunk in self.chunks.values():
314
+ status_counts[chunk.status] += 1
315
+
275
316
  base_stats.update(
276
317
  {
277
- "total": len(self.chunks),
278
- "pending": sum(1 for c in self.chunks.values() if c.status == "pending"),
279
- "assigned": sum(1 for c in self.chunks.values() if c.status == "assigned"),
280
- "completed": len(self.completed_chunks),
281
- "failed": sum(1 for c in self.chunks.values() if c.status == "failed"),
318
+ "total_in_memory": len(self.chunks),
319
+ "pending": status_counts["pending"],
320
+ "assigned": status_counts["assigned"],
321
+ "completed_in_memory": status_counts["completed"],
322
+ "failed": status_counts["failed"],
323
+ "total_completed": self._completed_count,
282
324
  }
283
325
  )
284
326
  return base_stats
@@ -297,10 +339,8 @@ class ChunkTracker(CheckpointTracker):
297
339
  "assigned_chunks": 0,
298
340
  "failed_chunks": 0,
299
341
  "is_complete": True,
300
- "chunks": [],
301
342
  }
302
343
 
303
- shards[shard_name]["chunks"].append(chunk_state)
304
344
  shards[shard_name]["total_chunks"] += 1
305
345
 
306
346
  if chunk_state.status == "completed":
@@ -326,32 +366,37 @@ class ChunkTracker(CheckpointTracker):
326
366
  return incomplete
327
367
 
328
368
  async def sync_with_storage(self, storage_manager):
329
- """Sync chunk state with storage to detect processed items."""
369
+ """Sync chunk state with storage to detect processed items - memory efficient version."""
330
370
  logger.info("Syncing chunk state with storage...")
331
371
 
332
- if storage_manager.captions_path.exists():
333
- import pyarrow.parquet as pq
372
+ if not storage_manager.captions_path.exists():
373
+ return
374
+
375
+ import pyarrow as pa
376
+ import pyarrow.parquet as pq
377
+
378
+ # Check if item_index column exists
379
+ table_metadata = pq.read_metadata(storage_manager.captions_path)
380
+ columns = ["job_id", "chunk_id", "item_key"]
381
+ if "item_index" in table_metadata.schema.names:
382
+ columns.append("item_index")
334
383
 
335
- # Read all relevant columns
336
- columns = ["job_id", "chunk_id", "item_key"]
337
- # Check if item_index column exists (new format)
338
- table_metadata = pq.read_metadata(storage_manager.captions_path)
339
- if "item_index" in table_metadata.schema.names:
340
- columns.append("item_index")
384
+ # Process in batches to avoid loading entire table
385
+ batch_size = 10000
386
+ parquet_file = pq.ParquetFile(storage_manager.captions_path)
341
387
 
342
- table = pq.read_table(storage_manager.captions_path, columns=columns)
388
+ chunk_indices = defaultdict(set)
343
389
 
344
- # Build lookup of chunk_id -> processed indices
345
- chunk_indices = defaultdict(set)
390
+ for batch in parquet_file.iter_batches(batch_size=batch_size, columns=columns):
391
+ batch_dict = batch.to_pydict()
346
392
 
347
- for i in range(len(table)):
348
- chunk_id = table["chunk_id"][i].as_py()
393
+ for i in range(len(batch_dict["chunk_id"])):
394
+ chunk_id = batch_dict["chunk_id"][i]
349
395
  if not chunk_id:
350
396
  continue
351
397
 
352
- # Get the chunk to find its boundaries
398
+ # Get or create chunk
353
399
  if chunk_id not in self.chunks:
354
- # Try to recreate chunk from chunk_id
355
400
  parts = chunk_id.rsplit("_chunk_", 1)
356
401
  if len(parts) != 2:
357
402
  continue
@@ -362,7 +407,6 @@ class ChunkTracker(CheckpointTracker):
362
407
  except ValueError:
363
408
  continue
364
409
 
365
- # Infer shard URL and create chunk with default size
366
410
  shard_url = f"unknown://{shard_name}.tar"
367
411
 
368
412
  self.chunks[chunk_id] = ChunkState(
@@ -370,18 +414,17 @@ class ChunkTracker(CheckpointTracker):
370
414
  shard_name=shard_name,
371
415
  shard_url=shard_url,
372
416
  start_index=start_idx,
373
- chunk_size=10000, # Default - should match your chunk size
417
+ chunk_size=10000, # Default
374
418
  status="pending",
375
419
  )
376
420
 
377
421
  chunk = self.chunks[chunk_id]
378
422
 
379
423
  # Get item index
380
- if "item_index" in table.column_names:
381
- item_index = table["item_index"][i].as_py()
424
+ if "item_index" in batch_dict:
425
+ item_index = batch_dict["item_index"][i]
382
426
  else:
383
- # Try to extract from item_key
384
- item_key = table["item_key"][i].as_py()
427
+ item_key = batch_dict["item_key"][i]
385
428
  try:
386
429
  item_index = int(item_key.split("_")[-1])
387
430
  except:
@@ -390,62 +433,70 @@ class ChunkTracker(CheckpointTracker):
390
433
  if item_index is None:
391
434
  continue
392
435
 
393
- # CRITICAL: Validate that this item belongs to this chunk
436
+ # Validate index belongs to chunk
394
437
  if (
395
438
  item_index < chunk.start_index
396
439
  or item_index >= chunk.start_index + chunk.chunk_size
397
440
  ):
398
- logger.warning(
399
- f"Item index {item_index} doesn't belong to chunk {chunk_id} "
400
- f"(boundaries: {chunk.start_index}-{chunk.start_index + chunk.chunk_size - 1})"
401
- )
402
441
  continue
403
442
 
404
- # Store the absolute index for now
405
443
  chunk_indices[chunk_id].add(item_index)
406
444
 
407
- # Convert absolute indices to relative and mark as processed
408
- for chunk_id, abs_indices in chunk_indices.items():
409
- if chunk_id not in self.chunks:
410
- continue
445
+ # Process accumulated indices periodically to avoid memory buildup
446
+ if len(chunk_indices) > 100:
447
+ self._process_chunk_indices(chunk_indices)
448
+ chunk_indices.clear()
411
449
 
412
- chunk = self.chunks[chunk_id]
450
+ # Process remaining indices
451
+ if chunk_indices:
452
+ self._process_chunk_indices(chunk_indices)
413
453
 
414
- # Convert to relative indices and group into ranges
415
- rel_indices = []
416
- for abs_idx in sorted(abs_indices):
417
- rel_idx = abs_idx - chunk.start_index
418
- if 0 <= rel_idx < chunk.chunk_size:
419
- rel_indices.append(rel_idx)
420
-
421
- # Group consecutive indices into ranges
422
- if rel_indices:
423
- ranges = []
424
- start = rel_indices[0]
425
- end = rel_indices[0]
426
-
427
- for idx in rel_indices[1:]:
428
- if idx == end + 1:
429
- end = idx
430
- else:
431
- ranges.append((start, end))
432
- start = idx
433
- end = idx
434
-
435
- ranges.append((start, end))
436
-
437
- # Mark ranges as processed
438
- for start_idx, end_idx in ranges:
439
- chunk.add_processed_range(start_idx, end_idx)
440
-
441
- logger.info(f"Synced {len(chunk_indices)} chunks with processed items")
442
- self.save()
454
+ logger.info("Sync with storage completed")
455
+ self.save()
456
+
457
+ def _process_chunk_indices(self, chunk_indices: Dict[str, Set[int]]):
458
+ """Process a batch of chunk indices."""
459
+ for chunk_id, abs_indices in chunk_indices.items():
460
+ if chunk_id not in self.chunks:
461
+ continue
462
+
463
+ chunk = self.chunks[chunk_id]
464
+
465
+ # Skip if already completed
466
+ if chunk.status == "completed":
467
+ continue
468
+
469
+ # Convert to relative indices and group into ranges
470
+ rel_indices = []
471
+ for abs_idx in sorted(abs_indices):
472
+ rel_idx = abs_idx - chunk.start_index
473
+ if 0 <= rel_idx < chunk.chunk_size:
474
+ rel_indices.append(rel_idx)
475
+
476
+ # Group consecutive indices into ranges
477
+ if rel_indices:
478
+ ranges = []
479
+ start = rel_indices[0]
480
+ end = rel_indices[0]
481
+
482
+ for idx in rel_indices[1:]:
483
+ if idx == end + 1:
484
+ end = idx
485
+ else:
486
+ ranges.append((start, end))
487
+ start = idx
488
+ end = idx
489
+
490
+ ranges.append((start, end))
491
+
492
+ # Mark ranges as processed
493
+ for start_idx, end_idx in ranges:
494
+ chunk.add_processed_range(start_idx, end_idx)
443
495
 
444
496
  def mark_items_processed(self, chunk_id: str, start_idx: int, end_idx: int):
445
497
  """Mark a range of items as processed within a chunk (expects ABSOLUTE indices)."""
446
498
  if chunk_id not in self.chunks:
447
499
  logger.error(f"Unknown chunk: {chunk_id}")
448
- logger.debug(f"Known chunks: {list(self.chunks.keys())}")
449
500
  return
450
501
 
451
502
  chunk = self.chunks[chunk_id]
@@ -466,9 +517,9 @@ class ChunkTracker(CheckpointTracker):
466
517
  # Add the relative range
467
518
  chunk.add_processed_range(relative_start, relative_end)
468
519
 
469
- # If chunk is now complete, update completed set
520
+ # If chunk is now complete, increment counter
470
521
  if chunk.status == "completed":
471
- self.completed_chunks.add(chunk_id)
522
+ self._completed_count += 1
472
523
 
473
524
  self.save()
474
525
  logger.debug(
@@ -482,25 +533,6 @@ class ChunkTracker(CheckpointTracker):
482
533
  if not chunk_state:
483
534
  return None
484
535
 
485
- # During startup or if no worker is assigned, treat all unprocessed as available
486
- if not hasattr(self, "_startup_complete"):
487
- self._startup_complete = False
488
-
489
- if not self._startup_complete or (
490
- not chunk_state.assigned_to or chunk_state.completed_at is None
491
- ):
492
- # Return all unprocessed ranges
493
- logger.debug(
494
- f"Returning all unprocessed ranges. Status {self._startup_complete=} {chunk_state=}"
495
- )
496
- return {
497
- "chunk_id": chunk_id,
498
- "unprocessed_ranges": chunk_state.get_unprocessed_ranges(),
499
- "status": chunk_state.status,
500
- }
501
-
502
- # Normal operation - only return ranges not being worked on
503
- # This would need more complex tracking of which ranges each worker is processing
504
536
  return {
505
537
  "chunk_id": chunk_id,
506
538
  "unprocessed_ranges": chunk_state.get_unprocessed_ranges(),