diffsynth 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. diffsynth/__init__.py +6 -0
  2. diffsynth/configs/__init__.py +0 -0
  3. diffsynth/configs/model_config.py +243 -0
  4. diffsynth/controlnets/__init__.py +2 -0
  5. diffsynth/controlnets/controlnet_unit.py +53 -0
  6. diffsynth/controlnets/processors.py +51 -0
  7. diffsynth/data/__init__.py +1 -0
  8. diffsynth/data/simple_text_image.py +35 -0
  9. diffsynth/data/video.py +148 -0
  10. diffsynth/extensions/ESRGAN/__init__.py +118 -0
  11. diffsynth/extensions/FastBlend/__init__.py +63 -0
  12. diffsynth/extensions/FastBlend/api.py +397 -0
  13. diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
  14. diffsynth/extensions/FastBlend/data.py +146 -0
  15. diffsynth/extensions/FastBlend/patch_match.py +298 -0
  16. diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
  17. diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
  18. diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
  19. diffsynth/extensions/FastBlend/runners/fast.py +141 -0
  20. diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
  21. diffsynth/extensions/RIFE/__init__.py +242 -0
  22. diffsynth/extensions/__init__.py +0 -0
  23. diffsynth/models/__init__.py +1 -0
  24. diffsynth/models/attention.py +89 -0
  25. diffsynth/models/downloader.py +66 -0
  26. diffsynth/models/hunyuan_dit.py +451 -0
  27. diffsynth/models/hunyuan_dit_text_encoder.py +163 -0
  28. diffsynth/models/kolors_text_encoder.py +1363 -0
  29. diffsynth/models/lora.py +195 -0
  30. diffsynth/models/model_manager.py +536 -0
  31. diffsynth/models/sd3_dit.py +798 -0
  32. diffsynth/models/sd3_text_encoder.py +1107 -0
  33. diffsynth/models/sd3_vae_decoder.py +81 -0
  34. diffsynth/models/sd3_vae_encoder.py +95 -0
  35. diffsynth/models/sd_controlnet.py +588 -0
  36. diffsynth/models/sd_ipadapter.py +57 -0
  37. diffsynth/models/sd_motion.py +199 -0
  38. diffsynth/models/sd_text_encoder.py +321 -0
  39. diffsynth/models/sd_unet.py +1108 -0
  40. diffsynth/models/sd_vae_decoder.py +336 -0
  41. diffsynth/models/sd_vae_encoder.py +282 -0
  42. diffsynth/models/sdxl_ipadapter.py +122 -0
  43. diffsynth/models/sdxl_motion.py +104 -0
  44. diffsynth/models/sdxl_text_encoder.py +759 -0
  45. diffsynth/models/sdxl_unet.py +1899 -0
  46. diffsynth/models/sdxl_vae_decoder.py +24 -0
  47. diffsynth/models/sdxl_vae_encoder.py +24 -0
  48. diffsynth/models/svd_image_encoder.py +505 -0
  49. diffsynth/models/svd_unet.py +2004 -0
  50. diffsynth/models/svd_vae_decoder.py +578 -0
  51. diffsynth/models/svd_vae_encoder.py +139 -0
  52. diffsynth/models/tiler.py +106 -0
  53. diffsynth/pipelines/__init__.py +9 -0
  54. diffsynth/pipelines/base.py +34 -0
  55. diffsynth/pipelines/dancer.py +178 -0
  56. diffsynth/pipelines/hunyuan_image.py +274 -0
  57. diffsynth/pipelines/pipeline_runner.py +105 -0
  58. diffsynth/pipelines/sd3_image.py +132 -0
  59. diffsynth/pipelines/sd_image.py +173 -0
  60. diffsynth/pipelines/sd_video.py +266 -0
  61. diffsynth/pipelines/sdxl_image.py +191 -0
  62. diffsynth/pipelines/sdxl_video.py +223 -0
  63. diffsynth/pipelines/svd_video.py +297 -0
  64. diffsynth/processors/FastBlend.py +142 -0
  65. diffsynth/processors/PILEditor.py +28 -0
  66. diffsynth/processors/RIFE.py +77 -0
  67. diffsynth/processors/__init__.py +0 -0
  68. diffsynth/processors/base.py +6 -0
  69. diffsynth/processors/sequencial_processor.py +41 -0
  70. diffsynth/prompters/__init__.py +6 -0
  71. diffsynth/prompters/base_prompter.py +57 -0
  72. diffsynth/prompters/hunyuan_dit_prompter.py +69 -0
  73. diffsynth/prompters/kolors_prompter.py +353 -0
  74. diffsynth/prompters/prompt_refiners.py +77 -0
  75. diffsynth/prompters/sd3_prompter.py +92 -0
  76. diffsynth/prompters/sd_prompter.py +73 -0
  77. diffsynth/prompters/sdxl_prompter.py +61 -0
  78. diffsynth/schedulers/__init__.py +3 -0
  79. diffsynth/schedulers/continuous_ode.py +59 -0
  80. diffsynth/schedulers/ddim.py +79 -0
  81. diffsynth/schedulers/flow_match.py +51 -0
  82. diffsynth/tokenizer_configs/__init__.py +0 -0
  83. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json +7 -0
  84. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json +16 -0
  85. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt +47020 -0
  86. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt +21128 -0
  87. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json +28 -0
  88. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json +1 -0
  89. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model +0 -0
  90. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json +1 -0
  91. diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model +0 -0
  92. diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json +12 -0
  93. diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt +0 -0
  94. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt +48895 -0
  95. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json +24 -0
  96. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json +34 -0
  97. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json +49410 -0
  98. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt +48895 -0
  99. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json +30 -0
  100. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json +30 -0
  101. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json +49410 -0
  102. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt +48895 -0
  103. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json +30 -0
  104. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json +38 -0
  105. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json +49410 -0
  106. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json +125 -0
  107. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/spiece.model +0 -0
  108. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json +129428 -0
  109. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json +940 -0
  110. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt +40213 -0
  111. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json +24 -0
  112. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json +38 -0
  113. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/vocab.json +49411 -0
  114. diffsynth/trainers/__init__.py +0 -0
  115. diffsynth/trainers/text_to_image.py +253 -0
  116. diffsynth-1.0.0.dist-info/LICENSE +201 -0
  117. diffsynth-1.0.0.dist-info/METADATA +23 -0
  118. diffsynth-1.0.0.dist-info/RECORD +120 -0
  119. diffsynth-1.0.0.dist-info/WHEEL +5 -0
  120. diffsynth-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,146 @@
