caption-flow 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,67 @@
1
+ """Dataset metadata caching for efficient HuggingFace dataset handling."""
2
+
3
+ import json
4
+ import logging
5
+ from pathlib import Path
6
+ from typing import Dict, Any, Optional, List
7
+ from datetime import datetime
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class DatasetMetadataCache:
13
+ """Caches dataset metadata to avoid repeated full iterations."""
14
+
15
+ def __init__(self, cache_dir: Path):
16
+ self.cache_dir = Path(cache_dir)
17
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
18
+ self.cache_file = self.cache_dir / "dataset_metadata.json"
19
+ self.metadata: Dict[str, Any] = {}
20
+ self._load_cache()
21
+
22
+ def _load_cache(self):
23
+ """Load cached metadata from disk."""
24
+ if self.cache_file.exists():
25
+ try:
26
+ with open(self.cache_file, "r") as f:
27
+ self.metadata = json.load(f)
28
+ logger.info(f"Loaded dataset metadata cache with {len(self.metadata)} datasets")
29
+ except Exception as e:
30
+ logger.error(f"Failed to load metadata cache: {e}")
31
+ self.metadata = {}
32
+
33
+ def _save_cache(self):
34
+ """Save metadata cache to disk."""
35
+ try:
36
+ with open(self.cache_file, "w") as f:
37
+ json.dump(self.metadata, f, indent=2)
38
+ logger.debug("Saved dataset metadata cache")
39
+ except Exception as e:
40
+ logger.error(f"Failed to save metadata cache: {e}")
41
+
42
+ def get_dataset_key(self, dataset_path: str, split: str) -> str:
43
+ """Generate a unique key for a dataset+split combination."""
44
+ return f"{dataset_path}:{split}"
45
+
46
+ def get_metadata(self, dataset_path: str, split: str) -> Optional[Dict[str, Any]]:
47
+ """Get cached metadata for a dataset."""
48
+ key = self.get_dataset_key(dataset_path, split)
49
+ return self.metadata.get(key)
50
+
51
+ def set_metadata(self, dataset_path: str, split: str, metadata: Dict[str, Any]):
52
+ """Cache metadata for a dataset."""
53
+ key = self.get_dataset_key(dataset_path, split)
54
+ metadata["cached_at"] = datetime.utcnow().isoformat()
55
+ metadata["dataset_path"] = dataset_path
56
+ metadata["split"] = split
57
+ self.metadata[key] = metadata
58
+ self._save_cache()
59
+ logger.info(f"Cached metadata for {key}: {metadata.get('total_items', 0)} items")
60
+
61
+ def invalidate(self, dataset_path: str, split: str):
62
+ """Remove cached metadata for a dataset."""
63
+ key = self.get_dataset_key(dataset_path, split)
64
+ if key in self.metadata:
65
+ del self.metadata[key]
66
+ self._save_cache()
67
+ logger.info(f"Invalidated metadata cache for {key}")
@@ -112,10 +112,7 @@ class ImageProcessor:
112
112
  return None
113
113
 
114
114
  except Exception as e:
115
- logger.error(f"Error processing image data: {e}")
116
- import traceback
117
-
118
- logger.error(traceback.format_exc())
115
+ logger.error(f"Error processing image data: {e}", exc_info=True)
119
116
  return None
120
117
 
121
118
  @staticmethod
@@ -7,6 +7,7 @@ from abc import ABC, abstractmethod
7
7
  from pathlib import Path
8
8
  from typing import Generator, Tuple, Optional, Dict, Any
9
9
  from dataclasses import dataclass
10
+ from .image_processor import ImageProcessor
10
11
  from threading import Event
11
12
  import shlex
12
13
 
@@ -22,84 +23,6 @@ class ShardProcessor(ABC):
22
23
  """Abstract base for processing dataset shards."""
23
24
 
24
25
  @abstractmethod
