hcpdiff 0.9.0__py3-none-any.whl → 2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/__init__.py +4 -4
- hcpdiff/ckpt_manager/__init__.py +4 -5
- hcpdiff/ckpt_manager/ckpt.py +24 -0
- hcpdiff/ckpt_manager/format/__init__.py +4 -0
- hcpdiff/ckpt_manager/format/diffusers.py +59 -0
- hcpdiff/ckpt_manager/format/emb.py +21 -0
- hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
- hcpdiff/ckpt_manager/format/sd_single.py +41 -0
- hcpdiff/ckpt_manager/loader.py +64 -0
- hcpdiff/data/__init__.py +4 -28
- hcpdiff/data/cache/__init__.py +1 -0
- hcpdiff/data/cache/vae.py +102 -0
- hcpdiff/data/dataset.py +20 -0
- hcpdiff/data/handler/__init__.py +3 -0
- hcpdiff/data/handler/controlnet.py +18 -0
- hcpdiff/data/handler/diffusion.py +80 -0
- hcpdiff/data/handler/text.py +111 -0
- hcpdiff/data/source/__init__.py +1 -2
- hcpdiff/data/source/folder_class.py +12 -29
- hcpdiff/data/source/text2img.py +36 -74
- hcpdiff/data/source/text2img_cond.py +9 -15
- hcpdiff/diffusion/__init__.py +0 -0
- hcpdiff/diffusion/noise/__init__.py +2 -0
- hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
- hcpdiff/diffusion/noise/zero_terminal.py +39 -0
- hcpdiff/diffusion/sampler/__init__.py +5 -0
- hcpdiff/diffusion/sampler/base.py +72 -0
- hcpdiff/diffusion/sampler/ddpm.py +20 -0
- hcpdiff/diffusion/sampler/diffusers.py +66 -0
- hcpdiff/diffusion/sampler/edm.py +22 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
- hcpdiff/easy/__init__.py +2 -0
- hcpdiff/easy/cfg/__init__.py +3 -0
- hcpdiff/easy/cfg/sd15_train.py +201 -0
- hcpdiff/easy/cfg/sdxl_train.py +140 -0
- hcpdiff/easy/cfg/t2i.py +177 -0
- hcpdiff/easy/model/__init__.py +2 -0
- hcpdiff/easy/model/cnet.py +31 -0
- hcpdiff/easy/model/loader.py +79 -0
- hcpdiff/easy/sampler.py +46 -0
- hcpdiff/evaluate/__init__.py +1 -0
- hcpdiff/evaluate/previewer.py +60 -0
- hcpdiff/loss/__init__.py +4 -1
- hcpdiff/loss/base.py +41 -0
- hcpdiff/loss/gw.py +35 -0
- hcpdiff/loss/ssim.py +37 -0
- hcpdiff/loss/vlb.py +79 -0
- hcpdiff/loss/weighting.py +66 -0
- hcpdiff/models/__init__.py +2 -2
- hcpdiff/models/cfg_context.py +17 -14
- hcpdiff/models/compose/compose_hook.py +44 -23
- hcpdiff/models/compose/compose_tokenizer.py +21 -8
- hcpdiff/models/compose/sdxl_composer.py +4 -4
- hcpdiff/models/container.py +1 -1
- hcpdiff/models/controlnet.py +16 -16
- hcpdiff/models/lora_base_patch.py +14 -25
- hcpdiff/models/lora_layers.py +3 -9
- hcpdiff/models/lora_layers_patch.py +14 -24
- hcpdiff/models/text_emb_ex.py +84 -6
- hcpdiff/models/textencoder_ex.py +54 -18
- hcpdiff/models/wrapper/__init__.py +3 -0
- hcpdiff/models/wrapper/pixart.py +19 -0
- hcpdiff/models/wrapper/sd.py +218 -0
- hcpdiff/models/wrapper/utils.py +20 -0
- hcpdiff/parser/__init__.py +1 -0
- hcpdiff/parser/embpt.py +32 -0
- hcpdiff/tools/convert_caption_txt2json.py +1 -1
- hcpdiff/tools/dataset_generator.py +94 -0
- hcpdiff/tools/download_hf_model.py +24 -0
- hcpdiff/tools/embedding_convert.py +6 -2
- hcpdiff/tools/init_proj.py +3 -21
- hcpdiff/tools/lora_convert.py +19 -15
- hcpdiff/tools/save_model.py +12 -0
- hcpdiff/tools/sd2diffusers.py +1 -1
- hcpdiff/train_colo.py +1 -1
- hcpdiff/train_deepspeed.py +1 -1
- hcpdiff/trainer_ac.py +79 -0
- hcpdiff/trainer_ac_single.py +31 -0
- hcpdiff/utils/__init__.py +0 -2
- hcpdiff/utils/inpaint_pipe.py +790 -0
- hcpdiff/utils/net_utils.py +29 -6
- hcpdiff/utils/pipe_hook.py +46 -33
- hcpdiff/utils/utils.py +21 -4
- hcpdiff/workflow/__init__.py +15 -10
- hcpdiff/workflow/daam/__init__.py +1 -0
- hcpdiff/workflow/daam/act.py +66 -0
- hcpdiff/workflow/daam/hook.py +109 -0
- hcpdiff/workflow/diffusion.py +128 -136
- hcpdiff/workflow/fast.py +31 -0
- hcpdiff/workflow/flow.py +67 -0
- hcpdiff/workflow/io.py +36 -68
- hcpdiff/workflow/model.py +46 -43
- hcpdiff/workflow/text.py +84 -52
- hcpdiff/workflow/utils.py +32 -12
- hcpdiff/workflow/vae.py +37 -38
- hcpdiff-2.1.dist-info/METADATA +285 -0
- hcpdiff-2.1.dist-info/RECORD +114 -0
- {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
- hcpdiff-2.1.dist-info/entry_points.txt +5 -0
- hcpdiff/ckpt_manager/base.py +0 -16
- hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
- hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
- hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -60
- hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
- hcpdiff/data/bucket.py +0 -358
- hcpdiff/data/caption_loader.py +0 -80
- hcpdiff/data/cond_dataset.py +0 -40
- hcpdiff/data/crop_info_dataset.py +0 -40
- hcpdiff/data/data_processor.py +0 -33
- hcpdiff/data/pair_dataset.py +0 -146
- hcpdiff/data/sampler.py +0 -54
- hcpdiff/data/source/base.py +0 -30
- hcpdiff/data/utils.py +0 -80
- hcpdiff/infer_workflow.py +0 -57
- hcpdiff/loggers/__init__.py +0 -13
- hcpdiff/loggers/base_logger.py +0 -76
- hcpdiff/loggers/cli_logger.py +0 -40
- hcpdiff/loggers/preview/__init__.py +0 -1
- hcpdiff/loggers/preview/image_previewer.py +0 -149
- hcpdiff/loggers/tensorboard_logger.py +0 -30
- hcpdiff/loggers/wandb_logger.py +0 -31
- hcpdiff/loggers/webui_logger.py +0 -9
- hcpdiff/loss/min_snr_loss.py +0 -52
- hcpdiff/models/layers.py +0 -81
- hcpdiff/models/plugin.py +0 -348
- hcpdiff/models/wrapper.py +0 -75
- hcpdiff/noise/__init__.py +0 -3
- hcpdiff/noise/noise_base.py +0 -16
- hcpdiff/noise/pyramid_noise.py +0 -50
- hcpdiff/noise/zero_terminal.py +0 -44
- hcpdiff/train_ac.py +0 -565
- hcpdiff/train_ac_single.py +0 -39
- hcpdiff/utils/caption_tools.py +0 -105
- hcpdiff/utils/cfg_net_tools.py +0 -321
- hcpdiff/utils/cfg_resolvers.py +0 -16
- hcpdiff/utils/ema.py +0 -52
- hcpdiff/utils/img_size_tool.py +0 -248
- hcpdiff/vis/__init__.py +0 -3
- hcpdiff/vis/base_interface.py +0 -12
- hcpdiff/vis/disk_interface.py +0 -48
- hcpdiff/vis/webui_interface.py +0 -17
- hcpdiff/visualizer.py +0 -258
- hcpdiff/visualizer_reloadable.py +0 -237
- hcpdiff/workflow/base.py +0 -59
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -57
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero2.json +0 -32
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero3.json +0 -39
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
- hcpdiff-0.9.0.dist-info/METADATA +0 -199
- hcpdiff-0.9.0.dist-info/RECORD +0 -155
- hcpdiff-0.9.0.dist-info/entry_points.txt +0 -2
- {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
- {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
hcpdiff/models/textencoder_ex.py
CHANGED
@@ -8,29 +8,53 @@ textencoder_ex.py
|
|
8
8
|
:Licence: Apache-2.0
|
9
9
|
"""
|
10
10
|
|
11
|
-
from typing import Tuple, Optional
|
11
|
+
from typing import Tuple, Optional
|
12
12
|
|
13
13
|
import torch
|
14
14
|
from einops import repeat, rearrange
|
15
15
|
from einops.layers.torch import Rearrange
|
16
|
+
from loguru import logger
|
16
17
|
from torch import nn
|
18
|
+
from transformers import CLIPTextModelWithProjection, T5EncoderModel
|
17
19
|
from transformers.models.clip.modeling_clip import CLIPAttention
|
18
20
|
|
19
21
|
class TEEXHook:
|
20
|
-
def __init__(self, text_enc: nn.Module, tokenizer, N_repeats=
|
22
|
+
def __init__(self, text_enc: nn.Module, tokenizer, N_repeats=1, clip_skip=0, clip_final_norm=True, use_attention_mask=False):
|
21
23
|
self.text_enc = text_enc
|
22
24
|
self.tokenizer = tokenizer
|
23
25
|
|
24
26
|
self.N_repeats = N_repeats
|
25
27
|
self.clip_skip = clip_skip
|
26
28
|
self.clip_final_norm = clip_final_norm
|
27
|
-
self.device = device
|
28
|
-
self.attn_mult = None
|
29
29
|
self.use_attention_mask = use_attention_mask
|
30
30
|
|
31
31
|
text_enc.register_forward_hook(self.forward_hook)
|
32
32
|
text_enc.register_forward_pre_hook(self.forward_hook_input)
|
33
33
|
|
34
|
+
def find_final_norm(self, text_enc: nn.Module):
|
35
|
+
for module in text_enc.modules():
|
36
|
+
if 'final_layer_norm' in module._modules:
|
37
|
+
logger.info(f'find final_layer_norm in {type(module)}')
|
38
|
+
return module.final_layer_norm
|
39
|
+
|
40
|
+
logger.info(f'final_layer_norm not found in {type(text_enc)}')
|
41
|
+
return None
|
42
|
+
|
43
|
+
@property
|
44
|
+
def clip_final_norm(self):
|
45
|
+
return self.final_layer_norm is not None
|
46
|
+
|
47
|
+
@clip_final_norm.setter
|
48
|
+
def clip_final_norm(self, value: bool):
|
49
|
+
if value:
|
50
|
+
self.final_layer_norm = self.find_final_norm(self.text_enc)
|
51
|
+
else:
|
52
|
+
self.final_layer_norm = None
|
53
|
+
|
54
|
+
@property
|
55
|
+
def device(self):
|
56
|
+
return self.text_enc.device
|
57
|
+
|
34
58
|
def encode_prompt_to_emb(self, prompt):
|
35
59
|
text_inputs = self.tokenizer(
|
36
60
|
prompt,
|
@@ -50,12 +74,23 @@ class TEEXHook:
|
|
50
74
|
if position_ids is not None:
|
51
75
|
position_ids = position_ids.to(self.device)
|
52
76
|
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
77
|
+
# align with sd-webui
|
78
|
+
if isinstance(self.text_enc, CLIPTextModelWithProjection):
|
79
|
+
self.text_enc.text_projection.weight.data = self.text_enc.text_projection.weight.data.t()
|
80
|
+
|
81
|
+
if isinstance(self.text_enc, T5EncoderModel):
|
82
|
+
prompt_embeds, pooled_output = self.text_enc(
|
83
|
+
text_input_ids.to(self.device),
|
84
|
+
attention_mask=attention_mask,
|
85
|
+
output_hidden_states=True,
|
86
|
+
)
|
87
|
+
else:
|
88
|
+
prompt_embeds, pooled_output = self.text_enc(
|
89
|
+
text_input_ids.to(self.device),
|
90
|
+
attention_mask=attention_mask,
|
91
|
+
position_ids=position_ids,
|
92
|
+
output_hidden_states=True,
|
93
|
+
)
|
59
94
|
return prompt_embeds, pooled_output, attention_mask
|
60
95
|
|
61
96
|
def forward_hook_input(self, host, feat_in):
|
@@ -64,13 +99,12 @@ class TEEXHook:
|
|
64
99
|
|
65
100
|
def forward_hook(self, host, feat_in: Tuple[torch.Tensor], feat_out):
|
66
101
|
encoder_hidden_states = feat_out['hidden_states'][-self.clip_skip-1]
|
67
|
-
if self.clip_final_norm:
|
68
|
-
encoder_hidden_states = self.
|
102
|
+
if self.clip_final_norm and self.final_layer_norm is not None:
|
103
|
+
encoder_hidden_states = self.final_layer_norm(encoder_hidden_states)
|
69
104
|
if self.text_enc.training and self.clip_skip>0:
|
70
105
|
encoder_hidden_states = encoder_hidden_states+0*feat_out['last_hidden_state'].mean() # avoid unused parameters, make gradient checkpointing happy
|
71
|
-
|
72
106
|
encoder_hidden_states = rearrange(encoder_hidden_states, '(b r) ... -> b r ...', r=self.N_repeats) # [B, N_repeat, N_word+2, N_emb]
|
73
|
-
pooled_output = feat_out.pooler_output
|
107
|
+
pooled_output = feat_out.get('pooler_output', feat_out.get('text_embeds', None))
|
74
108
|
# TODO: may have better fusion method
|
75
109
|
if pooled_output is not None:
|
76
110
|
pooled_output = rearrange(pooled_output, '(b r) ... -> b r ...', r=self.N_repeats).mean(dim=1)
|
@@ -81,7 +115,7 @@ class TEEXHook:
|
|
81
115
|
return encoder_hidden_states, pooled_output
|
82
116
|
|
83
117
|
def pool_hidden_states(self, encoder_hidden_states, input_ids):
|
84
|
-
pooled_output = encoder_hidden_states[:, :, -1, :].mean(dim=1)
|
118
|
+
pooled_output = encoder_hidden_states[:, :, -1, :].mean(dim=1) # [B, N_emb]
|
85
119
|
return pooled_output
|
86
120
|
|
87
121
|
@staticmethod
|
@@ -147,9 +181,11 @@ class TEEXHook:
|
|
147
181
|
layer.forward = forward
|
148
182
|
|
149
183
|
@classmethod
|
150
|
-
def hook(cls, text_enc: nn.Module, tokenizer, N_repeats=3, clip_skip=0, clip_final_norm=True,
|
151
|
-
return cls(text_enc, tokenizer, N_repeats=N_repeats, clip_skip=clip_skip, clip_final_norm=clip_final_norm,
|
184
|
+
def hook(cls, text_enc: nn.Module, tokenizer, N_repeats=3, clip_skip=0, clip_final_norm=True, use_attention_mask=False):
|
185
|
+
return cls(text_enc, tokenizer, N_repeats=N_repeats, clip_skip=clip_skip, clip_final_norm=clip_final_norm,
|
186
|
+
use_attention_mask=use_attention_mask)
|
152
187
|
|
153
188
|
@classmethod
|
154
189
|
def hook_pipe(cls, pipe, N_repeats=3, clip_skip=0, clip_final_norm=True, use_attention_mask=False):
|
155
|
-
return cls(pipe.text_encoder, pipe.tokenizer, N_repeats=N_repeats,
|
190
|
+
return cls(pipe.text_encoder, pipe.tokenizer, N_repeats=N_repeats, clip_skip=clip_skip, clip_final_norm=clip_final_norm,
|
191
|
+
use_attention_mask=use_attention_mask)
|
@@ -0,0 +1,19 @@
|
|
1
|
+
from .sd import SD15Wrapper
|
2
|
+
from hcpdiff.utils import pad_attn_bias
|
3
|
+
|
4
|
+
class PixArtWrapper(SD15Wrapper):
|
5
|
+
def forward_denoiser(self, x_t, prompt_ids, encoder_hidden_states, timesteps, attn_mask=None, position_ids=None, resolution=None, aspect_ratio=None,
|
6
|
+
plugin_input={}, **kwargs):
|
7
|
+
if attn_mask is not None:
|
8
|
+
attn_mask[:, :self.min_attnmask] = 1
|
9
|
+
encoder_hidden_states, attn_mask = pad_attn_bias(encoder_hidden_states, attn_mask)
|
10
|
+
|
11
|
+
input_all = dict(prompt_ids=prompt_ids, timesteps=timesteps, position_ids=position_ids, attn_mask=attn_mask,
|
12
|
+
encoder_hidden_states=encoder_hidden_states, **plugin_input)
|
13
|
+
if hasattr(self.denoiser, 'input_feeder'):
|
14
|
+
for feeder in self.denoiser.input_feeder:
|
15
|
+
feeder(input_all)
|
16
|
+
added_cond_kwargs = {"resolution":resolution, "aspect_ratio":aspect_ratio}
|
17
|
+
model_pred = self.denoiser(x_t, encoder_hidden_states, timesteps, encoder_attention_mask=attn_mask,
|
18
|
+
added_cond_kwargs=added_cond_kwargs).sample # Predict the noise residual
|
19
|
+
return model_pred
|
@@ -0,0 +1,218 @@
|
|
1
|
+
from contextlib import nullcontext
|
2
|
+
from functools import partial
|
3
|
+
from typing import Dict, Union
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from diffusers import AutoencoderKL, UNet2DConditionModel
|
7
|
+
from rainbowneko.models.wrapper import BaseWrapper
|
8
|
+
from torch import Tensor
|
9
|
+
from torch import nn
|
10
|
+
|
11
|
+
from hcpdiff.diffusion.sampler import BaseSampler
|
12
|
+
from hcpdiff.models import TEEXHook
|
13
|
+
from hcpdiff.models.compose import ComposeTEEXHook
|
14
|
+
from hcpdiff.utils import pad_attn_bias
|
15
|
+
from .utils import TEHookCFG, SD15_TEHookCFG, SDXL_TEHookCFG
|
16
|
+
from ..cfg_context import CFGContext
|
17
|
+
|
18
|
+
class SD15Wrapper(BaseWrapper):
|
19
|
+
def __init__(self, denoiser: UNet2DConditionModel, TE, vae: AutoencoderKL, noise_sampler: BaseSampler, tokenizer, min_attnmask=0,
|
20
|
+
pred_type='eps', TE_hook_cfg:TEHookCFG=SD15_TEHookCFG, cfg_context=CFGContext(), key_map_in=None, key_map_out=None):
|
21
|
+
super().__init__()
|
22
|
+
self.key_mapper_in = self.build_mapper(key_map_in, None, (
|
23
|
+
'prompt -> prompt_ids', 'image -> image', 'attn_mask -> attn_mask', 'position_ids -> position_ids', 'neg_prompt -> neg_prompt_ids',
|
24
|
+
'neg_attn_mask -> neg_attn_mask', 'neg_position_ids -> neg_position_ids', 'plugin_input -> plugin_input'))
|
25
|
+
self.key_mapper_out = self.build_mapper(key_map_out, None, None)
|
26
|
+
|
27
|
+
self.denoiser = denoiser
|
28
|
+
self.TE = TE
|
29
|
+
self.vae = vae
|
30
|
+
self.noise_sampler = noise_sampler
|
31
|
+
self.tokenizer = tokenizer
|
32
|
+
self.min_attnmask = min_attnmask
|
33
|
+
|
34
|
+
self.pred_type = pred_type
|
35
|
+
|
36
|
+
self.TE_hook_cfg = TEHookCFG.create(TE_hook_cfg)
|
37
|
+
self.cfg_context = cfg_context
|
38
|
+
self.tokenizer.N_repeats = self.TE_hook_cfg.tokenizer_repeats
|
39
|
+
|
40
|
+
def post_init(self):
|
41
|
+
self.make_TE_hook(self.TE_hook_cfg)
|
42
|
+
|
43
|
+
self.vae_trainable = False
|
44
|
+
if self.vae is not None:
|
45
|
+
for p in self.vae.parameters():
|
46
|
+
if p.requires_grad:
|
47
|
+
self.vae_trainable = True
|
48
|
+
break
|
49
|
+
|
50
|
+
self.TE_trainable = False
|
51
|
+
for p in self.TE.parameters():
|
52
|
+
if p.requires_grad:
|
53
|
+
self.TE_trainable = True
|
54
|
+
break
|
55
|
+
|
56
|
+
def make_TE_hook(self, TE_hook_cfg):
|
57
|
+
# Hook and extend text_encoder
|
58
|
+
self.text_enc_hook = TEEXHook.hook(self.TE, self.tokenizer, N_repeats=TE_hook_cfg.tokenizer_repeats,
|
59
|
+
clip_skip=TE_hook_cfg.clip_skip, clip_final_norm=TE_hook_cfg.clip_final_norm)
|
60
|
+
|
61
|
+
def get_latents(self, image: Tensor):
|
62
|
+
if image.shape[1] == 3:
|
63
|
+
with torch.no_grad() if self.vae_trainable else nullcontext():
|
64
|
+
latents = self.vae.encode(image.to(dtype=self.vae.dtype)).latent_dist.sample()
|
65
|
+
latents = latents*self.vae.config.scaling_factor
|
66
|
+
else:
|
67
|
+
latents = image # Cached latents
|
68
|
+
return latents
|
69
|
+
|
70
|
+
def forward_TE(self, prompt_ids, timesteps, attn_mask=None, position_ids=None, plugin_input={}, **kwargs):
|
71
|
+
input_all = dict(prompt_ids=prompt_ids, timesteps=timesteps, position_ids=position_ids, attn_mask=attn_mask, **plugin_input)
|
72
|
+
if hasattr(self.TE, 'input_feeder'):
|
73
|
+
for feeder in self.TE.input_feeder:
|
74
|
+
feeder(input_all)
|
75
|
+
# Get the text embedding for conditioning
|
76
|
+
encoder_hidden_states = self.TE(prompt_ids, position_ids=position_ids, attention_mask=attn_mask, output_hidden_states=True)[0]
|
77
|
+
return encoder_hidden_states
|
78
|
+
|
79
|
+
def forward_denoiser(self, x_t, prompt_ids, encoder_hidden_states, timesteps, attn_mask=None, position_ids=None, plugin_input={}, **kwargs):
|
80
|
+
if attn_mask is not None:
|
81
|
+
attn_mask[:, :self.min_attnmask] = 1
|
82
|
+
encoder_hidden_states, attn_mask = pad_attn_bias(encoder_hidden_states, attn_mask)
|
83
|
+
|
84
|
+
input_all = dict(prompt_ids=prompt_ids, timesteps=timesteps, position_ids=position_ids, attn_mask=attn_mask,
|
85
|
+
encoder_hidden_states=encoder_hidden_states, **plugin_input)
|
86
|
+
if hasattr(self.denoiser, 'input_feeder'):
|
87
|
+
for feeder in self.denoiser.input_feeder:
|
88
|
+
feeder(input_all)
|
89
|
+
model_pred = self.denoiser(x_t, timesteps, encoder_hidden_states, encoder_attention_mask=attn_mask).sample # Predict the noise residual
|
90
|
+
return model_pred
|
91
|
+
|
92
|
+
def model_forward(self, prompt_ids, image, attn_mask=None, position_ids=None, neg_prompt_ids=None, neg_attn_mask=None, neg_position_ids=None,
|
93
|
+
plugin_input={}, **kwargs):
|
94
|
+
# input prepare
|
95
|
+
x_0 = self.get_latents(image)
|
96
|
+
x_t, noise, sigma, timesteps = self.noise_sampler.add_noise_rand_t(x_0)
|
97
|
+
x_t_in = x_t*self.noise_sampler.c_in(sigma).to(dtype=x_t.dtype)
|
98
|
+
|
99
|
+
if neg_prompt_ids:
|
100
|
+
prompt_ids = torch.cat([neg_prompt_ids, prompt_ids], dim=0)
|
101
|
+
if neg_attn_mask:
|
102
|
+
attn_mask = torch.cat([neg_attn_mask, attn_mask], dim=0)
|
103
|
+
if neg_position_ids:
|
104
|
+
position_ids = torch.cat([neg_position_ids, position_ids], dim=0)
|
105
|
+
|
106
|
+
# model forward
|
107
|
+
x_t_in, timesteps = self.cfg_context.pre(x_t_in, timesteps)
|
108
|
+
encoder_hidden_states = self.forward_TE(prompt_ids, timesteps, attn_mask=attn_mask, position_ids=position_ids,
|
109
|
+
plugin_input=plugin_input, **kwargs)
|
110
|
+
model_pred = self.forward_denoiser(x_t_in, prompt_ids, encoder_hidden_states, timesteps, attn_mask=attn_mask, position_ids=position_ids,
|
111
|
+
plugin_input=plugin_input, **kwargs)
|
112
|
+
model_pred = self.cfg_context.post(model_pred)
|
113
|
+
|
114
|
+
return dict(model_pred=model_pred, noise=noise, sigma=sigma, timesteps=timesteps, x_0=x_0, x_t=x_t, pred_type=self.pred_type,
|
115
|
+
noise_sampler=self.noise_sampler)
|
116
|
+
|
117
|
+
def forward(self, ds_name=None, **kwargs):
|
118
|
+
model_args, model_kwargs = self.get_map_data(self.key_mapper_in, kwargs, ds_name)
|
119
|
+
out = self.model_forward(*model_args, **model_kwargs)
|
120
|
+
return self.get_map_data(self.key_mapper_out, out, ds_name=ds_name)[1]
|
121
|
+
|
122
|
+
def enable_gradient_checkpointing(self):
|
123
|
+
def grad_ckpt_enable(m):
|
124
|
+
if getattr(m, 'gradient_checkpointing', False):
|
125
|
+
m.training = True
|
126
|
+
|
127
|
+
self.denoiser.enable_gradient_checkpointing()
|
128
|
+
if self.TE_trainable:
|
129
|
+
self.TE.gradient_checkpointing_enable()
|
130
|
+
self.apply(grad_ckpt_enable)
|
131
|
+
|
132
|
+
def enable_xformers(self):
|
133
|
+
self.denoiser.enable_xformers_memory_efficient_attention()
|
134
|
+
|
135
|
+
@property
|
136
|
+
def trainable_parameters(self):
|
137
|
+
return [p for p in self.parameters() if p.requires_grad]
|
138
|
+
|
139
|
+
@property
|
140
|
+
def trainable_models(self) -> Dict[str, nn.Module]:
|
141
|
+
return {'self':self}
|
142
|
+
|
143
|
+
def set_dtype(self, dtype, vae_dtype):
|
144
|
+
self.dtype = dtype
|
145
|
+
self.vae_dtype = vae_dtype
|
146
|
+
# Move vae and text_encoder to device and cast to weight_dtype
|
147
|
+
if self.vae is not None:
|
148
|
+
self.vae = self.vae.to(dtype=vae_dtype)
|
149
|
+
if not self.TE_trainable:
|
150
|
+
self.TE = self.TE.to(dtype=dtype)
|
151
|
+
|
152
|
+
@classmethod
|
153
|
+
def from_pretrained(cls, models: Union[partial, Dict[str, nn.Module]], **kwargs):
|
154
|
+
models = models() if isinstance(models, partial) else models
|
155
|
+
return cls(models['denoiser'], models['TE'], models['vae'], models['noise_sampler'], models['tokenizer'], **kwargs)
|
156
|
+
|
157
|
+
class SDXLWrapper(SD15Wrapper):
|
158
|
+
def __init__(self, denoiser: UNet2DConditionModel, TE, vae: AutoencoderKL, noise_sampler: BaseSampler, tokenizer, min_attnmask=0,
|
159
|
+
pred_type='eps', TE_hook_cfg:TEHookCFG=SDXL_TEHookCFG, cfg_context=CFGContext(), key_map_in=None, key_map_out=None):
|
160
|
+
super().__init__(denoiser, TE, vae, noise_sampler, tokenizer, min_attnmask, pred_type, TE_hook_cfg, cfg_context, key_map_in, key_map_out)
|
161
|
+
self.key_mapper_in = self.build_mapper(key_map_in, None, (
|
162
|
+
'prompt -> prompt_ids', 'image -> image', 'attn_mask -> attn_mask', 'position_ids -> position_ids', 'neg_prompt -> neg_prompt_ids',
|
163
|
+
'neg_attn_mask -> neg_attn_mask', 'neg_position_ids -> neg_position_ids', 'plugin_input -> plugin_input', 'coord -> crop_info'))
|
164
|
+
|
165
|
+
def make_TE_hook(self, TE_hook_cfg):
|
166
|
+
# Hook and extend text_encoder
|
167
|
+
self.text_enc_hook = ComposeTEEXHook.hook(self.TE, self.tokenizer, N_repeats=TE_hook_cfg.tokenizer_repeats,
|
168
|
+
clip_skip=TE_hook_cfg.clip_skip, clip_final_norm=TE_hook_cfg.clip_final_norm)
|
169
|
+
|
170
|
+
def forward_TE(self, prompt_ids, timesteps, attn_mask=None, position_ids=None, plugin_input={}, **kwargs):
|
171
|
+
input_all = dict(prompt_ids=prompt_ids, timesteps=timesteps, position_ids=position_ids, attn_mask=attn_mask, **plugin_input)
|
172
|
+
if hasattr(self.TE, 'input_feeder'):
|
173
|
+
for feeder in self.TE.input_feeder:
|
174
|
+
feeder(input_all)
|
175
|
+
# Get the text embedding for conditioning
|
176
|
+
encoder_hidden_states, pooled_output = self.TE(prompt_ids, position_ids=position_ids, attention_mask=attn_mask, output_hidden_states=True)
|
177
|
+
return encoder_hidden_states, pooled_output
|
178
|
+
|
179
|
+
def forward_denoiser(self, x_t, prompt_ids, encoder_hidden_states, timesteps, added_cond_kwargs, attn_mask=None, position_ids=None,
|
180
|
+
plugin_input={}, **kwargs):
|
181
|
+
if attn_mask is not None:
|
182
|
+
attn_mask[:, :self.min_attnmask] = 1
|
183
|
+
encoder_hidden_states, attn_mask = pad_attn_bias(encoder_hidden_states, attn_mask)
|
184
|
+
|
185
|
+
input_all = dict(prompt_ids=prompt_ids, timesteps=timesteps, position_ids=position_ids, attn_mask=attn_mask,
|
186
|
+
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs, **plugin_input)
|
187
|
+
if hasattr(self.denoiser, 'input_feeder'):
|
188
|
+
for feeder in self.denoiser.input_feeder:
|
189
|
+
feeder(input_all)
|
190
|
+
model_pred = self.denoiser(x_t, timesteps, encoder_hidden_states, encoder_attention_mask=attn_mask,
|
191
|
+
added_cond_kwargs=added_cond_kwargs).sample # Predict the noise residual
|
192
|
+
return model_pred
|
193
|
+
|
194
|
+
def model_forward(self, prompt_ids, image, attn_mask=None, position_ids=None, neg_prompt_ids=None, neg_attn_mask=None, neg_position_ids=None,
|
195
|
+
crop_info=None, plugin_input={}):
|
196
|
+
# input prepare
|
197
|
+
x_0 = self.get_latents(image)
|
198
|
+
x_t, noise, sigma, timesteps = self.noise_sampler.add_noise_rand_t(x_0)
|
199
|
+
x_t_in = x_t*self.noise_sampler.c_in(sigma).to(dtype=x_t.dtype)
|
200
|
+
|
201
|
+
if neg_prompt_ids:
|
202
|
+
prompt_ids = torch.cat([neg_prompt_ids, prompt_ids], dim=0)
|
203
|
+
if neg_attn_mask:
|
204
|
+
attn_mask = torch.cat([neg_attn_mask, attn_mask], dim=0)
|
205
|
+
if neg_position_ids:
|
206
|
+
position_ids = torch.cat([neg_position_ids, position_ids], dim=0)
|
207
|
+
|
208
|
+
# model forward
|
209
|
+
x_t_in, timesteps = self.cfg_context.pre(x_t_in, timesteps)
|
210
|
+
encoder_hidden_states, pooled_output = self.forward_TE(prompt_ids, timesteps, attn_mask=attn_mask, position_ids=position_ids,
|
211
|
+
plugin_input=plugin_input)
|
212
|
+
added_cond_kwargs = {"text_embeds":pooled_output[-1], "time_ids":crop_info}
|
213
|
+
model_pred = self.forward_denoiser(x_t_in, prompt_ids, encoder_hidden_states, timesteps, added_cond_kwargs=added_cond_kwargs,
|
214
|
+
attn_mask=attn_mask, position_ids=position_ids, plugin_input=plugin_input)
|
215
|
+
model_pred = self.cfg_context.post(model_pred)
|
216
|
+
|
217
|
+
return dict(model_pred=model_pred, noise=noise, sigma=sigma, timesteps=timesteps, x_0=x_0, x_t=x_t, pred_type=self.pred_type,
|
218
|
+
noise_sampler=self.noise_sampler)
|
@@ -0,0 +1,20 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from rainbowneko.utils import is_dict
|
3
|
+
|
4
|
+
class TEHookCFG:
|
5
|
+
def __init__(self, tokenizer_repeats: int = 1, clip_skip: int = 0, clip_final_norm: bool = True):
|
6
|
+
self.tokenizer_repeats = tokenizer_repeats
|
7
|
+
self.clip_skip = clip_skip
|
8
|
+
self.clip_final_norm = clip_final_norm
|
9
|
+
|
10
|
+
@classmethod
|
11
|
+
def create(cls, cfg):
|
12
|
+
if is_dict(cfg):
|
13
|
+
return cls(**cfg)
|
14
|
+
elif isinstance(cfg, cls):
|
15
|
+
return cfg
|
16
|
+
else:
|
17
|
+
raise ValueError(f'Invalid TEHookCFG type: {type(cfg)}')
|
18
|
+
|
19
|
+
SD15_TEHookCFG = TEHookCFG()
|
20
|
+
SDXL_TEHookCFG = TEHookCFG(clip_skip=1, clip_final_norm=False)
|
@@ -0,0 +1 @@
|
|
1
|
+
from .embpt import CfgEmbPTParser
|
hcpdiff/parser/embpt.py
ADDED
@@ -0,0 +1,32 @@
|
|
1
|
+
from typing import Dict, Tuple, List
|
2
|
+
from rainbowneko.utils import Path_Like
|
3
|
+
from hcpdiff.models import EmbeddingPTHook
|
4
|
+
from torch import Tensor
|
5
|
+
|
6
|
+
class CfgEmbPTParser:
|
7
|
+
def __init__(self, emb_dir: Path_Like, cfg_pt: Dict[str, Dict], lr: float = 1e-5, weight_decay: float = 0):
|
8
|
+
self.emb_dir = emb_dir
|
9
|
+
self.cfg_pt = cfg_pt
|
10
|
+
self.lr = lr
|
11
|
+
self.weight_decay = weight_decay
|
12
|
+
|
13
|
+
def get_params_group(self, model) -> Tuple[List, Dict[str, Tensor]]:
|
14
|
+
self.embedding_hook, self.ex_words_emb = EmbeddingPTHook.hook_from_dir(
|
15
|
+
self.emb_dir, model.tokenizer, model.TE, N_repeats=model.tokenizer.N_repeats)
|
16
|
+
self.embedding_hook.requires_grad_(False)
|
17
|
+
|
18
|
+
train_params_emb = []
|
19
|
+
train_pts = {}
|
20
|
+
for pt_name, info in self.cfg_pt.items():
|
21
|
+
word_emb = self.ex_words_emb[pt_name]
|
22
|
+
train_pts[pt_name] = word_emb
|
23
|
+
word_emb.requires_grad = True
|
24
|
+
self.embedding_hook.emb_train.append(word_emb)
|
25
|
+
param_group = {'params':word_emb}
|
26
|
+
if 'lr' in info:
|
27
|
+
param_group['lr'] = info.lr
|
28
|
+
if 'weight_decay' in info:
|
29
|
+
param_group['weight_decay'] = info.weight_decay
|
30
|
+
train_params_emb.append(param_group)
|
31
|
+
|
32
|
+
return train_params_emb, train_pts
|
@@ -2,7 +2,7 @@ import argparse
|
|
2
2
|
import json
|
3
3
|
import os
|
4
4
|
|
5
|
-
from
|
5
|
+
from rainbowneko.utils import types_support
|
6
6
|
|
7
7
|
parser = argparse.ArgumentParser(description='Stable Diffusion Training')
|
8
8
|
parser.add_argument('--data_root', type=str, default='')
|
@@ -0,0 +1,94 @@
|
|
1
|
+
import argparse
|
2
|
+
import json
|
3
|
+
import os.path
|
4
|
+
from typing import Callable
|
5
|
+
|
6
|
+
import pyarrow.parquet as pq
|
7
|
+
import torch
|
8
|
+
from PIL import Image
|
9
|
+
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
10
|
+
from tqdm.auto import tqdm
|
11
|
+
|
12
|
+
from hcpdiff.data.caption_loader import auto_caption_loader
|
13
|
+
|
14
|
+
class DatasetCreator:
|
15
|
+
def __init__(self, pretrained_model, out_dir: str, img_w: int=512, img_h: int=512):
|
16
|
+
scheduler = DPMSolverMultistepScheduler(
|
17
|
+
beta_start = 0.00085,
|
18
|
+
beta_end = 0.012,
|
19
|
+
beta_schedule = 'scaled_linear',
|
20
|
+
algorithm_type = 'dpmsolver++',
|
21
|
+
use_karras_sigmas = True,
|
22
|
+
)
|
23
|
+
|
24
|
+
self.pipeline = DiffusionPipeline.from_pretrained(pretrained_model, scheduler=scheduler, torch_dtype=torch.float16)
|
25
|
+
self.pipeline.requires_safety_checker = False
|
26
|
+
self.pipeline.safety_checker = None
|
27
|
+
self.pipeline.to("cuda")
|
28
|
+
self.pipeline.unet.to(memory_format=torch.channels_last)
|
29
|
+
#self.pipeline.enable_xformers_memory_efficient_attention()
|
30
|
+
|
31
|
+
self.out_dir = out_dir
|
32
|
+
self.img_w = img_w
|
33
|
+
self.img_h = img_h
|
34
|
+
|
35
|
+
def create_from_prompt_dataset(self, prompt_file: str, negative_prompt: str, bs: int, num: int=None, repeat:int=1, save_fmt:str='txt',
|
36
|
+
callback: Callable[[int, int], bool] = None):
|
37
|
+
os.makedirs(self.out_dir, exist_ok=True)
|
38
|
+
data = auto_caption_loader(prompt_file).load()
|
39
|
+
data = list(data.items())
|
40
|
+
data = self.split_batch(data, bs) # [[(k,v),...],...]
|
41
|
+
|
42
|
+
if num is None:
|
43
|
+
num = len(data)
|
44
|
+
total = num*bs
|
45
|
+
count = 0
|
46
|
+
captions = {}
|
47
|
+
with torch.inference_mode():
|
48
|
+
for i in tqdm(range(num)):
|
49
|
+
for r in range(repeat):
|
50
|
+
name_batch, p_batch = list(zip(*data[i%len(data)]))
|
51
|
+
imgs = self.pipeline(list(p_batch), negative_prompt=[negative_prompt]*len(p_batch), num_inference_steps=25,
|
52
|
+
width=self.img_w, height=self.img_h).images
|
53
|
+
for name, prompt, img in zip(name_batch, p_batch, imgs):
|
54
|
+
img.save(os.path.join(self.out_dir, f'{count}_{name}.png'), format='PNG')
|
55
|
+
captions[f'{count}_{name}'] = prompt
|
56
|
+
count += 1
|
57
|
+
if callback:
|
58
|
+
if not callback(count, total):
|
59
|
+
break
|
60
|
+
|
61
|
+
if save_fmt=='txt':
|
62
|
+
for k, v in captions.items():
|
63
|
+
with open(os.path.join(self.out_dir, f'{k}.txt'), "w") as f:
|
64
|
+
f.write(v)
|
65
|
+
elif save_fmt=='json':
|
66
|
+
with open(os.path.join(self.out_dir, f'image_captions.json'), "w") as f:
|
67
|
+
json.dump(captions, f)
|
68
|
+
else:
|
69
|
+
raise ValueError(f"Invalid save_fmt: {save_fmt}")
|
70
|
+
|
71
|
+
@staticmethod
|
72
|
+
def split_batch(data, bs):
|
73
|
+
return [data[i:i+bs] for i in range(0, len(data), bs)]
|
74
|
+
|
75
|
+
# python dataset_generator.py --prompt_file 标注文件或文件夹路径 --model 模型名称 --out_dir 输出文件夹路径 --repeat 每个prompt生成几个图 --bs batch_size --img_w 图片宽度 --img_h 图片高度
|
76
|
+
# python dataset_generator.py --prompt_file 标注文件或文件夹路径 --model 模型名称 --out_dir 输出文件夹路径 --repeat 1 --bs 4 --img_w 640 --img_h 640
|
77
|
+
if __name__ == '__main__':
|
78
|
+
torch.backends.cudnn.benchmark = True
|
79
|
+
parser = argparse.ArgumentParser(description='Diffusion Dataset Generator')
|
80
|
+
parser.add_argument('--prompt_file', type=str, default='')
|
81
|
+
parser.add_argument('--model', type=str, default='runwayml/stable-diffusion-v1-5')
|
82
|
+
parser.add_argument('--out_dir', type=str, default=r'./prompt_ds')
|
83
|
+
parser.add_argument('--negative_prompt', type=str,
|
84
|
+
default='lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry')
|
85
|
+
parser.add_argument('--num', type=int, default=200)
|
86
|
+
parser.add_argument('--repeat', type=int, default=1)
|
87
|
+
parser.add_argument('--save_fmt', type=str, default='txt')
|
88
|
+
parser.add_argument('--bs', type=int, default=4)
|
89
|
+
parser.add_argument('--img_w', type=int, default=512)
|
90
|
+
parser.add_argument('--img_h', type=int, default=512)
|
91
|
+
args = parser.parse_args()
|
92
|
+
|
93
|
+
ds_creator = DatasetCreator(args.model, args.out_dir, args.img_w, args.img_h)
|
94
|
+
ds_creator.create_from_prompt_dataset(args.prompt_file, args.negative_prompt, args.bs, args.num, repeat=args.repeat, save_fmt=args.save_fmt)
|
@@ -0,0 +1,24 @@
|
|
1
|
+
from diffusers import DiffusionPipeline
|
2
|
+
import argparse
|
3
|
+
import torch
|
4
|
+
|
5
|
+
if __name__ == '__main__':
|
6
|
+
parser = argparse.ArgumentParser(description='Download Model')
|
7
|
+
parser.add_argument('--model', type=str, default='runwayml/stable-diffusion-v1-5')
|
8
|
+
parser.add_argument("--fp16", default=False, action="store_true")
|
9
|
+
parser.add_argument("--use_safetensors", default=False, action="store_true")
|
10
|
+
parser.add_argument("--out_path", type=str, default='ckpts/sd15')
|
11
|
+
args = parser.parse_args()
|
12
|
+
|
13
|
+
load_args = dict(torch_dtype = torch.float16 if args.fp16 else torch.float32)
|
14
|
+
save_args = dict()
|
15
|
+
|
16
|
+
if args.fp16:
|
17
|
+
load_args['variant'] = "fp16"
|
18
|
+
save_args['variant'] = "fp16"
|
19
|
+
if args.use_safetensors:
|
20
|
+
load_args['use_safetensors'] = True
|
21
|
+
save_args['safe_serialization'] = True
|
22
|
+
|
23
|
+
pipe = DiffusionPipeline.from_pretrained(args.model, **load_args)
|
24
|
+
pipe.save_pretrained(args.out_path, **save_args)
|
@@ -1,4 +1,6 @@
|
|
1
1
|
import argparse
|
2
|
+
import os
|
3
|
+
|
2
4
|
import torch
|
3
5
|
import shutil
|
4
6
|
|
@@ -14,6 +16,8 @@ if __name__ == '__main__':
|
|
14
16
|
parser.add_argument("--sdxl", default=None, type=str)
|
15
17
|
args = parser.parse_args()
|
16
18
|
|
19
|
+
os.makedirs(os.path.dirname(args.dump_path), exist_ok=True)
|
20
|
+
|
17
21
|
print(f'convert embedding')
|
18
22
|
ckpt_manager = auto_manager(args.embedding_path)
|
19
23
|
embedding = ckpt_manager.load_ckpt(args.embedding_path)
|
@@ -24,12 +28,12 @@ if __name__ == '__main__':
|
|
24
28
|
if args.to_webui:
|
25
29
|
new = embedding['string_to_param']['*']
|
26
30
|
new = {'clip_l':new[:, :768], 'clip_g':new[:, 768:]}
|
27
|
-
ckpt_manager._save_ckpt(new, args.dump_path)
|
31
|
+
ckpt_manager._save_ckpt(new, save_path=args.dump_path)
|
28
32
|
|
29
33
|
elif args.from_webui:
|
30
34
|
new = torch.cat([embedding['clip_l'], embedding['clip_g']], dim=1)
|
31
35
|
new = {'string_to_param':{'*':new}}
|
32
|
-
ckpt_manager._save_ckpt(new, args.dump_path)
|
36
|
+
ckpt_manager._save_ckpt(new, save_path=args.dump_path)
|
33
37
|
else:
|
34
38
|
raise ValueError("Either --to_webui or --from_webui should be set.")
|
35
39
|
|
hcpdiff/tools/init_proj.py
CHANGED
@@ -1,23 +1,5 @@
|
|
1
|
-
import
|
2
|
-
import shutil
|
3
|
-
import os
|
1
|
+
from rainbowneko.tools.init_proj import copy_package_data
|
4
2
|
|
5
3
|
def main():
|
6
|
-
|
7
|
-
|
8
|
-
prefix = os.path.join(prefix, 'local')
|
9
|
-
try:
|
10
|
-
if os.path.exists(r'./cfgs'):
|
11
|
-
shutil.rmtree(r'./cfgs')
|
12
|
-
if os.path.exists(r'./prompt_tuning_template'):
|
13
|
-
shutil.rmtree(r'./prompt_tuning_template')
|
14
|
-
shutil.copytree(os.path.join(prefix, 'hcpdiff/cfgs'), r'./cfgs')
|
15
|
-
shutil.copytree(os.path.join(prefix, 'hcpdiff/prompt_tuning_template'), r'./prompt_tuning_template')
|
16
|
-
except:
|
17
|
-
try:
|
18
|
-
shutil.copytree(os.path.join(prefix, '../hcpdiff/cfgs'), r'./cfgs')
|
19
|
-
shutil.copytree(os.path.join(prefix, '../hcpdiff/prompt_tuning_template'), r'./prompt_tuning_template')
|
20
|
-
except:
|
21
|
-
this_file_dir = os.path.dirname(os.path.abspath(__file__))
|
22
|
-
shutil.copytree(os.path.join(this_file_dir, '../../cfgs'), r'./cfgs')
|
23
|
-
shutil.copytree(os.path.join(this_file_dir, '../../prompt_tuning_template'), r'./prompt_tuning_template')
|
4
|
+
copy_package_data('hcpdiff', 'cfgs', './cfgs')
|
5
|
+
copy_package_data('hcpdiff', 'prompt_template', './prompt_template')
|