caption-flow 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,6 +9,8 @@ import json
9
9
 
10
10
  import webdataset as wds
11
11
  from huggingface_hub import HfFileSystem, get_token, hf_hub_url
12
+ from datasets import load_dataset, Dataset
13
+ from .image_processor import ImageProcessor
12
14
 
13
15
  logger = logging.getLogger(__name__)
14
16
 
@@ -16,42 +18,143 @@ logger = logging.getLogger(__name__)
16
18
  class DatasetLoader:
17
19
  """Handles loading datasets from various sources."""
18
20
 
19
- def __init__(self, dataset_path: str, dataset_type: str = "huggingface"):
21
+ def __init__(
22
+ self,
23
+ dataset_path: str,
24
+ dataset_type: str = "huggingface",
25
+ split: str = "train",
26
+ image_column: str = "image",
27
+ ):
20
28
  """
21
29
  Initialize dataset loader.
22
30
 
23
31
  Args:
24
32
  dataset_path: Path to dataset (HF repo, local dir, etc.)
25
33
  dataset_type: Type of dataset ("huggingface", "webdataset", "local")
34
+ split: Split to use for HuggingFace datasets (default: "train")
35
+ image_column: Column name containing image data or URLs (default: "image")
26
36
  """
27
37
  self.dataset_path = dataset_path
28
38
  self.dataset_type = dataset_type
39
+ self.split = split
40
+ self.image_column = image_column
29
41
  self.token = get_token()
42
+ self.dataset_format = None # Will be detected: "webdataset" or "huggingface_datasets"
43
+ self._hf_dataset = None # Cache for HuggingFace dataset
44
+ self._hf_total_items = None # Cache for total items count
30
45
 
31
46
  if not self.token and dataset_type == "huggingface":
32
47
  logger.warning("No HuggingFace token found; run `huggingface-cli login`")
33
48
 
49
+ # Detect the actual format if it's a HuggingFace dataset
50
+ if dataset_type == "huggingface":
51
+ self.dataset_format = self._detect_dataset_format()
52
+ logger.info(f"Detected dataset format: {self.dataset_format}")
53
+
54
+ def _detect_dataset_format(self) -> str:
55
+ """Detect whether it's WebDataset or HuggingFace datasets format."""
56
+ fs = HfFileSystem(token=self.token)
57
+
58
+ # Check for .tar files (WebDataset)
59
+ tar_files = list(fs.glob(f"hf://datasets/{self.dataset_path}/**/*.tar"))
60
+ if tar_files:
61
+ return "webdataset"
62
+
63
+ # Check for parquet files (HuggingFace datasets)
64
+ parquet_files = list(fs.glob(f"hf://datasets/{self.dataset_path}/**/*.parquet"))
65
+ if parquet_files:
66
+ return "huggingface_datasets"
67
+
68
+ # Check for dataset_info.json or dataset_dict.json
69
+ if fs.exists(f"datasets/{self.dataset_path}/dataset_info.json") or fs.exists(
70
+ f"datasets/{self.dataset_path}/dataset_dict.json"
71
+ ):
72
+ return "huggingface_datasets"
73
+
74
+ logger.warning(f"Could not detect dataset format for {self.dataset_path}")
75
+ return "unknown"
76
+
34
77
  def get_shard_list(self) -> List[str]:
35
78
  """Get list of all shards in the dataset."""
36
79
  if self.dataset_type == "huggingface":
37
- return self._get_hf_shards()
80
+ if self.dataset_format == "webdataset":
81
+ return self._get_hf_webdataset_shards()
82
+ elif self.dataset_format == "huggingface_datasets":
83
+ return self._get_hf_dataset_shards()
84
+ else:
85
+ logger.error(f"Unknown dataset format: {self.dataset_format}")
86
+ return []
38
87
  elif self.dataset_type == "local":
39
88
  return self._get_local_shards()
40
89
  else:
41
90
  raise ValueError(f"Unknown dataset type: {self.dataset_type}")
42
91
 
43
- def _get_hf_shards(self) -> List[str]:
44
- """Get shard URLs from HuggingFace dataset."""
45
- logger.info(f"Getting shard list from HuggingFace: {self.dataset_path}")
92
+ def _get_hf_webdataset_shards(self) -> List[str]:
93
+ """Get shard URLs from HuggingFace WebDataset."""
94
+ logger.info(f"Getting WebDataset shard list from HuggingFace: {self.dataset_path}")
46
95
 
47
- fs = HfFileSystem()
96
+ fs = HfFileSystem(token=self.token)
48
97
  files = [fs.resolve_path(p) for p in fs.glob(f"hf://datasets/{self.dataset_path}/**/*.tar")]
