hcpdiff 0.9.0__py3-none-any.whl → 2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/__init__.py +4 -4
- hcpdiff/ckpt_manager/__init__.py +4 -5
- hcpdiff/ckpt_manager/ckpt.py +24 -0
- hcpdiff/ckpt_manager/format/__init__.py +4 -0
- hcpdiff/ckpt_manager/format/diffusers.py +59 -0
- hcpdiff/ckpt_manager/format/emb.py +21 -0
- hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
- hcpdiff/ckpt_manager/format/sd_single.py +41 -0
- hcpdiff/ckpt_manager/loader.py +64 -0
- hcpdiff/data/__init__.py +4 -28
- hcpdiff/data/cache/__init__.py +1 -0
- hcpdiff/data/cache/vae.py +102 -0
- hcpdiff/data/dataset.py +20 -0
- hcpdiff/data/handler/__init__.py +3 -0
- hcpdiff/data/handler/controlnet.py +18 -0
- hcpdiff/data/handler/diffusion.py +80 -0
- hcpdiff/data/handler/text.py +111 -0
- hcpdiff/data/source/__init__.py +1 -2
- hcpdiff/data/source/folder_class.py +12 -29
- hcpdiff/data/source/text2img.py +36 -74
- hcpdiff/data/source/text2img_cond.py +9 -15
- hcpdiff/diffusion/__init__.py +0 -0
- hcpdiff/diffusion/noise/__init__.py +2 -0
- hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
- hcpdiff/diffusion/noise/zero_terminal.py +39 -0
- hcpdiff/diffusion/sampler/__init__.py +5 -0
- hcpdiff/diffusion/sampler/base.py +72 -0
- hcpdiff/diffusion/sampler/ddpm.py +20 -0
- hcpdiff/diffusion/sampler/diffusers.py +66 -0
- hcpdiff/diffusion/sampler/edm.py +22 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
- hcpdiff/easy/__init__.py +2 -0
- hcpdiff/easy/cfg/__init__.py +3 -0
- hcpdiff/easy/cfg/sd15_train.py +201 -0
- hcpdiff/easy/cfg/sdxl_train.py +140 -0
- hcpdiff/easy/cfg/t2i.py +177 -0
- hcpdiff/easy/model/__init__.py +2 -0
- hcpdiff/easy/model/cnet.py +31 -0
- hcpdiff/easy/model/loader.py +79 -0
- hcpdiff/easy/sampler.py +46 -0
- hcpdiff/evaluate/__init__.py +1 -0
- hcpdiff/evaluate/previewer.py +60 -0
- hcpdiff/loss/__init__.py +4 -1
- hcpdiff/loss/base.py +41 -0
- hcpdiff/loss/gw.py +35 -0
- hcpdiff/loss/ssim.py +37 -0
- hcpdiff/loss/vlb.py +79 -0
- hcpdiff/loss/weighting.py +66 -0
- hcpdiff/models/__init__.py +2 -2
- hcpdiff/models/cfg_context.py +17 -14
- hcpdiff/models/compose/compose_hook.py +44 -23
- hcpdiff/models/compose/compose_tokenizer.py +21 -8
- hcpdiff/models/compose/sdxl_composer.py +4 -4
- hcpdiff/models/container.py +1 -1
- hcpdiff/models/controlnet.py +16 -16
- hcpdiff/models/lora_base_patch.py +14 -25
- hcpdiff/models/lora_layers.py +3 -9
- hcpdiff/models/lora_layers_patch.py +14 -24
- hcpdiff/models/text_emb_ex.py +84 -6
- hcpdiff/models/textencoder_ex.py +54 -18
- hcpdiff/models/wrapper/__init__.py +3 -0
- hcpdiff/models/wrapper/pixart.py +19 -0
- hcpdiff/models/wrapper/sd.py +218 -0
- hcpdiff/models/wrapper/utils.py +20 -0
- hcpdiff/parser/__init__.py +1 -0
- hcpdiff/parser/embpt.py +32 -0
- hcpdiff/tools/convert_caption_txt2json.py +1 -1
- hcpdiff/tools/dataset_generator.py +94 -0
- hcpdiff/tools/download_hf_model.py +24 -0
- hcpdiff/tools/embedding_convert.py +6 -2
- hcpdiff/tools/init_proj.py +3 -21
- hcpdiff/tools/lora_convert.py +19 -15
- hcpdiff/tools/save_model.py +12 -0
- hcpdiff/tools/sd2diffusers.py +1 -1
- hcpdiff/train_colo.py +1 -1
- hcpdiff/train_deepspeed.py +1 -1
- hcpdiff/trainer_ac.py +79 -0
- hcpdiff/trainer_ac_single.py +31 -0
- hcpdiff/utils/__init__.py +0 -2
- hcpdiff/utils/inpaint_pipe.py +790 -0
- hcpdiff/utils/net_utils.py +29 -6
- hcpdiff/utils/pipe_hook.py +46 -33
- hcpdiff/utils/utils.py +21 -4
- hcpdiff/workflow/__init__.py +15 -10
- hcpdiff/workflow/daam/__init__.py +1 -0
- hcpdiff/workflow/daam/act.py +66 -0
- hcpdiff/workflow/daam/hook.py +109 -0
- hcpdiff/workflow/diffusion.py +128 -136
- hcpdiff/workflow/fast.py +31 -0
- hcpdiff/workflow/flow.py +67 -0
- hcpdiff/workflow/io.py +36 -68
- hcpdiff/workflow/model.py +46 -43
- hcpdiff/workflow/text.py +84 -52
- hcpdiff/workflow/utils.py +32 -12
- hcpdiff/workflow/vae.py +37 -38
- hcpdiff-2.1.dist-info/METADATA +285 -0
- hcpdiff-2.1.dist-info/RECORD +114 -0
- {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
- hcpdiff-2.1.dist-info/entry_points.txt +5 -0
- hcpdiff/ckpt_manager/base.py +0 -16
- hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
- hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
- hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -60
- hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
- hcpdiff/data/bucket.py +0 -358
- hcpdiff/data/caption_loader.py +0 -80
- hcpdiff/data/cond_dataset.py +0 -40
- hcpdiff/data/crop_info_dataset.py +0 -40
- hcpdiff/data/data_processor.py +0 -33
- hcpdiff/data/pair_dataset.py +0 -146
- hcpdiff/data/sampler.py +0 -54
- hcpdiff/data/source/base.py +0 -30
- hcpdiff/data/utils.py +0 -80
- hcpdiff/infer_workflow.py +0 -57
- hcpdiff/loggers/__init__.py +0 -13
- hcpdiff/loggers/base_logger.py +0 -76
- hcpdiff/loggers/cli_logger.py +0 -40
- hcpdiff/loggers/preview/__init__.py +0 -1
- hcpdiff/loggers/preview/image_previewer.py +0 -149
- hcpdiff/loggers/tensorboard_logger.py +0 -30
- hcpdiff/loggers/wandb_logger.py +0 -31
- hcpdiff/loggers/webui_logger.py +0 -9
- hcpdiff/loss/min_snr_loss.py +0 -52
- hcpdiff/models/layers.py +0 -81
- hcpdiff/models/plugin.py +0 -348
- hcpdiff/models/wrapper.py +0 -75
- hcpdiff/noise/__init__.py +0 -3
- hcpdiff/noise/noise_base.py +0 -16
- hcpdiff/noise/pyramid_noise.py +0 -50
- hcpdiff/noise/zero_terminal.py +0 -44
- hcpdiff/train_ac.py +0 -565
- hcpdiff/train_ac_single.py +0 -39
- hcpdiff/utils/caption_tools.py +0 -105
- hcpdiff/utils/cfg_net_tools.py +0 -321
- hcpdiff/utils/cfg_resolvers.py +0 -16
- hcpdiff/utils/ema.py +0 -52
- hcpdiff/utils/img_size_tool.py +0 -248
- hcpdiff/vis/__init__.py +0 -3
- hcpdiff/vis/base_interface.py +0 -12
- hcpdiff/vis/disk_interface.py +0 -48
- hcpdiff/vis/webui_interface.py +0 -17
- hcpdiff/visualizer.py +0 -258
- hcpdiff/visualizer_reloadable.py +0 -237
- hcpdiff/workflow/base.py +0 -59
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -57
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero2.json +0 -32
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero3.json +0 -39
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
- hcpdiff-0.9.0.dist-info/METADATA +0 -199
- hcpdiff-0.9.0.dist-info/RECORD +0 -155
- hcpdiff-0.9.0.dist-info/entry_points.txt +0 -2
- {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
- {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,80 @@
|
|
1
|
+
from typing import Union, Dict, Any
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
import torch
|
5
|
+
import torchvision.transforms as T
|
6
|
+
from PIL import Image
|
7
|
+
from rainbowneko.data import DataHandler, HandlerChain, LoadImageHandler, ImageHandler
|
8
|
+
|
9
|
+
from .text import TemplateFillHandler, TagDropoutHandler, TagEraseHandler, TagShuffleHandler, TokenizeHandler
|
10
|
+
|
11
|
+
class LossMapHandler(DataHandler):
|
12
|
+
def __init__(self, bucket, vae_scale=8, key_map_in=('loss_map -> image', 'image_size -> image_size'),
|
13
|
+
key_map_out=('image -> loss_map', 'coord -> coord')):
|
14
|
+
super().__init__(key_map_in, key_map_out)
|
15
|
+
self.vae_scale = vae_scale
|
16
|
+
|
17
|
+
self.handlers = HandlerChain(
|
18
|
+
load=LoadImageHandler(mode='L'),
|
19
|
+
bucket=bucket.handler,
|
20
|
+
image=ImageHandler(transform=T.Compose([
|
21
|
+
lambda x:x.resize((x.size[0]//self.vae_scale, x.size[1]//self.vae_scale), Image.BILINEAR),
|
22
|
+
T.ToTensor()
|
23
|
+
]), )
|
24
|
+
)
|
25
|
+
|
26
|
+
def handle(self, image: Union[Image.Image, str], image_size: np.ndarray[int]):
|
27
|
+
data = self.handlers(dict(image=image, image_size=image_size))
|
28
|
+
image = data['image']
|
29
|
+
image[image<=0.5] *= 2
|
30
|
+
image[image>0.5] = (image[image>0.5]-0.5)*4+1
|
31
|
+
return self.handlers(dict(**data, image=image))
|
32
|
+
|
33
|
+
class DiffusionImageHandler(DataHandler):
|
34
|
+
def __init__(self, bucket, key_map_in=('image -> image', 'image_size -> image_size'), key_map_out=('image -> image', 'coord -> coord')):
|
35
|
+
super().__init__(key_map_in, key_map_out)
|
36
|
+
|
37
|
+
self.handlers = HandlerChain(
|
38
|
+
load=LoadImageHandler(),
|
39
|
+
bucket=bucket.handler,
|
40
|
+
image=ImageHandler(transform=T.Compose([
|
41
|
+
T.ToTensor(),
|
42
|
+
T.Normalize([0.5], [0.5])
|
43
|
+
]), )
|
44
|
+
)
|
45
|
+
|
46
|
+
def handle(self, image: Image.Image, image_size: np.ndarray[int]):
|
47
|
+
if isinstance(image, torch.Tensor): # cached latents
|
48
|
+
return dict(image=image, image_size=image_size)
|
49
|
+
else:
|
50
|
+
return self.handlers(dict(image=image, image_size=image_size))
|
51
|
+
|
52
|
+
class StableDiffusionHandler(DataHandler):
|
53
|
+
def __init__(self, bucket, encoder_attention_mask=False, key_map_in=('image -> image', 'image_size -> image_size', 'prompt -> prompt'),
|
54
|
+
key_map_out=('image -> image', 'coord -> coord', 'prompt -> prompt'),
|
55
|
+
erase=0.15, dropout=0.0, shuffle=0.0, word_names={}, tokenize=True):
|
56
|
+
super().__init__(key_map_in, key_map_out)
|
57
|
+
|
58
|
+
self.image_handlers = DiffusionImageHandler(bucket)
|
59
|
+
|
60
|
+
text_handlers = {}
|
61
|
+
if dropout>0:
|
62
|
+
text_handlers['dropout'] = TagDropoutHandler(p=dropout)
|
63
|
+
if erase>0:
|
64
|
+
text_handlers['erase'] = TagEraseHandler(p=erase)
|
65
|
+
if shuffle>0:
|
66
|
+
text_handlers['shuffle'] = TagShuffleHandler()
|
67
|
+
text_handlers['fill'] = TemplateFillHandler(word_names)
|
68
|
+
if tokenize:
|
69
|
+
text_handlers['tokenize'] = TokenizeHandler(encoder_attention_mask)
|
70
|
+
self.text_handlers = HandlerChain(**text_handlers)
|
71
|
+
|
72
|
+
def handle(self, image: Image.Image, image_size: np.ndarray[int], prompt: str):
|
73
|
+
return dict(**self.image_handlers(dict(image=image, image_size=image_size)), **self.text_handlers(dict(prompt=prompt)))
|
74
|
+
|
75
|
+
def __call__(self, data) -> Dict[str, Any]:
|
76
|
+
data_proc = self.handle(**self.key_mapper_in.map_data(data)[1])
|
77
|
+
out_data = self.key_mapper_out.map_data(data_proc)[1]
|
78
|
+
data = dict(**data)
|
79
|
+
data.update(out_data)
|
80
|
+
return data
|
@@ -0,0 +1,111 @@
|
|
1
|
+
import random
|
2
|
+
from typing import Dict, Union, List
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
from string import Formatter
|
6
|
+
from rainbowneko.data import DataHandler
|
7
|
+
from rainbowneko._share import register_model_callback
|
8
|
+
|
9
|
+
class TagShuffleHandler(DataHandler):
|
10
|
+
def __init__(self, key_map_in=('prompt -> prompt',), key_map_out=('prompt -> prompt',)):
|
11
|
+
super().__init__(key_map_in, key_map_out)
|
12
|
+
|
13
|
+
def handle(self, prompt: Union[Dict[str, str], str]):
|
14
|
+
if isinstance(prompt, str):
|
15
|
+
tags = prompt.split(',')
|
16
|
+
random.shuffle(tags)
|
17
|
+
prompt = ','.join(tags)
|
18
|
+
else:
|
19
|
+
tags = prompt['caption'].split(',')
|
20
|
+
random.shuffle(tags)
|
21
|
+
prompt['caption'] = ','.join(tags)
|
22
|
+
return {'prompt':prompt}
|
23
|
+
|
24
|
+
def __repr__(self):
|
25
|
+
return 'TagShuffleHandler()'
|
26
|
+
|
27
|
+
class TagDropoutHandler(DataHandler):
|
28
|
+
def __init__(self, p=0.1, key_map_in=('prompt -> prompt',), key_map_out=('prompt -> prompt',)):
|
29
|
+
super().__init__(key_map_in, key_map_out)
|
30
|
+
self.p = p
|
31
|
+
|
32
|
+
def handle(self, prompt: Union[Dict[str, str], str]):
|
33
|
+
if isinstance(prompt, str):
|
34
|
+
tags = np.array(prompt.split(','))
|
35
|
+
prompt = ','.join(tags[np.random.random(len(tags))>self.p])
|
36
|
+
else:
|
37
|
+
tags = prompt['caption'].split(',')
|
38
|
+
prompt['caption'] = ','.join(tags[np.random.random(len(tags))>self.p])
|
39
|
+
return {'prompt':prompt}
|
40
|
+
|
41
|
+
def __repr__(self):
|
42
|
+
return f'TagDropoutHandler(p={self.p})'
|
43
|
+
|
44
|
+
class TagEraseHandler(DataHandler):
|
45
|
+
def __init__(self, p=0.1, key_map_in=('prompt -> prompt',), key_map_out=('prompt -> prompt',)):
|
46
|
+
super().__init__(key_map_in, key_map_out)
|
47
|
+
self.p = p
|
48
|
+
|
49
|
+
def handle(self, prompt):
|
50
|
+
if isinstance(prompt, str):
|
51
|
+
if random.random()<self.p:
|
52
|
+
prompt = ''
|
53
|
+
else:
|
54
|
+
if random.random()<self.p:
|
55
|
+
prompt['caption'] = ''
|
56
|
+
return {'prompt':prompt}
|
57
|
+
|
58
|
+
def __repr__(self):
|
59
|
+
return f'TagEraseHandler(p={self.p})'
|
60
|
+
|
61
|
+
|
62
|
+
class TemplateFillHandler(DataHandler):
|
63
|
+
def __init__(self, word_names: Dict[str, str], key_map_in=('prompt -> prompt',), key_map_out=('prompt -> prompt',)):
|
64
|
+
super().__init__(key_map_in, key_map_out)
|
65
|
+
self.word_names = word_names
|
66
|
+
|
67
|
+
def handle(self, prompt):
|
68
|
+
template, caption = prompt['template'], prompt['caption']
|
69
|
+
|
70
|
+
keys_need = {i[1] for i in Formatter().parse(template) if i[1] is not None}
|
71
|
+
fill_dict = {k: v for k, v in self.word_names.items() if k in keys_need}
|
72
|
+
|
73
|
+
if (caption is not None) and ('caption' in keys_need):
|
74
|
+
fill_dict.update(caption=fill_dict.get('caption', None) or caption)
|
75
|
+
|
76
|
+
# skip keys that not provide
|
77
|
+
for k in keys_need:
|
78
|
+
if k not in fill_dict:
|
79
|
+
fill_dict[k] = ''
|
80
|
+
|
81
|
+
# replace None value with ''
|
82
|
+
fill_dict = {k:(v or '') for k, v in fill_dict.items()}
|
83
|
+
return {'prompt':template.format(**fill_dict)}
|
84
|
+
|
85
|
+
def __repr__(self):
|
86
|
+
return f'TemplateFill(\nword_names={self.word_names}\n)'
|
87
|
+
|
88
|
+
class TokenizeHandler(DataHandler):
|
89
|
+
def __init__(self, encoder_attention_mask=False, key_map_in=('prompt -> prompt',), key_map_out=None):
|
90
|
+
super().__init__(key_map_in, key_map_out)
|
91
|
+
self.encoder_attention_mask = encoder_attention_mask
|
92
|
+
|
93
|
+
register_model_callback(self.acquire_tokenizer)
|
94
|
+
|
95
|
+
def acquire_tokenizer(self, model_wrapper):
|
96
|
+
self.tokenizer = model_wrapper.tokenizer
|
97
|
+
|
98
|
+
def handle(self, prompt):
|
99
|
+
token_info = self.tokenizer(prompt, truncation=True, padding="max_length", return_tensors="pt",
|
100
|
+
max_length=self.tokenizer.model_max_length*self.tokenizer.N_repeats)
|
101
|
+
tokens = token_info.input_ids.squeeze()
|
102
|
+
data = {'prompt':tokens}
|
103
|
+
if self.encoder_attention_mask and 'attention_mask' in token_info:
|
104
|
+
data['attn_mask'] = token_info.attention_mask.squeeze()
|
105
|
+
if 'position_ids' in token_info:
|
106
|
+
data['position_ids'] = token_info.position_ids.squeeze()
|
107
|
+
|
108
|
+
return data
|
109
|
+
|
110
|
+
def __repr__(self):
|
111
|
+
return f'TokenizeHandler(\nencoder_attention_mask={self.encoder_attention_mask}, tokenizer={self.tokenizer}\n)'
|
hcpdiff/data/source/__init__.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1
|
-
from .
|
2
|
-
from .text2img import Text2ImageSource, Text2ImageAttMapSource
|
1
|
+
from .text2img import Text2ImageSource, Text2ImageLossMapSource
|
3
2
|
from .text2img_cond import Text2ImageCondSource
|
4
3
|
from .folder_class import T2IFolderClassSource
|
@@ -1,40 +1,23 @@
|
|
1
|
-
import os
|
2
|
-
from typing import List, Tuple, Union
|
3
|
-
from hcpdiff.utils.utils import get_file_name, get_file_ext
|
4
|
-
from hcpdiff.utils.img_size_tool import types_support
|
5
|
-
from .text2img import Text2ImageAttMapSource
|
6
|
-
from hcpdiff.data.caption_loader import BaseCaptionLoader, auto_caption_loader
|
7
1
|
from copy import copy
|
2
|
+
from typing import Union
|
8
3
|
|
9
|
-
|
4
|
+
from rainbowneko.data.label_loader import BaseLabelLoader, auto_label_loader
|
10
5
|
|
11
|
-
|
12
|
-
sub_folders = [os.path.join(self.img_root, x) for x in os.listdir(self.img_root)]
|
13
|
-
class_imgs = []
|
14
|
-
for class_folder in sub_folders:
|
15
|
-
class_name = os.path.basename(class_folder)
|
16
|
-
imgs = [(os.path.join(class_folder, x), self) for x in os.listdir(class_folder) if get_file_ext(x) in types_support]
|
17
|
-
class_imgs.extend(imgs*self.repeat[class_name])
|
18
|
-
return class_imgs
|
6
|
+
from .text2img import Text2ImageLossMapSource
|
19
7
|
|
20
|
-
|
21
|
-
|
8
|
+
class T2IFolderClassSource(Text2ImageLossMapSource):
|
9
|
+
def _load_label_data(self, label_file: Union[str, BaseLabelLoader]):
|
10
|
+
''' {class_name/image.ext: label} '''
|
11
|
+
if label_file is None:
|
22
12
|
return {}
|
23
|
-
elif isinstance(
|
13
|
+
elif isinstance(label_file, str):
|
24
14
|
captions = {}
|
25
|
-
caption_loader =
|
26
|
-
for
|
27
|
-
class_folder = os.path.join(caption_loader.path, class_name)
|
15
|
+
caption_loader = auto_label_loader(label_file)
|
16
|
+
for class_folder in caption_loader.path.iterdir():
|
28
17
|
caption_loader_class = copy(caption_loader)
|
29
18
|
caption_loader_class.path = class_folder
|
30
|
-
captions_class = {f'{
|
19
|
+
captions_class = {f'{class_folder.name}/{name}':caption for name, caption in caption_loader_class.load().item()}
|
31
20
|
captions.update(captions_class)
|
32
21
|
return captions
|
33
22
|
else:
|
34
|
-
return
|
35
|
-
|
36
|
-
def get_image_name(self, path: str) -> str:
|
37
|
-
img_root, img_name = os.path.split(path)
|
38
|
-
img_name = img_name.rsplit('.')[0]
|
39
|
-
img_root, class_name = os.path.split(img_root)
|
40
|
-
return f'{class_name}/{img_name}'
|
23
|
+
return label_file.load()
|
hcpdiff/data/source/text2img.py
CHANGED
@@ -1,13 +1,11 @@
|
|
1
|
-
from .base import DataSource
|
2
|
-
from hcpdiff.data.caption_loader import BaseCaptionLoader, auto_caption_loader
|
3
|
-
from typing import Union, Any
|
4
1
|
import os
|
5
|
-
from hcpdiff.utils.utils import get_file_name, get_file_ext
|
6
|
-
from hcpdiff.utils.img_size_tool import types_support
|
7
|
-
from typing import Dict, List, Tuple
|
8
|
-
from PIL import Image
|
9
|
-
import numpy as np
|
10
2
|
import random
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import Any
|
5
|
+
from typing import Dict
|
6
|
+
|
7
|
+
from rainbowneko.data import ImageLabelSource
|
8
|
+
from rainbowneko.utils.utils import is_image_file
|
11
9
|
from torchvision.transforms import transforms
|
12
10
|
|
13
11
|
default_image_transforms = transforms.Compose([
|
@@ -15,77 +13,41 @@ default_image_transforms = transforms.Compose([
|
|
15
13
|
transforms.Normalize([0.5], [0.5])
|
16
14
|
])
|
17
15
|
|
18
|
-
class Text2ImageSource(
|
19
|
-
def __init__(self, img_root,
|
20
|
-
|
21
|
-
super(Text2ImageSource, self).__init__(img_root, repeat=repeat)
|
16
|
+
class Text2ImageSource(ImageLabelSource):
|
17
|
+
def __init__(self, img_root, label_file, prompt_template, repeat=1, **kwargs):
|
18
|
+
super().__init__(img_root, label_file, repeat=repeat)
|
22
19
|
|
23
|
-
self.caption_dict = self.load_captions(caption_file)
|
24
20
|
self.prompt_template = self.load_template(prompt_template)
|
25
|
-
self.image_transforms = image_transforms
|
26
|
-
self.text_transforms = text_transforms
|
27
|
-
self.bg_color = tuple(bg_color)
|
28
|
-
|
29
|
-
def load_captions(self, caption_file: Union[str, BaseCaptionLoader]):
|
30
|
-
if caption_file is None:
|
31
|
-
return {}
|
32
|
-
elif isinstance(caption_file, str):
|
33
|
-
return auto_caption_loader(caption_file).load()
|
34
|
-
else:
|
35
|
-
return caption_file.load()
|
36
21
|
|
37
22
|
def load_template(self, template_file):
|
38
23
|
with open(template_file, 'r', encoding='utf-8') as f:
|
39
24
|
return f.read().strip().split('\n')
|
40
25
|
|
41
|
-
def
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
def load_caption(self, img_name) -> str:
|
61
|
-
caption_ist = self.caption_dict.get(img_name, None)
|
62
|
-
prompt_template = random.choice(self.prompt_template)
|
63
|
-
prompt_ist = self.process_text({'prompt':prompt_template, 'caption':caption_ist})['prompt']
|
64
|
-
return prompt_ist
|
65
|
-
|
66
|
-
class Text2ImageAttMapSource(Text2ImageSource):
|
67
|
-
def __init__(self, img_root, caption_file, prompt_template, text_transforms, image_transforms=default_image_transforms, att_mask=None,
|
68
|
-
bg_color=(255, 255, 255), repeat=1, **kwargs):
|
69
|
-
super().__init__(img_root, caption_file, prompt_template, image_transforms=image_transforms, text_transforms=text_transforms,
|
70
|
-
bg_color=bg_color, repeat=repeat)
|
71
|
-
|
72
|
-
if att_mask is None:
|
73
|
-
self.att_mask = {}
|
26
|
+
def __getitem__(self, index) -> Dict[str, Any]:
|
27
|
+
img_name = self.img_ids[index]
|
28
|
+
path = os.path.join(self.img_root, img_name)
|
29
|
+
|
30
|
+
return {
|
31
|
+
'id':img_name,
|
32
|
+
'image':path,
|
33
|
+
'prompt':{
|
34
|
+
'template':random.choice(self.prompt_template),
|
35
|
+
'caption':self.label_dict.get(img_name, None),
|
36
|
+
}
|
37
|
+
}
|
38
|
+
|
39
|
+
class Text2ImageLossMapSource(Text2ImageSource):
|
40
|
+
def __init__(self, img_root, caption_file, prompt_template, loss_map=None, repeat=1, **kwargs):
|
41
|
+
super().__init__(img_root, caption_file, prompt_template, repeat=repeat)
|
42
|
+
|
43
|
+
if loss_map is None:
|
44
|
+
self.loss_map = {}
|
74
45
|
else:
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
def
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
np_mask[np_mask<=127+0.1] = (np_mask[np_mask<=127+0.1]/127.)
|
84
|
-
np_mask[np_mask>127] = ((np_mask[np_mask>127]-127)/128.)*4+1
|
85
|
-
return np_mask
|
86
|
-
|
87
|
-
def load_image(self, path) -> Dict[str, Any]:
|
88
|
-
img_root, img_name = os.path.split(path)
|
89
|
-
image_dict = super().load_image(path)
|
90
|
-
image_dict['att_mask'] = self.get_att_mask(get_file_name(img_name))
|
91
|
-
return image_dict
|
46
|
+
loss_map = Path(loss_map)
|
47
|
+
self.loss_map = {file.stem:loss_map/file for file in loss_map.iterdir() if is_image_file(file)}
|
48
|
+
|
49
|
+
def __getitem__(self, index) -> Dict[str, Any]:
|
50
|
+
data = super().__getitem__(index)
|
51
|
+
img_name = self.img_ids[index]
|
52
|
+
data['loss_map'] = self.loss_map[Path(img_name).stem]
|
53
|
+
return data
|
@@ -1,22 +1,16 @@
|
|
1
1
|
import os
|
2
2
|
from typing import Dict, Any
|
3
3
|
|
4
|
-
from
|
5
|
-
from torchvision import transforms
|
4
|
+
from .text2img import Text2ImageSource
|
6
5
|
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
def __init__(self, img_root, caption_file, prompt_template, text_transforms, image_transforms=default_image_transforms,
|
11
|
-
bg_color=(255, 255, 255), repeat=1, cond_dir=None, **kwargs):
|
12
|
-
super().__init__(img_root, caption_file, prompt_template, image_transforms=image_transforms, text_transforms=text_transforms,
|
13
|
-
bg_color=bg_color, repeat=repeat)
|
14
|
-
self.cond_transform = transforms.ToTensor()
|
6
|
+
class Text2ImageCondSource(Text2ImageSource):
|
7
|
+
def __init__(self, img_root, caption_file, prompt_template, repeat=1, cond_dir=None, **kwargs):
|
8
|
+
super().__init__(img_root, caption_file, prompt_template, repeat=repeat)
|
15
9
|
self.cond_dir = cond_dir
|
16
10
|
|
17
|
-
def
|
18
|
-
|
19
|
-
|
11
|
+
def __getitem__(self, index) -> Dict[str, Any]:
|
12
|
+
data = super().__getitem__(index)
|
13
|
+
img_name = self.img_ids[index]
|
20
14
|
cond_path = os.path.join(self.cond_dir, img_name)
|
21
|
-
|
22
|
-
return
|
15
|
+
data['cond'] = cond_path
|
16
|
+
return data
|
File without changes
|
@@ -0,0 +1,42 @@
|
|
1
|
+
import random
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from torch.nn import functional as F
|
5
|
+
|
6
|
+
from hcpdiff.diffusion.sampler import BaseSampler
|
7
|
+
|
8
|
+
class PyramidNoiseSampler:
|
9
|
+
def __init__(self, level: int = 6, discount: float = 0.4, step_size: float = 2., resize_mode: str = 'bilinear'):
|
10
|
+
self.level = level
|
11
|
+
self.step_size = step_size
|
12
|
+
self.resize_mode = resize_mode
|
13
|
+
self.discount = discount
|
14
|
+
|
15
|
+
def make_nosie(self, shape, device='cuda', dtype=torch.float32):
|
16
|
+
noise = torch.randn(shape, device=device, dtype=dtype)
|
17
|
+
with torch.no_grad():
|
18
|
+
b, c, h, w = noise.shape
|
19
|
+
for i in range(1, self.level):
|
20
|
+
r = random.random()*2+self.step_size
|
21
|
+
wn, hn = max(1, int(w/(r**i))), max(1, int(h/(r**i)))
|
22
|
+
noise += F.interpolate(torch.randn(b, c, hn, wn).to(noise), (h, w), None, self.resize_mode)*(self.discount**i)
|
23
|
+
if wn == 1 or hn == 1:
|
24
|
+
break
|
25
|
+
noise = noise/noise.std()
|
26
|
+
return noise
|
27
|
+
|
28
|
+
@classmethod
|
29
|
+
def patch(cls, base_sampler: BaseSampler, level: int = 6, discount: float = 0.4, step_size: float = 2., resize_mode: str = 'bilinear'):
|
30
|
+
patcher = cls(level, discount, step_size, resize_mode)
|
31
|
+
base_sampler.make_nosie = patcher.make_nosie
|
32
|
+
return base_sampler
|
33
|
+
|
34
|
+
if __name__ == '__main__':
|
35
|
+
from hcpdiff.diffusion.sampler import EDM_DDPMSampler, DDPMContinuousSigmaScheduler
|
36
|
+
from matplotlib import pyplot as plt
|
37
|
+
|
38
|
+
sampler = PyramidNoiseSampler.patch(EDM_DDPMSampler(DDPMContinuousSigmaScheduler()))
|
39
|
+
noise = sampler.make_nosie((1,3,512,512), device='cpu')
|
40
|
+
plt.figure()
|
41
|
+
plt.imshow(noise[0].permute(1,2,0))
|
42
|
+
plt.show()
|
@@ -0,0 +1,39 @@
|
|
1
|
+
import torch
|
2
|
+
from ..sampler.sigma_scheduler import DDPMDiscreteSigmaScheduler
|
3
|
+
|
4
|
+
class ZeroTerminalSampler:
|
5
|
+
|
6
|
+
@classmethod
|
7
|
+
def patch(cls, base_sampler):
|
8
|
+
assert isinstance(base_sampler.sigma_scheduler, DDPMDiscreteSigmaScheduler), "ZeroTerminalScheduler only works with DDPM SigmaScheduler"
|
9
|
+
|
10
|
+
alphas_cumprod = base_sampler.sigma_scheduler.alphas_cumprod
|
11
|
+
base_sampler.sigma_scheduler.alphas_cumprod = cls.rescale_zero_terminal_snr(alphas_cumprod)
|
12
|
+
base_sampler.sigma_scheduler.sigmas = ((1-alphas_cumprod)/alphas_cumprod).sqrt()
|
13
|
+
|
14
|
+
|
15
|
+
@staticmethod
|
16
|
+
def rescale_zero_terminal_snr(alphas_cumprod, thr=1e-4):
|
17
|
+
"""
|
18
|
+
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
19
|
+
Args:
|
20
|
+
alphas_cumprod (`torch.FloatTensor`)
|
21
|
+
Returns:
|
22
|
+
`torch.FloatTensor`: rescaled betas with zero terminal SNR
|
23
|
+
"""
|
24
|
+
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
25
|
+
|
26
|
+
# Store old values.
|
27
|
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
28
|
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
29
|
+
|
30
|
+
# Shift so the last timestep is zero.
|
31
|
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
32
|
+
|
33
|
+
# Scale so the first timestep is back to the old value.
|
34
|
+
alphas_bar_sqrt *= alphas_bar_sqrt_0/(alphas_bar_sqrt_0-alphas_bar_sqrt_T)
|
35
|
+
alphas_bar_sqrt[-1] = thr # avoid nan sigma
|
36
|
+
|
37
|
+
# Convert alphas_bar_sqrt to betas
|
38
|
+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
39
|
+
return alphas_bar
|
@@ -0,0 +1,72 @@
|
|
1
|
+
from typing import Tuple
|
2
|
+
import torch
|
3
|
+
from .sigma_scheduler import SigmaScheduler
|
4
|
+
from diffusers import DDPMScheduler
|
5
|
+
|
6
|
+
class BaseSampler:
|
7
|
+
def __init__(self, sigma_scheduler: SigmaScheduler, generator: torch.Generator = None):
|
8
|
+
self.sigma_scheduler = sigma_scheduler
|
9
|
+
self.generator = generator
|
10
|
+
|
11
|
+
def c_in(self, sigma):
|
12
|
+
return 1
|
13
|
+
|
14
|
+
def c_out(self, sigma):
|
15
|
+
return 1
|
16
|
+
|
17
|
+
def c_skip(self, sigma):
|
18
|
+
return 1
|
19
|
+
|
20
|
+
@property
|
21
|
+
def num_timesteps(self):
|
22
|
+
return getattr(self.sigma_scheduler, 'num_timesteps', 1000.)
|
23
|
+
|
24
|
+
def get_timesteps(self, N_steps, device='cuda'):
|
25
|
+
return torch.linspace(0, self.num_timesteps, N_steps, device=device)
|
26
|
+
|
27
|
+
def make_nosie(self, shape, device='cuda', dtype=torch.float32):
|
28
|
+
return torch.randn(shape, generator=self.generator, device=device, dtype=dtype)
|
29
|
+
|
30
|
+
def init_noise(self, shape, device='cuda', dtype=torch.float32):
|
31
|
+
sigma = self.sigma_scheduler.sigma_max
|
32
|
+
return self.make_nosie(shape, device, dtype)*sigma
|
33
|
+
|
34
|
+
def add_noise(self, x, sigma) -> Tuple[torch.Tensor, torch.Tensor]:
|
35
|
+
noise = self.make_nosie(x.shape, device=x.device)
|
36
|
+
noisy_x = (x.to(dtype=torch.float32)-self.c_out(sigma)*noise)/self.c_skip(sigma)
|
37
|
+
return noisy_x.to(dtype=x.dtype), noise.to(dtype=x.dtype)
|
38
|
+
|
39
|
+
def add_noise_rand_t(self, x):
|
40
|
+
bs = x.shape[0]
|
41
|
+
# timesteps: [0, 1]
|
42
|
+
sigma, timesteps = self.sigma_scheduler.sample_sigma(shape=(bs,))
|
43
|
+
sigma = sigma.view(-1, 1, 1, 1).to(x.device)
|
44
|
+
timesteps = timesteps.to(x.device)
|
45
|
+
noisy_x, noise = self.add_noise(x, sigma)
|
46
|
+
|
47
|
+
# Sample a random timestep for each image
|
48
|
+
timesteps = timesteps*(self.num_timesteps-1)
|
49
|
+
return noisy_x, noise, sigma, timesteps
|
50
|
+
|
51
|
+
def denoise(self, x, sigma, eps=None, generator=None):
|
52
|
+
raise NotImplementedError
|
53
|
+
|
54
|
+
def eps_to_x0(self, eps, x_t, sigma):
|
55
|
+
return self.c_skip(sigma)*x_t+self.c_out(sigma)*eps
|
56
|
+
|
57
|
+
def velocity_to_eps(self, v_pred, x_t, sigma):
|
58
|
+
alpha = 1/(sigma**2+1)
|
59
|
+
sqrt_alpha = alpha.sqrt()
|
60
|
+
one_sqrt_alpha = (1-alpha).sqrt()
|
61
|
+
return sqrt_alpha*v_pred + one_sqrt_alpha*(x_t*sqrt_alpha)
|
62
|
+
|
63
|
+
def eps_to_velocity(self, eps, x_t, sigma):
|
64
|
+
alpha = 1/(sigma**2+1)
|
65
|
+
sqrt_alpha = alpha.sqrt()
|
66
|
+
one_sqrt_alpha = (1-alpha).sqrt()
|
67
|
+
return eps/sqrt_alpha - one_sqrt_alpha*x_t
|
68
|
+
|
69
|
+
def velocity_to_x0(self, v_pred, x_t, sigma):
|
70
|
+
alpha = 1/(sigma**2+1)
|
71
|
+
one_sqrt_alpha = (1-alpha).sqrt()
|
72
|
+
return alpha*x_t - one_sqrt_alpha*v_pred
|
@@ -0,0 +1,20 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
from .base import BaseSampler
|
4
|
+
from .sigma_scheduler import SigmaScheduler
|
5
|
+
|
6
|
+
class DDPMSampler(BaseSampler):
|
7
|
+
def __init__(self, sigma_scheduler: SigmaScheduler, generator: torch.Generator=None):
|
8
|
+
super().__init__(sigma_scheduler, generator)
|
9
|
+
|
10
|
+
def c_in(self, sigma):
|
11
|
+
return 1./(sigma**2+1).sqrt()
|
12
|
+
|
13
|
+
def c_out(self, sigma):
|
14
|
+
return -sigma
|
15
|
+
|
16
|
+
def c_skip(self, sigma):
|
17
|
+
return 1.
|
18
|
+
|
19
|
+
def denoise(self, x, sigma, eps=None, generator=None):
|
20
|
+
raise NotImplementedError
|
@@ -0,0 +1,66 @@
|
|
1
|
+
import torch
|
2
|
+
import inspect
|
3
|
+
from diffusers import SchedulerMixin, DDPMScheduler
|
4
|
+
|
5
|
+
try:
|
6
|
+
from diffusers.utils import randn_tensor
|
7
|
+
except:
|
8
|
+
# new version of diffusers
|
9
|
+
from diffusers.utils.torch_utils import randn_tensor
|
10
|
+
|
11
|
+
from .base import BaseSampler
|
12
|
+
from .sigma_scheduler import TimeSigmaScheduler
|
13
|
+
|
14
|
+
class DiffusersSampler(BaseSampler):
|
15
|
+
def __init__(self, scheduler: SchedulerMixin, eta=0.0, generator: torch.Generator=None):
|
16
|
+
sigma_scheduler = TimeSigmaScheduler()
|
17
|
+
super().__init__(sigma_scheduler, generator)
|
18
|
+
self.scheduler = scheduler
|
19
|
+
self.eta = eta
|
20
|
+
|
21
|
+
def c_in(self, sigma):
|
22
|
+
one = torch.ones_like(sigma)
|
23
|
+
if hasattr(self.scheduler, '_step_index'):
|
24
|
+
self.scheduler._step_index = None
|
25
|
+
return self.scheduler.scale_model_input(one, sigma)
|
26
|
+
|
27
|
+
def c_out(self, sigma):
|
28
|
+
return -sigma
|
29
|
+
|
30
|
+
def c_skip(self, sigma):
|
31
|
+
if self.c_in(sigma) == 1.: # DDPM model
|
32
|
+
return (sigma**2+1).sqrt() # 1/sqrt(alpha_)
|
33
|
+
else: # EDM model
|
34
|
+
return 1.
|
35
|
+
|
36
|
+
def get_timesteps(self, N_steps, device='cuda'):
|
37
|
+
self.scheduler.set_timesteps(N_steps, device=device)
|
38
|
+
return self.scheduler.timesteps
|
39
|
+
|
40
|
+
def init_noise(self, shape, device='cuda', dtype=torch.float32):
|
41
|
+
return randn_tensor(shape, generator=self.generator, device=device, dtype=dtype)*self.scheduler.init_noise_sigma
|
42
|
+
|
43
|
+
def add_noise(self, x, sigma):
|
44
|
+
noise = randn_tensor(x.shape, generator=self.generator, device=x.device, dtype=x.dtype)
|
45
|
+
return self.scheduler.add_noise(x, noise, sigma), noise
|
46
|
+
|
47
|
+
def prepare_extra_step_kwargs(self, scheduler, generator, eta):
|
48
|
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
49
|
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
50
|
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
51
|
+
# and should be between [0, 1]
|
52
|
+
|
53
|
+
accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys())
|
54
|
+
extra_step_kwargs = {}
|
55
|
+
if accepts_eta:
|
56
|
+
extra_step_kwargs["eta"] = eta
|
57
|
+
|
58
|
+
# check if the scheduler accepts generator
|
59
|
+
accepts_generator = "generator" in set(inspect.signature(scheduler.step).parameters.keys())
|
60
|
+
if accepts_generator:
|
61
|
+
extra_step_kwargs["generator"] = generator
|
62
|
+
return extra_step_kwargs
|
63
|
+
|
64
|
+
def denoise(self, x_t, sigma, eps=None, generator=None):
|
65
|
+
extra_step_kwargs = self.prepare_extra_step_kwargs(self.scheduler, generator, self.eta)
|
66
|
+
return self.scheduler.step(eps, sigma, x_t, **extra_step_kwargs).prev_sample
|