diffsynth 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. diffsynth/__init__.py +6 -0
  2. diffsynth/configs/__init__.py +0 -0
  3. diffsynth/configs/model_config.py +243 -0
  4. diffsynth/controlnets/__init__.py +2 -0
  5. diffsynth/controlnets/controlnet_unit.py +53 -0
  6. diffsynth/controlnets/processors.py +51 -0
  7. diffsynth/data/__init__.py +1 -0
  8. diffsynth/data/simple_text_image.py +35 -0
  9. diffsynth/data/video.py +148 -0
  10. diffsynth/extensions/ESRGAN/__init__.py +118 -0
  11. diffsynth/extensions/FastBlend/__init__.py +63 -0
  12. diffsynth/extensions/FastBlend/api.py +397 -0
  13. diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
  14. diffsynth/extensions/FastBlend/data.py +146 -0
  15. diffsynth/extensions/FastBlend/patch_match.py +298 -0
  16. diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
  17. diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
  18. diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
  19. diffsynth/extensions/FastBlend/runners/fast.py +141 -0
  20. diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
  21. diffsynth/extensions/RIFE/__init__.py +242 -0
  22. diffsynth/extensions/__init__.py +0 -0
  23. diffsynth/models/__init__.py +1 -0
  24. diffsynth/models/attention.py +89 -0
  25. diffsynth/models/downloader.py +66 -0
  26. diffsynth/models/hunyuan_dit.py +451 -0
  27. diffsynth/models/hunyuan_dit_text_encoder.py +163 -0
  28. diffsynth/models/kolors_text_encoder.py +1363 -0
  29. diffsynth/models/lora.py +195 -0
  30. diffsynth/models/model_manager.py +536 -0
  31. diffsynth/models/sd3_dit.py +798 -0
  32. diffsynth/models/sd3_text_encoder.py +1107 -0
  33. diffsynth/models/sd3_vae_decoder.py +81 -0
  34. diffsynth/models/sd3_vae_encoder.py +95 -0
  35. diffsynth/models/sd_controlnet.py +588 -0
  36. diffsynth/models/sd_ipadapter.py +57 -0
  37. diffsynth/models/sd_motion.py +199 -0
  38. diffsynth/models/sd_text_encoder.py +321 -0
  39. diffsynth/models/sd_unet.py +1108 -0
  40. diffsynth/models/sd_vae_decoder.py +336 -0
  41. diffsynth/models/sd_vae_encoder.py +282 -0
  42. diffsynth/models/sdxl_ipadapter.py +122 -0
  43. diffsynth/models/sdxl_motion.py +104 -0
  44. diffsynth/models/sdxl_text_encoder.py +759 -0
  45. diffsynth/models/sdxl_unet.py +1899 -0
  46. diffsynth/models/sdxl_vae_decoder.py +24 -0
  47. diffsynth/models/sdxl_vae_encoder.py +24 -0
  48. diffsynth/models/svd_image_encoder.py +505 -0
  49. diffsynth/models/svd_unet.py +2004 -0
  50. diffsynth/models/svd_vae_decoder.py +578 -0
  51. diffsynth/models/svd_vae_encoder.py +139 -0
  52. diffsynth/models/tiler.py +106 -0
  53. diffsynth/pipelines/__init__.py +9 -0
  54. diffsynth/pipelines/base.py +34 -0
  55. diffsynth/pipelines/dancer.py +178 -0
  56. diffsynth/pipelines/hunyuan_image.py +274 -0
  57. diffsynth/pipelines/pipeline_runner.py +105 -0
  58. diffsynth/pipelines/sd3_image.py +132 -0
  59. diffsynth/pipelines/sd_image.py +173 -0
  60. diffsynth/pipelines/sd_video.py +266 -0
  61. diffsynth/pipelines/sdxl_image.py +191 -0
  62. diffsynth/pipelines/sdxl_video.py +223 -0
  63. diffsynth/pipelines/svd_video.py +297 -0
  64. diffsynth/processors/FastBlend.py +142 -0
  65. diffsynth/processors/PILEditor.py +28 -0
  66. diffsynth/processors/RIFE.py +77 -0
  67. diffsynth/processors/__init__.py +0 -0
  68. diffsynth/processors/base.py +6 -0
  69. diffsynth/processors/sequencial_processor.py +41 -0
  70. diffsynth/prompters/__init__.py +6 -0
  71. diffsynth/prompters/base_prompter.py +57 -0
  72. diffsynth/prompters/hunyuan_dit_prompter.py +69 -0
  73. diffsynth/prompters/kolors_prompter.py +353 -0
  74. diffsynth/prompters/prompt_refiners.py +77 -0
  75. diffsynth/prompters/sd3_prompter.py +92 -0
  76. diffsynth/prompters/sd_prompter.py +73 -0
  77. diffsynth/prompters/sdxl_prompter.py +61 -0
  78. diffsynth/schedulers/__init__.py +3 -0
  79. diffsynth/schedulers/continuous_ode.py +59 -0
  80. diffsynth/schedulers/ddim.py +79 -0
  81. diffsynth/schedulers/flow_match.py +51 -0
  82. diffsynth/tokenizer_configs/__init__.py +0 -0
  83. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json +7 -0
  84. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json +16 -0
  85. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt +47020 -0
  86. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt +21128 -0
  87. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json +28 -0
  88. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json +1 -0
  89. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model +0 -0
  90. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json +1 -0
  91. diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model +0 -0
  92. diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json +12 -0
  93. diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt +0 -0
  94. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt +48895 -0
  95. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json +24 -0
  96. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json +34 -0
  97. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json +49410 -0
  98. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt +48895 -0
  99. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json +30 -0
  100. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json +30 -0
  101. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json +49410 -0
  102. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt +48895 -0
  103. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json +30 -0
  104. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json +38 -0
  105. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json +49410 -0
  106. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json +125 -0
  107. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/spiece.model +0 -0
  108. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json +129428 -0
  109. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json +940 -0
  110. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt +40213 -0
  111. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json +24 -0
  112. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json +38 -0
  113. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/vocab.json +49411 -0
  114. diffsynth/trainers/__init__.py +0 -0
  115. diffsynth/trainers/text_to_image.py +253 -0
  116. diffsynth-1.0.0.dist-info/LICENSE +201 -0
  117. diffsynth-1.0.0.dist-info/METADATA +23 -0
  118. diffsynth-1.0.0.dist-info/RECORD +120 -0
  119. diffsynth-1.0.0.dist-info/WHEEL +5 -0
  120. diffsynth-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,139 @@
