InvokeAI 6.9.0rc3__py3-none-any.whl → 6.10.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invokeai/app/api/dependencies.py +2 -0
- invokeai/app/api/routers/model_manager.py +91 -2
- invokeai/app/api/routers/workflows.py +9 -0
- invokeai/app/invocations/fields.py +19 -0
- invokeai/app/invocations/image_to_latents.py +23 -5
- invokeai/app/invocations/latents_to_image.py +2 -25
- invokeai/app/invocations/metadata.py +9 -1
- invokeai/app/invocations/model.py +8 -0
- invokeai/app/invocations/primitives.py +12 -0
- invokeai/app/invocations/prompt_template.py +57 -0
- invokeai/app/invocations/z_image_control.py +112 -0
- invokeai/app/invocations/z_image_denoise.py +610 -0
- invokeai/app/invocations/z_image_image_to_latents.py +102 -0
- invokeai/app/invocations/z_image_latents_to_image.py +103 -0
- invokeai/app/invocations/z_image_lora_loader.py +153 -0
- invokeai/app/invocations/z_image_model_loader.py +135 -0
- invokeai/app/invocations/z_image_text_encoder.py +197 -0
- invokeai/app/services/model_install/model_install_common.py +14 -1
- invokeai/app/services/model_install/model_install_default.py +119 -19
- invokeai/app/services/model_records/model_records_base.py +12 -0
- invokeai/app/services/model_records/model_records_sql.py +17 -0
- invokeai/app/services/shared/graph.py +132 -77
- invokeai/app/services/workflow_records/workflow_records_base.py +8 -0
- invokeai/app/services/workflow_records/workflow_records_sqlite.py +42 -0
- invokeai/app/util/step_callback.py +3 -0
- invokeai/backend/model_manager/configs/controlnet.py +47 -1
- invokeai/backend/model_manager/configs/factory.py +26 -1
- invokeai/backend/model_manager/configs/lora.py +43 -1
- invokeai/backend/model_manager/configs/main.py +113 -0
- invokeai/backend/model_manager/configs/qwen3_encoder.py +156 -0
- invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_diffusers_rms_norm.py +40 -0
- invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_layer_norm.py +25 -0
- invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py +11 -2
- invokeai/backend/model_manager/load/model_loaders/lora.py +11 -0
- invokeai/backend/model_manager/load/model_loaders/z_image.py +935 -0
- invokeai/backend/model_manager/load/model_util.py +6 -1
- invokeai/backend/model_manager/metadata/metadata_base.py +12 -5
- invokeai/backend/model_manager/model_on_disk.py +3 -0
- invokeai/backend/model_manager/starter_models.py +70 -0
- invokeai/backend/model_manager/taxonomy.py +5 -0
- invokeai/backend/model_manager/util/select_hf_files.py +23 -8
- invokeai/backend/patches/layer_patcher.py +34 -16
- invokeai/backend/patches/layers/lora_layer_base.py +2 -1
- invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py +17 -2
- invokeai/backend/patches/lora_conversions/flux_xlabs_lora_conversion_utils.py +92 -0
- invokeai/backend/patches/lora_conversions/formats.py +5 -0
- invokeai/backend/patches/lora_conversions/z_image_lora_constants.py +8 -0
- invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +155 -0
- invokeai/backend/quantization/gguf/ggml_tensor.py +27 -4
- invokeai/backend/quantization/gguf/loaders.py +47 -12
- invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +13 -0
- invokeai/backend/util/devices.py +25 -0
- invokeai/backend/util/hotfixes.py +2 -2
- invokeai/backend/z_image/__init__.py +16 -0
- invokeai/backend/z_image/extensions/__init__.py +1 -0
- invokeai/backend/z_image/extensions/regional_prompting_extension.py +207 -0
- invokeai/backend/z_image/text_conditioning.py +74 -0
- invokeai/backend/z_image/z_image_control_adapter.py +238 -0
- invokeai/backend/z_image/z_image_control_transformer.py +643 -0
- invokeai/backend/z_image/z_image_controlnet_extension.py +531 -0
- invokeai/backend/z_image/z_image_patchify_utils.py +135 -0
- invokeai/backend/z_image/z_image_transformer_patch.py +234 -0
- invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +161 -0
- invokeai/frontend/web/dist/assets/{browser-ponyfill-CN1j0ARZ.js → browser-ponyfill-DHZxq1nk.js} +1 -1
- invokeai/frontend/web/dist/assets/index-dgSJAY--.js +530 -0
- invokeai/frontend/web/dist/index.html +1 -1
- invokeai/frontend/web/dist/locales/de.json +24 -6
- invokeai/frontend/web/dist/locales/en.json +70 -1
- invokeai/frontend/web/dist/locales/es.json +0 -5
- invokeai/frontend/web/dist/locales/fr.json +0 -6
- invokeai/frontend/web/dist/locales/it.json +17 -64
- invokeai/frontend/web/dist/locales/ja.json +379 -44
- invokeai/frontend/web/dist/locales/ru.json +0 -6
- invokeai/frontend/web/dist/locales/vi.json +7 -54
- invokeai/frontend/web/dist/locales/zh-CN.json +0 -6
- invokeai/version/invokeai_version.py +1 -1
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/METADATA +3 -3
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/RECORD +84 -60
- invokeai/frontend/web/dist/assets/App-Cn9UyjoV.js +0 -161
- invokeai/frontend/web/dist/assets/index-BDrf9CL-.js +0 -530
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/WHEEL +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/entry_points.txt +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,643 @@
|
|
|
1
|
+
# Adapted from https://github.com/aigc-apps/VideoX-Fun/blob/main/videox_fun/models/z_image_transformer2d_control.py
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
# Apache License 2.0
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
Z-Image Control Transformer for InvokeAI.
|
|
7
|
+
|
|
8
|
+
This module provides the ZImageControlTransformer2DModel which extends the base
|
|
9
|
+
ZImageTransformer2DModel with control conditioning capabilities (Canny, HED, Depth, Pose, MLSD).
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from typing import Any, Dict, List, Optional
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
import torch.nn as nn
|
|
16
|
+
from diffusers.configuration_utils import register_to_config
|
|
17
|
+
from diffusers.models.transformers.transformer_z_image import (
|
|
18
|
+
SEQ_MULTI_OF,
|
|
19
|
+
ZImageTransformer2DModel,
|
|
20
|
+
ZImageTransformerBlock,
|
|
21
|
+
)
|
|
22
|
+
from diffusers.utils import is_torch_version
|
|
23
|
+
from torch.nn.utils.rnn import pad_sequence
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ZImageControlTransformerBlock(ZImageTransformerBlock):
|
|
27
|
+
"""Control-specific transformer block with skip connections for hint generation.
|
|
28
|
+
|
|
29
|
+
This block extends ZImageTransformerBlock with before_proj and after_proj layers
|
|
30
|
+
that create skip connections for the control signal. The hints are accumulated
|
|
31
|
+
across blocks and used to condition the main transformer.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
layer_id: int,
|
|
37
|
+
dim: int,
|
|
38
|
+
n_heads: int,
|
|
39
|
+
n_kv_heads: int,
|
|
40
|
+
norm_eps: float,
|
|
41
|
+
qk_norm: bool,
|
|
42
|
+
modulation: bool = True,
|
|
43
|
+
block_id: int = 0,
|
|
44
|
+
):
|
|
45
|
+
super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation)
|
|
46
|
+
self.block_id = block_id
|
|
47
|
+
if block_id == 0:
|
|
48
|
+
self.before_proj = nn.Linear(dim, dim)
|
|
49
|
+
nn.init.zeros_(self.before_proj.weight)
|
|
50
|
+
nn.init.zeros_(self.before_proj.bias)
|
|
51
|
+
self.after_proj = nn.Linear(dim, dim)
|
|
52
|
+
nn.init.zeros_(self.after_proj.weight)
|
|
53
|
+
nn.init.zeros_(self.after_proj.bias)
|
|
54
|
+
|
|
55
|
+
def forward(
|
|
56
|
+
self,
|
|
57
|
+
c: torch.Tensor,
|
|
58
|
+
x: torch.Tensor,
|
|
59
|
+
attn_mask: torch.Tensor,
|
|
60
|
+
freqs_cis: torch.Tensor,
|
|
61
|
+
adaln_input: Optional[torch.Tensor] = None,
|
|
62
|
+
) -> torch.Tensor:
|
|
63
|
+
if self.block_id == 0:
|
|
64
|
+
c = self.before_proj(c) + x
|
|
65
|
+
all_c: list[torch.Tensor] = []
|
|
66
|
+
else:
|
|
67
|
+
all_c = list(torch.unbind(c))
|
|
68
|
+
c = all_c.pop(-1)
|
|
69
|
+
|
|
70
|
+
c = super().forward(c, attn_mask=attn_mask, freqs_cis=freqs_cis, adaln_input=adaln_input)
|
|
71
|
+
c_skip = self.after_proj(c)
|
|
72
|
+
all_c += [c_skip, c]
|
|
73
|
+
c = torch.stack(all_c)
|
|
74
|
+
return c
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class BaseZImageTransformerBlock(ZImageTransformerBlock):
|
|
78
|
+
"""Modified transformer block that accepts control hints.
|
|
79
|
+
|
|
80
|
+
This block extends ZImageTransformerBlock to add control hints to the
|
|
81
|
+
hidden states at specific positions in the network.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
def __init__(
|
|
85
|
+
self,
|
|
86
|
+
layer_id: int,
|
|
87
|
+
dim: int,
|
|
88
|
+
n_heads: int,
|
|
89
|
+
n_kv_heads: int,
|
|
90
|
+
norm_eps: float,
|
|
91
|
+
qk_norm: bool,
|
|
92
|
+
modulation: bool = True,
|
|
93
|
+
block_id: Optional[int] = 0,
|
|
94
|
+
):
|
|
95
|
+
super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation)
|
|
96
|
+
self.block_id = block_id
|
|
97
|
+
|
|
98
|
+
def forward(
|
|
99
|
+
self,
|
|
100
|
+
hidden_states: torch.Tensor,
|
|
101
|
+
attn_mask: torch.Tensor,
|
|
102
|
+
freqs_cis: torch.Tensor,
|
|
103
|
+
adaln_input: Optional[torch.Tensor] = None,
|
|
104
|
+
hints: Optional[tuple[torch.Tensor, ...]] = None,
|
|
105
|
+
context_scale: float = 1.0,
|
|
106
|
+
) -> torch.Tensor:
|
|
107
|
+
hidden_states = super().forward(
|
|
108
|
+
hidden_states,
|
|
109
|
+
attn_mask=attn_mask,
|
|
110
|
+
freqs_cis=freqs_cis,
|
|
111
|
+
adaln_input=adaln_input,
|
|
112
|
+
)
|
|
113
|
+
if self.block_id is not None and hints is not None:
|
|
114
|
+
hidden_states = hidden_states + hints[self.block_id] * context_scale
|
|
115
|
+
return hidden_states
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class ZImageControlTransformer2DModel(ZImageTransformer2DModel):
|
|
119
|
+
"""Z-Image Control Transformer for spatial conditioning.
|
|
120
|
+
|
|
121
|
+
This model extends ZImageTransformer2DModel with control layers that process
|
|
122
|
+
a control image (e.g., Canny edges, depth map) and inject control signals
|
|
123
|
+
into the main transformer at every other layer.
|
|
124
|
+
|
|
125
|
+
The control model supports 5 modes: Canny, HED, Depth, Pose, MLSD.
|
|
126
|
+
Recommended control_context_scale: 0.65-0.80.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
control_layers_places: List of layer indices where control is applied.
|
|
130
|
+
Defaults to every other layer [0, 2, 4, ...].
|
|
131
|
+
control_in_dim: Input dimension for control context. Defaults to in_channels.
|
|
132
|
+
All other args are passed to ZImageTransformer2DModel.
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
@register_to_config
|
|
136
|
+
def __init__(
|
|
137
|
+
self,
|
|
138
|
+
control_layers_places: Optional[List[int]] = None,
|
|
139
|
+
control_in_dim: Optional[int] = None,
|
|
140
|
+
all_patch_size: tuple[int, ...] = (2,),
|
|
141
|
+
all_f_patch_size: tuple[int, ...] = (1,),
|
|
142
|
+
in_channels: int = 16,
|
|
143
|
+
dim: int = 3840,
|
|
144
|
+
n_layers: int = 30,
|
|
145
|
+
n_refiner_layers: int = 2,
|
|
146
|
+
n_heads: int = 30,
|
|
147
|
+
n_kv_heads: int = 30,
|
|
148
|
+
norm_eps: float = 1e-5,
|
|
149
|
+
qk_norm: bool = True,
|
|
150
|
+
cap_feat_dim: int = 2560,
|
|
151
|
+
rope_theta: float = 256.0,
|
|
152
|
+
t_scale: float = 1000.0,
|
|
153
|
+
axes_dims: tuple[int, ...] = (32, 48, 48),
|
|
154
|
+
axes_lens: tuple[int, ...] = (1024, 512, 512),
|
|
155
|
+
):
|
|
156
|
+
super().__init__(
|
|
157
|
+
all_patch_size=all_patch_size,
|
|
158
|
+
all_f_patch_size=all_f_patch_size,
|
|
159
|
+
in_channels=in_channels,
|
|
160
|
+
dim=dim,
|
|
161
|
+
n_layers=n_layers,
|
|
162
|
+
n_refiner_layers=n_refiner_layers,
|
|
163
|
+
n_heads=n_heads,
|
|
164
|
+
n_kv_heads=n_kv_heads,
|
|
165
|
+
norm_eps=norm_eps,
|
|
166
|
+
qk_norm=qk_norm,
|
|
167
|
+
cap_feat_dim=cap_feat_dim,
|
|
168
|
+
rope_theta=rope_theta,
|
|
169
|
+
t_scale=t_scale,
|
|
170
|
+
axes_dims=axes_dims,
|
|
171
|
+
axes_lens=axes_lens,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
# Control layer configuration
|
|
175
|
+
self.control_layers_places = (
|
|
176
|
+
list(range(0, n_layers, 2)) if control_layers_places is None else control_layers_places
|
|
177
|
+
)
|
|
178
|
+
self.control_in_dim = in_channels if control_in_dim is None else control_in_dim
|
|
179
|
+
|
|
180
|
+
assert 0 in self.control_layers_places
|
|
181
|
+
self.control_layers_mapping = {i: n for n, i in enumerate(self.control_layers_places)}
|
|
182
|
+
|
|
183
|
+
# Replace standard layers with control-aware layers
|
|
184
|
+
del self.layers
|
|
185
|
+
self.layers = nn.ModuleList(
|
|
186
|
+
[
|
|
187
|
+
BaseZImageTransformerBlock(
|
|
188
|
+
i,
|
|
189
|
+
dim,
|
|
190
|
+
n_heads,
|
|
191
|
+
n_kv_heads,
|
|
192
|
+
norm_eps,
|
|
193
|
+
qk_norm,
|
|
194
|
+
block_id=self.control_layers_mapping[i] if i in self.control_layers_places else None,
|
|
195
|
+
)
|
|
196
|
+
for i in range(n_layers)
|
|
197
|
+
]
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
# Control transformer blocks
|
|
201
|
+
self.control_layers = nn.ModuleList(
|
|
202
|
+
[
|
|
203
|
+
ZImageControlTransformerBlock(
|
|
204
|
+
i,
|
|
205
|
+
dim,
|
|
206
|
+
n_heads,
|
|
207
|
+
n_kv_heads,
|
|
208
|
+
norm_eps,
|
|
209
|
+
qk_norm,
|
|
210
|
+
block_id=i,
|
|
211
|
+
)
|
|
212
|
+
for i in range(len(self.control_layers_places))
|
|
213
|
+
]
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
# Control patch embeddings
|
|
217
|
+
all_x_embedder = {}
|
|
218
|
+
for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size, strict=True):
|
|
219
|
+
x_embedder = nn.Linear(
|
|
220
|
+
f_patch_size * patch_size * patch_size * self.control_in_dim,
|
|
221
|
+
dim,
|
|
222
|
+
bias=True,
|
|
223
|
+
)
|
|
224
|
+
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
|
|
225
|
+
|
|
226
|
+
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
|
|
227
|
+
|
|
228
|
+
# Control noise refiner
|
|
229
|
+
self.control_noise_refiner = nn.ModuleList(
|
|
230
|
+
[
|
|
231
|
+
ZImageTransformerBlock(
|
|
232
|
+
1000 + layer_id,
|
|
233
|
+
dim,
|
|
234
|
+
n_heads,
|
|
235
|
+
n_kv_heads,
|
|
236
|
+
norm_eps,
|
|
237
|
+
qk_norm,
|
|
238
|
+
modulation=True,
|
|
239
|
+
)
|
|
240
|
+
for layer_id in range(n_refiner_layers)
|
|
241
|
+
]
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
def patchify(
|
|
245
|
+
self,
|
|
246
|
+
all_image: List[torch.Tensor],
|
|
247
|
+
patch_size: int,
|
|
248
|
+
f_patch_size: int,
|
|
249
|
+
cap_seq_len: int,
|
|
250
|
+
) -> tuple[List[torch.Tensor], List[tuple], List[torch.Tensor], List[torch.Tensor]]:
|
|
251
|
+
"""Patchify images without embedding.
|
|
252
|
+
|
|
253
|
+
This method extracts patches from images for control context processing.
|
|
254
|
+
Unlike patchify_and_embed, this only processes images without caption features.
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
all_image: List of image tensors [C, F, H, W]
|
|
258
|
+
patch_size: Spatial patch size (height and width)
|
|
259
|
+
f_patch_size: Frame patch size
|
|
260
|
+
cap_seq_len: Caption sequence length (for position ID offset)
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
Tuple of:
|
|
264
|
+
- all_image_out: List of patchified image tensors
|
|
265
|
+
- all_image_size: List of (F, H, W) tuples
|
|
266
|
+
- all_image_pos_ids: List of position ID tensors
|
|
267
|
+
- all_image_pad_mask: List of padding mask tensors
|
|
268
|
+
"""
|
|
269
|
+
pH = pW = patch_size
|
|
270
|
+
pF = f_patch_size
|
|
271
|
+
device = all_image[0].device
|
|
272
|
+
|
|
273
|
+
all_image_out = []
|
|
274
|
+
all_image_size = []
|
|
275
|
+
all_image_pos_ids = []
|
|
276
|
+
all_image_pad_mask = []
|
|
277
|
+
|
|
278
|
+
# Calculate padded caption length for position offset
|
|
279
|
+
cap_padding_len = (-cap_seq_len) % SEQ_MULTI_OF
|
|
280
|
+
cap_padded_len = cap_seq_len + cap_padding_len
|
|
281
|
+
|
|
282
|
+
for image in all_image:
|
|
283
|
+
C, F, H, W = image.size()
|
|
284
|
+
all_image_size.append((F, H, W))
|
|
285
|
+
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
|
286
|
+
|
|
287
|
+
# Patchify: [C, F, H, W] -> [(F*H*W)/(patch), patch_elements * C]
|
|
288
|
+
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
|
289
|
+
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
|
290
|
+
|
|
291
|
+
image_ori_len = len(image)
|
|
292
|
+
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
|
|
293
|
+
|
|
294
|
+
# Create position IDs
|
|
295
|
+
image_ori_pos_ids = self.create_coordinate_grid(
|
|
296
|
+
size=(F_tokens, H_tokens, W_tokens),
|
|
297
|
+
start=(cap_padded_len + 1, 0, 0),
|
|
298
|
+
device=device,
|
|
299
|
+
).flatten(0, 2)
|
|
300
|
+
image_padding_pos_ids = (
|
|
301
|
+
self.create_coordinate_grid(
|
|
302
|
+
size=(1, 1, 1),
|
|
303
|
+
start=(0, 0, 0),
|
|
304
|
+
device=device,
|
|
305
|
+
)
|
|
306
|
+
.flatten(0, 2)
|
|
307
|
+
.repeat(image_padding_len, 1)
|
|
308
|
+
)
|
|
309
|
+
image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
|
|
310
|
+
all_image_pos_ids.append(image_padded_pos_ids)
|
|
311
|
+
|
|
312
|
+
# Padding mask
|
|
313
|
+
all_image_pad_mask.append(
|
|
314
|
+
torch.cat(
|
|
315
|
+
[
|
|
316
|
+
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
|
|
317
|
+
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
|
|
318
|
+
],
|
|
319
|
+
dim=0,
|
|
320
|
+
)
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
# Padded feature
|
|
324
|
+
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
|
|
325
|
+
all_image_out.append(image_padded_feat)
|
|
326
|
+
|
|
327
|
+
return all_image_out, all_image_size, all_image_pos_ids, all_image_pad_mask
|
|
328
|
+
|
|
329
|
+
def forward_control(
|
|
330
|
+
self,
|
|
331
|
+
x: torch.Tensor,
|
|
332
|
+
cap_feats: torch.Tensor,
|
|
333
|
+
control_context: List[torch.Tensor],
|
|
334
|
+
kwargs: Dict[str, Any],
|
|
335
|
+
t: torch.Tensor,
|
|
336
|
+
patch_size: int = 2,
|
|
337
|
+
f_patch_size: int = 1,
|
|
338
|
+
) -> tuple[torch.Tensor, ...]:
|
|
339
|
+
"""Process control context and generate hints for the main transformer.
|
|
340
|
+
|
|
341
|
+
Args:
|
|
342
|
+
x: Unified image+caption embeddings from main path
|
|
343
|
+
cap_feats: Caption feature embeddings
|
|
344
|
+
control_context: List of control images (VAE-encoded latents)
|
|
345
|
+
kwargs: Additional kwargs including attn_mask, freqs_cis
|
|
346
|
+
t: Timestep embeddings
|
|
347
|
+
patch_size: Spatial patch size
|
|
348
|
+
f_patch_size: Frame patch size
|
|
349
|
+
|
|
350
|
+
Returns:
|
|
351
|
+
Tuple of hint tensors to be added at each control layer position
|
|
352
|
+
"""
|
|
353
|
+
bsz = len(control_context)
|
|
354
|
+
device = control_context[0].device
|
|
355
|
+
|
|
356
|
+
# Patchify control context
|
|
357
|
+
(
|
|
358
|
+
control_context_patches,
|
|
359
|
+
x_size,
|
|
360
|
+
x_pos_ids,
|
|
361
|
+
x_inner_pad_mask,
|
|
362
|
+
) = self.patchify(control_context, patch_size, f_patch_size, cap_feats.size(1))
|
|
363
|
+
|
|
364
|
+
# Embed control context
|
|
365
|
+
x_item_seqlens = [len(_) for _ in control_context_patches]
|
|
366
|
+
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
|
|
367
|
+
x_max_item_seqlen = max(x_item_seqlens)
|
|
368
|
+
|
|
369
|
+
control_context_cat = torch.cat(control_context_patches, dim=0)
|
|
370
|
+
control_context_cat = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context_cat)
|
|
371
|
+
|
|
372
|
+
# Match t_embedder output dtype
|
|
373
|
+
adaln_input = t.type_as(control_context_cat)
|
|
374
|
+
control_context_cat[torch.cat(x_inner_pad_mask)] = self.x_pad_token
|
|
375
|
+
control_context_list = list(control_context_cat.split(x_item_seqlens, dim=0))
|
|
376
|
+
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
|
|
377
|
+
|
|
378
|
+
control_context_padded = pad_sequence(control_context_list, batch_first=True, padding_value=0.0)
|
|
379
|
+
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
|
|
380
|
+
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
|
|
381
|
+
for i, seq_len in enumerate(x_item_seqlens):
|
|
382
|
+
x_attn_mask[i, :seq_len] = 1
|
|
383
|
+
|
|
384
|
+
# Refine control context
|
|
385
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
386
|
+
for layer in self.control_noise_refiner:
|
|
387
|
+
|
|
388
|
+
def create_custom_forward(module):
|
|
389
|
+
def custom_forward(*inputs):
|
|
390
|
+
return module(*inputs)
|
|
391
|
+
|
|
392
|
+
return custom_forward
|
|
393
|
+
|
|
394
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
|
395
|
+
control_context_padded = torch.utils.checkpoint.checkpoint(
|
|
396
|
+
create_custom_forward(layer),
|
|
397
|
+
control_context_padded,
|
|
398
|
+
x_attn_mask,
|
|
399
|
+
x_freqs_cis,
|
|
400
|
+
adaln_input,
|
|
401
|
+
**ckpt_kwargs,
|
|
402
|
+
)
|
|
403
|
+
else:
|
|
404
|
+
for layer in self.control_noise_refiner:
|
|
405
|
+
control_context_padded = layer(control_context_padded, x_attn_mask, x_freqs_cis, adaln_input)
|
|
406
|
+
|
|
407
|
+
# Unify with caption features
|
|
408
|
+
cap_item_seqlens = [cap_feats.size(1)] * bsz # Assume same length for batch
|
|
409
|
+
control_context_unified = []
|
|
410
|
+
for i in range(bsz):
|
|
411
|
+
x_len = x_item_seqlens[i]
|
|
412
|
+
cap_len = cap_item_seqlens[i]
|
|
413
|
+
control_context_unified.append(torch.cat([control_context_padded[i][:x_len], cap_feats[i][:cap_len]]))
|
|
414
|
+
control_context_unified = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0)
|
|
415
|
+
c = control_context_unified
|
|
416
|
+
|
|
417
|
+
# Process through control layers
|
|
418
|
+
for layer in self.control_layers:
|
|
419
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
420
|
+
|
|
421
|
+
def create_custom_forward(module, **static_kwargs):
|
|
422
|
+
def custom_forward(*inputs):
|
|
423
|
+
return module(*inputs, **static_kwargs)
|
|
424
|
+
|
|
425
|
+
return custom_forward
|
|
426
|
+
|
|
427
|
+
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
|
428
|
+
c = torch.utils.checkpoint.checkpoint(
|
|
429
|
+
create_custom_forward(
|
|
430
|
+
layer,
|
|
431
|
+
x=x,
|
|
432
|
+
attn_mask=kwargs["attn_mask"],
|
|
433
|
+
freqs_cis=kwargs["freqs_cis"],
|
|
434
|
+
adaln_input=kwargs["adaln_input"],
|
|
435
|
+
),
|
|
436
|
+
c,
|
|
437
|
+
**ckpt_kwargs,
|
|
438
|
+
)
|
|
439
|
+
else:
|
|
440
|
+
c = layer(
|
|
441
|
+
c,
|
|
442
|
+
x=x,
|
|
443
|
+
attn_mask=kwargs["attn_mask"],
|
|
444
|
+
freqs_cis=kwargs["freqs_cis"],
|
|
445
|
+
adaln_input=kwargs["adaln_input"],
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
hints = torch.unbind(c)[:-1]
|
|
449
|
+
return hints
|
|
450
|
+
|
|
451
|
+
def forward(
|
|
452
|
+
self,
|
|
453
|
+
x: List[torch.Tensor],
|
|
454
|
+
t: torch.Tensor,
|
|
455
|
+
cap_feats: List[torch.Tensor],
|
|
456
|
+
patch_size: int = 2,
|
|
457
|
+
f_patch_size: int = 1,
|
|
458
|
+
control_context: Optional[List[torch.Tensor]] = None,
|
|
459
|
+
control_context_scale: float = 1.0,
|
|
460
|
+
) -> tuple[List[torch.Tensor], dict]:
|
|
461
|
+
"""Forward pass with control conditioning.
|
|
462
|
+
|
|
463
|
+
Args:
|
|
464
|
+
x: List of image tensors [B, C, 1, H, W]
|
|
465
|
+
t: Timestep tensor
|
|
466
|
+
cap_feats: List of caption feature tensors
|
|
467
|
+
patch_size: Spatial patch size (default 2)
|
|
468
|
+
f_patch_size: Frame patch size (default 1)
|
|
469
|
+
control_context: List of control image latents (VAE-encoded)
|
|
470
|
+
control_context_scale: Strength of control signal (0.65-0.80 recommended)
|
|
471
|
+
|
|
472
|
+
Returns:
|
|
473
|
+
Tuple of (output tensors, empty dict)
|
|
474
|
+
"""
|
|
475
|
+
assert patch_size in self.all_patch_size
|
|
476
|
+
assert f_patch_size in self.all_f_patch_size
|
|
477
|
+
|
|
478
|
+
if control_context is None:
|
|
479
|
+
# Fall back to base model behavior without control
|
|
480
|
+
return super().forward(x, t, cap_feats, patch_size, f_patch_size)
|
|
481
|
+
|
|
482
|
+
bsz = len(x)
|
|
483
|
+
device = x[0].device
|
|
484
|
+
t = t * self.t_scale
|
|
485
|
+
t = self.t_embedder(t)
|
|
486
|
+
|
|
487
|
+
(
|
|
488
|
+
x,
|
|
489
|
+
cap_feats,
|
|
490
|
+
x_size,
|
|
491
|
+
x_pos_ids,
|
|
492
|
+
cap_pos_ids,
|
|
493
|
+
x_inner_pad_mask,
|
|
494
|
+
cap_inner_pad_mask,
|
|
495
|
+
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
|
|
496
|
+
|
|
497
|
+
# Image embedding and refinement
|
|
498
|
+
x_item_seqlens = [len(_) for _ in x]
|
|
499
|
+
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
|
|
500
|
+
x_max_item_seqlen = max(x_item_seqlens)
|
|
501
|
+
|
|
502
|
+
x = torch.cat(x, dim=0)
|
|
503
|
+
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
|
|
504
|
+
|
|
505
|
+
adaln_input = t.type_as(x)
|
|
506
|
+
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
|
|
507
|
+
x = list(x.split(x_item_seqlens, dim=0))
|
|
508
|
+
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
|
|
509
|
+
|
|
510
|
+
x = pad_sequence(x, batch_first=True, padding_value=0.0)
|
|
511
|
+
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
|
|
512
|
+
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
|
|
513
|
+
for i, seq_len in enumerate(x_item_seqlens):
|
|
514
|
+
x_attn_mask[i, :seq_len] = 1
|
|
515
|
+
|
|
516
|
+
# Noise refiner
|
|
517
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
518
|
+
for layer in self.noise_refiner:
|
|
519
|
+
|
|
520
|
+
def create_custom_forward(module):
|
|
521
|
+
def custom_forward(*inputs):
|
|
522
|
+
return module(*inputs)
|
|
523
|
+
|
|
524
|
+
return custom_forward
|
|
525
|
+
|
|
526
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
|
527
|
+
x = torch.utils.checkpoint.checkpoint(
|
|
528
|
+
create_custom_forward(layer),
|
|
529
|
+
x,
|
|
530
|
+
x_attn_mask,
|
|
531
|
+
x_freqs_cis,
|
|
532
|
+
adaln_input,
|
|
533
|
+
**ckpt_kwargs,
|
|
534
|
+
)
|
|
535
|
+
else:
|
|
536
|
+
for layer in self.noise_refiner:
|
|
537
|
+
x = layer(x, x_attn_mask, x_freqs_cis, adaln_input)
|
|
538
|
+
|
|
539
|
+
# Caption embedding and refinement
|
|
540
|
+
cap_item_seqlens = [len(_) for _ in cap_feats]
|
|
541
|
+
assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens)
|
|
542
|
+
cap_max_item_seqlen = max(cap_item_seqlens)
|
|
543
|
+
|
|
544
|
+
cap_feats = torch.cat(cap_feats, dim=0)
|
|
545
|
+
cap_feats = self.cap_embedder(cap_feats)
|
|
546
|
+
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
|
|
547
|
+
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
|
|
548
|
+
cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
|
|
549
|
+
|
|
550
|
+
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
|
|
551
|
+
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
|
|
552
|
+
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
|
|
553
|
+
for i, seq_len in enumerate(cap_item_seqlens):
|
|
554
|
+
cap_attn_mask[i, :seq_len] = 1
|
|
555
|
+
|
|
556
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
557
|
+
for layer in self.context_refiner:
|
|
558
|
+
|
|
559
|
+
def create_custom_forward(module):
|
|
560
|
+
def custom_forward(*inputs):
|
|
561
|
+
return module(*inputs)
|
|
562
|
+
|
|
563
|
+
return custom_forward
|
|
564
|
+
|
|
565
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
|
566
|
+
cap_feats = torch.utils.checkpoint.checkpoint(
|
|
567
|
+
create_custom_forward(layer),
|
|
568
|
+
cap_feats,
|
|
569
|
+
cap_attn_mask,
|
|
570
|
+
cap_freqs_cis,
|
|
571
|
+
**ckpt_kwargs,
|
|
572
|
+
)
|
|
573
|
+
else:
|
|
574
|
+
for layer in self.context_refiner:
|
|
575
|
+
cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)
|
|
576
|
+
|
|
577
|
+
# Unified processing
|
|
578
|
+
unified = []
|
|
579
|
+
unified_freqs_cis = []
|
|
580
|
+
for i in range(bsz):
|
|
581
|
+
x_len = x_item_seqlens[i]
|
|
582
|
+
cap_len = cap_item_seqlens[i]
|
|
583
|
+
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
|
|
584
|
+
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
|
|
585
|
+
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens, strict=True)]
|
|
586
|
+
unified_max_item_seqlen = max(unified_item_seqlens)
|
|
587
|
+
|
|
588
|
+
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
|
|
589
|
+
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
|
|
590
|
+
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
|
|
591
|
+
for i, seq_len in enumerate(unified_item_seqlens):
|
|
592
|
+
unified_attn_mask[i, :seq_len] = 1
|
|
593
|
+
|
|
594
|
+
# Generate control hints
|
|
595
|
+
kwargs = {
|
|
596
|
+
"attn_mask": unified_attn_mask,
|
|
597
|
+
"freqs_cis": unified_freqs_cis,
|
|
598
|
+
"adaln_input": adaln_input,
|
|
599
|
+
}
|
|
600
|
+
hints = self.forward_control(
|
|
601
|
+
unified,
|
|
602
|
+
cap_feats,
|
|
603
|
+
control_context,
|
|
604
|
+
kwargs,
|
|
605
|
+
t=t,
|
|
606
|
+
patch_size=patch_size,
|
|
607
|
+
f_patch_size=f_patch_size,
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
# Main transformer with control hints
|
|
611
|
+
for layer in self.layers:
|
|
612
|
+
layer_kwargs = {
|
|
613
|
+
"attn_mask": unified_attn_mask,
|
|
614
|
+
"freqs_cis": unified_freqs_cis,
|
|
615
|
+
"adaln_input": adaln_input,
|
|
616
|
+
"hints": hints,
|
|
617
|
+
"context_scale": control_context_scale,
|
|
618
|
+
}
|
|
619
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
620
|
+
|
|
621
|
+
def create_custom_forward(module, **static_kwargs):
|
|
622
|
+
def custom_forward(*inputs):
|
|
623
|
+
return module(*inputs, **static_kwargs)
|
|
624
|
+
|
|
625
|
+
return custom_forward
|
|
626
|
+
|
|
627
|
+
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
|
628
|
+
|
|
629
|
+
unified = torch.utils.checkpoint.checkpoint(
|
|
630
|
+
create_custom_forward(layer, **layer_kwargs),
|
|
631
|
+
unified,
|
|
632
|
+
**ckpt_kwargs,
|
|
633
|
+
)
|
|
634
|
+
else:
|
|
635
|
+
unified = layer(unified, **layer_kwargs)
|
|
636
|
+
|
|
637
|
+
# Final layer and unpatchify
|
|
638
|
+
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
|
|
639
|
+
unified = list(unified.unbind(dim=0))
|
|
640
|
+
x = self.unpatchify(unified, x_size, patch_size, f_patch_size)
|
|
641
|
+
|
|
642
|
+
x = torch.stack(x)
|
|
643
|
+
return x, {}
|