caption-flow 0.2.4__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,4 @@
1
1
  """Utility modules for CaptionFlow."""
2
2
 
3
- from .dataset_loader import DatasetLoader
4
- from .shard_tracker import ShardTracker
5
3
  from .chunk_tracker import ChunkTracker
6
4
  from .caption_utils import CaptionUtils
@@ -1,16 +1,16 @@
1
- """Chunk tracking using CheckpointTracker base class."""
1
+ """Chunk tracking using CheckpointTracker base class with memory optimization."""
2
2
 
3
3
  from collections import defaultdict
4
4
  import logging
5
5
  from pathlib import Path
6
6
  from typing import Set, Dict, List, Optional, Any, Tuple
7
- from datetime import datetime
7
+ from datetime import datetime, timedelta
8
8
  from dataclasses import dataclass, asdict, field
9
9
 
10
10
  from .checkpoint_tracker import CheckpointTracker
11
11
 
12
12
  logger = logging.getLogger(__name__)
13
- # logger.setLevel(logging.DEBUG)
13
+ logger.setLevel(logging.DEBUG)
14
14
 
15
15
 
16
16
  @dataclass
@@ -53,11 +53,22 @@ class ChunkState:
53
53
 
54
54
  # Auto-complete if all items processed
55
55
  if self.processed_count >= self.chunk_size:
56
- self.status = "completed"
57
- self.completed_at = datetime.utcnow()
56
+ self.mark_completed()
57
+
58
+ def mark_completed(self):
59
+ """Mark chunk as completed and clear unnecessary data to save memory."""
60
+ self.status = "completed"
61
+ self.completed_at = datetime.utcnow()
62
+ # Clear processed_ranges since we don't need them after completion
63
+ self.processed_ranges = []
64
+ self.assigned_to = None
65
+ self.assigned_at = None
58
66
 
59
67
  def get_unprocessed_ranges(self) -> List[Tuple[int, int]]:
60
68
  """Get ranges that haven't been processed yet."""
69
+ if self.status == "completed":
70
+ return []
71
+
61
72
  if not self.processed_ranges:
62
73
  logger.info(f"Chunk {self.chunk_id} has no processed ranges, returning full range")
63
74
  return [(0, self.chunk_size - 1)]
@@ -101,37 +112,89 @@ class ChunkState:
101
112
 
102
113
 
103
114
  class ChunkTracker(CheckpointTracker):
104
- """Tracks chunk processing state persistently."""
105
-
106
- def __init__(self, checkpoint_file: Path):
115
+ """Tracks chunk processing state persistently with memory optimization."""
116
+
117
+ def __init__(
118
+ self,
119
+ checkpoint_file: Path,
120
+ max_completed_chunks_in_memory: int = 1000,
121
+ archive_after_hours: int = 24,
122
+ ):
107
123
  self.chunks: Dict[str, ChunkState] = {}
108
- self.completed_chunks: Set[str] = set()
124
+ self.max_completed_chunks_in_memory = max_completed_chunks_in_memory
125
+ self.archive_after_hours = archive_after_hours
126
+ self._completed_count = 0 # Track count without storing all IDs
109
127
  super().__init__(checkpoint_file)
110
128
 
111
129
  def _get_default_state(self) -> Dict[str, Any]:
112
130
  """Return default state structure for new checkpoints."""
113
- return {"chunks": {}}
131
+ return {"chunks": {}, "completed_count": 0}
114
132
 
115
133
  def _deserialize_state(self, data: Dict[str, Any]) -> None:
116
134
  """Deserialize loaded data into instance state."""
117
135
  self.chunks = {}
118
- self.completed_chunks = set()
136
+ self._completed_count = data.get("completed_count", 0)
119
137
 
120
138
  # Load chunk states
139
+ completed_chunks = 0
121
140
  for chunk_id, chunk_data in data.get("chunks", {}).items():
122
141
  chunk_state = ChunkState.from_dict(chunk_data)
