diffsynth 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth/__init__.py +6 -0
- diffsynth/configs/__init__.py +0 -0
- diffsynth/configs/model_config.py +243 -0
- diffsynth/controlnets/__init__.py +2 -0
- diffsynth/controlnets/controlnet_unit.py +53 -0
- diffsynth/controlnets/processors.py +51 -0
- diffsynth/data/__init__.py +1 -0
- diffsynth/data/simple_text_image.py +35 -0
- diffsynth/data/video.py +148 -0
- diffsynth/extensions/ESRGAN/__init__.py +118 -0
- diffsynth/extensions/FastBlend/__init__.py +63 -0
- diffsynth/extensions/FastBlend/api.py +397 -0
- diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
- diffsynth/extensions/FastBlend/data.py +146 -0
- diffsynth/extensions/FastBlend/patch_match.py +298 -0
- diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
- diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
- diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
- diffsynth/extensions/FastBlend/runners/fast.py +141 -0
- diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
- diffsynth/extensions/RIFE/__init__.py +242 -0
- diffsynth/extensions/__init__.py +0 -0
- diffsynth/models/__init__.py +1 -0
- diffsynth/models/attention.py +89 -0
- diffsynth/models/downloader.py +66 -0
- diffsynth/models/hunyuan_dit.py +451 -0
- diffsynth/models/hunyuan_dit_text_encoder.py +163 -0
- diffsynth/models/kolors_text_encoder.py +1363 -0
- diffsynth/models/lora.py +195 -0
- diffsynth/models/model_manager.py +536 -0
- diffsynth/models/sd3_dit.py +798 -0
- diffsynth/models/sd3_text_encoder.py +1107 -0
- diffsynth/models/sd3_vae_decoder.py +81 -0
- diffsynth/models/sd3_vae_encoder.py +95 -0
- diffsynth/models/sd_controlnet.py +588 -0
- diffsynth/models/sd_ipadapter.py +57 -0
- diffsynth/models/sd_motion.py +199 -0
- diffsynth/models/sd_text_encoder.py +321 -0
- diffsynth/models/sd_unet.py +1108 -0
- diffsynth/models/sd_vae_decoder.py +336 -0
- diffsynth/models/sd_vae_encoder.py +282 -0
- diffsynth/models/sdxl_ipadapter.py +122 -0
- diffsynth/models/sdxl_motion.py +104 -0
- diffsynth/models/sdxl_text_encoder.py +759 -0
- diffsynth/models/sdxl_unet.py +1899 -0
- diffsynth/models/sdxl_vae_decoder.py +24 -0
- diffsynth/models/sdxl_vae_encoder.py +24 -0
- diffsynth/models/svd_image_encoder.py +505 -0
- diffsynth/models/svd_unet.py +2004 -0
- diffsynth/models/svd_vae_decoder.py +578 -0
- diffsynth/models/svd_vae_encoder.py +139 -0
- diffsynth/models/tiler.py +106 -0
- diffsynth/pipelines/__init__.py +9 -0
- diffsynth/pipelines/base.py +34 -0
- diffsynth/pipelines/dancer.py +178 -0
- diffsynth/pipelines/hunyuan_image.py +274 -0
- diffsynth/pipelines/pipeline_runner.py +105 -0
- diffsynth/pipelines/sd3_image.py +132 -0
- diffsynth/pipelines/sd_image.py +173 -0
- diffsynth/pipelines/sd_video.py +266 -0
- diffsynth/pipelines/sdxl_image.py +191 -0
- diffsynth/pipelines/sdxl_video.py +223 -0
- diffsynth/pipelines/svd_video.py +297 -0
- diffsynth/processors/FastBlend.py +142 -0
- diffsynth/processors/PILEditor.py +28 -0
- diffsynth/processors/RIFE.py +77 -0
- diffsynth/processors/__init__.py +0 -0
- diffsynth/processors/base.py +6 -0
- diffsynth/processors/sequencial_processor.py +41 -0
- diffsynth/prompters/__init__.py +6 -0
- diffsynth/prompters/base_prompter.py +57 -0
- diffsynth/prompters/hunyuan_dit_prompter.py +69 -0
- diffsynth/prompters/kolors_prompter.py +353 -0
- diffsynth/prompters/prompt_refiners.py +77 -0
- diffsynth/prompters/sd3_prompter.py +92 -0
- diffsynth/prompters/sd_prompter.py +73 -0
- diffsynth/prompters/sdxl_prompter.py +61 -0
- diffsynth/schedulers/__init__.py +3 -0
- diffsynth/schedulers/continuous_ode.py +59 -0
- diffsynth/schedulers/ddim.py +79 -0
- diffsynth/schedulers/flow_match.py +51 -0
- diffsynth/tokenizer_configs/__init__.py +0 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json +7 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json +16 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt +47020 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt +21128 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json +28 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json +1 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model +0 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json +1 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model +0 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json +12 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt +0 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json +24 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json +34 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json +38 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json +125 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/spiece.model +0 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json +129428 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json +940 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt +40213 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json +24 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json +38 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/vocab.json +49411 -0
- diffsynth/trainers/__init__.py +0 -0
- diffsynth/trainers/text_to_image.py +253 -0
- diffsynth-1.0.0.dist-info/LICENSE +201 -0
- diffsynth-1.0.0.dist-info/METADATA +23 -0
- diffsynth-1.0.0.dist-info/RECORD +120 -0
- diffsynth-1.0.0.dist-info/WHEEL +5 -0
- diffsynth-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
from ..models import ModelManager, SVDImageEncoder, SVDUNet, SVDVAEEncoder, SVDVAEDecoder
|
|
2
|
+
from ..schedulers import ContinuousODEScheduler
|
|
3
|
+
from .base import BasePipeline
|
|
4
|
+
import torch
|
|
5
|
+
from tqdm import tqdm
|
|
6
|
+
from PIL import Image
|
|
7
|
+
import numpy as np
|
|
8
|
+
from einops import rearrange, repeat
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SVDVideoPipeline(BasePipeline):
|
|
13
|
+
|
|
14
|
+
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
|
15
|
+
super().__init__(device=device, torch_dtype=torch_dtype)
|
|
16
|
+
self.scheduler = ContinuousODEScheduler()
|
|
17
|
+
# models
|
|
18
|
+
self.image_encoder: SVDImageEncoder = None
|
|
19
|
+
self.unet: SVDUNet = None
|
|
20
|
+
self.vae_encoder: SVDVAEEncoder = None
|
|
21
|
+
self.vae_decoder: SVDVAEDecoder = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def fetch_models(self, model_manager: ModelManager):
|
|
25
|
+
self.image_encoder = model_manager.fetch_model("svd_image_encoder")
|
|
26
|
+
self.unet = model_manager.fetch_model("svd_unet")
|
|
27
|
+
self.vae_encoder = model_manager.fetch_model("svd_vae_encoder")
|
|
28
|
+
self.vae_decoder = model_manager.fetch_model("svd_vae_decoder")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@staticmethod
|
|
32
|
+
def from_model_manager(model_manager: ModelManager, **kwargs):
|
|
33
|
+
pipe = SVDVideoPipeline(
|
|
34
|
+
device=model_manager.device,
|
|
35
|
+
torch_dtype=model_manager.torch_dtype
|
|
36
|
+
)
|
|
37
|
+
pipe.fetch_models(model_manager)
|
|
38
|
+
return pipe
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def encode_image_with_clip(self, image):
|
|
42
|
+
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
|
43
|
+
image = SVDCLIPImageProcessor().resize_with_antialiasing(image, (224, 224))
|
|
44
|
+
image = (image + 1.0) / 2.0
|
|
45
|
+
mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).reshape(1, 3, 1, 1).to(device=self.device, dtype=self.torch_dtype)
|
|
46
|
+
std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).reshape(1, 3, 1, 1).to(device=self.device, dtype=self.torch_dtype)
|
|
47
|
+
image = (image - mean) / std
|
|
48
|
+
image_emb = self.image_encoder(image)
|
|
49
|
+
return image_emb
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def encode_image_with_vae(self, image, noise_aug_strength):
|
|
53
|
+
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
|
54
|
+
noise = torch.randn(image.shape, device="cpu", dtype=self.torch_dtype).to(self.device)
|
|
55
|
+
image = image + noise_aug_strength * noise
|
|
56
|
+
image_emb = self.vae_encoder(image) / self.vae_encoder.scaling_factor
|
|
57
|
+
return image_emb
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def encode_video_with_vae(self, video):
|
|
61
|
+
video = torch.concat([self.preprocess_image(frame) for frame in video], dim=0)
|
|
62
|
+
video = rearrange(video, "T C H W -> 1 C T H W")
|
|
63
|
+
video = video.to(device=self.device, dtype=self.torch_dtype)
|
|
64
|
+
latents = self.vae_encoder.encode_video(video)
|
|
65
|
+
latents = rearrange(latents[0], "C T H W -> T C H W")
|
|
66
|
+
return latents
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def tensor2video(self, frames):
|
|
70
|
+
frames = rearrange(frames, "C T H W -> T H W C")
|
|
71
|
+
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
|
|
72
|
+
frames = [Image.fromarray(frame) for frame in frames]
|
|
73
|
+
return frames
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def calculate_noise_pred(
|
|
77
|
+
self,
|
|
78
|
+
latents,
|
|
79
|
+
timestep,
|
|
80
|
+
add_time_id,
|
|
81
|
+
cfg_scales,
|
|
82
|
+
image_emb_vae_posi, image_emb_clip_posi,
|
|
83
|
+
image_emb_vae_nega, image_emb_clip_nega
|
|
84
|
+
):
|
|
85
|
+
# Positive side
|
|
86
|
+
noise_pred_posi = self.unet(
|
|
87
|
+
torch.cat([latents, image_emb_vae_posi], dim=1),
|
|
88
|
+
timestep, image_emb_clip_posi, add_time_id
|
|
89
|
+
)
|
|
90
|
+
# Negative side
|
|
91
|
+
noise_pred_nega = self.unet(
|
|
92
|
+
torch.cat([latents, image_emb_vae_nega], dim=1),
|
|
93
|
+
timestep, image_emb_clip_nega, add_time_id
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Classifier-free guidance
|
|
97
|
+
noise_pred = noise_pred_nega + cfg_scales * (noise_pred_posi - noise_pred_nega)
|
|
98
|
+
|
|
99
|
+
return noise_pred
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def post_process_latents(self, latents, post_normalize=True, contrast_enhance_scale=1.0):
|
|
103
|
+
if post_normalize:
|
|
104
|
+
mean, std = latents.mean(), latents.std()
|
|
105
|
+
latents = (latents - latents.mean(dim=[1, 2, 3], keepdim=True)) / latents.std(dim=[1, 2, 3], keepdim=True) * std + mean
|
|
106
|
+
latents = latents * contrast_enhance_scale
|
|
107
|
+
return latents
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@torch.no_grad()
|
|
111
|
+
def __call__(
|
|
112
|
+
self,
|
|
113
|
+
input_image=None,
|
|
114
|
+
input_video=None,
|
|
115
|
+
mask_frames=[],
|
|
116
|
+
mask_frame_ids=[],
|
|
117
|
+
min_cfg_scale=1.0,
|
|
118
|
+
max_cfg_scale=3.0,
|
|
119
|
+
denoising_strength=1.0,
|
|
120
|
+
num_frames=25,
|
|
121
|
+
height=576,
|
|
122
|
+
width=1024,
|
|
123
|
+
fps=7,
|
|
124
|
+
motion_bucket_id=127,
|
|
125
|
+
noise_aug_strength=0.02,
|
|
126
|
+
num_inference_steps=20,
|
|
127
|
+
post_normalize=True,
|
|
128
|
+
contrast_enhance_scale=1.2,
|
|
129
|
+
progress_bar_cmd=tqdm,
|
|
130
|
+
progress_bar_st=None,
|
|
131
|
+
):
|
|
132
|
+
# Prepare scheduler
|
|
133
|
+
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
|
|
134
|
+
|
|
135
|
+
# Prepare latent tensors
|
|
136
|
+
noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).to(self.device)
|
|
137
|
+
if denoising_strength == 1.0:
|
|
138
|
+
latents = noise.clone()
|
|
139
|
+
else:
|
|
140
|
+
latents = self.encode_video_with_vae(input_video)
|
|
141
|
+
latents = self.scheduler.add_noise(latents, noise, self.scheduler.timesteps[0])
|
|
142
|
+
|
|
143
|
+
# Prepare mask frames
|
|
144
|
+
if len(mask_frames) > 0:
|
|
145
|
+
mask_latents = self.encode_video_with_vae(mask_frames)
|
|
146
|
+
|
|
147
|
+
# Encode image
|
|
148
|
+
image_emb_clip_posi = self.encode_image_with_clip(input_image)
|
|
149
|
+
image_emb_clip_nega = torch.zeros_like(image_emb_clip_posi)
|
|
150
|
+
image_emb_vae_posi = repeat(self.encode_image_with_vae(input_image, noise_aug_strength), "B C H W -> (B T) C H W", T=num_frames)
|
|
151
|
+
image_emb_vae_nega = torch.zeros_like(image_emb_vae_posi)
|
|
152
|
+
|
|
153
|
+
# Prepare classifier-free guidance
|
|
154
|
+
cfg_scales = torch.linspace(min_cfg_scale, max_cfg_scale, num_frames)
|
|
155
|
+
cfg_scales = cfg_scales.reshape(num_frames, 1, 1, 1).to(device=self.device, dtype=self.torch_dtype)
|
|
156
|
+
|
|
157
|
+
# Prepare positional id
|
|
158
|
+
add_time_id = torch.tensor([[fps-1, motion_bucket_id, noise_aug_strength]], device=self.device)
|
|
159
|
+
|
|
160
|
+
# Denoise
|
|
161
|
+
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
|
162
|
+
|
|
163
|
+
# Mask frames
|
|
164
|
+
for frame_id, mask_frame_id in enumerate(mask_frame_ids):
|
|
165
|
+
latents[mask_frame_id] = self.scheduler.add_noise(mask_latents[frame_id], noise[mask_frame_id], timestep)
|
|
166
|
+
|
|
167
|
+
# Fetch model output
|
|
168
|
+
noise_pred = self.calculate_noise_pred(
|
|
169
|
+
latents, timestep, add_time_id, cfg_scales,
|
|
170
|
+
image_emb_vae_posi, image_emb_clip_posi, image_emb_vae_nega, image_emb_clip_nega
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# Forward Euler
|
|
174
|
+
latents = self.scheduler.step(noise_pred, timestep, latents)
|
|
175
|
+
|
|
176
|
+
# Update progress bar
|
|
177
|
+
if progress_bar_st is not None:
|
|
178
|
+
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
|
179
|
+
|
|
180
|
+
# Decode image
|
|
181
|
+
latents = self.post_process_latents(latents, post_normalize=post_normalize, contrast_enhance_scale=contrast_enhance_scale)
|
|
182
|
+
video = self.vae_decoder.decode_video(latents, progress_bar=progress_bar_cmd)
|
|
183
|
+
video = self.tensor2video(video)
|
|
184
|
+
|
|
185
|
+
return video
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class SVDCLIPImageProcessor:
|
|
190
|
+
def __init__(self):
|
|
191
|
+
pass
|
|
192
|
+
|
|
193
|
+
def resize_with_antialiasing(self, input, size, interpolation="bicubic", align_corners=True):
|
|
194
|
+
h, w = input.shape[-2:]
|
|
195
|
+
factors = (h / size[0], w / size[1])
|
|
196
|
+
|
|
197
|
+
# First, we have to determine sigma
|
|
198
|
+
# Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
|
|
199
|
+
sigmas = (
|
|
200
|
+
max((factors[0] - 1.0) / 2.0, 0.001),
|
|
201
|
+
max((factors[1] - 1.0) / 2.0, 0.001),
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
# Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
|
|
205
|
+
# https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
|
|
206
|
+
# But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
|
|
207
|
+
ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
|
|
208
|
+
|
|
209
|
+
# Make sure it is odd
|
|
210
|
+
if (ks[0] % 2) == 0:
|
|
211
|
+
ks = ks[0] + 1, ks[1]
|
|
212
|
+
|
|
213
|
+
if (ks[1] % 2) == 0:
|
|
214
|
+
ks = ks[0], ks[1] + 1
|
|
215
|
+
|
|
216
|
+
input = self._gaussian_blur2d(input, ks, sigmas)
|
|
217
|
+
|
|
218
|
+
output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
|
|
219
|
+
return output
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def _compute_padding(self, kernel_size):
|
|
223
|
+
"""Compute padding tuple."""
|
|
224
|
+
# 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
|
|
225
|
+
# https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
|
|
226
|
+
if len(kernel_size) < 2:
|
|
227
|
+
raise AssertionError(kernel_size)
|
|
228
|
+
computed = [k - 1 for k in kernel_size]
|
|
229
|
+
|
|
230
|
+
# for even kernels we need to do asymmetric padding :(
|
|
231
|
+
out_padding = 2 * len(kernel_size) * [0]
|
|
232
|
+
|
|
233
|
+
for i in range(len(kernel_size)):
|
|
234
|
+
computed_tmp = computed[-(i + 1)]
|
|
235
|
+
|
|
236
|
+
pad_front = computed_tmp // 2
|
|
237
|
+
pad_rear = computed_tmp - pad_front
|
|
238
|
+
|
|
239
|
+
out_padding[2 * i + 0] = pad_front
|
|
240
|
+
out_padding[2 * i + 1] = pad_rear
|
|
241
|
+
|
|
242
|
+
return out_padding
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def _filter2d(self, input, kernel):
|
|
246
|
+
# prepare kernel
|
|
247
|
+
b, c, h, w = input.shape
|
|
248
|
+
tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
|
|
249
|
+
|
|
250
|
+
tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
|
|
251
|
+
|
|
252
|
+
height, width = tmp_kernel.shape[-2:]
|
|
253
|
+
|
|
254
|
+
padding_shape: list[int] = self._compute_padding([height, width])
|
|
255
|
+
input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
|
|
256
|
+
|
|
257
|
+
# kernel and input tensor reshape to align element-wise or batch-wise params
|
|
258
|
+
tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
|
|
259
|
+
input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
|
|
260
|
+
|
|
261
|
+
# convolve the tensor with the kernel.
|
|
262
|
+
output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
|
|
263
|
+
|
|
264
|
+
out = output.view(b, c, h, w)
|
|
265
|
+
return out
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def _gaussian(self, window_size: int, sigma):
|
|
269
|
+
if isinstance(sigma, float):
|
|
270
|
+
sigma = torch.tensor([[sigma]])
|
|
271
|
+
|
|
272
|
+
batch_size = sigma.shape[0]
|
|
273
|
+
|
|
274
|
+
x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
|
|
275
|
+
|
|
276
|
+
if window_size % 2 == 0:
|
|
277
|
+
x = x + 0.5
|
|
278
|
+
|
|
279
|
+
gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
|
|
280
|
+
|
|
281
|
+
return gauss / gauss.sum(-1, keepdim=True)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def _gaussian_blur2d(self, input, kernel_size, sigma):
|
|
285
|
+
if isinstance(sigma, tuple):
|
|
286
|
+
sigma = torch.tensor([sigma], dtype=input.dtype)
|
|
287
|
+
else:
|
|
288
|
+
sigma = sigma.to(dtype=input.dtype)
|
|
289
|
+
|
|
290
|
+
ky, kx = int(kernel_size[0]), int(kernel_size[1])
|
|
291
|
+
bs = sigma.shape[0]
|
|
292
|
+
kernel_x = self._gaussian(kx, sigma[:, 1].view(bs, 1))
|
|
293
|
+
kernel_y = self._gaussian(ky, sigma[:, 0].view(bs, 1))
|
|
294
|
+
out_x = self._filter2d(input, kernel_x[..., None, :])
|
|
295
|
+
out = self._filter2d(out_x, kernel_y[..., None])
|
|
296
|
+
|
|
297
|
+
return out
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
from PIL import Image
|
|
2
|
+
import cupy as cp
|
|
3
|
+
import numpy as np
|
|
4
|
+
from tqdm import tqdm
|
|
5
|
+
from ..extensions.FastBlend.patch_match import PyramidPatchMatcher
|
|
6
|
+
from ..extensions.FastBlend.runners.fast import TableManager
|
|
7
|
+
from .base import VideoProcessor
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class FastBlendSmoother(VideoProcessor):
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
inference_mode="fast", batch_size=8, window_size=60,
|
|
14
|
+
minimum_patch_size=5, threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0, initialize="identity", tracking_window_size=0
|
|
15
|
+
):
|
|
16
|
+
self.inference_mode = inference_mode
|
|
17
|
+
self.batch_size = batch_size
|
|
18
|
+
self.window_size = window_size
|
|
19
|
+
self.ebsynth_config = {
|
|
20
|
+
"minimum_patch_size": minimum_patch_size,
|
|
21
|
+
"threads_per_block": threads_per_block,
|
|
22
|
+
"num_iter": num_iter,
|
|
23
|
+
"gpu_id": gpu_id,
|
|
24
|
+
"guide_weight": guide_weight,
|
|
25
|
+
"initialize": initialize,
|
|
26
|
+
"tracking_window_size": tracking_window_size
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
@staticmethod
|
|
30
|
+
def from_model_manager(model_manager, **kwargs):
|
|
31
|
+
# TODO: fetch GPU ID from model_manager
|
|
32
|
+
return FastBlendSmoother(**kwargs)
|
|
33
|
+
|
|
34
|
+
def inference_fast(self, frames_guide, frames_style):
|
|
35
|
+
table_manager = TableManager()
|
|
36
|
+
patch_match_engine = PyramidPatchMatcher(
|
|
37
|
+
image_height=frames_style[0].shape[0],
|
|
38
|
+
image_width=frames_style[0].shape[1],
|
|
39
|
+
channel=3,
|
|
40
|
+
**self.ebsynth_config
|
|
41
|
+
)
|
|
42
|
+
# left part
|
|
43
|
+
table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, self.batch_size, desc="Fast Mode Step 1/4")
|
|
44
|
+
table_l = table_manager.remapping_table_to_blending_table(table_l)
|
|
45
|
+
table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, self.window_size, self.batch_size, desc="Fast Mode Step 2/4")
|
|
46
|
+
# right part
|
|
47
|
+
table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, self.batch_size, desc="Fast Mode Step 3/4")
|
|
48
|
+
table_r = table_manager.remapping_table_to_blending_table(table_r)
|
|
49
|
+
table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, self.window_size, self.batch_size, desc="Fast Mode Step 4/4")[::-1]
|
|
50
|
+
# merge
|
|
51
|
+
frames = []
|
|
52
|
+
for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
|
|
53
|
+
weight_m = -1
|
|
54
|
+
weight = weight_l + weight_m + weight_r
|
|
55
|
+
frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
|
|
56
|
+
frames.append(frame)
|
|
57
|
+
frames = [frame.clip(0, 255).astype("uint8") for frame in frames]
|
|
58
|
+
frames = [Image.fromarray(frame) for frame in frames]
|
|
59
|
+
return frames
|
|
60
|
+
|
|
61
|
+
def inference_balanced(self, frames_guide, frames_style):
|
|
62
|
+
patch_match_engine = PyramidPatchMatcher(
|
|
63
|
+
image_height=frames_style[0].shape[0],
|
|
64
|
+
image_width=frames_style[0].shape[1],
|
|
65
|
+
channel=3,
|
|
66
|
+
**self.ebsynth_config
|
|
67
|
+
)
|
|
68
|
+
output_frames = []
|
|
69
|
+
# tasks
|
|
70
|
+
n = len(frames_style)
|
|
71
|
+
tasks = []
|
|
72
|
+
for target in range(n):
|
|
73
|
+
for source in range(target - self.window_size, target + self.window_size + 1):
|
|
74
|
+
if source >= 0 and source < n and source != target:
|
|
75
|
+
tasks.append((source, target))
|
|
76
|
+
# run
|
|
77
|
+
frames = [(None, 1) for i in range(n)]
|
|
78
|
+
for batch_id in tqdm(range(0, len(tasks), self.batch_size), desc="Balanced Mode"):
|
|
79
|
+
tasks_batch = tasks[batch_id: min(batch_id+self.batch_size, len(tasks))]
|
|
80
|
+
source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
|
|
81
|
+
target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
|
|
82
|
+
source_style = np.stack([frames_style[source] for source, target in tasks_batch])
|
|
83
|
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
|
84
|
+
for (source, target), result in zip(tasks_batch, target_style):
|
|
85
|
+
frame, weight = frames[target]
|
|
86
|
+
if frame is None:
|
|
87
|
+
frame = frames_style[target]
|
|
88
|
+
frames[target] = (
|
|
89
|
+
frame * (weight / (weight + 1)) + result / (weight + 1),
|
|
90
|
+
weight + 1
|
|
91
|
+
)
|
|
92
|
+
if weight + 1 == min(n, target + self.window_size + 1) - max(0, target - self.window_size):
|
|
93
|
+
frame = frame.clip(0, 255).astype("uint8")
|
|
94
|
+
output_frames.append(Image.fromarray(frame))
|
|
95
|
+
frames[target] = (None, 1)
|
|
96
|
+
return output_frames
|
|
97
|
+
|
|
98
|
+
def inference_accurate(self, frames_guide, frames_style):
|
|
99
|
+
patch_match_engine = PyramidPatchMatcher(
|
|
100
|
+
image_height=frames_style[0].shape[0],
|
|
101
|
+
image_width=frames_style[0].shape[1],
|
|
102
|
+
channel=3,
|
|
103
|
+
use_mean_target_style=True,
|
|
104
|
+
**self.ebsynth_config
|
|
105
|
+
)
|
|
106
|
+
output_frames = []
|
|
107
|
+
# run
|
|
108
|
+
n = len(frames_style)
|
|
109
|
+
for target in tqdm(range(n), desc="Accurate Mode"):
|
|
110
|
+
l, r = max(target - self.window_size, 0), min(target + self.window_size + 1, n)
|
|
111
|
+
remapped_frames = []
|
|
112
|
+
for i in range(l, r, self.batch_size):
|
|
113
|
+
j = min(i + self.batch_size, r)
|
|
114
|
+
source_guide = np.stack([frames_guide[source] for source in range(i, j)])
|
|
115
|
+
target_guide = np.stack([frames_guide[target]] * (j - i))
|
|
116
|
+
source_style = np.stack([frames_style[source] for source in range(i, j)])
|
|
117
|
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
|
118
|
+
remapped_frames.append(target_style)
|
|
119
|
+
frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
|
|
120
|
+
frame = frame.clip(0, 255).astype("uint8")
|
|
121
|
+
output_frames.append(Image.fromarray(frame))
|
|
122
|
+
return output_frames
|
|
123
|
+
|
|
124
|
+
def release_vram(self):
|
|
125
|
+
mempool = cp.get_default_memory_pool()
|
|
126
|
+
pinned_mempool = cp.get_default_pinned_memory_pool()
|
|
127
|
+
mempool.free_all_blocks()
|
|
128
|
+
pinned_mempool.free_all_blocks()
|
|
129
|
+
|
|
130
|
+
def __call__(self, rendered_frames, original_frames=None, **kwargs):
|
|
131
|
+
rendered_frames = [np.array(frame) for frame in rendered_frames]
|
|
132
|
+
original_frames = [np.array(frame) for frame in original_frames]
|
|
133
|
+
if self.inference_mode == "fast":
|
|
134
|
+
output_frames = self.inference_fast(original_frames, rendered_frames)
|
|
135
|
+
elif self.inference_mode == "balanced":
|
|
136
|
+
output_frames = self.inference_balanced(original_frames, rendered_frames)
|
|
137
|
+
elif self.inference_mode == "accurate":
|
|
138
|
+
output_frames = self.inference_accurate(original_frames, rendered_frames)
|
|
139
|
+
else:
|
|
140
|
+
raise ValueError("inference_mode must be fast, balanced or accurate")
|
|
141
|
+
self.release_vram()
|
|
142
|
+
return output_frames
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from PIL import ImageEnhance
|
|
2
|
+
from .base import VideoProcessor
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class ContrastEditor(VideoProcessor):
|
|
6
|
+
def __init__(self, rate=1.5):
|
|
7
|
+
self.rate = rate
|
|
8
|
+
|
|
9
|
+
@staticmethod
|
|
10
|
+
def from_model_manager(model_manager, **kwargs):
|
|
11
|
+
return ContrastEditor(**kwargs)
|
|
12
|
+
|
|
13
|
+
def __call__(self, rendered_frames, **kwargs):
|
|
14
|
+
rendered_frames = [ImageEnhance.Contrast(i).enhance(self.rate) for i in rendered_frames]
|
|
15
|
+
return rendered_frames
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SharpnessEditor(VideoProcessor):
|
|
19
|
+
def __init__(self, rate=1.5):
|
|
20
|
+
self.rate = rate
|
|
21
|
+
|
|
22
|
+
@staticmethod
|
|
23
|
+
def from_model_manager(model_manager, **kwargs):
|
|
24
|
+
return SharpnessEditor(**kwargs)
|
|
25
|
+
|
|
26
|
+
def __call__(self, rendered_frames, **kwargs):
|
|
27
|
+
rendered_frames = [ImageEnhance.Sharpness(i).enhance(self.rate) for i in rendered_frames]
|
|
28
|
+
return rendered_frames
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
from PIL import Image
|
|
4
|
+
from .base import VideoProcessor
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class RIFESmoother(VideoProcessor):
|
|
8
|
+
def __init__(self, model, device="cuda", scale=1.0, batch_size=4, interpolate=True):
|
|
9
|
+
self.model = model
|
|
10
|
+
self.device = device
|
|
11
|
+
|
|
12
|
+
# IFNet only does not support float16
|
|
13
|
+
self.torch_dtype = torch.float32
|
|
14
|
+
|
|
15
|
+
# Other parameters
|
|
16
|
+
self.scale = scale
|
|
17
|
+
self.batch_size = batch_size
|
|
18
|
+
self.interpolate = interpolate
|
|
19
|
+
|
|
20
|
+
@staticmethod
|
|
21
|
+
def from_model_manager(model_manager, **kwargs):
|
|
22
|
+
return RIFESmoother(model_manager.RIFE, device=model_manager.device, **kwargs)
|
|
23
|
+
|
|
24
|
+
def process_image(self, image):
|
|
25
|
+
width, height = image.size
|
|
26
|
+
if width % 32 != 0 or height % 32 != 0:
|
|
27
|
+
width = (width + 31) // 32
|
|
28
|
+
height = (height + 31) // 32
|
|
29
|
+
image = image.resize((width, height))
|
|
30
|
+
image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1)
|
|
31
|
+
return image
|
|
32
|
+
|
|
33
|
+
def process_images(self, images):
|
|
34
|
+
images = [self.process_image(image) for image in images]
|
|
35
|
+
images = torch.stack(images)
|
|
36
|
+
return images
|
|
37
|
+
|
|
38
|
+
def decode_images(self, images):
|
|
39
|
+
images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
|
|
40
|
+
images = [Image.fromarray(image) for image in images]
|
|
41
|
+
return images
|
|
42
|
+
|
|
43
|
+
def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
|
|
44
|
+
output_tensor = []
|
|
45
|
+
for batch_id in range(0, input_tensor.shape[0], batch_size):
|
|
46
|
+
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
|
|
47
|
+
batch_input_tensor = input_tensor[batch_id: batch_id_]
|
|
48
|
+
batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
|
|
49
|
+
flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
|
|
50
|
+
output_tensor.append(merged[2].cpu())
|
|
51
|
+
output_tensor = torch.concat(output_tensor, dim=0)
|
|
52
|
+
return output_tensor
|
|
53
|
+
|
|
54
|
+
@torch.no_grad()
|
|
55
|
+
def __call__(self, rendered_frames, **kwargs):
|
|
56
|
+
# Preprocess
|
|
57
|
+
processed_images = self.process_images(rendered_frames)
|
|
58
|
+
|
|
59
|
+
# Input
|
|
60
|
+
input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1)
|
|
61
|
+
|
|
62
|
+
# Interpolate
|
|
63
|
+
output_tensor = self.process_tensors(input_tensor, scale=self.scale, batch_size=self.batch_size)
|
|
64
|
+
|
|
65
|
+
if self.interpolate:
|
|
66
|
+
# Blend
|
|
67
|
+
input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1)
|
|
68
|
+
output_tensor = self.process_tensors(input_tensor, scale=self.scale, batch_size=self.batch_size)
|
|
69
|
+
processed_images[1:-1] = output_tensor
|
|
70
|
+
else:
|
|
71
|
+
processed_images[1:-1] = (processed_images[1:-1] + output_tensor) / 2
|
|
72
|
+
|
|
73
|
+
# To images
|
|
74
|
+
output_images = self.decode_images(processed_images)
|
|
75
|
+
if output_images[0].size != rendered_frames[0].size:
|
|
76
|
+
output_images = [image.resize(rendered_frames[0].size) for image in output_images]
|
|
77
|
+
return output_images
|
|
File without changes
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from .base import VideoProcessor
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class AutoVideoProcessor(VideoProcessor):
|
|
5
|
+
def __init__(self):
|
|
6
|
+
pass
|
|
7
|
+
|
|
8
|
+
@staticmethod
|
|
9
|
+
def from_model_manager(model_manager, processor_type, **kwargs):
|
|
10
|
+
if processor_type == "FastBlend":
|
|
11
|
+
from .FastBlend import FastBlendSmoother
|
|
12
|
+
return FastBlendSmoother.from_model_manager(model_manager, **kwargs)
|
|
13
|
+
elif processor_type == "Contrast":
|
|
14
|
+
from .PILEditor import ContrastEditor
|
|
15
|
+
return ContrastEditor.from_model_manager(model_manager, **kwargs)
|
|
16
|
+
elif processor_type == "Sharpness":
|
|
17
|
+
from .PILEditor import SharpnessEditor
|
|
18
|
+
return SharpnessEditor.from_model_manager(model_manager, **kwargs)
|
|
19
|
+
elif processor_type == "RIFE":
|
|
20
|
+
from .RIFE import RIFESmoother
|
|
21
|
+
return RIFESmoother.from_model_manager(model_manager, **kwargs)
|
|
22
|
+
else:
|
|
23
|
+
raise ValueError(f"invalid processor_type: {processor_type}")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class SequencialProcessor(VideoProcessor):
|
|
27
|
+
def __init__(self, processors=[]):
|
|
28
|
+
self.processors = processors
|
|
29
|
+
|
|
30
|
+
@staticmethod
|
|
31
|
+
def from_model_manager(model_manager, configs):
|
|
32
|
+
processors = [
|
|
33
|
+
AutoVideoProcessor.from_model_manager(model_manager, config["processor_type"], **config["config"])
|
|
34
|
+
for config in configs
|
|
35
|
+
]
|
|
36
|
+
return SequencialProcessor(processors)
|
|
37
|
+
|
|
38
|
+
def __call__(self, rendered_frames, **kwargs):
|
|
39
|
+
for processor in self.processors:
|
|
40
|
+
rendered_frames = processor(rendered_frames, **kwargs)
|
|
41
|
+
return rendered_frames
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
from .prompt_refiners import Translator, BeautifulPrompt
|
|
2
|
+
from .sd_prompter import SDPrompter
|
|
3
|
+
from .sdxl_prompter import SDXLPrompter
|
|
4
|
+
from .sd3_prompter import SD3Prompter
|
|
5
|
+
from .hunyuan_dit_prompter import HunyuanDiTPrompter
|
|
6
|
+
from .kolors_prompter import KolorsPrompter
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from ..models.model_manager import ModelManager
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def tokenize_long_prompt(tokenizer, prompt, max_length=None):
|
|
7
|
+
# Get model_max_length from self.tokenizer
|
|
8
|
+
length = tokenizer.model_max_length if max_length is None else max_length
|
|
9
|
+
|
|
10
|
+
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
|
|
11
|
+
tokenizer.model_max_length = 99999999
|
|
12
|
+
|
|
13
|
+
# Tokenize it!
|
|
14
|
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
|
15
|
+
|
|
16
|
+
# Determine the real length.
|
|
17
|
+
max_length = (input_ids.shape[1] + length - 1) // length * length
|
|
18
|
+
|
|
19
|
+
# Restore tokenizer.model_max_length
|
|
20
|
+
tokenizer.model_max_length = length
|
|
21
|
+
|
|
22
|
+
# Tokenize it again with fixed length.
|
|
23
|
+
input_ids = tokenizer(
|
|
24
|
+
prompt,
|
|
25
|
+
return_tensors="pt",
|
|
26
|
+
padding="max_length",
|
|
27
|
+
max_length=max_length,
|
|
28
|
+
truncation=True
|
|
29
|
+
).input_ids
|
|
30
|
+
|
|
31
|
+
# Reshape input_ids to fit the text encoder.
|
|
32
|
+
num_sentence = input_ids.shape[1] // length
|
|
33
|
+
input_ids = input_ids.reshape((num_sentence, length))
|
|
34
|
+
|
|
35
|
+
return input_ids
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class BasePrompter:
|
|
40
|
+
def __init__(self, refiners=[]):
|
|
41
|
+
self.refiners = refiners
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def load_prompt_refiners(self, model_nameger: ModelManager, refiner_classes=[]):
|
|
45
|
+
for refiner_class in refiner_classes:
|
|
46
|
+
refiner = refiner_class.from_model_manager(model_nameger)
|
|
47
|
+
self.refiners.append(refiner)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@torch.no_grad()
|
|
51
|
+
def process_prompt(self, prompt, positive=True):
|
|
52
|
+
if isinstance(prompt, list):
|
|
53
|
+
prompt = [self.process_prompt(prompt_, positive=positive) for prompt_ in prompt]
|
|
54
|
+
else:
|
|
55
|
+
for refiner in self.refiners:
|
|
56
|
+
prompt = refiner(prompt, positive=positive)
|
|
57
|
+
return prompt
|