diffsynth 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth/__init__.py +6 -0
- diffsynth/configs/__init__.py +0 -0
- diffsynth/configs/model_config.py +243 -0
- diffsynth/controlnets/__init__.py +2 -0
- diffsynth/controlnets/controlnet_unit.py +53 -0
- diffsynth/controlnets/processors.py +51 -0
- diffsynth/data/__init__.py +1 -0
- diffsynth/data/simple_text_image.py +35 -0
- diffsynth/data/video.py +148 -0
- diffsynth/extensions/ESRGAN/__init__.py +118 -0
- diffsynth/extensions/FastBlend/__init__.py +63 -0
- diffsynth/extensions/FastBlend/api.py +397 -0
- diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
- diffsynth/extensions/FastBlend/data.py +146 -0
- diffsynth/extensions/FastBlend/patch_match.py +298 -0
- diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
- diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
- diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
- diffsynth/extensions/FastBlend/runners/fast.py +141 -0
- diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
- diffsynth/extensions/RIFE/__init__.py +242 -0
- diffsynth/extensions/__init__.py +0 -0
- diffsynth/models/__init__.py +1 -0
- diffsynth/models/attention.py +89 -0
- diffsynth/models/downloader.py +66 -0
- diffsynth/models/hunyuan_dit.py +451 -0
- diffsynth/models/hunyuan_dit_text_encoder.py +163 -0
- diffsynth/models/kolors_text_encoder.py +1363 -0
- diffsynth/models/lora.py +195 -0
- diffsynth/models/model_manager.py +536 -0
- diffsynth/models/sd3_dit.py +798 -0
- diffsynth/models/sd3_text_encoder.py +1107 -0
- diffsynth/models/sd3_vae_decoder.py +81 -0
- diffsynth/models/sd3_vae_encoder.py +95 -0
- diffsynth/models/sd_controlnet.py +588 -0
- diffsynth/models/sd_ipadapter.py +57 -0
- diffsynth/models/sd_motion.py +199 -0
- diffsynth/models/sd_text_encoder.py +321 -0
- diffsynth/models/sd_unet.py +1108 -0
- diffsynth/models/sd_vae_decoder.py +336 -0
- diffsynth/models/sd_vae_encoder.py +282 -0
- diffsynth/models/sdxl_ipadapter.py +122 -0
- diffsynth/models/sdxl_motion.py +104 -0
- diffsynth/models/sdxl_text_encoder.py +759 -0
- diffsynth/models/sdxl_unet.py +1899 -0
- diffsynth/models/sdxl_vae_decoder.py +24 -0
- diffsynth/models/sdxl_vae_encoder.py +24 -0
- diffsynth/models/svd_image_encoder.py +505 -0
- diffsynth/models/svd_unet.py +2004 -0
- diffsynth/models/svd_vae_decoder.py +578 -0
- diffsynth/models/svd_vae_encoder.py +139 -0
- diffsynth/models/tiler.py +106 -0
- diffsynth/pipelines/__init__.py +9 -0
- diffsynth/pipelines/base.py +34 -0
- diffsynth/pipelines/dancer.py +178 -0
- diffsynth/pipelines/hunyuan_image.py +274 -0
- diffsynth/pipelines/pipeline_runner.py +105 -0
- diffsynth/pipelines/sd3_image.py +132 -0
- diffsynth/pipelines/sd_image.py +173 -0
- diffsynth/pipelines/sd_video.py +266 -0
- diffsynth/pipelines/sdxl_image.py +191 -0
- diffsynth/pipelines/sdxl_video.py +223 -0
- diffsynth/pipelines/svd_video.py +297 -0
- diffsynth/processors/FastBlend.py +142 -0
- diffsynth/processors/PILEditor.py +28 -0
- diffsynth/processors/RIFE.py +77 -0
- diffsynth/processors/__init__.py +0 -0
- diffsynth/processors/base.py +6 -0
- diffsynth/processors/sequencial_processor.py +41 -0
- diffsynth/prompters/__init__.py +6 -0
- diffsynth/prompters/base_prompter.py +57 -0
- diffsynth/prompters/hunyuan_dit_prompter.py +69 -0
- diffsynth/prompters/kolors_prompter.py +353 -0
- diffsynth/prompters/prompt_refiners.py +77 -0
- diffsynth/prompters/sd3_prompter.py +92 -0
- diffsynth/prompters/sd_prompter.py +73 -0
- diffsynth/prompters/sdxl_prompter.py +61 -0
- diffsynth/schedulers/__init__.py +3 -0
- diffsynth/schedulers/continuous_ode.py +59 -0
- diffsynth/schedulers/ddim.py +79 -0
- diffsynth/schedulers/flow_match.py +51 -0
- diffsynth/tokenizer_configs/__init__.py +0 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json +7 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json +16 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt +47020 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt +21128 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json +28 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json +1 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model +0 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json +1 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model +0 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json +12 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt +0 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json +24 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json +34 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json +38 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json +125 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/spiece.model +0 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json +129428 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json +940 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt +40213 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json +24 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json +38 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/vocab.json +49411 -0
- diffsynth/trainers/__init__.py +0 -0
- diffsynth/trainers/text_to_image.py +253 -0
- diffsynth-1.0.0.dist-info/LICENSE +201 -0
- diffsynth-1.0.0.dist-info/METADATA +23 -0
- diffsynth-1.0.0.dist-info/RECORD +120 -0
- diffsynth-1.0.0.dist-info/WHEEL +5 -0
- diffsynth-1.0.0.dist-info/top_level.txt +1 -0
diffsynth/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
from typing_extensions import Literal, TypeAlias
|
|
2
|
+
|
|
3
|
+
from ..models.sd_text_encoder import SDTextEncoder
|
|
4
|
+
from ..models.sd_unet import SDUNet
|
|
5
|
+
from ..models.sd_vae_encoder import SDVAEEncoder
|
|
6
|
+
from ..models.sd_vae_decoder import SDVAEDecoder
|
|
7
|
+
|
|
8
|
+
from ..models.sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
|
9
|
+
from ..models.sdxl_unet import SDXLUNet
|
|
10
|
+
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
|
|
11
|
+
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
|
|
12
|
+
|
|
13
|
+
from ..models.sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
|
|
14
|
+
from ..models.sd3_dit import SD3DiT
|
|
15
|
+
from ..models.sd3_vae_decoder import SD3VAEDecoder
|
|
16
|
+
from ..models.sd3_vae_encoder import SD3VAEEncoder
|
|
17
|
+
|
|
18
|
+
from ..models.sd_controlnet import SDControlNet
|
|
19
|
+
|
|
20
|
+
from ..models.sd_motion import SDMotionModel
|
|
21
|
+
from ..models.sdxl_motion import SDXLMotionModel
|
|
22
|
+
|
|
23
|
+
from ..models.svd_image_encoder import SVDImageEncoder
|
|
24
|
+
from ..models.svd_unet import SVDUNet
|
|
25
|
+
from ..models.svd_vae_decoder import SVDVAEDecoder
|
|
26
|
+
from ..models.svd_vae_encoder import SVDVAEEncoder
|
|
27
|
+
|
|
28
|
+
from ..models.sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
|
|
29
|
+
from ..models.sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
|
30
|
+
|
|
31
|
+
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
|
32
|
+
from ..models.hunyuan_dit import HunyuanDiT
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
model_loader_configs = [
|
|
37
|
+
# These configs are provided for detecting model type automatically.
|
|
38
|
+
# The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
|
|
39
|
+
(None, "091b0e30e77c76626b3ba62acdf95343", ["sd_controlnet"], [SDControlNet], "civitai"),
|
|
40
|
+
(None, "4a6c8306a27d916dea81263c8c88f450", ["hunyuan_dit_clip_text_encoder"], [HunyuanDiTCLIPTextEncoder], "civitai"),
|
|
41
|
+
(None, "f4aec400fe394297961218c768004521", ["hunyuan_dit"], [HunyuanDiT], "civitai"),
|
|
42
|
+
(None, "9e6e58043a5a2e332803ed42f6ee7181", ["hunyuan_dit_t5_text_encoder"], [HunyuanDiTT5TextEncoder], "civitai"),
|
|
43
|
+
(None, "13115dd45a6e1c39860f91ab073b8a78", ["sdxl_vae_encoder", "sdxl_vae_decoder"], [SDXLVAEEncoder, SDXLVAEDecoder], "diffusers"),
|
|
44
|
+
(None, "d78aa6797382a6d455362358a3295ea9", ["sd_ipadapter_clip_image_encoder"], [IpAdapterCLIPImageEmbedder], "diffusers"),
|
|
45
|
+
(None, "e291636cc15e803186b47404262ef812", ["sd_ipadapter"], [SDIpAdapter], "civitai"),
|
|
46
|
+
(None, "399c81f2f8de8d1843d0127a00f3c224", ["sdxl_ipadapter_clip_image_encoder"], [IpAdapterXLCLIPImageEmbedder], "diffusers"),
|
|
47
|
+
(None, "a64eac9aa0db4b9602213bc0131281c7", ["sdxl_ipadapter"], [SDXLIpAdapter], "civitai"),
|
|
48
|
+
(None, "52817e4fdd89df154f02749ca6f692ac", ["sdxl_unet"], [SDXLUNet], "diffusers"),
|
|
49
|
+
(None, "03343c606f16d834d6411d0902b53636", ["sd_text_encoder", "sd_unet", "sd_vae_decoder", "sd_vae_encoder"], [SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder], "civitai"),
|
|
50
|
+
(None, "d4ba77a7ece070679b4a987f58f201e9", ["sd_text_encoder"], [SDTextEncoder], "civitai"),
|
|
51
|
+
(None, "d0c89e55c5a57cf3981def0cb1c9e65a", ["sd_vae_decoder", "sd_vae_encoder"], [SDVAEDecoder, SDVAEEncoder], "civitai"),
|
|
52
|
+
(None, "3926bf373b39a67eeafd7901478a47a7", ["sd_unet"], [SDUNet], "civitai"),
|
|
53
|
+
(None, "1e0c39ec176b9007c05f76d52b554a4d", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
|
54
|
+
(None, "d9e0290829ba8d98e28e1a2b1407db4a", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_text_encoder_3", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
|
55
|
+
(None, "5072d0b24e406b49507abe861cf97691", ["sd3_text_encoder_3"], [SD3TextEncoder3], "civitai"),
|
|
56
|
+
(None, "4cf64a799d04260df438c6f33c9a047e", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"),
|
|
57
|
+
(None, "d9b008a867c498ab12ad24042eff8e3f", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"), # SDXL-Turbo
|
|
58
|
+
(None, "025bb7452e531a3853d951d77c63f032", ["sdxl_text_encoder", "sdxl_text_encoder_2"], [SDXLTextEncoder, SDXLTextEncoder2], "civitai"),
|
|
59
|
+
(None, "298997b403a4245c04102c9f36aac348", ["sdxl_unet"], [SDXLUNet], "civitai"),
|
|
60
|
+
(None, "2a07abce74b4bdc696b76254ab474da6", ["svd_image_encoder", "svd_unet", "svd_vae_decoder", "svd_vae_encoder"], [SVDImageEncoder, SVDUNet, SVDVAEDecoder, SVDVAEEncoder], "civitai"),
|
|
61
|
+
(None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"),
|
|
62
|
+
(None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"),
|
|
63
|
+
]
|
|
64
|
+
huggingface_model_loader_configs = [
|
|
65
|
+
# These configs are provided for detecting model type automatically.
|
|
66
|
+
# The format is (architecture_in_huggingface_config, huggingface_lib, model_name)
|
|
67
|
+
("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder"),
|
|
68
|
+
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator"),
|
|
69
|
+
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt"),
|
|
70
|
+
]
|
|
71
|
+
patch_model_loader_configs = [
|
|
72
|
+
# These configs are provided for detecting model type automatically.
|
|
73
|
+
# The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
|
|
74
|
+
("9a4ab6869ac9b7d6e31f9854e397c867", ["svd_unet"], [SVDUNet], {"add_positional_conv": 128}),
|
|
75
|
+
]
|
|
76
|
+
|
|
77
|
+
preset_models_on_huggingface = {
|
|
78
|
+
"HunyuanDiT": [
|
|
79
|
+
("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
|
80
|
+
("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
|
81
|
+
("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
|
82
|
+
("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
|
83
|
+
],
|
|
84
|
+
"stable-video-diffusion-img2vid-xt": [
|
|
85
|
+
("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
|
86
|
+
],
|
|
87
|
+
"ExVideo-SVD-128f-v1": [
|
|
88
|
+
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
|
89
|
+
],
|
|
90
|
+
}
|
|
91
|
+
preset_models_on_modelscope = {
|
|
92
|
+
# Hunyuan DiT
|
|
93
|
+
"HunyuanDiT": [
|
|
94
|
+
("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
|
95
|
+
("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
|
96
|
+
("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
|
97
|
+
("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
|
98
|
+
],
|
|
99
|
+
# Stable Video Diffusion
|
|
100
|
+
"stable-video-diffusion-img2vid-xt": [
|
|
101
|
+
("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
|
102
|
+
],
|
|
103
|
+
# ExVideo
|
|
104
|
+
"ExVideo-SVD-128f-v1": [
|
|
105
|
+
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
|
106
|
+
],
|
|
107
|
+
# Stable Diffusion
|
|
108
|
+
"StableDiffusion_v15": [
|
|
109
|
+
("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
|
|
110
|
+
],
|
|
111
|
+
"DreamShaper_8": [
|
|
112
|
+
("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
|
|
113
|
+
],
|
|
114
|
+
"AingDiffusion_v12": [
|
|
115
|
+
("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
|
|
116
|
+
],
|
|
117
|
+
"Flat2DAnimerge_v45Sharp": [
|
|
118
|
+
("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
|
|
119
|
+
],
|
|
120
|
+
# Textual Inversion
|
|
121
|
+
"TextualInversion_VeryBadImageNegative_v1.3": [
|
|
122
|
+
("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
|
|
123
|
+
],
|
|
124
|
+
# Stable Diffusion XL
|
|
125
|
+
"StableDiffusionXL_v1": [
|
|
126
|
+
("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
|
|
127
|
+
],
|
|
128
|
+
"BluePencilXL_v200": [
|
|
129
|
+
("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
|
|
130
|
+
],
|
|
131
|
+
"StableDiffusionXL_Turbo": [
|
|
132
|
+
("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
|
|
133
|
+
],
|
|
134
|
+
# Stable Diffusion 3
|
|
135
|
+
"StableDiffusion3": [
|
|
136
|
+
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
|
|
137
|
+
],
|
|
138
|
+
"StableDiffusion3_without_T5": [
|
|
139
|
+
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
|
|
140
|
+
],
|
|
141
|
+
# ControlNet
|
|
142
|
+
"ControlNet_v11f1p_sd15_depth": [
|
|
143
|
+
("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
|
|
144
|
+
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
|
145
|
+
],
|
|
146
|
+
"ControlNet_v11p_sd15_softedge": [
|
|
147
|
+
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
|
|
148
|
+
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
|
|
149
|
+
],
|
|
150
|
+
"ControlNet_v11f1e_sd15_tile": [
|
|
151
|
+
("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
|
|
152
|
+
],
|
|
153
|
+
"ControlNet_v11p_sd15_lineart": [
|
|
154
|
+
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
|
|
155
|
+
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
|
|
156
|
+
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
|
|
157
|
+
],
|
|
158
|
+
# AnimateDiff
|
|
159
|
+
"AnimateDiff_v2": [
|
|
160
|
+
("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
|
|
161
|
+
],
|
|
162
|
+
"AnimateDiff_xl_beta": [
|
|
163
|
+
("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
|
|
164
|
+
],
|
|
165
|
+
# RIFE
|
|
166
|
+
"RIFE": [
|
|
167
|
+
("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
|
|
168
|
+
],
|
|
169
|
+
# Beautiful Prompt
|
|
170
|
+
"BeautifulPrompt": [
|
|
171
|
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
172
|
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
173
|
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
174
|
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
175
|
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
176
|
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
177
|
+
],
|
|
178
|
+
# Translator
|
|
179
|
+
"opus-mt-zh-en": [
|
|
180
|
+
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
|
181
|
+
("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
|
|
182
|
+
("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
|
|
183
|
+
("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
|
|
184
|
+
("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
|
|
185
|
+
("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
|
|
186
|
+
("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
|
|
187
|
+
("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
|
188
|
+
],
|
|
189
|
+
# IP-Adapter
|
|
190
|
+
"IP-Adapter-SD": [
|
|
191
|
+
("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
|
|
192
|
+
("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
|
|
193
|
+
],
|
|
194
|
+
"IP-Adapter-SDXL": [
|
|
195
|
+
("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
|
|
196
|
+
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
|
197
|
+
],
|
|
198
|
+
# Kolors
|
|
199
|
+
"Kolors": [
|
|
200
|
+
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
|
201
|
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
|
202
|
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
203
|
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
204
|
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
205
|
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
206
|
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
207
|
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
208
|
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
209
|
+
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
|
210
|
+
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
|
211
|
+
],
|
|
212
|
+
"SDXL-vae-fp16-fix": [
|
|
213
|
+
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
|
|
214
|
+
],
|
|
215
|
+
}
|
|
216
|
+
Preset_model_id: TypeAlias = Literal[
|
|
217
|
+
"HunyuanDiT",
|
|
218
|
+
"stable-video-diffusion-img2vid-xt",
|
|
219
|
+
"ExVideo-SVD-128f-v1",
|
|
220
|
+
"StableDiffusion_v15",
|
|
221
|
+
"DreamShaper_8",
|
|
222
|
+
"AingDiffusion_v12",
|
|
223
|
+
"Flat2DAnimerge_v45Sharp",
|
|
224
|
+
"TextualInversion_VeryBadImageNegative_v1.3",
|
|
225
|
+
"StableDiffusionXL_v1",
|
|
226
|
+
"BluePencilXL_v200",
|
|
227
|
+
"StableDiffusionXL_Turbo",
|
|
228
|
+
"ControlNet_v11f1p_sd15_depth",
|
|
229
|
+
"ControlNet_v11p_sd15_softedge",
|
|
230
|
+
"ControlNet_v11f1e_sd15_tile",
|
|
231
|
+
"ControlNet_v11p_sd15_lineart",
|
|
232
|
+
"AnimateDiff_v2",
|
|
233
|
+
"AnimateDiff_xl_beta",
|
|
234
|
+
"RIFE",
|
|
235
|
+
"BeautifulPrompt",
|
|
236
|
+
"opus-mt-zh-en",
|
|
237
|
+
"IP-Adapter-SD",
|
|
238
|
+
"IP-Adapter-SDXL",
|
|
239
|
+
"StableDiffusion3",
|
|
240
|
+
"StableDiffusion3_without_T5",
|
|
241
|
+
"Kolors",
|
|
242
|
+
"SDXL-vae-fp16-fix",
|
|
243
|
+
]
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
from .processors import Processor_id
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ControlNetConfigUnit:
|
|
7
|
+
def __init__(self, processor_id: Processor_id, model_path, scale=1.0):
|
|
8
|
+
self.processor_id = processor_id
|
|
9
|
+
self.model_path = model_path
|
|
10
|
+
self.scale = scale
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ControlNetUnit:
|
|
14
|
+
def __init__(self, processor, model, scale=1.0):
|
|
15
|
+
self.processor = processor
|
|
16
|
+
self.model = model
|
|
17
|
+
self.scale = scale
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class MultiControlNetManager:
|
|
21
|
+
def __init__(self, controlnet_units=[]):
|
|
22
|
+
self.processors = [unit.processor for unit in controlnet_units]
|
|
23
|
+
self.models = [unit.model for unit in controlnet_units]
|
|
24
|
+
self.scales = [unit.scale for unit in controlnet_units]
|
|
25
|
+
|
|
26
|
+
def process_image(self, image, processor_id=None):
|
|
27
|
+
if processor_id is None:
|
|
28
|
+
processed_image = [processor(image) for processor in self.processors]
|
|
29
|
+
else:
|
|
30
|
+
processed_image = [self.processors[processor_id](image)]
|
|
31
|
+
processed_image = torch.concat([
|
|
32
|
+
torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
|
|
33
|
+
for image_ in processed_image
|
|
34
|
+
], dim=0)
|
|
35
|
+
return processed_image
|
|
36
|
+
|
|
37
|
+
def __call__(
|
|
38
|
+
self,
|
|
39
|
+
sample, timestep, encoder_hidden_states, conditionings,
|
|
40
|
+
tiled=False, tile_size=64, tile_stride=32
|
|
41
|
+
):
|
|
42
|
+
res_stack = None
|
|
43
|
+
for conditioning, model, scale in zip(conditionings, self.models, self.scales):
|
|
44
|
+
res_stack_ = model(
|
|
45
|
+
sample, timestep, encoder_hidden_states, conditioning,
|
|
46
|
+
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
|
47
|
+
)
|
|
48
|
+
res_stack_ = [res * scale for res in res_stack_]
|
|
49
|
+
if res_stack is None:
|
|
50
|
+
res_stack = res_stack_
|
|
51
|
+
else:
|
|
52
|
+
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
|
53
|
+
return res_stack
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from typing_extensions import Literal, TypeAlias
|
|
2
|
+
import warnings
|
|
3
|
+
with warnings.catch_warnings():
|
|
4
|
+
warnings.simplefilter("ignore")
|
|
5
|
+
from controlnet_aux.processor import (
|
|
6
|
+
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
Processor_id: TypeAlias = Literal[
|
|
11
|
+
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
class Annotator:
|
|
15
|
+
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda'):
|
|
16
|
+
if processor_id == "canny":
|
|
17
|
+
self.processor = CannyDetector()
|
|
18
|
+
elif processor_id == "depth":
|
|
19
|
+
self.processor = MidasDetector.from_pretrained(model_path).to(device)
|
|
20
|
+
elif processor_id == "softedge":
|
|
21
|
+
self.processor = HEDdetector.from_pretrained(model_path).to(device)
|
|
22
|
+
elif processor_id == "lineart":
|
|
23
|
+
self.processor = LineartDetector.from_pretrained(model_path).to(device)
|
|
24
|
+
elif processor_id == "lineart_anime":
|
|
25
|
+
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
|
|
26
|
+
elif processor_id == "openpose":
|
|
27
|
+
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
|
|
28
|
+
elif processor_id == "tile":
|
|
29
|
+
self.processor = None
|
|
30
|
+
else:
|
|
31
|
+
raise ValueError(f"Unsupported processor_id: {processor_id}")
|
|
32
|
+
|
|
33
|
+
self.processor_id = processor_id
|
|
34
|
+
self.detect_resolution = detect_resolution
|
|
35
|
+
|
|
36
|
+
def __call__(self, image):
|
|
37
|
+
width, height = image.size
|
|
38
|
+
if self.processor_id == "openpose":
|
|
39
|
+
kwargs = {
|
|
40
|
+
"include_body": True,
|
|
41
|
+
"include_hand": True,
|
|
42
|
+
"include_face": True
|
|
43
|
+
}
|
|
44
|
+
else:
|
|
45
|
+
kwargs = {}
|
|
46
|
+
if self.processor is not None:
|
|
47
|
+
detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height)
|
|
48
|
+
image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs)
|
|
49
|
+
image = image.resize((width, height))
|
|
50
|
+
return image
|
|
51
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .video import VideoData, save_video, save_frames
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
import torch, os
|
|
2
|
+
from torchvision import transforms
|
|
3
|
+
import pandas as pd
|
|
4
|
+
from PIL import Image
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TextImageDataset(torch.utils.data.Dataset):
|
|
9
|
+
def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False):
|
|
10
|
+
self.steps_per_epoch = steps_per_epoch
|
|
11
|
+
metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
|
|
12
|
+
self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
|
|
13
|
+
self.text = metadata["text"].to_list()
|
|
14
|
+
self.image_processor = transforms.Compose(
|
|
15
|
+
[
|
|
16
|
+
transforms.Resize(max(height, width), interpolation=transforms.InterpolationMode.BILINEAR),
|
|
17
|
+
transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
|
|
18
|
+
transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
|
|
19
|
+
transforms.ToTensor(),
|
|
20
|
+
transforms.Normalize([0.5], [0.5]),
|
|
21
|
+
]
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def __getitem__(self, index):
|
|
26
|
+
data_id = torch.randint(0, len(self.path), (1,))[0]
|
|
27
|
+
data_id = (data_id + index) % len(self.path) # For fixed seed.
|
|
28
|
+
text = self.text[data_id]
|
|
29
|
+
image = Image.open(self.path[data_id]).convert("RGB")
|
|
30
|
+
image = self.image_processor(image)
|
|
31
|
+
return {"text": text, "image": image}
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def __len__(self):
|
|
35
|
+
return self.steps_per_epoch
|
diffsynth/data/video.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
import imageio, os
|
|
2
|
+
import numpy as np
|
|
3
|
+
from PIL import Image
|
|
4
|
+
from tqdm import tqdm
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LowMemoryVideo:
|
|
8
|
+
def __init__(self, file_name):
|
|
9
|
+
self.reader = imageio.get_reader(file_name)
|
|
10
|
+
|
|
11
|
+
def __len__(self):
|
|
12
|
+
return self.reader.count_frames()
|
|
13
|
+
|
|
14
|
+
def __getitem__(self, item):
|
|
15
|
+
return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")
|
|
16
|
+
|
|
17
|
+
def __del__(self):
|
|
18
|
+
self.reader.close()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def split_file_name(file_name):
|
|
22
|
+
result = []
|
|
23
|
+
number = -1
|
|
24
|
+
for i in file_name:
|
|
25
|
+
if ord(i)>=ord("0") and ord(i)<=ord("9"):
|
|
26
|
+
if number == -1:
|
|
27
|
+
number = 0
|
|
28
|
+
number = number*10 + ord(i) - ord("0")
|
|
29
|
+
else:
|
|
30
|
+
if number != -1:
|
|
31
|
+
result.append(number)
|
|
32
|
+
number = -1
|
|
33
|
+
result.append(i)
|
|
34
|
+
if number != -1:
|
|
35
|
+
result.append(number)
|
|
36
|
+
result = tuple(result)
|
|
37
|
+
return result
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def search_for_images(folder):
|
|
41
|
+
file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
|
|
42
|
+
file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
|
|
43
|
+
file_list = [i[1] for i in sorted(file_list)]
|
|
44
|
+
file_list = [os.path.join(folder, i) for i in file_list]
|
|
45
|
+
return file_list
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class LowMemoryImageFolder:
|
|
49
|
+
def __init__(self, folder, file_list=None):
|
|
50
|
+
if file_list is None:
|
|
51
|
+
self.file_list = search_for_images(folder)
|
|
52
|
+
else:
|
|
53
|
+
self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
|
|
54
|
+
|
|
55
|
+
def __len__(self):
|
|
56
|
+
return len(self.file_list)
|
|
57
|
+
|
|
58
|
+
def __getitem__(self, item):
|
|
59
|
+
return Image.open(self.file_list[item]).convert("RGB")
|
|
60
|
+
|
|
61
|
+
def __del__(self):
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def crop_and_resize(image, height, width):
|
|
66
|
+
image = np.array(image)
|
|
67
|
+
image_height, image_width, _ = image.shape
|
|
68
|
+
if image_height / image_width < height / width:
|
|
69
|
+
croped_width = int(image_height / height * width)
|
|
70
|
+
left = (image_width - croped_width) // 2
|
|
71
|
+
image = image[:, left: left+croped_width]
|
|
72
|
+
image = Image.fromarray(image).resize((width, height))
|
|
73
|
+
else:
|
|
74
|
+
croped_height = int(image_width / width * height)
|
|
75
|
+
left = (image_height - croped_height) // 2
|
|
76
|
+
image = image[left: left+croped_height, :]
|
|
77
|
+
image = Image.fromarray(image).resize((width, height))
|
|
78
|
+
return image
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class VideoData:
|
|
82
|
+
def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs):
|
|
83
|
+
if video_file is not None:
|
|
84
|
+
self.data_type = "video"
|
|
85
|
+
self.data = LowMemoryVideo(video_file, **kwargs)
|
|
86
|
+
elif image_folder is not None:
|
|
87
|
+
self.data_type = "images"
|
|
88
|
+
self.data = LowMemoryImageFolder(image_folder, **kwargs)
|
|
89
|
+
else:
|
|
90
|
+
raise ValueError("Cannot open video or image folder")
|
|
91
|
+
self.length = None
|
|
92
|
+
self.set_shape(height, width)
|
|
93
|
+
|
|
94
|
+
def raw_data(self):
|
|
95
|
+
frames = []
|
|
96
|
+
for i in range(self.__len__()):
|
|
97
|
+
frames.append(self.__getitem__(i))
|
|
98
|
+
return frames
|
|
99
|
+
|
|
100
|
+
def set_length(self, length):
|
|
101
|
+
self.length = length
|
|
102
|
+
|
|
103
|
+
def set_shape(self, height, width):
|
|
104
|
+
self.height = height
|
|
105
|
+
self.width = width
|
|
106
|
+
|
|
107
|
+
def __len__(self):
|
|
108
|
+
if self.length is None:
|
|
109
|
+
return len(self.data)
|
|
110
|
+
else:
|
|
111
|
+
return self.length
|
|
112
|
+
|
|
113
|
+
def shape(self):
|
|
114
|
+
if self.height is not None and self.width is not None:
|
|
115
|
+
return self.height, self.width
|
|
116
|
+
else:
|
|
117
|
+
height, width, _ = self.__getitem__(0).shape
|
|
118
|
+
return height, width
|
|
119
|
+
|
|
120
|
+
def __getitem__(self, item):
|
|
121
|
+
frame = self.data.__getitem__(item)
|
|
122
|
+
width, height = frame.size
|
|
123
|
+
if self.height is not None and self.width is not None:
|
|
124
|
+
if self.height != height or self.width != width:
|
|
125
|
+
frame = crop_and_resize(frame, self.height, self.width)
|
|
126
|
+
return frame
|
|
127
|
+
|
|
128
|
+
def __del__(self):
|
|
129
|
+
pass
|
|
130
|
+
|
|
131
|
+
def save_images(self, folder):
|
|
132
|
+
os.makedirs(folder, exist_ok=True)
|
|
133
|
+
for i in tqdm(range(self.__len__()), desc="Saving images"):
|
|
134
|
+
frame = self.__getitem__(i)
|
|
135
|
+
frame.save(os.path.join(folder, f"{i}.png"))
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def save_video(frames, save_path, fps, quality=9):
|
|
139
|
+
writer = imageio.get_writer(save_path, fps=fps, quality=quality)
|
|
140
|
+
for frame in tqdm(frames, desc="Saving video"):
|
|
141
|
+
frame = np.array(frame)
|
|
142
|
+
writer.append_data(frame)
|
|
143
|
+
writer.close()
|
|
144
|
+
|
|
145
|
+
def save_frames(frames, save_path):
|
|
146
|
+
os.makedirs(save_path, exist_ok=True)
|
|
147
|
+
for i, frame in enumerate(tqdm(frames, desc="Saving images")):
|
|
148
|
+
frame.save(os.path.join(save_path, f"{i}.png"))
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from einops import repeat
|
|
3
|
+
from PIL import Image
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ResidualDenseBlock(torch.nn.Module):
|
|
8
|
+
|
|
9
|
+
def __init__(self, num_feat=64, num_grow_ch=32):
|
|
10
|
+
super(ResidualDenseBlock, self).__init__()
|
|
11
|
+
self.conv1 = torch.nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
|
|
12
|
+
self.conv2 = torch.nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
|
|
13
|
+
self.conv3 = torch.nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
|
14
|
+
self.conv4 = torch.nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
|
15
|
+
self.conv5 = torch.nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
|
|
16
|
+
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
|
17
|
+
|
|
18
|
+
def forward(self, x):
|
|
19
|
+
x1 = self.lrelu(self.conv1(x))
|
|
20
|
+
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
|
21
|
+
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
|
22
|
+
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
|
23
|
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
|
24
|
+
return x5 * 0.2 + x
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class RRDB(torch.nn.Module):
|
|
28
|
+
|
|
29
|
+
def __init__(self, num_feat, num_grow_ch=32):
|
|
30
|
+
super(RRDB, self).__init__()
|
|
31
|
+
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
|
|
32
|
+
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
|
|
33
|
+
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
|
|
34
|
+
|
|
35
|
+
def forward(self, x):
|
|
36
|
+
out = self.rdb1(x)
|
|
37
|
+
out = self.rdb2(out)
|
|
38
|
+
out = self.rdb3(out)
|
|
39
|
+
return out * 0.2 + x
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class RRDBNet(torch.nn.Module):
|
|
43
|
+
|
|
44
|
+
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32):
|
|
45
|
+
super(RRDBNet, self).__init__()
|
|
46
|
+
self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
|
47
|
+
self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
|
|
48
|
+
self.conv_body = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
|
49
|
+
# upsample
|
|
50
|
+
self.conv_up1 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
|
51
|
+
self.conv_up2 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
|
52
|
+
self.conv_hr = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
|
53
|
+
self.conv_last = torch.nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
|
54
|
+
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
|
55
|
+
|
|
56
|
+
def forward(self, x):
|
|
57
|
+
feat = x
|
|
58
|
+
feat = self.conv_first(feat)
|
|
59
|
+
body_feat = self.conv_body(self.body(feat))
|
|
60
|
+
feat = feat + body_feat
|
|
61
|
+
# upsample
|
|
62
|
+
feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
|
|
63
|
+
feat = self.lrelu(self.conv_up1(feat))
|
|
64
|
+
feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
|
|
65
|
+
feat = self.lrelu(self.conv_up2(feat))
|
|
66
|
+
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
|
67
|
+
return out
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class ESRGAN(torch.nn.Module):
|
|
71
|
+
def __init__(self, model):
|
|
72
|
+
super().__init__()
|
|
73
|
+
self.model = model
|
|
74
|
+
|
|
75
|
+
@staticmethod
|
|
76
|
+
def from_pretrained(model_path):
|
|
77
|
+
model = RRDBNet()
|
|
78
|
+
state_dict = torch.load(model_path, map_location="cpu")["params_ema"]
|
|
79
|
+
model.load_state_dict(state_dict)
|
|
80
|
+
model.eval()
|
|
81
|
+
return ESRGAN(model)
|
|
82
|
+
|
|
83
|
+
def process_image(self, image):
|
|
84
|
+
image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)
|
|
85
|
+
return image
|
|
86
|
+
|
|
87
|
+
def process_images(self, images):
|
|
88
|
+
images = [self.process_image(image) for image in images]
|
|
89
|
+
images = torch.stack(images)
|
|
90
|
+
return images
|
|
91
|
+
|
|
92
|
+
def decode_images(self, images):
|
|
93
|
+
images = (images.permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
|
|
94
|
+
images = [Image.fromarray(image) for image in images]
|
|
95
|
+
return images
|
|
96
|
+
|
|
97
|
+
@torch.no_grad()
|
|
98
|
+
def upscale(self, images, batch_size=4, progress_bar=lambda x:x):
|
|
99
|
+
# Preprocess
|
|
100
|
+
input_tensor = self.process_images(images)
|
|
101
|
+
|
|
102
|
+
# Interpolate
|
|
103
|
+
output_tensor = []
|
|
104
|
+
for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
|
|
105
|
+
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
|
|
106
|
+
batch_input_tensor = input_tensor[batch_id: batch_id_]
|
|
107
|
+
batch_input_tensor = batch_input_tensor.to(
|
|
108
|
+
device=self.model.conv_first.weight.device,
|
|
109
|
+
dtype=self.model.conv_first.weight.dtype)
|
|
110
|
+
batch_output_tensor = self.model(batch_input_tensor)
|
|
111
|
+
output_tensor.append(batch_output_tensor.cpu())
|
|
112
|
+
|
|
113
|
+
# Output
|
|
114
|
+
output_tensor = torch.concat(output_tensor, dim=0)
|
|
115
|
+
|
|
116
|
+
# To images
|
|
117
|
+
output_images = self.decode_images(output_tensor)
|
|
118
|
+
return output_images
|