InvokeAI 6.9.0rc3__py3-none-any.whl → 6.10.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. invokeai/app/api/dependencies.py +2 -0
  2. invokeai/app/api/routers/model_manager.py +91 -2
  3. invokeai/app/api/routers/workflows.py +9 -0
  4. invokeai/app/invocations/fields.py +19 -0
  5. invokeai/app/invocations/image_to_latents.py +23 -5
  6. invokeai/app/invocations/latents_to_image.py +2 -25
  7. invokeai/app/invocations/metadata.py +9 -1
  8. invokeai/app/invocations/model.py +8 -0
  9. invokeai/app/invocations/primitives.py +12 -0
  10. invokeai/app/invocations/prompt_template.py +57 -0
  11. invokeai/app/invocations/z_image_control.py +112 -0
  12. invokeai/app/invocations/z_image_denoise.py +610 -0
  13. invokeai/app/invocations/z_image_image_to_latents.py +102 -0
  14. invokeai/app/invocations/z_image_latents_to_image.py +103 -0
  15. invokeai/app/invocations/z_image_lora_loader.py +153 -0
  16. invokeai/app/invocations/z_image_model_loader.py +135 -0
  17. invokeai/app/invocations/z_image_text_encoder.py +197 -0
  18. invokeai/app/services/model_install/model_install_common.py +14 -1
  19. invokeai/app/services/model_install/model_install_default.py +119 -19
  20. invokeai/app/services/model_records/model_records_base.py +12 -0
  21. invokeai/app/services/model_records/model_records_sql.py +17 -0
  22. invokeai/app/services/shared/graph.py +132 -77
  23. invokeai/app/services/workflow_records/workflow_records_base.py +8 -0
  24. invokeai/app/services/workflow_records/workflow_records_sqlite.py +42 -0
  25. invokeai/app/util/step_callback.py +3 -0
  26. invokeai/backend/model_manager/configs/controlnet.py +47 -1
  27. invokeai/backend/model_manager/configs/factory.py +26 -1
  28. invokeai/backend/model_manager/configs/lora.py +43 -1
  29. invokeai/backend/model_manager/configs/main.py +113 -0
  30. invokeai/backend/model_manager/configs/qwen3_encoder.py +156 -0
  31. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_diffusers_rms_norm.py +40 -0
  32. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_layer_norm.py +25 -0
  33. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py +11 -2
  34. invokeai/backend/model_manager/load/model_loaders/lora.py +11 -0
  35. invokeai/backend/model_manager/load/model_loaders/z_image.py +935 -0
  36. invokeai/backend/model_manager/load/model_util.py +6 -1
  37. invokeai/backend/model_manager/metadata/metadata_base.py +12 -5
  38. invokeai/backend/model_manager/model_on_disk.py +3 -0
  39. invokeai/backend/model_manager/starter_models.py +70 -0
  40. invokeai/backend/model_manager/taxonomy.py +5 -0
  41. invokeai/backend/model_manager/util/select_hf_files.py +23 -8
  42. invokeai/backend/patches/layer_patcher.py +34 -16
  43. invokeai/backend/patches/layers/lora_layer_base.py +2 -1
  44. invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py +17 -2
  45. invokeai/backend/patches/lora_conversions/flux_xlabs_lora_conversion_utils.py +92 -0
  46. invokeai/backend/patches/lora_conversions/formats.py +5 -0
  47. invokeai/backend/patches/lora_conversions/z_image_lora_constants.py +8 -0
  48. invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +155 -0
  49. invokeai/backend/quantization/gguf/ggml_tensor.py +27 -4
  50. invokeai/backend/quantization/gguf/loaders.py +47 -12
  51. invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +13 -0
  52. invokeai/backend/util/devices.py +25 -0
  53. invokeai/backend/util/hotfixes.py +2 -2
  54. invokeai/backend/z_image/__init__.py +16 -0
  55. invokeai/backend/z_image/extensions/__init__.py +1 -0
  56. invokeai/backend/z_image/extensions/regional_prompting_extension.py +207 -0
  57. invokeai/backend/z_image/text_conditioning.py +74 -0
  58. invokeai/backend/z_image/z_image_control_adapter.py +238 -0
  59. invokeai/backend/z_image/z_image_control_transformer.py +643 -0
  60. invokeai/backend/z_image/z_image_controlnet_extension.py +531 -0
  61. invokeai/backend/z_image/z_image_patchify_utils.py +135 -0
  62. invokeai/backend/z_image/z_image_transformer_patch.py +234 -0
  63. invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +161 -0
  64. invokeai/frontend/web/dist/assets/{browser-ponyfill-CN1j0ARZ.js → browser-ponyfill-DHZxq1nk.js} +1 -1
  65. invokeai/frontend/web/dist/assets/index-dgSJAY--.js +530 -0
  66. invokeai/frontend/web/dist/index.html +1 -1
  67. invokeai/frontend/web/dist/locales/de.json +24 -6
  68. invokeai/frontend/web/dist/locales/en.json +70 -1
  69. invokeai/frontend/web/dist/locales/es.json +0 -5
  70. invokeai/frontend/web/dist/locales/fr.json +0 -6
  71. invokeai/frontend/web/dist/locales/it.json +17 -64
  72. invokeai/frontend/web/dist/locales/ja.json +379 -44
  73. invokeai/frontend/web/dist/locales/ru.json +0 -6
  74. invokeai/frontend/web/dist/locales/vi.json +7 -54
  75. invokeai/frontend/web/dist/locales/zh-CN.json +0 -6
  76. invokeai/version/invokeai_version.py +1 -1
  77. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/METADATA +3 -3
  78. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/RECORD +84 -60
  79. invokeai/frontend/web/dist/assets/App-Cn9UyjoV.js +0 -161
  80. invokeai/frontend/web/dist/assets/index-BDrf9CL-.js +0 -530
  81. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/WHEEL +0 -0
  82. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/entry_points.txt +0 -0
  83. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE +0 -0
  84. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
  85. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
  86. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,643 @@
