InvokeAI 6.9.0rc3__py3-none-any.whl → 6.10.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. invokeai/app/api/dependencies.py +2 -0
  2. invokeai/app/api/routers/model_manager.py +91 -2
  3. invokeai/app/api/routers/workflows.py +9 -0
  4. invokeai/app/invocations/fields.py +19 -0
  5. invokeai/app/invocations/image_to_latents.py +23 -5
  6. invokeai/app/invocations/latents_to_image.py +2 -25
  7. invokeai/app/invocations/metadata.py +9 -1
  8. invokeai/app/invocations/model.py +8 -0
  9. invokeai/app/invocations/primitives.py +12 -0
  10. invokeai/app/invocations/prompt_template.py +57 -0
  11. invokeai/app/invocations/z_image_control.py +112 -0
  12. invokeai/app/invocations/z_image_denoise.py +610 -0
  13. invokeai/app/invocations/z_image_image_to_latents.py +102 -0
  14. invokeai/app/invocations/z_image_latents_to_image.py +103 -0
  15. invokeai/app/invocations/z_image_lora_loader.py +153 -0
  16. invokeai/app/invocations/z_image_model_loader.py +135 -0
  17. invokeai/app/invocations/z_image_text_encoder.py +197 -0
  18. invokeai/app/services/model_install/model_install_common.py +14 -1
  19. invokeai/app/services/model_install/model_install_default.py +119 -19
  20. invokeai/app/services/model_records/model_records_base.py +12 -0
  21. invokeai/app/services/model_records/model_records_sql.py +17 -0
  22. invokeai/app/services/shared/graph.py +132 -77
  23. invokeai/app/services/workflow_records/workflow_records_base.py +8 -0
  24. invokeai/app/services/workflow_records/workflow_records_sqlite.py +42 -0
  25. invokeai/app/util/step_callback.py +3 -0
  26. invokeai/backend/model_manager/configs/controlnet.py +47 -1
  27. invokeai/backend/model_manager/configs/factory.py +26 -1
  28. invokeai/backend/model_manager/configs/lora.py +43 -1
  29. invokeai/backend/model_manager/configs/main.py +113 -0
  30. invokeai/backend/model_manager/configs/qwen3_encoder.py +156 -0
  31. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_diffusers_rms_norm.py +40 -0
  32. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_layer_norm.py +25 -0
  33. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py +11 -2
  34. invokeai/backend/model_manager/load/model_loaders/lora.py +11 -0
  35. invokeai/backend/model_manager/load/model_loaders/z_image.py +935 -0
  36. invokeai/backend/model_manager/load/model_util.py +6 -1
  37. invokeai/backend/model_manager/metadata/metadata_base.py +12 -5
  38. invokeai/backend/model_manager/model_on_disk.py +3 -0
  39. invokeai/backend/model_manager/starter_models.py +70 -0
  40. invokeai/backend/model_manager/taxonomy.py +5 -0
  41. invokeai/backend/model_manager/util/select_hf_files.py +23 -8
  42. invokeai/backend/patches/layer_patcher.py +34 -16
  43. invokeai/backend/patches/layers/lora_layer_base.py +2 -1
  44. invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py +17 -2
  45. invokeai/backend/patches/lora_conversions/flux_xlabs_lora_conversion_utils.py +92 -0
  46. invokeai/backend/patches/lora_conversions/formats.py +5 -0
  47. invokeai/backend/patches/lora_conversions/z_image_lora_constants.py +8 -0
  48. invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +155 -0
  49. invokeai/backend/quantization/gguf/ggml_tensor.py +27 -4
  50. invokeai/backend/quantization/gguf/loaders.py +47 -12
  51. invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +13 -0
  52. invokeai/backend/util/devices.py +25 -0
  53. invokeai/backend/util/hotfixes.py +2 -2
  54. invokeai/backend/z_image/__init__.py +16 -0
  55. invokeai/backend/z_image/extensions/__init__.py +1 -0
  56. invokeai/backend/z_image/extensions/regional_prompting_extension.py +207 -0
  57. invokeai/backend/z_image/text_conditioning.py +74 -0
  58. invokeai/backend/z_image/z_image_control_adapter.py +238 -0
  59. invokeai/backend/z_image/z_image_control_transformer.py +643 -0
  60. invokeai/backend/z_image/z_image_controlnet_extension.py +531 -0
  61. invokeai/backend/z_image/z_image_patchify_utils.py +135 -0
  62. invokeai/backend/z_image/z_image_transformer_patch.py +234 -0
  63. invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +161 -0
  64. invokeai/frontend/web/dist/assets/{browser-ponyfill-CN1j0ARZ.js → browser-ponyfill-DHZxq1nk.js} +1 -1
  65. invokeai/frontend/web/dist/assets/index-dgSJAY--.js +530 -0
  66. invokeai/frontend/web/dist/index.html +1 -1
  67. invokeai/frontend/web/dist/locales/de.json +24 -6
  68. invokeai/frontend/web/dist/locales/en.json +70 -1
  69. invokeai/frontend/web/dist/locales/es.json +0 -5
  70. invokeai/frontend/web/dist/locales/fr.json +0 -6
  71. invokeai/frontend/web/dist/locales/it.json +17 -64
  72. invokeai/frontend/web/dist/locales/ja.json +379 -44
  73. invokeai/frontend/web/dist/locales/ru.json +0 -6
  74. invokeai/frontend/web/dist/locales/vi.json +7 -54
  75. invokeai/frontend/web/dist/locales/zh-CN.json +0 -6
  76. invokeai/version/invokeai_version.py +1 -1
  77. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/METADATA +3 -3
  78. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/RECORD +84 -60
  79. invokeai/frontend/web/dist/assets/App-Cn9UyjoV.js +0 -161
  80. invokeai/frontend/web/dist/assets/index-BDrf9CL-.js +0 -530
  81. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/WHEEL +0 -0
  82. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/entry_points.txt +0 -0
  83. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE +0 -0
  84. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
  85. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
  86. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,610 @@
