diffsynth 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. diffsynth/__init__.py +6 -0
  2. diffsynth/configs/__init__.py +0 -0
  3. diffsynth/configs/model_config.py +243 -0
  4. diffsynth/controlnets/__init__.py +2 -0
  5. diffsynth/controlnets/controlnet_unit.py +53 -0
  6. diffsynth/controlnets/processors.py +51 -0
  7. diffsynth/data/__init__.py +1 -0
  8. diffsynth/data/simple_text_image.py +35 -0
  9. diffsynth/data/video.py +148 -0
  10. diffsynth/extensions/ESRGAN/__init__.py +118 -0
  11. diffsynth/extensions/FastBlend/__init__.py +63 -0
  12. diffsynth/extensions/FastBlend/api.py +397 -0
  13. diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
  14. diffsynth/extensions/FastBlend/data.py +146 -0
  15. diffsynth/extensions/FastBlend/patch_match.py +298 -0
  16. diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
  17. diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
  18. diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
  19. diffsynth/extensions/FastBlend/runners/fast.py +141 -0
  20. diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
  21. diffsynth/extensions/RIFE/__init__.py +242 -0
  22. diffsynth/extensions/__init__.py +0 -0
  23. diffsynth/models/__init__.py +1 -0
  24. diffsynth/models/attention.py +89 -0
  25. diffsynth/models/downloader.py +66 -0
  26. diffsynth/models/hunyuan_dit.py +451 -0
  27. diffsynth/models/hunyuan_dit_text_encoder.py +163 -0
  28. diffsynth/models/kolors_text_encoder.py +1363 -0
  29. diffsynth/models/lora.py +195 -0
  30. diffsynth/models/model_manager.py +536 -0
  31. diffsynth/models/sd3_dit.py +798 -0
  32. diffsynth/models/sd3_text_encoder.py +1107 -0
  33. diffsynth/models/sd3_vae_decoder.py +81 -0
  34. diffsynth/models/sd3_vae_encoder.py +95 -0
  35. diffsynth/models/sd_controlnet.py +588 -0
  36. diffsynth/models/sd_ipadapter.py +57 -0
  37. diffsynth/models/sd_motion.py +199 -0
  38. diffsynth/models/sd_text_encoder.py +321 -0
  39. diffsynth/models/sd_unet.py +1108 -0
  40. diffsynth/models/sd_vae_decoder.py +336 -0
  41. diffsynth/models/sd_vae_encoder.py +282 -0
  42. diffsynth/models/sdxl_ipadapter.py +122 -0
  43. diffsynth/models/sdxl_motion.py +104 -0
  44. diffsynth/models/sdxl_text_encoder.py +759 -0
  45. diffsynth/models/sdxl_unet.py +1899 -0
  46. diffsynth/models/sdxl_vae_decoder.py +24 -0
  47. diffsynth/models/sdxl_vae_encoder.py +24 -0
  48. diffsynth/models/svd_image_encoder.py +505 -0
  49. diffsynth/models/svd_unet.py +2004 -0
  50. diffsynth/models/svd_vae_decoder.py +578 -0
  51. diffsynth/models/svd_vae_encoder.py +139 -0
  52. diffsynth/models/tiler.py +106 -0
  53. diffsynth/pipelines/__init__.py +9 -0
  54. diffsynth/pipelines/base.py +34 -0
  55. diffsynth/pipelines/dancer.py +178 -0
  56. diffsynth/pipelines/hunyuan_image.py +274 -0
  57. diffsynth/pipelines/pipeline_runner.py +105 -0
  58. diffsynth/pipelines/sd3_image.py +132 -0
  59. diffsynth/pipelines/sd_image.py +173 -0
  60. diffsynth/pipelines/sd_video.py +266 -0
  61. diffsynth/pipelines/sdxl_image.py +191 -0
  62. diffsynth/pipelines/sdxl_video.py +223 -0
  63. diffsynth/pipelines/svd_video.py +297 -0
  64. diffsynth/processors/FastBlend.py +142 -0
  65. diffsynth/processors/PILEditor.py +28 -0
  66. diffsynth/processors/RIFE.py +77 -0
  67. diffsynth/processors/__init__.py +0 -0
  68. diffsynth/processors/base.py +6 -0
  69. diffsynth/processors/sequencial_processor.py +41 -0
  70. diffsynth/prompters/__init__.py +6 -0
  71. diffsynth/prompters/base_prompter.py +57 -0
  72. diffsynth/prompters/hunyuan_dit_prompter.py +69 -0
  73. diffsynth/prompters/kolors_prompter.py +353 -0
  74. diffsynth/prompters/prompt_refiners.py +77 -0
  75. diffsynth/prompters/sd3_prompter.py +92 -0
  76. diffsynth/prompters/sd_prompter.py +73 -0
  77. diffsynth/prompters/sdxl_prompter.py +61 -0
  78. diffsynth/schedulers/__init__.py +3 -0
  79. diffsynth/schedulers/continuous_ode.py +59 -0
  80. diffsynth/schedulers/ddim.py +79 -0
  81. diffsynth/schedulers/flow_match.py +51 -0
  82. diffsynth/tokenizer_configs/__init__.py +0 -0
  83. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json +7 -0
  84. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json +16 -0
  85. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt +47020 -0
  86. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt +21128 -0
  87. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json +28 -0
  88. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json +1 -0
  89. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model +0 -0
  90. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json +1 -0
  91. diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model +0 -0
  92. diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json +12 -0
  93. diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt +0 -0
  94. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt +48895 -0
  95. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json +24 -0
  96. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json +34 -0
  97. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json +49410 -0
  98. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt +48895 -0
  99. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json +30 -0
  100. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json +30 -0
  101. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json +49410 -0
  102. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt +48895 -0
  103. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json +30 -0
  104. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json +38 -0
  105. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json +49410 -0
  106. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json +125 -0
  107. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/spiece.model +0 -0
  108. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json +129428 -0
  109. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json +940 -0
  110. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt +40213 -0
  111. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json +24 -0
  112. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json +38 -0
  113. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/vocab.json +49411 -0
  114. diffsynth/trainers/__init__.py +0 -0
  115. diffsynth/trainers/text_to_image.py +253 -0
  116. diffsynth-1.0.0.dist-info/LICENSE +201 -0
  117. diffsynth-1.0.0.dist-info/METADATA +23 -0
  118. diffsynth-1.0.0.dist-info/RECORD +120 -0
  119. diffsynth-1.0.0.dist-info/WHEEL +5 -0
  120. diffsynth-1.0.0.dist-info/top_level.txt +1 -0
