hcpdiff 2.3.1__py3-none-any.whl → 2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. hcpdiff/ckpt_manager/__init__.py +1 -1
  2. hcpdiff/ckpt_manager/format/__init__.py +2 -2
  3. hcpdiff/ckpt_manager/format/diffusers.py +19 -4
  4. hcpdiff/ckpt_manager/format/emb.py +8 -3
  5. hcpdiff/ckpt_manager/format/lora_webui.py +1 -1
  6. hcpdiff/ckpt_manager/format/sd_single.py +28 -5
  7. hcpdiff/data/cache/vae.py +10 -2
  8. hcpdiff/data/handler/text.py +15 -14
  9. hcpdiff/diffusion/sampler/__init__.py +2 -1
  10. hcpdiff/diffusion/sampler/base.py +17 -6
  11. hcpdiff/diffusion/sampler/diffusers.py +4 -3
  12. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +5 -14
  13. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +7 -6
  14. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +4 -4
  15. hcpdiff/diffusion/sampler/sigma_scheduler/flow.py +3 -3
  16. hcpdiff/diffusion/sampler/timer/__init__.py +2 -0
  17. hcpdiff/diffusion/sampler/timer/base.py +26 -0
  18. hcpdiff/diffusion/sampler/timer/shift.py +49 -0
  19. hcpdiff/easy/__init__.py +2 -1
  20. hcpdiff/easy/cfg/sd15_train.py +1 -3
  21. hcpdiff/easy/model/__init__.py +1 -1
  22. hcpdiff/easy/model/loader.py +33 -11
  23. hcpdiff/easy/sampler.py +8 -1
  24. hcpdiff/loss/__init__.py +4 -3
  25. hcpdiff/loss/charbonnier.py +17 -0
  26. hcpdiff/loss/vlb.py +2 -2
  27. hcpdiff/loss/weighting.py +29 -11
  28. hcpdiff/models/__init__.py +1 -1
  29. hcpdiff/models/cfg_context.py +5 -3
  30. hcpdiff/models/compose/__init__.py +2 -1
  31. hcpdiff/models/compose/compose_hook.py +69 -67
  32. hcpdiff/models/compose/compose_textencoder.py +59 -45
  33. hcpdiff/models/compose/compose_tokenizer.py +48 -11
  34. hcpdiff/models/compose/flux.py +75 -0
  35. hcpdiff/models/compose/sdxl.py +86 -0
  36. hcpdiff/models/text_emb_ex.py +13 -9
  37. hcpdiff/models/textencoder_ex.py +8 -38
  38. hcpdiff/models/wrapper/__init__.py +2 -1
  39. hcpdiff/models/wrapper/flux.py +75 -0
  40. hcpdiff/models/wrapper/pixart.py +13 -1
  41. hcpdiff/models/wrapper/sd.py +17 -8
  42. hcpdiff/parser/embpt.py +7 -7
  43. hcpdiff/utils/net_utils.py +22 -12
  44. hcpdiff/workflow/__init__.py +1 -1
  45. hcpdiff/workflow/diffusion.py +145 -18
  46. hcpdiff/workflow/text.py +49 -18
  47. hcpdiff/workflow/vae.py +10 -2
  48. {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/METADATA +1 -1
  49. {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/RECORD +53 -49
  50. hcpdiff/models/compose/sdxl_composer.py +0 -39
  51. hcpdiff/utils/inpaint_pipe.py +0 -790
  52. hcpdiff/utils/pipe_hook.py +0 -656
  53. {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/WHEEL +0 -0
  54. {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/entry_points.txt +0 -0
  55. {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/licenses/LICENSE +0 -0
  56. {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,12 @@
1
1
  import torch
2
- from hcpdiff.ckpt_manager import DiffusersSD15Format, DiffusersSDXLFormat, DiffusersPixArtFormat, OfficialSD15Format, OfficialSDXLFormat
2
+ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, FluxPipeline
3
3
  from rainbowneko.ckpt_manager import NekoLoader, LocalCkptSource
4
- from hcpdiff.utils import auto_tokenizer_cls, auto_text_encoder_cls, get_pipe_name
5
- from hcpdiff.models.wrapper import SDXLWrapper, SD15Wrapper, PixArtWrapper
6
- from hcpdiff.models.compose import SDXLTextEncoder
7
- from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
4
+
5
+ from hcpdiff.ckpt_manager import DiffusersSD15Format, DiffusersSDXLFormat, DiffusersPixArtFormat, OfficialSD15Format, OfficialSDXLFormat, \
6
+ DiffusersFluxFormat, OneFileFluxFormat
7
+ from hcpdiff.models.compose import SDXLTextEncoder, FluxTextEncoder
8
+ from hcpdiff.models.wrapper import SDXLWrapper, SD15Wrapper, PixArtWrapper, FluxWrapper
9
+ from hcpdiff.utils import auto_text_encoder_cls, get_pipe_name
8
10
 
9
11
  def SD15_auto_loader(ckpt_path, denoiser=None, TE=None, vae=None, noise_sampler=None,
10
12
  tokenizer=None, revision=None, dtype=torch.float32, **kwargs):
@@ -20,7 +22,7 @@ def SD15_auto_loader(ckpt_path, denoiser=None, TE=None, vae=None, noise_sampler=
20
22
  source=LocalCkptSource(),
21
23
  )
22
24
  models = loader.load(ckpt_path, denoiser=denoiser, TE=TE, vae=vae, noise_sampler=noise_sampler, tokenizer=tokenizer, revision=revision,
23
- dtype=dtype, **kwargs)
25
+ dtype=dtype, **kwargs)
24
26
  return models
25
27
 
26
28
  def SDXL_auto_loader(ckpt_path, denoiser=None, TE=None, vae=None, noise_sampler=None,
@@ -37,17 +39,34 @@ def SDXL_auto_loader(ckpt_path, denoiser=None, TE=None, vae=None, noise_sampler=
37
39
  source=LocalCkptSource(),
38
40
  )
39
41
  models = loader.load(ckpt_path, denoiser=denoiser, TE=TE, vae=vae, noise_sampler=noise_sampler, tokenizer=tokenizer, revision=revision,
40
- dtype=dtype, **kwargs)
42
+ dtype=dtype, **kwargs)
41
43
  return models
42
44
 
43
45
  def PixArt_auto_loader(ckpt_path, denoiser=None, TE=None, vae=None, noise_sampler=None,
44
- tokenizer=None, revision=None, dtype=torch.float32, **kwargs):
46
+ tokenizer=None, revision=None, dtype=torch.float32, **kwargs):
45
47
  loader = NekoLoader(
46
48
  format=DiffusersPixArtFormat(),
47
49
  source=LocalCkptSource(),
48
50
  )
49
51
  models = loader.load(ckpt_path, denoiser=denoiser, TE=TE, vae=vae, noise_sampler=noise_sampler, tokenizer=tokenizer, revision=revision,
50
- dtype=dtype, **kwargs)
52
+ dtype=dtype, **kwargs)
53
+ return models
54
+
55
+ def Flux_auto_loader(ckpt_path, denoiser=None, TE=None, vae=None, noise_sampler=None,
56
+ tokenizer=None, revision=None, dtype=torch.float32, **kwargs):
57
+ try:
58
+ try_diffusers = FluxPipeline.load_config(ckpt_path)
59
+ loader = NekoLoader(
60
+ format=DiffusersFluxFormat(),
61
+ source=LocalCkptSource(),
62
+ )
63
+ except EnvironmentError:
64
+ loader = NekoLoader(
65
+ format=OneFileFluxFormat(),
66
+ source=LocalCkptSource(),
67
+ )
68
+ models = loader.load(ckpt_path, denoiser=denoiser, TE=TE, vae=vae, noise_sampler=noise_sampler, tokenizer=tokenizer, revision=revision,
69
+ dtype=dtype, **kwargs)
51
70
  return models
52
71
 
53
72
  def auto_load_wrapper(pretrained_model, denoiser=None, TE=None, vae=None, noise_sampler=None, tokenizer=None, revision=None,
@@ -62,6 +81,9 @@ def auto_load_wrapper(pretrained_model, denoiser=None, TE=None, vae=None, noise_
62
81
  if text_encoder_cls == SDXLTextEncoder:
63
82
  wrapper_cls = SDXLWrapper
64
83
  format = DiffusersSDXLFormat()
84
+ elif text_encoder_cls == FluxTextEncoder:
85
+ wrapper_cls = FluxWrapper
86
+ format = DiffusersFluxFormat()
65
87
  elif 'PixArt' in pipe_name:
66
88
  wrapper_cls = PixArtWrapper
67
89
  format = DiffusersPixArtFormat()
@@ -74,6 +96,6 @@ def auto_load_wrapper(pretrained_model, denoiser=None, TE=None, vae=None, noise_
74
96
  source=LocalCkptSource(),
75
97
  )
76
98
  models = loader.load(pretrained_model, denoiser=denoiser, TE=TE, vae=vae, noise_sampler=noise_sampler, tokenizer=tokenizer, revision=revision,
77
- dtype=dtype)
99
+ dtype=dtype)
78
100
 
79
- return wrapper_cls.build_from_pretrained(models, **kwargs)
101
+ return wrapper_cls.build_from_pretrained(models, **kwargs)
hcpdiff/easy/sampler.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from hcpdiff.diffusion.sampler import DiffusersSampler
2
- from diffusers import DPMSolverMultistepScheduler, DDIMScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler
2
+ from diffusers import DPMSolverMultistepScheduler, DDIMScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, FlowMatchEulerDiscreteScheduler
3
3
 
4
4
  class Diffusers_SD:
5
5
  dpmpp_2m = DiffusersSampler(
@@ -43,4 +43,11 @@ class Diffusers_SD:
43
43
  beta_end=0.012,
44
44
  beta_schedule='scaled_linear',
45
45
  )
46
+ )
47
+
48
+ euler_flow = DiffusersSampler(
49
+ FlowMatchEulerDiscreteScheduler(
50
+ shift=3.0,
51
+ use_dynamic_shifting=True,
52
+ )
46
53
  )
hcpdiff/loss/__init__.py CHANGED
@@ -1,4 +1,5 @@
1
- from .weighting import MinSNRWeight, SNRWeight, EDMWeight, LossWeight
2
- from .ssim import SSIMLoss, MS_SSIMLoss
1
+ from .base import DiffusionLossContainer
2
+ from .charbonnier import CharbonnierLoss
3
3
  from .gw import GWLoss
4
- from .base import DiffusionLossContainer
4
+ from .ssim import SSIMLoss, MS_SSIMLoss
5
+ from .weighting import MinSNRWeight, SNRWeight, EDMWeight, LossWeight, LossMapWeight
@@ -0,0 +1,17 @@
1
+ import torch
2
+ from torch import nn
3
+
4
+ class CharbonnierLoss(nn.Module):
5
+ """Charbonnier Loss (L1)"""
6
+
7
+ def __init__(self, eps=1e-3, size_average=True):
8
+ super(CharbonnierLoss, self).__init__()
9
+ self.eps = eps
10
+ self.size_average = size_average
11
+
12
+ def forward(self, x, y):
13
+ diff = x - y
14
+ loss = torch.sqrt((diff * diff) + (self.eps*self.eps))
15
+ if self.size_average:
16
+ loss = loss.mean()
17
+ return loss
hcpdiff/loss/vlb.py CHANGED
@@ -25,10 +25,10 @@ class VLBLoss(nn.Module):
25
25
  x0_pred = sampler.eps_to_x0(eps_pred, x_t, sigma)
26
26
 
27
27
  true_mean = sampler.sigma_scheduler.get_post_mean(timesteps, target, x_t)
28
- true_logvar = sampler.sigma_scheduler.get_post_log_var(timesteps)
28
+ true_logvar = sampler.sigma_scheduler.get_post_log_var(timesteps, ndim=input.ndim)
29
29
 
30
30
  pred_mean = sampler.sigma_scheduler.get_post_mean(timesteps, x0_pred, x_t)
31
- pred_logvar = sampler.sigma_scheduler.get_post_log_var(timesteps, x_t_var=var_pred)
31
+ pred_logvar = sampler.sigma_scheduler.get_post_log_var(timesteps, ndim=input.ndim, x_t_var=var_pred)
32
32
 
33
33
  kl = self.normal_kl(true_mean, true_logvar, pred_mean, pred_logvar)
34
34
  kl = kl.mean(dim=(1,2,3))/np.log(2.0)
hcpdiff/loss/weighting.py CHANGED
@@ -1,9 +1,10 @@
1
+ from rainbowneko.utils import add_dims
2
+ from rainbowneko.train.loss import FullInputLoss
1
3
  from torch import nn
4
+ from typing import Callable
2
5
 
3
- from .base import DiffusionLossContainer
4
-
5
- class LossWeight(nn.Module):
6
- def __init__(self, loss: DiffusionLossContainer):
6
+ class LossWeight(nn.Module, FullInputLoss):
7
+ def __init__(self, loss: Callable):
7
8
  super().__init__()
8
9
  self.loss = loss
9
10
 
@@ -21,12 +22,29 @@ class LossWeight(nn.Module):
21
22
  '''
22
23
  raise NotImplementedError
23
24
 
24
- def forward(self, pred, inputs):
25
+ def forward(self, pred, inputs, _full_pred, _full_inputs):
25
26
  '''
26
27
  weight: [B,1,1,1] or [B,C,H,W]
27
28
  loss: [B,*,*,*]
28
29
  '''
29
- return self.get_weight(pred, inputs)*self.loss(pred, inputs)
30
+ return self.get_weight(_full_pred, _full_inputs)*self.loss(pred, inputs)
31
+
32
+ class LossMapWeight(LossWeight):
33
+ def __init__(self, loss: Callable, normalize: bool = False):
34
+ super().__init__(loss)
35
+ self.normalize = normalize
36
+
37
+ def get_weight(self, pred, inputs):
38
+ ndim = pred['model_pred'].ndim
39
+ loss_map = inputs['loss_map'].float()
40
+ if ndim == 4:
41
+ if self.normalize:
42
+ loss_map /= loss_map.mean(dim=(1,2), keepdim=True)
43
+ return loss_map.unsqueeze(1)
44
+ elif ndim == 3:
45
+ if self.normalize:
46
+ loss_map /= loss_map.mean(dim=1, keepdim=True)
47
+ return loss_map.unsqueeze(-1)
30
48
 
31
49
  class SNRWeight(LossWeight):
32
50
  def get_weight(self, pred, inputs):
@@ -42,10 +60,10 @@ class SNRWeight(LossWeight):
42
60
  else:
43
61
  raise ValueError(f"{self.__class__.__name__} is not support for target_type {target_type}")
44
62
 
45
- return w_snr.view(-1, 1, 1, 1)
63
+ return add_dims(w_snr, pred['model_pred'].ndim-1)
46
64
 
47
65
  class MinSNRWeight(LossWeight):
48
- def __init__(self, loss: DiffusionLossContainer, gamma: float = 1.):
66
+ def __init__(self, loss: Callable, gamma: float = 1.):
49
67
  super().__init__(loss)
50
68
  self.gamma = gamma
51
69
 
@@ -63,10 +81,10 @@ class MinSNRWeight(LossWeight):
63
81
  else:
64
82
  raise ValueError(f"{self.__class__.__name__} is not support for target_type {target_type}")
65
83
 
66
- return w_snr.view(-1, 1, 1, 1)
84
+ return add_dims(w_snr, pred['model_pred'].ndim-1)
67
85
 
68
86
  class EDMWeight(LossWeight):
69
- def __init__(self, loss: DiffusionLossContainer, gamma: float = 1.):
87
+ def __init__(self, loss: Callable, gamma: float = 1.):
70
88
  super().__init__(loss)
71
89
  self.gamma = gamma
72
90
 
@@ -81,4 +99,4 @@ class EDMWeight(LossWeight):
81
99
  else:
82
100
  raise ValueError(f"{self.__class__.__name__} is not support for target_type {target_type}")
83
101
 
84
- return w_snr.view(-1, 1, 1, 1)
102
+ return add_dims(w_snr, pred['model_pred'].ndim-1)
@@ -6,5 +6,5 @@ from .text_emb_ex import EmbeddingPTHook
6
6
  from .textencoder_ex import TEEXHook
7
7
  from .tokenizer_ex import TokenizerHook
8
8
  from .cfg_context import CFGContext, DreamArtistPTContext
9
- from .wrapper import SD15Wrapper, SDXLWrapper, PixArtWrapper, TEHookCFG
9
+ from .wrapper import SD15Wrapper, SDXLWrapper, PixArtWrapper, TEHookCFG, FluxWrapper
10
10
  from .controlnet import ControlNetPlugin
@@ -1,8 +1,10 @@
1
- import torch
2
- from einops import repeat
3
1
  import math
4
2
  from typing import Union, Callable
5
3
 
4
+ import torch
5
+ from einops import repeat
6
+ from rainbowneko.utils import add_dims
7
+
6
8
  class CFGContext:
7
9
  def pre(self, noisy_latents, timesteps):
8
10
  return noisy_latents, timesteps
@@ -35,7 +37,7 @@ class DreamArtistPTContext(CFGContext):
35
37
  pass
36
38
  else:
37
39
  rate = self.cfg_func(rate)
38
- rate = rate.view(-1, 1, 1, 1)
40
+ rate = add_dims(rate, model_pred.ndim-1)
39
41
  else:
40
42
  rate = 1
41
43
  model_pred = e_t_uncond+((self.cfg_high-self.cfg_low)*rate+self.cfg_low)*(e_t-e_t_uncond)
@@ -1,4 +1,5 @@
1
1
  from .compose_tokenizer import ComposeTokenizer
2
2
  from .compose_textencoder import ComposeTextEncoder
3
3
  from .compose_hook import ComposeTEEXHook, ComposeEmbPTHook
4
- from .sdxl_composer import SDXLTokenizer, SDXLTextEncoder
4
+ from .sdxl import SDXLTokenizer, SDXLTextEncoder
5
+ from .flux import FluxTokenizer, FluxTextEncoder
@@ -1,128 +1,129 @@
1
- import os
2
- from typing import Dict, Union, Tuple, List
1
+ from pathlib import Path
2
+ from typing import Dict, Union, Tuple
3
3
 
4
4
  import torch
5
- from loguru import logger
6
5
  from torch import nn
7
6
 
7
+ from hcpdiff.utils.net_utils import load_emb
8
8
  from .compose_textencoder import ComposeTextEncoder
9
9
  from ..text_emb_ex import EmbeddingPTHook
10
10
  from ..textencoder_ex import TEEXHook
11
- from ...utils.net_utils import load_emb
12
- from ..container import ParameterGroup
13
11
 
14
12
  class ComposeEmbPTHook(nn.Module):
15
- def __init__(self, hook_list: List[Tuple[str, EmbeddingPTHook]]):
13
+ def __init__(self, hooks: Dict[str, EmbeddingPTHook]):
16
14
  super().__init__()
17
- self.hook_list = hook_list
18
- self.emb_train = nn.ParameterList()
15
+ self.hooks = hooks
16
+ self.emb_train = nn.ParameterList() # [ParameterDict{model_name:Parameter, ...}, ...]
19
17
 
20
18
  @property
21
19
  def N_repeats(self):
22
- return self.hook_list[0][1].N_repeats
20
+ return {name:hook.N_repeats for name, hook in self.hooks.items()}
23
21
 
24
22
  @N_repeats.setter
25
23
  def N_repeats(self, value):
26
- for name, hook in self.hook_list:
27
- hook.N_repeats = value
24
+ for name, hook in self.hooks.items():
25
+ if isinstance(value, int):
26
+ hook.N_repeats = value
27
+ else:
28
+ hook.N_repeats = value[name]
28
29
 
29
- def add_emb(self, emb: nn.Parameter, token_id_list: List[int]):
30
- emb_len = 0
30
+ def add_emb(self, emb: Dict[str, nn.Parameter], token_ids: Dict[str, int]):
31
31
  # Same word in different tokenizer may have different token_id
32
- for (name, hook), token_id in zip(self.hook_list, token_id_list):
33
- hook.add_emb(emb[:, emb_len:emb_len+hook.embedding_dim], token_id)
34
- emb_len += hook.embedding_dim
32
+ for name, hook in self.hooks.items():
33
+ hook.add_emb(emb[name], token_ids[name])
35
34
 
36
35
  def remove(self):
37
- for name, hook in self.hook_list:
36
+ for name, hook in self.hooks.items():
38
37
  hook.remove()
39
38
 
40
39
  @classmethod
41
- def hook(cls, ex_words_emb: Dict[str, ParameterGroup], tokenizer, text_encoder, **kwargs):
40
+ def hook(cls, ex_words_emb: Dict[str, nn.ParameterDict], tokenizer, text_encoder, **kwargs):
42
41
  if isinstance(text_encoder, ComposeTextEncoder):
43
- hook_list = []
42
+ hooks = {}
44
43
 
45
44
  emb_len = 0
46
- for i, name in enumerate(tokenizer.tokenizer_names):
45
+ for name in tokenizer.tokenizer_names:
47
46
  text_encoder_i = getattr(text_encoder, name)
48
47
  tokenizer_i = getattr(tokenizer, name)
49
48
  embedding_dim = text_encoder_i.get_input_embeddings().embedding_dim
50
- ex_words_emb_i = {k:v[i] for k, v in ex_words_emb.items()}
49
+ ex_words_emb_i = {k:v[name] for k, v in ex_words_emb.items()} # {word_name:Parameter, ...}
51
50
  emb_len += embedding_dim
52
- hook_list.append((name, EmbeddingPTHook.hook(ex_words_emb_i, tokenizer_i, text_encoder_i, **kwargs)))
51
+ hooks[name] = EmbeddingPTHook.hook(ex_words_emb_i, tokenizer_i, text_encoder_i, **kwargs)
53
52
 
54
- return cls(hook_list)
53
+ return cls(hooks)
55
54
  else:
56
55
  return EmbeddingPTHook.hook(ex_words_emb, tokenizer, text_encoder, **kwargs)
57
56
 
58
57
  @classmethod
59
- def hook_from_dir(cls, emb_dir, tokenizer, text_encoder, device='cuda:0', **kwargs) -> Union[
60
- Tuple['ComposeEmbPTHook', Dict], Tuple[EmbeddingPTHook, Dict]]:
58
+ def hook_from_dir(cls, emb_dir: str | Path, tokenizer, text_encoder, device='cuda', **kwargs) -> (
59
+ Tuple['ComposeEmbPTHook', Dict[str, nn.ParameterDict]] | Tuple[EmbeddingPTHook, Dict[str, nn.Parameter]]):
60
+ emb_dir = Path(emb_dir) if emb_dir is not None else None
61
61
  if isinstance(text_encoder, ComposeTextEncoder):
62
62
  # multi text encoder
63
- # ex_words_emb = {file[:-3]:load_emb(os.path.join(emb_dir, file)).to(device) for file in os.listdir(emb_dir) if file.endswith('.pt')}
64
-
65
- # slice of nn.Parameter cannot return grad. Split the tensor
66
- ex_words_emb = {}
67
- if emb_dir is not None and os.path.exists(emb_dir):
68
- emb_dims = [x.embedding_dim for x in text_encoder.get_input_embeddings()]
69
- for file in os.listdir(emb_dir):
70
- if file.endswith('.pt'):
71
- emb = load_emb(os.path.join(emb_dir, file)).to(device)
72
- emb = ParameterGroup([nn.Parameter(item, requires_grad=False) for item in emb.split(emb_dims, dim=1)])
73
- ex_words_emb[file[:-3]] = emb
63
+ ex_words_emb = {} # {word_name:{model_name:Tensor, ...}, ...}
64
+ if emb_dir is not None and emb_dir.exists():
65
+ for file in emb_dir.glob('*.pt'):
66
+ emb = load_emb(file) # {model_name:Tensor, ...}
67
+ emb = nn.ParameterDict({name:nn.Parameter(emb_i.to(device), requires_grad=False) for name, emb_i in emb.items()})
68
+ ex_words_emb[file.stem] = emb
74
69
  return cls.hook(ex_words_emb, tokenizer, text_encoder, **kwargs), ex_words_emb
75
70
  else:
76
71
  return EmbeddingPTHook.hook_from_dir(emb_dir, tokenizer, text_encoder, **kwargs)
77
72
 
78
73
  class ComposeTEEXHook:
79
- def __init__(self, tehook_list: List[Tuple[str, TEEXHook]], cat_dim=-1):
80
- self.tehook_list = tehook_list
81
- self.cat_dim = cat_dim
74
+ def __init__(self, tehooks: Dict[str, TEEXHook]):
75
+ self.tehooks = tehooks
82
76
 
83
77
  @property
84
78
  def N_repeats(self):
85
- return self.tehook_list[0][1].N_repeats
79
+ return {name:tehook.N_repeats for name, tehook in self.tehooks.items()}
86
80
 
87
81
  @N_repeats.setter
88
- def N_repeats(self, value):
89
- for name, tehook in self.tehook_list:
90
- tehook.N_repeats = value
82
+ def N_repeats(self, value: int | Dict[str, int]):
83
+ for name, tehook in self.tehooks.items():
84
+ if isinstance(value, int):
85
+ tehook.N_repeats = value
86
+ else:
87
+ tehook.N_repeats = value[name]
91
88
 
92
89
  @property
93
90
  def clip_skip(self):
94
- return self.tehook_list[0][1].clip_skip
91
+ return {name:tehook.clip_skip for name, tehook in self.tehooks.items()}
95
92
 
96
93
  @clip_skip.setter
97
- def clip_skip(self, value):
98
- for name, tehook in self.tehook_list:
99
- tehook.clip_skip = value
94
+ def clip_skip(self, value: int | Dict[str, int]):
95
+ for name, tehook in self.tehooks.items():
96
+ if isinstance(value, int):
97
+ tehook.clip_skip = value
98
+ else:
99
+ tehook.clip_skip = value[name]
100
100
 
101
101
  @property
102
102
  def clip_final_norm(self):
103
- return self.tehook_list[0][1].clip_final_norm
103
+ return {name:tehook.clip_final_norm for name, tehook in self.tehooks.items()}
104
104
 
105
105
  @clip_final_norm.setter
106
- def clip_final_norm(self, value: bool):
107
- for name, tehook in self.tehook_list:
108
- tehook.clip_final_norm = value
106
+ def clip_final_norm(self, value: bool | Dict[str, bool]):
107
+ for name, tehook in self.tehooks.items():
108
+ if isinstance(value, bool):
109
+ tehook.clip_final_norm = value
110
+ else:
111
+ tehook.clip_final_norm = value[name]
109
112
 
110
113
  @property
111
114
  def use_attention_mask(self):
112
- return self.tehook_list[0][1].use_attention_mask
115
+ return {name:tehook.use_attention_mask for name, tehook in self.tehooks.items()}
113
116
 
114
117
  @use_attention_mask.setter
115
- def use_attention_mask(self, value: bool):
116
- for name, tehook in self.tehook_list:
117
- tehook.use_attention_mask = value
118
-
119
- def encode_prompt_to_emb(self, prompt):
120
- emb_list = [tehook.encode_prompt_to_emb(prompt) for name, tehook in self.tehook_list]
121
- encoder_hidden_states, pooled_output, attention_mask = list(zip(*emb_list))
122
- return torch.cat(encoder_hidden_states, dim=self.cat_dim), pooled_output, attention_mask[0]
118
+ def use_attention_mask(self, value: bool | Dict[str, bool]):
119
+ for name, tehook in self.tehooks.items():
120
+ if isinstance(value, bool):
121
+ tehook.use_attention_mask = value
122
+ else:
123
+ tehook.use_attention_mask = value[name]
123
124
 
124
125
  def enable_xformers(self):
125
- for name, tehook in self.tehook_list:
126
+ for name, tehook in self.tehooks.items():
126
127
  tehook.enable_xformers()
127
128
 
128
129
  @staticmethod
@@ -130,19 +131,20 @@ class ComposeTEEXHook:
130
131
  return TEEXHook.mult_attn(prompt_embeds, attn_mult)
131
132
 
132
133
  @classmethod
133
- def hook(cls, text_enc: nn.Module, tokenizer, N_repeats=3, clip_skip=0, clip_final_norm=True, use_attention_mask=False) -> Union[
134
+ def hook(cls, text_enc: nn.Module, tokenizer, N_repeats=1, clip_skip=0, clip_final_norm=True, use_attention_mask=False) -> Union[
134
135
  'ComposeTEEXHook', TEEXHook]:
135
136
  if isinstance(text_enc, ComposeTextEncoder):
136
137
  # multi text encoder
137
- tehook_list = [(name, TEEXHook.hook(getattr(text_enc, name), getattr(tokenizer, name), N_repeats, clip_skip, clip_final_norm,
138
- use_attention_mask=use_attention_mask))
139
- for name in tokenizer.tokenizer_names]
140
- return cls(tehook_list)
138
+ get_data = lambda name, data:data[name] if isinstance(data, dict) else data
139
+ tehooks = {name:TEEXHook.hook(getattr(text_enc, name), getattr(tokenizer, name), get_data(name, N_repeats), get_data(name, clip_skip),
140
+ get_data(name, clip_final_norm), use_attention_mask=get_data(name, use_attention_mask))
141
+ for name in tokenizer.tokenizer_names}
142
+ return cls(tehooks)
141
143
  else:
142
144
  # single text encoder
143
145
  return TEEXHook.hook(text_enc, tokenizer, N_repeats, clip_skip, clip_final_norm, use_attention_mask=use_attention_mask)
144
146
 
145
147
  @classmethod
146
- def hook_pipe(cls, pipe, N_repeats=3, clip_skip=0, clip_final_norm=True, use_attention_mask=False):
148
+ def hook_pipe(cls, pipe, N_repeats=1, clip_skip=0, clip_final_norm=True, use_attention_mask=False):
147
149
  return cls.hook(pipe.text_encoder, pipe.tokenizer, N_repeats=N_repeats, clip_skip=clip_skip, clip_final_norm=clip_final_norm,
148
150
  use_attention_mask=use_attention_mask)
@@ -13,24 +13,24 @@ from typing import Dict, Optional, Union, Tuple, List
13
13
 
14
14
  import torch
15
15
  from torch import nn
16
- from transformers import CLIPTextModel, PreTrainedModel, PretrainedConfig
16
+ from transformers import CLIPTextModel, PreTrainedModel, PretrainedConfig, AutoModel
17
17
  from transformers.modeling_outputs import BaseModelOutputWithPooling
18
+ from rainbowneko.utils import BatchableDict
18
19
 
19
20
  class ComposeTextEncoder(PreTrainedModel):
20
- def __init__(self, model_list: List[Tuple[str, CLIPTextModel]], cat_dim=-1, with_hook=True):
21
- super().__init__(PretrainedConfig(**{name:model.config for name, model in model_list}))
22
- self.cat_dim = cat_dim
21
+ def __init__(self, models: Dict[str, PreTrainedModel], with_hook=True):
22
+ super().__init__(PretrainedConfig(**{name:model.config for name, model in models.items()}))
23
23
  self.with_hook = with_hook
24
24
 
25
25
  self.model_names = []
26
- for name, model in model_list:
26
+ for name, model in models.items():
27
27
  setattr(self, name, model)
28
28
  self.model_names.append(name)
29
29
 
30
- def get_input_embeddings(self) -> List[nn.Module]:
31
- return [getattr(self, name).get_input_embeddings() for name in self.model_names]
30
+ def get_input_embeddings(self) -> Dict[str, nn.Module]:
31
+ return {name: getattr(self, name).get_input_embeddings() for name in self.model_names}
32
32
 
33
- def set_input_embeddings(self, value_dict: Dict[str, int]):
33
+ def set_input_embeddings(self, value_dict: Dict[str, torch.Tensor]):
34
34
  for name, value in value_dict.items():
35
35
  getattr(self, name).set_input_embeddings(value)
36
36
 
@@ -60,7 +60,7 @@ class ComposeTextEncoder(PreTrainedModel):
60
60
  >>> tokenizer_B = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
61
61
  >>> tokenizer_bigG = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
62
62
 
63
- >>> clip_model = MultiTextEncoder([('clip_B', clip_B), ('clip_bigG', clip_bigG)])
63
+ >>> clip_model = ComposeTextEncoder({'clip_B': clip_B, 'clip_bigG': clip_bigG})
64
64
 
65
65
  >>> inputs = {
66
66
  >>> 'clip_B':tokenizer_B(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt").input_ids
@@ -72,28 +72,42 @@ class ComposeTextEncoder(PreTrainedModel):
72
72
  >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
73
73
  ```"""
74
74
 
75
- input_ids_list = input_ids.chunk(len(self.model_names),dim=-1)
75
+ def get_data(name, data):
76
+ if data is None:
77
+ return None
78
+ elif isinstance(data, (dict, BatchableDict)):
79
+ return data[name]
80
+ else:
81
+ return data
76
82
 
77
83
  if self.with_hook:
78
- encoder_hidden_states_list, pooled_output_list = [], []
79
- for name, input_ids in zip(self.model_names, input_ids_list):
80
- encoder_hidden_states, pooled_output = getattr(self, name)(
81
- input_ids, # get token for model self.{name}
82
- attention_mask=attention_mask,
83
- position_ids=position_ids,
84
- output_attentions=output_attentions,
85
- output_hidden_states=output_hidden_states,
86
- return_dict=True,
87
- )
88
- encoder_hidden_states_list.append(encoder_hidden_states)
89
- pooled_output_list.append(pooled_output)
90
- encoder_hidden_states = torch.cat(encoder_hidden_states_list, dim=self.cat_dim)
91
- return encoder_hidden_states, pooled_output_list
84
+ encoder_hidden_states_dict, pooled_output_dict = {}, {}
85
+ for name in self.model_names:
86
+ if position_ids_i := get_data(name, position_ids) is None:
87
+ encoder_hidden_states, pooled_output = getattr(self, name)(
88
+ get_data(name, input_ids), # get token for model self.{name}
89
+ attention_mask=get_data(name, attention_mask),
90
+ output_attentions=get_data(name, output_attentions),
91
+ output_hidden_states=get_data(name, output_hidden_states),
92
+ return_dict=True,
93
+ )
94
+ else:
95
+ encoder_hidden_states, pooled_output = getattr(self, name)(
96
+ get_data(name, input_ids), # get token for model self.{name}
97
+ attention_mask=get_data(name, attention_mask),
98
+ position_ids=position_ids_i,
99
+ output_attentions=get_data(name, output_attentions),
100
+ output_hidden_states=get_data(name, output_hidden_states),
101
+ return_dict=True,
102
+ )
103
+ encoder_hidden_states_dict[name] = encoder_hidden_states
104
+ pooled_output_dict[name] = pooled_output
105
+ return encoder_hidden_states_dict, pooled_output_dict
92
106
  else:
93
107
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
94
108
 
95
- text_feat_list = {'last_hidden_state':[], 'pooler_output':[], 'hidden_states':[], 'attentions':[]}
96
- for name, input_ids in zip(self.model_names, input_ids_list):
109
+ text_feat_list = {'last_hidden_state':{}, 'pooler_output':{}, 'hidden_states':{}, 'attentions':{}}
110
+ for name in self.model_names:
97
111
  text_feat: BaseModelOutputWithPooling = getattr(self, name)(
98
112
  input_ids, # get token for model self.{name}
99
113
  attention_mask=attention_mask,
@@ -102,31 +116,31 @@ class ComposeTextEncoder(PreTrainedModel):
102
116
  output_hidden_states=output_hidden_states,
103
117
  return_dict=True,
104
118
  )
105
- text_feat_list['last_hidden_state'].append(text_feat.last_hidden_state)
106
- text_feat_list['pooler_output'].append(text_feat.pooler_output)
107
- text_feat_list['hidden_states'].append(text_feat.hidden_states)
108
- text_feat_list['attentions'].append(text_feat.attentions)
109
-
110
- last_hidden_state = torch.cat(text_feat_list['last_hidden_state'], dim=self.cat_dim)
111
- # pooler_output = torch.cat(text_feat_list['pooler_output'], dim=self.cat_dim)
112
- pooler_output = text_feat_list['pooler_output']
113
- if text_feat_list['hidden_states'][0] is None:
114
- hidden_states = None
115
- else:
116
- hidden_states = [torch.cat(states, dim=self.cat_dim) for states in zip(*text_feat_list['hidden_states'])]
119
+ text_feat_list['last_hidden_state'][name] = text_feat.last_hidden_state
120
+ text_feat_list['pooler_output'][name] = text_feat.pooler_output
121
+ text_feat_list['hidden_states'][name] = text_feat.hidden_states
122
+ text_feat_list['attentions'][name] = text_feat.attentions
123
+
124
+ # last_hidden_state = torch.cat(text_feat_list['last_hidden_state'], dim=self.cat_dim)
125
+ # # pooler_output = torch.cat(text_feat_list['pooler_output'], dim=self.cat_dim)
126
+ # pooler_output = text_feat_list['pooler_output']
127
+ # if text_feat_list['hidden_states'][0] is None:
128
+ # hidden_states = None
129
+ # else:
130
+ # hidden_states = [torch.cat(states, dim=self.cat_dim) for states in zip(*text_feat_list['hidden_states'])]
117
131
 
118
132
  if return_dict:
119
133
  return BaseModelOutputWithPooling(
120
- last_hidden_state=last_hidden_state,
121
- pooler_output=pooler_output,
122
- hidden_states=hidden_states,
134
+ last_hidden_state=text_feat_list['last_hidden_state'],
135
+ pooler_output=text_feat_list['pooler_output'],
136
+ hidden_states=text_feat_list['hidden_states'],
123
137
  attentions=text_feat_list['attentions'],
124
138
  )
125
139
  else:
126
- return (last_hidden_state, pooler_output)+hidden_states
140
+ return text_feat_list['last_hidden_state'], text_feat_list['pooler_output'], text_feat_list['hidden_states']
127
141
 
128
142
  @classmethod
129
- def from_pretrained(cls, pretrained_model_name_or_path: List[Tuple[str, str]], *args,
143
+ def from_pretrained(cls, pretrained_model_name_or_path: Dict[str, str], *args,
130
144
  subfolder: Dict[str, str] = None, revision: str = None, **kwargs):
131
145
  r"""
132
146
  Examples: sdxl text encoder
@@ -138,6 +152,6 @@ class ComposeTextEncoder(PreTrainedModel):
138
152
  >>> ], subfolder={'clip_B':'text_encoder', 'clip_bigG':'text_encoder_2'})
139
153
  ```
140
154
  """
141
- clip_list = [(name, CLIPTextModel.from_pretrained(path, subfolder=subfolder[name], **kwargs)) for name, path in pretrained_model_name_or_path]
142
- compose_model = cls(clip_list)
155
+ models = {name: AutoModel.from_pretrained(path, subfolder=subfolder[name], **kwargs) for name, path in pretrained_model_name_or_path.items()}
156
+ compose_model = cls(models)
143
157
  return compose_model