hcpdiff 0.9.0__py3-none-any.whl → 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (208) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +80 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +1 -2
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text2img.py +36 -74
  21. hcpdiff/data/source/text2img_cond.py +9 -15
  22. hcpdiff/diffusion/__init__.py +0 -0
  23. hcpdiff/diffusion/noise/__init__.py +2 -0
  24. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  25. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  26. hcpdiff/diffusion/sampler/__init__.py +5 -0
  27. hcpdiff/diffusion/sampler/base.py +72 -0
  28. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  29. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  30. hcpdiff/diffusion/sampler/edm.py +22 -0
  31. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  35. hcpdiff/easy/__init__.py +2 -0
  36. hcpdiff/easy/cfg/__init__.py +3 -0
  37. hcpdiff/easy/cfg/sd15_train.py +201 -0
  38. hcpdiff/easy/cfg/sdxl_train.py +140 -0
  39. hcpdiff/easy/cfg/t2i.py +177 -0
  40. hcpdiff/easy/model/__init__.py +2 -0
  41. hcpdiff/easy/model/cnet.py +31 -0
  42. hcpdiff/easy/model/loader.py +79 -0
  43. hcpdiff/easy/sampler.py +46 -0
  44. hcpdiff/evaluate/__init__.py +1 -0
  45. hcpdiff/evaluate/previewer.py +60 -0
  46. hcpdiff/loss/__init__.py +4 -1
  47. hcpdiff/loss/base.py +41 -0
  48. hcpdiff/loss/gw.py +35 -0
  49. hcpdiff/loss/ssim.py +37 -0
  50. hcpdiff/loss/vlb.py +79 -0
  51. hcpdiff/loss/weighting.py +66 -0
  52. hcpdiff/models/__init__.py +2 -2
  53. hcpdiff/models/cfg_context.py +17 -14
  54. hcpdiff/models/compose/compose_hook.py +44 -23
  55. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  56. hcpdiff/models/compose/sdxl_composer.py +4 -4
  57. hcpdiff/models/container.py +1 -1
  58. hcpdiff/models/controlnet.py +16 -16
  59. hcpdiff/models/lora_base_patch.py +14 -25
  60. hcpdiff/models/lora_layers.py +3 -9
  61. hcpdiff/models/lora_layers_patch.py +14 -24
  62. hcpdiff/models/text_emb_ex.py +84 -6
  63. hcpdiff/models/textencoder_ex.py +54 -18
  64. hcpdiff/models/wrapper/__init__.py +3 -0
  65. hcpdiff/models/wrapper/pixart.py +19 -0
  66. hcpdiff/models/wrapper/sd.py +218 -0
  67. hcpdiff/models/wrapper/utils.py +20 -0
  68. hcpdiff/parser/__init__.py +1 -0
  69. hcpdiff/parser/embpt.py +32 -0
  70. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  71. hcpdiff/tools/dataset_generator.py +94 -0
  72. hcpdiff/tools/download_hf_model.py +24 -0
  73. hcpdiff/tools/embedding_convert.py +6 -2
  74. hcpdiff/tools/init_proj.py +3 -21
  75. hcpdiff/tools/lora_convert.py +19 -15
  76. hcpdiff/tools/save_model.py +12 -0
  77. hcpdiff/tools/sd2diffusers.py +1 -1
  78. hcpdiff/train_colo.py +1 -1
  79. hcpdiff/train_deepspeed.py +1 -1
  80. hcpdiff/trainer_ac.py +79 -0
  81. hcpdiff/trainer_ac_single.py +31 -0
  82. hcpdiff/utils/__init__.py +0 -2
  83. hcpdiff/utils/inpaint_pipe.py +790 -0
  84. hcpdiff/utils/net_utils.py +29 -6
  85. hcpdiff/utils/pipe_hook.py +46 -33
  86. hcpdiff/utils/utils.py +21 -4
  87. hcpdiff/workflow/__init__.py +15 -10
  88. hcpdiff/workflow/daam/__init__.py +1 -0
  89. hcpdiff/workflow/daam/act.py +66 -0
  90. hcpdiff/workflow/daam/hook.py +109 -0
  91. hcpdiff/workflow/diffusion.py +128 -136
  92. hcpdiff/workflow/fast.py +31 -0
  93. hcpdiff/workflow/flow.py +67 -0
  94. hcpdiff/workflow/io.py +36 -68
  95. hcpdiff/workflow/model.py +46 -43
  96. hcpdiff/workflow/text.py +84 -52
  97. hcpdiff/workflow/utils.py +32 -12
  98. hcpdiff/workflow/vae.py +37 -38
  99. hcpdiff-2.1.dist-info/METADATA +285 -0
  100. hcpdiff-2.1.dist-info/RECORD +114 -0
  101. {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
  102. hcpdiff-2.1.dist-info/entry_points.txt +5 -0
  103. hcpdiff/ckpt_manager/base.py +0 -16
  104. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  105. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  106. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -60
  107. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  108. hcpdiff/data/bucket.py +0 -358
  109. hcpdiff/data/caption_loader.py +0 -80
  110. hcpdiff/data/cond_dataset.py +0 -40
  111. hcpdiff/data/crop_info_dataset.py +0 -40
  112. hcpdiff/data/data_processor.py +0 -33
  113. hcpdiff/data/pair_dataset.py +0 -146
  114. hcpdiff/data/sampler.py +0 -54
  115. hcpdiff/data/source/base.py +0 -30
  116. hcpdiff/data/utils.py +0 -80
  117. hcpdiff/infer_workflow.py +0 -57
  118. hcpdiff/loggers/__init__.py +0 -13
  119. hcpdiff/loggers/base_logger.py +0 -76
  120. hcpdiff/loggers/cli_logger.py +0 -40
  121. hcpdiff/loggers/preview/__init__.py +0 -1
  122. hcpdiff/loggers/preview/image_previewer.py +0 -149
  123. hcpdiff/loggers/tensorboard_logger.py +0 -30
  124. hcpdiff/loggers/wandb_logger.py +0 -31
  125. hcpdiff/loggers/webui_logger.py +0 -9
  126. hcpdiff/loss/min_snr_loss.py +0 -52
  127. hcpdiff/models/layers.py +0 -81
  128. hcpdiff/models/plugin.py +0 -348
  129. hcpdiff/models/wrapper.py +0 -75
  130. hcpdiff/noise/__init__.py +0 -3
  131. hcpdiff/noise/noise_base.py +0 -16
  132. hcpdiff/noise/pyramid_noise.py +0 -50
  133. hcpdiff/noise/zero_terminal.py +0 -44
  134. hcpdiff/train_ac.py +0 -565
  135. hcpdiff/train_ac_single.py +0 -39
  136. hcpdiff/utils/caption_tools.py +0 -105
  137. hcpdiff/utils/cfg_net_tools.py +0 -321
  138. hcpdiff/utils/cfg_resolvers.py +0 -16
  139. hcpdiff/utils/ema.py +0 -52
  140. hcpdiff/utils/img_size_tool.py +0 -248
  141. hcpdiff/vis/__init__.py +0 -3
  142. hcpdiff/vis/base_interface.py +0 -12
  143. hcpdiff/vis/disk_interface.py +0 -48
  144. hcpdiff/vis/webui_interface.py +0 -17
  145. hcpdiff/visualizer.py +0 -258
  146. hcpdiff/visualizer_reloadable.py +0 -237
  147. hcpdiff/workflow/base.py +0 -59
  148. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  149. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  150. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  151. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  152. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  153. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  154. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  155. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  156. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  157. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  158. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  159. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  160. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  161. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  162. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  163. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  164. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  165. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  166. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  167. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  168. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  169. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  170. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  171. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  172. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  173. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  174. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  175. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  176. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  177. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  178. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  179. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  180. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  181. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  182. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  183. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  184. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  185. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  186. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  187. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  188. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  189. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  190. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  191. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  192. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -57
  193. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  194. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero2.json +0 -32
  195. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero3.json +0 -39
  196. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  197. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  198. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  199. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  200. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  201. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  202. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  203. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  204. hcpdiff-0.9.0.dist-info/METADATA +0 -199
  205. hcpdiff-0.9.0.dist-info/RECORD +0 -155
  206. hcpdiff-0.9.0.dist-info/entry_points.txt +0 -2
  207. {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
  208. {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
@@ -1,31 +0,0 @@
1
- from typing import Dict, Any
2
-
3
- import os
4
- import wandb
5
- from PIL import Image
6
-
7
- from .base_logger import BaseLogger
8
-
9
-
10
- class WanDBLogger(BaseLogger):
11
- def __init__(self, exp_dir, out_path=None, enable_log_image=False, project='hcp-diffusion', log_step=10, image_log_step=200):
12
- super().__init__(exp_dir, out_path, enable_log_image, log_step, image_log_step)
13
- if exp_dir is not None: # exp_dir is only available in local main process
14
- wandb.init(project=project, name=os.path.basename(exp_dir))
15
- wandb.save(os.path.join(exp_dir, 'cfg.yaml'), base_path=exp_dir)
16
- else:
17
- self.writer = None
18
- self.disable()
19
-
20
- def _info(self, info):
21
- pass
22
-
23
- def _log(self, datas: Dict[str, Any], step: int = 0):
24
- log_dict = {'step': step}
25
- for k, v in datas.items():
26
- if len(v['data']) == 1:
27
- log_dict[k] = v['data'][0]
28
- wandb.log(log_dict)
29
-
30
- def _log_image(self, imgs: Dict[str, Image.Image], step: int = 0):
31
- wandb.log({next(iter(imgs.keys())): list(imgs.values())}, step=step)
@@ -1,9 +0,0 @@
1
- from typing import Dict, Any
2
-
3
- from loguru import logger
4
-
5
- from .cli_logger import CLILogger
6
-
7
- class WebUILogger(CLILogger):
8
- def _log(self, datas: Dict[str, Any], step: int = 0):
9
- logger.info('this progress steps:'+', '.join([f"{k} = {v['format'].format(*v['data'])}" for k, v in datas.items()]))
@@ -1,52 +0,0 @@
1
- import torch
2
- from diffusers import SchedulerMixin
3
- from torch import nn
4
-
5
- class MinSNRLoss(nn.MSELoss):
6
- need_timesteps = True
7
-
8
- def __init__(self, size_average=None, reduce=None, reduction: str = 'none', gamma=1.,
9
- noise_scheduler: SchedulerMixin = None, device='cuda:0', **kwargs):
10
- super().__init__(size_average, reduce, reduction)
11
- self.gamma = gamma
12
-
13
- # calculate SNR
14
- alphas_cumprod = noise_scheduler.alphas_cumprod
15
- sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
16
- sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0-alphas_cumprod)
17
- self.alpha = sqrt_alphas_cumprod.to(device)
18
- self.sigma = sqrt_one_minus_alphas_cumprod.to(device)
19
- self.all_snr = ((self.alpha/self.sigma)**2).to(device)
20
-
21
- def forward(self, input: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
22
- loss = super(MinSNRLoss, self).forward(input, target)
23
- snr = self.all_snr[timesteps[:loss.shape[0], ...].squeeze()]
24
- snr_weight = (self.gamma/snr).clip(max=1.).float()
25
- return loss*snr_weight.view(-1, 1, 1, 1)
26
-
27
-
28
- class SoftMinSNRLoss(MinSNRLoss):
29
- # gamma=2
30
-
31
- def forward(self, input: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
32
- loss = super(MinSNRLoss, self).forward(input, target)
33
- snr = self.all_snr[timesteps[:loss.shape[0], ...].squeeze()]
34
- snr_weight = (self.gamma**3/(snr**2 + self.gamma**3)).float()
35
- return loss*snr_weight.view(-1, 1, 1, 1)
36
-
37
- class KDiffMinSNRLoss(MinSNRLoss):
38
-
39
- def forward(self, input: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
40
- loss = super(MinSNRLoss, self).forward(input, target)
41
- snr = self.all_snr[timesteps[:loss.shape[0], ...].squeeze()]
42
- snr_weight = 4*(((self.gamma*snr)**2/(snr**2 + self.gamma**2)**2)).float()
43
- return loss*snr_weight.view(-1, 1, 1, 1)
44
-
45
- class EDMLoss(MinSNRLoss):
46
-
47
- def forward(self, input: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
48
- loss = super(MinSNRLoss, self).forward(input, target)
49
- sigma = self.sigma[timesteps[:loss.shape[0], ...].squeeze()]
50
- snr = self.all_snr[timesteps[:loss.shape[0], ...].squeeze()]
51
- snr_weight = ((sigma**2+self.gamma**2)/(snr*(sigma*self.gamma)**2)).float()
52
- return loss*snr_weight.view(-1, 1, 1, 1)
hcpdiff/models/layers.py DELETED
@@ -1,81 +0,0 @@
1
- """
2
- layers.py
3
- ====================
4
- :Name: GroupLinear and other layers
5
- :Author: Dong Ziyi
6
- :Affiliation: HCP Lab, SYSU
7
- :Created: 09/04/2023
8
- :Licence: Apache-2.0
9
- """
10
-
11
- import torch
12
- from torch import nn
13
- import math
14
- from einops import rearrange
15
-
16
- class GroupLinear(nn.Module):
17
- def __init__(self, in_features: int, out_features: int, groups: int, bias: bool = True,
18
- device=None, dtype=None):
19
- super().__init__()
20
- assert in_features%groups == 0
21
- assert out_features%groups == 0
22
-
23
- factory_kwargs = {'device': device, 'dtype': dtype}
24
-
25
- self.groups = groups
26
- self.in_features = in_features
27
- self.out_features = out_features
28
-
29
- self.weight = nn.Parameter(torch.empty((groups, in_features//groups, out_features//groups), **factory_kwargs))
30
- if bias:
31
- self.bias = nn.Parameter(torch.empty(groups, 1, out_features//groups, **factory_kwargs))
32
- else:
33
- self.register_parameter('bias', None)
34
- self.reset_parameters()
35
-
36
- def reset_parameters(self) -> None:
37
- # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
38
- # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
39
- # https://github.com/pytorch/pytorch/issues/57109
40
- self.kaiming_uniform_group(self.weight, a=math.sqrt(5))
41
- if self.bias is not None:
42
- fan_in, _ = self._calculate_fan_in_and_fan_out(self.weight)
43
- bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
44
- nn.init.uniform_(self.bias, -bound, bound)
45
-
46
- @staticmethod
47
- def _calculate_fan_in_and_fan_out(tensor):
48
- receptive_field_size = 1
49
- num_input_fmaps = tensor.size(-2)
50
- num_output_fmaps = tensor.size(-1)
51
- fan_in = num_input_fmaps * receptive_field_size
52
- fan_out = num_output_fmaps * receptive_field_size
53
-
54
- return fan_in, fan_out
55
-
56
- @staticmethod
57
- def kaiming_uniform_group(tensor: torch.Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu') -> torch.Tensor:
58
- def _calculate_correct_fan(tensor, mode):
59
- mode = mode.lower()
60
- valid_modes = ['fan_in', 'fan_out']
61
- if mode not in valid_modes:
62
- raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
63
-
64
- fan_in, fan_out = GroupLinear._calculate_fan_in_and_fan_out(tensor)
65
- return fan_in if mode == 'fan_in' else fan_out
66
-
67
- fan = _calculate_correct_fan(tensor, mode)
68
- gain = nn.init.calculate_gain(nonlinearity, a)
69
- std = gain / math.sqrt(fan)
70
- bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
71
- with torch.no_grad():
72
- return tensor.uniform_(-bound, bound)
73
-
74
- def forward(self, x: torch.Tensor): # x: [G,B,L,C]
75
- x = rearrange(x, '(g b) l c -> g (b l) c', g=self.num_groups)
76
- if self.bias is not None:
77
- out = torch.bmm(x, self.weight) + self.bias
78
- else:
79
- out = torch.bmm(x, self.weight)
80
- out = rearrange(out, 'g (b l) c -> (g b) l c', b=B)
81
- return out
hcpdiff/models/plugin.py DELETED
@@ -1,348 +0,0 @@
1
- """
2
- plugin.py
3
- ====================
4
- :Name: model plugin
5
- :Author: Dong Ziyi
6
- :Affiliation: HCP Lab, SYSU
7
- :Created: 10/03/2023
8
- :Licence: Apache-2.0
9
- """
10
-
11
- import weakref
12
- import re
13
- from typing import Tuple, List, Dict, Any, Iterable
14
-
15
- import torch
16
- from torch import nn
17
-
18
- from hcpdiff.utils.net_utils import split_module_name
19
-
20
- class BasePluginBlock(nn.Module):
21
- def __init__(self, name: str):
22
- super().__init__()
23
- self.name = name
24
-
25
- def forward(self, host: nn.Module, fea_in: Tuple[torch.Tensor], fea_out: torch.Tensor):
26
- return fea_out
27
-
28
- def remove(self):
29
- pass
30
-
31
- def feed_input_data(self, data):
32
- self.input_data = data
33
-
34
- def register_input_feeder_to(self, host_model):
35
- if not hasattr(host_model, 'input_feeder'):
36
- host_model.input_feeder = []
37
- host_model.input_feeder.append(self.feed_input_data)
38
-
39
- def set_hyper_params(self, **kwargs):
40
- for k, v in kwargs.items():
41
- setattr(self, k, v)
42
-
43
- @staticmethod
44
- def extract_state_without_plugin(model: nn.Module, trainable=False):
45
- trainable_keys = {k for k, v in model.named_parameters() if v.requires_grad}
46
- plugin_names = {k for k, v in model.named_modules() if isinstance(v, BasePluginBlock)}
47
- model_sd = {}
48
- for k, v in model.state_dict().items():
49
- if (not trainable) or k in trainable_keys:
50
- for name in plugin_names:
51
- if k.startswith(name):
52
- break
53
- else:
54
- model_sd[k] = v
55
- return model_sd
56
-
57
- def get_trainable_parameters(self) -> Iterable[nn.Parameter]:
58
- return self.parameters()
59
-
60
- class WrapablePlugin:
61
- wrapable_classes = ()
62
-
63
- @classmethod
64
- def wrap_layer(cls, name: str, layer: nn.Module, **kwargs):
65
- plugin = cls(name, layer, **kwargs)
66
- return plugin
67
-
68
- @classmethod
69
- def named_modules_with_exclude(cls, self, memo = None, prefix: str = '', remove_duplicate: bool = True,
70
- exclude_key=None, exclude_classes=tuple()):
71
-
72
- if memo is None:
73
- memo = set()
74
- if self not in memo:
75
- if remove_duplicate:
76
- memo.add(self)
77
- if (exclude_key is None or not re.search(exclude_key, prefix)) and not isinstance(self, exclude_classes):
78
- yield prefix, self
79
- for name, module in self._modules.items():
80
- if module is None:
81
- continue
82
- submodule_prefix = prefix + ('.' if prefix else '') + name
83
- for m in cls.named_modules_with_exclude(module, memo, submodule_prefix, remove_duplicate, exclude_key, exclude_classes):
84
- yield m
85
-
86
- @classmethod
87
- def wrap_model(cls, name: str, host: nn.Module, exclude_key=None, exclude_classes=tuple(), **kwargs): # -> Dict[str, SinglePluginBlock]:
88
- '''
89
- parent_block and other args required in __init__ will be put into kwargs, compatible with multiple models.
90
- '''
91
- plugin_block_dict = {}
92
- if isinstance(host, cls.wrapable_classes):
93
- plugin_block_dict[''] = cls.wrap_layer(name, host, **kwargs)
94
- else:
95
- named_modules = {layer_name:layer for layer_name, layer in cls.named_modules_with_exclude(
96
- host, exclude_key=exclude_key, exclude_classes=exclude_classes)}
97
- for layer_name, layer in named_modules.items():
98
- if isinstance(layer, cls.wrapable_classes):
99
- # For plugins that need parent_block
100
- if 'parent_block' in kwargs:
101
- parent_name, host_name = split_module_name(layer_name)
102
- kwargs['parent_block'] = named_modules[parent_name]
103
- kwargs['host_name'] = host_name
104
- plugin_block_dict[layer_name] = cls.wrap_layer(name, layer, **kwargs)
105
- return plugin_block_dict
106
-
107
- class SinglePluginBlock(BasePluginBlock, WrapablePlugin):
108
-
109
- def __init__(self, name: str, host: nn.Module, hook_param=None, host_model=None):
110
- super().__init__(name)
111
- self.host = weakref.ref(host)
112
- setattr(host, name, self)
113
-
114
- if hook_param is None:
115
- self.hook_handle = host.register_forward_hook(self.layer_hook)
116
- else: # hook for model parameters
117
- self.backup = getattr(host, hook_param)
118
- self.target = hook_param
119
- self.handle_pre = host.register_forward_pre_hook(self.pre_hook)
120
- self.handle_post = host.register_forward_hook(self.post_hook)
121
-
122
- def layer_hook(self, host, fea_in: Tuple[torch.Tensor], fea_out: torch.Tensor):
123
- return self(fea_in, fea_out)
124
-
125
- def pre_hook(self, host, fea_in: torch.Tensor):
126
- host.weight_restored = False
127
- host_param = getattr(host, self.target)
128
- delattr(host, self.target)
129
- setattr(host, self.target, self(host_param))
130
- return fea_in
131
-
132
- def post_hook(self, host, fea_int, fea_out):
133
- if not getattr(host, 'weight_restored', False):
134
- setattr(host, self.target, self.backup)
135
- host.weight_restored = True
136
-
137
- def remove(self):
138
- host = self.host()
139
- delattr(host, self.name)
140
- if hasattr(self, 'hook_handle'):
141
- self.hook_handle.remove()
142
- else:
143
- self.handle_pre.remove()
144
- self.handle_post.remove()
145
-
146
- class PluginBlock(BasePluginBlock):
147
- def __init__(self, name, from_layer: Dict[str, Any], to_layer: Dict[str, Any], host_model=None):
148
- super().__init__(name)
149
- self.host_from = weakref.ref(from_layer['layer'])
150
- self.host_to = weakref.ref(to_layer['layer'])
151
- setattr(from_layer['layer'], name, self)
152
-
153
- if from_layer['pre_hook']:
154
- self.hook_handle_from = from_layer['layer'].register_forward_pre_hook(lambda host, fea_in:self.from_layer_hook(host, fea_in, None))
155
- else:
156
- self.hook_handle_from = from_layer['layer'].register_forward_hook(
157
- lambda host, fea_in, fea_out:self.from_layer_hook(host, fea_in, fea_out))
158
- if to_layer['pre_hook']:
159
- self.hook_handle_to = to_layer['layer'].register_forward_pre_hook(lambda host, fea_in:self.to_layer_hook(host, fea_in, None))
160
- else:
161
- self.hook_handle_to = to_layer['layer'].register_forward_hook(lambda host, fea_in, fea_out:self.to_layer_hook(host, fea_in, fea_out))
162
-
163
- def from_layer_hook(self, host, fea_in: Tuple[torch.Tensor], fea_out: torch.Tensor):
164
- self.feat_from = fea_in
165
-
166
- def to_layer_hook(self, host, fea_in: Tuple[torch.Tensor], fea_out: torch.Tensor):
167
- return self(self.feat_from, fea_in, fea_out)
168
-
169
- def remove(self):
170
- host_from = self.host_from()
171
- delattr(host_from, self.name)
172
- self.hook_handle_from.remove()
173
- self.hook_handle_to.remove()
174
-
175
- class MultiPluginBlock(BasePluginBlock):
176
- def __init__(self, name: str, from_layers: List[Dict[str, Any]], to_layers: List[Dict[str, Any]], host_model=None):
177
- super().__init__(name)
178
- assert host_model is not None
179
- self.host_from = [weakref.ref(x['layer']) for x in from_layers]
180
- self.host_to = [weakref.ref(x['layer']) for x in to_layers]
181
- self.host_model = weakref.ref(host_model)
182
- setattr(host_model, name, self)
183
-
184
- self.feat_from = [None for _ in range(len(from_layers))]
185
-
186
- self.hook_handle_from = []
187
- self.hook_handle_to = []
188
-
189
- for idx, layer in enumerate(from_layers):
190
- if layer['pre_hook']:
191
- handle_from = layer['layer'].register_forward_pre_hook(lambda host, fea_in, idx=idx:self.from_layer_hook(host, fea_in, None, idx))
192
- else:
193
- handle_from = layer['layer'].register_forward_hook(
194
- lambda host, fea_in, fea_out, idx=idx:self.from_layer_hook(host, fea_in, fea_out, idx))
195
- self.hook_handle_from.append(handle_from)
196
- for idx, layer in enumerate(to_layers):
197
- if layer['pre_hook']:
198
- handle_to = layer['layer'].register_forward_pre_hook(lambda host, fea_in, idx=idx:self.to_layer_hook(host, fea_in, None, idx))
199
- else:
200
- handle_to = layer['layer'].register_forward_hook(lambda host, fea_in, fea_out, idx=idx:self.to_layer_hook(host, fea_in, fea_out, idx))
201
- self.hook_handle_to.append(handle_to)
202
-
203
- self.record_count = 0
204
-
205
- def from_layer_hook(self, host, fea_in: Tuple[torch.Tensor], fea_out: Tuple[torch.Tensor], idx: int):
206
- self.feat_from[idx] = fea_in
207
- self.record_count += 1
208
- if self.record_count == len(self.feat_from): # call forward when all feat is record
209
- self.record_count = 0
210
- self.feat_to = self(self.feat_from)
211
-
212
- def to_layer_hook(self, host, fea_in: Tuple[torch.Tensor], fea_out: Tuple[torch.Tensor], idx: int):
213
- return self.feat_to[idx]+fea_out
214
-
215
- def remove(self):
216
- host_model = self.host_model()
217
- delattr(host_model, self.name)
218
- for handle_from in self.hook_handle_from:
219
- handle_from.remove()
220
- for handle_to in self.hook_handle_to:
221
- handle_to.remove()
222
-
223
- class PatchPluginContainer(nn.Module):
224
- def __init__(self, host_name, host, parent_block):
225
- super().__init__()
226
- self._host = host
227
- self.host_name = host_name
228
- self.parent_block = weakref.ref(parent_block)
229
- self.plugin_names = []
230
-
231
- delattr(parent_block, host_name)
232
- setattr(parent_block, host_name, self)
233
-
234
- def add_plugin(self, name: str, plugin: 'PatchPluginBlock'):
235
- setattr(self, name, plugin)
236
- self.plugin_names.append(name)
237
-
238
- def remove_plugin(self, name: str):
239
- delattr(self, name)
240
- self.plugin_names.remove(name)
241
- if len(self.plugin_names) == 0:
242
- self.remove()
243
-
244
- def forward(self, *args, **kwargs):
245
- for name, plugin in self:
246
- args, kwargs = plugin.pre_forward(*args, **kwargs)
247
- output = self._host(*args, **kwargs)
248
- for name, plugin in self:
249
- output = plugin.post_forward(output, *args, **kwargs)
250
- return output
251
-
252
- def remove(self):
253
- parent_block = self.parent_block()
254
- delattr(parent_block, self.host_name)
255
- setattr(parent_block, self.host_name, self._host)
256
-
257
- def __iter__(self):
258
- for name in self.plugin_names:
259
- yield name, self[name]
260
-
261
- def __getitem__(self, name):
262
- return getattr(self, name)
263
-
264
- class PatchPluginBlock(BasePluginBlock, WrapablePlugin):
265
- container_cls = PatchPluginContainer
266
-
267
- def __init__(self, name: str, host: nn.Module, host_model=None, parent_block: nn.Module = None, host_name: str = None):
268
- super().__init__(name)
269
- if isinstance(host, self.container_cls):
270
- self.host = weakref.ref(host._host)
271
- else:
272
- self.host = weakref.ref(host)
273
- self.parent_block = weakref.ref(parent_block)
274
- self.host_name = host_name
275
-
276
- container = self.get_container(host, host_name, parent_block)
277
- container.add_plugin(name, self)
278
- self.container = weakref.ref(container)
279
-
280
- def pre_forward(self, *args, **kwargs):
281
- return args, kwargs
282
-
283
- def post_forward(self, output, *args, **kwargs):
284
- return output
285
-
286
- def remove(self):
287
- container = self.container()
288
- container.remove_plugin(self.name)
289
-
290
- def get_container(self, host, host_name, parent_block):
291
- if isinstance(host, self.container_cls):
292
- return host
293
- else:
294
- return self.container_cls(host_name, host, parent_block)
295
-
296
- @classmethod
297
- def wrap_model(cls, name: str, host: nn.Module, exclude_key=None, exclude_classes=tuple(), **kwargs): # -> Dict[str, SinglePluginBlock]:
298
- '''
299
- parent_block and other args required in __init__ will be put into kwargs, compatible with multiple models.
300
- '''
301
- plugin_block_dict = {}
302
- if isinstance(host, cls.wrapable_classes):
303
- plugin_block_dict[''] = cls.wrap_layer(name, host, **kwargs)
304
- else:
305
- named_modules = {layer_name:layer for layer_name, layer in cls.named_modules_with_exclude(
306
- host, exclude_key=exclude_key or '_host', exclude_classes=exclude_classes)}
307
- for layer_name, layer in named_modules.items():
308
- if isinstance(layer, cls.wrapable_classes) or isinstance(layer, cls.container_cls):
309
- # For plugins that need parent_block
310
- if 'parent_block' in kwargs:
311
- parent_name, host_name = split_module_name(layer_name)
312
- kwargs['parent_block'] = named_modules[parent_name]
313
- kwargs['host_name'] = host_name
314
- plugin_block_dict[layer_name] = cls.wrap_layer(name, layer, **kwargs)
315
- return plugin_block_dict
316
-
317
- class PluginGroup:
318
- def __init__(self, plugin_dict: Dict[str, BasePluginBlock]):
319
- self.plugin_dict = plugin_dict # {host_model_path: plugin_object}
320
-
321
- def __setitem__(self, k, v):
322
- self.plugin_dict[k] = v
323
-
324
- def __getitem__(self, k):
325
- return self.plugin_dict[k]
326
-
327
- @property
328
- def plugin_name(self):
329
- if self.empty():
330
- return None
331
- return next(iter(self.plugin_dict.values())).name
332
-
333
- def remove(self):
334
- for plugin in self.plugin_dict.values():
335
- plugin.remove()
336
-
337
- def state_dict(self, model=None):
338
- if model is None:
339
- return {f'{k}.___.{ks}':vs for k, v in self.plugin_dict.items() for ks, vs in v.state_dict().items()}
340
- else:
341
- sd_model = model.state_dict()
342
- return {f'{k}.___.{ks}':sd_model[f'{k}.{v.name}.{ks}'] for k, v in self.plugin_dict.items() for ks, vs in v.state_dict().items()}
343
-
344
- def state_keys_raw(self):
345
- return [f'{k}.{v.name}.{ks}' for k, v in self.plugin_dict.items() for ks, vs in v.state_dict().items()]
346
-
347
- def empty(self):
348
- return len(self.plugin_dict) == 0
hcpdiff/models/wrapper.py DELETED
@@ -1,75 +0,0 @@
1
- from torch import nn
2
- import itertools
3
- from transformers import CLIPTextModel
4
- from hcpdiff.utils import pad_attn_bias
5
-
6
- class TEUnetWrapper(nn.Module):
7
- def __init__(self, unet, TE, train_TE=False):
8
- super().__init__()
9
- self.unet = unet
10
- self.TE = TE
11
-
12
- self.train_TE = train_TE
13
-
14
- def forward(self, prompt_ids, noisy_latents, timesteps, attn_mask=None, position_ids=None, plugin_input={}, **kwargs):
15
- input_all = dict(prompt_ids=prompt_ids, noisy_latents=noisy_latents, timesteps=timesteps, position_ids=position_ids, attn_mask=attn_mask, **plugin_input)
16
-
17
- if hasattr(self.TE, 'input_feeder'):
18
- for feeder in self.TE.input_feeder:
19
- feeder(input_all)
20
- encoder_hidden_states = self.TE(prompt_ids, position_ids=position_ids, attention_mask=attn_mask, output_hidden_states=True)[0] # Get the text embedding for conditioning
21
-
22
- if attn_mask is not None:
23
- encoder_hidden_states, attn_mask = pad_attn_bias(encoder_hidden_states, attn_mask)
24
-
25
- input_all['encoder_hidden_states'] = encoder_hidden_states
26
- if hasattr(self.unet, 'input_feeder'):
27
- for feeder in self.unet.input_feeder:
28
- feeder(input_all)
29
- model_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, encoder_attention_mask=attn_mask).sample # Predict the noise residual
30
- return model_pred
31
-
32
- def prepare(self, accelerator):
33
- if self.train_TE:
34
- return accelerator.prepare(self)
35
- else:
36
- self.unet = accelerator.prepare(self.unet)
37
- return self
38
-
39
- def enable_gradient_checkpointing(self):
40
- def grad_ckpt_enable(m):
41
- if hasattr(m, 'gradient_checkpointing'):
42
- m.training = True
43
-
44
- self.unet.enable_gradient_checkpointing()
45
- if self.train_TE:
46
- self.TE.gradient_checkpointing_enable()
47
- self.apply(grad_ckpt_enable)
48
- else:
49
- self.unet.apply(grad_ckpt_enable)
50
-
51
- def trainable_parameters(self):
52
- if self.train_TE:
53
- return itertools.chain(self.unet.parameters(), self.TE.parameters())
54
- else:
55
- return self.unet.parameters()
56
-
57
- class SDXLTEUnetWrapper(TEUnetWrapper):
58
- def forward(self, prompt_ids, noisy_latents, timesteps, attn_mask=None, position_ids=None, crop_info=None, plugin_input={}, **kwargs):
59
- input_all = dict(prompt_ids=prompt_ids, noisy_latents=noisy_latents, timesteps=timesteps, position_ids=position_ids, attn_mask=attn_mask, **plugin_input)
60
-
61
- if hasattr(self.TE, 'input_feeder'):
62
- for feeder in self.TE.input_feeder:
63
- feeder(input_all)
64
- encoder_hidden_states, pooled_output = self.TE(prompt_ids, position_ids=position_ids, attention_mask=attn_mask, output_hidden_states=True) # Get the text embedding for conditioning
65
-
66
- added_cond_kwargs = {"text_embeds":pooled_output[-1], "time_ids":crop_info}
67
- if attn_mask is not None:
68
- encoder_hidden_states, attn_mask = pad_attn_bias(encoder_hidden_states, attn_mask)
69
-
70
- input_all['encoder_hidden_states'] = encoder_hidden_states
71
- if hasattr(self.unet, 'input_feeder'):
72
- for feeder in self.unet.input_feeder:
73
- feeder(input_all)
74
- model_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, encoder_attention_mask=attn_mask, added_cond_kwargs=added_cond_kwargs).sample # Predict the noise residual
75
- return model_pred
hcpdiff/noise/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from .noise_base import NoiseBase
2
- from .pyramid_noise import PyramidNoiseScheduler
3
- from .zero_terminal import ZeroTerminalScheduler
@@ -1,16 +0,0 @@
1
-
2
- class NoiseBase:
3
- def __init__(self, base_scheduler):
4
- self.base_scheduler = base_scheduler
5
-
6
- def __getattr__(self, item):
7
- try:
8
- return super(NoiseBase, self).__getattr__(item)
9
- except:
10
- return getattr(self.base_scheduler, item)
11
-
12
- def __setattr__(self, key, value):
13
- if hasattr(super(), 'base_scheduler') and hasattr(self.base_scheduler, key):
14
- setattr(self.base_scheduler, key, value)
15
- else:
16
- super(NoiseBase, self).__setattr__(key, value)
@@ -1,50 +0,0 @@
1
- import random
2
-
3
- import torch
4
- from torch.nn import functional as F
5
- from diffusers import SchedulerMixin
6
-
7
- from .noise_base import NoiseBase
8
-
9
- class PyramidNoiseScheduler(NoiseBase, SchedulerMixin):
10
- def __init__(self, base_scheduler, level: int = 10, discount: float = 0.9, step_size: float = 2., resize_mode: str = 'bilinear'):
11
- super().__init__(base_scheduler)
12
- self.level = level
13
- self.step_size = step_size
14
- self.resize_mode = resize_mode
15
- self.discount = discount
16
-
17
- def add_noise(
18
- self,
19
- original_samples: torch.FloatTensor,
20
- noise: torch.FloatTensor,
21
- timesteps: torch.IntTensor,
22
- ) -> torch.FloatTensor:
23
- with torch.no_grad():
24
- b, c, h, w = noise.shape
25
- for i in range(1, self.level):
26
- r = random.random()*2+self.step_size
27
- wn, hn = max(1, int(w/(r**i))), max(1, int(h/(r**i)))
28
- noise += F.interpolate(torch.randn(b, c, hn, wn).to(noise), (h, w), None, self.resize_mode)*(self.discount**i)
29
- if wn == 1 or hn == 1:
30
- break
31
- noise = noise/noise.std()
32
- return self.base_scheduler.add_noise(original_samples, noise, timesteps)
33
-
34
- # if __name__ == '__main__':
35
- # noise = torch.randn(1,3,512,512)
36
- # level=10
37
- # discount=0.6
38
- # b, c, h, w = noise.shape
39
- # for i in range(level):
40
- # r = random.random() * 2 + 2
41
- # wn, hn = max(1, int(w / (r ** i))), max(1, int(h / (r ** i)))
42
- # noise += F.interpolate(torch.randn(b, c, wn, hn).to(noise), (w, h), None, 'bilinear') * discount ** i
43
- # if wn == 1 or hn == 1:
44
- # break
45
- # noise = noise / noise.std()
46
- #
47
- # from matplotlib import pyplot as plt
48
- # plt.figure()
49
- # plt.imshow(noise[0].permute(1,2,0))
50
- # plt.show()