hcpdiff 0.9.1__py3-none-any.whl → 2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/__init__.py +4 -4
- hcpdiff/ckpt_manager/__init__.py +4 -5
- hcpdiff/ckpt_manager/ckpt.py +24 -0
- hcpdiff/ckpt_manager/format/__init__.py +4 -0
- hcpdiff/ckpt_manager/format/diffusers.py +59 -0
- hcpdiff/ckpt_manager/format/emb.py +21 -0
- hcpdiff/ckpt_manager/format/lora_webui.py +252 -0
- hcpdiff/ckpt_manager/format/sd_single.py +41 -0
- hcpdiff/ckpt_manager/loader.py +64 -0
- hcpdiff/data/__init__.py +4 -28
- hcpdiff/data/cache/__init__.py +1 -0
- hcpdiff/data/cache/vae.py +102 -0
- hcpdiff/data/dataset.py +20 -0
- hcpdiff/data/handler/__init__.py +3 -0
- hcpdiff/data/handler/controlnet.py +18 -0
- hcpdiff/data/handler/diffusion.py +90 -0
- hcpdiff/data/handler/text.py +111 -0
- hcpdiff/data/source/__init__.py +3 -3
- hcpdiff/data/source/folder_class.py +12 -29
- hcpdiff/data/source/text.py +40 -0
- hcpdiff/data/source/text2img.py +36 -74
- hcpdiff/data/source/text2img_cond.py +9 -15
- hcpdiff/diffusion/__init__.py +0 -0
- hcpdiff/diffusion/noise/__init__.py +2 -0
- hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
- hcpdiff/diffusion/noise/zero_terminal.py +39 -0
- hcpdiff/diffusion/sampler/__init__.py +5 -0
- hcpdiff/diffusion/sampler/base.py +72 -0
- hcpdiff/diffusion/sampler/ddpm.py +20 -0
- hcpdiff/diffusion/sampler/diffusers.py +66 -0
- hcpdiff/diffusion/sampler/edm.py +22 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
- hcpdiff/easy/__init__.py +2 -0
- hcpdiff/easy/cfg/__init__.py +3 -0
- hcpdiff/easy/cfg/sd15_train.py +207 -0
- hcpdiff/easy/cfg/sdxl_train.py +147 -0
- hcpdiff/easy/cfg/t2i.py +228 -0
- hcpdiff/easy/model/__init__.py +2 -0
- hcpdiff/easy/model/cnet.py +31 -0
- hcpdiff/easy/model/loader.py +79 -0
- hcpdiff/easy/sampler.py +46 -0
- hcpdiff/evaluate/__init__.py +1 -0
- hcpdiff/evaluate/previewer.py +60 -0
- hcpdiff/loss/__init__.py +4 -1
- hcpdiff/loss/base.py +41 -0
- hcpdiff/loss/gw.py +35 -0
- hcpdiff/loss/ssim.py +37 -0
- hcpdiff/loss/vlb.py +79 -0
- hcpdiff/loss/weighting.py +66 -0
- hcpdiff/models/__init__.py +2 -2
- hcpdiff/models/cfg_context.py +17 -14
- hcpdiff/models/compose/compose_hook.py +44 -23
- hcpdiff/models/compose/compose_tokenizer.py +21 -8
- hcpdiff/models/compose/sdxl_composer.py +4 -4
- hcpdiff/models/controlnet.py +16 -16
- hcpdiff/models/lora_base_patch.py +14 -25
- hcpdiff/models/lora_layers.py +3 -9
- hcpdiff/models/lora_layers_patch.py +14 -24
- hcpdiff/models/text_emb_ex.py +84 -6
- hcpdiff/models/textencoder_ex.py +54 -18
- hcpdiff/models/wrapper/__init__.py +3 -0
- hcpdiff/models/wrapper/pixart.py +19 -0
- hcpdiff/models/wrapper/sd.py +218 -0
- hcpdiff/models/wrapper/utils.py +20 -0
- hcpdiff/parser/__init__.py +1 -0
- hcpdiff/parser/embpt.py +32 -0
- hcpdiff/tools/convert_caption_txt2json.py +1 -1
- hcpdiff/tools/dataset_generator.py +94 -0
- hcpdiff/tools/download_hf_model.py +24 -0
- hcpdiff/tools/init_proj.py +3 -21
- hcpdiff/tools/lora_convert.py +18 -17
- hcpdiff/tools/save_model.py +12 -0
- hcpdiff/tools/sd2diffusers.py +1 -1
- hcpdiff/train_colo.py +1 -1
- hcpdiff/train_deepspeed.py +1 -1
- hcpdiff/trainer_ac.py +79 -0
- hcpdiff/trainer_ac_single.py +31 -0
- hcpdiff/utils/__init__.py +0 -2
- hcpdiff/utils/inpaint_pipe.py +7 -2
- hcpdiff/utils/net_utils.py +29 -6
- hcpdiff/utils/pipe_hook.py +24 -7
- hcpdiff/utils/utils.py +21 -4
- hcpdiff/workflow/__init__.py +15 -10
- hcpdiff/workflow/daam/__init__.py +1 -0
- hcpdiff/workflow/daam/act.py +66 -0
- hcpdiff/workflow/daam/hook.py +109 -0
- hcpdiff/workflow/diffusion.py +118 -128
- hcpdiff/workflow/fast.py +31 -0
- hcpdiff/workflow/flow.py +67 -0
- hcpdiff/workflow/io.py +36 -130
- hcpdiff/workflow/model.py +46 -43
- hcpdiff/workflow/text.py +60 -47
- hcpdiff/workflow/utils.py +32 -12
- hcpdiff/workflow/vae.py +37 -38
- hcpdiff-2.2.dist-info/METADATA +299 -0
- hcpdiff-2.2.dist-info/RECORD +115 -0
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/WHEEL +1 -1
- hcpdiff-2.2.dist-info/entry_points.txt +5 -0
- hcpdiff/ckpt_manager/base.py +0 -16
- hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
- hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
- hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
- hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
- hcpdiff/data/bucket.py +0 -358
- hcpdiff/data/caption_loader.py +0 -80
- hcpdiff/data/cond_dataset.py +0 -40
- hcpdiff/data/crop_info_dataset.py +0 -40
- hcpdiff/data/data_processor.py +0 -33
- hcpdiff/data/pair_dataset.py +0 -146
- hcpdiff/data/sampler.py +0 -54
- hcpdiff/data/source/base.py +0 -30
- hcpdiff/data/utils.py +0 -80
- hcpdiff/deprecated/__init__.py +0 -1
- hcpdiff/deprecated/cfg_converter.py +0 -81
- hcpdiff/deprecated/lora_convert.py +0 -31
- hcpdiff/infer_workflow.py +0 -57
- hcpdiff/loggers/__init__.py +0 -13
- hcpdiff/loggers/base_logger.py +0 -76
- hcpdiff/loggers/cli_logger.py +0 -40
- hcpdiff/loggers/preview/__init__.py +0 -1
- hcpdiff/loggers/preview/image_previewer.py +0 -149
- hcpdiff/loggers/tensorboard_logger.py +0 -30
- hcpdiff/loggers/wandb_logger.py +0 -31
- hcpdiff/loggers/webui_logger.py +0 -9
- hcpdiff/loss/min_snr_loss.py +0 -52
- hcpdiff/models/layers.py +0 -81
- hcpdiff/models/plugin.py +0 -348
- hcpdiff/models/wrapper.py +0 -75
- hcpdiff/noise/__init__.py +0 -3
- hcpdiff/noise/noise_base.py +0 -16
- hcpdiff/noise/pyramid_noise.py +0 -50
- hcpdiff/noise/zero_terminal.py +0 -44
- hcpdiff/train_ac.py +0 -566
- hcpdiff/train_ac_single.py +0 -39
- hcpdiff/utils/caption_tools.py +0 -105
- hcpdiff/utils/cfg_net_tools.py +0 -321
- hcpdiff/utils/cfg_resolvers.py +0 -16
- hcpdiff/utils/ema.py +0 -52
- hcpdiff/utils/img_size_tool.py +0 -248
- hcpdiff/vis/__init__.py +0 -3
- hcpdiff/vis/base_interface.py +0 -12
- hcpdiff/vis/disk_interface.py +0 -48
- hcpdiff/vis/webui_interface.py +0 -17
- hcpdiff/viser_fast.py +0 -138
- hcpdiff/visualizer.py +0 -265
- hcpdiff/visualizer_reloadable.py +0 -237
- hcpdiff/workflow/base.py +0 -59
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
- hcpdiff-0.9.1.dist-info/METADATA +0 -199
- hcpdiff-0.9.1.dist-info/RECORD +0 -160
- hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info/licenses}/LICENSE +0 -0
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,94 @@
|
|
1
|
+
import argparse
|
2
|
+
import json
|
3
|
+
import os.path
|
4
|
+
from typing import Callable
|
5
|
+
|
6
|
+
import pyarrow.parquet as pq
|
7
|
+
import torch
|
8
|
+
from PIL import Image
|
9
|
+
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
10
|
+
from tqdm.auto import tqdm
|
11
|
+
|
12
|
+
from hcpdiff.data.caption_loader import auto_caption_loader
|
13
|
+
|
14
|
+
class DatasetCreator:
|
15
|
+
def __init__(self, pretrained_model, out_dir: str, img_w: int=512, img_h: int=512):
|
16
|
+
scheduler = DPMSolverMultistepScheduler(
|
17
|
+
beta_start = 0.00085,
|
18
|
+
beta_end = 0.012,
|
19
|
+
beta_schedule = 'scaled_linear',
|
20
|
+
algorithm_type = 'dpmsolver++',
|
21
|
+
use_karras_sigmas = True,
|
22
|
+
)
|
23
|
+
|
24
|
+
self.pipeline = DiffusionPipeline.from_pretrained(pretrained_model, scheduler=scheduler, torch_dtype=torch.float16)
|
25
|
+
self.pipeline.requires_safety_checker = False
|
26
|
+
self.pipeline.safety_checker = None
|
27
|
+
self.pipeline.to("cuda")
|
28
|
+
self.pipeline.unet.to(memory_format=torch.channels_last)
|
29
|
+
#self.pipeline.enable_xformers_memory_efficient_attention()
|
30
|
+
|
31
|
+
self.out_dir = out_dir
|
32
|
+
self.img_w = img_w
|
33
|
+
self.img_h = img_h
|
34
|
+
|
35
|
+
def create_from_prompt_dataset(self, prompt_file: str, negative_prompt: str, bs: int, num: int=None, repeat:int=1, save_fmt:str='txt',
|
36
|
+
callback: Callable[[int, int], bool] = None):
|
37
|
+
os.makedirs(self.out_dir, exist_ok=True)
|
38
|
+
data = auto_caption_loader(prompt_file).load()
|
39
|
+
data = list(data.items())
|
40
|
+
data = self.split_batch(data, bs) # [[(k,v),...],...]
|
41
|
+
|
42
|
+
if num is None:
|
43
|
+
num = len(data)
|
44
|
+
total = num*bs
|
45
|
+
count = 0
|
46
|
+
captions = {}
|
47
|
+
with torch.inference_mode():
|
48
|
+
for i in tqdm(range(num)):
|
49
|
+
for r in range(repeat):
|
50
|
+
name_batch, p_batch = list(zip(*data[i%len(data)]))
|
51
|
+
imgs = self.pipeline(list(p_batch), negative_prompt=[negative_prompt]*len(p_batch), num_inference_steps=25,
|
52
|
+
width=self.img_w, height=self.img_h).images
|
53
|
+
for name, prompt, img in zip(name_batch, p_batch, imgs):
|
54
|
+
img.save(os.path.join(self.out_dir, f'{count}_{name}.png'), format='PNG')
|
55
|
+
captions[f'{count}_{name}'] = prompt
|
56
|
+
count += 1
|
57
|
+
if callback:
|
58
|
+
if not callback(count, total):
|
59
|
+
break
|
60
|
+
|
61
|
+
if save_fmt=='txt':
|
62
|
+
for k, v in captions.items():
|
63
|
+
with open(os.path.join(self.out_dir, f'{k}.txt'), "w") as f:
|
64
|
+
f.write(v)
|
65
|
+
elif save_fmt=='json':
|
66
|
+
with open(os.path.join(self.out_dir, f'image_captions.json'), "w") as f:
|
67
|
+
json.dump(captions, f)
|
68
|
+
else:
|
69
|
+
raise ValueError(f"Invalid save_fmt: {save_fmt}")
|
70
|
+
|
71
|
+
@staticmethod
|
72
|
+
def split_batch(data, bs):
|
73
|
+
return [data[i:i+bs] for i in range(0, len(data), bs)]
|
74
|
+
|
75
|
+
# python dataset_generator.py --prompt_file 标注文件或文件夹路径 --model 模型名称 --out_dir 输出文件夹路径 --repeat 每个prompt生成几个图 --bs batch_size --img_w 图片宽度 --img_h 图片高度
|
76
|
+
# python dataset_generator.py --prompt_file 标注文件或文件夹路径 --model 模型名称 --out_dir 输出文件夹路径 --repeat 1 --bs 4 --img_w 640 --img_h 640
|
77
|
+
if __name__ == '__main__':
|
78
|
+
torch.backends.cudnn.benchmark = True
|
79
|
+
parser = argparse.ArgumentParser(description='Diffusion Dataset Generator')
|
80
|
+
parser.add_argument('--prompt_file', type=str, default='')
|
81
|
+
parser.add_argument('--model', type=str, default='runwayml/stable-diffusion-v1-5')
|
82
|
+
parser.add_argument('--out_dir', type=str, default=r'./prompt_ds')
|
83
|
+
parser.add_argument('--negative_prompt', type=str,
|
84
|
+
default='lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry')
|
85
|
+
parser.add_argument('--num', type=int, default=200)
|
86
|
+
parser.add_argument('--repeat', type=int, default=1)
|
87
|
+
parser.add_argument('--save_fmt', type=str, default='txt')
|
88
|
+
parser.add_argument('--bs', type=int, default=4)
|
89
|
+
parser.add_argument('--img_w', type=int, default=512)
|
90
|
+
parser.add_argument('--img_h', type=int, default=512)
|
91
|
+
args = parser.parse_args()
|
92
|
+
|
93
|
+
ds_creator = DatasetCreator(args.model, args.out_dir, args.img_w, args.img_h)
|
94
|
+
ds_creator.create_from_prompt_dataset(args.prompt_file, args.negative_prompt, args.bs, args.num, repeat=args.repeat, save_fmt=args.save_fmt)
|
@@ -0,0 +1,24 @@
|
|
1
|
+
from diffusers import DiffusionPipeline
|
2
|
+
import argparse
|
3
|
+
import torch
|
4
|
+
|
5
|
+
if __name__ == '__main__':
|
6
|
+
parser = argparse.ArgumentParser(description='Download Model')
|
7
|
+
parser.add_argument('--model', type=str, default='runwayml/stable-diffusion-v1-5')
|
8
|
+
parser.add_argument("--fp16", default=False, action="store_true")
|
9
|
+
parser.add_argument("--use_safetensors", default=False, action="store_true")
|
10
|
+
parser.add_argument("--out_path", type=str, default='ckpts/sd15')
|
11
|
+
args = parser.parse_args()
|
12
|
+
|
13
|
+
load_args = dict(torch_dtype = torch.float16 if args.fp16 else torch.float32)
|
14
|
+
save_args = dict()
|
15
|
+
|
16
|
+
if args.fp16:
|
17
|
+
load_args['variant'] = "fp16"
|
18
|
+
save_args['variant'] = "fp16"
|
19
|
+
if args.use_safetensors:
|
20
|
+
load_args['use_safetensors'] = True
|
21
|
+
save_args['safe_serialization'] = True
|
22
|
+
|
23
|
+
pipe = DiffusionPipeline.from_pretrained(args.model, **load_args)
|
24
|
+
pipe.save_pretrained(args.out_path, **save_args)
|
hcpdiff/tools/init_proj.py
CHANGED
@@ -1,23 +1,5 @@
|
|
1
|
-
import
|
2
|
-
import shutil
|
3
|
-
import os
|
1
|
+
from rainbowneko.tools.init_proj import copy_package_data
|
4
2
|
|
5
3
|
def main():
|
6
|
-
|
7
|
-
|
8
|
-
prefix = os.path.join(prefix, 'local')
|
9
|
-
try:
|
10
|
-
if os.path.exists(r'./cfgs'):
|
11
|
-
shutil.rmtree(r'./cfgs')
|
12
|
-
if os.path.exists(r'./prompt_tuning_template'):
|
13
|
-
shutil.rmtree(r'./prompt_tuning_template')
|
14
|
-
shutil.copytree(os.path.join(prefix, 'hcpdiff/cfgs'), r'./cfgs')
|
15
|
-
shutil.copytree(os.path.join(prefix, 'hcpdiff/prompt_tuning_template'), r'./prompt_tuning_template')
|
16
|
-
except:
|
17
|
-
try:
|
18
|
-
shutil.copytree(os.path.join(prefix, '../hcpdiff/cfgs'), r'./cfgs')
|
19
|
-
shutil.copytree(os.path.join(prefix, '../hcpdiff/prompt_tuning_template'), r'./prompt_tuning_template')
|
20
|
-
except:
|
21
|
-
this_file_dir = os.path.dirname(os.path.abspath(__file__))
|
22
|
-
shutil.copytree(os.path.join(this_file_dir, '../../cfgs'), r'./cfgs')
|
23
|
-
shutil.copytree(os.path.join(this_file_dir, '../../prompt_tuning_template'), r'./prompt_tuning_template')
|
4
|
+
copy_package_data('hcpdiff', 'cfgs', './cfgs')
|
5
|
+
copy_package_data('hcpdiff', 'prompt_template', './prompt_template')
|
hcpdiff/tools/lora_convert.py
CHANGED
@@ -3,15 +3,14 @@ import os.path
|
|
3
3
|
from typing import List
|
4
4
|
import math
|
5
5
|
|
6
|
-
from
|
7
|
-
from hcpdiff.deprecated import convert_to_webui_maybe_old, convert_to_webui_xl_maybe_old
|
6
|
+
from rainbowneko.ckpt_manager import auto_ckpt_loader, NekoModelSaver
|
8
7
|
|
9
8
|
class LoraConverter:
|
10
9
|
com_name_unet = ['down_blocks', 'up_blocks', 'mid_block', 'transformer_blocks', 'to_q', 'to_k', 'to_v', 'to_out', 'proj_in', 'proj_out', 'input_blocks', 'middle_block', 'output_blocks']
|
11
10
|
com_name_TE = ['self_attn', 'q_proj', 'v_proj', 'k_proj', 'out_proj', 'text_model']
|
12
11
|
prefix_unet = 'lora_unet_'
|
13
12
|
prefix_TE = 'lora_te_'
|
14
|
-
|
13
|
+
prefix_TE_xl_clip_L = 'lora_te1_'
|
15
14
|
prefix_TE_xl_clip_bigG = 'lora_te2_'
|
16
15
|
|
17
16
|
lora_w_map = {'lora_down.weight': 'W_down', 'lora_up.weight':'W_up'}
|
@@ -26,14 +25,14 @@ class LoraConverter:
|
|
26
25
|
sd_TE = self.convert_from_webui_(state, prefix=self.prefix_TE, com_name=self.com_name_TE, com_name_tmp=self.com_name_TE_tmp)
|
27
26
|
else:
|
28
27
|
sd_unet = self.convert_from_webui_xl_unet_(state, prefix=self.prefix_unet, com_name=self.com_name_unet, com_name_tmp=self.com_name_unet_tmp)
|
29
|
-
sd_TE = self.convert_from_webui_xl_te_(state, prefix=self.
|
28
|
+
sd_TE = self.convert_from_webui_xl_te_(state, prefix=self.prefix_TE_xl_clip_L, com_name=self.com_name_TE, com_name_tmp=self.com_name_TE_tmp)
|
30
29
|
sd_TE2 = self.convert_from_webui_xl_te_(state, prefix=self.prefix_TE_xl_clip_bigG, com_name=self.com_name_TE, com_name_tmp=self.com_name_TE_tmp)
|
31
30
|
sd_TE.update(sd_TE2)
|
32
31
|
|
33
32
|
if auto_scale_alpha:
|
34
33
|
sd_unet = self.alpha_scale_from_webui(sd_unet)
|
35
34
|
sd_TE = self.alpha_scale_from_webui(sd_TE)
|
36
|
-
return {'
|
35
|
+
return {'plugin': sd_TE}, {'plugin': sd_unet}
|
37
36
|
|
38
37
|
def convert_to_webui(self, sd_unet, sd_TE, auto_scale_alpha=False, sdxl=False):
|
39
38
|
sd_unet = self.convert_to_webui_(sd_unet, prefix=self.prefix_unet)
|
@@ -59,7 +58,6 @@ class LoraConverter:
|
|
59
58
|
sd_covert[f'{model_k}.___.layer.{self.lora_w_map[lora_k]}'] = v
|
60
59
|
return sd_covert
|
61
60
|
|
62
|
-
@convert_to_webui_maybe_old
|
63
61
|
def convert_to_webui_(self, state, prefix):
|
64
62
|
sd_covert = {}
|
65
63
|
for k, v in state.items():
|
@@ -75,7 +73,6 @@ class LoraConverter:
|
|
75
73
|
sd_covert[f"{prefix}{model_k.replace('.', '_')}.{lora_k}"] = v
|
76
74
|
return sd_covert
|
77
75
|
|
78
|
-
@convert_to_webui_xl_maybe_old
|
79
76
|
def convert_to_webui_xl_(self, state, prefix):
|
80
77
|
sd_convert = {}
|
81
78
|
for k, v in state.items():
|
@@ -90,7 +87,7 @@ class LoraConverter:
|
|
90
87
|
|
91
88
|
new_k = f"{prefix}{model_k.replace('.', '_')}.{lora_k}"
|
92
89
|
if 'clip' in new_k:
|
93
|
-
new_k = new_k.replace('
|
90
|
+
new_k = new_k.replace('_clip_L', '1') if 'clip_L' in new_k else new_k.replace('_clip_bigG', '2')
|
94
91
|
sd_convert[new_k] = v
|
95
92
|
return sd_convert
|
96
93
|
|
@@ -103,7 +100,7 @@ class LoraConverter:
|
|
103
100
|
model_k, lora_k = k[prefix_len:].split('.', 1)
|
104
101
|
model_k = self.replace_all(model_k, com_name, com_name_tmp).replace('_', '.').replace('%', '_')
|
105
102
|
if prefix == 'lora_te1_':
|
106
|
-
model_k = f'
|
103
|
+
model_k = f'clip_L.{model_k}'
|
107
104
|
else:
|
108
105
|
model_k = f'clip_bigG.{model_k}'
|
109
106
|
|
@@ -224,23 +221,27 @@ if __name__ == '__main__':
|
|
224
221
|
|
225
222
|
# load lora model
|
226
223
|
print('convert lora model')
|
227
|
-
|
224
|
+
ckpt_loader = auto_ckpt_loader(args.lora_path)
|
225
|
+
ckpt_saver = NekoModelSaver(
|
226
|
+
format=ckpt_loader.format,
|
227
|
+
source=ckpt_loader.source,
|
228
|
+
)
|
228
229
|
|
229
230
|
if args.from_webui:
|
230
|
-
state =
|
231
|
+
state = ckpt_loader.load(args.lora_path)
|
231
232
|
# convert the weight name
|
232
233
|
sd_TE, sd_unet = converter.convert_from_webui(state, auto_scale_alpha=args.auto_scale_alpha, sdxl=args.sdxl)
|
233
234
|
# wegiht save
|
234
235
|
os.makedirs(args.dump_path, exist_ok=True)
|
235
236
|
TE_path = os.path.join(args.dump_path, 'TE-'+lora_name)
|
236
237
|
unet_path = os.path.join(args.dump_path, 'unet-'+lora_name)
|
237
|
-
|
238
|
-
|
238
|
+
ckpt_saver.save(sd_TE, TE_path)
|
239
|
+
ckpt_saver.save(sd_unet, unet_path)
|
239
240
|
print('save text encoder lora to:', TE_path)
|
240
241
|
print('save unet lora to:', unet_path)
|
241
242
|
elif args.to_webui:
|
242
|
-
sd_unet =
|
243
|
-
sd_TE =
|
244
|
-
state = converter.convert_to_webui(sd_unet['
|
245
|
-
|
243
|
+
sd_unet = ckpt_loader.load(args.lora_path)
|
244
|
+
sd_TE = ckpt_loader.load(args.lora_path_TE) if args.lora_path_TE else {'base':{}}
|
245
|
+
state = converter.convert_to_webui(sd_unet['base'], sd_TE['base'], auto_scale_alpha=args.auto_scale_alpha, sdxl=args.sdxl)
|
246
|
+
ckpt_saver.save(state, args.dump_path)
|
246
247
|
print('save lora to:', args.dump_path)
|
@@ -0,0 +1,12 @@
|
|
1
|
+
from diffusers import DiffusionPipeline
|
2
|
+
import argparse
|
3
|
+
|
4
|
+
parser = argparse.ArgumentParser()
|
5
|
+
parser.add_argument("model", default=None, type=str)
|
6
|
+
parser.add_argument("output", default=None, type=str)
|
7
|
+
args = parser.parse_args()
|
8
|
+
|
9
|
+
pipe = DiffusionPipeline.from_pretrained(args.model, safety_checker=None, requires_safety_checker=False,
|
10
|
+
resume_download=True)
|
11
|
+
|
12
|
+
pipe.save_pretrained(args.output)
|
hcpdiff/tools/sd2diffusers.py
CHANGED
@@ -211,7 +211,7 @@ def sd_vae_to_diffuser(args):
|
|
211
211
|
def convert_ckpt(args):
|
212
212
|
pipe = load_sd_ckpt(
|
213
213
|
args.checkpoint_path,
|
214
|
-
|
214
|
+
config_files={'v1': args.original_config_file},
|
215
215
|
image_size=args.image_size,
|
216
216
|
prediction_type=args.prediction_type,
|
217
217
|
model_type=args.pipeline_type,
|
hcpdiff/train_colo.py
CHANGED
@@ -23,7 +23,7 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
|
|
23
23
|
from colossalai.utils.model.colo_init_context import _convert_to_coloparam
|
24
24
|
from colossalai.tensor import ColoParameter
|
25
25
|
|
26
|
-
from hcpdiff.
|
26
|
+
from hcpdiff.train_ac_old import Trainer, get_scheduler, ModelEMA
|
27
27
|
from diffusers import UNet2DConditionModel
|
28
28
|
from hcpdiff.utils.colo_utils import gemini_zero_dpp, GeminiAdamOptimizerP
|
29
29
|
from hcpdiff.utils.utils import load_config_with_cli
|
hcpdiff/train_deepspeed.py
CHANGED
@@ -7,7 +7,7 @@ from functools import partial
|
|
7
7
|
import torch
|
8
8
|
|
9
9
|
from hcpdiff.ckpt_manager import CkptManagerPKL, CkptManagerSafe
|
10
|
-
from hcpdiff.
|
10
|
+
from hcpdiff.train_ac_old import Trainer, load_config_with_cli
|
11
11
|
from hcpdiff.utils.net_utils import get_scheduler
|
12
12
|
|
13
13
|
class TrainerDeepSpeed(Trainer):
|
hcpdiff/trainer_ac.py
ADDED
@@ -0,0 +1,79 @@
|
|
1
|
+
import argparse
|
2
|
+
import warnings
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from rainbowneko.parser import load_config_with_cli
|
6
|
+
from rainbowneko.ckpt_manager import NekoSaver
|
7
|
+
from rainbowneko.train import Trainer
|
8
|
+
from rainbowneko.utils import xformers_available, is_dict
|
9
|
+
from hcpdiff.ckpt_manager import EmbFormat
|
10
|
+
|
11
|
+
class HCPTrainer(Trainer):
|
12
|
+
def config_model(self):
|
13
|
+
if self.cfgs.model.enable_xformers:
|
14
|
+
if xformers_available:
|
15
|
+
self.model_wrapper.enable_xformers()
|
16
|
+
else:
|
17
|
+
warnings.warn("xformers is not available. Make sure it is installed correctly")
|
18
|
+
|
19
|
+
if self.model_wrapper.vae is not None:
|
20
|
+
self.vae_dtype = self.weight_dtype_map.get(self.cfgs.model.get('vae_dtype', None), torch.float32)
|
21
|
+
self.model_wrapper.set_dtype(self.weight_dtype, self.vae_dtype)
|
22
|
+
|
23
|
+
if self.cfgs.model.gradient_checkpointing:
|
24
|
+
self.model_wrapper.enable_gradient_checkpointing()
|
25
|
+
|
26
|
+
def get_param_group_train(self):
|
27
|
+
train_params = super().get_param_group_train()
|
28
|
+
|
29
|
+
# For prompt-tuning
|
30
|
+
if self.cfgs.emb_pt is None:
|
31
|
+
train_params_emb, self.train_pts = [], {}
|
32
|
+
else:
|
33
|
+
from hcpdiff.parser import CfgEmbPTParser
|
34
|
+
self.cfgs.emb_pt: CfgEmbPTParser
|
35
|
+
|
36
|
+
train_params_emb, self.train_pts = self.cfgs.emb_pt.get_params_group(self.model_wrapper)
|
37
|
+
self.emb_format = EmbFormat()
|
38
|
+
train_params += train_params_emb
|
39
|
+
return train_params
|
40
|
+
|
41
|
+
@property
|
42
|
+
def pt_trainable(self):
|
43
|
+
return self.cfgs.emb_pt is not None
|
44
|
+
|
45
|
+
def get_loss(self, ds_name, model_pred, inputs):
|
46
|
+
loss = super().get_loss(ds_name, model_pred, inputs)
|
47
|
+
# make DDP happy
|
48
|
+
if len(self.train_pts)>0:
|
49
|
+
loss = loss+0*sum([emb.mean() for emb in self.train_pts.values()])
|
50
|
+
return loss
|
51
|
+
|
52
|
+
def save_model(self, from_raw=False):
|
53
|
+
NekoSaver.save_all(
|
54
|
+
self.model_raw,
|
55
|
+
plugin_groups={**self.all_plugin, 'embs': self.train_pts},
|
56
|
+
cfg=self.ckpt_saver,
|
57
|
+
model_ema=getattr(self, "ema_model", None),
|
58
|
+
name_template=f'{{}}-{self.real_step}',
|
59
|
+
)
|
60
|
+
|
61
|
+
self.loggers.info(f"Saved state, step: {self.real_step}")
|
62
|
+
|
63
|
+
def hcp_train():
|
64
|
+
import subprocess
|
65
|
+
parser = argparse.ArgumentParser(description='HCP-Diffusion Launcher')
|
66
|
+
parser.add_argument('--launch_cfg', type=str, default='cfgs/launcher/multi.yaml')
|
67
|
+
args, train_args = parser.parse_known_args()
|
68
|
+
|
69
|
+
subprocess.run(["accelerate", "launch", '--config_file', args.launch_cfg, "-m",
|
70
|
+
"hcpdiff.trainer_ac"] + train_args, check=True)
|
71
|
+
|
72
|
+
if __name__ == "__main__":
|
73
|
+
parser = argparse.ArgumentParser(description="HCP Diffusion Trainer")
|
74
|
+
parser.add_argument("--cfg", type=str, default=None, required=True)
|
75
|
+
args, cfg_args = parser.parse_known_args()
|
76
|
+
|
77
|
+
parser, conf = load_config_with_cli(args.cfg, args_list=cfg_args) # skip --cfg
|
78
|
+
trainer = HCPTrainer(parser, conf)
|
79
|
+
trainer.train()
|
@@ -0,0 +1,31 @@
|
|
1
|
+
import argparse
|
2
|
+
import sys
|
3
|
+
from functools import partial
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from accelerate import Accelerator
|
7
|
+
from loguru import logger
|
8
|
+
|
9
|
+
from rainbowneko.train.trainer import TrainerSingleCard
|
10
|
+
from hcpdiff.trainer_ac import HCPTrainer, load_config_with_cli
|
11
|
+
|
12
|
+
class HCPTrainerSingleCard(TrainerSingleCard, HCPTrainer):
|
13
|
+
pass
|
14
|
+
|
15
|
+
def hcp_train():
|
16
|
+
import subprocess
|
17
|
+
parser = argparse.ArgumentParser(description='HCP-Diffusion Launcher')
|
18
|
+
parser.add_argument('--launch_cfg', type=str, default='cfgs/launcher/single.yaml')
|
19
|
+
args, train_args = parser.parse_known_args()
|
20
|
+
|
21
|
+
subprocess.run(["accelerate", "launch", '--config_file', args.launch_cfg, "-m",
|
22
|
+
"hcpdiff.trainer_ac_single"] + train_args, check=True)
|
23
|
+
|
24
|
+
if __name__ == '__main__':
|
25
|
+
parser = argparse.ArgumentParser(description='HCP Diffusion Trainer')
|
26
|
+
parser.add_argument("--cfg", type=str, default=None, required=True)
|
27
|
+
args, cfg_args = parser.parse_known_args()
|
28
|
+
|
29
|
+
parser, conf = load_config_with_cli(args.cfg, args_list=cfg_args) # skip --cfg
|
30
|
+
trainer = HCPTrainerSingleCard(parser, conf)
|
31
|
+
trainer.train()
|
hcpdiff/utils/__init__.py
CHANGED
hcpdiff/utils/inpaint_pipe.py
CHANGED
@@ -21,18 +21,23 @@ import torch
|
|
21
21
|
from packaging import version
|
22
22
|
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
23
23
|
|
24
|
+
from diffusers import StableDiffusionInpaintPipelineLegacy
|
24
25
|
from diffusers.configuration_utils import FrozenDict
|
25
26
|
from diffusers.image_processor import VaeImageProcessor
|
26
27
|
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
27
28
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
28
29
|
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
29
30
|
from diffusers.schedulers import KarrasDiffusionSchedulers
|
30
|
-
from diffusers.utils import PIL_INTERPOLATION,
|
31
|
+
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
|
31
32
|
from diffusers.utils.torch_utils import randn_tensor
|
32
33
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
33
|
-
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
34
|
+
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
34
35
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
35
36
|
|
37
|
+
try:
|
38
|
+
from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers
|
39
|
+
except:
|
40
|
+
USE_PEFT_BACKEND = False
|
36
41
|
|
37
42
|
logger = logging.get_logger(__name__)
|
38
43
|
|
hcpdiff/utils/net_utils.py
CHANGED
@@ -6,11 +6,19 @@ import torch
|
|
6
6
|
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION, Optimizer
|
7
7
|
from torch import nn
|
8
8
|
from torch.optim import lr_scheduler
|
9
|
-
from transformers import PretrainedConfig, AutoTokenizer
|
9
|
+
from transformers import PretrainedConfig, AutoTokenizer, T5EncoderModel, CLIPTextModel
|
10
10
|
from functools import partial
|
11
|
+
from huggingface_hub import hf_hub_download
|
12
|
+
import json
|
11
13
|
|
12
14
|
dtype_dict = {'fp32':torch.float32, 'amp':torch.float32, 'fp16':torch.float16, 'bf16':torch.bfloat16}
|
13
15
|
|
16
|
+
try:
|
17
|
+
dtype_dict['fp8_e4m3'] = torch.float8_e4m3fn
|
18
|
+
dtype_dict['fp8_e5m2'] = torch.float8_e5m2
|
19
|
+
except:
|
20
|
+
pass
|
21
|
+
|
14
22
|
def get_scheduler(cfg, optimizer):
|
15
23
|
if cfg is None:
|
16
24
|
return None
|
@@ -90,7 +98,7 @@ def auto_tokenizer_cls(pretrained_model_name_or_path: str, revision: str = None)
|
|
90
98
|
revision=revision, use_fast=False,
|
91
99
|
)
|
92
100
|
return SDXLTokenizer
|
93
|
-
except
|
101
|
+
except:
|
94
102
|
# not sdxl, only one tokenizer
|
95
103
|
return AutoTokenizer
|
96
104
|
|
@@ -102,8 +110,10 @@ def auto_text_encoder_cls(pretrained_model_name_or_path: str, revision: str = No
|
|
102
110
|
subfolder="text_encoder_2",
|
103
111
|
revision=revision,
|
104
112
|
)
|
113
|
+
if text_encoder_config.architectures is None:
|
114
|
+
raise ValueError()
|
105
115
|
return SDXLTextEncoder
|
106
|
-
except
|
116
|
+
except:
|
107
117
|
text_encoder_config = PretrainedConfig.from_pretrained(
|
108
118
|
pretrained_model_name_or_path,
|
109
119
|
subfolder="text_encoder",
|
@@ -112,16 +122,26 @@ def auto_text_encoder_cls(pretrained_model_name_or_path: str, revision: str = No
|
|
112
122
|
model_class = text_encoder_config.architectures[0]
|
113
123
|
|
114
124
|
if model_class == "CLIPTextModel":
|
115
|
-
from transformers import CLIPTextModel
|
116
|
-
|
117
125
|
return CLIPTextModel
|
118
126
|
elif model_class == "RobertaSeriesModelWithTransformation":
|
119
127
|
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
|
120
128
|
|
121
129
|
return RobertaSeriesModelWithTransformation
|
130
|
+
elif model_class == "T5EncoderModel":
|
131
|
+
return T5EncoderModel
|
122
132
|
else:
|
123
133
|
raise ValueError(f"{model_class} is not supported.")
|
124
134
|
|
135
|
+
def get_pipe_name(path: str):
|
136
|
+
if os.path.isdir(path):
|
137
|
+
json_file = os.path.join(path, "model_index.json")
|
138
|
+
else:
|
139
|
+
json_file = hf_hub_download(path, "model_index.json")
|
140
|
+
with open(json_file, "r", encoding="utf-8") as reader:
|
141
|
+
text = reader.read()
|
142
|
+
data = json.loads(text)
|
143
|
+
return data['_class_name']
|
144
|
+
|
125
145
|
def auto_tokenizer(pretrained_model_name_or_path: str, revision: str = None, **kwargs):
|
126
146
|
return auto_tokenizer_cls(pretrained_model_name_or_path, revision).from_pretrained(pretrained_model_name_or_path, revision=revision, **kwargs)
|
127
147
|
|
@@ -225,4 +245,7 @@ def split_module_name(layer_name):
|
|
225
245
|
return parent_name, host_name
|
226
246
|
|
227
247
|
def get_dtype(dtype):
|
228
|
-
|
248
|
+
if isinstance(dtype, torch.dtype):
|
249
|
+
return dtype
|
250
|
+
else:
|
251
|
+
return dtype_dict.get(dtype, torch.float32)
|
hcpdiff/utils/pipe_hook.py
CHANGED
@@ -2,9 +2,9 @@ from typing import Union, List, Optional, Callable, Dict, Any
|
|
2
2
|
|
3
3
|
import PIL
|
4
4
|
import torch
|
5
|
-
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
|
5
|
+
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, PixArtTransformer2DModel
|
6
6
|
from diffusers.image_processor import VaeImageProcessor
|
7
|
-
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
7
|
+
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
8
8
|
from .inpaint_pipe import preprocess_mask, preprocess_image, StableDiffusionInpaintPipelineLegacy
|
9
9
|
from einops import repeat
|
10
10
|
|
@@ -122,12 +122,20 @@ class HookPipe_T2I(StableDiffusionPipeline):
|
|
122
122
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
123
123
|
|
124
124
|
if pooled_output is None:
|
125
|
-
|
126
|
-
|
125
|
+
if isinstance(self.unet, PixArtTransformer2DModel):
|
126
|
+
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
|
127
|
+
noise_pred = self.unet(latent_model_input, timestep=t.repeat(latent_model_input.shape[0]), encoder_hidden_states=prompt_embeds[i],
|
128
|
+
encoder_attention_mask=encoder_attention_mask,
|
129
|
+
cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
|
130
|
+
else:
|
131
|
+
noise_pred = self.unet(latent_model_input, timestep=t, encoder_hidden_states=prompt_embeds[i],
|
132
|
+
encoder_attention_mask=encoder_attention_mask,
|
133
|
+
cross_attention_kwargs=cross_attention_kwargs).sample
|
127
134
|
else:
|
128
135
|
added_cond_kwargs = {"text_embeds":pooled_output, "time_ids":crop_info}
|
129
136
|
# predict the noise residual
|
130
|
-
noise_pred = self.unet(latent_model_input, t, prompt_embeds[i],
|
137
|
+
noise_pred = self.unet(latent_model_input, timestep=t, encoder_hidden_states=prompt_embeds[i],
|
138
|
+
encoder_attention_mask=encoder_attention_mask,
|
131
139
|
cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
|
132
140
|
|
133
141
|
# perform guidance
|
@@ -135,6 +143,10 @@ class HookPipe_T2I(StableDiffusionPipeline):
|
|
135
143
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
136
144
|
noise_pred = noise_pred_uncond+guidance_scale*(noise_pred_text-noise_pred_uncond)
|
137
145
|
|
146
|
+
# learned sigma
|
147
|
+
if self.unet.config.out_channels // 2 == num_channels_latents:
|
148
|
+
noise_pred = noise_pred.chunk(2, dim=1)[0]
|
149
|
+
|
138
150
|
# x_t -> x_0
|
139
151
|
alpha_prod_t = alphas_cumprod[t.long()]
|
140
152
|
beta_prod_t = 1-alpha_prod_t
|
@@ -271,8 +283,13 @@ class HookPipe_I2I(StableDiffusionImg2ImgPipeline):
|
|
271
283
|
|
272
284
|
# predict the noise residual
|
273
285
|
if pooled_output is None:
|
274
|
-
|
275
|
-
|
286
|
+
if isinstance(self.unet, PixArtTransformer2DModel):
|
287
|
+
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
|
288
|
+
noise_pred = self.unet(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
|
289
|
+
cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
|
290
|
+
else:
|
291
|
+
noise_pred = self.unet(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
|
292
|
+
cross_attention_kwargs=cross_attention_kwargs, ).sample
|
276
293
|
else:
|
277
294
|
added_cond_kwargs = {"text_embeds":pooled_output, "time_ids":crop_info}
|
278
295
|
# predict the noise residual
|
hcpdiff/utils/utils.py
CHANGED
@@ -56,8 +56,8 @@ def remove_config_undefined(cfg):
|
|
56
56
|
def load_config(path, remove_undefined=True):
|
57
57
|
cfg = OmegaConf.load(path)
|
58
58
|
if '_base_' in cfg:
|
59
|
-
for base in cfg['_base_']
|
60
|
-
|
59
|
+
base_cfgs = [load_config(base, remove_undefined=False) for base in cfg['_base_']]
|
60
|
+
cfg = OmegaConf.merge(*base_cfgs, cfg)
|
61
61
|
del cfg['_base_']
|
62
62
|
if remove_undefined:
|
63
63
|
cfg = remove_config_undefined(cfg)
|
@@ -85,7 +85,7 @@ def get_cfg_range(cfg_text:str):
|
|
85
85
|
def to_validate_file(name):
|
86
86
|
rstr = r"[\/\\\:\*\?\"\<\>\|]" # '/ \ : * ? " < > |'
|
87
87
|
new_title = re.sub(rstr, "_", name) # 替换为下划线
|
88
|
-
return new_title[:
|
88
|
+
return new_title[:200]
|
89
89
|
|
90
90
|
def make_mask(start, end, length):
|
91
91
|
mask=torch.zeros(length)
|
@@ -159,4 +159,21 @@ def pad_attn_bias(x, attn_bias, block_size=8):
|
|
159
159
|
# 在k维度上进行填充
|
160
160
|
x_padded = F.pad(x, (0, 0, 0, padding_l, 0, 0), mode='constant', value=0)
|
161
161
|
attn_bias_padded = F.pad(attn_bias, (0, padding_l, 0, 0), mode='constant', value=0)
|
162
|
-
return x_padded, attn_bias_padded
|
162
|
+
return x_padded, attn_bias_padded
|
163
|
+
|
164
|
+
def linear_interp(t, x):
|
165
|
+
'''
|
166
|
+
t_l ---------t_h
|
167
|
+
^x
|
168
|
+
'''
|
169
|
+
if (x>=len(t)).any():
|
170
|
+
x = x.clamp(max=len(t)-1e-6)
|
171
|
+
x0 = x.floor().long()
|
172
|
+
x1 = x0 + 1
|
173
|
+
|
174
|
+
y0 = t[x0]
|
175
|
+
y1 = t[x1]
|
176
|
+
|
177
|
+
xd = (x - x0.float())
|
178
|
+
|
179
|
+
return y0 * (1 - xd) + y1 * xd
|
hcpdiff/workflow/__init__.py
CHANGED
@@ -1,15 +1,20 @@
|
|
1
|
-
from .
|
2
|
-
|
3
|
-
X0PredAction, SeedAction, MakeTimestepsAction
|
1
|
+
from .diffusion import InputFeederAction, MakeLatentAction, DenoiseAction, SampleAction, DiffusionStepAction, \
|
2
|
+
X0PredAction, SeedAction, MakeTimestepsAction, PrepareDiffusionAction, time_iter
|
4
3
|
from .text import TextEncodeAction, TextHookAction, AttnMultTextEncodeAction
|
5
4
|
from .vae import EncodeAction, DecodeAction
|
6
|
-
from .io import
|
7
|
-
from .utils import LatentResizeAction, ImageResizeAction
|
8
|
-
from .model import VaeOptimizeAction, BuildOffloadAction, XformersEnableAction
|
5
|
+
from .io import BuildModelsAction, SaveImageAction, LoadImageAction
|
6
|
+
from .utils import LatentResizeAction, ImageResizeAction, FeedtoCNetAction
|
7
|
+
from .model import VaeOptimizeAction, BuildOffloadAction, XformersEnableAction
|
8
|
+
#from .flow import FilePromptAction
|
9
|
+
|
10
|
+
try:
|
11
|
+
from .fast import SFastCompileAction
|
12
|
+
except:
|
13
|
+
print('stable fast not installed.')
|
9
14
|
|
10
15
|
from omegaconf import OmegaConf
|
11
16
|
|
12
|
-
OmegaConf.register_new_resolver("hcp.from_memory", lambda mem_name:
|
13
|
-
'_target_':
|
14
|
-
'mem_name':
|
15
|
-
}))
|
17
|
+
OmegaConf.register_new_resolver("hcp.from_memory", lambda mem_name:OmegaConf.create({
|
18
|
+
'_target_':'hcpdiff.workflow.from_memory',
|
19
|
+
'mem_name':mem_name,
|
20
|
+
}))
|