InvokeAI 6.9.0rc3__py3-none-any.whl → 6.10.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invokeai/app/api/dependencies.py +2 -0
- invokeai/app/api/routers/model_manager.py +91 -2
- invokeai/app/api/routers/workflows.py +9 -0
- invokeai/app/invocations/fields.py +19 -0
- invokeai/app/invocations/image_to_latents.py +23 -5
- invokeai/app/invocations/latents_to_image.py +2 -25
- invokeai/app/invocations/metadata.py +9 -1
- invokeai/app/invocations/model.py +8 -0
- invokeai/app/invocations/primitives.py +12 -0
- invokeai/app/invocations/prompt_template.py +57 -0
- invokeai/app/invocations/z_image_control.py +112 -0
- invokeai/app/invocations/z_image_denoise.py +610 -0
- invokeai/app/invocations/z_image_image_to_latents.py +102 -0
- invokeai/app/invocations/z_image_latents_to_image.py +103 -0
- invokeai/app/invocations/z_image_lora_loader.py +153 -0
- invokeai/app/invocations/z_image_model_loader.py +135 -0
- invokeai/app/invocations/z_image_text_encoder.py +197 -0
- invokeai/app/services/model_install/model_install_common.py +14 -1
- invokeai/app/services/model_install/model_install_default.py +119 -19
- invokeai/app/services/model_records/model_records_base.py +12 -0
- invokeai/app/services/model_records/model_records_sql.py +17 -0
- invokeai/app/services/shared/graph.py +132 -77
- invokeai/app/services/workflow_records/workflow_records_base.py +8 -0
- invokeai/app/services/workflow_records/workflow_records_sqlite.py +42 -0
- invokeai/app/util/step_callback.py +3 -0
- invokeai/backend/model_manager/configs/controlnet.py +47 -1
- invokeai/backend/model_manager/configs/factory.py +26 -1
- invokeai/backend/model_manager/configs/lora.py +43 -1
- invokeai/backend/model_manager/configs/main.py +113 -0
- invokeai/backend/model_manager/configs/qwen3_encoder.py +156 -0
- invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_diffusers_rms_norm.py +40 -0
- invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_layer_norm.py +25 -0
- invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py +11 -2
- invokeai/backend/model_manager/load/model_loaders/lora.py +11 -0
- invokeai/backend/model_manager/load/model_loaders/z_image.py +935 -0
- invokeai/backend/model_manager/load/model_util.py +6 -1
- invokeai/backend/model_manager/metadata/metadata_base.py +12 -5
- invokeai/backend/model_manager/model_on_disk.py +3 -0
- invokeai/backend/model_manager/starter_models.py +70 -0
- invokeai/backend/model_manager/taxonomy.py +5 -0
- invokeai/backend/model_manager/util/select_hf_files.py +23 -8
- invokeai/backend/patches/layer_patcher.py +34 -16
- invokeai/backend/patches/layers/lora_layer_base.py +2 -1
- invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py +17 -2
- invokeai/backend/patches/lora_conversions/flux_xlabs_lora_conversion_utils.py +92 -0
- invokeai/backend/patches/lora_conversions/formats.py +5 -0
- invokeai/backend/patches/lora_conversions/z_image_lora_constants.py +8 -0
- invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +155 -0
- invokeai/backend/quantization/gguf/ggml_tensor.py +27 -4
- invokeai/backend/quantization/gguf/loaders.py +47 -12
- invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +13 -0
- invokeai/backend/util/devices.py +25 -0
- invokeai/backend/util/hotfixes.py +2 -2
- invokeai/backend/z_image/__init__.py +16 -0
- invokeai/backend/z_image/extensions/__init__.py +1 -0
- invokeai/backend/z_image/extensions/regional_prompting_extension.py +207 -0
- invokeai/backend/z_image/text_conditioning.py +74 -0
- invokeai/backend/z_image/z_image_control_adapter.py +238 -0
- invokeai/backend/z_image/z_image_control_transformer.py +643 -0
- invokeai/backend/z_image/z_image_controlnet_extension.py +531 -0
- invokeai/backend/z_image/z_image_patchify_utils.py +135 -0
- invokeai/backend/z_image/z_image_transformer_patch.py +234 -0
- invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +161 -0
- invokeai/frontend/web/dist/assets/{browser-ponyfill-CN1j0ARZ.js → browser-ponyfill-DHZxq1nk.js} +1 -1
- invokeai/frontend/web/dist/assets/index-dgSJAY--.js +530 -0
- invokeai/frontend/web/dist/index.html +1 -1
- invokeai/frontend/web/dist/locales/de.json +24 -6
- invokeai/frontend/web/dist/locales/en.json +70 -1
- invokeai/frontend/web/dist/locales/es.json +0 -5
- invokeai/frontend/web/dist/locales/fr.json +0 -6
- invokeai/frontend/web/dist/locales/it.json +17 -64
- invokeai/frontend/web/dist/locales/ja.json +379 -44
- invokeai/frontend/web/dist/locales/ru.json +0 -6
- invokeai/frontend/web/dist/locales/vi.json +7 -54
- invokeai/frontend/web/dist/locales/zh-CN.json +0 -6
- invokeai/version/invokeai_version.py +1 -1
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/METADATA +3 -3
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/RECORD +84 -60
- invokeai/frontend/web/dist/assets/App-Cn9UyjoV.js +0 -161
- invokeai/frontend/web/dist/assets/index-BDrf9CL-.js +0 -530
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/WHEEL +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/entry_points.txt +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,610 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from contextlib import ExitStack
|
|
3
|
+
from typing import Callable, Iterator, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import einops
|
|
6
|
+
import torch
|
|
7
|
+
import torchvision.transforms as tv_transforms
|
|
8
|
+
from PIL import Image
|
|
9
|
+
from torchvision.transforms.functional import resize as tv_resize
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
|
|
12
|
+
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
|
13
|
+
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
|
14
|
+
from invokeai.app.invocations.fields import (
|
|
15
|
+
DenoiseMaskField,
|
|
16
|
+
FieldDescriptions,
|
|
17
|
+
Input,
|
|
18
|
+
InputField,
|
|
19
|
+
LatentsField,
|
|
20
|
+
ZImageConditioningField,
|
|
21
|
+
)
|
|
22
|
+
from invokeai.app.invocations.model import TransformerField, VAEField
|
|
23
|
+
from invokeai.app.invocations.primitives import LatentsOutput
|
|
24
|
+
from invokeai.app.invocations.z_image_control import ZImageControlField
|
|
25
|
+
from invokeai.app.invocations.z_image_image_to_latents import ZImageImageToLatentsInvocation
|
|
26
|
+
from invokeai.app.services.shared.invocation_context import InvocationContext
|
|
27
|
+
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat
|
|
28
|
+
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
|
29
|
+
from invokeai.backend.patches.lora_conversions.z_image_lora_constants import Z_IMAGE_LORA_TRANSFORMER_PREFIX
|
|
30
|
+
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
|
31
|
+
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
|
|
32
|
+
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
|
33
|
+
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ZImageConditioningInfo
|
|
34
|
+
from invokeai.backend.util.devices import TorchDevice
|
|
35
|
+
from invokeai.backend.z_image.extensions.regional_prompting_extension import ZImageRegionalPromptingExtension
|
|
36
|
+
from invokeai.backend.z_image.text_conditioning import ZImageTextConditioning
|
|
37
|
+
from invokeai.backend.z_image.z_image_control_adapter import ZImageControlAdapter
|
|
38
|
+
from invokeai.backend.z_image.z_image_controlnet_extension import (
|
|
39
|
+
ZImageControlNetExtension,
|
|
40
|
+
z_image_forward_with_control,
|
|
41
|
+
)
|
|
42
|
+
from invokeai.backend.z_image.z_image_transformer_patch import patch_transformer_for_regional_prompting
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@invocation(
|
|
46
|
+
"z_image_denoise",
|
|
47
|
+
title="Denoise - Z-Image",
|
|
48
|
+
tags=["image", "z-image"],
|
|
49
|
+
category="image",
|
|
50
|
+
version="1.2.0",
|
|
51
|
+
classification=Classification.Prototype,
|
|
52
|
+
)
|
|
53
|
+
class ZImageDenoiseInvocation(BaseInvocation):
|
|
54
|
+
"""Run the denoising process with a Z-Image model.
|
|
55
|
+
|
|
56
|
+
Supports regional prompting by connecting multiple conditioning inputs with masks.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
# If latents is provided, this means we are doing image-to-image.
|
|
60
|
+
latents: Optional[LatentsField] = InputField(
|
|
61
|
+
default=None, description=FieldDescriptions.latents, input=Input.Connection
|
|
62
|
+
)
|
|
63
|
+
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
|
|
64
|
+
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
|
65
|
+
default=None, description=FieldDescriptions.denoise_mask, input=Input.Connection
|
|
66
|
+
)
|
|
67
|
+
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
|
|
68
|
+
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
|
69
|
+
transformer: TransformerField = InputField(
|
|
70
|
+
description=FieldDescriptions.z_image_model, input=Input.Connection, title="Transformer"
|
|
71
|
+
)
|
|
72
|
+
positive_conditioning: ZImageConditioningField | list[ZImageConditioningField] = InputField(
|
|
73
|
+
description=FieldDescriptions.positive_cond, input=Input.Connection
|
|
74
|
+
)
|
|
75
|
+
negative_conditioning: ZImageConditioningField | list[ZImageConditioningField] | None = InputField(
|
|
76
|
+
default=None, description=FieldDescriptions.negative_cond, input=Input.Connection
|
|
77
|
+
)
|
|
78
|
+
# Z-Image-Turbo works best without CFG (guidance_scale=1.0)
|
|
79
|
+
guidance_scale: float = InputField(
|
|
80
|
+
default=1.0,
|
|
81
|
+
ge=1.0,
|
|
82
|
+
description="Guidance scale for classifier-free guidance. 1.0 = no CFG (recommended for Z-Image-Turbo). "
|
|
83
|
+
"Values > 1.0 amplify guidance.",
|
|
84
|
+
title="Guidance Scale",
|
|
85
|
+
)
|
|
86
|
+
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
|
87
|
+
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
|
88
|
+
# Z-Image-Turbo uses 8 steps by default
|
|
89
|
+
steps: int = InputField(default=8, gt=0, description="Number of denoising steps. 8 recommended for Z-Image-Turbo.")
|
|
90
|
+
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
|
91
|
+
# Z-Image Control support
|
|
92
|
+
control: Optional[ZImageControlField] = InputField(
|
|
93
|
+
default=None,
|
|
94
|
+
description="Z-Image control conditioning for spatial control (Canny, HED, Depth, Pose, MLSD).",
|
|
95
|
+
input=Input.Connection,
|
|
96
|
+
)
|
|
97
|
+
# VAE for encoding control images (required when using control)
|
|
98
|
+
vae: Optional[VAEField] = InputField(
|
|
99
|
+
default=None,
|
|
100
|
+
description=FieldDescriptions.vae + " Required for control conditioning.",
|
|
101
|
+
input=Input.Connection,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
@torch.no_grad()
|
|
105
|
+
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
|
106
|
+
latents = self._run_diffusion(context)
|
|
107
|
+
latents = latents.detach().to("cpu")
|
|
108
|
+
|
|
109
|
+
name = context.tensors.save(tensor=latents)
|
|
110
|
+
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
|
111
|
+
|
|
112
|
+
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
|
|
113
|
+
"""Prepare the inpaint mask."""
|
|
114
|
+
if self.denoise_mask is None:
|
|
115
|
+
return None
|
|
116
|
+
mask = context.tensors.load(self.denoise_mask.mask_name)
|
|
117
|
+
|
|
118
|
+
# Invert mask: 0.0 = regions to denoise, 1.0 = regions to preserve
|
|
119
|
+
mask = 1.0 - mask
|
|
120
|
+
|
|
121
|
+
_, _, latent_height, latent_width = latents.shape
|
|
122
|
+
mask = tv_resize(
|
|
123
|
+
img=mask,
|
|
124
|
+
size=[latent_height, latent_width],
|
|
125
|
+
interpolation=tv_transforms.InterpolationMode.BILINEAR,
|
|
126
|
+
antialias=False,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
mask = mask.to(device=latents.device, dtype=latents.dtype)
|
|
130
|
+
return mask
|
|
131
|
+
|
|
132
|
+
def _load_text_conditioning(
|
|
133
|
+
self,
|
|
134
|
+
context: InvocationContext,
|
|
135
|
+
cond_field: ZImageConditioningField | list[ZImageConditioningField],
|
|
136
|
+
img_height: int,
|
|
137
|
+
img_width: int,
|
|
138
|
+
dtype: torch.dtype,
|
|
139
|
+
device: torch.device,
|
|
140
|
+
) -> list[ZImageTextConditioning]:
|
|
141
|
+
"""Load Z-Image text conditioning with optional regional masks.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
context: The invocation context.
|
|
145
|
+
cond_field: Single conditioning field or list of fields.
|
|
146
|
+
img_height: Height of the image token grid (H // patch_size).
|
|
147
|
+
img_width: Width of the image token grid (W // patch_size).
|
|
148
|
+
dtype: Target dtype.
|
|
149
|
+
device: Target device.
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
List of ZImageTextConditioning objects with embeddings and masks.
|
|
153
|
+
"""
|
|
154
|
+
# Normalize to a list
|
|
155
|
+
cond_list = [cond_field] if isinstance(cond_field, ZImageConditioningField) else cond_field
|
|
156
|
+
|
|
157
|
+
text_conditionings: list[ZImageTextConditioning] = []
|
|
158
|
+
for cond in cond_list:
|
|
159
|
+
# Load the text embeddings
|
|
160
|
+
cond_data = context.conditioning.load(cond.conditioning_name)
|
|
161
|
+
assert len(cond_data.conditionings) == 1
|
|
162
|
+
z_image_conditioning = cond_data.conditionings[0]
|
|
163
|
+
assert isinstance(z_image_conditioning, ZImageConditioningInfo)
|
|
164
|
+
z_image_conditioning = z_image_conditioning.to(dtype=dtype, device=device)
|
|
165
|
+
prompt_embeds = z_image_conditioning.prompt_embeds
|
|
166
|
+
|
|
167
|
+
# Load the mask, if provided
|
|
168
|
+
mask: torch.Tensor | None = None
|
|
169
|
+
if cond.mask is not None:
|
|
170
|
+
mask = context.tensors.load(cond.mask.tensor_name)
|
|
171
|
+
mask = mask.to(device=device)
|
|
172
|
+
mask = ZImageRegionalPromptingExtension.preprocess_regional_prompt_mask(
|
|
173
|
+
mask, img_height, img_width, dtype, device
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
text_conditionings.append(ZImageTextConditioning(prompt_embeds=prompt_embeds, mask=mask))
|
|
177
|
+
|
|
178
|
+
return text_conditionings
|
|
179
|
+
|
|
180
|
+
def _get_noise(
|
|
181
|
+
self,
|
|
182
|
+
batch_size: int,
|
|
183
|
+
num_channels_latents: int,
|
|
184
|
+
height: int,
|
|
185
|
+
width: int,
|
|
186
|
+
dtype: torch.dtype,
|
|
187
|
+
device: torch.device,
|
|
188
|
+
seed: int,
|
|
189
|
+
) -> torch.Tensor:
|
|
190
|
+
"""Generate initial noise tensor."""
|
|
191
|
+
# Generate noise as float32 on CPU for maximum compatibility,
|
|
192
|
+
# then cast to target dtype/device
|
|
193
|
+
rand_device = "cpu"
|
|
194
|
+
rand_dtype = torch.float32
|
|
195
|
+
|
|
196
|
+
return torch.randn(
|
|
197
|
+
batch_size,
|
|
198
|
+
num_channels_latents,
|
|
199
|
+
int(height) // LATENT_SCALE_FACTOR,
|
|
200
|
+
int(width) // LATENT_SCALE_FACTOR,
|
|
201
|
+
device=rand_device,
|
|
202
|
+
dtype=rand_dtype,
|
|
203
|
+
generator=torch.Generator(device=rand_device).manual_seed(seed),
|
|
204
|
+
).to(device=device, dtype=dtype)
|
|
205
|
+
|
|
206
|
+
def _calculate_shift(
|
|
207
|
+
self,
|
|
208
|
+
image_seq_len: int,
|
|
209
|
+
base_image_seq_len: int = 256,
|
|
210
|
+
max_image_seq_len: int = 4096,
|
|
211
|
+
base_shift: float = 0.5,
|
|
212
|
+
max_shift: float = 1.15,
|
|
213
|
+
) -> float:
|
|
214
|
+
"""Calculate timestep shift based on image sequence length.
|
|
215
|
+
|
|
216
|
+
Based on diffusers ZImagePipeline.calculate_shift method.
|
|
217
|
+
"""
|
|
218
|
+
m = (max_shift - base_shift) / (max_image_seq_len - base_image_seq_len)
|
|
219
|
+
b = base_shift - m * base_image_seq_len
|
|
220
|
+
mu = image_seq_len * m + b
|
|
221
|
+
return mu
|
|
222
|
+
|
|
223
|
+
def _get_sigmas(self, mu: float, num_steps: int) -> list[float]:
|
|
224
|
+
"""Generate sigma schedule with time shift.
|
|
225
|
+
|
|
226
|
+
Based on FlowMatchEulerDiscreteScheduler with shift.
|
|
227
|
+
Generates num_steps + 1 sigma values (including terminal 0.0).
|
|
228
|
+
"""
|
|
229
|
+
import math
|
|
230
|
+
|
|
231
|
+
def time_shift(mu: float, sigma: float, t: float) -> float:
|
|
232
|
+
"""Apply time shift to a single timestep value."""
|
|
233
|
+
if t <= 0:
|
|
234
|
+
return 0.0
|
|
235
|
+
if t >= 1:
|
|
236
|
+
return 1.0
|
|
237
|
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
|
238
|
+
|
|
239
|
+
# Generate linearly spaced values from 1 to 0 (excluding endpoints for safety)
|
|
240
|
+
# then apply time shift
|
|
241
|
+
sigmas = []
|
|
242
|
+
for i in range(num_steps + 1):
|
|
243
|
+
t = 1.0 - i / num_steps # Goes from 1.0 to 0.0
|
|
244
|
+
sigma = time_shift(mu, 1.0, t)
|
|
245
|
+
sigmas.append(sigma)
|
|
246
|
+
|
|
247
|
+
return sigmas
|
|
248
|
+
|
|
249
|
+
def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
|
|
250
|
+
device = TorchDevice.choose_torch_device()
|
|
251
|
+
inference_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
|
|
252
|
+
|
|
253
|
+
transformer_info = context.models.load(self.transformer.transformer)
|
|
254
|
+
|
|
255
|
+
# Calculate image token grid dimensions
|
|
256
|
+
patch_size = 2 # Z-Image uses patch_size=2
|
|
257
|
+
latent_height = self.height // LATENT_SCALE_FACTOR
|
|
258
|
+
latent_width = self.width // LATENT_SCALE_FACTOR
|
|
259
|
+
img_token_height = latent_height // patch_size
|
|
260
|
+
img_token_width = latent_width // patch_size
|
|
261
|
+
img_seq_len = img_token_height * img_token_width
|
|
262
|
+
|
|
263
|
+
# Load positive conditioning with regional masks
|
|
264
|
+
pos_text_conditionings = self._load_text_conditioning(
|
|
265
|
+
context=context,
|
|
266
|
+
cond_field=self.positive_conditioning,
|
|
267
|
+
img_height=img_token_height,
|
|
268
|
+
img_width=img_token_width,
|
|
269
|
+
dtype=inference_dtype,
|
|
270
|
+
device=device,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
# Create regional prompting extension
|
|
274
|
+
regional_extension = ZImageRegionalPromptingExtension.from_text_conditionings(
|
|
275
|
+
text_conditionings=pos_text_conditionings,
|
|
276
|
+
img_seq_len=img_seq_len,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
# Get the concatenated prompt embeddings for the transformer
|
|
280
|
+
pos_prompt_embeds = regional_extension.regional_text_conditioning.prompt_embeds
|
|
281
|
+
|
|
282
|
+
# Load negative conditioning if provided and guidance_scale != 1.0
|
|
283
|
+
# CFG formula: pred = pred_uncond + cfg_scale * (pred_cond - pred_uncond)
|
|
284
|
+
# At cfg_scale=1.0: pred = pred_cond (no effect, skip uncond computation)
|
|
285
|
+
# This matches FLUX's convention where 1.0 means "no CFG"
|
|
286
|
+
neg_prompt_embeds: torch.Tensor | None = None
|
|
287
|
+
do_classifier_free_guidance = (
|
|
288
|
+
not math.isclose(self.guidance_scale, 1.0) and self.negative_conditioning is not None
|
|
289
|
+
)
|
|
290
|
+
if do_classifier_free_guidance:
|
|
291
|
+
assert self.negative_conditioning is not None
|
|
292
|
+
# Load all negative conditionings and concatenate embeddings
|
|
293
|
+
# Note: We ignore masks for negative conditioning as regional negative prompting is not fully supported
|
|
294
|
+
neg_text_conditionings = self._load_text_conditioning(
|
|
295
|
+
context=context,
|
|
296
|
+
cond_field=self.negative_conditioning,
|
|
297
|
+
img_height=img_token_height,
|
|
298
|
+
img_width=img_token_width,
|
|
299
|
+
dtype=inference_dtype,
|
|
300
|
+
device=device,
|
|
301
|
+
)
|
|
302
|
+
# Concatenate all negative embeddings
|
|
303
|
+
neg_prompt_embeds = torch.cat([tc.prompt_embeds for tc in neg_text_conditionings], dim=0)
|
|
304
|
+
|
|
305
|
+
# Calculate shift based on image sequence length
|
|
306
|
+
mu = self._calculate_shift(img_seq_len)
|
|
307
|
+
|
|
308
|
+
# Generate sigma schedule with time shift
|
|
309
|
+
sigmas = self._get_sigmas(mu, self.steps)
|
|
310
|
+
|
|
311
|
+
# Apply denoising_start and denoising_end clipping
|
|
312
|
+
if self.denoising_start > 0 or self.denoising_end < 1:
|
|
313
|
+
# Calculate start and end indices based on denoising range
|
|
314
|
+
total_sigmas = len(sigmas)
|
|
315
|
+
start_idx = int(self.denoising_start * (total_sigmas - 1))
|
|
316
|
+
end_idx = int(self.denoising_end * (total_sigmas - 1)) + 1
|
|
317
|
+
sigmas = sigmas[start_idx:end_idx]
|
|
318
|
+
|
|
319
|
+
total_steps = len(sigmas) - 1
|
|
320
|
+
|
|
321
|
+
# Load input latents if provided (image-to-image)
|
|
322
|
+
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
|
|
323
|
+
if init_latents is not None:
|
|
324
|
+
init_latents = init_latents.to(device=device, dtype=inference_dtype)
|
|
325
|
+
|
|
326
|
+
# Generate initial noise
|
|
327
|
+
num_channels_latents = 16 # Z-Image uses 16 latent channels
|
|
328
|
+
noise = self._get_noise(
|
|
329
|
+
batch_size=1,
|
|
330
|
+
num_channels_latents=num_channels_latents,
|
|
331
|
+
height=self.height,
|
|
332
|
+
width=self.width,
|
|
333
|
+
dtype=inference_dtype,
|
|
334
|
+
device=device,
|
|
335
|
+
seed=self.seed,
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
# Prepare input latent image
|
|
339
|
+
if init_latents is not None:
|
|
340
|
+
s_0 = sigmas[0]
|
|
341
|
+
latents = s_0 * noise + (1.0 - s_0) * init_latents
|
|
342
|
+
else:
|
|
343
|
+
if self.denoising_start > 1e-5:
|
|
344
|
+
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
|
|
345
|
+
latents = noise
|
|
346
|
+
|
|
347
|
+
# Short-circuit if no denoising steps
|
|
348
|
+
if total_steps <= 0:
|
|
349
|
+
return latents
|
|
350
|
+
|
|
351
|
+
# Prepare inpaint extension
|
|
352
|
+
inpaint_mask = self._prep_inpaint_mask(context, latents)
|
|
353
|
+
inpaint_extension: RectifiedFlowInpaintExtension | None = None
|
|
354
|
+
if inpaint_mask is not None:
|
|
355
|
+
if init_latents is None:
|
|
356
|
+
raise ValueError("Initial latents are required when using an inpaint mask (image-to-image inpainting)")
|
|
357
|
+
inpaint_extension = RectifiedFlowInpaintExtension(
|
|
358
|
+
init_latents=init_latents,
|
|
359
|
+
inpaint_mask=inpaint_mask,
|
|
360
|
+
noise=noise,
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
step_callback = self._build_step_callback(context)
|
|
364
|
+
step_callback(
|
|
365
|
+
PipelineIntermediateState(
|
|
366
|
+
step=0,
|
|
367
|
+
order=1,
|
|
368
|
+
total_steps=total_steps,
|
|
369
|
+
timestep=int(sigmas[0] * 1000),
|
|
370
|
+
latents=latents,
|
|
371
|
+
),
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
with ExitStack() as exit_stack:
|
|
375
|
+
# Get transformer config to determine if it's quantized
|
|
376
|
+
transformer_config = context.models.get_config(self.transformer.transformer)
|
|
377
|
+
|
|
378
|
+
# Determine if the model is quantized.
|
|
379
|
+
# If the model is quantized, then we need to apply the LoRA weights as sidecar layers. This results in
|
|
380
|
+
# slower inference than direct patching, but is agnostic to the quantization format.
|
|
381
|
+
if transformer_config.format in [ModelFormat.Diffusers, ModelFormat.Checkpoint]:
|
|
382
|
+
model_is_quantized = False
|
|
383
|
+
elif transformer_config.format in [ModelFormat.GGUFQuantized]:
|
|
384
|
+
model_is_quantized = True
|
|
385
|
+
else:
|
|
386
|
+
raise ValueError(f"Unsupported Z-Image model format: {transformer_config.format}")
|
|
387
|
+
|
|
388
|
+
# Load transformer - always use base transformer, control is handled via extension
|
|
389
|
+
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
|
|
390
|
+
|
|
391
|
+
# Prepare control extension if control is provided
|
|
392
|
+
control_extension: ZImageControlNetExtension | None = None
|
|
393
|
+
|
|
394
|
+
if self.control is not None:
|
|
395
|
+
# Load control adapter using context manager (proper GPU memory management)
|
|
396
|
+
control_model_info = context.models.load(self.control.control_model)
|
|
397
|
+
(_, control_adapter) = exit_stack.enter_context(control_model_info.model_on_device())
|
|
398
|
+
assert isinstance(control_adapter, ZImageControlAdapter)
|
|
399
|
+
|
|
400
|
+
# Get control_in_dim from adapter config (16 for V1, 33 for V2.0)
|
|
401
|
+
adapter_config = control_adapter.config
|
|
402
|
+
control_in_dim = adapter_config.get("control_in_dim", 16)
|
|
403
|
+
num_control_blocks = adapter_config.get("num_control_blocks", 6)
|
|
404
|
+
|
|
405
|
+
# Log control configuration for debugging
|
|
406
|
+
version = "V2.0" if control_in_dim > 16 else "V1"
|
|
407
|
+
context.util.signal_progress(
|
|
408
|
+
f"Using Z-Image ControlNet {version} (Extension): control_in_dim={control_in_dim}, "
|
|
409
|
+
f"num_blocks={num_control_blocks}, scale={self.control.control_context_scale}"
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
# Load and prepare control image - must be VAE-encoded!
|
|
413
|
+
if self.vae is None:
|
|
414
|
+
raise ValueError("VAE is required when using Z-Image Control. Connect a VAE to the 'vae' input.")
|
|
415
|
+
|
|
416
|
+
control_image = context.images.get_pil(self.control.image_name)
|
|
417
|
+
|
|
418
|
+
# Resize control image to match output dimensions
|
|
419
|
+
control_image = control_image.convert("RGB")
|
|
420
|
+
control_image = control_image.resize((self.width, self.height), Image.Resampling.LANCZOS)
|
|
421
|
+
|
|
422
|
+
# Convert to tensor format for VAE encoding
|
|
423
|
+
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
|
424
|
+
|
|
425
|
+
control_image_tensor = image_resized_to_grid_as_tensor(control_image)
|
|
426
|
+
if control_image_tensor.dim() == 3:
|
|
427
|
+
control_image_tensor = einops.rearrange(control_image_tensor, "c h w -> 1 c h w")
|
|
428
|
+
|
|
429
|
+
# Encode control image through VAE to get latents
|
|
430
|
+
vae_info = context.models.load(self.vae.vae)
|
|
431
|
+
control_latents = ZImageImageToLatentsInvocation.vae_encode(
|
|
432
|
+
vae_info=vae_info,
|
|
433
|
+
image_tensor=control_image_tensor,
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
# Move to inference device/dtype
|
|
437
|
+
control_latents = control_latents.to(device=device, dtype=inference_dtype)
|
|
438
|
+
|
|
439
|
+
# Add frame dimension: [B, C, H, W] -> [C, 1, H, W] (single image)
|
|
440
|
+
control_latents = control_latents.squeeze(0).unsqueeze(1)
|
|
441
|
+
|
|
442
|
+
# Prepare control_cond based on control_in_dim
|
|
443
|
+
# V1: 16 channels (just control latents)
|
|
444
|
+
# V2.0: 33 channels = 16 control + 16 reference + 1 mask
|
|
445
|
+
# - Channels 0-15: control image latents (from VAE encoding)
|
|
446
|
+
# - Channels 16-31: reference/inpaint image latents (zeros for pure control)
|
|
447
|
+
# - Channel 32: inpaint mask (1.0 = don't inpaint, 0.0 = inpaint region)
|
|
448
|
+
# For pure control (no inpainting), we set mask=1 to tell model "use control, don't inpaint"
|
|
449
|
+
c, f, h, w = control_latents.shape
|
|
450
|
+
if c < control_in_dim:
|
|
451
|
+
padding_channels = control_in_dim - c
|
|
452
|
+
if padding_channels == 17:
|
|
453
|
+
# V2.0: 16 reference channels (zeros) + 1 mask channel (ones)
|
|
454
|
+
ref_padding = torch.zeros(
|
|
455
|
+
(16, f, h, w),
|
|
456
|
+
device=device,
|
|
457
|
+
dtype=inference_dtype,
|
|
458
|
+
)
|
|
459
|
+
# Mask channel = 1.0 means "don't inpaint this region, use control signal"
|
|
460
|
+
mask_channel = torch.ones(
|
|
461
|
+
(1, f, h, w),
|
|
462
|
+
device=device,
|
|
463
|
+
dtype=inference_dtype,
|
|
464
|
+
)
|
|
465
|
+
control_latents = torch.cat([control_latents, ref_padding, mask_channel], dim=0)
|
|
466
|
+
else:
|
|
467
|
+
# Generic padding with zeros for other cases
|
|
468
|
+
zero_padding = torch.zeros(
|
|
469
|
+
(padding_channels, f, h, w),
|
|
470
|
+
device=device,
|
|
471
|
+
dtype=inference_dtype,
|
|
472
|
+
)
|
|
473
|
+
control_latents = torch.cat([control_latents, zero_padding], dim=0)
|
|
474
|
+
|
|
475
|
+
# Create control extension (adapter is already on device from model_on_device)
|
|
476
|
+
control_extension = ZImageControlNetExtension(
|
|
477
|
+
control_adapter=control_adapter,
|
|
478
|
+
control_cond=control_latents,
|
|
479
|
+
weight=self.control.control_context_scale,
|
|
480
|
+
begin_step_percent=self.control.begin_step_percent,
|
|
481
|
+
end_step_percent=self.control.end_step_percent,
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
# Apply LoRA models to the transformer.
|
|
485
|
+
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
|
|
486
|
+
exit_stack.enter_context(
|
|
487
|
+
LayerPatcher.apply_smart_model_patches(
|
|
488
|
+
model=transformer,
|
|
489
|
+
patches=self._lora_iterator(context),
|
|
490
|
+
prefix=Z_IMAGE_LORA_TRANSFORMER_PREFIX,
|
|
491
|
+
dtype=inference_dtype,
|
|
492
|
+
cached_weights=cached_weights,
|
|
493
|
+
force_sidecar_patching=model_is_quantized,
|
|
494
|
+
)
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
# Apply regional prompting patch if we have regional masks
|
|
498
|
+
exit_stack.enter_context(
|
|
499
|
+
patch_transformer_for_regional_prompting(
|
|
500
|
+
transformer=transformer,
|
|
501
|
+
regional_attn_mask=regional_extension.regional_attn_mask,
|
|
502
|
+
img_seq_len=img_seq_len,
|
|
503
|
+
)
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
# Denoising loop
|
|
507
|
+
for step_idx in tqdm(range(total_steps)):
|
|
508
|
+
sigma_curr = sigmas[step_idx]
|
|
509
|
+
sigma_prev = sigmas[step_idx + 1]
|
|
510
|
+
|
|
511
|
+
# Timestep tensor for Z-Image model
|
|
512
|
+
# The model expects t=0 at start (noise) and t=1 at end (clean)
|
|
513
|
+
# Sigma goes from 1 (noise) to 0 (clean), so model_t = 1 - sigma
|
|
514
|
+
model_t = 1.0 - sigma_curr
|
|
515
|
+
timestep = torch.tensor([model_t], device=device, dtype=inference_dtype).expand(latents.shape[0])
|
|
516
|
+
|
|
517
|
+
# Run transformer for positive prediction
|
|
518
|
+
# Z-Image transformer expects: x as list of [C, 1, H, W] tensors, t, cap_feats as list
|
|
519
|
+
# Prepare latent input: [B, C, H, W] -> [B, C, 1, H, W] -> list of [C, 1, H, W]
|
|
520
|
+
latent_model_input = latents.to(transformer.dtype)
|
|
521
|
+
latent_model_input = latent_model_input.unsqueeze(2) # Add frame dimension
|
|
522
|
+
latent_model_input_list = list(latent_model_input.unbind(dim=0))
|
|
523
|
+
|
|
524
|
+
# Determine if control should be applied at this step
|
|
525
|
+
apply_control = control_extension is not None and control_extension.should_apply(step_idx, total_steps)
|
|
526
|
+
|
|
527
|
+
# Run forward pass - use custom forward with control if extension is active
|
|
528
|
+
if apply_control:
|
|
529
|
+
model_out_list, _ = z_image_forward_with_control(
|
|
530
|
+
transformer=transformer,
|
|
531
|
+
x=latent_model_input_list,
|
|
532
|
+
t=timestep,
|
|
533
|
+
cap_feats=[pos_prompt_embeds],
|
|
534
|
+
control_extension=control_extension,
|
|
535
|
+
)
|
|
536
|
+
else:
|
|
537
|
+
model_output = transformer(
|
|
538
|
+
x=latent_model_input_list,
|
|
539
|
+
t=timestep,
|
|
540
|
+
cap_feats=[pos_prompt_embeds],
|
|
541
|
+
)
|
|
542
|
+
model_out_list = model_output[0] # Extract list of tensors from tuple
|
|
543
|
+
|
|
544
|
+
noise_pred_cond = torch.stack([t.float() for t in model_out_list], dim=0)
|
|
545
|
+
noise_pred_cond = noise_pred_cond.squeeze(2) # Remove frame dimension
|
|
546
|
+
noise_pred_cond = -noise_pred_cond # Z-Image uses v-prediction with negation
|
|
547
|
+
|
|
548
|
+
# Apply CFG if enabled
|
|
549
|
+
if do_classifier_free_guidance and neg_prompt_embeds is not None:
|
|
550
|
+
if apply_control:
|
|
551
|
+
model_out_list_uncond, _ = z_image_forward_with_control(
|
|
552
|
+
transformer=transformer,
|
|
553
|
+
x=latent_model_input_list,
|
|
554
|
+
t=timestep,
|
|
555
|
+
cap_feats=[neg_prompt_embeds],
|
|
556
|
+
control_extension=control_extension,
|
|
557
|
+
)
|
|
558
|
+
else:
|
|
559
|
+
model_output_uncond = transformer(
|
|
560
|
+
x=latent_model_input_list,
|
|
561
|
+
t=timestep,
|
|
562
|
+
cap_feats=[neg_prompt_embeds],
|
|
563
|
+
)
|
|
564
|
+
model_out_list_uncond = model_output_uncond[0] # Extract list of tensors from tuple
|
|
565
|
+
|
|
566
|
+
noise_pred_uncond = torch.stack([t.float() for t in model_out_list_uncond], dim=0)
|
|
567
|
+
noise_pred_uncond = noise_pred_uncond.squeeze(2)
|
|
568
|
+
noise_pred_uncond = -noise_pred_uncond
|
|
569
|
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
|
570
|
+
else:
|
|
571
|
+
noise_pred = noise_pred_cond
|
|
572
|
+
|
|
573
|
+
# Euler step
|
|
574
|
+
latents_dtype = latents.dtype
|
|
575
|
+
latents = latents.to(dtype=torch.float32)
|
|
576
|
+
latents = latents + (sigma_prev - sigma_curr) * noise_pred
|
|
577
|
+
latents = latents.to(dtype=latents_dtype)
|
|
578
|
+
|
|
579
|
+
if inpaint_extension is not None:
|
|
580
|
+
latents = inpaint_extension.merge_intermediate_latents_with_init_latents(latents, sigma_prev)
|
|
581
|
+
|
|
582
|
+
step_callback(
|
|
583
|
+
PipelineIntermediateState(
|
|
584
|
+
step=step_idx + 1,
|
|
585
|
+
order=1,
|
|
586
|
+
total_steps=total_steps,
|
|
587
|
+
timestep=int(sigma_curr * 1000),
|
|
588
|
+
latents=latents,
|
|
589
|
+
),
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
return latents
|
|
593
|
+
|
|
594
|
+
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
|
|
595
|
+
def step_callback(state: PipelineIntermediateState) -> None:
|
|
596
|
+
context.util.sd_step_callback(state, BaseModelType.ZImage)
|
|
597
|
+
|
|
598
|
+
return step_callback
|
|
599
|
+
|
|
600
|
+
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
|
|
601
|
+
"""Iterate over LoRA models to apply to the transformer."""
|
|
602
|
+
for lora in self.transformer.loras:
|
|
603
|
+
lora_info = context.models.load(lora.lora)
|
|
604
|
+
if not isinstance(lora_info.model, ModelPatchRaw):
|
|
605
|
+
raise TypeError(
|
|
606
|
+
f"Expected ModelPatchRaw for LoRA '{lora.lora.key}', got {type(lora_info.model).__name__}. "
|
|
607
|
+
"The LoRA model may be corrupted or incompatible."
|
|
608
|
+
)
|
|
609
|
+
yield (lora_info.model, lora.weight)
|
|
610
|
+
del lora_info
|