123
142
  self.chunks[chunk_id] = chunk_state
124
143
  if chunk_state.status == "completed":
125
- self.completed_chunks.add(chunk_id)
144
+ completed_chunks += 1
126
145
 
127
146
  logger.info(
128
147
  f"Loaded {len(self.chunks)} chunks from checkpoint, "
129
- f"{len(self.completed_chunks)} completed"
148
+ f"{completed_chunks} completed in memory, "
149
+ f"{self._completed_count} total completed"
130
150
  )
131
151
 
132
152
  def _serialize_state(self) -> Dict[str, Any]:
133
153
  """Serialize instance state for saving."""
134
- return {"chunks": {chunk_id: chunk.to_dict() for chunk_id, chunk in self.chunks.items()}}
154
+ return {
155
+ "chunks": {chunk_id: chunk.to_dict() for chunk_id, chunk in self.chunks.items()},
156
+ "completed_count": self._completed_count,
157
+ }
158
+
159
+ def _archive_old_completed_chunks(self):
160
+ """Remove old completed chunks from memory to prevent unbounded growth."""
161
+ if not self.archive_after_hours:
162
+ return
163
+
164
+ cutoff_time = datetime.utcnow() - timedelta(hours=self.archive_after_hours)
165
+ chunks_to_remove = []
166
+
167
+ for chunk_id, chunk in self.chunks.items():
168
+ if (
169
+ chunk.status == "completed"
170
+ and chunk.completed_at
171
+ and chunk.completed_at < cutoff_time
172
+ ):
173
+ chunks_to_remove.append(chunk_id)
174
+
175
+ if chunks_to_remove:
176
+ for chunk_id in chunks_to_remove:
177
+ del self.chunks[chunk_id]
178
+ logger.info(f"Archived {len(chunks_to_remove)} old completed chunks from memory")
179
+ self.save()
180
+
181
+ def _limit_completed_chunks_in_memory(self):
182
+ """Keep only the most recent completed chunks in memory."""
183
+ completed_chunks = [
184
+ (cid, c) for cid, c in self.chunks.items() if c.status == "completed" and c.completed_at
185
+ ]
186
+
187
+ if len(completed_chunks) > self.max_completed_chunks_in_memory:
188
+ # Sort by completion time, oldest first
189
+ completed_chunks.sort(key=lambda x: x[1].completed_at)
190
+
191
+ # Remove oldest chunks
192
+ to_remove = len(completed_chunks) - self.max_completed_chunks_in_memory
193
+ for chunk_id, _ in completed_chunks[:to_remove]:
194
+ del self.chunks[chunk_id]
195
+
196
+ logger.info(f"Removed {to_remove} oldest completed chunks from memory")
197
+ self.save()
135
198
 
