diffsynth-engine 0.5.1.dev2__py3-none-any.whl → 0.5.1.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/__init__.py +2 -0
- diffsynth_engine/conf/models/wan/dit/wan2.2-s2v-14b.json +13 -0
- diffsynth_engine/configs/__init__.py +4 -0
- diffsynth_engine/configs/pipeline.py +36 -0
- diffsynth_engine/models/basic/attention.py +7 -4
- diffsynth_engine/models/wan/wan_audio_encoder.py +306 -0
- diffsynth_engine/models/wan/wan_dit.py +6 -2
- diffsynth_engine/models/wan/wan_s2v_dit.py +567 -0
- diffsynth_engine/pipelines/__init__.py +2 -0
- diffsynth_engine/pipelines/wan_s2v.py +685 -0
- diffsynth_engine/utils/constants.py +1 -0
- diffsynth_engine/utils/image.py +7 -0
- diffsynth_engine/utils/video.py +26 -0
- {diffsynth_engine-0.5.1.dev2.dist-info → diffsynth_engine-0.5.1.dev4.dist-info}/METADATA +3 -1
- {diffsynth_engine-0.5.1.dev2.dist-info → diffsynth_engine-0.5.1.dev4.dist-info}/RECORD +18 -14
- {diffsynth_engine-0.5.1.dev2.dist-info → diffsynth_engine-0.5.1.dev4.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.5.1.dev2.dist-info → diffsynth_engine-0.5.1.dev4.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.5.1.dev2.dist-info → diffsynth_engine-0.5.1.dev4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,567 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import List, Optional
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
from einops import rearrange
|
|
9
|
+
|
|
10
|
+
from diffsynth_engine.models.basic.transformer_helper import AdaLayerNorm
|
|
11
|
+
from diffsynth_engine.models.wan.wan_dit import (
|
|
12
|
+
WanDiT,
|
|
13
|
+
DiTBlock,
|
|
14
|
+
CrossAttention,
|
|
15
|
+
sinusoidal_embedding_1d,
|
|
16
|
+
precompute_freqs_cis_3d,
|
|
17
|
+
modulate,
|
|
18
|
+
)
|
|
19
|
+
from diffsynth_engine.utils.constants import WAN2_2_DIT_S2V_14B_CONFIG_FILE
|
|
20
|
+
from diffsynth_engine.utils.gguf import gguf_inference
|
|
21
|
+
from diffsynth_engine.utils.fp8_linear import fp8_inference
|
|
22
|
+
from diffsynth_engine.utils.parallel import (
|
|
23
|
+
cfg_parallel,
|
|
24
|
+
cfg_parallel_unshard,
|
|
25
|
+
sequence_parallel,
|
|
26
|
+
sequence_parallel_unshard,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def rope_precompute(x: torch.Tensor, grid_sizes: List[List[torch.Tensor]], freqs: torch.Tensor):
|
|
31
|
+
# roughly speaking, this function is to combine ropes, but it is written in a very strange way.
|
|
32
|
+
# I try to make it align better with normal implementation
|
|
33
|
+
b, s, n, c = x.shape
|
|
34
|
+
c = c // 2
|
|
35
|
+
output = torch.view_as_complex(x.reshape(b, s, n, c, 2).to(torch.float64))
|
|
36
|
+
prev_seq = 0
|
|
37
|
+
for grid_size in grid_sizes:
|
|
38
|
+
f_o, h_o, w_o = grid_size[0]
|
|
39
|
+
f, h, w = grid_size[1]
|
|
40
|
+
t_f, t_h, t_w = grid_size[2]
|
|
41
|
+
seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o
|
|
42
|
+
seq_len = int(seq_f * seq_h * seq_w)
|
|
43
|
+
# Generate a list of seq_f integers starting from f_o and ending at math.ceil(factor_f * seq_f.item() + f_o.item())
|
|
44
|
+
if f_o >= 0:
|
|
45
|
+
f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1, seq_f).astype(int).tolist()
|
|
46
|
+
else:
|
|
47
|
+
f_sam = np.linspace(-f_o.item(), (-t_f - f_o).item() + 1, seq_f).astype(int).tolist()
|
|
48
|
+
h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1, seq_h).astype(int).tolist()
|
|
49
|
+
w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1, seq_w).astype(int).tolist()
|
|
50
|
+
freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][f_sam].conj()
|
|
51
|
+
freqs_i = torch.cat(
|
|
52
|
+
[
|
|
53
|
+
freqs_0.view(seq_f, 1, 1, -1).expand(seq_f, seq_h, seq_w, -1),
|
|
54
|
+
freqs[1][h_sam].view(1, seq_h, 1, -1).expand(seq_f, seq_h, seq_w, -1),
|
|
55
|
+
freqs[2][w_sam].view(1, 1, seq_w, -1).expand(seq_f, seq_h, seq_w, -1),
|
|
56
|
+
],
|
|
57
|
+
dim=-1,
|
|
58
|
+
).reshape(seq_len, 1, -1)
|
|
59
|
+
# apply rotary embedding
|
|
60
|
+
output[:, prev_seq : prev_seq + seq_len] = freqs_i
|
|
61
|
+
prev_seq += seq_len
|
|
62
|
+
return output
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class FramePackMotioner(nn.Module):
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
inner_dim: int = 1024,
|
|
69
|
+
num_heads: int = 16,
|
|
70
|
+
zip_frame_buckets: List[int] = [
|
|
71
|
+
1,
|
|
72
|
+
2,
|
|
73
|
+
16,
|
|
74
|
+
], # Three numbers representing the number of frames sampled for patch operations from the nearest to the farthest frames
|
|
75
|
+
device: str = "cuda:0",
|
|
76
|
+
dtype: torch.dtype = torch.bfloat16,
|
|
77
|
+
):
|
|
78
|
+
super().__init__()
|
|
79
|
+
self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2), device=device, dtype=dtype)
|
|
80
|
+
self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4), device=device, dtype=dtype)
|
|
81
|
+
self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8), device=device, dtype=dtype)
|
|
82
|
+
self.zip_frame_buckets = zip_frame_buckets
|
|
83
|
+
|
|
84
|
+
self.inner_dim = inner_dim
|
|
85
|
+
self.num_heads = num_heads
|
|
86
|
+
|
|
87
|
+
assert (inner_dim % num_heads) == 0 and (inner_dim // num_heads) % 2 == 0
|
|
88
|
+
head_dim = inner_dim // num_heads
|
|
89
|
+
self.freqs = precompute_freqs_cis_3d(head_dim)
|
|
90
|
+
|
|
91
|
+
def forward(self, motion_latents: torch.Tensor):
|
|
92
|
+
b, _, f, h, w = motion_latents.shape
|
|
93
|
+
padd_latents = torch.zeros(
|
|
94
|
+
(b, 16, sum(self.zip_frame_buckets), h, w), device=motion_latents.device, dtype=motion_latents.dtype
|
|
95
|
+
)
|
|
96
|
+
overlap_frame = min(padd_latents.shape[2], f)
|
|
97
|
+
if overlap_frame > 0:
|
|
98
|
+
padd_latents[:, :, -overlap_frame:] = motion_latents[:, :, -overlap_frame:]
|
|
99
|
+
|
|
100
|
+
clean_latents_4x, clean_latents_2x, clean_latents_post = padd_latents[
|
|
101
|
+
:, :, -sum(self.zip_frame_buckets) :
|
|
102
|
+
].split(self.zip_frame_buckets[::-1], dim=2) # 16, 2 ,1
|
|
103
|
+
|
|
104
|
+
clean_latents_post = rearrange(self.proj(clean_latents_post), "b c f h w -> b (f h w) c").contiguous()
|
|
105
|
+
clean_latents_2x = rearrange(self.proj_2x(clean_latents_2x), "b c f h w -> b (f h w) c").contiguous()
|
|
106
|
+
clean_latents_4x = rearrange(self.proj_4x(clean_latents_4x), "b c f h w -> b (f h w) c").contiguous()
|
|
107
|
+
motion_latents = torch.cat([clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1)
|
|
108
|
+
|
|
109
|
+
def get_grid_sizes(i: int): # rope, 0: post, 1: 2x, 2: 4x
|
|
110
|
+
start_time_id = -sum(self.zip_frame_buckets[: (i + 1)])
|
|
111
|
+
end_time_id = start_time_id + self.zip_frame_buckets[i] // (2**i)
|
|
112
|
+
return [
|
|
113
|
+
[
|
|
114
|
+
torch.tensor([start_time_id, 0, 0]),
|
|
115
|
+
torch.tensor([end_time_id, h // (2 ** (i + 1)), w // (2 ** (i + 1))]),
|
|
116
|
+
torch.tensor([self.zip_frame_buckets[i], h // 2, w // 2]),
|
|
117
|
+
]
|
|
118
|
+
]
|
|
119
|
+
|
|
120
|
+
motion_rope_emb = rope_precompute(
|
|
121
|
+
x=rearrange(motion_latents, "b s (n d) -> b s n d", n=self.num_heads),
|
|
122
|
+
grid_sizes=get_grid_sizes(0) + get_grid_sizes(1) + get_grid_sizes(2),
|
|
123
|
+
freqs=self.freqs,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
return motion_latents, motion_rope_emb
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class CausalConv1d(nn.Module):
|
|
130
|
+
def __init__(
|
|
131
|
+
self,
|
|
132
|
+
chan_in: int,
|
|
133
|
+
chan_out: int,
|
|
134
|
+
kernel_size: int = 3,
|
|
135
|
+
stride: int = 1,
|
|
136
|
+
dilation: int = 1,
|
|
137
|
+
pad_mode: str = "replicate",
|
|
138
|
+
device: str = "cuda:0",
|
|
139
|
+
dtype: torch.dtype = torch.bfloat16,
|
|
140
|
+
):
|
|
141
|
+
super().__init__()
|
|
142
|
+
self.pad_mode = pad_mode
|
|
143
|
+
self.time_causal_padding = (kernel_size - 1, 0)
|
|
144
|
+
self.conv = nn.Conv1d(
|
|
145
|
+
chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, device=device, dtype=dtype
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
def forward(self, x: torch.Tensor):
|
|
149
|
+
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
|
150
|
+
return self.conv(x)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class MotionEncoder(nn.Module):
|
|
154
|
+
def __init__(
|
|
155
|
+
self,
|
|
156
|
+
in_dim: int,
|
|
157
|
+
hidden_dim: int,
|
|
158
|
+
num_heads: int = 8,
|
|
159
|
+
dtype: torch.dtype = torch.bfloat16,
|
|
160
|
+
device: str = "cuda:0",
|
|
161
|
+
):
|
|
162
|
+
super().__init__()
|
|
163
|
+
self.num_heads = num_heads
|
|
164
|
+
self.conv1_local = CausalConv1d(in_dim, hidden_dim // 4 * num_heads, 3, stride=1, device=device, dtype=dtype)
|
|
165
|
+
self.conv1_global = CausalConv1d(in_dim, hidden_dim // 4, 3, stride=1, device=device, dtype=dtype)
|
|
166
|
+
self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2, device=device, dtype=dtype)
|
|
167
|
+
self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2, device=device, dtype=dtype)
|
|
168
|
+
self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
|
169
|
+
self.norm2 = nn.LayerNorm(hidden_dim // 2, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
|
170
|
+
self.norm3 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
|
171
|
+
self.act = nn.SiLU()
|
|
172
|
+
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
|
|
173
|
+
self.final_linear = nn.Linear(hidden_dim, hidden_dim, device=device, dtype=dtype)
|
|
174
|
+
|
|
175
|
+
def forward(self, x: torch.Tensor):
|
|
176
|
+
x = rearrange(x, "b t c -> b c t")
|
|
177
|
+
x_original = x
|
|
178
|
+
b = x.shape[0]
|
|
179
|
+
x = self.conv1_local(x)
|
|
180
|
+
x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads)
|
|
181
|
+
x = self.act(self.norm1(x))
|
|
182
|
+
x = rearrange(x, "b t c -> b c t")
|
|
183
|
+
x = self.conv2(x)
|
|
184
|
+
x = rearrange(x, "b c t -> b t c")
|
|
185
|
+
x = self.act(self.norm2(x))
|
|
186
|
+
x = rearrange(x, "b t c -> b c t")
|
|
187
|
+
x = self.conv3(x)
|
|
188
|
+
x = rearrange(x, "b c t -> b t c")
|
|
189
|
+
x = self.act(self.norm3(x))
|
|
190
|
+
x = rearrange(x, "(b n) t c -> b t n c", b=b)
|
|
191
|
+
padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)
|
|
192
|
+
x = torch.cat([x, padding], dim=-2)
|
|
193
|
+
x_local = x
|
|
194
|
+
|
|
195
|
+
x = self.conv1_global(x_original)
|
|
196
|
+
x = rearrange(x, "b c t -> b t c")
|
|
197
|
+
x = self.act(self.norm1(x))
|
|
198
|
+
x = rearrange(x, "b t c -> b c t")
|
|
199
|
+
x = self.conv2(x)
|
|
200
|
+
x = rearrange(x, "b c t -> b t c")
|
|
201
|
+
x = self.act(self.norm2(x))
|
|
202
|
+
x = rearrange(x, "b t c -> b c t")
|
|
203
|
+
x = self.conv3(x)
|
|
204
|
+
x = rearrange(x, "b c t -> b t c")
|
|
205
|
+
x = self.act(self.norm3(x))
|
|
206
|
+
x = self.final_linear(x)
|
|
207
|
+
x = rearrange(x, "(b n) t c -> b t n c", b=b)
|
|
208
|
+
return x, x_local
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
class CausalAudioEncoder(nn.Module):
|
|
212
|
+
def __init__(
|
|
213
|
+
self,
|
|
214
|
+
dim: int = 1024,
|
|
215
|
+
num_layers: int = 25,
|
|
216
|
+
out_dim: int = 2048,
|
|
217
|
+
num_token: int = 4,
|
|
218
|
+
dtype: torch.dtype = torch.bfloat16,
|
|
219
|
+
device: str = "cuda:0",
|
|
220
|
+
):
|
|
221
|
+
super().__init__()
|
|
222
|
+
self.encoder = MotionEncoder(in_dim=dim, hidden_dim=out_dim, num_heads=num_token, device=device, dtype=dtype)
|
|
223
|
+
self.weights = nn.Parameter(torch.ones((1, num_layers, 1, 1), device=device, dtype=dtype) * 0.01)
|
|
224
|
+
self.act = nn.SiLU()
|
|
225
|
+
|
|
226
|
+
def forward(self, features: torch.Tensor):
|
|
227
|
+
# features: b num_layers dim video_length
|
|
228
|
+
weights = self.act(self.weights)
|
|
229
|
+
weights_sum = weights.sum(dim=1, keepdims=True)
|
|
230
|
+
weighted_feat = ((features * weights) / weights_sum).sum(dim=1) # b dim f
|
|
231
|
+
weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim
|
|
232
|
+
return self.encoder(weighted_feat) # b f n dim
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
class AudioInjector(nn.Module):
|
|
236
|
+
def __init__(
|
|
237
|
+
self,
|
|
238
|
+
dim=5120,
|
|
239
|
+
num_heads=40,
|
|
240
|
+
inject_layers=[0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39],
|
|
241
|
+
adain_dim=5120,
|
|
242
|
+
device: str = "cuda:0",
|
|
243
|
+
dtype: torch.dtype = torch.bfloat16,
|
|
244
|
+
):
|
|
245
|
+
super().__init__()
|
|
246
|
+
self.injected_block_id = {}
|
|
247
|
+
for i, id in enumerate(inject_layers):
|
|
248
|
+
self.injected_block_id[id] = i
|
|
249
|
+
|
|
250
|
+
self.injector = nn.ModuleList(
|
|
251
|
+
[
|
|
252
|
+
CrossAttention(
|
|
253
|
+
dim=dim,
|
|
254
|
+
num_heads=num_heads,
|
|
255
|
+
device=device,
|
|
256
|
+
dtype=dtype,
|
|
257
|
+
)
|
|
258
|
+
for _ in range(len(inject_layers))
|
|
259
|
+
]
|
|
260
|
+
)
|
|
261
|
+
self.injector_adain_layers = nn.ModuleList(
|
|
262
|
+
[AdaLayerNorm(dim=adain_dim, device=device, dtype=dtype) for _ in range(len(inject_layers))]
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
class DiTBlockS2V(nn.Module):
|
|
267
|
+
def __init__(self, dit_block: DiTBlock):
|
|
268
|
+
super().__init__()
|
|
269
|
+
self.dim = dit_block.dim
|
|
270
|
+
self.num_heads = dit_block.num_heads
|
|
271
|
+
self.ffn_dim = dit_block.ffn_dim
|
|
272
|
+
self.self_attn = dit_block.self_attn
|
|
273
|
+
self.cross_attn = dit_block.cross_attn
|
|
274
|
+
self.norm1 = dit_block.norm1
|
|
275
|
+
self.norm2 = dit_block.norm2
|
|
276
|
+
self.norm3 = dit_block.norm3
|
|
277
|
+
self.ffn = dit_block.ffn
|
|
278
|
+
self.modulation = dit_block.modulation
|
|
279
|
+
|
|
280
|
+
def forward(self, x, x_seq_len, context, t_mod, t_mod_0, freqs):
|
|
281
|
+
# msa: multi-head self-attention mlp: multi-layer perceptron
|
|
282
|
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
|
|
283
|
+
t for t in (self.modulation + t_mod).chunk(6, dim=1)
|
|
284
|
+
]
|
|
285
|
+
shift_msa_0, scale_msa_0, gate_msa_0, shift_mlp_0, scale_mlp_0, gate_mlp_0 = [
|
|
286
|
+
t for t in (self.modulation + t_mod_0).chunk(6, dim=1)
|
|
287
|
+
]
|
|
288
|
+
norm1_x = self.norm1(x)
|
|
289
|
+
input_x = torch.cat(
|
|
290
|
+
[
|
|
291
|
+
modulate(norm1_x[:, :x_seq_len], shift_msa, scale_msa),
|
|
292
|
+
modulate(norm1_x[:, x_seq_len:], shift_msa_0, scale_msa_0),
|
|
293
|
+
],
|
|
294
|
+
dim=1,
|
|
295
|
+
)
|
|
296
|
+
self_attn_x = self.self_attn(input_x, freqs)
|
|
297
|
+
x += torch.cat([self_attn_x[:, :x_seq_len] * gate_msa, self_attn_x[:, x_seq_len:] * gate_msa_0], dim=1)
|
|
298
|
+
x += self.cross_attn(self.norm3(x), context)
|
|
299
|
+
norm2_x = self.norm2(x)
|
|
300
|
+
input_x = torch.cat(
|
|
301
|
+
[
|
|
302
|
+
modulate(norm2_x[:, :x_seq_len], shift_mlp, scale_mlp),
|
|
303
|
+
modulate(norm2_x[:, x_seq_len:], shift_mlp_0, scale_mlp_0),
|
|
304
|
+
],
|
|
305
|
+
dim=1,
|
|
306
|
+
)
|
|
307
|
+
ffn_x = self.ffn(input_x)
|
|
308
|
+
x += torch.cat([ffn_x[:, :x_seq_len] * gate_mlp, ffn_x[:, x_seq_len:] * gate_mlp_0], dim=1)
|
|
309
|
+
return x
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
class WanS2VDiT(WanDiT):
|
|
313
|
+
def __init__(
|
|
314
|
+
self,
|
|
315
|
+
cond_dim: int = 16,
|
|
316
|
+
audio_dim: int = 1024,
|
|
317
|
+
num_audio_token: int = 4,
|
|
318
|
+
audio_inject_layers: List[int] = [0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39],
|
|
319
|
+
num_heads: int = 40,
|
|
320
|
+
device: str = "cuda:0",
|
|
321
|
+
dtype: torch.dtype = torch.bfloat16,
|
|
322
|
+
*args,
|
|
323
|
+
**kwargs,
|
|
324
|
+
):
|
|
325
|
+
super().__init__(num_heads=num_heads, device=device, dtype=dtype, *args, **kwargs)
|
|
326
|
+
self.num_heads = num_heads
|
|
327
|
+
self.cond_encoder = nn.Conv3d(
|
|
328
|
+
cond_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size, device=device, dtype=dtype
|
|
329
|
+
)
|
|
330
|
+
self.casual_audio_encoder = CausalAudioEncoder(
|
|
331
|
+
dim=audio_dim, out_dim=self.dim, num_token=num_audio_token, device=device, dtype=dtype
|
|
332
|
+
)
|
|
333
|
+
self.audio_injector = AudioInjector(
|
|
334
|
+
dim=self.dim,
|
|
335
|
+
num_heads=num_heads,
|
|
336
|
+
inject_layers=audio_inject_layers,
|
|
337
|
+
adain_dim=self.dim,
|
|
338
|
+
device=device,
|
|
339
|
+
dtype=dtype,
|
|
340
|
+
)
|
|
341
|
+
self.trainable_cond_mask = nn.Embedding(3, self.dim, device=device, dtype=dtype)
|
|
342
|
+
self.frame_packer = FramePackMotioner(
|
|
343
|
+
inner_dim=self.dim,
|
|
344
|
+
num_heads=num_heads,
|
|
345
|
+
zip_frame_buckets=[1, 2, 16],
|
|
346
|
+
device=device,
|
|
347
|
+
dtype=dtype,
|
|
348
|
+
)
|
|
349
|
+
dit_blocks_s2v: nn.ModuleList[DiTBlockS2V] = nn.ModuleList()
|
|
350
|
+
for block in self.blocks:
|
|
351
|
+
dit_blocks_s2v.append(DiTBlockS2V(block))
|
|
352
|
+
self.blocks = dit_blocks_s2v
|
|
353
|
+
|
|
354
|
+
@staticmethod
|
|
355
|
+
def get_model_config(model_type: str):
|
|
356
|
+
MODEL_CONFIG_FILES = {
|
|
357
|
+
"wan2.2-s2v-14b": WAN2_2_DIT_S2V_14B_CONFIG_FILE,
|
|
358
|
+
}
|
|
359
|
+
if model_type not in MODEL_CONFIG_FILES:
|
|
360
|
+
raise ValueError(f"Unsupported model type: {model_type}")
|
|
361
|
+
|
|
362
|
+
config_file = MODEL_CONFIG_FILES[model_type]
|
|
363
|
+
with open(config_file, "r") as f:
|
|
364
|
+
config = json.load(f)
|
|
365
|
+
return config
|
|
366
|
+
|
|
367
|
+
def inject_motion(
|
|
368
|
+
self,
|
|
369
|
+
x: torch.Tensor,
|
|
370
|
+
x_seq_len: int,
|
|
371
|
+
rope_embs: torch.Tensor,
|
|
372
|
+
motion_latents: torch.Tensor,
|
|
373
|
+
drop_motion_frames: bool = False,
|
|
374
|
+
):
|
|
375
|
+
# Initialize masks to indicate noisy latent, ref latent, and motion latent.
|
|
376
|
+
b, s, _ = x.shape
|
|
377
|
+
mask_input = torch.zeros([b, s], dtype=torch.long, device=x.device)
|
|
378
|
+
mask_input[:, x_seq_len:] = 1
|
|
379
|
+
|
|
380
|
+
if not drop_motion_frames:
|
|
381
|
+
motion, motion_rope_emb = self.frame_packer(motion_latents)
|
|
382
|
+
x = torch.cat([x, motion], dim=1)
|
|
383
|
+
rope_embs = torch.cat([rope_embs, motion_rope_emb], dim=1)
|
|
384
|
+
mask_input = torch.cat(
|
|
385
|
+
[
|
|
386
|
+
mask_input,
|
|
387
|
+
2 * torch.ones([b, motion.shape[1]], device=mask_input.device, dtype=mask_input.dtype),
|
|
388
|
+
],
|
|
389
|
+
dim=1,
|
|
390
|
+
)
|
|
391
|
+
x += self.trainable_cond_mask(mask_input).to(x.dtype)
|
|
392
|
+
return x, rope_embs
|
|
393
|
+
|
|
394
|
+
def patchify_x_with_pose(self, x: torch.Tensor, pose: torch.Tensor):
|
|
395
|
+
x = self.patch_embedding(x) + self.cond_encoder(pose)
|
|
396
|
+
grid_size = x.shape[2:]
|
|
397
|
+
x = rearrange(x, "b c f h w -> b (f h w) c").contiguous()
|
|
398
|
+
return x, grid_size # x, grid_size: (f, h, w)
|
|
399
|
+
|
|
400
|
+
def forward(
|
|
401
|
+
self,
|
|
402
|
+
x: torch.Tensor, # b c tx h w
|
|
403
|
+
context: torch.Tensor, # b s c
|
|
404
|
+
timestep: torch.Tensor, # b
|
|
405
|
+
ref_latents: torch.Tensor, # b c 1 h w
|
|
406
|
+
motion_latents: torch.Tensor, # b c tm h w
|
|
407
|
+
pose_cond: torch.Tensor, # b c tx h w
|
|
408
|
+
audio_input: torch.Tensor, # b c d tx
|
|
409
|
+
num_motion_frames: int = 73,
|
|
410
|
+
num_motion_latents: int = 19,
|
|
411
|
+
drop_motion_frames: bool = False, # !(ref_as_first_frame || clip_idx)
|
|
412
|
+
audio_mask: Optional[torch.Tensor] = None, # b c tx h w
|
|
413
|
+
void_audio_input: Optional[torch.Tensor] = None,
|
|
414
|
+
):
|
|
415
|
+
fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
|
|
416
|
+
use_cfg = x.shape[0] > 1
|
|
417
|
+
with (
|
|
418
|
+
fp8_inference(fp8_linear_enabled),
|
|
419
|
+
gguf_inference(),
|
|
420
|
+
cfg_parallel((x, context, audio_input), use_cfg=use_cfg),
|
|
421
|
+
):
|
|
422
|
+
audio_emb_global, merged_audio_emb, void_audio_emb_global, void_merged_audio_emb, audio_mask = (
|
|
423
|
+
self.get_audio_emb(audio_input, num_motion_frames, num_motion_latents, audio_mask, void_audio_input)
|
|
424
|
+
)
|
|
425
|
+
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep)) # (s, d)
|
|
426
|
+
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
|
|
427
|
+
t_mod_0 = self.time_projection(
|
|
428
|
+
self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, torch.zeros([1]).to(t)))
|
|
429
|
+
).unflatten(1, (6, self.dim))
|
|
430
|
+
context = self.text_embedding(context)
|
|
431
|
+
x, (f, h, w) = self.patchify_x_with_pose(x, pose_cond)
|
|
432
|
+
ref, _ = self.patchify(ref_latents)
|
|
433
|
+
x = torch.cat([x, ref], dim=1)
|
|
434
|
+
freqs = rope_precompute(
|
|
435
|
+
x=rearrange(x, "b s (n d) -> b s n d", n=self.num_heads),
|
|
436
|
+
grid_sizes=[
|
|
437
|
+
[
|
|
438
|
+
torch.tensor([0, 0, 0]),
|
|
439
|
+
torch.tensor([f, h, w]),
|
|
440
|
+
torch.tensor([f, h, w]),
|
|
441
|
+
], # grid size of x
|
|
442
|
+
[
|
|
443
|
+
torch.tensor([30, 0, 0]),
|
|
444
|
+
torch.tensor([31, h, w]),
|
|
445
|
+
torch.tensor([1, h, w]),
|
|
446
|
+
], # grid size of ref
|
|
447
|
+
],
|
|
448
|
+
freqs=self.freqs,
|
|
449
|
+
)
|
|
450
|
+
# why do they fix 30?
|
|
451
|
+
# seems that they just want self.freqs[0][30]
|
|
452
|
+
|
|
453
|
+
x_seq_len = f * h * w
|
|
454
|
+
x, freqs = self.inject_motion(
|
|
455
|
+
x=x,
|
|
456
|
+
x_seq_len=x_seq_len,
|
|
457
|
+
rope_embs=freqs,
|
|
458
|
+
motion_latents=motion_latents,
|
|
459
|
+
drop_motion_frames=drop_motion_frames,
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
# f must be divisible by ulysses world size
|
|
463
|
+
x_img, freqs_img = x[:, :x_seq_len], freqs[:, :x_seq_len]
|
|
464
|
+
x_ref_motion, freqs_ref_motion = x[:, x_seq_len:], freqs[:, x_seq_len:]
|
|
465
|
+
with sequence_parallel(
|
|
466
|
+
tensors=(
|
|
467
|
+
x_img,
|
|
468
|
+
freqs_img,
|
|
469
|
+
audio_emb_global,
|
|
470
|
+
merged_audio_emb,
|
|
471
|
+
audio_mask,
|
|
472
|
+
void_audio_emb_global,
|
|
473
|
+
void_merged_audio_emb,
|
|
474
|
+
),
|
|
475
|
+
seq_dims=(1, 1, 1, 1, 1, 1, 1),
|
|
476
|
+
):
|
|
477
|
+
x_seq_len_local = x_img.shape[1]
|
|
478
|
+
x = torch.concat([x_img, x_ref_motion], dim=1)
|
|
479
|
+
freqs = torch.concat([freqs_img, freqs_ref_motion], dim=1)
|
|
480
|
+
for idx, block in enumerate(self.blocks):
|
|
481
|
+
x = block(
|
|
482
|
+
x=x, x_seq_len=x_seq_len_local, context=context, t_mod=t_mod, t_mod_0=t_mod_0, freqs=freqs
|
|
483
|
+
)
|
|
484
|
+
if idx in self.audio_injector.injected_block_id.keys():
|
|
485
|
+
x = self.inject_audio(
|
|
486
|
+
x=x,
|
|
487
|
+
x_seq_len=x_seq_len_local,
|
|
488
|
+
block_idx=idx,
|
|
489
|
+
audio_emb_global=audio_emb_global,
|
|
490
|
+
merged_audio_emb=merged_audio_emb,
|
|
491
|
+
audio_mask=audio_mask,
|
|
492
|
+
void_audio_emb_global=void_audio_emb_global,
|
|
493
|
+
void_merged_audio_emb=void_merged_audio_emb,
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
x = x[:, :x_seq_len_local]
|
|
497
|
+
x = self.head(x, t)
|
|
498
|
+
(x,) = sequence_parallel_unshard((x,), seq_dims=(1,), seq_lens=(x_seq_len,))
|
|
499
|
+
x = self.unpatchify(x, (f, h, w))
|
|
500
|
+
(x,) = cfg_parallel_unshard((x,), use_cfg=use_cfg)
|
|
501
|
+
return x
|
|
502
|
+
|
|
503
|
+
def get_audio_emb(
|
|
504
|
+
self,
|
|
505
|
+
audio_input: torch.Tensor,
|
|
506
|
+
num_motion_frames: int = 73,
|
|
507
|
+
num_motion_latents: int = 19,
|
|
508
|
+
audio_mask: Optional[torch.Tensor] = None,
|
|
509
|
+
void_audio_input: Optional[torch.Tensor] = None,
|
|
510
|
+
):
|
|
511
|
+
void_audio_emb_global, void_merged_audio_emb = None, None
|
|
512
|
+
if audio_mask is not None:
|
|
513
|
+
audio_mask = rearrange(audio_mask, "b c f h w -> b (f h w) c").contiguous()
|
|
514
|
+
void_audio_input = torch.cat(
|
|
515
|
+
[void_audio_input[..., 0:1].repeat(1, 1, 1, num_motion_frames), void_audio_input], dim=-1
|
|
516
|
+
)
|
|
517
|
+
void_audio_emb_global, void_audio_emb = self.casual_audio_encoder(void_audio_input)
|
|
518
|
+
void_audio_emb_global = void_audio_emb_global[:, num_motion_latents:]
|
|
519
|
+
void_merged_audio_emb = void_audio_emb[:, num_motion_latents:, :]
|
|
520
|
+
|
|
521
|
+
audio_input = torch.cat([audio_input[..., 0:1].repeat(1, 1, 1, num_motion_frames), audio_input], dim=-1)
|
|
522
|
+
audio_emb_global, audio_emb = self.casual_audio_encoder(audio_input)
|
|
523
|
+
audio_emb_global = audio_emb_global[:, num_motion_latents:]
|
|
524
|
+
merged_audio_emb = audio_emb[:, num_motion_latents:, :]
|
|
525
|
+
return audio_emb_global, merged_audio_emb, void_audio_emb_global, void_merged_audio_emb, audio_mask
|
|
526
|
+
|
|
527
|
+
def inject_audio(
|
|
528
|
+
self,
|
|
529
|
+
x: torch.Tensor,
|
|
530
|
+
x_seq_len: int,
|
|
531
|
+
block_idx: int,
|
|
532
|
+
audio_emb_global: torch.Tensor,
|
|
533
|
+
merged_audio_emb: torch.Tensor,
|
|
534
|
+
audio_mask: Optional[torch.Tensor] = None,
|
|
535
|
+
void_audio_emb_global: Optional[torch.Tensor] = None,
|
|
536
|
+
void_merged_audio_emb: Optional[torch.Tensor] = None,
|
|
537
|
+
):
|
|
538
|
+
audio_attn_id = self.audio_injector.injected_block_id[block_idx]
|
|
539
|
+
num_latents_per_clip = merged_audio_emb.shape[1]
|
|
540
|
+
|
|
541
|
+
x_input = x[:, :x_seq_len] # b (f h w) c
|
|
542
|
+
x_input = rearrange(x_input, "b (t n) c -> (b t) n c", t=num_latents_per_clip)
|
|
543
|
+
|
|
544
|
+
def calc_x_adain(x_input: torch.Tensor, audio_emb_global: torch.Tensor):
|
|
545
|
+
audio_emb_global = rearrange(audio_emb_global, "b t n c -> (b t) n c")
|
|
546
|
+
return self.audio_injector.injector_adain_layers[audio_attn_id](x_input, emb=audio_emb_global[:, 0])
|
|
547
|
+
|
|
548
|
+
x_adain = calc_x_adain(x_input, audio_emb_global)
|
|
549
|
+
if void_audio_emb_global is not None:
|
|
550
|
+
x_void_adain = calc_x_adain(x_input, void_audio_emb_global)
|
|
551
|
+
|
|
552
|
+
def calc_x_residual(x_adain: torch.Tensor, merged_audio_emb: torch.Tensor):
|
|
553
|
+
merged_audio_emb = rearrange(merged_audio_emb, "b t n c -> (b t) n c", t=num_latents_per_clip)
|
|
554
|
+
x_cond_residual = self.audio_injector.injector[audio_attn_id](
|
|
555
|
+
x=x_adain,
|
|
556
|
+
y=merged_audio_emb,
|
|
557
|
+
)
|
|
558
|
+
return rearrange(x_cond_residual, "(b t) n c -> b (t n) c", t=num_latents_per_clip)
|
|
559
|
+
|
|
560
|
+
x_cond_residual = calc_x_residual(x_adain, merged_audio_emb)
|
|
561
|
+
if audio_mask is not None:
|
|
562
|
+
x_uncond_residual = calc_x_residual(x_void_adain, void_merged_audio_emb)
|
|
563
|
+
x[:, :x_seq_len] += x_cond_residual * audio_mask + x_uncond_residual * (1 - audio_mask)
|
|
564
|
+
else:
|
|
565
|
+
x[:, :x_seq_len] += x_cond_residual
|
|
566
|
+
|
|
567
|
+
return x
|
|
@@ -3,6 +3,7 @@ from .flux_image import FluxImagePipeline
|
|
|
3
3
|
from .sdxl_image import SDXLImagePipeline
|
|
4
4
|
from .sd_image import SDImagePipeline
|
|
5
5
|
from .wan_video import WanVideoPipeline
|
|
6
|
+
from .wan_s2v import WanSpeech2VideoPipeline
|
|
6
7
|
from .qwen_image import QwenImagePipeline
|
|
7
8
|
from .hunyuan3d_shape import Hunyuan3DShapePipeline
|
|
8
9
|
|
|
@@ -13,6 +14,7 @@ __all__ = [
|
|
|
13
14
|
"SDXLImagePipeline",
|
|
14
15
|
"SDImagePipeline",
|
|
15
16
|
"WanVideoPipeline",
|
|
17
|
+
"WanSpeech2VideoPipeline",
|
|
16
18
|
"QwenImagePipeline",
|
|
17
19
|
"Hunyuan3DShapePipeline",
|
|
18
20
|
]
|