caption-flow 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/cli.py +2 -1
- caption_flow/models.py +108 -1
- caption_flow/monitor.py +1 -1
- caption_flow/orchestrator.py +423 -1595
- caption_flow/processors/__init__.py +11 -0
- caption_flow/processors/base.py +219 -0
- caption_flow/processors/huggingface.py +832 -0
- caption_flow/processors/local_filesystem.py +683 -0
- caption_flow/processors/webdataset.py +782 -0
- caption_flow/storage.py +415 -406
- caption_flow/utils/checkpoint_tracker.py +2 -2
- caption_flow/utils/chunk_tracker.py +94 -35
- caption_flow/utils/dataset_loader.py +64 -522
- caption_flow/utils/dataset_metadata_cache.py +67 -0
- caption_flow/utils/image_processor.py +1 -4
- caption_flow/utils/shard_processor.py +4 -200
- caption_flow/utils/shard_tracker.py +1 -5
- caption_flow/workers/base.py +3 -3
- caption_flow/workers/caption.py +416 -792
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/METADATA +29 -27
- caption_flow-0.2.3.dist-info/RECORD +35 -0
- caption_flow-0.2.1.dist-info/RECORD +0 -29
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/WHEEL +0 -0
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,782 @@
|
|
1
|
+
"""WebDataset processor implementation."""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import threading
|
5
|
+
from typing import Dict, Any, List, Optional, Iterator, Set, Deque, Tuple
|
6
|
+
from collections import deque, defaultdict
|
7
|
+
from pathlib import Path
|
8
|
+
import json
|
9
|
+
import io
|
10
|
+
from datetime import datetime
|
11
|
+
from PIL import Image
|
12
|
+
from caption_flow.storage import StorageManager
|
13
|
+
|
14
|
+
from .base import OrchestratorProcessor, WorkerProcessor, ProcessorConfig, WorkUnit, WorkResult
|
15
|
+
from ..utils import DatasetLoader, ChunkTracker
|
16
|
+
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
logger.setLevel(logging.INFO)
|
19
|
+
|
20
|
+
|
21
|
+
class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
22
|
+
"""Orchestrator processor for WebDataset shards."""
|
23
|
+
|
24
|
+
def __init__(self):
|
25
|
+
logger.debug("Initializing WebDatasetOrchestratorProcessor")
|
26
|
+
self.dataset_loader: Optional[DatasetLoader] = None
|
27
|
+
self.chunk_tracker: Optional[ChunkTracker] = None
|
28
|
+
self.chunk_size: int = 1000
|
29
|
+
|
30
|
+
# Work unit management
|
31
|
+
self.work_units: Dict[str, WorkUnit] = {}
|
32
|
+
self.pending_units: Deque[str] = deque()
|
33
|
+
self.assigned_units: Dict[str, Set[str]] = defaultdict(set) # worker_id -> unit_ids
|
34
|
+
self.lock = threading.Lock()
|
35
|
+
|
36
|
+
# Shard processing state
|
37
|
+
self.all_shards: List[str] = []
|
38
|
+
self.current_shard_index = 0
|
39
|
+
self.current_shard_items = 0
|
40
|
+
|
41
|
+
# Background thread for creating work units
|
42
|
+
self.unit_creation_thread: Optional[threading.Thread] = None
|
43
|
+
self.stop_creation = threading.Event()
|
44
|
+
|
45
|
+
def initialize(self, config: ProcessorConfig, storage: StorageManager) -> None:
|
46
|
+
"""Initialize WebDataset processor."""
|
47
|
+
logger.debug("Initializing orchestrator with config: %s", config.config)
|
48
|
+
cfg = config.config
|
49
|
+
|
50
|
+
# Dataset configuration
|
51
|
+
dataset_cfg = cfg.get("dataset", {})
|
52
|
+
dataset_path = dataset_cfg.get("dataset_path")
|
53
|
+
dataset_type = dataset_cfg.get("dataset_type", "huggingface")
|
54
|
+
dataset_split = dataset_cfg.get("dataset_split", "train")
|
55
|
+
image_column = dataset_cfg.get("dataset_image_column", "image")
|
56
|
+
|
57
|
+
# Chunk settings
|
58
|
+
self.chunk_size = cfg.get("chunk_size", 1000)
|
59
|
+
self.min_buffer = cfg.get("min_chunk_buffer", 10)
|
60
|
+
self.buffer_multiplier = cfg.get("chunk_buffer_multiplier", 3)
|
61
|
+
|
62
|
+
logger.debug(
|
63
|
+
"Chunk size: %d, min_buffer: %d, buffer_multiplier: %d",
|
64
|
+
self.chunk_size,
|
65
|
+
self.min_buffer,
|
66
|
+
self.buffer_multiplier,
|
67
|
+
)
|
68
|
+
|
69
|
+
# Initialize dataset loader
|
70
|
+
if dataset_path:
|
71
|
+
checkpoint_dir = Path(cfg.get("checkpoint_dir", "./checkpoints"))
|
72
|
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
73
|
+
logger.debug("Checkpoint dir: %s", checkpoint_dir)
|
74
|
+
|
75
|
+
self.dataset_loader = DatasetLoader(
|
76
|
+
dataset_path=dataset_path,
|
77
|
+
dataset_type=dataset_type,
|
78
|
+
split=dataset_split,
|
79
|
+
image_column=image_column,
|
80
|
+
cache_dir=checkpoint_dir,
|
81
|
+
)
|
82
|
+
logger.debug("DatasetLoader initialized")
|
83
|
+
|
84
|
+
self.chunk_tracker = ChunkTracker(checkpoint_dir / "chunks.json")
|
85
|
+
logger.debug("ChunkTracker initialized at %s", checkpoint_dir / "chunks.json")
|
86
|
+
|
87
|
+
# Get all shards
|
88
|
+
self.all_shards = self.dataset_loader.get_shard_list()
|
89
|
+
logger.debug("All shards: %s", self.all_shards)
|
90
|
+
|
91
|
+
# Restore existing state from chunk tracker
|
92
|
+
self._restore_state(storage=storage)
|
93
|
+
|
94
|
+
# Start background unit creation
|
95
|
+
self.unit_creation_thread = threading.Thread(
|
96
|
+
target=self._create_units_background, daemon=True
|
97
|
+
)
|
98
|
+
self.unit_creation_thread.start()
|
99
|
+
logger.debug("Unit creation thread started")
|
100
|
+
else:
|
101
|
+
logger.error("No dataset_path provided in config")
|
102
|
+
|
103
|
+
def _restore_state(self, storage: StorageManager) -> None:
|
104
|
+
"""Restore state from chunk tracker."""
|
105
|
+
logger.debug("Restoring state from chunk tracker")
|
106
|
+
if not self.chunk_tracker:
|
107
|
+
return
|
108
|
+
|
109
|
+
shards_summary = self.chunk_tracker.get_shards_summary()
|
110
|
+
|
111
|
+
# Get all processed job_ids from storage
|
112
|
+
all_processed_jobs = storage.get_all_processed_job_ids()
|
113
|
+
|
114
|
+
with self.lock:
|
115
|
+
for shard_name, shard_info in shards_summary.items():
|
116
|
+
for chunk_state in shard_info["chunks"]:
|
117
|
+
# Calculate actual unprocessed ranges based on what's in storage
|
118
|
+
chunk_range = (
|
119
|
+
chunk_state.start_index,
|
120
|
+
chunk_state.start_index + chunk_state.chunk_size - 1,
|
121
|
+
)
|
122
|
+
|
123
|
+
# Get processed indices for this chunk
|
124
|
+
processed_ranges = self.chunk_tracker.get_processed_indices_for_chunk(
|
125
|
+
chunk_state.chunk_id, all_processed_jobs
|
126
|
+
)
|
127
|
+
|
128
|
+
# Calculate unprocessed ranges
|
129
|
+
unprocessed_ranges = self._subtract_ranges([chunk_range], processed_ranges)
|
130
|
+
|
131
|
+
if unprocessed_ranges:
|
132
|
+
# Create work unit for unprocessed items
|
133
|
+
logger.debug(f"Creating WorkUnit for chunk {chunk_state}")
|
134
|
+
unit = WorkUnit(
|
135
|
+
unit_id=chunk_state.chunk_id,
|
136
|
+
chunk_id=chunk_state.chunk_id,
|
137
|
+
source_id=shard_name,
|
138
|
+
data={
|
139
|
+
"shard_url": chunk_state.shard_url,
|
140
|
+
"start_index": chunk_state.start_index,
|
141
|
+
"chunk_size": chunk_state.chunk_size,
|
142
|
+
"unprocessed_ranges": unprocessed_ranges,
|
143
|
+
},
|
144
|
+
metadata={
|
145
|
+
"shard_name": shard_name,
|
146
|
+
"chunk_index": chunk_state.start_index // self.chunk_size,
|
147
|
+
},
|
148
|
+
)
|
149
|
+
|
150
|
+
self.work_units[unit.unit_id] = unit
|
151
|
+
self.pending_units.append(unit.unit_id)
|
152
|
+
|
153
|
+
def _create_units_background(self) -> None:
|
154
|
+
"""Background thread to create work units on demand."""
|
155
|
+
logger.info("Starting work unit creation thread")
|
156
|
+
|
157
|
+
shard_iter = iter(self.all_shards)
|
158
|
+
current_shard_url = None
|
159
|
+
current_shard_name = None
|
160
|
+
current_shard_items = 0
|
161
|
+
current_index = 0
|
162
|
+
|
163
|
+
while not self.stop_creation.is_set():
|
164
|
+
# Check if we need more units
|
165
|
+
with self.lock:
|
166
|
+
pending_count = len(self.pending_units)
|
167
|
+
assigned_count = sum(len(units) for units in self.assigned_units.values())
|
168
|
+
worker_count = max(1, len(self.assigned_units))
|
169
|
+
|
170
|
+
target_buffer = max(self.min_buffer, worker_count * self.buffer_multiplier)
|
171
|
+
units_needed = max(0, target_buffer - (pending_count + assigned_count))
|
172
|
+
logger.debug(
|
173
|
+
"pending_count=%d assigned_count=%d worker_count=%d target_buffer=%d units_needed=%d",
|
174
|
+
pending_count,
|
175
|
+
assigned_count,
|
176
|
+
worker_count,
|
177
|
+
target_buffer,
|
178
|
+
units_needed,
|
179
|
+
)
|
180
|
+
|
181
|
+
if units_needed == 0:
|
182
|
+
threading.Event().wait(5)
|
183
|
+
continue
|
184
|
+
|
185
|
+
# Create units as needed
|
186
|
+
units_created = 0
|
187
|
+
|
188
|
+
while units_created < units_needed and not self.stop_creation.is_set():
|
189
|
+
# Load next shard if needed
|
190
|
+
if current_shard_url is None or current_index >= current_shard_items:
|
191
|
+
try:
|
192
|
+
current_shard_url = next(shard_iter)
|
193
|
+
current_shard_name = Path(current_shard_url).stem
|
194
|
+
|
195
|
+
logger.debug("Loading shard: %s", current_shard_url)
|
196
|
+
# Count items in shard
|
197
|
+
current_shard_items = sum(
|
198
|
+
1 for _ in self.dataset_loader.iterate_shard(current_shard_url)
|
199
|
+
)
|
200
|
+
logger.info(
|
201
|
+
f"Processing shard {current_shard_name} with {current_shard_items} items"
|
202
|
+
)
|
203
|
+
current_index = 0
|
204
|
+
|
205
|
+
except StopIteration:
|
206
|
+
logger.info("All shards processed")
|
207
|
+
break
|
208
|
+
except Exception as e:
|
209
|
+
logger.error("Error loading shard: %s", e)
|
210
|
+
break
|
211
|
+
|
212
|
+
# Create work unit
|
213
|
+
if current_shard_url and current_index < current_shard_items:
|
214
|
+
chunk_size = min(self.chunk_size, current_shard_items - current_index)
|
215
|
+
unit_id = f"{current_shard_name}:chunk:{current_index // self.chunk_size}"
|
216
|
+
|
217
|
+
with self.lock:
|
218
|
+
# Check if this unit already exists in work_units
|
219
|
+
if unit_id in self.work_units:
|
220
|
+
logger.debug(
|
221
|
+
f"Unit {unit_id} already exists in work_units, skipping creation"
|
222
|
+
)
|
223
|
+
current_index += self.chunk_size
|
224
|
+
continue
|
225
|
+
|
226
|
+
# Check if chunk is already completed or has no unprocessed items
|
227
|
+
if self.chunk_tracker:
|
228
|
+
chunk_state = self.chunk_tracker.chunks.get(unit_id)
|
229
|
+
|
230
|
+
if chunk_state:
|
231
|
+
# Check if completed
|
232
|
+
if chunk_state.status == "completed":
|
233
|
+
logger.debug(f"Unit {unit_id} already completed, skipping")
|
234
|
+
current_index += self.chunk_size
|
235
|
+
continue
|
236
|
+
|
237
|
+
# Check if has unprocessed items
|
238
|
+
unprocessed_ranges = chunk_state.get_unprocessed_ranges()
|
239
|
+
if not unprocessed_ranges:
|
240
|
+
logger.debug(
|
241
|
+
f"Unit {unit_id} has no unprocessed items, skipping"
|
242
|
+
)
|
243
|
+
current_index += self.chunk_size
|
244
|
+
continue
|
245
|
+
|
246
|
+
# If chunk exists but has unprocessed items, use those ranges
|
247
|
+
logger.debug(
|
248
|
+
f"Existing chunk {unit_id} has unprocessed ranges: {unprocessed_ranges}"
|
249
|
+
)
|
250
|
+
|
251
|
+
unit = WorkUnit(
|
252
|
+
unit_id=unit_id,
|
253
|
+
chunk_id=unit_id,
|
254
|
+
source_id=current_shard_name,
|
255
|
+
data={
|
256
|
+
"shard_url": current_shard_url,
|
257
|
+
"start_index": current_index,
|
258
|
+
"chunk_size": chunk_size,
|
259
|
+
"unprocessed_ranges": [
|
260
|
+
(
|
261
|
+
r[0] + chunk_state.start_index,
|
262
|
+
r[1] + chunk_state.start_index,
|
263
|
+
)
|
264
|
+
for r in unprocessed_ranges
|
265
|
+
], # Convert relative to absolute
|
266
|
+
},
|
267
|
+
metadata={
|
268
|
+
"shard_name": current_shard_name,
|
269
|
+
"chunk_index": current_index // self.chunk_size,
|
270
|
+
},
|
271
|
+
)
|
272
|
+
else:
|
273
|
+
# New chunk
|
274
|
+
logger.debug(
|
275
|
+
"Creating new work unit: unit_id=%s shard=%s start_index=%d chunk_size=%d",
|
276
|
+
unit_id,
|
277
|
+
current_shard_name,
|
278
|
+
current_index,
|
279
|
+
chunk_size,
|
280
|
+
)
|
281
|
+
|
282
|
+
unit = WorkUnit(
|
283
|
+
unit_id=unit_id,
|
284
|
+
chunk_id=unit_id,
|
285
|
+
source_id=current_shard_name,
|
286
|
+
data={
|
287
|
+
"shard_url": current_shard_url,
|
288
|
+
"start_index": current_index,
|
289
|
+
"chunk_size": chunk_size,
|
290
|
+
"unprocessed_ranges": [
|
291
|
+
(current_index, current_index + chunk_size - 1)
|
292
|
+
],
|
293
|
+
},
|
294
|
+
metadata={
|
295
|
+
"shard_name": current_shard_name,
|
296
|
+
"chunk_index": current_index // self.chunk_size,
|
297
|
+
},
|
298
|
+
)
|
299
|
+
else:
|
300
|
+
# No chunk tracker, create normally
|
301
|
+
unit = WorkUnit(
|
302
|
+
unit_id=unit_id,
|
303
|
+
chunk_id=unit_id,
|
304
|
+
source_id=current_shard_name,
|
305
|
+
data={
|
306
|
+
"shard_url": current_shard_url,
|
307
|
+
"start_index": current_index,
|
308
|
+
"chunk_size": chunk_size,
|
309
|
+
"unprocessed_ranges": [
|
310
|
+
(current_index, current_index + chunk_size - 1)
|
311
|
+
],
|
312
|
+
},
|
313
|
+
metadata={
|
314
|
+
"shard_name": current_shard_name,
|
315
|
+
"chunk_index": current_index // self.chunk_size,
|
316
|
+
},
|
317
|
+
)
|
318
|
+
|
319
|
+
self.work_units[unit_id] = unit
|
320
|
+
self.pending_units.append(unit_id)
|
321
|
+
logger.debug("Added work unit %s to pending_units", unit_id)
|
322
|
+
|
323
|
+
if self.chunk_tracker:
|
324
|
+
added_chunk = self.chunk_tracker.add_chunk(
|
325
|
+
unit_id,
|
326
|
+
current_shard_name,
|
327
|
+
current_shard_url,
|
328
|
+
current_index,
|
329
|
+
chunk_size,
|
330
|
+
)
|
331
|
+
if added_chunk:
|
332
|
+
logger.debug("Added chunk to chunk_tracker: %s", unit_id)
|
333
|
+
else:
|
334
|
+
logger.debug("Chunk already exists in chunk_tracker: %s", unit_id)
|
335
|
+
|
336
|
+
units_created += 1
|
337
|
+
|
338
|
+
current_index += self.chunk_size
|
339
|
+
|
340
|
+
if units_created > 0:
|
341
|
+
logger.debug(f"Created {units_created} work units")
|
342
|
+
|
343
|
+
def get_work_units(self, count: int, worker_id: str) -> List[WorkUnit]:
|
344
|
+
"""Get available work units for a worker."""
|
345
|
+
logger.debug("get_work_units called: count=%d worker_id=%s", count, worker_id)
|
346
|
+
assigned = []
|
347
|
+
|
348
|
+
with self.lock:
|
349
|
+
# Get new units if needed
|
350
|
+
while len(assigned) < count and self.pending_units:
|
351
|
+
unit_id = self.pending_units.popleft()
|
352
|
+
unit = self.work_units.get(unit_id)
|
353
|
+
|
354
|
+
if unit:
|
355
|
+
self.assigned_units[worker_id].add(unit_id)
|
356
|
+
assigned.append(unit)
|
357
|
+
logger.debug("Assigning new unit %s to worker %s", unit_id, worker_id)
|
358
|
+
|
359
|
+
if self.chunk_tracker:
|
360
|
+
self.chunk_tracker.mark_assigned(unit_id, worker_id)
|
361
|
+
|
362
|
+
logger.debug("Returning %d work units to worker %s", len(assigned), worker_id)
|
363
|
+
return assigned
|
364
|
+
|
365
|
+
def _has_unprocessed_items(self, unit: WorkUnit) -> bool:
|
366
|
+
"""Check if a work unit has unprocessed items."""
|
367
|
+
if not self.chunk_tracker:
|
368
|
+
logger.debug("No chunk_tracker, assuming unit %s has unprocessed items", unit.unit_id)
|
369
|
+
return True
|
370
|
+
|
371
|
+
chunk_info = self.chunk_tracker.get_chunk_with_unprocessed_items(unit.unit_id)
|
372
|
+
has_unprocessed = bool(chunk_info and chunk_info.get("unprocessed_ranges"))
|
373
|
+
logger.debug("Unit %s has unprocessed items: %s", unit.unit_id, has_unprocessed)
|
374
|
+
return has_unprocessed
|
375
|
+
|
376
|
+
def mark_completed(self, unit_id: str, worker_id: str) -> None:
|
377
|
+
"""Mark a work unit as completed."""
|
378
|
+
logger.debug("Marking unit %s as completed by worker %s", unit_id, worker_id)
|
379
|
+
with self.lock:
|
380
|
+
if unit_id in self.work_units:
|
381
|
+
self.assigned_units[worker_id].discard(unit_id)
|
382
|
+
logger.debug(
|
383
|
+
"Removed unit %s from assigned_units for worker %s", unit_id, worker_id
|
384
|
+
)
|
385
|
+
|
386
|
+
if self.chunk_tracker:
|
387
|
+
self.chunk_tracker.mark_completed(unit_id)
|
388
|
+
logger.debug("Marked unit %s as completed in chunk_tracker", unit_id)
|
389
|
+
|
390
|
+
def mark_failed(self, unit_id: str, worker_id: str, error: str) -> None:
|
391
|
+
"""Mark a work unit as failed."""
|
392
|
+
logger.debug("Marking unit %s as failed by worker %s, error: %s", unit_id, worker_id, error)
|
393
|
+
with self.lock:
|
394
|
+
if unit_id in self.work_units:
|
395
|
+
self.assigned_units[worker_id].discard(unit_id)
|
396
|
+
self.pending_units.append(unit_id)
|
397
|
+
logger.debug("Returned unit %s to pending_units", unit_id)
|
398
|
+
|
399
|
+
if self.chunk_tracker:
|
400
|
+
self.chunk_tracker.mark_failed(unit_id)
|
401
|
+
logger.debug("Marked unit %s as failed in chunk_tracker", unit_id)
|
402
|
+
|
403
|
+
def release_assignments(self, worker_id: str) -> None:
|
404
|
+
"""Release all assignments for a disconnected worker."""
|
405
|
+
logger.debug("Releasing assignments for worker %s", worker_id)
|
406
|
+
with self.lock:
|
407
|
+
unit_ids = list(self.assigned_units.get(worker_id, []))
|
408
|
+
|
409
|
+
for unit_id in unit_ids:
|
410
|
+
if unit_id in self.work_units:
|
411
|
+
unit = self.work_units[unit_id]
|
412
|
+
|
413
|
+
# Update unprocessed ranges based on what's been processed
|
414
|
+
if self.chunk_tracker and unit_id in self.chunk_tracker.chunks:
|
415
|
+
chunk_state = self.chunk_tracker.chunks[unit_id]
|
416
|
+
unprocessed_ranges = chunk_state.get_unprocessed_ranges()
|
417
|
+
|
418
|
+
# Convert relative ranges back to absolute
|
419
|
+
absolute_ranges = []
|
420
|
+
for start, end in unprocessed_ranges:
|
421
|
+
abs_start = chunk_state.start_index + start
|
422
|
+
abs_end = chunk_state.start_index + end
|
423
|
+
absolute_ranges.append((abs_start, abs_end))
|
424
|
+
|
425
|
+
# Update the work unit's data
|
426
|
+
unit.data["unprocessed_ranges"] = absolute_ranges
|
427
|
+
|
428
|
+
logger.debug(
|
429
|
+
f"Updated unit {unit_id} with unprocessed ranges: {absolute_ranges}"
|
430
|
+
)
|
431
|
+
|
432
|
+
self.pending_units.append(unit_id)
|
433
|
+
logger.debug("Returned unit %s to pending_units", unit_id)
|
434
|
+
|
435
|
+
if worker_id in self.assigned_units:
|
436
|
+
del self.assigned_units[worker_id]
|
437
|
+
logger.debug("Deleted worker %s from assigned_units", worker_id)
|
438
|
+
|
439
|
+
if self.chunk_tracker:
|
440
|
+
self.chunk_tracker.release_worker_chunks(worker_id)
|
441
|
+
logger.debug("Released worker %s chunks in chunk_tracker", worker_id)
|
442
|
+
|
443
|
+
def _subtract_ranges(
|
444
|
+
self, total_ranges: List[Tuple[int, int]], processed_ranges: List[Tuple[int, int]]
|
445
|
+
) -> List[Tuple[int, int]]:
|
446
|
+
"""Subtract processed ranges from total ranges."""
|
447
|
+
if not processed_ranges:
|
448
|
+
return total_ranges
|
449
|
+
|
450
|
+
# Create a set of all processed indices
|
451
|
+
processed_indices = set()
|
452
|
+
for start, end in processed_ranges:
|
453
|
+
processed_indices.update(range(start, end + 1))
|
454
|
+
|
455
|
+
# Find unprocessed ranges
|
456
|
+
unprocessed_ranges = []
|
457
|
+
for start, end in total_ranges:
|
458
|
+
current_start = None
|
459
|
+
for i in range(start, end + 1):
|
460
|
+
if i not in processed_indices:
|
461
|
+
if current_start is None:
|
462
|
+
current_start = i
|
463
|
+
else:
|
464
|
+
if current_start is not None:
|
465
|
+
unprocessed_ranges.append((current_start, i - 1))
|
466
|
+
current_start = None
|
467
|
+
|
468
|
+
if current_start is not None:
|
469
|
+
unprocessed_ranges.append((current_start, end))
|
470
|
+
|
471
|
+
return unprocessed_ranges
|
472
|
+
|
473
|
+
def get_stats(self) -> Dict[str, Any]:
|
474
|
+
"""Get processor statistics."""
|
475
|
+
with self.lock:
|
476
|
+
stats = {
|
477
|
+
"total_units": len(self.work_units),
|
478
|
+
"pending_units": len(self.pending_units),
|
479
|
+
"assigned_units": sum(len(units) for units in self.assigned_units.values()),
|
480
|
+
"total_shards": len(self.all_shards),
|
481
|
+
"workers": len(self.assigned_units),
|
482
|
+
}
|
483
|
+
logger.debug("Stats: %s", stats)
|
484
|
+
return stats
|
485
|
+
|
486
|
+
def handle_result(self, result: WorkResult) -> Dict[str, Any]:
|
487
|
+
"""Handle WebDataset-specific result processing."""
|
488
|
+
# logger.debug("Handling result for unit %s", result.unit_id)
|
489
|
+
base_result = super().handle_result(result)
|
490
|
+
|
491
|
+
# Track processed items if we have chunk tracker
|
492
|
+
if self.chunk_tracker:
|
493
|
+
if "item_indices" not in result.metadata:
|
494
|
+
result.metadata["item_indices"] = [result.metadata.get("_item_index")]
|
495
|
+
indices = result.metadata["item_indices"]
|
496
|
+
logger.debug("Result metadata item_indices: %s", indices)
|
497
|
+
|
498
|
+
# Group consecutive indices into ranges
|
499
|
+
if indices:
|
500
|
+
indices.sort()
|
501
|
+
ranges = []
|
502
|
+
start = indices[0]
|
503
|
+
end = indices[0]
|
504
|
+
|
505
|
+
for i in range(1, len(indices)):
|
506
|
+
if indices[i] == end + 1:
|
507
|
+
end = indices[i]
|
508
|
+
else:
|
509
|
+
ranges.append((start, end))
|
510
|
+
start = indices[i]
|
511
|
+
end = indices[i]
|
512
|
+
|
513
|
+
ranges.append((start, end))
|
514
|
+
|
515
|
+
# Mark ranges as processed
|
516
|
+
for start_idx, end_idx in ranges:
|
517
|
+
logger.debug(f"Marking chunk as processed: {result.to_repr()}")
|
518
|
+
self.chunk_tracker.mark_items_processed(result.chunk_id, start_idx, end_idx)
|
519
|
+
logger.debug(
|
520
|
+
"Marked items processed for unit %s: %d-%d",
|
521
|
+
result.unit_id,
|
522
|
+
start_idx,
|
523
|
+
end_idx,
|
524
|
+
)
|
525
|
+
else:
|
526
|
+
logger.error(
|
527
|
+
f"No chunk tracker? {self.chunk_tracker} or no item_indices in {result.metadata}"
|
528
|
+
)
|
529
|
+
|
530
|
+
return base_result
|
531
|
+
|
532
|
+
def update_from_storage(self, processed_job_ids: Set[str]) -> None:
|
533
|
+
"""Update work units based on what's been processed."""
|
534
|
+
logger.info(f"Updating work units from {len(processed_job_ids)} processed jobs")
|
535
|
+
|
536
|
+
with self.lock:
|
537
|
+
for unit_id, unit in self.work_units.items():
|
538
|
+
# Extract chunk info from unit
|
539
|
+
start_index = unit.data["start_index"]
|
540
|
+
chunk_size = unit.data["chunk_size"]
|
541
|
+
shard_name = unit.metadata["shard_name"]
|
542
|
+
chunk_index = unit.metadata["chunk_index"]
|
543
|
+
|
544
|
+
# Find processed indices for this chunk
|
545
|
+
processed_indices = []
|
546
|
+
for job_id in processed_job_ids:
|
547
|
+
# Parse job_id format: "data-0000:chunk:0:idx:42"
|
548
|
+
parts = job_id.split(":")
|
549
|
+
if (
|
550
|
+
len(parts) == 5
|
551
|
+
and parts[0] == shard_name
|
552
|
+
and parts[1] == "chunk"
|
553
|
+
and int(parts[2]) == chunk_index
|
554
|
+
and parts[3] == "idx"
|
555
|
+
):
|
556
|
+
|
557
|
+
idx = int(parts[4])
|
558
|
+
if start_index <= idx < start_index + chunk_size:
|
559
|
+
processed_indices.append(idx)
|
560
|
+
|
561
|
+
if processed_indices:
|
562
|
+
# Convert to ranges
|
563
|
+
processed_indices.sort()
|
564
|
+
processed_ranges = []
|
565
|
+
start = processed_indices[0]
|
566
|
+
end = processed_indices[0]
|
567
|
+
|
568
|
+
for idx in processed_indices[1:]:
|
569
|
+
if idx == end + 1:
|
570
|
+
end = idx
|
571
|
+
else:
|
572
|
+
processed_ranges.append((start, end))
|
573
|
+
start = idx
|
574
|
+
end = idx
|
575
|
+
|
576
|
+
processed_ranges.append((start, end))
|
577
|
+
|
578
|
+
# Calculate unprocessed ranges
|
579
|
+
total_range = [(start_index, start_index + chunk_size - 1)]
|
580
|
+
unprocessed_ranges = self._subtract_ranges(total_range, processed_ranges)
|
581
|
+
|
582
|
+
# Update unit
|
583
|
+
unit.data["unprocessed_ranges"] = unprocessed_ranges
|
584
|
+
|
585
|
+
logger.debug(
|
586
|
+
f"Updated unit {unit_id}: {len(processed_indices)} processed, "
|
587
|
+
f"unprocessed ranges: {unprocessed_ranges}"
|
588
|
+
)
|
589
|
+
|
590
|
+
|
591
|
+
class WebDatasetWorkerProcessor(WorkerProcessor):
|
592
|
+
"""Worker processor for WebDataset shards."""
|
593
|
+
|
594
|
+
def __init__(self):
|
595
|
+
logger.debug("Initializing WebDatasetWorkerProcessor")
|
596
|
+
self.dataset_loader: Optional[DatasetLoader] = None
|
597
|
+
self.dataset_config: Dict[str, Any] = {}
|
598
|
+
self.dataset_name: Optional[str] = None
|
599
|
+
|
600
|
+
def initialize(self, config: ProcessorConfig) -> None:
|
601
|
+
"""Initialize WebDataset processor."""
|
602
|
+
logger.debug("Initializing worker with config: %s", config.config)
|
603
|
+
cfg = config.config["dataset"]
|
604
|
+
|
605
|
+
# Store config
|
606
|
+
self.dataset_config = cfg
|
607
|
+
|
608
|
+
# Initialize dataset loader
|
609
|
+
dataset_path = cfg.get("dataset_path")
|
610
|
+
self.dataset_path = dataset_path
|
611
|
+
dataset_type = cfg.get("dataset_type", "huggingface")
|
612
|
+
dataset_split = cfg.get("dataset_split", "train")
|
613
|
+
image_column = cfg.get("dataset_image_column", "image")
|
614
|
+
|
615
|
+
if dataset_path:
|
616
|
+
self.dataset_loader = DatasetLoader(
|
617
|
+
dataset_path=dataset_path,
|
618
|
+
dataset_type=dataset_type,
|
619
|
+
split=dataset_split,
|
620
|
+
image_column=image_column,
|
621
|
+
)
|
622
|
+
logger.debug("DatasetLoader initialized for worker")
|
623
|
+
else:
|
624
|
+
logger.error("No dataset_path provided in worker config")
|
625
|
+
|
626
|
+
def process_unit(self, unit: WorkUnit, context: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
|
627
|
+
"""Process a WebDataset chunk, yielding items to be captioned."""
|
628
|
+
logger.debug("Processing unit: %s", unit.unit_id)
|
629
|
+
if not self.dataset_loader:
|
630
|
+
logger.error("Dataset loader not initialized")
|
631
|
+
return
|
632
|
+
|
633
|
+
shard_name = unit.metadata["shard_name"]
|
634
|
+
chunk_index = unit.metadata["chunk_index"]
|
635
|
+
shard_url = unit.data["shard_url"]
|
636
|
+
start_index = unit.data["start_index"]
|
637
|
+
chunk_size = unit.data["chunk_size"]
|
638
|
+
unprocessed_ranges = unit.data.get(
|
639
|
+
"unprocessed_ranges", [(start_index, start_index + chunk_size - 1)]
|
640
|
+
)
|
641
|
+
|
642
|
+
logger.info(f"Processing unit {unit.unit_id} with ranges: {unprocessed_ranges}")
|
643
|
+
|
644
|
+
# Create set of indices to process
|
645
|
+
indices_to_process = set()
|
646
|
+
for start, end in unprocessed_ranges:
|
647
|
+
indices_to_process.update(range(start, end + 1))
|
648
|
+
logger.debug("Indices to process: %s", indices_to_process)
|
649
|
+
|
650
|
+
processed_indices = []
|
651
|
+
|
652
|
+
# Iterate through shard
|
653
|
+
for idx, (key, url, image_data, metadata) in enumerate(
|
654
|
+
self._iterate_shard_with_metadata(shard_url)
|
655
|
+
):
|
656
|
+
# Skip if not in our chunk range
|
657
|
+
if idx < start_index or idx >= start_index + chunk_size:
|
658
|
+
# logger.debug(f"Skipping idx={idx} not in chunk range")
|
659
|
+
continue
|
660
|
+
|
661
|
+
# Skip if already processed
|
662
|
+
if idx not in indices_to_process:
|
663
|
+
logger.debug(f"Skipping idx={idx} already processed")
|
664
|
+
continue
|
665
|
+
|
666
|
+
try:
|
667
|
+
# Load image
|
668
|
+
image = Image.open(io.BytesIO(image_data))
|
669
|
+
job_id = f"{shard_name}:chunk:{chunk_index}:idx:{idx}"
|
670
|
+
|
671
|
+
# Clean metadata - remove sensitive and redundant fields
|
672
|
+
clean_metadata = {
|
673
|
+
k: v
|
674
|
+
for k, v in metadata.items()
|
675
|
+
if k not in ["url", "_shard_url", "shard_name"] # Remove these fields
|
676
|
+
}
|
677
|
+
|
678
|
+
# Add only necessary index information
|
679
|
+
clean_metadata.update(
|
680
|
+
{
|
681
|
+
"_item_index": idx,
|
682
|
+
"_chunk_relative_index": idx - start_index,
|
683
|
+
"_job_id": job_id,
|
684
|
+
}
|
685
|
+
)
|
686
|
+
|
687
|
+
# Prepare item for captioning
|
688
|
+
# logger.debug("Yielding item idx=%d key=%s", idx, key)
|
689
|
+
yield {
|
690
|
+
"image": image,
|
691
|
+
"item_key": key,
|
692
|
+
"item_index": idx,
|
693
|
+
"metadata": clean_metadata,
|
694
|
+
"job_id": job_id,
|
695
|
+
}
|
696
|
+
|
697
|
+
processed_indices.append(idx)
|
698
|
+
|
699
|
+
except Exception as e:
|
700
|
+
logger.error(f"Error processing item {key}: {e}")
|
701
|
+
|
702
|
+
# Store processed indices in context for result preparation
|
703
|
+
context["_processed_indices"] = processed_indices
|
704
|
+
logger.debug("Processed indices for unit %s: %s", unit.unit_id, processed_indices)
|
705
|
+
|
706
|
+
def _iterate_shard_with_metadata(
|
707
|
+
self, shard_url: str
|
708
|
+
) -> Iterator[Tuple[str, str, bytes, Dict]]:
|
709
|
+
"""Iterate through a shard with metadata."""
|
710
|
+
logger.debug("Iterating shard with metadata: %s", shard_url)
|
711
|
+
|
712
|
+
if not self.dataset_loader:
|
713
|
+
logger.error("Dataset loader not initialized")
|
714
|
+
return
|
715
|
+
|
716
|
+
# Use the DatasetLoader that returns full samples
|
717
|
+
for sample in self.dataset_loader.iterate_shard(shard_url):
|
718
|
+
if not isinstance(sample, dict):
|
719
|
+
logger.warning("Unexpected sample format: %s", type(sample))
|
720
|
+
continue
|
721
|
+
|
722
|
+
key = sample.get("__key__", "unknown")
|
723
|
+
url = sample.get("__url__", "") # Don't use shard_url as default
|
724
|
+
|
725
|
+
# Find image data
|
726
|
+
image_data = None
|
727
|
+
image_ext = None
|
728
|
+
for ext in ["jpg", "jpeg", "png", "webp", "bmp", "jxl"]:
|
729
|
+
if ext in sample:
|
730
|
+
image_data = sample[ext]
|
731
|
+
image_ext = ext
|
732
|
+
break
|
733
|
+
|
734
|
+
if not image_data:
|
735
|
+
logger.debug(
|
736
|
+
"No image data found for item key=%s, available keys: %s",
|
737
|
+
key,
|
738
|
+
list(sample.keys()),
|
739
|
+
)
|
740
|
+
continue
|
741
|
+
|
742
|
+
# Extract metadata (all non-system and non-image keys)
|
743
|
+
metadata = {
|
744
|
+
k: v
|
745
|
+
for k, v in sample.items()
|
746
|
+
if not k.startswith("__") and k not in ["jpg", "jpeg", "png", "webp", "bmp", "jxl"]
|
747
|
+
}
|
748
|
+
|
749
|
+
# Add image format but not URLs
|
750
|
+
if image_ext:
|
751
|
+
metadata["_image_format"] = image_ext
|
752
|
+
|
753
|
+
yield key, url, image_data, metadata
|
754
|
+
|
755
|
+
def prepare_result(
|
756
|
+
self, unit: WorkUnit, outputs: List[Dict[str, Any]], processing_time_ms: float
|
757
|
+
) -> WorkResult:
|
758
|
+
"""Prepare WebDataset-specific result."""
|
759
|
+
logger.debug("Preparing result for unit %s", unit.unit_id)
|
760
|
+
result = super().prepare_result(unit, outputs, processing_time_ms)
|
761
|
+
|
762
|
+
# Add processed indices to metadata if available
|
763
|
+
if outputs and "_processed_indices" in outputs[0].get("metadata", {}):
|
764
|
+
result.metadata["item_indices"] = outputs[0]["metadata"]["_processed_indices"]
|
765
|
+
logger.debug(
|
766
|
+
"Added item_indices to result metadata: %s", result.metadata["item_indices"]
|
767
|
+
)
|
768
|
+
|
769
|
+
return result
|
770
|
+
|
771
|
+
def get_dataset_info(self) -> Dict[str, Any]:
|
772
|
+
"""Get dataset information."""
|
773
|
+
if self.dataset_loader:
|
774
|
+
info = self.dataset_loader.get_dataset_info()
|
775
|
+
logger.debug("Dataset info: %s", info)
|
776
|
+
return info
|
777
|
+
info = {
|
778
|
+
"dataset_path": self.dataset_config.get("dataset_path"),
|
779
|
+
"dataset_type": self.dataset_config.get("type", "huggingface"),
|
780
|
+
}
|
781
|
+
logger.debug("Dataset info (no loader): %s", info)
|
782
|
+
return info
|