136
199
  def add_chunk(
137
200
  self, chunk_id: str, shard_name: str, shard_url: str, start_index: int, chunk_size: int
@@ -139,23 +202,24 @@ class ChunkTracker(CheckpointTracker):
139
202
  """Add a new chunk. Returns False if chunk already exists and is completed."""
140
203
  if chunk_id in self.chunks:
141
204
  logger.debug(
142
- f"Chunk {chunk_id} already exists with status: {self.chunks[chunk_id].status}, not creating"
205
+ f"Chunk {chunk_id} already exists with status: {self.chunks[chunk_id].status}"
143
206
  )
144
207
  return False
145
- if chunk_id in self.completed_chunks:
146
- logger.debug(f"Chunk {chunk_id} already completed, skipping")
147
- return False
148
208
 
149
- if chunk_id not in self.chunks:
150
- self.chunks[chunk_id] = ChunkState(
151
- chunk_id=chunk_id,
152
- shard_name=shard_name,
153
- shard_url=shard_url, # Now included
154
- start_index=start_index,
155
- chunk_size=chunk_size,
156
- status="pending",
157
- )
158
- self.save()
209
+ self.chunks[chunk_id] = ChunkState(
210
+ chunk_id=chunk_id,
211
+ shard_name=shard_name,
212
+ shard_url=shard_url,
213
+ start_index=start_index,
214
+ chunk_size=chunk_size,
215
+ status="pending",
216
+ )
217
+ self.save()
218
+
219
+ # Periodically clean up old chunks
220
+ if len(self.chunks) % 100 == 0:
221
+ self._archive_old_completed_chunks()
222
+ self._limit_completed_chunks_in_memory()
159
223
 
160
224
  return True
161
225
 
@@ -172,12 +236,17 @@ class ChunkTracker(CheckpointTracker):
172
236
  """Mark chunk as completed."""
173
237
  if chunk_id in self.chunks:
174
238
  chunk = self.chunks[chunk_id]
175
- chunk.status = "completed"
176
- chunk.completed_at = datetime.utcnow()
177
- self.completed_chunks.add(chunk_id)
239
+ was_completed = chunk.status == "completed"
240
+ chunk.mark_completed() # This clears processed_ranges
241
+ if not was_completed:
242
+ self._completed_count += 1
178
243
  self.save()
179
244
  logger.debug(f"Chunk {chunk_id} marked as completed")
180
245
 
246
+ # Check if we need to clean up
247
+ if self._completed_count % 50 == 0:
248
+ self._limit_completed_chunks_in_memory()
249
+
181
250
  def mark_failed(self, chunk_id: str):
182
251
  """Mark chunk as failed."""
183
252
  if chunk_id in self.chunks:
@@ -191,6 +260,8 @@ class ChunkTracker(CheckpointTracker):
191
260
  """Mark chunk as pending (for manual reset)."""
192
261
  if chunk_id in self.chunks:
193
262
  chunk = self.chunks[chunk_id]
263
+ if chunk.status == "completed":
264
+ self._completed_count -= 1
194
265
  chunk.status = "pending"
195
266
  chunk.assigned_to = None
196
267
  chunk.assigned_at = None
@@ -217,48 +288,13 @@ class ChunkTracker(CheckpointTracker):
217
288
  pending.append(chunk_id)
218
289
  return pending
219
290
 
220
- def get_processed_indices_for_chunk(
221
- self, chunk_id: str, processed_job_ids: Set[str]
222
- ) -> List[Tuple[int, int]]:
223
- """Convert processed job_ids back to ranges for a chunk."""
224
- # Extract indices from job_ids like "data-0000:chunk:0:idx:42"
225
- processed_indices = []
226
- # this will be slow as shit, but it's simple for now, Proof of Concept.
227
- for job_id in processed_job_ids:
228
- test_chunk_id = chunk_id.replace("_", ":")
229
- if test_chunk_id in job_id:
230
- parts = job_id.split(":")
231
- logger.debug(
232
- f"Found matching job_id {job_id} for chunk {chunk_id} with {len(parts)=} and {parts[3]=}"
233
- )
234
- if len(parts) >= 5 and parts[3] == "idx":
235
- idx = int(parts[4])
236
- processed_indices.append(idx)
237
-
238
- # Convert to ranges
239
- if not processed_indices:
240
- # logger.warning(
241
- # f"Chunk {chunk_id} had no pre-processed ranges discovered, will process all elements"
242
- # )
243
- return []
244
- else:
245
- logger.debug(f"Chunk {chunk_id} has {len(processed_indices)} pre-processed indices")
246
-
247
- processed_indices.sort()
248
- ranges = []
249
- start = processed_indices[0]
250
- end = processed_indices[0]
251
-
252
- for idx in processed_indices[1:]:
253
- if idx == end + 1:
254
- end = idx
255
- else:
256
- ranges.append((start, end))
257
- start = idx
258
- end = idx
259
-
260
- ranges.append((start, end))
261
- return ranges
291
+ def is_chunk_completed(self, chunk_id: str) -> bool:
292
+ """Check if a chunk is completed (works even if chunk is archived)."""
293
+ if chunk_id in self.chunks:
294
+ return self.chunks[chunk_id].status == "completed"
295
+ # If not in memory, we can't know for sure without loading from disk
296
+ # Could implement a separate completed chunks index if needed
297
+ return False
262
298
 
263
299
  def is_shard_complete(self, shard_name: str) -> bool:
264
300
  """Check if all chunks for a shard are complete."""
@@ -272,13 +308,20 @@ class ChunkTracker(CheckpointTracker):
272
308
  def get_stats(self) -> Dict[str, int]:
273
309
  """Get chunk statistics."""
274
310
  base_stats = super().get_stats()
311
+
312
+ # Count chunks by status in memory
313
+ status_counts = defaultdict(int)
314
+ for chunk in self.chunks.values():
315
+ status_counts[chunk.status] += 1
316
+
275
317
  base_stats.update(
276
318
  {
277
- "total": len(self.chunks),
278
- "pending": sum(1 for c in self.chunks.values() if c.status == "pending"),
279
- "assigned": sum(1 for c in self.chunks.values() if c.status == "assigned"),
280
- "completed": len(self.completed_chunks),
281
- "failed": sum(1 for c in self.chunks.values() if c.status == "failed"),
319
+ "total_in_memory": len(self.chunks),
320
+ "pending": status_counts["pending"],
321
+ "assigned": status_counts["assigned"],
322
+ "completed_in_memory": status_counts["completed"],
323
+ "failed": status_counts["failed"],
324
+ "total_completed": self._completed_count,
282
325
  }
283
326
  )
284
327
  return base_stats
@@ -297,10 +340,8 @@ class ChunkTracker(CheckpointTracker):
297
340
  "assigned_chunks": 0,
298
341
  "failed_chunks": 0,
299
342
  "is_complete": True,
300
- "chunks": [],
301
343
  }
