caption-flow 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,14 +1,20 @@
1
1
  """Image preprocessing utilities."""
2
2
 
3
3
  import asyncio
4
+ import logging
4
5
  from concurrent.futures import ProcessPoolExecutor
6
+ from io import BytesIO
5
7
  from pathlib import Path
6
- from typing import List, Any
8
+ from typing import List, Any, Optional, Tuple, Union
7
9
 
8
10
  import numpy as np
11
+ import requests
9
12
  from PIL import Image
10
13
 
11
14
 
15
+ logger = logging.getLogger(__name__)
16
+
17
+
12
18
  class ImageProcessor:
13
19
  """Handles image loading and preprocessing."""
14
20
 
@@ -46,6 +52,120 @@ class ImageProcessor:
46
52
 
47
53
  return arr
48
54
 
55
+ @staticmethod
56
+ def process_image_data(img_data: Union[str, bytes, Image.Image]) -> Optional[bytes]:
57
+ """
58
+ Process various types of image data into bytes.
59
+
60
+ Args:
61
+ img_data: Can be a URL string, bytes, or PIL Image
62
+
63
+ Returns:
64
+ Image data as bytes, or None if processing failed
65
+ """
66
+ try:
67
+ if isinstance(img_data, str):
68
+ # It's a URL - download the image
69
+ try:
70
+ # Download with timeout
71
+ response = requests.get(
72
+ img_data,
73
+ timeout=30,
74
+ headers={"User-Agent": "Mozilla/5.0 (captionflow-dataset-loader)"},
75
+ )
76
+ response.raise_for_status()
77
+ image_data = response.content
78
+
79
+ # Verify it's an image by trying to open it
80
+ img = Image.open(BytesIO(image_data))
81
+ img.verify() # Verify it's a valid image
82
+
83
+ return image_data
84
+
85
+ except Exception as e:
86
+ logger.error(f"Failed to download image from {img_data}: {e}")
87
+ return None
88
+
89
+ elif hasattr(img_data, "__class__") and "Image" in str(img_data.__class__):
90
+ # It's a PIL Image object
91
+ import io
92
+
93
+ # Save as PNG bytes
94
+ img_bytes = io.BytesIO()
95
+ # Convert to RGB
96
+ img_data = img_data.convert("RGB")
97
+ img_data.save(img_bytes, format="PNG")
98
+ return img_bytes.getvalue()
99
+
100
+ elif isinstance(img_data, bytes):
101
+ # Already bytes - validate it's an image
102
+ try:
103
+ img = Image.open(BytesIO(img_data))
104
+ img.verify()
105
+ return img_data
106
+ except Exception as e:
107
+ logger.error(f"Invalid image data: {e}")
108
+ return None
109
+
110
+ else:
111
+ logger.warning(f"Unknown image data type: {type(img_data)}")
112
+ return None
113
+
114
+ except Exception as e:
115
+ logger.error(f"Error processing image data: {e}")
116
+ import traceback
117
+
118
+ logger.error(traceback.format_exc())
119
+ return None
120
+
121
+ @staticmethod
122
+ def prepare_for_inference(image: Image.Image) -> Image.Image:
123
+ """
124
+ Prepare image for inference, handling transparency and mostly black/white images.
125
+
126
+ Args:
127
+ image: PIL Image to prepare
128
+
129
+ Returns:
130
+ Prepared PIL Image
131
+ """
132
+ # Convert to RGBA to handle transparency
133
+ img_rgba = image.convert("RGBA")
134
+ rgb_img = img_rgba.convert("RGB")
135
+ np_img = np.array(rgb_img)
136
+
137
+ # Calculate percentage of pixels that are (0,0,0) or (255,255,255)
138
+ total_pixels = np_img.shape[0] * np_img.shape[1]
139
+ black_pixels = np.all(np_img == [0, 0, 0], axis=-1).sum()
140
+ white_pixels = np.all(np_img == [255, 255, 255], axis=-1).sum()
141
+ black_pct = black_pixels / total_pixels
142
+ white_pct = white_pixels / total_pixels
143
+
144
+ threshold = 0.90 # 90% threshold
145
+
146
+ is_mostly_black = black_pct >= threshold
147
+ is_mostly_white = white_pct >= threshold
148
+
149
+ if is_mostly_black or is_mostly_white:
150
+ # Replace background with opposite color for better contrast
151
+ bg_color = (255, 255, 255) if is_mostly_black else (0, 0, 0)
152
+ background = Image.new("RGB", img_rgba.size, bg_color)
153
+ # Use alpha channel as mask if present
154
+ if img_rgba.mode == "RGBA":
155
+ background.paste(img_rgba.convert("RGB"), mask=img_rgba.split()[3])
156
+ else:
157
+ background.paste(img_rgba.convert("RGB"))
158
+
159
+ color_type = "black" if is_mostly_black else "white"
160
+ pct = black_pct if is_mostly_black else white_pct
161
+ logger.debug(
162
+ f"Image is {pct*100:.1f}% {color_type}; background replaced with {bg_color}"
163
+ )
164
+
165
+ return background
166
+ else:
167
+ return rgb_img
168
+
49
169
  def shutdown(self):
