caption-flow 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,8 +9,6 @@ import json
9
9
 
10
10
  import webdataset as wds
11
11
  from huggingface_hub import HfFileSystem, get_token, hf_hub_url
12
- from datasets import load_dataset, Dataset
13
- from .image_processor import ImageProcessor
14
12
 
15
13
  logger = logging.getLogger(__name__)
16
14
 
@@ -24,6 +22,7 @@ class DatasetLoader:
24
22
  dataset_type: str = "huggingface",
25
23
  split: str = "train",
26
24
  image_column: str = "image",
25
+ cache_dir: Optional[Path] = None,
27
26
  ):
28
27
  """
29
28
  Initialize dataset loader.
@@ -40,8 +39,6 @@ class DatasetLoader:
40
39
  self.image_column = image_column
41
40
  self.token = get_token()
42
41
  self.dataset_format = None # Will be detected: "webdataset" or "huggingface_datasets"
43
- self._hf_dataset = None # Cache for HuggingFace dataset
44
- self._hf_total_items = None # Cache for total items count
45
42
 
46
43
  if not self.token and dataset_type == "huggingface":
47
44
  logger.warning("No HuggingFace token found; run `huggingface-cli login`")
@@ -60,27 +57,18 @@ class DatasetLoader:
60
57
  if tar_files:
61
58
  return "webdataset"
62
59
 
63
- # Check for parquet files (HuggingFace datasets)
60
+ # Check for .parquet files (Huggingface Arrow DB)
64
61
  parquet_files = list(fs.glob(f"hf://datasets/{self.dataset_path}/**/*.parquet"))
65
62
  if parquet_files:
66
63
  return "huggingface_datasets"
67
64
 
68
- # Check for dataset_info.json or dataset_dict.json
69
- if fs.exists(f"datasets/{self.dataset_path}/dataset_info.json") or fs.exists(
70
- f"datasets/{self.dataset_path}/dataset_dict.json"
71
- ):
72
- return "huggingface_datasets"
73
-
74
- logger.warning(f"Could not detect dataset format for {self.dataset_path}")
75
- return "unknown"
65
+ raise AssertionError(f"Could not detect dataset format for {self.dataset_path}")
76
66
 
77
67
  def get_shard_list(self) -> List[str]:
78
68
  """Get list of all shards in the dataset."""
79
69
  if self.dataset_type == "huggingface":
80
70
  if self.dataset_format == "webdataset":
81
71
  return self._get_hf_webdataset_shards()
82
- elif self.dataset_format == "huggingface_datasets":
83
- return self._get_hf_dataset_shards()
84
72
  else:
85
73
  logger.error(f"Unknown dataset format: {self.dataset_format}")
86
74
  return []
@@ -101,60 +89,6 @@ class DatasetLoader:
101
89
  logger.info(f"Found {len(urls)} WebDataset shards")
102
90
  return sorted(urls)
103
91
 
104
- def _get_hf_dataset_shards(self) -> List[str]:
105
- """Get virtual 'shards' for HuggingFace datasets format."""
106
- logger.info(f"Getting HuggingFace dataset info: {self.dataset_path}")
107
-
108
- # For HuggingFace datasets, we'll create virtual shards based on chunks
109
- # Each "shard" will be a range of indices
110
- try:
111
- # First, try to get available splits
112
- try:
113
- from datasets import get_dataset_split_names
114
-
115
- available_splits = get_dataset_split_names(self.dataset_path, token=self.token)
116
- logger.info(f"Available splits: {available_splits}")
117
-
118
- if self.split not in available_splits:
119
- logger.warning(
120
- f"Requested split '{self.split}' not found. "
121
- f"Available splits: {available_splits}. "
122
- f"Using first available split: '{available_splits[0]}'"
123
- )
124
- self.split = available_splits[0]
125
- except Exception as e:
126
- logger.warning(f"Could not get split names: {e}")
127
-
128
- # Load dataset info without downloading data
129
- dataset_info = load_dataset(
130
- self.dataset_path, split=self.split, streaming=True, token=self.token
131
- )
132
-
133
- # Try to get the total size
134
- # For streaming datasets, we might need to iterate to count
135
- # This is expensive, so we'll use a default chunk size instead
136
- chunk_size = 10000 # Default chunk size for virtual shards
137
-
138
- # Create virtual shard identifiers
139
- # Format: "hf_dataset:<dataset_path>:chunk:<start_idx>"
140
- virtual_shards = []
141
-
142
- # We'll create a reasonable number of virtual shards
143
- # Without knowing the total size, we'll create them on-demand
144
- # For now, create initial batch of virtual shards
145
- for i in range(10): # Start with 10 virtual shards
146
- shard_id = f"hf_dataset:{self.dataset_path}:chunk:{i * chunk_size}"
147
- virtual_shards.append(shard_id)
148
-
149
- logger.info(
150
- f"Created {len(virtual_shards)} initial virtual shards for HuggingFace dataset"
151
- )
152
- return virtual_shards
153
-
154
- except Exception as e:
155
- logger.error(f"Error loading HuggingFace dataset info: {e}")
156
- return []
157
-
158
92
  def _get_local_shards(self) -> List[str]:
159
93
  """Get shard files from local directory."""
160
94
  path = Path(self.dataset_path)
@@ -176,13 +110,6 @@ class DatasetLoader:
176
110
  if processed_keys is None:
177
111
  processed_keys = set()
178
112
 
179
- # Check if this is a virtual HuggingFace dataset shard
180
- if shard_url.startswith("hf_dataset:"):
181
- raise ValueError(
182
- "Virtual HuggingFace dataset shards should use iterate_shard() directly, "
183
- "not load_shard()"
184
- )
185
-
186
113
  if self.dataset_type == "huggingface" and self.dataset_format == "webdataset":
187
114
  # Use curl with auth token for HuggingFace
188
115
  url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.token)}' {shlex.quote(shard_url)} || true"
@@ -203,447 +130,57 @@ class DatasetLoader:
203
130
 
204
131
  return ds
205
132
 
206
- def _parse_virtual_shard(self, shard_url: str) -> Tuple[str, int, int]:
207
- """Parse virtual shard identifier."""
208
- # Format: "hf_dataset:<dataset_path>:chunk:<start_idx>"
209
- parts = shard_url.split(":")
210
- if len(parts) != 4 or parts[0] != "hf_dataset" or parts[2] != "chunk":
211
- raise ValueError(f"Invalid virtual shard format: {shard_url}")
212
-
213
- dataset_path = parts[1]
214
- start_idx = int(parts[3])
215
- chunk_size = 10000 # Default chunk size
216
-
217
- return dataset_path, start_idx, chunk_size
218
-
219
133
  def iterate_shard(
220
- self, shard_url: str, processed_keys: Optional[set] = None
221
- ) -> Generator[Tuple[str, str, bytes], None, None]:
134
+ self,
135
+ shard_url: str,
136
+ processed_keys: Optional[set] = None,
137
+ unprocessed_ranges: Optional[List[Tuple[int, int]]] = None,
138
+ ) -> Generator[Dict[str, Any], None, None]:
222
139
  """
