hcpdiff 0.9.1__py3-none-any.whl → 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (210) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +80 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +1 -2
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text2img.py +36 -74
  21. hcpdiff/data/source/text2img_cond.py +9 -15
  22. hcpdiff/diffusion/__init__.py +0 -0
  23. hcpdiff/diffusion/noise/__init__.py +2 -0
  24. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  25. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  26. hcpdiff/diffusion/sampler/__init__.py +5 -0
  27. hcpdiff/diffusion/sampler/base.py +72 -0
  28. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  29. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  30. hcpdiff/diffusion/sampler/edm.py +22 -0
  31. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  35. hcpdiff/easy/__init__.py +2 -0
  36. hcpdiff/easy/cfg/__init__.py +3 -0
  37. hcpdiff/easy/cfg/sd15_train.py +201 -0
  38. hcpdiff/easy/cfg/sdxl_train.py +140 -0
  39. hcpdiff/easy/cfg/t2i.py +177 -0
  40. hcpdiff/easy/model/__init__.py +2 -0
  41. hcpdiff/easy/model/cnet.py +31 -0
  42. hcpdiff/easy/model/loader.py +79 -0
  43. hcpdiff/easy/sampler.py +46 -0
  44. hcpdiff/evaluate/__init__.py +1 -0
  45. hcpdiff/evaluate/previewer.py +60 -0
  46. hcpdiff/loss/__init__.py +4 -1
  47. hcpdiff/loss/base.py +41 -0
  48. hcpdiff/loss/gw.py +35 -0
  49. hcpdiff/loss/ssim.py +37 -0
  50. hcpdiff/loss/vlb.py +79 -0
  51. hcpdiff/loss/weighting.py +66 -0
  52. hcpdiff/models/__init__.py +2 -2
  53. hcpdiff/models/cfg_context.py +17 -14
  54. hcpdiff/models/compose/compose_hook.py +44 -23
  55. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  56. hcpdiff/models/compose/sdxl_composer.py +4 -4
  57. hcpdiff/models/controlnet.py +16 -16
  58. hcpdiff/models/lora_base_patch.py +14 -25
  59. hcpdiff/models/lora_layers.py +3 -9
  60. hcpdiff/models/lora_layers_patch.py +14 -24
  61. hcpdiff/models/text_emb_ex.py +84 -6
  62. hcpdiff/models/textencoder_ex.py +54 -18
  63. hcpdiff/models/wrapper/__init__.py +3 -0
  64. hcpdiff/models/wrapper/pixart.py +19 -0
  65. hcpdiff/models/wrapper/sd.py +218 -0
  66. hcpdiff/models/wrapper/utils.py +20 -0
  67. hcpdiff/parser/__init__.py +1 -0
  68. hcpdiff/parser/embpt.py +32 -0
  69. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  70. hcpdiff/tools/dataset_generator.py +94 -0
  71. hcpdiff/tools/download_hf_model.py +24 -0
  72. hcpdiff/tools/init_proj.py +3 -21
  73. hcpdiff/tools/lora_convert.py +18 -17
  74. hcpdiff/tools/save_model.py +12 -0
  75. hcpdiff/tools/sd2diffusers.py +1 -1
  76. hcpdiff/train_colo.py +1 -1
  77. hcpdiff/train_deepspeed.py +1 -1
  78. hcpdiff/trainer_ac.py +79 -0
  79. hcpdiff/trainer_ac_single.py +31 -0
  80. hcpdiff/utils/__init__.py +0 -2
  81. hcpdiff/utils/inpaint_pipe.py +7 -2
  82. hcpdiff/utils/net_utils.py +29 -6
  83. hcpdiff/utils/pipe_hook.py +24 -7
  84. hcpdiff/utils/utils.py +21 -4
  85. hcpdiff/workflow/__init__.py +15 -10
  86. hcpdiff/workflow/daam/__init__.py +1 -0
  87. hcpdiff/workflow/daam/act.py +66 -0
  88. hcpdiff/workflow/daam/hook.py +109 -0
  89. hcpdiff/workflow/diffusion.py +114 -125
  90. hcpdiff/workflow/fast.py +31 -0
  91. hcpdiff/workflow/flow.py +67 -0
  92. hcpdiff/workflow/io.py +36 -130
  93. hcpdiff/workflow/model.py +46 -43
  94. hcpdiff/workflow/text.py +78 -46
  95. hcpdiff/workflow/utils.py +32 -12
  96. hcpdiff/workflow/vae.py +37 -38
  97. hcpdiff-2.1.dist-info/METADATA +285 -0
  98. hcpdiff-2.1.dist-info/RECORD +114 -0
  99. {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
  100. hcpdiff-2.1.dist-info/entry_points.txt +5 -0
  101. hcpdiff/ckpt_manager/base.py +0 -16
  102. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  103. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  104. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
  105. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  106. hcpdiff/data/bucket.py +0 -358
  107. hcpdiff/data/caption_loader.py +0 -80
  108. hcpdiff/data/cond_dataset.py +0 -40
  109. hcpdiff/data/crop_info_dataset.py +0 -40
  110. hcpdiff/data/data_processor.py +0 -33
  111. hcpdiff/data/pair_dataset.py +0 -146
  112. hcpdiff/data/sampler.py +0 -54
  113. hcpdiff/data/source/base.py +0 -30
  114. hcpdiff/data/utils.py +0 -80
  115. hcpdiff/deprecated/__init__.py +0 -1
  116. hcpdiff/deprecated/cfg_converter.py +0 -81
  117. hcpdiff/deprecated/lora_convert.py +0 -31
  118. hcpdiff/infer_workflow.py +0 -57
  119. hcpdiff/loggers/__init__.py +0 -13
  120. hcpdiff/loggers/base_logger.py +0 -76
  121. hcpdiff/loggers/cli_logger.py +0 -40
  122. hcpdiff/loggers/preview/__init__.py +0 -1
  123. hcpdiff/loggers/preview/image_previewer.py +0 -149
  124. hcpdiff/loggers/tensorboard_logger.py +0 -30
  125. hcpdiff/loggers/wandb_logger.py +0 -31
  126. hcpdiff/loggers/webui_logger.py +0 -9
  127. hcpdiff/loss/min_snr_loss.py +0 -52
  128. hcpdiff/models/layers.py +0 -81
  129. hcpdiff/models/plugin.py +0 -348
  130. hcpdiff/models/wrapper.py +0 -75
  131. hcpdiff/noise/__init__.py +0 -3
  132. hcpdiff/noise/noise_base.py +0 -16
  133. hcpdiff/noise/pyramid_noise.py +0 -50
  134. hcpdiff/noise/zero_terminal.py +0 -44
  135. hcpdiff/train_ac.py +0 -566
  136. hcpdiff/train_ac_single.py +0 -39
  137. hcpdiff/utils/caption_tools.py +0 -105
  138. hcpdiff/utils/cfg_net_tools.py +0 -321
  139. hcpdiff/utils/cfg_resolvers.py +0 -16
  140. hcpdiff/utils/ema.py +0 -52
  141. hcpdiff/utils/img_size_tool.py +0 -248
  142. hcpdiff/vis/__init__.py +0 -3
  143. hcpdiff/vis/base_interface.py +0 -12
  144. hcpdiff/vis/disk_interface.py +0 -48
  145. hcpdiff/vis/webui_interface.py +0 -17
  146. hcpdiff/viser_fast.py +0 -138
  147. hcpdiff/visualizer.py +0 -265
  148. hcpdiff/visualizer_reloadable.py +0 -237
  149. hcpdiff/workflow/base.py +0 -59
  150. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  151. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  152. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  153. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  154. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  155. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  156. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  157. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  158. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  159. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  160. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  161. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  162. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  163. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  164. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  165. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  166. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  167. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  168. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  169. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  170. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  171. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  172. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  173. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  174. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  175. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  176. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  177. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  178. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  179. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  180. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  181. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  182. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  183. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  184. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  185. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  186. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  187. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  188. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  189. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  190. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  191. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  192. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  193. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  194. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
  195. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  196. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
  197. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
  198. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  199. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  200. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  201. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  202. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  203. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  204. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  205. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  206. hcpdiff-0.9.1.dist-info/METADATA +0 -199
  207. hcpdiff-0.9.1.dist-info/RECORD +0 -160
  208. hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
  209. {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
  210. {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
1
1
  import torch
2
2
  from einops import repeat
3
3
  import math
4
+ from typing import Union, Callable
4
5
 
5
6
  class CFGContext:
6
7
  def pre(self, noisy_latents, timesteps):
@@ -10,9 +11,11 @@ class CFGContext:
10
11
  return model_pred
11
12
 
12
13
  class DreamArtistPTContext(CFGContext):
13
- def __init__(self, cfg_scale, num_train_timesteps):
14
- self.cfg_scale=cfg_scale
15
- self.num_train_timesteps=num_train_timesteps
14
+ def __init__(self, cfg_low: float, cfg_high: float=None, cfg_func: Union[str, Callable]=None, num_train_timesteps=1000):
15
+ self.cfg_low = cfg_low
16
+ self.cfg_high = cfg_high or cfg_low
17
+ self.cfg_func = cfg_func
18
+ self.num_train_timesteps = num_train_timesteps
16
19
 
17
20
  def pre(self, noisy_latents, timesteps):
18
21
  self.t_raw = timesteps
@@ -22,18 +25,18 @@ class DreamArtistPTContext(CFGContext):
22
25
 
23
26
  def post(self, model_pred):
24
27
  e_t_uncond, e_t = model_pred.chunk(2)
25
- if self.cfg_scale[0] != self.cfg_scale[1]:
26
- rate = self.t_raw / (self.num_train_timesteps - 1)
27
- if self.cfg_scale[2] == 'cos':
28
- rate = torch.cos((rate - 1) * math.pi / 2)
29
- elif self.cfg_scale[2] == 'cos2':
30
- rate = 1 - torch.cos(rate * math.pi / 2)
31
- elif self.cfg_scale[2] == 'ln':
28
+ if self.cfg_low != self.cfg_high:
29
+ rate = self.t_raw/(self.num_train_timesteps-1)
30
+ if self.cfg_func == 'cos':
31
+ rate = torch.cos((rate-1)*math.pi/2)
32
+ elif self.cfg_func == 'cos2':
33
+ rate = 1-torch.cos(rate*math.pi/2)
34
+ elif self.cfg_func == 'ln':
32
35
  pass
33
36
  else:
34
- rate = eval(self.cfg_scale[2])
35
- rate = rate.view(-1,1,1,1)
37
+ rate = self.cfg_func(rate)
38
+ rate = rate.view(-1, 1, 1, 1)
36
39
  else:
37
40
  rate = 1
38
- model_pred = e_t_uncond + ((self.cfg_scale[1] - self.cfg_scale[0]) * rate + self.cfg_scale[0]) * (e_t - e_t_uncond)
39
- return model_pred
41
+ model_pred = e_t_uncond+((self.cfg_high-self.cfg_low)*rate+self.cfg_low)*(e_t-e_t_uncond)
42
+ return model_pred
@@ -38,42 +38,42 @@ class ComposeEmbPTHook(nn.Module):
38
38
  hook.remove()
39
39
 
40
40
  @classmethod
41
- def hook(cls, ex_words_emb: Dict[str, ParameterGroup], tokenizer, text_encoder, log=False, **kwargs):
41
+ def hook(cls, ex_words_emb: Dict[str, ParameterGroup], tokenizer, text_encoder, **kwargs):
42
42
  if isinstance(text_encoder, ComposeTextEncoder):
43
43
  hook_list = []
44
44
 
45
45
  emb_len = 0
46
- for i, (name, tokenizer_i) in enumerate(tokenizer.tokenizer_list):
46
+ for i, name in enumerate(tokenizer.tokenizer_names):
47
47
  text_encoder_i = getattr(text_encoder, name)
48
- if log:
49
- logger.info(f'compose hook: {name}')
48
+ tokenizer_i = getattr(tokenizer, name)
50
49
  embedding_dim = text_encoder_i.get_input_embeddings().embedding_dim
51
50
  ex_words_emb_i = {k:v[i] for k, v in ex_words_emb.items()}
52
51
  emb_len += embedding_dim
53
- hook_list.append((name, EmbeddingPTHook.hook(ex_words_emb_i, tokenizer_i, text_encoder_i, log=log, **kwargs)))
52
+ hook_list.append((name, EmbeddingPTHook.hook(ex_words_emb_i, tokenizer_i, text_encoder_i, **kwargs)))
54
53
 
55
54
  return cls(hook_list)
56
55
  else:
57
- return EmbeddingPTHook.hook(ex_words_emb, tokenizer, text_encoder, log, **kwargs)
56
+ return EmbeddingPTHook.hook(ex_words_emb, tokenizer, text_encoder, **kwargs)
58
57
 
59
58
  @classmethod
60
- def hook_from_dir(cls, emb_dir, tokenizer, text_encoder, log=True, device='cuda:0', **kwargs) -> Union[
59
+ def hook_from_dir(cls, emb_dir, tokenizer, text_encoder, device='cuda:0', **kwargs) -> Union[
61
60
  Tuple['ComposeEmbPTHook', Dict], Tuple[EmbeddingPTHook, Dict]]:
62
61
  if isinstance(text_encoder, ComposeTextEncoder):
63
62
  # multi text encoder
64
- #ex_words_emb = {file[:-3]:load_emb(os.path.join(emb_dir, file)).to(device) for file in os.listdir(emb_dir) if file.endswith('.pt')}
63
+ # ex_words_emb = {file[:-3]:load_emb(os.path.join(emb_dir, file)).to(device) for file in os.listdir(emb_dir) if file.endswith('.pt')}
65
64
 
66
65
  # slice of nn.Parameter cannot return grad. Split the tensor
67
66
  ex_words_emb = {}
68
- emb_dims = [x.embedding_dim for x in text_encoder.get_input_embeddings()]
69
- for file in os.listdir(emb_dir):
70
- if file.endswith('.pt'):
71
- emb = load_emb(os.path.join(emb_dir, file)).to(device)
72
- emb = ParameterGroup([nn.Parameter(item, requires_grad=False) for item in emb.split(emb_dims, dim=1)])
73
- ex_words_emb[file[:-3]] = emb
74
- return cls.hook(ex_words_emb, tokenizer, text_encoder, log, **kwargs), ex_words_emb
67
+ if emb_dir is not None and os.path.exists(emb_dir):
68
+ emb_dims = [x.embedding_dim for x in text_encoder.get_input_embeddings()]
69
+ for file in os.listdir(emb_dir):
70
+ if file.endswith('.pt'):
71
+ emb = load_emb(os.path.join(emb_dir, file)).to(device)
72
+ emb = ParameterGroup([nn.Parameter(item, requires_grad=False) for item in emb.split(emb_dims, dim=1)])
73
+ ex_words_emb[file[:-3]] = emb
74
+ return cls.hook(ex_words_emb, tokenizer, text_encoder, **kwargs), ex_words_emb
75
75
  else:
76
- return EmbeddingPTHook.hook_from_dir(emb_dir, tokenizer, text_encoder, log, device, **kwargs)
76
+ return EmbeddingPTHook.hook_from_dir(emb_dir, tokenizer, text_encoder, **kwargs)
77
77
 
78
78
  class ComposeTEEXHook:
79
79
  def __init__(self, tehook_list: List[Tuple[str, TEEXHook]], cat_dim=-1):
@@ -98,10 +98,28 @@ class ComposeTEEXHook:
98
98
  for name, tehook in self.tehook_list:
99
99
  tehook.clip_skip = value
100
100
 
101
+ @property
102
+ def clip_final_norm(self):
103
+ return self.tehook_list[0][1].clip_final_norm
104
+
105
+ @clip_final_norm.setter
106
+ def clip_final_norm(self, value: bool):
107
+ for name, tehook in self.tehook_list:
108
+ tehook.clip_final_norm = value
109
+
110
+ @property
111
+ def use_attention_mask(self):
112
+ return self.tehook_list[0][1].use_attention_mask
113
+
114
+ @use_attention_mask.setter
115
+ def use_attention_mask(self, value: bool):
116
+ for name, tehook in self.tehook_list:
117
+ tehook.use_attention_mask = value
118
+
101
119
  def encode_prompt_to_emb(self, prompt):
102
120
  emb_list = [tehook.encode_prompt_to_emb(prompt) for name, tehook in self.tehook_list]
103
- encoder_hidden_states, pooled_output = list(zip(*emb_list))
104
- return torch.cat(encoder_hidden_states, dim=self.cat_dim), pooled_output
121
+ encoder_hidden_states, pooled_output, attention_mask = list(zip(*emb_list))
122
+ return torch.cat(encoder_hidden_states, dim=self.cat_dim), pooled_output, attention_mask[0]
105
123
 
106
124
  def enable_xformers(self):
107
125
  for name, tehook in self.tehook_list:
@@ -112,16 +130,19 @@ class ComposeTEEXHook:
112
130
  return TEEXHook.mult_attn(prompt_embeds, attn_mult)
113
131
 
114
132
  @classmethod
115
- def hook(cls, text_enc: nn.Module, tokenizer, N_repeats=3, clip_skip=0, clip_final_norm=True, device='cuda', use_attention_mask=False) -> Union['ComposeTEEXHook', TEEXHook]:
133
+ def hook(cls, text_enc: nn.Module, tokenizer, N_repeats=3, clip_skip=0, clip_final_norm=True, use_attention_mask=False) -> Union[
134
+ 'ComposeTEEXHook', TEEXHook]:
116
135
  if isinstance(text_enc, ComposeTextEncoder):
117
136
  # multi text encoder
118
- tehook_list = [(name, TEEXHook.hook(getattr(text_enc, name), tokenizer_i, N_repeats, clip_skip, clip_final_norm, device=device, use_attention_mask=use_attention_mask))
119
- for name, tokenizer_i in tokenizer.tokenizer_list]
137
+ tehook_list = [(name, TEEXHook.hook(getattr(text_enc, name), getattr(tokenizer, name), N_repeats, clip_skip, clip_final_norm,
138
+ use_attention_mask=use_attention_mask))
139
+ for name in tokenizer.tokenizer_names]
120
140
  return cls(tehook_list)
121
141
  else:
122
142
  # single text encoder
123
- return TEEXHook.hook(text_enc, tokenizer, N_repeats, clip_skip, device=device, use_attention_mask=use_attention_mask)
143
+ return TEEXHook.hook(text_enc, tokenizer, N_repeats, clip_skip, clip_final_norm, use_attention_mask=use_attention_mask)
124
144
 
125
145
  @classmethod
126
146
  def hook_pipe(cls, pipe, N_repeats=3, clip_skip=0, clip_final_norm=True, use_attention_mask=False):
127
- return cls.hook(pipe.text_encoder, pipe.tokenizer, N_repeats=N_repeats, device='cuda', clip_skip=clip_skip, clip_final_norm=clip_final_norm, use_attention_mask=use_attention_mask)
147
+ return cls.hook(pipe.text_encoder, pipe.tokenizer, N_repeats=N_repeats, clip_skip=clip_skip, clip_final_norm=clip_final_norm,
148
+ use_attention_mask=use_attention_mask)
@@ -18,14 +18,19 @@ from transformers.tokenization_utils_base import BatchEncoding
18
18
  class ComposeTokenizer(PreTrainedTokenizer):
19
19
  def __init__(self, tokenizer_list: List[Tuple[str, CLIPTokenizer]], cat_dim=-1):
20
20
  self.cat_dim = cat_dim
21
- self.tokenizer_list = tokenizer_list
21
+
22
+ self.tokenizer_names = []
23
+ for name, tokenizer in tokenizer_list:
24
+ setattr(self, name, tokenizer)
25
+ self.tokenizer_names.append(name)
26
+
22
27
  super().__init__()
23
28
 
24
- self.model_max_length = self.first_tokenizer.model_max_length
29
+ self.model_max_length = torch.tensor([tokenizer.model_max_length for name, tokenizer in tokenizer_list])
25
30
 
26
31
  @property
27
32
  def first_tokenizer(self):
28
- return self.tokenizer_list[0][1]
33
+ return getattr(self, self.tokenizer_names[0])
29
34
 
30
35
  @property
31
36
  def vocab_size(self):
@@ -40,18 +45,26 @@ class ComposeTokenizer(PreTrainedTokenizer):
40
45
  return self.first_tokenizer.bos_token_id
41
46
 
42
47
  def get_vocab(self):
43
- return dict(self.first_tokenizer.encoder, **self.first_tokenizer.added_tokens_encoder)
48
+ return self.first_tokenizer.get_vocab()
44
49
 
45
50
  def tokenize(self, text, **kwargs) -> List[str]:
46
51
  return self.first_tokenizer.tokenize(text, **kwargs)
47
52
 
48
53
  def add_tokens( self, new_tokens, special_tokens: bool = False) -> List[int]:
49
- return [tokenizer.add_tokens(new_tokens, special_tokens) for name, tokenizer in self.tokenizer_list]
54
+ return [getattr(self, name).add_tokens(new_tokens, special_tokens) for name in self.tokenizer_names]
55
+
56
+ def save_vocabulary(self, save_directory: str, filename_prefix = None) -> Tuple[str]:
57
+ return self.first_tokenizer.save_vocabulary(save_directory, filename_prefix)
58
+
59
+ def __call__(self, text, *args, max_length=None, **kwargs):
60
+ if isinstance(max_length, torch.Tensor):
61
+ token_list: List[BatchEncoding] = [getattr(self, name)(text, *args, max_length=max_length_i, **kwargs)
62
+ for name, max_length_i in zip(self.tokenizer_names, max_length)]
63
+ else:
64
+ token_list: List[BatchEncoding] = [getattr(self, name)(text, *args, max_length=max_length, **kwargs) for name in self.tokenizer_names]
50
65
 
51
- def __call__(self, text, *args, **kwargs):
52
- token_list: List[BatchEncoding] = [tokenizer(text, *args, **kwargs) for name, tokenizer in self.tokenizer_list]
53
66
  input_ids = torch.cat([token.input_ids for token in token_list], dim=-1) # [N_tokenizer, N_token]
54
- attention_mask = [token.attention_mask for token in token_list]
67
+ attention_mask = torch.cat([token.attention_mask for token in token_list], dim=-1)
55
68
  return BatchEncoding({'input_ids':input_ids, 'attention_mask':attention_mask})
56
69
 
57
70
  @classmethod
@@ -27,13 +27,13 @@ class CLIPTextModelWithProjection_Align(CLIPTextModelWithProjection):
27
27
  class SDXLTextEncoder(ComposeTextEncoder):
28
28
  @classmethod
29
29
  def from_pretrained(cls, pretrained_model_name_or_path: str, *args, subfolder=None, revision:str=None, **kwargs):
30
- clip_B = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder', **kwargs)
30
+ clip_L = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder', **kwargs)
31
31
  clip_bigG = CLIPTextModelWithProjection_Align.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder_2', **kwargs)
32
- return cls([('clip_B', clip_B), ('clip_bigG', clip_bigG)])
32
+ return cls([('clip_L', clip_L), ('clip_bigG', clip_bigG)])
33
33
 
34
34
  class SDXLTokenizer(ComposeTokenizer):
35
35
  @classmethod
36
36
  def from_pretrained(cls, pretrained_model_name_or_path: str, *args, subfolder=None, revision:str=None, **kwargs):
37
- clip_B = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer', **kwargs)
37
+ clip_L = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer', **kwargs)
38
38
  clip_bigG = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer_2', **kwargs)
39
- return cls([('clip_B', clip_B), ('clip_bigG', clip_bigG)])
39
+ return cls([('clip_L', clip_L), ('clip_bigG', clip_bigG)])
@@ -5,7 +5,7 @@ import torch
5
5
  from torch import nn
6
6
  from copy import deepcopy
7
7
 
8
- from .plugin import MultiPluginBlock, BasePluginBlock
8
+ from rainbowneko.models.plugin import MultiPluginBlock, BasePluginBlock
9
9
  from hcpdiff.utils.net_utils import remove_all_hooks, remove_layers
10
10
 
11
11
  class ControlNetPlugin(MultiPluginBlock):
@@ -55,25 +55,25 @@ class ControlNetPlugin(MultiPluginBlock):
55
55
  self.cond_head = nn.Sequential(*cond_head)
56
56
 
57
57
  def reset_parameters(self) -> None:
58
- def weight_init(m):
59
- if isinstance(m, nn.Conv2d):
60
- nn.init.constant_(m.weight, 0)
61
- self.controlnet_down_blocks.apply(weight_init)
62
- self.controlnet_mid_block.apply(weight_init)
63
- self.cond_head[-1].apply(weight_init)
64
-
65
- def from_layer_hook(self, host, fea_in:Tuple[torch.Tensor], fea_out:Tuple[torch.Tensor], idx: int):
58
+ def zero_weight_init(m):
59
+ for p in m.parameters():
60
+ p.detach().zero_()
61
+ self.controlnet_down_blocks.apply(zero_weight_init)
62
+ self.controlnet_mid_block.apply(zero_weight_init)
63
+ self.cond_head[-1].apply(zero_weight_init)
64
+
65
+ def from_layer_hook(self, host, idx: int, args: Tuple[Any, ...], kwargs: Dict[str, Any], fea_out: Any=None):
66
66
  if idx==0:
67
- self.data_input = fea_in
67
+ self.data_input = (args, kwargs)
68
68
  elif idx==1:
69
- self.feat_to = self(*self.data_input)
69
+ self.feat_to = self(*self.data_input[0], **self.data_input[1])
70
70
 
71
- def to_layer_hook(self, host, fea_in:Tuple[torch.Tensor], fea_out:Tuple[torch.Tensor], idx: int):
71
+ def to_layer_hook(self, host, idx: int, args: Tuple[Any, ...], kwargs: Dict[str, Any], fea_out: Any=None):
72
72
  if idx == 5:
73
- sp = fea_in[0].shape[1]//2
74
- new_feat = fea_in[0].clone()
75
- new_feat[:, sp:, ...] = fea_in[0][:, sp:, ...] + self.feat_to[0]
76
- return (new_feat, fea_in[1])
73
+ sp = args[0].shape[1]//2
74
+ new_feat = args[0].clone()
75
+ new_feat[:, sp:, ...] = args[0][:, sp:, ...] + self.feat_to[0]
76
+ return (new_feat, args[1])
77
77
  elif idx == 3:
78
78
  return (fea_out[0], tuple(fea_out[1][i] + self.feat_to[(idx) * 3 + i+1] for i in range(2)))
79
79
  elif idx == 4:
@@ -13,7 +13,7 @@ from torch import nn
13
13
  from torch.nn import functional as F
14
14
 
15
15
  from hcpdiff.utils.utils import make_mask, low_rank_approximate, isinstance_list
16
- from .plugin import PatchPluginBlock, PluginGroup, PatchPluginContainer
16
+ from rainbowneko.models.plugin import PatchPluginBlock, PluginGroup, PatchPluginContainer
17
17
 
18
18
  from typing import Union, Tuple, Dict, Type
19
19
 
@@ -38,9 +38,9 @@ class LoraBlock(PatchPluginBlock):
38
38
  container_cls = LoraPatchContainer
39
39
  wrapable_classes = (nn.Linear, nn.Conv2d)
40
40
 
41
- def __init__(self, lora_id:int, host:Union[nn.Linear, nn.Conv2d], rank, dropout=0.1, alpha=1.0, bias=False,
41
+ def __init__(self, name:int, host:Union[nn.Linear, nn.Conv2d], rank, dropout=0.1, alpha=1.0, bias=False,
42
42
  alpha_auto_scale=True, parent_block=None, host_name=None, **kwargs):
43
- super().__init__(f'lora_block_{lora_id}', host, parent_block=parent_block, host_name=host_name)
43
+ super().__init__(name, host, parent_block=parent_block, host_name=host_name)
44
44
 
45
45
  self.bias=bias
46
46
 
@@ -56,8 +56,14 @@ class LoraBlock(PatchPluginBlock):
56
56
  self.dropout = nn.Dropout(dropout)
57
57
 
58
58
  self.rank = self.layer.rank
59
+ self.alpha_auto_scale = alpha_auto_scale
59
60
  self.register_buffer('alpha', torch.tensor(alpha/self.rank if alpha_auto_scale else alpha))
60
61
 
62
+ def set_hyper_params(self, alpha=None, **kwargs):
63
+ if alpha is not None:
64
+ self.register_buffer('alpha', torch.tensor(alpha/self.rank if self.alpha_auto_scale else alpha))
65
+ super().set_hyper_params(**kwargs)
66
+
61
67
  def get_weight(self):
62
68
  return self.layer.get_weight() * self.alpha
63
69
 
@@ -91,7 +97,7 @@ class LoraBlock(PatchPluginBlock):
91
97
  host.weight.data * base_alpha + alpha * re_w.to(host.weight.device, dtype=host.weight.dtype)
92
98
  )
93
99
 
94
- if self.layer.lora_up.bias is not None:
100
+ if re_b is not None:
95
101
  if host.bias is None:
96
102
  host.bias = nn.Parameter(re_b.to(host.weight.device, dtype=host.weight.dtype))
97
103
  else:
@@ -145,32 +151,15 @@ class LoraBlock(PatchPluginBlock):
145
151
  pass
146
152
 
147
153
  @classmethod
148
- def wrap_layer(cls, lora_id:int, layer: Union[nn.Linear, nn.Conv2d], rank=1, dropout=0.0, alpha=1.0, svd_init=False,
154
+ def wrap_layer(cls, name:str, host: Union[nn.Linear, nn.Conv2d], rank=1, dropout=0.0, alpha=1.0, svd_init=False,
149
155
  bias=False, mask=None, **kwargs):# -> LoraBlock:
150
- lora_block = cls(lora_id, layer, rank, dropout, alpha, bias=bias, **kwargs)
156
+ lora_block = cls(name, host, rank, dropout, alpha, bias=bias, **kwargs)
151
157
  lora_block.init_weights(svd_init)
152
158
  return lora_block
153
159
 
154
160
  @classmethod
155
- def wrap_model(cls, lora_id:int, model: nn.Module, **kwargs):# -> Dict[str, LoraBlock]:
156
- return super(LoraBlock, cls).wrap_model(lora_id, model, exclude_classes=(LoraBlock,), **kwargs)
157
-
158
- @staticmethod
159
- def extract_lora_state(model:nn.Module):
160
- return {k:v for k,v in model.state_dict().items() if 'lora_block_' in k}
161
-
162
- @staticmethod
163
- def extract_state_without_lora(model:nn.Module):
164
- return {k:v for k,v in model.state_dict().items() if 'lora_block_' not in k}
165
-
166
- @staticmethod
167
- def extract_param_without_lora(model:nn.Module):
168
- return {k:v for k,v in model.named_parameters() if 'lora_block_' not in k}
169
-
170
- @staticmethod
171
- def extract_trainable_state_without_lora(model:nn.Module):
172
- trainable_keys = {k for k,v in model.named_parameters() if ('lora_block_' not in k) and v.requires_grad}
173
- return {k: v for k, v in model.state_dict().items() if k in trainable_keys}
161
+ def wrap_model(cls, name:str, host: nn.Module, **kwargs):# -> Dict[str, LoraBlock]:
162
+ return super().wrap_model(name, host, exclude_classes=(LoraBlock,), **kwargs)
174
163
 
175
164
  class LoraGroup(PluginGroup):
176
165
  def set_mask(self, batch_mask):
@@ -15,7 +15,7 @@ from einops import repeat, rearrange, einsum
15
15
  from torch import nn
16
16
 
17
17
  from .lora_base import LoraBlock
18
- from .layers import GroupLinear
18
+ from rainbowneko.models.layers import GroupLinear
19
19
  import warnings
20
20
 
21
21
  class LoraLayer(LoraBlock):
@@ -59,8 +59,8 @@ class LoraLayerGroup(LoraBlock):
59
59
  def __init__(self, host, rank, bias, dropout, block):
60
60
  super().__init__(host, rank, bias, dropout, block)
61
61
  self.register_buffer('rank_groups', torch.tensor(block.rank_groups_raw, dtype=torch.int))
62
- self.lora_down = GroupLinear(host.in_features*self.rank_groups, self.rank, groups=self.rank_groups, bias=False)
63
- self.lora_up = GroupLinear(self.rank, host.out_features*self.rank_groups, groups=self.rank_groups, bias=bias)
62
+ self.lora_down = GroupLinear(host.in_features, self.rank//self.rank_groups, group=self.rank_groups, bias=False)
63
+ self.lora_up = GroupLinear(self.rank//self.rank_groups, host.out_features, group=self.rank_groups, bias=bias)
64
64
 
65
65
  def feed_svd(self, U, V, weight):
66
66
  self.lora_up.weight.data = rearrange(U, 'o (g ri) -> g ri o', g=self.rank_groups).to(device=weight.device, dtype=weight.dtype)
@@ -137,9 +137,3 @@ class LohaLayer(LoraBlock):
137
137
  w = torch.prod(einsum(self.W_up.data, self.W_down.data, 'g o r ..., g r i ... -> g o i ...'), dim=0)
138
138
  b = None
139
139
  return w, b
140
-
141
- lora_layer_map={
142
- 'lora': LoraLayer,
143
- 'loha_group': LoraLayerGroup,
144
- 'loha': LohaLayer,
145
- }
@@ -8,19 +8,18 @@ lora_layers.py
8
8
  :Licence: Apache-2.0
9
9
  """
10
10
 
11
+ import math
12
+
11
13
  import torch
12
- from einops import einsum, rearrange
14
+ from einops import einsum
13
15
  from torch import nn
14
16
  from torch.nn import functional as F
15
17
 
16
18
  from .lora_base_patch import LoraBlock, PatchPluginContainer
17
- from .layers import GroupLinear
18
- import math
19
- from typing import Union, List
20
19
 
21
20
  class LoraLayer(LoraBlock):
22
- def __init__(self, lora_id: int, host, rank=1, dropout=0.1, alpha=1.0, bias=False, alpha_auto_scale=True, **kwargs):
23
- super().__init__(lora_id, host, rank, dropout, alpha=alpha, bias=bias, alpha_auto_scale=alpha_auto_scale, **kwargs)
21
+ def __init__(self, name: str, host, rank=1, dropout=0.0, alpha=1.0, bias=False, alpha_auto_scale=True, **kwargs):
22
+ super().__init__(name, host, rank, dropout, alpha=alpha, bias=bias, alpha_auto_scale=alpha_auto_scale, **kwargs)
24
23
 
25
24
  class LinearLayer(LoraBlock.LinearLayer):
26
25
  def __init__(self, host:nn.Linear, rank, bias, block):
@@ -99,6 +98,11 @@ class LoraLayer(LoraBlock):
99
98
  b = self.bias.data if self.bias else None
100
99
  return w, b
101
100
 
101
+ def none_add(a, b):
102
+ if a is None:
103
+ return b
104
+ return a+b
105
+
102
106
  class DAPPPatchContainer(PatchPluginContainer):
103
107
  def forward(self, x, *args, **kwargs):
104
108
  weight_p = None
@@ -107,25 +111,11 @@ class DAPPPatchContainer(PatchPluginContainer):
107
111
  bias_n = None
108
112
  for name in self.plugin_names:
109
113
  if self[name].branch=='p':
110
- if weight_p is None:
111
- weight_p = self[name].get_weight()
112
- else:
113
- weight_p = weight_p + self[name].get_weight()
114
-
115
- if bias_p is None:
116
- bias_p = self[name].get_bias()
117
- else:
118
- bias_p = bias_p+self[name].get_bias()
114
+ weight_p = none_add(weight_p, self[name].get_weight())
115
+ bias_p = none_add(bias_p, self[name].get_bias())
119
116
  elif self[name].branch=='n':
120
- if weight_n is None:
121
- weight_n = self[name].get_weight()
122
- else:
123
- weight_n = weight_n + self[name].get_weight()
124
-
125
- if bias_n is None:
126
- bias_n = self[name].get_bias()
127
- else:
128
- bias_n = bias_n+self[name].get_bias()
117
+ weight_n = none_add(weight_n, self[name].get_weight())
118
+ bias_n = none_add(bias_n, self[name].get_bias())
129
119
 
130
120
  B = x.shape[0]//2
131
121
  x_p = self[name].post_forward(x[B:], self._host.weight, weight_p, self._host.bias, bias_p)
@@ -7,16 +7,17 @@ text_emb_ex.py
7
7
  :Created: 10/03/2023
8
8
  :Licence: Apache-2.0
9
9
  """
10
- from typing import Tuple
10
+ from typing import Tuple, Dict, Any
11
11
 
12
12
  import torch
13
13
  from torch import nn
14
14
  import os
15
- from loguru import logger
15
+ from rainbowneko import _share
16
16
  from einops import rearrange, repeat
17
+ import torch.nn.functional as F
17
18
 
18
19
  from ..utils.net_utils import load_emb
19
- from .plugin import SinglePluginBlock
20
+ from rainbowneko.models.plugin import SinglePluginBlock
20
21
 
21
22
  class EmbeddingPTHook(SinglePluginBlock):
22
23
  def __init__(self, token_embedding:nn.Embedding, N_word=75, N_repeats=3):
@@ -37,6 +38,84 @@ class EmbeddingPTHook(SinglePluginBlock):
37
38
  self.input_ids = rearrange(input_ids[0], '(b r) w -> b (r w)', r=self.N_repeats) # 兼容Attention mask
38
39
  return self.input_ids.clip(0, self.num_embeddings-1)
39
40
 
41
+ def forward(self, inputs_embeds:torch.Tensor, *args: Tuple[Any, ...], **kwargs: Dict[str, Any]):
42
+ '''
43
+ :param input_ids: [B, N_ids]
44
+ :param inputs_embeds: [B, N_repeat*(N_word+2), N_emb]
45
+ :return: [B, N_repeat, N_word+2, N_emb]
46
+ '''
47
+ rep_idxs_B = self.input_ids >= self.num_embeddings
48
+ BOS = repeat(inputs_embeds[:,0,:], 'b e -> b r 1 e', r=self.N_repeats)
49
+ EOS = repeat(inputs_embeds[:,-1,:], 'b e -> b r 1 e', r=self.N_repeats)
50
+
51
+ replaced_embeds = []
52
+ for i, (item, rep_idxs, ids_raw) in enumerate(zip(inputs_embeds, rep_idxs_B, self.input_ids)):
53
+ # insert pt to embeddings
54
+ rep_idxs=torch.where(rep_idxs)[0]
55
+ item_new=[]
56
+ rep_idx_last=0
57
+ for rep_idx in rep_idxs:
58
+ rep_idx=rep_idx.item()
59
+ item_new.append(item[rep_idx_last:rep_idx, :])
60
+ item_new.append(self.emb[ids_raw[rep_idx].item()].to(dtype=item.dtype))
61
+ rep_idx_last=rep_idx+1
62
+ item_new.append(item[rep_idx_last:, :])
63
+
64
+ # split to N_repeat sentence
65
+ replaced_item = torch.cat(item_new, dim=0)[1:self.N_word*self.N_repeats+1, :]
66
+ replaced_item = rearrange(replaced_item, '(r w) e -> r w e', r=self.N_repeats, w=self.N_word)
67
+ replaced_item = torch.cat([BOS[i], replaced_item, EOS[i]], dim=1) # [N_repeat, N_word+2, N_emb]
68
+
69
+ replaced_embeds.append(replaced_item)
70
+ return torch.cat(replaced_embeds, dim=0) # [B*N_repeat, N_word+2, N_emb]
71
+
72
+ def remove(self):
73
+ super(EmbeddingPTHook, self).remove()
74
+ self.handle_pre.remove()
75
+
76
+ @classmethod
77
+ def hook(cls, ex_words_emb, tokenizer, text_encoder, **kwargs):
78
+ word_list = list(ex_words_emb.keys())
79
+ tokenizer.add_tokens(word_list)
80
+ token_ids = tokenizer(' '.join(word_list)).input_ids[1:-1]
81
+
82
+ embedding_hook = cls(text_encoder.get_input_embeddings(), N_word=tokenizer.model_max_length-2, **kwargs)
83
+ #text_encoder.text_model.embeddings.token_embedding = embedding_hook
84
+ for tid, word in zip(token_ids, word_list):
85
+ embedding_hook.add_emb(ex_words_emb[word], tid)
86
+ _share.loggers.info(f'hook: {word}, len: {ex_words_emb[word].shape[0]}, id: {tid}')
87
+ return embedding_hook
88
+
89
+ @classmethod
90
+ def hook_from_dir(cls, emb_dir, tokenizer, text_encoder, device='cuda:0', **kwargs):
91
+ ex_words_emb = {file[:-3]: nn.Parameter(load_emb(os.path.join(emb_dir, file)).to(device), requires_grad=False)
92
+ for file in os.listdir(emb_dir) if file.endswith('.pt')}
93
+ return cls.hook(ex_words_emb, tokenizer, text_encoder, **kwargs), ex_words_emb
94
+
95
+ class EmbeddingPTInterpHook(SinglePluginBlock):
96
+ def __init__(self, token_embedding:nn.Embedding, N_word=75, N_repeats=3):
97
+ super().__init__('emb_ex', token_embedding)
98
+ self.handle_pre = token_embedding.register_forward_pre_hook(self.pre_hook)
99
+
100
+ new_len = int(token_embedding.num_embeddings*N_repeats)
101
+ original_weights = token_embedding.weight.data.unsqueeze(1)
102
+ token_embedding.weight.data = F.interpolate(original_weights, size=new_len, mode='linear', align_corners=False).squeeze(1)
103
+ token_embedding.num_embeddings = new_len
104
+
105
+ self.N_word=N_word
106
+ self.N_repeats=N_repeats
107
+ self.num_embeddings=token_embedding.num_embeddings
108
+ self.embedding_dim=token_embedding.embedding_dim
109
+ self.emb={}
110
+ self.emb_train=nn.ParameterList()
111
+
112
+ def add_emb(self, emb:nn.Parameter, token_id:int):
113
+ self.emb[token_id]=emb
114
+
115
+ def pre_hook(self, host, input_ids: Tuple[torch.Tensor]):
116
+ self.input_ids = rearrange(input_ids[0], '(b r) w -> b (r w)', r=self.N_repeats) # 兼容Attention mask
117
+ return self.input_ids.clip(0, self.num_embeddings-1)
118
+
40
119
  def forward(self, fea_in:Tuple[torch.Tensor], inputs_embeds:torch.Tensor):
41
120
  '''
42
121
  :param input_ids: [B, N_ids]
@@ -83,12 +162,11 @@ class EmbeddingPTHook(SinglePluginBlock):
83
162
  for tid, word in zip(token_ids, word_list):
84
163
  embedding_hook.add_emb(ex_words_emb[word], tid)
85
164
  if log:
86
- logger.info(f'hook: {word}, len: {ex_words_emb[word].shape[0]}, id: {tid}')
165
+ _share.logger.info(f'hook: {word}, len: {ex_words_emb[word].shape[0]}, id: {tid}')
87
166
  return embedding_hook
88
167
 
89
168
  @classmethod
90
169
  def hook_from_dir(cls, emb_dir, tokenizer, text_encoder, log=True, device='cuda:0', **kwargs):
91
170
  ex_words_emb = {file[:-3]: nn.Parameter(load_emb(os.path.join(emb_dir, file)).to(device), requires_grad=False)
92
171
  for file in os.listdir(emb_dir) if file.endswith('.pt')}
93
- return cls.hook(ex_words_emb, tokenizer, text_encoder, log, **kwargs), ex_words_emb
94
-
172
+ return cls.hook(ex_words_emb, tokenizer, text_encoder, log, **kwargs), ex_words_emb