hcpdiff 0.9.1__py3-none-any.whl → 2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +252 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +90 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +3 -3
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text.py +40 -0
  21. hcpdiff/data/source/text2img.py +36 -74
  22. hcpdiff/data/source/text2img_cond.py +9 -15
  23. hcpdiff/diffusion/__init__.py +0 -0
  24. hcpdiff/diffusion/noise/__init__.py +2 -0
  25. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  26. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  27. hcpdiff/diffusion/sampler/__init__.py +5 -0
  28. hcpdiff/diffusion/sampler/base.py +72 -0
  29. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  30. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  31. hcpdiff/diffusion/sampler/edm.py +22 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  35. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  36. hcpdiff/easy/__init__.py +2 -0
  37. hcpdiff/easy/cfg/__init__.py +3 -0
  38. hcpdiff/easy/cfg/sd15_train.py +207 -0
  39. hcpdiff/easy/cfg/sdxl_train.py +147 -0
  40. hcpdiff/easy/cfg/t2i.py +228 -0
  41. hcpdiff/easy/model/__init__.py +2 -0
  42. hcpdiff/easy/model/cnet.py +31 -0
  43. hcpdiff/easy/model/loader.py +79 -0
  44. hcpdiff/easy/sampler.py +46 -0
  45. hcpdiff/evaluate/__init__.py +1 -0
  46. hcpdiff/evaluate/previewer.py +60 -0
  47. hcpdiff/loss/__init__.py +4 -1
  48. hcpdiff/loss/base.py +41 -0
  49. hcpdiff/loss/gw.py +35 -0
  50. hcpdiff/loss/ssim.py +37 -0
  51. hcpdiff/loss/vlb.py +79 -0
  52. hcpdiff/loss/weighting.py +66 -0
  53. hcpdiff/models/__init__.py +2 -2
  54. hcpdiff/models/cfg_context.py +17 -14
  55. hcpdiff/models/compose/compose_hook.py +44 -23
  56. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  57. hcpdiff/models/compose/sdxl_composer.py +4 -4
  58. hcpdiff/models/controlnet.py +16 -16
  59. hcpdiff/models/lora_base_patch.py +14 -25
  60. hcpdiff/models/lora_layers.py +3 -9
  61. hcpdiff/models/lora_layers_patch.py +14 -24
  62. hcpdiff/models/text_emb_ex.py +84 -6
  63. hcpdiff/models/textencoder_ex.py +54 -18
  64. hcpdiff/models/wrapper/__init__.py +3 -0
  65. hcpdiff/models/wrapper/pixart.py +19 -0
  66. hcpdiff/models/wrapper/sd.py +218 -0
  67. hcpdiff/models/wrapper/utils.py +20 -0
  68. hcpdiff/parser/__init__.py +1 -0
  69. hcpdiff/parser/embpt.py +32 -0
  70. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  71. hcpdiff/tools/dataset_generator.py +94 -0
  72. hcpdiff/tools/download_hf_model.py +24 -0
  73. hcpdiff/tools/init_proj.py +3 -21
  74. hcpdiff/tools/lora_convert.py +18 -17
  75. hcpdiff/tools/save_model.py +12 -0
  76. hcpdiff/tools/sd2diffusers.py +1 -1
  77. hcpdiff/train_colo.py +1 -1
  78. hcpdiff/train_deepspeed.py +1 -1
  79. hcpdiff/trainer_ac.py +79 -0
  80. hcpdiff/trainer_ac_single.py +31 -0
  81. hcpdiff/utils/__init__.py +0 -2
  82. hcpdiff/utils/inpaint_pipe.py +7 -2
  83. hcpdiff/utils/net_utils.py +29 -6
  84. hcpdiff/utils/pipe_hook.py +24 -7
  85. hcpdiff/utils/utils.py +21 -4
  86. hcpdiff/workflow/__init__.py +15 -10
  87. hcpdiff/workflow/daam/__init__.py +1 -0
  88. hcpdiff/workflow/daam/act.py +66 -0
  89. hcpdiff/workflow/daam/hook.py +109 -0
  90. hcpdiff/workflow/diffusion.py +118 -128
  91. hcpdiff/workflow/fast.py +31 -0
  92. hcpdiff/workflow/flow.py +67 -0
  93. hcpdiff/workflow/io.py +36 -130
  94. hcpdiff/workflow/model.py +46 -43
  95. hcpdiff/workflow/text.py +60 -47
  96. hcpdiff/workflow/utils.py +32 -12
  97. hcpdiff/workflow/vae.py +37 -38
  98. hcpdiff-2.2.dist-info/METADATA +299 -0
  99. hcpdiff-2.2.dist-info/RECORD +115 -0
  100. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/WHEEL +1 -1
  101. hcpdiff-2.2.dist-info/entry_points.txt +5 -0
  102. hcpdiff/ckpt_manager/base.py +0 -16
  103. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  104. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  105. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
  106. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  107. hcpdiff/data/bucket.py +0 -358
  108. hcpdiff/data/caption_loader.py +0 -80
  109. hcpdiff/data/cond_dataset.py +0 -40
  110. hcpdiff/data/crop_info_dataset.py +0 -40
  111. hcpdiff/data/data_processor.py +0 -33
  112. hcpdiff/data/pair_dataset.py +0 -146
  113. hcpdiff/data/sampler.py +0 -54
  114. hcpdiff/data/source/base.py +0 -30
  115. hcpdiff/data/utils.py +0 -80
  116. hcpdiff/deprecated/__init__.py +0 -1
  117. hcpdiff/deprecated/cfg_converter.py +0 -81
  118. hcpdiff/deprecated/lora_convert.py +0 -31
  119. hcpdiff/infer_workflow.py +0 -57
  120. hcpdiff/loggers/__init__.py +0 -13
  121. hcpdiff/loggers/base_logger.py +0 -76
  122. hcpdiff/loggers/cli_logger.py +0 -40
  123. hcpdiff/loggers/preview/__init__.py +0 -1
  124. hcpdiff/loggers/preview/image_previewer.py +0 -149
  125. hcpdiff/loggers/tensorboard_logger.py +0 -30
  126. hcpdiff/loggers/wandb_logger.py +0 -31
  127. hcpdiff/loggers/webui_logger.py +0 -9
  128. hcpdiff/loss/min_snr_loss.py +0 -52
  129. hcpdiff/models/layers.py +0 -81
  130. hcpdiff/models/plugin.py +0 -348
  131. hcpdiff/models/wrapper.py +0 -75
  132. hcpdiff/noise/__init__.py +0 -3
  133. hcpdiff/noise/noise_base.py +0 -16
  134. hcpdiff/noise/pyramid_noise.py +0 -50
  135. hcpdiff/noise/zero_terminal.py +0 -44
  136. hcpdiff/train_ac.py +0 -566
  137. hcpdiff/train_ac_single.py +0 -39
  138. hcpdiff/utils/caption_tools.py +0 -105
  139. hcpdiff/utils/cfg_net_tools.py +0 -321
  140. hcpdiff/utils/cfg_resolvers.py +0 -16
  141. hcpdiff/utils/ema.py +0 -52
  142. hcpdiff/utils/img_size_tool.py +0 -248
  143. hcpdiff/vis/__init__.py +0 -3
  144. hcpdiff/vis/base_interface.py +0 -12
  145. hcpdiff/vis/disk_interface.py +0 -48
  146. hcpdiff/vis/webui_interface.py +0 -17
  147. hcpdiff/viser_fast.py +0 -138
  148. hcpdiff/visualizer.py +0 -265
  149. hcpdiff/visualizer_reloadable.py +0 -237
  150. hcpdiff/workflow/base.py +0 -59
  151. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  152. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  153. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  154. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  155. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  156. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  157. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  158. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  159. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  160. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  161. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  162. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  163. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  164. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  165. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  166. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  167. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  168. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  169. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  170. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  171. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  172. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  173. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  174. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  175. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  176. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  177. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  178. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  179. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  180. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  181. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  182. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  183. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  184. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  185. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  186. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  187. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  188. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  189. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  190. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  191. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  192. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  193. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  194. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  195. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
  196. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  197. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
  198. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
  199. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  200. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  201. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  202. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  203. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  204. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  205. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  206. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  207. hcpdiff-0.9.1.dist-info/METADATA +0 -199
  208. hcpdiff-0.9.1.dist-info/RECORD +0 -160
  209. hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
  210. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info/licenses}/LICENSE +0 -0
  211. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/top_level.txt +0 -0