302
344
 
303
- shards[shard_name]["chunks"].append(chunk_state)
304
345
  shards[shard_name]["total_chunks"] += 1
305
346
 
306
347
  if chunk_state.status == "completed":
@@ -326,32 +367,37 @@ class ChunkTracker(CheckpointTracker):
326
367
  return incomplete
327
368
 
328
369
  async def sync_with_storage(self, storage_manager):
329
- """Sync chunk state with storage to detect processed items."""
370
+ """Sync chunk state with storage to detect processed items - memory efficient version."""
330
371
  logger.info("Syncing chunk state with storage...")
331
372
 
332
- if storage_manager.captions_path.exists():
333
- import pyarrow.parquet as pq
373
+ if not storage_manager.captions_path.exists():
374
+ return
375
+
376
+ import pyarrow as pa
377
+ import pyarrow.parquet as pq
378
+
379
+ # Check if item_index column exists
380
+ table_metadata = pq.read_metadata(storage_manager.captions_path)
381
+ columns = ["job_id", "chunk_id", "item_key"]
382
+ if "item_index" in table_metadata.schema.names:
383
+ columns.append("item_index")
334
384
 
335
- # Read all relevant columns
336
- columns = ["job_id", "chunk_id", "item_key"]
337
- # Check if item_index column exists (new format)
338
- table_metadata = pq.read_metadata(storage_manager.captions_path)
339
- if "item_index" in table_metadata.schema.names:
340
- columns.append("item_index")
385
+ # Process in batches to avoid loading entire table
386
+ batch_size = 10000
387
+ parquet_file = pq.ParquetFile(storage_manager.captions_path)
341
388
 
342
- table = pq.read_table(storage_manager.captions_path, columns=columns)
389
+ chunk_indices = defaultdict(set)
343
390
 
344
- # Build lookup of chunk_id -> processed indices
345
- chunk_indices = defaultdict(set)
391
+ for batch in parquet_file.iter_batches(batch_size=batch_size, columns=columns):
392
+ batch_dict = batch.to_pydict()
346
393
 
