diffsynth 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. diffsynth/__init__.py +6 -0
  2. diffsynth/configs/__init__.py +0 -0
  3. diffsynth/configs/model_config.py +243 -0
  4. diffsynth/controlnets/__init__.py +2 -0
  5. diffsynth/controlnets/controlnet_unit.py +53 -0
  6. diffsynth/controlnets/processors.py +51 -0
  7. diffsynth/data/__init__.py +1 -0
  8. diffsynth/data/simple_text_image.py +35 -0
  9. diffsynth/data/video.py +148 -0
  10. diffsynth/extensions/ESRGAN/__init__.py +118 -0
  11. diffsynth/extensions/FastBlend/__init__.py +63 -0
  12. diffsynth/extensions/FastBlend/api.py +397 -0
  13. diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
  14. diffsynth/extensions/FastBlend/data.py +146 -0
  15. diffsynth/extensions/FastBlend/patch_match.py +298 -0
  16. diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
  17. diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
  18. diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
  19. diffsynth/extensions/FastBlend/runners/fast.py +141 -0
  20. diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
  21. diffsynth/extensions/RIFE/__init__.py +242 -0
  22. diffsynth/extensions/__init__.py +0 -0
  23. diffsynth/models/__init__.py +1 -0
  24. diffsynth/models/attention.py +89 -0
  25. diffsynth/models/downloader.py +66 -0
  26. diffsynth/models/hunyuan_dit.py +451 -0
  27. diffsynth/models/hunyuan_dit_text_encoder.py +163 -0
  28. diffsynth/models/kolors_text_encoder.py +1363 -0
  29. diffsynth/models/lora.py +195 -0
  30. diffsynth/models/model_manager.py +536 -0
  31. diffsynth/models/sd3_dit.py +798 -0
  32. diffsynth/models/sd3_text_encoder.py +1107 -0
  33. diffsynth/models/sd3_vae_decoder.py +81 -0
  34. diffsynth/models/sd3_vae_encoder.py +95 -0
  35. diffsynth/models/sd_controlnet.py +588 -0
  36. diffsynth/models/sd_ipadapter.py +57 -0
  37. diffsynth/models/sd_motion.py +199 -0
  38. diffsynth/models/sd_text_encoder.py +321 -0
  39. diffsynth/models/sd_unet.py +1108 -0
  40. diffsynth/models/sd_vae_decoder.py +336 -0
  41. diffsynth/models/sd_vae_encoder.py +282 -0
  42. diffsynth/models/sdxl_ipadapter.py +122 -0
  43. diffsynth/models/sdxl_motion.py +104 -0
  44. diffsynth/models/sdxl_text_encoder.py +759 -0
  45. diffsynth/models/sdxl_unet.py +1899 -0
  46. diffsynth/models/sdxl_vae_decoder.py +24 -0
  47. diffsynth/models/sdxl_vae_encoder.py +24 -0
  48. diffsynth/models/svd_image_encoder.py +505 -0
  49. diffsynth/models/svd_unet.py +2004 -0
  50. diffsynth/models/svd_vae_decoder.py +578 -0
  51. diffsynth/models/svd_vae_encoder.py +139 -0
  52. diffsynth/models/tiler.py +106 -0
  53. diffsynth/pipelines/__init__.py +9 -0
  54. diffsynth/pipelines/base.py +34 -0
  55. diffsynth/pipelines/dancer.py +178 -0
  56. diffsynth/pipelines/hunyuan_image.py +274 -0
  57. diffsynth/pipelines/pipeline_runner.py +105 -0
  58. diffsynth/pipelines/sd3_image.py +132 -0
  59. diffsynth/pipelines/sd_image.py +173 -0
  60. diffsynth/pipelines/sd_video.py +266 -0
  61. diffsynth/pipelines/sdxl_image.py +191 -0
  62. diffsynth/pipelines/sdxl_video.py +223 -0
  63. diffsynth/pipelines/svd_video.py +297 -0
  64. diffsynth/processors/FastBlend.py +142 -0
  65. diffsynth/processors/PILEditor.py +28 -0
  66. diffsynth/processors/RIFE.py +77 -0
  67. diffsynth/processors/__init__.py +0 -0
  68. diffsynth/processors/base.py +6 -0
  69. diffsynth/processors/sequencial_processor.py +41 -0
  70. diffsynth/prompters/__init__.py +6 -0
  71. diffsynth/prompters/base_prompter.py +57 -0
  72. diffsynth/prompters/hunyuan_dit_prompter.py +69 -0
  73. diffsynth/prompters/kolors_prompter.py +353 -0
  74. diffsynth/prompters/prompt_refiners.py +77 -0
  75. diffsynth/prompters/sd3_prompter.py +92 -0
  76. diffsynth/prompters/sd_prompter.py +73 -0
  77. diffsynth/prompters/sdxl_prompter.py +61 -0
  78. diffsynth/schedulers/__init__.py +3 -0
  79. diffsynth/schedulers/continuous_ode.py +59 -0
  80. diffsynth/schedulers/ddim.py +79 -0
  81. diffsynth/schedulers/flow_match.py +51 -0
  82. diffsynth/tokenizer_configs/__init__.py +0 -0
  83. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json +7 -0
  84. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json +16 -0
  85. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt +47020 -0
  86. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt +21128 -0
  87. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json +28 -0
  88. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json +1 -0
  89. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model +0 -0
  90. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json +1 -0
  91. diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model +0 -0
  92. diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json +12 -0
  93. diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt +0 -0
  94. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt +48895 -0
  95. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json +24 -0
  96. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json +34 -0
  97. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json +49410 -0
  98. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt +48895 -0
  99. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json +30 -0
  100. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json +30 -0
  101. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json +49410 -0
  102. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt +48895 -0
  103. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json +30 -0
  104. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json +38 -0
  105. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json +49410 -0
  106. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json +125 -0
  107. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/spiece.model +0 -0
  108. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json +129428 -0
  109. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json +940 -0
  110. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt +40213 -0
  111. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json +24 -0
  112. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json +38 -0
  113. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/vocab.json +49411 -0
  114. diffsynth/trainers/__init__.py +0 -0
  115. diffsynth/trainers/text_to_image.py +253 -0
  116. diffsynth-1.0.0.dist-info/LICENSE +201 -0
  117. diffsynth-1.0.0.dist-info/METADATA +23 -0
  118. diffsynth-1.0.0.dist-info/RECORD +120 -0
  119. diffsynth-1.0.0.dist-info/WHEEL +5 -0
  120. diffsynth-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,274 @@
