caption-flow 0.3.4__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. caption_flow/__init__.py +3 -3
  2. caption_flow/cli.py +934 -415
  3. caption_flow/models.py +45 -3
  4. caption_flow/monitor.py +2 -3
  5. caption_flow/orchestrator.py +153 -104
  6. caption_flow/processors/__init__.py +3 -3
  7. caption_flow/processors/base.py +8 -7
  8. caption_flow/processors/huggingface.py +439 -67
  9. caption_flow/processors/local_filesystem.py +24 -28
  10. caption_flow/processors/webdataset.py +28 -22
  11. caption_flow/storage/exporter.py +420 -339
  12. caption_flow/storage/manager.py +636 -756
  13. caption_flow/utils/__init__.py +1 -1
  14. caption_flow/utils/auth.py +1 -1
  15. caption_flow/utils/caption_utils.py +1 -1
  16. caption_flow/utils/certificates.py +15 -8
  17. caption_flow/utils/checkpoint_tracker.py +30 -28
  18. caption_flow/utils/chunk_tracker.py +153 -56
  19. caption_flow/utils/image_processor.py +9 -9
  20. caption_flow/utils/json_utils.py +37 -20
  21. caption_flow/utils/prompt_template.py +24 -16
  22. caption_flow/utils/vllm_config.py +5 -4
  23. caption_flow/viewer.py +4 -12
  24. caption_flow/workers/base.py +5 -4
  25. caption_flow/workers/caption.py +265 -90
  26. caption_flow/workers/data.py +6 -8
  27. {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/METADATA +5 -4
  28. caption_flow-0.4.0.dist-info/RECORD +33 -0
  29. caption_flow-0.3.4.dist-info/RECORD +0 -33
  30. {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/WHEEL +0 -0
  31. {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/entry_points.txt +0 -0
  32. {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/licenses/LICENSE +0 -0
  33. {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
1
  """Utility modules for CaptionFlow."""
2
2
 
3
- from .chunk_tracker import ChunkTracker
4
3
  from .caption_utils import CaptionUtils
4
+ from .chunk_tracker import ChunkTracker
@@ -1,7 +1,7 @@
1
1
  """Authentication management."""
2
2
 
3
- from typing import Dict, Any, Optional
4
3
  from dataclasses import dataclass
4
+ from typing import Any, Dict, Optional
5
5
 
6
6
 
7
7
  @dataclass
@@ -1,6 +1,6 @@
1
1
  """Caption processing utilities from the original vLLM script."""
2
2
 
3
- from typing import List, Dict
3
+ from typing import Dict, List
4
4
 
5
5
 
6
6
  class CaptionUtils:
@@ -1,13 +1,15 @@
1
1
  """SSL certificate management."""
2
2
 
3
+ import datetime as _datetime
3
4
  import subprocess
5
+ from datetime import datetime, timedelta
4
6
  from pathlib import Path
5
7
  from typing import Optional
8
+
6
9
  from cryptography import x509
7
- from cryptography.x509.oid import NameOID
8
10
  from cryptography.hazmat.primitives import hashes, serialization
9
11
  from cryptography.hazmat.primitives.asymmetric import rsa
10
- from datetime import datetime, timedelta
12
+ from cryptography.x509.oid import NameOID
11
13
 
12
14
 
13
15
  class CertificateManager:
@@ -35,8 +37,8 @@ class CertificateManager:
35
37
  .issuer_name(issuer)
36
38
  .public_key(key.public_key())
37
39
  .serial_number(x509.random_serial_number())
38
- .not_valid_before(datetime.utcnow())
39
- .not_valid_after(datetime.utcnow() + timedelta(days=365))
40
+ .not_valid_before(datetime.now(_datetime.UTC))
41
+ .not_valid_after(datetime.now(_datetime.UTC) + timedelta(days=365))
40
42
  .add_extension(
41
43
  x509.SubjectAlternativeName(
42
44
  [
@@ -71,14 +73,15 @@ class CertificateManager:
71
73
  def generate_letsencrypt(
72
74
  self, domain: str, email: str, output_dir: Optional[Path] = None, staging: bool = False
73
75
  ) -> tuple[Path, Path]:
74
- """
75
- Generate Let's Encrypt certificate.
76
+ """Generate Let's Encrypt certificate.
76
77
 
77
78
  Args:
79
+ ----
78
80
  domain: Domain name for certificate
79
81
  email: Email for Let's Encrypt account
80
82
  output_dir: Custom output directory (uses /etc/letsencrypt by default)
81
83
  staging: Use Let's Encrypt staging server for testing
84
+
82
85
  """
83
86
  cmd = [
84
87
  "certbot",
@@ -133,8 +136,12 @@ class CertificateManager:
133
136
  return {
134
137
  "subject": cert.subject.rfc4514_string(),
135
138
  "issuer": cert.issuer.rfc4514_string(),
136
- "not_before": cert.not_valid_before,
137
- "not_after": cert.not_valid_after,
139
+ "not_before": cert.not_valid_before_utc,
140
+ "not_after": cert.not_valid_after_utc,
138
141
  "serial_number": cert.serial_number,
139
142
  "is_self_signed": cert.issuer == cert.subject,
140
143
  }
144
+
145
+ def inspect_certificate(self, cert_path: Path) -> dict:
146
+ """Inspect a certificate (alias for get_cert_info for CLI compatibility)."""
147
+ return self.get_cert_info(cert_path)
@@ -1,14 +1,17 @@
1
1
  """Base class for checkpoint tracking with persistent state."""
2
2
 
3
+ import datetime as _datetime
3
4
  import json
4
5
  import logging
6
+ import os
5
7
  from abc import ABC, abstractmethod
6
- from pathlib import Path
7
- from typing import Dict, Any, Optional
8
- from datetime import datetime
9
8
  from concurrent.futures import ThreadPoolExecutor
9
+ from datetime import datetime
10
+ from pathlib import Path
11
+ from typing import Any, Dict, Optional
10
12
 
11
13
  logger = logging.getLogger(__name__)
14
+ logger.setLevel(os.environ.get("CAPTIONFLOW_LOG_LEVEL", "INFO").upper())
12
15
 
13
16
 
14
17
  class CheckpointTracker(ABC):
@@ -53,35 +56,34 @@ class CheckpointTracker(ABC):
53
56
 
54
57
  def save(self) -> None:
55
58
  """Save checkpoint to disk atomically."""
56
- with self.lock:
57
- try:
58
- # Prepare data with metadata
59
+ try:
60
+ # If a save is already in progress, let it finish.
61
+ # This prevents race conditions if save() is called rapidly.
62
+ if hasattr(self, "_save_future") and self._save_future and not self._save_future.done():
63
+ logger.warning("Previous save still in progress, skipping this save")
64
+ return # don't save this time,
65
+ logger.info("Saving chunk tracker state...")
66
+ # Prepare data with metadata
67
+ with self.lock:
59
68
  data = self._serialize_state()
60
- data["updated_at"] = datetime.utcnow().isoformat()
61
-
62
- # Write atomically using temp file
63
- tmp_file = self.checkpoint_path.with_suffix(".tmp")
64
- # If a save is already in progress, let it finish.
65
- # This prevents race conditions if save() is called rapidly.
66
- if (
67
- hasattr(self, "_save_future")
68
- and self._save_future
69
- and not self._save_future.done()
70
- ):
71
- self._save_future.result() # Wait for the previous save to complete
72
-
73
- # Use an executor to run the save operation in a background thread.
74
- # This makes the save call non-blocking.
75
- with ThreadPoolExecutor(max_workers=1) as executor:
76
- data_to_save = data.copy()
77
- self._save_future = executor.submit(self._write_to_disk, data_to_save, tmp_file)
78
- except Exception as e:
79
- logger.error(f"Failed to submit save task: {e}", exc_info=True)
69
+ data["updated_at"] = datetime.now(_datetime.UTC).isoformat()
70
+
71
+ # Write atomically using temp file
72
+ tmp_file = self.checkpoint_path.with_suffix(".tmp")
73
+
74
+ # Use an executor to run the save operation in a background thread.
75
+ # This makes the save call non-blocking.
76
+ with ThreadPoolExecutor(max_workers=1) as executor:
77
+ data_to_save = data.copy()
78
+ self._save_future = executor.submit(self._write_to_disk, data_to_save, tmp_file)
79
+ except Exception as e:
80
+ logger.error(f"Failed to submit save task: {e}", exc_info=True)
80
81
 
81
- def _write_to_disk(self, data: Dict[str, Any]) -> None:
82
+ def _write_to_disk(self, data: Dict[str, Any], checkpoint_path: Optional[str] = None) -> None:
82
83
  """Write checkpoint data to disk atomically."""
83
84
  # Create a temporary file in the same directory as the checkpoint
84
- tmp_file = self.checkpoint_path.with_suffix(".tmp")
85
+ tmp_file = (checkpoint_path or self.checkpoint_path).with_suffix(".tmp")
86
+ logger.debug(f"Checkpoint {tmp_file=}")
85
87
 
86
88
  try:
87
89
  # Ensure the parent directory exists
@@ -1,17 +1,19 @@
1
1
  """Chunk tracking using CheckpointTracker base class with memory optimization."""
2
2
 
3
- from collections import defaultdict
3
+ import datetime as _datetime
4
4
  import logging
5
- from pathlib import Path
6
- from typing import Set, Dict, List, Optional, Any, Tuple
5
+ import os
6
+ from collections import defaultdict
7
+ from dataclasses import asdict, dataclass, field
7
8
  from datetime import datetime, timedelta
8
- from dataclasses import dataclass, asdict, field
9
+ from pathlib import Path
10
+ from threading import Lock
11
+ from typing import Any, Dict, List, Optional, Set, Tuple
9
12
 
10
13
  from .checkpoint_tracker import CheckpointTracker
11
- from threading import Lock
12
14
 
13
15
  logger = logging.getLogger(__name__)
14
- logger.setLevel(logging.DEBUG)
16
+ logger.setLevel(os.environ.get("CAPTIONFLOW_LOG_LEVEL", "INFO").upper())
15
17
 
16
18
 
17
19
  @dataclass
@@ -32,8 +34,16 @@ class ChunkState:
32
34
  assigned_to: Optional[str] = None
33
35
  assigned_at: Optional[datetime] = None
34
36
 
37
+ # Cache for expensive range calculations
38
+ _cached_merged_ranges: Optional[List[Tuple[int, int]]] = field(default=None, init=False)
39
+ _cached_unprocessed_ranges: Optional[List[Tuple[int, int]]] = field(default=None, init=False)
40
+ _cache_invalidated: bool = field(default=True, init=False)
41
+
35
42
  def add_processed_range(self, start: int, end: int):
36
43
  """Add a processed range and merge if needed."""
44
+ # Invalidate cache before modifying ranges
45
+ self._invalidate_cache()
46
+
37
47
  # Add new range
38
48
  self.processed_ranges.append((start, end))
39
49
 
@@ -58,24 +68,43 @@ class ChunkState:
58
68
 
59
69
  def mark_completed(self):
60
70
  """Mark chunk as completed and clear unnecessary data to save memory."""
71
+ self._invalidate_cache()
61
72
  self.status = "completed"
62
- self.completed_at = datetime.utcnow()
73
+ self.completed_at = datetime.now(_datetime.UTC)
63
74
  # Clear processed_ranges since we don't need them after completion
64
75
  # self.processed_ranges = []
65
76
  # self.assigned_to = None
66
77
  # self.assigned_at = None
67
78
 
79
+ def _invalidate_cache(self):
80
+ """Invalidate cached range calculations."""
81
+ self._cached_merged_ranges = None
82
+ self._cached_unprocessed_ranges = None
83
+ self._cache_invalidated = True
84
+
85
+ def _get_merged_ranges(self) -> List[Tuple[int, int]]:
86
+ """Get merged ranges with caching."""
87
+ if self._cached_merged_ranges is None:
88
+ self._cached_merged_ranges = self._merge_ranges(self.processed_ranges)
89
+ return self._cached_merged_ranges
90
+
68
91
  def get_unprocessed_ranges(self) -> List[Tuple[int, int]]:
69
92
  """Get ranges of unprocessed items within the chunk (relative indices)."""
70
93
  if self.status == "completed":
71
94
  return []
72
95
 
73
96
  if not self.processed_ranges:
74
- logger.info(f"Chunk {self.chunk_id} has no processed ranges, returning full range")
97
+ if self._cache_invalidated: # Only log once per invalidation
98
+ logger.info(f"Chunk {self.chunk_id} has no processed ranges, returning full range")
99
+ self._cache_invalidated = False
75
100
  return [(0, self.chunk_size - 1)]
76
101
 
77
- # Merge ranges first to ensure no overlaps
78
- merged_ranges = self._merge_ranges(self.processed_ranges)
102
+ # Use cached result if available
103
+ if self._cached_unprocessed_ranges is not None:
104
+ return self._cached_unprocessed_ranges
105
+
106
+ # Calculate and cache unprocessed ranges
107
+ merged_ranges = self._get_merged_ranges()
79
108
 
80
109
  unprocessed = []
81
110
  current_pos = 0
@@ -89,17 +118,23 @@ class ChunkState:
89
118
  if current_pos < self.chunk_size:
90
119
  unprocessed.append((current_pos, self.chunk_size - 1))
91
120
 
92
- # Log for debugging
93
- if not unprocessed:
94
- logger.info(
95
- f"Chunk {self.chunk_id} has processed ranges {merged_ranges} covering entire chunk size {self.chunk_size}"
96
- )
97
- else:
98
- total_processed = sum(end - start + 1 for start, end in merged_ranges)
99
- total_unprocessed = sum(end - start + 1 for start, end in unprocessed)
100
- logger.debug(
101
- f"Chunk {self.chunk_id}: {total_processed} processed, {total_unprocessed} unprocessed"
102
- )
121
+ # Cache the result
122
+ self._cached_unprocessed_ranges = unprocessed
123
+
124
+ # Log for debugging (only when cache is being computed)
125
+ if self._cache_invalidated:
126
+ if not unprocessed:
127
+ logger.info(
128
+ f"Chunk {self.chunk_id} has processed ranges {merged_ranges} covering entire chunk size {self.chunk_size}"
129
+ )
130
+ else:
131
+ logger.debug(f"Merged ranges for chunk {self.chunk_id}: {merged_ranges}")
132
+ total_processed = sum(end - start + 1 for start, end in merged_ranges)
133
+ total_unprocessed = sum(end - start + 1 for start, end in unprocessed)
134
+ logger.debug(
135
+ f"Chunk {self.chunk_id}: {total_processed} processed, {total_unprocessed} unprocessed"
136
+ )
137
+ self._cache_invalidated = False
103
138
 
104
139
  return unprocessed
105
140
 
@@ -144,6 +179,10 @@ class ChunkState:
144
179
  # Ensure processed_ranges exists
145
180
  d.setdefault("processed_ranges", [])
146
181
  d.setdefault("processed_count", 0)
182
+ # Remove cache fields from dict if they exist (shouldn't be serialized)
183
+ d.pop("_cached_merged_ranges", None)
184
+ d.pop("_cached_unprocessed_ranges", None)
185
+ d.pop("_cache_invalidated", None)
147
186
  return cls(**d)
148
187
 
149
188
 
@@ -155,12 +194,22 @@ class ChunkTracker(CheckpointTracker):
155
194
  checkpoint_file: Path,
156
195
  max_completed_chunks_in_memory: int = 1000,
157
196
  archive_after_hours: int = 24,
197
+ save_batch_size: int = 10,
198
+ auto_save_interval: int = 60,
158
199
  ):
159
200
  self.chunks: Dict[str, ChunkState] = {}
160
201
  self.max_completed_chunks_in_memory = max_completed_chunks_in_memory
161
202
  self.archive_after_hours = archive_after_hours
162
203
  self._completed_count = 0 # Track count without storing all IDs
163
204
  self.lock = Lock()
205
+
206
+ # Batching mechanism
207
+ self._dirty = False
208
+ self._pending_changes = 0
209
+ self._save_batch_size = save_batch_size
210
+ self._auto_save_interval = auto_save_interval
211
+ self._last_save = datetime.now(_datetime.UTC)
212
+
164
213
  super().__init__(checkpoint_file)
165
214
 
166
215
  def _get_default_state(self) -> Dict[str, Any]:
@@ -169,17 +218,17 @@ class ChunkTracker(CheckpointTracker):
169
218
 
170
219
  def _deserialize_state(self, data: Dict[str, Any]) -> None:
171
220
  """Deserialize loaded data into instance state."""
172
- with self.lock:
173
- self.chunks = {}
174
- self._completed_count = data.get("completed_count", 0)
175
-
176
- # Load chunk states
177
- completed_chunks = 0
178
- for chunk_id, chunk_data in data.get("chunks", {}).items():
179
- chunk_state = ChunkState.from_dict(chunk_data)
221
+ self.chunks = {}
222
+ self._completed_count = data.get("completed_count", 0)
223
+
224
+ # Load chunk states
225
+ completed_chunks = 0
226
+ for chunk_id, chunk_data in data.get("chunks", {}).items():
227
+ chunk_state = ChunkState.from_dict(chunk_data)
228
+ with self.lock:
180
229
  self.chunks[chunk_id] = chunk_state
181
- if chunk_state.status == "completed":
182
- completed_chunks += 1
230
+ if chunk_state.status == "completed":
231
+ completed_chunks += 1
183
232
 
184
233
  logger.info(
185
234
  f"Loaded {len(self.chunks)} chunks from checkpoint, "
@@ -194,12 +243,47 @@ class ChunkTracker(CheckpointTracker):
194
243
  "completed_count": self._completed_count,
195
244
  }
196
245
 
246
+ def _mark_dirty(self):
247
+ """Mark tracker as having pending changes."""
248
+ self._dirty = True
249
+ self._pending_changes += 1
250
+
251
+ # Auto-save based on batch size or time interval
252
+ now = datetime.now(_datetime.UTC)
253
+ time_since_last_save = (now - self._last_save).total_seconds()
254
+
255
+ if (
256
+ self._pending_changes >= self._save_batch_size
257
+ or time_since_last_save >= self._auto_save_interval
258
+ ):
259
+ self._do_save()
260
+
261
+ def _do_save(self) -> bool:
262
+ """Internal method to perform the actual save."""
263
+ super().save() # Parent method returns None but triggers save
264
+ # Reset dirty state since save was initiated successfully
265
+ self._dirty = False
266
+ self._pending_changes = 0
267
+ self._last_save = datetime.now(_datetime.UTC)
268
+ return True
269
+
270
+ def save(self, force: bool = False) -> bool:
271
+ """Save state to checkpoint file, with batching optimization."""
272
+ if not force and not self._dirty:
273
+ return False
274
+ return self._do_save()
275
+
276
+ def flush(self):
277
+ """Force save any pending changes."""
278
+ if self._dirty:
279
+ self._do_save()
280
+
197
281
  def _archive_old_completed_chunks(self):
198
282
  """Remove old completed chunks from memory to prevent unbounded growth."""
199
283
  if not self.archive_after_hours:
200
284
  return
201
285
 
202
- cutoff_time = datetime.utcnow() - timedelta(hours=self.archive_after_hours)
286
+ cutoff_time = datetime.now(_datetime.UTC) - timedelta(hours=self.archive_after_hours)
203
287
  chunks_to_remove = []
204
288
 
205
289
  for chunk_id, chunk in self.chunks.items():
@@ -214,7 +298,7 @@ class ChunkTracker(CheckpointTracker):
214
298
  for chunk_id in chunks_to_remove:
215
299
  del self.chunks[chunk_id]
216
300
  logger.info(f"Archived {len(chunks_to_remove)} old completed chunks from memory")
217
- self.save()
301
+ self._mark_dirty()
218
302
 
219
303
  def _limit_completed_chunks_in_memory(self):
220
304
  """Keep only the most recent completed chunks in memory."""
@@ -232,7 +316,7 @@ class ChunkTracker(CheckpointTracker):
232
316
  del self.chunks[chunk_id]
233
317
 
234
318
  logger.info(f"Removed {to_remove} oldest completed chunks from memory")
235
- self.save()
319
+ self._mark_dirty()
236
320
 
237
321
  def add_chunk(
238
322
  self, chunk_id: str, shard_name: str, shard_url: str, start_index: int, chunk_size: int
@@ -252,7 +336,7 @@ class ChunkTracker(CheckpointTracker):
252
336
  chunk_size=chunk_size,
253
337
  status="pending",
254
338
  )
255
- self.save()
339
+ self._mark_dirty()
256
340
 
257
341
  # Periodically clean up old chunks
258
342
  if len(self.chunks) % 100 == 0:
@@ -267,8 +351,8 @@ class ChunkTracker(CheckpointTracker):
267
351
  chunk = self.chunks[chunk_id]
268
352
  chunk.status = "assigned"
269
353
  chunk.assigned_to = worker_id
270
- chunk.assigned_at = datetime.utcnow()
271
- self.save()
354
+ chunk.assigned_at = datetime.now(_datetime.UTC)
355
+ self._mark_dirty()
272
356
 
273
357
  def mark_completed(self, chunk_id: str):
274
358
  """Mark chunk as completed."""
@@ -278,7 +362,7 @@ class ChunkTracker(CheckpointTracker):
278
362
  chunk.mark_completed() # This clears processed_ranges
279
363
  if not was_completed:
280
364
  self._completed_count += 1
281
- self.save()
365
+ self._mark_dirty()
282
366
  logger.debug(f"Chunk {chunk_id} marked as completed")
283
367
 
284
368
  # Check if we need to clean up
@@ -292,7 +376,7 @@ class ChunkTracker(CheckpointTracker):
292
376
  chunk.status = "pending" # Reset to pending for retry
293
377
  chunk.assigned_to = None
294
378
  chunk.assigned_at = None
295
- self.save()
379
+ self._mark_dirty()
296
380
 
297
381
  def mark_pending(self, chunk_id: str):
298
382
  """Mark chunk as pending (for manual reset)."""
@@ -303,7 +387,7 @@ class ChunkTracker(CheckpointTracker):
303
387
  chunk.status = "pending"
304
388
  chunk.assigned_to = None
305
389
  chunk.assigned_at = None
306
- self.save()
390
+ self._mark_dirty()
307
391
 
308
392
  def release_worker_chunks(self, worker_id: str):
309
393
  """Release all chunks assigned to a worker."""
@@ -314,7 +398,8 @@ class ChunkTracker(CheckpointTracker):
314
398
  chunk.assigned_to = None
315
399
  chunk.assigned_at = None
316
400
  released_chunks.append(chunk_id)
317
- self.save()
401
+ if released_chunks:
402
+ self._mark_dirty()
318
403
  return released_chunks
319
404
 
320
405
  def get_pending_chunks(self, shard_name: Optional[str] = None) -> List[str]:
@@ -368,7 +453,7 @@ class ChunkTracker(CheckpointTracker):
368
453
  """Get summary of all shards and their chunk status."""
369
454
  shards = {}
370
455
 
371
- for chunk_id, chunk_state in self.chunks.items():
456
+ for _chunk_id, chunk_state in self.chunks.items():
372
457
  shard_name = chunk_state.shard_name
373
458
  if shard_name not in shards:
374
459
  shards[shard_name] = {
@@ -378,9 +463,11 @@ class ChunkTracker(CheckpointTracker):
378
463
  "assigned_chunks": 0,
379
464
  "failed_chunks": 0,
380
465
  "is_complete": True,
466
+ "chunks": [],
381
467
  }
382
468
 
383
469
  shards[shard_name]["total_chunks"] += 1
470
+ shards[shard_name]["chunks"].append(chunk_state)
384
471
 
385
472
  if chunk_state.status == "completed":
386
473
  shards[shard_name]["completed_chunks"] += 1
@@ -399,7 +486,7 @@ class ChunkTracker(CheckpointTracker):
399
486
  def get_incomplete_shards(self) -> Set[str]:
400
487
  """Get set of shard names that have incomplete chunks."""
401
488
  incomplete = set()
402
- for chunk_id, chunk_state in self.chunks.items():
489
+ for _chunk_id, chunk_state in self.chunks.items():
403
490
  if chunk_state.status != "completed":
404
491
  incomplete.add(chunk_state.shard_name)
405
492
  return incomplete
@@ -411,22 +498,21 @@ class ChunkTracker(CheckpointTracker):
411
498
  if not storage_manager.captions_path.exists():
412
499
  return
413
500
 
414
- import pyarrow as pa
415
- import pyarrow.parquet as pq
501
+ import lance
416
502
 
417
503
  # Check if item_index column exists
418
- table_metadata = pq.read_metadata(storage_manager.captions_path)
504
+ table_metadata = lance.dataset(storage_manager.captions_path).schema
419
505
  columns = ["job_id", "chunk_id", "item_key"]
420
- if "item_index" in table_metadata.schema.names:
506
+ if "item_index" in table_metadata.names:
421
507
  columns.append("item_index")
422
508
 
423
509
  # Process in batches to avoid loading entire table
424
510
  batch_size = 10000
425
- parquet_file = pq.ParquetFile(storage_manager.captions_path)
511
+ lance_dataset = lance.dataset(storage_manager.captions_path)
426
512
 
427
513
  chunk_indices = defaultdict(set)
428
514
 
429
- for batch in parquet_file.iter_batches(batch_size=batch_size, columns=columns):
515
+ for batch in lance_dataset.to_batches(batch_size=batch_size, columns=columns):
430
516
  batch_dict = batch.to_pydict()
431
517
 
432
518
  for i in range(len(batch_dict["chunk_id"])):
@@ -491,11 +577,12 @@ class ChunkTracker(CheckpointTracker):
491
577
  self._process_chunk_indices(chunk_indices)
492
578
 
493
579
  logger.info("Sync with storage completed")
494
- self.save()
580
+ self._mark_dirty()
495
581
 
496
582
  def _process_chunk_indices(self, chunk_indices: Dict[str, Set[int]]):
497
583
  """Process a batch of chunk indices."""
498
584
  for chunk_id, abs_indices in chunk_indices.items():
585
+ logger.debug(f"Processing indices: {abs_indices} for chunk {chunk_id}")
499
586
  if chunk_id not in self.chunks:
500
587
  continue
501
588
 
@@ -544,27 +631,37 @@ class ChunkTracker(CheckpointTracker):
544
631
  relative_start = start_idx - chunk_state.start_index
545
632
  relative_end = end_idx - chunk_state.start_index
546
633
 
547
- # Ensure indices are within chunk bounds
634
+ # Ensure indices are within chunk bounds and maintain valid range
548
635
  relative_start = max(0, relative_start)
549
636
  relative_end = min(chunk_state.chunk_size - 1, relative_end)
550
637
 
638
+ # Skip invalid ranges where start > end
639
+ if relative_start > relative_end:
640
+ logger.warning(
641
+ f"Invalid range for chunk {chunk_id}: start={relative_start}, end={relative_end}, skipping"
642
+ )
643
+ return
644
+
645
+ # Invalidate cache before modifying ranges
646
+ chunk_state._invalidate_cache()
647
+
551
648
  # Add to processed ranges
552
649
  chunk_state.processed_ranges.append((relative_start, relative_end))
553
650
 
554
651
  # Merge overlapping ranges
555
652
  chunk_state.processed_ranges = chunk_state._merge_ranges(chunk_state.processed_ranges)
556
653
 
557
- logger.debug(
558
- f"Marked items {start_idx}-{end_idx} as processed in chunk {chunk_id} (relative indices: {relative_start}-{relative_end})"
559
- )
654
+ # logger.debug(
655
+ # f"Marked items {start_idx}-{end_idx} as processed in chunk {chunk_id} (relative indices: {relative_start}-{relative_end})"
656
+ # )
560
657
 
561
658
  # Check if chunk is now complete
562
659
  if chunk_state.get_unprocessed_ranges() == []:
563
660
  logger.info(f"Chunk {chunk_id} is now complete")
564
661
  chunk_state.status = "completed"
565
662
 
566
- # Save checkpoint after updating
567
- self.save()
663
+ # Mark as dirty, will be saved based on batching logic
664
+ self._mark_dirty()
568
665
 
569
666
  def get_chunk_with_unprocessed_items(self, chunk_id: str) -> Optional[Dict[str, Any]]:
570
667
  """Get chunk info with unprocessed item ranges."""
@@ -1,19 +1,16 @@
1
1
  """Image preprocessing utilities."""
2
2
 
3
- import asyncio
4
3
  import logging
4
+ import os
5
5
  from concurrent.futures import ProcessPoolExecutor
6
6
  from io import BytesIO
7
- from pathlib import Path
8
- from typing import List, Any, Optional, Tuple, Union
9
7
 
10
- import numpy as np
11
- import requests
12
8
  from PIL import Image
13
- from ..models import ProcessingItem
14
9
 
10
+ from ..models import ProcessingItem
15
11
 
16
12
  logger = logging.getLogger(__name__)
13
+ logger.setLevel(os.environ.get("CAPTIONFLOW_LOG_LEVEL", "INFO").upper())
17
14
 
18
15
 
19
16
  class ImageProcessor:
@@ -24,22 +21,25 @@ class ImageProcessor:
24
21
 
25
22
  @staticmethod
26
23
  def prepare_for_inference(item: ProcessingItem) -> Image.Image:
27
- """
28
- Prepare image for inference.
24
+ """Prepare image for inference.
29
25
 
30
26
  Args:
27
+ ----
31
28
  image: PIL Image to prepare
32
29
 
33
30
  Returns:
31
+ -------
34
32
  Prepared PIL Image
33
+
35
34
  """
36
35
  # We used to do a lot more hand-holding here with transparency, but oh well.
36
+ logger.debug(f"Preparing item for inference: {item}")
37
37
 
38
38
  if item.image is not None:
39
39
  image = item.image
40
40
  item.metadata["image_width"], item.metadata["image_height"] = image.size
41
41
  item.metadata["image_format"] = image.format or "unknown"
42
- item.image = None
42
+ # item.image = None
43
43
  return image
44
44
 
45
45
  item.image = None