InvokeAI 6.10.0rc1__py3-none-any.whl → 6.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. invokeai/app/api/routers/model_manager.py +43 -1
  2. invokeai/app/invocations/fields.py +1 -1
  3. invokeai/app/invocations/flux2_denoise.py +499 -0
  4. invokeai/app/invocations/flux2_klein_model_loader.py +222 -0
  5. invokeai/app/invocations/flux2_klein_text_encoder.py +222 -0
  6. invokeai/app/invocations/flux2_vae_decode.py +106 -0
  7. invokeai/app/invocations/flux2_vae_encode.py +88 -0
  8. invokeai/app/invocations/flux_denoise.py +77 -3
  9. invokeai/app/invocations/flux_lora_loader.py +1 -1
  10. invokeai/app/invocations/flux_model_loader.py +2 -5
  11. invokeai/app/invocations/ideal_size.py +6 -1
  12. invokeai/app/invocations/metadata.py +4 -0
  13. invokeai/app/invocations/metadata_linked.py +47 -0
  14. invokeai/app/invocations/model.py +1 -0
  15. invokeai/app/invocations/pbr_maps.py +59 -0
  16. invokeai/app/invocations/z_image_denoise.py +244 -84
  17. invokeai/app/invocations/z_image_image_to_latents.py +9 -1
  18. invokeai/app/invocations/z_image_latents_to_image.py +9 -1
  19. invokeai/app/invocations/z_image_seed_variance_enhancer.py +110 -0
  20. invokeai/app/services/config/config_default.py +3 -1
  21. invokeai/app/services/invocation_stats/invocation_stats_common.py +6 -6
  22. invokeai/app/services/invocation_stats/invocation_stats_default.py +9 -4
  23. invokeai/app/services/model_manager/model_manager_default.py +7 -0
  24. invokeai/app/services/model_records/model_records_base.py +4 -2
  25. invokeai/app/services/shared/invocation_context.py +15 -0
  26. invokeai/app/services/shared/sqlite/sqlite_util.py +2 -0
  27. invokeai/app/services/shared/sqlite_migrator/migrations/migration_25.py +61 -0
  28. invokeai/app/util/step_callback.py +58 -2
  29. invokeai/backend/flux/denoise.py +338 -118
  30. invokeai/backend/flux/dype/__init__.py +31 -0
  31. invokeai/backend/flux/dype/base.py +260 -0
  32. invokeai/backend/flux/dype/embed.py +116 -0
  33. invokeai/backend/flux/dype/presets.py +148 -0
  34. invokeai/backend/flux/dype/rope.py +110 -0
  35. invokeai/backend/flux/extensions/dype_extension.py +91 -0
  36. invokeai/backend/flux/schedulers.py +62 -0
  37. invokeai/backend/flux/util.py +35 -1
  38. invokeai/backend/flux2/__init__.py +4 -0
  39. invokeai/backend/flux2/denoise.py +280 -0
  40. invokeai/backend/flux2/ref_image_extension.py +294 -0
  41. invokeai/backend/flux2/sampling_utils.py +209 -0
  42. invokeai/backend/image_util/pbr_maps/architecture/block.py +367 -0
  43. invokeai/backend/image_util/pbr_maps/architecture/pbr_rrdb_net.py +70 -0
  44. invokeai/backend/image_util/pbr_maps/pbr_maps.py +141 -0
  45. invokeai/backend/image_util/pbr_maps/utils/image_ops.py +93 -0
  46. invokeai/backend/model_manager/configs/factory.py +19 -1
  47. invokeai/backend/model_manager/configs/lora.py +36 -0
  48. invokeai/backend/model_manager/configs/main.py +395 -3
  49. invokeai/backend/model_manager/configs/qwen3_encoder.py +116 -7
  50. invokeai/backend/model_manager/configs/vae.py +104 -2
  51. invokeai/backend/model_manager/load/model_cache/model_cache.py +107 -2
  52. invokeai/backend/model_manager/load/model_loaders/cogview4.py +2 -1
  53. invokeai/backend/model_manager/load/model_loaders/flux.py +1020 -8
  54. invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +4 -2
  55. invokeai/backend/model_manager/load/model_loaders/onnx.py +1 -0
  56. invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +2 -1
  57. invokeai/backend/model_manager/load/model_loaders/z_image.py +158 -31
  58. invokeai/backend/model_manager/starter_models.py +141 -4
  59. invokeai/backend/model_manager/taxonomy.py +31 -4
  60. invokeai/backend/model_manager/util/select_hf_files.py +3 -2
  61. invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +39 -5
  62. invokeai/backend/quantization/gguf/ggml_tensor.py +15 -4
  63. invokeai/backend/util/vae_working_memory.py +0 -2
  64. invokeai/backend/z_image/extensions/regional_prompting_extension.py +10 -12
  65. invokeai/frontend/web/dist/assets/App-D13dX7be.js +161 -0
  66. invokeai/frontend/web/dist/assets/{browser-ponyfill-DHZxq1nk.js → browser-ponyfill-u_ZjhQTI.js} +1 -1
  67. invokeai/frontend/web/dist/assets/index-BB0nHmDe.js +530 -0
  68. invokeai/frontend/web/dist/index.html +1 -1
  69. invokeai/frontend/web/dist/locales/en-GB.json +1 -0
  70. invokeai/frontend/web/dist/locales/en.json +85 -6
  71. invokeai/frontend/web/dist/locales/it.json +135 -15
  72. invokeai/frontend/web/dist/locales/ru.json +11 -11
  73. invokeai/version/invokeai_version.py +1 -1
  74. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/METADATA +8 -2
  75. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/RECORD +81 -57
  76. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/WHEEL +1 -1
  77. invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +0 -161
  78. invokeai/frontend/web/dist/assets/index-dgSJAY--.js +0 -530
  79. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/entry_points.txt +0 -0
  80. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE +0 -0
  81. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
  82. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
  83. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,13 @@
