caption-flow 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,6 +9,8 @@ import json
9
9
 
10
10
  import webdataset as wds
11
11
  from huggingface_hub import HfFileSystem, get_token, hf_hub_url
12
+ from datasets import load_dataset, Dataset
13
+ from .image_processor import ImageProcessor
12
14
 
13
15
  logger = logging.getLogger(__name__)
14
16
 
@@ -16,42 +18,143 @@ logger = logging.getLogger(__name__)
16
18
  class DatasetLoader:
17
19
  """Handles loading datasets from various sources."""
18
20
 
19
- def __init__(self, dataset_path: str, dataset_type: str = "huggingface"):
21
+ def __init__(
22
+ self,
23
+ dataset_path: str,
24
+ dataset_type: str = "huggingface",
25
+ split: str = "train",
26
+ image_column: str = "image",
27
+ ):
20
28
  """
21
29
  Initialize dataset loader.
22
30
 
23
31
  Args:
24
32
  dataset_path: Path to dataset (HF repo, local dir, etc.)
25
33
  dataset_type: Type of dataset ("huggingface", "webdataset", "local")
34
+ split: Split to use for HuggingFace datasets (default: "train")
35
+ image_column: Column name containing image data or URLs (default: "image")
26
36
  """
27
37
  self.dataset_path = dataset_path
28
38
  self.dataset_type = dataset_type
39
+ self.split = split
40
+ self.image_column = image_column
29
41
  self.token = get_token()
42
+ self.dataset_format = None # Will be detected: "webdataset" or "huggingface_datasets"
43
+ self._hf_dataset = None # Cache for HuggingFace dataset
44
+ self._hf_total_items = None # Cache for total items count
30
45
 
31
46
  if not self.token and dataset_type == "huggingface":
32
47
  logger.warning("No HuggingFace token found; run `huggingface-cli login`")
33
48
 
49
+ # Detect the actual format if it's a HuggingFace dataset
50
+ if dataset_type == "huggingface":
51
+ self.dataset_format = self._detect_dataset_format()
52
+ logger.info(f"Detected dataset format: {self.dataset_format}")
53
+
54
+ def _detect_dataset_format(self) -> str:
55
+ """Detect whether it's WebDataset or HuggingFace datasets format."""
56
+ fs = HfFileSystem(token=self.token)
57
+
58
+ # Check for .tar files (WebDataset)
59
+ tar_files = list(fs.glob(f"hf://datasets/{self.dataset_path}/**/*.tar"))
60
+ if tar_files:
61
+ return "webdataset"
62
+
63
+ # Check for parquet files (HuggingFace datasets)
64
+ parquet_files = list(fs.glob(f"hf://datasets/{self.dataset_path}/**/*.parquet"))
65
+ if parquet_files:
66
+ return "huggingface_datasets"
67
+
68
+ # Check for dataset_info.json or dataset_dict.json
69
+ if fs.exists(f"datasets/{self.dataset_path}/dataset_info.json") or fs.exists(
70
+ f"datasets/{self.dataset_path}/dataset_dict.json"
71
+ ):
72
+ return "huggingface_datasets"
73
+
74
+ logger.warning(f"Could not detect dataset format for {self.dataset_path}")
75
+ return "unknown"
76
+
34
77
  def get_shard_list(self) -> List[str]:
35
78
  """Get list of all shards in the dataset."""
36
79
  if self.dataset_type == "huggingface":
37
- return self._get_hf_shards()
80
+ if self.dataset_format == "webdataset":
81
+ return self._get_hf_webdataset_shards()
82
+ elif self.dataset_format == "huggingface_datasets":
83
+ return self._get_hf_dataset_shards()
84
+ else:
85
+ logger.error(f"Unknown dataset format: {self.dataset_format}")
86
+ return []
38
87
  elif self.dataset_type == "local":
39
88
  return self._get_local_shards()
40
89
  else:
41
90
  raise ValueError(f"Unknown dataset type: {self.dataset_type}")
42
91
 
