diffsynth-engine 0.5.1.dev2__py3-none-any.whl → 0.5.1.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,567 @@
1
+ import json
2
+ from typing import List, Optional
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+
10
+ from diffsynth_engine.models.basic.transformer_helper import AdaLayerNorm
11
+ from diffsynth_engine.models.wan.wan_dit import (
12
+ WanDiT,
13
+ DiTBlock,
14
+ CrossAttention,
15
+ sinusoidal_embedding_1d,
16
+ precompute_freqs_cis_3d,
17
+ modulate,
18
+ )
19
+ from diffsynth_engine.utils.constants import WAN2_2_DIT_S2V_14B_CONFIG_FILE
20
+ from diffsynth_engine.utils.gguf import gguf_inference
21
+ from diffsynth_engine.utils.fp8_linear import fp8_inference
22
+ from diffsynth_engine.utils.parallel import (
23
+ cfg_parallel,
24
+ cfg_parallel_unshard,
25
+ sequence_parallel,
26
+ sequence_parallel_unshard,
27
+ )
28
+
29
+
30
+ def rope_precompute(x: torch.Tensor, grid_sizes: List[List[torch.Tensor]], freqs: torch.Tensor):
31
+ # roughly speaking, this function is to combine ropes, but it is written in a very strange way.
32
+ # I try to make it align better with normal implementation
33
+ b, s, n, c = x.shape
34
+ c = c // 2
35
+ output = torch.view_as_complex(x.reshape(b, s, n, c, 2).to(torch.float64))
36
+ prev_seq = 0
37
+ for grid_size in grid_sizes:
38
+ f_o, h_o, w_o = grid_size[0]
39
+ f, h, w = grid_size[1]
40
+ t_f, t_h, t_w = grid_size[2]
41
+ seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o
42
+ seq_len = int(seq_f * seq_h * seq_w)
43
+ # Generate a list of seq_f integers starting from f_o and ending at math.ceil(factor_f * seq_f.item() + f_o.item())
44
+ if f_o >= 0:
45
+ f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1, seq_f).astype(int).tolist()
46
+ else:
47
+ f_sam = np.linspace(-f_o.item(), (-t_f - f_o).item() + 1, seq_f).astype(int).tolist()
48
+ h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1, seq_h).astype(int).tolist()
49
+ w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1, seq_w).astype(int).tolist()
50
+ freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][f_sam].conj()
51
+ freqs_i = torch.cat(
52
+ [
53
+ freqs_0.view(seq_f, 1, 1, -1).expand(seq_f, seq_h, seq_w, -1),
54
+ freqs[1][h_sam].view(1, seq_h, 1, -1).expand(seq_f, seq_h, seq_w, -1),
55
+ freqs[2][w_sam].view(1, 1, seq_w, -1).expand(seq_f, seq_h, seq_w, -1),
56
+ ],
57
+ dim=-1,
58
+ ).reshape(seq_len, 1, -1)
59
+ # apply rotary embedding
60
+ output[:, prev_seq : prev_seq + seq_len] = freqs_i
61
+ prev_seq += seq_len
62
+ return output
63
+
64
+
65
+ class FramePackMotioner(nn.Module):
66
+ def __init__(
67
+ self,
68
+ inner_dim: int = 1024,
69
+ num_heads: int = 16,
70
+ zip_frame_buckets: List[int] = [
71
+ 1,
72
+ 2,
73
+ 16,
74
+ ], # Three numbers representing the number of frames sampled for patch operations from the nearest to the farthest frames
75
+ device: str = "cuda:0",
76
+ dtype: torch.dtype = torch.bfloat16,
77
+ ):
78
+ super().__init__()
79
+ self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2), device=device, dtype=dtype)
80
+ self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4), device=device, dtype=dtype)
81
+ self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8), device=device, dtype=dtype)
82
+ self.zip_frame_buckets = zip_frame_buckets
83
+
84
+ self.inner_dim = inner_dim
85
+ self.num_heads = num_heads
86
+
87
+ assert (inner_dim % num_heads) == 0 and (inner_dim // num_heads) % 2 == 0
88
+ head_dim = inner_dim // num_heads
89
+ self.freqs = precompute_freqs_cis_3d(head_dim)
90
+
91
+ def forward(self, motion_latents: torch.Tensor):
92
+ b, _, f, h, w = motion_latents.shape
93
+ padd_latents = torch.zeros(
94
+ (b, 16, sum(self.zip_frame_buckets), h, w), device=motion_latents.device, dtype=motion_latents.dtype
95
+ )
96
+ overlap_frame = min(padd_latents.shape[2], f)
97
+ if overlap_frame > 0:
98
+ padd_latents[:, :, -overlap_frame:] = motion_latents[:, :, -overlap_frame:]
99
+
100
+ clean_latents_4x, clean_latents_2x, clean_latents_post = padd_latents[
101
+ :, :, -sum(self.zip_frame_buckets) :
102
+ ].split(self.zip_frame_buckets[::-1], dim=2) # 16, 2 ,1
103
+
104
+ clean_latents_post = rearrange(self.proj(clean_latents_post), "b c f h w -> b (f h w) c").contiguous()
105
+ clean_latents_2x = rearrange(self.proj_2x(clean_latents_2x), "b c f h w -> b (f h w) c").contiguous()
106
+ clean_latents_4x = rearrange(self.proj_4x(clean_latents_4x), "b c f h w -> b (f h w) c").contiguous()
107
+ motion_latents = torch.cat([clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1)
108
+
109
+ def get_grid_sizes(i: int): # rope, 0: post, 1: 2x, 2: 4x
110
+ start_time_id = -sum(self.zip_frame_buckets[: (i + 1)])
111
+ end_time_id = start_time_id + self.zip_frame_buckets[i] // (2**i)
112
+ return [
113
+ [
114
+ torch.tensor([start_time_id, 0, 0]),
115
+ torch.tensor([end_time_id, h // (2 ** (i + 1)), w // (2 ** (i + 1))]),
116
+ torch.tensor([self.zip_frame_buckets[i], h // 2, w // 2]),
117
+ ]
118
+ ]
119
+
120
+ motion_rope_emb = rope_precompute(
121
+ x=rearrange(motion_latents, "b s (n d) -> b s n d", n=self.num_heads),
122
+ grid_sizes=get_grid_sizes(0) + get_grid_sizes(1) + get_grid_sizes(2),
123
+ freqs=self.freqs,
124
+ )
125
+
126
+ return motion_latents, motion_rope_emb
127
+
128
+
129
+ class CausalConv1d(nn.Module):
130
+ def __init__(
131
+ self,
132
+ chan_in: int,
133
+ chan_out: int,
134
+ kernel_size: int = 3,
135
+ stride: int = 1,
136
+ dilation: int = 1,
137
+ pad_mode: str = "replicate",
138
+ device: str = "cuda:0",
139
+ dtype: torch.dtype = torch.bfloat16,
140
+ ):
141
+ super().__init__()
142
+ self.pad_mode = pad_mode
143
+ self.time_causal_padding = (kernel_size - 1, 0)
144
+ self.conv = nn.Conv1d(
145
+ chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, device=device, dtype=dtype
146
+ )
147
+
148
+ def forward(self, x: torch.Tensor):
149
+ x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
150
+ return self.conv(x)
151
+
152
+
153
+ class MotionEncoder(nn.Module):
154
+ def __init__(
155
+ self,
156
+ in_dim: int,
157
+ hidden_dim: int,
158
+ num_heads: int = 8,
159
+ dtype: torch.dtype = torch.bfloat16,
160
+ device: str = "cuda:0",
161
+ ):
162
+ super().__init__()
163
+ self.num_heads = num_heads
164
+ self.conv1_local = CausalConv1d(in_dim, hidden_dim // 4 * num_heads, 3, stride=1, device=device, dtype=dtype)
165
+ self.conv1_global = CausalConv1d(in_dim, hidden_dim // 4, 3, stride=1, device=device, dtype=dtype)
166
+ self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2, device=device, dtype=dtype)
167
+ self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2, device=device, dtype=dtype)
168
+ self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
169
+ self.norm2 = nn.LayerNorm(hidden_dim // 2, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
170
+ self.norm3 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
171
+ self.act = nn.SiLU()
172
+ self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
173
+ self.final_linear = nn.Linear(hidden_dim, hidden_dim, device=device, dtype=dtype)
174
+
175
+ def forward(self, x: torch.Tensor):
176
+ x = rearrange(x, "b t c -> b c t")
177
+ x_original = x
178
+ b = x.shape[0]
179
+ x = self.conv1_local(x)
180
+ x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads)
181
+ x = self.act(self.norm1(x))
182
+ x = rearrange(x, "b t c -> b c t")
183
+ x = self.conv2(x)
184
+ x = rearrange(x, "b c t -> b t c")
185
+ x = self.act(self.norm2(x))
186
+ x = rearrange(x, "b t c -> b c t")
187
+ x = self.conv3(x)
188
+ x = rearrange(x, "b c t -> b t c")
189
+ x = self.act(self.norm3(x))
190
+ x = rearrange(x, "(b n) t c -> b t n c", b=b)
191
+ padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)
192
+ x = torch.cat([x, padding], dim=-2)
193
+ x_local = x
194
+
195
+ x = self.conv1_global(x_original)
196
+ x = rearrange(x, "b c t -> b t c")
197
+ x = self.act(self.norm1(x))
198
+ x = rearrange(x, "b t c -> b c t")
199
+ x = self.conv2(x)
200
+ x = rearrange(x, "b c t -> b t c")
201
+ x = self.act(self.norm2(x))
202
+ x = rearrange(x, "b t c -> b c t")
203
+ x = self.conv3(x)
204
+ x = rearrange(x, "b c t -> b t c")
205
+ x = self.act(self.norm3(x))
206
+ x = self.final_linear(x)
207
+ x = rearrange(x, "(b n) t c -> b t n c", b=b)
208
+ return x, x_local
209
+
210
+
211
+ class CausalAudioEncoder(nn.Module):
212
+ def __init__(
213
+ self,
214
+ dim: int = 1024,
215
+ num_layers: int = 25,
216
+ out_dim: int = 2048,
217
+ num_token: int = 4,
218
+ dtype: torch.dtype = torch.bfloat16,
219
+ device: str = "cuda:0",
220
+ ):
221
+ super().__init__()
222
+ self.encoder = MotionEncoder(in_dim=dim, hidden_dim=out_dim, num_heads=num_token, device=device, dtype=dtype)
223
+ self.weights = nn.Parameter(torch.ones((1, num_layers, 1, 1), device=device, dtype=dtype) * 0.01)
224
+ self.act = nn.SiLU()
225
+
226
+ def forward(self, features: torch.Tensor):
227
+ # features: b num_layers dim video_length
228
+ weights = self.act(self.weights)
229
+ weights_sum = weights.sum(dim=1, keepdims=True)
230
+ weighted_feat = ((features * weights) / weights_sum).sum(dim=1) # b dim f
231
+ weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim
232
+ return self.encoder(weighted_feat) # b f n dim
233
+
234
+
235
+ class AudioInjector(nn.Module):
236
+ def __init__(
237
+ self,
238
+ dim=5120,
239
+ num_heads=40,
240
+ inject_layers=[0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39],
241
+ adain_dim=5120,
242
+ device: str = "cuda:0",
243
+ dtype: torch.dtype = torch.bfloat16,
244
+ ):
245
+ super().__init__()
246
+ self.injected_block_id = {}
247
+ for i, id in enumerate(inject_layers):
248
+ self.injected_block_id[id] = i
249
+
250
+ self.injector = nn.ModuleList(
251
+ [
252
+ CrossAttention(
253
+ dim=dim,
254
+ num_heads=num_heads,
255
+ device=device,
256
+ dtype=dtype,
257
+ )
258
+ for _ in range(len(inject_layers))
259
+ ]
260
+ )
261
+ self.injector_adain_layers = nn.ModuleList(
262
+ [AdaLayerNorm(dim=adain_dim, device=device, dtype=dtype) for _ in range(len(inject_layers))]
263
+ )
264
+
265
+
266
+ class DiTBlockS2V(nn.Module):
267
+ def __init__(self, dit_block: DiTBlock):
268
+ super().__init__()
269
+ self.dim = dit_block.dim
270
+ self.num_heads = dit_block.num_heads
271
+ self.ffn_dim = dit_block.ffn_dim
272
+ self.self_attn = dit_block.self_attn
273
+ self.cross_attn = dit_block.cross_attn
274
+ self.norm1 = dit_block.norm1
275
+ self.norm2 = dit_block.norm2
276
+ self.norm3 = dit_block.norm3
277
+ self.ffn = dit_block.ffn
278
+ self.modulation = dit_block.modulation
279
+
280
+ def forward(self, x, x_seq_len, context, t_mod, t_mod_0, freqs):
281
+ # msa: multi-head self-attention mlp: multi-layer perceptron
282
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
283
+ t for t in (self.modulation + t_mod).chunk(6, dim=1)
284
+ ]
285
+ shift_msa_0, scale_msa_0, gate_msa_0, shift_mlp_0, scale_mlp_0, gate_mlp_0 = [
286
+ t for t in (self.modulation + t_mod_0).chunk(6, dim=1)
287
+ ]
288
+ norm1_x = self.norm1(x)
289
+ input_x = torch.cat(
290
+ [
291
+ modulate(norm1_x[:, :x_seq_len], shift_msa, scale_msa),
292
+ modulate(norm1_x[:, x_seq_len:], shift_msa_0, scale_msa_0),
293
+ ],
294
+ dim=1,
295
+ )
296
+ self_attn_x = self.self_attn(input_x, freqs)
297
+ x += torch.cat([self_attn_x[:, :x_seq_len] * gate_msa, self_attn_x[:, x_seq_len:] * gate_msa_0], dim=1)
298
+ x += self.cross_attn(self.norm3(x), context)
299
+ norm2_x = self.norm2(x)
300
+ input_x = torch.cat(
301
+ [
302
+ modulate(norm2_x[:, :x_seq_len], shift_mlp, scale_mlp),
303
+ modulate(norm2_x[:, x_seq_len:], shift_mlp_0, scale_mlp_0),
304
+ ],
305
+ dim=1,
306
+ )
307
+ ffn_x = self.ffn(input_x)
308
+ x += torch.cat([ffn_x[:, :x_seq_len] * gate_mlp, ffn_x[:, x_seq_len:] * gate_mlp_0], dim=1)
309
+ return x
310
+
311
+
312
+ class WanS2VDiT(WanDiT):
313
+ def __init__(
314
+ self,
315
+ cond_dim: int = 16,
316
+ audio_dim: int = 1024,
317
+ num_audio_token: int = 4,
318
+ audio_inject_layers: List[int] = [0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39],
319
+ num_heads: int = 40,
320
+ device: str = "cuda:0",
321
+ dtype: torch.dtype = torch.bfloat16,
322
+ *args,
323
+ **kwargs,
324
+ ):
325
+ super().__init__(num_heads=num_heads, device=device, dtype=dtype, *args, **kwargs)
326
+ self.num_heads = num_heads
327
+ self.cond_encoder = nn.Conv3d(
328
+ cond_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size, device=device, dtype=dtype
329
+ )
330
+ self.casual_audio_encoder = CausalAudioEncoder(
331
+ dim=audio_dim, out_dim=self.dim, num_token=num_audio_token, device=device, dtype=dtype
332
+ )
333
+ self.audio_injector = AudioInjector(
334
+ dim=self.dim,
335
+ num_heads=num_heads,
336
+ inject_layers=audio_inject_layers,
337
+ adain_dim=self.dim,
338
+ device=device,
339
+ dtype=dtype,
340
+ )
341
+ self.trainable_cond_mask = nn.Embedding(3, self.dim, device=device, dtype=dtype)
342
+ self.frame_packer = FramePackMotioner(
343
+ inner_dim=self.dim,
344
+ num_heads=num_heads,
345
+ zip_frame_buckets=[1, 2, 16],
346
+ device=device,
347
+ dtype=dtype,
348
+ )
349
+ dit_blocks_s2v: nn.ModuleList[DiTBlockS2V] = nn.ModuleList()
350
+ for block in self.blocks:
351
+ dit_blocks_s2v.append(DiTBlockS2V(block))
352
+ self.blocks = dit_blocks_s2v
353
+
354
+ @staticmethod
355
+ def get_model_config(model_type: str):
356
+ MODEL_CONFIG_FILES = {
357
+ "wan2.2-s2v-14b": WAN2_2_DIT_S2V_14B_CONFIG_FILE,
358
+ }
359
+ if model_type not in MODEL_CONFIG_FILES:
360
+ raise ValueError(f"Unsupported model type: {model_type}")
361
+
362
+ config_file = MODEL_CONFIG_FILES[model_type]
363
+ with open(config_file, "r") as f:
364
+ config = json.load(f)
365
+ return config
366
+
367
+ def inject_motion(
368
+ self,
369
+ x: torch.Tensor,
370
+ x_seq_len: int,
371
+ rope_embs: torch.Tensor,
372
+ motion_latents: torch.Tensor,
373
+ drop_motion_frames: bool = False,
374
+ ):
375
+ # Initialize masks to indicate noisy latent, ref latent, and motion latent.
376
+ b, s, _ = x.shape
377
+ mask_input = torch.zeros([b, s], dtype=torch.long, device=x.device)
378
+ mask_input[:, x_seq_len:] = 1
379
+
380
+ if not drop_motion_frames:
381
+ motion, motion_rope_emb = self.frame_packer(motion_latents)
382
+ x = torch.cat([x, motion], dim=1)
383
+ rope_embs = torch.cat([rope_embs, motion_rope_emb], dim=1)
384
+ mask_input = torch.cat(
385
+ [
386
+ mask_input,
387
+ 2 * torch.ones([b, motion.shape[1]], device=mask_input.device, dtype=mask_input.dtype),
388
+ ],
389
+ dim=1,
390
+ )
391
+ x += self.trainable_cond_mask(mask_input).to(x.dtype)
392
+ return x, rope_embs
393
+
394
+ def patchify_x_with_pose(self, x: torch.Tensor, pose: torch.Tensor):
395
+ x = self.patch_embedding(x) + self.cond_encoder(pose)
396
+ grid_size = x.shape[2:]
397
+ x = rearrange(x, "b c f h w -> b (f h w) c").contiguous()
398
+ return x, grid_size # x, grid_size: (f, h, w)
399
+
400
+ def forward(
401
+ self,
402
+ x: torch.Tensor, # b c tx h w
403
+ context: torch.Tensor, # b s c
404
+ timestep: torch.Tensor, # b
405
+ ref_latents: torch.Tensor, # b c 1 h w
406
+ motion_latents: torch.Tensor, # b c tm h w
407
+ pose_cond: torch.Tensor, # b c tx h w
408
+ audio_input: torch.Tensor, # b c d tx
409
+ num_motion_frames: int = 73,
410
+ num_motion_latents: int = 19,
411
+ drop_motion_frames: bool = False, # !(ref_as_first_frame || clip_idx)
412
+ audio_mask: Optional[torch.Tensor] = None, # b c tx h w
413
+ void_audio_input: Optional[torch.Tensor] = None,
414
+ ):
415
+ fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
416
+ use_cfg = x.shape[0] > 1
417
+ with (
418
+ fp8_inference(fp8_linear_enabled),
419
+ gguf_inference(),
420
+ cfg_parallel((x, context, audio_input), use_cfg=use_cfg),
421
+ ):
422
+ audio_emb_global, merged_audio_emb, void_audio_emb_global, void_merged_audio_emb, audio_mask = (
423
+ self.get_audio_emb(audio_input, num_motion_frames, num_motion_latents, audio_mask, void_audio_input)
424
+ )
425
+ t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep)) # (s, d)
426
+ t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
427
+ t_mod_0 = self.time_projection(
428
+ self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, torch.zeros([1]).to(t)))
429
+ ).unflatten(1, (6, self.dim))
430
+ context = self.text_embedding(context)
431
+ x, (f, h, w) = self.patchify_x_with_pose(x, pose_cond)
432
+ ref, _ = self.patchify(ref_latents)
433
+ x = torch.cat([x, ref], dim=1)
434
+ freqs = rope_precompute(
435
+ x=rearrange(x, "b s (n d) -> b s n d", n=self.num_heads),
436
+ grid_sizes=[
437
+ [
438
+ torch.tensor([0, 0, 0]),
439
+ torch.tensor([f, h, w]),
440
+ torch.tensor([f, h, w]),
441
+ ], # grid size of x
442
+ [
443
+ torch.tensor([30, 0, 0]),
444
+ torch.tensor([31, h, w]),
445
+ torch.tensor([1, h, w]),
446
+ ], # grid size of ref
447
+ ],
448
+ freqs=self.freqs,
449
+ )
450
+ # why do they fix 30?
451
+ # seems that they just want self.freqs[0][30]
452
+
453
+ x_seq_len = f * h * w
454
+ x, freqs = self.inject_motion(
455
+ x=x,
456
+ x_seq_len=x_seq_len,
457
+ rope_embs=freqs,
458
+ motion_latents=motion_latents,
459
+ drop_motion_frames=drop_motion_frames,
460
+ )
461
+
462
+ # f must be divisible by ulysses world size
463
+ x_img, freqs_img = x[:, :x_seq_len], freqs[:, :x_seq_len]
464
+ x_ref_motion, freqs_ref_motion = x[:, x_seq_len:], freqs[:, x_seq_len:]
465
+ with sequence_parallel(
466
+ tensors=(
467
+ x_img,
468
+ freqs_img,
469
+ audio_emb_global,
470
+ merged_audio_emb,
471
+ audio_mask,
472
+ void_audio_emb_global,
473
+ void_merged_audio_emb,
474
+ ),
475
+ seq_dims=(1, 1, 1, 1, 1, 1, 1),
476
+ ):
477
+ x_seq_len_local = x_img.shape[1]
478
+ x = torch.concat([x_img, x_ref_motion], dim=1)
479
+ freqs = torch.concat([freqs_img, freqs_ref_motion], dim=1)
480
+ for idx, block in enumerate(self.blocks):
481
+ x = block(
482
+ x=x, x_seq_len=x_seq_len_local, context=context, t_mod=t_mod, t_mod_0=t_mod_0, freqs=freqs
483
+ )
484
+ if idx in self.audio_injector.injected_block_id.keys():
485
+ x = self.inject_audio(
486
+ x=x,
487
+ x_seq_len=x_seq_len_local,
488
+ block_idx=idx,
489
+ audio_emb_global=audio_emb_global,
490
+ merged_audio_emb=merged_audio_emb,
491
+ audio_mask=audio_mask,
492
+ void_audio_emb_global=void_audio_emb_global,
493
+ void_merged_audio_emb=void_merged_audio_emb,
494
+ )
495
+
496
+ x = x[:, :x_seq_len_local]
497
+ x = self.head(x, t)
498
+ (x,) = sequence_parallel_unshard((x,), seq_dims=(1,), seq_lens=(x_seq_len,))
499
+ x = self.unpatchify(x, (f, h, w))
500
+ (x,) = cfg_parallel_unshard((x,), use_cfg=use_cfg)
501
+ return x
502
+
503
+ def get_audio_emb(
504
+ self,
505
+ audio_input: torch.Tensor,
506
+ num_motion_frames: int = 73,
507
+ num_motion_latents: int = 19,
508
+ audio_mask: Optional[torch.Tensor] = None,
509
+ void_audio_input: Optional[torch.Tensor] = None,
510
+ ):
511
+ void_audio_emb_global, void_merged_audio_emb = None, None
512
+ if audio_mask is not None:
513
+ audio_mask = rearrange(audio_mask, "b c f h w -> b (f h w) c").contiguous()
514
+ void_audio_input = torch.cat(
515
+ [void_audio_input[..., 0:1].repeat(1, 1, 1, num_motion_frames), void_audio_input], dim=-1
516
+ )
517
+ void_audio_emb_global, void_audio_emb = self.casual_audio_encoder(void_audio_input)
518
+ void_audio_emb_global = void_audio_emb_global[:, num_motion_latents:]
519
+ void_merged_audio_emb = void_audio_emb[:, num_motion_latents:, :]
520
+
521
+ audio_input = torch.cat([audio_input[..., 0:1].repeat(1, 1, 1, num_motion_frames), audio_input], dim=-1)
522
+ audio_emb_global, audio_emb = self.casual_audio_encoder(audio_input)
523
+ audio_emb_global = audio_emb_global[:, num_motion_latents:]
524
+ merged_audio_emb = audio_emb[:, num_motion_latents:, :]
525
+ return audio_emb_global, merged_audio_emb, void_audio_emb_global, void_merged_audio_emb, audio_mask
526
+
527
+ def inject_audio(
528
+ self,
529
+ x: torch.Tensor,
530
+ x_seq_len: int,
531
+ block_idx: int,
532
+ audio_emb_global: torch.Tensor,
533
+ merged_audio_emb: torch.Tensor,
534
+ audio_mask: Optional[torch.Tensor] = None,
535
+ void_audio_emb_global: Optional[torch.Tensor] = None,
536
+ void_merged_audio_emb: Optional[torch.Tensor] = None,
537
+ ):
538
+ audio_attn_id = self.audio_injector.injected_block_id[block_idx]
539
+ num_latents_per_clip = merged_audio_emb.shape[1]
540
+
541
+ x_input = x[:, :x_seq_len] # b (f h w) c
542
+ x_input = rearrange(x_input, "b (t n) c -> (b t) n c", t=num_latents_per_clip)
543
+
544
+ def calc_x_adain(x_input: torch.Tensor, audio_emb_global: torch.Tensor):
545
+ audio_emb_global = rearrange(audio_emb_global, "b t n c -> (b t) n c")
546
+ return self.audio_injector.injector_adain_layers[audio_attn_id](x_input, emb=audio_emb_global[:, 0])
547
+
548
+ x_adain = calc_x_adain(x_input, audio_emb_global)
549
+ if void_audio_emb_global is not None:
550
+ x_void_adain = calc_x_adain(x_input, void_audio_emb_global)
551
+
552
+ def calc_x_residual(x_adain: torch.Tensor, merged_audio_emb: torch.Tensor):
553
+ merged_audio_emb = rearrange(merged_audio_emb, "b t n c -> (b t) n c", t=num_latents_per_clip)
554
+ x_cond_residual = self.audio_injector.injector[audio_attn_id](
555
+ x=x_adain,
556
+ y=merged_audio_emb,
557
+ )
558
+ return rearrange(x_cond_residual, "(b t) n c -> b (t n) c", t=num_latents_per_clip)
559
+
560
+ x_cond_residual = calc_x_residual(x_adain, merged_audio_emb)
561
+ if audio_mask is not None:
562
+ x_uncond_residual = calc_x_residual(x_void_adain, void_merged_audio_emb)
563
+ x[:, :x_seq_len] += x_cond_residual * audio_mask + x_uncond_residual * (1 - audio_mask)
564
+ else:
565
+ x[:, :x_seq_len] += x_cond_residual
566
+
567
+ return x
@@ -3,6 +3,7 @@ from .flux_image import FluxImagePipeline
3
3
  from .sdxl_image import SDXLImagePipeline
4
4
  from .sd_image import SDImagePipeline
5
5
  from .wan_video import WanVideoPipeline
6
+ from .wan_s2v import WanSpeech2VideoPipeline
6
7
  from .qwen_image import QwenImagePipeline
7
8
  from .hunyuan3d_shape import Hunyuan3DShapePipeline
8
9
 
@@ -13,6 +14,7 @@ __all__ = [
13
14
  "SDXLImagePipeline",
14
15
  "SDImagePipeline",
15
16
  "WanVideoPipeline",
17
+ "WanSpeech2VideoPipeline",
16
18
  "QwenImagePipeline",
17
19
  "Hunyuan3DShapePipeline",
18
20
  ]