347
- for i in range(len(table)):
348
- chunk_id = table["chunk_id"][i].as_py()
394
+ for i in range(len(batch_dict["chunk_id"])):
395
+ chunk_id = batch_dict["chunk_id"][i]
349
396
  if not chunk_id:
350
397
  continue
351
398
 
352
- # Get the chunk to find its boundaries
399
+ # Get or create chunk
353
400
  if chunk_id not in self.chunks:
354
- # Try to recreate chunk from chunk_id
355
401
  parts = chunk_id.rsplit("_chunk_", 1)
356
402
  if len(parts) != 2:
357
403
  continue
@@ -362,7 +408,6 @@ class ChunkTracker(CheckpointTracker):
362
408
  except ValueError:
363
409
  continue
364
410
 
365
- # Infer shard URL and create chunk with default size
366
411
  shard_url = f"unknown://{shard_name}.tar"
367
412
 
368
413
  self.chunks[chunk_id] = ChunkState(
@@ -370,18 +415,17 @@ class ChunkTracker(CheckpointTracker):
370
415
  shard_name=shard_name,
371
416
  shard_url=shard_url,
372
417
  start_index=start_idx,
373
- chunk_size=10000, # Default - should match your chunk size
418
+ chunk_size=10000, # Default
374
419
  status="pending",
375
420
  )
376
421
 
377
422
  chunk = self.chunks[chunk_id]
378
423
 
379
424
  # Get item index
380
- if "item_index" in table.column_names:
381
- item_index = table["item_index"][i].as_py()
425
+ if "item_index" in batch_dict:
426
+ item_index = batch_dict["item_index"][i]
382
427
  else:
383
- # Try to extract from item_key
384
- item_key = table["item_key"][i].as_py()
428
+ item_key = batch_dict["item_key"][i]
385
429
  try:
386
430
  item_index = int(item_key.split("_")[-1])
387
431
  except:
@@ -390,62 +434,70 @@ class ChunkTracker(CheckpointTracker):
390
434
  if item_index is None:
391
435
  continue
392
436
 
393
- # CRITICAL: Validate that this item belongs to this chunk
437
+ # Validate index belongs to chunk
394
438
  if (
395
439
  item_index < chunk.start_index
396
440
  or item_index >= chunk.start_index + chunk.chunk_size
397
441
  ):
398
- logger.warning(
399
- f"Item index {item_index} doesn't belong to chunk {chunk_id} "
400
- f"(boundaries: {chunk.start_index}-{chunk.start_index + chunk.chunk_size - 1})"
401
- )
402
442
  continue
403
443
 
404
- # Store the absolute index for now
405
444
  chunk_indices[chunk_id].add(item_index)
406
445
 
407
- # Convert absolute indices to relative and mark as processed
408
- for chunk_id, abs_indices in chunk_indices.items():
409
- if chunk_id not in self.chunks:
410
- continue
446
+ # Process accumulated indices periodically to avoid memory buildup
447
+ if len(chunk_indices) > 100:
448
+ self._process_chunk_indices(chunk_indices)
449
+ chunk_indices.clear()
411
450
 
412
- chunk = self.chunks[chunk_id]
451
+ # Process remaining indices
452
+ if chunk_indices:
453
+ self._process_chunk_indices(chunk_indices)
413
454
 