43
- def _get_hf_shards(self) -> List[str]:
44
- """Get shard URLs from HuggingFace dataset."""
45
- logger.info(f"Getting shard list from HuggingFace: {self.dataset_path}")
92
+ def _get_hf_webdataset_shards(self) -> List[str]:
93
+ """Get shard URLs from HuggingFace WebDataset."""
94
+ logger.info(f"Getting WebDataset shard list from HuggingFace: {self.dataset_path}")
46
95
 
47
- fs = HfFileSystem()
96
+ fs = HfFileSystem(token=self.token)
48
97
  files = [fs.resolve_path(p) for p in fs.glob(f"hf://datasets/{self.dataset_path}/**/*.tar")]
49
98
 
50
99
  urls = [hf_hub_url(f.repo_id, f.path_in_repo, repo_type="dataset") for f in files]
51
100
 
52
- logger.info(f"Found {len(urls)} shards")
101
+ logger.info(f"Found {len(urls)} WebDataset shards")
53
102
  return sorted(urls)
54
103
 
104
+ def _get_hf_dataset_shards(self) -> List[str]:
105
+ """Get virtual 'shards' for HuggingFace datasets format."""
106
+ logger.info(f"Getting HuggingFace dataset info: {self.dataset_path}")
107
+
108
+ # For HuggingFace datasets, we'll create virtual shards based on chunks
109
+ # Each "shard" will be a range of indices
110
+ try:
111
+ # First, try to get available splits
112
+ try:
113
+ from datasets import get_dataset_split_names
114
+
115
+ available_splits = get_dataset_split_names(self.dataset_path, token=self.token)
116
+ logger.info(f"Available splits: {available_splits}")
117
+
118
+ if self.split not in available_splits:
119
+ logger.warning(
120
+ f"Requested split '{self.split}' not found. "
121
+ f"Available splits: {available_splits}. "
122
+ f"Using first available split: '{available_splits[0]}'"
123
+ )
124
+ self.split = available_splits[0]
125
+ except Exception as e:
126
+ logger.warning(f"Could not get split names: {e}")
127
+
128
+ # Load dataset info without downloading data
129
+ dataset_info = load_dataset(
130
+ self.dataset_path, split=self.split, streaming=True, token=self.token
131
+ )
132
+
133
+ # Try to get the total size
134
+ # For streaming datasets, we might need to iterate to count
135
+ # This is expensive, so we'll use a default chunk size instead
136
+ chunk_size = 10000 # Default chunk size for virtual shards
137
+
138
+ # Create virtual shard identifiers
139
+ # Format: "hf_dataset:<dataset_path>:chunk:<start_idx>"
140
+ virtual_shards = []
141
+
142
+ # We'll create a reasonable number of virtual shards
143
+ # Without knowing the total size, we'll create them on-demand
144
+ # For now, create initial batch of virtual shards
145
+ for i in range(10): # Start with 10 virtual shards
146
+ shard_id = f"hf_dataset:{self.dataset_path}:chunk:{i * chunk_size}"
147
+ virtual_shards.append(shard_id)
148
+
149
+ logger.info(
150
+ f"Created {len(virtual_shards)} initial virtual shards for HuggingFace dataset"
151
+ )
152
+ return virtual_shards
153
+
154
+ except Exception as e:
155
+ logger.error(f"Error loading HuggingFace dataset info: {e}")
156
+ return []
157
+
55
158
  def _get_local_shards(self) -> List[str]:
56
159
  """Get shard files from local directory."""
57
160
  path = Path(self.dataset_path)
@@ -73,7 +176,14 @@ class DatasetLoader:
73
176
  if processed_keys is None:
74
177
  processed_keys = set()
75
178
 
76
- if self.dataset_type == "huggingface":
179
+ # Check if this is a virtual HuggingFace dataset shard
180
+ if shard_url.startswith("hf_dataset:"):
181
+ raise ValueError(
182
+ "Virtual HuggingFace dataset shards should use iterate_shard() directly, "
183
+ "not load_shard()"
184
+ )
185
+
186
+ if self.dataset_type == "huggingface" and self.dataset_format == "webdataset":
77
187
  # Use curl with auth token for HuggingFace
78
188
  url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.token)}' {shlex.quote(shard_url)} || true"