1
+ import imageio, os
2
+ import numpy as np
3
+ from PIL import Image
4
+
5
+
6
+ def read_video(file_name):
7
+ reader = imageio.get_reader(file_name)
8
+ video = []
9
+ for frame in reader:
10
+ frame = np.array(frame)
11
+ video.append(frame)
12
+ reader.close()
13
+ return video
14
+
15
+
16
+ def get_video_fps(file_name):
17
+ reader = imageio.get_reader(file_name)
18
+ fps = reader.get_meta_data()["fps"]
19
+ reader.close()
20
+ return fps
21
+
22
+
23
+ def save_video(frames_path, video_path, num_frames, fps):
24
+ writer = imageio.get_writer(video_path, fps=fps, quality=9)
25
+ for i in range(num_frames):
26
+ frame = np.array(Image.open(os.path.join(frames_path, "%05d.png" % i)))
27
+ writer.append_data(frame)
28
+ writer.close()
29
+ return video_path
30
+
31
+
32
+ class LowMemoryVideo:
33
+ def __init__(self, file_name):
34
+ self.reader = imageio.get_reader(file_name)
35
+
36
+ def __len__(self):
37
+ return self.reader.count_frames()
38
+
39
+ def __getitem__(self, item):
40
+ return np.array(self.reader.get_data(item))
41
+
42
+ def __del__(self):
43
+ self.reader.close()
44
+
45
+
46
+ def split_file_name(file_name):
47
+ result = []
48
+ number = -1
49
+ for i in file_name:
50
+ if ord(i)>=ord("0") and ord(i)<=ord("9"):
51
+ if number == -1:
52
+ number = 0
53
+ number = number*10 + ord(i) - ord("0")
54
+ else:
55
+ if number != -1:
56
+ result.append(number)
57
+ number = -1
58
+ result.append(i)
59
+ if number != -1:
60
+ result.append(number)
61
+ result = tuple(result)
62
+ return result
63
+
64
+
65
+ def search_for_images(folder):
66
+ file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
67
+ file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
68
+ file_list = [i[1] for i in sorted(file_list)]
69
+ file_list = [os.path.join(folder, i) for i in file_list]
70
+ return file_list
71
+
72
+
73
+ def read_images(folder):
74
+ file_list = search_for_images(folder)
75
+ frames = [np.array(Image.open(i)) for i in file_list]
76
+ return frames
77
+
78
+
79
+ class LowMemoryImageFolder:
80
+ def __init__(self, folder, file_list=None):
81
+ if file_list is None:
82
+ self.file_list = search_for_images(folder)
83
+ else:
84
+ self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
85
+
86
+ def __len__(self):
87
+ return len(self.file_list)
88
+
89
+ def __getitem__(self, item):
90
+ return np.array(Image.open(self.file_list[item]))
91
+
92
+ def __del__(self):
93
+ pass
94
+
95
+
96
+ class VideoData:
97
+ def __init__(self, video_file, image_folder, **kwargs):
98
+ if video_file is not None:
99
+ self.data_type = "video"
100
+ self.data = LowMemoryVideo(video_file, **kwargs)
101
+ elif image_folder is not None:
102
+ self.data_type = "images"
103
+ self.data = LowMemoryImageFolder(image_folder, **kwargs)
104
+ else:
105
+ raise ValueError("Cannot open video or image folder")
106
+ self.length = None
107
+ self.height = None
108
+ self.width = None
109
+
110
+ def raw_data(self):
111
+ frames = []
112
+ for i in range(self.__len__()):
113
+ frames.append(self.__getitem__(i))
114
+ return frames
115
+
116
+ def set_length(self, length):
117
+ self.length = length
118
+
119
+ def set_shape(self, height, width):
120
+ self.height = height
121
+ self.width = width
122
+
123
+ def __len__(self):
124
+ if self.length is None:
125
+ return len(self.data)
126
+ else:
127
+ return self.length
128
+
129
+ def shape(self):
130
+ if self.height is not None and self.width is not None:
131
+ return self.height, self.width
132
+ else:
133
+ height, width, _ = self.__getitem__(0).shape
134
+ return height, width
135
+
136
+ def __getitem__(self, item):
137
+ frame = self.data.__getitem__(item)
138
+ height, width, _ = frame.shape
139
+ if self.height is not None and self.width is not None:
140
+ if self.height != height or self.width != width:
141
+ frame = Image.fromarray(frame).resize((self.width, self.height))
142
+ frame = np.array(frame)
143
+ return frame
144
+
145
+ def __del__(self):
146
+ pass
@@ -0,0 +1,298 @@
1
+ from .cupy_kernels import remapping_kernel, patch_error_kernel, pairwise_patch_error_kernel
2
+ import numpy as np
3
+ import cupy as cp
4
+ import cv2
5
+
6
+
7
+ class PatchMatcher:
8
+ def __init__(
9
+ self, height, width, channel, minimum_patch_size,
10
+ threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
11
+ random_search_steps=3, random_search_range=4,
12
+ use_mean_target_style=False, use_pairwise_patch_error=False,
13
+ tracking_window_size=0
14
+ ):
15
+ self.height = height
16
+ self.width = width
17
+ self.channel = channel
18
+ self.minimum_patch_size = minimum_patch_size
19
+ self.threads_per_block = threads_per_block
20
+ self.num_iter = num_iter
21
+ self.gpu_id = gpu_id
22
+ self.guide_weight = guide_weight
23
+ self.random_search_steps = random_search_steps
24
+ self.random_search_range = random_search_range
25
+ self.use_mean_target_style = use_mean_target_style
26
+ self.use_pairwise_patch_error = use_pairwise_patch_error
27
+ self.tracking_window_size = tracking_window_size
28
+
29
+ self.patch_size_list = [minimum_patch_size + i*2 for i in range(num_iter)][::-1]
30
+ self.pad_size = self.patch_size_list[0] // 2
31
+ self.grid = (
32
+ (height + threads_per_block - 1) // threads_per_block,
33
+ (width + threads_per_block - 1) // threads_per_block
34
+ )
35
+ self.block = (threads_per_block, threads_per_block)
36
+
37
+ def pad_image(self, image):
38
+ return cp.pad(image, ((0, 0), (self.pad_size, self.pad_size), (self.pad_size, self.pad_size), (0, 0)))
39
+
40
+ def unpad_image(self, image):
41
+ return image[:, self.pad_size: -self.pad_size, self.pad_size: -self.pad_size, :]
42
+
43
+ def apply_nnf_to_image(self, nnf, source):
44
+ batch_size = source.shape[0]
45
+ target = cp.zeros((batch_size, self.height + self.pad_size * 2, self.width + self.pad_size * 2, self.channel), dtype=cp.float32)
46
+ remapping_kernel(
47
+ self.grid + (batch_size,),
48
+ self.block,
49
+ (self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target)
50
+ )
51
+ return target
52
+
53
+ def get_patch_error(self, source, nnf, target):
54
+ batch_size = source.shape[0]
55
+ error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
56
+ patch_error_kernel(
57
+ self.grid + (batch_size,),
58
+ self.block,
59
+ (self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target, error)
60
+ )
61
+ return error
62
+
63
+ def get_pairwise_patch_error(self, source, nnf):
64
+ batch_size = source.shape[0]//2
65
+ error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
66
+ source_a, nnf_a = source[0::2].copy(), nnf[0::2].copy()
67
+ source_b, nnf_b = source[1::2].copy(), nnf[1::2].copy()
68
+ pairwise_patch_error_kernel(
69
+ self.grid + (batch_size,),
70
+ self.block,
71
+ (self.height, self.width, self.channel, self.patch_size, self.pad_size, source_a, nnf_a, source_b, nnf_b, error)
72
+ )
73
+ error = error.repeat(2, axis=0)
74
+ return error
75
+
76
+ def get_error(self, source_guide, target_guide, source_style, target_style, nnf):
77
+ error_guide = self.get_patch_error(source_guide, nnf, target_guide)
78
+ if self.use_mean_target_style:
79
+ target_style = self.apply_nnf_to_image(nnf, source_style)
80
+ target_style = target_style.mean(axis=0, keepdims=True)
81
+ target_style = target_style.repeat(source_guide.shape[0], axis=0)
82
+ if self.use_pairwise_patch_error:
83
+ error_style = self.get_pairwise_patch_error(source_style, nnf)
84
+ else:
85
+ error_style = self.get_patch_error(source_style, nnf, target_style)
86
+ error = error_guide * self.guide_weight + error_style
87
+ return error
88
+
89
+ def clamp_bound(self, nnf):
90
+ nnf[:,:,:,0] = cp.clip(nnf[:,:,:,0], 0, self.height-1)
91
+ nnf[:,:,:,1] = cp.clip(nnf[:,:,:,1], 0, self.width-1)
92
+ return nnf
93
+
94
+ def random_step(self, nnf, r):
95
+ batch_size = nnf.shape[0]
96
+ step = cp.random.randint(-r, r+1, size=(batch_size, self.height, self.width, 2), dtype=cp.int32)
97
+ upd_nnf = self.clamp_bound(nnf + step)
98
+ return upd_nnf
99
+
100
+ def neighboor_step(self, nnf, d):
101
+ if d==0:
102
+ upd_nnf = cp.concatenate([nnf[:, :1, :], nnf[:, :-1, :]], axis=1)
103
+ upd_nnf[:, :, :, 0] += 1
104
+ elif d==1:
105
+ upd_nnf = cp.concatenate([nnf[:, :, :1], nnf[:, :, :-1]], axis=2)
106
+ upd_nnf[:, :, :, 1] += 1
107
+ elif d==2:
108
+ upd_nnf = cp.concatenate([nnf[:, 1:, :], nnf[:, -1:, :]], axis=1)
109
+ upd_nnf[:, :, :, 0] -= 1
110
+ elif d==3:
111
+ upd_nnf = cp.concatenate([nnf[:, :, 1:], nnf[:, :, -1:]], axis=2)
112
+ upd_nnf[:, :, :, 1] -= 1
113
+ upd_nnf = self.clamp_bound(upd_nnf)
114
+ return upd_nnf
115
+
116
+ def shift_nnf(self, nnf, d):
117
+ if d>0:
118
+ d = min(nnf.shape[0], d)
119
+ upd_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
120
+ else:
121
+ d = max(-nnf.shape[0], d)
122
+ upd_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
123
+ return upd_nnf
124
+
125
+ def track_step(self, nnf, d):
126
+ if self.use_pairwise_patch_error:
127
+ upd_nnf = cp.zeros_like(nnf)
128
+ upd_nnf[0::2] = self.shift_nnf(nnf[0::2], d)
129
+ upd_nnf[1::2] = self.shift_nnf(nnf[1::2], d)
130
+ else:
131
+ upd_nnf = self.shift_nnf(nnf, d)
132
+ return upd_nnf
133
+
134
+ def C(self, n, m):
135
+ # not used
136
+ c = 1
137
+ for i in range(1, n+1):
138
+ c *= i
139
+ for i in range(1, m+1):
140
+ c //= i
141
+ for i in range(1, n-m+1):
142
+ c //= i
143
+ return c
144
+
145
+ def bezier_step(self, nnf, r):
146
+ # not used
147
+ n = r * 2 - 1
148
+ upd_nnf = cp.zeros(shape=nnf.shape, dtype=cp.float32)
149
+ for i, d in enumerate(list(range(-r, 0)) + list(range(1, r+1))):
150
+ if d>0:
151
+ ctl_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
152
+ elif d<0:
153
+ ctl_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
154
+ upd_nnf += ctl_nnf * (self.C(n, i) / 2**n)
155
+ upd_nnf = self.clamp_bound(upd_nnf).astype(nnf.dtype)
156
+ return upd_nnf
157
+
158
+ def update(self, source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf):
159
+ upd_err = self.get_error(source_guide, target_guide, source_style, target_style, upd_nnf)
160
+ upd_idx = (upd_err < err)
161
+ nnf[upd_idx] = upd_nnf[upd_idx]
162
+ err[upd_idx] = upd_err[upd_idx]
163
+ return nnf, err
164
+
165
+ def propagation(self, source_guide, target_guide, source_style, target_style, nnf, err):
166
+ for d in cp.random.permutation(4):
167
+ upd_nnf = self.neighboor_step(nnf, d)
168
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
169
+ return nnf, err
170
+
171
+ def random_search(self, source_guide, target_guide, source_style, target_style, nnf, err):
172
+ for i in range(self.random_search_steps):
173
+ upd_nnf = self.random_step(nnf, self.random_search_range)
174
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
175
+ return nnf, err
176
+
177
+ def track(self, source_guide, target_guide, source_style, target_style, nnf, err):
178
+ for d in range(1, self.tracking_window_size + 1):
179
+ upd_nnf = self.track_step(nnf, d)
180
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
181
+ upd_nnf = self.track_step(nnf, -d)
182
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
183
+ return nnf, err
184
+
185
+ def iteration(self, source_guide, target_guide, source_style, target_style, nnf, err):
186
+ nnf, err = self.propagation(source_guide, target_guide, source_style, target_style, nnf, err)
187
+ nnf, err = self.random_search(source_guide, target_guide, source_style, target_style, nnf, err)
188
+ nnf, err = self.track(source_guide, target_guide, source_style, target_style, nnf, err)
189
+ return nnf, err
190
+
191
+ def estimate_nnf(self, source_guide, target_guide, source_style, nnf):
192
+ with cp.cuda.Device(self.gpu_id):
193
+ source_guide = self.pad_image(source_guide)
194
+ target_guide = self.pad_image(target_guide)
195
+ source_style = self.pad_image(source_style)
196
+ for it in range(self.num_iter):
197
+ self.patch_size = self.patch_size_list[it]
198
+ target_style = self.apply_nnf_to_image(nnf, source_style)
199
+ err = self.get_error(source_guide, target_guide, source_style, target_style, nnf)
200
+ nnf, err = self.iteration(source_guide, target_guide, source_style, target_style, nnf, err)
201
+ target_style = self.unpad_image(self.apply_nnf_to_image(nnf, source_style))
202
+ return nnf, target_style
203
+
204
+
205
+ class PyramidPatchMatcher:
206
+ def __init__(
207
+ self, image_height, image_width, channel, minimum_patch_size,
208
+ threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
209
+ use_mean_target_style=False, use_pairwise_patch_error=False,
210
+ tracking_window_size=0,
211
+ initialize="identity"
212
+ ):
213
+ maximum_patch_size = minimum_patch_size + (num_iter - 1) * 2
214
+ self.pyramid_level = int(np.log2(min(image_height, image_width) / maximum_patch_size))
215
+ self.pyramid_heights = []
216
+ self.pyramid_widths = []
217
+ self.patch_matchers = []
218
+ self.minimum_patch_size = minimum_patch_size
219
+ self.num_iter = num_iter
220
+ self.gpu_id = gpu_id
221
+ self.initialize = initialize
222
+ for level in range(self.pyramid_level):
223
+ height = image_height//(2**(self.pyramid_level - 1 - level))
224
+ width = image_width//(2**(self.pyramid_level - 1 - level))
225
+ self.pyramid_heights.append(height)
226
+ self.pyramid_widths.append(width)
227
+ self.patch_matchers.append(PatchMatcher(
228
+ height, width, channel, minimum_patch_size=minimum_patch_size,
229
+ threads_per_block=threads_per_block, num_iter=num_iter, gpu_id=gpu_id, guide_weight=guide_weight,
230
+ use_mean_target_style=use_mean_target_style, use_pairwise_patch_error=use_pairwise_patch_error,
231
+ tracking_window_size=tracking_window_size
232
+ ))
233
+
234
+ def resample_image(self, images, level):
235
+ height, width = self.pyramid_heights[level], self.pyramid_widths[level]
236
+ images = images.get()
237
+ images_resample = []
238
+ for image in images:
239
+ image_resample = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
240
+ images_resample.append(image_resample)
241
+ images_resample = cp.array(np.stack(images_resample), dtype=cp.float32)
242
+ return images_resample
243
+
244
+ def initialize_nnf(self, batch_size):
245
+ if self.initialize == "random":
246
+ height, width = self.pyramid_heights[0], self.pyramid_widths[0]
247
+ nnf = cp.stack([
248
+ cp.random.randint(0, height, (batch_size, height, width), dtype=cp.int32),
249
+ cp.random.randint(0, width, (batch_size, height, width), dtype=cp.int32)
250
+ ], axis=3)
251
+ elif self.initialize == "identity":
252
+ height, width = self.pyramid_heights[0], self.pyramid_widths[0]
253
+ nnf = cp.stack([
254
+ cp.repeat(cp.arange(height), width).reshape(height, width),
255
+ cp.tile(cp.arange(width), height).reshape(height, width)
256
+ ], axis=2)
257
+ nnf = cp.stack([nnf] * batch_size)
258
+ else:
259
+ raise NotImplementedError()
260
+ return nnf
261
+
262
+ def update_nnf(self, nnf, level):
263
+ # upscale
264
+ nnf = nnf.repeat(2, axis=1).repeat(2, axis=2) * 2
265
+ nnf[:,[i for i in range(nnf.shape[0]) if i&1],:,0] += 1
266
+ nnf[:,:,[i for i in range(nnf.shape[0]) if i&1],1] += 1
267
+ # check if scale is 2
268
+ height, width = self.pyramid_heights[level], self.pyramid_widths[level]
269
+ if height != nnf.shape[0] * 2 or width != nnf.shape[1] * 2:
270
+ nnf = nnf.get().astype(np.float32)
271
+ nnf = [cv2.resize(n, (width, height), interpolation=cv2.INTER_LINEAR) for n in nnf]
272
+ nnf = cp.array(np.stack(nnf), dtype=cp.int32)
273
+ nnf = self.patch_matchers[level].clamp_bound(nnf)
274
+ return nnf
275
+
276
+ def apply_nnf_to_image(self, nnf, image):
277
+ with cp.cuda.Device(self.gpu_id):
278
+ image = self.patch_matchers[-1].pad_image(image)
279
+ image = self.patch_matchers[-1].apply_nnf_to_image(nnf, image)
280
+ return image
281
+
282
+ def estimate_nnf(self, source_guide, target_guide, source_style):
283
+ with cp.cuda.Device(self.gpu_id):
284
+ if not isinstance(source_guide, cp.ndarray):
285
+ source_guide = cp.array(source_guide, dtype=cp.float32)
286
+ if not isinstance(target_guide, cp.ndarray):
287
+ target_guide = cp.array(target_guide, dtype=cp.float32)
288
+ if not isinstance(source_style, cp.ndarray):
289
+ source_style = cp.array(source_style, dtype=cp.float32)
290
+ for level in range(self.pyramid_level):
291
+ nnf = self.initialize_nnf(source_guide.shape[0]) if level==0 else self.update_nnf(nnf, level)
292
+ source_guide_ = self.resample_image(source_guide, level)
293
+ target_guide_ = self.resample_image(target_guide, level)
294
+ source_style_ = self.resample_image(source_style, level)
295
+ nnf, target_style = self.patch_matchers[level].estimate_nnf(
296
+ source_guide_, target_guide_, source_style_, nnf
297
+ )
298
+ return nnf.get(), target_style.get()
@@ -0,0 +1,4 @@
1
+ from .accurate import AccurateModeRunner
2
+ from .fast import FastModeRunner
3
+ from .balanced import BalancedModeRunner
4
+ from .interpolation import InterpolationModeRunner, InterpolationModeSingleFrameRunner
@@ -0,0 +1,35 @@
1
+ from ..patch_match import PyramidPatchMatcher
2
+ import os
3
+ import numpy as np
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+
8
+ class AccurateModeRunner:
9
+ def __init__(self):
10
+ pass
11
+
12
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Accurate Mode", save_path=None):
13
+ patch_match_engine = PyramidPatchMatcher(
14
+ image_height=frames_style[0].shape[0],
15
+ image_width=frames_style[0].shape[1],
16
+ channel=3,
17
+ use_mean_target_style=True,
18
+ **ebsynth_config
19
+ )
20
+ # run
21
+ n = len(frames_style)
22
+ for target in tqdm(range(n), desc=desc):
23
+ l, r = max(target - window_size, 0), min(target + window_size + 1, n)
24
+ remapped_frames = []
25
+ for i in range(l, r, batch_size):
26
+ j = min(i + batch_size, r)
27
+ source_guide = np.stack([frames_guide[source] for source in range(i, j)])
28
+ target_guide = np.stack([frames_guide[target]] * (j - i))
29
+ source_style = np.stack([frames_style[source] for source in range(i, j)])
30
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
31
+ remapped_frames.append(target_style)
32
+ frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
33
+ frame = frame.clip(0, 255).astype("uint8")
34
+ if save_path is not None:
35
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
@@ -0,0 +1,46 @@
1
+ from ..patch_match import PyramidPatchMatcher
2
+ import os
3
+ import numpy as np
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+
8
+ class BalancedModeRunner:
9
+ def __init__(self):
10
+ pass
11
+
12
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Balanced Mode", save_path=None):
13
+ patch_match_engine = PyramidPatchMatcher(
14
+ image_height=frames_style[0].shape[0],
15
+ image_width=frames_style[0].shape[1],
16
+ channel=3,
17
+ **ebsynth_config
18
+ )
19
+ # tasks
20
+ n = len(frames_style)
21
+ tasks = []
22
+ for target in range(n):
23
+ for source in range(target - window_size, target + window_size + 1):
24
+ if source >= 0 and source < n and source != target:
25
+ tasks.append((source, target))
26
+ # run
27
+ frames = [(None, 1) for i in range(n)]
28
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
29
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
30
+ source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
31
+ target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
32
+ source_style = np.stack([frames_style[source] for source, target in tasks_batch])
33
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
34
+ for (source, target), result in zip(tasks_batch, target_style):
35
+ frame, weight = frames[target]
36
+ if frame is None:
37
+ frame = frames_style[target]
38
+ frames[target] = (
39
+ frame * (weight / (weight + 1)) + result / (weight + 1),
40
+ weight + 1
41
+ )
42
+ if weight + 1 == min(n, target + window_size + 1) - max(0, target - window_size):
43
+ frame = frame.clip(0, 255).astype("uint8")
44
+ if save_path is not None:
45
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
46
+ frames[target] = (None, 1)
@@ -0,0 +1,141 @@
1
+ from ..patch_match import PyramidPatchMatcher
2
+ import functools, os
3
+ import numpy as np
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+
8
+ class TableManager:
9
+ def __init__(self):
10
+ pass
11
+
12
+ def task_list(self, n):
13
+ tasks = []
14
+ max_level = 1
15
+ while (1<<max_level)<=n:
16
+ max_level += 1
17
+ for i in range(n):
18
+ j = i
19
+ for level in range(max_level):
20
+ if i&(1<<level):
21
+ continue
22
+ j |= 1<<level
23
+ if j>=n:
24
+ break
25
+ meta_data = {
26
+ "source": i,
27
+ "target": j,
28
+ "level": level + 1
29
+ }
30
+ tasks.append(meta_data)
31
+ tasks.sort(key=functools.cmp_to_key(lambda u, v: u["level"]-v["level"]))
32
+ return tasks
33
+
34
+ def build_remapping_table(self, frames_guide, frames_style, patch_match_engine, batch_size, desc=""):
35
+ n = len(frames_guide)
36
+ tasks = self.task_list(n)
37
+ remapping_table = [[(frames_style[i], 1)] for i in range(n)]
38
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
39
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
40
+ source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
41
+ target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
42
+ source_style = np.stack([frames_style[task["source"]] for task in tasks_batch])
43
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
44
+ for task, result in zip(tasks_batch, target_style):
45
+ target, level = task["target"], task["level"]
46
+ if len(remapping_table[target])==level:
47
+ remapping_table[target].append((result, 1))
48
+ else:
49
+ frame, weight = remapping_table[target][level]
50
+ remapping_table[target][level] = (
51
+ frame * (weight / (weight + 1)) + result / (weight + 1),
52
+ weight + 1
53
+ )
54
+ return remapping_table
55
+
56
+ def remapping_table_to_blending_table(self, table):
57
+ for i in range(len(table)):
58
+ for j in range(1, len(table[i])):
59
+ frame_1, weight_1 = table[i][j-1]
60
+ frame_2, weight_2 = table[i][j]
61
+ frame = (frame_1 + frame_2) / 2
62
+ weight = weight_1 + weight_2
63
+ table[i][j] = (frame, weight)
64
+ return table
65
+
66
+ def tree_query(self, leftbound, rightbound):
67
+ node_list = []
68
+ node_index = rightbound
69
+ while node_index>=leftbound:
70
+ node_level = 0
71
+ while (1<<node_level)&node_index and node_index-(1<<node_level+1)+1>=leftbound:
72
+ node_level += 1
73
+ node_list.append((node_index, node_level))
74
+ node_index -= 1<<node_level
75
+ return node_list
76
+
77
+ def process_window_sum(self, frames_guide, blending_table, patch_match_engine, window_size, batch_size, desc=""):
78
+ n = len(blending_table)
79
+ tasks = []
80
+ frames_result = []
81
+ for target in range(n):
82
+ node_list = self.tree_query(max(target-window_size, 0), target)
83
+ for source, level in node_list:
84
+ if source!=target:
85
+ meta_data = {
86
+ "source": source,
87
+ "target": target,
88
+ "level": level
89
+ }
90
+ tasks.append(meta_data)
91
+ else:
92
+ frames_result.append(blending_table[target][level])
93
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
94
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
95
+ source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
96
+ target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
97
+ source_style = np.stack([blending_table[task["source"]][task["level"]][0] for task in tasks_batch])
98
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
99
+ for task, frame_2 in zip(tasks_batch, target_style):
100
+ source, target, level = task["source"], task["target"], task["level"]
101
+ frame_1, weight_1 = frames_result[target]
102
+ weight_2 = blending_table[source][level][1]
103
+ weight = weight_1 + weight_2
104
+ frame = frame_1 * (weight_1 / weight) + frame_2 * (weight_2 / weight)
105
+ frames_result[target] = (frame, weight)
106
+ return frames_result
107
+
108
+
109
+ class FastModeRunner:
110
+ def __init__(self):
111
+ pass
112
+
113
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, save_path=None):
114
+ frames_guide = frames_guide.raw_data()
115
+ frames_style = frames_style.raw_data()
116
+ table_manager = TableManager()
117
+ patch_match_engine = PyramidPatchMatcher(
118
+ image_height=frames_style[0].shape[0],
119
+ image_width=frames_style[0].shape[1],
120
+ channel=3,
121
+ **ebsynth_config
122
+ )
123
+ # left part
124
+ table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="Fast Mode Step 1/4")
125
+ table_l = table_manager.remapping_table_to_blending_table(table_l)
126
+ table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 2/4")
127
+ # right part
128
+ table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="Fast Mode Step 3/4")
129
+ table_r = table_manager.remapping_table_to_blending_table(table_r)
130
+ table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 4/4")[::-1]
131
+ # merge
132
+ frames = []
133
+ for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
134
+ weight_m = -1
135
+ weight = weight_l + weight_m + weight_r
136
+ frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
137
+ frames.append(frame)
138
+ frames = [frame.clip(0, 255).astype("uint8") for frame in frames]
139
+ if save_path is not None:
140
+ for target, frame in enumerate(frames):
141
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))