hcpdiff 0.9.1__py3-none-any.whl → 2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/__init__.py +4 -4
- hcpdiff/ckpt_manager/__init__.py +4 -5
- hcpdiff/ckpt_manager/ckpt.py +24 -0
- hcpdiff/ckpt_manager/format/__init__.py +4 -0
- hcpdiff/ckpt_manager/format/diffusers.py +59 -0
- hcpdiff/ckpt_manager/format/emb.py +21 -0
- hcpdiff/ckpt_manager/format/lora_webui.py +252 -0
- hcpdiff/ckpt_manager/format/sd_single.py +41 -0
- hcpdiff/ckpt_manager/loader.py +64 -0
- hcpdiff/data/__init__.py +4 -28
- hcpdiff/data/cache/__init__.py +1 -0
- hcpdiff/data/cache/vae.py +102 -0
- hcpdiff/data/dataset.py +20 -0
- hcpdiff/data/handler/__init__.py +3 -0
- hcpdiff/data/handler/controlnet.py +18 -0
- hcpdiff/data/handler/diffusion.py +90 -0
- hcpdiff/data/handler/text.py +111 -0
- hcpdiff/data/source/__init__.py +3 -3
- hcpdiff/data/source/folder_class.py +12 -29
- hcpdiff/data/source/text.py +40 -0
- hcpdiff/data/source/text2img.py +36 -74
- hcpdiff/data/source/text2img_cond.py +9 -15
- hcpdiff/diffusion/__init__.py +0 -0
- hcpdiff/diffusion/noise/__init__.py +2 -0
- hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
- hcpdiff/diffusion/noise/zero_terminal.py +39 -0
- hcpdiff/diffusion/sampler/__init__.py +5 -0
- hcpdiff/diffusion/sampler/base.py +72 -0
- hcpdiff/diffusion/sampler/ddpm.py +20 -0
- hcpdiff/diffusion/sampler/diffusers.py +66 -0
- hcpdiff/diffusion/sampler/edm.py +22 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
- hcpdiff/easy/__init__.py +2 -0
- hcpdiff/easy/cfg/__init__.py +3 -0
- hcpdiff/easy/cfg/sd15_train.py +207 -0
- hcpdiff/easy/cfg/sdxl_train.py +147 -0
- hcpdiff/easy/cfg/t2i.py +228 -0
- hcpdiff/easy/model/__init__.py +2 -0
- hcpdiff/easy/model/cnet.py +31 -0
- hcpdiff/easy/model/loader.py +79 -0
- hcpdiff/easy/sampler.py +46 -0
- hcpdiff/evaluate/__init__.py +1 -0
- hcpdiff/evaluate/previewer.py +60 -0
- hcpdiff/loss/__init__.py +4 -1
- hcpdiff/loss/base.py +41 -0
- hcpdiff/loss/gw.py +35 -0
- hcpdiff/loss/ssim.py +37 -0
- hcpdiff/loss/vlb.py +79 -0
- hcpdiff/loss/weighting.py +66 -0
- hcpdiff/models/__init__.py +2 -2
- hcpdiff/models/cfg_context.py +17 -14
- hcpdiff/models/compose/compose_hook.py +44 -23
- hcpdiff/models/compose/compose_tokenizer.py +21 -8
- hcpdiff/models/compose/sdxl_composer.py +4 -4
- hcpdiff/models/controlnet.py +16 -16
- hcpdiff/models/lora_base_patch.py +14 -25
- hcpdiff/models/lora_layers.py +3 -9
- hcpdiff/models/lora_layers_patch.py +14 -24
- hcpdiff/models/text_emb_ex.py +84 -6
- hcpdiff/models/textencoder_ex.py +54 -18
- hcpdiff/models/wrapper/__init__.py +3 -0
- hcpdiff/models/wrapper/pixart.py +19 -0
- hcpdiff/models/wrapper/sd.py +218 -0
- hcpdiff/models/wrapper/utils.py +20 -0
- hcpdiff/parser/__init__.py +1 -0
- hcpdiff/parser/embpt.py +32 -0
- hcpdiff/tools/convert_caption_txt2json.py +1 -1
- hcpdiff/tools/dataset_generator.py +94 -0
- hcpdiff/tools/download_hf_model.py +24 -0
- hcpdiff/tools/init_proj.py +3 -21
- hcpdiff/tools/lora_convert.py +18 -17
- hcpdiff/tools/save_model.py +12 -0
- hcpdiff/tools/sd2diffusers.py +1 -1
- hcpdiff/train_colo.py +1 -1
- hcpdiff/train_deepspeed.py +1 -1
- hcpdiff/trainer_ac.py +79 -0
- hcpdiff/trainer_ac_single.py +31 -0
- hcpdiff/utils/__init__.py +0 -2
- hcpdiff/utils/inpaint_pipe.py +7 -2
- hcpdiff/utils/net_utils.py +29 -6
- hcpdiff/utils/pipe_hook.py +24 -7
- hcpdiff/utils/utils.py +21 -4
- hcpdiff/workflow/__init__.py +15 -10
- hcpdiff/workflow/daam/__init__.py +1 -0
- hcpdiff/workflow/daam/act.py +66 -0
- hcpdiff/workflow/daam/hook.py +109 -0
- hcpdiff/workflow/diffusion.py +118 -128
- hcpdiff/workflow/fast.py +31 -0
- hcpdiff/workflow/flow.py +67 -0
- hcpdiff/workflow/io.py +36 -130
- hcpdiff/workflow/model.py +46 -43
- hcpdiff/workflow/text.py +60 -47
- hcpdiff/workflow/utils.py +32 -12
- hcpdiff/workflow/vae.py +37 -38
- hcpdiff-2.2.dist-info/METADATA +299 -0
- hcpdiff-2.2.dist-info/RECORD +115 -0
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/WHEEL +1 -1
- hcpdiff-2.2.dist-info/entry_points.txt +5 -0
- hcpdiff/ckpt_manager/base.py +0 -16
- hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
- hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
- hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
- hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
- hcpdiff/data/bucket.py +0 -358
- hcpdiff/data/caption_loader.py +0 -80
- hcpdiff/data/cond_dataset.py +0 -40
- hcpdiff/data/crop_info_dataset.py +0 -40
- hcpdiff/data/data_processor.py +0 -33
- hcpdiff/data/pair_dataset.py +0 -146
- hcpdiff/data/sampler.py +0 -54
- hcpdiff/data/source/base.py +0 -30
- hcpdiff/data/utils.py +0 -80
- hcpdiff/deprecated/__init__.py +0 -1
- hcpdiff/deprecated/cfg_converter.py +0 -81
- hcpdiff/deprecated/lora_convert.py +0 -31
- hcpdiff/infer_workflow.py +0 -57
- hcpdiff/loggers/__init__.py +0 -13
- hcpdiff/loggers/base_logger.py +0 -76
- hcpdiff/loggers/cli_logger.py +0 -40
- hcpdiff/loggers/preview/__init__.py +0 -1
- hcpdiff/loggers/preview/image_previewer.py +0 -149
- hcpdiff/loggers/tensorboard_logger.py +0 -30
- hcpdiff/loggers/wandb_logger.py +0 -31
- hcpdiff/loggers/webui_logger.py +0 -9
- hcpdiff/loss/min_snr_loss.py +0 -52
- hcpdiff/models/layers.py +0 -81
- hcpdiff/models/plugin.py +0 -348
- hcpdiff/models/wrapper.py +0 -75
- hcpdiff/noise/__init__.py +0 -3
- hcpdiff/noise/noise_base.py +0 -16
- hcpdiff/noise/pyramid_noise.py +0 -50
- hcpdiff/noise/zero_terminal.py +0 -44
- hcpdiff/train_ac.py +0 -566
- hcpdiff/train_ac_single.py +0 -39
- hcpdiff/utils/caption_tools.py +0 -105
- hcpdiff/utils/cfg_net_tools.py +0 -321
- hcpdiff/utils/cfg_resolvers.py +0 -16
- hcpdiff/utils/ema.py +0 -52
- hcpdiff/utils/img_size_tool.py +0 -248
- hcpdiff/vis/__init__.py +0 -3
- hcpdiff/vis/base_interface.py +0 -12
- hcpdiff/vis/disk_interface.py +0 -48
- hcpdiff/vis/webui_interface.py +0 -17
- hcpdiff/viser_fast.py +0 -138
- hcpdiff/visualizer.py +0 -265
- hcpdiff/visualizer_reloadable.py +0 -237
- hcpdiff/workflow/base.py +0 -59
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
- hcpdiff-0.9.1.dist-info/METADATA +0 -199
- hcpdiff-0.9.1.dist-info/RECORD +0 -160
- hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info/licenses}/LICENSE +0 -0
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/top_level.txt +0 -0
hcpdiff/workflow/io.py
CHANGED
@@ -1,150 +1,56 @@
|
|
1
1
|
import os
|
2
|
-
from
|
3
|
-
import
|
2
|
+
from functools import partial
|
3
|
+
from typing import List, Union
|
4
4
|
|
5
|
-
|
6
|
-
|
7
|
-
from hcpdiff.utils import auto_text_encoder, auto_tokenizer, to_validate_file
|
8
|
-
from hcpdiff.utils.cfg_net_tools import HCPModelLoader, make_plugin
|
9
|
-
from hcpdiff.utils.img_size_tool import types_support
|
5
|
+
import torch
|
6
|
+
from hcpdiff.utils import to_validate_file
|
10
7
|
from hcpdiff.utils.net_utils import get_dtype
|
11
|
-
from .
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
8
|
+
from rainbowneko.ckpt_manager import NekoLoader
|
9
|
+
from rainbowneko.infer import BasicAction
|
10
|
+
from rainbowneko.infer import LoadImageAction as Neko_LoadImageAction
|
11
|
+
from rainbowneko.utils.img_size_tool import types_support
|
12
|
+
|
13
|
+
class BuildModelsAction(BasicAction):
|
14
|
+
def __init__(self, model_loader: partial[NekoLoader.load], dtype: str=torch.float32, device='cuda', key_map_in=None, key_map_out=None):
|
15
|
+
super().__init__(key_map_in, key_map_out)
|
16
|
+
self.model_loader = model_loader
|
17
17
|
self.dtype = get_dtype(dtype)
|
18
|
+
self.device = device
|
18
19
|
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
20
|
+
def forward(self, in_preview=False, model=None, **states):
|
21
|
+
if in_preview:
|
22
|
+
model = self.model_loader(dtype=self.dtype, device=self.device, denoiser=model.denoiser, TE=model.TE, vae=model.vae)
|
23
|
+
else:
|
24
|
+
model = self.model_loader(dtype=self.dtype, device=self.device)
|
24
25
|
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
memory.vae = self.vae or AutoencoderKL.from_pretrained(self.pretrained_model, subfolder="vae", torch_dtype=self.dtype)
|
30
|
-
memory.scheduler = self.scheduler or PNDMScheduler.from_pretrained(self.pretrained_model, subfolder="scheduler", torch_dtype=self.dtype)
|
26
|
+
if isinstance(model, dict):
|
27
|
+
return model
|
28
|
+
else:
|
29
|
+
return {'model':model}
|
31
30
|
|
32
|
-
|
31
|
+
class LoadImageAction(Neko_LoadImageAction):
|
32
|
+
def __init__(self, image_paths: Union[str, List[str]], image_transforms=None, key_map_in=None, key_map_out=('input.x -> images',)):
|
33
|
+
super().__init__(image_paths, image_transforms, key_map_in, key_map_out)
|
33
34
|
|
34
35
|
class SaveImageAction(BasicAction):
|
35
|
-
|
36
|
-
|
36
|
+
def __init__(self, save_root: str, image_type: str = 'png', quality: int = 95, save_cfg=True, key_map_in=None, key_map_out=None):
|
37
|
+
super().__init__(key_map_in, key_map_out)
|
37
38
|
self.save_root = save_root
|
38
39
|
self.image_type = image_type
|
39
40
|
self.quality = quality
|
41
|
+
self.save_cfg = save_cfg
|
40
42
|
|
41
43
|
os.makedirs(save_root, exist_ok=True)
|
42
44
|
|
43
|
-
def forward(self, images, prompt, negative_prompt, seeds=None, **states):
|
44
|
-
|
45
|
+
def forward(self, images, prompt, negative_prompt, seeds, cfgs=None, parser=None, preview_root=None, preview_step=None, **states):
|
46
|
+
save_root = preview_root or self.save_root
|
47
|
+
num_img_exist = max([0]+[int(x.split('-', 1)[0]) for x in os.listdir(save_root) if x.rsplit('.', 1)[-1] in types_support])+1
|
45
48
|
|
46
49
|
for bid, (p, pn, img) in enumerate(zip(prompt, negative_prompt, images)):
|
47
|
-
img_path = os.path.join(
|
50
|
+
img_path = os.path.join(save_root, f"{preview_step or num_img_exist}-{seeds[bid]}-{to_validate_file(prompt[0])}.{self.image_type}")
|
48
51
|
img.save(img_path, quality=self.quality)
|
49
52
|
num_img_exist += 1
|
50
53
|
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
def forward(self, memory, **states):
|
55
|
-
memory.model_loader_unet = HCPModelLoader(memory.unet)
|
56
|
-
memory.model_loader_TE = HCPModelLoader(memory.text_encoder)
|
57
|
-
return states
|
58
|
-
|
59
|
-
class LoadPartAction(BasicAction, MemoryMixin):
|
60
|
-
@from_memory_context
|
61
|
-
def __init__(self, model: str, cfg):
|
62
|
-
self.model = model
|
63
|
-
self.cfg = cfg
|
64
|
-
|
65
|
-
def forward(self, memory, **states):
|
66
|
-
model_loader = memory[f"model_loader_{self.model}"]
|
67
|
-
model_loader.load_part(self.cfg)
|
68
|
-
return states
|
69
|
-
|
70
|
-
class LoadLoraAction(BasicAction, MemoryMixin):
|
71
|
-
@from_memory_context
|
72
|
-
def __init__(self, model: str, cfg):
|
73
|
-
self.model = model
|
74
|
-
self.cfg = cfg
|
75
|
-
|
76
|
-
def forward(self, memory, **states):
|
77
|
-
model_loader = memory[f"model_loader_{self.model}"]
|
78
|
-
lora_group = model_loader.load_lora(self.cfg)
|
79
|
-
if 'lora_dict' not in memory:
|
80
|
-
memory.lora_dict = {}
|
81
|
-
if path in memory.lora_dict:
|
82
|
-
warnings.warn(f"Lora {path} already loaded, and will be replaced!")
|
83
|
-
memory.lora_dict[path].remove()
|
84
|
-
memory.lora_dict[path] = lora_group
|
85
|
-
return states
|
86
|
-
|
87
|
-
class BuildPluginAction(BasicAction, MemoryMixin):
|
88
|
-
@from_memory_context
|
89
|
-
def __init__(self, model: str, cfg):
|
90
|
-
self.model = model
|
91
|
-
self.cfg = cfg
|
92
|
-
|
93
|
-
def forward(self, memory, **states):
|
94
|
-
if isinstance(self.cfg_merge.plugin_cfg, str):
|
95
|
-
plugin_cfg = load_config(self.cfg_merge.plugin_cfg)
|
96
|
-
plugin_cfg = {'plugin_unet':hydra.utils.instantiate(plugin_cfg['plugin_unet']),
|
97
|
-
'plugin_TE':hydra.utils.instantiate(plugin_cfg['plugin_TE'])}
|
98
|
-
else:
|
99
|
-
plugin_cfg = self.cfg_merge.plugin_cfg
|
100
|
-
all_plugin_group_unet = make_plugin(memory.unet, plugin_cfg['plugin_unet'])
|
101
|
-
all_plugin_group_TE = make_plugin(memory.text_encoder, plugin_cfg['plugin_TE'])
|
102
|
-
|
103
|
-
if 'plugin_dict' not in memory:
|
104
|
-
memory.plugin_dict = {}
|
105
|
-
|
106
|
-
for name, plugin_group in all_plugin_group_unet.items():
|
107
|
-
memory.plugin_dict[name] = plugin_group
|
108
|
-
for name, plugin_group in all_plugin_group_TE.items():
|
109
|
-
memory.plugin_dict[name] = plugin_group
|
110
|
-
|
111
|
-
return states
|
112
|
-
|
113
|
-
class LoadPluginAction(BasicAction, MemoryMixin):
|
114
|
-
@from_memory_context
|
115
|
-
def __init__(self, model: str, cfg):
|
116
|
-
self.model = model
|
117
|
-
self.cfg = cfg
|
118
|
-
|
119
|
-
def forward(self, memory, **states):
|
120
|
-
model_loader = memory[f"model_loader_{self.model}"]
|
121
|
-
model_loader.load_plugin(self.cfg)
|
122
|
-
return states
|
123
|
-
|
124
|
-
class RemoveLoraAction(BasicAction, MemoryMixin):
|
125
|
-
@from_memory_context
|
126
|
-
def __init__(self, path_list: List[str]):
|
127
|
-
self.path_list = path_list
|
128
|
-
|
129
|
-
def forward(self, memory, **states):
|
130
|
-
for path in self.path_list:
|
131
|
-
if path in memory.lora_dict:
|
132
|
-
memory.lora_dict[path].remove()
|
133
|
-
del memory.lora_dict[path]
|
134
|
-
else:
|
135
|
-
warnings.warn(f"Lora {path} not loaded!")
|
136
|
-
return states
|
137
|
-
|
138
|
-
class RemovePluginAction(BasicAction, MemoryMixin):
|
139
|
-
@from_memory_context
|
140
|
-
def __init__(self, name_list: List[str]):
|
141
|
-
self.name_list = name_list
|
142
|
-
|
143
|
-
def forward(self, memory, **states):
|
144
|
-
for name in self.name_list:
|
145
|
-
if name in memory.plugin_dict:
|
146
|
-
memory.plugin_dict[name].remove()
|
147
|
-
del memory.plugin_dict[name]
|
148
|
-
else:
|
149
|
-
warnings.warn(f"Plugin {name} not loaded!")
|
150
|
-
return states
|
54
|
+
if self.save_cfg:
|
55
|
+
cfgs.seed = seeds[bid]
|
56
|
+
parser.save_configs(cfgs, os.path.join(save_root, f"{preview_step or num_img_exist}-{seeds[bid]}-info"))
|
hcpdiff/workflow/model.py
CHANGED
@@ -1,67 +1,70 @@
|
|
1
|
+
import torch
|
1
2
|
from accelerate import infer_auto_device_map, dispatch_model
|
2
3
|
from diffusers.utils.import_utils import is_xformers_available
|
4
|
+
from rainbowneko.infer import BasicAction
|
3
5
|
|
4
|
-
from hcpdiff.utils.net_utils import get_dtype
|
6
|
+
from hcpdiff.utils.net_utils import get_dtype
|
7
|
+
from hcpdiff.utils.net_utils import to_cpu
|
5
8
|
from hcpdiff.utils.utils import size_to_int, int_to_size
|
6
|
-
from .base import BasicAction, from_memory_context, MemoryMixin
|
7
9
|
|
8
|
-
class VaeOptimizeAction(BasicAction
|
9
|
-
|
10
|
-
|
11
|
-
super().__init__()
|
10
|
+
class VaeOptimizeAction(BasicAction):
|
11
|
+
def __init__(self, slicing=True, tiling=False, key_map_in=None, key_map_out=None):
|
12
|
+
super().__init__(key_map_in, key_map_out)
|
12
13
|
self.slicing = slicing
|
13
14
|
self.tiling = tiling
|
14
|
-
self.vae = vae
|
15
|
-
|
16
|
-
def forward(self, memory, **states):
|
17
|
-
vae = self.vae or memory.vae
|
18
15
|
|
16
|
+
def forward(self, vae, **states):
|
19
17
|
if self.tiling:
|
20
18
|
vae.enable_tiling()
|
21
19
|
if self.slicing:
|
22
20
|
vae.enable_slicing()
|
23
|
-
return states
|
24
21
|
|
25
|
-
class BuildOffloadAction(BasicAction
|
26
|
-
|
27
|
-
|
28
|
-
super().__init__()
|
22
|
+
class BuildOffloadAction(BasicAction):
|
23
|
+
def __init__(self, max_VRAM: str, max_RAM: str, vae_cpu=False, key_map_in=None, key_map_out=None):
|
24
|
+
super().__init__(key_map_in, key_map_out)
|
29
25
|
self.max_VRAM = max_VRAM
|
30
26
|
self.max_RAM = max_RAM
|
27
|
+
self.vae_cpu = vae_cpu
|
31
28
|
|
32
|
-
def forward(self,
|
29
|
+
def forward(self, vae, denoiser, dtype: str, **states):
|
30
|
+
# denoiser offload
|
33
31
|
torch_dtype = get_dtype(dtype)
|
34
32
|
vram = size_to_int(self.max_VRAM)
|
35
|
-
device_map = infer_auto_device_map(
|
36
|
-
|
33
|
+
device_map = infer_auto_device_map(denoiser, max_memory={0:int_to_size(vram >> 1), "cpu":self.max_RAM}, dtype=torch_dtype)
|
34
|
+
denoiser = dispatch_model(denoiser, device_map)
|
37
35
|
|
38
|
-
device_map = infer_auto_device_map(
|
39
|
-
|
40
|
-
|
36
|
+
device_map = infer_auto_device_map(vae, max_memory={0:int_to_size(vram >> 5), "cpu":self.max_RAM}, dtype=torch_dtype)
|
37
|
+
vae = dispatch_model(vae, device_map)
|
38
|
+
# VAE offload
|
39
|
+
vram = size_to_int(self.max_VRAM)
|
40
|
+
if not self.vae_cpu:
|
41
|
+
device_map = infer_auto_device_map(vae, max_memory={0:int_to_size(vram >> 5), "cpu":self.max_RAM}, dtype=torch.float32)
|
42
|
+
vae = dispatch_model(vae, device_map)
|
43
|
+
else:
|
44
|
+
to_cpu(vae)
|
45
|
+
vae_decode_raw = vae.decode
|
41
46
|
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
# self.te_hook.enable_xformers()
|
47
|
-
return states
|
47
|
+
def vae_decode_offload(latents, return_dict=True, decode_raw=vae.decode):
|
48
|
+
vae.to(dtype=torch.float32)
|
49
|
+
res = decode_raw(latents.cpu().to(dtype=torch.float32), return_dict=return_dict)
|
50
|
+
return res
|
48
51
|
|
49
|
-
|
50
|
-
def forward(self, memory, **states):
|
51
|
-
to_cuda(memory.text_encoder)
|
52
|
-
return states
|
52
|
+
vae.decode = vae_decode_offload
|
53
53
|
|
54
|
-
|
55
|
-
def forward(self, memory, **states):
|
56
|
-
to_cpu(memory.text_encoder)
|
57
|
-
return states
|
54
|
+
vae_encode_raw = vae.encode
|
58
55
|
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
56
|
+
def vae_encode_offload(x, return_dict=True, encode_raw=vae.encode):
|
57
|
+
vae.to(dtype=torch.float32)
|
58
|
+
res = encode_raw(x.cpu().to(dtype=torch.float32), return_dict=return_dict)
|
59
|
+
return res
|
63
60
|
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
return
|
61
|
+
vae.encode = vae_encode_offload
|
62
|
+
return {'denoiser':denoiser, 'vae':vae, 'vae_decode_raw':vae_decode_raw, 'vae_encode_raw':vae_encode_raw}
|
63
|
+
|
64
|
+
return {'denoiser':denoiser, 'vae':vae}
|
65
|
+
|
66
|
+
class XformersEnableAction(BasicAction):
|
67
|
+
def forward(self, denoiser, **states):
|
68
|
+
if is_xformers_available():
|
69
|
+
denoiser.enable_xformers_memory_efficient_attention()
|
70
|
+
# self.te_hook.enable_xformers()
|
hcpdiff/workflow/text.py
CHANGED
@@ -3,78 +3,91 @@ from typing import List, Union
|
|
3
3
|
import torch
|
4
4
|
from hcpdiff.models import TokenizerHook
|
5
5
|
from hcpdiff.models.compose import ComposeTEEXHook, ComposeEmbPTHook
|
6
|
+
from hcpdiff.utils import pad_attn_bias
|
6
7
|
from hcpdiff.utils.net_utils import get_dtype, to_cpu, to_cuda
|
8
|
+
from rainbowneko.infer import BasicAction
|
7
9
|
from torch.cuda.amp import autocast
|
8
10
|
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
def __init__(self, TE=None, tokenizer=None, emb_dir: str = 'embs/', N_repeats: int = 1, layer_skip: int = 0, TE_final_norm: bool = True):
|
14
|
-
super().__init__()
|
15
|
-
self.TE = TE
|
16
|
-
self.tokenizer = tokenizer
|
11
|
+
class TextHookAction(BasicAction):
|
12
|
+
def __init__(self, emb_dir: str = None, N_repeats: int = 1, layer_skip: int = 0, TE_final_norm: bool = True,
|
13
|
+
use_attention_mask=False, key_map_in=None, key_map_out=None):
|
14
|
+
super().__init__(key_map_in, key_map_out)
|
17
15
|
|
18
16
|
self.emb_dir = emb_dir
|
19
17
|
self.N_repeats = N_repeats
|
20
18
|
self.layer_skip = layer_skip
|
21
19
|
self.TE_final_norm = TE_final_norm
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
20
|
+
self.use_attention_mask = use_attention_mask
|
21
|
+
|
22
|
+
def forward(self, TE, tokenizer, in_preview=False, te_hook:ComposeTEEXHook=None, emb_hook=None, **states):
|
23
|
+
if in_preview and emb_hook is not None:
|
24
|
+
emb_hook.N_repeats = self.N_repeats
|
25
|
+
else:
|
26
|
+
emb_hook, _ = ComposeEmbPTHook.hook_from_dir(self.emb_dir, tokenizer, TE, N_repeats=self.N_repeats)
|
27
|
+
tokenizer.N_repeats = self.N_repeats
|
28
|
+
|
29
|
+
if in_preview:
|
30
|
+
te_hook.N_repeats = self.N_repeats
|
31
|
+
te_hook.clip_skip = self.layer_skip
|
32
|
+
te_hook.clip_final_norm = self.TE_final_norm
|
33
|
+
te_hook.use_attention_mask = self.use_attention_mask
|
34
|
+
else:
|
35
|
+
te_hook = ComposeTEEXHook.hook(TE, tokenizer, N_repeats=self.N_repeats,
|
36
|
+
clip_skip=self.layer_skip, clip_final_norm=self.TE_final_norm, use_attention_mask=self.use_attention_mask)
|
37
|
+
token_ex = TokenizerHook(tokenizer)
|
38
|
+
return {'te_hook':te_hook, 'emb_hook':emb_hook, 'token_ex':token_ex}
|
39
|
+
|
40
|
+
class TextEncodeAction(BasicAction):
|
41
|
+
def __init__(self, prompt: Union[List, str], negative_prompt: Union[List, str], bs: int = None, key_map_in=None, key_map_out=None):
|
42
|
+
super().__init__(key_map_in, key_map_out)
|
37
43
|
if isinstance(prompt, str) and bs is not None:
|
38
44
|
prompt = [prompt]*bs
|
39
45
|
negative_prompt = [negative_prompt]*bs
|
40
46
|
|
41
47
|
self.prompt = prompt
|
42
48
|
self.negative_prompt = negative_prompt
|
49
|
+
self.bs = bs
|
43
50
|
|
44
|
-
|
51
|
+
def forward(self, te_hook, TE, dtype: str, device, amp=None, prompt=None, negative_prompt=None, model_offload=False, **states):
|
52
|
+
prompt = prompt or self.prompt
|
53
|
+
negative_prompt = negative_prompt or self.negative_prompt
|
54
|
+
|
55
|
+
if model_offload:
|
56
|
+
to_cuda(TE)
|
45
57
|
|
46
|
-
def forward(self, memory, dtype: str, device, amp=None, **states):
|
47
|
-
te_hook = self.te_hook or memory.te_hook
|
48
58
|
with autocast(enabled=amp is not None, dtype=get_dtype(amp)):
|
49
|
-
emb, pooled_output = te_hook.encode_prompt_to_emb(
|
50
|
-
|
51
|
-
|
52
|
-
'device':device, 'dtype':dtype}
|
59
|
+
emb, pooled_output, attention_mask = te_hook.encode_prompt_to_emb(negative_prompt+prompt)
|
60
|
+
if attention_mask is not None:
|
61
|
+
emb, attention_mask = pad_attn_bias(emb, attention_mask)
|
53
62
|
|
54
|
-
|
55
|
-
|
56
|
-
def __init__(self, prompt: Union[List, str], negative_prompt: Union[List, str], bs: int = None, te_hook=None, token_ex=None):
|
57
|
-
super().__init__(prompt, negative_prompt, bs, te_hook)
|
58
|
-
self.token_ex = token_ex
|
63
|
+
if model_offload:
|
64
|
+
to_cpu(TE)
|
59
65
|
|
60
|
-
|
61
|
-
|
62
|
-
|
66
|
+
if not isinstance(te_hook, ComposeTEEXHook):
|
67
|
+
pooled_output = None
|
68
|
+
return {'prompt':prompt, 'negative_prompt':negative_prompt, 'prompt_embeds':emb, 'encoder_attention_mask':attention_mask,
|
69
|
+
'pooled_output':pooled_output}
|
70
|
+
|
71
|
+
class AttnMultTextEncodeAction(TextEncodeAction):
|
72
|
+
def forward(self, te_hook, token_ex, TE, dtype: str, device, amp=None, prompt=None, negative_prompt=None, model_offload=False, **states):
|
73
|
+
prompt = prompt or self.prompt
|
74
|
+
negative_prompt = negative_prompt or self.negative_prompt
|
63
75
|
|
64
|
-
|
65
|
-
|
66
|
-
to_cuda(memory.text_encoder)
|
76
|
+
if model_offload:
|
77
|
+
to_cuda(TE)
|
67
78
|
|
68
|
-
mult_p, clean_text_p = token_ex.parse_attn_mult(
|
69
|
-
mult_n, clean_text_n = token_ex.parse_attn_mult(
|
79
|
+
mult_p, clean_text_p = token_ex.parse_attn_mult(prompt)
|
80
|
+
mult_n, clean_text_n = token_ex.parse_attn_mult(negative_prompt)
|
70
81
|
with autocast(enabled=amp is not None, dtype=get_dtype(amp)):
|
71
82
|
emb, pooled_output, attention_mask = te_hook.encode_prompt_to_emb(clean_text_n+clean_text_p)
|
83
|
+
if attention_mask is not None:
|
84
|
+
emb, attention_mask = pad_attn_bias(emb, attention_mask)
|
72
85
|
emb_n, emb_p = emb.chunk(2)
|
73
86
|
emb_p = te_hook.mult_attn(emb_p, mult_p)
|
74
87
|
emb_n = te_hook.mult_attn(emb_n, mult_n)
|
75
88
|
|
76
|
-
if
|
77
|
-
to_cpu(
|
89
|
+
if model_offload:
|
90
|
+
to_cpu(TE)
|
78
91
|
|
79
|
-
return {
|
80
|
-
'
|
92
|
+
return {'prompt':list(clean_text_p), 'negative_prompt':list(clean_text_n), 'prompt_embeds':torch.cat([emb_n, emb_p], dim=0),
|
93
|
+
'encoder_attention_mask':attention_mask, 'pooled_output':pooled_output}
|
hcpdiff/workflow/utils.py
CHANGED
@@ -1,13 +1,14 @@
|
|
1
|
-
import
|
1
|
+
from typing import List, Union
|
2
2
|
|
3
|
-
|
4
|
-
from torch import nn
|
3
|
+
import torch
|
5
4
|
from PIL import Image
|
6
|
-
from
|
5
|
+
from hcpdiff.data.handler import ControlNetHandler
|
6
|
+
from rainbowneko.infer import BasicAction
|
7
|
+
from torch import nn
|
7
8
|
|
8
9
|
class LatentResizeAction(BasicAction):
|
9
|
-
|
10
|
-
|
10
|
+
def __init__(self, width=1024, height=1024, mode='bicubic', antialias=True, key_map_in=None, key_map_out=None):
|
11
|
+
super().__init__(key_map_in, key_map_out)
|
11
12
|
self.size = (height//8, width//8)
|
12
13
|
self.mode = mode
|
13
14
|
self.antialias = antialias
|
@@ -16,18 +17,37 @@ class LatentResizeAction(BasicAction):
|
|
16
17
|
latents_dtype = latents.dtype
|
17
18
|
latents = nn.functional.interpolate(latents.to(dtype=torch.float32), size=self.size, mode=self.mode)
|
18
19
|
latents = latents.to(dtype=latents_dtype)
|
19
|
-
return {
|
20
|
+
return {'latents':latents}
|
20
21
|
|
21
22
|
class ImageResizeAction(BasicAction):
|
22
23
|
# resample name to Image.xxx
|
23
24
|
mode_map = {'nearest':Image.NEAREST, 'bilinear':Image.BILINEAR, 'bicubic':Image.BICUBIC, 'lanczos':Image.LANCZOS, 'box':Image.BOX,
|
24
|
-
'hamming':Image.HAMMING, 'antialias':Image.
|
25
|
+
'hamming':Image.HAMMING, 'antialias':Image.LANCZOS}
|
25
26
|
|
26
|
-
|
27
|
-
|
27
|
+
def __init__(self, width=1024, height=1024, mode='bicubic', key_map_in=None, key_map_out=None):
|
28
|
+
super().__init__(key_map_in, key_map_out)
|
28
29
|
self.size = (width, height)
|
29
30
|
self.mode = self.mode_map[mode]
|
30
31
|
|
31
|
-
def forward(self, images:List[Image.Image], **states):
|
32
|
+
def forward(self, images: List[Image.Image], **states):
|
32
33
|
images = [image.resize(self.size, resample=self.mode) for image in images]
|
33
|
-
return {
|
34
|
+
return {'images':images}
|
35
|
+
|
36
|
+
class FeedtoCNetAction(BasicAction):
|
37
|
+
def __init__(self, width=None, height=None, key_map_in=None, key_map_out=None):
|
38
|
+
super().__init__(key_map_in, key_map_out)
|
39
|
+
self.size = (width, height)
|
40
|
+
self.cnet_handler = ControlNetHandler()
|
41
|
+
|
42
|
+
def forward(self, images: Union[List[Image.Image], Image.Image], device='cuda', dtype=None, bs=None, latents=None, **states):
|
43
|
+
if bs is None:
|
44
|
+
if 'prompt' in states:
|
45
|
+
bs = len(states['prompt'])
|
46
|
+
|
47
|
+
if latents is not None:
|
48
|
+
width, height = latents.shape[3]*8, latents.shape[2]*8
|
49
|
+
else:
|
50
|
+
width, height = self.size
|
51
|
+
|
52
|
+
images = self.cnet_handler.handle(images).to(device, dtype=dtype).expand(bs*2, 3, width, height)
|
53
|
+
return {'ex_inputs':{'cond':images}}
|
hcpdiff/workflow/vae.py
CHANGED
@@ -1,33 +1,32 @@
|
|
1
|
-
from .base import BasicAction, from_memory_context
|
2
|
-
from diffusers import AutoencoderKL
|
3
|
-
from diffusers.image_processor import VaeImageProcessor
|
4
|
-
from typing import Dict, Any
|
5
1
|
import torch
|
2
|
+
from diffusers.image_processor import VaeImageProcessor
|
6
3
|
from hcpdiff.utils import to_cuda, to_cpu
|
7
4
|
from hcpdiff.utils.net_utils import get_dtype
|
5
|
+
from rainbowneko.infer import BasicAction
|
8
6
|
|
9
7
|
class EncodeAction(BasicAction):
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
self.vae = vae
|
14
|
-
self.vae_scale_factor = 2**(len(self.vae.config.block_out_channels)-1)
|
15
|
-
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) if image_processor is None else image_processor
|
16
|
-
self.offload = offload
|
8
|
+
def __init__(self, image_processor=None, key_map_in=None, key_map_out=None):
|
9
|
+
super().__init__(key_map_in, key_map_out)
|
10
|
+
self.image_processor = image_processor
|
17
11
|
|
18
|
-
def forward(self, images, dtype:str, device, generator, bs=None, **states):
|
12
|
+
def forward(self, vae, images, dtype: str, device, generator, bs=None, model_offload=False, **states):
|
19
13
|
if bs is None:
|
20
14
|
if 'prompt' in states:
|
21
15
|
bs = len(states['prompt'])
|
16
|
+
vae_scale_factor = 2**(len(vae.config.block_out_channels)-1)
|
17
|
+
if self.image_processor is None:
|
18
|
+
self.image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
|
22
19
|
|
23
20
|
image = self.image_processor.preprocess(images)
|
24
|
-
|
21
|
+
if bs is not None and image.shape[0] != bs:
|
22
|
+
image = image.repeat(bs//image.shape[0], 1, 1, 1)
|
23
|
+
image = image.to(device=device, dtype=vae.dtype)
|
25
24
|
|
26
25
|
if image.shape[1] == 4:
|
27
26
|
init_latents = image
|
28
27
|
else:
|
29
|
-
if
|
30
|
-
to_cuda(
|
28
|
+
if model_offload:
|
29
|
+
to_cuda(vae)
|
31
30
|
if isinstance(generator, list) and len(generator) != bs:
|
32
31
|
raise ValueError(
|
33
32
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
@@ -36,38 +35,38 @@ class EncodeAction(BasicAction):
|
|
36
35
|
|
37
36
|
elif isinstance(generator, list):
|
38
37
|
init_latents = [
|
39
|
-
|
38
|
+
vae.encode(image[i: i+1]).latent_dist.sample(generator[i]) for i in range(bs)
|
40
39
|
]
|
41
40
|
init_latents = torch.cat(init_latents, dim=0)
|
42
41
|
else:
|
43
|
-
init_latents =
|
42
|
+
init_latents = vae.encode(image).latent_dist.sample(generator)
|
44
43
|
|
45
|
-
init_latents =
|
46
|
-
if
|
47
|
-
to_cpu(
|
48
|
-
return {
|
44
|
+
init_latents = vae.config.scaling_factor*init_latents.to(dtype=get_dtype(dtype))
|
45
|
+
if model_offload:
|
46
|
+
to_cpu(vae)
|
47
|
+
return {'latents':init_latents}
|
49
48
|
|
50
49
|
class DecodeAction(BasicAction):
|
51
|
-
|
52
|
-
|
53
|
-
super().__init__()
|
54
|
-
self.vae = vae
|
55
|
-
self.offload = offload
|
50
|
+
def __init__(self, image_processor=None, output_type='pil', key_map_in=None, key_map_out=None):
|
51
|
+
super().__init__(key_map_in, key_map_out)
|
56
52
|
|
57
|
-
self.
|
58
|
-
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) if image_processor is None else image_processor
|
53
|
+
self.image_processor = image_processor
|
59
54
|
self.output_type = output_type
|
60
|
-
self.decode_key = decode_key
|
61
55
|
|
62
|
-
def forward(self, **states):
|
63
|
-
|
64
|
-
if self.
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
56
|
+
def forward(self, vae, denoiser, latents, model_offload=False, **states):
|
57
|
+
vae_scale_factor = 2**(len(vae.config.block_out_channels)-1)
|
58
|
+
if self.image_processor is None:
|
59
|
+
self.image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
|
60
|
+
|
61
|
+
if model_offload:
|
62
|
+
to_cpu(denoiser)
|
63
|
+
torch.cuda.synchronize()
|
64
|
+
to_cuda(vae)
|
65
|
+
latents = latents.to(dtype=vae.dtype)
|
66
|
+
image = vae.decode(latents/vae.config.scaling_factor, return_dict=False)[0]
|
67
|
+
if model_offload:
|
68
|
+
to_cpu(vae)
|
70
69
|
|
71
70
|
do_denormalize = [True]*image.shape[0]
|
72
71
|
image = self.image_processor.postprocess(image, output_type=self.output_type, do_denormalize=do_denormalize)
|
73
|
-
return {
|
72
|
+
return {'images':image}
|