hcpdiff 0.9.0__py3-none-any.whl → 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (208) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +80 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +1 -2
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text2img.py +36 -74
  21. hcpdiff/data/source/text2img_cond.py +9 -15
  22. hcpdiff/diffusion/__init__.py +0 -0
  23. hcpdiff/diffusion/noise/__init__.py +2 -0
  24. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  25. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  26. hcpdiff/diffusion/sampler/__init__.py +5 -0
  27. hcpdiff/diffusion/sampler/base.py +72 -0
  28. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  29. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  30. hcpdiff/diffusion/sampler/edm.py +22 -0
  31. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  35. hcpdiff/easy/__init__.py +2 -0
  36. hcpdiff/easy/cfg/__init__.py +3 -0
  37. hcpdiff/easy/cfg/sd15_train.py +201 -0
  38. hcpdiff/easy/cfg/sdxl_train.py +140 -0
  39. hcpdiff/easy/cfg/t2i.py +177 -0
  40. hcpdiff/easy/model/__init__.py +2 -0
  41. hcpdiff/easy/model/cnet.py +31 -0
  42. hcpdiff/easy/model/loader.py +79 -0
  43. hcpdiff/easy/sampler.py +46 -0
  44. hcpdiff/evaluate/__init__.py +1 -0
  45. hcpdiff/evaluate/previewer.py +60 -0
  46. hcpdiff/loss/__init__.py +4 -1
  47. hcpdiff/loss/base.py +41 -0
  48. hcpdiff/loss/gw.py +35 -0
  49. hcpdiff/loss/ssim.py +37 -0
  50. hcpdiff/loss/vlb.py +79 -0
  51. hcpdiff/loss/weighting.py +66 -0
  52. hcpdiff/models/__init__.py +2 -2
  53. hcpdiff/models/cfg_context.py +17 -14
  54. hcpdiff/models/compose/compose_hook.py +44 -23
  55. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  56. hcpdiff/models/compose/sdxl_composer.py +4 -4
  57. hcpdiff/models/container.py +1 -1
  58. hcpdiff/models/controlnet.py +16 -16
  59. hcpdiff/models/lora_base_patch.py +14 -25
  60. hcpdiff/models/lora_layers.py +3 -9
  61. hcpdiff/models/lora_layers_patch.py +14 -24
  62. hcpdiff/models/text_emb_ex.py +84 -6
  63. hcpdiff/models/textencoder_ex.py +54 -18
  64. hcpdiff/models/wrapper/__init__.py +3 -0
  65. hcpdiff/models/wrapper/pixart.py +19 -0
  66. hcpdiff/models/wrapper/sd.py +218 -0
  67. hcpdiff/models/wrapper/utils.py +20 -0
  68. hcpdiff/parser/__init__.py +1 -0
  69. hcpdiff/parser/embpt.py +32 -0
  70. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  71. hcpdiff/tools/dataset_generator.py +94 -0
  72. hcpdiff/tools/download_hf_model.py +24 -0
  73. hcpdiff/tools/embedding_convert.py +6 -2
  74. hcpdiff/tools/init_proj.py +3 -21
  75. hcpdiff/tools/lora_convert.py +19 -15
  76. hcpdiff/tools/save_model.py +12 -0
  77. hcpdiff/tools/sd2diffusers.py +1 -1
  78. hcpdiff/train_colo.py +1 -1
  79. hcpdiff/train_deepspeed.py +1 -1
  80. hcpdiff/trainer_ac.py +79 -0
  81. hcpdiff/trainer_ac_single.py +31 -0
  82. hcpdiff/utils/__init__.py +0 -2
  83. hcpdiff/utils/inpaint_pipe.py +790 -0
  84. hcpdiff/utils/net_utils.py +29 -6
  85. hcpdiff/utils/pipe_hook.py +46 -33
  86. hcpdiff/utils/utils.py +21 -4
  87. hcpdiff/workflow/__init__.py +15 -10
  88. hcpdiff/workflow/daam/__init__.py +1 -0
  89. hcpdiff/workflow/daam/act.py +66 -0
  90. hcpdiff/workflow/daam/hook.py +109 -0
  91. hcpdiff/workflow/diffusion.py +128 -136
  92. hcpdiff/workflow/fast.py +31 -0
  93. hcpdiff/workflow/flow.py +67 -0
  94. hcpdiff/workflow/io.py +36 -68
  95. hcpdiff/workflow/model.py +46 -43
  96. hcpdiff/workflow/text.py +84 -52
  97. hcpdiff/workflow/utils.py +32 -12
  98. hcpdiff/workflow/vae.py +37 -38
  99. hcpdiff-2.1.dist-info/METADATA +285 -0
  100. hcpdiff-2.1.dist-info/RECORD +114 -0
  101. {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
  102. hcpdiff-2.1.dist-info/entry_points.txt +5 -0
  103. hcpdiff/ckpt_manager/base.py +0 -16
  104. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  105. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  106. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -60
  107. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  108. hcpdiff/data/bucket.py +0 -358
  109. hcpdiff/data/caption_loader.py +0 -80
  110. hcpdiff/data/cond_dataset.py +0 -40
  111. hcpdiff/data/crop_info_dataset.py +0 -40
  112. hcpdiff/data/data_processor.py +0 -33
  113. hcpdiff/data/pair_dataset.py +0 -146
  114. hcpdiff/data/sampler.py +0 -54
  115. hcpdiff/data/source/base.py +0 -30
  116. hcpdiff/data/utils.py +0 -80
  117. hcpdiff/infer_workflow.py +0 -57
  118. hcpdiff/loggers/__init__.py +0 -13
  119. hcpdiff/loggers/base_logger.py +0 -76
  120. hcpdiff/loggers/cli_logger.py +0 -40
  121. hcpdiff/loggers/preview/__init__.py +0 -1
  122. hcpdiff/loggers/preview/image_previewer.py +0 -149
  123. hcpdiff/loggers/tensorboard_logger.py +0 -30
  124. hcpdiff/loggers/wandb_logger.py +0 -31
  125. hcpdiff/loggers/webui_logger.py +0 -9
  126. hcpdiff/loss/min_snr_loss.py +0 -52
  127. hcpdiff/models/layers.py +0 -81
  128. hcpdiff/models/plugin.py +0 -348
  129. hcpdiff/models/wrapper.py +0 -75
  130. hcpdiff/noise/__init__.py +0 -3
  131. hcpdiff/noise/noise_base.py +0 -16
  132. hcpdiff/noise/pyramid_noise.py +0 -50
  133. hcpdiff/noise/zero_terminal.py +0 -44
  134. hcpdiff/train_ac.py +0 -565
  135. hcpdiff/train_ac_single.py +0 -39
  136. hcpdiff/utils/caption_tools.py +0 -105
  137. hcpdiff/utils/cfg_net_tools.py +0 -321
  138. hcpdiff/utils/cfg_resolvers.py +0 -16
  139. hcpdiff/utils/ema.py +0 -52
  140. hcpdiff/utils/img_size_tool.py +0 -248
  141. hcpdiff/vis/__init__.py +0 -3
  142. hcpdiff/vis/base_interface.py +0 -12
  143. hcpdiff/vis/disk_interface.py +0 -48
  144. hcpdiff/vis/webui_interface.py +0 -17
  145. hcpdiff/visualizer.py +0 -258
  146. hcpdiff/visualizer_reloadable.py +0 -237
  147. hcpdiff/workflow/base.py +0 -59
  148. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  149. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  150. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  151. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  152. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  153. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  154. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  155. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  156. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  157. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  158. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  159. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  160. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  161. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  162. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  163. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  164. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  165. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  166. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  167. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  168. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  169. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  170. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  171. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  172. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  173. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  174. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  175. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  176. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  177. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  178. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  179. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  180. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  181. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  182. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  183. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  184. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  185. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  186. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  187. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  188. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  189. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  190. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  191. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  192. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -57
  193. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  194. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero2.json +0 -32
  195. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero3.json +0 -39
  196. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  197. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  198. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  199. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  200. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  201. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  202. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  203. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  204. hcpdiff-0.9.0.dist-info/METADATA +0 -199
  205. hcpdiff-0.9.0.dist-info/RECORD +0 -155
  206. hcpdiff-0.9.0.dist-info/entry_points.txt +0 -2
  207. {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
  208. {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
@@ -1,44 +0,0 @@
1
- import torch
2
- from diffusers import SchedulerMixin
3
- from .noise_base import NoiseBase
4
-
5
- class ZeroTerminalScheduler(NoiseBase, SchedulerMixin):
6
- def __init__(self, base_scheduler):
7
- super().__init__(base_scheduler)
8
- base_scheduler.betas = self.rescale_zero_terminal_snr(base_scheduler.betas)
9
- base_scheduler.alphas = 1.0-base_scheduler.betas
10
- base_scheduler.alphas_cumprod = torch.cumprod(base_scheduler.alphas, dim=0)
11
-
12
- def rescale_zero_terminal_snr(self, betas):
13
- """
14
- Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
15
- Args:
16
- betas (`torch.FloatTensor`):
17
- the betas that the scheduler is being initialized with.
18
- Returns:
19
- `torch.FloatTensor`: rescaled betas with zero terminal SNR
20
- """
21
- # Convert betas to alphas_bar_sqrt
22
- alphas = 1.0-betas
23
- alphas_cumprod = torch.cumprod(alphas, dim=0)
24
- alphas_bar_sqrt = alphas_cumprod.sqrt()
25
-
26
- # Store old values.
27
- alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
28
- alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
29
-
30
- # Shift so the last timestep is zero.
31
- alphas_bar_sqrt -= alphas_bar_sqrt_T
32
-
33
- # Scale so the first timestep is back to the old value.
34
- alphas_bar_sqrt *= alphas_bar_sqrt_0/(alphas_bar_sqrt_0-alphas_bar_sqrt_T)
35
-
36
- # Convert alphas_bar_sqrt to betas
37
- alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
38
- alphas = alphas_bar[1:]/alphas_bar[:-1] # Revert cumprod
39
- alphas = torch.cat([alphas_bar[0:1], alphas])
40
- betas = 1-alphas
41
-
42
- return betas
43
-
44
-
hcpdiff/train_ac.py DELETED
@@ -1,565 +0,0 @@
1
- """
2
- train_ac.py
3
- ====================
4
- :Name: train with accelerate
5
- :Author: Dong Ziyi
6
- :Affiliation: HCP Lab, SYSU
7
- :Created: 10/03/2023
8
- :Licence: Apache-2.0
9
- """
10
-
11
- import argparse
12
- import math
13
- import os
14
- import time
15
- import warnings
16
- from functools import partial
17
-
18
- import diffusers
19
- import hydra
20
- import torch
21
- import torch.utils.checkpoint
22
- # fix checkpoint bug for train part of model
23
- import torch.utils.checkpoint
24
- import torch.utils.data
25
- import transformers
26
- from accelerate import Accelerator, DistributedDataParallelKwargs
27
- from accelerate.utils import set_seed
28
- from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
29
- from diffusers.utils.import_utils import is_xformers_available
30
- from omegaconf import OmegaConf
31
-
32
- from hcpdiff.ckpt_manager import CkptManagerPKL, CkptManagerSafe
33
- from hcpdiff.data import RatioBucket, DataGroup, get_sampler
34
- from hcpdiff.loggers import LoggerGroup
35
- from hcpdiff.models import CFGContext, DreamArtistPTContext, TEUnetWrapper, SDXLTEUnetWrapper
36
- from hcpdiff.models.compose import ComposeEmbPTHook, ComposeTEEXHook
37
- from hcpdiff.models.compose import SDXLTextEncoder
38
- from hcpdiff.utils.cfg_net_tools import make_hcpdiff, make_plugin
39
- from hcpdiff.utils.ema import ModelEMA
40
- from hcpdiff.utils.net_utils import get_scheduler, auto_tokenizer_cls, auto_text_encoder_cls, load_emb
41
- from hcpdiff.utils.utils import load_config_with_cli, get_cfg_range, mgcd, format_number
42
- from hcpdiff.visualizer import Visualizer
43
-
44
- def checkpoint_fix(function, *args, use_reentrant: bool = False, checkpoint_raw=torch.utils.checkpoint.checkpoint, **kwargs):
45
- return checkpoint_raw(function, *args, use_reentrant=use_reentrant, **kwargs)
46
-
47
- torch.utils.checkpoint.checkpoint = checkpoint_fix
48
-
49
- class Trainer:
50
- weight_dtype_map = {'fp32':torch.float32, 'fp16':torch.float16, 'bf16':torch.bfloat16}
51
- ckpt_manager_map = {'torch':CkptManagerPKL, 'safetensors':CkptManagerSafe}
52
-
53
- def __init__(self, cfgs_raw):
54
- cfgs = hydra.utils.instantiate(cfgs_raw)
55
- self.cfgs = cfgs
56
-
57
- self.init_context(cfgs_raw)
58
- self.build_loggers(cfgs_raw)
59
-
60
- self.train_TE = any([cfgs.text_encoder, cfgs.lora_text_encoder, cfgs.plugin_TE])
61
-
62
- self.build_ckpt_manager()
63
- self.build_model()
64
- self.make_hooks()
65
- self.config_model()
66
- self.cache_latents = False
67
-
68
- self.batch_size_list = []
69
- assert len(cfgs.data)>0, "At least one dataset is need."
70
- loss_weights = [dataset.keywords['loss_weight'] for name, dataset in cfgs.data.items()]
71
- self.train_loader_group = DataGroup([self.build_data(dataset) for name, dataset in cfgs.data.items()], loss_weights)
72
-
73
- if self.cache_latents:
74
- self.vae = self.vae.to('cpu')
75
- self.build_optimizer_scheduler()
76
- try:
77
- self.criterion = cfgs.train.loss.criterion(noise_scheduler=self.noise_scheduler, device=self.device)
78
- except:
79
- self.criterion = cfgs.train.loss.criterion()
80
-
81
- self.cfg_scale = get_cfg_range(cfgs.train.cfg_scale)
82
- if self.cfg_scale[1] == 1.0:
83
- self.cfg_context = CFGContext()
84
- else: # DreamArtist
85
- self.cfg_context = DreamArtistPTContext(self.cfg_scale, self.num_train_timesteps)
86
-
87
- with torch.no_grad():
88
- self.build_ema()
89
-
90
- self.load_resume()
91
-
92
- torch.backends.cuda.matmul.allow_tf32 = cfgs.allow_tf32
93
-
94
- # calculate steps and epochs
95
- self.steps_per_epoch = len(self.train_loader_group.loader_list[0])
96
- if self.cfgs.train.train_epochs is not None:
97
- self.cfgs.train.train_steps = self.cfgs.train.train_epochs*self.steps_per_epoch
98
- else:
99
- self.cfgs.train.train_epochs = math.ceil(self.cfgs.train.train_steps/self.steps_per_epoch)
100
-
101
- if self.is_local_main_process and self.cfgs.previewer is not None:
102
- self.previewer = self.cfgs.previewer(exp_dir=self.exp_dir, te_hook=self.text_enc_hook, unet=self.TE_unet.unet,
103
- TE=self.TE_unet.TE, tokenizer=self.tokenizer, vae=self.vae)
104
-
105
- self.prepare()
106
-
107
- @property
108
- def device(self):
109
- return self.accelerator.device
110
-
111
- @property
112
- def is_local_main_process(self):
113
- return self.accelerator.is_local_main_process
114
-
115
- def init_context(self, cfgs_raw):
116
- ddp_kwargs = DistributedDataParallelKwargs(broadcast_buffers=False)
117
- self.accelerator = Accelerator(
118
- gradient_accumulation_steps=self.cfgs.train.gradient_accumulation_steps,
119
- mixed_precision=self.cfgs.mixed_precision,
120
- step_scheduler_with_optimizer=False,
121
- kwargs_handlers=[ddp_kwargs], # fix inplace bug in DDP while use data_class
122
- )
123
-
124
- self.local_rank = int(os.environ.get("LOCAL_RANK", -1))
125
- self.world_size = self.accelerator.num_processes
126
-
127
- set_seed(self.cfgs.seed+self.local_rank)
128
-
129
- def build_loggers(self, cfgs_raw):
130
- if self.is_local_main_process:
131
- self.exp_dir = self.cfgs.exp_dir.format(time=time.strftime("%Y-%m-%d-%H-%M-%S"))
132
- os.makedirs(os.path.join(self.exp_dir, 'ckpts/'), exist_ok=True)
133
- with open(os.path.join(self.exp_dir, 'cfg.yaml'), 'w', encoding='utf-8') as f:
134
- f.write(OmegaConf.to_yaml(cfgs_raw))
135
- self.loggers: LoggerGroup = LoggerGroup([builder(exp_dir=self.exp_dir) for builder in self.cfgs.logger])
136
- else:
137
- self.loggers: LoggerGroup = LoggerGroup([builder(exp_dir=None) for builder in self.cfgs.logger])
138
-
139
- self.min_log_step = mgcd(*([item.log_step for item in self.loggers.logger_list]))
140
- image_log_steps = [item.image_log_step for item in self.loggers.logger_list if item.enable_log_image]
141
- if len(image_log_steps)>0:
142
- self.min_img_log_step = mgcd(*image_log_steps)
143
- else:
144
- self.min_img_log_step = -1
145
-
146
- self.loggers.info(f'world_size: {self.world_size}')
147
- self.loggers.info(f'accumulation: {self.cfgs.train.gradient_accumulation_steps}')
148
-
149
- if self.is_local_main_process:
150
- transformers.utils.logging.set_verbosity_warning()
151
- diffusers.utils.logging.set_verbosity_warning()
152
- else:
153
- transformers.utils.logging.set_verbosity_error()
154
- diffusers.utils.logging.set_verbosity_error()
155
-
156
- def prepare(self):
157
- # Prepare everything with accelerator.
158
- prepare_name_list, prepare_obj_list = [], []
159
- if self.TE_unet.train_TE:
160
- prepare_obj_list.append(self.TE_unet)
161
- prepare_name_list.append('TE_unet')
162
- else:
163
- prepare_obj_list.append(self.TE_unet.unet)
164
- prepare_name_list.append('TE_unet.unet')
165
-
166
- if hasattr(self, 'optimizer'):
167
- prepare_obj_list.extend([self.optimizer, self.lr_scheduler] if self.lr_scheduler else [self.optimizer])
168
- prepare_name_list.extend(['optimizer', 'lr_scheduler'] if self.lr_scheduler else ['optimizer'])
169
- if hasattr(self, 'optimizer_pt'):
170
- prepare_obj_list.extend([self.optimizer_pt, self.lr_scheduler_pt] if self.lr_scheduler_pt else [self.optimizer_pt])
171
- prepare_name_list.extend(['optimizer_pt', 'lr_scheduler_pt'] if self.lr_scheduler_pt else ['optimizer_pt'])
172
-
173
- prepare_obj_list.extend(self.train_loader_group.loader_list)
174
- prepared_obj = self.accelerator.prepare(*prepare_obj_list)
175
-
176
- if not self.TE_unet.train_TE:
177
- self.TE_unet.unet = prepared_obj[0]
178
- prepared_obj = prepared_obj[1:]
179
- prepare_name_list = prepare_name_list[1:]
180
-
181
- ds_num = len(self.train_loader_group.loader_list)
182
- self.train_loader_group.loader_list = list(prepared_obj[-ds_num:])
183
- prepared_obj = prepared_obj[:-ds_num]
184
-
185
- for name, obj in zip(prepare_name_list, prepared_obj):
186
- setattr(self, name, obj)
187
-
188
- if self.cfgs.model.force_cast_precision:
189
- self.TE_unet.to(dtype=self.weight_dtype)
190
-
191
- def scale_lr(self, parameters):
192
- bs = sum(self.batch_size_list)
193
- scale_factor = bs*self.world_size*self.cfgs.train.gradient_accumulation_steps
194
- for param in parameters:
195
- if 'lr' in param:
196
- param['lr'] *= scale_factor
197
-
198
- def build_model(self):
199
- # Load the tokenizer
200
- if self.cfgs.model.get('tokenizer', None) is not None:
201
- self.tokenizer = self.cfgs.model.tokenizer
202
- else:
203
- tokenizer_cls = auto_tokenizer_cls(self.cfgs.model.pretrained_model_name_or_path, self.cfgs.model.revision)
204
- self.tokenizer = tokenizer_cls.from_pretrained(
205
- self.cfgs.model.pretrained_model_name_or_path, subfolder="tokenizer",
206
- revision=self.cfgs.model.revision, use_fast=False,
207
- )
208
-
209
- # Load scheduler and models
210
- self.noise_scheduler = self.cfgs.model.get('noise_scheduler', None) or \
211
- DDPMScheduler.from_pretrained(self.cfgs.model.pretrained_model_name_or_path, subfolder='scheduler')
212
-
213
- self.num_train_timesteps = len(self.noise_scheduler.timesteps)
214
- self.vae: AutoencoderKL = self.cfgs.model.get('vae', None) or AutoencoderKL.from_pretrained(
215
- self.cfgs.model.pretrained_model_name_or_path, subfolder="vae", revision=self.cfgs.model.revision)
216
- self.build_unet_and_TE()
217
-
218
- def build_unet_and_TE(self): # for easy to use colossalAI
219
- unet = self.cfgs.model.get('unet', None) or UNet2DConditionModel.from_pretrained(
220
- self.cfgs.model.pretrained_model_name_or_path, subfolder="unet", revision=self.cfgs.model.revision
221
- )
222
-
223
- if self.cfgs.model.get('text_encoder', None) is not None:
224
- text_encoder = self.cfgs.model.text_encoder
225
- text_encoder_cls = type(text_encoder)
226
- else:
227
- # import correct text encoder class
228
- text_encoder_cls = auto_text_encoder_cls(self.cfgs.model.pretrained_model_name_or_path, self.cfgs.model.revision)
229
- text_encoder = text_encoder_cls.from_pretrained(
230
- self.cfgs.model.pretrained_model_name_or_path, subfolder="text_encoder", revision=self.cfgs.model.revision
231
- )
232
-
233
- # Wrap unet and text_encoder to make DDP happy. Multiple DDP has soooooo many fxxking bugs!
234
- wrapper_cls = SDXLTEUnetWrapper if text_encoder_cls == SDXLTextEncoder else TEUnetWrapper
235
- self.TE_unet = wrapper_cls(unet, text_encoder, train_TE=self.train_TE)
236
-
237
- def build_ema(self):
238
- if self.cfgs.model.ema is not None:
239
- self.ema_unet = self.cfgs.model.ema(self.TE_unet.unet)
240
- if self.train_TE:
241
- self.ema_text_encoder = self.cfgs.model.ema(self.TE_unet.TE)
242
-
243
- def build_ckpt_manager(self):
244
- self.ckpt_manager = self.ckpt_manager_map[self.cfgs.ckpt_type]()
245
- if self.is_local_main_process:
246
- self.ckpt_manager.set_save_dir(os.path.join(self.exp_dir, 'ckpts'), emb_dir=self.cfgs.tokenizer_pt.emb_dir)
247
-
248
- @property
249
- def unet_raw(self):
250
- return self.TE_unet.module.unet if self.train_TE else self.TE_unet.unet.module
251
-
252
- @property
253
- def TE_raw(self):
254
- return self.TE_unet.module.TE if self.train_TE else self.TE_unet.TE
255
-
256
- def config_model(self):
257
- if self.cfgs.model.enable_xformers:
258
- if is_xformers_available():
259
- self.TE_unet.unet.enable_xformers_memory_efficient_attention()
260
- # self.text_enc_hook.enable_xformers()
261
- else:
262
- warnings.warn("xformers is not available. Make sure it is installed correctly")
263
-
264
- self.vae.requires_grad_(False)
265
- self.TE_unet.requires_grad_(False)
266
-
267
- self.TE_unet.eval()
268
-
269
- if self.cfgs.model.gradient_checkpointing:
270
- self.TE_unet.enable_gradient_checkpointing()
271
-
272
- self.weight_dtype = self.weight_dtype_map.get(self.cfgs.mixed_precision, torch.float32)
273
- self.vae_dtype = self.weight_dtype_map.get(self.cfgs.model.get('vae_dtype', None), torch.float32)
274
- # Move vae and text_encoder to device and cast to weight_dtype
275
- self.vae = self.vae.to(self.device, dtype=self.vae_dtype)
276
- if not self.train_TE:
277
- self.TE_unet.TE = self.TE_unet.TE.to(self.device, dtype=self.weight_dtype)
278
-
279
- @torch.no_grad()
280
- def load_resume(self):
281
- if self.cfgs.train.resume is not None:
282
- for ckpt in self.cfgs.train.resume.ckpt_path.unet:
283
- self.ckpt_manager.load_ckpt_to_model(self.TE_unet.unet, ckpt, model_ema=getattr(self, 'ema_unet', None))
284
- for ckpt in self.cfgs.train.resume.ckpt_path.TE:
285
- self.ckpt_manager.load_ckpt_to_model(self.TE_unet.TE, ckpt, model_ema=getattr(self, 'ema_text_encoder', None))
286
- for name, ckpt in self.cfgs.train.resume.ckpt_path.words:
287
- self.ex_words_emb[name].data = load_emb(ckpt)
288
-
289
- def make_hooks(self):
290
- # Hook tokenizer and embedding to support pt
291
- self.embedding_hook, self.ex_words_emb = ComposeEmbPTHook.hook_from_dir(
292
- self.cfgs.tokenizer_pt.emb_dir, self.tokenizer, self.TE_unet.TE, log=self.is_local_main_process,
293
- N_repeats=self.cfgs.model.tokenizer_repeats, device=self.device)
294
-
295
- self.text_enc_hook = ComposeTEEXHook.hook(self.TE_unet.TE, self.tokenizer, N_repeats=self.cfgs.model.tokenizer_repeats,
296
- device=self.device, clip_skip=self.cfgs.model.clip_skip,
297
- clip_final_norm=self.cfgs.model.clip_final_norm)
298
-
299
- def build_dataset(self, data_builder: partial):
300
- batch_size = data_builder.keywords.pop('batch_size')
301
- cache_latents = data_builder.keywords.pop('cache_latents')
302
- self.batch_size_list.append(batch_size)
303
-
304
- train_dataset = data_builder(tokenizer=self.tokenizer, tokenizer_repeats=self.cfgs.model.tokenizer_repeats)
305
- train_dataset.bucket.build(batch_size*self.world_size, file_names=train_dataset.source.get_image_list())
306
- arb = isinstance(train_dataset.bucket, RatioBucket)
307
- self.loggers.info(f"len(train_dataset): {len(train_dataset)}")
308
-
309
- if cache_latents:
310
- self.cache_latents = True
311
- train_dataset.cache_latents(self.vae, self.vae_dtype, self.device, show_prog=self.is_local_main_process)
312
- return train_dataset, batch_size, arb
313
-
314
- def build_data(self, data_builder: partial) -> torch.utils.data.DataLoader:
315
- train_dataset, batch_size, arb = self.build_dataset(data_builder)
316
-
317
- # Pytorch Data loader
318
- train_sampler = get_sampler()(train_dataset, num_replicas=self.world_size, rank=self.local_rank, shuffle=not arb)
319
- train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=self.cfgs.train.workers,
320
- sampler=train_sampler, collate_fn=train_dataset.collate_fn)
321
- return train_loader
322
-
323
- def get_param_group_train(self):
324
- # make miniFT and warp with lora
325
- self.DA_lora = False
326
- train_params_unet, self.lora_unet = make_hcpdiff(self.TE_unet.unet, self.cfgs.unet, self.cfgs.lora_unet)
327
- if isinstance(self.lora_unet, tuple): # creat negative lora
328
- self.DA_lora = True
329
- self.lora_unet, self.lora_unet_neg = self.lora_unet
330
- train_params_unet_plugin, self.all_plugin_unet = make_plugin(self.TE_unet.unet, self.cfgs.plugin_unet)
331
- train_params_unet += train_params_unet_plugin
332
-
333
- if self.train_TE:
334
- train_params_text_encoder, self.lora_TE = make_hcpdiff(self.TE_unet.TE, self.cfgs.text_encoder, self.cfgs.lora_text_encoder)
335
- if isinstance(self.lora_TE, tuple): # creat negative lora
336
- self.DA_lora = True
337
- self.lora_TE, self.lora_TE_neg = self.lora_TE
338
- train_params_TE_plugin, self.all_plugin_TE = make_plugin(self.TE_unet.TE, self.cfgs.plugin_TE)
339
- train_params_text_encoder += train_params_TE_plugin
340
- else:
341
- train_params_text_encoder = []
342
-
343
- N_params_unet = format_number(sum(sum(x.numel() for x in p['params']) for p in train_params_unet))
344
- N_params_TE = format_number(sum(sum(x.numel() for x in p['params']) for p in train_params_text_encoder))
345
- self.loggers.info(f'unet trainable params: {N_params_unet}; text encoder trainable params: {N_params_TE}')
346
-
347
- # params for embedding
348
- train_params_emb = []
349
- self.train_pts = {}
350
- if self.cfgs.tokenizer_pt.train is not None:
351
- for v in self.cfgs.tokenizer_pt.train:
352
- word_emb = self.ex_words_emb[v.name]
353
- self.train_pts[v.name] = word_emb
354
- word_emb.requires_grad = True
355
- self.embedding_hook.emb_train.append(word_emb)
356
- train_params_emb.append({'params':word_emb, 'lr':v.lr})
357
-
358
- return train_params_unet+train_params_text_encoder, train_params_emb
359
-
360
- def build_optimizer_scheduler(self):
361
- # set optimizer
362
- parameters, parameters_pt = self.get_param_group_train()
363
-
364
- if len(parameters)>0: # do fine-tuning
365
- cfg_opt = self.cfgs.train.optimizer
366
- if self.cfgs.train.scale_lr:
367
- self.scale_lr(parameters)
368
- assert isinstance(cfg_opt, partial), f'optimizer.type is not supported anymore, please use class path like "torch.optim.AdamW".'
369
- self.optimizer = cfg_opt(params=parameters)
370
- self.lr_scheduler = get_scheduler(self.cfgs.train.scheduler, self.optimizer)
371
-
372
- if len(parameters_pt)>0: # do prompt-tuning
373
- cfg_opt_pt = self.cfgs.train.optimizer_pt
374
- if self.cfgs.train.scale_lr_pt:
375
- self.scale_lr(parameters_pt)
376
- assert isinstance(cfg_opt_pt, partial), f'optimizer.type is not supported anymore, please use class path like "torch.optim.AdamW".'
377
- self.optimizer_pt = cfg_opt_pt(params=parameters_pt)
378
- self.lr_scheduler_pt = get_scheduler(self.cfgs.train.scheduler_pt, self.optimizer_pt)
379
-
380
- def train(self, loss_ema=0.93):
381
- total_batch_size = sum(self.batch_size_list)*self.world_size*self.cfgs.train.gradient_accumulation_steps
382
-
383
- self.loggers.info("***** Running training *****")
384
- self.loggers.info(f" Num batches each epoch = {len(self.train_loader_group.loader_list[0])}")
385
- self.loggers.info(f" Num Steps = {self.cfgs.train.train_steps}")
386
- self.loggers.info(f" Instantaneous batch size per device = {sum(self.batch_size_list)}")
387
- self.loggers.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
388
- self.loggers.info(f" Gradient Accumulation steps = {self.cfgs.train.gradient_accumulation_steps}")
389
- self.global_step = 0
390
- if self.cfgs.train.resume is not None:
391
- self.global_step = self.cfgs.train.resume.start_step
392
-
393
- loss_sum = None
394
- for data_list in self.train_loader_group:
395
- loss = self.train_one_step(data_list)
396
- loss_sum = loss if loss_sum is None else (loss_ema*loss_sum+(1-loss_ema)*loss)
397
-
398
- self.global_step += 1
399
- if self.is_local_main_process:
400
- if self.global_step%self.cfgs.train.save_step == 0:
401
- self.save_model()
402
- if self.global_step%self.min_log_step == 0:
403
- # get learning rate from optimizer
404
- lr_model = self.optimizer.param_groups[0]['lr'] if hasattr(self, 'optimizer') else 0.
405
- lr_word = self.optimizer_pt.param_groups[0]['lr'] if hasattr(self, 'optimizer_pt') else 0.
406
- self.loggers.log(datas={
407
- 'Step':{'format':'[{}/{}]', 'data':[self.global_step, self.cfgs.train.train_steps]},
408
- 'Epoch':{'format':'[{}/{}]<{}/{}>', 'data':[self.global_step//self.steps_per_epoch, self.cfgs.train.train_epochs,
409
- self.global_step%self.steps_per_epoch, self.steps_per_epoch]},
410
- 'LR_model':{'format':'{:.2e}', 'data':[lr_model]},
411
- 'LR_word':{'format':'{:.2e}', 'data':[lr_word]},
412
- 'Loss':{'format':'{:.5f}', 'data':[loss_sum]},
413
- }, step=self.global_step)
414
- if self.min_img_log_step>0 and self.global_step%self.min_img_log_step == 0:
415
- self.loggers.log_image(self.previewer.preview_dict(), self.global_step)
416
-
417
- if self.global_step>=self.cfgs.train.train_steps:
418
- break
419
-
420
- self.wait_for_everyone()
421
- if self.is_local_main_process:
422
- self.save_model()
423
-
424
- def wait_for_everyone(self):
425
- self.accelerator.wait_for_everyone()
426
-
427
- @torch.no_grad()
428
- def get_latents(self, image, dataset):
429
- if dataset.latents is None:
430
- latents = self.vae.encode(image.to(dtype=self.vae.dtype)).latent_dist.sample()
431
- latents = latents*self.vae.config.scaling_factor
432
- else:
433
- latents = image # Cached latents
434
- return latents
435
-
436
- def make_noise(self, latents):
437
- # Sample noise that we'll add to the latents
438
- noise = torch.randn_like(latents)
439
- bsz = latents.shape[0]
440
- # Sample a random timestep for each image
441
- timesteps = torch.randint(0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
442
- timesteps = timesteps.long()
443
-
444
- # Add noise to the latents according to the noise magnitude at each timestep
445
- # (this is the forward diffusion process)
446
- return self.noise_scheduler.add_noise(latents, noise, timesteps), noise, timesteps
447
-
448
- def forward(self, latents, prompt_ids, attn_mask=None, position_ids=None, **kwargs):
449
- noisy_latents, noise, timesteps = self.make_noise(latents)
450
-
451
- # CFG context for DreamArtist
452
- noisy_latents, timesteps = self.cfg_context.pre(noisy_latents, timesteps)
453
- model_pred = self.TE_unet(prompt_ids, noisy_latents, timesteps, attn_mask=attn_mask, position_ids=position_ids, **kwargs)
454
- model_pred = self.cfg_context.post(model_pred)
455
-
456
- # Get the target for loss depending on the prediction type
457
- if self.cfgs.train.loss.type == "eps":
458
- target = noise
459
- elif self.cfgs.train.loss.type == "sample":
460
- target = self.noise_scheduler.step(noise, timesteps, noisy_latents)
461
- model_pred = self.noise_scheduler.step(model_pred, timesteps, noisy_latents)
462
- else:
463
- raise ValueError(f"Unknown loss type {self.cfgs.train.loss.type}")
464
- return model_pred, target, timesteps
465
-
466
- def train_one_step(self, data_list):
467
- with self.accelerator.accumulate(self.TE_unet):
468
- for idx, data in enumerate(data_list):
469
- image = data.pop('img').to(self.device, dtype=self.weight_dtype)
470
- img_mask = data.pop('mask').to(self.device) if 'mask' in data else None
471
- prompt_ids = data.pop('prompt').to(self.device)
472
- attn_mask = data.pop('attn_mask').to(self.device) if 'attn_mask' in data else None
473
- position_ids = data.pop('position_ids').to(self.device) if 'position_ids' in data else None
474
- other_datas = {k:v.to(self.device) for k, v in data.items() if k!='plugin_input'}
475
- if 'plugin_input' in data:
476
- other_datas['plugin_input'] = {k:v.to(self.device, dtype=self.weight_dtype) for k, v in data['plugin_input'].items()}
477
-
478
- latents = self.get_latents(image, self.train_loader_group.get_dataset(idx))
479
- model_pred, target, timesteps = self.forward(latents, prompt_ids, attn_mask, position_ids, **other_datas)
480
- loss = self.get_loss(model_pred, target, timesteps, img_mask)*self.train_loader_group.get_loss_weights(idx)
481
- self.accelerator.backward(loss)
482
-
483
- if hasattr(self, 'optimizer'):
484
- if self.accelerator.sync_gradients: # fine-tuning
485
- if hasattr(self.TE_unet, 'trainable_parameters'):
486
- clip_param = self.TE_unet.trainable_parameters()
487
- else:
488
- clip_param = self.TE_unet.module.trainable_parameters()
489
- self.accelerator.clip_grad_norm_(clip_param, self.cfgs.train.max_grad_norm)
490
- self.optimizer.step()
491
- if self.lr_scheduler:
492
- self.lr_scheduler.step()
493
- self.optimizer.zero_grad(set_to_none=self.cfgs.train.set_grads_to_none)
494
-
495
- if hasattr(self, 'optimizer_pt'): # prompt tuning
496
- self.optimizer_pt.step()
497
- if self.lr_scheduler_pt:
498
- self.lr_scheduler_pt.step()
499
- self.optimizer_pt.zero_grad(set_to_none=self.cfgs.train.set_grads_to_none)
500
-
501
- if self.accelerator.sync_gradients:
502
- self.update_ema()
503
- return loss.item()
504
-
505
- def get_loss(self, model_pred, target, timesteps, att_mask):
506
- if att_mask is None:
507
- att_mask = 1.0
508
- if getattr(self.criterion, 'need_timesteps', False):
509
- loss = (self.criterion(model_pred.float(), target.float(), timesteps)*att_mask).mean()
510
- else:
511
- loss = (self.criterion(model_pred.float(), target.float())*att_mask).mean()
512
- if len(self.embedding_hook.emb_train)>0:
513
- loss = loss+0*sum([emb.mean() for emb in self.embedding_hook.emb_train])
514
- return loss
515
-
516
- def update_ema(self):
517
- if hasattr(self, 'ema_unet'):
518
- self.ema_unet.update(self.unet_raw)
519
- if hasattr(self, 'ema_text_encoder'):
520
- self.ema_text_encoder.update(self.TE_raw)
521
-
522
- def save_model(self, from_raw=False):
523
- unet_raw = self.unet_raw
524
- self.ckpt_manager.save_model_with_lora(unet_raw, self.lora_unet, model_ema=getattr(self, 'ema_unet', None),
525
- name='unet', step=self.global_step)
526
- self.ckpt_manager.save_plugins(unet_raw, self.all_plugin_unet, name='unet', step=self.global_step,
527
- model_ema=getattr(self, 'ema_unet', None))
528
- if self.train_TE:
529
- TE_raw = self.TE_raw
530
- # exclude_key: embeddings should not save with text-encoder
531
- self.ckpt_manager.save_model_with_lora(TE_raw, self.lora_TE, model_ema=getattr(self, 'ema_text_encoder', None),
532
- name='text_encoder', step=self.global_step, exclude_key='emb_ex.')
533
- self.ckpt_manager.save_plugins(TE_raw, self.all_plugin_TE, name='text_encoder', step=self.global_step,
534
- model_ema=getattr(self, 'ema_text_encoder', None))
535
-
536
- if self.DA_lora:
537
- self.ckpt_manager.save_model_with_lora(None, self.lora_unet_neg, name='unet-neg', step=self.global_step)
538
- if self.train_TE:
539
- self.ckpt_manager.save_model_with_lora(None, self.lora_TE_neg, name='text_encoder-neg', step=self.global_step)
540
-
541
- self.ckpt_manager.save_embedding(self.train_pts, self.global_step, self.cfgs.tokenizer_pt.replace)
542
-
543
- self.loggers.info(f"Saved state, step: {self.global_step}")
544
-
545
- def make_vis(self):
546
- vis_dir = os.path.join(self.exp_dir, f'vis-{self.global_step}')
547
- new_components = {
548
- 'unet':self.unet_raw,
549
- 'text_encoder':self.TE_raw,
550
- 'tokenizer':self.tokenizer,
551
- 'vae':self.vae,
552
- }
553
- viser = Visualizer(self.cfgs.model.pretrained_model_name_or_path, new_components=new_components)
554
- if self.cfgs.vis_info.prompt:
555
- raise ValueError('vis_info.prompt is None. cannot generate without prompt.')
556
- viser.vis_to_dir(vis_dir, self.cfgs.vis_prompt)
557
-
558
- if __name__ == '__main__':
559
- parser = argparse.ArgumentParser(description='Stable Diffusion Training')
560
- parser.add_argument('--cfg', type=str, default=None, required=True)
561
- args, cfg_args = parser.parse_known_args()
562
-
563
- conf = load_config_with_cli(args.cfg, args_list=cfg_args) # skip --cfg
564
- trainer = Trainer(conf)
565
- trainer.train()
@@ -1,39 +0,0 @@
1
- import argparse
2
- import sys
3
- from functools import partial
4
-
5
- import torch
6
- from accelerate import Accelerator
7
- from loguru import logger
8
-
9
- from hcpdiff.train_ac import Trainer, RatioBucket, load_config_with_cli, set_seed, get_sampler
10
-
11
- class TrainerSingleCard(Trainer):
12
- def init_context(self, cfgs_raw):
13
- self.accelerator = Accelerator(
14
- gradient_accumulation_steps=self.cfgs.train.gradient_accumulation_steps,
15
- mixed_precision=self.cfgs.mixed_precision,
16
- step_scheduler_with_optimizer=False,
17
- )
18
-
19
- self.local_rank = 0
20
- self.world_size = self.accelerator.num_processes
21
-
22
- set_seed(self.cfgs.seed+self.local_rank)
23
-
24
- @property
25
- def unet_raw(self):
26
- return self.TE_unet.unet
27
-
28
- @property
29
- def TE_raw(self):
30
- return self.TE_unet.TE
31
-
32
- if __name__ == '__main__':
33
- parser = argparse.ArgumentParser(description='Stable Diffusion Training')
34
- parser.add_argument('--cfg', type=str, default='cfg/train/demo.yaml')
35
- args, cfg_args = parser.parse_known_args()
36
-
37
- conf = load_config_with_cli(args.cfg, args_list=cfg_args) # skip --cfg
38
- trainer = TrainerSingleCard(conf)
39
- trainer.train()