hcpdiff 0.9.0__py3-none-any.whl → 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (208) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +80 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +1 -2
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text2img.py +36 -74
  21. hcpdiff/data/source/text2img_cond.py +9 -15
  22. hcpdiff/diffusion/__init__.py +0 -0
  23. hcpdiff/diffusion/noise/__init__.py +2 -0
  24. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  25. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  26. hcpdiff/diffusion/sampler/__init__.py +5 -0
  27. hcpdiff/diffusion/sampler/base.py +72 -0
  28. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  29. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  30. hcpdiff/diffusion/sampler/edm.py +22 -0
  31. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  35. hcpdiff/easy/__init__.py +2 -0
  36. hcpdiff/easy/cfg/__init__.py +3 -0
  37. hcpdiff/easy/cfg/sd15_train.py +201 -0
  38. hcpdiff/easy/cfg/sdxl_train.py +140 -0
  39. hcpdiff/easy/cfg/t2i.py +177 -0
  40. hcpdiff/easy/model/__init__.py +2 -0
  41. hcpdiff/easy/model/cnet.py +31 -0
  42. hcpdiff/easy/model/loader.py +79 -0
  43. hcpdiff/easy/sampler.py +46 -0
  44. hcpdiff/evaluate/__init__.py +1 -0
  45. hcpdiff/evaluate/previewer.py +60 -0
  46. hcpdiff/loss/__init__.py +4 -1
  47. hcpdiff/loss/base.py +41 -0
  48. hcpdiff/loss/gw.py +35 -0
  49. hcpdiff/loss/ssim.py +37 -0
  50. hcpdiff/loss/vlb.py +79 -0
  51. hcpdiff/loss/weighting.py +66 -0
  52. hcpdiff/models/__init__.py +2 -2
  53. hcpdiff/models/cfg_context.py +17 -14
  54. hcpdiff/models/compose/compose_hook.py +44 -23
  55. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  56. hcpdiff/models/compose/sdxl_composer.py +4 -4
  57. hcpdiff/models/container.py +1 -1
  58. hcpdiff/models/controlnet.py +16 -16
  59. hcpdiff/models/lora_base_patch.py +14 -25
  60. hcpdiff/models/lora_layers.py +3 -9
  61. hcpdiff/models/lora_layers_patch.py +14 -24
  62. hcpdiff/models/text_emb_ex.py +84 -6
  63. hcpdiff/models/textencoder_ex.py +54 -18
  64. hcpdiff/models/wrapper/__init__.py +3 -0
  65. hcpdiff/models/wrapper/pixart.py +19 -0
  66. hcpdiff/models/wrapper/sd.py +218 -0
  67. hcpdiff/models/wrapper/utils.py +20 -0
  68. hcpdiff/parser/__init__.py +1 -0
  69. hcpdiff/parser/embpt.py +32 -0
  70. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  71. hcpdiff/tools/dataset_generator.py +94 -0
  72. hcpdiff/tools/download_hf_model.py +24 -0
  73. hcpdiff/tools/embedding_convert.py +6 -2
  74. hcpdiff/tools/init_proj.py +3 -21
  75. hcpdiff/tools/lora_convert.py +19 -15
  76. hcpdiff/tools/save_model.py +12 -0
  77. hcpdiff/tools/sd2diffusers.py +1 -1
  78. hcpdiff/train_colo.py +1 -1
  79. hcpdiff/train_deepspeed.py +1 -1
  80. hcpdiff/trainer_ac.py +79 -0
  81. hcpdiff/trainer_ac_single.py +31 -0
  82. hcpdiff/utils/__init__.py +0 -2
  83. hcpdiff/utils/inpaint_pipe.py +790 -0
  84. hcpdiff/utils/net_utils.py +29 -6
  85. hcpdiff/utils/pipe_hook.py +46 -33
  86. hcpdiff/utils/utils.py +21 -4
  87. hcpdiff/workflow/__init__.py +15 -10
  88. hcpdiff/workflow/daam/__init__.py +1 -0
  89. hcpdiff/workflow/daam/act.py +66 -0
  90. hcpdiff/workflow/daam/hook.py +109 -0
  91. hcpdiff/workflow/diffusion.py +128 -136
  92. hcpdiff/workflow/fast.py +31 -0
  93. hcpdiff/workflow/flow.py +67 -0
  94. hcpdiff/workflow/io.py +36 -68
  95. hcpdiff/workflow/model.py +46 -43
  96. hcpdiff/workflow/text.py +84 -52
  97. hcpdiff/workflow/utils.py +32 -12
  98. hcpdiff/workflow/vae.py +37 -38
  99. hcpdiff-2.1.dist-info/METADATA +285 -0
  100. hcpdiff-2.1.dist-info/RECORD +114 -0
  101. {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
  102. hcpdiff-2.1.dist-info/entry_points.txt +5 -0
  103. hcpdiff/ckpt_manager/base.py +0 -16
  104. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  105. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  106. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -60
  107. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  108. hcpdiff/data/bucket.py +0 -358
  109. hcpdiff/data/caption_loader.py +0 -80
  110. hcpdiff/data/cond_dataset.py +0 -40
  111. hcpdiff/data/crop_info_dataset.py +0 -40
  112. hcpdiff/data/data_processor.py +0 -33
  113. hcpdiff/data/pair_dataset.py +0 -146
  114. hcpdiff/data/sampler.py +0 -54
  115. hcpdiff/data/source/base.py +0 -30
  116. hcpdiff/data/utils.py +0 -80
  117. hcpdiff/infer_workflow.py +0 -57
  118. hcpdiff/loggers/__init__.py +0 -13
  119. hcpdiff/loggers/base_logger.py +0 -76
  120. hcpdiff/loggers/cli_logger.py +0 -40
  121. hcpdiff/loggers/preview/__init__.py +0 -1
  122. hcpdiff/loggers/preview/image_previewer.py +0 -149
  123. hcpdiff/loggers/tensorboard_logger.py +0 -30
  124. hcpdiff/loggers/wandb_logger.py +0 -31
  125. hcpdiff/loggers/webui_logger.py +0 -9
  126. hcpdiff/loss/min_snr_loss.py +0 -52
  127. hcpdiff/models/layers.py +0 -81
  128. hcpdiff/models/plugin.py +0 -348
  129. hcpdiff/models/wrapper.py +0 -75
  130. hcpdiff/noise/__init__.py +0 -3
  131. hcpdiff/noise/noise_base.py +0 -16
  132. hcpdiff/noise/pyramid_noise.py +0 -50
  133. hcpdiff/noise/zero_terminal.py +0 -44
  134. hcpdiff/train_ac.py +0 -565
  135. hcpdiff/train_ac_single.py +0 -39
  136. hcpdiff/utils/caption_tools.py +0 -105
  137. hcpdiff/utils/cfg_net_tools.py +0 -321
  138. hcpdiff/utils/cfg_resolvers.py +0 -16
  139. hcpdiff/utils/ema.py +0 -52
  140. hcpdiff/utils/img_size_tool.py +0 -248
  141. hcpdiff/vis/__init__.py +0 -3
  142. hcpdiff/vis/base_interface.py +0 -12
  143. hcpdiff/vis/disk_interface.py +0 -48
  144. hcpdiff/vis/webui_interface.py +0 -17
  145. hcpdiff/visualizer.py +0 -258
  146. hcpdiff/visualizer_reloadable.py +0 -237
  147. hcpdiff/workflow/base.py +0 -59
  148. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  149. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  150. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  151. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  152. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  153. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  154. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  155. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  156. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  157. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  158. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  159. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  160. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  161. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  162. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  163. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  164. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  165. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  166. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  167. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  168. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  169. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  170. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  171. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  172. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  173. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  174. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  175. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  176. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  177. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  178. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  179. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  180. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  181. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  182. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  183. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  184. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  185. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  186. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  187. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  188. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  189. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  190. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  191. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  192. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -57
  193. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  194. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero2.json +0 -32
  195. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero3.json +0 -39
  196. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  197. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  198. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  199. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  200. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  201. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  202. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  203. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  204. hcpdiff-0.9.0.dist-info/METADATA +0 -199
  205. hcpdiff-0.9.0.dist-info/RECORD +0 -155
  206. hcpdiff-0.9.0.dist-info/entry_points.txt +0 -2
  207. {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
  208. {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,22 @@
1
+ import torch
2
+
3
+ from .base import BaseSampler
4
+ from .sigma_scheduler import SigmaScheduler
5
+
6
+ class EDMSampler(BaseSampler):
7
+ def __init__(self, sigma_scheduler: SigmaScheduler, generator: torch.Generator = None, sigma_data: float = 1.0, sigma_thr=1000):
8
+ super().__init__(sigma_scheduler, generator)
9
+ self.sigma_data = sigma_data
10
+ self.sigma_thr = sigma_thr
11
+
12
+ def c_in(self, sigma):
13
+ return 1/(sigma**2+self.sigma_data**2).sqrt()
14
+
15
+ def c_out(self, sigma):
16
+ return (sigma*self.sigma_data)/(sigma**2+self.sigma_data**2).sqrt()
17
+
18
+ def c_skip(self, sigma):
19
+ return self.sigma_data**2/(sigma**2+self.sigma_data**2)
20
+
21
+ def denoise(self, x, sigma, eps=None, generator=None):
22
+ raise NotImplementedError
@@ -0,0 +1,3 @@
1
+ from .base import SigmaScheduler
2
+ from .ddpm import DDPMDiscreteSigmaScheduler, DDPMContinuousSigmaScheduler, TimeSigmaScheduler
3
+ from .edm import EDMSigmaScheduler, EDMRefSigmaScheduler
@@ -0,0 +1,14 @@
1
+ from typing import Union, Tuple
2
+
3
+ import torch
4
+
5
+ class SigmaScheduler:
6
+
7
+ def get_sigma(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
8
+ '''
9
+ :param t: 0-1, rate of time step
10
+ '''
11
+ raise NotImplementedError
12
+
13
+ def sample_sigma(self, min_rate=0.0, max_rate=1.0, shape=(1,)) -> Tuple[torch.Tensor, torch.Tensor]:
14
+ raise NotImplementedError
@@ -0,0 +1,197 @@
1
+ import torch
2
+ import math
3
+ from typing import Union, Tuple
4
+ from hcpdiff.utils import linear_interp
5
+ from .base import SigmaScheduler
6
+
7
+ class DDPMDiscreteSigmaScheduler(SigmaScheduler):
8
+ def __init__(self, beta_schedule: str = "scaled_linear", linear_start=0.00085, linear_end=0.0120, num_timesteps=1000):
9
+ super().__init__()
10
+ self.num_timesteps = num_timesteps
11
+ self.betas = self.make_betas(beta_schedule, linear_start, linear_end, num_timesteps)
12
+ alphas = 1.0-self.betas
13
+ self.alphas_cumprod = torch.cumprod(alphas, dim=0)
14
+ self.sigmas = ((1-self.alphas_cumprod)/self.alphas_cumprod).sqrt()
15
+
16
+ # for VLB calculation
17
+ self.alphas_cumprod_prev = torch.cat([alphas.new_tensor([1.0]), self.alphas_cumprod[:-1]])
18
+ self.posterior_mean_coef1 = self.betas*torch.sqrt(self.alphas_cumprod_prev)/(1.0-self.alphas_cumprod)
19
+ self.posterior_mean_coef2 = (1.0-self.alphas_cumprod_prev)*torch.sqrt(alphas)/(1.0-self.alphas_cumprod)
20
+
21
+ self.posterior_variance = self.betas*(1.0-self.alphas_cumprod_prev)/(1.0-self.alphas_cumprod)
22
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
23
+ self.posterior_log_variance_clipped = torch.log(torch.cat([self.posterior_variance[1:2], self.posterior_variance[1:]]))
24
+
25
+
26
+ @property
27
+ def sigma_min(self):
28
+ return self.sigmas[0]
29
+
30
+ @property
31
+ def sigma_max(self):
32
+ return self.sigmas[-1]
33
+
34
+ def get_sigma(self, t: Union[float, torch.Tensor]):
35
+ if isinstance(t, float):
36
+ t = torch.tensor(t)
37
+ return self.sigmas[(t*len(self.sigmas)).long()]
38
+
39
+ def sample_sigma(self, min_rate=0.0, max_rate=1.0, shape=(1,)):
40
+ if isinstance(min_rate, float):
41
+ min_rate = torch.full(shape, min_rate)
42
+ if isinstance(max_rate, float):
43
+ max_rate = torch.full(shape, max_rate)
44
+
45
+ t = torch.lerp(min_rate, max_rate, torch.rand_like(min_rate))
46
+ t_scale = (t*(self.num_timesteps-1e-5)).long() # [0, num_timesteps-1)
47
+ return self.sigmas[t_scale], t
48
+
49
+ def sigma_to_t(self, sigma: Union[float, torch.Tensor]):
50
+ t = (self.sigmas-sigma).abs().argmin()
51
+ return t/self.num_timesteps
52
+
53
+ def get_post_mean(self, t, x_0, x_t):
54
+ t = (t*len(self.sigmas)).long()
55
+ return self.posterior_mean_coef1[t].view(-1, 1, 1, 1).to(t.device)*x_0 + self.posterior_mean_coef2[t].view(-1, 1, 1, 1).to(t.device)*x_t
56
+
57
+ def get_post_log_var(self, t, x_t_var=None):
58
+ t = (t*len(self.sigmas)).long()
59
+ min_log = self.posterior_log_variance_clipped[t].view(-1, 1, 1, 1).to(t.device)
60
+ if x_t_var is None:
61
+ return min_log
62
+ else:
63
+ max_log = self.betas.log()[t].view(-1, 1, 1, 1).to(t.device)
64
+ # The model_var_values is [-1, 1] for [min_var, max_var].
65
+ frac = (x_t_var+1)/2
66
+ model_log_variance = frac*max_log+(1-frac)*min_log
67
+ return model_log_variance
68
+
69
+
70
+ @staticmethod
71
+ def betas_for_alpha_bar(
72
+ num_diffusion_timesteps,
73
+ max_beta=0.999,
74
+ alpha_transform_type="cosine",
75
+ ):
76
+ """
77
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
78
+ (1-beta) over time from t = [0,1].
79
+
80
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
81
+ to that part of the diffusion process.
82
+
83
+
84
+ Args:
85
+ num_diffusion_timesteps (`int`): the number of betas to produce.
86
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
87
+ prevent singularities.
88
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
89
+ Choose from `cosine` or `exp`
90
+
91
+ Returns:
92
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
93
+ """
94
+ if alpha_transform_type == "cosine":
95
+
96
+ def alpha_bar_fn(t):
97
+ return math.cos((t+0.008)/1.008*math.pi/2)**2
98
+
99
+ elif alpha_transform_type == "exp":
100
+
101
+ def alpha_bar_fn(t):
102
+ return math.exp(t*-12.0)
103
+
104
+ else:
105
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
106
+
107
+ betas = []
108
+ for i in range(num_diffusion_timesteps):
109
+ t1 = i/num_diffusion_timesteps
110
+ t2 = (i+1)/num_diffusion_timesteps
111
+ betas.append(min(1-alpha_bar_fn(t2)/alpha_bar_fn(t1), max_beta))
112
+ return torch.tensor(betas, dtype=torch.float32)
113
+
114
+ @staticmethod
115
+ def make_betas(beta_schedule, beta_start, beta_end, num_train_timesteps, betas=None):
116
+ if betas is not None:
117
+ return torch.tensor(betas, dtype=torch.float32)
118
+ elif beta_schedule == "linear":
119
+ return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
120
+ elif beta_schedule == "scaled_linear":
121
+ # this schedule is very specific to the latent diffusion model.
122
+ return torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32)**2
123
+ elif beta_schedule == "squaredcos_cap_v2":
124
+ # Glide cosine schedule
125
+ return DDPMDiscreteSigmaScheduler.betas_for_alpha_bar(num_train_timesteps)
126
+ elif beta_schedule == "sigmoid":
127
+ # GeoDiff sigmoid schedule
128
+ betas = torch.linspace(-6, 6, num_train_timesteps)
129
+ return torch.sigmoid(betas)*(beta_end-beta_start)+beta_start
130
+ else:
131
+ raise NotImplementedError(f"{beta_schedule} does is not implemented.")
132
+
133
+ class DDPMContinuousSigmaScheduler(DDPMDiscreteSigmaScheduler):
134
+
135
+ def get_sigma(self, t: Union[float, torch.Tensor]):
136
+ if isinstance(t, float):
137
+ t = torch.tensor(t)
138
+ return linear_interp(self.sigmas, t)
139
+
140
+ def sample_sigma(self, min_rate=0.0, max_rate=1.0, shape=(1,)):
141
+ if isinstance(min_rate, float):
142
+ min_rate = torch.full(shape, min_rate)
143
+ if isinstance(max_rate, float):
144
+ max_rate = torch.full(shape, max_rate)
145
+
146
+ t = torch.lerp(min_rate, max_rate, torch.rand_like(min_rate))
147
+ t_scale = (t*(self.num_timesteps-1-1e-5)) # [0, num_timesteps-1)
148
+
149
+ return linear_interp(self.sigmas, t_scale), t
150
+
151
+ def sigma_to_t(self, sigma: Union[float, torch.Tensor]):
152
+ diff = self.sigmas-sigma
153
+ diff[diff<0] = float('inf')
154
+ t0 = diff.argmin().clamp(0, self.num_timesteps-2)
155
+ return t0 + diff.min()/(self.sigmas[t0+1]-self.sigmas[t0])
156
+
157
+ class TimeSigmaScheduler(SigmaScheduler):
158
+ def __init__(self, num_timesteps=1000):
159
+ super().__init__()
160
+ self.num_timesteps = num_timesteps
161
+
162
+ def get_sigma(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
163
+ '''
164
+ :param t: 0-1, rate of time step
165
+ '''
166
+ return t
167
+
168
+ def sample_sigma(self, min_rate=0.0, max_rate=1.0, shape=(1,)) -> Tuple[torch.Tensor, torch.Tensor]:
169
+ if isinstance(min_rate, float):
170
+ min_rate = torch.full(shape, min_rate)
171
+ if isinstance(max_rate, float):
172
+ max_rate = torch.full(shape, max_rate)
173
+
174
+ t = torch.lerp(min_rate, max_rate, torch.rand_like(min_rate))
175
+ t_scale = (t*(self.num_timesteps-1e-5)).long() # [0, num_timesteps-1)
176
+ return t_scale, t
177
+
178
+ if __name__ == '__main__':
179
+ from matplotlib import pyplot as plt
180
+ import numpy as np
181
+
182
+ sigma_scheduler = DDPMDiscreteSigmaScheduler()
183
+ print(sigma_scheduler.sigma_min, sigma_scheduler.sigma_max)
184
+ t = torch.linspace(0, 1, 1000)
185
+ rho = 1.
186
+ s2 = (sigma_scheduler.sigma_min**(1/rho)+t*(sigma_scheduler.sigma_max**(1/rho)-sigma_scheduler.sigma_min**(1/rho)))**rho
187
+ t2 = np.interp(s2.log().numpy(), sigma_scheduler.sigmas.log().numpy(), t.numpy())
188
+
189
+ plt.figure()
190
+ plt.plot(sigma_scheduler.sigmas)
191
+ plt.plot(t2*1000, s2)
192
+ plt.show()
193
+
194
+ plt.figure()
195
+ plt.plot(sigma_scheduler.sigmas.log())
196
+ plt.plot(t2*1000, s2.log())
197
+ plt.show()
@@ -0,0 +1,48 @@
1
+ from typing import Union
2
+
3
+ import torch
4
+ import numpy as np
5
+
6
+ from .base import SigmaScheduler
7
+
8
+ class EDMSigmaScheduler(SigmaScheduler):
9
+ def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0, num_timesteps=1000):
10
+ self.sigma_min = torch.tensor(sigma_min)
11
+ self.sigma_max = torch.tensor(sigma_max)
12
+ self.rho = rho
13
+
14
+ self.num_timesteps=num_timesteps
15
+
16
+ def get_sigma(self, t: Union[float, torch.Tensor]):
17
+ if isinstance(t, float):
18
+ t = torch.tensor(t)
19
+
20
+ min_inv_rho = self.sigma_min**(1/self.rho)
21
+ max_inv_rho = self.sigma_max**(1/self.rho)
22
+ return torch.lerp(min_inv_rho, max_inv_rho, t)**self.rho
23
+
24
+ def sample_sigma(self, min_rate=0.0, max_rate=1.0, shape=(1,)):
25
+ if isinstance(min_rate, float):
26
+ min_rate = torch.full(shape, min_rate)
27
+ if isinstance(max_rate, float):
28
+ max_rate = torch.full(shape, max_rate)
29
+
30
+ t = torch.lerp(min_rate, max_rate, torch.rand_like(min_rate))
31
+ return self.get_sigma(t), t
32
+
33
+ class EDMRefSigmaScheduler(EDMSigmaScheduler):
34
+ def __init__(self, ref_scheduler, sigma_min=0.002, sigma_max=80.0, rho=7.0, num_timesteps=1000):
35
+ super().__init__(sigma_min, sigma_max, rho, num_timesteps=num_timesteps)
36
+ self.ref_sigmas = ref_scheduler.sigmas.cpu().clip(min=1e-8).log().numpy()
37
+ self.ref_t = np.linspace(0, 1, len(self.ref_sigmas))
38
+
39
+ def sample_sigma(self, min_rate=0.0, max_rate=1.0, shape=(1,)):
40
+ if isinstance(min_rate, float):
41
+ min_rate = torch.full(shape, min_rate)
42
+ if isinstance(max_rate, float):
43
+ max_rate = torch.full(shape, max_rate)
44
+
45
+ t = torch.lerp(min_rate, max_rate, torch.rand_like(min_rate))
46
+ sigma = self.get_sigma(t)
47
+ t_rect = torch.tensor(np.interp(sigma.cpu().clip(min=1e-8).log().numpy(), self.ref_sigmas, self.ref_t))
48
+ return sigma, t_rect
@@ -0,0 +1,2 @@
1
+ from .model import SD15_auto_loader, SDXL_auto_loader, PixArt_auto_loader, ControlNet_SD15, make_controlnet_handler
2
+ from .sampler import Diffusers_SD
@@ -0,0 +1,3 @@
1
+ from .sd15_train import SD15_lora_train, cfg_data_SD_ARB, cfg_data_SD_resize_crop, SD15_finetuning
2
+ from .sdxl_train import SDXL_lora_train, SDXL_finetuning
3
+ from .t2i import SD15_t2i, SDXL_t2i, SDXL_t2i_lora, SD15_t2i_lora
@@ -0,0 +1,201 @@
1
+ import torch
2
+ from rainbowneko.ckpt_manager import ckpt_saver, LAYERS_TRAINABLE, plugin_saver
3
+ from rainbowneko.data import RatioBucket, FixedBucket
4
+ from rainbowneko.parser import CfgWDPluginParser, neko_cfg, CfgWDModelParser, disable_neko_cfg
5
+ from rainbowneko.utils import ConstantLR, Path_Like
6
+
7
+ from hcpdiff.data import TextImagePairDataset, Text2ImageSource, StableDiffusionHandler
8
+ from hcpdiff.data import VaeCache
9
+ from hcpdiff.easy import SD15_auto_loader
10
+ from hcpdiff.models import SD15Wrapper, TEHookCFG
11
+ from hcpdiff.models.lora_layers_patch import LoraLayer
12
+
13
+ @neko_cfg
14
+ def SD15_finetuning(base_model: str, train_steps: int, dataset, save_step: int = 500, lr: float = 1e-5, clip_skip: int = 0,
15
+ dtype: str = 'fp16', low_vram: bool = False, warmup_steps: int = 0, name: str = 'SD15'):
16
+ if low_vram:
17
+ from bitsandbytes.optim import AdamW8bit
18
+ optimizer = AdamW8bit(_partial_=True)
19
+ else:
20
+ optimizer = torch.optim.AdamW(_partial_=True)
21
+
22
+ from cfgs.train.py import train_base, tuning_base
23
+
24
+ return dict(
25
+ _base_=[train_base, tuning_base],
26
+ mixed_precision=dtype,
27
+
28
+ model_part=CfgWDModelParser([
29
+ dict(
30
+ lr=lr,
31
+ layers=['denoiser'], # train UNet
32
+ )
33
+ ], weight_decay=1e-2),
34
+
35
+ ckpt_saver=dict(
36
+ SD15=ckpt_saver(
37
+ ckpt_type='safetensors',
38
+ target_module='denoiser',
39
+ layers=LAYERS_TRAINABLE,
40
+ )
41
+ ),
42
+
43
+ train=dict(
44
+ train_steps=train_steps,
45
+ save_step=save_step,
46
+
47
+ optimizer=optimizer,
48
+
49
+ scheduler=ConstantLR(
50
+ _partial_=True,
51
+ warmup_steps=warmup_steps,
52
+ ),
53
+ ),
54
+
55
+ model=dict(
56
+ name=name,
57
+
58
+ ## Easy config
59
+ wrapper=SD15Wrapper.from_pretrained(
60
+ _partial_=True,
61
+ models=SD15_auto_loader(ckpt_path=base_model, _partial_=True),
62
+ TE_hook_cfg=TEHookCFG(clip_skip=clip_skip),
63
+ ),
64
+ ),
65
+
66
+ data_train=dataset,
67
+ )
68
+
69
+ @neko_cfg
70
+ def SD15_lora_train(base_model: str, train_steps: int, dataset, save_step: int = 200, lr: float = 1e-4, rank: int = 4, alpha: float = None,
71
+ clip_skip: int = 0, with_conv: bool = False, dtype: str = 'fp16', low_vram: bool = False, warmup_steps: int = 0,
72
+ name: str = 'SD15'):
73
+ with disable_neko_cfg:
74
+ if alpha is None:
75
+ alpha = rank
76
+
77
+ if with_conv:
78
+ lora_layers = [
79
+ r're:denoiser.*\.attn.?$',
80
+ r're:denoiser.*\.ff$',
81
+ r're:denoiser.*\.resnets$',
82
+ r're:denoiser.*\.proj_in$',
83
+ r're:denoiser.*\.proj_out$',
84
+ r're:denoiser.*\.conv$',
85
+ ]
86
+ else:
87
+ lora_layers = [
88
+ r're:denoiser.*\.attn.?$',
89
+ r're:denoiser.*\.ff$',
90
+ ]
91
+
92
+ if low_vram:
93
+ from bitsandbytes.optim import AdamW8bit
94
+ optimizer = AdamW8bit(_partial_=True, betas=(0.9, 0.99))
95
+ else:
96
+ optimizer = torch.optim.AdamW(_partial_=True, betas=(0.9, 0.99))
97
+
98
+ from cfgs.train.py.examples import SD_FT
99
+
100
+ return dict(
101
+ _base_=[SD_FT],
102
+ mixed_precision=dtype,
103
+
104
+ model_part=None,
105
+ model_plugin=CfgWDPluginParser(cfg_plugin=dict(
106
+ lora1=LoraLayer.wrap_model(
107
+ _partial_=True,
108
+ lr=lr,
109
+ rank=rank,
110
+ alpha=alpha,
111
+ layers=lora_layers
112
+ )
113
+ ), weight_decay=0.1),
114
+
115
+ ckpt_saver=dict(
116
+ _replace_ = True,
117
+ lora_unet=plugin_saver(
118
+ ckpt_type='safetensors',
119
+ target_plugin='lora1',
120
+ )
121
+ ),
122
+
123
+ train=dict(
124
+ train_steps=train_steps,
125
+ save_step=save_step,
126
+
127
+ optimizer=optimizer,
128
+
129
+ scheduler=ConstantLR(
130
+ _partial_=True,
131
+ warmup_steps=warmup_steps,
132
+ ),
133
+ ),
134
+
135
+ model=dict(
136
+ name=name,
137
+
138
+ wrapper=SD15Wrapper.from_pretrained(
139
+ _partial_=True,
140
+ models=SD15_auto_loader(ckpt_path=base_model, _partial_=True),
141
+ TE_hook_cfg=TEHookCFG(clip_skip=clip_skip),
142
+ ),
143
+ ),
144
+
145
+ data_train=dataset,
146
+ )
147
+
148
+ @neko_cfg
149
+ def cfg_data_SD_ARB(img_root: Path_Like, batch_size: int = 4, trigger_word: str = '', resolution: int = 512*512, num_bucket=4, word_names=None,
150
+ prompt_dropout: float = 0, prompt_template: Path_Like = 'prompt_template/caption.txt', loss_weight=1.0):
151
+ if word_names is None:
152
+ word_names = dict(pt1=trigger_word)
153
+ else:
154
+ word_names = word_names
155
+
156
+ return TextImagePairDataset(
157
+ _partial_=True, batch_size=batch_size, loss_weight=loss_weight,
158
+ source=dict(
159
+ data_source1=Text2ImageSource(
160
+ img_root=img_root,
161
+ label_file='${.img_root}', # path to image captions (file_words)
162
+ prompt_template=prompt_template,
163
+ ),
164
+ ),
165
+ handler=StableDiffusionHandler(
166
+ bucket=RatioBucket,
167
+ word_names=word_names,
168
+ erase=prompt_dropout,
169
+ ),
170
+ bucket=RatioBucket.from_files(
171
+ target_area=resolution,
172
+ num_bucket=num_bucket,
173
+ ),
174
+ cache=VaeCache(bs=batch_size)
175
+ )
176
+
177
+ @neko_cfg
178
+ def cfg_data_SD_resize_crop(img_root: Path_Like, batch_size: int = 4, trigger_word: str = '', target_size = (512, 512), word_names=None,
179
+ prompt_dropout: float = 0, prompt_template: Path_Like = 'prompt_template/caption.txt', loss_weight=1.0):
180
+ if word_names is None:
181
+ word_names = dict(pt1=trigger_word)
182
+ else:
183
+ word_names = word_names
184
+
185
+ return TextImagePairDataset(
186
+ _partial_=True, batch_size=batch_size, loss_weight=loss_weight,
187
+ source=dict(
188
+ data_source1=Text2ImageSource(
189
+ img_root=img_root,
190
+ label_file='${.img_root}', # path to image captions (file_words)
191
+ prompt_template=prompt_template,
192
+ ),
193
+ ),
194
+ handler=StableDiffusionHandler(
195
+ bucket=FixedBucket,
196
+ word_names=word_names,
197
+ erase=prompt_dropout,
198
+ ),
199
+ bucket=FixedBucket(target_size=target_size),
200
+ cache=VaeCache(bs=batch_size)
201
+ )
@@ -0,0 +1,140 @@
1
+ import torch
2
+ from rainbowneko.ckpt_manager import ckpt_saver, plugin_saver, LAYERS_TRAINABLE
3
+ from rainbowneko.parser import CfgWDPluginParser, neko_cfg, CfgWDModelParser, disable_neko_cfg
4
+ from rainbowneko.utils import ConstantLR
5
+
6
+ from hcpdiff.easy import SDXL_auto_loader
7
+ from hcpdiff.models import SDXLWrapper
8
+ from hcpdiff.models.lora_layers_patch import LoraLayer
9
+
10
+ @neko_cfg
11
+ def SDXL_finetuning(base_model: str, train_steps: int, dataset, save_step: int = 500, lr: float = 1e-5,
12
+ dtype: str = 'fp16', low_vram: bool = False, warmup_steps: int = 0, name: str = 'SDXL'):
13
+ if low_vram:
14
+ from bitsandbytes.optim import AdamW8bit
15
+ optimizer = AdamW8bit(_partial_=True)
16
+ else:
17
+ optimizer = torch.optim.AdamW(_partial_=True)
18
+
19
+ from cfgs.train.py import train_base, tuning_base
20
+
21
+ return dict(
22
+ _base_=[train_base, tuning_base],
23
+ mixed_precision=dtype,
24
+
25
+ model_part=CfgWDModelParser([
26
+ dict(
27
+ lr=lr,
28
+ layers=['denoiser'], # train UNet
29
+ )
30
+ ], weight_decay=1e-2),
31
+
32
+ ckpt_saver=dict(
33
+ SDXL=ckpt_saver(
34
+ ckpt_type='safetensors',
35
+ target_module='denoiser',
36
+ layers=LAYERS_TRAINABLE,
37
+ )
38
+ ),
39
+
40
+ train=dict(
41
+ train_steps=train_steps,
42
+ save_step=save_step,
43
+
44
+ optimizer=optimizer,
45
+
46
+ scheduler=ConstantLR(
47
+ _partial_=True,
48
+ warmup_steps=warmup_steps,
49
+ ),
50
+ ),
51
+
52
+ model=dict(
53
+ name=name,
54
+
55
+ ## Easy config
56
+ wrapper=SDXLWrapper.from_pretrained(
57
+ _partial_=True,
58
+ models=SDXL_auto_loader(ckpt_path=base_model, _partial_=True),
59
+ ),
60
+ ),
61
+
62
+ data_train=dataset,
63
+ )
64
+
65
+ @neko_cfg
66
+ def SDXL_lora_train(base_model: str, train_steps: int, dataset, save_step: int = 200, lr: float = 1e-4, rank: int = 4, alpha: float = None,
67
+ with_conv: bool = False, dtype: str = 'fp16', low_vram: bool = False, warmup_steps: int = 0, name: str = 'SD15'):
68
+ with disable_neko_cfg:
69
+ if alpha is None:
70
+ alpha = rank
71
+
72
+ if with_conv:
73
+ lora_layers = [
74
+ r're:denoiser.*\.attn.?$',
75
+ r're:denoiser.*\.ff$',
76
+ r're:denoiser.*\.resnets$',
77
+ r're:denoiser.*\.proj_in$',
78
+ r're:denoiser.*\.proj_out$',
79
+ r're:denoiser.*\.conv$',
80
+ ]
81
+ else:
82
+ lora_layers = [
83
+ r're:denoiser.*\.attn.?$',
84
+ r're:denoiser.*\.ff$',
85
+ ]
86
+
87
+ if low_vram:
88
+ from bitsandbytes.optim import AdamW8bit
89
+ optimizer = AdamW8bit(_partial_=True, betas=(0.9, 0.99))
90
+ else:
91
+ optimizer = torch.optim.AdamW(_partial_=True, betas=(0.9, 0.99))
92
+
93
+ from cfgs.train.py.examples import SD_FT
94
+
95
+ return dict(
96
+ _base_=[SD_FT],
97
+ mixed_precision=dtype,
98
+
99
+ model_part=None,
100
+ model_plugin=CfgWDPluginParser(cfg_plugin=dict(
101
+ lora1=LoraLayer.wrap_model(
102
+ _partial_=True,
103
+ lr=lr,
104
+ rank=rank,
105
+ alpha=alpha,
106
+ layers=lora_layers
107
+ )
108
+ ), weight_decay=0.1),
109
+
110
+ ckpt_saver=dict(
111
+ _replace_ = True,
112
+ lora_unet=plugin_saver(
113
+ ckpt_type='safetensors',
114
+ target_plugin='lora1',
115
+ )
116
+ ),
117
+
118
+ train=dict(
119
+ train_steps=train_steps,
120
+ save_step=save_step,
121
+
122
+ optimizer=optimizer,
123
+
124
+ scheduler=ConstantLR(
125
+ _partial_=True,
126
+ warmup_steps=warmup_steps,
127
+ ),
128
+ ),
129
+
130
+ model=dict(
131
+ name=name,
132
+
133
+ wrapper=SDXLWrapper.from_pretrained(
134
+ models=SDXL_auto_loader(ckpt_path=base_model, _partial_=True),
135
+ _partial_=True,
136
+ ),
137
+ ),
138
+
139
+ data_train=dataset,
140
+ )