caption-flow 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. caption_flow/__init__.py +3 -3
  2. caption_flow/cli.py +937 -416
  3. caption_flow/models.py +45 -3
  4. caption_flow/monitor.py +5 -3
  5. caption_flow/orchestrator.py +186 -116
  6. caption_flow/processors/__init__.py +3 -3
  7. caption_flow/processors/base.py +8 -7
  8. caption_flow/processors/huggingface.py +440 -68
  9. caption_flow/processors/local_filesystem.py +24 -28
  10. caption_flow/processors/webdataset.py +66 -25
  11. caption_flow/storage/exporter.py +420 -339
  12. caption_flow/storage/manager.py +636 -756
  13. caption_flow/utils/__init__.py +1 -1
  14. caption_flow/utils/auth.py +1 -1
  15. caption_flow/utils/caption_utils.py +1 -1
  16. caption_flow/utils/certificates.py +15 -8
  17. caption_flow/utils/checkpoint_tracker.py +41 -19
  18. caption_flow/utils/chunk_tracker.py +200 -65
  19. caption_flow/utils/image_processor.py +9 -9
  20. caption_flow/utils/json_utils.py +37 -20
  21. caption_flow/utils/prompt_template.py +24 -16
  22. caption_flow/utils/vllm_config.py +5 -4
  23. caption_flow/viewer.py +4 -12
  24. caption_flow/workers/base.py +12 -6
  25. caption_flow/workers/caption.py +272 -91
  26. caption_flow/workers/data.py +6 -8
  27. {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/METADATA +5 -4
  28. caption_flow-0.4.0.dist-info/RECORD +33 -0
  29. caption_flow-0.3.3.dist-info/RECORD +0 -33
  30. {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/WHEEL +0 -0
  31. {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/entry_points.txt +0 -0
  32. {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/licenses/LICENSE +0 -0
  33. {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
1
  """Utility modules for CaptionFlow."""
2
2
 
3
- from .chunk_tracker import ChunkTracker
4
3
  from .caption_utils import CaptionUtils
4
+ from .chunk_tracker import ChunkTracker
@@ -1,7 +1,7 @@
1
1
  """Authentication management."""
2
2
 
3
- from typing import Dict, Any, Optional
4
3
  from dataclasses import dataclass
4
+ from typing import Any, Dict, Optional
5
5
 
6
6
 
7
7
  @dataclass
@@ -1,6 +1,6 @@
1
1
  """Caption processing utilities from the original vLLM script."""
2
2
 
3
- from typing import List, Dict
3
+ from typing import Dict, List
4
4
 
5
5
 
6
6
  class CaptionUtils:
@@ -1,13 +1,15 @@
1
1
  """SSL certificate management."""
2
2
 
3
+ import datetime as _datetime
3
4
  import subprocess
5
+ from datetime import datetime, timedelta
4
6
  from pathlib import Path
5
7
  from typing import Optional
8
+
6
9
  from cryptography import x509
7
- from cryptography.x509.oid import NameOID
8
10
  from cryptography.hazmat.primitives import hashes, serialization
9
11
  from cryptography.hazmat.primitives.asymmetric import rsa
10
- from datetime import datetime, timedelta
12
+ from cryptography.x509.oid import NameOID
11
13
 
12
14
 
13
15
  class CertificateManager:
@@ -35,8 +37,8 @@ class CertificateManager:
35
37
  .issuer_name(issuer)
36
38
  .public_key(key.public_key())
37
39
  .serial_number(x509.random_serial_number())
38
- .not_valid_before(datetime.utcnow())
39
- .not_valid_after(datetime.utcnow() + timedelta(days=365))
40
+ .not_valid_before(datetime.now(_datetime.UTC))
41
+ .not_valid_after(datetime.now(_datetime.UTC) + timedelta(days=365))
40
42
  .add_extension(
41
43
  x509.SubjectAlternativeName(
42
44
  [
@@ -71,14 +73,15 @@ class CertificateManager:
71
73
  def generate_letsencrypt(
72
74
  self, domain: str, email: str, output_dir: Optional[Path] = None, staging: bool = False
73
75
  ) -> tuple[Path, Path]:
74
- """
75
- Generate Let's Encrypt certificate.
76
+ """Generate Let's Encrypt certificate.
76
77
 
77
78
  Args:
79
+ ----
78
80
  domain: Domain name for certificate
79
81
  email: Email for Let's Encrypt account
80
82
  output_dir: Custom output directory (uses /etc/letsencrypt by default)
81
83
  staging: Use Let's Encrypt staging server for testing
84
+
82
85
  """
83
86
  cmd = [
84
87
  "certbot",
@@ -133,8 +136,12 @@ class CertificateManager:
133
136
  return {
134
137
  "subject": cert.subject.rfc4514_string(),
135
138
  "issuer": cert.issuer.rfc4514_string(),
136
- "not_before": cert.not_valid_before,
137
- "not_after": cert.not_valid_after,
139
+ "not_before": cert.not_valid_before_utc,
140
+ "not_after": cert.not_valid_after_utc,
138
141
  "serial_number": cert.serial_number,
139
142
  "is_self_signed": cert.issuer == cert.subject,
140
143
  }
144
+
145
+ def inspect_certificate(self, cert_path: Path) -> dict:
146
+ """Inspect a certificate (alias for get_cert_info for CLI compatibility)."""
147
+ return self.get_cert_info(cert_path)
@@ -1,13 +1,17 @@
1
1
  """Base class for checkpoint tracking with persistent state."""
2
2
 
3
+ import datetime as _datetime
3
4
  import json
4
5
  import logging
6
+ import os
5
7
  from abc import ABC, abstractmethod
6
- from pathlib import Path
7
- from typing import Dict, Any, Optional
8
+ from concurrent.futures import ThreadPoolExecutor
8
9
  from datetime import datetime
10
+ from pathlib import Path
11
+ from typing import Any, Dict, Optional
9
12
 
10
13
  logger = logging.getLogger(__name__)
14
+ logger.setLevel(os.environ.get("CAPTIONFLOW_LOG_LEVEL", "INFO").upper())
11
15
 
12
16
 
13
17
  class CheckpointTracker(ABC):
@@ -53,34 +57,52 @@ class CheckpointTracker(ABC):
53
57
  def save(self) -> None:
54
58
  """Save checkpoint to disk atomically."""
55
59
  try:
60
+ # If a save is already in progress, let it finish.
61
+ # This prevents race conditions if save() is called rapidly.
62
+ if hasattr(self, "_save_future") and self._save_future and not self._save_future.done():
63
+ logger.warning("Previous save still in progress, skipping this save")
64
+ return # don't save this time,
65
+ logger.info("Saving chunk tracker state...")
56
66
  # Prepare data with metadata
57
- data = self._serialize_state()
58
- data["updated_at"] = datetime.utcnow().isoformat()
67
+ with self.lock:
68
+ data = self._serialize_state()
69
+ data["updated_at"] = datetime.now(_datetime.UTC).isoformat()
59
70
 
60
71
  # Write atomically using temp file
61
72
  tmp_file = self.checkpoint_path.with_suffix(".tmp")
62
73
 
74
+ # Use an executor to run the save operation in a background thread.
75
+ # This makes the save call non-blocking.
76
+ with ThreadPoolExecutor(max_workers=1) as executor:
77
+ data_to_save = data.copy()
78
+ self._save_future = executor.submit(self._write_to_disk, data_to_save, tmp_file)
79
+ except Exception as e:
80
+ logger.error(f"Failed to submit save task: {e}", exc_info=True)
81
+
82
+ def _write_to_disk(self, data: Dict[str, Any], checkpoint_path: Optional[str] = None) -> None:
83
+ """Write checkpoint data to disk atomically."""
84
+ # Create a temporary file in the same directory as the checkpoint
85
+ tmp_file = (checkpoint_path or self.checkpoint_path).with_suffix(".tmp")
86
+ logger.debug(f"Checkpoint {tmp_file=}")
87
+
88
+ try:
89
+ # Ensure the parent directory exists
90
+ self.checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
91
+
63
92
  with open(tmp_file, "w") as f:
64
93
  json.dump(data, f, indent=2)
65
94
 
66
- # Ensure temp file was created
67
- if not tmp_file.exists():
68
- raise IOError(f"Failed to create temporary file: {tmp_file}")
69
-
70
- # Move atomically
95
+ # Atomically replace the checkpoint file
71
96
  tmp_file.replace(self.checkpoint_path)
72
-
73
97
  logger.debug(f"Saved checkpoint to {self.checkpoint_path}")
74
-
75
98
  except Exception as e:
76
- # logger.error(f"Error saving checkpoint: {e}", exc_info=True)
77
- # Try direct write as fallback
78
- try:
79
- with open(self.checkpoint_path, "w") as f:
80
- json.dump(data, f, indent=2)
81
- # logger.info("Saved checkpoint using fallback direct write")
82
- except Exception as fallback_error:
83
- logger.error(f"Fallback save also failed: {fallback_error}")
99
+ logger.error(f"Failed to save checkpoint atomically: {e}", exc_info=True)
100
+ # Try to clean up the temp file if it exists
101
+ if tmp_file.exists():
102
+ try:
103
+ tmp_file.unlink()
104
+ except:
105
+ pass
84
106
 
85
107
  def get_stats(self) -> Dict[str, Any]:
86
108
  """Get statistics about tracked items. Override for custom stats."""
@@ -1,16 +1,19 @@
1
1
  """Chunk tracking using CheckpointTracker base class with memory optimization."""
2
2
 
3
- from collections import defaultdict
3
+ import datetime as _datetime
4
4
  import logging
5
- from pathlib import Path
6
- from typing import Set, Dict, List, Optional, Any, Tuple
5
+ import os
6
+ from collections import defaultdict
7
+ from dataclasses import asdict, dataclass, field
7
8
  from datetime import datetime, timedelta
8
- from dataclasses import dataclass, asdict, field
9
+ from pathlib import Path
10
+ from threading import Lock
11
+ from typing import Any, Dict, List, Optional, Set, Tuple
9
12
 
10
13
  from .checkpoint_tracker import CheckpointTracker
11
14
 
12
15
  logger = logging.getLogger(__name__)
13
- logger.setLevel(logging.DEBUG)
16
+ logger.setLevel(os.environ.get("CAPTIONFLOW_LOG_LEVEL", "INFO").upper())
14
17
 
15
18
 
16
19
  @dataclass
@@ -31,8 +34,16 @@ class ChunkState:
31
34
  assigned_to: Optional[str] = None
32
35
  assigned_at: Optional[datetime] = None
33
36
 
37
+ # Cache for expensive range calculations
38
+ _cached_merged_ranges: Optional[List[Tuple[int, int]]] = field(default=None, init=False)
39
+ _cached_unprocessed_ranges: Optional[List[Tuple[int, int]]] = field(default=None, init=False)
40
+ _cache_invalidated: bool = field(default=True, init=False)
41
+
34
42
  def add_processed_range(self, start: int, end: int):
35
43
  """Add a processed range and merge if needed."""
44
+ # Invalidate cache before modifying ranges
45
+ self._invalidate_cache()
46
+
36
47
  # Add new range
37
48
  self.processed_ranges.append((start, end))
38
49
 
@@ -57,38 +68,98 @@ class ChunkState:
57
68
 
58
69
  def mark_completed(self):
59
70
  """Mark chunk as completed and clear unnecessary data to save memory."""
71
+ self._invalidate_cache()
60
72
  self.status = "completed"
61
- self.completed_at = datetime.utcnow()
73
+ self.completed_at = datetime.now(_datetime.UTC)
62
74
  # Clear processed_ranges since we don't need them after completion
63
- self.processed_ranges = []
64
- self.assigned_to = None
65
- self.assigned_at = None
75
+ # self.processed_ranges = []
76
+ # self.assigned_to = None
77
+ # self.assigned_at = None
78
+
79
+ def _invalidate_cache(self):
80
+ """Invalidate cached range calculations."""
81
+ self._cached_merged_ranges = None
82
+ self._cached_unprocessed_ranges = None
83
+ self._cache_invalidated = True
84
+
85
+ def _get_merged_ranges(self) -> List[Tuple[int, int]]:
86
+ """Get merged ranges with caching."""
87
+ if self._cached_merged_ranges is None:
88
+ self._cached_merged_ranges = self._merge_ranges(self.processed_ranges)
89
+ return self._cached_merged_ranges
66
90
 
67
91
  def get_unprocessed_ranges(self) -> List[Tuple[int, int]]:
68
- """Get ranges that haven't been processed yet."""
92
+ """Get ranges of unprocessed items within the chunk (relative indices)."""
69
93
  if self.status == "completed":
70
94
  return []
71
95
 
72
96
  if not self.processed_ranges:
73
- logger.info(f"Chunk {self.chunk_id} has no processed ranges, returning full range")
97
+ if self._cache_invalidated: # Only log once per invalidation
98
+ logger.info(f"Chunk {self.chunk_id} has no processed ranges, returning full range")
99
+ self._cache_invalidated = False
74
100
  return [(0, self.chunk_size - 1)]
75
101
 
76
- unprocessed = []
77
- current = 0
102
+ # Use cached result if available
103
+ if self._cached_unprocessed_ranges is not None:
104
+ return self._cached_unprocessed_ranges
78
105
 
79
- logger.info(
80
- f"Processing {len(self.processed_ranges)} processed ranges for chunk {self.chunk_id}"
81
- )
82
- for start, end in self.processed_ranges:
83
- if current < start:
84
- unprocessed.append((current, start - 1))
85
- current = max(current, end + 1)
106
+ # Calculate and cache unprocessed ranges
107
+ merged_ranges = self._get_merged_ranges()
86
108
 
87
- if current < self.chunk_size:
88
- unprocessed.append((current, self.chunk_size - 1))
109
+ unprocessed = []
110
+ current_pos = 0
111
+
112
+ for start, end in merged_ranges:
113
+ if current_pos < start:
114
+ unprocessed.append((current_pos, start - 1))
115
+ current_pos = max(current_pos, end + 1)
116
+
117
+ # Add any remaining range
118
+ if current_pos < self.chunk_size:
119
+ unprocessed.append((current_pos, self.chunk_size - 1))
120
+
121
+ # Cache the result
122
+ self._cached_unprocessed_ranges = unprocessed
123
+
124
+ # Log for debugging (only when cache is being computed)
125
+ if self._cache_invalidated:
126
+ if not unprocessed:
127
+ logger.info(
128
+ f"Chunk {self.chunk_id} has processed ranges {merged_ranges} covering entire chunk size {self.chunk_size}"
129
+ )
130
+ else:
131
+ logger.debug(f"Merged ranges for chunk {self.chunk_id}: {merged_ranges}")
132
+ total_processed = sum(end - start + 1 for start, end in merged_ranges)
133
+ total_unprocessed = sum(end - start + 1 for start, end in unprocessed)
134
+ logger.debug(
135
+ f"Chunk {self.chunk_id}: {total_processed} processed, {total_unprocessed} unprocessed"
136
+ )
137
+ self._cache_invalidated = False
89
138
 
90
139
  return unprocessed
91
140
 
141
+ def _merge_ranges(self, ranges: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
142
+ """Merge overlapping or adjacent ranges."""
143
+ if not ranges:
144
+ return []
145
+
146
+ # Sort ranges by start index, ensuring all are tuples
147
+ sorted_ranges = sorted([tuple(r) for r in ranges])
148
+ merged = [sorted_ranges[0]]
149
+
150
+ for current_start, current_end in sorted_ranges[1:]:
151
+ last_start, last_end = merged[-1]
152
+
153
+ # Check if ranges overlap or are adjacent
154
+ if current_start <= last_end + 1:
155
+ # Merge the ranges
156
+ merged[-1] = (last_start, max(last_end, current_end))
157
+ else:
158
+ # Add as new range
159
+ merged.append((current_start, current_end))
160
+
161
+ return merged
162
+
92
163
  def to_dict(self):
93
164
  """Convert to dictionary for JSON serialization."""
94
165
  d = asdict(self)
@@ -108,6 +179,10 @@ class ChunkState:
108
179
  # Ensure processed_ranges exists
109
180
  d.setdefault("processed_ranges", [])
110
181
  d.setdefault("processed_count", 0)
182
+ # Remove cache fields from dict if they exist (shouldn't be serialized)
183
+ d.pop("_cached_merged_ranges", None)
184
+ d.pop("_cached_unprocessed_ranges", None)
185
+ d.pop("_cache_invalidated", None)
111
186
  return cls(**d)
112
187
 
113
188
 
@@ -119,11 +194,22 @@ class ChunkTracker(CheckpointTracker):
119
194
  checkpoint_file: Path,
120
195
  max_completed_chunks_in_memory: int = 1000,
121
196
  archive_after_hours: int = 24,
197
+ save_batch_size: int = 10,
198
+ auto_save_interval: int = 60,
122
199
  ):
123
200
  self.chunks: Dict[str, ChunkState] = {}
124
201
  self.max_completed_chunks_in_memory = max_completed_chunks_in_memory
125
202
  self.archive_after_hours = archive_after_hours
126
203
  self._completed_count = 0 # Track count without storing all IDs
204
+ self.lock = Lock()
205
+
206
+ # Batching mechanism
207
+ self._dirty = False
208
+ self._pending_changes = 0
209
+ self._save_batch_size = save_batch_size
210
+ self._auto_save_interval = auto_save_interval
211
+ self._last_save = datetime.now(_datetime.UTC)
212
+
127
213
  super().__init__(checkpoint_file)
128
214
 
129
215
  def _get_default_state(self) -> Dict[str, Any]:
@@ -139,7 +225,8 @@ class ChunkTracker(CheckpointTracker):
139
225
  completed_chunks = 0
140
226
  for chunk_id, chunk_data in data.get("chunks", {}).items():
141
227
  chunk_state = ChunkState.from_dict(chunk_data)
142
- self.chunks[chunk_id] = chunk_state
228
+ with self.lock:
229
+ self.chunks[chunk_id] = chunk_state
143
230
  if chunk_state.status == "completed":
144
231
  completed_chunks += 1
145
232
 
@@ -156,12 +243,47 @@ class ChunkTracker(CheckpointTracker):
156
243
  "completed_count": self._completed_count,
157
244
  }
158
245
 
246
+ def _mark_dirty(self):
247
+ """Mark tracker as having pending changes."""
248
+ self._dirty = True
249
+ self._pending_changes += 1
250
+
251
+ # Auto-save based on batch size or time interval
252
+ now = datetime.now(_datetime.UTC)
253
+ time_since_last_save = (now - self._last_save).total_seconds()
254
+
255
+ if (
256
+ self._pending_changes >= self._save_batch_size
257
+ or time_since_last_save >= self._auto_save_interval
258
+ ):
259
+ self._do_save()
260
+
261
+ def _do_save(self) -> bool:
262
+ """Internal method to perform the actual save."""
263
+ super().save() # Parent method returns None but triggers save
264
+ # Reset dirty state since save was initiated successfully
265
+ self._dirty = False
266
+ self._pending_changes = 0
267
+ self._last_save = datetime.now(_datetime.UTC)
268
+ return True
269
+
270
+ def save(self, force: bool = False) -> bool:
271
+ """Save state to checkpoint file, with batching optimization."""
272
+ if not force and not self._dirty:
273
+ return False
274
+ return self._do_save()
275
+
276
+ def flush(self):
277
+ """Force save any pending changes."""
278
+ if self._dirty:
279
+ self._do_save()
280
+
159
281
  def _archive_old_completed_chunks(self):
160
282
  """Remove old completed chunks from memory to prevent unbounded growth."""
161
283
  if not self.archive_after_hours:
162
284
  return
163
285
 
164
- cutoff_time = datetime.utcnow() - timedelta(hours=self.archive_after_hours)
286
+ cutoff_time = datetime.now(_datetime.UTC) - timedelta(hours=self.archive_after_hours)
165
287
  chunks_to_remove = []
166
288
 
167
289
  for chunk_id, chunk in self.chunks.items():
@@ -176,7 +298,7 @@ class ChunkTracker(CheckpointTracker):
176
298
  for chunk_id in chunks_to_remove:
177
299
  del self.chunks[chunk_id]
178
300
  logger.info(f"Archived {len(chunks_to_remove)} old completed chunks from memory")
179
- self.save()
301
+ self._mark_dirty()
180
302
 
181
303
  def _limit_completed_chunks_in_memory(self):
182
304
  """Keep only the most recent completed chunks in memory."""
@@ -194,7 +316,7 @@ class ChunkTracker(CheckpointTracker):
194
316
  del self.chunks[chunk_id]
195
317
 
196
318
  logger.info(f"Removed {to_remove} oldest completed chunks from memory")
197
- self.save()
319
+ self._mark_dirty()
198
320
 
199
321
  def add_chunk(
200
322
  self, chunk_id: str, shard_name: str, shard_url: str, start_index: int, chunk_size: int
@@ -214,7 +336,7 @@ class ChunkTracker(CheckpointTracker):
214
336
  chunk_size=chunk_size,
215
337
  status="pending",
216
338
  )
217
- self.save()
339
+ self._mark_dirty()
218
340
 
219
341
  # Periodically clean up old chunks
220
342
  if len(self.chunks) % 100 == 0:
@@ -229,8 +351,8 @@ class ChunkTracker(CheckpointTracker):
229
351
  chunk = self.chunks[chunk_id]
230
352
  chunk.status = "assigned"
231
353
  chunk.assigned_to = worker_id
232
- chunk.assigned_at = datetime.utcnow()
233
- self.save()
354
+ chunk.assigned_at = datetime.now(_datetime.UTC)
355
+ self._mark_dirty()
234
356
 
235
357
  def mark_completed(self, chunk_id: str):
236
358
  """Mark chunk as completed."""
@@ -240,7 +362,7 @@ class ChunkTracker(CheckpointTracker):
240
362
  chunk.mark_completed() # This clears processed_ranges
241
363
  if not was_completed:
242
364
  self._completed_count += 1
243
- self.save()
365
+ self._mark_dirty()
244
366
  logger.debug(f"Chunk {chunk_id} marked as completed")
245
367
 
246
368
  # Check if we need to clean up
@@ -254,7 +376,7 @@ class ChunkTracker(CheckpointTracker):
254
376
  chunk.status = "pending" # Reset to pending for retry
255
377
  chunk.assigned_to = None
256
378
  chunk.assigned_at = None
257
- self.save()
379
+ self._mark_dirty()
258
380
 
259
381
  def mark_pending(self, chunk_id: str):
260
382
  """Mark chunk as pending (for manual reset)."""
@@ -265,7 +387,7 @@ class ChunkTracker(CheckpointTracker):
265
387
  chunk.status = "pending"
266
388
  chunk.assigned_to = None
267
389
  chunk.assigned_at = None
268
- self.save()
390
+ self._mark_dirty()
269
391
 
270
392
  def release_worker_chunks(self, worker_id: str):
271
393
  """Release all chunks assigned to a worker."""
@@ -276,7 +398,8 @@ class ChunkTracker(CheckpointTracker):
276
398
  chunk.assigned_to = None
277
399
  chunk.assigned_at = None
278
400
  released_chunks.append(chunk_id)
279
- self.save()
401
+ if released_chunks:
402
+ self._mark_dirty()
280
403
  return released_chunks
281
404
 
282
405
  def get_pending_chunks(self, shard_name: Optional[str] = None) -> List[str]:
@@ -330,7 +453,7 @@ class ChunkTracker(CheckpointTracker):
330
453
  """Get summary of all shards and their chunk status."""
331
454
  shards = {}
332
455
 
333
- for chunk_id, chunk_state in self.chunks.items():
456
+ for _chunk_id, chunk_state in self.chunks.items():
334
457
  shard_name = chunk_state.shard_name
335
458
  if shard_name not in shards:
336
459
  shards[shard_name] = {
@@ -340,9 +463,11 @@ class ChunkTracker(CheckpointTracker):
340
463
  "assigned_chunks": 0,
341
464
  "failed_chunks": 0,
342
465
  "is_complete": True,
466
+ "chunks": [],
343
467
  }
344
468
 
345
469
  shards[shard_name]["total_chunks"] += 1
470
+ shards[shard_name]["chunks"].append(chunk_state)
346
471
 
347
472
  if chunk_state.status == "completed":
348
473
  shards[shard_name]["completed_chunks"] += 1
@@ -361,7 +486,7 @@ class ChunkTracker(CheckpointTracker):
361
486
  def get_incomplete_shards(self) -> Set[str]:
362
487
  """Get set of shard names that have incomplete chunks."""
363
488
  incomplete = set()
364
- for chunk_id, chunk_state in self.chunks.items():
489
+ for _chunk_id, chunk_state in self.chunks.items():
365
490
  if chunk_state.status != "completed":
366
491
  incomplete.add(chunk_state.shard_name)
367
492
  return incomplete
@@ -373,22 +498,21 @@ class ChunkTracker(CheckpointTracker):
373
498
  if not storage_manager.captions_path.exists():
374
499
  return
375
500
 
376
- import pyarrow as pa
377
- import pyarrow.parquet as pq
501
+ import lance
378
502
 
379
503
  # Check if item_index column exists
380
- table_metadata = pq.read_metadata(storage_manager.captions_path)
504
+ table_metadata = lance.dataset(storage_manager.captions_path).schema
381
505
  columns = ["job_id", "chunk_id", "item_key"]
382
- if "item_index" in table_metadata.schema.names:
506
+ if "item_index" in table_metadata.names:
383
507
  columns.append("item_index")
384
508
 
385
509
  # Process in batches to avoid loading entire table
386
510
  batch_size = 10000
387
- parquet_file = pq.ParquetFile(storage_manager.captions_path)
511
+ lance_dataset = lance.dataset(storage_manager.captions_path)
388
512
 
389
513
  chunk_indices = defaultdict(set)
390
514
 
391
- for batch in parquet_file.iter_batches(batch_size=batch_size, columns=columns):
515
+ for batch in lance_dataset.to_batches(batch_size=batch_size, columns=columns):
392
516
  batch_dict = batch.to_pydict()
393
517
 
394
518
  for i in range(len(batch_dict["chunk_id"])):
@@ -453,11 +577,12 @@ class ChunkTracker(CheckpointTracker):
453
577
  self._process_chunk_indices(chunk_indices)
454
578
 
455
579
  logger.info("Sync with storage completed")
456
- self.save()
580
+ self._mark_dirty()
457
581
 
458
582
  def _process_chunk_indices(self, chunk_indices: Dict[str, Set[int]]):
459
583
  """Process a batch of chunk indices."""
460
584
  for chunk_id, abs_indices in chunk_indices.items():
585
+ logger.debug(f"Processing indices: {abs_indices} for chunk {chunk_id}")
461
586
  if chunk_id not in self.chunks:
462
587
  continue
463
588
 
@@ -494,39 +619,49 @@ class ChunkTracker(CheckpointTracker):
494
619
  for start_idx, end_idx in ranges:
495
620
  chunk.add_processed_range(start_idx, end_idx)
496
621
 
497
- def mark_items_processed(self, chunk_id: str, start_idx: int, end_idx: int):
498
- """Mark a range of items as processed within a chunk (expects ABSOLUTE indices)."""
622
+ def mark_items_processed(self, chunk_id: str, start_idx: int, end_idx: int) -> None:
623
+ """Mark a range of items as processed within a chunk."""
499
624
  if chunk_id not in self.chunks:
500
- logger.error(f"Unknown chunk: {chunk_id}")
625
+ logger.warning(f"Chunk {chunk_id} not found in tracker")
501
626
  return
502
627
 
503
- chunk = self.chunks[chunk_id]
628
+ chunk_state = self.chunks[chunk_id]
504
629
 
505
- # Convert absolute indices to chunk-relative
506
- relative_start = start_idx - chunk.start_index
507
- relative_end = end_idx - chunk.start_index
630
+ # Convert absolute indices to chunk-relative indices
631
+ relative_start = start_idx - chunk_state.start_index
632
+ relative_end = end_idx - chunk_state.start_index
508
633
 
509
- # Validate boundaries
510
- if relative_start < 0 or relative_end >= chunk.chunk_size:
511
- logger.error(
512
- f"Invalid indices for chunk {chunk_id}: "
513
- f"absolute {start_idx}-{end_idx} (relative {relative_start}-{relative_end}) "
514
- f"outside chunk bounds [{chunk.start_index}, {chunk.start_index + chunk.chunk_size - 1}]"
634
+ # Ensure indices are within chunk bounds and maintain valid range
635
+ relative_start = max(0, relative_start)
636
+ relative_end = min(chunk_state.chunk_size - 1, relative_end)
637
+
638
+ # Skip invalid ranges where start > end
639
+ if relative_start > relative_end:
640
+ logger.warning(
641
+ f"Invalid range for chunk {chunk_id}: start={relative_start}, end={relative_end}, skipping"
515
642
  )
516
643
  return
517
644
 
518
- # Add the relative range
519
- chunk.add_processed_range(relative_start, relative_end)
645
+ # Invalidate cache before modifying ranges
646
+ chunk_state._invalidate_cache()
520
647
 
521
- # If chunk is now complete, increment counter
522
- if chunk.status == "completed":
523
- self._completed_count += 1
648
+ # Add to processed ranges
649
+ chunk_state.processed_ranges.append((relative_start, relative_end))
524
650
 
525
- self.save()
526
- logger.debug(
527
- f"Marked items {start_idx}-{end_idx} as processed in chunk {chunk_id} "
528
- f"(relative indices: {relative_start}-{relative_end})"
529
- )
651
+ # Merge overlapping ranges
652
+ chunk_state.processed_ranges = chunk_state._merge_ranges(chunk_state.processed_ranges)
653
+
654
+ # logger.debug(
655
+ # f"Marked items {start_idx}-{end_idx} as processed in chunk {chunk_id} (relative indices: {relative_start}-{relative_end})"
656
+ # )
657
+
658
+ # Check if chunk is now complete
659
+ if chunk_state.get_unprocessed_ranges() == []:
660
+ logger.info(f"Chunk {chunk_id} is now complete")
661
+ chunk_state.status = "completed"
662
+
663
+ # Mark as dirty, will be saved based on batching logic
664
+ self._mark_dirty()
530
665
 
531
666
  def get_chunk_with_unprocessed_items(self, chunk_id: str) -> Optional[Dict[str, Any]]:
532
667
  """Get chunk info with unprocessed item ranges."""