caption-flow 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/cli.py +1 -0
- caption_flow/models.py +108 -1
- caption_flow/monitor.py +1 -1
- caption_flow/orchestrator.py +423 -1715
- caption_flow/processors/__init__.py +11 -0
- caption_flow/processors/base.py +219 -0
- caption_flow/processors/huggingface.py +832 -0
- caption_flow/processors/local_filesystem.py +683 -0
- caption_flow/processors/webdataset.py +782 -0
- caption_flow/storage.py +411 -407
- caption_flow/utils/checkpoint_tracker.py +2 -2
- caption_flow/utils/chunk_tracker.py +73 -32
- caption_flow/utils/dataset_loader.py +58 -298
- caption_flow/utils/dataset_metadata_cache.py +67 -0
- caption_flow/utils/image_processor.py +1 -4
- caption_flow/utils/shard_processor.py +5 -265
- caption_flow/utils/shard_tracker.py +1 -5
- caption_flow/workers/base.py +3 -3
- caption_flow/workers/caption.py +416 -792
- {caption_flow-0.2.2.dist-info → caption_flow-0.2.3.dist-info}/METADATA +29 -27
- caption_flow-0.2.3.dist-info/RECORD +35 -0
- caption_flow-0.2.2.dist-info/RECORD +0 -29
- {caption_flow-0.2.2.dist-info → caption_flow-0.2.3.dist-info}/WHEEL +0 -0
- {caption_flow-0.2.2.dist-info → caption_flow-0.2.3.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.2.2.dist-info → caption_flow-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.2.2.dist-info → caption_flow-0.2.3.dist-info}/top_level.txt +0 -0
@@ -73,12 +73,12 @@ class CheckpointTracker(ABC):
|
|
73
73
|
logger.debug(f"Saved checkpoint to {self.checkpoint_path}")
|
74
74
|
|
75
75
|
except Exception as e:
|
76
|
-
logger.error(f"Error saving checkpoint: {e}", exc_info=True)
|
76
|
+
# logger.error(f"Error saving checkpoint: {e}", exc_info=True)
|
77
77
|
# Try direct write as fallback
|
78
78
|
try:
|
79
79
|
with open(self.checkpoint_path, "w") as f:
|
80
80
|
json.dump(data, f, indent=2)
|
81
|
-
logger.info("Saved checkpoint using fallback direct write")
|
81
|
+
# logger.info("Saved checkpoint using fallback direct write")
|
82
82
|
except Exception as fallback_error:
|
83
83
|
logger.error(f"Fallback save also failed: {fallback_error}")
|
84
84
|
|
@@ -10,6 +10,7 @@ from dataclasses import dataclass, asdict, field
|
|
10
10
|
from .checkpoint_tracker import CheckpointTracker
|
11
11
|
|
12
12
|
logger = logging.getLogger(__name__)
|
13
|
+
# logger.setLevel(logging.DEBUG)
|
13
14
|
|
14
15
|
|
15
16
|
@dataclass
|
@@ -58,11 +59,15 @@ class ChunkState:
|
|
58
59
|
def get_unprocessed_ranges(self) -> List[Tuple[int, int]]:
|
59
60
|
"""Get ranges that haven't been processed yet."""
|
60
61
|
if not self.processed_ranges:
|
62
|
+
logger.info(f"Chunk {self.chunk_id} has no processed ranges, returning full range")
|
61
63
|
return [(0, self.chunk_size - 1)]
|
62
64
|
|
63
65
|
unprocessed = []
|
64
66
|
current = 0
|
65
67
|
|
68
|
+
logger.info(
|
69
|
+
f"Processing {len(self.processed_ranges)} processed ranges for chunk {self.chunk_id}"
|
70
|
+
)
|
66
71
|
for start, end in self.processed_ranges:
|
67
72
|
if current < start:
|
68
73
|
unprocessed.append((current, start - 1))
|
@@ -132,6 +137,11 @@ class ChunkTracker(CheckpointTracker):
|
|
132
137
|
self, chunk_id: str, shard_name: str, shard_url: str, start_index: int, chunk_size: int
|
133
138
|
) -> bool:
|
134
139
|
"""Add a new chunk. Returns False if chunk already exists and is completed."""
|
140
|
+
if chunk_id in self.chunks:
|
141
|
+
logger.debug(
|
142
|
+
f"Chunk {chunk_id} already exists with status: {self.chunks[chunk_id].status}, not creating"
|
143
|
+
)
|
144
|
+
return False
|
135
145
|
if chunk_id in self.completed_chunks:
|
136
146
|
logger.debug(f"Chunk {chunk_id} already completed, skipping")
|
137
147
|
return False
|
@@ -166,7 +176,7 @@ class ChunkTracker(CheckpointTracker):
|
|
166
176
|
chunk.completed_at = datetime.utcnow()
|
167
177
|
self.completed_chunks.add(chunk_id)
|
168
178
|
self.save()
|
169
|
-
logger.
|
179
|
+
logger.debug(f"Chunk {chunk_id} marked as completed")
|
170
180
|
|
171
181
|
def mark_failed(self, chunk_id: str):
|
172
182
|
"""Mark chunk as failed."""
|
@@ -207,6 +217,49 @@ class ChunkTracker(CheckpointTracker):
|
|
207
217
|
pending.append(chunk_id)
|
208
218
|
return pending
|
209
219
|
|
220
|
+
def get_processed_indices_for_chunk(
|
221
|
+
self, chunk_id: str, processed_job_ids: Set[str]
|
222
|
+
) -> List[Tuple[int, int]]:
|
223
|
+
"""Convert processed job_ids back to ranges for a chunk."""
|
224
|
+
# Extract indices from job_ids like "data-0000:chunk:0:idx:42"
|
225
|
+
processed_indices = []
|
226
|
+
# this will be slow as shit, but it's simple for now, Proof of Concept.
|
227
|
+
for job_id in processed_job_ids:
|
228
|
+
test_chunk_id = chunk_id.replace("_", ":")
|
229
|
+
if test_chunk_id in job_id:
|
230
|
+
parts = job_id.split(":")
|
231
|
+
logger.debug(
|
232
|
+
f"Found matching job_id {job_id} for chunk {chunk_id} with {len(parts)=} and {parts[3]=}"
|
233
|
+
)
|
234
|
+
if len(parts) >= 5 and parts[3] == "idx":
|
235
|
+
idx = int(parts[4])
|
236
|
+
processed_indices.append(idx)
|
237
|
+
|
238
|
+
# Convert to ranges
|
239
|
+
if not processed_indices:
|
240
|
+
# logger.warning(
|
241
|
+
# f"Chunk {chunk_id} had no pre-processed ranges discovered, will process all elements"
|
242
|
+
# )
|
243
|
+
return []
|
244
|
+
else:
|
245
|
+
logger.debug(f"Chunk {chunk_id} has {len(processed_indices)} pre-processed indices")
|
246
|
+
|
247
|
+
processed_indices.sort()
|
248
|
+
ranges = []
|
249
|
+
start = processed_indices[0]
|
250
|
+
end = processed_indices[0]
|
251
|
+
|
252
|
+
for idx in processed_indices[1:]:
|
253
|
+
if idx == end + 1:
|
254
|
+
end = idx
|
255
|
+
else:
|
256
|
+
ranges.append((start, end))
|
257
|
+
start = idx
|
258
|
+
end = idx
|
259
|
+
|
260
|
+
ranges.append((start, end))
|
261
|
+
return ranges
|
262
|
+
|
210
263
|
def is_shard_complete(self, shard_name: str) -> bool:
|
211
264
|
"""Check if all chunks for a shard are complete."""
|
212
265
|
shard_chunks = [chunk for chunk in self.chunks.values() if chunk.shard_name == shard_name]
|
@@ -236,20 +289,8 @@ class ChunkTracker(CheckpointTracker):
|
|
236
289
|
|
237
290
|
for chunk_id, chunk_state in self.chunks.items():
|
238
291
|
shard_name = chunk_state.shard_name
|
239
|
-
|
240
|
-
|
241
|
-
if shard_name.startswith("hf_dataset:"):
|
242
|
-
parts = shard_name.split(":")
|
243
|
-
if len(parts) >= 4 and parts[2] == "chunk":
|
244
|
-
# Use just the dataset identifier as the shard name
|
245
|
-
normalized_shard_name = ":".join(parts[:2])
|
246
|
-
else:
|
247
|
-
normalized_shard_name = shard_name
|
248
|
-
else:
|
249
|
-
normalized_shard_name = shard_name
|
250
|
-
|
251
|
-
if normalized_shard_name not in shards:
|
252
|
-
shards[normalized_shard_name] = {
|
292
|
+
if shard_name not in shards:
|
293
|
+
shards[shard_name] = {
|
253
294
|
"total_chunks": 0,
|
254
295
|
"completed_chunks": 0,
|
255
296
|
"pending_chunks": 0,
|
@@ -259,20 +300,20 @@ class ChunkTracker(CheckpointTracker):
|
|
259
300
|
"chunks": [],
|
260
301
|
}
|
261
302
|
|
262
|
-
shards[
|
263
|
-
shards[
|
303
|
+
shards[shard_name]["chunks"].append(chunk_state)
|
304
|
+
shards[shard_name]["total_chunks"] += 1
|
264
305
|
|
265
306
|
if chunk_state.status == "completed":
|
266
|
-
shards[
|
307
|
+
shards[shard_name]["completed_chunks"] += 1
|
267
308
|
elif chunk_state.status == "pending":
|
268
|
-
shards[
|
269
|
-
shards[
|
309
|
+
shards[shard_name]["pending_chunks"] += 1
|
310
|
+
shards[shard_name]["is_complete"] = False
|
270
311
|
elif chunk_state.status == "assigned":
|
271
|
-
shards[
|
272
|
-
shards[
|
312
|
+
shards[shard_name]["assigned_chunks"] += 1
|
313
|
+
shards[shard_name]["is_complete"] = False
|
273
314
|
elif chunk_state.status == "failed":
|
274
|
-
shards[
|
275
|
-
shards[
|
315
|
+
shards[shard_name]["failed_chunks"] += 1
|
316
|
+
shards[shard_name]["is_complete"] = False
|
276
317
|
|
277
318
|
return shards
|
278
319
|
|
@@ -322,13 +363,7 @@ class ChunkTracker(CheckpointTracker):
|
|
322
363
|
continue
|
323
364
|
|
324
365
|
# Infer shard URL and create chunk with default size
|
325
|
-
|
326
|
-
# HF dataset
|
327
|
-
dataset_path = shard_name.replace("_", "/")
|
328
|
-
shard_url = f"hf_dataset:{dataset_path}:chunk:{start_idx}"
|
329
|
-
else:
|
330
|
-
# WebDataset
|
331
|
-
shard_url = f"unknown://{shard_name}.tar"
|
366
|
+
shard_url = f"unknown://{shard_name}.tar"
|
332
367
|
|
333
368
|
self.chunks[chunk_id] = ChunkState(
|
334
369
|
chunk_id=chunk_id,
|
@@ -410,6 +445,7 @@ class ChunkTracker(CheckpointTracker):
|
|
410
445
|
"""Mark a range of items as processed within a chunk (expects ABSOLUTE indices)."""
|
411
446
|
if chunk_id not in self.chunks:
|
412
447
|
logger.error(f"Unknown chunk: {chunk_id}")
|
448
|
+
logger.debug(f"Known chunks: {list(self.chunks.keys())}")
|
413
449
|
return
|
414
450
|
|
415
451
|
chunk = self.chunks[chunk_id]
|
@@ -450,8 +486,13 @@ class ChunkTracker(CheckpointTracker):
|
|
450
486
|
if not hasattr(self, "_startup_complete"):
|
451
487
|
self._startup_complete = False
|
452
488
|
|
453
|
-
if not self._startup_complete or
|
489
|
+
if not self._startup_complete or (
|
490
|
+
not chunk_state.assigned_to or chunk_state.completed_at is None
|
491
|
+
):
|
454
492
|
# Return all unprocessed ranges
|
493
|
+
logger.debug(
|
494
|
+
f"Returning all unprocessed ranges. Status {self._startup_complete=} {chunk_state=}"
|
495
|
+
)
|
455
496
|
return {
|
456
497
|
"chunk_id": chunk_id,
|
457
498
|
"unprocessed_ranges": chunk_state.get_unprocessed_ranges(),
|
@@ -9,8 +9,6 @@ import json
|
|
9
9
|
|
10
10
|
import webdataset as wds
|
11
11
|
from huggingface_hub import HfFileSystem, get_token, hf_hub_url
|
12
|
-
from datasets import load_dataset, Dataset
|
13
|
-
from .image_processor import ImageProcessor
|
14
12
|
|
15
13
|
logger = logging.getLogger(__name__)
|
16
14
|
|
@@ -24,6 +22,7 @@ class DatasetLoader:
|
|
24
22
|
dataset_type: str = "huggingface",
|
25
23
|
split: str = "train",
|
26
24
|
image_column: str = "image",
|
25
|
+
cache_dir: Optional[Path] = None,
|
27
26
|
):
|
28
27
|
"""
|
29
28
|
Initialize dataset loader.
|
@@ -40,8 +39,6 @@ class DatasetLoader:
|
|
40
39
|
self.image_column = image_column
|
41
40
|
self.token = get_token()
|
42
41
|
self.dataset_format = None # Will be detected: "webdataset" or "huggingface_datasets"
|
43
|
-
self._hf_dataset = None # Cache for HuggingFace dataset
|
44
|
-
self._hf_total_items = None # Cache for total items count
|
45
42
|
|
46
43
|
if not self.token and dataset_type == "huggingface":
|
47
44
|
logger.warning("No HuggingFace token found; run `huggingface-cli login`")
|
@@ -60,27 +57,18 @@ class DatasetLoader:
|
|
60
57
|
if tar_files:
|
61
58
|
return "webdataset"
|
62
59
|
|
63
|
-
# Check for parquet files (
|
60
|
+
# Check for .parquet files (Huggingface Arrow DB)
|
64
61
|
parquet_files = list(fs.glob(f"hf://datasets/{self.dataset_path}/**/*.parquet"))
|
65
62
|
if parquet_files:
|
66
63
|
return "huggingface_datasets"
|
67
64
|
|
68
|
-
|
69
|
-
if fs.exists(f"datasets/{self.dataset_path}/dataset_info.json") or fs.exists(
|
70
|
-
f"datasets/{self.dataset_path}/dataset_dict.json"
|
71
|
-
):
|
72
|
-
return "huggingface_datasets"
|
73
|
-
|
74
|
-
logger.warning(f"Could not detect dataset format for {self.dataset_path}")
|
75
|
-
return "unknown"
|
65
|
+
raise AssertionError(f"Could not detect dataset format for {self.dataset_path}")
|
76
66
|
|
77
67
|
def get_shard_list(self) -> List[str]:
|
78
68
|
"""Get list of all shards in the dataset."""
|
79
69
|
if self.dataset_type == "huggingface":
|
80
70
|
if self.dataset_format == "webdataset":
|
81
71
|
return self._get_hf_webdataset_shards()
|
82
|
-
elif self.dataset_format == "huggingface_datasets":
|
83
|
-
return self._get_hf_dataset_shards()
|
84
72
|
else:
|
85
73
|
logger.error(f"Unknown dataset format: {self.dataset_format}")
|
86
74
|
return []
|
@@ -101,60 +89,6 @@ class DatasetLoader:
|
|
101
89
|
logger.info(f"Found {len(urls)} WebDataset shards")
|
102
90
|
return sorted(urls)
|
103
91
|
|
104
|
-
def _get_hf_dataset_shards(self) -> List[str]:
|
105
|
-
"""Get virtual 'shards' for HuggingFace datasets format."""
|
106
|
-
logger.info(f"Getting HuggingFace dataset info: {self.dataset_path}")
|
107
|
-
|
108
|
-
# For HuggingFace datasets, we'll create virtual shards based on chunks
|
109
|
-
# Each "shard" will be a range of indices
|
110
|
-
try:
|
111
|
-
# First, try to get available splits
|
112
|
-
try:
|
113
|
-
from datasets import get_dataset_split_names
|
114
|
-
|
115
|
-
available_splits = get_dataset_split_names(self.dataset_path, token=self.token)
|
116
|
-
logger.info(f"Available splits: {available_splits}")
|
117
|
-
|
118
|
-
if self.split not in available_splits:
|
119
|
-
logger.warning(
|
120
|
-
f"Requested split '{self.split}' not found. "
|
121
|
-
f"Available splits: {available_splits}. "
|
122
|
-
f"Using first available split: '{available_splits[0]}'"
|
123
|
-
)
|
124
|
-
self.split = available_splits[0]
|
125
|
-
except Exception as e:
|
126
|
-
logger.warning(f"Could not get split names: {e}")
|
127
|
-
|
128
|
-
# Load dataset info without downloading data
|
129
|
-
dataset_info = load_dataset(
|
130
|
-
self.dataset_path, split=self.split, streaming=True, token=self.token
|
131
|
-
)
|
132
|
-
|
133
|
-
# Try to get the total size
|
134
|
-
# For streaming datasets, we might need to iterate to count
|
135
|
-
# This is expensive, so we'll use a default chunk size instead
|
136
|
-
chunk_size = 10000 # Default chunk size for virtual shards
|
137
|
-
|
138
|
-
# Create virtual shard identifiers
|
139
|
-
# Format: "hf_dataset:<dataset_path>:chunk:<start_idx>"
|
140
|
-
virtual_shards = []
|
141
|
-
|
142
|
-
# We'll create a reasonable number of virtual shards
|
143
|
-
# Without knowing the total size, we'll create them on-demand
|
144
|
-
# For now, create initial batch of virtual shards
|
145
|
-
for i in range(10): # Start with 10 virtual shards
|
146
|
-
shard_id = f"hf_dataset:{self.dataset_path}:chunk:{i * chunk_size}"
|
147
|
-
virtual_shards.append(shard_id)
|
148
|
-
|
149
|
-
logger.info(
|
150
|
-
f"Created {len(virtual_shards)} initial virtual shards for HuggingFace dataset"
|
151
|
-
)
|
152
|
-
return virtual_shards
|
153
|
-
|
154
|
-
except Exception as e:
|
155
|
-
logger.error(f"Error loading HuggingFace dataset info: {e}")
|
156
|
-
return []
|
157
|
-
|
158
92
|
def _get_local_shards(self) -> List[str]:
|
159
93
|
"""Get shard files from local directory."""
|
160
94
|
path = Path(self.dataset_path)
|
@@ -176,13 +110,6 @@ class DatasetLoader:
|
|
176
110
|
if processed_keys is None:
|
177
111
|
processed_keys = set()
|
178
112
|
|
179
|
-
# Check if this is a virtual HuggingFace dataset shard
|
180
|
-
if shard_url.startswith("hf_dataset:"):
|
181
|
-
raise ValueError(
|
182
|
-
"Virtual HuggingFace dataset shards should use iterate_shard() directly, "
|
183
|
-
"not load_shard()"
|
184
|
-
)
|
185
|
-
|
186
113
|
if self.dataset_type == "huggingface" and self.dataset_format == "webdataset":
|
187
114
|
# Use curl with auth token for HuggingFace
|
188
115
|
url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.token)}' {shlex.quote(shard_url)} || true"
|
@@ -203,229 +130,57 @@ class DatasetLoader:
|
|
203
130
|
|
204
131
|
return ds
|
205
132
|
|
206
|
-
def _parse_virtual_shard(self, shard_url: str) -> Tuple[str, int, int]:
|
207
|
-
"""Parse virtual shard identifier."""
|
208
|
-
# Format: "hf_dataset:<dataset_path>:chunk:<start_idx>"
|
209
|
-
parts = shard_url.split(":")
|
210
|
-
if len(parts) != 4 or parts[0] != "hf_dataset" or parts[2] != "chunk":
|
211
|
-
raise ValueError(f"Invalid virtual shard format: {shard_url}")
|
212
|
-
|
213
|
-
dataset_path = parts[1]
|
214
|
-
start_idx = int(parts[3])
|
215
|
-
chunk_size = 10000 # Default chunk size
|
216
|
-
|
217
|
-
return dataset_path, start_idx, chunk_size
|
218
|
-
|
219
133
|
def iterate_shard(
|
220
134
|
self,
|
221
135
|
shard_url: str,
|
222
136
|
processed_keys: Optional[set] = None,
|
223
137
|
unprocessed_ranges: Optional[List[Tuple[int, int]]] = None,
|
224
|
-
) -> Generator[
|
138
|
+
) -> Generator[Dict[str, Any], None, None]:
|
225
139
|
"""
|
226
|
-
Iterate over items in a shard.
|
140
|
+
Iterate over items in a shard, returning full sample dictionaries.
|
227
141
|
|
228
142
|
Args:
|
229
143
|
shard_url: URL or identifier of the shard
|
230
144
|
processed_keys: Set of already processed keys to skip
|
231
|
-
unprocessed_ranges: Specific ranges to process (for
|
145
|
+
unprocessed_ranges: Specific ranges to process (for range-based processing)
|
232
146
|
|
233
147
|
Yields:
|
234
|
-
|
148
|
+
Dictionary containing the full WebDataset sample
|
235
149
|
"""
|
236
|
-
if shard_url.startswith("hf_dataset:"):
|
237
|
-
raise ValueError(
|
238
|
-
"Virtual HuggingFace dataset shards should use iterate_shard_with_metadata()"
|
239
|
-
)
|
240
|
-
else:
|
241
|
-
# Regular WebDataset shard
|
242
|
-
ds = self.load_shard(shard_url, processed_keys)
|
243
|
-
for key, url, image_data in ds:
|
244
|
-
yield key, url, image_data
|
245
|
-
|
246
|
-
def _create_dataset_at_position(self, dataset_path: str, split: str, start_idx: int):
|
247
|
-
"""Create a dataset iterator positioned at start_idx using state_dict if available."""
|
248
|
-
try:
|
249
|
-
# Load dataset in streaming mode
|
250
|
-
dataset = load_dataset(
|
251
|
-
dataset_path,
|
252
|
-
split=split,
|
253
|
-
streaming=True,
|
254
|
-
token=self.token,
|
255
|
-
)
|
256
|
-
|
257
|
-
# Check if the dataset supports state_dict (newer versions of datasets library)
|
258
|
-
if hasattr(dataset, "load_state_dict") and hasattr(dataset, "state_dict"):
|
259
|
-
# Try to use the dataset's native state management
|
260
|
-
try:
|
261
|
-
# Get current state
|
262
|
-
state = dataset.state_dict()
|
263
|
-
|
264
|
-
# Modify the state to skip to start_idx
|
265
|
-
if "epoch" in state:
|
266
|
-
state["epoch"] = 0
|
267
|
-
if "num_examples_since_previous_state" in state:
|
268
|
-
state["num_examples_since_previous_state"] = start_idx
|
269
|
-
|
270
|
-
# For newer datasets with examples_iterable state
|
271
|
-
if "examples_iterable" in state:
|
272
|
-
if isinstance(state["examples_iterable"], dict):
|
273
|
-
if "shard_example_idx" in state["examples_iterable"]:
|
274
|
-
state["examples_iterable"]["shard_example_idx"] = start_idx
|
275
|
-
|
276
|
-
# Load the modified state
|
277
|
-
dataset.load_state_dict(state)
|
278
|
-
logger.info(f"Positioned dataset at index {start_idx} using state_dict")
|
279
|
-
return dataset
|
280
|
-
except Exception as e:
|
281
|
-
logger.debug(f"Could not use state_dict approach: {e}")
|
282
|
-
|
283
|
-
# Fall back to skip() for large skips
|
284
|
-
if start_idx > 0:
|
285
|
-
logger.info(f"Using skip() to position dataset at index {start_idx}")
|
286
|
-
dataset = dataset.skip(start_idx)
|
287
|
-
|
288
|
-
return dataset
|
289
|
-
|
290
|
-
except Exception as e:
|
291
|
-
logger.warning(f"Error creating positioned dataset: {e}")
|
292
|
-
return None
|
293
|
-
|
294
|
-
def _iterate_hf_dataset_shard_with_metadata(
|
295
|
-
self, shard_url: str, processed_keys: Optional[set] = None
|
296
|
-
) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
|
297
|
-
"""Iterate over a virtual HuggingFace dataset shard with metadata."""
|
298
150
|
if processed_keys is None:
|
299
151
|
processed_keys = set()
|
300
152
|
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
# The actual range filtering happens in the shard processor
|
310
|
-
items_processed = 0
|
311
|
-
current_abs_idx = start_idx
|
312
|
-
|
313
|
-
while items_processed < chunk_size:
|
314
|
-
# Create a fresh dataset iterator for each batch
|
315
|
-
# This avoids issues with stateful iterators
|
316
|
-
batch_size = min(1000, chunk_size - items_processed) # Process in smaller batches
|
317
|
-
|
318
|
-
dataset = load_dataset(
|
319
|
-
dataset_path,
|
320
|
-
split=self.split,
|
321
|
-
streaming=True,
|
322
|
-
token=self.token,
|
323
|
-
)
|
324
|
-
|
325
|
-
# Skip to current position
|
326
|
-
if current_abs_idx > 0:
|
327
|
-
dataset = dataset.skip(current_abs_idx)
|
328
|
-
|
329
|
-
batch_processed = 0
|
330
|
-
for item in dataset:
|
331
|
-
if batch_processed >= batch_size or items_processed >= chunk_size:
|
332
|
-
break
|
333
|
-
|
334
|
-
# Generate key
|
335
|
-
key = f"{dataset_path.replace('/', '_')}_{current_abs_idx:08d}"
|
336
|
-
|
337
|
-
if key in processed_keys:
|
338
|
-
current_abs_idx += 1
|
339
|
-
batch_processed += 1
|
340
|
-
items_processed += 1
|
341
|
-
continue
|
342
|
-
|
343
|
-
try:
|
344
|
-
if self.image_column in item:
|
345
|
-
img_data = item[self.image_column]
|
346
|
-
image_bytes = ImageProcessor.process_image_data(img_data)
|
347
|
-
|
348
|
-
if image_bytes:
|
349
|
-
metadata = {k: v for k, v in item.items() if k != self.image_column}
|
350
|
-
url = f"hf://{dataset_path}#{current_abs_idx}"
|
351
|
-
|
352
|
-
yield key, url, image_bytes, metadata
|
353
|
-
|
354
|
-
current_abs_idx += 1
|
355
|
-
batch_processed += 1
|
356
|
-
items_processed += 1
|
357
|
-
else:
|
358
|
-
logger.warning(
|
359
|
-
f"No image column '{self.image_column}' at index {current_abs_idx}"
|
360
|
-
)
|
361
|
-
current_abs_idx += 1
|
362
|
-
batch_processed += 1
|
363
|
-
items_processed += 1
|
364
|
-
|
365
|
-
except Exception as e:
|
366
|
-
logger.error(f"Error processing item at index {current_abs_idx}: {e}")
|
367
|
-
current_abs_idx += 1
|
368
|
-
batch_processed += 1
|
369
|
-
items_processed += 1
|
370
|
-
continue
|
371
|
-
|
372
|
-
except Exception as e:
|
373
|
-
logger.error(f"Error loading HuggingFace dataset: {e}")
|
374
|
-
return
|
375
|
-
|
376
|
-
def iterate_shard_with_metadata(
|
377
|
-
self, shard_url: str, processed_keys: Optional[set] = None
|
378
|
-
) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
|
379
|
-
"""
|
380
|
-
Iterate over items in a shard, including metadata.
|
381
|
-
|
382
|
-
Yields:
|
383
|
-
Tuple of (key, url, image_bytes, metadata_dict)
|
384
|
-
"""
|
385
|
-
# Check if this is a virtual HuggingFace dataset shard
|
386
|
-
if shard_url.startswith("hf_dataset:"):
|
387
|
-
yield from self._iterate_hf_dataset_shard_with_metadata(shard_url, processed_keys)
|
153
|
+
if self.dataset_type == "huggingface" and self.dataset_format == "webdataset":
|
154
|
+
# Use curl with auth token for HuggingFace
|
155
|
+
url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.token)}' {shlex.quote(shard_url)} || true"
|
156
|
+
ds = wds.DataPipeline(
|
157
|
+
wds.SimpleShardList(url_cmd),
|
158
|
+
wds.tarfile_to_samples(),
|
159
|
+
wds.select(lambda x: x.get("__key__", "") not in processed_keys),
|
160
|
+
)
|
388
161
|
else:
|
389
|
-
#
|
390
|
-
|
391
|
-
|
162
|
+
# Local file access
|
163
|
+
ds = wds.DataPipeline(
|
164
|
+
wds.SimpleShardList(shard_url),
|
165
|
+
wds.tarfile_to_samples(),
|
166
|
+
wds.select(lambda x: x.get("__key__", "") not in processed_keys),
|
167
|
+
)
|
168
|
+
|
169
|
+
# Return full samples as dictionaries
|
170
|
+
for sample in ds:
|
171
|
+
# Ensure it's a dict and has required fields
|
172
|
+
if isinstance(sample, dict) and "__key__" in sample:
|
173
|
+
yield sample
|
392
174
|
|
393
175
|
def count_shard_items(self, shard_url: str, processed_keys: Optional[set] = None) -> int:
|
394
176
|
"""Count items in a shard (can be slow for large shards)."""
|
395
|
-
|
396
|
-
|
397
|
-
_
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
if start_idx >= self._hf_total_items:
|
403
|
-
logger.warning(
|
404
|
-
f"Virtual shard starts at {start_idx} but dataset has "
|
405
|
-
f"only {self._hf_total_items} items"
|
406
|
-
)
|
407
|
-
return 0
|
408
|
-
|
409
|
-
# Otherwise, return the minimum of chunk_size and remaining items
|
410
|
-
remaining_items = self._hf_total_items - start_idx
|
411
|
-
actual_size = min(chunk_size, remaining_items)
|
412
|
-
logger.debug(
|
413
|
-
f"Virtual shard at {start_idx}: chunk_size={chunk_size}, "
|
414
|
-
f"remaining={remaining_items}, actual={actual_size}"
|
415
|
-
)
|
416
|
-
return actual_size
|
417
|
-
else:
|
418
|
-
# If we don't know total size, return chunk_size
|
419
|
-
return chunk_size
|
420
|
-
else:
|
421
|
-
# Regular WebDataset counting
|
422
|
-
count = 0
|
423
|
-
try:
|
424
|
-
for _ in self.iterate_shard(shard_url, processed_keys):
|
425
|
-
count += 1
|
426
|
-
except Exception as e:
|
427
|
-
logger.error(f"Error counting shard {shard_url}: {e}")
|
428
|
-
return count
|
177
|
+
count = 0
|
178
|
+
try:
|
179
|
+
for _ in self.iterate_shard(shard_url, processed_keys):
|
180
|
+
count += 1
|
181
|
+
except Exception as e:
|
182
|
+
logger.error(f"Error counting shard {shard_url}: {e}")
|
183
|
+
return count
|
429
184
|
|
430
185
|
def get_dataset_info(self) -> Dict[str, Any]:
|
431
186
|
"""Get information about the dataset."""
|
@@ -436,27 +191,32 @@ class DatasetLoader:
|
|
436
191
|
}
|
437
192
|
|
438
193
|
if self.dataset_format == "huggingface_datasets":
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
# Get features info
|
445
|
-
if hasattr(dataset_info, "features"):
|
446
|
-
info["features"] = str(dataset_info.features)
|
447
|
-
|
448
|
-
# Try to get total size (might not work for all datasets)
|
194
|
+
# Include cached metadata if available
|
195
|
+
if hasattr(self, "_hf_metadata"):
|
196
|
+
info.update(self._hf_metadata)
|
197
|
+
else:
|
198
|
+
|
449
199
|
try:
|
450
|
-
#
|
451
|
-
|
452
|
-
|
200
|
+
# Try to get more info about the dataset
|
201
|
+
dataset_info = load_dataset(
|
202
|
+
self.dataset_path, split=self.split, streaming=True, token=self.token
|
453
203
|
)
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
204
|
+
# Get features info
|
205
|
+
if hasattr(dataset_info, "features"):
|
206
|
+
info["features"] = str(dataset_info.features)
|
207
|
+
|
208
|
+
# Try to get total size (might not work for all datasets)
|
209
|
+
try:
|
210
|
+
# This might be expensive for large datasets
|
211
|
+
total_examples = len(
|
212
|
+
load_dataset(self.dataset_path, split=self.split, token=self.token)
|
213
|
+
)
|
214
|
+
info["total_examples"] = total_examples
|
215
|
+
self._hf_total_items = total_examples
|
216
|
+
except:
|
217
|
+
info["total_examples"] = "unknown"
|
458
218
|
|
459
|
-
|
460
|
-
|
219
|
+
except Exception as e:
|
220
|
+
logger.error(f"Error getting dataset info: {e}")
|
461
221
|
|
462
222
|
return info
|