@@ -8,19 +8,18 @@ lora_layers.py
8
8
  :Licence: Apache-2.0
9
9
  """
10
10
 
11
+ import math
12
+
11
13
  import torch
12
- from einops import einsum, rearrange
14
+ from einops import einsum
13
15
  from torch import nn
14
16
  from torch.nn import functional as F
15
17
 
16
18
  from .lora_base_patch import LoraBlock, PatchPluginContainer
17
- from .layers import GroupLinear
18
- import math
19
- from typing import Union, List
20
19
 
21
20
  class LoraLayer(LoraBlock):
22
- def __init__(self, lora_id: int, host, rank=1, dropout=0.1, alpha=1.0, bias=False, alpha_auto_scale=True, **kwargs):
23
- super().__init__(lora_id, host, rank, dropout, alpha=alpha, bias=bias, alpha_auto_scale=alpha_auto_scale, **kwargs)
21
+ def __init__(self, name: str, host, rank=1, dropout=0.0, alpha=1.0, bias=False, alpha_auto_scale=True, **kwargs):
22
+ super().__init__(name, host, rank, dropout, alpha=alpha, bias=bias, alpha_auto_scale=alpha_auto_scale, **kwargs)
24
23
 
25
24
  class LinearLayer(LoraBlock.LinearLayer):
26
25
  def __init__(self, host:nn.Linear, rank, bias, block):
@@ -99,6 +98,11 @@ class LoraLayer(LoraBlock):
99
98
  b = self.bias.data if self.bias else None
100
99
  return w, b
101
100
 
101
+ def none_add(a, b):
102
+ if a is None:
103
+ return b
104
+ return a+b
105
+
102
106
  class DAPPPatchContainer(PatchPluginContainer):
103
107
  def forward(self, x, *args, **kwargs):
104
108
  weight_p = None
@@ -107,25 +111,11 @@ class DAPPPatchContainer(PatchPluginContainer):
107
111
  bias_n = None
108
112
  for name in self.plugin_names:
109
113
  if self[name].branch=='p':
110
- if weight_p is None:
111
- weight_p = self[name].get_weight()
112
- else:
113
- weight_p = weight_p + self[name].get_weight()
114
-
115
- if bias_p is None:
116
- bias_p = self[name].get_bias()
117
- else:
118
- bias_p = bias_p+self[name].get_bias()
114
+ weight_p = none_add(weight_p, self[name].get_weight())
115
+ bias_p = none_add(bias_p, self[name].get_bias())
119
116
  elif self[name].branch=='n':
120
- if weight_n is None:
121
- weight_n = self[name].get_weight()
122
- else:
123
- weight_n = weight_n + self[name].get_weight()
124
-
125
- if bias_n is None:
126
- bias_n = self[name].get_bias()
127
- else:
128
- bias_n = bias_n+self[name].get_bias()
117
+ weight_n = none_add(weight_n, self[name].get_weight())
118
+ bias_n = none_add(bias_n, self[name].get_bias())
129
119
 
130
120
  B = x.shape[0]//2
131
121
  x_p = self[name].post_forward(x[B:], self._host.weight, weight_p, self._host.bias, bias_p)
@@ -7,16 +7,17 @@ text_emb_ex.py
7
7
  :Created: 10/03/2023
8
8
  :Licence: Apache-2.0
9
9
  """
