diffsynth-engine 0.5.1.dev4__py3-none-any.whl → 0.6.1.dev25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/__init__.py +12 -0
- diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +19 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +22 -6
- diffsynth_engine/conf/models/flux/flux_dit.json +20 -1
- diffsynth_engine/conf/models/flux/flux_vae.json +253 -5
- diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +41 -0
- diffsynth_engine/configs/__init__.py +16 -1
- diffsynth_engine/configs/controlnet.py +13 -0
- diffsynth_engine/configs/pipeline.py +37 -11
- diffsynth_engine/models/base.py +1 -1
- diffsynth_engine/models/basic/attention.py +105 -43
- diffsynth_engine/models/basic/transformer_helper.py +36 -2
- diffsynth_engine/models/basic/video_sparse_attention.py +238 -0
- diffsynth_engine/models/flux/flux_controlnet.py +16 -30
- diffsynth_engine/models/flux/flux_dit.py +49 -62
- diffsynth_engine/models/flux/flux_dit_fbcache.py +26 -28
- diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
- diffsynth_engine/models/flux/flux_text_encoder.py +1 -1
- diffsynth_engine/models/flux/flux_vae.py +20 -2
- diffsynth_engine/models/hunyuan3d/dino_image_encoder.py +4 -2
- diffsynth_engine/models/qwen_image/qwen2_5_vl.py +5 -0
- diffsynth_engine/models/qwen_image/qwen_image_dit.py +151 -58
- diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +14 -6
- diffsynth_engine/models/qwen_image/qwen_image_vae.py +1 -1
- diffsynth_engine/models/sd/sd_text_encoder.py +1 -1
- diffsynth_engine/models/sd/sd_unet.py +1 -1
- diffsynth_engine/models/sd3/sd3_dit.py +1 -1
- diffsynth_engine/models/sd3/sd3_text_encoder.py +1 -1
- diffsynth_engine/models/sdxl/sdxl_text_encoder.py +1 -1
- diffsynth_engine/models/sdxl/sdxl_unet.py +1 -1
- diffsynth_engine/models/vae/vae.py +1 -1
- diffsynth_engine/models/wan/wan_audio_encoder.py +6 -3
- diffsynth_engine/models/wan/wan_dit.py +65 -28
- diffsynth_engine/models/wan/wan_s2v_dit.py +1 -1
- diffsynth_engine/models/wan/wan_text_encoder.py +13 -13
- diffsynth_engine/models/wan/wan_vae.py +2 -2
- diffsynth_engine/pipelines/base.py +73 -7
- diffsynth_engine/pipelines/flux_image.py +139 -120
- diffsynth_engine/pipelines/hunyuan3d_shape.py +4 -0
- diffsynth_engine/pipelines/qwen_image.py +272 -87
- diffsynth_engine/pipelines/sdxl_image.py +1 -1
- diffsynth_engine/pipelines/utils.py +52 -0
- diffsynth_engine/pipelines/wan_s2v.py +25 -14
- diffsynth_engine/pipelines/wan_video.py +43 -19
- diffsynth_engine/tokenizers/base.py +6 -0
- diffsynth_engine/tokenizers/qwen2.py +12 -4
- diffsynth_engine/utils/constants.py +13 -12
- diffsynth_engine/utils/download.py +4 -2
- diffsynth_engine/utils/env.py +2 -0
- diffsynth_engine/utils/flag.py +6 -0
- diffsynth_engine/utils/loader.py +25 -6
- diffsynth_engine/utils/parallel.py +62 -29
- diffsynth_engine/utils/video.py +3 -1
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/RECORD +69 -67
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-flf2v-14b.json → wan2.1_flf2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-i2v-14b.json → wan2.1_i2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-1.3b.json → wan2.1_t2v_1.3b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-14b.json → wan2.1_t2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-i2v-a14b.json → wan2.2_i2v_a14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-s2v-14b.json → wan2.2_s2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-t2v-a14b.json → wan2.2_t2v_a14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-ti2v-5b.json → wan2.2_ti2v_5b.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan2.1-vae.json → wan2.1_vae.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan2.2-vae.json → wan2.2_vae.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan-vae-keymap.json → wan_vae_keymap.json} +0 -0
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/top_level.txt +0 -0
|
@@ -1,15 +1,20 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import torch.nn as nn
|
|
3
|
-
from typing import Any, Dict, Tuple, Union, Optional
|
|
3
|
+
from typing import Any, Dict, List, Tuple, Union, Optional
|
|
4
4
|
from einops import rearrange
|
|
5
5
|
|
|
6
6
|
from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
|
|
7
7
|
from diffsynth_engine.models.basic import attention as attention_ops
|
|
8
8
|
from diffsynth_engine.models.basic.timestep import TimestepEmbeddings
|
|
9
|
-
from diffsynth_engine.models.basic.transformer_helper import AdaLayerNorm,
|
|
9
|
+
from diffsynth_engine.models.basic.transformer_helper import AdaLayerNorm, GELU, RMSNorm
|
|
10
10
|
from diffsynth_engine.utils.gguf import gguf_inference
|
|
11
11
|
from diffsynth_engine.utils.fp8_linear import fp8_inference
|
|
12
|
-
from diffsynth_engine.utils.parallel import
|
|
12
|
+
from diffsynth_engine.utils.parallel import (
|
|
13
|
+
cfg_parallel,
|
|
14
|
+
cfg_parallel_unshard,
|
|
15
|
+
sequence_parallel,
|
|
16
|
+
sequence_parallel_unshard,
|
|
17
|
+
)
|
|
13
18
|
|
|
14
19
|
|
|
15
20
|
class QwenImageDiTStateDictConverter(StateDictConverter):
|
|
@@ -139,7 +144,7 @@ class QwenFeedForward(nn.Module):
|
|
|
139
144
|
super().__init__()
|
|
140
145
|
inner_dim = int(dim * 4)
|
|
141
146
|
self.net = nn.ModuleList([])
|
|
142
|
-
self.net.append(
|
|
147
|
+
self.net.append(GELU(dim, inner_dim, approximate="tanh", device=device, dtype=dtype))
|
|
143
148
|
self.net.append(nn.Dropout(dropout))
|
|
144
149
|
self.net.append(nn.Linear(inner_dim, dim_out, device=device, dtype=dtype))
|
|
145
150
|
|
|
@@ -150,8 +155,8 @@ class QwenFeedForward(nn.Module):
|
|
|
150
155
|
|
|
151
156
|
|
|
152
157
|
def apply_rotary_emb_qwen(x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]):
|
|
153
|
-
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
|
154
|
-
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
|
158
|
+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) # (b, s, h, d) -> (b, s, h, d/2, 2)
|
|
159
|
+
x_out = torch.view_as_real(x_rotated * freqs_cis.unsqueeze(1)).flatten(3) # (b, s, h, d/2, 2) -> (b, s, h, d)
|
|
155
160
|
return x_out.type_as(x)
|
|
156
161
|
|
|
157
162
|
|
|
@@ -162,7 +167,6 @@ class QwenDoubleStreamAttention(nn.Module):
|
|
|
162
167
|
dim_b,
|
|
163
168
|
num_heads,
|
|
164
169
|
head_dim,
|
|
165
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
166
170
|
device: str = "cuda:0",
|
|
167
171
|
dtype: torch.dtype = torch.bfloat16,
|
|
168
172
|
):
|
|
@@ -184,44 +188,42 @@ class QwenDoubleStreamAttention(nn.Module):
|
|
|
184
188
|
|
|
185
189
|
self.to_out = nn.Linear(dim_a, dim_a, device=device, dtype=dtype)
|
|
186
190
|
self.to_add_out = nn.Linear(dim_b, dim_b, device=device, dtype=dtype)
|
|
187
|
-
self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
188
191
|
|
|
189
192
|
def forward(
|
|
190
193
|
self,
|
|
191
194
|
image: torch.FloatTensor,
|
|
192
195
|
text: torch.FloatTensor,
|
|
193
|
-
|
|
196
|
+
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
197
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
198
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
194
199
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
|
195
200
|
img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image)
|
|
196
201
|
txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text)
|
|
197
202
|
|
|
198
|
-
img_q = rearrange(img_q, "b s (h d) -> b h
|
|
199
|
-
img_k = rearrange(img_k, "b s (h d) -> b h
|
|
200
|
-
img_v = rearrange(img_v, "b s (h d) -> b h
|
|
203
|
+
img_q = rearrange(img_q, "b s (h d) -> b s h d", h=self.num_heads)
|
|
204
|
+
img_k = rearrange(img_k, "b s (h d) -> b s h d", h=self.num_heads)
|
|
205
|
+
img_v = rearrange(img_v, "b s (h d) -> b s h d", h=self.num_heads)
|
|
201
206
|
|
|
202
|
-
txt_q = rearrange(txt_q, "b s (h d) -> b h
|
|
203
|
-
txt_k = rearrange(txt_k, "b s (h d) -> b h
|
|
204
|
-
txt_v = rearrange(txt_v, "b s (h d) -> b h
|
|
207
|
+
txt_q = rearrange(txt_q, "b s (h d) -> b s h d", h=self.num_heads)
|
|
208
|
+
txt_k = rearrange(txt_k, "b s (h d) -> b s h d", h=self.num_heads)
|
|
209
|
+
txt_v = rearrange(txt_v, "b s (h d) -> b s h d", h=self.num_heads)
|
|
205
210
|
|
|
206
211
|
img_q, img_k = self.norm_q(img_q), self.norm_k(img_k)
|
|
207
212
|
txt_q, txt_k = self.norm_added_q(txt_q), self.norm_added_k(txt_k)
|
|
208
213
|
|
|
209
|
-
if
|
|
210
|
-
img_freqs, txt_freqs =
|
|
214
|
+
if rotary_emb is not None:
|
|
215
|
+
img_freqs, txt_freqs = rotary_emb
|
|
211
216
|
img_q = apply_rotary_emb_qwen(img_q, img_freqs)
|
|
212
217
|
img_k = apply_rotary_emb_qwen(img_k, img_freqs)
|
|
213
218
|
txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs)
|
|
214
219
|
txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs)
|
|
215
220
|
|
|
216
|
-
joint_q = torch.cat([txt_q, img_q], dim=
|
|
217
|
-
joint_k = torch.cat([txt_k, img_k], dim=
|
|
218
|
-
joint_v = torch.cat([txt_v, img_v], dim=
|
|
221
|
+
joint_q = torch.cat([txt_q, img_q], dim=1)
|
|
222
|
+
joint_k = torch.cat([txt_k, img_k], dim=1)
|
|
223
|
+
joint_v = torch.cat([txt_v, img_v], dim=1)
|
|
219
224
|
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
joint_v = joint_v.transpose(1, 2)
|
|
223
|
-
|
|
224
|
-
joint_attn_out = attention_ops.attention(joint_q, joint_k, joint_v, **self.attn_kwargs)
|
|
225
|
+
attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
226
|
+
joint_attn_out = attention_ops.attention(joint_q, joint_k, joint_v, attn_mask=attn_mask, **attn_kwargs)
|
|
225
227
|
|
|
226
228
|
joint_attn_out = rearrange(joint_attn_out, "b s h d -> b s (h d)").to(joint_q.dtype)
|
|
227
229
|
|
|
@@ -241,7 +243,6 @@ class QwenImageTransformerBlock(nn.Module):
|
|
|
241
243
|
num_attention_heads: int,
|
|
242
244
|
attention_head_dim: int,
|
|
243
245
|
eps: float = 1e-6,
|
|
244
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
245
246
|
device: str = "cuda:0",
|
|
246
247
|
dtype: torch.dtype = torch.bfloat16,
|
|
247
248
|
):
|
|
@@ -261,7 +262,6 @@ class QwenImageTransformerBlock(nn.Module):
|
|
|
261
262
|
dim_b=dim,
|
|
262
263
|
num_heads=num_attention_heads,
|
|
263
264
|
head_dim=attention_head_dim,
|
|
264
|
-
attn_kwargs=attn_kwargs,
|
|
265
265
|
device=device,
|
|
266
266
|
dtype=dtype,
|
|
267
267
|
)
|
|
@@ -285,7 +285,9 @@ class QwenImageTransformerBlock(nn.Module):
|
|
|
285
285
|
image: torch.Tensor,
|
|
286
286
|
text: torch.Tensor,
|
|
287
287
|
temb: torch.Tensor,
|
|
288
|
-
|
|
288
|
+
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
289
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
290
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
289
291
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
290
292
|
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
|
|
291
293
|
txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
|
|
@@ -299,7 +301,9 @@ class QwenImageTransformerBlock(nn.Module):
|
|
|
299
301
|
img_attn_out, txt_attn_out = self.attn(
|
|
300
302
|
image=img_modulated,
|
|
301
303
|
text=txt_modulated,
|
|
302
|
-
|
|
304
|
+
rotary_emb=rotary_emb,
|
|
305
|
+
attn_mask=attn_mask,
|
|
306
|
+
attn_kwargs=attn_kwargs,
|
|
303
307
|
)
|
|
304
308
|
|
|
305
309
|
image = image + img_gate * img_attn_out
|
|
@@ -327,7 +331,6 @@ class QwenImageDiT(PreTrainedModel):
|
|
|
327
331
|
def __init__(
|
|
328
332
|
self,
|
|
329
333
|
num_layers: int = 60,
|
|
330
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
331
334
|
device: str = "cuda:0",
|
|
332
335
|
dtype: torch.dtype = torch.bfloat16,
|
|
333
336
|
):
|
|
@@ -348,7 +351,6 @@ class QwenImageDiT(PreTrainedModel):
|
|
|
348
351
|
dim=3072,
|
|
349
352
|
num_attention_heads=24,
|
|
350
353
|
attention_head_dim=128,
|
|
351
|
-
attn_kwargs=attn_kwargs,
|
|
352
354
|
device=device,
|
|
353
355
|
dtype=dtype,
|
|
354
356
|
)
|
|
@@ -368,13 +370,75 @@ class QwenImageDiT(PreTrainedModel):
|
|
|
368
370
|
)
|
|
369
371
|
return hidden_states
|
|
370
372
|
|
|
373
|
+
def process_entity_masks(
|
|
374
|
+
self,
|
|
375
|
+
text: torch.Tensor,
|
|
376
|
+
text_seq_lens: torch.LongTensor,
|
|
377
|
+
rotary_emb: Tuple[torch.Tensor, torch.Tensor],
|
|
378
|
+
video_fhw: List[Tuple[int, int, int]],
|
|
379
|
+
entity_text: List[torch.Tensor],
|
|
380
|
+
entity_seq_lens: List[torch.LongTensor],
|
|
381
|
+
entity_masks: List[torch.Tensor],
|
|
382
|
+
device: str,
|
|
383
|
+
dtype: torch.dtype,
|
|
384
|
+
):
|
|
385
|
+
entity_seq_lens = [seq_lens.max().item() for seq_lens in entity_seq_lens]
|
|
386
|
+
text_seq_lens = entity_seq_lens + [text_seq_lens.max().item()]
|
|
387
|
+
entity_text = [
|
|
388
|
+
self.txt_in(self.txt_norm(text[:, :seq_len])) for text, seq_len in zip(entity_text, entity_seq_lens)
|
|
389
|
+
]
|
|
390
|
+
text = torch.cat(entity_text + [text], dim=1)
|
|
391
|
+
|
|
392
|
+
entity_txt_freqs = [self.pos_embed(video_fhw, seq_len, device)[1] for seq_len in entity_seq_lens]
|
|
393
|
+
img_freqs, txt_freqs = rotary_emb
|
|
394
|
+
txt_freqs = torch.cat(entity_txt_freqs + [txt_freqs], dim=0)
|
|
395
|
+
rotary_emb = (img_freqs, txt_freqs)
|
|
396
|
+
|
|
397
|
+
global_mask = torch.ones_like(entity_masks[0], device=device, dtype=dtype)
|
|
398
|
+
patched_masks = [self.patchify(mask) for mask in entity_masks + [global_mask]]
|
|
399
|
+
batch_size, image_seq_len = patched_masks[0].shape[:2]
|
|
400
|
+
total_seq_len = sum(text_seq_lens) + image_seq_len
|
|
401
|
+
attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), device=device, dtype=torch.bool)
|
|
402
|
+
|
|
403
|
+
# text-image attention mask
|
|
404
|
+
img_start, img_end = sum(text_seq_lens), total_seq_len
|
|
405
|
+
cumsum = [0]
|
|
406
|
+
for seq_len in text_seq_lens:
|
|
407
|
+
cumsum.append(cumsum[-1] + seq_len)
|
|
408
|
+
for i, patched_mask in enumerate(patched_masks):
|
|
409
|
+
txt_start, txt_end = cumsum[i], cumsum[i + 1]
|
|
410
|
+
mask = torch.sum(patched_mask, dim=-1) > 0
|
|
411
|
+
mask = mask.unsqueeze(1).repeat(1, text_seq_lens[i], 1)
|
|
412
|
+
# text-to-image attention
|
|
413
|
+
attention_mask[:, txt_start:txt_end, img_start:img_end] = mask
|
|
414
|
+
# image-to-text attention
|
|
415
|
+
attention_mask[:, img_start:img_end, txt_start:txt_end] = mask.transpose(1, 2)
|
|
416
|
+
# entity text tokens should not attend to each other
|
|
417
|
+
for i in range(len(text_seq_lens)):
|
|
418
|
+
for j in range(len(text_seq_lens)):
|
|
419
|
+
if i == j:
|
|
420
|
+
continue
|
|
421
|
+
i_start, i_end = cumsum[i], cumsum[i + 1]
|
|
422
|
+
j_start, j_end = cumsum[j], cumsum[j + 1]
|
|
423
|
+
attention_mask[:, i_start:i_end, j_start:j_end] = False
|
|
424
|
+
|
|
425
|
+
attn_mask = torch.zeros_like(attention_mask, device=device, dtype=dtype)
|
|
426
|
+
attn_mask[~attention_mask] = -torch.inf
|
|
427
|
+
attn_mask = attn_mask.unsqueeze(1)
|
|
428
|
+
return text, rotary_emb, attn_mask
|
|
429
|
+
|
|
371
430
|
def forward(
|
|
372
431
|
self,
|
|
373
432
|
image: torch.Tensor,
|
|
374
433
|
edit: torch.Tensor = None,
|
|
375
|
-
text: torch.Tensor = None,
|
|
376
434
|
timestep: torch.LongTensor = None,
|
|
377
|
-
|
|
435
|
+
text: torch.Tensor = None,
|
|
436
|
+
text_seq_lens: torch.LongTensor = None,
|
|
437
|
+
context_latents: Optional[torch.Tensor] = None,
|
|
438
|
+
entity_text: Optional[List[torch.Tensor]] = None,
|
|
439
|
+
entity_seq_lens: Optional[List[torch.LongTensor]] = None,
|
|
440
|
+
entity_masks: Optional[List[torch.Tensor]] = None,
|
|
441
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
378
442
|
):
|
|
379
443
|
h, w = image.shape[-2:]
|
|
380
444
|
fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
|
|
@@ -385,37 +449,72 @@ class QwenImageDiT(PreTrainedModel):
|
|
|
385
449
|
cfg_parallel(
|
|
386
450
|
(
|
|
387
451
|
image,
|
|
388
|
-
edit,
|
|
389
|
-
text,
|
|
452
|
+
*(edit if edit is not None else ()),
|
|
390
453
|
timestep,
|
|
391
|
-
|
|
454
|
+
text,
|
|
455
|
+
text_seq_lens,
|
|
456
|
+
*(entity_text if entity_text is not None else ()),
|
|
457
|
+
*(entity_seq_lens if entity_seq_lens is not None else ()),
|
|
458
|
+
*(entity_masks if entity_masks is not None else ()),
|
|
459
|
+
context_latents,
|
|
392
460
|
),
|
|
393
461
|
use_cfg=use_cfg,
|
|
394
462
|
),
|
|
395
463
|
):
|
|
396
464
|
conditioning = self.time_text_embed(timestep, image.dtype)
|
|
397
465
|
video_fhw = [(1, h // 2, w // 2)] # frame, height, width
|
|
398
|
-
|
|
466
|
+
text_seq_len = text_seq_lens.max().item()
|
|
399
467
|
image = self.patchify(image)
|
|
400
468
|
image_seq_len = image.shape[1]
|
|
469
|
+
if context_latents is not None:
|
|
470
|
+
context_latents = context_latents.to(dtype=image.dtype)
|
|
471
|
+
context_latents = self.patchify(context_latents)
|
|
472
|
+
image = torch.cat([image, context_latents], dim=1)
|
|
473
|
+
video_fhw += [(1, h // 2, w // 2)]
|
|
401
474
|
if edit is not None:
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
475
|
+
for img in edit:
|
|
476
|
+
img = img.to(dtype=image.dtype)
|
|
477
|
+
edit_h, edit_w = img.shape[-2:]
|
|
478
|
+
img = self.patchify(img)
|
|
479
|
+
image = torch.cat([image, img], dim=1)
|
|
480
|
+
video_fhw += [(1, edit_h // 2, edit_w // 2)]
|
|
406
481
|
|
|
407
|
-
|
|
482
|
+
rotary_emb = self.pos_embed(video_fhw, text_seq_len, image.device)
|
|
408
483
|
|
|
409
484
|
image = self.img_in(image)
|
|
410
|
-
text = self.txt_in(self.txt_norm(text[:, :
|
|
485
|
+
text = self.txt_in(self.txt_norm(text[:, :text_seq_len]))
|
|
411
486
|
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
487
|
+
attn_mask = None
|
|
488
|
+
if entity_text is not None:
|
|
489
|
+
text, rotary_emb, attn_mask = self.process_entity_masks(
|
|
490
|
+
text,
|
|
491
|
+
text_seq_lens,
|
|
492
|
+
rotary_emb,
|
|
493
|
+
video_fhw,
|
|
494
|
+
entity_text,
|
|
495
|
+
entity_seq_lens,
|
|
496
|
+
entity_masks,
|
|
497
|
+
image.device,
|
|
498
|
+
image.dtype,
|
|
499
|
+
)
|
|
418
500
|
|
|
501
|
+
# warning: Eligen does not work with sequence parallel because long context attention does not support attention masks
|
|
502
|
+
img_freqs, txt_freqs = rotary_emb
|
|
503
|
+
with sequence_parallel((image, text, img_freqs, txt_freqs), seq_dims=(1, 1, 0, 0)):
|
|
504
|
+
rotary_emb = (img_freqs, txt_freqs)
|
|
505
|
+
for block in self.transformer_blocks:
|
|
506
|
+
text, image = block(
|
|
507
|
+
image=image,
|
|
508
|
+
text=text,
|
|
509
|
+
temb=conditioning,
|
|
510
|
+
rotary_emb=rotary_emb,
|
|
511
|
+
attn_mask=attn_mask,
|
|
512
|
+
attn_kwargs=attn_kwargs,
|
|
513
|
+
)
|
|
514
|
+
image = self.norm_out(image, conditioning)
|
|
515
|
+
image = self.proj_out(image)
|
|
516
|
+
(image,) = sequence_parallel_unshard((image,), seq_dims=(1,), seq_lens=(image_seq_len,))
|
|
517
|
+
image = image[:, :image_seq_len]
|
|
419
518
|
image = self.unpatchify(image, h, w)
|
|
420
519
|
|
|
421
520
|
(image,) = cfg_parallel_unshard((image,), use_cfg=use_cfg)
|
|
@@ -428,14 +527,8 @@ class QwenImageDiT(PreTrainedModel):
|
|
|
428
527
|
device: str,
|
|
429
528
|
dtype: torch.dtype,
|
|
430
529
|
num_layers: int = 60,
|
|
431
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
432
530
|
):
|
|
433
|
-
model = cls(
|
|
434
|
-
device="meta",
|
|
435
|
-
dtype=dtype,
|
|
436
|
-
num_layers=num_layers,
|
|
437
|
-
attn_kwargs=attn_kwargs,
|
|
438
|
-
)
|
|
531
|
+
model = cls(device="meta", dtype=dtype, num_layers=num_layers)
|
|
439
532
|
model = model.requires_grad_(False)
|
|
440
533
|
model.load_state_dict(state_dict, assign=True)
|
|
441
534
|
model.to(device=device, dtype=dtype, non_blocking=True)
|
|
@@ -445,5 +538,5 @@ class QwenImageDiT(PreTrainedModel):
|
|
|
445
538
|
for block in self.transformer_blocks:
|
|
446
539
|
block.compile(*args, **kwargs)
|
|
447
540
|
|
|
448
|
-
def
|
|
449
|
-
return
|
|
541
|
+
def get_fsdp_module_cls(self):
|
|
542
|
+
return {QwenImageTransformerBlock}
|
|
@@ -11,12 +11,11 @@ class QwenImageDiTFBCache(QwenImageDiT):
|
|
|
11
11
|
def __init__(
|
|
12
12
|
self,
|
|
13
13
|
num_layers: int = 60,
|
|
14
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
15
14
|
device: str = "cuda:0",
|
|
16
15
|
dtype: torch.dtype = torch.bfloat16,
|
|
17
16
|
relative_l1_threshold: float = 0.05,
|
|
18
17
|
):
|
|
19
|
-
super().__init__(num_layers=num_layers,
|
|
18
|
+
super().__init__(num_layers=num_layers, device=device, dtype=dtype)
|
|
20
19
|
self.relative_l1_threshold = relative_l1_threshold
|
|
21
20
|
self.step_count = 0
|
|
22
21
|
self.num_inference_steps = 0
|
|
@@ -43,6 +42,7 @@ class QwenImageDiTFBCache(QwenImageDiT):
|
|
|
43
42
|
text: torch.Tensor = None,
|
|
44
43
|
timestep: torch.LongTensor = None,
|
|
45
44
|
txt_seq_lens: torch.LongTensor = None,
|
|
45
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
46
46
|
):
|
|
47
47
|
h, w = image.shape[-2:]
|
|
48
48
|
fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
|
|
@@ -72,7 +72,11 @@ class QwenImageDiTFBCache(QwenImageDiT):
|
|
|
72
72
|
# first block
|
|
73
73
|
original_hidden_states = image
|
|
74
74
|
text, image = self.transformer_blocks[0](
|
|
75
|
-
image=image,
|
|
75
|
+
image=image,
|
|
76
|
+
text=text,
|
|
77
|
+
temb=conditioning,
|
|
78
|
+
image_rotary_emb=image_rotary_emb,
|
|
79
|
+
attn_kwargs=attn_kwargs,
|
|
76
80
|
)
|
|
77
81
|
first_hidden_states_residual = image - original_hidden_states
|
|
78
82
|
|
|
@@ -94,7 +98,13 @@ class QwenImageDiTFBCache(QwenImageDiT):
|
|
|
94
98
|
first_hidden_states = image.clone()
|
|
95
99
|
|
|
96
100
|
for block in self.transformer_blocks[1:]:
|
|
97
|
-
text, image = block(
|
|
101
|
+
text, image = block(
|
|
102
|
+
image=image,
|
|
103
|
+
text=text,
|
|
104
|
+
temb=conditioning,
|
|
105
|
+
image_rotary_emb=image_rotary_emb,
|
|
106
|
+
attn_kwargs=attn_kwargs,
|
|
107
|
+
)
|
|
98
108
|
|
|
99
109
|
previous_residual = image - first_hidden_states
|
|
100
110
|
self.previous_residual = previous_residual
|
|
@@ -114,14 +124,12 @@ class QwenImageDiTFBCache(QwenImageDiT):
|
|
|
114
124
|
device: str,
|
|
115
125
|
dtype: torch.dtype,
|
|
116
126
|
num_layers: int = 60,
|
|
117
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
118
127
|
relative_l1_threshold: float = 0.05,
|
|
119
128
|
):
|
|
120
129
|
model = cls(
|
|
121
130
|
device="meta",
|
|
122
131
|
dtype=dtype,
|
|
123
132
|
num_layers=num_layers,
|
|
124
|
-
attn_kwargs=attn_kwargs,
|
|
125
133
|
relative_l1_threshold=relative_l1_threshold,
|
|
126
134
|
)
|
|
127
135
|
model = model.requires_grad_(False)
|
|
@@ -223,7 +223,6 @@ class Wav2Vec2StateDictConverter:
|
|
|
223
223
|
|
|
224
224
|
class Wav2Vec2Model(PreTrainedModel):
|
|
225
225
|
converter = Wav2Vec2StateDictConverter()
|
|
226
|
-
_supports_parallelization = False
|
|
227
226
|
|
|
228
227
|
def __init__(self, config: Wav2Vec2Config, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
|
|
229
228
|
super().__init__()
|
|
@@ -267,9 +266,13 @@ def linear_interpolation(features: torch.Tensor, input_fps: int, output_fps: int
|
|
|
267
266
|
return output_features.transpose(1, 2) # [1, output_len, 512]
|
|
268
267
|
|
|
269
268
|
|
|
270
|
-
def extract_audio_feat(
|
|
269
|
+
def extract_audio_feat(
|
|
270
|
+
audio_input: torch.Tensor, model: Wav2Vec2Model, dtype=torch.float32, device="cuda:0"
|
|
271
|
+
) -> torch.Tensor:
|
|
271
272
|
video_rate = 30
|
|
272
|
-
input_values = (audio_input - audio_input.mean(dim=1, keepdim=True)) / torch.sqrt(
|
|
273
|
+
input_values = (audio_input - audio_input.mean(dim=1, keepdim=True)) / torch.sqrt(
|
|
274
|
+
audio_input.var(dim=1, keepdim=True) + 1e-7
|
|
275
|
+
)
|
|
273
276
|
feat = torch.cat(model(input_values.to(device)))
|
|
274
277
|
feat = linear_interpolation(feat, input_fps=50, output_fps=video_rate)
|
|
275
278
|
return feat.to(dtype) # Encoding for the motion
|