hcpdiff 0.9.1__py3-none-any.whl → 2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/__init__.py +4 -4
- hcpdiff/ckpt_manager/__init__.py +4 -5
- hcpdiff/ckpt_manager/ckpt.py +24 -0
- hcpdiff/ckpt_manager/format/__init__.py +4 -0
- hcpdiff/ckpt_manager/format/diffusers.py +59 -0
- hcpdiff/ckpt_manager/format/emb.py +21 -0
- hcpdiff/ckpt_manager/format/lora_webui.py +252 -0
- hcpdiff/ckpt_manager/format/sd_single.py +41 -0
- hcpdiff/ckpt_manager/loader.py +64 -0
- hcpdiff/data/__init__.py +4 -28
- hcpdiff/data/cache/__init__.py +1 -0
- hcpdiff/data/cache/vae.py +102 -0
- hcpdiff/data/dataset.py +20 -0
- hcpdiff/data/handler/__init__.py +3 -0
- hcpdiff/data/handler/controlnet.py +18 -0
- hcpdiff/data/handler/diffusion.py +90 -0
- hcpdiff/data/handler/text.py +111 -0
- hcpdiff/data/source/__init__.py +3 -3
- hcpdiff/data/source/folder_class.py +12 -29
- hcpdiff/data/source/text.py +40 -0
- hcpdiff/data/source/text2img.py +36 -74
- hcpdiff/data/source/text2img_cond.py +9 -15
- hcpdiff/diffusion/__init__.py +0 -0
- hcpdiff/diffusion/noise/__init__.py +2 -0
- hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
- hcpdiff/diffusion/noise/zero_terminal.py +39 -0
- hcpdiff/diffusion/sampler/__init__.py +5 -0
- hcpdiff/diffusion/sampler/base.py +72 -0
- hcpdiff/diffusion/sampler/ddpm.py +20 -0
- hcpdiff/diffusion/sampler/diffusers.py +66 -0
- hcpdiff/diffusion/sampler/edm.py +22 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
- hcpdiff/easy/__init__.py +2 -0
- hcpdiff/easy/cfg/__init__.py +3 -0
- hcpdiff/easy/cfg/sd15_train.py +207 -0
- hcpdiff/easy/cfg/sdxl_train.py +147 -0
- hcpdiff/easy/cfg/t2i.py +228 -0
- hcpdiff/easy/model/__init__.py +2 -0
- hcpdiff/easy/model/cnet.py +31 -0
- hcpdiff/easy/model/loader.py +79 -0
- hcpdiff/easy/sampler.py +46 -0
- hcpdiff/evaluate/__init__.py +1 -0
- hcpdiff/evaluate/previewer.py +60 -0
- hcpdiff/loss/__init__.py +4 -1
- hcpdiff/loss/base.py +41 -0
- hcpdiff/loss/gw.py +35 -0
- hcpdiff/loss/ssim.py +37 -0
- hcpdiff/loss/vlb.py +79 -0
- hcpdiff/loss/weighting.py +66 -0
- hcpdiff/models/__init__.py +2 -2
- hcpdiff/models/cfg_context.py +17 -14
- hcpdiff/models/compose/compose_hook.py +44 -23
- hcpdiff/models/compose/compose_tokenizer.py +21 -8
- hcpdiff/models/compose/sdxl_composer.py +4 -4
- hcpdiff/models/controlnet.py +16 -16
- hcpdiff/models/lora_base_patch.py +14 -25
- hcpdiff/models/lora_layers.py +3 -9
- hcpdiff/models/lora_layers_patch.py +14 -24
- hcpdiff/models/text_emb_ex.py +84 -6
- hcpdiff/models/textencoder_ex.py +54 -18
- hcpdiff/models/wrapper/__init__.py +3 -0
- hcpdiff/models/wrapper/pixart.py +19 -0
- hcpdiff/models/wrapper/sd.py +218 -0
- hcpdiff/models/wrapper/utils.py +20 -0
- hcpdiff/parser/__init__.py +1 -0
- hcpdiff/parser/embpt.py +32 -0
- hcpdiff/tools/convert_caption_txt2json.py +1 -1
- hcpdiff/tools/dataset_generator.py +94 -0
- hcpdiff/tools/download_hf_model.py +24 -0
- hcpdiff/tools/init_proj.py +3 -21
- hcpdiff/tools/lora_convert.py +18 -17
- hcpdiff/tools/save_model.py +12 -0
- hcpdiff/tools/sd2diffusers.py +1 -1
- hcpdiff/train_colo.py +1 -1
- hcpdiff/train_deepspeed.py +1 -1
- hcpdiff/trainer_ac.py +79 -0
- hcpdiff/trainer_ac_single.py +31 -0
- hcpdiff/utils/__init__.py +0 -2
- hcpdiff/utils/inpaint_pipe.py +7 -2
- hcpdiff/utils/net_utils.py +29 -6
- hcpdiff/utils/pipe_hook.py +24 -7
- hcpdiff/utils/utils.py +21 -4
- hcpdiff/workflow/__init__.py +15 -10
- hcpdiff/workflow/daam/__init__.py +1 -0
- hcpdiff/workflow/daam/act.py +66 -0
- hcpdiff/workflow/daam/hook.py +109 -0
- hcpdiff/workflow/diffusion.py +118 -128
- hcpdiff/workflow/fast.py +31 -0
- hcpdiff/workflow/flow.py +67 -0
- hcpdiff/workflow/io.py +36 -130
- hcpdiff/workflow/model.py +46 -43
- hcpdiff/workflow/text.py +60 -47
- hcpdiff/workflow/utils.py +32 -12
- hcpdiff/workflow/vae.py +37 -38
- hcpdiff-2.2.dist-info/METADATA +299 -0
- hcpdiff-2.2.dist-info/RECORD +115 -0
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/WHEEL +1 -1
- hcpdiff-2.2.dist-info/entry_points.txt +5 -0
- hcpdiff/ckpt_manager/base.py +0 -16
- hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
- hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
- hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
- hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
- hcpdiff/data/bucket.py +0 -358
- hcpdiff/data/caption_loader.py +0 -80
- hcpdiff/data/cond_dataset.py +0 -40
- hcpdiff/data/crop_info_dataset.py +0 -40
- hcpdiff/data/data_processor.py +0 -33
- hcpdiff/data/pair_dataset.py +0 -146
- hcpdiff/data/sampler.py +0 -54
- hcpdiff/data/source/base.py +0 -30
- hcpdiff/data/utils.py +0 -80
- hcpdiff/deprecated/__init__.py +0 -1
- hcpdiff/deprecated/cfg_converter.py +0 -81
- hcpdiff/deprecated/lora_convert.py +0 -31
- hcpdiff/infer_workflow.py +0 -57
- hcpdiff/loggers/__init__.py +0 -13
- hcpdiff/loggers/base_logger.py +0 -76
- hcpdiff/loggers/cli_logger.py +0 -40
- hcpdiff/loggers/preview/__init__.py +0 -1
- hcpdiff/loggers/preview/image_previewer.py +0 -149
- hcpdiff/loggers/tensorboard_logger.py +0 -30
- hcpdiff/loggers/wandb_logger.py +0 -31
- hcpdiff/loggers/webui_logger.py +0 -9
- hcpdiff/loss/min_snr_loss.py +0 -52
- hcpdiff/models/layers.py +0 -81
- hcpdiff/models/plugin.py +0 -348
- hcpdiff/models/wrapper.py +0 -75
- hcpdiff/noise/__init__.py +0 -3
- hcpdiff/noise/noise_base.py +0 -16
- hcpdiff/noise/pyramid_noise.py +0 -50
- hcpdiff/noise/zero_terminal.py +0 -44
- hcpdiff/train_ac.py +0 -566
- hcpdiff/train_ac_single.py +0 -39
- hcpdiff/utils/caption_tools.py +0 -105
- hcpdiff/utils/cfg_net_tools.py +0 -321
- hcpdiff/utils/cfg_resolvers.py +0 -16
- hcpdiff/utils/ema.py +0 -52
- hcpdiff/utils/img_size_tool.py +0 -248
- hcpdiff/vis/__init__.py +0 -3
- hcpdiff/vis/base_interface.py +0 -12
- hcpdiff/vis/disk_interface.py +0 -48
- hcpdiff/vis/webui_interface.py +0 -17
- hcpdiff/viser_fast.py +0 -138
- hcpdiff/visualizer.py +0 -265
- hcpdiff/visualizer_reloadable.py +0 -237
- hcpdiff/workflow/base.py +0 -59
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
- hcpdiff-0.9.1.dist-info/METADATA +0 -199
- hcpdiff-0.9.1.dist-info/RECORD +0 -160
- hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info/licenses}/LICENSE +0 -0
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/top_level.txt +0 -0
@@ -8,19 +8,18 @@ lora_layers.py
|
|
8
8
|
:Licence: Apache-2.0
|
9
9
|
"""
|
10
10
|
|
11
|
+
import math
|
12
|
+
|
11
13
|
import torch
|
12
|
-
from einops import einsum
|
14
|
+
from einops import einsum
|
13
15
|
from torch import nn
|
14
16
|
from torch.nn import functional as F
|
15
17
|
|
16
18
|
from .lora_base_patch import LoraBlock, PatchPluginContainer
|
17
|
-
from .layers import GroupLinear
|
18
|
-
import math
|
19
|
-
from typing import Union, List
|
20
19
|
|
21
20
|
class LoraLayer(LoraBlock):
|
22
|
-
def __init__(self,
|
23
|
-
super().__init__(
|
21
|
+
def __init__(self, name: str, host, rank=1, dropout=0.0, alpha=1.0, bias=False, alpha_auto_scale=True, **kwargs):
|
22
|
+
super().__init__(name, host, rank, dropout, alpha=alpha, bias=bias, alpha_auto_scale=alpha_auto_scale, **kwargs)
|
24
23
|
|
25
24
|
class LinearLayer(LoraBlock.LinearLayer):
|
26
25
|
def __init__(self, host:nn.Linear, rank, bias, block):
|
@@ -99,6 +98,11 @@ class LoraLayer(LoraBlock):
|
|
99
98
|
b = self.bias.data if self.bias else None
|
100
99
|
return w, b
|
101
100
|
|
101
|
+
def none_add(a, b):
|
102
|
+
if a is None:
|
103
|
+
return b
|
104
|
+
return a+b
|
105
|
+
|
102
106
|
class DAPPPatchContainer(PatchPluginContainer):
|
103
107
|
def forward(self, x, *args, **kwargs):
|
104
108
|
weight_p = None
|
@@ -107,25 +111,11 @@ class DAPPPatchContainer(PatchPluginContainer):
|
|
107
111
|
bias_n = None
|
108
112
|
for name in self.plugin_names:
|
109
113
|
if self[name].branch=='p':
|
110
|
-
|
111
|
-
|
112
|
-
else:
|
113
|
-
weight_p = weight_p + self[name].get_weight()
|
114
|
-
|
115
|
-
if bias_p is None:
|
116
|
-
bias_p = self[name].get_bias()
|
117
|
-
else:
|
118
|
-
bias_p = bias_p+self[name].get_bias()
|
114
|
+
weight_p = none_add(weight_p, self[name].get_weight())
|
115
|
+
bias_p = none_add(bias_p, self[name].get_bias())
|
119
116
|
elif self[name].branch=='n':
|
120
|
-
|
121
|
-
|
122
|
-
else:
|
123
|
-
weight_n = weight_n + self[name].get_weight()
|
124
|
-
|
125
|
-
if bias_n is None:
|
126
|
-
bias_n = self[name].get_bias()
|
127
|
-
else:
|
128
|
-
bias_n = bias_n+self[name].get_bias()
|
117
|
+
weight_n = none_add(weight_n, self[name].get_weight())
|
118
|
+
bias_n = none_add(bias_n, self[name].get_bias())
|
129
119
|
|
130
120
|
B = x.shape[0]//2
|
131
121
|
x_p = self[name].post_forward(x[B:], self._host.weight, weight_p, self._host.bias, bias_p)
|
hcpdiff/models/text_emb_ex.py
CHANGED
@@ -7,16 +7,17 @@ text_emb_ex.py
|
|
7
7
|
:Created: 10/03/2023
|
8
8
|
:Licence: Apache-2.0
|
9
9
|
"""
|
10
|
-
from typing import Tuple
|
10
|
+
from typing import Tuple, Dict, Any
|
11
11
|
|
12
12
|
import torch
|
13
13
|
from torch import nn
|
14
14
|
import os
|
15
|
-
from
|
15
|
+
from rainbowneko import _share
|
16
16
|
from einops import rearrange, repeat
|
17
|
+
import torch.nn.functional as F
|
17
18
|
|
18
19
|
from ..utils.net_utils import load_emb
|
19
|
-
from .plugin import SinglePluginBlock
|
20
|
+
from rainbowneko.models.plugin import SinglePluginBlock
|
20
21
|
|
21
22
|
class EmbeddingPTHook(SinglePluginBlock):
|
22
23
|
def __init__(self, token_embedding:nn.Embedding, N_word=75, N_repeats=3):
|
@@ -37,6 +38,84 @@ class EmbeddingPTHook(SinglePluginBlock):
|
|
37
38
|
self.input_ids = rearrange(input_ids[0], '(b r) w -> b (r w)', r=self.N_repeats) # 兼容Attention mask
|
38
39
|
return self.input_ids.clip(0, self.num_embeddings-1)
|
39
40
|
|
41
|
+
def forward(self, inputs_embeds:torch.Tensor, *args: Tuple[Any, ...], **kwargs: Dict[str, Any]):
|
42
|
+
'''
|
43
|
+
:param input_ids: [B, N_ids]
|
44
|
+
:param inputs_embeds: [B, N_repeat*(N_word+2), N_emb]
|
45
|
+
:return: [B, N_repeat, N_word+2, N_emb]
|
46
|
+
'''
|
47
|
+
rep_idxs_B = self.input_ids >= self.num_embeddings
|
48
|
+
BOS = repeat(inputs_embeds[:,0,:], 'b e -> b r 1 e', r=self.N_repeats)
|
49
|
+
EOS = repeat(inputs_embeds[:,-1,:], 'b e -> b r 1 e', r=self.N_repeats)
|
50
|
+
|
51
|
+
replaced_embeds = []
|
52
|
+
for i, (item, rep_idxs, ids_raw) in enumerate(zip(inputs_embeds, rep_idxs_B, self.input_ids)):
|
53
|
+
# insert pt to embeddings
|
54
|
+
rep_idxs=torch.where(rep_idxs)[0]
|
55
|
+
item_new=[]
|
56
|
+
rep_idx_last=0
|
57
|
+
for rep_idx in rep_idxs:
|
58
|
+
rep_idx=rep_idx.item()
|
59
|
+
item_new.append(item[rep_idx_last:rep_idx, :])
|
60
|
+
item_new.append(self.emb[ids_raw[rep_idx].item()].to(dtype=item.dtype))
|
61
|
+
rep_idx_last=rep_idx+1
|
62
|
+
item_new.append(item[rep_idx_last:, :])
|
63
|
+
|
64
|
+
# split to N_repeat sentence
|
65
|
+
replaced_item = torch.cat(item_new, dim=0)[1:self.N_word*self.N_repeats+1, :]
|
66
|
+
replaced_item = rearrange(replaced_item, '(r w) e -> r w e', r=self.N_repeats, w=self.N_word)
|
67
|
+
replaced_item = torch.cat([BOS[i], replaced_item, EOS[i]], dim=1) # [N_repeat, N_word+2, N_emb]
|
68
|
+
|
69
|
+
replaced_embeds.append(replaced_item)
|
70
|
+
return torch.cat(replaced_embeds, dim=0) # [B*N_repeat, N_word+2, N_emb]
|
71
|
+
|
72
|
+
def remove(self):
|
73
|
+
super(EmbeddingPTHook, self).remove()
|
74
|
+
self.handle_pre.remove()
|
75
|
+
|
76
|
+
@classmethod
|
77
|
+
def hook(cls, ex_words_emb, tokenizer, text_encoder, **kwargs):
|
78
|
+
word_list = list(ex_words_emb.keys())
|
79
|
+
tokenizer.add_tokens(word_list)
|
80
|
+
token_ids = tokenizer(' '.join(word_list)).input_ids[1:-1]
|
81
|
+
|
82
|
+
embedding_hook = cls(text_encoder.get_input_embeddings(), N_word=tokenizer.model_max_length-2, **kwargs)
|
83
|
+
#text_encoder.text_model.embeddings.token_embedding = embedding_hook
|
84
|
+
for tid, word in zip(token_ids, word_list):
|
85
|
+
embedding_hook.add_emb(ex_words_emb[word], tid)
|
86
|
+
_share.loggers.info(f'hook: {word}, len: {ex_words_emb[word].shape[0]}, id: {tid}')
|
87
|
+
return embedding_hook
|
88
|
+
|
89
|
+
@classmethod
|
90
|
+
def hook_from_dir(cls, emb_dir, tokenizer, text_encoder, device='cuda:0', **kwargs):
|
91
|
+
ex_words_emb = {file[:-3]: nn.Parameter(load_emb(os.path.join(emb_dir, file)).to(device), requires_grad=False)
|
92
|
+
for file in os.listdir(emb_dir) if file.endswith('.pt')}
|
93
|
+
return cls.hook(ex_words_emb, tokenizer, text_encoder, **kwargs), ex_words_emb
|
94
|
+
|
95
|
+
class EmbeddingPTInterpHook(SinglePluginBlock):
|
96
|
+
def __init__(self, token_embedding:nn.Embedding, N_word=75, N_repeats=3):
|
97
|
+
super().__init__('emb_ex', token_embedding)
|
98
|
+
self.handle_pre = token_embedding.register_forward_pre_hook(self.pre_hook)
|
99
|
+
|
100
|
+
new_len = int(token_embedding.num_embeddings*N_repeats)
|
101
|
+
original_weights = token_embedding.weight.data.unsqueeze(1)
|
102
|
+
token_embedding.weight.data = F.interpolate(original_weights, size=new_len, mode='linear', align_corners=False).squeeze(1)
|
103
|
+
token_embedding.num_embeddings = new_len
|
104
|
+
|
105
|
+
self.N_word=N_word
|
106
|
+
self.N_repeats=N_repeats
|
107
|
+
self.num_embeddings=token_embedding.num_embeddings
|
108
|
+
self.embedding_dim=token_embedding.embedding_dim
|
109
|
+
self.emb={}
|
110
|
+
self.emb_train=nn.ParameterList()
|
111
|
+
|
112
|
+
def add_emb(self, emb:nn.Parameter, token_id:int):
|
113
|
+
self.emb[token_id]=emb
|
114
|
+
|
115
|
+
def pre_hook(self, host, input_ids: Tuple[torch.Tensor]):
|
116
|
+
self.input_ids = rearrange(input_ids[0], '(b r) w -> b (r w)', r=self.N_repeats) # 兼容Attention mask
|
117
|
+
return self.input_ids.clip(0, self.num_embeddings-1)
|
118
|
+
|
40
119
|
def forward(self, fea_in:Tuple[torch.Tensor], inputs_embeds:torch.Tensor):
|
41
120
|
'''
|
42
121
|
:param input_ids: [B, N_ids]
|
@@ -83,12 +162,11 @@ class EmbeddingPTHook(SinglePluginBlock):
|
|
83
162
|
for tid, word in zip(token_ids, word_list):
|
84
163
|
embedding_hook.add_emb(ex_words_emb[word], tid)
|
85
164
|
if log:
|
86
|
-
logger.info(f'hook: {word}, len: {ex_words_emb[word].shape[0]}, id: {tid}')
|
165
|
+
_share.logger.info(f'hook: {word}, len: {ex_words_emb[word].shape[0]}, id: {tid}')
|
87
166
|
return embedding_hook
|
88
167
|
|
89
168
|
@classmethod
|
90
169
|
def hook_from_dir(cls, emb_dir, tokenizer, text_encoder, log=True, device='cuda:0', **kwargs):
|
91
170
|
ex_words_emb = {file[:-3]: nn.Parameter(load_emb(os.path.join(emb_dir, file)).to(device), requires_grad=False)
|
92
171
|
for file in os.listdir(emb_dir) if file.endswith('.pt')}
|
93
|
-
return cls.hook(ex_words_emb, tokenizer, text_encoder, log, **kwargs), ex_words_emb
|
94
|
-
|
172
|
+
return cls.hook(ex_words_emb, tokenizer, text_encoder, log, **kwargs), ex_words_emb
|
hcpdiff/models/textencoder_ex.py
CHANGED
@@ -8,29 +8,53 @@ textencoder_ex.py
|
|
8
8
|
:Licence: Apache-2.0
|
9
9
|
"""
|
10
10
|
|
11
|
-
from typing import Tuple, Optional
|
11
|
+
from typing import Tuple, Optional
|
12
12
|
|
13
13
|
import torch
|
14
14
|
from einops import repeat, rearrange
|
15
15
|
from einops.layers.torch import Rearrange
|
16
|
+
from loguru import logger
|
16
17
|
from torch import nn
|
18
|
+
from transformers import CLIPTextModelWithProjection, T5EncoderModel
|
17
19
|
from transformers.models.clip.modeling_clip import CLIPAttention
|
18
20
|
|
19
21
|
class TEEXHook:
|
20
|
-
def __init__(self, text_enc: nn.Module, tokenizer, N_repeats=
|
22
|
+
def __init__(self, text_enc: nn.Module, tokenizer, N_repeats=1, clip_skip=0, clip_final_norm=True, use_attention_mask=False):
|
21
23
|
self.text_enc = text_enc
|
22
24
|
self.tokenizer = tokenizer
|
23
25
|
|
24
26
|
self.N_repeats = N_repeats
|
25
27
|
self.clip_skip = clip_skip
|
26
28
|
self.clip_final_norm = clip_final_norm
|
27
|
-
self.device = device
|
28
|
-
self.attn_mult = None
|
29
29
|
self.use_attention_mask = use_attention_mask
|
30
30
|
|
31
31
|
text_enc.register_forward_hook(self.forward_hook)
|
32
32
|
text_enc.register_forward_pre_hook(self.forward_hook_input)
|
33
33
|
|
34
|
+
def find_final_norm(self, text_enc: nn.Module):
|
35
|
+
for module in text_enc.modules():
|
36
|
+
if 'final_layer_norm' in module._modules:
|
37
|
+
logger.info(f'find final_layer_norm in {type(module)}')
|
38
|
+
return module.final_layer_norm
|
39
|
+
|
40
|
+
logger.info(f'final_layer_norm not found in {type(text_enc)}')
|
41
|
+
return None
|
42
|
+
|
43
|
+
@property
|
44
|
+
def clip_final_norm(self):
|
45
|
+
return self.final_layer_norm is not None
|
46
|
+
|
47
|
+
@clip_final_norm.setter
|
48
|
+
def clip_final_norm(self, value: bool):
|
49
|
+
if value:
|
50
|
+
self.final_layer_norm = self.find_final_norm(self.text_enc)
|
51
|
+
else:
|
52
|
+
self.final_layer_norm = None
|
53
|
+
|
54
|
+
@property
|
55
|
+
def device(self):
|
56
|
+
return self.text_enc.device
|
57
|
+
|
34
58
|
def encode_prompt_to_emb(self, prompt):
|
35
59
|
text_inputs = self.tokenizer(
|
36
60
|
prompt,
|
@@ -50,12 +74,23 @@ class TEEXHook:
|
|
50
74
|
if position_ids is not None:
|
51
75
|
position_ids = position_ids.to(self.device)
|
52
76
|
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
77
|
+
# align with sd-webui
|
78
|
+
if isinstance(self.text_enc, CLIPTextModelWithProjection):
|
79
|
+
self.text_enc.text_projection.weight.data = self.text_enc.text_projection.weight.data.t()
|
80
|
+
|
81
|
+
if isinstance(self.text_enc, T5EncoderModel):
|
82
|
+
prompt_embeds, pooled_output = self.text_enc(
|
83
|
+
text_input_ids.to(self.device),
|
84
|
+
attention_mask=attention_mask,
|
85
|
+
output_hidden_states=True,
|
86
|
+
)
|
87
|
+
else:
|
88
|
+
prompt_embeds, pooled_output = self.text_enc(
|
89
|
+
text_input_ids.to(self.device),
|
90
|
+
attention_mask=attention_mask,
|
91
|
+
position_ids=position_ids,
|
92
|
+
output_hidden_states=True,
|
93
|
+
)
|
59
94
|
return prompt_embeds, pooled_output, attention_mask
|
60
95
|
|
61
96
|
def forward_hook_input(self, host, feat_in):
|
@@ -64,13 +99,12 @@ class TEEXHook:
|
|
64
99
|
|
65
100
|
def forward_hook(self, host, feat_in: Tuple[torch.Tensor], feat_out):
|
66
101
|
encoder_hidden_states = feat_out['hidden_states'][-self.clip_skip-1]
|
67
|
-
if self.clip_final_norm:
|
68
|
-
encoder_hidden_states = self.
|
102
|
+
if self.clip_final_norm and self.final_layer_norm is not None:
|
103
|
+
encoder_hidden_states = self.final_layer_norm(encoder_hidden_states)
|
69
104
|
if self.text_enc.training and self.clip_skip>0:
|
70
105
|
encoder_hidden_states = encoder_hidden_states+0*feat_out['last_hidden_state'].mean() # avoid unused parameters, make gradient checkpointing happy
|
71
|
-
|
72
106
|
encoder_hidden_states = rearrange(encoder_hidden_states, '(b r) ... -> b r ...', r=self.N_repeats) # [B, N_repeat, N_word+2, N_emb]
|
73
|
-
pooled_output = feat_out.pooler_output
|
107
|
+
pooled_output = feat_out.get('pooler_output', feat_out.get('text_embeds', None))
|
74
108
|
# TODO: may have better fusion method
|
75
109
|
if pooled_output is not None:
|
76
110
|
pooled_output = rearrange(pooled_output, '(b r) ... -> b r ...', r=self.N_repeats).mean(dim=1)
|
@@ -81,7 +115,7 @@ class TEEXHook:
|
|
81
115
|
return encoder_hidden_states, pooled_output
|
82
116
|
|
83
117
|
def pool_hidden_states(self, encoder_hidden_states, input_ids):
|
84
|
-
pooled_output = encoder_hidden_states[:, :, -1, :].mean(dim=1)
|
118
|
+
pooled_output = encoder_hidden_states[:, :, -1, :].mean(dim=1) # [B, N_emb]
|
85
119
|
return pooled_output
|
86
120
|
|
87
121
|
@staticmethod
|
@@ -147,9 +181,11 @@ class TEEXHook:
|
|
147
181
|
layer.forward = forward
|
148
182
|
|
149
183
|
@classmethod
|
150
|
-
def hook(cls, text_enc: nn.Module, tokenizer, N_repeats=3, clip_skip=0, clip_final_norm=True,
|
151
|
-
return cls(text_enc, tokenizer, N_repeats=N_repeats, clip_skip=clip_skip, clip_final_norm=clip_final_norm,
|
184
|
+
def hook(cls, text_enc: nn.Module, tokenizer, N_repeats=3, clip_skip=0, clip_final_norm=True, use_attention_mask=False):
|
185
|
+
return cls(text_enc, tokenizer, N_repeats=N_repeats, clip_skip=clip_skip, clip_final_norm=clip_final_norm,
|
186
|
+
use_attention_mask=use_attention_mask)
|
152
187
|
|
153
188
|
@classmethod
|
154
189
|
def hook_pipe(cls, pipe, N_repeats=3, clip_skip=0, clip_final_norm=True, use_attention_mask=False):
|
155
|
-
return cls(pipe.text_encoder, pipe.tokenizer, N_repeats=N_repeats,
|
190
|
+
return cls(pipe.text_encoder, pipe.tokenizer, N_repeats=N_repeats, clip_skip=clip_skip, clip_final_norm=clip_final_norm,
|
191
|
+
use_attention_mask=use_attention_mask)
|
@@ -0,0 +1,19 @@
|
|
1
|
+
from .sd import SD15Wrapper
|
2
|
+
from hcpdiff.utils import pad_attn_bias
|
3
|
+
|
4
|
+
class PixArtWrapper(SD15Wrapper):
|
5
|
+
def forward_denoiser(self, x_t, prompt_ids, encoder_hidden_states, timesteps, attn_mask=None, position_ids=None, resolution=None, aspect_ratio=None,
|
6
|
+
plugin_input={}, **kwargs):
|
7
|
+
if attn_mask is not None:
|
8
|
+
attn_mask[:, :self.min_attnmask] = 1
|
9
|
+
encoder_hidden_states, attn_mask = pad_attn_bias(encoder_hidden_states, attn_mask)
|
10
|
+
|
11
|
+
input_all = dict(prompt_ids=prompt_ids, timesteps=timesteps, position_ids=position_ids, attn_mask=attn_mask,
|
12
|
+
encoder_hidden_states=encoder_hidden_states, **plugin_input)
|
13
|
+
if hasattr(self.denoiser, 'input_feeder'):
|
14
|
+
for feeder in self.denoiser.input_feeder:
|
15
|
+
feeder(input_all)
|
16
|
+
added_cond_kwargs = {"resolution":resolution, "aspect_ratio":aspect_ratio}
|
17
|
+
model_pred = self.denoiser(x_t, encoder_hidden_states, timesteps, encoder_attention_mask=attn_mask,
|
18
|
+
added_cond_kwargs=added_cond_kwargs).sample # Predict the noise residual
|
19
|
+
return model_pred
|
@@ -0,0 +1,218 @@
|
|
1
|
+
from contextlib import nullcontext
|
2
|
+
from functools import partial
|
3
|
+
from typing import Dict, Union
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from diffusers import AutoencoderKL, UNet2DConditionModel
|
7
|
+
from rainbowneko.models.wrapper import BaseWrapper
|
8
|
+
from torch import Tensor
|
9
|
+
from torch import nn
|
10
|
+
|
11
|
+
from hcpdiff.diffusion.sampler import BaseSampler
|
12
|
+
from hcpdiff.models import TEEXHook
|
13
|
+
from hcpdiff.models.compose import ComposeTEEXHook
|
14
|
+
from hcpdiff.utils import pad_attn_bias
|
15
|
+
from .utils import TEHookCFG, SD15_TEHookCFG, SDXL_TEHookCFG
|
16
|
+
from ..cfg_context import CFGContext
|
17
|
+
|
18
|
+
class SD15Wrapper(BaseWrapper):
|
19
|
+
def __init__(self, denoiser: UNet2DConditionModel, TE, vae: AutoencoderKL, noise_sampler: BaseSampler, tokenizer, min_attnmask=0,
|
20
|
+
pred_type='eps', TE_hook_cfg:TEHookCFG=SD15_TEHookCFG, cfg_context=CFGContext(), key_map_in=None, key_map_out=None):
|
21
|
+
super().__init__()
|
22
|
+
self.key_mapper_in = self.build_mapper(key_map_in, None, (
|
23
|
+
'prompt -> prompt_ids', 'image -> image', 'attn_mask -> attn_mask', 'position_ids -> position_ids', 'neg_prompt -> neg_prompt_ids',
|
24
|
+
'neg_attn_mask -> neg_attn_mask', 'neg_position_ids -> neg_position_ids', 'plugin_input -> plugin_input'))
|
25
|
+
self.key_mapper_out = self.build_mapper(key_map_out, None, None)
|
26
|
+
|
27
|
+
self.denoiser = denoiser
|
28
|
+
self.TE = TE
|
29
|
+
self.vae = vae
|
30
|
+
self.noise_sampler = noise_sampler
|
31
|
+
self.tokenizer = tokenizer
|
32
|
+
self.min_attnmask = min_attnmask
|
33
|
+
|
34
|
+
self.pred_type = pred_type
|
35
|
+
|
36
|
+
self.TE_hook_cfg = TEHookCFG.create(TE_hook_cfg)
|
37
|
+
self.cfg_context = cfg_context
|
38
|
+
self.tokenizer.N_repeats = self.TE_hook_cfg.tokenizer_repeats
|
39
|
+
|
40
|
+
def post_init(self):
|
41
|
+
self.make_TE_hook(self.TE_hook_cfg)
|
42
|
+
|
43
|
+
self.vae_trainable = False
|
44
|
+
if self.vae is not None:
|
45
|
+
for p in self.vae.parameters():
|
46
|
+
if p.requires_grad:
|
47
|
+
self.vae_trainable = True
|
48
|
+
break
|
49
|
+
|
50
|
+
self.TE_trainable = False
|
51
|
+
for p in self.TE.parameters():
|
52
|
+
if p.requires_grad:
|
53
|
+
self.TE_trainable = True
|
54
|
+
break
|
55
|
+
|
56
|
+
def make_TE_hook(self, TE_hook_cfg):
|
57
|
+
# Hook and extend text_encoder
|
58
|
+
self.text_enc_hook = TEEXHook.hook(self.TE, self.tokenizer, N_repeats=TE_hook_cfg.tokenizer_repeats,
|
59
|
+
clip_skip=TE_hook_cfg.clip_skip, clip_final_norm=TE_hook_cfg.clip_final_norm)
|
60
|
+
|
61
|
+
def get_latents(self, image: Tensor):
|
62
|
+
if image.shape[1] == 3:
|
63
|
+
with torch.no_grad() if self.vae_trainable else nullcontext():
|
64
|
+
latents = self.vae.encode(image.to(dtype=self.vae.dtype)).latent_dist.sample()
|
65
|
+
latents = latents*self.vae.config.scaling_factor
|
66
|
+
else:
|
67
|
+
latents = image # Cached latents
|
68
|
+
return latents
|
69
|
+
|
70
|
+
def forward_TE(self, prompt_ids, timesteps, attn_mask=None, position_ids=None, plugin_input={}, **kwargs):
|
71
|
+
input_all = dict(prompt_ids=prompt_ids, timesteps=timesteps, position_ids=position_ids, attn_mask=attn_mask, **plugin_input)
|
72
|
+
if hasattr(self.TE, 'input_feeder'):
|
73
|
+
for feeder in self.TE.input_feeder:
|
74
|
+
feeder(input_all)
|
75
|
+
# Get the text embedding for conditioning
|
76
|
+
encoder_hidden_states = self.TE(prompt_ids, position_ids=position_ids, attention_mask=attn_mask, output_hidden_states=True)[0]
|
77
|
+
return encoder_hidden_states
|
78
|
+
|
79
|
+
def forward_denoiser(self, x_t, prompt_ids, encoder_hidden_states, timesteps, attn_mask=None, position_ids=None, plugin_input={}, **kwargs):
|
80
|
+
if attn_mask is not None:
|
81
|
+
attn_mask[:, :self.min_attnmask] = 1
|
82
|
+
encoder_hidden_states, attn_mask = pad_attn_bias(encoder_hidden_states, attn_mask)
|
83
|
+
|
84
|
+
input_all = dict(prompt_ids=prompt_ids, timesteps=timesteps, position_ids=position_ids, attn_mask=attn_mask,
|
85
|
+
encoder_hidden_states=encoder_hidden_states, **plugin_input)
|
86
|
+
if hasattr(self.denoiser, 'input_feeder'):
|
87
|
+
for feeder in self.denoiser.input_feeder:
|
88
|
+
feeder(input_all)
|
89
|
+
model_pred = self.denoiser(x_t, timesteps, encoder_hidden_states, encoder_attention_mask=attn_mask).sample # Predict the noise residual
|
90
|
+
return model_pred
|
91
|
+
|
92
|
+
def model_forward(self, prompt_ids, image, attn_mask=None, position_ids=None, neg_prompt_ids=None, neg_attn_mask=None, neg_position_ids=None,
|
93
|
+
plugin_input={}, **kwargs):
|
94
|
+
# input prepare
|
95
|
+
x_0 = self.get_latents(image)
|
96
|
+
x_t, noise, sigma, timesteps = self.noise_sampler.add_noise_rand_t(x_0)
|
97
|
+
x_t_in = x_t*self.noise_sampler.c_in(sigma).to(dtype=x_t.dtype)
|
98
|
+
|
99
|
+
if neg_prompt_ids:
|
100
|
+
prompt_ids = torch.cat([neg_prompt_ids, prompt_ids], dim=0)
|
101
|
+
if neg_attn_mask:
|
102
|
+
attn_mask = torch.cat([neg_attn_mask, attn_mask], dim=0)
|
103
|
+
if neg_position_ids:
|
104
|
+
position_ids = torch.cat([neg_position_ids, position_ids], dim=0)
|
105
|
+
|
106
|
+
# model forward
|
107
|
+
x_t_in, timesteps = self.cfg_context.pre(x_t_in, timesteps)
|
108
|
+
encoder_hidden_states = self.forward_TE(prompt_ids, timesteps, attn_mask=attn_mask, position_ids=position_ids,
|
109
|
+
plugin_input=plugin_input, **kwargs)
|
110
|
+
model_pred = self.forward_denoiser(x_t_in, prompt_ids, encoder_hidden_states, timesteps, attn_mask=attn_mask, position_ids=position_ids,
|
111
|
+
plugin_input=plugin_input, **kwargs)
|
112
|
+
model_pred = self.cfg_context.post(model_pred)
|
113
|
+
|
114
|
+
return dict(model_pred=model_pred, noise=noise, sigma=sigma, timesteps=timesteps, x_0=x_0, x_t=x_t, pred_type=self.pred_type,
|
115
|
+
noise_sampler=self.noise_sampler)
|
116
|
+
|
117
|
+
def forward(self, ds_name=None, **kwargs):
|
118
|
+
model_args, model_kwargs = self.get_map_data(self.key_mapper_in, kwargs, ds_name)
|
119
|
+
out = self.model_forward(*model_args, **model_kwargs)
|
120
|
+
return self.get_map_data(self.key_mapper_out, out, ds_name=ds_name)[1]
|
121
|
+
|
122
|
+
def enable_gradient_checkpointing(self):
|
123
|
+
def grad_ckpt_enable(m):
|
124
|
+
if getattr(m, 'gradient_checkpointing', False):
|
125
|
+
m.training = True
|
126
|
+
|
127
|
+
self.denoiser.enable_gradient_checkpointing()
|
128
|
+
if self.TE_trainable:
|
129
|
+
self.TE.gradient_checkpointing_enable()
|
130
|
+
self.apply(grad_ckpt_enable)
|
131
|
+
|
132
|
+
def enable_xformers(self):
|
133
|
+
self.denoiser.enable_xformers_memory_efficient_attention()
|
134
|
+
|
135
|
+
@property
|
136
|
+
def trainable_parameters(self):
|
137
|
+
return [p for p in self.parameters() if p.requires_grad]
|
138
|
+
|
139
|
+
@property
|
140
|
+
def trainable_models(self) -> Dict[str, nn.Module]:
|
141
|
+
return {'self':self}
|
142
|
+
|
143
|
+
def set_dtype(self, dtype, vae_dtype):
|
144
|
+
self.dtype = dtype
|
145
|
+
self.vae_dtype = vae_dtype
|
146
|
+
# Move vae and text_encoder to device and cast to weight_dtype
|
147
|
+
if self.vae is not None:
|
148
|
+
self.vae = self.vae.to(dtype=vae_dtype)
|
149
|
+
if not self.TE_trainable:
|
150
|
+
self.TE = self.TE.to(dtype=dtype)
|
151
|
+
|
152
|
+
@classmethod
|
153
|
+
def from_pretrained(cls, models: Union[partial, Dict[str, nn.Module]], **kwargs):
|
154
|
+
models = models() if isinstance(models, partial) else models
|
155
|
+
return cls(models['denoiser'], models['TE'], models['vae'], models['noise_sampler'], models['tokenizer'], **kwargs)
|
156
|
+
|
157
|
+
class SDXLWrapper(SD15Wrapper):
|
158
|
+
def __init__(self, denoiser: UNet2DConditionModel, TE, vae: AutoencoderKL, noise_sampler: BaseSampler, tokenizer, min_attnmask=0,
|
159
|
+
pred_type='eps', TE_hook_cfg:TEHookCFG=SDXL_TEHookCFG, cfg_context=CFGContext(), key_map_in=None, key_map_out=None):
|
160
|
+
super().__init__(denoiser, TE, vae, noise_sampler, tokenizer, min_attnmask, pred_type, TE_hook_cfg, cfg_context, key_map_in, key_map_out)
|
161
|
+
self.key_mapper_in = self.build_mapper(key_map_in, None, (
|
162
|
+
'prompt -> prompt_ids', 'image -> image', 'attn_mask -> attn_mask', 'position_ids -> position_ids', 'neg_prompt -> neg_prompt_ids',
|
163
|
+
'neg_attn_mask -> neg_attn_mask', 'neg_position_ids -> neg_position_ids', 'plugin_input -> plugin_input', 'coord -> crop_info'))
|
164
|
+
|
165
|
+
def make_TE_hook(self, TE_hook_cfg):
|
166
|
+
# Hook and extend text_encoder
|
167
|
+
self.text_enc_hook = ComposeTEEXHook.hook(self.TE, self.tokenizer, N_repeats=TE_hook_cfg.tokenizer_repeats,
|
168
|
+
clip_skip=TE_hook_cfg.clip_skip, clip_final_norm=TE_hook_cfg.clip_final_norm)
|
169
|
+
|
170
|
+
def forward_TE(self, prompt_ids, timesteps, attn_mask=None, position_ids=None, plugin_input={}, **kwargs):
|
171
|
+
input_all = dict(prompt_ids=prompt_ids, timesteps=timesteps, position_ids=position_ids, attn_mask=attn_mask, **plugin_input)
|
172
|
+
if hasattr(self.TE, 'input_feeder'):
|
173
|
+
for feeder in self.TE.input_feeder:
|
174
|
+
feeder(input_all)
|
175
|
+
# Get the text embedding for conditioning
|
176
|
+
encoder_hidden_states, pooled_output = self.TE(prompt_ids, position_ids=position_ids, attention_mask=attn_mask, output_hidden_states=True)
|
177
|
+
return encoder_hidden_states, pooled_output
|
178
|
+
|
179
|
+
def forward_denoiser(self, x_t, prompt_ids, encoder_hidden_states, timesteps, added_cond_kwargs, attn_mask=None, position_ids=None,
|
180
|
+
plugin_input={}, **kwargs):
|
181
|
+
if attn_mask is not None:
|
182
|
+
attn_mask[:, :self.min_attnmask] = 1
|
183
|
+
encoder_hidden_states, attn_mask = pad_attn_bias(encoder_hidden_states, attn_mask)
|
184
|
+
|
185
|
+
input_all = dict(prompt_ids=prompt_ids, timesteps=timesteps, position_ids=position_ids, attn_mask=attn_mask,
|
186
|
+
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs, **plugin_input)
|
187
|
+
if hasattr(self.denoiser, 'input_feeder'):
|
188
|
+
for feeder in self.denoiser.input_feeder:
|
189
|
+
feeder(input_all)
|
190
|
+
model_pred = self.denoiser(x_t, timesteps, encoder_hidden_states, encoder_attention_mask=attn_mask,
|
191
|
+
added_cond_kwargs=added_cond_kwargs).sample # Predict the noise residual
|
192
|
+
return model_pred
|
193
|
+
|
194
|
+
def model_forward(self, prompt_ids, image, attn_mask=None, position_ids=None, neg_prompt_ids=None, neg_attn_mask=None, neg_position_ids=None,
|
195
|
+
crop_info=None, plugin_input={}):
|
196
|
+
# input prepare
|
197
|
+
x_0 = self.get_latents(image)
|
198
|
+
x_t, noise, sigma, timesteps = self.noise_sampler.add_noise_rand_t(x_0)
|
199
|
+
x_t_in = x_t*self.noise_sampler.c_in(sigma).to(dtype=x_t.dtype)
|
200
|
+
|
201
|
+
if neg_prompt_ids:
|
202
|
+
prompt_ids = torch.cat([neg_prompt_ids, prompt_ids], dim=0)
|
203
|
+
if neg_attn_mask:
|
204
|
+
attn_mask = torch.cat([neg_attn_mask, attn_mask], dim=0)
|
205
|
+
if neg_position_ids:
|
206
|
+
position_ids = torch.cat([neg_position_ids, position_ids], dim=0)
|
207
|
+
|
208
|
+
# model forward
|
209
|
+
x_t_in, timesteps = self.cfg_context.pre(x_t_in, timesteps)
|
210
|
+
encoder_hidden_states, pooled_output = self.forward_TE(prompt_ids, timesteps, attn_mask=attn_mask, position_ids=position_ids,
|
211
|
+
plugin_input=plugin_input)
|
212
|
+
added_cond_kwargs = {"text_embeds":pooled_output[-1], "time_ids":crop_info}
|
213
|
+
model_pred = self.forward_denoiser(x_t_in, prompt_ids, encoder_hidden_states, timesteps, added_cond_kwargs=added_cond_kwargs,
|
214
|
+
attn_mask=attn_mask, position_ids=position_ids, plugin_input=plugin_input)
|
215
|
+
model_pred = self.cfg_context.post(model_pred)
|
216
|
+
|
217
|
+
return dict(model_pred=model_pred, noise=noise, sigma=sigma, timesteps=timesteps, x_0=x_0, x_t=x_t, pred_type=self.pred_type,
|
218
|
+
noise_sampler=self.noise_sampler)
|
@@ -0,0 +1,20 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from rainbowneko.utils import is_dict
|
3
|
+
|
4
|
+
class TEHookCFG:
|
5
|
+
def __init__(self, tokenizer_repeats: int = 1, clip_skip: int = 0, clip_final_norm: bool = True):
|
6
|
+
self.tokenizer_repeats = tokenizer_repeats
|
7
|
+
self.clip_skip = clip_skip
|
8
|
+
self.clip_final_norm = clip_final_norm
|
9
|
+
|
10
|
+
@classmethod
|
11
|
+
def create(cls, cfg):
|
12
|
+
if is_dict(cfg):
|
13
|
+
return cls(**cfg)
|
14
|
+
elif isinstance(cfg, cls):
|
15
|
+
return cfg
|
16
|
+
else:
|
17
|
+
raise ValueError(f'Invalid TEHookCFG type: {type(cfg)}')
|
18
|
+
|
19
|
+
SD15_TEHookCFG = TEHookCFG()
|
20
|
+
SDXL_TEHookCFG = TEHookCFG(clip_skip=1, clip_final_norm=False)
|
@@ -0,0 +1 @@
|
|
1
|
+
from .embpt import CfgEmbPTParser
|
hcpdiff/parser/embpt.py
ADDED
@@ -0,0 +1,32 @@
|
|
1
|
+
from typing import Dict, Tuple, List
|
2
|
+
from rainbowneko.utils import Path_Like
|
3
|
+
from hcpdiff.models import EmbeddingPTHook
|
4
|
+
from torch import Tensor
|
5
|
+
|
6
|
+
class CfgEmbPTParser:
|
7
|
+
def __init__(self, emb_dir: Path_Like, cfg_pt: Dict[str, Dict], lr: float = 1e-5, weight_decay: float = 0):
|
8
|
+
self.emb_dir = emb_dir
|
9
|
+
self.cfg_pt = cfg_pt
|
10
|
+
self.lr = lr
|
11
|
+
self.weight_decay = weight_decay
|
12
|
+
|
13
|
+
def get_params_group(self, model) -> Tuple[List, Dict[str, Tensor]]:
|
14
|
+
self.embedding_hook, self.ex_words_emb = EmbeddingPTHook.hook_from_dir(
|
15
|
+
self.emb_dir, model.tokenizer, model.TE, N_repeats=model.tokenizer.N_repeats)
|
16
|
+
self.embedding_hook.requires_grad_(False)
|
17
|
+
|
18
|
+
train_params_emb = []
|
19
|
+
train_pts = {}
|
20
|
+
for pt_name, info in self.cfg_pt.items():
|
21
|
+
word_emb = self.ex_words_emb[pt_name]
|
22
|
+
train_pts[pt_name] = word_emb
|
23
|
+
word_emb.requires_grad = True
|
24
|
+
self.embedding_hook.emb_train.append(word_emb)
|
25
|
+
param_group = {'params':word_emb}
|
26
|
+
if 'lr' in info:
|
27
|
+
param_group['lr'] = info.lr
|
28
|
+
if 'weight_decay' in info:
|
29
|
+
param_group['weight_decay'] = info.weight_decay
|
30
|
+
train_params_emb.append(param_group)
|
31
|
+
|
32
|
+
return train_params_emb, train_pts
|
@@ -2,7 +2,7 @@ import argparse
|
|
2
2
|
import json
|
3
3
|
import os
|
4
4
|
|
5
|
-
from
|
5
|
+
from rainbowneko.utils import types_support
|
6
6
|
|
7
7
|
parser = argparse.ArgumentParser(description='Stable Diffusion Training')
|
8
8
|
parser.add_argument('--data_root', type=str, default='')
|