diffsynth 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth/__init__.py +6 -0
- diffsynth/configs/__init__.py +0 -0
- diffsynth/configs/model_config.py +243 -0
- diffsynth/controlnets/__init__.py +2 -0
- diffsynth/controlnets/controlnet_unit.py +53 -0
- diffsynth/controlnets/processors.py +51 -0
- diffsynth/data/__init__.py +1 -0
- diffsynth/data/simple_text_image.py +35 -0
- diffsynth/data/video.py +148 -0
- diffsynth/extensions/ESRGAN/__init__.py +118 -0
- diffsynth/extensions/FastBlend/__init__.py +63 -0
- diffsynth/extensions/FastBlend/api.py +397 -0
- diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
- diffsynth/extensions/FastBlend/data.py +146 -0
- diffsynth/extensions/FastBlend/patch_match.py +298 -0
- diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
- diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
- diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
- diffsynth/extensions/FastBlend/runners/fast.py +141 -0
- diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
- diffsynth/extensions/RIFE/__init__.py +242 -0
- diffsynth/extensions/__init__.py +0 -0
- diffsynth/models/__init__.py +1 -0
- diffsynth/models/attention.py +89 -0
- diffsynth/models/downloader.py +66 -0
- diffsynth/models/hunyuan_dit.py +451 -0
- diffsynth/models/hunyuan_dit_text_encoder.py +163 -0
- diffsynth/models/kolors_text_encoder.py +1363 -0
- diffsynth/models/lora.py +195 -0
- diffsynth/models/model_manager.py +536 -0
- diffsynth/models/sd3_dit.py +798 -0
- diffsynth/models/sd3_text_encoder.py +1107 -0
- diffsynth/models/sd3_vae_decoder.py +81 -0
- diffsynth/models/sd3_vae_encoder.py +95 -0
- diffsynth/models/sd_controlnet.py +588 -0
- diffsynth/models/sd_ipadapter.py +57 -0
- diffsynth/models/sd_motion.py +199 -0
- diffsynth/models/sd_text_encoder.py +321 -0
- diffsynth/models/sd_unet.py +1108 -0
- diffsynth/models/sd_vae_decoder.py +336 -0
- diffsynth/models/sd_vae_encoder.py +282 -0
- diffsynth/models/sdxl_ipadapter.py +122 -0
- diffsynth/models/sdxl_motion.py +104 -0
- diffsynth/models/sdxl_text_encoder.py +759 -0
- diffsynth/models/sdxl_unet.py +1899 -0
- diffsynth/models/sdxl_vae_decoder.py +24 -0
- diffsynth/models/sdxl_vae_encoder.py +24 -0
- diffsynth/models/svd_image_encoder.py +505 -0
- diffsynth/models/svd_unet.py +2004 -0
- diffsynth/models/svd_vae_decoder.py +578 -0
- diffsynth/models/svd_vae_encoder.py +139 -0
- diffsynth/models/tiler.py +106 -0
- diffsynth/pipelines/__init__.py +9 -0
- diffsynth/pipelines/base.py +34 -0
- diffsynth/pipelines/dancer.py +178 -0
- diffsynth/pipelines/hunyuan_image.py +274 -0
- diffsynth/pipelines/pipeline_runner.py +105 -0
- diffsynth/pipelines/sd3_image.py +132 -0
- diffsynth/pipelines/sd_image.py +173 -0
- diffsynth/pipelines/sd_video.py +266 -0
- diffsynth/pipelines/sdxl_image.py +191 -0
- diffsynth/pipelines/sdxl_video.py +223 -0
- diffsynth/pipelines/svd_video.py +297 -0
- diffsynth/processors/FastBlend.py +142 -0
- diffsynth/processors/PILEditor.py +28 -0
- diffsynth/processors/RIFE.py +77 -0
- diffsynth/processors/__init__.py +0 -0
- diffsynth/processors/base.py +6 -0
- diffsynth/processors/sequencial_processor.py +41 -0
- diffsynth/prompters/__init__.py +6 -0
- diffsynth/prompters/base_prompter.py +57 -0
- diffsynth/prompters/hunyuan_dit_prompter.py +69 -0
- diffsynth/prompters/kolors_prompter.py +353 -0
- diffsynth/prompters/prompt_refiners.py +77 -0
- diffsynth/prompters/sd3_prompter.py +92 -0
- diffsynth/prompters/sd_prompter.py +73 -0
- diffsynth/prompters/sdxl_prompter.py +61 -0
- diffsynth/schedulers/__init__.py +3 -0
- diffsynth/schedulers/continuous_ode.py +59 -0
- diffsynth/schedulers/ddim.py +79 -0
- diffsynth/schedulers/flow_match.py +51 -0
- diffsynth/tokenizer_configs/__init__.py +0 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json +7 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json +16 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt +47020 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt +21128 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json +28 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json +1 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model +0 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json +1 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model +0 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json +12 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt +0 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json +24 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json +34 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json +38 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json +125 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/spiece.model +0 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json +129428 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json +940 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt +40213 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json +24 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json +38 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/vocab.json +49411 -0
- diffsynth/trainers/__init__.py +0 -0
- diffsynth/trainers/text_to_image.py +253 -0
- diffsynth-1.0.0.dist-info/LICENSE +201 -0
- diffsynth-1.0.0.dist-info/METADATA +23 -0
- diffsynth-1.0.0.dist-info/RECORD +120 -0
- diffsynth-1.0.0.dist-info/WHEEL +5 -0
- diffsynth-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
from ..models import SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder, SDMotionModel
|
|
2
|
+
from ..models.model_manager import ModelManager
|
|
3
|
+
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
|
4
|
+
from ..prompters import SDPrompter
|
|
5
|
+
from ..schedulers import EnhancedDDIMScheduler
|
|
6
|
+
from .sd_image import SDImagePipeline
|
|
7
|
+
from .dancer import lets_dance
|
|
8
|
+
from typing import List
|
|
9
|
+
import torch
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def lets_dance_with_long_video(
|
|
15
|
+
unet: SDUNet,
|
|
16
|
+
motion_modules: SDMotionModel = None,
|
|
17
|
+
controlnet: MultiControlNetManager = None,
|
|
18
|
+
sample = None,
|
|
19
|
+
timestep = None,
|
|
20
|
+
encoder_hidden_states = None,
|
|
21
|
+
ipadapter_kwargs_list = {},
|
|
22
|
+
controlnet_frames = None,
|
|
23
|
+
unet_batch_size = 1,
|
|
24
|
+
controlnet_batch_size = 1,
|
|
25
|
+
cross_frame_attention = False,
|
|
26
|
+
tiled=False,
|
|
27
|
+
tile_size=64,
|
|
28
|
+
tile_stride=32,
|
|
29
|
+
device="cuda",
|
|
30
|
+
animatediff_batch_size=16,
|
|
31
|
+
animatediff_stride=8,
|
|
32
|
+
):
|
|
33
|
+
num_frames = sample.shape[0]
|
|
34
|
+
hidden_states_output = [(torch.zeros(sample[0].shape, dtype=sample[0].dtype), 0) for i in range(num_frames)]
|
|
35
|
+
|
|
36
|
+
for batch_id in range(0, num_frames, animatediff_stride):
|
|
37
|
+
batch_id_ = min(batch_id + animatediff_batch_size, num_frames)
|
|
38
|
+
|
|
39
|
+
# process this batch
|
|
40
|
+
hidden_states_batch = lets_dance(
|
|
41
|
+
unet, motion_modules, controlnet,
|
|
42
|
+
sample[batch_id: batch_id_].to(device),
|
|
43
|
+
timestep,
|
|
44
|
+
encoder_hidden_states,
|
|
45
|
+
ipadapter_kwargs_list=ipadapter_kwargs_list,
|
|
46
|
+
controlnet_frames=controlnet_frames[:, batch_id: batch_id_].to(device) if controlnet_frames is not None else None,
|
|
47
|
+
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
|
|
48
|
+
cross_frame_attention=cross_frame_attention,
|
|
49
|
+
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, device=device
|
|
50
|
+
).cpu()
|
|
51
|
+
|
|
52
|
+
# update hidden_states
|
|
53
|
+
for i, hidden_states_updated in zip(range(batch_id, batch_id_), hidden_states_batch):
|
|
54
|
+
bias = max(1 - abs(i - (batch_id + batch_id_ - 1) / 2) / ((batch_id_ - batch_id - 1 + 1e-2) / 2), 1e-2)
|
|
55
|
+
hidden_states, num = hidden_states_output[i]
|
|
56
|
+
hidden_states = hidden_states * (num / (num + bias)) + hidden_states_updated * (bias / (num + bias))
|
|
57
|
+
hidden_states_output[i] = (hidden_states, num + bias)
|
|
58
|
+
|
|
59
|
+
if batch_id_ == num_frames:
|
|
60
|
+
break
|
|
61
|
+
|
|
62
|
+
# output
|
|
63
|
+
hidden_states = torch.stack([h for h, _ in hidden_states_output])
|
|
64
|
+
return hidden_states
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class SDVideoPipeline(SDImagePipeline):
|
|
69
|
+
|
|
70
|
+
def __init__(self, device="cuda", torch_dtype=torch.float16, use_original_animatediff=True):
|
|
71
|
+
super().__init__(device=device, torch_dtype=torch_dtype)
|
|
72
|
+
self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_original_animatediff else "scaled_linear")
|
|
73
|
+
self.prompter = SDPrompter()
|
|
74
|
+
# models
|
|
75
|
+
self.text_encoder: SDTextEncoder = None
|
|
76
|
+
self.unet: SDUNet = None
|
|
77
|
+
self.vae_decoder: SDVAEDecoder = None
|
|
78
|
+
self.vae_encoder: SDVAEEncoder = None
|
|
79
|
+
self.controlnet: MultiControlNetManager = None
|
|
80
|
+
self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None
|
|
81
|
+
self.ipadapter: SDIpAdapter = None
|
|
82
|
+
self.motion_modules: SDMotionModel = None
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
|
86
|
+
# Main models
|
|
87
|
+
self.text_encoder = model_manager.fetch_model("sd_text_encoder")
|
|
88
|
+
self.unet = model_manager.fetch_model("sd_unet")
|
|
89
|
+
self.vae_decoder = model_manager.fetch_model("sd_vae_decoder")
|
|
90
|
+
self.vae_encoder = model_manager.fetch_model("sd_vae_encoder")
|
|
91
|
+
self.prompter.fetch_models(self.text_encoder)
|
|
92
|
+
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
|
93
|
+
|
|
94
|
+
# ControlNets
|
|
95
|
+
controlnet_units = []
|
|
96
|
+
for config in controlnet_config_units:
|
|
97
|
+
controlnet_unit = ControlNetUnit(
|
|
98
|
+
Annotator(config.processor_id, device=self.device),
|
|
99
|
+
model_manager.fetch_model("sd_controlnet", config.model_path),
|
|
100
|
+
config.scale
|
|
101
|
+
)
|
|
102
|
+
controlnet_units.append(controlnet_unit)
|
|
103
|
+
self.controlnet = MultiControlNetManager(controlnet_units)
|
|
104
|
+
|
|
105
|
+
# IP-Adapters
|
|
106
|
+
self.ipadapter = model_manager.fetch_model("sd_ipadapter")
|
|
107
|
+
self.ipadapter_image_encoder = model_manager.fetch_model("sd_ipadapter_clip_image_encoder")
|
|
108
|
+
|
|
109
|
+
# Motion Modules
|
|
110
|
+
self.motion_modules = model_manager.fetch_model("sd_motion_modules")
|
|
111
|
+
if self.motion_modules is None:
|
|
112
|
+
self.scheduler = EnhancedDDIMScheduler(beta_schedule="scaled_linear")
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@staticmethod
|
|
116
|
+
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
|
117
|
+
pipe = SDVideoPipeline(
|
|
118
|
+
device=model_manager.device,
|
|
119
|
+
torch_dtype=model_manager.torch_dtype,
|
|
120
|
+
)
|
|
121
|
+
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
|
|
122
|
+
return pipe
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def decode_video(self, latents, tiled=False, tile_size=64, tile_stride=32):
|
|
126
|
+
images = [
|
|
127
|
+
self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
128
|
+
for frame_id in range(latents.shape[0])
|
|
129
|
+
]
|
|
130
|
+
return images
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def encode_video(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
|
|
134
|
+
latents = []
|
|
135
|
+
for image in processed_images:
|
|
136
|
+
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
|
137
|
+
latent = self.encode_image(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
138
|
+
latents.append(latent.cpu())
|
|
139
|
+
latents = torch.concat(latents, dim=0)
|
|
140
|
+
return latents
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
@torch.no_grad()
|
|
144
|
+
def __call__(
|
|
145
|
+
self,
|
|
146
|
+
prompt,
|
|
147
|
+
negative_prompt="",
|
|
148
|
+
cfg_scale=7.5,
|
|
149
|
+
clip_skip=1,
|
|
150
|
+
num_frames=None,
|
|
151
|
+
input_frames=None,
|
|
152
|
+
ipadapter_images=None,
|
|
153
|
+
ipadapter_scale=1.0,
|
|
154
|
+
controlnet_frames=None,
|
|
155
|
+
denoising_strength=1.0,
|
|
156
|
+
height=512,
|
|
157
|
+
width=512,
|
|
158
|
+
num_inference_steps=20,
|
|
159
|
+
animatediff_batch_size = 16,
|
|
160
|
+
animatediff_stride = 8,
|
|
161
|
+
unet_batch_size = 1,
|
|
162
|
+
controlnet_batch_size = 1,
|
|
163
|
+
cross_frame_attention = False,
|
|
164
|
+
smoother=None,
|
|
165
|
+
smoother_progress_ids=[],
|
|
166
|
+
tiled=False,
|
|
167
|
+
tile_size=64,
|
|
168
|
+
tile_stride=32,
|
|
169
|
+
progress_bar_cmd=tqdm,
|
|
170
|
+
progress_bar_st=None,
|
|
171
|
+
):
|
|
172
|
+
# Tiler parameters, batch size ...
|
|
173
|
+
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
|
174
|
+
other_kwargs = {
|
|
175
|
+
"animatediff_batch_size": animatediff_batch_size, "animatediff_stride": animatediff_stride,
|
|
176
|
+
"unet_batch_size": unet_batch_size, "controlnet_batch_size": controlnet_batch_size,
|
|
177
|
+
"cross_frame_attention": cross_frame_attention,
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
# Prepare scheduler
|
|
181
|
+
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
|
182
|
+
|
|
183
|
+
# Prepare latent tensors
|
|
184
|
+
if self.motion_modules is None:
|
|
185
|
+
noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
|
|
186
|
+
else:
|
|
187
|
+
noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype)
|
|
188
|
+
if input_frames is None or denoising_strength == 1.0:
|
|
189
|
+
latents = noise
|
|
190
|
+
else:
|
|
191
|
+
latents = self.encode_video(input_frames, **tiler_kwargs)
|
|
192
|
+
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
|
193
|
+
|
|
194
|
+
# Encode prompts
|
|
195
|
+
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
|
|
196
|
+
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
|
|
197
|
+
|
|
198
|
+
# IP-Adapter
|
|
199
|
+
if ipadapter_images is not None:
|
|
200
|
+
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
|
|
201
|
+
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
|
|
202
|
+
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
|
|
203
|
+
else:
|
|
204
|
+
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
|
|
205
|
+
|
|
206
|
+
# Prepare ControlNets
|
|
207
|
+
if controlnet_frames is not None:
|
|
208
|
+
if isinstance(controlnet_frames[0], list):
|
|
209
|
+
controlnet_frames_ = []
|
|
210
|
+
for processor_id in range(len(controlnet_frames)):
|
|
211
|
+
controlnet_frames_.append(
|
|
212
|
+
torch.stack([
|
|
213
|
+
self.controlnet.process_image(controlnet_frame, processor_id=processor_id).to(self.torch_dtype)
|
|
214
|
+
for controlnet_frame in progress_bar_cmd(controlnet_frames[processor_id])
|
|
215
|
+
], dim=1)
|
|
216
|
+
)
|
|
217
|
+
controlnet_frames = torch.concat(controlnet_frames_, dim=0)
|
|
218
|
+
else:
|
|
219
|
+
controlnet_frames = torch.stack([
|
|
220
|
+
self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
|
|
221
|
+
for controlnet_frame in progress_bar_cmd(controlnet_frames)
|
|
222
|
+
], dim=1)
|
|
223
|
+
controlnet_kwargs = {"controlnet_frames": controlnet_frames}
|
|
224
|
+
else:
|
|
225
|
+
controlnet_kwargs = {"controlnet_frames": None}
|
|
226
|
+
|
|
227
|
+
# Denoise
|
|
228
|
+
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
|
229
|
+
timestep = timestep.unsqueeze(0).to(self.device)
|
|
230
|
+
|
|
231
|
+
# Classifier-free guidance
|
|
232
|
+
noise_pred_posi = lets_dance_with_long_video(
|
|
233
|
+
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
|
|
234
|
+
sample=latents, timestep=timestep,
|
|
235
|
+
**prompt_emb_posi, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **other_kwargs, **tiler_kwargs,
|
|
236
|
+
device=self.device,
|
|
237
|
+
)
|
|
238
|
+
noise_pred_nega = lets_dance_with_long_video(
|
|
239
|
+
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
|
|
240
|
+
sample=latents, timestep=timestep,
|
|
241
|
+
**prompt_emb_nega, **controlnet_kwargs, **ipadapter_kwargs_list_nega, **other_kwargs, **tiler_kwargs,
|
|
242
|
+
device=self.device,
|
|
243
|
+
)
|
|
244
|
+
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
|
245
|
+
|
|
246
|
+
# DDIM and smoother
|
|
247
|
+
if smoother is not None and progress_id in smoother_progress_ids:
|
|
248
|
+
rendered_frames = self.scheduler.step(noise_pred, timestep, latents, to_final=True)
|
|
249
|
+
rendered_frames = self.decode_video(rendered_frames)
|
|
250
|
+
rendered_frames = smoother(rendered_frames, original_frames=input_frames)
|
|
251
|
+
target_latents = self.encode_video(rendered_frames)
|
|
252
|
+
noise_pred = self.scheduler.return_to_timestep(timestep, latents, target_latents)
|
|
253
|
+
latents = self.scheduler.step(noise_pred, timestep, latents)
|
|
254
|
+
|
|
255
|
+
# UI
|
|
256
|
+
if progress_bar_st is not None:
|
|
257
|
+
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
|
258
|
+
|
|
259
|
+
# Decode image
|
|
260
|
+
output_frames = self.decode_video(latents, **tiler_kwargs)
|
|
261
|
+
|
|
262
|
+
# Post-process
|
|
263
|
+
if smoother is not None and (num_inference_steps in smoother_progress_ids or -1 in smoother_progress_ids):
|
|
264
|
+
output_frames = smoother(output_frames, original_frames=input_frames)
|
|
265
|
+
|
|
266
|
+
return output_frames
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
from ..models import SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
|
2
|
+
from ..models.kolors_text_encoder import ChatGLMModel
|
|
3
|
+
from ..models.model_manager import ModelManager
|
|
4
|
+
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
|
5
|
+
from ..prompters import SDXLPrompter, KolorsPrompter
|
|
6
|
+
from ..schedulers import EnhancedDDIMScheduler
|
|
7
|
+
from .base import BasePipeline
|
|
8
|
+
from .dancer import lets_dance_xl
|
|
9
|
+
from typing import List
|
|
10
|
+
import torch
|
|
11
|
+
from tqdm import tqdm
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SDXLImagePipeline(BasePipeline):
|
|
16
|
+
|
|
17
|
+
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
|
18
|
+
super().__init__(device=device, torch_dtype=torch_dtype)
|
|
19
|
+
self.scheduler = EnhancedDDIMScheduler()
|
|
20
|
+
self.prompter = SDXLPrompter()
|
|
21
|
+
# models
|
|
22
|
+
self.text_encoder: SDXLTextEncoder = None
|
|
23
|
+
self.text_encoder_2: SDXLTextEncoder2 = None
|
|
24
|
+
self.text_encoder_kolors: ChatGLMModel = None
|
|
25
|
+
self.unet: SDXLUNet = None
|
|
26
|
+
self.vae_decoder: SDXLVAEDecoder = None
|
|
27
|
+
self.vae_encoder: SDXLVAEEncoder = None
|
|
28
|
+
# self.controlnet: MultiControlNetManager = None (TODO)
|
|
29
|
+
self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
|
|
30
|
+
self.ipadapter: SDXLIpAdapter = None
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def denoising_model(self):
|
|
34
|
+
return self.unet
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
|
38
|
+
# Main models
|
|
39
|
+
self.text_encoder = model_manager.fetch_model("sdxl_text_encoder")
|
|
40
|
+
self.text_encoder_2 = model_manager.fetch_model("sdxl_text_encoder_2")
|
|
41
|
+
self.text_encoder_kolors = model_manager.fetch_model("kolors_text_encoder")
|
|
42
|
+
self.unet = model_manager.fetch_model("sdxl_unet")
|
|
43
|
+
self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
|
|
44
|
+
self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
|
|
45
|
+
|
|
46
|
+
# ControlNets (TODO)
|
|
47
|
+
|
|
48
|
+
# IP-Adapters
|
|
49
|
+
self.ipadapter = model_manager.fetch_model("sdxl_ipadapter")
|
|
50
|
+
self.ipadapter_image_encoder = model_manager.fetch_model("sdxl_ipadapter_clip_image_encoder")
|
|
51
|
+
|
|
52
|
+
# Kolors
|
|
53
|
+
if self.text_encoder_kolors is not None:
|
|
54
|
+
print("Switch to Kolors. The prompter and scheduler will be replaced.")
|
|
55
|
+
self.prompter = KolorsPrompter()
|
|
56
|
+
self.prompter.fetch_models(self.text_encoder_kolors)
|
|
57
|
+
self.scheduler = EnhancedDDIMScheduler(beta_end=0.014, num_train_timesteps=1100)
|
|
58
|
+
else:
|
|
59
|
+
self.prompter.fetch_models(self.text_encoder, self.text_encoder_2)
|
|
60
|
+
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@staticmethod
|
|
64
|
+
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
|
65
|
+
pipe = SDXLImagePipeline(
|
|
66
|
+
device=model_manager.device,
|
|
67
|
+
torch_dtype=model_manager.torch_dtype,
|
|
68
|
+
)
|
|
69
|
+
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
|
|
70
|
+
return pipe
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
|
|
74
|
+
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
75
|
+
return latents
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
|
79
|
+
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
80
|
+
image = self.vae_output_to_image(image)
|
|
81
|
+
return image
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def encode_prompt(self, prompt, clip_skip=1, clip_skip_2=2, positive=True):
|
|
85
|
+
add_prompt_emb, prompt_emb = self.prompter.encode_prompt(
|
|
86
|
+
prompt,
|
|
87
|
+
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
|
88
|
+
device=self.device,
|
|
89
|
+
positive=positive,
|
|
90
|
+
)
|
|
91
|
+
return {"encoder_hidden_states": prompt_emb, "add_text_embeds": add_prompt_emb}
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def prepare_extra_input(self, latents=None):
|
|
95
|
+
height, width = latents.shape[2] * 8, latents.shape[3] * 8
|
|
96
|
+
return {"add_time_id": torch.tensor([height, width, 0, 0, height, width], device=self.device)}
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@torch.no_grad()
|
|
100
|
+
def __call__(
|
|
101
|
+
self,
|
|
102
|
+
prompt,
|
|
103
|
+
negative_prompt="",
|
|
104
|
+
cfg_scale=7.5,
|
|
105
|
+
clip_skip=1,
|
|
106
|
+
clip_skip_2=2,
|
|
107
|
+
input_image=None,
|
|
108
|
+
ipadapter_images=None,
|
|
109
|
+
ipadapter_scale=1.0,
|
|
110
|
+
ipadapter_use_instant_style=False,
|
|
111
|
+
controlnet_image=None,
|
|
112
|
+
denoising_strength=1.0,
|
|
113
|
+
height=1024,
|
|
114
|
+
width=1024,
|
|
115
|
+
num_inference_steps=20,
|
|
116
|
+
tiled=False,
|
|
117
|
+
tile_size=64,
|
|
118
|
+
tile_stride=32,
|
|
119
|
+
progress_bar_cmd=tqdm,
|
|
120
|
+
progress_bar_st=None,
|
|
121
|
+
):
|
|
122
|
+
# Tiler parameters
|
|
123
|
+
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
|
124
|
+
|
|
125
|
+
# Prepare scheduler
|
|
126
|
+
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
|
127
|
+
|
|
128
|
+
# Prepare latent tensors
|
|
129
|
+
if input_image is not None:
|
|
130
|
+
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
|
131
|
+
latents = self.encode_image(image, **tiler_kwargs)
|
|
132
|
+
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
|
133
|
+
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
|
134
|
+
else:
|
|
135
|
+
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
|
136
|
+
|
|
137
|
+
# Encode prompts
|
|
138
|
+
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
|
|
139
|
+
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=False)
|
|
140
|
+
|
|
141
|
+
# IP-Adapter
|
|
142
|
+
if ipadapter_images is not None:
|
|
143
|
+
if ipadapter_use_instant_style:
|
|
144
|
+
self.ipadapter.set_less_adapter()
|
|
145
|
+
else:
|
|
146
|
+
self.ipadapter.set_full_adapter()
|
|
147
|
+
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
|
|
148
|
+
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
|
|
149
|
+
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
|
|
150
|
+
else:
|
|
151
|
+
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
|
|
152
|
+
|
|
153
|
+
# Prepare ControlNets (TODO)
|
|
154
|
+
controlnet_kwargs = {"controlnet_frames": None}
|
|
155
|
+
|
|
156
|
+
# Prepare extra input
|
|
157
|
+
extra_input = self.prepare_extra_input(latents)
|
|
158
|
+
|
|
159
|
+
# Denoise
|
|
160
|
+
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
|
161
|
+
timestep = timestep.unsqueeze(0).to(self.device)
|
|
162
|
+
|
|
163
|
+
# Classifier-free guidance
|
|
164
|
+
noise_pred_posi = lets_dance_xl(
|
|
165
|
+
self.unet, motion_modules=None, controlnet=None,
|
|
166
|
+
sample=latents, timestep=timestep, **extra_input,
|
|
167
|
+
**prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi,
|
|
168
|
+
device=self.device,
|
|
169
|
+
)
|
|
170
|
+
if cfg_scale != 1.0:
|
|
171
|
+
noise_pred_nega = lets_dance_xl(
|
|
172
|
+
self.unet, motion_modules=None, controlnet=None,
|
|
173
|
+
sample=latents, timestep=timestep, **extra_input,
|
|
174
|
+
**prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega,
|
|
175
|
+
device=self.device,
|
|
176
|
+
)
|
|
177
|
+
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
|
178
|
+
else:
|
|
179
|
+
noise_pred = noise_pred_posi
|
|
180
|
+
|
|
181
|
+
# DDIM
|
|
182
|
+
latents = self.scheduler.step(noise_pred, timestep, latents)
|
|
183
|
+
|
|
184
|
+
# UI
|
|
185
|
+
if progress_bar_st is not None:
|
|
186
|
+
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
|
187
|
+
|
|
188
|
+
# Decode image
|
|
189
|
+
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
190
|
+
|
|
191
|
+
return image
|
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
from ..models import SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder, SDXLMotionModel
|
|
2
|
+
from ..models.kolors_text_encoder import ChatGLMModel
|
|
3
|
+
from ..models.model_manager import ModelManager
|
|
4
|
+
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
|
5
|
+
from ..prompters import SDXLPrompter, KolorsPrompter
|
|
6
|
+
from ..schedulers import EnhancedDDIMScheduler
|
|
7
|
+
from .sdxl_image import SDXLImagePipeline
|
|
8
|
+
from .dancer import lets_dance_xl
|
|
9
|
+
from typing import List
|
|
10
|
+
import torch
|
|
11
|
+
from tqdm import tqdm
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SDXLVideoPipeline(SDXLImagePipeline):
|
|
16
|
+
|
|
17
|
+
def __init__(self, device="cuda", torch_dtype=torch.float16, use_original_animatediff=True):
|
|
18
|
+
super().__init__(device=device, torch_dtype=torch_dtype)
|
|
19
|
+
self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_original_animatediff else "scaled_linear")
|
|
20
|
+
self.prompter = SDXLPrompter()
|
|
21
|
+
# models
|
|
22
|
+
self.text_encoder: SDXLTextEncoder = None
|
|
23
|
+
self.text_encoder_2: SDXLTextEncoder2 = None
|
|
24
|
+
self.text_encoder_kolors: ChatGLMModel = None
|
|
25
|
+
self.unet: SDXLUNet = None
|
|
26
|
+
self.vae_decoder: SDXLVAEDecoder = None
|
|
27
|
+
self.vae_encoder: SDXLVAEEncoder = None
|
|
28
|
+
# self.controlnet: MultiControlNetManager = None (TODO)
|
|
29
|
+
self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
|
|
30
|
+
self.ipadapter: SDXLIpAdapter = None
|
|
31
|
+
self.motion_modules: SDXLMotionModel = None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
|
35
|
+
# Main models
|
|
36
|
+
self.text_encoder = model_manager.fetch_model("sdxl_text_encoder")
|
|
37
|
+
self.text_encoder_2 = model_manager.fetch_model("sdxl_text_encoder_2")
|
|
38
|
+
self.text_encoder_kolors = model_manager.fetch_model("kolors_text_encoder")
|
|
39
|
+
self.unet = model_manager.fetch_model("sdxl_unet")
|
|
40
|
+
self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
|
|
41
|
+
self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
|
|
42
|
+
self.prompter.fetch_models(self.text_encoder)
|
|
43
|
+
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
|
44
|
+
|
|
45
|
+
# ControlNets (TODO)
|
|
46
|
+
|
|
47
|
+
# IP-Adapters
|
|
48
|
+
self.ipadapter = model_manager.fetch_model("sdxl_ipadapter")
|
|
49
|
+
self.ipadapter_image_encoder = model_manager.fetch_model("sdxl_ipadapter_clip_image_encoder")
|
|
50
|
+
|
|
51
|
+
# Motion Modules
|
|
52
|
+
self.motion_modules = model_manager.fetch_model("sdxl_motion_modules")
|
|
53
|
+
if self.motion_modules is None:
|
|
54
|
+
self.scheduler = EnhancedDDIMScheduler(beta_schedule="scaled_linear")
|
|
55
|
+
|
|
56
|
+
# Kolors
|
|
57
|
+
if self.text_encoder_kolors is not None:
|
|
58
|
+
print("Switch to Kolors. The prompter will be replaced.")
|
|
59
|
+
self.prompter = KolorsPrompter()
|
|
60
|
+
self.prompter.fetch_models(self.text_encoder_kolors)
|
|
61
|
+
# The schedulers of AniamteDiff and Kolors are incompatible. We align it with AniamteDiff.
|
|
62
|
+
if self.motion_modules is None:
|
|
63
|
+
self.scheduler = EnhancedDDIMScheduler(beta_end=0.014, num_train_timesteps=1100)
|
|
64
|
+
else:
|
|
65
|
+
self.prompter.fetch_models(self.text_encoder, self.text_encoder_2)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@staticmethod
|
|
69
|
+
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
|
70
|
+
pipe = SDXLVideoPipeline(
|
|
71
|
+
device=model_manager.device,
|
|
72
|
+
torch_dtype=model_manager.torch_dtype,
|
|
73
|
+
)
|
|
74
|
+
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
|
|
75
|
+
return pipe
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def decode_video(self, latents, tiled=False, tile_size=64, tile_stride=32):
|
|
79
|
+
images = [
|
|
80
|
+
self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
81
|
+
for frame_id in range(latents.shape[0])
|
|
82
|
+
]
|
|
83
|
+
return images
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def encode_video(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
|
|
87
|
+
latents = []
|
|
88
|
+
for image in processed_images:
|
|
89
|
+
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
|
90
|
+
latent = self.encode_image(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
91
|
+
latents.append(latent.cpu())
|
|
92
|
+
latents = torch.concat(latents, dim=0)
|
|
93
|
+
return latents
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@torch.no_grad()
|
|
97
|
+
def __call__(
|
|
98
|
+
self,
|
|
99
|
+
prompt,
|
|
100
|
+
negative_prompt="",
|
|
101
|
+
cfg_scale=7.5,
|
|
102
|
+
clip_skip=1,
|
|
103
|
+
num_frames=None,
|
|
104
|
+
input_frames=None,
|
|
105
|
+
ipadapter_images=None,
|
|
106
|
+
ipadapter_scale=1.0,
|
|
107
|
+
ipadapter_use_instant_style=False,
|
|
108
|
+
controlnet_frames=None,
|
|
109
|
+
denoising_strength=1.0,
|
|
110
|
+
height=512,
|
|
111
|
+
width=512,
|
|
112
|
+
num_inference_steps=20,
|
|
113
|
+
animatediff_batch_size = 16,
|
|
114
|
+
animatediff_stride = 8,
|
|
115
|
+
unet_batch_size = 1,
|
|
116
|
+
controlnet_batch_size = 1,
|
|
117
|
+
cross_frame_attention = False,
|
|
118
|
+
smoother=None,
|
|
119
|
+
smoother_progress_ids=[],
|
|
120
|
+
tiled=False,
|
|
121
|
+
tile_size=64,
|
|
122
|
+
tile_stride=32,
|
|
123
|
+
progress_bar_cmd=tqdm,
|
|
124
|
+
progress_bar_st=None,
|
|
125
|
+
):
|
|
126
|
+
# Tiler parameters, batch size ...
|
|
127
|
+
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
|
128
|
+
|
|
129
|
+
# Prepare scheduler
|
|
130
|
+
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
|
131
|
+
|
|
132
|
+
# Prepare latent tensors
|
|
133
|
+
if self.motion_modules is None:
|
|
134
|
+
noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
|
|
135
|
+
else:
|
|
136
|
+
noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype)
|
|
137
|
+
if input_frames is None or denoising_strength == 1.0:
|
|
138
|
+
latents = noise
|
|
139
|
+
else:
|
|
140
|
+
latents = self.encode_video(input_frames, **tiler_kwargs)
|
|
141
|
+
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
|
142
|
+
latents = latents.to(self.device) # will be deleted for supporting long videos
|
|
143
|
+
|
|
144
|
+
# Encode prompts
|
|
145
|
+
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
|
|
146
|
+
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
|
|
147
|
+
|
|
148
|
+
# IP-Adapter
|
|
149
|
+
if ipadapter_images is not None:
|
|
150
|
+
if ipadapter_use_instant_style:
|
|
151
|
+
self.ipadapter.set_less_adapter()
|
|
152
|
+
else:
|
|
153
|
+
self.ipadapter.set_full_adapter()
|
|
154
|
+
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
|
|
155
|
+
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
|
|
156
|
+
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
|
|
157
|
+
else:
|
|
158
|
+
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
|
|
159
|
+
|
|
160
|
+
# Prepare ControlNets
|
|
161
|
+
if controlnet_frames is not None:
|
|
162
|
+
if isinstance(controlnet_frames[0], list):
|
|
163
|
+
controlnet_frames_ = []
|
|
164
|
+
for processor_id in range(len(controlnet_frames)):
|
|
165
|
+
controlnet_frames_.append(
|
|
166
|
+
torch.stack([
|
|
167
|
+
self.controlnet.process_image(controlnet_frame, processor_id=processor_id).to(self.torch_dtype)
|
|
168
|
+
for controlnet_frame in progress_bar_cmd(controlnet_frames[processor_id])
|
|
169
|
+
], dim=1)
|
|
170
|
+
)
|
|
171
|
+
controlnet_frames = torch.concat(controlnet_frames_, dim=0)
|
|
172
|
+
else:
|
|
173
|
+
controlnet_frames = torch.stack([
|
|
174
|
+
self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
|
|
175
|
+
for controlnet_frame in progress_bar_cmd(controlnet_frames)
|
|
176
|
+
], dim=1)
|
|
177
|
+
controlnet_kwargs = {"controlnet_frames": controlnet_frames}
|
|
178
|
+
else:
|
|
179
|
+
controlnet_kwargs = {"controlnet_frames": None}
|
|
180
|
+
|
|
181
|
+
# Prepare extra input
|
|
182
|
+
extra_input = self.prepare_extra_input(latents)
|
|
183
|
+
|
|
184
|
+
# Denoise
|
|
185
|
+
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
|
186
|
+
timestep = timestep.unsqueeze(0).to(self.device)
|
|
187
|
+
|
|
188
|
+
# Classifier-free guidance
|
|
189
|
+
noise_pred_posi = lets_dance_xl(
|
|
190
|
+
self.unet, motion_modules=self.motion_modules, controlnet=None,
|
|
191
|
+
sample=latents, timestep=timestep,
|
|
192
|
+
**prompt_emb_posi, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **extra_input, **tiler_kwargs,
|
|
193
|
+
device=self.device,
|
|
194
|
+
)
|
|
195
|
+
noise_pred_nega = lets_dance_xl(
|
|
196
|
+
self.unet, motion_modules=self.motion_modules, controlnet=None,
|
|
197
|
+
sample=latents, timestep=timestep,
|
|
198
|
+
**prompt_emb_nega, **controlnet_kwargs, **ipadapter_kwargs_list_nega, **extra_input, **tiler_kwargs,
|
|
199
|
+
device=self.device,
|
|
200
|
+
)
|
|
201
|
+
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
|
202
|
+
|
|
203
|
+
# DDIM and smoother
|
|
204
|
+
if smoother is not None and progress_id in smoother_progress_ids:
|
|
205
|
+
rendered_frames = self.scheduler.step(noise_pred, timestep, latents, to_final=True)
|
|
206
|
+
rendered_frames = self.decode_video(rendered_frames)
|
|
207
|
+
rendered_frames = smoother(rendered_frames, original_frames=input_frames)
|
|
208
|
+
target_latents = self.encode_video(rendered_frames)
|
|
209
|
+
noise_pred = self.scheduler.return_to_timestep(timestep, latents, target_latents)
|
|
210
|
+
latents = self.scheduler.step(noise_pred, timestep, latents)
|
|
211
|
+
|
|
212
|
+
# UI
|
|
213
|
+
if progress_bar_st is not None:
|
|
214
|
+
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
|
215
|
+
|
|
216
|
+
# Decode image
|
|
217
|
+
output_frames = self.decode_video(latents, **tiler_kwargs)
|
|
218
|
+
|
|
219
|
+
# Post-process
|
|
220
|
+
if smoother is not None and (num_inference_steps in smoother_progress_ids or -1 in smoother_progress_ids):
|
|
221
|
+
output_frames = smoother(output_frames, original_frames=input_frames)
|
|
222
|
+
|
|
223
|
+
return output_frames
|