diffsynth 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. diffsynth/__init__.py +6 -0
  2. diffsynth/configs/__init__.py +0 -0
  3. diffsynth/configs/model_config.py +243 -0
  4. diffsynth/controlnets/__init__.py +2 -0
  5. diffsynth/controlnets/controlnet_unit.py +53 -0
  6. diffsynth/controlnets/processors.py +51 -0
  7. diffsynth/data/__init__.py +1 -0
  8. diffsynth/data/simple_text_image.py +35 -0
  9. diffsynth/data/video.py +148 -0
  10. diffsynth/extensions/ESRGAN/__init__.py +118 -0
  11. diffsynth/extensions/FastBlend/__init__.py +63 -0
  12. diffsynth/extensions/FastBlend/api.py +397 -0
  13. diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
  14. diffsynth/extensions/FastBlend/data.py +146 -0
  15. diffsynth/extensions/FastBlend/patch_match.py +298 -0
  16. diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
  17. diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
  18. diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
  19. diffsynth/extensions/FastBlend/runners/fast.py +141 -0
  20. diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
  21. diffsynth/extensions/RIFE/__init__.py +242 -0
  22. diffsynth/extensions/__init__.py +0 -0
  23. diffsynth/models/__init__.py +1 -0
  24. diffsynth/models/attention.py +89 -0
  25. diffsynth/models/downloader.py +66 -0
  26. diffsynth/models/hunyuan_dit.py +451 -0
  27. diffsynth/models/hunyuan_dit_text_encoder.py +163 -0
  28. diffsynth/models/kolors_text_encoder.py +1363 -0
  29. diffsynth/models/lora.py +195 -0
  30. diffsynth/models/model_manager.py +536 -0
  31. diffsynth/models/sd3_dit.py +798 -0
  32. diffsynth/models/sd3_text_encoder.py +1107 -0
  33. diffsynth/models/sd3_vae_decoder.py +81 -0
  34. diffsynth/models/sd3_vae_encoder.py +95 -0
  35. diffsynth/models/sd_controlnet.py +588 -0
  36. diffsynth/models/sd_ipadapter.py +57 -0
  37. diffsynth/models/sd_motion.py +199 -0
  38. diffsynth/models/sd_text_encoder.py +321 -0
  39. diffsynth/models/sd_unet.py +1108 -0
  40. diffsynth/models/sd_vae_decoder.py +336 -0
  41. diffsynth/models/sd_vae_encoder.py +282 -0
  42. diffsynth/models/sdxl_ipadapter.py +122 -0
  43. diffsynth/models/sdxl_motion.py +104 -0
  44. diffsynth/models/sdxl_text_encoder.py +759 -0
  45. diffsynth/models/sdxl_unet.py +1899 -0
  46. diffsynth/models/sdxl_vae_decoder.py +24 -0
  47. diffsynth/models/sdxl_vae_encoder.py +24 -0
  48. diffsynth/models/svd_image_encoder.py +505 -0
  49. diffsynth/models/svd_unet.py +2004 -0
  50. diffsynth/models/svd_vae_decoder.py +578 -0
  51. diffsynth/models/svd_vae_encoder.py +139 -0
  52. diffsynth/models/tiler.py +106 -0
  53. diffsynth/pipelines/__init__.py +9 -0
  54. diffsynth/pipelines/base.py +34 -0
  55. diffsynth/pipelines/dancer.py +178 -0
  56. diffsynth/pipelines/hunyuan_image.py +274 -0
  57. diffsynth/pipelines/pipeline_runner.py +105 -0
  58. diffsynth/pipelines/sd3_image.py +132 -0
  59. diffsynth/pipelines/sd_image.py +173 -0
  60. diffsynth/pipelines/sd_video.py +266 -0
  61. diffsynth/pipelines/sdxl_image.py +191 -0
  62. diffsynth/pipelines/sdxl_video.py +223 -0
  63. diffsynth/pipelines/svd_video.py +297 -0
  64. diffsynth/processors/FastBlend.py +142 -0
  65. diffsynth/processors/PILEditor.py +28 -0
  66. diffsynth/processors/RIFE.py +77 -0
  67. diffsynth/processors/__init__.py +0 -0
  68. diffsynth/processors/base.py +6 -0
  69. diffsynth/processors/sequencial_processor.py +41 -0
  70. diffsynth/prompters/__init__.py +6 -0
  71. diffsynth/prompters/base_prompter.py +57 -0
  72. diffsynth/prompters/hunyuan_dit_prompter.py +69 -0
  73. diffsynth/prompters/kolors_prompter.py +353 -0
  74. diffsynth/prompters/prompt_refiners.py +77 -0
  75. diffsynth/prompters/sd3_prompter.py +92 -0
  76. diffsynth/prompters/sd_prompter.py +73 -0
  77. diffsynth/prompters/sdxl_prompter.py +61 -0
  78. diffsynth/schedulers/__init__.py +3 -0
  79. diffsynth/schedulers/continuous_ode.py +59 -0
  80. diffsynth/schedulers/ddim.py +79 -0
  81. diffsynth/schedulers/flow_match.py +51 -0
  82. diffsynth/tokenizer_configs/__init__.py +0 -0
  83. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json +7 -0
  84. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json +16 -0
  85. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt +47020 -0
  86. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt +21128 -0
  87. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json +28 -0
  88. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json +1 -0
  89. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model +0 -0
  90. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json +1 -0
  91. diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model +0 -0
  92. diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json +12 -0
  93. diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt +0 -0
  94. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt +48895 -0
  95. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json +24 -0
  96. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json +34 -0
  97. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json +49410 -0
  98. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt +48895 -0
  99. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json +30 -0
  100. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json +30 -0
  101. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json +49410 -0
  102. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt +48895 -0
  103. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json +30 -0
  104. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json +38 -0
  105. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json +49410 -0
  106. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json +125 -0
  107. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/spiece.model +0 -0
  108. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json +129428 -0
  109. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json +940 -0
  110. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt +40213 -0
  111. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json +24 -0
  112. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json +38 -0
  113. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/vocab.json +49411 -0
  114. diffsynth/trainers/__init__.py +0 -0
  115. diffsynth/trainers/text_to_image.py +253 -0
  116. diffsynth-1.0.0.dist-info/LICENSE +201 -0
  117. diffsynth-1.0.0.dist-info/METADATA +23 -0
  118. diffsynth-1.0.0.dist-info/RECORD +120 -0
  119. diffsynth-1.0.0.dist-info/WHEEL +5 -0
  120. diffsynth-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,61 @@
