hcpdiff 2.2.1__py3-none-any.whl → 2.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. hcpdiff/ckpt_manager/__init__.py +1 -1
  2. hcpdiff/ckpt_manager/ckpt.py +21 -17
  3. hcpdiff/ckpt_manager/format/diffusers.py +4 -4
  4. hcpdiff/ckpt_manager/format/sd_single.py +3 -3
  5. hcpdiff/ckpt_manager/loader.py +11 -4
  6. hcpdiff/diffusion/noise/__init__.py +0 -1
  7. hcpdiff/diffusion/sampler/VP.py +27 -0
  8. hcpdiff/diffusion/sampler/__init__.py +2 -3
  9. hcpdiff/diffusion/sampler/base.py +106 -44
  10. hcpdiff/diffusion/sampler/diffusers.py +11 -17
  11. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -1
  12. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +77 -2
  13. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +193 -49
  14. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +110 -33
  15. hcpdiff/diffusion/sampler/sigma_scheduler/flow.py +74 -0
  16. hcpdiff/diffusion/sampler/sigma_scheduler/zero_terminal.py +22 -0
  17. hcpdiff/easy/cfg/sd15_train.py +33 -22
  18. hcpdiff/easy/cfg/sdxl_train.py +32 -23
  19. hcpdiff/evaluate/__init__.py +3 -1
  20. hcpdiff/evaluate/evaluator.py +76 -0
  21. hcpdiff/evaluate/metrics/__init__.py +1 -0
  22. hcpdiff/evaluate/metrics/clip_score.py +23 -0
  23. hcpdiff/evaluate/previewer.py +29 -12
  24. hcpdiff/loss/base.py +9 -26
  25. hcpdiff/loss/weighting.py +36 -18
  26. hcpdiff/models/lora_base_patch.py +26 -0
  27. hcpdiff/models/wrapper/sd.py +17 -19
  28. hcpdiff/trainer_ac.py +7 -5
  29. hcpdiff/trainer_ac_single.py +1 -6
  30. hcpdiff/utils/__init__.py +2 -1
  31. hcpdiff/utils/torch_utils.py +25 -0
  32. hcpdiff/workflow/__init__.py +1 -1
  33. hcpdiff/workflow/diffusion.py +27 -7
  34. hcpdiff/workflow/io.py +20 -3
  35. hcpdiff/workflow/text.py +6 -1
  36. {hcpdiff-2.2.1.dist-info → hcpdiff-2.3.1.dist-info}/METADATA +2 -2
  37. {hcpdiff-2.2.1.dist-info → hcpdiff-2.3.1.dist-info}/RECORD +41 -37
  38. {hcpdiff-2.2.1.dist-info → hcpdiff-2.3.1.dist-info}/WHEEL +1 -1
  39. hcpdiff/diffusion/noise/zero_terminal.py +0 -39
  40. hcpdiff/diffusion/sampler/ddpm.py +0 -20
  41. hcpdiff/diffusion/sampler/edm.py +0 -22
  42. {hcpdiff-2.2.1.dist-info → hcpdiff-2.3.1.dist-info}/entry_points.txt +0 -0
  43. {hcpdiff-2.2.1.dist-info → hcpdiff-2.3.1.dist-info}/licenses/LICENSE +0 -0
  44. {hcpdiff-2.2.1.dist-info → hcpdiff-2.3.1.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,14 @@
1
1
  import torch
2
- from rainbowneko.ckpt_manager import ckpt_saver, NekoPluginSaver, LAYERS_TRAINABLE, SafeTensorFormat
3
- from rainbowneko.parser import CfgWDPluginParser, neko_cfg, CfgWDModelParser, disable_neko_cfg
4
- from rainbowneko.utils import ConstantLR
5
-
2
+ from hcpdiff.ckpt_manager import LoraWebuiFormat
6
3
  from hcpdiff.easy import SDXL_auto_loader
7
4
  from hcpdiff.models import SDXLWrapper
8
5
  from hcpdiff.models.lora_layers_patch import LoraLayer
9
- from hcpdiff.ckpt_manager import LoraWebuiFormat
6
+ from rainbowneko.ckpt_manager import ckpt_saver, NekoPluginSaver, LAYERS_TRAINABLE, SafeTensorFormat, NekoOptimizerSaver
7
+ from rainbowneko.parser import CfgWDPluginParser, neko_cfg, CfgWDModelParser, disable_neko_cfg
8
+ from rainbowneko.utils import ConstantLR
10
9
 
11
10
  @neko_cfg
12
- def SDXL_finetuning(base_model: str, train_steps: int, dataset, save_step: int = 500, lr: float = 1e-5,
11
+ def SDXL_finetuning(base_model: str, train_steps: int, dataset, save_step: int = 500, save_optimizer=False, lr: float = 1e-5,
13
12
  dtype: str = 'fp16', low_vram: bool = False, warmup_steps: int = 0, name: str = 'SDXL'):
14
13
  if low_vram:
15
14
  from bitsandbytes.optim import AdamW8bit
@@ -17,6 +16,17 @@ def SDXL_finetuning(base_model: str, train_steps: int, dataset, save_step: int =
17
16
  else:
18
17
  optimizer = torch.optim.AdamW(_partial_=True)
19
18
 
19
+ ckpt_saver_dict = dict(
20
+ SDXL=ckpt_saver(
21
+ ckpt_type='safetensors',
22
+ target_module='denoiser',
23
+ layers=LAYERS_TRAINABLE,
24
+ )
25
+ )
26
+
27
+ if save_optimizer:
28
+ ckpt_saver_dict['optimizer'] = NekoOptimizerSaver()
29
+
20
30
  from cfgs.train.py import train_base, tuning_base
21
31
 
22
32
  return dict(
@@ -30,13 +40,7 @@ def SDXL_finetuning(base_model: str, train_steps: int, dataset, save_step: int =
30
40
  )
31
41
  ], weight_decay=1e-2),
32
42
 
33
- ckpt_saver=dict(
34
- SDXL=ckpt_saver(
35
- ckpt_type='safetensors',
36
- target_module='denoiser',
37
- layers=LAYERS_TRAINABLE,
38
- )
39
- ),
43
+ ckpt_saver=ckpt_saver_dict,
40
44
 
41
45
  train=dict(
42
46
  train_steps=train_steps,
@@ -64,9 +68,9 @@ def SDXL_finetuning(base_model: str, train_steps: int, dataset, save_step: int =
64
68
  )
65
69
 
66
70
  @neko_cfg
67
- def SDXL_lora_train(base_model: str, train_steps: int, dataset, save_step: int = 200, lr: float = 1e-4, rank: int = 4, alpha: float = None,
68
- with_conv: bool = False, dtype: str = 'fp16', low_vram: bool = False, warmup_steps: int = 0, name: str = 'SDXL',
69
- save_webui_format=False):
71
+ def SDXL_lora_train(base_model: str, train_steps: int, dataset, save_step: int = 200, save_optimizer=False, lr: float = 1e-4, rank: int = 4,
72
+ alpha: float = None, with_conv: bool = False, dtype: str = 'fp16', low_vram: bool = False, warmup_steps: int = 0,
73
+ name: str = 'SDXL', save_webui_format=False):
70
74
  with disable_neko_cfg:
71
75
  if alpha is None:
72
76
  alpha = rank
@@ -97,6 +101,17 @@ def SDXL_lora_train(base_model: str, train_steps: int, dataset, save_step: int =
97
101
  else:
98
102
  lora_format = SafeTensorFormat()
99
103
 
104
+ ckpt_saver_dict = dict(
105
+ _replace_=True,
106
+ lora_unet=NekoPluginSaver(
107
+ format=lora_format,
108
+ target_plugin='lora1',
109
+ )
110
+ )
111
+
112
+ if save_optimizer:
113
+ ckpt_saver_dict['optimizer'] = NekoOptimizerSaver()
114
+
100
115
  from cfgs.train.py.examples import SD_FT
101
116
 
102
117
  return dict(
@@ -114,13 +129,7 @@ def SDXL_lora_train(base_model: str, train_steps: int, dataset, save_step: int =
114
129
  )
115
130
  ), weight_decay=0.1),
116
131
 
117
- ckpt_saver=dict(
118
- _replace_ = True,
119
- lora_unet=NekoPluginSaver(
120
- format=lora_format,
121
- target_plugin='lora1',
122
- )
123
- ),
132
+ ckpt_saver=ckpt_saver_dict,
124
133
 
125
134
  train=dict(
126
135
  train_steps=train_steps,
@@ -1 +1,3 @@
1
- from .previewer import HCPPreviewer
1
+ from .previewer import HCPPreviewer
2
+ from .evaluator import HCPEvaluator
3
+ from .metrics import CLIPScoreMetric
@@ -0,0 +1,76 @@
1
+ from pathlib import Path
2
+
3
+ import torch
4
+ from accelerate.hooks import remove_hook_from_module
5
+ from rainbowneko.evaluate import WorkflowEvaluator, MetricGroup
6
+ from rainbowneko.utils import to_cuda
7
+
8
+ from hcpdiff.models.wrapper import SD15Wrapper
9
+
10
+ class HCPEvaluator(WorkflowEvaluator):
11
+
12
+ @torch.no_grad()
13
+ def evaluate(self, step: int, model: SD15Wrapper, prefix='eval/'):
14
+ if step%self.interval != 0 or not self.trainer.is_local_main_process:
15
+ return
16
+
17
+ # record training layers
18
+ training_layers = [layer for layer in model.modules() if layer.training]
19
+
20
+ model.eval()
21
+ self.trainer.loggers.info(f'Preview')
22
+
23
+ N_repeats = model.text_enc_hook.N_repeats
24
+ clip_skip = model.text_enc_hook.clip_skip
25
+ clip_final_norm = model.text_enc_hook.clip_final_norm
26
+ use_attention_mask = model.text_enc_hook.use_attention_mask
27
+
28
+ preview_root = Path(self.trainer.exp_dir)/'imgs'
29
+ preview_root.mkdir(parents=True, exist_ok=True)
30
+
31
+ states = self.workflow_runner.run(model=model, in_preview=True, te_hook=model.text_enc_hook,
32
+ device=self.device, dtype=self.dtype, preview_root=preview_root, preview_step=step,
33
+ world_size=self.trainer.world_size, local_rank=self.trainer.local_rank,
34
+ emb_hook=self.trainer.cfgs.emb_pt.embedding_hook if self.trainer.pt_trainable else None)
35
+
36
+ # get metrics
37
+ metric = states['_metric']
38
+
39
+ v_metric = metric.finish(self.trainer.accelerator.gather, self.trainer.is_local_main_process)
40
+ if not isinstance(v_metric, dict):
41
+ v_metric = {'metric':v_metric}
42
+
43
+ log_data = {
44
+ "eval/Step":{
45
+ "format":"{}",
46
+ "data":[step],
47
+ }
48
+ }
49
+ log_data.update(MetricGroup.format(v_metric, prefix=prefix))
50
+ self.trainer.loggers.log(log_data, step, force=True)
51
+
52
+ # restore model states
53
+ if model.vae is not None:
54
+ model.vae.disable_tiling()
55
+ model.vae.disable_slicing()
56
+ remove_hook_from_module(model.vae, recurse=True)
57
+ if 'vae_encode_raw' in states:
58
+ model.vae.encode = states['vae_encode_raw']
59
+ model.vae.decode = states['vae_decode_raw']
60
+
61
+ if 'emb_hook' in states and not self.trainer.pt_trainable:
62
+ states['emb_hook'].remove()
63
+
64
+ if self.trainer.pt_trainable:
65
+ self.trainer.cfgs.emb_pt.embedding_hook.N_repeats = N_repeats
66
+
67
+ model.tokenizer.N_repeats = N_repeats
68
+ model.text_enc_hook.N_repeats = N_repeats
69
+ model.text_enc_hook.clip_skip = clip_skip
70
+ model.text_enc_hook.clip_final_norm = clip_final_norm
71
+ model.text_enc_hook.use_attention_mask = use_attention_mask
72
+
73
+ to_cuda(model)
74
+
75
+ for layer in training_layers:
76
+ layer.train()
@@ -0,0 +1 @@
1
+ from .clip_score import CLIPScoreMetric
@@ -0,0 +1,23 @@
1
+ from torchmetrics.multimodal.clip_score import CLIPScore, _clip_score_update
2
+ from torch import Tensor
3
+ from typing import List
4
+
5
+ class CLIPScoreMetric(CLIPScore):
6
+ def update(self, images: Tensor | List[Tensor], text: str | list[str]) -> None:
7
+ """Update CLIP score on a batch of images and text.
8
+
9
+ Args:
10
+ images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors, in the [-1, 1] range
11
+ text: Either a single caption or a list of captions
12
+
13
+ Raises:
14
+ ValueError:
15
+ If not all images have format [C, H, W]
16
+ ValueError:
17
+ If the number of images and captions do not match
18
+
19
+ """
20
+ images = (images+1)/2 # [-1,1] -> [0,1]
21
+ score, n_samples = _clip_score_update(images, text, self.model, self.processor)
22
+ self.score += score.sum(0)
23
+ self.n_samples += n_samples
@@ -6,32 +6,49 @@ from rainbowneko.utils import to_cuda
6
6
 
7
7
  from hcpdiff.models.wrapper import SD15Wrapper
8
8
  from accelerate.hooks import remove_hook_from_module
9
+ from typing import Dict
10
+ from types import ModuleType
9
11
 
10
12
  class HCPPreviewer(WorkflowPreviewer):
13
+ def __init__(self, parser, cfgs_raw, workflow: str | ModuleType | Dict, ds_name=None, interval=100, trainer=None,
14
+ mixed_precision=None, seed=42, **cfgs):
15
+ super().__init__(parser, cfgs_raw, workflow, ds_name=ds_name, interval=interval, trainer=trainer,
16
+ mixed_precision=mixed_precision, seed=seed, **cfgs)
17
+ if trainer is None:
18
+ self.pt_trainable = False
19
+ else:
20
+ self.emb_pt = trainer.cfgs.emb_pt
21
+ self.pt_trainable = trainer.pt_trainable
11
22
 
12
23
  @torch.no_grad()
13
- def evaluate(self, step: int, model: SD15Wrapper, prefix='eval/'):
14
- if step%self.interval != 0 or not self.trainer.is_local_main_process:
24
+ def evaluate(self, step: int, prefix='eval/'):
25
+ if step%self.interval != 0 or not self.is_local_main_process:
15
26
  return
16
27
 
17
28
  # record training layers
18
- training_layers = [layer for layer in model.modules() if layer.training]
29
+ if self.model_wrapper is not None:
30
+ training_layers = [layer for layer in self.model_raw.modules() if layer.training]
31
+ self.model_wrapper.eval()
32
+ model = self.model_raw
33
+ else:
34
+ training_layers = []
35
+ model = None
19
36
 
20
- model.eval()
21
- self.trainer.loggers.info(f'Preview')
37
+ if self.loggers is not None:
38
+ self.loggers.info(f'Preview')
22
39
 
23
40
  N_repeats = model.text_enc_hook.N_repeats
24
41
  clip_skip = model.text_enc_hook.clip_skip
25
42
  clip_final_norm = model.text_enc_hook.clip_final_norm
26
43
  use_attention_mask = model.text_enc_hook.use_attention_mask
27
44
 
28
- preview_root = Path(self.trainer.exp_dir)/'imgs'
45
+ preview_root = Path(self.exp_dir)/'imgs'
29
46
  preview_root.mkdir(parents=True, exist_ok=True)
30
47
 
31
48
  states = self.workflow_runner.run(model=model, in_preview=True, te_hook=model.text_enc_hook,
32
- device=self.device, dtype=self.dtype, preview_root=preview_root, preview_step=step,
33
- world_size=self.trainer.world_size, local_rank=self.trainer.local_rank,
34
- emb_hook=self.trainer.cfgs.emb_pt.embedding_hook if self.trainer.pt_trainable else None)
49
+ device=self.device, dtype=self.weight_dtype, preview_root=preview_root, preview_step=step,
50
+ world_size=self.world_size, local_rank=self.local_rank,
51
+ emb_hook=self.emb_pt.embedding_hook if self.pt_trainable else None)
35
52
 
36
53
  # restore model states
37
54
  if model.vae is not None:
@@ -42,11 +59,11 @@ class HCPPreviewer(WorkflowPreviewer):
42
59
  model.vae.encode = states['vae_encode_raw']
43
60
  model.vae.decode = states['vae_decode_raw']
44
61
 
45
- if 'emb_hook' in states and not self.trainer.pt_trainable:
62
+ if 'emb_hook' in states and not self.pt_trainable:
46
63
  states['emb_hook'].remove()
47
64
 
48
- if self.trainer.pt_trainable:
49
- self.trainer.cfgs.emb_pt.embedding_hook.N_repeats = N_repeats
65
+ if self.pt_trainable:
66
+ self.emb_pt.embedding_hook.N_repeats = N_repeats
50
67
 
51
68
  model.tokenizer.N_repeats = N_repeats
52
69
  model.text_enc_hook.N_repeats = N_repeats
hcpdiff/loss/base.py CHANGED
@@ -6,36 +6,19 @@ class DiffusionLossContainer(LossContainer):
6
6
  def __init__(self, loss, weight=1.0, key_map=None):
7
7
  key_map = key_map or getattr(loss, '_key_map', None) or ('pred.model_pred -> 0', 'pred.target -> 1')
8
8
  super().__init__(loss, weight, key_map)
9
- self.target_type = getattr(loss, 'target_type', 'eps')
9
+ self.target_type = getattr(loss, 'target_type', None)
10
10
 
11
- def get_target(self, pred_type, model_pred, x_0, noise, x_t, sigma, noise_sampler, **kwargs):
11
+ def get_target(self, model_pred, x_0, noise, x_t, timesteps, noise_sampler, **kwargs):
12
12
  # Get target
13
- if self.target_type == "eps":
14
- target = noise
15
- elif self.target_type == "x0":
16
- target = x_0
17
- elif self.target_type == "velocity":
18
- target = noise_sampler.eps_to_velocity(noise, x_t, sigma)
19
- else:
20
- raise ValueError(f"Unsupport target_type {self.target_type}")
13
+ target = noise_sampler.get_target(x_0, x_t, timesteps, eps=noise, target_type=self.target_type)
21
14
 
22
- # TODO: put in wrapper
23
- # # remove pred vars
24
- # if model_pred.shape[1] == target.shape[1]*2:
25
- # model_pred, _ = model_pred.chunk(2, dim=1)
26
-
27
- # Convert pred_type to target_type
28
- if pred_type != self.target_type:
29
- cvt_func = getattr(noise_sampler, f'{pred_type}_to_{self.target_type}', None)
30
- if cvt_func is None:
31
- raise ValueError(f"Unsupport pred_type {pred_type} with target_type {self.target_type}")
32
- else:
33
- model_pred = cvt_func(model_pred, x_t, sigma)
34
- return model_pred, target
15
+ # Convert pred_type for target_type
16
+ pred = noise_sampler.pred_for_target(model_pred, x_t, timesteps, eps=noise, target_type=self.target_type)
17
+ return pred, target
35
18
 
36
19
  def forward(self, pred:Dict[str,Any], inputs:Dict[str,Any]) -> Tensor:
37
- model_pred, target = self.get_target(**pred)
38
- pred['model_pred'] = model_pred
20
+ pred_cvt, target = self.get_target(**pred)
21
+ pred['model_pred'] = pred_cvt
39
22
  pred['target'] = target
40
- loss = super().forward(pred, inputs) * self.weight # [B,*,*,*]
23
+ loss = super().forward(pred, inputs) # [B,*,*,*]
41
24
  return loss.mean()
hcpdiff/loss/weighting.py CHANGED
@@ -7,6 +7,11 @@ class LossWeight(nn.Module):
7
7
  super().__init__()
8
8
  self.loss = loss
9
9
 
10
+ def get_c_out(self, pred):
11
+ t = pred['timesteps']
12
+ noise_sampler = pred['noise_sampler']
13
+ return noise_sampler.sigma_scheduler.c_out(t)
14
+
10
15
  def get_weight(self, pred, inputs):
11
16
  '''
12
17
 
@@ -25,13 +30,19 @@ class LossWeight(nn.Module):
25
30
 
26
31
  class SNRWeight(LossWeight):
27
32
  def get_weight(self, pred, inputs):
28
- if self.loss.target_type == 'eps':
29
- return 1
30
- elif self.loss.target_type == "x0":
31
- sigma = pred['sigma']
32
- return (1./sigma**2).view(-1, 1, 1, 1)
33
+ noise_sampler = pred['noise_sampler']
34
+ c_out = self.get_c_out(pred)
35
+ target_type = getattr(self.loss, 'target_type', None) or noise_sampler.target_type
36
+ if target_type == 'eps':
37
+ w_snr = 1
38
+ elif target_type == "x0":
39
+ w_snr = (1./c_out**2).float()
40
+ elif target_type == "velocity":
41
+ w_snr = (1./(1-c_out)**2).float()
33
42
  else:
34
- raise ValueError(f"{self.__class__.__name__} is not support for target_type {self.loss.target_type}")
43
+ raise ValueError(f"{self.__class__.__name__} is not support for target_type {target_type}")
44
+
45
+ return w_snr.view(-1, 1, 1, 1)
35
46
 
36
47
  class MinSNRWeight(LossWeight):
37
48
  def __init__(self, loss: DiffusionLossContainer, gamma: float = 1.):
@@ -39,13 +50,18 @@ class MinSNRWeight(LossWeight):
39
50
  self.gamma = gamma
40
51
 
41
52
  def get_weight(self, pred, inputs):
42
- sigma = pred['sigma']
43
- if self.loss.target_type == 'eps':
44
- w_snr = (self.gamma*sigma**2).clip(max=1).float()
45
- elif self.loss.target_type == "x0":
46
- w_snr = (1/(sigma**2)).clip(max=self.gamma).float()
53
+ noise_sampler = pred['noise_sampler']
54
+ c_out = self.get_c_out(pred)
55
+ target_type = getattr(self.loss, 'target_type', None) or noise_sampler.target_type
56
+ if target_type == 'eps':
57
+ w_snr = (self.gamma*c_out**2).clip(max=1).float()
58
+ elif target_type == "x0":
59
+ w_snr = (1./c_out**2).clip(max=self.gamma).float()
60
+ elif target_type == "velocity":
61
+ w_v = 1/(1-c_out)**2
62
+ w_snr = (self.gamma*c_out**2/w_v).clip(max=w_v).float()
47
63
  else:
48
- raise ValueError(f"{self.__class__.__name__} is not support for target_type {self.loss.target_type}")
64
+ raise ValueError(f"{self.__class__.__name__} is not support for target_type {target_type}")
49
65
 
50
66
  return w_snr.view(-1, 1, 1, 1)
51
67
 
@@ -55,12 +71,14 @@ class EDMWeight(LossWeight):
55
71
  self.gamma = gamma
56
72
 
57
73
  def get_weight(self, pred, inputs):
58
- sigma = pred['sigma']
59
- if self.loss.target_type == 'eps':
60
- w_snr = ((sigma**2+self.gamma**2)/(self.gamma**2)).float()
61
- elif self.loss.target_type == "x0":
62
- w_snr = ((sigma**2+self.gamma**2)/((sigma*self.gamma)**2)).float()
74
+ c_out = self.get_c_out(pred)
75
+ noise_sampler = pred['noise_sampler']
76
+ target_type = getattr(self.loss, 'target_type', None) or noise_sampler.target_type
77
+ if target_type == 'edm':
78
+ w_snr = 1
79
+ elif target_type == "x0":
80
+ w_snr = (1./c_out**2).float()
63
81
  else:
64
- raise ValueError(f"{self.__class__.__name__} is not support for target_type {self.loss.target_type}")
82
+ raise ValueError(f"{self.__class__.__name__} is not support for target_type {target_type}")
65
83
 
66
84
  return w_snr.view(-1, 1, 1, 1)
@@ -34,6 +34,32 @@ class LoraPatchContainer(PatchPluginContainer):
34
34
 
35
35
  return self[name].post_forward(x, self._host.weight, weight_, self._host.bias, bias_)
36
36
 
37
+ @property
38
+ def weight(self):
39
+ weight_ = None
40
+ for name in self.plugin_names:
41
+ if weight_ is None:
42
+ weight_ = self[name].get_weight()
43
+ else:
44
+ weight_ = weight_+self[name].get_weight()
45
+ return self._host.weight + weight_
46
+
47
+ @property
48
+ def bias(self):
49
+ bias_ = None
50
+ for name in self.plugin_names:
51
+ if bias_ is None:
52
+ bias_ = self[name].get_bias()
53
+ else:
54
+ bias_ = bias_+self[name].get_bias()
55
+
56
+ if self._host.bias is not None:
57
+ if bias_ is None:
58
+ bias_ = self._host.bias
59
+ else:
60
+ bias_ = self._host.bias + bias_
61
+ return bias_
62
+
37
63
  class LoraBlock(PatchPluginBlock):
38
64
  container_cls = LoraPatchContainer
39
65
  wrapable_classes = (nn.Linear, nn.Conv2d)
@@ -17,7 +17,7 @@ from ..cfg_context import CFGContext
17
17
 
18
18
  class SD15Wrapper(BaseWrapper):
19
19
  def __init__(self, denoiser: UNet2DConditionModel, TE, vae: AutoencoderKL, noise_sampler: BaseSampler, tokenizer, min_attnmask=0,
20
- pred_type='eps', TE_hook_cfg:TEHookCFG=SD15_TEHookCFG, cfg_context=CFGContext(), key_map_in=None, key_map_out=None):
20
+ TE_hook_cfg:TEHookCFG=SD15_TEHookCFG, cfg_context=CFGContext(), key_map_in=None, key_map_out=None):
21
21
  super().__init__()
22
22
  self.key_mapper_in = self.build_mapper(key_map_in, None, (
23
23
  'prompt -> prompt_ids', 'image -> image', 'attn_mask -> attn_mask', 'position_ids -> position_ids', 'neg_prompt -> neg_prompt_ids',
@@ -31,8 +31,6 @@ class SD15Wrapper(BaseWrapper):
31
31
  self.tokenizer = tokenizer
32
32
  self.min_attnmask = min_attnmask
33
33
 
34
- self.pred_type = pred_type
35
-
36
34
  self.TE_hook_cfg = TEHookCFG.create(TE_hook_cfg)
37
35
  self.cfg_context = cfg_context
38
36
  self.tokenizer.N_repeats = self.TE_hook_cfg.tokenizer_repeats
@@ -93,8 +91,9 @@ class SD15Wrapper(BaseWrapper):
93
91
  plugin_input={}, **kwargs):
94
92
  # input prepare
95
93
  x_0 = self.get_latents(image)
96
- x_t, noise, sigma, timesteps = self.noise_sampler.add_noise_rand_t(x_0)
97
- x_t_in = x_t*self.noise_sampler.c_in(sigma).to(dtype=x_t.dtype)
94
+ x_t, noise, timesteps = self.noise_sampler.add_noise_rand_t(x_0)
95
+ x_t_in = x_t*self.noise_sampler.sigma_scheduler.c_in(timesteps).to(dtype=x_t.dtype).view(-1,1,1,1)
96
+ t_in = self.noise_sampler.sigma_scheduler.c_noise(timesteps)
98
97
 
99
98
  if neg_prompt_ids:
100
99
  prompt_ids = torch.cat([neg_prompt_ids, prompt_ids], dim=0)
@@ -104,15 +103,14 @@ class SD15Wrapper(BaseWrapper):
104
103
  position_ids = torch.cat([neg_position_ids, position_ids], dim=0)
105
104
 
106
105
  # model forward
107
- x_t_in, timesteps = self.cfg_context.pre(x_t_in, timesteps)
108
- encoder_hidden_states = self.forward_TE(prompt_ids, timesteps, attn_mask=attn_mask, position_ids=position_ids,
106
+ x_t_in, t_in = self.cfg_context.pre(x_t_in, t_in)
107
+ encoder_hidden_states = self.forward_TE(prompt_ids, t_in, attn_mask=attn_mask, position_ids=position_ids,
109
108
  plugin_input=plugin_input, **kwargs)
110
- model_pred = self.forward_denoiser(x_t_in, prompt_ids, encoder_hidden_states, timesteps, attn_mask=attn_mask, position_ids=position_ids,
109
+ model_pred = self.forward_denoiser(x_t_in, prompt_ids, encoder_hidden_states, t_in, attn_mask=attn_mask, position_ids=position_ids,
111
110
  plugin_input=plugin_input, **kwargs)
112
111
  model_pred = self.cfg_context.post(model_pred)
113
112
 
114
- return dict(model_pred=model_pred, noise=noise, sigma=sigma, timesteps=timesteps, x_0=x_0, x_t=x_t, pred_type=self.pred_type,
115
- noise_sampler=self.noise_sampler)
113
+ return dict(model_pred=model_pred, noise=noise, timesteps=timesteps, x_0=x_0, x_t=x_t, noise_sampler=self.noise_sampler)
116
114
 
117
115
  def forward(self, ds_name=None, **kwargs):
118
116
  model_args, model_kwargs = self.get_map_data(self.key_mapper_in, kwargs, ds_name)
@@ -156,8 +154,8 @@ class SD15Wrapper(BaseWrapper):
156
154
 
157
155
  class SDXLWrapper(SD15Wrapper):
158
156
  def __init__(self, denoiser: UNet2DConditionModel, TE, vae: AutoencoderKL, noise_sampler: BaseSampler, tokenizer, min_attnmask=0,
159
- pred_type='eps', TE_hook_cfg:TEHookCFG=SDXL_TEHookCFG, cfg_context=CFGContext(), key_map_in=None, key_map_out=None):
160
- super().__init__(denoiser, TE, vae, noise_sampler, tokenizer, min_attnmask, pred_type, TE_hook_cfg, cfg_context, key_map_in, key_map_out)
157
+ TE_hook_cfg:TEHookCFG=SDXL_TEHookCFG, cfg_context=CFGContext(), key_map_in=None, key_map_out=None):
158
+ super().__init__(denoiser, TE, vae, noise_sampler, tokenizer, min_attnmask, TE_hook_cfg, cfg_context, key_map_in, key_map_out)
161
159
  self.key_mapper_in = self.build_mapper(key_map_in, None, (
162
160
  'prompt -> prompt_ids', 'image -> image', 'attn_mask -> attn_mask', 'position_ids -> position_ids', 'neg_prompt -> neg_prompt_ids',
163
161
  'neg_attn_mask -> neg_attn_mask', 'neg_position_ids -> neg_position_ids', 'plugin_input -> plugin_input', 'coord -> crop_info'))
@@ -195,8 +193,9 @@ class SDXLWrapper(SD15Wrapper):
195
193
  crop_info=None, plugin_input={}):
196
194
  # input prepare
197
195
  x_0 = self.get_latents(image)
198
- x_t, noise, sigma, timesteps = self.noise_sampler.add_noise_rand_t(x_0)
199
- x_t_in = x_t*self.noise_sampler.c_in(sigma).to(dtype=x_t.dtype)
196
+ x_t, noise, timesteps = self.noise_sampler.add_noise_rand_t(x_0)
197
+ x_t_in = x_t*self.noise_sampler.sigma_scheduler.c_in(timesteps).to(dtype=x_t.dtype).view(-1,1,1,1)
198
+ t_in = self.noise_sampler.sigma_scheduler.c_noise(timesteps)
200
199
 
201
200
  if neg_prompt_ids:
202
201
  prompt_ids = torch.cat([neg_prompt_ids, prompt_ids], dim=0)
@@ -206,13 +205,12 @@ class SDXLWrapper(SD15Wrapper):
206
205
  position_ids = torch.cat([neg_position_ids, position_ids], dim=0)
207
206
 
208
207
  # model forward
209
- x_t_in, timesteps = self.cfg_context.pre(x_t_in, timesteps)
210
- encoder_hidden_states, pooled_output = self.forward_TE(prompt_ids, timesteps, attn_mask=attn_mask, position_ids=position_ids,
208
+ x_t_in, t_in = self.cfg_context.pre(x_t_in, t_in)
209
+ encoder_hidden_states, pooled_output = self.forward_TE(prompt_ids, t_in, attn_mask=attn_mask, position_ids=position_ids,
211
210
  plugin_input=plugin_input)
212
211
  added_cond_kwargs = {"text_embeds":pooled_output[-1], "time_ids":crop_info}
213
- model_pred = self.forward_denoiser(x_t_in, prompt_ids, encoder_hidden_states, timesteps, added_cond_kwargs=added_cond_kwargs,
212
+ model_pred = self.forward_denoiser(x_t_in, prompt_ids, encoder_hidden_states, t_in, added_cond_kwargs=added_cond_kwargs,
214
213
  attn_mask=attn_mask, position_ids=position_ids, plugin_input=plugin_input)
215
214
  model_pred = self.cfg_context.post(model_pred)
216
215
 
217
- return dict(model_pred=model_pred, noise=noise, sigma=sigma, timesteps=timesteps, x_0=x_0, x_t=x_t, pred_type=self.pred_type,
218
- noise_sampler=self.noise_sampler)
216
+ return dict(model_pred=model_pred, noise=noise, timesteps=timesteps, x_0=x_0, x_t=x_t, noise_sampler=self.noise_sampler)
hcpdiff/trainer_ac.py CHANGED
@@ -4,8 +4,8 @@ import warnings
4
4
  import torch
5
5
  from rainbowneko.parser import load_config_with_cli
6
6
  from rainbowneko.ckpt_manager import NekoSaver
7
- from rainbowneko.train import Trainer
8
- from rainbowneko.utils import xformers_available, is_dict
7
+ from rainbowneko.train.trainer import Trainer
8
+ from rainbowneko.utils import xformers_available, is_dict, weight_dtype_map
9
9
  from hcpdiff.ckpt_manager import EmbFormat
10
10
 
11
11
  class HCPTrainer(Trainer):
@@ -17,7 +17,7 @@ class HCPTrainer(Trainer):
17
17
  warnings.warn("xformers is not available. Make sure it is installed correctly")
18
18
 
19
19
  if self.model_wrapper.vae is not None:
20
- self.vae_dtype = self.weight_dtype_map.get(self.cfgs.model.get('vae_dtype', None), torch.float32)
20
+ self.vae_dtype = weight_dtype_map.get(self.cfgs.model.get('vae_dtype', None), torch.float32)
21
21
  self.model_wrapper.set_dtype(self.weight_dtype, self.vae_dtype)
22
22
 
23
23
  if self.cfgs.model.gradient_checkpointing:
@@ -44,10 +44,12 @@ class HCPTrainer(Trainer):
44
44
 
45
45
  def save_model(self, from_raw=False):
46
46
  NekoSaver.save_all(
47
- self.model_raw,
48
- plugin_groups={**self.all_plugin, 'embs': self.train_pts},
49
47
  cfg=self.ckpt_saver,
48
+ model=self.model_raw,
49
+ plugin_groups=self.all_plugin,
50
+ embs=self.train_pts,
50
51
  model_ema=getattr(self, "ema_model", None),
52
+ optimizer=self.optimizer,
51
53
  name_template=f'{{}}-{self.real_step}',
52
54
  )
53
55
 
@@ -1,12 +1,7 @@
1
1
  import argparse
2
- import sys
3
- from functools import partial
4
-
5
- import torch
6
- from accelerate import Accelerator
7
- from loguru import logger
8
2
 
9
3
  from rainbowneko.train.trainer import TrainerSingleCard
4
+
10
5
  from hcpdiff.trainer_ac import HCPTrainer, load_config_with_cli
11
6
 
12
7
  class HCPTrainerSingleCard(TrainerSingleCard, HCPTrainer):
hcpdiff/utils/__init__.py CHANGED
@@ -1,2 +1,3 @@
1
1
  from .utils import *
2
- from .net_utils import *
2
+ from .net_utils import *
3
+ from .torch_utils import invert_func
@@ -0,0 +1,25 @@
1
+ import torch
2
+
3
+ def invert_func(func, y, x_min=0.0, x_max=1.0, tol=1e-5, max_iter=100):
4
+ """
5
+ y: [B]
6
+ :return: x [B]
7
+ """
8
+ y = y.to(dtype=torch.float32)
9
+ left = torch.full_like(y, x_min)
10
+ right = torch.full_like(y, x_max)
11
+
12
+ for _ in range(max_iter):
13
+ mid = (left+right)/2
14
+ val = func(mid)
15
+
16
+ too_large = val>y
17
+ too_small = ~too_large
18
+
19
+ left = torch.where(too_small, mid, left)
20
+ right = torch.where(too_large, mid, right)
21
+
22
+ if torch.all(torch.abs(val-y)<tol):
23
+ break
24
+
25
+ return (left+right)/2
@@ -1,5 +1,5 @@
1
1
  from .diffusion import InputFeederAction, MakeLatentAction, DenoiseAction, SampleAction, DiffusionStepAction, \
2
- X0PredAction, SeedAction, MakeTimestepsAction, PrepareDiffusionAction, time_iter
2
+ X0PredAction, SeedAction, MakeTimestepsAction, PrepareDiffusionAction, time_iter, DiffusionActions
3
3
  from .text import TextEncodeAction, TextHookAction, AttnMultTextEncodeAction
4
4
  from .vae import EncodeAction, DecodeAction
5
5
  from .io import BuildModelsAction, SaveImageAction, LoadImageAction