diffsynth-engine 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/__init__.py +28 -0
- diffsynth_engine/algorithm/__init__.py +0 -0
- diffsynth_engine/algorithm/noise_scheduler/__init__.py +21 -0
- diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +10 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +5 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +28 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +25 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +50 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +26 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +25 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +19 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +21 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +77 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +17 -0
- diffsynth_engine/algorithm/sampler/__init__.py +19 -0
- diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
- diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +22 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +54 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +32 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +125 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +29 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +53 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +59 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +29 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +12 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +30 -0
- diffsynth_engine/conf/models/components/vae.json +254 -0
- diffsynth_engine/conf/models/flux/flux_dit.json +105 -0
- diffsynth_engine/conf/models/flux/flux_text_encoder.json +20 -0
- diffsynth_engine/conf/models/flux/flux_vae.json +250 -0
- diffsynth_engine/conf/models/sd/sd_text_encoder.json +220 -0
- diffsynth_engine/conf/models/sd/sd_unet.json +397 -0
- diffsynth_engine/conf/models/sd3/sd3_dit.json +908 -0
- diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +756 -0
- diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +455 -0
- diffsynth_engine/conf/models/sdxl/sdxl_unet.json +1056 -0
- diffsynth_engine/conf/models/wan/dit/1.3b-t2v.json +13 -0
- diffsynth_engine/conf/models/wan/dit/14b-i2v.json +13 -0
- diffsynth_engine/conf/models/wan/dit/14b-t2v.json +13 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +48895 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +30 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +30 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +49410 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +125 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +129428 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +940 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +48895 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +24 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +30 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +49410 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +40213 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +24 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +38 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +49411 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +308 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +1028026 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +2748 -0
- diffsynth_engine/kernels/__init__.py +0 -0
- diffsynth_engine/models/__init__.py +7 -0
- diffsynth_engine/models/base.py +64 -0
- diffsynth_engine/models/basic/__init__.py +0 -0
- diffsynth_engine/models/basic/attention.py +217 -0
- diffsynth_engine/models/basic/lora.py +293 -0
- diffsynth_engine/models/basic/relative_position_emb.py +56 -0
- diffsynth_engine/models/basic/timestep.py +81 -0
- diffsynth_engine/models/basic/transformer_helper.py +88 -0
- diffsynth_engine/models/basic/unet_helper.py +244 -0
- diffsynth_engine/models/components/__init__.py +0 -0
- diffsynth_engine/models/components/clip.py +56 -0
- diffsynth_engine/models/components/t5.py +222 -0
- diffsynth_engine/models/components/vae.py +392 -0
- diffsynth_engine/models/flux/__init__.py +14 -0
- diffsynth_engine/models/flux/flux_dit.py +476 -0
- diffsynth_engine/models/flux/flux_text_encoder.py +88 -0
- diffsynth_engine/models/flux/flux_vae.py +78 -0
- diffsynth_engine/models/sd/__init__.py +12 -0
- diffsynth_engine/models/sd/sd_text_encoder.py +142 -0
- diffsynth_engine/models/sd/sd_unet.py +293 -0
- diffsynth_engine/models/sd/sd_vae.py +38 -0
- diffsynth_engine/models/sd3/__init__.py +14 -0
- diffsynth_engine/models/sd3/sd3_dit.py +302 -0
- diffsynth_engine/models/sd3/sd3_text_encoder.py +163 -0
- diffsynth_engine/models/sd3/sd3_vae.py +43 -0
- diffsynth_engine/models/sdxl/__init__.py +13 -0
- diffsynth_engine/models/sdxl/sdxl_text_encoder.py +307 -0
- diffsynth_engine/models/sdxl/sdxl_unet.py +306 -0
- diffsynth_engine/models/sdxl/sdxl_vae.py +38 -0
- diffsynth_engine/models/utils.py +54 -0
- diffsynth_engine/models/wan/__init__.py +0 -0
- diffsynth_engine/models/wan/wan_dit.py +497 -0
- diffsynth_engine/models/wan/wan_image_encoder.py +494 -0
- diffsynth_engine/models/wan/wan_text_encoder.py +297 -0
- diffsynth_engine/models/wan/wan_vae.py +771 -0
- diffsynth_engine/pipelines/__init__.py +18 -0
- diffsynth_engine/pipelines/base.py +253 -0
- diffsynth_engine/pipelines/flux_image.py +512 -0
- diffsynth_engine/pipelines/sd_image.py +352 -0
- diffsynth_engine/pipelines/sdxl_image.py +395 -0
- diffsynth_engine/pipelines/wan_video.py +524 -0
- diffsynth_engine/tokenizers/__init__.py +6 -0
- diffsynth_engine/tokenizers/base.py +157 -0
- diffsynth_engine/tokenizers/clip.py +288 -0
- diffsynth_engine/tokenizers/t5.py +194 -0
- diffsynth_engine/tokenizers/wan.py +74 -0
- diffsynth_engine/utils/__init__.py +0 -0
- diffsynth_engine/utils/constants.py +34 -0
- diffsynth_engine/utils/download.py +135 -0
- diffsynth_engine/utils/env.py +7 -0
- diffsynth_engine/utils/flag.py +46 -0
- diffsynth_engine/utils/fp8_linear.py +64 -0
- diffsynth_engine/utils/gguf.py +415 -0
- diffsynth_engine/utils/loader.py +17 -0
- diffsynth_engine/utils/lock.py +56 -0
- diffsynth_engine/utils/logging.py +12 -0
- diffsynth_engine/utils/offload.py +44 -0
- diffsynth_engine/utils/parallel.py +390 -0
- diffsynth_engine/utils/prompt.py +9 -0
- diffsynth_engine/utils/video.py +40 -0
- diffsynth_engine-0.0.0.dist-info/LICENSE +201 -0
- diffsynth_engine-0.0.0.dist-info/METADATA +236 -0
- diffsynth_engine-0.0.0.dist-info/RECORD +127 -0
- diffsynth_engine-0.0.0.dist-info/WHEEL +5 -0
- diffsynth_engine-0.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from .pipelines import (
|
|
2
|
+
FluxImagePipeline,
|
|
3
|
+
SDXLImagePipeline,
|
|
4
|
+
SDImagePipeline,
|
|
5
|
+
WanVideoPipeline,
|
|
6
|
+
FluxModelConfig,
|
|
7
|
+
SDXLModelConfig,
|
|
8
|
+
SDModelConfig,
|
|
9
|
+
WanModelConfig,
|
|
10
|
+
)
|
|
11
|
+
from .utils.download import fetch_model, fetch_modelscope_model, fetch_civitai_model
|
|
12
|
+
from .utils.video import load_video, save_video
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"FluxImagePipeline",
|
|
16
|
+
"SDXLImagePipeline",
|
|
17
|
+
"SDImagePipeline",
|
|
18
|
+
"WanVideoPipeline",
|
|
19
|
+
"FluxModelConfig",
|
|
20
|
+
"SDXLModelConfig",
|
|
21
|
+
"SDModelConfig",
|
|
22
|
+
"WanModelConfig",
|
|
23
|
+
"fetch_model",
|
|
24
|
+
"fetch_modelscope_model",
|
|
25
|
+
"fetch_civitai_model",
|
|
26
|
+
"load_video",
|
|
27
|
+
"save_video",
|
|
28
|
+
]
|
|
File without changes
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from .stable_diffusion.linear import ScaledLinearScheduler
|
|
2
|
+
from .stable_diffusion.beta import BetaScheduler
|
|
3
|
+
from .stable_diffusion.karras import KarrasScheduler
|
|
4
|
+
from .stable_diffusion.exponential import ExponentialScheduler
|
|
5
|
+
from .stable_diffusion.ddim import DDIMScheduler
|
|
6
|
+
from .stable_diffusion.sgm_uniform import SGMUniformScheduler
|
|
7
|
+
from .flow_match.recifited_flow import RecifitedFlowScheduler
|
|
8
|
+
from .flow_match.flow_ddim import FlowDDIMScheduler
|
|
9
|
+
from .flow_match.flow_beta import FlowBetaScheduler
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"ScaledLinearScheduler",
|
|
13
|
+
"BetaScheduler",
|
|
14
|
+
"KarrasScheduler",
|
|
15
|
+
"ExponentialScheduler",
|
|
16
|
+
"DDIMScheduler",
|
|
17
|
+
"SGMUniformScheduler",
|
|
18
|
+
"RecifitedFlowScheduler",
|
|
19
|
+
"FlowDDIMScheduler",
|
|
20
|
+
"FlowBetaScheduler",
|
|
21
|
+
]
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
import scipy.stats as stats
|
|
4
|
+
|
|
5
|
+
from diffsynth_engine.algorithm.noise_scheduler.base_scheduler import append_zero
|
|
6
|
+
from diffsynth_engine.algorithm.noise_scheduler.flow_match.recifited_flow import RecifitedFlowScheduler
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class FlowBetaScheduler(RecifitedFlowScheduler):
|
|
10
|
+
def __init__(self):
|
|
11
|
+
super().__init__()
|
|
12
|
+
self.alpha = 0.6
|
|
13
|
+
self.beta = 0.6
|
|
14
|
+
|
|
15
|
+
def schedule(self, num_inference_steps: int, mu: float | None = None, sigmas: torch.Tensor | None = None):
|
|
16
|
+
pseudo_timestep_range = 10000
|
|
17
|
+
inner_sigmas = torch.arange(1, pseudo_timestep_range + 1, 1) / pseudo_timestep_range
|
|
18
|
+
inner_sigmas = self._time_shift(mu, 1.0, inner_sigmas)
|
|
19
|
+
sigma_min = inner_sigmas[0]
|
|
20
|
+
sigma_max = inner_sigmas[-1]
|
|
21
|
+
|
|
22
|
+
timesteps = 1 - np.linspace(0, 1, num_inference_steps)
|
|
23
|
+
timesteps = [stats.beta.ppf(x, self.alpha, self.beta) for x in timesteps]
|
|
24
|
+
sigmas = [sigma_min + (x * (sigma_max - sigma_min)) for x in timesteps]
|
|
25
|
+
sigmas = torch.FloatTensor(sigmas)
|
|
26
|
+
timesteps = self._sigma_to_t(sigmas)
|
|
27
|
+
sigmas = append_zero(sigmas)
|
|
28
|
+
return sigmas, timesteps
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from diffsynth_engine.algorithm.noise_scheduler.base_scheduler import append_zero
|
|
4
|
+
from diffsynth_engine.algorithm.noise_scheduler.flow_match.recifited_flow import RecifitedFlowScheduler
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class FlowDDIMScheduler(RecifitedFlowScheduler):
|
|
8
|
+
def __init__(self, shift=1.0, num_train_timesteps=1000, use_dynamic_shifting=False):
|
|
9
|
+
super().__init__(shift, num_train_timesteps, use_dynamic_shifting)
|
|
10
|
+
self.pseudo_timestep_range = 10000
|
|
11
|
+
|
|
12
|
+
def schedule(self, num_inference_steps: int, mu: float | None = None, sigmas: torch.Tensor | None = None):
|
|
13
|
+
inner_sigmas = torch.arange(1, self.pseudo_timestep_range + 1, 1) / self.pseudo_timestep_range
|
|
14
|
+
inner_sigmas = self._time_shift(mu, 1.0, inner_sigmas)
|
|
15
|
+
sigmas = []
|
|
16
|
+
ss = max(len(inner_sigmas) // num_inference_steps, 1)
|
|
17
|
+
for i in range(1, len(inner_sigmas), ss):
|
|
18
|
+
sigmas.append(float(inner_sigmas[i]))
|
|
19
|
+
sigmas = sigmas[::-1]
|
|
20
|
+
sigmas = torch.FloatTensor(sigmas)
|
|
21
|
+
|
|
22
|
+
timesteps = self._sigma_to_t(sigmas)
|
|
23
|
+
sigmas = append_zero(sigmas)
|
|
24
|
+
|
|
25
|
+
return sigmas, timesteps
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import math
|
|
3
|
+
|
|
4
|
+
from diffsynth_engine.algorithm.noise_scheduler.base_scheduler import append_zero, BaseScheduler
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class RecifitedFlowScheduler(BaseScheduler):
|
|
8
|
+
def __init__(
|
|
9
|
+
self,
|
|
10
|
+
shift=1.0,
|
|
11
|
+
sigma_min=0.001,
|
|
12
|
+
sigma_max=1.0,
|
|
13
|
+
num_train_timesteps=1000,
|
|
14
|
+
use_dynamic_shifting=False,
|
|
15
|
+
):
|
|
16
|
+
self.shift = shift
|
|
17
|
+
self.sigma_min = sigma_min
|
|
18
|
+
self.sigma_max = sigma_max
|
|
19
|
+
self.num_train_timesteps = num_train_timesteps
|
|
20
|
+
self.use_dynamic_shifting = use_dynamic_shifting
|
|
21
|
+
|
|
22
|
+
def _sigma_to_t(self, sigma):
|
|
23
|
+
return sigma * self.num_train_timesteps
|
|
24
|
+
|
|
25
|
+
def _t_to_sigma(self, t):
|
|
26
|
+
return t / self.num_train_timesteps
|
|
27
|
+
|
|
28
|
+
def _time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
|
29
|
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
|
30
|
+
|
|
31
|
+
def _shift_sigma(self, sigma: torch.Tensor, shift: float):
|
|
32
|
+
return shift * sigma / (1 + (shift - 1) * sigma)
|
|
33
|
+
|
|
34
|
+
def schedule(
|
|
35
|
+
self,
|
|
36
|
+
num_inference_steps: int,
|
|
37
|
+
mu: float | None = None,
|
|
38
|
+
sigma_min: float | None = None,
|
|
39
|
+
sigma_max: float | None = None,
|
|
40
|
+
):
|
|
41
|
+
sigma_min = self.sigma_min if sigma_min is None else sigma_min
|
|
42
|
+
sigma_max = self.sigma_max if sigma_max is None else sigma_max
|
|
43
|
+
sigmas = torch.linspace(sigma_max, sigma_min, num_inference_steps)
|
|
44
|
+
if self.use_dynamic_shifting:
|
|
45
|
+
sigmas = self._time_shift(mu, 1.0, sigmas) # FLUX
|
|
46
|
+
else:
|
|
47
|
+
sigmas = self._shift_sigma(sigmas, self.shift)
|
|
48
|
+
timesteps = sigmas * self.num_train_timesteps
|
|
49
|
+
sigmas = append_zero(sigmas)
|
|
50
|
+
return sigmas, timesteps
|
|
File without changes
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
import scipy.stats as stats
|
|
4
|
+
|
|
5
|
+
from diffsynth_engine.algorithm.noise_scheduler.stable_diffusion.linear import ScaledLinearScheduler
|
|
6
|
+
from diffsynth_engine.algorithm.noise_scheduler.base_scheduler import append_zero
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class BetaScheduler(ScaledLinearScheduler):
|
|
10
|
+
"""
|
|
11
|
+
Implemented based on: https://arxiv.org/abs/2407.12173
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(self):
|
|
15
|
+
super().__init__()
|
|
16
|
+
self.alpha = 0.6
|
|
17
|
+
self.beta = 0.6
|
|
18
|
+
|
|
19
|
+
def schedule(self, num_inference_steps: int):
|
|
20
|
+
timesteps = 1 - np.linspace(0, 1, num_inference_steps)
|
|
21
|
+
timesteps = [stats.beta.ppf(x, self.alpha, self.beta) for x in timesteps]
|
|
22
|
+
sigmas = [self.sigma_min + (x * (self.sigma_max - self.sigma_min)) for x in timesteps]
|
|
23
|
+
sigmas = torch.FloatTensor(sigmas).to(self.device)
|
|
24
|
+
timesteps = self.sigma_to_t(sigmas)
|
|
25
|
+
sigmas = append_zero(sigmas)
|
|
26
|
+
return sigmas, timesteps
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from diffsynth_engine.algorithm.noise_scheduler.stable_diffusion.linear import ScaledLinearScheduler
|
|
4
|
+
from diffsynth_engine.algorithm.noise_scheduler.base_scheduler import append_zero
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class DDIMScheduler(ScaledLinearScheduler):
|
|
8
|
+
"""
|
|
9
|
+
Implemented based on: https://arxiv.org/pdf/2010.02502.pdf
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
def __init__(self):
|
|
13
|
+
super().__init__()
|
|
14
|
+
|
|
15
|
+
def schedule(self, num_inference_steps: int):
|
|
16
|
+
inner_sigmas = self.get_sigmas()
|
|
17
|
+
sigmas = []
|
|
18
|
+
ss = max(len(inner_sigmas) // num_inference_steps, 1)
|
|
19
|
+
for i in range(1, len(inner_sigmas), ss):
|
|
20
|
+
sigmas.append(float(inner_sigmas[i]))
|
|
21
|
+
sigmas = sigmas[::-1]
|
|
22
|
+
sigmas = torch.FloatTensor(sigmas).to(self.device)
|
|
23
|
+
timesteps = self.sigma_to_t(sigmas)
|
|
24
|
+
sigmas = append_zero(sigmas)
|
|
25
|
+
return sigmas, timesteps
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import math
|
|
3
|
+
|
|
4
|
+
from diffsynth_engine.algorithm.noise_scheduler.stable_diffusion.linear import ScaledLinearScheduler
|
|
5
|
+
from diffsynth_engine.algorithm.noise_scheduler.base_scheduler import append_zero
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ExponentialScheduler(ScaledLinearScheduler):
|
|
9
|
+
def __init__(self):
|
|
10
|
+
super().__init__()
|
|
11
|
+
|
|
12
|
+
def schedule(self, num_inference_steps: int):
|
|
13
|
+
"""Constructs an exponential noise schedule."""
|
|
14
|
+
sigmas = torch.linspace(
|
|
15
|
+
math.log(self.sigma_max), math.log(self.sigma_min), num_inference_steps, device=self.device
|
|
16
|
+
).exp()
|
|
17
|
+
timesteps = self.sigma_to_t(sigmas)
|
|
18
|
+
sigmas = append_zero(sigmas)
|
|
19
|
+
return sigmas, timesteps
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from diffsynth_engine.algorithm.noise_scheduler.stable_diffusion.linear import ScaledLinearScheduler
|
|
4
|
+
from diffsynth_engine.algorithm.noise_scheduler.base_scheduler import append_zero
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class KarrasScheduler(ScaledLinearScheduler):
|
|
8
|
+
def __init__(self):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.rho = 7.0
|
|
11
|
+
self.device = "cpu"
|
|
12
|
+
|
|
13
|
+
def schedule(self, num_inference_steps: int):
|
|
14
|
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
|
15
|
+
ramp = torch.linspace(0, 1, num_inference_steps, device=self.device)
|
|
16
|
+
min_inv_rho = self.sigma_min ** (1 / self.rho)
|
|
17
|
+
max_inv_rho = self.sigma_max ** (1 / self.rho)
|
|
18
|
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
|
|
19
|
+
timesteps = self.sigma_to_t(sigmas)
|
|
20
|
+
sigmas = append_zero(sigmas).to(self.device)
|
|
21
|
+
return sigmas, timesteps
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from diffsynth_engine.algorithm.noise_scheduler.base_scheduler import BaseScheduler, append_zero
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def linear_beta_schedule(beta_start: float = 0.00085, beta_end: float = 0.0120, num_train_steps: int = 1000):
|
|
7
|
+
"""
|
|
8
|
+
DDPM Schedule
|
|
9
|
+
"""
|
|
10
|
+
return torch.linspace(beta_start, beta_end, num_train_steps)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def scaled_linear_beta_schedule(beta_start: float = 0.00085, beta_end: float = 0.0120, num_train_steps: int = 1000):
|
|
14
|
+
"""
|
|
15
|
+
Stable Diffusion Schedule
|
|
16
|
+
"""
|
|
17
|
+
return torch.linspace(beta_start**0.5, beta_end**0.5, num_train_steps) ** 2
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ScaledLinearScheduler(BaseScheduler):
|
|
21
|
+
def __init__(self):
|
|
22
|
+
self.device = "cpu"
|
|
23
|
+
self.num_train_steps = 1000
|
|
24
|
+
self.beta_start = 0.00085
|
|
25
|
+
self.beta_end = 0.0120
|
|
26
|
+
self.sigmas = self.get_sigmas()
|
|
27
|
+
self.log_sigmas = self.sigmas.log()
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
def sigma_min(self):
|
|
31
|
+
return self.sigmas[0]
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def sigma_max(self):
|
|
35
|
+
return self.sigmas[-1]
|
|
36
|
+
|
|
37
|
+
def get_sigmas(self):
|
|
38
|
+
# Stable Diffusion Sigmas
|
|
39
|
+
# len(sigmas) == 1000, sigma_min=sigmas[0] == 0.0292, sigma_max=sigmas[-1] == 14.6146
|
|
40
|
+
betas = scaled_linear_beta_schedule(
|
|
41
|
+
beta_start=self.beta_start, beta_end=self.beta_end, num_train_steps=self.num_train_steps
|
|
42
|
+
)
|
|
43
|
+
alphas = 1.0 - betas
|
|
44
|
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
|
45
|
+
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
|
46
|
+
return sigmas
|
|
47
|
+
|
|
48
|
+
def sigma_to_t(self, sigma):
|
|
49
|
+
"""
|
|
50
|
+
找到sigma.log()在self.log_sigmas中的位置(low和high), 进行加权插值得到t
|
|
51
|
+
"""
|
|
52
|
+
log_sigma = sigma.log()
|
|
53
|
+
dists = log_sigma - self.log_sigmas[:, None]
|
|
54
|
+
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
|
|
55
|
+
high_idx = low_idx + 1
|
|
56
|
+
low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
|
|
57
|
+
w = (low - log_sigma) / (low - high)
|
|
58
|
+
w = w.clamp(0, 1)
|
|
59
|
+
t = (1 - w) * low_idx + w * high_idx
|
|
60
|
+
return t.view(sigma.shape)
|
|
61
|
+
|
|
62
|
+
def t_to_sigma(self, t):
|
|
63
|
+
"""
|
|
64
|
+
对t进行floor和ceil, 得到low_idx和high_idx, 计算对应位置的log_sigma, 进行加权插值并exp得到sigma
|
|
65
|
+
"""
|
|
66
|
+
t = t.float()
|
|
67
|
+
low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
|
|
68
|
+
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
|
69
|
+
return log_sigma.exp()
|
|
70
|
+
|
|
71
|
+
def schedule(self, num_inference_steps: int):
|
|
72
|
+
"""
|
|
73
|
+
Uniformly sample timesteps for inference
|
|
74
|
+
"""
|
|
75
|
+
timesteps = torch.linspace(self.num_train_steps - 1, 0, num_inference_steps, device=self.sigmas.device)
|
|
76
|
+
sigmas = append_zero(self.t_to_sigma(timesteps))
|
|
77
|
+
return sigmas, timesteps
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from diffsynth_engine.algorithm.noise_scheduler.stable_diffusion.linear import ScaledLinearScheduler
|
|
4
|
+
from diffsynth_engine.algorithm.noise_scheduler.base_scheduler import append_zero
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SGMUniformScheduler(ScaledLinearScheduler):
|
|
8
|
+
def __init__(self):
|
|
9
|
+
super().__init__()
|
|
10
|
+
|
|
11
|
+
def schedule(self, num_inference_steps: int):
|
|
12
|
+
# suppose sigma_min and sigma_max is default value
|
|
13
|
+
timesteps = torch.linspace(999, 0, num_inference_steps + 1)[:-1]
|
|
14
|
+
sigmas = [self.t_to_sigma(timestep) for timestep in timesteps]
|
|
15
|
+
sigmas = torch.FloatTensor(sigmas).to(self.device)
|
|
16
|
+
sigmas = append_zero(sigmas)
|
|
17
|
+
return sigmas, timesteps
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from .stable_diffusion.ddpm import DDPMSampler
|
|
2
|
+
from .stable_diffusion.euler import EulerSampler
|
|
3
|
+
from .stable_diffusion.euler_ancestral import EulerAncestralSampler
|
|
4
|
+
from .stable_diffusion.dpmpp_2m import DPMSolverPlusPlus2MSampler
|
|
5
|
+
from .stable_diffusion.dpmpp_2m_sde import DPMSolverPlusPlus2MSDESampler
|
|
6
|
+
from .stable_diffusion.dpmpp_3m_sde import DPMSolverPlusPlus3MSDESampler
|
|
7
|
+
from .stable_diffusion.deis import DEISSampler
|
|
8
|
+
from .flow_match.flow_match_euler import FlowMatchEulerSampler
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"DDPMSampler",
|
|
12
|
+
"EulerSampler",
|
|
13
|
+
"EulerAncestralSampler",
|
|
14
|
+
"DPMSolverPlusPlus2MSampler",
|
|
15
|
+
"DPMSolverPlusPlus2MSDESampler",
|
|
16
|
+
"DPMSolverPlusPlus3MSDESampler",
|
|
17
|
+
"DEISSampler",
|
|
18
|
+
"FlowMatchEulerSampler",
|
|
19
|
+
]
|
|
File without changes
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class FlowMatchEulerSampler:
|
|
5
|
+
def initialize(self, init_latents, timesteps, sigmas, mask=None):
|
|
6
|
+
self.init_latents = init_latents
|
|
7
|
+
self.timesteps = timesteps
|
|
8
|
+
self.sigmas = sigmas
|
|
9
|
+
self.mask = mask
|
|
10
|
+
|
|
11
|
+
def step(self, latents, model_outputs, i):
|
|
12
|
+
if self.mask is not None:
|
|
13
|
+
model_outputs = model_outputs * self.mask + self.init_latents * (1 - self.mask)
|
|
14
|
+
|
|
15
|
+
dt = self.sigmas[i + 1] - self.sigmas[i]
|
|
16
|
+
latents = latents.to(dtype=torch.float32)
|
|
17
|
+
latents = latents + model_outputs * dt
|
|
18
|
+
latents = latents.to(dtype=model_outputs.dtype)
|
|
19
|
+
return latents
|
|
20
|
+
|
|
21
|
+
def add_noise(self, latents, noise, sigma):
|
|
22
|
+
return (1 - sigma) * latents + noise * sigma
|
|
File without changes
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torchsde
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class BatchedBrownianTree:
|
|
6
|
+
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
|
|
7
|
+
|
|
8
|
+
def __init__(self, x, t0, t1, seed=None, **kwargs):
|
|
9
|
+
t0, t1, self.sign = self.sort(t0, t1)
|
|
10
|
+
w0 = kwargs.get("w0", torch.zeros_like(x))
|
|
11
|
+
if seed is None:
|
|
12
|
+
seed = torch.randint(0, 2**63 - 1, []).item()
|
|
13
|
+
self.batched = True
|
|
14
|
+
try:
|
|
15
|
+
assert len(seed) == x.shape[0]
|
|
16
|
+
w0 = w0[0]
|
|
17
|
+
except TypeError:
|
|
18
|
+
seed = [seed]
|
|
19
|
+
self.batched = False
|
|
20
|
+
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
|
|
21
|
+
|
|
22
|
+
@staticmethod
|
|
23
|
+
def sort(a, b):
|
|
24
|
+
return (a, b, 1) if a < b else (b, a, -1)
|
|
25
|
+
|
|
26
|
+
def __call__(self, t0, t1):
|
|
27
|
+
t0, t1, sign = self.sort(t0, t1)
|
|
28
|
+
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
|
|
29
|
+
return w if self.batched else w[0]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class BrownianTreeNoiseSampler:
|
|
33
|
+
"""A noise sampler backed by a torchsde.BrownianTree.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
x (Tensor): The tensor whose shape, device and dtype to use to generate
|
|
37
|
+
random samples.
|
|
38
|
+
sigma_min (float): The low end of the valid interval.
|
|
39
|
+
sigma_max (float): The high end of the valid interval.
|
|
40
|
+
seed (int or List[int]): The random seed. If a list of seeds is
|
|
41
|
+
supplied instead of a single integer, then the noise sampler will
|
|
42
|
+
use one BrownianTree per batch item, each with its own seed.
|
|
43
|
+
transform (callable): A function that maps sigma to the sampler's
|
|
44
|
+
internal timestep.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
|
|
48
|
+
self.transform = transform
|
|
49
|
+
t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
|
|
50
|
+
self.tree = BatchedBrownianTree(x, t0, t1, seed)
|
|
51
|
+
|
|
52
|
+
def __call__(self, sigma, sigma_next):
|
|
53
|
+
t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
|
|
54
|
+
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .epsilon import EpsilonSampler
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class DDPMSampler(EpsilonSampler):
|
|
6
|
+
def _step_function(self, x, sigma, sigma_prev, noise):
|
|
7
|
+
alpha_cumprod = 1 / ((sigma * sigma) + 1)
|
|
8
|
+
alpha_cumprod_prev = 1 / ((sigma_prev * sigma_prev) + 1)
|
|
9
|
+
alpha = alpha_cumprod / alpha_cumprod_prev
|
|
10
|
+
|
|
11
|
+
mu = (1.0 / alpha) ** 0.5 * (x - (1 - alpha) * noise / (1 - alpha_cumprod) ** 0.5)
|
|
12
|
+
if sigma_prev > 0:
|
|
13
|
+
# Caution: this randn tensor needs to be controlled by `torch.manual_seed`.
|
|
14
|
+
mu += ((1 - alpha) * (1.0 - alpha_cumprod_prev) / (1.0 - alpha_cumprod)) ** 0.5 * torch.randn_like(x)
|
|
15
|
+
return mu
|
|
16
|
+
|
|
17
|
+
def step(self, latents, model_outputs, i):
|
|
18
|
+
sigma = self.sigmas[i]
|
|
19
|
+
sigma_next = self.sigmas[i + 1]
|
|
20
|
+
latents = self._scaling(sigma, latents)
|
|
21
|
+
|
|
22
|
+
denoised = self._to_denoised(sigma, model_outputs, latents)
|
|
23
|
+
latents = self._step_function(
|
|
24
|
+
latents / (1.0 + sigma**2.0) ** 0.5, sigma, sigma_next, (latents - denoised) / sigma
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
latents *= (1.0 + sigma_next**2.0) ** 0.5
|
|
28
|
+
|
|
29
|
+
return self._unscaling(self.sigmas[i + 1], latents)
|
|
30
|
+
|
|
31
|
+
def step2(self, latents, model_outputs, i):
|
|
32
|
+
return self._step_function(latents, self.sigmas[i], self.sigmas[i + 1], model_outputs)
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from diffsynth_engine.algorithm.sampler.stable_diffusion.epsilon import EpsilonSampler
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class DEISSampler(EpsilonSampler):
|
|
7
|
+
"""
|
|
8
|
+
According to the implementation of the webui forge, deis_mode only supports tab and rhoab.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
def initialize(self, init_latents, timesteps, sigmas, mask):
|
|
12
|
+
super().initialize(init_latents, timesteps, sigmas, mask)
|
|
13
|
+
self.max_order = 3
|
|
14
|
+
self.sigmas = sigmas
|
|
15
|
+
self.timesteps = timesteps
|
|
16
|
+
self.lower_order_nums = 0
|
|
17
|
+
self.coeff_list = get_deis_coeff_list(self.sigmas, self.max_order)
|
|
18
|
+
self.coeff_buffer = []
|
|
19
|
+
|
|
20
|
+
def step(self, latents, model_outputs, i):
|
|
21
|
+
s, s_next = self.sigmas[i], self.sigmas[i + 1]
|
|
22
|
+
denoised = latents - model_outputs * s
|
|
23
|
+
|
|
24
|
+
d = (latents - denoised) / s
|
|
25
|
+
order = min(self.max_order, i + 1)
|
|
26
|
+
if self.sigmas[i + 1] <= 0:
|
|
27
|
+
order = 1
|
|
28
|
+
if order == 1:
|
|
29
|
+
x_next = latents + (s_next - s) * d
|
|
30
|
+
elif order == 2:
|
|
31
|
+
coeff, coeff_prev1 = self.coeff_list[i]
|
|
32
|
+
x_next = latents + coeff * d + coeff_prev1 * self.coeff_buffer[-1]
|
|
33
|
+
elif order == 3:
|
|
34
|
+
coeff, coeff_prev1, coeff_prev2 = self.coeff_list[i]
|
|
35
|
+
x_next = latents + coeff * d + coeff_prev1 * self.coeff_buffer[-1] + coeff_prev2 * self.coeff_buffer[-2]
|
|
36
|
+
elif order == 4:
|
|
37
|
+
coeff, coeff_prev1, coeff_prev2, coeff_prev3 = self.coeff_list[i]
|
|
38
|
+
x_next = (
|
|
39
|
+
latents
|
|
40
|
+
+ coeff * d
|
|
41
|
+
+ coeff_prev1 * self.coeff_buffer[-1]
|
|
42
|
+
+ coeff_prev2 * self.coeff_buffer[-2]
|
|
43
|
+
+ coeff_prev3 * self.coeff_buffer[-3]
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
if len(self.coeff_buffer) == self.max_order - 1:
|
|
47
|
+
for k in range(self.max_order - 2):
|
|
48
|
+
self.coeff_buffer[k] = self.coeff_buffer[k + 1]
|
|
49
|
+
self.coeff_buffer[-1] = d
|
|
50
|
+
else:
|
|
51
|
+
self.coeff_buffer.append(d)
|
|
52
|
+
return x_next
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
# Taken from: https://github.com/zju-pi/diff-sampler/blob/main/gits-main/solver_utils.py
|
|
56
|
+
# under Apache 2 license
|
|
57
|
+
# A pytorch reimplementation of DEIS (https://github.com/qsh-zh/deis).
|
|
58
|
+
#############################
|
|
59
|
+
### Utils for DEIS solver ###
|
|
60
|
+
#############################
|
|
61
|
+
# ----------------------------------------------------------------------------
|
|
62
|
+
# Transfer from the input time (sigma) used in EDM to that (t) used in DEIS.
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def vp_sigma_inv(beta_d, beta_min, sigma):
|
|
66
|
+
return ((beta_min**2 + 2 * beta_d * (sigma**2 + 1).log()).sqrt() - beta_min) / beta_d
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def edm2t(edm_steps, epsilon_s=1e-3, sigma_min=0.002, sigma_max=80):
|
|
70
|
+
vp_beta_d = (
|
|
71
|
+
2
|
|
72
|
+
* (torch.log(torch.tensor(sigma_min) ** 2 + 1) / epsilon_s - torch.log(torch.tensor(sigma_max) ** 2 + 1))
|
|
73
|
+
/ (epsilon_s - 1)
|
|
74
|
+
)
|
|
75
|
+
vp_beta_min = torch.log(torch.tensor(sigma_max) ** 2 + 1) - 0.5 * vp_beta_d
|
|
76
|
+
t_steps = vp_sigma_inv(vp_beta_d, vp_beta_min, edm_steps)
|
|
77
|
+
return t_steps, vp_beta_min, vp_beta_d + vp_beta_min
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def cal_poly(prev_t, j, taus):
|
|
81
|
+
poly = 1
|
|
82
|
+
for k in range(prev_t.shape[0]):
|
|
83
|
+
if k == j:
|
|
84
|
+
continue
|
|
85
|
+
poly *= (taus - prev_t[k]) / (prev_t[j] - prev_t[k])
|
|
86
|
+
return poly
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def t2alpha_fn(beta_0, beta_1, t):
|
|
90
|
+
return torch.exp(-0.5 * t**2 * (beta_1 - beta_0) - t * beta_0)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def cal_intergrand(beta_0, beta_1, taus):
|
|
94
|
+
with torch.inference_mode(mode=False):
|
|
95
|
+
taus = taus.clone()
|
|
96
|
+
beta_0 = beta_0.clone()
|
|
97
|
+
beta_1 = beta_1.clone()
|
|
98
|
+
with torch.enable_grad():
|
|
99
|
+
taus.requires_grad_(True)
|
|
100
|
+
alpha = t2alpha_fn(beta_0, beta_1, taus)
|
|
101
|
+
log_alpha = alpha.log()
|
|
102
|
+
log_alpha.sum().backward()
|
|
103
|
+
d_log_alpha_dtau = taus.grad
|
|
104
|
+
integrand = -0.5 * d_log_alpha_dtau / torch.sqrt(alpha * (1 - alpha))
|
|
105
|
+
return integrand
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def get_deis_coeff_list(t_steps, max_order, N=10000):
|
|
109
|
+
t_steps, beta_0, beta_1 = edm2t(t_steps)
|
|
110
|
+
C = []
|
|
111
|
+
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
|
|
112
|
+
order = min(i + 1, max_order)
|
|
113
|
+
if order == 1:
|
|
114
|
+
C.append([])
|
|
115
|
+
else:
|
|
116
|
+
taus = torch.linspace(t_cur, t_next, N).to(t_next.device)
|
|
117
|
+
dtau = (t_next - t_cur) / N
|
|
118
|
+
prev_t = t_steps[[i - k for k in range(order)]]
|
|
119
|
+
coeff_temp = []
|
|
120
|
+
integrand = cal_intergrand(beta_0, beta_1, taus)
|
|
121
|
+
for j in range(order):
|
|
122
|
+
poly = cal_poly(prev_t, j, taus)
|
|
123
|
+
coeff_temp.append(torch.sum(integrand * poly) * dtau)
|
|
124
|
+
C.append(coeff_temp)
|
|
125
|
+
return C
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from diffsynth_engine.algorithm.sampler.stable_diffusion.epsilon import EpsilonSampler
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class DPMSolverPlusPlus2MSampler(EpsilonSampler):
|
|
5
|
+
"""
|
|
6
|
+
DPM Solver++ 2M sampler
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
def initialize(self, init_latents, timesteps, sigmas, mask):
|
|
10
|
+
super().initialize(init_latents, timesteps, sigmas, mask)
|
|
11
|
+
self.old_denoised = None
|
|
12
|
+
|
|
13
|
+
def step(self, latents, model_outputs, i):
|
|
14
|
+
s_prev, s, s_next = self.sigmas[i - 1], self.sigmas[i], self.sigmas[i + 1]
|
|
15
|
+
t_prev, t, t_next = self._sigma_to_t(s_prev), self._sigma_to_t(s), self._sigma_to_t(s_next)
|
|
16
|
+
h = t_next - t
|
|
17
|
+
x = self._scaling(latents, s)
|
|
18
|
+
denoised = self._to_denoised(s, model_outputs, x)
|
|
19
|
+
if self.old_denoised is None or s_next == 0:
|
|
20
|
+
self.old_denoised = denoised
|
|
21
|
+
return (s_next / s) * x - (-h).expm1() * denoised
|
|
22
|
+
h_last = t - t_prev
|
|
23
|
+
r = h_last / h
|
|
24
|
+
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * self.old_denoised
|
|
25
|
+
x = (s_next / s) * x - (-h).expm1() * denoised_d
|
|
26
|
+
return self._unscaling(x, s_next)
|
|
27
|
+
|
|
28
|
+
def _sigma_to_t(self, sigma):
|
|
29
|
+
return sigma.log().neg()
|