hcpdiff 2.3.1__py3-none-any.whl → 2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/ckpt_manager/__init__.py +1 -1
- hcpdiff/ckpt_manager/format/__init__.py +2 -2
- hcpdiff/ckpt_manager/format/diffusers.py +19 -4
- hcpdiff/ckpt_manager/format/emb.py +8 -3
- hcpdiff/ckpt_manager/format/lora_webui.py +1 -1
- hcpdiff/ckpt_manager/format/sd_single.py +28 -5
- hcpdiff/data/cache/vae.py +10 -2
- hcpdiff/data/handler/text.py +15 -14
- hcpdiff/diffusion/sampler/__init__.py +2 -1
- hcpdiff/diffusion/sampler/base.py +17 -6
- hcpdiff/diffusion/sampler/diffusers.py +4 -3
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +5 -14
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +7 -6
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +4 -4
- hcpdiff/diffusion/sampler/sigma_scheduler/flow.py +3 -3
- hcpdiff/diffusion/sampler/timer/__init__.py +2 -0
- hcpdiff/diffusion/sampler/timer/base.py +26 -0
- hcpdiff/diffusion/sampler/timer/shift.py +49 -0
- hcpdiff/easy/__init__.py +2 -1
- hcpdiff/easy/cfg/sd15_train.py +1 -3
- hcpdiff/easy/model/__init__.py +1 -1
- hcpdiff/easy/model/loader.py +33 -11
- hcpdiff/easy/sampler.py +8 -1
- hcpdiff/loss/__init__.py +4 -3
- hcpdiff/loss/charbonnier.py +17 -0
- hcpdiff/loss/vlb.py +2 -2
- hcpdiff/loss/weighting.py +29 -11
- hcpdiff/models/__init__.py +1 -1
- hcpdiff/models/cfg_context.py +5 -3
- hcpdiff/models/compose/__init__.py +2 -1
- hcpdiff/models/compose/compose_hook.py +69 -67
- hcpdiff/models/compose/compose_textencoder.py +59 -45
- hcpdiff/models/compose/compose_tokenizer.py +48 -11
- hcpdiff/models/compose/flux.py +75 -0
- hcpdiff/models/compose/sdxl.py +86 -0
- hcpdiff/models/text_emb_ex.py +13 -9
- hcpdiff/models/textencoder_ex.py +8 -38
- hcpdiff/models/wrapper/__init__.py +2 -1
- hcpdiff/models/wrapper/flux.py +75 -0
- hcpdiff/models/wrapper/pixart.py +13 -1
- hcpdiff/models/wrapper/sd.py +17 -8
- hcpdiff/parser/embpt.py +7 -7
- hcpdiff/utils/net_utils.py +22 -12
- hcpdiff/workflow/__init__.py +1 -1
- hcpdiff/workflow/diffusion.py +145 -18
- hcpdiff/workflow/text.py +49 -18
- hcpdiff/workflow/vae.py +10 -2
- {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/METADATA +1 -1
- {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/RECORD +53 -49
- hcpdiff/models/compose/sdxl_composer.py +0 -39
- hcpdiff/utils/inpaint_pipe.py +0 -790
- hcpdiff/utils/pipe_hook.py +0 -656
- {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/WHEEL +0 -0
- {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/entry_points.txt +0 -0
- {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/licenses/LICENSE +0 -0
- {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/top_level.txt +0 -0
@@ -14,19 +14,20 @@ from typing import Dict, Tuple, List
|
|
14
14
|
import torch
|
15
15
|
from transformers import AutoTokenizer, CLIPTokenizer, PreTrainedTokenizer, PretrainedConfig
|
16
16
|
from transformers.tokenization_utils_base import BatchEncoding
|
17
|
+
from rainbowneko.utils import BatchableDict
|
17
18
|
|
18
19
|
class ComposeTokenizer(PreTrainedTokenizer):
|
19
|
-
def __init__(self,
|
20
|
-
self.cat_dim = cat_dim
|
20
|
+
def __init__(self, tokenizers: Dict[str, CLIPTokenizer]):
|
21
21
|
|
22
22
|
self.tokenizer_names = []
|
23
|
-
for name, tokenizer in
|
23
|
+
for name, tokenizer in tokenizers.items():
|
24
24
|
setattr(self, name, tokenizer)
|
25
25
|
self.tokenizer_names.append(name)
|
26
26
|
|
27
27
|
super().__init__()
|
28
28
|
|
29
|
-
self.model_max_length = torch.tensor([tokenizer.model_max_length for name, tokenizer in tokenizer_list])
|
29
|
+
# self.model_max_length = torch.tensor([tokenizer.model_max_length for name, tokenizer in tokenizer_list])
|
30
|
+
self.model_max_length = {name: tokenizer.model_max_length for name, tokenizer in tokenizers.items()}
|
30
31
|
|
31
32
|
@property
|
32
33
|
def first_tokenizer(self):
|
@@ -57,15 +58,17 @@ class ComposeTokenizer(PreTrainedTokenizer):
|
|
57
58
|
return self.first_tokenizer.save_vocabulary(save_directory, filename_prefix)
|
58
59
|
|
59
60
|
def __call__(self, text, *args, max_length=None, **kwargs):
|
60
|
-
if isinstance(max_length,
|
61
|
-
|
62
|
-
for name
|
61
|
+
if isinstance(max_length, dict):
|
62
|
+
token_infos: Dict[str, BatchEncoding] = {name: getattr(self, name)(text, *args, max_length=max_length[name], **kwargs)
|
63
|
+
for name in self.tokenizer_names}
|
63
64
|
else:
|
64
|
-
|
65
|
+
token_infos: Dict[str, BatchEncoding] = {name: getattr(self, name)(text, *args, max_length=max_length, **kwargs)
|
66
|
+
for name in self.tokenizer_names}
|
65
67
|
|
66
|
-
input_ids =
|
67
|
-
attention_mask =
|
68
|
-
|
68
|
+
input_ids = BatchableDict({name: token.input_ids for name, token in token_infos.items()}) # [N_tokenizer, N_token]
|
69
|
+
attention_mask = BatchableDict({name: token.get('attention_mask', None) for name, token in token_infos.items()})
|
70
|
+
position_ids = BatchableDict({name: token.get('position_ids', None) for name, token in token_infos.items()})
|
71
|
+
return BatchEncoding({'input_ids':input_ids, 'attention_mask':attention_mask, 'position_ids':position_ids})
|
69
72
|
|
70
73
|
@classmethod
|
71
74
|
def from_pretrained(cls, pretrained_model_name_or_path: List[Tuple[str, str]], *args,
|
@@ -73,3 +76,37 @@ class ComposeTokenizer(PreTrainedTokenizer):
|
|
73
76
|
tokenizer_list = [(name, AutoTokenizer.from_pretrained(path, subfolder=subfolder[name], **kwargs)) for name, path in pretrained_model_name_or_path]
|
74
77
|
compose_tokenizer = cls(tokenizer_list)
|
75
78
|
return compose_tokenizer
|
79
|
+
|
80
|
+
def __repr__(self):
|
81
|
+
return f'ComposeTokenizer(\n' + '\n'.join([f' {name}: {repr(getattr(self, name))}' for name in self.tokenizer_names]) + ')'
|
82
|
+
|
83
|
+
@staticmethod
|
84
|
+
def tokenize_ex(tokenizer, *args, device='cpu', squeeze=False, **kwargs):
|
85
|
+
if isinstance(tokenizer, ComposeTokenizer):
|
86
|
+
max_length = {name: (tok := getattr(tokenizer, name)).model_max_length * getattr(tok, 'N_repeats', 1) for name in tokenizer.tokenizer_names}
|
87
|
+
else:
|
88
|
+
max_length = tokenizer.model_max_length * getattr(tokenizer, 'N_repeats', 1)
|
89
|
+
|
90
|
+
text_inputs = tokenizer(
|
91
|
+
*args,
|
92
|
+
max_length=max_length,
|
93
|
+
**kwargs
|
94
|
+
)
|
95
|
+
|
96
|
+
def proc_tensor(v):
|
97
|
+
if v is None:
|
98
|
+
return None
|
99
|
+
elif squeeze:
|
100
|
+
return v.squeeze().to(device)
|
101
|
+
else:
|
102
|
+
return v.to(device)
|
103
|
+
|
104
|
+
for k, v in text_inputs.items():
|
105
|
+
if torch.is_tensor(v):
|
106
|
+
text_inputs[k] = proc_tensor(v)
|
107
|
+
elif isinstance(v, (dict, BatchableDict)):
|
108
|
+
for name, vi in v.items():
|
109
|
+
if torch.is_tensor(vi):
|
110
|
+
v[name] = proc_tensor(vi)
|
111
|
+
|
112
|
+
return text_inputs
|
@@ -0,0 +1,75 @@
|
|
1
|
+
from .compose_textencoder import ComposeTextEncoder
|
2
|
+
from .compose_tokenizer import ComposeTokenizer
|
3
|
+
from transformers import CLIPTextModel, AutoTokenizer, CLIPTextModelWithProjection, T5EncoderModel
|
4
|
+
from typing import Optional, Union, Tuple, Dict
|
5
|
+
import torch
|
6
|
+
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
7
|
+
|
8
|
+
class T5EncoderModel_Align(T5EncoderModel):
|
9
|
+
# fxxk the transformers!
|
10
|
+
def forward(
|
11
|
+
self,
|
12
|
+
input_ids: Optional[torch.LongTensor] = None,
|
13
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
14
|
+
head_mask: Optional[torch.FloatTensor] = None,
|
15
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
16
|
+
output_attentions: Optional[bool] = None,
|
17
|
+
output_hidden_states: Optional[bool] = None,
|
18
|
+
return_dict: Optional[bool] = None,
|
19
|
+
) -> Union[Tuple[torch.FloatTensor], BaseModelOutputWithPooling]:
|
20
|
+
text_outputs = super().forward(input_ids, attention_mask, head_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict)
|
21
|
+
return BaseModelOutputWithPooling(
|
22
|
+
last_hidden_state=text_outputs.last_hidden_state,
|
23
|
+
pooler_output=None,
|
24
|
+
hidden_states=text_outputs.hidden_states,
|
25
|
+
attentions=text_outputs.attentions,
|
26
|
+
)
|
27
|
+
|
28
|
+
class FluxTextEncoder(ComposeTextEncoder):
|
29
|
+
def forward(
|
30
|
+
self,
|
31
|
+
input_ids: Optional[Dict[str, torch.Tensor]] = None,
|
32
|
+
attention_mask: Optional[torch.Tensor] = None,
|
33
|
+
position_ids: Optional[torch.Tensor] = None,
|
34
|
+
output_attentions: Optional[bool] = None,
|
35
|
+
output_hidden_states: Optional[bool] = None,
|
36
|
+
return_dict: Optional[bool] = None,
|
37
|
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
38
|
+
output = super().forward(
|
39
|
+
input_ids=input_ids,
|
40
|
+
attention_mask=attention_mask,
|
41
|
+
position_ids=position_ids,
|
42
|
+
output_attentions=output_attentions,
|
43
|
+
output_hidden_states=output_hidden_states,
|
44
|
+
return_dict=return_dict,
|
45
|
+
)
|
46
|
+
if self.with_hook:
|
47
|
+
encoder_hidden_states_dict, pooled_output_dict = output
|
48
|
+
encoder_hidden_states = encoder_hidden_states_dict['T5']
|
49
|
+
pooled_output = pooled_output_dict['clip']
|
50
|
+
else:
|
51
|
+
last_hidden_state = output['last_hidden_state']['T5']
|
52
|
+
pooler_output = output['pooler_output']['clip']
|
53
|
+
attentions = output['attentions']['T5']
|
54
|
+
hidden_states = output['hidden_states']['T5']
|
55
|
+
return BaseModelOutputWithPooling(
|
56
|
+
last_hidden_state=last_hidden_state,
|
57
|
+
pooler_output=pooler_output,
|
58
|
+
hidden_states=hidden_states,
|
59
|
+
attentions=attentions,
|
60
|
+
)
|
61
|
+
|
62
|
+
return encoder_hidden_states, pooled_output
|
63
|
+
|
64
|
+
@classmethod
|
65
|
+
def from_pretrained(cls, pretrained_model_name_or_path: str, *args, subfolder=None, revision:str=None, **kwargs):
|
66
|
+
clip_L = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder', **kwargs)
|
67
|
+
T5 = T5EncoderModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder_2', **kwargs)
|
68
|
+
return cls({'clip': clip_L, 'T5': T5})
|
69
|
+
|
70
|
+
class FluxTokenizer(ComposeTokenizer):
|
71
|
+
@classmethod
|
72
|
+
def from_pretrained(cls, pretrained_model_name_or_path: str, *args, subfolder=None, revision:str=None, **kwargs):
|
73
|
+
clip_L = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer', **kwargs)
|
74
|
+
T5 = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer_2', **kwargs)
|
75
|
+
return cls({'clip': clip_L, 'T5': T5})
|
@@ -0,0 +1,86 @@
|
|
1
|
+
from .compose_textencoder import ComposeTextEncoder
|
2
|
+
from .compose_tokenizer import ComposeTokenizer
|
3
|
+
from transformers import CLIPTextModel, AutoTokenizer, CLIPTextModelWithProjection
|
4
|
+
from typing import Optional, Union, Tuple, Dict
|
5
|
+
import torch
|
6
|
+
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
7
|
+
|
8
|
+
class CLIPTextModelWithProjection_Align(CLIPTextModelWithProjection):
|
9
|
+
# fxxk the transformers!
|
10
|
+
def forward(
|
11
|
+
self,
|
12
|
+
input_ids: Optional[torch.Tensor] = None,
|
13
|
+
attention_mask: Optional[torch.Tensor] = None,
|
14
|
+
position_ids: Optional[torch.Tensor] = None,
|
15
|
+
output_attentions: Optional[bool] = None,
|
16
|
+
output_hidden_states: Optional[bool] = None,
|
17
|
+
return_dict: Optional[bool] = None,
|
18
|
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
19
|
+
try: # old version of transformers
|
20
|
+
text_outputs = super().forward(input_ids, attention_mask, position_ids, output_attentions, output_hidden_states, return_dict)
|
21
|
+
except TypeError: # new version(like 4.53.1) of transformers removed 'return_dict'
|
22
|
+
text_outputs = super().forward(input_ids, attention_mask, position_ids, output_attentions, output_hidden_states)
|
23
|
+
|
24
|
+
return BaseModelOutputWithPooling(
|
25
|
+
last_hidden_state=text_outputs.last_hidden_state,
|
26
|
+
pooler_output=text_outputs.text_embeds,
|
27
|
+
hidden_states=text_outputs.hidden_states,
|
28
|
+
attentions=text_outputs.attentions,
|
29
|
+
)
|
30
|
+
|
31
|
+
class SDXLTextEncoder(ComposeTextEncoder):
|
32
|
+
def forward(
|
33
|
+
self,
|
34
|
+
input_ids: Optional[Dict[str, torch.Tensor]] = None,
|
35
|
+
attention_mask: Optional[torch.Tensor] = None,
|
36
|
+
position_ids: Optional[torch.Tensor] = None,
|
37
|
+
output_attentions: Optional[bool] = None,
|
38
|
+
output_hidden_states: Optional[bool] = None,
|
39
|
+
return_dict: Optional[bool] = None,
|
40
|
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
41
|
+
output = super().forward(
|
42
|
+
input_ids=input_ids,
|
43
|
+
attention_mask=attention_mask,
|
44
|
+
position_ids=position_ids,
|
45
|
+
output_attentions=output_attentions,
|
46
|
+
output_hidden_states=output_hidden_states,
|
47
|
+
return_dict=return_dict,
|
48
|
+
)
|
49
|
+
if self.with_hook:
|
50
|
+
encoder_hidden_states_dict, pooled_output_dict = output
|
51
|
+
encoder_hidden_states = torch.cat([encoder_hidden_states_dict['clip_L'], encoder_hidden_states_dict['clip_bigG']], dim=-1)
|
52
|
+
pooled_output = pooled_output_dict['clip_bigG']
|
53
|
+
else:
|
54
|
+
last_hidden_state = torch.cat((output['last_hidden_state']['clip_L'], output['last_hidden_state']['clip_bigG']), dim=-1)
|
55
|
+
pooler_output = output['pooler_output']['clip_bigG']
|
56
|
+
attentions = output['attentions']['clip_bigG']
|
57
|
+
if output['hidden_states']['clip_L'] is None:
|
58
|
+
hidden_states = None
|
59
|
+
else:
|
60
|
+
hidden_states = [torch.cat(states, dim=self.cat_dim) for states in zip(output['hidden_states']['clip_L'], output['hidden_states']['clip_bigG'])]
|
61
|
+
return BaseModelOutputWithPooling(
|
62
|
+
last_hidden_state=last_hidden_state,
|
63
|
+
pooler_output=pooler_output,
|
64
|
+
hidden_states=hidden_states,
|
65
|
+
attentions=attentions,
|
66
|
+
)
|
67
|
+
|
68
|
+
return encoder_hidden_states, pooled_output
|
69
|
+
|
70
|
+
@classmethod
|
71
|
+
def from_pretrained(cls, pretrained_model_name_or_path: str, *args, subfolder=None, revision:str=None, **kwargs):
|
72
|
+
clip_L = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder', **kwargs)
|
73
|
+
clip_bigG = CLIPTextModelWithProjection_Align.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder_2', **kwargs)
|
74
|
+
return cls({'clip_L': clip_L, 'clip_bigG': clip_bigG})
|
75
|
+
|
76
|
+
class SDXLTokenizer(ComposeTokenizer):
|
77
|
+
def __call__(self, text, *args, max_length=None, **kwargs):
|
78
|
+
token_info = super().__call__(text, *args, max_length=max_length, **kwargs)
|
79
|
+
token_info['attention_mask'] = token_info['attention_mask']['clip_L']
|
80
|
+
return token_info
|
81
|
+
|
82
|
+
@classmethod
|
83
|
+
def from_pretrained(cls, pretrained_model_name_or_path: str, *args, subfolder=None, revision:str=None, **kwargs):
|
84
|
+
clip_L = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer', **kwargs)
|
85
|
+
clip_bigG = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer_2', **kwargs)
|
86
|
+
return cls({'clip_L': clip_L, 'clip_bigG': clip_bigG})
|
hcpdiff/models/text_emb_ex.py
CHANGED
@@ -7,17 +7,18 @@ text_emb_ex.py
|
|
7
7
|
:Created: 10/03/2023
|
8
8
|
:Licence: Apache-2.0
|
9
9
|
"""
|
10
|
+
import os
|
11
|
+
from pathlib import Path
|
10
12
|
from typing import Tuple, Dict, Any
|
11
13
|
|
12
14
|
import torch
|
13
|
-
from torch import nn
|
14
|
-
import os
|
15
|
-
from rainbowneko import _share
|
16
|
-
from einops import rearrange, repeat
|
17
15
|
import torch.nn.functional as F
|
16
|
+
from einops import rearrange, repeat
|
17
|
+
from rainbowneko import _share
|
18
|
+
from rainbowneko.models.plugin import SinglePluginBlock
|
19
|
+
from torch import nn
|
18
20
|
|
19
21
|
from ..utils.net_utils import load_emb
|
20
|
-
from rainbowneko.models.plugin import SinglePluginBlock
|
21
22
|
|
22
23
|
class EmbeddingPTHook(SinglePluginBlock):
|
23
24
|
def __init__(self, token_embedding:nn.Embedding, N_word=75, N_repeats=3):
|
@@ -74,7 +75,7 @@ class EmbeddingPTHook(SinglePluginBlock):
|
|
74
75
|
self.handle_pre.remove()
|
75
76
|
|
76
77
|
@classmethod
|
77
|
-
def hook(cls, ex_words_emb, tokenizer, text_encoder, **kwargs):
|
78
|
+
def hook(cls, ex_words_emb:Dict[str, nn.Parameter], tokenizer, text_encoder, **kwargs):
|
78
79
|
word_list = list(ex_words_emb.keys())
|
79
80
|
tokenizer.add_tokens(word_list)
|
80
81
|
token_ids = tokenizer(' '.join(word_list)).input_ids[1:-1]
|
@@ -87,9 +88,12 @@ class EmbeddingPTHook(SinglePluginBlock):
|
|
87
88
|
return embedding_hook
|
88
89
|
|
89
90
|
@classmethod
|
90
|
-
def hook_from_dir(cls, emb_dir, tokenizer, text_encoder, device='cuda
|
91
|
-
|
92
|
-
|
91
|
+
def hook_from_dir(cls, emb_dir:str|Path, tokenizer, text_encoder, device='cuda', **kwargs):
|
92
|
+
if emb_dir is None:
|
93
|
+
ex_words_emb = {}
|
94
|
+
else:
|
95
|
+
emb_dir = Path(emb_dir)
|
96
|
+
ex_words_emb = {file.stem: nn.Parameter(load_emb(file).to(device), requires_grad=False) for file in emb_dir.glob('*.pt')}
|
93
97
|
return cls.hook(ex_words_emb, tokenizer, text_encoder, **kwargs), ex_words_emb
|
94
98
|
|
95
99
|
class EmbeddingPTInterpHook(SinglePluginBlock):
|
hcpdiff/models/textencoder_ex.py
CHANGED
@@ -51,48 +51,18 @@ class TEEXHook:
|
|
51
51
|
else:
|
52
52
|
self.final_layer_norm = None
|
53
53
|
|
54
|
+
@property
|
55
|
+
def N_repeats(self):
|
56
|
+
return self.tokenizer.N_repeats
|
57
|
+
|
58
|
+
@N_repeats.setter
|
59
|
+
def N_repeats(self, value: int):
|
60
|
+
self.tokenizer.N_repeats = value
|
61
|
+
|
54
62
|
@property
|
55
63
|
def device(self):
|
56
64
|
return self.text_enc.device
|
57
65
|
|
58
|
-
def encode_prompt_to_emb(self, prompt):
|
59
|
-
text_inputs = self.tokenizer(
|
60
|
-
prompt,
|
61
|
-
padding="max_length",
|
62
|
-
max_length=self.tokenizer.model_max_length*self.N_repeats,
|
63
|
-
truncation=True,
|
64
|
-
return_tensors="pt",
|
65
|
-
)
|
66
|
-
text_input_ids = text_inputs.input_ids
|
67
|
-
if self.use_attention_mask:
|
68
|
-
attention_mask = text_inputs.get('attention_mask', None)
|
69
|
-
else:
|
70
|
-
attention_mask = None
|
71
|
-
if attention_mask is not None:
|
72
|
-
attention_mask = attention_mask.to(self.device)
|
73
|
-
position_ids = text_inputs.get('position_ids', None)
|
74
|
-
if position_ids is not None:
|
75
|
-
position_ids = position_ids.to(self.device)
|
76
|
-
|
77
|
-
# align with sd-webui
|
78
|
-
if isinstance(self.text_enc, CLIPTextModelWithProjection):
|
79
|
-
self.text_enc.text_projection.weight.data = self.text_enc.text_projection.weight.data.t()
|
80
|
-
|
81
|
-
if isinstance(self.text_enc, T5EncoderModel):
|
82
|
-
prompt_embeds, pooled_output = self.text_enc(
|
83
|
-
text_input_ids.to(self.device),
|
84
|
-
attention_mask=attention_mask,
|
85
|
-
output_hidden_states=True,
|
86
|
-
)
|
87
|
-
else:
|
88
|
-
prompt_embeds, pooled_output = self.text_enc(
|
89
|
-
text_input_ids.to(self.device),
|
90
|
-
attention_mask=attention_mask,
|
91
|
-
position_ids=position_ids,
|
92
|
-
output_hidden_states=True,
|
93
|
-
)
|
94
|
-
return prompt_embeds, pooled_output, attention_mask
|
95
|
-
|
96
66
|
def forward_hook_input(self, host, feat_in):
|
97
67
|
feat_re = rearrange(feat_in[0], 'b (r w) -> (b r) w', r=self.N_repeats) # 使Attention mask的尺寸为N_word+2
|
98
68
|
return (feat_re,) if len(feat_in) == 1 else (feat_re, *feat_in[1:])
|
@@ -0,0 +1,75 @@
|
|
1
|
+
import torch
|
2
|
+
from diffusers import FluxTransformer2DModel, AutoencoderKL
|
3
|
+
from einops import repeat, rearrange
|
4
|
+
from hcpdiff.diffusion.sampler import BaseSampler
|
5
|
+
from hcpdiff.utils import pad_attn_bias
|
6
|
+
from rainbowneko.utils import add_dims
|
7
|
+
|
8
|
+
from .sd import SD15Wrapper
|
9
|
+
from .utils import TEHookCFG, SDXL_TEHookCFG
|
10
|
+
from ..cfg_context import CFGContext
|
11
|
+
|
12
|
+
class FluxWrapper(SD15Wrapper):
|
13
|
+
def __init__(self, denoiser: FluxTransformer2DModel, TE, vae: AutoencoderKL, noise_sampler: BaseSampler, tokenizer, min_attnmask=0,
|
14
|
+
guidance=5.0, patch_size=2, TE_hook_cfg: TEHookCFG = SDXL_TEHookCFG, cfg_context=CFGContext(), key_map_in=None, key_map_out=None):
|
15
|
+
super().__init__(denoiser, TE, vae, noise_sampler, tokenizer, min_attnmask, TE_hook_cfg, cfg_context, key_map_in, key_map_out)
|
16
|
+
self.key_mapper_in = self.build_mapper(key_map_in, None, (
|
17
|
+
'prompt -> prompt_ids', 'image -> image', 'attn_mask -> attn_mask', 'neg_prompt -> neg_prompt_ids',
|
18
|
+
'neg_attn_mask -> neg_attn_mask', 'plugin_input -> plugin_input'))
|
19
|
+
self.guidance = guidance
|
20
|
+
self.patch_size = patch_size
|
21
|
+
|
22
|
+
def forward_TE(self, prompt_ids, timesteps, attn_mask=None, position_ids=None, plugin_input={}, **kwargs):
|
23
|
+
input_all = dict(prompt_ids=prompt_ids, timesteps=timesteps, position_ids=position_ids, attn_mask=attn_mask, **plugin_input)
|
24
|
+
if hasattr(self.TE, 'input_feeder'):
|
25
|
+
for feeder in self.TE.input_feeder:
|
26
|
+
feeder(input_all)
|
27
|
+
# Get the text embedding for conditioning
|
28
|
+
encoder_hidden_states, pooled_output = self.TE(prompt_ids, position_ids=position_ids, attention_mask=attn_mask, output_hidden_states=True)
|
29
|
+
return encoder_hidden_states, pooled_output
|
30
|
+
|
31
|
+
def forward_denoiser(self, x_t, H, W, prompt_ids, encoder_hidden_states, pooled_output, timesteps, attn_mask=None, plugin_input={}, **kwargs):
|
32
|
+
attn_mask = attn_mask['T5']
|
33
|
+
if attn_mask is not None:
|
34
|
+
attn_mask[:, :self.min_attnmask] = 1
|
35
|
+
encoder_hidden_states, attn_mask = pad_attn_bias(encoder_hidden_states, attn_mask)
|
36
|
+
|
37
|
+
img_ids = torch.zeros(H, W, 3)
|
38
|
+
img_ids[..., 1] = img_ids[..., 1]+torch.arange(H)[:, None]
|
39
|
+
img_ids[..., 2] = img_ids[..., 2]+torch.arange(W)[None, :]
|
40
|
+
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=x_t.shape[0])
|
41
|
+
|
42
|
+
txt_ids = torch.zeros(x_t.shape[0], encoder_hidden_states.shape[1], 3)
|
43
|
+
|
44
|
+
input_all = dict(prompt_ids=prompt_ids, timesteps=timesteps, attn_mask=attn_mask, img_ids=img_ids, txt_ids=txt_ids,
|
45
|
+
encoder_hidden_states=encoder_hidden_states, **plugin_input)
|
46
|
+
if hasattr(self.denoiser, 'input_feeder'):
|
47
|
+
for feeder in self.denoiser.input_feeder:
|
48
|
+
feeder(input_all)
|
49
|
+
model_pred = self.denoiser(x_t, timesteps, self.guidance, pooled_output, encoder_hidden_states, img_ids=img_ids, txt_ids=txt_ids).sample
|
50
|
+
return model_pred
|
51
|
+
|
52
|
+
def model_forward(self, prompt_ids, image, attn_mask=None, neg_prompt_ids=None, neg_attn_mask=None, plugin_input={}, **kwargs):
|
53
|
+
# input prepare
|
54
|
+
x_0 = self.get_latents(image)
|
55
|
+
B, C, H, W = x_0.shape
|
56
|
+
x_0_patch = rearrange(x_0, "b c (h ph) (w pw) -> b (c ph pw) h w", ph=self.patch_size, pw=self.patch_size)
|
57
|
+
x_t, noise, timesteps = self.noise_sampler.add_noise_rand_t(x_0_patch)
|
58
|
+
x_t_in = x_t*add_dims(self.noise_sampler.sigma_scheduler.c_in(timesteps).to(dtype=x_t.dtype), x_t.ndim-1)
|
59
|
+
t_in = self.noise_sampler.sigma_scheduler.c_noise(timesteps)
|
60
|
+
x_t_in = rearrange(x_t_in, "b c h w -> b (h w) c")
|
61
|
+
|
62
|
+
if neg_prompt_ids:
|
63
|
+
prompt_ids = self.pn_cat(neg_prompt_ids, prompt_ids)
|
64
|
+
if neg_attn_mask:
|
65
|
+
attn_mask = self.pn_cat(neg_attn_mask, attn_mask)
|
66
|
+
|
67
|
+
# model forward
|
68
|
+
x_t_in, t_in = self.cfg_context.pre(x_t_in, t_in)
|
69
|
+
encoder_hidden_states, pooled_output = self.forward_TE(prompt_ids, t_in, attn_mask=attn_mask, plugin_input=plugin_input, **kwargs)
|
70
|
+
model_pred = self.forward_denoiser(x_t_in, H, W, prompt_ids, encoder_hidden_states, pooled_output, t_in, attn_mask=attn_mask,
|
71
|
+
plugin_input=plugin_input, **kwargs)
|
72
|
+
model_pred = rearrange(model_pred, "b (h w) (c ph pw) -> b c (h ph) (w pw)", ph=self.patch_size, pw=self.patch_size, h=H, w=W)
|
73
|
+
model_pred = self.cfg_context.post(model_pred)
|
74
|
+
|
75
|
+
return dict(model_pred=model_pred, noise=noise, timesteps=timesteps, x_0=x_0, x_t=x_t, noise_sampler=self.noise_sampler)
|
hcpdiff/models/wrapper/pixart.py
CHANGED
@@ -2,6 +2,16 @@ from .sd import SD15Wrapper
|
|
2
2
|
from hcpdiff.utils import pad_attn_bias
|
3
3
|
|
4
4
|
class PixArtWrapper(SD15Wrapper):
|
5
|
+
def forward_TE(self, prompt_ids, timesteps, attn_mask=None, plugin_input={}, **kwargs):
|
6
|
+
# T5Encoder do not need position_ids (It use relative position embedding for key and query)
|
7
|
+
input_all = dict(prompt_ids=prompt_ids, timesteps=timesteps, attn_mask=attn_mask, **plugin_input)
|
8
|
+
if hasattr(self.TE, 'input_feeder'):
|
9
|
+
for feeder in self.TE.input_feeder:
|
10
|
+
feeder(input_all)
|
11
|
+
# Get the text embedding for conditioning
|
12
|
+
encoder_hidden_states = self.TE(prompt_ids, attention_mask=attn_mask, output_hidden_states=True)[0]
|
13
|
+
return encoder_hidden_states
|
14
|
+
|
5
15
|
def forward_denoiser(self, x_t, prompt_ids, encoder_hidden_states, timesteps, attn_mask=None, position_ids=None, resolution=None, aspect_ratio=None,
|
6
16
|
plugin_input={}, **kwargs):
|
7
17
|
if attn_mask is not None:
|
@@ -16,4 +26,6 @@ class PixArtWrapper(SD15Wrapper):
|
|
16
26
|
added_cond_kwargs = {"resolution":resolution, "aspect_ratio":aspect_ratio}
|
17
27
|
model_pred = self.denoiser(x_t, encoder_hidden_states, timesteps, encoder_attention_mask=attn_mask,
|
18
28
|
added_cond_kwargs=added_cond_kwargs).sample # Predict the noise residual
|
19
|
-
|
29
|
+
|
30
|
+
# remove pred vars for pixart output (see DiT for more)
|
31
|
+
return model_pred.chunk(2, dim=1)[0]
|
hcpdiff/models/wrapper/sd.py
CHANGED
@@ -60,7 +60,10 @@ class SD15Wrapper(BaseWrapper):
|
|
60
60
|
if image.shape[1] == 3:
|
61
61
|
with torch.no_grad() if self.vae_trainable else nullcontext():
|
62
62
|
latents = self.vae.encode(image.to(dtype=self.vae.dtype)).latent_dist.sample()
|
63
|
-
|
63
|
+
if shift_factor := getattr(self.vae.config, 'shift_factor', None) is not None:
|
64
|
+
latents = (latents-shift_factor)*self.vae.config.scaling_factor
|
65
|
+
else:
|
66
|
+
latents = latents*self.vae.config.scaling_factor
|
64
67
|
else:
|
65
68
|
latents = image # Cached latents
|
66
69
|
return latents
|
@@ -87,6 +90,12 @@ class SD15Wrapper(BaseWrapper):
|
|
87
90
|
model_pred = self.denoiser(x_t, timesteps, encoder_hidden_states, encoder_attention_mask=attn_mask).sample # Predict the noise residual
|
88
91
|
return model_pred
|
89
92
|
|
93
|
+
def pn_cat(self, neg, pos, dim=0):
|
94
|
+
if isinstance(pos, dict): # ComposeTextEncoder
|
95
|
+
return {name:torch.cat([neg[name], pos_i], dim=dim) for name, pos_i in pos.items()}
|
96
|
+
else:
|
97
|
+
return torch.cat([neg, pos], dim=dim)
|
98
|
+
|
90
99
|
def model_forward(self, prompt_ids, image, attn_mask=None, position_ids=None, neg_prompt_ids=None, neg_attn_mask=None, neg_position_ids=None,
|
91
100
|
plugin_input={}, **kwargs):
|
92
101
|
# input prepare
|
@@ -96,11 +105,11 @@ class SD15Wrapper(BaseWrapper):
|
|
96
105
|
t_in = self.noise_sampler.sigma_scheduler.c_noise(timesteps)
|
97
106
|
|
98
107
|
if neg_prompt_ids:
|
99
|
-
prompt_ids =
|
108
|
+
prompt_ids = self.pn_cat(neg_prompt_ids, prompt_ids)
|
100
109
|
if neg_attn_mask:
|
101
|
-
attn_mask =
|
110
|
+
attn_mask = self.pn_cat(neg_attn_mask, attn_mask)
|
102
111
|
if neg_position_ids:
|
103
|
-
position_ids =
|
112
|
+
position_ids = self.pn_cat(neg_position_ids, position_ids)
|
104
113
|
|
105
114
|
# model forward
|
106
115
|
x_t_in, t_in = self.cfg_context.pre(x_t_in, t_in)
|
@@ -198,17 +207,17 @@ class SDXLWrapper(SD15Wrapper):
|
|
198
207
|
t_in = self.noise_sampler.sigma_scheduler.c_noise(timesteps)
|
199
208
|
|
200
209
|
if neg_prompt_ids:
|
201
|
-
prompt_ids =
|
210
|
+
prompt_ids = self.pn_cat(neg_prompt_ids, prompt_ids)
|
202
211
|
if neg_attn_mask:
|
203
|
-
attn_mask =
|
212
|
+
attn_mask = self.pn_cat(neg_attn_mask, attn_mask)
|
204
213
|
if neg_position_ids:
|
205
|
-
position_ids =
|
214
|
+
position_ids = self.pn_cat(neg_position_ids, position_ids)
|
206
215
|
|
207
216
|
# model forward
|
208
217
|
x_t_in, t_in = self.cfg_context.pre(x_t_in, t_in)
|
209
218
|
encoder_hidden_states, pooled_output = self.forward_TE(prompt_ids, t_in, attn_mask=attn_mask, position_ids=position_ids,
|
210
219
|
plugin_input=plugin_input)
|
211
|
-
added_cond_kwargs = {"text_embeds":pooled_output
|
220
|
+
added_cond_kwargs = {"text_embeds":pooled_output, "time_ids":crop_info}
|
212
221
|
model_pred = self.forward_denoiser(x_t_in, prompt_ids, encoder_hidden_states, t_in, added_cond_kwargs=added_cond_kwargs,
|
213
222
|
attn_mask=attn_mask, position_ids=position_ids, plugin_input=plugin_input)
|
214
223
|
model_pred = self.cfg_context.post(model_pred)
|
hcpdiff/parser/embpt.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
from typing import Dict, Tuple, List
|
2
2
|
from rainbowneko.utils import Path_Like
|
3
|
-
from hcpdiff.models import
|
4
|
-
from torch import Tensor
|
3
|
+
from hcpdiff.models.compose import ComposeEmbPTHook
|
4
|
+
from torch import Tensor, nn
|
5
5
|
|
6
6
|
class CfgEmbPTParser:
|
7
7
|
def __init__(self, emb_dir: Path_Like, cfg_pt: Dict[str, Dict], lr: float = 1e-5, weight_decay: float = 0):
|
@@ -11,22 +11,22 @@ class CfgEmbPTParser:
|
|
11
11
|
self.weight_decay = weight_decay
|
12
12
|
|
13
13
|
def get_params_group(self, model) -> Tuple[List, Dict[str, Tensor]]:
|
14
|
-
self.embedding_hook, self.ex_words_emb =
|
14
|
+
self.embedding_hook, self.ex_words_emb = ComposeEmbPTHook.hook_from_dir(
|
15
15
|
self.emb_dir, model.tokenizer, model.TE, N_repeats=model.tokenizer.N_repeats)
|
16
16
|
self.embedding_hook.requires_grad_(False)
|
17
17
|
|
18
18
|
train_params_emb = []
|
19
19
|
train_pts = {}
|
20
20
|
for pt_name, info in self.cfg_pt.items():
|
21
|
-
word_emb = self.ex_words_emb[pt_name]
|
21
|
+
word_emb: nn.Parameter | nn.ParameterDict = self.ex_words_emb[pt_name]
|
22
22
|
train_pts[pt_name] = word_emb
|
23
|
-
word_emb.
|
23
|
+
word_emb.requires_grad_(True)
|
24
24
|
self.embedding_hook.emb_train.append(word_emb)
|
25
|
-
param_group = {'params':word_emb}
|
25
|
+
param_group = {'params':word_emb.parameters() if hasattr(word_emb, 'parameters') else [word_emb]}
|
26
26
|
if 'lr' in info:
|
27
27
|
param_group['lr'] = info.lr
|
28
28
|
if 'weight_decay' in info:
|
29
29
|
param_group['weight_decay'] = info.weight_decay
|
30
30
|
train_params_emb.append(param_group)
|
31
31
|
|
32
|
-
return train_params_emb, train_pts
|
32
|
+
return train_params_emb, train_pts
|
hcpdiff/utils/net_utils.py
CHANGED
@@ -1,15 +1,16 @@
|
|
1
|
+
import json
|
1
2
|
import os
|
2
3
|
from copy import deepcopy
|
4
|
+
from functools import partial
|
3
5
|
from typing import Optional, Union
|
4
6
|
|
5
7
|
import torch
|
6
8
|
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION, Optimizer
|
9
|
+
from huggingface_hub import hf_hub_download
|
7
10
|
from torch import nn
|
8
11
|
from torch.optim import lr_scheduler
|
9
12
|
from transformers import PretrainedConfig, AutoTokenizer, T5EncoderModel, CLIPTextModel
|
10
|
-
from
|
11
|
-
from huggingface_hub import hf_hub_download
|
12
|
-
import json
|
13
|
+
from transformers.models.auto.tokenization_auto import get_tokenizer_config
|
13
14
|
|
14
15
|
dtype_dict = {'fp32':torch.float32, 'amp':torch.float32, 'fp16':torch.float16, 'bf16':torch.bfloat16}
|
15
16
|
|
@@ -91,19 +92,24 @@ def get_scheduler_with_name(
|
|
91
92
|
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **scheduler_kwargs)
|
92
93
|
|
93
94
|
def auto_tokenizer_cls(pretrained_model_name_or_path: str, revision: str = None):
|
94
|
-
from hcpdiff.models.compose import SDXLTokenizer
|
95
|
+
from hcpdiff.models.compose import SDXLTokenizer, FluxTokenizer
|
95
96
|
try:
|
96
|
-
|
97
|
-
pretrained_model_name_or_path,
|
98
|
-
|
97
|
+
tokenizer_config = get_tokenizer_config(
|
98
|
+
pretrained_model_name_or_path,
|
99
|
+
subfolder="tokenizer_2",
|
100
|
+
revision=revision
|
99
101
|
)
|
100
|
-
|
102
|
+
class_name = tokenizer_config.get("tokenizer_class")
|
103
|
+
if class_name == 'T5Tokenizer' or class_name == 'T5TokenizerFast':
|
104
|
+
return FluxTokenizer
|
105
|
+
else:
|
106
|
+
return SDXLTokenizer
|
101
107
|
except:
|
102
|
-
# not
|
108
|
+
# not composed, only one tokenizer
|
103
109
|
return AutoTokenizer
|
104
110
|
|
105
111
|
def auto_text_encoder_cls(pretrained_model_name_or_path: str, revision: str = None):
|
106
|
-
from hcpdiff.models.compose import SDXLTextEncoder
|
112
|
+
from hcpdiff.models.compose import SDXLTextEncoder, FluxTextEncoder
|
107
113
|
try:
|
108
114
|
text_encoder_config = PretrainedConfig.from_pretrained(
|
109
115
|
pretrained_model_name_or_path,
|
@@ -112,7 +118,11 @@ def auto_text_encoder_cls(pretrained_model_name_or_path: str, revision: str = No
|
|
112
118
|
)
|
113
119
|
if text_encoder_config.architectures is None:
|
114
120
|
raise ValueError()
|
115
|
-
|
121
|
+
model_class = text_encoder_config.architectures[0]
|
122
|
+
if model_class == "T5EncoderModel":
|
123
|
+
return FluxTextEncoder
|
124
|
+
else:
|
125
|
+
return SDXLTextEncoder
|
116
126
|
except:
|
117
127
|
text_encoder_config = PretrainedConfig.from_pretrained(
|
118
128
|
pretrained_model_name_or_path,
|
@@ -248,4 +258,4 @@ def get_dtype(dtype):
|
|
248
258
|
if isinstance(dtype, torch.dtype):
|
249
259
|
return dtype
|
250
260
|
else:
|
251
|
-
return dtype_dict.get(dtype, torch.float32)
|
261
|
+
return dtype_dict.get(dtype, torch.float32)
|
hcpdiff/workflow/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from .diffusion import InputFeederAction, MakeLatentAction,
|
1
|
+
from .diffusion import InputFeederAction, MakeLatentAction, SD15DenoiseAction, SDXLDenoiseAction, PixartDenoiseAction, FluxDenoiseAction, SampleAction, DiffusionStepAction, \
|
2
2
|
X0PredAction, SeedAction, MakeTimestepsAction, PrepareDiffusionAction, time_iter, DiffusionActions
|
3
3
|
from .text import TextEncodeAction, TextHookAction, AttnMultTextEncodeAction
|
4
4
|
from .vae import EncodeAction, DecodeAction
|