hcpdiff 0.9.0__py3-none-any.whl → 2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/__init__.py +4 -4
- hcpdiff/ckpt_manager/__init__.py +4 -5
- hcpdiff/ckpt_manager/ckpt.py +24 -0
- hcpdiff/ckpt_manager/format/__init__.py +4 -0
- hcpdiff/ckpt_manager/format/diffusers.py +59 -0
- hcpdiff/ckpt_manager/format/emb.py +21 -0
- hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
- hcpdiff/ckpt_manager/format/sd_single.py +41 -0
- hcpdiff/ckpt_manager/loader.py +64 -0
- hcpdiff/data/__init__.py +4 -28
- hcpdiff/data/cache/__init__.py +1 -0
- hcpdiff/data/cache/vae.py +102 -0
- hcpdiff/data/dataset.py +20 -0
- hcpdiff/data/handler/__init__.py +3 -0
- hcpdiff/data/handler/controlnet.py +18 -0
- hcpdiff/data/handler/diffusion.py +80 -0
- hcpdiff/data/handler/text.py +111 -0
- hcpdiff/data/source/__init__.py +1 -2
- hcpdiff/data/source/folder_class.py +12 -29
- hcpdiff/data/source/text2img.py +36 -74
- hcpdiff/data/source/text2img_cond.py +9 -15
- hcpdiff/diffusion/__init__.py +0 -0
- hcpdiff/diffusion/noise/__init__.py +2 -0
- hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
- hcpdiff/diffusion/noise/zero_terminal.py +39 -0
- hcpdiff/diffusion/sampler/__init__.py +5 -0
- hcpdiff/diffusion/sampler/base.py +72 -0
- hcpdiff/diffusion/sampler/ddpm.py +20 -0
- hcpdiff/diffusion/sampler/diffusers.py +66 -0
- hcpdiff/diffusion/sampler/edm.py +22 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
- hcpdiff/easy/__init__.py +2 -0
- hcpdiff/easy/cfg/__init__.py +3 -0
- hcpdiff/easy/cfg/sd15_train.py +201 -0
- hcpdiff/easy/cfg/sdxl_train.py +140 -0
- hcpdiff/easy/cfg/t2i.py +177 -0
- hcpdiff/easy/model/__init__.py +2 -0
- hcpdiff/easy/model/cnet.py +31 -0
- hcpdiff/easy/model/loader.py +79 -0
- hcpdiff/easy/sampler.py +46 -0
- hcpdiff/evaluate/__init__.py +1 -0
- hcpdiff/evaluate/previewer.py +60 -0
- hcpdiff/loss/__init__.py +4 -1
- hcpdiff/loss/base.py +41 -0
- hcpdiff/loss/gw.py +35 -0
- hcpdiff/loss/ssim.py +37 -0
- hcpdiff/loss/vlb.py +79 -0
- hcpdiff/loss/weighting.py +66 -0
- hcpdiff/models/__init__.py +2 -2
- hcpdiff/models/cfg_context.py +17 -14
- hcpdiff/models/compose/compose_hook.py +44 -23
- hcpdiff/models/compose/compose_tokenizer.py +21 -8
- hcpdiff/models/compose/sdxl_composer.py +4 -4
- hcpdiff/models/container.py +1 -1
- hcpdiff/models/controlnet.py +16 -16
- hcpdiff/models/lora_base_patch.py +14 -25
- hcpdiff/models/lora_layers.py +3 -9
- hcpdiff/models/lora_layers_patch.py +14 -24
- hcpdiff/models/text_emb_ex.py +84 -6
- hcpdiff/models/textencoder_ex.py +54 -18
- hcpdiff/models/wrapper/__init__.py +3 -0
- hcpdiff/models/wrapper/pixart.py +19 -0
- hcpdiff/models/wrapper/sd.py +218 -0
- hcpdiff/models/wrapper/utils.py +20 -0
- hcpdiff/parser/__init__.py +1 -0
- hcpdiff/parser/embpt.py +32 -0
- hcpdiff/tools/convert_caption_txt2json.py +1 -1
- hcpdiff/tools/dataset_generator.py +94 -0
- hcpdiff/tools/download_hf_model.py +24 -0
- hcpdiff/tools/embedding_convert.py +6 -2
- hcpdiff/tools/init_proj.py +3 -21
- hcpdiff/tools/lora_convert.py +19 -15
- hcpdiff/tools/save_model.py +12 -0
- hcpdiff/tools/sd2diffusers.py +1 -1
- hcpdiff/train_colo.py +1 -1
- hcpdiff/train_deepspeed.py +1 -1
- hcpdiff/trainer_ac.py +79 -0
- hcpdiff/trainer_ac_single.py +31 -0
- hcpdiff/utils/__init__.py +0 -2
- hcpdiff/utils/inpaint_pipe.py +790 -0
- hcpdiff/utils/net_utils.py +29 -6
- hcpdiff/utils/pipe_hook.py +46 -33
- hcpdiff/utils/utils.py +21 -4
- hcpdiff/workflow/__init__.py +15 -10
- hcpdiff/workflow/daam/__init__.py +1 -0
- hcpdiff/workflow/daam/act.py +66 -0
- hcpdiff/workflow/daam/hook.py +109 -0
- hcpdiff/workflow/diffusion.py +128 -136
- hcpdiff/workflow/fast.py +31 -0
- hcpdiff/workflow/flow.py +67 -0
- hcpdiff/workflow/io.py +36 -68
- hcpdiff/workflow/model.py +46 -43
- hcpdiff/workflow/text.py +84 -52
- hcpdiff/workflow/utils.py +32 -12
- hcpdiff/workflow/vae.py +37 -38
- hcpdiff-2.1.dist-info/METADATA +285 -0
- hcpdiff-2.1.dist-info/RECORD +114 -0
- {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
- hcpdiff-2.1.dist-info/entry_points.txt +5 -0
- hcpdiff/ckpt_manager/base.py +0 -16
- hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
- hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
- hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -60
- hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
- hcpdiff/data/bucket.py +0 -358
- hcpdiff/data/caption_loader.py +0 -80
- hcpdiff/data/cond_dataset.py +0 -40
- hcpdiff/data/crop_info_dataset.py +0 -40
- hcpdiff/data/data_processor.py +0 -33
- hcpdiff/data/pair_dataset.py +0 -146
- hcpdiff/data/sampler.py +0 -54
- hcpdiff/data/source/base.py +0 -30
- hcpdiff/data/utils.py +0 -80
- hcpdiff/infer_workflow.py +0 -57
- hcpdiff/loggers/__init__.py +0 -13
- hcpdiff/loggers/base_logger.py +0 -76
- hcpdiff/loggers/cli_logger.py +0 -40
- hcpdiff/loggers/preview/__init__.py +0 -1
- hcpdiff/loggers/preview/image_previewer.py +0 -149
- hcpdiff/loggers/tensorboard_logger.py +0 -30
- hcpdiff/loggers/wandb_logger.py +0 -31
- hcpdiff/loggers/webui_logger.py +0 -9
- hcpdiff/loss/min_snr_loss.py +0 -52
- hcpdiff/models/layers.py +0 -81
- hcpdiff/models/plugin.py +0 -348
- hcpdiff/models/wrapper.py +0 -75
- hcpdiff/noise/__init__.py +0 -3
- hcpdiff/noise/noise_base.py +0 -16
- hcpdiff/noise/pyramid_noise.py +0 -50
- hcpdiff/noise/zero_terminal.py +0 -44
- hcpdiff/train_ac.py +0 -565
- hcpdiff/train_ac_single.py +0 -39
- hcpdiff/utils/caption_tools.py +0 -105
- hcpdiff/utils/cfg_net_tools.py +0 -321
- hcpdiff/utils/cfg_resolvers.py +0 -16
- hcpdiff/utils/ema.py +0 -52
- hcpdiff/utils/img_size_tool.py +0 -248
- hcpdiff/vis/__init__.py +0 -3
- hcpdiff/vis/base_interface.py +0 -12
- hcpdiff/vis/disk_interface.py +0 -48
- hcpdiff/vis/webui_interface.py +0 -17
- hcpdiff/visualizer.py +0 -258
- hcpdiff/visualizer_reloadable.py +0 -237
- hcpdiff/workflow/base.py +0 -59
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -57
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero2.json +0 -32
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero3.json +0 -39
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
- hcpdiff-0.9.0.dist-info/METADATA +0 -199
- hcpdiff-0.9.0.dist-info/RECORD +0 -155
- hcpdiff-0.9.0.dist-info/entry_points.txt +0 -2
- {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
- {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
hcpdiff/noise/zero_terminal.py
DELETED
@@ -1,44 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
from diffusers import SchedulerMixin
|
3
|
-
from .noise_base import NoiseBase
|
4
|
-
|
5
|
-
class ZeroTerminalScheduler(NoiseBase, SchedulerMixin):
|
6
|
-
def __init__(self, base_scheduler):
|
7
|
-
super().__init__(base_scheduler)
|
8
|
-
base_scheduler.betas = self.rescale_zero_terminal_snr(base_scheduler.betas)
|
9
|
-
base_scheduler.alphas = 1.0-base_scheduler.betas
|
10
|
-
base_scheduler.alphas_cumprod = torch.cumprod(base_scheduler.alphas, dim=0)
|
11
|
-
|
12
|
-
def rescale_zero_terminal_snr(self, betas):
|
13
|
-
"""
|
14
|
-
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
15
|
-
Args:
|
16
|
-
betas (`torch.FloatTensor`):
|
17
|
-
the betas that the scheduler is being initialized with.
|
18
|
-
Returns:
|
19
|
-
`torch.FloatTensor`: rescaled betas with zero terminal SNR
|
20
|
-
"""
|
21
|
-
# Convert betas to alphas_bar_sqrt
|
22
|
-
alphas = 1.0-betas
|
23
|
-
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
24
|
-
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
25
|
-
|
26
|
-
# Store old values.
|
27
|
-
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
28
|
-
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
29
|
-
|
30
|
-
# Shift so the last timestep is zero.
|
31
|
-
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
32
|
-
|
33
|
-
# Scale so the first timestep is back to the old value.
|
34
|
-
alphas_bar_sqrt *= alphas_bar_sqrt_0/(alphas_bar_sqrt_0-alphas_bar_sqrt_T)
|
35
|
-
|
36
|
-
# Convert alphas_bar_sqrt to betas
|
37
|
-
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
38
|
-
alphas = alphas_bar[1:]/alphas_bar[:-1] # Revert cumprod
|
39
|
-
alphas = torch.cat([alphas_bar[0:1], alphas])
|
40
|
-
betas = 1-alphas
|
41
|
-
|
42
|
-
return betas
|
43
|
-
|
44
|
-
|
hcpdiff/train_ac.py
DELETED
@@ -1,565 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
train_ac.py
|
3
|
-
====================
|
4
|
-
:Name: train with accelerate
|
5
|
-
:Author: Dong Ziyi
|
6
|
-
:Affiliation: HCP Lab, SYSU
|
7
|
-
:Created: 10/03/2023
|
8
|
-
:Licence: Apache-2.0
|
9
|
-
"""
|
10
|
-
|
11
|
-
import argparse
|
12
|
-
import math
|
13
|
-
import os
|
14
|
-
import time
|
15
|
-
import warnings
|
16
|
-
from functools import partial
|
17
|
-
|
18
|
-
import diffusers
|
19
|
-
import hydra
|
20
|
-
import torch
|
21
|
-
import torch.utils.checkpoint
|
22
|
-
# fix checkpoint bug for train part of model
|
23
|
-
import torch.utils.checkpoint
|
24
|
-
import torch.utils.data
|
25
|
-
import transformers
|
26
|
-
from accelerate import Accelerator, DistributedDataParallelKwargs
|
27
|
-
from accelerate.utils import set_seed
|
28
|
-
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
|
29
|
-
from diffusers.utils.import_utils import is_xformers_available
|
30
|
-
from omegaconf import OmegaConf
|
31
|
-
|
32
|
-
from hcpdiff.ckpt_manager import CkptManagerPKL, CkptManagerSafe
|
33
|
-
from hcpdiff.data import RatioBucket, DataGroup, get_sampler
|
34
|
-
from hcpdiff.loggers import LoggerGroup
|
35
|
-
from hcpdiff.models import CFGContext, DreamArtistPTContext, TEUnetWrapper, SDXLTEUnetWrapper
|
36
|
-
from hcpdiff.models.compose import ComposeEmbPTHook, ComposeTEEXHook
|
37
|
-
from hcpdiff.models.compose import SDXLTextEncoder
|
38
|
-
from hcpdiff.utils.cfg_net_tools import make_hcpdiff, make_plugin
|
39
|
-
from hcpdiff.utils.ema import ModelEMA
|
40
|
-
from hcpdiff.utils.net_utils import get_scheduler, auto_tokenizer_cls, auto_text_encoder_cls, load_emb
|
41
|
-
from hcpdiff.utils.utils import load_config_with_cli, get_cfg_range, mgcd, format_number
|
42
|
-
from hcpdiff.visualizer import Visualizer
|
43
|
-
|
44
|
-
def checkpoint_fix(function, *args, use_reentrant: bool = False, checkpoint_raw=torch.utils.checkpoint.checkpoint, **kwargs):
|
45
|
-
return checkpoint_raw(function, *args, use_reentrant=use_reentrant, **kwargs)
|
46
|
-
|
47
|
-
torch.utils.checkpoint.checkpoint = checkpoint_fix
|
48
|
-
|
49
|
-
class Trainer:
|
50
|
-
weight_dtype_map = {'fp32':torch.float32, 'fp16':torch.float16, 'bf16':torch.bfloat16}
|
51
|
-
ckpt_manager_map = {'torch':CkptManagerPKL, 'safetensors':CkptManagerSafe}
|
52
|
-
|
53
|
-
def __init__(self, cfgs_raw):
|
54
|
-
cfgs = hydra.utils.instantiate(cfgs_raw)
|
55
|
-
self.cfgs = cfgs
|
56
|
-
|
57
|
-
self.init_context(cfgs_raw)
|
58
|
-
self.build_loggers(cfgs_raw)
|
59
|
-
|
60
|
-
self.train_TE = any([cfgs.text_encoder, cfgs.lora_text_encoder, cfgs.plugin_TE])
|
61
|
-
|
62
|
-
self.build_ckpt_manager()
|
63
|
-
self.build_model()
|
64
|
-
self.make_hooks()
|
65
|
-
self.config_model()
|
66
|
-
self.cache_latents = False
|
67
|
-
|
68
|
-
self.batch_size_list = []
|
69
|
-
assert len(cfgs.data)>0, "At least one dataset is need."
|
70
|
-
loss_weights = [dataset.keywords['loss_weight'] for name, dataset in cfgs.data.items()]
|
71
|
-
self.train_loader_group = DataGroup([self.build_data(dataset) for name, dataset in cfgs.data.items()], loss_weights)
|
72
|
-
|
73
|
-
if self.cache_latents:
|
74
|
-
self.vae = self.vae.to('cpu')
|
75
|
-
self.build_optimizer_scheduler()
|
76
|
-
try:
|
77
|
-
self.criterion = cfgs.train.loss.criterion(noise_scheduler=self.noise_scheduler, device=self.device)
|
78
|
-
except:
|
79
|
-
self.criterion = cfgs.train.loss.criterion()
|
80
|
-
|
81
|
-
self.cfg_scale = get_cfg_range(cfgs.train.cfg_scale)
|
82
|
-
if self.cfg_scale[1] == 1.0:
|
83
|
-
self.cfg_context = CFGContext()
|
84
|
-
else: # DreamArtist
|
85
|
-
self.cfg_context = DreamArtistPTContext(self.cfg_scale, self.num_train_timesteps)
|
86
|
-
|
87
|
-
with torch.no_grad():
|
88
|
-
self.build_ema()
|
89
|
-
|
90
|
-
self.load_resume()
|
91
|
-
|
92
|
-
torch.backends.cuda.matmul.allow_tf32 = cfgs.allow_tf32
|
93
|
-
|
94
|
-
# calculate steps and epochs
|
95
|
-
self.steps_per_epoch = len(self.train_loader_group.loader_list[0])
|
96
|
-
if self.cfgs.train.train_epochs is not None:
|
97
|
-
self.cfgs.train.train_steps = self.cfgs.train.train_epochs*self.steps_per_epoch
|
98
|
-
else:
|
99
|
-
self.cfgs.train.train_epochs = math.ceil(self.cfgs.train.train_steps/self.steps_per_epoch)
|
100
|
-
|
101
|
-
if self.is_local_main_process and self.cfgs.previewer is not None:
|
102
|
-
self.previewer = self.cfgs.previewer(exp_dir=self.exp_dir, te_hook=self.text_enc_hook, unet=self.TE_unet.unet,
|
103
|
-
TE=self.TE_unet.TE, tokenizer=self.tokenizer, vae=self.vae)
|
104
|
-
|
105
|
-
self.prepare()
|
106
|
-
|
107
|
-
@property
|
108
|
-
def device(self):
|
109
|
-
return self.accelerator.device
|
110
|
-
|
111
|
-
@property
|
112
|
-
def is_local_main_process(self):
|
113
|
-
return self.accelerator.is_local_main_process
|
114
|
-
|
115
|
-
def init_context(self, cfgs_raw):
|
116
|
-
ddp_kwargs = DistributedDataParallelKwargs(broadcast_buffers=False)
|
117
|
-
self.accelerator = Accelerator(
|
118
|
-
gradient_accumulation_steps=self.cfgs.train.gradient_accumulation_steps,
|
119
|
-
mixed_precision=self.cfgs.mixed_precision,
|
120
|
-
step_scheduler_with_optimizer=False,
|
121
|
-
kwargs_handlers=[ddp_kwargs], # fix inplace bug in DDP while use data_class
|
122
|
-
)
|
123
|
-
|
124
|
-
self.local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
125
|
-
self.world_size = self.accelerator.num_processes
|
126
|
-
|
127
|
-
set_seed(self.cfgs.seed+self.local_rank)
|
128
|
-
|
129
|
-
def build_loggers(self, cfgs_raw):
|
130
|
-
if self.is_local_main_process:
|
131
|
-
self.exp_dir = self.cfgs.exp_dir.format(time=time.strftime("%Y-%m-%d-%H-%M-%S"))
|
132
|
-
os.makedirs(os.path.join(self.exp_dir, 'ckpts/'), exist_ok=True)
|
133
|
-
with open(os.path.join(self.exp_dir, 'cfg.yaml'), 'w', encoding='utf-8') as f:
|
134
|
-
f.write(OmegaConf.to_yaml(cfgs_raw))
|
135
|
-
self.loggers: LoggerGroup = LoggerGroup([builder(exp_dir=self.exp_dir) for builder in self.cfgs.logger])
|
136
|
-
else:
|
137
|
-
self.loggers: LoggerGroup = LoggerGroup([builder(exp_dir=None) for builder in self.cfgs.logger])
|
138
|
-
|
139
|
-
self.min_log_step = mgcd(*([item.log_step for item in self.loggers.logger_list]))
|
140
|
-
image_log_steps = [item.image_log_step for item in self.loggers.logger_list if item.enable_log_image]
|
141
|
-
if len(image_log_steps)>0:
|
142
|
-
self.min_img_log_step = mgcd(*image_log_steps)
|
143
|
-
else:
|
144
|
-
self.min_img_log_step = -1
|
145
|
-
|
146
|
-
self.loggers.info(f'world_size: {self.world_size}')
|
147
|
-
self.loggers.info(f'accumulation: {self.cfgs.train.gradient_accumulation_steps}')
|
148
|
-
|
149
|
-
if self.is_local_main_process:
|
150
|
-
transformers.utils.logging.set_verbosity_warning()
|
151
|
-
diffusers.utils.logging.set_verbosity_warning()
|
152
|
-
else:
|
153
|
-
transformers.utils.logging.set_verbosity_error()
|
154
|
-
diffusers.utils.logging.set_verbosity_error()
|
155
|
-
|
156
|
-
def prepare(self):
|
157
|
-
# Prepare everything with accelerator.
|
158
|
-
prepare_name_list, prepare_obj_list = [], []
|
159
|
-
if self.TE_unet.train_TE:
|
160
|
-
prepare_obj_list.append(self.TE_unet)
|
161
|
-
prepare_name_list.append('TE_unet')
|
162
|
-
else:
|
163
|
-
prepare_obj_list.append(self.TE_unet.unet)
|
164
|
-
prepare_name_list.append('TE_unet.unet')
|
165
|
-
|
166
|
-
if hasattr(self, 'optimizer'):
|
167
|
-
prepare_obj_list.extend([self.optimizer, self.lr_scheduler] if self.lr_scheduler else [self.optimizer])
|
168
|
-
prepare_name_list.extend(['optimizer', 'lr_scheduler'] if self.lr_scheduler else ['optimizer'])
|
169
|
-
if hasattr(self, 'optimizer_pt'):
|
170
|
-
prepare_obj_list.extend([self.optimizer_pt, self.lr_scheduler_pt] if self.lr_scheduler_pt else [self.optimizer_pt])
|
171
|
-
prepare_name_list.extend(['optimizer_pt', 'lr_scheduler_pt'] if self.lr_scheduler_pt else ['optimizer_pt'])
|
172
|
-
|
173
|
-
prepare_obj_list.extend(self.train_loader_group.loader_list)
|
174
|
-
prepared_obj = self.accelerator.prepare(*prepare_obj_list)
|
175
|
-
|
176
|
-
if not self.TE_unet.train_TE:
|
177
|
-
self.TE_unet.unet = prepared_obj[0]
|
178
|
-
prepared_obj = prepared_obj[1:]
|
179
|
-
prepare_name_list = prepare_name_list[1:]
|
180
|
-
|
181
|
-
ds_num = len(self.train_loader_group.loader_list)
|
182
|
-
self.train_loader_group.loader_list = list(prepared_obj[-ds_num:])
|
183
|
-
prepared_obj = prepared_obj[:-ds_num]
|
184
|
-
|
185
|
-
for name, obj in zip(prepare_name_list, prepared_obj):
|
186
|
-
setattr(self, name, obj)
|
187
|
-
|
188
|
-
if self.cfgs.model.force_cast_precision:
|
189
|
-
self.TE_unet.to(dtype=self.weight_dtype)
|
190
|
-
|
191
|
-
def scale_lr(self, parameters):
|
192
|
-
bs = sum(self.batch_size_list)
|
193
|
-
scale_factor = bs*self.world_size*self.cfgs.train.gradient_accumulation_steps
|
194
|
-
for param in parameters:
|
195
|
-
if 'lr' in param:
|
196
|
-
param['lr'] *= scale_factor
|
197
|
-
|
198
|
-
def build_model(self):
|
199
|
-
# Load the tokenizer
|
200
|
-
if self.cfgs.model.get('tokenizer', None) is not None:
|
201
|
-
self.tokenizer = self.cfgs.model.tokenizer
|
202
|
-
else:
|
203
|
-
tokenizer_cls = auto_tokenizer_cls(self.cfgs.model.pretrained_model_name_or_path, self.cfgs.model.revision)
|
204
|
-
self.tokenizer = tokenizer_cls.from_pretrained(
|
205
|
-
self.cfgs.model.pretrained_model_name_or_path, subfolder="tokenizer",
|
206
|
-
revision=self.cfgs.model.revision, use_fast=False,
|
207
|
-
)
|
208
|
-
|
209
|
-
# Load scheduler and models
|
210
|
-
self.noise_scheduler = self.cfgs.model.get('noise_scheduler', None) or \
|
211
|
-
DDPMScheduler.from_pretrained(self.cfgs.model.pretrained_model_name_or_path, subfolder='scheduler')
|
212
|
-
|
213
|
-
self.num_train_timesteps = len(self.noise_scheduler.timesteps)
|
214
|
-
self.vae: AutoencoderKL = self.cfgs.model.get('vae', None) or AutoencoderKL.from_pretrained(
|
215
|
-
self.cfgs.model.pretrained_model_name_or_path, subfolder="vae", revision=self.cfgs.model.revision)
|
216
|
-
self.build_unet_and_TE()
|
217
|
-
|
218
|
-
def build_unet_and_TE(self): # for easy to use colossalAI
|
219
|
-
unet = self.cfgs.model.get('unet', None) or UNet2DConditionModel.from_pretrained(
|
220
|
-
self.cfgs.model.pretrained_model_name_or_path, subfolder="unet", revision=self.cfgs.model.revision
|
221
|
-
)
|
222
|
-
|
223
|
-
if self.cfgs.model.get('text_encoder', None) is not None:
|
224
|
-
text_encoder = self.cfgs.model.text_encoder
|
225
|
-
text_encoder_cls = type(text_encoder)
|
226
|
-
else:
|
227
|
-
# import correct text encoder class
|
228
|
-
text_encoder_cls = auto_text_encoder_cls(self.cfgs.model.pretrained_model_name_or_path, self.cfgs.model.revision)
|
229
|
-
text_encoder = text_encoder_cls.from_pretrained(
|
230
|
-
self.cfgs.model.pretrained_model_name_or_path, subfolder="text_encoder", revision=self.cfgs.model.revision
|
231
|
-
)
|
232
|
-
|
233
|
-
# Wrap unet and text_encoder to make DDP happy. Multiple DDP has soooooo many fxxking bugs!
|
234
|
-
wrapper_cls = SDXLTEUnetWrapper if text_encoder_cls == SDXLTextEncoder else TEUnetWrapper
|
235
|
-
self.TE_unet = wrapper_cls(unet, text_encoder, train_TE=self.train_TE)
|
236
|
-
|
237
|
-
def build_ema(self):
|
238
|
-
if self.cfgs.model.ema is not None:
|
239
|
-
self.ema_unet = self.cfgs.model.ema(self.TE_unet.unet)
|
240
|
-
if self.train_TE:
|
241
|
-
self.ema_text_encoder = self.cfgs.model.ema(self.TE_unet.TE)
|
242
|
-
|
243
|
-
def build_ckpt_manager(self):
|
244
|
-
self.ckpt_manager = self.ckpt_manager_map[self.cfgs.ckpt_type]()
|
245
|
-
if self.is_local_main_process:
|
246
|
-
self.ckpt_manager.set_save_dir(os.path.join(self.exp_dir, 'ckpts'), emb_dir=self.cfgs.tokenizer_pt.emb_dir)
|
247
|
-
|
248
|
-
@property
|
249
|
-
def unet_raw(self):
|
250
|
-
return self.TE_unet.module.unet if self.train_TE else self.TE_unet.unet.module
|
251
|
-
|
252
|
-
@property
|
253
|
-
def TE_raw(self):
|
254
|
-
return self.TE_unet.module.TE if self.train_TE else self.TE_unet.TE
|
255
|
-
|
256
|
-
def config_model(self):
|
257
|
-
if self.cfgs.model.enable_xformers:
|
258
|
-
if is_xformers_available():
|
259
|
-
self.TE_unet.unet.enable_xformers_memory_efficient_attention()
|
260
|
-
# self.text_enc_hook.enable_xformers()
|
261
|
-
else:
|
262
|
-
warnings.warn("xformers is not available. Make sure it is installed correctly")
|
263
|
-
|
264
|
-
self.vae.requires_grad_(False)
|
265
|
-
self.TE_unet.requires_grad_(False)
|
266
|
-
|
267
|
-
self.TE_unet.eval()
|
268
|
-
|
269
|
-
if self.cfgs.model.gradient_checkpointing:
|
270
|
-
self.TE_unet.enable_gradient_checkpointing()
|
271
|
-
|
272
|
-
self.weight_dtype = self.weight_dtype_map.get(self.cfgs.mixed_precision, torch.float32)
|
273
|
-
self.vae_dtype = self.weight_dtype_map.get(self.cfgs.model.get('vae_dtype', None), torch.float32)
|
274
|
-
# Move vae and text_encoder to device and cast to weight_dtype
|
275
|
-
self.vae = self.vae.to(self.device, dtype=self.vae_dtype)
|
276
|
-
if not self.train_TE:
|
277
|
-
self.TE_unet.TE = self.TE_unet.TE.to(self.device, dtype=self.weight_dtype)
|
278
|
-
|
279
|
-
@torch.no_grad()
|
280
|
-
def load_resume(self):
|
281
|
-
if self.cfgs.train.resume is not None:
|
282
|
-
for ckpt in self.cfgs.train.resume.ckpt_path.unet:
|
283
|
-
self.ckpt_manager.load_ckpt_to_model(self.TE_unet.unet, ckpt, model_ema=getattr(self, 'ema_unet', None))
|
284
|
-
for ckpt in self.cfgs.train.resume.ckpt_path.TE:
|
285
|
-
self.ckpt_manager.load_ckpt_to_model(self.TE_unet.TE, ckpt, model_ema=getattr(self, 'ema_text_encoder', None))
|
286
|
-
for name, ckpt in self.cfgs.train.resume.ckpt_path.words:
|
287
|
-
self.ex_words_emb[name].data = load_emb(ckpt)
|
288
|
-
|
289
|
-
def make_hooks(self):
|
290
|
-
# Hook tokenizer and embedding to support pt
|
291
|
-
self.embedding_hook, self.ex_words_emb = ComposeEmbPTHook.hook_from_dir(
|
292
|
-
self.cfgs.tokenizer_pt.emb_dir, self.tokenizer, self.TE_unet.TE, log=self.is_local_main_process,
|
293
|
-
N_repeats=self.cfgs.model.tokenizer_repeats, device=self.device)
|
294
|
-
|
295
|
-
self.text_enc_hook = ComposeTEEXHook.hook(self.TE_unet.TE, self.tokenizer, N_repeats=self.cfgs.model.tokenizer_repeats,
|
296
|
-
device=self.device, clip_skip=self.cfgs.model.clip_skip,
|
297
|
-
clip_final_norm=self.cfgs.model.clip_final_norm)
|
298
|
-
|
299
|
-
def build_dataset(self, data_builder: partial):
|
300
|
-
batch_size = data_builder.keywords.pop('batch_size')
|
301
|
-
cache_latents = data_builder.keywords.pop('cache_latents')
|
302
|
-
self.batch_size_list.append(batch_size)
|
303
|
-
|
304
|
-
train_dataset = data_builder(tokenizer=self.tokenizer, tokenizer_repeats=self.cfgs.model.tokenizer_repeats)
|
305
|
-
train_dataset.bucket.build(batch_size*self.world_size, file_names=train_dataset.source.get_image_list())
|
306
|
-
arb = isinstance(train_dataset.bucket, RatioBucket)
|
307
|
-
self.loggers.info(f"len(train_dataset): {len(train_dataset)}")
|
308
|
-
|
309
|
-
if cache_latents:
|
310
|
-
self.cache_latents = True
|
311
|
-
train_dataset.cache_latents(self.vae, self.vae_dtype, self.device, show_prog=self.is_local_main_process)
|
312
|
-
return train_dataset, batch_size, arb
|
313
|
-
|
314
|
-
def build_data(self, data_builder: partial) -> torch.utils.data.DataLoader:
|
315
|
-
train_dataset, batch_size, arb = self.build_dataset(data_builder)
|
316
|
-
|
317
|
-
# Pytorch Data loader
|
318
|
-
train_sampler = get_sampler()(train_dataset, num_replicas=self.world_size, rank=self.local_rank, shuffle=not arb)
|
319
|
-
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=self.cfgs.train.workers,
|
320
|
-
sampler=train_sampler, collate_fn=train_dataset.collate_fn)
|
321
|
-
return train_loader
|
322
|
-
|
323
|
-
def get_param_group_train(self):
|
324
|
-
# make miniFT and warp with lora
|
325
|
-
self.DA_lora = False
|
326
|
-
train_params_unet, self.lora_unet = make_hcpdiff(self.TE_unet.unet, self.cfgs.unet, self.cfgs.lora_unet)
|
327
|
-
if isinstance(self.lora_unet, tuple): # creat negative lora
|
328
|
-
self.DA_lora = True
|
329
|
-
self.lora_unet, self.lora_unet_neg = self.lora_unet
|
330
|
-
train_params_unet_plugin, self.all_plugin_unet = make_plugin(self.TE_unet.unet, self.cfgs.plugin_unet)
|
331
|
-
train_params_unet += train_params_unet_plugin
|
332
|
-
|
333
|
-
if self.train_TE:
|
334
|
-
train_params_text_encoder, self.lora_TE = make_hcpdiff(self.TE_unet.TE, self.cfgs.text_encoder, self.cfgs.lora_text_encoder)
|
335
|
-
if isinstance(self.lora_TE, tuple): # creat negative lora
|
336
|
-
self.DA_lora = True
|
337
|
-
self.lora_TE, self.lora_TE_neg = self.lora_TE
|
338
|
-
train_params_TE_plugin, self.all_plugin_TE = make_plugin(self.TE_unet.TE, self.cfgs.plugin_TE)
|
339
|
-
train_params_text_encoder += train_params_TE_plugin
|
340
|
-
else:
|
341
|
-
train_params_text_encoder = []
|
342
|
-
|
343
|
-
N_params_unet = format_number(sum(sum(x.numel() for x in p['params']) for p in train_params_unet))
|
344
|
-
N_params_TE = format_number(sum(sum(x.numel() for x in p['params']) for p in train_params_text_encoder))
|
345
|
-
self.loggers.info(f'unet trainable params: {N_params_unet}; text encoder trainable params: {N_params_TE}')
|
346
|
-
|
347
|
-
# params for embedding
|
348
|
-
train_params_emb = []
|
349
|
-
self.train_pts = {}
|
350
|
-
if self.cfgs.tokenizer_pt.train is not None:
|
351
|
-
for v in self.cfgs.tokenizer_pt.train:
|
352
|
-
word_emb = self.ex_words_emb[v.name]
|
353
|
-
self.train_pts[v.name] = word_emb
|
354
|
-
word_emb.requires_grad = True
|
355
|
-
self.embedding_hook.emb_train.append(word_emb)
|
356
|
-
train_params_emb.append({'params':word_emb, 'lr':v.lr})
|
357
|
-
|
358
|
-
return train_params_unet+train_params_text_encoder, train_params_emb
|
359
|
-
|
360
|
-
def build_optimizer_scheduler(self):
|
361
|
-
# set optimizer
|
362
|
-
parameters, parameters_pt = self.get_param_group_train()
|
363
|
-
|
364
|
-
if len(parameters)>0: # do fine-tuning
|
365
|
-
cfg_opt = self.cfgs.train.optimizer
|
366
|
-
if self.cfgs.train.scale_lr:
|
367
|
-
self.scale_lr(parameters)
|
368
|
-
assert isinstance(cfg_opt, partial), f'optimizer.type is not supported anymore, please use class path like "torch.optim.AdamW".'
|
369
|
-
self.optimizer = cfg_opt(params=parameters)
|
370
|
-
self.lr_scheduler = get_scheduler(self.cfgs.train.scheduler, self.optimizer)
|
371
|
-
|
372
|
-
if len(parameters_pt)>0: # do prompt-tuning
|
373
|
-
cfg_opt_pt = self.cfgs.train.optimizer_pt
|
374
|
-
if self.cfgs.train.scale_lr_pt:
|
375
|
-
self.scale_lr(parameters_pt)
|
376
|
-
assert isinstance(cfg_opt_pt, partial), f'optimizer.type is not supported anymore, please use class path like "torch.optim.AdamW".'
|
377
|
-
self.optimizer_pt = cfg_opt_pt(params=parameters_pt)
|
378
|
-
self.lr_scheduler_pt = get_scheduler(self.cfgs.train.scheduler_pt, self.optimizer_pt)
|
379
|
-
|
380
|
-
def train(self, loss_ema=0.93):
|
381
|
-
total_batch_size = sum(self.batch_size_list)*self.world_size*self.cfgs.train.gradient_accumulation_steps
|
382
|
-
|
383
|
-
self.loggers.info("***** Running training *****")
|
384
|
-
self.loggers.info(f" Num batches each epoch = {len(self.train_loader_group.loader_list[0])}")
|
385
|
-
self.loggers.info(f" Num Steps = {self.cfgs.train.train_steps}")
|
386
|
-
self.loggers.info(f" Instantaneous batch size per device = {sum(self.batch_size_list)}")
|
387
|
-
self.loggers.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
388
|
-
self.loggers.info(f" Gradient Accumulation steps = {self.cfgs.train.gradient_accumulation_steps}")
|
389
|
-
self.global_step = 0
|
390
|
-
if self.cfgs.train.resume is not None:
|
391
|
-
self.global_step = self.cfgs.train.resume.start_step
|
392
|
-
|
393
|
-
loss_sum = None
|
394
|
-
for data_list in self.train_loader_group:
|
395
|
-
loss = self.train_one_step(data_list)
|
396
|
-
loss_sum = loss if loss_sum is None else (loss_ema*loss_sum+(1-loss_ema)*loss)
|
397
|
-
|
398
|
-
self.global_step += 1
|
399
|
-
if self.is_local_main_process:
|
400
|
-
if self.global_step%self.cfgs.train.save_step == 0:
|
401
|
-
self.save_model()
|
402
|
-
if self.global_step%self.min_log_step == 0:
|
403
|
-
# get learning rate from optimizer
|
404
|
-
lr_model = self.optimizer.param_groups[0]['lr'] if hasattr(self, 'optimizer') else 0.
|
405
|
-
lr_word = self.optimizer_pt.param_groups[0]['lr'] if hasattr(self, 'optimizer_pt') else 0.
|
406
|
-
self.loggers.log(datas={
|
407
|
-
'Step':{'format':'[{}/{}]', 'data':[self.global_step, self.cfgs.train.train_steps]},
|
408
|
-
'Epoch':{'format':'[{}/{}]<{}/{}>', 'data':[self.global_step//self.steps_per_epoch, self.cfgs.train.train_epochs,
|
409
|
-
self.global_step%self.steps_per_epoch, self.steps_per_epoch]},
|
410
|
-
'LR_model':{'format':'{:.2e}', 'data':[lr_model]},
|
411
|
-
'LR_word':{'format':'{:.2e}', 'data':[lr_word]},
|
412
|
-
'Loss':{'format':'{:.5f}', 'data':[loss_sum]},
|
413
|
-
}, step=self.global_step)
|
414
|
-
if self.min_img_log_step>0 and self.global_step%self.min_img_log_step == 0:
|
415
|
-
self.loggers.log_image(self.previewer.preview_dict(), self.global_step)
|
416
|
-
|
417
|
-
if self.global_step>=self.cfgs.train.train_steps:
|
418
|
-
break
|
419
|
-
|
420
|
-
self.wait_for_everyone()
|
421
|
-
if self.is_local_main_process:
|
422
|
-
self.save_model()
|
423
|
-
|
424
|
-
def wait_for_everyone(self):
|
425
|
-
self.accelerator.wait_for_everyone()
|
426
|
-
|
427
|
-
@torch.no_grad()
|
428
|
-
def get_latents(self, image, dataset):
|
429
|
-
if dataset.latents is None:
|
430
|
-
latents = self.vae.encode(image.to(dtype=self.vae.dtype)).latent_dist.sample()
|
431
|
-
latents = latents*self.vae.config.scaling_factor
|
432
|
-
else:
|
433
|
-
latents = image # Cached latents
|
434
|
-
return latents
|
435
|
-
|
436
|
-
def make_noise(self, latents):
|
437
|
-
# Sample noise that we'll add to the latents
|
438
|
-
noise = torch.randn_like(latents)
|
439
|
-
bsz = latents.shape[0]
|
440
|
-
# Sample a random timestep for each image
|
441
|
-
timesteps = torch.randint(0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
442
|
-
timesteps = timesteps.long()
|
443
|
-
|
444
|
-
# Add noise to the latents according to the noise magnitude at each timestep
|
445
|
-
# (this is the forward diffusion process)
|
446
|
-
return self.noise_scheduler.add_noise(latents, noise, timesteps), noise, timesteps
|
447
|
-
|
448
|
-
def forward(self, latents, prompt_ids, attn_mask=None, position_ids=None, **kwargs):
|
449
|
-
noisy_latents, noise, timesteps = self.make_noise(latents)
|
450
|
-
|
451
|
-
# CFG context for DreamArtist
|
452
|
-
noisy_latents, timesteps = self.cfg_context.pre(noisy_latents, timesteps)
|
453
|
-
model_pred = self.TE_unet(prompt_ids, noisy_latents, timesteps, attn_mask=attn_mask, position_ids=position_ids, **kwargs)
|
454
|
-
model_pred = self.cfg_context.post(model_pred)
|
455
|
-
|
456
|
-
# Get the target for loss depending on the prediction type
|
457
|
-
if self.cfgs.train.loss.type == "eps":
|
458
|
-
target = noise
|
459
|
-
elif self.cfgs.train.loss.type == "sample":
|
460
|
-
target = self.noise_scheduler.step(noise, timesteps, noisy_latents)
|
461
|
-
model_pred = self.noise_scheduler.step(model_pred, timesteps, noisy_latents)
|
462
|
-
else:
|
463
|
-
raise ValueError(f"Unknown loss type {self.cfgs.train.loss.type}")
|
464
|
-
return model_pred, target, timesteps
|
465
|
-
|
466
|
-
def train_one_step(self, data_list):
|
467
|
-
with self.accelerator.accumulate(self.TE_unet):
|
468
|
-
for idx, data in enumerate(data_list):
|
469
|
-
image = data.pop('img').to(self.device, dtype=self.weight_dtype)
|
470
|
-
img_mask = data.pop('mask').to(self.device) if 'mask' in data else None
|
471
|
-
prompt_ids = data.pop('prompt').to(self.device)
|
472
|
-
attn_mask = data.pop('attn_mask').to(self.device) if 'attn_mask' in data else None
|
473
|
-
position_ids = data.pop('position_ids').to(self.device) if 'position_ids' in data else None
|
474
|
-
other_datas = {k:v.to(self.device) for k, v in data.items() if k!='plugin_input'}
|
475
|
-
if 'plugin_input' in data:
|
476
|
-
other_datas['plugin_input'] = {k:v.to(self.device, dtype=self.weight_dtype) for k, v in data['plugin_input'].items()}
|
477
|
-
|
478
|
-
latents = self.get_latents(image, self.train_loader_group.get_dataset(idx))
|
479
|
-
model_pred, target, timesteps = self.forward(latents, prompt_ids, attn_mask, position_ids, **other_datas)
|
480
|
-
loss = self.get_loss(model_pred, target, timesteps, img_mask)*self.train_loader_group.get_loss_weights(idx)
|
481
|
-
self.accelerator.backward(loss)
|
482
|
-
|
483
|
-
if hasattr(self, 'optimizer'):
|
484
|
-
if self.accelerator.sync_gradients: # fine-tuning
|
485
|
-
if hasattr(self.TE_unet, 'trainable_parameters'):
|
486
|
-
clip_param = self.TE_unet.trainable_parameters()
|
487
|
-
else:
|
488
|
-
clip_param = self.TE_unet.module.trainable_parameters()
|
489
|
-
self.accelerator.clip_grad_norm_(clip_param, self.cfgs.train.max_grad_norm)
|
490
|
-
self.optimizer.step()
|
491
|
-
if self.lr_scheduler:
|
492
|
-
self.lr_scheduler.step()
|
493
|
-
self.optimizer.zero_grad(set_to_none=self.cfgs.train.set_grads_to_none)
|
494
|
-
|
495
|
-
if hasattr(self, 'optimizer_pt'): # prompt tuning
|
496
|
-
self.optimizer_pt.step()
|
497
|
-
if self.lr_scheduler_pt:
|
498
|
-
self.lr_scheduler_pt.step()
|
499
|
-
self.optimizer_pt.zero_grad(set_to_none=self.cfgs.train.set_grads_to_none)
|
500
|
-
|
501
|
-
if self.accelerator.sync_gradients:
|
502
|
-
self.update_ema()
|
503
|
-
return loss.item()
|
504
|
-
|
505
|
-
def get_loss(self, model_pred, target, timesteps, att_mask):
|
506
|
-
if att_mask is None:
|
507
|
-
att_mask = 1.0
|
508
|
-
if getattr(self.criterion, 'need_timesteps', False):
|
509
|
-
loss = (self.criterion(model_pred.float(), target.float(), timesteps)*att_mask).mean()
|
510
|
-
else:
|
511
|
-
loss = (self.criterion(model_pred.float(), target.float())*att_mask).mean()
|
512
|
-
if len(self.embedding_hook.emb_train)>0:
|
513
|
-
loss = loss+0*sum([emb.mean() for emb in self.embedding_hook.emb_train])
|
514
|
-
return loss
|
515
|
-
|
516
|
-
def update_ema(self):
|
517
|
-
if hasattr(self, 'ema_unet'):
|
518
|
-
self.ema_unet.update(self.unet_raw)
|
519
|
-
if hasattr(self, 'ema_text_encoder'):
|
520
|
-
self.ema_text_encoder.update(self.TE_raw)
|
521
|
-
|
522
|
-
def save_model(self, from_raw=False):
|
523
|
-
unet_raw = self.unet_raw
|
524
|
-
self.ckpt_manager.save_model_with_lora(unet_raw, self.lora_unet, model_ema=getattr(self, 'ema_unet', None),
|
525
|
-
name='unet', step=self.global_step)
|
526
|
-
self.ckpt_manager.save_plugins(unet_raw, self.all_plugin_unet, name='unet', step=self.global_step,
|
527
|
-
model_ema=getattr(self, 'ema_unet', None))
|
528
|
-
if self.train_TE:
|
529
|
-
TE_raw = self.TE_raw
|
530
|
-
# exclude_key: embeddings should not save with text-encoder
|
531
|
-
self.ckpt_manager.save_model_with_lora(TE_raw, self.lora_TE, model_ema=getattr(self, 'ema_text_encoder', None),
|
532
|
-
name='text_encoder', step=self.global_step, exclude_key='emb_ex.')
|
533
|
-
self.ckpt_manager.save_plugins(TE_raw, self.all_plugin_TE, name='text_encoder', step=self.global_step,
|
534
|
-
model_ema=getattr(self, 'ema_text_encoder', None))
|
535
|
-
|
536
|
-
if self.DA_lora:
|
537
|
-
self.ckpt_manager.save_model_with_lora(None, self.lora_unet_neg, name='unet-neg', step=self.global_step)
|
538
|
-
if self.train_TE:
|
539
|
-
self.ckpt_manager.save_model_with_lora(None, self.lora_TE_neg, name='text_encoder-neg', step=self.global_step)
|
540
|
-
|
541
|
-
self.ckpt_manager.save_embedding(self.train_pts, self.global_step, self.cfgs.tokenizer_pt.replace)
|
542
|
-
|
543
|
-
self.loggers.info(f"Saved state, step: {self.global_step}")
|
544
|
-
|
545
|
-
def make_vis(self):
|
546
|
-
vis_dir = os.path.join(self.exp_dir, f'vis-{self.global_step}')
|
547
|
-
new_components = {
|
548
|
-
'unet':self.unet_raw,
|
549
|
-
'text_encoder':self.TE_raw,
|
550
|
-
'tokenizer':self.tokenizer,
|
551
|
-
'vae':self.vae,
|
552
|
-
}
|
553
|
-
viser = Visualizer(self.cfgs.model.pretrained_model_name_or_path, new_components=new_components)
|
554
|
-
if self.cfgs.vis_info.prompt:
|
555
|
-
raise ValueError('vis_info.prompt is None. cannot generate without prompt.')
|
556
|
-
viser.vis_to_dir(vis_dir, self.cfgs.vis_prompt)
|
557
|
-
|
558
|
-
if __name__ == '__main__':
|
559
|
-
parser = argparse.ArgumentParser(description='Stable Diffusion Training')
|
560
|
-
parser.add_argument('--cfg', type=str, default=None, required=True)
|
561
|
-
args, cfg_args = parser.parse_known_args()
|
562
|
-
|
563
|
-
conf = load_config_with_cli(args.cfg, args_list=cfg_args) # skip --cfg
|
564
|
-
trainer = Trainer(conf)
|
565
|
-
trainer.train()
|
hcpdiff/train_ac_single.py
DELETED
@@ -1,39 +0,0 @@
|
|
1
|
-
import argparse
|
2
|
-
import sys
|
3
|
-
from functools import partial
|
4
|
-
|
5
|
-
import torch
|
6
|
-
from accelerate import Accelerator
|
7
|
-
from loguru import logger
|
8
|
-
|
9
|
-
from hcpdiff.train_ac import Trainer, RatioBucket, load_config_with_cli, set_seed, get_sampler
|
10
|
-
|
11
|
-
class TrainerSingleCard(Trainer):
|
12
|
-
def init_context(self, cfgs_raw):
|
13
|
-
self.accelerator = Accelerator(
|
14
|
-
gradient_accumulation_steps=self.cfgs.train.gradient_accumulation_steps,
|
15
|
-
mixed_precision=self.cfgs.mixed_precision,
|
16
|
-
step_scheduler_with_optimizer=False,
|
17
|
-
)
|
18
|
-
|
19
|
-
self.local_rank = 0
|
20
|
-
self.world_size = self.accelerator.num_processes
|
21
|
-
|
22
|
-
set_seed(self.cfgs.seed+self.local_rank)
|
23
|
-
|
24
|
-
@property
|
25
|
-
def unet_raw(self):
|
26
|
-
return self.TE_unet.unet
|
27
|
-
|
28
|
-
@property
|
29
|
-
def TE_raw(self):
|
30
|
-
return self.TE_unet.TE
|
31
|
-
|
32
|
-
if __name__ == '__main__':
|
33
|
-
parser = argparse.ArgumentParser(description='Stable Diffusion Training')
|
34
|
-
parser.add_argument('--cfg', type=str, default='cfg/train/demo.yaml')
|
35
|
-
args, cfg_args = parser.parse_known_args()
|
36
|
-
|
37
|
-
conf = load_config_with_cli(args.cfg, args_list=cfg_args) # skip --cfg
|
38
|
-
trainer = TrainerSingleCard(conf)
|
39
|
-
trainer.train()
|