hcpdiff 2.3.1__py3-none-any.whl → 2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/ckpt_manager/__init__.py +1 -1
- hcpdiff/ckpt_manager/format/__init__.py +2 -2
- hcpdiff/ckpt_manager/format/diffusers.py +19 -4
- hcpdiff/ckpt_manager/format/emb.py +8 -3
- hcpdiff/ckpt_manager/format/lora_webui.py +1 -1
- hcpdiff/ckpt_manager/format/sd_single.py +28 -5
- hcpdiff/data/cache/vae.py +10 -2
- hcpdiff/data/handler/text.py +15 -14
- hcpdiff/diffusion/sampler/__init__.py +2 -1
- hcpdiff/diffusion/sampler/base.py +17 -6
- hcpdiff/diffusion/sampler/diffusers.py +4 -3
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +5 -14
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +7 -6
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +4 -4
- hcpdiff/diffusion/sampler/sigma_scheduler/flow.py +3 -3
- hcpdiff/diffusion/sampler/timer/__init__.py +2 -0
- hcpdiff/diffusion/sampler/timer/base.py +26 -0
- hcpdiff/diffusion/sampler/timer/shift.py +49 -0
- hcpdiff/easy/__init__.py +2 -1
- hcpdiff/easy/cfg/sd15_train.py +1 -3
- hcpdiff/easy/model/__init__.py +1 -1
- hcpdiff/easy/model/loader.py +33 -11
- hcpdiff/easy/sampler.py +8 -1
- hcpdiff/loss/__init__.py +4 -3
- hcpdiff/loss/charbonnier.py +17 -0
- hcpdiff/loss/vlb.py +2 -2
- hcpdiff/loss/weighting.py +29 -11
- hcpdiff/models/__init__.py +1 -1
- hcpdiff/models/cfg_context.py +5 -3
- hcpdiff/models/compose/__init__.py +2 -1
- hcpdiff/models/compose/compose_hook.py +69 -67
- hcpdiff/models/compose/compose_textencoder.py +59 -45
- hcpdiff/models/compose/compose_tokenizer.py +48 -11
- hcpdiff/models/compose/flux.py +75 -0
- hcpdiff/models/compose/sdxl.py +86 -0
- hcpdiff/models/text_emb_ex.py +13 -9
- hcpdiff/models/textencoder_ex.py +8 -38
- hcpdiff/models/wrapper/__init__.py +2 -1
- hcpdiff/models/wrapper/flux.py +75 -0
- hcpdiff/models/wrapper/pixart.py +13 -1
- hcpdiff/models/wrapper/sd.py +17 -8
- hcpdiff/parser/embpt.py +7 -7
- hcpdiff/utils/net_utils.py +22 -12
- hcpdiff/workflow/__init__.py +1 -1
- hcpdiff/workflow/diffusion.py +145 -18
- hcpdiff/workflow/text.py +49 -18
- hcpdiff/workflow/vae.py +10 -2
- {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/METADATA +1 -1
- {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/RECORD +53 -49
- hcpdiff/models/compose/sdxl_composer.py +0 -39
- hcpdiff/utils/inpaint_pipe.py +0 -790
- hcpdiff/utils/pipe_hook.py +0 -656
- {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/WHEEL +0 -0
- {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/entry_points.txt +0 -0
- {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/licenses/LICENSE +0 -0
- {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/top_level.txt +0 -0
hcpdiff/easy/model/loader.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1
1
|
import torch
|
2
|
-
from
|
2
|
+
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, FluxPipeline
|
3
3
|
from rainbowneko.ckpt_manager import NekoLoader, LocalCkptSource
|
4
|
-
|
5
|
-
from hcpdiff.
|
6
|
-
|
7
|
-
from
|
4
|
+
|
5
|
+
from hcpdiff.ckpt_manager import DiffusersSD15Format, DiffusersSDXLFormat, DiffusersPixArtFormat, OfficialSD15Format, OfficialSDXLFormat, \
|
6
|
+
DiffusersFluxFormat, OneFileFluxFormat
|
7
|
+
from hcpdiff.models.compose import SDXLTextEncoder, FluxTextEncoder
|
8
|
+
from hcpdiff.models.wrapper import SDXLWrapper, SD15Wrapper, PixArtWrapper, FluxWrapper
|
9
|
+
from hcpdiff.utils import auto_text_encoder_cls, get_pipe_name
|
8
10
|
|
9
11
|
def SD15_auto_loader(ckpt_path, denoiser=None, TE=None, vae=None, noise_sampler=None,
|
10
12
|
tokenizer=None, revision=None, dtype=torch.float32, **kwargs):
|
@@ -20,7 +22,7 @@ def SD15_auto_loader(ckpt_path, denoiser=None, TE=None, vae=None, noise_sampler=
|
|
20
22
|
source=LocalCkptSource(),
|
21
23
|
)
|
22
24
|
models = loader.load(ckpt_path, denoiser=denoiser, TE=TE, vae=vae, noise_sampler=noise_sampler, tokenizer=tokenizer, revision=revision,
|
23
|
-
|
25
|
+
dtype=dtype, **kwargs)
|
24
26
|
return models
|
25
27
|
|
26
28
|
def SDXL_auto_loader(ckpt_path, denoiser=None, TE=None, vae=None, noise_sampler=None,
|
@@ -37,17 +39,34 @@ def SDXL_auto_loader(ckpt_path, denoiser=None, TE=None, vae=None, noise_sampler=
|
|
37
39
|
source=LocalCkptSource(),
|
38
40
|
)
|
39
41
|
models = loader.load(ckpt_path, denoiser=denoiser, TE=TE, vae=vae, noise_sampler=noise_sampler, tokenizer=tokenizer, revision=revision,
|
40
|
-
|
42
|
+
dtype=dtype, **kwargs)
|
41
43
|
return models
|
42
44
|
|
43
45
|
def PixArt_auto_loader(ckpt_path, denoiser=None, TE=None, vae=None, noise_sampler=None,
|
44
|
-
|
46
|
+
tokenizer=None, revision=None, dtype=torch.float32, **kwargs):
|
45
47
|
loader = NekoLoader(
|
46
48
|
format=DiffusersPixArtFormat(),
|
47
49
|
source=LocalCkptSource(),
|
48
50
|
)
|
49
51
|
models = loader.load(ckpt_path, denoiser=denoiser, TE=TE, vae=vae, noise_sampler=noise_sampler, tokenizer=tokenizer, revision=revision,
|
50
|
-
|
52
|
+
dtype=dtype, **kwargs)
|
53
|
+
return models
|
54
|
+
|
55
|
+
def Flux_auto_loader(ckpt_path, denoiser=None, TE=None, vae=None, noise_sampler=None,
|
56
|
+
tokenizer=None, revision=None, dtype=torch.float32, **kwargs):
|
57
|
+
try:
|
58
|
+
try_diffusers = FluxPipeline.load_config(ckpt_path)
|
59
|
+
loader = NekoLoader(
|
60
|
+
format=DiffusersFluxFormat(),
|
61
|
+
source=LocalCkptSource(),
|
62
|
+
)
|
63
|
+
except EnvironmentError:
|
64
|
+
loader = NekoLoader(
|
65
|
+
format=OneFileFluxFormat(),
|
66
|
+
source=LocalCkptSource(),
|
67
|
+
)
|
68
|
+
models = loader.load(ckpt_path, denoiser=denoiser, TE=TE, vae=vae, noise_sampler=noise_sampler, tokenizer=tokenizer, revision=revision,
|
69
|
+
dtype=dtype, **kwargs)
|
51
70
|
return models
|
52
71
|
|
53
72
|
def auto_load_wrapper(pretrained_model, denoiser=None, TE=None, vae=None, noise_sampler=None, tokenizer=None, revision=None,
|
@@ -62,6 +81,9 @@ def auto_load_wrapper(pretrained_model, denoiser=None, TE=None, vae=None, noise_
|
|
62
81
|
if text_encoder_cls == SDXLTextEncoder:
|
63
82
|
wrapper_cls = SDXLWrapper
|
64
83
|
format = DiffusersSDXLFormat()
|
84
|
+
elif text_encoder_cls == FluxTextEncoder:
|
85
|
+
wrapper_cls = FluxWrapper
|
86
|
+
format = DiffusersFluxFormat()
|
65
87
|
elif 'PixArt' in pipe_name:
|
66
88
|
wrapper_cls = PixArtWrapper
|
67
89
|
format = DiffusersPixArtFormat()
|
@@ -74,6 +96,6 @@ def auto_load_wrapper(pretrained_model, denoiser=None, TE=None, vae=None, noise_
|
|
74
96
|
source=LocalCkptSource(),
|
75
97
|
)
|
76
98
|
models = loader.load(pretrained_model, denoiser=denoiser, TE=TE, vae=vae, noise_sampler=noise_sampler, tokenizer=tokenizer, revision=revision,
|
77
|
-
|
99
|
+
dtype=dtype)
|
78
100
|
|
79
|
-
return wrapper_cls.build_from_pretrained(models, **kwargs)
|
101
|
+
return wrapper_cls.build_from_pretrained(models, **kwargs)
|
hcpdiff/easy/sampler.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
from hcpdiff.diffusion.sampler import DiffusersSampler
|
2
|
-
from diffusers import DPMSolverMultistepScheduler, DDIMScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler
|
2
|
+
from diffusers import DPMSolverMultistepScheduler, DDIMScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, FlowMatchEulerDiscreteScheduler
|
3
3
|
|
4
4
|
class Diffusers_SD:
|
5
5
|
dpmpp_2m = DiffusersSampler(
|
@@ -43,4 +43,11 @@ class Diffusers_SD:
|
|
43
43
|
beta_end=0.012,
|
44
44
|
beta_schedule='scaled_linear',
|
45
45
|
)
|
46
|
+
)
|
47
|
+
|
48
|
+
euler_flow = DiffusersSampler(
|
49
|
+
FlowMatchEulerDiscreteScheduler(
|
50
|
+
shift=3.0,
|
51
|
+
use_dynamic_shifting=True,
|
52
|
+
)
|
46
53
|
)
|
hcpdiff/loss/__init__.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
|
-
from .
|
2
|
-
from .
|
1
|
+
from .base import DiffusionLossContainer
|
2
|
+
from .charbonnier import CharbonnierLoss
|
3
3
|
from .gw import GWLoss
|
4
|
-
from .
|
4
|
+
from .ssim import SSIMLoss, MS_SSIMLoss
|
5
|
+
from .weighting import MinSNRWeight, SNRWeight, EDMWeight, LossWeight, LossMapWeight
|
@@ -0,0 +1,17 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import nn
|
3
|
+
|
4
|
+
class CharbonnierLoss(nn.Module):
|
5
|
+
"""Charbonnier Loss (L1)"""
|
6
|
+
|
7
|
+
def __init__(self, eps=1e-3, size_average=True):
|
8
|
+
super(CharbonnierLoss, self).__init__()
|
9
|
+
self.eps = eps
|
10
|
+
self.size_average = size_average
|
11
|
+
|
12
|
+
def forward(self, x, y):
|
13
|
+
diff = x - y
|
14
|
+
loss = torch.sqrt((diff * diff) + (self.eps*self.eps))
|
15
|
+
if self.size_average:
|
16
|
+
loss = loss.mean()
|
17
|
+
return loss
|
hcpdiff/loss/vlb.py
CHANGED
@@ -25,10 +25,10 @@ class VLBLoss(nn.Module):
|
|
25
25
|
x0_pred = sampler.eps_to_x0(eps_pred, x_t, sigma)
|
26
26
|
|
27
27
|
true_mean = sampler.sigma_scheduler.get_post_mean(timesteps, target, x_t)
|
28
|
-
true_logvar = sampler.sigma_scheduler.get_post_log_var(timesteps)
|
28
|
+
true_logvar = sampler.sigma_scheduler.get_post_log_var(timesteps, ndim=input.ndim)
|
29
29
|
|
30
30
|
pred_mean = sampler.sigma_scheduler.get_post_mean(timesteps, x0_pred, x_t)
|
31
|
-
pred_logvar = sampler.sigma_scheduler.get_post_log_var(timesteps, x_t_var=var_pred)
|
31
|
+
pred_logvar = sampler.sigma_scheduler.get_post_log_var(timesteps, ndim=input.ndim, x_t_var=var_pred)
|
32
32
|
|
33
33
|
kl = self.normal_kl(true_mean, true_logvar, pred_mean, pred_logvar)
|
34
34
|
kl = kl.mean(dim=(1,2,3))/np.log(2.0)
|
hcpdiff/loss/weighting.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1
|
+
from rainbowneko.utils import add_dims
|
2
|
+
from rainbowneko.train.loss import FullInputLoss
|
1
3
|
from torch import nn
|
4
|
+
from typing import Callable
|
2
5
|
|
3
|
-
|
4
|
-
|
5
|
-
class LossWeight(nn.Module):
|
6
|
-
def __init__(self, loss: DiffusionLossContainer):
|
6
|
+
class LossWeight(nn.Module, FullInputLoss):
|
7
|
+
def __init__(self, loss: Callable):
|
7
8
|
super().__init__()
|
8
9
|
self.loss = loss
|
9
10
|
|
@@ -21,12 +22,29 @@ class LossWeight(nn.Module):
|
|
21
22
|
'''
|
22
23
|
raise NotImplementedError
|
23
24
|
|
24
|
-
def forward(self, pred, inputs):
|
25
|
+
def forward(self, pred, inputs, _full_pred, _full_inputs):
|
25
26
|
'''
|
26
27
|
weight: [B,1,1,1] or [B,C,H,W]
|
27
28
|
loss: [B,*,*,*]
|
28
29
|
'''
|
29
|
-
return self.get_weight(
|
30
|
+
return self.get_weight(_full_pred, _full_inputs)*self.loss(pred, inputs)
|
31
|
+
|
32
|
+
class LossMapWeight(LossWeight):
|
33
|
+
def __init__(self, loss: Callable, normalize: bool = False):
|
34
|
+
super().__init__(loss)
|
35
|
+
self.normalize = normalize
|
36
|
+
|
37
|
+
def get_weight(self, pred, inputs):
|
38
|
+
ndim = pred['model_pred'].ndim
|
39
|
+
loss_map = inputs['loss_map'].float()
|
40
|
+
if ndim == 4:
|
41
|
+
if self.normalize:
|
42
|
+
loss_map /= loss_map.mean(dim=(1,2), keepdim=True)
|
43
|
+
return loss_map.unsqueeze(1)
|
44
|
+
elif ndim == 3:
|
45
|
+
if self.normalize:
|
46
|
+
loss_map /= loss_map.mean(dim=1, keepdim=True)
|
47
|
+
return loss_map.unsqueeze(-1)
|
30
48
|
|
31
49
|
class SNRWeight(LossWeight):
|
32
50
|
def get_weight(self, pred, inputs):
|
@@ -42,10 +60,10 @@ class SNRWeight(LossWeight):
|
|
42
60
|
else:
|
43
61
|
raise ValueError(f"{self.__class__.__name__} is not support for target_type {target_type}")
|
44
62
|
|
45
|
-
return w_snr.
|
63
|
+
return add_dims(w_snr, pred['model_pred'].ndim-1)
|
46
64
|
|
47
65
|
class MinSNRWeight(LossWeight):
|
48
|
-
def __init__(self, loss:
|
66
|
+
def __init__(self, loss: Callable, gamma: float = 1.):
|
49
67
|
super().__init__(loss)
|
50
68
|
self.gamma = gamma
|
51
69
|
|
@@ -63,10 +81,10 @@ class MinSNRWeight(LossWeight):
|
|
63
81
|
else:
|
64
82
|
raise ValueError(f"{self.__class__.__name__} is not support for target_type {target_type}")
|
65
83
|
|
66
|
-
return w_snr.
|
84
|
+
return add_dims(w_snr, pred['model_pred'].ndim-1)
|
67
85
|
|
68
86
|
class EDMWeight(LossWeight):
|
69
|
-
def __init__(self, loss:
|
87
|
+
def __init__(self, loss: Callable, gamma: float = 1.):
|
70
88
|
super().__init__(loss)
|
71
89
|
self.gamma = gamma
|
72
90
|
|
@@ -81,4 +99,4 @@ class EDMWeight(LossWeight):
|
|
81
99
|
else:
|
82
100
|
raise ValueError(f"{self.__class__.__name__} is not support for target_type {target_type}")
|
83
101
|
|
84
|
-
return w_snr.
|
102
|
+
return add_dims(w_snr, pred['model_pred'].ndim-1)
|
hcpdiff/models/__init__.py
CHANGED
@@ -6,5 +6,5 @@ from .text_emb_ex import EmbeddingPTHook
|
|
6
6
|
from .textencoder_ex import TEEXHook
|
7
7
|
from .tokenizer_ex import TokenizerHook
|
8
8
|
from .cfg_context import CFGContext, DreamArtistPTContext
|
9
|
-
from .wrapper import SD15Wrapper, SDXLWrapper, PixArtWrapper, TEHookCFG
|
9
|
+
from .wrapper import SD15Wrapper, SDXLWrapper, PixArtWrapper, TEHookCFG, FluxWrapper
|
10
10
|
from .controlnet import ControlNetPlugin
|
hcpdiff/models/cfg_context.py
CHANGED
@@ -1,8 +1,10 @@
|
|
1
|
-
import torch
|
2
|
-
from einops import repeat
|
3
1
|
import math
|
4
2
|
from typing import Union, Callable
|
5
3
|
|
4
|
+
import torch
|
5
|
+
from einops import repeat
|
6
|
+
from rainbowneko.utils import add_dims
|
7
|
+
|
6
8
|
class CFGContext:
|
7
9
|
def pre(self, noisy_latents, timesteps):
|
8
10
|
return noisy_latents, timesteps
|
@@ -35,7 +37,7 @@ class DreamArtistPTContext(CFGContext):
|
|
35
37
|
pass
|
36
38
|
else:
|
37
39
|
rate = self.cfg_func(rate)
|
38
|
-
rate = rate.
|
40
|
+
rate = add_dims(rate, model_pred.ndim-1)
|
39
41
|
else:
|
40
42
|
rate = 1
|
41
43
|
model_pred = e_t_uncond+((self.cfg_high-self.cfg_low)*rate+self.cfg_low)*(e_t-e_t_uncond)
|
@@ -1,4 +1,5 @@
|
|
1
1
|
from .compose_tokenizer import ComposeTokenizer
|
2
2
|
from .compose_textencoder import ComposeTextEncoder
|
3
3
|
from .compose_hook import ComposeTEEXHook, ComposeEmbPTHook
|
4
|
-
from .
|
4
|
+
from .sdxl import SDXLTokenizer, SDXLTextEncoder
|
5
|
+
from .flux import FluxTokenizer, FluxTextEncoder
|
@@ -1,128 +1,129 @@
|
|
1
|
-
import
|
2
|
-
from typing import Dict, Union, Tuple
|
1
|
+
from pathlib import Path
|
2
|
+
from typing import Dict, Union, Tuple
|
3
3
|
|
4
4
|
import torch
|
5
|
-
from loguru import logger
|
6
5
|
from torch import nn
|
7
6
|
|
7
|
+
from hcpdiff.utils.net_utils import load_emb
|
8
8
|
from .compose_textencoder import ComposeTextEncoder
|
9
9
|
from ..text_emb_ex import EmbeddingPTHook
|
10
10
|
from ..textencoder_ex import TEEXHook
|
11
|
-
from ...utils.net_utils import load_emb
|
12
|
-
from ..container import ParameterGroup
|
13
11
|
|
14
12
|
class ComposeEmbPTHook(nn.Module):
|
15
|
-
def __init__(self,
|
13
|
+
def __init__(self, hooks: Dict[str, EmbeddingPTHook]):
|
16
14
|
super().__init__()
|
17
|
-
self.
|
18
|
-
self.emb_train = nn.ParameterList()
|
15
|
+
self.hooks = hooks
|
16
|
+
self.emb_train = nn.ParameterList() # [ParameterDict{model_name:Parameter, ...}, ...]
|
19
17
|
|
20
18
|
@property
|
21
19
|
def N_repeats(self):
|
22
|
-
return self.
|
20
|
+
return {name:hook.N_repeats for name, hook in self.hooks.items()}
|
23
21
|
|
24
22
|
@N_repeats.setter
|
25
23
|
def N_repeats(self, value):
|
26
|
-
for name, hook in self.
|
27
|
-
|
24
|
+
for name, hook in self.hooks.items():
|
25
|
+
if isinstance(value, int):
|
26
|
+
hook.N_repeats = value
|
27
|
+
else:
|
28
|
+
hook.N_repeats = value[name]
|
28
29
|
|
29
|
-
def add_emb(self, emb: nn.Parameter,
|
30
|
-
emb_len = 0
|
30
|
+
def add_emb(self, emb: Dict[str, nn.Parameter], token_ids: Dict[str, int]):
|
31
31
|
# Same word in different tokenizer may have different token_id
|
32
|
-
for
|
33
|
-
hook.add_emb(emb[
|
34
|
-
emb_len += hook.embedding_dim
|
32
|
+
for name, hook in self.hooks.items():
|
33
|
+
hook.add_emb(emb[name], token_ids[name])
|
35
34
|
|
36
35
|
def remove(self):
|
37
|
-
for name, hook in self.
|
36
|
+
for name, hook in self.hooks.items():
|
38
37
|
hook.remove()
|
39
38
|
|
40
39
|
@classmethod
|
41
|
-
def hook(cls, ex_words_emb: Dict[str,
|
40
|
+
def hook(cls, ex_words_emb: Dict[str, nn.ParameterDict], tokenizer, text_encoder, **kwargs):
|
42
41
|
if isinstance(text_encoder, ComposeTextEncoder):
|
43
|
-
|
42
|
+
hooks = {}
|
44
43
|
|
45
44
|
emb_len = 0
|
46
|
-
for
|
45
|
+
for name in tokenizer.tokenizer_names:
|
47
46
|
text_encoder_i = getattr(text_encoder, name)
|
48
47
|
tokenizer_i = getattr(tokenizer, name)
|
49
48
|
embedding_dim = text_encoder_i.get_input_embeddings().embedding_dim
|
50
|
-
ex_words_emb_i = {k:v[
|
49
|
+
ex_words_emb_i = {k:v[name] for k, v in ex_words_emb.items()} # {word_name:Parameter, ...}
|
51
50
|
emb_len += embedding_dim
|
52
|
-
|
51
|
+
hooks[name] = EmbeddingPTHook.hook(ex_words_emb_i, tokenizer_i, text_encoder_i, **kwargs)
|
53
52
|
|
54
|
-
return cls(
|
53
|
+
return cls(hooks)
|
55
54
|
else:
|
56
55
|
return EmbeddingPTHook.hook(ex_words_emb, tokenizer, text_encoder, **kwargs)
|
57
56
|
|
58
57
|
@classmethod
|
59
|
-
def hook_from_dir(cls, emb_dir, tokenizer, text_encoder, device='cuda
|
60
|
-
|
58
|
+
def hook_from_dir(cls, emb_dir: str | Path, tokenizer, text_encoder, device='cuda', **kwargs) -> (
|
59
|
+
Tuple['ComposeEmbPTHook', Dict[str, nn.ParameterDict]] | Tuple[EmbeddingPTHook, Dict[str, nn.Parameter]]):
|
60
|
+
emb_dir = Path(emb_dir) if emb_dir is not None else None
|
61
61
|
if isinstance(text_encoder, ComposeTextEncoder):
|
62
62
|
# multi text encoder
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
for file in os.listdir(emb_dir):
|
70
|
-
if file.endswith('.pt'):
|
71
|
-
emb = load_emb(os.path.join(emb_dir, file)).to(device)
|
72
|
-
emb = ParameterGroup([nn.Parameter(item, requires_grad=False) for item in emb.split(emb_dims, dim=1)])
|
73
|
-
ex_words_emb[file[:-3]] = emb
|
63
|
+
ex_words_emb = {} # {word_name:{model_name:Tensor, ...}, ...}
|
64
|
+
if emb_dir is not None and emb_dir.exists():
|
65
|
+
for file in emb_dir.glob('*.pt'):
|
66
|
+
emb = load_emb(file) # {model_name:Tensor, ...}
|
67
|
+
emb = nn.ParameterDict({name:nn.Parameter(emb_i.to(device), requires_grad=False) for name, emb_i in emb.items()})
|
68
|
+
ex_words_emb[file.stem] = emb
|
74
69
|
return cls.hook(ex_words_emb, tokenizer, text_encoder, **kwargs), ex_words_emb
|
75
70
|
else:
|
76
71
|
return EmbeddingPTHook.hook_from_dir(emb_dir, tokenizer, text_encoder, **kwargs)
|
77
72
|
|
78
73
|
class ComposeTEEXHook:
|
79
|
-
def __init__(self,
|
80
|
-
self.
|
81
|
-
self.cat_dim = cat_dim
|
74
|
+
def __init__(self, tehooks: Dict[str, TEEXHook]):
|
75
|
+
self.tehooks = tehooks
|
82
76
|
|
83
77
|
@property
|
84
78
|
def N_repeats(self):
|
85
|
-
return self.
|
79
|
+
return {name:tehook.N_repeats for name, tehook in self.tehooks.items()}
|
86
80
|
|
87
81
|
@N_repeats.setter
|
88
|
-
def N_repeats(self, value):
|
89
|
-
for name, tehook in self.
|
90
|
-
|
82
|
+
def N_repeats(self, value: int | Dict[str, int]):
|
83
|
+
for name, tehook in self.tehooks.items():
|
84
|
+
if isinstance(value, int):
|
85
|
+
tehook.N_repeats = value
|
86
|
+
else:
|
87
|
+
tehook.N_repeats = value[name]
|
91
88
|
|
92
89
|
@property
|
93
90
|
def clip_skip(self):
|
94
|
-
return self.
|
91
|
+
return {name:tehook.clip_skip for name, tehook in self.tehooks.items()}
|
95
92
|
|
96
93
|
@clip_skip.setter
|
97
|
-
def clip_skip(self, value):
|
98
|
-
for name, tehook in self.
|
99
|
-
|
94
|
+
def clip_skip(self, value: int | Dict[str, int]):
|
95
|
+
for name, tehook in self.tehooks.items():
|
96
|
+
if isinstance(value, int):
|
97
|
+
tehook.clip_skip = value
|
98
|
+
else:
|
99
|
+
tehook.clip_skip = value[name]
|
100
100
|
|
101
101
|
@property
|
102
102
|
def clip_final_norm(self):
|
103
|
-
return self.
|
103
|
+
return {name:tehook.clip_final_norm for name, tehook in self.tehooks.items()}
|
104
104
|
|
105
105
|
@clip_final_norm.setter
|
106
|
-
def clip_final_norm(self, value: bool):
|
107
|
-
for name, tehook in self.
|
108
|
-
|
106
|
+
def clip_final_norm(self, value: bool | Dict[str, bool]):
|
107
|
+
for name, tehook in self.tehooks.items():
|
108
|
+
if isinstance(value, bool):
|
109
|
+
tehook.clip_final_norm = value
|
110
|
+
else:
|
111
|
+
tehook.clip_final_norm = value[name]
|
109
112
|
|
110
113
|
@property
|
111
114
|
def use_attention_mask(self):
|
112
|
-
return self.
|
115
|
+
return {name:tehook.use_attention_mask for name, tehook in self.tehooks.items()}
|
113
116
|
|
114
117
|
@use_attention_mask.setter
|
115
|
-
def use_attention_mask(self, value: bool):
|
116
|
-
for name, tehook in self.
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
encoder_hidden_states, pooled_output, attention_mask = list(zip(*emb_list))
|
122
|
-
return torch.cat(encoder_hidden_states, dim=self.cat_dim), pooled_output, attention_mask[0]
|
118
|
+
def use_attention_mask(self, value: bool | Dict[str, bool]):
|
119
|
+
for name, tehook in self.tehooks.items():
|
120
|
+
if isinstance(value, bool):
|
121
|
+
tehook.use_attention_mask = value
|
122
|
+
else:
|
123
|
+
tehook.use_attention_mask = value[name]
|
123
124
|
|
124
125
|
def enable_xformers(self):
|
125
|
-
for name, tehook in self.
|
126
|
+
for name, tehook in self.tehooks.items():
|
126
127
|
tehook.enable_xformers()
|
127
128
|
|
128
129
|
@staticmethod
|
@@ -130,19 +131,20 @@ class ComposeTEEXHook:
|
|
130
131
|
return TEEXHook.mult_attn(prompt_embeds, attn_mult)
|
131
132
|
|
132
133
|
@classmethod
|
133
|
-
def hook(cls, text_enc: nn.Module, tokenizer, N_repeats=
|
134
|
+
def hook(cls, text_enc: nn.Module, tokenizer, N_repeats=1, clip_skip=0, clip_final_norm=True, use_attention_mask=False) -> Union[
|
134
135
|
'ComposeTEEXHook', TEEXHook]:
|
135
136
|
if isinstance(text_enc, ComposeTextEncoder):
|
136
137
|
# multi text encoder
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
138
|
+
get_data = lambda name, data:data[name] if isinstance(data, dict) else data
|
139
|
+
tehooks = {name:TEEXHook.hook(getattr(text_enc, name), getattr(tokenizer, name), get_data(name, N_repeats), get_data(name, clip_skip),
|
140
|
+
get_data(name, clip_final_norm), use_attention_mask=get_data(name, use_attention_mask))
|
141
|
+
for name in tokenizer.tokenizer_names}
|
142
|
+
return cls(tehooks)
|
141
143
|
else:
|
142
144
|
# single text encoder
|
143
145
|
return TEEXHook.hook(text_enc, tokenizer, N_repeats, clip_skip, clip_final_norm, use_attention_mask=use_attention_mask)
|
144
146
|
|
145
147
|
@classmethod
|
146
|
-
def hook_pipe(cls, pipe, N_repeats=
|
148
|
+
def hook_pipe(cls, pipe, N_repeats=1, clip_skip=0, clip_final_norm=True, use_attention_mask=False):
|
147
149
|
return cls.hook(pipe.text_encoder, pipe.tokenizer, N_repeats=N_repeats, clip_skip=clip_skip, clip_final_norm=clip_final_norm,
|
148
150
|
use_attention_mask=use_attention_mask)
|
@@ -13,24 +13,24 @@ from typing import Dict, Optional, Union, Tuple, List
|
|
13
13
|
|
14
14
|
import torch
|
15
15
|
from torch import nn
|
16
|
-
from transformers import CLIPTextModel, PreTrainedModel, PretrainedConfig
|
16
|
+
from transformers import CLIPTextModel, PreTrainedModel, PretrainedConfig, AutoModel
|
17
17
|
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
18
|
+
from rainbowneko.utils import BatchableDict
|
18
19
|
|
19
20
|
class ComposeTextEncoder(PreTrainedModel):
|
20
|
-
def __init__(self,
|
21
|
-
super().__init__(PretrainedConfig(**{name:model.config for name, model in
|
22
|
-
self.cat_dim = cat_dim
|
21
|
+
def __init__(self, models: Dict[str, PreTrainedModel], with_hook=True):
|
22
|
+
super().__init__(PretrainedConfig(**{name:model.config for name, model in models.items()}))
|
23
23
|
self.with_hook = with_hook
|
24
24
|
|
25
25
|
self.model_names = []
|
26
|
-
for name, model in
|
26
|
+
for name, model in models.items():
|
27
27
|
setattr(self, name, model)
|
28
28
|
self.model_names.append(name)
|
29
29
|
|
30
|
-
def get_input_embeddings(self) ->
|
31
|
-
return
|
30
|
+
def get_input_embeddings(self) -> Dict[str, nn.Module]:
|
31
|
+
return {name: getattr(self, name).get_input_embeddings() for name in self.model_names}
|
32
32
|
|
33
|
-
def set_input_embeddings(self, value_dict: Dict[str,
|
33
|
+
def set_input_embeddings(self, value_dict: Dict[str, torch.Tensor]):
|
34
34
|
for name, value in value_dict.items():
|
35
35
|
getattr(self, name).set_input_embeddings(value)
|
36
36
|
|
@@ -60,7 +60,7 @@ class ComposeTextEncoder(PreTrainedModel):
|
|
60
60
|
>>> tokenizer_B = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
61
61
|
>>> tokenizer_bigG = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
62
62
|
|
63
|
-
>>> clip_model =
|
63
|
+
>>> clip_model = ComposeTextEncoder({'clip_B': clip_B, 'clip_bigG': clip_bigG})
|
64
64
|
|
65
65
|
>>> inputs = {
|
66
66
|
>>> 'clip_B':tokenizer_B(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt").input_ids
|
@@ -72,28 +72,42 @@ class ComposeTextEncoder(PreTrainedModel):
|
|
72
72
|
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
|
73
73
|
```"""
|
74
74
|
|
75
|
-
|
75
|
+
def get_data(name, data):
|
76
|
+
if data is None:
|
77
|
+
return None
|
78
|
+
elif isinstance(data, (dict, BatchableDict)):
|
79
|
+
return data[name]
|
80
|
+
else:
|
81
|
+
return data
|
76
82
|
|
77
83
|
if self.with_hook:
|
78
|
-
|
79
|
-
for name
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
84
|
+
encoder_hidden_states_dict, pooled_output_dict = {}, {}
|
85
|
+
for name in self.model_names:
|
86
|
+
if position_ids_i := get_data(name, position_ids) is None:
|
87
|
+
encoder_hidden_states, pooled_output = getattr(self, name)(
|
88
|
+
get_data(name, input_ids), # get token for model self.{name}
|
89
|
+
attention_mask=get_data(name, attention_mask),
|
90
|
+
output_attentions=get_data(name, output_attentions),
|
91
|
+
output_hidden_states=get_data(name, output_hidden_states),
|
92
|
+
return_dict=True,
|
93
|
+
)
|
94
|
+
else:
|
95
|
+
encoder_hidden_states, pooled_output = getattr(self, name)(
|
96
|
+
get_data(name, input_ids), # get token for model self.{name}
|
97
|
+
attention_mask=get_data(name, attention_mask),
|
98
|
+
position_ids=position_ids_i,
|
99
|
+
output_attentions=get_data(name, output_attentions),
|
100
|
+
output_hidden_states=get_data(name, output_hidden_states),
|
101
|
+
return_dict=True,
|
102
|
+
)
|
103
|
+
encoder_hidden_states_dict[name] = encoder_hidden_states
|
104
|
+
pooled_output_dict[name] = pooled_output
|
105
|
+
return encoder_hidden_states_dict, pooled_output_dict
|
92
106
|
else:
|
93
107
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
94
108
|
|
95
|
-
text_feat_list = {'last_hidden_state':
|
96
|
-
for name
|
109
|
+
text_feat_list = {'last_hidden_state':{}, 'pooler_output':{}, 'hidden_states':{}, 'attentions':{}}
|
110
|
+
for name in self.model_names:
|
97
111
|
text_feat: BaseModelOutputWithPooling = getattr(self, name)(
|
98
112
|
input_ids, # get token for model self.{name}
|
99
113
|
attention_mask=attention_mask,
|
@@ -102,31 +116,31 @@ class ComposeTextEncoder(PreTrainedModel):
|
|
102
116
|
output_hidden_states=output_hidden_states,
|
103
117
|
return_dict=True,
|
104
118
|
)
|
105
|
-
text_feat_list['last_hidden_state']
|
106
|
-
text_feat_list['pooler_output']
|
107
|
-
text_feat_list['hidden_states']
|
108
|
-
text_feat_list['attentions']
|
109
|
-
|
110
|
-
last_hidden_state = torch.cat(text_feat_list['last_hidden_state'], dim=self.cat_dim)
|
111
|
-
# pooler_output = torch.cat(text_feat_list['pooler_output'], dim=self.cat_dim)
|
112
|
-
pooler_output = text_feat_list['pooler_output']
|
113
|
-
if text_feat_list['hidden_states'][0] is None:
|
114
|
-
|
115
|
-
else:
|
116
|
-
|
119
|
+
text_feat_list['last_hidden_state'][name] = text_feat.last_hidden_state
|
120
|
+
text_feat_list['pooler_output'][name] = text_feat.pooler_output
|
121
|
+
text_feat_list['hidden_states'][name] = text_feat.hidden_states
|
122
|
+
text_feat_list['attentions'][name] = text_feat.attentions
|
123
|
+
|
124
|
+
# last_hidden_state = torch.cat(text_feat_list['last_hidden_state'], dim=self.cat_dim)
|
125
|
+
# # pooler_output = torch.cat(text_feat_list['pooler_output'], dim=self.cat_dim)
|
126
|
+
# pooler_output = text_feat_list['pooler_output']
|
127
|
+
# if text_feat_list['hidden_states'][0] is None:
|
128
|
+
# hidden_states = None
|
129
|
+
# else:
|
130
|
+
# hidden_states = [torch.cat(states, dim=self.cat_dim) for states in zip(*text_feat_list['hidden_states'])]
|
117
131
|
|
118
132
|
if return_dict:
|
119
133
|
return BaseModelOutputWithPooling(
|
120
|
-
last_hidden_state=last_hidden_state,
|
121
|
-
pooler_output=pooler_output,
|
122
|
-
hidden_states=hidden_states,
|
134
|
+
last_hidden_state=text_feat_list['last_hidden_state'],
|
135
|
+
pooler_output=text_feat_list['pooler_output'],
|
136
|
+
hidden_states=text_feat_list['hidden_states'],
|
123
137
|
attentions=text_feat_list['attentions'],
|
124
138
|
)
|
125
139
|
else:
|
126
|
-
return
|
140
|
+
return text_feat_list['last_hidden_state'], text_feat_list['pooler_output'], text_feat_list['hidden_states']
|
127
141
|
|
128
142
|
@classmethod
|
129
|
-
def from_pretrained(cls, pretrained_model_name_or_path:
|
143
|
+
def from_pretrained(cls, pretrained_model_name_or_path: Dict[str, str], *args,
|
130
144
|
subfolder: Dict[str, str] = None, revision: str = None, **kwargs):
|
131
145
|
r"""
|
132
146
|
Examples: sdxl text encoder
|
@@ -138,6 +152,6 @@ class ComposeTextEncoder(PreTrainedModel):
|
|
138
152
|
>>> ], subfolder={'clip_B':'text_encoder', 'clip_bigG':'text_encoder_2'})
|
139
153
|
```
|
140
154
|
"""
|
141
|
-
|
142
|
-
compose_model = cls(
|
155
|
+
models = {name: AutoModel.from_pretrained(path, subfolder=subfolder[name], **kwargs) for name, path in pretrained_model_name_or_path.items()}
|
156
|
+
compose_model = cls(models)
|
143
157
|
return compose_model
|