diffsynth 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth/__init__.py +6 -0
- diffsynth/configs/__init__.py +0 -0
- diffsynth/configs/model_config.py +243 -0
- diffsynth/controlnets/__init__.py +2 -0
- diffsynth/controlnets/controlnet_unit.py +53 -0
- diffsynth/controlnets/processors.py +51 -0
- diffsynth/data/__init__.py +1 -0
- diffsynth/data/simple_text_image.py +35 -0
- diffsynth/data/video.py +148 -0
- diffsynth/extensions/ESRGAN/__init__.py +118 -0
- diffsynth/extensions/FastBlend/__init__.py +63 -0
- diffsynth/extensions/FastBlend/api.py +397 -0
- diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
- diffsynth/extensions/FastBlend/data.py +146 -0
- diffsynth/extensions/FastBlend/patch_match.py +298 -0
- diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
- diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
- diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
- diffsynth/extensions/FastBlend/runners/fast.py +141 -0
- diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
- diffsynth/extensions/RIFE/__init__.py +242 -0
- diffsynth/extensions/__init__.py +0 -0
- diffsynth/models/__init__.py +1 -0
- diffsynth/models/attention.py +89 -0
- diffsynth/models/downloader.py +66 -0
- diffsynth/models/hunyuan_dit.py +451 -0
- diffsynth/models/hunyuan_dit_text_encoder.py +163 -0
- diffsynth/models/kolors_text_encoder.py +1363 -0
- diffsynth/models/lora.py +195 -0
- diffsynth/models/model_manager.py +536 -0
- diffsynth/models/sd3_dit.py +798 -0
- diffsynth/models/sd3_text_encoder.py +1107 -0
- diffsynth/models/sd3_vae_decoder.py +81 -0
- diffsynth/models/sd3_vae_encoder.py +95 -0
- diffsynth/models/sd_controlnet.py +588 -0
- diffsynth/models/sd_ipadapter.py +57 -0
- diffsynth/models/sd_motion.py +199 -0
- diffsynth/models/sd_text_encoder.py +321 -0
- diffsynth/models/sd_unet.py +1108 -0
- diffsynth/models/sd_vae_decoder.py +336 -0
- diffsynth/models/sd_vae_encoder.py +282 -0
- diffsynth/models/sdxl_ipadapter.py +122 -0
- diffsynth/models/sdxl_motion.py +104 -0
- diffsynth/models/sdxl_text_encoder.py +759 -0
- diffsynth/models/sdxl_unet.py +1899 -0
- diffsynth/models/sdxl_vae_decoder.py +24 -0
- diffsynth/models/sdxl_vae_encoder.py +24 -0
- diffsynth/models/svd_image_encoder.py +505 -0
- diffsynth/models/svd_unet.py +2004 -0
- diffsynth/models/svd_vae_decoder.py +578 -0
- diffsynth/models/svd_vae_encoder.py +139 -0
- diffsynth/models/tiler.py +106 -0
- diffsynth/pipelines/__init__.py +9 -0
- diffsynth/pipelines/base.py +34 -0
- diffsynth/pipelines/dancer.py +178 -0
- diffsynth/pipelines/hunyuan_image.py +274 -0
- diffsynth/pipelines/pipeline_runner.py +105 -0
- diffsynth/pipelines/sd3_image.py +132 -0
- diffsynth/pipelines/sd_image.py +173 -0
- diffsynth/pipelines/sd_video.py +266 -0
- diffsynth/pipelines/sdxl_image.py +191 -0
- diffsynth/pipelines/sdxl_video.py +223 -0
- diffsynth/pipelines/svd_video.py +297 -0
- diffsynth/processors/FastBlend.py +142 -0
- diffsynth/processors/PILEditor.py +28 -0
- diffsynth/processors/RIFE.py +77 -0
- diffsynth/processors/__init__.py +0 -0
- diffsynth/processors/base.py +6 -0
- diffsynth/processors/sequencial_processor.py +41 -0
- diffsynth/prompters/__init__.py +6 -0
- diffsynth/prompters/base_prompter.py +57 -0
- diffsynth/prompters/hunyuan_dit_prompter.py +69 -0
- diffsynth/prompters/kolors_prompter.py +353 -0
- diffsynth/prompters/prompt_refiners.py +77 -0
- diffsynth/prompters/sd3_prompter.py +92 -0
- diffsynth/prompters/sd_prompter.py +73 -0
- diffsynth/prompters/sdxl_prompter.py +61 -0
- diffsynth/schedulers/__init__.py +3 -0
- diffsynth/schedulers/continuous_ode.py +59 -0
- diffsynth/schedulers/ddim.py +79 -0
- diffsynth/schedulers/flow_match.py +51 -0
- diffsynth/tokenizer_configs/__init__.py +0 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json +7 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json +16 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt +47020 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt +21128 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json +28 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json +1 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model +0 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json +1 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model +0 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json +12 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt +0 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json +24 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json +34 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json +38 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json +125 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/spiece.model +0 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json +129428 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json +940 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt +40213 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json +24 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json +38 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/vocab.json +49411 -0
- diffsynth/trainers/__init__.py +0 -0
- diffsynth/trainers/text_to_image.py +253 -0
- diffsynth-1.0.0.dist-info/LICENSE +201 -0
- diffsynth-1.0.0.dist-info/METADATA +23 -0
- diffsynth-1.0.0.dist-info/RECORD +120 -0
- diffsynth-1.0.0.dist-info/WHEEL +5 -0
- diffsynth-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
import imageio, os
|
|
2
|
+
import numpy as np
|
|
3
|
+
from PIL import Image
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def read_video(file_name):
|
|
7
|
+
reader = imageio.get_reader(file_name)
|
|
8
|
+
video = []
|
|
9
|
+
for frame in reader:
|
|
10
|
+
frame = np.array(frame)
|
|
11
|
+
video.append(frame)
|
|
12
|
+
reader.close()
|
|
13
|
+
return video
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_video_fps(file_name):
|
|
17
|
+
reader = imageio.get_reader(file_name)
|
|
18
|
+
fps = reader.get_meta_data()["fps"]
|
|
19
|
+
reader.close()
|
|
20
|
+
return fps
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def save_video(frames_path, video_path, num_frames, fps):
|
|
24
|
+
writer = imageio.get_writer(video_path, fps=fps, quality=9)
|
|
25
|
+
for i in range(num_frames):
|
|
26
|
+
frame = np.array(Image.open(os.path.join(frames_path, "%05d.png" % i)))
|
|
27
|
+
writer.append_data(frame)
|
|
28
|
+
writer.close()
|
|
29
|
+
return video_path
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class LowMemoryVideo:
|
|
33
|
+
def __init__(self, file_name):
|
|
34
|
+
self.reader = imageio.get_reader(file_name)
|
|
35
|
+
|
|
36
|
+
def __len__(self):
|
|
37
|
+
return self.reader.count_frames()
|
|
38
|
+
|
|
39
|
+
def __getitem__(self, item):
|
|
40
|
+
return np.array(self.reader.get_data(item))
|
|
41
|
+
|
|
42
|
+
def __del__(self):
|
|
43
|
+
self.reader.close()
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def split_file_name(file_name):
|
|
47
|
+
result = []
|
|
48
|
+
number = -1
|
|
49
|
+
for i in file_name:
|
|
50
|
+
if ord(i)>=ord("0") and ord(i)<=ord("9"):
|
|
51
|
+
if number == -1:
|
|
52
|
+
number = 0
|
|
53
|
+
number = number*10 + ord(i) - ord("0")
|
|
54
|
+
else:
|
|
55
|
+
if number != -1:
|
|
56
|
+
result.append(number)
|
|
57
|
+
number = -1
|
|
58
|
+
result.append(i)
|
|
59
|
+
if number != -1:
|
|
60
|
+
result.append(number)
|
|
61
|
+
result = tuple(result)
|
|
62
|
+
return result
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def search_for_images(folder):
|
|
66
|
+
file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
|
|
67
|
+
file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
|
|
68
|
+
file_list = [i[1] for i in sorted(file_list)]
|
|
69
|
+
file_list = [os.path.join(folder, i) for i in file_list]
|
|
70
|
+
return file_list
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def read_images(folder):
|
|
74
|
+
file_list = search_for_images(folder)
|
|
75
|
+
frames = [np.array(Image.open(i)) for i in file_list]
|
|
76
|
+
return frames
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class LowMemoryImageFolder:
|
|
80
|
+
def __init__(self, folder, file_list=None):
|
|
81
|
+
if file_list is None:
|
|
82
|
+
self.file_list = search_for_images(folder)
|
|
83
|
+
else:
|
|
84
|
+
self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
|
|
85
|
+
|
|
86
|
+
def __len__(self):
|
|
87
|
+
return len(self.file_list)
|
|
88
|
+
|
|
89
|
+
def __getitem__(self, item):
|
|
90
|
+
return np.array(Image.open(self.file_list[item]))
|
|
91
|
+
|
|
92
|
+
def __del__(self):
|
|
93
|
+
pass
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class VideoData:
|
|
97
|
+
def __init__(self, video_file, image_folder, **kwargs):
|
|
98
|
+
if video_file is not None:
|
|
99
|
+
self.data_type = "video"
|
|
100
|
+
self.data = LowMemoryVideo(video_file, **kwargs)
|
|
101
|
+
elif image_folder is not None:
|
|
102
|
+
self.data_type = "images"
|
|
103
|
+
self.data = LowMemoryImageFolder(image_folder, **kwargs)
|
|
104
|
+
else:
|
|
105
|
+
raise ValueError("Cannot open video or image folder")
|
|
106
|
+
self.length = None
|
|
107
|
+
self.height = None
|
|
108
|
+
self.width = None
|
|
109
|
+
|
|
110
|
+
def raw_data(self):
|
|
111
|
+
frames = []
|
|
112
|
+
for i in range(self.__len__()):
|
|
113
|
+
frames.append(self.__getitem__(i))
|
|
114
|
+
return frames
|
|
115
|
+
|
|
116
|
+
def set_length(self, length):
|
|
117
|
+
self.length = length
|
|
118
|
+
|
|
119
|
+
def set_shape(self, height, width):
|
|
120
|
+
self.height = height
|
|
121
|
+
self.width = width
|
|
122
|
+
|
|
123
|
+
def __len__(self):
|
|
124
|
+
if self.length is None:
|
|
125
|
+
return len(self.data)
|
|
126
|
+
else:
|
|
127
|
+
return self.length
|
|
128
|
+
|
|
129
|
+
def shape(self):
|
|
130
|
+
if self.height is not None and self.width is not None:
|
|
131
|
+
return self.height, self.width
|
|
132
|
+
else:
|
|
133
|
+
height, width, _ = self.__getitem__(0).shape
|
|
134
|
+
return height, width
|
|
135
|
+
|
|
136
|
+
def __getitem__(self, item):
|
|
137
|
+
frame = self.data.__getitem__(item)
|
|
138
|
+
height, width, _ = frame.shape
|
|
139
|
+
if self.height is not None and self.width is not None:
|
|
140
|
+
if self.height != height or self.width != width:
|
|
141
|
+
frame = Image.fromarray(frame).resize((self.width, self.height))
|
|
142
|
+
frame = np.array(frame)
|
|
143
|
+
return frame
|
|
144
|
+
|
|
145
|
+
def __del__(self):
|
|
146
|
+
pass
|
|
@@ -0,0 +1,298 @@
|
|
|
1
|
+
from .cupy_kernels import remapping_kernel, patch_error_kernel, pairwise_patch_error_kernel
|
|
2
|
+
import numpy as np
|
|
3
|
+
import cupy as cp
|
|
4
|
+
import cv2
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class PatchMatcher:
|
|
8
|
+
def __init__(
|
|
9
|
+
self, height, width, channel, minimum_patch_size,
|
|
10
|
+
threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
|
|
11
|
+
random_search_steps=3, random_search_range=4,
|
|
12
|
+
use_mean_target_style=False, use_pairwise_patch_error=False,
|
|
13
|
+
tracking_window_size=0
|
|
14
|
+
):
|
|
15
|
+
self.height = height
|
|
16
|
+
self.width = width
|
|
17
|
+
self.channel = channel
|
|
18
|
+
self.minimum_patch_size = minimum_patch_size
|
|
19
|
+
self.threads_per_block = threads_per_block
|
|
20
|
+
self.num_iter = num_iter
|
|
21
|
+
self.gpu_id = gpu_id
|
|
22
|
+
self.guide_weight = guide_weight
|
|
23
|
+
self.random_search_steps = random_search_steps
|
|
24
|
+
self.random_search_range = random_search_range
|
|
25
|
+
self.use_mean_target_style = use_mean_target_style
|
|
26
|
+
self.use_pairwise_patch_error = use_pairwise_patch_error
|
|
27
|
+
self.tracking_window_size = tracking_window_size
|
|
28
|
+
|
|
29
|
+
self.patch_size_list = [minimum_patch_size + i*2 for i in range(num_iter)][::-1]
|
|
30
|
+
self.pad_size = self.patch_size_list[0] // 2
|
|
31
|
+
self.grid = (
|
|
32
|
+
(height + threads_per_block - 1) // threads_per_block,
|
|
33
|
+
(width + threads_per_block - 1) // threads_per_block
|
|
34
|
+
)
|
|
35
|
+
self.block = (threads_per_block, threads_per_block)
|
|
36
|
+
|
|
37
|
+
def pad_image(self, image):
|
|
38
|
+
return cp.pad(image, ((0, 0), (self.pad_size, self.pad_size), (self.pad_size, self.pad_size), (0, 0)))
|
|
39
|
+
|
|
40
|
+
def unpad_image(self, image):
|
|
41
|
+
return image[:, self.pad_size: -self.pad_size, self.pad_size: -self.pad_size, :]
|
|
42
|
+
|
|
43
|
+
def apply_nnf_to_image(self, nnf, source):
|
|
44
|
+
batch_size = source.shape[0]
|
|
45
|
+
target = cp.zeros((batch_size, self.height + self.pad_size * 2, self.width + self.pad_size * 2, self.channel), dtype=cp.float32)
|
|
46
|
+
remapping_kernel(
|
|
47
|
+
self.grid + (batch_size,),
|
|
48
|
+
self.block,
|
|
49
|
+
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target)
|
|
50
|
+
)
|
|
51
|
+
return target
|
|
52
|
+
|
|
53
|
+
def get_patch_error(self, source, nnf, target):
|
|
54
|
+
batch_size = source.shape[0]
|
|
55
|
+
error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
|
|
56
|
+
patch_error_kernel(
|
|
57
|
+
self.grid + (batch_size,),
|
|
58
|
+
self.block,
|
|
59
|
+
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target, error)
|
|
60
|
+
)
|
|
61
|
+
return error
|
|
62
|
+
|
|
63
|
+
def get_pairwise_patch_error(self, source, nnf):
|
|
64
|
+
batch_size = source.shape[0]//2
|
|
65
|
+
error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
|
|
66
|
+
source_a, nnf_a = source[0::2].copy(), nnf[0::2].copy()
|
|
67
|
+
source_b, nnf_b = source[1::2].copy(), nnf[1::2].copy()
|
|
68
|
+
pairwise_patch_error_kernel(
|
|
69
|
+
self.grid + (batch_size,),
|
|
70
|
+
self.block,
|
|
71
|
+
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source_a, nnf_a, source_b, nnf_b, error)
|
|
72
|
+
)
|
|
73
|
+
error = error.repeat(2, axis=0)
|
|
74
|
+
return error
|
|
75
|
+
|
|
76
|
+
def get_error(self, source_guide, target_guide, source_style, target_style, nnf):
|
|
77
|
+
error_guide = self.get_patch_error(source_guide, nnf, target_guide)
|
|
78
|
+
if self.use_mean_target_style:
|
|
79
|
+
target_style = self.apply_nnf_to_image(nnf, source_style)
|
|
80
|
+
target_style = target_style.mean(axis=0, keepdims=True)
|
|
81
|
+
target_style = target_style.repeat(source_guide.shape[0], axis=0)
|
|
82
|
+
if self.use_pairwise_patch_error:
|
|
83
|
+
error_style = self.get_pairwise_patch_error(source_style, nnf)
|
|
84
|
+
else:
|
|
85
|
+
error_style = self.get_patch_error(source_style, nnf, target_style)
|
|
86
|
+
error = error_guide * self.guide_weight + error_style
|
|
87
|
+
return error
|
|
88
|
+
|
|
89
|
+
def clamp_bound(self, nnf):
|
|
90
|
+
nnf[:,:,:,0] = cp.clip(nnf[:,:,:,0], 0, self.height-1)
|
|
91
|
+
nnf[:,:,:,1] = cp.clip(nnf[:,:,:,1], 0, self.width-1)
|
|
92
|
+
return nnf
|
|
93
|
+
|
|
94
|
+
def random_step(self, nnf, r):
|
|
95
|
+
batch_size = nnf.shape[0]
|
|
96
|
+
step = cp.random.randint(-r, r+1, size=(batch_size, self.height, self.width, 2), dtype=cp.int32)
|
|
97
|
+
upd_nnf = self.clamp_bound(nnf + step)
|
|
98
|
+
return upd_nnf
|
|
99
|
+
|
|
100
|
+
def neighboor_step(self, nnf, d):
|
|
101
|
+
if d==0:
|
|
102
|
+
upd_nnf = cp.concatenate([nnf[:, :1, :], nnf[:, :-1, :]], axis=1)
|
|
103
|
+
upd_nnf[:, :, :, 0] += 1
|
|
104
|
+
elif d==1:
|
|
105
|
+
upd_nnf = cp.concatenate([nnf[:, :, :1], nnf[:, :, :-1]], axis=2)
|
|
106
|
+
upd_nnf[:, :, :, 1] += 1
|
|
107
|
+
elif d==2:
|
|
108
|
+
upd_nnf = cp.concatenate([nnf[:, 1:, :], nnf[:, -1:, :]], axis=1)
|
|
109
|
+
upd_nnf[:, :, :, 0] -= 1
|
|
110
|
+
elif d==3:
|
|
111
|
+
upd_nnf = cp.concatenate([nnf[:, :, 1:], nnf[:, :, -1:]], axis=2)
|
|
112
|
+
upd_nnf[:, :, :, 1] -= 1
|
|
113
|
+
upd_nnf = self.clamp_bound(upd_nnf)
|
|
114
|
+
return upd_nnf
|
|
115
|
+
|
|
116
|
+
def shift_nnf(self, nnf, d):
|
|
117
|
+
if d>0:
|
|
118
|
+
d = min(nnf.shape[0], d)
|
|
119
|
+
upd_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
|
|
120
|
+
else:
|
|
121
|
+
d = max(-nnf.shape[0], d)
|
|
122
|
+
upd_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
|
|
123
|
+
return upd_nnf
|
|
124
|
+
|
|
125
|
+
def track_step(self, nnf, d):
|
|
126
|
+
if self.use_pairwise_patch_error:
|
|
127
|
+
upd_nnf = cp.zeros_like(nnf)
|
|
128
|
+
upd_nnf[0::2] = self.shift_nnf(nnf[0::2], d)
|
|
129
|
+
upd_nnf[1::2] = self.shift_nnf(nnf[1::2], d)
|
|
130
|
+
else:
|
|
131
|
+
upd_nnf = self.shift_nnf(nnf, d)
|
|
132
|
+
return upd_nnf
|
|
133
|
+
|
|
134
|
+
def C(self, n, m):
|
|
135
|
+
# not used
|
|
136
|
+
c = 1
|
|
137
|
+
for i in range(1, n+1):
|
|
138
|
+
c *= i
|
|
139
|
+
for i in range(1, m+1):
|
|
140
|
+
c //= i
|
|
141
|
+
for i in range(1, n-m+1):
|
|
142
|
+
c //= i
|
|
143
|
+
return c
|
|
144
|
+
|
|
145
|
+
def bezier_step(self, nnf, r):
|
|
146
|
+
# not used
|
|
147
|
+
n = r * 2 - 1
|
|
148
|
+
upd_nnf = cp.zeros(shape=nnf.shape, dtype=cp.float32)
|
|
149
|
+
for i, d in enumerate(list(range(-r, 0)) + list(range(1, r+1))):
|
|
150
|
+
if d>0:
|
|
151
|
+
ctl_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
|
|
152
|
+
elif d<0:
|
|
153
|
+
ctl_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
|
|
154
|
+
upd_nnf += ctl_nnf * (self.C(n, i) / 2**n)
|
|
155
|
+
upd_nnf = self.clamp_bound(upd_nnf).astype(nnf.dtype)
|
|
156
|
+
return upd_nnf
|
|
157
|
+
|
|
158
|
+
def update(self, source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf):
|
|
159
|
+
upd_err = self.get_error(source_guide, target_guide, source_style, target_style, upd_nnf)
|
|
160
|
+
upd_idx = (upd_err < err)
|
|
161
|
+
nnf[upd_idx] = upd_nnf[upd_idx]
|
|
162
|
+
err[upd_idx] = upd_err[upd_idx]
|
|
163
|
+
return nnf, err
|
|
164
|
+
|
|
165
|
+
def propagation(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
|
166
|
+
for d in cp.random.permutation(4):
|
|
167
|
+
upd_nnf = self.neighboor_step(nnf, d)
|
|
168
|
+
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
|
169
|
+
return nnf, err
|
|
170
|
+
|
|
171
|
+
def random_search(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
|
172
|
+
for i in range(self.random_search_steps):
|
|
173
|
+
upd_nnf = self.random_step(nnf, self.random_search_range)
|
|
174
|
+
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
|
175
|
+
return nnf, err
|
|
176
|
+
|
|
177
|
+
def track(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
|
178
|
+
for d in range(1, self.tracking_window_size + 1):
|
|
179
|
+
upd_nnf = self.track_step(nnf, d)
|
|
180
|
+
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
|
181
|
+
upd_nnf = self.track_step(nnf, -d)
|
|
182
|
+
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
|
183
|
+
return nnf, err
|
|
184
|
+
|
|
185
|
+
def iteration(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
|
186
|
+
nnf, err = self.propagation(source_guide, target_guide, source_style, target_style, nnf, err)
|
|
187
|
+
nnf, err = self.random_search(source_guide, target_guide, source_style, target_style, nnf, err)
|
|
188
|
+
nnf, err = self.track(source_guide, target_guide, source_style, target_style, nnf, err)
|
|
189
|
+
return nnf, err
|
|
190
|
+
|
|
191
|
+
def estimate_nnf(self, source_guide, target_guide, source_style, nnf):
|
|
192
|
+
with cp.cuda.Device(self.gpu_id):
|
|
193
|
+
source_guide = self.pad_image(source_guide)
|
|
194
|
+
target_guide = self.pad_image(target_guide)
|
|
195
|
+
source_style = self.pad_image(source_style)
|
|
196
|
+
for it in range(self.num_iter):
|
|
197
|
+
self.patch_size = self.patch_size_list[it]
|
|
198
|
+
target_style = self.apply_nnf_to_image(nnf, source_style)
|
|
199
|
+
err = self.get_error(source_guide, target_guide, source_style, target_style, nnf)
|
|
200
|
+
nnf, err = self.iteration(source_guide, target_guide, source_style, target_style, nnf, err)
|
|
201
|
+
target_style = self.unpad_image(self.apply_nnf_to_image(nnf, source_style))
|
|
202
|
+
return nnf, target_style
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class PyramidPatchMatcher:
|
|
206
|
+
def __init__(
|
|
207
|
+
self, image_height, image_width, channel, minimum_patch_size,
|
|
208
|
+
threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
|
|
209
|
+
use_mean_target_style=False, use_pairwise_patch_error=False,
|
|
210
|
+
tracking_window_size=0,
|
|
211
|
+
initialize="identity"
|
|
212
|
+
):
|
|
213
|
+
maximum_patch_size = minimum_patch_size + (num_iter - 1) * 2
|
|
214
|
+
self.pyramid_level = int(np.log2(min(image_height, image_width) / maximum_patch_size))
|
|
215
|
+
self.pyramid_heights = []
|
|
216
|
+
self.pyramid_widths = []
|
|
217
|
+
self.patch_matchers = []
|
|
218
|
+
self.minimum_patch_size = minimum_patch_size
|
|
219
|
+
self.num_iter = num_iter
|
|
220
|
+
self.gpu_id = gpu_id
|
|
221
|
+
self.initialize = initialize
|
|
222
|
+
for level in range(self.pyramid_level):
|
|
223
|
+
height = image_height//(2**(self.pyramid_level - 1 - level))
|
|
224
|
+
width = image_width//(2**(self.pyramid_level - 1 - level))
|
|
225
|
+
self.pyramid_heights.append(height)
|
|
226
|
+
self.pyramid_widths.append(width)
|
|
227
|
+
self.patch_matchers.append(PatchMatcher(
|
|
228
|
+
height, width, channel, minimum_patch_size=minimum_patch_size,
|
|
229
|
+
threads_per_block=threads_per_block, num_iter=num_iter, gpu_id=gpu_id, guide_weight=guide_weight,
|
|
230
|
+
use_mean_target_style=use_mean_target_style, use_pairwise_patch_error=use_pairwise_patch_error,
|
|
231
|
+
tracking_window_size=tracking_window_size
|
|
232
|
+
))
|
|
233
|
+
|
|
234
|
+
def resample_image(self, images, level):
|
|
235
|
+
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
|
|
236
|
+
images = images.get()
|
|
237
|
+
images_resample = []
|
|
238
|
+
for image in images:
|
|
239
|
+
image_resample = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
|
|
240
|
+
images_resample.append(image_resample)
|
|
241
|
+
images_resample = cp.array(np.stack(images_resample), dtype=cp.float32)
|
|
242
|
+
return images_resample
|
|
243
|
+
|
|
244
|
+
def initialize_nnf(self, batch_size):
|
|
245
|
+
if self.initialize == "random":
|
|
246
|
+
height, width = self.pyramid_heights[0], self.pyramid_widths[0]
|
|
247
|
+
nnf = cp.stack([
|
|
248
|
+
cp.random.randint(0, height, (batch_size, height, width), dtype=cp.int32),
|
|
249
|
+
cp.random.randint(0, width, (batch_size, height, width), dtype=cp.int32)
|
|
250
|
+
], axis=3)
|
|
251
|
+
elif self.initialize == "identity":
|
|
252
|
+
height, width = self.pyramid_heights[0], self.pyramid_widths[0]
|
|
253
|
+
nnf = cp.stack([
|
|
254
|
+
cp.repeat(cp.arange(height), width).reshape(height, width),
|
|
255
|
+
cp.tile(cp.arange(width), height).reshape(height, width)
|
|
256
|
+
], axis=2)
|
|
257
|
+
nnf = cp.stack([nnf] * batch_size)
|
|
258
|
+
else:
|
|
259
|
+
raise NotImplementedError()
|
|
260
|
+
return nnf
|
|
261
|
+
|
|
262
|
+
def update_nnf(self, nnf, level):
|
|
263
|
+
# upscale
|
|
264
|
+
nnf = nnf.repeat(2, axis=1).repeat(2, axis=2) * 2
|
|
265
|
+
nnf[:,[i for i in range(nnf.shape[0]) if i&1],:,0] += 1
|
|
266
|
+
nnf[:,:,[i for i in range(nnf.shape[0]) if i&1],1] += 1
|
|
267
|
+
# check if scale is 2
|
|
268
|
+
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
|
|
269
|
+
if height != nnf.shape[0] * 2 or width != nnf.shape[1] * 2:
|
|
270
|
+
nnf = nnf.get().astype(np.float32)
|
|
271
|
+
nnf = [cv2.resize(n, (width, height), interpolation=cv2.INTER_LINEAR) for n in nnf]
|
|
272
|
+
nnf = cp.array(np.stack(nnf), dtype=cp.int32)
|
|
273
|
+
nnf = self.patch_matchers[level].clamp_bound(nnf)
|
|
274
|
+
return nnf
|
|
275
|
+
|
|
276
|
+
def apply_nnf_to_image(self, nnf, image):
|
|
277
|
+
with cp.cuda.Device(self.gpu_id):
|
|
278
|
+
image = self.patch_matchers[-1].pad_image(image)
|
|
279
|
+
image = self.patch_matchers[-1].apply_nnf_to_image(nnf, image)
|
|
280
|
+
return image
|
|
281
|
+
|
|
282
|
+
def estimate_nnf(self, source_guide, target_guide, source_style):
|
|
283
|
+
with cp.cuda.Device(self.gpu_id):
|
|
284
|
+
if not isinstance(source_guide, cp.ndarray):
|
|
285
|
+
source_guide = cp.array(source_guide, dtype=cp.float32)
|
|
286
|
+
if not isinstance(target_guide, cp.ndarray):
|
|
287
|
+
target_guide = cp.array(target_guide, dtype=cp.float32)
|
|
288
|
+
if not isinstance(source_style, cp.ndarray):
|
|
289
|
+
source_style = cp.array(source_style, dtype=cp.float32)
|
|
290
|
+
for level in range(self.pyramid_level):
|
|
291
|
+
nnf = self.initialize_nnf(source_guide.shape[0]) if level==0 else self.update_nnf(nnf, level)
|
|
292
|
+
source_guide_ = self.resample_image(source_guide, level)
|
|
293
|
+
target_guide_ = self.resample_image(target_guide, level)
|
|
294
|
+
source_style_ = self.resample_image(source_style, level)
|
|
295
|
+
nnf, target_style = self.patch_matchers[level].estimate_nnf(
|
|
296
|
+
source_guide_, target_guide_, source_style_, nnf
|
|
297
|
+
)
|
|
298
|
+
return nnf.get(), target_style.get()
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from ..patch_match import PyramidPatchMatcher
|
|
2
|
+
import os
|
|
3
|
+
import numpy as np
|
|
4
|
+
from PIL import Image
|
|
5
|
+
from tqdm import tqdm
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class AccurateModeRunner:
|
|
9
|
+
def __init__(self):
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Accurate Mode", save_path=None):
|
|
13
|
+
patch_match_engine = PyramidPatchMatcher(
|
|
14
|
+
image_height=frames_style[0].shape[0],
|
|
15
|
+
image_width=frames_style[0].shape[1],
|
|
16
|
+
channel=3,
|
|
17
|
+
use_mean_target_style=True,
|
|
18
|
+
**ebsynth_config
|
|
19
|
+
)
|
|
20
|
+
# run
|
|
21
|
+
n = len(frames_style)
|
|
22
|
+
for target in tqdm(range(n), desc=desc):
|
|
23
|
+
l, r = max(target - window_size, 0), min(target + window_size + 1, n)
|
|
24
|
+
remapped_frames = []
|
|
25
|
+
for i in range(l, r, batch_size):
|
|
26
|
+
j = min(i + batch_size, r)
|
|
27
|
+
source_guide = np.stack([frames_guide[source] for source in range(i, j)])
|
|
28
|
+
target_guide = np.stack([frames_guide[target]] * (j - i))
|
|
29
|
+
source_style = np.stack([frames_style[source] for source in range(i, j)])
|
|
30
|
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
|
31
|
+
remapped_frames.append(target_style)
|
|
32
|
+
frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
|
|
33
|
+
frame = frame.clip(0, 255).astype("uint8")
|
|
34
|
+
if save_path is not None:
|
|
35
|
+
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from ..patch_match import PyramidPatchMatcher
|
|
2
|
+
import os
|
|
3
|
+
import numpy as np
|
|
4
|
+
from PIL import Image
|
|
5
|
+
from tqdm import tqdm
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BalancedModeRunner:
|
|
9
|
+
def __init__(self):
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Balanced Mode", save_path=None):
|
|
13
|
+
patch_match_engine = PyramidPatchMatcher(
|
|
14
|
+
image_height=frames_style[0].shape[0],
|
|
15
|
+
image_width=frames_style[0].shape[1],
|
|
16
|
+
channel=3,
|
|
17
|
+
**ebsynth_config
|
|
18
|
+
)
|
|
19
|
+
# tasks
|
|
20
|
+
n = len(frames_style)
|
|
21
|
+
tasks = []
|
|
22
|
+
for target in range(n):
|
|
23
|
+
for source in range(target - window_size, target + window_size + 1):
|
|
24
|
+
if source >= 0 and source < n and source != target:
|
|
25
|
+
tasks.append((source, target))
|
|
26
|
+
# run
|
|
27
|
+
frames = [(None, 1) for i in range(n)]
|
|
28
|
+
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
|
29
|
+
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
|
30
|
+
source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
|
|
31
|
+
target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
|
|
32
|
+
source_style = np.stack([frames_style[source] for source, target in tasks_batch])
|
|
33
|
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
|
34
|
+
for (source, target), result in zip(tasks_batch, target_style):
|
|
35
|
+
frame, weight = frames[target]
|
|
36
|
+
if frame is None:
|
|
37
|
+
frame = frames_style[target]
|
|
38
|
+
frames[target] = (
|
|
39
|
+
frame * (weight / (weight + 1)) + result / (weight + 1),
|
|
40
|
+
weight + 1
|
|
41
|
+
)
|
|
42
|
+
if weight + 1 == min(n, target + window_size + 1) - max(0, target - window_size):
|
|
43
|
+
frame = frame.clip(0, 255).astype("uint8")
|
|
44
|
+
if save_path is not None:
|
|
45
|
+
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
|
46
|
+
frames[target] = (None, 1)
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
from ..patch_match import PyramidPatchMatcher
|
|
2
|
+
import functools, os
|
|
3
|
+
import numpy as np
|
|
4
|
+
from PIL import Image
|
|
5
|
+
from tqdm import tqdm
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TableManager:
|
|
9
|
+
def __init__(self):
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
def task_list(self, n):
|
|
13
|
+
tasks = []
|
|
14
|
+
max_level = 1
|
|
15
|
+
while (1<<max_level)<=n:
|
|
16
|
+
max_level += 1
|
|
17
|
+
for i in range(n):
|
|
18
|
+
j = i
|
|
19
|
+
for level in range(max_level):
|
|
20
|
+
if i&(1<<level):
|
|
21
|
+
continue
|
|
22
|
+
j |= 1<<level
|
|
23
|
+
if j>=n:
|
|
24
|
+
break
|
|
25
|
+
meta_data = {
|
|
26
|
+
"source": i,
|
|
27
|
+
"target": j,
|
|
28
|
+
"level": level + 1
|
|
29
|
+
}
|
|
30
|
+
tasks.append(meta_data)
|
|
31
|
+
tasks.sort(key=functools.cmp_to_key(lambda u, v: u["level"]-v["level"]))
|
|
32
|
+
return tasks
|
|
33
|
+
|
|
34
|
+
def build_remapping_table(self, frames_guide, frames_style, patch_match_engine, batch_size, desc=""):
|
|
35
|
+
n = len(frames_guide)
|
|
36
|
+
tasks = self.task_list(n)
|
|
37
|
+
remapping_table = [[(frames_style[i], 1)] for i in range(n)]
|
|
38
|
+
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
|
39
|
+
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
|
40
|
+
source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
|
|
41
|
+
target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
|
|
42
|
+
source_style = np.stack([frames_style[task["source"]] for task in tasks_batch])
|
|
43
|
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
|
44
|
+
for task, result in zip(tasks_batch, target_style):
|
|
45
|
+
target, level = task["target"], task["level"]
|
|
46
|
+
if len(remapping_table[target])==level:
|
|
47
|
+
remapping_table[target].append((result, 1))
|
|
48
|
+
else:
|
|
49
|
+
frame, weight = remapping_table[target][level]
|
|
50
|
+
remapping_table[target][level] = (
|
|
51
|
+
frame * (weight / (weight + 1)) + result / (weight + 1),
|
|
52
|
+
weight + 1
|
|
53
|
+
)
|
|
54
|
+
return remapping_table
|
|
55
|
+
|
|
56
|
+
def remapping_table_to_blending_table(self, table):
|
|
57
|
+
for i in range(len(table)):
|
|
58
|
+
for j in range(1, len(table[i])):
|
|
59
|
+
frame_1, weight_1 = table[i][j-1]
|
|
60
|
+
frame_2, weight_2 = table[i][j]
|
|
61
|
+
frame = (frame_1 + frame_2) / 2
|
|
62
|
+
weight = weight_1 + weight_2
|
|
63
|
+
table[i][j] = (frame, weight)
|
|
64
|
+
return table
|
|
65
|
+
|
|
66
|
+
def tree_query(self, leftbound, rightbound):
|
|
67
|
+
node_list = []
|
|
68
|
+
node_index = rightbound
|
|
69
|
+
while node_index>=leftbound:
|
|
70
|
+
node_level = 0
|
|
71
|
+
while (1<<node_level)&node_index and node_index-(1<<node_level+1)+1>=leftbound:
|
|
72
|
+
node_level += 1
|
|
73
|
+
node_list.append((node_index, node_level))
|
|
74
|
+
node_index -= 1<<node_level
|
|
75
|
+
return node_list
|
|
76
|
+
|
|
77
|
+
def process_window_sum(self, frames_guide, blending_table, patch_match_engine, window_size, batch_size, desc=""):
|
|
78
|
+
n = len(blending_table)
|
|
79
|
+
tasks = []
|
|
80
|
+
frames_result = []
|
|
81
|
+
for target in range(n):
|
|
82
|
+
node_list = self.tree_query(max(target-window_size, 0), target)
|
|
83
|
+
for source, level in node_list:
|
|
84
|
+
if source!=target:
|
|
85
|
+
meta_data = {
|
|
86
|
+
"source": source,
|
|
87
|
+
"target": target,
|
|
88
|
+
"level": level
|
|
89
|
+
}
|
|
90
|
+
tasks.append(meta_data)
|
|
91
|
+
else:
|
|
92
|
+
frames_result.append(blending_table[target][level])
|
|
93
|
+
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
|
94
|
+
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
|
95
|
+
source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
|
|
96
|
+
target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
|
|
97
|
+
source_style = np.stack([blending_table[task["source"]][task["level"]][0] for task in tasks_batch])
|
|
98
|
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
|
99
|
+
for task, frame_2 in zip(tasks_batch, target_style):
|
|
100
|
+
source, target, level = task["source"], task["target"], task["level"]
|
|
101
|
+
frame_1, weight_1 = frames_result[target]
|
|
102
|
+
weight_2 = blending_table[source][level][1]
|
|
103
|
+
weight = weight_1 + weight_2
|
|
104
|
+
frame = frame_1 * (weight_1 / weight) + frame_2 * (weight_2 / weight)
|
|
105
|
+
frames_result[target] = (frame, weight)
|
|
106
|
+
return frames_result
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class FastModeRunner:
|
|
110
|
+
def __init__(self):
|
|
111
|
+
pass
|
|
112
|
+
|
|
113
|
+
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, save_path=None):
|
|
114
|
+
frames_guide = frames_guide.raw_data()
|
|
115
|
+
frames_style = frames_style.raw_data()
|
|
116
|
+
table_manager = TableManager()
|
|
117
|
+
patch_match_engine = PyramidPatchMatcher(
|
|
118
|
+
image_height=frames_style[0].shape[0],
|
|
119
|
+
image_width=frames_style[0].shape[1],
|
|
120
|
+
channel=3,
|
|
121
|
+
**ebsynth_config
|
|
122
|
+
)
|
|
123
|
+
# left part
|
|
124
|
+
table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="Fast Mode Step 1/4")
|
|
125
|
+
table_l = table_manager.remapping_table_to_blending_table(table_l)
|
|
126
|
+
table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 2/4")
|
|
127
|
+
# right part
|
|
128
|
+
table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="Fast Mode Step 3/4")
|
|
129
|
+
table_r = table_manager.remapping_table_to_blending_table(table_r)
|
|
130
|
+
table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 4/4")[::-1]
|
|
131
|
+
# merge
|
|
132
|
+
frames = []
|
|
133
|
+
for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
|
|
134
|
+
weight_m = -1
|
|
135
|
+
weight = weight_l + weight_m + weight_r
|
|
136
|
+
frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
|
|
137
|
+
frames.append(frame)
|
|
138
|
+
frames = [frame.clip(0, 255).astype("uint8") for frame in frames]
|
|
139
|
+
if save_path is not None:
|
|
140
|
+
for target, frame in enumerate(frames):
|
|
141
|
+
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|