diffsynth 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. diffsynth/__init__.py +6 -0
  2. diffsynth/configs/__init__.py +0 -0
  3. diffsynth/configs/model_config.py +243 -0
  4. diffsynth/controlnets/__init__.py +2 -0
  5. diffsynth/controlnets/controlnet_unit.py +53 -0
  6. diffsynth/controlnets/processors.py +51 -0
  7. diffsynth/data/__init__.py +1 -0
  8. diffsynth/data/simple_text_image.py +35 -0
  9. diffsynth/data/video.py +148 -0
  10. diffsynth/extensions/ESRGAN/__init__.py +118 -0
  11. diffsynth/extensions/FastBlend/__init__.py +63 -0
  12. diffsynth/extensions/FastBlend/api.py +397 -0
  13. diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
  14. diffsynth/extensions/FastBlend/data.py +146 -0
  15. diffsynth/extensions/FastBlend/patch_match.py +298 -0
  16. diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
  17. diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
  18. diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
  19. diffsynth/extensions/FastBlend/runners/fast.py +141 -0
  20. diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
  21. diffsynth/extensions/RIFE/__init__.py +242 -0
  22. diffsynth/extensions/__init__.py +0 -0
  23. diffsynth/models/__init__.py +1 -0
  24. diffsynth/models/attention.py +89 -0
  25. diffsynth/models/downloader.py +66 -0
  26. diffsynth/models/hunyuan_dit.py +451 -0
  27. diffsynth/models/hunyuan_dit_text_encoder.py +163 -0
  28. diffsynth/models/kolors_text_encoder.py +1363 -0
  29. diffsynth/models/lora.py +195 -0
  30. diffsynth/models/model_manager.py +536 -0
  31. diffsynth/models/sd3_dit.py +798 -0
  32. diffsynth/models/sd3_text_encoder.py +1107 -0
  33. diffsynth/models/sd3_vae_decoder.py +81 -0
  34. diffsynth/models/sd3_vae_encoder.py +95 -0
  35. diffsynth/models/sd_controlnet.py +588 -0
  36. diffsynth/models/sd_ipadapter.py +57 -0
  37. diffsynth/models/sd_motion.py +199 -0
  38. diffsynth/models/sd_text_encoder.py +321 -0
  39. diffsynth/models/sd_unet.py +1108 -0
  40. diffsynth/models/sd_vae_decoder.py +336 -0
  41. diffsynth/models/sd_vae_encoder.py +282 -0
  42. diffsynth/models/sdxl_ipadapter.py +122 -0
  43. diffsynth/models/sdxl_motion.py +104 -0
  44. diffsynth/models/sdxl_text_encoder.py +759 -0
  45. diffsynth/models/sdxl_unet.py +1899 -0
  46. diffsynth/models/sdxl_vae_decoder.py +24 -0
  47. diffsynth/models/sdxl_vae_encoder.py +24 -0
  48. diffsynth/models/svd_image_encoder.py +505 -0
  49. diffsynth/models/svd_unet.py +2004 -0
  50. diffsynth/models/svd_vae_decoder.py +578 -0
  51. diffsynth/models/svd_vae_encoder.py +139 -0
  52. diffsynth/models/tiler.py +106 -0
  53. diffsynth/pipelines/__init__.py +9 -0
  54. diffsynth/pipelines/base.py +34 -0
  55. diffsynth/pipelines/dancer.py +178 -0
  56. diffsynth/pipelines/hunyuan_image.py +274 -0
  57. diffsynth/pipelines/pipeline_runner.py +105 -0
  58. diffsynth/pipelines/sd3_image.py +132 -0
  59. diffsynth/pipelines/sd_image.py +173 -0
  60. diffsynth/pipelines/sd_video.py +266 -0
  61. diffsynth/pipelines/sdxl_image.py +191 -0
  62. diffsynth/pipelines/sdxl_video.py +223 -0
  63. diffsynth/pipelines/svd_video.py +297 -0
  64. diffsynth/processors/FastBlend.py +142 -0
  65. diffsynth/processors/PILEditor.py +28 -0
  66. diffsynth/processors/RIFE.py +77 -0
  67. diffsynth/processors/__init__.py +0 -0
  68. diffsynth/processors/base.py +6 -0
  69. diffsynth/processors/sequencial_processor.py +41 -0
  70. diffsynth/prompters/__init__.py +6 -0
  71. diffsynth/prompters/base_prompter.py +57 -0
  72. diffsynth/prompters/hunyuan_dit_prompter.py +69 -0
  73. diffsynth/prompters/kolors_prompter.py +353 -0
  74. diffsynth/prompters/prompt_refiners.py +77 -0
  75. diffsynth/prompters/sd3_prompter.py +92 -0
  76. diffsynth/prompters/sd_prompter.py +73 -0
  77. diffsynth/prompters/sdxl_prompter.py +61 -0
  78. diffsynth/schedulers/__init__.py +3 -0
  79. diffsynth/schedulers/continuous_ode.py +59 -0
  80. diffsynth/schedulers/ddim.py +79 -0
  81. diffsynth/schedulers/flow_match.py +51 -0
  82. diffsynth/tokenizer_configs/__init__.py +0 -0
  83. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json +7 -0
  84. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json +16 -0
  85. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt +47020 -0
  86. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt +21128 -0
  87. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json +28 -0
  88. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json +1 -0
  89. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model +0 -0
  90. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json +1 -0
  91. diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model +0 -0
  92. diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json +12 -0
  93. diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt +0 -0
  94. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt +48895 -0
  95. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json +24 -0
  96. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json +34 -0
  97. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json +49410 -0
  98. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt +48895 -0
  99. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json +30 -0
  100. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json +30 -0
  101. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json +49410 -0
  102. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt +48895 -0
  103. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json +30 -0
  104. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json +38 -0
  105. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json +49410 -0
  106. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json +125 -0
  107. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/spiece.model +0 -0
  108. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json +129428 -0
  109. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json +940 -0
  110. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt +40213 -0
  111. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json +24 -0
  112. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json +38 -0
  113. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/vocab.json +49411 -0
  114. diffsynth/trainers/__init__.py +0 -0
  115. diffsynth/trainers/text_to_image.py +253 -0
  116. diffsynth-1.0.0.dist-info/LICENSE +201 -0
  117. diffsynth-1.0.0.dist-info/METADATA +23 -0
  118. diffsynth-1.0.0.dist-info/RECORD +120 -0
  119. diffsynth-1.0.0.dist-info/WHEEL +5 -0
  120. diffsynth-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,266 @@
