diffsynth-engine 0.5.1.dev2__py3-none-any.whl → 0.5.1.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/__init__.py +2 -0
- diffsynth_engine/conf/models/wan/dit/wan2.2-s2v-14b.json +13 -0
- diffsynth_engine/configs/__init__.py +4 -0
- diffsynth_engine/configs/pipeline.py +36 -0
- diffsynth_engine/models/basic/attention.py +7 -4
- diffsynth_engine/models/wan/wan_audio_encoder.py +306 -0
- diffsynth_engine/models/wan/wan_dit.py +6 -2
- diffsynth_engine/models/wan/wan_s2v_dit.py +567 -0
- diffsynth_engine/pipelines/__init__.py +2 -0
- diffsynth_engine/pipelines/wan_s2v.py +685 -0
- diffsynth_engine/utils/constants.py +1 -0
- diffsynth_engine/utils/image.py +7 -0
- diffsynth_engine/utils/video.py +26 -0
- {diffsynth_engine-0.5.1.dev2.dist-info → diffsynth_engine-0.5.1.dev4.dist-info}/METADATA +3 -1
- {diffsynth_engine-0.5.1.dev2.dist-info → diffsynth_engine-0.5.1.dev4.dist-info}/RECORD +18 -14
- {diffsynth_engine-0.5.1.dev2.dist-info → diffsynth_engine-0.5.1.dev4.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.5.1.dev2.dist-info → diffsynth_engine-0.5.1.dev4.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.5.1.dev2.dist-info → diffsynth_engine-0.5.1.dev4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,685 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
import torch.distributed as dist
|
|
7
|
+
from torchvision.transforms.functional import pil_to_tensor
|
|
8
|
+
from typing import Callable, List, Optional
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
from PIL import Image
|
|
11
|
+
|
|
12
|
+
from diffsynth_engine.configs import WanSpeech2VideoPipelineConfig, WanS2VStateDicts
|
|
13
|
+
from diffsynth_engine.models.wan.wan_s2v_dit import WanS2VDiT
|
|
14
|
+
from diffsynth_engine.models.wan.wan_text_encoder import WanTextEncoder
|
|
15
|
+
from diffsynth_engine.models.wan.wan_audio_encoder import (
|
|
16
|
+
Wav2Vec2Model,
|
|
17
|
+
Wav2Vec2Config,
|
|
18
|
+
get_audio_embed_bucket_fps,
|
|
19
|
+
extract_audio_feat,
|
|
20
|
+
)
|
|
21
|
+
from diffsynth_engine.models.wan.wan_vae import WanVideoVAE
|
|
22
|
+
from diffsynth_engine.pipelines.wan_video import WanVideoPipeline
|
|
23
|
+
from diffsynth_engine.models.basic.lora import LoRAContext
|
|
24
|
+
from diffsynth_engine.tokenizers import WanT5Tokenizer
|
|
25
|
+
from diffsynth_engine.utils.constants import WAN_TOKENIZER_CONF_PATH
|
|
26
|
+
from diffsynth_engine.utils.download import fetch_model
|
|
27
|
+
from diffsynth_engine.utils.fp8_linear import enable_fp8_linear
|
|
28
|
+
from diffsynth_engine.utils.image import resize_and_center_crop
|
|
29
|
+
from diffsynth_engine.utils.video import read_n_frames
|
|
30
|
+
from diffsynth_engine.utils.parallel import ParallelWrapper
|
|
31
|
+
from diffsynth_engine.utils import logging
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
logger = logging.get_logger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def get_face_mask(
|
|
38
|
+
ref_image: Image.Image,
|
|
39
|
+
speaker_end_sec: List[List[int]],
|
|
40
|
+
speaker_bbox: List[List[int]],
|
|
41
|
+
num_frames_total: int,
|
|
42
|
+
fps=16,
|
|
43
|
+
temporal_scale=4,
|
|
44
|
+
spatial_scale=16,
|
|
45
|
+
dtype=torch.bfloat16,
|
|
46
|
+
):
|
|
47
|
+
mask_height, mask_width = ref_image.height, ref_image.width
|
|
48
|
+
ref_image = np.array(ref_image)
|
|
49
|
+
face_mask = torch.zeros(
|
|
50
|
+
[1, num_frames_total, mask_height, mask_width],
|
|
51
|
+
dtype=dtype,
|
|
52
|
+
)
|
|
53
|
+
prev_time = 0
|
|
54
|
+
for speaker_id, end_time in speaker_end_sec:
|
|
55
|
+
start_frame = int(prev_time * fps)
|
|
56
|
+
end_frame = int(end_time * fps)
|
|
57
|
+
mask = torch.zeros(ref_image.shape[:2], dtype=dtype)
|
|
58
|
+
for id, bbox in enumerate(speaker_bbox):
|
|
59
|
+
if id == speaker_id:
|
|
60
|
+
continue
|
|
61
|
+
x_min, y_min, x_max, y_max = bbox
|
|
62
|
+
y_min = max(0, y_min)
|
|
63
|
+
y_max = min(mask.shape[0], y_max)
|
|
64
|
+
mask[y_min:y_max, x_min:x_max] = 1
|
|
65
|
+
face_mask[0, start_frame:end_frame] = mask[None]
|
|
66
|
+
prev_time = end_time
|
|
67
|
+
if end_frame > num_frames_total:
|
|
68
|
+
break
|
|
69
|
+
|
|
70
|
+
face_mask_resized = F.interpolate(
|
|
71
|
+
face_mask[None],
|
|
72
|
+
size=(
|
|
73
|
+
num_frames_total // temporal_scale,
|
|
74
|
+
mask_height // spatial_scale,
|
|
75
|
+
mask_width // spatial_scale,
|
|
76
|
+
),
|
|
77
|
+
mode="nearest",
|
|
78
|
+
)[0]
|
|
79
|
+
return 1 - face_mask_resized
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def transform_bbox(
|
|
83
|
+
bboxes: List[List[int]],
|
|
84
|
+
original_height: int,
|
|
85
|
+
original_width: int,
|
|
86
|
+
target_height: int,
|
|
87
|
+
target_width: int,
|
|
88
|
+
) -> Optional[List[float]]:
|
|
89
|
+
transformed_bboxes = []
|
|
90
|
+
for x_min, y_min, x_max, y_max in bboxes:
|
|
91
|
+
# --- 1. The Resize Operation ---
|
|
92
|
+
# The image is resized so its smaller edge is min(target_h, target_w).
|
|
93
|
+
resize_size = min(target_height, target_width)
|
|
94
|
+
|
|
95
|
+
# Determine the scaling factor.
|
|
96
|
+
if original_width < original_height:
|
|
97
|
+
# If width is the smaller edge
|
|
98
|
+
scale_factor = resize_size / original_width
|
|
99
|
+
resized_w = resize_size
|
|
100
|
+
resized_h = int(original_height * scale_factor)
|
|
101
|
+
else:
|
|
102
|
+
# If height is the smaller edge or they are equal
|
|
103
|
+
scale_factor = resize_size / original_height
|
|
104
|
+
resized_h = resize_size
|
|
105
|
+
resized_w = int(original_width * scale_factor)
|
|
106
|
+
|
|
107
|
+
# Apply the scaling factor to the bbox coordinates.
|
|
108
|
+
scaled_x_min = x_min * scale_factor
|
|
109
|
+
scaled_y_min = y_min * scale_factor
|
|
110
|
+
scaled_x_max = x_max * scale_factor
|
|
111
|
+
scaled_y_max = y_max * scale_factor
|
|
112
|
+
|
|
113
|
+
# --- 2. The Center Crop Operation ---
|
|
114
|
+
# Calculate the top-left corner (offset) of the crop area.
|
|
115
|
+
crop_offset_x = (resized_w - target_width) / 2.0
|
|
116
|
+
crop_offset_y = (resized_h - target_height) / 2.0
|
|
117
|
+
|
|
118
|
+
# Translate the bbox coordinates by subtracting the crop offset.
|
|
119
|
+
# The new coordinate system's origin (0,0) is the top-left of the crop.
|
|
120
|
+
final_x_min = scaled_x_min - crop_offset_x
|
|
121
|
+
final_y_min = scaled_y_min - crop_offset_y
|
|
122
|
+
final_x_max = scaled_x_max - crop_offset_x
|
|
123
|
+
final_y_max = scaled_y_max - crop_offset_y
|
|
124
|
+
|
|
125
|
+
# --- 3. Clipping ---
|
|
126
|
+
# The bbox might now be partially or fully outside the crop.
|
|
127
|
+
# Clip the coordinates to the crop dimensions [0, target_w] and [0, target_h].
|
|
128
|
+
final_x_min = max(0, final_x_min)
|
|
129
|
+
final_y_min = max(0, final_y_min)
|
|
130
|
+
final_x_max = min(target_width, final_x_max)
|
|
131
|
+
final_y_max = min(target_height, final_y_max)
|
|
132
|
+
|
|
133
|
+
# Check if the bbox is still valid (has a positive area).
|
|
134
|
+
if final_x_min >= final_x_max or final_y_min >= final_y_max:
|
|
135
|
+
transformed_bboxes.append([0, 0, 0, 0]) # The bbox is completely outside the crop.
|
|
136
|
+
else:
|
|
137
|
+
# make bbox integer
|
|
138
|
+
final_x_min = int(final_x_min)
|
|
139
|
+
final_y_min = int(final_y_min)
|
|
140
|
+
final_x_max = int(final_x_max)
|
|
141
|
+
final_y_max = int(final_y_max)
|
|
142
|
+
transformed_bboxes.append([final_x_min, final_y_min, final_x_max, final_y_max])
|
|
143
|
+
|
|
144
|
+
return transformed_bboxes
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def restrict_size_below_area(
|
|
148
|
+
height: int | None, width: int | None, ref_image: Image.Image, target_area: int = 1024 * 704, divisor: int = 64
|
|
149
|
+
):
|
|
150
|
+
if height is not None and width is not None:
|
|
151
|
+
return height, width
|
|
152
|
+
|
|
153
|
+
height, width = ref_image.height, ref_image.width
|
|
154
|
+
if height * width <= target_area:
|
|
155
|
+
# If the original image area is already less than or equal to the target,
|
|
156
|
+
# no resizing is needed—just padding. Still need to ensure that the padded area doesn't exceed the target.
|
|
157
|
+
max_upper_area = target_area
|
|
158
|
+
min_scale = 0.1
|
|
159
|
+
max_scale = 1.0
|
|
160
|
+
else:
|
|
161
|
+
# Resize to fit within the target area and then pad to multiples of `divisor`
|
|
162
|
+
max_upper_area = target_area # Maximum allowed total pixel count after padding
|
|
163
|
+
d = divisor - 1
|
|
164
|
+
b = d * (height + width)
|
|
165
|
+
a = height * width
|
|
166
|
+
c = d**2 - max_upper_area
|
|
167
|
+
|
|
168
|
+
# Calculate scale boundaries using quadratic equation
|
|
169
|
+
min_scale = (-b + math.sqrt(b**2 - 2 * a * c)) / (2 * a) # Scale when maximum padding is applied
|
|
170
|
+
max_scale = math.sqrt(max_upper_area / (height * width)) # Scale without any padding
|
|
171
|
+
|
|
172
|
+
# We want to choose the largest possible scale such that the final padded area does not exceed max_upper_area
|
|
173
|
+
for i in range(100):
|
|
174
|
+
scale = max_scale - (max_scale - min_scale) * i / 100
|
|
175
|
+
new_height, new_width = int(height * scale), int(width * scale)
|
|
176
|
+
|
|
177
|
+
# Pad to make dimensions divisible by 64
|
|
178
|
+
pad_height = (64 - new_height % 64) % 64
|
|
179
|
+
pad_width = (64 - new_width % 64) % 64
|
|
180
|
+
padded_height, padded_width = new_height + pad_height, new_width + pad_width
|
|
181
|
+
|
|
182
|
+
if padded_height * padded_width <= max_upper_area:
|
|
183
|
+
return padded_height, padded_width
|
|
184
|
+
|
|
185
|
+
# Fallback: calculate target dimensions based on aspect ratio and divisor alignment
|
|
186
|
+
aspect_ratio = width / height
|
|
187
|
+
target_width = int((target_area * aspect_ratio) ** 0.5 // divisor * divisor)
|
|
188
|
+
target_height = int((target_area / aspect_ratio) ** 0.5 // divisor * divisor)
|
|
189
|
+
|
|
190
|
+
# Ensure the result is not larger than the original resolution
|
|
191
|
+
if target_width >= width or target_height >= height:
|
|
192
|
+
target_width = int(width // divisor * divisor)
|
|
193
|
+
target_height = int(height // divisor * divisor)
|
|
194
|
+
|
|
195
|
+
return target_height, target_width
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class WanSpeech2VideoPipeline(WanVideoPipeline):
|
|
199
|
+
def __init__(
|
|
200
|
+
self,
|
|
201
|
+
config: WanSpeech2VideoPipelineConfig,
|
|
202
|
+
tokenizer: WanT5Tokenizer,
|
|
203
|
+
text_encoder: WanTextEncoder,
|
|
204
|
+
audio_encoder: Wav2Vec2Model,
|
|
205
|
+
dit: WanS2VDiT,
|
|
206
|
+
vae: WanVideoVAE,
|
|
207
|
+
):
|
|
208
|
+
super().__init__(
|
|
209
|
+
config=config,
|
|
210
|
+
tokenizer=tokenizer,
|
|
211
|
+
text_encoder=text_encoder,
|
|
212
|
+
dit=dit,
|
|
213
|
+
dit2=None,
|
|
214
|
+
vae=vae,
|
|
215
|
+
image_encoder=None,
|
|
216
|
+
)
|
|
217
|
+
self.audio_encoder = audio_encoder
|
|
218
|
+
self.model_names = ["audio_encoder", "text_encoder", "dit", "vae"]
|
|
219
|
+
|
|
220
|
+
def encode_ref_and_motion(
|
|
221
|
+
self,
|
|
222
|
+
ref_image: Image.Image | None,
|
|
223
|
+
height: int,
|
|
224
|
+
width: int,
|
|
225
|
+
num_motion_frames: int,
|
|
226
|
+
ref_as_first_frame: bool,
|
|
227
|
+
):
|
|
228
|
+
self.load_models_to_device(["vae"])
|
|
229
|
+
|
|
230
|
+
ref_frame = self.preprocess_image(ref_image)
|
|
231
|
+
ref_frame = torch.stack([ref_frame], dim=2).squeeze(0)
|
|
232
|
+
ref_latents = self.encode_video([ref_frame]).to(dtype=self.dtype, device=self.device)
|
|
233
|
+
|
|
234
|
+
# They fix channel and motion frame length.
|
|
235
|
+
motion_frames = torch.zeros([1, 3, num_motion_frames, height, width], dtype=self.dtype, device=self.device)
|
|
236
|
+
if ref_as_first_frame:
|
|
237
|
+
motion_frames[:, :, -6:] = ref_frame
|
|
238
|
+
motion_latents = self.encode_video(motion_frames).to(dtype=self.dtype, device=self.device)
|
|
239
|
+
|
|
240
|
+
return ref_latents, motion_latents, motion_frames
|
|
241
|
+
|
|
242
|
+
def encode_pose(self, pose_video: List[Image.Image], pose_video_fps: int, num_clips: int, num_frames_per_clip: int, height: int, width: int):
|
|
243
|
+
self.load_models_to_device(["vae"])
|
|
244
|
+
max_num_pose_frames = num_frames_per_clip * num_clips
|
|
245
|
+
pose_video = read_n_frames(pose_video, pose_video_fps, max_num_pose_frames, target_fps=self.config.fps)
|
|
246
|
+
pose_frames = torch.stack([pil_to_tensor(frame) for frame in pose_video])
|
|
247
|
+
pose_frames = pose_frames / 255.0 * 2 - 1.0
|
|
248
|
+
pose_frames = resize_and_center_crop(pose_frames, height, width).permute(1, 0, 2, 3)[None]
|
|
249
|
+
pose_frames_padding = torch.zeros([1, 3, max_num_pose_frames - pose_frames.shape[2], height, width])
|
|
250
|
+
pose_frames = torch.cat([pose_frames, pose_frames_padding], dim=2)
|
|
251
|
+
pose_frames_all_clips = torch.chunk(pose_frames, num_clips, dim=2)
|
|
252
|
+
|
|
253
|
+
pose_latents_all_clips = []
|
|
254
|
+
for pose_frames_per_clip in pose_frames_all_clips:
|
|
255
|
+
pose_frames_per_clip = torch.cat([pose_frames_per_clip[:, :, 0:1], pose_frames_per_clip], dim=2)
|
|
256
|
+
pose_latents_per_clip = self.encode_video([pose_frames_per_clip.squeeze(0)])[:, :, 1:].cpu()
|
|
257
|
+
pose_latents_all_clips.append(pose_latents_per_clip)
|
|
258
|
+
return pose_latents_all_clips
|
|
259
|
+
|
|
260
|
+
def encode_audio(self, audio: torch.Tensor, num_frames_per_clip: int, num_clips: int):
|
|
261
|
+
self.load_models_to_device(["audio_encoder"])
|
|
262
|
+
audio_embed_bucket, max_num_clips = get_audio_embed_bucket_fps(
|
|
263
|
+
audio_embed=extract_audio_feat(audio, self.audio_encoder, device=self.device),
|
|
264
|
+
num_frames_per_batch=num_frames_per_clip,
|
|
265
|
+
fps=self.config.fps,
|
|
266
|
+
)
|
|
267
|
+
audio_embed_bucket = audio_embed_bucket[None].to(self.device, self.dtype)
|
|
268
|
+
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1)
|
|
269
|
+
return audio_embed_bucket, min(max_num_clips, num_clips)
|
|
270
|
+
|
|
271
|
+
def encode_void_audio(self, void_audio: torch.Tensor, num_frames_per_clip: int):
|
|
272
|
+
self.load_models_to_device(["audio_encoder"])
|
|
273
|
+
void_audio_embed_bucket, _ = get_audio_embed_bucket_fps(
|
|
274
|
+
audio_embed=extract_audio_feat(void_audio, self.audio_encoder, device=self.device),
|
|
275
|
+
num_frames_per_batch=num_frames_per_clip,
|
|
276
|
+
fps=self.config.fps,
|
|
277
|
+
)
|
|
278
|
+
void_audio_embed_bucket = void_audio_embed_bucket[None].to(self.device, self.dtype)
|
|
279
|
+
void_audio_embed_bucket = void_audio_embed_bucket.permute(0, 2, 3, 1)
|
|
280
|
+
return void_audio_embed_bucket[..., :num_frames_per_clip]
|
|
281
|
+
|
|
282
|
+
def predict_noise_with_cfg(
|
|
283
|
+
self,
|
|
284
|
+
model: WanS2VDiT,
|
|
285
|
+
latents: torch.Tensor,
|
|
286
|
+
timestep: torch.Tensor,
|
|
287
|
+
positive_prompt_emb: torch.Tensor,
|
|
288
|
+
negative_prompt_emb: torch.Tensor,
|
|
289
|
+
cfg_scale: float,
|
|
290
|
+
batch_cfg: bool,
|
|
291
|
+
ref_latents: torch.Tensor,
|
|
292
|
+
motion_latents: torch.Tensor,
|
|
293
|
+
pose_cond: torch.Tensor,
|
|
294
|
+
audio_input: torch.Tensor,
|
|
295
|
+
num_motion_frames: int,
|
|
296
|
+
num_motion_latents: int,
|
|
297
|
+
drop_motion_frames: bool,
|
|
298
|
+
audio_mask: torch.Tensor | None,
|
|
299
|
+
void_audio_input: torch.Tensor | None,
|
|
300
|
+
):
|
|
301
|
+
if cfg_scale <= 1.0:
|
|
302
|
+
return self.predict_noise(
|
|
303
|
+
model=model,
|
|
304
|
+
latents=latents,
|
|
305
|
+
timestep=timestep,
|
|
306
|
+
context=positive_prompt_emb,
|
|
307
|
+
ref_latents=ref_latents,
|
|
308
|
+
motion_latents=motion_latents,
|
|
309
|
+
pose_cond=pose_cond,
|
|
310
|
+
audio_input=audio_input,
|
|
311
|
+
num_motion_frames=num_motion_frames,
|
|
312
|
+
num_motion_latents=num_motion_latents,
|
|
313
|
+
drop_motion_frames=drop_motion_frames,
|
|
314
|
+
audio_mask=audio_mask,
|
|
315
|
+
void_audio_input=void_audio_input,
|
|
316
|
+
)
|
|
317
|
+
if not batch_cfg:
|
|
318
|
+
positive_noise_pred = self.predict_noise(
|
|
319
|
+
model=model,
|
|
320
|
+
latents=latents,
|
|
321
|
+
timestep=timestep,
|
|
322
|
+
context=positive_prompt_emb,
|
|
323
|
+
ref_latents=ref_latents,
|
|
324
|
+
motion_latents=motion_latents,
|
|
325
|
+
pose_cond=pose_cond,
|
|
326
|
+
audio_input=audio_input,
|
|
327
|
+
num_motion_frames=num_motion_frames,
|
|
328
|
+
num_motion_latents=num_motion_latents,
|
|
329
|
+
drop_motion_frames=drop_motion_frames,
|
|
330
|
+
audio_mask=audio_mask,
|
|
331
|
+
void_audio_input=void_audio_input,
|
|
332
|
+
)
|
|
333
|
+
negative_noise_pred = self.predict_noise(
|
|
334
|
+
model=model,
|
|
335
|
+
latents=latents,
|
|
336
|
+
timestep=timestep,
|
|
337
|
+
context=negative_prompt_emb,
|
|
338
|
+
ref_latents=ref_latents,
|
|
339
|
+
motion_latents=motion_latents,
|
|
340
|
+
pose_cond=pose_cond,
|
|
341
|
+
audio_input=0.0 * audio_input,
|
|
342
|
+
num_motion_frames=num_motion_frames,
|
|
343
|
+
num_motion_latents=num_motion_latents,
|
|
344
|
+
drop_motion_frames=drop_motion_frames,
|
|
345
|
+
audio_mask=audio_mask,
|
|
346
|
+
void_audio_input=void_audio_input,
|
|
347
|
+
)
|
|
348
|
+
noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
|
|
349
|
+
return noise_pred
|
|
350
|
+
else:
|
|
351
|
+
prompt_emb = torch.cat([positive_prompt_emb, negative_prompt_emb], dim=0)
|
|
352
|
+
latents = torch.cat([latents, latents], dim=0)
|
|
353
|
+
audio_input = torch.cat([audio_input, 0.0 * audio_input], dim=0)
|
|
354
|
+
positive_noise_pred, negative_noise_pred = self.predict_noise(
|
|
355
|
+
model=model,
|
|
356
|
+
latents=latents,
|
|
357
|
+
timestep=timestep,
|
|
358
|
+
context=prompt_emb,
|
|
359
|
+
ref_latents=ref_latents,
|
|
360
|
+
motion_latents=motion_latents,
|
|
361
|
+
pose_cond=pose_cond,
|
|
362
|
+
audio_input=audio_input,
|
|
363
|
+
num_motion_frames=num_motion_frames,
|
|
364
|
+
num_motion_latents=num_motion_latents,
|
|
365
|
+
drop_motion_frames=drop_motion_frames,
|
|
366
|
+
audio_mask=audio_mask,
|
|
367
|
+
void_audio_input=void_audio_input,
|
|
368
|
+
)
|
|
369
|
+
noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
|
|
370
|
+
return noise_pred
|
|
371
|
+
|
|
372
|
+
def predict_noise(
|
|
373
|
+
self,
|
|
374
|
+
model: WanS2VDiT,
|
|
375
|
+
latents: torch.Tensor,
|
|
376
|
+
timestep: torch.Tensor,
|
|
377
|
+
context: torch.Tensor,
|
|
378
|
+
ref_latents: torch.Tensor,
|
|
379
|
+
motion_latents: torch.Tensor,
|
|
380
|
+
pose_cond: torch.Tensor,
|
|
381
|
+
audio_input: torch.Tensor,
|
|
382
|
+
num_motion_frames: int,
|
|
383
|
+
num_motion_latents: int,
|
|
384
|
+
drop_motion_frames: bool,
|
|
385
|
+
audio_mask: torch.Tensor | None = None,
|
|
386
|
+
void_audio_input: torch.Tensor | None = None,
|
|
387
|
+
):
|
|
388
|
+
latents = latents.to(dtype=self.config.model_dtype, device=self.device)
|
|
389
|
+
|
|
390
|
+
noise_pred = model(
|
|
391
|
+
x=latents,
|
|
392
|
+
context=context,
|
|
393
|
+
timestep=timestep,
|
|
394
|
+
ref_latents=ref_latents,
|
|
395
|
+
motion_latents=motion_latents,
|
|
396
|
+
pose_cond=pose_cond,
|
|
397
|
+
audio_input=audio_input,
|
|
398
|
+
num_motion_frames=num_motion_frames,
|
|
399
|
+
num_motion_latents=num_motion_latents,
|
|
400
|
+
drop_motion_frames=drop_motion_frames,
|
|
401
|
+
audio_mask=audio_mask,
|
|
402
|
+
void_audio_input=void_audio_input,
|
|
403
|
+
)
|
|
404
|
+
return noise_pred
|
|
405
|
+
|
|
406
|
+
@torch.no_grad()
|
|
407
|
+
def __call__(
|
|
408
|
+
self,
|
|
409
|
+
audio: torch.Tensor,
|
|
410
|
+
prompt: str,
|
|
411
|
+
negative_prompt: str = "",
|
|
412
|
+
cfg_scale: float | None = None,
|
|
413
|
+
num_inference_steps: int | None = None,
|
|
414
|
+
seed: int | None = None,
|
|
415
|
+
height: int | None = None,
|
|
416
|
+
width: int | None = None,
|
|
417
|
+
num_frames_per_clip: int = 80,
|
|
418
|
+
ref_image: Image.Image | None = None,
|
|
419
|
+
pose_video: List[Image.Image] | None = None,
|
|
420
|
+
pose_video_fps: int | None = None,
|
|
421
|
+
void_audio: torch.Tensor | None = None,
|
|
422
|
+
num_clips: int = 1,
|
|
423
|
+
ref_as_first_frame: bool = False,
|
|
424
|
+
speaker_bbox: List[List[int]] = [],
|
|
425
|
+
speaker_end_sec: List[List[int]] = [],
|
|
426
|
+
progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status)
|
|
427
|
+
):
|
|
428
|
+
assert ref_image is not None, "ref_image must be provided"
|
|
429
|
+
cfg_scale = self.config.cfg_scale if cfg_scale is None else cfg_scale
|
|
430
|
+
num_inference_steps = self.config.num_inference_steps if num_inference_steps is None else num_inference_steps
|
|
431
|
+
original_height, original_width = ref_image.height, ref_image.width
|
|
432
|
+
height, width = restrict_size_below_area(height, width, ref_image)
|
|
433
|
+
|
|
434
|
+
# Initialize noise
|
|
435
|
+
if dist.is_initialized() and seed is None:
|
|
436
|
+
raise ValueError("must provide a seed when parallelism is enabled")
|
|
437
|
+
|
|
438
|
+
# Encode prompts
|
|
439
|
+
self.load_models_to_device(["text_encoder"])
|
|
440
|
+
prompt_emb_posi = self.encode_prompt(prompt)
|
|
441
|
+
prompt_emb_nega = self.encode_prompt(negative_prompt)
|
|
442
|
+
|
|
443
|
+
# Encode ref image, previous video and audio
|
|
444
|
+
num_motion_frames = 73
|
|
445
|
+
num_motion_latents = (num_motion_frames + 3) // 4
|
|
446
|
+
ref_image = resize_and_center_crop(ref_image, height, width)
|
|
447
|
+
ref_latents, motion_latents, motion_frames = self.encode_ref_and_motion(
|
|
448
|
+
ref_image, height, width, num_motion_frames, ref_as_first_frame
|
|
449
|
+
)
|
|
450
|
+
audio_emb, num_clips = self.encode_audio(audio, num_frames_per_clip, num_clips)
|
|
451
|
+
if len(speaker_end_sec) > 0:
|
|
452
|
+
void_audio_emb = self.encode_void_audio(void_audio, num_frames_per_clip)
|
|
453
|
+
speaker_bbox = transform_bbox(
|
|
454
|
+
speaker_bbox,
|
|
455
|
+
original_height,
|
|
456
|
+
original_width,
|
|
457
|
+
height,
|
|
458
|
+
width,
|
|
459
|
+
)
|
|
460
|
+
audio_mask = get_face_mask(
|
|
461
|
+
ref_image=ref_image,
|
|
462
|
+
speaker_end_sec=speaker_end_sec,
|
|
463
|
+
speaker_bbox=speaker_bbox,
|
|
464
|
+
num_frames_total=num_clips * num_frames_per_clip,
|
|
465
|
+
fps=self.config.fps,
|
|
466
|
+
dtype=self.dtype,
|
|
467
|
+
).to(self.device)
|
|
468
|
+
if pose_video is not None:
|
|
469
|
+
pose_latents_all_clips = self.encode_pose(pose_video, pose_video_fps, num_clips, num_frames_per_clip, height, width)
|
|
470
|
+
|
|
471
|
+
output_frames_all_clips = []
|
|
472
|
+
for clip_idx in range(num_clips):
|
|
473
|
+
num_latents_per_clip = num_frames_per_clip // 4
|
|
474
|
+
noise = self.generate_noise(
|
|
475
|
+
(
|
|
476
|
+
1,
|
|
477
|
+
self.vae.z_dim,
|
|
478
|
+
num_latents_per_clip,
|
|
479
|
+
height // self.upsampling_factor,
|
|
480
|
+
width // self.upsampling_factor,
|
|
481
|
+
),
|
|
482
|
+
seed=seed + clip_idx,
|
|
483
|
+
device="cpu",
|
|
484
|
+
dtype=torch.float32,
|
|
485
|
+
).to(self.device)
|
|
486
|
+
_, latents, sigmas, timesteps = self.prepare_latents(
|
|
487
|
+
latents=noise,
|
|
488
|
+
input_video=None,
|
|
489
|
+
denoising_strength=None,
|
|
490
|
+
num_inference_steps=num_inference_steps,
|
|
491
|
+
)
|
|
492
|
+
# Initialize sampler
|
|
493
|
+
self.sampler.initialize(sigmas=sigmas)
|
|
494
|
+
|
|
495
|
+
# Index audio emb and pose latents
|
|
496
|
+
audio_emb_curr_clip = audio_emb[
|
|
497
|
+
..., (clip_idx * num_frames_per_clip) : ((clip_idx + 1) * num_frames_per_clip)
|
|
498
|
+
]
|
|
499
|
+
pose_latents_curr_clip = (
|
|
500
|
+
pose_latents_all_clips[clip_idx] if pose_video is not None else torch.zeros_like(latents)
|
|
501
|
+
)
|
|
502
|
+
pose_latents_curr_clip = pose_latents_curr_clip.to(dtype=self.dtype, device=self.device)
|
|
503
|
+
if len(speaker_end_sec) > 0:
|
|
504
|
+
audio_mask_curr_clip = audio_mask[
|
|
505
|
+
None, :, (clip_idx * num_latents_per_clip) : ((clip_idx + 1) * num_latents_per_clip)
|
|
506
|
+
]
|
|
507
|
+
else:
|
|
508
|
+
audio_mask_curr_clip, void_audio_emb = None, None
|
|
509
|
+
|
|
510
|
+
# Denoise
|
|
511
|
+
drop_motion_frames = (not ref_as_first_frame) and clip_idx == 0
|
|
512
|
+
hide_progress = dist.is_initialized() and dist.get_rank() != 0
|
|
513
|
+
for i, timestep in enumerate(tqdm(timesteps, disable=hide_progress)):
|
|
514
|
+
self.load_models_to_device(["dit"])
|
|
515
|
+
|
|
516
|
+
timestep = timestep[None].to(dtype=self.dtype, device=self.device)
|
|
517
|
+
# Classifier-free guidance
|
|
518
|
+
noise_pred = self.predict_noise_with_cfg(
|
|
519
|
+
model=self.dit,
|
|
520
|
+
latents=latents,
|
|
521
|
+
timestep=timestep,
|
|
522
|
+
positive_prompt_emb=prompt_emb_posi,
|
|
523
|
+
negative_prompt_emb=prompt_emb_nega,
|
|
524
|
+
cfg_scale=cfg_scale,
|
|
525
|
+
batch_cfg=self.config.batch_cfg,
|
|
526
|
+
ref_latents=ref_latents,
|
|
527
|
+
motion_latents=motion_latents,
|
|
528
|
+
pose_cond=pose_latents_curr_clip,
|
|
529
|
+
audio_input=audio_emb_curr_clip,
|
|
530
|
+
num_motion_frames=num_motion_frames,
|
|
531
|
+
num_motion_latents=num_motion_latents,
|
|
532
|
+
drop_motion_frames=drop_motion_frames,
|
|
533
|
+
audio_mask=audio_mask_curr_clip,
|
|
534
|
+
void_audio_input=void_audio_emb,
|
|
535
|
+
)
|
|
536
|
+
# Scheduler
|
|
537
|
+
latents = self.sampler.step(latents, noise_pred, i)
|
|
538
|
+
if progress_callback is not None:
|
|
539
|
+
progress_callback(i + 1, len(timesteps), "DENOISING")
|
|
540
|
+
|
|
541
|
+
if drop_motion_frames:
|
|
542
|
+
decode_latents = torch.cat([ref_latents, latents], dim=2)
|
|
543
|
+
else:
|
|
544
|
+
decode_latents = torch.cat([motion_latents, latents], dim=2)
|
|
545
|
+
self.load_models_to_device(["vae"])
|
|
546
|
+
output_frames_curr_clip = torch.stack(
|
|
547
|
+
self.decode_video(decode_latents, progress_callback=progress_callback)
|
|
548
|
+
)
|
|
549
|
+
output_frames_curr_clip = output_frames_curr_clip[:, :, -(num_frames_per_clip):]
|
|
550
|
+
if drop_motion_frames:
|
|
551
|
+
output_frames_curr_clip = output_frames_curr_clip[:, :, 3:]
|
|
552
|
+
output_frames_all_clips.append(output_frames_curr_clip.cpu())
|
|
553
|
+
|
|
554
|
+
if clip_idx < num_clips - 1:
|
|
555
|
+
f = output_frames_curr_clip.shape[2]
|
|
556
|
+
if f <= num_motion_frames:
|
|
557
|
+
motion_frames = torch.cat([motion_frames[:, :, f:], output_frames_curr_clip], dim=2)
|
|
558
|
+
else:
|
|
559
|
+
motion_frames = output_frames_curr_clip[:, :, -num_motion_frames:]
|
|
560
|
+
motion_latents = self.encode_video(motion_frames)
|
|
561
|
+
|
|
562
|
+
output_frames_all_clips = torch.cat(output_frames_all_clips, dim=2)
|
|
563
|
+
output_frames_all_clips = self.vae_output_to_image(output_frames_all_clips)
|
|
564
|
+
return output_frames_all_clips
|
|
565
|
+
|
|
566
|
+
@classmethod
|
|
567
|
+
def from_pretrained(cls, model_path_or_config: WanSpeech2VideoPipelineConfig) -> "WanSpeech2VideoPipeline":
|
|
568
|
+
if isinstance(model_path_or_config, str):
|
|
569
|
+
config = WanSpeech2VideoPipelineConfig(model_path=model_path_or_config)
|
|
570
|
+
else:
|
|
571
|
+
config = model_path_or_config
|
|
572
|
+
|
|
573
|
+
logger.info(f"loading dit state dict from {config.model_path} ...")
|
|
574
|
+
dit_state_dict = cls.load_model_checkpoint(config.model_path, device="cpu", dtype=config.model_dtype)
|
|
575
|
+
|
|
576
|
+
if config.t5_path is None:
|
|
577
|
+
config.t5_path = fetch_model("muse/wan2.1-umt5", path="umt5.safetensors")
|
|
578
|
+
if config.vae_path is None:
|
|
579
|
+
config.vae_path = fetch_model("muse/wan2.1-vae", path="vae.safetensors")
|
|
580
|
+
if config.audio_encoder_path is None:
|
|
581
|
+
config.audio_encoder_path = fetch_model(
|
|
582
|
+
"Wan-AI/Wan2.2-S2V-14B", path="wav2vec2-large-xlsr-53-english/model.safetensors"
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
logger.info(f"loading t5 state dict from {config.t5_path} ...")
|
|
586
|
+
t5_state_dict = cls.load_model_checkpoint(config.t5_path, device="cpu", dtype=config.t5_dtype)
|
|
587
|
+
|
|
588
|
+
logger.info(f"loading vae state dict from {config.vae_path} ...")
|
|
589
|
+
vae_state_dict = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)
|
|
590
|
+
|
|
591
|
+
logger.info(f"loading audio encoder state dict from {config.audio_encoder_path} ...")
|
|
592
|
+
wav2vec_state_dict = cls.load_model_checkpoint(
|
|
593
|
+
config.audio_encoder_path, device="cpu", dtype=config.audio_encoder_dtype
|
|
594
|
+
)
|
|
595
|
+
|
|
596
|
+
state_dicts = WanS2VStateDicts(
|
|
597
|
+
model=dit_state_dict,
|
|
598
|
+
t5=t5_state_dict,
|
|
599
|
+
vae=vae_state_dict,
|
|
600
|
+
audio_encoder=wav2vec_state_dict,
|
|
601
|
+
)
|
|
602
|
+
return cls.from_state_dict(state_dicts, config)
|
|
603
|
+
|
|
604
|
+
@classmethod
|
|
605
|
+
def from_state_dict(cls, state_dicts: WanS2VStateDicts, config: WanSpeech2VideoPipelineConfig) -> "WanSpeech2VideoPipeline":
|
|
606
|
+
if config.parallelism > 1:
|
|
607
|
+
pipe = ParallelWrapper(
|
|
608
|
+
cfg_degree=config.cfg_degree,
|
|
609
|
+
sp_ulysses_degree=config.sp_ulysses_degree,
|
|
610
|
+
sp_ring_degree=config.sp_ring_degree,
|
|
611
|
+
tp_degree=config.tp_degree,
|
|
612
|
+
use_fsdp=config.use_fsdp,
|
|
613
|
+
)
|
|
614
|
+
pipe.load_module(cls._from_state_dict, state_dicts=state_dicts, config=config)
|
|
615
|
+
else:
|
|
616
|
+
pipe = cls._from_state_dict(state_dicts, config)
|
|
617
|
+
return pipe
|
|
618
|
+
|
|
619
|
+
@classmethod
|
|
620
|
+
def _from_state_dict(cls, state_dicts: WanS2VStateDicts, config: WanSpeech2VideoPipelineConfig) -> "WanSpeech2VideoPipeline":
|
|
621
|
+
# default params from model config
|
|
622
|
+
vae_type = "wan2.1-vae"
|
|
623
|
+
dit_type = "wan2.2-s2v-14b"
|
|
624
|
+
vae_config: dict = WanVideoVAE.get_model_config(vae_type)
|
|
625
|
+
model_config: dict = WanS2VDiT.get_model_config(dit_type)
|
|
626
|
+
config.boundary = model_config.pop("boundary", -1.0)
|
|
627
|
+
config.shift = model_config.pop("shift", 5.0)
|
|
628
|
+
config.cfg_scale = model_config.pop("cfg_scale", 5.0)
|
|
629
|
+
config.num_inference_steps = model_config.pop("num_inference_steps", 50)
|
|
630
|
+
config.fps = model_config.pop("fps", 16)
|
|
631
|
+
|
|
632
|
+
init_device = "cpu" if config.offload_mode is not None else config.device
|
|
633
|
+
tokenizer = WanT5Tokenizer(WAN_TOKENIZER_CONF_PATH, seq_len=512, clean="whitespace")
|
|
634
|
+
text_encoder = WanTextEncoder.from_state_dict(state_dicts.t5, device=init_device, dtype=config.t5_dtype)
|
|
635
|
+
vae = WanVideoVAE.from_state_dict(state_dicts.vae, config=vae_config, device=init_device, dtype=config.vae_dtype)
|
|
636
|
+
audio_encoder = Wav2Vec2Model.from_state_dict(
|
|
637
|
+
state_dicts.audio_encoder, config=Wav2Vec2Config(), device=init_device, dtype=config.audio_encoder_dtype
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
with LoRAContext():
|
|
641
|
+
attn_kwargs = {
|
|
642
|
+
"attn_impl": config.dit_attn_impl,
|
|
643
|
+
"sparge_smooth_k": config.sparge_smooth_k,
|
|
644
|
+
"sparge_cdfthreshd": config.sparge_cdfthreshd,
|
|
645
|
+
"sparge_simthreshd1": config.sparge_simthreshd1,
|
|
646
|
+
"sparge_pvthreshd": config.sparge_pvthreshd,
|
|
647
|
+
}
|
|
648
|
+
dit = WanS2VDiT.from_state_dict(
|
|
649
|
+
state_dicts.model,
|
|
650
|
+
config=model_config,
|
|
651
|
+
device=init_device,
|
|
652
|
+
dtype=config.model_dtype,
|
|
653
|
+
attn_kwargs=attn_kwargs,
|
|
654
|
+
)
|
|
655
|
+
if config.use_fp8_linear:
|
|
656
|
+
enable_fp8_linear(dit)
|
|
657
|
+
|
|
658
|
+
pipe = cls(
|
|
659
|
+
config=config,
|
|
660
|
+
tokenizer=tokenizer,
|
|
661
|
+
text_encoder=text_encoder,
|
|
662
|
+
dit=dit,
|
|
663
|
+
vae=vae,
|
|
664
|
+
audio_encoder=audio_encoder,
|
|
665
|
+
)
|
|
666
|
+
pipe.eval()
|
|
667
|
+
|
|
668
|
+
if config.offload_mode is not None:
|
|
669
|
+
pipe.enable_cpu_offload(config.offload_mode)
|
|
670
|
+
|
|
671
|
+
if config.model_dtype == torch.float8_e4m3fn:
|
|
672
|
+
pipe.dtype = torch.bfloat16 # compute dtype
|
|
673
|
+
pipe.enable_fp8_autocast(
|
|
674
|
+
model_names=["dit"], compute_dtype=pipe.dtype, use_fp8_linear=config.use_fp8_linear
|
|
675
|
+
)
|
|
676
|
+
|
|
677
|
+
if config.t5_dtype == torch.float8_e4m3fn:
|
|
678
|
+
pipe.dtype = torch.bfloat16 # compute dtype
|
|
679
|
+
pipe.enable_fp8_autocast(
|
|
680
|
+
model_names=["text_encoder"], compute_dtype=pipe.dtype, use_fp8_linear=config.use_fp8_linear
|
|
681
|
+
)
|
|
682
|
+
|
|
683
|
+
if config.use_torch_compile:
|
|
684
|
+
pipe.compile()
|
|
685
|
+
return pipe
|
|
@@ -34,6 +34,7 @@ WAN2_1_DIT_FLF2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit
|
|
|
34
34
|
WAN2_2_DIT_TI2V_5B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2-ti2v-5b.json")
|
|
35
35
|
WAN2_2_DIT_T2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2-t2v-a14b.json")
|
|
36
36
|
WAN2_2_DIT_I2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2-i2v-a14b.json")
|
|
37
|
+
WAN2_2_DIT_S2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2-s2v-14b.json")
|
|
37
38
|
|
|
38
39
|
WAN2_1_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan2.1-vae.json")
|
|
39
40
|
WAN2_2_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan2.2-vae.json")
|