hcpdiff 0.9.0__py3-none-any.whl → 2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/__init__.py +4 -4
- hcpdiff/ckpt_manager/__init__.py +4 -5
- hcpdiff/ckpt_manager/ckpt.py +24 -0
- hcpdiff/ckpt_manager/format/__init__.py +4 -0
- hcpdiff/ckpt_manager/format/diffusers.py +59 -0
- hcpdiff/ckpt_manager/format/emb.py +21 -0
- hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
- hcpdiff/ckpt_manager/format/sd_single.py +41 -0
- hcpdiff/ckpt_manager/loader.py +64 -0
- hcpdiff/data/__init__.py +4 -28
- hcpdiff/data/cache/__init__.py +1 -0
- hcpdiff/data/cache/vae.py +102 -0
- hcpdiff/data/dataset.py +20 -0
- hcpdiff/data/handler/__init__.py +3 -0
- hcpdiff/data/handler/controlnet.py +18 -0
- hcpdiff/data/handler/diffusion.py +80 -0
- hcpdiff/data/handler/text.py +111 -0
- hcpdiff/data/source/__init__.py +1 -2
- hcpdiff/data/source/folder_class.py +12 -29
- hcpdiff/data/source/text2img.py +36 -74
- hcpdiff/data/source/text2img_cond.py +9 -15
- hcpdiff/diffusion/__init__.py +0 -0
- hcpdiff/diffusion/noise/__init__.py +2 -0
- hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
- hcpdiff/diffusion/noise/zero_terminal.py +39 -0
- hcpdiff/diffusion/sampler/__init__.py +5 -0
- hcpdiff/diffusion/sampler/base.py +72 -0
- hcpdiff/diffusion/sampler/ddpm.py +20 -0
- hcpdiff/diffusion/sampler/diffusers.py +66 -0
- hcpdiff/diffusion/sampler/edm.py +22 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
- hcpdiff/easy/__init__.py +2 -0
- hcpdiff/easy/cfg/__init__.py +3 -0
- hcpdiff/easy/cfg/sd15_train.py +201 -0
- hcpdiff/easy/cfg/sdxl_train.py +140 -0
- hcpdiff/easy/cfg/t2i.py +177 -0
- hcpdiff/easy/model/__init__.py +2 -0
- hcpdiff/easy/model/cnet.py +31 -0
- hcpdiff/easy/model/loader.py +79 -0
- hcpdiff/easy/sampler.py +46 -0
- hcpdiff/evaluate/__init__.py +1 -0
- hcpdiff/evaluate/previewer.py +60 -0
- hcpdiff/loss/__init__.py +4 -1
- hcpdiff/loss/base.py +41 -0
- hcpdiff/loss/gw.py +35 -0
- hcpdiff/loss/ssim.py +37 -0
- hcpdiff/loss/vlb.py +79 -0
- hcpdiff/loss/weighting.py +66 -0
- hcpdiff/models/__init__.py +2 -2
- hcpdiff/models/cfg_context.py +17 -14
- hcpdiff/models/compose/compose_hook.py +44 -23
- hcpdiff/models/compose/compose_tokenizer.py +21 -8
- hcpdiff/models/compose/sdxl_composer.py +4 -4
- hcpdiff/models/container.py +1 -1
- hcpdiff/models/controlnet.py +16 -16
- hcpdiff/models/lora_base_patch.py +14 -25
- hcpdiff/models/lora_layers.py +3 -9
- hcpdiff/models/lora_layers_patch.py +14 -24
- hcpdiff/models/text_emb_ex.py +84 -6
- hcpdiff/models/textencoder_ex.py +54 -18
- hcpdiff/models/wrapper/__init__.py +3 -0
- hcpdiff/models/wrapper/pixart.py +19 -0
- hcpdiff/models/wrapper/sd.py +218 -0
- hcpdiff/models/wrapper/utils.py +20 -0
- hcpdiff/parser/__init__.py +1 -0
- hcpdiff/parser/embpt.py +32 -0
- hcpdiff/tools/convert_caption_txt2json.py +1 -1
- hcpdiff/tools/dataset_generator.py +94 -0
- hcpdiff/tools/download_hf_model.py +24 -0
- hcpdiff/tools/embedding_convert.py +6 -2
- hcpdiff/tools/init_proj.py +3 -21
- hcpdiff/tools/lora_convert.py +19 -15
- hcpdiff/tools/save_model.py +12 -0
- hcpdiff/tools/sd2diffusers.py +1 -1
- hcpdiff/train_colo.py +1 -1
- hcpdiff/train_deepspeed.py +1 -1
- hcpdiff/trainer_ac.py +79 -0
- hcpdiff/trainer_ac_single.py +31 -0
- hcpdiff/utils/__init__.py +0 -2
- hcpdiff/utils/inpaint_pipe.py +790 -0
- hcpdiff/utils/net_utils.py +29 -6
- hcpdiff/utils/pipe_hook.py +46 -33
- hcpdiff/utils/utils.py +21 -4
- hcpdiff/workflow/__init__.py +15 -10
- hcpdiff/workflow/daam/__init__.py +1 -0
- hcpdiff/workflow/daam/act.py +66 -0
- hcpdiff/workflow/daam/hook.py +109 -0
- hcpdiff/workflow/diffusion.py +128 -136
- hcpdiff/workflow/fast.py +31 -0
- hcpdiff/workflow/flow.py +67 -0
- hcpdiff/workflow/io.py +36 -68
- hcpdiff/workflow/model.py +46 -43
- hcpdiff/workflow/text.py +84 -52
- hcpdiff/workflow/utils.py +32 -12
- hcpdiff/workflow/vae.py +37 -38
- hcpdiff-2.1.dist-info/METADATA +285 -0
- hcpdiff-2.1.dist-info/RECORD +114 -0
- {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
- hcpdiff-2.1.dist-info/entry_points.txt +5 -0
- hcpdiff/ckpt_manager/base.py +0 -16
- hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
- hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
- hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -60
- hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
- hcpdiff/data/bucket.py +0 -358
- hcpdiff/data/caption_loader.py +0 -80
- hcpdiff/data/cond_dataset.py +0 -40
- hcpdiff/data/crop_info_dataset.py +0 -40
- hcpdiff/data/data_processor.py +0 -33
- hcpdiff/data/pair_dataset.py +0 -146
- hcpdiff/data/sampler.py +0 -54
- hcpdiff/data/source/base.py +0 -30
- hcpdiff/data/utils.py +0 -80
- hcpdiff/infer_workflow.py +0 -57
- hcpdiff/loggers/__init__.py +0 -13
- hcpdiff/loggers/base_logger.py +0 -76
- hcpdiff/loggers/cli_logger.py +0 -40
- hcpdiff/loggers/preview/__init__.py +0 -1
- hcpdiff/loggers/preview/image_previewer.py +0 -149
- hcpdiff/loggers/tensorboard_logger.py +0 -30
- hcpdiff/loggers/wandb_logger.py +0 -31
- hcpdiff/loggers/webui_logger.py +0 -9
- hcpdiff/loss/min_snr_loss.py +0 -52
- hcpdiff/models/layers.py +0 -81
- hcpdiff/models/plugin.py +0 -348
- hcpdiff/models/wrapper.py +0 -75
- hcpdiff/noise/__init__.py +0 -3
- hcpdiff/noise/noise_base.py +0 -16
- hcpdiff/noise/pyramid_noise.py +0 -50
- hcpdiff/noise/zero_terminal.py +0 -44
- hcpdiff/train_ac.py +0 -565
- hcpdiff/train_ac_single.py +0 -39
- hcpdiff/utils/caption_tools.py +0 -105
- hcpdiff/utils/cfg_net_tools.py +0 -321
- hcpdiff/utils/cfg_resolvers.py +0 -16
- hcpdiff/utils/ema.py +0 -52
- hcpdiff/utils/img_size_tool.py +0 -248
- hcpdiff/vis/__init__.py +0 -3
- hcpdiff/vis/base_interface.py +0 -12
- hcpdiff/vis/disk_interface.py +0 -48
- hcpdiff/vis/webui_interface.py +0 -17
- hcpdiff/visualizer.py +0 -258
- hcpdiff/visualizer_reloadable.py +0 -237
- hcpdiff/workflow/base.py +0 -59
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -57
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero2.json +0 -32
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero3.json +0 -39
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
- hcpdiff-0.9.0.dist-info/METADATA +0 -199
- hcpdiff-0.9.0.dist-info/RECORD +0 -155
- hcpdiff-0.9.0.dist-info/entry_points.txt +0 -2
- {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
- {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
hcpdiff/utils/net_utils.py
CHANGED
@@ -6,11 +6,19 @@ import torch
|
|
6
6
|
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION, Optimizer
|
7
7
|
from torch import nn
|
8
8
|
from torch.optim import lr_scheduler
|
9
|
-
from transformers import PretrainedConfig, AutoTokenizer
|
9
|
+
from transformers import PretrainedConfig, AutoTokenizer, T5EncoderModel, CLIPTextModel
|
10
10
|
from functools import partial
|
11
|
+
from huggingface_hub import hf_hub_download
|
12
|
+
import json
|
11
13
|
|
12
14
|
dtype_dict = {'fp32':torch.float32, 'amp':torch.float32, 'fp16':torch.float16, 'bf16':torch.bfloat16}
|
13
15
|
|
16
|
+
try:
|
17
|
+
dtype_dict['fp8_e4m3'] = torch.float8_e4m3fn
|
18
|
+
dtype_dict['fp8_e5m2'] = torch.float8_e5m2
|
19
|
+
except:
|
20
|
+
pass
|
21
|
+
|
14
22
|
def get_scheduler(cfg, optimizer):
|
15
23
|
if cfg is None:
|
16
24
|
return None
|
@@ -90,7 +98,7 @@ def auto_tokenizer_cls(pretrained_model_name_or_path: str, revision: str = None)
|
|
90
98
|
revision=revision, use_fast=False,
|
91
99
|
)
|
92
100
|
return SDXLTokenizer
|
93
|
-
except
|
101
|
+
except:
|
94
102
|
# not sdxl, only one tokenizer
|
95
103
|
return AutoTokenizer
|
96
104
|
|
@@ -102,8 +110,10 @@ def auto_text_encoder_cls(pretrained_model_name_or_path: str, revision: str = No
|
|
102
110
|
subfolder="text_encoder_2",
|
103
111
|
revision=revision,
|
104
112
|
)
|
113
|
+
if text_encoder_config.architectures is None:
|
114
|
+
raise ValueError()
|
105
115
|
return SDXLTextEncoder
|
106
|
-
except
|
116
|
+
except:
|
107
117
|
text_encoder_config = PretrainedConfig.from_pretrained(
|
108
118
|
pretrained_model_name_or_path,
|
109
119
|
subfolder="text_encoder",
|
@@ -112,16 +122,26 @@ def auto_text_encoder_cls(pretrained_model_name_or_path: str, revision: str = No
|
|
112
122
|
model_class = text_encoder_config.architectures[0]
|
113
123
|
|
114
124
|
if model_class == "CLIPTextModel":
|
115
|
-
from transformers import CLIPTextModel
|
116
|
-
|
117
125
|
return CLIPTextModel
|
118
126
|
elif model_class == "RobertaSeriesModelWithTransformation":
|
119
127
|
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
|
120
128
|
|
121
129
|
return RobertaSeriesModelWithTransformation
|
130
|
+
elif model_class == "T5EncoderModel":
|
131
|
+
return T5EncoderModel
|
122
132
|
else:
|
123
133
|
raise ValueError(f"{model_class} is not supported.")
|
124
134
|
|
135
|
+
def get_pipe_name(path: str):
|
136
|
+
if os.path.isdir(path):
|
137
|
+
json_file = os.path.join(path, "model_index.json")
|
138
|
+
else:
|
139
|
+
json_file = hf_hub_download(path, "model_index.json")
|
140
|
+
with open(json_file, "r", encoding="utf-8") as reader:
|
141
|
+
text = reader.read()
|
142
|
+
data = json.loads(text)
|
143
|
+
return data['_class_name']
|
144
|
+
|
125
145
|
def auto_tokenizer(pretrained_model_name_or_path: str, revision: str = None, **kwargs):
|
126
146
|
return auto_tokenizer_cls(pretrained_model_name_or_path, revision).from_pretrained(pretrained_model_name_or_path, revision=revision, **kwargs)
|
127
147
|
|
@@ -225,4 +245,7 @@ def split_module_name(layer_name):
|
|
225
245
|
return parent_name, host_name
|
226
246
|
|
227
247
|
def get_dtype(dtype):
|
228
|
-
|
248
|
+
if isinstance(dtype, torch.dtype):
|
249
|
+
return dtype
|
250
|
+
else:
|
251
|
+
return dtype_dict.get(dtype, torch.float32)
|
hcpdiff/utils/pipe_hook.py
CHANGED
@@ -2,10 +2,10 @@ from typing import Union, List, Optional, Callable, Dict, Any
|
|
2
2
|
|
3
3
|
import PIL
|
4
4
|
import torch
|
5
|
-
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline,
|
5
|
+
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, PixArtTransformer2DModel
|
6
6
|
from diffusers.image_processor import VaeImageProcessor
|
7
|
-
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
8
|
-
from
|
7
|
+
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
8
|
+
from .inpaint_pipe import preprocess_mask, preprocess_image, StableDiffusionInpaintPipelineLegacy
|
9
9
|
from einops import repeat
|
10
10
|
|
11
11
|
class HookPipe_T2I(StableDiffusionPipeline):
|
@@ -17,25 +17,17 @@ class HookPipe_T2I(StableDiffusionPipeline):
|
|
17
17
|
def device(self) -> torch.device:
|
18
18
|
return torch.device('cuda')
|
19
19
|
|
20
|
-
def proc_prompt(self, device,
|
21
|
-
|
22
|
-
|
20
|
+
def proc_prompt(self, device, num_inference_steps, prompt_embeds = None, negative_prompt_embeds = None) -> List[torch.Tensor]:
|
21
|
+
if not isinstance(prompt_embeds, list): # to emb for each step
|
22
|
+
prompt_embeds = [prompt_embeds]*num_inference_steps
|
23
|
+
if not isinstance(negative_prompt_embeds, list): # to emb for each step
|
24
|
+
negative_prompt_embeds = [negative_prompt_embeds]*num_inference_steps
|
23
25
|
|
24
|
-
|
25
|
-
|
26
|
-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
27
|
-
prompt_embeds = prompt_embeds.view(bs_embed*num_images_per_prompt, seq_len, -1)
|
26
|
+
prompt_embeds = [p.to(dtype=self.text_encoder.dtype, device=device) for p in prompt_embeds]
|
27
|
+
negative_prompt_embeds = [p.to(dtype=self.text_encoder.dtype, device=device) for p in negative_prompt_embeds]
|
28
28
|
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
33
|
-
|
34
|
-
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
35
|
-
negative_prompt_embeds = negative_prompt_embeds.view(batch_size*num_images_per_prompt, seq_len, -1)
|
36
|
-
|
37
|
-
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
38
|
-
return prompt_embeds
|
29
|
+
prompt_embeds = [torch.cat([emb_neg, emb_pos]) for emb_pos, emb_neg in zip(prompt_embeds, negative_prompt_embeds)]
|
30
|
+
return prompt_embeds # List[emb_step_i]*num_inference_steps
|
39
31
|
|
40
32
|
@torch.no_grad()
|
41
33
|
def __call__(
|
@@ -46,7 +38,6 @@ class HookPipe_T2I(StableDiffusionPipeline):
|
|
46
38
|
num_inference_steps: int = 50,
|
47
39
|
guidance_scale: float = 7.5,
|
48
40
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
49
|
-
num_images_per_prompt: Optional[int] = 1,
|
50
41
|
eta: float = 0.0,
|
51
42
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
52
43
|
latents: Optional[torch.FloatTensor] = None,
|
@@ -74,6 +65,8 @@ class HookPipe_T2I(StableDiffusionPipeline):
|
|
74
65
|
batch_size = 1
|
75
66
|
elif prompt is not None and isinstance(prompt, list):
|
76
67
|
batch_size = len(prompt)
|
68
|
+
elif isinstance(prompt_embeds, list):
|
69
|
+
batch_size = prompt_embeds[0].shape[0]
|
77
70
|
else:
|
78
71
|
batch_size = prompt_embeds.shape[0]
|
79
72
|
|
@@ -84,7 +77,7 @@ class HookPipe_T2I(StableDiffusionPipeline):
|
|
84
77
|
do_classifier_free_guidance = guidance_scale>1.0
|
85
78
|
|
86
79
|
# 3. Encode input prompt
|
87
|
-
prompt_embeds = self.proc_prompt(device,
|
80
|
+
prompt_embeds = self.proc_prompt(device, num_inference_steps,
|
88
81
|
prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds)
|
89
82
|
|
90
83
|
# 4. Prepare timesteps
|
@@ -95,11 +88,11 @@ class HookPipe_T2I(StableDiffusionPipeline):
|
|
95
88
|
# 5. Prepare latent variables
|
96
89
|
num_channels_latents = self.unet.config.in_channels
|
97
90
|
latents = self.prepare_latents(
|
98
|
-
batch_size
|
91
|
+
batch_size,
|
99
92
|
num_channels_latents,
|
100
93
|
height,
|
101
94
|
width,
|
102
|
-
prompt_embeds.dtype,
|
95
|
+
prompt_embeds[0].dtype,
|
103
96
|
device,
|
104
97
|
generator,
|
105
98
|
latents,
|
@@ -114,7 +107,7 @@ class HookPipe_T2I(StableDiffusionPipeline):
|
|
114
107
|
crop_info = torch.tensor([height, width, 0, 0, height, width], dtype=torch.float)
|
115
108
|
else:
|
116
109
|
crop_info = torch.tensor([height, width, *crop_coord], dtype=torch.float)
|
117
|
-
crop_info = crop_info.to(device).repeat(batch_size
|
110
|
+
crop_info = crop_info.to(device).repeat(batch_size, 1)
|
118
111
|
pooled_output = pooled_output.to(device)
|
119
112
|
|
120
113
|
if do_classifier_free_guidance:
|
@@ -129,12 +122,20 @@ class HookPipe_T2I(StableDiffusionPipeline):
|
|
129
122
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
130
123
|
|
131
124
|
if pooled_output is None:
|
132
|
-
|
133
|
-
|
125
|
+
if isinstance(self.unet, PixArtTransformer2DModel):
|
126
|
+
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
|
127
|
+
noise_pred = self.unet(latent_model_input, timestep=t.repeat(latent_model_input.shape[0]), encoder_hidden_states=prompt_embeds[i],
|
128
|
+
encoder_attention_mask=encoder_attention_mask,
|
129
|
+
cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
|
130
|
+
else:
|
131
|
+
noise_pred = self.unet(latent_model_input, timestep=t, encoder_hidden_states=prompt_embeds[i],
|
132
|
+
encoder_attention_mask=encoder_attention_mask,
|
133
|
+
cross_attention_kwargs=cross_attention_kwargs).sample
|
134
134
|
else:
|
135
135
|
added_cond_kwargs = {"text_embeds":pooled_output, "time_ids":crop_info}
|
136
136
|
# predict the noise residual
|
137
|
-
noise_pred = self.unet(latent_model_input, t, prompt_embeds,
|
137
|
+
noise_pred = self.unet(latent_model_input, timestep=t, encoder_hidden_states=prompt_embeds[i],
|
138
|
+
encoder_attention_mask=encoder_attention_mask,
|
138
139
|
cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
|
139
140
|
|
140
141
|
# perform guidance
|
@@ -142,6 +143,10 @@ class HookPipe_T2I(StableDiffusionPipeline):
|
|
142
143
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
143
144
|
noise_pred = noise_pred_uncond+guidance_scale*(noise_pred_text-noise_pred_uncond)
|
144
145
|
|
146
|
+
# learned sigma
|
147
|
+
if self.unet.config.out_channels // 2 == num_channels_latents:
|
148
|
+
noise_pred = noise_pred.chunk(2, dim=1)[0]
|
149
|
+
|
145
150
|
# x_t -> x_0
|
146
151
|
alpha_prod_t = alphas_cumprod[t.long()]
|
147
152
|
beta_prod_t = 1-alpha_prod_t
|
@@ -155,7 +160,8 @@ class HookPipe_T2I(StableDiffusionPipeline):
|
|
155
160
|
if i == len(timesteps)-1 or ((i+1)>num_warmup_steps and (i+1)%self.scheduler.order == 0):
|
156
161
|
progress_bar.update()
|
157
162
|
if callback is not None and i%callback_steps == 0:
|
158
|
-
|
163
|
+
latents = callback(i, t, num_inference_steps, latents_x0, latents)
|
164
|
+
if latents is None:
|
159
165
|
return None
|
160
166
|
|
161
167
|
latents = latents.to(dtype=self.vae.dtype)
|
@@ -277,8 +283,13 @@ class HookPipe_I2I(StableDiffusionImg2ImgPipeline):
|
|
277
283
|
|
278
284
|
# predict the noise residual
|
279
285
|
if pooled_output is None:
|
280
|
-
|
281
|
-
|
286
|
+
if isinstance(self.unet, PixArtTransformer2DModel):
|
287
|
+
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
|
288
|
+
noise_pred = self.unet(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
|
289
|
+
cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
|
290
|
+
else:
|
291
|
+
noise_pred = self.unet(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
|
292
|
+
cross_attention_kwargs=cross_attention_kwargs, ).sample
|
282
293
|
else:
|
283
294
|
added_cond_kwargs = {"text_embeds":pooled_output, "time_ids":crop_info}
|
284
295
|
# predict the noise residual
|
@@ -302,7 +313,8 @@ class HookPipe_I2I(StableDiffusionImg2ImgPipeline):
|
|
302
313
|
if i == len(timesteps)-1 or ((i+1)>num_warmup_steps and (i+1)%self.scheduler.order == 0):
|
303
314
|
progress_bar.update()
|
304
315
|
if callback is not None and i%callback_steps == 0:
|
305
|
-
|
316
|
+
latents = callback(i, t, num_inference_steps, latents_x0, latents)
|
317
|
+
if latents is None:
|
306
318
|
return None
|
307
319
|
|
308
320
|
latents = latents.to(dtype=self.vae.dtype)
|
@@ -450,7 +462,8 @@ class HookPipe_Inpaint(StableDiffusionInpaintPipelineLegacy):
|
|
450
462
|
if i == len(timesteps)-1 or ((i+1)>num_warmup_steps and (i+1)%self.scheduler.order == 0):
|
451
463
|
progress_bar.update()
|
452
464
|
if callback is not None and i%callback_steps == 0:
|
453
|
-
|
465
|
+
latents = callback(i, t, num_inference_steps, latents_x0, latents)
|
466
|
+
if latents is None:
|
454
467
|
return None
|
455
468
|
|
456
469
|
# use original latents corresponding to unmasked portions of the image
|
hcpdiff/utils/utils.py
CHANGED
@@ -56,8 +56,8 @@ def remove_config_undefined(cfg):
|
|
56
56
|
def load_config(path, remove_undefined=True):
|
57
57
|
cfg = OmegaConf.load(path)
|
58
58
|
if '_base_' in cfg:
|
59
|
-
for base in cfg['_base_']
|
60
|
-
|
59
|
+
base_cfgs = [load_config(base, remove_undefined=False) for base in cfg['_base_']]
|
60
|
+
cfg = OmegaConf.merge(*base_cfgs, cfg)
|
61
61
|
del cfg['_base_']
|
62
62
|
if remove_undefined:
|
63
63
|
cfg = remove_config_undefined(cfg)
|
@@ -85,7 +85,7 @@ def get_cfg_range(cfg_text:str):
|
|
85
85
|
def to_validate_file(name):
|
86
86
|
rstr = r"[\/\\\:\*\?\"\<\>\|]" # '/ \ : * ? " < > |'
|
87
87
|
new_title = re.sub(rstr, "_", name) # 替换为下划线
|
88
|
-
return new_title[:
|
88
|
+
return new_title[:200]
|
89
89
|
|
90
90
|
def make_mask(start, end, length):
|
91
91
|
mask=torch.zeros(length)
|
@@ -159,4 +159,21 @@ def pad_attn_bias(x, attn_bias, block_size=8):
|
|
159
159
|
# 在k维度上进行填充
|
160
160
|
x_padded = F.pad(x, (0, 0, 0, padding_l, 0, 0), mode='constant', value=0)
|
161
161
|
attn_bias_padded = F.pad(attn_bias, (0, padding_l, 0, 0), mode='constant', value=0)
|
162
|
-
return x_padded, attn_bias_padded
|
162
|
+
return x_padded, attn_bias_padded
|
163
|
+
|
164
|
+
def linear_interp(t, x):
|
165
|
+
'''
|
166
|
+
t_l ---------t_h
|
167
|
+
^x
|
168
|
+
'''
|
169
|
+
if (x>=len(t)).any():
|
170
|
+
x = x.clamp(max=len(t)-1e-6)
|
171
|
+
x0 = x.floor().long()
|
172
|
+
x1 = x0 + 1
|
173
|
+
|
174
|
+
y0 = t[x0]
|
175
|
+
y1 = t[x1]
|
176
|
+
|
177
|
+
xd = (x - x0.float())
|
178
|
+
|
179
|
+
return y0 * (1 - xd) + y1 * xd
|
hcpdiff/workflow/__init__.py
CHANGED
@@ -1,15 +1,20 @@
|
|
1
|
-
from .
|
2
|
-
|
3
|
-
X0PredAction, SeedAction, MakeTimestepsAction
|
1
|
+
from .diffusion import InputFeederAction, MakeLatentAction, DenoiseAction, SampleAction, DiffusionStepAction, \
|
2
|
+
X0PredAction, SeedAction, MakeTimestepsAction, PrepareDiffusionAction, time_iter
|
4
3
|
from .text import TextEncodeAction, TextHookAction, AttnMultTextEncodeAction
|
5
4
|
from .vae import EncodeAction, DecodeAction
|
6
|
-
from .io import
|
7
|
-
from .utils import LatentResizeAction, ImageResizeAction
|
8
|
-
from .model import VaeOptimizeAction, BuildOffloadAction, XformersEnableAction
|
5
|
+
from .io import BuildModelsAction, SaveImageAction, LoadImageAction
|
6
|
+
from .utils import LatentResizeAction, ImageResizeAction, FeedtoCNetAction
|
7
|
+
from .model import VaeOptimizeAction, BuildOffloadAction, XformersEnableAction
|
8
|
+
#from .flow import FilePromptAction
|
9
|
+
|
10
|
+
try:
|
11
|
+
from .fast import SFastCompileAction
|
12
|
+
except:
|
13
|
+
print('stable fast not installed.')
|
9
14
|
|
10
15
|
from omegaconf import OmegaConf
|
11
16
|
|
12
|
-
OmegaConf.register_new_resolver("hcp.from_memory", lambda mem_name:
|
13
|
-
'_target_':
|
14
|
-
'mem_name':
|
15
|
-
}))
|
17
|
+
OmegaConf.register_new_resolver("hcp.from_memory", lambda mem_name:OmegaConf.create({
|
18
|
+
'_target_':'hcpdiff.workflow.from_memory',
|
19
|
+
'mem_name':mem_name,
|
20
|
+
}))
|
@@ -0,0 +1 @@
|
|
1
|
+
from .act import CaptureCrossAttnAction, SaveWordAttnAction
|
@@ -0,0 +1,66 @@
|
|
1
|
+
import os
|
2
|
+
from io import BytesIO
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
from PIL import Image
|
6
|
+
from hcpdiff.utils import to_validate_file
|
7
|
+
from rainbowneko.utils import types_support
|
8
|
+
from matplotlib import pyplot as plt
|
9
|
+
from rainbowneko.infer import BasicAction, Actions
|
10
|
+
|
11
|
+
from .hook import DiffusionHeatMapHooker
|
12
|
+
|
13
|
+
class CaptureCrossAttnAction(Actions):
|
14
|
+
def forward(self, prompt, denoiser, tokenizer, vae, **states):
|
15
|
+
bs = len(prompt)
|
16
|
+
N_head = 8
|
17
|
+
with DiffusionHeatMapHooker(denoiser, tokenizer, vae_scale_factor=vae.vae_scale_factor) as tc:
|
18
|
+
states = super().forward(**states)
|
19
|
+
heat_maps = [tc.compute_global_heat_map(prompt=prompt[i], head_idxs=range(N_head*i, N_head*(i+1))) for i in range(bs)]
|
20
|
+
|
21
|
+
return {**states, 'cross_attn_heat_maps':heat_maps}
|
22
|
+
|
23
|
+
class SaveWordAttnAction(BasicAction):
|
24
|
+
|
25
|
+
def __init__(self, save_root: str, N_col: int = 4, image_type: str = 'png', quality: int = 95, key_map_in=None, key_map_out=None):
|
26
|
+
super().__init__(key_map_in, key_map_out)
|
27
|
+
self.save_root = save_root
|
28
|
+
self.image_type = image_type
|
29
|
+
self.quality = quality
|
30
|
+
self.N_col = N_col
|
31
|
+
|
32
|
+
os.makedirs(save_root, exist_ok=True)
|
33
|
+
|
34
|
+
def draw_attn(self, tokenizer, prompt, image, global_heat_map):
|
35
|
+
prompt=tokenizer.bos_token+prompt+tokenizer.eos_token
|
36
|
+
tokens = [token.replace("</w>", "") for token in tokenizer.tokenize(prompt)]
|
37
|
+
|
38
|
+
d_len = self.N_col
|
39
|
+
plt.rcParams['figure.dpi'] = 300
|
40
|
+
plt.rcParams.update({'font.size':12})
|
41
|
+
h = int(np.ceil(len(tokens)/d_len))
|
42
|
+
fig, ax = plt.subplots(h, d_len, figsize=(2*d_len, 2*h))
|
43
|
+
for ax_ in ax.flatten():
|
44
|
+
ax_.set_xticks([])
|
45
|
+
ax_.set_yticks([])
|
46
|
+
for i, token in enumerate(tokens):
|
47
|
+
heat_map = global_heat_map.compute_word_heat_map(token, word_idx=i)
|
48
|
+
if h==1:
|
49
|
+
heat_map.plot_overlay(image, ax=ax[i%d_len])
|
50
|
+
else:
|
51
|
+
heat_map.plot_overlay(image, ax=ax[i//d_len, i%d_len])
|
52
|
+
# plt.tight_layout()
|
53
|
+
|
54
|
+
buf = BytesIO()
|
55
|
+
plt.savefig(buf, format='png')
|
56
|
+
buf.seek(0)
|
57
|
+
return Image.open(buf)
|
58
|
+
|
59
|
+
def forward(self, tokenizer, images, prompt, seeds, cross_attn_heat_maps, **states):
|
60
|
+
num_img_exist = max([0]+[int(x.split('-', 1)[0]) for x in os.listdir(self.save_root) if x.rsplit('.', 1)[-1] in types_support])
|
61
|
+
|
62
|
+
for bid, (p, img) in enumerate(zip(prompt, images)):
|
63
|
+
img_path = os.path.join(self.save_root, f"{num_img_exist}-{seeds[bid]}-cross_attn-{to_validate_file(prompt[0])}.{self.image_type}")
|
64
|
+
img = self.draw_attn(tokenizer, p, img, cross_attn_heat_maps[bid])
|
65
|
+
img.save(img_path, quality=self.quality)
|
66
|
+
num_img_exist += 1
|
@@ -0,0 +1,109 @@
|
|
1
|
+
from daam import AggregateHooker, RawHeatMapCollection, UNetCrossAttentionLocator, GlobalHeatMap
|
2
|
+
from daam.trace import UNetCrossAttentionHooker
|
3
|
+
from typing import List
|
4
|
+
from diffusers import UNet2DConditionModel
|
5
|
+
from PIL import Image
|
6
|
+
import numpy as np
|
7
|
+
import torch
|
8
|
+
import torch.nn.functional as F
|
9
|
+
|
10
|
+
def auto_autocast(*args, **kwargs):
|
11
|
+
if not torch.cuda.is_available():
|
12
|
+
kwargs['enabled'] = False
|
13
|
+
|
14
|
+
return torch.cuda.amp.autocast(*args, **kwargs)
|
15
|
+
|
16
|
+
class DiffusionHeatMapHooker(AggregateHooker):
|
17
|
+
def __init__(
|
18
|
+
self,
|
19
|
+
unet: UNet2DConditionModel,
|
20
|
+
tokenizer,
|
21
|
+
vae_scale_factor: int,
|
22
|
+
low_memory: bool = False,
|
23
|
+
load_heads: bool = False,
|
24
|
+
save_heads: bool = False,
|
25
|
+
data_dir: str = None
|
26
|
+
):
|
27
|
+
self.all_heat_maps = RawHeatMapCollection()
|
28
|
+
h = (unet.config.sample_size * vae_scale_factor)
|
29
|
+
self.latent_hw = 4096 if h == 512 or h == 1024 else 9216 # 64x64 or 96x96 depending on if it's 2.0-v or 2.0
|
30
|
+
locate_middle = load_heads or save_heads
|
31
|
+
self.locator = UNetCrossAttentionLocator(restrict={0} if low_memory else None, locate_middle_block=locate_middle)
|
32
|
+
self.last_prompt: str = ''
|
33
|
+
self.last_image: Image.Image = None
|
34
|
+
self.time_idx = 0
|
35
|
+
self._gen_idx = 0
|
36
|
+
|
37
|
+
self.tokenizer = tokenizer
|
38
|
+
|
39
|
+
modules = [
|
40
|
+
UNetCrossAttentionHooker(
|
41
|
+
x,
|
42
|
+
self,
|
43
|
+
layer_idx=idx,
|
44
|
+
latent_hw=self.latent_hw,
|
45
|
+
load_heads=load_heads,
|
46
|
+
save_heads=save_heads,
|
47
|
+
data_dir=data_dir
|
48
|
+
) for idx, x in enumerate(self.locator.locate(unet))
|
49
|
+
]
|
50
|
+
|
51
|
+
super().__init__(modules)
|
52
|
+
|
53
|
+
def time_callback(self, *args, **kwargs):
|
54
|
+
self.time_idx += 1
|
55
|
+
|
56
|
+
@property
|
57
|
+
def layer_names(self):
|
58
|
+
return self.locator.layer_names
|
59
|
+
|
60
|
+
def compute_global_heat_map(self, prompt=None, factors=None, head_idxs: List[int]=None, layer_idx=None, normalize=False):
|
61
|
+
# type: (str, List[float], int, int, bool) -> GlobalHeatMap
|
62
|
+
"""
|
63
|
+
Compute the global heat map for the given prompt, aggregating across time (inference steps) and space (different
|
64
|
+
spatial transformer block heat maps).
|
65
|
+
|
66
|
+
Args:
|
67
|
+
prompt: The prompt to compute the heat map for. If none, uses the last prompt that was used for generation.
|
68
|
+
factors: Restrict the application to heat maps with spatial factors in this set. If `None`, use all sizes.
|
69
|
+
head_idx: Restrict the application to heat maps with this head index. If `None`, use all heads.
|
70
|
+
layer_idx: Restrict the application to heat maps with this layer index. If `None`, use all layers.
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
A heat map object for computing word-level heat maps.
|
74
|
+
"""
|
75
|
+
heat_maps = self.all_heat_maps
|
76
|
+
|
77
|
+
if prompt is None:
|
78
|
+
prompt = self.last_prompt
|
79
|
+
|
80
|
+
if factors is None:
|
81
|
+
factors = {0, 1, 2, 4, 8, 16, 32, 64}
|
82
|
+
else:
|
83
|
+
factors = set(factors)
|
84
|
+
|
85
|
+
all_merges = []
|
86
|
+
x = int(np.sqrt(self.latent_hw))
|
87
|
+
|
88
|
+
with auto_autocast(dtype=torch.float32):
|
89
|
+
for (factor, layer, head), heat_map in heat_maps:
|
90
|
+
if (head_idxs is None or head in head_idxs) and (layer_idx is None or layer_idx == layer):
|
91
|
+
heat_map = heat_map.unsqueeze(1)/25 # [L,1,H,W]
|
92
|
+
# The clamping fixes undershoot.
|
93
|
+
all_merges.append(F.interpolate(heat_map, size=(x, x), mode='bicubic').clamp_(min=0))
|
94
|
+
|
95
|
+
try:
|
96
|
+
maps = torch.stack(all_merges, dim=0) # [B*head, L, 1, H, W]
|
97
|
+
except RuntimeError:
|
98
|
+
if head_idxs is not None or layer_idx is not None:
|
99
|
+
raise RuntimeError('No heat maps found for the given parameters.')
|
100
|
+
else:
|
101
|
+
raise RuntimeError('No heat maps found. Did you forget to call `with trace(...)` during generation?')
|
102
|
+
|
103
|
+
maps = maps.mean(0)[:, 0] # [L,H,W]
|
104
|
+
#maps = maps[:len(self.tokenizer.tokenize(prompt)) + 2] # 1 for SOS and 1 for padding
|
105
|
+
|
106
|
+
if normalize:
|
107
|
+
maps = maps / (maps[1:-1].sum(0, keepdim=True) + 1e-6) # drop out [SOS] and [PAD] for proper probabilities
|
108
|
+
|
109
|
+
return GlobalHeatMap(self.tokenizer, prompt, maps)
|