InvokeAI 6.10.0rc1__py3-none-any.whl → 6.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. invokeai/app/api/routers/model_manager.py +43 -1
  2. invokeai/app/invocations/fields.py +1 -1
  3. invokeai/app/invocations/flux2_denoise.py +499 -0
  4. invokeai/app/invocations/flux2_klein_model_loader.py +222 -0
  5. invokeai/app/invocations/flux2_klein_text_encoder.py +222 -0
  6. invokeai/app/invocations/flux2_vae_decode.py +106 -0
  7. invokeai/app/invocations/flux2_vae_encode.py +88 -0
  8. invokeai/app/invocations/flux_denoise.py +77 -3
  9. invokeai/app/invocations/flux_lora_loader.py +1 -1
  10. invokeai/app/invocations/flux_model_loader.py +2 -5
  11. invokeai/app/invocations/ideal_size.py +6 -1
  12. invokeai/app/invocations/metadata.py +4 -0
  13. invokeai/app/invocations/metadata_linked.py +47 -0
  14. invokeai/app/invocations/model.py +1 -0
  15. invokeai/app/invocations/pbr_maps.py +59 -0
  16. invokeai/app/invocations/z_image_denoise.py +244 -84
  17. invokeai/app/invocations/z_image_image_to_latents.py +9 -1
  18. invokeai/app/invocations/z_image_latents_to_image.py +9 -1
  19. invokeai/app/invocations/z_image_seed_variance_enhancer.py +110 -0
  20. invokeai/app/services/config/config_default.py +3 -1
  21. invokeai/app/services/invocation_stats/invocation_stats_common.py +6 -6
  22. invokeai/app/services/invocation_stats/invocation_stats_default.py +9 -4
  23. invokeai/app/services/model_manager/model_manager_default.py +7 -0
  24. invokeai/app/services/model_records/model_records_base.py +4 -2
  25. invokeai/app/services/shared/invocation_context.py +15 -0
  26. invokeai/app/services/shared/sqlite/sqlite_util.py +2 -0
  27. invokeai/app/services/shared/sqlite_migrator/migrations/migration_25.py +61 -0
  28. invokeai/app/util/step_callback.py +58 -2
  29. invokeai/backend/flux/denoise.py +338 -118
  30. invokeai/backend/flux/dype/__init__.py +31 -0
  31. invokeai/backend/flux/dype/base.py +260 -0
  32. invokeai/backend/flux/dype/embed.py +116 -0
  33. invokeai/backend/flux/dype/presets.py +148 -0
  34. invokeai/backend/flux/dype/rope.py +110 -0
  35. invokeai/backend/flux/extensions/dype_extension.py +91 -0
  36. invokeai/backend/flux/schedulers.py +62 -0
  37. invokeai/backend/flux/util.py +35 -1
  38. invokeai/backend/flux2/__init__.py +4 -0
  39. invokeai/backend/flux2/denoise.py +280 -0
  40. invokeai/backend/flux2/ref_image_extension.py +294 -0
  41. invokeai/backend/flux2/sampling_utils.py +209 -0
  42. invokeai/backend/image_util/pbr_maps/architecture/block.py +367 -0
  43. invokeai/backend/image_util/pbr_maps/architecture/pbr_rrdb_net.py +70 -0
  44. invokeai/backend/image_util/pbr_maps/pbr_maps.py +141 -0
  45. invokeai/backend/image_util/pbr_maps/utils/image_ops.py +93 -0
  46. invokeai/backend/model_manager/configs/factory.py +19 -1
  47. invokeai/backend/model_manager/configs/lora.py +36 -0
  48. invokeai/backend/model_manager/configs/main.py +395 -3
  49. invokeai/backend/model_manager/configs/qwen3_encoder.py +116 -7
  50. invokeai/backend/model_manager/configs/vae.py +104 -2
  51. invokeai/backend/model_manager/load/model_cache/model_cache.py +107 -2
  52. invokeai/backend/model_manager/load/model_loaders/cogview4.py +2 -1
  53. invokeai/backend/model_manager/load/model_loaders/flux.py +1020 -8
  54. invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +4 -2
  55. invokeai/backend/model_manager/load/model_loaders/onnx.py +1 -0
  56. invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +2 -1
  57. invokeai/backend/model_manager/load/model_loaders/z_image.py +158 -31
  58. invokeai/backend/model_manager/starter_models.py +141 -4
  59. invokeai/backend/model_manager/taxonomy.py +31 -4
  60. invokeai/backend/model_manager/util/select_hf_files.py +3 -2
  61. invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +39 -5
  62. invokeai/backend/quantization/gguf/ggml_tensor.py +15 -4
  63. invokeai/backend/util/vae_working_memory.py +0 -2
  64. invokeai/backend/z_image/extensions/regional_prompting_extension.py +10 -12
  65. invokeai/frontend/web/dist/assets/App-D13dX7be.js +161 -0
  66. invokeai/frontend/web/dist/assets/{browser-ponyfill-DHZxq1nk.js → browser-ponyfill-u_ZjhQTI.js} +1 -1
  67. invokeai/frontend/web/dist/assets/index-BB0nHmDe.js +530 -0
  68. invokeai/frontend/web/dist/index.html +1 -1
  69. invokeai/frontend/web/dist/locales/en-GB.json +1 -0
  70. invokeai/frontend/web/dist/locales/en.json +85 -6
  71. invokeai/frontend/web/dist/locales/it.json +135 -15
  72. invokeai/frontend/web/dist/locales/ru.json +11 -11
  73. invokeai/version/invokeai_version.py +1 -1
  74. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/METADATA +8 -2
  75. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/RECORD +81 -57
  76. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/WHEEL +1 -1
  77. invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +0 -161
  78. invokeai/frontend/web/dist/assets/index-dgSJAY--.js +0 -530
  79. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/entry_points.txt +0 -0
  80. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE +0 -0
  81. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
  82. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
  83. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,91 @@
