diffsynth 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. diffsynth/__init__.py +6 -0
  2. diffsynth/configs/__init__.py +0 -0
  3. diffsynth/configs/model_config.py +243 -0
  4. diffsynth/controlnets/__init__.py +2 -0
  5. diffsynth/controlnets/controlnet_unit.py +53 -0
  6. diffsynth/controlnets/processors.py +51 -0
  7. diffsynth/data/__init__.py +1 -0
  8. diffsynth/data/simple_text_image.py +35 -0
  9. diffsynth/data/video.py +148 -0
  10. diffsynth/extensions/ESRGAN/__init__.py +118 -0
  11. diffsynth/extensions/FastBlend/__init__.py +63 -0
  12. diffsynth/extensions/FastBlend/api.py +397 -0
  13. diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
  14. diffsynth/extensions/FastBlend/data.py +146 -0
  15. diffsynth/extensions/FastBlend/patch_match.py +298 -0
  16. diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
  17. diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
  18. diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
  19. diffsynth/extensions/FastBlend/runners/fast.py +141 -0
  20. diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
  21. diffsynth/extensions/RIFE/__init__.py +242 -0
  22. diffsynth/extensions/__init__.py +0 -0
  23. diffsynth/models/__init__.py +1 -0
  24. diffsynth/models/attention.py +89 -0
  25. diffsynth/models/downloader.py +66 -0
  26. diffsynth/models/hunyuan_dit.py +451 -0
  27. diffsynth/models/hunyuan_dit_text_encoder.py +163 -0
  28. diffsynth/models/kolors_text_encoder.py +1363 -0
  29. diffsynth/models/lora.py +195 -0
  30. diffsynth/models/model_manager.py +536 -0
  31. diffsynth/models/sd3_dit.py +798 -0
  32. diffsynth/models/sd3_text_encoder.py +1107 -0
  33. diffsynth/models/sd3_vae_decoder.py +81 -0
  34. diffsynth/models/sd3_vae_encoder.py +95 -0
  35. diffsynth/models/sd_controlnet.py +588 -0
  36. diffsynth/models/sd_ipadapter.py +57 -0
  37. diffsynth/models/sd_motion.py +199 -0
  38. diffsynth/models/sd_text_encoder.py +321 -0
  39. diffsynth/models/sd_unet.py +1108 -0
  40. diffsynth/models/sd_vae_decoder.py +336 -0
  41. diffsynth/models/sd_vae_encoder.py +282 -0
  42. diffsynth/models/sdxl_ipadapter.py +122 -0
  43. diffsynth/models/sdxl_motion.py +104 -0
  44. diffsynth/models/sdxl_text_encoder.py +759 -0
  45. diffsynth/models/sdxl_unet.py +1899 -0
  46. diffsynth/models/sdxl_vae_decoder.py +24 -0
  47. diffsynth/models/sdxl_vae_encoder.py +24 -0
  48. diffsynth/models/svd_image_encoder.py +505 -0
  49. diffsynth/models/svd_unet.py +2004 -0
  50. diffsynth/models/svd_vae_decoder.py +578 -0
  51. diffsynth/models/svd_vae_encoder.py +139 -0
  52. diffsynth/models/tiler.py +106 -0
  53. diffsynth/pipelines/__init__.py +9 -0
  54. diffsynth/pipelines/base.py +34 -0
  55. diffsynth/pipelines/dancer.py +178 -0
  56. diffsynth/pipelines/hunyuan_image.py +274 -0
  57. diffsynth/pipelines/pipeline_runner.py +105 -0
  58. diffsynth/pipelines/sd3_image.py +132 -0
  59. diffsynth/pipelines/sd_image.py +173 -0
  60. diffsynth/pipelines/sd_video.py +266 -0
  61. diffsynth/pipelines/sdxl_image.py +191 -0
  62. diffsynth/pipelines/sdxl_video.py +223 -0
  63. diffsynth/pipelines/svd_video.py +297 -0
  64. diffsynth/processors/FastBlend.py +142 -0
  65. diffsynth/processors/PILEditor.py +28 -0
  66. diffsynth/processors/RIFE.py +77 -0
  67. diffsynth/processors/__init__.py +0 -0
  68. diffsynth/processors/base.py +6 -0
  69. diffsynth/processors/sequencial_processor.py +41 -0
  70. diffsynth/prompters/__init__.py +6 -0
  71. diffsynth/prompters/base_prompter.py +57 -0
  72. diffsynth/prompters/hunyuan_dit_prompter.py +69 -0
  73. diffsynth/prompters/kolors_prompter.py +353 -0
  74. diffsynth/prompters/prompt_refiners.py +77 -0
  75. diffsynth/prompters/sd3_prompter.py +92 -0
  76. diffsynth/prompters/sd_prompter.py +73 -0
  77. diffsynth/prompters/sdxl_prompter.py +61 -0
  78. diffsynth/schedulers/__init__.py +3 -0
  79. diffsynth/schedulers/continuous_ode.py +59 -0
  80. diffsynth/schedulers/ddim.py +79 -0
  81. diffsynth/schedulers/flow_match.py +51 -0
  82. diffsynth/tokenizer_configs/__init__.py +0 -0
  83. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json +7 -0
  84. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json +16 -0
  85. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt +47020 -0
  86. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt +21128 -0
  87. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json +28 -0
  88. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json +1 -0
  89. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model +0 -0
  90. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json +1 -0
  91. diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model +0 -0
  92. diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json +12 -0
  93. diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt +0 -0
  94. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt +48895 -0
  95. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json +24 -0
  96. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json +34 -0
  97. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json +49410 -0
  98. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt +48895 -0
  99. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json +30 -0
  100. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json +30 -0
  101. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json +49410 -0
  102. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt +48895 -0
  103. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json +30 -0
  104. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json +38 -0
  105. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json +49410 -0
  106. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json +125 -0
  107. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/spiece.model +0 -0
  108. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json +129428 -0
  109. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json +940 -0
  110. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt +40213 -0
  111. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json +24 -0
  112. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json +38 -0
  113. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/vocab.json +49411 -0
  114. diffsynth/trainers/__init__.py +0 -0
  115. diffsynth/trainers/text_to_image.py +253 -0
  116. diffsynth-1.0.0.dist-info/LICENSE +201 -0
  117. diffsynth-1.0.0.dist-info/METADATA +23 -0
  118. diffsynth-1.0.0.dist-info/RECORD +120 -0
  119. diffsynth-1.0.0.dist-info/WHEEL +5 -0
  120. diffsynth-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,121 @@
