hcpdiff 0.9.0__py3-none-any.whl → 2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/__init__.py +4 -4
- hcpdiff/ckpt_manager/__init__.py +4 -5
- hcpdiff/ckpt_manager/ckpt.py +24 -0
- hcpdiff/ckpt_manager/format/__init__.py +4 -0
- hcpdiff/ckpt_manager/format/diffusers.py +59 -0
- hcpdiff/ckpt_manager/format/emb.py +21 -0
- hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
- hcpdiff/ckpt_manager/format/sd_single.py +41 -0
- hcpdiff/ckpt_manager/loader.py +64 -0
- hcpdiff/data/__init__.py +4 -28
- hcpdiff/data/cache/__init__.py +1 -0
- hcpdiff/data/cache/vae.py +102 -0
- hcpdiff/data/dataset.py +20 -0
- hcpdiff/data/handler/__init__.py +3 -0
- hcpdiff/data/handler/controlnet.py +18 -0
- hcpdiff/data/handler/diffusion.py +80 -0
- hcpdiff/data/handler/text.py +111 -0
- hcpdiff/data/source/__init__.py +1 -2
- hcpdiff/data/source/folder_class.py +12 -29
- hcpdiff/data/source/text2img.py +36 -74
- hcpdiff/data/source/text2img_cond.py +9 -15
- hcpdiff/diffusion/__init__.py +0 -0
- hcpdiff/diffusion/noise/__init__.py +2 -0
- hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
- hcpdiff/diffusion/noise/zero_terminal.py +39 -0
- hcpdiff/diffusion/sampler/__init__.py +5 -0
- hcpdiff/diffusion/sampler/base.py +72 -0
- hcpdiff/diffusion/sampler/ddpm.py +20 -0
- hcpdiff/diffusion/sampler/diffusers.py +66 -0
- hcpdiff/diffusion/sampler/edm.py +22 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
- hcpdiff/easy/__init__.py +2 -0
- hcpdiff/easy/cfg/__init__.py +3 -0
- hcpdiff/easy/cfg/sd15_train.py +201 -0
- hcpdiff/easy/cfg/sdxl_train.py +140 -0
- hcpdiff/easy/cfg/t2i.py +177 -0
- hcpdiff/easy/model/__init__.py +2 -0
- hcpdiff/easy/model/cnet.py +31 -0
- hcpdiff/easy/model/loader.py +79 -0
- hcpdiff/easy/sampler.py +46 -0
- hcpdiff/evaluate/__init__.py +1 -0
- hcpdiff/evaluate/previewer.py +60 -0
- hcpdiff/loss/__init__.py +4 -1
- hcpdiff/loss/base.py +41 -0
- hcpdiff/loss/gw.py +35 -0
- hcpdiff/loss/ssim.py +37 -0
- hcpdiff/loss/vlb.py +79 -0
- hcpdiff/loss/weighting.py +66 -0
- hcpdiff/models/__init__.py +2 -2
- hcpdiff/models/cfg_context.py +17 -14
- hcpdiff/models/compose/compose_hook.py +44 -23
- hcpdiff/models/compose/compose_tokenizer.py +21 -8
- hcpdiff/models/compose/sdxl_composer.py +4 -4
- hcpdiff/models/container.py +1 -1
- hcpdiff/models/controlnet.py +16 -16
- hcpdiff/models/lora_base_patch.py +14 -25
- hcpdiff/models/lora_layers.py +3 -9
- hcpdiff/models/lora_layers_patch.py +14 -24
- hcpdiff/models/text_emb_ex.py +84 -6
- hcpdiff/models/textencoder_ex.py +54 -18
- hcpdiff/models/wrapper/__init__.py +3 -0
- hcpdiff/models/wrapper/pixart.py +19 -0
- hcpdiff/models/wrapper/sd.py +218 -0
- hcpdiff/models/wrapper/utils.py +20 -0
- hcpdiff/parser/__init__.py +1 -0
- hcpdiff/parser/embpt.py +32 -0
- hcpdiff/tools/convert_caption_txt2json.py +1 -1
- hcpdiff/tools/dataset_generator.py +94 -0
- hcpdiff/tools/download_hf_model.py +24 -0
- hcpdiff/tools/embedding_convert.py +6 -2
- hcpdiff/tools/init_proj.py +3 -21
- hcpdiff/tools/lora_convert.py +19 -15
- hcpdiff/tools/save_model.py +12 -0
- hcpdiff/tools/sd2diffusers.py +1 -1
- hcpdiff/train_colo.py +1 -1
- hcpdiff/train_deepspeed.py +1 -1
- hcpdiff/trainer_ac.py +79 -0
- hcpdiff/trainer_ac_single.py +31 -0
- hcpdiff/utils/__init__.py +0 -2
- hcpdiff/utils/inpaint_pipe.py +790 -0
- hcpdiff/utils/net_utils.py +29 -6
- hcpdiff/utils/pipe_hook.py +46 -33
- hcpdiff/utils/utils.py +21 -4
- hcpdiff/workflow/__init__.py +15 -10
- hcpdiff/workflow/daam/__init__.py +1 -0
- hcpdiff/workflow/daam/act.py +66 -0
- hcpdiff/workflow/daam/hook.py +109 -0
- hcpdiff/workflow/diffusion.py +128 -136
- hcpdiff/workflow/fast.py +31 -0
- hcpdiff/workflow/flow.py +67 -0
- hcpdiff/workflow/io.py +36 -68
- hcpdiff/workflow/model.py +46 -43
- hcpdiff/workflow/text.py +84 -52
- hcpdiff/workflow/utils.py +32 -12
- hcpdiff/workflow/vae.py +37 -38
- hcpdiff-2.1.dist-info/METADATA +285 -0
- hcpdiff-2.1.dist-info/RECORD +114 -0
- {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
- hcpdiff-2.1.dist-info/entry_points.txt +5 -0
- hcpdiff/ckpt_manager/base.py +0 -16
- hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
- hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
- hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -60
- hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
- hcpdiff/data/bucket.py +0 -358
- hcpdiff/data/caption_loader.py +0 -80
- hcpdiff/data/cond_dataset.py +0 -40
- hcpdiff/data/crop_info_dataset.py +0 -40
- hcpdiff/data/data_processor.py +0 -33
- hcpdiff/data/pair_dataset.py +0 -146
- hcpdiff/data/sampler.py +0 -54
- hcpdiff/data/source/base.py +0 -30
- hcpdiff/data/utils.py +0 -80
- hcpdiff/infer_workflow.py +0 -57
- hcpdiff/loggers/__init__.py +0 -13
- hcpdiff/loggers/base_logger.py +0 -76
- hcpdiff/loggers/cli_logger.py +0 -40
- hcpdiff/loggers/preview/__init__.py +0 -1
- hcpdiff/loggers/preview/image_previewer.py +0 -149
- hcpdiff/loggers/tensorboard_logger.py +0 -30
- hcpdiff/loggers/wandb_logger.py +0 -31
- hcpdiff/loggers/webui_logger.py +0 -9
- hcpdiff/loss/min_snr_loss.py +0 -52
- hcpdiff/models/layers.py +0 -81
- hcpdiff/models/plugin.py +0 -348
- hcpdiff/models/wrapper.py +0 -75
- hcpdiff/noise/__init__.py +0 -3
- hcpdiff/noise/noise_base.py +0 -16
- hcpdiff/noise/pyramid_noise.py +0 -50
- hcpdiff/noise/zero_terminal.py +0 -44
- hcpdiff/train_ac.py +0 -565
- hcpdiff/train_ac_single.py +0 -39
- hcpdiff/utils/caption_tools.py +0 -105
- hcpdiff/utils/cfg_net_tools.py +0 -321
- hcpdiff/utils/cfg_resolvers.py +0 -16
- hcpdiff/utils/ema.py +0 -52
- hcpdiff/utils/img_size_tool.py +0 -248
- hcpdiff/vis/__init__.py +0 -3
- hcpdiff/vis/base_interface.py +0 -12
- hcpdiff/vis/disk_interface.py +0 -48
- hcpdiff/vis/webui_interface.py +0 -17
- hcpdiff/visualizer.py +0 -258
- hcpdiff/visualizer_reloadable.py +0 -237
- hcpdiff/workflow/base.py +0 -59
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -57
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero2.json +0 -32
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero3.json +0 -39
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
- hcpdiff-0.9.0.dist-info/METADATA +0 -199
- hcpdiff-0.9.0.dist-info/RECORD +0 -155
- hcpdiff-0.9.0.dist-info/entry_points.txt +0 -2
- {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
- {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
hcpdiff/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from .
|
2
|
-
from .train_ac_single import TrainerSingleCard
|
3
|
-
from .visualizer import Visualizer
|
4
|
-
from .visualizer_reloadable import VisualizerReloadable
|
1
|
+
#from .train_ac_old import Trainer
|
2
|
+
#from .train_ac_single import TrainerSingleCard
|
3
|
+
# from .visualizer import Visualizer
|
4
|
+
# from .visualizer_reloadable import VisualizerReloadable
|
hcpdiff/ckpt_manager/__init__.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
|
-
from .
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
return CkptManagerSafe() if ckpt_path.endswith('.safetensors') else CkptManagerPKL()
|
1
|
+
from .format import EmbFormat, DiffusersSD15Format, DiffusersModelFormat, DiffusersSDXLFormat, DiffusersPixArtFormat, OfficialSDXLFormat, \
|
2
|
+
OfficialSD15Format
|
3
|
+
from .ckpt import EmbSaver, easy_emb_saver
|
4
|
+
from .loader import HCPLoraLoader
|
@@ -0,0 +1,24 @@
|
|
1
|
+
from rainbowneko.ckpt_manager import NekoSaver, CkptFormat, LocalCkptSource, PKLFormat
|
2
|
+
from torch import nn
|
3
|
+
from typing import Dict, Any
|
4
|
+
|
5
|
+
class EmbSaver(NekoSaver):
|
6
|
+
def __init__(self, format: CkptFormat, source: LocalCkptSource, target_key='embs', prefix=None):
|
7
|
+
super().__init__(format, source)
|
8
|
+
self.target_key = target_key
|
9
|
+
self.prefix = prefix
|
10
|
+
|
11
|
+
def save_to(self, name, model: nn.Module, plugin_groups: Dict[str, Any], model_ema=None, exclude_key=None,
|
12
|
+
name_template=None):
|
13
|
+
train_pts = plugin_groups[self.target_key]
|
14
|
+
for pt_name, pt in train_pts.items():
|
15
|
+
self.save(pt_name, (pt_name, pt), prefix=self.prefix)
|
16
|
+
if name_template is not None:
|
17
|
+
pt_name = name_template.format(pt_name)
|
18
|
+
self.save(pt_name, (pt_name, pt), prefix=self.prefix)
|
19
|
+
|
20
|
+
def easy_emb_saver():
|
21
|
+
return EmbSaver(
|
22
|
+
format=PKLFormat(),
|
23
|
+
source=LocalCkptSource(),
|
24
|
+
)
|
@@ -0,0 +1,59 @@
|
|
1
|
+
import torch
|
2
|
+
from diffusers import ModelMixin, AutoencoderKL, UNet2DConditionModel, PixArtTransformer2DModel
|
3
|
+
from rainbowneko.ckpt_manager.format import CkptFormat
|
4
|
+
from transformers import CLIPTextModel, AutoTokenizer, T5EncoderModel
|
5
|
+
|
6
|
+
from hcpdiff.diffusion.sampler import DDPMSampler, DDPMDiscreteSigmaScheduler
|
7
|
+
from hcpdiff.models.compose import SDXLTokenizer, SDXLTextEncoder
|
8
|
+
|
9
|
+
class DiffusersModelFormat(CkptFormat):
|
10
|
+
def __init__(self, builder: ModelMixin):
|
11
|
+
self.builder = builder
|
12
|
+
|
13
|
+
def save_ckpt(self, sd_model: ModelMixin, save_f: str, **kwargs):
|
14
|
+
sd_model.save_pretrained(save_f)
|
15
|
+
|
16
|
+
def load_ckpt(self, ckpt_f: str, map_location="cpu", **kwargs):
|
17
|
+
self.builder.from_pretrained(ckpt_f, **kwargs)
|
18
|
+
|
19
|
+
class DiffusersSD15Format(CkptFormat):
|
20
|
+
def load_ckpt(self, pretrained_model: str, map_location="cpu", denoiser=None, TE=None, vae: AutoencoderKL = None, noise_sampler=None,
|
21
|
+
tokenizer=None, revision=None, dtype=torch.float32, **kwargs):
|
22
|
+
denoiser = denoiser or UNet2DConditionModel.from_pretrained(
|
23
|
+
pretrained_model, subfolder="unet", revision=revision, torch_dtype=dtype
|
24
|
+
)
|
25
|
+
vae = vae or AutoencoderKL.from_pretrained(pretrained_model, subfolder="vae", revision=revision, torch_dtype=dtype)
|
26
|
+
noise_sampler = noise_sampler or DDPMSampler(DDPMDiscreteSigmaScheduler())
|
27
|
+
|
28
|
+
TE = TE or CLIPTextModel.from_pretrained(pretrained_model, subfolder="text_encoder", revision=revision, torch_dtype=dtype)
|
29
|
+
tokenizer = tokenizer or AutoTokenizer.from_pretrained(pretrained_model, subfolder="tokenizer", revision=revision, use_fast=False)
|
30
|
+
|
31
|
+
return dict(denoiser=denoiser, TE=TE, vae=vae, noise_sampler=noise_sampler, tokenizer=tokenizer)
|
32
|
+
|
33
|
+
class DiffusersSDXLFormat(CkptFormat):
|
34
|
+
def load_ckpt(self, pretrained_model: str, map_location="cpu", denoiser=None, TE=None, vae: AutoencoderKL = None, noise_sampler=None,
|
35
|
+
tokenizer=None, revision=None, dtype=torch.float32, **kwargs):
|
36
|
+
denoiser = denoiser or UNet2DConditionModel.from_pretrained(
|
37
|
+
pretrained_model, subfolder="unet", revision=revision, torch_dtype=dtype
|
38
|
+
)
|
39
|
+
vae = vae or AutoencoderKL.from_pretrained(pretrained_model, subfolder="vae", revision=revision, torch_dtype=dtype)
|
40
|
+
noise_sampler = noise_sampler or DDPMSampler(DDPMDiscreteSigmaScheduler())
|
41
|
+
|
42
|
+
TE = TE or SDXLTextEncoder.from_pretrained(pretrained_model, subfolder="text_encoder", revision=revision, torch_dtype=dtype)
|
43
|
+
tokenizer = tokenizer or SDXLTokenizer.from_pretrained(pretrained_model, subfolder="tokenizer", revision=revision, use_fast=False)
|
44
|
+
|
45
|
+
return dict(denoiser=denoiser, TE=TE, vae=vae, noise_sampler=noise_sampler, tokenizer=tokenizer)
|
46
|
+
|
47
|
+
class DiffusersPixArtFormat(CkptFormat):
|
48
|
+
def load_ckpt(self, pretrained_model: str, map_location="cpu", denoiser=None, TE=None, vae: AutoencoderKL = None, noise_sampler=None,
|
49
|
+
tokenizer=None, revision=None, dtype=torch.float32, **kwargs):
|
50
|
+
denoiser = denoiser or PixArtTransformer2DModel.from_pretrained(
|
51
|
+
pretrained_model, subfolder="transformer", revision=revision, torch_dtype=dtype
|
52
|
+
)
|
53
|
+
vae = vae or AutoencoderKL.from_pretrained(pretrained_model, subfolder="vae", revision=revision, torch_dtype=dtype)
|
54
|
+
noise_sampler = noise_sampler or DDPMSampler(DDPMDiscreteSigmaScheduler())
|
55
|
+
|
56
|
+
TE = TE or T5EncoderModel.from_pretrained(pretrained_model, subfolder="text_encoder", revision=revision, torch_dtype=dtype)
|
57
|
+
tokenizer = tokenizer or AutoTokenizer.from_pretrained(pretrained_model, subfolder="tokenizer", revision=revision, use_fast=False)
|
58
|
+
|
59
|
+
return dict(denoiser=denoiser, TE=TE, vae=vae, noise_sampler=noise_sampler, tokenizer=tokenizer)
|
@@ -0,0 +1,21 @@
|
|
1
|
+
from typing import Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from rainbowneko.ckpt_manager.format import CkptFormat
|
5
|
+
from torch.serialization import FILE_LIKE
|
6
|
+
|
7
|
+
class EmbFormat(CkptFormat):
|
8
|
+
EXT = 'pt'
|
9
|
+
|
10
|
+
def save_ckpt(self, sd_model: Tuple[str, torch.Tensor], save_f: FILE_LIKE):
|
11
|
+
name, emb = sd_model
|
12
|
+
torch.save({'string_to_param':{'*':emb}, 'name':name}, save_f)
|
13
|
+
|
14
|
+
def load_ckpt(self, ckpt_f: FILE_LIKE, map_location="cpu"):
|
15
|
+
state = torch.load(ckpt_f, map_location=map_location)
|
16
|
+
if 'string_to_param' in state:
|
17
|
+
emb = state['string_to_param']['*']
|
18
|
+
else:
|
19
|
+
emb = state['emb_params']
|
20
|
+
emb.requires_grad_(False)
|
21
|
+
return emb
|
@@ -0,0 +1,244 @@
|
|
1
|
+
import math
|
2
|
+
import re
|
3
|
+
from typing import List, Dict, Any
|
4
|
+
|
5
|
+
from rainbowneko.ckpt_manager.format import CkptFormat
|
6
|
+
from torch.serialization import FILE_LIKE
|
7
|
+
|
8
|
+
class LoraConverter:
|
9
|
+
com_name_unet = ['down_blocks', 'up_blocks', 'mid_block', 'transformer_blocks', 'to_q', 'to_k', 'to_v', 'to_out', 'proj_in', 'proj_out',
|
10
|
+
'input_blocks', 'middle_block', 'output_blocks']
|
11
|
+
com_name_TE = ['self_attn', 'q_proj', 'v_proj', 'k_proj', 'out_proj', 'text_model']
|
12
|
+
prefix_unet = 'lora_unet_'
|
13
|
+
prefix_TE = 'lora_te_'
|
14
|
+
prefix_TE_xl_clip_L = 'lora_te1_'
|
15
|
+
prefix_TE_xl_clip_bigG = 'lora_te2_'
|
16
|
+
|
17
|
+
lora_w_map = {'lora_down.weight':'W_down', 'lora_up.weight':'W_up'}
|
18
|
+
|
19
|
+
def __init__(self):
|
20
|
+
self.com_name_unet_tmp = [x.replace('_', '%') for x in self.com_name_unet]
|
21
|
+
self.com_name_TE_tmp = [x.replace('_', '%') for x in self.com_name_TE]
|
22
|
+
|
23
|
+
def convert_from_webui(self, state, auto_scale_alpha=False, sdxl=False):
|
24
|
+
if not sdxl:
|
25
|
+
sd_unet = self.convert_from_webui_(state, prefix=self.prefix_unet, com_name=self.com_name_unet, com_name_tmp=self.com_name_unet_tmp)
|
26
|
+
sd_TE = self.convert_from_webui_(state, prefix=self.prefix_TE, com_name=self.com_name_TE, com_name_tmp=self.com_name_TE_tmp)
|
27
|
+
else:
|
28
|
+
sd_unet = self.convert_from_webui_xl_unet_(state, prefix=self.prefix_unet, com_name=self.com_name_unet,
|
29
|
+
com_name_tmp=self.com_name_unet_tmp)
|
30
|
+
sd_TE = self.convert_from_webui_xl_te_(state, prefix=self.prefix_TE_xl_clip_L, com_name=self.com_name_TE,
|
31
|
+
com_name_tmp=self.com_name_TE_tmp)
|
32
|
+
sd_TE2 = self.convert_from_webui_xl_te_(state, prefix=self.prefix_TE_xl_clip_bigG, com_name=self.com_name_TE,
|
33
|
+
com_name_tmp=self.com_name_TE_tmp)
|
34
|
+
sd_TE.update(sd_TE2)
|
35
|
+
|
36
|
+
if auto_scale_alpha:
|
37
|
+
sd_unet = self.alpha_scale_from_webui(sd_unet)
|
38
|
+
sd_TE = self.alpha_scale_from_webui(sd_TE)
|
39
|
+
return {'plugin':sd_TE}, {'plugin':sd_unet}
|
40
|
+
|
41
|
+
def convert_to_webui(self, sd_unet, sd_TE, auto_scale_alpha=False, sdxl=False):
|
42
|
+
sd_unet = self.convert_to_webui_(sd_unet, prefix=self.prefix_unet)
|
43
|
+
if sdxl:
|
44
|
+
sd_TE = self.convert_to_webui_xl_(sd_TE, prefix=self.prefix_TE)
|
45
|
+
else:
|
46
|
+
sd_TE = self.convert_to_webui_(sd_TE, prefix=self.prefix_TE)
|
47
|
+
sd_unet.update(sd_TE)
|
48
|
+
if auto_scale_alpha:
|
49
|
+
sd_unet = self.alpha_scale_to_webui(sd_unet)
|
50
|
+
return sd_unet
|
51
|
+
|
52
|
+
def convert_from_webui_(self, state, prefix, com_name, com_name_tmp):
|
53
|
+
state = {k:v for k, v in state.items() if k.startswith(prefix)}
|
54
|
+
prefix_len = len(prefix)
|
55
|
+
sd_covert = {}
|
56
|
+
for k, v in state.items():
|
57
|
+
model_k, lora_k = k[prefix_len:].split('.', 1)
|
58
|
+
model_k = self.replace_all(model_k, com_name, com_name_tmp).replace('_', '.').replace('%', '_')
|
59
|
+
if lora_k == 'alpha':
|
60
|
+
sd_covert[f'{model_k}.___.{lora_k}'] = v
|
61
|
+
else:
|
62
|
+
sd_covert[f'{model_k}.___.layer.{self.lora_w_map[lora_k]}'] = v
|
63
|
+
return sd_covert
|
64
|
+
|
65
|
+
def convert_to_webui_(self, state, prefix):
|
66
|
+
sd_covert = {}
|
67
|
+
for k, v in state.items():
|
68
|
+
if k.endswith('W_down'):
|
69
|
+
model_k, _ = k.split('.___.', 1)
|
70
|
+
lora_k = 'lora_down.weight'
|
71
|
+
elif k.endswith('W_up'):
|
72
|
+
model_k, _ = k.split('.___.', 1)
|
73
|
+
lora_k = 'lora_up.weight'
|
74
|
+
else:
|
75
|
+
model_k, lora_k = k.split('.___.', 1)
|
76
|
+
|
77
|
+
sd_covert[f"{prefix}{model_k.replace('.', '_')}.{lora_k}"] = v
|
78
|
+
return sd_covert
|
79
|
+
|
80
|
+
def convert_to_webui_xl_(self, state, prefix):
|
81
|
+
sd_convert = {}
|
82
|
+
for k, v in state.items():
|
83
|
+
if k.endswith('W_down'):
|
84
|
+
model_k, _ = k.split('.___.', 1)
|
85
|
+
lora_k = 'lora_down.weight'
|
86
|
+
elif k.endswith('W_up'):
|
87
|
+
model_k, _ = k.split('.___.', 1)
|
88
|
+
lora_k = 'lora_up.weight'
|
89
|
+
else:
|
90
|
+
model_k, lora_k = k.split('.___.', 1)
|
91
|
+
|
92
|
+
new_k = f"{prefix}{model_k.replace('.', '_')}.{lora_k}"
|
93
|
+
if 'clip' in new_k:
|
94
|
+
new_k = new_k.replace('_clip_L', '1') if 'clip_L' in new_k else new_k.replace('_clip_bigG', '2')
|
95
|
+
sd_convert[new_k] = v
|
96
|
+
return sd_convert
|
97
|
+
|
98
|
+
def convert_from_webui_xl_te_(self, state, prefix, com_name, com_name_tmp):
|
99
|
+
state = {k:v for k, v in state.items() if k.startswith(prefix)}
|
100
|
+
sd_covert = {}
|
101
|
+
prefix_len = len(prefix)
|
102
|
+
|
103
|
+
for k, v in state.items():
|
104
|
+
model_k, lora_k = k[prefix_len:].split('.', 1)
|
105
|
+
model_k = self.replace_all(model_k, com_name, com_name_tmp).replace('_', '.').replace('%', '_')
|
106
|
+
if prefix == 'lora_te1_':
|
107
|
+
model_k = f'clip_L.{model_k}'
|
108
|
+
else:
|
109
|
+
model_k = f'clip_bigG.{model_k}'
|
110
|
+
|
111
|
+
if lora_k == 'alpha':
|
112
|
+
sd_covert[f'{model_k}.___.{lora_k}'] = v
|
113
|
+
else:
|
114
|
+
sd_covert[f'{model_k}.___.layer.{self.lora_w_map[lora_k]}'] = v
|
115
|
+
return sd_covert
|
116
|
+
|
117
|
+
def convert_from_webui_xl_unet_(self, state, prefix, com_name, com_name_tmp):
|
118
|
+
# Down:
|
119
|
+
# 4 -> 1, 0 4 = 1 + 3 * 1 + 0
|
120
|
+
# 5 -> 1, 1 5 = 1 + 3 * 1 + 1
|
121
|
+
# 7 -> 2, 0 7 = 1 + 3 * 2 + 0
|
122
|
+
# 8 -> 2, 1 8 = 1 + 3 * 2 + 1
|
123
|
+
|
124
|
+
# Up
|
125
|
+
# 0 -> 0, 0 0 = 0 * 3 + 0
|
126
|
+
# 1 -> 0, 1 1 = 0 * 3 + 1
|
127
|
+
# 2 -> 0, 2 2 = 0 * 3 + 2
|
128
|
+
# 3 -> 1, 0 3 = 1 * 3 + 0
|
129
|
+
# 4 -> 1, 1 4 = 1 * 3 + 1
|
130
|
+
# 5 -> 1, 2 5 = 1 * 3 + 2
|
131
|
+
|
132
|
+
down = {
|
133
|
+
'4':[1, 0],
|
134
|
+
'5':[1, 1],
|
135
|
+
'7':[2, 0],
|
136
|
+
'8':[2, 1],
|
137
|
+
}
|
138
|
+
up = {
|
139
|
+
'0':[0, 0],
|
140
|
+
'1':[0, 1],
|
141
|
+
'2':[0, 2],
|
142
|
+
'3':[1, 0],
|
143
|
+
'4':[1, 1],
|
144
|
+
'5':[1, 2],
|
145
|
+
}
|
146
|
+
|
147
|
+
m = []
|
148
|
+
|
149
|
+
def match(key, regex_text):
|
150
|
+
regex = re.compile(regex_text)
|
151
|
+
r = re.match(regex, key)
|
152
|
+
if not r:
|
153
|
+
return False
|
154
|
+
|
155
|
+
m.clear()
|
156
|
+
m.extend(r.groups())
|
157
|
+
return True
|
158
|
+
|
159
|
+
state = {k:v for k, v in state.items() if k.startswith(prefix)}
|
160
|
+
sd_covert = {}
|
161
|
+
prefix_len = len(prefix)
|
162
|
+
for k, v in state.items():
|
163
|
+
model_k, lora_k = k[prefix_len:].split('.', 1)
|
164
|
+
|
165
|
+
model_k = self.replace_all(model_k, com_name, com_name_tmp).replace('_', '.').replace('%', '_')
|
166
|
+
|
167
|
+
if match(model_k, r'input_blocks.(\d+).1.(.+)'):
|
168
|
+
new_k = f'down_blocks.{down[m[0]][0]}.attentions.{down[m[0]][1]}.{m[1]}'
|
169
|
+
elif match(model_k, r'middle_block.1.(.+)'):
|
170
|
+
new_k = f'mid_block.attentions.0.{m[0]}'
|
171
|
+
pass
|
172
|
+
elif match(model_k, r'output_blocks.(\d+).(\d+).(.+)'):
|
173
|
+
new_k = f'up_blocks.{up[m[0]][0]}.attentions.{up[m[0]][1]}.{m[2]}'
|
174
|
+
else:
|
175
|
+
raise NotImplementedError
|
176
|
+
|
177
|
+
if lora_k == 'alpha':
|
178
|
+
sd_covert[f'{new_k}.___.{lora_k}'] = v
|
179
|
+
else:
|
180
|
+
sd_covert[f'{new_k}.___.layer.{self.lora_w_map[lora_k]}'] = v
|
181
|
+
|
182
|
+
return sd_covert
|
183
|
+
|
184
|
+
@staticmethod
|
185
|
+
def replace_all(data: str, srcs: List[str], dsts: List[str]):
|
186
|
+
for src, dst in zip(srcs, dsts):
|
187
|
+
data = data.replace(src, dst)
|
188
|
+
return data
|
189
|
+
|
190
|
+
@staticmethod
|
191
|
+
def alpha_scale_from_webui(state):
|
192
|
+
# Apply to "lora_down" and "lora_up" respectively to prevent overflow
|
193
|
+
for k, v in state.items():
|
194
|
+
if 'W_up' in k:
|
195
|
+
state[k] = v*math.sqrt(v.shape[1])
|
196
|
+
elif 'W_down' in k:
|
197
|
+
state[k] = v*math.sqrt(v.shape[0])
|
198
|
+
return state
|
199
|
+
|
200
|
+
@staticmethod
|
201
|
+
def alpha_scale_to_webui(state):
|
202
|
+
for k, v in state.items():
|
203
|
+
if 'lora_up' in k:
|
204
|
+
state[k] = v*math.sqrt(v.shape[1])
|
205
|
+
elif 'lora_down' in k:
|
206
|
+
state[k] = v*math.sqrt(v.shape[0])
|
207
|
+
return state
|
208
|
+
|
209
|
+
class LoraWebuiFormat(CkptFormat):
|
210
|
+
def __init__(self, format, auto_scale_alpha=False):
|
211
|
+
self.converter = LoraConverter()
|
212
|
+
self.auto_scale_alpha = auto_scale_alpha
|
213
|
+
self.format = format
|
214
|
+
|
215
|
+
def save_ckpt(self, sd_model: Dict[str, Any], save_f: FILE_LIKE):
|
216
|
+
sd_denoiser = {k.removeprefix('denoiser.'):v for k, v in sd_model['base'].items() if k.startswith('denoiser.')}
|
217
|
+
sd_TE = {k.removeprefix('TE.'):v for k, v in sd_model['base'].items() if k.startswith('TE.')}
|
218
|
+
|
219
|
+
if len(sd_denoiser)>0 or len(sd_TE)>0:
|
220
|
+
sdxl = False
|
221
|
+
for k in sd_TE.keys():
|
222
|
+
if 'clip_L' in k or 'clip_bigG' in k:
|
223
|
+
sdxl = True
|
224
|
+
break
|
225
|
+
sd_webui = self.converter.convert_to_webui(sd_denoiser, sd_TE, auto_scale_alpha=self.auto_scale_alpha, sdxl=sdxl)
|
226
|
+
else:
|
227
|
+
sd_webui = self.converter.convert_to_webui(sd_model['base'], {}, auto_scale_alpha=self.auto_scale_alpha)
|
228
|
+
|
229
|
+
self.format.save_ckpt(sd_webui, save_f)
|
230
|
+
|
231
|
+
def load_ckpt(self, ckpt_f: str, map_location="cpu", **kwargs):
|
232
|
+
sd_webui = self.format.load_ckpt(ckpt_f, map_location=map_location, **kwargs)
|
233
|
+
|
234
|
+
sdxl = False
|
235
|
+
for k in sd_webui.keys():
|
236
|
+
if ('lora_te1_' in k or 'lora_te2_' in k or
|
237
|
+
re.match(r'input_blocks.(\d+).1.(.+)', k) or
|
238
|
+
re.match(r'middle_block.1.(.+)', k) or
|
239
|
+
re.match(r'output_blocks.(\d+).(\d+).(.+)', k)):
|
240
|
+
sdxl = True
|
241
|
+
break
|
242
|
+
|
243
|
+
sd_TE, sd_unet = self.converter.convert_from_webui(sd_webui, auto_scale_alpha=self.auto_scale_alpha, sdxl=sdxl)
|
244
|
+
return sd_TE, sd_unet
|
@@ -0,0 +1,41 @@
|
|
1
|
+
import torch
|
2
|
+
from diffusers import AutoencoderKL, StableDiffusionPipeline, StableDiffusionXLPipeline
|
3
|
+
from rainbowneko.ckpt_manager.format import CkptFormat
|
4
|
+
|
5
|
+
from hcpdiff.diffusion.sampler import DDPMSampler, DDPMDiscreteSigmaScheduler
|
6
|
+
from hcpdiff.models.compose import SDXLTextEncoder, SDXLTokenizer
|
7
|
+
|
8
|
+
class OfficialSD15Format(CkptFormat):
|
9
|
+
# Single file format
|
10
|
+
def load_ckpt(self, pretrained_model: str, map_location="cpu", denoiser=None, TE=None, vae: AutoencoderKL = None, noise_sampler=None,
|
11
|
+
tokenizer=None, revision=None, dtype=torch.float32, **kwargs):
|
12
|
+
pipe_args = dict(unet=denoiser, vae=vae, text_encoder=TE, tokenizer=tokenizer)
|
13
|
+
pipe_args = {k:v for k,v in pipe_args.items() if v is not None}
|
14
|
+
pipe = StableDiffusionPipeline.from_single_file(
|
15
|
+
pretrained_model, revision=revision, torch_dtype=dtype, **pipe_args
|
16
|
+
)
|
17
|
+
noise_sampler = noise_sampler or DDPMSampler(DDPMDiscreteSigmaScheduler())
|
18
|
+
return dict(denoiser=pipe.unet, TE=pipe.text_encoder, vae=pipe.vae, noise_sampler=noise_sampler, tokenizer=pipe.tokenizer)
|
19
|
+
|
20
|
+
class OfficialSDXLFormat(CkptFormat):
|
21
|
+
# Single file format
|
22
|
+
def load_ckpt(self, pretrained_model: str, map_location="cpu", denoiser=None, TE=None, vae: AutoencoderKL = None, noise_sampler=None,
|
23
|
+
tokenizer=None, revision=None, dtype=torch.float32, **kwargs):
|
24
|
+
pipe_args = dict(unet=denoiser, vae=vae)
|
25
|
+
if TE is not None:
|
26
|
+
pipe_args['text_encoder'] = TE.clip_L
|
27
|
+
pipe_args['text_encoder_2'] = TE.clip_bigG
|
28
|
+
if tokenizer is not None:
|
29
|
+
pipe_args['tokenizer'] = tokenizer.clip_L
|
30
|
+
pipe_args['tokenizer_2'] = tokenizer.clip_bigG
|
31
|
+
|
32
|
+
pipe_args = {k:v for k,v in pipe_args.items() if v is not None}
|
33
|
+
pipe = StableDiffusionXLPipeline.from_single_file(
|
34
|
+
pretrained_model, revision=revision, torch_dtype=dtype, **pipe_args
|
35
|
+
)
|
36
|
+
|
37
|
+
noise_sampler = noise_sampler or DDPMSampler(DDPMDiscreteSigmaScheduler())
|
38
|
+
TE = SDXLTextEncoder([('clip_L', pipe.text_encoder), ('clip_bigG', pipe.text_encoder_2)])
|
39
|
+
tokenizer = SDXLTokenizer([('clip_L', pipe.tokenizer), ('clip_bigG', pipe.tokenizer_2)])
|
40
|
+
|
41
|
+
return dict(denoiser=pipe.unet, TE=TE, vae=pipe.vae, noise_sampler=noise_sampler, tokenizer=tokenizer)
|
@@ -0,0 +1,64 @@
|
|
1
|
+
from hcpdiff.models.lora_layers_patch import LoraLayer
|
2
|
+
from torch import nn
|
3
|
+
from hcpdiff.utils.net_utils import split_module_name
|
4
|
+
from rainbowneko.ckpt_manager import NekoPluginLoader, LocalCkptSource, CkptFormat
|
5
|
+
from rainbowneko.ckpt_manager.locator import get_match_layers
|
6
|
+
from rainbowneko.models.plugin import PluginGroup
|
7
|
+
|
8
|
+
def get_lora_rank_and_cls(lora_state):
|
9
|
+
if 'layer.W_down' in lora_state:
|
10
|
+
rank = lora_state['layer.W_down'].shape[0]
|
11
|
+
return LoraLayer, rank
|
12
|
+
else:
|
13
|
+
raise ValueError('Unknown lora format.')
|
14
|
+
|
15
|
+
class HCPLoraLoader(NekoPluginLoader):
|
16
|
+
def __init__(self, format: CkptFormat=None, source: LocalCkptSource=None, path: str = None, layers='all', target_plugin=None,
|
17
|
+
state_prefix=None, base_model_alpha=0.0, load_ema=False, module_to_load='', **plugin_kwargs):
|
18
|
+
super().__init__(format, source, path=path, layers=layers, target_plugin=target_plugin, state_prefix=state_prefix,
|
19
|
+
base_model_alpha=base_model_alpha, load_ema=load_ema, **plugin_kwargs)
|
20
|
+
self.module_to_load = module_to_load
|
21
|
+
|
22
|
+
def load_to(self, name, model):
|
23
|
+
# get model to load plugin and its named_modules
|
24
|
+
model = model if self.module_to_load == '' else eval(f"model.{self.module_to_load}")
|
25
|
+
|
26
|
+
named_modules = {k:v for k, v in model.named_modules()}
|
27
|
+
plugin_state = self.load(self.path, map_location='cpu')['base_ema' if self.load_ema else 'base']
|
28
|
+
|
29
|
+
# filter layers to load
|
30
|
+
if self.layers != 'all':
|
31
|
+
match_blocks = get_match_layers(self.layers, named_modules)
|
32
|
+
plugin_state = {k: v for blk in match_blocks for k, v in plugin_state.items() if k.startswith(blk)}
|
33
|
+
|
34
|
+
if self.state_prefix:
|
35
|
+
state_prefix_len = len(self.state_prefix)
|
36
|
+
plugin_state = {k[state_prefix_len:]: v for k, v in plugin_state.items() if k.startswith(self.state_prefix)}
|
37
|
+
|
38
|
+
lora_block_state = {}
|
39
|
+
# get all layers in the lora_state
|
40
|
+
for pname, p in plugin_state.items():
|
41
|
+
# lora_block. is the old format
|
42
|
+
prefix, block_name = pname.split('.___.', 1)
|
43
|
+
if prefix not in lora_block_state:
|
44
|
+
lora_block_state[prefix] = {}
|
45
|
+
lora_block_state[prefix][block_name] = p
|
46
|
+
|
47
|
+
# add lora to host and load weights
|
48
|
+
lora_blocks = {}
|
49
|
+
for layer_name, lora_state in lora_block_state.items():
|
50
|
+
lora_layer_cls, rank = get_lora_rank_and_cls(lora_state)
|
51
|
+
|
52
|
+
if 'alpha' in lora_state:
|
53
|
+
lora_state['alpha'] *= self.plugin_kwargs.get('alpha', 1.0)
|
54
|
+
|
55
|
+
parent_name, host_name = split_module_name(layer_name)
|
56
|
+
|
57
|
+
lora_block = lora_layer_cls.wrap_layer(name, named_modules[layer_name], rank=rank, bias='layer.bias' in lora_state,
|
58
|
+
parent_block=named_modules[parent_name], host_name=host_name)
|
59
|
+
lora_block.set_hyper_params(**self.plugin_kwargs)
|
60
|
+
lora_blocks[layer_name] = lora_block
|
61
|
+
load_info = lora_block.load_state_dict(lora_state, strict=False)
|
62
|
+
if len(load_info.unexpected_keys) > 0:
|
63
|
+
print(name, 'unexpected_keys', load_info.unexpected_keys)
|
64
|
+
return PluginGroup(lora_blocks)
|
hcpdiff/data/__init__.py
CHANGED
@@ -1,28 +1,4 @@
|
|
1
|
-
from .
|
2
|
-
from .
|
3
|
-
from .
|
4
|
-
from .
|
5
|
-
from .utils import CycleData
|
6
|
-
from .caption_loader import JsonCaptionLoader, TXTCaptionLoader
|
7
|
-
from .sampler import DistributedCycleSampler, get_sampler
|
8
|
-
|
9
|
-
class DataGroup:
|
10
|
-
def __init__(self, loader_list, loss_weights):
|
11
|
-
self.loader_list = loader_list
|
12
|
-
self.loss_weights = loss_weights
|
13
|
-
|
14
|
-
def __iter__(self):
|
15
|
-
self.data_iter_list = [iter(CycleData(loader)) for loader in self.loader_list]
|
16
|
-
return self
|
17
|
-
|
18
|
-
def __next__(self):
|
19
|
-
return [next(data_iter) for data_iter in self.data_iter_list]
|
20
|
-
|
21
|
-
def __len__(self):
|
22
|
-
return len(self.loader_list)
|
23
|
-
|
24
|
-
def get_dataset(self, idx):
|
25
|
-
return self.loader_list[idx].dataset
|
26
|
-
|
27
|
-
def get_loss_weights(self, idx):
|
28
|
-
return self.loss_weights[idx]
|
1
|
+
from .dataset import TextImagePairDataset
|
2
|
+
from .source import Text2ImageSource, Text2ImageLossMapSource, Text2ImageCondSource, T2IFolderClassSource
|
3
|
+
from .handler import StableDiffusionHandler, LossMapHandler, DiffusionImageHandler
|
4
|
+
from .cache import VaeCache
|
@@ -0,0 +1 @@
|
|
1
|
+
from .vae import VaeCache
|
@@ -0,0 +1,102 @@
|
|
1
|
+
from io import BytesIO
|
2
|
+
from pathlib import Path
|
3
|
+
from typing import Dict, Any
|
4
|
+
|
5
|
+
import lmdb
|
6
|
+
import torch
|
7
|
+
from hcpdiff.models.wrapper import SD15Wrapper
|
8
|
+
from rainbowneko import _share
|
9
|
+
from rainbowneko.data import DataCache, CacheableDataset
|
10
|
+
from rainbowneko.utils import Path_Like
|
11
|
+
from torch.utils.data import DataLoader
|
12
|
+
from torch.utils.data.distributed import DistributedSampler
|
13
|
+
from tqdm import tqdm
|
14
|
+
|
15
|
+
class VaeCache(DataCache):
|
16
|
+
def __init__(self, pre_build: Path_Like = None, lazy=False, bs=1):
|
17
|
+
super().__init__(pre_build)
|
18
|
+
self.lazy = lazy
|
19
|
+
self.bs = bs
|
20
|
+
|
21
|
+
def load_latent(self, id):
|
22
|
+
if self.lazy:
|
23
|
+
with self.env.begin() as txn:
|
24
|
+
byte_tensor = txn.get(str(id).encode())
|
25
|
+
return torch.load(BytesIO(byte_tensor))
|
26
|
+
else:
|
27
|
+
return self.cache[id]
|
28
|
+
|
29
|
+
def before_handler(self, index: int, data: Dict[str, Any]):
|
30
|
+
cached_data = self.load_latent(data['id'])
|
31
|
+
data['image'] = cached_data['latent']
|
32
|
+
data['coord'] = cached_data['coord']
|
33
|
+
return data
|
34
|
+
|
35
|
+
def on_finish(self, index, data):
|
36
|
+
return data
|
37
|
+
|
38
|
+
def load(self, path):
|
39
|
+
if self.lazy:
|
40
|
+
self.env = lmdb.open(path, subdir=False, readonly=True, lock=False, readahead=False, meminit=False)
|
41
|
+
return {}
|
42
|
+
elif len(self.cache)>0:
|
43
|
+
return self.cache
|
44
|
+
else:
|
45
|
+
env = lmdb.open(path, subdir=False, readonly=True, lock=False, readahead=False, meminit=False)
|
46
|
+
with env.begin() as txn:
|
47
|
+
cache = {k.decode():torch.load(BytesIO(v)) for k, v in txn.cursor()}
|
48
|
+
env.close()
|
49
|
+
return cache
|
50
|
+
|
51
|
+
def build(self, dataset: CacheableDataset, model: SD15Wrapper, all_gather):
|
52
|
+
if (self.pre_build and Path(self.pre_build).exists()) or len(self.cache)>0:
|
53
|
+
model.vae = None
|
54
|
+
return
|
55
|
+
|
56
|
+
vae = model.vae.to(_share.device)
|
57
|
+
with dataset.disable_cache():
|
58
|
+
dataset.bucket.rest(0)
|
59
|
+
|
60
|
+
loader = DataLoader(
|
61
|
+
dataset,
|
62
|
+
batch_size=self.bs,
|
63
|
+
num_workers=0,
|
64
|
+
sampler=DistributedSampler(dataset, num_replicas=_share.world_size, rank=_share.local_rank, shuffle=False),
|
65
|
+
collate_fn=dataset.collate_fn,
|
66
|
+
drop_last=False,
|
67
|
+
)
|
68
|
+
|
69
|
+
if self.pre_build:
|
70
|
+
Path(self.pre_build).parent.mkdir(parents=True, exist_ok=True)
|
71
|
+
env = lmdb.open(self.pre_build, map_size=1099511627776)
|
72
|
+
with env.begin(write=True) as txn:
|
73
|
+
for data in tqdm(loader):
|
74
|
+
image = data['image'].to(device=_share.device, dtype=vae.dtype)
|
75
|
+
latents = model.vae.encode(image).latent_dist.sample()
|
76
|
+
latents = (latents*vae.config.scaling_factor).cpu()
|
77
|
+
|
78
|
+
for img_id, latent, coord in zip(data['id'], latents, data['coord']):
|
79
|
+
data_cache = {'latent': latent, 'coord': coord}
|
80
|
+
|
81
|
+
byte_stream = BytesIO()
|
82
|
+
torch.save(data_cache, byte_stream)
|
83
|
+
txn.put(str(img_id).encode(), byte_stream.getvalue())
|
84
|
+
if not self.lazy:
|
85
|
+
self.cache[img_id] = data_cache
|
86
|
+
env.close()
|
87
|
+
else:
|
88
|
+
for data in tqdm(loader):
|
89
|
+
img_id = data['id']
|
90
|
+
image = data['image'].to(device=_share.device, dtype=vae.dtype)
|
91
|
+
latents = model.vae.encode(image).latent_dist.sample()
|
92
|
+
latents = (latents*vae.config.scaling_factor).cpu()
|
93
|
+
for img_id, latent, coord in zip(data['id'], latents, data['coord']):
|
94
|
+
self.cache[img_id] = {'latent': latent, 'coord': coord}
|
95
|
+
|
96
|
+
model.vae.to('cpu')
|
97
|
+
#model.vae = None
|
98
|
+
torch.cuda.empty_cache()
|
99
|
+
|
100
|
+
cache_all = all_gather(self.cache)
|
101
|
+
for cache in cache_all:
|
102
|
+
self.cache.update(cache)
|
hcpdiff/data/dataset.py
ADDED
@@ -0,0 +1,20 @@
|
|
1
|
+
"""
|
2
|
+
pair_dataset.py
|
3
|
+
====================
|
4
|
+
:Name: text-image pair dataset
|
5
|
+
:Author: Dong Ziyi
|
6
|
+
:Affiliation: HCP Lab, SYSU
|
7
|
+
:Created: 10/03/2023
|
8
|
+
:Licence: Apache-2.0
|
9
|
+
"""
|
10
|
+
|
11
|
+
from typing import Union, Dict
|
12
|
+
|
13
|
+
from rainbowneko.data import CacheableDataset, BaseDataset, BaseBucket, DataSource, DataHandler, DataCache
|
14
|
+
|
15
|
+
def TextImagePairDataset(bucket: BaseBucket = None, source: Dict[str, DataSource] = None, handler: DataHandler = None,
|
16
|
+
batch_handler: DataHandler = None, cache: DataCache = None, **kwargs) -> Union[CacheableDataset, BaseDataset]:
|
17
|
+
if cache is None:
|
18
|
+
return BaseDataset(bucket=bucket, source=source, handler=handler, batch_handler=batch_handler, **kwargs)
|
19
|
+
else:
|
20
|
+
return CacheableDataset(bucket=bucket, source=source, handler=handler, batch_handler=batch_handler, cache=cache, **kwargs)
|
@@ -0,0 +1,18 @@
|
|
1
|
+
import torchvision.transforms as T
|
2
|
+
from PIL import Image
|
3
|
+
from rainbowneko.data import DataHandler, HandlerChain, LoadImageHandler, ImageHandler
|
4
|
+
|
5
|
+
class ControlNetHandler(DataHandler):
|
6
|
+
def __init__(self, key_map_in=('cond -> image',), key_map_out=('image -> cond',), bucket=None):
|
7
|
+
super().__init__(key_map_in, key_map_out)
|
8
|
+
|
9
|
+
self.handlers = HandlerChain(
|
10
|
+
load=LoadImageHandler(),
|
11
|
+
bucket=bucket.handler if bucket else DataHandler(),
|
12
|
+
image=ImageHandler(
|
13
|
+
transform=T.ToTensor(),
|
14
|
+
)
|
15
|
+
)
|
16
|
+
|
17
|
+
def handle(self, image:Image.Image):
|
18
|
+
return self.handlers(dict(image=image))
|