InvokeAI 6.10.0rc1__py3-none-any.whl → 6.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. invokeai/app/api/routers/model_manager.py +43 -1
  2. invokeai/app/invocations/fields.py +1 -1
  3. invokeai/app/invocations/flux2_denoise.py +499 -0
  4. invokeai/app/invocations/flux2_klein_model_loader.py +222 -0
  5. invokeai/app/invocations/flux2_klein_text_encoder.py +222 -0
  6. invokeai/app/invocations/flux2_vae_decode.py +106 -0
  7. invokeai/app/invocations/flux2_vae_encode.py +88 -0
  8. invokeai/app/invocations/flux_denoise.py +77 -3
  9. invokeai/app/invocations/flux_lora_loader.py +1 -1
  10. invokeai/app/invocations/flux_model_loader.py +2 -5
  11. invokeai/app/invocations/ideal_size.py +6 -1
  12. invokeai/app/invocations/metadata.py +4 -0
  13. invokeai/app/invocations/metadata_linked.py +47 -0
  14. invokeai/app/invocations/model.py +1 -0
  15. invokeai/app/invocations/pbr_maps.py +59 -0
  16. invokeai/app/invocations/z_image_denoise.py +244 -84
  17. invokeai/app/invocations/z_image_image_to_latents.py +9 -1
  18. invokeai/app/invocations/z_image_latents_to_image.py +9 -1
  19. invokeai/app/invocations/z_image_seed_variance_enhancer.py +110 -0
  20. invokeai/app/services/config/config_default.py +3 -1
  21. invokeai/app/services/invocation_stats/invocation_stats_common.py +6 -6
  22. invokeai/app/services/invocation_stats/invocation_stats_default.py +9 -4
  23. invokeai/app/services/model_manager/model_manager_default.py +7 -0
  24. invokeai/app/services/model_records/model_records_base.py +4 -2
  25. invokeai/app/services/shared/invocation_context.py +15 -0
  26. invokeai/app/services/shared/sqlite/sqlite_util.py +2 -0
  27. invokeai/app/services/shared/sqlite_migrator/migrations/migration_25.py +61 -0
  28. invokeai/app/util/step_callback.py +58 -2
  29. invokeai/backend/flux/denoise.py +338 -118
  30. invokeai/backend/flux/dype/__init__.py +31 -0
  31. invokeai/backend/flux/dype/base.py +260 -0
  32. invokeai/backend/flux/dype/embed.py +116 -0
  33. invokeai/backend/flux/dype/presets.py +148 -0
  34. invokeai/backend/flux/dype/rope.py +110 -0
  35. invokeai/backend/flux/extensions/dype_extension.py +91 -0
  36. invokeai/backend/flux/schedulers.py +62 -0
  37. invokeai/backend/flux/util.py +35 -1
  38. invokeai/backend/flux2/__init__.py +4 -0
  39. invokeai/backend/flux2/denoise.py +280 -0
  40. invokeai/backend/flux2/ref_image_extension.py +294 -0
  41. invokeai/backend/flux2/sampling_utils.py +209 -0
  42. invokeai/backend/image_util/pbr_maps/architecture/block.py +367 -0
  43. invokeai/backend/image_util/pbr_maps/architecture/pbr_rrdb_net.py +70 -0
  44. invokeai/backend/image_util/pbr_maps/pbr_maps.py +141 -0
  45. invokeai/backend/image_util/pbr_maps/utils/image_ops.py +93 -0
  46. invokeai/backend/model_manager/configs/factory.py +19 -1
  47. invokeai/backend/model_manager/configs/lora.py +36 -0
  48. invokeai/backend/model_manager/configs/main.py +395 -3
  49. invokeai/backend/model_manager/configs/qwen3_encoder.py +116 -7
  50. invokeai/backend/model_manager/configs/vae.py +104 -2
  51. invokeai/backend/model_manager/load/model_cache/model_cache.py +107 -2
  52. invokeai/backend/model_manager/load/model_loaders/cogview4.py +2 -1
  53. invokeai/backend/model_manager/load/model_loaders/flux.py +1020 -8
  54. invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +4 -2
  55. invokeai/backend/model_manager/load/model_loaders/onnx.py +1 -0
  56. invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +2 -1
  57. invokeai/backend/model_manager/load/model_loaders/z_image.py +158 -31
  58. invokeai/backend/model_manager/starter_models.py +141 -4
  59. invokeai/backend/model_manager/taxonomy.py +31 -4
  60. invokeai/backend/model_manager/util/select_hf_files.py +3 -2
  61. invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +39 -5
  62. invokeai/backend/quantization/gguf/ggml_tensor.py +15 -4
  63. invokeai/backend/util/vae_working_memory.py +0 -2
  64. invokeai/backend/z_image/extensions/regional_prompting_extension.py +10 -12
  65. invokeai/frontend/web/dist/assets/App-D13dX7be.js +161 -0
  66. invokeai/frontend/web/dist/assets/{browser-ponyfill-DHZxq1nk.js → browser-ponyfill-u_ZjhQTI.js} +1 -1
  67. invokeai/frontend/web/dist/assets/index-BB0nHmDe.js +530 -0
  68. invokeai/frontend/web/dist/index.html +1 -1
  69. invokeai/frontend/web/dist/locales/en-GB.json +1 -0
  70. invokeai/frontend/web/dist/locales/en.json +85 -6
  71. invokeai/frontend/web/dist/locales/it.json +135 -15
  72. invokeai/frontend/web/dist/locales/ru.json +11 -11
  73. invokeai/version/invokeai_version.py +1 -1
  74. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/METADATA +8 -2
  75. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/RECORD +81 -57
  76. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/WHEEL +1 -1
  77. invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +0 -161
  78. invokeai/frontend/web/dist/assets/index-dgSJAY--.js +0 -530
  79. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/entry_points.txt +0 -0
  80. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE +0 -0
  81. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
  82. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
  83. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,294 @@