25
- def iterate_chunk(
26
- self,
27
- chunk,
28
- dataset_loader: Optional[DatasetLoader],
29
- should_stop: Event,
30
- connected: Event,
31
- ) -> Generator[Tuple[str, str, bytes], None, None]:
32
- """
33
- Iterate through items in a chunk.
34
-
35
- Yields:
36
- Tuple of (key, url, image_data)
37
- """
38
- pass
39
-
40
-
41
- class HFDatasetShardProcessor(ShardProcessor):
42
- """Processor for HuggingFace virtual dataset shards."""
43
-
44
- def iterate_chunk(
45
- self,
46
- chunk,
47
- dataset_loader: Optional[DatasetLoader],
48
- should_stop: Event,
49
- connected: Event,
50
- ) -> Generator[Tuple[str, str, bytes], None, None]:
51
- """Process HuggingFace virtual shard chunk."""
52
- if not dataset_loader:
53
- logger.error("No dataset loader configured for HuggingFace dataset shard")
54
- return
55
-
56
- # Get unprocessed ranges
57
- unprocessed_ranges = getattr(chunk, "unprocessed_ranges", [(0, chunk.chunk_size - 1)])
58
-
59
- logger.info(
60
- f"Processing HF dataset chunk {chunk.chunk_id} with ranges: {unprocessed_ranges}"
61
- )
62
-
63
- items_processed = 0
64
- current_idx = 0
65
-
66
- # Construct proper virtual shard URL
67
- parts = chunk.shard_url.split("_chunk_")
68
- if len(parts) == 2:
69
- base_path = parts[0]
70
- virtual_shard_url = f"{base_path}:chunk:{chunk.start_index}"
71
- else:
72
- virtual_shard_url = chunk.shard_url
73
-
74
- logger.debug(f"Using virtual shard URL: {virtual_shard_url}")
75
-
76
- # Iterate through the virtual shard
77
- for key, url, image_data in dataset_loader.iterate_shard(virtual_shard_url):
78
- # Check if we should stop
79
- if should_stop.is_set() or not connected.is_set():
80
- logger.info(f"Stopping chunk processing early due to disconnect")
81
- break
82
-
83
- # Check if current index is in any unprocessed range
84
- in_range = any(start <= current_idx <= end for start, end in unprocessed_ranges)
85
-
86
- if not in_range:
87
- current_idx += 1
88
- continue # Skip already processed items
89
-
90
- # Check if we've processed enough for this chunk
91
- if current_idx >= chunk.chunk_size:
92
- break
93
-
94
- items_processed += 1
95
- current_idx += 1
96
- yield key, url, image_data
97
-
98
- logger.info(
99
- f"HF dataset chunk {chunk.chunk_id}: yielded {items_processed} items "
100
- f"from ranges {unprocessed_ranges}"
101
- )
102
-
103
26
  def iterate_chunk_with_metadata(
104
27
  self,
105
28
  chunk,
@@ -108,63 +31,12 @@ class HFDatasetShardProcessor(ShardProcessor):
108
31
  connected: Event,
109
32
  ) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
110
33
  """
111
- Process HuggingFace virtual shard chunk with metadata.
34
+ Iterate through items in a chunk with metadata.
112
35
 
113
36
  Yields:
114
37
  Tuple of (key, url, image_data, metadata)
115
38
  """
116
- if not dataset_loader:
117
- logger.error("No dataset loader configured for HuggingFace dataset shard")
118
- return
119
-
120
- # Get unprocessed ranges
121
- unprocessed_ranges = getattr(chunk, "unprocessed_ranges", [(0, chunk.chunk_size - 1)])
122
-
123
- logger.info(
124
- f"Processing HF dataset chunk {chunk.chunk_id} with ranges: {unprocessed_ranges}"
125
- )
126
-
127
- items_processed = 0
128
- current_idx = 0
129
-
130
- # Construct proper virtual shard URL
131
- parts = chunk.shard_url.split("_chunk_")
132
- if len(parts) == 2:
133
- base_path = parts[0]
134
- virtual_shard_url = f"{base_path}:chunk:{chunk.start_index}"
135
- else:
136
- virtual_shard_url = chunk.shard_url
137
-
138
- logger.debug(f"Using virtual shard URL: {virtual_shard_url}")
139
-
140
- # Use the new iterate method that includes metadata
141
- for key, url, image_data, metadata in dataset_loader.iterate_shard_with_metadata(
142
- virtual_shard_url
143
- ):
144
- # Check if we should stop
145
- if should_stop.is_set() or not connected.is_set():
146
- logger.info(f"Stopping chunk processing early due to disconnect")
147
- break
148
-
149
- # Check if current index is in any unprocessed range
150
- in_range = any(start <= current_idx <= end for start, end in unprocessed_ranges)
151
-
152
- if not in_range:
153
- current_idx += 1
154
- continue # Skip already processed items
155
-
156
- # Check if we've processed enough for this chunk
157
- if current_idx >= chunk.chunk_size:
158
- break
159
-
160
- items_processed += 1
161
- current_idx += 1
162
- yield key, url, image_data, metadata
163
-
164
- logger.info(
165
- f"HF dataset chunk {chunk.chunk_id}: yielded {items_processed} items "
166
- f"from ranges {unprocessed_ranges}"
167
- )
39
+ pass
168
40
 
169
41
 
170
42
  class WebDatasetShardProcessor(ShardProcessor):
@@ -174,74 +46,6 @@ class WebDatasetShardProcessor(ShardProcessor):
174
46
  self.hf_token = hf_token
175
47
  self.dataset_type = dataset_type
176
48
 
