caption-flow 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +3 -2
- caption_flow/cli.py +65 -42
- caption_flow/models.py +6 -4
- caption_flow/monitor.py +13 -3
- caption_flow/orchestrator.py +1049 -264
- caption_flow/storage.py +579 -222
- caption_flow/utils/__init__.py +3 -1
- caption_flow/utils/auth.py +24 -25
- caption_flow/utils/checkpoint_tracker.py +92 -0
- caption_flow/utils/chunk_tracker.py +278 -194
- caption_flow/utils/dataset_loader.py +567 -73
- caption_flow/utils/image_processor.py +121 -1
- caption_flow/utils/prompt_template.py +137 -0
- caption_flow/utils/shard_processor.py +315 -0
- caption_flow/utils/shard_tracker.py +87 -0
- caption_flow/workers/base.py +228 -0
- caption_flow/workers/caption.py +1321 -0
- caption_flow/{worker_data.py → workers/data.py} +162 -234
- caption_flow-0.2.1.dist-info/METADATA +370 -0
- caption_flow-0.2.1.dist-info/RECORD +29 -0
- caption_flow/worker.py +0 -300
- caption_flow/worker_vllm.py +0 -1028
- caption_flow-0.1.0.dist-info/METADATA +0 -427
- caption_flow-0.1.0.dist-info/RECORD +0 -25
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.1.dist-info}/WHEEL +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.1.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.1.dist-info}/top_level.txt +0 -0
@@ -9,6 +9,8 @@ import json
|
|
9
9
|
|
10
10
|
import webdataset as wds
|
11
11
|
from huggingface_hub import HfFileSystem, get_token, hf_hub_url
|
12
|
+
from datasets import load_dataset, Dataset
|
13
|
+
from .image_processor import ImageProcessor
|
12
14
|
|
13
15
|
logger = logging.getLogger(__name__)
|
14
16
|
|
@@ -16,42 +18,143 @@ logger = logging.getLogger(__name__)
|
|
16
18
|
class DatasetLoader:
|
17
19
|
"""Handles loading datasets from various sources."""
|
18
20
|
|
19
|
-
def __init__(
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
dataset_path: str,
|
24
|
+
dataset_type: str = "huggingface",
|
25
|
+
split: str = "train",
|
26
|
+
image_column: str = "image",
|
27
|
+
):
|
20
28
|
"""
|
21
29
|
Initialize dataset loader.
|
22
30
|
|
23
31
|
Args:
|
24
32
|
dataset_path: Path to dataset (HF repo, local dir, etc.)
|
25
33
|
dataset_type: Type of dataset ("huggingface", "webdataset", "local")
|
34
|
+
split: Split to use for HuggingFace datasets (default: "train")
|
35
|
+
image_column: Column name containing image data or URLs (default: "image")
|
26
36
|
"""
|
27
37
|
self.dataset_path = dataset_path
|
28
38
|
self.dataset_type = dataset_type
|
39
|
+
self.split = split
|
40
|
+
self.image_column = image_column
|
29
41
|
self.token = get_token()
|
42
|
+
self.dataset_format = None # Will be detected: "webdataset" or "huggingface_datasets"
|
43
|
+
self._hf_dataset = None # Cache for HuggingFace dataset
|
44
|
+
self._hf_total_items = None # Cache for total items count
|
30
45
|
|
31
46
|
if not self.token and dataset_type == "huggingface":
|
32
47
|
logger.warning("No HuggingFace token found; run `huggingface-cli login`")
|
33
48
|
|
49
|
+
# Detect the actual format if it's a HuggingFace dataset
|
50
|
+
if dataset_type == "huggingface":
|
51
|
+
self.dataset_format = self._detect_dataset_format()
|
52
|
+
logger.info(f"Detected dataset format: {self.dataset_format}")
|
53
|
+
|
54
|
+
def _detect_dataset_format(self) -> str:
|
55
|
+
"""Detect whether it's WebDataset or HuggingFace datasets format."""
|
56
|
+
fs = HfFileSystem(token=self.token)
|
57
|
+
|
58
|
+
# Check for .tar files (WebDataset)
|
59
|
+
tar_files = list(fs.glob(f"hf://datasets/{self.dataset_path}/**/*.tar"))
|
60
|
+
if tar_files:
|
61
|
+
return "webdataset"
|
62
|
+
|
63
|
+
# Check for parquet files (HuggingFace datasets)
|
64
|
+
parquet_files = list(fs.glob(f"hf://datasets/{self.dataset_path}/**/*.parquet"))
|
65
|
+
if parquet_files:
|
66
|
+
return "huggingface_datasets"
|
67
|
+
|
68
|
+
# Check for dataset_info.json or dataset_dict.json
|
69
|
+
if fs.exists(f"datasets/{self.dataset_path}/dataset_info.json") or fs.exists(
|
70
|
+
f"datasets/{self.dataset_path}/dataset_dict.json"
|
71
|
+
):
|
72
|
+
return "huggingface_datasets"
|
73
|
+
|
74
|
+
logger.warning(f"Could not detect dataset format for {self.dataset_path}")
|
75
|
+
return "unknown"
|
76
|
+
|
34
77
|
def get_shard_list(self) -> List[str]:
|
35
78
|
"""Get list of all shards in the dataset."""
|
36
79
|
if self.dataset_type == "huggingface":
|
37
|
-
|
80
|
+
if self.dataset_format == "webdataset":
|
81
|
+
return self._get_hf_webdataset_shards()
|
82
|
+
elif self.dataset_format == "huggingface_datasets":
|
83
|
+
return self._get_hf_dataset_shards()
|
84
|
+
else:
|
85
|
+
logger.error(f"Unknown dataset format: {self.dataset_format}")
|
86
|
+
return []
|
38
87
|
elif self.dataset_type == "local":
|
39
88
|
return self._get_local_shards()
|
40
89
|
else:
|
41
90
|
raise ValueError(f"Unknown dataset type: {self.dataset_type}")
|
42
91
|
|
43
|
-
def
|
44
|
-
"""Get shard URLs from HuggingFace
|
45
|
-
logger.info(f"Getting shard list from HuggingFace: {self.dataset_path}")
|
92
|
+
def _get_hf_webdataset_shards(self) -> List[str]:
|
93
|
+
"""Get shard URLs from HuggingFace WebDataset."""
|
94
|
+
logger.info(f"Getting WebDataset shard list from HuggingFace: {self.dataset_path}")
|
46
95
|
|
47
|
-
fs = HfFileSystem()
|
96
|
+
fs = HfFileSystem(token=self.token)
|
48
97
|
files = [fs.resolve_path(p) for p in fs.glob(f"hf://datasets/{self.dataset_path}/**/*.tar")]
|
49
98
|
|
50
99
|
urls = [hf_hub_url(f.repo_id, f.path_in_repo, repo_type="dataset") for f in files]
|
51
100
|
|
52
|
-
logger.info(f"Found {len(urls)} shards")
|
101
|
+
logger.info(f"Found {len(urls)} WebDataset shards")
|
53
102
|
return sorted(urls)
|
54
103
|
|
104
|
+
def _get_hf_dataset_shards(self) -> List[str]:
|
105
|
+
"""Get virtual 'shards' for HuggingFace datasets format."""
|
106
|
+
logger.info(f"Getting HuggingFace dataset info: {self.dataset_path}")
|
107
|
+
|
108
|
+
# For HuggingFace datasets, we'll create virtual shards based on chunks
|
109
|
+
# Each "shard" will be a range of indices
|
110
|
+
try:
|
111
|
+
# First, try to get available splits
|
112
|
+
try:
|
113
|
+
from datasets import get_dataset_split_names
|
114
|
+
|
115
|
+
available_splits = get_dataset_split_names(self.dataset_path, token=self.token)
|
116
|
+
logger.info(f"Available splits: {available_splits}")
|
117
|
+
|
118
|
+
if self.split not in available_splits:
|
119
|
+
logger.warning(
|
120
|
+
f"Requested split '{self.split}' not found. "
|
121
|
+
f"Available splits: {available_splits}. "
|
122
|
+
f"Using first available split: '{available_splits[0]}'"
|
123
|
+
)
|
124
|
+
self.split = available_splits[0]
|
125
|
+
except Exception as e:
|
126
|
+
logger.warning(f"Could not get split names: {e}")
|
127
|
+
|
128
|
+
# Load dataset info without downloading data
|
129
|
+
dataset_info = load_dataset(
|
130
|
+
self.dataset_path, split=self.split, streaming=True, token=self.token
|
131
|
+
)
|
132
|
+
|
133
|
+
# Try to get the total size
|
134
|
+
# For streaming datasets, we might need to iterate to count
|
135
|
+
# This is expensive, so we'll use a default chunk size instead
|
136
|
+
chunk_size = 10000 # Default chunk size for virtual shards
|
137
|
+
|
138
|
+
# Create virtual shard identifiers
|
139
|
+
# Format: "hf_dataset:<dataset_path>:chunk:<start_idx>"
|
140
|
+
virtual_shards = []
|
141
|
+
|
142
|
+
# We'll create a reasonable number of virtual shards
|
143
|
+
# Without knowing the total size, we'll create them on-demand
|
144
|
+
# For now, create initial batch of virtual shards
|
145
|
+
for i in range(10): # Start with 10 virtual shards
|
146
|
+
shard_id = f"hf_dataset:{self.dataset_path}:chunk:{i * chunk_size}"
|
147
|
+
virtual_shards.append(shard_id)
|
148
|
+
|
149
|
+
logger.info(
|
150
|
+
f"Created {len(virtual_shards)} initial virtual shards for HuggingFace dataset"
|
151
|
+
)
|
152
|
+
return virtual_shards
|
153
|
+
|
154
|
+
except Exception as e:
|
155
|
+
logger.error(f"Error loading HuggingFace dataset info: {e}")
|
156
|
+
return []
|
157
|
+
|
55
158
|
def _get_local_shards(self) -> List[str]:
|
56
159
|
"""Get shard files from local directory."""
|
57
160
|
path = Path(self.dataset_path)
|
@@ -73,7 +176,14 @@ class DatasetLoader:
|
|
73
176
|
if processed_keys is None:
|
74
177
|
processed_keys = set()
|
75
178
|
|
76
|
-
if
|
179
|
+
# Check if this is a virtual HuggingFace dataset shard
|
180
|
+
if shard_url.startswith("hf_dataset:"):
|
181
|
+
raise ValueError(
|
182
|
+
"Virtual HuggingFace dataset shards should use iterate_shard() directly, "
|
183
|
+
"not load_shard()"
|
184
|
+
)
|
185
|
+
|
186
|
+
if self.dataset_type == "huggingface" and self.dataset_format == "webdataset":
|
77
187
|
# Use curl with auth token for HuggingFace
|
78
188
|
url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.token)}' {shlex.quote(shard_url)} || true"
|
79
189
|
ds = wds.DataPipeline(
|
@@ -93,6 +203,19 @@ class DatasetLoader:
|
|
93
203
|
|
94
204
|
return ds
|
95
205
|
|
206
|
+
def _parse_virtual_shard(self, shard_url: str) -> Tuple[str, int, int]:
|
207
|
+
"""Parse virtual shard identifier."""
|
208
|
+
# Format: "hf_dataset:<dataset_path>:chunk:<start_idx>"
|
209
|
+
parts = shard_url.split(":")
|
210
|
+
if len(parts) != 4 or parts[0] != "hf_dataset" or parts[2] != "chunk":
|
211
|
+
raise ValueError(f"Invalid virtual shard format: {shard_url}")
|
212
|
+
|
213
|
+
dataset_path = parts[1]
|
214
|
+
start_idx = int(parts[3])
|
215
|
+
chunk_size = 10000 # Default chunk size
|
216
|
+
|
217
|
+
return dataset_path, start_idx, chunk_size
|
218
|
+
|
96
219
|
def iterate_shard(
|
97
220
|
self, shard_url: str, processed_keys: Optional[set] = None
|
98
221
|
) -> Generator[Tuple[str, str, bytes], None, None]:
|
@@ -102,85 +225,456 @@ class DatasetLoader:
|
|
102
225
|
Yields:
|
103
226
|
Tuple of (key, url, image_bytes)
|
104
227
|
"""
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
228
|
+
# Check if this is a virtual HuggingFace dataset shard
|
229
|
+
if shard_url.startswith("hf_dataset:"):
|
230
|
+
yield from self._iterate_hf_dataset_shard(shard_url, processed_keys)
|
231
|
+
else:
|
232
|
+
# Regular WebDataset shard
|
233
|
+
ds = self.load_shard(shard_url, processed_keys)
|
234
|
+
for key, url, image_data in ds:
|
235
|
+
yield key, url, image_data
|
109
236
|
|
110
|
-
def
|
111
|
-
"""
|
112
|
-
count = 0
|
237
|
+
def _create_dataset_at_position(self, dataset_path: str, split: str, start_idx: int):
|
238
|
+
"""Create a dataset iterator positioned at start_idx using state_dict if available."""
|
113
239
|
try:
|
114
|
-
|
115
|
-
|
240
|
+
# Load dataset in streaming mode
|
241
|
+
dataset = load_dataset(
|
242
|
+
dataset_path,
|
243
|
+
split=split,
|
244
|
+
streaming=True,
|
245
|
+
token=self.token,
|
246
|
+
)
|
247
|
+
|
248
|
+
# Check if the dataset supports state_dict (newer versions of datasets library)
|
249
|
+
if hasattr(dataset, "load_state_dict") and hasattr(dataset, "state_dict"):
|
250
|
+
# Try to use the dataset's native state management
|
251
|
+
try:
|
252
|
+
# Get current state
|
253
|
+
state = dataset.state_dict()
|
254
|
+
|
255
|
+
# Modify the state to skip to start_idx
|
256
|
+
if "epoch" in state:
|
257
|
+
state["epoch"] = 0
|
258
|
+
if "num_examples_since_previous_state" in state:
|
259
|
+
state["num_examples_since_previous_state"] = start_idx
|
260
|
+
|
261
|
+
# For newer datasets with examples_iterable state
|
262
|
+
if "examples_iterable" in state:
|
263
|
+
if isinstance(state["examples_iterable"], dict):
|
264
|
+
if "shard_example_idx" in state["examples_iterable"]:
|
265
|
+
state["examples_iterable"]["shard_example_idx"] = start_idx
|
266
|
+
|
267
|
+
# Load the modified state
|
268
|
+
dataset.load_state_dict(state)
|
269
|
+
logger.info(f"Positioned dataset at index {start_idx} using state_dict")
|
270
|
+
return dataset
|
271
|
+
except Exception as e:
|
272
|
+
logger.debug(f"Could not use state_dict approach: {e}")
|
273
|
+
|
274
|
+
# Fall back to skip() for large skips
|
275
|
+
if start_idx > 0:
|
276
|
+
logger.info(f"Using skip() to position dataset at index {start_idx}")
|
277
|
+
dataset = dataset.skip(start_idx)
|
278
|
+
|
279
|
+
return dataset
|
280
|
+
|
116
281
|
except Exception as e:
|
117
|
-
logger.
|
118
|
-
|
282
|
+
logger.warning(f"Error creating positioned dataset: {e}")
|
283
|
+
return None
|
284
|
+
|
285
|
+
def _iterate_hf_dataset_shard_with_metadata(
|
286
|
+
self, shard_url: str, processed_keys: Optional[set] = None
|
287
|
+
) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
|
288
|
+
"""Iterate over a virtual HuggingFace dataset shard with metadata."""
|
289
|
+
if processed_keys is None:
|
290
|
+
processed_keys = set()
|
119
291
|
|
292
|
+
dataset_path, start_idx, chunk_size = self._parse_virtual_shard(shard_url)
|
120
293
|
|
121
|
-
|
122
|
-
|
294
|
+
logger.info(
|
295
|
+
f"Loading HuggingFace dataset with metadata: {dataset_path} (split: {self.split})"
|
296
|
+
)
|
123
297
|
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
298
|
+
try:
|
299
|
+
# Try optimized approach for large skips
|
300
|
+
if start_idx > 100:
|
301
|
+
dataset = self._create_dataset_at_position(dataset_path, self.split, start_idx)
|
302
|
+
if dataset:
|
303
|
+
items_processed = 0
|
304
|
+
|
305
|
+
for item in dataset:
|
306
|
+
# Stop after processing chunk_size items
|
307
|
+
if items_processed >= chunk_size:
|
308
|
+
break
|
309
|
+
|
310
|
+
# Generate a unique key for this item
|
311
|
+
key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
|
312
|
+
|
313
|
+
if key in processed_keys:
|
314
|
+
items_processed += 1
|
315
|
+
continue
|
316
|
+
|
317
|
+
try:
|
318
|
+
# Extract image data
|
319
|
+
if self.image_column in item:
|
320
|
+
img_data = item[self.image_column]
|
321
|
+
|
322
|
+
# Process image to bytes
|
323
|
+
image_bytes = ImageProcessor.process_image_data(img_data)
|
324
|
+
|
325
|
+
if image_bytes:
|
326
|
+
# Extract all metadata (excluding the image column)
|
327
|
+
metadata = {
|
328
|
+
k: v for k, v in item.items() if k != self.image_column
|
329
|
+
}
|
330
|
+
|
331
|
+
# URL is virtual for HF datasets
|
332
|
+
url = f"hf://{dataset_path}#{start_idx + items_processed}"
|
333
|
+
items_processed += 1
|
334
|
+
yield key, url, image_bytes, metadata
|
335
|
+
else:
|
336
|
+
logger.warning(
|
337
|
+
f"Failed to process image for item at index {start_idx + items_processed}"
|
338
|
+
)
|
339
|
+
items_processed += 1
|
340
|
+
continue
|
341
|
+
else:
|
342
|
+
logger.warning(
|
343
|
+
f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
|
344
|
+
f"Available columns: {list(item.keys())}"
|
345
|
+
)
|
346
|
+
items_processed += 1
|
347
|
+
|
348
|
+
except Exception as e:
|
349
|
+
logger.error(
|
350
|
+
f"Error processing item at index {start_idx + items_processed}: {e}"
|
351
|
+
)
|
352
|
+
items_processed += 1
|
353
|
+
continue
|
354
|
+
|
355
|
+
return
|
356
|
+
|
357
|
+
# Fall back to regular approach for small skips or if StatefulDataLoader not available
|
358
|
+
dataset = load_dataset(
|
359
|
+
dataset_path,
|
360
|
+
split=self.split,
|
361
|
+
streaming=True,
|
362
|
+
token=self.token,
|
363
|
+
)
|
128
364
|
|
129
|
-
|
130
|
-
|
131
|
-
|
365
|
+
# Skip to start index if needed
|
366
|
+
if start_idx > 0:
|
367
|
+
dataset = dataset.skip(start_idx)
|
368
|
+
|
369
|
+
items_processed = 0
|
370
|
+
|
371
|
+
for item in dataset:
|
372
|
+
# Stop after processing chunk_size items
|
373
|
+
if items_processed >= chunk_size:
|
374
|
+
break
|
375
|
+
|
376
|
+
# Generate a unique key for this item
|
377
|
+
key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
|
378
|
+
|
379
|
+
if key in processed_keys:
|
380
|
+
items_processed += 1
|
381
|
+
continue
|
382
|
+
|
383
|
+
try:
|
384
|
+
# Extract image data
|
385
|
+
if self.image_column in item:
|
386
|
+
img_data = item[self.image_column]
|
387
|
+
|
388
|
+
# Process image to bytes
|
389
|
+
image_bytes = ImageProcessor.process_image_data(img_data)
|
390
|
+
|
391
|
+
if image_bytes:
|
392
|
+
# Extract all metadata (excluding the image column)
|
393
|
+
metadata = {k: v for k, v in item.items() if k != self.image_column}
|
394
|
+
|
395
|
+
# URL is virtual for HF datasets
|
396
|
+
url = f"hf://{dataset_path}#{start_idx + items_processed}"
|
397
|
+
items_processed += 1
|
398
|
+
yield key, url, image_bytes, metadata
|
399
|
+
else:
|
400
|
+
logger.warning(
|
401
|
+
f"Failed to process image for item at index {start_idx + items_processed}"
|
402
|
+
)
|
403
|
+
items_processed += 1
|
404
|
+
continue
|
405
|
+
else:
|
406
|
+
logger.warning(
|
407
|
+
f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
|
408
|
+
f"Available columns: {list(item.keys())}"
|
409
|
+
)
|
410
|
+
items_processed += 1
|
411
|
+
|
412
|
+
except Exception as e:
|
413
|
+
logger.error(
|
414
|
+
f"Error processing item at index {start_idx + items_processed}: {e}"
|
415
|
+
)
|
416
|
+
items_processed += 1
|
417
|
+
continue
|
132
418
|
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
try:
|
137
|
-
data = json.loads(self.checkpoint_path.read_text())
|
138
|
-
self.completed_shards = set(data.get("completed_shards", []))
|
139
|
-
self.partial_shards = data.get("partial_shards", {})
|
140
|
-
logger.info(
|
141
|
-
f"Loaded checkpoint: {len(self.completed_shards)} completed, "
|
142
|
-
f"{len(self.partial_shards)} partial shards"
|
143
|
-
)
|
144
|
-
except Exception as e:
|
145
|
-
logger.error(f"Failed to load checkpoint: {e}")
|
419
|
+
except Exception as e:
|
420
|
+
logger.error(f"Error loading HuggingFace dataset: {e}")
|
421
|
+
return
|
146
422
|
|
147
|
-
def
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
423
|
+
def _iterate_hf_dataset_shard(
|
424
|
+
self, shard_url: str, processed_keys: Optional[set] = None
|
425
|
+
) -> Generator[Tuple[str, str, bytes], None, None]:
|
426
|
+
"""Iterate over a virtual HuggingFace dataset shard."""
|
427
|
+
if processed_keys is None:
|
428
|
+
processed_keys = set()
|
429
|
+
|
430
|
+
dataset_path, start_idx, chunk_size = self._parse_virtual_shard(shard_url)
|
431
|
+
|
432
|
+
# IMPORTANT: Check if start_idx is beyond dataset bounds
|
433
|
+
if self._hf_total_items is not None and start_idx >= self._hf_total_items:
|
434
|
+
logger.warning(
|
435
|
+
f"Virtual shard starts at index {start_idx} but dataset only has "
|
436
|
+
f"{self._hf_total_items} items. Skipping this shard."
|
437
|
+
)
|
438
|
+
return
|
439
|
+
|
440
|
+
logger.info(
|
441
|
+
f"Loading HuggingFace dataset in streaming mode: {dataset_path} "
|
442
|
+
f"(split: {self.split}, start: {start_idx}, chunk_size: {chunk_size})"
|
443
|
+
)
|
444
|
+
|
445
|
+
try:
|
446
|
+
# Try optimized approach for large skips
|
447
|
+
if start_idx > 100:
|
448
|
+
dataset = self._create_dataset_at_position(dataset_path, self.split, start_idx)
|
449
|
+
if dataset:
|
450
|
+
items_processed = 0
|
451
|
+
|
452
|
+
for item in dataset:
|
453
|
+
# Stop after processing chunk_size items
|
454
|
+
if items_processed >= chunk_size:
|
455
|
+
logger.info(f"Completed chunk: processed {items_processed} items")
|
456
|
+
break
|
457
|
+
|
458
|
+
# Also stop if we've reached the dataset end
|
459
|
+
if (
|
460
|
+
self._hf_total_items
|
461
|
+
and (start_idx + items_processed) >= self._hf_total_items
|
462
|
+
):
|
463
|
+
logger.info(
|
464
|
+
f"Reached dataset end at item {start_idx + items_processed} "
|
465
|
+
f"(total: {self._hf_total_items})"
|
466
|
+
)
|
467
|
+
break
|
468
|
+
|
469
|
+
# Generate a unique key for this item
|
470
|
+
key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
|
471
|
+
|
472
|
+
if key in processed_keys:
|
473
|
+
items_processed += 1
|
474
|
+
continue
|
475
|
+
|
476
|
+
try:
|
477
|
+
# Extract image data
|
478
|
+
if self.image_column in item:
|
479
|
+
img_data = item[self.image_column]
|
480
|
+
|
481
|
+
# Delegate image processing to ImageProcessor
|
482
|
+
image_bytes = ImageProcessor.process_image_data(img_data)
|
483
|
+
|
484
|
+
if image_bytes:
|
485
|
+
# URL is virtual for HF datasets
|
486
|
+
url = f"hf://{dataset_path}#{start_idx + items_processed}"
|
487
|
+
items_processed += 1
|
488
|
+
yield key, url, image_bytes
|
489
|
+
else:
|
490
|
+
logger.warning(
|
491
|
+
f"Failed to process image for item at index {start_idx + items_processed}"
|
492
|
+
)
|
493
|
+
items_processed += 1
|
494
|
+
continue
|
495
|
+
else:
|
496
|
+
logger.warning(
|
497
|
+
f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
|
498
|
+
f"Available columns: {list(item.keys())}"
|
499
|
+
)
|
500
|
+
items_processed += 1
|
501
|
+
|
502
|
+
except Exception as e:
|
503
|
+
logger.error(
|
504
|
+
f"Error processing item at index {start_idx + items_processed}: {e}"
|
505
|
+
)
|
506
|
+
items_processed += 1
|
507
|
+
continue
|
508
|
+
|
509
|
+
logger.info(
|
510
|
+
f"Virtual shard complete: processed {items_processed} items "
|
511
|
+
f"(start_idx: {start_idx})"
|
512
|
+
)
|
513
|
+
return
|
514
|
+
|
515
|
+
# Fall back to regular approach for small skips or if StatefulDataLoader not available
|
516
|
+
dataset = load_dataset(
|
517
|
+
dataset_path,
|
518
|
+
split=self.split,
|
519
|
+
streaming=True,
|
520
|
+
token=self.token,
|
521
|
+
)
|
153
522
|
|
154
|
-
|
155
|
-
|
156
|
-
|
523
|
+
# Use dataset.skip() for efficient skipping
|
524
|
+
if start_idx > 0:
|
525
|
+
dataset = dataset.skip(start_idx)
|
526
|
+
logger.info(f"Skipped to index {start_idx}")
|
527
|
+
|
528
|
+
items_processed = 0
|
529
|
+
|
530
|
+
# Now enumerate starts from 0 after skip
|
531
|
+
for item in dataset:
|
532
|
+
# Stop after processing chunk_size items
|
533
|
+
if items_processed >= chunk_size:
|
534
|
+
logger.info(f"Completed chunk: processed {items_processed} items")
|
535
|
+
break
|
536
|
+
|
537
|
+
# Also stop if we've reached the dataset end
|
538
|
+
if self._hf_total_items and (start_idx + items_processed) >= self._hf_total_items:
|
539
|
+
logger.info(
|
540
|
+
f"Reached dataset end at item {start_idx + items_processed} "
|
541
|
+
f"(total: {self._hf_total_items})"
|
542
|
+
)
|
543
|
+
break
|
544
|
+
|
545
|
+
# Generate a unique key for this item - ensure proper formatting
|
546
|
+
key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
|
547
|
+
|
548
|
+
if key in processed_keys:
|
549
|
+
items_processed += 1
|
550
|
+
continue
|
551
|
+
|
552
|
+
try:
|
553
|
+
# Extract image data - check configured column name
|
554
|
+
if self.image_column in item:
|
555
|
+
img_data = item[self.image_column]
|
556
|
+
|
557
|
+
# Delegate image processing to ImageProcessor
|
558
|
+
image_bytes = ImageProcessor.process_image_data(img_data)
|
559
|
+
|
560
|
+
if image_bytes:
|
561
|
+
# URL is virtual for HF datasets
|
562
|
+
url = f"hf://{dataset_path}#{start_idx + items_processed}"
|
563
|
+
items_processed += 1
|
564
|
+
yield key, url, image_bytes
|
565
|
+
else:
|
566
|
+
logger.warning(
|
567
|
+
f"Failed to process image for item at index {start_idx + items_processed}"
|
568
|
+
)
|
569
|
+
items_processed += 1
|
570
|
+
continue
|
571
|
+
else:
|
572
|
+
logger.warning(
|
573
|
+
f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
|
574
|
+
f"Available columns: {list(item.keys())}"
|
575
|
+
)
|
576
|
+
items_processed += 1
|
577
|
+
|
578
|
+
except Exception as e:
|
579
|
+
logger.error(
|
580
|
+
f"Error processing item at index {start_idx + items_processed}: {e}"
|
581
|
+
)
|
582
|
+
items_processed += 1
|
583
|
+
continue
|
584
|
+
|
585
|
+
logger.info(
|
586
|
+
f"Virtual shard complete: processed {items_processed} items "
|
587
|
+
f"(start_idx: {start_idx})"
|
588
|
+
)
|
157
589
|
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
if shard_name in self.partial_shards:
|
162
|
-
del self.partial_shards[shard_name]
|
163
|
-
self.save()
|
590
|
+
except Exception as e:
|
591
|
+
logger.error(f"Error loading HuggingFace dataset: {e}")
|
592
|
+
return
|
164
593
|
|
165
|
-
def
|
166
|
-
|
167
|
-
|
168
|
-
|
594
|
+
def iterate_shard_with_metadata(
|
595
|
+
self, shard_url: str, processed_keys: Optional[set] = None
|
596
|
+
) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
|
597
|
+
"""
|
598
|
+
Iterate over items in a shard, including metadata.
|
169
599
|
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
600
|
+
Yields:
|
601
|
+
Tuple of (key, url, image_bytes, metadata_dict)
|
602
|
+
"""
|
603
|
+
# Check if this is a virtual HuggingFace dataset shard
|
604
|
+
if shard_url.startswith("hf_dataset:"):
|
605
|
+
yield from self._iterate_hf_dataset_shard_with_metadata(shard_url, processed_keys)
|
606
|
+
else:
|
607
|
+
# Regular WebDataset shard - no metadata by default
|
608
|
+
for key, url, image_data in self.iterate_shard(shard_url, processed_keys):
|
609
|
+
yield key, url, image_data, {}
|
174
610
|
|
175
|
-
|
176
|
-
|
611
|
+
def count_shard_items(self, shard_url: str, processed_keys: Optional[set] = None) -> int:
|
612
|
+
"""Count items in a shard (can be slow for large shards)."""
|
613
|
+
if shard_url.startswith("hf_dataset:"):
|
614
|
+
# For virtual shards, return the chunk size
|
615
|
+
_, start_idx, chunk_size = self._parse_virtual_shard(shard_url)
|
616
|
+
|
617
|
+
# CRITICAL: Cap chunk size by dataset bounds
|
618
|
+
if self._hf_total_items is not None:
|
619
|
+
# If start index is beyond dataset, return 0
|
620
|
+
if start_idx >= self._hf_total_items:
|
621
|
+
logger.warning(
|
622
|
+
f"Virtual shard starts at {start_idx} but dataset has "
|
623
|
+
f"only {self._hf_total_items} items"
|
624
|
+
)
|
625
|
+
return 0
|
626
|
+
|
627
|
+
# Otherwise, return the minimum of chunk_size and remaining items
|
628
|
+
remaining_items = self._hf_total_items - start_idx
|
629
|
+
actual_size = min(chunk_size, remaining_items)
|
630
|
+
logger.debug(
|
631
|
+
f"Virtual shard at {start_idx}: chunk_size={chunk_size}, "
|
632
|
+
f"remaining={remaining_items}, actual={actual_size}"
|
633
|
+
)
|
634
|
+
return actual_size
|
635
|
+
else:
|
636
|
+
# If we don't know total size, return chunk_size
|
637
|
+
return chunk_size
|
638
|
+
else:
|
639
|
+
# Regular WebDataset counting
|
640
|
+
count = 0
|
641
|
+
try:
|
642
|
+
for _ in self.iterate_shard(shard_url, processed_keys):
|
643
|
+
count += 1
|
644
|
+
except Exception as e:
|
645
|
+
logger.error(f"Error counting shard {shard_url}: {e}")
|
646
|
+
return count
|
647
|
+
|
648
|
+
def get_dataset_info(self) -> Dict[str, Any]:
|
649
|
+
"""Get information about the dataset."""
|
650
|
+
info = {
|
651
|
+
"dataset_path": self.dataset_path,
|
652
|
+
"dataset_type": self.dataset_type,
|
653
|
+
"dataset_format": self.dataset_format,
|
654
|
+
}
|
177
655
|
|
178
|
-
|
656
|
+
if self.dataset_format == "huggingface_datasets":
|
657
|
+
try:
|
658
|
+
# Try to get more info about the dataset
|
659
|
+
dataset_info = load_dataset(
|
660
|
+
self.dataset_path, split=self.split, streaming=True, token=self.token
|
661
|
+
)
|
662
|
+
# Get features info
|
663
|
+
if hasattr(dataset_info, "features"):
|
664
|
+
info["features"] = str(dataset_info.features)
|
665
|
+
|
666
|
+
# Try to get total size (might not work for all datasets)
|
667
|
+
try:
|
668
|
+
# This might be expensive for large datasets
|
669
|
+
total_examples = len(
|
670
|
+
load_dataset(self.dataset_path, split=self.split, token=self.token)
|
671
|
+
)
|
672
|
+
info["total_examples"] = total_examples
|
673
|
+
self._hf_total_items = total_examples
|
674
|
+
except:
|
675
|
+
info["total_examples"] = "unknown"
|
179
676
|
|
180
|
-
|
181
|
-
|
182
|
-
return shard_name in self.completed_shards
|
677
|
+
except Exception as e:
|
678
|
+
logger.error(f"Error getting dataset info: {e}")
|
183
679
|
|
184
|
-
|
185
|
-
"""Get list of shards that still need processing."""
|
186
|
-
return [s for s in all_shards if Path(s).stem not in self.completed_shards]
|
680
|
+
return info
|