InvokeAI 6.9.0rc3__py3-none-any.whl → 6.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. invokeai/app/api/dependencies.py +2 -0
  2. invokeai/app/api/routers/model_manager.py +91 -2
  3. invokeai/app/api/routers/workflows.py +9 -0
  4. invokeai/app/invocations/fields.py +19 -0
  5. invokeai/app/invocations/flux_denoise.py +15 -1
  6. invokeai/app/invocations/image_to_latents.py +23 -5
  7. invokeai/app/invocations/latents_to_image.py +2 -25
  8. invokeai/app/invocations/metadata.py +9 -1
  9. invokeai/app/invocations/metadata_linked.py +47 -0
  10. invokeai/app/invocations/model.py +8 -0
  11. invokeai/app/invocations/pbr_maps.py +59 -0
  12. invokeai/app/invocations/primitives.py +12 -0
  13. invokeai/app/invocations/prompt_template.py +57 -0
  14. invokeai/app/invocations/z_image_control.py +112 -0
  15. invokeai/app/invocations/z_image_denoise.py +770 -0
  16. invokeai/app/invocations/z_image_image_to_latents.py +102 -0
  17. invokeai/app/invocations/z_image_latents_to_image.py +103 -0
  18. invokeai/app/invocations/z_image_lora_loader.py +153 -0
  19. invokeai/app/invocations/z_image_model_loader.py +135 -0
  20. invokeai/app/invocations/z_image_text_encoder.py +197 -0
  21. invokeai/app/services/config/config_default.py +3 -1
  22. invokeai/app/services/model_install/model_install_common.py +14 -1
  23. invokeai/app/services/model_install/model_install_default.py +119 -19
  24. invokeai/app/services/model_manager/model_manager_default.py +7 -0
  25. invokeai/app/services/model_records/model_records_base.py +12 -0
  26. invokeai/app/services/model_records/model_records_sql.py +17 -0
  27. invokeai/app/services/shared/graph.py +132 -77
  28. invokeai/app/services/workflow_records/workflow_records_base.py +8 -0
  29. invokeai/app/services/workflow_records/workflow_records_sqlite.py +42 -0
  30. invokeai/app/util/step_callback.py +3 -0
  31. invokeai/backend/flux/denoise.py +196 -11
  32. invokeai/backend/flux/schedulers.py +62 -0
  33. invokeai/backend/image_util/pbr_maps/architecture/block.py +367 -0
  34. invokeai/backend/image_util/pbr_maps/architecture/pbr_rrdb_net.py +70 -0
  35. invokeai/backend/image_util/pbr_maps/pbr_maps.py +141 -0
  36. invokeai/backend/image_util/pbr_maps/utils/image_ops.py +93 -0
  37. invokeai/backend/model_manager/configs/controlnet.py +47 -1
  38. invokeai/backend/model_manager/configs/factory.py +26 -1
  39. invokeai/backend/model_manager/configs/lora.py +79 -1
  40. invokeai/backend/model_manager/configs/main.py +113 -0
  41. invokeai/backend/model_manager/configs/qwen3_encoder.py +156 -0
  42. invokeai/backend/model_manager/load/model_cache/model_cache.py +104 -2
  43. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_diffusers_rms_norm.py +40 -0
  44. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_layer_norm.py +25 -0
  45. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py +11 -2
  46. invokeai/backend/model_manager/load/model_loaders/cogview4.py +2 -1
  47. invokeai/backend/model_manager/load/model_loaders/flux.py +13 -6
  48. invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +4 -2
  49. invokeai/backend/model_manager/load/model_loaders/lora.py +11 -0
  50. invokeai/backend/model_manager/load/model_loaders/onnx.py +1 -0
  51. invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +2 -1
  52. invokeai/backend/model_manager/load/model_loaders/z_image.py +969 -0
  53. invokeai/backend/model_manager/load/model_util.py +6 -1
  54. invokeai/backend/model_manager/metadata/metadata_base.py +12 -5
  55. invokeai/backend/model_manager/model_on_disk.py +3 -0
  56. invokeai/backend/model_manager/starter_models.py +79 -0
  57. invokeai/backend/model_manager/taxonomy.py +5 -0
  58. invokeai/backend/model_manager/util/select_hf_files.py +23 -8
  59. invokeai/backend/patches/layer_patcher.py +34 -16
  60. invokeai/backend/patches/layers/lora_layer_base.py +2 -1
  61. invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py +17 -2
  62. invokeai/backend/patches/lora_conversions/flux_xlabs_lora_conversion_utils.py +92 -0
  63. invokeai/backend/patches/lora_conversions/formats.py +5 -0
  64. invokeai/backend/patches/lora_conversions/z_image_lora_constants.py +8 -0
  65. invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +189 -0
  66. invokeai/backend/quantization/gguf/ggml_tensor.py +38 -4
  67. invokeai/backend/quantization/gguf/loaders.py +47 -12
  68. invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +13 -0
  69. invokeai/backend/util/devices.py +25 -0
  70. invokeai/backend/util/hotfixes.py +2 -2
  71. invokeai/backend/z_image/__init__.py +16 -0
  72. invokeai/backend/z_image/extensions/__init__.py +1 -0
  73. invokeai/backend/z_image/extensions/regional_prompting_extension.py +205 -0
  74. invokeai/backend/z_image/text_conditioning.py +74 -0
  75. invokeai/backend/z_image/z_image_control_adapter.py +238 -0
  76. invokeai/backend/z_image/z_image_control_transformer.py +643 -0
  77. invokeai/backend/z_image/z_image_controlnet_extension.py +531 -0
  78. invokeai/backend/z_image/z_image_patchify_utils.py +135 -0
  79. invokeai/backend/z_image/z_image_transformer_patch.py +234 -0
  80. invokeai/frontend/web/dist/assets/App-BBELGD-n.js +161 -0
  81. invokeai/frontend/web/dist/assets/{browser-ponyfill-CN1j0ARZ.js → browser-ponyfill-4xPFTMT3.js} +1 -1
  82. invokeai/frontend/web/dist/assets/index-vCDSQboA.js +530 -0
  83. invokeai/frontend/web/dist/index.html +1 -1
  84. invokeai/frontend/web/dist/locales/de.json +24 -6
  85. invokeai/frontend/web/dist/locales/en-GB.json +1 -0
  86. invokeai/frontend/web/dist/locales/en.json +78 -3
  87. invokeai/frontend/web/dist/locales/es.json +0 -5
  88. invokeai/frontend/web/dist/locales/fr.json +0 -6
  89. invokeai/frontend/web/dist/locales/it.json +17 -64
  90. invokeai/frontend/web/dist/locales/ja.json +379 -44
  91. invokeai/frontend/web/dist/locales/ru.json +0 -6
  92. invokeai/frontend/web/dist/locales/vi.json +7 -54
  93. invokeai/frontend/web/dist/locales/zh-CN.json +0 -6
  94. invokeai/version/invokeai_version.py +1 -1
  95. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/METADATA +4 -4
  96. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/RECORD +102 -71
  97. invokeai/frontend/web/dist/assets/App-Cn9UyjoV.js +0 -161
  98. invokeai/frontend/web/dist/assets/index-BDrf9CL-.js +0 -530
  99. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/WHEEL +0 -0
  100. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/entry_points.txt +0 -0
  101. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/licenses/LICENSE +0 -0
  102. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
  103. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
  104. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,770 @@
