diffsynth 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. diffsynth/__init__.py +6 -0
  2. diffsynth/configs/__init__.py +0 -0
  3. diffsynth/configs/model_config.py +243 -0
  4. diffsynth/controlnets/__init__.py +2 -0
  5. diffsynth/controlnets/controlnet_unit.py +53 -0
  6. diffsynth/controlnets/processors.py +51 -0
  7. diffsynth/data/__init__.py +1 -0
  8. diffsynth/data/simple_text_image.py +35 -0
  9. diffsynth/data/video.py +148 -0
  10. diffsynth/extensions/ESRGAN/__init__.py +118 -0
  11. diffsynth/extensions/FastBlend/__init__.py +63 -0
  12. diffsynth/extensions/FastBlend/api.py +397 -0
  13. diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
  14. diffsynth/extensions/FastBlend/data.py +146 -0
  15. diffsynth/extensions/FastBlend/patch_match.py +298 -0
  16. diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
  17. diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
  18. diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
  19. diffsynth/extensions/FastBlend/runners/fast.py +141 -0
  20. diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
  21. diffsynth/extensions/RIFE/__init__.py +242 -0
  22. diffsynth/extensions/__init__.py +0 -0
  23. diffsynth/models/__init__.py +1 -0
  24. diffsynth/models/attention.py +89 -0
  25. diffsynth/models/downloader.py +66 -0
  26. diffsynth/models/hunyuan_dit.py +451 -0
  27. diffsynth/models/hunyuan_dit_text_encoder.py +163 -0
  28. diffsynth/models/kolors_text_encoder.py +1363 -0
  29. diffsynth/models/lora.py +195 -0
  30. diffsynth/models/model_manager.py +536 -0
  31. diffsynth/models/sd3_dit.py +798 -0
  32. diffsynth/models/sd3_text_encoder.py +1107 -0
  33. diffsynth/models/sd3_vae_decoder.py +81 -0
  34. diffsynth/models/sd3_vae_encoder.py +95 -0
  35. diffsynth/models/sd_controlnet.py +588 -0
  36. diffsynth/models/sd_ipadapter.py +57 -0
  37. diffsynth/models/sd_motion.py +199 -0
  38. diffsynth/models/sd_text_encoder.py +321 -0
  39. diffsynth/models/sd_unet.py +1108 -0
  40. diffsynth/models/sd_vae_decoder.py +336 -0
  41. diffsynth/models/sd_vae_encoder.py +282 -0
  42. diffsynth/models/sdxl_ipadapter.py +122 -0
  43. diffsynth/models/sdxl_motion.py +104 -0
  44. diffsynth/models/sdxl_text_encoder.py +759 -0
  45. diffsynth/models/sdxl_unet.py +1899 -0
  46. diffsynth/models/sdxl_vae_decoder.py +24 -0
  47. diffsynth/models/sdxl_vae_encoder.py +24 -0
  48. diffsynth/models/svd_image_encoder.py +505 -0
  49. diffsynth/models/svd_unet.py +2004 -0
  50. diffsynth/models/svd_vae_decoder.py +578 -0
  51. diffsynth/models/svd_vae_encoder.py +139 -0
  52. diffsynth/models/tiler.py +106 -0
  53. diffsynth/pipelines/__init__.py +9 -0
  54. diffsynth/pipelines/base.py +34 -0
  55. diffsynth/pipelines/dancer.py +178 -0
  56. diffsynth/pipelines/hunyuan_image.py +274 -0
  57. diffsynth/pipelines/pipeline_runner.py +105 -0
  58. diffsynth/pipelines/sd3_image.py +132 -0
  59. diffsynth/pipelines/sd_image.py +173 -0
  60. diffsynth/pipelines/sd_video.py +266 -0
  61. diffsynth/pipelines/sdxl_image.py +191 -0
  62. diffsynth/pipelines/sdxl_video.py +223 -0
  63. diffsynth/pipelines/svd_video.py +297 -0
  64. diffsynth/processors/FastBlend.py +142 -0
  65. diffsynth/processors/PILEditor.py +28 -0
  66. diffsynth/processors/RIFE.py +77 -0
  67. diffsynth/processors/__init__.py +0 -0
  68. diffsynth/processors/base.py +6 -0
  69. diffsynth/processors/sequencial_processor.py +41 -0
  70. diffsynth/prompters/__init__.py +6 -0
  71. diffsynth/prompters/base_prompter.py +57 -0
  72. diffsynth/prompters/hunyuan_dit_prompter.py +69 -0
  73. diffsynth/prompters/kolors_prompter.py +353 -0
  74. diffsynth/prompters/prompt_refiners.py +77 -0
  75. diffsynth/prompters/sd3_prompter.py +92 -0
  76. diffsynth/prompters/sd_prompter.py +73 -0
  77. diffsynth/prompters/sdxl_prompter.py +61 -0
  78. diffsynth/schedulers/__init__.py +3 -0
  79. diffsynth/schedulers/continuous_ode.py +59 -0
  80. diffsynth/schedulers/ddim.py +79 -0
  81. diffsynth/schedulers/flow_match.py +51 -0
  82. diffsynth/tokenizer_configs/__init__.py +0 -0
  83. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json +7 -0
  84. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json +16 -0
  85. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt +47020 -0
  86. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt +21128 -0
  87. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json +28 -0
  88. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json +1 -0
  89. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model +0 -0
  90. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json +1 -0
  91. diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model +0 -0
  92. diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json +12 -0
  93. diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt +0 -0
  94. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt +48895 -0
  95. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json +24 -0
  96. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json +34 -0
  97. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json +49410 -0
  98. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt +48895 -0
  99. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json +30 -0
  100. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json +30 -0
  101. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json +49410 -0
  102. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt +48895 -0
  103. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json +30 -0
  104. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json +38 -0
  105. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json +49410 -0
  106. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json +125 -0
  107. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/spiece.model +0 -0
  108. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json +129428 -0
  109. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json +940 -0
  110. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt +40213 -0
  111. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json +24 -0
  112. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json +38 -0
  113. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/vocab.json +49411 -0
  114. diffsynth/trainers/__init__.py +0 -0
  115. diffsynth/trainers/text_to_image.py +253 -0
  116. diffsynth-1.0.0.dist-info/LICENSE +201 -0
  117. diffsynth-1.0.0.dist-info/METADATA +23 -0
  118. diffsynth-1.0.0.dist-info/RECORD +120 -0
  119. diffsynth-1.0.0.dist-info/WHEEL +5 -0
  120. diffsynth-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,297 @@
