diffsynth 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. diffsynth/__init__.py +6 -0
  2. diffsynth/configs/__init__.py +0 -0
  3. diffsynth/configs/model_config.py +243 -0
  4. diffsynth/controlnets/__init__.py +2 -0
  5. diffsynth/controlnets/controlnet_unit.py +53 -0
  6. diffsynth/controlnets/processors.py +51 -0
  7. diffsynth/data/__init__.py +1 -0
  8. diffsynth/data/simple_text_image.py +35 -0
  9. diffsynth/data/video.py +148 -0
  10. diffsynth/extensions/ESRGAN/__init__.py +118 -0
  11. diffsynth/extensions/FastBlend/__init__.py +63 -0
  12. diffsynth/extensions/FastBlend/api.py +397 -0
  13. diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
  14. diffsynth/extensions/FastBlend/data.py +146 -0
  15. diffsynth/extensions/FastBlend/patch_match.py +298 -0
  16. diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
  17. diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
  18. diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
  19. diffsynth/extensions/FastBlend/runners/fast.py +141 -0
  20. diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
  21. diffsynth/extensions/RIFE/__init__.py +242 -0
  22. diffsynth/extensions/__init__.py +0 -0
  23. diffsynth/models/__init__.py +1 -0
  24. diffsynth/models/attention.py +89 -0
  25. diffsynth/models/downloader.py +66 -0
  26. diffsynth/models/hunyuan_dit.py +451 -0
  27. diffsynth/models/hunyuan_dit_text_encoder.py +163 -0
  28. diffsynth/models/kolors_text_encoder.py +1363 -0
  29. diffsynth/models/lora.py +195 -0
  30. diffsynth/models/model_manager.py +536 -0
  31. diffsynth/models/sd3_dit.py +798 -0
  32. diffsynth/models/sd3_text_encoder.py +1107 -0
  33. diffsynth/models/sd3_vae_decoder.py +81 -0
  34. diffsynth/models/sd3_vae_encoder.py +95 -0
  35. diffsynth/models/sd_controlnet.py +588 -0
  36. diffsynth/models/sd_ipadapter.py +57 -0
  37. diffsynth/models/sd_motion.py +199 -0
  38. diffsynth/models/sd_text_encoder.py +321 -0
  39. diffsynth/models/sd_unet.py +1108 -0
  40. diffsynth/models/sd_vae_decoder.py +336 -0
  41. diffsynth/models/sd_vae_encoder.py +282 -0
  42. diffsynth/models/sdxl_ipadapter.py +122 -0
  43. diffsynth/models/sdxl_motion.py +104 -0
  44. diffsynth/models/sdxl_text_encoder.py +759 -0
  45. diffsynth/models/sdxl_unet.py +1899 -0
  46. diffsynth/models/sdxl_vae_decoder.py +24 -0
  47. diffsynth/models/sdxl_vae_encoder.py +24 -0
  48. diffsynth/models/svd_image_encoder.py +505 -0
  49. diffsynth/models/svd_unet.py +2004 -0
  50. diffsynth/models/svd_vae_decoder.py +578 -0
  51. diffsynth/models/svd_vae_encoder.py +139 -0
  52. diffsynth/models/tiler.py +106 -0
  53. diffsynth/pipelines/__init__.py +9 -0
  54. diffsynth/pipelines/base.py +34 -0
  55. diffsynth/pipelines/dancer.py +178 -0
  56. diffsynth/pipelines/hunyuan_image.py +274 -0
  57. diffsynth/pipelines/pipeline_runner.py +105 -0
  58. diffsynth/pipelines/sd3_image.py +132 -0
  59. diffsynth/pipelines/sd_image.py +173 -0
  60. diffsynth/pipelines/sd_video.py +266 -0
  61. diffsynth/pipelines/sdxl_image.py +191 -0
  62. diffsynth/pipelines/sdxl_video.py +223 -0
  63. diffsynth/pipelines/svd_video.py +297 -0
  64. diffsynth/processors/FastBlend.py +142 -0
  65. diffsynth/processors/PILEditor.py +28 -0
  66. diffsynth/processors/RIFE.py +77 -0
  67. diffsynth/processors/__init__.py +0 -0
  68. diffsynth/processors/base.py +6 -0
  69. diffsynth/processors/sequencial_processor.py +41 -0
  70. diffsynth/prompters/__init__.py +6 -0
  71. diffsynth/prompters/base_prompter.py +57 -0
  72. diffsynth/prompters/hunyuan_dit_prompter.py +69 -0
  73. diffsynth/prompters/kolors_prompter.py +353 -0
  74. diffsynth/prompters/prompt_refiners.py +77 -0
  75. diffsynth/prompters/sd3_prompter.py +92 -0
  76. diffsynth/prompters/sd_prompter.py +73 -0
  77. diffsynth/prompters/sdxl_prompter.py +61 -0
  78. diffsynth/schedulers/__init__.py +3 -0
  79. diffsynth/schedulers/continuous_ode.py +59 -0
  80. diffsynth/schedulers/ddim.py +79 -0
  81. diffsynth/schedulers/flow_match.py +51 -0
  82. diffsynth/tokenizer_configs/__init__.py +0 -0
  83. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json +7 -0
  84. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json +16 -0
  85. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt +47020 -0
  86. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt +21128 -0
  87. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json +28 -0
  88. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json +1 -0
  89. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model +0 -0
  90. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json +1 -0
  91. diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model +0 -0
  92. diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json +12 -0
  93. diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt +0 -0
  94. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt +48895 -0
  95. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json +24 -0
  96. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json +34 -0
  97. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json +49410 -0
  98. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt +48895 -0
  99. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json +30 -0
  100. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json +30 -0
  101. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json +49410 -0
  102. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt +48895 -0
  103. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json +30 -0
  104. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json +38 -0
  105. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json +49410 -0
  106. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json +125 -0
  107. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/spiece.model +0 -0
  108. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json +129428 -0
  109. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json +940 -0
  110. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt +40213 -0
  111. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json +24 -0
  112. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json +38 -0
  113. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/vocab.json +49411 -0
  114. diffsynth/trainers/__init__.py +0 -0
  115. diffsynth/trainers/text_to_image.py +253 -0
  116. diffsynth-1.0.0.dist-info/LICENSE +201 -0
  117. diffsynth-1.0.0.dist-info/METADATA +23 -0
  118. diffsynth-1.0.0.dist-info/RECORD +120 -0
  119. diffsynth-1.0.0.dist-info/WHEEL +5 -0
  120. diffsynth-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,195 @@