10
- from typing import Tuple
10
+ from typing import Tuple, Dict, Any
11
11
 
12
12
  import torch
13
13
  from torch import nn
14
14
  import os
15
- from loguru import logger
15
+ from rainbowneko import _share
16
16
  from einops import rearrange, repeat
17
+ import torch.nn.functional as F
17
18
 
18
19
  from ..utils.net_utils import load_emb
19
- from .plugin import SinglePluginBlock
20
+ from rainbowneko.models.plugin import SinglePluginBlock
20
21
 
21
22
  class EmbeddingPTHook(SinglePluginBlock):
22
23
  def __init__(self, token_embedding:nn.Embedding, N_word=75, N_repeats=3):
@@ -37,6 +38,84 @@ class EmbeddingPTHook(SinglePluginBlock):
37
38
  self.input_ids = rearrange(input_ids[0], '(b r) w -> b (r w)', r=self.N_repeats) # 兼容Attention mask
38
39
  return self.input_ids.clip(0, self.num_embeddings-1)
39
40
 
41
+ def forward(self, inputs_embeds:torch.Tensor, *args: Tuple[Any, ...], **kwargs: Dict[str, Any]):
42
+ '''
43
+ :param input_ids: [B, N_ids]
44
+ :param inputs_embeds: [B, N_repeat*(N_word+2), N_emb]
45
+ :return: [B, N_repeat, N_word+2, N_emb]
46
+ '''
47
+ rep_idxs_B = self.input_ids >= self.num_embeddings
48
+ BOS = repeat(inputs_embeds[:,0,:], 'b e -> b r 1 e', r=self.N_repeats)
49
+ EOS = repeat(inputs_embeds[:,-1,:], 'b e -> b r 1 e', r=self.N_repeats)
50
+
51
+ replaced_embeds = []
52
+ for i, (item, rep_idxs, ids_raw) in enumerate(zip(inputs_embeds, rep_idxs_B, self.input_ids)):
53
+ # insert pt to embeddings
54
+ rep_idxs=torch.where(rep_idxs)[0]
55
+ item_new=[]
56
+ rep_idx_last=0
57
+ for rep_idx in rep_idxs:
58
+ rep_idx=rep_idx.item()
59
+ item_new.append(item[rep_idx_last:rep_idx, :])
60
+ item_new.append(self.emb[ids_raw[rep_idx].item()].to(dtype=item.dtype))
61
+ rep_idx_last=rep_idx+1
62
+ item_new.append(item[rep_idx_last:, :])
63
+
64
+ # split to N_repeat sentence
65
+ replaced_item = torch.cat(item_new, dim=0)[1:self.N_word*self.N_repeats+1, :]
66
+ replaced_item = rearrange(replaced_item, '(r w) e -> r w e', r=self.N_repeats, w=self.N_word)
67
+ replaced_item = torch.cat([BOS[i], replaced_item, EOS[i]], dim=1) # [N_repeat, N_word+2, N_emb]
68
+
69
+ replaced_embeds.append(replaced_item)
70
+ return torch.cat(replaced_embeds, dim=0) # [B*N_repeat, N_word+2, N_emb]
71
+
72
+ def remove(self):
73
+ super(EmbeddingPTHook, self).remove()
74
+ self.handle_pre.remove()
75
+
76
+ @classmethod
77
+ def hook(cls, ex_words_emb, tokenizer, text_encoder, **kwargs):
78
+ word_list = list(ex_words_emb.keys())
79
+ tokenizer.add_tokens(word_list)
80
+ token_ids = tokenizer(' '.join(word_list)).input_ids[1:-1]
81
+
82
+ embedding_hook = cls(text_encoder.get_input_embeddings(), N_word=tokenizer.model_max_length-2, **kwargs)
83
+ #text_encoder.text_model.embeddings.token_embedding = embedding_hook
84
+ for tid, word in zip(token_ids, word_list):
85
+ embedding_hook.add_emb(ex_words_emb[word], tid)
86
+ _share.loggers.info(f'hook: {word}, len: {ex_words_emb[word].shape[0]}, id: {tid}')
87
+ return embedding_hook
88
+
89
+ @classmethod
90
+ def hook_from_dir(cls, emb_dir, tokenizer, text_encoder, device='cuda:0', **kwargs):
91
+ ex_words_emb = {file[:-3]: nn.Parameter(load_emb(os.path.join(emb_dir, file)).to(device), requires_grad=False)
92
+ for file in os.listdir(emb_dir) if file.endswith('.pt')}
93
+ return cls.hook(ex_words_emb, tokenizer, text_encoder, **kwargs), ex_words_emb
94
+
95
+ class EmbeddingPTInterpHook(SinglePluginBlock):
96
+ def __init__(self, token_embedding:nn.Embedding, N_word=75, N_repeats=3):
97
+ super().__init__('emb_ex', token_embedding)
98
+ self.handle_pre = token_embedding.register_forward_pre_hook(self.pre_hook)
99
+
100
+ new_len = int(token_embedding.num_embeddings*N_repeats)
101
+ original_weights = token_embedding.weight.data.unsqueeze(1)
102
+ token_embedding.weight.data = F.interpolate(original_weights, size=new_len, mode='linear', align_corners=False).squeeze(1)
103
+ token_embedding.num_embeddings = new_len
104
+
105
+ self.N_word=N_word
106
+ self.N_repeats=N_repeats
107
+ self.num_embeddings=token_embedding.num_embeddings
108
+ self.embedding_dim=token_embedding.embedding_dim
109
+ self.emb={}
110
+ self.emb_train=nn.ParameterList()
111
+
112
+ def add_emb(self, emb:nn.Parameter, token_id:int):
113
+ self.emb[token_id]=emb
114
+
115
+ def pre_hook(self, host, input_ids: Tuple[torch.Tensor]):
116
+ self.input_ids = rearrange(input_ids[0], '(b r) w -> b (r w)', r=self.N_repeats) # 兼容Attention mask
117
+ return self.input_ids.clip(0, self.num_embeddings-1)
118
+
40
119
  def forward(self, fea_in:Tuple[torch.Tensor], inputs_embeds:torch.Tensor):
