caption-flow 0.3.4__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +3 -3
- caption_flow/cli.py +921 -427
- caption_flow/models.py +45 -3
- caption_flow/monitor.py +2 -3
- caption_flow/orchestrator.py +153 -104
- caption_flow/processors/__init__.py +3 -3
- caption_flow/processors/base.py +8 -7
- caption_flow/processors/huggingface.py +463 -68
- caption_flow/processors/local_filesystem.py +24 -28
- caption_flow/processors/webdataset.py +28 -22
- caption_flow/storage/exporter.py +420 -339
- caption_flow/storage/manager.py +636 -756
- caption_flow/utils/__init__.py +1 -1
- caption_flow/utils/auth.py +1 -1
- caption_flow/utils/caption_utils.py +1 -1
- caption_flow/utils/certificates.py +15 -8
- caption_flow/utils/checkpoint_tracker.py +30 -28
- caption_flow/utils/chunk_tracker.py +153 -56
- caption_flow/utils/image_processor.py +9 -9
- caption_flow/utils/json_utils.py +37 -20
- caption_flow/utils/prompt_template.py +24 -16
- caption_flow/utils/vllm_config.py +5 -4
- caption_flow/viewer.py +4 -12
- caption_flow/workers/base.py +5 -4
- caption_flow/workers/caption.py +303 -92
- caption_flow/workers/data.py +6 -8
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/METADATA +9 -4
- caption_flow-0.4.1.dist-info/RECORD +33 -0
- caption_flow-0.3.4.dist-info/RECORD +0 -33
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/WHEEL +0 -0
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/top_level.txt +0 -0
caption_flow/utils/__init__.py
CHANGED
caption_flow/utils/auth.py
CHANGED
@@ -1,13 +1,15 @@
|
|
1
1
|
"""SSL certificate management."""
|
2
2
|
|
3
|
+
import datetime as _datetime
|
3
4
|
import subprocess
|
5
|
+
from datetime import datetime, timedelta
|
4
6
|
from pathlib import Path
|
5
7
|
from typing import Optional
|
8
|
+
|
6
9
|
from cryptography import x509
|
7
|
-
from cryptography.x509.oid import NameOID
|
8
10
|
from cryptography.hazmat.primitives import hashes, serialization
|
9
11
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
10
|
-
from
|
12
|
+
from cryptography.x509.oid import NameOID
|
11
13
|
|
12
14
|
|
13
15
|
class CertificateManager:
|
@@ -35,8 +37,8 @@ class CertificateManager:
|
|
35
37
|
.issuer_name(issuer)
|
36
38
|
.public_key(key.public_key())
|
37
39
|
.serial_number(x509.random_serial_number())
|
38
|
-
.not_valid_before(datetime.
|
39
|
-
.not_valid_after(datetime.
|
40
|
+
.not_valid_before(datetime.now(_datetime.UTC))
|
41
|
+
.not_valid_after(datetime.now(_datetime.UTC) + timedelta(days=365))
|
40
42
|
.add_extension(
|
41
43
|
x509.SubjectAlternativeName(
|
42
44
|
[
|
@@ -71,14 +73,15 @@ class CertificateManager:
|
|
71
73
|
def generate_letsencrypt(
|
72
74
|
self, domain: str, email: str, output_dir: Optional[Path] = None, staging: bool = False
|
73
75
|
) -> tuple[Path, Path]:
|
74
|
-
"""
|
75
|
-
Generate Let's Encrypt certificate.
|
76
|
+
"""Generate Let's Encrypt certificate.
|
76
77
|
|
77
78
|
Args:
|
79
|
+
----
|
78
80
|
domain: Domain name for certificate
|
79
81
|
email: Email for Let's Encrypt account
|
80
82
|
output_dir: Custom output directory (uses /etc/letsencrypt by default)
|
81
83
|
staging: Use Let's Encrypt staging server for testing
|
84
|
+
|
82
85
|
"""
|
83
86
|
cmd = [
|
84
87
|
"certbot",
|
@@ -133,8 +136,12 @@ class CertificateManager:
|
|
133
136
|
return {
|
134
137
|
"subject": cert.subject.rfc4514_string(),
|
135
138
|
"issuer": cert.issuer.rfc4514_string(),
|
136
|
-
"not_before": cert.
|
137
|
-
"not_after": cert.
|
139
|
+
"not_before": cert.not_valid_before_utc,
|
140
|
+
"not_after": cert.not_valid_after_utc,
|
138
141
|
"serial_number": cert.serial_number,
|
139
142
|
"is_self_signed": cert.issuer == cert.subject,
|
140
143
|
}
|
144
|
+
|
145
|
+
def inspect_certificate(self, cert_path: Path) -> dict:
|
146
|
+
"""Inspect a certificate (alias for get_cert_info for CLI compatibility)."""
|
147
|
+
return self.get_cert_info(cert_path)
|
@@ -1,14 +1,17 @@
|
|
1
1
|
"""Base class for checkpoint tracking with persistent state."""
|
2
2
|
|
3
|
+
import datetime as _datetime
|
3
4
|
import json
|
4
5
|
import logging
|
6
|
+
import os
|
5
7
|
from abc import ABC, abstractmethod
|
6
|
-
from pathlib import Path
|
7
|
-
from typing import Dict, Any, Optional
|
8
|
-
from datetime import datetime
|
9
8
|
from concurrent.futures import ThreadPoolExecutor
|
9
|
+
from datetime import datetime
|
10
|
+
from pathlib import Path
|
11
|
+
from typing import Any, Dict, Optional
|
10
12
|
|
11
13
|
logger = logging.getLogger(__name__)
|
14
|
+
logger.setLevel(os.environ.get("CAPTIONFLOW_LOG_LEVEL", "INFO").upper())
|
12
15
|
|
13
16
|
|
14
17
|
class CheckpointTracker(ABC):
|
@@ -53,35 +56,34 @@ class CheckpointTracker(ABC):
|
|
53
56
|
|
54
57
|
def save(self) -> None:
|
55
58
|
"""Save checkpoint to disk atomically."""
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
+
try:
|
60
|
+
# If a save is already in progress, let it finish.
|
61
|
+
# This prevents race conditions if save() is called rapidly.
|
62
|
+
if hasattr(self, "_save_future") and self._save_future and not self._save_future.done():
|
63
|
+
logger.warning("Previous save still in progress, skipping this save")
|
64
|
+
return # don't save this time,
|
65
|
+
logger.info("Saving chunk tracker state...")
|
66
|
+
# Prepare data with metadata
|
67
|
+
with self.lock:
|
59
68
|
data = self._serialize_state()
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
# Use an executor to run the save operation in a background thread.
|
74
|
-
# This makes the save call non-blocking.
|
75
|
-
with ThreadPoolExecutor(max_workers=1) as executor:
|
76
|
-
data_to_save = data.copy()
|
77
|
-
self._save_future = executor.submit(self._write_to_disk, data_to_save, tmp_file)
|
78
|
-
except Exception as e:
|
79
|
-
logger.error(f"Failed to submit save task: {e}", exc_info=True)
|
69
|
+
data["updated_at"] = datetime.now(_datetime.UTC).isoformat()
|
70
|
+
|
71
|
+
# Write atomically using temp file
|
72
|
+
tmp_file = self.checkpoint_path.with_suffix(".tmp")
|
73
|
+
|
74
|
+
# Use an executor to run the save operation in a background thread.
|
75
|
+
# This makes the save call non-blocking.
|
76
|
+
with ThreadPoolExecutor(max_workers=1) as executor:
|
77
|
+
data_to_save = data.copy()
|
78
|
+
self._save_future = executor.submit(self._write_to_disk, data_to_save, tmp_file)
|
79
|
+
except Exception as e:
|
80
|
+
logger.error(f"Failed to submit save task: {e}", exc_info=True)
|
80
81
|
|
81
|
-
def _write_to_disk(self, data: Dict[str, Any]) -> None:
|
82
|
+
def _write_to_disk(self, data: Dict[str, Any], checkpoint_path: Optional[str] = None) -> None:
|
82
83
|
"""Write checkpoint data to disk atomically."""
|
83
84
|
# Create a temporary file in the same directory as the checkpoint
|
84
|
-
tmp_file = self.checkpoint_path.with_suffix(".tmp")
|
85
|
+
tmp_file = (checkpoint_path or self.checkpoint_path).with_suffix(".tmp")
|
86
|
+
logger.debug(f"Checkpoint {tmp_file=}")
|
85
87
|
|
86
88
|
try:
|
87
89
|
# Ensure the parent directory exists
|
@@ -1,17 +1,19 @@
|
|
1
1
|
"""Chunk tracking using CheckpointTracker base class with memory optimization."""
|
2
2
|
|
3
|
-
|
3
|
+
import datetime as _datetime
|
4
4
|
import logging
|
5
|
-
|
6
|
-
from
|
5
|
+
import os
|
6
|
+
from collections import defaultdict
|
7
|
+
from dataclasses import asdict, dataclass, field
|
7
8
|
from datetime import datetime, timedelta
|
8
|
-
from
|
9
|
+
from pathlib import Path
|
10
|
+
from threading import Lock
|
11
|
+
from typing import Any, Dict, List, Optional, Set, Tuple
|
9
12
|
|
10
13
|
from .checkpoint_tracker import CheckpointTracker
|
11
|
-
from threading import Lock
|
12
14
|
|
13
15
|
logger = logging.getLogger(__name__)
|
14
|
-
logger.setLevel(
|
16
|
+
logger.setLevel(os.environ.get("CAPTIONFLOW_LOG_LEVEL", "INFO").upper())
|
15
17
|
|
16
18
|
|
17
19
|
@dataclass
|
@@ -32,8 +34,16 @@ class ChunkState:
|
|
32
34
|
assigned_to: Optional[str] = None
|
33
35
|
assigned_at: Optional[datetime] = None
|
34
36
|
|
37
|
+
# Cache for expensive range calculations
|
38
|
+
_cached_merged_ranges: Optional[List[Tuple[int, int]]] = field(default=None, init=False)
|
39
|
+
_cached_unprocessed_ranges: Optional[List[Tuple[int, int]]] = field(default=None, init=False)
|
40
|
+
_cache_invalidated: bool = field(default=True, init=False)
|
41
|
+
|
35
42
|
def add_processed_range(self, start: int, end: int):
|
36
43
|
"""Add a processed range and merge if needed."""
|
44
|
+
# Invalidate cache before modifying ranges
|
45
|
+
self._invalidate_cache()
|
46
|
+
|
37
47
|
# Add new range
|
38
48
|
self.processed_ranges.append((start, end))
|
39
49
|
|
@@ -58,24 +68,43 @@ class ChunkState:
|
|
58
68
|
|
59
69
|
def mark_completed(self):
|
60
70
|
"""Mark chunk as completed and clear unnecessary data to save memory."""
|
71
|
+
self._invalidate_cache()
|
61
72
|
self.status = "completed"
|
62
|
-
self.completed_at = datetime.
|
73
|
+
self.completed_at = datetime.now(_datetime.UTC)
|
63
74
|
# Clear processed_ranges since we don't need them after completion
|
64
75
|
# self.processed_ranges = []
|
65
76
|
# self.assigned_to = None
|
66
77
|
# self.assigned_at = None
|
67
78
|
|
79
|
+
def _invalidate_cache(self):
|
80
|
+
"""Invalidate cached range calculations."""
|
81
|
+
self._cached_merged_ranges = None
|
82
|
+
self._cached_unprocessed_ranges = None
|
83
|
+
self._cache_invalidated = True
|
84
|
+
|
85
|
+
def _get_merged_ranges(self) -> List[Tuple[int, int]]:
|
86
|
+
"""Get merged ranges with caching."""
|
87
|
+
if self._cached_merged_ranges is None:
|
88
|
+
self._cached_merged_ranges = self._merge_ranges(self.processed_ranges)
|
89
|
+
return self._cached_merged_ranges
|
90
|
+
|
68
91
|
def get_unprocessed_ranges(self) -> List[Tuple[int, int]]:
|
69
92
|
"""Get ranges of unprocessed items within the chunk (relative indices)."""
|
70
93
|
if self.status == "completed":
|
71
94
|
return []
|
72
95
|
|
73
96
|
if not self.processed_ranges:
|
74
|
-
|
97
|
+
if self._cache_invalidated: # Only log once per invalidation
|
98
|
+
logger.info(f"Chunk {self.chunk_id} has no processed ranges, returning full range")
|
99
|
+
self._cache_invalidated = False
|
75
100
|
return [(0, self.chunk_size - 1)]
|
76
101
|
|
77
|
-
#
|
78
|
-
|
102
|
+
# Use cached result if available
|
103
|
+
if self._cached_unprocessed_ranges is not None:
|
104
|
+
return self._cached_unprocessed_ranges
|
105
|
+
|
106
|
+
# Calculate and cache unprocessed ranges
|
107
|
+
merged_ranges = self._get_merged_ranges()
|
79
108
|
|
80
109
|
unprocessed = []
|
81
110
|
current_pos = 0
|
@@ -89,17 +118,23 @@ class ChunkState:
|
|
89
118
|
if current_pos < self.chunk_size:
|
90
119
|
unprocessed.append((current_pos, self.chunk_size - 1))
|
91
120
|
|
92
|
-
#
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
121
|
+
# Cache the result
|
122
|
+
self._cached_unprocessed_ranges = unprocessed
|
123
|
+
|
124
|
+
# Log for debugging (only when cache is being computed)
|
125
|
+
if self._cache_invalidated:
|
126
|
+
if not unprocessed:
|
127
|
+
logger.info(
|
128
|
+
f"Chunk {self.chunk_id} has processed ranges {merged_ranges} covering entire chunk size {self.chunk_size}"
|
129
|
+
)
|
130
|
+
else:
|
131
|
+
logger.debug(f"Merged ranges for chunk {self.chunk_id}: {merged_ranges}")
|
132
|
+
total_processed = sum(end - start + 1 for start, end in merged_ranges)
|
133
|
+
total_unprocessed = sum(end - start + 1 for start, end in unprocessed)
|
134
|
+
logger.debug(
|
135
|
+
f"Chunk {self.chunk_id}: {total_processed} processed, {total_unprocessed} unprocessed"
|
136
|
+
)
|
137
|
+
self._cache_invalidated = False
|
103
138
|
|
104
139
|
return unprocessed
|
105
140
|
|
@@ -144,6 +179,10 @@ class ChunkState:
|
|
144
179
|
# Ensure processed_ranges exists
|
145
180
|
d.setdefault("processed_ranges", [])
|
146
181
|
d.setdefault("processed_count", 0)
|
182
|
+
# Remove cache fields from dict if they exist (shouldn't be serialized)
|
183
|
+
d.pop("_cached_merged_ranges", None)
|
184
|
+
d.pop("_cached_unprocessed_ranges", None)
|
185
|
+
d.pop("_cache_invalidated", None)
|
147
186
|
return cls(**d)
|
148
187
|
|
149
188
|
|
@@ -155,12 +194,22 @@ class ChunkTracker(CheckpointTracker):
|
|
155
194
|
checkpoint_file: Path,
|
156
195
|
max_completed_chunks_in_memory: int = 1000,
|
157
196
|
archive_after_hours: int = 24,
|
197
|
+
save_batch_size: int = 10,
|
198
|
+
auto_save_interval: int = 60,
|
158
199
|
):
|
159
200
|
self.chunks: Dict[str, ChunkState] = {}
|
160
201
|
self.max_completed_chunks_in_memory = max_completed_chunks_in_memory
|
161
202
|
self.archive_after_hours = archive_after_hours
|
162
203
|
self._completed_count = 0 # Track count without storing all IDs
|
163
204
|
self.lock = Lock()
|
205
|
+
|
206
|
+
# Batching mechanism
|
207
|
+
self._dirty = False
|
208
|
+
self._pending_changes = 0
|
209
|
+
self._save_batch_size = save_batch_size
|
210
|
+
self._auto_save_interval = auto_save_interval
|
211
|
+
self._last_save = datetime.now(_datetime.UTC)
|
212
|
+
|
164
213
|
super().__init__(checkpoint_file)
|
165
214
|
|
166
215
|
def _get_default_state(self) -> Dict[str, Any]:
|
@@ -169,17 +218,17 @@ class ChunkTracker(CheckpointTracker):
|
|
169
218
|
|
170
219
|
def _deserialize_state(self, data: Dict[str, Any]) -> None:
|
171
220
|
"""Deserialize loaded data into instance state."""
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
221
|
+
self.chunks = {}
|
222
|
+
self._completed_count = data.get("completed_count", 0)
|
223
|
+
|
224
|
+
# Load chunk states
|
225
|
+
completed_chunks = 0
|
226
|
+
for chunk_id, chunk_data in data.get("chunks", {}).items():
|
227
|
+
chunk_state = ChunkState.from_dict(chunk_data)
|
228
|
+
with self.lock:
|
180
229
|
self.chunks[chunk_id] = chunk_state
|
181
|
-
|
182
|
-
|
230
|
+
if chunk_state.status == "completed":
|
231
|
+
completed_chunks += 1
|
183
232
|
|
184
233
|
logger.info(
|
185
234
|
f"Loaded {len(self.chunks)} chunks from checkpoint, "
|
@@ -194,12 +243,47 @@ class ChunkTracker(CheckpointTracker):
|
|
194
243
|
"completed_count": self._completed_count,
|
195
244
|
}
|
196
245
|
|
246
|
+
def _mark_dirty(self):
|
247
|
+
"""Mark tracker as having pending changes."""
|
248
|
+
self._dirty = True
|
249
|
+
self._pending_changes += 1
|
250
|
+
|
251
|
+
# Auto-save based on batch size or time interval
|
252
|
+
now = datetime.now(_datetime.UTC)
|
253
|
+
time_since_last_save = (now - self._last_save).total_seconds()
|
254
|
+
|
255
|
+
if (
|
256
|
+
self._pending_changes >= self._save_batch_size
|
257
|
+
or time_since_last_save >= self._auto_save_interval
|
258
|
+
):
|
259
|
+
self._do_save()
|
260
|
+
|
261
|
+
def _do_save(self) -> bool:
|
262
|
+
"""Internal method to perform the actual save."""
|
263
|
+
super().save() # Parent method returns None but triggers save
|
264
|
+
# Reset dirty state since save was initiated successfully
|
265
|
+
self._dirty = False
|
266
|
+
self._pending_changes = 0
|
267
|
+
self._last_save = datetime.now(_datetime.UTC)
|
268
|
+
return True
|
269
|
+
|
270
|
+
def save(self, force: bool = False) -> bool:
|
271
|
+
"""Save state to checkpoint file, with batching optimization."""
|
272
|
+
if not force and not self._dirty:
|
273
|
+
return False
|
274
|
+
return self._do_save()
|
275
|
+
|
276
|
+
def flush(self):
|
277
|
+
"""Force save any pending changes."""
|
278
|
+
if self._dirty:
|
279
|
+
self._do_save()
|
280
|
+
|
197
281
|
def _archive_old_completed_chunks(self):
|
198
282
|
"""Remove old completed chunks from memory to prevent unbounded growth."""
|
199
283
|
if not self.archive_after_hours:
|
200
284
|
return
|
201
285
|
|
202
|
-
cutoff_time = datetime.
|
286
|
+
cutoff_time = datetime.now(_datetime.UTC) - timedelta(hours=self.archive_after_hours)
|
203
287
|
chunks_to_remove = []
|
204
288
|
|
205
289
|
for chunk_id, chunk in self.chunks.items():
|
@@ -214,7 +298,7 @@ class ChunkTracker(CheckpointTracker):
|
|
214
298
|
for chunk_id in chunks_to_remove:
|
215
299
|
del self.chunks[chunk_id]
|
216
300
|
logger.info(f"Archived {len(chunks_to_remove)} old completed chunks from memory")
|
217
|
-
self.
|
301
|
+
self._mark_dirty()
|
218
302
|
|
219
303
|
def _limit_completed_chunks_in_memory(self):
|
220
304
|
"""Keep only the most recent completed chunks in memory."""
|
@@ -232,7 +316,7 @@ class ChunkTracker(CheckpointTracker):
|
|
232
316
|
del self.chunks[chunk_id]
|
233
317
|
|
234
318
|
logger.info(f"Removed {to_remove} oldest completed chunks from memory")
|
235
|
-
self.
|
319
|
+
self._mark_dirty()
|
236
320
|
|
237
321
|
def add_chunk(
|
238
322
|
self, chunk_id: str, shard_name: str, shard_url: str, start_index: int, chunk_size: int
|
@@ -252,7 +336,7 @@ class ChunkTracker(CheckpointTracker):
|
|
252
336
|
chunk_size=chunk_size,
|
253
337
|
status="pending",
|
254
338
|
)
|
255
|
-
self.
|
339
|
+
self._mark_dirty()
|
256
340
|
|
257
341
|
# Periodically clean up old chunks
|
258
342
|
if len(self.chunks) % 100 == 0:
|
@@ -267,8 +351,8 @@ class ChunkTracker(CheckpointTracker):
|
|
267
351
|
chunk = self.chunks[chunk_id]
|
268
352
|
chunk.status = "assigned"
|
269
353
|
chunk.assigned_to = worker_id
|
270
|
-
chunk.assigned_at = datetime.
|
271
|
-
self.
|
354
|
+
chunk.assigned_at = datetime.now(_datetime.UTC)
|
355
|
+
self._mark_dirty()
|
272
356
|
|
273
357
|
def mark_completed(self, chunk_id: str):
|
274
358
|
"""Mark chunk as completed."""
|
@@ -278,7 +362,7 @@ class ChunkTracker(CheckpointTracker):
|
|
278
362
|
chunk.mark_completed() # This clears processed_ranges
|
279
363
|
if not was_completed:
|
280
364
|
self._completed_count += 1
|
281
|
-
self.
|
365
|
+
self._mark_dirty()
|
282
366
|
logger.debug(f"Chunk {chunk_id} marked as completed")
|
283
367
|
|
284
368
|
# Check if we need to clean up
|
@@ -292,7 +376,7 @@ class ChunkTracker(CheckpointTracker):
|
|
292
376
|
chunk.status = "pending" # Reset to pending for retry
|
293
377
|
chunk.assigned_to = None
|
294
378
|
chunk.assigned_at = None
|
295
|
-
self.
|
379
|
+
self._mark_dirty()
|
296
380
|
|
297
381
|
def mark_pending(self, chunk_id: str):
|
298
382
|
"""Mark chunk as pending (for manual reset)."""
|
@@ -303,7 +387,7 @@ class ChunkTracker(CheckpointTracker):
|
|
303
387
|
chunk.status = "pending"
|
304
388
|
chunk.assigned_to = None
|
305
389
|
chunk.assigned_at = None
|
306
|
-
self.
|
390
|
+
self._mark_dirty()
|
307
391
|
|
308
392
|
def release_worker_chunks(self, worker_id: str):
|
309
393
|
"""Release all chunks assigned to a worker."""
|
@@ -314,7 +398,8 @@ class ChunkTracker(CheckpointTracker):
|
|
314
398
|
chunk.assigned_to = None
|
315
399
|
chunk.assigned_at = None
|
316
400
|
released_chunks.append(chunk_id)
|
317
|
-
|
401
|
+
if released_chunks:
|
402
|
+
self._mark_dirty()
|
318
403
|
return released_chunks
|
319
404
|
|
320
405
|
def get_pending_chunks(self, shard_name: Optional[str] = None) -> List[str]:
|
@@ -368,7 +453,7 @@ class ChunkTracker(CheckpointTracker):
|
|
368
453
|
"""Get summary of all shards and their chunk status."""
|
369
454
|
shards = {}
|
370
455
|
|
371
|
-
for
|
456
|
+
for _chunk_id, chunk_state in self.chunks.items():
|
372
457
|
shard_name = chunk_state.shard_name
|
373
458
|
if shard_name not in shards:
|
374
459
|
shards[shard_name] = {
|
@@ -378,9 +463,11 @@ class ChunkTracker(CheckpointTracker):
|
|
378
463
|
"assigned_chunks": 0,
|
379
464
|
"failed_chunks": 0,
|
380
465
|
"is_complete": True,
|
466
|
+
"chunks": [],
|
381
467
|
}
|
382
468
|
|
383
469
|
shards[shard_name]["total_chunks"] += 1
|
470
|
+
shards[shard_name]["chunks"].append(chunk_state)
|
384
471
|
|
385
472
|
if chunk_state.status == "completed":
|
386
473
|
shards[shard_name]["completed_chunks"] += 1
|
@@ -399,7 +486,7 @@ class ChunkTracker(CheckpointTracker):
|
|
399
486
|
def get_incomplete_shards(self) -> Set[str]:
|
400
487
|
"""Get set of shard names that have incomplete chunks."""
|
401
488
|
incomplete = set()
|
402
|
-
for
|
489
|
+
for _chunk_id, chunk_state in self.chunks.items():
|
403
490
|
if chunk_state.status != "completed":
|
404
491
|
incomplete.add(chunk_state.shard_name)
|
405
492
|
return incomplete
|
@@ -411,22 +498,21 @@ class ChunkTracker(CheckpointTracker):
|
|
411
498
|
if not storage_manager.captions_path.exists():
|
412
499
|
return
|
413
500
|
|
414
|
-
import
|
415
|
-
import pyarrow.parquet as pq
|
501
|
+
import lance
|
416
502
|
|
417
503
|
# Check if item_index column exists
|
418
|
-
table_metadata =
|
504
|
+
table_metadata = lance.dataset(storage_manager.captions_path).schema
|
419
505
|
columns = ["job_id", "chunk_id", "item_key"]
|
420
|
-
if "item_index" in table_metadata.
|
506
|
+
if "item_index" in table_metadata.names:
|
421
507
|
columns.append("item_index")
|
422
508
|
|
423
509
|
# Process in batches to avoid loading entire table
|
424
510
|
batch_size = 10000
|
425
|
-
|
511
|
+
lance_dataset = lance.dataset(storage_manager.captions_path)
|
426
512
|
|
427
513
|
chunk_indices = defaultdict(set)
|
428
514
|
|
429
|
-
for batch in
|
515
|
+
for batch in lance_dataset.to_batches(batch_size=batch_size, columns=columns):
|
430
516
|
batch_dict = batch.to_pydict()
|
431
517
|
|
432
518
|
for i in range(len(batch_dict["chunk_id"])):
|
@@ -491,11 +577,12 @@ class ChunkTracker(CheckpointTracker):
|
|
491
577
|
self._process_chunk_indices(chunk_indices)
|
492
578
|
|
493
579
|
logger.info("Sync with storage completed")
|
494
|
-
self.
|
580
|
+
self._mark_dirty()
|
495
581
|
|
496
582
|
def _process_chunk_indices(self, chunk_indices: Dict[str, Set[int]]):
|
497
583
|
"""Process a batch of chunk indices."""
|
498
584
|
for chunk_id, abs_indices in chunk_indices.items():
|
585
|
+
logger.debug(f"Processing indices: {abs_indices} for chunk {chunk_id}")
|
499
586
|
if chunk_id not in self.chunks:
|
500
587
|
continue
|
501
588
|
|
@@ -544,27 +631,37 @@ class ChunkTracker(CheckpointTracker):
|
|
544
631
|
relative_start = start_idx - chunk_state.start_index
|
545
632
|
relative_end = end_idx - chunk_state.start_index
|
546
633
|
|
547
|
-
# Ensure indices are within chunk bounds
|
634
|
+
# Ensure indices are within chunk bounds and maintain valid range
|
548
635
|
relative_start = max(0, relative_start)
|
549
636
|
relative_end = min(chunk_state.chunk_size - 1, relative_end)
|
550
637
|
|
638
|
+
# Skip invalid ranges where start > end
|
639
|
+
if relative_start > relative_end:
|
640
|
+
logger.warning(
|
641
|
+
f"Invalid range for chunk {chunk_id}: start={relative_start}, end={relative_end}, skipping"
|
642
|
+
)
|
643
|
+
return
|
644
|
+
|
645
|
+
# Invalidate cache before modifying ranges
|
646
|
+
chunk_state._invalidate_cache()
|
647
|
+
|
551
648
|
# Add to processed ranges
|
552
649
|
chunk_state.processed_ranges.append((relative_start, relative_end))
|
553
650
|
|
554
651
|
# Merge overlapping ranges
|
555
652
|
chunk_state.processed_ranges = chunk_state._merge_ranges(chunk_state.processed_ranges)
|
556
653
|
|
557
|
-
logger.debug(
|
558
|
-
|
559
|
-
)
|
654
|
+
# logger.debug(
|
655
|
+
# f"Marked items {start_idx}-{end_idx} as processed in chunk {chunk_id} (relative indices: {relative_start}-{relative_end})"
|
656
|
+
# )
|
560
657
|
|
561
658
|
# Check if chunk is now complete
|
562
659
|
if chunk_state.get_unprocessed_ranges() == []:
|
563
660
|
logger.info(f"Chunk {chunk_id} is now complete")
|
564
661
|
chunk_state.status = "completed"
|
565
662
|
|
566
|
-
#
|
567
|
-
self.
|
663
|
+
# Mark as dirty, will be saved based on batching logic
|
664
|
+
self._mark_dirty()
|
568
665
|
|
569
666
|
def get_chunk_with_unprocessed_items(self, chunk_id: str) -> Optional[Dict[str, Any]]:
|
570
667
|
"""Get chunk info with unprocessed item ranges."""
|
@@ -1,19 +1,16 @@
|
|
1
1
|
"""Image preprocessing utilities."""
|
2
2
|
|
3
|
-
import asyncio
|
4
3
|
import logging
|
4
|
+
import os
|
5
5
|
from concurrent.futures import ProcessPoolExecutor
|
6
6
|
from io import BytesIO
|
7
|
-
from pathlib import Path
|
8
|
-
from typing import List, Any, Optional, Tuple, Union
|
9
7
|
|
10
|
-
import numpy as np
|
11
|
-
import requests
|
12
8
|
from PIL import Image
|
13
|
-
from ..models import ProcessingItem
|
14
9
|
|
10
|
+
from ..models import ProcessingItem
|
15
11
|
|
16
12
|
logger = logging.getLogger(__name__)
|
13
|
+
logger.setLevel(os.environ.get("CAPTIONFLOW_LOG_LEVEL", "INFO").upper())
|
17
14
|
|
18
15
|
|
19
16
|
class ImageProcessor:
|
@@ -24,22 +21,25 @@ class ImageProcessor:
|
|
24
21
|
|
25
22
|
@staticmethod
|
26
23
|
def prepare_for_inference(item: ProcessingItem) -> Image.Image:
|
27
|
-
"""
|
28
|
-
Prepare image for inference.
|
24
|
+
"""Prepare image for inference.
|
29
25
|
|
30
26
|
Args:
|
27
|
+
----
|
31
28
|
image: PIL Image to prepare
|
32
29
|
|
33
30
|
Returns:
|
31
|
+
-------
|
34
32
|
Prepared PIL Image
|
33
|
+
|
35
34
|
"""
|
36
35
|
# We used to do a lot more hand-holding here with transparency, but oh well.
|
36
|
+
logger.debug(f"Preparing item for inference: {item}")
|
37
37
|
|
38
38
|
if item.image is not None:
|
39
39
|
image = item.image
|
40
40
|
item.metadata["image_width"], item.metadata["image_height"] = image.size
|
41
41
|
item.metadata["image_format"] = image.format or "unknown"
|
42
|
-
item.image = None
|
42
|
+
# item.image = None
|
43
43
|
return image
|
44
44
|
|
45
45
|
item.image = None
|