InvokeAI 6.10.0rc1__py3-none-any.whl → 6.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. invokeai/app/api/routers/model_manager.py +43 -1
  2. invokeai/app/invocations/fields.py +1 -1
  3. invokeai/app/invocations/flux2_denoise.py +499 -0
  4. invokeai/app/invocations/flux2_klein_model_loader.py +222 -0
  5. invokeai/app/invocations/flux2_klein_text_encoder.py +222 -0
  6. invokeai/app/invocations/flux2_vae_decode.py +106 -0
  7. invokeai/app/invocations/flux2_vae_encode.py +88 -0
  8. invokeai/app/invocations/flux_denoise.py +77 -3
  9. invokeai/app/invocations/flux_lora_loader.py +1 -1
  10. invokeai/app/invocations/flux_model_loader.py +2 -5
  11. invokeai/app/invocations/ideal_size.py +6 -1
  12. invokeai/app/invocations/metadata.py +4 -0
  13. invokeai/app/invocations/metadata_linked.py +47 -0
  14. invokeai/app/invocations/model.py +1 -0
  15. invokeai/app/invocations/pbr_maps.py +59 -0
  16. invokeai/app/invocations/z_image_denoise.py +244 -84
  17. invokeai/app/invocations/z_image_image_to_latents.py +9 -1
  18. invokeai/app/invocations/z_image_latents_to_image.py +9 -1
  19. invokeai/app/invocations/z_image_seed_variance_enhancer.py +110 -0
  20. invokeai/app/services/config/config_default.py +3 -1
  21. invokeai/app/services/invocation_stats/invocation_stats_common.py +6 -6
  22. invokeai/app/services/invocation_stats/invocation_stats_default.py +9 -4
  23. invokeai/app/services/model_manager/model_manager_default.py +7 -0
  24. invokeai/app/services/model_records/model_records_base.py +4 -2
  25. invokeai/app/services/shared/invocation_context.py +15 -0
  26. invokeai/app/services/shared/sqlite/sqlite_util.py +2 -0
  27. invokeai/app/services/shared/sqlite_migrator/migrations/migration_25.py +61 -0
  28. invokeai/app/util/step_callback.py +58 -2
  29. invokeai/backend/flux/denoise.py +338 -118
  30. invokeai/backend/flux/dype/__init__.py +31 -0
  31. invokeai/backend/flux/dype/base.py +260 -0
  32. invokeai/backend/flux/dype/embed.py +116 -0
  33. invokeai/backend/flux/dype/presets.py +148 -0
  34. invokeai/backend/flux/dype/rope.py +110 -0
  35. invokeai/backend/flux/extensions/dype_extension.py +91 -0
  36. invokeai/backend/flux/schedulers.py +62 -0
  37. invokeai/backend/flux/util.py +35 -1
  38. invokeai/backend/flux2/__init__.py +4 -0
  39. invokeai/backend/flux2/denoise.py +280 -0
  40. invokeai/backend/flux2/ref_image_extension.py +294 -0
  41. invokeai/backend/flux2/sampling_utils.py +209 -0
  42. invokeai/backend/image_util/pbr_maps/architecture/block.py +367 -0
  43. invokeai/backend/image_util/pbr_maps/architecture/pbr_rrdb_net.py +70 -0
  44. invokeai/backend/image_util/pbr_maps/pbr_maps.py +141 -0
  45. invokeai/backend/image_util/pbr_maps/utils/image_ops.py +93 -0
  46. invokeai/backend/model_manager/configs/factory.py +19 -1
  47. invokeai/backend/model_manager/configs/lora.py +36 -0
  48. invokeai/backend/model_manager/configs/main.py +395 -3
  49. invokeai/backend/model_manager/configs/qwen3_encoder.py +116 -7
  50. invokeai/backend/model_manager/configs/vae.py +104 -2
  51. invokeai/backend/model_manager/load/model_cache/model_cache.py +107 -2
  52. invokeai/backend/model_manager/load/model_loaders/cogview4.py +2 -1
  53. invokeai/backend/model_manager/load/model_loaders/flux.py +1020 -8
  54. invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +4 -2
  55. invokeai/backend/model_manager/load/model_loaders/onnx.py +1 -0
  56. invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +2 -1
  57. invokeai/backend/model_manager/load/model_loaders/z_image.py +158 -31
  58. invokeai/backend/model_manager/starter_models.py +141 -4
  59. invokeai/backend/model_manager/taxonomy.py +31 -4
  60. invokeai/backend/model_manager/util/select_hf_files.py +3 -2
  61. invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +39 -5
  62. invokeai/backend/quantization/gguf/ggml_tensor.py +15 -4
  63. invokeai/backend/util/vae_working_memory.py +0 -2
  64. invokeai/backend/z_image/extensions/regional_prompting_extension.py +10 -12
  65. invokeai/frontend/web/dist/assets/App-D13dX7be.js +161 -0
  66. invokeai/frontend/web/dist/assets/{browser-ponyfill-DHZxq1nk.js → browser-ponyfill-u_ZjhQTI.js} +1 -1
  67. invokeai/frontend/web/dist/assets/index-BB0nHmDe.js +530 -0
  68. invokeai/frontend/web/dist/index.html +1 -1
  69. invokeai/frontend/web/dist/locales/en-GB.json +1 -0
  70. invokeai/frontend/web/dist/locales/en.json +85 -6
  71. invokeai/frontend/web/dist/locales/it.json +135 -15
  72. invokeai/frontend/web/dist/locales/ru.json +11 -11
  73. invokeai/version/invokeai_version.py +1 -1
  74. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/METADATA +8 -2
  75. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/RECORD +81 -57
  76. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/WHEEL +1 -1
  77. invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +0 -161
  78. invokeai/frontend/web/dist/assets/index-dgSJAY--.js +0 -530
  79. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/entry_points.txt +0 -0
  80. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE +0 -0
  81. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
  82. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
  83. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/top_level.txt +0 -0
