hcpdiff 2.2__py3-none-any.whl → 2.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/easy/cfg/sd15_train.py +2 -2
- hcpdiff/easy/cfg/sdxl_train.py +2 -2
- hcpdiff/models/text_emb_ex.py +4 -0
- hcpdiff/trainer_ac.py +0 -7
- hcpdiff/trainer_deepspeed.py +47 -0
- {hcpdiff-2.2.dist-info → hcpdiff-2.2.1.dist-info}/METADATA +8 -4
- {hcpdiff-2.2.dist-info → hcpdiff-2.2.1.dist-info}/RECORD +11 -11
- {hcpdiff-2.2.dist-info → hcpdiff-2.2.1.dist-info}/WHEEL +1 -1
- {hcpdiff-2.2.dist-info → hcpdiff-2.2.1.dist-info}/entry_points.txt +1 -0
- hcpdiff/train_deepspeed.py +0 -69
- {hcpdiff-2.2.dist-info → hcpdiff-2.2.1.dist-info}/licenses/LICENSE +0 -0
- {hcpdiff-2.2.dist-info → hcpdiff-2.2.1.dist-info}/top_level.txt +0 -0
hcpdiff/easy/cfg/sd15_train.py
CHANGED
@@ -47,7 +47,7 @@ def SD15_finetuning(base_model: str, train_steps: int, dataset, save_step: int =
|
|
47
47
|
|
48
48
|
optimizer=optimizer,
|
49
49
|
|
50
|
-
|
50
|
+
lr_scheduler=ConstantLR(
|
51
51
|
_partial_=True,
|
52
52
|
warmup_steps=warmup_steps,
|
53
53
|
),
|
@@ -132,7 +132,7 @@ def SD15_lora_train(base_model: str, train_steps: int, dataset, save_step: int =
|
|
132
132
|
|
133
133
|
optimizer=optimizer,
|
134
134
|
|
135
|
-
|
135
|
+
lr_scheduler=ConstantLR(
|
136
136
|
_partial_=True,
|
137
137
|
warmup_steps=warmup_steps,
|
138
138
|
),
|
hcpdiff/easy/cfg/sdxl_train.py
CHANGED
@@ -44,7 +44,7 @@ def SDXL_finetuning(base_model: str, train_steps: int, dataset, save_step: int =
|
|
44
44
|
|
45
45
|
optimizer=optimizer,
|
46
46
|
|
47
|
-
|
47
|
+
lr_scheduler=ConstantLR(
|
48
48
|
_partial_=True,
|
49
49
|
warmup_steps=warmup_steps,
|
50
50
|
),
|
@@ -128,7 +128,7 @@ def SDXL_lora_train(base_model: str, train_steps: int, dataset, save_step: int =
|
|
128
128
|
|
129
129
|
optimizer=optimizer,
|
130
130
|
|
131
|
-
|
131
|
+
lr_scheduler=ConstantLR(
|
132
132
|
_partial_=True,
|
133
133
|
warmup_steps=warmup_steps,
|
134
134
|
),
|
hcpdiff/models/text_emb_ex.py
CHANGED
@@ -126,6 +126,10 @@ class EmbeddingPTInterpHook(SinglePluginBlock):
|
|
126
126
|
BOS = repeat(inputs_embeds[0,0,:], 'e -> r 1 e', r=self.N_repeats)
|
127
127
|
EOS = repeat(inputs_embeds[0,-1,:], 'e -> r 1 e', r=self.N_repeats)
|
128
128
|
|
129
|
+
# make DDP happy
|
130
|
+
if len(self.emb_train) > 0:
|
131
|
+
BOS = BOS + sum(emb.mean()*0 for emb in self.emb_train if emb.requires_grad)
|
132
|
+
|
129
133
|
replaced_embeds = []
|
130
134
|
for item, rep_idxs, ids_raw in zip(inputs_embeds, rep_idxs_B, self.input_ids):
|
131
135
|
# insert pt to embeddings
|
hcpdiff/trainer_ac.py
CHANGED
@@ -42,13 +42,6 @@ class HCPTrainer(Trainer):
|
|
42
42
|
def pt_trainable(self):
|
43
43
|
return self.cfgs.emb_pt is not None
|
44
44
|
|
45
|
-
def get_loss(self, ds_name, model_pred, inputs):
|
46
|
-
loss = super().get_loss(ds_name, model_pred, inputs)
|
47
|
-
# make DDP happy
|
48
|
-
if len(self.train_pts)>0:
|
49
|
-
loss = loss+0*sum([emb.mean() for emb in self.train_pts.values()])
|
50
|
-
return loss
|
51
|
-
|
52
45
|
def save_model(self, from_raw=False):
|
53
46
|
NekoSaver.save_all(
|
54
47
|
self.model_raw,
|
@@ -0,0 +1,47 @@
|
|
1
|
+
import argparse
|
2
|
+
import warnings
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from rainbowneko.ckpt_manager import NekoPluginSaver
|
6
|
+
from rainbowneko.train.trainer import TrainerDeepspeed
|
7
|
+
from rainbowneko.utils import xformers_available
|
8
|
+
|
9
|
+
from hcpdiff.trainer_ac import HCPTrainer, load_config_with_cli
|
10
|
+
|
11
|
+
class HCPTrainerDeepspeed(TrainerDeepspeed, HCPTrainer):
|
12
|
+
def config_model(self):
|
13
|
+
if self.cfgs.model.enable_xformers:
|
14
|
+
if xformers_available:
|
15
|
+
self.model_wrapper.enable_xformers()
|
16
|
+
else:
|
17
|
+
warnings.warn("xformers is not available. Make sure it is installed correctly")
|
18
|
+
|
19
|
+
if self.model_wrapper.vae is not None:
|
20
|
+
self.vae_dtype = self.weight_dtype_map.get(self.cfgs.model.get('vae_dtype', None), torch.float32)
|
21
|
+
self.model_wrapper.set_dtype(self.weight_dtype, self.vae_dtype)
|
22
|
+
|
23
|
+
if self.cfgs.model.gradient_checkpointing:
|
24
|
+
self.model_wrapper.enable_gradient_checkpointing()
|
25
|
+
|
26
|
+
if self.is_local_main_process:
|
27
|
+
for saver in self.ckpt_saver.values():
|
28
|
+
if isinstance(saver, NekoPluginSaver):
|
29
|
+
saver.plugin_from_raw = True
|
30
|
+
|
31
|
+
def hcp_train():
|
32
|
+
import subprocess
|
33
|
+
parser = argparse.ArgumentParser(description='HCP-Diffusion Launcher')
|
34
|
+
parser.add_argument('--launch_cfg', type=str, default='cfgs/launcher/deepspeed.yaml')
|
35
|
+
args, train_args = parser.parse_known_args()
|
36
|
+
|
37
|
+
subprocess.run(["accelerate", "launch", '--config_file', args.launch_cfg, "-m",
|
38
|
+
"hcpdiff.trainer_deepspeed"]+train_args, check=True)
|
39
|
+
|
40
|
+
if __name__ == '__main__':
|
41
|
+
parser = argparse.ArgumentParser(description='HCP Diffusion Trainer for DeepSpeed')
|
42
|
+
parser.add_argument("--cfg", type=str, default=None, required=True)
|
43
|
+
args, cfg_args = parser.parse_known_args()
|
44
|
+
|
45
|
+
parser, conf = load_config_with_cli(args.cfg, args_list=cfg_args) # skip --cfg
|
46
|
+
trainer = HCPTrainerDeepspeed(parser, conf)
|
47
|
+
trainer.train()
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: hcpdiff
|
3
|
-
Version: 2.2
|
3
|
+
Version: 2.2.1
|
4
4
|
Summary: A universal Diffusion toolbox
|
5
5
|
Home-page: https://github.com/IrisRainbowNeko/HCP-Diffusion
|
6
6
|
Author: Ziyi Dong
|
@@ -17,7 +17,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
17
17
|
Requires-Python: >=3.8
|
18
18
|
Description-Content-Type: text/markdown
|
19
19
|
License-File: LICENSE
|
20
|
-
Requires-Dist: rainbowneko
|
20
|
+
Requires-Dist: rainbowneko==1.6
|
21
21
|
Requires-Dist: diffusers
|
22
22
|
Requires-Dist: matplotlib
|
23
23
|
Requires-Dist: pyarrow
|
@@ -262,9 +262,13 @@ hcp_run --cfg cfgs/workflow/text2img_cli.py \
|
|
262
262
|
seed=42
|
263
263
|
```
|
264
264
|
|
265
|
-
### Tutorials
|
265
|
+
### 📚 Tutorials
|
266
266
|
|
267
|
-
|
267
|
+
+ 🧠 [Model Training Guide](https://hcpdiff.readthedocs.io/en/latest/user_guides/train.html)
|
268
|
+
+ 🔧 [LoRA Training Tutorial](https://hcpdiff.readthedocs.io/enlatest/tutorial/lora.html)
|
269
|
+
+ 🎨 [Image Generation Guide](https://hcpdiff.readthedocs.io/en/latest/user_guides/workflow.html)
|
270
|
+
+ ⚙️ [Configuration File Explanation](https://hcpdiff.readthedocs.io/en/latest/user_guides/cfg.html)
|
271
|
+
+ 🧩 [Model Format Explanation](https://hcpdiff.readthedocs.io/en/latest/user_guides/model_format.html)
|
268
272
|
|
269
273
|
---
|
270
274
|
|
@@ -1,8 +1,8 @@
|
|
1
1
|
hcpdiff/__init__.py,sha256=dwNwrEgvG4g60fGMG6b50K3q3AWD1XCfzlIgbxkSUpE,177
|
2
2
|
hcpdiff/train_colo.py,sha256=EsuNSzLBvGTZWU_LEk0JpP-F5eNW0lwkawIRAX38jmE,9250
|
3
|
-
hcpdiff/
|
4
|
-
hcpdiff/trainer_ac.py,sha256=6KAzo54in7ZRHud_rHjJdwRRZ4uWtc0B4SxVCxgcrmM,2990
|
3
|
+
hcpdiff/trainer_ac.py,sha256=scH3FU0onCQtwLiy0-pcrhuowTZob3fLQqRP52iwY0c,2717
|
5
4
|
hcpdiff/trainer_ac_single.py,sha256=0PIC5EScqcxp49EaeIWq4KS5K_09OZfKajqbFu-hUb8,1108
|
5
|
+
hcpdiff/trainer_deepspeed.py,sha256=7lGsiAstWuIlmhRMwWTcJCkoxzUaakVxBngKDnJdSJk,1947
|
6
6
|
hcpdiff/ckpt_manager/__init__.py,sha256=Mn_5KOC4xbf2GcN6OXg_XdbF5wO9zWeER_1ZO_prKAI,256
|
7
7
|
hcpdiff/ckpt_manager/ckpt.py,sha256=Pa3uXQbCi2T99mpV5fYddQ-OGHcpk8r1ll-0lmP_WXk,965
|
8
8
|
hcpdiff/ckpt_manager/loader.py,sha256=Ch1xsZmseq4nyPhpox9-nebN-dZB4k0rqBEHos-ZLso,3245
|
@@ -40,8 +40,8 @@ hcpdiff/diffusion/sampler/sigma_scheduler/edm.py,sha256=fOPB3lgnS9uVo4oW26Fur_nc
|
|
40
40
|
hcpdiff/easy/__init__.py,sha256=-emoyCOZlLCu3KNMI8L4qapUEtEYFSoiGU6-rKv1at4,149
|
41
41
|
hcpdiff/easy/sampler.py,sha256=dQSBkeGh71O0DAmZLhTHTbk1bY7XzyUCeW1oJO14A4I,1250
|
42
42
|
hcpdiff/easy/cfg/__init__.py,sha256=SxHMWG6T2CXhX3dP0xizSMd9vFWPaZQDc4Gj4CF__yQ,253
|
43
|
-
hcpdiff/easy/cfg/sd15_train.py,sha256=
|
44
|
-
hcpdiff/easy/cfg/sdxl_train.py,sha256=
|
43
|
+
hcpdiff/easy/cfg/sd15_train.py,sha256=kKdESVqAxNlBhhz12PvwrpHJBea80OUFzDDMHwiulVs,6710
|
44
|
+
hcpdiff/easy/cfg/sdxl_train.py,sha256=FUWE_hRJdQc9Qd9J6730jAyK0H4EIKS7-3BSufCItXU,4275
|
45
45
|
hcpdiff/easy/cfg/t2i.py,sha256=SnjFjZAKd9orjJr3RW5_N2_EIlW2Ree7JMvdNUAR9gc,9507
|
46
46
|
hcpdiff/easy/model/__init__.py,sha256=CA-7r3R2Jgweekk1XNByFYttLolbWyUV2bCnXygcD8w,133
|
47
47
|
hcpdiff/easy/model/cnet.py,sha256=m0NTH9V1kLzb5GybwBrSNT0KvTcRpPfGkzUeMz9jZZQ,1084
|
@@ -62,7 +62,7 @@ hcpdiff/models/lora_base.py,sha256=LGwBD9KP6qf4pgTx24i5-JLo4rDBQ6jFfterQKBjTbE,6
|
|
62
62
|
hcpdiff/models/lora_base_patch.py,sha256=WW3CULnROTxKXyynJiqirhHYCKN5JtxLhVpT5b7AUQg,6532
|
63
63
|
hcpdiff/models/lora_layers.py,sha256=O9W_Ue71lHj7Y_GbpioF4Hc3h2-z_zOqck93VYUra6s,7777
|
64
64
|
hcpdiff/models/lora_layers_patch.py,sha256=GYFYsJD2VSLZfdnLma9CmQEHz09HROFJcc4wc_gs9f0,8198
|
65
|
-
hcpdiff/models/text_emb_ex.py,sha256=
|
65
|
+
hcpdiff/models/text_emb_ex.py,sha256=O0XZqid01OrB0dHY7hCiBvdU2026SvZ38yfQaF2TWrs,8018
|
66
66
|
hcpdiff/models/textencoder_ex.py,sha256=JrTQ30Avx8tPbdr-Q6K5BvEWCEdsu8Z7eSOzMqpUuzg,8270
|
67
67
|
hcpdiff/models/tokenizer_ex.py,sha256=zKUn4BY7b3yXwK9PWkZtQKJPyKYwUc07E-hwB9NQybs,2446
|
68
68
|
hcpdiff/models/compose/__init__.py,sha256=lTNFTGg5csqvUuys22RqgjmWlk_7Okw6ZTsnTi1pqCg,217
|
@@ -107,9 +107,9 @@ hcpdiff/workflow/vae.py,sha256=cingDPkIOc4qGpOwwhXJK4EQbGoIxO583pm6gGov5t8,3118
|
|
107
107
|
hcpdiff/workflow/daam/__init__.py,sha256=ySIDaxloN-D3qM7OuVaG1BR3D-CibDoXYpoTgw0zUhU,59
|
108
108
|
hcpdiff/workflow/daam/act.py,sha256=tHbsFWTYYU4bvcZOo1Bpi_z6ofpJatRYccl4vvf8wIA,2756
|
109
109
|
hcpdiff/workflow/daam/hook.py,sha256=z9f9mBjKW21xuUZ-iQxQ0HbWOBXtZrisFB0VNMq6d0U,4383
|
110
|
-
hcpdiff-2.2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
111
|
-
hcpdiff-2.2.dist-info/METADATA,sha256=
|
112
|
-
hcpdiff-2.2.dist-info/WHEEL,sha256=
|
113
|
-
hcpdiff-2.2.dist-info/entry_points.txt,sha256=
|
114
|
-
hcpdiff-2.2.dist-info/top_level.txt,sha256=shyf78x-HVgykYpsmY22mKG0xIc7Qk30fDMdavdYWQ8,8
|
115
|
-
hcpdiff-2.2.dist-info/RECORD,,
|
110
|
+
hcpdiff-2.2.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
111
|
+
hcpdiff-2.2.1.dist-info/METADATA,sha256=f96Tc90K5WTBbJ35wWJw60G2JR46eGpUvQSaPIysVDg,10323
|
112
|
+
hcpdiff-2.2.1.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
|
113
|
+
hcpdiff-2.2.1.dist-info/entry_points.txt,sha256=_4VRsEsEWOhHfzBDu9bx8Wh_S8Wi4ZTHpI0n6rU0J-I,258
|
114
|
+
hcpdiff-2.2.1.dist-info/top_level.txt,sha256=shyf78x-HVgykYpsmY22mKG0xIc7Qk30fDMdavdYWQ8,8
|
115
|
+
hcpdiff-2.2.1.dist-info/RECORD,,
|
hcpdiff/train_deepspeed.py
DELETED
@@ -1,69 +0,0 @@
|
|
1
|
-
import argparse
|
2
|
-
import os
|
3
|
-
import sys
|
4
|
-
import warnings
|
5
|
-
from functools import partial
|
6
|
-
|
7
|
-
import torch
|
8
|
-
|
9
|
-
from hcpdiff.ckpt_manager import CkptManagerPKL, CkptManagerSafe
|
10
|
-
from hcpdiff.train_ac_old import Trainer, load_config_with_cli
|
11
|
-
from hcpdiff.utils.net_utils import get_scheduler
|
12
|
-
|
13
|
-
class TrainerDeepSpeed(Trainer):
|
14
|
-
|
15
|
-
def build_ckpt_manager(self):
|
16
|
-
self.ckpt_manager = self.ckpt_manager_map[self.cfgs.ckpt_type](plugin_from_raw=True)
|
17
|
-
if self.is_local_main_process:
|
18
|
-
self.ckpt_manager.set_save_dir(os.path.join(self.exp_dir, 'ckpts'), emb_dir=self.cfgs.tokenizer_pt.emb_dir)
|
19
|
-
|
20
|
-
@property
|
21
|
-
def unet_raw(self):
|
22
|
-
return self.accelerator.unwrap_model(self.TE_unet).unet if self.train_TE else self.accelerator.unwrap_model(self.TE_unet.unet)
|
23
|
-
|
24
|
-
@property
|
25
|
-
def TE_raw(self):
|
26
|
-
return self.accelerator.unwrap_model(self.TE_unet).TE if self.train_TE else self.TE_unet.TE
|
27
|
-
|
28
|
-
def get_loss(self, model_pred, target, timesteps, att_mask):
|
29
|
-
if att_mask is None:
|
30
|
-
att_mask = 1.0
|
31
|
-
if getattr(self.criterion, 'need_timesteps', False):
|
32
|
-
loss = (self.criterion(model_pred.float(), target.float(), timesteps)*att_mask).mean()
|
33
|
-
else:
|
34
|
-
loss = (self.criterion(model_pred.float(), target.float())*att_mask).mean()
|
35
|
-
return loss
|
36
|
-
|
37
|
-
def build_optimizer_scheduler(self):
|
38
|
-
# set optimizer
|
39
|
-
parameters, parameters_pt = self.get_param_group_train()
|
40
|
-
|
41
|
-
if len(parameters_pt)>0: # do prompt-tuning
|
42
|
-
cfg_opt_pt = self.cfgs.train.optimizer_pt
|
43
|
-
# if self.cfgs.train.scale_lr_pt:
|
44
|
-
# self.scale_lr(parameters_pt)
|
45
|
-
assert isinstance(cfg_opt_pt, partial), f'optimizer.type is not supported anymore, please use class path like "torch.optim.AdamW".'
|
46
|
-
weight_decay = cfg_opt_pt.keywords.get('weight_decay', None)
|
47
|
-
if weight_decay is not None:
|
48
|
-
for param in parameters_pt:
|
49
|
-
param['weight_decay'] = weight_decay
|
50
|
-
|
51
|
-
parameters += parameters_pt
|
52
|
-
warnings.warn('deepspeed dose not support multi optimizer and lr_scheduler. optimizer_pt and scheduler_pt will not work.')
|
53
|
-
|
54
|
-
if len(parameters)>0:
|
55
|
-
cfg_opt = self.cfgs.train.optimizer
|
56
|
-
if self.cfgs.train.scale_lr:
|
57
|
-
self.scale_lr(parameters)
|
58
|
-
assert isinstance(cfg_opt, partial), f'optimizer.type is not supported anymore, please use class path like "torch.optim.AdamW".'
|
59
|
-
self.optimizer = cfg_opt(params=parameters)
|
60
|
-
self.lr_scheduler = get_scheduler(self.cfgs.train.scheduler, self.optimizer)
|
61
|
-
|
62
|
-
if __name__ == '__main__':
|
63
|
-
parser = argparse.ArgumentParser(description='Stable Diffusion Training')
|
64
|
-
parser.add_argument('--cfg', type=str, default='cfg/train/demo.yaml')
|
65
|
-
args, cfg_args = parser.parse_known_args()
|
66
|
-
|
67
|
-
conf = load_config_with_cli(args.cfg, args_list=cfg_args) # skip --cfg
|
68
|
-
trainer = TrainerDeepSpeed(conf)
|
69
|
-
trainer.train()
|
File without changes
|
File without changes
|