caption-flow 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +3 -2
- caption_flow/cli.py +65 -42
- caption_flow/models.py +6 -4
- caption_flow/monitor.py +13 -3
- caption_flow/orchestrator.py +1049 -264
- caption_flow/storage.py +579 -222
- caption_flow/utils/__init__.py +3 -1
- caption_flow/utils/auth.py +24 -25
- caption_flow/utils/checkpoint_tracker.py +92 -0
- caption_flow/utils/chunk_tracker.py +278 -194
- caption_flow/utils/dataset_loader.py +567 -73
- caption_flow/utils/image_processor.py +121 -1
- caption_flow/utils/prompt_template.py +137 -0
- caption_flow/utils/shard_processor.py +315 -0
- caption_flow/utils/shard_tracker.py +87 -0
- caption_flow/workers/base.py +228 -0
- caption_flow/workers/caption.py +1321 -0
- caption_flow/{worker_data.py → workers/data.py} +162 -234
- caption_flow-0.2.1.dist-info/METADATA +370 -0
- caption_flow-0.2.1.dist-info/RECORD +29 -0
- caption_flow/worker.py +0 -300
- caption_flow/worker_vllm.py +0 -1028
- caption_flow-0.1.0.dist-info/METADATA +0 -427
- caption_flow-0.1.0.dist-info/RECORD +0 -25
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.1.dist-info}/WHEEL +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.1.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,32 +1,81 @@
|
|
1
|
-
"""Chunk tracking
|
1
|
+
"""Chunk tracking using CheckpointTracker base class."""
|
2
2
|
|
3
|
-
import
|
3
|
+
from collections import defaultdict
|
4
4
|
import logging
|
5
5
|
from pathlib import Path
|
6
|
-
from typing import Set, Dict, List, Optional, Any
|
6
|
+
from typing import Set, Dict, List, Optional, Any, Tuple
|
7
7
|
from datetime import datetime
|
8
|
-
from dataclasses import dataclass, asdict
|
8
|
+
from dataclasses import dataclass, asdict, field
|
9
|
+
|
10
|
+
from .checkpoint_tracker import CheckpointTracker
|
9
11
|
|
10
12
|
logger = logging.getLogger(__name__)
|
11
13
|
|
12
14
|
|
13
15
|
@dataclass
|
14
16
|
class ChunkState:
|
15
|
-
"""State of a chunk."""
|
17
|
+
"""State of a chunk with item-level tracking."""
|
16
18
|
|
17
19
|
chunk_id: str
|
18
20
|
shard_name: str
|
21
|
+
shard_url: str
|
19
22
|
start_index: int
|
20
23
|
chunk_size: int
|
21
24
|
status: str # pending, assigned, completed, failed
|
25
|
+
|
26
|
+
processed_ranges: List[Tuple[int, int]] = field(default_factory=list) # [(start, end), ...]
|
27
|
+
processed_count: int = 0
|
28
|
+
|
22
29
|
completed_at: Optional[datetime] = None
|
23
30
|
assigned_to: Optional[str] = None
|
24
31
|
assigned_at: Optional[datetime] = None
|
25
32
|
|
33
|
+
def add_processed_range(self, start: int, end: int):
|
34
|
+
"""Add a processed range and merge if needed."""
|
35
|
+
# Add new range
|
36
|
+
self.processed_ranges.append((start, end))
|
37
|
+
|
38
|
+
# Sort and merge overlapping ranges
|
39
|
+
processed_ranges = sorted([list(r) for r in self.processed_ranges])
|
40
|
+
merged = []
|
41
|
+
for start, end in processed_ranges:
|
42
|
+
if merged and start <= merged[-1][1] + 1:
|
43
|
+
# Merge with previous range
|
44
|
+
merged[-1] = (merged[-1][0], max(merged[-1][1], end))
|
45
|
+
else:
|
46
|
+
merged.append((start, end))
|
47
|
+
|
48
|
+
self.processed_ranges = merged
|
49
|
+
|
50
|
+
# Update count
|
51
|
+
self.processed_count = sum(end - start + 1 for start, end in self.processed_ranges)
|
52
|
+
|
53
|
+
# Auto-complete if all items processed
|
54
|
+
if self.processed_count >= self.chunk_size:
|
55
|
+
self.status = "completed"
|
56
|
+
self.completed_at = datetime.utcnow()
|
57
|
+
|
58
|
+
def get_unprocessed_ranges(self) -> List[Tuple[int, int]]:
|
59
|
+
"""Get ranges that haven't been processed yet."""
|
60
|
+
if not self.processed_ranges:
|
61
|
+
return [(0, self.chunk_size - 1)]
|
62
|
+
|
63
|
+
unprocessed = []
|
64
|
+
current = 0
|
65
|
+
|
66
|
+
for start, end in self.processed_ranges:
|
67
|
+
if current < start:
|
68
|
+
unprocessed.append((current, start - 1))
|
69
|
+
current = max(current, end + 1)
|
70
|
+
|
71
|
+
if current < self.chunk_size:
|
72
|
+
unprocessed.append((current, self.chunk_size - 1))
|
73
|
+
|
74
|
+
return unprocessed
|
75
|
+
|
26
76
|
def to_dict(self):
|
27
77
|
"""Convert to dictionary for JSON serialization."""
|
28
78
|
d = asdict(self)
|
29
|
-
# Convert datetime objects to ISO format strings
|
30
79
|
if d["completed_at"]:
|
31
80
|
d["completed_at"] = d["completed_at"].isoformat()
|
32
81
|
if d["assigned_at"]:
|
@@ -36,90 +85,52 @@ class ChunkState:
|
|
36
85
|
@classmethod
|
37
86
|
def from_dict(cls, d: Dict):
|
38
87
|
"""Create from dictionary."""
|
39
|
-
# Convert ISO format strings back to datetime objects
|
40
88
|
if d.get("completed_at"):
|
41
89
|
d["completed_at"] = datetime.fromisoformat(d["completed_at"])
|
42
90
|
if d.get("assigned_at"):
|
43
91
|
d["assigned_at"] = datetime.fromisoformat(d["assigned_at"])
|
92
|
+
# Ensure processed_ranges exists
|
93
|
+
d.setdefault("processed_ranges", [])
|
94
|
+
d.setdefault("processed_count", 0)
|
44
95
|
return cls(**d)
|
45
96
|
|
46
97
|
|
47
|
-
class ChunkTracker:
|
98
|
+
class ChunkTracker(CheckpointTracker):
|
48
99
|
"""Tracks chunk processing state persistently."""
|
49
100
|
|
50
101
|
def __init__(self, checkpoint_file: Path):
|
51
|
-
self.checkpoint_file = checkpoint_file
|
52
102
|
self.chunks: Dict[str, ChunkState] = {}
|
53
103
|
self.completed_chunks: Set[str] = set()
|
54
|
-
|
55
|
-
|
56
|
-
def _load_checkpoint(self):
|
57
|
-
"""Load checkpoint from disk."""
|
58
|
-
if self.checkpoint_file.exists():
|
59
|
-
try:
|
60
|
-
with open(self.checkpoint_file, "r") as f:
|
61
|
-
data = json.load(f)
|
62
|
-
|
63
|
-
# Load chunk states
|
64
|
-
for chunk_id, chunk_data in data.get("chunks", {}).items():
|
65
|
-
chunk_state = ChunkState.from_dict(chunk_data)
|
66
|
-
self.chunks[chunk_id] = chunk_state
|
67
|
-
if chunk_state.status == "completed":
|
68
|
-
self.completed_chunks.add(chunk_id)
|
69
|
-
|
70
|
-
logger.info(
|
71
|
-
f"Loaded {len(self.chunks)} chunks from checkpoint, "
|
72
|
-
f"{len(self.completed_chunks)} completed"
|
73
|
-
)
|
74
|
-
|
75
|
-
except Exception as e:
|
76
|
-
logger.error(f"Error loading chunk checkpoint: {e}")
|
77
|
-
self.chunks = {}
|
78
|
-
self.completed_chunks = set()
|
79
|
-
|
80
|
-
def save_checkpoint(self):
|
81
|
-
"""Save checkpoint to disk."""
|
82
|
-
try:
|
83
|
-
# Ensure parent directory exists
|
84
|
-
self.checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
|
85
|
-
|
86
|
-
# Convert chunks to serializable format
|
87
|
-
data = {
|
88
|
-
"chunks": {chunk_id: chunk.to_dict() for chunk_id, chunk in self.chunks.items()},
|
89
|
-
"updated_at": datetime.utcnow().isoformat(),
|
90
|
-
}
|
104
|
+
super().__init__(checkpoint_file)
|
91
105
|
|
92
|
-
|
93
|
-
|
106
|
+
def _get_default_state(self) -> Dict[str, Any]:
|
107
|
+
"""Return default state structure for new checkpoints."""
|
108
|
+
return {"chunks": {}}
|
94
109
|
|
95
|
-
|
96
|
-
|
97
|
-
|
110
|
+
def _deserialize_state(self, data: Dict[str, Any]) -> None:
|
111
|
+
"""Deserialize loaded data into instance state."""
|
112
|
+
self.chunks = {}
|
113
|
+
self.completed_chunks = set()
|
98
114
|
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
115
|
+
# Load chunk states
|
116
|
+
for chunk_id, chunk_data in data.get("chunks", {}).items():
|
117
|
+
chunk_state = ChunkState.from_dict(chunk_data)
|
118
|
+
self.chunks[chunk_id] = chunk_state
|
119
|
+
if chunk_state.status == "completed":
|
120
|
+
self.completed_chunks.add(chunk_id)
|
105
121
|
|
106
|
-
|
122
|
+
logger.info(
|
123
|
+
f"Loaded {len(self.chunks)} chunks from checkpoint, "
|
124
|
+
f"{len(self.completed_chunks)} completed"
|
125
|
+
)
|
107
126
|
|
108
|
-
|
109
|
-
|
110
|
-
|
127
|
+
def _serialize_state(self) -> Dict[str, Any]:
|
128
|
+
"""Serialize instance state for saving."""
|
129
|
+
return {"chunks": {chunk_id: chunk.to_dict() for chunk_id, chunk in self.chunks.items()}}
|
111
130
|
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
try:
|
116
|
-
with open(self.checkpoint_file, "w") as f:
|
117
|
-
json.dump(data, f, indent=2)
|
118
|
-
logger.info("Saved checkpoint using fallback direct write")
|
119
|
-
except Exception as fallback_error:
|
120
|
-
logger.error(f"Fallback save also failed: {fallback_error}")
|
121
|
-
|
122
|
-
def add_chunk(self, chunk_id: str, shard_name: str, start_index: int, chunk_size: int) -> bool:
|
131
|
+
def add_chunk(
|
132
|
+
self, chunk_id: str, shard_name: str, shard_url: str, start_index: int, chunk_size: int
|
133
|
+
) -> bool:
|
123
134
|
"""Add a new chunk. Returns False if chunk already exists and is completed."""
|
124
135
|
if chunk_id in self.completed_chunks:
|
125
136
|
logger.debug(f"Chunk {chunk_id} already completed, skipping")
|
@@ -129,11 +140,12 @@ class ChunkTracker:
|
|
129
140
|
self.chunks[chunk_id] = ChunkState(
|
130
141
|
chunk_id=chunk_id,
|
131
142
|
shard_name=shard_name,
|
143
|
+
shard_url=shard_url, # Now included
|
132
144
|
start_index=start_index,
|
133
145
|
chunk_size=chunk_size,
|
134
146
|
status="pending",
|
135
147
|
)
|
136
|
-
self.
|
148
|
+
self.save()
|
137
149
|
|
138
150
|
return True
|
139
151
|
|
@@ -144,7 +156,7 @@ class ChunkTracker:
|
|
144
156
|
chunk.status = "assigned"
|
145
157
|
chunk.assigned_to = worker_id
|
146
158
|
chunk.assigned_at = datetime.utcnow()
|
147
|
-
self.
|
159
|
+
self.save()
|
148
160
|
|
149
161
|
def mark_completed(self, chunk_id: str):
|
150
162
|
"""Mark chunk as completed."""
|
@@ -153,7 +165,7 @@ class ChunkTracker:
|
|
153
165
|
chunk.status = "completed"
|
154
166
|
chunk.completed_at = datetime.utcnow()
|
155
167
|
self.completed_chunks.add(chunk_id)
|
156
|
-
self.
|
168
|
+
self.save()
|
157
169
|
logger.info(f"Chunk {chunk_id} marked as completed")
|
158
170
|
|
159
171
|
def mark_failed(self, chunk_id: str):
|
@@ -163,7 +175,7 @@ class ChunkTracker:
|
|
163
175
|
chunk.status = "pending" # Reset to pending for retry
|
164
176
|
chunk.assigned_to = None
|
165
177
|
chunk.assigned_at = None
|
166
|
-
self.
|
178
|
+
self.save()
|
167
179
|
|
168
180
|
def mark_pending(self, chunk_id: str):
|
169
181
|
"""Mark chunk as pending (for manual reset)."""
|
@@ -172,7 +184,7 @@ class ChunkTracker:
|
|
172
184
|
chunk.status = "pending"
|
173
185
|
chunk.assigned_to = None
|
174
186
|
chunk.assigned_at = None
|
175
|
-
self.
|
187
|
+
self.save()
|
176
188
|
|
177
189
|
def release_worker_chunks(self, worker_id: str):
|
178
190
|
"""Release all chunks assigned to a worker."""
|
@@ -183,7 +195,7 @@ class ChunkTracker:
|
|
183
195
|
chunk.assigned_to = None
|
184
196
|
chunk.assigned_at = None
|
185
197
|
released_chunks.append(chunk_id)
|
186
|
-
self.
|
198
|
+
self.save()
|
187
199
|
return released_chunks
|
188
200
|
|
189
201
|
def get_pending_chunks(self, shard_name: Optional[str] = None) -> List[str]:
|
@@ -206,14 +218,17 @@ class ChunkTracker:
|
|
206
218
|
|
207
219
|
def get_stats(self) -> Dict[str, int]:
|
208
220
|
"""Get chunk statistics."""
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
221
|
+
base_stats = super().get_stats()
|
222
|
+
base_stats.update(
|
223
|
+
{
|
224
|
+
"total": len(self.chunks),
|
225
|
+
"pending": sum(1 for c in self.chunks.values() if c.status == "pending"),
|
226
|
+
"assigned": sum(1 for c in self.chunks.values() if c.status == "assigned"),
|
227
|
+
"completed": len(self.completed_chunks),
|
228
|
+
"failed": sum(1 for c in self.chunks.values() if c.status == "failed"),
|
229
|
+
}
|
230
|
+
)
|
231
|
+
return base_stats
|
217
232
|
|
218
233
|
def get_shards_summary(self) -> Dict[str, Dict[str, Any]]:
|
219
234
|
"""Get summary of all shards and their chunk status."""
|
@@ -221,8 +236,20 @@ class ChunkTracker:
|
|
221
236
|
|
222
237
|
for chunk_id, chunk_state in self.chunks.items():
|
223
238
|
shard_name = chunk_state.shard_name
|
224
|
-
|
225
|
-
|
239
|
+
|
240
|
+
# For virtual HF dataset shards, normalize the shard name
|
241
|
+
if shard_name.startswith("hf_dataset:"):
|
242
|
+
parts = shard_name.split(":")
|
243
|
+
if len(parts) >= 4 and parts[2] == "chunk":
|
244
|
+
# Use just the dataset identifier as the shard name
|
245
|
+
normalized_shard_name = ":".join(parts[:2])
|
246
|
+
else:
|
247
|
+
normalized_shard_name = shard_name
|
248
|
+
else:
|
249
|
+
normalized_shard_name = shard_name
|
250
|
+
|
251
|
+
if normalized_shard_name not in shards:
|
252
|
+
shards[normalized_shard_name] = {
|
226
253
|
"total_chunks": 0,
|
227
254
|
"completed_chunks": 0,
|
228
255
|
"pending_chunks": 0,
|
@@ -232,20 +259,20 @@ class ChunkTracker:
|
|
232
259
|
"chunks": [],
|
233
260
|
}
|
234
261
|
|
235
|
-
shards[
|
236
|
-
shards[
|
262
|
+
shards[normalized_shard_name]["chunks"].append(chunk_state)
|
263
|
+
shards[normalized_shard_name]["total_chunks"] += 1
|
237
264
|
|
238
265
|
if chunk_state.status == "completed":
|
239
|
-
shards[
|
266
|
+
shards[normalized_shard_name]["completed_chunks"] += 1
|
240
267
|
elif chunk_state.status == "pending":
|
241
|
-
shards[
|
242
|
-
shards[
|
268
|
+
shards[normalized_shard_name]["pending_chunks"] += 1
|
269
|
+
shards[normalized_shard_name]["is_complete"] = False
|
243
270
|
elif chunk_state.status == "assigned":
|
244
|
-
shards[
|
245
|
-
shards[
|
271
|
+
shards[normalized_shard_name]["assigned_chunks"] += 1
|
272
|
+
shards[normalized_shard_name]["is_complete"] = False
|
246
273
|
elif chunk_state.status == "failed":
|
247
|
-
shards[
|
248
|
-
shards[
|
274
|
+
shards[normalized_shard_name]["failed_chunks"] += 1
|
275
|
+
shards[normalized_shard_name]["is_complete"] = False
|
249
276
|
|
250
277
|
return shards
|
251
278
|
|
@@ -257,109 +284,166 @@ class ChunkTracker:
|
|
257
284
|
incomplete.add(chunk_state.shard_name)
|
258
285
|
return incomplete
|
259
286
|
|
260
|
-
def
|
261
|
-
"""
|
262
|
-
|
287
|
+
async def sync_with_storage(self, storage_manager):
|
288
|
+
"""Sync chunk state with storage to detect processed items."""
|
289
|
+
logger.info("Syncing chunk state with storage...")
|
263
290
|
|
264
|
-
|
265
|
-
|
266
|
-
if shard_name not in shards:
|
267
|
-
shards[shard_name] = {
|
268
|
-
"total_chunks": 0,
|
269
|
-
"completed_chunks": 0,
|
270
|
-
"pending_chunks": 0,
|
271
|
-
"assigned_chunks": 0,
|
272
|
-
"failed_chunks": 0,
|
273
|
-
"is_complete": True,
|
274
|
-
"chunks": [],
|
275
|
-
}
|
291
|
+
if storage_manager.captions_path.exists():
|
292
|
+
import pyarrow.parquet as pq
|
276
293
|
|
277
|
-
|
278
|
-
|
294
|
+
# Read all relevant columns
|
295
|
+
columns = ["job_id", "chunk_id", "item_key"]
|
296
|
+
# Check if item_index column exists (new format)
|
297
|
+
table_metadata = pq.read_metadata(storage_manager.captions_path)
|
298
|
+
if "item_index" in table_metadata.schema.names:
|
299
|
+
columns.append("item_index")
|
300
|
+
|
301
|
+
table = pq.read_table(storage_manager.captions_path, columns=columns)
|
302
|
+
|
303
|
+
# Build lookup of chunk_id -> processed indices
|
304
|
+
chunk_indices = defaultdict(set)
|
305
|
+
|
306
|
+
for i in range(len(table)):
|
307
|
+
chunk_id = table["chunk_id"][i].as_py()
|
308
|
+
if not chunk_id:
|
309
|
+
continue
|
310
|
+
|
311
|
+
# Get the chunk to find its boundaries
|
312
|
+
if chunk_id not in self.chunks:
|
313
|
+
# Try to recreate chunk from chunk_id
|
314
|
+
parts = chunk_id.rsplit("_chunk_", 1)
|
315
|
+
if len(parts) != 2:
|
316
|
+
continue
|
317
|
+
|
318
|
+
shard_name = parts[0]
|
319
|
+
try:
|
320
|
+
start_idx = int(parts[1])
|
321
|
+
except ValueError:
|
322
|
+
continue
|
323
|
+
|
324
|
+
# Infer shard URL and create chunk with default size
|
325
|
+
if shard_name.replace("_", "/") in chunk_id or "_" in shard_name:
|
326
|
+
# HF dataset
|
327
|
+
dataset_path = shard_name.replace("_", "/")
|
328
|
+
shard_url = f"hf_dataset:{dataset_path}:chunk:{start_idx}"
|
329
|
+
else:
|
330
|
+
# WebDataset
|
331
|
+
shard_url = f"unknown://{shard_name}.tar"
|
332
|
+
|
333
|
+
self.chunks[chunk_id] = ChunkState(
|
334
|
+
chunk_id=chunk_id,
|
335
|
+
shard_name=shard_name,
|
336
|
+
shard_url=shard_url,
|
337
|
+
start_index=start_idx,
|
338
|
+
chunk_size=10000, # Default - should match your chunk size
|
339
|
+
status="pending",
|
340
|
+
)
|
341
|
+
|
342
|
+
chunk = self.chunks[chunk_id]
|
343
|
+
|
344
|
+
# Get item index
|
345
|
+
if "item_index" in table.column_names:
|
346
|
+
item_index = table["item_index"][i].as_py()
|
347
|
+
else:
|
348
|
+
# Try to extract from item_key
|
349
|
+
item_key = table["item_key"][i].as_py()
|
350
|
+
try:
|
351
|
+
item_index = int(item_key.split("_")[-1])
|
352
|
+
except:
|
353
|
+
continue
|
354
|
+
|
355
|
+
if item_index is None:
|
356
|
+
continue
|
357
|
+
|
358
|
+
# CRITICAL: Validate that this item belongs to this chunk
|
359
|
+
if (
|
360
|
+
item_index < chunk.start_index
|
361
|
+
or item_index >= chunk.start_index + chunk.chunk_size
|
362
|
+
):
|
363
|
+
logger.warning(
|
364
|
+
f"Item index {item_index} doesn't belong to chunk {chunk_id} "
|
365
|
+
f"(boundaries: {chunk.start_index}-{chunk.start_index + chunk.chunk_size - 1})"
|
366
|
+
)
|
367
|
+
continue
|
368
|
+
|
369
|
+
# Store the absolute index for now
|
370
|
+
chunk_indices[chunk_id].add(item_index)
|
371
|
+
|
372
|
+
# Convert absolute indices to relative and mark as processed
|
373
|
+
for chunk_id, abs_indices in chunk_indices.items():
|
374
|
+
if chunk_id not in self.chunks:
|
375
|
+
continue
|
376
|
+
|
377
|
+
chunk = self.chunks[chunk_id]
|
378
|
+
|
379
|
+
# Convert to relative indices and group into ranges
|
380
|
+
rel_indices = []
|
381
|
+
for abs_idx in sorted(abs_indices):
|
382
|
+
rel_idx = abs_idx - chunk.start_index
|
383
|
+
if 0 <= rel_idx < chunk.chunk_size:
|
384
|
+
rel_indices.append(rel_idx)
|
385
|
+
|
386
|
+
# Group consecutive indices into ranges
|
387
|
+
if rel_indices:
|
388
|
+
ranges = []
|
389
|
+
start = rel_indices[0]
|
390
|
+
end = rel_indices[0]
|
391
|
+
|
392
|
+
for idx in rel_indices[1:]:
|
393
|
+
if idx == end + 1:
|
394
|
+
end = idx
|
395
|
+
else:
|
396
|
+
ranges.append((start, end))
|
397
|
+
start = idx
|
398
|
+
end = idx
|
399
|
+
|
400
|
+
ranges.append((start, end))
|
401
|
+
|
402
|
+
# Mark ranges as processed
|
403
|
+
for start_idx, end_idx in ranges:
|
404
|
+
chunk.add_processed_range(start_idx, end_idx)
|
405
|
+
|
406
|
+
logger.info(f"Synced {len(chunk_indices)} chunks with processed items")
|
407
|
+
self.save()
|
408
|
+
|
409
|
+
def mark_items_processed(self, chunk_id: str, start_idx: int, end_idx: int):
|
410
|
+
"""Mark a range of items as processed within a chunk (expects ABSOLUTE indices)."""
|
411
|
+
if chunk_id not in self.chunks:
|
412
|
+
logger.error(f"Unknown chunk: {chunk_id}")
|
413
|
+
return
|
279
414
|
|
280
|
-
|
281
|
-
shards[shard_name]["completed_chunks"] += 1
|
282
|
-
elif chunk_state.status == "pending":
|
283
|
-
shards[shard_name]["pending_chunks"] += 1
|
284
|
-
shards[shard_name]["is_complete"] = False
|
285
|
-
elif chunk_state.status == "assigned":
|
286
|
-
shards[shard_name]["assigned_chunks"] += 1
|
287
|
-
shards[shard_name]["is_complete"] = False
|
288
|
-
elif chunk_state.status == "failed":
|
289
|
-
shards[shard_name]["failed_chunks"] += 1
|
290
|
-
shards[shard_name]["is_complete"] = False
|
415
|
+
chunk = self.chunks[chunk_id]
|
291
416
|
|
292
|
-
|
417
|
+
# Convert absolute indices to chunk-relative
|
418
|
+
relative_start = start_idx - chunk.start_index
|
419
|
+
relative_end = end_idx - chunk.start_index
|
293
420
|
|
294
|
-
|
295
|
-
|
296
|
-
|
421
|
+
# Validate boundaries
|
422
|
+
if relative_start < 0 or relative_end >= chunk.chunk_size:
|
423
|
+
logger.error(
|
424
|
+
f"Invalid indices for chunk {chunk_id}: "
|
425
|
+
f"absolute {start_idx}-{end_idx} (relative {relative_start}-{relative_end}) "
|
426
|
+
f"outside chunk bounds [{chunk.start_index}, {chunk.start_index + chunk.chunk_size - 1}]"
|
427
|
+
)
|
428
|
+
return
|
297
429
|
|
298
|
-
#
|
299
|
-
|
300
|
-
import pyarrow.parquet as pq
|
430
|
+
# Add the relative range
|
431
|
+
chunk.add_processed_range(relative_start, relative_end)
|
301
432
|
|
302
|
-
|
303
|
-
|
304
|
-
|
433
|
+
# If chunk is now complete, update completed set
|
434
|
+
if chunk.status == "completed":
|
435
|
+
self.completed_chunks.add(chunk_id)
|
305
436
|
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
437
|
+
self.save()
|
438
|
+
logger.debug(
|
439
|
+
f"Marked items {start_idx}-{end_idx} as processed in chunk {chunk_id} "
|
440
|
+
f"(relative indices: {relative_start}-{relative_end})"
|
441
|
+
)
|
311
442
|
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
parts = chunk_id.rsplit("_chunk_", 1)
|
320
|
-
if len(parts) == 2:
|
321
|
-
shard_name = parts[0]
|
322
|
-
try:
|
323
|
-
start_idx = int(parts[1])
|
324
|
-
# We don't know exact chunk size, but mark it as completed
|
325
|
-
self.chunks[chunk_id] = ChunkState(
|
326
|
-
chunk_id=chunk_id,
|
327
|
-
shard_name=shard_name,
|
328
|
-
start_index=start_idx,
|
329
|
-
chunk_size=1000, # Default chunk size
|
330
|
-
status="completed",
|
331
|
-
completed_at=datetime.utcnow(),
|
332
|
-
)
|
333
|
-
self.completed_chunks.add(chunk_id)
|
334
|
-
except ValueError:
|
335
|
-
logger.warning(f"Could not parse chunk_id: {chunk_id}")
|
336
|
-
|
337
|
-
logger.info(f"Found {len(existing_chunk_ids)} completed chunks in storage")
|
338
|
-
|
339
|
-
# Also check by job_id pattern if chunk_id column doesn't exist
|
340
|
-
else:
|
341
|
-
for job_id in existing_job_ids:
|
342
|
-
# Extract chunk_id from job_id (format: chunk_id_item_key)
|
343
|
-
if "_chunk_" in job_id:
|
344
|
-
parts = job_id.split("_")
|
345
|
-
# Find the chunk part
|
346
|
-
for i, part in enumerate(parts):
|
347
|
-
if part == "chunk" and i + 1 < len(parts):
|
348
|
-
try:
|
349
|
-
# Reconstruct chunk_id
|
350
|
-
chunk_idx = int(parts[i + 1])
|
351
|
-
shard_parts = parts[:i]
|
352
|
-
chunk_id = f"{'_'.join(shard_parts)}_chunk_{chunk_idx}"
|
353
|
-
|
354
|
-
if chunk_id not in self.completed_chunks:
|
355
|
-
self.completed_chunks.add(chunk_id)
|
356
|
-
logger.debug(
|
357
|
-
f"Marked chunk {chunk_id} as completed from job_id"
|
358
|
-
)
|
359
|
-
break
|
360
|
-
except ValueError:
|
361
|
-
continue
|
362
|
-
|
363
|
-
logger.info(f"Inferred {len(self.completed_chunks)} completed chunks from job_ids")
|
364
|
-
|
365
|
-
self.save_checkpoint()
|
443
|
+
def get_chunk_with_unprocessed_items(self, chunk_id: str) -> Optional[Dict[str, Any]]:
|
444
|
+
"""Get chunk info including unprocessed ranges."""
|
445
|
+
if chunk_id not in self.chunks:
|
446
|
+
return None
|
447
|
+
|
448
|
+
chunk = self.chunks[chunk_id]
|
449
|
+
return {"chunk": chunk.to_dict(), "unprocessed_ranges": chunk.get_unprocessed_ranges()}
|