hcpdiff 0.9.1__py3-none-any.whl → 2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/__init__.py +4 -4
- hcpdiff/ckpt_manager/__init__.py +4 -5
- hcpdiff/ckpt_manager/ckpt.py +24 -0
- hcpdiff/ckpt_manager/format/__init__.py +4 -0
- hcpdiff/ckpt_manager/format/diffusers.py +59 -0
- hcpdiff/ckpt_manager/format/emb.py +21 -0
- hcpdiff/ckpt_manager/format/lora_webui.py +252 -0
- hcpdiff/ckpt_manager/format/sd_single.py +41 -0
- hcpdiff/ckpt_manager/loader.py +64 -0
- hcpdiff/data/__init__.py +4 -28
- hcpdiff/data/cache/__init__.py +1 -0
- hcpdiff/data/cache/vae.py +102 -0
- hcpdiff/data/dataset.py +20 -0
- hcpdiff/data/handler/__init__.py +3 -0
- hcpdiff/data/handler/controlnet.py +18 -0
- hcpdiff/data/handler/diffusion.py +90 -0
- hcpdiff/data/handler/text.py +111 -0
- hcpdiff/data/source/__init__.py +3 -3
- hcpdiff/data/source/folder_class.py +12 -29
- hcpdiff/data/source/text.py +40 -0
- hcpdiff/data/source/text2img.py +36 -74
- hcpdiff/data/source/text2img_cond.py +9 -15
- hcpdiff/diffusion/__init__.py +0 -0
- hcpdiff/diffusion/noise/__init__.py +2 -0
- hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
- hcpdiff/diffusion/noise/zero_terminal.py +39 -0
- hcpdiff/diffusion/sampler/__init__.py +5 -0
- hcpdiff/diffusion/sampler/base.py +72 -0
- hcpdiff/diffusion/sampler/ddpm.py +20 -0
- hcpdiff/diffusion/sampler/diffusers.py +66 -0
- hcpdiff/diffusion/sampler/edm.py +22 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
- hcpdiff/easy/__init__.py +2 -0
- hcpdiff/easy/cfg/__init__.py +3 -0
- hcpdiff/easy/cfg/sd15_train.py +207 -0
- hcpdiff/easy/cfg/sdxl_train.py +147 -0
- hcpdiff/easy/cfg/t2i.py +228 -0
- hcpdiff/easy/model/__init__.py +2 -0
- hcpdiff/easy/model/cnet.py +31 -0
- hcpdiff/easy/model/loader.py +79 -0
- hcpdiff/easy/sampler.py +46 -0
- hcpdiff/evaluate/__init__.py +1 -0
- hcpdiff/evaluate/previewer.py +60 -0
- hcpdiff/loss/__init__.py +4 -1
- hcpdiff/loss/base.py +41 -0
- hcpdiff/loss/gw.py +35 -0
- hcpdiff/loss/ssim.py +37 -0
- hcpdiff/loss/vlb.py +79 -0
- hcpdiff/loss/weighting.py +66 -0
- hcpdiff/models/__init__.py +2 -2
- hcpdiff/models/cfg_context.py +17 -14
- hcpdiff/models/compose/compose_hook.py +44 -23
- hcpdiff/models/compose/compose_tokenizer.py +21 -8
- hcpdiff/models/compose/sdxl_composer.py +4 -4
- hcpdiff/models/controlnet.py +16 -16
- hcpdiff/models/lora_base_patch.py +14 -25
- hcpdiff/models/lora_layers.py +3 -9
- hcpdiff/models/lora_layers_patch.py +14 -24
- hcpdiff/models/text_emb_ex.py +84 -6
- hcpdiff/models/textencoder_ex.py +54 -18
- hcpdiff/models/wrapper/__init__.py +3 -0
- hcpdiff/models/wrapper/pixart.py +19 -0
- hcpdiff/models/wrapper/sd.py +218 -0
- hcpdiff/models/wrapper/utils.py +20 -0
- hcpdiff/parser/__init__.py +1 -0
- hcpdiff/parser/embpt.py +32 -0
- hcpdiff/tools/convert_caption_txt2json.py +1 -1
- hcpdiff/tools/dataset_generator.py +94 -0
- hcpdiff/tools/download_hf_model.py +24 -0
- hcpdiff/tools/init_proj.py +3 -21
- hcpdiff/tools/lora_convert.py +18 -17
- hcpdiff/tools/save_model.py +12 -0
- hcpdiff/tools/sd2diffusers.py +1 -1
- hcpdiff/train_colo.py +1 -1
- hcpdiff/train_deepspeed.py +1 -1
- hcpdiff/trainer_ac.py +79 -0
- hcpdiff/trainer_ac_single.py +31 -0
- hcpdiff/utils/__init__.py +0 -2
- hcpdiff/utils/inpaint_pipe.py +7 -2
- hcpdiff/utils/net_utils.py +29 -6
- hcpdiff/utils/pipe_hook.py +24 -7
- hcpdiff/utils/utils.py +21 -4
- hcpdiff/workflow/__init__.py +15 -10
- hcpdiff/workflow/daam/__init__.py +1 -0
- hcpdiff/workflow/daam/act.py +66 -0
- hcpdiff/workflow/daam/hook.py +109 -0
- hcpdiff/workflow/diffusion.py +118 -128
- hcpdiff/workflow/fast.py +31 -0
- hcpdiff/workflow/flow.py +67 -0
- hcpdiff/workflow/io.py +36 -130
- hcpdiff/workflow/model.py +46 -43
- hcpdiff/workflow/text.py +60 -47
- hcpdiff/workflow/utils.py +32 -12
- hcpdiff/workflow/vae.py +37 -38
- hcpdiff-2.2.dist-info/METADATA +299 -0
- hcpdiff-2.2.dist-info/RECORD +115 -0
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/WHEEL +1 -1
- hcpdiff-2.2.dist-info/entry_points.txt +5 -0
- hcpdiff/ckpt_manager/base.py +0 -16
- hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
- hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
- hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
- hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
- hcpdiff/data/bucket.py +0 -358
- hcpdiff/data/caption_loader.py +0 -80
- hcpdiff/data/cond_dataset.py +0 -40
- hcpdiff/data/crop_info_dataset.py +0 -40
- hcpdiff/data/data_processor.py +0 -33
- hcpdiff/data/pair_dataset.py +0 -146
- hcpdiff/data/sampler.py +0 -54
- hcpdiff/data/source/base.py +0 -30
- hcpdiff/data/utils.py +0 -80
- hcpdiff/deprecated/__init__.py +0 -1
- hcpdiff/deprecated/cfg_converter.py +0 -81
- hcpdiff/deprecated/lora_convert.py +0 -31
- hcpdiff/infer_workflow.py +0 -57
- hcpdiff/loggers/__init__.py +0 -13
- hcpdiff/loggers/base_logger.py +0 -76
- hcpdiff/loggers/cli_logger.py +0 -40
- hcpdiff/loggers/preview/__init__.py +0 -1
- hcpdiff/loggers/preview/image_previewer.py +0 -149
- hcpdiff/loggers/tensorboard_logger.py +0 -30
- hcpdiff/loggers/wandb_logger.py +0 -31
- hcpdiff/loggers/webui_logger.py +0 -9
- hcpdiff/loss/min_snr_loss.py +0 -52
- hcpdiff/models/layers.py +0 -81
- hcpdiff/models/plugin.py +0 -348
- hcpdiff/models/wrapper.py +0 -75
- hcpdiff/noise/__init__.py +0 -3
- hcpdiff/noise/noise_base.py +0 -16
- hcpdiff/noise/pyramid_noise.py +0 -50
- hcpdiff/noise/zero_terminal.py +0 -44
- hcpdiff/train_ac.py +0 -566
- hcpdiff/train_ac_single.py +0 -39
- hcpdiff/utils/caption_tools.py +0 -105
- hcpdiff/utils/cfg_net_tools.py +0 -321
- hcpdiff/utils/cfg_resolvers.py +0 -16
- hcpdiff/utils/ema.py +0 -52
- hcpdiff/utils/img_size_tool.py +0 -248
- hcpdiff/vis/__init__.py +0 -3
- hcpdiff/vis/base_interface.py +0 -12
- hcpdiff/vis/disk_interface.py +0 -48
- hcpdiff/vis/webui_interface.py +0 -17
- hcpdiff/viser_fast.py +0 -138
- hcpdiff/visualizer.py +0 -265
- hcpdiff/visualizer_reloadable.py +0 -237
- hcpdiff/workflow/base.py +0 -59
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
- hcpdiff-0.9.1.dist-info/METADATA +0 -199
- hcpdiff-0.9.1.dist-info/RECORD +0 -160
- hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info/licenses}/LICENSE +0 -0
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1 @@
|
|
1
|
+
from .act import CaptureCrossAttnAction, SaveWordAttnAction
|
@@ -0,0 +1,66 @@
|
|
1
|
+
import os
|
2
|
+
from io import BytesIO
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
from PIL import Image
|
6
|
+
from hcpdiff.utils import to_validate_file
|
7
|
+
from rainbowneko.utils import types_support
|
8
|
+
from matplotlib import pyplot as plt
|
9
|
+
from rainbowneko.infer import BasicAction, Actions
|
10
|
+
|
11
|
+
from .hook import DiffusionHeatMapHooker
|
12
|
+
|
13
|
+
class CaptureCrossAttnAction(Actions):
|
14
|
+
def forward(self, prompt, denoiser, tokenizer, vae, **states):
|
15
|
+
bs = len(prompt)
|
16
|
+
N_head = 8
|
17
|
+
with DiffusionHeatMapHooker(denoiser, tokenizer, vae_scale_factor=vae.vae_scale_factor) as tc:
|
18
|
+
states = super().forward(**states)
|
19
|
+
heat_maps = [tc.compute_global_heat_map(prompt=prompt[i], head_idxs=range(N_head*i, N_head*(i+1))) for i in range(bs)]
|
20
|
+
|
21
|
+
return {**states, 'cross_attn_heat_maps':heat_maps}
|
22
|
+
|
23
|
+
class SaveWordAttnAction(BasicAction):
|
24
|
+
|
25
|
+
def __init__(self, save_root: str, N_col: int = 4, image_type: str = 'png', quality: int = 95, key_map_in=None, key_map_out=None):
|
26
|
+
super().__init__(key_map_in, key_map_out)
|
27
|
+
self.save_root = save_root
|
28
|
+
self.image_type = image_type
|
29
|
+
self.quality = quality
|
30
|
+
self.N_col = N_col
|
31
|
+
|
32
|
+
os.makedirs(save_root, exist_ok=True)
|
33
|
+
|
34
|
+
def draw_attn(self, tokenizer, prompt, image, global_heat_map):
|
35
|
+
prompt=tokenizer.bos_token+prompt+tokenizer.eos_token
|
36
|
+
tokens = [token.replace("</w>", "") for token in tokenizer.tokenize(prompt)]
|
37
|
+
|
38
|
+
d_len = self.N_col
|
39
|
+
plt.rcParams['figure.dpi'] = 300
|
40
|
+
plt.rcParams.update({'font.size':12})
|
41
|
+
h = int(np.ceil(len(tokens)/d_len))
|
42
|
+
fig, ax = plt.subplots(h, d_len, figsize=(2*d_len, 2*h))
|
43
|
+
for ax_ in ax.flatten():
|
44
|
+
ax_.set_xticks([])
|
45
|
+
ax_.set_yticks([])
|
46
|
+
for i, token in enumerate(tokens):
|
47
|
+
heat_map = global_heat_map.compute_word_heat_map(token, word_idx=i)
|
48
|
+
if h==1:
|
49
|
+
heat_map.plot_overlay(image, ax=ax[i%d_len])
|
50
|
+
else:
|
51
|
+
heat_map.plot_overlay(image, ax=ax[i//d_len, i%d_len])
|
52
|
+
# plt.tight_layout()
|
53
|
+
|
54
|
+
buf = BytesIO()
|
55
|
+
plt.savefig(buf, format='png')
|
56
|
+
buf.seek(0)
|
57
|
+
return Image.open(buf)
|
58
|
+
|
59
|
+
def forward(self, tokenizer, images, prompt, seeds, cross_attn_heat_maps, **states):
|
60
|
+
num_img_exist = max([0]+[int(x.split('-', 1)[0]) for x in os.listdir(self.save_root) if x.rsplit('.', 1)[-1] in types_support])
|
61
|
+
|
62
|
+
for bid, (p, img) in enumerate(zip(prompt, images)):
|
63
|
+
img_path = os.path.join(self.save_root, f"{num_img_exist}-{seeds[bid]}-cross_attn-{to_validate_file(prompt[0])}.{self.image_type}")
|
64
|
+
img = self.draw_attn(tokenizer, p, img, cross_attn_heat_maps[bid])
|
65
|
+
img.save(img_path, quality=self.quality)
|
66
|
+
num_img_exist += 1
|
@@ -0,0 +1,109 @@
|
|
1
|
+
from daam import AggregateHooker, RawHeatMapCollection, UNetCrossAttentionLocator, GlobalHeatMap
|
2
|
+
from daam.trace import UNetCrossAttentionHooker
|
3
|
+
from typing import List
|
4
|
+
from diffusers import UNet2DConditionModel
|
5
|
+
from PIL import Image
|
6
|
+
import numpy as np
|
7
|
+
import torch
|
8
|
+
import torch.nn.functional as F
|
9
|
+
|
10
|
+
def auto_autocast(*args, **kwargs):
|
11
|
+
if not torch.cuda.is_available():
|
12
|
+
kwargs['enabled'] = False
|
13
|
+
|
14
|
+
return torch.cuda.amp.autocast(*args, **kwargs)
|
15
|
+
|
16
|
+
class DiffusionHeatMapHooker(AggregateHooker):
|
17
|
+
def __init__(
|
18
|
+
self,
|
19
|
+
unet: UNet2DConditionModel,
|
20
|
+
tokenizer,
|
21
|
+
vae_scale_factor: int,
|
22
|
+
low_memory: bool = False,
|
23
|
+
load_heads: bool = False,
|
24
|
+
save_heads: bool = False,
|
25
|
+
data_dir: str = None
|
26
|
+
):
|
27
|
+
self.all_heat_maps = RawHeatMapCollection()
|
28
|
+
h = (unet.config.sample_size * vae_scale_factor)
|
29
|
+
self.latent_hw = 4096 if h == 512 or h == 1024 else 9216 # 64x64 or 96x96 depending on if it's 2.0-v or 2.0
|
30
|
+
locate_middle = load_heads or save_heads
|
31
|
+
self.locator = UNetCrossAttentionLocator(restrict={0} if low_memory else None, locate_middle_block=locate_middle)
|
32
|
+
self.last_prompt: str = ''
|
33
|
+
self.last_image: Image.Image = None
|
34
|
+
self.time_idx = 0
|
35
|
+
self._gen_idx = 0
|
36
|
+
|
37
|
+
self.tokenizer = tokenizer
|
38
|
+
|
39
|
+
modules = [
|
40
|
+
UNetCrossAttentionHooker(
|
41
|
+
x,
|
42
|
+
self,
|
43
|
+
layer_idx=idx,
|
44
|
+
latent_hw=self.latent_hw,
|
45
|
+
load_heads=load_heads,
|
46
|
+
save_heads=save_heads,
|
47
|
+
data_dir=data_dir
|
48
|
+
) for idx, x in enumerate(self.locator.locate(unet))
|
49
|
+
]
|
50
|
+
|
51
|
+
super().__init__(modules)
|
52
|
+
|
53
|
+
def time_callback(self, *args, **kwargs):
|
54
|
+
self.time_idx += 1
|
55
|
+
|
56
|
+
@property
|
57
|
+
def layer_names(self):
|
58
|
+
return self.locator.layer_names
|
59
|
+
|
60
|
+
def compute_global_heat_map(self, prompt=None, factors=None, head_idxs: List[int]=None, layer_idx=None, normalize=False):
|
61
|
+
# type: (str, List[float], int, int, bool) -> GlobalHeatMap
|
62
|
+
"""
|
63
|
+
Compute the global heat map for the given prompt, aggregating across time (inference steps) and space (different
|
64
|
+
spatial transformer block heat maps).
|
65
|
+
|
66
|
+
Args:
|
67
|
+
prompt: The prompt to compute the heat map for. If none, uses the last prompt that was used for generation.
|
68
|
+
factors: Restrict the application to heat maps with spatial factors in this set. If `None`, use all sizes.
|
69
|
+
head_idx: Restrict the application to heat maps with this head index. If `None`, use all heads.
|
70
|
+
layer_idx: Restrict the application to heat maps with this layer index. If `None`, use all layers.
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
A heat map object for computing word-level heat maps.
|
74
|
+
"""
|
75
|
+
heat_maps = self.all_heat_maps
|
76
|
+
|
77
|
+
if prompt is None:
|
78
|
+
prompt = self.last_prompt
|
79
|
+
|
80
|
+
if factors is None:
|
81
|
+
factors = {0, 1, 2, 4, 8, 16, 32, 64}
|
82
|
+
else:
|
83
|
+
factors = set(factors)
|
84
|
+
|
85
|
+
all_merges = []
|
86
|
+
x = int(np.sqrt(self.latent_hw))
|
87
|
+
|
88
|
+
with auto_autocast(dtype=torch.float32):
|
89
|
+
for (factor, layer, head), heat_map in heat_maps:
|
90
|
+
if (head_idxs is None or head in head_idxs) and (layer_idx is None or layer_idx == layer):
|
91
|
+
heat_map = heat_map.unsqueeze(1)/25 # [L,1,H,W]
|
92
|
+
# The clamping fixes undershoot.
|
93
|
+
all_merges.append(F.interpolate(heat_map, size=(x, x), mode='bicubic').clamp_(min=0))
|
94
|
+
|
95
|
+
try:
|
96
|
+
maps = torch.stack(all_merges, dim=0) # [B*head, L, 1, H, W]
|
97
|
+
except RuntimeError:
|
98
|
+
if head_idxs is not None or layer_idx is not None:
|
99
|
+
raise RuntimeError('No heat maps found for the given parameters.')
|
100
|
+
else:
|
101
|
+
raise RuntimeError('No heat maps found. Did you forget to call `with trace(...)` during generation?')
|
102
|
+
|
103
|
+
maps = maps.mean(0)[:, 0] # [L,H,W]
|
104
|
+
#maps = maps[:len(self.tokenizer.tokenize(prompt)) + 2] # 1 for SOS and 1 for padding
|
105
|
+
|
106
|
+
if normalize:
|
107
|
+
maps = maps / (maps[1:-1].sum(0, keepdim=True) + 1e-6) # drop out [SOS] and [PAD] for proper probabilities
|
108
|
+
|
109
|
+
return GlobalHeatMap(self.tokenizer, prompt, maps)
|
hcpdiff/workflow/diffusion.py
CHANGED
@@ -1,209 +1,199 @@
|
|
1
|
-
import
|
1
|
+
import random
|
2
|
+
import warnings
|
2
3
|
from typing import Dict, Any, Union, List
|
3
4
|
|
4
5
|
import torch
|
6
|
+
from hcpdiff.diffusion.sampler import BaseSampler, DiffusersSampler
|
7
|
+
from hcpdiff.utils import prepare_seed
|
8
|
+
from hcpdiff.utils.net_utils import get_dtype, to_cuda
|
9
|
+
from rainbowneko.infer import BasicAction
|
5
10
|
from torch.cuda.amp import autocast
|
6
11
|
|
7
|
-
from .base import BasicAction, from_memory_context, MemoryMixin
|
8
|
-
|
9
12
|
try:
|
10
13
|
from diffusers.utils import randn_tensor
|
11
14
|
except:
|
12
15
|
# new version of diffusers
|
13
16
|
from diffusers.utils.torch_utils import randn_tensor
|
14
17
|
|
15
|
-
from hcpdiff.utils import prepare_seed
|
16
|
-
from hcpdiff.utils.net_utils import get_dtype
|
17
|
-
import random
|
18
|
-
|
19
18
|
class InputFeederAction(BasicAction):
|
20
|
-
|
21
|
-
|
22
|
-
super().__init__()
|
19
|
+
def __init__(self, ex_inputs: Dict[str, Any], key_map_in=None, key_map_out=None):
|
20
|
+
super().__init__(key_map_in, key_map_out)
|
23
21
|
self.ex_inputs = ex_inputs
|
24
|
-
self.unet = unet
|
25
22
|
|
26
|
-
def forward(self, **states):
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
23
|
+
def forward(self, model, ex_inputs=None, **states):
|
24
|
+
ex_inputs = self.ex_inputs if ex_inputs is None else {**ex_inputs, **self.ex_inputs}
|
25
|
+
if hasattr(model, 'input_feeder'):
|
26
|
+
for feeder in model.input_feeder:
|
27
|
+
feeder(ex_inputs)
|
31
28
|
|
32
29
|
class SeedAction(BasicAction):
|
33
|
-
|
34
|
-
|
35
|
-
super().__init__()
|
30
|
+
def __init__(self, seed: Union[int, List[int]], bs: int = 1, key_map_in=None, key_map_out=None):
|
31
|
+
super().__init__(key_map_in, key_map_out)
|
36
32
|
self.seed = seed
|
37
33
|
self.bs = bs
|
38
34
|
|
39
|
-
def forward(self, device, **states):
|
35
|
+
def forward(self, device, seed=None, **states):
|
40
36
|
bs = states['prompt_embeds'].shape[0]//2 if 'prompt_embeds' in states else self.bs
|
41
|
-
|
37
|
+
seed = seed or self.seed
|
38
|
+
if seed is None:
|
42
39
|
seeds = [None]*bs
|
43
|
-
elif isinstance(
|
44
|
-
seeds = list(range(
|
40
|
+
elif isinstance(seed, int):
|
41
|
+
seeds = list(range(seed, seed+bs))
|
45
42
|
else:
|
46
|
-
seeds =
|
43
|
+
seeds = seed
|
47
44
|
seeds = [s or random.randint(0, 1 << 30) for s in seeds]
|
48
45
|
|
49
46
|
G = prepare_seed(seeds, device=device)
|
50
|
-
return {
|
47
|
+
return {'seeds':seeds, 'generator':G}
|
51
48
|
|
52
|
-
class PrepareDiffusionAction(BasicAction
|
53
|
-
def __init__(self,
|
54
|
-
|
49
|
+
class PrepareDiffusionAction(BasicAction):
|
50
|
+
def __init__(self, model_offload=False, amp=torch.float16, key_map_in=None, key_map_out=None):
|
51
|
+
super().__init__(key_map_in, key_map_out)
|
52
|
+
self.model_offload = model_offload
|
55
53
|
self.amp = amp
|
56
54
|
|
57
|
-
def forward(self,
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
memory.vae.to(dtype=dtype)
|
55
|
+
def forward(self, device, denoiser, TE, vae, **states):
|
56
|
+
denoiser.to(device)
|
57
|
+
TE.to(device)
|
58
|
+
vae.to(device)
|
62
59
|
|
63
|
-
|
64
|
-
|
65
|
-
|
60
|
+
TE.eval()
|
61
|
+
denoiser.eval()
|
62
|
+
vae.eval()
|
63
|
+
return {'amp':self.amp, 'model_offload':self.model_offload}
|
66
64
|
|
67
|
-
class MakeTimestepsAction(BasicAction
|
68
|
-
|
69
|
-
|
70
|
-
self.scheduler = scheduler
|
65
|
+
class MakeTimestepsAction(BasicAction):
|
66
|
+
def __init__(self, N_steps: int = 30, strength: float = None, key_map_in=None, key_map_out=None):
|
67
|
+
super().__init__(key_map_in, key_map_out)
|
71
68
|
self.N_steps = N_steps
|
72
69
|
self.strength = strength
|
73
70
|
|
74
|
-
def get_timesteps(self, timesteps, strength):
|
71
|
+
def get_timesteps(self, noise_sampler:BaseSampler, timesteps, strength):
|
75
72
|
# get the original timestep using init_timestep
|
76
73
|
num_inference_steps = len(timesteps)
|
77
74
|
init_timestep = min(int(num_inference_steps*strength), num_inference_steps)
|
78
75
|
|
79
76
|
t_start = max(num_inference_steps-init_timestep, 0)
|
80
|
-
|
77
|
+
if isinstance(noise_sampler, DiffusersSampler):
|
78
|
+
timesteps = timesteps[t_start*noise_sampler.scheduler.order:]
|
79
|
+
else:
|
80
|
+
timesteps = timesteps[t_start:]
|
81
81
|
|
82
82
|
return timesteps
|
83
83
|
|
84
|
-
def forward(self,
|
85
|
-
|
86
|
-
|
87
|
-
self.scheduler.set_timesteps(self.N_steps, device=device)
|
88
|
-
timesteps = self.scheduler.timesteps
|
84
|
+
def forward(self, noise_sampler:BaseSampler, device, **states):
|
85
|
+
timesteps = noise_sampler.get_timesteps(self.N_steps, device=device)
|
89
86
|
if self.strength:
|
90
|
-
timesteps = self.get_timesteps(timesteps, self.strength)
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
def __init__(self,
|
97
|
-
|
87
|
+
timesteps = self.get_timesteps(noise_sampler, timesteps, self.strength)
|
88
|
+
return {'timesteps':timesteps, 'start_timestep':timesteps[:1]}
|
89
|
+
else:
|
90
|
+
return {'timesteps':timesteps}
|
91
|
+
|
92
|
+
class MakeLatentAction(BasicAction):
|
93
|
+
def __init__(self, N_ch=4, height=None, width=None, key_map_in=None, key_map_out=None):
|
94
|
+
super().__init__(key_map_in, key_map_out)
|
98
95
|
self.N_ch = N_ch
|
99
96
|
self.height = height
|
100
97
|
self.width = width
|
101
98
|
|
102
|
-
def forward(self,
|
99
|
+
def forward(self, noise_sampler:BaseSampler, vae, generator, device, dtype, bs=None, latents=None, start_timestep=None,
|
100
|
+
pooled_output=None, crop_coord=None, **states):
|
103
101
|
if bs is None:
|
104
102
|
if 'prompt' in states:
|
105
103
|
bs = len(states['prompt'])
|
106
|
-
|
104
|
+
vae_scale_factor = 2**(len(vae.config.block_out_channels)-1)
|
105
|
+
device = torch.device(device)
|
107
106
|
|
108
|
-
|
107
|
+
if latents is None:
|
108
|
+
shape = (bs, self.N_ch, self.height//vae_scale_factor, self.width//vae_scale_factor)
|
109
|
+
else:
|
110
|
+
if self.height is not None:
|
111
|
+
warnings.warn('latents exist! User-specified width and height will be ignored!')
|
112
|
+
shape = latents.shape
|
109
113
|
if isinstance(generator, list) and len(generator) != bs:
|
110
114
|
raise ValueError(
|
111
115
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
112
116
|
f" size of {bs}. Make sure the batch size matches the length of the generators."
|
113
117
|
)
|
114
118
|
|
115
|
-
noise = randn_tensor(shape, generator=generator, device=device, dtype=get_dtype(dtype))
|
116
119
|
if latents is None:
|
117
|
-
# scale the initial noise by the standard deviation required by the
|
118
|
-
|
120
|
+
# scale the initial noise by the standard deviation required by the noise_sampler
|
121
|
+
noise_sampler.generator = generator
|
122
|
+
latents = noise_sampler.init_noise(shape, device=device, dtype=get_dtype(dtype))
|
119
123
|
else:
|
120
124
|
# image to image
|
121
125
|
latents = latents.to(device)
|
122
|
-
latents =
|
126
|
+
latents, noise = noise_sampler.add_noise(latents, start_timestep)
|
123
127
|
|
124
|
-
|
128
|
+
output = {'latents':latents}
|
125
129
|
|
126
|
-
|
127
|
-
|
128
|
-
|
130
|
+
# SDXL inputs
|
131
|
+
if pooled_output is not None:
|
132
|
+
width, height = shape[3]*vae_scale_factor, shape[2]*vae_scale_factor
|
133
|
+
if crop_coord is None:
|
134
|
+
crop_info = torch.tensor([height, width, 0, 0, height, width], dtype=torch.float)
|
135
|
+
else:
|
136
|
+
crop_info = torch.tensor([height, width, *crop_coord], dtype=torch.float)
|
137
|
+
crop_info = crop_info.to(device).repeat(bs, 1)
|
138
|
+
output['text_embeds'] = pooled_output[-1].to(device)
|
139
|
+
|
140
|
+
if 'negative_prompt' in states:
|
141
|
+
output['crop_info'] = torch.cat([crop_info, crop_info], dim=0)
|
142
|
+
|
143
|
+
return output
|
144
|
+
|
145
|
+
class DenoiseAction(BasicAction):
|
146
|
+
def __init__(self, guidance_scale: float = 7.0, key_map_in=None, key_map_out=None):
|
147
|
+
super().__init__(key_map_in, key_map_out)
|
129
148
|
self.guidance_scale = guidance_scale
|
130
|
-
self.unet = unet
|
131
|
-
self.scheduler = scheduler
|
132
149
|
|
133
|
-
def forward(self,
|
134
|
-
cross_attention_kwargs=None, dtype='fp32', amp=None, **states):
|
135
|
-
|
136
|
-
|
150
|
+
def forward(self, denoiser, noise_sampler: BaseSampler, t, latents, prompt_embeds, text_embeds=None, encoder_attention_mask=None, crop_info=None,
|
151
|
+
cross_attention_kwargs=None, dtype='fp32', amp=None, model_offload=False, **states):
|
152
|
+
|
153
|
+
if model_offload:
|
154
|
+
to_cuda(denoiser) # to_cpu in VAE
|
137
155
|
|
138
156
|
with autocast(enabled=amp is not None, dtype=get_dtype(amp)):
|
139
157
|
latent_model_input = torch.cat([latents]*2) if self.guidance_scale>1 else latents
|
140
|
-
latent_model_input =
|
158
|
+
latent_model_input = noise_sampler.c_in(t)*latent_model_input
|
141
159
|
|
142
|
-
if
|
143
|
-
noise_pred =
|
144
|
-
|
160
|
+
if text_embeds is None:
|
161
|
+
noise_pred = denoiser(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
|
162
|
+
cross_attention_kwargs=cross_attention_kwargs, ).sample
|
145
163
|
else:
|
146
|
-
added_cond_kwargs = {"text_embeds":
|
164
|
+
added_cond_kwargs = {"text_embeds":text_embeds, "time_ids":crop_info}
|
147
165
|
# predict the noise residual
|
148
|
-
noise_pred =
|
149
|
-
|
166
|
+
noise_pred = denoiser(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
|
167
|
+
cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
|
150
168
|
|
151
169
|
# perform guidance
|
152
170
|
if self.guidance_scale>1:
|
153
171
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
154
172
|
noise_pred = noise_pred_uncond+self.guidance_scale*(noise_pred_text-noise_pred_uncond)
|
155
173
|
|
156
|
-
return {
|
157
|
-
'crop_info':crop_info, 'cross_attention_kwargs':cross_attention_kwargs, 'dtype':dtype, 'amp':amp}
|
158
|
-
|
159
|
-
class SampleAction(BasicAction, MemoryMixin):
|
160
|
-
@from_memory_context
|
161
|
-
def __init__(self, scheduler=None, eta=0.0):
|
162
|
-
self.scheduler = scheduler
|
163
|
-
self.eta = eta
|
164
|
-
|
165
|
-
def prepare_extra_step_kwargs(self, generator, eta):
|
166
|
-
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
167
|
-
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
168
|
-
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
169
|
-
# and should be between [0, 1]
|
170
|
-
|
171
|
-
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
172
|
-
extra_step_kwargs = {}
|
173
|
-
if accepts_eta:
|
174
|
-
extra_step_kwargs["eta"] = eta
|
175
|
-
|
176
|
-
# check if the scheduler accepts generator
|
177
|
-
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
178
|
-
if accepts_generator:
|
179
|
-
extra_step_kwargs["generator"] = generator
|
180
|
-
return extra_step_kwargs
|
181
|
-
|
182
|
-
def forward(self, memory, noise_pred, t, latents, generator, **states):
|
183
|
-
self.scheduler = self.scheduler or memory.scheduler
|
184
|
-
|
185
|
-
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, self.eta)
|
174
|
+
return {'noise_pred':noise_pred}
|
186
175
|
|
176
|
+
class SampleAction(BasicAction):
|
177
|
+
def forward(self, noise_sampler: BaseSampler, noise_pred, t, latents, generator, **states):
|
187
178
|
# compute the previous noisy sample x_t -> x_t-1
|
188
|
-
|
189
|
-
latents
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
self.
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
states = self.
|
200
|
-
states = self.act_sample(memory=memory, **states)
|
179
|
+
latents = noise_sampler.denoise(latents, t, noise_pred, generator=generator)
|
180
|
+
return {'latents':latents}
|
181
|
+
|
182
|
+
class DiffusionStepAction(BasicAction):
|
183
|
+
def __init__(self, guidance_scale: float = 7.0, key_map_in=None, key_map_out=None):
|
184
|
+
super().__init__(key_map_in, key_map_out)
|
185
|
+
self.act_noise_pred = DenoiseAction(guidance_scale)
|
186
|
+
self.act_sample = SampleAction()
|
187
|
+
|
188
|
+
def forward(self, denoiser, noise_sampler, **states):
|
189
|
+
states = self.act_noise_pred(denoiser=denoiser, noise_sampler=noise_sampler, **states)
|
190
|
+
states = self.act_sample(**states)
|
201
191
|
return states
|
202
192
|
|
203
193
|
class X0PredAction(BasicAction):
|
204
|
-
def forward(self, latents,
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
194
|
+
def forward(self, latents, noise_sampler: BaseSampler, t, noise_pred, **states):
|
195
|
+
latents_x0 = noise_sampler.eps_to_x0(noise_pred, latents, t)
|
196
|
+
return {'latents_x0':latents_x0}
|
197
|
+
|
198
|
+
def time_iter(timesteps, **states):
|
199
|
+
return [{'t':t} for t in timesteps]
|
hcpdiff/workflow/fast.py
ADDED
@@ -0,0 +1,31 @@
|
|
1
|
+
from sfast.compilers.diffusion_pipeline_compiler import (compile_unet, CompilationConfig)
|
2
|
+
from rainbowneko.infer import BasicAction
|
3
|
+
|
4
|
+
|
5
|
+
class SFastCompileAction(BasicAction):
|
6
|
+
|
7
|
+
@staticmethod
|
8
|
+
def compile_model(unet):
|
9
|
+
# compile model
|
10
|
+
config = CompilationConfig.Default()
|
11
|
+
config.enable_xformers = False
|
12
|
+
try:
|
13
|
+
import xformers
|
14
|
+
config.enable_xformers = True
|
15
|
+
except ImportError:
|
16
|
+
print('xformers not installed, skip')
|
17
|
+
# NOTE:
|
18
|
+
# When GPU VRAM is insufficient or the architecture is too old, Triton might be slow.
|
19
|
+
# Disable Triton if you encounter this problem.
|
20
|
+
try:
|
21
|
+
import tritonx
|
22
|
+
config.enable_triton = True
|
23
|
+
except ImportError:
|
24
|
+
print('Triton not installed, skip')
|
25
|
+
config.enable_cuda_graph = True
|
26
|
+
|
27
|
+
return compile_unet(unet, config)
|
28
|
+
|
29
|
+
def forward(self, denoiser, **states):
|
30
|
+
denoiser = self.compile_model(denoiser)
|
31
|
+
return {'denoiser': denoiser}
|
hcpdiff/workflow/flow.py
ADDED
@@ -0,0 +1,67 @@
|
|
1
|
+
from rainbowneko.infer import BasicAction
|
2
|
+
from typing import List, Dict
|
3
|
+
from tqdm import tqdm
|
4
|
+
import math
|
5
|
+
|
6
|
+
class FilePromptAction(BasicAction):
|
7
|
+
def __init__(self, actions: List[BasicAction], prompt: str, negative_prompt: str, bs: int = 4, key_map_in=None, key_map_out=None):
|
8
|
+
super().__init__(key_map_in, key_map_out)
|
9
|
+
if prompt.endswith('.txt'):
|
10
|
+
with open(prompt, 'r') as f:
|
11
|
+
prompt = f.read().split('\n')
|
12
|
+
else:
|
13
|
+
prompt = [prompt]
|
14
|
+
|
15
|
+
if negative_prompt.endswith('.txt'):
|
16
|
+
with open(negative_prompt, 'r') as f:
|
17
|
+
negative_prompt = f.read().split('\n')
|
18
|
+
else:
|
19
|
+
negative_prompt = [negative_prompt]*len(prompt)
|
20
|
+
|
21
|
+
self.prompt = prompt
|
22
|
+
self.negative_prompt = negative_prompt
|
23
|
+
self.bs = bs
|
24
|
+
self.actions = actions
|
25
|
+
|
26
|
+
|
27
|
+
def forward(self, **states):
|
28
|
+
states.update({'prompt_all':self.prompt, 'negative_prompt_all':self.negative_prompt})
|
29
|
+
states_ref = dict(**states)
|
30
|
+
|
31
|
+
pbar = tqdm(range(math.ceil(len(self.prompt)/self.bs)))
|
32
|
+
N_steps = len(self.actions)
|
33
|
+
for gen_step in pbar:
|
34
|
+
states = dict(**states_ref)
|
35
|
+
feed_data = {'gen_step': gen_step}
|
36
|
+
states.update(feed_data)
|
37
|
+
for step, act in enumerate(self.actions):
|
38
|
+
pbar.set_description(f'[{step+1}/{N_steps}] action: {type(act).__name__}')
|
39
|
+
states = act(**states)
|
40
|
+
return states
|
41
|
+
|
42
|
+
class FlowPromptAction(BasicAction):
|
43
|
+
def __init__(self, actions: List[BasicAction], prompt: str, negative_prompt: str, bs: int = 4, num: int = None, key_map_in=None, key_map_out=None):
|
44
|
+
super().__init__(key_map_in, key_map_out)
|
45
|
+
prompt = [prompt]*num
|
46
|
+
negative_prompt = [negative_prompt]*num
|
47
|
+
|
48
|
+
self.prompt = prompt
|
49
|
+
self.negative_prompt = negative_prompt
|
50
|
+
self.bs = bs
|
51
|
+
self.actions = actions
|
52
|
+
|
53
|
+
|
54
|
+
def forward(self, **states):
|
55
|
+
states.update({'prompt_all':self.prompt, 'negative_prompt_all':self.negative_prompt})
|
56
|
+
states_ref = dict(**states)
|
57
|
+
|
58
|
+
pbar = tqdm(range(math.ceil(len(self.prompt)/self.bs)))
|
59
|
+
N_steps = len(self.actions)
|
60
|
+
for gen_step in pbar:
|
61
|
+
states = dict(**states_ref)
|
62
|
+
feed_data = {'gen_step': gen_step}
|
63
|
+
states.update(feed_data)
|
64
|
+
for step, act in enumerate(self.actions):
|
65
|
+
pbar.set_description(f'[{step+1}/{N_steps}] action: {type(act).__name__}')
|
66
|
+
states = act(**states)
|
67
|
+
return states
|