caption-flow 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +3 -3
- caption_flow/cli.py +937 -416
- caption_flow/models.py +45 -3
- caption_flow/monitor.py +5 -3
- caption_flow/orchestrator.py +186 -116
- caption_flow/processors/__init__.py +3 -3
- caption_flow/processors/base.py +8 -7
- caption_flow/processors/huggingface.py +440 -68
- caption_flow/processors/local_filesystem.py +24 -28
- caption_flow/processors/webdataset.py +66 -25
- caption_flow/storage/exporter.py +420 -339
- caption_flow/storage/manager.py +636 -756
- caption_flow/utils/__init__.py +1 -1
- caption_flow/utils/auth.py +1 -1
- caption_flow/utils/caption_utils.py +1 -1
- caption_flow/utils/certificates.py +15 -8
- caption_flow/utils/checkpoint_tracker.py +41 -19
- caption_flow/utils/chunk_tracker.py +200 -65
- caption_flow/utils/image_processor.py +9 -9
- caption_flow/utils/json_utils.py +37 -20
- caption_flow/utils/prompt_template.py +24 -16
- caption_flow/utils/vllm_config.py +5 -4
- caption_flow/viewer.py +4 -12
- caption_flow/workers/base.py +12 -6
- caption_flow/workers/caption.py +272 -91
- caption_flow/workers/data.py +6 -8
- {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/METADATA +5 -4
- caption_flow-0.4.0.dist-info/RECORD +33 -0
- caption_flow-0.3.3.dist-info/RECORD +0 -33
- {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/WHEEL +0 -0
- {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/top_level.txt +0 -0
caption_flow/utils/__init__.py
CHANGED
caption_flow/utils/auth.py
CHANGED
@@ -1,13 +1,15 @@
|
|
1
1
|
"""SSL certificate management."""
|
2
2
|
|
3
|
+
import datetime as _datetime
|
3
4
|
import subprocess
|
5
|
+
from datetime import datetime, timedelta
|
4
6
|
from pathlib import Path
|
5
7
|
from typing import Optional
|
8
|
+
|
6
9
|
from cryptography import x509
|
7
|
-
from cryptography.x509.oid import NameOID
|
8
10
|
from cryptography.hazmat.primitives import hashes, serialization
|
9
11
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
10
|
-
from
|
12
|
+
from cryptography.x509.oid import NameOID
|
11
13
|
|
12
14
|
|
13
15
|
class CertificateManager:
|
@@ -35,8 +37,8 @@ class CertificateManager:
|
|
35
37
|
.issuer_name(issuer)
|
36
38
|
.public_key(key.public_key())
|
37
39
|
.serial_number(x509.random_serial_number())
|
38
|
-
.not_valid_before(datetime.
|
39
|
-
.not_valid_after(datetime.
|
40
|
+
.not_valid_before(datetime.now(_datetime.UTC))
|
41
|
+
.not_valid_after(datetime.now(_datetime.UTC) + timedelta(days=365))
|
40
42
|
.add_extension(
|
41
43
|
x509.SubjectAlternativeName(
|
42
44
|
[
|
@@ -71,14 +73,15 @@ class CertificateManager:
|
|
71
73
|
def generate_letsencrypt(
|
72
74
|
self, domain: str, email: str, output_dir: Optional[Path] = None, staging: bool = False
|
73
75
|
) -> tuple[Path, Path]:
|
74
|
-
"""
|
75
|
-
Generate Let's Encrypt certificate.
|
76
|
+
"""Generate Let's Encrypt certificate.
|
76
77
|
|
77
78
|
Args:
|
79
|
+
----
|
78
80
|
domain: Domain name for certificate
|
79
81
|
email: Email for Let's Encrypt account
|
80
82
|
output_dir: Custom output directory (uses /etc/letsencrypt by default)
|
81
83
|
staging: Use Let's Encrypt staging server for testing
|
84
|
+
|
82
85
|
"""
|
83
86
|
cmd = [
|
84
87
|
"certbot",
|
@@ -133,8 +136,12 @@ class CertificateManager:
|
|
133
136
|
return {
|
134
137
|
"subject": cert.subject.rfc4514_string(),
|
135
138
|
"issuer": cert.issuer.rfc4514_string(),
|
136
|
-
"not_before": cert.
|
137
|
-
"not_after": cert.
|
139
|
+
"not_before": cert.not_valid_before_utc,
|
140
|
+
"not_after": cert.not_valid_after_utc,
|
138
141
|
"serial_number": cert.serial_number,
|
139
142
|
"is_self_signed": cert.issuer == cert.subject,
|
140
143
|
}
|
144
|
+
|
145
|
+
def inspect_certificate(self, cert_path: Path) -> dict:
|
146
|
+
"""Inspect a certificate (alias for get_cert_info for CLI compatibility)."""
|
147
|
+
return self.get_cert_info(cert_path)
|
@@ -1,13 +1,17 @@
|
|
1
1
|
"""Base class for checkpoint tracking with persistent state."""
|
2
2
|
|
3
|
+
import datetime as _datetime
|
3
4
|
import json
|
4
5
|
import logging
|
6
|
+
import os
|
5
7
|
from abc import ABC, abstractmethod
|
6
|
-
from
|
7
|
-
from typing import Dict, Any, Optional
|
8
|
+
from concurrent.futures import ThreadPoolExecutor
|
8
9
|
from datetime import datetime
|
10
|
+
from pathlib import Path
|
11
|
+
from typing import Any, Dict, Optional
|
9
12
|
|
10
13
|
logger = logging.getLogger(__name__)
|
14
|
+
logger.setLevel(os.environ.get("CAPTIONFLOW_LOG_LEVEL", "INFO").upper())
|
11
15
|
|
12
16
|
|
13
17
|
class CheckpointTracker(ABC):
|
@@ -53,34 +57,52 @@ class CheckpointTracker(ABC):
|
|
53
57
|
def save(self) -> None:
|
54
58
|
"""Save checkpoint to disk atomically."""
|
55
59
|
try:
|
60
|
+
# If a save is already in progress, let it finish.
|
61
|
+
# This prevents race conditions if save() is called rapidly.
|
62
|
+
if hasattr(self, "_save_future") and self._save_future and not self._save_future.done():
|
63
|
+
logger.warning("Previous save still in progress, skipping this save")
|
64
|
+
return # don't save this time,
|
65
|
+
logger.info("Saving chunk tracker state...")
|
56
66
|
# Prepare data with metadata
|
57
|
-
|
58
|
-
|
67
|
+
with self.lock:
|
68
|
+
data = self._serialize_state()
|
69
|
+
data["updated_at"] = datetime.now(_datetime.UTC).isoformat()
|
59
70
|
|
60
71
|
# Write atomically using temp file
|
61
72
|
tmp_file = self.checkpoint_path.with_suffix(".tmp")
|
62
73
|
|
74
|
+
# Use an executor to run the save operation in a background thread.
|
75
|
+
# This makes the save call non-blocking.
|
76
|
+
with ThreadPoolExecutor(max_workers=1) as executor:
|
77
|
+
data_to_save = data.copy()
|
78
|
+
self._save_future = executor.submit(self._write_to_disk, data_to_save, tmp_file)
|
79
|
+
except Exception as e:
|
80
|
+
logger.error(f"Failed to submit save task: {e}", exc_info=True)
|
81
|
+
|
82
|
+
def _write_to_disk(self, data: Dict[str, Any], checkpoint_path: Optional[str] = None) -> None:
|
83
|
+
"""Write checkpoint data to disk atomically."""
|
84
|
+
# Create a temporary file in the same directory as the checkpoint
|
85
|
+
tmp_file = (checkpoint_path or self.checkpoint_path).with_suffix(".tmp")
|
86
|
+
logger.debug(f"Checkpoint {tmp_file=}")
|
87
|
+
|
88
|
+
try:
|
89
|
+
# Ensure the parent directory exists
|
90
|
+
self.checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
|
91
|
+
|
63
92
|
with open(tmp_file, "w") as f:
|
64
93
|
json.dump(data, f, indent=2)
|
65
94
|
|
66
|
-
#
|
67
|
-
if not tmp_file.exists():
|
68
|
-
raise IOError(f"Failed to create temporary file: {tmp_file}")
|
69
|
-
|
70
|
-
# Move atomically
|
95
|
+
# Atomically replace the checkpoint file
|
71
96
|
tmp_file.replace(self.checkpoint_path)
|
72
|
-
|
73
97
|
logger.debug(f"Saved checkpoint to {self.checkpoint_path}")
|
74
|
-
|
75
98
|
except Exception as e:
|
76
|
-
|
77
|
-
# Try
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
logger.error(f"Fallback save also failed: {fallback_error}")
|
99
|
+
logger.error(f"Failed to save checkpoint atomically: {e}", exc_info=True)
|
100
|
+
# Try to clean up the temp file if it exists
|
101
|
+
if tmp_file.exists():
|
102
|
+
try:
|
103
|
+
tmp_file.unlink()
|
104
|
+
except:
|
105
|
+
pass
|
84
106
|
|
85
107
|
def get_stats(self) -> Dict[str, Any]:
|
86
108
|
"""Get statistics about tracked items. Override for custom stats."""
|
@@ -1,16 +1,19 @@
|
|
1
1
|
"""Chunk tracking using CheckpointTracker base class with memory optimization."""
|
2
2
|
|
3
|
-
|
3
|
+
import datetime as _datetime
|
4
4
|
import logging
|
5
|
-
|
6
|
-
from
|
5
|
+
import os
|
6
|
+
from collections import defaultdict
|
7
|
+
from dataclasses import asdict, dataclass, field
|
7
8
|
from datetime import datetime, timedelta
|
8
|
-
from
|
9
|
+
from pathlib import Path
|
10
|
+
from threading import Lock
|
11
|
+
from typing import Any, Dict, List, Optional, Set, Tuple
|
9
12
|
|
10
13
|
from .checkpoint_tracker import CheckpointTracker
|
11
14
|
|
12
15
|
logger = logging.getLogger(__name__)
|
13
|
-
logger.setLevel(
|
16
|
+
logger.setLevel(os.environ.get("CAPTIONFLOW_LOG_LEVEL", "INFO").upper())
|
14
17
|
|
15
18
|
|
16
19
|
@dataclass
|
@@ -31,8 +34,16 @@ class ChunkState:
|
|
31
34
|
assigned_to: Optional[str] = None
|
32
35
|
assigned_at: Optional[datetime] = None
|
33
36
|
|
37
|
+
# Cache for expensive range calculations
|
38
|
+
_cached_merged_ranges: Optional[List[Tuple[int, int]]] = field(default=None, init=False)
|
39
|
+
_cached_unprocessed_ranges: Optional[List[Tuple[int, int]]] = field(default=None, init=False)
|
40
|
+
_cache_invalidated: bool = field(default=True, init=False)
|
41
|
+
|
34
42
|
def add_processed_range(self, start: int, end: int):
|
35
43
|
"""Add a processed range and merge if needed."""
|
44
|
+
# Invalidate cache before modifying ranges
|
45
|
+
self._invalidate_cache()
|
46
|
+
|
36
47
|
# Add new range
|
37
48
|
self.processed_ranges.append((start, end))
|
38
49
|
|
@@ -57,38 +68,98 @@ class ChunkState:
|
|
57
68
|
|
58
69
|
def mark_completed(self):
|
59
70
|
"""Mark chunk as completed and clear unnecessary data to save memory."""
|
71
|
+
self._invalidate_cache()
|
60
72
|
self.status = "completed"
|
61
|
-
self.completed_at = datetime.
|
73
|
+
self.completed_at = datetime.now(_datetime.UTC)
|
62
74
|
# Clear processed_ranges since we don't need them after completion
|
63
|
-
self.processed_ranges = []
|
64
|
-
self.assigned_to = None
|
65
|
-
self.assigned_at = None
|
75
|
+
# self.processed_ranges = []
|
76
|
+
# self.assigned_to = None
|
77
|
+
# self.assigned_at = None
|
78
|
+
|
79
|
+
def _invalidate_cache(self):
|
80
|
+
"""Invalidate cached range calculations."""
|
81
|
+
self._cached_merged_ranges = None
|
82
|
+
self._cached_unprocessed_ranges = None
|
83
|
+
self._cache_invalidated = True
|
84
|
+
|
85
|
+
def _get_merged_ranges(self) -> List[Tuple[int, int]]:
|
86
|
+
"""Get merged ranges with caching."""
|
87
|
+
if self._cached_merged_ranges is None:
|
88
|
+
self._cached_merged_ranges = self._merge_ranges(self.processed_ranges)
|
89
|
+
return self._cached_merged_ranges
|
66
90
|
|
67
91
|
def get_unprocessed_ranges(self) -> List[Tuple[int, int]]:
|
68
|
-
"""Get ranges
|
92
|
+
"""Get ranges of unprocessed items within the chunk (relative indices)."""
|
69
93
|
if self.status == "completed":
|
70
94
|
return []
|
71
95
|
|
72
96
|
if not self.processed_ranges:
|
73
|
-
|
97
|
+
if self._cache_invalidated: # Only log once per invalidation
|
98
|
+
logger.info(f"Chunk {self.chunk_id} has no processed ranges, returning full range")
|
99
|
+
self._cache_invalidated = False
|
74
100
|
return [(0, self.chunk_size - 1)]
|
75
101
|
|
76
|
-
|
77
|
-
|
102
|
+
# Use cached result if available
|
103
|
+
if self._cached_unprocessed_ranges is not None:
|
104
|
+
return self._cached_unprocessed_ranges
|
78
105
|
|
79
|
-
|
80
|
-
|
81
|
-
)
|
82
|
-
for start, end in self.processed_ranges:
|
83
|
-
if current < start:
|
84
|
-
unprocessed.append((current, start - 1))
|
85
|
-
current = max(current, end + 1)
|
106
|
+
# Calculate and cache unprocessed ranges
|
107
|
+
merged_ranges = self._get_merged_ranges()
|
86
108
|
|
87
|
-
|
88
|
-
|
109
|
+
unprocessed = []
|
110
|
+
current_pos = 0
|
111
|
+
|
112
|
+
for start, end in merged_ranges:
|
113
|
+
if current_pos < start:
|
114
|
+
unprocessed.append((current_pos, start - 1))
|
115
|
+
current_pos = max(current_pos, end + 1)
|
116
|
+
|
117
|
+
# Add any remaining range
|
118
|
+
if current_pos < self.chunk_size:
|
119
|
+
unprocessed.append((current_pos, self.chunk_size - 1))
|
120
|
+
|
121
|
+
# Cache the result
|
122
|
+
self._cached_unprocessed_ranges = unprocessed
|
123
|
+
|
124
|
+
# Log for debugging (only when cache is being computed)
|
125
|
+
if self._cache_invalidated:
|
126
|
+
if not unprocessed:
|
127
|
+
logger.info(
|
128
|
+
f"Chunk {self.chunk_id} has processed ranges {merged_ranges} covering entire chunk size {self.chunk_size}"
|
129
|
+
)
|
130
|
+
else:
|
131
|
+
logger.debug(f"Merged ranges for chunk {self.chunk_id}: {merged_ranges}")
|
132
|
+
total_processed = sum(end - start + 1 for start, end in merged_ranges)
|
133
|
+
total_unprocessed = sum(end - start + 1 for start, end in unprocessed)
|
134
|
+
logger.debug(
|
135
|
+
f"Chunk {self.chunk_id}: {total_processed} processed, {total_unprocessed} unprocessed"
|
136
|
+
)
|
137
|
+
self._cache_invalidated = False
|
89
138
|
|
90
139
|
return unprocessed
|
91
140
|
|
141
|
+
def _merge_ranges(self, ranges: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
|
142
|
+
"""Merge overlapping or adjacent ranges."""
|
143
|
+
if not ranges:
|
144
|
+
return []
|
145
|
+
|
146
|
+
# Sort ranges by start index, ensuring all are tuples
|
147
|
+
sorted_ranges = sorted([tuple(r) for r in ranges])
|
148
|
+
merged = [sorted_ranges[0]]
|
149
|
+
|
150
|
+
for current_start, current_end in sorted_ranges[1:]:
|
151
|
+
last_start, last_end = merged[-1]
|
152
|
+
|
153
|
+
# Check if ranges overlap or are adjacent
|
154
|
+
if current_start <= last_end + 1:
|
155
|
+
# Merge the ranges
|
156
|
+
merged[-1] = (last_start, max(last_end, current_end))
|
157
|
+
else:
|
158
|
+
# Add as new range
|
159
|
+
merged.append((current_start, current_end))
|
160
|
+
|
161
|
+
return merged
|
162
|
+
|
92
163
|
def to_dict(self):
|
93
164
|
"""Convert to dictionary for JSON serialization."""
|
94
165
|
d = asdict(self)
|
@@ -108,6 +179,10 @@ class ChunkState:
|
|
108
179
|
# Ensure processed_ranges exists
|
109
180
|
d.setdefault("processed_ranges", [])
|
110
181
|
d.setdefault("processed_count", 0)
|
182
|
+
# Remove cache fields from dict if they exist (shouldn't be serialized)
|
183
|
+
d.pop("_cached_merged_ranges", None)
|
184
|
+
d.pop("_cached_unprocessed_ranges", None)
|
185
|
+
d.pop("_cache_invalidated", None)
|
111
186
|
return cls(**d)
|
112
187
|
|
113
188
|
|
@@ -119,11 +194,22 @@ class ChunkTracker(CheckpointTracker):
|
|
119
194
|
checkpoint_file: Path,
|
120
195
|
max_completed_chunks_in_memory: int = 1000,
|
121
196
|
archive_after_hours: int = 24,
|
197
|
+
save_batch_size: int = 10,
|
198
|
+
auto_save_interval: int = 60,
|
122
199
|
):
|
123
200
|
self.chunks: Dict[str, ChunkState] = {}
|
124
201
|
self.max_completed_chunks_in_memory = max_completed_chunks_in_memory
|
125
202
|
self.archive_after_hours = archive_after_hours
|
126
203
|
self._completed_count = 0 # Track count without storing all IDs
|
204
|
+
self.lock = Lock()
|
205
|
+
|
206
|
+
# Batching mechanism
|
207
|
+
self._dirty = False
|
208
|
+
self._pending_changes = 0
|
209
|
+
self._save_batch_size = save_batch_size
|
210
|
+
self._auto_save_interval = auto_save_interval
|
211
|
+
self._last_save = datetime.now(_datetime.UTC)
|
212
|
+
|
127
213
|
super().__init__(checkpoint_file)
|
128
214
|
|
129
215
|
def _get_default_state(self) -> Dict[str, Any]:
|
@@ -139,7 +225,8 @@ class ChunkTracker(CheckpointTracker):
|
|
139
225
|
completed_chunks = 0
|
140
226
|
for chunk_id, chunk_data in data.get("chunks", {}).items():
|
141
227
|
chunk_state = ChunkState.from_dict(chunk_data)
|
142
|
-
self.
|
228
|
+
with self.lock:
|
229
|
+
self.chunks[chunk_id] = chunk_state
|
143
230
|
if chunk_state.status == "completed":
|
144
231
|
completed_chunks += 1
|
145
232
|
|
@@ -156,12 +243,47 @@ class ChunkTracker(CheckpointTracker):
|
|
156
243
|
"completed_count": self._completed_count,
|
157
244
|
}
|
158
245
|
|
246
|
+
def _mark_dirty(self):
|
247
|
+
"""Mark tracker as having pending changes."""
|
248
|
+
self._dirty = True
|
249
|
+
self._pending_changes += 1
|
250
|
+
|
251
|
+
# Auto-save based on batch size or time interval
|
252
|
+
now = datetime.now(_datetime.UTC)
|
253
|
+
time_since_last_save = (now - self._last_save).total_seconds()
|
254
|
+
|
255
|
+
if (
|
256
|
+
self._pending_changes >= self._save_batch_size
|
257
|
+
or time_since_last_save >= self._auto_save_interval
|
258
|
+
):
|
259
|
+
self._do_save()
|
260
|
+
|
261
|
+
def _do_save(self) -> bool:
|
262
|
+
"""Internal method to perform the actual save."""
|
263
|
+
super().save() # Parent method returns None but triggers save
|
264
|
+
# Reset dirty state since save was initiated successfully
|
265
|
+
self._dirty = False
|
266
|
+
self._pending_changes = 0
|
267
|
+
self._last_save = datetime.now(_datetime.UTC)
|
268
|
+
return True
|
269
|
+
|
270
|
+
def save(self, force: bool = False) -> bool:
|
271
|
+
"""Save state to checkpoint file, with batching optimization."""
|
272
|
+
if not force and not self._dirty:
|
273
|
+
return False
|
274
|
+
return self._do_save()
|
275
|
+
|
276
|
+
def flush(self):
|
277
|
+
"""Force save any pending changes."""
|
278
|
+
if self._dirty:
|
279
|
+
self._do_save()
|
280
|
+
|
159
281
|
def _archive_old_completed_chunks(self):
|
160
282
|
"""Remove old completed chunks from memory to prevent unbounded growth."""
|
161
283
|
if not self.archive_after_hours:
|
162
284
|
return
|
163
285
|
|
164
|
-
cutoff_time = datetime.
|
286
|
+
cutoff_time = datetime.now(_datetime.UTC) - timedelta(hours=self.archive_after_hours)
|
165
287
|
chunks_to_remove = []
|
166
288
|
|
167
289
|
for chunk_id, chunk in self.chunks.items():
|
@@ -176,7 +298,7 @@ class ChunkTracker(CheckpointTracker):
|
|
176
298
|
for chunk_id in chunks_to_remove:
|
177
299
|
del self.chunks[chunk_id]
|
178
300
|
logger.info(f"Archived {len(chunks_to_remove)} old completed chunks from memory")
|
179
|
-
self.
|
301
|
+
self._mark_dirty()
|
180
302
|
|
181
303
|
def _limit_completed_chunks_in_memory(self):
|
182
304
|
"""Keep only the most recent completed chunks in memory."""
|
@@ -194,7 +316,7 @@ class ChunkTracker(CheckpointTracker):
|
|
194
316
|
del self.chunks[chunk_id]
|
195
317
|
|
196
318
|
logger.info(f"Removed {to_remove} oldest completed chunks from memory")
|
197
|
-
self.
|
319
|
+
self._mark_dirty()
|
198
320
|
|
199
321
|
def add_chunk(
|
200
322
|
self, chunk_id: str, shard_name: str, shard_url: str, start_index: int, chunk_size: int
|
@@ -214,7 +336,7 @@ class ChunkTracker(CheckpointTracker):
|
|
214
336
|
chunk_size=chunk_size,
|
215
337
|
status="pending",
|
216
338
|
)
|
217
|
-
self.
|
339
|
+
self._mark_dirty()
|
218
340
|
|
219
341
|
# Periodically clean up old chunks
|
220
342
|
if len(self.chunks) % 100 == 0:
|
@@ -229,8 +351,8 @@ class ChunkTracker(CheckpointTracker):
|
|
229
351
|
chunk = self.chunks[chunk_id]
|
230
352
|
chunk.status = "assigned"
|
231
353
|
chunk.assigned_to = worker_id
|
232
|
-
chunk.assigned_at = datetime.
|
233
|
-
self.
|
354
|
+
chunk.assigned_at = datetime.now(_datetime.UTC)
|
355
|
+
self._mark_dirty()
|
234
356
|
|
235
357
|
def mark_completed(self, chunk_id: str):
|
236
358
|
"""Mark chunk as completed."""
|
@@ -240,7 +362,7 @@ class ChunkTracker(CheckpointTracker):
|
|
240
362
|
chunk.mark_completed() # This clears processed_ranges
|
241
363
|
if not was_completed:
|
242
364
|
self._completed_count += 1
|
243
|
-
self.
|
365
|
+
self._mark_dirty()
|
244
366
|
logger.debug(f"Chunk {chunk_id} marked as completed")
|
245
367
|
|
246
368
|
# Check if we need to clean up
|
@@ -254,7 +376,7 @@ class ChunkTracker(CheckpointTracker):
|
|
254
376
|
chunk.status = "pending" # Reset to pending for retry
|
255
377
|
chunk.assigned_to = None
|
256
378
|
chunk.assigned_at = None
|
257
|
-
self.
|
379
|
+
self._mark_dirty()
|
258
380
|
|
259
381
|
def mark_pending(self, chunk_id: str):
|
260
382
|
"""Mark chunk as pending (for manual reset)."""
|
@@ -265,7 +387,7 @@ class ChunkTracker(CheckpointTracker):
|
|
265
387
|
chunk.status = "pending"
|
266
388
|
chunk.assigned_to = None
|
267
389
|
chunk.assigned_at = None
|
268
|
-
self.
|
390
|
+
self._mark_dirty()
|
269
391
|
|
270
392
|
def release_worker_chunks(self, worker_id: str):
|
271
393
|
"""Release all chunks assigned to a worker."""
|
@@ -276,7 +398,8 @@ class ChunkTracker(CheckpointTracker):
|
|
276
398
|
chunk.assigned_to = None
|
277
399
|
chunk.assigned_at = None
|
278
400
|
released_chunks.append(chunk_id)
|
279
|
-
|
401
|
+
if released_chunks:
|
402
|
+
self._mark_dirty()
|
280
403
|
return released_chunks
|
281
404
|
|
282
405
|
def get_pending_chunks(self, shard_name: Optional[str] = None) -> List[str]:
|
@@ -330,7 +453,7 @@ class ChunkTracker(CheckpointTracker):
|
|
330
453
|
"""Get summary of all shards and their chunk status."""
|
331
454
|
shards = {}
|
332
455
|
|
333
|
-
for
|
456
|
+
for _chunk_id, chunk_state in self.chunks.items():
|
334
457
|
shard_name = chunk_state.shard_name
|
335
458
|
if shard_name not in shards:
|
336
459
|
shards[shard_name] = {
|
@@ -340,9 +463,11 @@ class ChunkTracker(CheckpointTracker):
|
|
340
463
|
"assigned_chunks": 0,
|
341
464
|
"failed_chunks": 0,
|
342
465
|
"is_complete": True,
|
466
|
+
"chunks": [],
|
343
467
|
}
|
344
468
|
|
345
469
|
shards[shard_name]["total_chunks"] += 1
|
470
|
+
shards[shard_name]["chunks"].append(chunk_state)
|
346
471
|
|
347
472
|
if chunk_state.status == "completed":
|
348
473
|
shards[shard_name]["completed_chunks"] += 1
|
@@ -361,7 +486,7 @@ class ChunkTracker(CheckpointTracker):
|
|
361
486
|
def get_incomplete_shards(self) -> Set[str]:
|
362
487
|
"""Get set of shard names that have incomplete chunks."""
|
363
488
|
incomplete = set()
|
364
|
-
for
|
489
|
+
for _chunk_id, chunk_state in self.chunks.items():
|
365
490
|
if chunk_state.status != "completed":
|
366
491
|
incomplete.add(chunk_state.shard_name)
|
367
492
|
return incomplete
|
@@ -373,22 +498,21 @@ class ChunkTracker(CheckpointTracker):
|
|
373
498
|
if not storage_manager.captions_path.exists():
|
374
499
|
return
|
375
500
|
|
376
|
-
import
|
377
|
-
import pyarrow.parquet as pq
|
501
|
+
import lance
|
378
502
|
|
379
503
|
# Check if item_index column exists
|
380
|
-
table_metadata =
|
504
|
+
table_metadata = lance.dataset(storage_manager.captions_path).schema
|
381
505
|
columns = ["job_id", "chunk_id", "item_key"]
|
382
|
-
if "item_index" in table_metadata.
|
506
|
+
if "item_index" in table_metadata.names:
|
383
507
|
columns.append("item_index")
|
384
508
|
|
385
509
|
# Process in batches to avoid loading entire table
|
386
510
|
batch_size = 10000
|
387
|
-
|
511
|
+
lance_dataset = lance.dataset(storage_manager.captions_path)
|
388
512
|
|
389
513
|
chunk_indices = defaultdict(set)
|
390
514
|
|
391
|
-
for batch in
|
515
|
+
for batch in lance_dataset.to_batches(batch_size=batch_size, columns=columns):
|
392
516
|
batch_dict = batch.to_pydict()
|
393
517
|
|
394
518
|
for i in range(len(batch_dict["chunk_id"])):
|
@@ -453,11 +577,12 @@ class ChunkTracker(CheckpointTracker):
|
|
453
577
|
self._process_chunk_indices(chunk_indices)
|
454
578
|
|
455
579
|
logger.info("Sync with storage completed")
|
456
|
-
self.
|
580
|
+
self._mark_dirty()
|
457
581
|
|
458
582
|
def _process_chunk_indices(self, chunk_indices: Dict[str, Set[int]]):
|
459
583
|
"""Process a batch of chunk indices."""
|
460
584
|
for chunk_id, abs_indices in chunk_indices.items():
|
585
|
+
logger.debug(f"Processing indices: {abs_indices} for chunk {chunk_id}")
|
461
586
|
if chunk_id not in self.chunks:
|
462
587
|
continue
|
463
588
|
|
@@ -494,39 +619,49 @@ class ChunkTracker(CheckpointTracker):
|
|
494
619
|
for start_idx, end_idx in ranges:
|
495
620
|
chunk.add_processed_range(start_idx, end_idx)
|
496
621
|
|
497
|
-
def mark_items_processed(self, chunk_id: str, start_idx: int, end_idx: int):
|
498
|
-
"""Mark a range of items as processed within a chunk
|
622
|
+
def mark_items_processed(self, chunk_id: str, start_idx: int, end_idx: int) -> None:
|
623
|
+
"""Mark a range of items as processed within a chunk."""
|
499
624
|
if chunk_id not in self.chunks:
|
500
|
-
logger.
|
625
|
+
logger.warning(f"Chunk {chunk_id} not found in tracker")
|
501
626
|
return
|
502
627
|
|
503
|
-
|
628
|
+
chunk_state = self.chunks[chunk_id]
|
504
629
|
|
505
|
-
# Convert absolute indices to chunk-relative
|
506
|
-
relative_start = start_idx -
|
507
|
-
relative_end = end_idx -
|
630
|
+
# Convert absolute indices to chunk-relative indices
|
631
|
+
relative_start = start_idx - chunk_state.start_index
|
632
|
+
relative_end = end_idx - chunk_state.start_index
|
508
633
|
|
509
|
-
#
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
634
|
+
# Ensure indices are within chunk bounds and maintain valid range
|
635
|
+
relative_start = max(0, relative_start)
|
636
|
+
relative_end = min(chunk_state.chunk_size - 1, relative_end)
|
637
|
+
|
638
|
+
# Skip invalid ranges where start > end
|
639
|
+
if relative_start > relative_end:
|
640
|
+
logger.warning(
|
641
|
+
f"Invalid range for chunk {chunk_id}: start={relative_start}, end={relative_end}, skipping"
|
515
642
|
)
|
516
643
|
return
|
517
644
|
|
518
|
-
#
|
519
|
-
|
645
|
+
# Invalidate cache before modifying ranges
|
646
|
+
chunk_state._invalidate_cache()
|
520
647
|
|
521
|
-
#
|
522
|
-
|
523
|
-
self._completed_count += 1
|
648
|
+
# Add to processed ranges
|
649
|
+
chunk_state.processed_ranges.append((relative_start, relative_end))
|
524
650
|
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
)
|
651
|
+
# Merge overlapping ranges
|
652
|
+
chunk_state.processed_ranges = chunk_state._merge_ranges(chunk_state.processed_ranges)
|
653
|
+
|
654
|
+
# logger.debug(
|
655
|
+
# f"Marked items {start_idx}-{end_idx} as processed in chunk {chunk_id} (relative indices: {relative_start}-{relative_end})"
|
656
|
+
# )
|
657
|
+
|
658
|
+
# Check if chunk is now complete
|
659
|
+
if chunk_state.get_unprocessed_ranges() == []:
|
660
|
+
logger.info(f"Chunk {chunk_id} is now complete")
|
661
|
+
chunk_state.status = "completed"
|
662
|
+
|
663
|
+
# Mark as dirty, will be saved based on batching logic
|
664
|
+
self._mark_dirty()
|
530
665
|
|
531
666
|
def get_chunk_with_unprocessed_items(self, chunk_id: str) -> Optional[Dict[str, Any]]:
|
532
667
|
"""Get chunk info with unprocessed item ranges."""
|