diffsynth-engine 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. diffsynth_engine/__init__.py +28 -0
  2. diffsynth_engine/algorithm/__init__.py +0 -0
  3. diffsynth_engine/algorithm/noise_scheduler/__init__.py +21 -0
  4. diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +10 -0
  5. diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +5 -0
  6. diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +28 -0
  7. diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +25 -0
  8. diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +50 -0
  9. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
  10. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +26 -0
  11. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +25 -0
  12. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +19 -0
  13. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +21 -0
  14. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +77 -0
  15. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +17 -0
  16. diffsynth_engine/algorithm/sampler/__init__.py +19 -0
  17. diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
  18. diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +22 -0
  19. diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
  20. diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +54 -0
  21. diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +32 -0
  22. diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +125 -0
  23. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +29 -0
  24. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +53 -0
  25. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +59 -0
  26. diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +29 -0
  27. diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +12 -0
  28. diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +30 -0
  29. diffsynth_engine/conf/models/components/vae.json +254 -0
  30. diffsynth_engine/conf/models/flux/flux_dit.json +105 -0
  31. diffsynth_engine/conf/models/flux/flux_text_encoder.json +20 -0
  32. diffsynth_engine/conf/models/flux/flux_vae.json +250 -0
  33. diffsynth_engine/conf/models/sd/sd_text_encoder.json +220 -0
  34. diffsynth_engine/conf/models/sd/sd_unet.json +397 -0
  35. diffsynth_engine/conf/models/sd3/sd3_dit.json +908 -0
  36. diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +756 -0
  37. diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +455 -0
  38. diffsynth_engine/conf/models/sdxl/sdxl_unet.json +1056 -0
  39. diffsynth_engine/conf/models/wan/dit/1.3b-t2v.json +13 -0
  40. diffsynth_engine/conf/models/wan/dit/14b-i2v.json +13 -0
  41. diffsynth_engine/conf/models/wan/dit/14b-t2v.json +13 -0
  42. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +48895 -0
  43. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +30 -0
  44. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +30 -0
  45. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +49410 -0
  46. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +125 -0
  47. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
  48. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +129428 -0
  49. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +940 -0
  50. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +48895 -0
  51. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +24 -0
  52. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +30 -0
  53. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +49410 -0
  54. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +40213 -0
  55. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +24 -0
  56. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +38 -0
  57. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +49411 -0
  58. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +308 -0
  59. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
  60. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +1028026 -0
  61. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +2748 -0
  62. diffsynth_engine/kernels/__init__.py +0 -0
  63. diffsynth_engine/models/__init__.py +7 -0
  64. diffsynth_engine/models/base.py +64 -0
  65. diffsynth_engine/models/basic/__init__.py +0 -0
  66. diffsynth_engine/models/basic/attention.py +217 -0
  67. diffsynth_engine/models/basic/lora.py +293 -0
  68. diffsynth_engine/models/basic/relative_position_emb.py +56 -0
  69. diffsynth_engine/models/basic/timestep.py +81 -0
  70. diffsynth_engine/models/basic/transformer_helper.py +88 -0
  71. diffsynth_engine/models/basic/unet_helper.py +244 -0
  72. diffsynth_engine/models/components/__init__.py +0 -0
  73. diffsynth_engine/models/components/clip.py +56 -0
  74. diffsynth_engine/models/components/t5.py +222 -0
  75. diffsynth_engine/models/components/vae.py +392 -0
  76. diffsynth_engine/models/flux/__init__.py +14 -0
  77. diffsynth_engine/models/flux/flux_dit.py +476 -0
  78. diffsynth_engine/models/flux/flux_text_encoder.py +88 -0
  79. diffsynth_engine/models/flux/flux_vae.py +78 -0
  80. diffsynth_engine/models/sd/__init__.py +12 -0
  81. diffsynth_engine/models/sd/sd_text_encoder.py +142 -0
  82. diffsynth_engine/models/sd/sd_unet.py +293 -0
  83. diffsynth_engine/models/sd/sd_vae.py +38 -0
  84. diffsynth_engine/models/sd3/__init__.py +14 -0
  85. diffsynth_engine/models/sd3/sd3_dit.py +302 -0
  86. diffsynth_engine/models/sd3/sd3_text_encoder.py +163 -0
  87. diffsynth_engine/models/sd3/sd3_vae.py +43 -0
  88. diffsynth_engine/models/sdxl/__init__.py +13 -0
  89. diffsynth_engine/models/sdxl/sdxl_text_encoder.py +307 -0
  90. diffsynth_engine/models/sdxl/sdxl_unet.py +306 -0
  91. diffsynth_engine/models/sdxl/sdxl_vae.py +38 -0
  92. diffsynth_engine/models/utils.py +54 -0
  93. diffsynth_engine/models/wan/__init__.py +0 -0
  94. diffsynth_engine/models/wan/wan_dit.py +497 -0
  95. diffsynth_engine/models/wan/wan_image_encoder.py +494 -0
  96. diffsynth_engine/models/wan/wan_text_encoder.py +297 -0
  97. diffsynth_engine/models/wan/wan_vae.py +771 -0
  98. diffsynth_engine/pipelines/__init__.py +18 -0
  99. diffsynth_engine/pipelines/base.py +253 -0
  100. diffsynth_engine/pipelines/flux_image.py +512 -0
  101. diffsynth_engine/pipelines/sd_image.py +352 -0
  102. diffsynth_engine/pipelines/sdxl_image.py +395 -0
  103. diffsynth_engine/pipelines/wan_video.py +524 -0
  104. diffsynth_engine/tokenizers/__init__.py +6 -0
  105. diffsynth_engine/tokenizers/base.py +157 -0
  106. diffsynth_engine/tokenizers/clip.py +288 -0
  107. diffsynth_engine/tokenizers/t5.py +194 -0
  108. diffsynth_engine/tokenizers/wan.py +74 -0
  109. diffsynth_engine/utils/__init__.py +0 -0
  110. diffsynth_engine/utils/constants.py +34 -0
  111. diffsynth_engine/utils/download.py +135 -0
  112. diffsynth_engine/utils/env.py +7 -0
  113. diffsynth_engine/utils/flag.py +46 -0
  114. diffsynth_engine/utils/fp8_linear.py +64 -0
  115. diffsynth_engine/utils/gguf.py +415 -0
  116. diffsynth_engine/utils/loader.py +17 -0
  117. diffsynth_engine/utils/lock.py +56 -0
  118. diffsynth_engine/utils/logging.py +12 -0
  119. diffsynth_engine/utils/offload.py +44 -0
  120. diffsynth_engine/utils/parallel.py +390 -0
  121. diffsynth_engine/utils/prompt.py +9 -0
  122. diffsynth_engine/utils/video.py +40 -0
  123. diffsynth_engine-0.0.0.dist-info/LICENSE +201 -0
  124. diffsynth_engine-0.0.0.dist-info/METADATA +236 -0
  125. diffsynth_engine-0.0.0.dist-info/RECORD +127 -0
  126. diffsynth_engine-0.0.0.dist-info/WHEEL +5 -0
  127. diffsynth_engine-0.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,28 @@
