InvokeAI 6.10.0rc1__py3-none-any.whl → 6.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invokeai/app/api/routers/model_manager.py +43 -1
- invokeai/app/invocations/fields.py +1 -1
- invokeai/app/invocations/flux2_denoise.py +499 -0
- invokeai/app/invocations/flux2_klein_model_loader.py +222 -0
- invokeai/app/invocations/flux2_klein_text_encoder.py +222 -0
- invokeai/app/invocations/flux2_vae_decode.py +106 -0
- invokeai/app/invocations/flux2_vae_encode.py +88 -0
- invokeai/app/invocations/flux_denoise.py +77 -3
- invokeai/app/invocations/flux_lora_loader.py +1 -1
- invokeai/app/invocations/flux_model_loader.py +2 -5
- invokeai/app/invocations/ideal_size.py +6 -1
- invokeai/app/invocations/metadata.py +4 -0
- invokeai/app/invocations/metadata_linked.py +47 -0
- invokeai/app/invocations/model.py +1 -0
- invokeai/app/invocations/pbr_maps.py +59 -0
- invokeai/app/invocations/z_image_denoise.py +244 -84
- invokeai/app/invocations/z_image_image_to_latents.py +9 -1
- invokeai/app/invocations/z_image_latents_to_image.py +9 -1
- invokeai/app/invocations/z_image_seed_variance_enhancer.py +110 -0
- invokeai/app/services/config/config_default.py +3 -1
- invokeai/app/services/invocation_stats/invocation_stats_common.py +6 -6
- invokeai/app/services/invocation_stats/invocation_stats_default.py +9 -4
- invokeai/app/services/model_manager/model_manager_default.py +7 -0
- invokeai/app/services/model_records/model_records_base.py +4 -2
- invokeai/app/services/shared/invocation_context.py +15 -0
- invokeai/app/services/shared/sqlite/sqlite_util.py +2 -0
- invokeai/app/services/shared/sqlite_migrator/migrations/migration_25.py +61 -0
- invokeai/app/util/step_callback.py +58 -2
- invokeai/backend/flux/denoise.py +338 -118
- invokeai/backend/flux/dype/__init__.py +31 -0
- invokeai/backend/flux/dype/base.py +260 -0
- invokeai/backend/flux/dype/embed.py +116 -0
- invokeai/backend/flux/dype/presets.py +148 -0
- invokeai/backend/flux/dype/rope.py +110 -0
- invokeai/backend/flux/extensions/dype_extension.py +91 -0
- invokeai/backend/flux/schedulers.py +62 -0
- invokeai/backend/flux/util.py +35 -1
- invokeai/backend/flux2/__init__.py +4 -0
- invokeai/backend/flux2/denoise.py +280 -0
- invokeai/backend/flux2/ref_image_extension.py +294 -0
- invokeai/backend/flux2/sampling_utils.py +209 -0
- invokeai/backend/image_util/pbr_maps/architecture/block.py +367 -0
- invokeai/backend/image_util/pbr_maps/architecture/pbr_rrdb_net.py +70 -0
- invokeai/backend/image_util/pbr_maps/pbr_maps.py +141 -0
- invokeai/backend/image_util/pbr_maps/utils/image_ops.py +93 -0
- invokeai/backend/model_manager/configs/factory.py +19 -1
- invokeai/backend/model_manager/configs/lora.py +36 -0
- invokeai/backend/model_manager/configs/main.py +395 -3
- invokeai/backend/model_manager/configs/qwen3_encoder.py +116 -7
- invokeai/backend/model_manager/configs/vae.py +104 -2
- invokeai/backend/model_manager/load/model_cache/model_cache.py +107 -2
- invokeai/backend/model_manager/load/model_loaders/cogview4.py +2 -1
- invokeai/backend/model_manager/load/model_loaders/flux.py +1020 -8
- invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +4 -2
- invokeai/backend/model_manager/load/model_loaders/onnx.py +1 -0
- invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +2 -1
- invokeai/backend/model_manager/load/model_loaders/z_image.py +158 -31
- invokeai/backend/model_manager/starter_models.py +141 -4
- invokeai/backend/model_manager/taxonomy.py +31 -4
- invokeai/backend/model_manager/util/select_hf_files.py +3 -2
- invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +39 -5
- invokeai/backend/quantization/gguf/ggml_tensor.py +15 -4
- invokeai/backend/util/vae_working_memory.py +0 -2
- invokeai/backend/z_image/extensions/regional_prompting_extension.py +10 -12
- invokeai/frontend/web/dist/assets/App-D13dX7be.js +161 -0
- invokeai/frontend/web/dist/assets/{browser-ponyfill-DHZxq1nk.js → browser-ponyfill-u_ZjhQTI.js} +1 -1
- invokeai/frontend/web/dist/assets/index-BB0nHmDe.js +530 -0
- invokeai/frontend/web/dist/index.html +1 -1
- invokeai/frontend/web/dist/locales/en-GB.json +1 -0
- invokeai/frontend/web/dist/locales/en.json +85 -6
- invokeai/frontend/web/dist/locales/it.json +135 -15
- invokeai/frontend/web/dist/locales/ru.json +11 -11
- invokeai/version/invokeai_version.py +1 -1
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/METADATA +8 -2
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/RECORD +81 -57
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/WHEEL +1 -1
- invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +0 -161
- invokeai/frontend/web/dist/assets/index-dgSJAY--.js +0 -530
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/entry_points.txt +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,294 @@
|
|
|
1
|
+
"""FLUX.2 Klein Reference Image Extension for multi-reference image editing.
|
|
2
|
+
|
|
3
|
+
This module provides the Flux2RefImageExtension for FLUX.2 Klein models,
|
|
4
|
+
which handles encoding reference images using the FLUX.2 VAE and
|
|
5
|
+
generating the appropriate position IDs for multi-reference image editing.
|
|
6
|
+
|
|
7
|
+
FLUX.2 Klein has built-in support for reference image editing (unlike FLUX.1
|
|
8
|
+
which requires a separate Kontext model).
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import math
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
import torch.nn.functional as F
|
|
15
|
+
import torchvision.transforms as T
|
|
16
|
+
from einops import repeat
|
|
17
|
+
from PIL import Image
|
|
18
|
+
|
|
19
|
+
from invokeai.app.invocations.fields import FluxKontextConditioningField
|
|
20
|
+
from invokeai.app.invocations.model import VAEField
|
|
21
|
+
from invokeai.app.services.shared.invocation_context import InvocationContext
|
|
22
|
+
from invokeai.backend.flux2.sampling_utils import pack_flux2
|
|
23
|
+
from invokeai.backend.util.devices import TorchDevice
|
|
24
|
+
|
|
25
|
+
# Maximum pixel counts for reference images (matches BFL FLUX.2 sampling.py)
|
|
26
|
+
# Single reference image: 2024² pixels, Multiple: 1024² pixels
|
|
27
|
+
MAX_PIXELS_SINGLE_REF = 2024**2 # ~4.1M pixels
|
|
28
|
+
MAX_PIXELS_MULTI_REF = 1024**2 # ~1M pixels
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def resize_image_to_max_pixels(image: Image.Image, max_pixels: int) -> Image.Image:
|
|
32
|
+
"""Resize image to fit within max_pixels while preserving aspect ratio.
|
|
33
|
+
|
|
34
|
+
This matches the BFL FLUX.2 sampling.py cap_pixels() behavior.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
image: PIL Image to resize.
|
|
38
|
+
max_pixels: Maximum total pixel count (width * height).
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
Resized PIL Image (or original if already within bounds).
|
|
42
|
+
"""
|
|
43
|
+
width, height = image.size
|
|
44
|
+
pixel_count = width * height
|
|
45
|
+
|
|
46
|
+
if pixel_count <= max_pixels:
|
|
47
|
+
return image
|
|
48
|
+
|
|
49
|
+
# Calculate scale factor to fit within max_pixels (BFL approach)
|
|
50
|
+
scale = math.sqrt(max_pixels / pixel_count)
|
|
51
|
+
new_width = int(width * scale)
|
|
52
|
+
new_height = int(height * scale)
|
|
53
|
+
|
|
54
|
+
# Ensure dimensions are at least 1
|
|
55
|
+
new_width = max(1, new_width)
|
|
56
|
+
new_height = max(1, new_height)
|
|
57
|
+
|
|
58
|
+
return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def generate_img_ids_flux2_with_offset(
|
|
62
|
+
latent_height: int,
|
|
63
|
+
latent_width: int,
|
|
64
|
+
batch_size: int,
|
|
65
|
+
device: torch.device,
|
|
66
|
+
idx_offset: int = 0,
|
|
67
|
+
h_offset: int = 0,
|
|
68
|
+
w_offset: int = 0,
|
|
69
|
+
) -> torch.Tensor:
|
|
70
|
+
"""Generate tensor of image position ids with optional offsets for FLUX.2.
|
|
71
|
+
|
|
72
|
+
FLUX.2 uses 4D position coordinates (T, H, W, L) for its rotary position embeddings.
|
|
73
|
+
Position IDs use int64 (long) dtype.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
latent_height: Height of image in latent space (before packing).
|
|
77
|
+
latent_width: Width of image in latent space (before packing).
|
|
78
|
+
batch_size: Number of images in the batch.
|
|
79
|
+
device: Device to create tensors on.
|
|
80
|
+
idx_offset: Offset for T (time/index) coordinate - use 1 for reference images.
|
|
81
|
+
h_offset: Spatial offset for H coordinate in latent space.
|
|
82
|
+
w_offset: Spatial offset for W coordinate in latent space.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Image position ids with shape [batch_size, (latent_height//2 * latent_width//2), 4].
|
|
86
|
+
"""
|
|
87
|
+
# After packing, the spatial dimensions are halved due to the 2x2 patch structure
|
|
88
|
+
packed_height = latent_height // 2
|
|
89
|
+
packed_width = latent_width // 2
|
|
90
|
+
|
|
91
|
+
# Convert spatial offsets from latent space to packed space
|
|
92
|
+
packed_h_offset = h_offset // 2
|
|
93
|
+
packed_w_offset = w_offset // 2
|
|
94
|
+
|
|
95
|
+
# Create base tensor for position IDs with shape [packed_height, packed_width, 4]
|
|
96
|
+
# The 4 channels represent: [T, H, W, L]
|
|
97
|
+
img_ids = torch.zeros(packed_height, packed_width, 4, device=device, dtype=torch.long)
|
|
98
|
+
|
|
99
|
+
# Set T (time/index offset) for all positions - use 1 for reference images
|
|
100
|
+
img_ids[..., 0] = idx_offset
|
|
101
|
+
|
|
102
|
+
# Set H (height/y) coordinates with offset
|
|
103
|
+
h_coords = torch.arange(packed_height, device=device, dtype=torch.long) + packed_h_offset
|
|
104
|
+
img_ids[..., 1] = h_coords[:, None]
|
|
105
|
+
|
|
106
|
+
# Set W (width/x) coordinates with offset
|
|
107
|
+
w_coords = torch.arange(packed_width, device=device, dtype=torch.long) + packed_w_offset
|
|
108
|
+
img_ids[..., 2] = w_coords[None, :]
|
|
109
|
+
|
|
110
|
+
# L (layer) coordinate stays 0
|
|
111
|
+
|
|
112
|
+
# Expand to include batch dimension: [batch_size, (packed_height * packed_width), 4]
|
|
113
|
+
img_ids = img_ids.reshape(1, packed_height * packed_width, 4)
|
|
114
|
+
img_ids = repeat(img_ids, "1 s c -> b s c", b=batch_size)
|
|
115
|
+
|
|
116
|
+
return img_ids
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class Flux2RefImageExtension:
|
|
120
|
+
"""Applies FLUX.2 Klein reference image conditioning.
|
|
121
|
+
|
|
122
|
+
This extension handles encoding reference images using the FLUX.2 VAE
|
|
123
|
+
and generating the appropriate 4D position IDs for multi-reference image editing.
|
|
124
|
+
|
|
125
|
+
FLUX.2 Klein has built-in support for reference image editing, unlike FLUX.1
|
|
126
|
+
which requires a separate Kontext model.
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
def __init__(
|
|
130
|
+
self,
|
|
131
|
+
ref_image_conditioning: list[FluxKontextConditioningField],
|
|
132
|
+
context: InvocationContext,
|
|
133
|
+
vae_field: VAEField,
|
|
134
|
+
device: torch.device,
|
|
135
|
+
dtype: torch.dtype,
|
|
136
|
+
bn_mean: torch.Tensor | None = None,
|
|
137
|
+
bn_std: torch.Tensor | None = None,
|
|
138
|
+
):
|
|
139
|
+
"""Initialize the Flux2RefImageExtension.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
ref_image_conditioning: List of reference image conditioning fields.
|
|
143
|
+
context: The invocation context for loading models and images.
|
|
144
|
+
vae_field: The FLUX.2 VAE field for encoding images.
|
|
145
|
+
device: Target device for tensors.
|
|
146
|
+
dtype: Target dtype for tensors.
|
|
147
|
+
bn_mean: BN running mean for normalizing latents (shape: 128).
|
|
148
|
+
bn_std: BN running std for normalizing latents (shape: 128).
|
|
149
|
+
"""
|
|
150
|
+
self._context = context
|
|
151
|
+
self._device = device
|
|
152
|
+
self._dtype = dtype
|
|
153
|
+
self._vae_field = vae_field
|
|
154
|
+
self._bn_mean = bn_mean
|
|
155
|
+
self._bn_std = bn_std
|
|
156
|
+
self.ref_image_conditioning = ref_image_conditioning
|
|
157
|
+
|
|
158
|
+
# Pre-process and cache the reference image latents and ids upon initialization
|
|
159
|
+
self.ref_image_latents, self.ref_image_ids = self._prepare_ref_images()
|
|
160
|
+
|
|
161
|
+
def _bn_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
|
162
|
+
"""Apply BN normalization to packed latents.
|
|
163
|
+
|
|
164
|
+
BN formula (affine=False): y = (x - mean) / std
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
x: Packed latents of shape (B, seq, 128).
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
Normalized latents of same shape.
|
|
171
|
+
"""
|
|
172
|
+
assert self._bn_mean is not None and self._bn_std is not None
|
|
173
|
+
bn_mean = self._bn_mean.to(x.device, x.dtype)
|
|
174
|
+
bn_std = self._bn_std.to(x.device, x.dtype)
|
|
175
|
+
return (x - bn_mean) / bn_std
|
|
176
|
+
|
|
177
|
+
def _prepare_ref_images(self) -> tuple[torch.Tensor, torch.Tensor]:
|
|
178
|
+
"""Encode reference images and prepare their concatenated latents and IDs with spatial tiling."""
|
|
179
|
+
all_latents = []
|
|
180
|
+
all_ids = []
|
|
181
|
+
|
|
182
|
+
# Track cumulative dimensions for spatial tiling
|
|
183
|
+
canvas_h = 0
|
|
184
|
+
canvas_w = 0
|
|
185
|
+
|
|
186
|
+
vae_info = self._context.models.load(self._vae_field.vae)
|
|
187
|
+
|
|
188
|
+
# Determine max pixels based on number of reference images (BFL FLUX.2 approach)
|
|
189
|
+
num_refs = len(self.ref_image_conditioning)
|
|
190
|
+
max_pixels = MAX_PIXELS_SINGLE_REF if num_refs == 1 else MAX_PIXELS_MULTI_REF
|
|
191
|
+
|
|
192
|
+
for idx, ref_image_field in enumerate(self.ref_image_conditioning):
|
|
193
|
+
image = self._context.images.get_pil(ref_image_field.image.image_name)
|
|
194
|
+
image = image.convert("RGB")
|
|
195
|
+
|
|
196
|
+
# Resize large images to max pixel count (matches BFL FLUX.2 sampling.py)
|
|
197
|
+
image = resize_image_to_max_pixels(image, max_pixels)
|
|
198
|
+
|
|
199
|
+
# Convert to tensor using torchvision transforms
|
|
200
|
+
transformation = T.Compose([T.ToTensor()])
|
|
201
|
+
image_tensor = transformation(image)
|
|
202
|
+
# Convert from [0, 1] to [-1, 1] range expected by VAE
|
|
203
|
+
image_tensor = image_tensor * 2.0 - 1.0
|
|
204
|
+
image_tensor = image_tensor.unsqueeze(0) # Add batch dimension
|
|
205
|
+
|
|
206
|
+
# Encode using FLUX.2 VAE
|
|
207
|
+
with vae_info.model_on_device() as (_, vae):
|
|
208
|
+
vae_dtype = next(iter(vae.parameters())).dtype
|
|
209
|
+
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
|
|
210
|
+
|
|
211
|
+
# FLUX.2 VAE uses diffusers API
|
|
212
|
+
latent_dist = vae.encode(image_tensor, return_dict=False)[0]
|
|
213
|
+
|
|
214
|
+
# Use mode() for deterministic encoding (no sampling)
|
|
215
|
+
if hasattr(latent_dist, "mode"):
|
|
216
|
+
ref_image_latents_unpacked = latent_dist.mode()
|
|
217
|
+
elif hasattr(latent_dist, "sample"):
|
|
218
|
+
ref_image_latents_unpacked = latent_dist.sample()
|
|
219
|
+
else:
|
|
220
|
+
ref_image_latents_unpacked = latent_dist
|
|
221
|
+
|
|
222
|
+
TorchDevice.empty_cache()
|
|
223
|
+
|
|
224
|
+
# Extract tensor dimensions (B, 32, H, W for FLUX.2)
|
|
225
|
+
batch_size, _, latent_height, latent_width = ref_image_latents_unpacked.shape
|
|
226
|
+
|
|
227
|
+
# Pad latents to be compatible with patch_size=2
|
|
228
|
+
pad_h = (2 - latent_height % 2) % 2
|
|
229
|
+
pad_w = (2 - latent_width % 2) % 2
|
|
230
|
+
if pad_h > 0 or pad_w > 0:
|
|
231
|
+
ref_image_latents_unpacked = F.pad(ref_image_latents_unpacked, (0, pad_w, 0, pad_h), mode="circular")
|
|
232
|
+
_, _, latent_height, latent_width = ref_image_latents_unpacked.shape
|
|
233
|
+
|
|
234
|
+
# Pack the latents using FLUX.2 pack function (32 channels -> 128)
|
|
235
|
+
ref_image_latents_packed = pack_flux2(ref_image_latents_unpacked).to(self._device, self._dtype)
|
|
236
|
+
|
|
237
|
+
# Apply BN normalization to match the input latents scale
|
|
238
|
+
# This is critical - the transformer expects normalized latents
|
|
239
|
+
if self._bn_mean is not None and self._bn_std is not None:
|
|
240
|
+
ref_image_latents_packed = self._bn_normalize(ref_image_latents_packed)
|
|
241
|
+
|
|
242
|
+
# Determine spatial offsets for this reference image
|
|
243
|
+
h_offset = 0
|
|
244
|
+
w_offset = 0
|
|
245
|
+
|
|
246
|
+
if idx > 0: # First image starts at (0, 0)
|
|
247
|
+
# Calculate potential canvas dimensions for each tiling option
|
|
248
|
+
potential_h_vertical = canvas_h + latent_height
|
|
249
|
+
potential_w_horizontal = canvas_w + latent_width
|
|
250
|
+
|
|
251
|
+
# Choose arrangement that minimizes the maximum dimension
|
|
252
|
+
if potential_h_vertical > potential_w_horizontal:
|
|
253
|
+
# Tile horizontally (to the right)
|
|
254
|
+
w_offset = canvas_w
|
|
255
|
+
canvas_w = canvas_w + latent_width
|
|
256
|
+
canvas_h = max(canvas_h, latent_height)
|
|
257
|
+
else:
|
|
258
|
+
# Tile vertically (below)
|
|
259
|
+
h_offset = canvas_h
|
|
260
|
+
canvas_h = canvas_h + latent_height
|
|
261
|
+
canvas_w = max(canvas_w, latent_width)
|
|
262
|
+
else:
|
|
263
|
+
canvas_h = latent_height
|
|
264
|
+
canvas_w = latent_width
|
|
265
|
+
|
|
266
|
+
# Generate position IDs with 4D format (T, H, W, L)
|
|
267
|
+
# Use T-coordinate offset with scale=10 like diffusers Flux2Pipeline:
|
|
268
|
+
# T = scale + scale * idx (so first ref image is T=10, second is T=20, etc.)
|
|
269
|
+
# The generated image uses T=0, so this clearly separates reference images
|
|
270
|
+
t_offset = 10 + 10 * idx # scale=10 matches diffusers
|
|
271
|
+
ref_image_ids = generate_img_ids_flux2_with_offset(
|
|
272
|
+
latent_height=latent_height,
|
|
273
|
+
latent_width=latent_width,
|
|
274
|
+
batch_size=batch_size,
|
|
275
|
+
device=self._device,
|
|
276
|
+
idx_offset=t_offset, # Reference images use T=10, 20, 30...
|
|
277
|
+
h_offset=h_offset,
|
|
278
|
+
w_offset=w_offset,
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
all_latents.append(ref_image_latents_packed)
|
|
282
|
+
all_ids.append(ref_image_ids)
|
|
283
|
+
|
|
284
|
+
# Concatenate all latents and IDs along the sequence dimension
|
|
285
|
+
concatenated_latents = torch.cat(all_latents, dim=1)
|
|
286
|
+
concatenated_ids = torch.cat(all_ids, dim=1)
|
|
287
|
+
|
|
288
|
+
return concatenated_latents, concatenated_ids
|
|
289
|
+
|
|
290
|
+
def ensure_batch_size(self, target_batch_size: int) -> None:
|
|
291
|
+
"""Ensure the reference image latents and IDs match the target batch size."""
|
|
292
|
+
if self.ref_image_latents.shape[0] != target_batch_size:
|
|
293
|
+
self.ref_image_latents = self.ref_image_latents.repeat(target_batch_size, 1, 1)
|
|
294
|
+
self.ref_image_ids = self.ref_image_ids.repeat(target_batch_size, 1, 1)
|
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
"""FLUX.2 Klein Sampling Utilities.
|
|
2
|
+
|
|
3
|
+
FLUX.2 Klein uses a 32-channel VAE (AutoencoderKLFlux2) instead of the 16-channel VAE
|
|
4
|
+
used by FLUX.1. This module provides sampling utilities adapted for FLUX.2.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from einops import rearrange
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_noise_flux2(
|
|
14
|
+
num_samples: int,
|
|
15
|
+
height: int,
|
|
16
|
+
width: int,
|
|
17
|
+
device: torch.device,
|
|
18
|
+
dtype: torch.dtype,
|
|
19
|
+
seed: int,
|
|
20
|
+
) -> torch.Tensor:
|
|
21
|
+
"""Generate noise for FLUX.2 Klein (32 channels).
|
|
22
|
+
|
|
23
|
+
FLUX.2 uses a 32-channel VAE, so noise must have 32 channels.
|
|
24
|
+
The spatial dimensions are calculated to allow for packing.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
num_samples: Batch size.
|
|
28
|
+
height: Target image height in pixels.
|
|
29
|
+
width: Target image width in pixels.
|
|
30
|
+
device: Target device.
|
|
31
|
+
dtype: Target dtype.
|
|
32
|
+
seed: Random seed.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
Noise tensor of shape (num_samples, 32, latent_h, latent_w).
|
|
36
|
+
"""
|
|
37
|
+
# We always generate noise on the same device and dtype then cast to ensure consistency.
|
|
38
|
+
rand_device = "cpu"
|
|
39
|
+
rand_dtype = torch.float16
|
|
40
|
+
|
|
41
|
+
# FLUX.2 uses 32 latent channels
|
|
42
|
+
# Latent dimensions: height/8, width/8 (from VAE downsampling)
|
|
43
|
+
# Must be divisible by 2 for packing (patchify step)
|
|
44
|
+
latent_h = 2 * math.ceil(height / 16)
|
|
45
|
+
latent_w = 2 * math.ceil(width / 16)
|
|
46
|
+
|
|
47
|
+
return torch.randn(
|
|
48
|
+
num_samples,
|
|
49
|
+
32, # FLUX.2 uses 32 latent channels (vs 16 for FLUX.1)
|
|
50
|
+
latent_h,
|
|
51
|
+
latent_w,
|
|
52
|
+
device=rand_device,
|
|
53
|
+
dtype=rand_dtype,
|
|
54
|
+
generator=torch.Generator(device=rand_device).manual_seed(seed),
|
|
55
|
+
).to(device=device, dtype=dtype)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def pack_flux2(x: torch.Tensor) -> torch.Tensor:
|
|
59
|
+
"""Pack latent image to flattened array of patch embeddings for FLUX.2.
|
|
60
|
+
|
|
61
|
+
This performs the patchify + pack operation in one step:
|
|
62
|
+
1. Patchify: Group 2x2 spatial patches into channels (C*4)
|
|
63
|
+
2. Pack: Flatten spatial dimensions to sequence
|
|
64
|
+
|
|
65
|
+
For 32-channel input: (B, 32, H, W) -> (B, H/2*W/2, 128)
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
x: Latent tensor of shape (B, 32, H, W).
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
Packed tensor of shape (B, H/2*W/2, 128).
|
|
72
|
+
"""
|
|
73
|
+
# Same operation as FLUX.1 pack, but input has 32 channels -> output has 128
|
|
74
|
+
return rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def unpack_flux2(x: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
|
78
|
+
"""Unpack flat array of patch embeddings back to latent image for FLUX.2.
|
|
79
|
+
|
|
80
|
+
This reverses the pack_flux2 operation:
|
|
81
|
+
1. Unpack: Restore spatial dimensions from sequence
|
|
82
|
+
2. Unpatchify: Restore 32 channels from 128
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
x: Packed tensor of shape (B, H/2*W/2, 128).
|
|
86
|
+
height: Target image height in pixels.
|
|
87
|
+
width: Target image width in pixels.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
Latent tensor of shape (B, 32, H, W).
|
|
91
|
+
"""
|
|
92
|
+
# Calculate latent dimensions
|
|
93
|
+
latent_h = 2 * math.ceil(height / 16)
|
|
94
|
+
latent_w = 2 * math.ceil(width / 16)
|
|
95
|
+
|
|
96
|
+
# Packed dimensions (after patchify)
|
|
97
|
+
packed_h = latent_h // 2
|
|
98
|
+
packed_w = latent_w // 2
|
|
99
|
+
|
|
100
|
+
return rearrange(
|
|
101
|
+
x,
|
|
102
|
+
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
|
103
|
+
h=packed_h,
|
|
104
|
+
w=packed_w,
|
|
105
|
+
ph=2,
|
|
106
|
+
pw=2,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
|
|
111
|
+
"""Compute empirical mu for FLUX.2 schedule shifting.
|
|
112
|
+
|
|
113
|
+
This matches the diffusers Flux2Pipeline implementation.
|
|
114
|
+
The mu value controls how much the schedule is shifted towards higher timesteps.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
image_seq_len: Number of image tokens (packed_h * packed_w).
|
|
118
|
+
num_steps: Number of denoising steps.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
The empirical mu value.
|
|
122
|
+
"""
|
|
123
|
+
a1, b1 = 8.73809524e-05, 1.89833333
|
|
124
|
+
a2, b2 = 0.00016927, 0.45666666
|
|
125
|
+
|
|
126
|
+
if image_seq_len > 4300:
|
|
127
|
+
mu = a2 * image_seq_len + b2
|
|
128
|
+
return float(mu)
|
|
129
|
+
|
|
130
|
+
m_200 = a2 * image_seq_len + b2
|
|
131
|
+
m_10 = a1 * image_seq_len + b1
|
|
132
|
+
|
|
133
|
+
a = (m_200 - m_10) / 190.0
|
|
134
|
+
b = m_200 - 200.0 * a
|
|
135
|
+
mu = a * num_steps + b
|
|
136
|
+
|
|
137
|
+
return float(mu)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def get_schedule_flux2(
|
|
141
|
+
num_steps: int,
|
|
142
|
+
image_seq_len: int,
|
|
143
|
+
) -> list[float]:
|
|
144
|
+
"""Get linear timestep schedule for FLUX.2.
|
|
145
|
+
|
|
146
|
+
Returns a linear sigma schedule from 1.0 to 1/num_steps.
|
|
147
|
+
The actual schedule shifting is handled by the FlowMatchEulerDiscreteScheduler
|
|
148
|
+
using the mu parameter and use_dynamic_shifting=True.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
num_steps: Number of denoising steps.
|
|
152
|
+
image_seq_len: Number of image tokens (packed_h * packed_w). Currently unused,
|
|
153
|
+
but kept for API compatibility. The scheduler computes shifting internally.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
List of linear sigmas from 1.0 to 1/num_steps, plus final 0.0.
|
|
157
|
+
"""
|
|
158
|
+
import numpy as np
|
|
159
|
+
|
|
160
|
+
# Create linear sigmas from 1.0 to 1/num_steps
|
|
161
|
+
# The scheduler will apply dynamic shifting using mu parameter
|
|
162
|
+
sigmas = np.linspace(1.0, 1 / num_steps, num_steps)
|
|
163
|
+
sigmas_list = [float(s) for s in sigmas]
|
|
164
|
+
|
|
165
|
+
# Add final 0.0 for the last step (scheduler needs n+1 timesteps for n steps)
|
|
166
|
+
sigmas_list.append(0.0)
|
|
167
|
+
|
|
168
|
+
return sigmas_list
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def generate_img_ids_flux2(h: int, w: int, batch_size: int, device: torch.device) -> torch.Tensor:
|
|
172
|
+
"""Generate tensor of image position ids for FLUX.2.
|
|
173
|
+
|
|
174
|
+
FLUX.2 uses 4D position coordinates (T, H, W, L) for its rotary position embeddings.
|
|
175
|
+
This is different from FLUX.1 which uses 3D coordinates.
|
|
176
|
+
|
|
177
|
+
IMPORTANT: Position IDs must use int64 (long) dtype like diffusers, not bfloat16.
|
|
178
|
+
Using floating point dtype for position IDs can cause NaN in rotary embeddings.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
h: Height of image in latent space.
|
|
182
|
+
w: Width of image in latent space.
|
|
183
|
+
batch_size: Batch size.
|
|
184
|
+
device: Device.
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
Image position ids tensor of shape (batch_size, h/2*w/2, 4) with int64 dtype.
|
|
188
|
+
"""
|
|
189
|
+
# After packing, spatial dims are h/2 x w/2
|
|
190
|
+
packed_h = h // 2
|
|
191
|
+
packed_w = w // 2
|
|
192
|
+
|
|
193
|
+
# Create coordinate grids - 4D: (T, H, W, L)
|
|
194
|
+
# T = time/batch index, H = height, W = width, L = layer/channel
|
|
195
|
+
# Use int64 (long) dtype like diffusers
|
|
196
|
+
img_ids = torch.zeros(packed_h, packed_w, 4, device=device, dtype=torch.long)
|
|
197
|
+
|
|
198
|
+
# T (time/batch) coordinate - set to 0 (already initialized)
|
|
199
|
+
# H coordinates
|
|
200
|
+
img_ids[..., 1] = torch.arange(packed_h, device=device, dtype=torch.long)[:, None]
|
|
201
|
+
# W coordinates
|
|
202
|
+
img_ids[..., 2] = torch.arange(packed_w, device=device, dtype=torch.long)[None, :]
|
|
203
|
+
# L (layer) coordinate - set to 0 (already initialized)
|
|
204
|
+
|
|
205
|
+
# Flatten and expand for batch
|
|
206
|
+
img_ids = img_ids.reshape(1, packed_h * packed_w, 4)
|
|
207
|
+
img_ids = img_ids.expand(batch_size, -1, -1)
|
|
208
|
+
|
|
209
|
+
return img_ids
|