hcpdiff 2.3.1__py3-none-any.whl → 2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. hcpdiff/ckpt_manager/__init__.py +1 -1
  2. hcpdiff/ckpt_manager/format/__init__.py +2 -2
  3. hcpdiff/ckpt_manager/format/diffusers.py +19 -4
  4. hcpdiff/ckpt_manager/format/emb.py +8 -3
  5. hcpdiff/ckpt_manager/format/lora_webui.py +1 -1
  6. hcpdiff/ckpt_manager/format/sd_single.py +28 -5
  7. hcpdiff/data/cache/vae.py +10 -2
  8. hcpdiff/data/handler/text.py +15 -14
  9. hcpdiff/diffusion/sampler/__init__.py +2 -1
  10. hcpdiff/diffusion/sampler/base.py +17 -6
  11. hcpdiff/diffusion/sampler/diffusers.py +4 -3
  12. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +5 -14
  13. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +7 -6
  14. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +4 -4
  15. hcpdiff/diffusion/sampler/sigma_scheduler/flow.py +3 -3
  16. hcpdiff/diffusion/sampler/timer/__init__.py +2 -0
  17. hcpdiff/diffusion/sampler/timer/base.py +26 -0
  18. hcpdiff/diffusion/sampler/timer/shift.py +49 -0
  19. hcpdiff/easy/__init__.py +2 -1
  20. hcpdiff/easy/cfg/sd15_train.py +1 -3
  21. hcpdiff/easy/model/__init__.py +1 -1
  22. hcpdiff/easy/model/loader.py +33 -11
  23. hcpdiff/easy/sampler.py +8 -1
  24. hcpdiff/loss/__init__.py +4 -3
  25. hcpdiff/loss/charbonnier.py +17 -0
  26. hcpdiff/loss/vlb.py +2 -2
  27. hcpdiff/loss/weighting.py +29 -11
  28. hcpdiff/models/__init__.py +1 -1
  29. hcpdiff/models/cfg_context.py +5 -3
  30. hcpdiff/models/compose/__init__.py +2 -1
  31. hcpdiff/models/compose/compose_hook.py +69 -67
  32. hcpdiff/models/compose/compose_textencoder.py +59 -45
  33. hcpdiff/models/compose/compose_tokenizer.py +48 -11
  34. hcpdiff/models/compose/flux.py +75 -0
  35. hcpdiff/models/compose/sdxl.py +86 -0
  36. hcpdiff/models/text_emb_ex.py +13 -9
  37. hcpdiff/models/textencoder_ex.py +8 -38
  38. hcpdiff/models/wrapper/__init__.py +2 -1
  39. hcpdiff/models/wrapper/flux.py +75 -0
  40. hcpdiff/models/wrapper/pixart.py +13 -1
  41. hcpdiff/models/wrapper/sd.py +17 -8
  42. hcpdiff/parser/embpt.py +7 -7
  43. hcpdiff/utils/net_utils.py +22 -12
  44. hcpdiff/workflow/__init__.py +1 -1
  45. hcpdiff/workflow/diffusion.py +145 -18
  46. hcpdiff/workflow/text.py +49 -18
  47. hcpdiff/workflow/vae.py +10 -2
  48. {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/METADATA +1 -1
  49. {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/RECORD +53 -49
  50. hcpdiff/models/compose/sdxl_composer.py +0 -39
  51. hcpdiff/utils/inpaint_pipe.py +0 -790
  52. hcpdiff/utils/pipe_hook.py +0 -656
  53. {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/WHEEL +0 -0
  54. {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/entry_points.txt +0 -0
  55. {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/licenses/LICENSE +0 -0
  56. {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/top_level.txt +0 -0
@@ -14,19 +14,20 @@ from typing import Dict, Tuple, List
14
14
  import torch
15
15
  from transformers import AutoTokenizer, CLIPTokenizer, PreTrainedTokenizer, PretrainedConfig
16
16
  from transformers.tokenization_utils_base import BatchEncoding
17
+ from rainbowneko.utils import BatchableDict
17
18
 
18
19
  class ComposeTokenizer(PreTrainedTokenizer):
19
- def __init__(self, tokenizer_list: List[Tuple[str, CLIPTokenizer]], cat_dim=-1):
20
- self.cat_dim = cat_dim
20
+ def __init__(self, tokenizers: Dict[str, CLIPTokenizer]):
21
21
 
22
22
  self.tokenizer_names = []
23
- for name, tokenizer in tokenizer_list:
23
+ for name, tokenizer in tokenizers.items():
24
24
  setattr(self, name, tokenizer)
25
25
  self.tokenizer_names.append(name)
26
26
 
27
27
  super().__init__()
28
28
 
29
- self.model_max_length = torch.tensor([tokenizer.model_max_length for name, tokenizer in tokenizer_list])
29
+ # self.model_max_length = torch.tensor([tokenizer.model_max_length for name, tokenizer in tokenizer_list])
30
+ self.model_max_length = {name: tokenizer.model_max_length for name, tokenizer in tokenizers.items()}
30
31
 
31
32
  @property
32
33
  def first_tokenizer(self):
@@ -57,15 +58,17 @@ class ComposeTokenizer(PreTrainedTokenizer):
57
58
  return self.first_tokenizer.save_vocabulary(save_directory, filename_prefix)
58
59
 
59
60
  def __call__(self, text, *args, max_length=None, **kwargs):
60
- if isinstance(max_length, torch.Tensor):
61
- token_list: List[BatchEncoding] = [getattr(self, name)(text, *args, max_length=max_length_i, **kwargs)
62
- for name, max_length_i in zip(self.tokenizer_names, max_length)]
61
+ if isinstance(max_length, dict):
62
+ token_infos: Dict[str, BatchEncoding] = {name: getattr(self, name)(text, *args, max_length=max_length[name], **kwargs)
63
+ for name in self.tokenizer_names}
63
64
  else:
64
- token_list: List[BatchEncoding] = [getattr(self, name)(text, *args, max_length=max_length, **kwargs) for name in self.tokenizer_names]
65
+ token_infos: Dict[str, BatchEncoding] = {name: getattr(self, name)(text, *args, max_length=max_length, **kwargs)
66
+ for name in self.tokenizer_names}
65
67
 
66
- input_ids = torch.cat([token.input_ids for token in token_list], dim=-1) # [N_tokenizer, N_token]
67
- attention_mask = torch.cat([token.attention_mask for token in token_list], dim=-1)
68
- return BatchEncoding({'input_ids':input_ids, 'attention_mask':attention_mask})
68
+ input_ids = BatchableDict({name: token.input_ids for name, token in token_infos.items()}) # [N_tokenizer, N_token]
69
+ attention_mask = BatchableDict({name: token.get('attention_mask', None) for name, token in token_infos.items()})
70
+ position_ids = BatchableDict({name: token.get('position_ids', None) for name, token in token_infos.items()})
71
+ return BatchEncoding({'input_ids':input_ids, 'attention_mask':attention_mask, 'position_ids':position_ids})
69
72
 
70
73
  @classmethod
71
74
  def from_pretrained(cls, pretrained_model_name_or_path: List[Tuple[str, str]], *args,
@@ -73,3 +76,37 @@ class ComposeTokenizer(PreTrainedTokenizer):
73
76
  tokenizer_list = [(name, AutoTokenizer.from_pretrained(path, subfolder=subfolder[name], **kwargs)) for name, path in pretrained_model_name_or_path]
74
77
  compose_tokenizer = cls(tokenizer_list)
75
78
  return compose_tokenizer
79
+
80
+ def __repr__(self):
81
+ return f'ComposeTokenizer(\n' + '\n'.join([f' {name}: {repr(getattr(self, name))}' for name in self.tokenizer_names]) + ')'
82
+
83
+ @staticmethod
84
+ def tokenize_ex(tokenizer, *args, device='cpu', squeeze=False, **kwargs):
85
+ if isinstance(tokenizer, ComposeTokenizer):
86
+ max_length = {name: (tok := getattr(tokenizer, name)).model_max_length * getattr(tok, 'N_repeats', 1) for name in tokenizer.tokenizer_names}
87
+ else:
88
+ max_length = tokenizer.model_max_length * getattr(tokenizer, 'N_repeats', 1)
89
+
90
+ text_inputs = tokenizer(
91
+ *args,
92
+ max_length=max_length,
93
+ **kwargs
94
+ )
95
+
96
+ def proc_tensor(v):
97
+ if v is None:
98
+ return None
99
+ elif squeeze:
100
+ return v.squeeze().to(device)
101
+ else:
102
+ return v.to(device)
103
+
104
+ for k, v in text_inputs.items():
105
+ if torch.is_tensor(v):
106
+ text_inputs[k] = proc_tensor(v)
107
+ elif isinstance(v, (dict, BatchableDict)):
108
+ for name, vi in v.items():
109
+ if torch.is_tensor(vi):
110
+ v[name] = proc_tensor(vi)
111
+
112
+ return text_inputs
@@ -0,0 +1,75 @@
1
+ from .compose_textencoder import ComposeTextEncoder
2
+ from .compose_tokenizer import ComposeTokenizer
3
+ from transformers import CLIPTextModel, AutoTokenizer, CLIPTextModelWithProjection, T5EncoderModel
4
+ from typing import Optional, Union, Tuple, Dict
5
+ import torch
6
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
7
+
8
+ class T5EncoderModel_Align(T5EncoderModel):
9
+ # fxxk the transformers!
10
+ def forward(
11
+ self,
12
+ input_ids: Optional[torch.LongTensor] = None,
13
+ attention_mask: Optional[torch.FloatTensor] = None,
14
+ head_mask: Optional[torch.FloatTensor] = None,
15
+ inputs_embeds: Optional[torch.FloatTensor] = None,
16
+ output_attentions: Optional[bool] = None,
17
+ output_hidden_states: Optional[bool] = None,
18
+ return_dict: Optional[bool] = None,
19
+ ) -> Union[Tuple[torch.FloatTensor], BaseModelOutputWithPooling]:
20
+ text_outputs = super().forward(input_ids, attention_mask, head_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict)
21
+ return BaseModelOutputWithPooling(
22
+ last_hidden_state=text_outputs.last_hidden_state,
23
+ pooler_output=None,
24
+ hidden_states=text_outputs.hidden_states,
25
+ attentions=text_outputs.attentions,
26
+ )
27
+
28
+ class FluxTextEncoder(ComposeTextEncoder):
29
+ def forward(
30
+ self,
31
+ input_ids: Optional[Dict[str, torch.Tensor]] = None,
32
+ attention_mask: Optional[torch.Tensor] = None,
33
+ position_ids: Optional[torch.Tensor] = None,
34
+ output_attentions: Optional[bool] = None,
35
+ output_hidden_states: Optional[bool] = None,
36
+ return_dict: Optional[bool] = None,
37
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
38
+ output = super().forward(
39
+ input_ids=input_ids,
40
+ attention_mask=attention_mask,
41
+ position_ids=position_ids,
42
+ output_attentions=output_attentions,
43
+ output_hidden_states=output_hidden_states,
44
+ return_dict=return_dict,
45
+ )
46
+ if self.with_hook:
47
+ encoder_hidden_states_dict, pooled_output_dict = output
48
+ encoder_hidden_states = encoder_hidden_states_dict['T5']
49
+ pooled_output = pooled_output_dict['clip']
50
+ else:
51
+ last_hidden_state = output['last_hidden_state']['T5']
52
+ pooler_output = output['pooler_output']['clip']
53
+ attentions = output['attentions']['T5']
54
+ hidden_states = output['hidden_states']['T5']
55
+ return BaseModelOutputWithPooling(
56
+ last_hidden_state=last_hidden_state,
57
+ pooler_output=pooler_output,
58
+ hidden_states=hidden_states,
59
+ attentions=attentions,
60
+ )
61
+
62
+ return encoder_hidden_states, pooled_output
63
+
64
+ @classmethod
65
+ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, subfolder=None, revision:str=None, **kwargs):
66
+ clip_L = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder', **kwargs)
67
+ T5 = T5EncoderModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder_2', **kwargs)
68
+ return cls({'clip': clip_L, 'T5': T5})
69
+
70
+ class FluxTokenizer(ComposeTokenizer):
71
+ @classmethod
72
+ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, subfolder=None, revision:str=None, **kwargs):
73
+ clip_L = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer', **kwargs)
74
+ T5 = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer_2', **kwargs)
75
+ return cls({'clip': clip_L, 'T5': T5})
@@ -0,0 +1,86 @@
1
+ from .compose_textencoder import ComposeTextEncoder
2
+ from .compose_tokenizer import ComposeTokenizer
3
+ from transformers import CLIPTextModel, AutoTokenizer, CLIPTextModelWithProjection
4
+ from typing import Optional, Union, Tuple, Dict
5
+ import torch
6
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
7
+
8
+ class CLIPTextModelWithProjection_Align(CLIPTextModelWithProjection):
9
+ # fxxk the transformers!
10
+ def forward(
11
+ self,
12
+ input_ids: Optional[torch.Tensor] = None,
13
+ attention_mask: Optional[torch.Tensor] = None,
14
+ position_ids: Optional[torch.Tensor] = None,
15
+ output_attentions: Optional[bool] = None,
16
+ output_hidden_states: Optional[bool] = None,
17
+ return_dict: Optional[bool] = None,
18
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
19
+ try: # old version of transformers
20
+ text_outputs = super().forward(input_ids, attention_mask, position_ids, output_attentions, output_hidden_states, return_dict)
21
+ except TypeError: # new version(like 4.53.1) of transformers removed 'return_dict'
22
+ text_outputs = super().forward(input_ids, attention_mask, position_ids, output_attentions, output_hidden_states)
23
+
24
+ return BaseModelOutputWithPooling(
25
+ last_hidden_state=text_outputs.last_hidden_state,
26
+ pooler_output=text_outputs.text_embeds,
27
+ hidden_states=text_outputs.hidden_states,
28
+ attentions=text_outputs.attentions,
29
+ )
30
+
31
+ class SDXLTextEncoder(ComposeTextEncoder):
32
+ def forward(
33
+ self,
34
+ input_ids: Optional[Dict[str, torch.Tensor]] = None,
35
+ attention_mask: Optional[torch.Tensor] = None,
36
+ position_ids: Optional[torch.Tensor] = None,
37
+ output_attentions: Optional[bool] = None,
38
+ output_hidden_states: Optional[bool] = None,
39
+ return_dict: Optional[bool] = None,
40
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
41
+ output = super().forward(
42
+ input_ids=input_ids,
43
+ attention_mask=attention_mask,
44
+ position_ids=position_ids,
45
+ output_attentions=output_attentions,
46
+ output_hidden_states=output_hidden_states,
47
+ return_dict=return_dict,
48
+ )
49
+ if self.with_hook:
50
+ encoder_hidden_states_dict, pooled_output_dict = output
51
+ encoder_hidden_states = torch.cat([encoder_hidden_states_dict['clip_L'], encoder_hidden_states_dict['clip_bigG']], dim=-1)
52
+ pooled_output = pooled_output_dict['clip_bigG']
53
+ else:
54
+ last_hidden_state = torch.cat((output['last_hidden_state']['clip_L'], output['last_hidden_state']['clip_bigG']), dim=-1)
55
+ pooler_output = output['pooler_output']['clip_bigG']
56
+ attentions = output['attentions']['clip_bigG']
57
+ if output['hidden_states']['clip_L'] is None:
58
+ hidden_states = None
59
+ else:
60
+ hidden_states = [torch.cat(states, dim=self.cat_dim) for states in zip(output['hidden_states']['clip_L'], output['hidden_states']['clip_bigG'])]
61
+ return BaseModelOutputWithPooling(
62
+ last_hidden_state=last_hidden_state,
63
+ pooler_output=pooler_output,
64
+ hidden_states=hidden_states,
65
+ attentions=attentions,
66
+ )
67
+
68
+ return encoder_hidden_states, pooled_output
69
+
70
+ @classmethod
71
+ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, subfolder=None, revision:str=None, **kwargs):
72
+ clip_L = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder', **kwargs)
73
+ clip_bigG = CLIPTextModelWithProjection_Align.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder_2', **kwargs)
74
+ return cls({'clip_L': clip_L, 'clip_bigG': clip_bigG})
75
+
76
+ class SDXLTokenizer(ComposeTokenizer):
77
+ def __call__(self, text, *args, max_length=None, **kwargs):
78
+ token_info = super().__call__(text, *args, max_length=max_length, **kwargs)
79
+ token_info['attention_mask'] = token_info['attention_mask']['clip_L']
80
+ return token_info
81
+
82
+ @classmethod
83
+ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, subfolder=None, revision:str=None, **kwargs):
84
+ clip_L = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer', **kwargs)
85
+ clip_bigG = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer_2', **kwargs)
86
+ return cls({'clip_L': clip_L, 'clip_bigG': clip_bigG})
@@ -7,17 +7,18 @@ text_emb_ex.py
7
7
  :Created: 10/03/2023
8
8
  :Licence: Apache-2.0
9
9
  """
10
+ import os
11
+ from pathlib import Path
10
12
  from typing import Tuple, Dict, Any
11
13
 
12
14
  import torch
13
- from torch import nn
14
- import os
15
- from rainbowneko import _share
16
- from einops import rearrange, repeat
17
15
  import torch.nn.functional as F
16
+ from einops import rearrange, repeat
17
+ from rainbowneko import _share
18
+ from rainbowneko.models.plugin import SinglePluginBlock
19
+ from torch import nn
18
20
 
19
21
  from ..utils.net_utils import load_emb
20
- from rainbowneko.models.plugin import SinglePluginBlock
21
22
 
22
23
  class EmbeddingPTHook(SinglePluginBlock):
23
24
  def __init__(self, token_embedding:nn.Embedding, N_word=75, N_repeats=3):
@@ -74,7 +75,7 @@ class EmbeddingPTHook(SinglePluginBlock):
74
75
  self.handle_pre.remove()
75
76
 
76
77
  @classmethod
77
- def hook(cls, ex_words_emb, tokenizer, text_encoder, **kwargs):
78
+ def hook(cls, ex_words_emb:Dict[str, nn.Parameter], tokenizer, text_encoder, **kwargs):
78
79
  word_list = list(ex_words_emb.keys())
79
80
  tokenizer.add_tokens(word_list)
80
81
  token_ids = tokenizer(' '.join(word_list)).input_ids[1:-1]
@@ -87,9 +88,12 @@ class EmbeddingPTHook(SinglePluginBlock):
87
88
  return embedding_hook
88
89
 
89
90
  @classmethod
90
- def hook_from_dir(cls, emb_dir, tokenizer, text_encoder, device='cuda:0', **kwargs):
91
- ex_words_emb = {file[:-3]: nn.Parameter(load_emb(os.path.join(emb_dir, file)).to(device), requires_grad=False)
92
- for file in os.listdir(emb_dir) if file.endswith('.pt')}
91
+ def hook_from_dir(cls, emb_dir:str|Path, tokenizer, text_encoder, device='cuda', **kwargs):
92
+ if emb_dir is None:
93
+ ex_words_emb = {}
94
+ else:
95
+ emb_dir = Path(emb_dir)
96
+ ex_words_emb = {file.stem: nn.Parameter(load_emb(file).to(device), requires_grad=False) for file in emb_dir.glob('*.pt')}
93
97
  return cls.hook(ex_words_emb, tokenizer, text_encoder, **kwargs), ex_words_emb
94
98
 
95
99
  class EmbeddingPTInterpHook(SinglePluginBlock):
@@ -51,48 +51,18 @@ class TEEXHook:
51
51
  else:
52
52
  self.final_layer_norm = None
53
53
 
54
+ @property
55
+ def N_repeats(self):
56
+ return self.tokenizer.N_repeats
57
+
58
+ @N_repeats.setter
59
+ def N_repeats(self, value: int):
60
+ self.tokenizer.N_repeats = value
61
+
54
62
  @property
55
63
  def device(self):
56
64
  return self.text_enc.device
57
65
 
58
- def encode_prompt_to_emb(self, prompt):
59
- text_inputs = self.tokenizer(
60
- prompt,
61
- padding="max_length",
62
- max_length=self.tokenizer.model_max_length*self.N_repeats,
63
- truncation=True,
64
- return_tensors="pt",
65
- )
66
- text_input_ids = text_inputs.input_ids
67
- if self.use_attention_mask:
68
- attention_mask = text_inputs.get('attention_mask', None)
69
- else:
70
- attention_mask = None
71
- if attention_mask is not None:
72
- attention_mask = attention_mask.to(self.device)
73
- position_ids = text_inputs.get('position_ids', None)
74
- if position_ids is not None:
75
- position_ids = position_ids.to(self.device)
76
-
77
- # align with sd-webui
78
- if isinstance(self.text_enc, CLIPTextModelWithProjection):
79
- self.text_enc.text_projection.weight.data = self.text_enc.text_projection.weight.data.t()
80
-
81
- if isinstance(self.text_enc, T5EncoderModel):
82
- prompt_embeds, pooled_output = self.text_enc(
83
- text_input_ids.to(self.device),
84
- attention_mask=attention_mask,
85
- output_hidden_states=True,
86
- )
87
- else:
88
- prompt_embeds, pooled_output = self.text_enc(
89
- text_input_ids.to(self.device),
90
- attention_mask=attention_mask,
91
- position_ids=position_ids,
92
- output_hidden_states=True,
93
- )
94
- return prompt_embeds, pooled_output, attention_mask
95
-
96
66
  def forward_hook_input(self, host, feat_in):
97
67
  feat_re = rearrange(feat_in[0], 'b (r w) -> (b r) w', r=self.N_repeats) # 使Attention mask的尺寸为N_word+2
98
68
  return (feat_re,) if len(feat_in) == 1 else (feat_re, *feat_in[1:])
@@ -1,3 +1,4 @@
1
1
  from .sd import SD15Wrapper, SDXLWrapper
2
2
  from .pixart import PixArtWrapper
3
- from .utils import TEHookCFG, SD15_TEHookCFG, SDXL_TEHookCFG
3
+ from .utils import TEHookCFG, SD15_TEHookCFG, SDXL_TEHookCFG
4
+ from .flux import FluxWrapper
@@ -0,0 +1,75 @@
1
+ import torch
2
+ from diffusers import FluxTransformer2DModel, AutoencoderKL
3
+ from einops import repeat, rearrange
4
+ from hcpdiff.diffusion.sampler import BaseSampler
5
+ from hcpdiff.utils import pad_attn_bias
6
+ from rainbowneko.utils import add_dims
7
+
8
+ from .sd import SD15Wrapper
9
+ from .utils import TEHookCFG, SDXL_TEHookCFG
10
+ from ..cfg_context import CFGContext
11
+
12
+ class FluxWrapper(SD15Wrapper):
13
+ def __init__(self, denoiser: FluxTransformer2DModel, TE, vae: AutoencoderKL, noise_sampler: BaseSampler, tokenizer, min_attnmask=0,
14
+ guidance=5.0, patch_size=2, TE_hook_cfg: TEHookCFG = SDXL_TEHookCFG, cfg_context=CFGContext(), key_map_in=None, key_map_out=None):
15
+ super().__init__(denoiser, TE, vae, noise_sampler, tokenizer, min_attnmask, TE_hook_cfg, cfg_context, key_map_in, key_map_out)
16
+ self.key_mapper_in = self.build_mapper(key_map_in, None, (
17
+ 'prompt -> prompt_ids', 'image -> image', 'attn_mask -> attn_mask', 'neg_prompt -> neg_prompt_ids',
18
+ 'neg_attn_mask -> neg_attn_mask', 'plugin_input -> plugin_input'))
19
+ self.guidance = guidance
20
+ self.patch_size = patch_size
21
+
22
+ def forward_TE(self, prompt_ids, timesteps, attn_mask=None, position_ids=None, plugin_input={}, **kwargs):
23
+ input_all = dict(prompt_ids=prompt_ids, timesteps=timesteps, position_ids=position_ids, attn_mask=attn_mask, **plugin_input)
24
+ if hasattr(self.TE, 'input_feeder'):
25
+ for feeder in self.TE.input_feeder:
26
+ feeder(input_all)
27
+ # Get the text embedding for conditioning
28
+ encoder_hidden_states, pooled_output = self.TE(prompt_ids, position_ids=position_ids, attention_mask=attn_mask, output_hidden_states=True)
29
+ return encoder_hidden_states, pooled_output
30
+
31
+ def forward_denoiser(self, x_t, H, W, prompt_ids, encoder_hidden_states, pooled_output, timesteps, attn_mask=None, plugin_input={}, **kwargs):
32
+ attn_mask = attn_mask['T5']
33
+ if attn_mask is not None:
34
+ attn_mask[:, :self.min_attnmask] = 1
35
+ encoder_hidden_states, attn_mask = pad_attn_bias(encoder_hidden_states, attn_mask)
36
+
37
+ img_ids = torch.zeros(H, W, 3)
38
+ img_ids[..., 1] = img_ids[..., 1]+torch.arange(H)[:, None]
39
+ img_ids[..., 2] = img_ids[..., 2]+torch.arange(W)[None, :]
40
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=x_t.shape[0])
41
+
42
+ txt_ids = torch.zeros(x_t.shape[0], encoder_hidden_states.shape[1], 3)
43
+
44
+ input_all = dict(prompt_ids=prompt_ids, timesteps=timesteps, attn_mask=attn_mask, img_ids=img_ids, txt_ids=txt_ids,
45
+ encoder_hidden_states=encoder_hidden_states, **plugin_input)
46
+ if hasattr(self.denoiser, 'input_feeder'):
47
+ for feeder in self.denoiser.input_feeder:
48
+ feeder(input_all)
49
+ model_pred = self.denoiser(x_t, timesteps, self.guidance, pooled_output, encoder_hidden_states, img_ids=img_ids, txt_ids=txt_ids).sample
50
+ return model_pred
51
+
52
+ def model_forward(self, prompt_ids, image, attn_mask=None, neg_prompt_ids=None, neg_attn_mask=None, plugin_input={}, **kwargs):
53
+ # input prepare
54
+ x_0 = self.get_latents(image)
55
+ B, C, H, W = x_0.shape
56
+ x_0_patch = rearrange(x_0, "b c (h ph) (w pw) -> b (c ph pw) h w", ph=self.patch_size, pw=self.patch_size)
57
+ x_t, noise, timesteps = self.noise_sampler.add_noise_rand_t(x_0_patch)
58
+ x_t_in = x_t*add_dims(self.noise_sampler.sigma_scheduler.c_in(timesteps).to(dtype=x_t.dtype), x_t.ndim-1)
59
+ t_in = self.noise_sampler.sigma_scheduler.c_noise(timesteps)
60
+ x_t_in = rearrange(x_t_in, "b c h w -> b (h w) c")
61
+
62
+ if neg_prompt_ids:
63
+ prompt_ids = self.pn_cat(neg_prompt_ids, prompt_ids)
64
+ if neg_attn_mask:
65
+ attn_mask = self.pn_cat(neg_attn_mask, attn_mask)
66
+
67
+ # model forward
68
+ x_t_in, t_in = self.cfg_context.pre(x_t_in, t_in)
69
+ encoder_hidden_states, pooled_output = self.forward_TE(prompt_ids, t_in, attn_mask=attn_mask, plugin_input=plugin_input, **kwargs)
70
+ model_pred = self.forward_denoiser(x_t_in, H, W, prompt_ids, encoder_hidden_states, pooled_output, t_in, attn_mask=attn_mask,
71
+ plugin_input=plugin_input, **kwargs)
72
+ model_pred = rearrange(model_pred, "b (h w) (c ph pw) -> b c (h ph) (w pw)", ph=self.patch_size, pw=self.patch_size, h=H, w=W)
73
+ model_pred = self.cfg_context.post(model_pred)
74
+
75
+ return dict(model_pred=model_pred, noise=noise, timesteps=timesteps, x_0=x_0, x_t=x_t, noise_sampler=self.noise_sampler)
@@ -2,6 +2,16 @@ from .sd import SD15Wrapper
2
2
  from hcpdiff.utils import pad_attn_bias
3
3
 
4
4
  class PixArtWrapper(SD15Wrapper):
5
+ def forward_TE(self, prompt_ids, timesteps, attn_mask=None, plugin_input={}, **kwargs):
6
+ # T5Encoder do not need position_ids (It use relative position embedding for key and query)
7
+ input_all = dict(prompt_ids=prompt_ids, timesteps=timesteps, attn_mask=attn_mask, **plugin_input)
8
+ if hasattr(self.TE, 'input_feeder'):
9
+ for feeder in self.TE.input_feeder:
10
+ feeder(input_all)
11
+ # Get the text embedding for conditioning
12
+ encoder_hidden_states = self.TE(prompt_ids, attention_mask=attn_mask, output_hidden_states=True)[0]
13
+ return encoder_hidden_states
14
+
5
15
  def forward_denoiser(self, x_t, prompt_ids, encoder_hidden_states, timesteps, attn_mask=None, position_ids=None, resolution=None, aspect_ratio=None,
6
16
  plugin_input={}, **kwargs):
7
17
  if attn_mask is not None:
@@ -16,4 +26,6 @@ class PixArtWrapper(SD15Wrapper):
16
26
  added_cond_kwargs = {"resolution":resolution, "aspect_ratio":aspect_ratio}
17
27
  model_pred = self.denoiser(x_t, encoder_hidden_states, timesteps, encoder_attention_mask=attn_mask,
18
28
  added_cond_kwargs=added_cond_kwargs).sample # Predict the noise residual
19
- return model_pred
29
+
30
+ # remove pred vars for pixart output (see DiT for more)
31
+ return model_pred.chunk(2, dim=1)[0]
@@ -60,7 +60,10 @@ class SD15Wrapper(BaseWrapper):
60
60
  if image.shape[1] == 3:
61
61
  with torch.no_grad() if self.vae_trainable else nullcontext():
62
62
  latents = self.vae.encode(image.to(dtype=self.vae.dtype)).latent_dist.sample()
63
- latents = latents*self.vae.config.scaling_factor
63
+ if shift_factor := getattr(self.vae.config, 'shift_factor', None) is not None:
64
+ latents = (latents-shift_factor)*self.vae.config.scaling_factor
65
+ else:
66
+ latents = latents*self.vae.config.scaling_factor
64
67
  else:
65
68
  latents = image # Cached latents
66
69
  return latents
@@ -87,6 +90,12 @@ class SD15Wrapper(BaseWrapper):
87
90
  model_pred = self.denoiser(x_t, timesteps, encoder_hidden_states, encoder_attention_mask=attn_mask).sample # Predict the noise residual
88
91
  return model_pred
89
92
 
93
+ def pn_cat(self, neg, pos, dim=0):
94
+ if isinstance(pos, dict): # ComposeTextEncoder
95
+ return {name:torch.cat([neg[name], pos_i], dim=dim) for name, pos_i in pos.items()}
96
+ else:
97
+ return torch.cat([neg, pos], dim=dim)
98
+
90
99
  def model_forward(self, prompt_ids, image, attn_mask=None, position_ids=None, neg_prompt_ids=None, neg_attn_mask=None, neg_position_ids=None,
91
100
  plugin_input={}, **kwargs):
92
101
  # input prepare
@@ -96,11 +105,11 @@ class SD15Wrapper(BaseWrapper):
96
105
  t_in = self.noise_sampler.sigma_scheduler.c_noise(timesteps)
97
106
 
98
107
  if neg_prompt_ids:
99
- prompt_ids = torch.cat([neg_prompt_ids, prompt_ids], dim=0)
108
+ prompt_ids = self.pn_cat(neg_prompt_ids, prompt_ids)
100
109
  if neg_attn_mask:
101
- attn_mask = torch.cat([neg_attn_mask, attn_mask], dim=0)
110
+ attn_mask = self.pn_cat(neg_attn_mask, attn_mask)
102
111
  if neg_position_ids:
103
- position_ids = torch.cat([neg_position_ids, position_ids], dim=0)
112
+ position_ids = self.pn_cat(neg_position_ids, position_ids)
104
113
 
105
114
  # model forward
106
115
  x_t_in, t_in = self.cfg_context.pre(x_t_in, t_in)
@@ -198,17 +207,17 @@ class SDXLWrapper(SD15Wrapper):
198
207
  t_in = self.noise_sampler.sigma_scheduler.c_noise(timesteps)
199
208
 
200
209
  if neg_prompt_ids:
201
- prompt_ids = torch.cat([neg_prompt_ids, prompt_ids], dim=0)
210
+ prompt_ids = self.pn_cat(neg_prompt_ids, prompt_ids)
202
211
  if neg_attn_mask:
203
- attn_mask = torch.cat([neg_attn_mask, attn_mask], dim=0)
212
+ attn_mask = self.pn_cat(neg_attn_mask, attn_mask)
204
213
  if neg_position_ids:
205
- position_ids = torch.cat([neg_position_ids, position_ids], dim=0)
214
+ position_ids = self.pn_cat(neg_position_ids, position_ids)
206
215
 
207
216
  # model forward
208
217
  x_t_in, t_in = self.cfg_context.pre(x_t_in, t_in)
209
218
  encoder_hidden_states, pooled_output = self.forward_TE(prompt_ids, t_in, attn_mask=attn_mask, position_ids=position_ids,
210
219
  plugin_input=plugin_input)
211
- added_cond_kwargs = {"text_embeds":pooled_output[-1], "time_ids":crop_info}
220
+ added_cond_kwargs = {"text_embeds":pooled_output, "time_ids":crop_info}
212
221
  model_pred = self.forward_denoiser(x_t_in, prompt_ids, encoder_hidden_states, t_in, added_cond_kwargs=added_cond_kwargs,
213
222
  attn_mask=attn_mask, position_ids=position_ids, plugin_input=plugin_input)
214
223
  model_pred = self.cfg_context.post(model_pred)
hcpdiff/parser/embpt.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from typing import Dict, Tuple, List
2
2
  from rainbowneko.utils import Path_Like
3
- from hcpdiff.models import EmbeddingPTHook
4
- from torch import Tensor
3
+ from hcpdiff.models.compose import ComposeEmbPTHook
4
+ from torch import Tensor, nn
5
5
 
6
6
  class CfgEmbPTParser:
7
7
  def __init__(self, emb_dir: Path_Like, cfg_pt: Dict[str, Dict], lr: float = 1e-5, weight_decay: float = 0):
@@ -11,22 +11,22 @@ class CfgEmbPTParser:
11
11
  self.weight_decay = weight_decay
12
12
 
13
13
  def get_params_group(self, model) -> Tuple[List, Dict[str, Tensor]]:
14
- self.embedding_hook, self.ex_words_emb = EmbeddingPTHook.hook_from_dir(
14
+ self.embedding_hook, self.ex_words_emb = ComposeEmbPTHook.hook_from_dir(
15
15
  self.emb_dir, model.tokenizer, model.TE, N_repeats=model.tokenizer.N_repeats)
16
16
  self.embedding_hook.requires_grad_(False)
17
17
 
18
18
  train_params_emb = []
19
19
  train_pts = {}
20
20
  for pt_name, info in self.cfg_pt.items():
21
- word_emb = self.ex_words_emb[pt_name]
21
+ word_emb: nn.Parameter | nn.ParameterDict = self.ex_words_emb[pt_name]
22
22
  train_pts[pt_name] = word_emb
23
- word_emb.requires_grad = True
23
+ word_emb.requires_grad_(True)
24
24
  self.embedding_hook.emb_train.append(word_emb)
25
- param_group = {'params':word_emb}
25
+ param_group = {'params':word_emb.parameters() if hasattr(word_emb, 'parameters') else [word_emb]}
26
26
  if 'lr' in info:
27
27
  param_group['lr'] = info.lr
28
28
  if 'weight_decay' in info:
29
29
  param_group['weight_decay'] = info.weight_decay
30
30
  train_params_emb.append(param_group)
31
31
 
32
- return train_params_emb, train_pts
32
+ return train_params_emb, train_pts
@@ -1,15 +1,16 @@
1
+ import json
1
2
  import os
2
3
  from copy import deepcopy
4
+ from functools import partial
3
5
  from typing import Optional, Union
4
6
 
5
7
  import torch
6
8
  from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION, Optimizer
9
+ from huggingface_hub import hf_hub_download
7
10
  from torch import nn
8
11
  from torch.optim import lr_scheduler
9
12
  from transformers import PretrainedConfig, AutoTokenizer, T5EncoderModel, CLIPTextModel
10
- from functools import partial
11
- from huggingface_hub import hf_hub_download
12
- import json
13
+ from transformers.models.auto.tokenization_auto import get_tokenizer_config
13
14
 
14
15
  dtype_dict = {'fp32':torch.float32, 'amp':torch.float32, 'fp16':torch.float16, 'bf16':torch.bfloat16}
15
16
 
@@ -91,19 +92,24 @@ def get_scheduler_with_name(
91
92
  return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **scheduler_kwargs)
92
93
 
93
94
  def auto_tokenizer_cls(pretrained_model_name_or_path: str, revision: str = None):
94
- from hcpdiff.models.compose import SDXLTokenizer
95
+ from hcpdiff.models.compose import SDXLTokenizer, FluxTokenizer
95
96
  try:
96
- tokenizer = AutoTokenizer.from_pretrained(
97
- pretrained_model_name_or_path, subfolder="tokenizer_2",
98
- revision=revision, use_fast=False,
97
+ tokenizer_config = get_tokenizer_config(
98
+ pretrained_model_name_or_path,
99
+ subfolder="tokenizer_2",
100
+ revision=revision
99
101
  )
100
- return SDXLTokenizer
102
+ class_name = tokenizer_config.get("tokenizer_class")
103
+ if class_name == 'T5Tokenizer' or class_name == 'T5TokenizerFast':
104
+ return FluxTokenizer
105
+ else:
106
+ return SDXLTokenizer
101
107
  except:
102
- # not sdxl, only one tokenizer
108
+ # not composed, only one tokenizer
103
109
  return AutoTokenizer
104
110
 
105
111
  def auto_text_encoder_cls(pretrained_model_name_or_path: str, revision: str = None):
106
- from hcpdiff.models.compose import SDXLTextEncoder
112
+ from hcpdiff.models.compose import SDXLTextEncoder, FluxTextEncoder
107
113
  try:
108
114
  text_encoder_config = PretrainedConfig.from_pretrained(
109
115
  pretrained_model_name_or_path,
@@ -112,7 +118,11 @@ def auto_text_encoder_cls(pretrained_model_name_or_path: str, revision: str = No
112
118
  )
113
119
  if text_encoder_config.architectures is None:
114
120
  raise ValueError()
115
- return SDXLTextEncoder
121
+ model_class = text_encoder_config.architectures[0]
122
+ if model_class == "T5EncoderModel":
123
+ return FluxTextEncoder
124
+ else:
125
+ return SDXLTextEncoder
116
126
  except:
117
127
  text_encoder_config = PretrainedConfig.from_pretrained(
118
128
  pretrained_model_name_or_path,
@@ -248,4 +258,4 @@ def get_dtype(dtype):
248
258
  if isinstance(dtype, torch.dtype):
249
259
  return dtype
250
260
  else:
251
- return dtype_dict.get(dtype, torch.float32)
261
+ return dtype_dict.get(dtype, torch.float32)
@@ -1,4 +1,4 @@
1
- from .diffusion import InputFeederAction, MakeLatentAction, DenoiseAction, SampleAction, DiffusionStepAction, \
1
+ from .diffusion import InputFeederAction, MakeLatentAction, SD15DenoiseAction, SDXLDenoiseAction, PixartDenoiseAction, FluxDenoiseAction, SampleAction, DiffusionStepAction, \
2
2
  X0PredAction, SeedAction, MakeTimestepsAction, PrepareDiffusionAction, time_iter, DiffusionActions
3
3
  from .text import TextEncodeAction, TextHookAction, AttnMultTextEncodeAction
4
4
  from .vae import EncodeAction, DecodeAction