hcpdiff 0.9.1__py3-none-any.whl → 2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +252 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +90 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +3 -3
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text.py +40 -0
  21. hcpdiff/data/source/text2img.py +36 -74
  22. hcpdiff/data/source/text2img_cond.py +9 -15
  23. hcpdiff/diffusion/__init__.py +0 -0
  24. hcpdiff/diffusion/noise/__init__.py +2 -0
  25. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  26. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  27. hcpdiff/diffusion/sampler/__init__.py +5 -0
  28. hcpdiff/diffusion/sampler/base.py +72 -0
  29. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  30. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  31. hcpdiff/diffusion/sampler/edm.py +22 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  35. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  36. hcpdiff/easy/__init__.py +2 -0
  37. hcpdiff/easy/cfg/__init__.py +3 -0
  38. hcpdiff/easy/cfg/sd15_train.py +207 -0
  39. hcpdiff/easy/cfg/sdxl_train.py +147 -0
  40. hcpdiff/easy/cfg/t2i.py +228 -0
  41. hcpdiff/easy/model/__init__.py +2 -0
  42. hcpdiff/easy/model/cnet.py +31 -0
  43. hcpdiff/easy/model/loader.py +79 -0
  44. hcpdiff/easy/sampler.py +46 -0
  45. hcpdiff/evaluate/__init__.py +1 -0
  46. hcpdiff/evaluate/previewer.py +60 -0
  47. hcpdiff/loss/__init__.py +4 -1
  48. hcpdiff/loss/base.py +41 -0
  49. hcpdiff/loss/gw.py +35 -0
  50. hcpdiff/loss/ssim.py +37 -0
  51. hcpdiff/loss/vlb.py +79 -0
  52. hcpdiff/loss/weighting.py +66 -0
  53. hcpdiff/models/__init__.py +2 -2
  54. hcpdiff/models/cfg_context.py +17 -14
  55. hcpdiff/models/compose/compose_hook.py +44 -23
  56. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  57. hcpdiff/models/compose/sdxl_composer.py +4 -4
  58. hcpdiff/models/controlnet.py +16 -16
  59. hcpdiff/models/lora_base_patch.py +14 -25
  60. hcpdiff/models/lora_layers.py +3 -9
  61. hcpdiff/models/lora_layers_patch.py +14 -24
  62. hcpdiff/models/text_emb_ex.py +84 -6
  63. hcpdiff/models/textencoder_ex.py +54 -18
  64. hcpdiff/models/wrapper/__init__.py +3 -0
  65. hcpdiff/models/wrapper/pixart.py +19 -0
  66. hcpdiff/models/wrapper/sd.py +218 -0
  67. hcpdiff/models/wrapper/utils.py +20 -0
  68. hcpdiff/parser/__init__.py +1 -0
  69. hcpdiff/parser/embpt.py +32 -0
  70. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  71. hcpdiff/tools/dataset_generator.py +94 -0
  72. hcpdiff/tools/download_hf_model.py +24 -0
  73. hcpdiff/tools/init_proj.py +3 -21
  74. hcpdiff/tools/lora_convert.py +18 -17
  75. hcpdiff/tools/save_model.py +12 -0
  76. hcpdiff/tools/sd2diffusers.py +1 -1
  77. hcpdiff/train_colo.py +1 -1
  78. hcpdiff/train_deepspeed.py +1 -1
  79. hcpdiff/trainer_ac.py +79 -0
  80. hcpdiff/trainer_ac_single.py +31 -0
  81. hcpdiff/utils/__init__.py +0 -2
  82. hcpdiff/utils/inpaint_pipe.py +7 -2
  83. hcpdiff/utils/net_utils.py +29 -6
  84. hcpdiff/utils/pipe_hook.py +24 -7
  85. hcpdiff/utils/utils.py +21 -4
  86. hcpdiff/workflow/__init__.py +15 -10
  87. hcpdiff/workflow/daam/__init__.py +1 -0
  88. hcpdiff/workflow/daam/act.py +66 -0
  89. hcpdiff/workflow/daam/hook.py +109 -0
  90. hcpdiff/workflow/diffusion.py +118 -128
  91. hcpdiff/workflow/fast.py +31 -0
  92. hcpdiff/workflow/flow.py +67 -0
  93. hcpdiff/workflow/io.py +36 -130
  94. hcpdiff/workflow/model.py +46 -43
  95. hcpdiff/workflow/text.py +60 -47
  96. hcpdiff/workflow/utils.py +32 -12
  97. hcpdiff/workflow/vae.py +37 -38
  98. hcpdiff-2.2.dist-info/METADATA +299 -0
  99. hcpdiff-2.2.dist-info/RECORD +115 -0
  100. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/WHEEL +1 -1
  101. hcpdiff-2.2.dist-info/entry_points.txt +5 -0
  102. hcpdiff/ckpt_manager/base.py +0 -16
  103. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  104. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  105. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
  106. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  107. hcpdiff/data/bucket.py +0 -358
  108. hcpdiff/data/caption_loader.py +0 -80
  109. hcpdiff/data/cond_dataset.py +0 -40
  110. hcpdiff/data/crop_info_dataset.py +0 -40
  111. hcpdiff/data/data_processor.py +0 -33
  112. hcpdiff/data/pair_dataset.py +0 -146
  113. hcpdiff/data/sampler.py +0 -54
  114. hcpdiff/data/source/base.py +0 -30
  115. hcpdiff/data/utils.py +0 -80
  116. hcpdiff/deprecated/__init__.py +0 -1
  117. hcpdiff/deprecated/cfg_converter.py +0 -81
  118. hcpdiff/deprecated/lora_convert.py +0 -31
  119. hcpdiff/infer_workflow.py +0 -57
  120. hcpdiff/loggers/__init__.py +0 -13
  121. hcpdiff/loggers/base_logger.py +0 -76
  122. hcpdiff/loggers/cli_logger.py +0 -40
  123. hcpdiff/loggers/preview/__init__.py +0 -1
  124. hcpdiff/loggers/preview/image_previewer.py +0 -149
  125. hcpdiff/loggers/tensorboard_logger.py +0 -30
  126. hcpdiff/loggers/wandb_logger.py +0 -31
  127. hcpdiff/loggers/webui_logger.py +0 -9
  128. hcpdiff/loss/min_snr_loss.py +0 -52
  129. hcpdiff/models/layers.py +0 -81
  130. hcpdiff/models/plugin.py +0 -348
  131. hcpdiff/models/wrapper.py +0 -75
  132. hcpdiff/noise/__init__.py +0 -3
  133. hcpdiff/noise/noise_base.py +0 -16
  134. hcpdiff/noise/pyramid_noise.py +0 -50
  135. hcpdiff/noise/zero_terminal.py +0 -44
  136. hcpdiff/train_ac.py +0 -566
  137. hcpdiff/train_ac_single.py +0 -39
  138. hcpdiff/utils/caption_tools.py +0 -105
  139. hcpdiff/utils/cfg_net_tools.py +0 -321
  140. hcpdiff/utils/cfg_resolvers.py +0 -16
  141. hcpdiff/utils/ema.py +0 -52
  142. hcpdiff/utils/img_size_tool.py +0 -248
  143. hcpdiff/vis/__init__.py +0 -3
  144. hcpdiff/vis/base_interface.py +0 -12
  145. hcpdiff/vis/disk_interface.py +0 -48
  146. hcpdiff/vis/webui_interface.py +0 -17
  147. hcpdiff/viser_fast.py +0 -138
  148. hcpdiff/visualizer.py +0 -265
  149. hcpdiff/visualizer_reloadable.py +0 -237
  150. hcpdiff/workflow/base.py +0 -59
  151. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  152. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  153. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  154. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  155. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  156. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  157. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  158. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  159. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  160. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  161. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  162. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  163. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  164. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  165. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  166. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  167. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  168. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  169. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  170. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  171. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  172. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  173. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  174. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  175. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  176. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  177. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  178. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  179. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  180. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  181. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  182. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  183. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  184. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  185. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  186. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  187. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  188. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  189. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  190. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  191. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  192. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  193. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  194. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  195. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
  196. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  197. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
  198. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
  199. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  200. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  201. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  202. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  203. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  204. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  205. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  206. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  207. hcpdiff-0.9.1.dist-info/METADATA +0 -199
  208. hcpdiff-0.9.1.dist-info/RECORD +0 -160
  209. hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
  210. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info/licenses}/LICENSE +0 -0
  211. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1 @@