1
+ from ..patch_match import PyramidPatchMatcher
2
+ import os
3
+ import numpy as np
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+
8
+ class InterpolationModeRunner:
9
+ def __init__(self):
10
+ pass
11
+
12
+ def get_index_dict(self, index_style):
13
+ index_dict = {}
14
+ for i, index in enumerate(index_style):
15
+ index_dict[index] = i
16
+ return index_dict
17
+
18
+ def get_weight(self, l, m, r):
19
+ weight_l, weight_r = abs(m - r), abs(m - l)
20
+ if weight_l + weight_r == 0:
21
+ weight_l, weight_r = 0.5, 0.5
22
+ else:
23
+ weight_l, weight_r = weight_l / (weight_l + weight_r), weight_r / (weight_l + weight_r)
24
+ return weight_l, weight_r
25
+
26
+ def get_task_group(self, index_style, n):
27
+ task_group = []
28
+ index_style = sorted(index_style)
29
+ # first frame
30
+ if index_style[0]>0:
31
+ tasks = []
32
+ for m in range(index_style[0]):
33
+ tasks.append((index_style[0], m, index_style[0]))
34
+ task_group.append(tasks)
35
+ # middle frames
36
+ for l, r in zip(index_style[:-1], index_style[1:]):
37
+ tasks = []
38
+ for m in range(l, r):
39
+ tasks.append((l, m, r))
40
+ task_group.append(tasks)
41
+ # last frame
42
+ tasks = []
43
+ for m in range(index_style[-1], n):
44
+ tasks.append((index_style[-1], m, index_style[-1]))
45
+ task_group.append(tasks)
46
+ return task_group
47
+
48
+ def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
49
+ patch_match_engine = PyramidPatchMatcher(
50
+ image_height=frames_style[0].shape[0],
51
+ image_width=frames_style[0].shape[1],
52
+ channel=3,
53
+ use_mean_target_style=False,
54
+ use_pairwise_patch_error=True,
55
+ **ebsynth_config
56
+ )
57
+ # task
58
+ index_dict = self.get_index_dict(index_style)
59
+ task_group = self.get_task_group(index_style, len(frames_guide))
60
+ # run
61
+ for tasks in task_group:
62
+ index_start, index_end = min([i[1] for i in tasks]), max([i[1] for i in tasks])
63
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=f"Rendering frames {index_start}...{index_end}"):
64
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
65
+ source_guide, target_guide, source_style = [], [], []
66
+ for l, m, r in tasks_batch:
67
+ # l -> m
68
+ source_guide.append(frames_guide[l])
69
+ target_guide.append(frames_guide[m])
70
+ source_style.append(frames_style[index_dict[l]])
71
+ # r -> m
72
+ source_guide.append(frames_guide[r])
73
+ target_guide.append(frames_guide[m])
74
+ source_style.append(frames_style[index_dict[r]])
75
+ source_guide = np.stack(source_guide)
76
+ target_guide = np.stack(target_guide)
77
+ source_style = np.stack(source_style)
78
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
79
+ if save_path is not None:
80
+ for frame_l, frame_r, (l, m, r) in zip(target_style[0::2], target_style[1::2], tasks_batch):
81
+ weight_l, weight_r = self.get_weight(l, m, r)
82
+ frame = frame_l * weight_l + frame_r * weight_r
83
+ frame = frame.clip(0, 255).astype("uint8")
84
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % m))
85
+
86
+
87
+ class InterpolationModeSingleFrameRunner:
88
+ def __init__(self):
89
+ pass
90
+
91
+ def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
92
+ # check input
93
+ tracking_window_size = ebsynth_config["tracking_window_size"]
94
+ if tracking_window_size * 2 >= batch_size:
95
+ raise ValueError("batch_size should be larger than track_window_size * 2")
96
+ frame_style = frames_style[0]
97
+ frame_guide = frames_guide[index_style[0]]
98
+ patch_match_engine = PyramidPatchMatcher(
99
+ image_height=frame_style.shape[0],
100
+ image_width=frame_style.shape[1],
101
+ channel=3,
102
+ **ebsynth_config
103
+ )
104
+ # run
105
+ frame_id, n = 0, len(frames_guide)
106
+ for i in tqdm(range(0, n, batch_size - tracking_window_size * 2), desc=f"Rendering frames 0...{n}"):
107
+ if i + batch_size > n:
108
+ l, r = max(n - batch_size, 0), n
109
+ else:
110
+ l, r = i, i + batch_size
111
+ source_guide = np.stack([frame_guide] * (r-l))
112
+ target_guide = np.stack([frames_guide[i] for i in range(l, r)])
113
+ source_style = np.stack([frame_style] * (r-l))
114
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
115
+ for i, frame in zip(range(l, r), target_style):
116
+ if i==frame_id:
117
+ frame = frame.clip(0, 255).astype("uint8")
118
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % frame_id))
119
+ frame_id += 1
120
+ if r < n and r-frame_id <= tracking_window_size:
121
+ break
@@ -0,0 +1,242 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+
8
+ def warp(tenInput, tenFlow, device):
9
+ backwarp_tenGrid = {}
10
+ k = (str(tenFlow.device), str(tenFlow.size()))
11
+ if k not in backwarp_tenGrid:
12
+ tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
13
+ 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
14
+ tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
15
+ 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
16
+ backwarp_tenGrid[k] = torch.cat(
17
+ [tenHorizontal, tenVertical], 1).to(device)
18
+
19
+ tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
20
+ tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
21
+
22
+ g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
23
+ return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
24
+
25
+
26
+ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
27
+ return nn.Sequential(
28
+ nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
29
+ padding=padding, dilation=dilation, bias=True),
30
+ nn.PReLU(out_planes)
31
+ )
32
+
33
+
34
+ class IFBlock(nn.Module):
35
+ def __init__(self, in_planes, c=64):
36
+ super(IFBlock, self).__init__()
37
+ self.conv0 = nn.Sequential(conv(in_planes, c//2, 3, 2, 1), conv(c//2, c, 3, 2, 1),)
38
+ self.convblock0 = nn.Sequential(conv(c, c), conv(c, c))
39
+ self.convblock1 = nn.Sequential(conv(c, c), conv(c, c))
40
+ self.convblock2 = nn.Sequential(conv(c, c), conv(c, c))
41
+ self.convblock3 = nn.Sequential(conv(c, c), conv(c, c))
42
+ self.conv1 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 4, 4, 2, 1))
43
+ self.conv2 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 1, 4, 2, 1))
44
+
45
+ def forward(self, x, flow, scale=1):
46
+ x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
47
+ flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale
48
+ feat = self.conv0(torch.cat((x, flow), 1))
49
+ feat = self.convblock0(feat) + feat
50
+ feat = self.convblock1(feat) + feat
51
+ feat = self.convblock2(feat) + feat
52
+ feat = self.convblock3(feat) + feat
53
+ flow = self.conv1(feat)
54
+ mask = self.conv2(feat)
55
+ flow = F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale
56
+ mask = F.interpolate(mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
57
+ return flow, mask
58
+
59
+
60
+ class IFNet(nn.Module):
61
+ def __init__(self):
62
+ super(IFNet, self).__init__()
63
+ self.block0 = IFBlock(7+4, c=90)
64
+ self.block1 = IFBlock(7+4, c=90)
65
+ self.block2 = IFBlock(7+4, c=90)
66
+ self.block_tea = IFBlock(10+4, c=90)
67
+
68
+ def forward(self, x, scale_list=[4, 2, 1], training=False):
69
+ if training == False:
70
+ channel = x.shape[1] // 2
71
+ img0 = x[:, :channel]
72
+ img1 = x[:, channel:]
73
+ flow_list = []
74
+ merged = []
75
+ mask_list = []
76
+ warped_img0 = img0
77
+ warped_img1 = img1
78
+ flow = (x[:, :4]).detach() * 0
79
+ mask = (x[:, :1]).detach() * 0
80
+ block = [self.block0, self.block1, self.block2]
81
+ for i in range(3):
82
+ f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
83
+ f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i])
84
+ flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
85
+ mask = mask + (m0 + (-m1)) / 2
86
+ mask_list.append(mask)
87
+ flow_list.append(flow)
88
+ warped_img0 = warp(img0, flow[:, :2], device=x.device)
89
+ warped_img1 = warp(img1, flow[:, 2:4], device=x.device)
90
+ merged.append((warped_img0, warped_img1))
91
+ '''
92
+ c0 = self.contextnet(img0, flow[:, :2])
93
+ c1 = self.contextnet(img1, flow[:, 2:4])
94
+ tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
95
+ res = tmp[:, 1:4] * 2 - 1
96
+ '''
97
+ for i in range(3):
98
+ mask_list[i] = torch.sigmoid(mask_list[i])
99
+ merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
100
+ return flow_list, mask_list[2], merged
101
+
102
+ @staticmethod
103
+ def state_dict_converter():
104
+ return IFNetStateDictConverter()
105
+
106
+
107
+ class IFNetStateDictConverter:
108
+ def __init__(self):
109
+ pass
110
+
111
+ def from_diffusers(self, state_dict):
112
+ state_dict_ = {k.replace("module.", ""): v for k, v in state_dict.items()}
113
+ return state_dict_
114
+
115
+ def from_civitai(self, state_dict):
116
+ return self.from_diffusers(state_dict)
117
+
118
+
119
+ class RIFEInterpolater:
120
+ def __init__(self, model, device="cuda"):
121
+ self.model = model
122
+ self.device = device
123
+ # IFNet only does not support float16
124
+ self.torch_dtype = torch.float32
125
+
126
+ @staticmethod
127
+ def from_model_manager(model_manager):
128
+ return RIFEInterpolater(model_manager.RIFE, device=model_manager.device)
129
+
130
+ def process_image(self, image):
131
+ width, height = image.size
132
+ if width % 32 != 0 or height % 32 != 0:
133
+ width = (width + 31) // 32
134
+ height = (height + 31) // 32
135
+ image = image.resize((width, height))
136
+ image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1)
137
+ return image
138
+
139
+ def process_images(self, images):
140
+ images = [self.process_image(image) for image in images]
141
+ images = torch.stack(images)
142
+ return images
143
+
144
+ def decode_images(self, images):
145
+ images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
146
+ images = [Image.fromarray(image) for image in images]
147
+ return images
148
+
149
+ def add_interpolated_images(self, images, interpolated_images):
150
+ output_images = []
151
+ for image, interpolated_image in zip(images, interpolated_images):
152
+ output_images.append(image)
153
+ output_images.append(interpolated_image)
154
+ output_images.append(images[-1])
155
+ return output_images
156
+
157
+
158
+ @torch.no_grad()
159
+ def interpolate_(self, images, scale=1.0):
160
+ input_tensor = self.process_images(images)
161
+ input_tensor = torch.cat((input_tensor[:-1], input_tensor[1:]), dim=1)
162
+ input_tensor = input_tensor.to(device=self.device, dtype=self.torch_dtype)
163
+ flow, mask, merged = self.model(input_tensor, [4/scale, 2/scale, 1/scale])
164
+ output_images = self.decode_images(merged[2].cpu())
165
+ if output_images[0].size != images[0].size:
166
+ output_images = [image.resize(images[0].size) for image in output_images]
167
+ return output_images
168
+
169
+
170
+ @torch.no_grad()
171
+ def interpolate(self, images, scale=1.0, batch_size=4, num_iter=1, progress_bar=lambda x:x):
172
+ # Preprocess
173
+ processed_images = self.process_images(images)
174
+
175
+ for iter in range(num_iter):
176
+ # Input
177
+ input_tensor = torch.cat((processed_images[:-1], processed_images[1:]), dim=1)
178
+
179
+ # Interpolate
180
+ output_tensor = []
181
+ for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
182
+ batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
183
+ batch_input_tensor = input_tensor[batch_id: batch_id_]
184
+ batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
185
+ flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
186
+ output_tensor.append(merged[2].cpu())
187
+
188
+ # Output
189
+ output_tensor = torch.concat(output_tensor, dim=0).clip(0, 1)
190
+ processed_images = self.add_interpolated_images(processed_images, output_tensor)
191
+ processed_images = torch.stack(processed_images)
192
+
193
+ # To images
194
+ output_images = self.decode_images(processed_images)
195
+ if output_images[0].size != images[0].size:
196
+ output_images = [image.resize(images[0].size) for image in output_images]
197
+ return output_images
198
+
199
+
200
+ class RIFESmoother(RIFEInterpolater):
201
+ def __init__(self, model, device="cuda"):
202
+ super(RIFESmoother, self).__init__(model, device=device)
203
+
204
+ @staticmethod
205
+ def from_model_manager(model_manager):
206
+ return RIFESmoother(model_manager.RIFE, device=model_manager.device)
207
+
208
+ def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
209
+ output_tensor = []
210
+ for batch_id in range(0, input_tensor.shape[0], batch_size):
211
+ batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
212
+ batch_input_tensor = input_tensor[batch_id: batch_id_]
213
+ batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
214
+ flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
215
+ output_tensor.append(merged[2].cpu())
216
+ output_tensor = torch.concat(output_tensor, dim=0)
217
+ return output_tensor
218
+
219
+ @torch.no_grad()
220
+ def __call__(self, rendered_frames, scale=1.0, batch_size=4, num_iter=1, **kwargs):
221
+ # Preprocess
222
+ processed_images = self.process_images(rendered_frames)
223
+
224
+ for iter in range(num_iter):
225
+ # Input
226
+ input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1)
227
+
228
+ # Interpolate
229
+ output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
230
+
231
+ # Blend
232
+ input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1)
233
+ output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
234
+
235
+ # Add to frames
236
+ processed_images[1:-1] = output_tensor
237
+
238
+ # To images
239
+ output_images = self.decode_images(processed_images)
240
+ if output_images[0].size != rendered_frames[0].size:
241
+ output_images = [image.resize(rendered_frames[0].size) for image in output_images]
242
+ return output_images
File without changes
@@ -0,0 +1 @@
1
+ from .model_manager import *
@@ -0,0 +1,89 @@
1
+ import torch
2
+ from einops import rearrange
3
+
4
+
5
+ def low_version_attention(query, key, value, attn_bias=None):
6
+ scale = 1 / query.shape[-1] ** 0.5
7
+ query = query * scale
8
+ attn = torch.matmul(query, key.transpose(-2, -1))
9
+ if attn_bias is not None:
10
+ attn = attn + attn_bias
11
+ attn = attn.softmax(-1)
12
+ return attn @ value
13
+
14
+
15
+ class Attention(torch.nn.Module):
16
+
17
+ def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
18
+ super().__init__()
19
+ dim_inner = head_dim * num_heads
20
+ kv_dim = kv_dim if kv_dim is not None else q_dim
21
+ self.num_heads = num_heads
22
+ self.head_dim = head_dim
23
+
24
+ self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
25
+ self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
26
+ self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
27
+ self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
28
+
29
+ def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
30
+ batch_size = q.shape[0]
31
+ ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
32
+ ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
33
+ ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
34
+ hidden_states = hidden_states + scale * ip_hidden_states
35
+ return hidden_states
36
+
37
+ def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
38
+ if encoder_hidden_states is None:
39
+ encoder_hidden_states = hidden_states
40
+
41
+ batch_size = encoder_hidden_states.shape[0]
42
+
43
+ q = self.to_q(hidden_states)
44
+ k = self.to_k(encoder_hidden_states)
45
+ v = self.to_v(encoder_hidden_states)
46
+
47
+ q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
48
+ k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
49
+ v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
50
+
51
+ if qkv_preprocessor is not None:
52
+ q, k, v = qkv_preprocessor(q, k, v)
53
+
54
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
55
+ if ipadapter_kwargs is not None:
56
+ hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
57
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
58
+ hidden_states = hidden_states.to(q.dtype)
59
+
60
+ hidden_states = self.to_out(hidden_states)
61
+
62
+ return hidden_states
63
+
64
+ def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
65
+ if encoder_hidden_states is None:
66
+ encoder_hidden_states = hidden_states
67
+
68
+ q = self.to_q(hidden_states)
69
+ k = self.to_k(encoder_hidden_states)
70
+ v = self.to_v(encoder_hidden_states)
71
+
72
+ q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
73
+ k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
74
+ v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
75
+
76
+ if attn_mask is not None:
77
+ hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
78
+ else:
79
+ import xformers.ops as xops
80
+ hidden_states = xops.memory_efficient_attention(q, k, v)
81
+ hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
82
+
83
+ hidden_states = hidden_states.to(q.dtype)
84
+ hidden_states = self.to_out(hidden_states)
85
+
86
+ return hidden_states
87
+
88
+ def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
89
+ return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)
@@ -0,0 +1,66 @@
1
+ from huggingface_hub import hf_hub_download
2
+ from modelscope import snapshot_download
3
+ import os, shutil
4
+ from typing_extensions import Literal, TypeAlias
5
+ from typing import List
6
+ from ..configs.model_config import preset_models_on_huggingface, preset_models_on_modelscope, Preset_model_id
7
+
8
+
9
+ def download_from_modelscope(model_id, origin_file_path, local_dir):
10
+ os.makedirs(local_dir, exist_ok=True)
11
+ if os.path.basename(origin_file_path) in os.listdir(local_dir):
12
+ print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.")
13
+ return
14
+ else:
15
+ print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
16
+ snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
17
+ downloaded_file_path = os.path.join(local_dir, origin_file_path)
18
+ target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
19
+ if downloaded_file_path != target_file_path:
20
+ shutil.move(downloaded_file_path, target_file_path)
21
+ shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
22
+
23
+
24
+ def download_from_huggingface(model_id, origin_file_path, local_dir):
25
+ os.makedirs(local_dir, exist_ok=True)
26
+ if os.path.basename(origin_file_path) in os.listdir(local_dir):
27
+ print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.")
28
+ return
29
+ else:
30
+ print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
31
+ hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
32
+
33
+
34
+ Preset_model_website: TypeAlias = Literal[
35
+ "HuggingFace",
36
+ "ModelScope",
37
+ ]
38
+ website_to_preset_models = {
39
+ "HuggingFace": preset_models_on_huggingface,
40
+ "ModelScope": preset_models_on_modelscope,
41
+ }
42
+ website_to_download_fn = {
43
+ "HuggingFace": download_from_huggingface,
44
+ "ModelScope": download_from_modelscope,
45
+ }
46
+
47
+
48
+ def download_models(
49
+ model_id_list: List[Preset_model_id] = [],
50
+ downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
51
+ ):
52
+ print(f"Downloading models: {model_id_list}")
53
+ downloaded_files = []
54
+ for model_id in model_id_list:
55
+ for website in downloading_priority:
56
+ if model_id in website_to_preset_models[website]:
57
+ for model_id, origin_file_path, local_dir in website_to_preset_models[website][model_id]:
58
+ # Check if the file is downloaded.
59
+ file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
60
+ if file_to_download in downloaded_files:
61
+ continue
62
+ # Download
63
+ website_to_download_fn[website](model_id, origin_file_path, local_dir)
64
+ if os.path.basename(origin_file_path) in os.listdir(local_dir):
65
+ downloaded_files.append(file_to_download)
66
+ return downloaded_files