caption-flow 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/cli.py +2 -1
- caption_flow/models.py +108 -1
- caption_flow/monitor.py +1 -1
- caption_flow/orchestrator.py +423 -1595
- caption_flow/processors/__init__.py +11 -0
- caption_flow/processors/base.py +219 -0
- caption_flow/processors/huggingface.py +832 -0
- caption_flow/processors/local_filesystem.py +683 -0
- caption_flow/processors/webdataset.py +782 -0
- caption_flow/storage.py +415 -406
- caption_flow/utils/checkpoint_tracker.py +2 -2
- caption_flow/utils/chunk_tracker.py +94 -35
- caption_flow/utils/dataset_loader.py +64 -522
- caption_flow/utils/dataset_metadata_cache.py +67 -0
- caption_flow/utils/image_processor.py +1 -4
- caption_flow/utils/shard_processor.py +4 -200
- caption_flow/utils/shard_tracker.py +1 -5
- caption_flow/workers/base.py +3 -3
- caption_flow/workers/caption.py +416 -792
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/METADATA +29 -27
- caption_flow-0.2.3.dist-info/RECORD +35 -0
- caption_flow-0.2.1.dist-info/RECORD +0 -29
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/WHEEL +0 -0
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/top_level.txt +0 -0
@@ -9,8 +9,6 @@ import json
|
|
9
9
|
|
10
10
|
import webdataset as wds
|
11
11
|
from huggingface_hub import HfFileSystem, get_token, hf_hub_url
|
12
|
-
from datasets import load_dataset, Dataset
|
13
|
-
from .image_processor import ImageProcessor
|
14
12
|
|
15
13
|
logger = logging.getLogger(__name__)
|
16
14
|
|
@@ -24,6 +22,7 @@ class DatasetLoader:
|
|
24
22
|
dataset_type: str = "huggingface",
|
25
23
|
split: str = "train",
|
26
24
|
image_column: str = "image",
|
25
|
+
cache_dir: Optional[Path] = None,
|
27
26
|
):
|
28
27
|
"""
|
29
28
|
Initialize dataset loader.
|
@@ -40,8 +39,6 @@ class DatasetLoader:
|
|
40
39
|
self.image_column = image_column
|
41
40
|
self.token = get_token()
|
42
41
|
self.dataset_format = None # Will be detected: "webdataset" or "huggingface_datasets"
|
43
|
-
self._hf_dataset = None # Cache for HuggingFace dataset
|
44
|
-
self._hf_total_items = None # Cache for total items count
|
45
42
|
|
46
43
|
if not self.token and dataset_type == "huggingface":
|
47
44
|
logger.warning("No HuggingFace token found; run `huggingface-cli login`")
|
@@ -60,27 +57,18 @@ class DatasetLoader:
|
|
60
57
|
if tar_files:
|
61
58
|
return "webdataset"
|
62
59
|
|
63
|
-
# Check for parquet files (
|
60
|
+
# Check for .parquet files (Huggingface Arrow DB)
|
64
61
|
parquet_files = list(fs.glob(f"hf://datasets/{self.dataset_path}/**/*.parquet"))
|
65
62
|
if parquet_files:
|
66
63
|
return "huggingface_datasets"
|
67
64
|
|
68
|
-
|
69
|
-
if fs.exists(f"datasets/{self.dataset_path}/dataset_info.json") or fs.exists(
|
70
|
-
f"datasets/{self.dataset_path}/dataset_dict.json"
|
71
|
-
):
|
72
|
-
return "huggingface_datasets"
|
73
|
-
|
74
|
-
logger.warning(f"Could not detect dataset format for {self.dataset_path}")
|
75
|
-
return "unknown"
|
65
|
+
raise AssertionError(f"Could not detect dataset format for {self.dataset_path}")
|
76
66
|
|
77
67
|
def get_shard_list(self) -> List[str]:
|
78
68
|
"""Get list of all shards in the dataset."""
|
79
69
|
if self.dataset_type == "huggingface":
|
80
70
|
if self.dataset_format == "webdataset":
|
81
71
|
return self._get_hf_webdataset_shards()
|
82
|
-
elif self.dataset_format == "huggingface_datasets":
|
83
|
-
return self._get_hf_dataset_shards()
|
84
72
|
else:
|
85
73
|
logger.error(f"Unknown dataset format: {self.dataset_format}")
|
86
74
|
return []
|
@@ -101,60 +89,6 @@ class DatasetLoader:
|
|
101
89
|
logger.info(f"Found {len(urls)} WebDataset shards")
|
102
90
|
return sorted(urls)
|
103
91
|
|
104
|
-
def _get_hf_dataset_shards(self) -> List[str]:
|
105
|
-
"""Get virtual 'shards' for HuggingFace datasets format."""
|
106
|
-
logger.info(f"Getting HuggingFace dataset info: {self.dataset_path}")
|
107
|
-
|
108
|
-
# For HuggingFace datasets, we'll create virtual shards based on chunks
|
109
|
-
# Each "shard" will be a range of indices
|
110
|
-
try:
|
111
|
-
# First, try to get available splits
|
112
|
-
try:
|
113
|
-
from datasets import get_dataset_split_names
|
114
|
-
|
115
|
-
available_splits = get_dataset_split_names(self.dataset_path, token=self.token)
|
116
|
-
logger.info(f"Available splits: {available_splits}")
|
117
|
-
|
118
|
-
if self.split not in available_splits:
|
119
|
-
logger.warning(
|
120
|
-
f"Requested split '{self.split}' not found. "
|
121
|
-
f"Available splits: {available_splits}. "
|
122
|
-
f"Using first available split: '{available_splits[0]}'"
|
123
|
-
)
|
124
|
-
self.split = available_splits[0]
|
125
|
-
except Exception as e:
|
126
|
-
logger.warning(f"Could not get split names: {e}")
|
127
|
-
|
128
|
-
# Load dataset info without downloading data
|
129
|
-
dataset_info = load_dataset(
|
130
|
-
self.dataset_path, split=self.split, streaming=True, token=self.token
|
131
|
-
)
|
132
|
-
|
133
|
-
# Try to get the total size
|
134
|
-
# For streaming datasets, we might need to iterate to count
|
135
|
-
# This is expensive, so we'll use a default chunk size instead
|
136
|
-
chunk_size = 10000 # Default chunk size for virtual shards
|
137
|
-
|
138
|
-
# Create virtual shard identifiers
|
139
|
-
# Format: "hf_dataset:<dataset_path>:chunk:<start_idx>"
|
140
|
-
virtual_shards = []
|
141
|
-
|
142
|
-
# We'll create a reasonable number of virtual shards
|
143
|
-
# Without knowing the total size, we'll create them on-demand
|
144
|
-
# For now, create initial batch of virtual shards
|
145
|
-
for i in range(10): # Start with 10 virtual shards
|
146
|
-
shard_id = f"hf_dataset:{self.dataset_path}:chunk:{i * chunk_size}"
|
147
|
-
virtual_shards.append(shard_id)
|
148
|
-
|
149
|
-
logger.info(
|
150
|
-
f"Created {len(virtual_shards)} initial virtual shards for HuggingFace dataset"
|
151
|
-
)
|
152
|
-
return virtual_shards
|
153
|
-
|
154
|
-
except Exception as e:
|
155
|
-
logger.error(f"Error loading HuggingFace dataset info: {e}")
|
156
|
-
return []
|
157
|
-
|
158
92
|
def _get_local_shards(self) -> List[str]:
|
159
93
|
"""Get shard files from local directory."""
|
160
94
|
path = Path(self.dataset_path)
|
@@ -176,13 +110,6 @@ class DatasetLoader:
|
|
176
110
|
if processed_keys is None:
|
177
111
|
processed_keys = set()
|
178
112
|
|
179
|
-
# Check if this is a virtual HuggingFace dataset shard
|
180
|
-
if shard_url.startswith("hf_dataset:"):
|
181
|
-
raise ValueError(
|
182
|
-
"Virtual HuggingFace dataset shards should use iterate_shard() directly, "
|
183
|
-
"not load_shard()"
|
184
|
-
)
|
185
|
-
|
186
113
|
if self.dataset_type == "huggingface" and self.dataset_format == "webdataset":
|
187
114
|
# Use curl with auth token for HuggingFace
|
188
115
|
url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.token)}' {shlex.quote(shard_url)} || true"
|
@@ -203,447 +130,57 @@ class DatasetLoader:
|
|
203
130
|
|
204
131
|
return ds
|
205
132
|
|
206
|
-
def _parse_virtual_shard(self, shard_url: str) -> Tuple[str, int, int]:
|
207
|
-
"""Parse virtual shard identifier."""
|
208
|
-
# Format: "hf_dataset:<dataset_path>:chunk:<start_idx>"
|
209
|
-
parts = shard_url.split(":")
|
210
|
-
if len(parts) != 4 or parts[0] != "hf_dataset" or parts[2] != "chunk":
|
211
|
-
raise ValueError(f"Invalid virtual shard format: {shard_url}")
|
212
|
-
|
213
|
-
dataset_path = parts[1]
|
214
|
-
start_idx = int(parts[3])
|
215
|
-
chunk_size = 10000 # Default chunk size
|
216
|
-
|
217
|
-
return dataset_path, start_idx, chunk_size
|
218
|
-
|
219
133
|
def iterate_shard(
|
220
|
-
self,
|
221
|
-
|
134
|
+
self,
|
135
|
+
shard_url: str,
|
136
|
+
processed_keys: Optional[set] = None,
|
137
|
+
unprocessed_ranges: Optional[List[Tuple[int, int]]] = None,
|
138
|
+
) -> Generator[Dict[str, Any], None, None]:
|
222
139
|
"""
|
223
|
-
Iterate over items in a shard.
|
140
|
+
Iterate over items in a shard, returning full sample dictionaries.
|
141
|
+
|
142
|
+
Args:
|
143
|
+
shard_url: URL or identifier of the shard
|
144
|
+
processed_keys: Set of already processed keys to skip
|
145
|
+
unprocessed_ranges: Specific ranges to process (for range-based processing)
|
224
146
|
|
225
147
|
Yields:
|
226
|
-
|
148
|
+
Dictionary containing the full WebDataset sample
|
227
149
|
"""
|
228
|
-
# Check if this is a virtual HuggingFace dataset shard
|
229
|
-
if shard_url.startswith("hf_dataset:"):
|
230
|
-
yield from self._iterate_hf_dataset_shard(shard_url, processed_keys)
|
231
|
-
else:
|
232
|
-
# Regular WebDataset shard
|
233
|
-
ds = self.load_shard(shard_url, processed_keys)
|
234
|
-
for key, url, image_data in ds:
|
235
|
-
yield key, url, image_data
|
236
|
-
|
237
|
-
def _create_dataset_at_position(self, dataset_path: str, split: str, start_idx: int):
|
238
|
-
"""Create a dataset iterator positioned at start_idx using state_dict if available."""
|
239
|
-
try:
|
240
|
-
# Load dataset in streaming mode
|
241
|
-
dataset = load_dataset(
|
242
|
-
dataset_path,
|
243
|
-
split=split,
|
244
|
-
streaming=True,
|
245
|
-
token=self.token,
|
246
|
-
)
|
247
|
-
|
248
|
-
# Check if the dataset supports state_dict (newer versions of datasets library)
|
249
|
-
if hasattr(dataset, "load_state_dict") and hasattr(dataset, "state_dict"):
|
250
|
-
# Try to use the dataset's native state management
|
251
|
-
try:
|
252
|
-
# Get current state
|
253
|
-
state = dataset.state_dict()
|
254
|
-
|
255
|
-
# Modify the state to skip to start_idx
|
256
|
-
if "epoch" in state:
|
257
|
-
state["epoch"] = 0
|
258
|
-
if "num_examples_since_previous_state" in state:
|
259
|
-
state["num_examples_since_previous_state"] = start_idx
|
260
|
-
|
261
|
-
# For newer datasets with examples_iterable state
|
262
|
-
if "examples_iterable" in state:
|
263
|
-
if isinstance(state["examples_iterable"], dict):
|
264
|
-
if "shard_example_idx" in state["examples_iterable"]:
|
265
|
-
state["examples_iterable"]["shard_example_idx"] = start_idx
|
266
|
-
|
267
|
-
# Load the modified state
|
268
|
-
dataset.load_state_dict(state)
|
269
|
-
logger.info(f"Positioned dataset at index {start_idx} using state_dict")
|
270
|
-
return dataset
|
271
|
-
except Exception as e:
|
272
|
-
logger.debug(f"Could not use state_dict approach: {e}")
|
273
|
-
|
274
|
-
# Fall back to skip() for large skips
|
275
|
-
if start_idx > 0:
|
276
|
-
logger.info(f"Using skip() to position dataset at index {start_idx}")
|
277
|
-
dataset = dataset.skip(start_idx)
|
278
|
-
|
279
|
-
return dataset
|
280
|
-
|
281
|
-
except Exception as e:
|
282
|
-
logger.warning(f"Error creating positioned dataset: {e}")
|
283
|
-
return None
|
284
|
-
|
285
|
-
def _iterate_hf_dataset_shard_with_metadata(
|
286
|
-
self, shard_url: str, processed_keys: Optional[set] = None
|
287
|
-
) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
|
288
|
-
"""Iterate over a virtual HuggingFace dataset shard with metadata."""
|
289
150
|
if processed_keys is None:
|
290
151
|
processed_keys = set()
|
291
152
|
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
# Try optimized approach for large skips
|
300
|
-
if start_idx > 100:
|
301
|
-
dataset = self._create_dataset_at_position(dataset_path, self.split, start_idx)
|
302
|
-
if dataset:
|
303
|
-
items_processed = 0
|
304
|
-
|
305
|
-
for item in dataset:
|
306
|
-
# Stop after processing chunk_size items
|
307
|
-
if items_processed >= chunk_size:
|
308
|
-
break
|
309
|
-
|
310
|
-
# Generate a unique key for this item
|
311
|
-
key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
|
312
|
-
|
313
|
-
if key in processed_keys:
|
314
|
-
items_processed += 1
|
315
|
-
continue
|
316
|
-
|
317
|
-
try:
|
318
|
-
# Extract image data
|
319
|
-
if self.image_column in item:
|
320
|
-
img_data = item[self.image_column]
|
321
|
-
|
322
|
-
# Process image to bytes
|
323
|
-
image_bytes = ImageProcessor.process_image_data(img_data)
|
324
|
-
|
325
|
-
if image_bytes:
|
326
|
-
# Extract all metadata (excluding the image column)
|
327
|
-
metadata = {
|
328
|
-
k: v for k, v in item.items() if k != self.image_column
|
329
|
-
}
|
330
|
-
|
331
|
-
# URL is virtual for HF datasets
|
332
|
-
url = f"hf://{dataset_path}#{start_idx + items_processed}"
|
333
|
-
items_processed += 1
|
334
|
-
yield key, url, image_bytes, metadata
|
335
|
-
else:
|
336
|
-
logger.warning(
|
337
|
-
f"Failed to process image for item at index {start_idx + items_processed}"
|
338
|
-
)
|
339
|
-
items_processed += 1
|
340
|
-
continue
|
341
|
-
else:
|
342
|
-
logger.warning(
|
343
|
-
f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
|
344
|
-
f"Available columns: {list(item.keys())}"
|
345
|
-
)
|
346
|
-
items_processed += 1
|
347
|
-
|
348
|
-
except Exception as e:
|
349
|
-
logger.error(
|
350
|
-
f"Error processing item at index {start_idx + items_processed}: {e}"
|
351
|
-
)
|
352
|
-
items_processed += 1
|
353
|
-
continue
|
354
|
-
|
355
|
-
return
|
356
|
-
|
357
|
-
# Fall back to regular approach for small skips or if StatefulDataLoader not available
|
358
|
-
dataset = load_dataset(
|
359
|
-
dataset_path,
|
360
|
-
split=self.split,
|
361
|
-
streaming=True,
|
362
|
-
token=self.token,
|
363
|
-
)
|
364
|
-
|
365
|
-
# Skip to start index if needed
|
366
|
-
if start_idx > 0:
|
367
|
-
dataset = dataset.skip(start_idx)
|
368
|
-
|
369
|
-
items_processed = 0
|
370
|
-
|
371
|
-
for item in dataset:
|
372
|
-
# Stop after processing chunk_size items
|
373
|
-
if items_processed >= chunk_size:
|
374
|
-
break
|
375
|
-
|
376
|
-
# Generate a unique key for this item
|
377
|
-
key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
|
378
|
-
|
379
|
-
if key in processed_keys:
|
380
|
-
items_processed += 1
|
381
|
-
continue
|
382
|
-
|
383
|
-
try:
|
384
|
-
# Extract image data
|
385
|
-
if self.image_column in item:
|
386
|
-
img_data = item[self.image_column]
|
387
|
-
|
388
|
-
# Process image to bytes
|
389
|
-
image_bytes = ImageProcessor.process_image_data(img_data)
|
390
|
-
|
391
|
-
if image_bytes:
|
392
|
-
# Extract all metadata (excluding the image column)
|
393
|
-
metadata = {k: v for k, v in item.items() if k != self.image_column}
|
394
|
-
|
395
|
-
# URL is virtual for HF datasets
|
396
|
-
url = f"hf://{dataset_path}#{start_idx + items_processed}"
|
397
|
-
items_processed += 1
|
398
|
-
yield key, url, image_bytes, metadata
|
399
|
-
else:
|
400
|
-
logger.warning(
|
401
|
-
f"Failed to process image for item at index {start_idx + items_processed}"
|
402
|
-
)
|
403
|
-
items_processed += 1
|
404
|
-
continue
|
405
|
-
else:
|
406
|
-
logger.warning(
|
407
|
-
f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
|
408
|
-
f"Available columns: {list(item.keys())}"
|
409
|
-
)
|
410
|
-
items_processed += 1
|
411
|
-
|
412
|
-
except Exception as e:
|
413
|
-
logger.error(
|
414
|
-
f"Error processing item at index {start_idx + items_processed}: {e}"
|
415
|
-
)
|
416
|
-
items_processed += 1
|
417
|
-
continue
|
418
|
-
|
419
|
-
except Exception as e:
|
420
|
-
logger.error(f"Error loading HuggingFace dataset: {e}")
|
421
|
-
return
|
422
|
-
|
423
|
-
def _iterate_hf_dataset_shard(
|
424
|
-
self, shard_url: str, processed_keys: Optional[set] = None
|
425
|
-
) -> Generator[Tuple[str, str, bytes], None, None]:
|
426
|
-
"""Iterate over a virtual HuggingFace dataset shard."""
|
427
|
-
if processed_keys is None:
|
428
|
-
processed_keys = set()
|
429
|
-
|
430
|
-
dataset_path, start_idx, chunk_size = self._parse_virtual_shard(shard_url)
|
431
|
-
|
432
|
-
# IMPORTANT: Check if start_idx is beyond dataset bounds
|
433
|
-
if self._hf_total_items is not None and start_idx >= self._hf_total_items:
|
434
|
-
logger.warning(
|
435
|
-
f"Virtual shard starts at index {start_idx} but dataset only has "
|
436
|
-
f"{self._hf_total_items} items. Skipping this shard."
|
437
|
-
)
|
438
|
-
return
|
439
|
-
|
440
|
-
logger.info(
|
441
|
-
f"Loading HuggingFace dataset in streaming mode: {dataset_path} "
|
442
|
-
f"(split: {self.split}, start: {start_idx}, chunk_size: {chunk_size})"
|
443
|
-
)
|
444
|
-
|
445
|
-
try:
|
446
|
-
# Try optimized approach for large skips
|
447
|
-
if start_idx > 100:
|
448
|
-
dataset = self._create_dataset_at_position(dataset_path, self.split, start_idx)
|
449
|
-
if dataset:
|
450
|
-
items_processed = 0
|
451
|
-
|
452
|
-
for item in dataset:
|
453
|
-
# Stop after processing chunk_size items
|
454
|
-
if items_processed >= chunk_size:
|
455
|
-
logger.info(f"Completed chunk: processed {items_processed} items")
|
456
|
-
break
|
457
|
-
|
458
|
-
# Also stop if we've reached the dataset end
|
459
|
-
if (
|
460
|
-
self._hf_total_items
|
461
|
-
and (start_idx + items_processed) >= self._hf_total_items
|
462
|
-
):
|
463
|
-
logger.info(
|
464
|
-
f"Reached dataset end at item {start_idx + items_processed} "
|
465
|
-
f"(total: {self._hf_total_items})"
|
466
|
-
)
|
467
|
-
break
|
468
|
-
|
469
|
-
# Generate a unique key for this item
|
470
|
-
key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
|
471
|
-
|
472
|
-
if key in processed_keys:
|
473
|
-
items_processed += 1
|
474
|
-
continue
|
475
|
-
|
476
|
-
try:
|
477
|
-
# Extract image data
|
478
|
-
if self.image_column in item:
|
479
|
-
img_data = item[self.image_column]
|
480
|
-
|
481
|
-
# Delegate image processing to ImageProcessor
|
482
|
-
image_bytes = ImageProcessor.process_image_data(img_data)
|
483
|
-
|
484
|
-
if image_bytes:
|
485
|
-
# URL is virtual for HF datasets
|
486
|
-
url = f"hf://{dataset_path}#{start_idx + items_processed}"
|
487
|
-
items_processed += 1
|
488
|
-
yield key, url, image_bytes
|
489
|
-
else:
|
490
|
-
logger.warning(
|
491
|
-
f"Failed to process image for item at index {start_idx + items_processed}"
|
492
|
-
)
|
493
|
-
items_processed += 1
|
494
|
-
continue
|
495
|
-
else:
|
496
|
-
logger.warning(
|
497
|
-
f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
|
498
|
-
f"Available columns: {list(item.keys())}"
|
499
|
-
)
|
500
|
-
items_processed += 1
|
501
|
-
|
502
|
-
except Exception as e:
|
503
|
-
logger.error(
|
504
|
-
f"Error processing item at index {start_idx + items_processed}: {e}"
|
505
|
-
)
|
506
|
-
items_processed += 1
|
507
|
-
continue
|
508
|
-
|
509
|
-
logger.info(
|
510
|
-
f"Virtual shard complete: processed {items_processed} items "
|
511
|
-
f"(start_idx: {start_idx})"
|
512
|
-
)
|
513
|
-
return
|
514
|
-
|
515
|
-
# Fall back to regular approach for small skips or if StatefulDataLoader not available
|
516
|
-
dataset = load_dataset(
|
517
|
-
dataset_path,
|
518
|
-
split=self.split,
|
519
|
-
streaming=True,
|
520
|
-
token=self.token,
|
153
|
+
if self.dataset_type == "huggingface" and self.dataset_format == "webdataset":
|
154
|
+
# Use curl with auth token for HuggingFace
|
155
|
+
url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.token)}' {shlex.quote(shard_url)} || true"
|
156
|
+
ds = wds.DataPipeline(
|
157
|
+
wds.SimpleShardList(url_cmd),
|
158
|
+
wds.tarfile_to_samples(),
|
159
|
+
wds.select(lambda x: x.get("__key__", "") not in processed_keys),
|
521
160
|
)
|
522
|
-
|
523
|
-
#
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
items_processed = 0
|
529
|
-
|
530
|
-
# Now enumerate starts from 0 after skip
|
531
|
-
for item in dataset:
|
532
|
-
# Stop after processing chunk_size items
|
533
|
-
if items_processed >= chunk_size:
|
534
|
-
logger.info(f"Completed chunk: processed {items_processed} items")
|
535
|
-
break
|
536
|
-
|
537
|
-
# Also stop if we've reached the dataset end
|
538
|
-
if self._hf_total_items and (start_idx + items_processed) >= self._hf_total_items:
|
539
|
-
logger.info(
|
540
|
-
f"Reached dataset end at item {start_idx + items_processed} "
|
541
|
-
f"(total: {self._hf_total_items})"
|
542
|
-
)
|
543
|
-
break
|
544
|
-
|
545
|
-
# Generate a unique key for this item - ensure proper formatting
|
546
|
-
key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
|
547
|
-
|
548
|
-
if key in processed_keys:
|
549
|
-
items_processed += 1
|
550
|
-
continue
|
551
|
-
|
552
|
-
try:
|
553
|
-
# Extract image data - check configured column name
|
554
|
-
if self.image_column in item:
|
555
|
-
img_data = item[self.image_column]
|
556
|
-
|
557
|
-
# Delegate image processing to ImageProcessor
|
558
|
-
image_bytes = ImageProcessor.process_image_data(img_data)
|
559
|
-
|
560
|
-
if image_bytes:
|
561
|
-
# URL is virtual for HF datasets
|
562
|
-
url = f"hf://{dataset_path}#{start_idx + items_processed}"
|
563
|
-
items_processed += 1
|
564
|
-
yield key, url, image_bytes
|
565
|
-
else:
|
566
|
-
logger.warning(
|
567
|
-
f"Failed to process image for item at index {start_idx + items_processed}"
|
568
|
-
)
|
569
|
-
items_processed += 1
|
570
|
-
continue
|
571
|
-
else:
|
572
|
-
logger.warning(
|
573
|
-
f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
|
574
|
-
f"Available columns: {list(item.keys())}"
|
575
|
-
)
|
576
|
-
items_processed += 1
|
577
|
-
|
578
|
-
except Exception as e:
|
579
|
-
logger.error(
|
580
|
-
f"Error processing item at index {start_idx + items_processed}: {e}"
|
581
|
-
)
|
582
|
-
items_processed += 1
|
583
|
-
continue
|
584
|
-
|
585
|
-
logger.info(
|
586
|
-
f"Virtual shard complete: processed {items_processed} items "
|
587
|
-
f"(start_idx: {start_idx})"
|
161
|
+
else:
|
162
|
+
# Local file access
|
163
|
+
ds = wds.DataPipeline(
|
164
|
+
wds.SimpleShardList(shard_url),
|
165
|
+
wds.tarfile_to_samples(),
|
166
|
+
wds.select(lambda x: x.get("__key__", "") not in processed_keys),
|
588
167
|
)
|
589
168
|
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
self, shard_url: str, processed_keys: Optional[set] = None
|
596
|
-
) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
|
597
|
-
"""
|
598
|
-
Iterate over items in a shard, including metadata.
|
599
|
-
|
600
|
-
Yields:
|
601
|
-
Tuple of (key, url, image_bytes, metadata_dict)
|
602
|
-
"""
|
603
|
-
# Check if this is a virtual HuggingFace dataset shard
|
604
|
-
if shard_url.startswith("hf_dataset:"):
|
605
|
-
yield from self._iterate_hf_dataset_shard_with_metadata(shard_url, processed_keys)
|
606
|
-
else:
|
607
|
-
# Regular WebDataset shard - no metadata by default
|
608
|
-
for key, url, image_data in self.iterate_shard(shard_url, processed_keys):
|
609
|
-
yield key, url, image_data, {}
|
169
|
+
# Return full samples as dictionaries
|
170
|
+
for sample in ds:
|
171
|
+
# Ensure it's a dict and has required fields
|
172
|
+
if isinstance(sample, dict) and "__key__" in sample:
|
173
|
+
yield sample
|
610
174
|
|
611
175
|
def count_shard_items(self, shard_url: str, processed_keys: Optional[set] = None) -> int:
|
612
176
|
"""Count items in a shard (can be slow for large shards)."""
|
613
|
-
|
614
|
-
|
615
|
-
_
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
if start_idx >= self._hf_total_items:
|
621
|
-
logger.warning(
|
622
|
-
f"Virtual shard starts at {start_idx} but dataset has "
|
623
|
-
f"only {self._hf_total_items} items"
|
624
|
-
)
|
625
|
-
return 0
|
626
|
-
|
627
|
-
# Otherwise, return the minimum of chunk_size and remaining items
|
628
|
-
remaining_items = self._hf_total_items - start_idx
|
629
|
-
actual_size = min(chunk_size, remaining_items)
|
630
|
-
logger.debug(
|
631
|
-
f"Virtual shard at {start_idx}: chunk_size={chunk_size}, "
|
632
|
-
f"remaining={remaining_items}, actual={actual_size}"
|
633
|
-
)
|
634
|
-
return actual_size
|
635
|
-
else:
|
636
|
-
# If we don't know total size, return chunk_size
|
637
|
-
return chunk_size
|
638
|
-
else:
|
639
|
-
# Regular WebDataset counting
|
640
|
-
count = 0
|
641
|
-
try:
|
642
|
-
for _ in self.iterate_shard(shard_url, processed_keys):
|
643
|
-
count += 1
|
644
|
-
except Exception as e:
|
645
|
-
logger.error(f"Error counting shard {shard_url}: {e}")
|
646
|
-
return count
|
177
|
+
count = 0
|
178
|
+
try:
|
179
|
+
for _ in self.iterate_shard(shard_url, processed_keys):
|
180
|
+
count += 1
|
181
|
+
except Exception as e:
|
182
|
+
logger.error(f"Error counting shard {shard_url}: {e}")
|
183
|
+
return count
|
647
184
|
|
648
185
|
def get_dataset_info(self) -> Dict[str, Any]:
|
649
186
|
"""Get information about the dataset."""
|
@@ -654,27 +191,32 @@ class DatasetLoader:
|
|
654
191
|
}
|
655
192
|
|
656
193
|
if self.dataset_format == "huggingface_datasets":
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
# Get features info
|
663
|
-
if hasattr(dataset_info, "features"):
|
664
|
-
info["features"] = str(dataset_info.features)
|
665
|
-
|
666
|
-
# Try to get total size (might not work for all datasets)
|
194
|
+
# Include cached metadata if available
|
195
|
+
if hasattr(self, "_hf_metadata"):
|
196
|
+
info.update(self._hf_metadata)
|
197
|
+
else:
|
198
|
+
|
667
199
|
try:
|
668
|
-
#
|
669
|
-
|
670
|
-
|
200
|
+
# Try to get more info about the dataset
|
201
|
+
dataset_info = load_dataset(
|
202
|
+
self.dataset_path, split=self.split, streaming=True, token=self.token
|
671
203
|
)
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
204
|
+
# Get features info
|
205
|
+
if hasattr(dataset_info, "features"):
|
206
|
+
info["features"] = str(dataset_info.features)
|
207
|
+
|
208
|
+
# Try to get total size (might not work for all datasets)
|
209
|
+
try:
|
210
|
+
# This might be expensive for large datasets
|
211
|
+
total_examples = len(
|
212
|
+
load_dataset(self.dataset_path, split=self.split, token=self.token)
|
213
|
+
)
|
214
|
+
info["total_examples"] = total_examples
|
215
|
+
self._hf_total_items = total_examples
|
216
|
+
except:
|
217
|
+
info["total_examples"] = "unknown"
|
676
218
|
|
677
|
-
|
678
|
-
|
219
|
+
except Exception as e:
|
220
|
+
logger.error(f"Error getting dataset info: {e}")
|
679
221
|
|
680
222
|
return info
|