hcpdiff 0.9.1__py3-none-any.whl → 2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/__init__.py +4 -4
- hcpdiff/ckpt_manager/__init__.py +4 -5
- hcpdiff/ckpt_manager/ckpt.py +24 -0
- hcpdiff/ckpt_manager/format/__init__.py +4 -0
- hcpdiff/ckpt_manager/format/diffusers.py +59 -0
- hcpdiff/ckpt_manager/format/emb.py +21 -0
- hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
- hcpdiff/ckpt_manager/format/sd_single.py +41 -0
- hcpdiff/ckpt_manager/loader.py +64 -0
- hcpdiff/data/__init__.py +4 -28
- hcpdiff/data/cache/__init__.py +1 -0
- hcpdiff/data/cache/vae.py +102 -0
- hcpdiff/data/dataset.py +20 -0
- hcpdiff/data/handler/__init__.py +3 -0
- hcpdiff/data/handler/controlnet.py +18 -0
- hcpdiff/data/handler/diffusion.py +80 -0
- hcpdiff/data/handler/text.py +111 -0
- hcpdiff/data/source/__init__.py +1 -2
- hcpdiff/data/source/folder_class.py +12 -29
- hcpdiff/data/source/text2img.py +36 -74
- hcpdiff/data/source/text2img_cond.py +9 -15
- hcpdiff/diffusion/__init__.py +0 -0
- hcpdiff/diffusion/noise/__init__.py +2 -0
- hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
- hcpdiff/diffusion/noise/zero_terminal.py +39 -0
- hcpdiff/diffusion/sampler/__init__.py +5 -0
- hcpdiff/diffusion/sampler/base.py +72 -0
- hcpdiff/diffusion/sampler/ddpm.py +20 -0
- hcpdiff/diffusion/sampler/diffusers.py +66 -0
- hcpdiff/diffusion/sampler/edm.py +22 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
- hcpdiff/easy/__init__.py +2 -0
- hcpdiff/easy/cfg/__init__.py +3 -0
- hcpdiff/easy/cfg/sd15_train.py +201 -0
- hcpdiff/easy/cfg/sdxl_train.py +140 -0
- hcpdiff/easy/cfg/t2i.py +177 -0
- hcpdiff/easy/model/__init__.py +2 -0
- hcpdiff/easy/model/cnet.py +31 -0
- hcpdiff/easy/model/loader.py +79 -0
- hcpdiff/easy/sampler.py +46 -0
- hcpdiff/evaluate/__init__.py +1 -0
- hcpdiff/evaluate/previewer.py +60 -0
- hcpdiff/loss/__init__.py +4 -1
- hcpdiff/loss/base.py +41 -0
- hcpdiff/loss/gw.py +35 -0
- hcpdiff/loss/ssim.py +37 -0
- hcpdiff/loss/vlb.py +79 -0
- hcpdiff/loss/weighting.py +66 -0
- hcpdiff/models/__init__.py +2 -2
- hcpdiff/models/cfg_context.py +17 -14
- hcpdiff/models/compose/compose_hook.py +44 -23
- hcpdiff/models/compose/compose_tokenizer.py +21 -8
- hcpdiff/models/compose/sdxl_composer.py +4 -4
- hcpdiff/models/controlnet.py +16 -16
- hcpdiff/models/lora_base_patch.py +14 -25
- hcpdiff/models/lora_layers.py +3 -9
- hcpdiff/models/lora_layers_patch.py +14 -24
- hcpdiff/models/text_emb_ex.py +84 -6
- hcpdiff/models/textencoder_ex.py +54 -18
- hcpdiff/models/wrapper/__init__.py +3 -0
- hcpdiff/models/wrapper/pixart.py +19 -0
- hcpdiff/models/wrapper/sd.py +218 -0
- hcpdiff/models/wrapper/utils.py +20 -0
- hcpdiff/parser/__init__.py +1 -0
- hcpdiff/parser/embpt.py +32 -0
- hcpdiff/tools/convert_caption_txt2json.py +1 -1
- hcpdiff/tools/dataset_generator.py +94 -0
- hcpdiff/tools/download_hf_model.py +24 -0
- hcpdiff/tools/init_proj.py +3 -21
- hcpdiff/tools/lora_convert.py +18 -17
- hcpdiff/tools/save_model.py +12 -0
- hcpdiff/tools/sd2diffusers.py +1 -1
- hcpdiff/train_colo.py +1 -1
- hcpdiff/train_deepspeed.py +1 -1
- hcpdiff/trainer_ac.py +79 -0
- hcpdiff/trainer_ac_single.py +31 -0
- hcpdiff/utils/__init__.py +0 -2
- hcpdiff/utils/inpaint_pipe.py +7 -2
- hcpdiff/utils/net_utils.py +29 -6
- hcpdiff/utils/pipe_hook.py +24 -7
- hcpdiff/utils/utils.py +21 -4
- hcpdiff/workflow/__init__.py +15 -10
- hcpdiff/workflow/daam/__init__.py +1 -0
- hcpdiff/workflow/daam/act.py +66 -0
- hcpdiff/workflow/daam/hook.py +109 -0
- hcpdiff/workflow/diffusion.py +114 -125
- hcpdiff/workflow/fast.py +31 -0
- hcpdiff/workflow/flow.py +67 -0
- hcpdiff/workflow/io.py +36 -130
- hcpdiff/workflow/model.py +46 -43
- hcpdiff/workflow/text.py +78 -46
- hcpdiff/workflow/utils.py +32 -12
- hcpdiff/workflow/vae.py +37 -38
- hcpdiff-2.1.dist-info/METADATA +285 -0
- hcpdiff-2.1.dist-info/RECORD +114 -0
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
- hcpdiff-2.1.dist-info/entry_points.txt +5 -0
- hcpdiff/ckpt_manager/base.py +0 -16
- hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
- hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
- hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
- hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
- hcpdiff/data/bucket.py +0 -358
- hcpdiff/data/caption_loader.py +0 -80
- hcpdiff/data/cond_dataset.py +0 -40
- hcpdiff/data/crop_info_dataset.py +0 -40
- hcpdiff/data/data_processor.py +0 -33
- hcpdiff/data/pair_dataset.py +0 -146
- hcpdiff/data/sampler.py +0 -54
- hcpdiff/data/source/base.py +0 -30
- hcpdiff/data/utils.py +0 -80
- hcpdiff/deprecated/__init__.py +0 -1
- hcpdiff/deprecated/cfg_converter.py +0 -81
- hcpdiff/deprecated/lora_convert.py +0 -31
- hcpdiff/infer_workflow.py +0 -57
- hcpdiff/loggers/__init__.py +0 -13
- hcpdiff/loggers/base_logger.py +0 -76
- hcpdiff/loggers/cli_logger.py +0 -40
- hcpdiff/loggers/preview/__init__.py +0 -1
- hcpdiff/loggers/preview/image_previewer.py +0 -149
- hcpdiff/loggers/tensorboard_logger.py +0 -30
- hcpdiff/loggers/wandb_logger.py +0 -31
- hcpdiff/loggers/webui_logger.py +0 -9
- hcpdiff/loss/min_snr_loss.py +0 -52
- hcpdiff/models/layers.py +0 -81
- hcpdiff/models/plugin.py +0 -348
- hcpdiff/models/wrapper.py +0 -75
- hcpdiff/noise/__init__.py +0 -3
- hcpdiff/noise/noise_base.py +0 -16
- hcpdiff/noise/pyramid_noise.py +0 -50
- hcpdiff/noise/zero_terminal.py +0 -44
- hcpdiff/train_ac.py +0 -566
- hcpdiff/train_ac_single.py +0 -39
- hcpdiff/utils/caption_tools.py +0 -105
- hcpdiff/utils/cfg_net_tools.py +0 -321
- hcpdiff/utils/cfg_resolvers.py +0 -16
- hcpdiff/utils/ema.py +0 -52
- hcpdiff/utils/img_size_tool.py +0 -248
- hcpdiff/vis/__init__.py +0 -3
- hcpdiff/vis/base_interface.py +0 -12
- hcpdiff/vis/disk_interface.py +0 -48
- hcpdiff/vis/webui_interface.py +0 -17
- hcpdiff/viser_fast.py +0 -138
- hcpdiff/visualizer.py +0 -265
- hcpdiff/visualizer_reloadable.py +0 -237
- hcpdiff/workflow/base.py +0 -59
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
- hcpdiff-0.9.1.dist-info/METADATA +0 -199
- hcpdiff-0.9.1.dist-info/RECORD +0 -160
- hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
hcpdiff/utils/caption_tools.py
DELETED
@@ -1,105 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
caption_tools.py
|
3
|
-
====================
|
4
|
-
:Name: process prompts
|
5
|
-
:Author: Dong Ziyi
|
6
|
-
:Affiliation: HCP Lab, SYSU
|
7
|
-
:Created: 10/03/2023
|
8
|
-
:Licence: Apache-2.0
|
9
|
-
"""
|
10
|
-
|
11
|
-
import random
|
12
|
-
from string import Formatter
|
13
|
-
from typing import List, Dict, Union
|
14
|
-
|
15
|
-
import numpy as np
|
16
|
-
|
17
|
-
|
18
|
-
class TagShuffle:
|
19
|
-
def __call__(self, data):
|
20
|
-
if 'caption' in data:
|
21
|
-
text = data['caption']
|
22
|
-
if text is not None:
|
23
|
-
tags = text.split(',')
|
24
|
-
random.shuffle(tags)
|
25
|
-
data['caption'] = ','.join(tags)
|
26
|
-
return data
|
27
|
-
else:
|
28
|
-
for i, item in enumerate(data['prompt']):
|
29
|
-
tags = item.split(',')
|
30
|
-
random.shuffle(tags)
|
31
|
-
data['prompt'][i] = ','.join(tags)
|
32
|
-
return data
|
33
|
-
|
34
|
-
def __repr__(self):
|
35
|
-
return 'TagShuffle()'
|
36
|
-
|
37
|
-
|
38
|
-
class TagDropout:
|
39
|
-
def __init__(self, p=0.1):
|
40
|
-
self.p = p
|
41
|
-
|
42
|
-
def __call__(self, data):
|
43
|
-
if 'caption' in data:
|
44
|
-
text = data['caption']
|
45
|
-
if text is not None:
|
46
|
-
tags = np.array(text.split(','))
|
47
|
-
data['caption'] = ','.join(tags[np.random.random(len(tags)) > self.p])
|
48
|
-
return data
|
49
|
-
else:
|
50
|
-
for i, item in enumerate(data['prompt']):
|
51
|
-
tags = item.split(',')
|
52
|
-
data['prompt'][i] = ','.join(tags[np.random.random(len(tags)) > self.p])
|
53
|
-
return data
|
54
|
-
|
55
|
-
def __repr__(self):
|
56
|
-
return f'TagDropout(p={self.p})'
|
57
|
-
|
58
|
-
class TagErase:
|
59
|
-
def __init__(self, p=0.1):
|
60
|
-
self.p = p
|
61
|
-
|
62
|
-
def __call__(self, data):
|
63
|
-
for i, item in enumerate(data['prompt']):
|
64
|
-
if random.random()<self.p:
|
65
|
-
data['prompt'][i] = ''
|
66
|
-
return data
|
67
|
-
|
68
|
-
def __repr__(self):
|
69
|
-
return f'TagErase(p={self.p})'
|
70
|
-
|
71
|
-
class TemplateFill:
|
72
|
-
def __init__(self, word_names: Dict[str, Union[str, List[str]]]):
|
73
|
-
self.word_names = word_names
|
74
|
-
self.DA_names = {k: v for k, v in word_names.items() if not isinstance(v, str)}
|
75
|
-
self.dream_artist = len(self.DA_names) > 0
|
76
|
-
|
77
|
-
def __call__(self, data):
|
78
|
-
template, caption = data['prompt'], data['caption']
|
79
|
-
|
80
|
-
keys_need = {i[1] for i in Formatter().parse(template) if i[1] is not None}
|
81
|
-
fill_dict = {k: v for k, v in self.word_names.items() if k in keys_need}
|
82
|
-
|
83
|
-
if (caption is not None) and ('caption' in keys_need):
|
84
|
-
if self.dream_artist:
|
85
|
-
cap_fill = fill_dict.get('caption', [None, None])
|
86
|
-
fill_dict.update(caption=[cap_fill[0] or caption, cap_fill[1] or caption])
|
87
|
-
else:
|
88
|
-
fill_dict.update(caption=fill_dict.get('caption', None) or caption)
|
89
|
-
|
90
|
-
# skip keys that not provide
|
91
|
-
for k in keys_need:
|
92
|
-
if k not in fill_dict:
|
93
|
-
fill_dict[k] = ''
|
94
|
-
|
95
|
-
if self.dream_artist:
|
96
|
-
fill_dict_pos = {k: ((v if isinstance(v, str) else v[0]) or '') for k, v in fill_dict.items()}
|
97
|
-
fill_dict_neg = {k: ((v if isinstance(v, str) else v[1]) or '') for k, v in fill_dict.items()}
|
98
|
-
return {'prompt':[template.format(**fill_dict_neg), template.format(**fill_dict_pos)]}
|
99
|
-
else:
|
100
|
-
# replace None value with ''
|
101
|
-
fill_dict = {k:(v or '') for k, v in fill_dict.items()}
|
102
|
-
return {'prompt':[template.format(**fill_dict)]}
|
103
|
-
|
104
|
-
def __repr__(self):
|
105
|
-
return f'TemplateFill(\nword_names={self.word_names}\n)'
|
hcpdiff/utils/cfg_net_tools.py
DELETED
@@ -1,321 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
cfg_net_tools.py
|
3
|
-
====================
|
4
|
-
:Name: creat model and plugin from config
|
5
|
-
:Author: Dong Ziyi
|
6
|
-
:Affiliation: HCP Lab, SYSU
|
7
|
-
:Created: 10/03/2023
|
8
|
-
:Licence: Apache-2.0
|
9
|
-
"""
|
10
|
-
import warnings
|
11
|
-
from typing import Dict, List, Tuple, Union, Any
|
12
|
-
|
13
|
-
import re
|
14
|
-
import torch
|
15
|
-
from torch import nn
|
16
|
-
|
17
|
-
from .utils import net_path_join
|
18
|
-
from hcpdiff.models import LoraBlock, LoraGroup, lora_layer_map
|
19
|
-
from hcpdiff.models.plugin import SinglePluginBlock, MultiPluginBlock, PluginBlock, PluginGroup, PatchPluginBlock
|
20
|
-
from hcpdiff.ckpt_manager import auto_manager
|
21
|
-
from .net_utils import split_module_name
|
22
|
-
from hcpdiff.tools.convert_old_lora import convert_state
|
23
|
-
|
24
|
-
def get_class_match_layer(class_name, block:nn.Module):
|
25
|
-
if type(block).__name__==class_name:
|
26
|
-
return ['']
|
27
|
-
else:
|
28
|
-
return ['.'+name for name, layer in block.named_modules() if type(layer).__name__==class_name]
|
29
|
-
|
30
|
-
def get_match_layers(layers, all_layers, return_metas=False) -> Union[List[str], List[Dict[str, Any]]]:
|
31
|
-
res=[]
|
32
|
-
for name in layers:
|
33
|
-
metas = name.split(':')
|
34
|
-
|
35
|
-
use_re = False
|
36
|
-
pre_hook = False
|
37
|
-
cls_filter = None
|
38
|
-
for meta in metas[:-1]:
|
39
|
-
if meta=='re':
|
40
|
-
use_re=True
|
41
|
-
elif meta=='pre_hook':
|
42
|
-
pre_hook=True
|
43
|
-
elif meta.startswith('cls('):
|
44
|
-
cls_filter=meta[4:-1]
|
45
|
-
|
46
|
-
name = metas[-1]
|
47
|
-
if use_re:
|
48
|
-
pattern = re.compile(name)
|
49
|
-
match_layers = filter(lambda x: pattern.match(x) != None, all_layers.keys())
|
50
|
-
else:
|
51
|
-
match_layers = [name]
|
52
|
-
|
53
|
-
if cls_filter is not None:
|
54
|
-
match_layers_new = []
|
55
|
-
for layer in match_layers:
|
56
|
-
match_layers_new.extend([layer + x for x in get_class_match_layer(name[1], all_layers[layer])])
|
57
|
-
match_layers = match_layers_new
|
58
|
-
|
59
|
-
for layer in match_layers:
|
60
|
-
if return_metas:
|
61
|
-
res.append({'layer': layer, 'pre_hook': pre_hook})
|
62
|
-
else:
|
63
|
-
res.append(layer)
|
64
|
-
|
65
|
-
# Remove duplicates and keep the original order
|
66
|
-
if return_metas:
|
67
|
-
layer_set=set()
|
68
|
-
res_unique = []
|
69
|
-
for item in res:
|
70
|
-
if item['layer'] not in layer_set:
|
71
|
-
layer_set.add(item['layer'])
|
72
|
-
res_unique.append(item)
|
73
|
-
return res_unique
|
74
|
-
else:
|
75
|
-
return sorted(set(res), key=res.index)
|
76
|
-
|
77
|
-
def get_lora_rank_and_cls(lora_state):
|
78
|
-
if 'layer.lora_down.weight' in lora_state: # old format
|
79
|
-
warnings.warn("The old lora format is deprecated.", DeprecationWarning)
|
80
|
-
rank = lora_state['layer.lora_down.weight'].shape[0]
|
81
|
-
lora_layer_cls = lora_layer_map['lora']
|
82
|
-
return lora_layer_cls, rank, True
|
83
|
-
elif 'layer.W_down' in lora_state:
|
84
|
-
rank = lora_state['layer.W_down'].shape[0]
|
85
|
-
lora_layer_cls = lora_layer_map['lora']
|
86
|
-
return lora_layer_cls, rank, False
|
87
|
-
else:
|
88
|
-
raise ValueError('Unknown lora format.')
|
89
|
-
|
90
|
-
def make_hcpdiff(model, cfg_model, cfg_lora, default_lr=1e-5) -> Tuple[List[Dict], Union[LoraGroup, Tuple[LoraGroup, LoraGroup]]]:
|
91
|
-
named_modules = {k:v for k,v in model.named_modules()}
|
92
|
-
|
93
|
-
train_params=[]
|
94
|
-
all_lora_blocks={}
|
95
|
-
all_lora_blocks_neg={}
|
96
|
-
|
97
|
-
if cfg_model is not None:
|
98
|
-
for item in cfg_model:
|
99
|
-
params_group = []
|
100
|
-
for layer_name in get_match_layers(item.layers, named_modules):
|
101
|
-
layer = named_modules[layer_name]
|
102
|
-
layer.requires_grad_(True)
|
103
|
-
layer.train()
|
104
|
-
params_group.extend(list(LoraBlock.extract_param_without_lora(layer).values()))
|
105
|
-
train_params.append({'params':list(set(params_group)), 'lr':getattr(item, 'lr', default_lr)})
|
106
|
-
|
107
|
-
if cfg_lora is not None:
|
108
|
-
for lora_id, item in enumerate(cfg_lora):
|
109
|
-
params_group = []
|
110
|
-
for layer_name in get_match_layers(item.layers, named_modules):
|
111
|
-
parent_name, host_name = split_module_name(layer_name)
|
112
|
-
layer = named_modules[layer_name]
|
113
|
-
arg_dict = {k:v for k,v in item.items() if k!='layers'}
|
114
|
-
lora_block_dict = lora_layer_map[arg_dict.get('type', 'lora')].wrap_model(lora_id, layer, parent_block=named_modules[parent_name], host_name=host_name, **arg_dict)
|
115
|
-
|
116
|
-
for k,v in lora_block_dict.items():
|
117
|
-
block_path = net_path_join(layer_name, k)
|
118
|
-
all_lora_blocks[block_path] = v
|
119
|
-
v.requires_grad_(True)
|
120
|
-
v.train()
|
121
|
-
params_group.extend(v.parameters())
|
122
|
-
|
123
|
-
train_params.append({'params': params_group, 'lr':getattr(item, 'lr', default_lr)})
|
124
|
-
|
125
|
-
if len(all_lora_blocks_neg)>0:
|
126
|
-
return train_params, (LoraGroup(all_lora_blocks), LoraGroup(all_lora_blocks_neg))
|
127
|
-
else:
|
128
|
-
return train_params, LoraGroup(all_lora_blocks)
|
129
|
-
|
130
|
-
def make_plugin(model, cfg_plugin, default_lr=1e-5) -> Tuple[List, Dict[str, PluginGroup]]:
|
131
|
-
train_params=[]
|
132
|
-
all_plugin_group={}
|
133
|
-
|
134
|
-
if cfg_plugin is None:
|
135
|
-
return train_params, all_plugin_group
|
136
|
-
|
137
|
-
named_modules = {k: v for k, v in model.named_modules()}
|
138
|
-
|
139
|
-
# builder: functools.partial
|
140
|
-
for plugin_name, builder in cfg_plugin.items():
|
141
|
-
all_plugin_blocks={}
|
142
|
-
|
143
|
-
lr = builder.keywords.pop('lr') if 'lr' in builder.keywords else default_lr
|
144
|
-
train_plugin = builder.keywords.pop('train') if 'train' in builder.keywords else True
|
145
|
-
plugin_class = getattr(builder.func, '__self__', builder.func) # support static or class method
|
146
|
-
|
147
|
-
params_group = []
|
148
|
-
if issubclass(plugin_class, MultiPluginBlock):
|
149
|
-
from_layers = [{**item, 'layer':named_modules[item['layer']]} for item in get_match_layers(builder.keywords.pop('from_layers'), named_modules, return_metas=True)]
|
150
|
-
to_layers = [{**item, 'layer':named_modules[item['layer']]} for item in get_match_layers(builder.keywords.pop('to_layers'), named_modules, return_metas=True)]
|
151
|
-
|
152
|
-
layer = builder(name=plugin_name, host_model=model, from_layers=from_layers, to_layers=to_layers)
|
153
|
-
if train_plugin:
|
154
|
-
layer.train()
|
155
|
-
params = layer.get_trainable_parameters()
|
156
|
-
for p in params:
|
157
|
-
p.requires_grad_(True)
|
158
|
-
params_group.append(p)
|
159
|
-
else:
|
160
|
-
layer.requires_grad_(False)
|
161
|
-
layer.eval()
|
162
|
-
all_plugin_blocks[''] = layer
|
163
|
-
elif issubclass(plugin_class, SinglePluginBlock):
|
164
|
-
layers_name = builder.keywords.pop('layers')
|
165
|
-
for layer_name in get_match_layers(layers_name, named_modules):
|
166
|
-
blocks = builder(name=plugin_name, host_model=model, host=named_modules[layer_name])
|
167
|
-
if not isinstance(blocks, dict):
|
168
|
-
blocks={'':blocks}
|
169
|
-
|
170
|
-
for k,v in blocks.items():
|
171
|
-
all_plugin_blocks[net_path_join(layer_name, k)] = v
|
172
|
-
if train_plugin:
|
173
|
-
v.train()
|
174
|
-
params = v.get_trainable_parameters()
|
175
|
-
for p in params:
|
176
|
-
p.requires_grad_(True)
|
177
|
-
params_group.append(p)
|
178
|
-
else:
|
179
|
-
v.requires_grad_(False)
|
180
|
-
v.eval()
|
181
|
-
elif issubclass(plugin_class, PluginBlock):
|
182
|
-
from_layer = get_match_layers(builder.keywords.pop('from_layer'), named_modules, return_metas=True)
|
183
|
-
to_layer = get_match_layers(builder.keywords.pop('to_layer'), named_modules, return_metas=True)
|
184
|
-
|
185
|
-
for from_layer_meta, to_layer_meta in zip(from_layer, to_layer):
|
186
|
-
from_layer_name=from_layer_meta['layer']
|
187
|
-
from_layer_meta['layer']=named_modules[from_layer_name]
|
188
|
-
to_layer_meta['layer']=named_modules[to_layer_meta['layer']]
|
189
|
-
layer = builder(name=plugin_name, host_model=model, from_layer=from_layer_meta, to_layer=to_layer_meta)
|
190
|
-
if train_plugin:
|
191
|
-
layer.train()
|
192
|
-
params = layer.get_trainable_parameters()
|
193
|
-
for p in params:
|
194
|
-
p.requires_grad_(True)
|
195
|
-
params_group.append(p)
|
196
|
-
else:
|
197
|
-
layer.requires_grad_(False)
|
198
|
-
layer.eval()
|
199
|
-
all_plugin_blocks[from_layer_name] = layer
|
200
|
-
elif issubclass(plugin_class, PatchPluginBlock):
|
201
|
-
layers_name = builder.keywords.pop('layers')
|
202
|
-
for layer_name in get_match_layers(layers_name, named_modules):
|
203
|
-
parent_name, host_name = split_module_name(layer_name)
|
204
|
-
layers = builder(name=plugin_name, host_model=model, host=named_modules[layer_name],
|
205
|
-
parent_block=named_modules[parent_name], host_name=host_name)
|
206
|
-
if not isinstance(layers, dict):
|
207
|
-
layers={'':layers}
|
208
|
-
|
209
|
-
for k,v in layers.items():
|
210
|
-
all_plugin_blocks[net_path_join(layer_name, k)] = v
|
211
|
-
if train_plugin:
|
212
|
-
v.train()
|
213
|
-
params = v.get_trainable_parameters()
|
214
|
-
for p in params:
|
215
|
-
p.requires_grad_(True)
|
216
|
-
params_group.append(p)
|
217
|
-
else:
|
218
|
-
v.requires_grad_(False)
|
219
|
-
v.eval()
|
220
|
-
else:
|
221
|
-
raise NotImplementedError(f'Unknown plugin {plugin_class}')
|
222
|
-
if train_plugin:
|
223
|
-
train_params.append({'params':params_group, 'lr':lr})
|
224
|
-
all_plugin_group[plugin_name] = PluginGroup(all_plugin_blocks)
|
225
|
-
return train_params, all_plugin_group
|
226
|
-
|
227
|
-
class HCPModelLoader:
|
228
|
-
def __init__(self, host):
|
229
|
-
self.host = host
|
230
|
-
self.named_modules = {k:v for k, v in host.named_modules()}
|
231
|
-
self.named_params = {k:v for k, v in host.named_parameters()}
|
232
|
-
|
233
|
-
@torch.no_grad()
|
234
|
-
def load_part(self, cfg, base_model_alpha=0.0, load_ema=False):
|
235
|
-
if cfg is None:
|
236
|
-
return
|
237
|
-
for item in cfg:
|
238
|
-
part_state = auto_manager(item.path).load_ckpt(item.path, map_location='cpu')['base_ema' if load_ema else 'base']
|
239
|
-
layers = item.get('layers', 'all')
|
240
|
-
if layers == 'all':
|
241
|
-
for k, v in part_state.items():
|
242
|
-
self.named_params[k].data = base_model_alpha * self.named_params[k].data + item.alpha * v
|
243
|
-
else:
|
244
|
-
match_blocks = get_match_layers(layers, self.named_modules)
|
245
|
-
state_add = {k:v for blk in match_blocks for k,v in part_state.items() if k.startswith(blk)}
|
246
|
-
for k, v in state_add.items():
|
247
|
-
self.named_params[k].data = base_model_alpha * self.named_params[k].data + item.alpha * v
|
248
|
-
|
249
|
-
@torch.no_grad()
|
250
|
-
def load_lora(self, cfg, base_model_alpha=1.0, load_ema=False):
|
251
|
-
if cfg is None:
|
252
|
-
return
|
253
|
-
|
254
|
-
all_lora_blocks = {}
|
255
|
-
for lora_id, item in enumerate(cfg):
|
256
|
-
lora_state = auto_manager(item.path).load_ckpt(item.path, map_location='cpu')['lora_ema' if load_ema else 'lora']
|
257
|
-
lora_block_state = {}
|
258
|
-
# get all layers in the lora_state
|
259
|
-
for name, p in lora_state.items():
|
260
|
-
# lora_block. is the old format
|
261
|
-
prefix, block_name = name.split('.___.' if name.rfind('lora_block.')==-1 else '.lora_block.', 1)
|
262
|
-
if prefix not in lora_block_state:
|
263
|
-
lora_block_state[prefix] = {}
|
264
|
-
lora_block_state[prefix][block_name] = p
|
265
|
-
# get selected layers
|
266
|
-
layers = item.get('layers', 'all')
|
267
|
-
if layers != 'all':
|
268
|
-
match_blocks = get_match_layers(layers, self.named_modules)
|
269
|
-
lora_state_new = {}
|
270
|
-
for k, v in lora_block_state.items():
|
271
|
-
for mk in match_blocks:
|
272
|
-
if k.startswith(mk):
|
273
|
-
lora_state_new[k]=v
|
274
|
-
break
|
275
|
-
lora_block_state = lora_state_new
|
276
|
-
# add lora to host and load weights
|
277
|
-
for layer_name, lora_state in lora_block_state.items():
|
278
|
-
parent_name, host_name = split_module_name(layer_name)
|
279
|
-
lora_layer_cls, rank, old_format = get_lora_rank_and_cls(lora_state)
|
280
|
-
if 'alpha' in lora_state:
|
281
|
-
del lora_state['alpha']
|
282
|
-
|
283
|
-
if old_format:
|
284
|
-
lora_state = convert_state(lora_state)
|
285
|
-
|
286
|
-
lora_block = lora_layer_cls.wrap_layer(lora_id, self.named_modules[layer_name], rank=rank, dropout=getattr(item, 'dropout', 0.0),
|
287
|
-
alpha=getattr(item, 'alpha', 1.0), bias='layer.bias' in lora_state, alpha_auto_scale=getattr(item, 'alpha_auto_scale', True),
|
288
|
-
parent_block=self.named_modules[parent_name], host_name=host_name)
|
289
|
-
all_lora_blocks[f'{layer_name}.{lora_block.name}'] = lora_block
|
290
|
-
lora_block.load_state_dict(lora_state, strict=False)
|
291
|
-
lora_block.to(self.host.device)
|
292
|
-
return LoraGroup(all_lora_blocks)
|
293
|
-
|
294
|
-
@torch.no_grad()
|
295
|
-
def load_plugin(self, cfg, load_ema=False):
|
296
|
-
if cfg is None:
|
297
|
-
return
|
298
|
-
|
299
|
-
for name, item in cfg.items():
|
300
|
-
plugin_state = auto_manager(item.path).load_ckpt(item.path, map_location='cpu')['plugin_ema' if load_ema else 'plugin']
|
301
|
-
layers = item.get('layers', 'all')
|
302
|
-
if layers != 'all':
|
303
|
-
match_blocks = get_match_layers(layers, self.named_modules)
|
304
|
-
plugin_state = {k:v for blk in match_blocks for k, v in plugin_state.items() if k.startswith(blk)}
|
305
|
-
plugin_key_set = set([k.split('___', 1)[0]+name for k in plugin_state.keys()])
|
306
|
-
plugin_state = {k.replace('___', name):v for k, v in plugin_state.items()} # replace placeholder to target plugin name
|
307
|
-
self.host.load_state_dict(plugin_state, strict=False)
|
308
|
-
if 'layers' in item:
|
309
|
-
del item.layers
|
310
|
-
del item.path
|
311
|
-
if hasattr(self.host, name): # MultiPluginBlock
|
312
|
-
getattr(self.host, name).set_hyper_params(**item)
|
313
|
-
else:
|
314
|
-
for plugin_key in plugin_key_set:
|
315
|
-
self.named_modules[plugin_key].set_hyper_params(**item)
|
316
|
-
|
317
|
-
def load_all(self, cfg_merge, load_ema=False):
|
318
|
-
self.load_part(cfg_merge.get('part', []), base_model_alpha=cfg_merge.get('base_model_alpha', 0.0), load_ema=load_ema)
|
319
|
-
lora_group = self.load_lora(cfg_merge.get('lora', []), base_model_alpha=cfg_merge.get('base_model_alpha', 1.0), load_ema=load_ema)
|
320
|
-
self.load_plugin(cfg_merge.get('plugin', {}), load_ema=load_ema)
|
321
|
-
return lora_group
|
hcpdiff/utils/cfg_resolvers.py
DELETED
@@ -1,16 +0,0 @@
|
|
1
|
-
import time
|
2
|
-
import warnings
|
3
|
-
from omegaconf import OmegaConf
|
4
|
-
import torch
|
5
|
-
from .net_utils import dtype_dict
|
6
|
-
|
7
|
-
def times(a,b):
|
8
|
-
warnings.warn(f"${{times:{a},{b}}} is deprecated and will be removed in the future. Please use ${{hcp.eval:{a}*{b}}} instead.", DeprecationWarning)
|
9
|
-
return a*b
|
10
|
-
|
11
|
-
OmegaConf.register_new_resolver("times", times)
|
12
|
-
|
13
|
-
OmegaConf.register_new_resolver("hcp.eval", lambda exp: eval(exp))
|
14
|
-
OmegaConf.register_new_resolver("hcp.time", lambda format="%Y-%m-%d-%H-%M-%S": time.strftime(format))
|
15
|
-
|
16
|
-
OmegaConf.register_new_resolver("hcp.dtype", lambda dtype: dtype_dict.get(dtype, torch.float32))
|
hcpdiff/utils/ema.py
DELETED
@@ -1,52 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
from torch import nn
|
3
|
-
from copy import deepcopy
|
4
|
-
from typing import Iterable, Tuple, Dict
|
5
|
-
import numpy as np
|
6
|
-
|
7
|
-
class ModelEMA:
|
8
|
-
def __init__(self, model: nn.Module, decay_max=0.9997, inv_gamma=1., power=2/3, start_step=0, device='cpu'):
|
9
|
-
self.train_params = {name:p.data.to(device) for name, p in model.named_parameters() if p.requires_grad}
|
10
|
-
self.train_params.update({name:p.to(device) for name, p in model.named_buffers()})
|
11
|
-
self.decay_max = decay_max
|
12
|
-
self.inv_gamma = inv_gamma
|
13
|
-
self.power = power
|
14
|
-
self.step = start_step
|
15
|
-
self.device=device
|
16
|
-
|
17
|
-
@torch.no_grad()
|
18
|
-
def update(self, model: nn.Module):
|
19
|
-
self.step += 1
|
20
|
-
# Compute the decay factor for the exponential moving average.
|
21
|
-
decay = 1-(1+self.step/self.inv_gamma)**-self.power
|
22
|
-
decay = np.clip(decay, 0., self.decay_max)
|
23
|
-
|
24
|
-
for name, param in model.named_parameters():
|
25
|
-
if name in self.train_params:
|
26
|
-
self.train_params[name].lerp_(param.data.to(self.device), 1-decay) # (1-e)x + e*x_
|
27
|
-
|
28
|
-
for name, param in model.named_buffers():
|
29
|
-
if name in self.train_params:
|
30
|
-
self.train_params[name].copy_(param.to(self.device))
|
31
|
-
|
32
|
-
#torch.cuda.empty_cache()
|
33
|
-
|
34
|
-
def copy_to(self, model: nn.Module) -> None:
|
35
|
-
for name, param in model.named_parameters():
|
36
|
-
if name in self.train_params:
|
37
|
-
param.data.copy_(self.train_params[name])
|
38
|
-
|
39
|
-
def to(self, device=None, dtype=None):
|
40
|
-
# .to() on the tensors handles None correctly
|
41
|
-
self.train_params = {
|
42
|
-
name:(p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)) for name, p in self.train_params.items()
|
43
|
-
}
|
44
|
-
return self
|
45
|
-
|
46
|
-
def state_dict(self) -> Dict[str, torch.Tensor]:
|
47
|
-
return self.train_params
|
48
|
-
|
49
|
-
def load_state_dict(self, state: Dict[str, torch.Tensor]):
|
50
|
-
for k, v in state:
|
51
|
-
if k in self.train_params:
|
52
|
-
self.train_params[k]=v
|