diffsynth 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth/__init__.py +6 -0
- diffsynth/configs/__init__.py +0 -0
- diffsynth/configs/model_config.py +243 -0
- diffsynth/controlnets/__init__.py +2 -0
- diffsynth/controlnets/controlnet_unit.py +53 -0
- diffsynth/controlnets/processors.py +51 -0
- diffsynth/data/__init__.py +1 -0
- diffsynth/data/simple_text_image.py +35 -0
- diffsynth/data/video.py +148 -0
- diffsynth/extensions/ESRGAN/__init__.py +118 -0
- diffsynth/extensions/FastBlend/__init__.py +63 -0
- diffsynth/extensions/FastBlend/api.py +397 -0
- diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
- diffsynth/extensions/FastBlend/data.py +146 -0
- diffsynth/extensions/FastBlend/patch_match.py +298 -0
- diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
- diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
- diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
- diffsynth/extensions/FastBlend/runners/fast.py +141 -0
- diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
- diffsynth/extensions/RIFE/__init__.py +242 -0
- diffsynth/extensions/__init__.py +0 -0
- diffsynth/models/__init__.py +1 -0
- diffsynth/models/attention.py +89 -0
- diffsynth/models/downloader.py +66 -0
- diffsynth/models/hunyuan_dit.py +451 -0
- diffsynth/models/hunyuan_dit_text_encoder.py +163 -0
- diffsynth/models/kolors_text_encoder.py +1363 -0
- diffsynth/models/lora.py +195 -0
- diffsynth/models/model_manager.py +536 -0
- diffsynth/models/sd3_dit.py +798 -0
- diffsynth/models/sd3_text_encoder.py +1107 -0
- diffsynth/models/sd3_vae_decoder.py +81 -0
- diffsynth/models/sd3_vae_encoder.py +95 -0
- diffsynth/models/sd_controlnet.py +588 -0
- diffsynth/models/sd_ipadapter.py +57 -0
- diffsynth/models/sd_motion.py +199 -0
- diffsynth/models/sd_text_encoder.py +321 -0
- diffsynth/models/sd_unet.py +1108 -0
- diffsynth/models/sd_vae_decoder.py +336 -0
- diffsynth/models/sd_vae_encoder.py +282 -0
- diffsynth/models/sdxl_ipadapter.py +122 -0
- diffsynth/models/sdxl_motion.py +104 -0
- diffsynth/models/sdxl_text_encoder.py +759 -0
- diffsynth/models/sdxl_unet.py +1899 -0
- diffsynth/models/sdxl_vae_decoder.py +24 -0
- diffsynth/models/sdxl_vae_encoder.py +24 -0
- diffsynth/models/svd_image_encoder.py +505 -0
- diffsynth/models/svd_unet.py +2004 -0
- diffsynth/models/svd_vae_decoder.py +578 -0
- diffsynth/models/svd_vae_encoder.py +139 -0
- diffsynth/models/tiler.py +106 -0
- diffsynth/pipelines/__init__.py +9 -0
- diffsynth/pipelines/base.py +34 -0
- diffsynth/pipelines/dancer.py +178 -0
- diffsynth/pipelines/hunyuan_image.py +274 -0
- diffsynth/pipelines/pipeline_runner.py +105 -0
- diffsynth/pipelines/sd3_image.py +132 -0
- diffsynth/pipelines/sd_image.py +173 -0
- diffsynth/pipelines/sd_video.py +266 -0
- diffsynth/pipelines/sdxl_image.py +191 -0
- diffsynth/pipelines/sdxl_video.py +223 -0
- diffsynth/pipelines/svd_video.py +297 -0
- diffsynth/processors/FastBlend.py +142 -0
- diffsynth/processors/PILEditor.py +28 -0
- diffsynth/processors/RIFE.py +77 -0
- diffsynth/processors/__init__.py +0 -0
- diffsynth/processors/base.py +6 -0
- diffsynth/processors/sequencial_processor.py +41 -0
- diffsynth/prompters/__init__.py +6 -0
- diffsynth/prompters/base_prompter.py +57 -0
- diffsynth/prompters/hunyuan_dit_prompter.py +69 -0
- diffsynth/prompters/kolors_prompter.py +353 -0
- diffsynth/prompters/prompt_refiners.py +77 -0
- diffsynth/prompters/sd3_prompter.py +92 -0
- diffsynth/prompters/sd_prompter.py +73 -0
- diffsynth/prompters/sdxl_prompter.py +61 -0
- diffsynth/schedulers/__init__.py +3 -0
- diffsynth/schedulers/continuous_ode.py +59 -0
- diffsynth/schedulers/ddim.py +79 -0
- diffsynth/schedulers/flow_match.py +51 -0
- diffsynth/tokenizer_configs/__init__.py +0 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json +7 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json +16 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt +47020 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt +21128 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json +28 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json +1 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model +0 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json +1 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model +0 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json +12 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt +0 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json +24 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json +34 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json +38 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json +125 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/spiece.model +0 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json +129428 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json +940 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt +40213 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json +24 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json +38 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/vocab.json +49411 -0
- diffsynth/trainers/__init__.py +0 -0
- diffsynth/trainers/text_to_image.py +253 -0
- diffsynth-1.0.0.dist-info/LICENSE +201 -0
- diffsynth-1.0.0.dist-info/METADATA +23 -0
- diffsynth-1.0.0.dist-info/RECORD +120 -0
- diffsynth-1.0.0.dist-info/WHEEL +5 -0
- diffsynth-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from ..patch_match import PyramidPatchMatcher
|
|
2
|
+
import os
|
|
3
|
+
import numpy as np
|
|
4
|
+
from PIL import Image
|
|
5
|
+
from tqdm import tqdm
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class InterpolationModeRunner:
|
|
9
|
+
def __init__(self):
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
def get_index_dict(self, index_style):
|
|
13
|
+
index_dict = {}
|
|
14
|
+
for i, index in enumerate(index_style):
|
|
15
|
+
index_dict[index] = i
|
|
16
|
+
return index_dict
|
|
17
|
+
|
|
18
|
+
def get_weight(self, l, m, r):
|
|
19
|
+
weight_l, weight_r = abs(m - r), abs(m - l)
|
|
20
|
+
if weight_l + weight_r == 0:
|
|
21
|
+
weight_l, weight_r = 0.5, 0.5
|
|
22
|
+
else:
|
|
23
|
+
weight_l, weight_r = weight_l / (weight_l + weight_r), weight_r / (weight_l + weight_r)
|
|
24
|
+
return weight_l, weight_r
|
|
25
|
+
|
|
26
|
+
def get_task_group(self, index_style, n):
|
|
27
|
+
task_group = []
|
|
28
|
+
index_style = sorted(index_style)
|
|
29
|
+
# first frame
|
|
30
|
+
if index_style[0]>0:
|
|
31
|
+
tasks = []
|
|
32
|
+
for m in range(index_style[0]):
|
|
33
|
+
tasks.append((index_style[0], m, index_style[0]))
|
|
34
|
+
task_group.append(tasks)
|
|
35
|
+
# middle frames
|
|
36
|
+
for l, r in zip(index_style[:-1], index_style[1:]):
|
|
37
|
+
tasks = []
|
|
38
|
+
for m in range(l, r):
|
|
39
|
+
tasks.append((l, m, r))
|
|
40
|
+
task_group.append(tasks)
|
|
41
|
+
# last frame
|
|
42
|
+
tasks = []
|
|
43
|
+
for m in range(index_style[-1], n):
|
|
44
|
+
tasks.append((index_style[-1], m, index_style[-1]))
|
|
45
|
+
task_group.append(tasks)
|
|
46
|
+
return task_group
|
|
47
|
+
|
|
48
|
+
def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
|
|
49
|
+
patch_match_engine = PyramidPatchMatcher(
|
|
50
|
+
image_height=frames_style[0].shape[0],
|
|
51
|
+
image_width=frames_style[0].shape[1],
|
|
52
|
+
channel=3,
|
|
53
|
+
use_mean_target_style=False,
|
|
54
|
+
use_pairwise_patch_error=True,
|
|
55
|
+
**ebsynth_config
|
|
56
|
+
)
|
|
57
|
+
# task
|
|
58
|
+
index_dict = self.get_index_dict(index_style)
|
|
59
|
+
task_group = self.get_task_group(index_style, len(frames_guide))
|
|
60
|
+
# run
|
|
61
|
+
for tasks in task_group:
|
|
62
|
+
index_start, index_end = min([i[1] for i in tasks]), max([i[1] for i in tasks])
|
|
63
|
+
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=f"Rendering frames {index_start}...{index_end}"):
|
|
64
|
+
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
|
65
|
+
source_guide, target_guide, source_style = [], [], []
|
|
66
|
+
for l, m, r in tasks_batch:
|
|
67
|
+
# l -> m
|
|
68
|
+
source_guide.append(frames_guide[l])
|
|
69
|
+
target_guide.append(frames_guide[m])
|
|
70
|
+
source_style.append(frames_style[index_dict[l]])
|
|
71
|
+
# r -> m
|
|
72
|
+
source_guide.append(frames_guide[r])
|
|
73
|
+
target_guide.append(frames_guide[m])
|
|
74
|
+
source_style.append(frames_style[index_dict[r]])
|
|
75
|
+
source_guide = np.stack(source_guide)
|
|
76
|
+
target_guide = np.stack(target_guide)
|
|
77
|
+
source_style = np.stack(source_style)
|
|
78
|
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
|
79
|
+
if save_path is not None:
|
|
80
|
+
for frame_l, frame_r, (l, m, r) in zip(target_style[0::2], target_style[1::2], tasks_batch):
|
|
81
|
+
weight_l, weight_r = self.get_weight(l, m, r)
|
|
82
|
+
frame = frame_l * weight_l + frame_r * weight_r
|
|
83
|
+
frame = frame.clip(0, 255).astype("uint8")
|
|
84
|
+
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % m))
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class InterpolationModeSingleFrameRunner:
|
|
88
|
+
def __init__(self):
|
|
89
|
+
pass
|
|
90
|
+
|
|
91
|
+
def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
|
|
92
|
+
# check input
|
|
93
|
+
tracking_window_size = ebsynth_config["tracking_window_size"]
|
|
94
|
+
if tracking_window_size * 2 >= batch_size:
|
|
95
|
+
raise ValueError("batch_size should be larger than track_window_size * 2")
|
|
96
|
+
frame_style = frames_style[0]
|
|
97
|
+
frame_guide = frames_guide[index_style[0]]
|
|
98
|
+
patch_match_engine = PyramidPatchMatcher(
|
|
99
|
+
image_height=frame_style.shape[0],
|
|
100
|
+
image_width=frame_style.shape[1],
|
|
101
|
+
channel=3,
|
|
102
|
+
**ebsynth_config
|
|
103
|
+
)
|
|
104
|
+
# run
|
|
105
|
+
frame_id, n = 0, len(frames_guide)
|
|
106
|
+
for i in tqdm(range(0, n, batch_size - tracking_window_size * 2), desc=f"Rendering frames 0...{n}"):
|
|
107
|
+
if i + batch_size > n:
|
|
108
|
+
l, r = max(n - batch_size, 0), n
|
|
109
|
+
else:
|
|
110
|
+
l, r = i, i + batch_size
|
|
111
|
+
source_guide = np.stack([frame_guide] * (r-l))
|
|
112
|
+
target_guide = np.stack([frames_guide[i] for i in range(l, r)])
|
|
113
|
+
source_style = np.stack([frame_style] * (r-l))
|
|
114
|
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
|
115
|
+
for i, frame in zip(range(l, r), target_style):
|
|
116
|
+
if i==frame_id:
|
|
117
|
+
frame = frame.clip(0, 255).astype("uint8")
|
|
118
|
+
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % frame_id))
|
|
119
|
+
frame_id += 1
|
|
120
|
+
if r < n and r-frame_id <= tracking_window_size:
|
|
121
|
+
break
|
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
import numpy as np
|
|
5
|
+
from PIL import Image
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def warp(tenInput, tenFlow, device):
|
|
9
|
+
backwarp_tenGrid = {}
|
|
10
|
+
k = (str(tenFlow.device), str(tenFlow.size()))
|
|
11
|
+
if k not in backwarp_tenGrid:
|
|
12
|
+
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
|
|
13
|
+
1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
|
|
14
|
+
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
|
|
15
|
+
1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
|
|
16
|
+
backwarp_tenGrid[k] = torch.cat(
|
|
17
|
+
[tenHorizontal, tenVertical], 1).to(device)
|
|
18
|
+
|
|
19
|
+
tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
|
|
20
|
+
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
|
|
21
|
+
|
|
22
|
+
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
|
|
23
|
+
return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
|
27
|
+
return nn.Sequential(
|
|
28
|
+
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
|
29
|
+
padding=padding, dilation=dilation, bias=True),
|
|
30
|
+
nn.PReLU(out_planes)
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class IFBlock(nn.Module):
|
|
35
|
+
def __init__(self, in_planes, c=64):
|
|
36
|
+
super(IFBlock, self).__init__()
|
|
37
|
+
self.conv0 = nn.Sequential(conv(in_planes, c//2, 3, 2, 1), conv(c//2, c, 3, 2, 1),)
|
|
38
|
+
self.convblock0 = nn.Sequential(conv(c, c), conv(c, c))
|
|
39
|
+
self.convblock1 = nn.Sequential(conv(c, c), conv(c, c))
|
|
40
|
+
self.convblock2 = nn.Sequential(conv(c, c), conv(c, c))
|
|
41
|
+
self.convblock3 = nn.Sequential(conv(c, c), conv(c, c))
|
|
42
|
+
self.conv1 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 4, 4, 2, 1))
|
|
43
|
+
self.conv2 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 1, 4, 2, 1))
|
|
44
|
+
|
|
45
|
+
def forward(self, x, flow, scale=1):
|
|
46
|
+
x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
|
47
|
+
flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale
|
|
48
|
+
feat = self.conv0(torch.cat((x, flow), 1))
|
|
49
|
+
feat = self.convblock0(feat) + feat
|
|
50
|
+
feat = self.convblock1(feat) + feat
|
|
51
|
+
feat = self.convblock2(feat) + feat
|
|
52
|
+
feat = self.convblock3(feat) + feat
|
|
53
|
+
flow = self.conv1(feat)
|
|
54
|
+
mask = self.conv2(feat)
|
|
55
|
+
flow = F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale
|
|
56
|
+
mask = F.interpolate(mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
|
57
|
+
return flow, mask
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class IFNet(nn.Module):
|
|
61
|
+
def __init__(self):
|
|
62
|
+
super(IFNet, self).__init__()
|
|
63
|
+
self.block0 = IFBlock(7+4, c=90)
|
|
64
|
+
self.block1 = IFBlock(7+4, c=90)
|
|
65
|
+
self.block2 = IFBlock(7+4, c=90)
|
|
66
|
+
self.block_tea = IFBlock(10+4, c=90)
|
|
67
|
+
|
|
68
|
+
def forward(self, x, scale_list=[4, 2, 1], training=False):
|
|
69
|
+
if training == False:
|
|
70
|
+
channel = x.shape[1] // 2
|
|
71
|
+
img0 = x[:, :channel]
|
|
72
|
+
img1 = x[:, channel:]
|
|
73
|
+
flow_list = []
|
|
74
|
+
merged = []
|
|
75
|
+
mask_list = []
|
|
76
|
+
warped_img0 = img0
|
|
77
|
+
warped_img1 = img1
|
|
78
|
+
flow = (x[:, :4]).detach() * 0
|
|
79
|
+
mask = (x[:, :1]).detach() * 0
|
|
80
|
+
block = [self.block0, self.block1, self.block2]
|
|
81
|
+
for i in range(3):
|
|
82
|
+
f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
|
|
83
|
+
f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i])
|
|
84
|
+
flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
|
|
85
|
+
mask = mask + (m0 + (-m1)) / 2
|
|
86
|
+
mask_list.append(mask)
|
|
87
|
+
flow_list.append(flow)
|
|
88
|
+
warped_img0 = warp(img0, flow[:, :2], device=x.device)
|
|
89
|
+
warped_img1 = warp(img1, flow[:, 2:4], device=x.device)
|
|
90
|
+
merged.append((warped_img0, warped_img1))
|
|
91
|
+
'''
|
|
92
|
+
c0 = self.contextnet(img0, flow[:, :2])
|
|
93
|
+
c1 = self.contextnet(img1, flow[:, 2:4])
|
|
94
|
+
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
|
|
95
|
+
res = tmp[:, 1:4] * 2 - 1
|
|
96
|
+
'''
|
|
97
|
+
for i in range(3):
|
|
98
|
+
mask_list[i] = torch.sigmoid(mask_list[i])
|
|
99
|
+
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
|
|
100
|
+
return flow_list, mask_list[2], merged
|
|
101
|
+
|
|
102
|
+
@staticmethod
|
|
103
|
+
def state_dict_converter():
|
|
104
|
+
return IFNetStateDictConverter()
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class IFNetStateDictConverter:
|
|
108
|
+
def __init__(self):
|
|
109
|
+
pass
|
|
110
|
+
|
|
111
|
+
def from_diffusers(self, state_dict):
|
|
112
|
+
state_dict_ = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
|
113
|
+
return state_dict_
|
|
114
|
+
|
|
115
|
+
def from_civitai(self, state_dict):
|
|
116
|
+
return self.from_diffusers(state_dict)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class RIFEInterpolater:
|
|
120
|
+
def __init__(self, model, device="cuda"):
|
|
121
|
+
self.model = model
|
|
122
|
+
self.device = device
|
|
123
|
+
# IFNet only does not support float16
|
|
124
|
+
self.torch_dtype = torch.float32
|
|
125
|
+
|
|
126
|
+
@staticmethod
|
|
127
|
+
def from_model_manager(model_manager):
|
|
128
|
+
return RIFEInterpolater(model_manager.RIFE, device=model_manager.device)
|
|
129
|
+
|
|
130
|
+
def process_image(self, image):
|
|
131
|
+
width, height = image.size
|
|
132
|
+
if width % 32 != 0 or height % 32 != 0:
|
|
133
|
+
width = (width + 31) // 32
|
|
134
|
+
height = (height + 31) // 32
|
|
135
|
+
image = image.resize((width, height))
|
|
136
|
+
image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1)
|
|
137
|
+
return image
|
|
138
|
+
|
|
139
|
+
def process_images(self, images):
|
|
140
|
+
images = [self.process_image(image) for image in images]
|
|
141
|
+
images = torch.stack(images)
|
|
142
|
+
return images
|
|
143
|
+
|
|
144
|
+
def decode_images(self, images):
|
|
145
|
+
images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
|
|
146
|
+
images = [Image.fromarray(image) for image in images]
|
|
147
|
+
return images
|
|
148
|
+
|
|
149
|
+
def add_interpolated_images(self, images, interpolated_images):
|
|
150
|
+
output_images = []
|
|
151
|
+
for image, interpolated_image in zip(images, interpolated_images):
|
|
152
|
+
output_images.append(image)
|
|
153
|
+
output_images.append(interpolated_image)
|
|
154
|
+
output_images.append(images[-1])
|
|
155
|
+
return output_images
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
@torch.no_grad()
|
|
159
|
+
def interpolate_(self, images, scale=1.0):
|
|
160
|
+
input_tensor = self.process_images(images)
|
|
161
|
+
input_tensor = torch.cat((input_tensor[:-1], input_tensor[1:]), dim=1)
|
|
162
|
+
input_tensor = input_tensor.to(device=self.device, dtype=self.torch_dtype)
|
|
163
|
+
flow, mask, merged = self.model(input_tensor, [4/scale, 2/scale, 1/scale])
|
|
164
|
+
output_images = self.decode_images(merged[2].cpu())
|
|
165
|
+
if output_images[0].size != images[0].size:
|
|
166
|
+
output_images = [image.resize(images[0].size) for image in output_images]
|
|
167
|
+
return output_images
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
@torch.no_grad()
|
|
171
|
+
def interpolate(self, images, scale=1.0, batch_size=4, num_iter=1, progress_bar=lambda x:x):
|
|
172
|
+
# Preprocess
|
|
173
|
+
processed_images = self.process_images(images)
|
|
174
|
+
|
|
175
|
+
for iter in range(num_iter):
|
|
176
|
+
# Input
|
|
177
|
+
input_tensor = torch.cat((processed_images[:-1], processed_images[1:]), dim=1)
|
|
178
|
+
|
|
179
|
+
# Interpolate
|
|
180
|
+
output_tensor = []
|
|
181
|
+
for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
|
|
182
|
+
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
|
|
183
|
+
batch_input_tensor = input_tensor[batch_id: batch_id_]
|
|
184
|
+
batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
|
|
185
|
+
flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
|
|
186
|
+
output_tensor.append(merged[2].cpu())
|
|
187
|
+
|
|
188
|
+
# Output
|
|
189
|
+
output_tensor = torch.concat(output_tensor, dim=0).clip(0, 1)
|
|
190
|
+
processed_images = self.add_interpolated_images(processed_images, output_tensor)
|
|
191
|
+
processed_images = torch.stack(processed_images)
|
|
192
|
+
|
|
193
|
+
# To images
|
|
194
|
+
output_images = self.decode_images(processed_images)
|
|
195
|
+
if output_images[0].size != images[0].size:
|
|
196
|
+
output_images = [image.resize(images[0].size) for image in output_images]
|
|
197
|
+
return output_images
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class RIFESmoother(RIFEInterpolater):
|
|
201
|
+
def __init__(self, model, device="cuda"):
|
|
202
|
+
super(RIFESmoother, self).__init__(model, device=device)
|
|
203
|
+
|
|
204
|
+
@staticmethod
|
|
205
|
+
def from_model_manager(model_manager):
|
|
206
|
+
return RIFESmoother(model_manager.RIFE, device=model_manager.device)
|
|
207
|
+
|
|
208
|
+
def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
|
|
209
|
+
output_tensor = []
|
|
210
|
+
for batch_id in range(0, input_tensor.shape[0], batch_size):
|
|
211
|
+
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
|
|
212
|
+
batch_input_tensor = input_tensor[batch_id: batch_id_]
|
|
213
|
+
batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
|
|
214
|
+
flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
|
|
215
|
+
output_tensor.append(merged[2].cpu())
|
|
216
|
+
output_tensor = torch.concat(output_tensor, dim=0)
|
|
217
|
+
return output_tensor
|
|
218
|
+
|
|
219
|
+
@torch.no_grad()
|
|
220
|
+
def __call__(self, rendered_frames, scale=1.0, batch_size=4, num_iter=1, **kwargs):
|
|
221
|
+
# Preprocess
|
|
222
|
+
processed_images = self.process_images(rendered_frames)
|
|
223
|
+
|
|
224
|
+
for iter in range(num_iter):
|
|
225
|
+
# Input
|
|
226
|
+
input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1)
|
|
227
|
+
|
|
228
|
+
# Interpolate
|
|
229
|
+
output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
|
|
230
|
+
|
|
231
|
+
# Blend
|
|
232
|
+
input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1)
|
|
233
|
+
output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
|
|
234
|
+
|
|
235
|
+
# Add to frames
|
|
236
|
+
processed_images[1:-1] = output_tensor
|
|
237
|
+
|
|
238
|
+
# To images
|
|
239
|
+
output_images = self.decode_images(processed_images)
|
|
240
|
+
if output_images[0].size != rendered_frames[0].size:
|
|
241
|
+
output_images = [image.resize(rendered_frames[0].size) for image in output_images]
|
|
242
|
+
return output_images
|
|
File without changes
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .model_manager import *
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from einops import rearrange
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def low_version_attention(query, key, value, attn_bias=None):
|
|
6
|
+
scale = 1 / query.shape[-1] ** 0.5
|
|
7
|
+
query = query * scale
|
|
8
|
+
attn = torch.matmul(query, key.transpose(-2, -1))
|
|
9
|
+
if attn_bias is not None:
|
|
10
|
+
attn = attn + attn_bias
|
|
11
|
+
attn = attn.softmax(-1)
|
|
12
|
+
return attn @ value
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Attention(torch.nn.Module):
|
|
16
|
+
|
|
17
|
+
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
|
18
|
+
super().__init__()
|
|
19
|
+
dim_inner = head_dim * num_heads
|
|
20
|
+
kv_dim = kv_dim if kv_dim is not None else q_dim
|
|
21
|
+
self.num_heads = num_heads
|
|
22
|
+
self.head_dim = head_dim
|
|
23
|
+
|
|
24
|
+
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
|
25
|
+
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
|
26
|
+
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
|
27
|
+
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
|
28
|
+
|
|
29
|
+
def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
|
|
30
|
+
batch_size = q.shape[0]
|
|
31
|
+
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
32
|
+
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
33
|
+
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
|
|
34
|
+
hidden_states = hidden_states + scale * ip_hidden_states
|
|
35
|
+
return hidden_states
|
|
36
|
+
|
|
37
|
+
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
|
38
|
+
if encoder_hidden_states is None:
|
|
39
|
+
encoder_hidden_states = hidden_states
|
|
40
|
+
|
|
41
|
+
batch_size = encoder_hidden_states.shape[0]
|
|
42
|
+
|
|
43
|
+
q = self.to_q(hidden_states)
|
|
44
|
+
k = self.to_k(encoder_hidden_states)
|
|
45
|
+
v = self.to_v(encoder_hidden_states)
|
|
46
|
+
|
|
47
|
+
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
48
|
+
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
49
|
+
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
50
|
+
|
|
51
|
+
if qkv_preprocessor is not None:
|
|
52
|
+
q, k, v = qkv_preprocessor(q, k, v)
|
|
53
|
+
|
|
54
|
+
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
|
55
|
+
if ipadapter_kwargs is not None:
|
|
56
|
+
hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
|
|
57
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
|
58
|
+
hidden_states = hidden_states.to(q.dtype)
|
|
59
|
+
|
|
60
|
+
hidden_states = self.to_out(hidden_states)
|
|
61
|
+
|
|
62
|
+
return hidden_states
|
|
63
|
+
|
|
64
|
+
def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
|
65
|
+
if encoder_hidden_states is None:
|
|
66
|
+
encoder_hidden_states = hidden_states
|
|
67
|
+
|
|
68
|
+
q = self.to_q(hidden_states)
|
|
69
|
+
k = self.to_k(encoder_hidden_states)
|
|
70
|
+
v = self.to_v(encoder_hidden_states)
|
|
71
|
+
|
|
72
|
+
q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
|
|
73
|
+
k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
|
|
74
|
+
v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
|
|
75
|
+
|
|
76
|
+
if attn_mask is not None:
|
|
77
|
+
hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
|
|
78
|
+
else:
|
|
79
|
+
import xformers.ops as xops
|
|
80
|
+
hidden_states = xops.memory_efficient_attention(q, k, v)
|
|
81
|
+
hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
|
|
82
|
+
|
|
83
|
+
hidden_states = hidden_states.to(q.dtype)
|
|
84
|
+
hidden_states = self.to_out(hidden_states)
|
|
85
|
+
|
|
86
|
+
return hidden_states
|
|
87
|
+
|
|
88
|
+
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
|
89
|
+
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from huggingface_hub import hf_hub_download
|
|
2
|
+
from modelscope import snapshot_download
|
|
3
|
+
import os, shutil
|
|
4
|
+
from typing_extensions import Literal, TypeAlias
|
|
5
|
+
from typing import List
|
|
6
|
+
from ..configs.model_config import preset_models_on_huggingface, preset_models_on_modelscope, Preset_model_id
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def download_from_modelscope(model_id, origin_file_path, local_dir):
|
|
10
|
+
os.makedirs(local_dir, exist_ok=True)
|
|
11
|
+
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
|
12
|
+
print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.")
|
|
13
|
+
return
|
|
14
|
+
else:
|
|
15
|
+
print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
|
|
16
|
+
snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
|
|
17
|
+
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
|
18
|
+
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
|
|
19
|
+
if downloaded_file_path != target_file_path:
|
|
20
|
+
shutil.move(downloaded_file_path, target_file_path)
|
|
21
|
+
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def download_from_huggingface(model_id, origin_file_path, local_dir):
|
|
25
|
+
os.makedirs(local_dir, exist_ok=True)
|
|
26
|
+
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
|
27
|
+
print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.")
|
|
28
|
+
return
|
|
29
|
+
else:
|
|
30
|
+
print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
|
|
31
|
+
hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
Preset_model_website: TypeAlias = Literal[
|
|
35
|
+
"HuggingFace",
|
|
36
|
+
"ModelScope",
|
|
37
|
+
]
|
|
38
|
+
website_to_preset_models = {
|
|
39
|
+
"HuggingFace": preset_models_on_huggingface,
|
|
40
|
+
"ModelScope": preset_models_on_modelscope,
|
|
41
|
+
}
|
|
42
|
+
website_to_download_fn = {
|
|
43
|
+
"HuggingFace": download_from_huggingface,
|
|
44
|
+
"ModelScope": download_from_modelscope,
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def download_models(
|
|
49
|
+
model_id_list: List[Preset_model_id] = [],
|
|
50
|
+
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
|
51
|
+
):
|
|
52
|
+
print(f"Downloading models: {model_id_list}")
|
|
53
|
+
downloaded_files = []
|
|
54
|
+
for model_id in model_id_list:
|
|
55
|
+
for website in downloading_priority:
|
|
56
|
+
if model_id in website_to_preset_models[website]:
|
|
57
|
+
for model_id, origin_file_path, local_dir in website_to_preset_models[website][model_id]:
|
|
58
|
+
# Check if the file is downloaded.
|
|
59
|
+
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
|
|
60
|
+
if file_to_download in downloaded_files:
|
|
61
|
+
continue
|
|
62
|
+
# Download
|
|
63
|
+
website_to_download_fn[website](model_id, origin_file_path, local_dir)
|
|
64
|
+
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
|
65
|
+
downloaded_files.append(file_to_download)
|
|
66
|
+
return downloaded_files
|