1
+ from .pipelines import (
2
+ FluxImagePipeline,
3
+ SDXLImagePipeline,
4
+ SDImagePipeline,
5
+ WanVideoPipeline,
6
+ FluxModelConfig,
7
+ SDXLModelConfig,
8
+ SDModelConfig,
9
+ WanModelConfig,
10
+ )
11
+ from .utils.download import fetch_model, fetch_modelscope_model, fetch_civitai_model
12
+ from .utils.video import load_video, save_video
13
+
14
+ __all__ = [
15
+ "FluxImagePipeline",
16
+ "SDXLImagePipeline",
17
+ "SDImagePipeline",
18
+ "WanVideoPipeline",
19
+ "FluxModelConfig",
20
+ "SDXLModelConfig",
21
+ "SDModelConfig",
22
+ "WanModelConfig",
23
+ "fetch_model",
24
+ "fetch_modelscope_model",
25
+ "fetch_civitai_model",
26
+ "load_video",
27
+ "save_video",
28
+ ]
File without changes
@@ -0,0 +1,21 @@
1
+ from .stable_diffusion.linear import ScaledLinearScheduler
2
+ from .stable_diffusion.beta import BetaScheduler
3
+ from .stable_diffusion.karras import KarrasScheduler
4
+ from .stable_diffusion.exponential import ExponentialScheduler
5
+ from .stable_diffusion.ddim import DDIMScheduler
6
+ from .stable_diffusion.sgm_uniform import SGMUniformScheduler
7
+ from .flow_match.recifited_flow import RecifitedFlowScheduler
8
+ from .flow_match.flow_ddim import FlowDDIMScheduler
9
+ from .flow_match.flow_beta import FlowBetaScheduler
10
+
11
+ __all__ = [
12
+ "ScaledLinearScheduler",
13
+ "BetaScheduler",
14
+ "KarrasScheduler",
15
+ "ExponentialScheduler",
16
+ "DDIMScheduler",
17
+ "SGMUniformScheduler",
18
+ "RecifitedFlowScheduler",
19
+ "FlowDDIMScheduler",
20
+ "FlowBetaScheduler",
21
+ ]
@@ -0,0 +1,10 @@
1
+ import torch
2
+
3
+
4
+ def append_zero(x):
5
+ return torch.cat([x, x.new_zeros([1])])
6
+
7
+
8
+ class BaseScheduler:
9
+ def schedule(self, num_inference_steps: int):
10
+ raise NotImplementedError()
@@ -0,0 +1,5 @@
1
+ from .recifited_flow import RecifitedFlowScheduler
2
+ from .flow_ddim import FlowDDIMScheduler
3
+ from .flow_beta import FlowBetaScheduler
4
+
5
+ __all__ = ["RecifitedFlowScheduler", "FlowDDIMScheduler", "FlowBetaScheduler"]
@@ -0,0 +1,28 @@
1
+ import torch
2
+ import numpy as np
3
+ import scipy.stats as stats
4
+
5
+ from diffsynth_engine.algorithm.noise_scheduler.base_scheduler import append_zero
6
+ from diffsynth_engine.algorithm.noise_scheduler.flow_match.recifited_flow import RecifitedFlowScheduler
7
+
8
+
9
+ class FlowBetaScheduler(RecifitedFlowScheduler):
10
+ def __init__(self):
11
+ super().__init__()
12
+ self.alpha = 0.6
13
+ self.beta = 0.6
14
+
15
+ def schedule(self, num_inference_steps: int, mu: float | None = None, sigmas: torch.Tensor | None = None):
16
+ pseudo_timestep_range = 10000
17
+ inner_sigmas = torch.arange(1, pseudo_timestep_range + 1, 1) / pseudo_timestep_range
18
+ inner_sigmas = self._time_shift(mu, 1.0, inner_sigmas)
19
+ sigma_min = inner_sigmas[0]
20
+ sigma_max = inner_sigmas[-1]
21
+
22
+ timesteps = 1 - np.linspace(0, 1, num_inference_steps)
23
+ timesteps = [stats.beta.ppf(x, self.alpha, self.beta) for x in timesteps]
24
+ sigmas = [sigma_min + (x * (sigma_max - sigma_min)) for x in timesteps]
25
+ sigmas = torch.FloatTensor(sigmas)
26
+ timesteps = self._sigma_to_t(sigmas)
27
+ sigmas = append_zero(sigmas)
28
+ return sigmas, timesteps
@@ -0,0 +1,25 @@
1
+ import torch
2
+
3
+ from diffsynth_engine.algorithm.noise_scheduler.base_scheduler import append_zero
4
+ from diffsynth_engine.algorithm.noise_scheduler.flow_match.recifited_flow import RecifitedFlowScheduler
5
+
6
+
7
+ class FlowDDIMScheduler(RecifitedFlowScheduler):
8
+ def __init__(self, shift=1.0, num_train_timesteps=1000, use_dynamic_shifting=False):
9
+ super().__init__(shift, num_train_timesteps, use_dynamic_shifting)
10
+ self.pseudo_timestep_range = 10000
11
+
12
+ def schedule(self, num_inference_steps: int, mu: float | None = None, sigmas: torch.Tensor | None = None):
13
+ inner_sigmas = torch.arange(1, self.pseudo_timestep_range + 1, 1) / self.pseudo_timestep_range
14
+ inner_sigmas = self._time_shift(mu, 1.0, inner_sigmas)
15
+ sigmas = []
16
+ ss = max(len(inner_sigmas) // num_inference_steps, 1)
17
+ for i in range(1, len(inner_sigmas), ss):
18
+ sigmas.append(float(inner_sigmas[i]))
19
+ sigmas = sigmas[::-1]
20
+ sigmas = torch.FloatTensor(sigmas)
21
+
22
+ timesteps = self._sigma_to_t(sigmas)
23
+ sigmas = append_zero(sigmas)
24
+
25
+ return sigmas, timesteps
@@ -0,0 +1,50 @@
1
+ import torch
2
+ import math
3
+
4
+ from diffsynth_engine.algorithm.noise_scheduler.base_scheduler import append_zero, BaseScheduler
5
+
6
+
7
+ class RecifitedFlowScheduler(BaseScheduler):
8
+ def __init__(
9
+ self,
10
+ shift=1.0,
11
+ sigma_min=0.001,
12
+ sigma_max=1.0,
13
+ num_train_timesteps=1000,
14
+ use_dynamic_shifting=False,
15
+ ):
16
+ self.shift = shift
17
+ self.sigma_min = sigma_min
18
+ self.sigma_max = sigma_max
19
+ self.num_train_timesteps = num_train_timesteps
20
+ self.use_dynamic_shifting = use_dynamic_shifting
21
+
22
+ def _sigma_to_t(self, sigma):
23
+ return sigma * self.num_train_timesteps
24
+
25
+ def _t_to_sigma(self, t):
26
+ return t / self.num_train_timesteps
27
+
28
+ def _time_shift(self, mu: float, sigma: float, t: torch.Tensor):
29
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
30
+
31
+ def _shift_sigma(self, sigma: torch.Tensor, shift: float):
32
+ return shift * sigma / (1 + (shift - 1) * sigma)
33
+
34
+ def schedule(
35
+ self,
36
+ num_inference_steps: int,
37
+ mu: float | None = None,
38
+ sigma_min: float | None = None,
39
+ sigma_max: float | None = None,
40
+ ):
41
+ sigma_min = self.sigma_min if sigma_min is None else sigma_min
42
+ sigma_max = self.sigma_max if sigma_max is None else sigma_max
43
+ sigmas = torch.linspace(sigma_max, sigma_min, num_inference_steps)
44
+ if self.use_dynamic_shifting:
45
+ sigmas = self._time_shift(mu, 1.0, sigmas) # FLUX
46
+ else:
47
+ sigmas = self._shift_sigma(sigmas, self.shift)
48
+ timesteps = sigmas * self.num_train_timesteps
49
+ sigmas = append_zero(sigmas)
50
+ return sigmas, timesteps
@@ -0,0 +1,26 @@
1
+ import torch
2
+ import numpy as np
3
+ import scipy.stats as stats
4
+
5
+ from diffsynth_engine.algorithm.noise_scheduler.stable_diffusion.linear import ScaledLinearScheduler
6
+ from diffsynth_engine.algorithm.noise_scheduler.base_scheduler import append_zero
7
+
8
+
9
+ class BetaScheduler(ScaledLinearScheduler):
10
+ """
11
+ Implemented based on: https://arxiv.org/abs/2407.12173
12
+ """
13
+
14
+ def __init__(self):
15
+ super().__init__()
16
+ self.alpha = 0.6
17
+ self.beta = 0.6
18
+
19
+ def schedule(self, num_inference_steps: int):
20
+ timesteps = 1 - np.linspace(0, 1, num_inference_steps)
21
+ timesteps = [stats.beta.ppf(x, self.alpha, self.beta) for x in timesteps]
22
+ sigmas = [self.sigma_min + (x * (self.sigma_max - self.sigma_min)) for x in timesteps]
23
+ sigmas = torch.FloatTensor(sigmas).to(self.device)
24
+ timesteps = self.sigma_to_t(sigmas)
25
+ sigmas = append_zero(sigmas)
26
+ return sigmas, timesteps
@@ -0,0 +1,25 @@
1
+ import torch
2
+
3
+ from diffsynth_engine.algorithm.noise_scheduler.stable_diffusion.linear import ScaledLinearScheduler
4
+ from diffsynth_engine.algorithm.noise_scheduler.base_scheduler import append_zero
5
+
6
+
7
+ class DDIMScheduler(ScaledLinearScheduler):
8
+ """
9
+ Implemented based on: https://arxiv.org/pdf/2010.02502.pdf
10
+ """
11
+
12
+ def __init__(self):
13
+ super().__init__()
14
+
15
+ def schedule(self, num_inference_steps: int):
16
+ inner_sigmas = self.get_sigmas()
17
+ sigmas = []
18
+ ss = max(len(inner_sigmas) // num_inference_steps, 1)
19
+ for i in range(1, len(inner_sigmas), ss):
20
+ sigmas.append(float(inner_sigmas[i]))
21
+ sigmas = sigmas[::-1]
22
+ sigmas = torch.FloatTensor(sigmas).to(self.device)
23
+ timesteps = self.sigma_to_t(sigmas)
24
+ sigmas = append_zero(sigmas)
25
+ return sigmas, timesteps
@@ -0,0 +1,19 @@
1
+ import torch
2
+ import math
3
+
4
+ from diffsynth_engine.algorithm.noise_scheduler.stable_diffusion.linear import ScaledLinearScheduler
5
+ from diffsynth_engine.algorithm.noise_scheduler.base_scheduler import append_zero
6
+
7
+
8
+ class ExponentialScheduler(ScaledLinearScheduler):
9
+ def __init__(self):
10
+ super().__init__()
11
+
12
+ def schedule(self, num_inference_steps: int):
13
+ """Constructs an exponential noise schedule."""
14
+ sigmas = torch.linspace(
15
+ math.log(self.sigma_max), math.log(self.sigma_min), num_inference_steps, device=self.device
16
+ ).exp()
17
+ timesteps = self.sigma_to_t(sigmas)
18
+ sigmas = append_zero(sigmas)
19
+ return sigmas, timesteps
@@ -0,0 +1,21 @@
1
+ import torch
2
+
3
+ from diffsynth_engine.algorithm.noise_scheduler.stable_diffusion.linear import ScaledLinearScheduler
4
+ from diffsynth_engine.algorithm.noise_scheduler.base_scheduler import append_zero
5
+
6
+
7
+ class KarrasScheduler(ScaledLinearScheduler):
8
+ def __init__(self):
9
+ super().__init__()
10
+ self.rho = 7.0
11
+ self.device = "cpu"
12
+
13
+ def schedule(self, num_inference_steps: int):
14
+ """Constructs the noise schedule of Karras et al. (2022)."""
15
+ ramp = torch.linspace(0, 1, num_inference_steps, device=self.device)
16
+ min_inv_rho = self.sigma_min ** (1 / self.rho)
17
+ max_inv_rho = self.sigma_max ** (1 / self.rho)
18
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
19
+ timesteps = self.sigma_to_t(sigmas)
20
+ sigmas = append_zero(sigmas).to(self.device)
21
+ return sigmas, timesteps
@@ -0,0 +1,77 @@
1
+ import torch
2
+
3
+ from diffsynth_engine.algorithm.noise_scheduler.base_scheduler import BaseScheduler, append_zero
4
+
5
+
6
+ def linear_beta_schedule(beta_start: float = 0.00085, beta_end: float = 0.0120, num_train_steps: int = 1000):
7
+ """
8
+ DDPM Schedule
9
+ """
10
+ return torch.linspace(beta_start, beta_end, num_train_steps)
11
+
12
+
13
+ def scaled_linear_beta_schedule(beta_start: float = 0.00085, beta_end: float = 0.0120, num_train_steps: int = 1000):
14
+ """
15
+ Stable Diffusion Schedule
16
+ """
17
+ return torch.linspace(beta_start**0.5, beta_end**0.5, num_train_steps) ** 2
18
+
19
+
20
+ class ScaledLinearScheduler(BaseScheduler):
21
+ def __init__(self):
22
+ self.device = "cpu"
23
+ self.num_train_steps = 1000
24
+ self.beta_start = 0.00085
25
+ self.beta_end = 0.0120
26
+ self.sigmas = self.get_sigmas()
27
+ self.log_sigmas = self.sigmas.log()
28
+
29
+ @property
30
+ def sigma_min(self):
31
+ return self.sigmas[0]
32
+
33
+ @property
34
+ def sigma_max(self):
35
+ return self.sigmas[-1]
36
+
37
+ def get_sigmas(self):
38
+ # Stable Diffusion Sigmas
39
+ # len(sigmas) == 1000, sigma_min=sigmas[0] == 0.0292, sigma_max=sigmas[-1] == 14.6146
40
+ betas = scaled_linear_beta_schedule(
41
+ beta_start=self.beta_start, beta_end=self.beta_end, num_train_steps=self.num_train_steps
42
+ )
43
+ alphas = 1.0 - betas
44
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
45
+ sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
46
+ return sigmas
47
+
48
+ def sigma_to_t(self, sigma):
49
+ """
50
+ 找到sigma.log()在self.log_sigmas中的位置(low和high), 进行加权插值得到t
51
+ """
52
+ log_sigma = sigma.log()
53
+ dists = log_sigma - self.log_sigmas[:, None]
54
+ low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
55
+ high_idx = low_idx + 1
56
+ low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
57
+ w = (low - log_sigma) / (low - high)
58
+ w = w.clamp(0, 1)
59
+ t = (1 - w) * low_idx + w * high_idx
60
+ return t.view(sigma.shape)
61
+
62
+ def t_to_sigma(self, t):
63
+ """
64
+ 对t进行floor和ceil, 得到low_idx和high_idx, 计算对应位置的log_sigma, 进行加权插值并exp得到sigma
65
+ """
66
+ t = t.float()
67
+ low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
68
+ log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
69
+ return log_sigma.exp()
70
+
71
+ def schedule(self, num_inference_steps: int):
72
+ """
73
+ Uniformly sample timesteps for inference
74
+ """
75
+ timesteps = torch.linspace(self.num_train_steps - 1, 0, num_inference_steps, device=self.sigmas.device)
76
+ sigmas = append_zero(self.t_to_sigma(timesteps))
77
+ return sigmas, timesteps
@@ -0,0 +1,17 @@
1
+ import torch
2
+
3
+ from diffsynth_engine.algorithm.noise_scheduler.stable_diffusion.linear import ScaledLinearScheduler
4
+ from diffsynth_engine.algorithm.noise_scheduler.base_scheduler import append_zero
5
+
6
+
7
+ class SGMUniformScheduler(ScaledLinearScheduler):
8
+ def __init__(self):
9
+ super().__init__()
10
+
11
+ def schedule(self, num_inference_steps: int):
12
+ # suppose sigma_min and sigma_max is default value
13
+ timesteps = torch.linspace(999, 0, num_inference_steps + 1)[:-1]
14
+ sigmas = [self.t_to_sigma(timestep) for timestep in timesteps]
15
+ sigmas = torch.FloatTensor(sigmas).to(self.device)
16
+ sigmas = append_zero(sigmas)
17
+ return sigmas, timesteps
@@ -0,0 +1,19 @@
1
+ from .stable_diffusion.ddpm import DDPMSampler
2
+ from .stable_diffusion.euler import EulerSampler
3
+ from .stable_diffusion.euler_ancestral import EulerAncestralSampler
4
+ from .stable_diffusion.dpmpp_2m import DPMSolverPlusPlus2MSampler
5
+ from .stable_diffusion.dpmpp_2m_sde import DPMSolverPlusPlus2MSDESampler
6
+ from .stable_diffusion.dpmpp_3m_sde import DPMSolverPlusPlus3MSDESampler
7
+ from .stable_diffusion.deis import DEISSampler
8
+ from .flow_match.flow_match_euler import FlowMatchEulerSampler
9
+
10
+ __all__ = [
11
+ "DDPMSampler",
12
+ "EulerSampler",
13
+ "EulerAncestralSampler",
14
+ "DPMSolverPlusPlus2MSampler",
15
+ "DPMSolverPlusPlus2MSDESampler",
16
+ "DPMSolverPlusPlus3MSDESampler",
17
+ "DEISSampler",
18
+ "FlowMatchEulerSampler",
19
+ ]
@@ -0,0 +1,22 @@
1
+ import torch
2
+
3
+
4
+ class FlowMatchEulerSampler:
5
+ def initialize(self, init_latents, timesteps, sigmas, mask=None):
6
+ self.init_latents = init_latents
7
+ self.timesteps = timesteps
8
+ self.sigmas = sigmas
9
+ self.mask = mask
10
+
11
+ def step(self, latents, model_outputs, i):
12
+ if self.mask is not None:
13
+ model_outputs = model_outputs * self.mask + self.init_latents * (1 - self.mask)
14
+
15
+ dt = self.sigmas[i + 1] - self.sigmas[i]
16
+ latents = latents.to(dtype=torch.float32)
17
+ latents = latents + model_outputs * dt
18
+ latents = latents.to(dtype=model_outputs.dtype)
19
+ return latents
20
+
21
+ def add_noise(self, latents, noise, sigma):
22
+ return (1 - sigma) * latents + noise * sigma
@@ -0,0 +1,54 @@
1
+ import torch
2
+ import torchsde
3
+
4
+
5
+ class BatchedBrownianTree:
6
+ """A wrapper around torchsde.BrownianTree that enables batches of entropy."""
7
+
8
+ def __init__(self, x, t0, t1, seed=None, **kwargs):
9
+ t0, t1, self.sign = self.sort(t0, t1)
10
+ w0 = kwargs.get("w0", torch.zeros_like(x))
11
+ if seed is None:
12
+ seed = torch.randint(0, 2**63 - 1, []).item()
13
+ self.batched = True
14
+ try:
15
+ assert len(seed) == x.shape[0]
16
+ w0 = w0[0]
17
+ except TypeError:
18
+ seed = [seed]
19
+ self.batched = False
20
+ self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
21
+
22
+ @staticmethod
23
+ def sort(a, b):
24
+ return (a, b, 1) if a < b else (b, a, -1)
25
+
26
+ def __call__(self, t0, t1):
27
+ t0, t1, sign = self.sort(t0, t1)
28
+ w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
29
+ return w if self.batched else w[0]
30
+
31
+
32
+ class BrownianTreeNoiseSampler:
33
+ """A noise sampler backed by a torchsde.BrownianTree.
34
+
35
+ Args:
36
+ x (Tensor): The tensor whose shape, device and dtype to use to generate
37
+ random samples.
38
+ sigma_min (float): The low end of the valid interval.
39
+ sigma_max (float): The high end of the valid interval.
40
+ seed (int or List[int]): The random seed. If a list of seeds is
41
+ supplied instead of a single integer, then the noise sampler will
42
+ use one BrownianTree per batch item, each with its own seed.
43
+ transform (callable): A function that maps sigma to the sampler's
44
+ internal timestep.
45
+ """
46
+
47
+ def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
48
+ self.transform = transform
49
+ t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
50
+ self.tree = BatchedBrownianTree(x, t0, t1, seed)
51
+
52
+ def __call__(self, sigma, sigma_next):
53
+ t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
54
+ return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
@@ -0,0 +1,32 @@
1
+ import torch
2
+ from .epsilon import EpsilonSampler
3
+
4
+
5
+ class DDPMSampler(EpsilonSampler):
6
+ def _step_function(self, x, sigma, sigma_prev, noise):
7
+ alpha_cumprod = 1 / ((sigma * sigma) + 1)
8
+ alpha_cumprod_prev = 1 / ((sigma_prev * sigma_prev) + 1)
9
+ alpha = alpha_cumprod / alpha_cumprod_prev
10
+
11
+ mu = (1.0 / alpha) ** 0.5 * (x - (1 - alpha) * noise / (1 - alpha_cumprod) ** 0.5)
12
+ if sigma_prev > 0:
13
+ # Caution: this randn tensor needs to be controlled by `torch.manual_seed`.
14
+ mu += ((1 - alpha) * (1.0 - alpha_cumprod_prev) / (1.0 - alpha_cumprod)) ** 0.5 * torch.randn_like(x)
15
+ return mu
16
+
17
+ def step(self, latents, model_outputs, i):
18
+ sigma = self.sigmas[i]
19
+ sigma_next = self.sigmas[i + 1]
20
+ latents = self._scaling(sigma, latents)
21
+
22
+ denoised = self._to_denoised(sigma, model_outputs, latents)
23
+ latents = self._step_function(
24
+ latents / (1.0 + sigma**2.0) ** 0.5, sigma, sigma_next, (latents - denoised) / sigma
25
+ )
26
+
27
+ latents *= (1.0 + sigma_next**2.0) ** 0.5
28
+
29
+ return self._unscaling(self.sigmas[i + 1], latents)
30
+
31
+ def step2(self, latents, model_outputs, i):
32
+ return self._step_function(latents, self.sigmas[i], self.sigmas[i + 1], model_outputs)
@@ -0,0 +1,125 @@
1
+ import torch
2
+
3
+ from diffsynth_engine.algorithm.sampler.stable_diffusion.epsilon import EpsilonSampler
4
+
5
+
6
+ class DEISSampler(EpsilonSampler):
7
+ """
8
+ According to the implementation of the webui forge, deis_mode only supports tab and rhoab.
9
+ """
10
+
11
+ def initialize(self, init_latents, timesteps, sigmas, mask):
12
+ super().initialize(init_latents, timesteps, sigmas, mask)
13
+ self.max_order = 3
14
+ self.sigmas = sigmas
15
+ self.timesteps = timesteps
16
+ self.lower_order_nums = 0
17
+ self.coeff_list = get_deis_coeff_list(self.sigmas, self.max_order)
18
+ self.coeff_buffer = []
19
+
20
+ def step(self, latents, model_outputs, i):
21
+ s, s_next = self.sigmas[i], self.sigmas[i + 1]
22
+ denoised = latents - model_outputs * s
23
+
24
+ d = (latents - denoised) / s
25
+ order = min(self.max_order, i + 1)
26
+ if self.sigmas[i + 1] <= 0:
27
+ order = 1
28
+ if order == 1:
29
+ x_next = latents + (s_next - s) * d
30
+ elif order == 2:
31
+ coeff, coeff_prev1 = self.coeff_list[i]
32
+ x_next = latents + coeff * d + coeff_prev1 * self.coeff_buffer[-1]
33
+ elif order == 3:
34
+ coeff, coeff_prev1, coeff_prev2 = self.coeff_list[i]
35
+ x_next = latents + coeff * d + coeff_prev1 * self.coeff_buffer[-1] + coeff_prev2 * self.coeff_buffer[-2]
36
+ elif order == 4:
37
+ coeff, coeff_prev1, coeff_prev2, coeff_prev3 = self.coeff_list[i]
38
+ x_next = (
39
+ latents
40
+ + coeff * d
41
+ + coeff_prev1 * self.coeff_buffer[-1]
42
+ + coeff_prev2 * self.coeff_buffer[-2]
43
+ + coeff_prev3 * self.coeff_buffer[-3]
44
+ )
45
+
46
+ if len(self.coeff_buffer) == self.max_order - 1:
47
+ for k in range(self.max_order - 2):
48
+ self.coeff_buffer[k] = self.coeff_buffer[k + 1]
49
+ self.coeff_buffer[-1] = d
50
+ else:
51
+ self.coeff_buffer.append(d)
52
+ return x_next
53
+
54
+
55
+ # Taken from: https://github.com/zju-pi/diff-sampler/blob/main/gits-main/solver_utils.py
56
+ # under Apache 2 license
57
+ # A pytorch reimplementation of DEIS (https://github.com/qsh-zh/deis).
58
+ #############################
59
+ ### Utils for DEIS solver ###
60
+ #############################
61
+ # ----------------------------------------------------------------------------
62
+ # Transfer from the input time (sigma) used in EDM to that (t) used in DEIS.
63
+
64
+
65
+ def vp_sigma_inv(beta_d, beta_min, sigma):
66
+ return ((beta_min**2 + 2 * beta_d * (sigma**2 + 1).log()).sqrt() - beta_min) / beta_d
67
+
68
+
69
+ def edm2t(edm_steps, epsilon_s=1e-3, sigma_min=0.002, sigma_max=80):
70
+ vp_beta_d = (
71
+ 2
72
+ * (torch.log(torch.tensor(sigma_min) ** 2 + 1) / epsilon_s - torch.log(torch.tensor(sigma_max) ** 2 + 1))
73
+ / (epsilon_s - 1)
74
+ )
75
+ vp_beta_min = torch.log(torch.tensor(sigma_max) ** 2 + 1) - 0.5 * vp_beta_d
76
+ t_steps = vp_sigma_inv(vp_beta_d, vp_beta_min, edm_steps)
77
+ return t_steps, vp_beta_min, vp_beta_d + vp_beta_min
78
+
79
+
80
+ def cal_poly(prev_t, j, taus):
81
+ poly = 1
82
+ for k in range(prev_t.shape[0]):
83
+ if k == j:
84
+ continue
85
+ poly *= (taus - prev_t[k]) / (prev_t[j] - prev_t[k])
86
+ return poly
87
+
88
+
89
+ def t2alpha_fn(beta_0, beta_1, t):
90
+ return torch.exp(-0.5 * t**2 * (beta_1 - beta_0) - t * beta_0)
91
+
92
+
93
+ def cal_intergrand(beta_0, beta_1, taus):
94
+ with torch.inference_mode(mode=False):
95
+ taus = taus.clone()
96
+ beta_0 = beta_0.clone()
97
+ beta_1 = beta_1.clone()
98
+ with torch.enable_grad():
99
+ taus.requires_grad_(True)
100
+ alpha = t2alpha_fn(beta_0, beta_1, taus)
101
+ log_alpha = alpha.log()
102
+ log_alpha.sum().backward()
103
+ d_log_alpha_dtau = taus.grad
104
+ integrand = -0.5 * d_log_alpha_dtau / torch.sqrt(alpha * (1 - alpha))
105
+ return integrand
106
+
107
+
108
+ def get_deis_coeff_list(t_steps, max_order, N=10000):
109
+ t_steps, beta_0, beta_1 = edm2t(t_steps)
110
+ C = []
111
+ for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
112
+ order = min(i + 1, max_order)
113
+ if order == 1:
114
+ C.append([])
115
+ else:
116
+ taus = torch.linspace(t_cur, t_next, N).to(t_next.device)
117
+ dtau = (t_next - t_cur) / N
118
+ prev_t = t_steps[[i - k for k in range(order)]]
119
+ coeff_temp = []
120
+ integrand = cal_intergrand(beta_0, beta_1, taus)
121
+ for j in range(order):
122
+ poly = cal_poly(prev_t, j, taus)
123
+ coeff_temp.append(torch.sum(integrand * poly) * dtau)
124
+ C.append(coeff_temp)
125
+ return C
@@ -0,0 +1,29 @@
1
+ from diffsynth_engine.algorithm.sampler.stable_diffusion.epsilon import EpsilonSampler
2
+
3
+
4
+ class DPMSolverPlusPlus2MSampler(EpsilonSampler):
5
+ """
6
+ DPM Solver++ 2M sampler
7
+ """
8
+
9
+ def initialize(self, init_latents, timesteps, sigmas, mask):
10
+ super().initialize(init_latents, timesteps, sigmas, mask)
11
+ self.old_denoised = None
12
+
13
+ def step(self, latents, model_outputs, i):
14
+ s_prev, s, s_next = self.sigmas[i - 1], self.sigmas[i], self.sigmas[i + 1]
15
+ t_prev, t, t_next = self._sigma_to_t(s_prev), self._sigma_to_t(s), self._sigma_to_t(s_next)
16
+ h = t_next - t
17
+ x = self._scaling(latents, s)
18
+ denoised = self._to_denoised(s, model_outputs, x)
19
+ if self.old_denoised is None or s_next == 0:
20
+ self.old_denoised = denoised
21
+ return (s_next / s) * x - (-h).expm1() * denoised
22
+ h_last = t - t_prev
23
+ r = h_last / h
24
+ denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * self.old_denoised
25
+ x = (s_next / s) * x - (-h).expm1() * denoised_d
26
+ return self._unscaling(x, s_next)
27
+
28
+ def _sigma_to_t(self, sigma):
29
+ return sigma.log().neg()