177
- def iterate_chunk(
178
- self,
179
- chunk,
180
- dataset_loader: Optional[DatasetLoader],
181
- should_stop: Event,
182
- connected: Event,
183
- ) -> Generator[Tuple[str, str, bytes], None, None]:
184
- """Process WebDataset shard chunk with unprocessed ranges."""
185
- # Get unprocessed ranges
186
- unprocessed_ranges = getattr(chunk, "unprocessed_ranges", [(0, chunk.chunk_size - 1)])
187
-
188
- logger.info(
189
- f"Processing WebDataset chunk {chunk.chunk_id} with ranges: {unprocessed_ranges}"
190
- )
191
-
192
- # Create WebDataset pipeline
193
- if self.dataset_type == "huggingface" and not chunk.shard_url.startswith("hf_dataset:"):
194
- # Use curl with auth for HuggingFace WebDataset
195
- url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.hf_token)}' {shlex.quote(chunk.shard_url)} || true"
196
- ds = wds.DataPipeline(
197
- wds.SimpleShardList(url_cmd),
198
- wds.tarfile_to_samples(),
199
- wds.to_tuple("__key__", "jpg;png;jpeg;webp;jxl"),
200
- )
201
- else:
202
- # Local file
203
- ds = wds.DataPipeline(
204
- wds.SimpleShardList(chunk.shard_url),
205
- wds.tarfile_to_samples(),
206
- wds.to_tuple("__key__", "jpg;png;jpeg;webp;jxl"),
207
- )
208
-
209
- # Process items
210
- current_idx = 0
211
- items_yielded = 0
212
-
213
- for key, image_data in ds:
214
- # Check if we should stop
215
- if should_stop.is_set() or not connected.is_set():
216
- logger.info(f"Stopping WebDataset chunk processing early due to disconnect")
217
- break
218
-
219
- # Calculate relative index within chunk
220
- relative_idx = current_idx - chunk.start_index
221
-
222
- # Skip items before chunk start
223
- if current_idx < chunk.start_index:
224
- current_idx += 1
225
- continue
226
-
227
- # Stop if beyond chunk
228
- if relative_idx >= chunk.chunk_size:
229
- break
230
-
231
- # Check if current index is in any unprocessed range
232
- in_range = any(start <= relative_idx <= end for start, end in unprocessed_ranges)
233
-
234
- if in_range:
235
- items_yielded += 1
236
- yield key, chunk.shard_url, image_data
237
-
238
- current_idx += 1
239
-
240
- logger.info(
241
- f"WebDataset chunk {chunk.chunk_id}: yielded {items_yielded} items "
242
- f"from ranges {unprocessed_ranges}"
243
- )
244
-
245
49
  def iterate_chunk_with_metadata(
246
50
  self,
247
51
  chunk,
@@ -258,7 +62,7 @@ class WebDatasetShardProcessor(ShardProcessor):
258
62
  )
259
63
 
260
64
  # Create WebDataset pipeline
261
- if self.dataset_type == "huggingface" and not chunk.shard_url.startswith("hf_dataset:"):
65
+ if self.dataset_type == "huggingface":
262
66
  # Use curl with auth for HuggingFace WebDataset
263
67
  url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.hf_token)}' {shlex.quote(chunk.shard_url)} || true"
264
68
  ds = wds.DataPipeline(
@@ -61,11 +61,7 @@ class ShardTracker(CheckpointTracker):
61
61
  """Get list of shards that still need processing."""
62
62
  remaining = []
63
63
  for s in all_shards:
64
- # Extract shard name properly for both regular and virtual shards
65
- if s.startswith("hf_dataset:"):
66
- shard_name = s # Use full virtual shard ID
67
- else:
68
- shard_name = Path(s).stem
64
+ shard_name = Path(s).stem
69
65
 
70
66
  if shard_name not in self.completed_shards:
71
67
  remaining.append(s)
@@ -74,7 +74,7 @@ class BaseWorker(ABC):
74
74
  await self._connect_and_run()
75
75
  reconnect_delay = 5 # Reset delay on successful connection
76
76
  except Exception as e:
77
- logger.error(f"Connection error: {e}")
77
+ logger.error(f"Connection error: {e}", exc_info=True)
78
78
  self.connected.clear()
79
79
  self.websocket = None
80
80
 
@@ -159,13 +159,13 @@ class BaseWorker(ABC):
159
159
  except json.JSONDecodeError as e:
160
160
  logger.error(f"Invalid message format: {e}")
161
161
  except Exception as e:
162
- logger.error(f"Error handling message: {e}")
162
+ logger.error(f"Error handling message: {e}", exc_info=True)
163
163
 
164
164
  except websockets.exceptions.ConnectionClosed as e:
165
165
  logger.info(f"Connection closed by orchestrator: {e}")
166
166
  raise
167
167
  except Exception as e:
168
- logger.error(f"Message handler error: {e}")
168
+ logger.error(f"Message handler error: {e}", exc_info=True)
169
169
  raise
170
170
 
171
171
  async def shutdown(self):