caption-flow 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,782 @@
1
+ """WebDataset processor implementation."""
2
+
3
+ import logging
4
+ import threading
5
+ from typing import Dict, Any, List, Optional, Iterator, Set, Deque, Tuple
6
+ from collections import deque, defaultdict
7
+ from pathlib import Path
8
+ import json
9
+ import io
10
+ from datetime import datetime
11
+ from PIL import Image
12
+ from caption_flow.storage import StorageManager
13
+
14
+ from .base import OrchestratorProcessor, WorkerProcessor, ProcessorConfig, WorkUnit, WorkResult
15
+ from ..utils import DatasetLoader, ChunkTracker
16
+
17
+ logger = logging.getLogger(__name__)
18
+ logger.setLevel(logging.INFO)
19
+
20
+
21
+ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
22
+ """Orchestrator processor for WebDataset shards."""
23
+
24
+ def __init__(self):
25
+ logger.debug("Initializing WebDatasetOrchestratorProcessor")
26
+ self.dataset_loader: Optional[DatasetLoader] = None
27
+ self.chunk_tracker: Optional[ChunkTracker] = None
28
+ self.chunk_size: int = 1000
29
+
30
+ # Work unit management
31
+ self.work_units: Dict[str, WorkUnit] = {}
32
+ self.pending_units: Deque[str] = deque()
33
+ self.assigned_units: Dict[str, Set[str]] = defaultdict(set) # worker_id -> unit_ids
34
+ self.lock = threading.Lock()
35
+
36
+ # Shard processing state
37
+ self.all_shards: List[str] = []
38
+ self.current_shard_index = 0
39
+ self.current_shard_items = 0
40
+
41
+ # Background thread for creating work units
42
+ self.unit_creation_thread: Optional[threading.Thread] = None
43
+ self.stop_creation = threading.Event()
44
+
45
+ def initialize(self, config: ProcessorConfig, storage: StorageManager) -> None:
46
+ """Initialize WebDataset processor."""
47
+ logger.debug("Initializing orchestrator with config: %s", config.config)
48
+ cfg = config.config
49
+
50
+ # Dataset configuration
51
+ dataset_cfg = cfg.get("dataset", {})
52
+ dataset_path = dataset_cfg.get("dataset_path")
53
+ dataset_type = dataset_cfg.get("dataset_type", "huggingface")
54
+ dataset_split = dataset_cfg.get("dataset_split", "train")
55
+ image_column = dataset_cfg.get("dataset_image_column", "image")
56
+
57
+ # Chunk settings
58
+ self.chunk_size = cfg.get("chunk_size", 1000)
59
+ self.min_buffer = cfg.get("min_chunk_buffer", 10)
60
+ self.buffer_multiplier = cfg.get("chunk_buffer_multiplier", 3)
61
+
62
+ logger.debug(
63
+ "Chunk size: %d, min_buffer: %d, buffer_multiplier: %d",
64
+ self.chunk_size,
65
+ self.min_buffer,
66
+ self.buffer_multiplier,
67
+ )
68
+
69
+ # Initialize dataset loader
70
+ if dataset_path:
71
+ checkpoint_dir = Path(cfg.get("checkpoint_dir", "./checkpoints"))
72
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
73
+ logger.debug("Checkpoint dir: %s", checkpoint_dir)
74
+
75
+ self.dataset_loader = DatasetLoader(
76
+ dataset_path=dataset_path,
77
+ dataset_type=dataset_type,
78
+ split=dataset_split,
79
+ image_column=image_column,
80
+ cache_dir=checkpoint_dir,
81
+ )
82
+ logger.debug("DatasetLoader initialized")
83
+
84
+ self.chunk_tracker = ChunkTracker(checkpoint_dir / "chunks.json")
85
+ logger.debug("ChunkTracker initialized at %s", checkpoint_dir / "chunks.json")
86
+
87
+ # Get all shards
88
+ self.all_shards = self.dataset_loader.get_shard_list()
89
+ logger.debug("All shards: %s", self.all_shards)
90
+
91
+ # Restore existing state from chunk tracker
92
+ self._restore_state(storage=storage)
93
+
94
+ # Start background unit creation
95
+ self.unit_creation_thread = threading.Thread(
96
+ target=self._create_units_background, daemon=True
97
+ )
98
+ self.unit_creation_thread.start()
99
+ logger.debug("Unit creation thread started")
100
+ else:
101
+ logger.error("No dataset_path provided in config")
102
+
103
+ def _restore_state(self, storage: StorageManager) -> None:
104
+ """Restore state from chunk tracker."""
105
+ logger.debug("Restoring state from chunk tracker")
106
+ if not self.chunk_tracker:
107
+ return
108
+
109
+ shards_summary = self.chunk_tracker.get_shards_summary()
110
+
111
+ # Get all processed job_ids from storage
112
+ all_processed_jobs = storage.get_all_processed_job_ids()
113
+
114
+ with self.lock:
115
+ for shard_name, shard_info in shards_summary.items():
116
+ for chunk_state in shard_info["chunks"]:
117
+ # Calculate actual unprocessed ranges based on what's in storage
118
+ chunk_range = (
119
+ chunk_state.start_index,
120
+ chunk_state.start_index + chunk_state.chunk_size - 1,
121
+ )
122
+
123
+ # Get processed indices for this chunk
124
+ processed_ranges = self.chunk_tracker.get_processed_indices_for_chunk(
125
+ chunk_state.chunk_id, all_processed_jobs
126
+ )
127
+
128
+ # Calculate unprocessed ranges
129
+ unprocessed_ranges = self._subtract_ranges([chunk_range], processed_ranges)
130
+
131
+ if unprocessed_ranges:
132
+ # Create work unit for unprocessed items
133
+ logger.debug(f"Creating WorkUnit for chunk {chunk_state}")
134
+ unit = WorkUnit(
135
+ unit_id=chunk_state.chunk_id,
136
+ chunk_id=chunk_state.chunk_id,
137
+ source_id=shard_name,
138
+ data={
139
+ "shard_url": chunk_state.shard_url,
140
+ "start_index": chunk_state.start_index,
141
+ "chunk_size": chunk_state.chunk_size,
142
+ "unprocessed_ranges": unprocessed_ranges,
143
+ },
144
+ metadata={
145
+ "shard_name": shard_name,
146
+ "chunk_index": chunk_state.start_index // self.chunk_size,
147
+ },
148
+ )
149
+
150
+ self.work_units[unit.unit_id] = unit
151
+ self.pending_units.append(unit.unit_id)
152
+
153
+ def _create_units_background(self) -> None:
154
+ """Background thread to create work units on demand."""
155
+ logger.info("Starting work unit creation thread")
156
+
157
+ shard_iter = iter(self.all_shards)
158
+ current_shard_url = None
159
+ current_shard_name = None
160
+ current_shard_items = 0
161
+ current_index = 0
162
+
163
+ while not self.stop_creation.is_set():
164
+ # Check if we need more units
165
+ with self.lock:
166
+ pending_count = len(self.pending_units)
167
+ assigned_count = sum(len(units) for units in self.assigned_units.values())
168
+ worker_count = max(1, len(self.assigned_units))
169
+
170
+ target_buffer = max(self.min_buffer, worker_count * self.buffer_multiplier)
171
+ units_needed = max(0, target_buffer - (pending_count + assigned_count))
172
+ logger.debug(
173
+ "pending_count=%d assigned_count=%d worker_count=%d target_buffer=%d units_needed=%d",
174
+ pending_count,
175
+ assigned_count,
176
+ worker_count,
177
+ target_buffer,
178
+ units_needed,
179
+ )
180
+
181
+ if units_needed == 0:
182
+ threading.Event().wait(5)
183
+ continue
184
+
185
+ # Create units as needed
186
+ units_created = 0
187
+
188
+ while units_created < units_needed and not self.stop_creation.is_set():
189
+ # Load next shard if needed
190
+ if current_shard_url is None or current_index >= current_shard_items:
191
+ try:
192
+ current_shard_url = next(shard_iter)
193
+ current_shard_name = Path(current_shard_url).stem
194
+
195
+ logger.debug("Loading shard: %s", current_shard_url)
196
+ # Count items in shard
197
+ current_shard_items = sum(
198
+ 1 for _ in self.dataset_loader.iterate_shard(current_shard_url)
199
+ )
200
+ logger.info(
201
+ f"Processing shard {current_shard_name} with {current_shard_items} items"
202
+ )
203
+ current_index = 0
204
+
205
+ except StopIteration:
206
+ logger.info("All shards processed")
207
+ break
208
+ except Exception as e:
209
+ logger.error("Error loading shard: %s", e)
210
+ break
211
+
212
+ # Create work unit
213
+ if current_shard_url and current_index < current_shard_items:
214
+ chunk_size = min(self.chunk_size, current_shard_items - current_index)
215
+ unit_id = f"{current_shard_name}:chunk:{current_index // self.chunk_size}"
216
+
217
+ with self.lock:
218
+ # Check if this unit already exists in work_units
219
+ if unit_id in self.work_units:
220
+ logger.debug(
221
+ f"Unit {unit_id} already exists in work_units, skipping creation"
222
+ )
223
+ current_index += self.chunk_size
224
+ continue
225
+
226
+ # Check if chunk is already completed or has no unprocessed items
227
+ if self.chunk_tracker:
228
+ chunk_state = self.chunk_tracker.chunks.get(unit_id)
229
+
230
+ if chunk_state:
231
+ # Check if completed
232
+ if chunk_state.status == "completed":
233
+ logger.debug(f"Unit {unit_id} already completed, skipping")
234
+ current_index += self.chunk_size
235
+ continue
236
+
237
+ # Check if has unprocessed items
238
+ unprocessed_ranges = chunk_state.get_unprocessed_ranges()
239
+ if not unprocessed_ranges:
240
+ logger.debug(
241
+ f"Unit {unit_id} has no unprocessed items, skipping"
242
+ )
243
+ current_index += self.chunk_size
244
+ continue
245
+
246
+ # If chunk exists but has unprocessed items, use those ranges
247
+ logger.debug(
248
+ f"Existing chunk {unit_id} has unprocessed ranges: {unprocessed_ranges}"
249
+ )
250
+
251
+ unit = WorkUnit(
252
+ unit_id=unit_id,
253
+ chunk_id=unit_id,
254
+ source_id=current_shard_name,
255
+ data={
256
+ "shard_url": current_shard_url,
257
+ "start_index": current_index,
258
+ "chunk_size": chunk_size,
259
+ "unprocessed_ranges": [
260
+ (
261
+ r[0] + chunk_state.start_index,
262
+ r[1] + chunk_state.start_index,
263
+ )
264
+ for r in unprocessed_ranges
265
+ ], # Convert relative to absolute
266
+ },
267
+ metadata={
268
+ "shard_name": current_shard_name,
269
+ "chunk_index": current_index // self.chunk_size,
270
+ },
271
+ )
272
+ else:
273
+ # New chunk
274
+ logger.debug(
275
+ "Creating new work unit: unit_id=%s shard=%s start_index=%d chunk_size=%d",
276
+ unit_id,
277
+ current_shard_name,
278
+ current_index,
279
+ chunk_size,
280
+ )
281
+
282
+ unit = WorkUnit(
283
+ unit_id=unit_id,
284
+ chunk_id=unit_id,
285
+ source_id=current_shard_name,
286
+ data={
287
+ "shard_url": current_shard_url,
288
+ "start_index": current_index,
289
+ "chunk_size": chunk_size,
290
+ "unprocessed_ranges": [
291
+ (current_index, current_index + chunk_size - 1)
292
+ ],
293
+ },
294
+ metadata={
295
+ "shard_name": current_shard_name,
296
+ "chunk_index": current_index // self.chunk_size,
297
+ },
298
+ )
299
+ else:
300
+ # No chunk tracker, create normally
301
+ unit = WorkUnit(
302
+ unit_id=unit_id,
303
+ chunk_id=unit_id,
304
+ source_id=current_shard_name,
305
+ data={
306
+ "shard_url": current_shard_url,
307
+ "start_index": current_index,
308
+ "chunk_size": chunk_size,
309
+ "unprocessed_ranges": [
310
+ (current_index, current_index + chunk_size - 1)
311
+ ],
312
+ },
313
+ metadata={
314
+ "shard_name": current_shard_name,
315
+ "chunk_index": current_index // self.chunk_size,
316
+ },
317
+ )
318
+
319
+ self.work_units[unit_id] = unit
320
+ self.pending_units.append(unit_id)
321
+ logger.debug("Added work unit %s to pending_units", unit_id)
322
+
323
+ if self.chunk_tracker:
324
+ added_chunk = self.chunk_tracker.add_chunk(
325
+ unit_id,
326
+ current_shard_name,
327
+ current_shard_url,
328
+ current_index,
329
+ chunk_size,
330
+ )
331
+ if added_chunk:
332
+ logger.debug("Added chunk to chunk_tracker: %s", unit_id)
333
+ else:
334
+ logger.debug("Chunk already exists in chunk_tracker: %s", unit_id)
335
+
336
+ units_created += 1
337
+
338
+ current_index += self.chunk_size
339
+
340
+ if units_created > 0:
341
+ logger.debug(f"Created {units_created} work units")
342
+
343
+ def get_work_units(self, count: int, worker_id: str) -> List[WorkUnit]:
344
+ """Get available work units for a worker."""
345
+ logger.debug("get_work_units called: count=%d worker_id=%s", count, worker_id)
346
+ assigned = []
347
+
348
+ with self.lock:
349
+ # Get new units if needed
350
+ while len(assigned) < count and self.pending_units:
351
+ unit_id = self.pending_units.popleft()
352
+ unit = self.work_units.get(unit_id)
353
+
354
+ if unit:
355
+ self.assigned_units[worker_id].add(unit_id)
356
+ assigned.append(unit)
357
+ logger.debug("Assigning new unit %s to worker %s", unit_id, worker_id)
358
+
359
+ if self.chunk_tracker:
360
+ self.chunk_tracker.mark_assigned(unit_id, worker_id)
361
+
362
+ logger.debug("Returning %d work units to worker %s", len(assigned), worker_id)
363
+ return assigned
364
+
365
+ def _has_unprocessed_items(self, unit: WorkUnit) -> bool:
366
+ """Check if a work unit has unprocessed items."""
367
+ if not self.chunk_tracker:
368
+ logger.debug("No chunk_tracker, assuming unit %s has unprocessed items", unit.unit_id)
369
+ return True
370
+
371
+ chunk_info = self.chunk_tracker.get_chunk_with_unprocessed_items(unit.unit_id)
372
+ has_unprocessed = bool(chunk_info and chunk_info.get("unprocessed_ranges"))
373
+ logger.debug("Unit %s has unprocessed items: %s", unit.unit_id, has_unprocessed)
374
+ return has_unprocessed
375
+
376
+ def mark_completed(self, unit_id: str, worker_id: str) -> None:
377
+ """Mark a work unit as completed."""
378
+ logger.debug("Marking unit %s as completed by worker %s", unit_id, worker_id)
379
+ with self.lock:
380
+ if unit_id in self.work_units:
381
+ self.assigned_units[worker_id].discard(unit_id)
382
+ logger.debug(
383
+ "Removed unit %s from assigned_units for worker %s", unit_id, worker_id
384
+ )
385
+
386
+ if self.chunk_tracker:
387
+ self.chunk_tracker.mark_completed(unit_id)
388
+ logger.debug("Marked unit %s as completed in chunk_tracker", unit_id)
389
+
390
+ def mark_failed(self, unit_id: str, worker_id: str, error: str) -> None:
391
+ """Mark a work unit as failed."""
392
+ logger.debug("Marking unit %s as failed by worker %s, error: %s", unit_id, worker_id, error)
393
+ with self.lock:
394
+ if unit_id in self.work_units:
395
+ self.assigned_units[worker_id].discard(unit_id)
396
+ self.pending_units.append(unit_id)
397
+ logger.debug("Returned unit %s to pending_units", unit_id)
398
+
399
+ if self.chunk_tracker:
400
+ self.chunk_tracker.mark_failed(unit_id)
401
+ logger.debug("Marked unit %s as failed in chunk_tracker", unit_id)
402
+
403
+ def release_assignments(self, worker_id: str) -> None:
404
+ """Release all assignments for a disconnected worker."""
405
+ logger.debug("Releasing assignments for worker %s", worker_id)
406
+ with self.lock:
407
+ unit_ids = list(self.assigned_units.get(worker_id, []))
408
+
409
+ for unit_id in unit_ids:
410
+ if unit_id in self.work_units:
411
+ unit = self.work_units[unit_id]
412
+
413
+ # Update unprocessed ranges based on what's been processed
414
+ if self.chunk_tracker and unit_id in self.chunk_tracker.chunks:
415
+ chunk_state = self.chunk_tracker.chunks[unit_id]
416
+ unprocessed_ranges = chunk_state.get_unprocessed_ranges()
417
+
418
+ # Convert relative ranges back to absolute
419
+ absolute_ranges = []
420
+ for start, end in unprocessed_ranges:
421
+ abs_start = chunk_state.start_index + start
422
+ abs_end = chunk_state.start_index + end
423
+ absolute_ranges.append((abs_start, abs_end))
424
+
425
+ # Update the work unit's data
426
+ unit.data["unprocessed_ranges"] = absolute_ranges
427
+
428
+ logger.debug(
429
+ f"Updated unit {unit_id} with unprocessed ranges: {absolute_ranges}"
430
+ )
431
+
432
+ self.pending_units.append(unit_id)
433
+ logger.debug("Returned unit %s to pending_units", unit_id)
434
+
435
+ if worker_id in self.assigned_units:
436
+ del self.assigned_units[worker_id]
437
+ logger.debug("Deleted worker %s from assigned_units", worker_id)
438
+
439
+ if self.chunk_tracker:
440
+ self.chunk_tracker.release_worker_chunks(worker_id)
441
+ logger.debug("Released worker %s chunks in chunk_tracker", worker_id)
442
+
443
+ def _subtract_ranges(
444
+ self, total_ranges: List[Tuple[int, int]], processed_ranges: List[Tuple[int, int]]
445
+ ) -> List[Tuple[int, int]]:
446
+ """Subtract processed ranges from total ranges."""
447
+ if not processed_ranges:
448
+ return total_ranges
449
+
450
+ # Create a set of all processed indices
451
+ processed_indices = set()
452
+ for start, end in processed_ranges:
453
+ processed_indices.update(range(start, end + 1))
454
+
455
+ # Find unprocessed ranges
456
+ unprocessed_ranges = []
457
+ for start, end in total_ranges:
458
+ current_start = None
459
+ for i in range(start, end + 1):
460
+ if i not in processed_indices:
461
+ if current_start is None:
462
+ current_start = i
463
+ else:
464
+ if current_start is not None:
465
+ unprocessed_ranges.append((current_start, i - 1))
466
+ current_start = None
467
+
468
+ if current_start is not None:
469
+ unprocessed_ranges.append((current_start, end))
470
+
471
+ return unprocessed_ranges
472
+
473
+ def get_stats(self) -> Dict[str, Any]:
474
+ """Get processor statistics."""
475
+ with self.lock:
476
+ stats = {
477
+ "total_units": len(self.work_units),
478
+ "pending_units": len(self.pending_units),
479
+ "assigned_units": sum(len(units) for units in self.assigned_units.values()),
480
+ "total_shards": len(self.all_shards),
481
+ "workers": len(self.assigned_units),
482
+ }
483
+ logger.debug("Stats: %s", stats)
484
+ return stats
485
+
486
+ def handle_result(self, result: WorkResult) -> Dict[str, Any]:
487
+ """Handle WebDataset-specific result processing."""
488
+ # logger.debug("Handling result for unit %s", result.unit_id)
489
+ base_result = super().handle_result(result)
490
+
491
+ # Track processed items if we have chunk tracker
492
+ if self.chunk_tracker:
493
+ if "item_indices" not in result.metadata:
494
+ result.metadata["item_indices"] = [result.metadata.get("_item_index")]
495
+ indices = result.metadata["item_indices"]
496
+ logger.debug("Result metadata item_indices: %s", indices)
497
+
498
+ # Group consecutive indices into ranges
499
+ if indices:
500
+ indices.sort()
501
+ ranges = []
502
+ start = indices[0]
503
+ end = indices[0]
504
+
505
+ for i in range(1, len(indices)):
506
+ if indices[i] == end + 1:
507
+ end = indices[i]
508
+ else:
509
+ ranges.append((start, end))
510
+ start = indices[i]
511
+ end = indices[i]
512
+
513
+ ranges.append((start, end))
514
+
515
+ # Mark ranges as processed
516
+ for start_idx, end_idx in ranges:
517
+ logger.debug(f"Marking chunk as processed: {result.to_repr()}")
518
+ self.chunk_tracker.mark_items_processed(result.chunk_id, start_idx, end_idx)
519
+ logger.debug(
520
+ "Marked items processed for unit %s: %d-%d",
521
+ result.unit_id,
522
+ start_idx,
523
+ end_idx,
524
+ )
525
+ else:
526
+ logger.error(
527
+ f"No chunk tracker? {self.chunk_tracker} or no item_indices in {result.metadata}"
528
+ )
529
+
530
+ return base_result
531
+
532
+ def update_from_storage(self, processed_job_ids: Set[str]) -> None:
533
+ """Update work units based on what's been processed."""
534
+ logger.info(f"Updating work units from {len(processed_job_ids)} processed jobs")
535
+
536
+ with self.lock:
537
+ for unit_id, unit in self.work_units.items():
538
+ # Extract chunk info from unit
539
+ start_index = unit.data["start_index"]
540
+ chunk_size = unit.data["chunk_size"]
541
+ shard_name = unit.metadata["shard_name"]
542
+ chunk_index = unit.metadata["chunk_index"]
543
+
544
+ # Find processed indices for this chunk
545
+ processed_indices = []
546
+ for job_id in processed_job_ids:
547
+ # Parse job_id format: "data-0000:chunk:0:idx:42"
548
+ parts = job_id.split(":")
549
+ if (
550
+ len(parts) == 5
551
+ and parts[0] == shard_name
552
+ and parts[1] == "chunk"
553
+ and int(parts[2]) == chunk_index
554
+ and parts[3] == "idx"
555
+ ):
556
+
557
+ idx = int(parts[4])
558
+ if start_index <= idx < start_index + chunk_size:
559
+ processed_indices.append(idx)
560
+
561
+ if processed_indices:
562
+ # Convert to ranges
563
+ processed_indices.sort()
564
+ processed_ranges = []
565
+ start = processed_indices[0]
566
+ end = processed_indices[0]
567
+
568
+ for idx in processed_indices[1:]:
569
+ if idx == end + 1:
570
+ end = idx
571
+ else:
572
+ processed_ranges.append((start, end))
573
+ start = idx
574
+ end = idx
575
+
576
+ processed_ranges.append((start, end))
577
+
578
+ # Calculate unprocessed ranges
579
+ total_range = [(start_index, start_index + chunk_size - 1)]
580
+ unprocessed_ranges = self._subtract_ranges(total_range, processed_ranges)
581
+
582
+ # Update unit
583
+ unit.data["unprocessed_ranges"] = unprocessed_ranges
584
+
585
+ logger.debug(
586
+ f"Updated unit {unit_id}: {len(processed_indices)} processed, "
587
+ f"unprocessed ranges: {unprocessed_ranges}"
588
+ )
589
+
590
+
591
+ class WebDatasetWorkerProcessor(WorkerProcessor):
592
+ """Worker processor for WebDataset shards."""
593
+
594
+ def __init__(self):
595
+ logger.debug("Initializing WebDatasetWorkerProcessor")
596
+ self.dataset_loader: Optional[DatasetLoader] = None
597
+ self.dataset_config: Dict[str, Any] = {}
598
+ self.dataset_name: Optional[str] = None
599
+
600
+ def initialize(self, config: ProcessorConfig) -> None:
601
+ """Initialize WebDataset processor."""
602
+ logger.debug("Initializing worker with config: %s", config.config)
603
+ cfg = config.config["dataset"]
604
+
605
+ # Store config
606
+ self.dataset_config = cfg
607
+
608
+ # Initialize dataset loader
609
+ dataset_path = cfg.get("dataset_path")
610
+ self.dataset_path = dataset_path
611
+ dataset_type = cfg.get("dataset_type", "huggingface")
612
+ dataset_split = cfg.get("dataset_split", "train")
613
+ image_column = cfg.get("dataset_image_column", "image")
614
+
615
+ if dataset_path:
616
+ self.dataset_loader = DatasetLoader(
617
+ dataset_path=dataset_path,
618
+ dataset_type=dataset_type,
619
+ split=dataset_split,
620
+ image_column=image_column,
621
+ )
622
+ logger.debug("DatasetLoader initialized for worker")
623
+ else:
624
+ logger.error("No dataset_path provided in worker config")
625
+
626
+ def process_unit(self, unit: WorkUnit, context: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
627
+ """Process a WebDataset chunk, yielding items to be captioned."""
628
+ logger.debug("Processing unit: %s", unit.unit_id)
629
+ if not self.dataset_loader:
630
+ logger.error("Dataset loader not initialized")
631
+ return
632
+
633
+ shard_name = unit.metadata["shard_name"]
634
+ chunk_index = unit.metadata["chunk_index"]
635
+ shard_url = unit.data["shard_url"]
636
+ start_index = unit.data["start_index"]
637
+ chunk_size = unit.data["chunk_size"]
638
+ unprocessed_ranges = unit.data.get(
639
+ "unprocessed_ranges", [(start_index, start_index + chunk_size - 1)]
640
+ )
641
+
642
+ logger.info(f"Processing unit {unit.unit_id} with ranges: {unprocessed_ranges}")
643
+
644
+ # Create set of indices to process
645
+ indices_to_process = set()
646
+ for start, end in unprocessed_ranges:
647
+ indices_to_process.update(range(start, end + 1))
648
+ logger.debug("Indices to process: %s", indices_to_process)
649
+
650
+ processed_indices = []
651
+
652
+ # Iterate through shard
653
+ for idx, (key, url, image_data, metadata) in enumerate(
654
+ self._iterate_shard_with_metadata(shard_url)
655
+ ):
656
+ # Skip if not in our chunk range
657
+ if idx < start_index or idx >= start_index + chunk_size:
658
+ # logger.debug(f"Skipping idx={idx} not in chunk range")
659
+ continue
660
+
661
+ # Skip if already processed
662
+ if idx not in indices_to_process:
663
+ logger.debug(f"Skipping idx={idx} already processed")
664
+ continue
665
+
666
+ try:
667
+ # Load image
668
+ image = Image.open(io.BytesIO(image_data))
669
+ job_id = f"{shard_name}:chunk:{chunk_index}:idx:{idx}"
670
+
671
+ # Clean metadata - remove sensitive and redundant fields
672
+ clean_metadata = {
673
+ k: v
674
+ for k, v in metadata.items()
675
+ if k not in ["url", "_shard_url", "shard_name"] # Remove these fields
676
+ }
677
+
678
+ # Add only necessary index information
679
+ clean_metadata.update(
680
+ {
681
+ "_item_index": idx,
682
+ "_chunk_relative_index": idx - start_index,
683
+ "_job_id": job_id,
684
+ }
685
+ )
686
+
687
+ # Prepare item for captioning
688
+ # logger.debug("Yielding item idx=%d key=%s", idx, key)
689
+ yield {
690
+ "image": image,
691
+ "item_key": key,
692
+ "item_index": idx,
693
+ "metadata": clean_metadata,
694
+ "job_id": job_id,
695
+ }
696
+
697
+ processed_indices.append(idx)
698
+
699
+ except Exception as e:
700
+ logger.error(f"Error processing item {key}: {e}")
701
+
702
+ # Store processed indices in context for result preparation
703
+ context["_processed_indices"] = processed_indices
704
+ logger.debug("Processed indices for unit %s: %s", unit.unit_id, processed_indices)
705
+
706
+ def _iterate_shard_with_metadata(
707
+ self, shard_url: str
708
+ ) -> Iterator[Tuple[str, str, bytes, Dict]]:
709
+ """Iterate through a shard with metadata."""
710
+ logger.debug("Iterating shard with metadata: %s", shard_url)
711
+
712
+ if not self.dataset_loader:
713
+ logger.error("Dataset loader not initialized")
714
+ return
715
+
716
+ # Use the DatasetLoader that returns full samples
717
+ for sample in self.dataset_loader.iterate_shard(shard_url):
718
+ if not isinstance(sample, dict):
719
+ logger.warning("Unexpected sample format: %s", type(sample))
720
+ continue
721
+
722
+ key = sample.get("__key__", "unknown")
723
+ url = sample.get("__url__", "") # Don't use shard_url as default
724
+
725
+ # Find image data
726
+ image_data = None
727
+ image_ext = None
728
+ for ext in ["jpg", "jpeg", "png", "webp", "bmp", "jxl"]:
729
+ if ext in sample:
730
+ image_data = sample[ext]
731
+ image_ext = ext
732
+ break
733
+
734
+ if not image_data:
735
+ logger.debug(
736
+ "No image data found for item key=%s, available keys: %s",
737
+ key,
738
+ list(sample.keys()),
739
+ )
740
+ continue
741
+
742
+ # Extract metadata (all non-system and non-image keys)
743
+ metadata = {
744
+ k: v
745
+ for k, v in sample.items()
746
+ if not k.startswith("__") and k not in ["jpg", "jpeg", "png", "webp", "bmp", "jxl"]
747
+ }
748
+
749
+ # Add image format but not URLs
750
+ if image_ext:
751
+ metadata["_image_format"] = image_ext
752
+
753
+ yield key, url, image_data, metadata
754
+
755
+ def prepare_result(
756
+ self, unit: WorkUnit, outputs: List[Dict[str, Any]], processing_time_ms: float
757
+ ) -> WorkResult:
758
+ """Prepare WebDataset-specific result."""
759
+ logger.debug("Preparing result for unit %s", unit.unit_id)
760
+ result = super().prepare_result(unit, outputs, processing_time_ms)
761
+
762
+ # Add processed indices to metadata if available
763
+ if outputs and "_processed_indices" in outputs[0].get("metadata", {}):
764
+ result.metadata["item_indices"] = outputs[0]["metadata"]["_processed_indices"]
765
+ logger.debug(
766
+ "Added item_indices to result metadata: %s", result.metadata["item_indices"]
767
+ )
768
+
769
+ return result
770
+
771
+ def get_dataset_info(self) -> Dict[str, Any]:
772
+ """Get dataset information."""
773
+ if self.dataset_loader:
774
+ info = self.dataset_loader.get_dataset_info()
775
+ logger.debug("Dataset info: %s", info)
776
+ return info
777
+ info = {
778
+ "dataset_path": self.dataset_config.get("dataset_path"),
779
+ "dataset_type": self.dataset_config.get("type", "huggingface"),
780
+ }
781
+ logger.debug("Dataset info (no loader): %s", info)
782
+ return info