caption-flow 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -73,12 +73,12 @@ class CheckpointTracker(ABC):
73
73
  logger.debug(f"Saved checkpoint to {self.checkpoint_path}")
74
74
 
75
75
  except Exception as e:
76
- logger.error(f"Error saving checkpoint: {e}", exc_info=True)
76
+ # logger.error(f"Error saving checkpoint: {e}", exc_info=True)
77
77
  # Try direct write as fallback
78
78
  try:
79
79
  with open(self.checkpoint_path, "w") as f:
80
80
  json.dump(data, f, indent=2)
81
- logger.info("Saved checkpoint using fallback direct write")
81
+ # logger.info("Saved checkpoint using fallback direct write")
82
82
  except Exception as fallback_error:
83
83
  logger.error(f"Fallback save also failed: {fallback_error}")
84
84
 
@@ -10,6 +10,7 @@ from dataclasses import dataclass, asdict, field
10
10
  from .checkpoint_tracker import CheckpointTracker
11
11
 
12
12
  logger = logging.getLogger(__name__)
13
+ # logger.setLevel(logging.DEBUG)
13
14
 
14
15
 
15
16
  @dataclass
@@ -58,11 +59,15 @@ class ChunkState:
58
59
  def get_unprocessed_ranges(self) -> List[Tuple[int, int]]:
59
60
  """Get ranges that haven't been processed yet."""
60
61
  if not self.processed_ranges:
62
+ logger.info(f"Chunk {self.chunk_id} has no processed ranges, returning full range")
61
63
  return [(0, self.chunk_size - 1)]
62
64
 
63
65
  unprocessed = []
64
66
  current = 0
65
67
 
68
+ logger.info(
69
+ f"Processing {len(self.processed_ranges)} processed ranges for chunk {self.chunk_id}"
70
+ )
66
71
  for start, end in self.processed_ranges:
67
72
  if current < start:
68
73
  unprocessed.append((current, start - 1))
@@ -132,6 +137,11 @@ class ChunkTracker(CheckpointTracker):
132
137
  self, chunk_id: str, shard_name: str, shard_url: str, start_index: int, chunk_size: int
133
138
  ) -> bool:
134
139
  """Add a new chunk. Returns False if chunk already exists and is completed."""
140
+ if chunk_id in self.chunks:
141
+ logger.debug(
142
+ f"Chunk {chunk_id} already exists with status: {self.chunks[chunk_id].status}, not creating"
143
+ )
144
+ return False
135
145
  if chunk_id in self.completed_chunks:
136
146
  logger.debug(f"Chunk {chunk_id} already completed, skipping")
137
147
  return False
@@ -166,7 +176,7 @@ class ChunkTracker(CheckpointTracker):
166
176
  chunk.completed_at = datetime.utcnow()
167
177
  self.completed_chunks.add(chunk_id)
168
178
  self.save()
169
- logger.info(f"Chunk {chunk_id} marked as completed")
179
+ logger.debug(f"Chunk {chunk_id} marked as completed")
170
180
 
171
181
  def mark_failed(self, chunk_id: str):
172
182
  """Mark chunk as failed."""
@@ -207,6 +217,49 @@ class ChunkTracker(CheckpointTracker):
207
217
  pending.append(chunk_id)
208
218
  return pending
209
219
 
