caption-flow 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,832 @@
1
+ """HuggingFace Datasets processor implementation."""
2
+
3
+ import logging
4
+ import threading
5
+ import re
6
+ import requests
7
+ from typing import Dict, Any, List, Optional, Iterator, Set, Deque, Tuple
8
+ from collections import deque, defaultdict
9
+ from pathlib import Path
10
+ import json
11
+ import io
12
+ from datetime import datetime
13
+ from PIL import Image
14
+ from datasets import (
15
+ Dataset,
16
+ get_dataset_config_names,
17
+ get_dataset_split_names,
18
+ load_dataset_builder,
19
+ )
20
+ from huggingface_hub import hf_hub_download, get_token
21
+ from caption_flow.storage import StorageManager
22
+
23
+ from .base import OrchestratorProcessor, WorkerProcessor, ProcessorConfig, WorkUnit, WorkResult
24
+ from ..utils import ChunkTracker
25
+ from ..models import JobId
26
+
27
+ logger = logging.getLogger(__name__)
28
+ logger.setLevel(logging.DEBUG)
29
+
30
+
31
+ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
32
+ """Orchestrator processor for HuggingFace datasets."""
33
+
34
+ def __init__(self):
35
+ logger.debug("Initializing HuggingFaceDatasetOrchestratorProcessor")
36
+ self.dataset_name: Optional[str] = None
37
+ self.config: Optional[str] = None
38
+ self.split: Optional[str] = None
39
+ self.chunk_tracker: Optional[ChunkTracker] = None
40
+ self.chunk_size: int = 1000
41
+ self.token = get_token()
42
+
43
+ # Shard information
44
+ self.shard_info: Dict[int, Dict[str, Any]] = {}
45
+ self.total_items: int = 0
46
+
47
+ # Work unit management
48
+ self.work_units: Dict[str, WorkUnit] = {}
49
+ self.pending_units: Deque[str] = deque()
50
+ self.assigned_units: Dict[str, Set[str]] = defaultdict(set) # worker_id -> unit_ids
51
+ self.lock = threading.Lock()
52
+
53
+ # Background thread for creating work units
54
+ self.unit_creation_thread: Optional[threading.Thread] = None
55
+ self.stop_creation = threading.Event()
56
+
57
+ def initialize(self, config: ProcessorConfig, storage: StorageManager) -> None:
58
+ """Initialize HuggingFace dataset processor."""
59
+ logger.debug("Initializing orchestrator with config: %s", config.config)
60
+ cfg = config.config
61
+
62
+ # Dataset configuration
63
+ dataset_cfg = cfg.get("dataset", {})
64
+ self.dataset_name = dataset_cfg.get("dataset_path")
65
+ if not self.dataset_name:
66
+ raise ValueError("dataset_path is required in config")
67
+
68
+ # Auto-detect config if not provided
69
+ provided_config = dataset_cfg.get("dataset_config")
70
+ self.config = self._detect_config(provided_config)
71
+
72
+ # Auto-detect split if not provided
73
+ provided_split = dataset_cfg.get("dataset_split")
74
+ self.split = self._detect_split(provided_split)
75
+
76
+ logger.info(
77
+ f"Using dataset: {self.dataset_name}, config: {self.config}, split: {self.split}"
78
+ )
79
+
80
+ # Chunk settings
81
+ self.chunk_size = cfg.get("chunk_size", 1000)
82
+ self.min_buffer = cfg.get("min_chunk_buffer", 10)
83
+ self.buffer_multiplier = cfg.get("chunk_buffer_multiplier", 3)
84
+
85
+ # Initialize chunk tracking
86
+ checkpoint_dir = Path(cfg.get("checkpoint_dir", "./checkpoints"))
87
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
88
+ self.chunk_tracker = ChunkTracker(checkpoint_dir / "chunks.json")
89
+
90
+ # Discover shards
91
+ self._discover_shards()
92
+
93
+ # Restore existing state
94
+ self._restore_state(storage=storage)
95
+
96
+ # Start background unit creation
97
+ self.unit_creation_thread = threading.Thread(
98
+ target=self._create_units_background, daemon=True
99
+ )
100
+ self.unit_creation_thread.start()
101
+ logger.debug("Unit creation thread started")
102
+
103
+ def _detect_config(self, provided_config: Optional[str]) -> str:
104
+ """Auto-detect config if not provided."""
105
+ if provided_config:
106
+ return provided_config
107
+
108
+ try:
109
+ configs = get_dataset_config_names(self.dataset_name, token=self.token)
110
+ if not configs:
111
+ return "default"
112
+
113
+ # Prefer common config names
114
+ preferred = ["default", "en", "train", "main"]
115
+ for pref in preferred:
116
+ if pref in configs:
117
+ logger.info(f"Auto-selected config: {pref}")
118
+ return pref
119
+
120
+ # Otherwise use first available
121
+ logger.info(f"Auto-selected first available config: {configs[0]}")
122
+ return configs[0]
123
+ except Exception as e:
124
+ logger.warning(f"Error detecting config: {e}, using 'default'")
125
+ return "default"
126
+
127
+ def _detect_split(self, provided_split: Optional[str]) -> str:
128
+ """Auto-detect split if not provided."""
129
+ if provided_split:
130
+ return provided_split
131
+
132
+ try:
133
+ splits = get_dataset_split_names(
134
+ self.dataset_name, config_name=self.config, token=self.token
135
+ )
136
+ if not splits:
137
+ logger.warning("No splits found, using 'train'")
138
+ return "train"
139
+
140
+ # Prefer training splits
141
+ preferred = ["train", "training", "test", "validation", "dev"]
142
+ for pref in preferred:
143
+ if pref in splits:
144
+ logger.info(f"Auto-selected split: {pref}")
145
+ return pref
146
+
147
+ # Otherwise use first available
148
+ logger.info(f"Auto-selected first available split: {splits[0]}")
149
+ return splits[0]
150
+ except Exception as e:
151
+ logger.warning(f"Error detecting split: {e}, using 'train'")
152
+ return "train"
153
+
154
+ def _extract_filename_from_url(self, url: str) -> str:
155
+ """Extract filename from HF URL format."""
156
+ # Format: hf://datasets/user/dataset@hash/filename
157
+ match = re.search(r"@[a-f0-9]+/(.+)$", url)
158
+ if match:
159
+ return match.group(1)
160
+ # Fallback: just get last part
161
+ return url.split("/")[-1]
162
+
163
+ def _discover_shards(self):
164
+ """Discover all shards and their sizes."""
165
+ logger.info("Discovering shards...")
166
+
167
+ # Load dataset builder to get file info
168
+ builder = load_dataset_builder(self.dataset_name, self.config)
169
+
170
+ # Get data files for our split
171
+ data_files = []
172
+ if hasattr(builder.config, "data_files"):
173
+ if isinstance(builder.config.data_files, dict):
174
+ files = builder.config.data_files.get(self.split, [])
175
+ if isinstance(files, str):
176
+ files = [files]
177
+ data_files = files
178
+
179
+ if not data_files:
180
+ raise ValueError(f"No data files found for split '{self.split}'")
181
+
182
+ logger.info(f"Found {len(data_files)} data files")
183
+
184
+ # Get info about each shard
185
+ cumulative_offset = 0
186
+ for i, file_url in enumerate(data_files):
187
+ filename = self._extract_filename_from_url(file_url)
188
+ logger.info(f"Discovering shard {i}: {filename}")
189
+
190
+ # We don't download shards here - workers will do that
191
+ # For now, store the info we have
192
+ self.shard_info[i] = {
193
+ "shard_id": i,
194
+ "file_url": file_url,
195
+ "filename": filename,
196
+ "start_offset": cumulative_offset,
197
+ # Size will be determined when first worker needs it
198
+ "size": None,
199
+ "end_offset": None,
200
+ }
201
+
202
+ # Try to get size from builder info if available
203
+ if hasattr(builder.info, "splits") and self.split in builder.info.splits:
204
+ split_info = builder.info.splits[self.split]
205
+ if split_info.num_examples and len(data_files) == 1:
206
+ # Single shard case
207
+ self.shard_info[i]["size"] = split_info.num_examples
208
+ self.shard_info[i]["end_offset"] = (
209
+ cumulative_offset + split_info.num_examples - 1
210
+ )
211
+ cumulative_offset += split_info.num_examples
212
+
213
+ # If we couldn't get sizes, we'll need to load shards on demand
214
+ if self.shard_info[0]["size"] is None:
215
+ logger.warning("Shard sizes not available from metadata, will load on demand")
216
+ else:
217
+ self.total_items = cumulative_offset
218
+ logger.info(f"Total items across all shards: {self.total_items}")
219
+
220
+ def _get_shard_size(self, shard_id: int) -> int:
221
+ """Get size of a shard, loading it if necessary."""
222
+ if self.shard_info[shard_id]["size"] is not None:
223
+ return self.shard_info[shard_id]["size"]
224
+
225
+ # Need to load the shard to get its size
226
+ logger.info(f"Loading shard {shard_id} to determine size...")
227
+ filename = self.shard_info[shard_id]["filename"]
228
+
229
+ local_path = hf_hub_download(
230
+ repo_id=self.dataset_name, filename=filename, repo_type="dataset", token=self.token
231
+ )
232
+
233
+ # Load just to get size
234
+ dataset = Dataset.from_parquet(local_path)
235
+ size = len(dataset)
236
+
237
+ # Update shard info
238
+ self.shard_info[shard_id]["size"] = size
239
+
240
+ # Update offsets for this and subsequent shards
241
+ for sid in range(shard_id, len(self.shard_info)):
242
+ if sid > shard_id:
243
+ self.shard_info[sid]["start_offset"] = self.shard_info[sid - 1]["end_offset"] + 1
244
+ self.shard_info[sid]["end_offset"] = (
245
+ self.shard_info[sid]["start_offset"] + self.shard_info[sid]["size"] - 1
246
+ )
247
+
248
+ # Update total items
249
+ if all(s["size"] is not None for s in self.shard_info.values()):
250
+ self.total_items = sum(s["size"] for s in self.shard_info.values())
251
+ logger.info(f"Total items: {self.total_items}")
252
+
253
+ return size
254
+
255
+ def _restore_state(self, storage: StorageManager) -> None:
256
+ """Restore state from chunk tracker."""
257
+ logger.debug("Restoring state from chunk tracker")
258
+ if not self.chunk_tracker:
259
+ return
260
+
261
+ all_processed_jobs = storage.get_all_processed_job_ids()
262
+
263
+ with self.lock:
264
+ for chunk_id, chunk_state in self.chunk_tracker.chunks.items():
265
+ # Calculate actual unprocessed ranges
266
+ chunk_range = (
267
+ chunk_state.start_index,
268
+ chunk_state.start_index + chunk_state.chunk_size - 1,
269
+ )
270
+
271
+ # Get processed indices for this chunk
272
+ processed_ranges = self.chunk_tracker.get_processed_indices_for_chunk(
273
+ chunk_id, all_processed_jobs
274
+ )
275
+
276
+ # Calculate unprocessed ranges
277
+ unprocessed_ranges = self._subtract_ranges([chunk_range], processed_ranges)
278
+
279
+ if unprocessed_ranges:
280
+ # Find which shard(s) this chunk belongs to
281
+ shard_ids = []
282
+ for sid, sinfo in self.shard_info.items():
283
+ # Need size to check
284
+ if sinfo["size"] is None:
285
+ self._get_shard_size(sid)
286
+
287
+ if (
288
+ sinfo["start_offset"]
289
+ <= chunk_state.start_index + chunk_state.chunk_size - 1
290
+ and sinfo["end_offset"] >= chunk_state.start_index
291
+ ):
292
+ shard_ids.append(sid)
293
+ logger.info(f"Found shard {sid} for chunk {chunk_id}: {sinfo}")
294
+
295
+ chunk_index = chunk_state.start_index // self.chunk_size
296
+ shard_name = Path(self.shard_info[shard_ids[0]]["filename"]).stem
297
+ unit = WorkUnit(
298
+ unit_id=chunk_id,
299
+ chunk_id=chunk_id,
300
+ source_id=shard_name,
301
+ data={
302
+ "dataset_name": self.dataset_name,
303
+ "config": self.config,
304
+ "split": self.split,
305
+ "start_index": chunk_state.start_index,
306
+ "chunk_size": chunk_state.chunk_size,
307
+ "unprocessed_ranges": unprocessed_ranges,
308
+ "shard_ids": shard_ids,
309
+ },
310
+ metadata={
311
+ "dataset": self.dataset_name,
312
+ "shard_name": shard_name,
313
+ "chunk_index": chunk_index,
314
+ },
315
+ )
316
+
317
+ self.work_units[unit.unit_id] = unit
318
+ self.pending_units.append(unit.unit_id)
319
+
320
+ def _create_units_background(self) -> None:
321
+ """Background thread to create work units on demand."""
322
+ logger.info("Starting work unit creation thread")
323
+
324
+ current_index = 0
325
+
326
+ while not self.stop_creation.is_set():
327
+ # Check if we need more units
328
+ with self.lock:
329
+ pending_count = len(self.pending_units)
330
+ assigned_count = sum(len(units) for units in self.assigned_units.values())
331
+ worker_count = max(1, len(self.assigned_units))
332
+
333
+ target_buffer = max(self.min_buffer, worker_count * self.buffer_multiplier)
334
+ units_needed = max(0, target_buffer - (pending_count + assigned_count))
335
+
336
+ if units_needed == 0:
337
+ threading.Event().wait(5)
338
+ continue
339
+
340
+ # Make sure we know total items
341
+ if self.total_items == 0:
342
+ # Load all shard sizes
343
+ for sid in range(len(self.shard_info)):
344
+ self._get_shard_size(sid)
345
+
346
+ # Create units as needed
347
+ units_created = 0
348
+
349
+ while units_created < units_needed and current_index < self.total_items:
350
+ chunk_size = min(self.chunk_size, self.total_items - current_index)
351
+ chunk_id = current_index // self.chunk_size
352
+
353
+ with self.lock:
354
+ shard_ids = []
355
+ for sid, sinfo in self.shard_info.items():
356
+ if (
357
+ sinfo["start_offset"] <= current_index + chunk_size - 1
358
+ and sinfo["end_offset"] >= current_index
359
+ ):
360
+ shard_ids.append(sid)
361
+ shard_name = Path(self.shard_info[shard_ids[0]]["filename"]).stem
362
+
363
+ job_id_obj = JobId(
364
+ shard_id=shard_name, chunk_id=chunk_id, sample_id=current_index
365
+ )
366
+ unit_id = (
367
+ job_id_obj.get_chunk_str()
368
+ ) # just the chunk part, eg pixel-images:chunk:0
369
+ if unit_id in self.work_units:
370
+ current_index += self.chunk_size
371
+ continue
372
+
373
+ # Check if chunk is already completed
374
+ if self.chunk_tracker:
375
+ chunk_state = self.chunk_tracker.chunks.get(unit_id)
376
+ if chunk_state and chunk_state.status == "completed":
377
+ current_index += self.chunk_size
378
+ continue
379
+
380
+ # Find which shard(s) this chunk belongs to
381
+
382
+ unit = WorkUnit(
383
+ unit_id=unit_id,
384
+ chunk_id=unit_id,
385
+ source_id=shard_name,
386
+ data={
387
+ "dataset_name": self.dataset_name,
388
+ "config": self.config,
389
+ "split": self.split,
390
+ "start_index": current_index,
391
+ "chunk_size": chunk_size,
392
+ "unprocessed_ranges": [(current_index, current_index + chunk_size - 1)],
393
+ "shard_ids": shard_ids,
394
+ },
395
+ metadata={
396
+ "dataset": self.dataset_name,
397
+ "shard_name": shard_name,
398
+ "chunk_index": chunk_id,
399
+ },
400
+ )
401
+ logger.debug(f"Created WorkUnit: {unit}")
402
+
403
+ self.work_units[unit_id] = unit
404
+ self.pending_units.append(unit_id)
405
+
406
+ if self.chunk_tracker:
407
+ self.chunk_tracker.add_chunk(
408
+ unit_id,
409
+ self.dataset_name,
410
+ "", # No shard URL
411
+ current_index,
412
+ chunk_size,
413
+ )
414
+
415
+ units_created += 1
416
+
417
+ current_index += self.chunk_size
418
+
419
+ if units_created > 0:
420
+ logger.debug(f"Created {units_created} work units")
421
+
422
+ def _subtract_ranges(
423
+ self, total_ranges: List[Tuple[int, int]], processed_ranges: List[Tuple[int, int]]
424
+ ) -> List[Tuple[int, int]]:
425
+ """Subtract processed ranges from total ranges."""
426
+ if not processed_ranges:
427
+ return total_ranges
428
+
429
+ # Create a set of all processed indices
430
+ processed_indices = set()
431
+ for start, end in processed_ranges:
432
+ processed_indices.update(range(start, end + 1))
433
+
434
+ # Find unprocessed ranges
435
+ unprocessed_ranges = []
436
+ for start, end in total_ranges:
437
+ current_start = None
438
+ for i in range(start, end + 1):
439
+ if i not in processed_indices:
440
+ if current_start is None:
441
+ current_start = i
442
+ else:
443
+ if current_start is not None:
444
+ unprocessed_ranges.append((current_start, i - 1))
445
+ current_start = None
446
+
447
+ if current_start is not None:
448
+ unprocessed_ranges.append((current_start, end))
449
+
450
+ return unprocessed_ranges
451
+
452
+ def get_work_units(self, count: int, worker_id: str) -> List[WorkUnit]:
453
+ """Get available work units for a worker."""
454
+ logger.debug("get_work_units called: count=%d worker_id=%s", count, worker_id)
455
+ assigned = []
456
+
457
+ with self.lock:
458
+ while len(assigned) < count and self.pending_units:
459
+ unit_id = self.pending_units.popleft()
460
+ unit = self.work_units.get(unit_id)
461
+
462
+ if unit:
463
+ self.assigned_units[worker_id].add(unit_id)
464
+ assigned.append(unit)
465
+ logger.debug("Assigning unit %s to worker %s", unit_id, worker_id)
466
+
467
+ if self.chunk_tracker:
468
+ self.chunk_tracker.mark_assigned(unit_id, worker_id)
469
+
470
+ logger.debug("Returning %d work units to worker %s", len(assigned), worker_id)
471
+ return assigned
472
+
473
+ def mark_completed(self, unit_id: str, worker_id: str) -> None:
474
+ """Mark a work unit as completed."""
475
+ logger.debug("Marking unit %s as completed by worker %s", unit_id, worker_id)
476
+ with self.lock:
477
+ if unit_id in self.work_units:
478
+ self.assigned_units[worker_id].discard(unit_id)
479
+
480
+ if self.chunk_tracker:
481
+ self.chunk_tracker.mark_completed(unit_id)
482
+
483
+ def mark_failed(self, unit_id: str, worker_id: str, error: str) -> None:
484
+ """Mark a work unit as failed."""
485
+ logger.debug("Marking unit %s as failed by worker %s, error: %s", unit_id, worker_id, error)
486
+ with self.lock:
487
+ if unit_id in self.work_units:
488
+ self.assigned_units[worker_id].discard(unit_id)
489
+ self.pending_units.append(unit_id)
490
+
491
+ if self.chunk_tracker:
492
+ self.chunk_tracker.mark_failed(unit_id)
493
+
494
+ def release_assignments(self, worker_id: str) -> None:
495
+ """Release all assignments for a disconnected worker."""
496
+ logger.debug("Releasing assignments for worker %s", worker_id)
497
+ with self.lock:
498
+ unit_ids = list(self.assigned_units.get(worker_id, []))
499
+
500
+ for unit_id in unit_ids:
501
+ if unit_id in self.work_units:
502
+ self.pending_units.append(unit_id)
503
+
504
+ if worker_id in self.assigned_units:
505
+ del self.assigned_units[worker_id]
506
+
507
+ if self.chunk_tracker:
508
+ self.chunk_tracker.release_worker_chunks(worker_id)
509
+
510
+ def update_from_storage(self, processed_job_ids: Set[str]) -> None:
511
+ """Update work units based on what's been processed."""
512
+ logger.info(f"Updating work units from {len(processed_job_ids)} processed jobs")
513
+
514
+ with self.lock:
515
+ for unit_id, unit in self.work_units.items():
516
+ # Extract chunk info from unit
517
+ logger.debug(f"Checking unit {unit_id} for updates")
518
+ logger.debug(f"Unit data: {unit.data}")
519
+ logger.debug(f"Unit metadata: {unit.metadata}")
520
+ start_index = unit.data["start_index"]
521
+ chunk_size = unit.data["chunk_size"]
522
+ shard_name = unit.metadata["shard_name"]
523
+ chunk_index = unit.metadata["chunk_index"]
524
+
525
+ # Find processed indices for this chunk
526
+ processed_indices = []
527
+ for job_id in processed_job_ids:
528
+ # Parse job_id format: "data-0000:chunk:0:idx:42"
529
+ job_id = JobId.from_str(job_id=job_id)
530
+ if job_id.shard_id == shard_name and int(job_id.chunk_id) == chunk_index:
531
+ idx = int(job_id.sample_id)
532
+ if start_index <= idx < start_index + chunk_size:
533
+ processed_indices.append(idx)
534
+
535
+ if processed_indices:
536
+ # Convert to ranges
537
+ processed_indices.sort()
538
+ processed_ranges = []
539
+ start = processed_indices[0]
540
+ end = processed_indices[0]
541
+
542
+ for idx in processed_indices[1:]:
543
+ if idx == end + 1:
544
+ end = idx
545
+ else:
546
+ processed_ranges.append((start, end))
547
+ start = idx
548
+ end = idx
549
+
550
+ processed_ranges.append((start, end))
551
+
552
+ # Calculate unprocessed ranges
553
+ total_range = [(start_index, start_index + chunk_size - 1)]
554
+ unprocessed_ranges = self._subtract_ranges(total_range, processed_ranges)
555
+
556
+ # Update unit
557
+ unit.data["unprocessed_ranges"] = unprocessed_ranges
558
+
559
+ logger.debug(
560
+ f"Updated unit {unit_id}: {len(processed_indices)} processed, "
561
+ f"unprocessed ranges: {unprocessed_ranges}"
562
+ )
563
+
564
+ def get_stats(self) -> Dict[str, Any]:
565
+ """Get processor statistics."""
566
+ with self.lock:
567
+ stats = {
568
+ "dataset": self.dataset_name,
569
+ "config": self.config,
570
+ "split": self.split,
571
+ "total_units": len(self.work_units),
572
+ "pending_units": len(self.pending_units),
573
+ "assigned_units": sum(len(units) for units in self.assigned_units.values()),
574
+ "total_shards": len(self.shard_info),
575
+ "total_items": self.total_items,
576
+ "workers": len(self.assigned_units),
577
+ }
578
+ return stats
579
+
580
+ def handle_result(self, result: WorkResult) -> Dict[str, Any]:
581
+ """Handle result processing."""
582
+ base_result = super().handle_result(result)
583
+
584
+ # Track processed items
585
+ if self.chunk_tracker:
586
+ if "item_indices" not in result.metadata:
587
+ result.metadata["item_indices"] = [result.metadata.get("_item_index")]
588
+ indices = result.metadata["item_indices"]
589
+
590
+ if indices:
591
+ indices.sort()
592
+ ranges = []
593
+ start = indices[0]
594
+ end = indices[0]
595
+
596
+ for i in range(1, len(indices)):
597
+ if indices[i] == end + 1:
598
+ end = indices[i]
599
+ else:
600
+ ranges.append((start, end))
601
+ start = indices[i]
602
+ end = indices[i]
603
+
604
+ ranges.append((start, end))
605
+
606
+ for start_idx, end_idx in ranges:
607
+ self.chunk_tracker.mark_items_processed(result.chunk_id, start_idx, end_idx)
608
+
609
+ return base_result
610
+
611
+
612
+ class HuggingFaceDatasetWorkerProcessor(WorkerProcessor):
613
+ """Worker processor for HuggingFace datasets."""
614
+
615
+ def __init__(self):
616
+ logger.debug("Initializing HuggingFaceDatasetWorkerProcessor")
617
+ self.dataset_config: Dict[str, Any] = {}
618
+ self.token = get_token()
619
+ self.shard_cache: Dict[int, Dataset] = {} # Cache loaded shards
620
+ self.image_column: Optional[str] = None
621
+ self.url_column: Optional[str] = None
622
+
623
+ def initialize(self, config: ProcessorConfig) -> None:
624
+ """Initialize processor."""
625
+ logger.debug("Initializing worker with config: %s", config.config)
626
+ self.dataset_config = config.config.get("dataset", {})
627
+
628
+ # Determine if this is an image URL dataset or binary image dataset
629
+ self.image_column = self.dataset_config.get("dataset_image_column", "image")
630
+ self.url_column = self.dataset_config.get("dataset_url_column", "image_url")
631
+ self.dataset_path = self.dataset_config.get("dataset_path", None)
632
+
633
+ def _load_shard(self, dataset_name: str, shard_filename: str, shard_id: int) -> Dataset:
634
+ """Load a shard if not already cached."""
635
+ if shard_id in self.shard_cache:
636
+ return self.shard_cache[shard_id]
637
+
638
+ logger.info(f"Loading shard {shard_id}: {shard_filename}")
639
+
640
+ local_path = hf_hub_download(
641
+ repo_id=dataset_name, filename=shard_filename, repo_type="dataset", token=self.token
642
+ )
643
+
644
+ dataset = Dataset.from_parquet(local_path)
645
+ self.shard_cache[shard_id] = dataset
646
+
647
+ return dataset
648
+
649
+ def _extract_filename_from_url(self, url: str) -> str:
650
+ """Extract filename from HF URL format."""
651
+ match = re.search(r"@[a-f0-9]+/(.+)$", url)
652
+ if match:
653
+ return match.group(1)
654
+ return url.split("/")[-1]
655
+
656
+ def process_unit(self, unit: WorkUnit, context: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
657
+ """Process a work unit, yielding items to be captioned."""
658
+ logger.debug("Processing unit: %s", unit.unit_id)
659
+
660
+ dataset_name = unit.data["dataset_name"]
661
+ config = unit.data["config"]
662
+ split = unit.data["split"]
663
+ start_index = unit.data["start_index"]
664
+ chunk_size = unit.data["chunk_size"]
665
+ unprocessed_ranges = unit.data.get(
666
+ "unprocessed_ranges", [(start_index, start_index + chunk_size - 1)]
667
+ )
668
+ shard_ids = unit.data.get("shard_ids", [])
669
+
670
+ logger.info(f"Processing unit {unit.unit_id} with ranges: {unprocessed_ranges}")
671
+
672
+ # Need to get shard info - should be passed in unit data
673
+ # For now, we'll need to load dataset builder to get file info
674
+ from datasets import load_dataset_builder
675
+
676
+ builder = load_dataset_builder(dataset_name, config)
677
+
678
+ data_files = []
679
+ if hasattr(builder.config, "data_files"):
680
+ if isinstance(builder.config.data_files, dict):
681
+ files = builder.config.data_files.get(split, [])
682
+ if isinstance(files, str):
683
+ files = [files]
684
+ data_files = files
685
+
686
+ # Build shard info
687
+ shard_info = {}
688
+ cumulative_offset = 0
689
+
690
+ for i, file_url in enumerate(data_files):
691
+ if i not in shard_ids:
692
+ # Skip loading this shard, but we need its size for offsets
693
+ # This is inefficient - in real implementation, orchestrator should pass this info
694
+ filename = self._extract_filename_from_url(file_url)
695
+ dataset = self._load_shard(dataset_name, filename, i)
696
+ size = len(dataset)
697
+ cumulative_offset += size
698
+ continue
699
+
700
+ filename = self._extract_filename_from_url(file_url)
701
+ dataset = self._load_shard(dataset_name, filename, i)
702
+
703
+ shard_info[i] = {
704
+ "dataset": dataset,
705
+ "start_offset": cumulative_offset,
706
+ "end_offset": cumulative_offset + len(dataset) - 1,
707
+ "columns": dataset.column_names,
708
+ }
709
+ cumulative_offset += len(dataset)
710
+
711
+ # Create set of indices to process
712
+ indices_to_process = set()
713
+ for start, end in unprocessed_ranges:
714
+ indices_to_process.update(range(start, end + 1))
715
+
716
+ processed_indices = []
717
+
718
+ # Process items
719
+ for global_idx in sorted(indices_to_process):
720
+ # Find which shard contains this index
721
+ shard_id = None
722
+ local_idx = None
723
+
724
+ for sid, sinfo in shard_info.items():
725
+ if sinfo["start_offset"] <= global_idx <= sinfo["end_offset"]:
726
+ shard_id = sid
727
+ local_idx = global_idx - sinfo["start_offset"]
728
+ break
729
+
730
+ if shard_id is None:
731
+ logger.warning(f"Could not find shard for global index {global_idx}")
732
+ continue
733
+
734
+ try:
735
+ # Get item from shard
736
+ item = shard_info[shard_id]["dataset"][local_idx]
737
+
738
+ # Check if this is a URL dataset or binary image dataset
739
+ image = None
740
+ image_url = None
741
+
742
+ # Try URL column first
743
+ if self.url_column and self.url_column in item:
744
+ image_url = item[self.url_column]
745
+ # Download image from URL
746
+ try:
747
+ response = requests.get(image_url, timeout=30)
748
+ response.raise_for_status()
749
+ image = Image.open(io.BytesIO(response.content))
750
+ except Exception as e:
751
+ logger.error(f"Error downloading image from {image_url}: {e}")
752
+ continue
753
+
754
+ # Try binary image column
755
+ elif self.image_column and self.image_column in item:
756
+ image_data = item[self.image_column]
757
+ if isinstance(image_data, Image.Image):
758
+ image = image_data
759
+ elif isinstance(image_data, dict) and "bytes" in image_data:
760
+ # Handle datasets Image feature
761
+ image = Image.open(io.BytesIO(image_data["bytes"]))
762
+ elif isinstance(image_data, bytes):
763
+ image = Image.open(io.BytesIO(image_data))
764
+
765
+ if image is None:
766
+ logger.warning(f"No image found for item at index {global_idx}")
767
+ continue
768
+
769
+ # Build job ID
770
+ chunk_index = unit.metadata["chunk_index"]
771
+ shard_name = unit.metadata["shard_name"]
772
+ job_id_obj = JobId(
773
+ shard_id=shard_name, chunk_id=str(chunk_index), sample_id=str(global_idx)
774
+ )
775
+ job_id = job_id_obj.get_sample_str()
776
+
777
+ # Clean metadata
778
+ clean_metadata = {
779
+ k: v
780
+ for k, v in item.items()
781
+ if k not in [self.image_column, self.url_column] and not k.startswith("_")
782
+ }
783
+
784
+ clean_metadata.update(
785
+ {
786
+ "_item_index": global_idx,
787
+ "_chunk_relative_index": global_idx - start_index,
788
+ "_job_id": job_id,
789
+ "_shard_id": shard_id,
790
+ "_local_index": local_idx,
791
+ "_url": image_url,
792
+ }
793
+ )
794
+
795
+ yield {
796
+ "image": image,
797
+ "item_key": str(global_idx),
798
+ "item_index": global_idx,
799
+ "metadata": clean_metadata,
800
+ "job_id": job_id,
801
+ }
802
+
803
+ processed_indices.append(global_idx)
804
+
805
+ except Exception as e:
806
+ logger.error(f"Error processing item at index {global_idx}: {e}")
807
+
808
+ # Store processed indices in context
809
+ context["_processed_indices"] = processed_indices
810
+ logger.debug("Processed indices for unit %s: %s", unit.unit_id, processed_indices)
811
+
812
+ def prepare_result(
813
+ self, unit: WorkUnit, outputs: List[Dict[str, Any]], processing_time_ms: float
814
+ ) -> WorkResult:
815
+ """Prepare result."""
816
+ logger.debug("Preparing result for unit %s", unit.unit_id)
817
+ result = super().prepare_result(unit, outputs, processing_time_ms)
818
+
819
+ # Add processed indices to metadata
820
+ if outputs and "_processed_indices" in outputs[0].get("metadata", {}):
821
+ result.metadata["item_indices"] = outputs[0]["metadata"]["_processed_indices"]
822
+
823
+ return result
824
+
825
+ def get_dataset_info(self) -> Dict[str, Any]:
826
+ """Get dataset information."""
827
+ return {
828
+ "dataset_path": self.dataset_config.get("dataset_path"),
829
+ "dataset_type": "huggingface",
830
+ "config": self.dataset_config.get("dataset_config"),
831
+ "split": self.dataset_config.get("dataset_split"),
832
+ }