diffsynth-engine 0.5.1.dev2__py3-none-any.whl → 0.5.1.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,685 @@
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import torch.distributed as dist
7
+ from torchvision.transforms.functional import pil_to_tensor
8
+ from typing import Callable, List, Optional
9
+ from tqdm import tqdm
10
+ from PIL import Image
11
+
12
+ from diffsynth_engine.configs import WanSpeech2VideoPipelineConfig, WanS2VStateDicts
13
+ from diffsynth_engine.models.wan.wan_s2v_dit import WanS2VDiT
14
+ from diffsynth_engine.models.wan.wan_text_encoder import WanTextEncoder
15
+ from diffsynth_engine.models.wan.wan_audio_encoder import (
16
+ Wav2Vec2Model,
17
+ Wav2Vec2Config,
18
+ get_audio_embed_bucket_fps,
19
+ extract_audio_feat,
20
+ )
21
+ from diffsynth_engine.models.wan.wan_vae import WanVideoVAE
22
+ from diffsynth_engine.pipelines.wan_video import WanVideoPipeline
23
+ from diffsynth_engine.models.basic.lora import LoRAContext
24
+ from diffsynth_engine.tokenizers import WanT5Tokenizer
25
+ from diffsynth_engine.utils.constants import WAN_TOKENIZER_CONF_PATH
26
+ from diffsynth_engine.utils.download import fetch_model
27
+ from diffsynth_engine.utils.fp8_linear import enable_fp8_linear
28
+ from diffsynth_engine.utils.image import resize_and_center_crop
29
+ from diffsynth_engine.utils.video import read_n_frames
30
+ from diffsynth_engine.utils.parallel import ParallelWrapper
31
+ from diffsynth_engine.utils import logging
32
+
33
+
34
+ logger = logging.get_logger(__name__)
35
+
36
+
37
+ def get_face_mask(
38
+ ref_image: Image.Image,
39
+ speaker_end_sec: List[List[int]],
40
+ speaker_bbox: List[List[int]],
41
+ num_frames_total: int,
42
+ fps=16,
43
+ temporal_scale=4,
44
+ spatial_scale=16,
45
+ dtype=torch.bfloat16,
46
+ ):
47
+ mask_height, mask_width = ref_image.height, ref_image.width
48
+ ref_image = np.array(ref_image)
49
+ face_mask = torch.zeros(
50
+ [1, num_frames_total, mask_height, mask_width],
51
+ dtype=dtype,
52
+ )
53
+ prev_time = 0
54
+ for speaker_id, end_time in speaker_end_sec:
55
+ start_frame = int(prev_time * fps)
56
+ end_frame = int(end_time * fps)
57
+ mask = torch.zeros(ref_image.shape[:2], dtype=dtype)
58
+ for id, bbox in enumerate(speaker_bbox):
59
+ if id == speaker_id:
60
+ continue
61
+ x_min, y_min, x_max, y_max = bbox
62
+ y_min = max(0, y_min)
63
+ y_max = min(mask.shape[0], y_max)
64
+ mask[y_min:y_max, x_min:x_max] = 1
65
+ face_mask[0, start_frame:end_frame] = mask[None]
66
+ prev_time = end_time
67
+ if end_frame > num_frames_total:
68
+ break
69
+
70
+ face_mask_resized = F.interpolate(
71
+ face_mask[None],
72
+ size=(
73
+ num_frames_total // temporal_scale,
74
+ mask_height // spatial_scale,
75
+ mask_width // spatial_scale,
76
+ ),
77
+ mode="nearest",
78
+ )[0]
79
+ return 1 - face_mask_resized
80
+
81
+
82
+ def transform_bbox(
83
+ bboxes: List[List[int]],
84
+ original_height: int,
85
+ original_width: int,
86
+ target_height: int,
87
+ target_width: int,
88
+ ) -> Optional[List[float]]:
89
+ transformed_bboxes = []
90
+ for x_min, y_min, x_max, y_max in bboxes:
91
+ # --- 1. The Resize Operation ---
92
+ # The image is resized so its smaller edge is min(target_h, target_w).
93
+ resize_size = min(target_height, target_width)
94
+
95
+ # Determine the scaling factor.
96
+ if original_width < original_height:
97
+ # If width is the smaller edge
98
+ scale_factor = resize_size / original_width
99
+ resized_w = resize_size
100
+ resized_h = int(original_height * scale_factor)
101
+ else:
102
+ # If height is the smaller edge or they are equal
103
+ scale_factor = resize_size / original_height
104
+ resized_h = resize_size
105
+ resized_w = int(original_width * scale_factor)
106
+
107
+ # Apply the scaling factor to the bbox coordinates.
108
+ scaled_x_min = x_min * scale_factor
109
+ scaled_y_min = y_min * scale_factor
110
+ scaled_x_max = x_max * scale_factor
111
+ scaled_y_max = y_max * scale_factor
112
+
113
+ # --- 2. The Center Crop Operation ---
114
+ # Calculate the top-left corner (offset) of the crop area.
115
+ crop_offset_x = (resized_w - target_width) / 2.0
116
+ crop_offset_y = (resized_h - target_height) / 2.0
117
+
118
+ # Translate the bbox coordinates by subtracting the crop offset.
119
+ # The new coordinate system's origin (0,0) is the top-left of the crop.
120
+ final_x_min = scaled_x_min - crop_offset_x
121
+ final_y_min = scaled_y_min - crop_offset_y
122
+ final_x_max = scaled_x_max - crop_offset_x
123
+ final_y_max = scaled_y_max - crop_offset_y
124
+
125
+ # --- 3. Clipping ---
126
+ # The bbox might now be partially or fully outside the crop.
127
+ # Clip the coordinates to the crop dimensions [0, target_w] and [0, target_h].
128
+ final_x_min = max(0, final_x_min)
129
+ final_y_min = max(0, final_y_min)
130
+ final_x_max = min(target_width, final_x_max)
131
+ final_y_max = min(target_height, final_y_max)
132
+
133
+ # Check if the bbox is still valid (has a positive area).
134
+ if final_x_min >= final_x_max or final_y_min >= final_y_max:
135
+ transformed_bboxes.append([0, 0, 0, 0]) # The bbox is completely outside the crop.
136
+ else:
137
+ # make bbox integer
138
+ final_x_min = int(final_x_min)
139
+ final_y_min = int(final_y_min)
140
+ final_x_max = int(final_x_max)
141
+ final_y_max = int(final_y_max)
142
+ transformed_bboxes.append([final_x_min, final_y_min, final_x_max, final_y_max])
143
+
144
+ return transformed_bboxes
145
+
146
+
147
+ def restrict_size_below_area(
148
+ height: int | None, width: int | None, ref_image: Image.Image, target_area: int = 1024 * 704, divisor: int = 64
149
+ ):
150
+ if height is not None and width is not None:
151
+ return height, width
152
+
153
+ height, width = ref_image.height, ref_image.width
154
+ if height * width <= target_area:
155
+ # If the original image area is already less than or equal to the target,
156
+ # no resizing is needed—just padding. Still need to ensure that the padded area doesn't exceed the target.
157
+ max_upper_area = target_area
158
+ min_scale = 0.1
159
+ max_scale = 1.0
160
+ else:
161
+ # Resize to fit within the target area and then pad to multiples of `divisor`
162
+ max_upper_area = target_area # Maximum allowed total pixel count after padding
163
+ d = divisor - 1
164
+ b = d * (height + width)
165
+ a = height * width
166
+ c = d**2 - max_upper_area
167
+
168
+ # Calculate scale boundaries using quadratic equation
169
+ min_scale = (-b + math.sqrt(b**2 - 2 * a * c)) / (2 * a) # Scale when maximum padding is applied
170
+ max_scale = math.sqrt(max_upper_area / (height * width)) # Scale without any padding
171
+
172
+ # We want to choose the largest possible scale such that the final padded area does not exceed max_upper_area
173
+ for i in range(100):
174
+ scale = max_scale - (max_scale - min_scale) * i / 100
175
+ new_height, new_width = int(height * scale), int(width * scale)
176
+
177
+ # Pad to make dimensions divisible by 64
178
+ pad_height = (64 - new_height % 64) % 64
179
+ pad_width = (64 - new_width % 64) % 64
180
+ padded_height, padded_width = new_height + pad_height, new_width + pad_width
181
+
182
+ if padded_height * padded_width <= max_upper_area:
183
+ return padded_height, padded_width
184
+
185
+ # Fallback: calculate target dimensions based on aspect ratio and divisor alignment
186
+ aspect_ratio = width / height
187
+ target_width = int((target_area * aspect_ratio) ** 0.5 // divisor * divisor)
188
+ target_height = int((target_area / aspect_ratio) ** 0.5 // divisor * divisor)
189
+
190
+ # Ensure the result is not larger than the original resolution
191
+ if target_width >= width or target_height >= height:
192
+ target_width = int(width // divisor * divisor)
193
+ target_height = int(height // divisor * divisor)
194
+
195
+ return target_height, target_width
196
+
197
+
198
+ class WanSpeech2VideoPipeline(WanVideoPipeline):
199
+ def __init__(
200
+ self,
201
+ config: WanSpeech2VideoPipelineConfig,
202
+ tokenizer: WanT5Tokenizer,
203
+ text_encoder: WanTextEncoder,
204
+ audio_encoder: Wav2Vec2Model,
205
+ dit: WanS2VDiT,
206
+ vae: WanVideoVAE,
207
+ ):
208
+ super().__init__(
209
+ config=config,
210
+ tokenizer=tokenizer,
211
+ text_encoder=text_encoder,
212
+ dit=dit,
213
+ dit2=None,
214
+ vae=vae,
215
+ image_encoder=None,
216
+ )
217
+ self.audio_encoder = audio_encoder
218
+ self.model_names = ["audio_encoder", "text_encoder", "dit", "vae"]
219
+
220
+ def encode_ref_and_motion(
221
+ self,
222
+ ref_image: Image.Image | None,
223
+ height: int,
224
+ width: int,
225
+ num_motion_frames: int,
226
+ ref_as_first_frame: bool,
227
+ ):
228
+ self.load_models_to_device(["vae"])
229
+
230
+ ref_frame = self.preprocess_image(ref_image)
231
+ ref_frame = torch.stack([ref_frame], dim=2).squeeze(0)
232
+ ref_latents = self.encode_video([ref_frame]).to(dtype=self.dtype, device=self.device)
233
+
234
+ # They fix channel and motion frame length.
235
+ motion_frames = torch.zeros([1, 3, num_motion_frames, height, width], dtype=self.dtype, device=self.device)
236
+ if ref_as_first_frame:
237
+ motion_frames[:, :, -6:] = ref_frame
238
+ motion_latents = self.encode_video(motion_frames).to(dtype=self.dtype, device=self.device)
239
+
240
+ return ref_latents, motion_latents, motion_frames
241
+
242
+ def encode_pose(self, pose_video: List[Image.Image], pose_video_fps: int, num_clips: int, num_frames_per_clip: int, height: int, width: int):
243
+ self.load_models_to_device(["vae"])
244
+ max_num_pose_frames = num_frames_per_clip * num_clips
245
+ pose_video = read_n_frames(pose_video, pose_video_fps, max_num_pose_frames, target_fps=self.config.fps)
246
+ pose_frames = torch.stack([pil_to_tensor(frame) for frame in pose_video])
247
+ pose_frames = pose_frames / 255.0 * 2 - 1.0
248
+ pose_frames = resize_and_center_crop(pose_frames, height, width).permute(1, 0, 2, 3)[None]
249
+ pose_frames_padding = torch.zeros([1, 3, max_num_pose_frames - pose_frames.shape[2], height, width])
250
+ pose_frames = torch.cat([pose_frames, pose_frames_padding], dim=2)
251
+ pose_frames_all_clips = torch.chunk(pose_frames, num_clips, dim=2)
252
+
253
+ pose_latents_all_clips = []
254
+ for pose_frames_per_clip in pose_frames_all_clips:
255
+ pose_frames_per_clip = torch.cat([pose_frames_per_clip[:, :, 0:1], pose_frames_per_clip], dim=2)
256
+ pose_latents_per_clip = self.encode_video([pose_frames_per_clip.squeeze(0)])[:, :, 1:].cpu()
257
+ pose_latents_all_clips.append(pose_latents_per_clip)
258
+ return pose_latents_all_clips
259
+
260
+ def encode_audio(self, audio: torch.Tensor, num_frames_per_clip: int, num_clips: int):
261
+ self.load_models_to_device(["audio_encoder"])
262
+ audio_embed_bucket, max_num_clips = get_audio_embed_bucket_fps(
263
+ audio_embed=extract_audio_feat(audio, self.audio_encoder, device=self.device),
264
+ num_frames_per_batch=num_frames_per_clip,
265
+ fps=self.config.fps,
266
+ )
267
+ audio_embed_bucket = audio_embed_bucket[None].to(self.device, self.dtype)
268
+ audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1)
269
+ return audio_embed_bucket, min(max_num_clips, num_clips)
270
+
271
+ def encode_void_audio(self, void_audio: torch.Tensor, num_frames_per_clip: int):
272
+ self.load_models_to_device(["audio_encoder"])
273
+ void_audio_embed_bucket, _ = get_audio_embed_bucket_fps(
274
+ audio_embed=extract_audio_feat(void_audio, self.audio_encoder, device=self.device),
275
+ num_frames_per_batch=num_frames_per_clip,
276
+ fps=self.config.fps,
277
+ )
278
+ void_audio_embed_bucket = void_audio_embed_bucket[None].to(self.device, self.dtype)
279
+ void_audio_embed_bucket = void_audio_embed_bucket.permute(0, 2, 3, 1)
280
+ return void_audio_embed_bucket[..., :num_frames_per_clip]
281
+
282
+ def predict_noise_with_cfg(
283
+ self,
284
+ model: WanS2VDiT,
285
+ latents: torch.Tensor,
286
+ timestep: torch.Tensor,
287
+ positive_prompt_emb: torch.Tensor,
288
+ negative_prompt_emb: torch.Tensor,
289
+ cfg_scale: float,
290
+ batch_cfg: bool,
291
+ ref_latents: torch.Tensor,
292
+ motion_latents: torch.Tensor,
293
+ pose_cond: torch.Tensor,
294
+ audio_input: torch.Tensor,
295
+ num_motion_frames: int,
296
+ num_motion_latents: int,
297
+ drop_motion_frames: bool,
298
+ audio_mask: torch.Tensor | None,
299
+ void_audio_input: torch.Tensor | None,
300
+ ):
301
+ if cfg_scale <= 1.0:
302
+ return self.predict_noise(
303
+ model=model,
304
+ latents=latents,
305
+ timestep=timestep,
306
+ context=positive_prompt_emb,
307
+ ref_latents=ref_latents,
308
+ motion_latents=motion_latents,
309
+ pose_cond=pose_cond,
310
+ audio_input=audio_input,
311
+ num_motion_frames=num_motion_frames,
312
+ num_motion_latents=num_motion_latents,
313
+ drop_motion_frames=drop_motion_frames,
314
+ audio_mask=audio_mask,
315
+ void_audio_input=void_audio_input,
316
+ )
317
+ if not batch_cfg:
318
+ positive_noise_pred = self.predict_noise(
319
+ model=model,
320
+ latents=latents,
321
+ timestep=timestep,
322
+ context=positive_prompt_emb,
323
+ ref_latents=ref_latents,
324
+ motion_latents=motion_latents,
325
+ pose_cond=pose_cond,
326
+ audio_input=audio_input,
327
+ num_motion_frames=num_motion_frames,
328
+ num_motion_latents=num_motion_latents,
329
+ drop_motion_frames=drop_motion_frames,
330
+ audio_mask=audio_mask,
331
+ void_audio_input=void_audio_input,
332
+ )
333
+ negative_noise_pred = self.predict_noise(
334
+ model=model,
335
+ latents=latents,
336
+ timestep=timestep,
337
+ context=negative_prompt_emb,
338
+ ref_latents=ref_latents,
339
+ motion_latents=motion_latents,
340
+ pose_cond=pose_cond,
341
+ audio_input=0.0 * audio_input,
342
+ num_motion_frames=num_motion_frames,
343
+ num_motion_latents=num_motion_latents,
344
+ drop_motion_frames=drop_motion_frames,
345
+ audio_mask=audio_mask,
346
+ void_audio_input=void_audio_input,
347
+ )
348
+ noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
349
+ return noise_pred
350
+ else:
351
+ prompt_emb = torch.cat([positive_prompt_emb, negative_prompt_emb], dim=0)
352
+ latents = torch.cat([latents, latents], dim=0)
353
+ audio_input = torch.cat([audio_input, 0.0 * audio_input], dim=0)
354
+ positive_noise_pred, negative_noise_pred = self.predict_noise(
355
+ model=model,
356
+ latents=latents,
357
+ timestep=timestep,
358
+ context=prompt_emb,
359
+ ref_latents=ref_latents,
360
+ motion_latents=motion_latents,
361
+ pose_cond=pose_cond,
362
+ audio_input=audio_input,
363
+ num_motion_frames=num_motion_frames,
364
+ num_motion_latents=num_motion_latents,
365
+ drop_motion_frames=drop_motion_frames,
366
+ audio_mask=audio_mask,
367
+ void_audio_input=void_audio_input,
368
+ )
369
+ noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
370
+ return noise_pred
371
+
372
+ def predict_noise(
373
+ self,
374
+ model: WanS2VDiT,
375
+ latents: torch.Tensor,
376
+ timestep: torch.Tensor,
377
+ context: torch.Tensor,
378
+ ref_latents: torch.Tensor,
379
+ motion_latents: torch.Tensor,
380
+ pose_cond: torch.Tensor,
381
+ audio_input: torch.Tensor,
382
+ num_motion_frames: int,
383
+ num_motion_latents: int,
384
+ drop_motion_frames: bool,
385
+ audio_mask: torch.Tensor | None = None,
386
+ void_audio_input: torch.Tensor | None = None,
387
+ ):
388
+ latents = latents.to(dtype=self.config.model_dtype, device=self.device)
389
+
390
+ noise_pred = model(
391
+ x=latents,
392
+ context=context,
393
+ timestep=timestep,
394
+ ref_latents=ref_latents,
395
+ motion_latents=motion_latents,
396
+ pose_cond=pose_cond,
397
+ audio_input=audio_input,
398
+ num_motion_frames=num_motion_frames,
399
+ num_motion_latents=num_motion_latents,
400
+ drop_motion_frames=drop_motion_frames,
401
+ audio_mask=audio_mask,
402
+ void_audio_input=void_audio_input,
403
+ )
404
+ return noise_pred
405
+
406
+ @torch.no_grad()
407
+ def __call__(
408
+ self,
409
+ audio: torch.Tensor,
410
+ prompt: str,
411
+ negative_prompt: str = "",
412
+ cfg_scale: float | None = None,
413
+ num_inference_steps: int | None = None,
414
+ seed: int | None = None,
415
+ height: int | None = None,
416
+ width: int | None = None,
417
+ num_frames_per_clip: int = 80,
418
+ ref_image: Image.Image | None = None,
419
+ pose_video: List[Image.Image] | None = None,
420
+ pose_video_fps: int | None = None,
421
+ void_audio: torch.Tensor | None = None,
422
+ num_clips: int = 1,
423
+ ref_as_first_frame: bool = False,
424
+ speaker_bbox: List[List[int]] = [],
425
+ speaker_end_sec: List[List[int]] = [],
426
+ progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status)
427
+ ):
428
+ assert ref_image is not None, "ref_image must be provided"
429
+ cfg_scale = self.config.cfg_scale if cfg_scale is None else cfg_scale
430
+ num_inference_steps = self.config.num_inference_steps if num_inference_steps is None else num_inference_steps
431
+ original_height, original_width = ref_image.height, ref_image.width
432
+ height, width = restrict_size_below_area(height, width, ref_image)
433
+
434
+ # Initialize noise
435
+ if dist.is_initialized() and seed is None:
436
+ raise ValueError("must provide a seed when parallelism is enabled")
437
+
438
+ # Encode prompts
439
+ self.load_models_to_device(["text_encoder"])
440
+ prompt_emb_posi = self.encode_prompt(prompt)
441
+ prompt_emb_nega = self.encode_prompt(negative_prompt)
442
+
443
+ # Encode ref image, previous video and audio
444
+ num_motion_frames = 73
445
+ num_motion_latents = (num_motion_frames + 3) // 4
446
+ ref_image = resize_and_center_crop(ref_image, height, width)
447
+ ref_latents, motion_latents, motion_frames = self.encode_ref_and_motion(
448
+ ref_image, height, width, num_motion_frames, ref_as_first_frame
449
+ )
450
+ audio_emb, num_clips = self.encode_audio(audio, num_frames_per_clip, num_clips)
451
+ if len(speaker_end_sec) > 0:
452
+ void_audio_emb = self.encode_void_audio(void_audio, num_frames_per_clip)
453
+ speaker_bbox = transform_bbox(
454
+ speaker_bbox,
455
+ original_height,
456
+ original_width,
457
+ height,
458
+ width,
459
+ )
460
+ audio_mask = get_face_mask(
461
+ ref_image=ref_image,
462
+ speaker_end_sec=speaker_end_sec,
463
+ speaker_bbox=speaker_bbox,
464
+ num_frames_total=num_clips * num_frames_per_clip,
465
+ fps=self.config.fps,
466
+ dtype=self.dtype,
467
+ ).to(self.device)
468
+ if pose_video is not None:
469
+ pose_latents_all_clips = self.encode_pose(pose_video, pose_video_fps, num_clips, num_frames_per_clip, height, width)
470
+
471
+ output_frames_all_clips = []
472
+ for clip_idx in range(num_clips):
473
+ num_latents_per_clip = num_frames_per_clip // 4
474
+ noise = self.generate_noise(
475
+ (
476
+ 1,
477
+ self.vae.z_dim,
478
+ num_latents_per_clip,
479
+ height // self.upsampling_factor,
480
+ width // self.upsampling_factor,
481
+ ),
482
+ seed=seed + clip_idx,
483
+ device="cpu",
484
+ dtype=torch.float32,
485
+ ).to(self.device)
486
+ _, latents, sigmas, timesteps = self.prepare_latents(
487
+ latents=noise,
488
+ input_video=None,
489
+ denoising_strength=None,
490
+ num_inference_steps=num_inference_steps,
491
+ )
492
+ # Initialize sampler
493
+ self.sampler.initialize(sigmas=sigmas)
494
+
495
+ # Index audio emb and pose latents
496
+ audio_emb_curr_clip = audio_emb[
497
+ ..., (clip_idx * num_frames_per_clip) : ((clip_idx + 1) * num_frames_per_clip)
498
+ ]
499
+ pose_latents_curr_clip = (
500
+ pose_latents_all_clips[clip_idx] if pose_video is not None else torch.zeros_like(latents)
501
+ )
502
+ pose_latents_curr_clip = pose_latents_curr_clip.to(dtype=self.dtype, device=self.device)
503
+ if len(speaker_end_sec) > 0:
504
+ audio_mask_curr_clip = audio_mask[
505
+ None, :, (clip_idx * num_latents_per_clip) : ((clip_idx + 1) * num_latents_per_clip)
506
+ ]
507
+ else:
508
+ audio_mask_curr_clip, void_audio_emb = None, None
509
+
510
+ # Denoise
511
+ drop_motion_frames = (not ref_as_first_frame) and clip_idx == 0
512
+ hide_progress = dist.is_initialized() and dist.get_rank() != 0
513
+ for i, timestep in enumerate(tqdm(timesteps, disable=hide_progress)):
514
+ self.load_models_to_device(["dit"])
515
+
516
+ timestep = timestep[None].to(dtype=self.dtype, device=self.device)
517
+ # Classifier-free guidance
518
+ noise_pred = self.predict_noise_with_cfg(
519
+ model=self.dit,
520
+ latents=latents,
521
+ timestep=timestep,
522
+ positive_prompt_emb=prompt_emb_posi,
523
+ negative_prompt_emb=prompt_emb_nega,
524
+ cfg_scale=cfg_scale,
525
+ batch_cfg=self.config.batch_cfg,
526
+ ref_latents=ref_latents,
527
+ motion_latents=motion_latents,
528
+ pose_cond=pose_latents_curr_clip,
529
+ audio_input=audio_emb_curr_clip,
530
+ num_motion_frames=num_motion_frames,
531
+ num_motion_latents=num_motion_latents,
532
+ drop_motion_frames=drop_motion_frames,
533
+ audio_mask=audio_mask_curr_clip,
534
+ void_audio_input=void_audio_emb,
535
+ )
536
+ # Scheduler
537
+ latents = self.sampler.step(latents, noise_pred, i)
538
+ if progress_callback is not None:
539
+ progress_callback(i + 1, len(timesteps), "DENOISING")
540
+
541
+ if drop_motion_frames:
542
+ decode_latents = torch.cat([ref_latents, latents], dim=2)
543
+ else:
544
+ decode_latents = torch.cat([motion_latents, latents], dim=2)
545
+ self.load_models_to_device(["vae"])
546
+ output_frames_curr_clip = torch.stack(
547
+ self.decode_video(decode_latents, progress_callback=progress_callback)
548
+ )
549
+ output_frames_curr_clip = output_frames_curr_clip[:, :, -(num_frames_per_clip):]
550
+ if drop_motion_frames:
551
+ output_frames_curr_clip = output_frames_curr_clip[:, :, 3:]
552
+ output_frames_all_clips.append(output_frames_curr_clip.cpu())
553
+
554
+ if clip_idx < num_clips - 1:
555
+ f = output_frames_curr_clip.shape[2]
556
+ if f <= num_motion_frames:
557
+ motion_frames = torch.cat([motion_frames[:, :, f:], output_frames_curr_clip], dim=2)
558
+ else:
559
+ motion_frames = output_frames_curr_clip[:, :, -num_motion_frames:]
560
+ motion_latents = self.encode_video(motion_frames)
561
+
562
+ output_frames_all_clips = torch.cat(output_frames_all_clips, dim=2)
563
+ output_frames_all_clips = self.vae_output_to_image(output_frames_all_clips)
564
+ return output_frames_all_clips
565
+
566
+ @classmethod
567
+ def from_pretrained(cls, model_path_or_config: WanSpeech2VideoPipelineConfig) -> "WanSpeech2VideoPipeline":
568
+ if isinstance(model_path_or_config, str):
569
+ config = WanSpeech2VideoPipelineConfig(model_path=model_path_or_config)
570
+ else:
571
+ config = model_path_or_config
572
+
573
+ logger.info(f"loading dit state dict from {config.model_path} ...")
574
+ dit_state_dict = cls.load_model_checkpoint(config.model_path, device="cpu", dtype=config.model_dtype)
575
+
576
+ if config.t5_path is None:
577
+ config.t5_path = fetch_model("muse/wan2.1-umt5", path="umt5.safetensors")
578
+ if config.vae_path is None:
579
+ config.vae_path = fetch_model("muse/wan2.1-vae", path="vae.safetensors")
580
+ if config.audio_encoder_path is None:
581
+ config.audio_encoder_path = fetch_model(
582
+ "Wan-AI/Wan2.2-S2V-14B", path="wav2vec2-large-xlsr-53-english/model.safetensors"
583
+ )
584
+
585
+ logger.info(f"loading t5 state dict from {config.t5_path} ...")
586
+ t5_state_dict = cls.load_model_checkpoint(config.t5_path, device="cpu", dtype=config.t5_dtype)
587
+
588
+ logger.info(f"loading vae state dict from {config.vae_path} ...")
589
+ vae_state_dict = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)
590
+
591
+ logger.info(f"loading audio encoder state dict from {config.audio_encoder_path} ...")
592
+ wav2vec_state_dict = cls.load_model_checkpoint(
593
+ config.audio_encoder_path, device="cpu", dtype=config.audio_encoder_dtype
594
+ )
595
+
596
+ state_dicts = WanS2VStateDicts(
597
+ model=dit_state_dict,
598
+ t5=t5_state_dict,
599
+ vae=vae_state_dict,
600
+ audio_encoder=wav2vec_state_dict,
601
+ )
602
+ return cls.from_state_dict(state_dicts, config)
603
+
604
+ @classmethod
605
+ def from_state_dict(cls, state_dicts: WanS2VStateDicts, config: WanSpeech2VideoPipelineConfig) -> "WanSpeech2VideoPipeline":
606
+ if config.parallelism > 1:
607
+ pipe = ParallelWrapper(
608
+ cfg_degree=config.cfg_degree,
609
+ sp_ulysses_degree=config.sp_ulysses_degree,
610
+ sp_ring_degree=config.sp_ring_degree,
611
+ tp_degree=config.tp_degree,
612
+ use_fsdp=config.use_fsdp,
613
+ )
614
+ pipe.load_module(cls._from_state_dict, state_dicts=state_dicts, config=config)
615
+ else:
616
+ pipe = cls._from_state_dict(state_dicts, config)
617
+ return pipe
618
+
619
+ @classmethod
620
+ def _from_state_dict(cls, state_dicts: WanS2VStateDicts, config: WanSpeech2VideoPipelineConfig) -> "WanSpeech2VideoPipeline":
621
+ # default params from model config
622
+ vae_type = "wan2.1-vae"
623
+ dit_type = "wan2.2-s2v-14b"
624
+ vae_config: dict = WanVideoVAE.get_model_config(vae_type)
625
+ model_config: dict = WanS2VDiT.get_model_config(dit_type)
626
+ config.boundary = model_config.pop("boundary", -1.0)
627
+ config.shift = model_config.pop("shift", 5.0)
628
+ config.cfg_scale = model_config.pop("cfg_scale", 5.0)
629
+ config.num_inference_steps = model_config.pop("num_inference_steps", 50)
630
+ config.fps = model_config.pop("fps", 16)
631
+
632
+ init_device = "cpu" if config.offload_mode is not None else config.device
633
+ tokenizer = WanT5Tokenizer(WAN_TOKENIZER_CONF_PATH, seq_len=512, clean="whitespace")
634
+ text_encoder = WanTextEncoder.from_state_dict(state_dicts.t5, device=init_device, dtype=config.t5_dtype)
635
+ vae = WanVideoVAE.from_state_dict(state_dicts.vae, config=vae_config, device=init_device, dtype=config.vae_dtype)
636
+ audio_encoder = Wav2Vec2Model.from_state_dict(
637
+ state_dicts.audio_encoder, config=Wav2Vec2Config(), device=init_device, dtype=config.audio_encoder_dtype
638
+ )
639
+
640
+ with LoRAContext():
641
+ attn_kwargs = {
642
+ "attn_impl": config.dit_attn_impl,
643
+ "sparge_smooth_k": config.sparge_smooth_k,
644
+ "sparge_cdfthreshd": config.sparge_cdfthreshd,
645
+ "sparge_simthreshd1": config.sparge_simthreshd1,
646
+ "sparge_pvthreshd": config.sparge_pvthreshd,
647
+ }
648
+ dit = WanS2VDiT.from_state_dict(
649
+ state_dicts.model,
650
+ config=model_config,
651
+ device=init_device,
652
+ dtype=config.model_dtype,
653
+ attn_kwargs=attn_kwargs,
654
+ )
655
+ if config.use_fp8_linear:
656
+ enable_fp8_linear(dit)
657
+
658
+ pipe = cls(
659
+ config=config,
660
+ tokenizer=tokenizer,
661
+ text_encoder=text_encoder,
662
+ dit=dit,
663
+ vae=vae,
664
+ audio_encoder=audio_encoder,
665
+ )
666
+ pipe.eval()
667
+
668
+ if config.offload_mode is not None:
669
+ pipe.enable_cpu_offload(config.offload_mode)
670
+
671
+ if config.model_dtype == torch.float8_e4m3fn:
672
+ pipe.dtype = torch.bfloat16 # compute dtype
673
+ pipe.enable_fp8_autocast(
674
+ model_names=["dit"], compute_dtype=pipe.dtype, use_fp8_linear=config.use_fp8_linear
675
+ )
676
+
677
+ if config.t5_dtype == torch.float8_e4m3fn:
678
+ pipe.dtype = torch.bfloat16 # compute dtype
679
+ pipe.enable_fp8_autocast(
680
+ model_names=["text_encoder"], compute_dtype=pipe.dtype, use_fp8_linear=config.use_fp8_linear
681
+ )
682
+
683
+ if config.use_torch_compile:
684
+ pipe.compile()
685
+ return pipe
@@ -34,6 +34,7 @@ WAN2_1_DIT_FLF2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit
34
34
  WAN2_2_DIT_TI2V_5B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2-ti2v-5b.json")
35
35
  WAN2_2_DIT_T2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2-t2v-a14b.json")
36
36
  WAN2_2_DIT_I2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2-i2v-a14b.json")
37
+ WAN2_2_DIT_S2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2-s2v-14b.json")
37
38
 
38
39
  WAN2_1_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan2.1-vae.json")
39
40
  WAN2_2_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan2.2-vae.json")