diffsynth/__init__.py ADDED
@@ -0,0 +1,6 @@
1
+ from .data import *
2
+ from .models import *
3
+ from .prompters import *
4
+ from .schedulers import *
5
+ from .pipelines import *
6
+ from .controlnets import *
File without changes
@@ -0,0 +1,243 @@
1
+ from typing_extensions import Literal, TypeAlias
2
+
3
+ from ..models.sd_text_encoder import SDTextEncoder
4
+ from ..models.sd_unet import SDUNet
5
+ from ..models.sd_vae_encoder import SDVAEEncoder
6
+ from ..models.sd_vae_decoder import SDVAEDecoder
7
+
8
+ from ..models.sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
9
+ from ..models.sdxl_unet import SDXLUNet
10
+ from ..models.sdxl_vae_decoder import SDXLVAEDecoder
11
+ from ..models.sdxl_vae_encoder import SDXLVAEEncoder
12
+
13
+ from ..models.sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
14
+ from ..models.sd3_dit import SD3DiT
15
+ from ..models.sd3_vae_decoder import SD3VAEDecoder
16
+ from ..models.sd3_vae_encoder import SD3VAEEncoder
17
+
18
+ from ..models.sd_controlnet import SDControlNet
19
+
20
+ from ..models.sd_motion import SDMotionModel
21
+ from ..models.sdxl_motion import SDXLMotionModel
22
+
23
+ from ..models.svd_image_encoder import SVDImageEncoder
24
+ from ..models.svd_unet import SVDUNet
25
+ from ..models.svd_vae_decoder import SVDVAEDecoder
26
+ from ..models.svd_vae_encoder import SVDVAEEncoder
27
+
28
+ from ..models.sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
29
+ from ..models.sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
30
+
31
+ from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
32
+ from ..models.hunyuan_dit import HunyuanDiT
33
+
34
+
35
+
36
+ model_loader_configs = [
37
+ # These configs are provided for detecting model type automatically.
38
+ # The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
39
+ (None, "091b0e30e77c76626b3ba62acdf95343", ["sd_controlnet"], [SDControlNet], "civitai"),
40
+ (None, "4a6c8306a27d916dea81263c8c88f450", ["hunyuan_dit_clip_text_encoder"], [HunyuanDiTCLIPTextEncoder], "civitai"),
41
+ (None, "f4aec400fe394297961218c768004521", ["hunyuan_dit"], [HunyuanDiT], "civitai"),
42
+ (None, "9e6e58043a5a2e332803ed42f6ee7181", ["hunyuan_dit_t5_text_encoder"], [HunyuanDiTT5TextEncoder], "civitai"),
43
+ (None, "13115dd45a6e1c39860f91ab073b8a78", ["sdxl_vae_encoder", "sdxl_vae_decoder"], [SDXLVAEEncoder, SDXLVAEDecoder], "diffusers"),
44
+ (None, "d78aa6797382a6d455362358a3295ea9", ["sd_ipadapter_clip_image_encoder"], [IpAdapterCLIPImageEmbedder], "diffusers"),
45
+ (None, "e291636cc15e803186b47404262ef812", ["sd_ipadapter"], [SDIpAdapter], "civitai"),
46
+ (None, "399c81f2f8de8d1843d0127a00f3c224", ["sdxl_ipadapter_clip_image_encoder"], [IpAdapterXLCLIPImageEmbedder], "diffusers"),
47
+ (None, "a64eac9aa0db4b9602213bc0131281c7", ["sdxl_ipadapter"], [SDXLIpAdapter], "civitai"),
48
+ (None, "52817e4fdd89df154f02749ca6f692ac", ["sdxl_unet"], [SDXLUNet], "diffusers"),
49
+ (None, "03343c606f16d834d6411d0902b53636", ["sd_text_encoder", "sd_unet", "sd_vae_decoder", "sd_vae_encoder"], [SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder], "civitai"),
50
+ (None, "d4ba77a7ece070679b4a987f58f201e9", ["sd_text_encoder"], [SDTextEncoder], "civitai"),
51
+ (None, "d0c89e55c5a57cf3981def0cb1c9e65a", ["sd_vae_decoder", "sd_vae_encoder"], [SDVAEDecoder, SDVAEEncoder], "civitai"),
52
+ (None, "3926bf373b39a67eeafd7901478a47a7", ["sd_unet"], [SDUNet], "civitai"),
53
+ (None, "1e0c39ec176b9007c05f76d52b554a4d", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
54
+ (None, "d9e0290829ba8d98e28e1a2b1407db4a", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_text_encoder_3", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
55
+ (None, "5072d0b24e406b49507abe861cf97691", ["sd3_text_encoder_3"], [SD3TextEncoder3], "civitai"),
56
+ (None, "4cf64a799d04260df438c6f33c9a047e", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"),
57
+ (None, "d9b008a867c498ab12ad24042eff8e3f", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"), # SDXL-Turbo
58
+ (None, "025bb7452e531a3853d951d77c63f032", ["sdxl_text_encoder", "sdxl_text_encoder_2"], [SDXLTextEncoder, SDXLTextEncoder2], "civitai"),
59
+ (None, "298997b403a4245c04102c9f36aac348", ["sdxl_unet"], [SDXLUNet], "civitai"),
60
+ (None, "2a07abce74b4bdc696b76254ab474da6", ["svd_image_encoder", "svd_unet", "svd_vae_decoder", "svd_vae_encoder"], [SVDImageEncoder, SVDUNet, SVDVAEDecoder, SVDVAEEncoder], "civitai"),
61
+ (None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"),
62
+ (None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"),
63
+ ]
64
+ huggingface_model_loader_configs = [
65
+ # These configs are provided for detecting model type automatically.
66
+ # The format is (architecture_in_huggingface_config, huggingface_lib, model_name)
67
+ ("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder"),
68
+ ("MarianMTModel", "transformers.models.marian.modeling_marian", "translator"),
69
+ ("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt"),
70
+ ]
71
+ patch_model_loader_configs = [
72
+ # These configs are provided for detecting model type automatically.
73
+ # The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
74
+ ("9a4ab6869ac9b7d6e31f9854e397c867", ["svd_unet"], [SVDUNet], {"add_positional_conv": 128}),
75
+ ]
76
+
77
+ preset_models_on_huggingface = {
78
+ "HunyuanDiT": [
79
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
80
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
81
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
82
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
83
+ ],
84
+ "stable-video-diffusion-img2vid-xt": [
85
+ ("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
86
+ ],
87
+ "ExVideo-SVD-128f-v1": [
88
+ ("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
89
+ ],
90
+ }
91
+ preset_models_on_modelscope = {
92
+ # Hunyuan DiT
93
+ "HunyuanDiT": [
94
+ ("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
95
+ ("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
96
+ ("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
97
+ ("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
98
+ ],
99
+ # Stable Video Diffusion
100
+ "stable-video-diffusion-img2vid-xt": [
101
+ ("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
102
+ ],
103
+ # ExVideo
104
+ "ExVideo-SVD-128f-v1": [
105
+ ("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
106
+ ],
107
+ # Stable Diffusion
108
+ "StableDiffusion_v15": [
109
+ ("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
110
+ ],
111
+ "DreamShaper_8": [
112
+ ("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
113
+ ],
114
+ "AingDiffusion_v12": [
115
+ ("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
116
+ ],
117
+ "Flat2DAnimerge_v45Sharp": [
118
+ ("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
119
+ ],
120
+ # Textual Inversion
121
+ "TextualInversion_VeryBadImageNegative_v1.3": [
122
+ ("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
123
+ ],
124
+ # Stable Diffusion XL
125
+ "StableDiffusionXL_v1": [
126
+ ("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
127
+ ],
128
+ "BluePencilXL_v200": [
129
+ ("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
130
+ ],
131
+ "StableDiffusionXL_Turbo": [
132
+ ("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
133
+ ],
134
+ # Stable Diffusion 3
135
+ "StableDiffusion3": [
136
+ ("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
137
+ ],
138
+ "StableDiffusion3_without_T5": [
139
+ ("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
140
+ ],
141
+ # ControlNet
142
+ "ControlNet_v11f1p_sd15_depth": [
143
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
144
+ ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
145
+ ],
146
+ "ControlNet_v11p_sd15_softedge": [
147
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
148
+ ("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
149
+ ],
150
+ "ControlNet_v11f1e_sd15_tile": [
151
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
152
+ ],
153
+ "ControlNet_v11p_sd15_lineart": [
154
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
155
+ ("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
156
+ ("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
157
+ ],
158
+ # AnimateDiff
159
+ "AnimateDiff_v2": [
160
+ ("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
161
+ ],
162
+ "AnimateDiff_xl_beta": [
163
+ ("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
164
+ ],
165
+ # RIFE
166
+ "RIFE": [
167
+ ("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
168
+ ],
169
+ # Beautiful Prompt
170
+ "BeautifulPrompt": [
171
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
172
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
173
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
174
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
175
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
176
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
177
+ ],
178
+ # Translator
179
+ "opus-mt-zh-en": [
180
+ ("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
181
+ ("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
182
+ ("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
183
+ ("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
184
+ ("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
185
+ ("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
186
+ ("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
187
+ ("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
188
+ ],
189
+ # IP-Adapter
190
+ "IP-Adapter-SD": [
191
+ ("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
192
+ ("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
193
+ ],
194
+ "IP-Adapter-SDXL": [
195
+ ("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
196
+ ("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
197
+ ],
198
+ # Kolors
199
+ "Kolors": [
200
+ ("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
201
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
202
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
203
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
204
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
205
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
206
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
207
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
208
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
209
+ ("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
210
+ ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
211
+ ],
212
+ "SDXL-vae-fp16-fix": [
213
+ ("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
214
+ ],
215
+ }
216
+ Preset_model_id: TypeAlias = Literal[
217
+ "HunyuanDiT",
218
+ "stable-video-diffusion-img2vid-xt",
219
+ "ExVideo-SVD-128f-v1",
220
+ "StableDiffusion_v15",
221
+ "DreamShaper_8",
222
+ "AingDiffusion_v12",
223
+ "Flat2DAnimerge_v45Sharp",
224
+ "TextualInversion_VeryBadImageNegative_v1.3",
225
+ "StableDiffusionXL_v1",
226
+ "BluePencilXL_v200",
227
+ "StableDiffusionXL_Turbo",
228
+ "ControlNet_v11f1p_sd15_depth",
229
+ "ControlNet_v11p_sd15_softedge",
230
+ "ControlNet_v11f1e_sd15_tile",
231
+ "ControlNet_v11p_sd15_lineart",
232
+ "AnimateDiff_v2",
233
+ "AnimateDiff_xl_beta",
234
+ "RIFE",
235
+ "BeautifulPrompt",
236
+ "opus-mt-zh-en",
237
+ "IP-Adapter-SD",
238
+ "IP-Adapter-SDXL",
239
+ "StableDiffusion3",
240
+ "StableDiffusion3_without_T5",
241
+ "Kolors",
242
+ "SDXL-vae-fp16-fix",
243
+ ]
@@ -0,0 +1,2 @@
1
+ from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager
2
+ from .processors import Annotator
@@ -0,0 +1,53 @@
1
+ import torch
2
+ import numpy as np
3
+ from .processors import Processor_id
4
+
5
+
6
+ class ControlNetConfigUnit:
7
+ def __init__(self, processor_id: Processor_id, model_path, scale=1.0):
8
+ self.processor_id = processor_id
9
+ self.model_path = model_path
10
+ self.scale = scale
11
+
12
+
13
+ class ControlNetUnit:
14
+ def __init__(self, processor, model, scale=1.0):
15
+ self.processor = processor
16
+ self.model = model
17
+ self.scale = scale
18
+
19
+
20
+ class MultiControlNetManager:
21
+ def __init__(self, controlnet_units=[]):
22
+ self.processors = [unit.processor for unit in controlnet_units]
23
+ self.models = [unit.model for unit in controlnet_units]
24
+ self.scales = [unit.scale for unit in controlnet_units]
25
+
26
+ def process_image(self, image, processor_id=None):
27
+ if processor_id is None:
28
+ processed_image = [processor(image) for processor in self.processors]
29
+ else:
30
+ processed_image = [self.processors[processor_id](image)]
31
+ processed_image = torch.concat([
32
+ torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
33
+ for image_ in processed_image
34
+ ], dim=0)
35
+ return processed_image
36
+
37
+ def __call__(
38
+ self,
39
+ sample, timestep, encoder_hidden_states, conditionings,
40
+ tiled=False, tile_size=64, tile_stride=32
41
+ ):
42
+ res_stack = None
43
+ for conditioning, model, scale in zip(conditionings, self.models, self.scales):
44
+ res_stack_ = model(
45
+ sample, timestep, encoder_hidden_states, conditioning,
46
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
47
+ )
48
+ res_stack_ = [res * scale for res in res_stack_]
49
+ if res_stack is None:
50
+ res_stack = res_stack_
51
+ else:
52
+ res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
53
+ return res_stack
@@ -0,0 +1,51 @@
1
+ from typing_extensions import Literal, TypeAlias
2
+ import warnings
3
+ with warnings.catch_warnings():
4
+ warnings.simplefilter("ignore")
5
+ from controlnet_aux.processor import (
6
+ CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector
7
+ )
8
+
9
+
10
+ Processor_id: TypeAlias = Literal[
11
+ "canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"
12
+ ]
13
+
14
+ class Annotator:
15
+ def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda'):
16
+ if processor_id == "canny":
17
+ self.processor = CannyDetector()
18
+ elif processor_id == "depth":
19
+ self.processor = MidasDetector.from_pretrained(model_path).to(device)
20
+ elif processor_id == "softedge":
21
+ self.processor = HEDdetector.from_pretrained(model_path).to(device)
22
+ elif processor_id == "lineart":
23
+ self.processor = LineartDetector.from_pretrained(model_path).to(device)
24
+ elif processor_id == "lineart_anime":
25
+ self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
26
+ elif processor_id == "openpose":
27
+ self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
28
+ elif processor_id == "tile":
29
+ self.processor = None
30
+ else:
31
+ raise ValueError(f"Unsupported processor_id: {processor_id}")
32
+
33
+ self.processor_id = processor_id
34
+ self.detect_resolution = detect_resolution
35
+
36
+ def __call__(self, image):
37
+ width, height = image.size
38
+ if self.processor_id == "openpose":
39
+ kwargs = {
40
+ "include_body": True,
41
+ "include_hand": True,
42
+ "include_face": True
43
+ }
44
+ else:
45
+ kwargs = {}
46
+ if self.processor is not None:
47
+ detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height)
48
+ image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs)
49
+ image = image.resize((width, height))
50
+ return image
51
+
@@ -0,0 +1 @@
1
+ from .video import VideoData, save_video, save_frames
@@ -0,0 +1,35 @@
1
+ import torch, os
2
+ from torchvision import transforms
3
+ import pandas as pd
4
+ from PIL import Image
5
+
6
+
7
+
8
+ class TextImageDataset(torch.utils.data.Dataset):
9
+ def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False):
10
+ self.steps_per_epoch = steps_per_epoch
11
+ metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
12
+ self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
13
+ self.text = metadata["text"].to_list()
14
+ self.image_processor = transforms.Compose(
15
+ [
16
+ transforms.Resize(max(height, width), interpolation=transforms.InterpolationMode.BILINEAR),
17
+ transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
18
+ transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize([0.5], [0.5]),
21
+ ]
22
+ )
23
+
24
+
25
+ def __getitem__(self, index):
26
+ data_id = torch.randint(0, len(self.path), (1,))[0]
27
+ data_id = (data_id + index) % len(self.path) # For fixed seed.
28
+ text = self.text[data_id]
29
+ image = Image.open(self.path[data_id]).convert("RGB")
30
+ image = self.image_processor(image)
31
+ return {"text": text, "image": image}
32
+
33
+
34
+ def __len__(self):
35
+ return self.steps_per_epoch
@@ -0,0 +1,148 @@
1
+ import imageio, os
2
+ import numpy as np
3
+ from PIL import Image
4
+ from tqdm import tqdm
5
+
6
+
7
+ class LowMemoryVideo:
8
+ def __init__(self, file_name):
9
+ self.reader = imageio.get_reader(file_name)
10
+
11
+ def __len__(self):
12
+ return self.reader.count_frames()
13
+
14
+ def __getitem__(self, item):
15
+ return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")
16
+
17
+ def __del__(self):
18
+ self.reader.close()
19
+
20
+
21
+ def split_file_name(file_name):
22
+ result = []
23
+ number = -1
24
+ for i in file_name:
25
+ if ord(i)>=ord("0") and ord(i)<=ord("9"):
26
+ if number == -1:
27
+ number = 0
28
+ number = number*10 + ord(i) - ord("0")
29
+ else:
30
+ if number != -1:
31
+ result.append(number)
32
+ number = -1
33
+ result.append(i)
34
+ if number != -1:
35
+ result.append(number)
36
+ result = tuple(result)
37
+ return result
38
+
39
+
40
+ def search_for_images(folder):
41
+ file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
42
+ file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
43
+ file_list = [i[1] for i in sorted(file_list)]
44
+ file_list = [os.path.join(folder, i) for i in file_list]
45
+ return file_list
46
+
47
+
48
+ class LowMemoryImageFolder:
49
+ def __init__(self, folder, file_list=None):
50
+ if file_list is None:
51
+ self.file_list = search_for_images(folder)
52
+ else:
53
+ self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
54
+
55
+ def __len__(self):
56
+ return len(self.file_list)
57
+
58
+ def __getitem__(self, item):
59
+ return Image.open(self.file_list[item]).convert("RGB")
60
+
61
+ def __del__(self):
62
+ pass
63
+
64
+
65
+ def crop_and_resize(image, height, width):
66
+ image = np.array(image)
67
+ image_height, image_width, _ = image.shape
68
+ if image_height / image_width < height / width:
69
+ croped_width = int(image_height / height * width)
70
+ left = (image_width - croped_width) // 2
71
+ image = image[:, left: left+croped_width]
72
+ image = Image.fromarray(image).resize((width, height))
73
+ else:
74
+ croped_height = int(image_width / width * height)
75
+ left = (image_height - croped_height) // 2
76
+ image = image[left: left+croped_height, :]
77
+ image = Image.fromarray(image).resize((width, height))
78
+ return image
79
+
80
+
81
+ class VideoData:
82
+ def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs):
83
+ if video_file is not None:
84
+ self.data_type = "video"
85
+ self.data = LowMemoryVideo(video_file, **kwargs)
86
+ elif image_folder is not None:
87
+ self.data_type = "images"
88
+ self.data = LowMemoryImageFolder(image_folder, **kwargs)
89
+ else:
90
+ raise ValueError("Cannot open video or image folder")
91
+ self.length = None
92
+ self.set_shape(height, width)
93
+
94
+ def raw_data(self):
95
+ frames = []
96
+ for i in range(self.__len__()):
97
+ frames.append(self.__getitem__(i))
98
+ return frames
99
+
100
+ def set_length(self, length):
101
+ self.length = length
102
+
103
+ def set_shape(self, height, width):
104
+ self.height = height
105
+ self.width = width
106
+
107
+ def __len__(self):
108
+ if self.length is None:
109
+ return len(self.data)
110
+ else:
111
+ return self.length
112
+
113
+ def shape(self):
114
+ if self.height is not None and self.width is not None:
115
+ return self.height, self.width
116
+ else:
117
+ height, width, _ = self.__getitem__(0).shape
118
+ return height, width
119
+
120
+ def __getitem__(self, item):
121
+ frame = self.data.__getitem__(item)
122
+ width, height = frame.size
123
+ if self.height is not None and self.width is not None:
124
+ if self.height != height or self.width != width:
125
+ frame = crop_and_resize(frame, self.height, self.width)
126
+ return frame
127
+
128
+ def __del__(self):
129
+ pass
130
+
131
+ def save_images(self, folder):
132
+ os.makedirs(folder, exist_ok=True)
133
+ for i in tqdm(range(self.__len__()), desc="Saving images"):
134
+ frame = self.__getitem__(i)
135
+ frame.save(os.path.join(folder, f"{i}.png"))
136
+
137
+
138
+ def save_video(frames, save_path, fps, quality=9):
139
+ writer = imageio.get_writer(save_path, fps=fps, quality=quality)
140
+ for frame in tqdm(frames, desc="Saving video"):
141
+ frame = np.array(frame)
142
+ writer.append_data(frame)
143
+ writer.close()
144
+
145
+ def save_frames(frames, save_path):
146
+ os.makedirs(save_path, exist_ok=True)
147
+ for i, frame in enumerate(tqdm(frames, desc="Saving images")):
148
+ frame.save(os.path.join(save_path, f"{i}.png"))
@@ -0,0 +1,118 @@
1
+ import torch
2
+ from einops import repeat
3
+ from PIL import Image
4
+ import numpy as np
5
+
6
+
7
+ class ResidualDenseBlock(torch.nn.Module):
8
+
9
+ def __init__(self, num_feat=64, num_grow_ch=32):
10
+ super(ResidualDenseBlock, self).__init__()
11
+ self.conv1 = torch.nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
12
+ self.conv2 = torch.nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
13
+ self.conv3 = torch.nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
14
+ self.conv4 = torch.nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
15
+ self.conv5 = torch.nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
16
+ self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
17
+
18
+ def forward(self, x):
19
+ x1 = self.lrelu(self.conv1(x))
20
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
21
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
22
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
23
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
24
+ return x5 * 0.2 + x
25
+
26
+
27
+ class RRDB(torch.nn.Module):
28
+
29
+ def __init__(self, num_feat, num_grow_ch=32):
30
+ super(RRDB, self).__init__()
31
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
32
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
33
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
34
+
35
+ def forward(self, x):
36
+ out = self.rdb1(x)
37
+ out = self.rdb2(out)
38
+ out = self.rdb3(out)
39
+ return out * 0.2 + x
40
+
41
+
42
+ class RRDBNet(torch.nn.Module):
43
+
44
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32):
45
+ super(RRDBNet, self).__init__()
46
+ self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
47
+ self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
48
+ self.conv_body = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
49
+ # upsample
50
+ self.conv_up1 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
51
+ self.conv_up2 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
52
+ self.conv_hr = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
53
+ self.conv_last = torch.nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
54
+ self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
55
+
56
+ def forward(self, x):
57
+ feat = x
58
+ feat = self.conv_first(feat)
59
+ body_feat = self.conv_body(self.body(feat))
60
+ feat = feat + body_feat
61
+ # upsample
62
+ feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
63
+ feat = self.lrelu(self.conv_up1(feat))
64
+ feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
65
+ feat = self.lrelu(self.conv_up2(feat))
66
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
67
+ return out
68
+
69
+
70
+ class ESRGAN(torch.nn.Module):
71
+ def __init__(self, model):
72
+ super().__init__()
73
+ self.model = model
74
+
75
+ @staticmethod
76
+ def from_pretrained(model_path):
77
+ model = RRDBNet()
78
+ state_dict = torch.load(model_path, map_location="cpu")["params_ema"]
79
+ model.load_state_dict(state_dict)
80
+ model.eval()
81
+ return ESRGAN(model)
82
+
83
+ def process_image(self, image):
84
+ image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)
85
+ return image
86
+
87
+ def process_images(self, images):
88
+ images = [self.process_image(image) for image in images]
89
+ images = torch.stack(images)
90
+ return images
91
+
92
+ def decode_images(self, images):
93
+ images = (images.permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
94
+ images = [Image.fromarray(image) for image in images]
95
+ return images
96
+
97
+ @torch.no_grad()
98
+ def upscale(self, images, batch_size=4, progress_bar=lambda x:x):
99
+ # Preprocess
100
+ input_tensor = self.process_images(images)
101
+
102
+ # Interpolate
103
+ output_tensor = []
104
+ for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
105
+ batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
106
+ batch_input_tensor = input_tensor[batch_id: batch_id_]
107
+ batch_input_tensor = batch_input_tensor.to(
108
+ device=self.model.conv_first.weight.device,
109
+ dtype=self.model.conv_first.weight.dtype)
110
+ batch_output_tensor = self.model(batch_input_tensor)
111
+ output_tensor.append(batch_output_tensor.cpu())
112
+
113
+ # Output
114
+ output_tensor = torch.concat(output_tensor, dim=0)
115
+
116
+ # To images
117
+ output_images = self.decode_images(output_tensor)
118
+ return output_images