hcpdiff 2.2__py3-none-any.whl → 2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. hcpdiff/ckpt_manager/__init__.py +1 -1
  2. hcpdiff/ckpt_manager/ckpt.py +21 -17
  3. hcpdiff/ckpt_manager/format/diffusers.py +4 -4
  4. hcpdiff/ckpt_manager/format/sd_single.py +3 -3
  5. hcpdiff/ckpt_manager/loader.py +11 -4
  6. hcpdiff/diffusion/noise/__init__.py +0 -1
  7. hcpdiff/diffusion/sampler/VP.py +27 -0
  8. hcpdiff/diffusion/sampler/__init__.py +2 -3
  9. hcpdiff/diffusion/sampler/base.py +106 -44
  10. hcpdiff/diffusion/sampler/diffusers.py +11 -17
  11. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -1
  12. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +77 -2
  13. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +193 -49
  14. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +110 -33
  15. hcpdiff/diffusion/sampler/sigma_scheduler/flow.py +74 -0
  16. hcpdiff/diffusion/sampler/sigma_scheduler/zero_terminal.py +22 -0
  17. hcpdiff/easy/cfg/sd15_train.py +35 -24
  18. hcpdiff/easy/cfg/sdxl_train.py +34 -25
  19. hcpdiff/evaluate/__init__.py +3 -1
  20. hcpdiff/evaluate/evaluator.py +76 -0
  21. hcpdiff/evaluate/metrics/__init__.py +1 -0
  22. hcpdiff/evaluate/metrics/clip_score.py +23 -0
  23. hcpdiff/evaluate/previewer.py +29 -12
  24. hcpdiff/loss/base.py +9 -26
  25. hcpdiff/loss/weighting.py +36 -18
  26. hcpdiff/models/lora_base_patch.py +26 -0
  27. hcpdiff/models/text_emb_ex.py +4 -0
  28. hcpdiff/models/wrapper/sd.py +17 -19
  29. hcpdiff/trainer_ac.py +7 -12
  30. hcpdiff/trainer_ac_single.py +1 -6
  31. hcpdiff/trainer_deepspeed.py +47 -0
  32. hcpdiff/utils/__init__.py +2 -1
  33. hcpdiff/utils/torch_utils.py +25 -0
  34. hcpdiff/workflow/__init__.py +1 -1
  35. hcpdiff/workflow/diffusion.py +27 -7
  36. hcpdiff/workflow/io.py +20 -3
  37. hcpdiff/workflow/text.py +6 -1
  38. {hcpdiff-2.2.dist-info → hcpdiff-2.3.dist-info}/METADATA +8 -4
  39. {hcpdiff-2.2.dist-info → hcpdiff-2.3.dist-info}/RECORD +43 -39
  40. {hcpdiff-2.2.dist-info → hcpdiff-2.3.dist-info}/WHEEL +1 -1
  41. {hcpdiff-2.2.dist-info → hcpdiff-2.3.dist-info}/entry_points.txt +1 -0
  42. hcpdiff/diffusion/noise/zero_terminal.py +0 -39
  43. hcpdiff/diffusion/sampler/ddpm.py +0 -20
  44. hcpdiff/diffusion/sampler/edm.py +0 -22
  45. hcpdiff/train_deepspeed.py +0 -69
  46. {hcpdiff-2.2.dist-info → hcpdiff-2.3.dist-info}/licenses/LICENSE +0 -0
  47. {hcpdiff-2.2.dist-info → hcpdiff-2.3.dist-info}/top_level.txt +0 -0
