diffsynth-engine 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/__init__.py +28 -0
- diffsynth_engine/algorithm/__init__.py +0 -0
- diffsynth_engine/algorithm/noise_scheduler/__init__.py +21 -0
- diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +10 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +5 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +28 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +25 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +50 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +26 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +25 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +19 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +21 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +77 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +17 -0
- diffsynth_engine/algorithm/sampler/__init__.py +19 -0
- diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
- diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +22 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +54 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +32 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +125 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +29 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +53 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +59 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +29 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +12 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +30 -0
- diffsynth_engine/conf/models/components/vae.json +254 -0
- diffsynth_engine/conf/models/flux/flux_dit.json +105 -0
- diffsynth_engine/conf/models/flux/flux_text_encoder.json +20 -0
- diffsynth_engine/conf/models/flux/flux_vae.json +250 -0
- diffsynth_engine/conf/models/sd/sd_text_encoder.json +220 -0
- diffsynth_engine/conf/models/sd/sd_unet.json +397 -0
- diffsynth_engine/conf/models/sd3/sd3_dit.json +908 -0
- diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +756 -0
- diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +455 -0
- diffsynth_engine/conf/models/sdxl/sdxl_unet.json +1056 -0
- diffsynth_engine/conf/models/wan/dit/1.3b-t2v.json +13 -0
- diffsynth_engine/conf/models/wan/dit/14b-i2v.json +13 -0
- diffsynth_engine/conf/models/wan/dit/14b-t2v.json +13 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +48895 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +30 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +30 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +49410 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +125 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +129428 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +940 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +48895 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +24 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +30 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +49410 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +40213 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +24 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +38 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +49411 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +308 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +1028026 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +2748 -0
- diffsynth_engine/kernels/__init__.py +0 -0
- diffsynth_engine/models/__init__.py +7 -0
- diffsynth_engine/models/base.py +64 -0
- diffsynth_engine/models/basic/__init__.py +0 -0
- diffsynth_engine/models/basic/attention.py +217 -0
- diffsynth_engine/models/basic/lora.py +293 -0
- diffsynth_engine/models/basic/relative_position_emb.py +56 -0
- diffsynth_engine/models/basic/timestep.py +81 -0
- diffsynth_engine/models/basic/transformer_helper.py +88 -0
- diffsynth_engine/models/basic/unet_helper.py +244 -0
- diffsynth_engine/models/components/__init__.py +0 -0
- diffsynth_engine/models/components/clip.py +56 -0
- diffsynth_engine/models/components/t5.py +222 -0
- diffsynth_engine/models/components/vae.py +392 -0
- diffsynth_engine/models/flux/__init__.py +14 -0
- diffsynth_engine/models/flux/flux_dit.py +476 -0
- diffsynth_engine/models/flux/flux_text_encoder.py +88 -0
- diffsynth_engine/models/flux/flux_vae.py +78 -0
- diffsynth_engine/models/sd/__init__.py +12 -0
- diffsynth_engine/models/sd/sd_text_encoder.py +142 -0
- diffsynth_engine/models/sd/sd_unet.py +293 -0
- diffsynth_engine/models/sd/sd_vae.py +38 -0
- diffsynth_engine/models/sd3/__init__.py +14 -0
- diffsynth_engine/models/sd3/sd3_dit.py +302 -0
- diffsynth_engine/models/sd3/sd3_text_encoder.py +163 -0
- diffsynth_engine/models/sd3/sd3_vae.py +43 -0
- diffsynth_engine/models/sdxl/__init__.py +13 -0
- diffsynth_engine/models/sdxl/sdxl_text_encoder.py +307 -0
- diffsynth_engine/models/sdxl/sdxl_unet.py +306 -0
- diffsynth_engine/models/sdxl/sdxl_vae.py +38 -0
- diffsynth_engine/models/utils.py +54 -0
- diffsynth_engine/models/wan/__init__.py +0 -0
- diffsynth_engine/models/wan/wan_dit.py +497 -0
- diffsynth_engine/models/wan/wan_image_encoder.py +494 -0
- diffsynth_engine/models/wan/wan_text_encoder.py +297 -0
- diffsynth_engine/models/wan/wan_vae.py +771 -0
- diffsynth_engine/pipelines/__init__.py +18 -0
- diffsynth_engine/pipelines/base.py +253 -0
- diffsynth_engine/pipelines/flux_image.py +512 -0
- diffsynth_engine/pipelines/sd_image.py +352 -0
- diffsynth_engine/pipelines/sdxl_image.py +395 -0
- diffsynth_engine/pipelines/wan_video.py +524 -0
- diffsynth_engine/tokenizers/__init__.py +6 -0
- diffsynth_engine/tokenizers/base.py +157 -0
- diffsynth_engine/tokenizers/clip.py +288 -0
- diffsynth_engine/tokenizers/t5.py +194 -0
- diffsynth_engine/tokenizers/wan.py +74 -0
- diffsynth_engine/utils/__init__.py +0 -0
- diffsynth_engine/utils/constants.py +34 -0
- diffsynth_engine/utils/download.py +135 -0
- diffsynth_engine/utils/env.py +7 -0
- diffsynth_engine/utils/flag.py +46 -0
- diffsynth_engine/utils/fp8_linear.py +64 -0
- diffsynth_engine/utils/gguf.py +415 -0
- diffsynth_engine/utils/loader.py +17 -0
- diffsynth_engine/utils/lock.py +56 -0
- diffsynth_engine/utils/logging.py +12 -0
- diffsynth_engine/utils/offload.py +44 -0
- diffsynth_engine/utils/parallel.py +390 -0
- diffsynth_engine/utils/prompt.py +9 -0
- diffsynth_engine/utils/video.py +40 -0
- diffsynth_engine-0.0.0.dist-info/LICENSE +201 -0
- diffsynth_engine-0.0.0.dist-info/METADATA +236 -0
- diffsynth_engine-0.0.0.dist-info/RECORD +127 -0
- diffsynth_engine-0.0.0.dist-info/WHEEL +5 -0
- diffsynth_engine-0.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,524 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
from einops import rearrange
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from functools import partial
|
|
6
|
+
from typing import Callable, List, Tuple, Optional
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
from PIL import Image
|
|
9
|
+
|
|
10
|
+
from diffsynth_engine.algorithm.noise_scheduler.flow_match import RecifitedFlowScheduler
|
|
11
|
+
from diffsynth_engine.algorithm.sampler import FlowMatchEulerSampler
|
|
12
|
+
from diffsynth_engine.models.wan.wan_dit import WanDiT
|
|
13
|
+
from diffsynth_engine.models.wan.wan_text_encoder import WanTextEncoder
|
|
14
|
+
from diffsynth_engine.models.wan.wan_vae import WanVideoVAE
|
|
15
|
+
from diffsynth_engine.models.wan.wan_image_encoder import WanImageEncoder
|
|
16
|
+
from diffsynth_engine.models.basic.lora import LoRAContext
|
|
17
|
+
from diffsynth_engine.tokenizers import WanT5Tokenizer
|
|
18
|
+
from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
|
|
19
|
+
from diffsynth_engine.utils.constants import WAN_TOKENIZER_CONF_PATH
|
|
20
|
+
from diffsynth_engine.utils.download import fetch_model
|
|
21
|
+
from diffsynth_engine.utils.parallel import ParallelModel, shard_model
|
|
22
|
+
from diffsynth_engine.utils import logging
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
logger = logging.get_logger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class WanModelConfig:
|
|
30
|
+
model_path: Optional[str] = None
|
|
31
|
+
vae_path: Optional[str] = None
|
|
32
|
+
t5_path: Optional[str] = None
|
|
33
|
+
image_encoder_path: Optional[str] = None
|
|
34
|
+
|
|
35
|
+
vae_dtype: torch.dtype = torch.float32
|
|
36
|
+
dit_dtype: torch.dtype = torch.bfloat16
|
|
37
|
+
t5_dtype: torch.dtype = torch.bfloat16
|
|
38
|
+
image_encoder_dtype: torch.dtype = torch.bfloat16
|
|
39
|
+
|
|
40
|
+
dit_attn_impl: Optional[str] = "auto"
|
|
41
|
+
dit_fsdp: bool = False
|
|
42
|
+
|
|
43
|
+
sp_ulysses_degree: Optional[int] = None
|
|
44
|
+
sp_ring_degree: Optional[int] = None
|
|
45
|
+
tp_degree: Optional[int] = None
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class WanLoRAConverter(LoRAStateDictConverter):
|
|
49
|
+
def _from_diffsynth(self, state_dict):
|
|
50
|
+
dit_dict = {}
|
|
51
|
+
for key, param in state_dict.items():
|
|
52
|
+
lora_args = {}
|
|
53
|
+
if ".lora_A.default.weight" not in key:
|
|
54
|
+
continue
|
|
55
|
+
|
|
56
|
+
lora_args["up"] = state_dict[key.replace(".lora_A.default.weight", ".lora_B.default.weight")]
|
|
57
|
+
lora_args["down"] = param
|
|
58
|
+
lora_args["rank"] = lora_args["up"].shape[1]
|
|
59
|
+
if key.replace(".lora_A.default.weight", ".alpha") in state_dict:
|
|
60
|
+
lora_args["alpha"] = state_dict[key.replace(".lora_A.default.weight", ".alpha")]
|
|
61
|
+
else:
|
|
62
|
+
lora_args["alpha"] = lora_args["rank"]
|
|
63
|
+
key = key.replace(".lora_A.default.weight", "")
|
|
64
|
+
dit_dict[key] = lora_args
|
|
65
|
+
return {"dit": dit_dict}
|
|
66
|
+
|
|
67
|
+
def _from_civitai(self, state_dict):
|
|
68
|
+
dit_dict = {}
|
|
69
|
+
for key, param in state_dict.items():
|
|
70
|
+
if ".lora_A.weight" not in key:
|
|
71
|
+
continue
|
|
72
|
+
|
|
73
|
+
lora_args = {}
|
|
74
|
+
lora_args["up"] = state_dict[key.replace(".lora_A.weight", ".lora_B.weight")]
|
|
75
|
+
lora_args["down"] = param
|
|
76
|
+
lora_args["rank"] = lora_args["up"].shape[1]
|
|
77
|
+
if key.replace(".lora_A.weight", ".alpha") in state_dict:
|
|
78
|
+
lora_args["alpha"] = state_dict[key.replace(".lora_A.weight", ".alpha")]
|
|
79
|
+
else:
|
|
80
|
+
lora_args["alpha"] = lora_args["rank"]
|
|
81
|
+
key = key.replace("diffusion_model.", "").replace(".lora_A.weight", "")
|
|
82
|
+
dit_dict[key] = lora_args
|
|
83
|
+
return {"dit": dit_dict}
|
|
84
|
+
|
|
85
|
+
def convert(self, state_dict):
|
|
86
|
+
if "diffusion_model.blocks.0.cross_attn.k.lora_A.weight" in state_dict:
|
|
87
|
+
state_dict = self._from_civitai(state_dict)
|
|
88
|
+
logger.info("use civitai format state dict")
|
|
89
|
+
else:
|
|
90
|
+
state_dict = self._from_diffsynth(state_dict)
|
|
91
|
+
logger.info("use diffsynth format state dict")
|
|
92
|
+
return state_dict
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class WanVideoPipeline(BasePipeline):
|
|
96
|
+
lora_converter = WanLoRAConverter()
|
|
97
|
+
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
config: WanModelConfig,
|
|
101
|
+
tokenizer: WanT5Tokenizer,
|
|
102
|
+
text_encoder: WanTextEncoder,
|
|
103
|
+
dit: WanDiT,
|
|
104
|
+
vae: WanVideoVAE,
|
|
105
|
+
image_encoder: WanImageEncoder,
|
|
106
|
+
batch_cfg: bool = False,
|
|
107
|
+
device="cuda",
|
|
108
|
+
dtype=torch.bfloat16,
|
|
109
|
+
):
|
|
110
|
+
super().__init__(device=device, dtype=dtype)
|
|
111
|
+
self.noise_scheduler = RecifitedFlowScheduler(shift=5.0, sigma_min=0.001, sigma_max=0.999)
|
|
112
|
+
self.sampler = FlowMatchEulerSampler()
|
|
113
|
+
self.tokenizer = tokenizer
|
|
114
|
+
self.text_encoder = text_encoder
|
|
115
|
+
self.dit = dit
|
|
116
|
+
self.vae = vae
|
|
117
|
+
self.image_encoder = image_encoder
|
|
118
|
+
self.batch_cfg = batch_cfg
|
|
119
|
+
self.config = config
|
|
120
|
+
self.model_names = ["text_encoder", "dit", "vae"]
|
|
121
|
+
|
|
122
|
+
def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
|
|
123
|
+
assert self.config.tp_degree is None, (
|
|
124
|
+
"load LoRA is not allowed when tensor parallel is enabled; "
|
|
125
|
+
"set tp_degree=None during pipeline initialization"
|
|
126
|
+
)
|
|
127
|
+
assert not (self.config.dit_fsdp and fused), (
|
|
128
|
+
"load fused LoRA is not allowed when fully sharded data parallel is enabled; "
|
|
129
|
+
"either load LoRA with fused=False or set dit_fsdp=False during pipeline initialization"
|
|
130
|
+
)
|
|
131
|
+
super().load_loras(lora_list, fused, save_original_weight)
|
|
132
|
+
|
|
133
|
+
def unload_loras(self):
|
|
134
|
+
self.dit.unload_loras()
|
|
135
|
+
self.text_encoder.unload_loras()
|
|
136
|
+
|
|
137
|
+
def denoising_model(self):
|
|
138
|
+
return self.dit
|
|
139
|
+
|
|
140
|
+
def encode_prompt(self, prompt):
|
|
141
|
+
ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True)
|
|
142
|
+
ids = ids.to(self.device)
|
|
143
|
+
mask = mask.to(self.device)
|
|
144
|
+
prompt_emb = self.text_encoder(ids, mask)
|
|
145
|
+
prompt_emb = prompt_emb.masked_fill(mask.unsqueeze(-1).expand_as(prompt_emb) == 0, 0)
|
|
146
|
+
return prompt_emb
|
|
147
|
+
|
|
148
|
+
def encode_image(self, image, num_frames, height, width):
|
|
149
|
+
image = self.preprocess_image(image.resize((width, height), Image.Resampling.LANCZOS)).to(
|
|
150
|
+
self.device, self.config.image_encoder_dtype
|
|
151
|
+
)
|
|
152
|
+
clip_context = self.image_encoder.encode_image([image])
|
|
153
|
+
msk = torch.ones(
|
|
154
|
+
1, num_frames, height // 8, width // 8, device=self.device, dtype=self.config.image_encoder_dtype
|
|
155
|
+
)
|
|
156
|
+
msk[:, 1:] = 0
|
|
157
|
+
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
|
158
|
+
msk = msk.view(1, msk.shape[1] // 4, 4, height // 8, width // 8)
|
|
159
|
+
msk = msk.transpose(1, 2)[0]
|
|
160
|
+
y = self.vae.encode(
|
|
161
|
+
[
|
|
162
|
+
torch.concat(
|
|
163
|
+
[
|
|
164
|
+
image.transpose(0, 1),
|
|
165
|
+
torch.zeros(3, num_frames - 1, height, width).to(image.device, self.config.vae_dtype),
|
|
166
|
+
],
|
|
167
|
+
dim=1,
|
|
168
|
+
)
|
|
169
|
+
],
|
|
170
|
+
device=self.device,
|
|
171
|
+
)[0]
|
|
172
|
+
y = torch.concat([msk, y]).to(dtype=self.dtype)
|
|
173
|
+
return clip_context, torch.unsqueeze(y, 0)
|
|
174
|
+
|
|
175
|
+
def tensor2video(self, frames):
|
|
176
|
+
frames = rearrange(frames, "C T H W -> T H W C")
|
|
177
|
+
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
|
|
178
|
+
frames = [Image.fromarray(frame) for frame in frames]
|
|
179
|
+
return frames
|
|
180
|
+
|
|
181
|
+
def encode_video(self, videos: torch.Tensor, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
|
182
|
+
videos = videos.to(dtype=self.config.vae_dtype, device=self.device)
|
|
183
|
+
latents = self.vae.encode(videos, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
184
|
+
latents = latents.to(dtype=self.config.dit_dtype, device=self.device)
|
|
185
|
+
return latents
|
|
186
|
+
|
|
187
|
+
def decode_video(
|
|
188
|
+
self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16), progress_callback=None
|
|
189
|
+
) -> List[torch.Tensor]:
|
|
190
|
+
latents = latents.to(dtype=self.config.vae_dtype, device=self.device)
|
|
191
|
+
videos = self.vae.decode(
|
|
192
|
+
latents,
|
|
193
|
+
device=self.device,
|
|
194
|
+
tiled=tiled,
|
|
195
|
+
tile_size=tile_size,
|
|
196
|
+
tile_stride=tile_stride,
|
|
197
|
+
progress_callback=progress_callback,
|
|
198
|
+
)
|
|
199
|
+
videos = [video.to(dtype=self.config.dit_dtype, device=self.device) for video in videos]
|
|
200
|
+
return videos
|
|
201
|
+
|
|
202
|
+
def predict_noise_with_cfg(
|
|
203
|
+
self,
|
|
204
|
+
latents: torch.Tensor,
|
|
205
|
+
image_clip_feature: torch.Tensor,
|
|
206
|
+
image_y: torch.Tensor,
|
|
207
|
+
timestep: torch.Tensor,
|
|
208
|
+
positive_prompt_emb: torch.Tensor,
|
|
209
|
+
negative_prompt_emb: torch.Tensor,
|
|
210
|
+
cfg_scale: float,
|
|
211
|
+
batch_cfg: bool,
|
|
212
|
+
):
|
|
213
|
+
if cfg_scale <= 1.0:
|
|
214
|
+
return self.predict_noise(
|
|
215
|
+
latents=latents,
|
|
216
|
+
image_clip_feature=image_clip_feature,
|
|
217
|
+
image_y=image_y,
|
|
218
|
+
timestep=timestep,
|
|
219
|
+
context=positive_prompt_emb,
|
|
220
|
+
)
|
|
221
|
+
if not batch_cfg:
|
|
222
|
+
# cfg by predict noise one by one
|
|
223
|
+
positive_noise_pred = self.predict_noise(
|
|
224
|
+
latents=latents,
|
|
225
|
+
image_clip_feature=image_clip_feature,
|
|
226
|
+
image_y=image_y,
|
|
227
|
+
timestep=timestep,
|
|
228
|
+
context=positive_prompt_emb,
|
|
229
|
+
)
|
|
230
|
+
negative_noise_pred = self.predict_noise(
|
|
231
|
+
latents=latents,
|
|
232
|
+
image_clip_feature=image_clip_feature,
|
|
233
|
+
image_y=image_y,
|
|
234
|
+
timestep=timestep,
|
|
235
|
+
context=negative_prompt_emb,
|
|
236
|
+
)
|
|
237
|
+
noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
|
|
238
|
+
return noise_pred
|
|
239
|
+
else:
|
|
240
|
+
# cfg by predict noise in one batch
|
|
241
|
+
prompt_emb = torch.cat([positive_prompt_emb, negative_prompt_emb], dim=0)
|
|
242
|
+
latents = torch.cat([latents, latents], dim=0)
|
|
243
|
+
timestep = torch.cat([timestep, timestep], dim=0)
|
|
244
|
+
if image_y is not None:
|
|
245
|
+
image_y = torch.cat([image_y, image_y], dim=0)
|
|
246
|
+
if image_clip_feature is not None:
|
|
247
|
+
image_clip_feature = torch.cat([image_clip_feature, image_clip_feature], dim=0)
|
|
248
|
+
positive_noise_pred, negative_noise_pred = self.predict_noise(
|
|
249
|
+
latents=latents,
|
|
250
|
+
image_clip_feature=image_clip_feature,
|
|
251
|
+
image_y=image_y,
|
|
252
|
+
timestep=timestep,
|
|
253
|
+
context=prompt_emb,
|
|
254
|
+
)
|
|
255
|
+
noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
|
|
256
|
+
return noise_pred
|
|
257
|
+
|
|
258
|
+
def predict_noise(self, latents, image_clip_feature, image_y, timestep, context):
|
|
259
|
+
latents = latents.to(dtype=self.config.dit_dtype, device=self.device)
|
|
260
|
+
|
|
261
|
+
noise_pred = self.dit(
|
|
262
|
+
x=latents,
|
|
263
|
+
timestep=timestep,
|
|
264
|
+
context=context,
|
|
265
|
+
clip_feature=image_clip_feature,
|
|
266
|
+
y=image_y,
|
|
267
|
+
)
|
|
268
|
+
return noise_pred
|
|
269
|
+
|
|
270
|
+
def prepare_latents(
|
|
271
|
+
self,
|
|
272
|
+
latents,
|
|
273
|
+
input_video,
|
|
274
|
+
denoising_strength,
|
|
275
|
+
num_inference_steps,
|
|
276
|
+
tiled=True,
|
|
277
|
+
tile_size=(34, 34),
|
|
278
|
+
tile_stride=(18, 16),
|
|
279
|
+
):
|
|
280
|
+
if input_video is not None:
|
|
281
|
+
total_steps = num_inference_steps
|
|
282
|
+
sigmas, timesteps = self.noise_scheduler.schedule(total_steps)
|
|
283
|
+
t_start = max(total_steps - int(num_inference_steps * denoising_strength), 1)
|
|
284
|
+
sigma_start, sigmas = sigmas[t_start - 1], sigmas[t_start - 1 :]
|
|
285
|
+
timesteps = timesteps[t_start - 1 :]
|
|
286
|
+
|
|
287
|
+
noise = latents
|
|
288
|
+
input_video = self.preprocess_images(input_video)
|
|
289
|
+
input_video = torch.stack(input_video, dim=2)
|
|
290
|
+
latents = self.encode_video(input_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(
|
|
291
|
+
dtype=latents.dtype, device=latents.device
|
|
292
|
+
)
|
|
293
|
+
init_latents = latents.clone()
|
|
294
|
+
latents = self.sampler.add_noise(latents, noise, sigma_start)
|
|
295
|
+
else:
|
|
296
|
+
sigmas, timesteps = self.noise_scheduler.schedule(num_inference_steps)
|
|
297
|
+
init_latents = latents.clone()
|
|
298
|
+
|
|
299
|
+
return init_latents, latents, sigmas, timesteps
|
|
300
|
+
|
|
301
|
+
@torch.no_grad()
|
|
302
|
+
def __call__(
|
|
303
|
+
self,
|
|
304
|
+
prompt,
|
|
305
|
+
negative_prompt="",
|
|
306
|
+
input_image=None,
|
|
307
|
+
input_video=None,
|
|
308
|
+
denoising_strength=1.0,
|
|
309
|
+
seed=None,
|
|
310
|
+
height=480,
|
|
311
|
+
width=832,
|
|
312
|
+
num_frames=81,
|
|
313
|
+
cfg_scale=5.0,
|
|
314
|
+
num_inference_steps=50,
|
|
315
|
+
tiled=True,
|
|
316
|
+
tile_size=(34, 34),
|
|
317
|
+
tile_stride=(18, 16),
|
|
318
|
+
progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status)
|
|
319
|
+
):
|
|
320
|
+
assert height % 16 == 0 and width % 16 == 0, "height and width must be divisible by 16"
|
|
321
|
+
assert (num_frames - 1) % 4 == 0, "num_frames must be 4X+1"
|
|
322
|
+
|
|
323
|
+
# Initialize noise
|
|
324
|
+
noise = self.generate_noise(
|
|
325
|
+
(1, 16, (num_frames - 1) // 4 + 1, height // 8, width // 8), seed=seed, device="cpu", dtype=torch.float32
|
|
326
|
+
).to(self.device)
|
|
327
|
+
init_latents, latents, sigmas, timesteps = self.prepare_latents(
|
|
328
|
+
noise,
|
|
329
|
+
input_video,
|
|
330
|
+
denoising_strength,
|
|
331
|
+
num_inference_steps,
|
|
332
|
+
tiled=tiled,
|
|
333
|
+
tile_size=tile_size,
|
|
334
|
+
tile_stride=tile_stride,
|
|
335
|
+
)
|
|
336
|
+
self.sampler.initialize(init_latents=init_latents, timesteps=timesteps, sigmas=sigmas)
|
|
337
|
+
# Encode prompts
|
|
338
|
+
self.load_models_to_device(["text_encoder"])
|
|
339
|
+
prompt_emb_posi = self.encode_prompt(prompt)
|
|
340
|
+
prompt_emb_nega = None if cfg_scale <= 1.0 else self.encode_prompt(negative_prompt)
|
|
341
|
+
|
|
342
|
+
# Encode image
|
|
343
|
+
if input_image is not None and self.image_encoder is not None:
|
|
344
|
+
self.load_models_to_device(["image_encoder", "vae"])
|
|
345
|
+
image_clip_feature, image_y = self.encode_image(input_image, num_frames, height, width)
|
|
346
|
+
else:
|
|
347
|
+
image_clip_feature, image_y = None, None
|
|
348
|
+
|
|
349
|
+
# Denoise
|
|
350
|
+
self.load_models_to_device(["dit"])
|
|
351
|
+
for i, timestep in enumerate(tqdm(timesteps)):
|
|
352
|
+
timestep = timestep.unsqueeze(0).to(dtype=self.config.dit_dtype, device=self.device)
|
|
353
|
+
# Classifier-free guidance
|
|
354
|
+
noise_pred = self.predict_noise_with_cfg(
|
|
355
|
+
latents=latents,
|
|
356
|
+
timestep=timestep,
|
|
357
|
+
positive_prompt_emb=prompt_emb_posi,
|
|
358
|
+
negative_prompt_emb=prompt_emb_nega,
|
|
359
|
+
image_clip_feature=image_clip_feature,
|
|
360
|
+
image_y=image_y,
|
|
361
|
+
cfg_scale=cfg_scale,
|
|
362
|
+
batch_cfg=self.batch_cfg,
|
|
363
|
+
)
|
|
364
|
+
# Scheduler
|
|
365
|
+
latents = self.sampler.step(latents, noise_pred, i)
|
|
366
|
+
if progress_callback is not None:
|
|
367
|
+
progress_callback(i + 1, len(timesteps), "DENOISING")
|
|
368
|
+
|
|
369
|
+
# Decode
|
|
370
|
+
self.load_models_to_device(["vae"])
|
|
371
|
+
frames = self.decode_video(
|
|
372
|
+
latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, progress_callback=progress_callback
|
|
373
|
+
)
|
|
374
|
+
frames = self.tensor2video(frames[0])
|
|
375
|
+
return frames
|
|
376
|
+
|
|
377
|
+
@classmethod
|
|
378
|
+
def from_pretrained(
|
|
379
|
+
cls,
|
|
380
|
+
model_path_or_config: str | WanModelConfig,
|
|
381
|
+
device: str = "cuda",
|
|
382
|
+
dtype: torch.dtype = torch.bfloat16,
|
|
383
|
+
batch_cfg: bool = False,
|
|
384
|
+
offload_mode: str | None = None,
|
|
385
|
+
parallelism: int = 1,
|
|
386
|
+
use_cfg_parallel: bool = False,
|
|
387
|
+
) -> "WanVideoPipeline":
|
|
388
|
+
cls.validate_offload_mode(offload_mode)
|
|
389
|
+
|
|
390
|
+
if isinstance(model_path_or_config, str):
|
|
391
|
+
model_config = WanModelConfig(model_path=model_path_or_config)
|
|
392
|
+
else:
|
|
393
|
+
model_config = model_path_or_config
|
|
394
|
+
|
|
395
|
+
if model_config.model_path is None:
|
|
396
|
+
model_config.model_path = fetch_model("MusePublic/wan2.1-1.3b", path="dit.safetensors")
|
|
397
|
+
if model_config.t5_path is None:
|
|
398
|
+
model_config.t5_path = fetch_model("muse/wan2.1-umt5", path="umt5.safetensors")
|
|
399
|
+
if model_config.vae_path is None:
|
|
400
|
+
model_config.vae_path = fetch_model("muse/wan2.1-vae", path="vae.safetensors")
|
|
401
|
+
|
|
402
|
+
logger.info(f"loading state dict from {model_config.model_path} ...")
|
|
403
|
+
dit_state_dict = cls.load_model_checkpoint(model_config.model_path, device="cpu", dtype=model_config.dit_dtype)
|
|
404
|
+
|
|
405
|
+
logger.info(f"loading state dict from {model_config.t5_path} ...")
|
|
406
|
+
t5_state_dict = cls.load_model_checkpoint(model_config.t5_path, device="cpu", dtype=model_config.t5_dtype)
|
|
407
|
+
|
|
408
|
+
logger.info(f"loading state dict from {model_config.vae_path} ...")
|
|
409
|
+
vae_state_dict = cls.load_model_checkpoint(model_config.vae_path, device="cpu", dtype=model_config.vae_dtype)
|
|
410
|
+
|
|
411
|
+
init_device = "cpu" if offload_mode else device
|
|
412
|
+
tokenizer = WanT5Tokenizer(WAN_TOKENIZER_CONF_PATH, seq_len=512, clean="whitespace")
|
|
413
|
+
text_encoder = WanTextEncoder.from_state_dict(t5_state_dict, device=init_device, dtype=model_config.t5_dtype)
|
|
414
|
+
|
|
415
|
+
vae = WanVideoVAE.from_state_dict(vae_state_dict, device=init_device, dtype=model_config.vae_dtype)
|
|
416
|
+
|
|
417
|
+
image_encoder = None
|
|
418
|
+
if model_config.image_encoder_path is not None:
|
|
419
|
+
logger.info(f"loading state dict from {model_config.image_encoder_path} ...")
|
|
420
|
+
image_encoder_state_dict = cls.load_model_checkpoint(
|
|
421
|
+
model_config.image_encoder_path,
|
|
422
|
+
device="cpu",
|
|
423
|
+
dtype=model_config.image_encoder_dtype,
|
|
424
|
+
)
|
|
425
|
+
image_encoder = WanImageEncoder.from_state_dict(
|
|
426
|
+
image_encoder_state_dict,
|
|
427
|
+
device=init_device,
|
|
428
|
+
dtype=model_config.image_encoder_dtype,
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
# determine wan video model type by dit params
|
|
432
|
+
model_type = None
|
|
433
|
+
if "blocks.39.self_attn.norm_q.weight" in dit_state_dict:
|
|
434
|
+
if image_encoder is not None:
|
|
435
|
+
model_type = "14b-i2v"
|
|
436
|
+
else:
|
|
437
|
+
model_type = "14b-t2v"
|
|
438
|
+
else:
|
|
439
|
+
model_type = "1.3b-t2v"
|
|
440
|
+
|
|
441
|
+
if parallelism > 1:
|
|
442
|
+
assert parallelism in (2, 4, 8), "parallelism must be 2, 4 or 8"
|
|
443
|
+
batch_cfg = True if use_cfg_parallel else batch_cfg
|
|
444
|
+
cfg_degree = 2 if use_cfg_parallel else 1
|
|
445
|
+
sp_ulysses_degree = model_config.sp_ulysses_degree
|
|
446
|
+
sp_ring_degree = model_config.sp_ring_degree
|
|
447
|
+
tp_degree = model_config.tp_degree
|
|
448
|
+
|
|
449
|
+
if tp_degree is not None:
|
|
450
|
+
assert sp_ulysses_degree is None and sp_ring_degree is None, (
|
|
451
|
+
"not allowed to enable sequence parallel and tensor parallel together; "
|
|
452
|
+
"either set sp_ulysses_degree=None, sp_ring_degree=None or set tp_degree=None during pipeline initialization"
|
|
453
|
+
)
|
|
454
|
+
assert model_config.dit_fsdp is False, (
|
|
455
|
+
"not allowed to enable fully sharded data parallel and tensor parallel together; "
|
|
456
|
+
"either set dit_fsdp=False or set tp_degree=None during pipeline initialization"
|
|
457
|
+
)
|
|
458
|
+
assert parallelism == cfg_degree * tp_degree, (
|
|
459
|
+
f"parallelism ({parallelism}) must be equal to cfg_degree ({cfg_degree}) * tp_degree ({tp_degree})"
|
|
460
|
+
)
|
|
461
|
+
sp_ulysses_degree = 1
|
|
462
|
+
sp_ring_degree = 1
|
|
463
|
+
elif sp_ulysses_degree is None and sp_ring_degree is None:
|
|
464
|
+
# use ulysses if not specified
|
|
465
|
+
sp_ulysses_degree = parallelism // cfg_degree
|
|
466
|
+
sp_ring_degree = 1
|
|
467
|
+
tp_degree = 1
|
|
468
|
+
elif sp_ulysses_degree is not None and sp_ring_degree is not None:
|
|
469
|
+
assert parallelism == cfg_degree * sp_ulysses_degree * sp_ring_degree, (
|
|
470
|
+
f"parallelism ({parallelism}) must be equal to cfg_degree ({cfg_degree}) * "
|
|
471
|
+
f"sp_ulysses_degree ({sp_ulysses_degree}) * sp_ring_degree ({sp_ring_degree})"
|
|
472
|
+
)
|
|
473
|
+
tp_degree = 1
|
|
474
|
+
else:
|
|
475
|
+
raise ValueError("sp_ulysses_degree and sp_ring_degree must be specified together")
|
|
476
|
+
|
|
477
|
+
with LoRAContext():
|
|
478
|
+
dit = WanDiT.from_state_dict(
|
|
479
|
+
dit_state_dict,
|
|
480
|
+
model_type=model_type,
|
|
481
|
+
device="cpu",
|
|
482
|
+
dtype=model_config.dit_dtype,
|
|
483
|
+
attn_impl=model_config.dit_attn_impl,
|
|
484
|
+
use_usp=(sp_ulysses_degree * sp_ring_degree > 1),
|
|
485
|
+
)
|
|
486
|
+
dit = ParallelModel(
|
|
487
|
+
dit,
|
|
488
|
+
cfg_degree=cfg_degree,
|
|
489
|
+
sp_ulysses_degree=sp_ulysses_degree,
|
|
490
|
+
sp_ring_degree=sp_ring_degree,
|
|
491
|
+
tp_degree=tp_degree,
|
|
492
|
+
shard_fn=partial(shard_model, wrap_module_names=["blocks"]) if model_config.dit_fsdp else None,
|
|
493
|
+
device="cuda",
|
|
494
|
+
)
|
|
495
|
+
else:
|
|
496
|
+
with LoRAContext():
|
|
497
|
+
dit = WanDiT.from_state_dict(
|
|
498
|
+
dit_state_dict,
|
|
499
|
+
model_type=model_type,
|
|
500
|
+
device=init_device,
|
|
501
|
+
dtype=model_config.dit_dtype,
|
|
502
|
+
attn_impl=model_config.dit_attn_impl,
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
pipe = cls(
|
|
506
|
+
config=model_config,
|
|
507
|
+
tokenizer=tokenizer,
|
|
508
|
+
text_encoder=text_encoder,
|
|
509
|
+
dit=dit,
|
|
510
|
+
vae=vae,
|
|
511
|
+
image_encoder=image_encoder,
|
|
512
|
+
batch_cfg=batch_cfg,
|
|
513
|
+
device=device,
|
|
514
|
+
dtype=dtype,
|
|
515
|
+
)
|
|
516
|
+
pipe.eval()
|
|
517
|
+
if offload_mode == "cpu_offload":
|
|
518
|
+
pipe.enable_cpu_offload()
|
|
519
|
+
elif offload_mode == "sequential_cpu_offload":
|
|
520
|
+
pipe.enable_sequential_cpu_offload()
|
|
521
|
+
return pipe
|
|
522
|
+
|
|
523
|
+
def __del__(self):
|
|
524
|
+
del self.dit
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
# Modified from transformers.tokenization_utils_base
|
|
2
|
+
from typing import Dict, List, Union, overload
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BaseTokenizer:
|
|
9
|
+
SPECIAL_TOKENS_ATTRIBUTES = [
|
|
10
|
+
"bos_token",
|
|
11
|
+
"eos_token",
|
|
12
|
+
"unk_token",
|
|
13
|
+
"pad_token",
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
def __init__(self, **kwargs):
|
|
17
|
+
self.bos_token = None
|
|
18
|
+
self.eos_token = None
|
|
19
|
+
self.unk_token = None
|
|
20
|
+
self.pad_token = None
|
|
21
|
+
|
|
22
|
+
for key, value in kwargs.items():
|
|
23
|
+
if value is None:
|
|
24
|
+
continue
|
|
25
|
+
if key in self.SPECIAL_TOKENS_ATTRIBUTES:
|
|
26
|
+
if isinstance(value, str):
|
|
27
|
+
setattr(self, key, value)
|
|
28
|
+
else:
|
|
29
|
+
raise TypeError(f"Special token {key} has to be str but got: {type(value)}")
|
|
30
|
+
|
|
31
|
+
self.model_max_length = kwargs.pop("model_max_length", None)
|
|
32
|
+
|
|
33
|
+
self.clean_up_tokenization_spaces = kwargs.pop("clean_up_tokenization_spaces", False)
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def bos_token_id(self) -> int:
|
|
37
|
+
if self.bos_token is None:
|
|
38
|
+
raise ValueError("Special token bos_token is not defined")
|
|
39
|
+
return self.convert_tokens_to_ids(self.bos_token)
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def eos_token_id(self) -> int:
|
|
43
|
+
if self.eos_token is None:
|
|
44
|
+
raise ValueError("Special token eos_token is not defined")
|
|
45
|
+
return self.convert_tokens_to_ids(self.eos_token)
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def unk_token_id(self) -> int:
|
|
49
|
+
if self.unk_token is None:
|
|
50
|
+
raise ValueError("Special token unk_token is not defined")
|
|
51
|
+
return self.convert_tokens_to_ids(self.unk_token)
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def pad_token_id(self) -> int:
|
|
55
|
+
if self.pad_token is None:
|
|
56
|
+
raise ValueError("Special token pad_token is not defined")
|
|
57
|
+
return self.convert_tokens_to_ids(self.pad_token)
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def special_tokens_map(self) -> Dict[str, str]:
|
|
61
|
+
"""
|
|
62
|
+
`Dict[str, str]`: A dictionary mapping special token class attributes (`bos_token`, `unk_token`, etc.)
|
|
63
|
+
to their values (`'<bos>'`, `'<unk>'`, etc.).
|
|
64
|
+
"""
|
|
65
|
+
set_attr = {}
|
|
66
|
+
for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
|
|
67
|
+
attr_value = getattr(self, attr)
|
|
68
|
+
if attr_value:
|
|
69
|
+
set_attr[attr] = attr_value
|
|
70
|
+
return set_attr
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def all_special_tokens(self) -> List[str]:
|
|
74
|
+
"""
|
|
75
|
+
`List[str]`: A list of the unique special tokens (`'<bos>'`, `'<unk>'`, ..., etc.).
|
|
76
|
+
"""
|
|
77
|
+
return list(self.special_tokens_map.values())
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def all_special_ids(self) -> List[int]:
|
|
81
|
+
"""
|
|
82
|
+
`List[int]`: List the ids of the special tokens(`'<bos>'`, `'<unk>'`, etc.) mapped to class attributes.
|
|
83
|
+
"""
|
|
84
|
+
return self.convert_tokens_to_ids(self.all_special_tokens)
|
|
85
|
+
|
|
86
|
+
@overload
|
|
87
|
+
def tokenize(self, texts: str) -> List[str]: ...
|
|
88
|
+
|
|
89
|
+
@overload
|
|
90
|
+
def tokenize(self, texts: List[str]) -> List[List[str]]: ...
|
|
91
|
+
|
|
92
|
+
def tokenize(self, texts: Union[str, List[str]]) -> Union[List[str], List[List[str]]]:
|
|
93
|
+
raise NotImplementedError()
|
|
94
|
+
|
|
95
|
+
def encode(self, texts: str) -> List[int]:
|
|
96
|
+
raise NotImplementedError()
|
|
97
|
+
|
|
98
|
+
def batch_encode(self, texts: List[str]) -> List[List[int]]:
|
|
99
|
+
raise NotImplementedError()
|
|
100
|
+
|
|
101
|
+
def decode(
|
|
102
|
+
self, ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = None
|
|
103
|
+
) -> str:
|
|
104
|
+
raise NotImplementedError()
|
|
105
|
+
|
|
106
|
+
def batch_decode(
|
|
107
|
+
self, ids: List[List[int]], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = None
|
|
108
|
+
) -> List[str]:
|
|
109
|
+
raise NotImplementedError()
|
|
110
|
+
|
|
111
|
+
@overload
|
|
112
|
+
def convert_tokens_to_ids(self, tokens: str) -> int: ...
|
|
113
|
+
|
|
114
|
+
@overload
|
|
115
|
+
def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]: ...
|
|
116
|
+
|
|
117
|
+
def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
|
|
118
|
+
raise NotImplementedError()
|
|
119
|
+
|
|
120
|
+
@overload
|
|
121
|
+
def convert_ids_to_tokens(self, ids: int, skip_special_tokens: bool = False) -> str: ...
|
|
122
|
+
|
|
123
|
+
@overload
|
|
124
|
+
def convert_ids_to_tokens(self, ids: List[int], skip_special_tokens: bool = False) -> List[str]: ...
|
|
125
|
+
|
|
126
|
+
def convert_ids_to_tokens(
|
|
127
|
+
self, ids: Union[int, List[int]], skip_special_tokens: bool = False
|
|
128
|
+
) -> Union[str, List[str]]:
|
|
129
|
+
raise NotImplementedError()
|
|
130
|
+
|
|
131
|
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
|
132
|
+
raise NotImplementedError()
|
|
133
|
+
|
|
134
|
+
@staticmethod
|
|
135
|
+
def clean_up_tokenization(text: str) -> str:
|
|
136
|
+
"""
|
|
137
|
+
Clean up a list of simple English tokenization artifacts like spaces before punctuations and abbreviated forms.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
text (`str`): The text to clean up.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
`str`: The cleaned-up string.
|
|
144
|
+
"""
|
|
145
|
+
text = (
|
|
146
|
+
text.replace(" .", ".")
|
|
147
|
+
.replace(" ?", "?")
|
|
148
|
+
.replace(" !", "!")
|
|
149
|
+
.replace(" ,", ",")
|
|
150
|
+
.replace(" ' ", "'")
|
|
151
|
+
.replace(" n't", "n't")
|
|
152
|
+
.replace(" 'm", "'m")
|
|
153
|
+
.replace(" 's", "'s")
|
|
154
|
+
.replace(" 've", "'ve")
|
|
155
|
+
.replace(" 're", "'re")
|
|
156
|
+
)
|
|
157
|
+
return text
|