1
+ from .base_prompter import BasePrompter, tokenize_long_prompt
2
+ from ..models.model_manager import ModelManager
3
+ from ..models import SDXLTextEncoder, SDXLTextEncoder2
4
+ from transformers import CLIPTokenizer
5
+ import torch, os
6
+
7
+
8
+
9
+ class SDXLPrompter(BasePrompter):
10
+ def __init__(
11
+ self,
12
+ tokenizer_path=None,
13
+ tokenizer_2_path=None
14
+ ):
15
+ if tokenizer_path is None:
16
+ base_path = os.path.dirname(os.path.dirname(__file__))
17
+ tokenizer_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion/tokenizer")
18
+ if tokenizer_2_path is None:
19
+ base_path = os.path.dirname(os.path.dirname(__file__))
20
+ tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_xl/tokenizer_2")
21
+ super().__init__()
22
+ self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
23
+ self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
24
+ self.text_encoder: SDXLTextEncoder = None
25
+ self.text_encoder_2: SDXLTextEncoder2 = None
26
+
27
+
28
+ def fetch_models(self, text_encoder: SDXLTextEncoder = None, text_encoder_2: SDXLTextEncoder2 = None):
29
+ self.text_encoder = text_encoder
30
+ self.text_encoder_2 = text_encoder_2
31
+
32
+
33
+ def encode_prompt(
34
+ self,
35
+ prompt,
36
+ clip_skip=1,
37
+ clip_skip_2=2,
38
+ positive=True,
39
+ device="cuda"
40
+ ):
41
+ prompt = self.process_prompt(prompt, positive=positive)
42
+
43
+ # 1
44
+ input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
45
+ prompt_emb_1 = self.text_encoder(input_ids, clip_skip=clip_skip)
46
+
47
+ # 2
48
+ input_ids_2 = tokenize_long_prompt(self.tokenizer_2, prompt).to(device)
49
+ add_text_embeds, prompt_emb_2 = self.text_encoder_2(input_ids_2, clip_skip=clip_skip_2)
50
+
51
+ # Merge
52
+ if prompt_emb_1.shape[0] != prompt_emb_2.shape[0]:
53
+ max_batch_size = min(prompt_emb_1.shape[0], prompt_emb_2.shape[0])
54
+ prompt_emb_1 = prompt_emb_1[: max_batch_size]
55
+ prompt_emb_2 = prompt_emb_2[: max_batch_size]
56
+ prompt_emb = torch.concatenate([prompt_emb_1, prompt_emb_2], dim=-1)
57
+
58
+ # For very long prompt, we only use the first 77 tokens to compute `add_text_embeds`.
59
+ add_text_embeds = add_text_embeds[0:1]
60
+ prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
61
+ return add_text_embeds, prompt_emb
@@ -0,0 +1,3 @@
1
+ from .ddim import EnhancedDDIMScheduler
2
+ from .continuous_ode import ContinuousODEScheduler
3
+ from .flow_match import FlowMatchScheduler
@@ -0,0 +1,59 @@
1
+ import torch
2
+
3
+
4
+ class ContinuousODEScheduler():
5
+
6
+ def __init__(self, num_inference_steps=100, sigma_max=700.0, sigma_min=0.002, rho=7.0):
7
+ self.sigma_max = sigma_max
8
+ self.sigma_min = sigma_min
9
+ self.rho = rho
10
+ self.set_timesteps(num_inference_steps)
11
+
12
+
13
+ def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0):
14
+ ramp = torch.linspace(1-denoising_strength, 1, num_inference_steps)
15
+ min_inv_rho = torch.pow(torch.tensor((self.sigma_min,)), (1 / self.rho))
16
+ max_inv_rho = torch.pow(torch.tensor((self.sigma_max,)), (1 / self.rho))
17
+ self.sigmas = torch.pow(max_inv_rho + ramp * (min_inv_rho - max_inv_rho), self.rho)
18
+ self.timesteps = torch.log(self.sigmas) * 0.25
19
+
20
+
21
+ def step(self, model_output, timestep, sample, to_final=False):
22
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
23
+ sigma = self.sigmas[timestep_id]
24
+ sample *= (sigma*sigma + 1).sqrt()
25
+ estimated_sample = -sigma / (sigma*sigma + 1).sqrt() * model_output + 1 / (sigma*sigma + 1) * sample
26
+ if to_final or timestep_id + 1 >= len(self.timesteps):
27
+ prev_sample = estimated_sample
28
+ else:
29
+ sigma_ = self.sigmas[timestep_id + 1]
30
+ derivative = 1 / sigma * (sample - estimated_sample)
31
+ prev_sample = sample + derivative * (sigma_ - sigma)
32
+ prev_sample /= (sigma_*sigma_ + 1).sqrt()
33
+ return prev_sample
34
+
35
+
36
+ def return_to_timestep(self, timestep, sample, sample_stablized):
37
+ # This scheduler doesn't support this function.
38
+ pass
39
+
40
+
41
+ def add_noise(self, original_samples, noise, timestep):
42
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
43
+ sigma = self.sigmas[timestep_id]
44
+ sample = (original_samples + noise * sigma) / (sigma*sigma + 1).sqrt()
45
+ return sample
46
+
47
+
48
+ def training_target(self, sample, noise, timestep):
49
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
50
+ sigma = self.sigmas[timestep_id]
51
+ target = (-(sigma*sigma + 1).sqrt() / sigma + 1 / (sigma*sigma + 1).sqrt() / sigma) * sample + 1 / (sigma*sigma + 1).sqrt() * noise
52
+ return target
53
+
54
+
55
+ def training_weight(self, timestep):
56
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
57
+ sigma = self.sigmas[timestep_id]
58
+ weight = (1 + sigma*sigma).sqrt() / sigma
59
+ return weight
@@ -0,0 +1,79 @@
1
+ import torch, math
2
+
3
+
4
+ class EnhancedDDIMScheduler():
5
+
6
+ def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="epsilon"):
7
+ self.num_train_timesteps = num_train_timesteps
8
+ if beta_schedule == "scaled_linear":
9
+ betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32))
10
+ elif beta_schedule == "linear":
11
+ betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
12
+ else:
13
+ raise NotImplementedError(f"{beta_schedule} is not implemented")
14
+ self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0).tolist()
15
+ self.set_timesteps(10)
16
+ self.prediction_type = prediction_type
17
+
18
+
19
+ def set_timesteps(self, num_inference_steps, denoising_strength=1.0):
20
+ # The timesteps are aligned to 999...0, which is different from other implementations,
21
+ # but I think this implementation is more reasonable in theory.
22
+ max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0)
23
+ num_inference_steps = min(num_inference_steps, max_timestep + 1)
24
+ if num_inference_steps == 1:
25
+ self.timesteps = torch.Tensor([max_timestep])
26
+ else:
27
+ step_length = max_timestep / (num_inference_steps - 1)
28
+ self.timesteps = torch.Tensor([round(max_timestep - i*step_length) for i in range(num_inference_steps)])
29
+
30
+
31
+ def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev):
32
+ if self.prediction_type == "epsilon":
33
+ weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t)
34
+ weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t)
35
+ prev_sample = sample * weight_x + model_output * weight_e
36
+ elif self.prediction_type == "v_prediction":
37
+ weight_e = -math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t)) + math.sqrt(alpha_prod_t * (1 - alpha_prod_t_prev))
38
+ weight_x = math.sqrt(alpha_prod_t * alpha_prod_t_prev) + math.sqrt((1 - alpha_prod_t) * (1 - alpha_prod_t_prev))
39
+ prev_sample = sample * weight_x + model_output * weight_e
40
+ else:
41
+ raise NotImplementedError(f"{self.prediction_type} is not implemented")
42
+ return prev_sample
43
+
44
+
45
+ def step(self, model_output, timestep, sample, to_final=False):
46
+ alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
47
+ if isinstance(timestep, torch.Tensor):
48
+ timestep = timestep.cpu()
49
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
50
+ if to_final or timestep_id + 1 >= len(self.timesteps):
51
+ alpha_prod_t_prev = 1.0
52
+ else:
53
+ timestep_prev = int(self.timesteps[timestep_id + 1])
54
+ alpha_prod_t_prev = self.alphas_cumprod[timestep_prev]
55
+
56
+ return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev)
57
+
58
+
59
+ def return_to_timestep(self, timestep, sample, sample_stablized):
60
+ alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
61
+ noise_pred = (sample - math.sqrt(alpha_prod_t) * sample_stablized) / math.sqrt(1 - alpha_prod_t)
62
+ return noise_pred
63
+
64
+
65
+ def add_noise(self, original_samples, noise, timestep):
66
+ sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
67
+ sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
68
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
69
+ return noisy_samples
70
+
71
+
72
+ def training_target(self, sample, noise, timestep):
73
+ if self.prediction_type == "epsilon":
74
+ return noise
75
+ else:
76
+ sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
77
+ sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
78
+ target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
79
+ return target
@@ -0,0 +1,51 @@
1
+ import torch
2
+
3
+
4
+
5
+ class FlowMatchScheduler():
6
+
7
+ def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002):
8
+ self.num_train_timesteps = num_train_timesteps
9
+ self.shift = shift
10
+ self.sigma_max = sigma_max
11
+ self.sigma_min = sigma_min
12
+ self.set_timesteps(num_inference_steps)
13
+
14
+
15
+ def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0):
16
+ sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
17
+ self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
18
+ self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
19
+ self.timesteps = self.sigmas * self.num_train_timesteps
20
+
21
+
22
+ def step(self, model_output, timestep, sample, to_final=False):
23
+ if isinstance(timestep, torch.Tensor):
24
+ timestep = timestep.cpu()
25
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
26
+ sigma = self.sigmas[timestep_id]
27
+ if to_final or timestep_id + 1 >= len(self.timesteps):
28
+ sigma_ = 0
29
+ else:
30
+ sigma_ = self.sigmas[timestep_id + 1]
31
+ prev_sample = sample + model_output * (sigma_ - sigma)
32
+ return prev_sample
33
+
34
+
35
+ def return_to_timestep(self, timestep, sample, sample_stablized):
36
+ # This scheduler doesn't support this function.
37
+ pass
38
+
39
+
40
+ def add_noise(self, original_samples, noise, timestep):
41
+ if isinstance(timestep, torch.Tensor):
42
+ timestep = timestep.cpu()
43
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
44
+ sigma = self.sigmas[timestep_id]
45
+ sample = (1 - sigma) * original_samples + sigma * noise
46
+ return sample
47
+
48
+
49
+ def training_target(self, sample, noise, timestep):
50
+ target = noise - sample
51
+ return target
File without changes
@@ -0,0 +1,7 @@
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
@@ -0,0 +1,16 @@
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "do_basic_tokenize": true,
4
+ "do_lower_case": true,
5
+ "mask_token": "[MASK]",
6
+ "name_or_path": "hfl/chinese-roberta-wwm-ext",
7
+ "never_split": null,
8
+ "pad_token": "[PAD]",
9
+ "sep_token": "[SEP]",
10
+ "special_tokens_map_file": "/home/chenweifeng/.cache/huggingface/hub/models--hfl--chinese-roberta-wwm-ext/snapshots/5c58d0b8ec1d9014354d691c538661bf00bfdb44/special_tokens_map.json",
11
+ "strip_accents": null,
12
+ "tokenize_chinese_chars": true,
13
+ "tokenizer_class": "BertTokenizer",
14
+ "unk_token": "[UNK]",
15
+ "model_max_length": 77
16
+ }