caption-flow 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,67 @@
1
+ """Dataset metadata caching for efficient HuggingFace dataset handling."""
2
+
3
+ import json
4
+ import logging
5
+ from pathlib import Path
6
+ from typing import Dict, Any, Optional, List
7
+ from datetime import datetime
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class DatasetMetadataCache:
13
+ """Caches dataset metadata to avoid repeated full iterations."""
14
+
15
+ def __init__(self, cache_dir: Path):
16
+ self.cache_dir = Path(cache_dir)
17
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
18
+ self.cache_file = self.cache_dir / "dataset_metadata.json"
19
+ self.metadata: Dict[str, Any] = {}
20
+ self._load_cache()
21
+
22
+ def _load_cache(self):
23
+ """Load cached metadata from disk."""
24
+ if self.cache_file.exists():
25
+ try:
26
+ with open(self.cache_file, "r") as f:
27
+ self.metadata = json.load(f)
28
+ logger.info(f"Loaded dataset metadata cache with {len(self.metadata)} datasets")
29
+ except Exception as e:
30
+ logger.error(f"Failed to load metadata cache: {e}")
31
+ self.metadata = {}
32
+
33
+ def _save_cache(self):
34
+ """Save metadata cache to disk."""
35
+ try:
36
+ with open(self.cache_file, "w") as f:
37
+ json.dump(self.metadata, f, indent=2)
38
+ logger.debug("Saved dataset metadata cache")
39
+ except Exception as e:
40
+ logger.error(f"Failed to save metadata cache: {e}")
41
+
42
+ def get_dataset_key(self, dataset_path: str, split: str) -> str:
43
+ """Generate a unique key for a dataset+split combination."""
44
+ return f"{dataset_path}:{split}"
45
+
46
+ def get_metadata(self, dataset_path: str, split: str) -> Optional[Dict[str, Any]]:
47
+ """Get cached metadata for a dataset."""
48
+ key = self.get_dataset_key(dataset_path, split)
49
+ return self.metadata.get(key)
50
+
51
+ def set_metadata(self, dataset_path: str, split: str, metadata: Dict[str, Any]):
52
+ """Cache metadata for a dataset."""
53
+ key = self.get_dataset_key(dataset_path, split)
54
+ metadata["cached_at"] = datetime.utcnow().isoformat()
55
+ metadata["dataset_path"] = dataset_path
56
+ metadata["split"] = split
57
+ self.metadata[key] = metadata
58
+ self._save_cache()
59
+ logger.info(f"Cached metadata for {key}: {metadata.get('total_items', 0)} items")
60
+
61
+ def invalidate(self, dataset_path: str, split: str):
62
+ """Remove cached metadata for a dataset."""
63
+ key = self.get_dataset_key(dataset_path, split)
64
+ if key in self.metadata:
65
+ del self.metadata[key]
66
+ self._save_cache()
67
+ logger.info(f"Invalidated metadata cache for {key}")
@@ -112,10 +112,7 @@ class ImageProcessor:
112
112
  return None
113
113
 
114
114
  except Exception as e:
115
- logger.error(f"Error processing image data: {e}")
116
- import traceback
117
-
118
- logger.error(traceback.format_exc())
115
+ logger.error(f"Error processing image data: {e}", exc_info=True)
119
116
  return None
120
117
 
121
118
  @staticmethod
@@ -7,7 +7,6 @@ from abc import ABC, abstractmethod
7
7
  from pathlib import Path
8
8
  from typing import Generator, Tuple, Optional, Dict, Any
9
9
  from dataclasses import dataclass
10
- from datasets import load_dataset
11
10
  from .image_processor import ImageProcessor
12
11
  from threading import Event
13
12
  import shlex
@@ -24,213 +23,22 @@ class ShardProcessor(ABC):
24
23
  """Abstract base for processing dataset shards."""
25
24
 
26
25
  @abstractmethod