50
170
  """Shutdown the executor."""
51
171
  self.executor.shutdown(wait=True)
@@ -0,0 +1,137 @@
1
+ """Prompt template system for dynamic column substitution."""
2
+
3
+ import re
4
+ import logging
5
+ from typing import Dict, Any, List, Optional
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class PromptTemplate:
11
+ """Handles prompt templates with column substitution."""
12
+
13
+ # Pattern to match {column:column_name} or {col:column_name}
14
+ COLUMN_PATTERN = re.compile(r"\{(?:column|col):([\w-]+)\}")
15
+
16
+ def __init__(self, template: str):
17
+ """
18
+ Initialize with a prompt template.
19
+
20
+ Args:
21
+ template: Prompt template string, e.g.
22
+ "describe this image. tags: {column:user_tags}"
23
+ """
24
+ self.template = template
25
+ self.required_columns = self._extract_columns()
26
+
27
+ def _extract_columns(self) -> List[str]:
28
+ """Extract required column names from template."""
29
+ matches = self.COLUMN_PATTERN.findall(self.template)
30
+ return list(set(matches)) # Remove duplicates
31
+
32
+ def format(self, item_data: Dict[str, Any]) -> str:
33
+ """
34
+ Format the template with actual column values.
35
+
36
+ Args:
37
+ item_data: Dictionary containing column values from dataset
38
+
39
+ Returns:
40
+ Formatted prompt string
41
+ """
42
+ prompt = self.template
43
+
44
+ # Replace all column references
45
+ for match in self.COLUMN_PATTERN.finditer(self.template):
46
+ full_match = match.group(0) # e.g., {column:user_tags}
47
+ column_name = match.group(1) # e.g., user_tags
48
+
49
+ # Get column value with fallback
50
+ value = item_data.get(column_name, "")
51
+
52
+ # Handle different value types
53
+ if value is None:
54
+ value = ""
55
+ elif isinstance(value, list):
56
+ # Join list items with commas
57
+ value = ", ".join(str(v) for v in value if v)
58
+ elif not isinstance(value, str):
59
+ value = str(value)
60
+
61
+ # Replace in prompt
62
+ prompt = prompt.replace(full_match, value)
63
+
64
+ return prompt.strip()
65
+
66
+ def validate_columns(self, available_columns: List[str]) -> List[str]:
67
+ """
68
+ Validate that required columns are available.
69
+
70
+ Returns:
71
+ List of missing column names
72
+ """
73
+ missing = []
74
+ for col in self.required_columns:
75
+ if col not in available_columns:
76
+ missing.append(col)
77
+ return missing
78
+
79
+
80
+ class PromptTemplateManager:
81
+ """Manages multiple prompt templates."""
82
+
83
+ def __init__(self, prompts: List[str]):
84
+ """
85
+ Initialize with list of prompt strings (which may contain templates).
86
+
87
+ Args:
88
+ prompts: List of prompt strings
89
+ """
90
+ self.templates = [PromptTemplate(p) for p in prompts]
91
+ self._all_required_columns = None
92
+
93
+ @property
94
+ def required_columns(self) -> List[str]:
95
+ """Get all required columns across all templates."""
96
+ if self._all_required_columns is None:
97
+ cols = set()
98
+ for template in self.templates:
99
+ cols.update(template.required_columns)
100
+ self._all_required_columns = list(cols)
101
+ return self._all_required_columns
102
+
103
+ def format_all(self, item_data: Dict[str, Any]) -> List[str]:
104
+ """
105
+ Format all templates with item data.
106
+
107
+ Args:
108
+ item_data: Dictionary containing column values
109
+
110
+ Returns:
111
+ List of formatted prompts
112
+ """
113
+ formatted = []
114
+ for template in self.templates:
115
+ try:
116
+ prompt = template.format(item_data)
117
+ formatted.append(prompt)
118
+ except Exception as e:
119
+ logger.error(f"Error formatting prompt template '{template.template}': {e}")
120
+ # Fall back to raw template
121
+ formatted.append(template.template)
122
+
123
+ return formatted
124
+
125
+ def validate_all(self, available_columns: List[str]) -> Dict[str, List[str]]:
126
+ """
127
+ Validate all templates against available columns.
128
+
129
+ Returns:
130
+ Dict mapping template string to list of missing columns
131
+ """
132
+ issues = {}
133
+ for template in self.templates:
134
+ missing = template.validate_columns(available_columns)
135
+ if missing:
136
+ issues[template.template] = missing
137
+ return issues
@@ -0,0 +1,315 @@
1
+ """Shard processing abstraction for different dataset types."""
2
+
3
+ import io
4
+ import logging
5
+ import time
6
+ from abc import ABC, abstractmethod
7
+ from pathlib import Path
8
+ from typing import Generator, Tuple, Optional, Dict, Any
9
+ from dataclasses import dataclass
10
+ from threading import Event
11
+ import shlex
12
+
13
+ import webdataset as wds
14
+ from PIL import Image
15
+
16
+ from .dataset_loader import DatasetLoader
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class ShardProcessor(ABC):
22
+ """Abstract base for processing dataset shards."""
23
+
24
+ @abstractmethod
25
+ def iterate_chunk(
26
+ self,
27
+ chunk,
28
+ dataset_loader: Optional[DatasetLoader],
29
+ should_stop: Event,
30
+ connected: Event,
31
+ ) -> Generator[Tuple[str, str, bytes], None, None]:
32
+ """
33
+ Iterate through items in a chunk.
34
+
35
+ Yields:
36
+ Tuple of (key, url, image_data)
37
+ """
38
+ pass
39
+
40
+
41
+ class HFDatasetShardProcessor(ShardProcessor):
42
+ """Processor for HuggingFace virtual dataset shards."""
43
+
44
+ def iterate_chunk(
45
+ self,
46
+ chunk,
47
+ dataset_loader: Optional[DatasetLoader],
48
+ should_stop: Event,
49
+ connected: Event,
50
+ ) -> Generator[Tuple[str, str, bytes], None, None]:
51
+ """Process HuggingFace virtual shard chunk."""
52
+ if not dataset_loader:
53
+ logger.error("No dataset loader configured for HuggingFace dataset shard")
54
+ return
55
+
56
+ # Get unprocessed ranges
57
+ unprocessed_ranges = getattr(chunk, "unprocessed_ranges", [(0, chunk.chunk_size - 1)])
58
+
59
+ logger.info(
60
+ f"Processing HF dataset chunk {chunk.chunk_id} with ranges: {unprocessed_ranges}"
61
+ )
62
+
63
+ items_processed = 0
64
+ current_idx = 0
65
+
66
+ # Construct proper virtual shard URL
67
+ parts = chunk.shard_url.split("_chunk_")
68
+ if len(parts) == 2:
69
+ base_path = parts[0]
70
+ virtual_shard_url = f"{base_path}:chunk:{chunk.start_index}"
71
+ else:
72
+ virtual_shard_url = chunk.shard_url
73
+
74
+ logger.debug(f"Using virtual shard URL: {virtual_shard_url}")
75
+
76
+ # Iterate through the virtual shard
77
+ for key, url, image_data in dataset_loader.iterate_shard(virtual_shard_url):
78
+ # Check if we should stop
79
+ if should_stop.is_set() or not connected.is_set():
80
+ logger.info(f"Stopping chunk processing early due to disconnect")
81
+ break
82
+
83
+ # Check if current index is in any unprocessed range
84
+ in_range = any(start <= current_idx <= end for start, end in unprocessed_ranges)
85
+
86
+ if not in_range:
87
+ current_idx += 1
88
+ continue # Skip already processed items
89
+
90
+ # Check if we've processed enough for this chunk
91
+ if current_idx >= chunk.chunk_size:
92
+ break
93
+
94
+ items_processed += 1
95
+ current_idx += 1
96
+ yield key, url, image_data
97
+
98
+ logger.info(
99
+ f"HF dataset chunk {chunk.chunk_id}: yielded {items_processed} items "
100
+ f"from ranges {unprocessed_ranges}"
101
+ )
102
+
103
+ def iterate_chunk_with_metadata(
104
+ self,
105
+ chunk,
106
+ dataset_loader: Optional[DatasetLoader],
107
+ should_stop: Event,
108
+ connected: Event,
109
+ ) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
110
+ """
111
+ Process HuggingFace virtual shard chunk with metadata.
112
+
113
+ Yields:
114
+ Tuple of (key, url, image_data, metadata)
115
+ """
116
+ if not dataset_loader:
117
+ logger.error("No dataset loader configured for HuggingFace dataset shard")
118
+ return
119
+
120
+ # Get unprocessed ranges
121
+ unprocessed_ranges = getattr(chunk, "unprocessed_ranges", [(0, chunk.chunk_size - 1)])
122
+
123
+ logger.info(
124
+ f"Processing HF dataset chunk {chunk.chunk_id} with ranges: {unprocessed_ranges}"
125
+ )
126
+
127
+ items_processed = 0
128
+ current_idx = 0
129
+
130
+ # Construct proper virtual shard URL
131
+ parts = chunk.shard_url.split("_chunk_")
132
+ if len(parts) == 2:
133
+ base_path = parts[0]
134
+ virtual_shard_url = f"{base_path}:chunk:{chunk.start_index}"
135
+ else:
136
+ virtual_shard_url = chunk.shard_url
137
+
138
+ logger.debug(f"Using virtual shard URL: {virtual_shard_url}")
139
+
140
+ # Use the new iterate method that includes metadata
141
+ for key, url, image_data, metadata in dataset_loader.iterate_shard_with_metadata(
142
+ virtual_shard_url
143
+ ):
144
+ # Check if we should stop
145
+ if should_stop.is_set() or not connected.is_set():
146
+ logger.info(f"Stopping chunk processing early due to disconnect")
147
+ break
148
+
149
+ # Check if current index is in any unprocessed range
150
+ in_range = any(start <= current_idx <= end for start, end in unprocessed_ranges)
151
+
152
+ if not in_range:
153
+ current_idx += 1
154
+ continue # Skip already processed items
155
+
156
+ # Check if we've processed enough for this chunk
157
+ if current_idx >= chunk.chunk_size:
158
+ break
159
+
160
+ items_processed += 1
161
+ current_idx += 1
162
+ yield key, url, image_data, metadata
163
+
164
+ logger.info(
165
+ f"HF dataset chunk {chunk.chunk_id}: yielded {items_processed} items "
166
+ f"from ranges {unprocessed_ranges}"
167
+ )
168
+
169
+
170
+ class WebDatasetShardProcessor(ShardProcessor):
171
+ """Processor for WebDataset tar shards with range support."""
172
+
173
+ def __init__(self, hf_token: Optional[str] = None, dataset_type: str = "local"):
174
+ self.hf_token = hf_token
175
+ self.dataset_type = dataset_type
176
+
177
+ def iterate_chunk(
178
+ self,
179
+ chunk,
180
+ dataset_loader: Optional[DatasetLoader],
181
+ should_stop: Event,
182
+ connected: Event,
183
+ ) -> Generator[Tuple[str, str, bytes], None, None]:
184
+ """Process WebDataset shard chunk with unprocessed ranges."""
185
+ # Get unprocessed ranges
186
+ unprocessed_ranges = getattr(chunk, "unprocessed_ranges", [(0, chunk.chunk_size - 1)])
187
+
188
+ logger.info(
189
+ f"Processing WebDataset chunk {chunk.chunk_id} with ranges: {unprocessed_ranges}"
190
+ )
191
+
192
+ # Create WebDataset pipeline
193
+ if self.dataset_type == "huggingface" and not chunk.shard_url.startswith("hf_dataset:"):
194
+ # Use curl with auth for HuggingFace WebDataset
195
+ url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.hf_token)}' {shlex.quote(chunk.shard_url)} || true"
196
+ ds = wds.DataPipeline(
197
+ wds.SimpleShardList(url_cmd),
198
+ wds.tarfile_to_samples(),
199
+ wds.to_tuple("__key__", "jpg;png;jpeg;webp;jxl"),
200
+ )
201
+ else:
202
+ # Local file
203
+ ds = wds.DataPipeline(
204
+ wds.SimpleShardList(chunk.shard_url),
205
+ wds.tarfile_to_samples(),
206
+ wds.to_tuple("__key__", "jpg;png;jpeg;webp;jxl"),
207
+ )
208
+
209
+ # Process items
210
+ current_idx = 0
211
+ items_yielded = 0
212
+
213
+ for key, image_data in ds:
214
+ # Check if we should stop
215
+ if should_stop.is_set() or not connected.is_set():
216
+ logger.info(f"Stopping WebDataset chunk processing early due to disconnect")
217
+ break
218
+
219
+ # Calculate relative index within chunk
220
+ relative_idx = current_idx - chunk.start_index
221
+
222
+ # Skip items before chunk start
223
+ if current_idx < chunk.start_index:
224
+ current_idx += 1
225
+ continue
226
+
227
+ # Stop if beyond chunk
228
+ if relative_idx >= chunk.chunk_size:
229
+ break
230
+
231
+ # Check if current index is in any unprocessed range
232
+ in_range = any(start <= relative_idx <= end for start, end in unprocessed_ranges)
233
+
234
+ if in_range:
235
+ items_yielded += 1
236
+ yield key, chunk.shard_url, image_data
237
+
238
+ current_idx += 1
239
+
240
+ logger.info(
241
+ f"WebDataset chunk {chunk.chunk_id}: yielded {items_yielded} items "
242
+ f"from ranges {unprocessed_ranges}"
243
+ )
244
+
245
+ def iterate_chunk_with_metadata(
246
+ self,
247
+ chunk,
248
+ dataset_loader: Optional[DatasetLoader],
249
+ should_stop: Event,
250
+ connected: Event,
251
+ ) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
252
+ """Process WebDataset shard chunk with metadata and range support."""
253
+ # Get unprocessed ranges
254
+ unprocessed_ranges = getattr(chunk, "unprocessed_ranges", [(0, chunk.chunk_size - 1)])
255
+
256
+ logger.info(
257
+ f"Processing WebDataset chunk {chunk.chunk_id} with ranges: {unprocessed_ranges}"
258
+ )
259
+
260
+ # Create WebDataset pipeline
261
+ if self.dataset_type == "huggingface" and not chunk.shard_url.startswith("hf_dataset:"):
262
+ # Use curl with auth for HuggingFace WebDataset
263
+ url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.hf_token)}' {shlex.quote(chunk.shard_url)} || true"
264
+ ds = wds.DataPipeline(
265
+ wds.SimpleShardList(url_cmd),
266
+ wds.tarfile_to_samples(),
267
+ wds.to_tuple("__key__", "jpg;png;jpeg;webp;jxl"),
268
+ )
269
+ else:
270
+ # Local file
271
+ ds = wds.DataPipeline(
272
+ wds.SimpleShardList(chunk.shard_url),
273
+ wds.tarfile_to_samples(),
274
+ wds.to_tuple("__key__", "jpg;png;jpeg;webp;jxl"),
275
+ )
276
+
277
+ # Process items
278
+ absolute_idx = 0 # Absolute index in the shard
279
+ items_yielded = 0
280
+
281
+ for key, image_data in ds:
282
+ # Check if we should stop
283
+ if should_stop.is_set() or not connected.is_set():
284
+ logger.info(f"Stopping WebDataset chunk processing early due to disconnect")
285
+ break
286
+
287
+ # Skip items before chunk start
288
+ if absolute_idx < chunk.start_index:
289
+ absolute_idx += 1
290
+ continue
291
+
292
+ # Calculate relative index within chunk
293
+ relative_idx = absolute_idx - chunk.start_index
294
+
295
+ # Stop if beyond chunk
296
+ if relative_idx >= chunk.chunk_size:
297
+ break
298
+
299
+ # Check if current index is in any unprocessed range
300
+ in_range = any(start <= relative_idx <= end for start, end in unprocessed_ranges)
301
+
302
+ if in_range:
303
+ # Create metadata with the relative index
304
+ metadata = {
305
+ "_chunk_relative_index": relative_idx,
306
+ }
307
+ items_yielded += 1
308
+ yield key, chunk.shard_url, image_data, metadata
309
+
310
+ absolute_idx += 1
311
+
312
+ logger.info(
313
+ f"WebDataset chunk {chunk.chunk_id}: yielded {items_yielded} items "
314
+ f"from ranges {unprocessed_ranges}"
315
+ )
@@ -0,0 +1,87 @@
1
+ """Shard tracking using CheckpointTracker base class."""
2
+
3
+ from pathlib import Path
4
+ from typing import Dict, Any, List, Set
5
+
6
+ from .checkpoint_tracker import CheckpointTracker
7
+
8
+
9
+ class ShardTracker(CheckpointTracker):
10
+ """Tracks shard processing progress."""
11
+
12
+ def __init__(self, checkpoint_path: Path):
13
+ """Initialize shard tracker with checkpoint file."""
14
+ self.completed_shards: Set[str] = set()
15
+ self.partial_shards: Dict[str, Dict[str, Any]] = {}
16
+ super().__init__(checkpoint_path)
17
+
18
+ def _get_default_state(self) -> Dict[str, Any]:
19
+ """Return default state structure for new checkpoints."""
20
+ return {"completed_shards": [], "partial_shards": {}}
21
+
22
+ def _deserialize_state(self, data: Dict[str, Any]) -> None:
23
+ """Deserialize loaded data into instance state."""
24
+ self.completed_shards = set(data.get("completed_shards", []))
25
+ self.partial_shards = data.get("partial_shards", {})
26
+
27
+ def _serialize_state(self) -> Dict[str, Any]:
28
+ """Serialize instance state for saving."""
29
+ return {
30
+ "completed_shards": list(self.completed_shards),
31
+ "partial_shards": self.partial_shards,
32
+ }
33
+
34
+ def mark_complete(self, shard_name: str) -> None:
35
+ """Mark a shard as complete."""
36
+ self.completed_shards.add(shard_name)
37
+ if shard_name in self.partial_shards:
38
+ del self.partial_shards[shard_name]
39
+ self.save()
40
+
41
+ def update_partial(self, shard_name: str, processed_keys: List[str]) -> None:
42
+ """Update partial progress for a shard."""
43
+ self.partial_shards[shard_name] = {"keys": processed_keys, "count": len(processed_keys)}
44
+ self.save()
45
+
46
+ def get_processed_keys(self, shard_name: str) -> Set[str]:
47
+ """Get set of processed keys for a shard."""
48
+ if shard_name in self.completed_shards:
49
+ return set() # All done
50
+
51
+ if shard_name in self.partial_shards:
52
+ return set(self.partial_shards[shard_name].get("keys", []))
53
+
54
+ return set()
55
+
56
+ def is_complete(self, shard_name: str) -> bool:
57
+ """Check if a shard is complete."""
58
+ return shard_name in self.completed_shards
59
+
60
+ def get_remaining_shards(self, all_shards: List[str]) -> List[str]:
61
+ """Get list of shards that still need processing."""
62
+ remaining = []
63
+ for s in all_shards:
64
+ # Extract shard name properly for both regular and virtual shards
65
+ if s.startswith("hf_dataset:"):
66
+ shard_name = s # Use full virtual shard ID
67
+ else:
68
+ shard_name = Path(s).stem
69
+
70
+ if shard_name not in self.completed_shards:
71
+ remaining.append(s)
72
+
73
+ return remaining
74
+
75
+ def get_stats(self) -> Dict[str, Any]:
76
+ """Get shard tracking statistics."""
77
+ base_stats = super().get_stats()
78
+ base_stats.update(
79
+ {
80
+ "completed_shards": len(self.completed_shards),
81
+ "partial_shards": len(self.partial_shards),
82
+ "total_partial_keys": sum(
83
+ len(data.get("keys", [])) for data in self.partial_shards.values()
84
+ ),
85
+ }
86
+ )
87
+ return base_stats