49
98
 
50
99
  urls = [hf_hub_url(f.repo_id, f.path_in_repo, repo_type="dataset") for f in files]
51
100
 
52
- logger.info(f"Found {len(urls)} shards")
101
+ logger.info(f"Found {len(urls)} WebDataset shards")
53
102
  return sorted(urls)
54
103
 
104
+ def _get_hf_dataset_shards(self) -> List[str]:
105
+ """Get virtual 'shards' for HuggingFace datasets format."""
106
+ logger.info(f"Getting HuggingFace dataset info: {self.dataset_path}")
107
+
108
+ # For HuggingFace datasets, we'll create virtual shards based on chunks
109
+ # Each "shard" will be a range of indices
110
+ try:
111
+ # First, try to get available splits
112
+ try:
113
+ from datasets import get_dataset_split_names
114
+
115
+ available_splits = get_dataset_split_names(self.dataset_path, token=self.token)
116
+ logger.info(f"Available splits: {available_splits}")
117
+
118
+ if self.split not in available_splits:
119
+ logger.warning(
120
+ f"Requested split '{self.split}' not found. "
121
+ f"Available splits: {available_splits}. "
122
+ f"Using first available split: '{available_splits[0]}'"
123
+ )
124
+ self.split = available_splits[0]
125
+ except Exception as e:
126
+ logger.warning(f"Could not get split names: {e}")
127
+
128
+ # Load dataset info without downloading data
129
+ dataset_info = load_dataset(
130
+ self.dataset_path, split=self.split, streaming=True, token=self.token
131
+ )
132
+
133
+ # Try to get the total size
134
+ # For streaming datasets, we might need to iterate to count
135
+ # This is expensive, so we'll use a default chunk size instead
136
+ chunk_size = 10000 # Default chunk size for virtual shards
137
+
138
+ # Create virtual shard identifiers
139
+ # Format: "hf_dataset:<dataset_path>:chunk:<start_idx>"
140
+ virtual_shards = []
141
+
142
+ # We'll create a reasonable number of virtual shards
143
+ # Without knowing the total size, we'll create them on-demand
144
+ # For now, create initial batch of virtual shards
145
+ for i in range(10): # Start with 10 virtual shards
146
+ shard_id = f"hf_dataset:{self.dataset_path}:chunk:{i * chunk_size}"
147
+ virtual_shards.append(shard_id)
148
+
149
+ logger.info(
150
+ f"Created {len(virtual_shards)} initial virtual shards for HuggingFace dataset"
151
+ )
152
+ return virtual_shards
153
+
154
+ except Exception as e:
155
+ logger.error(f"Error loading HuggingFace dataset info: {e}")
156
+ return []
157
+
55
158
  def _get_local_shards(self) -> List[str]:
56
159
  """Get shard files from local directory."""
57
160
  path = Path(self.dataset_path)
@@ -73,7 +176,14 @@ class DatasetLoader:
73
176
  if processed_keys is None:
74
177
  processed_keys = set()
75
178
 
76
- if self.dataset_type == "huggingface":
179
+ # Check if this is a virtual HuggingFace dataset shard
180
+ if shard_url.startswith("hf_dataset:"):
181
+ raise ValueError(
182
+ "Virtual HuggingFace dataset shards should use iterate_shard() directly, "
183
+ "not load_shard()"
184
+ )
185
+
186
+ if self.dataset_type == "huggingface" and self.dataset_format == "webdataset":
77
187
  # Use curl with auth token for HuggingFace
78
188
  url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.token)}' {shlex.quote(shard_url)} || true"