1
+ from .act import CaptureCrossAttnAction, SaveWordAttnAction
@@ -0,0 +1,66 @@
1
+ import os
2
+ from io import BytesIO
3
+
4
+ import numpy as np
5
+ from PIL import Image
6
+ from hcpdiff.utils import to_validate_file
7
+ from rainbowneko.utils import types_support
8
+ from matplotlib import pyplot as plt
9
+ from rainbowneko.infer import BasicAction, Actions
10
+
11
+ from .hook import DiffusionHeatMapHooker
12
+
13
+ class CaptureCrossAttnAction(Actions):
14
+ def forward(self, prompt, denoiser, tokenizer, vae, **states):
15
+ bs = len(prompt)
16
+ N_head = 8
17
+ with DiffusionHeatMapHooker(denoiser, tokenizer, vae_scale_factor=vae.vae_scale_factor) as tc:
18
+ states = super().forward(**states)
19
+ heat_maps = [tc.compute_global_heat_map(prompt=prompt[i], head_idxs=range(N_head*i, N_head*(i+1))) for i in range(bs)]
20
+
21
+ return {**states, 'cross_attn_heat_maps':heat_maps}
22
+
23
+ class SaveWordAttnAction(BasicAction):
24
+
25
+ def __init__(self, save_root: str, N_col: int = 4, image_type: str = 'png', quality: int = 95, key_map_in=None, key_map_out=None):
26
+ super().__init__(key_map_in, key_map_out)
27
+ self.save_root = save_root
28
+ self.image_type = image_type
29
+ self.quality = quality
30
+ self.N_col = N_col
31
+
32
+ os.makedirs(save_root, exist_ok=True)
33
+
34
+ def draw_attn(self, tokenizer, prompt, image, global_heat_map):
35
+ prompt=tokenizer.bos_token+prompt+tokenizer.eos_token
36
+ tokens = [token.replace("</w>", "") for token in tokenizer.tokenize(prompt)]
37
+
38
+ d_len = self.N_col
39
+ plt.rcParams['figure.dpi'] = 300
40
+ plt.rcParams.update({'font.size':12})
41
+ h = int(np.ceil(len(tokens)/d_len))
42
+ fig, ax = plt.subplots(h, d_len, figsize=(2*d_len, 2*h))
43
+ for ax_ in ax.flatten():
44
+ ax_.set_xticks([])
45
+ ax_.set_yticks([])
46
+ for i, token in enumerate(tokens):
47
+ heat_map = global_heat_map.compute_word_heat_map(token, word_idx=i)
48
+ if h==1:
49
+ heat_map.plot_overlay(image, ax=ax[i%d_len])
50
+ else:
51
+ heat_map.plot_overlay(image, ax=ax[i//d_len, i%d_len])
52
+ # plt.tight_layout()
53
+
54
+ buf = BytesIO()
55
+ plt.savefig(buf, format='png')
56
+ buf.seek(0)
57
+ return Image.open(buf)
58
+
59
+ def forward(self, tokenizer, images, prompt, seeds, cross_attn_heat_maps, **states):
60
+ num_img_exist = max([0]+[int(x.split('-', 1)[0]) for x in os.listdir(self.save_root) if x.rsplit('.', 1)[-1] in types_support])
61
+
62
+ for bid, (p, img) in enumerate(zip(prompt, images)):
63
+ img_path = os.path.join(self.save_root, f"{num_img_exist}-{seeds[bid]}-cross_attn-{to_validate_file(prompt[0])}.{self.image_type}")
64
+ img = self.draw_attn(tokenizer, p, img, cross_attn_heat_maps[bid])
65
+ img.save(img_path, quality=self.quality)
66
+ num_img_exist += 1
@@ -0,0 +1,109 @@
1
+ from daam import AggregateHooker, RawHeatMapCollection, UNetCrossAttentionLocator, GlobalHeatMap
2
+ from daam.trace import UNetCrossAttentionHooker
3
+ from typing import List
4
+ from diffusers import UNet2DConditionModel
5
+ from PIL import Image
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ def auto_autocast(*args, **kwargs):
11
+ if not torch.cuda.is_available():
12
+ kwargs['enabled'] = False
13
+
14
+ return torch.cuda.amp.autocast(*args, **kwargs)
15
+
16
+ class DiffusionHeatMapHooker(AggregateHooker):
17
+ def __init__(
18
+ self,
19
+ unet: UNet2DConditionModel,
20
+ tokenizer,
21
+ vae_scale_factor: int,
22
+ low_memory: bool = False,
23
+ load_heads: bool = False,
24
+ save_heads: bool = False,
25
+ data_dir: str = None
26
+ ):
27
+ self.all_heat_maps = RawHeatMapCollection()
28
+ h = (unet.config.sample_size * vae_scale_factor)
29
+ self.latent_hw = 4096 if h == 512 or h == 1024 else 9216 # 64x64 or 96x96 depending on if it's 2.0-v or 2.0
30
+ locate_middle = load_heads or save_heads
31
+ self.locator = UNetCrossAttentionLocator(restrict={0} if low_memory else None, locate_middle_block=locate_middle)
32
+ self.last_prompt: str = ''
33
+ self.last_image: Image.Image = None
34
+ self.time_idx = 0
35
+ self._gen_idx = 0
36
+
37
+ self.tokenizer = tokenizer
38
+
39
+ modules = [
40
+ UNetCrossAttentionHooker(
41
+ x,
42
+ self,
43
+ layer_idx=idx,
44
+ latent_hw=self.latent_hw,
45
+ load_heads=load_heads,
46
+ save_heads=save_heads,
47
+ data_dir=data_dir
48
+ ) for idx, x in enumerate(self.locator.locate(unet))
49
+ ]
50
+
51
+ super().__init__(modules)
52
+
53
+ def time_callback(self, *args, **kwargs):
54
+ self.time_idx += 1
55
+
56
+ @property
57
+ def layer_names(self):
58
+ return self.locator.layer_names
59
+
60
+ def compute_global_heat_map(self, prompt=None, factors=None, head_idxs: List[int]=None, layer_idx=None, normalize=False):
61
+ # type: (str, List[float], int, int, bool) -> GlobalHeatMap
62
+ """
63
+ Compute the global heat map for the given prompt, aggregating across time (inference steps) and space (different
64
+ spatial transformer block heat maps).
65
+
66
+ Args:
67
+ prompt: The prompt to compute the heat map for. If none, uses the last prompt that was used for generation.
68
+ factors: Restrict the application to heat maps with spatial factors in this set. If `None`, use all sizes.
69
+ head_idx: Restrict the application to heat maps with this head index. If `None`, use all heads.
70
+ layer_idx: Restrict the application to heat maps with this layer index. If `None`, use all layers.
71
+
72
+ Returns:
73
+ A heat map object for computing word-level heat maps.
74
+ """
75
+ heat_maps = self.all_heat_maps
76
+
77
+ if prompt is None:
78
+ prompt = self.last_prompt
79
+
80
+ if factors is None:
81
+ factors = {0, 1, 2, 4, 8, 16, 32, 64}
82
+ else:
83
+ factors = set(factors)
84
+
85
+ all_merges = []
86
+ x = int(np.sqrt(self.latent_hw))
87
+
88
+ with auto_autocast(dtype=torch.float32):
89
+ for (factor, layer, head), heat_map in heat_maps:
90
+ if (head_idxs is None or head in head_idxs) and (layer_idx is None or layer_idx == layer):
91
+ heat_map = heat_map.unsqueeze(1)/25 # [L,1,H,W]
92
+ # The clamping fixes undershoot.
93
+ all_merges.append(F.interpolate(heat_map, size=(x, x), mode='bicubic').clamp_(min=0))
94
+
95
+ try:
96
+ maps = torch.stack(all_merges, dim=0) # [B*head, L, 1, H, W]
97
+ except RuntimeError:
98
+ if head_idxs is not None or layer_idx is not None:
99
+ raise RuntimeError('No heat maps found for the given parameters.')
100
+ else:
101
+ raise RuntimeError('No heat maps found. Did you forget to call `with trace(...)` during generation?')
102
+
103
+ maps = maps.mean(0)[:, 0] # [L,H,W]
104
+ #maps = maps[:len(self.tokenizer.tokenize(prompt)) + 2] # 1 for SOS and 1 for padding
105
+
106
+ if normalize:
107
+ maps = maps / (maps[1:-1].sum(0, keepdim=True) + 1e-6) # drop out [SOS] and [PAD] for proper probabilities
108
+
109
+ return GlobalHeatMap(self.tokenizer, prompt, maps)
@@ -1,209 +1,199 @@
1
- import inspect
1
+ import random
2
+ import warnings
2
3
  from typing import Dict, Any, Union, List
3
4
 
4
5
  import torch
6
+ from hcpdiff.diffusion.sampler import BaseSampler, DiffusersSampler
7
+ from hcpdiff.utils import prepare_seed
8
+ from hcpdiff.utils.net_utils import get_dtype, to_cuda
9
+ from rainbowneko.infer import BasicAction
5
10
  from torch.cuda.amp import autocast
6
11
 
7
- from .base import BasicAction, from_memory_context, MemoryMixin
8
-
9
12
  try:
10
13
  from diffusers.utils import randn_tensor
11
14
  except:
12
15
  # new version of diffusers
13
16
  from diffusers.utils.torch_utils import randn_tensor
14
17
 
15
- from hcpdiff.utils import prepare_seed
16
- from hcpdiff.utils.net_utils import get_dtype
17
- import random
18
-
19
18
  class InputFeederAction(BasicAction):
20
- @from_memory_context
21
- def __init__(self, ex_inputs: Dict[str, Any], unet=None):
22
- super().__init__()
19
+ def __init__(self, ex_inputs: Dict[str, Any], key_map_in=None, key_map_out=None):
20
+ super().__init__(key_map_in, key_map_out)
23
21
  self.ex_inputs = ex_inputs
24
- self.unet = unet
25
22
 
26
- def forward(self, **states):
27
- if hasattr(self.unet, 'input_feeder'):
28
- for feeder in self.unet.input_feeder:
29
- feeder(self.ex_inputs)
30
- return states
23
+ def forward(self, model, ex_inputs=None, **states):
24
+ ex_inputs = self.ex_inputs if ex_inputs is None else {**ex_inputs, **self.ex_inputs}
25
+ if hasattr(model, 'input_feeder'):
26
+ for feeder in model.input_feeder:
27
+ feeder(ex_inputs)
31
28
 
32
29
  class SeedAction(BasicAction):
33
- @from_memory_context
34
- def __init__(self, seed: Union[int, List[int]], bs: int = 1):
35
- super().__init__()
30
+ def __init__(self, seed: Union[int, List[int]], bs: int = 1, key_map_in=None, key_map_out=None):
31
+ super().__init__(key_map_in, key_map_out)
36
32
  self.seed = seed
37
33
  self.bs = bs
38
34
 
39
- def forward(self, device, **states):
35
+ def forward(self, device, seed=None, **states):
40
36
  bs = states['prompt_embeds'].shape[0]//2 if 'prompt_embeds' in states else self.bs
41
- if self.seed is None:
37
+ seed = seed or self.seed
38
+ if seed is None:
42
39
  seeds = [None]*bs
43
- elif isinstance(self.seed, int):
44
- seeds = list(range(self.seed, self.seed+bs))
40
+ elif isinstance(seed, int):
41
+ seeds = list(range(seed, seed+bs))
45
42
  else:
46
- seeds = self.seed
43
+ seeds = seed
47
44
  seeds = [s or random.randint(0, 1 << 30) for s in seeds]
48
45
 
49
46
  G = prepare_seed(seeds, device=device)
50
- return {**states, 'seeds':seeds, 'generator':G, 'device':device}
47
+ return {'seeds':seeds, 'generator':G}
51
48
 
52
- class PrepareDiffusionAction(BasicAction, MemoryMixin):
53
- def __init__(self, dtype='fp32', amp=True):
54
- self.dtype = dtype
49
+ class PrepareDiffusionAction(BasicAction):
50
+ def __init__(self, model_offload=False, amp=torch.float16, key_map_in=None, key_map_out=None):
51
+ super().__init__(key_map_in, key_map_out)
52
+ self.model_offload = model_offload
55
53
  self.amp = amp
56
54
 
57
- def forward(self, memory, **states):
58
- dtype = get_dtype(self.dtype)
59
- memory.unet.to(dtype=dtype)
60
- memory.text_encoder.to(dtype=dtype)
61
- memory.vae.to(dtype=dtype)
55
+ def forward(self, device, denoiser, TE, vae, **states):
56
+ denoiser.to(device)
57
+ TE.to(device)
58
+ vae.to(device)
62
59
 
63
- device = memory.unet.device
64
- vae_scale_factor = 2**(len(memory.vae.config.block_out_channels)-1)
65
- return {**states, 'dtype':self.dtype, 'amp':self.amp, 'device':device, 'vae_scale_factor':vae_scale_factor}
60
+ TE.eval()
61
+ denoiser.eval()
62
+ vae.eval()
63
+ return {'amp':self.amp, 'model_offload':self.model_offload}
66
64
 
67
- class MakeTimestepsAction(BasicAction, MemoryMixin):
68
- @from_memory_context
69
- def __init__(self, scheduler=None, N_steps: int = 30, strength: float = None):
70
- self.scheduler = scheduler
65
+ class MakeTimestepsAction(BasicAction):
66
+ def __init__(self, N_steps: int = 30, strength: float = None, key_map_in=None, key_map_out=None):
67
+ super().__init__(key_map_in, key_map_out)
71
68
  self.N_steps = N_steps
72
69
  self.strength = strength
73
70
 
74
- def get_timesteps(self, timesteps, strength):
71
+ def get_timesteps(self, noise_sampler:BaseSampler, timesteps, strength):
75
72
  # get the original timestep using init_timestep
76
73
  num_inference_steps = len(timesteps)
77
74
  init_timestep = min(int(num_inference_steps*strength), num_inference_steps)
78
75
 
79
76
  t_start = max(num_inference_steps-init_timestep, 0)
80
- timesteps = timesteps[t_start*self.scheduler.order:]
77
+ if isinstance(noise_sampler, DiffusersSampler):
78
+ timesteps = timesteps[t_start*noise_sampler.scheduler.order:]
79
+ else:
80
+ timesteps = timesteps[t_start:]
81
81
 
82
82
  return timesteps
83
83
 
84
- def forward(self, memory, device, **states):
85
- self.scheduler = self.scheduler or memory.scheduler
86
-
87
- self.scheduler.set_timesteps(self.N_steps, device=device)
88
- timesteps = self.scheduler.timesteps
84
+ def forward(self, noise_sampler:BaseSampler, device, **states):
85
+ timesteps = noise_sampler.get_timesteps(self.N_steps, device=device)
89
86
  if self.strength:
90
- timesteps = self.get_timesteps(timesteps, self.strength)
91
- alphas_cumprod = self.scheduler.alphas_cumprod.to(timesteps.device)
92
- return {**states, 'device':device, 'timesteps':timesteps, 'alphas_cumprod':alphas_cumprod}
93
-
94
- class MakeLatentAction(BasicAction, MemoryMixin):
95
- @from_memory_context
96
- def __init__(self, scheduler=None, N_ch=4, height=512, width=512):
97
- self.scheduler = scheduler
87
+ timesteps = self.get_timesteps(noise_sampler, timesteps, self.strength)
88
+ return {'timesteps':timesteps, 'start_timestep':timesteps[:1]}
89
+ else:
90
+ return {'timesteps':timesteps}
91
+
92
+ class MakeLatentAction(BasicAction):
93
+ def __init__(self, N_ch=4, height=None, width=None, key_map_in=None, key_map_out=None):
94
+ super().__init__(key_map_in, key_map_out)
98
95
  self.N_ch = N_ch
99
96
  self.height = height
100
97
  self.width = width
101
98
 
102
- def forward(self, memory, generator, device, dtype, bs=None, latents=None, vae_scale_factor=8, start_timestep=None, **states):
99
+ def forward(self, noise_sampler:BaseSampler, vae, generator, device, dtype, bs=None, latents=None, start_timestep=None,
100
+ pooled_output=None, crop_coord=None, **states):
103
101
  if bs is None:
104
102
  if 'prompt' in states:
105
103
  bs = len(states['prompt'])
106
- scheduler = self.scheduler or memory.scheduler
104
+ vae_scale_factor = 2**(len(vae.config.block_out_channels)-1)
105
+ device = torch.device(device)
107
106
 
108
- shape = (bs, self.N_ch, self.height//vae_scale_factor, self.width//vae_scale_factor)
107
+ if latents is None:
108
+ shape = (bs, self.N_ch, self.height//vae_scale_factor, self.width//vae_scale_factor)
109
+ else:
110
+ if self.height is not None:
111
+ warnings.warn('latents exist! User-specified width and height will be ignored!')
112
+ shape = latents.shape
109
113
  if isinstance(generator, list) and len(generator) != bs:
110
114
  raise ValueError(
111
115
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
112
116
  f" size of {bs}. Make sure the batch size matches the length of the generators."
113
117
  )
114
118
 
115
- noise = randn_tensor(shape, generator=generator, device=device, dtype=get_dtype(dtype))
116
119
  if latents is None:
117
- # scale the initial noise by the standard deviation required by the scheduler
118
- latents = noise*scheduler.init_noise_sigma
120
+ # scale the initial noise by the standard deviation required by the noise_sampler
121
+ noise_sampler.generator = generator
122
+ latents = noise_sampler.init_noise(shape, device=device, dtype=get_dtype(dtype))
119
123
  else:
120
124
  # image to image
121
125
  latents = latents.to(device)
122
- latents = scheduler.add_noise(latents, noise, start_timestep)
126
+ latents, noise = noise_sampler.add_noise(latents, start_timestep)
123
127
 
124
- return {**states, 'latents':latents, 'device':device, 'dtype':dtype, 'generator':generator}
128
+ output = {'latents':latents}
125
129
 
126
- class NoisePredAction(BasicAction, MemoryMixin):
127
- @from_memory_context
128
- def __init__(self, unet=None, scheduler=None, guidance_scale: float = 7.0):
130
+ # SDXL inputs
131
+ if pooled_output is not None:
132
+ width, height = shape[3]*vae_scale_factor, shape[2]*vae_scale_factor
133
+ if crop_coord is None:
134
+ crop_info = torch.tensor([height, width, 0, 0, height, width], dtype=torch.float)
135
+ else:
136
+ crop_info = torch.tensor([height, width, *crop_coord], dtype=torch.float)
137
+ crop_info = crop_info.to(device).repeat(bs, 1)
138
+ output['text_embeds'] = pooled_output[-1].to(device)
139
+
140
+ if 'negative_prompt' in states:
141
+ output['crop_info'] = torch.cat([crop_info, crop_info], dim=0)
142
+
143
+ return output
144
+
145
+ class DenoiseAction(BasicAction):
146
+ def __init__(self, guidance_scale: float = 7.0, key_map_in=None, key_map_out=None):
147
+ super().__init__(key_map_in, key_map_out)
129
148
  self.guidance_scale = guidance_scale
130
- self.unet = unet
131
- self.scheduler = scheduler
132
149
 
133
- def forward(self, memory, t, latents, prompt_embeds, pooled_output=None, encoder_attention_mask=None, crop_info=None,
134
- cross_attention_kwargs=None, dtype='fp32', amp=None, **states):
135
- self.scheduler = self.scheduler or memory.scheduler
136
- self.unet = self.unet or memory.unet
150
+ def forward(self, denoiser, noise_sampler: BaseSampler, t, latents, prompt_embeds, text_embeds=None, encoder_attention_mask=None, crop_info=None,
151
+ cross_attention_kwargs=None, dtype='fp32', amp=None, model_offload=False, **states):
152
+
153
+ if model_offload:
154
+ to_cuda(denoiser) # to_cpu in VAE
137
155
 
138
156
  with autocast(enabled=amp is not None, dtype=get_dtype(amp)):
139
157
  latent_model_input = torch.cat([latents]*2) if self.guidance_scale>1 else latents
140
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
158
+ latent_model_input = noise_sampler.c_in(t)*latent_model_input
141
159
 
142
- if pooled_output is None:
143
- noise_pred = self.unet(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
144
- cross_attention_kwargs=cross_attention_kwargs, ).sample
160
+ if text_embeds is None:
161
+ noise_pred = denoiser(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
162
+ cross_attention_kwargs=cross_attention_kwargs, ).sample
145
163
  else:
146
- added_cond_kwargs = {"text_embeds":pooled_output, "time_ids":crop_info}
164
+ added_cond_kwargs = {"text_embeds":text_embeds, "time_ids":crop_info}
147
165
  # predict the noise residual
148
- noise_pred = self.unet(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
149
- cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
166
+ noise_pred = denoiser(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
167
+ cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
150
168
 
151
169
  # perform guidance
152
170
  if self.guidance_scale>1:
153
171
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
154
172
  noise_pred = noise_pred_uncond+self.guidance_scale*(noise_pred_text-noise_pred_uncond)
155
173
 
156
- return {**states, 'noise_pred':noise_pred, 'latents':latents, 't':t, 'prompt_embeds':prompt_embeds, 'pooled_output':pooled_output,
157
- 'crop_info':crop_info, 'cross_attention_kwargs':cross_attention_kwargs, 'dtype':dtype, 'amp':amp}
158
-
159
- class SampleAction(BasicAction, MemoryMixin):
160
- @from_memory_context
161
- def __init__(self, scheduler=None, eta=0.0):
162
- self.scheduler = scheduler
163
- self.eta = eta
164
-
165
- def prepare_extra_step_kwargs(self, generator, eta):
166
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
167
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
168
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
169
- # and should be between [0, 1]
170
-
171
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
172
- extra_step_kwargs = {}
173
- if accepts_eta:
174
- extra_step_kwargs["eta"] = eta
175
-
176
- # check if the scheduler accepts generator
177
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
178
- if accepts_generator:
179
- extra_step_kwargs["generator"] = generator
180
- return extra_step_kwargs
181
-
182
- def forward(self, memory, noise_pred, t, latents, generator, **states):
183
- self.scheduler = self.scheduler or memory.scheduler
184
-
185
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, self.eta)
174
+ return {'noise_pred':noise_pred}
186
175
 
176
+ class SampleAction(BasicAction):
177
+ def forward(self, noise_sampler: BaseSampler, noise_pred, t, latents, generator, **states):
187
178
  # compute the previous noisy sample x_t -> x_t-1
188
- sc_out = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)
189
- latents = sc_out.prev_sample
190
- return {**states, 'latents':latents, 't':t, 'generator':generator}
191
-
192
- class DiffusionStepAction(BasicAction, MemoryMixin):
193
- @from_memory_context
194
- def __init__(self, unet=None, scheduler=None, guidance_scale: float = 7.0):
195
- self.act_noise_pred = NoisePredAction(unet, scheduler, guidance_scale)
196
- self.act_sample = SampleAction(scheduler)
197
-
198
- def forward(self, memory, **states):
199
- states = self.act_noise_pred(memory=memory, **states)
200
- states = self.act_sample(memory=memory, **states)
179
+ latents = noise_sampler.denoise(latents, t, noise_pred, generator=generator)
180
+ return {'latents':latents}
181
+
182
+ class DiffusionStepAction(BasicAction):
183
+ def __init__(self, guidance_scale: float = 7.0, key_map_in=None, key_map_out=None):
184
+ super().__init__(key_map_in, key_map_out)
185
+ self.act_noise_pred = DenoiseAction(guidance_scale)
186
+ self.act_sample = SampleAction()
187
+
188
+ def forward(self, denoiser, noise_sampler, **states):
189
+ states = self.act_noise_pred(denoiser=denoiser, noise_sampler=noise_sampler, **states)
190
+ states = self.act_sample(**states)
201
191
  return states
202
192
 
203
193
  class X0PredAction(BasicAction):
204
- def forward(self, latents, alphas_cumprod, t, noise_pred, **states):
205
- # x_t -> x_0
206
- alpha_prod_t = alphas_cumprod[t.long()]
207
- beta_prod_t = 1-alpha_prod_t
208
- latents_x0 = (latents-beta_prod_t**(0.5)*noise_pred)/alpha_prod_t**(0.5) # approximate x_0
209
- return {**states, 'latents_x0':latents_x0, 'latents':latents, 'alphas_cumprod':alphas_cumprod, 't':t, 'noise_pred':noise_pred}
194
+ def forward(self, latents, noise_sampler: BaseSampler, t, noise_pred, **states):
195
+ latents_x0 = noise_sampler.eps_to_x0(noise_pred, latents, t)
196
+ return {'latents_x0':latents_x0}
197
+
198
+ def time_iter(timesteps, **states):
199
+ return [{'t':t} for t in timesteps]
@@ -0,0 +1,31 @@
1
+ from sfast.compilers.diffusion_pipeline_compiler import (compile_unet, CompilationConfig)
2
+ from rainbowneko.infer import BasicAction
3
+
4
+
5
+ class SFastCompileAction(BasicAction):
6
+
7
+ @staticmethod
8
+ def compile_model(unet):
9
+ # compile model
10
+ config = CompilationConfig.Default()
11
+ config.enable_xformers = False
12
+ try:
13
+ import xformers
14
+ config.enable_xformers = True
15
+ except ImportError:
16
+ print('xformers not installed, skip')
17
+ # NOTE:
18
+ # When GPU VRAM is insufficient or the architecture is too old, Triton might be slow.
19
+ # Disable Triton if you encounter this problem.
20
+ try:
21
+ import tritonx
22
+ config.enable_triton = True
23
+ except ImportError:
24
+ print('Triton not installed, skip')
25
+ config.enable_cuda_graph = True
26
+
27
+ return compile_unet(unet, config)
28
+
29
+ def forward(self, denoiser, **states):
30
+ denoiser = self.compile_model(denoiser)
31
+ return {'denoiser': denoiser}
@@ -0,0 +1,67 @@
1
+ from rainbowneko.infer import BasicAction
2
+ from typing import List, Dict
3
+ from tqdm import tqdm
4
+ import math
5
+
6
+ class FilePromptAction(BasicAction):
7
+ def __init__(self, actions: List[BasicAction], prompt: str, negative_prompt: str, bs: int = 4, key_map_in=None, key_map_out=None):
8
+ super().__init__(key_map_in, key_map_out)
9
+ if prompt.endswith('.txt'):
10
+ with open(prompt, 'r') as f:
11
+ prompt = f.read().split('\n')
12
+ else:
13
+ prompt = [prompt]
14
+
15
+ if negative_prompt.endswith('.txt'):
16
+ with open(negative_prompt, 'r') as f:
17
+ negative_prompt = f.read().split('\n')
18
+ else:
19
+ negative_prompt = [negative_prompt]*len(prompt)
20
+
21
+ self.prompt = prompt
22
+ self.negative_prompt = negative_prompt
23
+ self.bs = bs
24
+ self.actions = actions
25
+
26
+
27
+ def forward(self, **states):
28
+ states.update({'prompt_all':self.prompt, 'negative_prompt_all':self.negative_prompt})
29
+ states_ref = dict(**states)
30
+
31
+ pbar = tqdm(range(math.ceil(len(self.prompt)/self.bs)))
32
+ N_steps = len(self.actions)
33
+ for gen_step in pbar:
34
+ states = dict(**states_ref)
35
+ feed_data = {'gen_step': gen_step}
36
+ states.update(feed_data)
37
+ for step, act in enumerate(self.actions):
38
+ pbar.set_description(f'[{step+1}/{N_steps}] action: {type(act).__name__}')
39
+ states = act(**states)
40
+ return states
41
+
42
+ class FlowPromptAction(BasicAction):
43
+ def __init__(self, actions: List[BasicAction], prompt: str, negative_prompt: str, bs: int = 4, num: int = None, key_map_in=None, key_map_out=None):
44
+ super().__init__(key_map_in, key_map_out)
45
+ prompt = [prompt]*num
46
+ negative_prompt = [negative_prompt]*num
47
+
48
+ self.prompt = prompt
49
+ self.negative_prompt = negative_prompt
50
+ self.bs = bs
51
+ self.actions = actions
52
+
53
+
54
+ def forward(self, **states):
55
+ states.update({'prompt_all':self.prompt, 'negative_prompt_all':self.negative_prompt})
56
+ states_ref = dict(**states)
57
+
58
+ pbar = tqdm(range(math.ceil(len(self.prompt)/self.bs)))
59
+ N_steps = len(self.actions)
60
+ for gen_step in pbar:
61
+ states = dict(**states_ref)
62
+ feed_data = {'gen_step': gen_step}
63
+ states.update(feed_data)
64
+ for step, act in enumerate(self.actions):
65
+ pbar.set_description(f'[{step+1}/{N_steps}] action: {type(act).__name__}')
66
+ states = act(**states)
67
+ return states