diffsynth-engine 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. diffsynth_engine/__init__.py +28 -0
  2. diffsynth_engine/algorithm/__init__.py +0 -0
  3. diffsynth_engine/algorithm/noise_scheduler/__init__.py +21 -0
  4. diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +10 -0
  5. diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +5 -0
  6. diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +28 -0
  7. diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +25 -0
  8. diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +50 -0
  9. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
  10. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +26 -0
  11. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +25 -0
  12. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +19 -0
  13. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +21 -0
  14. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +77 -0
  15. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +17 -0
  16. diffsynth_engine/algorithm/sampler/__init__.py +19 -0
  17. diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
  18. diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +22 -0
  19. diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
  20. diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +54 -0
  21. diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +32 -0
  22. diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +125 -0
  23. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +29 -0
  24. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +53 -0
  25. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +59 -0
  26. diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +29 -0
  27. diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +12 -0
  28. diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +30 -0
  29. diffsynth_engine/conf/models/components/vae.json +254 -0
  30. diffsynth_engine/conf/models/flux/flux_dit.json +105 -0
  31. diffsynth_engine/conf/models/flux/flux_text_encoder.json +20 -0
  32. diffsynth_engine/conf/models/flux/flux_vae.json +250 -0
  33. diffsynth_engine/conf/models/sd/sd_text_encoder.json +220 -0
  34. diffsynth_engine/conf/models/sd/sd_unet.json +397 -0
  35. diffsynth_engine/conf/models/sd3/sd3_dit.json +908 -0
  36. diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +756 -0
  37. diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +455 -0
  38. diffsynth_engine/conf/models/sdxl/sdxl_unet.json +1056 -0
  39. diffsynth_engine/conf/models/wan/dit/1.3b-t2v.json +13 -0
  40. diffsynth_engine/conf/models/wan/dit/14b-i2v.json +13 -0
  41. diffsynth_engine/conf/models/wan/dit/14b-t2v.json +13 -0
  42. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +48895 -0
  43. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +30 -0
  44. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +30 -0
  45. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +49410 -0
  46. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +125 -0
  47. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
  48. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +129428 -0
  49. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +940 -0
  50. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +48895 -0
  51. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +24 -0
  52. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +30 -0
  53. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +49410 -0
  54. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +40213 -0
  55. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +24 -0
  56. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +38 -0
  57. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +49411 -0
  58. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +308 -0
  59. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
  60. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +1028026 -0
  61. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +2748 -0
  62. diffsynth_engine/kernels/__init__.py +0 -0
  63. diffsynth_engine/models/__init__.py +7 -0
  64. diffsynth_engine/models/base.py +64 -0
  65. diffsynth_engine/models/basic/__init__.py +0 -0
  66. diffsynth_engine/models/basic/attention.py +217 -0
  67. diffsynth_engine/models/basic/lora.py +293 -0
  68. diffsynth_engine/models/basic/relative_position_emb.py +56 -0
  69. diffsynth_engine/models/basic/timestep.py +81 -0
  70. diffsynth_engine/models/basic/transformer_helper.py +88 -0
  71. diffsynth_engine/models/basic/unet_helper.py +244 -0
  72. diffsynth_engine/models/components/__init__.py +0 -0
  73. diffsynth_engine/models/components/clip.py +56 -0
  74. diffsynth_engine/models/components/t5.py +222 -0
  75. diffsynth_engine/models/components/vae.py +392 -0
  76. diffsynth_engine/models/flux/__init__.py +14 -0
  77. diffsynth_engine/models/flux/flux_dit.py +476 -0
  78. diffsynth_engine/models/flux/flux_text_encoder.py +88 -0
  79. diffsynth_engine/models/flux/flux_vae.py +78 -0
  80. diffsynth_engine/models/sd/__init__.py +12 -0
  81. diffsynth_engine/models/sd/sd_text_encoder.py +142 -0
  82. diffsynth_engine/models/sd/sd_unet.py +293 -0
  83. diffsynth_engine/models/sd/sd_vae.py +38 -0
  84. diffsynth_engine/models/sd3/__init__.py +14 -0
  85. diffsynth_engine/models/sd3/sd3_dit.py +302 -0
  86. diffsynth_engine/models/sd3/sd3_text_encoder.py +163 -0
  87. diffsynth_engine/models/sd3/sd3_vae.py +43 -0
  88. diffsynth_engine/models/sdxl/__init__.py +13 -0
  89. diffsynth_engine/models/sdxl/sdxl_text_encoder.py +307 -0
  90. diffsynth_engine/models/sdxl/sdxl_unet.py +306 -0
  91. diffsynth_engine/models/sdxl/sdxl_vae.py +38 -0
  92. diffsynth_engine/models/utils.py +54 -0
  93. diffsynth_engine/models/wan/__init__.py +0 -0
  94. diffsynth_engine/models/wan/wan_dit.py +497 -0
  95. diffsynth_engine/models/wan/wan_image_encoder.py +494 -0
  96. diffsynth_engine/models/wan/wan_text_encoder.py +297 -0
  97. diffsynth_engine/models/wan/wan_vae.py +771 -0
  98. diffsynth_engine/pipelines/__init__.py +18 -0
  99. diffsynth_engine/pipelines/base.py +253 -0
  100. diffsynth_engine/pipelines/flux_image.py +512 -0
  101. diffsynth_engine/pipelines/sd_image.py +352 -0
  102. diffsynth_engine/pipelines/sdxl_image.py +395 -0
  103. diffsynth_engine/pipelines/wan_video.py +524 -0
  104. diffsynth_engine/tokenizers/__init__.py +6 -0
  105. diffsynth_engine/tokenizers/base.py +157 -0
  106. diffsynth_engine/tokenizers/clip.py +288 -0
  107. diffsynth_engine/tokenizers/t5.py +194 -0
  108. diffsynth_engine/tokenizers/wan.py +74 -0
  109. diffsynth_engine/utils/__init__.py +0 -0
  110. diffsynth_engine/utils/constants.py +34 -0
  111. diffsynth_engine/utils/download.py +135 -0
  112. diffsynth_engine/utils/env.py +7 -0
  113. diffsynth_engine/utils/flag.py +46 -0
  114. diffsynth_engine/utils/fp8_linear.py +64 -0
  115. diffsynth_engine/utils/gguf.py +415 -0
  116. diffsynth_engine/utils/loader.py +17 -0
  117. diffsynth_engine/utils/lock.py +56 -0
  118. diffsynth_engine/utils/logging.py +12 -0
  119. diffsynth_engine/utils/offload.py +44 -0
  120. diffsynth_engine/utils/parallel.py +390 -0
  121. diffsynth_engine/utils/prompt.py +9 -0
  122. diffsynth_engine/utils/video.py +40 -0
  123. diffsynth_engine-0.0.0.dist-info/LICENSE +201 -0
  124. diffsynth_engine-0.0.0.dist-info/METADATA +236 -0
  125. diffsynth_engine-0.0.0.dist-info/RECORD +127 -0
  126. diffsynth_engine-0.0.0.dist-info/WHEEL +5 -0
  127. diffsynth_engine-0.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,524 @@
