InvokeAI 6.10.0rc2__py3-none-any.whl → 6.11.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invokeai/app/api/routers/model_manager.py +43 -1
- invokeai/app/invocations/fields.py +1 -1
- invokeai/app/invocations/flux2_denoise.py +499 -0
- invokeai/app/invocations/flux2_klein_model_loader.py +222 -0
- invokeai/app/invocations/flux2_klein_text_encoder.py +222 -0
- invokeai/app/invocations/flux2_vae_decode.py +106 -0
- invokeai/app/invocations/flux2_vae_encode.py +88 -0
- invokeai/app/invocations/flux_denoise.py +50 -3
- invokeai/app/invocations/flux_lora_loader.py +1 -1
- invokeai/app/invocations/ideal_size.py +6 -1
- invokeai/app/invocations/metadata.py +4 -0
- invokeai/app/invocations/metadata_linked.py +47 -0
- invokeai/app/invocations/model.py +1 -0
- invokeai/app/invocations/z_image_denoise.py +8 -3
- invokeai/app/invocations/z_image_image_to_latents.py +9 -1
- invokeai/app/invocations/z_image_latents_to_image.py +9 -1
- invokeai/app/invocations/z_image_seed_variance_enhancer.py +110 -0
- invokeai/app/services/config/config_default.py +3 -1
- invokeai/app/services/invocation_stats/invocation_stats_common.py +6 -6
- invokeai/app/services/invocation_stats/invocation_stats_default.py +9 -4
- invokeai/app/services/model_manager/model_manager_default.py +7 -0
- invokeai/app/services/model_records/model_records_base.py +4 -2
- invokeai/app/services/shared/invocation_context.py +15 -0
- invokeai/app/services/shared/sqlite/sqlite_util.py +2 -0
- invokeai/app/services/shared/sqlite_migrator/migrations/migration_25.py +61 -0
- invokeai/app/util/step_callback.py +42 -0
- invokeai/backend/flux/denoise.py +239 -204
- invokeai/backend/flux/dype/__init__.py +18 -0
- invokeai/backend/flux/dype/base.py +226 -0
- invokeai/backend/flux/dype/embed.py +116 -0
- invokeai/backend/flux/dype/presets.py +141 -0
- invokeai/backend/flux/dype/rope.py +110 -0
- invokeai/backend/flux/extensions/dype_extension.py +91 -0
- invokeai/backend/flux/util.py +35 -1
- invokeai/backend/flux2/__init__.py +4 -0
- invokeai/backend/flux2/denoise.py +261 -0
- invokeai/backend/flux2/ref_image_extension.py +294 -0
- invokeai/backend/flux2/sampling_utils.py +209 -0
- invokeai/backend/model_manager/configs/factory.py +19 -1
- invokeai/backend/model_manager/configs/main.py +395 -3
- invokeai/backend/model_manager/configs/qwen3_encoder.py +116 -7
- invokeai/backend/model_manager/configs/vae.py +104 -2
- invokeai/backend/model_manager/load/load_default.py +0 -1
- invokeai/backend/model_manager/load/model_cache/model_cache.py +107 -2
- invokeai/backend/model_manager/load/model_loaders/flux.py +1007 -2
- invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +0 -1
- invokeai/backend/model_manager/load/model_loaders/z_image.py +121 -28
- invokeai/backend/model_manager/starter_models.py +128 -0
- invokeai/backend/model_manager/taxonomy.py +31 -4
- invokeai/backend/model_manager/util/select_hf_files.py +3 -2
- invokeai/backend/util/vae_working_memory.py +0 -2
- invokeai/frontend/web/dist/assets/App-ClpIJstk.js +161 -0
- invokeai/frontend/web/dist/assets/{browser-ponyfill-BP0RxJ4G.js → browser-ponyfill-Cw07u5G1.js} +1 -1
- invokeai/frontend/web/dist/assets/{index-B44qKjrs.js → index-DSKM8iGj.js} +69 -69
- invokeai/frontend/web/dist/index.html +1 -1
- invokeai/frontend/web/dist/locales/en.json +58 -5
- invokeai/frontend/web/dist/locales/it.json +2 -1
- invokeai/version/invokeai_version.py +1 -1
- {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/METADATA +7 -1
- {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/RECORD +66 -49
- {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/WHEEL +1 -1
- invokeai/frontend/web/dist/assets/App-DllqPQ3j.js +0 -161
- {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/entry_points.txt +0 -0
- {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/licenses/LICENSE +0 -0
- {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
- {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
- {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/top_level.txt +0 -0
invokeai/backend/flux/util.py
CHANGED
|
@@ -5,7 +5,7 @@ from typing import Literal
|
|
|
5
5
|
|
|
6
6
|
from invokeai.backend.flux.model import FluxParams
|
|
7
7
|
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams
|
|
8
|
-
from invokeai.backend.model_manager.taxonomy import AnyVariant, FluxVariantType
|
|
8
|
+
from invokeai.backend.model_manager.taxonomy import AnyVariant, Flux2VariantType, FluxVariantType
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
@dataclass
|
|
@@ -46,6 +46,8 @@ _flux_max_seq_lengths: dict[AnyVariant, Literal[256, 512]] = {
|
|
|
46
46
|
FluxVariantType.Dev: 512,
|
|
47
47
|
FluxVariantType.DevFill: 512,
|
|
48
48
|
FluxVariantType.Schnell: 256,
|
|
49
|
+
Flux2VariantType.Klein4B: 512,
|
|
50
|
+
Flux2VariantType.Klein9B: 512,
|
|
49
51
|
}
|
|
50
52
|
|
|
51
53
|
|
|
@@ -117,6 +119,38 @@ _flux_transformer_params: dict[AnyVariant, FluxParams] = {
|
|
|
117
119
|
qkv_bias=True,
|
|
118
120
|
guidance_embed=True,
|
|
119
121
|
),
|
|
122
|
+
# Flux2 Klein 4B uses Qwen3 4B text encoder with stacked embeddings from layers [9, 18, 27]
|
|
123
|
+
# The context_in_dim is 3 * hidden_size of Qwen3 (3 * 2560 = 7680)
|
|
124
|
+
Flux2VariantType.Klein4B: FluxParams(
|
|
125
|
+
in_channels=64,
|
|
126
|
+
vec_in_dim=2560, # Qwen3-4B hidden size (used for pooled output)
|
|
127
|
+
context_in_dim=7680, # 3 layers * 2560 = 7680 for Qwen3-4B
|
|
128
|
+
hidden_size=3072,
|
|
129
|
+
mlp_ratio=4.0,
|
|
130
|
+
num_heads=24,
|
|
131
|
+
depth=19,
|
|
132
|
+
depth_single_blocks=38,
|
|
133
|
+
axes_dim=[16, 56, 56],
|
|
134
|
+
theta=10_000,
|
|
135
|
+
qkv_bias=True,
|
|
136
|
+
guidance_embed=True,
|
|
137
|
+
),
|
|
138
|
+
# Flux2 Klein 9B uses Qwen3 8B text encoder with stacked embeddings from layers [9, 18, 27]
|
|
139
|
+
# The context_in_dim is 3 * hidden_size of Qwen3 (3 * 4096 = 12288)
|
|
140
|
+
Flux2VariantType.Klein9B: FluxParams(
|
|
141
|
+
in_channels=64,
|
|
142
|
+
vec_in_dim=4096, # Qwen3-8B hidden size (used for pooled output)
|
|
143
|
+
context_in_dim=12288, # 3 layers * 4096 = 12288 for Qwen3-8B
|
|
144
|
+
hidden_size=3072,
|
|
145
|
+
mlp_ratio=4.0,
|
|
146
|
+
num_heads=24,
|
|
147
|
+
depth=19,
|
|
148
|
+
depth_single_blocks=38,
|
|
149
|
+
axes_dim=[16, 56, 56],
|
|
150
|
+
theta=10_000,
|
|
151
|
+
qkv_bias=True,
|
|
152
|
+
guidance_embed=True,
|
|
153
|
+
),
|
|
120
154
|
}
|
|
121
155
|
|
|
122
156
|
|
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
"""Flux2 Klein Denoising Function.
|
|
2
|
+
|
|
3
|
+
This module provides the denoising function for FLUX.2 Klein models,
|
|
4
|
+
which use Qwen3 as the text encoder instead of CLIP+T5.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
from typing import Any, Callable
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import torch
|
|
12
|
+
from tqdm import tqdm
|
|
13
|
+
|
|
14
|
+
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
|
|
15
|
+
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def denoise(
|
|
19
|
+
model: torch.nn.Module,
|
|
20
|
+
# model input
|
|
21
|
+
img: torch.Tensor,
|
|
22
|
+
img_ids: torch.Tensor,
|
|
23
|
+
txt: torch.Tensor,
|
|
24
|
+
txt_ids: torch.Tensor,
|
|
25
|
+
# sampling parameters
|
|
26
|
+
timesteps: list[float],
|
|
27
|
+
step_callback: Callable[[PipelineIntermediateState], None],
|
|
28
|
+
cfg_scale: list[float],
|
|
29
|
+
# Negative conditioning for CFG
|
|
30
|
+
neg_txt: torch.Tensor | None = None,
|
|
31
|
+
neg_txt_ids: torch.Tensor | None = None,
|
|
32
|
+
# Scheduler for stepping (e.g., FlowMatchEulerDiscreteScheduler, FlowMatchHeunDiscreteScheduler)
|
|
33
|
+
scheduler: Any = None,
|
|
34
|
+
# Dynamic shifting parameter for FLUX.2 Klein (computed from image resolution)
|
|
35
|
+
mu: float | None = None,
|
|
36
|
+
# Inpainting extension for merging latents during denoising
|
|
37
|
+
inpaint_extension: RectifiedFlowInpaintExtension | None = None,
|
|
38
|
+
# Reference image conditioning (multi-reference image editing)
|
|
39
|
+
img_cond_seq: torch.Tensor | None = None,
|
|
40
|
+
img_cond_seq_ids: torch.Tensor | None = None,
|
|
41
|
+
) -> torch.Tensor:
|
|
42
|
+
"""Denoise latents using a FLUX.2 Klein transformer model.
|
|
43
|
+
|
|
44
|
+
This is a simplified denoise function for FLUX.2 Klein models that uses
|
|
45
|
+
the diffusers Flux2Transformer2DModel interface.
|
|
46
|
+
|
|
47
|
+
Note: FLUX.2 Klein has guidance_embeds=False, so no guidance parameter is used.
|
|
48
|
+
CFG is applied externally using negative conditioning when cfg_scale != 1.0.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
model: The Flux2Transformer2DModel from diffusers.
|
|
52
|
+
img: Packed latent image tensor of shape (B, seq_len, channels).
|
|
53
|
+
img_ids: Image position IDs tensor.
|
|
54
|
+
txt: Text encoder hidden states (Qwen3 embeddings).
|
|
55
|
+
txt_ids: Text position IDs tensor.
|
|
56
|
+
timesteps: List of timesteps for denoising schedule (linear sigmas from 1.0 to 1/n).
|
|
57
|
+
step_callback: Callback function for progress updates.
|
|
58
|
+
cfg_scale: List of CFG scale values per step.
|
|
59
|
+
neg_txt: Negative text embeddings for CFG (optional).
|
|
60
|
+
neg_txt_ids: Negative text position IDs (optional).
|
|
61
|
+
scheduler: Optional diffusers scheduler (Euler, Heun, LCM). If None, uses manual Euler.
|
|
62
|
+
mu: Dynamic shifting parameter computed from image resolution. Required when scheduler
|
|
63
|
+
has use_dynamic_shifting=True.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
Denoised latent tensor.
|
|
67
|
+
"""
|
|
68
|
+
total_steps = len(timesteps) - 1
|
|
69
|
+
|
|
70
|
+
# Store original sequence length for extracting output later (before concatenating reference images)
|
|
71
|
+
original_seq_len = img.shape[1]
|
|
72
|
+
|
|
73
|
+
# Concatenate reference image conditioning if provided (multi-reference image editing)
|
|
74
|
+
if img_cond_seq is not None and img_cond_seq_ids is not None:
|
|
75
|
+
img = torch.cat([img, img_cond_seq], dim=1)
|
|
76
|
+
img_ids = torch.cat([img_ids, img_cond_seq_ids], dim=1)
|
|
77
|
+
|
|
78
|
+
# Klein has guidance_embeds=False, but the transformer forward() still requires a guidance tensor
|
|
79
|
+
# We pass a dummy value (1.0) since it won't affect the output when guidance_embeds=False
|
|
80
|
+
guidance = torch.full((img.shape[0],), 1.0, device=img.device, dtype=img.dtype)
|
|
81
|
+
|
|
82
|
+
# Use scheduler if provided
|
|
83
|
+
use_scheduler = scheduler is not None
|
|
84
|
+
if use_scheduler:
|
|
85
|
+
# Set up scheduler with sigmas and mu for dynamic shifting
|
|
86
|
+
# Convert timesteps (0-1 range) to sigmas for the scheduler
|
|
87
|
+
# The scheduler will apply dynamic shifting internally using mu (if enabled in scheduler config)
|
|
88
|
+
sigmas = np.array(timesteps[:-1], dtype=np.float32) # Exclude final 0.0
|
|
89
|
+
|
|
90
|
+
# Pass mu if provided - it will only be used if scheduler has use_dynamic_shifting=True
|
|
91
|
+
if mu is not None:
|
|
92
|
+
scheduler.set_timesteps(sigmas=sigmas.tolist(), mu=mu, device=img.device)
|
|
93
|
+
else:
|
|
94
|
+
scheduler.set_timesteps(sigmas=sigmas.tolist(), device=img.device)
|
|
95
|
+
num_scheduler_steps = len(scheduler.timesteps)
|
|
96
|
+
is_heun = hasattr(scheduler, "state_in_first_order")
|
|
97
|
+
user_step = 0
|
|
98
|
+
|
|
99
|
+
pbar = tqdm(total=total_steps, desc="Denoising")
|
|
100
|
+
for step_index in range(num_scheduler_steps):
|
|
101
|
+
timestep = scheduler.timesteps[step_index]
|
|
102
|
+
# Convert scheduler timestep (0-1000) to normalized (0-1) for the model
|
|
103
|
+
t_curr = timestep.item() / scheduler.config.num_train_timesteps
|
|
104
|
+
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
|
105
|
+
|
|
106
|
+
# Track if we're in first or second order step (for Heun)
|
|
107
|
+
in_first_order = scheduler.state_in_first_order if is_heun else True
|
|
108
|
+
|
|
109
|
+
# Run the transformer model (matching diffusers: guidance=guidance, return_dict=False)
|
|
110
|
+
output = model(
|
|
111
|
+
hidden_states=img,
|
|
112
|
+
encoder_hidden_states=txt,
|
|
113
|
+
timestep=t_vec,
|
|
114
|
+
img_ids=img_ids,
|
|
115
|
+
txt_ids=txt_ids,
|
|
116
|
+
guidance=guidance,
|
|
117
|
+
return_dict=False,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# Extract the sample from the output (return_dict=False returns tuple)
|
|
121
|
+
pred = output[0] if isinstance(output, tuple) else output
|
|
122
|
+
|
|
123
|
+
step_cfg_scale = cfg_scale[min(user_step, len(cfg_scale) - 1)]
|
|
124
|
+
|
|
125
|
+
# Apply CFG if scale is not 1.0
|
|
126
|
+
if not math.isclose(step_cfg_scale, 1.0):
|
|
127
|
+
if neg_txt is None:
|
|
128
|
+
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
|
|
129
|
+
|
|
130
|
+
neg_output = model(
|
|
131
|
+
hidden_states=img,
|
|
132
|
+
encoder_hidden_states=neg_txt,
|
|
133
|
+
timestep=t_vec,
|
|
134
|
+
img_ids=img_ids,
|
|
135
|
+
txt_ids=neg_txt_ids if neg_txt_ids is not None else txt_ids,
|
|
136
|
+
guidance=guidance,
|
|
137
|
+
return_dict=False,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
neg_pred = neg_output[0] if isinstance(neg_output, tuple) else neg_output
|
|
141
|
+
pred = neg_pred + step_cfg_scale * (pred - neg_pred)
|
|
142
|
+
|
|
143
|
+
# Use scheduler.step() for the update
|
|
144
|
+
step_output = scheduler.step(model_output=pred, timestep=timestep, sample=img)
|
|
145
|
+
img = step_output.prev_sample
|
|
146
|
+
|
|
147
|
+
# Get t_prev for inpainting (next sigma value)
|
|
148
|
+
if step_index + 1 < len(scheduler.sigmas):
|
|
149
|
+
t_prev = scheduler.sigmas[step_index + 1].item()
|
|
150
|
+
else:
|
|
151
|
+
t_prev = 0.0
|
|
152
|
+
|
|
153
|
+
# Apply inpainting merge at each step
|
|
154
|
+
if inpaint_extension is not None:
|
|
155
|
+
img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
|
|
156
|
+
|
|
157
|
+
# For Heun, only increment user step after second-order step completes
|
|
158
|
+
if is_heun:
|
|
159
|
+
if not in_first_order:
|
|
160
|
+
user_step += 1
|
|
161
|
+
if user_step <= total_steps:
|
|
162
|
+
pbar.update(1)
|
|
163
|
+
preview_img = img - t_curr * pred
|
|
164
|
+
if inpaint_extension is not None:
|
|
165
|
+
preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(
|
|
166
|
+
preview_img, 0.0
|
|
167
|
+
)
|
|
168
|
+
step_callback(
|
|
169
|
+
PipelineIntermediateState(
|
|
170
|
+
step=user_step,
|
|
171
|
+
order=2,
|
|
172
|
+
total_steps=total_steps,
|
|
173
|
+
timestep=int(t_curr * 1000),
|
|
174
|
+
latents=preview_img,
|
|
175
|
+
),
|
|
176
|
+
)
|
|
177
|
+
else:
|
|
178
|
+
user_step += 1
|
|
179
|
+
if user_step <= total_steps:
|
|
180
|
+
pbar.update(1)
|
|
181
|
+
preview_img = img - t_curr * pred
|
|
182
|
+
if inpaint_extension is not None:
|
|
183
|
+
preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_img, 0.0)
|
|
184
|
+
# Extract only the generated image portion for preview (exclude reference images)
|
|
185
|
+
callback_latents = preview_img[:, :original_seq_len, :] if img_cond_seq is not None else preview_img
|
|
186
|
+
step_callback(
|
|
187
|
+
PipelineIntermediateState(
|
|
188
|
+
step=user_step,
|
|
189
|
+
order=1,
|
|
190
|
+
total_steps=total_steps,
|
|
191
|
+
timestep=int(t_curr * 1000),
|
|
192
|
+
latents=callback_latents,
|
|
193
|
+
),
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
pbar.close()
|
|
197
|
+
else:
|
|
198
|
+
# Manual Euler stepping (original behavior)
|
|
199
|
+
for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
|
|
200
|
+
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
|
201
|
+
|
|
202
|
+
# Run the transformer model (matching diffusers: guidance=guidance, return_dict=False)
|
|
203
|
+
output = model(
|
|
204
|
+
hidden_states=img,
|
|
205
|
+
encoder_hidden_states=txt,
|
|
206
|
+
timestep=t_vec,
|
|
207
|
+
img_ids=img_ids,
|
|
208
|
+
txt_ids=txt_ids,
|
|
209
|
+
guidance=guidance,
|
|
210
|
+
return_dict=False,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# Extract the sample from the output (return_dict=False returns tuple)
|
|
214
|
+
pred = output[0] if isinstance(output, tuple) else output
|
|
215
|
+
|
|
216
|
+
step_cfg_scale = cfg_scale[step_index]
|
|
217
|
+
|
|
218
|
+
# Apply CFG if scale is not 1.0
|
|
219
|
+
if not math.isclose(step_cfg_scale, 1.0):
|
|
220
|
+
if neg_txt is None:
|
|
221
|
+
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
|
|
222
|
+
|
|
223
|
+
neg_output = model(
|
|
224
|
+
hidden_states=img,
|
|
225
|
+
encoder_hidden_states=neg_txt,
|
|
226
|
+
timestep=t_vec,
|
|
227
|
+
img_ids=img_ids,
|
|
228
|
+
txt_ids=neg_txt_ids if neg_txt_ids is not None else txt_ids,
|
|
229
|
+
guidance=guidance,
|
|
230
|
+
return_dict=False,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
neg_pred = neg_output[0] if isinstance(neg_output, tuple) else neg_output
|
|
234
|
+
pred = neg_pred + step_cfg_scale * (pred - neg_pred)
|
|
235
|
+
|
|
236
|
+
# Euler step
|
|
237
|
+
preview_img = img - t_curr * pred
|
|
238
|
+
img = img + (t_prev - t_curr) * pred
|
|
239
|
+
|
|
240
|
+
# Apply inpainting merge at each step
|
|
241
|
+
if inpaint_extension is not None:
|
|
242
|
+
img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
|
|
243
|
+
preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_img, 0.0)
|
|
244
|
+
|
|
245
|
+
# Extract only the generated image portion for preview (exclude reference images)
|
|
246
|
+
callback_latents = preview_img[:, :original_seq_len, :] if img_cond_seq is not None else preview_img
|
|
247
|
+
step_callback(
|
|
248
|
+
PipelineIntermediateState(
|
|
249
|
+
step=step_index + 1,
|
|
250
|
+
order=1,
|
|
251
|
+
total_steps=total_steps,
|
|
252
|
+
timestep=int(t_curr),
|
|
253
|
+
latents=callback_latents,
|
|
254
|
+
),
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
# Extract only the generated image portion (exclude concatenated reference images)
|
|
258
|
+
if img_cond_seq is not None:
|
|
259
|
+
img = img[:, :original_seq_len, :]
|
|
260
|
+
|
|
261
|
+
return img
|
|
@@ -0,0 +1,294 @@
|
|
|
1
|
+
"""FLUX.2 Klein Reference Image Extension for multi-reference image editing.
|
|
2
|
+
|
|
3
|
+
This module provides the Flux2RefImageExtension for FLUX.2 Klein models,
|
|
4
|
+
which handles encoding reference images using the FLUX.2 VAE and
|
|
5
|
+
generating the appropriate position IDs for multi-reference image editing.
|
|
6
|
+
|
|
7
|
+
FLUX.2 Klein has built-in support for reference image editing (unlike FLUX.1
|
|
8
|
+
which requires a separate Kontext model).
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import math
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
import torch.nn.functional as F
|
|
15
|
+
import torchvision.transforms as T
|
|
16
|
+
from einops import repeat
|
|
17
|
+
from PIL import Image
|
|
18
|
+
|
|
19
|
+
from invokeai.app.invocations.fields import FluxKontextConditioningField
|
|
20
|
+
from invokeai.app.invocations.model import VAEField
|
|
21
|
+
from invokeai.app.services.shared.invocation_context import InvocationContext
|
|
22
|
+
from invokeai.backend.flux2.sampling_utils import pack_flux2
|
|
23
|
+
from invokeai.backend.util.devices import TorchDevice
|
|
24
|
+
|
|
25
|
+
# Maximum pixel counts for reference images (matches BFL FLUX.2 sampling.py)
|
|
26
|
+
# Single reference image: 2024² pixels, Multiple: 1024² pixels
|
|
27
|
+
MAX_PIXELS_SINGLE_REF = 2024**2 # ~4.1M pixels
|
|
28
|
+
MAX_PIXELS_MULTI_REF = 1024**2 # ~1M pixels
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def resize_image_to_max_pixels(image: Image.Image, max_pixels: int) -> Image.Image:
|
|
32
|
+
"""Resize image to fit within max_pixels while preserving aspect ratio.
|
|
33
|
+
|
|
34
|
+
This matches the BFL FLUX.2 sampling.py cap_pixels() behavior.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
image: PIL Image to resize.
|
|
38
|
+
max_pixels: Maximum total pixel count (width * height).
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
Resized PIL Image (or original if already within bounds).
|
|
42
|
+
"""
|
|
43
|
+
width, height = image.size
|
|
44
|
+
pixel_count = width * height
|
|
45
|
+
|
|
46
|
+
if pixel_count <= max_pixels:
|
|
47
|
+
return image
|
|
48
|
+
|
|
49
|
+
# Calculate scale factor to fit within max_pixels (BFL approach)
|
|
50
|
+
scale = math.sqrt(max_pixels / pixel_count)
|
|
51
|
+
new_width = int(width * scale)
|
|
52
|
+
new_height = int(height * scale)
|
|
53
|
+
|
|
54
|
+
# Ensure dimensions are at least 1
|
|
55
|
+
new_width = max(1, new_width)
|
|
56
|
+
new_height = max(1, new_height)
|
|
57
|
+
|
|
58
|
+
return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def generate_img_ids_flux2_with_offset(
|
|
62
|
+
latent_height: int,
|
|
63
|
+
latent_width: int,
|
|
64
|
+
batch_size: int,
|
|
65
|
+
device: torch.device,
|
|
66
|
+
idx_offset: int = 0,
|
|
67
|
+
h_offset: int = 0,
|
|
68
|
+
w_offset: int = 0,
|
|
69
|
+
) -> torch.Tensor:
|
|
70
|
+
"""Generate tensor of image position ids with optional offsets for FLUX.2.
|
|
71
|
+
|
|
72
|
+
FLUX.2 uses 4D position coordinates (T, H, W, L) for its rotary position embeddings.
|
|
73
|
+
Position IDs use int64 (long) dtype.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
latent_height: Height of image in latent space (before packing).
|
|
77
|
+
latent_width: Width of image in latent space (before packing).
|
|
78
|
+
batch_size: Number of images in the batch.
|
|
79
|
+
device: Device to create tensors on.
|
|
80
|
+
idx_offset: Offset for T (time/index) coordinate - use 1 for reference images.
|
|
81
|
+
h_offset: Spatial offset for H coordinate in latent space.
|
|
82
|
+
w_offset: Spatial offset for W coordinate in latent space.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Image position ids with shape [batch_size, (latent_height//2 * latent_width//2), 4].
|
|
86
|
+
"""
|
|
87
|
+
# After packing, the spatial dimensions are halved due to the 2x2 patch structure
|
|
88
|
+
packed_height = latent_height // 2
|
|
89
|
+
packed_width = latent_width // 2
|
|
90
|
+
|
|
91
|
+
# Convert spatial offsets from latent space to packed space
|
|
92
|
+
packed_h_offset = h_offset // 2
|
|
93
|
+
packed_w_offset = w_offset // 2
|
|
94
|
+
|
|
95
|
+
# Create base tensor for position IDs with shape [packed_height, packed_width, 4]
|
|
96
|
+
# The 4 channels represent: [T, H, W, L]
|
|
97
|
+
img_ids = torch.zeros(packed_height, packed_width, 4, device=device, dtype=torch.long)
|
|
98
|
+
|
|
99
|
+
# Set T (time/index offset) for all positions - use 1 for reference images
|
|
100
|
+
img_ids[..., 0] = idx_offset
|
|
101
|
+
|
|
102
|
+
# Set H (height/y) coordinates with offset
|
|
103
|
+
h_coords = torch.arange(packed_height, device=device, dtype=torch.long) + packed_h_offset
|
|
104
|
+
img_ids[..., 1] = h_coords[:, None]
|
|
105
|
+
|
|
106
|
+
# Set W (width/x) coordinates with offset
|
|
107
|
+
w_coords = torch.arange(packed_width, device=device, dtype=torch.long) + packed_w_offset
|
|
108
|
+
img_ids[..., 2] = w_coords[None, :]
|
|
109
|
+
|
|
110
|
+
# L (layer) coordinate stays 0
|
|
111
|
+
|
|
112
|
+
# Expand to include batch dimension: [batch_size, (packed_height * packed_width), 4]
|
|
113
|
+
img_ids = img_ids.reshape(1, packed_height * packed_width, 4)
|
|
114
|
+
img_ids = repeat(img_ids, "1 s c -> b s c", b=batch_size)
|
|
115
|
+
|
|
116
|
+
return img_ids
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class Flux2RefImageExtension:
|
|
120
|
+
"""Applies FLUX.2 Klein reference image conditioning.
|
|
121
|
+
|
|
122
|
+
This extension handles encoding reference images using the FLUX.2 VAE
|
|
123
|
+
and generating the appropriate 4D position IDs for multi-reference image editing.
|
|
124
|
+
|
|
125
|
+
FLUX.2 Klein has built-in support for reference image editing, unlike FLUX.1
|
|
126
|
+
which requires a separate Kontext model.
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
def __init__(
|
|
130
|
+
self,
|
|
131
|
+
ref_image_conditioning: list[FluxKontextConditioningField],
|
|
132
|
+
context: InvocationContext,
|
|
133
|
+
vae_field: VAEField,
|
|
134
|
+
device: torch.device,
|
|
135
|
+
dtype: torch.dtype,
|
|
136
|
+
bn_mean: torch.Tensor | None = None,
|
|
137
|
+
bn_std: torch.Tensor | None = None,
|
|
138
|
+
):
|
|
139
|
+
"""Initialize the Flux2RefImageExtension.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
ref_image_conditioning: List of reference image conditioning fields.
|
|
143
|
+
context: The invocation context for loading models and images.
|
|
144
|
+
vae_field: The FLUX.2 VAE field for encoding images.
|
|
145
|
+
device: Target device for tensors.
|
|
146
|
+
dtype: Target dtype for tensors.
|
|
147
|
+
bn_mean: BN running mean for normalizing latents (shape: 128).
|
|
148
|
+
bn_std: BN running std for normalizing latents (shape: 128).
|
|
149
|
+
"""
|
|
150
|
+
self._context = context
|
|
151
|
+
self._device = device
|
|
152
|
+
self._dtype = dtype
|
|
153
|
+
self._vae_field = vae_field
|
|
154
|
+
self._bn_mean = bn_mean
|
|
155
|
+
self._bn_std = bn_std
|
|
156
|
+
self.ref_image_conditioning = ref_image_conditioning
|
|
157
|
+
|
|
158
|
+
# Pre-process and cache the reference image latents and ids upon initialization
|
|
159
|
+
self.ref_image_latents, self.ref_image_ids = self._prepare_ref_images()
|
|
160
|
+
|
|
161
|
+
def _bn_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
|
162
|
+
"""Apply BN normalization to packed latents.
|
|
163
|
+
|
|
164
|
+
BN formula (affine=False): y = (x - mean) / std
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
x: Packed latents of shape (B, seq, 128).
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
Normalized latents of same shape.
|
|
171
|
+
"""
|
|
172
|
+
assert self._bn_mean is not None and self._bn_std is not None
|
|
173
|
+
bn_mean = self._bn_mean.to(x.device, x.dtype)
|
|
174
|
+
bn_std = self._bn_std.to(x.device, x.dtype)
|
|
175
|
+
return (x - bn_mean) / bn_std
|
|
176
|
+
|
|
177
|
+
def _prepare_ref_images(self) -> tuple[torch.Tensor, torch.Tensor]:
|
|
178
|
+
"""Encode reference images and prepare their concatenated latents and IDs with spatial tiling."""
|
|
179
|
+
all_latents = []
|
|
180
|
+
all_ids = []
|
|
181
|
+
|
|
182
|
+
# Track cumulative dimensions for spatial tiling
|
|
183
|
+
canvas_h = 0
|
|
184
|
+
canvas_w = 0
|
|
185
|
+
|
|
186
|
+
vae_info = self._context.models.load(self._vae_field.vae)
|
|
187
|
+
|
|
188
|
+
# Determine max pixels based on number of reference images (BFL FLUX.2 approach)
|
|
189
|
+
num_refs = len(self.ref_image_conditioning)
|
|
190
|
+
max_pixels = MAX_PIXELS_SINGLE_REF if num_refs == 1 else MAX_PIXELS_MULTI_REF
|
|
191
|
+
|
|
192
|
+
for idx, ref_image_field in enumerate(self.ref_image_conditioning):
|
|
193
|
+
image = self._context.images.get_pil(ref_image_field.image.image_name)
|
|
194
|
+
image = image.convert("RGB")
|
|
195
|
+
|
|
196
|
+
# Resize large images to max pixel count (matches BFL FLUX.2 sampling.py)
|
|
197
|
+
image = resize_image_to_max_pixels(image, max_pixels)
|
|
198
|
+
|
|
199
|
+
# Convert to tensor using torchvision transforms
|
|
200
|
+
transformation = T.Compose([T.ToTensor()])
|
|
201
|
+
image_tensor = transformation(image)
|
|
202
|
+
# Convert from [0, 1] to [-1, 1] range expected by VAE
|
|
203
|
+
image_tensor = image_tensor * 2.0 - 1.0
|
|
204
|
+
image_tensor = image_tensor.unsqueeze(0) # Add batch dimension
|
|
205
|
+
|
|
206
|
+
# Encode using FLUX.2 VAE
|
|
207
|
+
with vae_info.model_on_device() as (_, vae):
|
|
208
|
+
vae_dtype = next(iter(vae.parameters())).dtype
|
|
209
|
+
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
|
|
210
|
+
|
|
211
|
+
# FLUX.2 VAE uses diffusers API
|
|
212
|
+
latent_dist = vae.encode(image_tensor, return_dict=False)[0]
|
|
213
|
+
|
|
214
|
+
# Use mode() for deterministic encoding (no sampling)
|
|
215
|
+
if hasattr(latent_dist, "mode"):
|
|
216
|
+
ref_image_latents_unpacked = latent_dist.mode()
|
|
217
|
+
elif hasattr(latent_dist, "sample"):
|
|
218
|
+
ref_image_latents_unpacked = latent_dist.sample()
|
|
219
|
+
else:
|
|
220
|
+
ref_image_latents_unpacked = latent_dist
|
|
221
|
+
|
|
222
|
+
TorchDevice.empty_cache()
|
|
223
|
+
|
|
224
|
+
# Extract tensor dimensions (B, 32, H, W for FLUX.2)
|
|
225
|
+
batch_size, _, latent_height, latent_width = ref_image_latents_unpacked.shape
|
|
226
|
+
|
|
227
|
+
# Pad latents to be compatible with patch_size=2
|
|
228
|
+
pad_h = (2 - latent_height % 2) % 2
|
|
229
|
+
pad_w = (2 - latent_width % 2) % 2
|
|
230
|
+
if pad_h > 0 or pad_w > 0:
|
|
231
|
+
ref_image_latents_unpacked = F.pad(ref_image_latents_unpacked, (0, pad_w, 0, pad_h), mode="circular")
|
|
232
|
+
_, _, latent_height, latent_width = ref_image_latents_unpacked.shape
|
|
233
|
+
|
|
234
|
+
# Pack the latents using FLUX.2 pack function (32 channels -> 128)
|
|
235
|
+
ref_image_latents_packed = pack_flux2(ref_image_latents_unpacked).to(self._device, self._dtype)
|
|
236
|
+
|
|
237
|
+
# Apply BN normalization to match the input latents scale
|
|
238
|
+
# This is critical - the transformer expects normalized latents
|
|
239
|
+
if self._bn_mean is not None and self._bn_std is not None:
|
|
240
|
+
ref_image_latents_packed = self._bn_normalize(ref_image_latents_packed)
|
|
241
|
+
|
|
242
|
+
# Determine spatial offsets for this reference image
|
|
243
|
+
h_offset = 0
|
|
244
|
+
w_offset = 0
|
|
245
|
+
|
|
246
|
+
if idx > 0: # First image starts at (0, 0)
|
|
247
|
+
# Calculate potential canvas dimensions for each tiling option
|
|
248
|
+
potential_h_vertical = canvas_h + latent_height
|
|
249
|
+
potential_w_horizontal = canvas_w + latent_width
|
|
250
|
+
|
|
251
|
+
# Choose arrangement that minimizes the maximum dimension
|
|
252
|
+
if potential_h_vertical > potential_w_horizontal:
|
|
253
|
+
# Tile horizontally (to the right)
|
|
254
|
+
w_offset = canvas_w
|
|
255
|
+
canvas_w = canvas_w + latent_width
|
|
256
|
+
canvas_h = max(canvas_h, latent_height)
|
|
257
|
+
else:
|
|
258
|
+
# Tile vertically (below)
|
|
259
|
+
h_offset = canvas_h
|
|
260
|
+
canvas_h = canvas_h + latent_height
|
|
261
|
+
canvas_w = max(canvas_w, latent_width)
|
|
262
|
+
else:
|
|
263
|
+
canvas_h = latent_height
|
|
264
|
+
canvas_w = latent_width
|
|
265
|
+
|
|
266
|
+
# Generate position IDs with 4D format (T, H, W, L)
|
|
267
|
+
# Use T-coordinate offset with scale=10 like diffusers Flux2Pipeline:
|
|
268
|
+
# T = scale + scale * idx (so first ref image is T=10, second is T=20, etc.)
|
|
269
|
+
# The generated image uses T=0, so this clearly separates reference images
|
|
270
|
+
t_offset = 10 + 10 * idx # scale=10 matches diffusers
|
|
271
|
+
ref_image_ids = generate_img_ids_flux2_with_offset(
|
|
272
|
+
latent_height=latent_height,
|
|
273
|
+
latent_width=latent_width,
|
|
274
|
+
batch_size=batch_size,
|
|
275
|
+
device=self._device,
|
|
276
|
+
idx_offset=t_offset, # Reference images use T=10, 20, 30...
|
|
277
|
+
h_offset=h_offset,
|
|
278
|
+
w_offset=w_offset,
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
all_latents.append(ref_image_latents_packed)
|
|
282
|
+
all_ids.append(ref_image_ids)
|
|
283
|
+
|
|
284
|
+
# Concatenate all latents and IDs along the sequence dimension
|
|
285
|
+
concatenated_latents = torch.cat(all_latents, dim=1)
|
|
286
|
+
concatenated_ids = torch.cat(all_ids, dim=1)
|
|
287
|
+
|
|
288
|
+
return concatenated_latents, concatenated_ids
|
|
289
|
+
|
|
290
|
+
def ensure_batch_size(self, target_batch_size: int) -> None:
|
|
291
|
+
"""Ensure the reference image latents and IDs match the target batch size."""
|
|
292
|
+
if self.ref_image_latents.shape[0] != target_batch_size:
|
|
293
|
+
self.ref_image_latents = self.ref_image_latents.repeat(target_batch_size, 1, 1)
|
|
294
|
+
self.ref_image_ids = self.ref_image_ids.repeat(target_batch_size, 1, 1)
|