hcpdiff 0.9.1__py3-none-any.whl → 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (210) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +80 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +1 -2
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text2img.py +36 -74
  21. hcpdiff/data/source/text2img_cond.py +9 -15
  22. hcpdiff/diffusion/__init__.py +0 -0
  23. hcpdiff/diffusion/noise/__init__.py +2 -0
  24. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  25. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  26. hcpdiff/diffusion/sampler/__init__.py +5 -0
  27. hcpdiff/diffusion/sampler/base.py +72 -0
  28. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  29. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  30. hcpdiff/diffusion/sampler/edm.py +22 -0
  31. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  35. hcpdiff/easy/__init__.py +2 -0
  36. hcpdiff/easy/cfg/__init__.py +3 -0
  37. hcpdiff/easy/cfg/sd15_train.py +201 -0
  38. hcpdiff/easy/cfg/sdxl_train.py +140 -0
  39. hcpdiff/easy/cfg/t2i.py +177 -0
  40. hcpdiff/easy/model/__init__.py +2 -0
  41. hcpdiff/easy/model/cnet.py +31 -0
  42. hcpdiff/easy/model/loader.py +79 -0
  43. hcpdiff/easy/sampler.py +46 -0
  44. hcpdiff/evaluate/__init__.py +1 -0
  45. hcpdiff/evaluate/previewer.py +60 -0
  46. hcpdiff/loss/__init__.py +4 -1
  47. hcpdiff/loss/base.py +41 -0
  48. hcpdiff/loss/gw.py +35 -0
  49. hcpdiff/loss/ssim.py +37 -0
  50. hcpdiff/loss/vlb.py +79 -0
  51. hcpdiff/loss/weighting.py +66 -0
  52. hcpdiff/models/__init__.py +2 -2
  53. hcpdiff/models/cfg_context.py +17 -14
  54. hcpdiff/models/compose/compose_hook.py +44 -23
  55. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  56. hcpdiff/models/compose/sdxl_composer.py +4 -4
  57. hcpdiff/models/controlnet.py +16 -16
  58. hcpdiff/models/lora_base_patch.py +14 -25
  59. hcpdiff/models/lora_layers.py +3 -9
  60. hcpdiff/models/lora_layers_patch.py +14 -24
  61. hcpdiff/models/text_emb_ex.py +84 -6
  62. hcpdiff/models/textencoder_ex.py +54 -18
  63. hcpdiff/models/wrapper/__init__.py +3 -0
  64. hcpdiff/models/wrapper/pixart.py +19 -0
  65. hcpdiff/models/wrapper/sd.py +218 -0
  66. hcpdiff/models/wrapper/utils.py +20 -0
  67. hcpdiff/parser/__init__.py +1 -0
  68. hcpdiff/parser/embpt.py +32 -0
  69. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  70. hcpdiff/tools/dataset_generator.py +94 -0
  71. hcpdiff/tools/download_hf_model.py +24 -0
  72. hcpdiff/tools/init_proj.py +3 -21
  73. hcpdiff/tools/lora_convert.py +18 -17
  74. hcpdiff/tools/save_model.py +12 -0
  75. hcpdiff/tools/sd2diffusers.py +1 -1
  76. hcpdiff/train_colo.py +1 -1
  77. hcpdiff/train_deepspeed.py +1 -1
  78. hcpdiff/trainer_ac.py +79 -0
  79. hcpdiff/trainer_ac_single.py +31 -0
  80. hcpdiff/utils/__init__.py +0 -2
  81. hcpdiff/utils/inpaint_pipe.py +7 -2
  82. hcpdiff/utils/net_utils.py +29 -6
  83. hcpdiff/utils/pipe_hook.py +24 -7
  84. hcpdiff/utils/utils.py +21 -4
  85. hcpdiff/workflow/__init__.py +15 -10
  86. hcpdiff/workflow/daam/__init__.py +1 -0
  87. hcpdiff/workflow/daam/act.py +66 -0
  88. hcpdiff/workflow/daam/hook.py +109 -0
  89. hcpdiff/workflow/diffusion.py +114 -125
  90. hcpdiff/workflow/fast.py +31 -0
  91. hcpdiff/workflow/flow.py +67 -0
  92. hcpdiff/workflow/io.py +36 -130
  93. hcpdiff/workflow/model.py +46 -43
  94. hcpdiff/workflow/text.py +78 -46
  95. hcpdiff/workflow/utils.py +32 -12
  96. hcpdiff/workflow/vae.py +37 -38
  97. hcpdiff-2.1.dist-info/METADATA +285 -0
  98. hcpdiff-2.1.dist-info/RECORD +114 -0
  99. {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
  100. hcpdiff-2.1.dist-info/entry_points.txt +5 -0
  101. hcpdiff/ckpt_manager/base.py +0 -16
  102. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  103. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  104. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
  105. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  106. hcpdiff/data/bucket.py +0 -358
  107. hcpdiff/data/caption_loader.py +0 -80
  108. hcpdiff/data/cond_dataset.py +0 -40
  109. hcpdiff/data/crop_info_dataset.py +0 -40
  110. hcpdiff/data/data_processor.py +0 -33
  111. hcpdiff/data/pair_dataset.py +0 -146
  112. hcpdiff/data/sampler.py +0 -54
  113. hcpdiff/data/source/base.py +0 -30
  114. hcpdiff/data/utils.py +0 -80
  115. hcpdiff/deprecated/__init__.py +0 -1
  116. hcpdiff/deprecated/cfg_converter.py +0 -81
  117. hcpdiff/deprecated/lora_convert.py +0 -31
  118. hcpdiff/infer_workflow.py +0 -57
  119. hcpdiff/loggers/__init__.py +0 -13
  120. hcpdiff/loggers/base_logger.py +0 -76
  121. hcpdiff/loggers/cli_logger.py +0 -40
  122. hcpdiff/loggers/preview/__init__.py +0 -1
  123. hcpdiff/loggers/preview/image_previewer.py +0 -149
  124. hcpdiff/loggers/tensorboard_logger.py +0 -30
  125. hcpdiff/loggers/wandb_logger.py +0 -31
  126. hcpdiff/loggers/webui_logger.py +0 -9
  127. hcpdiff/loss/min_snr_loss.py +0 -52
  128. hcpdiff/models/layers.py +0 -81
  129. hcpdiff/models/plugin.py +0 -348
  130. hcpdiff/models/wrapper.py +0 -75
  131. hcpdiff/noise/__init__.py +0 -3
  132. hcpdiff/noise/noise_base.py +0 -16
  133. hcpdiff/noise/pyramid_noise.py +0 -50
  134. hcpdiff/noise/zero_terminal.py +0 -44
  135. hcpdiff/train_ac.py +0 -566
  136. hcpdiff/train_ac_single.py +0 -39
  137. hcpdiff/utils/caption_tools.py +0 -105
  138. hcpdiff/utils/cfg_net_tools.py +0 -321
  139. hcpdiff/utils/cfg_resolvers.py +0 -16
  140. hcpdiff/utils/ema.py +0 -52
  141. hcpdiff/utils/img_size_tool.py +0 -248
  142. hcpdiff/vis/__init__.py +0 -3
  143. hcpdiff/vis/base_interface.py +0 -12
  144. hcpdiff/vis/disk_interface.py +0 -48
  145. hcpdiff/vis/webui_interface.py +0 -17
  146. hcpdiff/viser_fast.py +0 -138
  147. hcpdiff/visualizer.py +0 -265
  148. hcpdiff/visualizer_reloadable.py +0 -237
  149. hcpdiff/workflow/base.py +0 -59
  150. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  151. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  152. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  153. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  154. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  155. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  156. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  157. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  158. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  159. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  160. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  161. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  162. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  163. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  164. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  165. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  166. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  167. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  168. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  169. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  170. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  171. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  172. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  173. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  174. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  175. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  176. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  177. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  178. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  179. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  180. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  181. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  182. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  183. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  184. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  185. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  186. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  187. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  188. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  189. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  190. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  191. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  192. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  193. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  194. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
  195. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  196. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
  197. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
  198. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  199. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  200. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  201. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  202. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  203. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  204. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  205. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  206. hcpdiff-0.9.1.dist-info/METADATA +0 -199
  207. hcpdiff-0.9.1.dist-info/RECORD +0 -160
  208. hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
  209. {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
  210. {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
@@ -1,149 +0,0 @@
1
- from contextlib import contextmanager
2
- from typing import List
3
-
4
- import hydra
5
- import torch
6
- from accelerate import infer_auto_device_map, dispatch_model
7
- from accelerate.hooks import remove_hook_from_module
8
- from diffusers import PNDMScheduler
9
- from torch.cuda.amp import autocast
10
-
11
- from hcpdiff.models import TokenizerHook
12
- from hcpdiff.utils.net_utils import to_cpu
13
- from hcpdiff.utils.utils import prepare_seed, load_config, size_to_int, int_to_size
14
- from hcpdiff.utils.utils import to_validate_file
15
- from hcpdiff.visualizer import Visualizer
16
-
17
- class ImagePreviewer(Visualizer):
18
- def __init__(self, infer_cfg, exp_dir, te_hook,
19
- unet, TE, tokenizer, vae, save_cfg=False):
20
- self.exp_dir = exp_dir
21
- self.cfgs_raw = load_config(infer_cfg)
22
- self.cfgs = hydra.utils.instantiate(self.cfgs_raw)
23
- self.save_cfg = save_cfg
24
- self.offload = 'offload' in self.cfgs and self.cfgs.offload is not None
25
- self.dtype = self.dtype_dict[self.cfgs.dtype]
26
-
27
- if getattr(self.cfgs.new_components, 'scheduler', None) is None:
28
- scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule='scaled_linear')
29
- else:
30
- scheduler = self.cfgs.new_components.scheduler
31
-
32
- pipe_cls = self.get_pipeline()
33
- self.pipe = pipe_cls(vae=vae, text_encoder=TE, tokenizer=tokenizer, unet=unet, scheduler=scheduler, feature_extractor=None,
34
- safety_checker=None, requires_safety_checker=False)
35
-
36
- self.token_ex = TokenizerHook(tokenizer)
37
- self.te_hook = te_hook
38
-
39
- if self.cfgs.seed is not None:
40
- self.seeds = list(range(self.cfgs.seed, self.cfgs.seed+self.cfgs.num*self.cfgs.bs))
41
- else:
42
- self.seeds = [None]*(self.cfgs.num*self.cfgs.bs)
43
-
44
- def build_vae_offload(self, offload_cfg):
45
- vram = size_to_int(offload_cfg.max_VRAM)
46
- if not offload_cfg.vae_cpu:
47
- device_map = infer_auto_device_map(self.pipe.vae, max_memory={0:int_to_size(vram >> 5), "cpu":offload_cfg.max_RAM}, dtype=torch.float32)
48
- self.pipe.vae = dispatch_model(self.pipe.vae, device_map)
49
- else:
50
- to_cpu(self.pipe.vae)
51
- self.vae_decode_raw = self.pipe.vae.decode
52
-
53
- def vae_decode_offload(latents, return_dict=True, decode_raw=self.pipe.vae.decode):
54
- self.pipe.vae.to(dtype=torch.float32)
55
- res = decode_raw(latents.cpu().to(dtype=torch.float32), return_dict=return_dict)
56
- return res
57
-
58
- self.pipe.vae.decode = vae_decode_offload
59
-
60
- self.vae_encode_raw = self.pipe.vae.encode
61
-
62
- def vae_encode_offload(x, return_dict=True, encode_raw=self.pipe.vae.encode):
63
- self.pipe.vae.to(dtype=torch.float32)
64
- res = encode_raw(x.cpu().to(dtype=torch.float32), return_dict=return_dict)
65
- return res
66
-
67
- self.pipe.vae.encode = vae_encode_offload
68
-
69
- def remove_vae_offload(self, offload_cfg):
70
- if not offload_cfg.vae_cpu:
71
- remove_hook_from_module(self.pipe.vae, recurse=True)
72
- else:
73
- self.pipe.vae.encode = self.vae_encode_raw
74
- self.pipe.vae.decode = self.vae_decode_raw
75
-
76
- @contextmanager
77
- def infer_optimize(self):
78
- if getattr(self.cfgs, 'vae_optimize', None) is not None:
79
- if self.cfgs.vae_optimize.tiling:
80
- self.pipe.vae.enable_tiling()
81
- if self.cfgs.vae_optimize.slicing:
82
- self.pipe.vae.enable_slicing()
83
- vae_device = self.pipe.vae.device
84
- if self.offload:
85
- self.build_vae_offload(self.cfgs.offload)
86
- else:
87
- self.pipe.vae.to(self.pipe.unet.device)
88
-
89
- yield
90
-
91
- if self.offload:
92
- self.remove_vae_offload(self.cfgs.offload)
93
- self.pipe.vae.to(vae_device)
94
- self.pipe.vae.disable_tiling()
95
- self.pipe.vae.disable_slicing()
96
-
97
- def preview(self):
98
- image_list, info_list = [], []
99
- with self.infer_optimize():
100
- for i in range(self.cfgs.num):
101
- prompt = self.cfgs.prompt[i*self.cfgs.bs:(i+1)*self.cfgs.bs] if isinstance(self.cfgs.prompt, list) \
102
- else [self.cfgs.prompt]*self.cfgs.bs
103
- negative_prompt = self.cfgs.neg_prompt[i*self.cfgs.bs:(i+1)*self.cfgs.bs] if isinstance(self.cfgs.neg_prompt, list) \
104
- else [self.cfgs.neg_prompt]*self.cfgs.bs
105
- seeds = self.seeds[i*self.cfgs.bs:(i+1)*self.cfgs.bs]
106
- images = self.vis_images(prompt=prompt, negative_prompt=negative_prompt, seeds=seeds,
107
- **self.cfgs.infer_args)
108
- for prompt_i, negative_prompt_i, seed in zip(prompt, negative_prompt, seeds):
109
- info_list.append({
110
- 'prompt':prompt_i,
111
- 'negative_prompt':negative_prompt_i,
112
- 'seed':seed,
113
- })
114
- image_list += images
115
-
116
- return image_list, info_list
117
-
118
- def preview_dict(self):
119
- image_list, info_list = self.preview()
120
- imgs = {f'{info["seed"]}-{to_validate_file(info["prompt"])}':img for img, info in zip(image_list, info_list)}
121
- return imgs
122
-
123
- @torch.no_grad()
124
- def vis_images(self, prompt, negative_prompt='', seeds: List[int] = None, **kwargs):
125
- G = prepare_seed(seeds or [None]*len(prompt))
126
-
127
- ex_input_dict, pipe_input_dict = self.get_ex_input()
128
- kwargs.update(pipe_input_dict)
129
-
130
- mult_p, clean_text_p = self.token_ex.parse_attn_mult(prompt)
131
- mult_n, clean_text_n = self.token_ex.parse_attn_mult(negative_prompt)
132
- with autocast(enabled=self.cfgs.amp, dtype=self.dtype):
133
- emb, pooled_output, attention_mask = self.te_hook.encode_prompt_to_emb(clean_text_n+clean_text_p)
134
- if not self.cfgs.encoder_attention_mask:
135
- attention_mask = None
136
- emb_n, emb_p = emb.chunk(2)
137
- emb_p = self.te_hook.mult_attn(emb_p, mult_p)
138
- emb_n = self.te_hook.mult_attn(emb_n, mult_n)
139
-
140
- if hasattr(self.pipe.unet, 'input_feeder'):
141
- for feeder in self.pipe.unet.input_feeder:
142
- feeder(ex_input_dict)
143
-
144
- if pooled_output is not None:
145
- pooled_output = pooled_output[-1]
146
-
147
- images = self.pipe(prompt_embeds=emb_p, negative_prompt_embeds=emb_n, callback=self.inter_callback, generator=G,
148
- pooled_output=pooled_output, encoder_attention_mask=attention_mask, **kwargs).images
149
- return images
@@ -1,30 +0,0 @@
1
- import os
2
- from typing import Dict, Any
3
-
4
- import numpy as np
5
- from PIL import Image
6
- from torch.utils.tensorboard import SummaryWriter
7
-
8
- from .base_logger import BaseLogger
9
-
10
-
11
- class TBLogger(BaseLogger):
12
- def __init__(self, exp_dir, out_path, enable_log_image=False, log_step=10, image_log_step=200):
13
- super().__init__(exp_dir, out_path, enable_log_image, log_step, image_log_step)
14
- if exp_dir is not None: # exp_dir is only available in local main process
15
- self.writer = SummaryWriter(os.path.join(exp_dir, out_path))
16
- else:
17
- self.writer = None
18
- self.disable()
19
-
20
- def _info(self, info):
21
- pass
22
-
23
- def _log(self, datas: Dict[str, Any], step: int = 0):
24
- for k, v in datas.items():
25
- if len(v['data']) == 1:
26
- self.writer.add_scalar(k, v['data'][0], global_step=step)
27
-
28
- def _log_image(self, imgs: Dict[str, Image.Image], step: int = 0):
29
- for name, img in imgs.items():
30
- self.writer.add_image(f'img/{name}', np.array(img), dataformats='HWC', global_step=step)
@@ -1,31 +0,0 @@
1
- from typing import Dict, Any
2
-
3
- import os
4
- import wandb
5
- from PIL import Image
6
-
7
- from .base_logger import BaseLogger
8
-
9
-
10
- class WanDBLogger(BaseLogger):
11
- def __init__(self, exp_dir, out_path=None, enable_log_image=False, project='hcp-diffusion', log_step=10, image_log_step=200):
12
- super().__init__(exp_dir, out_path, enable_log_image, log_step, image_log_step)
13
- if exp_dir is not None: # exp_dir is only available in local main process
14
- wandb.init(project=project, name=os.path.basename(exp_dir))
15
- wandb.save(os.path.join(exp_dir, 'cfg.yaml'), base_path=exp_dir)
16
- else:
17
- self.writer = None
18
- self.disable()
19
-
20
- def _info(self, info):
21
- pass
22
-
23
- def _log(self, datas: Dict[str, Any], step: int = 0):
24
- log_dict = {'step': step}
25
- for k, v in datas.items():
26
- if len(v['data']) == 1:
27
- log_dict[k] = v['data'][0]
28
- wandb.log(log_dict)
29
-
30
- def _log_image(self, imgs: Dict[str, Image.Image], step: int = 0):
31
- wandb.log({next(iter(imgs.keys())): list(imgs.values())}, step=step)
@@ -1,9 +0,0 @@
1
- from typing import Dict, Any
2
-
3
- from loguru import logger
4
-
5
- from .cli_logger import CLILogger
6
-
7
- class WebUILogger(CLILogger):
8
- def _log(self, datas: Dict[str, Any], step: int = 0):
9
- logger.info('this progress steps:'+', '.join([f"{k} = {v['format'].format(*v['data'])}" for k, v in datas.items()]))
@@ -1,52 +0,0 @@
1
- import torch
2
- from diffusers import SchedulerMixin
3
- from torch import nn
4
-
5
- class MinSNRLoss(nn.MSELoss):
6
- need_timesteps = True
7
-
8
- def __init__(self, size_average=None, reduce=None, reduction: str = 'none', gamma=1.,
9
- noise_scheduler: SchedulerMixin = None, device='cuda:0', **kwargs):
10
- super().__init__(size_average, reduce, reduction)
11
- self.gamma = gamma
12
-
13
- # calculate SNR
14
- alphas_cumprod = noise_scheduler.alphas_cumprod
15
- sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
16
- sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0-alphas_cumprod)
17
- self.alpha = sqrt_alphas_cumprod.to(device)
18
- self.sigma = sqrt_one_minus_alphas_cumprod.to(device)
19
- self.all_snr = ((self.alpha/self.sigma)**2).to(device)
20
-
21
- def forward(self, input: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
22
- loss = super(MinSNRLoss, self).forward(input, target)
23
- snr = self.all_snr[timesteps[:loss.shape[0], ...].squeeze()]
24
- snr_weight = (self.gamma/snr).clip(max=1.).float()
25
- return loss*snr_weight.view(-1, 1, 1, 1)
26
-
27
-
28
- class SoftMinSNRLoss(MinSNRLoss):
29
- # gamma=2
30
-
31
- def forward(self, input: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
32
- loss = super(MinSNRLoss, self).forward(input, target)
33
- snr = self.all_snr[timesteps[:loss.shape[0], ...].squeeze()]
34
- snr_weight = (self.gamma**3/(snr**2 + self.gamma**3)).float()
35
- return loss*snr_weight.view(-1, 1, 1, 1)
36
-
37
- class KDiffMinSNRLoss(MinSNRLoss):
38
-
39
- def forward(self, input: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
40
- loss = super(MinSNRLoss, self).forward(input, target)
41
- snr = self.all_snr[timesteps[:loss.shape[0], ...].squeeze()]
42
- snr_weight = 4*(((self.gamma*snr)**2/(snr**2 + self.gamma**2)**2)).float()
43
- return loss*snr_weight.view(-1, 1, 1, 1)
44
-
45
- class EDMLoss(MinSNRLoss):
46
-
47
- def forward(self, input: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
48
- loss = super(MinSNRLoss, self).forward(input, target)
49
- sigma = self.sigma[timesteps[:loss.shape[0], ...].squeeze()]
50
- snr = self.all_snr[timesteps[:loss.shape[0], ...].squeeze()]
51
- snr_weight = ((sigma**2+self.gamma**2)/(snr*(sigma*self.gamma)**2)).float()
52
- return loss*snr_weight.view(-1, 1, 1, 1)
hcpdiff/models/layers.py DELETED
@@ -1,81 +0,0 @@
1
- """
2
- layers.py
3
- ====================
4
- :Name: GroupLinear and other layers
5
- :Author: Dong Ziyi
6
- :Affiliation: HCP Lab, SYSU
7
- :Created: 09/04/2023
8
- :Licence: Apache-2.0
9
- """
10
-
11
- import torch
12
- from torch import nn
13
- import math
14
- from einops import rearrange
15
-
16
- class GroupLinear(nn.Module):
17
- def __init__(self, in_features: int, out_features: int, groups: int, bias: bool = True,
18
- device=None, dtype=None):
19
- super().__init__()
20
- assert in_features%groups == 0
21
- assert out_features%groups == 0
22
-
23
- factory_kwargs = {'device': device, 'dtype': dtype}
24
-
25
- self.groups = groups
26
- self.in_features = in_features
27
- self.out_features = out_features
28
-
29
- self.weight = nn.Parameter(torch.empty((groups, in_features//groups, out_features//groups), **factory_kwargs))
30
- if bias:
31
- self.bias = nn.Parameter(torch.empty(groups, 1, out_features//groups, **factory_kwargs))
32
- else:
33
- self.register_parameter('bias', None)
34
- self.reset_parameters()
35
-
36
- def reset_parameters(self) -> None:
37
- # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
38
- # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
39
- # https://github.com/pytorch/pytorch/issues/57109
40
- self.kaiming_uniform_group(self.weight, a=math.sqrt(5))
41
- if self.bias is not None:
42
- fan_in, _ = self._calculate_fan_in_and_fan_out(self.weight)
43
- bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
44
- nn.init.uniform_(self.bias, -bound, bound)
45
-
46
- @staticmethod
47
- def _calculate_fan_in_and_fan_out(tensor):
48
- receptive_field_size = 1
49
- num_input_fmaps = tensor.size(-2)
50
- num_output_fmaps = tensor.size(-1)
51
- fan_in = num_input_fmaps * receptive_field_size
52
- fan_out = num_output_fmaps * receptive_field_size
53
-
54
- return fan_in, fan_out
55
-
56
- @staticmethod
57
- def kaiming_uniform_group(tensor: torch.Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu') -> torch.Tensor:
58
- def _calculate_correct_fan(tensor, mode):
59
- mode = mode.lower()
60
- valid_modes = ['fan_in', 'fan_out']
61
- if mode not in valid_modes:
62
- raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
63
-
64
- fan_in, fan_out = GroupLinear._calculate_fan_in_and_fan_out(tensor)
65
- return fan_in if mode == 'fan_in' else fan_out
66
-
67
- fan = _calculate_correct_fan(tensor, mode)
68
- gain = nn.init.calculate_gain(nonlinearity, a)
69
- std = gain / math.sqrt(fan)
70
- bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
71
- with torch.no_grad():
72
- return tensor.uniform_(-bound, bound)
73
-
74
- def forward(self, x: torch.Tensor): # x: [G,B,L,C]
75
- x = rearrange(x, '(g b) l c -> g (b l) c', g=self.num_groups)
76
- if self.bias is not None:
77
- out = torch.bmm(x, self.weight) + self.bias
78
- else:
79
- out = torch.bmm(x, self.weight)
80
- out = rearrange(out, 'g (b l) c -> (g b) l c', b=B)
81
- return out