@@ -1,20 +0,0 @@
1
- import torch
2
-
3
- from .base import BaseSampler
4
- from .sigma_scheduler import SigmaScheduler
5
-
6
- class DDPMSampler(BaseSampler):
7
- def __init__(self, sigma_scheduler: SigmaScheduler, generator: torch.Generator=None):
8
- super().__init__(sigma_scheduler, generator)
9
-
10
- def c_in(self, sigma):
11
- return 1./(sigma**2+1).sqrt()
12
-
13
- def c_out(self, sigma):
14
- return -sigma
15
-
16
- def c_skip(self, sigma):
17
- return 1.
18
-
19
- def denoise(self, x, sigma, eps=None, generator=None):
20
- raise NotImplementedError
@@ -1,22 +0,0 @@
1
- import torch
2
-
3
- from .base import BaseSampler
4
- from .sigma_scheduler import SigmaScheduler
5
-
6
- class EDMSampler(BaseSampler):
7
- def __init__(self, sigma_scheduler: SigmaScheduler, generator: torch.Generator = None, sigma_data: float = 1.0, sigma_thr=1000):
8
- super().__init__(sigma_scheduler, generator)
9
- self.sigma_data = sigma_data
10
- self.sigma_thr = sigma_thr
11
-
12
- def c_in(self, sigma):
13
- return 1/(sigma**2+self.sigma_data**2).sqrt()
14
-
15
- def c_out(self, sigma):
16
- return (sigma*self.sigma_data)/(sigma**2+self.sigma_data**2).sqrt()
17
-
18
- def c_skip(self, sigma):
19
- return self.sigma_data**2/(sigma**2+self.sigma_data**2)
20
-
21
- def denoise(self, x, sigma, eps=None, generator=None):
22
- raise NotImplementedError
@@ -1,69 +0,0 @@
1
- import argparse
2
- import os
3
- import sys
4
- import warnings
5
- from functools import partial
6
-
7
- import torch
8
-
9
- from hcpdiff.ckpt_manager import CkptManagerPKL, CkptManagerSafe
10
- from hcpdiff.train_ac_old import Trainer, load_config_with_cli
11
- from hcpdiff.utils.net_utils import get_scheduler
12
-
13
- class TrainerDeepSpeed(Trainer):
14
-
15
- def build_ckpt_manager(self):
16
- self.ckpt_manager = self.ckpt_manager_map[self.cfgs.ckpt_type](plugin_from_raw=True)
17
- if self.is_local_main_process:
18
- self.ckpt_manager.set_save_dir(os.path.join(self.exp_dir, 'ckpts'), emb_dir=self.cfgs.tokenizer_pt.emb_dir)
19
-
20
- @property
21
- def unet_raw(self):
22
- return self.accelerator.unwrap_model(self.TE_unet).unet if self.train_TE else self.accelerator.unwrap_model(self.TE_unet.unet)
23
-
24
- @property
25
- def TE_raw(self):
26
- return self.accelerator.unwrap_model(self.TE_unet).TE if self.train_TE else self.TE_unet.TE
27
-
28
- def get_loss(self, model_pred, target, timesteps, att_mask):
29
- if att_mask is None:
30
- att_mask = 1.0
31
- if getattr(self.criterion, 'need_timesteps', False):
32
- loss = (self.criterion(model_pred.float(), target.float(), timesteps)*att_mask).mean()
33
- else:
34
- loss = (self.criterion(model_pred.float(), target.float())*att_mask).mean()
35
- return loss
36
-
37
- def build_optimizer_scheduler(self):
38
- # set optimizer
39
- parameters, parameters_pt = self.get_param_group_train()
40
-
41
- if len(parameters_pt)>0: # do prompt-tuning
42
- cfg_opt_pt = self.cfgs.train.optimizer_pt
43
- # if self.cfgs.train.scale_lr_pt:
44
- # self.scale_lr(parameters_pt)
45
- assert isinstance(cfg_opt_pt, partial), f'optimizer.type is not supported anymore, please use class path like "torch.optim.AdamW".'
46
- weight_decay = cfg_opt_pt.keywords.get('weight_decay', None)
47
- if weight_decay is not None:
48
- for param in parameters_pt:
49
- param['weight_decay'] = weight_decay
50
-
51
- parameters += parameters_pt
52
- warnings.warn('deepspeed dose not support multi optimizer and lr_scheduler. optimizer_pt and scheduler_pt will not work.')
53
-
54
- if len(parameters)>0:
55
- cfg_opt = self.cfgs.train.optimizer
56
- if self.cfgs.train.scale_lr:
57
- self.scale_lr(parameters)
58
- assert isinstance(cfg_opt, partial), f'optimizer.type is not supported anymore, please use class path like "torch.optim.AdamW".'
59
- self.optimizer = cfg_opt(params=parameters)
60
- self.lr_scheduler = get_scheduler(self.cfgs.train.scheduler, self.optimizer)
61
-
62
- if __name__ == '__main__':
63
- parser = argparse.ArgumentParser(description='Stable Diffusion Training')
64
- parser.add_argument('--cfg', type=str, default='cfg/train/demo.yaml')
65
- args, cfg_args = parser.parse_known_args()
66
-
67
- conf = load_config_with_cli(args.cfg, args_list=cfg_args) # skip --cfg
68
- trainer = TrainerDeepSpeed(conf)
69
- trainer.train()