hcpdiff 0.9.0__py3-none-any.whl → 2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/__init__.py +4 -4
- hcpdiff/ckpt_manager/__init__.py +4 -5
- hcpdiff/ckpt_manager/ckpt.py +24 -0
- hcpdiff/ckpt_manager/format/__init__.py +4 -0
- hcpdiff/ckpt_manager/format/diffusers.py +59 -0
- hcpdiff/ckpt_manager/format/emb.py +21 -0
- hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
- hcpdiff/ckpt_manager/format/sd_single.py +41 -0
- hcpdiff/ckpt_manager/loader.py +64 -0
- hcpdiff/data/__init__.py +4 -28
- hcpdiff/data/cache/__init__.py +1 -0
- hcpdiff/data/cache/vae.py +102 -0
- hcpdiff/data/dataset.py +20 -0
- hcpdiff/data/handler/__init__.py +3 -0
- hcpdiff/data/handler/controlnet.py +18 -0
- hcpdiff/data/handler/diffusion.py +80 -0
- hcpdiff/data/handler/text.py +111 -0
- hcpdiff/data/source/__init__.py +1 -2
- hcpdiff/data/source/folder_class.py +12 -29
- hcpdiff/data/source/text2img.py +36 -74
- hcpdiff/data/source/text2img_cond.py +9 -15
- hcpdiff/diffusion/__init__.py +0 -0
- hcpdiff/diffusion/noise/__init__.py +2 -0
- hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
- hcpdiff/diffusion/noise/zero_terminal.py +39 -0
- hcpdiff/diffusion/sampler/__init__.py +5 -0
- hcpdiff/diffusion/sampler/base.py +72 -0
- hcpdiff/diffusion/sampler/ddpm.py +20 -0
- hcpdiff/diffusion/sampler/diffusers.py +66 -0
- hcpdiff/diffusion/sampler/edm.py +22 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
- hcpdiff/easy/__init__.py +2 -0
- hcpdiff/easy/cfg/__init__.py +3 -0
- hcpdiff/easy/cfg/sd15_train.py +201 -0
- hcpdiff/easy/cfg/sdxl_train.py +140 -0
- hcpdiff/easy/cfg/t2i.py +177 -0
- hcpdiff/easy/model/__init__.py +2 -0
- hcpdiff/easy/model/cnet.py +31 -0
- hcpdiff/easy/model/loader.py +79 -0
- hcpdiff/easy/sampler.py +46 -0
- hcpdiff/evaluate/__init__.py +1 -0
- hcpdiff/evaluate/previewer.py +60 -0
- hcpdiff/loss/__init__.py +4 -1
- hcpdiff/loss/base.py +41 -0
- hcpdiff/loss/gw.py +35 -0
- hcpdiff/loss/ssim.py +37 -0
- hcpdiff/loss/vlb.py +79 -0
- hcpdiff/loss/weighting.py +66 -0
- hcpdiff/models/__init__.py +2 -2
- hcpdiff/models/cfg_context.py +17 -14
- hcpdiff/models/compose/compose_hook.py +44 -23
- hcpdiff/models/compose/compose_tokenizer.py +21 -8
- hcpdiff/models/compose/sdxl_composer.py +4 -4
- hcpdiff/models/container.py +1 -1
- hcpdiff/models/controlnet.py +16 -16
- hcpdiff/models/lora_base_patch.py +14 -25
- hcpdiff/models/lora_layers.py +3 -9
- hcpdiff/models/lora_layers_patch.py +14 -24
- hcpdiff/models/text_emb_ex.py +84 -6
- hcpdiff/models/textencoder_ex.py +54 -18
- hcpdiff/models/wrapper/__init__.py +3 -0
- hcpdiff/models/wrapper/pixart.py +19 -0
- hcpdiff/models/wrapper/sd.py +218 -0
- hcpdiff/models/wrapper/utils.py +20 -0
- hcpdiff/parser/__init__.py +1 -0
- hcpdiff/parser/embpt.py +32 -0
- hcpdiff/tools/convert_caption_txt2json.py +1 -1
- hcpdiff/tools/dataset_generator.py +94 -0
- hcpdiff/tools/download_hf_model.py +24 -0
- hcpdiff/tools/embedding_convert.py +6 -2
- hcpdiff/tools/init_proj.py +3 -21
- hcpdiff/tools/lora_convert.py +19 -15
- hcpdiff/tools/save_model.py +12 -0
- hcpdiff/tools/sd2diffusers.py +1 -1
- hcpdiff/train_colo.py +1 -1
- hcpdiff/train_deepspeed.py +1 -1
- hcpdiff/trainer_ac.py +79 -0
- hcpdiff/trainer_ac_single.py +31 -0
- hcpdiff/utils/__init__.py +0 -2
- hcpdiff/utils/inpaint_pipe.py +790 -0
- hcpdiff/utils/net_utils.py +29 -6
- hcpdiff/utils/pipe_hook.py +46 -33
- hcpdiff/utils/utils.py +21 -4
- hcpdiff/workflow/__init__.py +15 -10
- hcpdiff/workflow/daam/__init__.py +1 -0
- hcpdiff/workflow/daam/act.py +66 -0
- hcpdiff/workflow/daam/hook.py +109 -0
- hcpdiff/workflow/diffusion.py +128 -136
- hcpdiff/workflow/fast.py +31 -0
- hcpdiff/workflow/flow.py +67 -0
- hcpdiff/workflow/io.py +36 -68
- hcpdiff/workflow/model.py +46 -43
- hcpdiff/workflow/text.py +84 -52
- hcpdiff/workflow/utils.py +32 -12
- hcpdiff/workflow/vae.py +37 -38
- hcpdiff-2.1.dist-info/METADATA +285 -0
- hcpdiff-2.1.dist-info/RECORD +114 -0
- {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
- hcpdiff-2.1.dist-info/entry_points.txt +5 -0
- hcpdiff/ckpt_manager/base.py +0 -16
- hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
- hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
- hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -60
- hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
- hcpdiff/data/bucket.py +0 -358
- hcpdiff/data/caption_loader.py +0 -80
- hcpdiff/data/cond_dataset.py +0 -40
- hcpdiff/data/crop_info_dataset.py +0 -40
- hcpdiff/data/data_processor.py +0 -33
- hcpdiff/data/pair_dataset.py +0 -146
- hcpdiff/data/sampler.py +0 -54
- hcpdiff/data/source/base.py +0 -30
- hcpdiff/data/utils.py +0 -80
- hcpdiff/infer_workflow.py +0 -57
- hcpdiff/loggers/__init__.py +0 -13
- hcpdiff/loggers/base_logger.py +0 -76
- hcpdiff/loggers/cli_logger.py +0 -40
- hcpdiff/loggers/preview/__init__.py +0 -1
- hcpdiff/loggers/preview/image_previewer.py +0 -149
- hcpdiff/loggers/tensorboard_logger.py +0 -30
- hcpdiff/loggers/wandb_logger.py +0 -31
- hcpdiff/loggers/webui_logger.py +0 -9
- hcpdiff/loss/min_snr_loss.py +0 -52
- hcpdiff/models/layers.py +0 -81
- hcpdiff/models/plugin.py +0 -348
- hcpdiff/models/wrapper.py +0 -75
- hcpdiff/noise/__init__.py +0 -3
- hcpdiff/noise/noise_base.py +0 -16
- hcpdiff/noise/pyramid_noise.py +0 -50
- hcpdiff/noise/zero_terminal.py +0 -44
- hcpdiff/train_ac.py +0 -565
- hcpdiff/train_ac_single.py +0 -39
- hcpdiff/utils/caption_tools.py +0 -105
- hcpdiff/utils/cfg_net_tools.py +0 -321
- hcpdiff/utils/cfg_resolvers.py +0 -16
- hcpdiff/utils/ema.py +0 -52
- hcpdiff/utils/img_size_tool.py +0 -248
- hcpdiff/vis/__init__.py +0 -3
- hcpdiff/vis/base_interface.py +0 -12
- hcpdiff/vis/disk_interface.py +0 -48
- hcpdiff/vis/webui_interface.py +0 -17
- hcpdiff/visualizer.py +0 -258
- hcpdiff/visualizer_reloadable.py +0 -237
- hcpdiff/workflow/base.py +0 -59
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -57
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero2.json +0 -32
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero3.json +0 -39
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
- hcpdiff-0.9.0.dist-info/METADATA +0 -199
- hcpdiff-0.9.0.dist-info/RECORD +0 -155
- hcpdiff-0.9.0.dist-info/entry_points.txt +0 -2
- {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
- {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
hcpdiff/loggers/wandb_logger.py
DELETED
@@ -1,31 +0,0 @@
|
|
1
|
-
from typing import Dict, Any
|
2
|
-
|
3
|
-
import os
|
4
|
-
import wandb
|
5
|
-
from PIL import Image
|
6
|
-
|
7
|
-
from .base_logger import BaseLogger
|
8
|
-
|
9
|
-
|
10
|
-
class WanDBLogger(BaseLogger):
|
11
|
-
def __init__(self, exp_dir, out_path=None, enable_log_image=False, project='hcp-diffusion', log_step=10, image_log_step=200):
|
12
|
-
super().__init__(exp_dir, out_path, enable_log_image, log_step, image_log_step)
|
13
|
-
if exp_dir is not None: # exp_dir is only available in local main process
|
14
|
-
wandb.init(project=project, name=os.path.basename(exp_dir))
|
15
|
-
wandb.save(os.path.join(exp_dir, 'cfg.yaml'), base_path=exp_dir)
|
16
|
-
else:
|
17
|
-
self.writer = None
|
18
|
-
self.disable()
|
19
|
-
|
20
|
-
def _info(self, info):
|
21
|
-
pass
|
22
|
-
|
23
|
-
def _log(self, datas: Dict[str, Any], step: int = 0):
|
24
|
-
log_dict = {'step': step}
|
25
|
-
for k, v in datas.items():
|
26
|
-
if len(v['data']) == 1:
|
27
|
-
log_dict[k] = v['data'][0]
|
28
|
-
wandb.log(log_dict)
|
29
|
-
|
30
|
-
def _log_image(self, imgs: Dict[str, Image.Image], step: int = 0):
|
31
|
-
wandb.log({next(iter(imgs.keys())): list(imgs.values())}, step=step)
|
hcpdiff/loggers/webui_logger.py
DELETED
@@ -1,9 +0,0 @@
|
|
1
|
-
from typing import Dict, Any
|
2
|
-
|
3
|
-
from loguru import logger
|
4
|
-
|
5
|
-
from .cli_logger import CLILogger
|
6
|
-
|
7
|
-
class WebUILogger(CLILogger):
|
8
|
-
def _log(self, datas: Dict[str, Any], step: int = 0):
|
9
|
-
logger.info('this progress steps:'+', '.join([f"{k} = {v['format'].format(*v['data'])}" for k, v in datas.items()]))
|
hcpdiff/loss/min_snr_loss.py
DELETED
@@ -1,52 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
from diffusers import SchedulerMixin
|
3
|
-
from torch import nn
|
4
|
-
|
5
|
-
class MinSNRLoss(nn.MSELoss):
|
6
|
-
need_timesteps = True
|
7
|
-
|
8
|
-
def __init__(self, size_average=None, reduce=None, reduction: str = 'none', gamma=1.,
|
9
|
-
noise_scheduler: SchedulerMixin = None, device='cuda:0', **kwargs):
|
10
|
-
super().__init__(size_average, reduce, reduction)
|
11
|
-
self.gamma = gamma
|
12
|
-
|
13
|
-
# calculate SNR
|
14
|
-
alphas_cumprod = noise_scheduler.alphas_cumprod
|
15
|
-
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
16
|
-
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0-alphas_cumprod)
|
17
|
-
self.alpha = sqrt_alphas_cumprod.to(device)
|
18
|
-
self.sigma = sqrt_one_minus_alphas_cumprod.to(device)
|
19
|
-
self.all_snr = ((self.alpha/self.sigma)**2).to(device)
|
20
|
-
|
21
|
-
def forward(self, input: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
|
22
|
-
loss = super(MinSNRLoss, self).forward(input, target)
|
23
|
-
snr = self.all_snr[timesteps[:loss.shape[0], ...].squeeze()]
|
24
|
-
snr_weight = (self.gamma/snr).clip(max=1.).float()
|
25
|
-
return loss*snr_weight.view(-1, 1, 1, 1)
|
26
|
-
|
27
|
-
|
28
|
-
class SoftMinSNRLoss(MinSNRLoss):
|
29
|
-
# gamma=2
|
30
|
-
|
31
|
-
def forward(self, input: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
|
32
|
-
loss = super(MinSNRLoss, self).forward(input, target)
|
33
|
-
snr = self.all_snr[timesteps[:loss.shape[0], ...].squeeze()]
|
34
|
-
snr_weight = (self.gamma**3/(snr**2 + self.gamma**3)).float()
|
35
|
-
return loss*snr_weight.view(-1, 1, 1, 1)
|
36
|
-
|
37
|
-
class KDiffMinSNRLoss(MinSNRLoss):
|
38
|
-
|
39
|
-
def forward(self, input: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
|
40
|
-
loss = super(MinSNRLoss, self).forward(input, target)
|
41
|
-
snr = self.all_snr[timesteps[:loss.shape[0], ...].squeeze()]
|
42
|
-
snr_weight = 4*(((self.gamma*snr)**2/(snr**2 + self.gamma**2)**2)).float()
|
43
|
-
return loss*snr_weight.view(-1, 1, 1, 1)
|
44
|
-
|
45
|
-
class EDMLoss(MinSNRLoss):
|
46
|
-
|
47
|
-
def forward(self, input: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
|
48
|
-
loss = super(MinSNRLoss, self).forward(input, target)
|
49
|
-
sigma = self.sigma[timesteps[:loss.shape[0], ...].squeeze()]
|
50
|
-
snr = self.all_snr[timesteps[:loss.shape[0], ...].squeeze()]
|
51
|
-
snr_weight = ((sigma**2+self.gamma**2)/(snr*(sigma*self.gamma)**2)).float()
|
52
|
-
return loss*snr_weight.view(-1, 1, 1, 1)
|
hcpdiff/models/layers.py
DELETED
@@ -1,81 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
layers.py
|
3
|
-
====================
|
4
|
-
:Name: GroupLinear and other layers
|
5
|
-
:Author: Dong Ziyi
|
6
|
-
:Affiliation: HCP Lab, SYSU
|
7
|
-
:Created: 09/04/2023
|
8
|
-
:Licence: Apache-2.0
|
9
|
-
"""
|
10
|
-
|
11
|
-
import torch
|
12
|
-
from torch import nn
|
13
|
-
import math
|
14
|
-
from einops import rearrange
|
15
|
-
|
16
|
-
class GroupLinear(nn.Module):
|
17
|
-
def __init__(self, in_features: int, out_features: int, groups: int, bias: bool = True,
|
18
|
-
device=None, dtype=None):
|
19
|
-
super().__init__()
|
20
|
-
assert in_features%groups == 0
|
21
|
-
assert out_features%groups == 0
|
22
|
-
|
23
|
-
factory_kwargs = {'device': device, 'dtype': dtype}
|
24
|
-
|
25
|
-
self.groups = groups
|
26
|
-
self.in_features = in_features
|
27
|
-
self.out_features = out_features
|
28
|
-
|
29
|
-
self.weight = nn.Parameter(torch.empty((groups, in_features//groups, out_features//groups), **factory_kwargs))
|
30
|
-
if bias:
|
31
|
-
self.bias = nn.Parameter(torch.empty(groups, 1, out_features//groups, **factory_kwargs))
|
32
|
-
else:
|
33
|
-
self.register_parameter('bias', None)
|
34
|
-
self.reset_parameters()
|
35
|
-
|
36
|
-
def reset_parameters(self) -> None:
|
37
|
-
# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
|
38
|
-
# uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
|
39
|
-
# https://github.com/pytorch/pytorch/issues/57109
|
40
|
-
self.kaiming_uniform_group(self.weight, a=math.sqrt(5))
|
41
|
-
if self.bias is not None:
|
42
|
-
fan_in, _ = self._calculate_fan_in_and_fan_out(self.weight)
|
43
|
-
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
44
|
-
nn.init.uniform_(self.bias, -bound, bound)
|
45
|
-
|
46
|
-
@staticmethod
|
47
|
-
def _calculate_fan_in_and_fan_out(tensor):
|
48
|
-
receptive_field_size = 1
|
49
|
-
num_input_fmaps = tensor.size(-2)
|
50
|
-
num_output_fmaps = tensor.size(-1)
|
51
|
-
fan_in = num_input_fmaps * receptive_field_size
|
52
|
-
fan_out = num_output_fmaps * receptive_field_size
|
53
|
-
|
54
|
-
return fan_in, fan_out
|
55
|
-
|
56
|
-
@staticmethod
|
57
|
-
def kaiming_uniform_group(tensor: torch.Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu') -> torch.Tensor:
|
58
|
-
def _calculate_correct_fan(tensor, mode):
|
59
|
-
mode = mode.lower()
|
60
|
-
valid_modes = ['fan_in', 'fan_out']
|
61
|
-
if mode not in valid_modes:
|
62
|
-
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
|
63
|
-
|
64
|
-
fan_in, fan_out = GroupLinear._calculate_fan_in_and_fan_out(tensor)
|
65
|
-
return fan_in if mode == 'fan_in' else fan_out
|
66
|
-
|
67
|
-
fan = _calculate_correct_fan(tensor, mode)
|
68
|
-
gain = nn.init.calculate_gain(nonlinearity, a)
|
69
|
-
std = gain / math.sqrt(fan)
|
70
|
-
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
|
71
|
-
with torch.no_grad():
|
72
|
-
return tensor.uniform_(-bound, bound)
|
73
|
-
|
74
|
-
def forward(self, x: torch.Tensor): # x: [G,B,L,C]
|
75
|
-
x = rearrange(x, '(g b) l c -> g (b l) c', g=self.num_groups)
|
76
|
-
if self.bias is not None:
|
77
|
-
out = torch.bmm(x, self.weight) + self.bias
|
78
|
-
else:
|
79
|
-
out = torch.bmm(x, self.weight)
|
80
|
-
out = rearrange(out, 'g (b l) c -> (g b) l c', b=B)
|
81
|
-
return out
|
hcpdiff/models/plugin.py
DELETED
@@ -1,348 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
plugin.py
|
3
|
-
====================
|
4
|
-
:Name: model plugin
|
5
|
-
:Author: Dong Ziyi
|
6
|
-
:Affiliation: HCP Lab, SYSU
|
7
|
-
:Created: 10/03/2023
|
8
|
-
:Licence: Apache-2.0
|
9
|
-
"""
|
10
|
-
|
11
|
-
import weakref
|
12
|
-
import re
|
13
|
-
from typing import Tuple, List, Dict, Any, Iterable
|
14
|
-
|
15
|
-
import torch
|
16
|
-
from torch import nn
|
17
|
-
|
18
|
-
from hcpdiff.utils.net_utils import split_module_name
|
19
|
-
|
20
|
-
class BasePluginBlock(nn.Module):
|
21
|
-
def __init__(self, name: str):
|
22
|
-
super().__init__()
|
23
|
-
self.name = name
|
24
|
-
|
25
|
-
def forward(self, host: nn.Module, fea_in: Tuple[torch.Tensor], fea_out: torch.Tensor):
|
26
|
-
return fea_out
|
27
|
-
|
28
|
-
def remove(self):
|
29
|
-
pass
|
30
|
-
|
31
|
-
def feed_input_data(self, data):
|
32
|
-
self.input_data = data
|
33
|
-
|
34
|
-
def register_input_feeder_to(self, host_model):
|
35
|
-
if not hasattr(host_model, 'input_feeder'):
|
36
|
-
host_model.input_feeder = []
|
37
|
-
host_model.input_feeder.append(self.feed_input_data)
|
38
|
-
|
39
|
-
def set_hyper_params(self, **kwargs):
|
40
|
-
for k, v in kwargs.items():
|
41
|
-
setattr(self, k, v)
|
42
|
-
|
43
|
-
@staticmethod
|
44
|
-
def extract_state_without_plugin(model: nn.Module, trainable=False):
|
45
|
-
trainable_keys = {k for k, v in model.named_parameters() if v.requires_grad}
|
46
|
-
plugin_names = {k for k, v in model.named_modules() if isinstance(v, BasePluginBlock)}
|
47
|
-
model_sd = {}
|
48
|
-
for k, v in model.state_dict().items():
|
49
|
-
if (not trainable) or k in trainable_keys:
|
50
|
-
for name in plugin_names:
|
51
|
-
if k.startswith(name):
|
52
|
-
break
|
53
|
-
else:
|
54
|
-
model_sd[k] = v
|
55
|
-
return model_sd
|
56
|
-
|
57
|
-
def get_trainable_parameters(self) -> Iterable[nn.Parameter]:
|
58
|
-
return self.parameters()
|
59
|
-
|
60
|
-
class WrapablePlugin:
|
61
|
-
wrapable_classes = ()
|
62
|
-
|
63
|
-
@classmethod
|
64
|
-
def wrap_layer(cls, name: str, layer: nn.Module, **kwargs):
|
65
|
-
plugin = cls(name, layer, **kwargs)
|
66
|
-
return plugin
|
67
|
-
|
68
|
-
@classmethod
|
69
|
-
def named_modules_with_exclude(cls, self, memo = None, prefix: str = '', remove_duplicate: bool = True,
|
70
|
-
exclude_key=None, exclude_classes=tuple()):
|
71
|
-
|
72
|
-
if memo is None:
|
73
|
-
memo = set()
|
74
|
-
if self not in memo:
|
75
|
-
if remove_duplicate:
|
76
|
-
memo.add(self)
|
77
|
-
if (exclude_key is None or not re.search(exclude_key, prefix)) and not isinstance(self, exclude_classes):
|
78
|
-
yield prefix, self
|
79
|
-
for name, module in self._modules.items():
|
80
|
-
if module is None:
|
81
|
-
continue
|
82
|
-
submodule_prefix = prefix + ('.' if prefix else '') + name
|
83
|
-
for m in cls.named_modules_with_exclude(module, memo, submodule_prefix, remove_duplicate, exclude_key, exclude_classes):
|
84
|
-
yield m
|
85
|
-
|
86
|
-
@classmethod
|
87
|
-
def wrap_model(cls, name: str, host: nn.Module, exclude_key=None, exclude_classes=tuple(), **kwargs): # -> Dict[str, SinglePluginBlock]:
|
88
|
-
'''
|
89
|
-
parent_block and other args required in __init__ will be put into kwargs, compatible with multiple models.
|
90
|
-
'''
|
91
|
-
plugin_block_dict = {}
|
92
|
-
if isinstance(host, cls.wrapable_classes):
|
93
|
-
plugin_block_dict[''] = cls.wrap_layer(name, host, **kwargs)
|
94
|
-
else:
|
95
|
-
named_modules = {layer_name:layer for layer_name, layer in cls.named_modules_with_exclude(
|
96
|
-
host, exclude_key=exclude_key, exclude_classes=exclude_classes)}
|
97
|
-
for layer_name, layer in named_modules.items():
|
98
|
-
if isinstance(layer, cls.wrapable_classes):
|
99
|
-
# For plugins that need parent_block
|
100
|
-
if 'parent_block' in kwargs:
|
101
|
-
parent_name, host_name = split_module_name(layer_name)
|
102
|
-
kwargs['parent_block'] = named_modules[parent_name]
|
103
|
-
kwargs['host_name'] = host_name
|
104
|
-
plugin_block_dict[layer_name] = cls.wrap_layer(name, layer, **kwargs)
|
105
|
-
return plugin_block_dict
|
106
|
-
|
107
|
-
class SinglePluginBlock(BasePluginBlock, WrapablePlugin):
|
108
|
-
|
109
|
-
def __init__(self, name: str, host: nn.Module, hook_param=None, host_model=None):
|
110
|
-
super().__init__(name)
|
111
|
-
self.host = weakref.ref(host)
|
112
|
-
setattr(host, name, self)
|
113
|
-
|
114
|
-
if hook_param is None:
|
115
|
-
self.hook_handle = host.register_forward_hook(self.layer_hook)
|
116
|
-
else: # hook for model parameters
|
117
|
-
self.backup = getattr(host, hook_param)
|
118
|
-
self.target = hook_param
|
119
|
-
self.handle_pre = host.register_forward_pre_hook(self.pre_hook)
|
120
|
-
self.handle_post = host.register_forward_hook(self.post_hook)
|
121
|
-
|
122
|
-
def layer_hook(self, host, fea_in: Tuple[torch.Tensor], fea_out: torch.Tensor):
|
123
|
-
return self(fea_in, fea_out)
|
124
|
-
|
125
|
-
def pre_hook(self, host, fea_in: torch.Tensor):
|
126
|
-
host.weight_restored = False
|
127
|
-
host_param = getattr(host, self.target)
|
128
|
-
delattr(host, self.target)
|
129
|
-
setattr(host, self.target, self(host_param))
|
130
|
-
return fea_in
|
131
|
-
|
132
|
-
def post_hook(self, host, fea_int, fea_out):
|
133
|
-
if not getattr(host, 'weight_restored', False):
|
134
|
-
setattr(host, self.target, self.backup)
|
135
|
-
host.weight_restored = True
|
136
|
-
|
137
|
-
def remove(self):
|
138
|
-
host = self.host()
|
139
|
-
delattr(host, self.name)
|
140
|
-
if hasattr(self, 'hook_handle'):
|
141
|
-
self.hook_handle.remove()
|
142
|
-
else:
|
143
|
-
self.handle_pre.remove()
|
144
|
-
self.handle_post.remove()
|
145
|
-
|
146
|
-
class PluginBlock(BasePluginBlock):
|
147
|
-
def __init__(self, name, from_layer: Dict[str, Any], to_layer: Dict[str, Any], host_model=None):
|
148
|
-
super().__init__(name)
|
149
|
-
self.host_from = weakref.ref(from_layer['layer'])
|
150
|
-
self.host_to = weakref.ref(to_layer['layer'])
|
151
|
-
setattr(from_layer['layer'], name, self)
|
152
|
-
|
153
|
-
if from_layer['pre_hook']:
|
154
|
-
self.hook_handle_from = from_layer['layer'].register_forward_pre_hook(lambda host, fea_in:self.from_layer_hook(host, fea_in, None))
|
155
|
-
else:
|
156
|
-
self.hook_handle_from = from_layer['layer'].register_forward_hook(
|
157
|
-
lambda host, fea_in, fea_out:self.from_layer_hook(host, fea_in, fea_out))
|
158
|
-
if to_layer['pre_hook']:
|
159
|
-
self.hook_handle_to = to_layer['layer'].register_forward_pre_hook(lambda host, fea_in:self.to_layer_hook(host, fea_in, None))
|
160
|
-
else:
|
161
|
-
self.hook_handle_to = to_layer['layer'].register_forward_hook(lambda host, fea_in, fea_out:self.to_layer_hook(host, fea_in, fea_out))
|
162
|
-
|
163
|
-
def from_layer_hook(self, host, fea_in: Tuple[torch.Tensor], fea_out: torch.Tensor):
|
164
|
-
self.feat_from = fea_in
|
165
|
-
|
166
|
-
def to_layer_hook(self, host, fea_in: Tuple[torch.Tensor], fea_out: torch.Tensor):
|
167
|
-
return self(self.feat_from, fea_in, fea_out)
|
168
|
-
|
169
|
-
def remove(self):
|
170
|
-
host_from = self.host_from()
|
171
|
-
delattr(host_from, self.name)
|
172
|
-
self.hook_handle_from.remove()
|
173
|
-
self.hook_handle_to.remove()
|
174
|
-
|
175
|
-
class MultiPluginBlock(BasePluginBlock):
|
176
|
-
def __init__(self, name: str, from_layers: List[Dict[str, Any]], to_layers: List[Dict[str, Any]], host_model=None):
|
177
|
-
super().__init__(name)
|
178
|
-
assert host_model is not None
|
179
|
-
self.host_from = [weakref.ref(x['layer']) for x in from_layers]
|
180
|
-
self.host_to = [weakref.ref(x['layer']) for x in to_layers]
|
181
|
-
self.host_model = weakref.ref(host_model)
|
182
|
-
setattr(host_model, name, self)
|
183
|
-
|
184
|
-
self.feat_from = [None for _ in range(len(from_layers))]
|
185
|
-
|
186
|
-
self.hook_handle_from = []
|
187
|
-
self.hook_handle_to = []
|
188
|
-
|
189
|
-
for idx, layer in enumerate(from_layers):
|
190
|
-
if layer['pre_hook']:
|
191
|
-
handle_from = layer['layer'].register_forward_pre_hook(lambda host, fea_in, idx=idx:self.from_layer_hook(host, fea_in, None, idx))
|
192
|
-
else:
|
193
|
-
handle_from = layer['layer'].register_forward_hook(
|
194
|
-
lambda host, fea_in, fea_out, idx=idx:self.from_layer_hook(host, fea_in, fea_out, idx))
|
195
|
-
self.hook_handle_from.append(handle_from)
|
196
|
-
for idx, layer in enumerate(to_layers):
|
197
|
-
if layer['pre_hook']:
|
198
|
-
handle_to = layer['layer'].register_forward_pre_hook(lambda host, fea_in, idx=idx:self.to_layer_hook(host, fea_in, None, idx))
|
199
|
-
else:
|
200
|
-
handle_to = layer['layer'].register_forward_hook(lambda host, fea_in, fea_out, idx=idx:self.to_layer_hook(host, fea_in, fea_out, idx))
|
201
|
-
self.hook_handle_to.append(handle_to)
|
202
|
-
|
203
|
-
self.record_count = 0
|
204
|
-
|
205
|
-
def from_layer_hook(self, host, fea_in: Tuple[torch.Tensor], fea_out: Tuple[torch.Tensor], idx: int):
|
206
|
-
self.feat_from[idx] = fea_in
|
207
|
-
self.record_count += 1
|
208
|
-
if self.record_count == len(self.feat_from): # call forward when all feat is record
|
209
|
-
self.record_count = 0
|
210
|
-
self.feat_to = self(self.feat_from)
|
211
|
-
|
212
|
-
def to_layer_hook(self, host, fea_in: Tuple[torch.Tensor], fea_out: Tuple[torch.Tensor], idx: int):
|
213
|
-
return self.feat_to[idx]+fea_out
|
214
|
-
|
215
|
-
def remove(self):
|
216
|
-
host_model = self.host_model()
|
217
|
-
delattr(host_model, self.name)
|
218
|
-
for handle_from in self.hook_handle_from:
|
219
|
-
handle_from.remove()
|
220
|
-
for handle_to in self.hook_handle_to:
|
221
|
-
handle_to.remove()
|
222
|
-
|
223
|
-
class PatchPluginContainer(nn.Module):
|
224
|
-
def __init__(self, host_name, host, parent_block):
|
225
|
-
super().__init__()
|
226
|
-
self._host = host
|
227
|
-
self.host_name = host_name
|
228
|
-
self.parent_block = weakref.ref(parent_block)
|
229
|
-
self.plugin_names = []
|
230
|
-
|
231
|
-
delattr(parent_block, host_name)
|
232
|
-
setattr(parent_block, host_name, self)
|
233
|
-
|
234
|
-
def add_plugin(self, name: str, plugin: 'PatchPluginBlock'):
|
235
|
-
setattr(self, name, plugin)
|
236
|
-
self.plugin_names.append(name)
|
237
|
-
|
238
|
-
def remove_plugin(self, name: str):
|
239
|
-
delattr(self, name)
|
240
|
-
self.plugin_names.remove(name)
|
241
|
-
if len(self.plugin_names) == 0:
|
242
|
-
self.remove()
|
243
|
-
|
244
|
-
def forward(self, *args, **kwargs):
|
245
|
-
for name, plugin in self:
|
246
|
-
args, kwargs = plugin.pre_forward(*args, **kwargs)
|
247
|
-
output = self._host(*args, **kwargs)
|
248
|
-
for name, plugin in self:
|
249
|
-
output = plugin.post_forward(output, *args, **kwargs)
|
250
|
-
return output
|
251
|
-
|
252
|
-
def remove(self):
|
253
|
-
parent_block = self.parent_block()
|
254
|
-
delattr(parent_block, self.host_name)
|
255
|
-
setattr(parent_block, self.host_name, self._host)
|
256
|
-
|
257
|
-
def __iter__(self):
|
258
|
-
for name in self.plugin_names:
|
259
|
-
yield name, self[name]
|
260
|
-
|
261
|
-
def __getitem__(self, name):
|
262
|
-
return getattr(self, name)
|
263
|
-
|
264
|
-
class PatchPluginBlock(BasePluginBlock, WrapablePlugin):
|
265
|
-
container_cls = PatchPluginContainer
|
266
|
-
|
267
|
-
def __init__(self, name: str, host: nn.Module, host_model=None, parent_block: nn.Module = None, host_name: str = None):
|
268
|
-
super().__init__(name)
|
269
|
-
if isinstance(host, self.container_cls):
|
270
|
-
self.host = weakref.ref(host._host)
|
271
|
-
else:
|
272
|
-
self.host = weakref.ref(host)
|
273
|
-
self.parent_block = weakref.ref(parent_block)
|
274
|
-
self.host_name = host_name
|
275
|
-
|
276
|
-
container = self.get_container(host, host_name, parent_block)
|
277
|
-
container.add_plugin(name, self)
|
278
|
-
self.container = weakref.ref(container)
|
279
|
-
|
280
|
-
def pre_forward(self, *args, **kwargs):
|
281
|
-
return args, kwargs
|
282
|
-
|
283
|
-
def post_forward(self, output, *args, **kwargs):
|
284
|
-
return output
|
285
|
-
|
286
|
-
def remove(self):
|
287
|
-
container = self.container()
|
288
|
-
container.remove_plugin(self.name)
|
289
|
-
|
290
|
-
def get_container(self, host, host_name, parent_block):
|
291
|
-
if isinstance(host, self.container_cls):
|
292
|
-
return host
|
293
|
-
else:
|
294
|
-
return self.container_cls(host_name, host, parent_block)
|
295
|
-
|
296
|
-
@classmethod
|
297
|
-
def wrap_model(cls, name: str, host: nn.Module, exclude_key=None, exclude_classes=tuple(), **kwargs): # -> Dict[str, SinglePluginBlock]:
|
298
|
-
'''
|
299
|
-
parent_block and other args required in __init__ will be put into kwargs, compatible with multiple models.
|
300
|
-
'''
|
301
|
-
plugin_block_dict = {}
|
302
|
-
if isinstance(host, cls.wrapable_classes):
|
303
|
-
plugin_block_dict[''] = cls.wrap_layer(name, host, **kwargs)
|
304
|
-
else:
|
305
|
-
named_modules = {layer_name:layer for layer_name, layer in cls.named_modules_with_exclude(
|
306
|
-
host, exclude_key=exclude_key or '_host', exclude_classes=exclude_classes)}
|
307
|
-
for layer_name, layer in named_modules.items():
|
308
|
-
if isinstance(layer, cls.wrapable_classes) or isinstance(layer, cls.container_cls):
|
309
|
-
# For plugins that need parent_block
|
310
|
-
if 'parent_block' in kwargs:
|
311
|
-
parent_name, host_name = split_module_name(layer_name)
|
312
|
-
kwargs['parent_block'] = named_modules[parent_name]
|
313
|
-
kwargs['host_name'] = host_name
|
314
|
-
plugin_block_dict[layer_name] = cls.wrap_layer(name, layer, **kwargs)
|
315
|
-
return plugin_block_dict
|
316
|
-
|
317
|
-
class PluginGroup:
|
318
|
-
def __init__(self, plugin_dict: Dict[str, BasePluginBlock]):
|
319
|
-
self.plugin_dict = plugin_dict # {host_model_path: plugin_object}
|
320
|
-
|
321
|
-
def __setitem__(self, k, v):
|
322
|
-
self.plugin_dict[k] = v
|
323
|
-
|
324
|
-
def __getitem__(self, k):
|
325
|
-
return self.plugin_dict[k]
|
326
|
-
|
327
|
-
@property
|
328
|
-
def plugin_name(self):
|
329
|
-
if self.empty():
|
330
|
-
return None
|
331
|
-
return next(iter(self.plugin_dict.values())).name
|
332
|
-
|
333
|
-
def remove(self):
|
334
|
-
for plugin in self.plugin_dict.values():
|
335
|
-
plugin.remove()
|
336
|
-
|
337
|
-
def state_dict(self, model=None):
|
338
|
-
if model is None:
|
339
|
-
return {f'{k}.___.{ks}':vs for k, v in self.plugin_dict.items() for ks, vs in v.state_dict().items()}
|
340
|
-
else:
|
341
|
-
sd_model = model.state_dict()
|
342
|
-
return {f'{k}.___.{ks}':sd_model[f'{k}.{v.name}.{ks}'] for k, v in self.plugin_dict.items() for ks, vs in v.state_dict().items()}
|
343
|
-
|
344
|
-
def state_keys_raw(self):
|
345
|
-
return [f'{k}.{v.name}.{ks}' for k, v in self.plugin_dict.items() for ks, vs in v.state_dict().items()]
|
346
|
-
|
347
|
-
def empty(self):
|
348
|
-
return len(self.plugin_dict) == 0
|
hcpdiff/models/wrapper.py
DELETED
@@ -1,75 +0,0 @@
|
|
1
|
-
from torch import nn
|
2
|
-
import itertools
|
3
|
-
from transformers import CLIPTextModel
|
4
|
-
from hcpdiff.utils import pad_attn_bias
|
5
|
-
|
6
|
-
class TEUnetWrapper(nn.Module):
|
7
|
-
def __init__(self, unet, TE, train_TE=False):
|
8
|
-
super().__init__()
|
9
|
-
self.unet = unet
|
10
|
-
self.TE = TE
|
11
|
-
|
12
|
-
self.train_TE = train_TE
|
13
|
-
|
14
|
-
def forward(self, prompt_ids, noisy_latents, timesteps, attn_mask=None, position_ids=None, plugin_input={}, **kwargs):
|
15
|
-
input_all = dict(prompt_ids=prompt_ids, noisy_latents=noisy_latents, timesteps=timesteps, position_ids=position_ids, attn_mask=attn_mask, **plugin_input)
|
16
|
-
|
17
|
-
if hasattr(self.TE, 'input_feeder'):
|
18
|
-
for feeder in self.TE.input_feeder:
|
19
|
-
feeder(input_all)
|
20
|
-
encoder_hidden_states = self.TE(prompt_ids, position_ids=position_ids, attention_mask=attn_mask, output_hidden_states=True)[0] # Get the text embedding for conditioning
|
21
|
-
|
22
|
-
if attn_mask is not None:
|
23
|
-
encoder_hidden_states, attn_mask = pad_attn_bias(encoder_hidden_states, attn_mask)
|
24
|
-
|
25
|
-
input_all['encoder_hidden_states'] = encoder_hidden_states
|
26
|
-
if hasattr(self.unet, 'input_feeder'):
|
27
|
-
for feeder in self.unet.input_feeder:
|
28
|
-
feeder(input_all)
|
29
|
-
model_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, encoder_attention_mask=attn_mask).sample # Predict the noise residual
|
30
|
-
return model_pred
|
31
|
-
|
32
|
-
def prepare(self, accelerator):
|
33
|
-
if self.train_TE:
|
34
|
-
return accelerator.prepare(self)
|
35
|
-
else:
|
36
|
-
self.unet = accelerator.prepare(self.unet)
|
37
|
-
return self
|
38
|
-
|
39
|
-
def enable_gradient_checkpointing(self):
|
40
|
-
def grad_ckpt_enable(m):
|
41
|
-
if hasattr(m, 'gradient_checkpointing'):
|
42
|
-
m.training = True
|
43
|
-
|
44
|
-
self.unet.enable_gradient_checkpointing()
|
45
|
-
if self.train_TE:
|
46
|
-
self.TE.gradient_checkpointing_enable()
|
47
|
-
self.apply(grad_ckpt_enable)
|
48
|
-
else:
|
49
|
-
self.unet.apply(grad_ckpt_enable)
|
50
|
-
|
51
|
-
def trainable_parameters(self):
|
52
|
-
if self.train_TE:
|
53
|
-
return itertools.chain(self.unet.parameters(), self.TE.parameters())
|
54
|
-
else:
|
55
|
-
return self.unet.parameters()
|
56
|
-
|
57
|
-
class SDXLTEUnetWrapper(TEUnetWrapper):
|
58
|
-
def forward(self, prompt_ids, noisy_latents, timesteps, attn_mask=None, position_ids=None, crop_info=None, plugin_input={}, **kwargs):
|
59
|
-
input_all = dict(prompt_ids=prompt_ids, noisy_latents=noisy_latents, timesteps=timesteps, position_ids=position_ids, attn_mask=attn_mask, **plugin_input)
|
60
|
-
|
61
|
-
if hasattr(self.TE, 'input_feeder'):
|
62
|
-
for feeder in self.TE.input_feeder:
|
63
|
-
feeder(input_all)
|
64
|
-
encoder_hidden_states, pooled_output = self.TE(prompt_ids, position_ids=position_ids, attention_mask=attn_mask, output_hidden_states=True) # Get the text embedding for conditioning
|
65
|
-
|
66
|
-
added_cond_kwargs = {"text_embeds":pooled_output[-1], "time_ids":crop_info}
|
67
|
-
if attn_mask is not None:
|
68
|
-
encoder_hidden_states, attn_mask = pad_attn_bias(encoder_hidden_states, attn_mask)
|
69
|
-
|
70
|
-
input_all['encoder_hidden_states'] = encoder_hidden_states
|
71
|
-
if hasattr(self.unet, 'input_feeder'):
|
72
|
-
for feeder in self.unet.input_feeder:
|
73
|
-
feeder(input_all)
|
74
|
-
model_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, encoder_attention_mask=attn_mask, added_cond_kwargs=added_cond_kwargs).sample # Predict the noise residual
|
75
|
-
return model_pred
|
hcpdiff/noise/__init__.py
DELETED
hcpdiff/noise/noise_base.py
DELETED
@@ -1,16 +0,0 @@
|
|
1
|
-
|
2
|
-
class NoiseBase:
|
3
|
-
def __init__(self, base_scheduler):
|
4
|
-
self.base_scheduler = base_scheduler
|
5
|
-
|
6
|
-
def __getattr__(self, item):
|
7
|
-
try:
|
8
|
-
return super(NoiseBase, self).__getattr__(item)
|
9
|
-
except:
|
10
|
-
return getattr(self.base_scheduler, item)
|
11
|
-
|
12
|
-
def __setattr__(self, key, value):
|
13
|
-
if hasattr(super(), 'base_scheduler') and hasattr(self.base_scheduler, key):
|
14
|
-
setattr(self.base_scheduler, key, value)
|
15
|
-
else:
|
16
|
-
super(NoiseBase, self).__setattr__(key, value)
|
hcpdiff/noise/pyramid_noise.py
DELETED
@@ -1,50 +0,0 @@
|
|
1
|
-
import random
|
2
|
-
|
3
|
-
import torch
|
4
|
-
from torch.nn import functional as F
|
5
|
-
from diffusers import SchedulerMixin
|
6
|
-
|
7
|
-
from .noise_base import NoiseBase
|
8
|
-
|
9
|
-
class PyramidNoiseScheduler(NoiseBase, SchedulerMixin):
|
10
|
-
def __init__(self, base_scheduler, level: int = 10, discount: float = 0.9, step_size: float = 2., resize_mode: str = 'bilinear'):
|
11
|
-
super().__init__(base_scheduler)
|
12
|
-
self.level = level
|
13
|
-
self.step_size = step_size
|
14
|
-
self.resize_mode = resize_mode
|
15
|
-
self.discount = discount
|
16
|
-
|
17
|
-
def add_noise(
|
18
|
-
self,
|
19
|
-
original_samples: torch.FloatTensor,
|
20
|
-
noise: torch.FloatTensor,
|
21
|
-
timesteps: torch.IntTensor,
|
22
|
-
) -> torch.FloatTensor:
|
23
|
-
with torch.no_grad():
|
24
|
-
b, c, h, w = noise.shape
|
25
|
-
for i in range(1, self.level):
|
26
|
-
r = random.random()*2+self.step_size
|
27
|
-
wn, hn = max(1, int(w/(r**i))), max(1, int(h/(r**i)))
|
28
|
-
noise += F.interpolate(torch.randn(b, c, hn, wn).to(noise), (h, w), None, self.resize_mode)*(self.discount**i)
|
29
|
-
if wn == 1 or hn == 1:
|
30
|
-
break
|
31
|
-
noise = noise/noise.std()
|
32
|
-
return self.base_scheduler.add_noise(original_samples, noise, timesteps)
|
33
|
-
|
34
|
-
# if __name__ == '__main__':
|
35
|
-
# noise = torch.randn(1,3,512,512)
|
36
|
-
# level=10
|
37
|
-
# discount=0.6
|
38
|
-
# b, c, h, w = noise.shape
|
39
|
-
# for i in range(level):
|
40
|
-
# r = random.random() * 2 + 2
|
41
|
-
# wn, hn = max(1, int(w / (r ** i))), max(1, int(h / (r ** i)))
|
42
|
-
# noise += F.interpolate(torch.randn(b, c, wn, hn).to(noise), (w, h), None, 'bilinear') * discount ** i
|
43
|
-
# if wn == 1 or hn == 1:
|
44
|
-
# break
|
45
|
-
# noise = noise / noise.std()
|
46
|
-
#
|
47
|
-
# from matplotlib import pyplot as plt
|
48
|
-
# plt.figure()
|
49
|
-
# plt.imshow(noise[0].permute(1,2,0))
|
50
|
-
# plt.show()
|