hcpdiff 0.9.1__py3-none-any.whl → 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (210) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +80 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +1 -2
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text2img.py +36 -74
  21. hcpdiff/data/source/text2img_cond.py +9 -15
  22. hcpdiff/diffusion/__init__.py +0 -0
  23. hcpdiff/diffusion/noise/__init__.py +2 -0
  24. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  25. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  26. hcpdiff/diffusion/sampler/__init__.py +5 -0
  27. hcpdiff/diffusion/sampler/base.py +72 -0
  28. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  29. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  30. hcpdiff/diffusion/sampler/edm.py +22 -0
  31. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  35. hcpdiff/easy/__init__.py +2 -0
  36. hcpdiff/easy/cfg/__init__.py +3 -0
  37. hcpdiff/easy/cfg/sd15_train.py +201 -0
  38. hcpdiff/easy/cfg/sdxl_train.py +140 -0
  39. hcpdiff/easy/cfg/t2i.py +177 -0
  40. hcpdiff/easy/model/__init__.py +2 -0
  41. hcpdiff/easy/model/cnet.py +31 -0
  42. hcpdiff/easy/model/loader.py +79 -0
  43. hcpdiff/easy/sampler.py +46 -0
  44. hcpdiff/evaluate/__init__.py +1 -0
  45. hcpdiff/evaluate/previewer.py +60 -0
  46. hcpdiff/loss/__init__.py +4 -1
  47. hcpdiff/loss/base.py +41 -0
  48. hcpdiff/loss/gw.py +35 -0
  49. hcpdiff/loss/ssim.py +37 -0
  50. hcpdiff/loss/vlb.py +79 -0
  51. hcpdiff/loss/weighting.py +66 -0
  52. hcpdiff/models/__init__.py +2 -2
  53. hcpdiff/models/cfg_context.py +17 -14
  54. hcpdiff/models/compose/compose_hook.py +44 -23
  55. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  56. hcpdiff/models/compose/sdxl_composer.py +4 -4
  57. hcpdiff/models/controlnet.py +16 -16
  58. hcpdiff/models/lora_base_patch.py +14 -25
  59. hcpdiff/models/lora_layers.py +3 -9
  60. hcpdiff/models/lora_layers_patch.py +14 -24
  61. hcpdiff/models/text_emb_ex.py +84 -6
  62. hcpdiff/models/textencoder_ex.py +54 -18
  63. hcpdiff/models/wrapper/__init__.py +3 -0
  64. hcpdiff/models/wrapper/pixart.py +19 -0
  65. hcpdiff/models/wrapper/sd.py +218 -0
  66. hcpdiff/models/wrapper/utils.py +20 -0
  67. hcpdiff/parser/__init__.py +1 -0
  68. hcpdiff/parser/embpt.py +32 -0
  69. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  70. hcpdiff/tools/dataset_generator.py +94 -0
  71. hcpdiff/tools/download_hf_model.py +24 -0
  72. hcpdiff/tools/init_proj.py +3 -21
  73. hcpdiff/tools/lora_convert.py +18 -17
  74. hcpdiff/tools/save_model.py +12 -0
  75. hcpdiff/tools/sd2diffusers.py +1 -1
  76. hcpdiff/train_colo.py +1 -1
  77. hcpdiff/train_deepspeed.py +1 -1
  78. hcpdiff/trainer_ac.py +79 -0
  79. hcpdiff/trainer_ac_single.py +31 -0
  80. hcpdiff/utils/__init__.py +0 -2
  81. hcpdiff/utils/inpaint_pipe.py +7 -2
  82. hcpdiff/utils/net_utils.py +29 -6
  83. hcpdiff/utils/pipe_hook.py +24 -7
  84. hcpdiff/utils/utils.py +21 -4
  85. hcpdiff/workflow/__init__.py +15 -10
  86. hcpdiff/workflow/daam/__init__.py +1 -0
  87. hcpdiff/workflow/daam/act.py +66 -0
  88. hcpdiff/workflow/daam/hook.py +109 -0
  89. hcpdiff/workflow/diffusion.py +114 -125
  90. hcpdiff/workflow/fast.py +31 -0
  91. hcpdiff/workflow/flow.py +67 -0
  92. hcpdiff/workflow/io.py +36 -130
  93. hcpdiff/workflow/model.py +46 -43
  94. hcpdiff/workflow/text.py +78 -46
  95. hcpdiff/workflow/utils.py +32 -12
  96. hcpdiff/workflow/vae.py +37 -38
  97. hcpdiff-2.1.dist-info/METADATA +285 -0
  98. hcpdiff-2.1.dist-info/RECORD +114 -0
  99. {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
  100. hcpdiff-2.1.dist-info/entry_points.txt +5 -0
  101. hcpdiff/ckpt_manager/base.py +0 -16
  102. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  103. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  104. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
  105. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  106. hcpdiff/data/bucket.py +0 -358
  107. hcpdiff/data/caption_loader.py +0 -80
  108. hcpdiff/data/cond_dataset.py +0 -40
  109. hcpdiff/data/crop_info_dataset.py +0 -40
  110. hcpdiff/data/data_processor.py +0 -33
  111. hcpdiff/data/pair_dataset.py +0 -146
  112. hcpdiff/data/sampler.py +0 -54
  113. hcpdiff/data/source/base.py +0 -30
  114. hcpdiff/data/utils.py +0 -80
  115. hcpdiff/deprecated/__init__.py +0 -1
  116. hcpdiff/deprecated/cfg_converter.py +0 -81
  117. hcpdiff/deprecated/lora_convert.py +0 -31
  118. hcpdiff/infer_workflow.py +0 -57
  119. hcpdiff/loggers/__init__.py +0 -13
  120. hcpdiff/loggers/base_logger.py +0 -76
  121. hcpdiff/loggers/cli_logger.py +0 -40
  122. hcpdiff/loggers/preview/__init__.py +0 -1
  123. hcpdiff/loggers/preview/image_previewer.py +0 -149
  124. hcpdiff/loggers/tensorboard_logger.py +0 -30
  125. hcpdiff/loggers/wandb_logger.py +0 -31
  126. hcpdiff/loggers/webui_logger.py +0 -9
  127. hcpdiff/loss/min_snr_loss.py +0 -52
  128. hcpdiff/models/layers.py +0 -81
  129. hcpdiff/models/plugin.py +0 -348
  130. hcpdiff/models/wrapper.py +0 -75
  131. hcpdiff/noise/__init__.py +0 -3
  132. hcpdiff/noise/noise_base.py +0 -16
  133. hcpdiff/noise/pyramid_noise.py +0 -50
  134. hcpdiff/noise/zero_terminal.py +0 -44
  135. hcpdiff/train_ac.py +0 -566
  136. hcpdiff/train_ac_single.py +0 -39
  137. hcpdiff/utils/caption_tools.py +0 -105
  138. hcpdiff/utils/cfg_net_tools.py +0 -321
  139. hcpdiff/utils/cfg_resolvers.py +0 -16
  140. hcpdiff/utils/ema.py +0 -52
  141. hcpdiff/utils/img_size_tool.py +0 -248
  142. hcpdiff/vis/__init__.py +0 -3
  143. hcpdiff/vis/base_interface.py +0 -12
  144. hcpdiff/vis/disk_interface.py +0 -48
  145. hcpdiff/vis/webui_interface.py +0 -17
  146. hcpdiff/viser_fast.py +0 -138
  147. hcpdiff/visualizer.py +0 -265
  148. hcpdiff/visualizer_reloadable.py +0 -237
  149. hcpdiff/workflow/base.py +0 -59
  150. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  151. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  152. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  153. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  154. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  155. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  156. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  157. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  158. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  159. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  160. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  161. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  162. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  163. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  164. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  165. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  166. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  167. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  168. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  169. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  170. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  171. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  172. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  173. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  174. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  175. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  176. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  177. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  178. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  179. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  180. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  181. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  182. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  183. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  184. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  185. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  186. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  187. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  188. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  189. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  190. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  191. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  192. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  193. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  194. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
  195. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  196. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
  197. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
  198. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  199. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  200. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  201. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  202. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  203. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  204. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  205. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  206. hcpdiff-0.9.1.dist-info/METADATA +0 -199
  207. hcpdiff-0.9.1.dist-info/RECORD +0 -160
  208. hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
  209. {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
  210. {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
hcpdiff/train_ac.py DELETED
@@ -1,566 +0,0 @@
1
- """
2
- train_ac.py
3
- ====================
4
- :Name: train with accelerate
5
- :Author: Dong Ziyi
6
- :Affiliation: HCP Lab, SYSU
7
- :Created: 10/03/2023
8
- :Licence: Apache-2.0
9
- """
10
-
11
- import argparse
12
- import math
13
- import os
14
- import time
15
- import warnings
16
- from functools import partial
17
-
18
- import diffusers
19
- import hydra
20
- import torch
21
- import torch.utils.checkpoint
22
- # fix checkpoint bug for train part of model
23
- import torch.utils.checkpoint
24
- import torch.utils.data
25
- import transformers
26
- from accelerate import Accelerator, DistributedDataParallelKwargs
27
- from accelerate.utils import set_seed
28
- from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
29
- from diffusers.utils.import_utils import is_xformers_available
30
- from omegaconf import OmegaConf
31
-
32
- from hcpdiff.ckpt_manager import CkptManagerPKL, CkptManagerSafe
33
- from hcpdiff.data import RatioBucket, DataGroup, get_sampler
34
- from hcpdiff.deprecated.cfg_converter import TrainCFGConverter
35
- from hcpdiff.loggers import LoggerGroup
36
- from hcpdiff.models import CFGContext, DreamArtistPTContext, TEUnetWrapper, SDXLTEUnetWrapper
37
- from hcpdiff.models.compose import ComposeEmbPTHook, ComposeTEEXHook
38
- from hcpdiff.models.compose import SDXLTextEncoder
39
- from hcpdiff.utils.cfg_net_tools import make_hcpdiff, make_plugin
40
- from hcpdiff.utils.net_utils import get_scheduler, auto_tokenizer_cls, auto_text_encoder_cls, load_emb
41
- from hcpdiff.utils.utils import load_config_with_cli, get_cfg_range, mgcd, format_number
42
- from hcpdiff.visualizer import Visualizer
43
-
44
- def checkpoint_fix(function, *args, use_reentrant: bool = False, checkpoint_raw=torch.utils.checkpoint.checkpoint, **kwargs):
45
- return checkpoint_raw(function, *args, use_reentrant=use_reentrant, **kwargs)
46
-
47
- torch.utils.checkpoint.checkpoint = checkpoint_fix
48
-
49
- class Trainer:
50
- weight_dtype_map = {'fp32':torch.float32, 'fp16':torch.float16, 'bf16':torch.bfloat16}
51
- ckpt_manager_map = {'torch':CkptManagerPKL, 'safetensors':CkptManagerSafe}
52
-
53
- def __init__(self, cfgs_raw):
54
- cfgs_raw = TrainCFGConverter().convert(cfgs_raw) # support old cfgs format
55
- cfgs = hydra.utils.instantiate(cfgs_raw)
56
- self.cfgs = cfgs
57
-
58
- self.init_context(cfgs_raw)
59
- self.build_loggers(cfgs_raw)
60
-
61
- self.train_TE = any([cfgs.text_encoder, cfgs.lora_text_encoder, cfgs.plugin_TE])
62
-
63
- self.build_ckpt_manager()
64
- self.build_model()
65
- self.make_hooks()
66
- self.config_model()
67
- self.cache_latents = False
68
-
69
- self.batch_size_list = []
70
- assert len(cfgs.data)>0, "At least one dataset is need."
71
- loss_weights = [dataset.keywords['loss_weight'] for name, dataset in cfgs.data.items()]
72
- self.train_loader_group = DataGroup([self.build_data(dataset) for name, dataset in cfgs.data.items()], loss_weights)
73
-
74
- if self.cache_latents:
75
- self.vae = self.vae.to('cpu')
76
- self.build_optimizer_scheduler()
77
- try:
78
- self.criterion = cfgs.train.loss.criterion(noise_scheduler=self.noise_scheduler, device=self.device)
79
- except:
80
- self.criterion = cfgs.train.loss.criterion()
81
-
82
- self.cfg_scale = get_cfg_range(cfgs.train.cfg_scale)
83
- if self.cfg_scale[1] == 1.0:
84
- self.cfg_context = CFGContext()
85
- else: # DreamArtist
86
- self.cfg_context = DreamArtistPTContext(self.cfg_scale, self.num_train_timesteps)
87
-
88
- with torch.no_grad():
89
- self.build_ema()
90
-
91
- self.load_resume()
92
-
93
- torch.backends.cuda.matmul.allow_tf32 = cfgs.allow_tf32
94
-
95
- # calculate steps and epochs
96
- self.steps_per_epoch = len(self.train_loader_group.loader_list[0])
97
- if self.cfgs.train.train_epochs is not None:
98
- self.cfgs.train.train_steps = self.cfgs.train.train_epochs*self.steps_per_epoch
99
- else:
100
- self.cfgs.train.train_epochs = math.ceil(self.cfgs.train.train_steps/self.steps_per_epoch)
101
-
102
- if self.is_local_main_process and self.cfgs.previewer is not None:
103
- self.previewer = self.cfgs.previewer(exp_dir=self.exp_dir, te_hook=self.text_enc_hook, unet=self.TE_unet.unet,
104
- TE=self.TE_unet.TE, tokenizer=self.tokenizer, vae=self.vae)
105
-
106
- self.prepare()
107
-
108
- @property
109
- def device(self):
110
- return self.accelerator.device
111
-
112
- @property
113
- def is_local_main_process(self):
114
- return self.accelerator.is_local_main_process
115
-
116
- def init_context(self, cfgs_raw):
117
- ddp_kwargs = DistributedDataParallelKwargs(broadcast_buffers=False)
118
- self.accelerator = Accelerator(
119
- gradient_accumulation_steps=self.cfgs.train.gradient_accumulation_steps,
120
- mixed_precision=self.cfgs.mixed_precision,
121
- step_scheduler_with_optimizer=False,
122
- kwargs_handlers=[ddp_kwargs], # fix inplace bug in DDP while use data_class
123
- )
124
-
125
- self.local_rank = int(os.environ.get("LOCAL_RANK", -1))
126
- self.world_size = self.accelerator.num_processes
127
-
128
- set_seed(self.cfgs.seed+self.local_rank)
129
-
130
- def build_loggers(self, cfgs_raw):
131
- if self.is_local_main_process:
132
- self.exp_dir = self.cfgs.exp_dir.format(time=time.strftime("%Y-%m-%d-%H-%M-%S"))
133
- os.makedirs(os.path.join(self.exp_dir, 'ckpts/'), exist_ok=True)
134
- with open(os.path.join(self.exp_dir, 'cfg.yaml'), 'w', encoding='utf-8') as f:
135
- f.write(OmegaConf.to_yaml(cfgs_raw))
136
- self.loggers: LoggerGroup = LoggerGroup([builder(exp_dir=self.exp_dir) for builder in self.cfgs.logger])
137
- else:
138
- self.loggers: LoggerGroup = LoggerGroup([builder(exp_dir=None) for builder in self.cfgs.logger])
139
-
140
- self.min_log_step = mgcd(*([item.log_step for item in self.loggers.logger_list]))
141
- image_log_steps = [item.image_log_step for item in self.loggers.logger_list if item.enable_log_image]
142
- if len(image_log_steps)>0:
143
- self.min_img_log_step = mgcd(*image_log_steps)
144
- else:
145
- self.min_img_log_step = -1
146
-
147
- self.loggers.info(f'world_size: {self.world_size}')
148
- self.loggers.info(f'accumulation: {self.cfgs.train.gradient_accumulation_steps}')
149
-
150
- if self.is_local_main_process:
151
- transformers.utils.logging.set_verbosity_warning()
152
- diffusers.utils.logging.set_verbosity_warning()
153
- else:
154
- transformers.utils.logging.set_verbosity_error()
155
- diffusers.utils.logging.set_verbosity_error()
156
-
157
- def prepare(self):
158
- # Prepare everything with accelerator.
159
- prepare_name_list, prepare_obj_list = [], []
160
- if self.TE_unet.train_TE:
161
- prepare_obj_list.append(self.TE_unet)
162
- prepare_name_list.append('TE_unet')
163
- else:
164
- prepare_obj_list.append(self.TE_unet.unet)
165
- prepare_name_list.append('TE_unet.unet')
166
-
167
- if hasattr(self, 'optimizer'):
168
- prepare_obj_list.extend([self.optimizer, self.lr_scheduler] if self.lr_scheduler else [self.optimizer])
169
- prepare_name_list.extend(['optimizer', 'lr_scheduler'] if self.lr_scheduler else ['optimizer'])
170
- if hasattr(self, 'optimizer_pt'):
171
- prepare_obj_list.extend([self.optimizer_pt, self.lr_scheduler_pt] if self.lr_scheduler_pt else [self.optimizer_pt])
172
- prepare_name_list.extend(['optimizer_pt', 'lr_scheduler_pt'] if self.lr_scheduler_pt else ['optimizer_pt'])
173
-
174
- prepare_obj_list.extend(self.train_loader_group.loader_list)
175
- prepared_obj = self.accelerator.prepare(*prepare_obj_list)
176
-
177
- if not self.TE_unet.train_TE:
178
- self.TE_unet.unet = prepared_obj[0]
179
- prepared_obj = prepared_obj[1:]
180
- prepare_name_list = prepare_name_list[1:]
181
-
182
- ds_num = len(self.train_loader_group.loader_list)
183
- self.train_loader_group.loader_list = list(prepared_obj[-ds_num:])
184
- prepared_obj = prepared_obj[:-ds_num]
185
-
186
- for name, obj in zip(prepare_name_list, prepared_obj):
187
- setattr(self, name, obj)
188
-
189
- if self.cfgs.model.force_cast_precision:
190
- self.TE_unet.to(dtype=self.weight_dtype)
191
-
192
- def scale_lr(self, parameters):
193
- bs = sum(self.batch_size_list)
194
- scale_factor = bs*self.world_size*self.cfgs.train.gradient_accumulation_steps
195
- for param in parameters:
196
- if 'lr' in param:
197
- param['lr'] *= scale_factor
198
-
199
- def build_model(self):
200
- # Load the tokenizer
201
- if self.cfgs.model.get('tokenizer', None) is not None:
202
- self.tokenizer = self.cfgs.model.tokenizer
203
- else:
204
- tokenizer_cls = auto_tokenizer_cls(self.cfgs.model.pretrained_model_name_or_path, self.cfgs.model.revision)
205
- self.tokenizer = tokenizer_cls.from_pretrained(
206
- self.cfgs.model.pretrained_model_name_or_path, subfolder="tokenizer",
207
- revision=self.cfgs.model.revision, use_fast=False,
208
- )
209
-
210
- # Load scheduler and models
211
- self.noise_scheduler = self.cfgs.model.get('noise_scheduler', None) or \
212
- DDPMScheduler.from_pretrained(self.cfgs.model.pretrained_model_name_or_path, subfolder='scheduler')
213
-
214
- self.num_train_timesteps = len(self.noise_scheduler.timesteps)
215
- self.vae: AutoencoderKL = self.cfgs.model.get('vae', None) or AutoencoderKL.from_pretrained(
216
- self.cfgs.model.pretrained_model_name_or_path, subfolder="vae", revision=self.cfgs.model.revision)
217
- self.build_unet_and_TE()
218
-
219
- def build_unet_and_TE(self): # for easy to use colossalAI
220
- unet = self.cfgs.model.get('unet', None) or UNet2DConditionModel.from_pretrained(
221
- self.cfgs.model.pretrained_model_name_or_path, subfolder="unet", revision=self.cfgs.model.revision
222
- )
223
-
224
- if self.cfgs.model.get('text_encoder', None) is not None:
225
- text_encoder = self.cfgs.model.text_encoder
226
- text_encoder_cls = type(text_encoder)
227
- else:
228
- # import correct text encoder class
229
- text_encoder_cls = auto_text_encoder_cls(self.cfgs.model.pretrained_model_name_or_path, self.cfgs.model.revision)
230
- text_encoder = text_encoder_cls.from_pretrained(
231
- self.cfgs.model.pretrained_model_name_or_path, subfolder="text_encoder", revision=self.cfgs.model.revision
232
- )
233
-
234
- # Wrap unet and text_encoder to make DDP happy. Multiple DDP has soooooo many fxxking bugs!
235
- wrapper_cls = SDXLTEUnetWrapper if text_encoder_cls == SDXLTextEncoder else TEUnetWrapper
236
- self.TE_unet = wrapper_cls(unet, text_encoder, train_TE=self.train_TE)
237
-
238
- def build_ema(self):
239
- if self.cfgs.model.ema is not None:
240
- self.ema_unet = self.cfgs.model.ema(self.TE_unet.unet)
241
- if self.train_TE:
242
- self.ema_text_encoder = self.cfgs.model.ema(self.TE_unet.TE)
243
-
244
- def build_ckpt_manager(self):
245
- self.ckpt_manager = self.ckpt_manager_map[self.cfgs.ckpt_type]()
246
- if self.is_local_main_process:
247
- self.ckpt_manager.set_save_dir(os.path.join(self.exp_dir, 'ckpts'), emb_dir=self.cfgs.tokenizer_pt.emb_dir)
248
-
249
- @property
250
- def unet_raw(self):
251
- return self.TE_unet.module.unet if self.train_TE else self.TE_unet.unet.module
252
-
253
- @property
254
- def TE_raw(self):
255
- return self.TE_unet.module.TE if self.train_TE else self.TE_unet.TE
256
-
257
- def config_model(self):
258
- if self.cfgs.model.enable_xformers:
259
- if is_xformers_available():
260
- self.TE_unet.unet.enable_xformers_memory_efficient_attention()
261
- # self.text_enc_hook.enable_xformers()
262
- else:
263
- warnings.warn("xformers is not available. Make sure it is installed correctly")
264
-
265
- self.vae.requires_grad_(False)
266
- self.TE_unet.requires_grad_(False)
267
-
268
- self.TE_unet.eval()
269
-
270
- if self.cfgs.model.gradient_checkpointing:
271
- self.TE_unet.enable_gradient_checkpointing()
272
-
273
- self.weight_dtype = self.weight_dtype_map.get(self.cfgs.mixed_precision, torch.float32)
274
- self.vae_dtype = self.weight_dtype_map.get(self.cfgs.model.get('vae_dtype', None), torch.float32)
275
- # Move vae and text_encoder to device and cast to weight_dtype
276
- self.vae = self.vae.to(self.device, dtype=self.vae_dtype)
277
- if not self.train_TE:
278
- self.TE_unet.TE = self.TE_unet.TE.to(self.device, dtype=self.weight_dtype)
279
-
280
- @torch.no_grad()
281
- def load_resume(self):
282
- if self.cfgs.train.resume is not None:
283
- for ckpt in self.cfgs.train.resume.ckpt_path.unet:
284
- self.ckpt_manager.load_ckpt_to_model(self.TE_unet.unet, ckpt, model_ema=getattr(self, 'ema_unet', None))
285
- for ckpt in self.cfgs.train.resume.ckpt_path.TE:
286
- self.ckpt_manager.load_ckpt_to_model(self.TE_unet.TE, ckpt, model_ema=getattr(self, 'ema_text_encoder', None))
287
- for name, ckpt in self.cfgs.train.resume.ckpt_path.words:
288
- self.ex_words_emb[name].data = load_emb(ckpt)
289
-
290
- def make_hooks(self):
291
- # Hook tokenizer and embedding to support pt
292
- self.embedding_hook, self.ex_words_emb = ComposeEmbPTHook.hook_from_dir(
293
- self.cfgs.tokenizer_pt.emb_dir, self.tokenizer, self.TE_unet.TE, log=self.is_local_main_process,
294
- N_repeats=self.cfgs.model.tokenizer_repeats, device=self.device)
295
-
296
- self.text_enc_hook = ComposeTEEXHook.hook(self.TE_unet.TE, self.tokenizer, N_repeats=self.cfgs.model.tokenizer_repeats,
297
- device=self.device, clip_skip=self.cfgs.model.clip_skip,
298
- clip_final_norm=self.cfgs.model.clip_final_norm)
299
-
300
- def build_dataset(self, data_builder: partial):
301
- batch_size = data_builder.keywords.pop('batch_size')
302
- cache_latents = data_builder.keywords.pop('cache_latents')
303
- self.batch_size_list.append(batch_size)
304
-
305
- train_dataset = data_builder(tokenizer=self.tokenizer, tokenizer_repeats=self.cfgs.model.tokenizer_repeats)
306
- train_dataset.bucket.build(batch_size*self.world_size, file_names=train_dataset.source.get_image_list())
307
- arb = isinstance(train_dataset.bucket, RatioBucket)
308
- self.loggers.info(f"len(train_dataset): {len(train_dataset)}")
309
-
310
- if cache_latents:
311
- self.cache_latents = True
312
- train_dataset.cache_latents(self.vae, self.vae_dtype, self.device, show_prog=self.is_local_main_process)
313
- return train_dataset, batch_size, arb
314
-
315
- def build_data(self, data_builder: partial) -> torch.utils.data.DataLoader:
316
- train_dataset, batch_size, arb = self.build_dataset(data_builder)
317
-
318
- # Pytorch Data loader
319
- train_sampler = get_sampler()(train_dataset, num_replicas=self.world_size, rank=self.local_rank, shuffle=not arb)
320
- train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=self.cfgs.train.workers,
321
- sampler=train_sampler, collate_fn=train_dataset.collate_fn)
322
- return train_loader
323
-
324
- def get_param_group_train(self):
325
- # make miniFT and warp with lora
326
- self.DA_lora = False
327
- train_params_unet, self.lora_unet = make_hcpdiff(self.TE_unet.unet, self.cfgs.unet, self.cfgs.lora_unet)
328
- if isinstance(self.lora_unet, tuple): # creat negative lora
329
- self.DA_lora = True
330
- self.lora_unet, self.lora_unet_neg = self.lora_unet
331
- train_params_unet_plugin, self.all_plugin_unet = make_plugin(self.TE_unet.unet, self.cfgs.plugin_unet)
332
- train_params_unet += train_params_unet_plugin
333
-
334
- if self.train_TE:
335
- train_params_text_encoder, self.lora_TE = make_hcpdiff(self.TE_unet.TE, self.cfgs.text_encoder, self.cfgs.lora_text_encoder)
336
- if isinstance(self.lora_TE, tuple): # creat negative lora
337
- self.DA_lora = True
338
- self.lora_TE, self.lora_TE_neg = self.lora_TE
339
- train_params_TE_plugin, self.all_plugin_TE = make_plugin(self.TE_unet.TE, self.cfgs.plugin_TE)
340
- train_params_text_encoder += train_params_TE_plugin
341
- else:
342
- train_params_text_encoder = []
343
-
344
- N_params_unet = format_number(sum(sum(x.numel() for x in p['params']) for p in train_params_unet))
345
- N_params_TE = format_number(sum(sum(x.numel() for x in p['params']) for p in train_params_text_encoder))
346
- self.loggers.info(f'unet trainable params: {N_params_unet}; text encoder trainable params: {N_params_TE}')
347
-
348
- # params for embedding
349
- train_params_emb = []
350
- self.train_pts = {}
351
- if self.cfgs.tokenizer_pt.train is not None:
352
- for v in self.cfgs.tokenizer_pt.train:
353
- word_emb = self.ex_words_emb[v.name]
354
- self.train_pts[v.name] = word_emb
355
- word_emb.requires_grad = True
356
- self.embedding_hook.emb_train.append(word_emb)
357
- train_params_emb.append({'params':word_emb, 'lr':v.lr})
358
-
359
- return train_params_unet+train_params_text_encoder, train_params_emb
360
-
361
- def build_optimizer_scheduler(self):
362
- # set optimizer
363
- parameters, parameters_pt = self.get_param_group_train()
364
-
365
- if len(parameters)>0: # do fine-tuning
366
- cfg_opt = self.cfgs.train.optimizer
367
- if self.cfgs.train.scale_lr:
368
- self.scale_lr(parameters)
369
- assert isinstance(cfg_opt, partial), f'optimizer.type is not supported anymore, please use class path like "torch.optim.AdamW".'
370
- self.optimizer = cfg_opt(params=parameters)
371
- self.lr_scheduler = get_scheduler(self.cfgs.train.scheduler, self.optimizer)
372
-
373
- if len(parameters_pt)>0: # do prompt-tuning
374
- cfg_opt_pt = self.cfgs.train.optimizer_pt
375
- if self.cfgs.train.scale_lr_pt:
376
- self.scale_lr(parameters_pt)
377
- assert isinstance(cfg_opt_pt, partial), f'optimizer.type is not supported anymore, please use class path like "torch.optim.AdamW".'
378
- self.optimizer_pt = cfg_opt_pt(params=parameters_pt)
379
- self.lr_scheduler_pt = get_scheduler(self.cfgs.train.scheduler_pt, self.optimizer_pt)
380
-
381
- def train(self, loss_ema=0.93):
382
- total_batch_size = sum(self.batch_size_list)*self.world_size*self.cfgs.train.gradient_accumulation_steps
383
-
384
- self.loggers.info("***** Running training *****")
385
- self.loggers.info(f" Num batches each epoch = {len(self.train_loader_group.loader_list[0])}")
386
- self.loggers.info(f" Num Steps = {self.cfgs.train.train_steps}")
387
- self.loggers.info(f" Instantaneous batch size per device = {sum(self.batch_size_list)}")
388
- self.loggers.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
389
- self.loggers.info(f" Gradient Accumulation steps = {self.cfgs.train.gradient_accumulation_steps}")
390
- self.global_step = 0
391
- if self.cfgs.train.resume is not None:
392
- self.global_step = self.cfgs.train.resume.start_step
393
-
394
- loss_sum = None
395
- for data_list in self.train_loader_group:
396
- loss = self.train_one_step(data_list)
397
- loss_sum = loss if loss_sum is None else (loss_ema*loss_sum+(1-loss_ema)*loss)
398
-
399
- self.global_step += 1
400
- if self.is_local_main_process:
401
- if self.global_step%self.cfgs.train.save_step == 0:
402
- self.save_model()
403
- if self.global_step%self.min_log_step == 0:
404
- # get learning rate from optimizer
405
- lr_model = self.optimizer.param_groups[0]['lr'] if hasattr(self, 'optimizer') else 0.
406
- lr_word = self.optimizer_pt.param_groups[0]['lr'] if hasattr(self, 'optimizer_pt') else 0.
407
- self.loggers.log(datas={
408
- 'Step':{'format':'[{}/{}]', 'data':[self.global_step, self.cfgs.train.train_steps]},
409
- 'Epoch':{'format':'[{}/{}]<{}/{}>', 'data':[self.global_step//self.steps_per_epoch, self.cfgs.train.train_epochs,
410
- self.global_step%self.steps_per_epoch, self.steps_per_epoch]},
411
- 'LR_model':{'format':'{:.2e}', 'data':[lr_model]},
412
- 'LR_word':{'format':'{:.2e}', 'data':[lr_word]},
413
- 'Loss':{'format':'{:.5f}', 'data':[loss_sum]},
414
- }, step=self.global_step)
415
- if self.min_img_log_step>0 and self.global_step%self.min_img_log_step == 0:
416
- self.loggers.log_image(self.previewer.preview_dict(), self.global_step)
417
-
418
- if self.global_step>=self.cfgs.train.train_steps:
419
- break
420
-
421
- self.wait_for_everyone()
422
- if self.is_local_main_process:
423
- self.save_model()
424
-
425
- def wait_for_everyone(self):
426
- self.accelerator.wait_for_everyone()
427
-
428
- @torch.no_grad()
429
- def get_latents(self, image, dataset):
430
- if dataset.latents is None:
431
- latents = self.vae.encode(image.to(dtype=self.vae.dtype)).latent_dist.sample()
432
- latents = latents*self.vae.config.scaling_factor
433
- else:
434
- latents = image # Cached latents
435
- return latents
436
-
437
- def make_noise(self, latents):
438
- # Sample noise that we'll add to the latents
439
- noise = torch.randn_like(latents)
440
- bsz = latents.shape[0]
441
- # Sample a random timestep for each image
442
- timesteps = torch.randint(0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
443
- timesteps = timesteps.long()
444
-
445
- # Add noise to the latents according to the noise magnitude at each timestep
446
- # (this is the forward diffusion process)
447
- return self.noise_scheduler.add_noise(latents, noise, timesteps), noise, timesteps
448
-
449
- def forward(self, latents, prompt_ids, attn_mask=None, position_ids=None, **kwargs):
450
- noisy_latents, noise, timesteps = self.make_noise(latents)
451
-
452
- # CFG context for DreamArtist
453
- noisy_latents, timesteps = self.cfg_context.pre(noisy_latents, timesteps)
454
- model_pred = self.TE_unet(prompt_ids, noisy_latents, timesteps, attn_mask=attn_mask, position_ids=position_ids, **kwargs)
455
- model_pred = self.cfg_context.post(model_pred)
456
-
457
- # Get the target for loss depending on the prediction type
458
- if self.cfgs.train.loss.type == "eps":
459
- target = noise
460
- elif self.cfgs.train.loss.type == "sample":
461
- target = self.noise_scheduler.step(noise, timesteps, noisy_latents)
462
- model_pred = self.noise_scheduler.step(model_pred, timesteps, noisy_latents)
463
- else:
464
- raise ValueError(f"Unknown loss type {self.cfgs.train.loss.type}")
465
- return model_pred, target, timesteps
466
-
467
- def train_one_step(self, data_list):
468
- with self.accelerator.accumulate(self.TE_unet):
469
- for idx, data in enumerate(data_list):
470
- image = data.pop('img').to(self.device, dtype=self.weight_dtype)
471
- img_mask = data.pop('mask').to(self.device) if 'mask' in data else None
472
- prompt_ids = data.pop('prompt').to(self.device)
473
- attn_mask = data.pop('attn_mask').to(self.device) if 'attn_mask' in data else None
474
- position_ids = data.pop('position_ids').to(self.device) if 'position_ids' in data else None
475
- other_datas = {k:v.to(self.device) for k, v in data.items() if k != 'plugin_input'}
476
- if 'plugin_input' in data:
477
- other_datas['plugin_input'] = {k:v.to(self.device, dtype=self.weight_dtype) for k, v in data['plugin_input'].items()}
478
-
479
- latents = self.get_latents(image, self.train_loader_group.get_dataset(idx))
480
- model_pred, target, timesteps = self.forward(latents, prompt_ids, attn_mask, position_ids, **other_datas)
481
- loss = self.get_loss(model_pred, target, timesteps, img_mask)*self.train_loader_group.get_loss_weights(idx)
482
- self.accelerator.backward(loss)
483
-
484
- if hasattr(self, 'optimizer'):
485
- if self.accelerator.sync_gradients: # fine-tuning
486
- if hasattr(self.TE_unet, 'trainable_parameters'):
487
- clip_param = self.TE_unet.trainable_parameters()
488
- else:
489
- clip_param = self.TE_unet.module.trainable_parameters()
490
- self.accelerator.clip_grad_norm_(clip_param, self.cfgs.train.max_grad_norm)
491
- self.optimizer.step()
492
- if self.lr_scheduler:
493
- self.lr_scheduler.step()
494
- self.optimizer.zero_grad(set_to_none=self.cfgs.train.set_grads_to_none)
495
-
496
- if hasattr(self, 'optimizer_pt'): # prompt tuning
497
- self.optimizer_pt.step()
498
- if self.lr_scheduler_pt:
499
- self.lr_scheduler_pt.step()
500
- self.optimizer_pt.zero_grad(set_to_none=self.cfgs.train.set_grads_to_none)
501
-
502
- if self.accelerator.sync_gradients:
503
- self.update_ema()
504
- return loss.item()
505
-
506
- def get_loss(self, model_pred, target, timesteps, att_mask):
507
- if att_mask is None:
508
- att_mask = 1.0
509
- if getattr(self.criterion, 'need_timesteps', False):
510
- loss = (self.criterion(model_pred.float(), target.float(), timesteps)*att_mask).mean()
511
- else:
512
- loss = (self.criterion(model_pred.float(), target.float())*att_mask).mean()
513
- if len(self.embedding_hook.emb_train)>0:
514
- loss = loss+0*sum([emb.mean() for emb in self.embedding_hook.emb_train])
515
- return loss
516
-
517
- def update_ema(self):
518
- if hasattr(self, 'ema_unet'):
519
- self.ema_unet.update(self.unet_raw)
520
- if hasattr(self, 'ema_text_encoder'):
521
- self.ema_text_encoder.update(self.TE_raw)
522
-
523
- def save_model(self, from_raw=False):
524
- unet_raw = self.unet_raw
525
- self.ckpt_manager.save_model_with_lora(unet_raw, self.lora_unet, model_ema=getattr(self, 'ema_unet', None),
526
- name='unet', step=self.global_step)
527
- self.ckpt_manager.save_plugins(unet_raw, self.all_plugin_unet, name='unet', step=self.global_step,
528
- model_ema=getattr(self, 'ema_unet', None))
529
- if self.train_TE:
530
- TE_raw = self.TE_raw
531
- # exclude_key: embeddings should not save with text-encoder
532
- self.ckpt_manager.save_model_with_lora(TE_raw, self.lora_TE, model_ema=getattr(self, 'ema_text_encoder', None),
533
- name='text_encoder', step=self.global_step, exclude_key='emb_ex.')
534
- self.ckpt_manager.save_plugins(TE_raw, self.all_plugin_TE, name='text_encoder', step=self.global_step,
535
- model_ema=getattr(self, 'ema_text_encoder', None))
536
-
537
- if self.DA_lora:
538
- self.ckpt_manager.save_model_with_lora(None, self.lora_unet_neg, name='unet-neg', step=self.global_step)
539
- if self.train_TE:
540
- self.ckpt_manager.save_model_with_lora(None, self.lora_TE_neg, name='text_encoder-neg', step=self.global_step)
541
-
542
- self.ckpt_manager.save_embedding(self.train_pts, self.global_step, self.cfgs.tokenizer_pt.replace)
543
-
544
- self.loggers.info(f"Saved state, step: {self.global_step}")
545
-
546
- def make_vis(self):
547
- vis_dir = os.path.join(self.exp_dir, f'vis-{self.global_step}')
548
- new_components = {
549
- 'unet':self.unet_raw,
550
- 'text_encoder':self.TE_raw,
551
- 'tokenizer':self.tokenizer,
552
- 'vae':self.vae,
553
- }
554
- viser = Visualizer(self.cfgs.model.pretrained_model_name_or_path, new_components=new_components)
555
- if self.cfgs.vis_info.prompt:
556
- raise ValueError('vis_info.prompt is None. cannot generate without prompt.')
557
- viser.vis_to_dir(vis_dir, self.cfgs.vis_prompt)
558
-
559
- if __name__ == '__main__':
560
- parser = argparse.ArgumentParser(description='Stable Diffusion Training')
561
- parser.add_argument('--cfg', type=str, default=None, required=True)
562
- args, cfg_args = parser.parse_known_args()
563
-
564
- conf = load_config_with_cli(args.cfg, args_list=cfg_args) # skip --cfg
565
- trainer = Trainer(conf)
566
- trainer.train()
@@ -1,39 +0,0 @@
1
- import argparse
2
- import sys
3
- from functools import partial
4
-
5
- import torch
6
- from accelerate import Accelerator
7
- from loguru import logger
8
-
9
- from hcpdiff.train_ac import Trainer, RatioBucket, load_config_with_cli, set_seed, get_sampler
10
-
11
- class TrainerSingleCard(Trainer):
12
- def init_context(self, cfgs_raw):
13
- self.accelerator = Accelerator(
14
- gradient_accumulation_steps=self.cfgs.train.gradient_accumulation_steps,
15
- mixed_precision=self.cfgs.mixed_precision,
16
- step_scheduler_with_optimizer=False,
17
- )
18
-
19
- self.local_rank = 0
20
- self.world_size = self.accelerator.num_processes
21
-
22
- set_seed(self.cfgs.seed+self.local_rank)
23
-
24
- @property
25
- def unet_raw(self):
26
- return self.TE_unet.unet
27
-
28
- @property
29
- def TE_raw(self):
30
- return self.TE_unet.TE
31
-
32
- if __name__ == '__main__':
33
- parser = argparse.ArgumentParser(description='Stable Diffusion Training')
34
- parser.add_argument('--cfg', type=str, default='cfg/train/demo.yaml')
35
- args, cfg_args = parser.parse_known_args()
36
-
37
- conf = load_config_with_cli(args.cfg, args_list=cfg_args) # skip --cfg
38
- trainer = TrainerSingleCard(conf)
39
- trainer.train()