diffsynth 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. diffsynth/__init__.py +6 -0
  2. diffsynth/configs/__init__.py +0 -0
  3. diffsynth/configs/model_config.py +243 -0
  4. diffsynth/controlnets/__init__.py +2 -0
  5. diffsynth/controlnets/controlnet_unit.py +53 -0
  6. diffsynth/controlnets/processors.py +51 -0
  7. diffsynth/data/__init__.py +1 -0
  8. diffsynth/data/simple_text_image.py +35 -0
  9. diffsynth/data/video.py +148 -0
  10. diffsynth/extensions/ESRGAN/__init__.py +118 -0
  11. diffsynth/extensions/FastBlend/__init__.py +63 -0
  12. diffsynth/extensions/FastBlend/api.py +397 -0
  13. diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
  14. diffsynth/extensions/FastBlend/data.py +146 -0
  15. diffsynth/extensions/FastBlend/patch_match.py +298 -0
  16. diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
  17. diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
  18. diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
  19. diffsynth/extensions/FastBlend/runners/fast.py +141 -0
  20. diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
  21. diffsynth/extensions/RIFE/__init__.py +242 -0
  22. diffsynth/extensions/__init__.py +0 -0
  23. diffsynth/models/__init__.py +1 -0
  24. diffsynth/models/attention.py +89 -0
  25. diffsynth/models/downloader.py +66 -0
  26. diffsynth/models/hunyuan_dit.py +451 -0
  27. diffsynth/models/hunyuan_dit_text_encoder.py +163 -0
  28. diffsynth/models/kolors_text_encoder.py +1363 -0
  29. diffsynth/models/lora.py +195 -0
  30. diffsynth/models/model_manager.py +536 -0
  31. diffsynth/models/sd3_dit.py +798 -0
  32. diffsynth/models/sd3_text_encoder.py +1107 -0
  33. diffsynth/models/sd3_vae_decoder.py +81 -0
  34. diffsynth/models/sd3_vae_encoder.py +95 -0
  35. diffsynth/models/sd_controlnet.py +588 -0
  36. diffsynth/models/sd_ipadapter.py +57 -0
  37. diffsynth/models/sd_motion.py +199 -0
  38. diffsynth/models/sd_text_encoder.py +321 -0
  39. diffsynth/models/sd_unet.py +1108 -0
  40. diffsynth/models/sd_vae_decoder.py +336 -0
  41. diffsynth/models/sd_vae_encoder.py +282 -0
  42. diffsynth/models/sdxl_ipadapter.py +122 -0
  43. diffsynth/models/sdxl_motion.py +104 -0
  44. diffsynth/models/sdxl_text_encoder.py +759 -0
  45. diffsynth/models/sdxl_unet.py +1899 -0
  46. diffsynth/models/sdxl_vae_decoder.py +24 -0
  47. diffsynth/models/sdxl_vae_encoder.py +24 -0
  48. diffsynth/models/svd_image_encoder.py +505 -0
  49. diffsynth/models/svd_unet.py +2004 -0
  50. diffsynth/models/svd_vae_decoder.py +578 -0
  51. diffsynth/models/svd_vae_encoder.py +139 -0
  52. diffsynth/models/tiler.py +106 -0
  53. diffsynth/pipelines/__init__.py +9 -0
  54. diffsynth/pipelines/base.py +34 -0
  55. diffsynth/pipelines/dancer.py +178 -0
  56. diffsynth/pipelines/hunyuan_image.py +274 -0
  57. diffsynth/pipelines/pipeline_runner.py +105 -0
  58. diffsynth/pipelines/sd3_image.py +132 -0
  59. diffsynth/pipelines/sd_image.py +173 -0
  60. diffsynth/pipelines/sd_video.py +266 -0
  61. diffsynth/pipelines/sdxl_image.py +191 -0
  62. diffsynth/pipelines/sdxl_video.py +223 -0
  63. diffsynth/pipelines/svd_video.py +297 -0
  64. diffsynth/processors/FastBlend.py +142 -0
  65. diffsynth/processors/PILEditor.py +28 -0
  66. diffsynth/processors/RIFE.py +77 -0
  67. diffsynth/processors/__init__.py +0 -0
  68. diffsynth/processors/base.py +6 -0
  69. diffsynth/processors/sequencial_processor.py +41 -0
  70. diffsynth/prompters/__init__.py +6 -0
  71. diffsynth/prompters/base_prompter.py +57 -0
  72. diffsynth/prompters/hunyuan_dit_prompter.py +69 -0
  73. diffsynth/prompters/kolors_prompter.py +353 -0
  74. diffsynth/prompters/prompt_refiners.py +77 -0
  75. diffsynth/prompters/sd3_prompter.py +92 -0
  76. diffsynth/prompters/sd_prompter.py +73 -0
  77. diffsynth/prompters/sdxl_prompter.py +61 -0
  78. diffsynth/schedulers/__init__.py +3 -0
  79. diffsynth/schedulers/continuous_ode.py +59 -0
  80. diffsynth/schedulers/ddim.py +79 -0
  81. diffsynth/schedulers/flow_match.py +51 -0
  82. diffsynth/tokenizer_configs/__init__.py +0 -0
  83. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json +7 -0
  84. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json +16 -0
  85. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt +47020 -0
  86. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt +21128 -0
  87. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json +28 -0
  88. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json +1 -0
  89. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model +0 -0
  90. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json +1 -0
  91. diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model +0 -0
  92. diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json +12 -0
  93. diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt +0 -0
  94. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt +48895 -0
  95. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json +24 -0
  96. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json +34 -0
  97. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json +49410 -0
  98. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt +48895 -0
  99. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json +30 -0
  100. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json +30 -0
  101. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json +49410 -0
  102. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt +48895 -0
  103. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json +30 -0
  104. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json +38 -0
  105. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json +49410 -0
  106. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json +125 -0
  107. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/spiece.model +0 -0
  108. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json +129428 -0
  109. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json +940 -0
  110. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt +40213 -0
  111. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json +24 -0
  112. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json +38 -0
  113. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/vocab.json +49411 -0
  114. diffsynth/trainers/__init__.py +0 -0
  115. diffsynth/trainers/text_to_image.py +253 -0
  116. diffsynth-1.0.0.dist-info/LICENSE +201 -0
  117. diffsynth-1.0.0.dist-info/METADATA +23 -0
  118. diffsynth-1.0.0.dist-info/RECORD +120 -0
  119. diffsynth-1.0.0.dist-info/WHEEL +5 -0
  120. diffsynth-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,122 @@