41
120
  '''
42
121
  :param input_ids: [B, N_ids]
@@ -83,12 +162,11 @@ class EmbeddingPTHook(SinglePluginBlock):
83
162
  for tid, word in zip(token_ids, word_list):
84
163
  embedding_hook.add_emb(ex_words_emb[word], tid)
85
164
  if log:
86
- logger.info(f'hook: {word}, len: {ex_words_emb[word].shape[0]}, id: {tid}')
165
+ _share.logger.info(f'hook: {word}, len: {ex_words_emb[word].shape[0]}, id: {tid}')
87
166
  return embedding_hook
88
167
 
89
168
  @classmethod
90
169
  def hook_from_dir(cls, emb_dir, tokenizer, text_encoder, log=True, device='cuda:0', **kwargs):
91
170
  ex_words_emb = {file[:-3]: nn.Parameter(load_emb(os.path.join(emb_dir, file)).to(device), requires_grad=False)
92
171
  for file in os.listdir(emb_dir) if file.endswith('.pt')}
93
- return cls.hook(ex_words_emb, tokenizer, text_encoder, log, **kwargs), ex_words_emb
94
-
172
+ return cls.hook(ex_words_emb, tokenizer, text_encoder, log, **kwargs), ex_words_emb
@@ -8,29 +8,53 @@ textencoder_ex.py
8
8
  :Licence: Apache-2.0
9
9
  """
10
10
 
11
- from typing import Tuple, Optional, List
11
+ from typing import Tuple, Optional
12
12
 
13
13
  import torch
14
14
  from einops import repeat, rearrange
15
15
  from einops.layers.torch import Rearrange
16
+ from loguru import logger
16
17
  from torch import nn
18
+ from transformers import CLIPTextModelWithProjection, T5EncoderModel
17
19
  from transformers.models.clip.modeling_clip import CLIPAttention
18
20
 
19
21
  class TEEXHook:
20
- def __init__(self, text_enc: nn.Module, tokenizer, N_repeats=3, clip_skip=0, clip_final_norm=True, device='cuda', use_attention_mask=False):
22
+ def __init__(self, text_enc: nn.Module, tokenizer, N_repeats=1, clip_skip=0, clip_final_norm=True, use_attention_mask=False):
21
23
  self.text_enc = text_enc
22
24
  self.tokenizer = tokenizer
23
25
 
24
26
  self.N_repeats = N_repeats
25
27
  self.clip_skip = clip_skip
26
28
  self.clip_final_norm = clip_final_norm
27
- self.device = device
28
- self.attn_mult = None
29
29
  self.use_attention_mask = use_attention_mask
30
30
 
31
31
  text_enc.register_forward_hook(self.forward_hook)
32
32
  text_enc.register_forward_pre_hook(self.forward_hook_input)
33
33
 
34
+ def find_final_norm(self, text_enc: nn.Module):
35
+ for module in text_enc.modules():
36
+ if 'final_layer_norm' in module._modules:
37
+ logger.info(f'find final_layer_norm in {type(module)}')
38
+ return module.final_layer_norm
39
+
40
+ logger.info(f'final_layer_norm not found in {type(text_enc)}')
41
+ return None
42
+
43
+ @property
44
+ def clip_final_norm(self):
45
+ return self.final_layer_norm is not None
46
+
47
+ @clip_final_norm.setter
48
+ def clip_final_norm(self, value: bool):
49
+ if value:
50
+ self.final_layer_norm = self.find_final_norm(self.text_enc)
51
+ else:
52
+ self.final_layer_norm = None
53
+
54
+ @property
55
+ def device(self):
56
+ return self.text_enc.device
57
+
34
58
  def encode_prompt_to_emb(self, prompt):
35
59
  text_inputs = self.tokenizer(
36
60
  prompt,
@@ -50,12 +74,23 @@ class TEEXHook:
50
74
  if position_ids is not None:
51
75
  position_ids = position_ids.to(self.device)
52
76
 
53
- prompt_embeds, pooled_output = self.text_enc(
54
- text_input_ids.to(self.device),
55
- attention_mask=attention_mask,
56
- position_ids=position_ids,
57
- output_hidden_states=True,
58
- )
77
+ # align with sd-webui
78
+ if isinstance(self.text_enc, CLIPTextModelWithProjection):
79
+ self.text_enc.text_projection.weight.data = self.text_enc.text_projection.weight.data.t()
80
+
81
+ if isinstance(self.text_enc, T5EncoderModel):
82
+ prompt_embeds, pooled_output = self.text_enc(
83
+ text_input_ids.to(self.device),
84
+ attention_mask=attention_mask,
85
+ output_hidden_states=True,
86
+ )
87
+ else:
88
+ prompt_embeds, pooled_output = self.text_enc(
89
+ text_input_ids.to(self.device),
90
+ attention_mask=attention_mask,
91
+ position_ids=position_ids,
92
+ output_hidden_states=True,
93
+ )
59
94
  return prompt_embeds, pooled_output, attention_mask
60
95
 
61
96
  def forward_hook_input(self, host, feat_in):
@@ -64,13 +99,12 @@ class TEEXHook:
64
99
 
65
100
  def forward_hook(self, host, feat_in: Tuple[torch.Tensor], feat_out):
66
101
  encoder_hidden_states = feat_out['hidden_states'][-self.clip_skip-1]
67
- if self.clip_final_norm:
68
- encoder_hidden_states = self.text_enc.text_model.final_layer_norm(encoder_hidden_states)
102
+ if self.clip_final_norm and self.final_layer_norm is not None:
103
+ encoder_hidden_states = self.final_layer_norm(encoder_hidden_states)
69
104
  if self.text_enc.training and self.clip_skip>0:
70
105
  encoder_hidden_states = encoder_hidden_states+0*feat_out['last_hidden_state'].mean() # avoid unused parameters, make gradient checkpointing happy
71
-
72
106
  encoder_hidden_states = rearrange(encoder_hidden_states, '(b r) ... -> b r ...', r=self.N_repeats) # [B, N_repeat, N_word+2, N_emb]
73
- pooled_output = feat_out.pooler_output
107
+ pooled_output = feat_out.get('pooler_output', feat_out.get('text_embeds', None))
74
108
  # TODO: may have better fusion method
75
109
  if pooled_output is not None:
76
110
  pooled_output = rearrange(pooled_output, '(b r) ... -> b r ...', r=self.N_repeats).mean(dim=1)
@@ -81,7 +115,7 @@ class TEEXHook:
81
115
  return encoder_hidden_states, pooled_output
82
116
 
83
117
  def pool_hidden_states(self, encoder_hidden_states, input_ids):
84
- pooled_output = encoder_hidden_states[:, :, -1, :].mean(dim=1) # [B, N_emb]
118
+ pooled_output = encoder_hidden_states[:, :, -1, :].mean(dim=1) # [B, N_emb]
85
119
  return pooled_output
86
120
 
87
121
  @staticmethod
@@ -147,9 +181,11 @@ class TEEXHook:
147
181
  layer.forward = forward
148
182
 
149
183
  @classmethod
150
- def hook(cls, text_enc: nn.Module, tokenizer, N_repeats=3, clip_skip=0, clip_final_norm=True, device='cuda', use_attention_mask=False):
151
- return cls(text_enc, tokenizer, N_repeats=N_repeats, clip_skip=clip_skip, clip_final_norm=clip_final_norm, device=device, use_attention_mask=use_attention_mask)
184
+ def hook(cls, text_enc: nn.Module, tokenizer, N_repeats=3, clip_skip=0, clip_final_norm=True, use_attention_mask=False):
185
+ return cls(text_enc, tokenizer, N_repeats=N_repeats, clip_skip=clip_skip, clip_final_norm=clip_final_norm,
186
+ use_attention_mask=use_attention_mask)
152
187
 
153
188
  @classmethod
154
189
  def hook_pipe(cls, pipe, N_repeats=3, clip_skip=0, clip_final_norm=True, use_attention_mask=False):
155
- return cls(pipe.text_encoder, pipe.tokenizer, N_repeats=N_repeats, device='cuda', clip_skip=clip_skip, clip_final_norm=clip_final_norm, use_attention_mask=use_attention_mask)
190
+ return cls(pipe.text_encoder, pipe.tokenizer, N_repeats=N_repeats, clip_skip=clip_skip, clip_final_norm=clip_final_norm,
191
+ use_attention_mask=use_attention_mask)
@@ -0,0 +1,3 @@
1
+ from .sd import SD15Wrapper, SDXLWrapper
2
+ from .pixart import PixArtWrapper
3
+ from .utils import TEHookCFG, SD15_TEHookCFG, SDXL_TEHookCFG
@@ -0,0 +1,19 @@
1
+ from .sd import SD15Wrapper
2
+ from hcpdiff.utils import pad_attn_bias
3
+
4
+ class PixArtWrapper(SD15Wrapper):
5
+ def forward_denoiser(self, x_t, prompt_ids, encoder_hidden_states, timesteps, attn_mask=None, position_ids=None, resolution=None, aspect_ratio=None,
6
+ plugin_input={}, **kwargs):
7
+ if attn_mask is not None:
8
+ attn_mask[:, :self.min_attnmask] = 1
9
+ encoder_hidden_states, attn_mask = pad_attn_bias(encoder_hidden_states, attn_mask)
10
+
11
+ input_all = dict(prompt_ids=prompt_ids, timesteps=timesteps, position_ids=position_ids, attn_mask=attn_mask,
12
+ encoder_hidden_states=encoder_hidden_states, **plugin_input)
13
+ if hasattr(self.denoiser, 'input_feeder'):
14
+ for feeder in self.denoiser.input_feeder:
15
+ feeder(input_all)
16
+ added_cond_kwargs = {"resolution":resolution, "aspect_ratio":aspect_ratio}
17
+ model_pred = self.denoiser(x_t, encoder_hidden_states, timesteps, encoder_attention_mask=attn_mask,
18
+ added_cond_kwargs=added_cond_kwargs).sample # Predict the noise residual
19
+ return model_pred
@@ -0,0 +1,218 @@
1
+ from contextlib import nullcontext
2
+ from functools import partial
3
+ from typing import Dict, Union
4
+
5
+ import torch
6
+ from diffusers import AutoencoderKL, UNet2DConditionModel
7
+ from rainbowneko.models.wrapper import BaseWrapper
8
+ from torch import Tensor
9
+ from torch import nn
10
+
11
+ from hcpdiff.diffusion.sampler import BaseSampler
12
+ from hcpdiff.models import TEEXHook
13
+ from hcpdiff.models.compose import ComposeTEEXHook
14
+ from hcpdiff.utils import pad_attn_bias
15
+ from .utils import TEHookCFG, SD15_TEHookCFG, SDXL_TEHookCFG
16
+ from ..cfg_context import CFGContext
17
+
18
+ class SD15Wrapper(BaseWrapper):
19
+ def __init__(self, denoiser: UNet2DConditionModel, TE, vae: AutoencoderKL, noise_sampler: BaseSampler, tokenizer, min_attnmask=0,
20
+ pred_type='eps', TE_hook_cfg:TEHookCFG=SD15_TEHookCFG, cfg_context=CFGContext(), key_map_in=None, key_map_out=None):
21
+ super().__init__()
22
+ self.key_mapper_in = self.build_mapper(key_map_in, None, (
23
+ 'prompt -> prompt_ids', 'image -> image', 'attn_mask -> attn_mask', 'position_ids -> position_ids', 'neg_prompt -> neg_prompt_ids',
24
+ 'neg_attn_mask -> neg_attn_mask', 'neg_position_ids -> neg_position_ids', 'plugin_input -> plugin_input'))
25
+ self.key_mapper_out = self.build_mapper(key_map_out, None, None)
26
+
27
+ self.denoiser = denoiser
28
+ self.TE = TE
29
+ self.vae = vae
30
+ self.noise_sampler = noise_sampler
31
+ self.tokenizer = tokenizer
32
+ self.min_attnmask = min_attnmask
33
+
34
+ self.pred_type = pred_type
35
+
36
+ self.TE_hook_cfg = TEHookCFG.create(TE_hook_cfg)
37
+ self.cfg_context = cfg_context
38
+ self.tokenizer.N_repeats = self.TE_hook_cfg.tokenizer_repeats
39
+
40
+ def post_init(self):
41
+ self.make_TE_hook(self.TE_hook_cfg)
42
+
43
+ self.vae_trainable = False
44
+ if self.vae is not None:
45
+ for p in self.vae.parameters():
46
+ if p.requires_grad:
47
+ self.vae_trainable = True
48
+ break
49
+
50
+ self.TE_trainable = False
51
+ for p in self.TE.parameters():
52
+ if p.requires_grad:
53
+ self.TE_trainable = True
54
+ break
55
+
56
+ def make_TE_hook(self, TE_hook_cfg):
57
+ # Hook and extend text_encoder
58
+ self.text_enc_hook = TEEXHook.hook(self.TE, self.tokenizer, N_repeats=TE_hook_cfg.tokenizer_repeats,
59
+ clip_skip=TE_hook_cfg.clip_skip, clip_final_norm=TE_hook_cfg.clip_final_norm)
60
+
61
+ def get_latents(self, image: Tensor):
62
+ if image.shape[1] == 3:
63
+ with torch.no_grad() if self.vae_trainable else nullcontext():
64
+ latents = self.vae.encode(image.to(dtype=self.vae.dtype)).latent_dist.sample()
65
+ latents = latents*self.vae.config.scaling_factor
66
+ else:
67
+ latents = image # Cached latents
68
+ return latents
69
+
70
+ def forward_TE(self, prompt_ids, timesteps, attn_mask=None, position_ids=None, plugin_input={}, **kwargs):
71
+ input_all = dict(prompt_ids=prompt_ids, timesteps=timesteps, position_ids=position_ids, attn_mask=attn_mask, **plugin_input)
72
+ if hasattr(self.TE, 'input_feeder'):
73
+ for feeder in self.TE.input_feeder:
74
+ feeder(input_all)
75
+ # Get the text embedding for conditioning
76
+ encoder_hidden_states = self.TE(prompt_ids, position_ids=position_ids, attention_mask=attn_mask, output_hidden_states=True)[0]
77
+ return encoder_hidden_states
78
+
79
+ def forward_denoiser(self, x_t, prompt_ids, encoder_hidden_states, timesteps, attn_mask=None, position_ids=None, plugin_input={}, **kwargs):
80
+ if attn_mask is not None:
81
+ attn_mask[:, :self.min_attnmask] = 1
82
+ encoder_hidden_states, attn_mask = pad_attn_bias(encoder_hidden_states, attn_mask)
83
+
84
+ input_all = dict(prompt_ids=prompt_ids, timesteps=timesteps, position_ids=position_ids, attn_mask=attn_mask,
85
+ encoder_hidden_states=encoder_hidden_states, **plugin_input)
86
+ if hasattr(self.denoiser, 'input_feeder'):
87
+ for feeder in self.denoiser.input_feeder:
88
+ feeder(input_all)
89
+ model_pred = self.denoiser(x_t, timesteps, encoder_hidden_states, encoder_attention_mask=attn_mask).sample # Predict the noise residual
90
+ return model_pred
91
+
92
+ def model_forward(self, prompt_ids, image, attn_mask=None, position_ids=None, neg_prompt_ids=None, neg_attn_mask=None, neg_position_ids=None,
93
+ plugin_input={}, **kwargs):
94
+ # input prepare
95
+ x_0 = self.get_latents(image)
96
+ x_t, noise, sigma, timesteps = self.noise_sampler.add_noise_rand_t(x_0)
97
+ x_t_in = x_t*self.noise_sampler.c_in(sigma).to(dtype=x_t.dtype)
98
+
99
+ if neg_prompt_ids:
100
+ prompt_ids = torch.cat([neg_prompt_ids, prompt_ids], dim=0)
101
+ if neg_attn_mask:
102
+ attn_mask = torch.cat([neg_attn_mask, attn_mask], dim=0)
103
+ if neg_position_ids:
104
+ position_ids = torch.cat([neg_position_ids, position_ids], dim=0)
105
+
106
+ # model forward
107
+ x_t_in, timesteps = self.cfg_context.pre(x_t_in, timesteps)
108
+ encoder_hidden_states = self.forward_TE(prompt_ids, timesteps, attn_mask=attn_mask, position_ids=position_ids,
109
+ plugin_input=plugin_input, **kwargs)
110
+ model_pred = self.forward_denoiser(x_t_in, prompt_ids, encoder_hidden_states, timesteps, attn_mask=attn_mask, position_ids=position_ids,
111
+ plugin_input=plugin_input, **kwargs)
112
+ model_pred = self.cfg_context.post(model_pred)
113
+
114
+ return dict(model_pred=model_pred, noise=noise, sigma=sigma, timesteps=timesteps, x_0=x_0, x_t=x_t, pred_type=self.pred_type,
115
+ noise_sampler=self.noise_sampler)
116
+
117
+ def forward(self, ds_name=None, **kwargs):
118
+ model_args, model_kwargs = self.get_map_data(self.key_mapper_in, kwargs, ds_name)
119
+ out = self.model_forward(*model_args, **model_kwargs)
120
+ return self.get_map_data(self.key_mapper_out, out, ds_name=ds_name)[1]
121
+
122
+ def enable_gradient_checkpointing(self):
123
+ def grad_ckpt_enable(m):
124
+ if getattr(m, 'gradient_checkpointing', False):
125
+ m.training = True
126
+
127
+ self.denoiser.enable_gradient_checkpointing()
128
+ if self.TE_trainable:
129
+ self.TE.gradient_checkpointing_enable()
130
+ self.apply(grad_ckpt_enable)
131
+
132
+ def enable_xformers(self):
133
+ self.denoiser.enable_xformers_memory_efficient_attention()
134
+
135
+ @property
136
+ def trainable_parameters(self):
137
+ return [p for p in self.parameters() if p.requires_grad]
138
+
139
+ @property
140
+ def trainable_models(self) -> Dict[str, nn.Module]:
141
+ return {'self':self}
142
+
143
+ def set_dtype(self, dtype, vae_dtype):
144
+ self.dtype = dtype
145
+ self.vae_dtype = vae_dtype
146
+ # Move vae and text_encoder to device and cast to weight_dtype
147
+ if self.vae is not None:
148
+ self.vae = self.vae.to(dtype=vae_dtype)
149
+ if not self.TE_trainable:
150
+ self.TE = self.TE.to(dtype=dtype)
151
+
152
+ @classmethod
153
+ def from_pretrained(cls, models: Union[partial, Dict[str, nn.Module]], **kwargs):
154
+ models = models() if isinstance(models, partial) else models
155
+ return cls(models['denoiser'], models['TE'], models['vae'], models['noise_sampler'], models['tokenizer'], **kwargs)
156
+
157
+ class SDXLWrapper(SD15Wrapper):
158
+ def __init__(self, denoiser: UNet2DConditionModel, TE, vae: AutoencoderKL, noise_sampler: BaseSampler, tokenizer, min_attnmask=0,
159
+ pred_type='eps', TE_hook_cfg:TEHookCFG=SDXL_TEHookCFG, cfg_context=CFGContext(), key_map_in=None, key_map_out=None):
160
+ super().__init__(denoiser, TE, vae, noise_sampler, tokenizer, min_attnmask, pred_type, TE_hook_cfg, cfg_context, key_map_in, key_map_out)
161
+ self.key_mapper_in = self.build_mapper(key_map_in, None, (
162
+ 'prompt -> prompt_ids', 'image -> image', 'attn_mask -> attn_mask', 'position_ids -> position_ids', 'neg_prompt -> neg_prompt_ids',
163
+ 'neg_attn_mask -> neg_attn_mask', 'neg_position_ids -> neg_position_ids', 'plugin_input -> plugin_input', 'coord -> crop_info'))
164
+
165
+ def make_TE_hook(self, TE_hook_cfg):
166
+ # Hook and extend text_encoder
167
+ self.text_enc_hook = ComposeTEEXHook.hook(self.TE, self.tokenizer, N_repeats=TE_hook_cfg.tokenizer_repeats,
168
+ clip_skip=TE_hook_cfg.clip_skip, clip_final_norm=TE_hook_cfg.clip_final_norm)
169
+
170
+ def forward_TE(self, prompt_ids, timesteps, attn_mask=None, position_ids=None, plugin_input={}, **kwargs):
171
+ input_all = dict(prompt_ids=prompt_ids, timesteps=timesteps, position_ids=position_ids, attn_mask=attn_mask, **plugin_input)
172
+ if hasattr(self.TE, 'input_feeder'):
173
+ for feeder in self.TE.input_feeder:
174
+ feeder(input_all)
175
+ # Get the text embedding for conditioning
176
+ encoder_hidden_states, pooled_output = self.TE(prompt_ids, position_ids=position_ids, attention_mask=attn_mask, output_hidden_states=True)
177
+ return encoder_hidden_states, pooled_output
178
+
179
+ def forward_denoiser(self, x_t, prompt_ids, encoder_hidden_states, timesteps, added_cond_kwargs, attn_mask=None, position_ids=None,
180
+ plugin_input={}, **kwargs):
181
+ if attn_mask is not None:
182
+ attn_mask[:, :self.min_attnmask] = 1
183
+ encoder_hidden_states, attn_mask = pad_attn_bias(encoder_hidden_states, attn_mask)
184
+
185
+ input_all = dict(prompt_ids=prompt_ids, timesteps=timesteps, position_ids=position_ids, attn_mask=attn_mask,
186
+ encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs, **plugin_input)
187
+ if hasattr(self.denoiser, 'input_feeder'):
188
+ for feeder in self.denoiser.input_feeder:
189
+ feeder(input_all)
190
+ model_pred = self.denoiser(x_t, timesteps, encoder_hidden_states, encoder_attention_mask=attn_mask,
191
+ added_cond_kwargs=added_cond_kwargs).sample # Predict the noise residual
192
+ return model_pred
193
+
194
+ def model_forward(self, prompt_ids, image, attn_mask=None, position_ids=None, neg_prompt_ids=None, neg_attn_mask=None, neg_position_ids=None,
195
+ crop_info=None, plugin_input={}):
196
+ # input prepare
197
+ x_0 = self.get_latents(image)
198
+ x_t, noise, sigma, timesteps = self.noise_sampler.add_noise_rand_t(x_0)
199
+ x_t_in = x_t*self.noise_sampler.c_in(sigma).to(dtype=x_t.dtype)
200
+
201
+ if neg_prompt_ids:
202
+ prompt_ids = torch.cat([neg_prompt_ids, prompt_ids], dim=0)
203
+ if neg_attn_mask:
204
+ attn_mask = torch.cat([neg_attn_mask, attn_mask], dim=0)
205
+ if neg_position_ids:
206
+ position_ids = torch.cat([neg_position_ids, position_ids], dim=0)
207
+
208
+ # model forward
209
+ x_t_in, timesteps = self.cfg_context.pre(x_t_in, timesteps)
210
+ encoder_hidden_states, pooled_output = self.forward_TE(prompt_ids, timesteps, attn_mask=attn_mask, position_ids=position_ids,
211
+ plugin_input=plugin_input)
212
+ added_cond_kwargs = {"text_embeds":pooled_output[-1], "time_ids":crop_info}
213
+ model_pred = self.forward_denoiser(x_t_in, prompt_ids, encoder_hidden_states, timesteps, added_cond_kwargs=added_cond_kwargs,
214
+ attn_mask=attn_mask, position_ids=position_ids, plugin_input=plugin_input)
215
+ model_pred = self.cfg_context.post(model_pred)
216
+
217
+ return dict(model_pred=model_pred, noise=noise, sigma=sigma, timesteps=timesteps, x_0=x_0, x_t=x_t, pred_type=self.pred_type,
218
+ noise_sampler=self.noise_sampler)
@@ -0,0 +1,20 @@
1
+ from dataclasses import dataclass
2
+ from rainbowneko.utils import is_dict
3
+
4
+ class TEHookCFG:
5
+ def __init__(self, tokenizer_repeats: int = 1, clip_skip: int = 0, clip_final_norm: bool = True):
6
+ self.tokenizer_repeats = tokenizer_repeats
7
+ self.clip_skip = clip_skip
8
+ self.clip_final_norm = clip_final_norm
9
+
10
+ @classmethod
11
+ def create(cls, cfg):
12
+ if is_dict(cfg):
13
+ return cls(**cfg)
14
+ elif isinstance(cfg, cls):
15
+ return cfg
16
+ else:
17
+ raise ValueError(f'Invalid TEHookCFG type: {type(cfg)}')
18
+
19
+ SD15_TEHookCFG = TEHookCFG()
20
+ SDXL_TEHookCFG = TEHookCFG(clip_skip=1, clip_final_norm=False)
@@ -0,0 +1 @@
1
+ from .embpt import CfgEmbPTParser
@@ -0,0 +1,32 @@
1
+ from typing import Dict, Tuple, List
2
+ from rainbowneko.utils import Path_Like
3
+ from hcpdiff.models import EmbeddingPTHook
4
+ from torch import Tensor
5
+
6
+ class CfgEmbPTParser:
7
+ def __init__(self, emb_dir: Path_Like, cfg_pt: Dict[str, Dict], lr: float = 1e-5, weight_decay: float = 0):
8
+ self.emb_dir = emb_dir
9
+ self.cfg_pt = cfg_pt
10
+ self.lr = lr
11
+ self.weight_decay = weight_decay
12
+
13
+ def get_params_group(self, model) -> Tuple[List, Dict[str, Tensor]]:
14
+ self.embedding_hook, self.ex_words_emb = EmbeddingPTHook.hook_from_dir(
15
+ self.emb_dir, model.tokenizer, model.TE, N_repeats=model.tokenizer.N_repeats)
16
+ self.embedding_hook.requires_grad_(False)
17
+
18
+ train_params_emb = []
19
+ train_pts = {}
20
+ for pt_name, info in self.cfg_pt.items():
21
+ word_emb = self.ex_words_emb[pt_name]
22
+ train_pts[pt_name] = word_emb
23
+ word_emb.requires_grad = True
24
+ self.embedding_hook.emb_train.append(word_emb)
25
+ param_group = {'params':word_emb}
26
+ if 'lr' in info:
27
+ param_group['lr'] = info.lr
28
+ if 'weight_decay' in info:
29
+ param_group['weight_decay'] = info.weight_decay
30
+ train_params_emb.append(param_group)
31
+
32
+ return train_params_emb, train_pts
@@ -2,7 +2,7 @@ import argparse
2
2
  import json
3
3
  import os
4
4
 
5
- from hcpdiff.utils.img_size_tool import types_support
5
+ from rainbowneko.utils import types_support
6
6
 
7
7
  parser = argparse.ArgumentParser(description='Stable Diffusion Training')
8
8
  parser.add_argument('--data_root', type=str, default='')