hcpdiff 0.9.0__py3-none-any.whl → 2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/__init__.py +4 -4
- hcpdiff/ckpt_manager/__init__.py +4 -5
- hcpdiff/ckpt_manager/ckpt.py +24 -0
- hcpdiff/ckpt_manager/format/__init__.py +4 -0
- hcpdiff/ckpt_manager/format/diffusers.py +59 -0
- hcpdiff/ckpt_manager/format/emb.py +21 -0
- hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
- hcpdiff/ckpt_manager/format/sd_single.py +41 -0
- hcpdiff/ckpt_manager/loader.py +64 -0
- hcpdiff/data/__init__.py +4 -28
- hcpdiff/data/cache/__init__.py +1 -0
- hcpdiff/data/cache/vae.py +102 -0
- hcpdiff/data/dataset.py +20 -0
- hcpdiff/data/handler/__init__.py +3 -0
- hcpdiff/data/handler/controlnet.py +18 -0
- hcpdiff/data/handler/diffusion.py +80 -0
- hcpdiff/data/handler/text.py +111 -0
- hcpdiff/data/source/__init__.py +1 -2
- hcpdiff/data/source/folder_class.py +12 -29
- hcpdiff/data/source/text2img.py +36 -74
- hcpdiff/data/source/text2img_cond.py +9 -15
- hcpdiff/diffusion/__init__.py +0 -0
- hcpdiff/diffusion/noise/__init__.py +2 -0
- hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
- hcpdiff/diffusion/noise/zero_terminal.py +39 -0
- hcpdiff/diffusion/sampler/__init__.py +5 -0
- hcpdiff/diffusion/sampler/base.py +72 -0
- hcpdiff/diffusion/sampler/ddpm.py +20 -0
- hcpdiff/diffusion/sampler/diffusers.py +66 -0
- hcpdiff/diffusion/sampler/edm.py +22 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
- hcpdiff/easy/__init__.py +2 -0
- hcpdiff/easy/cfg/__init__.py +3 -0
- hcpdiff/easy/cfg/sd15_train.py +201 -0
- hcpdiff/easy/cfg/sdxl_train.py +140 -0
- hcpdiff/easy/cfg/t2i.py +177 -0
- hcpdiff/easy/model/__init__.py +2 -0
- hcpdiff/easy/model/cnet.py +31 -0
- hcpdiff/easy/model/loader.py +79 -0
- hcpdiff/easy/sampler.py +46 -0
- hcpdiff/evaluate/__init__.py +1 -0
- hcpdiff/evaluate/previewer.py +60 -0
- hcpdiff/loss/__init__.py +4 -1
- hcpdiff/loss/base.py +41 -0
- hcpdiff/loss/gw.py +35 -0
- hcpdiff/loss/ssim.py +37 -0
- hcpdiff/loss/vlb.py +79 -0
- hcpdiff/loss/weighting.py +66 -0
- hcpdiff/models/__init__.py +2 -2
- hcpdiff/models/cfg_context.py +17 -14
- hcpdiff/models/compose/compose_hook.py +44 -23
- hcpdiff/models/compose/compose_tokenizer.py +21 -8
- hcpdiff/models/compose/sdxl_composer.py +4 -4
- hcpdiff/models/container.py +1 -1
- hcpdiff/models/controlnet.py +16 -16
- hcpdiff/models/lora_base_patch.py +14 -25
- hcpdiff/models/lora_layers.py +3 -9
- hcpdiff/models/lora_layers_patch.py +14 -24
- hcpdiff/models/text_emb_ex.py +84 -6
- hcpdiff/models/textencoder_ex.py +54 -18
- hcpdiff/models/wrapper/__init__.py +3 -0
- hcpdiff/models/wrapper/pixart.py +19 -0
- hcpdiff/models/wrapper/sd.py +218 -0
- hcpdiff/models/wrapper/utils.py +20 -0
- hcpdiff/parser/__init__.py +1 -0
- hcpdiff/parser/embpt.py +32 -0
- hcpdiff/tools/convert_caption_txt2json.py +1 -1
- hcpdiff/tools/dataset_generator.py +94 -0
- hcpdiff/tools/download_hf_model.py +24 -0
- hcpdiff/tools/embedding_convert.py +6 -2
- hcpdiff/tools/init_proj.py +3 -21
- hcpdiff/tools/lora_convert.py +19 -15
- hcpdiff/tools/save_model.py +12 -0
- hcpdiff/tools/sd2diffusers.py +1 -1
- hcpdiff/train_colo.py +1 -1
- hcpdiff/train_deepspeed.py +1 -1
- hcpdiff/trainer_ac.py +79 -0
- hcpdiff/trainer_ac_single.py +31 -0
- hcpdiff/utils/__init__.py +0 -2
- hcpdiff/utils/inpaint_pipe.py +790 -0
- hcpdiff/utils/net_utils.py +29 -6
- hcpdiff/utils/pipe_hook.py +46 -33
- hcpdiff/utils/utils.py +21 -4
- hcpdiff/workflow/__init__.py +15 -10
- hcpdiff/workflow/daam/__init__.py +1 -0
- hcpdiff/workflow/daam/act.py +66 -0
- hcpdiff/workflow/daam/hook.py +109 -0
- hcpdiff/workflow/diffusion.py +128 -136
- hcpdiff/workflow/fast.py +31 -0
- hcpdiff/workflow/flow.py +67 -0
- hcpdiff/workflow/io.py +36 -68
- hcpdiff/workflow/model.py +46 -43
- hcpdiff/workflow/text.py +84 -52
- hcpdiff/workflow/utils.py +32 -12
- hcpdiff/workflow/vae.py +37 -38
- hcpdiff-2.1.dist-info/METADATA +285 -0
- hcpdiff-2.1.dist-info/RECORD +114 -0
- {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
- hcpdiff-2.1.dist-info/entry_points.txt +5 -0
- hcpdiff/ckpt_manager/base.py +0 -16
- hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
- hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
- hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -60
- hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
- hcpdiff/data/bucket.py +0 -358
- hcpdiff/data/caption_loader.py +0 -80
- hcpdiff/data/cond_dataset.py +0 -40
- hcpdiff/data/crop_info_dataset.py +0 -40
- hcpdiff/data/data_processor.py +0 -33
- hcpdiff/data/pair_dataset.py +0 -146
- hcpdiff/data/sampler.py +0 -54
- hcpdiff/data/source/base.py +0 -30
- hcpdiff/data/utils.py +0 -80
- hcpdiff/infer_workflow.py +0 -57
- hcpdiff/loggers/__init__.py +0 -13
- hcpdiff/loggers/base_logger.py +0 -76
- hcpdiff/loggers/cli_logger.py +0 -40
- hcpdiff/loggers/preview/__init__.py +0 -1
- hcpdiff/loggers/preview/image_previewer.py +0 -149
- hcpdiff/loggers/tensorboard_logger.py +0 -30
- hcpdiff/loggers/wandb_logger.py +0 -31
- hcpdiff/loggers/webui_logger.py +0 -9
- hcpdiff/loss/min_snr_loss.py +0 -52
- hcpdiff/models/layers.py +0 -81
- hcpdiff/models/plugin.py +0 -348
- hcpdiff/models/wrapper.py +0 -75
- hcpdiff/noise/__init__.py +0 -3
- hcpdiff/noise/noise_base.py +0 -16
- hcpdiff/noise/pyramid_noise.py +0 -50
- hcpdiff/noise/zero_terminal.py +0 -44
- hcpdiff/train_ac.py +0 -565
- hcpdiff/train_ac_single.py +0 -39
- hcpdiff/utils/caption_tools.py +0 -105
- hcpdiff/utils/cfg_net_tools.py +0 -321
- hcpdiff/utils/cfg_resolvers.py +0 -16
- hcpdiff/utils/ema.py +0 -52
- hcpdiff/utils/img_size_tool.py +0 -248
- hcpdiff/vis/__init__.py +0 -3
- hcpdiff/vis/base_interface.py +0 -12
- hcpdiff/vis/disk_interface.py +0 -48
- hcpdiff/vis/webui_interface.py +0 -17
- hcpdiff/visualizer.py +0 -258
- hcpdiff/visualizer_reloadable.py +0 -237
- hcpdiff/workflow/base.py +0 -59
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -57
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero2.json +0 -32
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero3.json +0 -39
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
- hcpdiff-0.9.0.dist-info/METADATA +0 -199
- hcpdiff-0.9.0.dist-info/RECORD +0 -155
- hcpdiff-0.9.0.dist-info/entry_points.txt +0 -2
- {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
- {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
hcpdiff/workflow/diffusion.py
CHANGED
@@ -1,8 +1,13 @@
|
|
1
|
-
|
1
|
+
import random
|
2
|
+
import warnings
|
2
3
|
from typing import Dict, Any, Union, List
|
4
|
+
|
3
5
|
import torch
|
6
|
+
from hcpdiff.diffusion.sampler import BaseSampler, DiffusersSampler
|
7
|
+
from hcpdiff.utils import prepare_seed
|
8
|
+
from hcpdiff.utils.net_utils import get_dtype, to_cuda
|
9
|
+
from rainbowneko.infer import BasicAction
|
4
10
|
from torch.cuda.amp import autocast
|
5
|
-
import inspect
|
6
11
|
|
7
12
|
try:
|
8
13
|
from diffusers.utils import randn_tensor
|
@@ -10,197 +15,184 @@ except:
|
|
10
15
|
# new version of diffusers
|
11
16
|
from diffusers.utils.torch_utils import randn_tensor
|
12
17
|
|
13
|
-
from hcpdiff.utils import prepare_seed
|
14
|
-
from hcpdiff.utils.net_utils import get_dtype, to_cpu, to_cuda
|
15
|
-
import random
|
16
|
-
|
17
18
|
class InputFeederAction(BasicAction):
|
18
|
-
|
19
|
-
|
20
|
-
super().__init__()
|
19
|
+
def __init__(self, ex_inputs: Dict[str, Any], key_map_in=None, key_map_out=None):
|
20
|
+
super().__init__(key_map_in, key_map_out)
|
21
21
|
self.ex_inputs = ex_inputs
|
22
|
-
self.unet = unet
|
23
22
|
|
24
|
-
def forward(self, **states):
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
23
|
+
def forward(self, model, ex_inputs=None, **states):
|
24
|
+
ex_inputs = self.ex_inputs if ex_inputs is None else {**ex_inputs, **self.ex_inputs}
|
25
|
+
if hasattr(model, 'input_feeder'):
|
26
|
+
for feeder in model.input_feeder:
|
27
|
+
feeder(ex_inputs)
|
29
28
|
|
30
29
|
class SeedAction(BasicAction):
|
31
|
-
|
32
|
-
|
33
|
-
super().__init__()
|
30
|
+
def __init__(self, seed: Union[int, List[int]], bs: int = 1, key_map_in=None, key_map_out=None):
|
31
|
+
super().__init__(key_map_in, key_map_out)
|
34
32
|
self.seed = seed
|
35
33
|
self.bs = bs
|
36
34
|
|
37
|
-
def forward(self, device, **states):
|
35
|
+
def forward(self, device, gen_step=0, **states):
|
38
36
|
bs = states['prompt_embeds'].shape[0]//2 if 'prompt_embeds' in states else self.bs
|
39
37
|
if self.seed is None:
|
40
38
|
seeds = [None]*bs
|
41
39
|
elif isinstance(self.seed, int):
|
42
|
-
seeds = list(range(self.seed, self.seed+bs))
|
40
|
+
seeds = list(range(self.seed+gen_step*bs, self.seed+(gen_step+1)*bs))
|
43
41
|
else:
|
44
42
|
seeds = self.seed
|
45
43
|
seeds = [s or random.randint(0, 1 << 30) for s in seeds]
|
46
44
|
|
47
45
|
G = prepare_seed(seeds, device=device)
|
48
|
-
return {
|
49
|
-
|
50
|
-
class PrepareDiffusionAction(BasicAction
|
51
|
-
def __init__(self,
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
46
|
+
return {'seeds':seeds, 'generator':G}
|
47
|
+
|
48
|
+
class PrepareDiffusionAction(BasicAction):
|
49
|
+
def __init__(self, model_offload=False, amp=torch.float16, key_map_in=None, key_map_out=None):
|
50
|
+
super().__init__(key_map_in, key_map_out)
|
51
|
+
self.model_offload = model_offload
|
52
|
+
self.amp = amp
|
53
|
+
|
54
|
+
def forward(self, device, denoiser, TE, vae, **states):
|
55
|
+
denoiser.to(device)
|
56
|
+
TE.to(device)
|
57
|
+
vae.to(device)
|
58
|
+
|
59
|
+
TE.eval()
|
60
|
+
denoiser.eval()
|
61
|
+
vae.eval()
|
62
|
+
return {'amp':self.amp, 'model_offload':self.model_offload}
|
63
|
+
|
64
|
+
class MakeTimestepsAction(BasicAction):
|
65
|
+
def __init__(self, N_steps: int = 30, strength: float = None, key_map_in=None, key_map_out=None):
|
66
|
+
super().__init__(key_map_in, key_map_out)
|
68
67
|
self.N_steps = N_steps
|
69
68
|
self.strength = strength
|
70
69
|
|
71
|
-
def get_timesteps(self, timesteps, strength):
|
70
|
+
def get_timesteps(self, noise_sampler:BaseSampler, timesteps, strength):
|
72
71
|
# get the original timestep using init_timestep
|
73
72
|
num_inference_steps = len(timesteps)
|
74
|
-
init_timestep = min(int(num_inference_steps
|
73
|
+
init_timestep = min(int(num_inference_steps*strength), num_inference_steps)
|
75
74
|
|
76
|
-
t_start = max(num_inference_steps
|
77
|
-
|
75
|
+
t_start = max(num_inference_steps-init_timestep, 0)
|
76
|
+
if isinstance(noise_sampler, DiffusersSampler):
|
77
|
+
timesteps = timesteps[t_start*noise_sampler.scheduler.order:]
|
78
|
+
else:
|
79
|
+
timesteps = timesteps[t_start:]
|
78
80
|
|
79
81
|
return timesteps
|
80
82
|
|
81
|
-
def forward(self,
|
82
|
-
|
83
|
-
|
84
|
-
self.scheduler.set_timesteps(self.N_steps, device=device)
|
85
|
-
timesteps = self.scheduler.timesteps
|
83
|
+
def forward(self, noise_sampler:BaseSampler, device, **states):
|
84
|
+
timesteps = noise_sampler.get_timesteps(self.N_steps, device=device)
|
86
85
|
if self.strength:
|
87
|
-
timesteps = self.get_timesteps(timesteps, self.strength)
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
def __init__(self,
|
94
|
-
|
95
|
-
self.N_ch=N_ch
|
96
|
-
self.height=height
|
97
|
-
self.width=width
|
98
|
-
|
99
|
-
def forward(self,
|
86
|
+
timesteps = self.get_timesteps(noise_sampler, timesteps, self.strength)
|
87
|
+
return {'timesteps':timesteps, 'start_timestep':timesteps[:1]}
|
88
|
+
else:
|
89
|
+
return {'timesteps':timesteps}
|
90
|
+
|
91
|
+
class MakeLatentAction(BasicAction):
|
92
|
+
def __init__(self, N_ch=4, height=None, width=None, key_map_in=None, key_map_out=None):
|
93
|
+
super().__init__(key_map_in, key_map_out)
|
94
|
+
self.N_ch = N_ch
|
95
|
+
self.height = height
|
96
|
+
self.width = width
|
97
|
+
|
98
|
+
def forward(self, noise_sampler:BaseSampler, vae, generator, device, dtype, bs=None, latents=None, start_timestep=None,
|
99
|
+
pooled_output=None, crop_coord=None, **states):
|
100
100
|
if bs is None:
|
101
101
|
if 'prompt' in states:
|
102
102
|
bs = len(states['prompt'])
|
103
|
-
|
103
|
+
vae_scale_factor = 2**(len(vae.config.block_out_channels)-1)
|
104
|
+
device = torch.device(device)
|
104
105
|
|
105
|
-
|
106
|
+
if latents is None:
|
107
|
+
shape = (bs, self.N_ch, self.height//vae_scale_factor, self.width//vae_scale_factor)
|
108
|
+
else:
|
109
|
+
if self.height is not None:
|
110
|
+
warnings.warn('latents exist! User-specified width and height will be ignored!')
|
111
|
+
shape = latents.shape
|
106
112
|
if isinstance(generator, list) and len(generator) != bs:
|
107
113
|
raise ValueError(
|
108
114
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
109
115
|
f" size of {bs}. Make sure the batch size matches the length of the generators."
|
110
116
|
)
|
111
117
|
|
112
|
-
noise = randn_tensor(shape, generator=generator, device=device, dtype=get_dtype(dtype))
|
113
118
|
if latents is None:
|
114
|
-
# scale the initial noise by the standard deviation required by the
|
115
|
-
|
119
|
+
# scale the initial noise by the standard deviation required by the noise_sampler
|
120
|
+
noise_sampler.generator = generator
|
121
|
+
latents = noise_sampler.init_noise(shape, device=device, dtype=get_dtype(dtype))
|
116
122
|
else:
|
117
123
|
# image to image
|
118
124
|
latents = latents.to(device)
|
119
|
-
latents =
|
125
|
+
latents, noise = noise_sampler.add_noise(latents, start_timestep)
|
126
|
+
|
127
|
+
output = {'latents':latents}
|
128
|
+
|
129
|
+
# SDXL inputs
|
130
|
+
if pooled_output is not None:
|
131
|
+
width, height = shape[3]*vae_scale_factor, shape[2]*vae_scale_factor
|
132
|
+
if crop_coord is None:
|
133
|
+
crop_info = torch.tensor([height, width, 0, 0, height, width], dtype=torch.float)
|
134
|
+
else:
|
135
|
+
crop_info = torch.tensor([height, width, *crop_coord], dtype=torch.float)
|
136
|
+
crop_info = crop_info.to(device).repeat(bs, 1)
|
137
|
+
output['text_embeds'] = pooled_output[-1].to(device)
|
120
138
|
|
121
|
-
|
139
|
+
if 'negative_prompt' in states:
|
140
|
+
output['crop_info'] = torch.cat([crop_info, crop_info], dim=0)
|
122
141
|
|
123
|
-
|
124
|
-
@from_memory_context
|
125
|
-
def __init__(self, unet=None, scheduler=None, guidance_scale:float=7.0):
|
126
|
-
self.guidance_scale=guidance_scale
|
127
|
-
self.unet = unet
|
128
|
-
self.scheduler = scheduler
|
142
|
+
return output
|
129
143
|
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
self.
|
144
|
+
class DenoiseAction(BasicAction):
|
145
|
+
def __init__(self, guidance_scale: float = 7.0, key_map_in=None, key_map_out=None):
|
146
|
+
super().__init__(key_map_in, key_map_out)
|
147
|
+
self.guidance_scale = guidance_scale
|
134
148
|
|
135
|
-
|
149
|
+
def forward(self, denoiser, noise_sampler: BaseSampler, t, latents, prompt_embeds, text_embeds=None, encoder_attention_mask=None, crop_info=None,
|
150
|
+
cross_attention_kwargs=None, dtype='fp32', amp=None, model_offload=False, **states):
|
151
|
+
|
152
|
+
if model_offload:
|
153
|
+
to_cuda(denoiser) # to_cpu in VAE
|
154
|
+
|
155
|
+
with autocast(enabled=amp is not None, dtype=get_dtype(amp)):
|
136
156
|
latent_model_input = torch.cat([latents]*2) if self.guidance_scale>1 else latents
|
137
|
-
latent_model_input =
|
157
|
+
latent_model_input = noise_sampler.c_in(t)*latent_model_input
|
138
158
|
|
139
|
-
if
|
140
|
-
noise_pred =
|
141
|
-
|
159
|
+
if text_embeds is None:
|
160
|
+
noise_pred = denoiser(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
|
161
|
+
cross_attention_kwargs=cross_attention_kwargs, ).sample
|
142
162
|
else:
|
143
|
-
added_cond_kwargs = {"text_embeds":
|
163
|
+
added_cond_kwargs = {"text_embeds":text_embeds, "time_ids":crop_info}
|
144
164
|
# predict the noise residual
|
145
|
-
noise_pred =
|
146
|
-
|
165
|
+
noise_pred = denoiser(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
|
166
|
+
cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
|
147
167
|
|
148
168
|
# perform guidance
|
149
169
|
if self.guidance_scale>1:
|
150
170
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
151
171
|
noise_pred = noise_pred_uncond+self.guidance_scale*(noise_pred_text-noise_pred_uncond)
|
152
172
|
|
153
|
-
return {
|
154
|
-
'crop_info':crop_info, 'cross_attention_kwargs':cross_attention_kwargs, 'dtype':dtype}
|
155
|
-
|
156
|
-
class SampleAction(BasicAction, MemoryMixin):
|
157
|
-
@from_memory_context
|
158
|
-
def __init__(self, scheduler=None, eta=0.0):
|
159
|
-
self.scheduler = scheduler
|
160
|
-
self.eta = eta
|
161
|
-
|
162
|
-
def prepare_extra_step_kwargs(self, generator, eta):
|
163
|
-
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
164
|
-
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
165
|
-
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
166
|
-
# and should be between [0, 1]
|
167
|
-
|
168
|
-
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
169
|
-
extra_step_kwargs = {}
|
170
|
-
if accepts_eta:
|
171
|
-
extra_step_kwargs["eta"] = eta
|
172
|
-
|
173
|
-
# check if the scheduler accepts generator
|
174
|
-
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
175
|
-
if accepts_generator:
|
176
|
-
extra_step_kwargs["generator"] = generator
|
177
|
-
return extra_step_kwargs
|
178
|
-
|
179
|
-
def forward(self, memory, noise_pred, t, latents, generator, **states):
|
180
|
-
self.scheduler = self.scheduler or memory.scheduler
|
181
|
-
|
182
|
-
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, self.eta)
|
173
|
+
return {'noise_pred':noise_pred}
|
183
174
|
|
175
|
+
class SampleAction(BasicAction):
|
176
|
+
def forward(self, noise_sampler: BaseSampler, noise_pred, t, latents, generator, **states):
|
184
177
|
# compute the previous noisy sample x_t -> x_t-1
|
185
|
-
|
186
|
-
latents
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
self.
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
states = self.
|
197
|
-
states = self.act_sample(memory=memory, **states)
|
178
|
+
latents = noise_sampler.denoise(latents, t, noise_pred, generator=generator)
|
179
|
+
return {'latents':latents}
|
180
|
+
|
181
|
+
class DiffusionStepAction(BasicAction):
|
182
|
+
def __init__(self, guidance_scale: float = 7.0, key_map_in=None, key_map_out=None):
|
183
|
+
super().__init__(key_map_in, key_map_out)
|
184
|
+
self.act_noise_pred = DenoiseAction(guidance_scale)
|
185
|
+
self.act_sample = SampleAction()
|
186
|
+
|
187
|
+
def forward(self, denoiser, noise_sampler, **states):
|
188
|
+
states = self.act_noise_pred(denoiser=denoiser, noise_sampler=noise_sampler, **states)
|
189
|
+
states = self.act_sample(**states)
|
198
190
|
return states
|
199
191
|
|
200
192
|
class X0PredAction(BasicAction):
|
201
|
-
def forward(self, latents,
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
193
|
+
def forward(self, latents, noise_sampler: BaseSampler, t, noise_pred, **states):
|
194
|
+
latents_x0 = noise_sampler.eps_to_x0(noise_pred, latents, t)
|
195
|
+
return {'latents_x0':latents_x0}
|
196
|
+
|
197
|
+
def time_iter(timesteps, **states):
|
198
|
+
return [{'t':t} for t in timesteps]
|
hcpdiff/workflow/fast.py
ADDED
@@ -0,0 +1,31 @@
|
|
1
|
+
from sfast.compilers.diffusion_pipeline_compiler import (compile_unet, CompilationConfig)
|
2
|
+
from rainbowneko.infer import BasicAction
|
3
|
+
|
4
|
+
|
5
|
+
class SFastCompileAction(BasicAction):
|
6
|
+
|
7
|
+
@staticmethod
|
8
|
+
def compile_model(unet):
|
9
|
+
# compile model
|
10
|
+
config = CompilationConfig.Default()
|
11
|
+
config.enable_xformers = False
|
12
|
+
try:
|
13
|
+
import xformers
|
14
|
+
config.enable_xformers = True
|
15
|
+
except ImportError:
|
16
|
+
print('xformers not installed, skip')
|
17
|
+
# NOTE:
|
18
|
+
# When GPU VRAM is insufficient or the architecture is too old, Triton might be slow.
|
19
|
+
# Disable Triton if you encounter this problem.
|
20
|
+
try:
|
21
|
+
import tritonx
|
22
|
+
config.enable_triton = True
|
23
|
+
except ImportError:
|
24
|
+
print('Triton not installed, skip')
|
25
|
+
config.enable_cuda_graph = True
|
26
|
+
|
27
|
+
return compile_unet(unet, config)
|
28
|
+
|
29
|
+
def forward(self, denoiser, **states):
|
30
|
+
denoiser = self.compile_model(denoiser)
|
31
|
+
return {'denoiser': denoiser}
|
hcpdiff/workflow/flow.py
ADDED
@@ -0,0 +1,67 @@
|
|
1
|
+
from rainbowneko.infer import BasicAction
|
2
|
+
from typing import List, Dict
|
3
|
+
from tqdm import tqdm
|
4
|
+
import math
|
5
|
+
|
6
|
+
class FilePromptAction(BasicAction):
|
7
|
+
def __init__(self, actions: List[BasicAction], prompt: str, negative_prompt: str, bs: int = 4, key_map_in=None, key_map_out=None):
|
8
|
+
super().__init__(key_map_in, key_map_out)
|
9
|
+
if prompt.endswith('.txt'):
|
10
|
+
with open(prompt, 'r') as f:
|
11
|
+
prompt = f.read().split('\n')
|
12
|
+
else:
|
13
|
+
prompt = [prompt]
|
14
|
+
|
15
|
+
if negative_prompt.endswith('.txt'):
|
16
|
+
with open(negative_prompt, 'r') as f:
|
17
|
+
negative_prompt = f.read().split('\n')
|
18
|
+
else:
|
19
|
+
negative_prompt = [negative_prompt]*len(prompt)
|
20
|
+
|
21
|
+
self.prompt = prompt
|
22
|
+
self.negative_prompt = negative_prompt
|
23
|
+
self.bs = bs
|
24
|
+
self.actions = actions
|
25
|
+
|
26
|
+
|
27
|
+
def forward(self, **states):
|
28
|
+
states.update({'prompt_all':self.prompt, 'negative_prompt_all':self.negative_prompt})
|
29
|
+
states_ref = dict(**states)
|
30
|
+
|
31
|
+
pbar = tqdm(range(math.ceil(len(self.prompt)/self.bs)))
|
32
|
+
N_steps = len(self.actions)
|
33
|
+
for gen_step in pbar:
|
34
|
+
states = dict(**states_ref)
|
35
|
+
feed_data = {'gen_step': gen_step}
|
36
|
+
states.update(feed_data)
|
37
|
+
for step, act in enumerate(self.actions):
|
38
|
+
pbar.set_description(f'[{step+1}/{N_steps}] action: {type(act).__name__}')
|
39
|
+
states = act(**states)
|
40
|
+
return states
|
41
|
+
|
42
|
+
class FlowPromptAction(BasicAction):
|
43
|
+
def __init__(self, actions: List[BasicAction], prompt: str, negative_prompt: str, bs: int = 4, num: int = None, key_map_in=None, key_map_out=None):
|
44
|
+
super().__init__(key_map_in, key_map_out)
|
45
|
+
prompt = [prompt]*num
|
46
|
+
negative_prompt = [negative_prompt]*num
|
47
|
+
|
48
|
+
self.prompt = prompt
|
49
|
+
self.negative_prompt = negative_prompt
|
50
|
+
self.bs = bs
|
51
|
+
self.actions = actions
|
52
|
+
|
53
|
+
|
54
|
+
def forward(self, **states):
|
55
|
+
states.update({'prompt_all':self.prompt, 'negative_prompt_all':self.negative_prompt})
|
56
|
+
states_ref = dict(**states)
|
57
|
+
|
58
|
+
pbar = tqdm(range(math.ceil(len(self.prompt)/self.bs)))
|
59
|
+
N_steps = len(self.actions)
|
60
|
+
for gen_step in pbar:
|
61
|
+
states = dict(**states_ref)
|
62
|
+
feed_data = {'gen_step': gen_step}
|
63
|
+
states.update(feed_data)
|
64
|
+
for step, act in enumerate(self.actions):
|
65
|
+
pbar.set_description(f'[{step+1}/{N_steps}] action: {type(act).__name__}')
|
66
|
+
states = act(**states)
|
67
|
+
return states
|
hcpdiff/workflow/io.py
CHANGED
@@ -1,88 +1,56 @@
|
|
1
1
|
import os
|
2
|
+
from functools import partial
|
3
|
+
from typing import List, Union
|
2
4
|
|
3
|
-
|
4
|
-
|
5
|
-
from hcpdiff.utils import auto_text_encoder, auto_tokenizer, to_validate_file
|
6
|
-
from hcpdiff.utils.cfg_net_tools import HCPModelLoader
|
7
|
-
from hcpdiff.utils.img_size_tool import types_support
|
5
|
+
import torch
|
6
|
+
from hcpdiff.utils import to_validate_file
|
8
7
|
from hcpdiff.utils.net_utils import get_dtype
|
9
|
-
from .
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
8
|
+
from rainbowneko.ckpt_manager import NekoLoader
|
9
|
+
from rainbowneko.infer import BasicAction
|
10
|
+
from rainbowneko.infer import LoadImageAction as Neko_LoadImageAction
|
11
|
+
from rainbowneko.utils.img_size_tool import types_support
|
12
|
+
|
13
|
+
class BuildModelsAction(BasicAction):
|
14
|
+
def __init__(self, model_loader: partial[NekoLoader.load], dtype: str=torch.float32, device='cuda', key_map_in=None, key_map_out=None):
|
15
|
+
super().__init__(key_map_in, key_map_out)
|
16
|
+
self.model_loader = model_loader
|
15
17
|
self.dtype = get_dtype(dtype)
|
18
|
+
self.device = device
|
16
19
|
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
20
|
+
def forward(self, in_preview=False, model=None, **states):
|
21
|
+
if in_preview:
|
22
|
+
model = self.model_loader(dtype=self.dtype, device=self.device, denoiser=model.denoiser, TE=model.TE, vae=model.vae)
|
23
|
+
else:
|
24
|
+
model = self.model_loader(dtype=self.dtype, device=self.device)
|
22
25
|
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
memory.vae = self.vae or AutoencoderKL.from_pretrained(self.pretrained_model, subfolder="vae", torch_dtype=self.dtype)
|
28
|
-
memory.scheduler = self.scheduler or PNDMScheduler.from_pretrained(self.pretrained_model, subfolder="scheduler", torch_dtype=self.dtype)
|
26
|
+
if isinstance(model, dict):
|
27
|
+
return model
|
28
|
+
else:
|
29
|
+
return {'model':model}
|
29
30
|
|
30
|
-
|
31
|
+
class LoadImageAction(Neko_LoadImageAction):
|
32
|
+
def __init__(self, image_paths: Union[str, List[str]], image_transforms=None, key_map_in=None, key_map_out=('input.x -> images',)):
|
33
|
+
super().__init__(image_paths, image_transforms, key_map_in, key_map_out)
|
31
34
|
|
32
35
|
class SaveImageAction(BasicAction):
|
33
|
-
|
34
|
-
|
36
|
+
def __init__(self, save_root: str, image_type: str = 'png', quality: int = 95, save_cfg=True, key_map_in=None, key_map_out=None):
|
37
|
+
super().__init__(key_map_in, key_map_out)
|
35
38
|
self.save_root = save_root
|
36
39
|
self.image_type = image_type
|
37
40
|
self.quality = quality
|
41
|
+
self.save_cfg = save_cfg
|
38
42
|
|
39
43
|
os.makedirs(save_root, exist_ok=True)
|
40
44
|
|
41
|
-
def forward(self, images, prompt, negative_prompt, seeds=None, **states):
|
42
|
-
|
45
|
+
def forward(self, images, prompt, negative_prompt, seeds, cfgs=None, parser=None, preview_root=None, preview_step=None, **states):
|
46
|
+
save_root = preview_root or self.save_root
|
47
|
+
num_img_exist = max([0]+[int(x.split('-', 1)[0]) for x in os.listdir(save_root) if x.rsplit('.', 1)[-1] in types_support])+1
|
43
48
|
|
44
49
|
for bid, (p, pn, img) in enumerate(zip(prompt, negative_prompt, images)):
|
45
|
-
img_path = os.path.join(
|
50
|
+
img_path = os.path.join(save_root, f"{preview_step or num_img_exist}-{seeds[bid]}-{to_validate_file(prompt[0])}.{self.image_type}")
|
46
51
|
img.save(img_path, quality=self.quality)
|
47
52
|
num_img_exist += 1
|
48
53
|
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
def forward(self, memory, **states):
|
53
|
-
memory.model_loader_unet = HCPModelLoader(memory.unet)
|
54
|
-
memory.model_loader_TE = HCPModelLoader(memory.text_encoder)
|
55
|
-
return states
|
56
|
-
|
57
|
-
class LoadPartAction(BasicAction, MemoryMixin):
|
58
|
-
@from_memory_context
|
59
|
-
def __init__(self, model: str, cfg):
|
60
|
-
self.model = model
|
61
|
-
self.cfg = cfg
|
62
|
-
|
63
|
-
def forward(self, memory, **states):
|
64
|
-
model_loader = memory[f"model_loader_{self.model}"]
|
65
|
-
model_loader.load_part(self.cfg)
|
66
|
-
return states
|
67
|
-
|
68
|
-
class LoadLoraAction(BasicAction, MemoryMixin):
|
69
|
-
@from_memory_context
|
70
|
-
def __init__(self, model: str, cfg):
|
71
|
-
self.model = model
|
72
|
-
self.cfg = cfg
|
73
|
-
|
74
|
-
def forward(self, memory, **states):
|
75
|
-
model_loader = memory[f"model_loader_{self.model}"]
|
76
|
-
model_loader.load_lora(self.cfg)
|
77
|
-
return states
|
78
|
-
|
79
|
-
class LoadPluginAction(BasicAction, MemoryMixin):
|
80
|
-
@from_memory_context
|
81
|
-
def __init__(self, model: str, cfg):
|
82
|
-
self.model = model
|
83
|
-
self.cfg = cfg
|
84
|
-
|
85
|
-
def forward(self, memory, **states):
|
86
|
-
model_loader = memory[f"model_loader_{self.model}"]
|
87
|
-
model_loader.load_plugin(self.cfg)
|
88
|
-
return states
|
54
|
+
if self.save_cfg:
|
55
|
+
cfgs.seed = seeds[bid]
|
56
|
+
parser.save_configs(cfgs, os.path.join(save_root, f"{preview_step or num_img_exist}-{seeds[bid]}-info"))
|