hcpdiff 0.9.1__py3-none-any.whl → 2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/__init__.py +4 -4
- hcpdiff/ckpt_manager/__init__.py +4 -5
- hcpdiff/ckpt_manager/ckpt.py +24 -0
- hcpdiff/ckpt_manager/format/__init__.py +4 -0
- hcpdiff/ckpt_manager/format/diffusers.py +59 -0
- hcpdiff/ckpt_manager/format/emb.py +21 -0
- hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
- hcpdiff/ckpt_manager/format/sd_single.py +41 -0
- hcpdiff/ckpt_manager/loader.py +64 -0
- hcpdiff/data/__init__.py +4 -28
- hcpdiff/data/cache/__init__.py +1 -0
- hcpdiff/data/cache/vae.py +102 -0
- hcpdiff/data/dataset.py +20 -0
- hcpdiff/data/handler/__init__.py +3 -0
- hcpdiff/data/handler/controlnet.py +18 -0
- hcpdiff/data/handler/diffusion.py +80 -0
- hcpdiff/data/handler/text.py +111 -0
- hcpdiff/data/source/__init__.py +1 -2
- hcpdiff/data/source/folder_class.py +12 -29
- hcpdiff/data/source/text2img.py +36 -74
- hcpdiff/data/source/text2img_cond.py +9 -15
- hcpdiff/diffusion/__init__.py +0 -0
- hcpdiff/diffusion/noise/__init__.py +2 -0
- hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
- hcpdiff/diffusion/noise/zero_terminal.py +39 -0
- hcpdiff/diffusion/sampler/__init__.py +5 -0
- hcpdiff/diffusion/sampler/base.py +72 -0
- hcpdiff/diffusion/sampler/ddpm.py +20 -0
- hcpdiff/diffusion/sampler/diffusers.py +66 -0
- hcpdiff/diffusion/sampler/edm.py +22 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
- hcpdiff/easy/__init__.py +2 -0
- hcpdiff/easy/cfg/__init__.py +3 -0
- hcpdiff/easy/cfg/sd15_train.py +201 -0
- hcpdiff/easy/cfg/sdxl_train.py +140 -0
- hcpdiff/easy/cfg/t2i.py +177 -0
- hcpdiff/easy/model/__init__.py +2 -0
- hcpdiff/easy/model/cnet.py +31 -0
- hcpdiff/easy/model/loader.py +79 -0
- hcpdiff/easy/sampler.py +46 -0
- hcpdiff/evaluate/__init__.py +1 -0
- hcpdiff/evaluate/previewer.py +60 -0
- hcpdiff/loss/__init__.py +4 -1
- hcpdiff/loss/base.py +41 -0
- hcpdiff/loss/gw.py +35 -0
- hcpdiff/loss/ssim.py +37 -0
- hcpdiff/loss/vlb.py +79 -0
- hcpdiff/loss/weighting.py +66 -0
- hcpdiff/models/__init__.py +2 -2
- hcpdiff/models/cfg_context.py +17 -14
- hcpdiff/models/compose/compose_hook.py +44 -23
- hcpdiff/models/compose/compose_tokenizer.py +21 -8
- hcpdiff/models/compose/sdxl_composer.py +4 -4
- hcpdiff/models/controlnet.py +16 -16
- hcpdiff/models/lora_base_patch.py +14 -25
- hcpdiff/models/lora_layers.py +3 -9
- hcpdiff/models/lora_layers_patch.py +14 -24
- hcpdiff/models/text_emb_ex.py +84 -6
- hcpdiff/models/textencoder_ex.py +54 -18
- hcpdiff/models/wrapper/__init__.py +3 -0
- hcpdiff/models/wrapper/pixart.py +19 -0
- hcpdiff/models/wrapper/sd.py +218 -0
- hcpdiff/models/wrapper/utils.py +20 -0
- hcpdiff/parser/__init__.py +1 -0
- hcpdiff/parser/embpt.py +32 -0
- hcpdiff/tools/convert_caption_txt2json.py +1 -1
- hcpdiff/tools/dataset_generator.py +94 -0
- hcpdiff/tools/download_hf_model.py +24 -0
- hcpdiff/tools/init_proj.py +3 -21
- hcpdiff/tools/lora_convert.py +18 -17
- hcpdiff/tools/save_model.py +12 -0
- hcpdiff/tools/sd2diffusers.py +1 -1
- hcpdiff/train_colo.py +1 -1
- hcpdiff/train_deepspeed.py +1 -1
- hcpdiff/trainer_ac.py +79 -0
- hcpdiff/trainer_ac_single.py +31 -0
- hcpdiff/utils/__init__.py +0 -2
- hcpdiff/utils/inpaint_pipe.py +7 -2
- hcpdiff/utils/net_utils.py +29 -6
- hcpdiff/utils/pipe_hook.py +24 -7
- hcpdiff/utils/utils.py +21 -4
- hcpdiff/workflow/__init__.py +15 -10
- hcpdiff/workflow/daam/__init__.py +1 -0
- hcpdiff/workflow/daam/act.py +66 -0
- hcpdiff/workflow/daam/hook.py +109 -0
- hcpdiff/workflow/diffusion.py +114 -125
- hcpdiff/workflow/fast.py +31 -0
- hcpdiff/workflow/flow.py +67 -0
- hcpdiff/workflow/io.py +36 -130
- hcpdiff/workflow/model.py +46 -43
- hcpdiff/workflow/text.py +78 -46
- hcpdiff/workflow/utils.py +32 -12
- hcpdiff/workflow/vae.py +37 -38
- hcpdiff-2.1.dist-info/METADATA +285 -0
- hcpdiff-2.1.dist-info/RECORD +114 -0
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
- hcpdiff-2.1.dist-info/entry_points.txt +5 -0
- hcpdiff/ckpt_manager/base.py +0 -16
- hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
- hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
- hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
- hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
- hcpdiff/data/bucket.py +0 -358
- hcpdiff/data/caption_loader.py +0 -80
- hcpdiff/data/cond_dataset.py +0 -40
- hcpdiff/data/crop_info_dataset.py +0 -40
- hcpdiff/data/data_processor.py +0 -33
- hcpdiff/data/pair_dataset.py +0 -146
- hcpdiff/data/sampler.py +0 -54
- hcpdiff/data/source/base.py +0 -30
- hcpdiff/data/utils.py +0 -80
- hcpdiff/deprecated/__init__.py +0 -1
- hcpdiff/deprecated/cfg_converter.py +0 -81
- hcpdiff/deprecated/lora_convert.py +0 -31
- hcpdiff/infer_workflow.py +0 -57
- hcpdiff/loggers/__init__.py +0 -13
- hcpdiff/loggers/base_logger.py +0 -76
- hcpdiff/loggers/cli_logger.py +0 -40
- hcpdiff/loggers/preview/__init__.py +0 -1
- hcpdiff/loggers/preview/image_previewer.py +0 -149
- hcpdiff/loggers/tensorboard_logger.py +0 -30
- hcpdiff/loggers/wandb_logger.py +0 -31
- hcpdiff/loggers/webui_logger.py +0 -9
- hcpdiff/loss/min_snr_loss.py +0 -52
- hcpdiff/models/layers.py +0 -81
- hcpdiff/models/plugin.py +0 -348
- hcpdiff/models/wrapper.py +0 -75
- hcpdiff/noise/__init__.py +0 -3
- hcpdiff/noise/noise_base.py +0 -16
- hcpdiff/noise/pyramid_noise.py +0 -50
- hcpdiff/noise/zero_terminal.py +0 -44
- hcpdiff/train_ac.py +0 -566
- hcpdiff/train_ac_single.py +0 -39
- hcpdiff/utils/caption_tools.py +0 -105
- hcpdiff/utils/cfg_net_tools.py +0 -321
- hcpdiff/utils/cfg_resolvers.py +0 -16
- hcpdiff/utils/ema.py +0 -52
- hcpdiff/utils/img_size_tool.py +0 -248
- hcpdiff/vis/__init__.py +0 -3
- hcpdiff/vis/base_interface.py +0 -12
- hcpdiff/vis/disk_interface.py +0 -48
- hcpdiff/vis/webui_interface.py +0 -17
- hcpdiff/viser_fast.py +0 -138
- hcpdiff/visualizer.py +0 -265
- hcpdiff/visualizer_reloadable.py +0 -237
- hcpdiff/workflow/base.py +0 -59
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
- hcpdiff-0.9.1.dist-info/METADATA +0 -199
- hcpdiff-0.9.1.dist-info/RECORD +0 -160
- hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
hcpdiff/models/cfg_context.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import torch
|
2
2
|
from einops import repeat
|
3
3
|
import math
|
4
|
+
from typing import Union, Callable
|
4
5
|
|
5
6
|
class CFGContext:
|
6
7
|
def pre(self, noisy_latents, timesteps):
|
@@ -10,9 +11,11 @@ class CFGContext:
|
|
10
11
|
return model_pred
|
11
12
|
|
12
13
|
class DreamArtistPTContext(CFGContext):
|
13
|
-
def __init__(self,
|
14
|
-
self.
|
15
|
-
self.
|
14
|
+
def __init__(self, cfg_low: float, cfg_high: float=None, cfg_func: Union[str, Callable]=None, num_train_timesteps=1000):
|
15
|
+
self.cfg_low = cfg_low
|
16
|
+
self.cfg_high = cfg_high or cfg_low
|
17
|
+
self.cfg_func = cfg_func
|
18
|
+
self.num_train_timesteps = num_train_timesteps
|
16
19
|
|
17
20
|
def pre(self, noisy_latents, timesteps):
|
18
21
|
self.t_raw = timesteps
|
@@ -22,18 +25,18 @@ class DreamArtistPTContext(CFGContext):
|
|
22
25
|
|
23
26
|
def post(self, model_pred):
|
24
27
|
e_t_uncond, e_t = model_pred.chunk(2)
|
25
|
-
if self.
|
26
|
-
rate = self.t_raw
|
27
|
-
if self.
|
28
|
-
rate = torch.cos((rate
|
29
|
-
elif self.
|
30
|
-
rate = 1
|
31
|
-
elif self.
|
28
|
+
if self.cfg_low != self.cfg_high:
|
29
|
+
rate = self.t_raw/(self.num_train_timesteps-1)
|
30
|
+
if self.cfg_func == 'cos':
|
31
|
+
rate = torch.cos((rate-1)*math.pi/2)
|
32
|
+
elif self.cfg_func == 'cos2':
|
33
|
+
rate = 1-torch.cos(rate*math.pi/2)
|
34
|
+
elif self.cfg_func == 'ln':
|
32
35
|
pass
|
33
36
|
else:
|
34
|
-
rate =
|
35
|
-
rate = rate.view(-1,1,1,1)
|
37
|
+
rate = self.cfg_func(rate)
|
38
|
+
rate = rate.view(-1, 1, 1, 1)
|
36
39
|
else:
|
37
40
|
rate = 1
|
38
|
-
model_pred = e_t_uncond
|
39
|
-
return model_pred
|
41
|
+
model_pred = e_t_uncond+((self.cfg_high-self.cfg_low)*rate+self.cfg_low)*(e_t-e_t_uncond)
|
42
|
+
return model_pred
|
@@ -38,42 +38,42 @@ class ComposeEmbPTHook(nn.Module):
|
|
38
38
|
hook.remove()
|
39
39
|
|
40
40
|
@classmethod
|
41
|
-
def hook(cls, ex_words_emb: Dict[str, ParameterGroup], tokenizer, text_encoder,
|
41
|
+
def hook(cls, ex_words_emb: Dict[str, ParameterGroup], tokenizer, text_encoder, **kwargs):
|
42
42
|
if isinstance(text_encoder, ComposeTextEncoder):
|
43
43
|
hook_list = []
|
44
44
|
|
45
45
|
emb_len = 0
|
46
|
-
for i,
|
46
|
+
for i, name in enumerate(tokenizer.tokenizer_names):
|
47
47
|
text_encoder_i = getattr(text_encoder, name)
|
48
|
-
|
49
|
-
logger.info(f'compose hook: {name}')
|
48
|
+
tokenizer_i = getattr(tokenizer, name)
|
50
49
|
embedding_dim = text_encoder_i.get_input_embeddings().embedding_dim
|
51
50
|
ex_words_emb_i = {k:v[i] for k, v in ex_words_emb.items()}
|
52
51
|
emb_len += embedding_dim
|
53
|
-
hook_list.append((name, EmbeddingPTHook.hook(ex_words_emb_i, tokenizer_i, text_encoder_i,
|
52
|
+
hook_list.append((name, EmbeddingPTHook.hook(ex_words_emb_i, tokenizer_i, text_encoder_i, **kwargs)))
|
54
53
|
|
55
54
|
return cls(hook_list)
|
56
55
|
else:
|
57
|
-
return EmbeddingPTHook.hook(ex_words_emb, tokenizer, text_encoder,
|
56
|
+
return EmbeddingPTHook.hook(ex_words_emb, tokenizer, text_encoder, **kwargs)
|
58
57
|
|
59
58
|
@classmethod
|
60
|
-
def hook_from_dir(cls, emb_dir, tokenizer, text_encoder,
|
59
|
+
def hook_from_dir(cls, emb_dir, tokenizer, text_encoder, device='cuda:0', **kwargs) -> Union[
|
61
60
|
Tuple['ComposeEmbPTHook', Dict], Tuple[EmbeddingPTHook, Dict]]:
|
62
61
|
if isinstance(text_encoder, ComposeTextEncoder):
|
63
62
|
# multi text encoder
|
64
|
-
#ex_words_emb = {file[:-3]:load_emb(os.path.join(emb_dir, file)).to(device) for file in os.listdir(emb_dir) if file.endswith('.pt')}
|
63
|
+
# ex_words_emb = {file[:-3]:load_emb(os.path.join(emb_dir, file)).to(device) for file in os.listdir(emb_dir) if file.endswith('.pt')}
|
65
64
|
|
66
65
|
# slice of nn.Parameter cannot return grad. Split the tensor
|
67
66
|
ex_words_emb = {}
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
67
|
+
if emb_dir is not None and os.path.exists(emb_dir):
|
68
|
+
emb_dims = [x.embedding_dim for x in text_encoder.get_input_embeddings()]
|
69
|
+
for file in os.listdir(emb_dir):
|
70
|
+
if file.endswith('.pt'):
|
71
|
+
emb = load_emb(os.path.join(emb_dir, file)).to(device)
|
72
|
+
emb = ParameterGroup([nn.Parameter(item, requires_grad=False) for item in emb.split(emb_dims, dim=1)])
|
73
|
+
ex_words_emb[file[:-3]] = emb
|
74
|
+
return cls.hook(ex_words_emb, tokenizer, text_encoder, **kwargs), ex_words_emb
|
75
75
|
else:
|
76
|
-
return EmbeddingPTHook.hook_from_dir(emb_dir, tokenizer, text_encoder,
|
76
|
+
return EmbeddingPTHook.hook_from_dir(emb_dir, tokenizer, text_encoder, **kwargs)
|
77
77
|
|
78
78
|
class ComposeTEEXHook:
|
79
79
|
def __init__(self, tehook_list: List[Tuple[str, TEEXHook]], cat_dim=-1):
|
@@ -98,10 +98,28 @@ class ComposeTEEXHook:
|
|
98
98
|
for name, tehook in self.tehook_list:
|
99
99
|
tehook.clip_skip = value
|
100
100
|
|
101
|
+
@property
|
102
|
+
def clip_final_norm(self):
|
103
|
+
return self.tehook_list[0][1].clip_final_norm
|
104
|
+
|
105
|
+
@clip_final_norm.setter
|
106
|
+
def clip_final_norm(self, value: bool):
|
107
|
+
for name, tehook in self.tehook_list:
|
108
|
+
tehook.clip_final_norm = value
|
109
|
+
|
110
|
+
@property
|
111
|
+
def use_attention_mask(self):
|
112
|
+
return self.tehook_list[0][1].use_attention_mask
|
113
|
+
|
114
|
+
@use_attention_mask.setter
|
115
|
+
def use_attention_mask(self, value: bool):
|
116
|
+
for name, tehook in self.tehook_list:
|
117
|
+
tehook.use_attention_mask = value
|
118
|
+
|
101
119
|
def encode_prompt_to_emb(self, prompt):
|
102
120
|
emb_list = [tehook.encode_prompt_to_emb(prompt) for name, tehook in self.tehook_list]
|
103
|
-
encoder_hidden_states, pooled_output = list(zip(*emb_list))
|
104
|
-
return torch.cat(encoder_hidden_states, dim=self.cat_dim), pooled_output
|
121
|
+
encoder_hidden_states, pooled_output, attention_mask = list(zip(*emb_list))
|
122
|
+
return torch.cat(encoder_hidden_states, dim=self.cat_dim), pooled_output, attention_mask[0]
|
105
123
|
|
106
124
|
def enable_xformers(self):
|
107
125
|
for name, tehook in self.tehook_list:
|
@@ -112,16 +130,19 @@ class ComposeTEEXHook:
|
|
112
130
|
return TEEXHook.mult_attn(prompt_embeds, attn_mult)
|
113
131
|
|
114
132
|
@classmethod
|
115
|
-
def hook(cls, text_enc: nn.Module, tokenizer, N_repeats=3, clip_skip=0, clip_final_norm=True,
|
133
|
+
def hook(cls, text_enc: nn.Module, tokenizer, N_repeats=3, clip_skip=0, clip_final_norm=True, use_attention_mask=False) -> Union[
|
134
|
+
'ComposeTEEXHook', TEEXHook]:
|
116
135
|
if isinstance(text_enc, ComposeTextEncoder):
|
117
136
|
# multi text encoder
|
118
|
-
tehook_list = [(name, TEEXHook.hook(getattr(text_enc, name),
|
119
|
-
|
137
|
+
tehook_list = [(name, TEEXHook.hook(getattr(text_enc, name), getattr(tokenizer, name), N_repeats, clip_skip, clip_final_norm,
|
138
|
+
use_attention_mask=use_attention_mask))
|
139
|
+
for name in tokenizer.tokenizer_names]
|
120
140
|
return cls(tehook_list)
|
121
141
|
else:
|
122
142
|
# single text encoder
|
123
|
-
return TEEXHook.hook(text_enc, tokenizer, N_repeats, clip_skip,
|
143
|
+
return TEEXHook.hook(text_enc, tokenizer, N_repeats, clip_skip, clip_final_norm, use_attention_mask=use_attention_mask)
|
124
144
|
|
125
145
|
@classmethod
|
126
146
|
def hook_pipe(cls, pipe, N_repeats=3, clip_skip=0, clip_final_norm=True, use_attention_mask=False):
|
127
|
-
return cls.hook(pipe.text_encoder, pipe.tokenizer, N_repeats=N_repeats,
|
147
|
+
return cls.hook(pipe.text_encoder, pipe.tokenizer, N_repeats=N_repeats, clip_skip=clip_skip, clip_final_norm=clip_final_norm,
|
148
|
+
use_attention_mask=use_attention_mask)
|
@@ -18,14 +18,19 @@ from transformers.tokenization_utils_base import BatchEncoding
|
|
18
18
|
class ComposeTokenizer(PreTrainedTokenizer):
|
19
19
|
def __init__(self, tokenizer_list: List[Tuple[str, CLIPTokenizer]], cat_dim=-1):
|
20
20
|
self.cat_dim = cat_dim
|
21
|
-
|
21
|
+
|
22
|
+
self.tokenizer_names = []
|
23
|
+
for name, tokenizer in tokenizer_list:
|
24
|
+
setattr(self, name, tokenizer)
|
25
|
+
self.tokenizer_names.append(name)
|
26
|
+
|
22
27
|
super().__init__()
|
23
28
|
|
24
|
-
self.model_max_length =
|
29
|
+
self.model_max_length = torch.tensor([tokenizer.model_max_length for name, tokenizer in tokenizer_list])
|
25
30
|
|
26
31
|
@property
|
27
32
|
def first_tokenizer(self):
|
28
|
-
return self.
|
33
|
+
return getattr(self, self.tokenizer_names[0])
|
29
34
|
|
30
35
|
@property
|
31
36
|
def vocab_size(self):
|
@@ -40,18 +45,26 @@ class ComposeTokenizer(PreTrainedTokenizer):
|
|
40
45
|
return self.first_tokenizer.bos_token_id
|
41
46
|
|
42
47
|
def get_vocab(self):
|
43
|
-
return
|
48
|
+
return self.first_tokenizer.get_vocab()
|
44
49
|
|
45
50
|
def tokenize(self, text, **kwargs) -> List[str]:
|
46
51
|
return self.first_tokenizer.tokenize(text, **kwargs)
|
47
52
|
|
48
53
|
def add_tokens( self, new_tokens, special_tokens: bool = False) -> List[int]:
|
49
|
-
return [
|
54
|
+
return [getattr(self, name).add_tokens(new_tokens, special_tokens) for name in self.tokenizer_names]
|
55
|
+
|
56
|
+
def save_vocabulary(self, save_directory: str, filename_prefix = None) -> Tuple[str]:
|
57
|
+
return self.first_tokenizer.save_vocabulary(save_directory, filename_prefix)
|
58
|
+
|
59
|
+
def __call__(self, text, *args, max_length=None, **kwargs):
|
60
|
+
if isinstance(max_length, torch.Tensor):
|
61
|
+
token_list: List[BatchEncoding] = [getattr(self, name)(text, *args, max_length=max_length_i, **kwargs)
|
62
|
+
for name, max_length_i in zip(self.tokenizer_names, max_length)]
|
63
|
+
else:
|
64
|
+
token_list: List[BatchEncoding] = [getattr(self, name)(text, *args, max_length=max_length, **kwargs) for name in self.tokenizer_names]
|
50
65
|
|
51
|
-
def __call__(self, text, *args, **kwargs):
|
52
|
-
token_list: List[BatchEncoding] = [tokenizer(text, *args, **kwargs) for name, tokenizer in self.tokenizer_list]
|
53
66
|
input_ids = torch.cat([token.input_ids for token in token_list], dim=-1) # [N_tokenizer, N_token]
|
54
|
-
attention_mask = [token.attention_mask for token in token_list]
|
67
|
+
attention_mask = torch.cat([token.attention_mask for token in token_list], dim=-1)
|
55
68
|
return BatchEncoding({'input_ids':input_ids, 'attention_mask':attention_mask})
|
56
69
|
|
57
70
|
@classmethod
|
@@ -27,13 +27,13 @@ class CLIPTextModelWithProjection_Align(CLIPTextModelWithProjection):
|
|
27
27
|
class SDXLTextEncoder(ComposeTextEncoder):
|
28
28
|
@classmethod
|
29
29
|
def from_pretrained(cls, pretrained_model_name_or_path: str, *args, subfolder=None, revision:str=None, **kwargs):
|
30
|
-
|
30
|
+
clip_L = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder', **kwargs)
|
31
31
|
clip_bigG = CLIPTextModelWithProjection_Align.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder_2', **kwargs)
|
32
|
-
return cls([('
|
32
|
+
return cls([('clip_L', clip_L), ('clip_bigG', clip_bigG)])
|
33
33
|
|
34
34
|
class SDXLTokenizer(ComposeTokenizer):
|
35
35
|
@classmethod
|
36
36
|
def from_pretrained(cls, pretrained_model_name_or_path: str, *args, subfolder=None, revision:str=None, **kwargs):
|
37
|
-
|
37
|
+
clip_L = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer', **kwargs)
|
38
38
|
clip_bigG = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer_2', **kwargs)
|
39
|
-
return cls([('
|
39
|
+
return cls([('clip_L', clip_L), ('clip_bigG', clip_bigG)])
|
hcpdiff/models/controlnet.py
CHANGED
@@ -5,7 +5,7 @@ import torch
|
|
5
5
|
from torch import nn
|
6
6
|
from copy import deepcopy
|
7
7
|
|
8
|
-
from .plugin import MultiPluginBlock, BasePluginBlock
|
8
|
+
from rainbowneko.models.plugin import MultiPluginBlock, BasePluginBlock
|
9
9
|
from hcpdiff.utils.net_utils import remove_all_hooks, remove_layers
|
10
10
|
|
11
11
|
class ControlNetPlugin(MultiPluginBlock):
|
@@ -55,25 +55,25 @@ class ControlNetPlugin(MultiPluginBlock):
|
|
55
55
|
self.cond_head = nn.Sequential(*cond_head)
|
56
56
|
|
57
57
|
def reset_parameters(self) -> None:
|
58
|
-
def
|
59
|
-
|
60
|
-
|
61
|
-
self.controlnet_down_blocks.apply(
|
62
|
-
self.controlnet_mid_block.apply(
|
63
|
-
self.cond_head[-1].apply(
|
64
|
-
|
65
|
-
def from_layer_hook(self, host,
|
58
|
+
def zero_weight_init(m):
|
59
|
+
for p in m.parameters():
|
60
|
+
p.detach().zero_()
|
61
|
+
self.controlnet_down_blocks.apply(zero_weight_init)
|
62
|
+
self.controlnet_mid_block.apply(zero_weight_init)
|
63
|
+
self.cond_head[-1].apply(zero_weight_init)
|
64
|
+
|
65
|
+
def from_layer_hook(self, host, idx: int, args: Tuple[Any, ...], kwargs: Dict[str, Any], fea_out: Any=None):
|
66
66
|
if idx==0:
|
67
|
-
self.data_input =
|
67
|
+
self.data_input = (args, kwargs)
|
68
68
|
elif idx==1:
|
69
|
-
self.feat_to = self(*self.data_input)
|
69
|
+
self.feat_to = self(*self.data_input[0], **self.data_input[1])
|
70
70
|
|
71
|
-
def to_layer_hook(self, host,
|
71
|
+
def to_layer_hook(self, host, idx: int, args: Tuple[Any, ...], kwargs: Dict[str, Any], fea_out: Any=None):
|
72
72
|
if idx == 5:
|
73
|
-
sp =
|
74
|
-
new_feat =
|
75
|
-
new_feat[:, sp:, ...] =
|
76
|
-
return (new_feat,
|
73
|
+
sp = args[0].shape[1]//2
|
74
|
+
new_feat = args[0].clone()
|
75
|
+
new_feat[:, sp:, ...] = args[0][:, sp:, ...] + self.feat_to[0]
|
76
|
+
return (new_feat, args[1])
|
77
77
|
elif idx == 3:
|
78
78
|
return (fea_out[0], tuple(fea_out[1][i] + self.feat_to[(idx) * 3 + i+1] for i in range(2)))
|
79
79
|
elif idx == 4:
|
@@ -13,7 +13,7 @@ from torch import nn
|
|
13
13
|
from torch.nn import functional as F
|
14
14
|
|
15
15
|
from hcpdiff.utils.utils import make_mask, low_rank_approximate, isinstance_list
|
16
|
-
from .plugin import PatchPluginBlock, PluginGroup, PatchPluginContainer
|
16
|
+
from rainbowneko.models.plugin import PatchPluginBlock, PluginGroup, PatchPluginContainer
|
17
17
|
|
18
18
|
from typing import Union, Tuple, Dict, Type
|
19
19
|
|
@@ -38,9 +38,9 @@ class LoraBlock(PatchPluginBlock):
|
|
38
38
|
container_cls = LoraPatchContainer
|
39
39
|
wrapable_classes = (nn.Linear, nn.Conv2d)
|
40
40
|
|
41
|
-
def __init__(self,
|
41
|
+
def __init__(self, name:int, host:Union[nn.Linear, nn.Conv2d], rank, dropout=0.1, alpha=1.0, bias=False,
|
42
42
|
alpha_auto_scale=True, parent_block=None, host_name=None, **kwargs):
|
43
|
-
super().__init__(
|
43
|
+
super().__init__(name, host, parent_block=parent_block, host_name=host_name)
|
44
44
|
|
45
45
|
self.bias=bias
|
46
46
|
|
@@ -56,8 +56,14 @@ class LoraBlock(PatchPluginBlock):
|
|
56
56
|
self.dropout = nn.Dropout(dropout)
|
57
57
|
|
58
58
|
self.rank = self.layer.rank
|
59
|
+
self.alpha_auto_scale = alpha_auto_scale
|
59
60
|
self.register_buffer('alpha', torch.tensor(alpha/self.rank if alpha_auto_scale else alpha))
|
60
61
|
|
62
|
+
def set_hyper_params(self, alpha=None, **kwargs):
|
63
|
+
if alpha is not None:
|
64
|
+
self.register_buffer('alpha', torch.tensor(alpha/self.rank if self.alpha_auto_scale else alpha))
|
65
|
+
super().set_hyper_params(**kwargs)
|
66
|
+
|
61
67
|
def get_weight(self):
|
62
68
|
return self.layer.get_weight() * self.alpha
|
63
69
|
|
@@ -91,7 +97,7 @@ class LoraBlock(PatchPluginBlock):
|
|
91
97
|
host.weight.data * base_alpha + alpha * re_w.to(host.weight.device, dtype=host.weight.dtype)
|
92
98
|
)
|
93
99
|
|
94
|
-
if
|
100
|
+
if re_b is not None:
|
95
101
|
if host.bias is None:
|
96
102
|
host.bias = nn.Parameter(re_b.to(host.weight.device, dtype=host.weight.dtype))
|
97
103
|
else:
|
@@ -145,32 +151,15 @@ class LoraBlock(PatchPluginBlock):
|
|
145
151
|
pass
|
146
152
|
|
147
153
|
@classmethod
|
148
|
-
def wrap_layer(cls,
|
154
|
+
def wrap_layer(cls, name:str, host: Union[nn.Linear, nn.Conv2d], rank=1, dropout=0.0, alpha=1.0, svd_init=False,
|
149
155
|
bias=False, mask=None, **kwargs):# -> LoraBlock:
|
150
|
-
lora_block = cls(
|
156
|
+
lora_block = cls(name, host, rank, dropout, alpha, bias=bias, **kwargs)
|
151
157
|
lora_block.init_weights(svd_init)
|
152
158
|
return lora_block
|
153
159
|
|
154
160
|
@classmethod
|
155
|
-
def wrap_model(cls,
|
156
|
-
return super(
|
157
|
-
|
158
|
-
@staticmethod
|
159
|
-
def extract_lora_state(model:nn.Module):
|
160
|
-
return {k:v for k,v in model.state_dict().items() if 'lora_block_' in k}
|
161
|
-
|
162
|
-
@staticmethod
|
163
|
-
def extract_state_without_lora(model:nn.Module):
|
164
|
-
return {k:v for k,v in model.state_dict().items() if 'lora_block_' not in k}
|
165
|
-
|
166
|
-
@staticmethod
|
167
|
-
def extract_param_without_lora(model:nn.Module):
|
168
|
-
return {k:v for k,v in model.named_parameters() if 'lora_block_' not in k}
|
169
|
-
|
170
|
-
@staticmethod
|
171
|
-
def extract_trainable_state_without_lora(model:nn.Module):
|
172
|
-
trainable_keys = {k for k,v in model.named_parameters() if ('lora_block_' not in k) and v.requires_grad}
|
173
|
-
return {k: v for k, v in model.state_dict().items() if k in trainable_keys}
|
161
|
+
def wrap_model(cls, name:str, host: nn.Module, **kwargs):# -> Dict[str, LoraBlock]:
|
162
|
+
return super().wrap_model(name, host, exclude_classes=(LoraBlock,), **kwargs)
|
174
163
|
|
175
164
|
class LoraGroup(PluginGroup):
|
176
165
|
def set_mask(self, batch_mask):
|
hcpdiff/models/lora_layers.py
CHANGED
@@ -15,7 +15,7 @@ from einops import repeat, rearrange, einsum
|
|
15
15
|
from torch import nn
|
16
16
|
|
17
17
|
from .lora_base import LoraBlock
|
18
|
-
from .layers import GroupLinear
|
18
|
+
from rainbowneko.models.layers import GroupLinear
|
19
19
|
import warnings
|
20
20
|
|
21
21
|
class LoraLayer(LoraBlock):
|
@@ -59,8 +59,8 @@ class LoraLayerGroup(LoraBlock):
|
|
59
59
|
def __init__(self, host, rank, bias, dropout, block):
|
60
60
|
super().__init__(host, rank, bias, dropout, block)
|
61
61
|
self.register_buffer('rank_groups', torch.tensor(block.rank_groups_raw, dtype=torch.int))
|
62
|
-
self.lora_down = GroupLinear(host.in_features
|
63
|
-
self.lora_up = GroupLinear(self.rank, host.out_features
|
62
|
+
self.lora_down = GroupLinear(host.in_features, self.rank//self.rank_groups, group=self.rank_groups, bias=False)
|
63
|
+
self.lora_up = GroupLinear(self.rank//self.rank_groups, host.out_features, group=self.rank_groups, bias=bias)
|
64
64
|
|
65
65
|
def feed_svd(self, U, V, weight):
|
66
66
|
self.lora_up.weight.data = rearrange(U, 'o (g ri) -> g ri o', g=self.rank_groups).to(device=weight.device, dtype=weight.dtype)
|
@@ -137,9 +137,3 @@ class LohaLayer(LoraBlock):
|
|
137
137
|
w = torch.prod(einsum(self.W_up.data, self.W_down.data, 'g o r ..., g r i ... -> g o i ...'), dim=0)
|
138
138
|
b = None
|
139
139
|
return w, b
|
140
|
-
|
141
|
-
lora_layer_map={
|
142
|
-
'lora': LoraLayer,
|
143
|
-
'loha_group': LoraLayerGroup,
|
144
|
-
'loha': LohaLayer,
|
145
|
-
}
|
@@ -8,19 +8,18 @@ lora_layers.py
|
|
8
8
|
:Licence: Apache-2.0
|
9
9
|
"""
|
10
10
|
|
11
|
+
import math
|
12
|
+
|
11
13
|
import torch
|
12
|
-
from einops import einsum
|
14
|
+
from einops import einsum
|
13
15
|
from torch import nn
|
14
16
|
from torch.nn import functional as F
|
15
17
|
|
16
18
|
from .lora_base_patch import LoraBlock, PatchPluginContainer
|
17
|
-
from .layers import GroupLinear
|
18
|
-
import math
|
19
|
-
from typing import Union, List
|
20
19
|
|
21
20
|
class LoraLayer(LoraBlock):
|
22
|
-
def __init__(self,
|
23
|
-
super().__init__(
|
21
|
+
def __init__(self, name: str, host, rank=1, dropout=0.0, alpha=1.0, bias=False, alpha_auto_scale=True, **kwargs):
|
22
|
+
super().__init__(name, host, rank, dropout, alpha=alpha, bias=bias, alpha_auto_scale=alpha_auto_scale, **kwargs)
|
24
23
|
|
25
24
|
class LinearLayer(LoraBlock.LinearLayer):
|
26
25
|
def __init__(self, host:nn.Linear, rank, bias, block):
|
@@ -99,6 +98,11 @@ class LoraLayer(LoraBlock):
|
|
99
98
|
b = self.bias.data if self.bias else None
|
100
99
|
return w, b
|
101
100
|
|
101
|
+
def none_add(a, b):
|
102
|
+
if a is None:
|
103
|
+
return b
|
104
|
+
return a+b
|
105
|
+
|
102
106
|
class DAPPPatchContainer(PatchPluginContainer):
|
103
107
|
def forward(self, x, *args, **kwargs):
|
104
108
|
weight_p = None
|
@@ -107,25 +111,11 @@ class DAPPPatchContainer(PatchPluginContainer):
|
|
107
111
|
bias_n = None
|
108
112
|
for name in self.plugin_names:
|
109
113
|
if self[name].branch=='p':
|
110
|
-
|
111
|
-
|
112
|
-
else:
|
113
|
-
weight_p = weight_p + self[name].get_weight()
|
114
|
-
|
115
|
-
if bias_p is None:
|
116
|
-
bias_p = self[name].get_bias()
|
117
|
-
else:
|
118
|
-
bias_p = bias_p+self[name].get_bias()
|
114
|
+
weight_p = none_add(weight_p, self[name].get_weight())
|
115
|
+
bias_p = none_add(bias_p, self[name].get_bias())
|
119
116
|
elif self[name].branch=='n':
|
120
|
-
|
121
|
-
|
122
|
-
else:
|
123
|
-
weight_n = weight_n + self[name].get_weight()
|
124
|
-
|
125
|
-
if bias_n is None:
|
126
|
-
bias_n = self[name].get_bias()
|
127
|
-
else:
|
128
|
-
bias_n = bias_n+self[name].get_bias()
|
117
|
+
weight_n = none_add(weight_n, self[name].get_weight())
|
118
|
+
bias_n = none_add(bias_n, self[name].get_bias())
|
129
119
|
|
130
120
|
B = x.shape[0]//2
|
131
121
|
x_p = self[name].post_forward(x[B:], self._host.weight, weight_p, self._host.bias, bias_p)
|
hcpdiff/models/text_emb_ex.py
CHANGED
@@ -7,16 +7,17 @@ text_emb_ex.py
|
|
7
7
|
:Created: 10/03/2023
|
8
8
|
:Licence: Apache-2.0
|
9
9
|
"""
|
10
|
-
from typing import Tuple
|
10
|
+
from typing import Tuple, Dict, Any
|
11
11
|
|
12
12
|
import torch
|
13
13
|
from torch import nn
|
14
14
|
import os
|
15
|
-
from
|
15
|
+
from rainbowneko import _share
|
16
16
|
from einops import rearrange, repeat
|
17
|
+
import torch.nn.functional as F
|
17
18
|
|
18
19
|
from ..utils.net_utils import load_emb
|
19
|
-
from .plugin import SinglePluginBlock
|
20
|
+
from rainbowneko.models.plugin import SinglePluginBlock
|
20
21
|
|
21
22
|
class EmbeddingPTHook(SinglePluginBlock):
|
22
23
|
def __init__(self, token_embedding:nn.Embedding, N_word=75, N_repeats=3):
|
@@ -37,6 +38,84 @@ class EmbeddingPTHook(SinglePluginBlock):
|
|
37
38
|
self.input_ids = rearrange(input_ids[0], '(b r) w -> b (r w)', r=self.N_repeats) # 兼容Attention mask
|
38
39
|
return self.input_ids.clip(0, self.num_embeddings-1)
|
39
40
|
|
41
|
+
def forward(self, inputs_embeds:torch.Tensor, *args: Tuple[Any, ...], **kwargs: Dict[str, Any]):
|
42
|
+
'''
|
43
|
+
:param input_ids: [B, N_ids]
|
44
|
+
:param inputs_embeds: [B, N_repeat*(N_word+2), N_emb]
|
45
|
+
:return: [B, N_repeat, N_word+2, N_emb]
|
46
|
+
'''
|
47
|
+
rep_idxs_B = self.input_ids >= self.num_embeddings
|
48
|
+
BOS = repeat(inputs_embeds[:,0,:], 'b e -> b r 1 e', r=self.N_repeats)
|
49
|
+
EOS = repeat(inputs_embeds[:,-1,:], 'b e -> b r 1 e', r=self.N_repeats)
|
50
|
+
|
51
|
+
replaced_embeds = []
|
52
|
+
for i, (item, rep_idxs, ids_raw) in enumerate(zip(inputs_embeds, rep_idxs_B, self.input_ids)):
|
53
|
+
# insert pt to embeddings
|
54
|
+
rep_idxs=torch.where(rep_idxs)[0]
|
55
|
+
item_new=[]
|
56
|
+
rep_idx_last=0
|
57
|
+
for rep_idx in rep_idxs:
|
58
|
+
rep_idx=rep_idx.item()
|
59
|
+
item_new.append(item[rep_idx_last:rep_idx, :])
|
60
|
+
item_new.append(self.emb[ids_raw[rep_idx].item()].to(dtype=item.dtype))
|
61
|
+
rep_idx_last=rep_idx+1
|
62
|
+
item_new.append(item[rep_idx_last:, :])
|
63
|
+
|
64
|
+
# split to N_repeat sentence
|
65
|
+
replaced_item = torch.cat(item_new, dim=0)[1:self.N_word*self.N_repeats+1, :]
|
66
|
+
replaced_item = rearrange(replaced_item, '(r w) e -> r w e', r=self.N_repeats, w=self.N_word)
|
67
|
+
replaced_item = torch.cat([BOS[i], replaced_item, EOS[i]], dim=1) # [N_repeat, N_word+2, N_emb]
|
68
|
+
|
69
|
+
replaced_embeds.append(replaced_item)
|
70
|
+
return torch.cat(replaced_embeds, dim=0) # [B*N_repeat, N_word+2, N_emb]
|
71
|
+
|
72
|
+
def remove(self):
|
73
|
+
super(EmbeddingPTHook, self).remove()
|
74
|
+
self.handle_pre.remove()
|
75
|
+
|
76
|
+
@classmethod
|
77
|
+
def hook(cls, ex_words_emb, tokenizer, text_encoder, **kwargs):
|
78
|
+
word_list = list(ex_words_emb.keys())
|
79
|
+
tokenizer.add_tokens(word_list)
|
80
|
+
token_ids = tokenizer(' '.join(word_list)).input_ids[1:-1]
|
81
|
+
|
82
|
+
embedding_hook = cls(text_encoder.get_input_embeddings(), N_word=tokenizer.model_max_length-2, **kwargs)
|
83
|
+
#text_encoder.text_model.embeddings.token_embedding = embedding_hook
|
84
|
+
for tid, word in zip(token_ids, word_list):
|
85
|
+
embedding_hook.add_emb(ex_words_emb[word], tid)
|
86
|
+
_share.loggers.info(f'hook: {word}, len: {ex_words_emb[word].shape[0]}, id: {tid}')
|
87
|
+
return embedding_hook
|
88
|
+
|
89
|
+
@classmethod
|
90
|
+
def hook_from_dir(cls, emb_dir, tokenizer, text_encoder, device='cuda:0', **kwargs):
|
91
|
+
ex_words_emb = {file[:-3]: nn.Parameter(load_emb(os.path.join(emb_dir, file)).to(device), requires_grad=False)
|
92
|
+
for file in os.listdir(emb_dir) if file.endswith('.pt')}
|
93
|
+
return cls.hook(ex_words_emb, tokenizer, text_encoder, **kwargs), ex_words_emb
|
94
|
+
|
95
|
+
class EmbeddingPTInterpHook(SinglePluginBlock):
|
96
|
+
def __init__(self, token_embedding:nn.Embedding, N_word=75, N_repeats=3):
|
97
|
+
super().__init__('emb_ex', token_embedding)
|
98
|
+
self.handle_pre = token_embedding.register_forward_pre_hook(self.pre_hook)
|
99
|
+
|
100
|
+
new_len = int(token_embedding.num_embeddings*N_repeats)
|
101
|
+
original_weights = token_embedding.weight.data.unsqueeze(1)
|
102
|
+
token_embedding.weight.data = F.interpolate(original_weights, size=new_len, mode='linear', align_corners=False).squeeze(1)
|
103
|
+
token_embedding.num_embeddings = new_len
|
104
|
+
|
105
|
+
self.N_word=N_word
|
106
|
+
self.N_repeats=N_repeats
|
107
|
+
self.num_embeddings=token_embedding.num_embeddings
|
108
|
+
self.embedding_dim=token_embedding.embedding_dim
|
109
|
+
self.emb={}
|
110
|
+
self.emb_train=nn.ParameterList()
|
111
|
+
|
112
|
+
def add_emb(self, emb:nn.Parameter, token_id:int):
|
113
|
+
self.emb[token_id]=emb
|
114
|
+
|
115
|
+
def pre_hook(self, host, input_ids: Tuple[torch.Tensor]):
|
116
|
+
self.input_ids = rearrange(input_ids[0], '(b r) w -> b (r w)', r=self.N_repeats) # 兼容Attention mask
|
117
|
+
return self.input_ids.clip(0, self.num_embeddings-1)
|
118
|
+
|
40
119
|
def forward(self, fea_in:Tuple[torch.Tensor], inputs_embeds:torch.Tensor):
|
41
120
|
'''
|
42
121
|
:param input_ids: [B, N_ids]
|
@@ -83,12 +162,11 @@ class EmbeddingPTHook(SinglePluginBlock):
|
|
83
162
|
for tid, word in zip(token_ids, word_list):
|
84
163
|
embedding_hook.add_emb(ex_words_emb[word], tid)
|
85
164
|
if log:
|
86
|
-
logger.info(f'hook: {word}, len: {ex_words_emb[word].shape[0]}, id: {tid}')
|
165
|
+
_share.logger.info(f'hook: {word}, len: {ex_words_emb[word].shape[0]}, id: {tid}')
|
87
166
|
return embedding_hook
|
88
167
|
|
89
168
|
@classmethod
|
90
169
|
def hook_from_dir(cls, emb_dir, tokenizer, text_encoder, log=True, device='cuda:0', **kwargs):
|
91
170
|
ex_words_emb = {file[:-3]: nn.Parameter(load_emb(os.path.join(emb_dir, file)).to(device), requires_grad=False)
|
92
171
|
for file in os.listdir(emb_dir) if file.endswith('.pt')}
|
93
|
-
return cls.hook(ex_words_emb, tokenizer, text_encoder, log, **kwargs), ex_words_emb
|
94
|
-
|
172
|
+
return cls.hook(ex_words_emb, tokenizer, text_encoder, log, **kwargs), ex_words_emb
|