1
+ from .svd_image_encoder import SVDImageEncoder
2
+ from transformers import CLIPImageProcessor
3
+ import torch
4
+
5
+
6
+ class IpAdapterXLCLIPImageEmbedder(SVDImageEncoder):
7
+ def __init__(self):
8
+ super().__init__(embed_dim=1664, encoder_intermediate_size=8192, projection_dim=1280, num_encoder_layers=48, num_heads=16, head_dim=104)
9
+ self.image_processor = CLIPImageProcessor()
10
+
11
+ def forward(self, image):
12
+ pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values
13
+ pixel_values = pixel_values.to(device=self.embeddings.class_embedding.device, dtype=self.embeddings.class_embedding.dtype)
14
+ return super().forward(pixel_values)
15
+
16
+
17
+ class IpAdapterImageProjModel(torch.nn.Module):
18
+ def __init__(self, cross_attention_dim=2048, clip_embeddings_dim=1280, clip_extra_context_tokens=4):
19
+ super().__init__()
20
+ self.cross_attention_dim = cross_attention_dim
21
+ self.clip_extra_context_tokens = clip_extra_context_tokens
22
+ self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
23
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
24
+
25
+ def forward(self, image_embeds):
26
+ clip_extra_context_tokens = self.proj(image_embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
27
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
28
+ return clip_extra_context_tokens
29
+
30
+
31
+ class IpAdapterModule(torch.nn.Module):
32
+ def __init__(self, input_dim, output_dim):
33
+ super().__init__()
34
+ self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
35
+ self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
36
+
37
+ def forward(self, hidden_states):
38
+ ip_k = self.to_k_ip(hidden_states)
39
+ ip_v = self.to_v_ip(hidden_states)
40
+ return ip_k, ip_v
41
+
42
+
43
+ class SDXLIpAdapter(torch.nn.Module):
44
+ def __init__(self):
45
+ super().__init__()
46
+ shape_list = [(2048, 640)] * 4 + [(2048, 1280)] * 50 + [(2048, 640)] * 6 + [(2048, 1280)] * 10
47
+ self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(*shape) for shape in shape_list])
48
+ self.image_proj = IpAdapterImageProjModel()
49
+ self.set_full_adapter()
50
+
51
+ def set_full_adapter(self):
52
+ map_list = sum([
53
+ [(7, i) for i in range(2)],
54
+ [(10, i) for i in range(2)],
55
+ [(15, i) for i in range(10)],
56
+ [(18, i) for i in range(10)],
57
+ [(25, i) for i in range(10)],
58
+ [(28, i) for i in range(10)],
59
+ [(31, i) for i in range(10)],
60
+ [(35, i) for i in range(2)],
61
+ [(38, i) for i in range(2)],
62
+ [(41, i) for i in range(2)],
63
+ [(21, i) for i in range(10)],
64
+ ], [])
65
+ self.call_block_id = {i: j for j, i in enumerate(map_list)}
66
+
67
+ def set_less_adapter(self):
68
+ map_list = sum([
69
+ [(7, i) for i in range(2)],
70
+ [(10, i) for i in range(2)],
71
+ [(15, i) for i in range(10)],
72
+ [(18, i) for i in range(10)],
73
+ [(25, i) for i in range(10)],
74
+ [(28, i) for i in range(10)],
75
+ [(31, i) for i in range(10)],
76
+ [(35, i) for i in range(2)],
77
+ [(38, i) for i in range(2)],
78
+ [(41, i) for i in range(2)],
79
+ [(21, i) for i in range(10)],
80
+ ], [])
81
+ self.call_block_id = {i: j for j, i in enumerate(map_list) if j>=34 and j<44}
82
+
83
+ def forward(self, hidden_states, scale=1.0):
84
+ hidden_states = self.image_proj(hidden_states)
85
+ hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
86
+ ip_kv_dict = {}
87
+ for (block_id, transformer_id) in self.call_block_id:
88
+ ipadapter_id = self.call_block_id[(block_id, transformer_id)]
89
+ ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
90
+ if block_id not in ip_kv_dict:
91
+ ip_kv_dict[block_id] = {}
92
+ ip_kv_dict[block_id][transformer_id] = {
93
+ "ip_k": ip_k,
94
+ "ip_v": ip_v,
95
+ "scale": scale
96
+ }
97
+ return ip_kv_dict
98
+
99
+ @staticmethod
100
+ def state_dict_converter():
101
+ return SDXLIpAdapterStateDictConverter()
102
+
103
+
104
+ class SDXLIpAdapterStateDictConverter:
105
+ def __init__(self):
106
+ pass
107
+
108
+ def from_diffusers(self, state_dict):
109
+ state_dict_ = {}
110
+ for name in state_dict["ip_adapter"]:
111
+ names = name.split(".")
112
+ layer_id = str(int(names[0]) // 2)
113
+ name_ = ".".join(["ipadapter_modules"] + [layer_id] + names[1:])
114
+ state_dict_[name_] = state_dict["ip_adapter"][name]
115
+ for name in state_dict["image_proj"]:
116
+ name_ = "image_proj." + name
117
+ state_dict_[name_] = state_dict["image_proj"][name]
118
+ return state_dict_
119
+
120
+ def from_civitai(self, state_dict):
121
+ return self.from_diffusers(state_dict)
122
+
@@ -0,0 +1,104 @@
1
+ from .sd_motion import TemporalBlock
2
+ import torch
3
+
4
+
5
+
6
+ class SDXLMotionModel(torch.nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+ self.motion_modules = torch.nn.ModuleList([
10
+ TemporalBlock(8, 320//8, 320, eps=1e-6),
11
+ TemporalBlock(8, 320//8, 320, eps=1e-6),
12
+
13
+ TemporalBlock(8, 640//8, 640, eps=1e-6),
14
+ TemporalBlock(8, 640//8, 640, eps=1e-6),
15
+
16
+ TemporalBlock(8, 1280//8, 1280, eps=1e-6),
17
+ TemporalBlock(8, 1280//8, 1280, eps=1e-6),
18
+
19
+ TemporalBlock(8, 1280//8, 1280, eps=1e-6),
20
+ TemporalBlock(8, 1280//8, 1280, eps=1e-6),
21
+ TemporalBlock(8, 1280//8, 1280, eps=1e-6),
22
+
23
+ TemporalBlock(8, 640//8, 640, eps=1e-6),
24
+ TemporalBlock(8, 640//8, 640, eps=1e-6),
25
+ TemporalBlock(8, 640//8, 640, eps=1e-6),
26
+
27
+ TemporalBlock(8, 320//8, 320, eps=1e-6),
28
+ TemporalBlock(8, 320//8, 320, eps=1e-6),
29
+ TemporalBlock(8, 320//8, 320, eps=1e-6),
30
+ ])
31
+ self.call_block_id = {
32
+ 0: 0,
33
+ 2: 1,
34
+ 7: 2,
35
+ 10: 3,
36
+ 15: 4,
37
+ 18: 5,
38
+ 25: 6,
39
+ 28: 7,
40
+ 31: 8,
41
+ 35: 9,
42
+ 38: 10,
43
+ 41: 11,
44
+ 44: 12,
45
+ 46: 13,
46
+ 48: 14,
47
+ }
48
+
49
+ def forward(self):
50
+ pass
51
+
52
+ @staticmethod
53
+ def state_dict_converter():
54
+ return SDMotionModelStateDictConverter()
55
+
56
+
57
+ class SDMotionModelStateDictConverter:
58
+ def __init__(self):
59
+ pass
60
+
61
+ def from_diffusers(self, state_dict):
62
+ rename_dict = {
63
+ "norm": "norm",
64
+ "proj_in": "proj_in",
65
+ "transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q",
66
+ "transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k",
67
+ "transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v",
68
+ "transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out",
69
+ "transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1",
70
+ "transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q",
71
+ "transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k",
72
+ "transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v",
73
+ "transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out",
74
+ "transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2",
75
+ "transformer_blocks.0.norms.0": "transformer_blocks.0.norm1",
76
+ "transformer_blocks.0.norms.1": "transformer_blocks.0.norm2",
77
+ "transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj",
78
+ "transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff",
79
+ "transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3",
80
+ "proj_out": "proj_out",
81
+ }
82
+ name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")])
83
+ name_list += sorted([i for i in state_dict if i.startswith("mid_block.")])
84
+ name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")])
85
+ state_dict_ = {}
86
+ last_prefix, module_id = "", -1
87
+ for name in name_list:
88
+ names = name.split(".")
89
+ prefix_index = names.index("temporal_transformer") + 1
90
+ prefix = ".".join(names[:prefix_index])
91
+ if prefix != last_prefix:
92
+ last_prefix = prefix
93
+ module_id += 1
94
+ middle_name = ".".join(names[prefix_index:-1])
95
+ suffix = names[-1]
96
+ if "pos_encoder" in names:
97
+ rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]])
98
+ else:
99
+ rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
100
+ state_dict_[rename] = state_dict[name]
101
+ return state_dict_
102
+
103
+ def from_civitai(self, state_dict):
104
+ return self.from_diffusers(state_dict)