220
+ def get_processed_indices_for_chunk(
221
+ self, chunk_id: str, processed_job_ids: Set[str]
222
+ ) -> List[Tuple[int, int]]:
223
+ """Convert processed job_ids back to ranges for a chunk."""
224
+ # Extract indices from job_ids like "data-0000:chunk:0:idx:42"
225
+ processed_indices = []
226
+ # this will be slow as shit, but it's simple for now, Proof of Concept.
227
+ for job_id in processed_job_ids:
228
+ test_chunk_id = chunk_id.replace("_", ":")
229
+ if test_chunk_id in job_id:
230
+ parts = job_id.split(":")
231
+ logger.debug(
232
+ f"Found matching job_id {job_id} for chunk {chunk_id} with {len(parts)=} and {parts[3]=}"
233
+ )
234
+ if len(parts) >= 5 and parts[3] == "idx":
235
+ idx = int(parts[4])
236
+ processed_indices.append(idx)
237
+
238
+ # Convert to ranges
239
+ if not processed_indices:
240
+ # logger.warning(
241
+ # f"Chunk {chunk_id} had no pre-processed ranges discovered, will process all elements"
242
+ # )
243
+ return []
244
+ else:
245
+ logger.debug(f"Chunk {chunk_id} has {len(processed_indices)} pre-processed indices")
246
+
247
+ processed_indices.sort()
248
+ ranges = []
249
+ start = processed_indices[0]
250
+ end = processed_indices[0]
251
+
252
+ for idx in processed_indices[1:]:
253
+ if idx == end + 1:
254
+ end = idx
255
+ else:
256
+ ranges.append((start, end))
257
+ start = idx
258
+ end = idx
259
+
260
+ ranges.append((start, end))
261
+ return ranges
262
+
210
263
  def is_shard_complete(self, shard_name: str) -> bool:
211
264
  """Check if all chunks for a shard are complete."""
212
265
  shard_chunks = [chunk for chunk in self.chunks.values() if chunk.shard_name == shard_name]
@@ -236,20 +289,8 @@ class ChunkTracker(CheckpointTracker):
236
289
 
237
290
  for chunk_id, chunk_state in self.chunks.items():
238
291
  shard_name = chunk_state.shard_name
