diffsynth 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. diffsynth/__init__.py +6 -0
  2. diffsynth/configs/__init__.py +0 -0
  3. diffsynth/configs/model_config.py +243 -0
  4. diffsynth/controlnets/__init__.py +2 -0
  5. diffsynth/controlnets/controlnet_unit.py +53 -0
  6. diffsynth/controlnets/processors.py +51 -0
  7. diffsynth/data/__init__.py +1 -0
  8. diffsynth/data/simple_text_image.py +35 -0
  9. diffsynth/data/video.py +148 -0
  10. diffsynth/extensions/ESRGAN/__init__.py +118 -0
  11. diffsynth/extensions/FastBlend/__init__.py +63 -0
  12. diffsynth/extensions/FastBlend/api.py +397 -0
  13. diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
  14. diffsynth/extensions/FastBlend/data.py +146 -0
  15. diffsynth/extensions/FastBlend/patch_match.py +298 -0
  16. diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
  17. diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
  18. diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
  19. diffsynth/extensions/FastBlend/runners/fast.py +141 -0
  20. diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
  21. diffsynth/extensions/RIFE/__init__.py +242 -0
  22. diffsynth/extensions/__init__.py +0 -0
  23. diffsynth/models/__init__.py +1 -0
  24. diffsynth/models/attention.py +89 -0
  25. diffsynth/models/downloader.py +66 -0
  26. diffsynth/models/hunyuan_dit.py +451 -0
  27. diffsynth/models/hunyuan_dit_text_encoder.py +163 -0
  28. diffsynth/models/kolors_text_encoder.py +1363 -0
  29. diffsynth/models/lora.py +195 -0
  30. diffsynth/models/model_manager.py +536 -0
  31. diffsynth/models/sd3_dit.py +798 -0
  32. diffsynth/models/sd3_text_encoder.py +1107 -0
  33. diffsynth/models/sd3_vae_decoder.py +81 -0
  34. diffsynth/models/sd3_vae_encoder.py +95 -0
  35. diffsynth/models/sd_controlnet.py +588 -0
  36. diffsynth/models/sd_ipadapter.py +57 -0
  37. diffsynth/models/sd_motion.py +199 -0
  38. diffsynth/models/sd_text_encoder.py +321 -0
  39. diffsynth/models/sd_unet.py +1108 -0
  40. diffsynth/models/sd_vae_decoder.py +336 -0
  41. diffsynth/models/sd_vae_encoder.py +282 -0
  42. diffsynth/models/sdxl_ipadapter.py +122 -0
  43. diffsynth/models/sdxl_motion.py +104 -0
  44. diffsynth/models/sdxl_text_encoder.py +759 -0
  45. diffsynth/models/sdxl_unet.py +1899 -0
  46. diffsynth/models/sdxl_vae_decoder.py +24 -0
  47. diffsynth/models/sdxl_vae_encoder.py +24 -0
  48. diffsynth/models/svd_image_encoder.py +505 -0
  49. diffsynth/models/svd_unet.py +2004 -0
  50. diffsynth/models/svd_vae_decoder.py +578 -0
  51. diffsynth/models/svd_vae_encoder.py +139 -0
  52. diffsynth/models/tiler.py +106 -0
  53. diffsynth/pipelines/__init__.py +9 -0
  54. diffsynth/pipelines/base.py +34 -0
  55. diffsynth/pipelines/dancer.py +178 -0
  56. diffsynth/pipelines/hunyuan_image.py +274 -0
  57. diffsynth/pipelines/pipeline_runner.py +105 -0
  58. diffsynth/pipelines/sd3_image.py +132 -0
  59. diffsynth/pipelines/sd_image.py +173 -0
  60. diffsynth/pipelines/sd_video.py +266 -0
  61. diffsynth/pipelines/sdxl_image.py +191 -0
  62. diffsynth/pipelines/sdxl_video.py +223 -0
  63. diffsynth/pipelines/svd_video.py +297 -0
  64. diffsynth/processors/FastBlend.py +142 -0
  65. diffsynth/processors/PILEditor.py +28 -0
  66. diffsynth/processors/RIFE.py +77 -0
  67. diffsynth/processors/__init__.py +0 -0
  68. diffsynth/processors/base.py +6 -0
  69. diffsynth/processors/sequencial_processor.py +41 -0
  70. diffsynth/prompters/__init__.py +6 -0
  71. diffsynth/prompters/base_prompter.py +57 -0
  72. diffsynth/prompters/hunyuan_dit_prompter.py +69 -0
  73. diffsynth/prompters/kolors_prompter.py +353 -0
  74. diffsynth/prompters/prompt_refiners.py +77 -0
  75. diffsynth/prompters/sd3_prompter.py +92 -0
  76. diffsynth/prompters/sd_prompter.py +73 -0
  77. diffsynth/prompters/sdxl_prompter.py +61 -0
  78. diffsynth/schedulers/__init__.py +3 -0
  79. diffsynth/schedulers/continuous_ode.py +59 -0
  80. diffsynth/schedulers/ddim.py +79 -0
  81. diffsynth/schedulers/flow_match.py +51 -0
  82. diffsynth/tokenizer_configs/__init__.py +0 -0
  83. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json +7 -0
  84. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json +16 -0
  85. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt +47020 -0
  86. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt +21128 -0
  87. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json +28 -0
  88. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json +1 -0
  89. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model +0 -0
  90. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json +1 -0
  91. diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model +0 -0
  92. diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json +12 -0
  93. diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt +0 -0
  94. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt +48895 -0
  95. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json +24 -0
  96. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json +34 -0
  97. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json +49410 -0
  98. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt +48895 -0
  99. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json +30 -0
  100. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json +30 -0
  101. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json +49410 -0
  102. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt +48895 -0
  103. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json +30 -0
  104. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json +38 -0
  105. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json +49410 -0
  106. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json +125 -0
  107. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/spiece.model +0 -0
  108. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json +129428 -0
  109. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json +940 -0
  110. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt +40213 -0
  111. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json +24 -0
  112. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json +38 -0
  113. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/vocab.json +49411 -0
  114. diffsynth/trainers/__init__.py +0 -0
  115. diffsynth/trainers/text_to_image.py +253 -0
  116. diffsynth-1.0.0.dist-info/LICENSE +201 -0
  117. diffsynth-1.0.0.dist-info/METADATA +23 -0
  118. diffsynth-1.0.0.dist-info/RECORD +120 -0
  119. diffsynth-1.0.0.dist-info/WHEEL +5 -0
  120. diffsynth-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,81 @@