1
+ from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder
2
+
3
+
4
+ class SVDVAEEncoder(SDVAEEncoder):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.scaling_factor = 0.13025
8
+
9
+ @staticmethod
10
+ def state_dict_converter():
11
+ return SVDVAEEncoderStateDictConverter()
12
+
13
+
14
+ class SVDVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
15
+ def __init__(self):
16
+ super().__init__()
17
+
18
+ def from_diffusers(self, state_dict):
19
+ return super().from_diffusers(state_dict)
20
+
21
+ def from_civitai(self, state_dict):
22
+ rename_dict = {
23
+ "conditioner.embedders.3.encoder.encoder.conv_in.bias": "conv_in.bias",
24
+ "conditioner.embedders.3.encoder.encoder.conv_in.weight": "conv_in.weight",
25
+ "conditioner.embedders.3.encoder.encoder.conv_out.bias": "conv_out.bias",
26
+ "conditioner.embedders.3.encoder.encoder.conv_out.weight": "conv_out.weight",
27
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
28
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
29
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
30
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
31
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
32
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
33
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
34
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
35
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
36
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
37
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
38
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
39
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
40
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
41
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
42
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
43
+ "conditioner.embedders.3.encoder.encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
44
+ "conditioner.embedders.3.encoder.encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
45
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
46
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
47
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
48
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
49
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
50
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
51
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
52
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
53
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
54
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
55
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
56
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
57
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
58
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
59
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
60
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
61
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
62
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
63
+ "conditioner.embedders.3.encoder.encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
64
+ "conditioner.embedders.3.encoder.encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
65
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
66
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
67
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
68
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
69
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
70
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
71
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
72
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
73
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
74
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
75
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
76
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
77
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
78
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
79
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
80
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
81
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
82
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
83
+ "conditioner.embedders.3.encoder.encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
84
+ "conditioner.embedders.3.encoder.encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
85
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
86
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
87
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
88
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
89
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
90
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
91
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
92
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
93
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
94
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
95
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
96
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
97
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
98
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
99
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
100
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
101
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
102
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
103
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
104
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
105
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
106
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
107
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
108
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
109
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
110
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
111
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
112
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
113
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
114
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
115
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
116
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
117
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
118
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
119
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
120
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
121
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
122
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
123
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
124
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
125
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
126
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
127
+ "conditioner.embedders.3.encoder.encoder.norm_out.bias": "conv_norm_out.bias",
128
+ "conditioner.embedders.3.encoder.encoder.norm_out.weight": "conv_norm_out.weight",
129
+ "conditioner.embedders.3.encoder.quant_conv.bias": "quant_conv.bias",
130
+ "conditioner.embedders.3.encoder.quant_conv.weight": "quant_conv.weight",
131
+ }
132
+ state_dict_ = {}
133
+ for name in state_dict:
134
+ if name in rename_dict:
135
+ param = state_dict[name]
136
+ if "transformer_blocks" in rename_dict[name]:
137
+ param = param.squeeze()
138
+ state_dict_[rename_dict[name]] = param
139
+ return state_dict_
@@ -0,0 +1,106 @@
1
+ import torch
2
+ from einops import rearrange, repeat
3
+
4
+
5
+ class TileWorker:
6
+ def __init__(self):
7
+ pass
8
+
9
+
10
+ def mask(self, height, width, border_width):
11
+ # Create a mask with shape (height, width).
12
+ # The centre area is filled with 1, and the border line is filled with values in range (0, 1].
13
+ x = torch.arange(height).repeat(width, 1).T
14
+ y = torch.arange(width).repeat(height, 1)
15
+ mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values
16
+ mask = (mask / border_width).clip(0, 1)
17
+ return mask
18
+
19
+
20
+ def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype):
21
+ # Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num)
22
+ batch_size, channel, _, _ = model_input.shape
23
+ model_input = model_input.to(device=tile_device, dtype=tile_dtype)
24
+ unfold_operator = torch.nn.Unfold(
25
+ kernel_size=(tile_size, tile_size),
26
+ stride=(tile_stride, tile_stride)
27
+ )
28
+ model_input = unfold_operator(model_input)
29
+ model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1))
30
+
31
+ return model_input
32
+
33
+
34
+ def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype):
35
+ # Call y=forward_fn(x) for each tile
36
+ tile_num = model_input.shape[-1]
37
+ model_output_stack = []
38
+
39
+ for tile_id in range(0, tile_num, tile_batch_size):
40
+
41
+ # process input
42
+ tile_id_ = min(tile_id + tile_batch_size, tile_num)
43
+ x = model_input[:, :, :, :, tile_id: tile_id_]
44
+ x = x.to(device=inference_device, dtype=inference_dtype)
45
+ x = rearrange(x, "b c h w n -> (n b) c h w")
46
+
47
+ # process output
48
+ y = forward_fn(x)
49
+ y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id)
50
+ y = y.to(device=tile_device, dtype=tile_dtype)
51
+ model_output_stack.append(y)
52
+
53
+ model_output = torch.concat(model_output_stack, dim=-1)
54
+ return model_output
55
+
56
+
57
+ def io_scale(self, model_output, tile_size):
58
+ # Determine the size modification happend in forward_fn
59
+ # We only consider the same scale on height and width.
60
+ io_scale = model_output.shape[2] / tile_size
61
+ return io_scale
62
+
63
+
64
+ def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype):
65
+ # The reversed function of tile
66
+ mask = self.mask(tile_size, tile_size, border_width)
67
+ mask = mask.to(device=tile_device, dtype=tile_dtype)
68
+ mask = rearrange(mask, "h w -> 1 1 h w 1")
69
+ model_output = model_output * mask
70
+
71
+ fold_operator = torch.nn.Fold(
72
+ output_size=(height, width),
73
+ kernel_size=(tile_size, tile_size),
74
+ stride=(tile_stride, tile_stride)
75
+ )
76
+ mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1])
77
+ model_output = rearrange(model_output, "b c h w n -> b (c h w) n")
78
+ model_output = fold_operator(model_output) / fold_operator(mask)
79
+
80
+ return model_output
81
+
82
+
83
+ def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
84
+ # Prepare
85
+ inference_device, inference_dtype = model_input.device, model_input.dtype
86
+ height, width = model_input.shape[2], model_input.shape[3]
87
+ border_width = int(tile_stride*0.5) if border_width is None else border_width
88
+
89
+ # tile
90
+ model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype)
91
+
92
+ # inference
93
+ model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype)
94
+
95
+ # resize
96
+ io_scale = self.io_scale(model_output, tile_size)
97
+ height, width = int(height*io_scale), int(width*io_scale)
98
+ tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale)
99
+ border_width = int(border_width*io_scale)
100
+
101
+ # untile
102
+ model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype)
103
+
104
+ # Done!
105
+ model_output = model_output.to(device=inference_device, dtype=inference_dtype)
106
+ return model_output
@@ -0,0 +1,9 @@
1
+ from .sd_image import SDImagePipeline
2
+ from .sd_video import SDVideoPipeline
3
+ from .sdxl_image import SDXLImagePipeline
4
+ from .sdxl_video import SDXLVideoPipeline
5
+ from .sd3_image import SD3ImagePipeline
6
+ from .hunyuan_image import HunyuanDiTImagePipeline
7
+ from .svd_video import SVDVideoPipeline
8
+ from .pipeline_runner import SDVideoPipelineRunner
9
+ KolorsImagePipeline = SDXLImagePipeline
@@ -0,0 +1,34 @@
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+
5
+
6
+
7
+ class BasePipeline(torch.nn.Module):
8
+
9
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
10
+ super().__init__()
11
+ self.device = device
12
+ self.torch_dtype = torch_dtype
13
+
14
+
15
+ def preprocess_image(self, image):
16
+ image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
17
+ return image
18
+
19
+
20
+ def preprocess_images(self, images):
21
+ return [self.preprocess_image(image) for image in images]
22
+
23
+
24
+ def vae_output_to_image(self, vae_output):
25
+ image = vae_output[0].cpu().permute(1, 2, 0).numpy()
26
+ image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
27
+ return image
28
+
29
+
30
+ def vae_output_to_video(self, vae_output):
31
+ video = vae_output.cpu().permute(1, 2, 0).numpy()
32
+ video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video]
33
+ return video
34
+
@@ -0,0 +1,178 @@
1
+ import torch
2
+ from ..models import SDUNet, SDMotionModel, SDXLUNet, SDXLMotionModel
3
+ from ..models.sd_unet import PushBlock, PopBlock
4
+ from ..controlnets import MultiControlNetManager
5
+
6
+
7
+ def lets_dance(
8
+ unet: SDUNet,
9
+ motion_modules: SDMotionModel = None,
10
+ controlnet: MultiControlNetManager = None,
11
+ sample = None,
12
+ timestep = None,
13
+ encoder_hidden_states = None,
14
+ ipadapter_kwargs_list = {},
15
+ controlnet_frames = None,
16
+ unet_batch_size = 1,
17
+ controlnet_batch_size = 1,
18
+ cross_frame_attention = False,
19
+ tiled=False,
20
+ tile_size=64,
21
+ tile_stride=32,
22
+ device = "cuda",
23
+ vram_limit_level = 0,
24
+ ):
25
+ # 0. Text embedding alignment (only for video processing)
26
+ if encoder_hidden_states.shape[0] != sample.shape[0]:
27
+ encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0], 1, 1, 1)
28
+
29
+ # 1. ControlNet
30
+ # This part will be repeated on overlapping frames if animatediff_batch_size > animatediff_stride.
31
+ # I leave it here because I intend to do something interesting on the ControlNets.
32
+ controlnet_insert_block_id = 30
33
+ if controlnet is not None and controlnet_frames is not None:
34
+ res_stacks = []
35
+ # process controlnet frames with batch
36
+ for batch_id in range(0, sample.shape[0], controlnet_batch_size):
37
+ batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0])
38
+ res_stack = controlnet(
39
+ sample[batch_id: batch_id_],
40
+ timestep,
41
+ encoder_hidden_states[batch_id: batch_id_],
42
+ controlnet_frames[:, batch_id: batch_id_],
43
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
44
+ )
45
+ if vram_limit_level >= 1:
46
+ res_stack = [res.cpu() for res in res_stack]
47
+ res_stacks.append(res_stack)
48
+ # concat the residual
49
+ additional_res_stack = []
50
+ for i in range(len(res_stacks[0])):
51
+ res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0)
52
+ additional_res_stack.append(res)
53
+ else:
54
+ additional_res_stack = None
55
+
56
+ # 2. time
57
+ time_emb = unet.time_proj(timestep).to(sample.dtype)
58
+ time_emb = unet.time_embedding(time_emb)
59
+
60
+ # 3. pre-process
61
+ height, width = sample.shape[2], sample.shape[3]
62
+ hidden_states = unet.conv_in(sample)
63
+ text_emb = encoder_hidden_states
64
+ res_stack = [hidden_states.cpu() if vram_limit_level>=1 else hidden_states]
65
+
66
+ # 4. blocks
67
+ for block_id, block in enumerate(unet.blocks):
68
+ # 4.1 UNet
69
+ if isinstance(block, PushBlock):
70
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
71
+ if vram_limit_level>=1:
72
+ res_stack[-1] = res_stack[-1].cpu()
73
+ elif isinstance(block, PopBlock):
74
+ if vram_limit_level>=1:
75
+ res_stack[-1] = res_stack[-1].to(device)
76
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
77
+ else:
78
+ hidden_states_input = hidden_states
79
+ hidden_states_output = []
80
+ for batch_id in range(0, sample.shape[0], unet_batch_size):
81
+ batch_id_ = min(batch_id + unet_batch_size, sample.shape[0])
82
+ hidden_states, _, _, _ = block(
83
+ hidden_states_input[batch_id: batch_id_],
84
+ time_emb,
85
+ text_emb[batch_id: batch_id_],
86
+ res_stack,
87
+ cross_frame_attention=cross_frame_attention,
88
+ ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {}),
89
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
90
+ )
91
+ hidden_states_output.append(hidden_states)
92
+ hidden_states = torch.concat(hidden_states_output, dim=0)
93
+ # 4.2 AnimateDiff
94
+ if motion_modules is not None:
95
+ if block_id in motion_modules.call_block_id:
96
+ motion_module_id = motion_modules.call_block_id[block_id]
97
+ hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id](
98
+ hidden_states, time_emb, text_emb, res_stack,
99
+ batch_size=1
100
+ )
101
+ # 4.3 ControlNet
102
+ if block_id == controlnet_insert_block_id and additional_res_stack is not None:
103
+ hidden_states += additional_res_stack.pop().to(device)
104
+ if vram_limit_level>=1:
105
+ res_stack = [(res.to(device) + additional_res.to(device)).cpu() for res, additional_res in zip(res_stack, additional_res_stack)]
106
+ else:
107
+ res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)]
108
+
109
+ # 5. output
110
+ hidden_states = unet.conv_norm_out(hidden_states)
111
+ hidden_states = unet.conv_act(hidden_states)
112
+ hidden_states = unet.conv_out(hidden_states)
113
+
114
+ return hidden_states
115
+
116
+
117
+
118
+
119
+ def lets_dance_xl(
120
+ unet: SDXLUNet,
121
+ motion_modules: SDXLMotionModel = None,
122
+ controlnet: MultiControlNetManager = None,
123
+ sample = None,
124
+ add_time_id = None,
125
+ add_text_embeds = None,
126
+ timestep = None,
127
+ encoder_hidden_states = None,
128
+ ipadapter_kwargs_list = {},
129
+ controlnet_frames = None,
130
+ unet_batch_size = 1,
131
+ controlnet_batch_size = 1,
132
+ cross_frame_attention = False,
133
+ tiled=False,
134
+ tile_size=64,
135
+ tile_stride=32,
136
+ device = "cuda",
137
+ vram_limit_level = 0,
138
+ ):
139
+ # 2. time
140
+ t_emb = unet.time_proj(timestep).to(sample.dtype)
141
+ t_emb = unet.time_embedding(t_emb)
142
+
143
+ time_embeds = unet.add_time_proj(add_time_id)
144
+ time_embeds = time_embeds.reshape((add_text_embeds.shape[0], -1))
145
+ add_embeds = torch.concat([add_text_embeds, time_embeds], dim=-1)
146
+ add_embeds = add_embeds.to(sample.dtype)
147
+ add_embeds = unet.add_time_embedding(add_embeds)
148
+
149
+ time_emb = t_emb + add_embeds
150
+
151
+ # 3. pre-process
152
+ height, width = sample.shape[2], sample.shape[3]
153
+ hidden_states = unet.conv_in(sample)
154
+ text_emb = encoder_hidden_states if unet.text_intermediate_proj is None else unet.text_intermediate_proj(encoder_hidden_states)
155
+ res_stack = [hidden_states]
156
+
157
+ # 4. blocks
158
+ for block_id, block in enumerate(unet.blocks):
159
+ hidden_states, time_emb, text_emb, res_stack = block(
160
+ hidden_states, time_emb, text_emb, res_stack,
161
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
162
+ ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {})
163
+ )
164
+ # 4.2 AnimateDiff
165
+ if motion_modules is not None:
166
+ if block_id in motion_modules.call_block_id:
167
+ motion_module_id = motion_modules.call_block_id[block_id]
168
+ hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id](
169
+ hidden_states, time_emb, text_emb, res_stack,
170
+ batch_size=1
171
+ )
172
+
173
+ # 5. output
174
+ hidden_states = unet.conv_norm_out(hidden_states)
175
+ hidden_states = unet.conv_act(hidden_states)
176
+ hidden_states = unet.conv_out(hidden_states)
177
+
178
+ return hidden_states