diffsynth 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth/__init__.py +6 -0
- diffsynth/configs/__init__.py +0 -0
- diffsynth/configs/model_config.py +243 -0
- diffsynth/controlnets/__init__.py +2 -0
- diffsynth/controlnets/controlnet_unit.py +53 -0
- diffsynth/controlnets/processors.py +51 -0
- diffsynth/data/__init__.py +1 -0
- diffsynth/data/simple_text_image.py +35 -0
- diffsynth/data/video.py +148 -0
- diffsynth/extensions/ESRGAN/__init__.py +118 -0
- diffsynth/extensions/FastBlend/__init__.py +63 -0
- diffsynth/extensions/FastBlend/api.py +397 -0
- diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
- diffsynth/extensions/FastBlend/data.py +146 -0
- diffsynth/extensions/FastBlend/patch_match.py +298 -0
- diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
- diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
- diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
- diffsynth/extensions/FastBlend/runners/fast.py +141 -0
- diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
- diffsynth/extensions/RIFE/__init__.py +242 -0
- diffsynth/extensions/__init__.py +0 -0
- diffsynth/models/__init__.py +1 -0
- diffsynth/models/attention.py +89 -0
- diffsynth/models/downloader.py +66 -0
- diffsynth/models/hunyuan_dit.py +451 -0
- diffsynth/models/hunyuan_dit_text_encoder.py +163 -0
- diffsynth/models/kolors_text_encoder.py +1363 -0
- diffsynth/models/lora.py +195 -0
- diffsynth/models/model_manager.py +536 -0
- diffsynth/models/sd3_dit.py +798 -0
- diffsynth/models/sd3_text_encoder.py +1107 -0
- diffsynth/models/sd3_vae_decoder.py +81 -0
- diffsynth/models/sd3_vae_encoder.py +95 -0
- diffsynth/models/sd_controlnet.py +588 -0
- diffsynth/models/sd_ipadapter.py +57 -0
- diffsynth/models/sd_motion.py +199 -0
- diffsynth/models/sd_text_encoder.py +321 -0
- diffsynth/models/sd_unet.py +1108 -0
- diffsynth/models/sd_vae_decoder.py +336 -0
- diffsynth/models/sd_vae_encoder.py +282 -0
- diffsynth/models/sdxl_ipadapter.py +122 -0
- diffsynth/models/sdxl_motion.py +104 -0
- diffsynth/models/sdxl_text_encoder.py +759 -0
- diffsynth/models/sdxl_unet.py +1899 -0
- diffsynth/models/sdxl_vae_decoder.py +24 -0
- diffsynth/models/sdxl_vae_encoder.py +24 -0
- diffsynth/models/svd_image_encoder.py +505 -0
- diffsynth/models/svd_unet.py +2004 -0
- diffsynth/models/svd_vae_decoder.py +578 -0
- diffsynth/models/svd_vae_encoder.py +139 -0
- diffsynth/models/tiler.py +106 -0
- diffsynth/pipelines/__init__.py +9 -0
- diffsynth/pipelines/base.py +34 -0
- diffsynth/pipelines/dancer.py +178 -0
- diffsynth/pipelines/hunyuan_image.py +274 -0
- diffsynth/pipelines/pipeline_runner.py +105 -0
- diffsynth/pipelines/sd3_image.py +132 -0
- diffsynth/pipelines/sd_image.py +173 -0
- diffsynth/pipelines/sd_video.py +266 -0
- diffsynth/pipelines/sdxl_image.py +191 -0
- diffsynth/pipelines/sdxl_video.py +223 -0
- diffsynth/pipelines/svd_video.py +297 -0
- diffsynth/processors/FastBlend.py +142 -0
- diffsynth/processors/PILEditor.py +28 -0
- diffsynth/processors/RIFE.py +77 -0
- diffsynth/processors/__init__.py +0 -0
- diffsynth/processors/base.py +6 -0
- diffsynth/processors/sequencial_processor.py +41 -0
- diffsynth/prompters/__init__.py +6 -0
- diffsynth/prompters/base_prompter.py +57 -0
- diffsynth/prompters/hunyuan_dit_prompter.py +69 -0
- diffsynth/prompters/kolors_prompter.py +353 -0
- diffsynth/prompters/prompt_refiners.py +77 -0
- diffsynth/prompters/sd3_prompter.py +92 -0
- diffsynth/prompters/sd_prompter.py +73 -0
- diffsynth/prompters/sdxl_prompter.py +61 -0
- diffsynth/schedulers/__init__.py +3 -0
- diffsynth/schedulers/continuous_ode.py +59 -0
- diffsynth/schedulers/ddim.py +79 -0
- diffsynth/schedulers/flow_match.py +51 -0
- diffsynth/tokenizer_configs/__init__.py +0 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json +7 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json +16 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt +47020 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt +21128 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json +28 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json +1 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model +0 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json +1 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model +0 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json +12 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt +0 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json +24 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json +34 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json +38 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json +125 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/spiece.model +0 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json +129428 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json +940 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt +40213 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json +24 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json +38 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/vocab.json +49411 -0
- diffsynth/trainers/__init__.py +0 -0
- diffsynth/trainers/text_to_image.py +253 -0
- diffsynth-1.0.0.dist-info/LICENSE +201 -0
- diffsynth-1.0.0.dist-info/METADATA +23 -0
- diffsynth-1.0.0.dist-info/RECORD +120 -0
- diffsynth-1.0.0.dist-info/WHEEL +5 -0
- diffsynth-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class SVDVAEEncoder(SDVAEEncoder):
|
|
5
|
+
def __init__(self):
|
|
6
|
+
super().__init__()
|
|
7
|
+
self.scaling_factor = 0.13025
|
|
8
|
+
|
|
9
|
+
@staticmethod
|
|
10
|
+
def state_dict_converter():
|
|
11
|
+
return SVDVAEEncoderStateDictConverter()
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SVDVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
|
|
15
|
+
def __init__(self):
|
|
16
|
+
super().__init__()
|
|
17
|
+
|
|
18
|
+
def from_diffusers(self, state_dict):
|
|
19
|
+
return super().from_diffusers(state_dict)
|
|
20
|
+
|
|
21
|
+
def from_civitai(self, state_dict):
|
|
22
|
+
rename_dict = {
|
|
23
|
+
"conditioner.embedders.3.encoder.encoder.conv_in.bias": "conv_in.bias",
|
|
24
|
+
"conditioner.embedders.3.encoder.encoder.conv_in.weight": "conv_in.weight",
|
|
25
|
+
"conditioner.embedders.3.encoder.encoder.conv_out.bias": "conv_out.bias",
|
|
26
|
+
"conditioner.embedders.3.encoder.encoder.conv_out.weight": "conv_out.weight",
|
|
27
|
+
"conditioner.embedders.3.encoder.encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
|
|
28
|
+
"conditioner.embedders.3.encoder.encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
|
|
29
|
+
"conditioner.embedders.3.encoder.encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
|
|
30
|
+
"conditioner.embedders.3.encoder.encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
|
|
31
|
+
"conditioner.embedders.3.encoder.encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
|
|
32
|
+
"conditioner.embedders.3.encoder.encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
|
|
33
|
+
"conditioner.embedders.3.encoder.encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
|
|
34
|
+
"conditioner.embedders.3.encoder.encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
|
|
35
|
+
"conditioner.embedders.3.encoder.encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
|
|
36
|
+
"conditioner.embedders.3.encoder.encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
|
|
37
|
+
"conditioner.embedders.3.encoder.encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
|
|
38
|
+
"conditioner.embedders.3.encoder.encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
|
|
39
|
+
"conditioner.embedders.3.encoder.encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
|
|
40
|
+
"conditioner.embedders.3.encoder.encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
|
|
41
|
+
"conditioner.embedders.3.encoder.encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
|
|
42
|
+
"conditioner.embedders.3.encoder.encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
|
|
43
|
+
"conditioner.embedders.3.encoder.encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
|
|
44
|
+
"conditioner.embedders.3.encoder.encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
|
|
45
|
+
"conditioner.embedders.3.encoder.encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
|
|
46
|
+
"conditioner.embedders.3.encoder.encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
|
|
47
|
+
"conditioner.embedders.3.encoder.encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
|
|
48
|
+
"conditioner.embedders.3.encoder.encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
|
|
49
|
+
"conditioner.embedders.3.encoder.encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
|
|
50
|
+
"conditioner.embedders.3.encoder.encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
|
|
51
|
+
"conditioner.embedders.3.encoder.encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
|
|
52
|
+
"conditioner.embedders.3.encoder.encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
|
|
53
|
+
"conditioner.embedders.3.encoder.encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
|
|
54
|
+
"conditioner.embedders.3.encoder.encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
|
|
55
|
+
"conditioner.embedders.3.encoder.encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
|
|
56
|
+
"conditioner.embedders.3.encoder.encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
|
|
57
|
+
"conditioner.embedders.3.encoder.encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
|
|
58
|
+
"conditioner.embedders.3.encoder.encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
|
|
59
|
+
"conditioner.embedders.3.encoder.encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
|
|
60
|
+
"conditioner.embedders.3.encoder.encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
|
|
61
|
+
"conditioner.embedders.3.encoder.encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
|
|
62
|
+
"conditioner.embedders.3.encoder.encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
|
|
63
|
+
"conditioner.embedders.3.encoder.encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
|
|
64
|
+
"conditioner.embedders.3.encoder.encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
|
|
65
|
+
"conditioner.embedders.3.encoder.encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
|
|
66
|
+
"conditioner.embedders.3.encoder.encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
|
|
67
|
+
"conditioner.embedders.3.encoder.encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
|
|
68
|
+
"conditioner.embedders.3.encoder.encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
|
|
69
|
+
"conditioner.embedders.3.encoder.encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
|
|
70
|
+
"conditioner.embedders.3.encoder.encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
|
|
71
|
+
"conditioner.embedders.3.encoder.encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
|
|
72
|
+
"conditioner.embedders.3.encoder.encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
|
|
73
|
+
"conditioner.embedders.3.encoder.encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
|
|
74
|
+
"conditioner.embedders.3.encoder.encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
|
|
75
|
+
"conditioner.embedders.3.encoder.encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
|
|
76
|
+
"conditioner.embedders.3.encoder.encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
|
|
77
|
+
"conditioner.embedders.3.encoder.encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
|
|
78
|
+
"conditioner.embedders.3.encoder.encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
|
|
79
|
+
"conditioner.embedders.3.encoder.encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
|
|
80
|
+
"conditioner.embedders.3.encoder.encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
|
|
81
|
+
"conditioner.embedders.3.encoder.encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
|
|
82
|
+
"conditioner.embedders.3.encoder.encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
|
|
83
|
+
"conditioner.embedders.3.encoder.encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
|
|
84
|
+
"conditioner.embedders.3.encoder.encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
|
|
85
|
+
"conditioner.embedders.3.encoder.encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
|
|
86
|
+
"conditioner.embedders.3.encoder.encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
|
|
87
|
+
"conditioner.embedders.3.encoder.encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
|
|
88
|
+
"conditioner.embedders.3.encoder.encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
|
|
89
|
+
"conditioner.embedders.3.encoder.encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
|
|
90
|
+
"conditioner.embedders.3.encoder.encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
|
|
91
|
+
"conditioner.embedders.3.encoder.encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
|
|
92
|
+
"conditioner.embedders.3.encoder.encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
|
|
93
|
+
"conditioner.embedders.3.encoder.encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
|
|
94
|
+
"conditioner.embedders.3.encoder.encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
|
|
95
|
+
"conditioner.embedders.3.encoder.encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
|
|
96
|
+
"conditioner.embedders.3.encoder.encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
|
|
97
|
+
"conditioner.embedders.3.encoder.encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
|
|
98
|
+
"conditioner.embedders.3.encoder.encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
|
|
99
|
+
"conditioner.embedders.3.encoder.encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
|
|
100
|
+
"conditioner.embedders.3.encoder.encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
|
|
101
|
+
"conditioner.embedders.3.encoder.encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
|
|
102
|
+
"conditioner.embedders.3.encoder.encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
|
|
103
|
+
"conditioner.embedders.3.encoder.encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
|
|
104
|
+
"conditioner.embedders.3.encoder.encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
|
|
105
|
+
"conditioner.embedders.3.encoder.encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
|
|
106
|
+
"conditioner.embedders.3.encoder.encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
|
|
107
|
+
"conditioner.embedders.3.encoder.encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
|
|
108
|
+
"conditioner.embedders.3.encoder.encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
|
|
109
|
+
"conditioner.embedders.3.encoder.encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
|
|
110
|
+
"conditioner.embedders.3.encoder.encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
|
|
111
|
+
"conditioner.embedders.3.encoder.encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
|
|
112
|
+
"conditioner.embedders.3.encoder.encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
|
|
113
|
+
"conditioner.embedders.3.encoder.encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
|
|
114
|
+
"conditioner.embedders.3.encoder.encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
|
|
115
|
+
"conditioner.embedders.3.encoder.encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
|
|
116
|
+
"conditioner.embedders.3.encoder.encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
|
|
117
|
+
"conditioner.embedders.3.encoder.encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
|
|
118
|
+
"conditioner.embedders.3.encoder.encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
|
|
119
|
+
"conditioner.embedders.3.encoder.encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
|
|
120
|
+
"conditioner.embedders.3.encoder.encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
|
|
121
|
+
"conditioner.embedders.3.encoder.encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
|
|
122
|
+
"conditioner.embedders.3.encoder.encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
|
|
123
|
+
"conditioner.embedders.3.encoder.encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
|
|
124
|
+
"conditioner.embedders.3.encoder.encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
|
|
125
|
+
"conditioner.embedders.3.encoder.encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
|
|
126
|
+
"conditioner.embedders.3.encoder.encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
|
|
127
|
+
"conditioner.embedders.3.encoder.encoder.norm_out.bias": "conv_norm_out.bias",
|
|
128
|
+
"conditioner.embedders.3.encoder.encoder.norm_out.weight": "conv_norm_out.weight",
|
|
129
|
+
"conditioner.embedders.3.encoder.quant_conv.bias": "quant_conv.bias",
|
|
130
|
+
"conditioner.embedders.3.encoder.quant_conv.weight": "quant_conv.weight",
|
|
131
|
+
}
|
|
132
|
+
state_dict_ = {}
|
|
133
|
+
for name in state_dict:
|
|
134
|
+
if name in rename_dict:
|
|
135
|
+
param = state_dict[name]
|
|
136
|
+
if "transformer_blocks" in rename_dict[name]:
|
|
137
|
+
param = param.squeeze()
|
|
138
|
+
state_dict_[rename_dict[name]] = param
|
|
139
|
+
return state_dict_
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from einops import rearrange, repeat
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class TileWorker:
|
|
6
|
+
def __init__(self):
|
|
7
|
+
pass
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def mask(self, height, width, border_width):
|
|
11
|
+
# Create a mask with shape (height, width).
|
|
12
|
+
# The centre area is filled with 1, and the border line is filled with values in range (0, 1].
|
|
13
|
+
x = torch.arange(height).repeat(width, 1).T
|
|
14
|
+
y = torch.arange(width).repeat(height, 1)
|
|
15
|
+
mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values
|
|
16
|
+
mask = (mask / border_width).clip(0, 1)
|
|
17
|
+
return mask
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype):
|
|
21
|
+
# Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num)
|
|
22
|
+
batch_size, channel, _, _ = model_input.shape
|
|
23
|
+
model_input = model_input.to(device=tile_device, dtype=tile_dtype)
|
|
24
|
+
unfold_operator = torch.nn.Unfold(
|
|
25
|
+
kernel_size=(tile_size, tile_size),
|
|
26
|
+
stride=(tile_stride, tile_stride)
|
|
27
|
+
)
|
|
28
|
+
model_input = unfold_operator(model_input)
|
|
29
|
+
model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1))
|
|
30
|
+
|
|
31
|
+
return model_input
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype):
|
|
35
|
+
# Call y=forward_fn(x) for each tile
|
|
36
|
+
tile_num = model_input.shape[-1]
|
|
37
|
+
model_output_stack = []
|
|
38
|
+
|
|
39
|
+
for tile_id in range(0, tile_num, tile_batch_size):
|
|
40
|
+
|
|
41
|
+
# process input
|
|
42
|
+
tile_id_ = min(tile_id + tile_batch_size, tile_num)
|
|
43
|
+
x = model_input[:, :, :, :, tile_id: tile_id_]
|
|
44
|
+
x = x.to(device=inference_device, dtype=inference_dtype)
|
|
45
|
+
x = rearrange(x, "b c h w n -> (n b) c h w")
|
|
46
|
+
|
|
47
|
+
# process output
|
|
48
|
+
y = forward_fn(x)
|
|
49
|
+
y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id)
|
|
50
|
+
y = y.to(device=tile_device, dtype=tile_dtype)
|
|
51
|
+
model_output_stack.append(y)
|
|
52
|
+
|
|
53
|
+
model_output = torch.concat(model_output_stack, dim=-1)
|
|
54
|
+
return model_output
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def io_scale(self, model_output, tile_size):
|
|
58
|
+
# Determine the size modification happend in forward_fn
|
|
59
|
+
# We only consider the same scale on height and width.
|
|
60
|
+
io_scale = model_output.shape[2] / tile_size
|
|
61
|
+
return io_scale
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype):
|
|
65
|
+
# The reversed function of tile
|
|
66
|
+
mask = self.mask(tile_size, tile_size, border_width)
|
|
67
|
+
mask = mask.to(device=tile_device, dtype=tile_dtype)
|
|
68
|
+
mask = rearrange(mask, "h w -> 1 1 h w 1")
|
|
69
|
+
model_output = model_output * mask
|
|
70
|
+
|
|
71
|
+
fold_operator = torch.nn.Fold(
|
|
72
|
+
output_size=(height, width),
|
|
73
|
+
kernel_size=(tile_size, tile_size),
|
|
74
|
+
stride=(tile_stride, tile_stride)
|
|
75
|
+
)
|
|
76
|
+
mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1])
|
|
77
|
+
model_output = rearrange(model_output, "b c h w n -> b (c h w) n")
|
|
78
|
+
model_output = fold_operator(model_output) / fold_operator(mask)
|
|
79
|
+
|
|
80
|
+
return model_output
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
|
|
84
|
+
# Prepare
|
|
85
|
+
inference_device, inference_dtype = model_input.device, model_input.dtype
|
|
86
|
+
height, width = model_input.shape[2], model_input.shape[3]
|
|
87
|
+
border_width = int(tile_stride*0.5) if border_width is None else border_width
|
|
88
|
+
|
|
89
|
+
# tile
|
|
90
|
+
model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype)
|
|
91
|
+
|
|
92
|
+
# inference
|
|
93
|
+
model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype)
|
|
94
|
+
|
|
95
|
+
# resize
|
|
96
|
+
io_scale = self.io_scale(model_output, tile_size)
|
|
97
|
+
height, width = int(height*io_scale), int(width*io_scale)
|
|
98
|
+
tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale)
|
|
99
|
+
border_width = int(border_width*io_scale)
|
|
100
|
+
|
|
101
|
+
# untile
|
|
102
|
+
model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype)
|
|
103
|
+
|
|
104
|
+
# Done!
|
|
105
|
+
model_output = model_output.to(device=inference_device, dtype=inference_dtype)
|
|
106
|
+
return model_output
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from .sd_image import SDImagePipeline
|
|
2
|
+
from .sd_video import SDVideoPipeline
|
|
3
|
+
from .sdxl_image import SDXLImagePipeline
|
|
4
|
+
from .sdxl_video import SDXLVideoPipeline
|
|
5
|
+
from .sd3_image import SD3ImagePipeline
|
|
6
|
+
from .hunyuan_image import HunyuanDiTImagePipeline
|
|
7
|
+
from .svd_video import SVDVideoPipeline
|
|
8
|
+
from .pipeline_runner import SDVideoPipelineRunner
|
|
9
|
+
KolorsImagePipeline = SDXLImagePipeline
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
from PIL import Image
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class BasePipeline(torch.nn.Module):
|
|
8
|
+
|
|
9
|
+
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
|
10
|
+
super().__init__()
|
|
11
|
+
self.device = device
|
|
12
|
+
self.torch_dtype = torch_dtype
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def preprocess_image(self, image):
|
|
16
|
+
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
|
17
|
+
return image
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def preprocess_images(self, images):
|
|
21
|
+
return [self.preprocess_image(image) for image in images]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def vae_output_to_image(self, vae_output):
|
|
25
|
+
image = vae_output[0].cpu().permute(1, 2, 0).numpy()
|
|
26
|
+
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
|
27
|
+
return image
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def vae_output_to_video(self, vae_output):
|
|
31
|
+
video = vae_output.cpu().permute(1, 2, 0).numpy()
|
|
32
|
+
video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video]
|
|
33
|
+
return video
|
|
34
|
+
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from ..models import SDUNet, SDMotionModel, SDXLUNet, SDXLMotionModel
|
|
3
|
+
from ..models.sd_unet import PushBlock, PopBlock
|
|
4
|
+
from ..controlnets import MultiControlNetManager
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def lets_dance(
|
|
8
|
+
unet: SDUNet,
|
|
9
|
+
motion_modules: SDMotionModel = None,
|
|
10
|
+
controlnet: MultiControlNetManager = None,
|
|
11
|
+
sample = None,
|
|
12
|
+
timestep = None,
|
|
13
|
+
encoder_hidden_states = None,
|
|
14
|
+
ipadapter_kwargs_list = {},
|
|
15
|
+
controlnet_frames = None,
|
|
16
|
+
unet_batch_size = 1,
|
|
17
|
+
controlnet_batch_size = 1,
|
|
18
|
+
cross_frame_attention = False,
|
|
19
|
+
tiled=False,
|
|
20
|
+
tile_size=64,
|
|
21
|
+
tile_stride=32,
|
|
22
|
+
device = "cuda",
|
|
23
|
+
vram_limit_level = 0,
|
|
24
|
+
):
|
|
25
|
+
# 0. Text embedding alignment (only for video processing)
|
|
26
|
+
if encoder_hidden_states.shape[0] != sample.shape[0]:
|
|
27
|
+
encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0], 1, 1, 1)
|
|
28
|
+
|
|
29
|
+
# 1. ControlNet
|
|
30
|
+
# This part will be repeated on overlapping frames if animatediff_batch_size > animatediff_stride.
|
|
31
|
+
# I leave it here because I intend to do something interesting on the ControlNets.
|
|
32
|
+
controlnet_insert_block_id = 30
|
|
33
|
+
if controlnet is not None and controlnet_frames is not None:
|
|
34
|
+
res_stacks = []
|
|
35
|
+
# process controlnet frames with batch
|
|
36
|
+
for batch_id in range(0, sample.shape[0], controlnet_batch_size):
|
|
37
|
+
batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0])
|
|
38
|
+
res_stack = controlnet(
|
|
39
|
+
sample[batch_id: batch_id_],
|
|
40
|
+
timestep,
|
|
41
|
+
encoder_hidden_states[batch_id: batch_id_],
|
|
42
|
+
controlnet_frames[:, batch_id: batch_id_],
|
|
43
|
+
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
|
44
|
+
)
|
|
45
|
+
if vram_limit_level >= 1:
|
|
46
|
+
res_stack = [res.cpu() for res in res_stack]
|
|
47
|
+
res_stacks.append(res_stack)
|
|
48
|
+
# concat the residual
|
|
49
|
+
additional_res_stack = []
|
|
50
|
+
for i in range(len(res_stacks[0])):
|
|
51
|
+
res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0)
|
|
52
|
+
additional_res_stack.append(res)
|
|
53
|
+
else:
|
|
54
|
+
additional_res_stack = None
|
|
55
|
+
|
|
56
|
+
# 2. time
|
|
57
|
+
time_emb = unet.time_proj(timestep).to(sample.dtype)
|
|
58
|
+
time_emb = unet.time_embedding(time_emb)
|
|
59
|
+
|
|
60
|
+
# 3. pre-process
|
|
61
|
+
height, width = sample.shape[2], sample.shape[3]
|
|
62
|
+
hidden_states = unet.conv_in(sample)
|
|
63
|
+
text_emb = encoder_hidden_states
|
|
64
|
+
res_stack = [hidden_states.cpu() if vram_limit_level>=1 else hidden_states]
|
|
65
|
+
|
|
66
|
+
# 4. blocks
|
|
67
|
+
for block_id, block in enumerate(unet.blocks):
|
|
68
|
+
# 4.1 UNet
|
|
69
|
+
if isinstance(block, PushBlock):
|
|
70
|
+
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
|
71
|
+
if vram_limit_level>=1:
|
|
72
|
+
res_stack[-1] = res_stack[-1].cpu()
|
|
73
|
+
elif isinstance(block, PopBlock):
|
|
74
|
+
if vram_limit_level>=1:
|
|
75
|
+
res_stack[-1] = res_stack[-1].to(device)
|
|
76
|
+
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
|
77
|
+
else:
|
|
78
|
+
hidden_states_input = hidden_states
|
|
79
|
+
hidden_states_output = []
|
|
80
|
+
for batch_id in range(0, sample.shape[0], unet_batch_size):
|
|
81
|
+
batch_id_ = min(batch_id + unet_batch_size, sample.shape[0])
|
|
82
|
+
hidden_states, _, _, _ = block(
|
|
83
|
+
hidden_states_input[batch_id: batch_id_],
|
|
84
|
+
time_emb,
|
|
85
|
+
text_emb[batch_id: batch_id_],
|
|
86
|
+
res_stack,
|
|
87
|
+
cross_frame_attention=cross_frame_attention,
|
|
88
|
+
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {}),
|
|
89
|
+
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
|
90
|
+
)
|
|
91
|
+
hidden_states_output.append(hidden_states)
|
|
92
|
+
hidden_states = torch.concat(hidden_states_output, dim=0)
|
|
93
|
+
# 4.2 AnimateDiff
|
|
94
|
+
if motion_modules is not None:
|
|
95
|
+
if block_id in motion_modules.call_block_id:
|
|
96
|
+
motion_module_id = motion_modules.call_block_id[block_id]
|
|
97
|
+
hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id](
|
|
98
|
+
hidden_states, time_emb, text_emb, res_stack,
|
|
99
|
+
batch_size=1
|
|
100
|
+
)
|
|
101
|
+
# 4.3 ControlNet
|
|
102
|
+
if block_id == controlnet_insert_block_id and additional_res_stack is not None:
|
|
103
|
+
hidden_states += additional_res_stack.pop().to(device)
|
|
104
|
+
if vram_limit_level>=1:
|
|
105
|
+
res_stack = [(res.to(device) + additional_res.to(device)).cpu() for res, additional_res in zip(res_stack, additional_res_stack)]
|
|
106
|
+
else:
|
|
107
|
+
res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)]
|
|
108
|
+
|
|
109
|
+
# 5. output
|
|
110
|
+
hidden_states = unet.conv_norm_out(hidden_states)
|
|
111
|
+
hidden_states = unet.conv_act(hidden_states)
|
|
112
|
+
hidden_states = unet.conv_out(hidden_states)
|
|
113
|
+
|
|
114
|
+
return hidden_states
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def lets_dance_xl(
|
|
120
|
+
unet: SDXLUNet,
|
|
121
|
+
motion_modules: SDXLMotionModel = None,
|
|
122
|
+
controlnet: MultiControlNetManager = None,
|
|
123
|
+
sample = None,
|
|
124
|
+
add_time_id = None,
|
|
125
|
+
add_text_embeds = None,
|
|
126
|
+
timestep = None,
|
|
127
|
+
encoder_hidden_states = None,
|
|
128
|
+
ipadapter_kwargs_list = {},
|
|
129
|
+
controlnet_frames = None,
|
|
130
|
+
unet_batch_size = 1,
|
|
131
|
+
controlnet_batch_size = 1,
|
|
132
|
+
cross_frame_attention = False,
|
|
133
|
+
tiled=False,
|
|
134
|
+
tile_size=64,
|
|
135
|
+
tile_stride=32,
|
|
136
|
+
device = "cuda",
|
|
137
|
+
vram_limit_level = 0,
|
|
138
|
+
):
|
|
139
|
+
# 2. time
|
|
140
|
+
t_emb = unet.time_proj(timestep).to(sample.dtype)
|
|
141
|
+
t_emb = unet.time_embedding(t_emb)
|
|
142
|
+
|
|
143
|
+
time_embeds = unet.add_time_proj(add_time_id)
|
|
144
|
+
time_embeds = time_embeds.reshape((add_text_embeds.shape[0], -1))
|
|
145
|
+
add_embeds = torch.concat([add_text_embeds, time_embeds], dim=-1)
|
|
146
|
+
add_embeds = add_embeds.to(sample.dtype)
|
|
147
|
+
add_embeds = unet.add_time_embedding(add_embeds)
|
|
148
|
+
|
|
149
|
+
time_emb = t_emb + add_embeds
|
|
150
|
+
|
|
151
|
+
# 3. pre-process
|
|
152
|
+
height, width = sample.shape[2], sample.shape[3]
|
|
153
|
+
hidden_states = unet.conv_in(sample)
|
|
154
|
+
text_emb = encoder_hidden_states if unet.text_intermediate_proj is None else unet.text_intermediate_proj(encoder_hidden_states)
|
|
155
|
+
res_stack = [hidden_states]
|
|
156
|
+
|
|
157
|
+
# 4. blocks
|
|
158
|
+
for block_id, block in enumerate(unet.blocks):
|
|
159
|
+
hidden_states, time_emb, text_emb, res_stack = block(
|
|
160
|
+
hidden_states, time_emb, text_emb, res_stack,
|
|
161
|
+
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
|
162
|
+
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {})
|
|
163
|
+
)
|
|
164
|
+
# 4.2 AnimateDiff
|
|
165
|
+
if motion_modules is not None:
|
|
166
|
+
if block_id in motion_modules.call_block_id:
|
|
167
|
+
motion_module_id = motion_modules.call_block_id[block_id]
|
|
168
|
+
hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id](
|
|
169
|
+
hidden_states, time_emb, text_emb, res_stack,
|
|
170
|
+
batch_size=1
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# 5. output
|
|
174
|
+
hidden_states = unet.conv_norm_out(hidden_states)
|
|
175
|
+
hidden_states = unet.conv_act(hidden_states)
|
|
176
|
+
hidden_states = unet.conv_out(hidden_states)
|
|
177
|
+
|
|
178
|
+
return hidden_states
|