diffsynth 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth/__init__.py +6 -0
- diffsynth/configs/__init__.py +0 -0
- diffsynth/configs/model_config.py +243 -0
- diffsynth/controlnets/__init__.py +2 -0
- diffsynth/controlnets/controlnet_unit.py +53 -0
- diffsynth/controlnets/processors.py +51 -0
- diffsynth/data/__init__.py +1 -0
- diffsynth/data/simple_text_image.py +35 -0
- diffsynth/data/video.py +148 -0
- diffsynth/extensions/ESRGAN/__init__.py +118 -0
- diffsynth/extensions/FastBlend/__init__.py +63 -0
- diffsynth/extensions/FastBlend/api.py +397 -0
- diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
- diffsynth/extensions/FastBlend/data.py +146 -0
- diffsynth/extensions/FastBlend/patch_match.py +298 -0
- diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
- diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
- diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
- diffsynth/extensions/FastBlend/runners/fast.py +141 -0
- diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
- diffsynth/extensions/RIFE/__init__.py +242 -0
- diffsynth/extensions/__init__.py +0 -0
- diffsynth/models/__init__.py +1 -0
- diffsynth/models/attention.py +89 -0
- diffsynth/models/downloader.py +66 -0
- diffsynth/models/hunyuan_dit.py +451 -0
- diffsynth/models/hunyuan_dit_text_encoder.py +163 -0
- diffsynth/models/kolors_text_encoder.py +1363 -0
- diffsynth/models/lora.py +195 -0
- diffsynth/models/model_manager.py +536 -0
- diffsynth/models/sd3_dit.py +798 -0
- diffsynth/models/sd3_text_encoder.py +1107 -0
- diffsynth/models/sd3_vae_decoder.py +81 -0
- diffsynth/models/sd3_vae_encoder.py +95 -0
- diffsynth/models/sd_controlnet.py +588 -0
- diffsynth/models/sd_ipadapter.py +57 -0
- diffsynth/models/sd_motion.py +199 -0
- diffsynth/models/sd_text_encoder.py +321 -0
- diffsynth/models/sd_unet.py +1108 -0
- diffsynth/models/sd_vae_decoder.py +336 -0
- diffsynth/models/sd_vae_encoder.py +282 -0
- diffsynth/models/sdxl_ipadapter.py +122 -0
- diffsynth/models/sdxl_motion.py +104 -0
- diffsynth/models/sdxl_text_encoder.py +759 -0
- diffsynth/models/sdxl_unet.py +1899 -0
- diffsynth/models/sdxl_vae_decoder.py +24 -0
- diffsynth/models/sdxl_vae_encoder.py +24 -0
- diffsynth/models/svd_image_encoder.py +505 -0
- diffsynth/models/svd_unet.py +2004 -0
- diffsynth/models/svd_vae_decoder.py +578 -0
- diffsynth/models/svd_vae_encoder.py +139 -0
- diffsynth/models/tiler.py +106 -0
- diffsynth/pipelines/__init__.py +9 -0
- diffsynth/pipelines/base.py +34 -0
- diffsynth/pipelines/dancer.py +178 -0
- diffsynth/pipelines/hunyuan_image.py +274 -0
- diffsynth/pipelines/pipeline_runner.py +105 -0
- diffsynth/pipelines/sd3_image.py +132 -0
- diffsynth/pipelines/sd_image.py +173 -0
- diffsynth/pipelines/sd_video.py +266 -0
- diffsynth/pipelines/sdxl_image.py +191 -0
- diffsynth/pipelines/sdxl_video.py +223 -0
- diffsynth/pipelines/svd_video.py +297 -0
- diffsynth/processors/FastBlend.py +142 -0
- diffsynth/processors/PILEditor.py +28 -0
- diffsynth/processors/RIFE.py +77 -0
- diffsynth/processors/__init__.py +0 -0
- diffsynth/processors/base.py +6 -0
- diffsynth/processors/sequencial_processor.py +41 -0
- diffsynth/prompters/__init__.py +6 -0
- diffsynth/prompters/base_prompter.py +57 -0
- diffsynth/prompters/hunyuan_dit_prompter.py +69 -0
- diffsynth/prompters/kolors_prompter.py +353 -0
- diffsynth/prompters/prompt_refiners.py +77 -0
- diffsynth/prompters/sd3_prompter.py +92 -0
- diffsynth/prompters/sd_prompter.py +73 -0
- diffsynth/prompters/sdxl_prompter.py +61 -0
- diffsynth/schedulers/__init__.py +3 -0
- diffsynth/schedulers/continuous_ode.py +59 -0
- diffsynth/schedulers/ddim.py +79 -0
- diffsynth/schedulers/flow_match.py +51 -0
- diffsynth/tokenizer_configs/__init__.py +0 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json +7 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json +16 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt +47020 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt +21128 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json +28 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json +1 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model +0 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json +1 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model +0 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json +12 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt +0 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json +24 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json +34 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json +38 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json +125 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/spiece.model +0 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json +129428 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json +940 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt +40213 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json +24 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json +38 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/vocab.json +49411 -0
- diffsynth/trainers/__init__.py +0 -0
- diffsynth/trainers/text_to_image.py +253 -0
- diffsynth-1.0.0.dist-info/LICENSE +201 -0
- diffsynth-1.0.0.dist-info/METADATA +23 -0
- diffsynth-1.0.0.dist-info/RECORD +120 -0
- diffsynth-1.0.0.dist-info/WHEEL +5 -0
- diffsynth-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
from ..models.hunyuan_dit import HunyuanDiT
|
|
2
|
+
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
|
3
|
+
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
|
|
4
|
+
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
|
|
5
|
+
from ..models import ModelManager
|
|
6
|
+
from ..prompters import HunyuanDiTPrompter
|
|
7
|
+
from ..schedulers import EnhancedDDIMScheduler
|
|
8
|
+
from .base import BasePipeline
|
|
9
|
+
import torch
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ImageSizeManager:
|
|
16
|
+
def __init__(self):
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _to_tuple(self, x):
|
|
21
|
+
if isinstance(x, int):
|
|
22
|
+
return x, x
|
|
23
|
+
else:
|
|
24
|
+
return x
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def get_fill_resize_and_crop(self, src, tgt):
|
|
28
|
+
th, tw = self._to_tuple(tgt)
|
|
29
|
+
h, w = self._to_tuple(src)
|
|
30
|
+
|
|
31
|
+
tr = th / tw # base 分辨率
|
|
32
|
+
r = h / w # 目标分辨率
|
|
33
|
+
|
|
34
|
+
# resize
|
|
35
|
+
if r > tr:
|
|
36
|
+
resize_height = th
|
|
37
|
+
resize_width = int(round(th / h * w))
|
|
38
|
+
else:
|
|
39
|
+
resize_width = tw
|
|
40
|
+
resize_height = int(round(tw / w * h)) # 根据base分辨率,将目标分辨率resize下来
|
|
41
|
+
|
|
42
|
+
crop_top = int(round((th - resize_height) / 2.0))
|
|
43
|
+
crop_left = int(round((tw - resize_width) / 2.0))
|
|
44
|
+
|
|
45
|
+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_meshgrid(self, start, *args):
|
|
49
|
+
if len(args) == 0:
|
|
50
|
+
# start is grid_size
|
|
51
|
+
num = self._to_tuple(start)
|
|
52
|
+
start = (0, 0)
|
|
53
|
+
stop = num
|
|
54
|
+
elif len(args) == 1:
|
|
55
|
+
# start is start, args[0] is stop, step is 1
|
|
56
|
+
start = self._to_tuple(start)
|
|
57
|
+
stop = self._to_tuple(args[0])
|
|
58
|
+
num = (stop[0] - start[0], stop[1] - start[1])
|
|
59
|
+
elif len(args) == 2:
|
|
60
|
+
# start is start, args[0] is stop, args[1] is num
|
|
61
|
+
start = self._to_tuple(start) # 左上角 eg: 12,0
|
|
62
|
+
stop = self._to_tuple(args[0]) # 右下角 eg: 20,32
|
|
63
|
+
num = self._to_tuple(args[1]) # 目标大小 eg: 32,124
|
|
64
|
+
else:
|
|
65
|
+
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
|
66
|
+
|
|
67
|
+
grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32) # 12-20 中间差值32份 0-32 中间差值124份
|
|
68
|
+
grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32)
|
|
69
|
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
|
70
|
+
grid = np.stack(grid, axis=0) # [2, W, H]
|
|
71
|
+
return grid
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def get_2d_rotary_pos_embed(self, embed_dim, start, *args, use_real=True):
|
|
75
|
+
grid = self.get_meshgrid(start, *args) # [2, H, w]
|
|
76
|
+
grid = grid.reshape([2, 1, *grid.shape[1:]]) # 返回一个采样矩阵 分辨率与目标分辨率一致
|
|
77
|
+
pos_embed = self.get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
|
|
78
|
+
return pos_embed
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def get_2d_rotary_pos_embed_from_grid(self, embed_dim, grid, use_real=False):
|
|
82
|
+
assert embed_dim % 4 == 0
|
|
83
|
+
|
|
84
|
+
# use half of dimensions to encode grid_h
|
|
85
|
+
emb_h = self.get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
|
|
86
|
+
emb_w = self.get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
|
|
87
|
+
|
|
88
|
+
if use_real:
|
|
89
|
+
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
|
|
90
|
+
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
|
|
91
|
+
return cos, sin
|
|
92
|
+
else:
|
|
93
|
+
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
|
|
94
|
+
return emb
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def get_1d_rotary_pos_embed(self, dim: int, pos, theta: float = 10000.0, use_real=False):
|
|
98
|
+
if isinstance(pos, int):
|
|
99
|
+
pos = np.arange(pos)
|
|
100
|
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
|
|
101
|
+
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
|
|
102
|
+
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
|
|
103
|
+
if use_real:
|
|
104
|
+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
|
105
|
+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
|
106
|
+
return freqs_cos, freqs_sin
|
|
107
|
+
else:
|
|
108
|
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
|
109
|
+
return freqs_cis
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def calc_rope(self, height, width):
|
|
113
|
+
patch_size = 2
|
|
114
|
+
head_size = 88
|
|
115
|
+
th = height // 8 // patch_size
|
|
116
|
+
tw = width // 8 // patch_size
|
|
117
|
+
base_size = 512 // 8 // patch_size
|
|
118
|
+
start, stop = self.get_fill_resize_and_crop((th, tw), base_size)
|
|
119
|
+
sub_args = [start, stop, (th, tw)]
|
|
120
|
+
rope = self.get_2d_rotary_pos_embed(head_size, *sub_args)
|
|
121
|
+
return rope
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class HunyuanDiTImagePipeline(BasePipeline):
|
|
126
|
+
|
|
127
|
+
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
|
128
|
+
super().__init__(device=device, torch_dtype=torch_dtype)
|
|
129
|
+
self.scheduler = EnhancedDDIMScheduler(prediction_type="v_prediction", beta_start=0.00085, beta_end=0.03)
|
|
130
|
+
self.prompter = HunyuanDiTPrompter()
|
|
131
|
+
self.image_size_manager = ImageSizeManager()
|
|
132
|
+
# models
|
|
133
|
+
self.text_encoder: HunyuanDiTCLIPTextEncoder = None
|
|
134
|
+
self.text_encoder_t5: HunyuanDiTT5TextEncoder = None
|
|
135
|
+
self.dit: HunyuanDiT = None
|
|
136
|
+
self.vae_decoder: SDXLVAEDecoder = None
|
|
137
|
+
self.vae_encoder: SDXLVAEEncoder = None
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def denoising_model(self):
|
|
141
|
+
return self.dit
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
|
|
145
|
+
# Main models
|
|
146
|
+
self.text_encoder = model_manager.fetch_model("hunyuan_dit_clip_text_encoder")
|
|
147
|
+
self.text_encoder_t5 = model_manager.fetch_model("hunyuan_dit_t5_text_encoder")
|
|
148
|
+
self.dit = model_manager.fetch_model("hunyuan_dit")
|
|
149
|
+
self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
|
|
150
|
+
self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
|
|
151
|
+
self.prompter.fetch_models(self.text_encoder, self.text_encoder_t5)
|
|
152
|
+
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
@staticmethod
|
|
156
|
+
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]):
|
|
157
|
+
pipe = HunyuanDiTImagePipeline(
|
|
158
|
+
device=model_manager.device,
|
|
159
|
+
torch_dtype=model_manager.torch_dtype,
|
|
160
|
+
)
|
|
161
|
+
pipe.fetch_models(model_manager, prompt_refiner_classes)
|
|
162
|
+
return pipe
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
|
|
166
|
+
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
167
|
+
return latents
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
|
171
|
+
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
172
|
+
image = self.vae_output_to_image(image)
|
|
173
|
+
return image
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def encode_prompt(self, prompt, clip_skip=1, clip_skip_2=1, positive=True):
|
|
177
|
+
text_emb, text_emb_mask, text_emb_t5, text_emb_mask_t5 = self.prompter.encode_prompt(
|
|
178
|
+
prompt,
|
|
179
|
+
clip_skip=clip_skip,
|
|
180
|
+
clip_skip_2=clip_skip_2,
|
|
181
|
+
positive=positive,
|
|
182
|
+
device=self.device
|
|
183
|
+
)
|
|
184
|
+
return {
|
|
185
|
+
"text_emb": text_emb,
|
|
186
|
+
"text_emb_mask": text_emb_mask,
|
|
187
|
+
"text_emb_t5": text_emb_t5,
|
|
188
|
+
"text_emb_mask_t5": text_emb_mask_t5
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def prepare_extra_input(self, latents=None, tiled=False, tile_size=64, tile_stride=32):
|
|
193
|
+
batch_size, height, width = latents.shape[0], latents.shape[2] * 8, latents.shape[3] * 8
|
|
194
|
+
if tiled:
|
|
195
|
+
height, width = tile_size * 16, tile_size * 16
|
|
196
|
+
image_meta_size = torch.as_tensor([width, height, width, height, 0, 0]).to(device=self.device)
|
|
197
|
+
freqs_cis_img = self.image_size_manager.calc_rope(height, width)
|
|
198
|
+
image_meta_size = torch.stack([image_meta_size] * batch_size)
|
|
199
|
+
return {
|
|
200
|
+
"size_emb": image_meta_size,
|
|
201
|
+
"freq_cis_img": (freqs_cis_img[0].to(dtype=self.torch_dtype, device=self.device), freqs_cis_img[1].to(dtype=self.torch_dtype, device=self.device)),
|
|
202
|
+
"tiled": tiled,
|
|
203
|
+
"tile_size": tile_size,
|
|
204
|
+
"tile_stride": tile_stride
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
@torch.no_grad()
|
|
209
|
+
def __call__(
|
|
210
|
+
self,
|
|
211
|
+
prompt,
|
|
212
|
+
negative_prompt="",
|
|
213
|
+
cfg_scale=7.5,
|
|
214
|
+
clip_skip=1,
|
|
215
|
+
clip_skip_2=1,
|
|
216
|
+
input_image=None,
|
|
217
|
+
reference_strengths=[0.4],
|
|
218
|
+
denoising_strength=1.0,
|
|
219
|
+
height=1024,
|
|
220
|
+
width=1024,
|
|
221
|
+
num_inference_steps=20,
|
|
222
|
+
tiled=False,
|
|
223
|
+
tile_size=64,
|
|
224
|
+
tile_stride=32,
|
|
225
|
+
progress_bar_cmd=tqdm,
|
|
226
|
+
progress_bar_st=None,
|
|
227
|
+
):
|
|
228
|
+
# Prepare scheduler
|
|
229
|
+
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
|
230
|
+
|
|
231
|
+
# Prepare latent tensors
|
|
232
|
+
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
|
233
|
+
if input_image is not None:
|
|
234
|
+
image = self.preprocess_image(input_image).to(device=self.device, dtype=torch.float32)
|
|
235
|
+
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
|
|
236
|
+
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
|
237
|
+
else:
|
|
238
|
+
latents = noise.clone()
|
|
239
|
+
|
|
240
|
+
# Encode prompts
|
|
241
|
+
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
|
|
242
|
+
if cfg_scale != 1.0:
|
|
243
|
+
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
|
|
244
|
+
|
|
245
|
+
# Prepare positional id
|
|
246
|
+
extra_input = self.prepare_extra_input(latents, tiled, tile_size)
|
|
247
|
+
|
|
248
|
+
# Denoise
|
|
249
|
+
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
|
250
|
+
timestep = torch.tensor([timestep]).to(dtype=self.torch_dtype, device=self.device)
|
|
251
|
+
|
|
252
|
+
# Positive side
|
|
253
|
+
noise_pred_posi = self.dit(
|
|
254
|
+
latents, timestep=timestep, **prompt_emb_posi, **extra_input,
|
|
255
|
+
)
|
|
256
|
+
if cfg_scale != 1.0:
|
|
257
|
+
# Negative side
|
|
258
|
+
noise_pred_nega = self.dit(
|
|
259
|
+
latents, timestep=timestep, **prompt_emb_nega, **extra_input,
|
|
260
|
+
)
|
|
261
|
+
# Classifier-free guidance
|
|
262
|
+
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
|
263
|
+
else:
|
|
264
|
+
noise_pred = noise_pred_posi
|
|
265
|
+
|
|
266
|
+
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
|
267
|
+
|
|
268
|
+
if progress_bar_st is not None:
|
|
269
|
+
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
|
270
|
+
|
|
271
|
+
# Decode image
|
|
272
|
+
image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
273
|
+
|
|
274
|
+
return image
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
import os, torch, json
|
|
2
|
+
from .sd_video import ModelManager, SDVideoPipeline, ControlNetConfigUnit
|
|
3
|
+
from ..processors.sequencial_processor import SequencialProcessor
|
|
4
|
+
from ..data import VideoData, save_frames, save_video
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SDVideoPipelineRunner:
|
|
9
|
+
def __init__(self, in_streamlit=False):
|
|
10
|
+
self.in_streamlit = in_streamlit
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def load_pipeline(self, model_list, textual_inversion_folder, device, lora_alphas, controlnet_units):
|
|
14
|
+
# Load models
|
|
15
|
+
model_manager = ModelManager(torch_dtype=torch.float16, device=device)
|
|
16
|
+
model_manager.load_models(model_list)
|
|
17
|
+
pipe = SDVideoPipeline.from_model_manager(
|
|
18
|
+
model_manager,
|
|
19
|
+
[
|
|
20
|
+
ControlNetConfigUnit(
|
|
21
|
+
processor_id=unit["processor_id"],
|
|
22
|
+
model_path=unit["model_path"],
|
|
23
|
+
scale=unit["scale"]
|
|
24
|
+
) for unit in controlnet_units
|
|
25
|
+
]
|
|
26
|
+
)
|
|
27
|
+
textual_inversion_paths = []
|
|
28
|
+
for file_name in os.listdir(textual_inversion_folder):
|
|
29
|
+
if file_name.endswith(".pt") or file_name.endswith(".bin") or file_name.endswith(".pth") or file_name.endswith(".safetensors"):
|
|
30
|
+
textual_inversion_paths.append(os.path.join(textual_inversion_folder, file_name))
|
|
31
|
+
pipe.prompter.load_textual_inversions(textual_inversion_paths)
|
|
32
|
+
return model_manager, pipe
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def load_smoother(self, model_manager, smoother_configs):
|
|
36
|
+
smoother = SequencialProcessor.from_model_manager(model_manager, smoother_configs)
|
|
37
|
+
return smoother
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def synthesize_video(self, model_manager, pipe, seed, smoother, **pipeline_inputs):
|
|
41
|
+
torch.manual_seed(seed)
|
|
42
|
+
if self.in_streamlit:
|
|
43
|
+
import streamlit as st
|
|
44
|
+
progress_bar_st = st.progress(0.0)
|
|
45
|
+
output_video = pipe(**pipeline_inputs, smoother=smoother, progress_bar_st=progress_bar_st)
|
|
46
|
+
progress_bar_st.progress(1.0)
|
|
47
|
+
else:
|
|
48
|
+
output_video = pipe(**pipeline_inputs, smoother=smoother)
|
|
49
|
+
model_manager.to("cpu")
|
|
50
|
+
return output_video
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def load_video(self, video_file, image_folder, height, width, start_frame_id, end_frame_id):
|
|
54
|
+
video = VideoData(video_file=video_file, image_folder=image_folder, height=height, width=width)
|
|
55
|
+
if start_frame_id is None:
|
|
56
|
+
start_frame_id = 0
|
|
57
|
+
if end_frame_id is None:
|
|
58
|
+
end_frame_id = len(video)
|
|
59
|
+
frames = [video[i] for i in range(start_frame_id, end_frame_id)]
|
|
60
|
+
return frames
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def add_data_to_pipeline_inputs(self, data, pipeline_inputs):
|
|
64
|
+
pipeline_inputs["input_frames"] = self.load_video(**data["input_frames"])
|
|
65
|
+
pipeline_inputs["num_frames"] = len(pipeline_inputs["input_frames"])
|
|
66
|
+
pipeline_inputs["width"], pipeline_inputs["height"] = pipeline_inputs["input_frames"][0].size
|
|
67
|
+
if len(data["controlnet_frames"]) > 0:
|
|
68
|
+
pipeline_inputs["controlnet_frames"] = [self.load_video(**unit) for unit in data["controlnet_frames"]]
|
|
69
|
+
return pipeline_inputs
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def save_output(self, video, output_folder, fps, config):
|
|
73
|
+
os.makedirs(output_folder, exist_ok=True)
|
|
74
|
+
save_frames(video, os.path.join(output_folder, "frames"))
|
|
75
|
+
save_video(video, os.path.join(output_folder, "video.mp4"), fps=fps)
|
|
76
|
+
config["pipeline"]["pipeline_inputs"]["input_frames"] = []
|
|
77
|
+
config["pipeline"]["pipeline_inputs"]["controlnet_frames"] = []
|
|
78
|
+
with open(os.path.join(output_folder, "config.json"), 'w') as file:
|
|
79
|
+
json.dump(config, file, indent=4)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def run(self, config):
|
|
83
|
+
if self.in_streamlit:
|
|
84
|
+
import streamlit as st
|
|
85
|
+
if self.in_streamlit: st.markdown("Loading videos ...")
|
|
86
|
+
config["pipeline"]["pipeline_inputs"] = self.add_data_to_pipeline_inputs(config["data"], config["pipeline"]["pipeline_inputs"])
|
|
87
|
+
if self.in_streamlit: st.markdown("Loading videos ... done!")
|
|
88
|
+
if self.in_streamlit: st.markdown("Loading models ...")
|
|
89
|
+
model_manager, pipe = self.load_pipeline(**config["models"])
|
|
90
|
+
if self.in_streamlit: st.markdown("Loading models ... done!")
|
|
91
|
+
if "smoother_configs" in config:
|
|
92
|
+
if self.in_streamlit: st.markdown("Loading smoother ...")
|
|
93
|
+
smoother = self.load_smoother(model_manager, config["smoother_configs"])
|
|
94
|
+
if self.in_streamlit: st.markdown("Loading smoother ... done!")
|
|
95
|
+
else:
|
|
96
|
+
smoother = None
|
|
97
|
+
if self.in_streamlit: st.markdown("Synthesizing videos ...")
|
|
98
|
+
output_video = self.synthesize_video(model_manager, pipe, config["pipeline"]["seed"], smoother, **config["pipeline"]["pipeline_inputs"])
|
|
99
|
+
if self.in_streamlit: st.markdown("Synthesizing videos ... done!")
|
|
100
|
+
if self.in_streamlit: st.markdown("Saving videos ...")
|
|
101
|
+
self.save_output(output_video, config["data"]["output_folder"], config["data"]["fps"], config)
|
|
102
|
+
if self.in_streamlit: st.markdown("Saving videos ... done!")
|
|
103
|
+
if self.in_streamlit: st.markdown("Finished!")
|
|
104
|
+
video_file = open(os.path.join(os.path.join(config["data"]["output_folder"], "video.mp4")), 'rb')
|
|
105
|
+
if self.in_streamlit: st.video(video_file.read())
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
from ..models import ModelManager, SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEDecoder, SD3VAEEncoder
|
|
2
|
+
from ..prompters import SD3Prompter
|
|
3
|
+
from ..schedulers import FlowMatchScheduler
|
|
4
|
+
from .base import BasePipeline
|
|
5
|
+
import torch
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SD3ImagePipeline(BasePipeline):
|
|
11
|
+
|
|
12
|
+
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
|
13
|
+
super().__init__(device=device, torch_dtype=torch_dtype)
|
|
14
|
+
self.scheduler = FlowMatchScheduler()
|
|
15
|
+
self.prompter = SD3Prompter()
|
|
16
|
+
# models
|
|
17
|
+
self.text_encoder_1: SD3TextEncoder1 = None
|
|
18
|
+
self.text_encoder_2: SD3TextEncoder2 = None
|
|
19
|
+
self.text_encoder_3: SD3TextEncoder3 = None
|
|
20
|
+
self.dit: SD3DiT = None
|
|
21
|
+
self.vae_decoder: SD3VAEDecoder = None
|
|
22
|
+
self.vae_encoder: SD3VAEEncoder = None
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def denoising_model(self):
|
|
26
|
+
return self.dit
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
|
|
30
|
+
self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
|
|
31
|
+
self.text_encoder_2 = model_manager.fetch_model("sd3_text_encoder_2")
|
|
32
|
+
if "sd3_text_encoder_3" in model_manager.model:
|
|
33
|
+
self.text_encoder_3 = model_manager.fetch_model("sd3_text_encoder_3")
|
|
34
|
+
self.dit = model_manager.fetch_model("sd3_dit")
|
|
35
|
+
self.vae_decoder = model_manager.fetch_model("sd3_vae_decoder")
|
|
36
|
+
self.vae_encoder = model_manager.fetch_model("sd3_vae_encoder")
|
|
37
|
+
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2, self.text_encoder_3)
|
|
38
|
+
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@staticmethod
|
|
42
|
+
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]):
|
|
43
|
+
pipe = SD3ImagePipeline(
|
|
44
|
+
device=model_manager.device,
|
|
45
|
+
torch_dtype=model_manager.torch_dtype,
|
|
46
|
+
)
|
|
47
|
+
pipe.fetch_models(model_manager, prompt_refiner_classes)
|
|
48
|
+
return pipe
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
|
|
52
|
+
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
53
|
+
return latents
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
|
57
|
+
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
58
|
+
image = self.vae_output_to_image(image)
|
|
59
|
+
return image
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def encode_prompt(self, prompt, positive=True):
|
|
63
|
+
prompt_emb, pooled_prompt_emb = self.prompter.encode_prompt(
|
|
64
|
+
prompt, device=self.device, positive=positive
|
|
65
|
+
)
|
|
66
|
+
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb}
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def prepare_extra_input(self, latents=None):
|
|
70
|
+
return {}
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@torch.no_grad()
|
|
74
|
+
def __call__(
|
|
75
|
+
self,
|
|
76
|
+
prompt,
|
|
77
|
+
negative_prompt="",
|
|
78
|
+
cfg_scale=7.5,
|
|
79
|
+
input_image=None,
|
|
80
|
+
denoising_strength=1.0,
|
|
81
|
+
height=1024,
|
|
82
|
+
width=1024,
|
|
83
|
+
num_inference_steps=20,
|
|
84
|
+
tiled=False,
|
|
85
|
+
tile_size=128,
|
|
86
|
+
tile_stride=64,
|
|
87
|
+
progress_bar_cmd=tqdm,
|
|
88
|
+
progress_bar_st=None,
|
|
89
|
+
):
|
|
90
|
+
# Tiler parameters
|
|
91
|
+
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
|
92
|
+
|
|
93
|
+
# Prepare scheduler
|
|
94
|
+
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
|
95
|
+
|
|
96
|
+
# Prepare latent tensors
|
|
97
|
+
if input_image is not None:
|
|
98
|
+
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
|
99
|
+
latents = self.encode_image(image, **tiler_kwargs)
|
|
100
|
+
noise = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
|
101
|
+
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
|
102
|
+
else:
|
|
103
|
+
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
|
104
|
+
|
|
105
|
+
# Encode prompts
|
|
106
|
+
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
|
107
|
+
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
|
108
|
+
|
|
109
|
+
# Denoise
|
|
110
|
+
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
|
111
|
+
timestep = timestep.unsqueeze(0).to(self.device)
|
|
112
|
+
|
|
113
|
+
# Classifier-free guidance
|
|
114
|
+
noise_pred_posi = self.dit(
|
|
115
|
+
latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs,
|
|
116
|
+
)
|
|
117
|
+
noise_pred_nega = self.dit(
|
|
118
|
+
latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs,
|
|
119
|
+
)
|
|
120
|
+
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
|
121
|
+
|
|
122
|
+
# DDIM
|
|
123
|
+
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
|
124
|
+
|
|
125
|
+
# UI
|
|
126
|
+
if progress_bar_st is not None:
|
|
127
|
+
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
|
128
|
+
|
|
129
|
+
# Decode image
|
|
130
|
+
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
131
|
+
|
|
132
|
+
return image
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
from ..models import SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder
|
|
2
|
+
from ..models.model_manager import ModelManager
|
|
3
|
+
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
|
4
|
+
from ..prompters import SDPrompter
|
|
5
|
+
from ..schedulers import EnhancedDDIMScheduler
|
|
6
|
+
from .base import BasePipeline
|
|
7
|
+
from .dancer import lets_dance
|
|
8
|
+
from typing import List
|
|
9
|
+
import torch
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SDImagePipeline(BasePipeline):
|
|
15
|
+
|
|
16
|
+
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
|
17
|
+
super().__init__(device=device, torch_dtype=torch_dtype)
|
|
18
|
+
self.scheduler = EnhancedDDIMScheduler()
|
|
19
|
+
self.prompter = SDPrompter()
|
|
20
|
+
# models
|
|
21
|
+
self.text_encoder: SDTextEncoder = None
|
|
22
|
+
self.unet: SDUNet = None
|
|
23
|
+
self.vae_decoder: SDVAEDecoder = None
|
|
24
|
+
self.vae_encoder: SDVAEEncoder = None
|
|
25
|
+
self.controlnet: MultiControlNetManager = None
|
|
26
|
+
self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None
|
|
27
|
+
self.ipadapter: SDIpAdapter = None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def denoising_model(self):
|
|
31
|
+
return self.unet
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
|
35
|
+
# Main models
|
|
36
|
+
self.text_encoder = model_manager.fetch_model("sd_text_encoder")
|
|
37
|
+
self.unet = model_manager.fetch_model("sd_unet")
|
|
38
|
+
self.vae_decoder = model_manager.fetch_model("sd_vae_decoder")
|
|
39
|
+
self.vae_encoder = model_manager.fetch_model("sd_vae_encoder")
|
|
40
|
+
self.prompter.fetch_models(self.text_encoder)
|
|
41
|
+
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
|
42
|
+
|
|
43
|
+
# ControlNets
|
|
44
|
+
controlnet_units = []
|
|
45
|
+
for config in controlnet_config_units:
|
|
46
|
+
controlnet_unit = ControlNetUnit(
|
|
47
|
+
Annotator(config.processor_id, device=self.device),
|
|
48
|
+
model_manager.fetch_model("sd_controlnet", config.model_path),
|
|
49
|
+
config.scale
|
|
50
|
+
)
|
|
51
|
+
controlnet_units.append(controlnet_unit)
|
|
52
|
+
self.controlnet = MultiControlNetManager(controlnet_units)
|
|
53
|
+
|
|
54
|
+
# IP-Adapters
|
|
55
|
+
self.ipadapter = model_manager.fetch_model("sd_ipadapter")
|
|
56
|
+
self.ipadapter_image_encoder = model_manager.fetch_model("sd_ipadapter_clip_image_encoder")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@staticmethod
|
|
60
|
+
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
|
61
|
+
pipe = SDImagePipeline(
|
|
62
|
+
device=model_manager.device,
|
|
63
|
+
torch_dtype=model_manager.torch_dtype,
|
|
64
|
+
)
|
|
65
|
+
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes=[])
|
|
66
|
+
return pipe
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
|
|
70
|
+
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
71
|
+
return latents
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
|
75
|
+
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
76
|
+
image = self.vae_output_to_image(image)
|
|
77
|
+
return image
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def encode_prompt(self, prompt, clip_skip=1, positive=True):
|
|
81
|
+
prompt_emb = self.prompter.encode_prompt(prompt, clip_skip=clip_skip, device=self.device, positive=positive)
|
|
82
|
+
return {"encoder_hidden_states": prompt_emb}
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def prepare_extra_input(self, latents=None):
|
|
86
|
+
return {}
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@torch.no_grad()
|
|
90
|
+
def __call__(
|
|
91
|
+
self,
|
|
92
|
+
prompt,
|
|
93
|
+
negative_prompt="",
|
|
94
|
+
cfg_scale=7.5,
|
|
95
|
+
clip_skip=1,
|
|
96
|
+
input_image=None,
|
|
97
|
+
ipadapter_images=None,
|
|
98
|
+
ipadapter_scale=1.0,
|
|
99
|
+
controlnet_image=None,
|
|
100
|
+
denoising_strength=1.0,
|
|
101
|
+
height=512,
|
|
102
|
+
width=512,
|
|
103
|
+
num_inference_steps=20,
|
|
104
|
+
tiled=False,
|
|
105
|
+
tile_size=64,
|
|
106
|
+
tile_stride=32,
|
|
107
|
+
progress_bar_cmd=tqdm,
|
|
108
|
+
progress_bar_st=None,
|
|
109
|
+
):
|
|
110
|
+
# Tiler parameters
|
|
111
|
+
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
|
112
|
+
|
|
113
|
+
# Prepare scheduler
|
|
114
|
+
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
|
115
|
+
|
|
116
|
+
# Prepare latent tensors
|
|
117
|
+
if input_image is not None:
|
|
118
|
+
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
|
119
|
+
latents = self.encode_image(image, **tiler_kwargs)
|
|
120
|
+
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
|
121
|
+
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
|
122
|
+
else:
|
|
123
|
+
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
|
124
|
+
|
|
125
|
+
# Encode prompts
|
|
126
|
+
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
|
|
127
|
+
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
|
|
128
|
+
|
|
129
|
+
# IP-Adapter
|
|
130
|
+
if ipadapter_images is not None:
|
|
131
|
+
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
|
|
132
|
+
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
|
|
133
|
+
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
|
|
134
|
+
else:
|
|
135
|
+
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
|
|
136
|
+
|
|
137
|
+
# Prepare ControlNets
|
|
138
|
+
if controlnet_image is not None:
|
|
139
|
+
controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
|
|
140
|
+
controlnet_image = controlnet_image.unsqueeze(1)
|
|
141
|
+
controlnet_kwargs = {"controlnet_frames": controlnet_image}
|
|
142
|
+
else:
|
|
143
|
+
controlnet_kwargs = {"controlnet_frames": None}
|
|
144
|
+
|
|
145
|
+
# Denoise
|
|
146
|
+
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
|
147
|
+
timestep = timestep.unsqueeze(0).to(self.device)
|
|
148
|
+
|
|
149
|
+
# Classifier-free guidance
|
|
150
|
+
noise_pred_posi = lets_dance(
|
|
151
|
+
self.unet, motion_modules=None, controlnet=self.controlnet,
|
|
152
|
+
sample=latents, timestep=timestep,
|
|
153
|
+
**prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi,
|
|
154
|
+
device=self.device,
|
|
155
|
+
)
|
|
156
|
+
noise_pred_nega = lets_dance(
|
|
157
|
+
self.unet, motion_modules=None, controlnet=self.controlnet,
|
|
158
|
+
sample=latents, timestep=timestep, **prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega,
|
|
159
|
+
device=self.device,
|
|
160
|
+
)
|
|
161
|
+
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
|
162
|
+
|
|
163
|
+
# DDIM
|
|
164
|
+
latents = self.scheduler.step(noise_pred, timestep, latents)
|
|
165
|
+
|
|
166
|
+
# UI
|
|
167
|
+
if progress_bar_st is not None:
|
|
168
|
+
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
|
169
|
+
|
|
170
|
+
# Decode image
|
|
171
|
+
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
172
|
+
|
|
173
|
+
return image
|