79
189
  ds = wds.DataPipeline(
@@ -93,6 +203,19 @@ class DatasetLoader:
93
203
 
94
204
  return ds
95
205
 
206
+ def _parse_virtual_shard(self, shard_url: str) -> Tuple[str, int, int]:
207
+ """Parse virtual shard identifier."""
208
+ # Format: "hf_dataset:<dataset_path>:chunk:<start_idx>"
209
+ parts = shard_url.split(":")
210
+ if len(parts) != 4 or parts[0] != "hf_dataset" or parts[2] != "chunk":
211
+ raise ValueError(f"Invalid virtual shard format: {shard_url}")
212
+
213
+ dataset_path = parts[1]
214
+ start_idx = int(parts[3])
215
+ chunk_size = 10000 # Default chunk size
216
+
217
+ return dataset_path, start_idx, chunk_size
218
+
96
219
  def iterate_shard(
97
220
  self, shard_url: str, processed_keys: Optional[set] = None
98
221
  ) -> Generator[Tuple[str, str, bytes], None, None]:
@@ -102,85 +225,456 @@ class DatasetLoader:
102
225
  Yields:
103
226
  Tuple of (key, url, image_bytes)
104
227
  """
105
- ds = self.load_shard(shard_url, processed_keys)
106
-
107
- for key, url, image_data in ds:
108
- yield key, url, image_data
228
+ # Check if this is a virtual HuggingFace dataset shard
229
+ if shard_url.startswith("hf_dataset:"):
230
+ yield from self._iterate_hf_dataset_shard(shard_url, processed_keys)
231
+ else:
232
+ # Regular WebDataset shard
233
+ ds = self.load_shard(shard_url, processed_keys)
234
+ for key, url, image_data in ds:
235
+ yield key, url, image_data
109
236
 
110
- def count_shard_items(self, shard_url: str, processed_keys: Optional[set] = None) -> int:
111
- """Count items in a shard (can be slow for large shards)."""
112
- count = 0
237
+ def _create_dataset_at_position(self, dataset_path: str, split: str, start_idx: int):
238
+ """Create a dataset iterator positioned at start_idx using state_dict if available."""
113
239
  try:
114
- for _ in self.iterate_shard(shard_url, processed_keys):
115
- count += 1
240
+ # Load dataset in streaming mode
241
+ dataset = load_dataset(
242
+ dataset_path,
243
+ split=split,
244
+ streaming=True,
245
+ token=self.token,
246
+ )
247
+
248
+ # Check if the dataset supports state_dict (newer versions of datasets library)
249
+ if hasattr(dataset, "load_state_dict") and hasattr(dataset, "state_dict"):
250
+ # Try to use the dataset's native state management
251
+ try:
252
+ # Get current state
253
+ state = dataset.state_dict()
254
+
255
+ # Modify the state to skip to start_idx
256
+ if "epoch" in state:
257
+ state["epoch"] = 0
258
+ if "num_examples_since_previous_state" in state:
259
+ state["num_examples_since_previous_state"] = start_idx
260
+
261
+ # For newer datasets with examples_iterable state
262
+ if "examples_iterable" in state:
263
+ if isinstance(state["examples_iterable"], dict):
264
+ if "shard_example_idx" in state["examples_iterable"]:
265
+ state["examples_iterable"]["shard_example_idx"] = start_idx
266
+
267
+ # Load the modified state
268
+ dataset.load_state_dict(state)
269
+ logger.info(f"Positioned dataset at index {start_idx} using state_dict")
270
+ return dataset
271
+ except Exception as e:
272
+ logger.debug(f"Could not use state_dict approach: {e}")
273
+
274
+ # Fall back to skip() for large skips
275
+ if start_idx > 0:
276
+ logger.info(f"Using skip() to position dataset at index {start_idx}")
277
+ dataset = dataset.skip(start_idx)
278
+
279
+ return dataset
280
+
116
281
  except Exception as e:
117
- logger.error(f"Error counting shard {shard_url}: {e}")
118
- return count
282
+ logger.warning(f"Error creating positioned dataset: {e}")
283
+ return None
284
+
285
+ def _iterate_hf_dataset_shard_with_metadata(
286
+ self, shard_url: str, processed_keys: Optional[set] = None
287
+ ) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
288
+ """Iterate over a virtual HuggingFace dataset shard with metadata."""
289
+ if processed_keys is None:
290
+ processed_keys = set()
119
291
 
292
+ dataset_path, start_idx, chunk_size = self._parse_virtual_shard(shard_url)
120
293
 
121
- class ShardTracker:
122
- """Tracks shard processing progress."""
294
+ logger.info(
295
+ f"Loading HuggingFace dataset with metadata: {dataset_path} (split: {self.split})"
296
+ )
123
297
 
124
- def __init__(self, checkpoint_path: Path):
125
- """Initialize shard tracker with checkpoint file."""
126
- self.checkpoint_path = checkpoint_path
127
- self.checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
298
+ try:
299
+ # Try optimized approach for large skips
300
+ if start_idx > 100:
301
+ dataset = self._create_dataset_at_position(dataset_path, self.split, start_idx)
302
+ if dataset:
303
+ items_processed = 0
304
+
305
+ for item in dataset:
306
+ # Stop after processing chunk_size items
307
+ if items_processed >= chunk_size:
308
+ break
309
+
310
+ # Generate a unique key for this item
311
+ key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
312
+
313
+ if key in processed_keys:
314
+ items_processed += 1
315
+ continue
316
+
317
+ try:
318
+ # Extract image data
319
+ if self.image_column in item:
320
+ img_data = item[self.image_column]
321
+
322
+ # Process image to bytes
323
+ image_bytes = ImageProcessor.process_image_data(img_data)
324
+
325
+ if image_bytes:
326
+ # Extract all metadata (excluding the image column)
327
+ metadata = {
328
+ k: v for k, v in item.items() if k != self.image_column
329
+ }
330
+
331
+ # URL is virtual for HF datasets
332
+ url = f"hf://{dataset_path}#{start_idx + items_processed}"
333
+ items_processed += 1
334
+ yield key, url, image_bytes, metadata
335
+ else:
336
+ logger.warning(
337
+ f"Failed to process image for item at index {start_idx + items_processed}"
338
+ )
339
+ items_processed += 1
340
+ continue
341
+ else:
342
+ logger.warning(
343
+ f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
344
+ f"Available columns: {list(item.keys())}"
345
+ )
346
+ items_processed += 1
347
+
348
+ except Exception as e:
349
+ logger.error(
350
+ f"Error processing item at index {start_idx + items_processed}: {e}"
351
+ )
352
+ items_processed += 1
353
+ continue
354
+
355
+ return
356
+
357
+ # Fall back to regular approach for small skips or if StatefulDataLoader not available
358
+ dataset = load_dataset(
359
+ dataset_path,
360
+ split=self.split,
361
+ streaming=True,
362
+ token=self.token,
363
+ )
128
364
 
129
- self.completed_shards: set = set()
130
- self.partial_shards: Dict[str, Dict[str, Any]] = {}
131
- self.load()
365
+ # Skip to start index if needed
366
+ if start_idx > 0:
367
+ dataset = dataset.skip(start_idx)
368
+
369
+ items_processed = 0
370
+
371
+ for item in dataset:
372
+ # Stop after processing chunk_size items
373
+ if items_processed >= chunk_size:
374
+ break
375
+
376
+ # Generate a unique key for this item
377
+ key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
378
+
379
+ if key in processed_keys:
380
+ items_processed += 1
381
+ continue
382
+
383
+ try:
384
+ # Extract image data
385
+ if self.image_column in item:
386
+ img_data = item[self.image_column]
387
+
388
+ # Process image to bytes
389
+ image_bytes = ImageProcessor.process_image_data(img_data)
390
+
391
+ if image_bytes:
392
+ # Extract all metadata (excluding the image column)
393
+ metadata = {k: v for k, v in item.items() if k != self.image_column}
394
+
395
+ # URL is virtual for HF datasets
396
+ url = f"hf://{dataset_path}#{start_idx + items_processed}"
397
+ items_processed += 1
398
+ yield key, url, image_bytes, metadata
399
+ else:
400
+ logger.warning(
401
+ f"Failed to process image for item at index {start_idx + items_processed}"
402
+ )
403
+ items_processed += 1
404
+ continue
405
+ else:
406
+ logger.warning(
407
+ f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
408
+ f"Available columns: {list(item.keys())}"
409
+ )
410
+ items_processed += 1
411
+
412
+ except Exception as e:
413
+ logger.error(
414
+ f"Error processing item at index {start_idx + items_processed}: {e}"
415
+ )
416
+ items_processed += 1
417
+ continue
132
418
 
133
- def load(self):
134
- """Load checkpoint from disk."""
135
- if self.checkpoint_path.exists():
136
- try:
137
- data = json.loads(self.checkpoint_path.read_text())
138
- self.completed_shards = set(data.get("completed_shards", []))
139
- self.partial_shards = data.get("partial_shards", {})
140
- logger.info(
141
- f"Loaded checkpoint: {len(self.completed_shards)} completed, "
142
- f"{len(self.partial_shards)} partial shards"
143
- )
144
- except Exception as e:
145
- logger.error(f"Failed to load checkpoint: {e}")
419
+ except Exception as e:
420
+ logger.error(f"Error loading HuggingFace dataset: {e}")
421
+ return
146
422
 
147
- def save(self):
148
- """Save checkpoint to disk."""
149
- data = {
150
- "completed_shards": list(self.completed_shards),
151
- "partial_shards": self.partial_shards,
152
- }
423
+ def _iterate_hf_dataset_shard(
424
+ self, shard_url: str, processed_keys: Optional[set] = None
425
+ ) -> Generator[Tuple[str, str, bytes], None, None]:
426
+ """Iterate over a virtual HuggingFace dataset shard."""
427
+ if processed_keys is None:
428
+ processed_keys = set()
429
+
430
+ dataset_path, start_idx, chunk_size = self._parse_virtual_shard(shard_url)
431
+
432
+ # IMPORTANT: Check if start_idx is beyond dataset bounds
433
+ if self._hf_total_items is not None and start_idx >= self._hf_total_items:
434
+ logger.warning(
435
+ f"Virtual shard starts at index {start_idx} but dataset only has "
436
+ f"{self._hf_total_items} items. Skipping this shard."
437
+ )
438
+ return
439
+
440
+ logger.info(
441
+ f"Loading HuggingFace dataset in streaming mode: {dataset_path} "
442
+ f"(split: {self.split}, start: {start_idx}, chunk_size: {chunk_size})"
443
+ )
444
+
445
+ try:
446
+ # Try optimized approach for large skips
447
+ if start_idx > 100:
448
+ dataset = self._create_dataset_at_position(dataset_path, self.split, start_idx)
449
+ if dataset:
450
+ items_processed = 0
451
+
452
+ for item in dataset:
453
+ # Stop after processing chunk_size items
454
+ if items_processed >= chunk_size:
455
+ logger.info(f"Completed chunk: processed {items_processed} items")
456
+ break
457
+
458
+ # Also stop if we've reached the dataset end
459
+ if (
460
+ self._hf_total_items
461
+ and (start_idx + items_processed) >= self._hf_total_items
462
+ ):
463
+ logger.info(
464
+ f"Reached dataset end at item {start_idx + items_processed} "
465
+ f"(total: {self._hf_total_items})"
466
+ )
467
+ break
468
+
469
+ # Generate a unique key for this item
470
+ key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
471
+
472
+ if key in processed_keys:
473
+ items_processed += 1
474
+ continue
475
+
476
+ try:
477
+ # Extract image data
478
+ if self.image_column in item:
479
+ img_data = item[self.image_column]
480
+
481
+ # Delegate image processing to ImageProcessor
482
+ image_bytes = ImageProcessor.process_image_data(img_data)
483
+
484
+ if image_bytes:
485
+ # URL is virtual for HF datasets
486
+ url = f"hf://{dataset_path}#{start_idx + items_processed}"
487
+ items_processed += 1
488
+ yield key, url, image_bytes
489
+ else:
490
+ logger.warning(
491
+ f"Failed to process image for item at index {start_idx + items_processed}"
492
+ )
493
+ items_processed += 1
494
+ continue
495
+ else:
496
+ logger.warning(
497
+ f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
498
+ f"Available columns: {list(item.keys())}"
499
+ )
500
+ items_processed += 1
501
+
502
+ except Exception as e:
503
+ logger.error(
504
+ f"Error processing item at index {start_idx + items_processed}: {e}"
505
+ )
506
+ items_processed += 1
507
+ continue
508
+
509
+ logger.info(
510
+ f"Virtual shard complete: processed {items_processed} items "
511
+ f"(start_idx: {start_idx})"
512
+ )
513
+ return
514
+
515
+ # Fall back to regular approach for small skips or if StatefulDataLoader not available
516
+ dataset = load_dataset(
517
+ dataset_path,
518
+ split=self.split,
519
+ streaming=True,
520
+ token=self.token,
521
+ )
153
522
 
154
- tmp = self.checkpoint_path.with_suffix(".tmp")
155
- tmp.write_text(json.dumps(data, indent=2))
156
- tmp.replace(self.checkpoint_path)
523
+ # Use dataset.skip() for efficient skipping
524
+ if start_idx > 0:
525
+ dataset = dataset.skip(start_idx)
526
+ logger.info(f"Skipped to index {start_idx}")
527
+
528
+ items_processed = 0
529
+
530
+ # Now enumerate starts from 0 after skip
531
+ for item in dataset:
532
+ # Stop after processing chunk_size items
533
+ if items_processed >= chunk_size:
534
+ logger.info(f"Completed chunk: processed {items_processed} items")
535
+ break
536
+
537
+ # Also stop if we've reached the dataset end
538
+ if self._hf_total_items and (start_idx + items_processed) >= self._hf_total_items:
539
+ logger.info(
540
+ f"Reached dataset end at item {start_idx + items_processed} "
541
+ f"(total: {self._hf_total_items})"
542
+ )
543
+ break
544
+
545
+ # Generate a unique key for this item - ensure proper formatting
546
+ key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
547
+
548
+ if key in processed_keys:
549
+ items_processed += 1
550
+ continue
551
+
552
+ try:
553
+ # Extract image data - check configured column name
554
+ if self.image_column in item:
555
+ img_data = item[self.image_column]
556
+
557
+ # Delegate image processing to ImageProcessor
558
+ image_bytes = ImageProcessor.process_image_data(img_data)
559
+
560
+ if image_bytes:
561
+ # URL is virtual for HF datasets
562
+ url = f"hf://{dataset_path}#{start_idx + items_processed}"
563
+ items_processed += 1
564
+ yield key, url, image_bytes
565
+ else:
566
+ logger.warning(
567
+ f"Failed to process image for item at index {start_idx + items_processed}"
568
+ )
569
+ items_processed += 1
570
+ continue
571
+ else:
572
+ logger.warning(
573
+ f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
574
+ f"Available columns: {list(item.keys())}"
575
+ )
576
+ items_processed += 1
577
+
578
+ except Exception as e:
579
+ logger.error(
580
+ f"Error processing item at index {start_idx + items_processed}: {e}"
581
+ )
582
+ items_processed += 1
583
+ continue
584
+
585
+ logger.info(
586
+ f"Virtual shard complete: processed {items_processed} items "
587
+ f"(start_idx: {start_idx})"
588
+ )
157
589
 
158
- def mark_complete(self, shard_name: str):
159
- """Mark a shard as complete."""
160
- self.completed_shards.add(shard_name)
161
- if shard_name in self.partial_shards:
162
- del self.partial_shards[shard_name]
163
- self.save()
590
+ except Exception as e:
591
+ logger.error(f"Error loading HuggingFace dataset: {e}")
592
+ return
164
593
 
165
- def update_partial(self, shard_name: str, processed_keys: List[str]):
166
- """Update partial progress for a shard."""
167
- self.partial_shards[shard_name] = {"keys": processed_keys, "count": len(processed_keys)}
168
- self.save()
594
+ def iterate_shard_with_metadata(
595
+ self, shard_url: str, processed_keys: Optional[set] = None
596
+ ) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
597
+ """
598
+ Iterate over items in a shard, including metadata.
169
599
 
170
- def get_processed_keys(self, shard_name: str) -> set:
171
- """Get set of processed keys for a shard."""
172
- if shard_name in self.completed_shards:
173
- return set() # All done
600
+ Yields:
601
+ Tuple of (key, url, image_bytes, metadata_dict)
602
+ """
603
+ # Check if this is a virtual HuggingFace dataset shard
604
+ if shard_url.startswith("hf_dataset:"):
605
+ yield from self._iterate_hf_dataset_shard_with_metadata(shard_url, processed_keys)
606
+ else:
607
+ # Regular WebDataset shard - no metadata by default
608
+ for key, url, image_data in self.iterate_shard(shard_url, processed_keys):
609
+ yield key, url, image_data, {}
174
610
 
175
- if shard_name in self.partial_shards:
176
- return set(self.partial_shards[shard_name].get("keys", []))
611
+ def count_shard_items(self, shard_url: str, processed_keys: Optional[set] = None) -> int:
612
+ """Count items in a shard (can be slow for large shards)."""
613
+ if shard_url.startswith("hf_dataset:"):
614
+ # For virtual shards, return the chunk size
615
+ _, start_idx, chunk_size = self._parse_virtual_shard(shard_url)
616
+
617
+ # CRITICAL: Cap chunk size by dataset bounds
618
+ if self._hf_total_items is not None:
619
+ # If start index is beyond dataset, return 0
620
+ if start_idx >= self._hf_total_items:
621
+ logger.warning(
622
+ f"Virtual shard starts at {start_idx} but dataset has "
623
+ f"only {self._hf_total_items} items"
624
+ )
625
+ return 0
626
+
627
+ # Otherwise, return the minimum of chunk_size and remaining items
628
+ remaining_items = self._hf_total_items - start_idx
629
+ actual_size = min(chunk_size, remaining_items)
630
+ logger.debug(
631
+ f"Virtual shard at {start_idx}: chunk_size={chunk_size}, "
632
+ f"remaining={remaining_items}, actual={actual_size}"
633
+ )
634
+ return actual_size
635
+ else:
636
+ # If we don't know total size, return chunk_size
637
+ return chunk_size
638
+ else:
639
+ # Regular WebDataset counting
640
+ count = 0
641
+ try:
642
+ for _ in self.iterate_shard(shard_url, processed_keys):
643
+ count += 1
644
+ except Exception as e:
645
+ logger.error(f"Error counting shard {shard_url}: {e}")
646
+ return count
647
+
648
+ def get_dataset_info(self) -> Dict[str, Any]:
649
+ """Get information about the dataset."""
650
+ info = {
651
+ "dataset_path": self.dataset_path,
652
+ "dataset_type": self.dataset_type,
653
+ "dataset_format": self.dataset_format,
654
+ }
177
655
 
178
- return set()
656
+ if self.dataset_format == "huggingface_datasets":
657
+ try:
658
+ # Try to get more info about the dataset
659
+ dataset_info = load_dataset(
660
+ self.dataset_path, split=self.split, streaming=True, token=self.token
661
+ )
662
+ # Get features info
663
+ if hasattr(dataset_info, "features"):
664
+ info["features"] = str(dataset_info.features)
665
+
666
+ # Try to get total size (might not work for all datasets)
667
+ try:
668
+ # This might be expensive for large datasets
669
+ total_examples = len(
670
+ load_dataset(self.dataset_path, split=self.split, token=self.token)
671
+ )
672
+ info["total_examples"] = total_examples
673
+ self._hf_total_items = total_examples
674
+ except:
675
+ info["total_examples"] = "unknown"
179
676
 
180
- def is_complete(self, shard_name: str) -> bool:
181
- """Check if a shard is complete."""
182
- return shard_name in self.completed_shards
677
+ except Exception as e:
678
+ logger.error(f"Error getting dataset info: {e}")
183
679
 
184
- def get_remaining_shards(self, all_shards: List[str]) -> List[str]:
185
- """Get list of shards that still need processing."""
186
- return [s for s in all_shards if Path(s).stem not in self.completed_shards]
680
+ return info