1
+ from ..models import ModelManager, SVDImageEncoder, SVDUNet, SVDVAEEncoder, SVDVAEDecoder
2
+ from ..schedulers import ContinuousODEScheduler
3
+ from .base import BasePipeline
4
+ import torch
5
+ from tqdm import tqdm
6
+ from PIL import Image
7
+ import numpy as np
8
+ from einops import rearrange, repeat
9
+
10
+
11
+
12
+ class SVDVideoPipeline(BasePipeline):
13
+
14
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
15
+ super().__init__(device=device, torch_dtype=torch_dtype)
16
+ self.scheduler = ContinuousODEScheduler()
17
+ # models
18
+ self.image_encoder: SVDImageEncoder = None
19
+ self.unet: SVDUNet = None
20
+ self.vae_encoder: SVDVAEEncoder = None
21
+ self.vae_decoder: SVDVAEDecoder = None
22
+
23
+
24
+ def fetch_models(self, model_manager: ModelManager):
25
+ self.image_encoder = model_manager.fetch_model("svd_image_encoder")
26
+ self.unet = model_manager.fetch_model("svd_unet")
27
+ self.vae_encoder = model_manager.fetch_model("svd_vae_encoder")
28
+ self.vae_decoder = model_manager.fetch_model("svd_vae_decoder")
29
+
30
+
31
+ @staticmethod
32
+ def from_model_manager(model_manager: ModelManager, **kwargs):
33
+ pipe = SVDVideoPipeline(
34
+ device=model_manager.device,
35
+ torch_dtype=model_manager.torch_dtype
36
+ )
37
+ pipe.fetch_models(model_manager)
38
+ return pipe
39
+
40
+
41
+ def encode_image_with_clip(self, image):
42
+ image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
43
+ image = SVDCLIPImageProcessor().resize_with_antialiasing(image, (224, 224))
44
+ image = (image + 1.0) / 2.0
45
+ mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).reshape(1, 3, 1, 1).to(device=self.device, dtype=self.torch_dtype)
46
+ std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).reshape(1, 3, 1, 1).to(device=self.device, dtype=self.torch_dtype)
47
+ image = (image - mean) / std
48
+ image_emb = self.image_encoder(image)
49
+ return image_emb
50
+
51
+
52
+ def encode_image_with_vae(self, image, noise_aug_strength):
53
+ image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
54
+ noise = torch.randn(image.shape, device="cpu", dtype=self.torch_dtype).to(self.device)
55
+ image = image + noise_aug_strength * noise
56
+ image_emb = self.vae_encoder(image) / self.vae_encoder.scaling_factor
57
+ return image_emb
58
+
59
+
60
+ def encode_video_with_vae(self, video):
61
+ video = torch.concat([self.preprocess_image(frame) for frame in video], dim=0)
62
+ video = rearrange(video, "T C H W -> 1 C T H W")
63
+ video = video.to(device=self.device, dtype=self.torch_dtype)
64
+ latents = self.vae_encoder.encode_video(video)
65
+ latents = rearrange(latents[0], "C T H W -> T C H W")
66
+ return latents
67
+
68
+
69
+ def tensor2video(self, frames):
70
+ frames = rearrange(frames, "C T H W -> T H W C")
71
+ frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
72
+ frames = [Image.fromarray(frame) for frame in frames]
73
+ return frames
74
+
75
+
76
+ def calculate_noise_pred(
77
+ self,
78
+ latents,
79
+ timestep,
80
+ add_time_id,
81
+ cfg_scales,
82
+ image_emb_vae_posi, image_emb_clip_posi,
83
+ image_emb_vae_nega, image_emb_clip_nega
84
+ ):
85
+ # Positive side
86
+ noise_pred_posi = self.unet(
87
+ torch.cat([latents, image_emb_vae_posi], dim=1),
88
+ timestep, image_emb_clip_posi, add_time_id
89
+ )
90
+ # Negative side
91
+ noise_pred_nega = self.unet(
92
+ torch.cat([latents, image_emb_vae_nega], dim=1),
93
+ timestep, image_emb_clip_nega, add_time_id
94
+ )
95
+
96
+ # Classifier-free guidance
97
+ noise_pred = noise_pred_nega + cfg_scales * (noise_pred_posi - noise_pred_nega)
98
+
99
+ return noise_pred
100
+
101
+
102
+ def post_process_latents(self, latents, post_normalize=True, contrast_enhance_scale=1.0):
103
+ if post_normalize:
104
+ mean, std = latents.mean(), latents.std()
105
+ latents = (latents - latents.mean(dim=[1, 2, 3], keepdim=True)) / latents.std(dim=[1, 2, 3], keepdim=True) * std + mean
106
+ latents = latents * contrast_enhance_scale
107
+ return latents
108
+
109
+
110
+ @torch.no_grad()
111
+ def __call__(
112
+ self,
113
+ input_image=None,
114
+ input_video=None,
115
+ mask_frames=[],
116
+ mask_frame_ids=[],
117
+ min_cfg_scale=1.0,
118
+ max_cfg_scale=3.0,
119
+ denoising_strength=1.0,
120
+ num_frames=25,
121
+ height=576,
122
+ width=1024,
123
+ fps=7,
124
+ motion_bucket_id=127,
125
+ noise_aug_strength=0.02,
126
+ num_inference_steps=20,
127
+ post_normalize=True,
128
+ contrast_enhance_scale=1.2,
129
+ progress_bar_cmd=tqdm,
130
+ progress_bar_st=None,
131
+ ):
132
+ # Prepare scheduler
133
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
134
+
135
+ # Prepare latent tensors
136
+ noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).to(self.device)
137
+ if denoising_strength == 1.0:
138
+ latents = noise.clone()
139
+ else:
140
+ latents = self.encode_video_with_vae(input_video)
141
+ latents = self.scheduler.add_noise(latents, noise, self.scheduler.timesteps[0])
142
+
143
+ # Prepare mask frames
144
+ if len(mask_frames) > 0:
145
+ mask_latents = self.encode_video_with_vae(mask_frames)
146
+
147
+ # Encode image
148
+ image_emb_clip_posi = self.encode_image_with_clip(input_image)
149
+ image_emb_clip_nega = torch.zeros_like(image_emb_clip_posi)
150
+ image_emb_vae_posi = repeat(self.encode_image_with_vae(input_image, noise_aug_strength), "B C H W -> (B T) C H W", T=num_frames)
151
+ image_emb_vae_nega = torch.zeros_like(image_emb_vae_posi)
152
+
153
+ # Prepare classifier-free guidance
154
+ cfg_scales = torch.linspace(min_cfg_scale, max_cfg_scale, num_frames)
155
+ cfg_scales = cfg_scales.reshape(num_frames, 1, 1, 1).to(device=self.device, dtype=self.torch_dtype)
156
+
157
+ # Prepare positional id
158
+ add_time_id = torch.tensor([[fps-1, motion_bucket_id, noise_aug_strength]], device=self.device)
159
+
160
+ # Denoise
161
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
162
+
163
+ # Mask frames
164
+ for frame_id, mask_frame_id in enumerate(mask_frame_ids):
165
+ latents[mask_frame_id] = self.scheduler.add_noise(mask_latents[frame_id], noise[mask_frame_id], timestep)
166
+
167
+ # Fetch model output
168
+ noise_pred = self.calculate_noise_pred(
169
+ latents, timestep, add_time_id, cfg_scales,
170
+ image_emb_vae_posi, image_emb_clip_posi, image_emb_vae_nega, image_emb_clip_nega
171
+ )
172
+
173
+ # Forward Euler
174
+ latents = self.scheduler.step(noise_pred, timestep, latents)
175
+
176
+ # Update progress bar
177
+ if progress_bar_st is not None:
178
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
179
+
180
+ # Decode image
181
+ latents = self.post_process_latents(latents, post_normalize=post_normalize, contrast_enhance_scale=contrast_enhance_scale)
182
+ video = self.vae_decoder.decode_video(latents, progress_bar=progress_bar_cmd)
183
+ video = self.tensor2video(video)
184
+
185
+ return video
186
+
187
+
188
+
189
+ class SVDCLIPImageProcessor:
190
+ def __init__(self):
191
+ pass
192
+
193
+ def resize_with_antialiasing(self, input, size, interpolation="bicubic", align_corners=True):
194
+ h, w = input.shape[-2:]
195
+ factors = (h / size[0], w / size[1])
196
+
197
+ # First, we have to determine sigma
198
+ # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
199
+ sigmas = (
200
+ max((factors[0] - 1.0) / 2.0, 0.001),
201
+ max((factors[1] - 1.0) / 2.0, 0.001),
202
+ )
203
+
204
+ # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
205
+ # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
206
+ # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
207
+ ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
208
+
209
+ # Make sure it is odd
210
+ if (ks[0] % 2) == 0:
211
+ ks = ks[0] + 1, ks[1]
212
+
213
+ if (ks[1] % 2) == 0:
214
+ ks = ks[0], ks[1] + 1
215
+
216
+ input = self._gaussian_blur2d(input, ks, sigmas)
217
+
218
+ output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
219
+ return output
220
+
221
+
222
+ def _compute_padding(self, kernel_size):
223
+ """Compute padding tuple."""
224
+ # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
225
+ # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
226
+ if len(kernel_size) < 2:
227
+ raise AssertionError(kernel_size)
228
+ computed = [k - 1 for k in kernel_size]
229
+
230
+ # for even kernels we need to do asymmetric padding :(
231
+ out_padding = 2 * len(kernel_size) * [0]
232
+
233
+ for i in range(len(kernel_size)):
234
+ computed_tmp = computed[-(i + 1)]
235
+
236
+ pad_front = computed_tmp // 2
237
+ pad_rear = computed_tmp - pad_front
238
+
239
+ out_padding[2 * i + 0] = pad_front
240
+ out_padding[2 * i + 1] = pad_rear
241
+
242
+ return out_padding
243
+
244
+
245
+ def _filter2d(self, input, kernel):
246
+ # prepare kernel
247
+ b, c, h, w = input.shape
248
+ tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
249
+
250
+ tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
251
+
252
+ height, width = tmp_kernel.shape[-2:]
253
+
254
+ padding_shape: list[int] = self._compute_padding([height, width])
255
+ input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
256
+
257
+ # kernel and input tensor reshape to align element-wise or batch-wise params
258
+ tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
259
+ input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
260
+
261
+ # convolve the tensor with the kernel.
262
+ output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
263
+
264
+ out = output.view(b, c, h, w)
265
+ return out
266
+
267
+
268
+ def _gaussian(self, window_size: int, sigma):
269
+ if isinstance(sigma, float):
270
+ sigma = torch.tensor([[sigma]])
271
+
272
+ batch_size = sigma.shape[0]
273
+
274
+ x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
275
+
276
+ if window_size % 2 == 0:
277
+ x = x + 0.5
278
+
279
+ gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
280
+
281
+ return gauss / gauss.sum(-1, keepdim=True)
282
+
283
+
284
+ def _gaussian_blur2d(self, input, kernel_size, sigma):
285
+ if isinstance(sigma, tuple):
286
+ sigma = torch.tensor([sigma], dtype=input.dtype)
287
+ else:
288
+ sigma = sigma.to(dtype=input.dtype)
289
+
290
+ ky, kx = int(kernel_size[0]), int(kernel_size[1])
291
+ bs = sigma.shape[0]
292
+ kernel_x = self._gaussian(kx, sigma[:, 1].view(bs, 1))
293
+ kernel_y = self._gaussian(ky, sigma[:, 0].view(bs, 1))
294
+ out_x = self._filter2d(input, kernel_x[..., None, :])
295
+ out = self._filter2d(out_x, kernel_y[..., None])
296
+
297
+ return out
@@ -0,0 +1,142 @@
1
+ from PIL import Image
2
+ import cupy as cp
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ from ..extensions.FastBlend.patch_match import PyramidPatchMatcher
6
+ from ..extensions.FastBlend.runners.fast import TableManager
7
+ from .base import VideoProcessor
8
+
9
+
10
+ class FastBlendSmoother(VideoProcessor):
11
+ def __init__(
12
+ self,
13
+ inference_mode="fast", batch_size=8, window_size=60,
14
+ minimum_patch_size=5, threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0, initialize="identity", tracking_window_size=0
15
+ ):
16
+ self.inference_mode = inference_mode
17
+ self.batch_size = batch_size
18
+ self.window_size = window_size
19
+ self.ebsynth_config = {
20
+ "minimum_patch_size": minimum_patch_size,
21
+ "threads_per_block": threads_per_block,
22
+ "num_iter": num_iter,
23
+ "gpu_id": gpu_id,
24
+ "guide_weight": guide_weight,
25
+ "initialize": initialize,
26
+ "tracking_window_size": tracking_window_size
27
+ }
28
+
29
+ @staticmethod
30
+ def from_model_manager(model_manager, **kwargs):
31
+ # TODO: fetch GPU ID from model_manager
32
+ return FastBlendSmoother(**kwargs)
33
+
34
+ def inference_fast(self, frames_guide, frames_style):
35
+ table_manager = TableManager()
36
+ patch_match_engine = PyramidPatchMatcher(
37
+ image_height=frames_style[0].shape[0],
38
+ image_width=frames_style[0].shape[1],
39
+ channel=3,
40
+ **self.ebsynth_config
41
+ )
42
+ # left part
43
+ table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, self.batch_size, desc="Fast Mode Step 1/4")
44
+ table_l = table_manager.remapping_table_to_blending_table(table_l)
45
+ table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, self.window_size, self.batch_size, desc="Fast Mode Step 2/4")
46
+ # right part
47
+ table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, self.batch_size, desc="Fast Mode Step 3/4")
48
+ table_r = table_manager.remapping_table_to_blending_table(table_r)
49
+ table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, self.window_size, self.batch_size, desc="Fast Mode Step 4/4")[::-1]
50
+ # merge
51
+ frames = []
52
+ for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
53
+ weight_m = -1
54
+ weight = weight_l + weight_m + weight_r
55
+ frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
56
+ frames.append(frame)
57
+ frames = [frame.clip(0, 255).astype("uint8") for frame in frames]
58
+ frames = [Image.fromarray(frame) for frame in frames]
59
+ return frames
60
+
61
+ def inference_balanced(self, frames_guide, frames_style):
62
+ patch_match_engine = PyramidPatchMatcher(
63
+ image_height=frames_style[0].shape[0],
64
+ image_width=frames_style[0].shape[1],
65
+ channel=3,
66
+ **self.ebsynth_config
67
+ )
68
+ output_frames = []
69
+ # tasks
70
+ n = len(frames_style)
71
+ tasks = []
72
+ for target in range(n):
73
+ for source in range(target - self.window_size, target + self.window_size + 1):
74
+ if source >= 0 and source < n and source != target:
75
+ tasks.append((source, target))
76
+ # run
77
+ frames = [(None, 1) for i in range(n)]
78
+ for batch_id in tqdm(range(0, len(tasks), self.batch_size), desc="Balanced Mode"):
79
+ tasks_batch = tasks[batch_id: min(batch_id+self.batch_size, len(tasks))]
80
+ source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
81
+ target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
82
+ source_style = np.stack([frames_style[source] for source, target in tasks_batch])
83
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
84
+ for (source, target), result in zip(tasks_batch, target_style):
85
+ frame, weight = frames[target]
86
+ if frame is None:
87
+ frame = frames_style[target]
88
+ frames[target] = (
89
+ frame * (weight / (weight + 1)) + result / (weight + 1),
90
+ weight + 1
91
+ )
92
+ if weight + 1 == min(n, target + self.window_size + 1) - max(0, target - self.window_size):
93
+ frame = frame.clip(0, 255).astype("uint8")
94
+ output_frames.append(Image.fromarray(frame))
95
+ frames[target] = (None, 1)
96
+ return output_frames
97
+
98
+ def inference_accurate(self, frames_guide, frames_style):
99
+ patch_match_engine = PyramidPatchMatcher(
100
+ image_height=frames_style[0].shape[0],
101
+ image_width=frames_style[0].shape[1],
102
+ channel=3,
103
+ use_mean_target_style=True,
104
+ **self.ebsynth_config
105
+ )
106
+ output_frames = []
107
+ # run
108
+ n = len(frames_style)
109
+ for target in tqdm(range(n), desc="Accurate Mode"):
110
+ l, r = max(target - self.window_size, 0), min(target + self.window_size + 1, n)
111
+ remapped_frames = []
112
+ for i in range(l, r, self.batch_size):
113
+ j = min(i + self.batch_size, r)
114
+ source_guide = np.stack([frames_guide[source] for source in range(i, j)])
115
+ target_guide = np.stack([frames_guide[target]] * (j - i))
116
+ source_style = np.stack([frames_style[source] for source in range(i, j)])
117
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
118
+ remapped_frames.append(target_style)
119
+ frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
120
+ frame = frame.clip(0, 255).astype("uint8")
121
+ output_frames.append(Image.fromarray(frame))
122
+ return output_frames
123
+
124
+ def release_vram(self):
125
+ mempool = cp.get_default_memory_pool()
126
+ pinned_mempool = cp.get_default_pinned_memory_pool()
127
+ mempool.free_all_blocks()
128
+ pinned_mempool.free_all_blocks()
129
+
130
+ def __call__(self, rendered_frames, original_frames=None, **kwargs):
131
+ rendered_frames = [np.array(frame) for frame in rendered_frames]
132
+ original_frames = [np.array(frame) for frame in original_frames]
133
+ if self.inference_mode == "fast":
134
+ output_frames = self.inference_fast(original_frames, rendered_frames)
135
+ elif self.inference_mode == "balanced":
136
+ output_frames = self.inference_balanced(original_frames, rendered_frames)
137
+ elif self.inference_mode == "accurate":
138
+ output_frames = self.inference_accurate(original_frames, rendered_frames)
139
+ else:
140
+ raise ValueError("inference_mode must be fast, balanced or accurate")
141
+ self.release_vram()
142
+ return output_frames
@@ -0,0 +1,28 @@
1
+ from PIL import ImageEnhance
2
+ from .base import VideoProcessor
3
+
4
+
5
+ class ContrastEditor(VideoProcessor):
6
+ def __init__(self, rate=1.5):
7
+ self.rate = rate
8
+
9
+ @staticmethod
10
+ def from_model_manager(model_manager, **kwargs):
11
+ return ContrastEditor(**kwargs)
12
+
13
+ def __call__(self, rendered_frames, **kwargs):
14
+ rendered_frames = [ImageEnhance.Contrast(i).enhance(self.rate) for i in rendered_frames]
15
+ return rendered_frames
16
+
17
+
18
+ class SharpnessEditor(VideoProcessor):
19
+ def __init__(self, rate=1.5):
20
+ self.rate = rate
21
+
22
+ @staticmethod
23
+ def from_model_manager(model_manager, **kwargs):
24
+ return SharpnessEditor(**kwargs)
25
+
26
+ def __call__(self, rendered_frames, **kwargs):
27
+ rendered_frames = [ImageEnhance.Sharpness(i).enhance(self.rate) for i in rendered_frames]
28
+ return rendered_frames
@@ -0,0 +1,77 @@
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ from .base import VideoProcessor
5
+
6
+
7
+ class RIFESmoother(VideoProcessor):
8
+ def __init__(self, model, device="cuda", scale=1.0, batch_size=4, interpolate=True):
9
+ self.model = model
10
+ self.device = device
11
+
12
+ # IFNet only does not support float16
13
+ self.torch_dtype = torch.float32
14
+
15
+ # Other parameters
16
+ self.scale = scale
17
+ self.batch_size = batch_size
18
+ self.interpolate = interpolate
19
+
20
+ @staticmethod
21
+ def from_model_manager(model_manager, **kwargs):
22
+ return RIFESmoother(model_manager.RIFE, device=model_manager.device, **kwargs)
23
+
24
+ def process_image(self, image):
25
+ width, height = image.size
26
+ if width % 32 != 0 or height % 32 != 0:
27
+ width = (width + 31) // 32
28
+ height = (height + 31) // 32
29
+ image = image.resize((width, height))
30
+ image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1)
31
+ return image
32
+
33
+ def process_images(self, images):
34
+ images = [self.process_image(image) for image in images]
35
+ images = torch.stack(images)
36
+ return images
37
+
38
+ def decode_images(self, images):
39
+ images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
40
+ images = [Image.fromarray(image) for image in images]
41
+ return images
42
+
43
+ def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
44
+ output_tensor = []
45
+ for batch_id in range(0, input_tensor.shape[0], batch_size):
46
+ batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
47
+ batch_input_tensor = input_tensor[batch_id: batch_id_]
48
+ batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
49
+ flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
50
+ output_tensor.append(merged[2].cpu())
51
+ output_tensor = torch.concat(output_tensor, dim=0)
52
+ return output_tensor
53
+
54
+ @torch.no_grad()
55
+ def __call__(self, rendered_frames, **kwargs):
56
+ # Preprocess
57
+ processed_images = self.process_images(rendered_frames)
58
+
59
+ # Input
60
+ input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1)
61
+
62
+ # Interpolate
63
+ output_tensor = self.process_tensors(input_tensor, scale=self.scale, batch_size=self.batch_size)
64
+
65
+ if self.interpolate:
66
+ # Blend
67
+ input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1)
68
+ output_tensor = self.process_tensors(input_tensor, scale=self.scale, batch_size=self.batch_size)
69
+ processed_images[1:-1] = output_tensor
70
+ else:
71
+ processed_images[1:-1] = (processed_images[1:-1] + output_tensor) / 2
72
+
73
+ # To images
74
+ output_images = self.decode_images(processed_images)
75
+ if output_images[0].size != rendered_frames[0].size:
76
+ output_images = [image.resize(rendered_frames[0].size) for image in output_images]
77
+ return output_images
File without changes
@@ -0,0 +1,6 @@
1
+ class VideoProcessor:
2
+ def __init__(self):
3
+ pass
4
+
5
+ def __call__(self):
6
+ raise NotImplementedError
@@ -0,0 +1,41 @@
1
+ from .base import VideoProcessor
2
+
3
+
4
+ class AutoVideoProcessor(VideoProcessor):
5
+ def __init__(self):
6
+ pass
7
+
8
+ @staticmethod
9
+ def from_model_manager(model_manager, processor_type, **kwargs):
10
+ if processor_type == "FastBlend":
11
+ from .FastBlend import FastBlendSmoother
12
+ return FastBlendSmoother.from_model_manager(model_manager, **kwargs)
13
+ elif processor_type == "Contrast":
14
+ from .PILEditor import ContrastEditor
15
+ return ContrastEditor.from_model_manager(model_manager, **kwargs)
16
+ elif processor_type == "Sharpness":
17
+ from .PILEditor import SharpnessEditor
18
+ return SharpnessEditor.from_model_manager(model_manager, **kwargs)
19
+ elif processor_type == "RIFE":
20
+ from .RIFE import RIFESmoother
21
+ return RIFESmoother.from_model_manager(model_manager, **kwargs)
22
+ else:
23
+ raise ValueError(f"invalid processor_type: {processor_type}")
24
+
25
+
26
+ class SequencialProcessor(VideoProcessor):
27
+ def __init__(self, processors=[]):
28
+ self.processors = processors
29
+
30
+ @staticmethod
31
+ def from_model_manager(model_manager, configs):
32
+ processors = [
33
+ AutoVideoProcessor.from_model_manager(model_manager, config["processor_type"], **config["config"])
34
+ for config in configs
35
+ ]
36
+ return SequencialProcessor(processors)
37
+
38
+ def __call__(self, rendered_frames, **kwargs):
39
+ for processor in self.processors:
40
+ rendered_frames = processor(rendered_frames, **kwargs)
41
+ return rendered_frames
@@ -0,0 +1,6 @@
1
+ from .prompt_refiners import Translator, BeautifulPrompt
2
+ from .sd_prompter import SDPrompter
3
+ from .sdxl_prompter import SDXLPrompter
4
+ from .sd3_prompter import SD3Prompter
5
+ from .hunyuan_dit_prompter import HunyuanDiTPrompter
6
+ from .kolors_prompter import KolorsPrompter
@@ -0,0 +1,57 @@
1
+ from ..models.model_manager import ModelManager
2
+ import torch
3
+
4
+
5
+
6
+ def tokenize_long_prompt(tokenizer, prompt, max_length=None):
7
+ # Get model_max_length from self.tokenizer
8
+ length = tokenizer.model_max_length if max_length is None else max_length
9
+
10
+ # To avoid the warning. set self.tokenizer.model_max_length to +oo.
11
+ tokenizer.model_max_length = 99999999
12
+
13
+ # Tokenize it!
14
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
15
+
16
+ # Determine the real length.
17
+ max_length = (input_ids.shape[1] + length - 1) // length * length
18
+
19
+ # Restore tokenizer.model_max_length
20
+ tokenizer.model_max_length = length
21
+
22
+ # Tokenize it again with fixed length.
23
+ input_ids = tokenizer(
24
+ prompt,
25
+ return_tensors="pt",
26
+ padding="max_length",
27
+ max_length=max_length,
28
+ truncation=True
29
+ ).input_ids
30
+
31
+ # Reshape input_ids to fit the text encoder.
32
+ num_sentence = input_ids.shape[1] // length
33
+ input_ids = input_ids.reshape((num_sentence, length))
34
+
35
+ return input_ids
36
+
37
+
38
+
39
+ class BasePrompter:
40
+ def __init__(self, refiners=[]):
41
+ self.refiners = refiners
42
+
43
+
44
+ def load_prompt_refiners(self, model_nameger: ModelManager, refiner_classes=[]):
45
+ for refiner_class in refiner_classes:
46
+ refiner = refiner_class.from_model_manager(model_nameger)
47
+ self.refiners.append(refiner)
48
+
49
+
50
+ @torch.no_grad()
51
+ def process_prompt(self, prompt, positive=True):
52
+ if isinstance(prompt, list):
53
+ prompt = [self.process_prompt(prompt_, positive=positive) for prompt_ in prompt]
54
+ else:
55
+ for refiner in self.refiners:
56
+ prompt = refiner(prompt, positive=positive)
57
+ return prompt