InvokeAI 6.9.0rc3__py3-none-any.whl → 6.10.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invokeai/app/api/dependencies.py +2 -0
- invokeai/app/api/routers/model_manager.py +91 -2
- invokeai/app/api/routers/workflows.py +9 -0
- invokeai/app/invocations/fields.py +19 -0
- invokeai/app/invocations/image_to_latents.py +23 -5
- invokeai/app/invocations/latents_to_image.py +2 -25
- invokeai/app/invocations/metadata.py +9 -1
- invokeai/app/invocations/model.py +8 -0
- invokeai/app/invocations/primitives.py +12 -0
- invokeai/app/invocations/prompt_template.py +57 -0
- invokeai/app/invocations/z_image_control.py +112 -0
- invokeai/app/invocations/z_image_denoise.py +610 -0
- invokeai/app/invocations/z_image_image_to_latents.py +102 -0
- invokeai/app/invocations/z_image_latents_to_image.py +103 -0
- invokeai/app/invocations/z_image_lora_loader.py +153 -0
- invokeai/app/invocations/z_image_model_loader.py +135 -0
- invokeai/app/invocations/z_image_text_encoder.py +197 -0
- invokeai/app/services/model_install/model_install_common.py +14 -1
- invokeai/app/services/model_install/model_install_default.py +119 -19
- invokeai/app/services/model_records/model_records_base.py +12 -0
- invokeai/app/services/model_records/model_records_sql.py +17 -0
- invokeai/app/services/shared/graph.py +132 -77
- invokeai/app/services/workflow_records/workflow_records_base.py +8 -0
- invokeai/app/services/workflow_records/workflow_records_sqlite.py +42 -0
- invokeai/app/util/step_callback.py +3 -0
- invokeai/backend/model_manager/configs/controlnet.py +47 -1
- invokeai/backend/model_manager/configs/factory.py +26 -1
- invokeai/backend/model_manager/configs/lora.py +43 -1
- invokeai/backend/model_manager/configs/main.py +113 -0
- invokeai/backend/model_manager/configs/qwen3_encoder.py +156 -0
- invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_diffusers_rms_norm.py +40 -0
- invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_layer_norm.py +25 -0
- invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py +11 -2
- invokeai/backend/model_manager/load/model_loaders/lora.py +11 -0
- invokeai/backend/model_manager/load/model_loaders/z_image.py +935 -0
- invokeai/backend/model_manager/load/model_util.py +6 -1
- invokeai/backend/model_manager/metadata/metadata_base.py +12 -5
- invokeai/backend/model_manager/model_on_disk.py +3 -0
- invokeai/backend/model_manager/starter_models.py +70 -0
- invokeai/backend/model_manager/taxonomy.py +5 -0
- invokeai/backend/model_manager/util/select_hf_files.py +23 -8
- invokeai/backend/patches/layer_patcher.py +34 -16
- invokeai/backend/patches/layers/lora_layer_base.py +2 -1
- invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py +17 -2
- invokeai/backend/patches/lora_conversions/flux_xlabs_lora_conversion_utils.py +92 -0
- invokeai/backend/patches/lora_conversions/formats.py +5 -0
- invokeai/backend/patches/lora_conversions/z_image_lora_constants.py +8 -0
- invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +155 -0
- invokeai/backend/quantization/gguf/ggml_tensor.py +27 -4
- invokeai/backend/quantization/gguf/loaders.py +47 -12
- invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +13 -0
- invokeai/backend/util/devices.py +25 -0
- invokeai/backend/util/hotfixes.py +2 -2
- invokeai/backend/z_image/__init__.py +16 -0
- invokeai/backend/z_image/extensions/__init__.py +1 -0
- invokeai/backend/z_image/extensions/regional_prompting_extension.py +207 -0
- invokeai/backend/z_image/text_conditioning.py +74 -0
- invokeai/backend/z_image/z_image_control_adapter.py +238 -0
- invokeai/backend/z_image/z_image_control_transformer.py +643 -0
- invokeai/backend/z_image/z_image_controlnet_extension.py +531 -0
- invokeai/backend/z_image/z_image_patchify_utils.py +135 -0
- invokeai/backend/z_image/z_image_transformer_patch.py +234 -0
- invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +161 -0
- invokeai/frontend/web/dist/assets/{browser-ponyfill-CN1j0ARZ.js → browser-ponyfill-DHZxq1nk.js} +1 -1
- invokeai/frontend/web/dist/assets/index-dgSJAY--.js +530 -0
- invokeai/frontend/web/dist/index.html +1 -1
- invokeai/frontend/web/dist/locales/de.json +24 -6
- invokeai/frontend/web/dist/locales/en.json +70 -1
- invokeai/frontend/web/dist/locales/es.json +0 -5
- invokeai/frontend/web/dist/locales/fr.json +0 -6
- invokeai/frontend/web/dist/locales/it.json +17 -64
- invokeai/frontend/web/dist/locales/ja.json +379 -44
- invokeai/frontend/web/dist/locales/ru.json +0 -6
- invokeai/frontend/web/dist/locales/vi.json +7 -54
- invokeai/frontend/web/dist/locales/zh-CN.json +0 -6
- invokeai/version/invokeai_version.py +1 -1
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/METADATA +3 -3
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/RECORD +84 -60
- invokeai/frontend/web/dist/assets/App-Cn9UyjoV.js +0 -161
- invokeai/frontend/web/dist/assets/index-BDrf9CL-.js +0 -530
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/WHEEL +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/entry_points.txt +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,531 @@
|
|
|
1
|
+
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
|
2
|
+
"""Z-Image ControlNet Extension for spatial conditioning.
|
|
3
|
+
|
|
4
|
+
This module provides an extension-based approach to Z-Image ControlNet,
|
|
5
|
+
similar to how FLUX ControlNet works. Instead of duplicating the entire
|
|
6
|
+
transformer, we compute control hints separately and inject them into
|
|
7
|
+
the base transformer's forward pass.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from typing import List, Optional, Tuple
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel
|
|
14
|
+
from torch.nn.utils.rnn import pad_sequence
|
|
15
|
+
|
|
16
|
+
from invokeai.backend.z_image.z_image_control_adapter import ZImageControlAdapter
|
|
17
|
+
from invokeai.backend.z_image.z_image_patchify_utils import SEQ_MULTI_OF, patchify_control_context
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ZImageControlNetExtension:
|
|
21
|
+
"""Extension for Z-Image ControlNet - computes control hints without duplicating the transformer.
|
|
22
|
+
|
|
23
|
+
This class follows the same pattern as FLUX ControlNet extensions:
|
|
24
|
+
- The control adapter is loaded separately
|
|
25
|
+
- Control hints are computed per step
|
|
26
|
+
- Hints are injected into the transformer's layer outputs
|
|
27
|
+
|
|
28
|
+
Attributes:
|
|
29
|
+
control_adapter: The Z-Image control adapter model
|
|
30
|
+
control_cond: VAE-encoded control image latents
|
|
31
|
+
weight: Control strength (recommended: 0.65-0.80)
|
|
32
|
+
begin_step_percent: When to start applying control (0.0 = start)
|
|
33
|
+
end_step_percent: When to stop applying control (1.0 = end)
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
control_adapter: ZImageControlAdapter,
|
|
39
|
+
control_cond: torch.Tensor,
|
|
40
|
+
weight: float = 0.75,
|
|
41
|
+
begin_step_percent: float = 0.0,
|
|
42
|
+
end_step_percent: float = 1.0,
|
|
43
|
+
skip_layers: int = 0, # Skip first N control injection layers
|
|
44
|
+
):
|
|
45
|
+
self._adapter = control_adapter
|
|
46
|
+
self._control_cond = control_cond
|
|
47
|
+
self._weight = weight
|
|
48
|
+
self._begin_step_percent = begin_step_percent
|
|
49
|
+
self._end_step_percent = end_step_percent
|
|
50
|
+
self._skip_layers = skip_layers
|
|
51
|
+
|
|
52
|
+
# Get actual number of control blocks from loaded model (not config!)
|
|
53
|
+
# The safetensors may have more blocks than the config suggests
|
|
54
|
+
self._num_control_blocks = len(control_adapter.control_layers)
|
|
55
|
+
|
|
56
|
+
# Control layers are applied at every other layer (0, 2, 4, ...)
|
|
57
|
+
# This matches the default configuration in the original implementation
|
|
58
|
+
self._control_places = [i * 2 for i in range(self._num_control_blocks)]
|
|
59
|
+
|
|
60
|
+
# DEBUG: Print control configuration
|
|
61
|
+
print(f"[DEBUG] Actual num_control_blocks: {self._num_control_blocks}")
|
|
62
|
+
print(f"[DEBUG] control_places: {self._control_places}")
|
|
63
|
+
|
|
64
|
+
# DEBUG: Check if control_layers have non-zero weights
|
|
65
|
+
first_layer = control_adapter.control_layers[0]
|
|
66
|
+
if hasattr(first_layer, "after_proj"):
|
|
67
|
+
after_proj_norm = first_layer.after_proj.weight.norm().item()
|
|
68
|
+
print(f"[DEBUG] First control layer after_proj weight norm: {after_proj_norm}")
|
|
69
|
+
if after_proj_norm < 1e-6:
|
|
70
|
+
print("[WARNING] after_proj weights are near-zero! Weights may not be loaded correctly.")
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def weight(self) -> float:
|
|
74
|
+
return self._weight
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def control_places(self) -> List[int]:
|
|
78
|
+
return self._control_places
|
|
79
|
+
|
|
80
|
+
def should_apply(self, step_index: int, total_steps: int) -> bool:
|
|
81
|
+
"""Check if control should be applied at this step."""
|
|
82
|
+
if total_steps == 0:
|
|
83
|
+
return True
|
|
84
|
+
step_percent = step_index / total_steps
|
|
85
|
+
return self._begin_step_percent <= step_percent <= self._end_step_percent
|
|
86
|
+
|
|
87
|
+
def prepare_control_state(
|
|
88
|
+
self,
|
|
89
|
+
base_transformer: ZImageTransformer2DModel,
|
|
90
|
+
cap_feats: torch.Tensor,
|
|
91
|
+
timestep_emb: torch.Tensor,
|
|
92
|
+
x_item_seqlens: List[int],
|
|
93
|
+
cap_item_seqlens: List[int],
|
|
94
|
+
x_freqs_cis: torch.Tensor,
|
|
95
|
+
patch_size: int = 2,
|
|
96
|
+
f_patch_size: int = 1,
|
|
97
|
+
) -> torch.Tensor:
|
|
98
|
+
"""Prepare control state (control_unified) for incremental hint computation.
|
|
99
|
+
|
|
100
|
+
This processes the control condition through patchify and noise_refiner,
|
|
101
|
+
returning the control_unified tensor that will be used incrementally.
|
|
102
|
+
"""
|
|
103
|
+
bsz = 1
|
|
104
|
+
device = self._control_cond.device
|
|
105
|
+
|
|
106
|
+
# Patchify control context
|
|
107
|
+
control_context = [self._control_cond]
|
|
108
|
+
(
|
|
109
|
+
control_patches,
|
|
110
|
+
_,
|
|
111
|
+
_control_pos_ids,
|
|
112
|
+
control_pad_mask,
|
|
113
|
+
) = patchify_control_context(
|
|
114
|
+
control_context,
|
|
115
|
+
patch_size,
|
|
116
|
+
f_patch_size,
|
|
117
|
+
cap_feats.size(1),
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# Embed control context
|
|
121
|
+
ctrl_item_seqlens = [len(p) for p in control_patches]
|
|
122
|
+
ctrl_max_seqlen = max(ctrl_item_seqlens)
|
|
123
|
+
|
|
124
|
+
control_cat = torch.cat(control_patches, dim=0)
|
|
125
|
+
embedder_key = f"{patch_size}-{f_patch_size}"
|
|
126
|
+
control_cat = self._adapter.control_all_x_embedder[embedder_key](control_cat)
|
|
127
|
+
|
|
128
|
+
# Apply padding token
|
|
129
|
+
adaln_input = timestep_emb.type_as(control_cat)
|
|
130
|
+
x_pad_token = self._adapter.x_pad_token.to(dtype=control_cat.dtype)
|
|
131
|
+
control_cat[torch.cat(control_pad_mask)] = x_pad_token
|
|
132
|
+
|
|
133
|
+
control_list = list(control_cat.split(ctrl_item_seqlens, dim=0))
|
|
134
|
+
control_padded = pad_sequence(control_list, batch_first=True, padding_value=0.0)
|
|
135
|
+
|
|
136
|
+
# Use x_freqs_cis from main path for aligned position encoding
|
|
137
|
+
ctrl_freqs_cis_for_refiner = x_freqs_cis[:, : control_padded.shape[1]]
|
|
138
|
+
|
|
139
|
+
ctrl_attn_mask = torch.zeros((bsz, ctrl_max_seqlen), dtype=torch.bool, device=device)
|
|
140
|
+
for i, seq_len in enumerate(ctrl_item_seqlens):
|
|
141
|
+
ctrl_attn_mask[i, :seq_len] = 1
|
|
142
|
+
|
|
143
|
+
# Refine control context through control_noise_refiner
|
|
144
|
+
for layer in self._adapter.control_noise_refiner:
|
|
145
|
+
control_padded = layer(control_padded, ctrl_attn_mask, ctrl_freqs_cis_for_refiner, adaln_input)
|
|
146
|
+
|
|
147
|
+
# Store these for compute_single_hint
|
|
148
|
+
self._ctrl_item_seqlens = ctrl_item_seqlens
|
|
149
|
+
self._adaln_input = adaln_input
|
|
150
|
+
|
|
151
|
+
# Unify control with caption features
|
|
152
|
+
control_unified = []
|
|
153
|
+
for i in range(bsz):
|
|
154
|
+
ctrl_len = ctrl_item_seqlens[i]
|
|
155
|
+
cap_len = cap_item_seqlens[i]
|
|
156
|
+
control_unified.append(torch.cat([control_padded[i][:ctrl_len], cap_feats[i][:cap_len]]))
|
|
157
|
+
|
|
158
|
+
control_unified = pad_sequence(control_unified, batch_first=True, padding_value=0.0)
|
|
159
|
+
|
|
160
|
+
# DEBUG (only once)
|
|
161
|
+
if not hasattr(self, "_prepare_printed"):
|
|
162
|
+
self._prepare_printed = True
|
|
163
|
+
print(f"[DEBUG] Control state prepared: shape {control_unified.shape}")
|
|
164
|
+
|
|
165
|
+
return control_unified
|
|
166
|
+
|
|
167
|
+
def compute_single_hint(
|
|
168
|
+
self,
|
|
169
|
+
control_layer_idx: int,
|
|
170
|
+
control_state: torch.Tensor,
|
|
171
|
+
unified_hidden_states: torch.Tensor,
|
|
172
|
+
attn_mask: torch.Tensor,
|
|
173
|
+
freqs_cis: torch.Tensor,
|
|
174
|
+
adaln_input: torch.Tensor,
|
|
175
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
176
|
+
"""Compute a single hint from one control layer.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
control_layer_idx: Which control layer to use (0, 1, 2, ...)
|
|
180
|
+
control_state: Current control state (stacked tensor from previous layers)
|
|
181
|
+
unified_hidden_states: Current unified hidden states from main transformer
|
|
182
|
+
attn_mask: Attention mask
|
|
183
|
+
freqs_cis: RoPE frequencies
|
|
184
|
+
adaln_input: Timestep embedding
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
Tuple of (hint tensor, updated control_state)
|
|
188
|
+
"""
|
|
189
|
+
layer = self._adapter.control_layers[control_layer_idx]
|
|
190
|
+
|
|
191
|
+
# Run control layer with CURRENT unified_hidden_states
|
|
192
|
+
control_state = layer(
|
|
193
|
+
control_state,
|
|
194
|
+
x=unified_hidden_states,
|
|
195
|
+
attn_mask=attn_mask,
|
|
196
|
+
freqs_cis=freqs_cis,
|
|
197
|
+
adaln_input=adaln_input,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
# Extract hint from stacked state
|
|
201
|
+
# After control layer, control_state is stacked: [skip_0, ..., skip_n, running_state]
|
|
202
|
+
# We want the latest skip (second to last element)
|
|
203
|
+
unbinded = torch.unbind(control_state)
|
|
204
|
+
hint = unbinded[-2] # Latest skip connection
|
|
205
|
+
|
|
206
|
+
return hint, control_state
|
|
207
|
+
|
|
208
|
+
def compute_hints(
|
|
209
|
+
self,
|
|
210
|
+
base_transformer: ZImageTransformer2DModel,
|
|
211
|
+
unified_hidden_states: torch.Tensor,
|
|
212
|
+
cap_feats: torch.Tensor,
|
|
213
|
+
timestep_emb: torch.Tensor,
|
|
214
|
+
attn_mask: torch.Tensor,
|
|
215
|
+
freqs_cis: torch.Tensor,
|
|
216
|
+
x_item_seqlens: List[int],
|
|
217
|
+
cap_item_seqlens: List[int],
|
|
218
|
+
x_freqs_cis: torch.Tensor,
|
|
219
|
+
patch_size: int = 2,
|
|
220
|
+
f_patch_size: int = 1,
|
|
221
|
+
) -> Tuple[torch.Tensor, ...]:
|
|
222
|
+
"""Compute control hints using the adapter.
|
|
223
|
+
|
|
224
|
+
This method processes the control condition through the adapter's
|
|
225
|
+
control_noise_refiner and control_layers to produce hints that
|
|
226
|
+
will be added to the transformer's hidden states.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
base_transformer: The base Z-Image transformer (for rope_embedder)
|
|
230
|
+
unified_hidden_states: Combined image+caption hidden states
|
|
231
|
+
cap_feats: Caption feature embeddings (padded)
|
|
232
|
+
timestep_emb: Timestep embeddings (adaln_input)
|
|
233
|
+
attn_mask: Unified attention mask
|
|
234
|
+
freqs_cis: RoPE frequencies
|
|
235
|
+
x_item_seqlens: Image sequence lengths per batch item
|
|
236
|
+
cap_item_seqlens: Caption sequence lengths per batch item
|
|
237
|
+
patch_size: Spatial patch size
|
|
238
|
+
f_patch_size: Frame patch size
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
Tuple of hint tensors to add at each control layer position
|
|
242
|
+
"""
|
|
243
|
+
# control_cond is always [C, F, H, W] format (single control image)
|
|
244
|
+
# where C = control_in_dim (16 for V1, 33 for V2.0), F = 1 frame
|
|
245
|
+
bsz = 1
|
|
246
|
+
device = self._control_cond.device
|
|
247
|
+
|
|
248
|
+
# Wrap control_cond in a list for patchify_control_context
|
|
249
|
+
# Expected input: List of [C, F, H, W] tensors
|
|
250
|
+
control_context = [self._control_cond]
|
|
251
|
+
|
|
252
|
+
# Patchify control context
|
|
253
|
+
# Note: We don't use control_pos_ids anymore - we use x_freqs_cis from main path instead
|
|
254
|
+
(
|
|
255
|
+
control_patches,
|
|
256
|
+
_,
|
|
257
|
+
_control_pos_ids, # Not used - we use main path's position encoding
|
|
258
|
+
control_pad_mask,
|
|
259
|
+
) = patchify_control_context(
|
|
260
|
+
control_context,
|
|
261
|
+
patch_size,
|
|
262
|
+
f_patch_size,
|
|
263
|
+
cap_feats.size(1),
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# Embed control context
|
|
267
|
+
ctrl_item_seqlens = [len(p) for p in control_patches]
|
|
268
|
+
assert all(s % SEQ_MULTI_OF == 0 for s in ctrl_item_seqlens)
|
|
269
|
+
ctrl_max_seqlen = max(ctrl_item_seqlens)
|
|
270
|
+
|
|
271
|
+
control_cat = torch.cat(control_patches, dim=0)
|
|
272
|
+
embedder_key = f"{patch_size}-{f_patch_size}"
|
|
273
|
+
control_cat = self._adapter.control_all_x_embedder[embedder_key](control_cat)
|
|
274
|
+
|
|
275
|
+
# Apply padding token (ensure dtype matches)
|
|
276
|
+
adaln_input = timestep_emb.type_as(control_cat)
|
|
277
|
+
x_pad_token = self._adapter.x_pad_token.to(dtype=control_cat.dtype)
|
|
278
|
+
control_cat[torch.cat(control_pad_mask)] = x_pad_token
|
|
279
|
+
|
|
280
|
+
control_list = list(control_cat.split(ctrl_item_seqlens, dim=0))
|
|
281
|
+
|
|
282
|
+
control_padded = pad_sequence(control_list, batch_first=True, padding_value=0.0)
|
|
283
|
+
|
|
284
|
+
# Use x_freqs_cis from main path for control patches (same spatial structure)
|
|
285
|
+
# This ensures control and image have aligned position encodings
|
|
286
|
+
ctrl_freqs_cis_for_refiner = x_freqs_cis[:, : control_padded.shape[1]]
|
|
287
|
+
|
|
288
|
+
ctrl_attn_mask = torch.zeros((bsz, ctrl_max_seqlen), dtype=torch.bool, device=device)
|
|
289
|
+
for i, seq_len in enumerate(ctrl_item_seqlens):
|
|
290
|
+
ctrl_attn_mask[i, :seq_len] = 1
|
|
291
|
+
|
|
292
|
+
# Refine control context through control_noise_refiner
|
|
293
|
+
# Using x_freqs_cis to match main path's position encoding
|
|
294
|
+
for layer in self._adapter.control_noise_refiner:
|
|
295
|
+
control_padded = layer(control_padded, ctrl_attn_mask, ctrl_freqs_cis_for_refiner, adaln_input)
|
|
296
|
+
|
|
297
|
+
# Unify control with caption features
|
|
298
|
+
control_unified = []
|
|
299
|
+
for i in range(bsz):
|
|
300
|
+
ctrl_len = ctrl_item_seqlens[i]
|
|
301
|
+
cap_len = cap_item_seqlens[i]
|
|
302
|
+
control_unified.append(torch.cat([control_padded[i][:ctrl_len], cap_feats[i][:cap_len]]))
|
|
303
|
+
|
|
304
|
+
control_unified = pad_sequence(control_unified, batch_first=True, padding_value=0.0)
|
|
305
|
+
c = control_unified
|
|
306
|
+
|
|
307
|
+
# Process through control_layers to generate hints
|
|
308
|
+
# DEBUG: Print shapes before control_layers (only on first call)
|
|
309
|
+
if not hasattr(self, "_debug_printed"):
|
|
310
|
+
self._debug_printed = True
|
|
311
|
+
print(f"[DEBUG] control_unified shape: {control_unified.shape}")
|
|
312
|
+
print(f"[DEBUG] unified_hidden_states shape: {unified_hidden_states.shape}")
|
|
313
|
+
print(f"[DEBUG] ctrl_item_seqlens: {ctrl_item_seqlens}, x_item_seqlens: {x_item_seqlens}")
|
|
314
|
+
|
|
315
|
+
# Check weight norms of critical layers
|
|
316
|
+
layer0 = self._adapter.control_layers[0]
|
|
317
|
+
if hasattr(layer0, "before_proj"):
|
|
318
|
+
print(f"[DEBUG] before_proj weight norm: {layer0.before_proj.weight.norm().item():.6f}")
|
|
319
|
+
if hasattr(layer0, "after_proj"):
|
|
320
|
+
print(f"[DEBUG] after_proj weight norm: {layer0.after_proj.weight.norm().item():.6f}")
|
|
321
|
+
|
|
322
|
+
# Check control_noise_refiner weights
|
|
323
|
+
if len(self._adapter.control_noise_refiner) > 0:
|
|
324
|
+
refiner0 = self._adapter.control_noise_refiner[0]
|
|
325
|
+
if hasattr(refiner0, "attn"):
|
|
326
|
+
print(f"[DEBUG] noise_refiner[0] attn.wq norm: {refiner0.attn.wq.weight.norm().item():.6f}")
|
|
327
|
+
|
|
328
|
+
for layer in self._adapter.control_layers:
|
|
329
|
+
c = layer(
|
|
330
|
+
c,
|
|
331
|
+
x=unified_hidden_states,
|
|
332
|
+
attn_mask=attn_mask,
|
|
333
|
+
freqs_cis=freqs_cis,
|
|
334
|
+
adaln_input=adaln_input,
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
# Extract hints (all but the last element which is the running state)
|
|
338
|
+
hints = tuple(torch.unbind(c)[:-1])
|
|
339
|
+
|
|
340
|
+
# DEBUG: Print hint shapes (only on first call)
|
|
341
|
+
if not hasattr(self, "_hints_printed"):
|
|
342
|
+
self._hints_printed = True
|
|
343
|
+
print(f"[DEBUG] Number of hints: {len(hints)}")
|
|
344
|
+
if hints:
|
|
345
|
+
print(f"[DEBUG] First hint shape: {hints[0].shape}")
|
|
346
|
+
# Also check hint statistics for each hint
|
|
347
|
+
for i, h in enumerate(hints[:3]): # First 3 hints
|
|
348
|
+
print(
|
|
349
|
+
f"[DEBUG] Hint[{i}] mean: {h.mean().item():.6f}, std: {h.std().item():.6f}, min: {h.min().item():.6f}, max: {h.max().item():.6f}"
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
return hints
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def z_image_forward_with_control(
|
|
356
|
+
transformer: ZImageTransformer2DModel,
|
|
357
|
+
x: List[torch.Tensor],
|
|
358
|
+
t: torch.Tensor,
|
|
359
|
+
cap_feats: List[torch.Tensor],
|
|
360
|
+
control_extension: Optional[ZImageControlNetExtension] = None,
|
|
361
|
+
patch_size: int = 2,
|
|
362
|
+
f_patch_size: int = 1,
|
|
363
|
+
) -> Tuple[List[torch.Tensor], dict]:
|
|
364
|
+
"""Forward pass through Z-Image transformer with optional control injection.
|
|
365
|
+
|
|
366
|
+
This function replicates the base transformer's forward pass but allows
|
|
367
|
+
injecting control hints at specific layer positions. It uses the base
|
|
368
|
+
transformer's weights directly without duplicating them.
|
|
369
|
+
|
|
370
|
+
Args:
|
|
371
|
+
transformer: The base Z-Image transformer model
|
|
372
|
+
x: List of image tensors [C, F, H, W]
|
|
373
|
+
t: Timestep tensor
|
|
374
|
+
cap_feats: List of caption feature tensors
|
|
375
|
+
control_extension: Optional control extension for hint injection
|
|
376
|
+
patch_size: Spatial patch size (default: 2)
|
|
377
|
+
f_patch_size: Frame patch size (default: 1)
|
|
378
|
+
|
|
379
|
+
Returns:
|
|
380
|
+
Tuple of (output tensors list, empty dict for compatibility)
|
|
381
|
+
"""
|
|
382
|
+
assert patch_size in transformer.all_patch_size
|
|
383
|
+
assert f_patch_size in transformer.all_f_patch_size
|
|
384
|
+
|
|
385
|
+
bsz = len(x)
|
|
386
|
+
device = x[0].device
|
|
387
|
+
t_scaled = t * transformer.t_scale
|
|
388
|
+
t_emb = transformer.t_embedder(t_scaled)
|
|
389
|
+
|
|
390
|
+
# Patchify and embed using base transformer's method
|
|
391
|
+
(
|
|
392
|
+
x_patches,
|
|
393
|
+
cap_feats_patches,
|
|
394
|
+
x_size,
|
|
395
|
+
x_pos_ids,
|
|
396
|
+
cap_pos_ids,
|
|
397
|
+
x_inner_pad_mask,
|
|
398
|
+
cap_inner_pad_mask,
|
|
399
|
+
) = transformer.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
|
|
400
|
+
|
|
401
|
+
# === X embed & refine ===
|
|
402
|
+
x_item_seqlens = [len(p) for p in x_patches]
|
|
403
|
+
assert all(s % SEQ_MULTI_OF == 0 for s in x_item_seqlens)
|
|
404
|
+
x_max_item_seqlen = max(x_item_seqlens)
|
|
405
|
+
|
|
406
|
+
embedder_key = f"{patch_size}-{f_patch_size}"
|
|
407
|
+
x_cat = torch.cat(x_patches, dim=0)
|
|
408
|
+
x_cat = transformer.all_x_embedder[embedder_key](x_cat)
|
|
409
|
+
|
|
410
|
+
adaln_input = t_emb.type_as(x_cat)
|
|
411
|
+
x_cat[torch.cat(x_inner_pad_mask)] = transformer.x_pad_token
|
|
412
|
+
|
|
413
|
+
x_list = list(x_cat.split(x_item_seqlens, dim=0))
|
|
414
|
+
x_freqs_cis = list(transformer.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(p) for p in x_pos_ids], dim=0))
|
|
415
|
+
|
|
416
|
+
x_padded = pad_sequence(x_list, batch_first=True, padding_value=0.0)
|
|
417
|
+
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
|
|
418
|
+
x_freqs_cis = x_freqs_cis[:, : x_padded.shape[1]]
|
|
419
|
+
|
|
420
|
+
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
|
|
421
|
+
for i, seq_len in enumerate(x_item_seqlens):
|
|
422
|
+
x_attn_mask[i, :seq_len] = 1
|
|
423
|
+
|
|
424
|
+
# Noise refiner
|
|
425
|
+
for layer in transformer.noise_refiner:
|
|
426
|
+
x_padded = layer(x_padded, x_attn_mask, x_freqs_cis, adaln_input)
|
|
427
|
+
|
|
428
|
+
# === Cap embed & refine ===
|
|
429
|
+
cap_item_seqlens = [len(p) for p in cap_feats_patches]
|
|
430
|
+
cap_max_item_seqlen = max(cap_item_seqlens)
|
|
431
|
+
|
|
432
|
+
cap_cat = torch.cat(cap_feats_patches, dim=0)
|
|
433
|
+
cap_cat = transformer.cap_embedder(cap_cat)
|
|
434
|
+
cap_cat[torch.cat(cap_inner_pad_mask)] = transformer.cap_pad_token
|
|
435
|
+
|
|
436
|
+
cap_list = list(cap_cat.split(cap_item_seqlens, dim=0))
|
|
437
|
+
cap_freqs_cis = list(
|
|
438
|
+
transformer.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(p) for p in cap_pos_ids], dim=0)
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
cap_padded = pad_sequence(cap_list, batch_first=True, padding_value=0.0)
|
|
442
|
+
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
|
|
443
|
+
cap_freqs_cis = cap_freqs_cis[:, : cap_padded.shape[1]]
|
|
444
|
+
|
|
445
|
+
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
|
|
446
|
+
for i, seq_len in enumerate(cap_item_seqlens):
|
|
447
|
+
cap_attn_mask[i, :seq_len] = 1
|
|
448
|
+
|
|
449
|
+
# Context refiner
|
|
450
|
+
for layer in transformer.context_refiner:
|
|
451
|
+
cap_padded = layer(cap_padded, cap_attn_mask, cap_freqs_cis)
|
|
452
|
+
|
|
453
|
+
# === Unified ===
|
|
454
|
+
unified = []
|
|
455
|
+
unified_freqs_cis = []
|
|
456
|
+
for i in range(bsz):
|
|
457
|
+
x_len = x_item_seqlens[i]
|
|
458
|
+
cap_len = cap_item_seqlens[i]
|
|
459
|
+
unified.append(torch.cat([x_padded[i][:x_len], cap_padded[i][:cap_len]]))
|
|
460
|
+
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
|
|
461
|
+
|
|
462
|
+
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens, strict=False)]
|
|
463
|
+
unified_max_item_seqlen = max(unified_item_seqlens)
|
|
464
|
+
|
|
465
|
+
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
|
|
466
|
+
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
|
|
467
|
+
|
|
468
|
+
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
|
|
469
|
+
for i, seq_len in enumerate(unified_item_seqlens):
|
|
470
|
+
unified_attn_mask[i, :seq_len] = 1
|
|
471
|
+
|
|
472
|
+
# === Compute control hints if extension provided ===
|
|
473
|
+
# IMPORTANT: Hints are computed ONCE using the INITIAL unified state (before main layers)
|
|
474
|
+
# This matches the original VideoX-Fun architecture
|
|
475
|
+
control_places: List[int] = []
|
|
476
|
+
control_weight: float = 1.0
|
|
477
|
+
hints: Optional[Tuple[torch.Tensor, ...]] = None
|
|
478
|
+
|
|
479
|
+
# DEBUG: Print number of transformer layers (only once per session)
|
|
480
|
+
if not hasattr(z_image_forward_with_control, "_layers_printed"):
|
|
481
|
+
z_image_forward_with_control._layers_printed = True
|
|
482
|
+
print(f"[DEBUG] Base transformer has {len(transformer.layers)} layers")
|
|
483
|
+
|
|
484
|
+
if control_extension is not None:
|
|
485
|
+
# Compute ALL hints at once using the INITIAL unified state (before main layers run)
|
|
486
|
+
hints = control_extension.compute_hints(
|
|
487
|
+
base_transformer=transformer,
|
|
488
|
+
unified_hidden_states=unified, # INITIAL unified state!
|
|
489
|
+
cap_feats=cap_padded,
|
|
490
|
+
timestep_emb=adaln_input,
|
|
491
|
+
attn_mask=unified_attn_mask,
|
|
492
|
+
freqs_cis=unified_freqs_cis,
|
|
493
|
+
x_item_seqlens=x_item_seqlens,
|
|
494
|
+
cap_item_seqlens=cap_item_seqlens,
|
|
495
|
+
x_freqs_cis=x_freqs_cis,
|
|
496
|
+
patch_size=patch_size,
|
|
497
|
+
f_patch_size=f_patch_size,
|
|
498
|
+
)
|
|
499
|
+
control_places = control_extension.control_places
|
|
500
|
+
control_weight = control_extension.weight
|
|
501
|
+
|
|
502
|
+
# === Main transformer layers with pre-computed hint injection ===
|
|
503
|
+
skip_layers = control_extension._skip_layers if control_extension is not None else 0
|
|
504
|
+
control_layer_idx = 0
|
|
505
|
+
for layer_idx, layer in enumerate(transformer.layers):
|
|
506
|
+
unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input)
|
|
507
|
+
|
|
508
|
+
# Inject pre-computed control hint at designated positions
|
|
509
|
+
if hints is not None and layer_idx in control_places and control_layer_idx < len(hints):
|
|
510
|
+
# Skip first N hints if configured
|
|
511
|
+
if control_layer_idx >= skip_layers:
|
|
512
|
+
hint = hints[control_layer_idx]
|
|
513
|
+
|
|
514
|
+
# DEBUG: Print on first injection
|
|
515
|
+
if not hasattr(z_image_forward_with_control, "_injection_printed"):
|
|
516
|
+
z_image_forward_with_control._injection_printed = True
|
|
517
|
+
print(f"[DEBUG] Injection at layer {layer_idx} (control_layer {control_layer_idx})")
|
|
518
|
+
print(f"[DEBUG] Hint mean: {hint.mean().item():.6f}, std: {hint.std().item():.6f}")
|
|
519
|
+
print(f"[DEBUG] Unified mean: {unified.mean().item():.6f}, std: {unified.std().item():.6f}")
|
|
520
|
+
print(f"[DEBUG] control_weight: {control_weight}, skip_layers: {skip_layers}")
|
|
521
|
+
|
|
522
|
+
unified = unified + hint * control_weight
|
|
523
|
+
|
|
524
|
+
control_layer_idx += 1
|
|
525
|
+
|
|
526
|
+
# === Final layer and unpatchify ===
|
|
527
|
+
unified = transformer.all_final_layer[embedder_key](unified, adaln_input)
|
|
528
|
+
unified = list(unified.unbind(dim=0))
|
|
529
|
+
output = transformer.unpatchify(unified, x_size, patch_size, f_patch_size)
|
|
530
|
+
|
|
531
|
+
return output, {}
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
|
2
|
+
"""Utility functions for Z-Image patchify operations."""
|
|
3
|
+
|
|
4
|
+
from typing import List, Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
# Sequence must be multiple of this value (from diffusers transformer_z_image)
|
|
9
|
+
SEQ_MULTI_OF = 32
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def create_coordinate_grid(
|
|
13
|
+
size: Tuple[int, ...],
|
|
14
|
+
start: Tuple[int, ...] | None = None,
|
|
15
|
+
device: torch.device | None = None,
|
|
16
|
+
) -> torch.Tensor:
|
|
17
|
+
"""Create a coordinate grid for position embeddings.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
size: Size of the grid (e.g., (F, H, W))
|
|
21
|
+
start: Starting coordinates (default: all zeros)
|
|
22
|
+
device: Target device
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Coordinate grid tensor of shape (*size, len(size))
|
|
26
|
+
"""
|
|
27
|
+
if start is None:
|
|
28
|
+
start = tuple(0 for _ in size)
|
|
29
|
+
|
|
30
|
+
axes = [
|
|
31
|
+
torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size, strict=False)
|
|
32
|
+
]
|
|
33
|
+
grids = torch.meshgrid(axes, indexing="ij")
|
|
34
|
+
return torch.stack(grids, dim=-1)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def patchify_control_context(
|
|
38
|
+
all_image: List[torch.Tensor],
|
|
39
|
+
patch_size: int,
|
|
40
|
+
f_patch_size: int,
|
|
41
|
+
cap_seq_len: int,
|
|
42
|
+
) -> Tuple[List[torch.Tensor], List[Tuple[int, int, int]], List[torch.Tensor], List[torch.Tensor]]:
|
|
43
|
+
"""Patchify control images without embedding.
|
|
44
|
+
|
|
45
|
+
This function extracts patches from control images for control context processing.
|
|
46
|
+
It handles padding and position ID creation for the control signal.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
all_image: List of control image tensors [C, F, H, W]
|
|
50
|
+
patch_size: Spatial patch size (height and width)
|
|
51
|
+
f_patch_size: Frame patch size
|
|
52
|
+
cap_seq_len: Caption sequence length (for position ID offset)
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Tuple of:
|
|
56
|
+
- all_image_out: List of patchified image tensors
|
|
57
|
+
- all_image_size: List of (F, H, W) tuples
|
|
58
|
+
- all_image_pos_ids: List of position ID tensors
|
|
59
|
+
- all_image_pad_mask: List of padding mask tensors
|
|
60
|
+
"""
|
|
61
|
+
pH = pW = patch_size
|
|
62
|
+
pF = f_patch_size
|
|
63
|
+
device = all_image[0].device
|
|
64
|
+
|
|
65
|
+
all_image_out: List[torch.Tensor] = []
|
|
66
|
+
all_image_size: List[Tuple[int, int, int]] = []
|
|
67
|
+
all_image_pos_ids: List[torch.Tensor] = []
|
|
68
|
+
all_image_pad_mask: List[torch.Tensor] = []
|
|
69
|
+
|
|
70
|
+
# Calculate padded caption length for position offset
|
|
71
|
+
cap_padding_len = (-cap_seq_len) % SEQ_MULTI_OF
|
|
72
|
+
cap_padded_len = cap_seq_len + cap_padding_len
|
|
73
|
+
|
|
74
|
+
for image in all_image:
|
|
75
|
+
C, F, H, W = image.size()
|
|
76
|
+
all_image_size.append((F, H, W))
|
|
77
|
+
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
|
78
|
+
|
|
79
|
+
# Patchify: [C, F, H, W] -> [(F_tokens*H_tokens*W_tokens), (pF*pH*pW*C)]
|
|
80
|
+
# Step 1: Rearrange to put spatial dims together for proper patching
|
|
81
|
+
# [C, F, H, W] -> [F, H, W, C]
|
|
82
|
+
image = image.permute(1, 2, 3, 0).contiguous()
|
|
83
|
+
|
|
84
|
+
# Step 2: Split H and W into tokens and patch sizes
|
|
85
|
+
# [F, H, W, C] -> [F, H_tokens, pH, W_tokens, pW, C]
|
|
86
|
+
image = image.view(F, H_tokens, pH, W_tokens, pW, C)
|
|
87
|
+
|
|
88
|
+
# Step 3: Rearrange to group patches and features
|
|
89
|
+
# [F, H_tokens, pH, W_tokens, pW, C] -> [F, H_tokens, W_tokens, pH, pW, C]
|
|
90
|
+
image = image.permute(0, 1, 3, 2, 4, 5).contiguous()
|
|
91
|
+
|
|
92
|
+
# Step 4: For F > 1, we'd need to handle F similarly, but for F=1 this is simpler
|
|
93
|
+
# Final reshape: [F*H_tokens*W_tokens, pH*pW*C]
|
|
94
|
+
num_patches = F_tokens * H_tokens * W_tokens
|
|
95
|
+
patch_features = pF * pH * pW * C
|
|
96
|
+
image = image.reshape(num_patches, patch_features)
|
|
97
|
+
|
|
98
|
+
image_ori_len = len(image)
|
|
99
|
+
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
|
|
100
|
+
|
|
101
|
+
# Create position IDs
|
|
102
|
+
image_ori_pos_ids = create_coordinate_grid(
|
|
103
|
+
size=(F_tokens, H_tokens, W_tokens),
|
|
104
|
+
start=(cap_padded_len + 1, 0, 0),
|
|
105
|
+
device=device,
|
|
106
|
+
).flatten(0, 2)
|
|
107
|
+
|
|
108
|
+
image_padding_pos_ids = (
|
|
109
|
+
create_coordinate_grid(
|
|
110
|
+
size=(1, 1, 1),
|
|
111
|
+
start=(0, 0, 0),
|
|
112
|
+
device=device,
|
|
113
|
+
)
|
|
114
|
+
.flatten(0, 2)
|
|
115
|
+
.repeat(image_padding_len, 1)
|
|
116
|
+
)
|
|
117
|
+
image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
|
|
118
|
+
all_image_pos_ids.append(image_padded_pos_ids)
|
|
119
|
+
|
|
120
|
+
# Padding mask
|
|
121
|
+
all_image_pad_mask.append(
|
|
122
|
+
torch.cat(
|
|
123
|
+
[
|
|
124
|
+
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
|
|
125
|
+
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
|
|
126
|
+
],
|
|
127
|
+
dim=0,
|
|
128
|
+
)
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
# Padded feature
|
|
132
|
+
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
|
|
133
|
+
all_image_out.append(image_padded_feat)
|
|
134
|
+
|
|
135
|
+
return all_image_out, all_image_size, all_image_pos_ids, all_image_pad_mask
|