hcpdiff 2.2__py3-none-any.whl → 2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/ckpt_manager/__init__.py +1 -1
- hcpdiff/ckpt_manager/ckpt.py +21 -17
- hcpdiff/ckpt_manager/format/diffusers.py +4 -4
- hcpdiff/ckpt_manager/format/sd_single.py +3 -3
- hcpdiff/ckpt_manager/loader.py +11 -4
- hcpdiff/diffusion/noise/__init__.py +0 -1
- hcpdiff/diffusion/sampler/VP.py +27 -0
- hcpdiff/diffusion/sampler/__init__.py +2 -3
- hcpdiff/diffusion/sampler/base.py +106 -44
- hcpdiff/diffusion/sampler/diffusers.py +11 -17
- hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -1
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +77 -2
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +193 -49
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +110 -33
- hcpdiff/diffusion/sampler/sigma_scheduler/flow.py +74 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/zero_terminal.py +22 -0
- hcpdiff/easy/cfg/sd15_train.py +35 -24
- hcpdiff/easy/cfg/sdxl_train.py +34 -25
- hcpdiff/evaluate/__init__.py +3 -1
- hcpdiff/evaluate/evaluator.py +76 -0
- hcpdiff/evaluate/metrics/__init__.py +1 -0
- hcpdiff/evaluate/metrics/clip_score.py +23 -0
- hcpdiff/evaluate/previewer.py +29 -12
- hcpdiff/loss/base.py +9 -26
- hcpdiff/loss/weighting.py +36 -18
- hcpdiff/models/lora_base_patch.py +26 -0
- hcpdiff/models/text_emb_ex.py +4 -0
- hcpdiff/models/wrapper/sd.py +17 -19
- hcpdiff/trainer_ac.py +7 -12
- hcpdiff/trainer_ac_single.py +1 -6
- hcpdiff/trainer_deepspeed.py +47 -0
- hcpdiff/utils/__init__.py +2 -1
- hcpdiff/utils/torch_utils.py +25 -0
- hcpdiff/workflow/__init__.py +1 -1
- hcpdiff/workflow/diffusion.py +27 -7
- hcpdiff/workflow/io.py +20 -3
- hcpdiff/workflow/text.py +6 -1
- {hcpdiff-2.2.dist-info → hcpdiff-2.3.dist-info}/METADATA +8 -4
- {hcpdiff-2.2.dist-info → hcpdiff-2.3.dist-info}/RECORD +43 -39
- {hcpdiff-2.2.dist-info → hcpdiff-2.3.dist-info}/WHEEL +1 -1
- {hcpdiff-2.2.dist-info → hcpdiff-2.3.dist-info}/entry_points.txt +1 -0
- hcpdiff/diffusion/noise/zero_terminal.py +0 -39
- hcpdiff/diffusion/sampler/ddpm.py +0 -20
- hcpdiff/diffusion/sampler/edm.py +0 -22
- hcpdiff/train_deepspeed.py +0 -69
- {hcpdiff-2.2.dist-info → hcpdiff-2.3.dist-info}/licenses/LICENSE +0 -0
- {hcpdiff-2.2.dist-info → hcpdiff-2.3.dist-info}/top_level.txt +0 -0
hcpdiff/ckpt_manager/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
1
|
from .format import EmbFormat, DiffusersSD15Format, DiffusersModelFormat, DiffusersSDXLFormat, DiffusersPixArtFormat, OfficialSDXLFormat, \
|
2
2
|
OfficialSD15Format, LoraWebuiFormat
|
3
|
-
from .ckpt import EmbSaver
|
3
|
+
from .ckpt import EmbSaver
|
4
4
|
from .loader import HCPLoraLoader
|
hcpdiff/ckpt_manager/ckpt.py
CHANGED
@@ -1,24 +1,28 @@
|
|
1
|
-
from rainbowneko.ckpt_manager import NekoSaver, CkptFormat, LocalCkptSource, PKLFormat
|
2
|
-
from torch import
|
1
|
+
from rainbowneko.ckpt_manager import NekoSaver, CkptFormat, LocalCkptSource, PKLFormat, LAYERS_ALL, LAYERS_TRAINABLE
|
2
|
+
from torch import Tensor
|
3
3
|
from typing import Dict, Any
|
4
4
|
|
5
5
|
class EmbSaver(NekoSaver):
|
6
|
-
def __init__(self, format: CkptFormat, source: LocalCkptSource,
|
7
|
-
|
8
|
-
|
6
|
+
def __init__(self, format: CkptFormat=None, source: LocalCkptSource=None, layers='all', key_map=None, prefix=None):
|
7
|
+
if format is None:
|
8
|
+
format = PKLFormat()
|
9
|
+
if source is None:
|
10
|
+
source = LocalCkptSource()
|
11
|
+
key_map = key_map or ('name -> name', 'embs -> embs', 'name_template -> name_template')
|
12
|
+
super().__init__(format, source, layers=layers, key_map=key_map)
|
9
13
|
self.prefix = prefix
|
10
14
|
|
11
|
-
def
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
self.
|
15
|
+
def _save_to(self, name, embs: Dict[str, Tensor], name_template=None):
|
16
|
+
for pt_name, pt in embs.items():
|
17
|
+
if self.layers == LAYERS_ALL:
|
18
|
+
pass
|
19
|
+
elif self.layers == LAYERS_TRAINABLE:
|
20
|
+
if not pt.requires_grad:
|
21
|
+
continue
|
22
|
+
elif pt_name not in self.layers:
|
23
|
+
continue
|
24
|
+
|
25
|
+
self.save((pt_name, pt), pt_name, prefix=self.prefix)
|
16
26
|
if name_template is not None:
|
17
27
|
pt_name = name_template.format(pt_name)
|
18
|
-
self.save(
|
19
|
-
|
20
|
-
def easy_emb_saver():
|
21
|
-
return EmbSaver(
|
22
|
-
format=PKLFormat(),
|
23
|
-
source=LocalCkptSource(),
|
24
|
-
)
|
28
|
+
self.save((pt_name, pt), pt_name, prefix=self.prefix)
|
@@ -3,7 +3,7 @@ from diffusers import ModelMixin, AutoencoderKL, UNet2DConditionModel, PixArtTra
|
|
3
3
|
from rainbowneko.ckpt_manager.format import CkptFormat
|
4
4
|
from transformers import CLIPTextModel, AutoTokenizer, T5EncoderModel
|
5
5
|
|
6
|
-
from hcpdiff.diffusion.sampler import
|
6
|
+
from hcpdiff.diffusion.sampler import VPSampler, DDPMDiscreteSigmaScheduler
|
7
7
|
from hcpdiff.models.compose import SDXLTokenizer, SDXLTextEncoder
|
8
8
|
|
9
9
|
class DiffusersModelFormat(CkptFormat):
|
@@ -23,7 +23,7 @@ class DiffusersSD15Format(CkptFormat):
|
|
23
23
|
pretrained_model, subfolder="unet", revision=revision, torch_dtype=dtype
|
24
24
|
)
|
25
25
|
vae = vae or AutoencoderKL.from_pretrained(pretrained_model, subfolder="vae", revision=revision, torch_dtype=dtype)
|
26
|
-
noise_sampler = noise_sampler or
|
26
|
+
noise_sampler = noise_sampler or VPSampler(DDPMDiscreteSigmaScheduler())
|
27
27
|
|
28
28
|
TE = TE or CLIPTextModel.from_pretrained(pretrained_model, subfolder="text_encoder", revision=revision, torch_dtype=dtype)
|
29
29
|
tokenizer = tokenizer or AutoTokenizer.from_pretrained(pretrained_model, subfolder="tokenizer", revision=revision, use_fast=False)
|
@@ -37,7 +37,7 @@ class DiffusersSDXLFormat(CkptFormat):
|
|
37
37
|
pretrained_model, subfolder="unet", revision=revision, torch_dtype=dtype
|
38
38
|
)
|
39
39
|
vae = vae or AutoencoderKL.from_pretrained(pretrained_model, subfolder="vae", revision=revision, torch_dtype=dtype)
|
40
|
-
noise_sampler = noise_sampler or
|
40
|
+
noise_sampler = noise_sampler or VPSampler(DDPMDiscreteSigmaScheduler())
|
41
41
|
|
42
42
|
TE = TE or SDXLTextEncoder.from_pretrained(pretrained_model, subfolder="text_encoder", revision=revision, torch_dtype=dtype)
|
43
43
|
tokenizer = tokenizer or SDXLTokenizer.from_pretrained(pretrained_model, subfolder="tokenizer", revision=revision, use_fast=False)
|
@@ -51,7 +51,7 @@ class DiffusersPixArtFormat(CkptFormat):
|
|
51
51
|
pretrained_model, subfolder="transformer", revision=revision, torch_dtype=dtype
|
52
52
|
)
|
53
53
|
vae = vae or AutoencoderKL.from_pretrained(pretrained_model, subfolder="vae", revision=revision, torch_dtype=dtype)
|
54
|
-
noise_sampler = noise_sampler or
|
54
|
+
noise_sampler = noise_sampler or VPSampler(DDPMDiscreteSigmaScheduler())
|
55
55
|
|
56
56
|
TE = TE or T5EncoderModel.from_pretrained(pretrained_model, subfolder="text_encoder", revision=revision, torch_dtype=dtype)
|
57
57
|
tokenizer = tokenizer or AutoTokenizer.from_pretrained(pretrained_model, subfolder="tokenizer", revision=revision, use_fast=False)
|
@@ -2,7 +2,7 @@ import torch
|
|
2
2
|
from diffusers import AutoencoderKL, StableDiffusionPipeline, StableDiffusionXLPipeline
|
3
3
|
from rainbowneko.ckpt_manager.format import CkptFormat
|
4
4
|
|
5
|
-
from hcpdiff.diffusion.sampler import
|
5
|
+
from hcpdiff.diffusion.sampler import VPSampler, DDPMDiscreteSigmaScheduler
|
6
6
|
from hcpdiff.models.compose import SDXLTextEncoder, SDXLTokenizer
|
7
7
|
|
8
8
|
class OfficialSD15Format(CkptFormat):
|
@@ -14,7 +14,7 @@ class OfficialSD15Format(CkptFormat):
|
|
14
14
|
pipe = StableDiffusionPipeline.from_single_file(
|
15
15
|
pretrained_model, revision=revision, torch_dtype=dtype, **pipe_args
|
16
16
|
)
|
17
|
-
noise_sampler = noise_sampler or
|
17
|
+
noise_sampler = noise_sampler or VPSampler(DDPMDiscreteSigmaScheduler())
|
18
18
|
return dict(denoiser=pipe.unet, TE=pipe.text_encoder, vae=pipe.vae, noise_sampler=noise_sampler, tokenizer=pipe.tokenizer)
|
19
19
|
|
20
20
|
class OfficialSDXLFormat(CkptFormat):
|
@@ -34,7 +34,7 @@ class OfficialSDXLFormat(CkptFormat):
|
|
34
34
|
pretrained_model, revision=revision, torch_dtype=dtype, **pipe_args
|
35
35
|
)
|
36
36
|
|
37
|
-
noise_sampler = noise_sampler or
|
37
|
+
noise_sampler = noise_sampler or VPSampler(DDPMDiscreteSigmaScheduler())
|
38
38
|
TE = SDXLTextEncoder([('clip_L', pipe.text_encoder), ('clip_bigG', pipe.text_encoder_2)])
|
39
39
|
tokenizer = SDXLTokenizer([('clip_L', pipe.tokenizer), ('clip_bigG', pipe.tokenizer_2)])
|
40
40
|
|
hcpdiff/ckpt_manager/loader.py
CHANGED
@@ -14,17 +14,24 @@ def get_lora_rank_and_cls(lora_state):
|
|
14
14
|
|
15
15
|
class HCPLoraLoader(NekoPluginLoader):
|
16
16
|
def __init__(self, format: CkptFormat=None, source: LocalCkptSource=None, path: str = None, layers='all', target_plugin=None,
|
17
|
-
state_prefix=None, base_model_alpha=0.0, load_ema=False, module_to_load='', **plugin_kwargs):
|
17
|
+
state_prefix=None, base_model_alpha=0.0, load_ema=False, module_to_load='', key_map=None, **plugin_kwargs):
|
18
|
+
key_map = key_map or ('name -> name', 'model -> model')
|
18
19
|
super().__init__(format, source, path=path, layers=layers, target_plugin=target_plugin, state_prefix=state_prefix,
|
19
|
-
base_model_alpha=base_model_alpha, load_ema=load_ema, **plugin_kwargs)
|
20
|
+
base_model_alpha=base_model_alpha, load_ema=load_ema, key_map=key_map, **plugin_kwargs)
|
20
21
|
self.module_to_load = module_to_load
|
21
22
|
|
22
|
-
def
|
23
|
+
def _load_to(self, name, model):
|
23
24
|
# get model to load plugin and its named_modules
|
24
25
|
model = model if self.module_to_load == '' else eval(f"model.{self.module_to_load}")
|
25
26
|
|
26
27
|
named_modules = {k:v for k, v in model.named_modules()}
|
27
|
-
|
28
|
+
state_dict = self.load(self.path, map_location='cpu')
|
29
|
+
if 'base' in state_dict or 'base_ema' in state_dict:
|
30
|
+
plugin_state = state_dict['base_ema' if self.load_ema else 'base']
|
31
|
+
elif 'plugin' in state_dict or 'plugin_ema' in state_dict:
|
32
|
+
plugin_state = state_dict['plugin_ema' if self.load_ema else 'plugin']
|
33
|
+
else:
|
34
|
+
plugin_state = state_dict
|
28
35
|
|
29
36
|
# filter layers to load
|
30
37
|
if self.layers != 'all':
|
@@ -0,0 +1,27 @@
|
|
1
|
+
from .base import Sampler
|
2
|
+
|
3
|
+
class VPSampler(Sampler):
|
4
|
+
# closed-form: \alpha(t)^2 + \sigma(t)^2 = 1
|
5
|
+
def velocity_to_eps(self, v_pred, x_t, t):
|
6
|
+
alpha = self.sigma_scheduler.alpha(t)
|
7
|
+
sigma = self.sigma_scheduler.sigma(t)
|
8
|
+
return alpha*v_pred+sigma*x_t
|
9
|
+
|
10
|
+
def eps_to_velocity(self, eps, x_t, t, x_0=None):
|
11
|
+
alpha = self.sigma_scheduler.alpha(t)
|
12
|
+
sigma = self.sigma_scheduler.sigma(t)
|
13
|
+
if x_0 is None:
|
14
|
+
x_0 = self.eps_to_x0(eps, x_t, t)
|
15
|
+
return alpha*eps-sigma*x_0
|
16
|
+
|
17
|
+
def velocity_to_x0(self, v_pred, x_t, t):
|
18
|
+
alpha = self.sigma_scheduler.alpha(t)
|
19
|
+
sigma = self.sigma_scheduler.sigma(t)
|
20
|
+
return alpha*x_t-sigma*v_pred
|
21
|
+
|
22
|
+
def x0_to_velocity(self, x_0, x_t, t, eps=None):
|
23
|
+
alpha = self.sigma_scheduler.alpha(t)
|
24
|
+
sigma = self.sigma_scheduler.sigma(t)
|
25
|
+
if eps is None:
|
26
|
+
eps = self.x0_to_eps(x_0, x_t, t)
|
27
|
+
return alpha*eps-sigma*x_0
|
@@ -1,72 +1,134 @@
|
|
1
1
|
from typing import Tuple
|
2
|
+
|
2
3
|
import torch
|
3
|
-
from .sigma_scheduler import SigmaScheduler
|
4
|
-
from diffusers import DDPMScheduler
|
5
4
|
|
6
|
-
|
7
|
-
def __init__(self, sigma_scheduler: SigmaScheduler, generator: torch.Generator = None):
|
8
|
-
self.sigma_scheduler = sigma_scheduler
|
9
|
-
self.generator = generator
|
5
|
+
from .sigma_scheduler import SigmaScheduler
|
10
6
|
|
11
|
-
|
12
|
-
|
7
|
+
try:
|
8
|
+
from diffusers.utils import randn_tensor
|
9
|
+
except:
|
10
|
+
# new version of diffusers
|
11
|
+
from diffusers.utils.torch_utils import randn_tensor
|
13
12
|
|
14
|
-
|
15
|
-
|
13
|
+
class BaseSampler:
|
14
|
+
def __init__(self, sigma_scheduler: SigmaScheduler, pred_type='eps', target_type='eps', generator: torch.Generator = None):
|
15
|
+
'''
|
16
|
+
Some losses can only be calculated in a specific space. Such as SSIM in x0 space.
|
17
|
+
The model pred need convert to target space.
|
16
18
|
|
17
|
-
|
18
|
-
|
19
|
+
:param pred_type: ['x0', 'eps', 'velocity', ..., None] The output space of the model
|
20
|
+
:param target_type: ['x0', 'eps', 'velocity', ..., None] The space to calculate the loss
|
21
|
+
'''
|
19
22
|
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
+
self.sigma_scheduler = sigma_scheduler
|
24
|
+
self.generator = generator
|
25
|
+
self.pred_type = pred_type
|
26
|
+
self.target_type = target_type
|
23
27
|
|
24
28
|
def get_timesteps(self, N_steps, device='cuda'):
|
25
|
-
|
29
|
+
times = torch.linspace(0., 1., N_steps, device=device)
|
30
|
+
return self.sigma_scheduler.scale_t(times)
|
26
31
|
|
27
32
|
def make_nosie(self, shape, device='cuda', dtype=torch.float32):
|
28
|
-
return
|
33
|
+
return randn_tensor(shape, generator=self.generator, device=device, dtype=dtype)
|
29
34
|
|
30
35
|
def init_noise(self, shape, device='cuda', dtype=torch.float32):
|
31
|
-
sigma = self.sigma_scheduler.
|
36
|
+
sigma = self.sigma_scheduler.sigma_end
|
32
37
|
return self.make_nosie(shape, device, dtype)*sigma
|
33
38
|
|
34
|
-
def add_noise(self, x,
|
39
|
+
def add_noise(self, x, t) -> Tuple[torch.Tensor, torch.Tensor]:
|
35
40
|
noise = self.make_nosie(x.shape, device=x.device)
|
36
|
-
|
41
|
+
alpha = self.sigma_scheduler.alpha(t).view(-1, 1, 1, 1).to(x.device)
|
42
|
+
sigma = self.sigma_scheduler.sigma(t).view(-1, 1, 1, 1).to(x.device)
|
43
|
+
noisy_x = alpha*x+sigma*noise
|
37
44
|
return noisy_x.to(dtype=x.dtype), noise.to(dtype=x.dtype)
|
38
45
|
|
39
46
|
def add_noise_rand_t(self, x):
|
40
47
|
bs = x.shape[0]
|
41
48
|
# timesteps: [0, 1]
|
42
|
-
|
43
|
-
sigma = sigma.view(-1, 1, 1, 1).to(x.device)
|
49
|
+
timesteps = self.sigma_scheduler.sample(shape=(bs,))
|
44
50
|
timesteps = timesteps.to(x.device)
|
45
|
-
noisy_x, noise = self.add_noise(x,
|
51
|
+
noisy_x, noise = self.add_noise(x, timesteps)
|
46
52
|
|
47
53
|
# Sample a random timestep for each image
|
48
|
-
|
49
|
-
return noisy_x, noise, sigma, timesteps
|
54
|
+
return noisy_x, noise, timesteps
|
50
55
|
|
51
56
|
def denoise(self, x, sigma, eps=None, generator=None):
|
52
57
|
raise NotImplementedError
|
53
58
|
|
54
|
-
def
|
55
|
-
|
56
|
-
|
57
|
-
def
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
def
|
70
|
-
|
71
|
-
|
72
|
-
|
59
|
+
def get_target(self, x0, x_t, t, eps=None, target_type=None):
|
60
|
+
raise x0
|
61
|
+
|
62
|
+
def pred_for_target(self, pred, x_t, t, eps=None, target_type=None):
|
63
|
+
return self.sigma_scheduler.c_skip(t)*x_t+self.sigma_scheduler.c_out(t)*pred
|
64
|
+
|
65
|
+
class Sampler(BaseSampler):
|
66
|
+
'''
|
67
|
+
Some losses can only be calculated in a specific space. Such as SSIM in x0 space.
|
68
|
+
The model pred need convert to target space.
|
69
|
+
|
70
|
+
:param pred_type: ['x0', 'eps', 'velocity', ..., None] The output space of the model
|
71
|
+
:param target_type: ['x0', 'eps', 'velocity', ..., None] The space to calculate the loss
|
72
|
+
'''
|
73
|
+
|
74
|
+
def get_target(self, x_0, x_t, t, eps=None, target_type=None):
|
75
|
+
'''
|
76
|
+
target_type can be specified by the loss. If not specified use self.target_type as default.
|
77
|
+
'''
|
78
|
+
target_type = target_type or self.target_type
|
79
|
+
if target_type == 'x0':
|
80
|
+
raise x_0
|
81
|
+
elif target_type == 'eps':
|
82
|
+
return eps if eps is not None else self.x0_to_eps(eps, x_t, t)
|
83
|
+
elif target_type == 'velocity':
|
84
|
+
return self.x0_to_velocity(x_0, x_t, t, eps)
|
85
|
+
else:
|
86
|
+
return (x_0-self.sigma_scheduler.c_skip(t)*x_t)/self.sigma_scheduler.c_out(t)
|
87
|
+
|
88
|
+
def pred_for_target(self, pred, x_t, t, eps=None, target_type=None):
|
89
|
+
'''
|
90
|
+
target_type can be specified by the loss. If not specified use self.target_type as default.
|
91
|
+
'''
|
92
|
+
target_type = target_type or self.target_type
|
93
|
+
if self.pred_type == target_type:
|
94
|
+
return pred
|
95
|
+
else:
|
96
|
+
cvt_func = getattr(self, f'{self.pred_type}_to_{target_type}', None)
|
97
|
+
if cvt_func is None:
|
98
|
+
if target_type == 'x0':
|
99
|
+
return self.sigma_scheduler.c_skip(t)*x_t+self.sigma_scheduler.c_out(t)*pred
|
100
|
+
else:
|
101
|
+
raise ValueError(f'pred_type "{self.pred_type}" can not be convert for target_type "{target_type}"')
|
102
|
+
else:
|
103
|
+
return cvt_func(pred, x_t, t)
|
104
|
+
|
105
|
+
# convert targets
|
106
|
+
def x0_to_eps(self, x_0, x_t, t):
|
107
|
+
return (x_t-self.sigma_scheduler.alpha(t)*x_0)/self.sigma_scheduler.sigma(t)
|
108
|
+
|
109
|
+
def x0_to_velocity(self, x_0, x_t, t, eps=None):
|
110
|
+
d_alpha, d_sigma = self.sigma_scheduler.velocity(t)
|
111
|
+
if eps is None:
|
112
|
+
eps = self.x0_to_eps(x_0, x_t, t)
|
113
|
+
return d_alpha*x_0+d_sigma*eps
|
114
|
+
|
115
|
+
def eps_to_x0(self, eps, x_t, t):
|
116
|
+
return (x_t-self.sigma_scheduler.sigma(t)*eps)/self.sigma_scheduler.alpha(t)
|
117
|
+
|
118
|
+
def eps_to_velocity(self, eps, x_t, t, x_0=None):
|
119
|
+
d_alpha, d_sigma = self.sigma_scheduler.velocity(t)
|
120
|
+
if x_0 is None:
|
121
|
+
x_0 = self.eps_to_x0(eps, x_t, t)
|
122
|
+
return d_alpha*x_0+d_sigma*eps
|
123
|
+
|
124
|
+
def velocity_to_eps(self, v_pred, x_t, t):
|
125
|
+
alpha = self.sigma_scheduler.alpha(t)
|
126
|
+
sigma = self.sigma_scheduler.sigma(t)
|
127
|
+
d_alpha, d_sigma = self.sigma_scheduler.velocity(t)
|
128
|
+
return (alpha*v_pred-d_alpha*x_t)/(d_sigma*alpha-d_alpha*sigma)
|
129
|
+
|
130
|
+
def velocity_to_x0(self, v_pred, x_t, t):
|
131
|
+
alpha = self.sigma_scheduler.alpha(t)
|
132
|
+
sigma = self.sigma_scheduler.sigma(t)
|
133
|
+
d_alpha, d_sigma = self.sigma_scheduler.velocity(t)
|
134
|
+
return (sigma*v_pred-d_sigma*x_t)/(d_alpha*sigma-d_sigma*alpha)
|
@@ -18,31 +18,24 @@ class DiffusersSampler(BaseSampler):
|
|
18
18
|
self.scheduler = scheduler
|
19
19
|
self.eta = eta
|
20
20
|
|
21
|
-
|
22
|
-
|
21
|
+
self.sigma_scheduler.c_in = self.c_in
|
22
|
+
|
23
|
+
def c_in(self, t):
|
24
|
+
one = torch.ones_like(t)
|
23
25
|
if hasattr(self.scheduler, '_step_index'):
|
24
26
|
self.scheduler._step_index = None
|
25
|
-
return self.scheduler.scale_model_input(one,
|
26
|
-
|
27
|
-
def c_out(self, sigma):
|
28
|
-
return -sigma
|
29
|
-
|
30
|
-
def c_skip(self, sigma):
|
31
|
-
if self.c_in(sigma) == 1.: # DDPM model
|
32
|
-
return (sigma**2+1).sqrt() # 1/sqrt(alpha_)
|
33
|
-
else: # EDM model
|
34
|
-
return 1.
|
27
|
+
return self.scheduler.scale_model_input(one, t)
|
35
28
|
|
36
29
|
def get_timesteps(self, N_steps, device='cuda'):
|
37
30
|
self.scheduler.set_timesteps(N_steps, device=device)
|
38
|
-
return self.scheduler.timesteps
|
31
|
+
return self.scheduler.timesteps / self.sigma_scheduler.num_timesteps # Normalize timesteps to [0, 1]
|
39
32
|
|
40
33
|
def init_noise(self, shape, device='cuda', dtype=torch.float32):
|
41
34
|
return randn_tensor(shape, generator=self.generator, device=device, dtype=dtype)*self.scheduler.init_noise_sigma
|
42
35
|
|
43
|
-
def add_noise(self, x,
|
36
|
+
def add_noise(self, x, t):
|
44
37
|
noise = randn_tensor(x.shape, generator=self.generator, device=x.device, dtype=x.dtype)
|
45
|
-
return self.scheduler.add_noise(x, noise,
|
38
|
+
return self.scheduler.add_noise(x, noise, t), noise
|
46
39
|
|
47
40
|
def prepare_extra_step_kwargs(self, scheduler, generator, eta):
|
48
41
|
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
@@ -61,6 +54,7 @@ class DiffusersSampler(BaseSampler):
|
|
61
54
|
extra_step_kwargs["generator"] = generator
|
62
55
|
return extra_step_kwargs
|
63
56
|
|
64
|
-
def denoise(self, x_t,
|
57
|
+
def denoise(self, x_t, t, eps=None, generator=None):
|
58
|
+
t_in = self.sigma_scheduler.c_noise(t)
|
65
59
|
extra_step_kwargs = self.prepare_extra_step_kwargs(self.scheduler, generator, self.eta)
|
66
|
-
return self.scheduler.step(eps,
|
60
|
+
return self.scheduler.step(eps, t_in, x_t, **extra_step_kwargs).prev_sample
|
@@ -1,3 +1,5 @@
|
|
1
1
|
from .base import SigmaScheduler
|
2
2
|
from .ddpm import DDPMDiscreteSigmaScheduler, DDPMContinuousSigmaScheduler, TimeSigmaScheduler
|
3
|
-
from .edm import EDMSigmaScheduler,
|
3
|
+
from .edm import EDMSigmaScheduler, EDMTimeRescaleScheduler
|
4
|
+
from .flow import FlowSigmaScheduler
|
5
|
+
from .zero_terminal import ZeroTerminalScheduler
|
@@ -3,12 +3,87 @@ from typing import Union, Tuple
|
|
3
3
|
import torch
|
4
4
|
|
5
5
|
class SigmaScheduler:
|
6
|
+
def scale_t(self, t):
|
7
|
+
return t
|
6
8
|
|
7
|
-
def
|
9
|
+
def sigma(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
|
8
10
|
'''
|
11
|
+
x(t) = \alpha(t)*x(0) + \sigma(t)*eps
|
9
12
|
:param t: 0-1, rate of time step
|
10
13
|
'''
|
11
14
|
raise NotImplementedError
|
12
15
|
|
13
|
-
def
|
16
|
+
def alpha(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
|
17
|
+
'''
|
18
|
+
x(t) = \alpha(t)*x(0) + \sigma(t)*eps
|
19
|
+
:param t: 0-1, rate of time step
|
20
|
+
'''
|
21
|
+
raise NotImplementedError
|
22
|
+
|
23
|
+
def velocity(self, t: Union[float, torch.Tensor], dt=1e-8, normlize=True) -> Tuple[torch.Tensor, torch.Tensor]:
|
24
|
+
'''
|
25
|
+
v(t) = dx(t)/dt = d\alpha(t)/dt * x(0) + d\sigma(t)/dt *eps
|
26
|
+
:param t: 0-1, rate of time step
|
27
|
+
:return: d\alpha(t)/dt, d\sigma(t)/dt
|
28
|
+
'''
|
29
|
+
d_alpha = (self.alpha(t+dt)-self.alpha(t))/dt
|
30
|
+
d_sigma = (self.sigma(t+dt)-self.sigma(t))/dt
|
31
|
+
if normlize:
|
32
|
+
norm = torch.sqrt(d_alpha**2+d_sigma**2)
|
33
|
+
return d_alpha/norm, d_sigma/norm
|
34
|
+
else:
|
35
|
+
return d_alpha, d_sigma
|
36
|
+
|
37
|
+
@property
|
38
|
+
def sigma_start(self):
|
39
|
+
return self.sigma(0)
|
40
|
+
|
41
|
+
@property
|
42
|
+
def sigma_end(self):
|
43
|
+
return self.sigma(1)
|
44
|
+
|
45
|
+
@property
|
46
|
+
def alpha_start(self):
|
47
|
+
return self.alpha(0)
|
48
|
+
|
49
|
+
@property
|
50
|
+
def alpha_end(self):
|
51
|
+
return self.alpha(1)
|
52
|
+
|
53
|
+
def alpha_to_sigma(self, alpha):
|
14
54
|
raise NotImplementedError
|
55
|
+
|
56
|
+
def sigma_to_alpha(self, sigma):
|
57
|
+
raise NotImplementedError
|
58
|
+
|
59
|
+
def c_in(self, t: Union[float, torch.Tensor]):
|
60
|
+
if isinstance(t, float):
|
61
|
+
return 1.
|
62
|
+
else:
|
63
|
+
return torch.ones_like(t, dtype=torch.float32)
|
64
|
+
|
65
|
+
def c_skip(self, t: Union[float, torch.Tensor]):
|
66
|
+
'''
|
67
|
+
\hat{x}(0) = c_skip*x(t) + c_out*f(x(t))
|
68
|
+
:param t: 0-1, rate of time step
|
69
|
+
'''
|
70
|
+
return 1./self.alpha(t)
|
71
|
+
|
72
|
+
def c_out(self, t: Union[float, torch.Tensor]):
|
73
|
+
'''
|
74
|
+
\hat{x}(0) = c_skip*x(t) + c_out*f(x(t))
|
75
|
+
:param t: 0-1, rate of time step
|
76
|
+
'''
|
77
|
+
return -self.sigma(t)/self.alpha(t)
|
78
|
+
|
79
|
+
def c_noise(self, t: Union[float, torch.Tensor]):
|
80
|
+
return t
|
81
|
+
|
82
|
+
def sample(self, min_t=0.0, max_t=1.0, shape=(1,)) -> torch.Tensor:
|
83
|
+
if isinstance(min_t, float):
|
84
|
+
min_t = torch.full(shape, min_t)
|
85
|
+
if isinstance(max_t, float):
|
86
|
+
max_t = torch.full(shape, max_t)
|
87
|
+
|
88
|
+
t = torch.lerp(min_t, max_t, torch.rand_like(min_t))
|
89
|
+
return t
|