hcpdiff 0.9.1__py3-none-any.whl → 2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/__init__.py +4 -4
- hcpdiff/ckpt_manager/__init__.py +4 -5
- hcpdiff/ckpt_manager/ckpt.py +24 -0
- hcpdiff/ckpt_manager/format/__init__.py +4 -0
- hcpdiff/ckpt_manager/format/diffusers.py +59 -0
- hcpdiff/ckpt_manager/format/emb.py +21 -0
- hcpdiff/ckpt_manager/format/lora_webui.py +252 -0
- hcpdiff/ckpt_manager/format/sd_single.py +41 -0
- hcpdiff/ckpt_manager/loader.py +64 -0
- hcpdiff/data/__init__.py +4 -28
- hcpdiff/data/cache/__init__.py +1 -0
- hcpdiff/data/cache/vae.py +102 -0
- hcpdiff/data/dataset.py +20 -0
- hcpdiff/data/handler/__init__.py +3 -0
- hcpdiff/data/handler/controlnet.py +18 -0
- hcpdiff/data/handler/diffusion.py +90 -0
- hcpdiff/data/handler/text.py +111 -0
- hcpdiff/data/source/__init__.py +3 -3
- hcpdiff/data/source/folder_class.py +12 -29
- hcpdiff/data/source/text.py +40 -0
- hcpdiff/data/source/text2img.py +36 -74
- hcpdiff/data/source/text2img_cond.py +9 -15
- hcpdiff/diffusion/__init__.py +0 -0
- hcpdiff/diffusion/noise/__init__.py +2 -0
- hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
- hcpdiff/diffusion/noise/zero_terminal.py +39 -0
- hcpdiff/diffusion/sampler/__init__.py +5 -0
- hcpdiff/diffusion/sampler/base.py +72 -0
- hcpdiff/diffusion/sampler/ddpm.py +20 -0
- hcpdiff/diffusion/sampler/diffusers.py +66 -0
- hcpdiff/diffusion/sampler/edm.py +22 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
- hcpdiff/easy/__init__.py +2 -0
- hcpdiff/easy/cfg/__init__.py +3 -0
- hcpdiff/easy/cfg/sd15_train.py +207 -0
- hcpdiff/easy/cfg/sdxl_train.py +147 -0
- hcpdiff/easy/cfg/t2i.py +228 -0
- hcpdiff/easy/model/__init__.py +2 -0
- hcpdiff/easy/model/cnet.py +31 -0
- hcpdiff/easy/model/loader.py +79 -0
- hcpdiff/easy/sampler.py +46 -0
- hcpdiff/evaluate/__init__.py +1 -0
- hcpdiff/evaluate/previewer.py +60 -0
- hcpdiff/loss/__init__.py +4 -1
- hcpdiff/loss/base.py +41 -0
- hcpdiff/loss/gw.py +35 -0
- hcpdiff/loss/ssim.py +37 -0
- hcpdiff/loss/vlb.py +79 -0
- hcpdiff/loss/weighting.py +66 -0
- hcpdiff/models/__init__.py +2 -2
- hcpdiff/models/cfg_context.py +17 -14
- hcpdiff/models/compose/compose_hook.py +44 -23
- hcpdiff/models/compose/compose_tokenizer.py +21 -8
- hcpdiff/models/compose/sdxl_composer.py +4 -4
- hcpdiff/models/controlnet.py +16 -16
- hcpdiff/models/lora_base_patch.py +14 -25
- hcpdiff/models/lora_layers.py +3 -9
- hcpdiff/models/lora_layers_patch.py +14 -24
- hcpdiff/models/text_emb_ex.py +84 -6
- hcpdiff/models/textencoder_ex.py +54 -18
- hcpdiff/models/wrapper/__init__.py +3 -0
- hcpdiff/models/wrapper/pixart.py +19 -0
- hcpdiff/models/wrapper/sd.py +218 -0
- hcpdiff/models/wrapper/utils.py +20 -0
- hcpdiff/parser/__init__.py +1 -0
- hcpdiff/parser/embpt.py +32 -0
- hcpdiff/tools/convert_caption_txt2json.py +1 -1
- hcpdiff/tools/dataset_generator.py +94 -0
- hcpdiff/tools/download_hf_model.py +24 -0
- hcpdiff/tools/init_proj.py +3 -21
- hcpdiff/tools/lora_convert.py +18 -17
- hcpdiff/tools/save_model.py +12 -0
- hcpdiff/tools/sd2diffusers.py +1 -1
- hcpdiff/train_colo.py +1 -1
- hcpdiff/train_deepspeed.py +1 -1
- hcpdiff/trainer_ac.py +79 -0
- hcpdiff/trainer_ac_single.py +31 -0
- hcpdiff/utils/__init__.py +0 -2
- hcpdiff/utils/inpaint_pipe.py +7 -2
- hcpdiff/utils/net_utils.py +29 -6
- hcpdiff/utils/pipe_hook.py +24 -7
- hcpdiff/utils/utils.py +21 -4
- hcpdiff/workflow/__init__.py +15 -10
- hcpdiff/workflow/daam/__init__.py +1 -0
- hcpdiff/workflow/daam/act.py +66 -0
- hcpdiff/workflow/daam/hook.py +109 -0
- hcpdiff/workflow/diffusion.py +118 -128
- hcpdiff/workflow/fast.py +31 -0
- hcpdiff/workflow/flow.py +67 -0
- hcpdiff/workflow/io.py +36 -130
- hcpdiff/workflow/model.py +46 -43
- hcpdiff/workflow/text.py +60 -47
- hcpdiff/workflow/utils.py +32 -12
- hcpdiff/workflow/vae.py +37 -38
- hcpdiff-2.2.dist-info/METADATA +299 -0
- hcpdiff-2.2.dist-info/RECORD +115 -0
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/WHEEL +1 -1
- hcpdiff-2.2.dist-info/entry_points.txt +5 -0
- hcpdiff/ckpt_manager/base.py +0 -16
- hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
- hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
- hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
- hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
- hcpdiff/data/bucket.py +0 -358
- hcpdiff/data/caption_loader.py +0 -80
- hcpdiff/data/cond_dataset.py +0 -40
- hcpdiff/data/crop_info_dataset.py +0 -40
- hcpdiff/data/data_processor.py +0 -33
- hcpdiff/data/pair_dataset.py +0 -146
- hcpdiff/data/sampler.py +0 -54
- hcpdiff/data/source/base.py +0 -30
- hcpdiff/data/utils.py +0 -80
- hcpdiff/deprecated/__init__.py +0 -1
- hcpdiff/deprecated/cfg_converter.py +0 -81
- hcpdiff/deprecated/lora_convert.py +0 -31
- hcpdiff/infer_workflow.py +0 -57
- hcpdiff/loggers/__init__.py +0 -13
- hcpdiff/loggers/base_logger.py +0 -76
- hcpdiff/loggers/cli_logger.py +0 -40
- hcpdiff/loggers/preview/__init__.py +0 -1
- hcpdiff/loggers/preview/image_previewer.py +0 -149
- hcpdiff/loggers/tensorboard_logger.py +0 -30
- hcpdiff/loggers/wandb_logger.py +0 -31
- hcpdiff/loggers/webui_logger.py +0 -9
- hcpdiff/loss/min_snr_loss.py +0 -52
- hcpdiff/models/layers.py +0 -81
- hcpdiff/models/plugin.py +0 -348
- hcpdiff/models/wrapper.py +0 -75
- hcpdiff/noise/__init__.py +0 -3
- hcpdiff/noise/noise_base.py +0 -16
- hcpdiff/noise/pyramid_noise.py +0 -50
- hcpdiff/noise/zero_terminal.py +0 -44
- hcpdiff/train_ac.py +0 -566
- hcpdiff/train_ac_single.py +0 -39
- hcpdiff/utils/caption_tools.py +0 -105
- hcpdiff/utils/cfg_net_tools.py +0 -321
- hcpdiff/utils/cfg_resolvers.py +0 -16
- hcpdiff/utils/ema.py +0 -52
- hcpdiff/utils/img_size_tool.py +0 -248
- hcpdiff/vis/__init__.py +0 -3
- hcpdiff/vis/base_interface.py +0 -12
- hcpdiff/vis/disk_interface.py +0 -48
- hcpdiff/vis/webui_interface.py +0 -17
- hcpdiff/viser_fast.py +0 -138
- hcpdiff/visualizer.py +0 -265
- hcpdiff/visualizer_reloadable.py +0 -237
- hcpdiff/workflow/base.py +0 -59
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
- hcpdiff-0.9.1.dist-info/METADATA +0 -199
- hcpdiff-0.9.1.dist-info/RECORD +0 -160
- hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info/licenses}/LICENSE +0 -0
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/top_level.txt +0 -0
@@ -1,149 +0,0 @@
|
|
1
|
-
from contextlib import contextmanager
|
2
|
-
from typing import List
|
3
|
-
|
4
|
-
import hydra
|
5
|
-
import torch
|
6
|
-
from accelerate import infer_auto_device_map, dispatch_model
|
7
|
-
from accelerate.hooks import remove_hook_from_module
|
8
|
-
from diffusers import PNDMScheduler
|
9
|
-
from torch.cuda.amp import autocast
|
10
|
-
|
11
|
-
from hcpdiff.models import TokenizerHook
|
12
|
-
from hcpdiff.utils.net_utils import to_cpu
|
13
|
-
from hcpdiff.utils.utils import prepare_seed, load_config, size_to_int, int_to_size
|
14
|
-
from hcpdiff.utils.utils import to_validate_file
|
15
|
-
from hcpdiff.visualizer import Visualizer
|
16
|
-
|
17
|
-
class ImagePreviewer(Visualizer):
|
18
|
-
def __init__(self, infer_cfg, exp_dir, te_hook,
|
19
|
-
unet, TE, tokenizer, vae, save_cfg=False):
|
20
|
-
self.exp_dir = exp_dir
|
21
|
-
self.cfgs_raw = load_config(infer_cfg)
|
22
|
-
self.cfgs = hydra.utils.instantiate(self.cfgs_raw)
|
23
|
-
self.save_cfg = save_cfg
|
24
|
-
self.offload = 'offload' in self.cfgs and self.cfgs.offload is not None
|
25
|
-
self.dtype = self.dtype_dict[self.cfgs.dtype]
|
26
|
-
|
27
|
-
if getattr(self.cfgs.new_components, 'scheduler', None) is None:
|
28
|
-
scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule='scaled_linear')
|
29
|
-
else:
|
30
|
-
scheduler = self.cfgs.new_components.scheduler
|
31
|
-
|
32
|
-
pipe_cls = self.get_pipeline()
|
33
|
-
self.pipe = pipe_cls(vae=vae, text_encoder=TE, tokenizer=tokenizer, unet=unet, scheduler=scheduler, feature_extractor=None,
|
34
|
-
safety_checker=None, requires_safety_checker=False)
|
35
|
-
|
36
|
-
self.token_ex = TokenizerHook(tokenizer)
|
37
|
-
self.te_hook = te_hook
|
38
|
-
|
39
|
-
if self.cfgs.seed is not None:
|
40
|
-
self.seeds = list(range(self.cfgs.seed, self.cfgs.seed+self.cfgs.num*self.cfgs.bs))
|
41
|
-
else:
|
42
|
-
self.seeds = [None]*(self.cfgs.num*self.cfgs.bs)
|
43
|
-
|
44
|
-
def build_vae_offload(self, offload_cfg):
|
45
|
-
vram = size_to_int(offload_cfg.max_VRAM)
|
46
|
-
if not offload_cfg.vae_cpu:
|
47
|
-
device_map = infer_auto_device_map(self.pipe.vae, max_memory={0:int_to_size(vram >> 5), "cpu":offload_cfg.max_RAM}, dtype=torch.float32)
|
48
|
-
self.pipe.vae = dispatch_model(self.pipe.vae, device_map)
|
49
|
-
else:
|
50
|
-
to_cpu(self.pipe.vae)
|
51
|
-
self.vae_decode_raw = self.pipe.vae.decode
|
52
|
-
|
53
|
-
def vae_decode_offload(latents, return_dict=True, decode_raw=self.pipe.vae.decode):
|
54
|
-
self.pipe.vae.to(dtype=torch.float32)
|
55
|
-
res = decode_raw(latents.cpu().to(dtype=torch.float32), return_dict=return_dict)
|
56
|
-
return res
|
57
|
-
|
58
|
-
self.pipe.vae.decode = vae_decode_offload
|
59
|
-
|
60
|
-
self.vae_encode_raw = self.pipe.vae.encode
|
61
|
-
|
62
|
-
def vae_encode_offload(x, return_dict=True, encode_raw=self.pipe.vae.encode):
|
63
|
-
self.pipe.vae.to(dtype=torch.float32)
|
64
|
-
res = encode_raw(x.cpu().to(dtype=torch.float32), return_dict=return_dict)
|
65
|
-
return res
|
66
|
-
|
67
|
-
self.pipe.vae.encode = vae_encode_offload
|
68
|
-
|
69
|
-
def remove_vae_offload(self, offload_cfg):
|
70
|
-
if not offload_cfg.vae_cpu:
|
71
|
-
remove_hook_from_module(self.pipe.vae, recurse=True)
|
72
|
-
else:
|
73
|
-
self.pipe.vae.encode = self.vae_encode_raw
|
74
|
-
self.pipe.vae.decode = self.vae_decode_raw
|
75
|
-
|
76
|
-
@contextmanager
|
77
|
-
def infer_optimize(self):
|
78
|
-
if getattr(self.cfgs, 'vae_optimize', None) is not None:
|
79
|
-
if self.cfgs.vae_optimize.tiling:
|
80
|
-
self.pipe.vae.enable_tiling()
|
81
|
-
if self.cfgs.vae_optimize.slicing:
|
82
|
-
self.pipe.vae.enable_slicing()
|
83
|
-
vae_device = self.pipe.vae.device
|
84
|
-
if self.offload:
|
85
|
-
self.build_vae_offload(self.cfgs.offload)
|
86
|
-
else:
|
87
|
-
self.pipe.vae.to(self.pipe.unet.device)
|
88
|
-
|
89
|
-
yield
|
90
|
-
|
91
|
-
if self.offload:
|
92
|
-
self.remove_vae_offload(self.cfgs.offload)
|
93
|
-
self.pipe.vae.to(vae_device)
|
94
|
-
self.pipe.vae.disable_tiling()
|
95
|
-
self.pipe.vae.disable_slicing()
|
96
|
-
|
97
|
-
def preview(self):
|
98
|
-
image_list, info_list = [], []
|
99
|
-
with self.infer_optimize():
|
100
|
-
for i in range(self.cfgs.num):
|
101
|
-
prompt = self.cfgs.prompt[i*self.cfgs.bs:(i+1)*self.cfgs.bs] if isinstance(self.cfgs.prompt, list) \
|
102
|
-
else [self.cfgs.prompt]*self.cfgs.bs
|
103
|
-
negative_prompt = self.cfgs.neg_prompt[i*self.cfgs.bs:(i+1)*self.cfgs.bs] if isinstance(self.cfgs.neg_prompt, list) \
|
104
|
-
else [self.cfgs.neg_prompt]*self.cfgs.bs
|
105
|
-
seeds = self.seeds[i*self.cfgs.bs:(i+1)*self.cfgs.bs]
|
106
|
-
images = self.vis_images(prompt=prompt, negative_prompt=negative_prompt, seeds=seeds,
|
107
|
-
**self.cfgs.infer_args)
|
108
|
-
for prompt_i, negative_prompt_i, seed in zip(prompt, negative_prompt, seeds):
|
109
|
-
info_list.append({
|
110
|
-
'prompt':prompt_i,
|
111
|
-
'negative_prompt':negative_prompt_i,
|
112
|
-
'seed':seed,
|
113
|
-
})
|
114
|
-
image_list += images
|
115
|
-
|
116
|
-
return image_list, info_list
|
117
|
-
|
118
|
-
def preview_dict(self):
|
119
|
-
image_list, info_list = self.preview()
|
120
|
-
imgs = {f'{info["seed"]}-{to_validate_file(info["prompt"])}':img for img, info in zip(image_list, info_list)}
|
121
|
-
return imgs
|
122
|
-
|
123
|
-
@torch.no_grad()
|
124
|
-
def vis_images(self, prompt, negative_prompt='', seeds: List[int] = None, **kwargs):
|
125
|
-
G = prepare_seed(seeds or [None]*len(prompt))
|
126
|
-
|
127
|
-
ex_input_dict, pipe_input_dict = self.get_ex_input()
|
128
|
-
kwargs.update(pipe_input_dict)
|
129
|
-
|
130
|
-
mult_p, clean_text_p = self.token_ex.parse_attn_mult(prompt)
|
131
|
-
mult_n, clean_text_n = self.token_ex.parse_attn_mult(negative_prompt)
|
132
|
-
with autocast(enabled=self.cfgs.amp, dtype=self.dtype):
|
133
|
-
emb, pooled_output, attention_mask = self.te_hook.encode_prompt_to_emb(clean_text_n+clean_text_p)
|
134
|
-
if not self.cfgs.encoder_attention_mask:
|
135
|
-
attention_mask = None
|
136
|
-
emb_n, emb_p = emb.chunk(2)
|
137
|
-
emb_p = self.te_hook.mult_attn(emb_p, mult_p)
|
138
|
-
emb_n = self.te_hook.mult_attn(emb_n, mult_n)
|
139
|
-
|
140
|
-
if hasattr(self.pipe.unet, 'input_feeder'):
|
141
|
-
for feeder in self.pipe.unet.input_feeder:
|
142
|
-
feeder(ex_input_dict)
|
143
|
-
|
144
|
-
if pooled_output is not None:
|
145
|
-
pooled_output = pooled_output[-1]
|
146
|
-
|
147
|
-
images = self.pipe(prompt_embeds=emb_p, negative_prompt_embeds=emb_n, callback=self.inter_callback, generator=G,
|
148
|
-
pooled_output=pooled_output, encoder_attention_mask=attention_mask, **kwargs).images
|
149
|
-
return images
|
@@ -1,30 +0,0 @@
|
|
1
|
-
import os
|
2
|
-
from typing import Dict, Any
|
3
|
-
|
4
|
-
import numpy as np
|
5
|
-
from PIL import Image
|
6
|
-
from torch.utils.tensorboard import SummaryWriter
|
7
|
-
|
8
|
-
from .base_logger import BaseLogger
|
9
|
-
|
10
|
-
|
11
|
-
class TBLogger(BaseLogger):
|
12
|
-
def __init__(self, exp_dir, out_path, enable_log_image=False, log_step=10, image_log_step=200):
|
13
|
-
super().__init__(exp_dir, out_path, enable_log_image, log_step, image_log_step)
|
14
|
-
if exp_dir is not None: # exp_dir is only available in local main process
|
15
|
-
self.writer = SummaryWriter(os.path.join(exp_dir, out_path))
|
16
|
-
else:
|
17
|
-
self.writer = None
|
18
|
-
self.disable()
|
19
|
-
|
20
|
-
def _info(self, info):
|
21
|
-
pass
|
22
|
-
|
23
|
-
def _log(self, datas: Dict[str, Any], step: int = 0):
|
24
|
-
for k, v in datas.items():
|
25
|
-
if len(v['data']) == 1:
|
26
|
-
self.writer.add_scalar(k, v['data'][0], global_step=step)
|
27
|
-
|
28
|
-
def _log_image(self, imgs: Dict[str, Image.Image], step: int = 0):
|
29
|
-
for name, img in imgs.items():
|
30
|
-
self.writer.add_image(f'img/{name}', np.array(img), dataformats='HWC', global_step=step)
|
hcpdiff/loggers/wandb_logger.py
DELETED
@@ -1,31 +0,0 @@
|
|
1
|
-
from typing import Dict, Any
|
2
|
-
|
3
|
-
import os
|
4
|
-
import wandb
|
5
|
-
from PIL import Image
|
6
|
-
|
7
|
-
from .base_logger import BaseLogger
|
8
|
-
|
9
|
-
|
10
|
-
class WanDBLogger(BaseLogger):
|
11
|
-
def __init__(self, exp_dir, out_path=None, enable_log_image=False, project='hcp-diffusion', log_step=10, image_log_step=200):
|
12
|
-
super().__init__(exp_dir, out_path, enable_log_image, log_step, image_log_step)
|
13
|
-
if exp_dir is not None: # exp_dir is only available in local main process
|
14
|
-
wandb.init(project=project, name=os.path.basename(exp_dir))
|
15
|
-
wandb.save(os.path.join(exp_dir, 'cfg.yaml'), base_path=exp_dir)
|
16
|
-
else:
|
17
|
-
self.writer = None
|
18
|
-
self.disable()
|
19
|
-
|
20
|
-
def _info(self, info):
|
21
|
-
pass
|
22
|
-
|
23
|
-
def _log(self, datas: Dict[str, Any], step: int = 0):
|
24
|
-
log_dict = {'step': step}
|
25
|
-
for k, v in datas.items():
|
26
|
-
if len(v['data']) == 1:
|
27
|
-
log_dict[k] = v['data'][0]
|
28
|
-
wandb.log(log_dict)
|
29
|
-
|
30
|
-
def _log_image(self, imgs: Dict[str, Image.Image], step: int = 0):
|
31
|
-
wandb.log({next(iter(imgs.keys())): list(imgs.values())}, step=step)
|
hcpdiff/loggers/webui_logger.py
DELETED
@@ -1,9 +0,0 @@
|
|
1
|
-
from typing import Dict, Any
|
2
|
-
|
3
|
-
from loguru import logger
|
4
|
-
|
5
|
-
from .cli_logger import CLILogger
|
6
|
-
|
7
|
-
class WebUILogger(CLILogger):
|
8
|
-
def _log(self, datas: Dict[str, Any], step: int = 0):
|
9
|
-
logger.info('this progress steps:'+', '.join([f"{k} = {v['format'].format(*v['data'])}" for k, v in datas.items()]))
|
hcpdiff/loss/min_snr_loss.py
DELETED
@@ -1,52 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
from diffusers import SchedulerMixin
|
3
|
-
from torch import nn
|
4
|
-
|
5
|
-
class MinSNRLoss(nn.MSELoss):
|
6
|
-
need_timesteps = True
|
7
|
-
|
8
|
-
def __init__(self, size_average=None, reduce=None, reduction: str = 'none', gamma=1.,
|
9
|
-
noise_scheduler: SchedulerMixin = None, device='cuda:0', **kwargs):
|
10
|
-
super().__init__(size_average, reduce, reduction)
|
11
|
-
self.gamma = gamma
|
12
|
-
|
13
|
-
# calculate SNR
|
14
|
-
alphas_cumprod = noise_scheduler.alphas_cumprod
|
15
|
-
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
16
|
-
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0-alphas_cumprod)
|
17
|
-
self.alpha = sqrt_alphas_cumprod.to(device)
|
18
|
-
self.sigma = sqrt_one_minus_alphas_cumprod.to(device)
|
19
|
-
self.all_snr = ((self.alpha/self.sigma)**2).to(device)
|
20
|
-
|
21
|
-
def forward(self, input: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
|
22
|
-
loss = super(MinSNRLoss, self).forward(input, target)
|
23
|
-
snr = self.all_snr[timesteps[:loss.shape[0], ...].squeeze()]
|
24
|
-
snr_weight = (self.gamma/snr).clip(max=1.).float()
|
25
|
-
return loss*snr_weight.view(-1, 1, 1, 1)
|
26
|
-
|
27
|
-
|
28
|
-
class SoftMinSNRLoss(MinSNRLoss):
|
29
|
-
# gamma=2
|
30
|
-
|
31
|
-
def forward(self, input: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
|
32
|
-
loss = super(MinSNRLoss, self).forward(input, target)
|
33
|
-
snr = self.all_snr[timesteps[:loss.shape[0], ...].squeeze()]
|
34
|
-
snr_weight = (self.gamma**3/(snr**2 + self.gamma**3)).float()
|
35
|
-
return loss*snr_weight.view(-1, 1, 1, 1)
|
36
|
-
|
37
|
-
class KDiffMinSNRLoss(MinSNRLoss):
|
38
|
-
|
39
|
-
def forward(self, input: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
|
40
|
-
loss = super(MinSNRLoss, self).forward(input, target)
|
41
|
-
snr = self.all_snr[timesteps[:loss.shape[0], ...].squeeze()]
|
42
|
-
snr_weight = 4*(((self.gamma*snr)**2/(snr**2 + self.gamma**2)**2)).float()
|
43
|
-
return loss*snr_weight.view(-1, 1, 1, 1)
|
44
|
-
|
45
|
-
class EDMLoss(MinSNRLoss):
|
46
|
-
|
47
|
-
def forward(self, input: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
|
48
|
-
loss = super(MinSNRLoss, self).forward(input, target)
|
49
|
-
sigma = self.sigma[timesteps[:loss.shape[0], ...].squeeze()]
|
50
|
-
snr = self.all_snr[timesteps[:loss.shape[0], ...].squeeze()]
|
51
|
-
snr_weight = ((sigma**2+self.gamma**2)/(snr*(sigma*self.gamma)**2)).float()
|
52
|
-
return loss*snr_weight.view(-1, 1, 1, 1)
|
hcpdiff/models/layers.py
DELETED
@@ -1,81 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
layers.py
|
3
|
-
====================
|
4
|
-
:Name: GroupLinear and other layers
|
5
|
-
:Author: Dong Ziyi
|
6
|
-
:Affiliation: HCP Lab, SYSU
|
7
|
-
:Created: 09/04/2023
|
8
|
-
:Licence: Apache-2.0
|
9
|
-
"""
|
10
|
-
|
11
|
-
import torch
|
12
|
-
from torch import nn
|
13
|
-
import math
|
14
|
-
from einops import rearrange
|
15
|
-
|
16
|
-
class GroupLinear(nn.Module):
|
17
|
-
def __init__(self, in_features: int, out_features: int, groups: int, bias: bool = True,
|
18
|
-
device=None, dtype=None):
|
19
|
-
super().__init__()
|
20
|
-
assert in_features%groups == 0
|
21
|
-
assert out_features%groups == 0
|
22
|
-
|
23
|
-
factory_kwargs = {'device': device, 'dtype': dtype}
|
24
|
-
|
25
|
-
self.groups = groups
|
26
|
-
self.in_features = in_features
|
27
|
-
self.out_features = out_features
|
28
|
-
|
29
|
-
self.weight = nn.Parameter(torch.empty((groups, in_features//groups, out_features//groups), **factory_kwargs))
|
30
|
-
if bias:
|
31
|
-
self.bias = nn.Parameter(torch.empty(groups, 1, out_features//groups, **factory_kwargs))
|
32
|
-
else:
|
33
|
-
self.register_parameter('bias', None)
|
34
|
-
self.reset_parameters()
|
35
|
-
|
36
|
-
def reset_parameters(self) -> None:
|
37
|
-
# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
|
38
|
-
# uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
|
39
|
-
# https://github.com/pytorch/pytorch/issues/57109
|
40
|
-
self.kaiming_uniform_group(self.weight, a=math.sqrt(5))
|
41
|
-
if self.bias is not None:
|
42
|
-
fan_in, _ = self._calculate_fan_in_and_fan_out(self.weight)
|
43
|
-
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
44
|
-
nn.init.uniform_(self.bias, -bound, bound)
|
45
|
-
|
46
|
-
@staticmethod
|
47
|
-
def _calculate_fan_in_and_fan_out(tensor):
|
48
|
-
receptive_field_size = 1
|
49
|
-
num_input_fmaps = tensor.size(-2)
|
50
|
-
num_output_fmaps = tensor.size(-1)
|
51
|
-
fan_in = num_input_fmaps * receptive_field_size
|
52
|
-
fan_out = num_output_fmaps * receptive_field_size
|
53
|
-
|
54
|
-
return fan_in, fan_out
|
55
|
-
|
56
|
-
@staticmethod
|
57
|
-
def kaiming_uniform_group(tensor: torch.Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu') -> torch.Tensor:
|
58
|
-
def _calculate_correct_fan(tensor, mode):
|
59
|
-
mode = mode.lower()
|
60
|
-
valid_modes = ['fan_in', 'fan_out']
|
61
|
-
if mode not in valid_modes:
|
62
|
-
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
|
63
|
-
|
64
|
-
fan_in, fan_out = GroupLinear._calculate_fan_in_and_fan_out(tensor)
|
65
|
-
return fan_in if mode == 'fan_in' else fan_out
|
66
|
-
|
67
|
-
fan = _calculate_correct_fan(tensor, mode)
|
68
|
-
gain = nn.init.calculate_gain(nonlinearity, a)
|
69
|
-
std = gain / math.sqrt(fan)
|
70
|
-
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
|
71
|
-
with torch.no_grad():
|
72
|
-
return tensor.uniform_(-bound, bound)
|
73
|
-
|
74
|
-
def forward(self, x: torch.Tensor): # x: [G,B,L,C]
|
75
|
-
x = rearrange(x, '(g b) l c -> g (b l) c', g=self.num_groups)
|
76
|
-
if self.bias is not None:
|
77
|
-
out = torch.bmm(x, self.weight) + self.bias
|
78
|
-
else:
|
79
|
-
out = torch.bmm(x, self.weight)
|
80
|
-
out = rearrange(out, 'g (b l) c -> (g b) l c', b=B)
|
81
|
-
return out
|