hcpdiff 0.9.1__py3-none-any.whl → 2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/__init__.py +4 -4
- hcpdiff/ckpt_manager/__init__.py +4 -5
- hcpdiff/ckpt_manager/ckpt.py +24 -0
- hcpdiff/ckpt_manager/format/__init__.py +4 -0
- hcpdiff/ckpt_manager/format/diffusers.py +59 -0
- hcpdiff/ckpt_manager/format/emb.py +21 -0
- hcpdiff/ckpt_manager/format/lora_webui.py +252 -0
- hcpdiff/ckpt_manager/format/sd_single.py +41 -0
- hcpdiff/ckpt_manager/loader.py +64 -0
- hcpdiff/data/__init__.py +4 -28
- hcpdiff/data/cache/__init__.py +1 -0
- hcpdiff/data/cache/vae.py +102 -0
- hcpdiff/data/dataset.py +20 -0
- hcpdiff/data/handler/__init__.py +3 -0
- hcpdiff/data/handler/controlnet.py +18 -0
- hcpdiff/data/handler/diffusion.py +90 -0
- hcpdiff/data/handler/text.py +111 -0
- hcpdiff/data/source/__init__.py +3 -3
- hcpdiff/data/source/folder_class.py +12 -29
- hcpdiff/data/source/text.py +40 -0
- hcpdiff/data/source/text2img.py +36 -74
- hcpdiff/data/source/text2img_cond.py +9 -15
- hcpdiff/diffusion/__init__.py +0 -0
- hcpdiff/diffusion/noise/__init__.py +2 -0
- hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
- hcpdiff/diffusion/noise/zero_terminal.py +39 -0
- hcpdiff/diffusion/sampler/__init__.py +5 -0
- hcpdiff/diffusion/sampler/base.py +72 -0
- hcpdiff/diffusion/sampler/ddpm.py +20 -0
- hcpdiff/diffusion/sampler/diffusers.py +66 -0
- hcpdiff/diffusion/sampler/edm.py +22 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
- hcpdiff/easy/__init__.py +2 -0
- hcpdiff/easy/cfg/__init__.py +3 -0
- hcpdiff/easy/cfg/sd15_train.py +207 -0
- hcpdiff/easy/cfg/sdxl_train.py +147 -0
- hcpdiff/easy/cfg/t2i.py +228 -0
- hcpdiff/easy/model/__init__.py +2 -0
- hcpdiff/easy/model/cnet.py +31 -0
- hcpdiff/easy/model/loader.py +79 -0
- hcpdiff/easy/sampler.py +46 -0
- hcpdiff/evaluate/__init__.py +1 -0
- hcpdiff/evaluate/previewer.py +60 -0
- hcpdiff/loss/__init__.py +4 -1
- hcpdiff/loss/base.py +41 -0
- hcpdiff/loss/gw.py +35 -0
- hcpdiff/loss/ssim.py +37 -0
- hcpdiff/loss/vlb.py +79 -0
- hcpdiff/loss/weighting.py +66 -0
- hcpdiff/models/__init__.py +2 -2
- hcpdiff/models/cfg_context.py +17 -14
- hcpdiff/models/compose/compose_hook.py +44 -23
- hcpdiff/models/compose/compose_tokenizer.py +21 -8
- hcpdiff/models/compose/sdxl_composer.py +4 -4
- hcpdiff/models/controlnet.py +16 -16
- hcpdiff/models/lora_base_patch.py +14 -25
- hcpdiff/models/lora_layers.py +3 -9
- hcpdiff/models/lora_layers_patch.py +14 -24
- hcpdiff/models/text_emb_ex.py +84 -6
- hcpdiff/models/textencoder_ex.py +54 -18
- hcpdiff/models/wrapper/__init__.py +3 -0
- hcpdiff/models/wrapper/pixart.py +19 -0
- hcpdiff/models/wrapper/sd.py +218 -0
- hcpdiff/models/wrapper/utils.py +20 -0
- hcpdiff/parser/__init__.py +1 -0
- hcpdiff/parser/embpt.py +32 -0
- hcpdiff/tools/convert_caption_txt2json.py +1 -1
- hcpdiff/tools/dataset_generator.py +94 -0
- hcpdiff/tools/download_hf_model.py +24 -0
- hcpdiff/tools/init_proj.py +3 -21
- hcpdiff/tools/lora_convert.py +18 -17
- hcpdiff/tools/save_model.py +12 -0
- hcpdiff/tools/sd2diffusers.py +1 -1
- hcpdiff/train_colo.py +1 -1
- hcpdiff/train_deepspeed.py +1 -1
- hcpdiff/trainer_ac.py +79 -0
- hcpdiff/trainer_ac_single.py +31 -0
- hcpdiff/utils/__init__.py +0 -2
- hcpdiff/utils/inpaint_pipe.py +7 -2
- hcpdiff/utils/net_utils.py +29 -6
- hcpdiff/utils/pipe_hook.py +24 -7
- hcpdiff/utils/utils.py +21 -4
- hcpdiff/workflow/__init__.py +15 -10
- hcpdiff/workflow/daam/__init__.py +1 -0
- hcpdiff/workflow/daam/act.py +66 -0
- hcpdiff/workflow/daam/hook.py +109 -0
- hcpdiff/workflow/diffusion.py +118 -128
- hcpdiff/workflow/fast.py +31 -0
- hcpdiff/workflow/flow.py +67 -0
- hcpdiff/workflow/io.py +36 -130
- hcpdiff/workflow/model.py +46 -43
- hcpdiff/workflow/text.py +60 -47
- hcpdiff/workflow/utils.py +32 -12
- hcpdiff/workflow/vae.py +37 -38
- hcpdiff-2.2.dist-info/METADATA +299 -0
- hcpdiff-2.2.dist-info/RECORD +115 -0
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/WHEEL +1 -1
- hcpdiff-2.2.dist-info/entry_points.txt +5 -0
- hcpdiff/ckpt_manager/base.py +0 -16
- hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
- hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
- hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
- hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
- hcpdiff/data/bucket.py +0 -358
- hcpdiff/data/caption_loader.py +0 -80
- hcpdiff/data/cond_dataset.py +0 -40
- hcpdiff/data/crop_info_dataset.py +0 -40
- hcpdiff/data/data_processor.py +0 -33
- hcpdiff/data/pair_dataset.py +0 -146
- hcpdiff/data/sampler.py +0 -54
- hcpdiff/data/source/base.py +0 -30
- hcpdiff/data/utils.py +0 -80
- hcpdiff/deprecated/__init__.py +0 -1
- hcpdiff/deprecated/cfg_converter.py +0 -81
- hcpdiff/deprecated/lora_convert.py +0 -31
- hcpdiff/infer_workflow.py +0 -57
- hcpdiff/loggers/__init__.py +0 -13
- hcpdiff/loggers/base_logger.py +0 -76
- hcpdiff/loggers/cli_logger.py +0 -40
- hcpdiff/loggers/preview/__init__.py +0 -1
- hcpdiff/loggers/preview/image_previewer.py +0 -149
- hcpdiff/loggers/tensorboard_logger.py +0 -30
- hcpdiff/loggers/wandb_logger.py +0 -31
- hcpdiff/loggers/webui_logger.py +0 -9
- hcpdiff/loss/min_snr_loss.py +0 -52
- hcpdiff/models/layers.py +0 -81
- hcpdiff/models/plugin.py +0 -348
- hcpdiff/models/wrapper.py +0 -75
- hcpdiff/noise/__init__.py +0 -3
- hcpdiff/noise/noise_base.py +0 -16
- hcpdiff/noise/pyramid_noise.py +0 -50
- hcpdiff/noise/zero_terminal.py +0 -44
- hcpdiff/train_ac.py +0 -566
- hcpdiff/train_ac_single.py +0 -39
- hcpdiff/utils/caption_tools.py +0 -105
- hcpdiff/utils/cfg_net_tools.py +0 -321
- hcpdiff/utils/cfg_resolvers.py +0 -16
- hcpdiff/utils/ema.py +0 -52
- hcpdiff/utils/img_size_tool.py +0 -248
- hcpdiff/vis/__init__.py +0 -3
- hcpdiff/vis/base_interface.py +0 -12
- hcpdiff/vis/disk_interface.py +0 -48
- hcpdiff/vis/webui_interface.py +0 -17
- hcpdiff/viser_fast.py +0 -138
- hcpdiff/visualizer.py +0 -265
- hcpdiff/visualizer_reloadable.py +0 -237
- hcpdiff/workflow/base.py +0 -59
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
- hcpdiff-0.9.1.dist-info/METADATA +0 -199
- hcpdiff-0.9.1.dist-info/RECORD +0 -160
- hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info/licenses}/LICENSE +0 -0
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,147 @@
|
|
1
|
+
import torch
|
2
|
+
from rainbowneko.ckpt_manager import ckpt_saver, NekoPluginSaver, LAYERS_TRAINABLE, SafeTensorFormat
|
3
|
+
from rainbowneko.parser import CfgWDPluginParser, neko_cfg, CfgWDModelParser, disable_neko_cfg
|
4
|
+
from rainbowneko.utils import ConstantLR
|
5
|
+
|
6
|
+
from hcpdiff.easy import SDXL_auto_loader
|
7
|
+
from hcpdiff.models import SDXLWrapper
|
8
|
+
from hcpdiff.models.lora_layers_patch import LoraLayer
|
9
|
+
from hcpdiff.ckpt_manager import LoraWebuiFormat
|
10
|
+
|
11
|
+
@neko_cfg
|
12
|
+
def SDXL_finetuning(base_model: str, train_steps: int, dataset, save_step: int = 500, lr: float = 1e-5,
|
13
|
+
dtype: str = 'fp16', low_vram: bool = False, warmup_steps: int = 0, name: str = 'SDXL'):
|
14
|
+
if low_vram:
|
15
|
+
from bitsandbytes.optim import AdamW8bit
|
16
|
+
optimizer = AdamW8bit(_partial_=True)
|
17
|
+
else:
|
18
|
+
optimizer = torch.optim.AdamW(_partial_=True)
|
19
|
+
|
20
|
+
from cfgs.train.py import train_base, tuning_base
|
21
|
+
|
22
|
+
return dict(
|
23
|
+
_base_=[train_base, tuning_base],
|
24
|
+
mixed_precision=dtype,
|
25
|
+
|
26
|
+
model_part=CfgWDModelParser([
|
27
|
+
dict(
|
28
|
+
lr=lr,
|
29
|
+
layers=['denoiser'], # train UNet
|
30
|
+
)
|
31
|
+
], weight_decay=1e-2),
|
32
|
+
|
33
|
+
ckpt_saver=dict(
|
34
|
+
SDXL=ckpt_saver(
|
35
|
+
ckpt_type='safetensors',
|
36
|
+
target_module='denoiser',
|
37
|
+
layers=LAYERS_TRAINABLE,
|
38
|
+
)
|
39
|
+
),
|
40
|
+
|
41
|
+
train=dict(
|
42
|
+
train_steps=train_steps,
|
43
|
+
save_step=save_step,
|
44
|
+
|
45
|
+
optimizer=optimizer,
|
46
|
+
|
47
|
+
scheduler=ConstantLR(
|
48
|
+
_partial_=True,
|
49
|
+
warmup_steps=warmup_steps,
|
50
|
+
),
|
51
|
+
),
|
52
|
+
|
53
|
+
model=dict(
|
54
|
+
name=name,
|
55
|
+
|
56
|
+
## Easy config
|
57
|
+
wrapper=SDXLWrapper.from_pretrained(
|
58
|
+
_partial_=True,
|
59
|
+
models=SDXL_auto_loader(ckpt_path=base_model, _partial_=True),
|
60
|
+
),
|
61
|
+
),
|
62
|
+
|
63
|
+
data_train=dataset,
|
64
|
+
)
|
65
|
+
|
66
|
+
@neko_cfg
|
67
|
+
def SDXL_lora_train(base_model: str, train_steps: int, dataset, save_step: int = 200, lr: float = 1e-4, rank: int = 4, alpha: float = None,
|
68
|
+
with_conv: bool = False, dtype: str = 'fp16', low_vram: bool = False, warmup_steps: int = 0, name: str = 'SDXL',
|
69
|
+
save_webui_format=False):
|
70
|
+
with disable_neko_cfg:
|
71
|
+
if alpha is None:
|
72
|
+
alpha = rank
|
73
|
+
|
74
|
+
if with_conv:
|
75
|
+
lora_layers = [
|
76
|
+
r're:denoiser.*\.attn.?$',
|
77
|
+
r're:denoiser.*\.ff$',
|
78
|
+
r're:denoiser.*\.resnets$',
|
79
|
+
r're:denoiser.*\.proj_in$',
|
80
|
+
r're:denoiser.*\.proj_out$',
|
81
|
+
r're:denoiser.*\.conv$',
|
82
|
+
]
|
83
|
+
else:
|
84
|
+
lora_layers = [
|
85
|
+
r're:denoiser.*\.attn.?$',
|
86
|
+
r're:denoiser.*\.ff$',
|
87
|
+
]
|
88
|
+
|
89
|
+
if low_vram:
|
90
|
+
from bitsandbytes.optim import AdamW8bit
|
91
|
+
optimizer = AdamW8bit(_partial_=True, betas=(0.9, 0.99))
|
92
|
+
else:
|
93
|
+
optimizer = torch.optim.AdamW(_partial_=True, betas=(0.9, 0.99))
|
94
|
+
|
95
|
+
if save_webui_format:
|
96
|
+
lora_format = LoraWebuiFormat()
|
97
|
+
else:
|
98
|
+
lora_format = SafeTensorFormat()
|
99
|
+
|
100
|
+
from cfgs.train.py.examples import SD_FT
|
101
|
+
|
102
|
+
return dict(
|
103
|
+
_base_=[SD_FT],
|
104
|
+
mixed_precision=dtype,
|
105
|
+
|
106
|
+
model_part=None,
|
107
|
+
model_plugin=CfgWDPluginParser(cfg_plugin=dict(
|
108
|
+
lora1=LoraLayer.wrap_model(
|
109
|
+
_partial_=True,
|
110
|
+
lr=lr,
|
111
|
+
rank=rank,
|
112
|
+
alpha=alpha,
|
113
|
+
layers=lora_layers
|
114
|
+
)
|
115
|
+
), weight_decay=0.1),
|
116
|
+
|
117
|
+
ckpt_saver=dict(
|
118
|
+
_replace_ = True,
|
119
|
+
lora_unet=NekoPluginSaver(
|
120
|
+
format=lora_format,
|
121
|
+
target_plugin='lora1',
|
122
|
+
)
|
123
|
+
),
|
124
|
+
|
125
|
+
train=dict(
|
126
|
+
train_steps=train_steps,
|
127
|
+
save_step=save_step,
|
128
|
+
|
129
|
+
optimizer=optimizer,
|
130
|
+
|
131
|
+
scheduler=ConstantLR(
|
132
|
+
_partial_=True,
|
133
|
+
warmup_steps=warmup_steps,
|
134
|
+
),
|
135
|
+
),
|
136
|
+
|
137
|
+
model=dict(
|
138
|
+
name=name,
|
139
|
+
|
140
|
+
wrapper=SDXLWrapper.from_pretrained(
|
141
|
+
models=SDXL_auto_loader(ckpt_path=base_model, _partial_=True),
|
142
|
+
_partial_=True,
|
143
|
+
),
|
144
|
+
),
|
145
|
+
|
146
|
+
data_train=dataset,
|
147
|
+
)
|
hcpdiff/easy/cfg/t2i.py
ADDED
@@ -0,0 +1,228 @@
|
|
1
|
+
import torch
|
2
|
+
from rainbowneko.infer.workflow import (Actions, PrepareAction, LoopAction, LoadModelAction)
|
3
|
+
from rainbowneko.ckpt_manager import NekoModelLoader
|
4
|
+
from rainbowneko.parser import neko_cfg, disable_neko_cfg
|
5
|
+
from typing import Union, List
|
6
|
+
|
7
|
+
from hcpdiff.ckpt_manager import HCPLoraLoader
|
8
|
+
from hcpdiff.easy import Diffusers_SD, SD15_auto_loader, SDXL_auto_loader
|
9
|
+
from hcpdiff.workflow import (BuildModelsAction, PrepareDiffusionAction, XformersEnableAction, VaeOptimizeAction, TextHookAction,
|
10
|
+
AttnMultTextEncodeAction, SeedAction, MakeTimestepsAction, MakeLatentAction, DiffusionStepAction,
|
11
|
+
time_iter, DecodeAction, SaveImageAction, LatentResizeAction)
|
12
|
+
|
13
|
+
negative_prompt = 'lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry'
|
14
|
+
|
15
|
+
## Easy config
|
16
|
+
@neko_cfg
|
17
|
+
def build_model(pretrained_model='ckpts/any5', noise_sampler=Diffusers_SD.dpmpp_2m_karras) -> Actions:
|
18
|
+
return Actions([
|
19
|
+
PrepareAction(device='cuda', dtype=torch.float16),
|
20
|
+
BuildModelsAction(
|
21
|
+
model_loader=SD15_auto_loader(
|
22
|
+
_partial_=True,
|
23
|
+
ckpt_path=pretrained_model,
|
24
|
+
noise_sampler=noise_sampler
|
25
|
+
)
|
26
|
+
),
|
27
|
+
])
|
28
|
+
|
29
|
+
@neko_cfg
|
30
|
+
def load_parts(info: List[str]) -> Actions:
|
31
|
+
acts = []
|
32
|
+
for i, path in enumerate(info):
|
33
|
+
part_unet = LoadModelAction(cfg={
|
34
|
+
f'part_unet_{i}':NekoModelLoader(
|
35
|
+
path=path,
|
36
|
+
state_prefix='denoiser.'
|
37
|
+
)
|
38
|
+
}, key_map_in=('denoiser -> model', 'in_preview -> in_preview'))
|
39
|
+
part_TE = LoadModelAction(cfg={
|
40
|
+
f'part_TE_{i}':NekoModelLoader(
|
41
|
+
path=path,
|
42
|
+
state_prefix='TE.',
|
43
|
+
)
|
44
|
+
}, key_map_in=('TE -> model', 'in_preview -> in_preview'))
|
45
|
+
|
46
|
+
with disable_neko_cfg:
|
47
|
+
acts.append(part_unet)
|
48
|
+
acts.append(part_TE)
|
49
|
+
|
50
|
+
return Actions(acts)
|
51
|
+
|
52
|
+
@neko_cfg
|
53
|
+
def load_lora(info: List[List]) -> Actions:
|
54
|
+
lora_acts = []
|
55
|
+
for i, item in enumerate(info):
|
56
|
+
lora_unet = LoadModelAction(cfg={
|
57
|
+
f'lora_unet_{i}':HCPLoraLoader(
|
58
|
+
path=item[0],
|
59
|
+
state_prefix='denoiser.',
|
60
|
+
alpha=item[1],
|
61
|
+
)
|
62
|
+
}, key_map_in=('denoiser -> model', 'in_preview -> in_preview'))
|
63
|
+
lora_TE = LoadModelAction(cfg={
|
64
|
+
f'lora_TE_{i}':HCPLoraLoader(
|
65
|
+
path=item[0],
|
66
|
+
state_prefix='TE.',
|
67
|
+
alpha=item[1],
|
68
|
+
)
|
69
|
+
}, key_map_in=('TE -> model', 'in_preview -> in_preview'))
|
70
|
+
|
71
|
+
with disable_neko_cfg:
|
72
|
+
lora_acts.append(lora_unet)
|
73
|
+
lora_acts.append(lora_TE)
|
74
|
+
|
75
|
+
return Actions(lora_acts)
|
76
|
+
|
77
|
+
@neko_cfg
|
78
|
+
def optimize_model() -> Actions:
|
79
|
+
return Actions([
|
80
|
+
PrepareDiffusionAction(),
|
81
|
+
XformersEnableAction(),
|
82
|
+
VaeOptimizeAction(slicing=True),
|
83
|
+
])
|
84
|
+
|
85
|
+
@neko_cfg
|
86
|
+
def text(prompt, negative_prompt=negative_prompt, bs=4, N_repeats=1, layer_skip=1) -> Actions:
|
87
|
+
return Actions([
|
88
|
+
TextHookAction(N_repeats=N_repeats, layer_skip=layer_skip),
|
89
|
+
AttnMultTextEncodeAction(
|
90
|
+
prompt=prompt,
|
91
|
+
negative_prompt=negative_prompt,
|
92
|
+
bs=bs
|
93
|
+
),
|
94
|
+
])
|
95
|
+
|
96
|
+
@neko_cfg
|
97
|
+
def build_model_SDXL(pretrained_model='ckpts/any5', noise_sampler=Diffusers_SD.dpmpp_2m_karras) -> Actions:
|
98
|
+
return Actions([
|
99
|
+
PrepareAction(device='cuda', dtype=torch.float16),
|
100
|
+
## Easy config
|
101
|
+
BuildModelsAction(
|
102
|
+
model_loader=SDXL_auto_loader(
|
103
|
+
_partial_=True,
|
104
|
+
ckpt_path=pretrained_model,
|
105
|
+
noise_sampler=noise_sampler
|
106
|
+
)
|
107
|
+
),
|
108
|
+
])
|
109
|
+
|
110
|
+
@neko_cfg
|
111
|
+
def text_SDXL(prompt, negative_prompt=negative_prompt, bs=4, N_repeats=1, layer_skip=1) -> Actions:
|
112
|
+
return Actions([
|
113
|
+
TextHookAction(N_repeats=N_repeats, layer_skip=layer_skip, TE_final_norm=False),
|
114
|
+
AttnMultTextEncodeAction(
|
115
|
+
prompt=prompt,
|
116
|
+
negative_prompt=negative_prompt,
|
117
|
+
bs=bs
|
118
|
+
),
|
119
|
+
])
|
120
|
+
|
121
|
+
@neko_cfg
|
122
|
+
def config_diffusion(width=512, height=512, seed=None, N_steps=20, strength: float = None) -> Actions:
|
123
|
+
return Actions([
|
124
|
+
SeedAction(seed),
|
125
|
+
MakeTimestepsAction(N_steps=N_steps, strength=strength),
|
126
|
+
MakeLatentAction(width=width, height=height)
|
127
|
+
])
|
128
|
+
|
129
|
+
@neko_cfg
|
130
|
+
def diffusion(guidance_scale=7.0) -> Actions:
|
131
|
+
return Actions([
|
132
|
+
LoopAction(
|
133
|
+
iterator=time_iter,
|
134
|
+
actions=[
|
135
|
+
DiffusionStepAction(guidance_scale=guidance_scale)
|
136
|
+
]
|
137
|
+
)
|
138
|
+
])
|
139
|
+
|
140
|
+
@neko_cfg
|
141
|
+
def decode(save_root='output_pipe/') -> Actions:
|
142
|
+
return Actions([
|
143
|
+
DecodeAction(),
|
144
|
+
SaveImageAction(save_root=save_root, image_type='png'),
|
145
|
+
])
|
146
|
+
|
147
|
+
@neko_cfg
|
148
|
+
def resize(width=1024, height=1024):
|
149
|
+
return Actions([
|
150
|
+
LatentResizeAction(width=width, height=height)
|
151
|
+
])
|
152
|
+
|
153
|
+
@neko_cfg
|
154
|
+
def SD15_t2i(pretrained_model, prompt, negative_prompt=negative_prompt, noise_sampler=Diffusers_SD.dpmpp_2m_karras, bs=4, width=512, height=512,
|
155
|
+
seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/', N_repeats=1, layer_skip=1):
|
156
|
+
return dict(workflow=Actions(actions=[
|
157
|
+
build_model(pretrained_model=pretrained_model, noise_sampler=noise_sampler),
|
158
|
+
optimize_model(),
|
159
|
+
text(prompt=prompt, negative_prompt=negative_prompt, bs=bs, N_repeats=N_repeats, layer_skip=layer_skip),
|
160
|
+
config_diffusion(width=width, height=height, seed=seed, N_steps=N_steps),
|
161
|
+
diffusion(guidance_scale=guidance_scale),
|
162
|
+
decode(save_root=save_root)
|
163
|
+
]))
|
164
|
+
|
165
|
+
@neko_cfg
|
166
|
+
def SD15_t2i_parts(pretrained_model, parts, prompt, negative_prompt=negative_prompt, noise_sampler=Diffusers_SD.dpmpp_2m_karras, bs=4, width=512, height=512,
|
167
|
+
seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/', N_repeats=1, layer_skip=1):
|
168
|
+
return dict(workflow=Actions(actions=[
|
169
|
+
build_model(pretrained_model=pretrained_model, noise_sampler=noise_sampler),
|
170
|
+
load_parts(parts),
|
171
|
+
optimize_model(),
|
172
|
+
text(prompt=prompt, negative_prompt=negative_prompt, bs=bs, N_repeats=N_repeats, layer_skip=layer_skip),
|
173
|
+
config_diffusion(width=width, height=height, seed=seed, N_steps=N_steps),
|
174
|
+
diffusion(guidance_scale=guidance_scale),
|
175
|
+
decode(save_root=save_root)
|
176
|
+
]))
|
177
|
+
|
178
|
+
@neko_cfg
|
179
|
+
def SD15_t2i_lora(pretrained_model, lora_info, prompt, negative_prompt=negative_prompt, noise_sampler=Diffusers_SD.dpmpp_2m_karras, bs=4,
|
180
|
+
width=512, height=512, seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/', N_repeats=1, layer_skip=1):
|
181
|
+
return dict(workflow=Actions(actions=[
|
182
|
+
build_model(pretrained_model=pretrained_model, noise_sampler=noise_sampler),
|
183
|
+
load_lora(info=lora_info),
|
184
|
+
optimize_model(),
|
185
|
+
text(prompt=prompt, negative_prompt=negative_prompt, bs=bs, N_repeats=N_repeats, layer_skip=layer_skip),
|
186
|
+
config_diffusion(width=width, height=height, seed=seed, N_steps=N_steps),
|
187
|
+
diffusion(guidance_scale=guidance_scale),
|
188
|
+
decode(save_root=save_root)
|
189
|
+
]))
|
190
|
+
|
191
|
+
@neko_cfg
|
192
|
+
def SDXL_t2i(pretrained_model, prompt, negative_prompt=negative_prompt, noise_sampler=Diffusers_SD.dpmpp_2m_karras, bs=4, width=1024, height=1024,
|
193
|
+
seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/', N_repeats=1, layer_skip=1):
|
194
|
+
return dict(workflow=Actions(actions=[
|
195
|
+
build_model_SDXL(pretrained_model=pretrained_model, noise_sampler=noise_sampler),
|
196
|
+
optimize_model(),
|
197
|
+
text_SDXL(prompt=prompt, negative_prompt=negative_prompt, bs=bs, N_repeats=N_repeats, layer_skip=layer_skip),
|
198
|
+
config_diffusion(width=width, height=height, seed=seed, N_steps=N_steps),
|
199
|
+
diffusion(guidance_scale=guidance_scale),
|
200
|
+
decode(save_root=save_root)
|
201
|
+
]))
|
202
|
+
|
203
|
+
@neko_cfg
|
204
|
+
def SDXL_t2i_parts(pretrained_model, parts, prompt, negative_prompt=negative_prompt, noise_sampler=Diffusers_SD.dpmpp_2m_karras, bs=4, width=1024, height=1024,
|
205
|
+
seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/', N_repeats=1, layer_skip=1):
|
206
|
+
return dict(workflow=Actions(actions=[
|
207
|
+
build_model_SDXL(pretrained_model=pretrained_model, noise_sampler=noise_sampler),
|
208
|
+
load_parts(parts),
|
209
|
+
optimize_model(),
|
210
|
+
text_SDXL(prompt=prompt, negative_prompt=negative_prompt, bs=bs, N_repeats=N_repeats, layer_skip=layer_skip),
|
211
|
+
config_diffusion(width=width, height=height, seed=seed, N_steps=N_steps),
|
212
|
+
diffusion(guidance_scale=guidance_scale),
|
213
|
+
decode(save_root=save_root)
|
214
|
+
]))
|
215
|
+
|
216
|
+
|
217
|
+
@neko_cfg
|
218
|
+
def SDXL_t2i_lora(pretrained_model, lora_info, prompt, negative_prompt=negative_prompt, noise_sampler=Diffusers_SD.dpmpp_2m_karras, bs=4,
|
219
|
+
width=1024, height=1024, seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/', N_repeats=1, layer_skip=1):
|
220
|
+
return dict(workflow=Actions(actions=[
|
221
|
+
build_model_SDXL(pretrained_model=pretrained_model, noise_sampler=noise_sampler),
|
222
|
+
load_lora(info=lora_info),
|
223
|
+
optimize_model(),
|
224
|
+
text_SDXL(prompt=prompt, negative_prompt=negative_prompt, bs=bs, N_repeats=N_repeats, layer_skip=layer_skip),
|
225
|
+
config_diffusion(width=width, height=height, seed=seed, N_steps=N_steps),
|
226
|
+
diffusion(guidance_scale=guidance_scale),
|
227
|
+
decode(save_root=save_root)
|
228
|
+
]))
|
@@ -0,0 +1,31 @@
|
|
1
|
+
from hcpdiff.data.handler import ControlNetHandler, StableDiffusionHandler
|
2
|
+
from hcpdiff.models import ControlNetPlugin
|
3
|
+
from rainbowneko.data import SyncHandler
|
4
|
+
from rainbowneko.parser import neko_cfg
|
5
|
+
|
6
|
+
@neko_cfg
|
7
|
+
def ControlNet_SD15(lr=1e-4):
|
8
|
+
return ControlNetPlugin(
|
9
|
+
_partial_=True,
|
10
|
+
lr=lr,
|
11
|
+
from_layers=[
|
12
|
+
'pre_hook:',
|
13
|
+
'pre_hook:conv_in', # to make forward inside autocast
|
14
|
+
],
|
15
|
+
to_layers=[
|
16
|
+
'down_blocks.0',
|
17
|
+
'down_blocks.1',
|
18
|
+
'down_blocks.2',
|
19
|
+
'down_blocks.3',
|
20
|
+
'mid_block',
|
21
|
+
'pre_hook:up_blocks.3.resnets.2',
|
22
|
+
]
|
23
|
+
)
|
24
|
+
|
25
|
+
@neko_cfg
|
26
|
+
def make_controlnet_handler(bucket=None, encoder_attention_mask=False, erase=0.15, dropout=0.0, shuffle=0.0, word_names={}):
|
27
|
+
return SyncHandler(
|
28
|
+
diffusion=StableDiffusionHandler(bucket=bucket, encoder_attention_mask=encoder_attention_mask, erase=erase, dropout=dropout, shuffle=shuffle,
|
29
|
+
word_names=word_names),
|
30
|
+
cnet=ControlNetHandler(bucket=bucket)
|
31
|
+
)
|
@@ -0,0 +1,79 @@
|
|
1
|
+
import torch
|
2
|
+
from hcpdiff.ckpt_manager import DiffusersSD15Format, DiffusersSDXLFormat, DiffusersPixArtFormat, OfficialSD15Format, OfficialSDXLFormat
|
3
|
+
from rainbowneko.ckpt_manager import NekoLoader, LocalCkptSource
|
4
|
+
from hcpdiff.utils import auto_tokenizer_cls, auto_text_encoder_cls, get_pipe_name
|
5
|
+
from hcpdiff.models.wrapper import SDXLWrapper, SD15Wrapper, PixArtWrapper
|
6
|
+
from hcpdiff.models.compose import SDXLTextEncoder
|
7
|
+
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
8
|
+
|
9
|
+
def SD15_auto_loader(ckpt_path, denoiser=None, TE=None, vae=None, noise_sampler=None,
|
10
|
+
tokenizer=None, revision=None, dtype=torch.float32, **kwargs):
|
11
|
+
try:
|
12
|
+
try_diffusers = StableDiffusionPipeline.load_config(ckpt_path)
|
13
|
+
loader = NekoLoader(
|
14
|
+
format=DiffusersSD15Format(),
|
15
|
+
source=LocalCkptSource(),
|
16
|
+
)
|
17
|
+
except EnvironmentError:
|
18
|
+
loader = NekoLoader(
|
19
|
+
format=OfficialSD15Format(),
|
20
|
+
source=LocalCkptSource(),
|
21
|
+
)
|
22
|
+
models = loader.load(ckpt_path, denoiser=denoiser, TE=TE, vae=vae, noise_sampler=noise_sampler, tokenizer=tokenizer, revision=revision,
|
23
|
+
dtype=dtype, **kwargs)
|
24
|
+
return models
|
25
|
+
|
26
|
+
def SDXL_auto_loader(ckpt_path, denoiser=None, TE=None, vae=None, noise_sampler=None,
|
27
|
+
tokenizer=None, revision=None, dtype=torch.float32, **kwargs):
|
28
|
+
try:
|
29
|
+
try_diffusers = StableDiffusionXLPipeline.load_config(ckpt_path)
|
30
|
+
loader = NekoLoader(
|
31
|
+
format=DiffusersSDXLFormat(),
|
32
|
+
source=LocalCkptSource(),
|
33
|
+
)
|
34
|
+
except EnvironmentError:
|
35
|
+
loader = NekoLoader(
|
36
|
+
format=OfficialSDXLFormat(),
|
37
|
+
source=LocalCkptSource(),
|
38
|
+
)
|
39
|
+
models = loader.load(ckpt_path, denoiser=denoiser, TE=TE, vae=vae, noise_sampler=noise_sampler, tokenizer=tokenizer, revision=revision,
|
40
|
+
dtype=dtype, **kwargs)
|
41
|
+
return models
|
42
|
+
|
43
|
+
def PixArt_auto_loader(ckpt_path, denoiser=None, TE=None, vae=None, noise_sampler=None,
|
44
|
+
tokenizer=None, revision=None, dtype=torch.float32, **kwargs):
|
45
|
+
loader = NekoLoader(
|
46
|
+
format=DiffusersPixArtFormat(),
|
47
|
+
source=LocalCkptSource(),
|
48
|
+
)
|
49
|
+
models = loader.load(ckpt_path, denoiser=denoiser, TE=TE, vae=vae, noise_sampler=noise_sampler, tokenizer=tokenizer, revision=revision,
|
50
|
+
dtype=dtype, **kwargs)
|
51
|
+
return models
|
52
|
+
|
53
|
+
def auto_load_wrapper(pretrained_model, denoiser=None, TE=None, vae=None, noise_sampler=None, tokenizer=None, revision=None,
|
54
|
+
dtype=torch.float32, **kwargs):
|
55
|
+
if TE is not None:
|
56
|
+
text_encoder_cls = type(TE)
|
57
|
+
else:
|
58
|
+
text_encoder_cls = auto_text_encoder_cls(pretrained_model, revision)
|
59
|
+
|
60
|
+
pipe_name = get_pipe_name(pretrained_model)
|
61
|
+
|
62
|
+
if text_encoder_cls == SDXLTextEncoder:
|
63
|
+
wrapper_cls = SDXLWrapper
|
64
|
+
format = DiffusersSDXLFormat()
|
65
|
+
elif 'PixArt' in pipe_name:
|
66
|
+
wrapper_cls = PixArtWrapper
|
67
|
+
format = DiffusersPixArtFormat()
|
68
|
+
else:
|
69
|
+
wrapper_cls = SD15Wrapper
|
70
|
+
format = DiffusersSD15Format()
|
71
|
+
|
72
|
+
loader = NekoLoader(
|
73
|
+
format=format,
|
74
|
+
source=LocalCkptSource(),
|
75
|
+
)
|
76
|
+
models = loader.load(pretrained_model, denoiser=denoiser, TE=TE, vae=vae, noise_sampler=noise_sampler, tokenizer=tokenizer, revision=revision,
|
77
|
+
dtype=dtype)
|
78
|
+
|
79
|
+
return wrapper_cls.build_from_pretrained(models, **kwargs)
|
hcpdiff/easy/sampler.py
ADDED
@@ -0,0 +1,46 @@
|
|
1
|
+
from hcpdiff.diffusion.sampler import DiffusersSampler
|
2
|
+
from diffusers import DPMSolverMultistepScheduler, DDIMScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler
|
3
|
+
|
4
|
+
class Diffusers_SD:
|
5
|
+
dpmpp_2m = DiffusersSampler(
|
6
|
+
DPMSolverMultistepScheduler(
|
7
|
+
beta_start=0.00085,
|
8
|
+
beta_end=0.012,
|
9
|
+
beta_schedule='scaled_linear',
|
10
|
+
algorithm_type='dpmsolver++',
|
11
|
+
)
|
12
|
+
)
|
13
|
+
|
14
|
+
dpmpp_2m_karras = DiffusersSampler(
|
15
|
+
DPMSolverMultistepScheduler(
|
16
|
+
beta_start=0.00085,
|
17
|
+
beta_end=0.012,
|
18
|
+
beta_schedule='scaled_linear',
|
19
|
+
algorithm_type='dpmsolver++',
|
20
|
+
use_karras_sigmas=True,
|
21
|
+
)
|
22
|
+
)
|
23
|
+
|
24
|
+
ddim = DiffusersSampler(
|
25
|
+
DDIMScheduler(
|
26
|
+
beta_start=0.00085,
|
27
|
+
beta_end=0.012,
|
28
|
+
beta_schedule='scaled_linear',
|
29
|
+
)
|
30
|
+
)
|
31
|
+
|
32
|
+
euler = DiffusersSampler(
|
33
|
+
EulerDiscreteScheduler(
|
34
|
+
beta_start=0.00085,
|
35
|
+
beta_end=0.012,
|
36
|
+
beta_schedule='scaled_linear',
|
37
|
+
)
|
38
|
+
)
|
39
|
+
|
40
|
+
euler_a = DiffusersSampler(
|
41
|
+
EulerAncestralDiscreteScheduler(
|
42
|
+
beta_start=0.00085,
|
43
|
+
beta_end=0.012,
|
44
|
+
beta_schedule='scaled_linear',
|
45
|
+
)
|
46
|
+
)
|
@@ -0,0 +1 @@
|
|
1
|
+
from .previewer import HCPPreviewer
|
@@ -0,0 +1,60 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from rainbowneko.evaluate.preview import WorkflowPreviewer
|
5
|
+
from rainbowneko.utils import to_cuda
|
6
|
+
|
7
|
+
from hcpdiff.models.wrapper import SD15Wrapper
|
8
|
+
from accelerate.hooks import remove_hook_from_module
|
9
|
+
|
10
|
+
class HCPPreviewer(WorkflowPreviewer):
|
11
|
+
|
12
|
+
@torch.no_grad()
|
13
|
+
def evaluate(self, step: int, model: SD15Wrapper, prefix='eval/'):
|
14
|
+
if step%self.interval != 0 or not self.trainer.is_local_main_process:
|
15
|
+
return
|
16
|
+
|
17
|
+
# record training layers
|
18
|
+
training_layers = [layer for layer in model.modules() if layer.training]
|
19
|
+
|
20
|
+
model.eval()
|
21
|
+
self.trainer.loggers.info(f'Preview')
|
22
|
+
|
23
|
+
N_repeats = model.text_enc_hook.N_repeats
|
24
|
+
clip_skip = model.text_enc_hook.clip_skip
|
25
|
+
clip_final_norm = model.text_enc_hook.clip_final_norm
|
26
|
+
use_attention_mask = model.text_enc_hook.use_attention_mask
|
27
|
+
|
28
|
+
preview_root = Path(self.trainer.exp_dir)/'imgs'
|
29
|
+
preview_root.mkdir(parents=True, exist_ok=True)
|
30
|
+
|
31
|
+
states = self.workflow_runner.run(model=model, in_preview=True, te_hook=model.text_enc_hook,
|
32
|
+
device=self.device, dtype=self.dtype, preview_root=preview_root, preview_step=step,
|
33
|
+
world_size=self.trainer.world_size, local_rank=self.trainer.local_rank,
|
34
|
+
emb_hook=self.trainer.cfgs.emb_pt.embedding_hook if self.trainer.pt_trainable else None)
|
35
|
+
|
36
|
+
# restore model states
|
37
|
+
if model.vae is not None:
|
38
|
+
model.vae.disable_tiling()
|
39
|
+
model.vae.disable_slicing()
|
40
|
+
remove_hook_from_module(model.vae, recurse=True)
|
41
|
+
if 'vae_encode_raw' in states:
|
42
|
+
model.vae.encode = states['vae_encode_raw']
|
43
|
+
model.vae.decode = states['vae_decode_raw']
|
44
|
+
|
45
|
+
if 'emb_hook' in states and not self.trainer.pt_trainable:
|
46
|
+
states['emb_hook'].remove()
|
47
|
+
|
48
|
+
if self.trainer.pt_trainable:
|
49
|
+
self.trainer.cfgs.emb_pt.embedding_hook.N_repeats = N_repeats
|
50
|
+
|
51
|
+
model.tokenizer.N_repeats = N_repeats
|
52
|
+
model.text_enc_hook.N_repeats = N_repeats
|
53
|
+
model.text_enc_hook.clip_skip = clip_skip
|
54
|
+
model.text_enc_hook.clip_final_norm = clip_final_norm
|
55
|
+
model.text_enc_hook.use_attention_mask = use_attention_mask
|
56
|
+
|
57
|
+
to_cuda(model)
|
58
|
+
|
59
|
+
for layer in training_layers:
|
60
|
+
layer.train()
|
hcpdiff/loss/__init__.py
CHANGED
hcpdiff/loss/base.py
ADDED
@@ -0,0 +1,41 @@
|
|
1
|
+
from rainbowneko.train.loss import LossContainer
|
2
|
+
from typing import Dict, Any
|
3
|
+
from torch import Tensor
|
4
|
+
|
5
|
+
class DiffusionLossContainer(LossContainer):
|
6
|
+
def __init__(self, loss, weight=1.0, key_map=None):
|
7
|
+
key_map = key_map or getattr(loss, '_key_map', None) or ('pred.model_pred -> 0', 'pred.target -> 1')
|
8
|
+
super().__init__(loss, weight, key_map)
|
9
|
+
self.target_type = getattr(loss, 'target_type', 'eps')
|
10
|
+
|
11
|
+
def get_target(self, pred_type, model_pred, x_0, noise, x_t, sigma, noise_sampler, **kwargs):
|
12
|
+
# Get target
|
13
|
+
if self.target_type == "eps":
|
14
|
+
target = noise
|
15
|
+
elif self.target_type == "x0":
|
16
|
+
target = x_0
|
17
|
+
elif self.target_type == "velocity":
|
18
|
+
target = noise_sampler.eps_to_velocity(noise, x_t, sigma)
|
19
|
+
else:
|
20
|
+
raise ValueError(f"Unsupport target_type {self.target_type}")
|
21
|
+
|
22
|
+
# TODO: put in wrapper
|
23
|
+
# # remove pred vars
|
24
|
+
# if model_pred.shape[1] == target.shape[1]*2:
|
25
|
+
# model_pred, _ = model_pred.chunk(2, dim=1)
|
26
|
+
|
27
|
+
# Convert pred_type to target_type
|
28
|
+
if pred_type != self.target_type:
|
29
|
+
cvt_func = getattr(noise_sampler, f'{pred_type}_to_{self.target_type}', None)
|
30
|
+
if cvt_func is None:
|
31
|
+
raise ValueError(f"Unsupport pred_type {pred_type} with target_type {self.target_type}")
|
32
|
+
else:
|
33
|
+
model_pred = cvt_func(model_pred, x_t, sigma)
|
34
|
+
return model_pred, target
|
35
|
+
|
36
|
+
def forward(self, pred:Dict[str,Any], inputs:Dict[str,Any]) -> Tensor:
|
37
|
+
model_pred, target = self.get_target(**pred)
|
38
|
+
pred['model_pred'] = model_pred
|
39
|
+
pred['target'] = target
|
40
|
+
loss = super().forward(pred, inputs) * self.weight # [B,*,*,*]
|
41
|
+
return loss.mean()
|
hcpdiff/loss/gw.py
ADDED
@@ -0,0 +1,35 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import nn
|
3
|
+
from torch.nn import functional as F
|
4
|
+
|
5
|
+
class GWLoss(nn.Module):
|
6
|
+
def __init__(self):
|
7
|
+
super().__init__()
|
8
|
+
|
9
|
+
sobel_x = [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]
|
10
|
+
sobel_y = [[-1, -2, -1], [0, 0, 0], [1, 2, 1]]
|
11
|
+
self.sobel_x = torch.FloatTensor(sobel_x)
|
12
|
+
self.sobel_y = torch.FloatTensor(sobel_y)
|
13
|
+
self.register_buffer('sobel_x', self.sobel_x)
|
14
|
+
self.register_buffer('sobel_y', self.sobel_y)
|
15
|
+
|
16
|
+
def forward(self, pred, target):
|
17
|
+
'''
|
18
|
+
|
19
|
+
:param pred: [B,C,H,W]
|
20
|
+
:param target: [B,C,H,W]
|
21
|
+
:return: [B,C,H,W]
|
22
|
+
'''
|
23
|
+
b, c, w, h = pred.shape
|
24
|
+
|
25
|
+
sobel_x = self.sobel_x.expand(c, 1, 3, 3).to(pred.device)
|
26
|
+
sobel_y = self.sobel_y.expand(c, 1, 3, 3).to(pred.device)
|
27
|
+
Ix1 = F.conv2d(pred, sobel_x, stride=1, padding=1, groups=c)
|
28
|
+
Ix2 = F.conv2d(target, sobel_x, stride=1, padding=1, groups=c)
|
29
|
+
Iy1 = F.conv2d(pred, sobel_y, stride=1, padding=1, groups=c)
|
30
|
+
Iy2 = F.conv2d(target, sobel_y, stride=1, padding=1, groups=c)
|
31
|
+
|
32
|
+
dx = torch.abs(Ix1 - Ix2)
|
33
|
+
dy = torch.abs(Iy1 - Iy2)
|
34
|
+
loss = (1 + 4 * dx) * (1 + 4 * dy) * torch.abs(pred - target)
|
35
|
+
return loss
|