caption-flow 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +3 -2
- caption_flow/cli.py +65 -42
- caption_flow/models.py +6 -4
- caption_flow/monitor.py +13 -3
- caption_flow/orchestrator.py +1049 -264
- caption_flow/storage.py +579 -222
- caption_flow/utils/__init__.py +3 -1
- caption_flow/utils/auth.py +24 -25
- caption_flow/utils/checkpoint_tracker.py +92 -0
- caption_flow/utils/chunk_tracker.py +278 -194
- caption_flow/utils/dataset_loader.py +567 -73
- caption_flow/utils/image_processor.py +121 -1
- caption_flow/utils/prompt_template.py +137 -0
- caption_flow/utils/shard_processor.py +315 -0
- caption_flow/utils/shard_tracker.py +87 -0
- caption_flow/workers/base.py +228 -0
- caption_flow/workers/caption.py +1321 -0
- caption_flow/{worker_data.py → workers/data.py} +162 -234
- caption_flow-0.2.1.dist-info/METADATA +370 -0
- caption_flow-0.2.1.dist-info/RECORD +29 -0
- caption_flow/worker.py +0 -300
- caption_flow/worker_vllm.py +0 -1028
- caption_flow-0.1.0.dist-info/METADATA +0 -427
- caption_flow-0.1.0.dist-info/RECORD +0 -25
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.1.dist-info}/WHEEL +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.1.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,20 @@
|
|
1
1
|
"""Image preprocessing utilities."""
|
2
2
|
|
3
3
|
import asyncio
|
4
|
+
import logging
|
4
5
|
from concurrent.futures import ProcessPoolExecutor
|
6
|
+
from io import BytesIO
|
5
7
|
from pathlib import Path
|
6
|
-
from typing import List, Any
|
8
|
+
from typing import List, Any, Optional, Tuple, Union
|
7
9
|
|
8
10
|
import numpy as np
|
11
|
+
import requests
|
9
12
|
from PIL import Image
|
10
13
|
|
11
14
|
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
|
12
18
|
class ImageProcessor:
|
13
19
|
"""Handles image loading and preprocessing."""
|
14
20
|
|
@@ -46,6 +52,120 @@ class ImageProcessor:
|
|
46
52
|
|
47
53
|
return arr
|
48
54
|
|
55
|
+
@staticmethod
|
56
|
+
def process_image_data(img_data: Union[str, bytes, Image.Image]) -> Optional[bytes]:
|
57
|
+
"""
|
58
|
+
Process various types of image data into bytes.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
img_data: Can be a URL string, bytes, or PIL Image
|
62
|
+
|
63
|
+
Returns:
|
64
|
+
Image data as bytes, or None if processing failed
|
65
|
+
"""
|
66
|
+
try:
|
67
|
+
if isinstance(img_data, str):
|
68
|
+
# It's a URL - download the image
|
69
|
+
try:
|
70
|
+
# Download with timeout
|
71
|
+
response = requests.get(
|
72
|
+
img_data,
|
73
|
+
timeout=30,
|
74
|
+
headers={"User-Agent": "Mozilla/5.0 (captionflow-dataset-loader)"},
|
75
|
+
)
|
76
|
+
response.raise_for_status()
|
77
|
+
image_data = response.content
|
78
|
+
|
79
|
+
# Verify it's an image by trying to open it
|
80
|
+
img = Image.open(BytesIO(image_data))
|
81
|
+
img.verify() # Verify it's a valid image
|
82
|
+
|
83
|
+
return image_data
|
84
|
+
|
85
|
+
except Exception as e:
|
86
|
+
logger.error(f"Failed to download image from {img_data}: {e}")
|
87
|
+
return None
|
88
|
+
|
89
|
+
elif hasattr(img_data, "__class__") and "Image" in str(img_data.__class__):
|
90
|
+
# It's a PIL Image object
|
91
|
+
import io
|
92
|
+
|
93
|
+
# Save as PNG bytes
|
94
|
+
img_bytes = io.BytesIO()
|
95
|
+
# Convert to RGB
|
96
|
+
img_data = img_data.convert("RGB")
|
97
|
+
img_data.save(img_bytes, format="PNG")
|
98
|
+
return img_bytes.getvalue()
|
99
|
+
|
100
|
+
elif isinstance(img_data, bytes):
|
101
|
+
# Already bytes - validate it's an image
|
102
|
+
try:
|
103
|
+
img = Image.open(BytesIO(img_data))
|
104
|
+
img.verify()
|
105
|
+
return img_data
|
106
|
+
except Exception as e:
|
107
|
+
logger.error(f"Invalid image data: {e}")
|
108
|
+
return None
|
109
|
+
|
110
|
+
else:
|
111
|
+
logger.warning(f"Unknown image data type: {type(img_data)}")
|
112
|
+
return None
|
113
|
+
|
114
|
+
except Exception as e:
|
115
|
+
logger.error(f"Error processing image data: {e}")
|
116
|
+
import traceback
|
117
|
+
|
118
|
+
logger.error(traceback.format_exc())
|
119
|
+
return None
|
120
|
+
|
121
|
+
@staticmethod
|
122
|
+
def prepare_for_inference(image: Image.Image) -> Image.Image:
|
123
|
+
"""
|
124
|
+
Prepare image for inference, handling transparency and mostly black/white images.
|
125
|
+
|
126
|
+
Args:
|
127
|
+
image: PIL Image to prepare
|
128
|
+
|
129
|
+
Returns:
|
130
|
+
Prepared PIL Image
|
131
|
+
"""
|
132
|
+
# Convert to RGBA to handle transparency
|
133
|
+
img_rgba = image.convert("RGBA")
|
134
|
+
rgb_img = img_rgba.convert("RGB")
|
135
|
+
np_img = np.array(rgb_img)
|
136
|
+
|
137
|
+
# Calculate percentage of pixels that are (0,0,0) or (255,255,255)
|
138
|
+
total_pixels = np_img.shape[0] * np_img.shape[1]
|
139
|
+
black_pixels = np.all(np_img == [0, 0, 0], axis=-1).sum()
|
140
|
+
white_pixels = np.all(np_img == [255, 255, 255], axis=-1).sum()
|
141
|
+
black_pct = black_pixels / total_pixels
|
142
|
+
white_pct = white_pixels / total_pixels
|
143
|
+
|
144
|
+
threshold = 0.90 # 90% threshold
|
145
|
+
|
146
|
+
is_mostly_black = black_pct >= threshold
|
147
|
+
is_mostly_white = white_pct >= threshold
|
148
|
+
|
149
|
+
if is_mostly_black or is_mostly_white:
|
150
|
+
# Replace background with opposite color for better contrast
|
151
|
+
bg_color = (255, 255, 255) if is_mostly_black else (0, 0, 0)
|
152
|
+
background = Image.new("RGB", img_rgba.size, bg_color)
|
153
|
+
# Use alpha channel as mask if present
|
154
|
+
if img_rgba.mode == "RGBA":
|
155
|
+
background.paste(img_rgba.convert("RGB"), mask=img_rgba.split()[3])
|
156
|
+
else:
|
157
|
+
background.paste(img_rgba.convert("RGB"))
|
158
|
+
|
159
|
+
color_type = "black" if is_mostly_black else "white"
|
160
|
+
pct = black_pct if is_mostly_black else white_pct
|
161
|
+
logger.debug(
|
162
|
+
f"Image is {pct*100:.1f}% {color_type}; background replaced with {bg_color}"
|
163
|
+
)
|
164
|
+
|
165
|
+
return background
|
166
|
+
else:
|
167
|
+
return rgb_img
|
168
|
+
|
49
169
|
def shutdown(self):
|
50
170
|
"""Shutdown the executor."""
|
51
171
|
self.executor.shutdown(wait=True)
|
@@ -0,0 +1,137 @@
|
|
1
|
+
"""Prompt template system for dynamic column substitution."""
|
2
|
+
|
3
|
+
import re
|
4
|
+
import logging
|
5
|
+
from typing import Dict, Any, List, Optional
|
6
|
+
|
7
|
+
logger = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
|
10
|
+
class PromptTemplate:
|
11
|
+
"""Handles prompt templates with column substitution."""
|
12
|
+
|
13
|
+
# Pattern to match {column:column_name} or {col:column_name}
|
14
|
+
COLUMN_PATTERN = re.compile(r"\{(?:column|col):([\w-]+)\}")
|
15
|
+
|
16
|
+
def __init__(self, template: str):
|
17
|
+
"""
|
18
|
+
Initialize with a prompt template.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
template: Prompt template string, e.g.
|
22
|
+
"describe this image. tags: {column:user_tags}"
|
23
|
+
"""
|
24
|
+
self.template = template
|
25
|
+
self.required_columns = self._extract_columns()
|
26
|
+
|
27
|
+
def _extract_columns(self) -> List[str]:
|
28
|
+
"""Extract required column names from template."""
|
29
|
+
matches = self.COLUMN_PATTERN.findall(self.template)
|
30
|
+
return list(set(matches)) # Remove duplicates
|
31
|
+
|
32
|
+
def format(self, item_data: Dict[str, Any]) -> str:
|
33
|
+
"""
|
34
|
+
Format the template with actual column values.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
item_data: Dictionary containing column values from dataset
|
38
|
+
|
39
|
+
Returns:
|
40
|
+
Formatted prompt string
|
41
|
+
"""
|
42
|
+
prompt = self.template
|
43
|
+
|
44
|
+
# Replace all column references
|
45
|
+
for match in self.COLUMN_PATTERN.finditer(self.template):
|
46
|
+
full_match = match.group(0) # e.g., {column:user_tags}
|
47
|
+
column_name = match.group(1) # e.g., user_tags
|
48
|
+
|
49
|
+
# Get column value with fallback
|
50
|
+
value = item_data.get(column_name, "")
|
51
|
+
|
52
|
+
# Handle different value types
|
53
|
+
if value is None:
|
54
|
+
value = ""
|
55
|
+
elif isinstance(value, list):
|
56
|
+
# Join list items with commas
|
57
|
+
value = ", ".join(str(v) for v in value if v)
|
58
|
+
elif not isinstance(value, str):
|
59
|
+
value = str(value)
|
60
|
+
|
61
|
+
# Replace in prompt
|
62
|
+
prompt = prompt.replace(full_match, value)
|
63
|
+
|
64
|
+
return prompt.strip()
|
65
|
+
|
66
|
+
def validate_columns(self, available_columns: List[str]) -> List[str]:
|
67
|
+
"""
|
68
|
+
Validate that required columns are available.
|
69
|
+
|
70
|
+
Returns:
|
71
|
+
List of missing column names
|
72
|
+
"""
|
73
|
+
missing = []
|
74
|
+
for col in self.required_columns:
|
75
|
+
if col not in available_columns:
|
76
|
+
missing.append(col)
|
77
|
+
return missing
|
78
|
+
|
79
|
+
|
80
|
+
class PromptTemplateManager:
|
81
|
+
"""Manages multiple prompt templates."""
|
82
|
+
|
83
|
+
def __init__(self, prompts: List[str]):
|
84
|
+
"""
|
85
|
+
Initialize with list of prompt strings (which may contain templates).
|
86
|
+
|
87
|
+
Args:
|
88
|
+
prompts: List of prompt strings
|
89
|
+
"""
|
90
|
+
self.templates = [PromptTemplate(p) for p in prompts]
|
91
|
+
self._all_required_columns = None
|
92
|
+
|
93
|
+
@property
|
94
|
+
def required_columns(self) -> List[str]:
|
95
|
+
"""Get all required columns across all templates."""
|
96
|
+
if self._all_required_columns is None:
|
97
|
+
cols = set()
|
98
|
+
for template in self.templates:
|
99
|
+
cols.update(template.required_columns)
|
100
|
+
self._all_required_columns = list(cols)
|
101
|
+
return self._all_required_columns
|
102
|
+
|
103
|
+
def format_all(self, item_data: Dict[str, Any]) -> List[str]:
|
104
|
+
"""
|
105
|
+
Format all templates with item data.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
item_data: Dictionary containing column values
|
109
|
+
|
110
|
+
Returns:
|
111
|
+
List of formatted prompts
|
112
|
+
"""
|
113
|
+
formatted = []
|
114
|
+
for template in self.templates:
|
115
|
+
try:
|
116
|
+
prompt = template.format(item_data)
|
117
|
+
formatted.append(prompt)
|
118
|
+
except Exception as e:
|
119
|
+
logger.error(f"Error formatting prompt template '{template.template}': {e}")
|
120
|
+
# Fall back to raw template
|
121
|
+
formatted.append(template.template)
|
122
|
+
|
123
|
+
return formatted
|
124
|
+
|
125
|
+
def validate_all(self, available_columns: List[str]) -> Dict[str, List[str]]:
|
126
|
+
"""
|
127
|
+
Validate all templates against available columns.
|
128
|
+
|
129
|
+
Returns:
|
130
|
+
Dict mapping template string to list of missing columns
|
131
|
+
"""
|
132
|
+
issues = {}
|
133
|
+
for template in self.templates:
|
134
|
+
missing = template.validate_columns(available_columns)
|
135
|
+
if missing:
|
136
|
+
issues[template.template] = missing
|
137
|
+
return issues
|
@@ -0,0 +1,315 @@
|
|
1
|
+
"""Shard processing abstraction for different dataset types."""
|
2
|
+
|
3
|
+
import io
|
4
|
+
import logging
|
5
|
+
import time
|
6
|
+
from abc import ABC, abstractmethod
|
7
|
+
from pathlib import Path
|
8
|
+
from typing import Generator, Tuple, Optional, Dict, Any
|
9
|
+
from dataclasses import dataclass
|
10
|
+
from threading import Event
|
11
|
+
import shlex
|
12
|
+
|
13
|
+
import webdataset as wds
|
14
|
+
from PIL import Image
|
15
|
+
|
16
|
+
from .dataset_loader import DatasetLoader
|
17
|
+
|
18
|
+
logger = logging.getLogger(__name__)
|
19
|
+
|
20
|
+
|
21
|
+
class ShardProcessor(ABC):
|
22
|
+
"""Abstract base for processing dataset shards."""
|
23
|
+
|
24
|
+
@abstractmethod
|
25
|
+
def iterate_chunk(
|
26
|
+
self,
|
27
|
+
chunk,
|
28
|
+
dataset_loader: Optional[DatasetLoader],
|
29
|
+
should_stop: Event,
|
30
|
+
connected: Event,
|
31
|
+
) -> Generator[Tuple[str, str, bytes], None, None]:
|
32
|
+
"""
|
33
|
+
Iterate through items in a chunk.
|
34
|
+
|
35
|
+
Yields:
|
36
|
+
Tuple of (key, url, image_data)
|
37
|
+
"""
|
38
|
+
pass
|
39
|
+
|
40
|
+
|
41
|
+
class HFDatasetShardProcessor(ShardProcessor):
|
42
|
+
"""Processor for HuggingFace virtual dataset shards."""
|
43
|
+
|
44
|
+
def iterate_chunk(
|
45
|
+
self,
|
46
|
+
chunk,
|
47
|
+
dataset_loader: Optional[DatasetLoader],
|
48
|
+
should_stop: Event,
|
49
|
+
connected: Event,
|
50
|
+
) -> Generator[Tuple[str, str, bytes], None, None]:
|
51
|
+
"""Process HuggingFace virtual shard chunk."""
|
52
|
+
if not dataset_loader:
|
53
|
+
logger.error("No dataset loader configured for HuggingFace dataset shard")
|
54
|
+
return
|
55
|
+
|
56
|
+
# Get unprocessed ranges
|
57
|
+
unprocessed_ranges = getattr(chunk, "unprocessed_ranges", [(0, chunk.chunk_size - 1)])
|
58
|
+
|
59
|
+
logger.info(
|
60
|
+
f"Processing HF dataset chunk {chunk.chunk_id} with ranges: {unprocessed_ranges}"
|
61
|
+
)
|
62
|
+
|
63
|
+
items_processed = 0
|
64
|
+
current_idx = 0
|
65
|
+
|
66
|
+
# Construct proper virtual shard URL
|
67
|
+
parts = chunk.shard_url.split("_chunk_")
|
68
|
+
if len(parts) == 2:
|
69
|
+
base_path = parts[0]
|
70
|
+
virtual_shard_url = f"{base_path}:chunk:{chunk.start_index}"
|
71
|
+
else:
|
72
|
+
virtual_shard_url = chunk.shard_url
|
73
|
+
|
74
|
+
logger.debug(f"Using virtual shard URL: {virtual_shard_url}")
|
75
|
+
|
76
|
+
# Iterate through the virtual shard
|
77
|
+
for key, url, image_data in dataset_loader.iterate_shard(virtual_shard_url):
|
78
|
+
# Check if we should stop
|
79
|
+
if should_stop.is_set() or not connected.is_set():
|
80
|
+
logger.info(f"Stopping chunk processing early due to disconnect")
|
81
|
+
break
|
82
|
+
|
83
|
+
# Check if current index is in any unprocessed range
|
84
|
+
in_range = any(start <= current_idx <= end for start, end in unprocessed_ranges)
|
85
|
+
|
86
|
+
if not in_range:
|
87
|
+
current_idx += 1
|
88
|
+
continue # Skip already processed items
|
89
|
+
|
90
|
+
# Check if we've processed enough for this chunk
|
91
|
+
if current_idx >= chunk.chunk_size:
|
92
|
+
break
|
93
|
+
|
94
|
+
items_processed += 1
|
95
|
+
current_idx += 1
|
96
|
+
yield key, url, image_data
|
97
|
+
|
98
|
+
logger.info(
|
99
|
+
f"HF dataset chunk {chunk.chunk_id}: yielded {items_processed} items "
|
100
|
+
f"from ranges {unprocessed_ranges}"
|
101
|
+
)
|
102
|
+
|
103
|
+
def iterate_chunk_with_metadata(
|
104
|
+
self,
|
105
|
+
chunk,
|
106
|
+
dataset_loader: Optional[DatasetLoader],
|
107
|
+
should_stop: Event,
|
108
|
+
connected: Event,
|
109
|
+
) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
|
110
|
+
"""
|
111
|
+
Process HuggingFace virtual shard chunk with metadata.
|
112
|
+
|
113
|
+
Yields:
|
114
|
+
Tuple of (key, url, image_data, metadata)
|
115
|
+
"""
|
116
|
+
if not dataset_loader:
|
117
|
+
logger.error("No dataset loader configured for HuggingFace dataset shard")
|
118
|
+
return
|
119
|
+
|
120
|
+
# Get unprocessed ranges
|
121
|
+
unprocessed_ranges = getattr(chunk, "unprocessed_ranges", [(0, chunk.chunk_size - 1)])
|
122
|
+
|
123
|
+
logger.info(
|
124
|
+
f"Processing HF dataset chunk {chunk.chunk_id} with ranges: {unprocessed_ranges}"
|
125
|
+
)
|
126
|
+
|
127
|
+
items_processed = 0
|
128
|
+
current_idx = 0
|
129
|
+
|
130
|
+
# Construct proper virtual shard URL
|
131
|
+
parts = chunk.shard_url.split("_chunk_")
|
132
|
+
if len(parts) == 2:
|
133
|
+
base_path = parts[0]
|
134
|
+
virtual_shard_url = f"{base_path}:chunk:{chunk.start_index}"
|
135
|
+
else:
|
136
|
+
virtual_shard_url = chunk.shard_url
|
137
|
+
|
138
|
+
logger.debug(f"Using virtual shard URL: {virtual_shard_url}")
|
139
|
+
|
140
|
+
# Use the new iterate method that includes metadata
|
141
|
+
for key, url, image_data, metadata in dataset_loader.iterate_shard_with_metadata(
|
142
|
+
virtual_shard_url
|
143
|
+
):
|
144
|
+
# Check if we should stop
|
145
|
+
if should_stop.is_set() or not connected.is_set():
|
146
|
+
logger.info(f"Stopping chunk processing early due to disconnect")
|
147
|
+
break
|
148
|
+
|
149
|
+
# Check if current index is in any unprocessed range
|
150
|
+
in_range = any(start <= current_idx <= end for start, end in unprocessed_ranges)
|
151
|
+
|
152
|
+
if not in_range:
|
153
|
+
current_idx += 1
|
154
|
+
continue # Skip already processed items
|
155
|
+
|
156
|
+
# Check if we've processed enough for this chunk
|
157
|
+
if current_idx >= chunk.chunk_size:
|
158
|
+
break
|
159
|
+
|
160
|
+
items_processed += 1
|
161
|
+
current_idx += 1
|
162
|
+
yield key, url, image_data, metadata
|
163
|
+
|
164
|
+
logger.info(
|
165
|
+
f"HF dataset chunk {chunk.chunk_id}: yielded {items_processed} items "
|
166
|
+
f"from ranges {unprocessed_ranges}"
|
167
|
+
)
|
168
|
+
|
169
|
+
|
170
|
+
class WebDatasetShardProcessor(ShardProcessor):
|
171
|
+
"""Processor for WebDataset tar shards with range support."""
|
172
|
+
|
173
|
+
def __init__(self, hf_token: Optional[str] = None, dataset_type: str = "local"):
|
174
|
+
self.hf_token = hf_token
|
175
|
+
self.dataset_type = dataset_type
|
176
|
+
|
177
|
+
def iterate_chunk(
|
178
|
+
self,
|
179
|
+
chunk,
|
180
|
+
dataset_loader: Optional[DatasetLoader],
|
181
|
+
should_stop: Event,
|
182
|
+
connected: Event,
|
183
|
+
) -> Generator[Tuple[str, str, bytes], None, None]:
|
184
|
+
"""Process WebDataset shard chunk with unprocessed ranges."""
|
185
|
+
# Get unprocessed ranges
|
186
|
+
unprocessed_ranges = getattr(chunk, "unprocessed_ranges", [(0, chunk.chunk_size - 1)])
|
187
|
+
|
188
|
+
logger.info(
|
189
|
+
f"Processing WebDataset chunk {chunk.chunk_id} with ranges: {unprocessed_ranges}"
|
190
|
+
)
|
191
|
+
|
192
|
+
# Create WebDataset pipeline
|
193
|
+
if self.dataset_type == "huggingface" and not chunk.shard_url.startswith("hf_dataset:"):
|
194
|
+
# Use curl with auth for HuggingFace WebDataset
|
195
|
+
url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.hf_token)}' {shlex.quote(chunk.shard_url)} || true"
|
196
|
+
ds = wds.DataPipeline(
|
197
|
+
wds.SimpleShardList(url_cmd),
|
198
|
+
wds.tarfile_to_samples(),
|
199
|
+
wds.to_tuple("__key__", "jpg;png;jpeg;webp;jxl"),
|
200
|
+
)
|
201
|
+
else:
|
202
|
+
# Local file
|
203
|
+
ds = wds.DataPipeline(
|
204
|
+
wds.SimpleShardList(chunk.shard_url),
|
205
|
+
wds.tarfile_to_samples(),
|
206
|
+
wds.to_tuple("__key__", "jpg;png;jpeg;webp;jxl"),
|
207
|
+
)
|
208
|
+
|
209
|
+
# Process items
|
210
|
+
current_idx = 0
|
211
|
+
items_yielded = 0
|
212
|
+
|
213
|
+
for key, image_data in ds:
|
214
|
+
# Check if we should stop
|
215
|
+
if should_stop.is_set() or not connected.is_set():
|
216
|
+
logger.info(f"Stopping WebDataset chunk processing early due to disconnect")
|
217
|
+
break
|
218
|
+
|
219
|
+
# Calculate relative index within chunk
|
220
|
+
relative_idx = current_idx - chunk.start_index
|
221
|
+
|
222
|
+
# Skip items before chunk start
|
223
|
+
if current_idx < chunk.start_index:
|
224
|
+
current_idx += 1
|
225
|
+
continue
|
226
|
+
|
227
|
+
# Stop if beyond chunk
|
228
|
+
if relative_idx >= chunk.chunk_size:
|
229
|
+
break
|
230
|
+
|
231
|
+
# Check if current index is in any unprocessed range
|
232
|
+
in_range = any(start <= relative_idx <= end for start, end in unprocessed_ranges)
|
233
|
+
|
234
|
+
if in_range:
|
235
|
+
items_yielded += 1
|
236
|
+
yield key, chunk.shard_url, image_data
|
237
|
+
|
238
|
+
current_idx += 1
|
239
|
+
|
240
|
+
logger.info(
|
241
|
+
f"WebDataset chunk {chunk.chunk_id}: yielded {items_yielded} items "
|
242
|
+
f"from ranges {unprocessed_ranges}"
|
243
|
+
)
|
244
|
+
|
245
|
+
def iterate_chunk_with_metadata(
|
246
|
+
self,
|
247
|
+
chunk,
|
248
|
+
dataset_loader: Optional[DatasetLoader],
|
249
|
+
should_stop: Event,
|
250
|
+
connected: Event,
|
251
|
+
) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
|
252
|
+
"""Process WebDataset shard chunk with metadata and range support."""
|
253
|
+
# Get unprocessed ranges
|
254
|
+
unprocessed_ranges = getattr(chunk, "unprocessed_ranges", [(0, chunk.chunk_size - 1)])
|
255
|
+
|
256
|
+
logger.info(
|
257
|
+
f"Processing WebDataset chunk {chunk.chunk_id} with ranges: {unprocessed_ranges}"
|
258
|
+
)
|
259
|
+
|
260
|
+
# Create WebDataset pipeline
|
261
|
+
if self.dataset_type == "huggingface" and not chunk.shard_url.startswith("hf_dataset:"):
|
262
|
+
# Use curl with auth for HuggingFace WebDataset
|
263
|
+
url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.hf_token)}' {shlex.quote(chunk.shard_url)} || true"
|
264
|
+
ds = wds.DataPipeline(
|
265
|
+
wds.SimpleShardList(url_cmd),
|
266
|
+
wds.tarfile_to_samples(),
|
267
|
+
wds.to_tuple("__key__", "jpg;png;jpeg;webp;jxl"),
|
268
|
+
)
|
269
|
+
else:
|
270
|
+
# Local file
|
271
|
+
ds = wds.DataPipeline(
|
272
|
+
wds.SimpleShardList(chunk.shard_url),
|
273
|
+
wds.tarfile_to_samples(),
|
274
|
+
wds.to_tuple("__key__", "jpg;png;jpeg;webp;jxl"),
|
275
|
+
)
|
276
|
+
|
277
|
+
# Process items
|
278
|
+
absolute_idx = 0 # Absolute index in the shard
|
279
|
+
items_yielded = 0
|
280
|
+
|
281
|
+
for key, image_data in ds:
|
282
|
+
# Check if we should stop
|
283
|
+
if should_stop.is_set() or not connected.is_set():
|
284
|
+
logger.info(f"Stopping WebDataset chunk processing early due to disconnect")
|
285
|
+
break
|
286
|
+
|
287
|
+
# Skip items before chunk start
|
288
|
+
if absolute_idx < chunk.start_index:
|
289
|
+
absolute_idx += 1
|
290
|
+
continue
|
291
|
+
|
292
|
+
# Calculate relative index within chunk
|
293
|
+
relative_idx = absolute_idx - chunk.start_index
|
294
|
+
|
295
|
+
# Stop if beyond chunk
|
296
|
+
if relative_idx >= chunk.chunk_size:
|
297
|
+
break
|
298
|
+
|
299
|
+
# Check if current index is in any unprocessed range
|
300
|
+
in_range = any(start <= relative_idx <= end for start, end in unprocessed_ranges)
|
301
|
+
|
302
|
+
if in_range:
|
303
|
+
# Create metadata with the relative index
|
304
|
+
metadata = {
|
305
|
+
"_chunk_relative_index": relative_idx,
|
306
|
+
}
|
307
|
+
items_yielded += 1
|
308
|
+
yield key, chunk.shard_url, image_data, metadata
|
309
|
+
|
310
|
+
absolute_idx += 1
|
311
|
+
|
312
|
+
logger.info(
|
313
|
+
f"WebDataset chunk {chunk.chunk_id}: yielded {items_yielded} items "
|
314
|
+
f"from ranges {unprocessed_ranges}"
|
315
|
+
)
|
@@ -0,0 +1,87 @@
|
|
1
|
+
"""Shard tracking using CheckpointTracker base class."""
|
2
|
+
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import Dict, Any, List, Set
|
5
|
+
|
6
|
+
from .checkpoint_tracker import CheckpointTracker
|
7
|
+
|
8
|
+
|
9
|
+
class ShardTracker(CheckpointTracker):
|
10
|
+
"""Tracks shard processing progress."""
|
11
|
+
|
12
|
+
def __init__(self, checkpoint_path: Path):
|
13
|
+
"""Initialize shard tracker with checkpoint file."""
|
14
|
+
self.completed_shards: Set[str] = set()
|
15
|
+
self.partial_shards: Dict[str, Dict[str, Any]] = {}
|
16
|
+
super().__init__(checkpoint_path)
|
17
|
+
|
18
|
+
def _get_default_state(self) -> Dict[str, Any]:
|
19
|
+
"""Return default state structure for new checkpoints."""
|
20
|
+
return {"completed_shards": [], "partial_shards": {}}
|
21
|
+
|
22
|
+
def _deserialize_state(self, data: Dict[str, Any]) -> None:
|
23
|
+
"""Deserialize loaded data into instance state."""
|
24
|
+
self.completed_shards = set(data.get("completed_shards", []))
|
25
|
+
self.partial_shards = data.get("partial_shards", {})
|
26
|
+
|
27
|
+
def _serialize_state(self) -> Dict[str, Any]:
|
28
|
+
"""Serialize instance state for saving."""
|
29
|
+
return {
|
30
|
+
"completed_shards": list(self.completed_shards),
|
31
|
+
"partial_shards": self.partial_shards,
|
32
|
+
}
|
33
|
+
|
34
|
+
def mark_complete(self, shard_name: str) -> None:
|
35
|
+
"""Mark a shard as complete."""
|
36
|
+
self.completed_shards.add(shard_name)
|
37
|
+
if shard_name in self.partial_shards:
|
38
|
+
del self.partial_shards[shard_name]
|
39
|
+
self.save()
|
40
|
+
|
41
|
+
def update_partial(self, shard_name: str, processed_keys: List[str]) -> None:
|
42
|
+
"""Update partial progress for a shard."""
|
43
|
+
self.partial_shards[shard_name] = {"keys": processed_keys, "count": len(processed_keys)}
|
44
|
+
self.save()
|
45
|
+
|
46
|
+
def get_processed_keys(self, shard_name: str) -> Set[str]:
|
47
|
+
"""Get set of processed keys for a shard."""
|
48
|
+
if shard_name in self.completed_shards:
|
49
|
+
return set() # All done
|
50
|
+
|
51
|
+
if shard_name in self.partial_shards:
|
52
|
+
return set(self.partial_shards[shard_name].get("keys", []))
|
53
|
+
|
54
|
+
return set()
|
55
|
+
|
56
|
+
def is_complete(self, shard_name: str) -> bool:
|
57
|
+
"""Check if a shard is complete."""
|
58
|
+
return shard_name in self.completed_shards
|
59
|
+
|
60
|
+
def get_remaining_shards(self, all_shards: List[str]) -> List[str]:
|
61
|
+
"""Get list of shards that still need processing."""
|
62
|
+
remaining = []
|
63
|
+
for s in all_shards:
|
64
|
+
# Extract shard name properly for both regular and virtual shards
|
65
|
+
if s.startswith("hf_dataset:"):
|
66
|
+
shard_name = s # Use full virtual shard ID
|
67
|
+
else:
|
68
|
+
shard_name = Path(s).stem
|
69
|
+
|
70
|
+
if shard_name not in self.completed_shards:
|
71
|
+
remaining.append(s)
|
72
|
+
|
73
|
+
return remaining
|
74
|
+
|
75
|
+
def get_stats(self) -> Dict[str, Any]:
|
76
|
+
"""Get shard tracking statistics."""
|
77
|
+
base_stats = super().get_stats()
|
78
|
+
base_stats.update(
|
79
|
+
{
|
80
|
+
"completed_shards": len(self.completed_shards),
|
81
|
+
"partial_shards": len(self.partial_shards),
|
82
|
+
"total_partial_keys": sum(
|
83
|
+
len(data.get("keys", [])) for data in self.partial_shards.values()
|
84
|
+
),
|
85
|
+
}
|
86
|
+
)
|
87
|
+
return base_stats
|