1
+ import inspect
1
2
  import math
2
3
  from typing import Callable
3
4
 
4
5
  import torch
6
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
5
7
  from tqdm import tqdm
6
8
 
7
9
  from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput, sum_controlnet_flux_outputs
10
+ from invokeai.backend.flux.extensions.dype_extension import DyPEExtension
8
11
  from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
9
12
  from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
10
13
  from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
@@ -35,149 +38,366 @@ def denoise(
35
38
  # extra img tokens (sequence-wise) - for Kontext conditioning
36
39
  img_cond_seq: torch.Tensor | None = None,
37
40
  img_cond_seq_ids: torch.Tensor | None = None,
41
+ # DyPE extension for high-resolution generation
42
+ dype_extension: DyPEExtension | None = None,
43
+ # Optional scheduler for alternative sampling methods
44
+ scheduler: SchedulerMixin | None = None,
38
45
  ):
39
- # step 0 is the initial state
40
- total_steps = len(timesteps) - 1
41
- step_callback(
42
- PipelineIntermediateState(
43
- step=0,
44
- order=1,
45
- total_steps=total_steps,
46
- timestep=int(timesteps[0]),
47
- latents=img,
48
- ),
49
- )
46
+ # Determine if we're using a diffusers scheduler or the built-in Euler method
47
+ use_scheduler = scheduler is not None
48
+
49
+ if use_scheduler:
50
+ # Initialize scheduler with timesteps
51
+ # The timesteps list contains values in [0, 1] range (sigmas)
52
+ # LCM should use num_inference_steps (it has its own sigma schedule),
53
+ # while other schedulers can use custom sigmas if supported
54
+ is_lcm = scheduler.__class__.__name__ == "FlowMatchLCMScheduler"
55
+ set_timesteps_sig = inspect.signature(scheduler.set_timesteps)
56
+ if not is_lcm and "sigmas" in set_timesteps_sig.parameters:
57
+ # Scheduler supports custom sigmas - use InvokeAI's time-shifted schedule
58
+ scheduler.set_timesteps(sigmas=timesteps, device=img.device)
59
+ else:
60
+ # LCM or scheduler doesn't support custom sigmas - use num_inference_steps
61
+ # The schedule will be computed by the scheduler itself
62
+ num_inference_steps = len(timesteps) - 1
63
+ scheduler.set_timesteps(num_inference_steps=num_inference_steps, device=img.device)
64
+
65
+ # For schedulers like Heun, the number of actual steps may differ
66
+ # (Heun doubles timesteps internally)
67
+ num_scheduler_steps = len(scheduler.timesteps)
68
+ # For user-facing step count, use the original number of denoising steps
69
+ total_steps = len(timesteps) - 1
70
+ else:
71
+ total_steps = len(timesteps) - 1
72
+ num_scheduler_steps = total_steps
73
+
50
74
  # guidance_vec is ignored for schnell.