1
+ from ..models import SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder, SDMotionModel
2
+ from ..models.model_manager import ModelManager
3
+ from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
4
+ from ..prompters import SDPrompter
5
+ from ..schedulers import EnhancedDDIMScheduler
6
+ from .sd_image import SDImagePipeline
7
+ from .dancer import lets_dance
8
+ from typing import List
9
+ import torch
10
+ from tqdm import tqdm
11
+
12
+
13
+
14
+ def lets_dance_with_long_video(
15
+ unet: SDUNet,
16
+ motion_modules: SDMotionModel = None,
17
+ controlnet: MultiControlNetManager = None,
18
+ sample = None,
19
+ timestep = None,
20
+ encoder_hidden_states = None,
21
+ ipadapter_kwargs_list = {},
22
+ controlnet_frames = None,
23
+ unet_batch_size = 1,
24
+ controlnet_batch_size = 1,
25
+ cross_frame_attention = False,
26
+ tiled=False,
27
+ tile_size=64,
28
+ tile_stride=32,
29
+ device="cuda",
30
+ animatediff_batch_size=16,
31
+ animatediff_stride=8,
32
+ ):
33
+ num_frames = sample.shape[0]
34
+ hidden_states_output = [(torch.zeros(sample[0].shape, dtype=sample[0].dtype), 0) for i in range(num_frames)]
35
+
36
+ for batch_id in range(0, num_frames, animatediff_stride):
37
+ batch_id_ = min(batch_id + animatediff_batch_size, num_frames)
38
+
39
+ # process this batch
40
+ hidden_states_batch = lets_dance(
41
+ unet, motion_modules, controlnet,
42
+ sample[batch_id: batch_id_].to(device),
43
+ timestep,
44
+ encoder_hidden_states,
45
+ ipadapter_kwargs_list=ipadapter_kwargs_list,
46
+ controlnet_frames=controlnet_frames[:, batch_id: batch_id_].to(device) if controlnet_frames is not None else None,
47
+ unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
48
+ cross_frame_attention=cross_frame_attention,
49
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, device=device
50
+ ).cpu()
51
+
52
+ # update hidden_states
53
+ for i, hidden_states_updated in zip(range(batch_id, batch_id_), hidden_states_batch):
54
+ bias = max(1 - abs(i - (batch_id + batch_id_ - 1) / 2) / ((batch_id_ - batch_id - 1 + 1e-2) / 2), 1e-2)
55
+ hidden_states, num = hidden_states_output[i]
56
+ hidden_states = hidden_states * (num / (num + bias)) + hidden_states_updated * (bias / (num + bias))
57
+ hidden_states_output[i] = (hidden_states, num + bias)
58
+
59
+ if batch_id_ == num_frames:
60
+ break
61
+
62
+ # output
63
+ hidden_states = torch.stack([h for h, _ in hidden_states_output])
64
+ return hidden_states
65
+
66
+
67
+
68
+ class SDVideoPipeline(SDImagePipeline):
69
+
70
+ def __init__(self, device="cuda", torch_dtype=torch.float16, use_original_animatediff=True):
71
+ super().__init__(device=device, torch_dtype=torch_dtype)
72
+ self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_original_animatediff else "scaled_linear")
73
+ self.prompter = SDPrompter()
74
+ # models
75
+ self.text_encoder: SDTextEncoder = None
76
+ self.unet: SDUNet = None
77
+ self.vae_decoder: SDVAEDecoder = None
78
+ self.vae_encoder: SDVAEEncoder = None
79
+ self.controlnet: MultiControlNetManager = None
80
+ self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None
81
+ self.ipadapter: SDIpAdapter = None
82
+ self.motion_modules: SDMotionModel = None
83
+
84
+
85
+ def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
86
+ # Main models
87
+ self.text_encoder = model_manager.fetch_model("sd_text_encoder")
88
+ self.unet = model_manager.fetch_model("sd_unet")
89
+ self.vae_decoder = model_manager.fetch_model("sd_vae_decoder")
90
+ self.vae_encoder = model_manager.fetch_model("sd_vae_encoder")
91
+ self.prompter.fetch_models(self.text_encoder)
92
+ self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
93
+
94
+ # ControlNets
95
+ controlnet_units = []
96
+ for config in controlnet_config_units:
97
+ controlnet_unit = ControlNetUnit(
98
+ Annotator(config.processor_id, device=self.device),
99
+ model_manager.fetch_model("sd_controlnet", config.model_path),
100
+ config.scale
101
+ )
102
+ controlnet_units.append(controlnet_unit)
103
+ self.controlnet = MultiControlNetManager(controlnet_units)
104
+
105
+ # IP-Adapters
106
+ self.ipadapter = model_manager.fetch_model("sd_ipadapter")
107
+ self.ipadapter_image_encoder = model_manager.fetch_model("sd_ipadapter_clip_image_encoder")
108
+
109
+ # Motion Modules
110
+ self.motion_modules = model_manager.fetch_model("sd_motion_modules")
111
+ if self.motion_modules is None:
112
+ self.scheduler = EnhancedDDIMScheduler(beta_schedule="scaled_linear")
113
+
114
+
115
+ @staticmethod
116
+ def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
117
+ pipe = SDVideoPipeline(
118
+ device=model_manager.device,
119
+ torch_dtype=model_manager.torch_dtype,
120
+ )
121
+ pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
122
+ return pipe
123
+
124
+
125
+ def decode_video(self, latents, tiled=False, tile_size=64, tile_stride=32):
126
+ images = [
127
+ self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
128
+ for frame_id in range(latents.shape[0])
129
+ ]
130
+ return images
131
+
132
+
133
+ def encode_video(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
134
+ latents = []
135
+ for image in processed_images:
136
+ image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
137
+ latent = self.encode_image(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
138
+ latents.append(latent.cpu())
139
+ latents = torch.concat(latents, dim=0)
140
+ return latents
141
+
142
+
143
+ @torch.no_grad()
144
+ def __call__(
145
+ self,
146
+ prompt,
147
+ negative_prompt="",
148
+ cfg_scale=7.5,
149
+ clip_skip=1,
150
+ num_frames=None,
151
+ input_frames=None,
152
+ ipadapter_images=None,
153
+ ipadapter_scale=1.0,
154
+ controlnet_frames=None,
155
+ denoising_strength=1.0,
156
+ height=512,
157
+ width=512,
158
+ num_inference_steps=20,
159
+ animatediff_batch_size = 16,
160
+ animatediff_stride = 8,
161
+ unet_batch_size = 1,
162
+ controlnet_batch_size = 1,
163
+ cross_frame_attention = False,
164
+ smoother=None,
165
+ smoother_progress_ids=[],
166
+ tiled=False,
167
+ tile_size=64,
168
+ tile_stride=32,
169
+ progress_bar_cmd=tqdm,
170
+ progress_bar_st=None,
171
+ ):
172
+ # Tiler parameters, batch size ...
173
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
174
+ other_kwargs = {
175
+ "animatediff_batch_size": animatediff_batch_size, "animatediff_stride": animatediff_stride,
176
+ "unet_batch_size": unet_batch_size, "controlnet_batch_size": controlnet_batch_size,
177
+ "cross_frame_attention": cross_frame_attention,
178
+ }
179
+
180
+ # Prepare scheduler
181
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
182
+
183
+ # Prepare latent tensors
184
+ if self.motion_modules is None:
185
+ noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
186
+ else:
187
+ noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype)
188
+ if input_frames is None or denoising_strength == 1.0:
189
+ latents = noise
190
+ else:
191
+ latents = self.encode_video(input_frames, **tiler_kwargs)
192
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
193
+
194
+ # Encode prompts
195
+ prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
196
+ prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
197
+
198
+ # IP-Adapter
199
+ if ipadapter_images is not None:
200
+ ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
201
+ ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
202
+ ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
203
+ else:
204
+ ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
205
+
206
+ # Prepare ControlNets
207
+ if controlnet_frames is not None:
208
+ if isinstance(controlnet_frames[0], list):
209
+ controlnet_frames_ = []
210
+ for processor_id in range(len(controlnet_frames)):
211
+ controlnet_frames_.append(
212
+ torch.stack([
213
+ self.controlnet.process_image(controlnet_frame, processor_id=processor_id).to(self.torch_dtype)
214
+ for controlnet_frame in progress_bar_cmd(controlnet_frames[processor_id])
215
+ ], dim=1)
216
+ )
217
+ controlnet_frames = torch.concat(controlnet_frames_, dim=0)
218
+ else:
219
+ controlnet_frames = torch.stack([
220
+ self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
221
+ for controlnet_frame in progress_bar_cmd(controlnet_frames)
222
+ ], dim=1)
223
+ controlnet_kwargs = {"controlnet_frames": controlnet_frames}
224
+ else:
225
+ controlnet_kwargs = {"controlnet_frames": None}
226
+
227
+ # Denoise
228
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
229
+ timestep = timestep.unsqueeze(0).to(self.device)
230
+
231
+ # Classifier-free guidance
232
+ noise_pred_posi = lets_dance_with_long_video(
233
+ self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
234
+ sample=latents, timestep=timestep,
235
+ **prompt_emb_posi, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **other_kwargs, **tiler_kwargs,
236
+ device=self.device,
237
+ )
238
+ noise_pred_nega = lets_dance_with_long_video(
239
+ self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
240
+ sample=latents, timestep=timestep,
241
+ **prompt_emb_nega, **controlnet_kwargs, **ipadapter_kwargs_list_nega, **other_kwargs, **tiler_kwargs,
242
+ device=self.device,
243
+ )
244
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
245
+
246
+ # DDIM and smoother
247
+ if smoother is not None and progress_id in smoother_progress_ids:
248
+ rendered_frames = self.scheduler.step(noise_pred, timestep, latents, to_final=True)
249
+ rendered_frames = self.decode_video(rendered_frames)
250
+ rendered_frames = smoother(rendered_frames, original_frames=input_frames)
251
+ target_latents = self.encode_video(rendered_frames)
252
+ noise_pred = self.scheduler.return_to_timestep(timestep, latents, target_latents)
253
+ latents = self.scheduler.step(noise_pred, timestep, latents)
254
+
255
+ # UI
256
+ if progress_bar_st is not None:
257
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
258
+
259
+ # Decode image
260
+ output_frames = self.decode_video(latents, **tiler_kwargs)
261
+
262
+ # Post-process
263
+ if smoother is not None and (num_inference_steps in smoother_progress_ids or -1 in smoother_progress_ids):
264
+ output_frames = smoother(output_frames, original_frames=input_frames)
265
+
266
+ return output_frames
@@ -0,0 +1,191 @@
1
+ from ..models import SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
2
+ from ..models.kolors_text_encoder import ChatGLMModel
3
+ from ..models.model_manager import ModelManager
4
+ from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
5
+ from ..prompters import SDXLPrompter, KolorsPrompter
6
+ from ..schedulers import EnhancedDDIMScheduler
7
+ from .base import BasePipeline
8
+ from .dancer import lets_dance_xl
9
+ from typing import List
10
+ import torch
11
+ from tqdm import tqdm
12
+
13
+
14
+
15
+ class SDXLImagePipeline(BasePipeline):
16
+
17
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
18
+ super().__init__(device=device, torch_dtype=torch_dtype)
19
+ self.scheduler = EnhancedDDIMScheduler()
20
+ self.prompter = SDXLPrompter()
21
+ # models
22
+ self.text_encoder: SDXLTextEncoder = None
23
+ self.text_encoder_2: SDXLTextEncoder2 = None
24
+ self.text_encoder_kolors: ChatGLMModel = None
25
+ self.unet: SDXLUNet = None
26
+ self.vae_decoder: SDXLVAEDecoder = None
27
+ self.vae_encoder: SDXLVAEEncoder = None
28
+ # self.controlnet: MultiControlNetManager = None (TODO)
29
+ self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
30
+ self.ipadapter: SDXLIpAdapter = None
31
+
32
+
33
+ def denoising_model(self):
34
+ return self.unet
35
+
36
+
37
+ def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
38
+ # Main models
39
+ self.text_encoder = model_manager.fetch_model("sdxl_text_encoder")
40
+ self.text_encoder_2 = model_manager.fetch_model("sdxl_text_encoder_2")
41
+ self.text_encoder_kolors = model_manager.fetch_model("kolors_text_encoder")
42
+ self.unet = model_manager.fetch_model("sdxl_unet")
43
+ self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
44
+ self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
45
+
46
+ # ControlNets (TODO)
47
+
48
+ # IP-Adapters
49
+ self.ipadapter = model_manager.fetch_model("sdxl_ipadapter")
50
+ self.ipadapter_image_encoder = model_manager.fetch_model("sdxl_ipadapter_clip_image_encoder")
51
+
52
+ # Kolors
53
+ if self.text_encoder_kolors is not None:
54
+ print("Switch to Kolors. The prompter and scheduler will be replaced.")
55
+ self.prompter = KolorsPrompter()
56
+ self.prompter.fetch_models(self.text_encoder_kolors)
57
+ self.scheduler = EnhancedDDIMScheduler(beta_end=0.014, num_train_timesteps=1100)
58
+ else:
59
+ self.prompter.fetch_models(self.text_encoder, self.text_encoder_2)
60
+ self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
61
+
62
+
63
+ @staticmethod
64
+ def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
65
+ pipe = SDXLImagePipeline(
66
+ device=model_manager.device,
67
+ torch_dtype=model_manager.torch_dtype,
68
+ )
69
+ pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
70
+ return pipe
71
+
72
+
73
+ def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
74
+ latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
75
+ return latents
76
+
77
+
78
+ def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
79
+ image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
80
+ image = self.vae_output_to_image(image)
81
+ return image
82
+
83
+
84
+ def encode_prompt(self, prompt, clip_skip=1, clip_skip_2=2, positive=True):
85
+ add_prompt_emb, prompt_emb = self.prompter.encode_prompt(
86
+ prompt,
87
+ clip_skip=clip_skip, clip_skip_2=clip_skip_2,
88
+ device=self.device,
89
+ positive=positive,
90
+ )
91
+ return {"encoder_hidden_states": prompt_emb, "add_text_embeds": add_prompt_emb}
92
+
93
+
94
+ def prepare_extra_input(self, latents=None):
95
+ height, width = latents.shape[2] * 8, latents.shape[3] * 8
96
+ return {"add_time_id": torch.tensor([height, width, 0, 0, height, width], device=self.device)}
97
+
98
+
99
+ @torch.no_grad()
100
+ def __call__(
101
+ self,
102
+ prompt,
103
+ negative_prompt="",
104
+ cfg_scale=7.5,
105
+ clip_skip=1,
106
+ clip_skip_2=2,
107
+ input_image=None,
108
+ ipadapter_images=None,
109
+ ipadapter_scale=1.0,
110
+ ipadapter_use_instant_style=False,
111
+ controlnet_image=None,
112
+ denoising_strength=1.0,
113
+ height=1024,
114
+ width=1024,
115
+ num_inference_steps=20,
116
+ tiled=False,
117
+ tile_size=64,
118
+ tile_stride=32,
119
+ progress_bar_cmd=tqdm,
120
+ progress_bar_st=None,
121
+ ):
122
+ # Tiler parameters
123
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
124
+
125
+ # Prepare scheduler
126
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
127
+
128
+ # Prepare latent tensors
129
+ if input_image is not None:
130
+ image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
131
+ latents = self.encode_image(image, **tiler_kwargs)
132
+ noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
133
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
134
+ else:
135
+ latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
136
+
137
+ # Encode prompts
138
+ prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
139
+ prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=False)
140
+
141
+ # IP-Adapter
142
+ if ipadapter_images is not None:
143
+ if ipadapter_use_instant_style:
144
+ self.ipadapter.set_less_adapter()
145
+ else:
146
+ self.ipadapter.set_full_adapter()
147
+ ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
148
+ ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
149
+ ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
150
+ else:
151
+ ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
152
+
153
+ # Prepare ControlNets (TODO)
154
+ controlnet_kwargs = {"controlnet_frames": None}
155
+
156
+ # Prepare extra input
157
+ extra_input = self.prepare_extra_input(latents)
158
+
159
+ # Denoise
160
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
161
+ timestep = timestep.unsqueeze(0).to(self.device)
162
+
163
+ # Classifier-free guidance
164
+ noise_pred_posi = lets_dance_xl(
165
+ self.unet, motion_modules=None, controlnet=None,
166
+ sample=latents, timestep=timestep, **extra_input,
167
+ **prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi,
168
+ device=self.device,
169
+ )
170
+ if cfg_scale != 1.0:
171
+ noise_pred_nega = lets_dance_xl(
172
+ self.unet, motion_modules=None, controlnet=None,
173
+ sample=latents, timestep=timestep, **extra_input,
174
+ **prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega,
175
+ device=self.device,
176
+ )
177
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
178
+ else:
179
+ noise_pred = noise_pred_posi
180
+
181
+ # DDIM
182
+ latents = self.scheduler.step(noise_pred, timestep, latents)
183
+
184
+ # UI
185
+ if progress_bar_st is not None:
186
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
187
+
188
+ # Decode image
189
+ image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
190
+
191
+ return image
@@ -0,0 +1,223 @@
1
+ from ..models import SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder, SDXLMotionModel
2
+ from ..models.kolors_text_encoder import ChatGLMModel
3
+ from ..models.model_manager import ModelManager
4
+ from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
5
+ from ..prompters import SDXLPrompter, KolorsPrompter
6
+ from ..schedulers import EnhancedDDIMScheduler
7
+ from .sdxl_image import SDXLImagePipeline
8
+ from .dancer import lets_dance_xl
9
+ from typing import List
10
+ import torch
11
+ from tqdm import tqdm
12
+
13
+
14
+
15
+ class SDXLVideoPipeline(SDXLImagePipeline):
16
+
17
+ def __init__(self, device="cuda", torch_dtype=torch.float16, use_original_animatediff=True):
18
+ super().__init__(device=device, torch_dtype=torch_dtype)
19
+ self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_original_animatediff else "scaled_linear")
20
+ self.prompter = SDXLPrompter()
21
+ # models
22
+ self.text_encoder: SDXLTextEncoder = None
23
+ self.text_encoder_2: SDXLTextEncoder2 = None
24
+ self.text_encoder_kolors: ChatGLMModel = None
25
+ self.unet: SDXLUNet = None
26
+ self.vae_decoder: SDXLVAEDecoder = None
27
+ self.vae_encoder: SDXLVAEEncoder = None
28
+ # self.controlnet: MultiControlNetManager = None (TODO)
29
+ self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
30
+ self.ipadapter: SDXLIpAdapter = None
31
+ self.motion_modules: SDXLMotionModel = None
32
+
33
+
34
+ def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
35
+ # Main models
36
+ self.text_encoder = model_manager.fetch_model("sdxl_text_encoder")
37
+ self.text_encoder_2 = model_manager.fetch_model("sdxl_text_encoder_2")
38
+ self.text_encoder_kolors = model_manager.fetch_model("kolors_text_encoder")
39
+ self.unet = model_manager.fetch_model("sdxl_unet")
40
+ self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
41
+ self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
42
+ self.prompter.fetch_models(self.text_encoder)
43
+ self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
44
+
45
+ # ControlNets (TODO)
46
+
47
+ # IP-Adapters
48
+ self.ipadapter = model_manager.fetch_model("sdxl_ipadapter")
49
+ self.ipadapter_image_encoder = model_manager.fetch_model("sdxl_ipadapter_clip_image_encoder")
50
+
51
+ # Motion Modules
52
+ self.motion_modules = model_manager.fetch_model("sdxl_motion_modules")
53
+ if self.motion_modules is None:
54
+ self.scheduler = EnhancedDDIMScheduler(beta_schedule="scaled_linear")
55
+
56
+ # Kolors
57
+ if self.text_encoder_kolors is not None:
58
+ print("Switch to Kolors. The prompter will be replaced.")
59
+ self.prompter = KolorsPrompter()
60
+ self.prompter.fetch_models(self.text_encoder_kolors)
61
+ # The schedulers of AniamteDiff and Kolors are incompatible. We align it with AniamteDiff.
62
+ if self.motion_modules is None:
63
+ self.scheduler = EnhancedDDIMScheduler(beta_end=0.014, num_train_timesteps=1100)
64
+ else:
65
+ self.prompter.fetch_models(self.text_encoder, self.text_encoder_2)
66
+
67
+
68
+ @staticmethod
69
+ def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
70
+ pipe = SDXLVideoPipeline(
71
+ device=model_manager.device,
72
+ torch_dtype=model_manager.torch_dtype,
73
+ )
74
+ pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
75
+ return pipe
76
+
77
+
78
+ def decode_video(self, latents, tiled=False, tile_size=64, tile_stride=32):
79
+ images = [
80
+ self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
81
+ for frame_id in range(latents.shape[0])
82
+ ]
83
+ return images
84
+
85
+
86
+ def encode_video(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
87
+ latents = []
88
+ for image in processed_images:
89
+ image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
90
+ latent = self.encode_image(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
91
+ latents.append(latent.cpu())
92
+ latents = torch.concat(latents, dim=0)
93
+ return latents
94
+
95
+
96
+ @torch.no_grad()
97
+ def __call__(
98
+ self,
99
+ prompt,
100
+ negative_prompt="",
101
+ cfg_scale=7.5,
102
+ clip_skip=1,
103
+ num_frames=None,
104
+ input_frames=None,
105
+ ipadapter_images=None,
106
+ ipadapter_scale=1.0,
107
+ ipadapter_use_instant_style=False,
108
+ controlnet_frames=None,
109
+ denoising_strength=1.0,
110
+ height=512,
111
+ width=512,
112
+ num_inference_steps=20,
113
+ animatediff_batch_size = 16,
114
+ animatediff_stride = 8,
115
+ unet_batch_size = 1,
116
+ controlnet_batch_size = 1,
117
+ cross_frame_attention = False,
118
+ smoother=None,
119
+ smoother_progress_ids=[],
120
+ tiled=False,
121
+ tile_size=64,
122
+ tile_stride=32,
123
+ progress_bar_cmd=tqdm,
124
+ progress_bar_st=None,
125
+ ):
126
+ # Tiler parameters, batch size ...
127
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
128
+
129
+ # Prepare scheduler
130
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
131
+
132
+ # Prepare latent tensors
133
+ if self.motion_modules is None:
134
+ noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
135
+ else:
136
+ noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype)
137
+ if input_frames is None or denoising_strength == 1.0:
138
+ latents = noise
139
+ else:
140
+ latents = self.encode_video(input_frames, **tiler_kwargs)
141
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
142
+ latents = latents.to(self.device) # will be deleted for supporting long videos
143
+
144
+ # Encode prompts
145
+ prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
146
+ prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
147
+
148
+ # IP-Adapter
149
+ if ipadapter_images is not None:
150
+ if ipadapter_use_instant_style:
151
+ self.ipadapter.set_less_adapter()
152
+ else:
153
+ self.ipadapter.set_full_adapter()
154
+ ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
155
+ ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
156
+ ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
157
+ else:
158
+ ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
159
+
160
+ # Prepare ControlNets
161
+ if controlnet_frames is not None:
162
+ if isinstance(controlnet_frames[0], list):
163
+ controlnet_frames_ = []
164
+ for processor_id in range(len(controlnet_frames)):
165
+ controlnet_frames_.append(
166
+ torch.stack([
167
+ self.controlnet.process_image(controlnet_frame, processor_id=processor_id).to(self.torch_dtype)
168
+ for controlnet_frame in progress_bar_cmd(controlnet_frames[processor_id])
169
+ ], dim=1)
170
+ )
171
+ controlnet_frames = torch.concat(controlnet_frames_, dim=0)
172
+ else:
173
+ controlnet_frames = torch.stack([
174
+ self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
175
+ for controlnet_frame in progress_bar_cmd(controlnet_frames)
176
+ ], dim=1)
177
+ controlnet_kwargs = {"controlnet_frames": controlnet_frames}
178
+ else:
179
+ controlnet_kwargs = {"controlnet_frames": None}
180
+
181
+ # Prepare extra input
182
+ extra_input = self.prepare_extra_input(latents)
183
+
184
+ # Denoise
185
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
186
+ timestep = timestep.unsqueeze(0).to(self.device)
187
+
188
+ # Classifier-free guidance
189
+ noise_pred_posi = lets_dance_xl(
190
+ self.unet, motion_modules=self.motion_modules, controlnet=None,
191
+ sample=latents, timestep=timestep,
192
+ **prompt_emb_posi, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **extra_input, **tiler_kwargs,
193
+ device=self.device,
194
+ )
195
+ noise_pred_nega = lets_dance_xl(
196
+ self.unet, motion_modules=self.motion_modules, controlnet=None,
197
+ sample=latents, timestep=timestep,
198
+ **prompt_emb_nega, **controlnet_kwargs, **ipadapter_kwargs_list_nega, **extra_input, **tiler_kwargs,
199
+ device=self.device,
200
+ )
201
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
202
+
203
+ # DDIM and smoother
204
+ if smoother is not None and progress_id in smoother_progress_ids:
205
+ rendered_frames = self.scheduler.step(noise_pred, timestep, latents, to_final=True)
206
+ rendered_frames = self.decode_video(rendered_frames)
207
+ rendered_frames = smoother(rendered_frames, original_frames=input_frames)
208
+ target_latents = self.encode_video(rendered_frames)
209
+ noise_pred = self.scheduler.return_to_timestep(timestep, latents, target_latents)
210
+ latents = self.scheduler.step(noise_pred, timestep, latents)
211
+
212
+ # UI
213
+ if progress_bar_st is not None:
214
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
215
+
216
+ # Decode image
217
+ output_frames = self.decode_video(latents, **tiler_kwargs)
218
+
219
+ # Post-process
220
+ if smoother is not None and (num_inference_steps in smoother_progress_ids or -1 in smoother_progress_ids):
221
+ output_frames = smoother(output_frames, original_frames=input_frames)
222
+
223
+ return output_frames