@@ -219,7 +219,16 @@ async def reidentify_model(
219
219
  result = ModelConfigFactory.from_model_on_disk(mod)
220
220
  if result.config is None:
221
221
  raise InvalidModelException("Unable to identify model format")
222
- result.config.key = config.key # retain the same key
222
+
223
+ # Retain user-editable fields from the original config
224
+ result.config.key = config.key
225
+ result.config.name = config.name
226
+ result.config.description = config.description
227
+ result.config.cover_image = config.cover_image
228
+ result.config.trigger_phrases = config.trigger_phrases
229
+ result.config.source = config.source
230
+ result.config.source_type = config.source_type
231
+
223
232
  new_config = ApiDependencies.invoker.services.model_manager.store.replace_model(config.key, result.config)
224
233
  return new_config
225
234
  except UnknownModelException as e:
@@ -905,15 +914,48 @@ class StarterModelResponse(BaseModel):
905
914
  def get_is_installed(
906
915
  starter_model: StarterModel | StarterModelWithoutDependencies, installed_models: list[AnyModelConfig]
907
916
  ) -> bool:
917
+ from invokeai.backend.model_manager.taxonomy import ModelType
918
+
908
919
  for model in installed_models:
920
+ # Check if source matches exactly
909
921
  if model.source == starter_model.source:
910
922
  return True
923
+ # Check if name (or previous names), base and type match
911
924
  if (
912
925
  (model.name == starter_model.name or model.name in starter_model.previous_names)
913
926
  and model.base == starter_model.base
914
927
  and model.type == starter_model.type
915
928
  ):
916
929
  return True
930
+
931
+ # Special handling for Qwen3Encoder models - check by type and variant
932
+ # This allows renamed models to still be detected as installed
933
+ if starter_model.type == ModelType.Qwen3Encoder:
934
+ from invokeai.backend.model_manager.taxonomy import Qwen3VariantType
935
+
936
+ # Determine expected variant from source pattern
937
+ expected_variant: Qwen3VariantType | None = None
938
+ if "klein-9B" in starter_model.source or "qwen3_8b" in starter_model.source.lower():
939
+ expected_variant = Qwen3VariantType.Qwen3_8B
940
+ elif (
941
+ "klein-4B" in starter_model.source
942
+ or "qwen3_4b" in starter_model.source.lower()
943
+ or "Z-Image" in starter_model.source
944
+ ):
945
+ expected_variant = Qwen3VariantType.Qwen3_4B
946
+
947
+ if expected_variant is not None:
948
+ for model in installed_models:
949
+ if model.type == ModelType.Qwen3Encoder and hasattr(model, "variant"):
950
+ model_variant = model.variant
951
+ # Handle both enum and string values
952
+ if isinstance(model_variant, Qwen3VariantType):
953
+ if model_variant == expected_variant:
954
+ return True
955
+ elif isinstance(model_variant, str):
956
+ if model_variant == expected_variant.value:
957
+ return True
958
+
917
959
  return False
918
960
 
919
961
 
@@ -532,7 +532,7 @@ def migrate_model_ui_type(ui_type: UIType | str, json_schema_extra: dict[str, An
532
532
  case UIType.VAEModel:
533
533
  ui_model_type = [ModelType.VAE]
534
534
  case UIType.FluxVAEModel:
535
- ui_model_base = [BaseModelType.Flux]
535
+ ui_model_base = [BaseModelType.Flux, BaseModelType.Flux2]
536
536
  ui_model_type = [ModelType.VAE]
537
537
  case UIType.LoRAModel:
538
538
  ui_model_type = [ModelType.LoRA]
@@ -0,0 +1,499 @@
1
+ """Flux2 Klein Denoise Invocation.
2
+
3
+ Run denoising process with a FLUX.2 Klein transformer model.
4
+ Uses Qwen3 conditioning instead of CLIP+T5.
5
+ """
6
+
7
+ from contextlib import ExitStack
8
+ from typing import Callable, Iterator, Optional, Tuple
9
+
10
+ import torch
11
+ import torchvision.transforms as tv_transforms
12
+ from torchvision.transforms.functional import resize as tv_resize
13
+
14
+ from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
15
+ from invokeai.app.invocations.fields import (
16
+ DenoiseMaskField,
17
+ FieldDescriptions,
18
+ FluxConditioningField,
19
+ FluxKontextConditioningField,
20
+ Input,
21
+ InputField,
22
+ LatentsField,
23
+ )
24
+ from invokeai.app.invocations.model import TransformerField, VAEField
25
+ from invokeai.app.invocations.primitives import LatentsOutput
26
+ from invokeai.app.services.shared.invocation_context import InvocationContext
27
+ from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional
28
+ from invokeai.backend.flux.schedulers import FLUX_SCHEDULER_LABELS, FLUX_SCHEDULER_MAP, FLUX_SCHEDULER_NAME_VALUES
29
+ from invokeai.backend.flux2.denoise import denoise
30
+ from invokeai.backend.flux2.ref_image_extension import Flux2RefImageExtension
31
+ from invokeai.backend.flux2.sampling_utils import (
32
+ compute_empirical_mu,
33
+ generate_img_ids_flux2,
34
+ get_noise_flux2,
35
+ get_schedule_flux2,
36
+ pack_flux2,
37
+ unpack_flux2,
38
+ )
39
+ from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
40
+ from invokeai.backend.patches.layer_patcher import LayerPatcher
41
+ from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
42
+ from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
43
+ from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
44
+ from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
45
+ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
46
+ from invokeai.backend.util.devices import TorchDevice
47
+
48
+
49
+ @invocation(
50
+ "flux2_denoise",
51
+ title="FLUX2 Denoise",
52
+ tags=["image", "flux", "flux2", "klein", "denoise"],
53
+ category="image",
54
+ version="1.3.0",
55
+ classification=Classification.Prototype,
56
+ )
57
+ class Flux2DenoiseInvocation(BaseInvocation):
58
+ """Run denoising process with a FLUX.2 Klein transformer model.
59
+
60
+ This node is designed for FLUX.2 Klein models which use Qwen3 as the text encoder.
61
+ It does not support ControlNet, IP-Adapters, or regional prompting.
62
+ """
63
+
64
+ latents: Optional[LatentsField] = InputField(
65
+ default=None,
66
+ description=FieldDescriptions.latents,
67
+ input=Input.Connection,
68
+ )
69
+ denoise_mask: Optional[DenoiseMaskField] = InputField(
70
+ default=None,
71
+ description=FieldDescriptions.denoise_mask,
72
+ input=Input.Connection,
73
+ )
74
+ denoising_start: float = InputField(
75
+ default=0.0,
76
+ ge=0,
77
+ le=1,
78
+ description=FieldDescriptions.denoising_start,
79
+ )
80
+ denoising_end: float = InputField(
81
+ default=1.0,
82
+ ge=0,
83
+ le=1,
84
+ description=FieldDescriptions.denoising_end,
85
+ )
86
+ add_noise: bool = InputField(default=True, description="Add noise based on denoising start.")
87
+ transformer: TransformerField = InputField(
88
+ description=FieldDescriptions.flux_model,
89
+ input=Input.Connection,
90
+ title="Transformer",
91
+ )
92
+ positive_text_conditioning: FluxConditioningField = InputField(
93
+ description=FieldDescriptions.positive_cond,
94
+ input=Input.Connection,
95
+ )
96
+ negative_text_conditioning: Optional[FluxConditioningField] = InputField(
97
+ default=None,
98
+ description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.",
99
+ input=Input.Connection,
100
+ )
101
+ cfg_scale: float = InputField(
102
+ default=1.0,
103
+ description=FieldDescriptions.cfg_scale,
104
+ title="CFG Scale",
105
+ )
106
+ width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
107
+ height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
108
+ num_steps: int = InputField(
109
+ default=4,
110
+ description="Number of diffusion steps. Use 4 for distilled models, 28+ for base models.",
111
+ )
112
+ scheduler: FLUX_SCHEDULER_NAME_VALUES = InputField(
113
+ default="euler",
114
+ description="Scheduler (sampler) for the denoising process. 'euler' is fast and standard. "
115
+ "'heun' is 2nd-order (better quality, 2x slower). 'lcm' is optimized for few steps.",
116
+ ui_choice_labels=FLUX_SCHEDULER_LABELS,
117
+ )
118
+ seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
119
+ vae: VAEField = InputField(
120
+ description="FLUX.2 VAE model (required for BN statistics).",
121
+ input=Input.Connection,
122
+ )
123
+ kontext_conditioning: FluxKontextConditioningField | list[FluxKontextConditioningField] | None = InputField(
124
+ default=None,
125
+ description="FLUX Kontext conditioning (reference images for multi-reference image editing).",
126
+ input=Input.Connection,
127
+ title="Reference Images",
128
+ )
129
+
130
+ def _get_bn_stats(self, context: InvocationContext) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
131
+ """Extract BN statistics from the FLUX.2 VAE.
132
+
133
+ The FLUX.2 VAE uses batch normalization on the patchified 128-channel representation.
134
+ IMPORTANT: BFL FLUX.2 VAE uses affine=False, so there are NO learnable weight/bias.
135
+
136
+ BN formula (affine=False): y = (x - mean) / std
137
+ Inverse: x = y * std + mean
138
+
139
+ Returns:
140
+ Tuple of (bn_mean, bn_std) tensors of shape (128,), or None if BN layer not found.
141
+ """
142
+ with context.models.load(self.vae.vae).model_on_device() as (_, vae):
143
+ # Ensure VAE is in eval mode to prevent BN stats from being updated
144
+ vae.eval()
145
+
146
+ # Try to find the BN layer - it may be at different locations depending on model format
147
+ bn_layer = None
148
+ if hasattr(vae, "bn"):
149
+ bn_layer = vae.bn
150
+ elif hasattr(vae, "batch_norm"):
151
+ bn_layer = vae.batch_norm
152
+ elif hasattr(vae, "encoder") and hasattr(vae.encoder, "bn"):
153
+ bn_layer = vae.encoder.bn
154
+
155
+ if bn_layer is None:
156
+ return None
157
+
158
+ # Verify running statistics are initialized
159
+ if bn_layer.running_mean is None or bn_layer.running_var is None:
160
+ return None
161
+
162
+ # Get BN running statistics from VAE
163
+ bn_mean = bn_layer.running_mean.clone() # Shape: (128,)
164
+ bn_var = bn_layer.running_var.clone() # Shape: (128,)
165
+ bn_eps = bn_layer.eps if hasattr(bn_layer, "eps") else 1e-4 # BFL uses 1e-4
166
+ bn_std = torch.sqrt(bn_var + bn_eps)
167
+
168
+ return bn_mean, bn_std
169
+
170
+ def _bn_normalize(
171
+ self,
172
+ x: torch.Tensor,
173
+ bn_mean: torch.Tensor,
174
+ bn_std: torch.Tensor,
175
+ ) -> torch.Tensor:
176
+ """Apply BN normalization to packed latents.
177
+
178
+ BN formula (affine=False): y = (x - mean) / std
179
+
180
+ Args:
181
+ x: Packed latents of shape (B, seq, 128).
182
+ bn_mean: BN running mean of shape (128,).
183
+ bn_std: BN running std of shape (128,).
184
+
185
+ Returns:
186
+ Normalized latents of same shape.
187
+ """
188
+ # x: (B, seq, 128), params: (128,) -> broadcast over batch and sequence dims
189
+ bn_mean = bn_mean.to(x.device, x.dtype)
190
+ bn_std = bn_std.to(x.device, x.dtype)
191
+ return (x - bn_mean) / bn_std
192
+
193
+ def _bn_denormalize(
194
+ self,
195
+ x: torch.Tensor,
196
+ bn_mean: torch.Tensor,
197
+ bn_std: torch.Tensor,
198
+ ) -> torch.Tensor:
199
+ """Apply BN denormalization to packed latents (inverse of normalization).
200
+
201
+ Inverse BN (affine=False): x = y * std + mean
202
+
203
+ Args:
204
+ x: Packed latents of shape (B, seq, 128).
205
+ bn_mean: BN running mean of shape (128,).
206
+ bn_std: BN running std of shape (128,).
207
+
208
+ Returns:
209
+ Denormalized latents of same shape.
210
+ """
211
+ # x: (B, seq, 128), params: (128,) -> broadcast over batch and sequence dims
212
+ bn_mean = bn_mean.to(x.device, x.dtype)
213
+ bn_std = bn_std.to(x.device, x.dtype)
214
+ return x * bn_std + bn_mean
215
+
216
+ @torch.no_grad()
217
+ def invoke(self, context: InvocationContext) -> LatentsOutput:
218
+ latents = self._run_diffusion(context)
219
+ latents = latents.detach().to("cpu")
220
+
221
+ name = context.tensors.save(tensor=latents)
222
+ return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
223
+
224
+ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
225
+ inference_dtype = torch.bfloat16
226
+ device = TorchDevice.choose_torch_device()
227
+
228
+ # Get BN statistics from VAE for latent denormalization (optional)
229
+ # BFL FLUX.2 VAE uses affine=False, so only mean/std are needed
230
+ # Some VAE formats (e.g. diffusers) may not expose BN stats directly
231
+ bn_stats = self._get_bn_stats(context)
232
+ bn_mean, bn_std = bn_stats if bn_stats is not None else (None, None)
233
+
234
+ # Load the input latents, if provided
235
+ init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
236
+ if init_latents is not None:
237
+ init_latents = init_latents.to(device=device, dtype=inference_dtype)
238
+
239
+ # Prepare input noise (FLUX.2 uses 32 channels)
240
+ noise = get_noise_flux2(
241
+ num_samples=1,
242
+ height=self.height,
243
+ width=self.width,
244
+ device=device,
245
+ dtype=inference_dtype,
246
+ seed=self.seed,
247
+ )
248
+ b, _c, latent_h, latent_w = noise.shape
249
+ packed_h = latent_h // 2
250
+ packed_w = latent_w // 2
251
+
252
+ # Load the conditioning data
253
+ pos_cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
254
+ assert len(pos_cond_data.conditionings) == 1
255
+ pos_flux_conditioning = pos_cond_data.conditionings[0]
256
+ assert isinstance(pos_flux_conditioning, FLUXConditioningInfo)
257
+ pos_flux_conditioning = pos_flux_conditioning.to(dtype=inference_dtype, device=device)
258
+
259
+ # Qwen3 stacked embeddings (stored in t5_embeds field for compatibility)
260
+ txt = pos_flux_conditioning.t5_embeds
261
+
262
+ # Generate text position IDs (4D format for FLUX.2: T, H, W, L)
263
+ # FLUX.2 uses 4D position coordinates for its rotary position embeddings
264
+ # IMPORTANT: Position IDs must be int64 (long) dtype
265
+ # Diffusers uses: T=0, H=0, W=0, L=0..seq_len-1
266
+ seq_len = txt.shape[1]
267
+ txt_ids = torch.zeros(1, seq_len, 4, device=device, dtype=torch.long)
268
+ txt_ids[..., 3] = torch.arange(seq_len, device=device, dtype=torch.long) # L coordinate varies
269
+
270
+ # Load negative conditioning if provided
271
+ neg_txt = None
272
+ neg_txt_ids = None
273
+ if self.negative_text_conditioning is not None:
274
+ neg_cond_data = context.conditioning.load(self.negative_text_conditioning.conditioning_name)
275
+ assert len(neg_cond_data.conditionings) == 1
276
+ neg_flux_conditioning = neg_cond_data.conditionings[0]
277
+ assert isinstance(neg_flux_conditioning, FLUXConditioningInfo)
278
+ neg_flux_conditioning = neg_flux_conditioning.to(dtype=inference_dtype, device=device)
279
+ neg_txt = neg_flux_conditioning.t5_embeds
280
+ # For text tokens: T=0, H=0, W=0, L=0..seq_len-1 (only L varies per token)
281
+ neg_seq_len = neg_txt.shape[1]
282
+ neg_txt_ids = torch.zeros(1, neg_seq_len, 4, device=device, dtype=torch.long)
283
+ neg_txt_ids[..., 3] = torch.arange(neg_seq_len, device=device, dtype=torch.long)
284
+
285
+ # Validate transformer config
286
+ transformer_config = context.models.get_config(self.transformer.transformer)
287
+ assert transformer_config.base == BaseModelType.Flux2 and transformer_config.type == ModelType.Main
288
+
289
+ # Calculate the timestep schedule using FLUX.2 specific schedule
290
+ # This matches diffusers' Flux2Pipeline implementation
291
+ # Note: Schedule shifting is handled by the scheduler via mu parameter
292
+ image_seq_len = packed_h * packed_w
293
+ timesteps = get_schedule_flux2(
294
+ num_steps=self.num_steps,
295
+ image_seq_len=image_seq_len,
296
+ )
297
+ # Compute mu for dynamic schedule shifting (used by FlowMatchEulerDiscreteScheduler)
298
+ mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=self.num_steps)
299
+
300
+ # Clip the timesteps schedule based on denoising_start and denoising_end
301
+ timesteps = clip_timestep_schedule_fractional(timesteps, self.denoising_start, self.denoising_end)
302
+
303
+ # Prepare input latent image
304
+ if init_latents is not None:
305
+ if self.add_noise:
306
+ t_0 = timesteps[0]
307
+ x = t_0 * noise + (1.0 - t_0) * init_latents
308
+ else:
309
+ x = init_latents
310
+ else:
311
+ if self.denoising_start > 1e-5:
312
+ raise ValueError("denoising_start should be 0 when initial latents are not provided.")
313
+ x = noise
314
+
315
+ # If len(timesteps) == 1, then short-circuit
316
+ if len(timesteps) <= 1:
317
+ return x
318
+
319
+ # Generate image position IDs (FLUX.2 uses 4D coordinates)
320
+ # Position IDs use int64 dtype like diffusers
321
+ img_ids = generate_img_ids_flux2(h=latent_h, w=latent_w, batch_size=b, device=device)
322
+
323
+ # Prepare inpaint mask
324
+ inpaint_mask = self._prep_inpaint_mask(context, x)
325
+
326
+ # Pack all latent tensors
327
+ init_latents_packed = pack_flux2(init_latents) if init_latents is not None else None
328
+ inpaint_mask_packed = pack_flux2(inpaint_mask) if inpaint_mask is not None else None
329
+ noise_packed = pack_flux2(noise)
330
+ x = pack_flux2(x)
331
+
332
+ # Apply BN normalization BEFORE denoising (as per diffusers Flux2KleinPipeline)
333
+ # BN normalization: y = (x - mean) / std
334
+ # This transforms latents to normalized space for the transformer
335
+ # IMPORTANT: Also normalize init_latents and noise for inpainting to maintain consistency
336
+ if bn_mean is not None and bn_std is not None:
337
+ x = self._bn_normalize(x, bn_mean, bn_std)
338
+ if init_latents_packed is not None:
339
+ init_latents_packed = self._bn_normalize(init_latents_packed, bn_mean, bn_std)
340
+ noise_packed = self._bn_normalize(noise_packed, bn_mean, bn_std)
341
+
342
+ # Verify packed dimensions
343
+ assert packed_h * packed_w == x.shape[1]
344
+
345
+ # Prepare inpaint extension
346
+ inpaint_extension: Optional[RectifiedFlowInpaintExtension] = None
347
+ if inpaint_mask_packed is not None:
348
+ assert init_latents_packed is not None
349
+ inpaint_extension = RectifiedFlowInpaintExtension(
350
+ init_latents=init_latents_packed,
351
+ inpaint_mask=inpaint_mask_packed,
352
+ noise=noise_packed,
353
+ )
354
+
355
+ # Prepare CFG scale list
356
+ num_steps = len(timesteps) - 1
357
+ cfg_scale_list = [self.cfg_scale] * num_steps
358
+
359
+ # Check if we're doing inpainting (have a mask or a clipped schedule)
360
+ is_inpainting = self.denoise_mask is not None or self.denoising_start > 1e-5
361
+
362
+ # Create scheduler with FLUX.2 Klein configuration
363
+ # For inpainting/img2img, use manual Euler stepping to preserve the exact timestep schedule
364
+ # For txt2img, use the scheduler with dynamic shifting for optimal results
365
+ scheduler = None
366
+ if self.scheduler in FLUX_SCHEDULER_MAP and not is_inpainting:
367
+ # Only use scheduler for txt2img - use manual Euler for inpainting to preserve exact timesteps
368
+ scheduler_class = FLUX_SCHEDULER_MAP[self.scheduler]
369
+ scheduler = scheduler_class(
370
+ num_train_timesteps=1000,
371
+ shift=3.0,
372
+ use_dynamic_shifting=True,
373
+ base_shift=0.5,
374
+ max_shift=1.15,
375
+ base_image_seq_len=256,
376
+ max_image_seq_len=4096,
377
+ time_shift_type="exponential",
378
+ )
379
+
380
+ # Prepare reference image extension for FLUX.2 Klein built-in editing
381
+ ref_image_extension = None
382
+ if self.kontext_conditioning:
383
+ ref_image_extension = Flux2RefImageExtension(
384
+ context=context,
385
+ ref_image_conditioning=self.kontext_conditioning
386
+ if isinstance(self.kontext_conditioning, list)
387
+ else [self.kontext_conditioning],
388
+ vae_field=self.vae,
389
+ device=device,
390
+ dtype=inference_dtype,
391
+ bn_mean=bn_mean,
392
+ bn_std=bn_std,
393
+ )
394
+
395
+ with ExitStack() as exit_stack:
396
+ # Load the transformer model
397
+ (cached_weights, transformer) = exit_stack.enter_context(
398
+ context.models.load(self.transformer.transformer).model_on_device()
399
+ )
400
+ config = transformer_config
401
+
402
+ # Determine if the model is quantized
403
+ if config.format in [ModelFormat.Diffusers]:
404
+ model_is_quantized = False
405
+ elif config.format in [
406
+ ModelFormat.BnbQuantizedLlmInt8b,
407
+ ModelFormat.BnbQuantizednf4b,
408
+ ModelFormat.GGUFQuantized,
409
+ ]:
410
+ model_is_quantized = True
411
+ else:
412
+ model_is_quantized = False
413
+
414
+ # Apply LoRA models to the transformer
415
+ exit_stack.enter_context(
416
+ LayerPatcher.apply_smart_model_patches(
417
+ model=transformer,
418
+ patches=self._lora_iterator(context),
419
+ prefix=FLUX_LORA_TRANSFORMER_PREFIX,
420
+ dtype=inference_dtype,
421
+ cached_weights=cached_weights,
422
+ force_sidecar_patching=model_is_quantized,
423
+ )
424
+ )
425
+
426
+ # Prepare reference image conditioning if provided
427
+ img_cond_seq = None
428
+ img_cond_seq_ids = None
429
+ if ref_image_extension is not None:
430
+ # Ensure batch sizes match
431
+ ref_image_extension.ensure_batch_size(x.shape[0])
432
+ img_cond_seq, img_cond_seq_ids = (
433
+ ref_image_extension.ref_image_latents,
434
+ ref_image_extension.ref_image_ids,
435
+ )
436
+
437
+ x = denoise(
438
+ model=transformer,
439
+ img=x,
440
+ img_ids=img_ids,
441
+ txt=txt,
442
+ txt_ids=txt_ids,
443
+ timesteps=timesteps,
444
+ step_callback=self._build_step_callback(context),
445
+ cfg_scale=cfg_scale_list,
446
+ neg_txt=neg_txt,
447
+ neg_txt_ids=neg_txt_ids,
448
+ scheduler=scheduler,
449
+ mu=mu,
450
+ inpaint_extension=inpaint_extension,
451
+ img_cond_seq=img_cond_seq,
452
+ img_cond_seq_ids=img_cond_seq_ids,
453
+ )
454
+
455
+ # Apply BN denormalization if BN stats are available
456
+ # The diffusers Flux2KleinPipeline applies: latents = latents * bn_std + bn_mean
457
+ # This transforms latents from normalized space to VAE's expected input space
458
+ if bn_mean is not None and bn_std is not None:
459
+ x = self._bn_denormalize(x, bn_mean, bn_std)
460
+
461
+ x = unpack_flux2(x.float(), self.height, self.width)
462
+ return x
463
+
464
+ def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> Optional[torch.Tensor]:
465
+ """Prepare the inpaint mask."""
466
+ if self.denoise_mask is None:
467
+ return None
468
+
469
+ mask = context.tensors.load(self.denoise_mask.mask_name)
470
+ mask = 1.0 - mask
471
+
472
+ _, _, latent_height, latent_width = latents.shape
473
+ mask = tv_resize(
474
+ img=mask,
475
+ size=[latent_height, latent_width],
476
+ interpolation=tv_transforms.InterpolationMode.BILINEAR,
477
+ antialias=False,
478
+ )
479
+
480
+ mask = mask.to(device=latents.device, dtype=latents.dtype)
481
+ return mask.expand_as(latents)
482
+
483
+ def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
484
+ """Iterate over LoRA models to apply."""
485
+ for lora in self.transformer.loras:
486
+ lora_info = context.models.load(lora.lora)
487
+ assert isinstance(lora_info.model, ModelPatchRaw)
488
+ yield (lora_info.model, lora.weight)
489
+ del lora_info
490
+
491
+ def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
492
+ """Build a callback for step progress updates."""
493
+
494
+ def step_callback(state: PipelineIntermediateState) -> None:
495
+ latents = state.latents.float()
496
+ state.latents = unpack_flux2(latents, self.height, self.width).squeeze()
497
+ context.util.flux2_step_callback(state)
498
+
499
+ return step_callback