hcpdiff 0.9.0__py3-none-any.whl → 2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/__init__.py +4 -4
- hcpdiff/ckpt_manager/__init__.py +4 -5
- hcpdiff/ckpt_manager/ckpt.py +24 -0
- hcpdiff/ckpt_manager/format/__init__.py +4 -0
- hcpdiff/ckpt_manager/format/diffusers.py +59 -0
- hcpdiff/ckpt_manager/format/emb.py +21 -0
- hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
- hcpdiff/ckpt_manager/format/sd_single.py +41 -0
- hcpdiff/ckpt_manager/loader.py +64 -0
- hcpdiff/data/__init__.py +4 -28
- hcpdiff/data/cache/__init__.py +1 -0
- hcpdiff/data/cache/vae.py +102 -0
- hcpdiff/data/dataset.py +20 -0
- hcpdiff/data/handler/__init__.py +3 -0
- hcpdiff/data/handler/controlnet.py +18 -0
- hcpdiff/data/handler/diffusion.py +80 -0
- hcpdiff/data/handler/text.py +111 -0
- hcpdiff/data/source/__init__.py +1 -2
- hcpdiff/data/source/folder_class.py +12 -29
- hcpdiff/data/source/text2img.py +36 -74
- hcpdiff/data/source/text2img_cond.py +9 -15
- hcpdiff/diffusion/__init__.py +0 -0
- hcpdiff/diffusion/noise/__init__.py +2 -0
- hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
- hcpdiff/diffusion/noise/zero_terminal.py +39 -0
- hcpdiff/diffusion/sampler/__init__.py +5 -0
- hcpdiff/diffusion/sampler/base.py +72 -0
- hcpdiff/diffusion/sampler/ddpm.py +20 -0
- hcpdiff/diffusion/sampler/diffusers.py +66 -0
- hcpdiff/diffusion/sampler/edm.py +22 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
- hcpdiff/easy/__init__.py +2 -0
- hcpdiff/easy/cfg/__init__.py +3 -0
- hcpdiff/easy/cfg/sd15_train.py +201 -0
- hcpdiff/easy/cfg/sdxl_train.py +140 -0
- hcpdiff/easy/cfg/t2i.py +177 -0
- hcpdiff/easy/model/__init__.py +2 -0
- hcpdiff/easy/model/cnet.py +31 -0
- hcpdiff/easy/model/loader.py +79 -0
- hcpdiff/easy/sampler.py +46 -0
- hcpdiff/evaluate/__init__.py +1 -0
- hcpdiff/evaluate/previewer.py +60 -0
- hcpdiff/loss/__init__.py +4 -1
- hcpdiff/loss/base.py +41 -0
- hcpdiff/loss/gw.py +35 -0
- hcpdiff/loss/ssim.py +37 -0
- hcpdiff/loss/vlb.py +79 -0
- hcpdiff/loss/weighting.py +66 -0
- hcpdiff/models/__init__.py +2 -2
- hcpdiff/models/cfg_context.py +17 -14
- hcpdiff/models/compose/compose_hook.py +44 -23
- hcpdiff/models/compose/compose_tokenizer.py +21 -8
- hcpdiff/models/compose/sdxl_composer.py +4 -4
- hcpdiff/models/container.py +1 -1
- hcpdiff/models/controlnet.py +16 -16
- hcpdiff/models/lora_base_patch.py +14 -25
- hcpdiff/models/lora_layers.py +3 -9
- hcpdiff/models/lora_layers_patch.py +14 -24
- hcpdiff/models/text_emb_ex.py +84 -6
- hcpdiff/models/textencoder_ex.py +54 -18
- hcpdiff/models/wrapper/__init__.py +3 -0
- hcpdiff/models/wrapper/pixart.py +19 -0
- hcpdiff/models/wrapper/sd.py +218 -0
- hcpdiff/models/wrapper/utils.py +20 -0
- hcpdiff/parser/__init__.py +1 -0
- hcpdiff/parser/embpt.py +32 -0
- hcpdiff/tools/convert_caption_txt2json.py +1 -1
- hcpdiff/tools/dataset_generator.py +94 -0
- hcpdiff/tools/download_hf_model.py +24 -0
- hcpdiff/tools/embedding_convert.py +6 -2
- hcpdiff/tools/init_proj.py +3 -21
- hcpdiff/tools/lora_convert.py +19 -15
- hcpdiff/tools/save_model.py +12 -0
- hcpdiff/tools/sd2diffusers.py +1 -1
- hcpdiff/train_colo.py +1 -1
- hcpdiff/train_deepspeed.py +1 -1
- hcpdiff/trainer_ac.py +79 -0
- hcpdiff/trainer_ac_single.py +31 -0
- hcpdiff/utils/__init__.py +0 -2
- hcpdiff/utils/inpaint_pipe.py +790 -0
- hcpdiff/utils/net_utils.py +29 -6
- hcpdiff/utils/pipe_hook.py +46 -33
- hcpdiff/utils/utils.py +21 -4
- hcpdiff/workflow/__init__.py +15 -10
- hcpdiff/workflow/daam/__init__.py +1 -0
- hcpdiff/workflow/daam/act.py +66 -0
- hcpdiff/workflow/daam/hook.py +109 -0
- hcpdiff/workflow/diffusion.py +128 -136
- hcpdiff/workflow/fast.py +31 -0
- hcpdiff/workflow/flow.py +67 -0
- hcpdiff/workflow/io.py +36 -68
- hcpdiff/workflow/model.py +46 -43
- hcpdiff/workflow/text.py +84 -52
- hcpdiff/workflow/utils.py +32 -12
- hcpdiff/workflow/vae.py +37 -38
- hcpdiff-2.1.dist-info/METADATA +285 -0
- hcpdiff-2.1.dist-info/RECORD +114 -0
- {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
- hcpdiff-2.1.dist-info/entry_points.txt +5 -0
- hcpdiff/ckpt_manager/base.py +0 -16
- hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
- hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
- hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -60
- hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
- hcpdiff/data/bucket.py +0 -358
- hcpdiff/data/caption_loader.py +0 -80
- hcpdiff/data/cond_dataset.py +0 -40
- hcpdiff/data/crop_info_dataset.py +0 -40
- hcpdiff/data/data_processor.py +0 -33
- hcpdiff/data/pair_dataset.py +0 -146
- hcpdiff/data/sampler.py +0 -54
- hcpdiff/data/source/base.py +0 -30
- hcpdiff/data/utils.py +0 -80
- hcpdiff/infer_workflow.py +0 -57
- hcpdiff/loggers/__init__.py +0 -13
- hcpdiff/loggers/base_logger.py +0 -76
- hcpdiff/loggers/cli_logger.py +0 -40
- hcpdiff/loggers/preview/__init__.py +0 -1
- hcpdiff/loggers/preview/image_previewer.py +0 -149
- hcpdiff/loggers/tensorboard_logger.py +0 -30
- hcpdiff/loggers/wandb_logger.py +0 -31
- hcpdiff/loggers/webui_logger.py +0 -9
- hcpdiff/loss/min_snr_loss.py +0 -52
- hcpdiff/models/layers.py +0 -81
- hcpdiff/models/plugin.py +0 -348
- hcpdiff/models/wrapper.py +0 -75
- hcpdiff/noise/__init__.py +0 -3
- hcpdiff/noise/noise_base.py +0 -16
- hcpdiff/noise/pyramid_noise.py +0 -50
- hcpdiff/noise/zero_terminal.py +0 -44
- hcpdiff/train_ac.py +0 -565
- hcpdiff/train_ac_single.py +0 -39
- hcpdiff/utils/caption_tools.py +0 -105
- hcpdiff/utils/cfg_net_tools.py +0 -321
- hcpdiff/utils/cfg_resolvers.py +0 -16
- hcpdiff/utils/ema.py +0 -52
- hcpdiff/utils/img_size_tool.py +0 -248
- hcpdiff/vis/__init__.py +0 -3
- hcpdiff/vis/base_interface.py +0 -12
- hcpdiff/vis/disk_interface.py +0 -48
- hcpdiff/vis/webui_interface.py +0 -17
- hcpdiff/visualizer.py +0 -258
- hcpdiff/visualizer_reloadable.py +0 -237
- hcpdiff/workflow/base.py +0 -59
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -57
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero2.json +0 -32
- hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero3.json +0 -39
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
- hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
- hcpdiff-0.9.0.dist-info/METADATA +0 -199
- hcpdiff-0.9.0.dist-info/RECORD +0 -155
- hcpdiff-0.9.0.dist-info/entry_points.txt +0 -2
- {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
- {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
hcpdiff/tools/lora_convert.py
CHANGED
@@ -3,14 +3,14 @@ import os.path
|
|
3
3
|
from typing import List
|
4
4
|
import math
|
5
5
|
|
6
|
-
from
|
6
|
+
from rainbowneko.ckpt_manager import auto_ckpt_loader, NekoModelSaver
|
7
7
|
|
8
8
|
class LoraConverter:
|
9
9
|
com_name_unet = ['down_blocks', 'up_blocks', 'mid_block', 'transformer_blocks', 'to_q', 'to_k', 'to_v', 'to_out', 'proj_in', 'proj_out', 'input_blocks', 'middle_block', 'output_blocks']
|
10
10
|
com_name_TE = ['self_attn', 'q_proj', 'v_proj', 'k_proj', 'out_proj', 'text_model']
|
11
11
|
prefix_unet = 'lora_unet_'
|
12
12
|
prefix_TE = 'lora_te_'
|
13
|
-
|
13
|
+
prefix_TE_xl_clip_L = 'lora_te1_'
|
14
14
|
prefix_TE_xl_clip_bigG = 'lora_te2_'
|
15
15
|
|
16
16
|
lora_w_map = {'lora_down.weight': 'W_down', 'lora_up.weight':'W_up'}
|
@@ -25,14 +25,14 @@ class LoraConverter:
|
|
25
25
|
sd_TE = self.convert_from_webui_(state, prefix=self.prefix_TE, com_name=self.com_name_TE, com_name_tmp=self.com_name_TE_tmp)
|
26
26
|
else:
|
27
27
|
sd_unet = self.convert_from_webui_xl_unet_(state, prefix=self.prefix_unet, com_name=self.com_name_unet, com_name_tmp=self.com_name_unet_tmp)
|
28
|
-
sd_TE = self.convert_from_webui_xl_te_(state, prefix=self.
|
28
|
+
sd_TE = self.convert_from_webui_xl_te_(state, prefix=self.prefix_TE_xl_clip_L, com_name=self.com_name_TE, com_name_tmp=self.com_name_TE_tmp)
|
29
29
|
sd_TE2 = self.convert_from_webui_xl_te_(state, prefix=self.prefix_TE_xl_clip_bigG, com_name=self.com_name_TE, com_name_tmp=self.com_name_TE_tmp)
|
30
30
|
sd_TE.update(sd_TE2)
|
31
31
|
|
32
32
|
if auto_scale_alpha:
|
33
33
|
sd_unet = self.alpha_scale_from_webui(sd_unet)
|
34
34
|
sd_TE = self.alpha_scale_from_webui(sd_TE)
|
35
|
-
return {'
|
35
|
+
return {'plugin': sd_TE}, {'plugin': sd_unet}
|
36
36
|
|
37
37
|
def convert_to_webui(self, sd_unet, sd_TE, auto_scale_alpha=False, sdxl=False):
|
38
38
|
sd_unet = self.convert_to_webui_(sd_unet, prefix=self.prefix_unet)
|
@@ -72,7 +72,7 @@ class LoraConverter:
|
|
72
72
|
|
73
73
|
sd_covert[f"{prefix}{model_k.replace('.', '_')}.{lora_k}"] = v
|
74
74
|
return sd_covert
|
75
|
-
|
75
|
+
|
76
76
|
def convert_to_webui_xl_(self, state, prefix):
|
77
77
|
sd_convert = {}
|
78
78
|
for k, v in state.items():
|
@@ -87,7 +87,7 @@ class LoraConverter:
|
|
87
87
|
|
88
88
|
new_k = f"{prefix}{model_k.replace('.', '_')}.{lora_k}"
|
89
89
|
if 'clip' in new_k:
|
90
|
-
new_k = new_k.replace('
|
90
|
+
new_k = new_k.replace('_clip_L', '1') if 'clip_L' in new_k else new_k.replace('_clip_bigG', '2')
|
91
91
|
sd_convert[new_k] = v
|
92
92
|
return sd_convert
|
93
93
|
|
@@ -100,7 +100,7 @@ class LoraConverter:
|
|
100
100
|
model_k, lora_k = k[prefix_len:].split('.', 1)
|
101
101
|
model_k = self.replace_all(model_k, com_name, com_name_tmp).replace('_', '.').replace('%', '_')
|
102
102
|
if prefix == 'lora_te1_':
|
103
|
-
model_k = f'
|
103
|
+
model_k = f'clip_L.{model_k}'
|
104
104
|
else:
|
105
105
|
model_k = f'clip_bigG.{model_k}'
|
106
106
|
|
@@ -221,23 +221,27 @@ if __name__ == '__main__':
|
|
221
221
|
|
222
222
|
# load lora model
|
223
223
|
print('convert lora model')
|
224
|
-
|
224
|
+
ckpt_loader = auto_ckpt_loader(args.lora_path)
|
225
|
+
ckpt_saver = NekoModelSaver(
|
226
|
+
format=ckpt_loader.format,
|
227
|
+
source=ckpt_loader.source,
|
228
|
+
)
|
225
229
|
|
226
230
|
if args.from_webui:
|
227
|
-
state =
|
231
|
+
state = ckpt_loader.load(args.lora_path)
|
228
232
|
# convert the weight name
|
229
233
|
sd_TE, sd_unet = converter.convert_from_webui(state, auto_scale_alpha=args.auto_scale_alpha, sdxl=args.sdxl)
|
230
234
|
# wegiht save
|
231
235
|
os.makedirs(args.dump_path, exist_ok=True)
|
232
236
|
TE_path = os.path.join(args.dump_path, 'TE-'+lora_name)
|
233
237
|
unet_path = os.path.join(args.dump_path, 'unet-'+lora_name)
|
234
|
-
|
235
|
-
|
238
|
+
ckpt_saver.save(sd_TE, TE_path)
|
239
|
+
ckpt_saver.save(sd_unet, unet_path)
|
236
240
|
print('save text encoder lora to:', TE_path)
|
237
241
|
print('save unet lora to:', unet_path)
|
238
242
|
elif args.to_webui:
|
239
|
-
sd_unet =
|
240
|
-
sd_TE =
|
241
|
-
state = converter.convert_to_webui(sd_unet['
|
242
|
-
|
243
|
+
sd_unet = ckpt_loader.load(args.lora_path)
|
244
|
+
sd_TE = ckpt_loader.load(args.lora_path_TE) if args.lora_path_TE else {'base':{}}
|
245
|
+
state = converter.convert_to_webui(sd_unet['base'], sd_TE['base'], auto_scale_alpha=args.auto_scale_alpha, sdxl=args.sdxl)
|
246
|
+
ckpt_saver.save(state, args.dump_path)
|
243
247
|
print('save lora to:', args.dump_path)
|
@@ -0,0 +1,12 @@
|
|
1
|
+
from diffusers import DiffusionPipeline
|
2
|
+
import argparse
|
3
|
+
|
4
|
+
parser = argparse.ArgumentParser()
|
5
|
+
parser.add_argument("model", default=None, type=str)
|
6
|
+
parser.add_argument("output", default=None, type=str)
|
7
|
+
args = parser.parse_args()
|
8
|
+
|
9
|
+
pipe = DiffusionPipeline.from_pretrained(args.model, safety_checker=None, requires_safety_checker=False,
|
10
|
+
resume_download=True)
|
11
|
+
|
12
|
+
pipe.save_pretrained(args.output)
|
hcpdiff/tools/sd2diffusers.py
CHANGED
@@ -211,7 +211,7 @@ def sd_vae_to_diffuser(args):
|
|
211
211
|
def convert_ckpt(args):
|
212
212
|
pipe = load_sd_ckpt(
|
213
213
|
args.checkpoint_path,
|
214
|
-
|
214
|
+
config_files={'v1': args.original_config_file},
|
215
215
|
image_size=args.image_size,
|
216
216
|
prediction_type=args.prediction_type,
|
217
217
|
model_type=args.pipeline_type,
|
hcpdiff/train_colo.py
CHANGED
@@ -23,7 +23,7 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
|
|
23
23
|
from colossalai.utils.model.colo_init_context import _convert_to_coloparam
|
24
24
|
from colossalai.tensor import ColoParameter
|
25
25
|
|
26
|
-
from hcpdiff.
|
26
|
+
from hcpdiff.train_ac_old import Trainer, get_scheduler, ModelEMA
|
27
27
|
from diffusers import UNet2DConditionModel
|
28
28
|
from hcpdiff.utils.colo_utils import gemini_zero_dpp, GeminiAdamOptimizerP
|
29
29
|
from hcpdiff.utils.utils import load_config_with_cli
|
hcpdiff/train_deepspeed.py
CHANGED
@@ -7,7 +7,7 @@ from functools import partial
|
|
7
7
|
import torch
|
8
8
|
|
9
9
|
from hcpdiff.ckpt_manager import CkptManagerPKL, CkptManagerSafe
|
10
|
-
from hcpdiff.
|
10
|
+
from hcpdiff.train_ac_old import Trainer, load_config_with_cli
|
11
11
|
from hcpdiff.utils.net_utils import get_scheduler
|
12
12
|
|
13
13
|
class TrainerDeepSpeed(Trainer):
|
hcpdiff/trainer_ac.py
ADDED
@@ -0,0 +1,79 @@
|
|
1
|
+
import argparse
|
2
|
+
import warnings
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from rainbowneko.parser import load_config_with_cli
|
6
|
+
from rainbowneko.ckpt_manager import NekoSaver
|
7
|
+
from rainbowneko.train import Trainer
|
8
|
+
from rainbowneko.utils import xformers_available, is_dict
|
9
|
+
from hcpdiff.ckpt_manager import EmbFormat
|
10
|
+
|
11
|
+
class HCPTrainer(Trainer):
|
12
|
+
def config_model(self):
|
13
|
+
if self.cfgs.model.enable_xformers:
|
14
|
+
if xformers_available:
|
15
|
+
self.model_wrapper.enable_xformers()
|
16
|
+
else:
|
17
|
+
warnings.warn("xformers is not available. Make sure it is installed correctly")
|
18
|
+
|
19
|
+
if self.model_wrapper.vae is not None:
|
20
|
+
self.vae_dtype = self.weight_dtype_map.get(self.cfgs.model.get('vae_dtype', None), torch.float32)
|
21
|
+
self.model_wrapper.set_dtype(self.weight_dtype, self.vae_dtype)
|
22
|
+
|
23
|
+
if self.cfgs.model.gradient_checkpointing:
|
24
|
+
self.model_wrapper.enable_gradient_checkpointing()
|
25
|
+
|
26
|
+
def get_param_group_train(self):
|
27
|
+
train_params = super().get_param_group_train()
|
28
|
+
|
29
|
+
# For prompt-tuning
|
30
|
+
if self.cfgs.emb_pt is None:
|
31
|
+
train_params_emb, self.train_pts = [], {}
|
32
|
+
else:
|
33
|
+
from hcpdiff.parser import CfgEmbPTParser
|
34
|
+
self.cfgs.emb_pt: CfgEmbPTParser
|
35
|
+
|
36
|
+
train_params_emb, self.train_pts = self.cfgs.emb_pt.get_params_group(self.model_wrapper)
|
37
|
+
self.emb_format = EmbFormat()
|
38
|
+
train_params += train_params_emb
|
39
|
+
return train_params
|
40
|
+
|
41
|
+
@property
|
42
|
+
def pt_trainable(self):
|
43
|
+
return self.cfgs.emb_pt is not None
|
44
|
+
|
45
|
+
def get_loss(self, ds_name, model_pred, inputs):
|
46
|
+
loss = super().get_loss(ds_name, model_pred, inputs)
|
47
|
+
# make DDP happy
|
48
|
+
if len(self.train_pts)>0:
|
49
|
+
loss = loss+0*sum([emb.mean() for emb in self.train_pts.values()])
|
50
|
+
return loss
|
51
|
+
|
52
|
+
def save_model(self, from_raw=False):
|
53
|
+
NekoSaver.save_all(
|
54
|
+
self.model_raw,
|
55
|
+
plugin_groups={**self.all_plugin, 'embs': self.train_pts},
|
56
|
+
cfg=self.ckpt_saver,
|
57
|
+
model_ema=getattr(self, "ema_model", None),
|
58
|
+
name_template=f'{{}}-{self.real_step}',
|
59
|
+
)
|
60
|
+
|
61
|
+
self.loggers.info(f"Saved state, step: {self.real_step}")
|
62
|
+
|
63
|
+
def hcp_train():
|
64
|
+
import subprocess
|
65
|
+
parser = argparse.ArgumentParser(description='HCP-Diffusion Launcher')
|
66
|
+
parser.add_argument('--launch_cfg', type=str, default='cfgs/launcher/multi.yaml')
|
67
|
+
args, train_args = parser.parse_known_args()
|
68
|
+
|
69
|
+
subprocess.run(["accelerate", "launch", '--config_file', args.launch_cfg, "-m",
|
70
|
+
"hcpdiff.trainer_ac"] + train_args, check=True)
|
71
|
+
|
72
|
+
if __name__ == "__main__":
|
73
|
+
parser = argparse.ArgumentParser(description="HCP Diffusion Trainer")
|
74
|
+
parser.add_argument("--cfg", type=str, default=None, required=True)
|
75
|
+
args, cfg_args = parser.parse_known_args()
|
76
|
+
|
77
|
+
parser, conf = load_config_with_cli(args.cfg, args_list=cfg_args) # skip --cfg
|
78
|
+
trainer = HCPTrainer(parser, conf)
|
79
|
+
trainer.train()
|
@@ -0,0 +1,31 @@
|
|
1
|
+
import argparse
|
2
|
+
import sys
|
3
|
+
from functools import partial
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from accelerate import Accelerator
|
7
|
+
from loguru import logger
|
8
|
+
|
9
|
+
from rainbowneko.train.trainer import TrainerSingleCard
|
10
|
+
from hcpdiff.trainer_ac import HCPTrainer, load_config_with_cli
|
11
|
+
|
12
|
+
class HCPTrainerSingleCard(TrainerSingleCard, HCPTrainer):
|
13
|
+
pass
|
14
|
+
|
15
|
+
def hcp_train():
|
16
|
+
import subprocess
|
17
|
+
parser = argparse.ArgumentParser(description='HCP-Diffusion Launcher')
|
18
|
+
parser.add_argument('--launch_cfg', type=str, default='cfgs/launcher/single.yaml')
|
19
|
+
args, train_args = parser.parse_known_args()
|
20
|
+
|
21
|
+
subprocess.run(["accelerate", "launch", '--config_file', args.launch_cfg, "-m",
|
22
|
+
"hcpdiff.trainer_ac_single"] + train_args, check=True)
|
23
|
+
|
24
|
+
if __name__ == '__main__':
|
25
|
+
parser = argparse.ArgumentParser(description='HCP Diffusion Trainer')
|
26
|
+
parser.add_argument("--cfg", type=str, default=None, required=True)
|
27
|
+
args, cfg_args = parser.parse_known_args()
|
28
|
+
|
29
|
+
parser, conf = load_config_with_cli(args.cfg, args_list=cfg_args) # skip --cfg
|
30
|
+
trainer = HCPTrainerSingleCard(parser, conf)
|
31
|
+
trainer.train()
|
hcpdiff/utils/__init__.py
CHANGED