1
+ import torch
2
+ import numpy as np
3
+ from einops import rearrange
4
+ from dataclasses import dataclass
5
+ from functools import partial
6
+ from typing import Callable, List, Tuple, Optional
7
+ from tqdm import tqdm
8
+ from PIL import Image
9
+
10
+ from diffsynth_engine.algorithm.noise_scheduler.flow_match import RecifitedFlowScheduler
11
+ from diffsynth_engine.algorithm.sampler import FlowMatchEulerSampler
12
+ from diffsynth_engine.models.wan.wan_dit import WanDiT
13
+ from diffsynth_engine.models.wan.wan_text_encoder import WanTextEncoder
14
+ from diffsynth_engine.models.wan.wan_vae import WanVideoVAE
15
+ from diffsynth_engine.models.wan.wan_image_encoder import WanImageEncoder
16
+ from diffsynth_engine.models.basic.lora import LoRAContext
17
+ from diffsynth_engine.tokenizers import WanT5Tokenizer
18
+ from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
19
+ from diffsynth_engine.utils.constants import WAN_TOKENIZER_CONF_PATH
20
+ from diffsynth_engine.utils.download import fetch_model
21
+ from diffsynth_engine.utils.parallel import ParallelModel, shard_model
22
+ from diffsynth_engine.utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ @dataclass
29
+ class WanModelConfig:
30
+ model_path: Optional[str] = None
31
+ vae_path: Optional[str] = None
32
+ t5_path: Optional[str] = None
33
+ image_encoder_path: Optional[str] = None
34
+
35
+ vae_dtype: torch.dtype = torch.float32
36
+ dit_dtype: torch.dtype = torch.bfloat16
37
+ t5_dtype: torch.dtype = torch.bfloat16
38
+ image_encoder_dtype: torch.dtype = torch.bfloat16
39
+
40
+ dit_attn_impl: Optional[str] = "auto"
41
+ dit_fsdp: bool = False
42
+
43
+ sp_ulysses_degree: Optional[int] = None
44
+ sp_ring_degree: Optional[int] = None
45
+ tp_degree: Optional[int] = None
46
+
47
+
48
+ class WanLoRAConverter(LoRAStateDictConverter):
49
+ def _from_diffsynth(self, state_dict):
50
+ dit_dict = {}
51
+ for key, param in state_dict.items():
52
+ lora_args = {}
53
+ if ".lora_A.default.weight" not in key:
54
+ continue
55
+
56
+ lora_args["up"] = state_dict[key.replace(".lora_A.default.weight", ".lora_B.default.weight")]
57
+ lora_args["down"] = param
58
+ lora_args["rank"] = lora_args["up"].shape[1]
59
+ if key.replace(".lora_A.default.weight", ".alpha") in state_dict:
60
+ lora_args["alpha"] = state_dict[key.replace(".lora_A.default.weight", ".alpha")]
61
+ else:
62
+ lora_args["alpha"] = lora_args["rank"]
63
+ key = key.replace(".lora_A.default.weight", "")
64
+ dit_dict[key] = lora_args
65
+ return {"dit": dit_dict}
66
+
67
+ def _from_civitai(self, state_dict):
68
+ dit_dict = {}
69
+ for key, param in state_dict.items():
70
+ if ".lora_A.weight" not in key:
71
+ continue
72
+
73
+ lora_args = {}
74
+ lora_args["up"] = state_dict[key.replace(".lora_A.weight", ".lora_B.weight")]
75
+ lora_args["down"] = param
76
+ lora_args["rank"] = lora_args["up"].shape[1]
77
+ if key.replace(".lora_A.weight", ".alpha") in state_dict:
78
+ lora_args["alpha"] = state_dict[key.replace(".lora_A.weight", ".alpha")]
79
+ else:
80
+ lora_args["alpha"] = lora_args["rank"]
81
+ key = key.replace("diffusion_model.", "").replace(".lora_A.weight", "")
82
+ dit_dict[key] = lora_args
83
+ return {"dit": dit_dict}
84
+
85
+ def convert(self, state_dict):
86
+ if "diffusion_model.blocks.0.cross_attn.k.lora_A.weight" in state_dict:
87
+ state_dict = self._from_civitai(state_dict)
88
+ logger.info("use civitai format state dict")
89
+ else:
90
+ state_dict = self._from_diffsynth(state_dict)
91
+ logger.info("use diffsynth format state dict")
92
+ return state_dict
93
+
94
+
95
+ class WanVideoPipeline(BasePipeline):
96
+ lora_converter = WanLoRAConverter()
97
+
98
+ def __init__(
99
+ self,
100
+ config: WanModelConfig,
101
+ tokenizer: WanT5Tokenizer,
102
+ text_encoder: WanTextEncoder,
103
+ dit: WanDiT,
104
+ vae: WanVideoVAE,
105
+ image_encoder: WanImageEncoder,
106
+ batch_cfg: bool = False,
107
+ device="cuda",
108
+ dtype=torch.bfloat16,
109
+ ):
110
+ super().__init__(device=device, dtype=dtype)
111
+ self.noise_scheduler = RecifitedFlowScheduler(shift=5.0, sigma_min=0.001, sigma_max=0.999)
112
+ self.sampler = FlowMatchEulerSampler()
113
+ self.tokenizer = tokenizer
114
+ self.text_encoder = text_encoder
115
+ self.dit = dit
116
+ self.vae = vae
117
+ self.image_encoder = image_encoder
118
+ self.batch_cfg = batch_cfg
119
+ self.config = config
120
+ self.model_names = ["text_encoder", "dit", "vae"]
121
+
122
+ def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
123
+ assert self.config.tp_degree is None, (
124
+ "load LoRA is not allowed when tensor parallel is enabled; "
125
+ "set tp_degree=None during pipeline initialization"
126
+ )
127
+ assert not (self.config.dit_fsdp and fused), (
128
+ "load fused LoRA is not allowed when fully sharded data parallel is enabled; "
129
+ "either load LoRA with fused=False or set dit_fsdp=False during pipeline initialization"
130
+ )
131
+ super().load_loras(lora_list, fused, save_original_weight)
132
+
133
+ def unload_loras(self):
134
+ self.dit.unload_loras()
135
+ self.text_encoder.unload_loras()
136
+
137
+ def denoising_model(self):
138
+ return self.dit
139
+
140
+ def encode_prompt(self, prompt):
141
+ ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True)
142
+ ids = ids.to(self.device)
143
+ mask = mask.to(self.device)
144
+ prompt_emb = self.text_encoder(ids, mask)
145
+ prompt_emb = prompt_emb.masked_fill(mask.unsqueeze(-1).expand_as(prompt_emb) == 0, 0)
146
+ return prompt_emb
147
+
148
+ def encode_image(self, image, num_frames, height, width):
149
+ image = self.preprocess_image(image.resize((width, height), Image.Resampling.LANCZOS)).to(
150
+ self.device, self.config.image_encoder_dtype
151
+ )
152
+ clip_context = self.image_encoder.encode_image([image])
153
+ msk = torch.ones(
154
+ 1, num_frames, height // 8, width // 8, device=self.device, dtype=self.config.image_encoder_dtype
155
+ )
156
+ msk[:, 1:] = 0
157
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
158
+ msk = msk.view(1, msk.shape[1] // 4, 4, height // 8, width // 8)
159
+ msk = msk.transpose(1, 2)[0]
160
+ y = self.vae.encode(
161
+ [
162
+ torch.concat(
163
+ [
164
+ image.transpose(0, 1),
165
+ torch.zeros(3, num_frames - 1, height, width).to(image.device, self.config.vae_dtype),
166
+ ],
167
+ dim=1,
168
+ )
169
+ ],
170
+ device=self.device,
171
+ )[0]
172
+ y = torch.concat([msk, y]).to(dtype=self.dtype)
173
+ return clip_context, torch.unsqueeze(y, 0)
174
+
175
+ def tensor2video(self, frames):
176
+ frames = rearrange(frames, "C T H W -> T H W C")
177
+ frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
178
+ frames = [Image.fromarray(frame) for frame in frames]
179
+ return frames
180
+
181
+ def encode_video(self, videos: torch.Tensor, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
182
+ videos = videos.to(dtype=self.config.vae_dtype, device=self.device)
183
+ latents = self.vae.encode(videos, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
184
+ latents = latents.to(dtype=self.config.dit_dtype, device=self.device)
185
+ return latents
186
+
187
+ def decode_video(
188
+ self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16), progress_callback=None
189
+ ) -> List[torch.Tensor]:
190
+ latents = latents.to(dtype=self.config.vae_dtype, device=self.device)
191
+ videos = self.vae.decode(
192
+ latents,
193
+ device=self.device,
194
+ tiled=tiled,
195
+ tile_size=tile_size,
196
+ tile_stride=tile_stride,
197
+ progress_callback=progress_callback,
198
+ )
199
+ videos = [video.to(dtype=self.config.dit_dtype, device=self.device) for video in videos]
200
+ return videos
201
+
202
+ def predict_noise_with_cfg(
203
+ self,
204
+ latents: torch.Tensor,
205
+ image_clip_feature: torch.Tensor,
206
+ image_y: torch.Tensor,
207
+ timestep: torch.Tensor,
208
+ positive_prompt_emb: torch.Tensor,
209
+ negative_prompt_emb: torch.Tensor,
210
+ cfg_scale: float,
211
+ batch_cfg: bool,
212
+ ):
213
+ if cfg_scale <= 1.0:
214
+ return self.predict_noise(
215
+ latents=latents,
216
+ image_clip_feature=image_clip_feature,
217
+ image_y=image_y,
218
+ timestep=timestep,
219
+ context=positive_prompt_emb,
220
+ )
221
+ if not batch_cfg:
222
+ # cfg by predict noise one by one
223
+ positive_noise_pred = self.predict_noise(
224
+ latents=latents,
225
+ image_clip_feature=image_clip_feature,
226
+ image_y=image_y,
227
+ timestep=timestep,
228
+ context=positive_prompt_emb,
229
+ )
230
+ negative_noise_pred = self.predict_noise(
231
+ latents=latents,
232
+ image_clip_feature=image_clip_feature,
233
+ image_y=image_y,
234
+ timestep=timestep,
235
+ context=negative_prompt_emb,
236
+ )
237
+ noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
238
+ return noise_pred
239
+ else:
240
+ # cfg by predict noise in one batch
241
+ prompt_emb = torch.cat([positive_prompt_emb, negative_prompt_emb], dim=0)
242
+ latents = torch.cat([latents, latents], dim=0)
243
+ timestep = torch.cat([timestep, timestep], dim=0)
244
+ if image_y is not None:
245
+ image_y = torch.cat([image_y, image_y], dim=0)
246
+ if image_clip_feature is not None:
247
+ image_clip_feature = torch.cat([image_clip_feature, image_clip_feature], dim=0)
248
+ positive_noise_pred, negative_noise_pred = self.predict_noise(
249
+ latents=latents,
250
+ image_clip_feature=image_clip_feature,
251
+ image_y=image_y,
252
+ timestep=timestep,
253
+ context=prompt_emb,
254
+ )
255
+ noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
256
+ return noise_pred
257
+
258
+ def predict_noise(self, latents, image_clip_feature, image_y, timestep, context):
259
+ latents = latents.to(dtype=self.config.dit_dtype, device=self.device)
260
+
261
+ noise_pred = self.dit(
262
+ x=latents,
263
+ timestep=timestep,
264
+ context=context,
265
+ clip_feature=image_clip_feature,
266
+ y=image_y,
267
+ )
268
+ return noise_pred
269
+
270
+ def prepare_latents(
271
+ self,
272
+ latents,
273
+ input_video,
274
+ denoising_strength,
275
+ num_inference_steps,
276
+ tiled=True,
277
+ tile_size=(34, 34),
278
+ tile_stride=(18, 16),
279
+ ):
280
+ if input_video is not None:
281
+ total_steps = num_inference_steps
282
+ sigmas, timesteps = self.noise_scheduler.schedule(total_steps)
283
+ t_start = max(total_steps - int(num_inference_steps * denoising_strength), 1)
284
+ sigma_start, sigmas = sigmas[t_start - 1], sigmas[t_start - 1 :]
285
+ timesteps = timesteps[t_start - 1 :]
286
+
287
+ noise = latents
288
+ input_video = self.preprocess_images(input_video)
289
+ input_video = torch.stack(input_video, dim=2)
290
+ latents = self.encode_video(input_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(
291
+ dtype=latents.dtype, device=latents.device
292
+ )
293
+ init_latents = latents.clone()
294
+ latents = self.sampler.add_noise(latents, noise, sigma_start)
295
+ else:
296
+ sigmas, timesteps = self.noise_scheduler.schedule(num_inference_steps)
297
+ init_latents = latents.clone()
298
+
299
+ return init_latents, latents, sigmas, timesteps
300
+
301
+ @torch.no_grad()
302
+ def __call__(
303
+ self,
304
+ prompt,
305
+ negative_prompt="",
306
+ input_image=None,
307
+ input_video=None,
308
+ denoising_strength=1.0,
309
+ seed=None,
310
+ height=480,
311
+ width=832,
312
+ num_frames=81,
313
+ cfg_scale=5.0,
314
+ num_inference_steps=50,
315
+ tiled=True,
316
+ tile_size=(34, 34),
317
+ tile_stride=(18, 16),
318
+ progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status)
319
+ ):
320
+ assert height % 16 == 0 and width % 16 == 0, "height and width must be divisible by 16"
321
+ assert (num_frames - 1) % 4 == 0, "num_frames must be 4X+1"
322
+
323
+ # Initialize noise
324
+ noise = self.generate_noise(
325
+ (1, 16, (num_frames - 1) // 4 + 1, height // 8, width // 8), seed=seed, device="cpu", dtype=torch.float32
326
+ ).to(self.device)
327
+ init_latents, latents, sigmas, timesteps = self.prepare_latents(
328
+ noise,
329
+ input_video,
330
+ denoising_strength,
331
+ num_inference_steps,
332
+ tiled=tiled,
333
+ tile_size=tile_size,
334
+ tile_stride=tile_stride,
335
+ )
336
+ self.sampler.initialize(init_latents=init_latents, timesteps=timesteps, sigmas=sigmas)
337
+ # Encode prompts
338
+ self.load_models_to_device(["text_encoder"])
339
+ prompt_emb_posi = self.encode_prompt(prompt)
340
+ prompt_emb_nega = None if cfg_scale <= 1.0 else self.encode_prompt(negative_prompt)
341
+
342
+ # Encode image
343
+ if input_image is not None and self.image_encoder is not None:
344
+ self.load_models_to_device(["image_encoder", "vae"])
345
+ image_clip_feature, image_y = self.encode_image(input_image, num_frames, height, width)
346
+ else:
347
+ image_clip_feature, image_y = None, None
348
+
349
+ # Denoise
350
+ self.load_models_to_device(["dit"])
351
+ for i, timestep in enumerate(tqdm(timesteps)):
352
+ timestep = timestep.unsqueeze(0).to(dtype=self.config.dit_dtype, device=self.device)
353
+ # Classifier-free guidance
354
+ noise_pred = self.predict_noise_with_cfg(
355
+ latents=latents,
356
+ timestep=timestep,
357
+ positive_prompt_emb=prompt_emb_posi,
358
+ negative_prompt_emb=prompt_emb_nega,
359
+ image_clip_feature=image_clip_feature,
360
+ image_y=image_y,
361
+ cfg_scale=cfg_scale,
362
+ batch_cfg=self.batch_cfg,
363
+ )
364
+ # Scheduler
365
+ latents = self.sampler.step(latents, noise_pred, i)
366
+ if progress_callback is not None:
367
+ progress_callback(i + 1, len(timesteps), "DENOISING")
368
+
369
+ # Decode
370
+ self.load_models_to_device(["vae"])
371
+ frames = self.decode_video(
372
+ latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, progress_callback=progress_callback
373
+ )
374
+ frames = self.tensor2video(frames[0])
375
+ return frames
376
+
377
+ @classmethod
378
+ def from_pretrained(
379
+ cls,
380
+ model_path_or_config: str | WanModelConfig,
381
+ device: str = "cuda",
382
+ dtype: torch.dtype = torch.bfloat16,
383
+ batch_cfg: bool = False,
384
+ offload_mode: str | None = None,
385
+ parallelism: int = 1,
386
+ use_cfg_parallel: bool = False,
387
+ ) -> "WanVideoPipeline":
388
+ cls.validate_offload_mode(offload_mode)
389
+
390
+ if isinstance(model_path_or_config, str):
391
+ model_config = WanModelConfig(model_path=model_path_or_config)
392
+ else:
393
+ model_config = model_path_or_config
394
+
395
+ if model_config.model_path is None:
396
+ model_config.model_path = fetch_model("MusePublic/wan2.1-1.3b", path="dit.safetensors")
397
+ if model_config.t5_path is None:
398
+ model_config.t5_path = fetch_model("muse/wan2.1-umt5", path="umt5.safetensors")
399
+ if model_config.vae_path is None:
400
+ model_config.vae_path = fetch_model("muse/wan2.1-vae", path="vae.safetensors")
401
+
402
+ logger.info(f"loading state dict from {model_config.model_path} ...")
403
+ dit_state_dict = cls.load_model_checkpoint(model_config.model_path, device="cpu", dtype=model_config.dit_dtype)
404
+
405
+ logger.info(f"loading state dict from {model_config.t5_path} ...")
406
+ t5_state_dict = cls.load_model_checkpoint(model_config.t5_path, device="cpu", dtype=model_config.t5_dtype)
407
+
408
+ logger.info(f"loading state dict from {model_config.vae_path} ...")
409
+ vae_state_dict = cls.load_model_checkpoint(model_config.vae_path, device="cpu", dtype=model_config.vae_dtype)
410
+
411
+ init_device = "cpu" if offload_mode else device
412
+ tokenizer = WanT5Tokenizer(WAN_TOKENIZER_CONF_PATH, seq_len=512, clean="whitespace")
413
+ text_encoder = WanTextEncoder.from_state_dict(t5_state_dict, device=init_device, dtype=model_config.t5_dtype)
414
+
415
+ vae = WanVideoVAE.from_state_dict(vae_state_dict, device=init_device, dtype=model_config.vae_dtype)
416
+
417
+ image_encoder = None
418
+ if model_config.image_encoder_path is not None:
419
+ logger.info(f"loading state dict from {model_config.image_encoder_path} ...")
420
+ image_encoder_state_dict = cls.load_model_checkpoint(
421
+ model_config.image_encoder_path,
422
+ device="cpu",
423
+ dtype=model_config.image_encoder_dtype,
424
+ )
425
+ image_encoder = WanImageEncoder.from_state_dict(
426
+ image_encoder_state_dict,
427
+ device=init_device,
428
+ dtype=model_config.image_encoder_dtype,
429
+ )
430
+
431
+ # determine wan video model type by dit params
432
+ model_type = None
433
+ if "blocks.39.self_attn.norm_q.weight" in dit_state_dict:
434
+ if image_encoder is not None:
435
+ model_type = "14b-i2v"
436
+ else:
437
+ model_type = "14b-t2v"
438
+ else:
439
+ model_type = "1.3b-t2v"
440
+
441
+ if parallelism > 1:
442
+ assert parallelism in (2, 4, 8), "parallelism must be 2, 4 or 8"
443
+ batch_cfg = True if use_cfg_parallel else batch_cfg
444
+ cfg_degree = 2 if use_cfg_parallel else 1
445
+ sp_ulysses_degree = model_config.sp_ulysses_degree
446
+ sp_ring_degree = model_config.sp_ring_degree
447
+ tp_degree = model_config.tp_degree
448
+
449
+ if tp_degree is not None:
450
+ assert sp_ulysses_degree is None and sp_ring_degree is None, (
451
+ "not allowed to enable sequence parallel and tensor parallel together; "
452
+ "either set sp_ulysses_degree=None, sp_ring_degree=None or set tp_degree=None during pipeline initialization"
453
+ )
454
+ assert model_config.dit_fsdp is False, (
455
+ "not allowed to enable fully sharded data parallel and tensor parallel together; "
456
+ "either set dit_fsdp=False or set tp_degree=None during pipeline initialization"
457
+ )
458
+ assert parallelism == cfg_degree * tp_degree, (
459
+ f"parallelism ({parallelism}) must be equal to cfg_degree ({cfg_degree}) * tp_degree ({tp_degree})"
460
+ )
461
+ sp_ulysses_degree = 1
462
+ sp_ring_degree = 1
463
+ elif sp_ulysses_degree is None and sp_ring_degree is None:
464
+ # use ulysses if not specified
465
+ sp_ulysses_degree = parallelism // cfg_degree
466
+ sp_ring_degree = 1
467
+ tp_degree = 1
468
+ elif sp_ulysses_degree is not None and sp_ring_degree is not None:
469
+ assert parallelism == cfg_degree * sp_ulysses_degree * sp_ring_degree, (
470
+ f"parallelism ({parallelism}) must be equal to cfg_degree ({cfg_degree}) * "
471
+ f"sp_ulysses_degree ({sp_ulysses_degree}) * sp_ring_degree ({sp_ring_degree})"
472
+ )
473
+ tp_degree = 1
474
+ else:
475
+ raise ValueError("sp_ulysses_degree and sp_ring_degree must be specified together")
476
+
477
+ with LoRAContext():
478
+ dit = WanDiT.from_state_dict(
479
+ dit_state_dict,
480
+ model_type=model_type,
481
+ device="cpu",
482
+ dtype=model_config.dit_dtype,
483
+ attn_impl=model_config.dit_attn_impl,
484
+ use_usp=(sp_ulysses_degree * sp_ring_degree > 1),
485
+ )
486
+ dit = ParallelModel(
487
+ dit,
488
+ cfg_degree=cfg_degree,
489
+ sp_ulysses_degree=sp_ulysses_degree,
490
+ sp_ring_degree=sp_ring_degree,
491
+ tp_degree=tp_degree,
492
+ shard_fn=partial(shard_model, wrap_module_names=["blocks"]) if model_config.dit_fsdp else None,
493
+ device="cuda",
494
+ )
495
+ else:
496
+ with LoRAContext():
497
+ dit = WanDiT.from_state_dict(
498
+ dit_state_dict,
499
+ model_type=model_type,
500
+ device=init_device,
501
+ dtype=model_config.dit_dtype,
502
+ attn_impl=model_config.dit_attn_impl,
503
+ )
504
+
505
+ pipe = cls(
506
+ config=model_config,
507
+ tokenizer=tokenizer,
508
+ text_encoder=text_encoder,
509
+ dit=dit,
510
+ vae=vae,
511
+ image_encoder=image_encoder,
512
+ batch_cfg=batch_cfg,
513
+ device=device,
514
+ dtype=dtype,
515
+ )
516
+ pipe.eval()
517
+ if offload_mode == "cpu_offload":
518
+ pipe.enable_cpu_offload()
519
+ elif offload_mode == "sequential_cpu_offload":
520
+ pipe.enable_sequential_cpu_offload()
521
+ return pipe
522
+
523
+ def __del__(self):
524
+ del self.dit
@@ -0,0 +1,6 @@
1
+ from .base import BaseTokenizer
2
+ from .clip import CLIPTokenizer
3
+ from .t5 import T5TokenizerFast
4
+ from .wan import WanT5Tokenizer
5
+
6
+ __all__ = ["BaseTokenizer", "CLIPTokenizer", "T5TokenizerFast", "WanT5Tokenizer"]
@@ -0,0 +1,157 @@
1
+ # Modified from transformers.tokenization_utils_base
2
+ from typing import Dict, List, Union, overload
3
+
4
+
5
+ TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
6
+
7
+
8
+ class BaseTokenizer:
9
+ SPECIAL_TOKENS_ATTRIBUTES = [
10
+ "bos_token",
11
+ "eos_token",
12
+ "unk_token",
13
+ "pad_token",
14
+ ]
15
+
16
+ def __init__(self, **kwargs):
17
+ self.bos_token = None
18
+ self.eos_token = None
19
+ self.unk_token = None
20
+ self.pad_token = None
21
+
22
+ for key, value in kwargs.items():
23
+ if value is None:
24
+ continue
25
+ if key in self.SPECIAL_TOKENS_ATTRIBUTES:
26
+ if isinstance(value, str):
27
+ setattr(self, key, value)
28
+ else:
29
+ raise TypeError(f"Special token {key} has to be str but got: {type(value)}")
30
+
31
+ self.model_max_length = kwargs.pop("model_max_length", None)
32
+
33
+ self.clean_up_tokenization_spaces = kwargs.pop("clean_up_tokenization_spaces", False)
34
+
35
+ @property
36
+ def bos_token_id(self) -> int:
37
+ if self.bos_token is None:
38
+ raise ValueError("Special token bos_token is not defined")
39
+ return self.convert_tokens_to_ids(self.bos_token)
40
+
41
+ @property
42
+ def eos_token_id(self) -> int:
43
+ if self.eos_token is None:
44
+ raise ValueError("Special token eos_token is not defined")
45
+ return self.convert_tokens_to_ids(self.eos_token)
46
+
47
+ @property
48
+ def unk_token_id(self) -> int:
49
+ if self.unk_token is None:
50
+ raise ValueError("Special token unk_token is not defined")
51
+ return self.convert_tokens_to_ids(self.unk_token)
52
+
53
+ @property
54
+ def pad_token_id(self) -> int:
55
+ if self.pad_token is None:
56
+ raise ValueError("Special token pad_token is not defined")
57
+ return self.convert_tokens_to_ids(self.pad_token)
58
+
59
+ @property
60
+ def special_tokens_map(self) -> Dict[str, str]:
61
+ """
62
+ `Dict[str, str]`: A dictionary mapping special token class attributes (`bos_token`, `unk_token`, etc.)
63
+ to their values (`'<bos>'`, `'<unk>'`, etc.).
64
+ """
65
+ set_attr = {}
66
+ for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
67
+ attr_value = getattr(self, attr)
68
+ if attr_value:
69
+ set_attr[attr] = attr_value
70
+ return set_attr
71
+
72
+ @property
73
+ def all_special_tokens(self) -> List[str]:
74
+ """
75
+ `List[str]`: A list of the unique special tokens (`'<bos>'`, `'<unk>'`, ..., etc.).
76
+ """
77
+ return list(self.special_tokens_map.values())
78
+
79
+ @property
80
+ def all_special_ids(self) -> List[int]:
81
+ """
82
+ `List[int]`: List the ids of the special tokens(`'<bos>'`, `'<unk>'`, etc.) mapped to class attributes.
83
+ """
84
+ return self.convert_tokens_to_ids(self.all_special_tokens)
85
+
86
+ @overload
87
+ def tokenize(self, texts: str) -> List[str]: ...
88
+
89
+ @overload
90
+ def tokenize(self, texts: List[str]) -> List[List[str]]: ...
91
+
92
+ def tokenize(self, texts: Union[str, List[str]]) -> Union[List[str], List[List[str]]]:
93
+ raise NotImplementedError()
94
+
95
+ def encode(self, texts: str) -> List[int]:
96
+ raise NotImplementedError()
97
+
98
+ def batch_encode(self, texts: List[str]) -> List[List[int]]:
99
+ raise NotImplementedError()
100
+
101
+ def decode(
102
+ self, ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = None
103
+ ) -> str:
104
+ raise NotImplementedError()
105
+
106
+ def batch_decode(
107
+ self, ids: List[List[int]], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = None
108
+ ) -> List[str]:
109
+ raise NotImplementedError()
110
+
111
+ @overload
112
+ def convert_tokens_to_ids(self, tokens: str) -> int: ...
113
+
114
+ @overload
115
+ def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]: ...
116
+
117
+ def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
118
+ raise NotImplementedError()
119
+
120
+ @overload
121
+ def convert_ids_to_tokens(self, ids: int, skip_special_tokens: bool = False) -> str: ...
122
+
123
+ @overload
124
+ def convert_ids_to_tokens(self, ids: List[int], skip_special_tokens: bool = False) -> List[str]: ...
125
+
126
+ def convert_ids_to_tokens(
127
+ self, ids: Union[int, List[int]], skip_special_tokens: bool = False
128
+ ) -> Union[str, List[str]]:
129
+ raise NotImplementedError()
130
+
131
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
132
+ raise NotImplementedError()
133
+
134
+ @staticmethod
135
+ def clean_up_tokenization(text: str) -> str:
136
+ """
137
+ Clean up a list of simple English tokenization artifacts like spaces before punctuations and abbreviated forms.
138
+
139
+ Args:
140
+ text (`str`): The text to clean up.
141
+
142
+ Returns:
143
+ `str`: The cleaned-up string.
144
+ """
145
+ text = (
146
+ text.replace(" .", ".")
147
+ .replace(" ?", "?")
148
+ .replace(" !", "!")
149
+ .replace(" ,", ",")
150
+ .replace(" ' ", "'")
151
+ .replace(" n't", "n't")
152
+ .replace(" 'm", "'m")
153
+ .replace(" 's", "'s")
154
+ .replace(" 've", "'ve")
155
+ .replace(" 're", "'re")
156
+ )
157
+ return text