27
- def iterate_chunk(
26
+ def iterate_chunk_with_metadata(
28
27
  self,
29
28
  chunk,
30
29
  dataset_loader: Optional[DatasetLoader],
31
30
  should_stop: Event,
32
31
  connected: Event,
33
- ) -> Generator[Tuple[str, str, bytes], None, None]:
32
+ ) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
34
33
  """
35
- Iterate through items in a chunk.
34
+ Iterate through items in a chunk with metadata.
36
35
 
37
36
  Yields:
38
- Tuple of (key, url, image_data)
37
+ Tuple of (key, url, image_data, metadata)
39
38
  """
40
39
  pass
41
40
 
42
41
 
43
- class HFDatasetShardProcessor(ShardProcessor):
44
- """Processor for HuggingFace virtual dataset shards."""
45
-
46
- def iterate_chunk(
47
- self,
48
- chunk,
49
- dataset_loader: Optional[DatasetLoader],
50
- should_stop: Event,
51
- connected: Event,
52
- ) -> Generator[Tuple[str, str, bytes], None, None]:
53
- """Process HuggingFace virtual shard chunk."""
54
- if not dataset_loader:
55
- logger.error("No dataset loader configured for HuggingFace dataset shard")
56
- return
57
-
58
- # Get unprocessed ranges
59
- unprocessed_ranges = getattr(chunk, "unprocessed_ranges", [(0, chunk.chunk_size - 1)])
60
-
61
- logger.info(
62
- f"Processing HF dataset chunk {chunk.chunk_id} with ranges: {unprocessed_ranges}"
63
- )
64
-
65
- items_processed = 0
66
- current_idx = 0
67
-
68
- # Construct proper virtual shard URL
69
- parts = chunk.shard_url.split("_chunk_")
70
- if len(parts) == 2:
71
- base_path = parts[0]
72
- virtual_shard_url = f"{base_path}:chunk:{chunk.start_index}"
73
- else:
74
- virtual_shard_url = chunk.shard_url
75
-
76
- logger.debug(f"Using virtual shard URL: {virtual_shard_url}")
77
-
78
- # Iterate through the virtual shard
79
- for key, url, image_data in dataset_loader.iterate_shard(virtual_shard_url):
80
- # Check if we should stop
81
- if should_stop.is_set() or not connected.is_set():
82
- logger.info(f"Stopping chunk processing early due to disconnect")
83
- break
84
-
85
- # Check if current index is in any unprocessed range
86
- in_range = any(start <= current_idx <= end for start, end in unprocessed_ranges)
87
-
88
- if not in_range:
89
- current_idx += 1
90
- continue # Skip already processed items
91
-
92
- # Check if we've processed enough for this chunk
93
- if current_idx >= chunk.chunk_size:
94
- break
95
-
96
- items_processed += 1
97
- current_idx += 1
98
- yield key, url, image_data
99
-
100
- logger.info(
101
- f"HF dataset chunk {chunk.chunk_id}: yielded {items_processed} items "
102
- f"from ranges {unprocessed_ranges}"
103
- )
104
-
105
- def iterate_chunk_with_metadata(
106
- self,
107
- chunk,
108
- dataset_loader: Optional[DatasetLoader],
109
- should_stop: Event,
110
- connected: Event,
111
- ) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
112
- """
113
- Process HuggingFace virtual shard chunk with metadata, range by range.
114
- """
115
- if not dataset_loader:
116
- logger.error("No dataset loader configured for HuggingFace dataset shard")
117
- return
118
-
119
- # Get unprocessed ranges
120
- unprocessed_ranges = getattr(chunk, "unprocessed_ranges", [(0, chunk.chunk_size - 1)])
121
-
122
- logger.info(
123
- f"Processing HF dataset chunk {chunk.chunk_id} with {len(unprocessed_ranges)} ranges"
124
- )
125
-
126
- items_yielded = 0
127
-
128
- # Process each range independently with its own iterator
129
- for range_start, range_end in unprocessed_ranges:
130
- if should_stop.is_set() or not connected.is_set():
131
- logger.info(f"Stopping chunk processing early due to disconnect")
132
- break
133
-
134
- # Calculate absolute indices for this range
135
- abs_start = chunk.start_index + range_start
136
- abs_end = chunk.start_index + range_end
137
- range_size = range_end - range_start + 1
138
-
139
- logger.debug(
140
- f"Processing range [{range_start}, {range_end}] "
141
- f"(absolute: [{abs_start}, {abs_end}])"
142
- )
143
-
144
- try:
145
- # Create a fresh dataset iterator for this range
146
- dataset = load_dataset(
147
- dataset_loader.dataset_path,
148
- split=dataset_loader.split,
149
- streaming=True,
150
- token=dataset_loader.token,
151
- )
152
-
153
- # Use state_dict if available for efficient positioning
154
- if hasattr(dataset, "load_state_dict") and hasattr(dataset, "state_dict"):
155
- try:
156
- state = dataset.state_dict()
157
- # Modify state to jump to abs_start
158
- if "num_examples_since_previous_state" in state:
159
- state["num_examples_since_previous_state"] = abs_start
160
- if "examples_iterable" in state and isinstance(
161
- state["examples_iterable"], dict
162
- ):
163
- if "shard_example_idx" in state["examples_iterable"]:
164
- state["examples_iterable"]["shard_example_idx"] = abs_start
165
- dataset.load_state_dict(state)
166
- logger.debug(f"Positioned dataset at index {abs_start} using state_dict")
167
- except Exception as e:
168
- logger.debug(f"Could not use state_dict, falling back to skip: {e}")
169
- dataset = dataset.skip(abs_start)
170
- else:
171
- # Fall back to skip
172
- dataset = dataset.skip(abs_start)
173
-
174
- # Process items in this range
175
- range_items = 0
176
- for item in dataset:
177
- if range_items >= range_size:
178
- break
179
-
180
- if should_stop.is_set() or not connected.is_set():
181
- break
182
-
183
- # Generate key for this item
184
- current_abs_idx = abs_start + range_items
185
- key = f"{dataset_loader.dataset_path.replace('/', '_')}_{current_abs_idx:08d}"
186
-
187
- try:
188
- if dataset_loader.image_column in item:
189
- img_data = item[dataset_loader.image_column]
190
- image_bytes = ImageProcessor.process_image_data(img_data)
191
-
192
- if image_bytes:
193
- # Extract metadata
194
- metadata = {
195
- k: v
196
- for k, v in item.items()
197
- if k != dataset_loader.image_column
198
- }
199
- # Add chunk-relative index to metadata
200
- metadata["_chunk_relative_index"] = range_start + range_items
201
-
202
- url = f"hf://{dataset_loader.dataset_path}#{current_abs_idx}"
203
-
204
- items_yielded += 1
205
- range_items += 1
206
-
207
- yield key, url, image_bytes, metadata
208
- else:
209
- logger.warning(
210
- f"Failed to process image at index {current_abs_idx}"
211
- )
212
- range_items += 1
213
- else:
214
- logger.warning(
215
- f"No image column '{dataset_loader.image_column}' at index {current_abs_idx}"
216
- )
217
- range_items += 1
218
-
219
- except Exception as e:
220
- logger.error(f"Error processing item at index {current_abs_idx}: {e}")
221
- range_items += 1
222
- continue
223
-
224
- except Exception as e:
225
- logger.error(f"Error processing range [{range_start}, {range_end}]: {e}")
226
- continue
227
-
228
- logger.info(
229
- f"HF dataset chunk {chunk.chunk_id}: yielded {items_yielded} items "
230
- f"from {len(unprocessed_ranges)} ranges"
231
- )
232
-
233
-
234
42
  class WebDatasetShardProcessor(ShardProcessor):
