caption-flow 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +3 -2
- caption_flow/cli.py +56 -39
- caption_flow/models.py +6 -4
- caption_flow/monitor.py +12 -2
- caption_flow/orchestrator.py +729 -217
- caption_flow/storage.py +579 -222
- caption_flow/utils/__init__.py +3 -1
- caption_flow/utils/auth.py +24 -25
- caption_flow/utils/checkpoint_tracker.py +92 -0
- caption_flow/utils/chunk_tracker.py +278 -194
- caption_flow/utils/dataset_loader.py +392 -73
- caption_flow/utils/image_processor.py +121 -1
- caption_flow/utils/prompt_template.py +137 -0
- caption_flow/utils/shard_processor.py +315 -0
- caption_flow/utils/shard_tracker.py +87 -0
- caption_flow/workers/base.py +228 -0
- caption_flow/workers/caption.py +1321 -0
- caption_flow/{worker_data.py → workers/data.py} +162 -234
- caption_flow-0.2.0.dist-info/METADATA +369 -0
- caption_flow-0.2.0.dist-info/RECORD +29 -0
- caption_flow/worker.py +0 -300
- caption_flow/worker_vllm.py +0 -1028
- caption_flow-0.1.0.dist-info/METADATA +0 -427
- caption_flow-0.1.0.dist-info/RECORD +0 -25
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.0.dist-info}/WHEEL +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.0.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.0.dist-info}/top_level.txt +0 -0
@@ -9,6 +9,8 @@ import json
|
|
9
9
|
|
10
10
|
import webdataset as wds
|
11
11
|
from huggingface_hub import HfFileSystem, get_token, hf_hub_url
|
12
|
+
from datasets import load_dataset, Dataset
|
13
|
+
from .image_processor import ImageProcessor
|
12
14
|
|
13
15
|
logger = logging.getLogger(__name__)
|
14
16
|
|
@@ -16,42 +18,143 @@ logger = logging.getLogger(__name__)
|
|
16
18
|
class DatasetLoader:
|
17
19
|
"""Handles loading datasets from various sources."""
|
18
20
|
|
19
|
-
def __init__(
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
dataset_path: str,
|
24
|
+
dataset_type: str = "huggingface",
|
25
|
+
split: str = "train",
|
26
|
+
image_column: str = "image",
|
27
|
+
):
|
20
28
|
"""
|
21
29
|
Initialize dataset loader.
|
22
30
|
|
23
31
|
Args:
|
24
32
|
dataset_path: Path to dataset (HF repo, local dir, etc.)
|
25
33
|
dataset_type: Type of dataset ("huggingface", "webdataset", "local")
|
34
|
+
split: Split to use for HuggingFace datasets (default: "train")
|
35
|
+
image_column: Column name containing image data or URLs (default: "image")
|
26
36
|
"""
|
27
37
|
self.dataset_path = dataset_path
|
28
38
|
self.dataset_type = dataset_type
|
39
|
+
self.split = split
|
40
|
+
self.image_column = image_column
|
29
41
|
self.token = get_token()
|
42
|
+
self.dataset_format = None # Will be detected: "webdataset" or "huggingface_datasets"
|
43
|
+
self._hf_dataset = None # Cache for HuggingFace dataset
|
44
|
+
self._hf_total_items = None # Cache for total items count
|
30
45
|
|
31
46
|
if not self.token and dataset_type == "huggingface":
|
32
47
|
logger.warning("No HuggingFace token found; run `huggingface-cli login`")
|
33
48
|
|
49
|
+
# Detect the actual format if it's a HuggingFace dataset
|
50
|
+
if dataset_type == "huggingface":
|
51
|
+
self.dataset_format = self._detect_dataset_format()
|
52
|
+
logger.info(f"Detected dataset format: {self.dataset_format}")
|
53
|
+
|
54
|
+
def _detect_dataset_format(self) -> str:
|
55
|
+
"""Detect whether it's WebDataset or HuggingFace datasets format."""
|
56
|
+
fs = HfFileSystem(token=self.token)
|
57
|
+
|
58
|
+
# Check for .tar files (WebDataset)
|
59
|
+
tar_files = list(fs.glob(f"hf://datasets/{self.dataset_path}/**/*.tar"))
|
60
|
+
if tar_files:
|
61
|
+
return "webdataset"
|
62
|
+
|
63
|
+
# Check for parquet files (HuggingFace datasets)
|
64
|
+
parquet_files = list(fs.glob(f"hf://datasets/{self.dataset_path}/**/*.parquet"))
|
65
|
+
if parquet_files:
|
66
|
+
return "huggingface_datasets"
|
67
|
+
|
68
|
+
# Check for dataset_info.json or dataset_dict.json
|
69
|
+
if fs.exists(f"datasets/{self.dataset_path}/dataset_info.json") or fs.exists(
|
70
|
+
f"datasets/{self.dataset_path}/dataset_dict.json"
|
71
|
+
):
|
72
|
+
return "huggingface_datasets"
|
73
|
+
|
74
|
+
logger.warning(f"Could not detect dataset format for {self.dataset_path}")
|
75
|
+
return "unknown"
|
76
|
+
|
34
77
|
def get_shard_list(self) -> List[str]:
|
35
78
|
"""Get list of all shards in the dataset."""
|
36
79
|
if self.dataset_type == "huggingface":
|
37
|
-
|
80
|
+
if self.dataset_format == "webdataset":
|
81
|
+
return self._get_hf_webdataset_shards()
|
82
|
+
elif self.dataset_format == "huggingface_datasets":
|
83
|
+
return self._get_hf_dataset_shards()
|
84
|
+
else:
|
85
|
+
logger.error(f"Unknown dataset format: {self.dataset_format}")
|
86
|
+
return []
|
38
87
|
elif self.dataset_type == "local":
|
39
88
|
return self._get_local_shards()
|
40
89
|
else:
|
41
90
|
raise ValueError(f"Unknown dataset type: {self.dataset_type}")
|
42
91
|
|
43
|
-
def
|
44
|
-
"""Get shard URLs from HuggingFace
|
45
|
-
logger.info(f"Getting shard list from HuggingFace: {self.dataset_path}")
|
92
|
+
def _get_hf_webdataset_shards(self) -> List[str]:
|
93
|
+
"""Get shard URLs from HuggingFace WebDataset."""
|
94
|
+
logger.info(f"Getting WebDataset shard list from HuggingFace: {self.dataset_path}")
|
46
95
|
|
47
|
-
fs = HfFileSystem()
|
96
|
+
fs = HfFileSystem(token=self.token)
|
48
97
|
files = [fs.resolve_path(p) for p in fs.glob(f"hf://datasets/{self.dataset_path}/**/*.tar")]
|
49
98
|
|
50
99
|
urls = [hf_hub_url(f.repo_id, f.path_in_repo, repo_type="dataset") for f in files]
|
51
100
|
|
52
|
-
logger.info(f"Found {len(urls)} shards")
|
101
|
+
logger.info(f"Found {len(urls)} WebDataset shards")
|
53
102
|
return sorted(urls)
|
54
103
|
|
104
|
+
def _get_hf_dataset_shards(self) -> List[str]:
|
105
|
+
"""Get virtual 'shards' for HuggingFace datasets format."""
|
106
|
+
logger.info(f"Getting HuggingFace dataset info: {self.dataset_path}")
|
107
|
+
|
108
|
+
# For HuggingFace datasets, we'll create virtual shards based on chunks
|
109
|
+
# Each "shard" will be a range of indices
|
110
|
+
try:
|
111
|
+
# First, try to get available splits
|
112
|
+
try:
|
113
|
+
from datasets import get_dataset_split_names
|
114
|
+
|
115
|
+
available_splits = get_dataset_split_names(self.dataset_path, token=self.token)
|
116
|
+
logger.info(f"Available splits: {available_splits}")
|
117
|
+
|
118
|
+
if self.split not in available_splits:
|
119
|
+
logger.warning(
|
120
|
+
f"Requested split '{self.split}' not found. "
|
121
|
+
f"Available splits: {available_splits}. "
|
122
|
+
f"Using first available split: '{available_splits[0]}'"
|
123
|
+
)
|
124
|
+
self.split = available_splits[0]
|
125
|
+
except Exception as e:
|
126
|
+
logger.warning(f"Could not get split names: {e}")
|
127
|
+
|
128
|
+
# Load dataset info without downloading data
|
129
|
+
dataset_info = load_dataset(
|
130
|
+
self.dataset_path, split=self.split, streaming=True, token=self.token
|
131
|
+
)
|
132
|
+
|
133
|
+
# Try to get the total size
|
134
|
+
# For streaming datasets, we might need to iterate to count
|
135
|
+
# This is expensive, so we'll use a default chunk size instead
|
136
|
+
chunk_size = 10000 # Default chunk size for virtual shards
|
137
|
+
|
138
|
+
# Create virtual shard identifiers
|
139
|
+
# Format: "hf_dataset:<dataset_path>:chunk:<start_idx>"
|
140
|
+
virtual_shards = []
|
141
|
+
|
142
|
+
# We'll create a reasonable number of virtual shards
|
143
|
+
# Without knowing the total size, we'll create them on-demand
|
144
|
+
# For now, create initial batch of virtual shards
|
145
|
+
for i in range(10): # Start with 10 virtual shards
|
146
|
+
shard_id = f"hf_dataset:{self.dataset_path}:chunk:{i * chunk_size}"
|
147
|
+
virtual_shards.append(shard_id)
|
148
|
+
|
149
|
+
logger.info(
|
150
|
+
f"Created {len(virtual_shards)} initial virtual shards for HuggingFace dataset"
|
151
|
+
)
|
152
|
+
return virtual_shards
|
153
|
+
|
154
|
+
except Exception as e:
|
155
|
+
logger.error(f"Error loading HuggingFace dataset info: {e}")
|
156
|
+
return []
|
157
|
+
|
55
158
|
def _get_local_shards(self) -> List[str]:
|
56
159
|
"""Get shard files from local directory."""
|
57
160
|
path = Path(self.dataset_path)
|
@@ -73,7 +176,14 @@ class DatasetLoader:
|
|
73
176
|
if processed_keys is None:
|
74
177
|
processed_keys = set()
|
75
178
|
|
76
|
-
if
|
179
|
+
# Check if this is a virtual HuggingFace dataset shard
|
180
|
+
if shard_url.startswith("hf_dataset:"):
|
181
|
+
raise ValueError(
|
182
|
+
"Virtual HuggingFace dataset shards should use iterate_shard() directly, "
|
183
|
+
"not load_shard()"
|
184
|
+
)
|
185
|
+
|
186
|
+
if self.dataset_type == "huggingface" and self.dataset_format == "webdataset":
|
77
187
|
# Use curl with auth token for HuggingFace
|
78
188
|
url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.token)}' {shlex.quote(shard_url)} || true"
|
79
189
|
ds = wds.DataPipeline(
|
@@ -93,6 +203,19 @@ class DatasetLoader:
|
|
93
203
|
|
94
204
|
return ds
|
95
205
|
|
206
|
+
def _parse_virtual_shard(self, shard_url: str) -> Tuple[str, int, int]:
|
207
|
+
"""Parse virtual shard identifier."""
|
208
|
+
# Format: "hf_dataset:<dataset_path>:chunk:<start_idx>"
|
209
|
+
parts = shard_url.split(":")
|
210
|
+
if len(parts) != 4 or parts[0] != "hf_dataset" or parts[2] != "chunk":
|
211
|
+
raise ValueError(f"Invalid virtual shard format: {shard_url}")
|
212
|
+
|
213
|
+
dataset_path = parts[1]
|
214
|
+
start_idx = int(parts[3])
|
215
|
+
chunk_size = 10000 # Default chunk size
|
216
|
+
|
217
|
+
return dataset_path, start_idx, chunk_size
|
218
|
+
|
96
219
|
def iterate_shard(
|
97
220
|
self, shard_url: str, processed_keys: Optional[set] = None
|
98
221
|
) -> Generator[Tuple[str, str, bytes], None, None]:
|
@@ -102,85 +225,281 @@ class DatasetLoader:
|
|
102
225
|
Yields:
|
103
226
|
Tuple of (key, url, image_bytes)
|
104
227
|
"""
|
105
|
-
|
228
|
+
# Check if this is a virtual HuggingFace dataset shard
|
229
|
+
if shard_url.startswith("hf_dataset:"):
|
230
|
+
yield from self._iterate_hf_dataset_shard(shard_url, processed_keys)
|
231
|
+
else:
|
232
|
+
# Regular WebDataset shard
|
233
|
+
ds = self.load_shard(shard_url, processed_keys)
|
234
|
+
for key, url, image_data in ds:
|
235
|
+
yield key, url, image_data
|
236
|
+
|
237
|
+
def _iterate_hf_dataset_shard_with_metadata(
|
238
|
+
self, shard_url: str, processed_keys: Optional[set] = None
|
239
|
+
) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
|
240
|
+
"""Iterate over a virtual HuggingFace dataset shard with metadata."""
|
241
|
+
if processed_keys is None:
|
242
|
+
processed_keys = set()
|
106
243
|
|
107
|
-
|
108
|
-
|
244
|
+
dataset_path, start_idx, chunk_size = self._parse_virtual_shard(shard_url)
|
245
|
+
|
246
|
+
logger.info(
|
247
|
+
f"Loading HuggingFace dataset with metadata: {dataset_path} (split: {self.split})"
|
248
|
+
)
|
109
249
|
|
110
|
-
def count_shard_items(self, shard_url: str, processed_keys: Optional[set] = None) -> int:
|
111
|
-
"""Count items in a shard (can be slow for large shards)."""
|
112
|
-
count = 0
|
113
250
|
try:
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
251
|
+
# Load dataset in streaming mode
|
252
|
+
dataset = load_dataset(
|
253
|
+
dataset_path,
|
254
|
+
split=self.split,
|
255
|
+
streaming=True,
|
256
|
+
token=self.token,
|
257
|
+
)
|
119
258
|
|
259
|
+
# Skip to start index if needed - CONSISTENT WITH OTHER METHOD
|
260
|
+
if start_idx > 0:
|
261
|
+
dataset = dataset.skip(start_idx)
|
262
|
+
|
263
|
+
items_processed = 0
|
264
|
+
|
265
|
+
for item in dataset:
|
266
|
+
# Stop after processing chunk_size items
|
267
|
+
if items_processed >= chunk_size:
|
268
|
+
break
|
269
|
+
|
270
|
+
# Generate a unique key for this item - CONSISTENT FORMAT
|
271
|
+
key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
|
272
|
+
|
273
|
+
if key in processed_keys:
|
274
|
+
items_processed += 1
|
275
|
+
continue
|
276
|
+
|
277
|
+
try:
|
278
|
+
# Extract image data
|
279
|
+
if self.image_column in item:
|
280
|
+
img_data = item[self.image_column]
|
281
|
+
|
282
|
+
# Process image to bytes
|
283
|
+
image_bytes = ImageProcessor.process_image_data(img_data)
|
284
|
+
|
285
|
+
if image_bytes:
|
286
|
+
# Extract all metadata (excluding the image column)
|
287
|
+
metadata = {k: v for k, v in item.items() if k != self.image_column}
|
288
|
+
|
289
|
+
# URL is virtual for HF datasets
|
290
|
+
url = f"hf://{dataset_path}#{start_idx + items_processed}"
|
291
|
+
items_processed += 1
|
292
|
+
yield key, url, image_bytes, metadata
|
293
|
+
else:
|
294
|
+
logger.warning(
|
295
|
+
f"Failed to process image for item at index {start_idx + items_processed}"
|
296
|
+
)
|
297
|
+
items_processed += 1
|
298
|
+
continue
|
299
|
+
else:
|
300
|
+
logger.warning(
|
301
|
+
f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
|
302
|
+
f"Available columns: {list(item.keys())}"
|
303
|
+
)
|
304
|
+
items_processed += 1
|
305
|
+
|
306
|
+
except Exception as e:
|
307
|
+
logger.error(
|
308
|
+
f"Error processing item at index {start_idx + items_processed}: {e}"
|
309
|
+
)
|
310
|
+
items_processed += 1
|
311
|
+
continue
|
120
312
|
|
121
|
-
|
122
|
-
|
313
|
+
except Exception as e:
|
314
|
+
logger.error(f"Error loading HuggingFace dataset: {e}")
|
315
|
+
return
|
123
316
|
|
124
|
-
def
|
125
|
-
|
126
|
-
|
127
|
-
|
317
|
+
def _iterate_hf_dataset_shard(
|
318
|
+
self, shard_url: str, processed_keys: Optional[set] = None
|
319
|
+
) -> Generator[Tuple[str, str, bytes], None, None]:
|
320
|
+
"""Iterate over a virtual HuggingFace dataset shard."""
|
321
|
+
if processed_keys is None:
|
322
|
+
processed_keys = set()
|
128
323
|
|
129
|
-
|
130
|
-
self.partial_shards: Dict[str, Dict[str, Any]] = {}
|
131
|
-
self.load()
|
324
|
+
dataset_path, start_idx, chunk_size = self._parse_virtual_shard(shard_url)
|
132
325
|
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
logger.info(
|
141
|
-
f"Loaded checkpoint: {len(self.completed_shards)} completed, "
|
142
|
-
f"{len(self.partial_shards)} partial shards"
|
143
|
-
)
|
144
|
-
except Exception as e:
|
145
|
-
logger.error(f"Failed to load checkpoint: {e}")
|
326
|
+
# IMPORTANT: Check if start_idx is beyond dataset bounds
|
327
|
+
if self._hf_total_items is not None and start_idx >= self._hf_total_items:
|
328
|
+
logger.warning(
|
329
|
+
f"Virtual shard starts at index {start_idx} but dataset only has "
|
330
|
+
f"{self._hf_total_items} items. Skipping this shard."
|
331
|
+
)
|
332
|
+
return
|
146
333
|
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
"partial_shards": self.partial_shards,
|
152
|
-
}
|
334
|
+
logger.info(
|
335
|
+
f"Loading HuggingFace dataset in streaming mode: {dataset_path} "
|
336
|
+
f"(split: {self.split}, start: {start_idx}, chunk_size: {chunk_size})"
|
337
|
+
)
|
153
338
|
|
154
|
-
|
155
|
-
|
156
|
-
|
339
|
+
try:
|
340
|
+
# Load dataset in streaming mode
|
341
|
+
dataset = load_dataset(
|
342
|
+
dataset_path,
|
343
|
+
split=self.split,
|
344
|
+
streaming=True,
|
345
|
+
token=self.token,
|
346
|
+
)
|
157
347
|
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
348
|
+
# Use dataset.skip() for efficient skipping
|
349
|
+
if start_idx > 0:
|
350
|
+
dataset = dataset.skip(start_idx)
|
351
|
+
logger.info(f"Skipped to index {start_idx}")
|
352
|
+
|
353
|
+
items_processed = 0
|
354
|
+
|
355
|
+
# Now enumerate starts from 0 after skip
|
356
|
+
for item in dataset:
|
357
|
+
# Stop after processing chunk_size items
|
358
|
+
if items_processed >= chunk_size:
|
359
|
+
logger.info(f"Completed chunk: processed {items_processed} items")
|
360
|
+
break
|
361
|
+
|
362
|
+
# Also stop if we've reached the dataset end
|
363
|
+
if self._hf_total_items and (start_idx + items_processed) >= self._hf_total_items:
|
364
|
+
logger.info(
|
365
|
+
f"Reached dataset end at item {start_idx + items_processed} "
|
366
|
+
f"(total: {self._hf_total_items})"
|
367
|
+
)
|
368
|
+
break
|
369
|
+
|
370
|
+
# Generate a unique key for this item - ensure proper formatting
|
371
|
+
key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
|
372
|
+
|
373
|
+
if key in processed_keys:
|
374
|
+
items_processed += 1
|
375
|
+
continue
|
376
|
+
|
377
|
+
try:
|
378
|
+
# Extract image data - check configured column name
|
379
|
+
if self.image_column in item:
|
380
|
+
img_data = item[self.image_column]
|
381
|
+
|
382
|
+
# Delegate image processing to ImageProcessor
|
383
|
+
image_bytes = ImageProcessor.process_image_data(img_data)
|
384
|
+
|
385
|
+
if image_bytes:
|
386
|
+
# URL is virtual for HF datasets
|
387
|
+
url = f"hf://{dataset_path}#{start_idx + items_processed}"
|
388
|
+
items_processed += 1
|
389
|
+
yield key, url, image_bytes
|
390
|
+
else:
|
391
|
+
logger.warning(
|
392
|
+
f"Failed to process image for item at index {start_idx + items_processed}"
|
393
|
+
)
|
394
|
+
items_processed += 1
|
395
|
+
continue
|
396
|
+
else:
|
397
|
+
logger.warning(
|
398
|
+
f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
|
399
|
+
f"Available columns: {list(item.keys())}"
|
400
|
+
)
|
401
|
+
items_processed += 1
|
402
|
+
|
403
|
+
except Exception as e:
|
404
|
+
logger.error(
|
405
|
+
f"Error processing item at index {start_idx + items_processed}: {e}"
|
406
|
+
)
|
407
|
+
items_processed += 1
|
408
|
+
continue
|
409
|
+
|
410
|
+
logger.info(
|
411
|
+
f"Virtual shard complete: processed {items_processed} items "
|
412
|
+
f"(start_idx: {start_idx})"
|
413
|
+
)
|
164
414
|
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
self.save()
|
415
|
+
except Exception as e:
|
416
|
+
logger.error(f"Error loading HuggingFace dataset: {e}")
|
417
|
+
return
|
169
418
|
|
170
|
-
def
|
171
|
-
|
172
|
-
|
173
|
-
|
419
|
+
def iterate_shard_with_metadata(
|
420
|
+
self, shard_url: str, processed_keys: Optional[set] = None
|
421
|
+
) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
|
422
|
+
"""
|
423
|
+
Iterate over items in a shard, including metadata.
|
424
|
+
|
425
|
+
Yields:
|
426
|
+
Tuple of (key, url, image_bytes, metadata_dict)
|
427
|
+
"""
|
428
|
+
# Check if this is a virtual HuggingFace dataset shard
|
429
|
+
if shard_url.startswith("hf_dataset:"):
|
430
|
+
yield from self._iterate_hf_dataset_shard_with_metadata(shard_url, processed_keys)
|
431
|
+
else:
|
432
|
+
# Regular WebDataset shard - no metadata by default
|
433
|
+
for key, url, image_data in self.iterate_shard(shard_url, processed_keys):
|
434
|
+
yield key, url, image_data, {}
|
174
435
|
|
175
|
-
|
176
|
-
|
436
|
+
def count_shard_items(self, shard_url: str, processed_keys: Optional[set] = None) -> int:
|
437
|
+
"""Count items in a shard (can be slow for large shards)."""
|
438
|
+
if shard_url.startswith("hf_dataset:"):
|
439
|
+
# For virtual shards, return the chunk size
|
440
|
+
_, start_idx, chunk_size = self._parse_virtual_shard(shard_url)
|
441
|
+
|
442
|
+
# CRITICAL: Cap chunk size by dataset bounds
|
443
|
+
if self._hf_total_items is not None:
|
444
|
+
# If start index is beyond dataset, return 0
|
445
|
+
if start_idx >= self._hf_total_items:
|
446
|
+
logger.warning(
|
447
|
+
f"Virtual shard starts at {start_idx} but dataset has "
|
448
|
+
f"only {self._hf_total_items} items"
|
449
|
+
)
|
450
|
+
return 0
|
451
|
+
|
452
|
+
# Otherwise, return the minimum of chunk_size and remaining items
|
453
|
+
remaining_items = self._hf_total_items - start_idx
|
454
|
+
actual_size = min(chunk_size, remaining_items)
|
455
|
+
logger.debug(
|
456
|
+
f"Virtual shard at {start_idx}: chunk_size={chunk_size}, "
|
457
|
+
f"remaining={remaining_items}, actual={actual_size}"
|
458
|
+
)
|
459
|
+
return actual_size
|
460
|
+
else:
|
461
|
+
# If we don't know total size, return chunk_size
|
462
|
+
return chunk_size
|
463
|
+
else:
|
464
|
+
# Regular WebDataset counting
|
465
|
+
count = 0
|
466
|
+
try:
|
467
|
+
for _ in self.iterate_shard(shard_url, processed_keys):
|
468
|
+
count += 1
|
469
|
+
except Exception as e:
|
470
|
+
logger.error(f"Error counting shard {shard_url}: {e}")
|
471
|
+
return count
|
472
|
+
|
473
|
+
def get_dataset_info(self) -> Dict[str, Any]:
|
474
|
+
"""Get information about the dataset."""
|
475
|
+
info = {
|
476
|
+
"dataset_path": self.dataset_path,
|
477
|
+
"dataset_type": self.dataset_type,
|
478
|
+
"dataset_format": self.dataset_format,
|
479
|
+
}
|
177
480
|
|
178
|
-
|
481
|
+
if self.dataset_format == "huggingface_datasets":
|
482
|
+
try:
|
483
|
+
# Try to get more info about the dataset
|
484
|
+
dataset_info = load_dataset(
|
485
|
+
self.dataset_path, split=self.split, streaming=True, token=self.token
|
486
|
+
)
|
487
|
+
# Get features info
|
488
|
+
if hasattr(dataset_info, "features"):
|
489
|
+
info["features"] = str(dataset_info.features)
|
490
|
+
|
491
|
+
# Try to get total size (might not work for all datasets)
|
492
|
+
try:
|
493
|
+
# This might be expensive for large datasets
|
494
|
+
total_examples = len(
|
495
|
+
load_dataset(self.dataset_path, split=self.split, token=self.token)
|
496
|
+
)
|
497
|
+
info["total_examples"] = total_examples
|
498
|
+
self._hf_total_items = total_examples
|
499
|
+
except:
|
500
|
+
info["total_examples"] = "unknown"
|
179
501
|
|
180
|
-
|
181
|
-
|
182
|
-
return shard_name in self.completed_shards
|
502
|
+
except Exception as e:
|
503
|
+
logger.error(f"Error getting dataset info: {e}")
|
183
504
|
|
184
|
-
|
185
|
-
"""Get list of shards that still need processing."""
|
186
|
-
return [s for s in all_shards if Path(s).stem not in self.completed_shards]
|
505
|
+
return info
|