223
- Iterate over items in a shard.
140
+ Iterate over items in a shard, returning full sample dictionaries.
141
+
142
+ Args:
143
+ shard_url: URL or identifier of the shard
144
+ processed_keys: Set of already processed keys to skip
145
+ unprocessed_ranges: Specific ranges to process (for range-based processing)
224
146
 
225
147
  Yields:
226
- Tuple of (key, url, image_bytes)
148
+ Dictionary containing the full WebDataset sample
227
149
  """
228
- # Check if this is a virtual HuggingFace dataset shard
229
- if shard_url.startswith("hf_dataset:"):
230
- yield from self._iterate_hf_dataset_shard(shard_url, processed_keys)
231
- else:
232
- # Regular WebDataset shard
233
- ds = self.load_shard(shard_url, processed_keys)
234
- for key, url, image_data in ds:
235
- yield key, url, image_data
236
-
237
- def _create_dataset_at_position(self, dataset_path: str, split: str, start_idx: int):
238
- """Create a dataset iterator positioned at start_idx using state_dict if available."""
239
- try:
240
- # Load dataset in streaming mode
241
- dataset = load_dataset(
242
- dataset_path,
243
- split=split,
244
- streaming=True,
245
- token=self.token,
246
- )
247
-
248
- # Check if the dataset supports state_dict (newer versions of datasets library)
249
- if hasattr(dataset, "load_state_dict") and hasattr(dataset, "state_dict"):
250
- # Try to use the dataset's native state management
251
- try:
252
- # Get current state
253
- state = dataset.state_dict()
254
-
255
- # Modify the state to skip to start_idx
256
- if "epoch" in state:
257
- state["epoch"] = 0
258
- if "num_examples_since_previous_state" in state:
259
- state["num_examples_since_previous_state"] = start_idx
260
-
261
- # For newer datasets with examples_iterable state
262
- if "examples_iterable" in state:
263
- if isinstance(state["examples_iterable"], dict):
264
- if "shard_example_idx" in state["examples_iterable"]:
265
- state["examples_iterable"]["shard_example_idx"] = start_idx
266
-
267
- # Load the modified state
268
- dataset.load_state_dict(state)
269
- logger.info(f"Positioned dataset at index {start_idx} using state_dict")
270
- return dataset
271
- except Exception as e:
272
- logger.debug(f"Could not use state_dict approach: {e}")
273
-
274
- # Fall back to skip() for large skips
275
- if start_idx > 0:
276
- logger.info(f"Using skip() to position dataset at index {start_idx}")
277
- dataset = dataset.skip(start_idx)
278
-
279
- return dataset
280
-
281
- except Exception as e:
282
- logger.warning(f"Error creating positioned dataset: {e}")
283
- return None
284
-
285
- def _iterate_hf_dataset_shard_with_metadata(
286
- self, shard_url: str, processed_keys: Optional[set] = None
287
- ) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
288
- """Iterate over a virtual HuggingFace dataset shard with metadata."""
289
150
  if processed_keys is None:
290
151
  processed_keys = set()
291
152
 
292
- dataset_path, start_idx, chunk_size = self._parse_virtual_shard(shard_url)
293
-
294
- logger.info(
295
- f"Loading HuggingFace dataset with metadata: {dataset_path} (split: {self.split})"
296
- )
297
-
298
- try:
299
- # Try optimized approach for large skips
300
- if start_idx > 100:
301
- dataset = self._create_dataset_at_position(dataset_path, self.split, start_idx)
302
- if dataset:
303
- items_processed = 0
304
-
305
- for item in dataset:
306
- # Stop after processing chunk_size items
307
- if items_processed >= chunk_size:
308
- break
309
-
310
- # Generate a unique key for this item
311
- key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
312
-
313
- if key in processed_keys:
314
- items_processed += 1
315
- continue
316
-
317
- try:
318
- # Extract image data
319
- if self.image_column in item:
320
- img_data = item[self.image_column]
321
-
322
- # Process image to bytes
323
- image_bytes = ImageProcessor.process_image_data(img_data)
324
-
325
- if image_bytes:
326
- # Extract all metadata (excluding the image column)
327
- metadata = {
328
- k: v for k, v in item.items() if k != self.image_column
329
- }
330
-
331
- # URL is virtual for HF datasets
332
- url = f"hf://{dataset_path}#{start_idx + items_processed}"
333
- items_processed += 1
334
- yield key, url, image_bytes, metadata
335
- else:
336
- logger.warning(
337
- f"Failed to process image for item at index {start_idx + items_processed}"
338
- )
339
- items_processed += 1
340
- continue
341
- else:
342
- logger.warning(
343
- f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
344
- f"Available columns: {list(item.keys())}"
345
- )
346
- items_processed += 1
347
-
348
- except Exception as e:
349
- logger.error(
350
- f"Error processing item at index {start_idx + items_processed}: {e}"
351
- )
352
- items_processed += 1
353
- continue
354
-
355
- return
356
-
357
- # Fall back to regular approach for small skips or if StatefulDataLoader not available
358
- dataset = load_dataset(
359
- dataset_path,
360
- split=self.split,
361
- streaming=True,
362
- token=self.token,
363
- )
364
-
365
- # Skip to start index if needed
366
- if start_idx > 0:
367
- dataset = dataset.skip(start_idx)
368
-
369
- items_processed = 0
370
-
371
- for item in dataset:
372
- # Stop after processing chunk_size items
373
- if items_processed >= chunk_size:
374
- break
375
-
376
- # Generate a unique key for this item
377
- key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
378
-
379
- if key in processed_keys:
380
- items_processed += 1
381
- continue
382
-
383
- try:
384
- # Extract image data
385
- if self.image_column in item:
386
- img_data = item[self.image_column]
387
-
388
- # Process image to bytes
389
- image_bytes = ImageProcessor.process_image_data(img_data)
390
-
391
- if image_bytes:
392
- # Extract all metadata (excluding the image column)
393
- metadata = {k: v for k, v in item.items() if k != self.image_column}
394
-
395
- # URL is virtual for HF datasets
396
- url = f"hf://{dataset_path}#{start_idx + items_processed}"
397
- items_processed += 1
398
- yield key, url, image_bytes, metadata
399
- else:
400
- logger.warning(
401
- f"Failed to process image for item at index {start_idx + items_processed}"
402
- )
403
- items_processed += 1
404
- continue
405
- else:
406
- logger.warning(
407
- f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
408
- f"Available columns: {list(item.keys())}"
409
- )
410
- items_processed += 1
411
-
412
- except Exception as e:
413
- logger.error(
414
- f"Error processing item at index {start_idx + items_processed}: {e}"
415
- )
416
- items_processed += 1
417
- continue
418
-
419
- except Exception as e:
420
- logger.error(f"Error loading HuggingFace dataset: {e}")
421
- return
422
-
423
- def _iterate_hf_dataset_shard(
424
- self, shard_url: str, processed_keys: Optional[set] = None
425
- ) -> Generator[Tuple[str, str, bytes], None, None]:
426
- """Iterate over a virtual HuggingFace dataset shard."""
427
- if processed_keys is None:
428
- processed_keys = set()
429
-
430
- dataset_path, start_idx, chunk_size = self._parse_virtual_shard(shard_url)
431
-
432
- # IMPORTANT: Check if start_idx is beyond dataset bounds
433
- if self._hf_total_items is not None and start_idx >= self._hf_total_items:
434
- logger.warning(
435
- f"Virtual shard starts at index {start_idx} but dataset only has "
436
- f"{self._hf_total_items} items. Skipping this shard."
437
- )
438
- return
439
-
440
- logger.info(
441
- f"Loading HuggingFace dataset in streaming mode: {dataset_path} "
442
- f"(split: {self.split}, start: {start_idx}, chunk_size: {chunk_size})"
443
- )
444
-
445
- try:
446
- # Try optimized approach for large skips
447
- if start_idx > 100:
448
- dataset = self._create_dataset_at_position(dataset_path, self.split, start_idx)
449
- if dataset:
450
- items_processed = 0
451
-
452
- for item in dataset:
453
- # Stop after processing chunk_size items
454
- if items_processed >= chunk_size:
455
- logger.info(f"Completed chunk: processed {items_processed} items")
456
- break
457
-
458
- # Also stop if we've reached the dataset end
459
- if (
460
- self._hf_total_items
461
- and (start_idx + items_processed) >= self._hf_total_items
462
- ):
463
- logger.info(
464
- f"Reached dataset end at item {start_idx + items_processed} "
465
- f"(total: {self._hf_total_items})"
466
- )
467
- break
468
-
469
- # Generate a unique key for this item
470
- key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
471
-
472
- if key in processed_keys:
473
- items_processed += 1
474
- continue
475
-
476
- try:
477
- # Extract image data
478
- if self.image_column in item:
479
- img_data = item[self.image_column]
480
-
481
- # Delegate image processing to ImageProcessor
482
- image_bytes = ImageProcessor.process_image_data(img_data)
483
-
484
- if image_bytes:
485
- # URL is virtual for HF datasets
486
- url = f"hf://{dataset_path}#{start_idx + items_processed}"
487
- items_processed += 1
488
- yield key, url, image_bytes
489
- else:
490
- logger.warning(
491
- f"Failed to process image for item at index {start_idx + items_processed}"
492
- )
493
- items_processed += 1
494
- continue
495
- else:
496
- logger.warning(
497
- f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
498
- f"Available columns: {list(item.keys())}"
499
- )
500
- items_processed += 1
501
-
502
- except Exception as e:
503
- logger.error(
504
- f"Error processing item at index {start_idx + items_processed}: {e}"
505
- )
506
- items_processed += 1
507
- continue
508
-
509
- logger.info(
510
- f"Virtual shard complete: processed {items_processed} items "
511
- f"(start_idx: {start_idx})"
512
- )
513
- return
514
-
515
- # Fall back to regular approach for small skips or if StatefulDataLoader not available
516
- dataset = load_dataset(
517
- dataset_path,
518
- split=self.split,
519
- streaming=True,
520
- token=self.token,
153
+ if self.dataset_type == "huggingface" and self.dataset_format == "webdataset":
154
+ # Use curl with auth token for HuggingFace
155
+ url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.token)}' {shlex.quote(shard_url)} || true"
156
+ ds = wds.DataPipeline(
157
+ wds.SimpleShardList(url_cmd),
158
+ wds.tarfile_to_samples(),
159
+ wds.select(lambda x: x.get("__key__", "") not in processed_keys),
521
160
  )
522
-
523
- # Use dataset.skip() for efficient skipping
524
- if start_idx > 0:
525
- dataset = dataset.skip(start_idx)
526
- logger.info(f"Skipped to index {start_idx}")
527
-
528
- items_processed = 0
529
-
530
- # Now enumerate starts from 0 after skip
531
- for item in dataset:
532
- # Stop after processing chunk_size items
533
- if items_processed >= chunk_size:
534
- logger.info(f"Completed chunk: processed {items_processed} items")
535
- break
536
-
537
- # Also stop if we've reached the dataset end
538
- if self._hf_total_items and (start_idx + items_processed) >= self._hf_total_items:
539
- logger.info(
540
- f"Reached dataset end at item {start_idx + items_processed} "
541
- f"(total: {self._hf_total_items})"
542
- )
543
- break
544
-
545
- # Generate a unique key for this item - ensure proper formatting
546
- key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
547
-
548
- if key in processed_keys:
549
- items_processed += 1
550
- continue
551
-
552
- try:
553
- # Extract image data - check configured column name
554
- if self.image_column in item:
555
- img_data = item[self.image_column]
556
-
557
- # Delegate image processing to ImageProcessor
558
- image_bytes = ImageProcessor.process_image_data(img_data)
559
-
560
- if image_bytes:
561
- # URL is virtual for HF datasets
562
- url = f"hf://{dataset_path}#{start_idx + items_processed}"
563
- items_processed += 1
564
- yield key, url, image_bytes
565
- else:
566
- logger.warning(
567
- f"Failed to process image for item at index {start_idx + items_processed}"
568
- )
569
- items_processed += 1
570
- continue
571
- else:
572
- logger.warning(
573
- f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
574
- f"Available columns: {list(item.keys())}"
575
- )
576
- items_processed += 1
577
-
578
- except Exception as e:
579
- logger.error(
580
- f"Error processing item at index {start_idx + items_processed}: {e}"
581
- )
582
- items_processed += 1
583
- continue
584
-
585
- logger.info(
586
- f"Virtual shard complete: processed {items_processed} items "
587
- f"(start_idx: {start_idx})"
161
+ else:
162
+ # Local file access
163
+ ds = wds.DataPipeline(
164
+ wds.SimpleShardList(shard_url),
165
+ wds.tarfile_to_samples(),
166
+ wds.select(lambda x: x.get("__key__", "") not in processed_keys),
588
167
  )
589
168
 
590
- except Exception as e:
591
- logger.error(f"Error loading HuggingFace dataset: {e}")
592
- return
593
-
594
- def iterate_shard_with_metadata(
595
- self, shard_url: str, processed_keys: Optional[set] = None
596
- ) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
597
- """
598
- Iterate over items in a shard, including metadata.
599
-
600
- Yields:
601
- Tuple of (key, url, image_bytes, metadata_dict)
602
- """
603
- # Check if this is a virtual HuggingFace dataset shard
604
- if shard_url.startswith("hf_dataset:"):
605
- yield from self._iterate_hf_dataset_shard_with_metadata(shard_url, processed_keys)
606
- else:
607
- # Regular WebDataset shard - no metadata by default
608
- for key, url, image_data in self.iterate_shard(shard_url, processed_keys):
609
- yield key, url, image_data, {}
169
+ # Return full samples as dictionaries
170
+ for sample in ds:
171
+ # Ensure it's a dict and has required fields
172
+ if isinstance(sample, dict) and "__key__" in sample:
173
+ yield sample
610
174
 
611
175
  def count_shard_items(self, shard_url: str, processed_keys: Optional[set] = None) -> int:
612
176
  """Count items in a shard (can be slow for large shards)."""
613
- if shard_url.startswith("hf_dataset:"):
614
- # For virtual shards, return the chunk size
615
- _, start_idx, chunk_size = self._parse_virtual_shard(shard_url)
616
-
617
- # CRITICAL: Cap chunk size by dataset bounds
618
- if self._hf_total_items is not None:
619
- # If start index is beyond dataset, return 0
620
- if start_idx >= self._hf_total_items:
621
- logger.warning(
622
- f"Virtual shard starts at {start_idx} but dataset has "
623
- f"only {self._hf_total_items} items"
624
- )
625
- return 0
626
-
627
- # Otherwise, return the minimum of chunk_size and remaining items
628
- remaining_items = self._hf_total_items - start_idx
629
- actual_size = min(chunk_size, remaining_items)
630
- logger.debug(
631
- f"Virtual shard at {start_idx}: chunk_size={chunk_size}, "
632
- f"remaining={remaining_items}, actual={actual_size}"
633
- )
634
- return actual_size
635
- else:
636
- # If we don't know total size, return chunk_size
637
- return chunk_size
638
- else:
639
- # Regular WebDataset counting
640
- count = 0
641
- try:
642
- for _ in self.iterate_shard(shard_url, processed_keys):
643
- count += 1
644
- except Exception as e:
645
- logger.error(f"Error counting shard {shard_url}: {e}")
646
- return count
177
+ count = 0
178
+ try:
179
+ for _ in self.iterate_shard(shard_url, processed_keys):
180
+ count += 1
181
+ except Exception as e:
182
+ logger.error(f"Error counting shard {shard_url}: {e}")
183
+ return count
647
184
 
648
185
  def get_dataset_info(self) -> Dict[str, Any]:
649
186
  """Get information about the dataset."""
@@ -654,27 +191,32 @@ class DatasetLoader:
654
191
  }
655
192
 
656
193
  if self.dataset_format == "huggingface_datasets":
657
- try:
658
- # Try to get more info about the dataset
659
- dataset_info = load_dataset(
660
- self.dataset_path, split=self.split, streaming=True, token=self.token
661
- )
662
- # Get features info
663
- if hasattr(dataset_info, "features"):
664
- info["features"] = str(dataset_info.features)
665
-
666
- # Try to get total size (might not work for all datasets)
194
+ # Include cached metadata if available
195
+ if hasattr(self, "_hf_metadata"):
196
+ info.update(self._hf_metadata)
197
+ else:
198
+
667
199
  try:
668
- # This might be expensive for large datasets
669
- total_examples = len(
670
- load_dataset(self.dataset_path, split=self.split, token=self.token)
200
+ # Try to get more info about the dataset
201
+ dataset_info = load_dataset(
202
+ self.dataset_path, split=self.split, streaming=True, token=self.token
671
203
  )
672
- info["total_examples"] = total_examples
673
- self._hf_total_items = total_examples
674
- except:
675
- info["total_examples"] = "unknown"
204
+ # Get features info
205
+ if hasattr(dataset_info, "features"):
206
+ info["features"] = str(dataset_info.features)
207
+
208
+ # Try to get total size (might not work for all datasets)
209
+ try:
210
+ # This might be expensive for large datasets
211
+ total_examples = len(
212
+ load_dataset(self.dataset_path, split=self.split, token=self.token)
213
+ )
214
+ info["total_examples"] = total_examples
215
+ self._hf_total_items = total_examples
216
+ except:
217
+ info["total_examples"] = "unknown"
676
218
 
677
- except Exception as e:
678
- logger.error(f"Error getting dataset info: {e}")
219
+ except Exception as e:
220
+ logger.error(f"Error getting dataset info: {e}")
679
221
 
680
222
  return info