235
43
  """Processor for WebDataset tar shards with range support."""
236
44
 
@@ -238,74 +46,6 @@ class WebDatasetShardProcessor(ShardProcessor):
238
46
  self.hf_token = hf_token
239
47
  self.dataset_type = dataset_type
240
48
 
241
- def iterate_chunk(
242
- self,
243
- chunk,
244
- dataset_loader: Optional[DatasetLoader],
245
- should_stop: Event,
246
- connected: Event,
247
- ) -> Generator[Tuple[str, str, bytes], None, None]:
248
- """Process WebDataset shard chunk with unprocessed ranges."""
249
- # Get unprocessed ranges
250
- unprocessed_ranges = getattr(chunk, "unprocessed_ranges", [(0, chunk.chunk_size - 1)])
251
-
252
- logger.info(
253
- f"Processing WebDataset chunk {chunk.chunk_id} with ranges: {unprocessed_ranges}"
254
- )
255
-
256
- # Create WebDataset pipeline
257
- if self.dataset_type == "huggingface" and not chunk.shard_url.startswith("hf_dataset:"):
258
- # Use curl with auth for HuggingFace WebDataset
259
- url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.hf_token)}' {shlex.quote(chunk.shard_url)} || true"
260
- ds = wds.DataPipeline(
261
- wds.SimpleShardList(url_cmd),
262
- wds.tarfile_to_samples(),
263
- wds.to_tuple("__key__", "jpg;png;jpeg;webp;jxl"),
264
- )
265
- else:
266
- # Local file
267
- ds = wds.DataPipeline(
268
- wds.SimpleShardList(chunk.shard_url),
269
- wds.tarfile_to_samples(),
270
- wds.to_tuple("__key__", "jpg;png;jpeg;webp;jxl"),
271
- )
272
-
273
- # Process items
274
- current_idx = 0
275
- items_yielded = 0
276
-
277
- for key, image_data in ds:
278
- # Check if we should stop
279
- if should_stop.is_set() or not connected.is_set():
280
- logger.info(f"Stopping WebDataset chunk processing early due to disconnect")
281
- break
282
-
283
- # Calculate relative index within chunk
284
- relative_idx = current_idx - chunk.start_index
285
-
286
- # Skip items before chunk start
287
- if current_idx < chunk.start_index:
288
- current_idx += 1
289
- continue
290
-
291
- # Stop if beyond chunk
292
- if relative_idx >= chunk.chunk_size:
293
- break
294
-
295
- # Check if current index is in any unprocessed range
296
- in_range = any(start <= relative_idx <= end for start, end in unprocessed_ranges)
297
-
298
- if in_range:
299
- items_yielded += 1
300
- yield key, chunk.shard_url, image_data
301
-
302
- current_idx += 1
303
-
304
- logger.info(
305
- f"WebDataset chunk {chunk.chunk_id}: yielded {items_yielded} items "
306
- f"from ranges {unprocessed_ranges}"
307
- )
308
-
309
49
  def iterate_chunk_with_metadata(
310
50
  self,
311
51
  chunk,
@@ -322,7 +62,7 @@ class WebDatasetShardProcessor(ShardProcessor):
322
62
  )
323
63
 
324
64
  # Create WebDataset pipeline
325
- if self.dataset_type == "huggingface" and not chunk.shard_url.startswith("hf_dataset:"):
65
+ if self.dataset_type == "huggingface":
326
66
  # Use curl with auth for HuggingFace WebDataset
327
67
  url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.hf_token)}' {shlex.quote(chunk.shard_url)} || true"
328
68
  ds = wds.DataPipeline(
@@ -61,11 +61,7 @@ class ShardTracker(CheckpointTracker):
61
61
  """Get list of shards that still need processing."""
62
62
  remaining = []
63
63
  for s in all_shards:
64
- # Extract shard name properly for both regular and virtual shards
65
- if s.startswith("hf_dataset:"):
66
- shard_name = s # Use full virtual shard ID
67
- else:
68
- shard_name = Path(s).stem
64
+ shard_name = Path(s).stem
69
65
 
70
66
  if shard_name not in self.completed_shards:
71
67
  remaining.append(s)