1
+ import inspect
2
+ import math
3
+ from contextlib import ExitStack
4
+ from typing import Callable, Iterator, Optional, Tuple
5
+
6
+ import einops
7
+ import torch
8
+ import torchvision.transforms as tv_transforms
9
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
10
+ from PIL import Image
11
+ from torchvision.transforms.functional import resize as tv_resize
12
+ from tqdm import tqdm
13
+
14
+ from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
15
+ from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
16
+ from invokeai.app.invocations.fields import (
17
+ DenoiseMaskField,
18
+ FieldDescriptions,
19
+ Input,
20
+ InputField,
21
+ LatentsField,
22
+ ZImageConditioningField,
23
+ )
24
+ from invokeai.app.invocations.model import TransformerField, VAEField
25
+ from invokeai.app.invocations.primitives import LatentsOutput
26
+ from invokeai.app.invocations.z_image_control import ZImageControlField
27
+ from invokeai.app.invocations.z_image_image_to_latents import ZImageImageToLatentsInvocation
28
+ from invokeai.app.services.shared.invocation_context import InvocationContext
29
+ from invokeai.backend.flux.schedulers import ZIMAGE_SCHEDULER_LABELS, ZIMAGE_SCHEDULER_MAP, ZIMAGE_SCHEDULER_NAME_VALUES
30
+ from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat
31
+ from invokeai.backend.patches.layer_patcher import LayerPatcher
32
+ from invokeai.backend.patches.lora_conversions.z_image_lora_constants import Z_IMAGE_LORA_TRANSFORMER_PREFIX
33
+ from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
34
+ from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
35
+ from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
36
+ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ZImageConditioningInfo
37
+ from invokeai.backend.util.devices import TorchDevice
38
+ from invokeai.backend.z_image.extensions.regional_prompting_extension import ZImageRegionalPromptingExtension
39
+ from invokeai.backend.z_image.text_conditioning import ZImageTextConditioning
40
+ from invokeai.backend.z_image.z_image_control_adapter import ZImageControlAdapter
41
+ from invokeai.backend.z_image.z_image_controlnet_extension import (
42
+ ZImageControlNetExtension,
43
+ z_image_forward_with_control,
44
+ )
45
+ from invokeai.backend.z_image.z_image_transformer_patch import patch_transformer_for_regional_prompting
46
+
47
+
48
+ @invocation(
49
+ "z_image_denoise",
50
+ title="Denoise - Z-Image",
51
+ tags=["image", "z-image"],
52
+ category="image",
53
+ version="1.4.0",
54
+ classification=Classification.Prototype,
55
+ )
56
+ class ZImageDenoiseInvocation(BaseInvocation):
57
+ """Run the denoising process with a Z-Image model.
58
+
59
+ Supports regional prompting by connecting multiple conditioning inputs with masks.
60
+ """
61
+
62
+ # If latents is provided, this means we are doing image-to-image.
63
+ latents: Optional[LatentsField] = InputField(
64
+ default=None, description=FieldDescriptions.latents, input=Input.Connection
65
+ )
66
+ # denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
67
+ denoise_mask: Optional[DenoiseMaskField] = InputField(
68
+ default=None, description=FieldDescriptions.denoise_mask, input=Input.Connection
69
+ )
70
+ denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
71
+ denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
72
+ add_noise: bool = InputField(default=True, description="Add noise based on denoising start.")
73
+ transformer: TransformerField = InputField(
74
+ description=FieldDescriptions.z_image_model, input=Input.Connection, title="Transformer"
75
+ )
76
+ positive_conditioning: ZImageConditioningField | list[ZImageConditioningField] = InputField(
77
+ description=FieldDescriptions.positive_cond, input=Input.Connection
78
+ )
79
+ negative_conditioning: ZImageConditioningField | list[ZImageConditioningField] | None = InputField(
80
+ default=None, description=FieldDescriptions.negative_cond, input=Input.Connection
81
+ )
82
+ # Z-Image-Turbo works best without CFG (guidance_scale=1.0)
83
+ guidance_scale: float = InputField(
84
+ default=1.0,
85
+ ge=1.0,
86
+ description="Guidance scale for classifier-free guidance. 1.0 = no CFG (recommended for Z-Image-Turbo). "
87
+ "Values > 1.0 amplify guidance.",
88
+ title="Guidance Scale",
89
+ )
90
+ width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
91
+ height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
92
+ # Z-Image-Turbo uses 8 steps by default
93
+ steps: int = InputField(default=8, gt=0, description="Number of denoising steps. 8 recommended for Z-Image-Turbo.")
94
+ seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
95
+ # Z-Image Control support
96
+ control: Optional[ZImageControlField] = InputField(
97
+ default=None,
98
+ description="Z-Image control conditioning for spatial control (Canny, HED, Depth, Pose, MLSD).",
99
+ input=Input.Connection,
100
+ )
101
+ # VAE for encoding control images (required when using control)
102
+ vae: Optional[VAEField] = InputField(
103
+ default=None,
104
+ description=FieldDescriptions.vae + " Required for control conditioning.",
105
+ input=Input.Connection,
106
+ )
107
+ # Scheduler selection for the denoising process
108
+ scheduler: ZIMAGE_SCHEDULER_NAME_VALUES = InputField(
109
+ default="euler",
110
+ description="Scheduler (sampler) for the denoising process. Euler is the default and recommended for "
111
+ "Z-Image-Turbo. Heun is 2nd-order (better quality, 2x slower). LCM is optimized for few steps.",
112
+ ui_choice_labels=ZIMAGE_SCHEDULER_LABELS,
113
+ )
114
+
115
+ @torch.no_grad()
116
+ def invoke(self, context: InvocationContext) -> LatentsOutput:
117
+ latents = self._run_diffusion(context)
118
+ latents = latents.detach().to("cpu")
119
+
120
+ name = context.tensors.save(tensor=latents)
121
+ return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
122
+
123
+ def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
124
+ """Prepare the inpaint mask."""
125
+ if self.denoise_mask is None:
126
+ return None
127
+ mask = context.tensors.load(self.denoise_mask.mask_name)
128
+
129
+ # Invert mask: 0.0 = regions to denoise, 1.0 = regions to preserve
130
+ mask = 1.0 - mask
131
+
132
+ _, _, latent_height, latent_width = latents.shape
133
+ mask = tv_resize(
134
+ img=mask,
135
+ size=[latent_height, latent_width],
136
+ interpolation=tv_transforms.InterpolationMode.BILINEAR,
137
+ antialias=False,
138
+ )
139
+
140
+ mask = mask.to(device=latents.device, dtype=latents.dtype)
141
+ return mask
142
+
143
+ def _load_text_conditioning(
144
+ self,
145
+ context: InvocationContext,
146
+ cond_field: ZImageConditioningField | list[ZImageConditioningField],
147
+ img_height: int,
148
+ img_width: int,
149
+ dtype: torch.dtype,
150
+ device: torch.device,
151
+ ) -> list[ZImageTextConditioning]:
152
+ """Load Z-Image text conditioning with optional regional masks.
153
+
154
+ Args:
155
+ context: The invocation context.
156
+ cond_field: Single conditioning field or list of fields.
157
+ img_height: Height of the image token grid (H // patch_size).
158
+ img_width: Width of the image token grid (W // patch_size).
159
+ dtype: Target dtype.
160
+ device: Target device.
161
+
162
+ Returns:
163
+ List of ZImageTextConditioning objects with embeddings and masks.
164
+ """
165
+ # Normalize to a list
166
+ cond_list = [cond_field] if isinstance(cond_field, ZImageConditioningField) else cond_field
167
+
168
+ text_conditionings: list[ZImageTextConditioning] = []
169
+ for cond in cond_list:
170
+ # Load the text embeddings
171
+ cond_data = context.conditioning.load(cond.conditioning_name)
172
+ assert len(cond_data.conditionings) == 1
173
+ z_image_conditioning = cond_data.conditionings[0]
174
+ assert isinstance(z_image_conditioning, ZImageConditioningInfo)
175
+ z_image_conditioning = z_image_conditioning.to(dtype=dtype, device=device)
176
+ prompt_embeds = z_image_conditioning.prompt_embeds
177
+
178
+ # Load the mask, if provided
179
+ mask: torch.Tensor | None = None
180
+ if cond.mask is not None:
181
+ mask = context.tensors.load(cond.mask.tensor_name)
182
+ mask = mask.to(device=device)
183
+ mask = ZImageRegionalPromptingExtension.preprocess_regional_prompt_mask(
184
+ mask, img_height, img_width, dtype, device
185
+ )
186
+
187
+ text_conditionings.append(ZImageTextConditioning(prompt_embeds=prompt_embeds, mask=mask))
188
+
189
+ return text_conditionings
190
+
191
+ def _get_noise(
192
+ self,
193
+ batch_size: int,
194
+ num_channels_latents: int,
195
+ height: int,
196
+ width: int,
197
+ dtype: torch.dtype,
198
+ device: torch.device,
199
+ seed: int,
200
+ ) -> torch.Tensor:
201
+ """Generate initial noise tensor."""
202
+ # Generate noise as float32 on CPU for maximum compatibility,
203
+ # then cast to target dtype/device
204
+ rand_device = "cpu"
205
+ rand_dtype = torch.float32
206
+
207
+ return torch.randn(
208
+ batch_size,
209
+ num_channels_latents,
210
+ int(height) // LATENT_SCALE_FACTOR,
211
+ int(width) // LATENT_SCALE_FACTOR,
212
+ device=rand_device,
213
+ dtype=rand_dtype,
214
+ generator=torch.Generator(device=rand_device).manual_seed(seed),
215
+ ).to(device=device, dtype=dtype)
216
+
217
+ def _calculate_shift(
218
+ self,
219
+ image_seq_len: int,
220
+ base_image_seq_len: int = 256,
221
+ max_image_seq_len: int = 4096,
222
+ base_shift: float = 0.5,
223
+ max_shift: float = 1.15,
224
+ ) -> float:
225
+ """Calculate timestep shift based on image sequence length.
226
+
227
+ Based on diffusers ZImagePipeline.calculate_shift method.
228
+ """
229
+ m = (max_shift - base_shift) / (max_image_seq_len - base_image_seq_len)
230
+ b = base_shift - m * base_image_seq_len
231
+ mu = image_seq_len * m + b
232
+ return mu
233
+
234
+ def _get_sigmas(self, mu: float, num_steps: int) -> list[float]:
235
+ """Generate sigma schedule with time shift.
236
+
237
+ Based on FlowMatchEulerDiscreteScheduler with shift.
238
+ Generates num_steps + 1 sigma values (including terminal 0.0).
239
+ """
240
+ import math
241
+
242
+ def time_shift(mu: float, sigma: float, t: float) -> float:
243
+ """Apply time shift to a single timestep value."""
244
+ if t <= 0:
245
+ return 0.0
246
+ if t >= 1:
247
+ return 1.0
248
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
249
+
250
+ # Generate linearly spaced values from 1 to 0 (excluding endpoints for safety)
251
+ # then apply time shift
252
+ sigmas = []
253
+ for i in range(num_steps + 1):
254
+ t = 1.0 - i / num_steps # Goes from 1.0 to 0.0
255
+ sigma = time_shift(mu, 1.0, t)
256
+ sigmas.append(sigma)
257
+
258
+ return sigmas
259
+
260
+ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
261
+ device = TorchDevice.choose_torch_device()
262
+ inference_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
263
+
264
+ transformer_info = context.models.load(self.transformer.transformer)
265
+
266
+ # Calculate image token grid dimensions
267
+ patch_size = 2 # Z-Image uses patch_size=2
268
+ latent_height = self.height // LATENT_SCALE_FACTOR
269
+ latent_width = self.width // LATENT_SCALE_FACTOR
270
+ img_token_height = latent_height // patch_size
271
+ img_token_width = latent_width // patch_size
272
+ img_seq_len = img_token_height * img_token_width
273
+
274
+ # Load positive conditioning with regional masks
275
+ pos_text_conditionings = self._load_text_conditioning(
276
+ context=context,
277
+ cond_field=self.positive_conditioning,
278
+ img_height=img_token_height,
279
+ img_width=img_token_width,
280
+ dtype=inference_dtype,
281
+ device=device,
282
+ )
283
+
284
+ # Create regional prompting extension
285
+ regional_extension = ZImageRegionalPromptingExtension.from_text_conditionings(
286
+ text_conditionings=pos_text_conditionings,
287
+ img_seq_len=img_seq_len,
288
+ )
289
+
290
+ # Get the concatenated prompt embeddings for the transformer
291
+ pos_prompt_embeds = regional_extension.regional_text_conditioning.prompt_embeds
292
+
293
+ # Load negative conditioning if provided and guidance_scale != 1.0
294
+ # CFG formula: pred = pred_uncond + cfg_scale * (pred_cond - pred_uncond)
295
+ # At cfg_scale=1.0: pred = pred_cond (no effect, skip uncond computation)
296
+ # This matches FLUX's convention where 1.0 means "no CFG"
297
+ neg_prompt_embeds: torch.Tensor | None = None
298
+ do_classifier_free_guidance = (
299
+ not math.isclose(self.guidance_scale, 1.0) and self.negative_conditioning is not None
300
+ )
301
+ if do_classifier_free_guidance:
302
+ assert self.negative_conditioning is not None
303
+ # Load all negative conditionings and concatenate embeddings
304
+ # Note: We ignore masks for negative conditioning as regional negative prompting is not fully supported
305
+ neg_text_conditionings = self._load_text_conditioning(
306
+ context=context,
307
+ cond_field=self.negative_conditioning,
308
+ img_height=img_token_height,
309
+ img_width=img_token_width,
310
+ dtype=inference_dtype,
311
+ device=device,
312
+ )
313
+ # Concatenate all negative embeddings
314
+ neg_prompt_embeds = torch.cat([tc.prompt_embeds for tc in neg_text_conditionings], dim=0)
315
+
316
+ # Calculate shift based on image sequence length
317
+ mu = self._calculate_shift(img_seq_len)
318
+
319
+ # Generate sigma schedule with time shift
320
+ sigmas = self._get_sigmas(mu, self.steps)
321
+
322
+ # Apply denoising_start and denoising_end clipping
323
+ if self.denoising_start > 0 or self.denoising_end < 1:
324
+ # Calculate start and end indices based on denoising range
325
+ total_sigmas = len(sigmas)
326
+ start_idx = int(self.denoising_start * (total_sigmas - 1))
327
+ end_idx = int(self.denoising_end * (total_sigmas - 1)) + 1
328
+ sigmas = sigmas[start_idx:end_idx]
329
+
330
+ total_steps = len(sigmas) - 1
331
+
332
+ # Load input latents if provided (image-to-image)
333
+ init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
334
+ if init_latents is not None:
335
+ init_latents = init_latents.to(device=device, dtype=inference_dtype)
336
+
337
+ # Generate initial noise
338
+ num_channels_latents = 16 # Z-Image uses 16 latent channels
339
+ noise = self._get_noise(
340
+ batch_size=1,
341
+ num_channels_latents=num_channels_latents,
342
+ height=self.height,
343
+ width=self.width,
344
+ dtype=inference_dtype,
345
+ device=device,
346
+ seed=self.seed,
347
+ )
348
+
349
+ # Prepare input latent image
350
+ if init_latents is not None:
351
+ if self.add_noise:
352
+ # Noise the init_latents by the appropriate amount for the first timestep.
353
+ s_0 = sigmas[0]
354
+ latents = s_0 * noise + (1.0 - s_0) * init_latents
355
+ else:
356
+ latents = init_latents
357
+ else:
358
+ if self.denoising_start > 1e-5:
359
+ raise ValueError("denoising_start should be 0 when initial latents are not provided.")
360
+ latents = noise
361
+
362
+ # Short-circuit if no denoising steps
363
+ if total_steps <= 0:
364
+ return latents
365
+
366
+ # Prepare inpaint extension
367
+ inpaint_mask = self._prep_inpaint_mask(context, latents)
368
+ inpaint_extension: RectifiedFlowInpaintExtension | None = None
369
+ if inpaint_mask is not None:
370
+ if init_latents is None:
371
+ raise ValueError("Initial latents are required when using an inpaint mask (image-to-image inpainting)")
372
+ inpaint_extension = RectifiedFlowInpaintExtension(
373
+ init_latents=init_latents,
374
+ inpaint_mask=inpaint_mask,
375
+ noise=noise,
376
+ )
377
+
378
+ step_callback = self._build_step_callback(context)
379
+
380
+ # Initialize the diffusers scheduler if not using built-in Euler
381
+ scheduler: SchedulerMixin | None = None
382
+ use_scheduler = self.scheduler != "euler"
383
+
384
+ if use_scheduler:
385
+ scheduler_class = ZIMAGE_SCHEDULER_MAP[self.scheduler]
386
+ scheduler = scheduler_class(
387
+ num_train_timesteps=1000,
388
+ shift=1.0,
389
+ )
390
+ # Set timesteps - LCM should use num_inference_steps (it has its own sigma schedule),
391
+ # while other schedulers can use custom sigmas if supported
392
+ is_lcm = self.scheduler == "lcm"
393
+ set_timesteps_sig = inspect.signature(scheduler.set_timesteps)
394
+ if not is_lcm and "sigmas" in set_timesteps_sig.parameters:
395
+ # Convert sigmas list to tensor for scheduler
396
+ scheduler.set_timesteps(sigmas=sigmas, device=device)
397
+ else:
398
+ # LCM or scheduler doesn't support custom sigmas - use num_inference_steps
399
+ scheduler.set_timesteps(num_inference_steps=total_steps, device=device)
400
+
401
+ # For Heun scheduler, the number of actual steps may differ
402
+ num_scheduler_steps = len(scheduler.timesteps)
403
+ else:
404
+ num_scheduler_steps = total_steps
405
+
406
+ with ExitStack() as exit_stack:
407
+ # Get transformer config to determine if it's quantized
408
+ transformer_config = context.models.get_config(self.transformer.transformer)
409
+
410
+ # Determine if the model is quantized.
411
+ # If the model is quantized, then we need to apply the LoRA weights as sidecar layers. This results in
412
+ # slower inference than direct patching, but is agnostic to the quantization format.
413
+ if transformer_config.format in [ModelFormat.Diffusers, ModelFormat.Checkpoint]:
414
+ model_is_quantized = False
415
+ elif transformer_config.format in [ModelFormat.GGUFQuantized]:
416
+ model_is_quantized = True
417
+ else:
418
+ raise ValueError(f"Unsupported Z-Image model format: {transformer_config.format}")
419
+
420
+ # Load transformer - always use base transformer, control is handled via extension
421
+ (cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
422
+
423
+ # Prepare control extension if control is provided
424
+ control_extension: ZImageControlNetExtension | None = None
425
+
426
+ if self.control is not None:
427
+ # Load control adapter using context manager (proper GPU memory management)
428
+ control_model_info = context.models.load(self.control.control_model)
429
+ (_, control_adapter) = exit_stack.enter_context(control_model_info.model_on_device())
430
+ assert isinstance(control_adapter, ZImageControlAdapter)
431
+
432
+ # Get control_in_dim from adapter config (16 for V1, 33 for V2.0)
433
+ adapter_config = control_adapter.config
434
+ control_in_dim = adapter_config.get("control_in_dim", 16)
435
+ num_control_blocks = adapter_config.get("num_control_blocks", 6)
436
+
437
+ # Log control configuration for debugging
438
+ version = "V2.0" if control_in_dim > 16 else "V1"
439
+ context.util.signal_progress(
440
+ f"Using Z-Image ControlNet {version} (Extension): control_in_dim={control_in_dim}, "
441
+ f"num_blocks={num_control_blocks}, scale={self.control.control_context_scale}"
442
+ )
443
+
444
+ # Load and prepare control image - must be VAE-encoded!
445
+ if self.vae is None:
446
+ raise ValueError("VAE is required when using Z-Image Control. Connect a VAE to the 'vae' input.")
447
+
448
+ control_image = context.images.get_pil(self.control.image_name)
449
+
450
+ # Resize control image to match output dimensions
451
+ control_image = control_image.convert("RGB")
452
+ control_image = control_image.resize((self.width, self.height), Image.Resampling.LANCZOS)
453
+
454
+ # Convert to tensor format for VAE encoding
455
+ from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
456
+
457
+ control_image_tensor = image_resized_to_grid_as_tensor(control_image)
458
+ if control_image_tensor.dim() == 3:
459
+ control_image_tensor = einops.rearrange(control_image_tensor, "c h w -> 1 c h w")
460
+
461
+ # Encode control image through VAE to get latents
462
+ vae_info = context.models.load(self.vae.vae)
463
+ control_latents = ZImageImageToLatentsInvocation.vae_encode(
464
+ vae_info=vae_info,
465
+ image_tensor=control_image_tensor,
466
+ )
467
+
468
+ # Move to inference device/dtype
469
+ control_latents = control_latents.to(device=device, dtype=inference_dtype)
470
+
471
+ # Add frame dimension: [B, C, H, W] -> [C, 1, H, W] (single image)
472
+ control_latents = control_latents.squeeze(0).unsqueeze(1)
473
+
474
+ # Prepare control_cond based on control_in_dim
475
+ # V1: 16 channels (just control latents)
476
+ # V2.0: 33 channels = 16 control + 16 reference + 1 mask
477
+ # - Channels 0-15: control image latents (from VAE encoding)
478
+ # - Channels 16-31: reference/inpaint image latents (zeros for pure control)
479
+ # - Channel 32: inpaint mask (1.0 = don't inpaint, 0.0 = inpaint region)
480
+ # For pure control (no inpainting), we set mask=1 to tell model "use control, don't inpaint"
481
+ c, f, h, w = control_latents.shape
482
+ if c < control_in_dim:
483
+ padding_channels = control_in_dim - c
484
+ if padding_channels == 17:
485
+ # V2.0: 16 reference channels (zeros) + 1 mask channel (ones)
486
+ ref_padding = torch.zeros(
487
+ (16, f, h, w),
488
+ device=device,
489
+ dtype=inference_dtype,
490
+ )
491
+ # Mask channel = 1.0 means "don't inpaint this region, use control signal"
492
+ mask_channel = torch.ones(
493
+ (1, f, h, w),
494
+ device=device,
495
+ dtype=inference_dtype,
496
+ )
497
+ control_latents = torch.cat([control_latents, ref_padding, mask_channel], dim=0)
498
+ else:
499
+ # Generic padding with zeros for other cases
500
+ zero_padding = torch.zeros(
501
+ (padding_channels, f, h, w),
502
+ device=device,
503
+ dtype=inference_dtype,
504
+ )
505
+ control_latents = torch.cat([control_latents, zero_padding], dim=0)
506
+
507
+ # Create control extension (adapter is already on device from model_on_device)
508
+ control_extension = ZImageControlNetExtension(
509
+ control_adapter=control_adapter,
510
+ control_cond=control_latents,
511
+ weight=self.control.control_context_scale,
512
+ begin_step_percent=self.control.begin_step_percent,
513
+ end_step_percent=self.control.end_step_percent,
514
+ )
515
+
516
+ # Apply LoRA models to the transformer.
517
+ # Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
518
+ exit_stack.enter_context(
519
+ LayerPatcher.apply_smart_model_patches(
520
+ model=transformer,
521
+ patches=self._lora_iterator(context),
522
+ prefix=Z_IMAGE_LORA_TRANSFORMER_PREFIX,
523
+ dtype=inference_dtype,
524
+ cached_weights=cached_weights,
525
+ force_sidecar_patching=model_is_quantized,
526
+ )
527
+ )
528
+
529
+ # Apply regional prompting patch if we have regional masks
530
+ exit_stack.enter_context(
531
+ patch_transformer_for_regional_prompting(
532
+ transformer=transformer,
533
+ regional_attn_mask=regional_extension.regional_attn_mask,
534
+ img_seq_len=img_seq_len,
535
+ )
536
+ )
537
+
538
+ # Denoising loop - supports both built-in Euler and diffusers schedulers
539
+ # Track user-facing step for progress (accounts for Heun's double steps)
540
+ user_step = 0
541
+
542
+ if use_scheduler and scheduler is not None:
543
+ # Use diffusers scheduler for stepping
544
+ # Use tqdm with total_steps (user-facing steps) not num_scheduler_steps (internal steps)
545
+ # This ensures progress bar shows 1/8, 2/8, etc. even when scheduler uses more internal steps
546
+ pbar = tqdm(total=total_steps, desc="Denoising")
547
+ for step_index in range(num_scheduler_steps):
548
+ sched_timestep = scheduler.timesteps[step_index]
549
+ # Convert scheduler timestep (0-1000) to normalized sigma (0-1)
550
+ sigma_curr = sched_timestep.item() / scheduler.config.num_train_timesteps
551
+
552
+ # For Heun scheduler, track if we're in first or second order step
553
+ is_heun = hasattr(scheduler, "state_in_first_order")
554
+ in_first_order = scheduler.state_in_first_order if is_heun else True
555
+
556
+ # Timestep tensor for Z-Image model
557
+ # The model expects t=0 at start (noise) and t=1 at end (clean)
558
+ model_t = 1.0 - sigma_curr
559
+ timestep = torch.tensor([model_t], device=device, dtype=inference_dtype).expand(latents.shape[0])
560
+
561
+ # Run transformer for positive prediction
562
+ latent_model_input = latents.to(transformer.dtype)
563
+ latent_model_input = latent_model_input.unsqueeze(2) # Add frame dimension
564
+ latent_model_input_list = list(latent_model_input.unbind(dim=0))
565
+
566
+ # Determine if control should be applied at this step
567
+ apply_control = control_extension is not None and control_extension.should_apply(
568
+ user_step, total_steps
569
+ )
570
+
571
+ # Run forward pass
572
+ if apply_control:
573
+ model_out_list, _ = z_image_forward_with_control(
574
+ transformer=transformer,
575
+ x=latent_model_input_list,
576
+ t=timestep,
577
+ cap_feats=[pos_prompt_embeds],
578
+ control_extension=control_extension,
579
+ )
580
+ else:
581
+ model_output = transformer(
582
+ x=latent_model_input_list,
583
+ t=timestep,
584
+ cap_feats=[pos_prompt_embeds],
585
+ )
586
+ model_out_list = model_output[0]
587
+
588
+ noise_pred_cond = torch.stack([t.float() for t in model_out_list], dim=0)
589
+ noise_pred_cond = noise_pred_cond.squeeze(2)
590
+ noise_pred_cond = -noise_pred_cond # Z-Image uses v-prediction with negation
591
+
592
+ # Apply CFG if enabled
593
+ if do_classifier_free_guidance and neg_prompt_embeds is not None:
594
+ if apply_control:
595
+ model_out_list_uncond, _ = z_image_forward_with_control(
596
+ transformer=transformer,
597
+ x=latent_model_input_list,
598
+ t=timestep,
599
+ cap_feats=[neg_prompt_embeds],
600
+ control_extension=control_extension,
601
+ )
602
+ else:
603
+ model_output_uncond = transformer(
604
+ x=latent_model_input_list,
605
+ t=timestep,
606
+ cap_feats=[neg_prompt_embeds],
607
+ )
608
+ model_out_list_uncond = model_output_uncond[0]
609
+
610
+ noise_pred_uncond = torch.stack([t.float() for t in model_out_list_uncond], dim=0)
611
+ noise_pred_uncond = noise_pred_uncond.squeeze(2)
612
+ noise_pred_uncond = -noise_pred_uncond
613
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
614
+ else:
615
+ noise_pred = noise_pred_cond
616
+
617
+ # Use scheduler.step() for the update
618
+ step_output = scheduler.step(model_output=noise_pred, timestep=sched_timestep, sample=latents)
619
+ latents = step_output.prev_sample
620
+
621
+ # Get sigma_prev for inpainting (next sigma value)
622
+ if step_index + 1 < len(scheduler.sigmas):
623
+ sigma_prev = scheduler.sigmas[step_index + 1].item()
624
+ else:
625
+ sigma_prev = 0.0
626
+
627
+ if inpaint_extension is not None:
628
+ latents = inpaint_extension.merge_intermediate_latents_with_init_latents(latents, sigma_prev)
629
+
630
+ # For Heun, only increment user step after second-order step completes
631
+ if is_heun:
632
+ if not in_first_order:
633
+ user_step += 1
634
+ # Only call step_callback if we haven't exceeded total_steps
635
+ if user_step <= total_steps:
636
+ pbar.update(1)
637
+ step_callback(
638
+ PipelineIntermediateState(
639
+ step=user_step,
640
+ order=2,
641
+ total_steps=total_steps,
642
+ timestep=int(sigma_curr * 1000),
643
+ latents=latents,
644
+ ),
645
+ )
646
+ else:
647
+ # For LCM and other first-order schedulers
648
+ user_step += 1
649
+ # Only call step_callback if we haven't exceeded total_steps
650
+ # (LCM scheduler may have more internal steps than user-facing steps)
651
+ if user_step <= total_steps:
652
+ pbar.update(1)
653
+ step_callback(
654
+ PipelineIntermediateState(
655
+ step=user_step,
656
+ order=1,
657
+ total_steps=total_steps,
658
+ timestep=int(sigma_curr * 1000),
659
+ latents=latents,
660
+ ),
661
+ )
662
+ pbar.close()
663
+ else:
664
+ # Original Euler implementation (default, optimized for Z-Image)
665
+ for step_idx in tqdm(range(total_steps)):
666
+ sigma_curr = sigmas[step_idx]
667
+ sigma_prev = sigmas[step_idx + 1]
668
+
669
+ # Timestep tensor for Z-Image model
670
+ # The model expects t=0 at start (noise) and t=1 at end (clean)
671
+ # Sigma goes from 1 (noise) to 0 (clean), so model_t = 1 - sigma
672
+ model_t = 1.0 - sigma_curr
673
+ timestep = torch.tensor([model_t], device=device, dtype=inference_dtype).expand(latents.shape[0])
674
+
675
+ # Run transformer for positive prediction
676
+ # Z-Image transformer expects: x as list of [C, 1, H, W] tensors, t, cap_feats as list
677
+ # Prepare latent input: [B, C, H, W] -> [B, C, 1, H, W] -> list of [C, 1, H, W]
678
+ latent_model_input = latents.to(transformer.dtype)
679
+ latent_model_input = latent_model_input.unsqueeze(2) # Add frame dimension
680
+ latent_model_input_list = list(latent_model_input.unbind(dim=0))
681
+
682
+ # Determine if control should be applied at this step
683
+ apply_control = control_extension is not None and control_extension.should_apply(
684
+ step_idx, total_steps
685
+ )
686
+
687
+ # Run forward pass - use custom forward with control if extension is active
688
+ if apply_control:
689
+ model_out_list, _ = z_image_forward_with_control(
690
+ transformer=transformer,
691
+ x=latent_model_input_list,
692
+ t=timestep,
693
+ cap_feats=[pos_prompt_embeds],
694
+ control_extension=control_extension,
695
+ )
696
+ else:
697
+ model_output = transformer(
698
+ x=latent_model_input_list,
699
+ t=timestep,
700
+ cap_feats=[pos_prompt_embeds],
701
+ )
702
+ model_out_list = model_output[0] # Extract list of tensors from tuple
703
+
704
+ noise_pred_cond = torch.stack([t.float() for t in model_out_list], dim=0)
705
+ noise_pred_cond = noise_pred_cond.squeeze(2) # Remove frame dimension
706
+ noise_pred_cond = -noise_pred_cond # Z-Image uses v-prediction with negation
707
+
708
+ # Apply CFG if enabled
709
+ if do_classifier_free_guidance and neg_prompt_embeds is not None:
710
+ if apply_control:
711
+ model_out_list_uncond, _ = z_image_forward_with_control(
712
+ transformer=transformer,
713
+ x=latent_model_input_list,
714
+ t=timestep,
715
+ cap_feats=[neg_prompt_embeds],
716
+ control_extension=control_extension,
717
+ )
718
+ else:
719
+ model_output_uncond = transformer(
720
+ x=latent_model_input_list,
721
+ t=timestep,
722
+ cap_feats=[neg_prompt_embeds],
723
+ )
724
+ model_out_list_uncond = model_output_uncond[0] # Extract list of tensors from tuple
725
+
726
+ noise_pred_uncond = torch.stack([t.float() for t in model_out_list_uncond], dim=0)
727
+ noise_pred_uncond = noise_pred_uncond.squeeze(2)
728
+ noise_pred_uncond = -noise_pred_uncond
729
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
730
+ else:
731
+ noise_pred = noise_pred_cond
732
+
733
+ # Euler step
734
+ latents_dtype = latents.dtype
735
+ latents = latents.to(dtype=torch.float32)
736
+ latents = latents + (sigma_prev - sigma_curr) * noise_pred
737
+ latents = latents.to(dtype=latents_dtype)
738
+
739
+ if inpaint_extension is not None:
740
+ latents = inpaint_extension.merge_intermediate_latents_with_init_latents(latents, sigma_prev)
741
+
742
+ step_callback(
743
+ PipelineIntermediateState(
744
+ step=step_idx + 1,
745
+ order=1,
746
+ total_steps=total_steps,
747
+ timestep=int(sigma_curr * 1000),
748
+ latents=latents,
749
+ ),
750
+ )
751
+
752
+ return latents
753
+
754
+ def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
755
+ def step_callback(state: PipelineIntermediateState) -> None:
756
+ context.util.sd_step_callback(state, BaseModelType.ZImage)
757
+
758
+ return step_callback
759
+
760
+ def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
761
+ """Iterate over LoRA models to apply to the transformer."""
762
+ for lora in self.transformer.loras:
763
+ lora_info = context.models.load(lora.lora)
764
+ if not isinstance(lora_info.model, ModelPatchRaw):
765
+ raise TypeError(
766
+ f"Expected ModelPatchRaw for LoRA '{lora.lora.key}', got {type(lora_info.model).__name__}. "
767
+ "The LoRA model may be corrupted or incompatible."
768
+ )
769
+ yield (lora_info.model, lora.weight)
770
+ del lora_info