hcpdiff 0.9.1__py3-none-any.whl → 2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/__init__.py +4 -4
- hcpdiff/ckpt_manager/__init__.py +4 -5
- hcpdiff/ckpt_manager/ckpt.py +24 -0
- hcpdiff/ckpt_manager/format/__init__.py +4 -0
- hcpdiff/ckpt_manager/format/diffusers.py +59 -0
- hcpdiff/ckpt_manager/format/emb.py +21 -0
- hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
- hcpdiff/ckpt_manager/format/sd_single.py +41 -0
- hcpdiff/ckpt_manager/loader.py +64 -0
- hcpdiff/data/__init__.py +4 -28
- hcpdiff/data/cache/__init__.py +1 -0
- hcpdiff/data/cache/vae.py +102 -0
- hcpdiff/data/dataset.py +20 -0
- hcpdiff/data/handler/__init__.py +3 -0
- hcpdiff/data/handler/controlnet.py +18 -0
- hcpdiff/data/handler/diffusion.py +80 -0
- hcpdiff/data/handler/text.py +111 -0
- hcpdiff/data/source/__init__.py +1 -2
- hcpdiff/data/source/folder_class.py +12 -29
- hcpdiff/data/source/text2img.py +36 -74
- hcpdiff/data/source/text2img_cond.py +9 -15
- hcpdiff/diffusion/__init__.py +0 -0
- hcpdiff/diffusion/noise/__init__.py +2 -0
- hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
- hcpdiff/diffusion/noise/zero_terminal.py +39 -0
- hcpdiff/diffusion/sampler/__init__.py +5 -0
- hcpdiff/diffusion/sampler/base.py +72 -0
- hcpdiff/diffusion/sampler/ddpm.py +20 -0
- hcpdiff/diffusion/sampler/diffusers.py +66 -0
- hcpdiff/diffusion/sampler/edm.py +22 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
- hcpdiff/easy/__init__.py +2 -0
- hcpdiff/easy/cfg/__init__.py +3 -0
- hcpdiff/easy/cfg/sd15_train.py +201 -0
- hcpdiff/easy/cfg/sdxl_train.py +140 -0
- hcpdiff/easy/cfg/t2i.py +177 -0
- hcpdiff/easy/model/__init__.py +2 -0
- hcpdiff/easy/model/cnet.py +31 -0
- hcpdiff/easy/model/loader.py +79 -0
- hcpdiff/easy/sampler.py +46 -0
- hcpdiff/evaluate/__init__.py +1 -0
- hcpdiff/evaluate/previewer.py +60 -0
- hcpdiff/loss/__init__.py +4 -1
- hcpdiff/loss/base.py +41 -0
- hcpdiff/loss/gw.py +35 -0
- hcpdiff/loss/ssim.py +37 -0
- hcpdiff/loss/vlb.py +79 -0
- hcpdiff/loss/weighting.py +66 -0
- hcpdiff/models/__init__.py +2 -2
- hcpdiff/models/cfg_context.py +17 -14
- hcpdiff/models/compose/compose_hook.py +44 -23
- hcpdiff/models/compose/compose_tokenizer.py +21 -8
- hcpdiff/models/compose/sdxl_composer.py +4 -4
- hcpdiff/models/controlnet.py +16 -16
- hcpdiff/models/lora_base_patch.py +14 -25
- hcpdiff/models/lora_layers.py +3 -9
- hcpdiff/models/lora_layers_patch.py +14 -24
- hcpdiff/models/text_emb_ex.py +84 -6
- hcpdiff/models/textencoder_ex.py +54 -18
- hcpdiff/models/wrapper/__init__.py +3 -0
- hcpdiff/models/wrapper/pixart.py +19 -0
- hcpdiff/models/wrapper/sd.py +218 -0
- hcpdiff/models/wrapper/utils.py +20 -0
- hcpdiff/parser/__init__.py +1 -0
- hcpdiff/parser/embpt.py +32 -0
- hcpdiff/tools/convert_caption_txt2json.py +1 -1
- hcpdiff/tools/dataset_generator.py +94 -0
- hcpdiff/tools/download_hf_model.py +24 -0
- hcpdiff/tools/init_proj.py +3 -21
- hcpdiff/tools/lora_convert.py +18 -17
- hcpdiff/tools/save_model.py +12 -0
- hcpdiff/tools/sd2diffusers.py +1 -1
- hcpdiff/train_colo.py +1 -1
- hcpdiff/train_deepspeed.py +1 -1
- hcpdiff/trainer_ac.py +79 -0
- hcpdiff/trainer_ac_single.py +31 -0
- hcpdiff/utils/__init__.py +0 -2
- hcpdiff/utils/inpaint_pipe.py +7 -2
- hcpdiff/utils/net_utils.py +29 -6
- hcpdiff/utils/pipe_hook.py +24 -7
- hcpdiff/utils/utils.py +21 -4
- hcpdiff/workflow/__init__.py +15 -10
- hcpdiff/workflow/daam/__init__.py +1 -0
- hcpdiff/workflow/daam/act.py +66 -0
- hcpdiff/workflow/daam/hook.py +109 -0
- hcpdiff/workflow/diffusion.py +114 -125
- hcpdiff/workflow/fast.py +31 -0
- hcpdiff/workflow/flow.py +67 -0
- hcpdiff/workflow/io.py +36 -130
- hcpdiff/workflow/model.py +46 -43
- hcpdiff/workflow/text.py +78 -46
- hcpdiff/workflow/utils.py +32 -12
- hcpdiff/workflow/vae.py +37 -38
- hcpdiff-2.1.dist-info/METADATA +285 -0
- hcpdiff-2.1.dist-info/RECORD +114 -0
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
- hcpdiff-2.1.dist-info/entry_points.txt +5 -0
- hcpdiff/ckpt_manager/base.py +0 -16
- hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
- hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
- hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
- hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
- hcpdiff/data/bucket.py +0 -358
- hcpdiff/data/caption_loader.py +0 -80
- hcpdiff/data/cond_dataset.py +0 -40
- hcpdiff/data/crop_info_dataset.py +0 -40
- hcpdiff/data/data_processor.py +0 -33
- hcpdiff/data/pair_dataset.py +0 -146
- hcpdiff/data/sampler.py +0 -54
- hcpdiff/data/source/base.py +0 -30
- hcpdiff/data/utils.py +0 -80
- hcpdiff/deprecated/__init__.py +0 -1
- hcpdiff/deprecated/cfg_converter.py +0 -81
- hcpdiff/deprecated/lora_convert.py +0 -31
- hcpdiff/infer_workflow.py +0 -57
- hcpdiff/loggers/__init__.py +0 -13
- hcpdiff/loggers/base_logger.py +0 -76
- hcpdiff/loggers/cli_logger.py +0 -40
- hcpdiff/loggers/preview/__init__.py +0 -1
- hcpdiff/loggers/preview/image_previewer.py +0 -149
- hcpdiff/loggers/tensorboard_logger.py +0 -30
- hcpdiff/loggers/wandb_logger.py +0 -31
- hcpdiff/loggers/webui_logger.py +0 -9
- hcpdiff/loss/min_snr_loss.py +0 -52
- hcpdiff/models/layers.py +0 -81
- hcpdiff/models/plugin.py +0 -348
- hcpdiff/models/wrapper.py +0 -75
- hcpdiff/noise/__init__.py +0 -3
- hcpdiff/noise/noise_base.py +0 -16
- hcpdiff/noise/pyramid_noise.py +0 -50
- hcpdiff/noise/zero_terminal.py +0 -44
- hcpdiff/train_ac.py +0 -566
- hcpdiff/train_ac_single.py +0 -39
- hcpdiff/utils/caption_tools.py +0 -105
- hcpdiff/utils/cfg_net_tools.py +0 -321
- hcpdiff/utils/cfg_resolvers.py +0 -16
- hcpdiff/utils/ema.py +0 -52
- hcpdiff/utils/img_size_tool.py +0 -248
- hcpdiff/vis/__init__.py +0 -3
- hcpdiff/vis/base_interface.py +0 -12
- hcpdiff/vis/disk_interface.py +0 -48
- hcpdiff/vis/webui_interface.py +0 -17
- hcpdiff/viser_fast.py +0 -138
- hcpdiff/visualizer.py +0 -265
- hcpdiff/visualizer_reloadable.py +0 -237
- hcpdiff/workflow/base.py +0 -59
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
- hcpdiff-0.9.1.dist-info/METADATA +0 -199
- hcpdiff-0.9.1.dist-info/RECORD +0 -160
- hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
hcpdiff/visualizer.py
DELETED
@@ -1,265 +0,0 @@
|
|
1
|
-
import argparse
|
2
|
-
import os
|
3
|
-
import random
|
4
|
-
from typing import List
|
5
|
-
|
6
|
-
import hydra
|
7
|
-
import torch
|
8
|
-
from PIL import Image
|
9
|
-
from accelerate import infer_auto_device_map, dispatch_model
|
10
|
-
from diffusers.utils.import_utils import is_xformers_available
|
11
|
-
from hcpdiff.models import TokenizerHook, LoraBlock
|
12
|
-
from hcpdiff.models.compose import ComposeTEEXHook, ComposeEmbPTHook, ComposeTextEncoder
|
13
|
-
from hcpdiff.utils.cfg_net_tools import HCPModelLoader, make_plugin
|
14
|
-
from hcpdiff.utils.net_utils import to_cpu, to_cuda, auto_tokenizer, auto_text_encoder
|
15
|
-
from hcpdiff.utils.pipe_hook import HookPipe_T2I, HookPipe_I2I, HookPipe_Inpaint
|
16
|
-
from hcpdiff.utils.utils import load_config_with_cli, load_config, size_to_int, int_to_size, prepare_seed, is_list, pad_attn_bias
|
17
|
-
from hcpdiff.deprecated.cfg_converter import InferCFGConverter
|
18
|
-
from omegaconf import OmegaConf
|
19
|
-
from torch.cuda.amp import autocast
|
20
|
-
|
21
|
-
class Visualizer:
|
22
|
-
dtype_dict = {'fp32':torch.float32, 'fp16':torch.float16, 'bf16':torch.bfloat16}
|
23
|
-
|
24
|
-
def __init__(self, cfgs):
|
25
|
-
self.cfgs_raw = cfgs
|
26
|
-
self.cfgs = hydra.utils.instantiate(self.cfgs_raw)
|
27
|
-
self.cfg_merge = self.cfgs.merge
|
28
|
-
self.offload = 'offload' in self.cfgs and self.cfgs.offload is not None
|
29
|
-
self.dtype = self.dtype_dict[self.cfgs.dtype]
|
30
|
-
|
31
|
-
self.need_inter_imgs = any(item.need_inter_imgs for item in self.cfgs.interface)
|
32
|
-
|
33
|
-
self.pipe = self.load_model(self.cfgs.pretrained_model)
|
34
|
-
|
35
|
-
if self.cfg_merge:
|
36
|
-
self.merge_model()
|
37
|
-
|
38
|
-
self.pipe = self.pipe.to(torch_dtype=self.dtype)
|
39
|
-
|
40
|
-
if isinstance(self.pipe.text_encoder, ComposeTextEncoder):
|
41
|
-
self.pipe.vae = self.pipe.vae.to(dtype=torch.float32)
|
42
|
-
|
43
|
-
if 'save_model' in self.cfgs and self.cfgs.save_model is not None:
|
44
|
-
self.save_model(self.cfgs.save_model)
|
45
|
-
os._exit(0)
|
46
|
-
|
47
|
-
self.build_optimize()
|
48
|
-
|
49
|
-
def load_model(self, pretrained_model):
|
50
|
-
pipeline = self.get_pipeline()
|
51
|
-
te = auto_text_encoder(pretrained_model, subfolder="text_encoder", torch_dtype=self.dtype, resume_download=True)
|
52
|
-
tokenizer = auto_tokenizer(pretrained_model, subfolder="tokenizer", use_fast=False)
|
53
|
-
|
54
|
-
return pipeline.from_pretrained(pretrained_model, safety_checker=None, requires_safety_checker=False,
|
55
|
-
text_encoder=te, tokenizer=tokenizer, resume_download=True,
|
56
|
-
torch_dtype=self.dtype, **self.cfgs.new_components)
|
57
|
-
|
58
|
-
def build_optimize(self):
|
59
|
-
if self.offload:
|
60
|
-
self.build_offload(self.cfgs.offload)
|
61
|
-
else:
|
62
|
-
self.pipe.unet.to('cuda')
|
63
|
-
self.build_vae_offload()
|
64
|
-
|
65
|
-
if getattr(self.cfgs, 'vae_optimize', None) is not None:
|
66
|
-
if self.cfgs.vae_optimize.tiling:
|
67
|
-
self.pipe.vae.enable_tiling()
|
68
|
-
if self.cfgs.vae_optimize.slicing:
|
69
|
-
self.pipe.vae.enable_slicing()
|
70
|
-
|
71
|
-
self.emb_hook, _ = ComposeEmbPTHook.hook_from_dir(self.cfgs.emb_dir, self.pipe.tokenizer, self.pipe.text_encoder,
|
72
|
-
N_repeats=self.cfgs.N_repeats)
|
73
|
-
self.te_hook = ComposeTEEXHook.hook_pipe(self.pipe, N_repeats=self.cfgs.N_repeats, clip_skip=self.cfgs.clip_skip,
|
74
|
-
clip_final_norm=self.cfgs.clip_final_norm, use_attention_mask=self.cfgs.encoder_attention_mask)
|
75
|
-
self.token_ex = TokenizerHook(self.pipe.tokenizer)
|
76
|
-
|
77
|
-
if is_xformers_available():
|
78
|
-
self.pipe.unet.enable_xformers_memory_efficient_attention()
|
79
|
-
# self.te_hook.enable_xformers()
|
80
|
-
|
81
|
-
def save_model(self, save_cfg):
|
82
|
-
for k, v in self.pipe.unet.named_modules():
|
83
|
-
if isinstance(v, LoraBlock):
|
84
|
-
v.reparameterization_to_host()
|
85
|
-
v.remove()
|
86
|
-
for k, v in self.pipe.text_encoder.named_modules():
|
87
|
-
if isinstance(v, LoraBlock):
|
88
|
-
v.reparameterization_to_host()
|
89
|
-
v.remove()
|
90
|
-
|
91
|
-
if save_cfg.path.endswith('.ckpt'):
|
92
|
-
from hcpdiff.tools.diffusers2sd import save_state_dict
|
93
|
-
save_state_dict(save_cfg.path, self.pipe.unet.state_dict(), self.pipe.vae.state_dict(), self.pipe.text_encoder.state_dict(),
|
94
|
-
use_safetensors=save_cfg.to_safetensors)
|
95
|
-
|
96
|
-
else:
|
97
|
-
self.pipe.save_pretrained(save_cfg.path, safe_serialization=save_cfg.to_safetensors)
|
98
|
-
|
99
|
-
def get_pipeline(self):
|
100
|
-
if self.cfgs.condition is None:
|
101
|
-
pipe_cls = HookPipe_T2I
|
102
|
-
else:
|
103
|
-
if self.cfgs.condition.type == 'i2i':
|
104
|
-
pipe_cls = HookPipe_I2I
|
105
|
-
elif self.cfgs.condition.type == 'inpaint':
|
106
|
-
pipe_cls = HookPipe_Inpaint
|
107
|
-
else:
|
108
|
-
raise NotImplementedError(f'No condition type named {self.cfgs.condition.type}')
|
109
|
-
|
110
|
-
return pipe_cls
|
111
|
-
|
112
|
-
def build_offload(self, offload_cfg):
|
113
|
-
vram = size_to_int(offload_cfg.max_VRAM)
|
114
|
-
device_map = infer_auto_device_map(self.pipe.unet, max_memory={0:int_to_size(vram >> 1), "cpu":offload_cfg.max_RAM}, dtype=self.dtype)
|
115
|
-
self.pipe.unet = dispatch_model(self.pipe.unet, device_map)
|
116
|
-
if not offload_cfg.vae_cpu:
|
117
|
-
device_map = infer_auto_device_map(self.pipe.vae, max_memory={0:int_to_size(vram >> 5), "cpu":offload_cfg.max_RAM}, dtype=self.dtype)
|
118
|
-
self.pipe.vae = dispatch_model(self.pipe.vae, device_map)
|
119
|
-
|
120
|
-
def build_vae_offload(self):
|
121
|
-
def vae_decode_offload(latents, return_dict=True, decode_raw=self.pipe.vae.decode):
|
122
|
-
if self.need_inter_imgs:
|
123
|
-
to_cuda(self.pipe.vae)
|
124
|
-
res = decode_raw(latents, return_dict=return_dict)
|
125
|
-
else:
|
126
|
-
to_cpu(self.pipe.unet)
|
127
|
-
|
128
|
-
if self.offload and self.cfgs.offload.vae_cpu:
|
129
|
-
self.pipe.vae.to(dtype=torch.float32)
|
130
|
-
res = decode_raw(latents.cpu().to(dtype=torch.float32), return_dict=return_dict)
|
131
|
-
else:
|
132
|
-
to_cuda(self.pipe.vae)
|
133
|
-
res = decode_raw(latents.to(dtype=self.pipe.vae.dtype), return_dict=return_dict)
|
134
|
-
|
135
|
-
to_cpu(self.pipe.vae)
|
136
|
-
to_cuda(self.pipe.unet)
|
137
|
-
return res
|
138
|
-
|
139
|
-
self.pipe.vae.decode = vae_decode_offload
|
140
|
-
|
141
|
-
def vae_encode_offload(x, return_dict=True, encode_raw=self.pipe.vae.encode):
|
142
|
-
to_cuda(self.pipe.vae)
|
143
|
-
res = encode_raw(x.to(dtype=self.pipe.vae.dtype), return_dict=return_dict)
|
144
|
-
to_cpu(self.pipe.vae)
|
145
|
-
return res
|
146
|
-
|
147
|
-
self.pipe.vae.encode = vae_encode_offload
|
148
|
-
|
149
|
-
def merge_model(self):
|
150
|
-
if 'plugin_cfg' in self.cfg_merge: # Build plugins
|
151
|
-
if isinstance(self.cfg_merge.plugin_cfg, str):
|
152
|
-
plugin_cfg = load_config(self.cfg_merge.plugin_cfg)
|
153
|
-
plugin_cfg = {'plugin_unet': hydra.utils.instantiate(plugin_cfg['plugin_unet']),
|
154
|
-
'plugin_TE': hydra.utils.instantiate(plugin_cfg['plugin_TE'])}
|
155
|
-
else:
|
156
|
-
plugin_cfg = self.cfg_merge.plugin_cfg
|
157
|
-
make_plugin(self.pipe.unet, plugin_cfg['plugin_unet'])
|
158
|
-
make_plugin(self.pipe.text_encoder, plugin_cfg['plugin_TE'])
|
159
|
-
|
160
|
-
load_ema = self.cfg_merge.get('load_ema', False)
|
161
|
-
for cfg_group in self.cfg_merge.values():
|
162
|
-
if hasattr(cfg_group, 'type'):
|
163
|
-
if cfg_group.type == 'unet':
|
164
|
-
HCPModelLoader(self.pipe.unet).load_all(cfg_group, load_ema=load_ema)
|
165
|
-
elif cfg_group.type == 'TE':
|
166
|
-
HCPModelLoader(self.pipe.text_encoder).load_all(cfg_group, load_ema=load_ema)
|
167
|
-
|
168
|
-
def set_scheduler(self, scheduler):
|
169
|
-
self.pipe.scheduler = scheduler
|
170
|
-
|
171
|
-
def get_ex_input(self):
|
172
|
-
ex_input_dict, pipe_input_dict = {}, {}
|
173
|
-
if self.cfgs.condition is not None:
|
174
|
-
if self.cfgs.condition.type == 'i2i':
|
175
|
-
pipe_input_dict['image'] = Image.open(self.cfgs.condition.image).convert('RGB')
|
176
|
-
elif self.cfgs.condition.type == 'inpaint':
|
177
|
-
pipe_input_dict['image'] = Image.open(self.cfgs.condition.image).convert('RGB')
|
178
|
-
pipe_input_dict['mask_image'] = Image.open(self.cfgs.condition.mask).convert('L')
|
179
|
-
|
180
|
-
if getattr(self.cfgs, 'ex_input', None) is not None:
|
181
|
-
for key, processor in self.cfgs.ex_input.items():
|
182
|
-
ex_input_dict[key] = processor(self.cfgs.infer_args.width, self.cfgs.infer_args.height, self.cfgs.bs*2, 'cuda', self.dtype)
|
183
|
-
return ex_input_dict, pipe_input_dict
|
184
|
-
|
185
|
-
@torch.no_grad()
|
186
|
-
def vis_images(self, prompt, negative_prompt='', seeds: List[int] = None, **kwargs):
|
187
|
-
G = prepare_seed(seeds or [None]*len(prompt))
|
188
|
-
|
189
|
-
ex_input_dict, pipe_input_dict = self.get_ex_input()
|
190
|
-
kwargs.update(pipe_input_dict)
|
191
|
-
|
192
|
-
to_cuda(self.pipe.text_encoder)
|
193
|
-
|
194
|
-
mult_p, clean_text_p = self.token_ex.parse_attn_mult(prompt)
|
195
|
-
mult_n, clean_text_n = self.token_ex.parse_attn_mult(negative_prompt)
|
196
|
-
with autocast(enabled=self.cfgs.amp, dtype=self.dtype):
|
197
|
-
if hasattr(self.pipe.text_encoder, 'input_feeder'):
|
198
|
-
for feeder in self.pipe.text_encoder.input_feeder:
|
199
|
-
feeder(ex_input_dict)
|
200
|
-
|
201
|
-
emb, pooled_output, attention_mask = self.te_hook.encode_prompt_to_emb(clean_text_n+clean_text_p)
|
202
|
-
if self.cfgs.encoder_attention_mask:
|
203
|
-
emb, attention_mask = pad_attn_bias(emb, attention_mask)
|
204
|
-
else:
|
205
|
-
attention_mask = None
|
206
|
-
emb_n, emb_p = emb.chunk(2)
|
207
|
-
emb_p = self.te_hook.mult_attn(emb_p, mult_p)
|
208
|
-
emb_n = self.te_hook.mult_attn(emb_n, mult_n)
|
209
|
-
|
210
|
-
to_cpu(self.pipe.text_encoder)
|
211
|
-
to_cuda(self.pipe.unet)
|
212
|
-
|
213
|
-
if hasattr(self.pipe.unet, 'input_feeder'):
|
214
|
-
for feeder in self.pipe.unet.input_feeder:
|
215
|
-
feeder(ex_input_dict)
|
216
|
-
|
217
|
-
images = self.pipe(prompt_embeds=emb_p, negative_prompt_embeds=emb_n, callback=self.inter_callback, generator=G,
|
218
|
-
pooled_output=pooled_output[-1], encoder_attention_mask=attention_mask, **kwargs).images
|
219
|
-
return images
|
220
|
-
|
221
|
-
def inter_callback(self, i, t, num_t, latents_x0, latents):
|
222
|
-
images = None
|
223
|
-
interrupt = False
|
224
|
-
for interface in self.cfgs.interface:
|
225
|
-
if interface.show_steps>0 and i%interface.show_steps == 0:
|
226
|
-
if self.need_inter_imgs and images is None:
|
227
|
-
images = self.pipe.decode_latents(latents_x0)
|
228
|
-
images = self.pipe.numpy_to_pil(images)
|
229
|
-
feed_back = interface.on_inter_step(i, num_t, t, latents_x0, images)
|
230
|
-
interrupt |= bool(feed_back)
|
231
|
-
return None if interrupt else latents
|
232
|
-
|
233
|
-
def save_images(self, images, prompt, negative_prompt='', seeds: List[int] = None):
|
234
|
-
for interface in self.cfgs.interface:
|
235
|
-
interface.on_infer_finish(images, prompt, negative_prompt, self.cfgs_raw, seeds=seeds)
|
236
|
-
|
237
|
-
def vis_to_dir(self, prompt, negative_prompt='', seeds: List[int] = None, **kwargs):
|
238
|
-
seeds = [s or random.randint(0, 1 << 30) for s in seeds]
|
239
|
-
|
240
|
-
images = self.vis_images(prompt, negative_prompt, seeds=seeds, **kwargs)
|
241
|
-
self.save_images(images, prompt, negative_prompt, seeds=seeds)
|
242
|
-
|
243
|
-
if __name__ == '__main__':
|
244
|
-
parser = argparse.ArgumentParser(description='HCP Diffusion Inference')
|
245
|
-
parser.add_argument('--cfg', type=str, default='')
|
246
|
-
args, cfg_args = parser.parse_known_args()
|
247
|
-
cfgs = load_config_with_cli(args.cfg, args_list=cfg_args) # skip --cfg
|
248
|
-
|
249
|
-
cfgs = InferCFGConverter().convert(cfgs) # support old cfgs format
|
250
|
-
|
251
|
-
if cfgs.seed is not None:
|
252
|
-
if is_list(cfgs.seed):
|
253
|
-
assert len(cfgs.seed) == cfgs.num*cfgs.bs, 'seed list length should be equal to num*bs'
|
254
|
-
seeds = list(cfgs.seed)
|
255
|
-
else:
|
256
|
-
seeds = list(range(cfgs.seed, cfgs.seed+cfgs.num*cfgs.bs))
|
257
|
-
else:
|
258
|
-
seeds = [None]*(cfgs.num*cfgs.bs)
|
259
|
-
|
260
|
-
viser = Visualizer(cfgs)
|
261
|
-
for i in range(cfgs.num):
|
262
|
-
prompt = cfgs.prompt[i*cfgs.bs:(i+1)*cfgs.bs] if is_list(cfgs.prompt) else [cfgs.prompt]*cfgs.bs
|
263
|
-
negative_prompt = cfgs.neg_prompt[i*cfgs.bs:(i+1)*cfgs.bs] if is_list(cfgs.neg_prompt) else [cfgs.neg_prompt]*cfgs.bs
|
264
|
-
viser.vis_to_dir(prompt=prompt, negative_prompt=negative_prompt,
|
265
|
-
seeds=seeds[i*cfgs.bs:(i+1)*cfgs.bs], save_cfg=cfgs.save.save_cfg, **cfgs.infer_args)
|
hcpdiff/visualizer_reloadable.py
DELETED
@@ -1,237 +0,0 @@
|
|
1
|
-
from hcpdiff.visualizer import Visualizer
|
2
|
-
import accelerate.hooks
|
3
|
-
from omegaconf import OmegaConf
|
4
|
-
from hcpdiff.models import EmbeddingPTHook
|
5
|
-
import hydra
|
6
|
-
from diffusers import AutoencoderKL, PNDMScheduler
|
7
|
-
import torch
|
8
|
-
from hcpdiff.utils.cfg_net_tools import HCPModelLoader, make_plugin
|
9
|
-
from hcpdiff.utils import load_config, hash_str
|
10
|
-
from copy import deepcopy
|
11
|
-
|
12
|
-
class VisualizerReloadable(Visualizer):
|
13
|
-
def __init__(self, cfgs):
|
14
|
-
self.lora_dict = {}
|
15
|
-
self.part_plugin_cfg_set = set()
|
16
|
-
super().__init__(cfgs)
|
17
|
-
|
18
|
-
def _merge_model(self, cfg_merge):
|
19
|
-
if 'plugin_cfg' in cfg_merge: # Build plugins
|
20
|
-
plugin_cfg = hydra.utils.instantiate(load_config(cfg_merge.plugin_cfg))
|
21
|
-
make_plugin(self.pipe.unet, plugin_cfg.plugin_unet)
|
22
|
-
make_plugin(self.pipe.text_encoder, plugin_cfg.plugin_TE)
|
23
|
-
|
24
|
-
for cfg_group in cfg_merge.values():
|
25
|
-
if hasattr(cfg_group, 'type'):
|
26
|
-
if cfg_group.type == 'unet':
|
27
|
-
lora_group = HCPModelLoader(self.pipe.unet).load_all(cfg_group)
|
28
|
-
elif cfg_group.type == 'TE':
|
29
|
-
lora_group = HCPModelLoader(self.pipe.text_encoder).load_all(cfg_group)
|
30
|
-
else:
|
31
|
-
raise ValueError(f'no host model type named {cfg_group.type}')
|
32
|
-
|
33
|
-
# record all lora plugin with its config hash
|
34
|
-
if not lora_group.empty():
|
35
|
-
for cfg_lora, lora_plugin in zip(cfg_group.lora, lora_group.plugin_dict.values()):
|
36
|
-
self.lora_dict[hash_str(OmegaConf.to_yaml(cfg_lora, resolve=True))] = lora_plugin
|
37
|
-
|
38
|
-
# record all part and plugin config hash
|
39
|
-
for cfg_part in getattr(cfg_group, "part", None) or []:
|
40
|
-
self.part_plugin_cfg_set.add(hash_str(OmegaConf.to_yaml(cfg_part, resolve=True)))
|
41
|
-
for cfg_plugin in getattr(cfg_group, "plugin", None) or []:
|
42
|
-
self.part_plugin_cfg_set.add(hash_str(OmegaConf.to_yaml(cfg_plugin, resolve=True)))
|
43
|
-
|
44
|
-
def merge_model(self):
|
45
|
-
self.part_plugin_cfg_set.clear()
|
46
|
-
self.lora_dict.clear()
|
47
|
-
self._merge_model(self.cfg_merge)
|
48
|
-
|
49
|
-
def part_plugin_changed(self):
|
50
|
-
if not self.cfg_merge:
|
51
|
-
return not self.cfg_same(self.cfg_merge, self.cfgs_old.merge)
|
52
|
-
part_plugin_cfg_set_new = set()
|
53
|
-
for cfg_group in self.cfg_merge.values():
|
54
|
-
for cfg_part in getattr(cfg_group, "part", None) or []:
|
55
|
-
part_plugin_cfg_set_new.add(hash_str(OmegaConf.to_yaml(cfg_part, resolve=True)))
|
56
|
-
for cfg_plugin in getattr(cfg_group, "plugin", None) or []:
|
57
|
-
part_plugin_cfg_set_new.add(hash_str(OmegaConf.to_yaml(cfg_plugin, resolve=True)))
|
58
|
-
return part_plugin_cfg_set_new != self.part_plugin_cfg_set
|
59
|
-
|
60
|
-
@staticmethod
|
61
|
-
def cfg_same(cfg1, cfg2):
|
62
|
-
if cfg1 is None:
|
63
|
-
return cfg2 is None
|
64
|
-
elif cfg2 is None:
|
65
|
-
return cfg1 is None
|
66
|
-
else:
|
67
|
-
return OmegaConf.to_yaml(cfg1) == OmegaConf.to_yaml(cfg2)
|
68
|
-
|
69
|
-
def reload_offload(self) -> bool:
|
70
|
-
if not self.cfg_same(self.cfgs_raw.offload, self.cfgs_raw_old.offload):
|
71
|
-
if self.offload_old:
|
72
|
-
# remove offload hooks
|
73
|
-
accelerate.hooks.remove_hook_from_module(self.pipe.unet, recurse=True)
|
74
|
-
accelerate.hooks.remove_hook_from_module(self.pipe.vae, recurse=True)
|
75
|
-
else:
|
76
|
-
return False
|
77
|
-
|
78
|
-
if self.offload:
|
79
|
-
self.pipe.unet.to('cpu')
|
80
|
-
self.pipe.vae.to('cpu')
|
81
|
-
torch.cuda.empty_cache()
|
82
|
-
torch.cuda.synchronize()
|
83
|
-
self.build_offload(self.cfgs.offload)
|
84
|
-
else:
|
85
|
-
self.pipe.unet.to('cuda')
|
86
|
-
return True
|
87
|
-
|
88
|
-
def reload_emb_hook(self) -> bool:
|
89
|
-
if self.cfgs.emb_dir!=self.cfgs_old.emb_dir or self.cfgs.N_repeats!=self.cfgs_old.N_repeats:
|
90
|
-
self.emb_hook.remove()
|
91
|
-
self.emb_hook, _ = EmbeddingPTHook.hook_from_dir(self.cfgs.emb_dir, self.pipe.tokenizer, self.pipe.text_encoder,
|
92
|
-
N_repeats=self.cfgs.N_repeats)
|
93
|
-
return True
|
94
|
-
return False
|
95
|
-
|
96
|
-
def reload_te_hook(self) -> bool:
|
97
|
-
if self.cfgs.clip_skip != self.cfgs_old.clip_skip or self.cfgs.N_repeats != self.cfgs_old.N_repeats:
|
98
|
-
self.te_hook.N_repeats = self.cfgs.N_repeats
|
99
|
-
self.te_hook.clip_skip = self.cfgs.clip_skip
|
100
|
-
return True
|
101
|
-
return False
|
102
|
-
|
103
|
-
def reload_model(self) -> bool:
|
104
|
-
pipeline = self.get_pipeline()
|
105
|
-
if self.cfgs.pretrained_model!=self.cfgs_old.pretrained_model or self.part_plugin_changed():
|
106
|
-
comp = pipeline.from_pretrained(self.cfgs.pretrained_model, safety_checker=None, requires_safety_checker=False,
|
107
|
-
torch_dtype=self.dtype).components
|
108
|
-
if 'vae' in self.cfgs.new_components:
|
109
|
-
self.cfgs.new_components.vae = hydra.utils.instantiate(self.cfgs.new_components.vae)
|
110
|
-
comp.update(self.cfgs.new_components)
|
111
|
-
self.pipe = pipeline(**comp)
|
112
|
-
if self.cfg_merge:
|
113
|
-
self.merge_model()
|
114
|
-
self.pipe = self.pipe.to(torch_dtype=self.dtype)
|
115
|
-
self.build_optimize()
|
116
|
-
return True
|
117
|
-
return False
|
118
|
-
|
119
|
-
def reload_pipe(self) -> bool:
|
120
|
-
pipeline = self.get_pipeline()
|
121
|
-
if type(self.pipe)!=pipeline:
|
122
|
-
self.pipe = pipeline(**self.pipe.components)
|
123
|
-
return True
|
124
|
-
return False
|
125
|
-
|
126
|
-
|
127
|
-
def reload_scheduler(self) -> bool:
|
128
|
-
if 'scheduler' in self.cfgs_raw_old.new_components and 'scheduler' not in self.cfgs_raw.new_components:
|
129
|
-
# load default scheduler
|
130
|
-
self.pipe.scheduler = PNDMScheduler.from_pretrained(self.cfgs.pretrained_model, subfolder='scheduler', torch_dtype=self.dtype)
|
131
|
-
return True
|
132
|
-
elif not self.cfg_same(getattr(self.cfgs_raw_old.new_components, 'scheduler', {}), getattr(self.cfgs_raw.new_components, 'scheduler', {})):
|
133
|
-
self.pipe.scheduler = self.cfgs.new_components.scheduler
|
134
|
-
return True
|
135
|
-
return False
|
136
|
-
|
137
|
-
def reload_vae(self) -> bool:
|
138
|
-
if 'vae' in self.cfgs_raw_old.new_components and 'vae' not in self.cfgs_raw.new_components:
|
139
|
-
# load default VAE
|
140
|
-
self.cfgs.new_components.vae = AutoencoderKL.from_pretrained(self.cfgs.pretrained_model, subfolder='vae', torch_dtype=self.dtype)
|
141
|
-
return True
|
142
|
-
elif not self.cfg_same(getattr(self.cfgs_raw_old.new_components, 'vae', {}), getattr(self.cfgs_raw.new_components, 'vae', {})):
|
143
|
-
# VAE config changed, need reload
|
144
|
-
if 'vae' in self.cfgs_old.new_components:
|
145
|
-
del self.cfgs_old.new_components.vae
|
146
|
-
torch.cuda.empty_cache()
|
147
|
-
self.cfgs.new_components.vae = hydra.utils.instantiate(self.cfgs.new_components.vae)
|
148
|
-
self.pipe.vae = self.cfgs.new_components.vae
|
149
|
-
return True
|
150
|
-
return False
|
151
|
-
|
152
|
-
def reload_lora(self):
|
153
|
-
if self.cfg_merge is None:
|
154
|
-
if self.cfgs_old.merge is None:
|
155
|
-
return False
|
156
|
-
else:
|
157
|
-
for lora in self.lora_dict.values():
|
158
|
-
lora.remove()
|
159
|
-
self.lora_dict.clear()
|
160
|
-
return True
|
161
|
-
|
162
|
-
cfg_merge = deepcopy(self.cfg_merge)
|
163
|
-
all_lora_hash = set()
|
164
|
-
for k, cfg_group in self.cfg_merge.items():
|
165
|
-
if 'part' in cfg_merge[k]:
|
166
|
-
del cfg_merge[k].part
|
167
|
-
if 'plugin' in cfg_merge[k]:
|
168
|
-
del cfg_merge[k].plugin
|
169
|
-
|
170
|
-
lora_add = []
|
171
|
-
for cfg_lora in getattr(cfg_group, "lora", None) or []:
|
172
|
-
cfg_hash = hash_str(OmegaConf.to_yaml(cfg_lora, resolve=True))
|
173
|
-
if cfg_hash not in self.lora_dict:
|
174
|
-
lora_add.append(cfg_lora)
|
175
|
-
all_lora_hash.add(cfg_hash)
|
176
|
-
cfg_merge[k].lora = OmegaConf.create(lora_add)
|
177
|
-
|
178
|
-
lora_rm_set = set(self.lora_dict.keys())-all_lora_hash
|
179
|
-
for cfg_hash in lora_rm_set:
|
180
|
-
self.lora_dict[cfg_hash].remove()
|
181
|
-
for cfg_hash in lora_rm_set:
|
182
|
-
del self.lora_dict[cfg_hash]
|
183
|
-
|
184
|
-
self._merge_model(cfg_merge)
|
185
|
-
|
186
|
-
def check_reload(self, cfgs):
|
187
|
-
'''
|
188
|
-
Reload and modify each module based on the changes of configuration file.
|
189
|
-
'''
|
190
|
-
self.cfgs_raw_old = self.cfgs_raw
|
191
|
-
self.cfgs_old = self.cfgs
|
192
|
-
self.offload_old = self.offload
|
193
|
-
|
194
|
-
self.cfgs_raw = cfgs
|
195
|
-
|
196
|
-
# Reload vae only when vae config changes
|
197
|
-
if 'vae' in self.cfgs_raw.new_components:
|
198
|
-
vae_cfg = self.cfgs_raw.new_components.vae
|
199
|
-
self.cfgs_raw.new_components.vae = None
|
200
|
-
self.cfgs = hydra.utils.instantiate(self.cfgs_raw)
|
201
|
-
self.cfgs_raw.new_components.vae = vae_cfg
|
202
|
-
self.cfgs.new_components.vae = vae_cfg
|
203
|
-
else:
|
204
|
-
self.cfgs = hydra.utils.instantiate(self.cfgs_raw)
|
205
|
-
|
206
|
-
self.cfg_merge = self.cfgs.merge
|
207
|
-
self.offload = 'offload' in self.cfgs and self.cfgs.offload is not None
|
208
|
-
self.dtype = self.dtype_dict[self.cfgs.dtype]
|
209
|
-
|
210
|
-
self.need_inter_imgs = any(item.need_inter_imgs for item in self.cfgs.interface)
|
211
|
-
|
212
|
-
is_model_reload = self.reload_model()
|
213
|
-
if not is_model_reload:
|
214
|
-
is_vae_reload = self.reload_vae()
|
215
|
-
if is_vae_reload:
|
216
|
-
self.build_vae_offload()
|
217
|
-
self.reload_lora()
|
218
|
-
self.reload_scheduler()
|
219
|
-
self.reload_offload()
|
220
|
-
self.reload_emb_hook()
|
221
|
-
self.reload_te_hook()
|
222
|
-
self.reload_pipe()
|
223
|
-
|
224
|
-
if getattr(self.cfgs, 'vae_optimize', None) is not None:
|
225
|
-
if self.cfgs.vae_optimize.tiling:
|
226
|
-
self.pipe.vae.enable_tiling()
|
227
|
-
else:
|
228
|
-
self.pipe.vae.disable_tiling()
|
229
|
-
|
230
|
-
if self.cfgs.vae_optimize.slicing:
|
231
|
-
self.pipe.vae.enable_slicing()
|
232
|
-
else:
|
233
|
-
self.pipe.vae.disable_slicing()
|
234
|
-
|
235
|
-
del self.cfgs_raw_old
|
236
|
-
del self.cfgs_old
|
237
|
-
|
hcpdiff/workflow/base.py
DELETED
@@ -1,59 +0,0 @@
|
|
1
|
-
from typing import List, Dict
|
2
|
-
from tqdm.auto import tqdm
|
3
|
-
|
4
|
-
class from_memory:
|
5
|
-
#TODO: add memory for all from_memory in cfg
|
6
|
-
def __init__(self, memory, mem_name):
|
7
|
-
self.mem_name = mem_name
|
8
|
-
self.memory = memory
|
9
|
-
|
10
|
-
def __call__(self):
|
11
|
-
memory = self.memory # use in eval
|
12
|
-
return eval(f'memory.{self.mem_name}')
|
13
|
-
|
14
|
-
def from_memory_context(fun):
|
15
|
-
def f(*args, **kwargs):
|
16
|
-
filter_kwargs = {k: (v() if isinstance(v, from_memory) else v) for k,v in kwargs.items()}
|
17
|
-
return fun(*args, **filter_kwargs)
|
18
|
-
return f
|
19
|
-
|
20
|
-
class BasicAction:
|
21
|
-
def __init__(self):
|
22
|
-
pass
|
23
|
-
|
24
|
-
def __call__(self, *args, **kwargs):
|
25
|
-
return self.forward(*args, **kwargs)
|
26
|
-
|
27
|
-
def forward(self, *args, **kwargs):
|
28
|
-
raise NotImplementedError()
|
29
|
-
|
30
|
-
class MemoryMixin:
|
31
|
-
pass
|
32
|
-
|
33
|
-
class ExecAction(BasicAction, MemoryMixin):
|
34
|
-
def __init__(self, prog:str):
|
35
|
-
self.prog = prog
|
36
|
-
|
37
|
-
def forward(self, memory, **states):
|
38
|
-
exec(self.prog)
|
39
|
-
return states
|
40
|
-
|
41
|
-
class LoopAction(BasicAction, MemoryMixin):
|
42
|
-
def __init__(self, loop_value:Dict[str, str], actions:List[BasicAction]):
|
43
|
-
self.loop_value = loop_value
|
44
|
-
self.actions = actions
|
45
|
-
|
46
|
-
def forward(self, memory, **states):
|
47
|
-
loop_data = [states.pop(k) for k in self.loop_value.keys()]
|
48
|
-
pbar = tqdm(zip(*loop_data), total=len(loop_data[0]))
|
49
|
-
N_steps = len(self.actions)
|
50
|
-
for data in pbar:
|
51
|
-
feed_data = {k:v for k,v in zip(self.loop_value.values(), data)}
|
52
|
-
states.update(feed_data)
|
53
|
-
for step, act in enumerate(self.actions):
|
54
|
-
pbar.set_description(f'[{step+1}/{N_steps}] action: {type(act).__name__}')
|
55
|
-
if isinstance(act, MemoryMixin):
|
56
|
-
states = act(memory=memory, **states)
|
57
|
-
else:
|
58
|
-
states = act(**states)
|
59
|
-
return states
|
@@ -1,21 +0,0 @@
|
|
1
|
-
_base_: [cfgs/infer/text2img.yaml]
|
2
|
-
|
3
|
-
pretrained_model: 'deepghs/animefull-latest' # animefull-latest model
|
4
|
-
prompt: 'masterpiece, best quality, 1girl, solo, tohsaka rin' # image of 远坂凛(tohsaka rin)
|
5
|
-
neg_prompt: 'lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry'
|
6
|
-
|
7
|
-
clip_skip: 1 #动漫模型通常会跳过一个CLIP层
|
8
|
-
|
9
|
-
infer_args:
|
10
|
-
width: 512
|
11
|
-
height: 768 # image size
|
12
|
-
guidance_scale: 7.5 # scale, when higher, the images will tend to be more similar
|
13
|
-
num_inference_steps: 30 # how many steps
|
14
|
-
|
15
|
-
new_components:
|
16
|
-
scheduler:
|
17
|
-
_target_: diffusers.EulerAncestralDiscreteScheduler # change Sampler
|
18
|
-
beta_start: 0.00085
|
19
|
-
beta_end: 0.012
|
20
|
-
beta_schedule: 'scaled_linear'
|
21
|
-
|
@@ -1,58 +0,0 @@
|
|
1
|
-
_base_:
|
2
|
-
- cfgs/infer/anime/text2img_anime.yaml
|
3
|
-
|
4
|
-
pretrained_model: 'stablediffusionapi/anything-v5' # better generic anime model
|
5
|
-
|
6
|
-
# safe prompt
|
7
|
-
prompt: 'masterpiece, best quality, highres, game cg, 1girl, solo, {night}, {starry sky}, beach, beautiful detailed sky, {extremely detailed background:1.2}, mature, {surtr_arknights-${model_steps}:1.2}, red_hair, horns, long_hair, purple_eyes, bangs, looking_at_viewer, bare_shoulders, hair_between_eyes, cleavage, {standing}, looking at viewer, {bikini:1.3}, light smile'
|
8
|
-
|
9
|
-
# r18 prompt
|
10
|
-
# prompt: 'nsfw, masterpiece, best quality, highres, 1girl, solo, {lyging on bed}, {extremely detailed background:1.2}, {nude:1.4}, {spread legs}, {arms up}, mature, {surtr_arknights-1000:1.2}, red_hair, horns, long_hair, purple_eyes, bangs, looking_at_viewer, bare_shoulders, hair_between_eyes, cleavage, nipples, {pussy:1.15}, {pussy juice:1.3}, looking at viewer, {embarrassed}, endured face, feet out of frame'
|
11
|
-
|
12
|
-
# negative prompt
|
13
|
-
neg_prompt: 'lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, white border'
|
14
|
-
N_repeats: 2 # if prompt or neg_prompt is too long, increase this number
|
15
|
-
|
16
|
-
bs: 1
|
17
|
-
num: 1
|
18
|
-
|
19
|
-
# when seed is not set, random seed will be used
|
20
|
-
# seed: 758691538 # seed for safe
|
21
|
-
# seed: 465191133 # seed for r18
|
22
|
-
|
23
|
-
infer_args:
|
24
|
-
width: 512
|
25
|
-
height: 768 # image size
|
26
|
-
guidance_scale: 7.5 # scale, when higher, the images will tend to be more similar
|
27
|
-
num_inference_steps: 30 # how many steps
|
28
|
-
|
29
|
-
exp_dir: 'exps/2023-07-26-01-05-35' # experiment directory
|
30
|
-
model_steps: 1000 # steps of selected model
|
31
|
-
emb_dir: '${exp_dir}/ckpts/'
|
32
|
-
output_dir: 'output/'
|
33
|
-
|
34
|
-
merge:
|
35
|
-
alpha: 0.85 # lora权重, default: 0.85
|
36
|
-
|
37
|
-
group1:
|
38
|
-
type: 'unet'
|
39
|
-
base_model_alpha: 1.0 # base model weight to merge with lora or part
|
40
|
-
lora:
|
41
|
-
- path: '${.....exp_dir}/ckpts/unet-${.....model_steps}.safetensors'
|
42
|
-
alpha: ${....alpha}
|
43
|
-
layers: 'all'
|
44
|
-
part: null
|
45
|
-
|
46
|
-
group2:
|
47
|
-
type: 'TE'
|
48
|
-
base_model_alpha: 1.0 # base model weight to merge with lora or part
|
49
|
-
lora:
|
50
|
-
- path: '${.....exp_dir}/ckpts/text_encoder-${.....model_steps}.safetensors'
|
51
|
-
alpha: ${....alpha}
|
52
|
-
layers: 'all'
|
53
|
-
part: null
|
54
|
-
|
55
|
-
interface:
|
56
|
-
- _target_: hcpdiff.vis.DiskInterface
|
57
|
-
show_steps: 0
|
58
|
-
save_root: '${output_dir}'
|