1
+ import math
2
+ from contextlib import ExitStack
3
+ from typing import Callable, Iterator, Optional, Tuple
4
+
5
+ import einops
6
+ import torch
7
+ import torchvision.transforms as tv_transforms
8
+ from PIL import Image
9
+ from torchvision.transforms.functional import resize as tv_resize
10
+ from tqdm import tqdm
11
+
12
+ from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
13
+ from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
14
+ from invokeai.app.invocations.fields import (
15
+ DenoiseMaskField,
16
+ FieldDescriptions,
17
+ Input,
18
+ InputField,
19
+ LatentsField,
20
+ ZImageConditioningField,
21
+ )
22
+ from invokeai.app.invocations.model import TransformerField, VAEField
23
+ from invokeai.app.invocations.primitives import LatentsOutput
24
+ from invokeai.app.invocations.z_image_control import ZImageControlField
25
+ from invokeai.app.invocations.z_image_image_to_latents import ZImageImageToLatentsInvocation
26
+ from invokeai.app.services.shared.invocation_context import InvocationContext
27
+ from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat
28
+ from invokeai.backend.patches.layer_patcher import LayerPatcher
29
+ from invokeai.backend.patches.lora_conversions.z_image_lora_constants import Z_IMAGE_LORA_TRANSFORMER_PREFIX
30
+ from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
31
+ from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
32
+ from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
33
+ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ZImageConditioningInfo
34
+ from invokeai.backend.util.devices import TorchDevice
35
+ from invokeai.backend.z_image.extensions.regional_prompting_extension import ZImageRegionalPromptingExtension
36
+ from invokeai.backend.z_image.text_conditioning import ZImageTextConditioning
37
+ from invokeai.backend.z_image.z_image_control_adapter import ZImageControlAdapter
38
+ from invokeai.backend.z_image.z_image_controlnet_extension import (
39
+ ZImageControlNetExtension,
40
+ z_image_forward_with_control,
41
+ )
42
+ from invokeai.backend.z_image.z_image_transformer_patch import patch_transformer_for_regional_prompting
43
+
44
+
45
+ @invocation(
46
+ "z_image_denoise",
47
+ title="Denoise - Z-Image",
48
+ tags=["image", "z-image"],
49
+ category="image",
50
+ version="1.2.0",
51
+ classification=Classification.Prototype,
52
+ )
53
+ class ZImageDenoiseInvocation(BaseInvocation):
54
+ """Run the denoising process with a Z-Image model.
55
+
56
+ Supports regional prompting by connecting multiple conditioning inputs with masks.
57
+ """
58
+
59
+ # If latents is provided, this means we are doing image-to-image.
60
+ latents: Optional[LatentsField] = InputField(
61
+ default=None, description=FieldDescriptions.latents, input=Input.Connection
62
+ )
63
+ # denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
64
+ denoise_mask: Optional[DenoiseMaskField] = InputField(
65
+ default=None, description=FieldDescriptions.denoise_mask, input=Input.Connection
66
+ )
67
+ denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
68
+ denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
69
+ transformer: TransformerField = InputField(
70
+ description=FieldDescriptions.z_image_model, input=Input.Connection, title="Transformer"
71
+ )
72
+ positive_conditioning: ZImageConditioningField | list[ZImageConditioningField] = InputField(
73
+ description=FieldDescriptions.positive_cond, input=Input.Connection
74
+ )
75
+ negative_conditioning: ZImageConditioningField | list[ZImageConditioningField] | None = InputField(
76
+ default=None, description=FieldDescriptions.negative_cond, input=Input.Connection
77
+ )
78
+ # Z-Image-Turbo works best without CFG (guidance_scale=1.0)
79
+ guidance_scale: float = InputField(
80
+ default=1.0,
81
+ ge=1.0,
82
+ description="Guidance scale for classifier-free guidance. 1.0 = no CFG (recommended for Z-Image-Turbo). "
83
+ "Values > 1.0 amplify guidance.",
84
+ title="Guidance Scale",
85
+ )
86
+ width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
87
+ height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
88
+ # Z-Image-Turbo uses 8 steps by default
89
+ steps: int = InputField(default=8, gt=0, description="Number of denoising steps. 8 recommended for Z-Image-Turbo.")
90
+ seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
91
+ # Z-Image Control support
92
+ control: Optional[ZImageControlField] = InputField(
93
+ default=None,
94
+ description="Z-Image control conditioning for spatial control (Canny, HED, Depth, Pose, MLSD).",
95
+ input=Input.Connection,
96
+ )
97
+ # VAE for encoding control images (required when using control)
98
+ vae: Optional[VAEField] = InputField(
99
+ default=None,
100
+ description=FieldDescriptions.vae + " Required for control conditioning.",
101
+ input=Input.Connection,
102
+ )
103
+
104
+ @torch.no_grad()
105
+ def invoke(self, context: InvocationContext) -> LatentsOutput:
106
+ latents = self._run_diffusion(context)
107
+ latents = latents.detach().to("cpu")
108
+
109
+ name = context.tensors.save(tensor=latents)
110
+ return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
111
+
112
+ def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
113
+ """Prepare the inpaint mask."""
114
+ if self.denoise_mask is None:
115
+ return None
116
+ mask = context.tensors.load(self.denoise_mask.mask_name)
117
+
118
+ # Invert mask: 0.0 = regions to denoise, 1.0 = regions to preserve
119
+ mask = 1.0 - mask
120
+
121
+ _, _, latent_height, latent_width = latents.shape
122
+ mask = tv_resize(
123
+ img=mask,
124
+ size=[latent_height, latent_width],
125
+ interpolation=tv_transforms.InterpolationMode.BILINEAR,
126
+ antialias=False,
127
+ )
128
+
129
+ mask = mask.to(device=latents.device, dtype=latents.dtype)
130
+ return mask
131
+
132
+ def _load_text_conditioning(
133
+ self,
134
+ context: InvocationContext,
135
+ cond_field: ZImageConditioningField | list[ZImageConditioningField],
136
+ img_height: int,
137
+ img_width: int,
138
+ dtype: torch.dtype,
139
+ device: torch.device,
140
+ ) -> list[ZImageTextConditioning]:
141
+ """Load Z-Image text conditioning with optional regional masks.
142
+
143
+ Args:
144
+ context: The invocation context.
145
+ cond_field: Single conditioning field or list of fields.
146
+ img_height: Height of the image token grid (H // patch_size).
147
+ img_width: Width of the image token grid (W // patch_size).
148
+ dtype: Target dtype.
149
+ device: Target device.
150
+
151
+ Returns:
152
+ List of ZImageTextConditioning objects with embeddings and masks.
153
+ """
154
+ # Normalize to a list
155
+ cond_list = [cond_field] if isinstance(cond_field, ZImageConditioningField) else cond_field
156
+
157
+ text_conditionings: list[ZImageTextConditioning] = []
158
+ for cond in cond_list:
159
+ # Load the text embeddings
160
+ cond_data = context.conditioning.load(cond.conditioning_name)
161
+ assert len(cond_data.conditionings) == 1
162
+ z_image_conditioning = cond_data.conditionings[0]
163
+ assert isinstance(z_image_conditioning, ZImageConditioningInfo)
164
+ z_image_conditioning = z_image_conditioning.to(dtype=dtype, device=device)
165
+ prompt_embeds = z_image_conditioning.prompt_embeds
166
+
167
+ # Load the mask, if provided
168
+ mask: torch.Tensor | None = None
169
+ if cond.mask is not None:
170
+ mask = context.tensors.load(cond.mask.tensor_name)
171
+ mask = mask.to(device=device)
172
+ mask = ZImageRegionalPromptingExtension.preprocess_regional_prompt_mask(
173
+ mask, img_height, img_width, dtype, device
174
+ )
175
+
176
+ text_conditionings.append(ZImageTextConditioning(prompt_embeds=prompt_embeds, mask=mask))
177
+
178
+ return text_conditionings
179
+
180
+ def _get_noise(
181
+ self,
182
+ batch_size: int,
183
+ num_channels_latents: int,
184
+ height: int,
185
+ width: int,
186
+ dtype: torch.dtype,
187
+ device: torch.device,
188
+ seed: int,
189
+ ) -> torch.Tensor:
190
+ """Generate initial noise tensor."""
191
+ # Generate noise as float32 on CPU for maximum compatibility,
192
+ # then cast to target dtype/device
193
+ rand_device = "cpu"
194
+ rand_dtype = torch.float32
195
+
196
+ return torch.randn(
197
+ batch_size,
198
+ num_channels_latents,
199
+ int(height) // LATENT_SCALE_FACTOR,
200
+ int(width) // LATENT_SCALE_FACTOR,
201
+ device=rand_device,
202
+ dtype=rand_dtype,
203
+ generator=torch.Generator(device=rand_device).manual_seed(seed),
204
+ ).to(device=device, dtype=dtype)
205
+
206
+ def _calculate_shift(
207
+ self,
208
+ image_seq_len: int,
209
+ base_image_seq_len: int = 256,
210
+ max_image_seq_len: int = 4096,
211
+ base_shift: float = 0.5,
212
+ max_shift: float = 1.15,
213
+ ) -> float:
214
+ """Calculate timestep shift based on image sequence length.
215
+
216
+ Based on diffusers ZImagePipeline.calculate_shift method.
217
+ """
218
+ m = (max_shift - base_shift) / (max_image_seq_len - base_image_seq_len)
219
+ b = base_shift - m * base_image_seq_len
220
+ mu = image_seq_len * m + b
221
+ return mu
222
+
223
+ def _get_sigmas(self, mu: float, num_steps: int) -> list[float]:
224
+ """Generate sigma schedule with time shift.
225
+
226
+ Based on FlowMatchEulerDiscreteScheduler with shift.
227
+ Generates num_steps + 1 sigma values (including terminal 0.0).
228
+ """
229
+ import math
230
+
231
+ def time_shift(mu: float, sigma: float, t: float) -> float:
232
+ """Apply time shift to a single timestep value."""
233
+ if t <= 0:
234
+ return 0.0
235
+ if t >= 1:
236
+ return 1.0
237
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
238
+
239
+ # Generate linearly spaced values from 1 to 0 (excluding endpoints for safety)
240
+ # then apply time shift
241
+ sigmas = []
242
+ for i in range(num_steps + 1):
243
+ t = 1.0 - i / num_steps # Goes from 1.0 to 0.0
244
+ sigma = time_shift(mu, 1.0, t)
245
+ sigmas.append(sigma)
246
+
247
+ return sigmas
248
+
249
+ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
250
+ device = TorchDevice.choose_torch_device()
251
+ inference_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
252
+
253
+ transformer_info = context.models.load(self.transformer.transformer)
254
+
255
+ # Calculate image token grid dimensions
256
+ patch_size = 2 # Z-Image uses patch_size=2
257
+ latent_height = self.height // LATENT_SCALE_FACTOR
258
+ latent_width = self.width // LATENT_SCALE_FACTOR
259
+ img_token_height = latent_height // patch_size
260
+ img_token_width = latent_width // patch_size
261
+ img_seq_len = img_token_height * img_token_width
262
+
263
+ # Load positive conditioning with regional masks
264
+ pos_text_conditionings = self._load_text_conditioning(
265
+ context=context,
266
+ cond_field=self.positive_conditioning,
267
+ img_height=img_token_height,
268
+ img_width=img_token_width,
269
+ dtype=inference_dtype,
270
+ device=device,
271
+ )
272
+
273
+ # Create regional prompting extension
274
+ regional_extension = ZImageRegionalPromptingExtension.from_text_conditionings(
275
+ text_conditionings=pos_text_conditionings,
276
+ img_seq_len=img_seq_len,
277
+ )
278
+
279
+ # Get the concatenated prompt embeddings for the transformer
280
+ pos_prompt_embeds = regional_extension.regional_text_conditioning.prompt_embeds
281
+
282
+ # Load negative conditioning if provided and guidance_scale != 1.0
283
+ # CFG formula: pred = pred_uncond + cfg_scale * (pred_cond - pred_uncond)
284
+ # At cfg_scale=1.0: pred = pred_cond (no effect, skip uncond computation)
285
+ # This matches FLUX's convention where 1.0 means "no CFG"
286
+ neg_prompt_embeds: torch.Tensor | None = None
287
+ do_classifier_free_guidance = (
288
+ not math.isclose(self.guidance_scale, 1.0) and self.negative_conditioning is not None
289
+ )
290
+ if do_classifier_free_guidance:
291
+ assert self.negative_conditioning is not None
292
+ # Load all negative conditionings and concatenate embeddings
293
+ # Note: We ignore masks for negative conditioning as regional negative prompting is not fully supported
294
+ neg_text_conditionings = self._load_text_conditioning(
295
+ context=context,
296
+ cond_field=self.negative_conditioning,
297
+ img_height=img_token_height,
298
+ img_width=img_token_width,
299
+ dtype=inference_dtype,
300
+ device=device,
301
+ )
302
+ # Concatenate all negative embeddings
303
+ neg_prompt_embeds = torch.cat([tc.prompt_embeds for tc in neg_text_conditionings], dim=0)
304
+
305
+ # Calculate shift based on image sequence length
306
+ mu = self._calculate_shift(img_seq_len)
307
+
308
+ # Generate sigma schedule with time shift
309
+ sigmas = self._get_sigmas(mu, self.steps)
310
+
311
+ # Apply denoising_start and denoising_end clipping
312
+ if self.denoising_start > 0 or self.denoising_end < 1:
313
+ # Calculate start and end indices based on denoising range
314
+ total_sigmas = len(sigmas)
315
+ start_idx = int(self.denoising_start * (total_sigmas - 1))
316
+ end_idx = int(self.denoising_end * (total_sigmas - 1)) + 1
317
+ sigmas = sigmas[start_idx:end_idx]
318
+
319
+ total_steps = len(sigmas) - 1
320
+
321
+ # Load input latents if provided (image-to-image)
322
+ init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
323
+ if init_latents is not None:
324
+ init_latents = init_latents.to(device=device, dtype=inference_dtype)
325
+
326
+ # Generate initial noise
327
+ num_channels_latents = 16 # Z-Image uses 16 latent channels
328
+ noise = self._get_noise(
329
+ batch_size=1,
330
+ num_channels_latents=num_channels_latents,
331
+ height=self.height,
332
+ width=self.width,
333
+ dtype=inference_dtype,
334
+ device=device,
335
+ seed=self.seed,
336
+ )
337
+
338
+ # Prepare input latent image
339
+ if init_latents is not None:
340
+ s_0 = sigmas[0]
341
+ latents = s_0 * noise + (1.0 - s_0) * init_latents
342
+ else:
343
+ if self.denoising_start > 1e-5:
344
+ raise ValueError("denoising_start should be 0 when initial latents are not provided.")
345
+ latents = noise
346
+
347
+ # Short-circuit if no denoising steps
348
+ if total_steps <= 0:
349
+ return latents
350
+
351
+ # Prepare inpaint extension
352
+ inpaint_mask = self._prep_inpaint_mask(context, latents)
353
+ inpaint_extension: RectifiedFlowInpaintExtension | None = None
354
+ if inpaint_mask is not None:
355
+ if init_latents is None:
356
+ raise ValueError("Initial latents are required when using an inpaint mask (image-to-image inpainting)")
357
+ inpaint_extension = RectifiedFlowInpaintExtension(
358
+ init_latents=init_latents,
359
+ inpaint_mask=inpaint_mask,
360
+ noise=noise,
361
+ )
362
+
363
+ step_callback = self._build_step_callback(context)
364
+ step_callback(
365
+ PipelineIntermediateState(
366
+ step=0,
367
+ order=1,
368
+ total_steps=total_steps,
369
+ timestep=int(sigmas[0] * 1000),
370
+ latents=latents,
371
+ ),
372
+ )
373
+
374
+ with ExitStack() as exit_stack:
375
+ # Get transformer config to determine if it's quantized
376
+ transformer_config = context.models.get_config(self.transformer.transformer)
377
+
378
+ # Determine if the model is quantized.
379
+ # If the model is quantized, then we need to apply the LoRA weights as sidecar layers. This results in
380
+ # slower inference than direct patching, but is agnostic to the quantization format.
381
+ if transformer_config.format in [ModelFormat.Diffusers, ModelFormat.Checkpoint]:
382
+ model_is_quantized = False
383
+ elif transformer_config.format in [ModelFormat.GGUFQuantized]:
384
+ model_is_quantized = True
385
+ else:
386
+ raise ValueError(f"Unsupported Z-Image model format: {transformer_config.format}")
387
+
388
+ # Load transformer - always use base transformer, control is handled via extension
389
+ (cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
390
+
391
+ # Prepare control extension if control is provided
392
+ control_extension: ZImageControlNetExtension | None = None
393
+
394
+ if self.control is not None:
395
+ # Load control adapter using context manager (proper GPU memory management)
396
+ control_model_info = context.models.load(self.control.control_model)
397
+ (_, control_adapter) = exit_stack.enter_context(control_model_info.model_on_device())
398
+ assert isinstance(control_adapter, ZImageControlAdapter)
399
+
400
+ # Get control_in_dim from adapter config (16 for V1, 33 for V2.0)
401
+ adapter_config = control_adapter.config
402
+ control_in_dim = adapter_config.get("control_in_dim", 16)
403
+ num_control_blocks = adapter_config.get("num_control_blocks", 6)
404
+
405
+ # Log control configuration for debugging
406
+ version = "V2.0" if control_in_dim > 16 else "V1"
407
+ context.util.signal_progress(
408
+ f"Using Z-Image ControlNet {version} (Extension): control_in_dim={control_in_dim}, "
409
+ f"num_blocks={num_control_blocks}, scale={self.control.control_context_scale}"
410
+ )
411
+
412
+ # Load and prepare control image - must be VAE-encoded!
413
+ if self.vae is None:
414
+ raise ValueError("VAE is required when using Z-Image Control. Connect a VAE to the 'vae' input.")
415
+
416
+ control_image = context.images.get_pil(self.control.image_name)
417
+
418
+ # Resize control image to match output dimensions
419
+ control_image = control_image.convert("RGB")
420
+ control_image = control_image.resize((self.width, self.height), Image.Resampling.LANCZOS)
421
+
422
+ # Convert to tensor format for VAE encoding
423
+ from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
424
+
425
+ control_image_tensor = image_resized_to_grid_as_tensor(control_image)
426
+ if control_image_tensor.dim() == 3:
427
+ control_image_tensor = einops.rearrange(control_image_tensor, "c h w -> 1 c h w")
428
+
429
+ # Encode control image through VAE to get latents
430
+ vae_info = context.models.load(self.vae.vae)
431
+ control_latents = ZImageImageToLatentsInvocation.vae_encode(
432
+ vae_info=vae_info,
433
+ image_tensor=control_image_tensor,
434
+ )
435
+
436
+ # Move to inference device/dtype
437
+ control_latents = control_latents.to(device=device, dtype=inference_dtype)
438
+
439
+ # Add frame dimension: [B, C, H, W] -> [C, 1, H, W] (single image)
440
+ control_latents = control_latents.squeeze(0).unsqueeze(1)
441
+
442
+ # Prepare control_cond based on control_in_dim
443
+ # V1: 16 channels (just control latents)
444
+ # V2.0: 33 channels = 16 control + 16 reference + 1 mask
445
+ # - Channels 0-15: control image latents (from VAE encoding)
446
+ # - Channels 16-31: reference/inpaint image latents (zeros for pure control)
447
+ # - Channel 32: inpaint mask (1.0 = don't inpaint, 0.0 = inpaint region)
448
+ # For pure control (no inpainting), we set mask=1 to tell model "use control, don't inpaint"
449
+ c, f, h, w = control_latents.shape
450
+ if c < control_in_dim:
451
+ padding_channels = control_in_dim - c
452
+ if padding_channels == 17:
453
+ # V2.0: 16 reference channels (zeros) + 1 mask channel (ones)
454
+ ref_padding = torch.zeros(
455
+ (16, f, h, w),
456
+ device=device,
457
+ dtype=inference_dtype,
458
+ )
459
+ # Mask channel = 1.0 means "don't inpaint this region, use control signal"
460
+ mask_channel = torch.ones(
461
+ (1, f, h, w),
462
+ device=device,
463
+ dtype=inference_dtype,
464
+ )
465
+ control_latents = torch.cat([control_latents, ref_padding, mask_channel], dim=0)
466
+ else:
467
+ # Generic padding with zeros for other cases
468
+ zero_padding = torch.zeros(
469
+ (padding_channels, f, h, w),
470
+ device=device,
471
+ dtype=inference_dtype,
472
+ )
473
+ control_latents = torch.cat([control_latents, zero_padding], dim=0)
474
+
475
+ # Create control extension (adapter is already on device from model_on_device)
476
+ control_extension = ZImageControlNetExtension(
477
+ control_adapter=control_adapter,
478
+ control_cond=control_latents,
479
+ weight=self.control.control_context_scale,
480
+ begin_step_percent=self.control.begin_step_percent,
481
+ end_step_percent=self.control.end_step_percent,
482
+ )
483
+
484
+ # Apply LoRA models to the transformer.
485
+ # Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
486
+ exit_stack.enter_context(
487
+ LayerPatcher.apply_smart_model_patches(
488
+ model=transformer,
489
+ patches=self._lora_iterator(context),
490
+ prefix=Z_IMAGE_LORA_TRANSFORMER_PREFIX,
491
+ dtype=inference_dtype,
492
+ cached_weights=cached_weights,
493
+ force_sidecar_patching=model_is_quantized,
494
+ )
495
+ )
496
+
497
+ # Apply regional prompting patch if we have regional masks
498
+ exit_stack.enter_context(
499
+ patch_transformer_for_regional_prompting(
500
+ transformer=transformer,
501
+ regional_attn_mask=regional_extension.regional_attn_mask,
502
+ img_seq_len=img_seq_len,
503
+ )
504
+ )
505
+
506
+ # Denoising loop
507
+ for step_idx in tqdm(range(total_steps)):
508
+ sigma_curr = sigmas[step_idx]
509
+ sigma_prev = sigmas[step_idx + 1]
510
+
511
+ # Timestep tensor for Z-Image model
512
+ # The model expects t=0 at start (noise) and t=1 at end (clean)
513
+ # Sigma goes from 1 (noise) to 0 (clean), so model_t = 1 - sigma
514
+ model_t = 1.0 - sigma_curr
515
+ timestep = torch.tensor([model_t], device=device, dtype=inference_dtype).expand(latents.shape[0])
516
+
517
+ # Run transformer for positive prediction
518
+ # Z-Image transformer expects: x as list of [C, 1, H, W] tensors, t, cap_feats as list
519
+ # Prepare latent input: [B, C, H, W] -> [B, C, 1, H, W] -> list of [C, 1, H, W]
520
+ latent_model_input = latents.to(transformer.dtype)
521
+ latent_model_input = latent_model_input.unsqueeze(2) # Add frame dimension
522
+ latent_model_input_list = list(latent_model_input.unbind(dim=0))
523
+
524
+ # Determine if control should be applied at this step
525
+ apply_control = control_extension is not None and control_extension.should_apply(step_idx, total_steps)
526
+
527
+ # Run forward pass - use custom forward with control if extension is active
528
+ if apply_control:
529
+ model_out_list, _ = z_image_forward_with_control(
530
+ transformer=transformer,
531
+ x=latent_model_input_list,
532
+ t=timestep,
533
+ cap_feats=[pos_prompt_embeds],
534
+ control_extension=control_extension,
535
+ )
536
+ else:
537
+ model_output = transformer(
538
+ x=latent_model_input_list,
539
+ t=timestep,
540
+ cap_feats=[pos_prompt_embeds],
541
+ )
542
+ model_out_list = model_output[0] # Extract list of tensors from tuple
543
+
544
+ noise_pred_cond = torch.stack([t.float() for t in model_out_list], dim=0)
545
+ noise_pred_cond = noise_pred_cond.squeeze(2) # Remove frame dimension
546
+ noise_pred_cond = -noise_pred_cond # Z-Image uses v-prediction with negation
547
+
548
+ # Apply CFG if enabled
549
+ if do_classifier_free_guidance and neg_prompt_embeds is not None:
550
+ if apply_control:
551
+ model_out_list_uncond, _ = z_image_forward_with_control(
552
+ transformer=transformer,
553
+ x=latent_model_input_list,
554
+ t=timestep,
555
+ cap_feats=[neg_prompt_embeds],
556
+ control_extension=control_extension,
557
+ )
558
+ else:
559
+ model_output_uncond = transformer(
560
+ x=latent_model_input_list,
561
+ t=timestep,
562
+ cap_feats=[neg_prompt_embeds],
563
+ )
564
+ model_out_list_uncond = model_output_uncond[0] # Extract list of tensors from tuple
565
+
566
+ noise_pred_uncond = torch.stack([t.float() for t in model_out_list_uncond], dim=0)
567
+ noise_pred_uncond = noise_pred_uncond.squeeze(2)
568
+ noise_pred_uncond = -noise_pred_uncond
569
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
570
+ else:
571
+ noise_pred = noise_pred_cond
572
+
573
+ # Euler step
574
+ latents_dtype = latents.dtype
575
+ latents = latents.to(dtype=torch.float32)
576
+ latents = latents + (sigma_prev - sigma_curr) * noise_pred
577
+ latents = latents.to(dtype=latents_dtype)
578
+
579
+ if inpaint_extension is not None:
580
+ latents = inpaint_extension.merge_intermediate_latents_with_init_latents(latents, sigma_prev)
581
+
582
+ step_callback(
583
+ PipelineIntermediateState(
584
+ step=step_idx + 1,
585
+ order=1,
586
+ total_steps=total_steps,
587
+ timestep=int(sigma_curr * 1000),
588
+ latents=latents,
589
+ ),
590
+ )
591
+
592
+ return latents
593
+
594
+ def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
595
+ def step_callback(state: PipelineIntermediateState) -> None:
596
+ context.util.sd_step_callback(state, BaseModelType.ZImage)
597
+
598
+ return step_callback
599
+
600
+ def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
601
+ """Iterate over LoRA models to apply to the transformer."""
602
+ for lora in self.transformer.loras:
603
+ lora_info = context.models.load(lora.lora)
604
+ if not isinstance(lora_info.model, ModelPatchRaw):
605
+ raise TypeError(
606
+ f"Expected ModelPatchRaw for LoRA '{lora.lora.key}', got {type(lora_info.model).__name__}. "
607
+ "The LoRA model may be corrupted or incompatible."
608
+ )
609
+ yield (lora_info.model, lora.weight)
610
+ del lora_info