1
+ """DyPE extension for FLUX denoising pipeline."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING
5
+
6
+ from invokeai.backend.flux.dype.base import DyPEConfig
7
+ from invokeai.backend.flux.dype.embed import DyPEEmbedND
8
+
9
+ if TYPE_CHECKING:
10
+ from invokeai.backend.flux.model import Flux
11
+
12
+
13
+ @dataclass
14
+ class DyPEExtension:
15
+ """Extension for Dynamic Position Extrapolation in FLUX models.
16
+
17
+ This extension manages the patching of the FLUX model's position embedder
18
+ and updates the step state during denoising.
19
+
20
+ Usage:
21
+ 1. Create extension with config and target dimensions
22
+ 2. Call patch_model() to replace pe_embedder with DyPE version
23
+ 3. Call update_step_state() before each denoising step
24
+ 4. Call restore_model() after denoising to restore original embedder
25
+ """
26
+
27
+ config: DyPEConfig
28
+ target_height: int
29
+ target_width: int
30
+
31
+ def patch_model(self, model: "Flux") -> tuple[DyPEEmbedND, object]:
32
+ """Patch the model's position embedder with DyPE version.
33
+
34
+ Args:
35
+ model: The FLUX model to patch
36
+
37
+ Returns:
38
+ Tuple of (new DyPE embedder, original embedder for restoration)
39
+ """
40
+ original_embedder = model.pe_embedder
41
+
42
+ dype_embedder = DyPEEmbedND.from_embednd(
43
+ embed_nd=original_embedder,
44
+ dype_config=self.config,
45
+ )
46
+
47
+ # Set initial state
48
+ dype_embedder.set_step_state(
49
+ sigma=1.0,
50
+ height=self.target_height,
51
+ width=self.target_width,
52
+ )
53
+
54
+ # Replace the embedder
55
+ model.pe_embedder = dype_embedder
56
+
57
+ return dype_embedder, original_embedder
58
+
59
+ def update_step_state(
60
+ self,
61
+ embedder: DyPEEmbedND,
62
+ timestep: float,
63
+ timestep_index: int,
64
+ total_steps: int,
65
+ ) -> None:
66
+ """Update the step state in the DyPE embedder.
67
+
68
+ This should be called before each denoising step to update the
69
+ current noise level for timestep-dependent scaling.
70
+
71
+ Args:
72
+ embedder: The DyPE embedder to update
73
+ timestep: Current timestep value (sigma/noise level)
74
+ timestep_index: Current step index (0-based)
75
+ total_steps: Total number of denoising steps
76
+ """
77
+ embedder.set_step_state(
78
+ sigma=timestep,
79
+ height=self.target_height,
80
+ width=self.target_width,
81
+ )
82
+
83
+ @staticmethod
84
+ def restore_model(model: "Flux", original_embedder: object) -> None:
85
+ """Restore the original position embedder.
86
+
87
+ Args:
88
+ model: The FLUX model to restore
89
+ original_embedder: The original embedder saved from patch_model()
90
+ """
91
+ model.pe_embedder = original_embedder
@@ -0,0 +1,62 @@
1
+ """Flow Matching scheduler definitions and mapping.
2
+
3
+ This module provides the scheduler types and mapping for Flow Matching models
4
+ (Flux and Z-Image), supporting multiple schedulers from the diffusers library.
5
+ """
6
+
7
+ from typing import Literal, Type
8
+
9
+ from diffusers import (
10
+ FlowMatchEulerDiscreteScheduler,
11
+ FlowMatchHeunDiscreteScheduler,
12
+ )
13
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
14
+
15
+ # Note: FlowMatchLCMScheduler may not be available in all diffusers versions
16
+ try:
17
+ from diffusers import FlowMatchLCMScheduler
18
+
19
+ _HAS_LCM = True
20
+ except ImportError:
21
+ _HAS_LCM = False
22
+
23
+ # Scheduler name literal type for type checking
24
+ FLUX_SCHEDULER_NAME_VALUES = Literal["euler", "heun", "lcm"]
25
+
26
+ # Human-readable labels for the UI
27
+ FLUX_SCHEDULER_LABELS: dict[str, str] = {
28
+ "euler": "Euler",
29
+ "heun": "Heun (2nd order)",
30
+ "lcm": "LCM",
31
+ }
32
+
33
+ # Mapping from scheduler names to scheduler classes
34
+ FLUX_SCHEDULER_MAP: dict[str, Type[SchedulerMixin]] = {
35
+ "euler": FlowMatchEulerDiscreteScheduler,
36
+ "heun": FlowMatchHeunDiscreteScheduler,
37
+ }
38
+
39
+ if _HAS_LCM:
40
+ FLUX_SCHEDULER_MAP["lcm"] = FlowMatchLCMScheduler
41
+
42
+
43
+ # Z-Image scheduler types (same schedulers as Flux, both use Flow Matching)
44
+ # Note: Z-Image-Turbo is optimized for ~8 steps with Euler, but other schedulers
45
+ # can be used for experimentation.
46
+ ZIMAGE_SCHEDULER_NAME_VALUES = Literal["euler", "heun", "lcm"]
47
+
48
+ # Human-readable labels for the UI
49
+ ZIMAGE_SCHEDULER_LABELS: dict[str, str] = {
50
+ "euler": "Euler",
51
+ "heun": "Heun (2nd order)",
52
+ "lcm": "LCM",
53
+ }
54
+
55
+ # Mapping from scheduler names to scheduler classes (same as Flux)
56
+ ZIMAGE_SCHEDULER_MAP: dict[str, Type[SchedulerMixin]] = {
57
+ "euler": FlowMatchEulerDiscreteScheduler,
58
+ "heun": FlowMatchHeunDiscreteScheduler,
59
+ }
60
+
61
+ if _HAS_LCM:
62
+ ZIMAGE_SCHEDULER_MAP["lcm"] = FlowMatchLCMScheduler
@@ -5,7 +5,7 @@ from typing import Literal
5
5
 
6
6
  from invokeai.backend.flux.model import FluxParams
7
7
  from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams
8
- from invokeai.backend.model_manager.taxonomy import AnyVariant, FluxVariantType
8
+ from invokeai.backend.model_manager.taxonomy import AnyVariant, Flux2VariantType, FluxVariantType
9
9
 
10
10
 
11
11
  @dataclass
@@ -46,6 +46,8 @@ _flux_max_seq_lengths: dict[AnyVariant, Literal[256, 512]] = {
46
46
  FluxVariantType.Dev: 512,
47
47
  FluxVariantType.DevFill: 512,
48
48
  FluxVariantType.Schnell: 256,
49
+ Flux2VariantType.Klein4B: 512,
50
+ Flux2VariantType.Klein9B: 512,
49
51
  }
50
52
 
51
53
 
@@ -117,6 +119,38 @@ _flux_transformer_params: dict[AnyVariant, FluxParams] = {
117
119
  qkv_bias=True,
118
120
  guidance_embed=True,
119
121
  ),
122
+ # Flux2 Klein 4B uses Qwen3 4B text encoder with stacked embeddings from layers [9, 18, 27]
123
+ # The context_in_dim is 3 * hidden_size of Qwen3 (3 * 2560 = 7680)
124
+ Flux2VariantType.Klein4B: FluxParams(
125
+ in_channels=64,
126
+ vec_in_dim=2560, # Qwen3-4B hidden size (used for pooled output)
127
+ context_in_dim=7680, # 3 layers * 2560 = 7680 for Qwen3-4B
128
+ hidden_size=3072,
129
+ mlp_ratio=4.0,
130
+ num_heads=24,
131
+ depth=19,
132
+ depth_single_blocks=38,
133
+ axes_dim=[16, 56, 56],
134
+ theta=10_000,
135
+ qkv_bias=True,
136
+ guidance_embed=True,
137
+ ),
138
+ # Flux2 Klein 9B uses Qwen3 8B text encoder with stacked embeddings from layers [9, 18, 27]
139
+ # The context_in_dim is 3 * hidden_size of Qwen3 (3 * 4096 = 12288)
140
+ Flux2VariantType.Klein9B: FluxParams(
141
+ in_channels=64,
142
+ vec_in_dim=4096, # Qwen3-8B hidden size (used for pooled output)
143
+ context_in_dim=12288, # 3 layers * 4096 = 12288 for Qwen3-8B
144
+ hidden_size=3072,
145
+ mlp_ratio=4.0,
146
+ num_heads=24,
147
+ depth=19,
148
+ depth_single_blocks=38,
149
+ axes_dim=[16, 56, 56],
150
+ theta=10_000,
151
+ qkv_bias=True,
152
+ guidance_embed=True,
153
+ ),
120
154
  }
121
155
 
122
156
 
@@ -0,0 +1,4 @@
1
+ """FLUX.2 backend modules.
2
+
3
+ This package contains modules specific to FLUX.2 models (e.g., Klein).
4
+ """
@@ -0,0 +1,280 @@
1
+ """Flux2 Klein Denoising Function.
2
+
3
+ This module provides the denoising function for FLUX.2 Klein models,
4
+ which use Qwen3 as the text encoder instead of CLIP+T5.
5
+ """
6
+
7
+ import math
8
+ from typing import Any, Callable
9
+
10
+ import numpy as np
11
+ import torch
12
+ from tqdm import tqdm
13
+
14
+ from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
15
+ from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
16
+
17
+
18
+ def denoise(
19
+ model: torch.nn.Module,
20
+ # model input
21
+ img: torch.Tensor,
22
+ img_ids: torch.Tensor,
23
+ txt: torch.Tensor,
24
+ txt_ids: torch.Tensor,
25
+ # sampling parameters
26
+ timesteps: list[float],
27
+ step_callback: Callable[[PipelineIntermediateState], None],
28
+ cfg_scale: list[float],
29
+ # Negative conditioning for CFG
30
+ neg_txt: torch.Tensor | None = None,
31
+ neg_txt_ids: torch.Tensor | None = None,
32
+ # Scheduler for stepping (e.g., FlowMatchEulerDiscreteScheduler, FlowMatchHeunDiscreteScheduler)
33
+ scheduler: Any = None,
34
+ # Dynamic shifting parameter for FLUX.2 Klein (computed from image resolution)
35
+ mu: float | None = None,
36
+ # Inpainting extension for merging latents during denoising
37
+ inpaint_extension: RectifiedFlowInpaintExtension | None = None,
38
+ # Reference image conditioning (multi-reference image editing)
39
+ img_cond_seq: torch.Tensor | None = None,
40
+ img_cond_seq_ids: torch.Tensor | None = None,
41
+ ) -> torch.Tensor:
42
+ """Denoise latents using a FLUX.2 Klein transformer model.
43
+
44
+ This is a simplified denoise function for FLUX.2 Klein models that uses
45
+ the diffusers Flux2Transformer2DModel interface.
46
+
47
+ Note: FLUX.2 Klein has guidance_embeds=False, so no guidance parameter is used.
48
+ CFG is applied externally using negative conditioning when cfg_scale != 1.0.
49
+
50
+ Args:
51
+ model: The Flux2Transformer2DModel from diffusers.
52
+ img: Packed latent image tensor of shape (B, seq_len, channels).
53
+ img_ids: Image position IDs tensor.
54
+ txt: Text encoder hidden states (Qwen3 embeddings).
55
+ txt_ids: Text position IDs tensor.
56
+ timesteps: List of timesteps for denoising schedule (linear sigmas from 1.0 to 1/n).
57
+ step_callback: Callback function for progress updates.
58
+ cfg_scale: List of CFG scale values per step.
59
+ neg_txt: Negative text embeddings for CFG (optional).
60
+ neg_txt_ids: Negative text position IDs (optional).
61
+ scheduler: Optional diffusers scheduler (Euler, Heun, LCM). If None, uses manual Euler.
62
+ mu: Dynamic shifting parameter computed from image resolution. Required when scheduler
63
+ has use_dynamic_shifting=True.
64
+
65
+ Returns:
66
+ Denoised latent tensor.
67
+ """
68
+ total_steps = len(timesteps) - 1
69
+
70
+ # Store original sequence length for extracting output later (before concatenating reference images)
71
+ original_seq_len = img.shape[1]
72
+
73
+ # Concatenate reference image conditioning if provided (multi-reference image editing)
74
+ if img_cond_seq is not None and img_cond_seq_ids is not None:
75
+ img = torch.cat([img, img_cond_seq], dim=1)
76
+ img_ids = torch.cat([img_ids, img_cond_seq_ids], dim=1)
77
+
78
+ # Klein has guidance_embeds=False, but the transformer forward() still requires a guidance tensor
79
+ # We pass a dummy value (1.0) since it won't affect the output when guidance_embeds=False
80
+ guidance = torch.full((img.shape[0],), 1.0, device=img.device, dtype=img.dtype)
81
+
82
+ # Use scheduler if provided
83
+ use_scheduler = scheduler is not None
84
+ if use_scheduler:
85
+ # Set up scheduler with sigmas and mu for dynamic shifting
86
+ # Convert timesteps (0-1 range) to sigmas for the scheduler
87
+ # The scheduler will apply dynamic shifting internally using mu (if enabled in scheduler config)
88
+ sigmas = np.array(timesteps[:-1], dtype=np.float32) # Exclude final 0.0
89
+
90
+ # Pass mu if provided - it will only be used if scheduler has use_dynamic_shifting=True
91
+ if mu is not None:
92
+ scheduler.set_timesteps(sigmas=sigmas.tolist(), mu=mu, device=img.device)
93
+ else:
94
+ scheduler.set_timesteps(sigmas=sigmas.tolist(), device=img.device)
95
+ num_scheduler_steps = len(scheduler.timesteps)
96
+ is_heun = hasattr(scheduler, "state_in_first_order")
97
+ user_step = 0
98
+
99
+ pbar = tqdm(total=total_steps, desc="Denoising")
100
+ for step_index in range(num_scheduler_steps):
101
+ timestep = scheduler.timesteps[step_index]
102
+ # Convert scheduler timestep (0-1000) to normalized (0-1) for the model
103
+ t_curr = timestep.item() / scheduler.config.num_train_timesteps
104
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
105
+
106
+ # Track if we're in first or second order step (for Heun)
107
+ in_first_order = scheduler.state_in_first_order if is_heun else True
108
+
109
+ # Run the transformer model (matching diffusers: guidance=guidance, return_dict=False)
110
+ output = model(
111
+ hidden_states=img,
112
+ encoder_hidden_states=txt,
113
+ timestep=t_vec,
114
+ img_ids=img_ids,
115
+ txt_ids=txt_ids,
116
+ guidance=guidance,
117
+ return_dict=False,
118
+ )
119
+
120
+ # Extract the sample from the output (return_dict=False returns tuple)
121
+ pred = output[0] if isinstance(output, tuple) else output
122
+
123
+ step_cfg_scale = cfg_scale[min(user_step, len(cfg_scale) - 1)]
124
+
125
+ # Apply CFG if scale is not 1.0
126
+ if not math.isclose(step_cfg_scale, 1.0):
127
+ if neg_txt is None:
128
+ raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
129
+
130
+ neg_output = model(
131
+ hidden_states=img,
132
+ encoder_hidden_states=neg_txt,
133
+ timestep=t_vec,
134
+ img_ids=img_ids,
135
+ txt_ids=neg_txt_ids if neg_txt_ids is not None else txt_ids,
136
+ guidance=guidance,
137
+ return_dict=False,
138
+ )
139
+
140
+ neg_pred = neg_output[0] if isinstance(neg_output, tuple) else neg_output
141
+ pred = neg_pred + step_cfg_scale * (pred - neg_pred)
142
+
143
+ # Use scheduler.step() for the update
144
+ step_output = scheduler.step(model_output=pred, timestep=timestep, sample=img)
145
+ img = step_output.prev_sample
146
+
147
+ # Get t_prev for inpainting (next sigma value)
148
+ if step_index + 1 < len(scheduler.sigmas):
149
+ t_prev = scheduler.sigmas[step_index + 1].item()
150
+ else:
151
+ t_prev = 0.0
152
+
153
+ # Apply inpainting merge at each step
154
+ if inpaint_extension is not None:
155
+ # Separate the generated latents from the reference conditioning
156
+ gen_img = img[:, :original_seq_len, :]
157
+ ref_img = img[:, original_seq_len:, :]
158
+
159
+ # Merge only the generated part
160
+ gen_img = inpaint_extension.merge_intermediate_latents_with_init_latents(gen_img, t_prev)
161
+
162
+ # Concatenate back together
163
+ img = torch.cat([gen_img, ref_img], dim=1)
164
+
165
+ # For Heun, only increment user step after second-order step completes
166
+ if is_heun:
167
+ if not in_first_order:
168
+ user_step += 1
169
+ if user_step <= total_steps:
170
+ pbar.update(1)
171
+ preview_img = img - t_curr * pred
172
+ if inpaint_extension is not None:
173
+ preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(
174
+ preview_img, 0.0
175
+ )
176
+ step_callback(
177
+ PipelineIntermediateState(
178
+ step=user_step,
179
+ order=2,
180
+ total_steps=total_steps,
181
+ timestep=int(t_curr * 1000),
182
+ latents=preview_img,
183
+ ),
184
+ )
185
+ else:
186
+ user_step += 1
187
+ if user_step <= total_steps:
188
+ pbar.update(1)
189
+ preview_img = img - t_curr * pred
190
+ if inpaint_extension is not None:
191
+ preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_img, 0.0)
192
+ # Extract only the generated image portion for preview (exclude reference images)
193
+ callback_latents = preview_img[:, :original_seq_len, :] if img_cond_seq is not None else preview_img
194
+ step_callback(
195
+ PipelineIntermediateState(
196
+ step=user_step,
197
+ order=1,
198
+ total_steps=total_steps,
199
+ timestep=int(t_curr * 1000),
200
+ latents=callback_latents,
201
+ ),
202
+ )
203
+
204
+ pbar.close()
205
+ else:
206
+ # Manual Euler stepping (original behavior)
207
+ for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
208
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
209
+
210
+ # Run the transformer model (matching diffusers: guidance=guidance, return_dict=False)
211
+ output = model(
212
+ hidden_states=img,
213
+ encoder_hidden_states=txt,
214
+ timestep=t_vec,
215
+ img_ids=img_ids,
216
+ txt_ids=txt_ids,
217
+ guidance=guidance,
218
+ return_dict=False,
219
+ )
220
+
221
+ # Extract the sample from the output (return_dict=False returns tuple)
222
+ pred = output[0] if isinstance(output, tuple) else output
223
+
224
+ step_cfg_scale = cfg_scale[step_index]
225
+
226
+ # Apply CFG if scale is not 1.0
227
+ if not math.isclose(step_cfg_scale, 1.0):
228
+ if neg_txt is None:
229
+ raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
230
+
231
+ neg_output = model(
232
+ hidden_states=img,
233
+ encoder_hidden_states=neg_txt,
234
+ timestep=t_vec,
235
+ img_ids=img_ids,
236
+ txt_ids=neg_txt_ids if neg_txt_ids is not None else txt_ids,
237
+ guidance=guidance,
238
+ return_dict=False,
239
+ )
240
+
241
+ neg_pred = neg_output[0] if isinstance(neg_output, tuple) else neg_output
242
+ pred = neg_pred + step_cfg_scale * (pred - neg_pred)
243
+
244
+ # Euler step
245
+ preview_img = img - t_curr * pred
246
+ img = img + (t_prev - t_curr) * pred
247
+
248
+ # Apply inpainting merge at each step
249
+ if inpaint_extension is not None:
250
+ # Separate the generated latents from the reference conditioning
251
+ gen_img = img[:, :original_seq_len, :]
252
+ ref_img = img[:, original_seq_len:, :]
253
+
254
+ # Merge only the generated part
255
+ gen_img = inpaint_extension.merge_intermediate_latents_with_init_latents(gen_img, t_prev)
256
+
257
+ # Concatenate back together
258
+ img = torch.cat([gen_img, ref_img], dim=1)
259
+
260
+ # Handling preview images
261
+ preview_gen = preview_img[:, :original_seq_len, :]
262
+ preview_gen = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_gen, 0.0)
263
+
264
+ # Extract only the generated image portion for preview (exclude reference images)
265
+ callback_latents = preview_img[:, :original_seq_len, :] if img_cond_seq is not None else preview_img
266
+ step_callback(
267
+ PipelineIntermediateState(
268
+ step=step_index + 1,
269
+ order=1,
270
+ total_steps=total_steps,
271
+ timestep=int(t_curr),
272
+ latents=callback_latents,
273
+ ),
274
+ )
275
+
276
+ # Extract only the generated image portion (exclude concatenated reference images)
277
+ if img_cond_seq is not None:
278
+ img = img[:, :original_seq_len, :]
279
+
280
+ return img