1
+ """FLUX.2 Klein Reference Image Extension for multi-reference image editing.
2
+
3
+ This module provides the Flux2RefImageExtension for FLUX.2 Klein models,
4
+ which handles encoding reference images using the FLUX.2 VAE and
5
+ generating the appropriate position IDs for multi-reference image editing.
6
+
7
+ FLUX.2 Klein has built-in support for reference image editing (unlike FLUX.1
8
+ which requires a separate Kontext model).
9
+ """
10
+
11
+ import math
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ import torchvision.transforms as T
16
+ from einops import repeat
17
+ from PIL import Image
18
+
19
+ from invokeai.app.invocations.fields import FluxKontextConditioningField
20
+ from invokeai.app.invocations.model import VAEField
21
+ from invokeai.app.services.shared.invocation_context import InvocationContext
22
+ from invokeai.backend.flux2.sampling_utils import pack_flux2
23
+ from invokeai.backend.util.devices import TorchDevice
24
+
25
+ # Maximum pixel counts for reference images (matches BFL FLUX.2 sampling.py)
26
+ # Single reference image: 2024² pixels, Multiple: 1024² pixels
27
+ MAX_PIXELS_SINGLE_REF = 2024**2 # ~4.1M pixels
28
+ MAX_PIXELS_MULTI_REF = 1024**2 # ~1M pixels
29
+
30
+
31
+ def resize_image_to_max_pixels(image: Image.Image, max_pixels: int) -> Image.Image:
32
+ """Resize image to fit within max_pixels while preserving aspect ratio.
33
+
34
+ This matches the BFL FLUX.2 sampling.py cap_pixels() behavior.
35
+
36
+ Args:
37
+ image: PIL Image to resize.
38
+ max_pixels: Maximum total pixel count (width * height).
39
+
40
+ Returns:
41
+ Resized PIL Image (or original if already within bounds).
42
+ """
43
+ width, height = image.size
44
+ pixel_count = width * height
45
+
46
+ if pixel_count <= max_pixels:
47
+ return image
48
+
49
+ # Calculate scale factor to fit within max_pixels (BFL approach)
50
+ scale = math.sqrt(max_pixels / pixel_count)
51
+ new_width = int(width * scale)
52
+ new_height = int(height * scale)
53
+
54
+ # Ensure dimensions are at least 1
55
+ new_width = max(1, new_width)
56
+ new_height = max(1, new_height)
57
+
58
+ return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
59
+
60
+
61
+ def generate_img_ids_flux2_with_offset(
62
+ latent_height: int,
63
+ latent_width: int,
64
+ batch_size: int,
65
+ device: torch.device,
66
+ idx_offset: int = 0,
67
+ h_offset: int = 0,
68
+ w_offset: int = 0,
69
+ ) -> torch.Tensor:
70
+ """Generate tensor of image position ids with optional offsets for FLUX.2.
71
+
72
+ FLUX.2 uses 4D position coordinates (T, H, W, L) for its rotary position embeddings.
73
+ Position IDs use int64 (long) dtype.
74
+
75
+ Args:
76
+ latent_height: Height of image in latent space (before packing).
77
+ latent_width: Width of image in latent space (before packing).
78
+ batch_size: Number of images in the batch.
79
+ device: Device to create tensors on.
80
+ idx_offset: Offset for T (time/index) coordinate - use 1 for reference images.
81
+ h_offset: Spatial offset for H coordinate in latent space.
82
+ w_offset: Spatial offset for W coordinate in latent space.
83
+
84
+ Returns:
85
+ Image position ids with shape [batch_size, (latent_height//2 * latent_width//2), 4].
86
+ """
87
+ # After packing, the spatial dimensions are halved due to the 2x2 patch structure
88
+ packed_height = latent_height // 2
89
+ packed_width = latent_width // 2
90
+
91
+ # Convert spatial offsets from latent space to packed space
92
+ packed_h_offset = h_offset // 2
93
+ packed_w_offset = w_offset // 2
94
+
95
+ # Create base tensor for position IDs with shape [packed_height, packed_width, 4]
96
+ # The 4 channels represent: [T, H, W, L]
97
+ img_ids = torch.zeros(packed_height, packed_width, 4, device=device, dtype=torch.long)
98
+
99
+ # Set T (time/index offset) for all positions - use 1 for reference images
100
+ img_ids[..., 0] = idx_offset
101
+
102
+ # Set H (height/y) coordinates with offset
103
+ h_coords = torch.arange(packed_height, device=device, dtype=torch.long) + packed_h_offset
104
+ img_ids[..., 1] = h_coords[:, None]
105
+
106
+ # Set W (width/x) coordinates with offset
107
+ w_coords = torch.arange(packed_width, device=device, dtype=torch.long) + packed_w_offset
108
+ img_ids[..., 2] = w_coords[None, :]
109
+
110
+ # L (layer) coordinate stays 0
111
+
112
+ # Expand to include batch dimension: [batch_size, (packed_height * packed_width), 4]
113
+ img_ids = img_ids.reshape(1, packed_height * packed_width, 4)
114
+ img_ids = repeat(img_ids, "1 s c -> b s c", b=batch_size)
115
+
116
+ return img_ids
117
+
118
+
119
+ class Flux2RefImageExtension:
120
+ """Applies FLUX.2 Klein reference image conditioning.
121
+
122
+ This extension handles encoding reference images using the FLUX.2 VAE
123
+ and generating the appropriate 4D position IDs for multi-reference image editing.
124
+
125
+ FLUX.2 Klein has built-in support for reference image editing, unlike FLUX.1
126
+ which requires a separate Kontext model.
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ ref_image_conditioning: list[FluxKontextConditioningField],
132
+ context: InvocationContext,
133
+ vae_field: VAEField,
134
+ device: torch.device,
135
+ dtype: torch.dtype,
136
+ bn_mean: torch.Tensor | None = None,
137
+ bn_std: torch.Tensor | None = None,
138
+ ):
139
+ """Initialize the Flux2RefImageExtension.
140
+
141
+ Args:
142
+ ref_image_conditioning: List of reference image conditioning fields.
143
+ context: The invocation context for loading models and images.
144
+ vae_field: The FLUX.2 VAE field for encoding images.
145
+ device: Target device for tensors.
146
+ dtype: Target dtype for tensors.
147
+ bn_mean: BN running mean for normalizing latents (shape: 128).
148
+ bn_std: BN running std for normalizing latents (shape: 128).
149
+ """
150
+ self._context = context
151
+ self._device = device
152
+ self._dtype = dtype
153
+ self._vae_field = vae_field
154
+ self._bn_mean = bn_mean
155
+ self._bn_std = bn_std
156
+ self.ref_image_conditioning = ref_image_conditioning
157
+
158
+ # Pre-process and cache the reference image latents and ids upon initialization
159
+ self.ref_image_latents, self.ref_image_ids = self._prepare_ref_images()
160
+
161
+ def _bn_normalize(self, x: torch.Tensor) -> torch.Tensor:
162
+ """Apply BN normalization to packed latents.
163
+
164
+ BN formula (affine=False): y = (x - mean) / std
165
+
166
+ Args:
167
+ x: Packed latents of shape (B, seq, 128).
168
+
169
+ Returns:
170
+ Normalized latents of same shape.
171
+ """
172
+ assert self._bn_mean is not None and self._bn_std is not None
173
+ bn_mean = self._bn_mean.to(x.device, x.dtype)
174
+ bn_std = self._bn_std.to(x.device, x.dtype)
175
+ return (x - bn_mean) / bn_std
176
+
177
+ def _prepare_ref_images(self) -> tuple[torch.Tensor, torch.Tensor]:
178
+ """Encode reference images and prepare their concatenated latents and IDs with spatial tiling."""
179
+ all_latents = []
180
+ all_ids = []
181
+
182
+ # Track cumulative dimensions for spatial tiling
183
+ canvas_h = 0
184
+ canvas_w = 0
185
+
186
+ vae_info = self._context.models.load(self._vae_field.vae)
187
+
188
+ # Determine max pixels based on number of reference images (BFL FLUX.2 approach)
189
+ num_refs = len(self.ref_image_conditioning)
190
+ max_pixels = MAX_PIXELS_SINGLE_REF if num_refs == 1 else MAX_PIXELS_MULTI_REF
191
+
192
+ for idx, ref_image_field in enumerate(self.ref_image_conditioning):
193
+ image = self._context.images.get_pil(ref_image_field.image.image_name)
194
+ image = image.convert("RGB")
195
+
196
+ # Resize large images to max pixel count (matches BFL FLUX.2 sampling.py)
197
+ image = resize_image_to_max_pixels(image, max_pixels)
198
+
199
+ # Convert to tensor using torchvision transforms
200
+ transformation = T.Compose([T.ToTensor()])
201
+ image_tensor = transformation(image)
202
+ # Convert from [0, 1] to [-1, 1] range expected by VAE
203
+ image_tensor = image_tensor * 2.0 - 1.0
204
+ image_tensor = image_tensor.unsqueeze(0) # Add batch dimension
205
+
206
+ # Encode using FLUX.2 VAE
207
+ with vae_info.model_on_device() as (_, vae):
208
+ vae_dtype = next(iter(vae.parameters())).dtype
209
+ image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
210
+
211
+ # FLUX.2 VAE uses diffusers API
212
+ latent_dist = vae.encode(image_tensor, return_dict=False)[0]
213
+
214
+ # Use mode() for deterministic encoding (no sampling)
215
+ if hasattr(latent_dist, "mode"):
216
+ ref_image_latents_unpacked = latent_dist.mode()
217
+ elif hasattr(latent_dist, "sample"):
218
+ ref_image_latents_unpacked = latent_dist.sample()
219
+ else:
220
+ ref_image_latents_unpacked = latent_dist
221
+
222
+ TorchDevice.empty_cache()
223
+
224
+ # Extract tensor dimensions (B, 32, H, W for FLUX.2)
225
+ batch_size, _, latent_height, latent_width = ref_image_latents_unpacked.shape
226
+
227
+ # Pad latents to be compatible with patch_size=2
228
+ pad_h = (2 - latent_height % 2) % 2
229
+ pad_w = (2 - latent_width % 2) % 2
230
+ if pad_h > 0 or pad_w > 0:
231
+ ref_image_latents_unpacked = F.pad(ref_image_latents_unpacked, (0, pad_w, 0, pad_h), mode="circular")
232
+ _, _, latent_height, latent_width = ref_image_latents_unpacked.shape
233
+
234
+ # Pack the latents using FLUX.2 pack function (32 channels -> 128)
235
+ ref_image_latents_packed = pack_flux2(ref_image_latents_unpacked).to(self._device, self._dtype)
236
+
237
+ # Apply BN normalization to match the input latents scale
238
+ # This is critical - the transformer expects normalized latents
239
+ if self._bn_mean is not None and self._bn_std is not None:
240
+ ref_image_latents_packed = self._bn_normalize(ref_image_latents_packed)
241
+
242
+ # Determine spatial offsets for this reference image
243
+ h_offset = 0
244
+ w_offset = 0
245
+
246
+ if idx > 0: # First image starts at (0, 0)
247
+ # Calculate potential canvas dimensions for each tiling option
248
+ potential_h_vertical = canvas_h + latent_height
249
+ potential_w_horizontal = canvas_w + latent_width
250
+
251
+ # Choose arrangement that minimizes the maximum dimension
252
+ if potential_h_vertical > potential_w_horizontal:
253
+ # Tile horizontally (to the right)
254
+ w_offset = canvas_w
255
+ canvas_w = canvas_w + latent_width
256
+ canvas_h = max(canvas_h, latent_height)
257
+ else:
258
+ # Tile vertically (below)
259
+ h_offset = canvas_h
260
+ canvas_h = canvas_h + latent_height
261
+ canvas_w = max(canvas_w, latent_width)
262
+ else:
263
+ canvas_h = latent_height
264
+ canvas_w = latent_width
265
+
266
+ # Generate position IDs with 4D format (T, H, W, L)
267
+ # Use T-coordinate offset with scale=10 like diffusers Flux2Pipeline:
268
+ # T = scale + scale * idx (so first ref image is T=10, second is T=20, etc.)
269
+ # The generated image uses T=0, so this clearly separates reference images
270
+ t_offset = 10 + 10 * idx # scale=10 matches diffusers
271
+ ref_image_ids = generate_img_ids_flux2_with_offset(
272
+ latent_height=latent_height,
273
+ latent_width=latent_width,
274
+ batch_size=batch_size,
275
+ device=self._device,
276
+ idx_offset=t_offset, # Reference images use T=10, 20, 30...
277
+ h_offset=h_offset,
278
+ w_offset=w_offset,
279
+ )
280
+
281
+ all_latents.append(ref_image_latents_packed)
282
+ all_ids.append(ref_image_ids)
283
+
284
+ # Concatenate all latents and IDs along the sequence dimension
285
+ concatenated_latents = torch.cat(all_latents, dim=1)
286
+ concatenated_ids = torch.cat(all_ids, dim=1)
287
+
288
+ return concatenated_latents, concatenated_ids
289
+
290
+ def ensure_batch_size(self, target_batch_size: int) -> None:
291
+ """Ensure the reference image latents and IDs match the target batch size."""
292
+ if self.ref_image_latents.shape[0] != target_batch_size:
293
+ self.ref_image_latents = self.ref_image_latents.repeat(target_batch_size, 1, 1)
294
+ self.ref_image_ids = self.ref_image_ids.repeat(target_batch_size, 1, 1)
@@ -0,0 +1,209 @@
1
+ """FLUX.2 Klein Sampling Utilities.
2
+
3
+ FLUX.2 Klein uses a 32-channel VAE (AutoencoderKLFlux2) instead of the 16-channel VAE
4
+ used by FLUX.1. This module provides sampling utilities adapted for FLUX.2.
5
+ """
6
+
7
+ import math
8
+
9
+ import torch
10
+ from einops import rearrange
11
+
12
+
13
+ def get_noise_flux2(
14
+ num_samples: int,
15
+ height: int,
16
+ width: int,
17
+ device: torch.device,
18
+ dtype: torch.dtype,
19
+ seed: int,
20
+ ) -> torch.Tensor:
21
+ """Generate noise for FLUX.2 Klein (32 channels).
22
+
23
+ FLUX.2 uses a 32-channel VAE, so noise must have 32 channels.
24
+ The spatial dimensions are calculated to allow for packing.
25
+
26
+ Args:
27
+ num_samples: Batch size.
28
+ height: Target image height in pixels.
29
+ width: Target image width in pixels.
30
+ device: Target device.
31
+ dtype: Target dtype.
32
+ seed: Random seed.
33
+
34
+ Returns:
35
+ Noise tensor of shape (num_samples, 32, latent_h, latent_w).
36
+ """
37
+ # We always generate noise on the same device and dtype then cast to ensure consistency.
38
+ rand_device = "cpu"
39
+ rand_dtype = torch.float16
40
+
41
+ # FLUX.2 uses 32 latent channels
42
+ # Latent dimensions: height/8, width/8 (from VAE downsampling)
43
+ # Must be divisible by 2 for packing (patchify step)
44
+ latent_h = 2 * math.ceil(height / 16)
45
+ latent_w = 2 * math.ceil(width / 16)
46
+
47
+ return torch.randn(
48
+ num_samples,
49
+ 32, # FLUX.2 uses 32 latent channels (vs 16 for FLUX.1)
50
+ latent_h,
51
+ latent_w,
52
+ device=rand_device,
53
+ dtype=rand_dtype,
54
+ generator=torch.Generator(device=rand_device).manual_seed(seed),
55
+ ).to(device=device, dtype=dtype)
56
+
57
+
58
+ def pack_flux2(x: torch.Tensor) -> torch.Tensor:
59
+ """Pack latent image to flattened array of patch embeddings for FLUX.2.
60
+
61
+ This performs the patchify + pack operation in one step:
62
+ 1. Patchify: Group 2x2 spatial patches into channels (C*4)
63
+ 2. Pack: Flatten spatial dimensions to sequence
64
+
65
+ For 32-channel input: (B, 32, H, W) -> (B, H/2*W/2, 128)
66
+
67
+ Args:
68
+ x: Latent tensor of shape (B, 32, H, W).
69
+
70
+ Returns:
71
+ Packed tensor of shape (B, H/2*W/2, 128).
72
+ """
73
+ # Same operation as FLUX.1 pack, but input has 32 channels -> output has 128
74
+ return rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
75
+
76
+
77
+ def unpack_flux2(x: torch.Tensor, height: int, width: int) -> torch.Tensor:
78
+ """Unpack flat array of patch embeddings back to latent image for FLUX.2.
79
+
80
+ This reverses the pack_flux2 operation:
81
+ 1. Unpack: Restore spatial dimensions from sequence
82
+ 2. Unpatchify: Restore 32 channels from 128
83
+
84
+ Args:
85
+ x: Packed tensor of shape (B, H/2*W/2, 128).
86
+ height: Target image height in pixels.
87
+ width: Target image width in pixels.
88
+
89
+ Returns:
90
+ Latent tensor of shape (B, 32, H, W).
91
+ """
92
+ # Calculate latent dimensions
93
+ latent_h = 2 * math.ceil(height / 16)
94
+ latent_w = 2 * math.ceil(width / 16)
95
+
96
+ # Packed dimensions (after patchify)
97
+ packed_h = latent_h // 2
98
+ packed_w = latent_w // 2
99
+
100
+ return rearrange(
101
+ x,
102
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
103
+ h=packed_h,
104
+ w=packed_w,
105
+ ph=2,
106
+ pw=2,
107
+ )
108
+
109
+
110
+ def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
111
+ """Compute empirical mu for FLUX.2 schedule shifting.
112
+
113
+ This matches the diffusers Flux2Pipeline implementation.
114
+ The mu value controls how much the schedule is shifted towards higher timesteps.
115
+
116
+ Args:
117
+ image_seq_len: Number of image tokens (packed_h * packed_w).
118
+ num_steps: Number of denoising steps.
119
+
120
+ Returns:
121
+ The empirical mu value.
122
+ """
123
+ a1, b1 = 8.73809524e-05, 1.89833333
124
+ a2, b2 = 0.00016927, 0.45666666
125
+
126
+ if image_seq_len > 4300:
127
+ mu = a2 * image_seq_len + b2
128
+ return float(mu)
129
+
130
+ m_200 = a2 * image_seq_len + b2
131
+ m_10 = a1 * image_seq_len + b1
132
+
133
+ a = (m_200 - m_10) / 190.0
134
+ b = m_200 - 200.0 * a
135
+ mu = a * num_steps + b
136
+
137
+ return float(mu)
138
+
139
+
140
+ def get_schedule_flux2(
141
+ num_steps: int,
142
+ image_seq_len: int,
143
+ ) -> list[float]:
144
+ """Get linear timestep schedule for FLUX.2.
145
+
146
+ Returns a linear sigma schedule from 1.0 to 1/num_steps.
147
+ The actual schedule shifting is handled by the FlowMatchEulerDiscreteScheduler
148
+ using the mu parameter and use_dynamic_shifting=True.
149
+
150
+ Args:
151
+ num_steps: Number of denoising steps.
152
+ image_seq_len: Number of image tokens (packed_h * packed_w). Currently unused,
153
+ but kept for API compatibility. The scheduler computes shifting internally.
154
+
155
+ Returns:
156
+ List of linear sigmas from 1.0 to 1/num_steps, plus final 0.0.
157
+ """
158
+ import numpy as np
159
+
160
+ # Create linear sigmas from 1.0 to 1/num_steps
161
+ # The scheduler will apply dynamic shifting using mu parameter
162
+ sigmas = np.linspace(1.0, 1 / num_steps, num_steps)
163
+ sigmas_list = [float(s) for s in sigmas]
164
+
165
+ # Add final 0.0 for the last step (scheduler needs n+1 timesteps for n steps)
166
+ sigmas_list.append(0.0)
167
+
168
+ return sigmas_list
169
+
170
+
171
+ def generate_img_ids_flux2(h: int, w: int, batch_size: int, device: torch.device) -> torch.Tensor:
172
+ """Generate tensor of image position ids for FLUX.2.
173
+
174
+ FLUX.2 uses 4D position coordinates (T, H, W, L) for its rotary position embeddings.
175
+ This is different from FLUX.1 which uses 3D coordinates.
176
+
177
+ IMPORTANT: Position IDs must use int64 (long) dtype like diffusers, not bfloat16.
178
+ Using floating point dtype for position IDs can cause NaN in rotary embeddings.
179
+
180
+ Args:
181
+ h: Height of image in latent space.
182
+ w: Width of image in latent space.
183
+ batch_size: Batch size.
184
+ device: Device.
185
+
186
+ Returns:
187
+ Image position ids tensor of shape (batch_size, h/2*w/2, 4) with int64 dtype.
188
+ """
189
+ # After packing, spatial dims are h/2 x w/2
190
+ packed_h = h // 2
191
+ packed_w = w // 2
192
+
193
+ # Create coordinate grids - 4D: (T, H, W, L)
194
+ # T = time/batch index, H = height, W = width, L = layer/channel
195
+ # Use int64 (long) dtype like diffusers
196
+ img_ids = torch.zeros(packed_h, packed_w, 4, device=device, dtype=torch.long)
197
+
198
+ # T (time/batch) coordinate - set to 0 (already initialized)
199
+ # H coordinates
200
+ img_ids[..., 1] = torch.arange(packed_h, device=device, dtype=torch.long)[:, None]
201
+ # W coordinates
202
+ img_ids[..., 2] = torch.arange(packed_w, device=device, dtype=torch.long)[None, :]
203
+ # L (layer) coordinate - set to 0 (already initialized)
204
+
205
+ # Flatten and expand for batch
206
+ img_ids = img_ids.reshape(1, packed_h * packed_w, 4)
207
+ img_ids = img_ids.expand(batch_size, -1, -1)
208
+
209
+ return img_ids