239
-
240
- # For virtual HF dataset shards, normalize the shard name
241
- if shard_name.startswith("hf_dataset:"):
242
- parts = shard_name.split(":")
243
- if len(parts) >= 4 and parts[2] == "chunk":
244
- # Use just the dataset identifier as the shard name
245
- normalized_shard_name = ":".join(parts[:2])
246
- else:
247
- normalized_shard_name = shard_name
248
- else:
249
- normalized_shard_name = shard_name
250
-
251
- if normalized_shard_name not in shards:
252
- shards[normalized_shard_name] = {
292
+ if shard_name not in shards:
293
+ shards[shard_name] = {
253
294
  "total_chunks": 0,
254
295
  "completed_chunks": 0,
255
296
  "pending_chunks": 0,
@@ -259,20 +300,20 @@ class ChunkTracker(CheckpointTracker):
259
300
  "chunks": [],
260
301
  }
261
302
 
262
- shards[normalized_shard_name]["chunks"].append(chunk_state)
263
- shards[normalized_shard_name]["total_chunks"] += 1
303
+ shards[shard_name]["chunks"].append(chunk_state)
304
+ shards[shard_name]["total_chunks"] += 1
264
305
 
265
306
  if chunk_state.status == "completed":
266
- shards[normalized_shard_name]["completed_chunks"] += 1
307
+ shards[shard_name]["completed_chunks"] += 1
267
308
  elif chunk_state.status == "pending":
268
- shards[normalized_shard_name]["pending_chunks"] += 1
269
- shards[normalized_shard_name]["is_complete"] = False
309
+ shards[shard_name]["pending_chunks"] += 1
310
+ shards[shard_name]["is_complete"] = False
270
311
  elif chunk_state.status == "assigned":
271
- shards[normalized_shard_name]["assigned_chunks"] += 1
272
- shards[normalized_shard_name]["is_complete"] = False
312
+ shards[shard_name]["assigned_chunks"] += 1
313
+ shards[shard_name]["is_complete"] = False
273
314
  elif chunk_state.status == "failed":
274
- shards[normalized_shard_name]["failed_chunks"] += 1
275
- shards[normalized_shard_name]["is_complete"] = False
315
+ shards[shard_name]["failed_chunks"] += 1
316
+ shards[shard_name]["is_complete"] = False
276
317
 
277
318
  return shards
278
319
 
@@ -322,13 +363,7 @@ class ChunkTracker(CheckpointTracker):
322
363
  continue
323
364
 
324
365
  # Infer shard URL and create chunk with default size
325
- if shard_name.replace("_", "/") in chunk_id or "_" in shard_name:
326
- # HF dataset
327
- dataset_path = shard_name.replace("_", "/")
328
- shard_url = f"hf_dataset:{dataset_path}:chunk:{start_idx}"
329
- else:
330
- # WebDataset
331
- shard_url = f"unknown://{shard_name}.tar"
366
+ shard_url = f"unknown://{shard_name}.tar"
332
367
 
333
368
  self.chunks[chunk_id] = ChunkState(
334
369
  chunk_id=chunk_id,
@@ -410,6 +445,7 @@ class ChunkTracker(CheckpointTracker):
410
445
  """Mark a range of items as processed within a chunk (expects ABSOLUTE indices)."""
411
446
  if chunk_id not in self.chunks:
412
447
  logger.error(f"Unknown chunk: {chunk_id}")
448
+ logger.debug(f"Known chunks: {list(self.chunks.keys())}")
413
449
  return
414
450
 
415
451
  chunk = self.chunks[chunk_id]
@@ -450,8 +486,13 @@ class ChunkTracker(CheckpointTracker):
450
486
  if not hasattr(self, "_startup_complete"):
451
487
  self._startup_complete = False
452
488
 
453
- if not self._startup_complete or not chunk_state.assigned_to:
489
+ if not self._startup_complete or (
490
+ not chunk_state.assigned_to or chunk_state.completed_at is None
491
+ ):
454
492
  # Return all unprocessed ranges
493
+ logger.debug(
494
+ f"Returning all unprocessed ranges. Status {self._startup_complete=} {chunk_state=}"
495
+ )
455
496
  return {
456
497
  "chunk_id": chunk_id,
457
498
  "unprocessed_ranges": chunk_state.get_unprocessed_ranges(),
@@ -9,8 +9,6 @@ import json
9
9
 
10
10
  import webdataset as wds
11
11
  from huggingface_hub import HfFileSystem, get_token, hf_hub_url
12
- from datasets import load_dataset, Dataset
13
- from .image_processor import ImageProcessor
14
12
 
15
13
  logger = logging.getLogger(__name__)
16
14
 
@@ -24,6 +22,7 @@ class DatasetLoader:
24
22
  dataset_type: str = "huggingface",
25
23
  split: str = "train",
26
24
  image_column: str = "image",
25
+ cache_dir: Optional[Path] = None,
27
26
  ):
28
27
  """
29
28
  Initialize dataset loader.
@@ -40,8 +39,6 @@ class DatasetLoader:
40
39
  self.image_column = image_column
41
40
  self.token = get_token()
42
41
  self.dataset_format = None # Will be detected: "webdataset" or "huggingface_datasets"
43
- self._hf_dataset = None # Cache for HuggingFace dataset
44
- self._hf_total_items = None # Cache for total items count
45
42
 
46
43
  if not self.token and dataset_type == "huggingface":
47
44
  logger.warning("No HuggingFace token found; run `huggingface-cli login`")
@@ -60,27 +57,18 @@ class DatasetLoader:
60
57
  if tar_files:
61
58
  return "webdataset"
62
59
 
63
- # Check for parquet files (HuggingFace datasets)
60
+ # Check for .parquet files (Huggingface Arrow DB)
64
61
  parquet_files = list(fs.glob(f"hf://datasets/{self.dataset_path}/**/*.parquet"))
65
62
  if parquet_files:
66
63
  return "huggingface_datasets"
67
64
 
68
- # Check for dataset_info.json or dataset_dict.json
69
- if fs.exists(f"datasets/{self.dataset_path}/dataset_info.json") or fs.exists(
70
- f"datasets/{self.dataset_path}/dataset_dict.json"
71
- ):
72
- return "huggingface_datasets"
73
-
74
- logger.warning(f"Could not detect dataset format for {self.dataset_path}")
75
- return "unknown"
65
+ raise AssertionError(f"Could not detect dataset format for {self.dataset_path}")
76
66
 
77
67
  def get_shard_list(self) -> List[str]:
78
68
  """Get list of all shards in the dataset."""
79
69
  if self.dataset_type == "huggingface":
80
70
  if self.dataset_format == "webdataset":
81
71
  return self._get_hf_webdataset_shards()
82
- elif self.dataset_format == "huggingface_datasets":
83
- return self._get_hf_dataset_shards()
84
72
  else:
85
73
  logger.error(f"Unknown dataset format: {self.dataset_format}")
86
74
  return []
@@ -101,60 +89,6 @@ class DatasetLoader:
101
89
  logger.info(f"Found {len(urls)} WebDataset shards")
102
90
  return sorted(urls)
103
91
 
104
- def _get_hf_dataset_shards(self) -> List[str]:
105
- """Get virtual 'shards' for HuggingFace datasets format."""
106
- logger.info(f"Getting HuggingFace dataset info: {self.dataset_path}")
107
-
108
- # For HuggingFace datasets, we'll create virtual shards based on chunks
109
- # Each "shard" will be a range of indices
110
- try:
111
- # First, try to get available splits
112
- try:
113
- from datasets import get_dataset_split_names
114
-
115
- available_splits = get_dataset_split_names(self.dataset_path, token=self.token)
116
- logger.info(f"Available splits: {available_splits}")
117
-
118
- if self.split not in available_splits:
119
- logger.warning(
120
- f"Requested split '{self.split}' not found. "
121
- f"Available splits: {available_splits}. "
122
- f"Using first available split: '{available_splits[0]}'"
123
- )
124
- self.split = available_splits[0]
125
- except Exception as e:
126
- logger.warning(f"Could not get split names: {e}")
127
-
128
- # Load dataset info without downloading data
129
- dataset_info = load_dataset(
130
- self.dataset_path, split=self.split, streaming=True, token=self.token
131
- )
132
-
133
- # Try to get the total size
134
- # For streaming datasets, we might need to iterate to count
135
- # This is expensive, so we'll use a default chunk size instead
136
- chunk_size = 10000 # Default chunk size for virtual shards
137
-
138
- # Create virtual shard identifiers
139
- # Format: "hf_dataset:<dataset_path>:chunk:<start_idx>"
140
- virtual_shards = []
141
-
142
- # We'll create a reasonable number of virtual shards
143
- # Without knowing the total size, we'll create them on-demand
144
- # For now, create initial batch of virtual shards
145
- for i in range(10): # Start with 10 virtual shards
146
- shard_id = f"hf_dataset:{self.dataset_path}:chunk:{i * chunk_size}"
147
- virtual_shards.append(shard_id)
148
-
149
- logger.info(
150
- f"Created {len(virtual_shards)} initial virtual shards for HuggingFace dataset"
151
- )
152
- return virtual_shards
153
-
154
- except Exception as e:
155
- logger.error(f"Error loading HuggingFace dataset info: {e}")
156
- return []
157
-
158
92
  def _get_local_shards(self) -> List[str]:
159
93
  """Get shard files from local directory."""
160
94
  path = Path(self.dataset_path)
@@ -176,13 +110,6 @@ class DatasetLoader:
176
110
  if processed_keys is None:
177
111
  processed_keys = set()
178
112
 
179
- # Check if this is a virtual HuggingFace dataset shard
180
- if shard_url.startswith("hf_dataset:"):
181
- raise ValueError(
182
- "Virtual HuggingFace dataset shards should use iterate_shard() directly, "
183
- "not load_shard()"
184
- )
185
-
186
113
  if self.dataset_type == "huggingface" and self.dataset_format == "webdataset":
187
114
  # Use curl with auth token for HuggingFace
188
115
  url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.token)}' {shlex.quote(shard_url)} || true"
@@ -203,229 +130,57 @@ class DatasetLoader:
203
130
 
204
131
  return ds
205
132
 
206
- def _parse_virtual_shard(self, shard_url: str) -> Tuple[str, int, int]:
207
- """Parse virtual shard identifier."""
208
- # Format: "hf_dataset:<dataset_path>:chunk:<start_idx>"
209
- parts = shard_url.split(":")
210
- if len(parts) != 4 or parts[0] != "hf_dataset" or parts[2] != "chunk":
211
- raise ValueError(f"Invalid virtual shard format: {shard_url}")
212
-
213
- dataset_path = parts[1]
214
- start_idx = int(parts[3])
215
- chunk_size = 10000 # Default chunk size
216
-
217
- return dataset_path, start_idx, chunk_size
218
-
219
133
  def iterate_shard(
220
134
  self,
221
135
  shard_url: str,
222
136
  processed_keys: Optional[set] = None,
223
137
  unprocessed_ranges: Optional[List[Tuple[int, int]]] = None,
224
- ) -> Generator[Tuple[str, str, bytes], None, None]:
138
+ ) -> Generator[Dict[str, Any], None, None]:
225
139
  """
226
- Iterate over items in a shard.
140
+ Iterate over items in a shard, returning full sample dictionaries.
227
141
 
228
142
  Args:
229
143
  shard_url: URL or identifier of the shard
230
144
  processed_keys: Set of already processed keys to skip
231
- unprocessed_ranges: Specific ranges to process (for HF datasets)
145
+ unprocessed_ranges: Specific ranges to process (for range-based processing)
232
146
 
233
147
  Yields:
234
- Tuple of (key, url, image_bytes)
148
+ Dictionary containing the full WebDataset sample
235
149
  """
236
- if shard_url.startswith("hf_dataset:"):
237
- raise ValueError(
238
- "Virtual HuggingFace dataset shards should use iterate_shard_with_metadata()"
239
- )
240
- else:
241
- # Regular WebDataset shard
242
- ds = self.load_shard(shard_url, processed_keys)
243
- for key, url, image_data in ds:
244
- yield key, url, image_data
245
-
246
- def _create_dataset_at_position(self, dataset_path: str, split: str, start_idx: int):
247
- """Create a dataset iterator positioned at start_idx using state_dict if available."""
248
- try:
249
- # Load dataset in streaming mode
250
- dataset = load_dataset(
251
- dataset_path,
252
- split=split,
253
- streaming=True,
254
- token=self.token,
255
- )
256
-
257
- # Check if the dataset supports state_dict (newer versions of datasets library)
258
- if hasattr(dataset, "load_state_dict") and hasattr(dataset, "state_dict"):
259
- # Try to use the dataset's native state management
260
- try:
261
- # Get current state
262
- state = dataset.state_dict()
263
-
264
- # Modify the state to skip to start_idx
265
- if "epoch" in state:
266
- state["epoch"] = 0
267
- if "num_examples_since_previous_state" in state:
268
- state["num_examples_since_previous_state"] = start_idx
269
-
270
- # For newer datasets with examples_iterable state
271
- if "examples_iterable" in state:
272
- if isinstance(state["examples_iterable"], dict):
273
- if "shard_example_idx" in state["examples_iterable"]:
274
- state["examples_iterable"]["shard_example_idx"] = start_idx
275
-
276
- # Load the modified state
277
- dataset.load_state_dict(state)
278
- logger.info(f"Positioned dataset at index {start_idx} using state_dict")
279
- return dataset
280
- except Exception as e:
281
- logger.debug(f"Could not use state_dict approach: {e}")
282
-
283
- # Fall back to skip() for large skips
284
- if start_idx > 0:
285
- logger.info(f"Using skip() to position dataset at index {start_idx}")
286
- dataset = dataset.skip(start_idx)
287
-
288
- return dataset
289
-
290
- except Exception as e:
291
- logger.warning(f"Error creating positioned dataset: {e}")
292
- return None
293
-
294
- def _iterate_hf_dataset_shard_with_metadata(
295
- self, shard_url: str, processed_keys: Optional[set] = None
296
- ) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
297
- """Iterate over a virtual HuggingFace dataset shard with metadata."""
298
150
  if processed_keys is None:
299
151
  processed_keys = set()
300
152
 
301
- dataset_path, start_idx, chunk_size = self._parse_virtual_shard(shard_url)
302
-
303
- logger.info(
304
- f"Loading HuggingFace dataset with metadata: {dataset_path} (split: {self.split})"
305
- )
306
-
307
- try:
308
- # For HF datasets, we iterate through the full chunk range
309
- # The actual range filtering happens in the shard processor
310
- items_processed = 0
311
- current_abs_idx = start_idx
312
-
313
- while items_processed < chunk_size:
314
- # Create a fresh dataset iterator for each batch
315
- # This avoids issues with stateful iterators
316
- batch_size = min(1000, chunk_size - items_processed) # Process in smaller batches
317
-
318
- dataset = load_dataset(
319
- dataset_path,
320
- split=self.split,
321
- streaming=True,
322
- token=self.token,
323
- )
324
-
325
- # Skip to current position
326
- if current_abs_idx > 0:
327
- dataset = dataset.skip(current_abs_idx)
328
-
329
- batch_processed = 0
330
- for item in dataset:
331
- if batch_processed >= batch_size or items_processed >= chunk_size:
332
- break
333
-
334
- # Generate key
335
- key = f"{dataset_path.replace('/', '_')}_{current_abs_idx:08d}"
336
-
337
- if key in processed_keys:
338
- current_abs_idx += 1
339
- batch_processed += 1
340
- items_processed += 1
341
- continue
342
-
343
- try:
344
- if self.image_column in item:
345
- img_data = item[self.image_column]
346
- image_bytes = ImageProcessor.process_image_data(img_data)
347
-
348
- if image_bytes:
349
- metadata = {k: v for k, v in item.items() if k != self.image_column}
350
- url = f"hf://{dataset_path}#{current_abs_idx}"
351
-
352
- yield key, url, image_bytes, metadata
353
-
354
- current_abs_idx += 1
355
- batch_processed += 1
356
- items_processed += 1
357
- else:
358
- logger.warning(
359
- f"No image column '{self.image_column}' at index {current_abs_idx}"
360
- )
361
- current_abs_idx += 1
362
- batch_processed += 1
363
- items_processed += 1
364
-
365
- except Exception as e:
366
- logger.error(f"Error processing item at index {current_abs_idx}: {e}")
367
- current_abs_idx += 1
368
- batch_processed += 1
369
- items_processed += 1
370
- continue
371
-
372
- except Exception as e:
373
- logger.error(f"Error loading HuggingFace dataset: {e}")
374
- return
375
-
376
- def iterate_shard_with_metadata(
377
- self, shard_url: str, processed_keys: Optional[set] = None
378
- ) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
379
- """
380
- Iterate over items in a shard, including metadata.
381
-
382
- Yields:
383
- Tuple of (key, url, image_bytes, metadata_dict)
384
- """
385
- # Check if this is a virtual HuggingFace dataset shard
386
- if shard_url.startswith("hf_dataset:"):
387
- yield from self._iterate_hf_dataset_shard_with_metadata(shard_url, processed_keys)
153
+ if self.dataset_type == "huggingface" and self.dataset_format == "webdataset":
154
+ # Use curl with auth token for HuggingFace
155
+ url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.token)}' {shlex.quote(shard_url)} || true"
156
+ ds = wds.DataPipeline(
157
+ wds.SimpleShardList(url_cmd),
158
+ wds.tarfile_to_samples(),
159
+ wds.select(lambda x: x.get("__key__", "") not in processed_keys),
160
+ )
388
161
  else:
389
- # Regular WebDataset shard - no metadata by default
390
- for key, url, image_data in self.iterate_shard(shard_url, processed_keys):
391
- yield key, url, image_data, {}
162
+ # Local file access
163
+ ds = wds.DataPipeline(
164
+ wds.SimpleShardList(shard_url),
165
+ wds.tarfile_to_samples(),
166
+ wds.select(lambda x: x.get("__key__", "") not in processed_keys),
167
+ )
168
+
169
+ # Return full samples as dictionaries
170
+ for sample in ds:
171
+ # Ensure it's a dict and has required fields
172
+ if isinstance(sample, dict) and "__key__" in sample:
173
+ yield sample
392
174
 
393
175
  def count_shard_items(self, shard_url: str, processed_keys: Optional[set] = None) -> int:
394
176
  """Count items in a shard (can be slow for large shards)."""
395
- if shard_url.startswith("hf_dataset:"):
396
- # For virtual shards, return the chunk size
397
- _, start_idx, chunk_size = self._parse_virtual_shard(shard_url)
398
-
399
- # CRITICAL: Cap chunk size by dataset bounds
400
- if self._hf_total_items is not None:
401
- # If start index is beyond dataset, return 0
402
- if start_idx >= self._hf_total_items:
403
- logger.warning(
404
- f"Virtual shard starts at {start_idx} but dataset has "
405
- f"only {self._hf_total_items} items"
406
- )
407
- return 0
408
-
409
- # Otherwise, return the minimum of chunk_size and remaining items
410
- remaining_items = self._hf_total_items - start_idx
411
- actual_size = min(chunk_size, remaining_items)
412
- logger.debug(
413
- f"Virtual shard at {start_idx}: chunk_size={chunk_size}, "
414
- f"remaining={remaining_items}, actual={actual_size}"
415
- )
416
- return actual_size
417
- else:
418
- # If we don't know total size, return chunk_size
419
- return chunk_size
420
- else:
421
- # Regular WebDataset counting
422
- count = 0
423
- try:
424
- for _ in self.iterate_shard(shard_url, processed_keys):
425
- count += 1
426
- except Exception as e:
427
- logger.error(f"Error counting shard {shard_url}: {e}")
428
- return count
177
+ count = 0
178
+ try:
179
+ for _ in self.iterate_shard(shard_url, processed_keys):
180
+ count += 1
181
+ except Exception as e:
182
+ logger.error(f"Error counting shard {shard_url}: {e}")
183
+ return count
429
184
 
430
185
  def get_dataset_info(self) -> Dict[str, Any]:
431
186
  """Get information about the dataset."""
@@ -436,27 +191,32 @@ class DatasetLoader:
436
191
  }
437
192
 
438
193
  if self.dataset_format == "huggingface_datasets":
439
- try:
440
- # Try to get more info about the dataset
441
- dataset_info = load_dataset(
442
- self.dataset_path, split=self.split, streaming=True, token=self.token
443
- )
444
- # Get features info
445
- if hasattr(dataset_info, "features"):
446
- info["features"] = str(dataset_info.features)
447
-
448
- # Try to get total size (might not work for all datasets)
194
+ # Include cached metadata if available
195
+ if hasattr(self, "_hf_metadata"):
196
+ info.update(self._hf_metadata)
197
+ else:
198
+
449
199
  try:
450
- # This might be expensive for large datasets
451
- total_examples = len(
452
- load_dataset(self.dataset_path, split=self.split, token=self.token)
200
+ # Try to get more info about the dataset
201
+ dataset_info = load_dataset(
202
+ self.dataset_path, split=self.split, streaming=True, token=self.token
453
203
  )
454
- info["total_examples"] = total_examples
455
- self._hf_total_items = total_examples
456
- except:
457
- info["total_examples"] = "unknown"
204
+ # Get features info
205
+ if hasattr(dataset_info, "features"):
206
+ info["features"] = str(dataset_info.features)
207
+
208
+ # Try to get total size (might not work for all datasets)
209
+ try:
210
+ # This might be expensive for large datasets
211
+ total_examples = len(
212
+ load_dataset(self.dataset_path, split=self.split, token=self.token)
213
+ )
214
+ info["total_examples"] = total_examples
215
+ self._hf_total_items = total_examples
216
+ except:
217
+ info["total_examples"] = "unknown"
458
218
 
459
- except Exception as e:
460
- logger.error(f"Error getting dataset info: {e}")
219
+ except Exception as e:
220
+ logger.error(f"Error getting dataset info: {e}")
461
221
 
462
222
  return info