1
+ from ..models.hunyuan_dit import HunyuanDiT
2
+ from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
3
+ from ..models.sdxl_vae_encoder import SDXLVAEEncoder
4
+ from ..models.sdxl_vae_decoder import SDXLVAEDecoder
5
+ from ..models import ModelManager
6
+ from ..prompters import HunyuanDiTPrompter
7
+ from ..schedulers import EnhancedDDIMScheduler
8
+ from .base import BasePipeline
9
+ import torch
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+
13
+
14
+
15
+ class ImageSizeManager:
16
+ def __init__(self):
17
+ pass
18
+
19
+
20
+ def _to_tuple(self, x):
21
+ if isinstance(x, int):
22
+ return x, x
23
+ else:
24
+ return x
25
+
26
+
27
+ def get_fill_resize_and_crop(self, src, tgt):
28
+ th, tw = self._to_tuple(tgt)
29
+ h, w = self._to_tuple(src)
30
+
31
+ tr = th / tw # base 分辨率
32
+ r = h / w # 目标分辨率
33
+
34
+ # resize
35
+ if r > tr:
36
+ resize_height = th
37
+ resize_width = int(round(th / h * w))
38
+ else:
39
+ resize_width = tw
40
+ resize_height = int(round(tw / w * h)) # 根据base分辨率,将目标分辨率resize下来
41
+
42
+ crop_top = int(round((th - resize_height) / 2.0))
43
+ crop_left = int(round((tw - resize_width) / 2.0))
44
+
45
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
46
+
47
+
48
+ def get_meshgrid(self, start, *args):
49
+ if len(args) == 0:
50
+ # start is grid_size
51
+ num = self._to_tuple(start)
52
+ start = (0, 0)
53
+ stop = num
54
+ elif len(args) == 1:
55
+ # start is start, args[0] is stop, step is 1
56
+ start = self._to_tuple(start)
57
+ stop = self._to_tuple(args[0])
58
+ num = (stop[0] - start[0], stop[1] - start[1])
59
+ elif len(args) == 2:
60
+ # start is start, args[0] is stop, args[1] is num
61
+ start = self._to_tuple(start) # 左上角 eg: 12,0
62
+ stop = self._to_tuple(args[0]) # 右下角 eg: 20,32
63
+ num = self._to_tuple(args[1]) # 目标大小 eg: 32,124
64
+ else:
65
+ raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
66
+
67
+ grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32) # 12-20 中间差值32份 0-32 中间差值124份
68
+ grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32)
69
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
70
+ grid = np.stack(grid, axis=0) # [2, W, H]
71
+ return grid
72
+
73
+
74
+ def get_2d_rotary_pos_embed(self, embed_dim, start, *args, use_real=True):
75
+ grid = self.get_meshgrid(start, *args) # [2, H, w]
76
+ grid = grid.reshape([2, 1, *grid.shape[1:]]) # 返回一个采样矩阵 分辨率与目标分辨率一致
77
+ pos_embed = self.get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
78
+ return pos_embed
79
+
80
+
81
+ def get_2d_rotary_pos_embed_from_grid(self, embed_dim, grid, use_real=False):
82
+ assert embed_dim % 4 == 0
83
+
84
+ # use half of dimensions to encode grid_h
85
+ emb_h = self.get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
86
+ emb_w = self.get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
87
+
88
+ if use_real:
89
+ cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
90
+ sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
91
+ return cos, sin
92
+ else:
93
+ emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
94
+ return emb
95
+
96
+
97
+ def get_1d_rotary_pos_embed(self, dim: int, pos, theta: float = 10000.0, use_real=False):
98
+ if isinstance(pos, int):
99
+ pos = np.arange(pos)
100
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
101
+ t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
102
+ freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
103
+ if use_real:
104
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
105
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
106
+ return freqs_cos, freqs_sin
107
+ else:
108
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
109
+ return freqs_cis
110
+
111
+
112
+ def calc_rope(self, height, width):
113
+ patch_size = 2
114
+ head_size = 88
115
+ th = height // 8 // patch_size
116
+ tw = width // 8 // patch_size
117
+ base_size = 512 // 8 // patch_size
118
+ start, stop = self.get_fill_resize_and_crop((th, tw), base_size)
119
+ sub_args = [start, stop, (th, tw)]
120
+ rope = self.get_2d_rotary_pos_embed(head_size, *sub_args)
121
+ return rope
122
+
123
+
124
+
125
+ class HunyuanDiTImagePipeline(BasePipeline):
126
+
127
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
128
+ super().__init__(device=device, torch_dtype=torch_dtype)
129
+ self.scheduler = EnhancedDDIMScheduler(prediction_type="v_prediction", beta_start=0.00085, beta_end=0.03)
130
+ self.prompter = HunyuanDiTPrompter()
131
+ self.image_size_manager = ImageSizeManager()
132
+ # models
133
+ self.text_encoder: HunyuanDiTCLIPTextEncoder = None
134
+ self.text_encoder_t5: HunyuanDiTT5TextEncoder = None
135
+ self.dit: HunyuanDiT = None
136
+ self.vae_decoder: SDXLVAEDecoder = None
137
+ self.vae_encoder: SDXLVAEEncoder = None
138
+
139
+
140
+ def denoising_model(self):
141
+ return self.dit
142
+
143
+
144
+ def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
145
+ # Main models
146
+ self.text_encoder = model_manager.fetch_model("hunyuan_dit_clip_text_encoder")
147
+ self.text_encoder_t5 = model_manager.fetch_model("hunyuan_dit_t5_text_encoder")
148
+ self.dit = model_manager.fetch_model("hunyuan_dit")
149
+ self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
150
+ self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
151
+ self.prompter.fetch_models(self.text_encoder, self.text_encoder_t5)
152
+ self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
153
+
154
+
155
+ @staticmethod
156
+ def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]):
157
+ pipe = HunyuanDiTImagePipeline(
158
+ device=model_manager.device,
159
+ torch_dtype=model_manager.torch_dtype,
160
+ )
161
+ pipe.fetch_models(model_manager, prompt_refiner_classes)
162
+ return pipe
163
+
164
+
165
+ def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
166
+ latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
167
+ return latents
168
+
169
+
170
+ def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
171
+ image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
172
+ image = self.vae_output_to_image(image)
173
+ return image
174
+
175
+
176
+ def encode_prompt(self, prompt, clip_skip=1, clip_skip_2=1, positive=True):
177
+ text_emb, text_emb_mask, text_emb_t5, text_emb_mask_t5 = self.prompter.encode_prompt(
178
+ prompt,
179
+ clip_skip=clip_skip,
180
+ clip_skip_2=clip_skip_2,
181
+ positive=positive,
182
+ device=self.device
183
+ )
184
+ return {
185
+ "text_emb": text_emb,
186
+ "text_emb_mask": text_emb_mask,
187
+ "text_emb_t5": text_emb_t5,
188
+ "text_emb_mask_t5": text_emb_mask_t5
189
+ }
190
+
191
+
192
+ def prepare_extra_input(self, latents=None, tiled=False, tile_size=64, tile_stride=32):
193
+ batch_size, height, width = latents.shape[0], latents.shape[2] * 8, latents.shape[3] * 8
194
+ if tiled:
195
+ height, width = tile_size * 16, tile_size * 16
196
+ image_meta_size = torch.as_tensor([width, height, width, height, 0, 0]).to(device=self.device)
197
+ freqs_cis_img = self.image_size_manager.calc_rope(height, width)
198
+ image_meta_size = torch.stack([image_meta_size] * batch_size)
199
+ return {
200
+ "size_emb": image_meta_size,
201
+ "freq_cis_img": (freqs_cis_img[0].to(dtype=self.torch_dtype, device=self.device), freqs_cis_img[1].to(dtype=self.torch_dtype, device=self.device)),
202
+ "tiled": tiled,
203
+ "tile_size": tile_size,
204
+ "tile_stride": tile_stride
205
+ }
206
+
207
+
208
+ @torch.no_grad()
209
+ def __call__(
210
+ self,
211
+ prompt,
212
+ negative_prompt="",
213
+ cfg_scale=7.5,
214
+ clip_skip=1,
215
+ clip_skip_2=1,
216
+ input_image=None,
217
+ reference_strengths=[0.4],
218
+ denoising_strength=1.0,
219
+ height=1024,
220
+ width=1024,
221
+ num_inference_steps=20,
222
+ tiled=False,
223
+ tile_size=64,
224
+ tile_stride=32,
225
+ progress_bar_cmd=tqdm,
226
+ progress_bar_st=None,
227
+ ):
228
+ # Prepare scheduler
229
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
230
+
231
+ # Prepare latent tensors
232
+ noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
233
+ if input_image is not None:
234
+ image = self.preprocess_image(input_image).to(device=self.device, dtype=torch.float32)
235
+ latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
236
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
237
+ else:
238
+ latents = noise.clone()
239
+
240
+ # Encode prompts
241
+ prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
242
+ if cfg_scale != 1.0:
243
+ prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
244
+
245
+ # Prepare positional id
246
+ extra_input = self.prepare_extra_input(latents, tiled, tile_size)
247
+
248
+ # Denoise
249
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
250
+ timestep = torch.tensor([timestep]).to(dtype=self.torch_dtype, device=self.device)
251
+
252
+ # Positive side
253
+ noise_pred_posi = self.dit(
254
+ latents, timestep=timestep, **prompt_emb_posi, **extra_input,
255
+ )
256
+ if cfg_scale != 1.0:
257
+ # Negative side
258
+ noise_pred_nega = self.dit(
259
+ latents, timestep=timestep, **prompt_emb_nega, **extra_input,
260
+ )
261
+ # Classifier-free guidance
262
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
263
+ else:
264
+ noise_pred = noise_pred_posi
265
+
266
+ latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
267
+
268
+ if progress_bar_st is not None:
269
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
270
+
271
+ # Decode image
272
+ image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
273
+
274
+ return image
@@ -0,0 +1,105 @@
1
+ import os, torch, json
2
+ from .sd_video import ModelManager, SDVideoPipeline, ControlNetConfigUnit
3
+ from ..processors.sequencial_processor import SequencialProcessor
4
+ from ..data import VideoData, save_frames, save_video
5
+
6
+
7
+
8
+ class SDVideoPipelineRunner:
9
+ def __init__(self, in_streamlit=False):
10
+ self.in_streamlit = in_streamlit
11
+
12
+
13
+ def load_pipeline(self, model_list, textual_inversion_folder, device, lora_alphas, controlnet_units):
14
+ # Load models
15
+ model_manager = ModelManager(torch_dtype=torch.float16, device=device)
16
+ model_manager.load_models(model_list)
17
+ pipe = SDVideoPipeline.from_model_manager(
18
+ model_manager,
19
+ [
20
+ ControlNetConfigUnit(
21
+ processor_id=unit["processor_id"],
22
+ model_path=unit["model_path"],
23
+ scale=unit["scale"]
24
+ ) for unit in controlnet_units
25
+ ]
26
+ )
27
+ textual_inversion_paths = []
28
+ for file_name in os.listdir(textual_inversion_folder):
29
+ if file_name.endswith(".pt") or file_name.endswith(".bin") or file_name.endswith(".pth") or file_name.endswith(".safetensors"):
30
+ textual_inversion_paths.append(os.path.join(textual_inversion_folder, file_name))
31
+ pipe.prompter.load_textual_inversions(textual_inversion_paths)
32
+ return model_manager, pipe
33
+
34
+
35
+ def load_smoother(self, model_manager, smoother_configs):
36
+ smoother = SequencialProcessor.from_model_manager(model_manager, smoother_configs)
37
+ return smoother
38
+
39
+
40
+ def synthesize_video(self, model_manager, pipe, seed, smoother, **pipeline_inputs):
41
+ torch.manual_seed(seed)
42
+ if self.in_streamlit:
43
+ import streamlit as st
44
+ progress_bar_st = st.progress(0.0)
45
+ output_video = pipe(**pipeline_inputs, smoother=smoother, progress_bar_st=progress_bar_st)
46
+ progress_bar_st.progress(1.0)
47
+ else:
48
+ output_video = pipe(**pipeline_inputs, smoother=smoother)
49
+ model_manager.to("cpu")
50
+ return output_video
51
+
52
+
53
+ def load_video(self, video_file, image_folder, height, width, start_frame_id, end_frame_id):
54
+ video = VideoData(video_file=video_file, image_folder=image_folder, height=height, width=width)
55
+ if start_frame_id is None:
56
+ start_frame_id = 0
57
+ if end_frame_id is None:
58
+ end_frame_id = len(video)
59
+ frames = [video[i] for i in range(start_frame_id, end_frame_id)]
60
+ return frames
61
+
62
+
63
+ def add_data_to_pipeline_inputs(self, data, pipeline_inputs):
64
+ pipeline_inputs["input_frames"] = self.load_video(**data["input_frames"])
65
+ pipeline_inputs["num_frames"] = len(pipeline_inputs["input_frames"])
66
+ pipeline_inputs["width"], pipeline_inputs["height"] = pipeline_inputs["input_frames"][0].size
67
+ if len(data["controlnet_frames"]) > 0:
68
+ pipeline_inputs["controlnet_frames"] = [self.load_video(**unit) for unit in data["controlnet_frames"]]
69
+ return pipeline_inputs
70
+
71
+
72
+ def save_output(self, video, output_folder, fps, config):
73
+ os.makedirs(output_folder, exist_ok=True)
74
+ save_frames(video, os.path.join(output_folder, "frames"))
75
+ save_video(video, os.path.join(output_folder, "video.mp4"), fps=fps)
76
+ config["pipeline"]["pipeline_inputs"]["input_frames"] = []
77
+ config["pipeline"]["pipeline_inputs"]["controlnet_frames"] = []
78
+ with open(os.path.join(output_folder, "config.json"), 'w') as file:
79
+ json.dump(config, file, indent=4)
80
+
81
+
82
+ def run(self, config):
83
+ if self.in_streamlit:
84
+ import streamlit as st
85
+ if self.in_streamlit: st.markdown("Loading videos ...")
86
+ config["pipeline"]["pipeline_inputs"] = self.add_data_to_pipeline_inputs(config["data"], config["pipeline"]["pipeline_inputs"])
87
+ if self.in_streamlit: st.markdown("Loading videos ... done!")
88
+ if self.in_streamlit: st.markdown("Loading models ...")
89
+ model_manager, pipe = self.load_pipeline(**config["models"])
90
+ if self.in_streamlit: st.markdown("Loading models ... done!")
91
+ if "smoother_configs" in config:
92
+ if self.in_streamlit: st.markdown("Loading smoother ...")
93
+ smoother = self.load_smoother(model_manager, config["smoother_configs"])
94
+ if self.in_streamlit: st.markdown("Loading smoother ... done!")
95
+ else:
96
+ smoother = None
97
+ if self.in_streamlit: st.markdown("Synthesizing videos ...")
98
+ output_video = self.synthesize_video(model_manager, pipe, config["pipeline"]["seed"], smoother, **config["pipeline"]["pipeline_inputs"])
99
+ if self.in_streamlit: st.markdown("Synthesizing videos ... done!")
100
+ if self.in_streamlit: st.markdown("Saving videos ...")
101
+ self.save_output(output_video, config["data"]["output_folder"], config["data"]["fps"], config)
102
+ if self.in_streamlit: st.markdown("Saving videos ... done!")
103
+ if self.in_streamlit: st.markdown("Finished!")
104
+ video_file = open(os.path.join(os.path.join(config["data"]["output_folder"], "video.mp4")), 'rb')
105
+ if self.in_streamlit: st.video(video_file.read())
@@ -0,0 +1,132 @@
1
+ from ..models import ModelManager, SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEDecoder, SD3VAEEncoder
2
+ from ..prompters import SD3Prompter
3
+ from ..schedulers import FlowMatchScheduler
4
+ from .base import BasePipeline
5
+ import torch
6
+ from tqdm import tqdm
7
+
8
+
9
+
10
+ class SD3ImagePipeline(BasePipeline):
11
+
12
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
13
+ super().__init__(device=device, torch_dtype=torch_dtype)
14
+ self.scheduler = FlowMatchScheduler()
15
+ self.prompter = SD3Prompter()
16
+ # models
17
+ self.text_encoder_1: SD3TextEncoder1 = None
18
+ self.text_encoder_2: SD3TextEncoder2 = None
19
+ self.text_encoder_3: SD3TextEncoder3 = None
20
+ self.dit: SD3DiT = None
21
+ self.vae_decoder: SD3VAEDecoder = None
22
+ self.vae_encoder: SD3VAEEncoder = None
23
+
24
+
25
+ def denoising_model(self):
26
+ return self.dit
27
+
28
+
29
+ def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
30
+ self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
31
+ self.text_encoder_2 = model_manager.fetch_model("sd3_text_encoder_2")
32
+ if "sd3_text_encoder_3" in model_manager.model:
33
+ self.text_encoder_3 = model_manager.fetch_model("sd3_text_encoder_3")
34
+ self.dit = model_manager.fetch_model("sd3_dit")
35
+ self.vae_decoder = model_manager.fetch_model("sd3_vae_decoder")
36
+ self.vae_encoder = model_manager.fetch_model("sd3_vae_encoder")
37
+ self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2, self.text_encoder_3)
38
+ self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
39
+
40
+
41
+ @staticmethod
42
+ def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]):
43
+ pipe = SD3ImagePipeline(
44
+ device=model_manager.device,
45
+ torch_dtype=model_manager.torch_dtype,
46
+ )
47
+ pipe.fetch_models(model_manager, prompt_refiner_classes)
48
+ return pipe
49
+
50
+
51
+ def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
52
+ latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
53
+ return latents
54
+
55
+
56
+ def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
57
+ image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
58
+ image = self.vae_output_to_image(image)
59
+ return image
60
+
61
+
62
+ def encode_prompt(self, prompt, positive=True):
63
+ prompt_emb, pooled_prompt_emb = self.prompter.encode_prompt(
64
+ prompt, device=self.device, positive=positive
65
+ )
66
+ return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb}
67
+
68
+
69
+ def prepare_extra_input(self, latents=None):
70
+ return {}
71
+
72
+
73
+ @torch.no_grad()
74
+ def __call__(
75
+ self,
76
+ prompt,
77
+ negative_prompt="",
78
+ cfg_scale=7.5,
79
+ input_image=None,
80
+ denoising_strength=1.0,
81
+ height=1024,
82
+ width=1024,
83
+ num_inference_steps=20,
84
+ tiled=False,
85
+ tile_size=128,
86
+ tile_stride=64,
87
+ progress_bar_cmd=tqdm,
88
+ progress_bar_st=None,
89
+ ):
90
+ # Tiler parameters
91
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
92
+
93
+ # Prepare scheduler
94
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
95
+
96
+ # Prepare latent tensors
97
+ if input_image is not None:
98
+ image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
99
+ latents = self.encode_image(image, **tiler_kwargs)
100
+ noise = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
101
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
102
+ else:
103
+ latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
104
+
105
+ # Encode prompts
106
+ prompt_emb_posi = self.encode_prompt(prompt, positive=True)
107
+ prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
108
+
109
+ # Denoise
110
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
111
+ timestep = timestep.unsqueeze(0).to(self.device)
112
+
113
+ # Classifier-free guidance
114
+ noise_pred_posi = self.dit(
115
+ latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs,
116
+ )
117
+ noise_pred_nega = self.dit(
118
+ latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs,
119
+ )
120
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
121
+
122
+ # DDIM
123
+ latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
124
+
125
+ # UI
126
+ if progress_bar_st is not None:
127
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
128
+
129
+ # Decode image
130
+ image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
131
+
132
+ return image
@@ -0,0 +1,173 @@
1
+ from ..models import SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder
2
+ from ..models.model_manager import ModelManager
3
+ from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
4
+ from ..prompters import SDPrompter
5
+ from ..schedulers import EnhancedDDIMScheduler
6
+ from .base import BasePipeline
7
+ from .dancer import lets_dance
8
+ from typing import List
9
+ import torch
10
+ from tqdm import tqdm
11
+
12
+
13
+
14
+ class SDImagePipeline(BasePipeline):
15
+
16
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
17
+ super().__init__(device=device, torch_dtype=torch_dtype)
18
+ self.scheduler = EnhancedDDIMScheduler()
19
+ self.prompter = SDPrompter()
20
+ # models
21
+ self.text_encoder: SDTextEncoder = None
22
+ self.unet: SDUNet = None
23
+ self.vae_decoder: SDVAEDecoder = None
24
+ self.vae_encoder: SDVAEEncoder = None
25
+ self.controlnet: MultiControlNetManager = None
26
+ self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None
27
+ self.ipadapter: SDIpAdapter = None
28
+
29
+
30
+ def denoising_model(self):
31
+ return self.unet
32
+
33
+
34
+ def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
35
+ # Main models
36
+ self.text_encoder = model_manager.fetch_model("sd_text_encoder")
37
+ self.unet = model_manager.fetch_model("sd_unet")
38
+ self.vae_decoder = model_manager.fetch_model("sd_vae_decoder")
39
+ self.vae_encoder = model_manager.fetch_model("sd_vae_encoder")
40
+ self.prompter.fetch_models(self.text_encoder)
41
+ self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
42
+
43
+ # ControlNets
44
+ controlnet_units = []
45
+ for config in controlnet_config_units:
46
+ controlnet_unit = ControlNetUnit(
47
+ Annotator(config.processor_id, device=self.device),
48
+ model_manager.fetch_model("sd_controlnet", config.model_path),
49
+ config.scale
50
+ )
51
+ controlnet_units.append(controlnet_unit)
52
+ self.controlnet = MultiControlNetManager(controlnet_units)
53
+
54
+ # IP-Adapters
55
+ self.ipadapter = model_manager.fetch_model("sd_ipadapter")
56
+ self.ipadapter_image_encoder = model_manager.fetch_model("sd_ipadapter_clip_image_encoder")
57
+
58
+
59
+ @staticmethod
60
+ def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
61
+ pipe = SDImagePipeline(
62
+ device=model_manager.device,
63
+ torch_dtype=model_manager.torch_dtype,
64
+ )
65
+ pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes=[])
66
+ return pipe
67
+
68
+
69
+ def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
70
+ latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
71
+ return latents
72
+
73
+
74
+ def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
75
+ image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
76
+ image = self.vae_output_to_image(image)
77
+ return image
78
+
79
+
80
+ def encode_prompt(self, prompt, clip_skip=1, positive=True):
81
+ prompt_emb = self.prompter.encode_prompt(prompt, clip_skip=clip_skip, device=self.device, positive=positive)
82
+ return {"encoder_hidden_states": prompt_emb}
83
+
84
+
85
+ def prepare_extra_input(self, latents=None):
86
+ return {}
87
+
88
+
89
+ @torch.no_grad()
90
+ def __call__(
91
+ self,
92
+ prompt,
93
+ negative_prompt="",
94
+ cfg_scale=7.5,
95
+ clip_skip=1,
96
+ input_image=None,
97
+ ipadapter_images=None,
98
+ ipadapter_scale=1.0,
99
+ controlnet_image=None,
100
+ denoising_strength=1.0,
101
+ height=512,
102
+ width=512,
103
+ num_inference_steps=20,
104
+ tiled=False,
105
+ tile_size=64,
106
+ tile_stride=32,
107
+ progress_bar_cmd=tqdm,
108
+ progress_bar_st=None,
109
+ ):
110
+ # Tiler parameters
111
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
112
+
113
+ # Prepare scheduler
114
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
115
+
116
+ # Prepare latent tensors
117
+ if input_image is not None:
118
+ image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
119
+ latents = self.encode_image(image, **tiler_kwargs)
120
+ noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
121
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
122
+ else:
123
+ latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
124
+
125
+ # Encode prompts
126
+ prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
127
+ prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
128
+
129
+ # IP-Adapter
130
+ if ipadapter_images is not None:
131
+ ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
132
+ ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
133
+ ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
134
+ else:
135
+ ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
136
+
137
+ # Prepare ControlNets
138
+ if controlnet_image is not None:
139
+ controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
140
+ controlnet_image = controlnet_image.unsqueeze(1)
141
+ controlnet_kwargs = {"controlnet_frames": controlnet_image}
142
+ else:
143
+ controlnet_kwargs = {"controlnet_frames": None}
144
+
145
+ # Denoise
146
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
147
+ timestep = timestep.unsqueeze(0).to(self.device)
148
+
149
+ # Classifier-free guidance
150
+ noise_pred_posi = lets_dance(
151
+ self.unet, motion_modules=None, controlnet=self.controlnet,
152
+ sample=latents, timestep=timestep,
153
+ **prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi,
154
+ device=self.device,
155
+ )
156
+ noise_pred_nega = lets_dance(
157
+ self.unet, motion_modules=None, controlnet=self.controlnet,
158
+ sample=latents, timestep=timestep, **prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega,
159
+ device=self.device,
160
+ )
161
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
162
+
163
+ # DDIM
164
+ latents = self.scheduler.step(noise_pred, timestep, latents)
165
+
166
+ # UI
167
+ if progress_bar_st is not None:
168
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
169
+
170
+ # Decode image
171
+ image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
172
+
173
+ return image