1
+ import torch
2
+ from .sd_unet import SDUNet
3
+ from .sdxl_unet import SDXLUNet
4
+ from .sd_text_encoder import SDTextEncoder
5
+ from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
6
+ from .sd3_dit import SD3DiT
7
+ from .hunyuan_dit import HunyuanDiT
8
+
9
+
10
+
11
+ class LoRAFromCivitai:
12
+ def __init__(self):
13
+ self.supported_model_classes = []
14
+ self.lora_prefix = []
15
+ self.renamed_lora_prefix = {}
16
+ self.special_keys = {}
17
+
18
+
19
+ def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
20
+ renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "")
21
+ state_dict_ = {}
22
+ for key in state_dict:
23
+ if ".lora_up" not in key:
24
+ continue
25
+ if not key.startswith(lora_prefix):
26
+ continue
27
+ weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
28
+ weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
29
+ if len(weight_up.shape) == 4:
30
+ weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
31
+ weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
32
+ lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
33
+ else:
34
+ lora_weight = alpha * torch.mm(weight_up, weight_down)
35
+ target_name = key.split(".")[0].replace(lora_prefix, renamed_lora_prefix).replace("_", ".") + ".weight"
36
+ for special_key in self.special_keys:
37
+ target_name = target_name.replace(special_key, self.special_keys[special_key])
38
+ state_dict_[target_name] = lora_weight.cpu()
39
+ return state_dict_
40
+
41
+
42
+ def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None):
43
+ state_dict_model = model.state_dict()
44
+ state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=alpha)
45
+ if model_resource == "diffusers":
46
+ state_dict_lora = model.__class__.state_dict_converter().from_diffusers(state_dict_lora)
47
+ elif model_resource == "civitai":
48
+ state_dict_lora = model.__class__.state_dict_converter().from_civitai(state_dict_lora)
49
+ if len(state_dict_lora) > 0:
50
+ print(f" {len(state_dict_lora)} tensors are updated.")
51
+ for name in state_dict_lora:
52
+ state_dict_model[name] += state_dict_lora[name].to(
53
+ dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
54
+ model.load_state_dict(state_dict_model)
55
+
56
+
57
+ def match(self, model, state_dict_lora):
58
+ for lora_prefix, model_class in zip(self.lora_prefix, self.supported_model_classes):
59
+ if not isinstance(model, model_class):
60
+ continue
61
+ state_dict_model = model.state_dict()
62
+ for model_resource in ["diffusers", "civitai"]:
63
+ try:
64
+ state_dict_lora_ = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=1.0)
65
+ converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == "diffusers" \
66
+ else model.__class__.state_dict_converter().from_civitai
67
+ state_dict_lora_ = converter_fn(state_dict_lora_)
68
+ if len(state_dict_lora_) == 0:
69
+ continue
70
+ for name in state_dict_lora_:
71
+ if name not in state_dict_model:
72
+ break
73
+ else:
74
+ return lora_prefix, model_resource
75
+ except:
76
+ pass
77
+ return None
78
+
79
+
80
+
81
+ class SDLoRAFromCivitai(LoRAFromCivitai):
82
+ def __init__(self):
83
+ super().__init__()
84
+ self.supported_model_classes = [SDUNet, SDTextEncoder]
85
+ self.lora_prefix = ["lora_unet_", "lora_te_"]
86
+ self.special_keys = {
87
+ "down.blocks": "down_blocks",
88
+ "up.blocks": "up_blocks",
89
+ "mid.block": "mid_block",
90
+ "proj.in": "proj_in",
91
+ "proj.out": "proj_out",
92
+ "transformer.blocks": "transformer_blocks",
93
+ "to.q": "to_q",
94
+ "to.k": "to_k",
95
+ "to.v": "to_v",
96
+ "to.out": "to_out",
97
+ "text.model": "text_model",
98
+ "self.attn.q.proj": "self_attn.q_proj",
99
+ "self.attn.k.proj": "self_attn.k_proj",
100
+ "self.attn.v.proj": "self_attn.v_proj",
101
+ "self.attn.out.proj": "self_attn.out_proj",
102
+ "input.blocks": "model.diffusion_model.input_blocks",
103
+ "middle.block": "model.diffusion_model.middle_block",
104
+ "output.blocks": "model.diffusion_model.output_blocks",
105
+ }
106
+
107
+
108
+ class SDXLLoRAFromCivitai(LoRAFromCivitai):
109
+ def __init__(self):
110
+ super().__init__()
111
+ self.supported_model_classes = [SDXLUNet, SDXLTextEncoder, SDXLTextEncoder2]
112
+ self.lora_prefix = ["lora_unet_", "lora_te1_", "lora_te2_"]
113
+ self.renamed_lora_prefix = {"lora_te2_": "2"}
114
+ self.special_keys = {
115
+ "down.blocks": "down_blocks",
116
+ "up.blocks": "up_blocks",
117
+ "mid.block": "mid_block",
118
+ "proj.in": "proj_in",
119
+ "proj.out": "proj_out",
120
+ "transformer.blocks": "transformer_blocks",
121
+ "to.q": "to_q",
122
+ "to.k": "to_k",
123
+ "to.v": "to_v",
124
+ "to.out": "to_out",
125
+ "text.model": "conditioner.embedders.0.transformer.text_model",
126
+ "self.attn.q.proj": "self_attn.q_proj",
127
+ "self.attn.k.proj": "self_attn.k_proj",
128
+ "self.attn.v.proj": "self_attn.v_proj",
129
+ "self.attn.out.proj": "self_attn.out_proj",
130
+ "input.blocks": "model.diffusion_model.input_blocks",
131
+ "middle.block": "model.diffusion_model.middle_block",
132
+ "output.blocks": "model.diffusion_model.output_blocks",
133
+ "2conditioner.embedders.0.transformer.text_model.encoder.layers": "text_model.encoder.layers"
134
+ }
135
+
136
+
137
+
138
+ class GeneralLoRAFromPeft:
139
+ def __init__(self):
140
+ self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT]
141
+
142
+
143
+ def convert_state_dict(self, state_dict, alpha=1.0, device="cuda", torch_dtype=torch.float16):
144
+ state_dict_ = {}
145
+ for key in state_dict:
146
+ if ".lora_B." not in key:
147
+ continue
148
+ weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
149
+ weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
150
+ if len(weight_up.shape) == 4:
151
+ weight_up = weight_up.squeeze(3).squeeze(2)
152
+ weight_down = weight_down.squeeze(3).squeeze(2)
153
+ lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
154
+ else:
155
+ lora_weight = alpha * torch.mm(weight_up, weight_down)
156
+ keys = key.split(".")
157
+ keys.pop(keys.index("lora_B") + 1)
158
+ keys.pop(keys.index("lora_B"))
159
+ target_name = ".".join(keys)
160
+ state_dict_[target_name] = lora_weight.cpu()
161
+ return state_dict_
162
+
163
+
164
+ def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
165
+ state_dict_model = model.state_dict()
166
+ for name, param in state_dict_model.items():
167
+ torch_dtype = param.dtype
168
+ device = param.device
169
+ break
170
+ state_dict_lora = self.convert_state_dict(state_dict_lora, alpha=alpha, device=device, torch_dtype=torch_dtype)
171
+ if len(state_dict_lora) > 0:
172
+ print(f" {len(state_dict_lora)} tensors are updated.")
173
+ for name in state_dict_lora:
174
+ state_dict_model[name] += state_dict_lora[name].to(
175
+ dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
176
+ model.load_state_dict(state_dict_model)
177
+
178
+
179
+ def match(self, model, state_dict_lora):
180
+ for model_class in self.supported_model_classes:
181
+ if not isinstance(model, model_class):
182
+ continue
183
+ state_dict_model = model.state_dict()
184
+ try:
185
+ state_dict_lora_ = self.convert_state_dict(state_dict_lora, alpha=1.0)
186
+ if len(state_dict_lora_) == 0:
187
+ continue
188
+ for name in state_dict_lora_:
189
+ if name not in state_dict_model:
190
+ break
191
+ else:
192
+ return "", ""
193
+ except:
194
+ pass
195
+ return None