hcpdiff 0.9.1__py3-none-any.whl → 2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/__init__.py +4 -4
- hcpdiff/ckpt_manager/__init__.py +4 -5
- hcpdiff/ckpt_manager/ckpt.py +24 -0
- hcpdiff/ckpt_manager/format/__init__.py +4 -0
- hcpdiff/ckpt_manager/format/diffusers.py +59 -0
- hcpdiff/ckpt_manager/format/emb.py +21 -0
- hcpdiff/ckpt_manager/format/lora_webui.py +252 -0
- hcpdiff/ckpt_manager/format/sd_single.py +41 -0
- hcpdiff/ckpt_manager/loader.py +64 -0
- hcpdiff/data/__init__.py +4 -28
- hcpdiff/data/cache/__init__.py +1 -0
- hcpdiff/data/cache/vae.py +102 -0
- hcpdiff/data/dataset.py +20 -0
- hcpdiff/data/handler/__init__.py +3 -0
- hcpdiff/data/handler/controlnet.py +18 -0
- hcpdiff/data/handler/diffusion.py +90 -0
- hcpdiff/data/handler/text.py +111 -0
- hcpdiff/data/source/__init__.py +3 -3
- hcpdiff/data/source/folder_class.py +12 -29
- hcpdiff/data/source/text.py +40 -0
- hcpdiff/data/source/text2img.py +36 -74
- hcpdiff/data/source/text2img_cond.py +9 -15
- hcpdiff/diffusion/__init__.py +0 -0
- hcpdiff/diffusion/noise/__init__.py +2 -0
- hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
- hcpdiff/diffusion/noise/zero_terminal.py +39 -0
- hcpdiff/diffusion/sampler/__init__.py +5 -0
- hcpdiff/diffusion/sampler/base.py +72 -0
- hcpdiff/diffusion/sampler/ddpm.py +20 -0
- hcpdiff/diffusion/sampler/diffusers.py +66 -0
- hcpdiff/diffusion/sampler/edm.py +22 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
- hcpdiff/easy/__init__.py +2 -0
- hcpdiff/easy/cfg/__init__.py +3 -0
- hcpdiff/easy/cfg/sd15_train.py +207 -0
- hcpdiff/easy/cfg/sdxl_train.py +147 -0
- hcpdiff/easy/cfg/t2i.py +228 -0
- hcpdiff/easy/model/__init__.py +2 -0
- hcpdiff/easy/model/cnet.py +31 -0
- hcpdiff/easy/model/loader.py +79 -0
- hcpdiff/easy/sampler.py +46 -0
- hcpdiff/evaluate/__init__.py +1 -0
- hcpdiff/evaluate/previewer.py +60 -0
- hcpdiff/loss/__init__.py +4 -1
- hcpdiff/loss/base.py +41 -0
- hcpdiff/loss/gw.py +35 -0
- hcpdiff/loss/ssim.py +37 -0
- hcpdiff/loss/vlb.py +79 -0
- hcpdiff/loss/weighting.py +66 -0
- hcpdiff/models/__init__.py +2 -2
- hcpdiff/models/cfg_context.py +17 -14
- hcpdiff/models/compose/compose_hook.py +44 -23
- hcpdiff/models/compose/compose_tokenizer.py +21 -8
- hcpdiff/models/compose/sdxl_composer.py +4 -4
- hcpdiff/models/controlnet.py +16 -16
- hcpdiff/models/lora_base_patch.py +14 -25
- hcpdiff/models/lora_layers.py +3 -9
- hcpdiff/models/lora_layers_patch.py +14 -24
- hcpdiff/models/text_emb_ex.py +84 -6
- hcpdiff/models/textencoder_ex.py +54 -18
- hcpdiff/models/wrapper/__init__.py +3 -0
- hcpdiff/models/wrapper/pixart.py +19 -0
- hcpdiff/models/wrapper/sd.py +218 -0
- hcpdiff/models/wrapper/utils.py +20 -0
- hcpdiff/parser/__init__.py +1 -0
- hcpdiff/parser/embpt.py +32 -0
- hcpdiff/tools/convert_caption_txt2json.py +1 -1
- hcpdiff/tools/dataset_generator.py +94 -0
- hcpdiff/tools/download_hf_model.py +24 -0
- hcpdiff/tools/init_proj.py +3 -21
- hcpdiff/tools/lora_convert.py +18 -17
- hcpdiff/tools/save_model.py +12 -0
- hcpdiff/tools/sd2diffusers.py +1 -1
- hcpdiff/train_colo.py +1 -1
- hcpdiff/train_deepspeed.py +1 -1
- hcpdiff/trainer_ac.py +79 -0
- hcpdiff/trainer_ac_single.py +31 -0
- hcpdiff/utils/__init__.py +0 -2
- hcpdiff/utils/inpaint_pipe.py +7 -2
- hcpdiff/utils/net_utils.py +29 -6
- hcpdiff/utils/pipe_hook.py +24 -7
- hcpdiff/utils/utils.py +21 -4
- hcpdiff/workflow/__init__.py +15 -10
- hcpdiff/workflow/daam/__init__.py +1 -0
- hcpdiff/workflow/daam/act.py +66 -0
- hcpdiff/workflow/daam/hook.py +109 -0
- hcpdiff/workflow/diffusion.py +118 -128
- hcpdiff/workflow/fast.py +31 -0
- hcpdiff/workflow/flow.py +67 -0
- hcpdiff/workflow/io.py +36 -130
- hcpdiff/workflow/model.py +46 -43
- hcpdiff/workflow/text.py +60 -47
- hcpdiff/workflow/utils.py +32 -12
- hcpdiff/workflow/vae.py +37 -38
- hcpdiff-2.2.dist-info/METADATA +299 -0
- hcpdiff-2.2.dist-info/RECORD +115 -0
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/WHEEL +1 -1
- hcpdiff-2.2.dist-info/entry_points.txt +5 -0
- hcpdiff/ckpt_manager/base.py +0 -16
- hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
- hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
- hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
- hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
- hcpdiff/data/bucket.py +0 -358
- hcpdiff/data/caption_loader.py +0 -80
- hcpdiff/data/cond_dataset.py +0 -40
- hcpdiff/data/crop_info_dataset.py +0 -40
- hcpdiff/data/data_processor.py +0 -33
- hcpdiff/data/pair_dataset.py +0 -146
- hcpdiff/data/sampler.py +0 -54
- hcpdiff/data/source/base.py +0 -30
- hcpdiff/data/utils.py +0 -80
- hcpdiff/deprecated/__init__.py +0 -1
- hcpdiff/deprecated/cfg_converter.py +0 -81
- hcpdiff/deprecated/lora_convert.py +0 -31
- hcpdiff/infer_workflow.py +0 -57
- hcpdiff/loggers/__init__.py +0 -13
- hcpdiff/loggers/base_logger.py +0 -76
- hcpdiff/loggers/cli_logger.py +0 -40
- hcpdiff/loggers/preview/__init__.py +0 -1
- hcpdiff/loggers/preview/image_previewer.py +0 -149
- hcpdiff/loggers/tensorboard_logger.py +0 -30
- hcpdiff/loggers/wandb_logger.py +0 -31
- hcpdiff/loggers/webui_logger.py +0 -9
- hcpdiff/loss/min_snr_loss.py +0 -52
- hcpdiff/models/layers.py +0 -81
- hcpdiff/models/plugin.py +0 -348
- hcpdiff/models/wrapper.py +0 -75
- hcpdiff/noise/__init__.py +0 -3
- hcpdiff/noise/noise_base.py +0 -16
- hcpdiff/noise/pyramid_noise.py +0 -50
- hcpdiff/noise/zero_terminal.py +0 -44
- hcpdiff/train_ac.py +0 -566
- hcpdiff/train_ac_single.py +0 -39
- hcpdiff/utils/caption_tools.py +0 -105
- hcpdiff/utils/cfg_net_tools.py +0 -321
- hcpdiff/utils/cfg_resolvers.py +0 -16
- hcpdiff/utils/ema.py +0 -52
- hcpdiff/utils/img_size_tool.py +0 -248
- hcpdiff/vis/__init__.py +0 -3
- hcpdiff/vis/base_interface.py +0 -12
- hcpdiff/vis/disk_interface.py +0 -48
- hcpdiff/vis/webui_interface.py +0 -17
- hcpdiff/viser_fast.py +0 -138
- hcpdiff/visualizer.py +0 -265
- hcpdiff/visualizer_reloadable.py +0 -237
- hcpdiff/workflow/base.py +0 -59
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
- hcpdiff-0.9.1.dist-info/METADATA +0 -199
- hcpdiff-0.9.1.dist-info/RECORD +0 -160
- hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info/licenses}/LICENSE +0 -0
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/top_level.txt +0 -0
hcpdiff/loss/ssim.py
ADDED
@@ -0,0 +1,37 @@
|
|
1
|
+
from pytorch_msssim import SSIM, MS_SSIM
|
2
|
+
from torch.nn.modules.loss import _Loss
|
3
|
+
import torch
|
4
|
+
|
5
|
+
class SSIMLoss(_Loss):
|
6
|
+
target_type = 'x0'
|
7
|
+
|
8
|
+
def __init__(self, size_average=None, reduce=None, reduction: str = 'mean'):
|
9
|
+
super().__init__(size_average=size_average, reduce=reduce, reduction=reduction)
|
10
|
+
self.ssim = SSIM(data_range=1., size_average=False, channel=4)
|
11
|
+
|
12
|
+
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
13
|
+
'''
|
14
|
+
|
15
|
+
:param input: [B,C,H,W]
|
16
|
+
:param target: [B,C,H,W]
|
17
|
+
:return: [B,1,1,1]
|
18
|
+
'''
|
19
|
+
input = (input+1)/2
|
20
|
+
target = (target+1)/2
|
21
|
+
return 1-self.ssim(input, target).view(-1,1,1,1)
|
22
|
+
|
23
|
+
class MS_SSIMLoss(_Loss):
|
24
|
+
def __init__(self, size_average=None, reduce=None, reduction: str = 'mean'):
|
25
|
+
super().__init__(size_average=size_average, reduce=reduce, reduction=reduction)
|
26
|
+
self.ssim = MS_SSIM(data_range=1., size_average=False, channel=4)
|
27
|
+
|
28
|
+
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
29
|
+
'''
|
30
|
+
|
31
|
+
:param input: [B,C,H,W]
|
32
|
+
:param target: [B,C,H,W]
|
33
|
+
:return: [B,1,1,1]
|
34
|
+
'''
|
35
|
+
input = (input+1)/2
|
36
|
+
target = (target+1)/2
|
37
|
+
return 1-self.ssim(input, target).view(-1,1,1,1)
|
hcpdiff/loss/vlb.py
ADDED
@@ -0,0 +1,79 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import nn
|
3
|
+
import numpy as np
|
4
|
+
|
5
|
+
class VLBLoss(nn.Module):
|
6
|
+
need_sigma = True
|
7
|
+
need_timesteps = True
|
8
|
+
need_sampler = True
|
9
|
+
var_pred = True
|
10
|
+
|
11
|
+
def __init__(self, loss, weight: float = 1.):
|
12
|
+
super().__init__()
|
13
|
+
self.loss = loss
|
14
|
+
self.weight = weight
|
15
|
+
|
16
|
+
def normal_kl(self, mean1, logvar1, mean2, logvar2):
|
17
|
+
"""
|
18
|
+
Compute the KL divergence between two gaussians.
|
19
|
+
"""
|
20
|
+
|
21
|
+
return 0.5*(-1.0+logvar2-logvar1+(logvar1-logvar2).exp()+((mean1-mean2)**2)*(-logvar2).exp())
|
22
|
+
|
23
|
+
def forward(self, input: torch.Tensor, target: torch.Tensor, sigma, timesteps: torch.Tensor, x_t: torch.Tensor, sampler):
|
24
|
+
eps_pred, var_pred = input.chunk(2, dim=1)
|
25
|
+
x0_pred = sampler.eps_to_x0(eps_pred, x_t, sigma)
|
26
|
+
|
27
|
+
true_mean = sampler.sigma_scheduler.get_post_mean(timesteps, target, x_t)
|
28
|
+
true_logvar = sampler.sigma_scheduler.get_post_log_var(timesteps)
|
29
|
+
|
30
|
+
pred_mean = sampler.sigma_scheduler.get_post_mean(timesteps, x0_pred, x_t)
|
31
|
+
pred_logvar = sampler.sigma_scheduler.get_post_log_var(timesteps, x_t_var=var_pred)
|
32
|
+
|
33
|
+
kl = self.normal_kl(true_mean, true_logvar, pred_mean, pred_logvar)
|
34
|
+
kl = kl.mean(dim=(1,2,3))/np.log(2.0)
|
35
|
+
|
36
|
+
decoder_nll = -self.discretized_gaussian_log_likelihood(target, means=pred_mean, log_scales=0.5*pred_logvar)
|
37
|
+
assert decoder_nll.shape == target.shape
|
38
|
+
decoder_nll = decoder_nll.mean(dim=(1,2,3))/np.log(2.0)
|
39
|
+
|
40
|
+
# At the first timestep return the decoder NLL,
|
41
|
+
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
42
|
+
output = torch.where((timesteps == 0), decoder_nll, kl)
|
43
|
+
|
44
|
+
return self.weight*output
|
45
|
+
|
46
|
+
def approx_standard_normal_cdf(self, x):
|
47
|
+
"""
|
48
|
+
A fast approximation of the cumulative distribution function of the
|
49
|
+
standard normal.
|
50
|
+
"""
|
51
|
+
return 0.5*(1.0+torch.tanh(np.sqrt(2.0/np.pi)*(x+0.044715*torch.pow(x, 3))))
|
52
|
+
|
53
|
+
def discretized_gaussian_log_likelihood(self, x, *, means, log_scales):
|
54
|
+
"""
|
55
|
+
Compute the log-likelihood of a Gaussian distribution discretizing to a
|
56
|
+
given image.
|
57
|
+
:param x: the target images. It is assumed that this was uint8 values,
|
58
|
+
rescaled to the range [-1, 1].
|
59
|
+
:param means: the Gaussian mean Tensor.
|
60
|
+
:param log_scales: the Gaussian log stddev Tensor.
|
61
|
+
:return: a tensor like x of log probabilities (in nats).
|
62
|
+
"""
|
63
|
+
assert x.shape == means.shape == log_scales.shape
|
64
|
+
centered_x = x-means
|
65
|
+
inv_stdv = torch.exp(-log_scales)
|
66
|
+
plus_in = inv_stdv*(centered_x+1.0/255.0)
|
67
|
+
cdf_plus = self.approx_standard_normal_cdf(plus_in)
|
68
|
+
min_in = inv_stdv*(centered_x-1.0/255.0)
|
69
|
+
cdf_min = self.approx_standard_normal_cdf(min_in)
|
70
|
+
log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
|
71
|
+
log_one_minus_cdf_min = torch.log((1.0-cdf_min).clamp(min=1e-12))
|
72
|
+
cdf_delta = cdf_plus-cdf_min
|
73
|
+
log_probs = torch.where(
|
74
|
+
x<-0.999,
|
75
|
+
log_cdf_plus,
|
76
|
+
torch.where(x>0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))),
|
77
|
+
)
|
78
|
+
assert log_probs.shape == x.shape
|
79
|
+
return log_probs
|
@@ -0,0 +1,66 @@
|
|
1
|
+
from torch import nn
|
2
|
+
|
3
|
+
from .base import DiffusionLossContainer
|
4
|
+
|
5
|
+
class LossWeight(nn.Module):
|
6
|
+
def __init__(self, loss: DiffusionLossContainer):
|
7
|
+
super().__init__()
|
8
|
+
self.loss = loss
|
9
|
+
|
10
|
+
def get_weight(self, pred, inputs):
|
11
|
+
'''
|
12
|
+
|
13
|
+
:param input: [B,C,H,W]
|
14
|
+
:param target: [B,C,H,W]
|
15
|
+
:return: [B,1,1,1] or [B,C,H,W]
|
16
|
+
'''
|
17
|
+
raise NotImplementedError
|
18
|
+
|
19
|
+
def forward(self, pred, inputs):
|
20
|
+
'''
|
21
|
+
weight: [B,1,1,1] or [B,C,H,W]
|
22
|
+
loss: [B,*,*,*]
|
23
|
+
'''
|
24
|
+
return self.get_weight(pred, inputs)*self.loss(pred, inputs)
|
25
|
+
|
26
|
+
class SNRWeight(LossWeight):
|
27
|
+
def get_weight(self, pred, inputs):
|
28
|
+
if self.loss.target_type == 'eps':
|
29
|
+
return 1
|
30
|
+
elif self.loss.target_type == "x0":
|
31
|
+
sigma = pred['sigma']
|
32
|
+
return (1./sigma**2).view(-1, 1, 1, 1)
|
33
|
+
else:
|
34
|
+
raise ValueError(f"{self.__class__.__name__} is not support for target_type {self.loss.target_type}")
|
35
|
+
|
36
|
+
class MinSNRWeight(LossWeight):
|
37
|
+
def __init__(self, loss: DiffusionLossContainer, gamma: float = 1.):
|
38
|
+
super().__init__(loss)
|
39
|
+
self.gamma = gamma
|
40
|
+
|
41
|
+
def get_weight(self, pred, inputs):
|
42
|
+
sigma = pred['sigma']
|
43
|
+
if self.loss.target_type == 'eps':
|
44
|
+
w_snr = (self.gamma*sigma**2).clip(max=1).float()
|
45
|
+
elif self.loss.target_type == "x0":
|
46
|
+
w_snr = (1/(sigma**2)).clip(max=self.gamma).float()
|
47
|
+
else:
|
48
|
+
raise ValueError(f"{self.__class__.__name__} is not support for target_type {self.loss.target_type}")
|
49
|
+
|
50
|
+
return w_snr.view(-1, 1, 1, 1)
|
51
|
+
|
52
|
+
class EDMWeight(LossWeight):
|
53
|
+
def __init__(self, loss: DiffusionLossContainer, gamma: float = 1.):
|
54
|
+
super().__init__(loss)
|
55
|
+
self.gamma = gamma
|
56
|
+
|
57
|
+
def get_weight(self, pred, inputs):
|
58
|
+
sigma = pred['sigma']
|
59
|
+
if self.loss.target_type == 'eps':
|
60
|
+
w_snr = ((sigma**2+self.gamma**2)/(self.gamma**2)).float()
|
61
|
+
elif self.loss.target_type == "x0":
|
62
|
+
w_snr = ((sigma**2+self.gamma**2)/((sigma*self.gamma)**2)).float()
|
63
|
+
else:
|
64
|
+
raise ValueError(f"{self.__class__.__name__} is not support for target_type {self.loss.target_type}")
|
65
|
+
|
66
|
+
return w_snr.view(-1, 1, 1, 1)
|
hcpdiff/models/__init__.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1
|
-
from .plugin import PluginBlock, PluginGroup, SinglePluginBlock, MultiPluginBlock, PatchPluginBlock
|
2
1
|
# from .lora_base import LoraBlock, LoraGroup
|
3
2
|
# from .lora_layers import lora_layer_map
|
4
3
|
from .lora_base_patch import LoraBlock, LoraGroup
|
@@ -7,4 +6,5 @@ from .text_emb_ex import EmbeddingPTHook
|
|
7
6
|
from .textencoder_ex import TEEXHook
|
8
7
|
from .tokenizer_ex import TokenizerHook
|
9
8
|
from .cfg_context import CFGContext, DreamArtistPTContext
|
10
|
-
from .wrapper import
|
9
|
+
from .wrapper import SD15Wrapper, SDXLWrapper, PixArtWrapper, TEHookCFG
|
10
|
+
from .controlnet import ControlNetPlugin
|
hcpdiff/models/cfg_context.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import torch
|
2
2
|
from einops import repeat
|
3
3
|
import math
|
4
|
+
from typing import Union, Callable
|
4
5
|
|
5
6
|
class CFGContext:
|
6
7
|
def pre(self, noisy_latents, timesteps):
|
@@ -10,9 +11,11 @@ class CFGContext:
|
|
10
11
|
return model_pred
|
11
12
|
|
12
13
|
class DreamArtistPTContext(CFGContext):
|
13
|
-
def __init__(self,
|
14
|
-
self.
|
15
|
-
self.
|
14
|
+
def __init__(self, cfg_low: float, cfg_high: float=None, cfg_func: Union[str, Callable]=None, num_train_timesteps=1000):
|
15
|
+
self.cfg_low = cfg_low
|
16
|
+
self.cfg_high = cfg_high or cfg_low
|
17
|
+
self.cfg_func = cfg_func
|
18
|
+
self.num_train_timesteps = num_train_timesteps
|
16
19
|
|
17
20
|
def pre(self, noisy_latents, timesteps):
|
18
21
|
self.t_raw = timesteps
|
@@ -22,18 +25,18 @@ class DreamArtistPTContext(CFGContext):
|
|
22
25
|
|
23
26
|
def post(self, model_pred):
|
24
27
|
e_t_uncond, e_t = model_pred.chunk(2)
|
25
|
-
if self.
|
26
|
-
rate = self.t_raw
|
27
|
-
if self.
|
28
|
-
rate = torch.cos((rate
|
29
|
-
elif self.
|
30
|
-
rate = 1
|
31
|
-
elif self.
|
28
|
+
if self.cfg_low != self.cfg_high:
|
29
|
+
rate = self.t_raw/(self.num_train_timesteps-1)
|
30
|
+
if self.cfg_func == 'cos':
|
31
|
+
rate = torch.cos((rate-1)*math.pi/2)
|
32
|
+
elif self.cfg_func == 'cos2':
|
33
|
+
rate = 1-torch.cos(rate*math.pi/2)
|
34
|
+
elif self.cfg_func == 'ln':
|
32
35
|
pass
|
33
36
|
else:
|
34
|
-
rate =
|
35
|
-
rate = rate.view(-1,1,1,1)
|
37
|
+
rate = self.cfg_func(rate)
|
38
|
+
rate = rate.view(-1, 1, 1, 1)
|
36
39
|
else:
|
37
40
|
rate = 1
|
38
|
-
model_pred = e_t_uncond
|
39
|
-
return model_pred
|
41
|
+
model_pred = e_t_uncond+((self.cfg_high-self.cfg_low)*rate+self.cfg_low)*(e_t-e_t_uncond)
|
42
|
+
return model_pred
|
@@ -38,42 +38,42 @@ class ComposeEmbPTHook(nn.Module):
|
|
38
38
|
hook.remove()
|
39
39
|
|
40
40
|
@classmethod
|
41
|
-
def hook(cls, ex_words_emb: Dict[str, ParameterGroup], tokenizer, text_encoder,
|
41
|
+
def hook(cls, ex_words_emb: Dict[str, ParameterGroup], tokenizer, text_encoder, **kwargs):
|
42
42
|
if isinstance(text_encoder, ComposeTextEncoder):
|
43
43
|
hook_list = []
|
44
44
|
|
45
45
|
emb_len = 0
|
46
|
-
for i,
|
46
|
+
for i, name in enumerate(tokenizer.tokenizer_names):
|
47
47
|
text_encoder_i = getattr(text_encoder, name)
|
48
|
-
|
49
|
-
logger.info(f'compose hook: {name}')
|
48
|
+
tokenizer_i = getattr(tokenizer, name)
|
50
49
|
embedding_dim = text_encoder_i.get_input_embeddings().embedding_dim
|
51
50
|
ex_words_emb_i = {k:v[i] for k, v in ex_words_emb.items()}
|
52
51
|
emb_len += embedding_dim
|
53
|
-
hook_list.append((name, EmbeddingPTHook.hook(ex_words_emb_i, tokenizer_i, text_encoder_i,
|
52
|
+
hook_list.append((name, EmbeddingPTHook.hook(ex_words_emb_i, tokenizer_i, text_encoder_i, **kwargs)))
|
54
53
|
|
55
54
|
return cls(hook_list)
|
56
55
|
else:
|
57
|
-
return EmbeddingPTHook.hook(ex_words_emb, tokenizer, text_encoder,
|
56
|
+
return EmbeddingPTHook.hook(ex_words_emb, tokenizer, text_encoder, **kwargs)
|
58
57
|
|
59
58
|
@classmethod
|
60
|
-
def hook_from_dir(cls, emb_dir, tokenizer, text_encoder,
|
59
|
+
def hook_from_dir(cls, emb_dir, tokenizer, text_encoder, device='cuda:0', **kwargs) -> Union[
|
61
60
|
Tuple['ComposeEmbPTHook', Dict], Tuple[EmbeddingPTHook, Dict]]:
|
62
61
|
if isinstance(text_encoder, ComposeTextEncoder):
|
63
62
|
# multi text encoder
|
64
|
-
#ex_words_emb = {file[:-3]:load_emb(os.path.join(emb_dir, file)).to(device) for file in os.listdir(emb_dir) if file.endswith('.pt')}
|
63
|
+
# ex_words_emb = {file[:-3]:load_emb(os.path.join(emb_dir, file)).to(device) for file in os.listdir(emb_dir) if file.endswith('.pt')}
|
65
64
|
|
66
65
|
# slice of nn.Parameter cannot return grad. Split the tensor
|
67
66
|
ex_words_emb = {}
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
67
|
+
if emb_dir is not None and os.path.exists(emb_dir):
|
68
|
+
emb_dims = [x.embedding_dim for x in text_encoder.get_input_embeddings()]
|
69
|
+
for file in os.listdir(emb_dir):
|
70
|
+
if file.endswith('.pt'):
|
71
|
+
emb = load_emb(os.path.join(emb_dir, file)).to(device)
|
72
|
+
emb = ParameterGroup([nn.Parameter(item, requires_grad=False) for item in emb.split(emb_dims, dim=1)])
|
73
|
+
ex_words_emb[file[:-3]] = emb
|
74
|
+
return cls.hook(ex_words_emb, tokenizer, text_encoder, **kwargs), ex_words_emb
|
75
75
|
else:
|
76
|
-
return EmbeddingPTHook.hook_from_dir(emb_dir, tokenizer, text_encoder,
|
76
|
+
return EmbeddingPTHook.hook_from_dir(emb_dir, tokenizer, text_encoder, **kwargs)
|
77
77
|
|
78
78
|
class ComposeTEEXHook:
|
79
79
|
def __init__(self, tehook_list: List[Tuple[str, TEEXHook]], cat_dim=-1):
|
@@ -98,10 +98,28 @@ class ComposeTEEXHook:
|
|
98
98
|
for name, tehook in self.tehook_list:
|
99
99
|
tehook.clip_skip = value
|
100
100
|
|
101
|
+
@property
|
102
|
+
def clip_final_norm(self):
|
103
|
+
return self.tehook_list[0][1].clip_final_norm
|
104
|
+
|
105
|
+
@clip_final_norm.setter
|
106
|
+
def clip_final_norm(self, value: bool):
|
107
|
+
for name, tehook in self.tehook_list:
|
108
|
+
tehook.clip_final_norm = value
|
109
|
+
|
110
|
+
@property
|
111
|
+
def use_attention_mask(self):
|
112
|
+
return self.tehook_list[0][1].use_attention_mask
|
113
|
+
|
114
|
+
@use_attention_mask.setter
|
115
|
+
def use_attention_mask(self, value: bool):
|
116
|
+
for name, tehook in self.tehook_list:
|
117
|
+
tehook.use_attention_mask = value
|
118
|
+
|
101
119
|
def encode_prompt_to_emb(self, prompt):
|
102
120
|
emb_list = [tehook.encode_prompt_to_emb(prompt) for name, tehook in self.tehook_list]
|
103
|
-
encoder_hidden_states, pooled_output = list(zip(*emb_list))
|
104
|
-
return torch.cat(encoder_hidden_states, dim=self.cat_dim), pooled_output
|
121
|
+
encoder_hidden_states, pooled_output, attention_mask = list(zip(*emb_list))
|
122
|
+
return torch.cat(encoder_hidden_states, dim=self.cat_dim), pooled_output, attention_mask[0]
|
105
123
|
|
106
124
|
def enable_xformers(self):
|
107
125
|
for name, tehook in self.tehook_list:
|
@@ -112,16 +130,19 @@ class ComposeTEEXHook:
|
|
112
130
|
return TEEXHook.mult_attn(prompt_embeds, attn_mult)
|
113
131
|
|
114
132
|
@classmethod
|
115
|
-
def hook(cls, text_enc: nn.Module, tokenizer, N_repeats=3, clip_skip=0, clip_final_norm=True,
|
133
|
+
def hook(cls, text_enc: nn.Module, tokenizer, N_repeats=3, clip_skip=0, clip_final_norm=True, use_attention_mask=False) -> Union[
|
134
|
+
'ComposeTEEXHook', TEEXHook]:
|
116
135
|
if isinstance(text_enc, ComposeTextEncoder):
|
117
136
|
# multi text encoder
|
118
|
-
tehook_list = [(name, TEEXHook.hook(getattr(text_enc, name),
|
119
|
-
|
137
|
+
tehook_list = [(name, TEEXHook.hook(getattr(text_enc, name), getattr(tokenizer, name), N_repeats, clip_skip, clip_final_norm,
|
138
|
+
use_attention_mask=use_attention_mask))
|
139
|
+
for name in tokenizer.tokenizer_names]
|
120
140
|
return cls(tehook_list)
|
121
141
|
else:
|
122
142
|
# single text encoder
|
123
|
-
return TEEXHook.hook(text_enc, tokenizer, N_repeats, clip_skip,
|
143
|
+
return TEEXHook.hook(text_enc, tokenizer, N_repeats, clip_skip, clip_final_norm, use_attention_mask=use_attention_mask)
|
124
144
|
|
125
145
|
@classmethod
|
126
146
|
def hook_pipe(cls, pipe, N_repeats=3, clip_skip=0, clip_final_norm=True, use_attention_mask=False):
|
127
|
-
return cls.hook(pipe.text_encoder, pipe.tokenizer, N_repeats=N_repeats,
|
147
|
+
return cls.hook(pipe.text_encoder, pipe.tokenizer, N_repeats=N_repeats, clip_skip=clip_skip, clip_final_norm=clip_final_norm,
|
148
|
+
use_attention_mask=use_attention_mask)
|
@@ -18,14 +18,19 @@ from transformers.tokenization_utils_base import BatchEncoding
|
|
18
18
|
class ComposeTokenizer(PreTrainedTokenizer):
|
19
19
|
def __init__(self, tokenizer_list: List[Tuple[str, CLIPTokenizer]], cat_dim=-1):
|
20
20
|
self.cat_dim = cat_dim
|
21
|
-
|
21
|
+
|
22
|
+
self.tokenizer_names = []
|
23
|
+
for name, tokenizer in tokenizer_list:
|
24
|
+
setattr(self, name, tokenizer)
|
25
|
+
self.tokenizer_names.append(name)
|
26
|
+
|
22
27
|
super().__init__()
|
23
28
|
|
24
|
-
self.model_max_length =
|
29
|
+
self.model_max_length = torch.tensor([tokenizer.model_max_length for name, tokenizer in tokenizer_list])
|
25
30
|
|
26
31
|
@property
|
27
32
|
def first_tokenizer(self):
|
28
|
-
return self.
|
33
|
+
return getattr(self, self.tokenizer_names[0])
|
29
34
|
|
30
35
|
@property
|
31
36
|
def vocab_size(self):
|
@@ -40,18 +45,26 @@ class ComposeTokenizer(PreTrainedTokenizer):
|
|
40
45
|
return self.first_tokenizer.bos_token_id
|
41
46
|
|
42
47
|
def get_vocab(self):
|
43
|
-
return
|
48
|
+
return self.first_tokenizer.get_vocab()
|
44
49
|
|
45
50
|
def tokenize(self, text, **kwargs) -> List[str]:
|
46
51
|
return self.first_tokenizer.tokenize(text, **kwargs)
|
47
52
|
|
48
53
|
def add_tokens( self, new_tokens, special_tokens: bool = False) -> List[int]:
|
49
|
-
return [
|
54
|
+
return [getattr(self, name).add_tokens(new_tokens, special_tokens) for name in self.tokenizer_names]
|
55
|
+
|
56
|
+
def save_vocabulary(self, save_directory: str, filename_prefix = None) -> Tuple[str]:
|
57
|
+
return self.first_tokenizer.save_vocabulary(save_directory, filename_prefix)
|
58
|
+
|
59
|
+
def __call__(self, text, *args, max_length=None, **kwargs):
|
60
|
+
if isinstance(max_length, torch.Tensor):
|
61
|
+
token_list: List[BatchEncoding] = [getattr(self, name)(text, *args, max_length=max_length_i, **kwargs)
|
62
|
+
for name, max_length_i in zip(self.tokenizer_names, max_length)]
|
63
|
+
else:
|
64
|
+
token_list: List[BatchEncoding] = [getattr(self, name)(text, *args, max_length=max_length, **kwargs) for name in self.tokenizer_names]
|
50
65
|
|
51
|
-
def __call__(self, text, *args, **kwargs):
|
52
|
-
token_list: List[BatchEncoding] = [tokenizer(text, *args, **kwargs) for name, tokenizer in self.tokenizer_list]
|
53
66
|
input_ids = torch.cat([token.input_ids for token in token_list], dim=-1) # [N_tokenizer, N_token]
|
54
|
-
attention_mask = [token.attention_mask for token in token_list]
|
67
|
+
attention_mask = torch.cat([token.attention_mask for token in token_list], dim=-1)
|
55
68
|
return BatchEncoding({'input_ids':input_ids, 'attention_mask':attention_mask})
|
56
69
|
|
57
70
|
@classmethod
|
@@ -27,13 +27,13 @@ class CLIPTextModelWithProjection_Align(CLIPTextModelWithProjection):
|
|
27
27
|
class SDXLTextEncoder(ComposeTextEncoder):
|
28
28
|
@classmethod
|
29
29
|
def from_pretrained(cls, pretrained_model_name_or_path: str, *args, subfolder=None, revision:str=None, **kwargs):
|
30
|
-
|
30
|
+
clip_L = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder', **kwargs)
|
31
31
|
clip_bigG = CLIPTextModelWithProjection_Align.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder_2', **kwargs)
|
32
|
-
return cls([('
|
32
|
+
return cls([('clip_L', clip_L), ('clip_bigG', clip_bigG)])
|
33
33
|
|
34
34
|
class SDXLTokenizer(ComposeTokenizer):
|
35
35
|
@classmethod
|
36
36
|
def from_pretrained(cls, pretrained_model_name_or_path: str, *args, subfolder=None, revision:str=None, **kwargs):
|
37
|
-
|
37
|
+
clip_L = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer', **kwargs)
|
38
38
|
clip_bigG = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer_2', **kwargs)
|
39
|
-
return cls([('
|
39
|
+
return cls([('clip_L', clip_L), ('clip_bigG', clip_bigG)])
|
hcpdiff/models/controlnet.py
CHANGED
@@ -5,7 +5,7 @@ import torch
|
|
5
5
|
from torch import nn
|
6
6
|
from copy import deepcopy
|
7
7
|
|
8
|
-
from .plugin import MultiPluginBlock, BasePluginBlock
|
8
|
+
from rainbowneko.models.plugin import MultiPluginBlock, BasePluginBlock
|
9
9
|
from hcpdiff.utils.net_utils import remove_all_hooks, remove_layers
|
10
10
|
|
11
11
|
class ControlNetPlugin(MultiPluginBlock):
|
@@ -55,25 +55,25 @@ class ControlNetPlugin(MultiPluginBlock):
|
|
55
55
|
self.cond_head = nn.Sequential(*cond_head)
|
56
56
|
|
57
57
|
def reset_parameters(self) -> None:
|
58
|
-
def
|
59
|
-
|
60
|
-
|
61
|
-
self.controlnet_down_blocks.apply(
|
62
|
-
self.controlnet_mid_block.apply(
|
63
|
-
self.cond_head[-1].apply(
|
64
|
-
|
65
|
-
def from_layer_hook(self, host,
|
58
|
+
def zero_weight_init(m):
|
59
|
+
for p in m.parameters():
|
60
|
+
p.detach().zero_()
|
61
|
+
self.controlnet_down_blocks.apply(zero_weight_init)
|
62
|
+
self.controlnet_mid_block.apply(zero_weight_init)
|
63
|
+
self.cond_head[-1].apply(zero_weight_init)
|
64
|
+
|
65
|
+
def from_layer_hook(self, host, idx: int, args: Tuple[Any, ...], kwargs: Dict[str, Any], fea_out: Any=None):
|
66
66
|
if idx==0:
|
67
|
-
self.data_input =
|
67
|
+
self.data_input = (args, kwargs)
|
68
68
|
elif idx==1:
|
69
|
-
self.feat_to = self(*self.data_input)
|
69
|
+
self.feat_to = self(*self.data_input[0], **self.data_input[1])
|
70
70
|
|
71
|
-
def to_layer_hook(self, host,
|
71
|
+
def to_layer_hook(self, host, idx: int, args: Tuple[Any, ...], kwargs: Dict[str, Any], fea_out: Any=None):
|
72
72
|
if idx == 5:
|
73
|
-
sp =
|
74
|
-
new_feat =
|
75
|
-
new_feat[:, sp:, ...] =
|
76
|
-
return (new_feat,
|
73
|
+
sp = args[0].shape[1]//2
|
74
|
+
new_feat = args[0].clone()
|
75
|
+
new_feat[:, sp:, ...] = args[0][:, sp:, ...] + self.feat_to[0]
|
76
|
+
return (new_feat, args[1])
|
77
77
|
elif idx == 3:
|
78
78
|
return (fea_out[0], tuple(fea_out[1][i] + self.feat_to[(idx) * 3 + i+1] for i in range(2)))
|
79
79
|
elif idx == 4:
|
@@ -13,7 +13,7 @@ from torch import nn
|
|
13
13
|
from torch.nn import functional as F
|
14
14
|
|
15
15
|
from hcpdiff.utils.utils import make_mask, low_rank_approximate, isinstance_list
|
16
|
-
from .plugin import PatchPluginBlock, PluginGroup, PatchPluginContainer
|
16
|
+
from rainbowneko.models.plugin import PatchPluginBlock, PluginGroup, PatchPluginContainer
|
17
17
|
|
18
18
|
from typing import Union, Tuple, Dict, Type
|
19
19
|
|
@@ -38,9 +38,9 @@ class LoraBlock(PatchPluginBlock):
|
|
38
38
|
container_cls = LoraPatchContainer
|
39
39
|
wrapable_classes = (nn.Linear, nn.Conv2d)
|
40
40
|
|
41
|
-
def __init__(self,
|
41
|
+
def __init__(self, name:int, host:Union[nn.Linear, nn.Conv2d], rank, dropout=0.1, alpha=1.0, bias=False,
|
42
42
|
alpha_auto_scale=True, parent_block=None, host_name=None, **kwargs):
|
43
|
-
super().__init__(
|
43
|
+
super().__init__(name, host, parent_block=parent_block, host_name=host_name)
|
44
44
|
|
45
45
|
self.bias=bias
|
46
46
|
|
@@ -56,8 +56,14 @@ class LoraBlock(PatchPluginBlock):
|
|
56
56
|
self.dropout = nn.Dropout(dropout)
|
57
57
|
|
58
58
|
self.rank = self.layer.rank
|
59
|
+
self.alpha_auto_scale = alpha_auto_scale
|
59
60
|
self.register_buffer('alpha', torch.tensor(alpha/self.rank if alpha_auto_scale else alpha))
|
60
61
|
|
62
|
+
def set_hyper_params(self, alpha=None, **kwargs):
|
63
|
+
if alpha is not None:
|
64
|
+
self.register_buffer('alpha', torch.tensor(alpha/self.rank if self.alpha_auto_scale else alpha))
|
65
|
+
super().set_hyper_params(**kwargs)
|
66
|
+
|
61
67
|
def get_weight(self):
|
62
68
|
return self.layer.get_weight() * self.alpha
|
63
69
|
|
@@ -91,7 +97,7 @@ class LoraBlock(PatchPluginBlock):
|
|
91
97
|
host.weight.data * base_alpha + alpha * re_w.to(host.weight.device, dtype=host.weight.dtype)
|
92
98
|
)
|
93
99
|
|
94
|
-
if
|
100
|
+
if re_b is not None:
|
95
101
|
if host.bias is None:
|
96
102
|
host.bias = nn.Parameter(re_b.to(host.weight.device, dtype=host.weight.dtype))
|
97
103
|
else:
|
@@ -145,32 +151,15 @@ class LoraBlock(PatchPluginBlock):
|
|
145
151
|
pass
|
146
152
|
|
147
153
|
@classmethod
|
148
|
-
def wrap_layer(cls,
|
154
|
+
def wrap_layer(cls, name:str, host: Union[nn.Linear, nn.Conv2d], rank=1, dropout=0.0, alpha=1.0, svd_init=False,
|
149
155
|
bias=False, mask=None, **kwargs):# -> LoraBlock:
|
150
|
-
lora_block = cls(
|
156
|
+
lora_block = cls(name, host, rank, dropout, alpha, bias=bias, **kwargs)
|
151
157
|
lora_block.init_weights(svd_init)
|
152
158
|
return lora_block
|
153
159
|
|
154
160
|
@classmethod
|
155
|
-
def wrap_model(cls,
|
156
|
-
return super(
|
157
|
-
|
158
|
-
@staticmethod
|
159
|
-
def extract_lora_state(model:nn.Module):
|
160
|
-
return {k:v for k,v in model.state_dict().items() if 'lora_block_' in k}
|
161
|
-
|
162
|
-
@staticmethod
|
163
|
-
def extract_state_without_lora(model:nn.Module):
|
164
|
-
return {k:v for k,v in model.state_dict().items() if 'lora_block_' not in k}
|
165
|
-
|
166
|
-
@staticmethod
|
167
|
-
def extract_param_without_lora(model:nn.Module):
|
168
|
-
return {k:v for k,v in model.named_parameters() if 'lora_block_' not in k}
|
169
|
-
|
170
|
-
@staticmethod
|
171
|
-
def extract_trainable_state_without_lora(model:nn.Module):
|
172
|
-
trainable_keys = {k for k,v in model.named_parameters() if ('lora_block_' not in k) and v.requires_grad}
|
173
|
-
return {k: v for k, v in model.state_dict().items() if k in trainable_keys}
|
161
|
+
def wrap_model(cls, name:str, host: nn.Module, **kwargs):# -> Dict[str, LoraBlock]:
|
162
|
+
return super().wrap_model(name, host, exclude_classes=(LoraBlock,), **kwargs)
|
174
163
|
|
175
164
|
class LoraGroup(PluginGroup):
|
176
165
|
def set_mask(self, batch_mask):
|
hcpdiff/models/lora_layers.py
CHANGED
@@ -15,7 +15,7 @@ from einops import repeat, rearrange, einsum
|
|
15
15
|
from torch import nn
|
16
16
|
|
17
17
|
from .lora_base import LoraBlock
|
18
|
-
from .layers import GroupLinear
|
18
|
+
from rainbowneko.models.layers import GroupLinear
|
19
19
|
import warnings
|
20
20
|
|
21
21
|
class LoraLayer(LoraBlock):
|
@@ -59,8 +59,8 @@ class LoraLayerGroup(LoraBlock):
|
|
59
59
|
def __init__(self, host, rank, bias, dropout, block):
|
60
60
|
super().__init__(host, rank, bias, dropout, block)
|
61
61
|
self.register_buffer('rank_groups', torch.tensor(block.rank_groups_raw, dtype=torch.int))
|
62
|
-
self.lora_down = GroupLinear(host.in_features
|
63
|
-
self.lora_up = GroupLinear(self.rank, host.out_features
|
62
|
+
self.lora_down = GroupLinear(host.in_features, self.rank//self.rank_groups, group=self.rank_groups, bias=False)
|
63
|
+
self.lora_up = GroupLinear(self.rank//self.rank_groups, host.out_features, group=self.rank_groups, bias=bias)
|
64
64
|
|
65
65
|
def feed_svd(self, U, V, weight):
|
66
66
|
self.lora_up.weight.data = rearrange(U, 'o (g ri) -> g ri o', g=self.rank_groups).to(device=weight.device, dtype=weight.dtype)
|
@@ -137,9 +137,3 @@ class LohaLayer(LoraBlock):
|
|
137
137
|
w = torch.prod(einsum(self.W_up.data, self.W_down.data, 'g o r ..., g r i ... -> g o i ...'), dim=0)
|
138
138
|
b = None
|
139
139
|
return w, b
|
140
|
-
|
141
|
-
lora_layer_map={
|
142
|
-
'lora': LoraLayer,
|
143
|
-
'loha_group': LoraLayerGroup,
|
144
|
-
'loha': LohaLayer,
|
145
|
-
}
|