414
- # Convert to relative indices and group into ranges
415
- rel_indices = []
416
- for abs_idx in sorted(abs_indices):
417
- rel_idx = abs_idx - chunk.start_index
418
- if 0 <= rel_idx < chunk.chunk_size:
419
- rel_indices.append(rel_idx)
420
-
421
- # Group consecutive indices into ranges
422
- if rel_indices:
423
- ranges = []
424
- start = rel_indices[0]
425
- end = rel_indices[0]
426
-
427
- for idx in rel_indices[1:]:
428
- if idx == end + 1:
429
- end = idx
430
- else:
431
- ranges.append((start, end))
432
- start = idx
433
- end = idx
434
-
435
- ranges.append((start, end))
436
-
437
- # Mark ranges as processed
438
- for start_idx, end_idx in ranges:
439
- chunk.add_processed_range(start_idx, end_idx)
440
-
441
- logger.info(f"Synced {len(chunk_indices)} chunks with processed items")
442
- self.save()
455
+ logger.info("Sync with storage completed")
456
+ self.save()
457
+
458
+ def _process_chunk_indices(self, chunk_indices: Dict[str, Set[int]]):
459
+ """Process a batch of chunk indices."""
460
+ for chunk_id, abs_indices in chunk_indices.items():
461
+ if chunk_id not in self.chunks:
462
+ continue
463
+
464
+ chunk = self.chunks[chunk_id]
465
+
466
+ # Skip if already completed
467
+ if chunk.status == "completed":
468
+ continue
469
+
470
+ # Convert to relative indices and group into ranges
471
+ rel_indices = []
472
+ for abs_idx in sorted(abs_indices):
473
+ rel_idx = abs_idx - chunk.start_index
474
+ if 0 <= rel_idx < chunk.chunk_size:
475
+ rel_indices.append(rel_idx)
476
+
477
+ # Group consecutive indices into ranges
478
+ if rel_indices:
479
+ ranges = []
480
+ start = rel_indices[0]
481
+ end = rel_indices[0]
482
+
483
+ for idx in rel_indices[1:]:
484
+ if idx == end + 1:
485
+ end = idx
486
+ else:
487
+ ranges.append((start, end))
488
+ start = idx
489
+ end = idx
490
+
491
+ ranges.append((start, end))
492
+
493
+ # Mark ranges as processed
494
+ for start_idx, end_idx in ranges:
495
+ chunk.add_processed_range(start_idx, end_idx)
443
496
 
444
497
  def mark_items_processed(self, chunk_id: str, start_idx: int, end_idx: int):
445
498
  """Mark a range of items as processed within a chunk (expects ABSOLUTE indices)."""
446
499
  if chunk_id not in self.chunks:
447
500
  logger.error(f"Unknown chunk: {chunk_id}")
448
- logger.debug(f"Known chunks: {list(self.chunks.keys())}")
449
501
  return
450
502
 
451
503
  chunk = self.chunks[chunk_id]
@@ -466,9 +518,9 @@ class ChunkTracker(CheckpointTracker):
466
518
  # Add the relative range
467
519
  chunk.add_processed_range(relative_start, relative_end)
468
520
 
469
- # If chunk is now complete, update completed set
521
+ # If chunk is now complete, increment counter
470
522
  if chunk.status == "completed":
471
- self.completed_chunks.add(chunk_id)
523
+ self._completed_count += 1
472
524
 
473
525
  self.save()
474
526
  logger.debug(
@@ -482,25 +534,6 @@ class ChunkTracker(CheckpointTracker):
482
534
  if not chunk_state:
483
535
  return None
484
536
 
485
- # During startup or if no worker is assigned, treat all unprocessed as available
486
- if not hasattr(self, "_startup_complete"):
487
- self._startup_complete = False
488
-
489
- if not self._startup_complete or (
490
- not chunk_state.assigned_to or chunk_state.completed_at is None
491
- ):
492
- # Return all unprocessed ranges
493
- logger.debug(
494
- f"Returning all unprocessed ranges. Status {self._startup_complete=} {chunk_state=}"
495
- )
496
- return {
497
- "chunk_id": chunk_id,
498
- "unprocessed_ranges": chunk_state.get_unprocessed_ranges(),
499
- "status": chunk_state.status,
500
- }
501
-
502
- # Normal operation - only return ranges not being worked on
503
- # This would need more complex tracking of which ranges each worker is processing
504
537
  return {
505
538
  "chunk_id": chunk_id,
506
539
  "unprocessed_ranges": chunk_state.get_unprocessed_ranges(),