hcpdiff 0.9.1__py3-none-any.whl → 2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +252 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +90 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +3 -3
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text.py +40 -0
  21. hcpdiff/data/source/text2img.py +36 -74
  22. hcpdiff/data/source/text2img_cond.py +9 -15
  23. hcpdiff/diffusion/__init__.py +0 -0
  24. hcpdiff/diffusion/noise/__init__.py +2 -0
  25. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  26. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  27. hcpdiff/diffusion/sampler/__init__.py +5 -0
  28. hcpdiff/diffusion/sampler/base.py +72 -0
  29. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  30. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  31. hcpdiff/diffusion/sampler/edm.py +22 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  35. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  36. hcpdiff/easy/__init__.py +2 -0
  37. hcpdiff/easy/cfg/__init__.py +3 -0
  38. hcpdiff/easy/cfg/sd15_train.py +207 -0
  39. hcpdiff/easy/cfg/sdxl_train.py +147 -0
  40. hcpdiff/easy/cfg/t2i.py +228 -0
  41. hcpdiff/easy/model/__init__.py +2 -0
  42. hcpdiff/easy/model/cnet.py +31 -0
  43. hcpdiff/easy/model/loader.py +79 -0
  44. hcpdiff/easy/sampler.py +46 -0
  45. hcpdiff/evaluate/__init__.py +1 -0
  46. hcpdiff/evaluate/previewer.py +60 -0
  47. hcpdiff/loss/__init__.py +4 -1
  48. hcpdiff/loss/base.py +41 -0
  49. hcpdiff/loss/gw.py +35 -0
  50. hcpdiff/loss/ssim.py +37 -0
  51. hcpdiff/loss/vlb.py +79 -0
  52. hcpdiff/loss/weighting.py +66 -0
  53. hcpdiff/models/__init__.py +2 -2
  54. hcpdiff/models/cfg_context.py +17 -14
  55. hcpdiff/models/compose/compose_hook.py +44 -23
  56. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  57. hcpdiff/models/compose/sdxl_composer.py +4 -4
  58. hcpdiff/models/controlnet.py +16 -16
  59. hcpdiff/models/lora_base_patch.py +14 -25
  60. hcpdiff/models/lora_layers.py +3 -9
  61. hcpdiff/models/lora_layers_patch.py +14 -24
  62. hcpdiff/models/text_emb_ex.py +84 -6
  63. hcpdiff/models/textencoder_ex.py +54 -18
  64. hcpdiff/models/wrapper/__init__.py +3 -0
  65. hcpdiff/models/wrapper/pixart.py +19 -0
  66. hcpdiff/models/wrapper/sd.py +218 -0
  67. hcpdiff/models/wrapper/utils.py +20 -0
  68. hcpdiff/parser/__init__.py +1 -0
  69. hcpdiff/parser/embpt.py +32 -0
  70. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  71. hcpdiff/tools/dataset_generator.py +94 -0
  72. hcpdiff/tools/download_hf_model.py +24 -0
  73. hcpdiff/tools/init_proj.py +3 -21
  74. hcpdiff/tools/lora_convert.py +18 -17
  75. hcpdiff/tools/save_model.py +12 -0
  76. hcpdiff/tools/sd2diffusers.py +1 -1
  77. hcpdiff/train_colo.py +1 -1
  78. hcpdiff/train_deepspeed.py +1 -1
  79. hcpdiff/trainer_ac.py +79 -0
  80. hcpdiff/trainer_ac_single.py +31 -0
  81. hcpdiff/utils/__init__.py +0 -2
  82. hcpdiff/utils/inpaint_pipe.py +7 -2
  83. hcpdiff/utils/net_utils.py +29 -6
  84. hcpdiff/utils/pipe_hook.py +24 -7
  85. hcpdiff/utils/utils.py +21 -4
  86. hcpdiff/workflow/__init__.py +15 -10
  87. hcpdiff/workflow/daam/__init__.py +1 -0
  88. hcpdiff/workflow/daam/act.py +66 -0
  89. hcpdiff/workflow/daam/hook.py +109 -0
  90. hcpdiff/workflow/diffusion.py +118 -128
  91. hcpdiff/workflow/fast.py +31 -0
  92. hcpdiff/workflow/flow.py +67 -0
  93. hcpdiff/workflow/io.py +36 -130
  94. hcpdiff/workflow/model.py +46 -43
  95. hcpdiff/workflow/text.py +60 -47
  96. hcpdiff/workflow/utils.py +32 -12
  97. hcpdiff/workflow/vae.py +37 -38
  98. hcpdiff-2.2.dist-info/METADATA +299 -0
  99. hcpdiff-2.2.dist-info/RECORD +115 -0
  100. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/WHEEL +1 -1
  101. hcpdiff-2.2.dist-info/entry_points.txt +5 -0
  102. hcpdiff/ckpt_manager/base.py +0 -16
  103. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  104. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  105. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
  106. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  107. hcpdiff/data/bucket.py +0 -358
  108. hcpdiff/data/caption_loader.py +0 -80
  109. hcpdiff/data/cond_dataset.py +0 -40
  110. hcpdiff/data/crop_info_dataset.py +0 -40
  111. hcpdiff/data/data_processor.py +0 -33
  112. hcpdiff/data/pair_dataset.py +0 -146
  113. hcpdiff/data/sampler.py +0 -54
  114. hcpdiff/data/source/base.py +0 -30
  115. hcpdiff/data/utils.py +0 -80
  116. hcpdiff/deprecated/__init__.py +0 -1
  117. hcpdiff/deprecated/cfg_converter.py +0 -81
  118. hcpdiff/deprecated/lora_convert.py +0 -31
  119. hcpdiff/infer_workflow.py +0 -57
  120. hcpdiff/loggers/__init__.py +0 -13
  121. hcpdiff/loggers/base_logger.py +0 -76
  122. hcpdiff/loggers/cli_logger.py +0 -40
  123. hcpdiff/loggers/preview/__init__.py +0 -1
  124. hcpdiff/loggers/preview/image_previewer.py +0 -149
  125. hcpdiff/loggers/tensorboard_logger.py +0 -30
  126. hcpdiff/loggers/wandb_logger.py +0 -31
  127. hcpdiff/loggers/webui_logger.py +0 -9
  128. hcpdiff/loss/min_snr_loss.py +0 -52
  129. hcpdiff/models/layers.py +0 -81
  130. hcpdiff/models/plugin.py +0 -348
  131. hcpdiff/models/wrapper.py +0 -75
  132. hcpdiff/noise/__init__.py +0 -3
  133. hcpdiff/noise/noise_base.py +0 -16
  134. hcpdiff/noise/pyramid_noise.py +0 -50
  135. hcpdiff/noise/zero_terminal.py +0 -44
  136. hcpdiff/train_ac.py +0 -566
  137. hcpdiff/train_ac_single.py +0 -39
  138. hcpdiff/utils/caption_tools.py +0 -105
  139. hcpdiff/utils/cfg_net_tools.py +0 -321
  140. hcpdiff/utils/cfg_resolvers.py +0 -16
  141. hcpdiff/utils/ema.py +0 -52
  142. hcpdiff/utils/img_size_tool.py +0 -248
  143. hcpdiff/vis/__init__.py +0 -3
  144. hcpdiff/vis/base_interface.py +0 -12
  145. hcpdiff/vis/disk_interface.py +0 -48
  146. hcpdiff/vis/webui_interface.py +0 -17
  147. hcpdiff/viser_fast.py +0 -138
  148. hcpdiff/visualizer.py +0 -265
  149. hcpdiff/visualizer_reloadable.py +0 -237
  150. hcpdiff/workflow/base.py +0 -59
  151. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  152. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  153. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  154. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  155. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  156. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  157. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  158. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  159. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  160. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  161. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  162. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  163. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  164. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  165. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  166. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  167. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  168. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  169. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  170. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  171. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  172. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  173. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  174. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  175. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  176. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  177. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  178. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  179. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  180. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  181. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  182. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  183. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  184. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  185. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  186. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  187. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  188. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  189. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  190. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  191. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  192. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  193. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  194. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  195. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
  196. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  197. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
  198. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
  199. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  200. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  201. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  202. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  203. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  204. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  205. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  206. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  207. hcpdiff-0.9.1.dist-info/METADATA +0 -199
  208. hcpdiff-0.9.1.dist-info/RECORD +0 -160
  209. hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
  210. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info/licenses}/LICENSE +0 -0
  211. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,20 @@