1
+ import torch
2
+ from .sd_vae_decoder import VAEAttentionBlock, SDVAEDecoderStateDictConverter
3
+ from .sd_unet import ResnetBlock, UpSampler
4
+ from .tiler import TileWorker
5
+
6
+
7
+
8
+ class SD3VAEDecoder(torch.nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+ self.scaling_factor = 1.5305 # Different from SD 1.x
12
+ self.shift_factor = 0.0609 # Different from SD 1.x
13
+ self.conv_in = torch.nn.Conv2d(16, 512, kernel_size=3, padding=1) # Different from SD 1.x
14
+
15
+ self.blocks = torch.nn.ModuleList([
16
+ # UNetMidBlock2D
17
+ ResnetBlock(512, 512, eps=1e-6),
18
+ VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
19
+ ResnetBlock(512, 512, eps=1e-6),
20
+ # UpDecoderBlock2D
21
+ ResnetBlock(512, 512, eps=1e-6),
22
+ ResnetBlock(512, 512, eps=1e-6),
23
+ ResnetBlock(512, 512, eps=1e-6),
24
+ UpSampler(512),
25
+ # UpDecoderBlock2D
26
+ ResnetBlock(512, 512, eps=1e-6),
27
+ ResnetBlock(512, 512, eps=1e-6),
28
+ ResnetBlock(512, 512, eps=1e-6),
29
+ UpSampler(512),
30
+ # UpDecoderBlock2D
31
+ ResnetBlock(512, 256, eps=1e-6),
32
+ ResnetBlock(256, 256, eps=1e-6),
33
+ ResnetBlock(256, 256, eps=1e-6),
34
+ UpSampler(256),
35
+ # UpDecoderBlock2D
36
+ ResnetBlock(256, 128, eps=1e-6),
37
+ ResnetBlock(128, 128, eps=1e-6),
38
+ ResnetBlock(128, 128, eps=1e-6),
39
+ ])
40
+
41
+ self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6)
42
+ self.conv_act = torch.nn.SiLU()
43
+ self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
44
+
45
+ def tiled_forward(self, sample, tile_size=64, tile_stride=32):
46
+ hidden_states = TileWorker().tiled_forward(
47
+ lambda x: self.forward(x),
48
+ sample,
49
+ tile_size,
50
+ tile_stride,
51
+ tile_device=sample.device,
52
+ tile_dtype=sample.dtype
53
+ )
54
+ return hidden_states
55
+
56
+ def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
57
+ # For VAE Decoder, we do not need to apply the tiler on each layer.
58
+ if tiled:
59
+ return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
60
+
61
+ # 1. pre-process
62
+ hidden_states = sample / self.scaling_factor + self.shift_factor
63
+ hidden_states = self.conv_in(hidden_states)
64
+ time_emb = None
65
+ text_emb = None
66
+ res_stack = None
67
+
68
+ # 2. blocks
69
+ for i, block in enumerate(self.blocks):
70
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
71
+
72
+ # 3. output
73
+ hidden_states = self.conv_norm_out(hidden_states)
74
+ hidden_states = self.conv_act(hidden_states)
75
+ hidden_states = self.conv_out(hidden_states)
76
+
77
+ return hidden_states
78
+
79
+ @staticmethod
80
+ def state_dict_converter():
81
+ return SDVAEDecoderStateDictConverter()
@@ -0,0 +1,95 @@
1
+ import torch
2
+ from .sd_unet import ResnetBlock, DownSampler
3
+ from .sd_vae_encoder import VAEAttentionBlock, SDVAEEncoderStateDictConverter
4
+ from .tiler import TileWorker
5
+ from einops import rearrange
6
+
7
+
8
+ class SD3VAEEncoder(torch.nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+ self.scaling_factor = 1.5305 # Different from SD 1.x
12
+ self.shift_factor = 0.0609 # Different from SD 1.x
13
+ self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
14
+
15
+ self.blocks = torch.nn.ModuleList([
16
+ # DownEncoderBlock2D
17
+ ResnetBlock(128, 128, eps=1e-6),
18
+ ResnetBlock(128, 128, eps=1e-6),
19
+ DownSampler(128, padding=0, extra_padding=True),
20
+ # DownEncoderBlock2D
21
+ ResnetBlock(128, 256, eps=1e-6),
22
+ ResnetBlock(256, 256, eps=1e-6),
23
+ DownSampler(256, padding=0, extra_padding=True),
24
+ # DownEncoderBlock2D
25
+ ResnetBlock(256, 512, eps=1e-6),
26
+ ResnetBlock(512, 512, eps=1e-6),
27
+ DownSampler(512, padding=0, extra_padding=True),
28
+ # DownEncoderBlock2D
29
+ ResnetBlock(512, 512, eps=1e-6),
30
+ ResnetBlock(512, 512, eps=1e-6),
31
+ # UNetMidBlock2D
32
+ ResnetBlock(512, 512, eps=1e-6),
33
+ VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
34
+ ResnetBlock(512, 512, eps=1e-6),
35
+ ])
36
+
37
+ self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
38
+ self.conv_act = torch.nn.SiLU()
39
+ self.conv_out = torch.nn.Conv2d(512, 32, kernel_size=3, padding=1)
40
+
41
+ def tiled_forward(self, sample, tile_size=64, tile_stride=32):
42
+ hidden_states = TileWorker().tiled_forward(
43
+ lambda x: self.forward(x),
44
+ sample,
45
+ tile_size,
46
+ tile_stride,
47
+ tile_device=sample.device,
48
+ tile_dtype=sample.dtype
49
+ )
50
+ return hidden_states
51
+
52
+ def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
53
+ # For VAE Decoder, we do not need to apply the tiler on each layer.
54
+ if tiled:
55
+ return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
56
+
57
+ # 1. pre-process
58
+ hidden_states = self.conv_in(sample)
59
+ time_emb = None
60
+ text_emb = None
61
+ res_stack = None
62
+
63
+ # 2. blocks
64
+ for i, block in enumerate(self.blocks):
65
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
66
+
67
+ # 3. output
68
+ hidden_states = self.conv_norm_out(hidden_states)
69
+ hidden_states = self.conv_act(hidden_states)
70
+ hidden_states = self.conv_out(hidden_states)
71
+ hidden_states = hidden_states[:, :16]
72
+ hidden_states = (hidden_states - self.shift_factor) * self.scaling_factor
73
+
74
+ return hidden_states
75
+
76
+ def encode_video(self, sample, batch_size=8):
77
+ B = sample.shape[0]
78
+ hidden_states = []
79
+
80
+ for i in range(0, sample.shape[2], batch_size):
81
+
82
+ j = min(i + batch_size, sample.shape[2])
83
+ sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
84
+
85
+ hidden_states_batch = self(sample_batch)
86
+ hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
87
+
88
+ hidden_states.append(hidden_states_batch)
89
+
90
+ hidden_states = torch.concat(hidden_states, dim=2)
91
+ return hidden_states
92
+
93
+ @staticmethod
94
+ def state_dict_converter():
95
+ return SDVAEEncoderStateDictConverter()