1
+ # Adapted from https://github.com/aigc-apps/VideoX-Fun/blob/main/videox_fun/models/z_image_transformer2d_control.py
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ # Apache License 2.0
4
+
5
+ """
6
+ Z-Image Control Transformer for InvokeAI.
7
+
8
+ This module provides the ZImageControlTransformer2DModel which extends the base
9
+ ZImageTransformer2DModel with control conditioning capabilities (Canny, HED, Depth, Pose, MLSD).
10
+ """
11
+
12
+ from typing import Any, Dict, List, Optional
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from diffusers.configuration_utils import register_to_config
17
+ from diffusers.models.transformers.transformer_z_image import (
18
+ SEQ_MULTI_OF,
19
+ ZImageTransformer2DModel,
20
+ ZImageTransformerBlock,
21
+ )
22
+ from diffusers.utils import is_torch_version
23
+ from torch.nn.utils.rnn import pad_sequence
24
+
25
+
26
+ class ZImageControlTransformerBlock(ZImageTransformerBlock):
27
+ """Control-specific transformer block with skip connections for hint generation.
28
+
29
+ This block extends ZImageTransformerBlock with before_proj and after_proj layers
30
+ that create skip connections for the control signal. The hints are accumulated
31
+ across blocks and used to condition the main transformer.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ layer_id: int,
37
+ dim: int,
38
+ n_heads: int,
39
+ n_kv_heads: int,
40
+ norm_eps: float,
41
+ qk_norm: bool,
42
+ modulation: bool = True,
43
+ block_id: int = 0,
44
+ ):
45
+ super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation)
46
+ self.block_id = block_id
47
+ if block_id == 0:
48
+ self.before_proj = nn.Linear(dim, dim)
49
+ nn.init.zeros_(self.before_proj.weight)
50
+ nn.init.zeros_(self.before_proj.bias)
51
+ self.after_proj = nn.Linear(dim, dim)
52
+ nn.init.zeros_(self.after_proj.weight)
53
+ nn.init.zeros_(self.after_proj.bias)
54
+
55
+ def forward(
56
+ self,
57
+ c: torch.Tensor,
58
+ x: torch.Tensor,
59
+ attn_mask: torch.Tensor,
60
+ freqs_cis: torch.Tensor,
61
+ adaln_input: Optional[torch.Tensor] = None,
62
+ ) -> torch.Tensor:
63
+ if self.block_id == 0:
64
+ c = self.before_proj(c) + x
65
+ all_c: list[torch.Tensor] = []
66
+ else:
67
+ all_c = list(torch.unbind(c))
68
+ c = all_c.pop(-1)
69
+
70
+ c = super().forward(c, attn_mask=attn_mask, freqs_cis=freqs_cis, adaln_input=adaln_input)
71
+ c_skip = self.after_proj(c)
72
+ all_c += [c_skip, c]
73
+ c = torch.stack(all_c)
74
+ return c
75
+
76
+
77
+ class BaseZImageTransformerBlock(ZImageTransformerBlock):
78
+ """Modified transformer block that accepts control hints.
79
+
80
+ This block extends ZImageTransformerBlock to add control hints to the
81
+ hidden states at specific positions in the network.
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ layer_id: int,
87
+ dim: int,
88
+ n_heads: int,
89
+ n_kv_heads: int,
90
+ norm_eps: float,
91
+ qk_norm: bool,
92
+ modulation: bool = True,
93
+ block_id: Optional[int] = 0,
94
+ ):
95
+ super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation)
96
+ self.block_id = block_id
97
+
98
+ def forward(
99
+ self,
100
+ hidden_states: torch.Tensor,
101
+ attn_mask: torch.Tensor,
102
+ freqs_cis: torch.Tensor,
103
+ adaln_input: Optional[torch.Tensor] = None,
104
+ hints: Optional[tuple[torch.Tensor, ...]] = None,
105
+ context_scale: float = 1.0,
106
+ ) -> torch.Tensor:
107
+ hidden_states = super().forward(
108
+ hidden_states,
109
+ attn_mask=attn_mask,
110
+ freqs_cis=freqs_cis,
111
+ adaln_input=adaln_input,
112
+ )
113
+ if self.block_id is not None and hints is not None:
114
+ hidden_states = hidden_states + hints[self.block_id] * context_scale
115
+ return hidden_states
116
+
117
+
118
+ class ZImageControlTransformer2DModel(ZImageTransformer2DModel):
119
+ """Z-Image Control Transformer for spatial conditioning.
120
+
121
+ This model extends ZImageTransformer2DModel with control layers that process
122
+ a control image (e.g., Canny edges, depth map) and inject control signals
123
+ into the main transformer at every other layer.
124
+
125
+ The control model supports 5 modes: Canny, HED, Depth, Pose, MLSD.
126
+ Recommended control_context_scale: 0.65-0.80.
127
+
128
+ Args:
129
+ control_layers_places: List of layer indices where control is applied.
130
+ Defaults to every other layer [0, 2, 4, ...].
131
+ control_in_dim: Input dimension for control context. Defaults to in_channels.
132
+ All other args are passed to ZImageTransformer2DModel.
133
+ """
134
+
135
+ @register_to_config
136
+ def __init__(
137
+ self,
138
+ control_layers_places: Optional[List[int]] = None,
139
+ control_in_dim: Optional[int] = None,
140
+ all_patch_size: tuple[int, ...] = (2,),
141
+ all_f_patch_size: tuple[int, ...] = (1,),
142
+ in_channels: int = 16,
143
+ dim: int = 3840,
144
+ n_layers: int = 30,
145
+ n_refiner_layers: int = 2,
146
+ n_heads: int = 30,
147
+ n_kv_heads: int = 30,
148
+ norm_eps: float = 1e-5,
149
+ qk_norm: bool = True,
150
+ cap_feat_dim: int = 2560,
151
+ rope_theta: float = 256.0,
152
+ t_scale: float = 1000.0,
153
+ axes_dims: tuple[int, ...] = (32, 48, 48),
154
+ axes_lens: tuple[int, ...] = (1024, 512, 512),
155
+ ):
156
+ super().__init__(
157
+ all_patch_size=all_patch_size,
158
+ all_f_patch_size=all_f_patch_size,
159
+ in_channels=in_channels,
160
+ dim=dim,
161
+ n_layers=n_layers,
162
+ n_refiner_layers=n_refiner_layers,
163
+ n_heads=n_heads,
164
+ n_kv_heads=n_kv_heads,
165
+ norm_eps=norm_eps,
166
+ qk_norm=qk_norm,
167
+ cap_feat_dim=cap_feat_dim,
168
+ rope_theta=rope_theta,
169
+ t_scale=t_scale,
170
+ axes_dims=axes_dims,
171
+ axes_lens=axes_lens,
172
+ )
173
+
174
+ # Control layer configuration
175
+ self.control_layers_places = (
176
+ list(range(0, n_layers, 2)) if control_layers_places is None else control_layers_places
177
+ )
178
+ self.control_in_dim = in_channels if control_in_dim is None else control_in_dim
179
+
180
+ assert 0 in self.control_layers_places
181
+ self.control_layers_mapping = {i: n for n, i in enumerate(self.control_layers_places)}
182
+
183
+ # Replace standard layers with control-aware layers
184
+ del self.layers
185
+ self.layers = nn.ModuleList(
186
+ [
187
+ BaseZImageTransformerBlock(
188
+ i,
189
+ dim,
190
+ n_heads,
191
+ n_kv_heads,
192
+ norm_eps,
193
+ qk_norm,
194
+ block_id=self.control_layers_mapping[i] if i in self.control_layers_places else None,
195
+ )
196
+ for i in range(n_layers)
197
+ ]
198
+ )
199
+
200
+ # Control transformer blocks
201
+ self.control_layers = nn.ModuleList(
202
+ [
203
+ ZImageControlTransformerBlock(
204
+ i,
205
+ dim,
206
+ n_heads,
207
+ n_kv_heads,
208
+ norm_eps,
209
+ qk_norm,
210
+ block_id=i,
211
+ )
212
+ for i in range(len(self.control_layers_places))
213
+ ]
214
+ )
215
+
216
+ # Control patch embeddings
217
+ all_x_embedder = {}
218
+ for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size, strict=True):
219
+ x_embedder = nn.Linear(
220
+ f_patch_size * patch_size * patch_size * self.control_in_dim,
221
+ dim,
222
+ bias=True,
223
+ )
224
+ all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
225
+
226
+ self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
227
+
228
+ # Control noise refiner
229
+ self.control_noise_refiner = nn.ModuleList(
230
+ [
231
+ ZImageTransformerBlock(
232
+ 1000 + layer_id,
233
+ dim,
234
+ n_heads,
235
+ n_kv_heads,
236
+ norm_eps,
237
+ qk_norm,
238
+ modulation=True,
239
+ )
240
+ for layer_id in range(n_refiner_layers)
241
+ ]
242
+ )
243
+
244
+ def patchify(
245
+ self,
246
+ all_image: List[torch.Tensor],
247
+ patch_size: int,
248
+ f_patch_size: int,
249
+ cap_seq_len: int,
250
+ ) -> tuple[List[torch.Tensor], List[tuple], List[torch.Tensor], List[torch.Tensor]]:
251
+ """Patchify images without embedding.
252
+
253
+ This method extracts patches from images for control context processing.
254
+ Unlike patchify_and_embed, this only processes images without caption features.
255
+
256
+ Args:
257
+ all_image: List of image tensors [C, F, H, W]
258
+ patch_size: Spatial patch size (height and width)
259
+ f_patch_size: Frame patch size
260
+ cap_seq_len: Caption sequence length (for position ID offset)
261
+
262
+ Returns:
263
+ Tuple of:
264
+ - all_image_out: List of patchified image tensors
265
+ - all_image_size: List of (F, H, W) tuples
266
+ - all_image_pos_ids: List of position ID tensors
267
+ - all_image_pad_mask: List of padding mask tensors
268
+ """
269
+ pH = pW = patch_size
270
+ pF = f_patch_size
271
+ device = all_image[0].device
272
+
273
+ all_image_out = []
274
+ all_image_size = []
275
+ all_image_pos_ids = []
276
+ all_image_pad_mask = []
277
+
278
+ # Calculate padded caption length for position offset
279
+ cap_padding_len = (-cap_seq_len) % SEQ_MULTI_OF
280
+ cap_padded_len = cap_seq_len + cap_padding_len
281
+
282
+ for image in all_image:
283
+ C, F, H, W = image.size()
284
+ all_image_size.append((F, H, W))
285
+ F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
286
+
287
+ # Patchify: [C, F, H, W] -> [(F*H*W)/(patch), patch_elements * C]
288
+ image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
289
+ image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
290
+
291
+ image_ori_len = len(image)
292
+ image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
293
+
294
+ # Create position IDs
295
+ image_ori_pos_ids = self.create_coordinate_grid(
296
+ size=(F_tokens, H_tokens, W_tokens),
297
+ start=(cap_padded_len + 1, 0, 0),
298
+ device=device,
299
+ ).flatten(0, 2)
300
+ image_padding_pos_ids = (
301
+ self.create_coordinate_grid(
302
+ size=(1, 1, 1),
303
+ start=(0, 0, 0),
304
+ device=device,
305
+ )
306
+ .flatten(0, 2)
307
+ .repeat(image_padding_len, 1)
308
+ )
309
+ image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
310
+ all_image_pos_ids.append(image_padded_pos_ids)
311
+
312
+ # Padding mask
313
+ all_image_pad_mask.append(
314
+ torch.cat(
315
+ [
316
+ torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
317
+ torch.ones((image_padding_len,), dtype=torch.bool, device=device),
318
+ ],
319
+ dim=0,
320
+ )
321
+ )
322
+
323
+ # Padded feature
324
+ image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
325
+ all_image_out.append(image_padded_feat)
326
+
327
+ return all_image_out, all_image_size, all_image_pos_ids, all_image_pad_mask
328
+
329
+ def forward_control(
330
+ self,
331
+ x: torch.Tensor,
332
+ cap_feats: torch.Tensor,
333
+ control_context: List[torch.Tensor],
334
+ kwargs: Dict[str, Any],
335
+ t: torch.Tensor,
336
+ patch_size: int = 2,
337
+ f_patch_size: int = 1,
338
+ ) -> tuple[torch.Tensor, ...]:
339
+ """Process control context and generate hints for the main transformer.
340
+
341
+ Args:
342
+ x: Unified image+caption embeddings from main path
343
+ cap_feats: Caption feature embeddings
344
+ control_context: List of control images (VAE-encoded latents)
345
+ kwargs: Additional kwargs including attn_mask, freqs_cis
346
+ t: Timestep embeddings
347
+ patch_size: Spatial patch size
348
+ f_patch_size: Frame patch size
349
+
350
+ Returns:
351
+ Tuple of hint tensors to be added at each control layer position
352
+ """
353
+ bsz = len(control_context)
354
+ device = control_context[0].device
355
+
356
+ # Patchify control context
357
+ (
358
+ control_context_patches,
359
+ x_size,
360
+ x_pos_ids,
361
+ x_inner_pad_mask,
362
+ ) = self.patchify(control_context, patch_size, f_patch_size, cap_feats.size(1))
363
+
364
+ # Embed control context
365
+ x_item_seqlens = [len(_) for _ in control_context_patches]
366
+ assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
367
+ x_max_item_seqlen = max(x_item_seqlens)
368
+
369
+ control_context_cat = torch.cat(control_context_patches, dim=0)
370
+ control_context_cat = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context_cat)
371
+
372
+ # Match t_embedder output dtype
373
+ adaln_input = t.type_as(control_context_cat)
374
+ control_context_cat[torch.cat(x_inner_pad_mask)] = self.x_pad_token
375
+ control_context_list = list(control_context_cat.split(x_item_seqlens, dim=0))
376
+ x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
377
+
378
+ control_context_padded = pad_sequence(control_context_list, batch_first=True, padding_value=0.0)
379
+ x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
380
+ x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
381
+ for i, seq_len in enumerate(x_item_seqlens):
382
+ x_attn_mask[i, :seq_len] = 1
383
+
384
+ # Refine control context
385
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
386
+ for layer in self.control_noise_refiner:
387
+
388
+ def create_custom_forward(module):
389
+ def custom_forward(*inputs):
390
+ return module(*inputs)
391
+
392
+ return custom_forward
393
+
394
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
395
+ control_context_padded = torch.utils.checkpoint.checkpoint(
396
+ create_custom_forward(layer),
397
+ control_context_padded,
398
+ x_attn_mask,
399
+ x_freqs_cis,
400
+ adaln_input,
401
+ **ckpt_kwargs,
402
+ )
403
+ else:
404
+ for layer in self.control_noise_refiner:
405
+ control_context_padded = layer(control_context_padded, x_attn_mask, x_freqs_cis, adaln_input)
406
+
407
+ # Unify with caption features
408
+ cap_item_seqlens = [cap_feats.size(1)] * bsz # Assume same length for batch
409
+ control_context_unified = []
410
+ for i in range(bsz):
411
+ x_len = x_item_seqlens[i]
412
+ cap_len = cap_item_seqlens[i]
413
+ control_context_unified.append(torch.cat([control_context_padded[i][:x_len], cap_feats[i][:cap_len]]))
414
+ control_context_unified = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0)
415
+ c = control_context_unified
416
+
417
+ # Process through control layers
418
+ for layer in self.control_layers:
419
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
420
+
421
+ def create_custom_forward(module, **static_kwargs):
422
+ def custom_forward(*inputs):
423
+ return module(*inputs, **static_kwargs)
424
+
425
+ return custom_forward
426
+
427
+ ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
428
+ c = torch.utils.checkpoint.checkpoint(
429
+ create_custom_forward(
430
+ layer,
431
+ x=x,
432
+ attn_mask=kwargs["attn_mask"],
433
+ freqs_cis=kwargs["freqs_cis"],
434
+ adaln_input=kwargs["adaln_input"],
435
+ ),
436
+ c,
437
+ **ckpt_kwargs,
438
+ )
439
+ else:
440
+ c = layer(
441
+ c,
442
+ x=x,
443
+ attn_mask=kwargs["attn_mask"],
444
+ freqs_cis=kwargs["freqs_cis"],
445
+ adaln_input=kwargs["adaln_input"],
446
+ )
447
+
448
+ hints = torch.unbind(c)[:-1]
449
+ return hints
450
+
451
+ def forward(
452
+ self,
453
+ x: List[torch.Tensor],
454
+ t: torch.Tensor,
455
+ cap_feats: List[torch.Tensor],
456
+ patch_size: int = 2,
457
+ f_patch_size: int = 1,
458
+ control_context: Optional[List[torch.Tensor]] = None,
459
+ control_context_scale: float = 1.0,
460
+ ) -> tuple[List[torch.Tensor], dict]:
461
+ """Forward pass with control conditioning.
462
+
463
+ Args:
464
+ x: List of image tensors [B, C, 1, H, W]
465
+ t: Timestep tensor
466
+ cap_feats: List of caption feature tensors
467
+ patch_size: Spatial patch size (default 2)
468
+ f_patch_size: Frame patch size (default 1)
469
+ control_context: List of control image latents (VAE-encoded)
470
+ control_context_scale: Strength of control signal (0.65-0.80 recommended)
471
+
472
+ Returns:
473
+ Tuple of (output tensors, empty dict)
474
+ """
475
+ assert patch_size in self.all_patch_size
476
+ assert f_patch_size in self.all_f_patch_size
477
+
478
+ if control_context is None:
479
+ # Fall back to base model behavior without control
480
+ return super().forward(x, t, cap_feats, patch_size, f_patch_size)
481
+
482
+ bsz = len(x)
483
+ device = x[0].device
484
+ t = t * self.t_scale
485
+ t = self.t_embedder(t)
486
+
487
+ (
488
+ x,
489
+ cap_feats,
490
+ x_size,
491
+ x_pos_ids,
492
+ cap_pos_ids,
493
+ x_inner_pad_mask,
494
+ cap_inner_pad_mask,
495
+ ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
496
+
497
+ # Image embedding and refinement
498
+ x_item_seqlens = [len(_) for _ in x]
499
+ assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
500
+ x_max_item_seqlen = max(x_item_seqlens)
501
+
502
+ x = torch.cat(x, dim=0)
503
+ x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
504
+
505
+ adaln_input = t.type_as(x)
506
+ x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
507
+ x = list(x.split(x_item_seqlens, dim=0))
508
+ x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
509
+
510
+ x = pad_sequence(x, batch_first=True, padding_value=0.0)
511
+ x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
512
+ x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
513
+ for i, seq_len in enumerate(x_item_seqlens):
514
+ x_attn_mask[i, :seq_len] = 1
515
+
516
+ # Noise refiner
517
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
518
+ for layer in self.noise_refiner:
519
+
520
+ def create_custom_forward(module):
521
+ def custom_forward(*inputs):
522
+ return module(*inputs)
523
+
524
+ return custom_forward
525
+
526
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
527
+ x = torch.utils.checkpoint.checkpoint(
528
+ create_custom_forward(layer),
529
+ x,
530
+ x_attn_mask,
531
+ x_freqs_cis,
532
+ adaln_input,
533
+ **ckpt_kwargs,
534
+ )
535
+ else:
536
+ for layer in self.noise_refiner:
537
+ x = layer(x, x_attn_mask, x_freqs_cis, adaln_input)
538
+
539
+ # Caption embedding and refinement
540
+ cap_item_seqlens = [len(_) for _ in cap_feats]
541
+ assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens)
542
+ cap_max_item_seqlen = max(cap_item_seqlens)
543
+
544
+ cap_feats = torch.cat(cap_feats, dim=0)
545
+ cap_feats = self.cap_embedder(cap_feats)
546
+ cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
547
+ cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
548
+ cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
549
+
550
+ cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
551
+ cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
552
+ cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
553
+ for i, seq_len in enumerate(cap_item_seqlens):
554
+ cap_attn_mask[i, :seq_len] = 1
555
+
556
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
557
+ for layer in self.context_refiner:
558
+
559
+ def create_custom_forward(module):
560
+ def custom_forward(*inputs):
561
+ return module(*inputs)
562
+
563
+ return custom_forward
564
+
565
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
566
+ cap_feats = torch.utils.checkpoint.checkpoint(
567
+ create_custom_forward(layer),
568
+ cap_feats,
569
+ cap_attn_mask,
570
+ cap_freqs_cis,
571
+ **ckpt_kwargs,
572
+ )
573
+ else:
574
+ for layer in self.context_refiner:
575
+ cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)
576
+
577
+ # Unified processing
578
+ unified = []
579
+ unified_freqs_cis = []
580
+ for i in range(bsz):
581
+ x_len = x_item_seqlens[i]
582
+ cap_len = cap_item_seqlens[i]
583
+ unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
584
+ unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
585
+ unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens, strict=True)]
586
+ unified_max_item_seqlen = max(unified_item_seqlens)
587
+
588
+ unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
589
+ unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
590
+ unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
591
+ for i, seq_len in enumerate(unified_item_seqlens):
592
+ unified_attn_mask[i, :seq_len] = 1
593
+
594
+ # Generate control hints
595
+ kwargs = {
596
+ "attn_mask": unified_attn_mask,
597
+ "freqs_cis": unified_freqs_cis,
598
+ "adaln_input": adaln_input,
599
+ }
600
+ hints = self.forward_control(
601
+ unified,
602
+ cap_feats,
603
+ control_context,
604
+ kwargs,
605
+ t=t,
606
+ patch_size=patch_size,
607
+ f_patch_size=f_patch_size,
608
+ )
609
+
610
+ # Main transformer with control hints
611
+ for layer in self.layers:
612
+ layer_kwargs = {
613
+ "attn_mask": unified_attn_mask,
614
+ "freqs_cis": unified_freqs_cis,
615
+ "adaln_input": adaln_input,
616
+ "hints": hints,
617
+ "context_scale": control_context_scale,
618
+ }
619
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
620
+
621
+ def create_custom_forward(module, **static_kwargs):
622
+ def custom_forward(*inputs):
623
+ return module(*inputs, **static_kwargs)
624
+
625
+ return custom_forward
626
+
627
+ ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
628
+
629
+ unified = torch.utils.checkpoint.checkpoint(
630
+ create_custom_forward(layer, **layer_kwargs),
631
+ unified,
632
+ **ckpt_kwargs,
633
+ )
634
+ else:
635
+ unified = layer(unified, **layer_kwargs)
636
+
637
+ # Final layer and unpatchify
638
+ unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
639
+ unified = list(unified.unbind(dim=0))
640
+ x = self.unpatchify(unified, x_size, patch_size, f_patch_size)
641
+
642
+ x = torch.stack(x)
643
+ return x, {}