InvokeAI 6.10.0rc2__py3-none-any.whl → 6.11.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. invokeai/app/api/routers/model_manager.py +43 -1
  2. invokeai/app/invocations/fields.py +1 -1
  3. invokeai/app/invocations/flux2_denoise.py +499 -0
  4. invokeai/app/invocations/flux2_klein_model_loader.py +222 -0
  5. invokeai/app/invocations/flux2_klein_text_encoder.py +222 -0
  6. invokeai/app/invocations/flux2_vae_decode.py +106 -0
  7. invokeai/app/invocations/flux2_vae_encode.py +88 -0
  8. invokeai/app/invocations/flux_denoise.py +50 -3
  9. invokeai/app/invocations/flux_lora_loader.py +1 -1
  10. invokeai/app/invocations/ideal_size.py +6 -1
  11. invokeai/app/invocations/metadata.py +4 -0
  12. invokeai/app/invocations/metadata_linked.py +47 -0
  13. invokeai/app/invocations/model.py +1 -0
  14. invokeai/app/invocations/z_image_denoise.py +8 -3
  15. invokeai/app/invocations/z_image_image_to_latents.py +9 -1
  16. invokeai/app/invocations/z_image_latents_to_image.py +9 -1
  17. invokeai/app/invocations/z_image_seed_variance_enhancer.py +110 -0
  18. invokeai/app/services/config/config_default.py +3 -1
  19. invokeai/app/services/invocation_stats/invocation_stats_common.py +6 -6
  20. invokeai/app/services/invocation_stats/invocation_stats_default.py +9 -4
  21. invokeai/app/services/model_manager/model_manager_default.py +7 -0
  22. invokeai/app/services/model_records/model_records_base.py +4 -2
  23. invokeai/app/services/shared/invocation_context.py +15 -0
  24. invokeai/app/services/shared/sqlite/sqlite_util.py +2 -0
  25. invokeai/app/services/shared/sqlite_migrator/migrations/migration_25.py +61 -0
  26. invokeai/app/util/step_callback.py +42 -0
  27. invokeai/backend/flux/denoise.py +239 -204
  28. invokeai/backend/flux/dype/__init__.py +18 -0
  29. invokeai/backend/flux/dype/base.py +226 -0
  30. invokeai/backend/flux/dype/embed.py +116 -0
  31. invokeai/backend/flux/dype/presets.py +141 -0
  32. invokeai/backend/flux/dype/rope.py +110 -0
  33. invokeai/backend/flux/extensions/dype_extension.py +91 -0
  34. invokeai/backend/flux/util.py +35 -1
  35. invokeai/backend/flux2/__init__.py +4 -0
  36. invokeai/backend/flux2/denoise.py +261 -0
  37. invokeai/backend/flux2/ref_image_extension.py +294 -0
  38. invokeai/backend/flux2/sampling_utils.py +209 -0
  39. invokeai/backend/model_manager/configs/factory.py +19 -1
  40. invokeai/backend/model_manager/configs/main.py +395 -3
  41. invokeai/backend/model_manager/configs/qwen3_encoder.py +116 -7
  42. invokeai/backend/model_manager/configs/vae.py +104 -2
  43. invokeai/backend/model_manager/load/load_default.py +0 -1
  44. invokeai/backend/model_manager/load/model_cache/model_cache.py +107 -2
  45. invokeai/backend/model_manager/load/model_loaders/flux.py +1007 -2
  46. invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +0 -1
  47. invokeai/backend/model_manager/load/model_loaders/z_image.py +121 -28
  48. invokeai/backend/model_manager/starter_models.py +128 -0
  49. invokeai/backend/model_manager/taxonomy.py +31 -4
  50. invokeai/backend/model_manager/util/select_hf_files.py +3 -2
  51. invokeai/backend/util/vae_working_memory.py +0 -2
  52. invokeai/frontend/web/dist/assets/App-ClpIJstk.js +161 -0
  53. invokeai/frontend/web/dist/assets/{browser-ponyfill-BP0RxJ4G.js → browser-ponyfill-Cw07u5G1.js} +1 -1
  54. invokeai/frontend/web/dist/assets/{index-B44qKjrs.js → index-DSKM8iGj.js} +69 -69
  55. invokeai/frontend/web/dist/index.html +1 -1
  56. invokeai/frontend/web/dist/locales/en.json +58 -5
  57. invokeai/frontend/web/dist/locales/it.json +2 -1
  58. invokeai/version/invokeai_version.py +1 -1
  59. {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/METADATA +7 -1
  60. {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/RECORD +66 -49
  61. {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/WHEEL +1 -1
  62. invokeai/frontend/web/dist/assets/App-DllqPQ3j.js +0 -161
  63. {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/entry_points.txt +0 -0
  64. {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/licenses/LICENSE +0 -0
  65. {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
  66. {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
  67. {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,7 @@ from typing import Literal
5
5
 
6
6
  from invokeai.backend.flux.model import FluxParams
7
7
  from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams
8
- from invokeai.backend.model_manager.taxonomy import AnyVariant, FluxVariantType
8
+ from invokeai.backend.model_manager.taxonomy import AnyVariant, Flux2VariantType, FluxVariantType
9
9
 
10
10
 
11
11
  @dataclass
@@ -46,6 +46,8 @@ _flux_max_seq_lengths: dict[AnyVariant, Literal[256, 512]] = {
46
46
  FluxVariantType.Dev: 512,
47
47
  FluxVariantType.DevFill: 512,
48
48
  FluxVariantType.Schnell: 256,
49
+ Flux2VariantType.Klein4B: 512,
50
+ Flux2VariantType.Klein9B: 512,
49
51
  }
50
52
 
51
53
 
@@ -117,6 +119,38 @@ _flux_transformer_params: dict[AnyVariant, FluxParams] = {
117
119
  qkv_bias=True,
118
120
  guidance_embed=True,
119
121
  ),
122
+ # Flux2 Klein 4B uses Qwen3 4B text encoder with stacked embeddings from layers [9, 18, 27]
123
+ # The context_in_dim is 3 * hidden_size of Qwen3 (3 * 2560 = 7680)
124
+ Flux2VariantType.Klein4B: FluxParams(
125
+ in_channels=64,
126
+ vec_in_dim=2560, # Qwen3-4B hidden size (used for pooled output)
127
+ context_in_dim=7680, # 3 layers * 2560 = 7680 for Qwen3-4B
128
+ hidden_size=3072,
129
+ mlp_ratio=4.0,
130
+ num_heads=24,
131
+ depth=19,
132
+ depth_single_blocks=38,
133
+ axes_dim=[16, 56, 56],
134
+ theta=10_000,
135
+ qkv_bias=True,
136
+ guidance_embed=True,
137
+ ),
138
+ # Flux2 Klein 9B uses Qwen3 8B text encoder with stacked embeddings from layers [9, 18, 27]
139
+ # The context_in_dim is 3 * hidden_size of Qwen3 (3 * 4096 = 12288)
140
+ Flux2VariantType.Klein9B: FluxParams(
141
+ in_channels=64,
142
+ vec_in_dim=4096, # Qwen3-8B hidden size (used for pooled output)
143
+ context_in_dim=12288, # 3 layers * 4096 = 12288 for Qwen3-8B
144
+ hidden_size=3072,
145
+ mlp_ratio=4.0,
146
+ num_heads=24,
147
+ depth=19,
148
+ depth_single_blocks=38,
149
+ axes_dim=[16, 56, 56],
150
+ theta=10_000,
151
+ qkv_bias=True,
152
+ guidance_embed=True,
153
+ ),
120
154
  }
121
155
 
122
156
 
@@ -0,0 +1,4 @@
1
+ """FLUX.2 backend modules.
2
+
3
+ This package contains modules specific to FLUX.2 models (e.g., Klein).
4
+ """
@@ -0,0 +1,261 @@
1
+ """Flux2 Klein Denoising Function.
2
+
3
+ This module provides the denoising function for FLUX.2 Klein models,
4
+ which use Qwen3 as the text encoder instead of CLIP+T5.
5
+ """
6
+
7
+ import math
8
+ from typing import Any, Callable
9
+
10
+ import numpy as np
11
+ import torch
12
+ from tqdm import tqdm
13
+
14
+ from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
15
+ from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
16
+
17
+
18
+ def denoise(
19
+ model: torch.nn.Module,
20
+ # model input
21
+ img: torch.Tensor,
22
+ img_ids: torch.Tensor,
23
+ txt: torch.Tensor,
24
+ txt_ids: torch.Tensor,
25
+ # sampling parameters
26
+ timesteps: list[float],
27
+ step_callback: Callable[[PipelineIntermediateState], None],
28
+ cfg_scale: list[float],
29
+ # Negative conditioning for CFG
30
+ neg_txt: torch.Tensor | None = None,
31
+ neg_txt_ids: torch.Tensor | None = None,
32
+ # Scheduler for stepping (e.g., FlowMatchEulerDiscreteScheduler, FlowMatchHeunDiscreteScheduler)
33
+ scheduler: Any = None,
34
+ # Dynamic shifting parameter for FLUX.2 Klein (computed from image resolution)
35
+ mu: float | None = None,
36
+ # Inpainting extension for merging latents during denoising
37
+ inpaint_extension: RectifiedFlowInpaintExtension | None = None,
38
+ # Reference image conditioning (multi-reference image editing)
39
+ img_cond_seq: torch.Tensor | None = None,
40
+ img_cond_seq_ids: torch.Tensor | None = None,
41
+ ) -> torch.Tensor:
42
+ """Denoise latents using a FLUX.2 Klein transformer model.
43
+
44
+ This is a simplified denoise function for FLUX.2 Klein models that uses
45
+ the diffusers Flux2Transformer2DModel interface.
46
+
47
+ Note: FLUX.2 Klein has guidance_embeds=False, so no guidance parameter is used.
48
+ CFG is applied externally using negative conditioning when cfg_scale != 1.0.
49
+
50
+ Args:
51
+ model: The Flux2Transformer2DModel from diffusers.
52
+ img: Packed latent image tensor of shape (B, seq_len, channels).
53
+ img_ids: Image position IDs tensor.
54
+ txt: Text encoder hidden states (Qwen3 embeddings).
55
+ txt_ids: Text position IDs tensor.
56
+ timesteps: List of timesteps for denoising schedule (linear sigmas from 1.0 to 1/n).
57
+ step_callback: Callback function for progress updates.
58
+ cfg_scale: List of CFG scale values per step.
59
+ neg_txt: Negative text embeddings for CFG (optional).
60
+ neg_txt_ids: Negative text position IDs (optional).
61
+ scheduler: Optional diffusers scheduler (Euler, Heun, LCM). If None, uses manual Euler.
62
+ mu: Dynamic shifting parameter computed from image resolution. Required when scheduler
63
+ has use_dynamic_shifting=True.
64
+
65
+ Returns:
66
+ Denoised latent tensor.
67
+ """
68
+ total_steps = len(timesteps) - 1
69
+
70
+ # Store original sequence length for extracting output later (before concatenating reference images)
71
+ original_seq_len = img.shape[1]
72
+
73
+ # Concatenate reference image conditioning if provided (multi-reference image editing)
74
+ if img_cond_seq is not None and img_cond_seq_ids is not None:
75
+ img = torch.cat([img, img_cond_seq], dim=1)
76
+ img_ids = torch.cat([img_ids, img_cond_seq_ids], dim=1)
77
+
78
+ # Klein has guidance_embeds=False, but the transformer forward() still requires a guidance tensor
79
+ # We pass a dummy value (1.0) since it won't affect the output when guidance_embeds=False
80
+ guidance = torch.full((img.shape[0],), 1.0, device=img.device, dtype=img.dtype)
81
+
82
+ # Use scheduler if provided
83
+ use_scheduler = scheduler is not None
84
+ if use_scheduler:
85
+ # Set up scheduler with sigmas and mu for dynamic shifting
86
+ # Convert timesteps (0-1 range) to sigmas for the scheduler
87
+ # The scheduler will apply dynamic shifting internally using mu (if enabled in scheduler config)
88
+ sigmas = np.array(timesteps[:-1], dtype=np.float32) # Exclude final 0.0
89
+
90
+ # Pass mu if provided - it will only be used if scheduler has use_dynamic_shifting=True
91
+ if mu is not None:
92
+ scheduler.set_timesteps(sigmas=sigmas.tolist(), mu=mu, device=img.device)
93
+ else:
94
+ scheduler.set_timesteps(sigmas=sigmas.tolist(), device=img.device)
95
+ num_scheduler_steps = len(scheduler.timesteps)
96
+ is_heun = hasattr(scheduler, "state_in_first_order")
97
+ user_step = 0
98
+
99
+ pbar = tqdm(total=total_steps, desc="Denoising")
100
+ for step_index in range(num_scheduler_steps):
101
+ timestep = scheduler.timesteps[step_index]
102
+ # Convert scheduler timestep (0-1000) to normalized (0-1) for the model
103
+ t_curr = timestep.item() / scheduler.config.num_train_timesteps
104
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
105
+
106
+ # Track if we're in first or second order step (for Heun)
107
+ in_first_order = scheduler.state_in_first_order if is_heun else True
108
+
109
+ # Run the transformer model (matching diffusers: guidance=guidance, return_dict=False)
110
+ output = model(
111
+ hidden_states=img,
112
+ encoder_hidden_states=txt,
113
+ timestep=t_vec,
114
+ img_ids=img_ids,
115
+ txt_ids=txt_ids,
116
+ guidance=guidance,
117
+ return_dict=False,
118
+ )
119
+
120
+ # Extract the sample from the output (return_dict=False returns tuple)
121
+ pred = output[0] if isinstance(output, tuple) else output
122
+
123
+ step_cfg_scale = cfg_scale[min(user_step, len(cfg_scale) - 1)]
124
+
125
+ # Apply CFG if scale is not 1.0
126
+ if not math.isclose(step_cfg_scale, 1.0):
127
+ if neg_txt is None:
128
+ raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
129
+
130
+ neg_output = model(
131
+ hidden_states=img,
132
+ encoder_hidden_states=neg_txt,
133
+ timestep=t_vec,
134
+ img_ids=img_ids,
135
+ txt_ids=neg_txt_ids if neg_txt_ids is not None else txt_ids,
136
+ guidance=guidance,
137
+ return_dict=False,
138
+ )
139
+
140
+ neg_pred = neg_output[0] if isinstance(neg_output, tuple) else neg_output
141
+ pred = neg_pred + step_cfg_scale * (pred - neg_pred)
142
+
143
+ # Use scheduler.step() for the update
144
+ step_output = scheduler.step(model_output=pred, timestep=timestep, sample=img)
145
+ img = step_output.prev_sample
146
+
147
+ # Get t_prev for inpainting (next sigma value)
148
+ if step_index + 1 < len(scheduler.sigmas):
149
+ t_prev = scheduler.sigmas[step_index + 1].item()
150
+ else:
151
+ t_prev = 0.0
152
+
153
+ # Apply inpainting merge at each step
154
+ if inpaint_extension is not None:
155
+ img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
156
+
157
+ # For Heun, only increment user step after second-order step completes
158
+ if is_heun:
159
+ if not in_first_order:
160
+ user_step += 1
161
+ if user_step <= total_steps:
162
+ pbar.update(1)
163
+ preview_img = img - t_curr * pred
164
+ if inpaint_extension is not None:
165
+ preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(
166
+ preview_img, 0.0
167
+ )
168
+ step_callback(
169
+ PipelineIntermediateState(
170
+ step=user_step,
171
+ order=2,
172
+ total_steps=total_steps,
173
+ timestep=int(t_curr * 1000),
174
+ latents=preview_img,
175
+ ),
176
+ )
177
+ else:
178
+ user_step += 1
179
+ if user_step <= total_steps:
180
+ pbar.update(1)
181
+ preview_img = img - t_curr * pred
182
+ if inpaint_extension is not None:
183
+ preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_img, 0.0)
184
+ # Extract only the generated image portion for preview (exclude reference images)
185
+ callback_latents = preview_img[:, :original_seq_len, :] if img_cond_seq is not None else preview_img
186
+ step_callback(
187
+ PipelineIntermediateState(
188
+ step=user_step,
189
+ order=1,
190
+ total_steps=total_steps,
191
+ timestep=int(t_curr * 1000),
192
+ latents=callback_latents,
193
+ ),
194
+ )
195
+
196
+ pbar.close()
197
+ else:
198
+ # Manual Euler stepping (original behavior)
199
+ for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
200
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
201
+
202
+ # Run the transformer model (matching diffusers: guidance=guidance, return_dict=False)
203
+ output = model(
204
+ hidden_states=img,
205
+ encoder_hidden_states=txt,
206
+ timestep=t_vec,
207
+ img_ids=img_ids,
208
+ txt_ids=txt_ids,
209
+ guidance=guidance,
210
+ return_dict=False,
211
+ )
212
+
213
+ # Extract the sample from the output (return_dict=False returns tuple)
214
+ pred = output[0] if isinstance(output, tuple) else output
215
+
216
+ step_cfg_scale = cfg_scale[step_index]
217
+
218
+ # Apply CFG if scale is not 1.0
219
+ if not math.isclose(step_cfg_scale, 1.0):
220
+ if neg_txt is None:
221
+ raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
222
+
223
+ neg_output = model(
224
+ hidden_states=img,
225
+ encoder_hidden_states=neg_txt,
226
+ timestep=t_vec,
227
+ img_ids=img_ids,
228
+ txt_ids=neg_txt_ids if neg_txt_ids is not None else txt_ids,
229
+ guidance=guidance,
230
+ return_dict=False,
231
+ )
232
+
233
+ neg_pred = neg_output[0] if isinstance(neg_output, tuple) else neg_output
234
+ pred = neg_pred + step_cfg_scale * (pred - neg_pred)
235
+
236
+ # Euler step
237
+ preview_img = img - t_curr * pred
238
+ img = img + (t_prev - t_curr) * pred
239
+
240
+ # Apply inpainting merge at each step
241
+ if inpaint_extension is not None:
242
+ img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
243
+ preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_img, 0.0)
244
+
245
+ # Extract only the generated image portion for preview (exclude reference images)
246
+ callback_latents = preview_img[:, :original_seq_len, :] if img_cond_seq is not None else preview_img
247
+ step_callback(
248
+ PipelineIntermediateState(
249
+ step=step_index + 1,
250
+ order=1,
251
+ total_steps=total_steps,
252
+ timestep=int(t_curr),
253
+ latents=callback_latents,
254
+ ),
255
+ )
256
+
257
+ # Extract only the generated image portion (exclude concatenated reference images)
258
+ if img_cond_seq is not None:
259
+ img = img[:, :original_seq_len, :]
260
+
261
+ return img
@@ -0,0 +1,294 @@
1
+ """FLUX.2 Klein Reference Image Extension for multi-reference image editing.
2
+
3
+ This module provides the Flux2RefImageExtension for FLUX.2 Klein models,
4
+ which handles encoding reference images using the FLUX.2 VAE and
5
+ generating the appropriate position IDs for multi-reference image editing.
6
+
7
+ FLUX.2 Klein has built-in support for reference image editing (unlike FLUX.1
8
+ which requires a separate Kontext model).
9
+ """
10
+
11
+ import math
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ import torchvision.transforms as T
16
+ from einops import repeat
17
+ from PIL import Image
18
+
19
+ from invokeai.app.invocations.fields import FluxKontextConditioningField
20
+ from invokeai.app.invocations.model import VAEField
21
+ from invokeai.app.services.shared.invocation_context import InvocationContext
22
+ from invokeai.backend.flux2.sampling_utils import pack_flux2
23
+ from invokeai.backend.util.devices import TorchDevice
24
+
25
+ # Maximum pixel counts for reference images (matches BFL FLUX.2 sampling.py)
26
+ # Single reference image: 2024² pixels, Multiple: 1024² pixels
27
+ MAX_PIXELS_SINGLE_REF = 2024**2 # ~4.1M pixels
28
+ MAX_PIXELS_MULTI_REF = 1024**2 # ~1M pixels
29
+
30
+
31
+ def resize_image_to_max_pixels(image: Image.Image, max_pixels: int) -> Image.Image:
32
+ """Resize image to fit within max_pixels while preserving aspect ratio.
33
+
34
+ This matches the BFL FLUX.2 sampling.py cap_pixels() behavior.
35
+
36
+ Args:
37
+ image: PIL Image to resize.
38
+ max_pixels: Maximum total pixel count (width * height).
39
+
40
+ Returns:
41
+ Resized PIL Image (or original if already within bounds).
42
+ """
43
+ width, height = image.size
44
+ pixel_count = width * height
45
+
46
+ if pixel_count <= max_pixels:
47
+ return image
48
+
49
+ # Calculate scale factor to fit within max_pixels (BFL approach)
50
+ scale = math.sqrt(max_pixels / pixel_count)
51
+ new_width = int(width * scale)
52
+ new_height = int(height * scale)
53
+
54
+ # Ensure dimensions are at least 1
55
+ new_width = max(1, new_width)
56
+ new_height = max(1, new_height)
57
+
58
+ return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
59
+
60
+
61
+ def generate_img_ids_flux2_with_offset(
62
+ latent_height: int,
63
+ latent_width: int,
64
+ batch_size: int,
65
+ device: torch.device,
66
+ idx_offset: int = 0,
67
+ h_offset: int = 0,
68
+ w_offset: int = 0,
69
+ ) -> torch.Tensor:
70
+ """Generate tensor of image position ids with optional offsets for FLUX.2.
71
+
72
+ FLUX.2 uses 4D position coordinates (T, H, W, L) for its rotary position embeddings.
73
+ Position IDs use int64 (long) dtype.
74
+
75
+ Args:
76
+ latent_height: Height of image in latent space (before packing).
77
+ latent_width: Width of image in latent space (before packing).
78
+ batch_size: Number of images in the batch.
79
+ device: Device to create tensors on.
80
+ idx_offset: Offset for T (time/index) coordinate - use 1 for reference images.
81
+ h_offset: Spatial offset for H coordinate in latent space.
82
+ w_offset: Spatial offset for W coordinate in latent space.
83
+
84
+ Returns:
85
+ Image position ids with shape [batch_size, (latent_height//2 * latent_width//2), 4].
86
+ """
87
+ # After packing, the spatial dimensions are halved due to the 2x2 patch structure
88
+ packed_height = latent_height // 2
89
+ packed_width = latent_width // 2
90
+
91
+ # Convert spatial offsets from latent space to packed space
92
+ packed_h_offset = h_offset // 2
93
+ packed_w_offset = w_offset // 2
94
+
95
+ # Create base tensor for position IDs with shape [packed_height, packed_width, 4]
96
+ # The 4 channels represent: [T, H, W, L]
97
+ img_ids = torch.zeros(packed_height, packed_width, 4, device=device, dtype=torch.long)
98
+
99
+ # Set T (time/index offset) for all positions - use 1 for reference images
100
+ img_ids[..., 0] = idx_offset
101
+
102
+ # Set H (height/y) coordinates with offset
103
+ h_coords = torch.arange(packed_height, device=device, dtype=torch.long) + packed_h_offset
104
+ img_ids[..., 1] = h_coords[:, None]
105
+
106
+ # Set W (width/x) coordinates with offset
107
+ w_coords = torch.arange(packed_width, device=device, dtype=torch.long) + packed_w_offset
108
+ img_ids[..., 2] = w_coords[None, :]
109
+
110
+ # L (layer) coordinate stays 0
111
+
112
+ # Expand to include batch dimension: [batch_size, (packed_height * packed_width), 4]
113
+ img_ids = img_ids.reshape(1, packed_height * packed_width, 4)
114
+ img_ids = repeat(img_ids, "1 s c -> b s c", b=batch_size)
115
+
116
+ return img_ids
117
+
118
+
119
+ class Flux2RefImageExtension:
120
+ """Applies FLUX.2 Klein reference image conditioning.
121
+
122
+ This extension handles encoding reference images using the FLUX.2 VAE
123
+ and generating the appropriate 4D position IDs for multi-reference image editing.
124
+
125
+ FLUX.2 Klein has built-in support for reference image editing, unlike FLUX.1
126
+ which requires a separate Kontext model.
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ ref_image_conditioning: list[FluxKontextConditioningField],
132
+ context: InvocationContext,
133
+ vae_field: VAEField,
134
+ device: torch.device,
135
+ dtype: torch.dtype,
136
+ bn_mean: torch.Tensor | None = None,
137
+ bn_std: torch.Tensor | None = None,
138
+ ):
139
+ """Initialize the Flux2RefImageExtension.
140
+
141
+ Args:
142
+ ref_image_conditioning: List of reference image conditioning fields.
143
+ context: The invocation context for loading models and images.
144
+ vae_field: The FLUX.2 VAE field for encoding images.
145
+ device: Target device for tensors.
146
+ dtype: Target dtype for tensors.
147
+ bn_mean: BN running mean for normalizing latents (shape: 128).
148
+ bn_std: BN running std for normalizing latents (shape: 128).
149
+ """
150
+ self._context = context
151
+ self._device = device
152
+ self._dtype = dtype
153
+ self._vae_field = vae_field
154
+ self._bn_mean = bn_mean
155
+ self._bn_std = bn_std
156
+ self.ref_image_conditioning = ref_image_conditioning
157
+
158
+ # Pre-process and cache the reference image latents and ids upon initialization
159
+ self.ref_image_latents, self.ref_image_ids = self._prepare_ref_images()
160
+
161
+ def _bn_normalize(self, x: torch.Tensor) -> torch.Tensor:
162
+ """Apply BN normalization to packed latents.
163
+
164
+ BN formula (affine=False): y = (x - mean) / std
165
+
166
+ Args:
167
+ x: Packed latents of shape (B, seq, 128).
168
+
169
+ Returns:
170
+ Normalized latents of same shape.
171
+ """
172
+ assert self._bn_mean is not None and self._bn_std is not None
173
+ bn_mean = self._bn_mean.to(x.device, x.dtype)
174
+ bn_std = self._bn_std.to(x.device, x.dtype)
175
+ return (x - bn_mean) / bn_std
176
+
177
+ def _prepare_ref_images(self) -> tuple[torch.Tensor, torch.Tensor]:
178
+ """Encode reference images and prepare their concatenated latents and IDs with spatial tiling."""
179
+ all_latents = []
180
+ all_ids = []
181
+
182
+ # Track cumulative dimensions for spatial tiling
183
+ canvas_h = 0
184
+ canvas_w = 0
185
+
186
+ vae_info = self._context.models.load(self._vae_field.vae)
187
+
188
+ # Determine max pixels based on number of reference images (BFL FLUX.2 approach)
189
+ num_refs = len(self.ref_image_conditioning)
190
+ max_pixels = MAX_PIXELS_SINGLE_REF if num_refs == 1 else MAX_PIXELS_MULTI_REF
191
+
192
+ for idx, ref_image_field in enumerate(self.ref_image_conditioning):
193
+ image = self._context.images.get_pil(ref_image_field.image.image_name)
194
+ image = image.convert("RGB")
195
+
196
+ # Resize large images to max pixel count (matches BFL FLUX.2 sampling.py)
197
+ image = resize_image_to_max_pixels(image, max_pixels)
198
+
199
+ # Convert to tensor using torchvision transforms
200
+ transformation = T.Compose([T.ToTensor()])
201
+ image_tensor = transformation(image)
202
+ # Convert from [0, 1] to [-1, 1] range expected by VAE
203
+ image_tensor = image_tensor * 2.0 - 1.0
204
+ image_tensor = image_tensor.unsqueeze(0) # Add batch dimension
205
+
206
+ # Encode using FLUX.2 VAE
207
+ with vae_info.model_on_device() as (_, vae):
208
+ vae_dtype = next(iter(vae.parameters())).dtype
209
+ image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
210
+
211
+ # FLUX.2 VAE uses diffusers API
212
+ latent_dist = vae.encode(image_tensor, return_dict=False)[0]
213
+
214
+ # Use mode() for deterministic encoding (no sampling)
215
+ if hasattr(latent_dist, "mode"):
216
+ ref_image_latents_unpacked = latent_dist.mode()
217
+ elif hasattr(latent_dist, "sample"):
218
+ ref_image_latents_unpacked = latent_dist.sample()
219
+ else:
220
+ ref_image_latents_unpacked = latent_dist
221
+
222
+ TorchDevice.empty_cache()
223
+
224
+ # Extract tensor dimensions (B, 32, H, W for FLUX.2)
225
+ batch_size, _, latent_height, latent_width = ref_image_latents_unpacked.shape
226
+
227
+ # Pad latents to be compatible with patch_size=2
228
+ pad_h = (2 - latent_height % 2) % 2
229
+ pad_w = (2 - latent_width % 2) % 2
230
+ if pad_h > 0 or pad_w > 0:
231
+ ref_image_latents_unpacked = F.pad(ref_image_latents_unpacked, (0, pad_w, 0, pad_h), mode="circular")
232
+ _, _, latent_height, latent_width = ref_image_latents_unpacked.shape
233
+
234
+ # Pack the latents using FLUX.2 pack function (32 channels -> 128)
235
+ ref_image_latents_packed = pack_flux2(ref_image_latents_unpacked).to(self._device, self._dtype)
236
+
237
+ # Apply BN normalization to match the input latents scale
238
+ # This is critical - the transformer expects normalized latents
239
+ if self._bn_mean is not None and self._bn_std is not None:
240
+ ref_image_latents_packed = self._bn_normalize(ref_image_latents_packed)
241
+
242
+ # Determine spatial offsets for this reference image
243
+ h_offset = 0
244
+ w_offset = 0
245
+
246
+ if idx > 0: # First image starts at (0, 0)
247
+ # Calculate potential canvas dimensions for each tiling option
248
+ potential_h_vertical = canvas_h + latent_height
249
+ potential_w_horizontal = canvas_w + latent_width
250
+
251
+ # Choose arrangement that minimizes the maximum dimension
252
+ if potential_h_vertical > potential_w_horizontal:
253
+ # Tile horizontally (to the right)
254
+ w_offset = canvas_w
255
+ canvas_w = canvas_w + latent_width
256
+ canvas_h = max(canvas_h, latent_height)
257
+ else:
258
+ # Tile vertically (below)
259
+ h_offset = canvas_h
260
+ canvas_h = canvas_h + latent_height
261
+ canvas_w = max(canvas_w, latent_width)
262
+ else:
263
+ canvas_h = latent_height
264
+ canvas_w = latent_width
265
+
266
+ # Generate position IDs with 4D format (T, H, W, L)
267
+ # Use T-coordinate offset with scale=10 like diffusers Flux2Pipeline:
268
+ # T = scale + scale * idx (so first ref image is T=10, second is T=20, etc.)
269
+ # The generated image uses T=0, so this clearly separates reference images
270
+ t_offset = 10 + 10 * idx # scale=10 matches diffusers
271
+ ref_image_ids = generate_img_ids_flux2_with_offset(
272
+ latent_height=latent_height,
273
+ latent_width=latent_width,
274
+ batch_size=batch_size,
275
+ device=self._device,
276
+ idx_offset=t_offset, # Reference images use T=10, 20, 30...
277
+ h_offset=h_offset,
278
+ w_offset=w_offset,
279
+ )
280
+
281
+ all_latents.append(ref_image_latents_packed)
282
+ all_ids.append(ref_image_ids)
283
+
284
+ # Concatenate all latents and IDs along the sequence dimension
285
+ concatenated_latents = torch.cat(all_latents, dim=1)
286
+ concatenated_ids = torch.cat(all_ids, dim=1)
287
+
288
+ return concatenated_latents, concatenated_ids
289
+
290
+ def ensure_batch_size(self, target_batch_size: int) -> None:
291
+ """Ensure the reference image latents and IDs match the target batch size."""
292
+ if self.ref_image_latents.shape[0] != target_batch_size:
293
+ self.ref_image_latents = self.ref_image_latents.repeat(target_batch_size, 1, 1)
294
+ self.ref_image_ids = self.ref_image_ids.repeat(target_batch_size, 1, 1)