hcpdiff 0.9.1__py3-none-any.whl → 2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/__init__.py +4 -4
- hcpdiff/ckpt_manager/__init__.py +4 -5
- hcpdiff/ckpt_manager/ckpt.py +24 -0
- hcpdiff/ckpt_manager/format/__init__.py +4 -0
- hcpdiff/ckpt_manager/format/diffusers.py +59 -0
- hcpdiff/ckpt_manager/format/emb.py +21 -0
- hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
- hcpdiff/ckpt_manager/format/sd_single.py +41 -0
- hcpdiff/ckpt_manager/loader.py +64 -0
- hcpdiff/data/__init__.py +4 -28
- hcpdiff/data/cache/__init__.py +1 -0
- hcpdiff/data/cache/vae.py +102 -0
- hcpdiff/data/dataset.py +20 -0
- hcpdiff/data/handler/__init__.py +3 -0
- hcpdiff/data/handler/controlnet.py +18 -0
- hcpdiff/data/handler/diffusion.py +80 -0
- hcpdiff/data/handler/text.py +111 -0
- hcpdiff/data/source/__init__.py +1 -2
- hcpdiff/data/source/folder_class.py +12 -29
- hcpdiff/data/source/text2img.py +36 -74
- hcpdiff/data/source/text2img_cond.py +9 -15
- hcpdiff/diffusion/__init__.py +0 -0
- hcpdiff/diffusion/noise/__init__.py +2 -0
- hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
- hcpdiff/diffusion/noise/zero_terminal.py +39 -0
- hcpdiff/diffusion/sampler/__init__.py +5 -0
- hcpdiff/diffusion/sampler/base.py +72 -0
- hcpdiff/diffusion/sampler/ddpm.py +20 -0
- hcpdiff/diffusion/sampler/diffusers.py +66 -0
- hcpdiff/diffusion/sampler/edm.py +22 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
- hcpdiff/easy/__init__.py +2 -0
- hcpdiff/easy/cfg/__init__.py +3 -0
- hcpdiff/easy/cfg/sd15_train.py +201 -0
- hcpdiff/easy/cfg/sdxl_train.py +140 -0
- hcpdiff/easy/cfg/t2i.py +177 -0
- hcpdiff/easy/model/__init__.py +2 -0
- hcpdiff/easy/model/cnet.py +31 -0
- hcpdiff/easy/model/loader.py +79 -0
- hcpdiff/easy/sampler.py +46 -0
- hcpdiff/evaluate/__init__.py +1 -0
- hcpdiff/evaluate/previewer.py +60 -0
- hcpdiff/loss/__init__.py +4 -1
- hcpdiff/loss/base.py +41 -0
- hcpdiff/loss/gw.py +35 -0
- hcpdiff/loss/ssim.py +37 -0
- hcpdiff/loss/vlb.py +79 -0
- hcpdiff/loss/weighting.py +66 -0
- hcpdiff/models/__init__.py +2 -2
- hcpdiff/models/cfg_context.py +17 -14
- hcpdiff/models/compose/compose_hook.py +44 -23
- hcpdiff/models/compose/compose_tokenizer.py +21 -8
- hcpdiff/models/compose/sdxl_composer.py +4 -4
- hcpdiff/models/controlnet.py +16 -16
- hcpdiff/models/lora_base_patch.py +14 -25
- hcpdiff/models/lora_layers.py +3 -9
- hcpdiff/models/lora_layers_patch.py +14 -24
- hcpdiff/models/text_emb_ex.py +84 -6
- hcpdiff/models/textencoder_ex.py +54 -18
- hcpdiff/models/wrapper/__init__.py +3 -0
- hcpdiff/models/wrapper/pixart.py +19 -0
- hcpdiff/models/wrapper/sd.py +218 -0
- hcpdiff/models/wrapper/utils.py +20 -0
- hcpdiff/parser/__init__.py +1 -0
- hcpdiff/parser/embpt.py +32 -0
- hcpdiff/tools/convert_caption_txt2json.py +1 -1
- hcpdiff/tools/dataset_generator.py +94 -0
- hcpdiff/tools/download_hf_model.py +24 -0
- hcpdiff/tools/init_proj.py +3 -21
- hcpdiff/tools/lora_convert.py +18 -17
- hcpdiff/tools/save_model.py +12 -0
- hcpdiff/tools/sd2diffusers.py +1 -1
- hcpdiff/train_colo.py +1 -1
- hcpdiff/train_deepspeed.py +1 -1
- hcpdiff/trainer_ac.py +79 -0
- hcpdiff/trainer_ac_single.py +31 -0
- hcpdiff/utils/__init__.py +0 -2
- hcpdiff/utils/inpaint_pipe.py +7 -2
- hcpdiff/utils/net_utils.py +29 -6
- hcpdiff/utils/pipe_hook.py +24 -7
- hcpdiff/utils/utils.py +21 -4
- hcpdiff/workflow/__init__.py +15 -10
- hcpdiff/workflow/daam/__init__.py +1 -0
- hcpdiff/workflow/daam/act.py +66 -0
- hcpdiff/workflow/daam/hook.py +109 -0
- hcpdiff/workflow/diffusion.py +114 -125
- hcpdiff/workflow/fast.py +31 -0
- hcpdiff/workflow/flow.py +67 -0
- hcpdiff/workflow/io.py +36 -130
- hcpdiff/workflow/model.py +46 -43
- hcpdiff/workflow/text.py +78 -46
- hcpdiff/workflow/utils.py +32 -12
- hcpdiff/workflow/vae.py +37 -38
- hcpdiff-2.1.dist-info/METADATA +285 -0
- hcpdiff-2.1.dist-info/RECORD +114 -0
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
- hcpdiff-2.1.dist-info/entry_points.txt +5 -0
- hcpdiff/ckpt_manager/base.py +0 -16
- hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
- hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
- hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
- hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
- hcpdiff/data/bucket.py +0 -358
- hcpdiff/data/caption_loader.py +0 -80
- hcpdiff/data/cond_dataset.py +0 -40
- hcpdiff/data/crop_info_dataset.py +0 -40
- hcpdiff/data/data_processor.py +0 -33
- hcpdiff/data/pair_dataset.py +0 -146
- hcpdiff/data/sampler.py +0 -54
- hcpdiff/data/source/base.py +0 -30
- hcpdiff/data/utils.py +0 -80
- hcpdiff/deprecated/__init__.py +0 -1
- hcpdiff/deprecated/cfg_converter.py +0 -81
- hcpdiff/deprecated/lora_convert.py +0 -31
- hcpdiff/infer_workflow.py +0 -57
- hcpdiff/loggers/__init__.py +0 -13
- hcpdiff/loggers/base_logger.py +0 -76
- hcpdiff/loggers/cli_logger.py +0 -40
- hcpdiff/loggers/preview/__init__.py +0 -1
- hcpdiff/loggers/preview/image_previewer.py +0 -149
- hcpdiff/loggers/tensorboard_logger.py +0 -30
- hcpdiff/loggers/wandb_logger.py +0 -31
- hcpdiff/loggers/webui_logger.py +0 -9
- hcpdiff/loss/min_snr_loss.py +0 -52
- hcpdiff/models/layers.py +0 -81
- hcpdiff/models/plugin.py +0 -348
- hcpdiff/models/wrapper.py +0 -75
- hcpdiff/noise/__init__.py +0 -3
- hcpdiff/noise/noise_base.py +0 -16
- hcpdiff/noise/pyramid_noise.py +0 -50
- hcpdiff/noise/zero_terminal.py +0 -44
- hcpdiff/train_ac.py +0 -566
- hcpdiff/train_ac_single.py +0 -39
- hcpdiff/utils/caption_tools.py +0 -105
- hcpdiff/utils/cfg_net_tools.py +0 -321
- hcpdiff/utils/cfg_resolvers.py +0 -16
- hcpdiff/utils/ema.py +0 -52
- hcpdiff/utils/img_size_tool.py +0 -248
- hcpdiff/vis/__init__.py +0 -3
- hcpdiff/vis/base_interface.py +0 -12
- hcpdiff/vis/disk_interface.py +0 -48
- hcpdiff/vis/webui_interface.py +0 -17
- hcpdiff/viser_fast.py +0 -138
- hcpdiff/visualizer.py +0 -265
- hcpdiff/visualizer_reloadable.py +0 -237
- hcpdiff/workflow/base.py +0 -59
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
- hcpdiff-0.9.1.dist-info/METADATA +0 -199
- hcpdiff-0.9.1.dist-info/RECORD +0 -160
- hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,109 @@
|
|
1
|
+
from daam import AggregateHooker, RawHeatMapCollection, UNetCrossAttentionLocator, GlobalHeatMap
|
2
|
+
from daam.trace import UNetCrossAttentionHooker
|
3
|
+
from typing import List
|
4
|
+
from diffusers import UNet2DConditionModel
|
5
|
+
from PIL import Image
|
6
|
+
import numpy as np
|
7
|
+
import torch
|
8
|
+
import torch.nn.functional as F
|
9
|
+
|
10
|
+
def auto_autocast(*args, **kwargs):
|
11
|
+
if not torch.cuda.is_available():
|
12
|
+
kwargs['enabled'] = False
|
13
|
+
|
14
|
+
return torch.cuda.amp.autocast(*args, **kwargs)
|
15
|
+
|
16
|
+
class DiffusionHeatMapHooker(AggregateHooker):
|
17
|
+
def __init__(
|
18
|
+
self,
|
19
|
+
unet: UNet2DConditionModel,
|
20
|
+
tokenizer,
|
21
|
+
vae_scale_factor: int,
|
22
|
+
low_memory: bool = False,
|
23
|
+
load_heads: bool = False,
|
24
|
+
save_heads: bool = False,
|
25
|
+
data_dir: str = None
|
26
|
+
):
|
27
|
+
self.all_heat_maps = RawHeatMapCollection()
|
28
|
+
h = (unet.config.sample_size * vae_scale_factor)
|
29
|
+
self.latent_hw = 4096 if h == 512 or h == 1024 else 9216 # 64x64 or 96x96 depending on if it's 2.0-v or 2.0
|
30
|
+
locate_middle = load_heads or save_heads
|
31
|
+
self.locator = UNetCrossAttentionLocator(restrict={0} if low_memory else None, locate_middle_block=locate_middle)
|
32
|
+
self.last_prompt: str = ''
|
33
|
+
self.last_image: Image.Image = None
|
34
|
+
self.time_idx = 0
|
35
|
+
self._gen_idx = 0
|
36
|
+
|
37
|
+
self.tokenizer = tokenizer
|
38
|
+
|
39
|
+
modules = [
|
40
|
+
UNetCrossAttentionHooker(
|
41
|
+
x,
|
42
|
+
self,
|
43
|
+
layer_idx=idx,
|
44
|
+
latent_hw=self.latent_hw,
|
45
|
+
load_heads=load_heads,
|
46
|
+
save_heads=save_heads,
|
47
|
+
data_dir=data_dir
|
48
|
+
) for idx, x in enumerate(self.locator.locate(unet))
|
49
|
+
]
|
50
|
+
|
51
|
+
super().__init__(modules)
|
52
|
+
|
53
|
+
def time_callback(self, *args, **kwargs):
|
54
|
+
self.time_idx += 1
|
55
|
+
|
56
|
+
@property
|
57
|
+
def layer_names(self):
|
58
|
+
return self.locator.layer_names
|
59
|
+
|
60
|
+
def compute_global_heat_map(self, prompt=None, factors=None, head_idxs: List[int]=None, layer_idx=None, normalize=False):
|
61
|
+
# type: (str, List[float], int, int, bool) -> GlobalHeatMap
|
62
|
+
"""
|
63
|
+
Compute the global heat map for the given prompt, aggregating across time (inference steps) and space (different
|
64
|
+
spatial transformer block heat maps).
|
65
|
+
|
66
|
+
Args:
|
67
|
+
prompt: The prompt to compute the heat map for. If none, uses the last prompt that was used for generation.
|
68
|
+
factors: Restrict the application to heat maps with spatial factors in this set. If `None`, use all sizes.
|
69
|
+
head_idx: Restrict the application to heat maps with this head index. If `None`, use all heads.
|
70
|
+
layer_idx: Restrict the application to heat maps with this layer index. If `None`, use all layers.
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
A heat map object for computing word-level heat maps.
|
74
|
+
"""
|
75
|
+
heat_maps = self.all_heat_maps
|
76
|
+
|
77
|
+
if prompt is None:
|
78
|
+
prompt = self.last_prompt
|
79
|
+
|
80
|
+
if factors is None:
|
81
|
+
factors = {0, 1, 2, 4, 8, 16, 32, 64}
|
82
|
+
else:
|
83
|
+
factors = set(factors)
|
84
|
+
|
85
|
+
all_merges = []
|
86
|
+
x = int(np.sqrt(self.latent_hw))
|
87
|
+
|
88
|
+
with auto_autocast(dtype=torch.float32):
|
89
|
+
for (factor, layer, head), heat_map in heat_maps:
|
90
|
+
if (head_idxs is None or head in head_idxs) and (layer_idx is None or layer_idx == layer):
|
91
|
+
heat_map = heat_map.unsqueeze(1)/25 # [L,1,H,W]
|
92
|
+
# The clamping fixes undershoot.
|
93
|
+
all_merges.append(F.interpolate(heat_map, size=(x, x), mode='bicubic').clamp_(min=0))
|
94
|
+
|
95
|
+
try:
|
96
|
+
maps = torch.stack(all_merges, dim=0) # [B*head, L, 1, H, W]
|
97
|
+
except RuntimeError:
|
98
|
+
if head_idxs is not None or layer_idx is not None:
|
99
|
+
raise RuntimeError('No heat maps found for the given parameters.')
|
100
|
+
else:
|
101
|
+
raise RuntimeError('No heat maps found. Did you forget to call `with trace(...)` during generation?')
|
102
|
+
|
103
|
+
maps = maps.mean(0)[:, 0] # [L,H,W]
|
104
|
+
#maps = maps[:len(self.tokenizer.tokenize(prompt)) + 2] # 1 for SOS and 1 for padding
|
105
|
+
|
106
|
+
if normalize:
|
107
|
+
maps = maps / (maps[1:-1].sum(0, keepdim=True) + 1e-6) # drop out [SOS] and [PAD] for proper probabilities
|
108
|
+
|
109
|
+
return GlobalHeatMap(self.tokenizer, prompt, maps)
|
hcpdiff/workflow/diffusion.py
CHANGED
@@ -1,209 +1,198 @@
|
|
1
|
-
import
|
1
|
+
import random
|
2
|
+
import warnings
|
2
3
|
from typing import Dict, Any, Union, List
|
3
4
|
|
4
5
|
import torch
|
6
|
+
from hcpdiff.diffusion.sampler import BaseSampler, DiffusersSampler
|
7
|
+
from hcpdiff.utils import prepare_seed
|
8
|
+
from hcpdiff.utils.net_utils import get_dtype, to_cuda
|
9
|
+
from rainbowneko.infer import BasicAction
|
5
10
|
from torch.cuda.amp import autocast
|
6
11
|
|
7
|
-
from .base import BasicAction, from_memory_context, MemoryMixin
|
8
|
-
|
9
12
|
try:
|
10
13
|
from diffusers.utils import randn_tensor
|
11
14
|
except:
|
12
15
|
# new version of diffusers
|
13
16
|
from diffusers.utils.torch_utils import randn_tensor
|
14
17
|
|
15
|
-
from hcpdiff.utils import prepare_seed
|
16
|
-
from hcpdiff.utils.net_utils import get_dtype
|
17
|
-
import random
|
18
|
-
|
19
18
|
class InputFeederAction(BasicAction):
|
20
|
-
|
21
|
-
|
22
|
-
super().__init__()
|
19
|
+
def __init__(self, ex_inputs: Dict[str, Any], key_map_in=None, key_map_out=None):
|
20
|
+
super().__init__(key_map_in, key_map_out)
|
23
21
|
self.ex_inputs = ex_inputs
|
24
|
-
self.unet = unet
|
25
22
|
|
26
|
-
def forward(self, **states):
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
23
|
+
def forward(self, model, ex_inputs=None, **states):
|
24
|
+
ex_inputs = self.ex_inputs if ex_inputs is None else {**ex_inputs, **self.ex_inputs}
|
25
|
+
if hasattr(model, 'input_feeder'):
|
26
|
+
for feeder in model.input_feeder:
|
27
|
+
feeder(ex_inputs)
|
31
28
|
|
32
29
|
class SeedAction(BasicAction):
|
33
|
-
|
34
|
-
|
35
|
-
super().__init__()
|
30
|
+
def __init__(self, seed: Union[int, List[int]], bs: int = 1, key_map_in=None, key_map_out=None):
|
31
|
+
super().__init__(key_map_in, key_map_out)
|
36
32
|
self.seed = seed
|
37
33
|
self.bs = bs
|
38
34
|
|
39
|
-
def forward(self, device, **states):
|
35
|
+
def forward(self, device, gen_step=0, **states):
|
40
36
|
bs = states['prompt_embeds'].shape[0]//2 if 'prompt_embeds' in states else self.bs
|
41
37
|
if self.seed is None:
|
42
38
|
seeds = [None]*bs
|
43
39
|
elif isinstance(self.seed, int):
|
44
|
-
seeds = list(range(self.seed, self.seed+bs))
|
40
|
+
seeds = list(range(self.seed+gen_step*bs, self.seed+(gen_step+1)*bs))
|
45
41
|
else:
|
46
42
|
seeds = self.seed
|
47
43
|
seeds = [s or random.randint(0, 1 << 30) for s in seeds]
|
48
44
|
|
49
45
|
G = prepare_seed(seeds, device=device)
|
50
|
-
return {
|
46
|
+
return {'seeds':seeds, 'generator':G}
|
51
47
|
|
52
|
-
class PrepareDiffusionAction(BasicAction
|
53
|
-
def __init__(self,
|
54
|
-
|
48
|
+
class PrepareDiffusionAction(BasicAction):
|
49
|
+
def __init__(self, model_offload=False, amp=torch.float16, key_map_in=None, key_map_out=None):
|
50
|
+
super().__init__(key_map_in, key_map_out)
|
51
|
+
self.model_offload = model_offload
|
55
52
|
self.amp = amp
|
56
53
|
|
57
|
-
def forward(self,
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
memory.vae.to(dtype=dtype)
|
54
|
+
def forward(self, device, denoiser, TE, vae, **states):
|
55
|
+
denoiser.to(device)
|
56
|
+
TE.to(device)
|
57
|
+
vae.to(device)
|
62
58
|
|
63
|
-
|
64
|
-
|
65
|
-
|
59
|
+
TE.eval()
|
60
|
+
denoiser.eval()
|
61
|
+
vae.eval()
|
62
|
+
return {'amp':self.amp, 'model_offload':self.model_offload}
|
66
63
|
|
67
|
-
class MakeTimestepsAction(BasicAction
|
68
|
-
|
69
|
-
|
70
|
-
self.scheduler = scheduler
|
64
|
+
class MakeTimestepsAction(BasicAction):
|
65
|
+
def __init__(self, N_steps: int = 30, strength: float = None, key_map_in=None, key_map_out=None):
|
66
|
+
super().__init__(key_map_in, key_map_out)
|
71
67
|
self.N_steps = N_steps
|
72
68
|
self.strength = strength
|
73
69
|
|
74
|
-
def get_timesteps(self, timesteps, strength):
|
70
|
+
def get_timesteps(self, noise_sampler:BaseSampler, timesteps, strength):
|
75
71
|
# get the original timestep using init_timestep
|
76
72
|
num_inference_steps = len(timesteps)
|
77
73
|
init_timestep = min(int(num_inference_steps*strength), num_inference_steps)
|
78
74
|
|
79
75
|
t_start = max(num_inference_steps-init_timestep, 0)
|
80
|
-
|
76
|
+
if isinstance(noise_sampler, DiffusersSampler):
|
77
|
+
timesteps = timesteps[t_start*noise_sampler.scheduler.order:]
|
78
|
+
else:
|
79
|
+
timesteps = timesteps[t_start:]
|
81
80
|
|
82
81
|
return timesteps
|
83
82
|
|
84
|
-
def forward(self,
|
85
|
-
|
86
|
-
|
87
|
-
self.scheduler.set_timesteps(self.N_steps, device=device)
|
88
|
-
timesteps = self.scheduler.timesteps
|
83
|
+
def forward(self, noise_sampler:BaseSampler, device, **states):
|
84
|
+
timesteps = noise_sampler.get_timesteps(self.N_steps, device=device)
|
89
85
|
if self.strength:
|
90
|
-
timesteps = self.get_timesteps(timesteps, self.strength)
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
def __init__(self,
|
97
|
-
|
86
|
+
timesteps = self.get_timesteps(noise_sampler, timesteps, self.strength)
|
87
|
+
return {'timesteps':timesteps, 'start_timestep':timesteps[:1]}
|
88
|
+
else:
|
89
|
+
return {'timesteps':timesteps}
|
90
|
+
|
91
|
+
class MakeLatentAction(BasicAction):
|
92
|
+
def __init__(self, N_ch=4, height=None, width=None, key_map_in=None, key_map_out=None):
|
93
|
+
super().__init__(key_map_in, key_map_out)
|
98
94
|
self.N_ch = N_ch
|
99
95
|
self.height = height
|
100
96
|
self.width = width
|
101
97
|
|
102
|
-
def forward(self,
|
98
|
+
def forward(self, noise_sampler:BaseSampler, vae, generator, device, dtype, bs=None, latents=None, start_timestep=None,
|
99
|
+
pooled_output=None, crop_coord=None, **states):
|
103
100
|
if bs is None:
|
104
101
|
if 'prompt' in states:
|
105
102
|
bs = len(states['prompt'])
|
106
|
-
|
103
|
+
vae_scale_factor = 2**(len(vae.config.block_out_channels)-1)
|
104
|
+
device = torch.device(device)
|
107
105
|
|
108
|
-
|
106
|
+
if latents is None:
|
107
|
+
shape = (bs, self.N_ch, self.height//vae_scale_factor, self.width//vae_scale_factor)
|
108
|
+
else:
|
109
|
+
if self.height is not None:
|
110
|
+
warnings.warn('latents exist! User-specified width and height will be ignored!')
|
111
|
+
shape = latents.shape
|
109
112
|
if isinstance(generator, list) and len(generator) != bs:
|
110
113
|
raise ValueError(
|
111
114
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
112
115
|
f" size of {bs}. Make sure the batch size matches the length of the generators."
|
113
116
|
)
|
114
117
|
|
115
|
-
noise = randn_tensor(shape, generator=generator, device=device, dtype=get_dtype(dtype))
|
116
118
|
if latents is None:
|
117
|
-
# scale the initial noise by the standard deviation required by the
|
118
|
-
|
119
|
+
# scale the initial noise by the standard deviation required by the noise_sampler
|
120
|
+
noise_sampler.generator = generator
|
121
|
+
latents = noise_sampler.init_noise(shape, device=device, dtype=get_dtype(dtype))
|
119
122
|
else:
|
120
123
|
# image to image
|
121
124
|
latents = latents.to(device)
|
122
|
-
latents =
|
125
|
+
latents, noise = noise_sampler.add_noise(latents, start_timestep)
|
123
126
|
|
124
|
-
|
127
|
+
output = {'latents':latents}
|
125
128
|
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
+
# SDXL inputs
|
130
|
+
if pooled_output is not None:
|
131
|
+
width, height = shape[3]*vae_scale_factor, shape[2]*vae_scale_factor
|
132
|
+
if crop_coord is None:
|
133
|
+
crop_info = torch.tensor([height, width, 0, 0, height, width], dtype=torch.float)
|
134
|
+
else:
|
135
|
+
crop_info = torch.tensor([height, width, *crop_coord], dtype=torch.float)
|
136
|
+
crop_info = crop_info.to(device).repeat(bs, 1)
|
137
|
+
output['text_embeds'] = pooled_output[-1].to(device)
|
138
|
+
|
139
|
+
if 'negative_prompt' in states:
|
140
|
+
output['crop_info'] = torch.cat([crop_info, crop_info], dim=0)
|
141
|
+
|
142
|
+
return output
|
143
|
+
|
144
|
+
class DenoiseAction(BasicAction):
|
145
|
+
def __init__(self, guidance_scale: float = 7.0, key_map_in=None, key_map_out=None):
|
146
|
+
super().__init__(key_map_in, key_map_out)
|
129
147
|
self.guidance_scale = guidance_scale
|
130
|
-
self.unet = unet
|
131
|
-
self.scheduler = scheduler
|
132
148
|
|
133
|
-
def forward(self,
|
134
|
-
cross_attention_kwargs=None, dtype='fp32', amp=None, **states):
|
135
|
-
|
136
|
-
|
149
|
+
def forward(self, denoiser, noise_sampler: BaseSampler, t, latents, prompt_embeds, text_embeds=None, encoder_attention_mask=None, crop_info=None,
|
150
|
+
cross_attention_kwargs=None, dtype='fp32', amp=None, model_offload=False, **states):
|
151
|
+
|
152
|
+
if model_offload:
|
153
|
+
to_cuda(denoiser) # to_cpu in VAE
|
137
154
|
|
138
155
|
with autocast(enabled=amp is not None, dtype=get_dtype(amp)):
|
139
156
|
latent_model_input = torch.cat([latents]*2) if self.guidance_scale>1 else latents
|
140
|
-
latent_model_input =
|
157
|
+
latent_model_input = noise_sampler.c_in(t)*latent_model_input
|
141
158
|
|
142
|
-
if
|
143
|
-
noise_pred =
|
144
|
-
|
159
|
+
if text_embeds is None:
|
160
|
+
noise_pred = denoiser(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
|
161
|
+
cross_attention_kwargs=cross_attention_kwargs, ).sample
|
145
162
|
else:
|
146
|
-
added_cond_kwargs = {"text_embeds":
|
163
|
+
added_cond_kwargs = {"text_embeds":text_embeds, "time_ids":crop_info}
|
147
164
|
# predict the noise residual
|
148
|
-
noise_pred =
|
149
|
-
|
165
|
+
noise_pred = denoiser(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
|
166
|
+
cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
|
150
167
|
|
151
168
|
# perform guidance
|
152
169
|
if self.guidance_scale>1:
|
153
170
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
154
171
|
noise_pred = noise_pred_uncond+self.guidance_scale*(noise_pred_text-noise_pred_uncond)
|
155
172
|
|
156
|
-
return {
|
157
|
-
'crop_info':crop_info, 'cross_attention_kwargs':cross_attention_kwargs, 'dtype':dtype, 'amp':amp}
|
158
|
-
|
159
|
-
class SampleAction(BasicAction, MemoryMixin):
|
160
|
-
@from_memory_context
|
161
|
-
def __init__(self, scheduler=None, eta=0.0):
|
162
|
-
self.scheduler = scheduler
|
163
|
-
self.eta = eta
|
164
|
-
|
165
|
-
def prepare_extra_step_kwargs(self, generator, eta):
|
166
|
-
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
167
|
-
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
168
|
-
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
169
|
-
# and should be between [0, 1]
|
170
|
-
|
171
|
-
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
172
|
-
extra_step_kwargs = {}
|
173
|
-
if accepts_eta:
|
174
|
-
extra_step_kwargs["eta"] = eta
|
175
|
-
|
176
|
-
# check if the scheduler accepts generator
|
177
|
-
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
178
|
-
if accepts_generator:
|
179
|
-
extra_step_kwargs["generator"] = generator
|
180
|
-
return extra_step_kwargs
|
181
|
-
|
182
|
-
def forward(self, memory, noise_pred, t, latents, generator, **states):
|
183
|
-
self.scheduler = self.scheduler or memory.scheduler
|
184
|
-
|
185
|
-
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, self.eta)
|
173
|
+
return {'noise_pred':noise_pred}
|
186
174
|
|
175
|
+
class SampleAction(BasicAction):
|
176
|
+
def forward(self, noise_sampler: BaseSampler, noise_pred, t, latents, generator, **states):
|
187
177
|
# compute the previous noisy sample x_t -> x_t-1
|
188
|
-
|
189
|
-
latents
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
self.
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
states = self.
|
200
|
-
states = self.act_sample(memory=memory, **states)
|
178
|
+
latents = noise_sampler.denoise(latents, t, noise_pred, generator=generator)
|
179
|
+
return {'latents':latents}
|
180
|
+
|
181
|
+
class DiffusionStepAction(BasicAction):
|
182
|
+
def __init__(self, guidance_scale: float = 7.0, key_map_in=None, key_map_out=None):
|
183
|
+
super().__init__(key_map_in, key_map_out)
|
184
|
+
self.act_noise_pred = DenoiseAction(guidance_scale)
|
185
|
+
self.act_sample = SampleAction()
|
186
|
+
|
187
|
+
def forward(self, denoiser, noise_sampler, **states):
|
188
|
+
states = self.act_noise_pred(denoiser=denoiser, noise_sampler=noise_sampler, **states)
|
189
|
+
states = self.act_sample(**states)
|
201
190
|
return states
|
202
191
|
|
203
192
|
class X0PredAction(BasicAction):
|
204
|
-
def forward(self, latents,
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
193
|
+
def forward(self, latents, noise_sampler: BaseSampler, t, noise_pred, **states):
|
194
|
+
latents_x0 = noise_sampler.eps_to_x0(noise_pred, latents, t)
|
195
|
+
return {'latents_x0':latents_x0}
|
196
|
+
|
197
|
+
def time_iter(timesteps, **states):
|
198
|
+
return [{'t':t} for t in timesteps]
|
hcpdiff/workflow/fast.py
ADDED
@@ -0,0 +1,31 @@
|
|
1
|
+
from sfast.compilers.diffusion_pipeline_compiler import (compile_unet, CompilationConfig)
|
2
|
+
from rainbowneko.infer import BasicAction
|
3
|
+
|
4
|
+
|
5
|
+
class SFastCompileAction(BasicAction):
|
6
|
+
|
7
|
+
@staticmethod
|
8
|
+
def compile_model(unet):
|
9
|
+
# compile model
|
10
|
+
config = CompilationConfig.Default()
|
11
|
+
config.enable_xformers = False
|
12
|
+
try:
|
13
|
+
import xformers
|
14
|
+
config.enable_xformers = True
|
15
|
+
except ImportError:
|
16
|
+
print('xformers not installed, skip')
|
17
|
+
# NOTE:
|
18
|
+
# When GPU VRAM is insufficient or the architecture is too old, Triton might be slow.
|
19
|
+
# Disable Triton if you encounter this problem.
|
20
|
+
try:
|
21
|
+
import tritonx
|
22
|
+
config.enable_triton = True
|
23
|
+
except ImportError:
|
24
|
+
print('Triton not installed, skip')
|
25
|
+
config.enable_cuda_graph = True
|
26
|
+
|
27
|
+
return compile_unet(unet, config)
|
28
|
+
|
29
|
+
def forward(self, denoiser, **states):
|
30
|
+
denoiser = self.compile_model(denoiser)
|
31
|
+
return {'denoiser': denoiser}
|
hcpdiff/workflow/flow.py
ADDED
@@ -0,0 +1,67 @@
|
|
1
|
+
from rainbowneko.infer import BasicAction
|
2
|
+
from typing import List, Dict
|
3
|
+
from tqdm import tqdm
|
4
|
+
import math
|
5
|
+
|
6
|
+
class FilePromptAction(BasicAction):
|
7
|
+
def __init__(self, actions: List[BasicAction], prompt: str, negative_prompt: str, bs: int = 4, key_map_in=None, key_map_out=None):
|
8
|
+
super().__init__(key_map_in, key_map_out)
|
9
|
+
if prompt.endswith('.txt'):
|
10
|
+
with open(prompt, 'r') as f:
|
11
|
+
prompt = f.read().split('\n')
|
12
|
+
else:
|
13
|
+
prompt = [prompt]
|
14
|
+
|
15
|
+
if negative_prompt.endswith('.txt'):
|
16
|
+
with open(negative_prompt, 'r') as f:
|
17
|
+
negative_prompt = f.read().split('\n')
|
18
|
+
else:
|
19
|
+
negative_prompt = [negative_prompt]*len(prompt)
|
20
|
+
|
21
|
+
self.prompt = prompt
|
22
|
+
self.negative_prompt = negative_prompt
|
23
|
+
self.bs = bs
|
24
|
+
self.actions = actions
|
25
|
+
|
26
|
+
|
27
|
+
def forward(self, **states):
|
28
|
+
states.update({'prompt_all':self.prompt, 'negative_prompt_all':self.negative_prompt})
|
29
|
+
states_ref = dict(**states)
|
30
|
+
|
31
|
+
pbar = tqdm(range(math.ceil(len(self.prompt)/self.bs)))
|
32
|
+
N_steps = len(self.actions)
|
33
|
+
for gen_step in pbar:
|
34
|
+
states = dict(**states_ref)
|
35
|
+
feed_data = {'gen_step': gen_step}
|
36
|
+
states.update(feed_data)
|
37
|
+
for step, act in enumerate(self.actions):
|
38
|
+
pbar.set_description(f'[{step+1}/{N_steps}] action: {type(act).__name__}')
|
39
|
+
states = act(**states)
|
40
|
+
return states
|
41
|
+
|
42
|
+
class FlowPromptAction(BasicAction):
|
43
|
+
def __init__(self, actions: List[BasicAction], prompt: str, negative_prompt: str, bs: int = 4, num: int = None, key_map_in=None, key_map_out=None):
|
44
|
+
super().__init__(key_map_in, key_map_out)
|
45
|
+
prompt = [prompt]*num
|
46
|
+
negative_prompt = [negative_prompt]*num
|
47
|
+
|
48
|
+
self.prompt = prompt
|
49
|
+
self.negative_prompt = negative_prompt
|
50
|
+
self.bs = bs
|
51
|
+
self.actions = actions
|
52
|
+
|
53
|
+
|
54
|
+
def forward(self, **states):
|
55
|
+
states.update({'prompt_all':self.prompt, 'negative_prompt_all':self.negative_prompt})
|
56
|
+
states_ref = dict(**states)
|
57
|
+
|
58
|
+
pbar = tqdm(range(math.ceil(len(self.prompt)/self.bs)))
|
59
|
+
N_steps = len(self.actions)
|
60
|
+
for gen_step in pbar:
|
61
|
+
states = dict(**states_ref)
|
62
|
+
feed_data = {'gen_step': gen_step}
|
63
|
+
states.update(feed_data)
|
64
|
+
for step, act in enumerate(self.actions):
|
65
|
+
pbar.set_description(f'[{step+1}/{N_steps}] action: {type(act).__name__}')
|
66
|
+
states = act(**states)
|
67
|
+
return states
|