79
189
  ds = wds.DataPipeline(
@@ -93,6 +203,19 @@ class DatasetLoader:
93
203
 
94
204
  return ds
95
205
 
206
+ def _parse_virtual_shard(self, shard_url: str) -> Tuple[str, int, int]:
207
+ """Parse virtual shard identifier."""
208
+ # Format: "hf_dataset:<dataset_path>:chunk:<start_idx>"
209
+ parts = shard_url.split(":")
210
+ if len(parts) != 4 or parts[0] != "hf_dataset" or parts[2] != "chunk":
211
+ raise ValueError(f"Invalid virtual shard format: {shard_url}")
212
+
213
+ dataset_path = parts[1]
214
+ start_idx = int(parts[3])
215
+ chunk_size = 10000 # Default chunk size
216
+
217
+ return dataset_path, start_idx, chunk_size
218
+
96
219
  def iterate_shard(
97
220
  self, shard_url: str, processed_keys: Optional[set] = None
98
221
  ) -> Generator[Tuple[str, str, bytes], None, None]:
@@ -102,85 +225,281 @@ class DatasetLoader:
102
225
  Yields:
103
226
  Tuple of (key, url, image_bytes)
104
227
  """
105
- ds = self.load_shard(shard_url, processed_keys)
228
+ # Check if this is a virtual HuggingFace dataset shard
229
+ if shard_url.startswith("hf_dataset:"):
230
+ yield from self._iterate_hf_dataset_shard(shard_url, processed_keys)
231
+ else:
232
+ # Regular WebDataset shard
233
+ ds = self.load_shard(shard_url, processed_keys)
234
+ for key, url, image_data in ds:
235
+ yield key, url, image_data
236
+
237
+ def _iterate_hf_dataset_shard_with_metadata(
238
+ self, shard_url: str, processed_keys: Optional[set] = None
239
+ ) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
240
+ """Iterate over a virtual HuggingFace dataset shard with metadata."""
241
+ if processed_keys is None:
242
+ processed_keys = set()
106
243
 
107
- for key, url, image_data in ds:
108
- yield key, url, image_data
244
+ dataset_path, start_idx, chunk_size = self._parse_virtual_shard(shard_url)
245
+
246
+ logger.info(
247
+ f"Loading HuggingFace dataset with metadata: {dataset_path} (split: {self.split})"
248
+ )
109
249
 
110
- def count_shard_items(self, shard_url: str, processed_keys: Optional[set] = None) -> int:
111
- """Count items in a shard (can be slow for large shards)."""
112
- count = 0
113
250
  try:
114
- for _ in self.iterate_shard(shard_url, processed_keys):
115
- count += 1
116
- except Exception as e:
117
- logger.error(f"Error counting shard {shard_url}: {e}")
118
- return count
251
+ # Load dataset in streaming mode
252
+ dataset = load_dataset(
253
+ dataset_path,
254
+ split=self.split,
255
+ streaming=True,
256
+ token=self.token,
257
+ )
119
258
 
259
+ # Skip to start index if needed - CONSISTENT WITH OTHER METHOD
260
+ if start_idx > 0:
261
+ dataset = dataset.skip(start_idx)
262
+
263
+ items_processed = 0
264
+
265
+ for item in dataset:
266
+ # Stop after processing chunk_size items
267
+ if items_processed >= chunk_size:
268
+ break
269
+
270
+ # Generate a unique key for this item - CONSISTENT FORMAT
271
+ key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
272
+
273
+ if key in processed_keys:
274
+ items_processed += 1
275
+ continue
276
+
277
+ try:
278
+ # Extract image data
279
+ if self.image_column in item:
280
+ img_data = item[self.image_column]
281
+
282
+ # Process image to bytes
283
+ image_bytes = ImageProcessor.process_image_data(img_data)
284
+
285
+ if image_bytes:
286
+ # Extract all metadata (excluding the image column)
287
+ metadata = {k: v for k, v in item.items() if k != self.image_column}
288
+
289
+ # URL is virtual for HF datasets
290
+ url = f"hf://{dataset_path}#{start_idx + items_processed}"
291
+ items_processed += 1
292
+ yield key, url, image_bytes, metadata
293
+ else:
294
+ logger.warning(
295
+ f"Failed to process image for item at index {start_idx + items_processed}"
296
+ )
297
+ items_processed += 1
298
+ continue
299
+ else:
300
+ logger.warning(
301
+ f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
302
+ f"Available columns: {list(item.keys())}"
303
+ )
304
+ items_processed += 1
305
+
306
+ except Exception as e:
307
+ logger.error(
308
+ f"Error processing item at index {start_idx + items_processed}: {e}"
309
+ )
310
+ items_processed += 1
311
+ continue
120
312
 
121
- class ShardTracker:
122
- """Tracks shard processing progress."""
313
+ except Exception as e:
314
+ logger.error(f"Error loading HuggingFace dataset: {e}")
315
+ return
123
316
 
124
- def __init__(self, checkpoint_path: Path):
125
- """Initialize shard tracker with checkpoint file."""
126
- self.checkpoint_path = checkpoint_path
127
- self.checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
317
+ def _iterate_hf_dataset_shard(
318
+ self, shard_url: str, processed_keys: Optional[set] = None
319
+ ) -> Generator[Tuple[str, str, bytes], None, None]:
320
+ """Iterate over a virtual HuggingFace dataset shard."""
321
+ if processed_keys is None:
322
+ processed_keys = set()
128
323
 
129
- self.completed_shards: set = set()
130
- self.partial_shards: Dict[str, Dict[str, Any]] = {}
131
- self.load()
324
+ dataset_path, start_idx, chunk_size = self._parse_virtual_shard(shard_url)
132
325
 
133
- def load(self):
134
- """Load checkpoint from disk."""
135
- if self.checkpoint_path.exists():
136
- try:
137
- data = json.loads(self.checkpoint_path.read_text())
138
- self.completed_shards = set(data.get("completed_shards", []))
139
- self.partial_shards = data.get("partial_shards", {})
140
- logger.info(
141
- f"Loaded checkpoint: {len(self.completed_shards)} completed, "
142
- f"{len(self.partial_shards)} partial shards"
143
- )
144
- except Exception as e:
145
- logger.error(f"Failed to load checkpoint: {e}")
326
+ # IMPORTANT: Check if start_idx is beyond dataset bounds
327
+ if self._hf_total_items is not None and start_idx >= self._hf_total_items:
328
+ logger.warning(
329
+ f"Virtual shard starts at index {start_idx} but dataset only has "
330
+ f"{self._hf_total_items} items. Skipping this shard."
331
+ )
332
+ return
146
333
 
147
- def save(self):
148
- """Save checkpoint to disk."""
149
- data = {
150
- "completed_shards": list(self.completed_shards),
151
- "partial_shards": self.partial_shards,
152
- }
334
+ logger.info(
335
+ f"Loading HuggingFace dataset in streaming mode: {dataset_path} "
336
+ f"(split: {self.split}, start: {start_idx}, chunk_size: {chunk_size})"
337
+ )
153
338
 
154
- tmp = self.checkpoint_path.with_suffix(".tmp")
155
- tmp.write_text(json.dumps(data, indent=2))
156
- tmp.replace(self.checkpoint_path)
339
+ try:
340
+ # Load dataset in streaming mode
341
+ dataset = load_dataset(
342
+ dataset_path,
343
+ split=self.split,
344
+ streaming=True,
345
+ token=self.token,
346
+ )
157
347
 
158
- def mark_complete(self, shard_name: str):
159
- """Mark a shard as complete."""
160
- self.completed_shards.add(shard_name)
161
- if shard_name in self.partial_shards:
162
- del self.partial_shards[shard_name]
163
- self.save()
348
+ # Use dataset.skip() for efficient skipping
349
+ if start_idx > 0:
350
+ dataset = dataset.skip(start_idx)
351
+ logger.info(f"Skipped to index {start_idx}")
352
+
353
+ items_processed = 0
354
+
355
+ # Now enumerate starts from 0 after skip
356
+ for item in dataset:
357
+ # Stop after processing chunk_size items
358
+ if items_processed >= chunk_size:
359
+ logger.info(f"Completed chunk: processed {items_processed} items")
360
+ break
361
+
362
+ # Also stop if we've reached the dataset end
363
+ if self._hf_total_items and (start_idx + items_processed) >= self._hf_total_items:
364
+ logger.info(
365
+ f"Reached dataset end at item {start_idx + items_processed} "
366
+ f"(total: {self._hf_total_items})"
367
+ )
368
+ break
369
+
370
+ # Generate a unique key for this item - ensure proper formatting
371
+ key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
372
+
373
+ if key in processed_keys:
374
+ items_processed += 1
375
+ continue
376
+
377
+ try:
378
+ # Extract image data - check configured column name
379
+ if self.image_column in item:
380
+ img_data = item[self.image_column]
381
+
382
+ # Delegate image processing to ImageProcessor
383
+ image_bytes = ImageProcessor.process_image_data(img_data)
384
+
385
+ if image_bytes:
386
+ # URL is virtual for HF datasets
387
+ url = f"hf://{dataset_path}#{start_idx + items_processed}"
388
+ items_processed += 1
389
+ yield key, url, image_bytes
390
+ else:
391
+ logger.warning(
392
+ f"Failed to process image for item at index {start_idx + items_processed}"
393
+ )
394
+ items_processed += 1
395
+ continue
396
+ else:
397
+ logger.warning(
398
+ f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
399
+ f"Available columns: {list(item.keys())}"
400
+ )
401
+ items_processed += 1
402
+
403
+ except Exception as e:
404
+ logger.error(
405
+ f"Error processing item at index {start_idx + items_processed}: {e}"
406
+ )
407
+ items_processed += 1
408
+ continue
409
+
410
+ logger.info(
411
+ f"Virtual shard complete: processed {items_processed} items "
412
+ f"(start_idx: {start_idx})"
413
+ )
164
414
 
165
- def update_partial(self, shard_name: str, processed_keys: List[str]):
166
- """Update partial progress for a shard."""
167
- self.partial_shards[shard_name] = {"keys": processed_keys, "count": len(processed_keys)}
168
- self.save()
415
+ except Exception as e:
416
+ logger.error(f"Error loading HuggingFace dataset: {e}")
417
+ return
169
418
 
170
- def get_processed_keys(self, shard_name: str) -> set:
171
- """Get set of processed keys for a shard."""
172
- if shard_name in self.completed_shards:
173
- return set() # All done
419
+ def iterate_shard_with_metadata(
420
+ self, shard_url: str, processed_keys: Optional[set] = None
421
+ ) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
422
+ """
423
+ Iterate over items in a shard, including metadata.
424
+
425
+ Yields:
426
+ Tuple of (key, url, image_bytes, metadata_dict)
427
+ """
428
+ # Check if this is a virtual HuggingFace dataset shard
429
+ if shard_url.startswith("hf_dataset:"):
430
+ yield from self._iterate_hf_dataset_shard_with_metadata(shard_url, processed_keys)
431
+ else:
432
+ # Regular WebDataset shard - no metadata by default
433
+ for key, url, image_data in self.iterate_shard(shard_url, processed_keys):
434
+ yield key, url, image_data, {}
174
435
 
175
- if shard_name in self.partial_shards:
176
- return set(self.partial_shards[shard_name].get("keys", []))
436
+ def count_shard_items(self, shard_url: str, processed_keys: Optional[set] = None) -> int:
437
+ """Count items in a shard (can be slow for large shards)."""
438
+ if shard_url.startswith("hf_dataset:"):
439
+ # For virtual shards, return the chunk size
440
+ _, start_idx, chunk_size = self._parse_virtual_shard(shard_url)
441
+
442
+ # CRITICAL: Cap chunk size by dataset bounds
443
+ if self._hf_total_items is not None:
444
+ # If start index is beyond dataset, return 0
445
+ if start_idx >= self._hf_total_items:
446
+ logger.warning(
447
+ f"Virtual shard starts at {start_idx} but dataset has "
448
+ f"only {self._hf_total_items} items"
449
+ )
450
+ return 0
451
+
452
+ # Otherwise, return the minimum of chunk_size and remaining items
453
+ remaining_items = self._hf_total_items - start_idx
454
+ actual_size = min(chunk_size, remaining_items)
455
+ logger.debug(
456
+ f"Virtual shard at {start_idx}: chunk_size={chunk_size}, "
457
+ f"remaining={remaining_items}, actual={actual_size}"
458
+ )
459
+ return actual_size
460
+ else:
461
+ # If we don't know total size, return chunk_size
462
+ return chunk_size
463
+ else:
464
+ # Regular WebDataset counting
465
+ count = 0
466
+ try:
467
+ for _ in self.iterate_shard(shard_url, processed_keys):
468
+ count += 1
469
+ except Exception as e:
470
+ logger.error(f"Error counting shard {shard_url}: {e}")
471
+ return count
472
+
473
+ def get_dataset_info(self) -> Dict[str, Any]:
474
+ """Get information about the dataset."""
475
+ info = {
476
+ "dataset_path": self.dataset_path,
477
+ "dataset_type": self.dataset_type,
478
+ "dataset_format": self.dataset_format,
479
+ }
177
480
 
178
- return set()
481
+ if self.dataset_format == "huggingface_datasets":
482
+ try:
483
+ # Try to get more info about the dataset
484
+ dataset_info = load_dataset(
485
+ self.dataset_path, split=self.split, streaming=True, token=self.token
486
+ )
487
+ # Get features info
488
+ if hasattr(dataset_info, "features"):
489
+ info["features"] = str(dataset_info.features)
490
+
491
+ # Try to get total size (might not work for all datasets)
492
+ try:
493
+ # This might be expensive for large datasets
494
+ total_examples = len(
495
+ load_dataset(self.dataset_path, split=self.split, token=self.token)
496
+ )
497
+ info["total_examples"] = total_examples
498
+ self._hf_total_items = total_examples
499
+ except:
500
+ info["total_examples"] = "unknown"
179
501
 
180
- def is_complete(self, shard_name: str) -> bool:
181
- """Check if a shard is complete."""
182
- return shard_name in self.completed_shards
502
+ except Exception as e:
503
+ logger.error(f"Error getting dataset info: {e}")
183
504
 
184
- def get_remaining_shards(self, all_shards: List[str]) -> List[str]:
185
- """Get list of shards that still need processing."""
186
- return [s for s in all_shards if Path(s).stem not in self.completed_shards]
505
+ return info