hcpdiff 2.2.1__py3-none-any.whl → 2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/ckpt_manager/__init__.py +1 -1
- hcpdiff/ckpt_manager/ckpt.py +21 -17
- hcpdiff/ckpt_manager/format/diffusers.py +4 -4
- hcpdiff/ckpt_manager/format/sd_single.py +3 -3
- hcpdiff/ckpt_manager/loader.py +11 -4
- hcpdiff/diffusion/noise/__init__.py +0 -1
- hcpdiff/diffusion/sampler/VP.py +27 -0
- hcpdiff/diffusion/sampler/__init__.py +2 -3
- hcpdiff/diffusion/sampler/base.py +106 -44
- hcpdiff/diffusion/sampler/diffusers.py +11 -17
- hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -1
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +77 -2
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +193 -49
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +110 -33
- hcpdiff/diffusion/sampler/sigma_scheduler/flow.py +74 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/zero_terminal.py +22 -0
- hcpdiff/easy/cfg/sd15_train.py +33 -22
- hcpdiff/easy/cfg/sdxl_train.py +32 -23
- hcpdiff/evaluate/__init__.py +3 -1
- hcpdiff/evaluate/evaluator.py +76 -0
- hcpdiff/evaluate/metrics/__init__.py +1 -0
- hcpdiff/evaluate/metrics/clip_score.py +23 -0
- hcpdiff/evaluate/previewer.py +29 -12
- hcpdiff/loss/base.py +9 -26
- hcpdiff/loss/weighting.py +36 -18
- hcpdiff/models/lora_base_patch.py +26 -0
- hcpdiff/models/wrapper/sd.py +17 -19
- hcpdiff/trainer_ac.py +7 -5
- hcpdiff/trainer_ac_single.py +1 -6
- hcpdiff/utils/__init__.py +2 -1
- hcpdiff/utils/torch_utils.py +25 -0
- hcpdiff/workflow/__init__.py +1 -1
- hcpdiff/workflow/diffusion.py +27 -7
- hcpdiff/workflow/io.py +20 -3
- hcpdiff/workflow/text.py +6 -1
- {hcpdiff-2.2.1.dist-info → hcpdiff-2.3.dist-info}/METADATA +2 -2
- {hcpdiff-2.2.1.dist-info → hcpdiff-2.3.dist-info}/RECORD +41 -37
- {hcpdiff-2.2.1.dist-info → hcpdiff-2.3.dist-info}/WHEEL +1 -1
- hcpdiff/diffusion/noise/zero_terminal.py +0 -39
- hcpdiff/diffusion/sampler/ddpm.py +0 -20
- hcpdiff/diffusion/sampler/edm.py +0 -22
- {hcpdiff-2.2.1.dist-info → hcpdiff-2.3.dist-info}/entry_points.txt +0 -0
- {hcpdiff-2.2.1.dist-info → hcpdiff-2.3.dist-info}/licenses/LICENSE +0 -0
- {hcpdiff-2.2.1.dist-info → hcpdiff-2.3.dist-info}/top_level.txt +0 -0
hcpdiff/easy/cfg/sdxl_train.py
CHANGED
@@ -1,15 +1,14 @@
|
|
1
1
|
import torch
|
2
|
-
from
|
3
|
-
from rainbowneko.parser import CfgWDPluginParser, neko_cfg, CfgWDModelParser, disable_neko_cfg
|
4
|
-
from rainbowneko.utils import ConstantLR
|
5
|
-
|
2
|
+
from hcpdiff.ckpt_manager import LoraWebuiFormat
|
6
3
|
from hcpdiff.easy import SDXL_auto_loader
|
7
4
|
from hcpdiff.models import SDXLWrapper
|
8
5
|
from hcpdiff.models.lora_layers_patch import LoraLayer
|
9
|
-
from
|
6
|
+
from rainbowneko.ckpt_manager import ckpt_saver, NekoPluginSaver, LAYERS_TRAINABLE, SafeTensorFormat, NekoOptimizerSaver
|
7
|
+
from rainbowneko.parser import CfgWDPluginParser, neko_cfg, CfgWDModelParser, disable_neko_cfg
|
8
|
+
from rainbowneko.utils import ConstantLR
|
10
9
|
|
11
10
|
@neko_cfg
|
12
|
-
def SDXL_finetuning(base_model: str, train_steps: int, dataset, save_step: int = 500, lr: float = 1e-5,
|
11
|
+
def SDXL_finetuning(base_model: str, train_steps: int, dataset, save_step: int = 500, save_optimizer=False, lr: float = 1e-5,
|
13
12
|
dtype: str = 'fp16', low_vram: bool = False, warmup_steps: int = 0, name: str = 'SDXL'):
|
14
13
|
if low_vram:
|
15
14
|
from bitsandbytes.optim import AdamW8bit
|
@@ -17,6 +16,17 @@ def SDXL_finetuning(base_model: str, train_steps: int, dataset, save_step: int =
|
|
17
16
|
else:
|
18
17
|
optimizer = torch.optim.AdamW(_partial_=True)
|
19
18
|
|
19
|
+
ckpt_saver_dict = dict(
|
20
|
+
SDXL=ckpt_saver(
|
21
|
+
ckpt_type='safetensors',
|
22
|
+
target_module='denoiser',
|
23
|
+
layers=LAYERS_TRAINABLE,
|
24
|
+
)
|
25
|
+
)
|
26
|
+
|
27
|
+
if save_optimizer:
|
28
|
+
ckpt_saver_dict['optimizer'] = NekoOptimizerSaver()
|
29
|
+
|
20
30
|
from cfgs.train.py import train_base, tuning_base
|
21
31
|
|
22
32
|
return dict(
|
@@ -30,13 +40,7 @@ def SDXL_finetuning(base_model: str, train_steps: int, dataset, save_step: int =
|
|
30
40
|
)
|
31
41
|
], weight_decay=1e-2),
|
32
42
|
|
33
|
-
ckpt_saver=
|
34
|
-
SDXL=ckpt_saver(
|
35
|
-
ckpt_type='safetensors',
|
36
|
-
target_module='denoiser',
|
37
|
-
layers=LAYERS_TRAINABLE,
|
38
|
-
)
|
39
|
-
),
|
43
|
+
ckpt_saver=ckpt_saver_dict,
|
40
44
|
|
41
45
|
train=dict(
|
42
46
|
train_steps=train_steps,
|
@@ -64,9 +68,9 @@ def SDXL_finetuning(base_model: str, train_steps: int, dataset, save_step: int =
|
|
64
68
|
)
|
65
69
|
|
66
70
|
@neko_cfg
|
67
|
-
def SDXL_lora_train(base_model: str, train_steps: int, dataset, save_step: int = 200, lr: float = 1e-4, rank: int = 4,
|
68
|
-
with_conv: bool = False, dtype: str = 'fp16', low_vram: bool = False, warmup_steps: int = 0,
|
69
|
-
save_webui_format=False):
|
71
|
+
def SDXL_lora_train(base_model: str, train_steps: int, dataset, save_step: int = 200, save_optimizer=False, lr: float = 1e-4, rank: int = 4,
|
72
|
+
alpha: float = None, with_conv: bool = False, dtype: str = 'fp16', low_vram: bool = False, warmup_steps: int = 0,
|
73
|
+
name: str = 'SDXL', save_webui_format=False):
|
70
74
|
with disable_neko_cfg:
|
71
75
|
if alpha is None:
|
72
76
|
alpha = rank
|
@@ -97,6 +101,17 @@ def SDXL_lora_train(base_model: str, train_steps: int, dataset, save_step: int =
|
|
97
101
|
else:
|
98
102
|
lora_format = SafeTensorFormat()
|
99
103
|
|
104
|
+
ckpt_saver_dict = dict(
|
105
|
+
_replace_=True,
|
106
|
+
lora_unet=NekoPluginSaver(
|
107
|
+
format=lora_format,
|
108
|
+
target_plugin='lora1',
|
109
|
+
)
|
110
|
+
)
|
111
|
+
|
112
|
+
if save_optimizer:
|
113
|
+
ckpt_saver_dict['optimizer'] = NekoOptimizerSaver()
|
114
|
+
|
100
115
|
from cfgs.train.py.examples import SD_FT
|
101
116
|
|
102
117
|
return dict(
|
@@ -114,13 +129,7 @@ def SDXL_lora_train(base_model: str, train_steps: int, dataset, save_step: int =
|
|
114
129
|
)
|
115
130
|
), weight_decay=0.1),
|
116
131
|
|
117
|
-
ckpt_saver=
|
118
|
-
_replace_ = True,
|
119
|
-
lora_unet=NekoPluginSaver(
|
120
|
-
format=lora_format,
|
121
|
-
target_plugin='lora1',
|
122
|
-
)
|
123
|
-
),
|
132
|
+
ckpt_saver=ckpt_saver_dict,
|
124
133
|
|
125
134
|
train=dict(
|
126
135
|
train_steps=train_steps,
|
hcpdiff/evaluate/__init__.py
CHANGED
@@ -0,0 +1,76 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from accelerate.hooks import remove_hook_from_module
|
5
|
+
from rainbowneko.evaluate import WorkflowEvaluator, MetricGroup
|
6
|
+
from rainbowneko.utils import to_cuda
|
7
|
+
|
8
|
+
from hcpdiff.models.wrapper import SD15Wrapper
|
9
|
+
|
10
|
+
class HCPEvaluator(WorkflowEvaluator):
|
11
|
+
|
12
|
+
@torch.no_grad()
|
13
|
+
def evaluate(self, step: int, model: SD15Wrapper, prefix='eval/'):
|
14
|
+
if step%self.interval != 0 or not self.trainer.is_local_main_process:
|
15
|
+
return
|
16
|
+
|
17
|
+
# record training layers
|
18
|
+
training_layers = [layer for layer in model.modules() if layer.training]
|
19
|
+
|
20
|
+
model.eval()
|
21
|
+
self.trainer.loggers.info(f'Preview')
|
22
|
+
|
23
|
+
N_repeats = model.text_enc_hook.N_repeats
|
24
|
+
clip_skip = model.text_enc_hook.clip_skip
|
25
|
+
clip_final_norm = model.text_enc_hook.clip_final_norm
|
26
|
+
use_attention_mask = model.text_enc_hook.use_attention_mask
|
27
|
+
|
28
|
+
preview_root = Path(self.trainer.exp_dir)/'imgs'
|
29
|
+
preview_root.mkdir(parents=True, exist_ok=True)
|
30
|
+
|
31
|
+
states = self.workflow_runner.run(model=model, in_preview=True, te_hook=model.text_enc_hook,
|
32
|
+
device=self.device, dtype=self.dtype, preview_root=preview_root, preview_step=step,
|
33
|
+
world_size=self.trainer.world_size, local_rank=self.trainer.local_rank,
|
34
|
+
emb_hook=self.trainer.cfgs.emb_pt.embedding_hook if self.trainer.pt_trainable else None)
|
35
|
+
|
36
|
+
# get metrics
|
37
|
+
metric = states['_metric']
|
38
|
+
|
39
|
+
v_metric = metric.finish(self.trainer.accelerator.gather, self.trainer.is_local_main_process)
|
40
|
+
if not isinstance(v_metric, dict):
|
41
|
+
v_metric = {'metric':v_metric}
|
42
|
+
|
43
|
+
log_data = {
|
44
|
+
"eval/Step":{
|
45
|
+
"format":"{}",
|
46
|
+
"data":[step],
|
47
|
+
}
|
48
|
+
}
|
49
|
+
log_data.update(MetricGroup.format(v_metric, prefix=prefix))
|
50
|
+
self.trainer.loggers.log(log_data, step, force=True)
|
51
|
+
|
52
|
+
# restore model states
|
53
|
+
if model.vae is not None:
|
54
|
+
model.vae.disable_tiling()
|
55
|
+
model.vae.disable_slicing()
|
56
|
+
remove_hook_from_module(model.vae, recurse=True)
|
57
|
+
if 'vae_encode_raw' in states:
|
58
|
+
model.vae.encode = states['vae_encode_raw']
|
59
|
+
model.vae.decode = states['vae_decode_raw']
|
60
|
+
|
61
|
+
if 'emb_hook' in states and not self.trainer.pt_trainable:
|
62
|
+
states['emb_hook'].remove()
|
63
|
+
|
64
|
+
if self.trainer.pt_trainable:
|
65
|
+
self.trainer.cfgs.emb_pt.embedding_hook.N_repeats = N_repeats
|
66
|
+
|
67
|
+
model.tokenizer.N_repeats = N_repeats
|
68
|
+
model.text_enc_hook.N_repeats = N_repeats
|
69
|
+
model.text_enc_hook.clip_skip = clip_skip
|
70
|
+
model.text_enc_hook.clip_final_norm = clip_final_norm
|
71
|
+
model.text_enc_hook.use_attention_mask = use_attention_mask
|
72
|
+
|
73
|
+
to_cuda(model)
|
74
|
+
|
75
|
+
for layer in training_layers:
|
76
|
+
layer.train()
|
@@ -0,0 +1 @@
|
|
1
|
+
from .clip_score import CLIPScoreMetric
|
@@ -0,0 +1,23 @@
|
|
1
|
+
from torchmetrics.multimodal.clip_score import CLIPScore, _clip_score_update
|
2
|
+
from torch import Tensor
|
3
|
+
from typing import List
|
4
|
+
|
5
|
+
class CLIPScoreMetric(CLIPScore):
|
6
|
+
def update(self, images: Tensor | List[Tensor], text: str | list[str]) -> None:
|
7
|
+
"""Update CLIP score on a batch of images and text.
|
8
|
+
|
9
|
+
Args:
|
10
|
+
images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors, in the [-1, 1] range
|
11
|
+
text: Either a single caption or a list of captions
|
12
|
+
|
13
|
+
Raises:
|
14
|
+
ValueError:
|
15
|
+
If not all images have format [C, H, W]
|
16
|
+
ValueError:
|
17
|
+
If the number of images and captions do not match
|
18
|
+
|
19
|
+
"""
|
20
|
+
images = (images+1)/2 # [-1,1] -> [0,1]
|
21
|
+
score, n_samples = _clip_score_update(images, text, self.model, self.processor)
|
22
|
+
self.score += score.sum(0)
|
23
|
+
self.n_samples += n_samples
|
hcpdiff/evaluate/previewer.py
CHANGED
@@ -6,32 +6,49 @@ from rainbowneko.utils import to_cuda
|
|
6
6
|
|
7
7
|
from hcpdiff.models.wrapper import SD15Wrapper
|
8
8
|
from accelerate.hooks import remove_hook_from_module
|
9
|
+
from typing import Dict
|
10
|
+
from types import ModuleType
|
9
11
|
|
10
12
|
class HCPPreviewer(WorkflowPreviewer):
|
13
|
+
def __init__(self, parser, cfgs_raw, workflow: str | ModuleType | Dict, ds_name=None, interval=100, trainer=None,
|
14
|
+
mixed_precision=None, seed=42, **cfgs):
|
15
|
+
super().__init__(parser, cfgs_raw, workflow, ds_name=ds_name, interval=interval, trainer=trainer,
|
16
|
+
mixed_precision=mixed_precision, seed=seed, **cfgs)
|
17
|
+
if trainer is None:
|
18
|
+
self.pt_trainable = False
|
19
|
+
else:
|
20
|
+
self.emb_pt = trainer.cfgs.emb_pt
|
21
|
+
self.pt_trainable = trainer.pt_trainable
|
11
22
|
|
12
23
|
@torch.no_grad()
|
13
|
-
def evaluate(self, step: int,
|
14
|
-
if step%self.interval != 0 or not self.
|
24
|
+
def evaluate(self, step: int, prefix='eval/'):
|
25
|
+
if step%self.interval != 0 or not self.is_local_main_process:
|
15
26
|
return
|
16
27
|
|
17
28
|
# record training layers
|
18
|
-
|
29
|
+
if self.model_wrapper is not None:
|
30
|
+
training_layers = [layer for layer in self.model_raw.modules() if layer.training]
|
31
|
+
self.model_wrapper.eval()
|
32
|
+
model = self.model_raw
|
33
|
+
else:
|
34
|
+
training_layers = []
|
35
|
+
model = None
|
19
36
|
|
20
|
-
|
21
|
-
|
37
|
+
if self.loggers is not None:
|
38
|
+
self.loggers.info(f'Preview')
|
22
39
|
|
23
40
|
N_repeats = model.text_enc_hook.N_repeats
|
24
41
|
clip_skip = model.text_enc_hook.clip_skip
|
25
42
|
clip_final_norm = model.text_enc_hook.clip_final_norm
|
26
43
|
use_attention_mask = model.text_enc_hook.use_attention_mask
|
27
44
|
|
28
|
-
preview_root = Path(self.
|
45
|
+
preview_root = Path(self.exp_dir)/'imgs'
|
29
46
|
preview_root.mkdir(parents=True, exist_ok=True)
|
30
47
|
|
31
48
|
states = self.workflow_runner.run(model=model, in_preview=True, te_hook=model.text_enc_hook,
|
32
|
-
device=self.device, dtype=self.
|
33
|
-
world_size=self.
|
34
|
-
emb_hook=self.
|
49
|
+
device=self.device, dtype=self.weight_dtype, preview_root=preview_root, preview_step=step,
|
50
|
+
world_size=self.world_size, local_rank=self.local_rank,
|
51
|
+
emb_hook=self.emb_pt.embedding_hook if self.pt_trainable else None)
|
35
52
|
|
36
53
|
# restore model states
|
37
54
|
if model.vae is not None:
|
@@ -42,11 +59,11 @@ class HCPPreviewer(WorkflowPreviewer):
|
|
42
59
|
model.vae.encode = states['vae_encode_raw']
|
43
60
|
model.vae.decode = states['vae_decode_raw']
|
44
61
|
|
45
|
-
if 'emb_hook' in states and not self.
|
62
|
+
if 'emb_hook' in states and not self.pt_trainable:
|
46
63
|
states['emb_hook'].remove()
|
47
64
|
|
48
|
-
if self.
|
49
|
-
self.
|
65
|
+
if self.pt_trainable:
|
66
|
+
self.emb_pt.embedding_hook.N_repeats = N_repeats
|
50
67
|
|
51
68
|
model.tokenizer.N_repeats = N_repeats
|
52
69
|
model.text_enc_hook.N_repeats = N_repeats
|
hcpdiff/loss/base.py
CHANGED
@@ -6,36 +6,19 @@ class DiffusionLossContainer(LossContainer):
|
|
6
6
|
def __init__(self, loss, weight=1.0, key_map=None):
|
7
7
|
key_map = key_map or getattr(loss, '_key_map', None) or ('pred.model_pred -> 0', 'pred.target -> 1')
|
8
8
|
super().__init__(loss, weight, key_map)
|
9
|
-
self.target_type = getattr(loss, 'target_type',
|
9
|
+
self.target_type = getattr(loss, 'target_type', None)
|
10
10
|
|
11
|
-
def get_target(self,
|
11
|
+
def get_target(self, model_pred, x_0, noise, x_t, timesteps, noise_sampler, **kwargs):
|
12
12
|
# Get target
|
13
|
-
|
14
|
-
target = noise
|
15
|
-
elif self.target_type == "x0":
|
16
|
-
target = x_0
|
17
|
-
elif self.target_type == "velocity":
|
18
|
-
target = noise_sampler.eps_to_velocity(noise, x_t, sigma)
|
19
|
-
else:
|
20
|
-
raise ValueError(f"Unsupport target_type {self.target_type}")
|
13
|
+
target = noise_sampler.get_target(x_0, x_t, timesteps, eps=noise, target_type=self.target_type)
|
21
14
|
|
22
|
-
#
|
23
|
-
|
24
|
-
|
25
|
-
# model_pred, _ = model_pred.chunk(2, dim=1)
|
26
|
-
|
27
|
-
# Convert pred_type to target_type
|
28
|
-
if pred_type != self.target_type:
|
29
|
-
cvt_func = getattr(noise_sampler, f'{pred_type}_to_{self.target_type}', None)
|
30
|
-
if cvt_func is None:
|
31
|
-
raise ValueError(f"Unsupport pred_type {pred_type} with target_type {self.target_type}")
|
32
|
-
else:
|
33
|
-
model_pred = cvt_func(model_pred, x_t, sigma)
|
34
|
-
return model_pred, target
|
15
|
+
# Convert pred_type for target_type
|
16
|
+
pred = noise_sampler.pred_for_target(model_pred, x_t, timesteps, eps=noise, target_type=self.target_type)
|
17
|
+
return pred, target
|
35
18
|
|
36
19
|
def forward(self, pred:Dict[str,Any], inputs:Dict[str,Any]) -> Tensor:
|
37
|
-
|
38
|
-
pred['model_pred'] =
|
20
|
+
pred_cvt, target = self.get_target(**pred)
|
21
|
+
pred['model_pred'] = pred_cvt
|
39
22
|
pred['target'] = target
|
40
|
-
loss = super().forward(pred, inputs)
|
23
|
+
loss = super().forward(pred, inputs) # [B,*,*,*]
|
41
24
|
return loss.mean()
|
hcpdiff/loss/weighting.py
CHANGED
@@ -7,6 +7,11 @@ class LossWeight(nn.Module):
|
|
7
7
|
super().__init__()
|
8
8
|
self.loss = loss
|
9
9
|
|
10
|
+
def get_c_out(self, pred):
|
11
|
+
t = pred['timesteps']
|
12
|
+
noise_sampler = pred['noise_sampler']
|
13
|
+
return noise_sampler.sigma_scheduler.c_out(t)
|
14
|
+
|
10
15
|
def get_weight(self, pred, inputs):
|
11
16
|
'''
|
12
17
|
|
@@ -25,13 +30,19 @@ class LossWeight(nn.Module):
|
|
25
30
|
|
26
31
|
class SNRWeight(LossWeight):
|
27
32
|
def get_weight(self, pred, inputs):
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
+
noise_sampler = pred['noise_sampler']
|
34
|
+
c_out = self.get_c_out(pred)
|
35
|
+
target_type = getattr(self.loss, 'target_type', None) or noise_sampler.target_type
|
36
|
+
if target_type == 'eps':
|
37
|
+
w_snr = 1
|
38
|
+
elif target_type == "x0":
|
39
|
+
w_snr = (1./c_out**2).float()
|
40
|
+
elif target_type == "velocity":
|
41
|
+
w_snr = (1./(1-c_out)**2).float()
|
33
42
|
else:
|
34
|
-
raise ValueError(f"{self.__class__.__name__} is not support for target_type {
|
43
|
+
raise ValueError(f"{self.__class__.__name__} is not support for target_type {target_type}")
|
44
|
+
|
45
|
+
return w_snr.view(-1, 1, 1, 1)
|
35
46
|
|
36
47
|
class MinSNRWeight(LossWeight):
|
37
48
|
def __init__(self, loss: DiffusionLossContainer, gamma: float = 1.):
|
@@ -39,13 +50,18 @@ class MinSNRWeight(LossWeight):
|
|
39
50
|
self.gamma = gamma
|
40
51
|
|
41
52
|
def get_weight(self, pred, inputs):
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
w_snr = (
|
53
|
+
noise_sampler = pred['noise_sampler']
|
54
|
+
c_out = self.get_c_out(pred)
|
55
|
+
target_type = getattr(self.loss, 'target_type', None) or noise_sampler.target_type
|
56
|
+
if target_type == 'eps':
|
57
|
+
w_snr = (self.gamma*c_out**2).clip(max=1).float()
|
58
|
+
elif target_type == "x0":
|
59
|
+
w_snr = (1./c_out**2).clip(max=self.gamma).float()
|
60
|
+
elif target_type == "velocity":
|
61
|
+
w_v = 1/(1-c_out)**2
|
62
|
+
w_snr = (self.gamma*c_out**2/w_v).clip(max=w_v).float()
|
47
63
|
else:
|
48
|
-
raise ValueError(f"{self.__class__.__name__} is not support for target_type {
|
64
|
+
raise ValueError(f"{self.__class__.__name__} is not support for target_type {target_type}")
|
49
65
|
|
50
66
|
return w_snr.view(-1, 1, 1, 1)
|
51
67
|
|
@@ -55,12 +71,14 @@ class EDMWeight(LossWeight):
|
|
55
71
|
self.gamma = gamma
|
56
72
|
|
57
73
|
def get_weight(self, pred, inputs):
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
w_snr =
|
74
|
+
c_out = self.get_c_out(pred)
|
75
|
+
noise_sampler = pred['noise_sampler']
|
76
|
+
target_type = getattr(self.loss, 'target_type', None) or noise_sampler.target_type
|
77
|
+
if target_type == 'edm':
|
78
|
+
w_snr = 1
|
79
|
+
elif target_type == "x0":
|
80
|
+
w_snr = (1./c_out**2).float()
|
63
81
|
else:
|
64
|
-
raise ValueError(f"{self.__class__.__name__} is not support for target_type {
|
82
|
+
raise ValueError(f"{self.__class__.__name__} is not support for target_type {target_type}")
|
65
83
|
|
66
84
|
return w_snr.view(-1, 1, 1, 1)
|
@@ -34,6 +34,32 @@ class LoraPatchContainer(PatchPluginContainer):
|
|
34
34
|
|
35
35
|
return self[name].post_forward(x, self._host.weight, weight_, self._host.bias, bias_)
|
36
36
|
|
37
|
+
@property
|
38
|
+
def weight(self):
|
39
|
+
weight_ = None
|
40
|
+
for name in self.plugin_names:
|
41
|
+
if weight_ is None:
|
42
|
+
weight_ = self[name].get_weight()
|
43
|
+
else:
|
44
|
+
weight_ = weight_+self[name].get_weight()
|
45
|
+
return self._host.weight + weight_
|
46
|
+
|
47
|
+
@property
|
48
|
+
def bias(self):
|
49
|
+
bias_ = None
|
50
|
+
for name in self.plugin_names:
|
51
|
+
if bias_ is None:
|
52
|
+
bias_ = self[name].get_bias()
|
53
|
+
else:
|
54
|
+
bias_ = bias_+self[name].get_bias()
|
55
|
+
|
56
|
+
if self._host.bias is not None:
|
57
|
+
if bias_ is None:
|
58
|
+
bias_ = self._host.bias
|
59
|
+
else:
|
60
|
+
bias_ = self._host.bias + bias_
|
61
|
+
return bias_
|
62
|
+
|
37
63
|
class LoraBlock(PatchPluginBlock):
|
38
64
|
container_cls = LoraPatchContainer
|
39
65
|
wrapable_classes = (nn.Linear, nn.Conv2d)
|
hcpdiff/models/wrapper/sd.py
CHANGED
@@ -17,7 +17,7 @@ from ..cfg_context import CFGContext
|
|
17
17
|
|
18
18
|
class SD15Wrapper(BaseWrapper):
|
19
19
|
def __init__(self, denoiser: UNet2DConditionModel, TE, vae: AutoencoderKL, noise_sampler: BaseSampler, tokenizer, min_attnmask=0,
|
20
|
-
|
20
|
+
TE_hook_cfg:TEHookCFG=SD15_TEHookCFG, cfg_context=CFGContext(), key_map_in=None, key_map_out=None):
|
21
21
|
super().__init__()
|
22
22
|
self.key_mapper_in = self.build_mapper(key_map_in, None, (
|
23
23
|
'prompt -> prompt_ids', 'image -> image', 'attn_mask -> attn_mask', 'position_ids -> position_ids', 'neg_prompt -> neg_prompt_ids',
|
@@ -31,8 +31,6 @@ class SD15Wrapper(BaseWrapper):
|
|
31
31
|
self.tokenizer = tokenizer
|
32
32
|
self.min_attnmask = min_attnmask
|
33
33
|
|
34
|
-
self.pred_type = pred_type
|
35
|
-
|
36
34
|
self.TE_hook_cfg = TEHookCFG.create(TE_hook_cfg)
|
37
35
|
self.cfg_context = cfg_context
|
38
36
|
self.tokenizer.N_repeats = self.TE_hook_cfg.tokenizer_repeats
|
@@ -93,8 +91,9 @@ class SD15Wrapper(BaseWrapper):
|
|
93
91
|
plugin_input={}, **kwargs):
|
94
92
|
# input prepare
|
95
93
|
x_0 = self.get_latents(image)
|
96
|
-
x_t, noise,
|
97
|
-
x_t_in = x_t*self.noise_sampler.c_in(
|
94
|
+
x_t, noise, timesteps = self.noise_sampler.add_noise_rand_t(x_0)
|
95
|
+
x_t_in = x_t*self.noise_sampler.sigma_scheduler.c_in(timesteps).to(dtype=x_t.dtype).view(-1,1,1,1)
|
96
|
+
t_in = self.noise_sampler.sigma_scheduler.c_noise(timesteps)
|
98
97
|
|
99
98
|
if neg_prompt_ids:
|
100
99
|
prompt_ids = torch.cat([neg_prompt_ids, prompt_ids], dim=0)
|
@@ -104,15 +103,14 @@ class SD15Wrapper(BaseWrapper):
|
|
104
103
|
position_ids = torch.cat([neg_position_ids, position_ids], dim=0)
|
105
104
|
|
106
105
|
# model forward
|
107
|
-
x_t_in,
|
108
|
-
encoder_hidden_states = self.forward_TE(prompt_ids,
|
106
|
+
x_t_in, t_in = self.cfg_context.pre(x_t_in, t_in)
|
107
|
+
encoder_hidden_states = self.forward_TE(prompt_ids, t_in, attn_mask=attn_mask, position_ids=position_ids,
|
109
108
|
plugin_input=plugin_input, **kwargs)
|
110
|
-
model_pred = self.forward_denoiser(x_t_in, prompt_ids, encoder_hidden_states,
|
109
|
+
model_pred = self.forward_denoiser(x_t_in, prompt_ids, encoder_hidden_states, t_in, attn_mask=attn_mask, position_ids=position_ids,
|
111
110
|
plugin_input=plugin_input, **kwargs)
|
112
111
|
model_pred = self.cfg_context.post(model_pred)
|
113
112
|
|
114
|
-
return dict(model_pred=model_pred, noise=noise,
|
115
|
-
noise_sampler=self.noise_sampler)
|
113
|
+
return dict(model_pred=model_pred, noise=noise, timesteps=timesteps, x_0=x_0, x_t=x_t, noise_sampler=self.noise_sampler)
|
116
114
|
|
117
115
|
def forward(self, ds_name=None, **kwargs):
|
118
116
|
model_args, model_kwargs = self.get_map_data(self.key_mapper_in, kwargs, ds_name)
|
@@ -156,8 +154,8 @@ class SD15Wrapper(BaseWrapper):
|
|
156
154
|
|
157
155
|
class SDXLWrapper(SD15Wrapper):
|
158
156
|
def __init__(self, denoiser: UNet2DConditionModel, TE, vae: AutoencoderKL, noise_sampler: BaseSampler, tokenizer, min_attnmask=0,
|
159
|
-
|
160
|
-
super().__init__(denoiser, TE, vae, noise_sampler, tokenizer, min_attnmask,
|
157
|
+
TE_hook_cfg:TEHookCFG=SDXL_TEHookCFG, cfg_context=CFGContext(), key_map_in=None, key_map_out=None):
|
158
|
+
super().__init__(denoiser, TE, vae, noise_sampler, tokenizer, min_attnmask, TE_hook_cfg, cfg_context, key_map_in, key_map_out)
|
161
159
|
self.key_mapper_in = self.build_mapper(key_map_in, None, (
|
162
160
|
'prompt -> prompt_ids', 'image -> image', 'attn_mask -> attn_mask', 'position_ids -> position_ids', 'neg_prompt -> neg_prompt_ids',
|
163
161
|
'neg_attn_mask -> neg_attn_mask', 'neg_position_ids -> neg_position_ids', 'plugin_input -> plugin_input', 'coord -> crop_info'))
|
@@ -195,8 +193,9 @@ class SDXLWrapper(SD15Wrapper):
|
|
195
193
|
crop_info=None, plugin_input={}):
|
196
194
|
# input prepare
|
197
195
|
x_0 = self.get_latents(image)
|
198
|
-
x_t, noise,
|
199
|
-
x_t_in = x_t*self.noise_sampler.c_in(
|
196
|
+
x_t, noise, timesteps = self.noise_sampler.add_noise_rand_t(x_0)
|
197
|
+
x_t_in = x_t*self.noise_sampler.sigma_scheduler.c_in(timesteps).to(dtype=x_t.dtype).view(-1,1,1,1)
|
198
|
+
t_in = self.noise_sampler.sigma_scheduler.c_noise(timesteps)
|
200
199
|
|
201
200
|
if neg_prompt_ids:
|
202
201
|
prompt_ids = torch.cat([neg_prompt_ids, prompt_ids], dim=0)
|
@@ -206,13 +205,12 @@ class SDXLWrapper(SD15Wrapper):
|
|
206
205
|
position_ids = torch.cat([neg_position_ids, position_ids], dim=0)
|
207
206
|
|
208
207
|
# model forward
|
209
|
-
x_t_in,
|
210
|
-
encoder_hidden_states, pooled_output = self.forward_TE(prompt_ids,
|
208
|
+
x_t_in, t_in = self.cfg_context.pre(x_t_in, t_in)
|
209
|
+
encoder_hidden_states, pooled_output = self.forward_TE(prompt_ids, t_in, attn_mask=attn_mask, position_ids=position_ids,
|
211
210
|
plugin_input=plugin_input)
|
212
211
|
added_cond_kwargs = {"text_embeds":pooled_output[-1], "time_ids":crop_info}
|
213
|
-
model_pred = self.forward_denoiser(x_t_in, prompt_ids, encoder_hidden_states,
|
212
|
+
model_pred = self.forward_denoiser(x_t_in, prompt_ids, encoder_hidden_states, t_in, added_cond_kwargs=added_cond_kwargs,
|
214
213
|
attn_mask=attn_mask, position_ids=position_ids, plugin_input=plugin_input)
|
215
214
|
model_pred = self.cfg_context.post(model_pred)
|
216
215
|
|
217
|
-
return dict(model_pred=model_pred, noise=noise,
|
218
|
-
noise_sampler=self.noise_sampler)
|
216
|
+
return dict(model_pred=model_pred, noise=noise, timesteps=timesteps, x_0=x_0, x_t=x_t, noise_sampler=self.noise_sampler)
|
hcpdiff/trainer_ac.py
CHANGED
@@ -4,8 +4,8 @@ import warnings
|
|
4
4
|
import torch
|
5
5
|
from rainbowneko.parser import load_config_with_cli
|
6
6
|
from rainbowneko.ckpt_manager import NekoSaver
|
7
|
-
from rainbowneko.train import Trainer
|
8
|
-
from rainbowneko.utils import xformers_available, is_dict
|
7
|
+
from rainbowneko.train.trainer import Trainer
|
8
|
+
from rainbowneko.utils import xformers_available, is_dict, weight_dtype_map
|
9
9
|
from hcpdiff.ckpt_manager import EmbFormat
|
10
10
|
|
11
11
|
class HCPTrainer(Trainer):
|
@@ -17,7 +17,7 @@ class HCPTrainer(Trainer):
|
|
17
17
|
warnings.warn("xformers is not available. Make sure it is installed correctly")
|
18
18
|
|
19
19
|
if self.model_wrapper.vae is not None:
|
20
|
-
self.vae_dtype =
|
20
|
+
self.vae_dtype = weight_dtype_map.get(self.cfgs.model.get('vae_dtype', None), torch.float32)
|
21
21
|
self.model_wrapper.set_dtype(self.weight_dtype, self.vae_dtype)
|
22
22
|
|
23
23
|
if self.cfgs.model.gradient_checkpointing:
|
@@ -44,10 +44,12 @@ class HCPTrainer(Trainer):
|
|
44
44
|
|
45
45
|
def save_model(self, from_raw=False):
|
46
46
|
NekoSaver.save_all(
|
47
|
-
self.model_raw,
|
48
|
-
plugin_groups={**self.all_plugin, 'embs': self.train_pts},
|
49
47
|
cfg=self.ckpt_saver,
|
48
|
+
model=self.model_raw,
|
49
|
+
plugin_groups=self.all_plugin,
|
50
|
+
embs=self.train_pts,
|
50
51
|
model_ema=getattr(self, "ema_model", None),
|
52
|
+
optimizer=self.optimizer,
|
51
53
|
name_template=f'{{}}-{self.real_step}',
|
52
54
|
)
|
53
55
|
|
hcpdiff/trainer_ac_single.py
CHANGED
@@ -1,12 +1,7 @@
|
|
1
1
|
import argparse
|
2
|
-
import sys
|
3
|
-
from functools import partial
|
4
|
-
|
5
|
-
import torch
|
6
|
-
from accelerate import Accelerator
|
7
|
-
from loguru import logger
|
8
2
|
|
9
3
|
from rainbowneko.train.trainer import TrainerSingleCard
|
4
|
+
|
10
5
|
from hcpdiff.trainer_ac import HCPTrainer, load_config_with_cli
|
11
6
|
|
12
7
|
class HCPTrainerSingleCard(TrainerSingleCard, HCPTrainer):
|
hcpdiff/utils/__init__.py
CHANGED
@@ -0,0 +1,25 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
def invert_func(func, y, x_min=0.0, x_max=1.0, tol=1e-5, max_iter=100):
|
4
|
+
"""
|
5
|
+
y: [B]
|
6
|
+
:return: x [B]
|
7
|
+
"""
|
8
|
+
y = y.to(dtype=torch.float32)
|
9
|
+
left = torch.full_like(y, x_min)
|
10
|
+
right = torch.full_like(y, x_max)
|
11
|
+
|
12
|
+
for _ in range(max_iter):
|
13
|
+
mid = (left+right)/2
|
14
|
+
val = func(mid)
|
15
|
+
|
16
|
+
too_large = val>y
|
17
|
+
too_small = ~too_large
|
18
|
+
|
19
|
+
left = torch.where(too_small, mid, left)
|
20
|
+
right = torch.where(too_large, mid, right)
|
21
|
+
|
22
|
+
if torch.all(torch.abs(val-y)<tol):
|
23
|
+
break
|
24
|
+
|
25
|
+
return (left+right)/2
|
hcpdiff/workflow/__init__.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
from .diffusion import InputFeederAction, MakeLatentAction, DenoiseAction, SampleAction, DiffusionStepAction, \
|
2
|
-
X0PredAction, SeedAction, MakeTimestepsAction, PrepareDiffusionAction, time_iter
|
2
|
+
X0PredAction, SeedAction, MakeTimestepsAction, PrepareDiffusionAction, time_iter, DiffusionActions
|
3
3
|
from .text import TextEncodeAction, TextHookAction, AttnMultTextEncodeAction
|
4
4
|
from .vae import EncodeAction, DecodeAction
|
5
5
|
from .io import BuildModelsAction, SaveImageAction, LoadImageAction
|