hcpdiff 2.2__py3-none-any.whl → 2.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -47,7 +47,7 @@ def SD15_finetuning(base_model: str, train_steps: int, dataset, save_step: int =
47
47
 
48
48
  optimizer=optimizer,
49
49
 
50
- scheduler=ConstantLR(
50
+ lr_scheduler=ConstantLR(
51
51
  _partial_=True,
52
52
  warmup_steps=warmup_steps,
53
53
  ),
@@ -132,7 +132,7 @@ def SD15_lora_train(base_model: str, train_steps: int, dataset, save_step: int =
132
132
 
133
133
  optimizer=optimizer,
134
134
 
135
- scheduler=ConstantLR(
135
+ lr_scheduler=ConstantLR(
136
136
  _partial_=True,
137
137
  warmup_steps=warmup_steps,
138
138
  ),
@@ -44,7 +44,7 @@ def SDXL_finetuning(base_model: str, train_steps: int, dataset, save_step: int =
44
44
 
45
45
  optimizer=optimizer,
46
46
 
47
- scheduler=ConstantLR(
47
+ lr_scheduler=ConstantLR(
48
48
  _partial_=True,
49
49
  warmup_steps=warmup_steps,
50
50
  ),
@@ -128,7 +128,7 @@ def SDXL_lora_train(base_model: str, train_steps: int, dataset, save_step: int =
128
128
 
129
129
  optimizer=optimizer,
130
130
 
131
- scheduler=ConstantLR(
131
+ lr_scheduler=ConstantLR(
132
132
  _partial_=True,
133
133
  warmup_steps=warmup_steps,
134
134
  ),
@@ -126,6 +126,10 @@ class EmbeddingPTInterpHook(SinglePluginBlock):
126
126
  BOS = repeat(inputs_embeds[0,0,:], 'e -> r 1 e', r=self.N_repeats)
127
127
  EOS = repeat(inputs_embeds[0,-1,:], 'e -> r 1 e', r=self.N_repeats)
128
128
 
129
+ # make DDP happy
130
+ if len(self.emb_train) > 0:
131
+ BOS = BOS + sum(emb.mean()*0 for emb in self.emb_train if emb.requires_grad)
132
+
129
133
  replaced_embeds = []
130
134
  for item, rep_idxs, ids_raw in zip(inputs_embeds, rep_idxs_B, self.input_ids):
131
135
  # insert pt to embeddings
hcpdiff/trainer_ac.py CHANGED
@@ -42,13 +42,6 @@ class HCPTrainer(Trainer):
42
42
  def pt_trainable(self):
43
43
  return self.cfgs.emb_pt is not None
44
44
 
45
- def get_loss(self, ds_name, model_pred, inputs):
46
- loss = super().get_loss(ds_name, model_pred, inputs)
47
- # make DDP happy
48
- if len(self.train_pts)>0:
49
- loss = loss+0*sum([emb.mean() for emb in self.train_pts.values()])
50
- return loss
51
-
52
45
  def save_model(self, from_raw=False):
53
46
  NekoSaver.save_all(
54
47
  self.model_raw,
@@ -0,0 +1,47 @@
1
+ import argparse
2
+ import warnings
3
+
4
+ import torch
5
+ from rainbowneko.ckpt_manager import NekoPluginSaver
6
+ from rainbowneko.train.trainer import TrainerDeepspeed
7
+ from rainbowneko.utils import xformers_available
8
+
9
+ from hcpdiff.trainer_ac import HCPTrainer, load_config_with_cli
10
+
11
+ class HCPTrainerDeepspeed(TrainerDeepspeed, HCPTrainer):
12
+ def config_model(self):
13
+ if self.cfgs.model.enable_xformers:
14
+ if xformers_available:
15
+ self.model_wrapper.enable_xformers()
16
+ else:
17
+ warnings.warn("xformers is not available. Make sure it is installed correctly")
18
+
19
+ if self.model_wrapper.vae is not None:
20
+ self.vae_dtype = self.weight_dtype_map.get(self.cfgs.model.get('vae_dtype', None), torch.float32)
21
+ self.model_wrapper.set_dtype(self.weight_dtype, self.vae_dtype)
22
+
23
+ if self.cfgs.model.gradient_checkpointing:
24
+ self.model_wrapper.enable_gradient_checkpointing()
25
+
26
+ if self.is_local_main_process:
27
+ for saver in self.ckpt_saver.values():
28
+ if isinstance(saver, NekoPluginSaver):
29
+ saver.plugin_from_raw = True
30
+
31
+ def hcp_train():
32
+ import subprocess
33
+ parser = argparse.ArgumentParser(description='HCP-Diffusion Launcher')
34
+ parser.add_argument('--launch_cfg', type=str, default='cfgs/launcher/deepspeed.yaml')
35
+ args, train_args = parser.parse_known_args()
36
+
37
+ subprocess.run(["accelerate", "launch", '--config_file', args.launch_cfg, "-m",
38
+ "hcpdiff.trainer_deepspeed"]+train_args, check=True)
39
+
40
+ if __name__ == '__main__':
41
+ parser = argparse.ArgumentParser(description='HCP Diffusion Trainer for DeepSpeed')
42
+ parser.add_argument("--cfg", type=str, default=None, required=True)
43
+ args, cfg_args = parser.parse_known_args()
44
+
45
+ parser, conf = load_config_with_cli(args.cfg, args_list=cfg_args) # skip --cfg
46
+ trainer = HCPTrainerDeepspeed(parser, conf)
47
+ trainer.train()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hcpdiff
3
- Version: 2.2
3
+ Version: 2.2.1
4
4
  Summary: A universal Diffusion toolbox
5
5
  Home-page: https://github.com/IrisRainbowNeko/HCP-Diffusion
6
6
  Author: Ziyi Dong
@@ -17,7 +17,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
17
17
  Requires-Python: >=3.8
18
18
  Description-Content-Type: text/markdown
19
19
  License-File: LICENSE
20
- Requires-Dist: rainbowneko
20
+ Requires-Dist: rainbowneko==1.6
21
21
  Requires-Dist: diffusers
22
22
  Requires-Dist: matplotlib
23
23
  Requires-Dist: pyarrow
@@ -262,9 +262,13 @@ hcp_run --cfg cfgs/workflow/text2img_cli.py \
262
262
  seed=42
263
263
  ```
264
264
 
265
- ### Tutorials
265
+ ### 📚 Tutorials
266
266
 
267
- 🚧 In Development
267
+ + 🧠 [Model Training Guide](https://hcpdiff.readthedocs.io/en/latest/user_guides/train.html)
268
+ + 🔧 [LoRA Training Tutorial](https://hcpdiff.readthedocs.io/enlatest/tutorial/lora.html)
269
+ + 🎨 [Image Generation Guide](https://hcpdiff.readthedocs.io/en/latest/user_guides/workflow.html)
270
+ + ⚙️ [Configuration File Explanation](https://hcpdiff.readthedocs.io/en/latest/user_guides/cfg.html)
271
+ + 🧩 [Model Format Explanation](https://hcpdiff.readthedocs.io/en/latest/user_guides/model_format.html)
268
272
 
269
273
  ---
270
274
 
@@ -1,8 +1,8 @@
1
1
  hcpdiff/__init__.py,sha256=dwNwrEgvG4g60fGMG6b50K3q3AWD1XCfzlIgbxkSUpE,177
2
2
  hcpdiff/train_colo.py,sha256=EsuNSzLBvGTZWU_LEk0JpP-F5eNW0lwkawIRAX38jmE,9250
3
- hcpdiff/train_deepspeed.py,sha256=PwyNukWi0of6TXy_VRDgBQSMLCZBhipO5g3Lq0nCYNk,2988
4
- hcpdiff/trainer_ac.py,sha256=6KAzo54in7ZRHud_rHjJdwRRZ4uWtc0B4SxVCxgcrmM,2990
3
+ hcpdiff/trainer_ac.py,sha256=scH3FU0onCQtwLiy0-pcrhuowTZob3fLQqRP52iwY0c,2717
5
4
  hcpdiff/trainer_ac_single.py,sha256=0PIC5EScqcxp49EaeIWq4KS5K_09OZfKajqbFu-hUb8,1108
5
+ hcpdiff/trainer_deepspeed.py,sha256=7lGsiAstWuIlmhRMwWTcJCkoxzUaakVxBngKDnJdSJk,1947
6
6
  hcpdiff/ckpt_manager/__init__.py,sha256=Mn_5KOC4xbf2GcN6OXg_XdbF5wO9zWeER_1ZO_prKAI,256
7
7
  hcpdiff/ckpt_manager/ckpt.py,sha256=Pa3uXQbCi2T99mpV5fYddQ-OGHcpk8r1ll-0lmP_WXk,965
8
8
  hcpdiff/ckpt_manager/loader.py,sha256=Ch1xsZmseq4nyPhpox9-nebN-dZB4k0rqBEHos-ZLso,3245
@@ -40,8 +40,8 @@ hcpdiff/diffusion/sampler/sigma_scheduler/edm.py,sha256=fOPB3lgnS9uVo4oW26Fur_nc
40
40
  hcpdiff/easy/__init__.py,sha256=-emoyCOZlLCu3KNMI8L4qapUEtEYFSoiGU6-rKv1at4,149
41
41
  hcpdiff/easy/sampler.py,sha256=dQSBkeGh71O0DAmZLhTHTbk1bY7XzyUCeW1oJO14A4I,1250
42
42
  hcpdiff/easy/cfg/__init__.py,sha256=SxHMWG6T2CXhX3dP0xizSMd9vFWPaZQDc4Gj4CF__yQ,253
43
- hcpdiff/easy/cfg/sd15_train.py,sha256=KtplqN-OhzdZjsX2s60J3XR6o7tRJ-QDx7Eqza_eDkM,6704
44
- hcpdiff/easy/cfg/sdxl_train.py,sha256=ZKfJ19IvR2dZqDNXULmhZEmqjE7qV4QYxSTvEhI7efQ,4269
43
+ hcpdiff/easy/cfg/sd15_train.py,sha256=kKdESVqAxNlBhhz12PvwrpHJBea80OUFzDDMHwiulVs,6710
44
+ hcpdiff/easy/cfg/sdxl_train.py,sha256=FUWE_hRJdQc9Qd9J6730jAyK0H4EIKS7-3BSufCItXU,4275
45
45
  hcpdiff/easy/cfg/t2i.py,sha256=SnjFjZAKd9orjJr3RW5_N2_EIlW2Ree7JMvdNUAR9gc,9507
46
46
  hcpdiff/easy/model/__init__.py,sha256=CA-7r3R2Jgweekk1XNByFYttLolbWyUV2bCnXygcD8w,133
47
47
  hcpdiff/easy/model/cnet.py,sha256=m0NTH9V1kLzb5GybwBrSNT0KvTcRpPfGkzUeMz9jZZQ,1084
@@ -62,7 +62,7 @@ hcpdiff/models/lora_base.py,sha256=LGwBD9KP6qf4pgTx24i5-JLo4rDBQ6jFfterQKBjTbE,6
62
62
  hcpdiff/models/lora_base_patch.py,sha256=WW3CULnROTxKXyynJiqirhHYCKN5JtxLhVpT5b7AUQg,6532
63
63
  hcpdiff/models/lora_layers.py,sha256=O9W_Ue71lHj7Y_GbpioF4Hc3h2-z_zOqck93VYUra6s,7777
64
64
  hcpdiff/models/lora_layers_patch.py,sha256=GYFYsJD2VSLZfdnLma9CmQEHz09HROFJcc4wc_gs9f0,8198
65
- hcpdiff/models/text_emb_ex.py,sha256=a5QImxzvj0zWR12qXOPP9kmpESl8J9VLabA0W9D_i_c,7867
65
+ hcpdiff/models/text_emb_ex.py,sha256=O0XZqid01OrB0dHY7hCiBvdU2026SvZ38yfQaF2TWrs,8018
66
66
  hcpdiff/models/textencoder_ex.py,sha256=JrTQ30Avx8tPbdr-Q6K5BvEWCEdsu8Z7eSOzMqpUuzg,8270
67
67
  hcpdiff/models/tokenizer_ex.py,sha256=zKUn4BY7b3yXwK9PWkZtQKJPyKYwUc07E-hwB9NQybs,2446
68
68
  hcpdiff/models/compose/__init__.py,sha256=lTNFTGg5csqvUuys22RqgjmWlk_7Okw6ZTsnTi1pqCg,217
@@ -107,9 +107,9 @@ hcpdiff/workflow/vae.py,sha256=cingDPkIOc4qGpOwwhXJK4EQbGoIxO583pm6gGov5t8,3118
107
107
  hcpdiff/workflow/daam/__init__.py,sha256=ySIDaxloN-D3qM7OuVaG1BR3D-CibDoXYpoTgw0zUhU,59
108
108
  hcpdiff/workflow/daam/act.py,sha256=tHbsFWTYYU4bvcZOo1Bpi_z6ofpJatRYccl4vvf8wIA,2756
109
109
  hcpdiff/workflow/daam/hook.py,sha256=z9f9mBjKW21xuUZ-iQxQ0HbWOBXtZrisFB0VNMq6d0U,4383
110
- hcpdiff-2.2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
111
- hcpdiff-2.2.dist-info/METADATA,sha256=u52mZtA0hI2P_fObmJZRUkZZfnKFYg5c24f4p0trH0o,9833
112
- hcpdiff-2.2.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
113
- hcpdiff-2.2.dist-info/entry_points.txt,sha256=86wPOMzsfWWflTJ-sQPLc7WG5Vtu0kGYBH9C_vR3ur8,207
114
- hcpdiff-2.2.dist-info/top_level.txt,sha256=shyf78x-HVgykYpsmY22mKG0xIc7Qk30fDMdavdYWQ8,8
115
- hcpdiff-2.2.dist-info/RECORD,,
110
+ hcpdiff-2.2.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
111
+ hcpdiff-2.2.1.dist-info/METADATA,sha256=f96Tc90K5WTBbJ35wWJw60G2JR46eGpUvQSaPIysVDg,10323
112
+ hcpdiff-2.2.1.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
113
+ hcpdiff-2.2.1.dist-info/entry_points.txt,sha256=_4VRsEsEWOhHfzBDu9bx8Wh_S8Wi4ZTHpI0n6rU0J-I,258
114
+ hcpdiff-2.2.1.dist-info/top_level.txt,sha256=shyf78x-HVgykYpsmY22mKG0xIc7Qk30fDMdavdYWQ8,8
115
+ hcpdiff-2.2.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.0)
2
+ Generator: setuptools (78.1.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -2,4 +2,5 @@
2
2
  hcp_run = rainbowneko.infer.infer_workflow:run_workflow
3
3
  hcp_train = hcpdiff.trainer_ac:hcp_train
4
4
  hcp_train_1gpu = hcpdiff.trainer_ac_single:hcp_train
5
+ hcp_train_ds = hcpdiff.trainer_deepspeed:hcp_train
5
6
  hcpinit = hcpdiff.tools.init_proj:main
@@ -1,69 +0,0 @@
1
- import argparse
2
- import os
3
- import sys
4
- import warnings
5
- from functools import partial
6
-
7
- import torch
8
-
9
- from hcpdiff.ckpt_manager import CkptManagerPKL, CkptManagerSafe
10
- from hcpdiff.train_ac_old import Trainer, load_config_with_cli
11
- from hcpdiff.utils.net_utils import get_scheduler
12
-
13
- class TrainerDeepSpeed(Trainer):
14
-
15
- def build_ckpt_manager(self):
16
- self.ckpt_manager = self.ckpt_manager_map[self.cfgs.ckpt_type](plugin_from_raw=True)
17
- if self.is_local_main_process:
18
- self.ckpt_manager.set_save_dir(os.path.join(self.exp_dir, 'ckpts'), emb_dir=self.cfgs.tokenizer_pt.emb_dir)
19
-
20
- @property
21
- def unet_raw(self):
22
- return self.accelerator.unwrap_model(self.TE_unet).unet if self.train_TE else self.accelerator.unwrap_model(self.TE_unet.unet)
23
-
24
- @property
25
- def TE_raw(self):
26
- return self.accelerator.unwrap_model(self.TE_unet).TE if self.train_TE else self.TE_unet.TE
27
-
28
- def get_loss(self, model_pred, target, timesteps, att_mask):
29
- if att_mask is None:
30
- att_mask = 1.0
31
- if getattr(self.criterion, 'need_timesteps', False):
32
- loss = (self.criterion(model_pred.float(), target.float(), timesteps)*att_mask).mean()
33
- else:
34
- loss = (self.criterion(model_pred.float(), target.float())*att_mask).mean()
35
- return loss
36
-
37
- def build_optimizer_scheduler(self):
38
- # set optimizer
39
- parameters, parameters_pt = self.get_param_group_train()
40
-
41
- if len(parameters_pt)>0: # do prompt-tuning
42
- cfg_opt_pt = self.cfgs.train.optimizer_pt
43
- # if self.cfgs.train.scale_lr_pt:
44
- # self.scale_lr(parameters_pt)
45
- assert isinstance(cfg_opt_pt, partial), f'optimizer.type is not supported anymore, please use class path like "torch.optim.AdamW".'
46
- weight_decay = cfg_opt_pt.keywords.get('weight_decay', None)
47
- if weight_decay is not None:
48
- for param in parameters_pt:
49
- param['weight_decay'] = weight_decay
50
-
51
- parameters += parameters_pt
52
- warnings.warn('deepspeed dose not support multi optimizer and lr_scheduler. optimizer_pt and scheduler_pt will not work.')
53
-
54
- if len(parameters)>0:
55
- cfg_opt = self.cfgs.train.optimizer
56
- if self.cfgs.train.scale_lr:
57
- self.scale_lr(parameters)
58
- assert isinstance(cfg_opt, partial), f'optimizer.type is not supported anymore, please use class path like "torch.optim.AdamW".'
59
- self.optimizer = cfg_opt(params=parameters)
60
- self.lr_scheduler = get_scheduler(self.cfgs.train.scheduler, self.optimizer)
61
-
62
- if __name__ == '__main__':
63
- parser = argparse.ArgumentParser(description='Stable Diffusion Training')
64
- parser.add_argument('--cfg', type=str, default='cfg/train/demo.yaml')
65
- args, cfg_args = parser.parse_known_args()
66
-
67
- conf = load_config_with_cli(args.cfg, args_list=cfg_args) # skip --cfg
68
- trainer = TrainerDeepSpeed(conf)
69
- trainer.train()