hcpdiff 0.9.0__py3-none-any.whl → 2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/__init__.py +4 -4
- hcpdiff/ckpt_manager/__init__.py +4 -5
- hcpdiff/ckpt_manager/ckpt.py +24 -0
- hcpdiff/ckpt_manager/format/__init__.py +4 -0
- hcpdiff/ckpt_manager/format/diffusers.py +59 -0
- hcpdiff/ckpt_manager/format/emb.py +21 -0
- hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
- hcpdiff/ckpt_manager/format/sd_single.py +41 -0
- hcpdiff/ckpt_manager/loader.py +64 -0
- hcpdiff/data/__init__.py +4 -28
- hcpdiff/data/cache/__init__.py +1 -0
- hcpdiff/data/cache/vae.py +102 -0
- hcpdiff/data/dataset.py +20 -0
- hcpdiff/data/handler/__init__.py +3 -0
- hcpdiff/data/handler/controlnet.py +18 -0
- hcpdiff/data/handler/diffusion.py +80 -0
- hcpdiff/data/handler/text.py +111 -0
- hcpdiff/data/source/__init__.py +1 -2
- hcpdiff/data/source/folder_class.py +12 -29
- hcpdiff/data/source/text2img.py +36 -74
- hcpdiff/data/source/text2img_cond.py +9 -15
- hcpdiff/diffusion/__init__.py +0 -0
- hcpdiff/diffusion/noise/__init__.py +2 -0
- hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
- hcpdiff/diffusion/noise/zero_terminal.py +39 -0
- hcpdiff/diffusion/sampler/__init__.py +5 -0
- hcpdiff/diffusion/sampler/base.py +72 -0
- hcpdiff/diffusion/sampler/ddpm.py +20 -0
- hcpdiff/diffusion/sampler/diffusers.py +66 -0
- hcpdiff/diffusion/sampler/edm.py +22 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
- hcpdiff/easy/__init__.py +2 -0
- hcpdiff/easy/cfg/__init__.py +3 -0
- hcpdiff/easy/cfg/sd15_train.py +201 -0
- hcpdiff/easy/cfg/sdxl_train.py +140 -0
- hcpdiff/easy/cfg/t2i.py +177 -0
- hcpdiff/easy/model/__init__.py +2 -0
- hcpdiff/easy/model/cnet.py +31 -0
- hcpdiff/easy/model/loader.py +79 -0
- hcpdiff/easy/sampler.py +46 -0
- hcpdiff/evaluate/__init__.py +1 -0
- hcpdiff/evaluate/previewer.py +60 -0
- hcpdiff/loss/__init__.py +4 -1
- hcpdiff/loss/base.py +41 -0
- hcpdiff/loss/gw.py +35 -0
- hcpdiff/loss/ssim.py +37 -0
- hcpdiff/loss/vlb.py +79 -0
- hcpdiff/loss/weighting.py +66 -0
- hcpdiff/models/__init__.py +2 -2
- hcpdiff/models/cfg_context.py +17 -14
- hcpdiff/models/compose/compose_hook.py +44 -23
- hcpdiff/models/compose/compose_tokenizer.py +21 -8
- hcpdiff/models/compose/sdxl_composer.py +4 -4
- hcpdiff/models/container.py +1 -1
- hcpdiff/models/controlnet.py +16 -16
- hcpdiff/models/lora_base_patch.py +14 -25
- hcpdiff/models/lora_layers.py +3 -9
- hcpdiff/models/lora_layers_patch.py +14 -24
- hcpdiff/models/text_emb_ex.py +84 -6
- hcpdiff/models/textencoder_ex.py +54 -18
- hcpdiff/models/wrapper/__init__.py +3 -0
- hcpdiff/models/wrapper/pixart.py +19 -0
- hcpdiff/models/wrapper/sd.py +218 -0
- hcpdiff/models/wrapper/utils.py +20 -0
- hcpdiff/parser/__init__.py +1 -0
- hcpdiff/parser/embpt.py +32 -0
- hcpdiff/tools/convert_caption_txt2json.py +1 -1
- hcpdiff/tools/dataset_generator.py +94 -0
- hcpdiff/tools/download_hf_model.py +24 -0
- hcpdiff/tools/embedding_convert.py +6 -2
- hcpdiff/tools/init_proj.py +3 -21
- hcpdiff/tools/lora_convert.py +19 -15
- hcpdiff/tools/save_model.py +12 -0
- hcpdiff/tools/sd2diffusers.py +1 -1
- hcpdiff/train_colo.py +1 -1
- hcpdiff/train_deepspeed.py +1 -1
- hcpdiff/trainer_ac.py +79 -0
- hcpdiff/trainer_ac_single.py +31 -0
- hcpdiff/utils/__init__.py +0 -2
- hcpdiff/utils/inpaint_pipe.py +790 -0
- hcpdiff/utils/net_utils.py +29 -6
- hcpdiff/utils/pipe_hook.py +46 -33
- hcpdiff/utils/utils.py +21 -4
- hcpdiff/workflow/__init__.py +15 -10
- hcpdiff/workflow/daam/__init__.py +1 -0
- hcpdiff/workflow/daam/act.py +66 -0
- hcpdiff/workflow/daam/hook.py +109 -0
- hcpdiff/workflow/diffusion.py +128 -136
- hcpdiff/workflow/fast.py +31 -0
- hcpdiff/workflow/flow.py +67 -0
- hcpdiff/workflow/io.py +36 -68
- hcpdiff/workflow/model.py +46 -43
- hcpdiff/workflow/text.py +84 -52
- hcpdiff/workflow/utils.py +32 -12
- hcpdiff/workflow/vae.py +37 -38
- hcpdiff-2.1.dist-info/METADATA +285 -0
- hcpdiff-2.1.dist-info/RECORD +114 -0
- {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
- hcpdiff-2.1.dist-info/entry_points.txt +5 -0
- hcpdiff/ckpt_manager/base.py +0 -16
- hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
- hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
- hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -60
- hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
- hcpdiff/data/bucket.py +0 -358
- hcpdiff/data/caption_loader.py +0 -80
- hcpdiff/data/cond_dataset.py +0 -40
- hcpdiff/data/crop_info_dataset.py +0 -40
- hcpdiff/data/data_processor.py +0 -33
- hcpdiff/data/pair_dataset.py +0 -146
- hcpdiff/data/sampler.py +0 -54
- hcpdiff/data/source/base.py +0 -30
- hcpdiff/data/utils.py +0 -80
- hcpdiff/infer_workflow.py +0 -57
- hcpdiff/loggers/__init__.py +0 -13
- hcpdiff/loggers/base_logger.py +0 -76
- hcpdiff/loggers/cli_logger.py +0 -40
- hcpdiff/loggers/preview/__init__.py +0 -1
- hcpdiff/loggers/preview/image_previewer.py +0 -149
- hcpdiff/loggers/tensorboard_logger.py +0 -30
- hcpdiff/loggers/wandb_logger.py +0 -31
- hcpdiff/loggers/webui_logger.py +0 -9
- hcpdiff/loss/min_snr_loss.py +0 -52
- hcpdiff/models/layers.py +0 -81
- hcpdiff/models/plugin.py +0 -348
- hcpdiff/models/wrapper.py +0 -75
- hcpdiff/noise/__init__.py +0 -3
- hcpdiff/noise/noise_base.py +0 -16
- hcpdiff/noise/pyramid_noise.py +0 -50
- hcpdiff/noise/zero_terminal.py +0 -44
- hcpdiff/train_ac.py +0 -565
- hcpdiff/train_ac_single.py +0 -39
- hcpdiff/utils/caption_tools.py +0 -105
- hcpdiff/utils/cfg_net_tools.py +0 -321
- hcpdiff/utils/cfg_resolvers.py +0 -16
- hcpdiff/utils/ema.py +0 -52
- hcpdiff/utils/img_size_tool.py +0 -248
- hcpdiff/vis/__init__.py +0 -3
- hcpdiff/vis/base_interface.py +0 -12
- hcpdiff/vis/disk_interface.py +0 -48
- hcpdiff/vis/webui_interface.py +0 -17
- hcpdiff/visualizer.py +0 -258
- hcpdiff/visualizer_reloadable.py +0 -237
- hcpdiff/workflow/base.py +0 -59
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -57
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero2.json +0 -32
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero3.json +0 -39
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
- hcpdiff-0.9.0.dist-info/METADATA +0 -199
- hcpdiff-0.9.0.dist-info/RECORD +0 -155
- hcpdiff-0.9.0.dist-info/entry_points.txt +0 -2
- {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
- {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
hcpdiff/workflow/model.py
CHANGED
@@ -1,67 +1,70 @@
|
|
1
|
+
import torch
|
1
2
|
from accelerate import infer_auto_device_map, dispatch_model
|
2
3
|
from diffusers.utils.import_utils import is_xformers_available
|
4
|
+
from rainbowneko.infer import BasicAction
|
3
5
|
|
4
|
-
from hcpdiff.utils.net_utils import get_dtype
|
6
|
+
from hcpdiff.utils.net_utils import get_dtype
|
7
|
+
from hcpdiff.utils.net_utils import to_cpu
|
5
8
|
from hcpdiff.utils.utils import size_to_int, int_to_size
|
6
|
-
from .base import BasicAction, from_memory_context, MemoryMixin
|
7
9
|
|
8
|
-
class VaeOptimizeAction(BasicAction
|
9
|
-
|
10
|
-
|
11
|
-
super().__init__()
|
10
|
+
class VaeOptimizeAction(BasicAction):
|
11
|
+
def __init__(self, slicing=True, tiling=False, key_map_in=None, key_map_out=None):
|
12
|
+
super().__init__(key_map_in, key_map_out)
|
12
13
|
self.slicing = slicing
|
13
14
|
self.tiling = tiling
|
14
|
-
self.vae = vae
|
15
|
-
|
16
|
-
def forward(self, memory, **states):
|
17
|
-
vae = self.vae or memory.vae
|
18
15
|
|
16
|
+
def forward(self, vae, **states):
|
19
17
|
if self.tiling:
|
20
18
|
vae.enable_tiling()
|
21
19
|
if self.slicing:
|
22
20
|
vae.enable_slicing()
|
23
|
-
return states
|
24
21
|
|
25
|
-
class BuildOffloadAction(BasicAction
|
26
|
-
|
27
|
-
|
28
|
-
super().__init__()
|
22
|
+
class BuildOffloadAction(BasicAction):
|
23
|
+
def __init__(self, max_VRAM: str, max_RAM: str, vae_cpu=False, key_map_in=None, key_map_out=None):
|
24
|
+
super().__init__(key_map_in, key_map_out)
|
29
25
|
self.max_VRAM = max_VRAM
|
30
26
|
self.max_RAM = max_RAM
|
27
|
+
self.vae_cpu = vae_cpu
|
31
28
|
|
32
|
-
def forward(self,
|
29
|
+
def forward(self, vae, denoiser, dtype: str, **states):
|
30
|
+
# denoiser offload
|
33
31
|
torch_dtype = get_dtype(dtype)
|
34
32
|
vram = size_to_int(self.max_VRAM)
|
35
|
-
device_map = infer_auto_device_map(
|
36
|
-
|
33
|
+
device_map = infer_auto_device_map(denoiser, max_memory={0:int_to_size(vram >> 1), "cpu":self.max_RAM}, dtype=torch_dtype)
|
34
|
+
denoiser = dispatch_model(denoiser, device_map)
|
37
35
|
|
38
|
-
device_map = infer_auto_device_map(
|
39
|
-
|
40
|
-
|
36
|
+
device_map = infer_auto_device_map(vae, max_memory={0:int_to_size(vram >> 5), "cpu":self.max_RAM}, dtype=torch_dtype)
|
37
|
+
vae = dispatch_model(vae, device_map)
|
38
|
+
# VAE offload
|
39
|
+
vram = size_to_int(self.max_VRAM)
|
40
|
+
if not self.vae_cpu:
|
41
|
+
device_map = infer_auto_device_map(vae, max_memory={0:int_to_size(vram >> 5), "cpu":self.max_RAM}, dtype=torch.float32)
|
42
|
+
vae = dispatch_model(vae, device_map)
|
43
|
+
else:
|
44
|
+
to_cpu(vae)
|
45
|
+
vae_decode_raw = vae.decode
|
41
46
|
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
# self.te_hook.enable_xformers()
|
47
|
-
return states
|
47
|
+
def vae_decode_offload(latents, return_dict=True, decode_raw=vae.decode):
|
48
|
+
vae.to(dtype=torch.float32)
|
49
|
+
res = decode_raw(latents.cpu().to(dtype=torch.float32), return_dict=return_dict)
|
50
|
+
return res
|
48
51
|
|
49
|
-
|
50
|
-
def forward(self, memory, **states):
|
51
|
-
to_cuda(memory.text_encoder)
|
52
|
-
return states
|
52
|
+
vae.decode = vae_decode_offload
|
53
53
|
|
54
|
-
|
55
|
-
def forward(self, memory, **states):
|
56
|
-
to_cpu(memory.text_encoder)
|
57
|
-
return states
|
54
|
+
vae_encode_raw = vae.encode
|
58
55
|
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
56
|
+
def vae_encode_offload(x, return_dict=True, encode_raw=vae.encode):
|
57
|
+
vae.to(dtype=torch.float32)
|
58
|
+
res = encode_raw(x.cpu().to(dtype=torch.float32), return_dict=return_dict)
|
59
|
+
return res
|
63
60
|
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
return
|
61
|
+
vae.encode = vae_encode_offload
|
62
|
+
return {'denoiser':denoiser, 'vae':vae, 'vae_decode_raw':vae_decode_raw, 'vae_encode_raw':vae_encode_raw}
|
63
|
+
|
64
|
+
return {'denoiser':denoiser, 'vae':vae}
|
65
|
+
|
66
|
+
class XformersEnableAction(BasicAction):
|
67
|
+
def forward(self, denoiser, **states):
|
68
|
+
if is_xformers_available():
|
69
|
+
denoiser.enable_xformers_memory_efficient_attention()
|
70
|
+
# self.te_hook.enable_xformers()
|
hcpdiff/workflow/text.py
CHANGED
@@ -1,80 +1,112 @@
|
|
1
1
|
from typing import List, Union
|
2
2
|
|
3
3
|
import torch
|
4
|
-
from torch.cuda.amp import autocast
|
5
|
-
|
6
4
|
from hcpdiff.models import TokenizerHook
|
7
5
|
from hcpdiff.models.compose import ComposeTEEXHook, ComposeEmbPTHook
|
8
|
-
from .
|
6
|
+
from hcpdiff.utils import pad_attn_bias
|
9
7
|
from hcpdiff.utils.net_utils import get_dtype, to_cpu, to_cuda
|
8
|
+
from rainbowneko.infer import BasicAction
|
9
|
+
from torch.cuda.amp import autocast
|
10
10
|
|
11
|
-
class TextHookAction(BasicAction
|
12
|
-
|
13
|
-
|
14
|
-
super().__init__()
|
15
|
-
self.TE = TE
|
16
|
-
self.tokenizer = tokenizer
|
11
|
+
class TextHookAction(BasicAction):
|
12
|
+
def __init__(self, emb_dir: str = None, N_repeats: int = 1, layer_skip: int = 0, TE_final_norm: bool = True,
|
13
|
+
use_attention_mask=False, key_map_in=None, key_map_out=None):
|
14
|
+
super().__init__(key_map_in, key_map_out)
|
17
15
|
|
18
16
|
self.emb_dir = emb_dir
|
19
17
|
self.N_repeats = N_repeats
|
20
18
|
self.layer_skip = layer_skip
|
21
19
|
self.TE_final_norm = TE_final_norm
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
20
|
+
self.use_attention_mask = use_attention_mask
|
21
|
+
|
22
|
+
def forward(self, TE, tokenizer, in_preview=False, te_hook:ComposeTEEXHook=None, emb_hook=None, **states):
|
23
|
+
if in_preview and emb_hook is not None:
|
24
|
+
emb_hook.N_repeats = self.N_repeats
|
25
|
+
else:
|
26
|
+
emb_hook, _ = ComposeEmbPTHook.hook_from_dir(self.emb_dir, tokenizer, TE, N_repeats=self.N_repeats)
|
27
|
+
tokenizer.N_repeats = self.N_repeats
|
28
|
+
|
29
|
+
if in_preview:
|
30
|
+
te_hook.N_repeats = self.N_repeats
|
31
|
+
te_hook.clip_skip = self.layer_skip
|
32
|
+
te_hook.clip_final_norm = self.TE_final_norm
|
33
|
+
te_hook.use_attention_mask = self.use_attention_mask
|
34
|
+
else:
|
35
|
+
te_hook = ComposeTEEXHook.hook(TE, tokenizer, N_repeats=self.N_repeats,
|
36
|
+
clip_skip=self.layer_skip, clip_final_norm=self.TE_final_norm, use_attention_mask=self.use_attention_mask)
|
37
|
+
token_ex = TokenizerHook(tokenizer)
|
38
|
+
return {'te_hook':te_hook, 'emb_hook':emb_hook, 'token_ex':token_ex}
|
39
|
+
|
40
|
+
class TextEncodeAction(BasicAction):
|
41
|
+
def __init__(self, prompt: Union[List, str], negative_prompt: Union[List, str], bs: int = None, key_map_in=None, key_map_out=None):
|
42
|
+
super().__init__(key_map_in, key_map_out)
|
37
43
|
if isinstance(prompt, str) and bs is not None:
|
38
44
|
prompt = [prompt]*bs
|
39
45
|
negative_prompt = [negative_prompt]*bs
|
40
46
|
|
41
47
|
self.prompt = prompt
|
42
48
|
self.negative_prompt = negative_prompt
|
49
|
+
self.bs = bs
|
43
50
|
|
44
|
-
|
51
|
+
def forward(self, te_hook, TE, dtype: str, device, amp=None, gen_step=None, prompt_all=None, negative_prompt_all=None, model_offload=False,
|
52
|
+
**states):
|
53
|
+
prompt_all = prompt_all or self.prompt
|
54
|
+
negative_prompt_all = negative_prompt_all or self.negative_prompt
|
45
55
|
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
56
|
+
if gen_step is not None:
|
57
|
+
idx = (gen_step*self.bs)%len(prompt_all)
|
58
|
+
prompt = prompt_all[idx:idx+self.bs]
|
59
|
+
negative_prompt = negative_prompt_all[idx:idx+self.bs]
|
60
|
+
else:
|
61
|
+
prompt = prompt_all
|
62
|
+
negative_prompt = negative_prompt_all
|
63
|
+
|
64
|
+
if model_offload:
|
65
|
+
to_cuda(TE)
|
66
|
+
|
67
|
+
with autocast(enabled=amp is not None, dtype=get_dtype(amp)):
|
68
|
+
emb, pooled_output, attention_mask = te_hook.encode_prompt_to_emb(negative_prompt+prompt)
|
69
|
+
if attention_mask is not None:
|
70
|
+
emb, attention_mask = pad_attn_bias(emb, attention_mask)
|
71
|
+
|
72
|
+
if model_offload:
|
73
|
+
to_cpu(TE)
|
74
|
+
|
75
|
+
if not isinstance(te_hook, ComposeTEEXHook):
|
76
|
+
pooled_output = None
|
77
|
+
return {'prompt':prompt, 'negative_prompt':negative_prompt, 'prompt_embeds':emb, 'encoder_attention_mask':attention_mask,
|
78
|
+
'pooled_output':pooled_output}
|
52
79
|
|
53
80
|
class AttnMultTextEncodeAction(TextEncodeAction):
|
54
|
-
|
55
|
-
def
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
81
|
+
|
82
|
+
def forward(self, te_hook, token_ex, TE, dtype: str, device, amp=None, gen_step=None, prompt_all=None, negative_prompt_all=None,
|
83
|
+
model_offload=False, **states):
|
84
|
+
prompt_all = prompt_all if prompt_all is not None else self.prompt
|
85
|
+
negative_prompt_all = negative_prompt_all if negative_prompt_all is not None else self.negative_prompt
|
86
|
+
|
87
|
+
if gen_step is not None:
|
88
|
+
idx = (gen_step*self.bs)%len(prompt_all)
|
89
|
+
prompt = prompt_all[idx:idx+self.bs]
|
90
|
+
negative_prompt = negative_prompt_all[idx:idx+self.bs]
|
91
|
+
else:
|
92
|
+
prompt = prompt_all
|
93
|
+
negative_prompt = negative_prompt_all
|
94
|
+
|
95
|
+
if model_offload:
|
96
|
+
to_cuda(TE)
|
97
|
+
|
98
|
+
mult_p, clean_text_p = token_ex.parse_attn_mult(prompt)
|
99
|
+
mult_n, clean_text_n = token_ex.parse_attn_mult(negative_prompt)
|
100
|
+
with autocast(enabled=amp is not None, dtype=get_dtype(amp)):
|
70
101
|
emb, pooled_output, attention_mask = te_hook.encode_prompt_to_emb(clean_text_n+clean_text_p)
|
71
|
-
|
102
|
+
if attention_mask is not None:
|
103
|
+
emb, attention_mask = pad_attn_bias(emb, attention_mask)
|
72
104
|
emb_n, emb_p = emb.chunk(2)
|
73
105
|
emb_p = te_hook.mult_attn(emb_p, mult_p)
|
74
106
|
emb_n = te_hook.mult_attn(emb_n, mult_n)
|
75
107
|
|
76
|
-
if
|
77
|
-
to_cpu(
|
108
|
+
if model_offload:
|
109
|
+
to_cpu(TE)
|
78
110
|
|
79
|
-
return {
|
80
|
-
'
|
111
|
+
return {'prompt':list(clean_text_p), 'negative_prompt':list(clean_text_n), 'prompt_embeds':torch.cat([emb_n, emb_p], dim=0),
|
112
|
+
'encoder_attention_mask':attention_mask, 'pooled_output':pooled_output}
|
hcpdiff/workflow/utils.py
CHANGED
@@ -1,13 +1,14 @@
|
|
1
|
-
import
|
1
|
+
from typing import List, Union
|
2
2
|
|
3
|
-
|
4
|
-
from torch import nn
|
3
|
+
import torch
|
5
4
|
from PIL import Image
|
6
|
-
from
|
5
|
+
from hcpdiff.data.handler import ControlNetHandler
|
6
|
+
from rainbowneko.infer import BasicAction
|
7
|
+
from torch import nn
|
7
8
|
|
8
9
|
class LatentResizeAction(BasicAction):
|
9
|
-
|
10
|
-
|
10
|
+
def __init__(self, width=1024, height=1024, mode='bicubic', antialias=True, key_map_in=None, key_map_out=None):
|
11
|
+
super().__init__(key_map_in, key_map_out)
|
11
12
|
self.size = (height//8, width//8)
|
12
13
|
self.mode = mode
|
13
14
|
self.antialias = antialias
|
@@ -16,18 +17,37 @@ class LatentResizeAction(BasicAction):
|
|
16
17
|
latents_dtype = latents.dtype
|
17
18
|
latents = nn.functional.interpolate(latents.to(dtype=torch.float32), size=self.size, mode=self.mode)
|
18
19
|
latents = latents.to(dtype=latents_dtype)
|
19
|
-
return {
|
20
|
+
return {'latents':latents}
|
20
21
|
|
21
22
|
class ImageResizeAction(BasicAction):
|
22
23
|
# resample name to Image.xxx
|
23
24
|
mode_map = {'nearest':Image.NEAREST, 'bilinear':Image.BILINEAR, 'bicubic':Image.BICUBIC, 'lanczos':Image.LANCZOS, 'box':Image.BOX,
|
24
|
-
'hamming':Image.HAMMING, 'antialias':Image.
|
25
|
+
'hamming':Image.HAMMING, 'antialias':Image.LANCZOS}
|
25
26
|
|
26
|
-
|
27
|
-
|
27
|
+
def __init__(self, width=1024, height=1024, mode='bicubic', key_map_in=None, key_map_out=None):
|
28
|
+
super().__init__(key_map_in, key_map_out)
|
28
29
|
self.size = (width, height)
|
29
30
|
self.mode = self.mode_map[mode]
|
30
31
|
|
31
|
-
def forward(self, images:List[Image.Image], **states):
|
32
|
+
def forward(self, images: List[Image.Image], **states):
|
32
33
|
images = [image.resize(self.size, resample=self.mode) for image in images]
|
33
|
-
return {
|
34
|
+
return {'images':images}
|
35
|
+
|
36
|
+
class FeedtoCNetAction(BasicAction):
|
37
|
+
def __init__(self, width=None, height=None, key_map_in=None, key_map_out=None):
|
38
|
+
super().__init__(key_map_in, key_map_out)
|
39
|
+
self.size = (width, height)
|
40
|
+
self.cnet_handler = ControlNetHandler()
|
41
|
+
|
42
|
+
def forward(self, images: Union[List[Image.Image], Image.Image], device='cuda', dtype=None, bs=None, latents=None, **states):
|
43
|
+
if bs is None:
|
44
|
+
if 'prompt' in states:
|
45
|
+
bs = len(states['prompt'])
|
46
|
+
|
47
|
+
if latents is not None:
|
48
|
+
width, height = latents.shape[3]*8, latents.shape[2]*8
|
49
|
+
else:
|
50
|
+
width, height = self.size
|
51
|
+
|
52
|
+
images = self.cnet_handler.handle(images).to(device, dtype=dtype).expand(bs*2, 3, width, height)
|
53
|
+
return {'ex_inputs':{'cond':images}}
|
hcpdiff/workflow/vae.py
CHANGED
@@ -1,33 +1,32 @@
|
|
1
|
-
from .base import BasicAction, from_memory_context
|
2
|
-
from diffusers import AutoencoderKL
|
3
|
-
from diffusers.image_processor import VaeImageProcessor
|
4
|
-
from typing import Dict, Any
|
5
1
|
import torch
|
2
|
+
from diffusers.image_processor import VaeImageProcessor
|
6
3
|
from hcpdiff.utils import to_cuda, to_cpu
|
7
4
|
from hcpdiff.utils.net_utils import get_dtype
|
5
|
+
from rainbowneko.infer import BasicAction
|
8
6
|
|
9
7
|
class EncodeAction(BasicAction):
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
self.vae = vae
|
14
|
-
self.vae_scale_factor = 2**(len(self.vae.config.block_out_channels)-1)
|
15
|
-
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) if image_processor is None else image_processor
|
16
|
-
self.offload = offload
|
8
|
+
def __init__(self, image_processor=None, key_map_in=None, key_map_out=None):
|
9
|
+
super().__init__(key_map_in, key_map_out)
|
10
|
+
self.image_processor = image_processor
|
17
11
|
|
18
|
-
def forward(self, images, dtype:str, device, generator, bs=None, **states):
|
12
|
+
def forward(self, vae, images, dtype: str, device, generator, bs=None, model_offload=False, **states):
|
19
13
|
if bs is None:
|
20
14
|
if 'prompt' in states:
|
21
15
|
bs = len(states['prompt'])
|
16
|
+
vae_scale_factor = 2**(len(vae.config.block_out_channels)-1)
|
17
|
+
if self.image_processor is None:
|
18
|
+
self.image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
|
22
19
|
|
23
20
|
image = self.image_processor.preprocess(images)
|
24
|
-
|
21
|
+
if bs is not None and image.shape[0] != bs:
|
22
|
+
image = image.repeat(bs//image.shape[0], 1, 1, 1)
|
23
|
+
image = image.to(device=device, dtype=vae.dtype)
|
25
24
|
|
26
25
|
if image.shape[1] == 4:
|
27
26
|
init_latents = image
|
28
27
|
else:
|
29
|
-
if
|
30
|
-
to_cuda(
|
28
|
+
if model_offload:
|
29
|
+
to_cuda(vae)
|
31
30
|
if isinstance(generator, list) and len(generator) != bs:
|
32
31
|
raise ValueError(
|
33
32
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
@@ -36,38 +35,38 @@ class EncodeAction(BasicAction):
|
|
36
35
|
|
37
36
|
elif isinstance(generator, list):
|
38
37
|
init_latents = [
|
39
|
-
|
38
|
+
vae.encode(image[i: i+1]).latent_dist.sample(generator[i]) for i in range(bs)
|
40
39
|
]
|
41
40
|
init_latents = torch.cat(init_latents, dim=0)
|
42
41
|
else:
|
43
|
-
init_latents =
|
42
|
+
init_latents = vae.encode(image).latent_dist.sample(generator)
|
44
43
|
|
45
|
-
init_latents =
|
46
|
-
if
|
47
|
-
to_cpu(
|
48
|
-
return {
|
44
|
+
init_latents = vae.config.scaling_factor*init_latents.to(dtype=get_dtype(dtype))
|
45
|
+
if model_offload:
|
46
|
+
to_cpu(vae)
|
47
|
+
return {'latents':init_latents}
|
49
48
|
|
50
49
|
class DecodeAction(BasicAction):
|
51
|
-
|
52
|
-
|
53
|
-
super().__init__()
|
54
|
-
self.vae = vae
|
55
|
-
self.offload = offload
|
50
|
+
def __init__(self, image_processor=None, output_type='pil', key_map_in=None, key_map_out=None):
|
51
|
+
super().__init__(key_map_in, key_map_out)
|
56
52
|
|
57
|
-
self.
|
58
|
-
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) if image_processor is None else image_processor
|
53
|
+
self.image_processor = image_processor
|
59
54
|
self.output_type = output_type
|
60
|
-
self.decode_key = decode_key
|
61
55
|
|
62
|
-
def forward(self, **states):
|
63
|
-
|
64
|
-
if self.
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
56
|
+
def forward(self, vae, denoiser, latents, model_offload=False, **states):
|
57
|
+
vae_scale_factor = 2**(len(vae.config.block_out_channels)-1)
|
58
|
+
if self.image_processor is None:
|
59
|
+
self.image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
|
60
|
+
|
61
|
+
if model_offload:
|
62
|
+
to_cpu(denoiser)
|
63
|
+
torch.cuda.synchronize()
|
64
|
+
to_cuda(vae)
|
65
|
+
latents = latents.to(dtype=vae.dtype)
|
66
|
+
image = vae.decode(latents/vae.config.scaling_factor, return_dict=False)[0]
|
67
|
+
if model_offload:
|
68
|
+
to_cpu(vae)
|
70
69
|
|
71
70
|
do_denormalize = [True]*image.shape[0]
|
72
71
|
image = self.image_processor.postprocess(image, output_type=self.output_type, do_denormalize=do_denormalize)
|
73
|
-
return {
|
72
|
+
return {'images':image}
|