51
75
  guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
52
76
 
53
77
  # Store original sequence length for slicing predictions
54
78
  original_seq_len = img.shape[1]
55
79
 
56
- for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
57
- t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
80
+ # DyPE: Patch model with DyPE-aware position embedder
81
+ dype_embedder = None
82
+ original_pe_embedder = None
83
+ if dype_extension is not None:
84
+ dype_embedder, original_pe_embedder = dype_extension.patch_model(model)
58
85
 
59
- # Run ControlNet models.
60
- controlnet_residuals: list[ControlNetFluxOutput] = []
61
- for controlnet_extension in controlnet_extensions:
62
- controlnet_residuals.append(
63
- controlnet_extension.run_controlnet(
64
- timestep_index=step_index,
65
- total_num_timesteps=total_steps,
66
- img=img,
67
- img_ids=img_ids,
86
+ try:
87
+ # Track the actual step for user-facing progress (accounts for Heun's double steps)
88
+ user_step = 0
89
+
90
+ if use_scheduler:
91
+ # Use diffusers scheduler for stepping
92
+ # Use tqdm with total_steps (user-facing steps) not num_scheduler_steps (internal steps)
93
+ # This ensures progress bar shows 1/8, 2/8, etc. even when scheduler uses more internal steps
94
+ pbar = tqdm(total=total_steps, desc="Denoising")
95
+ for step_index in range(num_scheduler_steps):
96
+ timestep = scheduler.timesteps[step_index]
97
+ # Convert scheduler timestep (0-1000) to normalized (0-1) for the model
98
+ t_curr = timestep.item() / scheduler.config.num_train_timesteps
99
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
100
+
101
+ # DyPE: Update step state for timestep-dependent scaling
102
+ if dype_extension is not None and dype_embedder is not None:
103
+ dype_extension.update_step_state(
104
+ embedder=dype_embedder,
105
+ timestep=t_curr,
106
+ timestep_index=user_step,
107
+ total_steps=total_steps,
108
+ )
109
+
110
+ # For Heun scheduler, track if we're in first or second order step
111
+ is_heun = hasattr(scheduler, "state_in_first_order")
112
+ in_first_order = scheduler.state_in_first_order if is_heun else True
113
+
114
+ # Run ControlNet models
115
+ controlnet_residuals: list[ControlNetFluxOutput] = []
116
+ for controlnet_extension in controlnet_extensions:
117
+ controlnet_residuals.append(
118
+ controlnet_extension.run_controlnet(
119
+ timestep_index=user_step,
120
+ total_num_timesteps=total_steps,
121
+ img=img,
122
+ img_ids=img_ids,
123
+ txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
124
+ txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
125
+ y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
126
+ timesteps=t_vec,
127
+ guidance=guidance_vec,
128
+ )
129
+ )
130
+
131
+ merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
132
+
133
+ # Prepare input for model
134
+ img_input = img
135
+ img_input_ids = img_ids
136
+
137
+ if img_cond is not None:
138
+ img_input = torch.cat((img_input, img_cond), dim=-1)
139
+
140
+ if img_cond_seq is not None:
141
+ assert img_cond_seq_ids is not None
142
+ img_input = torch.cat((img_input, img_cond_seq), dim=1)
143
+ img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
144
+
145
+ pred = model(
146
+ img=img_input,
147
+ img_ids=img_input_ids,
68
148
  txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
69
149
  txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
70
150
  y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
71
151
  timesteps=t_vec,
72
152
  guidance=guidance_vec,
153
+ timestep_index=user_step,
154
+ total_num_timesteps=total_steps,
155
+ controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
156
+ controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
157
+ ip_adapter_extensions=pos_ip_adapter_extensions,
158
+ regional_prompting_extension=pos_regional_prompting_extension,
73
159
  )
74
- )
75
160
 
76
- # Merge the ControlNet residuals from multiple ControlNets.
77
- # TODO(ryand): We may want to calculate the sum just-in-time to keep peak memory low. Keep in mind, that the
78
- # controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
79
- # tensors. Calculating the sum materializes each tensor into its own instance.
80
- merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
161
+ if img_cond_seq is not None:
162
+ pred = pred[:, :original_seq_len]
81
163
 
82
- # Prepare input for model - concatenate fresh each step
83
- img_input = img
84
- img_input_ids = img_ids
164
+ # Get CFG scale for current user step
165
+ step_cfg_scale = cfg_scale[min(user_step, len(cfg_scale) - 1)]
85
166
 
86
- # Add channel-wise conditioning (for ControlNet, FLUX Fill, etc.)
87
- if img_cond is not None:
88
- img_input = torch.cat((img_input, img_cond), dim=-1)
167
+ if not math.isclose(step_cfg_scale, 1.0):
168
+ if neg_regional_prompting_extension is None:
169
+ raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
89
170
 
90
- # Add sequence-wise conditioning (for Kontext)
91
- if img_cond_seq is not None:
92
- assert img_cond_seq_ids is not None, (
93
- "You need to provide either both or neither of the sequence conditioning"
94
- )
95
- img_input = torch.cat((img_input, img_cond_seq), dim=1)
96
- img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
97
-
98
- pred = model(
99
- img=img_input,
100
- img_ids=img_input_ids,
101
- txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
102
- txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
103
- y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
104
- timesteps=t_vec,
105
- guidance=guidance_vec,
106
- timestep_index=step_index,
107
- total_num_timesteps=total_steps,
108
- controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
109
- controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
110
- ip_adapter_extensions=pos_ip_adapter_extensions,
111
- regional_prompting_extension=pos_regional_prompting_extension,
112
- )
113
-
114
- # Slice prediction to only include the main image tokens
115
- if img_cond_seq is not None:
116
- pred = pred[:, :original_seq_len]
117
-
118
- step_cfg_scale = cfg_scale[step_index]
119
-
120
- # If step_cfg_scale, is 1.0, then we don't need to run the negative prediction.
121
- if not math.isclose(step_cfg_scale, 1.0):
122
- # TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance
123
- # on systems with sufficient VRAM.
124
-
125
- if neg_regional_prompting_extension is None:
126
- raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
127
-
128
- # For negative prediction with Kontext, we need to include the reference images
129
- # to maintain consistency between positive and negative passes. Without this,
130
- # CFG would create artifacts as the attention mechanism would see different
131
- # spatial structures in each pass
132
- neg_img_input = img
133
- neg_img_input_ids = img_ids
134
-
135
- # Add channel-wise conditioning for negative pass if present
171
+ neg_img_input = img
172
+ neg_img_input_ids = img_ids
173
+
174
+ if img_cond is not None:
175
+ neg_img_input = torch.cat((neg_img_input, img_cond), dim=-1)
176
+
177
+ if img_cond_seq is not None:
178
+ neg_img_input = torch.cat((neg_img_input, img_cond_seq), dim=1)
179
+ neg_img_input_ids = torch.cat((neg_img_input_ids, img_cond_seq_ids), dim=1)
180
+
181
+ neg_pred = model(
182
+ img=neg_img_input,
183
+ img_ids=neg_img_input_ids,
184
+ txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
185
+ txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
186
+ y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
187
+ timesteps=t_vec,
188
+ guidance=guidance_vec,
189
+ timestep_index=user_step,
190
+ total_num_timesteps=total_steps,
191
+ controlnet_double_block_residuals=None,
192
+ controlnet_single_block_residuals=None,
193
+ ip_adapter_extensions=neg_ip_adapter_extensions,
194
+ regional_prompting_extension=neg_regional_prompting_extension,
195
+ )
196
+
197
+ if img_cond_seq is not None:
198
+ neg_pred = neg_pred[:, :original_seq_len]
199
+ pred = neg_pred + step_cfg_scale * (pred - neg_pred)
200
+
201
+ # Use scheduler.step() for the update
202
+ step_output = scheduler.step(model_output=pred, timestep=timestep, sample=img)
203
+ img = step_output.prev_sample
204
+
205
+ # Get t_prev for inpainting (next sigma value)
206
+ if step_index + 1 < len(scheduler.sigmas):
207
+ t_prev = scheduler.sigmas[step_index + 1].item()
208
+ else:
209
+ t_prev = 0.0
210
+
211
+ if inpaint_extension is not None:
212
+ img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
213
+
214
+ # For Heun, only increment user step after second-order step completes
215
+ if is_heun:
216
+ if not in_first_order:
217
+ # Second order step completed
218
+ user_step += 1
219
+ # Only call step_callback if we haven't exceeded total_steps
220
+ if user_step <= total_steps:
221
+ pbar.update(1)
222
+ preview_img = img - t_curr * pred
223
+ if inpaint_extension is not None:
224
+ preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(
225
+ preview_img, 0.0
226
+ )
227
+ step_callback(
228
+ PipelineIntermediateState(
229
+ step=user_step,
230
+ order=2,
231
+ total_steps=total_steps,
232
+ timestep=int(t_curr * 1000),
233
+ latents=preview_img,
234
+ ),
235
+ )
236
+ else:
237
+ # For LCM and other first-order schedulers
238
+ user_step += 1
239
+ # Only call step_callback if we haven't exceeded total_steps
240
+ # (LCM scheduler may have more internal steps than user-facing steps)
241
+ if user_step <= total_steps:
242
+ pbar.update(1)
243
+ preview_img = img - t_curr * pred
244
+ if inpaint_extension is not None:
245
+ preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(
246
+ preview_img, 0.0
247
+ )
248
+ step_callback(
249
+ PipelineIntermediateState(
250
+ step=user_step,
251
+ order=1,
252
+ total_steps=total_steps,
253
+ timestep=int(t_curr * 1000),
254
+ latents=preview_img,
255
+ ),
256
+ )
257
+
258
+ pbar.close()
259
+ return img
260
+
261
+ # Original Euler implementation (when scheduler is None)
262
+ for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
263
+ # DyPE: Update step state for timestep-dependent scaling
264
+ if dype_extension is not None and dype_embedder is not None:
265
+ dype_extension.update_step_state(
266
+ embedder=dype_embedder,
267
+ timestep=t_curr,
268
+ timestep_index=step_index,
269
+ total_steps=total_steps,
270
+ )
271
+
272
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
273
+
274
+ # Run ControlNet models.
275
+ controlnet_residuals: list[ControlNetFluxOutput] = []
276
+ for controlnet_extension in controlnet_extensions:
277
+ controlnet_residuals.append(
278
+ controlnet_extension.run_controlnet(
279
+ timestep_index=step_index,
280
+ total_num_timesteps=total_steps,
281
+ img=img,
282
+ img_ids=img_ids,
283
+ txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
284
+ txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
285
+ y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
286
+ timesteps=t_vec,
287
+ guidance=guidance_vec,
288
+ )
289
+ )
290
+
291
+ # Merge the ControlNet residuals from multiple ControlNets.
292
+ # TODO(ryand): We may want to calculate the sum just-in-time to keep peak memory low. Keep in mind, that the
293
+ # controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
294
+ # tensors. Calculating the sum materializes each tensor into its own instance.
295
+ merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
296
+
297
+ # Prepare input for model - concatenate fresh each step
298
+ img_input = img
299
+ img_input_ids = img_ids
300
+
301
+ # Add channel-wise conditioning (for ControlNet, FLUX Fill, etc.)
136
302
  if img_cond is not None:
137
- neg_img_input = torch.cat((neg_img_input, img_cond), dim=-1)
303
+ img_input = torch.cat((img_input, img_cond), dim=-1)
138
304
 
139
- # Add sequence-wise conditioning (Kontext) for negative pass
140
- # This ensures reference images are processed consistently
305
+ # Add sequence-wise conditioning (for Kontext)
141
306
  if img_cond_seq is not None:
142
- neg_img_input = torch.cat((neg_img_input, img_cond_seq), dim=1)
143
- neg_img_input_ids = torch.cat((neg_img_input_ids, img_cond_seq_ids), dim=1)
144
-
145
- neg_pred = model(
146
- img=neg_img_input,
147
- img_ids=neg_img_input_ids,
148
- txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
149
- txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
150
- y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
307
+ assert img_cond_seq_ids is not None, (
308
+ "You need to provide either both or neither of the sequence conditioning"
309
+ )
310
+ img_input = torch.cat((img_input, img_cond_seq), dim=1)
311
+ img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
312
+
313
+ pred = model(
314
+ img=img_input,
315
+ img_ids=img_input_ids,
316
+ txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
317
+ txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
318
+ y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
151
319
  timesteps=t_vec,
152
320
  guidance=guidance_vec,
153
321
  timestep_index=step_index,
154
322
  total_num_timesteps=total_steps,
155
- controlnet_double_block_residuals=None,
156
- controlnet_single_block_residuals=None,
157
- ip_adapter_extensions=neg_ip_adapter_extensions,
158
- regional_prompting_extension=neg_regional_prompting_extension,
323
+ controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
324
+ controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
325
+ ip_adapter_extensions=pos_ip_adapter_extensions,
326
+ regional_prompting_extension=pos_regional_prompting_extension,
159
327
  )
160
328
 
161
- # Slice negative prediction to match main image tokens
329
+ # Slice prediction to only include the main image tokens
162
330
  if img_cond_seq is not None:
163
- neg_pred = neg_pred[:, :original_seq_len]
164
- pred = neg_pred + step_cfg_scale * (pred - neg_pred)
165
-
166
- preview_img = img - t_curr * pred
167
- img = img + (t_prev - t_curr) * pred
168
-
169
- if inpaint_extension is not None:
170
- img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
171
- preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_img, 0.0)
172
-
173
- step_callback(
174
- PipelineIntermediateState(
175
- step=step_index + 1,
176
- order=1,
177
- total_steps=total_steps,
178
- timestep=int(t_curr),
179
- latents=preview_img,
180
- ),
181
- )
182
-
183
- return img
331
+ pred = pred[:, :original_seq_len]
332
+
333
+ step_cfg_scale = cfg_scale[step_index]
334
+
335
+ # If step_cfg_scale, is 1.0, then we don't need to run the negative prediction.
336
+ if not math.isclose(step_cfg_scale, 1.0):
337
+ # TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance
338
+ # on systems with sufficient VRAM.
339
+
340
+ if neg_regional_prompting_extension is None:
341
+ raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
342
+
343
+ # For negative prediction with Kontext, we need to include the reference images
344
+ # to maintain consistency between positive and negative passes. Without this,
345
+ # CFG would create artifacts as the attention mechanism would see different
346
+ # spatial structures in each pass
347
+ neg_img_input = img
348
+ neg_img_input_ids = img_ids
349
+
350
+ # Add channel-wise conditioning for negative pass if present
351
+ if img_cond is not None:
352
+ neg_img_input = torch.cat((neg_img_input, img_cond), dim=-1)
353
+
354
+ # Add sequence-wise conditioning (Kontext) for negative pass
355
+ # This ensures reference images are processed consistently
356
+ if img_cond_seq is not None:
357
+ neg_img_input = torch.cat((neg_img_input, img_cond_seq), dim=1)
358
+ neg_img_input_ids = torch.cat((neg_img_input_ids, img_cond_seq_ids), dim=1)
359
+
360
+ neg_pred = model(
361
+ img=neg_img_input,
362
+ img_ids=neg_img_input_ids,
363
+ txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
364
+ txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
365
+ y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
366
+ timesteps=t_vec,
367
+ guidance=guidance_vec,
368
+ timestep_index=step_index,
369
+ total_num_timesteps=total_steps,
370
+ controlnet_double_block_residuals=None,
371
+ controlnet_single_block_residuals=None,
372
+ ip_adapter_extensions=neg_ip_adapter_extensions,
373
+ regional_prompting_extension=neg_regional_prompting_extension,
374
+ )
375
+
376
+ # Slice negative prediction to match main image tokens
377
+ if img_cond_seq is not None:
378
+ neg_pred = neg_pred[:, :original_seq_len]
379
+ pred = neg_pred + step_cfg_scale * (pred - neg_pred)
380
+
381
+ preview_img = img - t_curr * pred
382
+ img = img + (t_prev - t_curr) * pred
383
+
384
+ if inpaint_extension is not None:
385
+ img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
386
+ preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_img, 0.0)
387
+
388
+ step_callback(
389
+ PipelineIntermediateState(
390
+ step=step_index + 1,
391
+ order=1,
392
+ total_steps=total_steps,
393
+ timestep=int(t_curr),
394
+ latents=preview_img,
395
+ ),
396
+ )
397
+
398
+ return img
399
+
400
+ finally:
401
+ # DyPE: Restore original position embedder
402
+ if original_pe_embedder is not None:
403
+ DyPEExtension.restore_model(model, original_pe_embedder)
@@ -0,0 +1,31 @@
1
+ """Dynamic Position Extrapolation (DyPE) for FLUX models.
2
+
3
+ DyPE enables high-resolution image generation (4K+) with pretrained FLUX models
4
+ by dynamically scaling RoPE position embeddings during the denoising process.
5
+
6
+ Based on: https://github.com/wildminder/ComfyUI-DyPE
7
+ """
8
+
9
+ from invokeai.backend.flux.dype.base import DyPEConfig
10
+ from invokeai.backend.flux.dype.embed import DyPEEmbedND
11
+ from invokeai.backend.flux.dype.presets import (
12
+ DYPE_PRESET_4K,
13
+ DYPE_PRESET_AUTO,
14
+ DYPE_PRESET_LABELS,
15
+ DYPE_PRESET_MANUAL,
16
+ DYPE_PRESET_OFF,
17
+ DyPEPreset,
18
+ get_dype_config_for_resolution,
19
+ )
20
+
21
+ __all__ = [
22
+ "DyPEConfig",
23
+ "DyPEEmbedND",
24
+ "DyPEPreset",
25
+ "DYPE_PRESET_OFF",
26
+ "DYPE_PRESET_MANUAL",
27
+ "DYPE_PRESET_AUTO",
28
+ "DYPE_PRESET_4K",
29
+ "DYPE_PRESET_LABELS",
30
+ "get_dype_config_for_resolution",
31
+ ]