hcpdiff 0.9.0__py3-none-any.whl → 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (208) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +80 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +1 -2
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text2img.py +36 -74
  21. hcpdiff/data/source/text2img_cond.py +9 -15
  22. hcpdiff/diffusion/__init__.py +0 -0
  23. hcpdiff/diffusion/noise/__init__.py +2 -0
  24. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  25. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  26. hcpdiff/diffusion/sampler/__init__.py +5 -0
  27. hcpdiff/diffusion/sampler/base.py +72 -0
  28. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  29. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  30. hcpdiff/diffusion/sampler/edm.py +22 -0
  31. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  35. hcpdiff/easy/__init__.py +2 -0
  36. hcpdiff/easy/cfg/__init__.py +3 -0
  37. hcpdiff/easy/cfg/sd15_train.py +201 -0
  38. hcpdiff/easy/cfg/sdxl_train.py +140 -0
  39. hcpdiff/easy/cfg/t2i.py +177 -0
  40. hcpdiff/easy/model/__init__.py +2 -0
  41. hcpdiff/easy/model/cnet.py +31 -0
  42. hcpdiff/easy/model/loader.py +79 -0
  43. hcpdiff/easy/sampler.py +46 -0
  44. hcpdiff/evaluate/__init__.py +1 -0
  45. hcpdiff/evaluate/previewer.py +60 -0
  46. hcpdiff/loss/__init__.py +4 -1
  47. hcpdiff/loss/base.py +41 -0
  48. hcpdiff/loss/gw.py +35 -0
  49. hcpdiff/loss/ssim.py +37 -0
  50. hcpdiff/loss/vlb.py +79 -0
  51. hcpdiff/loss/weighting.py +66 -0
  52. hcpdiff/models/__init__.py +2 -2
  53. hcpdiff/models/cfg_context.py +17 -14
  54. hcpdiff/models/compose/compose_hook.py +44 -23
  55. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  56. hcpdiff/models/compose/sdxl_composer.py +4 -4
  57. hcpdiff/models/container.py +1 -1
  58. hcpdiff/models/controlnet.py +16 -16
  59. hcpdiff/models/lora_base_patch.py +14 -25
  60. hcpdiff/models/lora_layers.py +3 -9
  61. hcpdiff/models/lora_layers_patch.py +14 -24
  62. hcpdiff/models/text_emb_ex.py +84 -6
  63. hcpdiff/models/textencoder_ex.py +54 -18
  64. hcpdiff/models/wrapper/__init__.py +3 -0
  65. hcpdiff/models/wrapper/pixart.py +19 -0
  66. hcpdiff/models/wrapper/sd.py +218 -0
  67. hcpdiff/models/wrapper/utils.py +20 -0
  68. hcpdiff/parser/__init__.py +1 -0
  69. hcpdiff/parser/embpt.py +32 -0
  70. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  71. hcpdiff/tools/dataset_generator.py +94 -0
  72. hcpdiff/tools/download_hf_model.py +24 -0
  73. hcpdiff/tools/embedding_convert.py +6 -2
  74. hcpdiff/tools/init_proj.py +3 -21
  75. hcpdiff/tools/lora_convert.py +19 -15
  76. hcpdiff/tools/save_model.py +12 -0
  77. hcpdiff/tools/sd2diffusers.py +1 -1
  78. hcpdiff/train_colo.py +1 -1
  79. hcpdiff/train_deepspeed.py +1 -1
  80. hcpdiff/trainer_ac.py +79 -0
  81. hcpdiff/trainer_ac_single.py +31 -0
  82. hcpdiff/utils/__init__.py +0 -2
  83. hcpdiff/utils/inpaint_pipe.py +790 -0
  84. hcpdiff/utils/net_utils.py +29 -6
  85. hcpdiff/utils/pipe_hook.py +46 -33
  86. hcpdiff/utils/utils.py +21 -4
  87. hcpdiff/workflow/__init__.py +15 -10
  88. hcpdiff/workflow/daam/__init__.py +1 -0
  89. hcpdiff/workflow/daam/act.py +66 -0
  90. hcpdiff/workflow/daam/hook.py +109 -0
  91. hcpdiff/workflow/diffusion.py +128 -136
  92. hcpdiff/workflow/fast.py +31 -0
  93. hcpdiff/workflow/flow.py +67 -0
  94. hcpdiff/workflow/io.py +36 -68
  95. hcpdiff/workflow/model.py +46 -43
  96. hcpdiff/workflow/text.py +84 -52
  97. hcpdiff/workflow/utils.py +32 -12
  98. hcpdiff/workflow/vae.py +37 -38
  99. hcpdiff-2.1.dist-info/METADATA +285 -0
  100. hcpdiff-2.1.dist-info/RECORD +114 -0
  101. {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
  102. hcpdiff-2.1.dist-info/entry_points.txt +5 -0
  103. hcpdiff/ckpt_manager/base.py +0 -16
  104. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  105. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  106. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -60
  107. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  108. hcpdiff/data/bucket.py +0 -358
  109. hcpdiff/data/caption_loader.py +0 -80
  110. hcpdiff/data/cond_dataset.py +0 -40
  111. hcpdiff/data/crop_info_dataset.py +0 -40
  112. hcpdiff/data/data_processor.py +0 -33
  113. hcpdiff/data/pair_dataset.py +0 -146
  114. hcpdiff/data/sampler.py +0 -54
  115. hcpdiff/data/source/base.py +0 -30
  116. hcpdiff/data/utils.py +0 -80
  117. hcpdiff/infer_workflow.py +0 -57
  118. hcpdiff/loggers/__init__.py +0 -13
  119. hcpdiff/loggers/base_logger.py +0 -76
  120. hcpdiff/loggers/cli_logger.py +0 -40
  121. hcpdiff/loggers/preview/__init__.py +0 -1
  122. hcpdiff/loggers/preview/image_previewer.py +0 -149
  123. hcpdiff/loggers/tensorboard_logger.py +0 -30
  124. hcpdiff/loggers/wandb_logger.py +0 -31
  125. hcpdiff/loggers/webui_logger.py +0 -9
  126. hcpdiff/loss/min_snr_loss.py +0 -52
  127. hcpdiff/models/layers.py +0 -81
  128. hcpdiff/models/plugin.py +0 -348
  129. hcpdiff/models/wrapper.py +0 -75
  130. hcpdiff/noise/__init__.py +0 -3
  131. hcpdiff/noise/noise_base.py +0 -16
  132. hcpdiff/noise/pyramid_noise.py +0 -50
  133. hcpdiff/noise/zero_terminal.py +0 -44
  134. hcpdiff/train_ac.py +0 -565
  135. hcpdiff/train_ac_single.py +0 -39
  136. hcpdiff/utils/caption_tools.py +0 -105
  137. hcpdiff/utils/cfg_net_tools.py +0 -321
  138. hcpdiff/utils/cfg_resolvers.py +0 -16
  139. hcpdiff/utils/ema.py +0 -52
  140. hcpdiff/utils/img_size_tool.py +0 -248
  141. hcpdiff/vis/__init__.py +0 -3
  142. hcpdiff/vis/base_interface.py +0 -12
  143. hcpdiff/vis/disk_interface.py +0 -48
  144. hcpdiff/vis/webui_interface.py +0 -17
  145. hcpdiff/visualizer.py +0 -258
  146. hcpdiff/visualizer_reloadable.py +0 -237
  147. hcpdiff/workflow/base.py +0 -59
  148. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  149. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  150. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  151. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  152. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  153. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  154. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  155. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  156. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  157. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  158. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  159. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  160. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  161. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  162. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  163. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  164. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  165. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  166. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  167. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  168. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  169. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  170. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  171. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  172. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  173. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  174. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  175. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  176. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  177. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  178. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  179. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  180. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  181. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  182. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  183. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  184. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  185. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  186. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  187. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  188. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  189. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  190. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  191. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  192. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -57
  193. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  194. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero2.json +0 -32
  195. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero3.json +0 -39
  196. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  197. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  198. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  199. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  200. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  201. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  202. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  203. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  204. hcpdiff-0.9.0.dist-info/METADATA +0 -199
  205. hcpdiff-0.9.0.dist-info/RECORD +0 -155
  206. hcpdiff-0.9.0.dist-info/entry_points.txt +0 -2
  207. {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
  208. {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
@@ -6,11 +6,19 @@ import torch
6
6
  from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION, Optimizer
7
7
  from torch import nn
8
8
  from torch.optim import lr_scheduler
9
- from transformers import PretrainedConfig, AutoTokenizer
9
+ from transformers import PretrainedConfig, AutoTokenizer, T5EncoderModel, CLIPTextModel
10
10
  from functools import partial
11
+ from huggingface_hub import hf_hub_download
12
+ import json
11
13
 
12
14
  dtype_dict = {'fp32':torch.float32, 'amp':torch.float32, 'fp16':torch.float16, 'bf16':torch.bfloat16}
13
15
 
16
+ try:
17
+ dtype_dict['fp8_e4m3'] = torch.float8_e4m3fn
18
+ dtype_dict['fp8_e5m2'] = torch.float8_e5m2
19
+ except:
20
+ pass
21
+
14
22
  def get_scheduler(cfg, optimizer):
15
23
  if cfg is None:
16
24
  return None
@@ -90,7 +98,7 @@ def auto_tokenizer_cls(pretrained_model_name_or_path: str, revision: str = None)
90
98
  revision=revision, use_fast=False,
91
99
  )
92
100
  return SDXLTokenizer
93
- except OSError:
101
+ except:
94
102
  # not sdxl, only one tokenizer
95
103
  return AutoTokenizer
96
104
 
@@ -102,8 +110,10 @@ def auto_text_encoder_cls(pretrained_model_name_or_path: str, revision: str = No
102
110
  subfolder="text_encoder_2",
103
111
  revision=revision,
104
112
  )
113
+ if text_encoder_config.architectures is None:
114
+ raise ValueError()
105
115
  return SDXLTextEncoder
106
- except OSError:
116
+ except:
107
117
  text_encoder_config = PretrainedConfig.from_pretrained(
108
118
  pretrained_model_name_or_path,
109
119
  subfolder="text_encoder",
@@ -112,16 +122,26 @@ def auto_text_encoder_cls(pretrained_model_name_or_path: str, revision: str = No
112
122
  model_class = text_encoder_config.architectures[0]
113
123
 
114
124
  if model_class == "CLIPTextModel":
115
- from transformers import CLIPTextModel
116
-
117
125
  return CLIPTextModel
118
126
  elif model_class == "RobertaSeriesModelWithTransformation":
119
127
  from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
120
128
 
121
129
  return RobertaSeriesModelWithTransformation
130
+ elif model_class == "T5EncoderModel":
131
+ return T5EncoderModel
122
132
  else:
123
133
  raise ValueError(f"{model_class} is not supported.")
124
134
 
135
+ def get_pipe_name(path: str):
136
+ if os.path.isdir(path):
137
+ json_file = os.path.join(path, "model_index.json")
138
+ else:
139
+ json_file = hf_hub_download(path, "model_index.json")
140
+ with open(json_file, "r", encoding="utf-8") as reader:
141
+ text = reader.read()
142
+ data = json.loads(text)
143
+ return data['_class_name']
144
+
125
145
  def auto_tokenizer(pretrained_model_name_or_path: str, revision: str = None, **kwargs):
126
146
  return auto_tokenizer_cls(pretrained_model_name_or_path, revision).from_pretrained(pretrained_model_name_or_path, revision=revision, **kwargs)
127
147
 
@@ -225,4 +245,7 @@ def split_module_name(layer_name):
225
245
  return parent_name, host_name
226
246
 
227
247
  def get_dtype(dtype):
228
- return dtype_dict.get(dtype, torch.float32)
248
+ if isinstance(dtype, torch.dtype):
249
+ return dtype
250
+ else:
251
+ return dtype_dict.get(dtype, torch.float32)
@@ -2,10 +2,10 @@ from typing import Union, List, Optional, Callable, Dict, Any
2
2
 
3
3
  import PIL
4
4
  import torch
5
- from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipelineLegacy
5
+ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, PixArtTransformer2DModel
6
6
  from diffusers.image_processor import VaeImageProcessor
7
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
8
- from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint_legacy import preprocess_mask, preprocess_image
7
+ from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
8
+ from .inpaint_pipe import preprocess_mask, preprocess_image, StableDiffusionInpaintPipelineLegacy
9
9
  from einops import repeat
10
10
 
11
11
  class HookPipe_T2I(StableDiffusionPipeline):
@@ -17,25 +17,17 @@ class HookPipe_T2I(StableDiffusionPipeline):
17
17
  def device(self) -> torch.device:
18
18
  return torch.device('cuda')
19
19
 
20
- def proc_prompt(self, device, num_images_per_prompt, prompt_embeds = None, negative_prompt_embeds = None):
21
- batch_size = prompt_embeds.shape[0]
22
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
20
+ def proc_prompt(self, device, num_inference_steps, prompt_embeds = None, negative_prompt_embeds = None) -> List[torch.Tensor]:
21
+ if not isinstance(prompt_embeds, list): # to emb for each step
22
+ prompt_embeds = [prompt_embeds]*num_inference_steps
23
+ if not isinstance(negative_prompt_embeds, list): # to emb for each step
24
+ negative_prompt_embeds = [negative_prompt_embeds]*num_inference_steps
23
25
 
24
- bs_embed, seq_len, _ = prompt_embeds.shape
25
- # duplicate text embeddings for each generation per prompt, using mps friendly method
26
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
27
- prompt_embeds = prompt_embeds.view(bs_embed*num_images_per_prompt, seq_len, -1)
26
+ prompt_embeds = [p.to(dtype=self.text_encoder.dtype, device=device) for p in prompt_embeds]
27
+ negative_prompt_embeds = [p.to(dtype=self.text_encoder.dtype, device=device) for p in negative_prompt_embeds]
28
28
 
29
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
30
- seq_len = negative_prompt_embeds.shape[1]
31
-
32
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
33
-
34
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
35
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size*num_images_per_prompt, seq_len, -1)
36
-
37
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
38
- return prompt_embeds
29
+ prompt_embeds = [torch.cat([emb_neg, emb_pos]) for emb_pos, emb_neg in zip(prompt_embeds, negative_prompt_embeds)]
30
+ return prompt_embeds # List[emb_step_i]*num_inference_steps
39
31
 
40
32
  @torch.no_grad()
41
33
  def __call__(
@@ -46,7 +38,6 @@ class HookPipe_T2I(StableDiffusionPipeline):
46
38
  num_inference_steps: int = 50,
47
39
  guidance_scale: float = 7.5,
48
40
  negative_prompt: Optional[Union[str, List[str]]] = None,
49
- num_images_per_prompt: Optional[int] = 1,
50
41
  eta: float = 0.0,
51
42
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
52
43
  latents: Optional[torch.FloatTensor] = None,
@@ -74,6 +65,8 @@ class HookPipe_T2I(StableDiffusionPipeline):
74
65
  batch_size = 1
75
66
  elif prompt is not None and isinstance(prompt, list):
76
67
  batch_size = len(prompt)
68
+ elif isinstance(prompt_embeds, list):
69
+ batch_size = prompt_embeds[0].shape[0]
77
70
  else:
78
71
  batch_size = prompt_embeds.shape[0]
79
72
 
@@ -84,7 +77,7 @@ class HookPipe_T2I(StableDiffusionPipeline):
84
77
  do_classifier_free_guidance = guidance_scale>1.0
85
78
 
86
79
  # 3. Encode input prompt
87
- prompt_embeds = self.proc_prompt(device, num_images_per_prompt,
80
+ prompt_embeds = self.proc_prompt(device, num_inference_steps,
88
81
  prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds)
89
82
 
90
83
  # 4. Prepare timesteps
@@ -95,11 +88,11 @@ class HookPipe_T2I(StableDiffusionPipeline):
95
88
  # 5. Prepare latent variables
96
89
  num_channels_latents = self.unet.config.in_channels
97
90
  latents = self.prepare_latents(
98
- batch_size*num_images_per_prompt,
91
+ batch_size,
99
92
  num_channels_latents,
100
93
  height,
101
94
  width,
102
- prompt_embeds.dtype,
95
+ prompt_embeds[0].dtype,
103
96
  device,
104
97
  generator,
105
98
  latents,
@@ -114,7 +107,7 @@ class HookPipe_T2I(StableDiffusionPipeline):
114
107
  crop_info = torch.tensor([height, width, 0, 0, height, width], dtype=torch.float)
115
108
  else:
116
109
  crop_info = torch.tensor([height, width, *crop_coord], dtype=torch.float)
117
- crop_info = crop_info.to(device).repeat(batch_size*num_images_per_prompt, 1)
110
+ crop_info = crop_info.to(device).repeat(batch_size, 1)
118
111
  pooled_output = pooled_output.to(device)
119
112
 
120
113
  if do_classifier_free_guidance:
@@ -129,12 +122,20 @@ class HookPipe_T2I(StableDiffusionPipeline):
129
122
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
130
123
 
131
124
  if pooled_output is None:
132
- noise_pred = self.unet(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
133
- cross_attention_kwargs=cross_attention_kwargs, ).sample
125
+ if isinstance(self.unet, PixArtTransformer2DModel):
126
+ added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
127
+ noise_pred = self.unet(latent_model_input, timestep=t.repeat(latent_model_input.shape[0]), encoder_hidden_states=prompt_embeds[i],
128
+ encoder_attention_mask=encoder_attention_mask,
129
+ cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
130
+ else:
131
+ noise_pred = self.unet(latent_model_input, timestep=t, encoder_hidden_states=prompt_embeds[i],
132
+ encoder_attention_mask=encoder_attention_mask,
133
+ cross_attention_kwargs=cross_attention_kwargs).sample
134
134
  else:
135
135
  added_cond_kwargs = {"text_embeds":pooled_output, "time_ids":crop_info}
136
136
  # predict the noise residual
137
- noise_pred = self.unet(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
137
+ noise_pred = self.unet(latent_model_input, timestep=t, encoder_hidden_states=prompt_embeds[i],
138
+ encoder_attention_mask=encoder_attention_mask,
138
139
  cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
139
140
 
140
141
  # perform guidance
@@ -142,6 +143,10 @@ class HookPipe_T2I(StableDiffusionPipeline):
142
143
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
143
144
  noise_pred = noise_pred_uncond+guidance_scale*(noise_pred_text-noise_pred_uncond)
144
145
 
146
+ # learned sigma
147
+ if self.unet.config.out_channels // 2 == num_channels_latents:
148
+ noise_pred = noise_pred.chunk(2, dim=1)[0]
149
+
145
150
  # x_t -> x_0
146
151
  alpha_prod_t = alphas_cumprod[t.long()]
147
152
  beta_prod_t = 1-alpha_prod_t
@@ -155,7 +160,8 @@ class HookPipe_T2I(StableDiffusionPipeline):
155
160
  if i == len(timesteps)-1 or ((i+1)>num_warmup_steps and (i+1)%self.scheduler.order == 0):
156
161
  progress_bar.update()
157
162
  if callback is not None and i%callback_steps == 0:
158
- if callback(i, t, num_inference_steps, latents_x0):
163
+ latents = callback(i, t, num_inference_steps, latents_x0, latents)
164
+ if latents is None:
159
165
  return None
160
166
 
161
167
  latents = latents.to(dtype=self.vae.dtype)
@@ -277,8 +283,13 @@ class HookPipe_I2I(StableDiffusionImg2ImgPipeline):
277
283
 
278
284
  # predict the noise residual
279
285
  if pooled_output is None:
280
- noise_pred = self.unet(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
281
- cross_attention_kwargs=cross_attention_kwargs, ).sample
286
+ if isinstance(self.unet, PixArtTransformer2DModel):
287
+ added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
288
+ noise_pred = self.unet(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
289
+ cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
290
+ else:
291
+ noise_pred = self.unet(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
292
+ cross_attention_kwargs=cross_attention_kwargs, ).sample
282
293
  else:
283
294
  added_cond_kwargs = {"text_embeds":pooled_output, "time_ids":crop_info}
284
295
  # predict the noise residual
@@ -302,7 +313,8 @@ class HookPipe_I2I(StableDiffusionImg2ImgPipeline):
302
313
  if i == len(timesteps)-1 or ((i+1)>num_warmup_steps and (i+1)%self.scheduler.order == 0):
303
314
  progress_bar.update()
304
315
  if callback is not None and i%callback_steps == 0:
305
- if callback(i, t, num_inference_steps, latents_x0):
316
+ latents = callback(i, t, num_inference_steps, latents_x0, latents)
317
+ if latents is None:
306
318
  return None
307
319
 
308
320
  latents = latents.to(dtype=self.vae.dtype)
@@ -450,7 +462,8 @@ class HookPipe_Inpaint(StableDiffusionInpaintPipelineLegacy):
450
462
  if i == len(timesteps)-1 or ((i+1)>num_warmup_steps and (i+1)%self.scheduler.order == 0):
451
463
  progress_bar.update()
452
464
  if callback is not None and i%callback_steps == 0:
453
- if callback(i, t, num_inference_steps, latents_x0):
465
+ latents = callback(i, t, num_inference_steps, latents_x0, latents)
466
+ if latents is None:
454
467
  return None
455
468
 
456
469
  # use original latents corresponding to unmasked portions of the image
hcpdiff/utils/utils.py CHANGED
@@ -56,8 +56,8 @@ def remove_config_undefined(cfg):
56
56
  def load_config(path, remove_undefined=True):
57
57
  cfg = OmegaConf.load(path)
58
58
  if '_base_' in cfg:
59
- for base in cfg['_base_']:
60
- cfg = OmegaConf.merge(load_config(base, remove_undefined=False), cfg)
59
+ base_cfgs = [load_config(base, remove_undefined=False) for base in cfg['_base_']]
60
+ cfg = OmegaConf.merge(*base_cfgs, cfg)
61
61
  del cfg['_base_']
62
62
  if remove_undefined:
63
63
  cfg = remove_config_undefined(cfg)
@@ -85,7 +85,7 @@ def get_cfg_range(cfg_text:str):
85
85
  def to_validate_file(name):
86
86
  rstr = r"[\/\\\:\*\?\"\<\>\|]" # '/ \ : * ? " < > |'
87
87
  new_title = re.sub(rstr, "_", name) # 替换为下划线
88
- return new_title[:50]
88
+ return new_title[:200]
89
89
 
90
90
  def make_mask(start, end, length):
91
91
  mask=torch.zeros(length)
@@ -159,4 +159,21 @@ def pad_attn_bias(x, attn_bias, block_size=8):
159
159
  # 在k维度上进行填充
160
160
  x_padded = F.pad(x, (0, 0, 0, padding_l, 0, 0), mode='constant', value=0)
161
161
  attn_bias_padded = F.pad(attn_bias, (0, padding_l, 0, 0), mode='constant', value=0)
162
- return x_padded, attn_bias_padded
162
+ return x_padded, attn_bias_padded
163
+
164
+ def linear_interp(t, x):
165
+ '''
166
+ t_l ---------t_h
167
+ ^x
168
+ '''
169
+ if (x>=len(t)).any():
170
+ x = x.clamp(max=len(t)-1e-6)
171
+ x0 = x.floor().long()
172
+ x1 = x0 + 1
173
+
174
+ y0 = t[x0]
175
+ y1 = t[x1]
176
+
177
+ xd = (x - x0.float())
178
+
179
+ return y0 * (1 - xd) + y1 * xd
@@ -1,15 +1,20 @@
1
- from .base import BasicAction, MemoryMixin, from_memory, ExecAction, LoopAction
2
- from .diffusion import InputFeederAction, PrepareDiffusionAction, MakeLatentAction, NoisePredAction, SampleAction, DiffusionStepAction, \
3
- X0PredAction, SeedAction, MakeTimestepsAction
1
+ from .diffusion import InputFeederAction, MakeLatentAction, DenoiseAction, SampleAction, DiffusionStepAction, \
2
+ X0PredAction, SeedAction, MakeTimestepsAction, PrepareDiffusionAction, time_iter
4
3
  from .text import TextEncodeAction, TextHookAction, AttnMultTextEncodeAction
5
4
  from .vae import EncodeAction, DecodeAction
6
- from .io import LoadModelsAction, SaveImageAction, BuildModelLoaderAction, LoadPartAction, LoadLoraAction, LoadPluginAction
7
- from .utils import LatentResizeAction, ImageResizeAction
8
- from .model import VaeOptimizeAction, BuildOffloadAction, XformersEnableAction, StartTextEncode, StartDiffusion, EndTextEncode, EndDiffusion
5
+ from .io import BuildModelsAction, SaveImageAction, LoadImageAction
6
+ from .utils import LatentResizeAction, ImageResizeAction, FeedtoCNetAction
7
+ from .model import VaeOptimizeAction, BuildOffloadAction, XformersEnableAction
8
+ #from .flow import FilePromptAction
9
+
10
+ try:
11
+ from .fast import SFastCompileAction
12
+ except:
13
+ print('stable fast not installed.')
9
14
 
10
15
  from omegaconf import OmegaConf
11
16
 
12
- OmegaConf.register_new_resolver("hcp.from_memory", lambda mem_name: OmegaConf.create({
13
- '_target_': 'hcpdiff.workflow.from_memory',
14
- 'mem_name': mem_name,
15
- }))
17
+ OmegaConf.register_new_resolver("hcp.from_memory", lambda mem_name:OmegaConf.create({
18
+ '_target_':'hcpdiff.workflow.from_memory',
19
+ 'mem_name':mem_name,
20
+ }))
@@ -0,0 +1 @@
1
+ from .act import CaptureCrossAttnAction, SaveWordAttnAction
@@ -0,0 +1,66 @@
1
+ import os
2
+ from io import BytesIO
3
+
4
+ import numpy as np
5
+ from PIL import Image
6
+ from hcpdiff.utils import to_validate_file
7
+ from rainbowneko.utils import types_support
8
+ from matplotlib import pyplot as plt
9
+ from rainbowneko.infer import BasicAction, Actions
10
+
11
+ from .hook import DiffusionHeatMapHooker
12
+
13
+ class CaptureCrossAttnAction(Actions):
14
+ def forward(self, prompt, denoiser, tokenizer, vae, **states):
15
+ bs = len(prompt)
16
+ N_head = 8
17
+ with DiffusionHeatMapHooker(denoiser, tokenizer, vae_scale_factor=vae.vae_scale_factor) as tc:
18
+ states = super().forward(**states)
19
+ heat_maps = [tc.compute_global_heat_map(prompt=prompt[i], head_idxs=range(N_head*i, N_head*(i+1))) for i in range(bs)]
20
+
21
+ return {**states, 'cross_attn_heat_maps':heat_maps}
22
+
23
+ class SaveWordAttnAction(BasicAction):
24
+
25
+ def __init__(self, save_root: str, N_col: int = 4, image_type: str = 'png', quality: int = 95, key_map_in=None, key_map_out=None):
26
+ super().__init__(key_map_in, key_map_out)
27
+ self.save_root = save_root
28
+ self.image_type = image_type
29
+ self.quality = quality
30
+ self.N_col = N_col
31
+
32
+ os.makedirs(save_root, exist_ok=True)
33
+
34
+ def draw_attn(self, tokenizer, prompt, image, global_heat_map):
35
+ prompt=tokenizer.bos_token+prompt+tokenizer.eos_token
36
+ tokens = [token.replace("</w>", "") for token in tokenizer.tokenize(prompt)]
37
+
38
+ d_len = self.N_col
39
+ plt.rcParams['figure.dpi'] = 300
40
+ plt.rcParams.update({'font.size':12})
41
+ h = int(np.ceil(len(tokens)/d_len))
42
+ fig, ax = plt.subplots(h, d_len, figsize=(2*d_len, 2*h))
43
+ for ax_ in ax.flatten():
44
+ ax_.set_xticks([])
45
+ ax_.set_yticks([])
46
+ for i, token in enumerate(tokens):
47
+ heat_map = global_heat_map.compute_word_heat_map(token, word_idx=i)
48
+ if h==1:
49
+ heat_map.plot_overlay(image, ax=ax[i%d_len])
50
+ else:
51
+ heat_map.plot_overlay(image, ax=ax[i//d_len, i%d_len])
52
+ # plt.tight_layout()
53
+
54
+ buf = BytesIO()
55
+ plt.savefig(buf, format='png')
56
+ buf.seek(0)
57
+ return Image.open(buf)
58
+
59
+ def forward(self, tokenizer, images, prompt, seeds, cross_attn_heat_maps, **states):
60
+ num_img_exist = max([0]+[int(x.split('-', 1)[0]) for x in os.listdir(self.save_root) if x.rsplit('.', 1)[-1] in types_support])
61
+
62
+ for bid, (p, img) in enumerate(zip(prompt, images)):
63
+ img_path = os.path.join(self.save_root, f"{num_img_exist}-{seeds[bid]}-cross_attn-{to_validate_file(prompt[0])}.{self.image_type}")
64
+ img = self.draw_attn(tokenizer, p, img, cross_attn_heat_maps[bid])
65
+ img.save(img_path, quality=self.quality)
66
+ num_img_exist += 1
@@ -0,0 +1,109 @@
1
+ from daam import AggregateHooker, RawHeatMapCollection, UNetCrossAttentionLocator, GlobalHeatMap
2
+ from daam.trace import UNetCrossAttentionHooker
3
+ from typing import List
4
+ from diffusers import UNet2DConditionModel
5
+ from PIL import Image
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ def auto_autocast(*args, **kwargs):
11
+ if not torch.cuda.is_available():
12
+ kwargs['enabled'] = False
13
+
14
+ return torch.cuda.amp.autocast(*args, **kwargs)
15
+
16
+ class DiffusionHeatMapHooker(AggregateHooker):
17
+ def __init__(
18
+ self,
19
+ unet: UNet2DConditionModel,
20
+ tokenizer,
21
+ vae_scale_factor: int,
22
+ low_memory: bool = False,
23
+ load_heads: bool = False,
24
+ save_heads: bool = False,
25
+ data_dir: str = None
26
+ ):
27
+ self.all_heat_maps = RawHeatMapCollection()
28
+ h = (unet.config.sample_size * vae_scale_factor)
29
+ self.latent_hw = 4096 if h == 512 or h == 1024 else 9216 # 64x64 or 96x96 depending on if it's 2.0-v or 2.0
30
+ locate_middle = load_heads or save_heads
31
+ self.locator = UNetCrossAttentionLocator(restrict={0} if low_memory else None, locate_middle_block=locate_middle)
32
+ self.last_prompt: str = ''
33
+ self.last_image: Image.Image = None
34
+ self.time_idx = 0
35
+ self._gen_idx = 0
36
+
37
+ self.tokenizer = tokenizer
38
+
39
+ modules = [
40
+ UNetCrossAttentionHooker(
41
+ x,
42
+ self,
43
+ layer_idx=idx,
44
+ latent_hw=self.latent_hw,
45
+ load_heads=load_heads,
46
+ save_heads=save_heads,
47
+ data_dir=data_dir
48
+ ) for idx, x in enumerate(self.locator.locate(unet))
49
+ ]
50
+
51
+ super().__init__(modules)
52
+
53
+ def time_callback(self, *args, **kwargs):
54
+ self.time_idx += 1
55
+
56
+ @property
57
+ def layer_names(self):
58
+ return self.locator.layer_names
59
+
60
+ def compute_global_heat_map(self, prompt=None, factors=None, head_idxs: List[int]=None, layer_idx=None, normalize=False):
61
+ # type: (str, List[float], int, int, bool) -> GlobalHeatMap
62
+ """
63
+ Compute the global heat map for the given prompt, aggregating across time (inference steps) and space (different
64
+ spatial transformer block heat maps).
65
+
66
+ Args:
67
+ prompt: The prompt to compute the heat map for. If none, uses the last prompt that was used for generation.
68
+ factors: Restrict the application to heat maps with spatial factors in this set. If `None`, use all sizes.
69
+ head_idx: Restrict the application to heat maps with this head index. If `None`, use all heads.
70
+ layer_idx: Restrict the application to heat maps with this layer index. If `None`, use all layers.
71
+
72
+ Returns:
73
+ A heat map object for computing word-level heat maps.
74
+ """
75
+ heat_maps = self.all_heat_maps
76
+
77
+ if prompt is None:
78
+ prompt = self.last_prompt
79
+
80
+ if factors is None:
81
+ factors = {0, 1, 2, 4, 8, 16, 32, 64}
82
+ else:
83
+ factors = set(factors)
84
+
85
+ all_merges = []
86
+ x = int(np.sqrt(self.latent_hw))
87
+
88
+ with auto_autocast(dtype=torch.float32):
89
+ for (factor, layer, head), heat_map in heat_maps:
90
+ if (head_idxs is None or head in head_idxs) and (layer_idx is None or layer_idx == layer):
91
+ heat_map = heat_map.unsqueeze(1)/25 # [L,1,H,W]
92
+ # The clamping fixes undershoot.
93
+ all_merges.append(F.interpolate(heat_map, size=(x, x), mode='bicubic').clamp_(min=0))
94
+
95
+ try:
96
+ maps = torch.stack(all_merges, dim=0) # [B*head, L, 1, H, W]
97
+ except RuntimeError:
98
+ if head_idxs is not None or layer_idx is not None:
99
+ raise RuntimeError('No heat maps found for the given parameters.')
100
+ else:
101
+ raise RuntimeError('No heat maps found. Did you forget to call `with trace(...)` during generation?')
102
+
103
+ maps = maps.mean(0)[:, 0] # [L,H,W]
104
+ #maps = maps[:len(self.tokenizer.tokenize(prompt)) + 2] # 1 for SOS and 1 for padding
105
+
106
+ if normalize:
107
+ maps = maps / (maps[1:-1].sum(0, keepdim=True) + 1e-6) # drop out [SOS] and [PAD] for proper probabilities
108
+
109
+ return GlobalHeatMap(self.tokenizer, prompt, maps)