hcpdiff 2.2__py3-none-any.whl → 2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/ckpt_manager/__init__.py +1 -1
- hcpdiff/ckpt_manager/ckpt.py +21 -17
- hcpdiff/ckpt_manager/format/diffusers.py +4 -4
- hcpdiff/ckpt_manager/format/sd_single.py +3 -3
- hcpdiff/ckpt_manager/loader.py +11 -4
- hcpdiff/diffusion/noise/__init__.py +0 -1
- hcpdiff/diffusion/sampler/VP.py +27 -0
- hcpdiff/diffusion/sampler/__init__.py +2 -3
- hcpdiff/diffusion/sampler/base.py +106 -44
- hcpdiff/diffusion/sampler/diffusers.py +11 -17
- hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -1
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +77 -2
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +193 -49
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +110 -33
- hcpdiff/diffusion/sampler/sigma_scheduler/flow.py +74 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/zero_terminal.py +22 -0
- hcpdiff/easy/cfg/sd15_train.py +35 -24
- hcpdiff/easy/cfg/sdxl_train.py +34 -25
- hcpdiff/evaluate/__init__.py +3 -1
- hcpdiff/evaluate/evaluator.py +76 -0
- hcpdiff/evaluate/metrics/__init__.py +1 -0
- hcpdiff/evaluate/metrics/clip_score.py +23 -0
- hcpdiff/evaluate/previewer.py +29 -12
- hcpdiff/loss/base.py +9 -26
- hcpdiff/loss/weighting.py +36 -18
- hcpdiff/models/lora_base_patch.py +26 -0
- hcpdiff/models/text_emb_ex.py +4 -0
- hcpdiff/models/wrapper/sd.py +17 -19
- hcpdiff/trainer_ac.py +7 -12
- hcpdiff/trainer_ac_single.py +1 -6
- hcpdiff/trainer_deepspeed.py +47 -0
- hcpdiff/utils/__init__.py +2 -1
- hcpdiff/utils/torch_utils.py +25 -0
- hcpdiff/workflow/__init__.py +1 -1
- hcpdiff/workflow/diffusion.py +27 -7
- hcpdiff/workflow/io.py +20 -3
- hcpdiff/workflow/text.py +6 -1
- {hcpdiff-2.2.dist-info → hcpdiff-2.3.dist-info}/METADATA +8 -4
- {hcpdiff-2.2.dist-info → hcpdiff-2.3.dist-info}/RECORD +43 -39
- {hcpdiff-2.2.dist-info → hcpdiff-2.3.dist-info}/WHEEL +1 -1
- {hcpdiff-2.2.dist-info → hcpdiff-2.3.dist-info}/entry_points.txt +1 -0
- hcpdiff/diffusion/noise/zero_terminal.py +0 -39
- hcpdiff/diffusion/sampler/ddpm.py +0 -20
- hcpdiff/diffusion/sampler/edm.py +0 -22
- hcpdiff/train_deepspeed.py +0 -69
- {hcpdiff-2.2.dist-info → hcpdiff-2.3.dist-info}/licenses/LICENSE +0 -0
- {hcpdiff-2.2.dist-info → hcpdiff-2.3.dist-info}/top_level.txt +0 -0
@@ -1,20 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
|
3
|
-
from .base import BaseSampler
|
4
|
-
from .sigma_scheduler import SigmaScheduler
|
5
|
-
|
6
|
-
class DDPMSampler(BaseSampler):
|
7
|
-
def __init__(self, sigma_scheduler: SigmaScheduler, generator: torch.Generator=None):
|
8
|
-
super().__init__(sigma_scheduler, generator)
|
9
|
-
|
10
|
-
def c_in(self, sigma):
|
11
|
-
return 1./(sigma**2+1).sqrt()
|
12
|
-
|
13
|
-
def c_out(self, sigma):
|
14
|
-
return -sigma
|
15
|
-
|
16
|
-
def c_skip(self, sigma):
|
17
|
-
return 1.
|
18
|
-
|
19
|
-
def denoise(self, x, sigma, eps=None, generator=None):
|
20
|
-
raise NotImplementedError
|
hcpdiff/diffusion/sampler/edm.py
DELETED
@@ -1,22 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
|
3
|
-
from .base import BaseSampler
|
4
|
-
from .sigma_scheduler import SigmaScheduler
|
5
|
-
|
6
|
-
class EDMSampler(BaseSampler):
|
7
|
-
def __init__(self, sigma_scheduler: SigmaScheduler, generator: torch.Generator = None, sigma_data: float = 1.0, sigma_thr=1000):
|
8
|
-
super().__init__(sigma_scheduler, generator)
|
9
|
-
self.sigma_data = sigma_data
|
10
|
-
self.sigma_thr = sigma_thr
|
11
|
-
|
12
|
-
def c_in(self, sigma):
|
13
|
-
return 1/(sigma**2+self.sigma_data**2).sqrt()
|
14
|
-
|
15
|
-
def c_out(self, sigma):
|
16
|
-
return (sigma*self.sigma_data)/(sigma**2+self.sigma_data**2).sqrt()
|
17
|
-
|
18
|
-
def c_skip(self, sigma):
|
19
|
-
return self.sigma_data**2/(sigma**2+self.sigma_data**2)
|
20
|
-
|
21
|
-
def denoise(self, x, sigma, eps=None, generator=None):
|
22
|
-
raise NotImplementedError
|
hcpdiff/train_deepspeed.py
DELETED
@@ -1,69 +0,0 @@
|
|
1
|
-
import argparse
|
2
|
-
import os
|
3
|
-
import sys
|
4
|
-
import warnings
|
5
|
-
from functools import partial
|
6
|
-
|
7
|
-
import torch
|
8
|
-
|
9
|
-
from hcpdiff.ckpt_manager import CkptManagerPKL, CkptManagerSafe
|
10
|
-
from hcpdiff.train_ac_old import Trainer, load_config_with_cli
|
11
|
-
from hcpdiff.utils.net_utils import get_scheduler
|
12
|
-
|
13
|
-
class TrainerDeepSpeed(Trainer):
|
14
|
-
|
15
|
-
def build_ckpt_manager(self):
|
16
|
-
self.ckpt_manager = self.ckpt_manager_map[self.cfgs.ckpt_type](plugin_from_raw=True)
|
17
|
-
if self.is_local_main_process:
|
18
|
-
self.ckpt_manager.set_save_dir(os.path.join(self.exp_dir, 'ckpts'), emb_dir=self.cfgs.tokenizer_pt.emb_dir)
|
19
|
-
|
20
|
-
@property
|
21
|
-
def unet_raw(self):
|
22
|
-
return self.accelerator.unwrap_model(self.TE_unet).unet if self.train_TE else self.accelerator.unwrap_model(self.TE_unet.unet)
|
23
|
-
|
24
|
-
@property
|
25
|
-
def TE_raw(self):
|
26
|
-
return self.accelerator.unwrap_model(self.TE_unet).TE if self.train_TE else self.TE_unet.TE
|
27
|
-
|
28
|
-
def get_loss(self, model_pred, target, timesteps, att_mask):
|
29
|
-
if att_mask is None:
|
30
|
-
att_mask = 1.0
|
31
|
-
if getattr(self.criterion, 'need_timesteps', False):
|
32
|
-
loss = (self.criterion(model_pred.float(), target.float(), timesteps)*att_mask).mean()
|
33
|
-
else:
|
34
|
-
loss = (self.criterion(model_pred.float(), target.float())*att_mask).mean()
|
35
|
-
return loss
|
36
|
-
|
37
|
-
def build_optimizer_scheduler(self):
|
38
|
-
# set optimizer
|
39
|
-
parameters, parameters_pt = self.get_param_group_train()
|
40
|
-
|
41
|
-
if len(parameters_pt)>0: # do prompt-tuning
|
42
|
-
cfg_opt_pt = self.cfgs.train.optimizer_pt
|
43
|
-
# if self.cfgs.train.scale_lr_pt:
|
44
|
-
# self.scale_lr(parameters_pt)
|
45
|
-
assert isinstance(cfg_opt_pt, partial), f'optimizer.type is not supported anymore, please use class path like "torch.optim.AdamW".'
|
46
|
-
weight_decay = cfg_opt_pt.keywords.get('weight_decay', None)
|
47
|
-
if weight_decay is not None:
|
48
|
-
for param in parameters_pt:
|
49
|
-
param['weight_decay'] = weight_decay
|
50
|
-
|
51
|
-
parameters += parameters_pt
|
52
|
-
warnings.warn('deepspeed dose not support multi optimizer and lr_scheduler. optimizer_pt and scheduler_pt will not work.')
|
53
|
-
|
54
|
-
if len(parameters)>0:
|
55
|
-
cfg_opt = self.cfgs.train.optimizer
|
56
|
-
if self.cfgs.train.scale_lr:
|
57
|
-
self.scale_lr(parameters)
|
58
|
-
assert isinstance(cfg_opt, partial), f'optimizer.type is not supported anymore, please use class path like "torch.optim.AdamW".'
|
59
|
-
self.optimizer = cfg_opt(params=parameters)
|
60
|
-
self.lr_scheduler = get_scheduler(self.cfgs.train.scheduler, self.optimizer)
|
61
|
-
|
62
|
-
if __name__ == '__main__':
|
63
|
-
parser = argparse.ArgumentParser(description='Stable Diffusion Training')
|
64
|
-
parser.add_argument('--cfg', type=str, default='cfg/train/demo.yaml')
|
65
|
-
args, cfg_args = parser.parse_known_args()
|
66
|
-
|
67
|
-
conf = load_config_with_cli(args.cfg, args_list=cfg_args) # skip --cfg
|
68
|
-
trainer = TrainerDeepSpeed(conf)
|
69
|
-
trainer.train()
|
File without changes
|
File without changes
|