1
+ import torch
2
+
3
+ from .base import BaseSampler
4
+ from .sigma_scheduler import SigmaScheduler
5
+
6
+ class DDPMSampler(BaseSampler):
7
+ def __init__(self, sigma_scheduler: SigmaScheduler, generator: torch.Generator=None):
8
+ super().__init__(sigma_scheduler, generator)
9
+
10
+ def c_in(self, sigma):
11
+ return 1./(sigma**2+1).sqrt()
12
+
13
+ def c_out(self, sigma):
14
+ return -sigma
15
+
16
+ def c_skip(self, sigma):
17
+ return 1.
18
+
19
+ def denoise(self, x, sigma, eps=None, generator=None):
20
+ raise NotImplementedError
@@ -0,0 +1,66 @@
1
+ import torch
2
+ import inspect
3
+ from diffusers import SchedulerMixin, DDPMScheduler
4
+
5
+ try:
6
+ from diffusers.utils import randn_tensor
7
+ except:
8
+ # new version of diffusers
9
+ from diffusers.utils.torch_utils import randn_tensor
10
+
11
+ from .base import BaseSampler
12
+ from .sigma_scheduler import TimeSigmaScheduler
13
+
14
+ class DiffusersSampler(BaseSampler):
15
+ def __init__(self, scheduler: SchedulerMixin, eta=0.0, generator: torch.Generator=None):
16
+ sigma_scheduler = TimeSigmaScheduler()
17
+ super().__init__(sigma_scheduler, generator)
18
+ self.scheduler = scheduler
19
+ self.eta = eta
20
+
21
+ def c_in(self, sigma):
22
+ one = torch.ones_like(sigma)
23
+ if hasattr(self.scheduler, '_step_index'):
24
+ self.scheduler._step_index = None
25
+ return self.scheduler.scale_model_input(one, sigma)
26
+
27
+ def c_out(self, sigma):
28
+ return -sigma
29
+
30
+ def c_skip(self, sigma):
31
+ if self.c_in(sigma) == 1.: # DDPM model
32
+ return (sigma**2+1).sqrt() # 1/sqrt(alpha_)
33
+ else: # EDM model
34
+ return 1.
35
+
36
+ def get_timesteps(self, N_steps, device='cuda'):
37
+ self.scheduler.set_timesteps(N_steps, device=device)
38
+ return self.scheduler.timesteps
39
+
40
+ def init_noise(self, shape, device='cuda', dtype=torch.float32):
41
+ return randn_tensor(shape, generator=self.generator, device=device, dtype=dtype)*self.scheduler.init_noise_sigma
42
+
43
+ def add_noise(self, x, sigma):
44
+ noise = randn_tensor(x.shape, generator=self.generator, device=x.device, dtype=x.dtype)
45
+ return self.scheduler.add_noise(x, noise, sigma), noise
46
+
47
+ def prepare_extra_step_kwargs(self, scheduler, generator, eta):
48
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
49
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
50
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
51
+ # and should be between [0, 1]
52
+
53
+ accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys())
54
+ extra_step_kwargs = {}
55
+ if accepts_eta:
56
+ extra_step_kwargs["eta"] = eta
57
+
58
+ # check if the scheduler accepts generator
59
+ accepts_generator = "generator" in set(inspect.signature(scheduler.step).parameters.keys())
60
+ if accepts_generator:
61
+ extra_step_kwargs["generator"] = generator
62
+ return extra_step_kwargs
63
+
64
+ def denoise(self, x_t, sigma, eps=None, generator=None):
65
+ extra_step_kwargs = self.prepare_extra_step_kwargs(self.scheduler, generator, self.eta)
66
+ return self.scheduler.step(eps, sigma, x_t, **extra_step_kwargs).prev_sample
@@ -0,0 +1,22 @@
1
+ import torch
2
+
3
+ from .base import BaseSampler
4
+ from .sigma_scheduler import SigmaScheduler
5
+
6
+ class EDMSampler(BaseSampler):
7
+ def __init__(self, sigma_scheduler: SigmaScheduler, generator: torch.Generator = None, sigma_data: float = 1.0, sigma_thr=1000):
8
+ super().__init__(sigma_scheduler, generator)
9
+ self.sigma_data = sigma_data
10
+ self.sigma_thr = sigma_thr
11
+
12
+ def c_in(self, sigma):
13
+ return 1/(sigma**2+self.sigma_data**2).sqrt()
14
+
15
+ def c_out(self, sigma):
16
+ return (sigma*self.sigma_data)/(sigma**2+self.sigma_data**2).sqrt()
17
+
18
+ def c_skip(self, sigma):
19
+ return self.sigma_data**2/(sigma**2+self.sigma_data**2)
20
+
21
+ def denoise(self, x, sigma, eps=None, generator=None):
22
+ raise NotImplementedError
@@ -0,0 +1,3 @@
1
+ from .base import SigmaScheduler
2
+ from .ddpm import DDPMDiscreteSigmaScheduler, DDPMContinuousSigmaScheduler, TimeSigmaScheduler
3
+ from .edm import EDMSigmaScheduler, EDMRefSigmaScheduler
@@ -0,0 +1,14 @@
1
+ from typing import Union, Tuple
2
+
3
+ import torch
4
+
5
+ class SigmaScheduler:
6
+
7
+ def get_sigma(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
8
+ '''
9
+ :param t: 0-1, rate of time step
10
+ '''
11
+ raise NotImplementedError
12
+
13
+ def sample_sigma(self, min_rate=0.0, max_rate=1.0, shape=(1,)) -> Tuple[torch.Tensor, torch.Tensor]:
14
+ raise NotImplementedError
@@ -0,0 +1,197 @@
1
+ import torch
2
+ import math
3
+ from typing import Union, Tuple
4
+ from hcpdiff.utils import linear_interp
5
+ from .base import SigmaScheduler
6
+
7
+ class DDPMDiscreteSigmaScheduler(SigmaScheduler):
8
+ def __init__(self, beta_schedule: str = "scaled_linear", linear_start=0.00085, linear_end=0.0120, num_timesteps=1000):
9
+ super().__init__()
10
+ self.num_timesteps = num_timesteps
11
+ self.betas = self.make_betas(beta_schedule, linear_start, linear_end, num_timesteps)
12
+ alphas = 1.0-self.betas
13
+ self.alphas_cumprod = torch.cumprod(alphas, dim=0)
14
+ self.sigmas = ((1-self.alphas_cumprod)/self.alphas_cumprod).sqrt()
15
+
16
+ # for VLB calculation
17
+ self.alphas_cumprod_prev = torch.cat([alphas.new_tensor([1.0]), self.alphas_cumprod[:-1]])
18
+ self.posterior_mean_coef1 = self.betas*torch.sqrt(self.alphas_cumprod_prev)/(1.0-self.alphas_cumprod)
19
+ self.posterior_mean_coef2 = (1.0-self.alphas_cumprod_prev)*torch.sqrt(alphas)/(1.0-self.alphas_cumprod)
20
+
21
+ self.posterior_variance = self.betas*(1.0-self.alphas_cumprod_prev)/(1.0-self.alphas_cumprod)
22
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
23
+ self.posterior_log_variance_clipped = torch.log(torch.cat([self.posterior_variance[1:2], self.posterior_variance[1:]]))
24
+
25
+
26
+ @property
27
+ def sigma_min(self):
28
+ return self.sigmas[0]
29
+
30
+ @property
31
+ def sigma_max(self):
32
+ return self.sigmas[-1]
33
+
34
+ def get_sigma(self, t: Union[float, torch.Tensor]):
35
+ if isinstance(t, float):
36
+ t = torch.tensor(t)
37
+ return self.sigmas[(t*len(self.sigmas)).long()]
38
+
39
+ def sample_sigma(self, min_rate=0.0, max_rate=1.0, shape=(1,)):
40
+ if isinstance(min_rate, float):
41
+ min_rate = torch.full(shape, min_rate)
42
+ if isinstance(max_rate, float):
43
+ max_rate = torch.full(shape, max_rate)
44
+
45
+ t = torch.lerp(min_rate, max_rate, torch.rand_like(min_rate))
46
+ t_scale = (t*(self.num_timesteps-1e-5)).long() # [0, num_timesteps-1)
47
+ return self.sigmas[t_scale], t
48
+
49
+ def sigma_to_t(self, sigma: Union[float, torch.Tensor]):
50
+ t = (self.sigmas-sigma).abs().argmin()
51
+ return t/self.num_timesteps
52
+
53
+ def get_post_mean(self, t, x_0, x_t):
54
+ t = (t*len(self.sigmas)).long()
55
+ return self.posterior_mean_coef1[t].view(-1, 1, 1, 1).to(t.device)*x_0 + self.posterior_mean_coef2[t].view(-1, 1, 1, 1).to(t.device)*x_t
56
+
57
+ def get_post_log_var(self, t, x_t_var=None):
58
+ t = (t*len(self.sigmas)).long()
59
+ min_log = self.posterior_log_variance_clipped[t].view(-1, 1, 1, 1).to(t.device)
60
+ if x_t_var is None:
61
+ return min_log
62
+ else:
63
+ max_log = self.betas.log()[t].view(-1, 1, 1, 1).to(t.device)
64
+ # The model_var_values is [-1, 1] for [min_var, max_var].
65
+ frac = (x_t_var+1)/2
66
+ model_log_variance = frac*max_log+(1-frac)*min_log
67
+ return model_log_variance
68
+
69
+
70
+ @staticmethod
71
+ def betas_for_alpha_bar(
72
+ num_diffusion_timesteps,
73
+ max_beta=0.999,
74
+ alpha_transform_type="cosine",
75
+ ):
76
+ """
77
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
78
+ (1-beta) over time from t = [0,1].
79
+
80
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
81
+ to that part of the diffusion process.
82
+
83
+
84
+ Args:
85
+ num_diffusion_timesteps (`int`): the number of betas to produce.
86
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
87
+ prevent singularities.
88
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
89
+ Choose from `cosine` or `exp`
90
+
91
+ Returns:
92
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
93
+ """
94
+ if alpha_transform_type == "cosine":
95
+
96
+ def alpha_bar_fn(t):
97
+ return math.cos((t+0.008)/1.008*math.pi/2)**2
98
+
99
+ elif alpha_transform_type == "exp":
100
+
101
+ def alpha_bar_fn(t):
102
+ return math.exp(t*-12.0)
103
+
104
+ else:
105
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
106
+
107
+ betas = []
108
+ for i in range(num_diffusion_timesteps):
109
+ t1 = i/num_diffusion_timesteps
110
+ t2 = (i+1)/num_diffusion_timesteps
111
+ betas.append(min(1-alpha_bar_fn(t2)/alpha_bar_fn(t1), max_beta))
112
+ return torch.tensor(betas, dtype=torch.float32)
113
+
114
+ @staticmethod
115
+ def make_betas(beta_schedule, beta_start, beta_end, num_train_timesteps, betas=None):
116
+ if betas is not None:
117
+ return torch.tensor(betas, dtype=torch.float32)
118
+ elif beta_schedule == "linear":
119
+ return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
120
+ elif beta_schedule == "scaled_linear":
121
+ # this schedule is very specific to the latent diffusion model.
122
+ return torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32)**2
123
+ elif beta_schedule == "squaredcos_cap_v2":
124
+ # Glide cosine schedule
125
+ return DDPMDiscreteSigmaScheduler.betas_for_alpha_bar(num_train_timesteps)
126
+ elif beta_schedule == "sigmoid":
127
+ # GeoDiff sigmoid schedule
128
+ betas = torch.linspace(-6, 6, num_train_timesteps)
129
+ return torch.sigmoid(betas)*(beta_end-beta_start)+beta_start
130
+ else:
131
+ raise NotImplementedError(f"{beta_schedule} does is not implemented.")
132
+
133
+ class DDPMContinuousSigmaScheduler(DDPMDiscreteSigmaScheduler):
134
+
135
+ def get_sigma(self, t: Union[float, torch.Tensor]):
136
+ if isinstance(t, float):
137
+ t = torch.tensor(t)
138
+ return linear_interp(self.sigmas, t)
139
+
140
+ def sample_sigma(self, min_rate=0.0, max_rate=1.0, shape=(1,)):
141
+ if isinstance(min_rate, float):
142
+ min_rate = torch.full(shape, min_rate)
143
+ if isinstance(max_rate, float):
144
+ max_rate = torch.full(shape, max_rate)
145
+
146
+ t = torch.lerp(min_rate, max_rate, torch.rand_like(min_rate))
147
+ t_scale = (t*(self.num_timesteps-1-1e-5)) # [0, num_timesteps-1)
148
+
149
+ return linear_interp(self.sigmas, t_scale), t
150
+
151
+ def sigma_to_t(self, sigma: Union[float, torch.Tensor]):
152
+ diff = self.sigmas-sigma
153
+ diff[diff<0] = float('inf')
154
+ t0 = diff.argmin().clamp(0, self.num_timesteps-2)
155
+ return t0 + diff.min()/(self.sigmas[t0+1]-self.sigmas[t0])
156
+
157
+ class TimeSigmaScheduler(SigmaScheduler):
158
+ def __init__(self, num_timesteps=1000):
159
+ super().__init__()
160
+ self.num_timesteps = num_timesteps
161
+
162
+ def get_sigma(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
163
+ '''
164
+ :param t: 0-1, rate of time step
165
+ '''
166
+ return t
167
+
168
+ def sample_sigma(self, min_rate=0.0, max_rate=1.0, shape=(1,)) -> Tuple[torch.Tensor, torch.Tensor]:
169
+ if isinstance(min_rate, float):
170
+ min_rate = torch.full(shape, min_rate)
171
+ if isinstance(max_rate, float):
172
+ max_rate = torch.full(shape, max_rate)
173
+
174
+ t = torch.lerp(min_rate, max_rate, torch.rand_like(min_rate))
175
+ t_scale = (t*(self.num_timesteps-1e-5)).long() # [0, num_timesteps-1)
176
+ return t_scale, t
177
+
178
+ if __name__ == '__main__':
179
+ from matplotlib import pyplot as plt
180
+ import numpy as np
181
+
182
+ sigma_scheduler = DDPMDiscreteSigmaScheduler()
183
+ print(sigma_scheduler.sigma_min, sigma_scheduler.sigma_max)
184
+ t = torch.linspace(0, 1, 1000)
185
+ rho = 1.
186
+ s2 = (sigma_scheduler.sigma_min**(1/rho)+t*(sigma_scheduler.sigma_max**(1/rho)-sigma_scheduler.sigma_min**(1/rho)))**rho
187
+ t2 = np.interp(s2.log().numpy(), sigma_scheduler.sigmas.log().numpy(), t.numpy())
188
+
189
+ plt.figure()
190
+ plt.plot(sigma_scheduler.sigmas)
191
+ plt.plot(t2*1000, s2)
192
+ plt.show()
193
+
194
+ plt.figure()
195
+ plt.plot(sigma_scheduler.sigmas.log())
196
+ plt.plot(t2*1000, s2.log())
197
+ plt.show()
@@ -0,0 +1,48 @@
1
+ from typing import Union
2
+
3
+ import torch
4
+ import numpy as np
5
+
6
+ from .base import SigmaScheduler
7
+
8
+ class EDMSigmaScheduler(SigmaScheduler):
9
+ def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0, num_timesteps=1000):
10
+ self.sigma_min = torch.tensor(sigma_min)
11
+ self.sigma_max = torch.tensor(sigma_max)
12
+ self.rho = rho
13
+
14
+ self.num_timesteps=num_timesteps
15
+
16
+ def get_sigma(self, t: Union[float, torch.Tensor]):
17
+ if isinstance(t, float):
18
+ t = torch.tensor(t)
19
+
20
+ min_inv_rho = self.sigma_min**(1/self.rho)
21
+ max_inv_rho = self.sigma_max**(1/self.rho)
22
+ return torch.lerp(min_inv_rho, max_inv_rho, t)**self.rho
23
+
24
+ def sample_sigma(self, min_rate=0.0, max_rate=1.0, shape=(1,)):
25
+ if isinstance(min_rate, float):
26
+ min_rate = torch.full(shape, min_rate)
27
+ if isinstance(max_rate, float):
28
+ max_rate = torch.full(shape, max_rate)
29
+
30
+ t = torch.lerp(min_rate, max_rate, torch.rand_like(min_rate))
31
+ return self.get_sigma(t), t
32
+
33
+ class EDMRefSigmaScheduler(EDMSigmaScheduler):
34
+ def __init__(self, ref_scheduler, sigma_min=0.002, sigma_max=80.0, rho=7.0, num_timesteps=1000):
35
+ super().__init__(sigma_min, sigma_max, rho, num_timesteps=num_timesteps)
36
+ self.ref_sigmas = ref_scheduler.sigmas.cpu().clip(min=1e-8).log().numpy()
37
+ self.ref_t = np.linspace(0, 1, len(self.ref_sigmas))
38
+
39
+ def sample_sigma(self, min_rate=0.0, max_rate=1.0, shape=(1,)):
40
+ if isinstance(min_rate, float):
41
+ min_rate = torch.full(shape, min_rate)
42
+ if isinstance(max_rate, float):
43
+ max_rate = torch.full(shape, max_rate)
44
+
45
+ t = torch.lerp(min_rate, max_rate, torch.rand_like(min_rate))
46
+ sigma = self.get_sigma(t)
47
+ t_rect = torch.tensor(np.interp(sigma.cpu().clip(min=1e-8).log().numpy(), self.ref_sigmas, self.ref_t))
48
+ return sigma, t_rect
@@ -0,0 +1,2 @@
1
+ from .model import SD15_auto_loader, SDXL_auto_loader, PixArt_auto_loader, ControlNet_SD15, make_controlnet_handler
2
+ from .sampler import Diffusers_SD
@@ -0,0 +1,3 @@
1
+ from .sd15_train import SD15_lora_train, cfg_data_SD_ARB, cfg_data_SD_resize_crop, SD15_finetuning
2
+ from .sdxl_train import SDXL_lora_train, SDXL_finetuning
3
+ from .t2i import SD15_t2i, SDXL_t2i, SDXL_t2i_lora, SD15_t2i_lora, SDXL_t2i_parts, SD15_t2i_parts
@@ -0,0 +1,207 @@
1
+ import torch
2
+ from rainbowneko.ckpt_manager import ckpt_saver, LAYERS_TRAINABLE, NekoPluginSaver, SafeTensorFormat
3
+ from rainbowneko.data import RatioBucket, FixedBucket
4
+ from rainbowneko.parser import CfgWDPluginParser, neko_cfg, CfgWDModelParser, disable_neko_cfg
5
+ from rainbowneko.utils import ConstantLR, Path_Like
6
+
7
+ from hcpdiff.ckpt_manager import LoraWebuiFormat
8
+ from hcpdiff.data import TextImagePairDataset, Text2ImageSource, StableDiffusionHandler
9
+ from hcpdiff.data import VaeCache
10
+ from hcpdiff.easy import SD15_auto_loader
11
+ from hcpdiff.models import SD15Wrapper, TEHookCFG
12
+ from hcpdiff.models.lora_layers_patch import LoraLayer
13
+
14
+ @neko_cfg
15
+ def SD15_finetuning(base_model: str, train_steps: int, dataset, save_step: int = 500, lr: float = 1e-5, clip_skip: int = 0,
16
+ dtype: str = 'fp16', low_vram: bool = False, warmup_steps: int = 0, name: str = 'SD15'):
17
+ if low_vram:
18
+ from bitsandbytes.optim import AdamW8bit
19
+ optimizer = AdamW8bit(_partial_=True)
20
+ else:
21
+ optimizer = torch.optim.AdamW(_partial_=True)
22
+
23
+ from cfgs.train.py import train_base, tuning_base
24
+
25
+ return dict(
26
+ _base_=[train_base, tuning_base],
27
+ mixed_precision=dtype,
28
+
29
+ model_part=CfgWDModelParser([
30
+ dict(
31
+ lr=lr,
32
+ layers=['denoiser'], # train UNet
33
+ )
34
+ ], weight_decay=1e-2),
35
+
36
+ ckpt_saver=dict(
37
+ SD15=ckpt_saver(
38
+ ckpt_type='safetensors',
39
+ target_module='denoiser',
40
+ layers=LAYERS_TRAINABLE,
41
+ )
42
+ ),
43
+
44
+ train=dict(
45
+ train_steps=train_steps,
46
+ save_step=save_step,
47
+
48
+ optimizer=optimizer,
49
+
50
+ scheduler=ConstantLR(
51
+ _partial_=True,
52
+ warmup_steps=warmup_steps,
53
+ ),
54
+ ),
55
+
56
+ model=dict(
57
+ name=name,
58
+
59
+ ## Easy config
60
+ wrapper=SD15Wrapper.from_pretrained(
61
+ _partial_=True,
62
+ models=SD15_auto_loader(ckpt_path=base_model, _partial_=True),
63
+ TE_hook_cfg=TEHookCFG(clip_skip=clip_skip),
64
+ ),
65
+ ),
66
+
67
+ data_train=dataset,
68
+ )
69
+
70
+ @neko_cfg
71
+ def SD15_lora_train(base_model: str, train_steps: int, dataset, save_step: int = 200, lr: float = 1e-4, rank: int = 4, alpha: float = None,
72
+ clip_skip: int = 0, with_conv: bool = False, dtype: str = 'fp16', low_vram: bool = False, warmup_steps: int = 0,
73
+ name: str = 'SD15', save_webui_format=False):
74
+ with disable_neko_cfg:
75
+ if alpha is None:
76
+ alpha = rank
77
+
78
+ if with_conv:
79
+ lora_layers = [
80
+ r're:denoiser.*\.attn.?$',
81
+ r're:denoiser.*\.ff$',
82
+ r're:denoiser.*\.resnets$',
83
+ r're:denoiser.*\.proj_in$',
84
+ r're:denoiser.*\.proj_out$',
85
+ r're:denoiser.*\.conv$',
86
+ ]
87
+ else:
88
+ lora_layers = [
89
+ r're:denoiser.*\.attn.?$',
90
+ r're:denoiser.*\.ff$',
91
+ ]
92
+
93
+ if low_vram:
94
+ from bitsandbytes.optim import AdamW8bit
95
+ optimizer = AdamW8bit(_partial_=True, betas=(0.9, 0.99))
96
+ else:
97
+ optimizer = torch.optim.AdamW(_partial_=True, betas=(0.9, 0.99))
98
+
99
+ if save_webui_format:
100
+ lora_format = LoraWebuiFormat()
101
+ else:
102
+ lora_format = SafeTensorFormat()
103
+
104
+ from cfgs.train.py.examples import SD_FT
105
+
106
+ return dict(
107
+ _base_=[SD_FT],
108
+ mixed_precision=dtype,
109
+
110
+ model_part=None,
111
+ model_plugin=CfgWDPluginParser(cfg_plugin=dict(
112
+ lora1=LoraLayer.wrap_model(
113
+ _partial_=True,
114
+ lr=lr,
115
+ rank=rank,
116
+ alpha=alpha,
117
+ layers=lora_layers
118
+ )
119
+ ), weight_decay=0.1),
120
+
121
+ ckpt_saver=dict(
122
+ _replace_ = True,
123
+ lora_unet=NekoPluginSaver(
124
+ format=lora_format,
125
+ target_plugin='lora1',
126
+ )
127
+ ),
128
+
129
+ train=dict(
130
+ train_steps=train_steps,
131
+ save_step=save_step,
132
+
133
+ optimizer=optimizer,
134
+
135
+ scheduler=ConstantLR(
136
+ _partial_=True,
137
+ warmup_steps=warmup_steps,
138
+ ),
139
+ ),
140
+
141
+ model=dict(
142
+ name=name,
143
+
144
+ wrapper=SD15Wrapper.from_pretrained(
145
+ _partial_=True,
146
+ models=SD15_auto_loader(ckpt_path=base_model, _partial_=True),
147
+ TE_hook_cfg=TEHookCFG(clip_skip=clip_skip),
148
+ ),
149
+ ),
150
+
151
+ data_train=dataset,
152
+ )
153
+
154
+ @neko_cfg
155
+ def cfg_data_SD_ARB(img_root: Path_Like, batch_size: int = 4, trigger_word: str = '', resolution: int = 512*512, num_bucket=4, word_names=None,
156
+ prompt_dropout: float = 0, prompt_template: Path_Like = 'prompt_template/caption.txt', loss_weight=1.0):
157
+ if word_names is None:
158
+ word_names = dict(pt1=trigger_word)
159
+ else:
160
+ word_names = word_names
161
+
162
+ return TextImagePairDataset(
163
+ _partial_=True, batch_size=batch_size, loss_weight=loss_weight,
164
+ source=dict(
165
+ data_source1=Text2ImageSource(
166
+ img_root=img_root,
167
+ label_file='${.img_root}', # path to image captions (file_words)
168
+ prompt_template=prompt_template,
169
+ ),
170
+ ),
171
+ handler=StableDiffusionHandler(
172
+ bucket=RatioBucket,
173
+ word_names=word_names,
174
+ erase=prompt_dropout,
175
+ ),
176
+ bucket=RatioBucket.from_files(
177
+ target_area=resolution,
178
+ num_bucket=num_bucket,
179
+ ),
180
+ cache=VaeCache(bs=batch_size)
181
+ )
182
+
183
+ @neko_cfg
184
+ def cfg_data_SD_resize_crop(img_root: Path_Like, batch_size: int = 4, trigger_word: str = '', target_size = (512, 512), word_names=None,
185
+ prompt_dropout: float = 0, prompt_template: Path_Like = 'prompt_template/caption.txt', loss_weight=1.0):
186
+ if word_names is None:
187
+ word_names = dict(pt1=trigger_word)
188
+ else:
189
+ word_names = word_names
190
+
191
+ return TextImagePairDataset(
192
+ _partial_=True, batch_size=batch_size, loss_weight=loss_weight,
193
+ source=dict(
194
+ data_source1=Text2ImageSource(
195
+ img_root=img_root,
196
+ label_file='${.img_root}', # path to image captions (file_words)
197
+ prompt_template=prompt_template,
198
+ ),
199
+ ),
200
+ handler=StableDiffusionHandler(
201
+ bucket=FixedBucket,
202
+ word_names=word_names,
203
+ erase=prompt_dropout,
204
+ ),
205
+ bucket=FixedBucket(target_size=target_size),
206
+ cache=VaeCache(bs=batch_size)
207
+ )