hcpdiff 0.9.1__py3-none-any.whl → 2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +252 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +90 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +3 -3
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text.py +40 -0
  21. hcpdiff/data/source/text2img.py +36 -74
  22. hcpdiff/data/source/text2img_cond.py +9 -15
  23. hcpdiff/diffusion/__init__.py +0 -0
  24. hcpdiff/diffusion/noise/__init__.py +2 -0
  25. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  26. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  27. hcpdiff/diffusion/sampler/__init__.py +5 -0
  28. hcpdiff/diffusion/sampler/base.py +72 -0
  29. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  30. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  31. hcpdiff/diffusion/sampler/edm.py +22 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  35. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  36. hcpdiff/easy/__init__.py +2 -0
  37. hcpdiff/easy/cfg/__init__.py +3 -0
  38. hcpdiff/easy/cfg/sd15_train.py +207 -0
  39. hcpdiff/easy/cfg/sdxl_train.py +147 -0
  40. hcpdiff/easy/cfg/t2i.py +228 -0
  41. hcpdiff/easy/model/__init__.py +2 -0
  42. hcpdiff/easy/model/cnet.py +31 -0
  43. hcpdiff/easy/model/loader.py +79 -0
  44. hcpdiff/easy/sampler.py +46 -0
  45. hcpdiff/evaluate/__init__.py +1 -0
  46. hcpdiff/evaluate/previewer.py +60 -0
  47. hcpdiff/loss/__init__.py +4 -1
  48. hcpdiff/loss/base.py +41 -0
  49. hcpdiff/loss/gw.py +35 -0
  50. hcpdiff/loss/ssim.py +37 -0
  51. hcpdiff/loss/vlb.py +79 -0
  52. hcpdiff/loss/weighting.py +66 -0
  53. hcpdiff/models/__init__.py +2 -2
  54. hcpdiff/models/cfg_context.py +17 -14
  55. hcpdiff/models/compose/compose_hook.py +44 -23
  56. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  57. hcpdiff/models/compose/sdxl_composer.py +4 -4
  58. hcpdiff/models/controlnet.py +16 -16
  59. hcpdiff/models/lora_base_patch.py +14 -25
  60. hcpdiff/models/lora_layers.py +3 -9
  61. hcpdiff/models/lora_layers_patch.py +14 -24
  62. hcpdiff/models/text_emb_ex.py +84 -6
  63. hcpdiff/models/textencoder_ex.py +54 -18
  64. hcpdiff/models/wrapper/__init__.py +3 -0
  65. hcpdiff/models/wrapper/pixart.py +19 -0
  66. hcpdiff/models/wrapper/sd.py +218 -0
  67. hcpdiff/models/wrapper/utils.py +20 -0
  68. hcpdiff/parser/__init__.py +1 -0
  69. hcpdiff/parser/embpt.py +32 -0
  70. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  71. hcpdiff/tools/dataset_generator.py +94 -0
  72. hcpdiff/tools/download_hf_model.py +24 -0
  73. hcpdiff/tools/init_proj.py +3 -21
  74. hcpdiff/tools/lora_convert.py +18 -17
  75. hcpdiff/tools/save_model.py +12 -0
  76. hcpdiff/tools/sd2diffusers.py +1 -1
  77. hcpdiff/train_colo.py +1 -1
  78. hcpdiff/train_deepspeed.py +1 -1
  79. hcpdiff/trainer_ac.py +79 -0
  80. hcpdiff/trainer_ac_single.py +31 -0
  81. hcpdiff/utils/__init__.py +0 -2
  82. hcpdiff/utils/inpaint_pipe.py +7 -2
  83. hcpdiff/utils/net_utils.py +29 -6
  84. hcpdiff/utils/pipe_hook.py +24 -7
  85. hcpdiff/utils/utils.py +21 -4
  86. hcpdiff/workflow/__init__.py +15 -10
  87. hcpdiff/workflow/daam/__init__.py +1 -0
  88. hcpdiff/workflow/daam/act.py +66 -0
  89. hcpdiff/workflow/daam/hook.py +109 -0
  90. hcpdiff/workflow/diffusion.py +118 -128
  91. hcpdiff/workflow/fast.py +31 -0
  92. hcpdiff/workflow/flow.py +67 -0
  93. hcpdiff/workflow/io.py +36 -130
  94. hcpdiff/workflow/model.py +46 -43
  95. hcpdiff/workflow/text.py +60 -47
  96. hcpdiff/workflow/utils.py +32 -12
  97. hcpdiff/workflow/vae.py +37 -38
  98. hcpdiff-2.2.dist-info/METADATA +299 -0
  99. hcpdiff-2.2.dist-info/RECORD +115 -0
  100. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/WHEEL +1 -1
  101. hcpdiff-2.2.dist-info/entry_points.txt +5 -0
  102. hcpdiff/ckpt_manager/base.py +0 -16
  103. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  104. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  105. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
  106. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  107. hcpdiff/data/bucket.py +0 -358
  108. hcpdiff/data/caption_loader.py +0 -80
  109. hcpdiff/data/cond_dataset.py +0 -40
  110. hcpdiff/data/crop_info_dataset.py +0 -40
  111. hcpdiff/data/data_processor.py +0 -33
  112. hcpdiff/data/pair_dataset.py +0 -146
  113. hcpdiff/data/sampler.py +0 -54
  114. hcpdiff/data/source/base.py +0 -30
  115. hcpdiff/data/utils.py +0 -80
  116. hcpdiff/deprecated/__init__.py +0 -1
  117. hcpdiff/deprecated/cfg_converter.py +0 -81
  118. hcpdiff/deprecated/lora_convert.py +0 -31
  119. hcpdiff/infer_workflow.py +0 -57
  120. hcpdiff/loggers/__init__.py +0 -13
  121. hcpdiff/loggers/base_logger.py +0 -76
  122. hcpdiff/loggers/cli_logger.py +0 -40
  123. hcpdiff/loggers/preview/__init__.py +0 -1
  124. hcpdiff/loggers/preview/image_previewer.py +0 -149
  125. hcpdiff/loggers/tensorboard_logger.py +0 -30
  126. hcpdiff/loggers/wandb_logger.py +0 -31
  127. hcpdiff/loggers/webui_logger.py +0 -9
  128. hcpdiff/loss/min_snr_loss.py +0 -52
  129. hcpdiff/models/layers.py +0 -81
  130. hcpdiff/models/plugin.py +0 -348
  131. hcpdiff/models/wrapper.py +0 -75
  132. hcpdiff/noise/__init__.py +0 -3
  133. hcpdiff/noise/noise_base.py +0 -16
  134. hcpdiff/noise/pyramid_noise.py +0 -50
  135. hcpdiff/noise/zero_terminal.py +0 -44
  136. hcpdiff/train_ac.py +0 -566
  137. hcpdiff/train_ac_single.py +0 -39
  138. hcpdiff/utils/caption_tools.py +0 -105
  139. hcpdiff/utils/cfg_net_tools.py +0 -321
  140. hcpdiff/utils/cfg_resolvers.py +0 -16
  141. hcpdiff/utils/ema.py +0 -52
  142. hcpdiff/utils/img_size_tool.py +0 -248
  143. hcpdiff/vis/__init__.py +0 -3
  144. hcpdiff/vis/base_interface.py +0 -12
  145. hcpdiff/vis/disk_interface.py +0 -48
  146. hcpdiff/vis/webui_interface.py +0 -17
  147. hcpdiff/viser_fast.py +0 -138
  148. hcpdiff/visualizer.py +0 -265
  149. hcpdiff/visualizer_reloadable.py +0 -237
  150. hcpdiff/workflow/base.py +0 -59
  151. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  152. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  153. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  154. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  155. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  156. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  157. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  158. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  159. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  160. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  161. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  162. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  163. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  164. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  165. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  166. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  167. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  168. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  169. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  170. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  171. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  172. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  173. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  174. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  175. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  176. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  177. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  178. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  179. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  180. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  181. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  182. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  183. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  184. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  185. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  186. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  187. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  188. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  189. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  190. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  191. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  192. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  193. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  194. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  195. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
  196. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  197. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
  198. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
  199. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  200. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  201. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  202. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  203. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  204. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  205. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  206. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  207. hcpdiff-0.9.1.dist-info/METADATA +0 -199
  208. hcpdiff-0.9.1.dist-info/RECORD +0 -160
  209. hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
  210. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info/licenses}/LICENSE +0 -0
  211. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/top_level.txt +0 -0
hcpdiff/loss/ssim.py ADDED
@@ -0,0 +1,37 @@
1
+ from pytorch_msssim import SSIM, MS_SSIM
2
+ from torch.nn.modules.loss import _Loss
3
+ import torch
4
+
5
+ class SSIMLoss(_Loss):
6
+ target_type = 'x0'
7
+
8
+ def __init__(self, size_average=None, reduce=None, reduction: str = 'mean'):
9
+ super().__init__(size_average=size_average, reduce=reduce, reduction=reduction)
10
+ self.ssim = SSIM(data_range=1., size_average=False, channel=4)
11
+
12
+ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
13
+ '''
14
+
15
+ :param input: [B,C,H,W]
16
+ :param target: [B,C,H,W]
17
+ :return: [B,1,1,1]
18
+ '''
19
+ input = (input+1)/2
20
+ target = (target+1)/2
21
+ return 1-self.ssim(input, target).view(-1,1,1,1)
22
+
23
+ class MS_SSIMLoss(_Loss):
24
+ def __init__(self, size_average=None, reduce=None, reduction: str = 'mean'):
25
+ super().__init__(size_average=size_average, reduce=reduce, reduction=reduction)
26
+ self.ssim = MS_SSIM(data_range=1., size_average=False, channel=4)
27
+
28
+ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
29
+ '''
30
+
31
+ :param input: [B,C,H,W]
32
+ :param target: [B,C,H,W]
33
+ :return: [B,1,1,1]
34
+ '''
35
+ input = (input+1)/2
36
+ target = (target+1)/2
37
+ return 1-self.ssim(input, target).view(-1,1,1,1)
hcpdiff/loss/vlb.py ADDED
@@ -0,0 +1,79 @@
1
+ import torch
2
+ from torch import nn
3
+ import numpy as np
4
+
5
+ class VLBLoss(nn.Module):
6
+ need_sigma = True
7
+ need_timesteps = True
8
+ need_sampler = True
9
+ var_pred = True
10
+
11
+ def __init__(self, loss, weight: float = 1.):
12
+ super().__init__()
13
+ self.loss = loss
14
+ self.weight = weight
15
+
16
+ def normal_kl(self, mean1, logvar1, mean2, logvar2):
17
+ """
18
+ Compute the KL divergence between two gaussians.
19
+ """
20
+
21
+ return 0.5*(-1.0+logvar2-logvar1+(logvar1-logvar2).exp()+((mean1-mean2)**2)*(-logvar2).exp())
22
+
23
+ def forward(self, input: torch.Tensor, target: torch.Tensor, sigma, timesteps: torch.Tensor, x_t: torch.Tensor, sampler):
24
+ eps_pred, var_pred = input.chunk(2, dim=1)
25
+ x0_pred = sampler.eps_to_x0(eps_pred, x_t, sigma)
26
+
27
+ true_mean = sampler.sigma_scheduler.get_post_mean(timesteps, target, x_t)
28
+ true_logvar = sampler.sigma_scheduler.get_post_log_var(timesteps)
29
+
30
+ pred_mean = sampler.sigma_scheduler.get_post_mean(timesteps, x0_pred, x_t)
31
+ pred_logvar = sampler.sigma_scheduler.get_post_log_var(timesteps, x_t_var=var_pred)
32
+
33
+ kl = self.normal_kl(true_mean, true_logvar, pred_mean, pred_logvar)
34
+ kl = kl.mean(dim=(1,2,3))/np.log(2.0)
35
+
36
+ decoder_nll = -self.discretized_gaussian_log_likelihood(target, means=pred_mean, log_scales=0.5*pred_logvar)
37
+ assert decoder_nll.shape == target.shape
38
+ decoder_nll = decoder_nll.mean(dim=(1,2,3))/np.log(2.0)
39
+
40
+ # At the first timestep return the decoder NLL,
41
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
42
+ output = torch.where((timesteps == 0), decoder_nll, kl)
43
+
44
+ return self.weight*output
45
+
46
+ def approx_standard_normal_cdf(self, x):
47
+ """
48
+ A fast approximation of the cumulative distribution function of the
49
+ standard normal.
50
+ """
51
+ return 0.5*(1.0+torch.tanh(np.sqrt(2.0/np.pi)*(x+0.044715*torch.pow(x, 3))))
52
+
53
+ def discretized_gaussian_log_likelihood(self, x, *, means, log_scales):
54
+ """
55
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
56
+ given image.
57
+ :param x: the target images. It is assumed that this was uint8 values,
58
+ rescaled to the range [-1, 1].
59
+ :param means: the Gaussian mean Tensor.
60
+ :param log_scales: the Gaussian log stddev Tensor.
61
+ :return: a tensor like x of log probabilities (in nats).
62
+ """
63
+ assert x.shape == means.shape == log_scales.shape
64
+ centered_x = x-means
65
+ inv_stdv = torch.exp(-log_scales)
66
+ plus_in = inv_stdv*(centered_x+1.0/255.0)
67
+ cdf_plus = self.approx_standard_normal_cdf(plus_in)
68
+ min_in = inv_stdv*(centered_x-1.0/255.0)
69
+ cdf_min = self.approx_standard_normal_cdf(min_in)
70
+ log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
71
+ log_one_minus_cdf_min = torch.log((1.0-cdf_min).clamp(min=1e-12))
72
+ cdf_delta = cdf_plus-cdf_min
73
+ log_probs = torch.where(
74
+ x<-0.999,
75
+ log_cdf_plus,
76
+ torch.where(x>0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))),
77
+ )
78
+ assert log_probs.shape == x.shape
79
+ return log_probs
@@ -0,0 +1,66 @@
1
+ from torch import nn
2
+
3
+ from .base import DiffusionLossContainer
4
+
5
+ class LossWeight(nn.Module):
6
+ def __init__(self, loss: DiffusionLossContainer):
7
+ super().__init__()
8
+ self.loss = loss
9
+
10
+ def get_weight(self, pred, inputs):
11
+ '''
12
+
13
+ :param input: [B,C,H,W]
14
+ :param target: [B,C,H,W]
15
+ :return: [B,1,1,1] or [B,C,H,W]
16
+ '''
17
+ raise NotImplementedError
18
+
19
+ def forward(self, pred, inputs):
20
+ '''
21
+ weight: [B,1,1,1] or [B,C,H,W]
22
+ loss: [B,*,*,*]
23
+ '''
24
+ return self.get_weight(pred, inputs)*self.loss(pred, inputs)
25
+
26
+ class SNRWeight(LossWeight):
27
+ def get_weight(self, pred, inputs):
28
+ if self.loss.target_type == 'eps':
29
+ return 1
30
+ elif self.loss.target_type == "x0":
31
+ sigma = pred['sigma']
32
+ return (1./sigma**2).view(-1, 1, 1, 1)
33
+ else:
34
+ raise ValueError(f"{self.__class__.__name__} is not support for target_type {self.loss.target_type}")
35
+
36
+ class MinSNRWeight(LossWeight):
37
+ def __init__(self, loss: DiffusionLossContainer, gamma: float = 1.):
38
+ super().__init__(loss)
39
+ self.gamma = gamma
40
+
41
+ def get_weight(self, pred, inputs):
42
+ sigma = pred['sigma']
43
+ if self.loss.target_type == 'eps':
44
+ w_snr = (self.gamma*sigma**2).clip(max=1).float()
45
+ elif self.loss.target_type == "x0":
46
+ w_snr = (1/(sigma**2)).clip(max=self.gamma).float()
47
+ else:
48
+ raise ValueError(f"{self.__class__.__name__} is not support for target_type {self.loss.target_type}")
49
+
50
+ return w_snr.view(-1, 1, 1, 1)
51
+
52
+ class EDMWeight(LossWeight):
53
+ def __init__(self, loss: DiffusionLossContainer, gamma: float = 1.):
54
+ super().__init__(loss)
55
+ self.gamma = gamma
56
+
57
+ def get_weight(self, pred, inputs):
58
+ sigma = pred['sigma']
59
+ if self.loss.target_type == 'eps':
60
+ w_snr = ((sigma**2+self.gamma**2)/(self.gamma**2)).float()
61
+ elif self.loss.target_type == "x0":
62
+ w_snr = ((sigma**2+self.gamma**2)/((sigma*self.gamma)**2)).float()
63
+ else:
64
+ raise ValueError(f"{self.__class__.__name__} is not support for target_type {self.loss.target_type}")
65
+
66
+ return w_snr.view(-1, 1, 1, 1)
@@ -1,4 +1,3 @@
1
- from .plugin import PluginBlock, PluginGroup, SinglePluginBlock, MultiPluginBlock, PatchPluginBlock
2
1
  # from .lora_base import LoraBlock, LoraGroup
3
2
  # from .lora_layers import lora_layer_map
4
3
  from .lora_base_patch import LoraBlock, LoraGroup
@@ -7,4 +6,5 @@ from .text_emb_ex import EmbeddingPTHook
7
6
  from .textencoder_ex import TEEXHook
8
7
  from .tokenizer_ex import TokenizerHook
9
8
  from .cfg_context import CFGContext, DreamArtistPTContext
10
- from .wrapper import TEUnetWrapper, SDXLTEUnetWrapper
9
+ from .wrapper import SD15Wrapper, SDXLWrapper, PixArtWrapper, TEHookCFG
10
+ from .controlnet import ControlNetPlugin
@@ -1,6 +1,7 @@
1
1
  import torch
2
2
  from einops import repeat
3
3
  import math
4
+ from typing import Union, Callable
4
5
 
5
6
  class CFGContext:
6
7
  def pre(self, noisy_latents, timesteps):
@@ -10,9 +11,11 @@ class CFGContext:
10
11
  return model_pred
11
12
 
12
13
  class DreamArtistPTContext(CFGContext):
13
- def __init__(self, cfg_scale, num_train_timesteps):
14
- self.cfg_scale=cfg_scale
15
- self.num_train_timesteps=num_train_timesteps
14
+ def __init__(self, cfg_low: float, cfg_high: float=None, cfg_func: Union[str, Callable]=None, num_train_timesteps=1000):
15
+ self.cfg_low = cfg_low
16
+ self.cfg_high = cfg_high or cfg_low
17
+ self.cfg_func = cfg_func
18
+ self.num_train_timesteps = num_train_timesteps
16
19
 
17
20
  def pre(self, noisy_latents, timesteps):
18
21
  self.t_raw = timesteps
@@ -22,18 +25,18 @@ class DreamArtistPTContext(CFGContext):
22
25
 
23
26
  def post(self, model_pred):
24
27
  e_t_uncond, e_t = model_pred.chunk(2)
25
- if self.cfg_scale[0] != self.cfg_scale[1]:
26
- rate = self.t_raw / (self.num_train_timesteps - 1)
27
- if self.cfg_scale[2] == 'cos':
28
- rate = torch.cos((rate - 1) * math.pi / 2)
29
- elif self.cfg_scale[2] == 'cos2':
30
- rate = 1 - torch.cos(rate * math.pi / 2)
31
- elif self.cfg_scale[2] == 'ln':
28
+ if self.cfg_low != self.cfg_high:
29
+ rate = self.t_raw/(self.num_train_timesteps-1)
30
+ if self.cfg_func == 'cos':
31
+ rate = torch.cos((rate-1)*math.pi/2)
32
+ elif self.cfg_func == 'cos2':
33
+ rate = 1-torch.cos(rate*math.pi/2)
34
+ elif self.cfg_func == 'ln':
32
35
  pass
33
36
  else:
34
- rate = eval(self.cfg_scale[2])
35
- rate = rate.view(-1,1,1,1)
37
+ rate = self.cfg_func(rate)
38
+ rate = rate.view(-1, 1, 1, 1)
36
39
  else:
37
40
  rate = 1
38
- model_pred = e_t_uncond + ((self.cfg_scale[1] - self.cfg_scale[0]) * rate + self.cfg_scale[0]) * (e_t - e_t_uncond)
39
- return model_pred
41
+ model_pred = e_t_uncond+((self.cfg_high-self.cfg_low)*rate+self.cfg_low)*(e_t-e_t_uncond)
42
+ return model_pred
@@ -38,42 +38,42 @@ class ComposeEmbPTHook(nn.Module):
38
38
  hook.remove()
39
39
 
40
40
  @classmethod
41
- def hook(cls, ex_words_emb: Dict[str, ParameterGroup], tokenizer, text_encoder, log=False, **kwargs):
41
+ def hook(cls, ex_words_emb: Dict[str, ParameterGroup], tokenizer, text_encoder, **kwargs):
42
42
  if isinstance(text_encoder, ComposeTextEncoder):
43
43
  hook_list = []
44
44
 
45
45
  emb_len = 0
46
- for i, (name, tokenizer_i) in enumerate(tokenizer.tokenizer_list):
46
+ for i, name in enumerate(tokenizer.tokenizer_names):
47
47
  text_encoder_i = getattr(text_encoder, name)
48
- if log:
49
- logger.info(f'compose hook: {name}')
48
+ tokenizer_i = getattr(tokenizer, name)
50
49
  embedding_dim = text_encoder_i.get_input_embeddings().embedding_dim
51
50
  ex_words_emb_i = {k:v[i] for k, v in ex_words_emb.items()}
52
51
  emb_len += embedding_dim
53
- hook_list.append((name, EmbeddingPTHook.hook(ex_words_emb_i, tokenizer_i, text_encoder_i, log=log, **kwargs)))
52
+ hook_list.append((name, EmbeddingPTHook.hook(ex_words_emb_i, tokenizer_i, text_encoder_i, **kwargs)))
54
53
 
55
54
  return cls(hook_list)
56
55
  else:
57
- return EmbeddingPTHook.hook(ex_words_emb, tokenizer, text_encoder, log, **kwargs)
56
+ return EmbeddingPTHook.hook(ex_words_emb, tokenizer, text_encoder, **kwargs)
58
57
 
59
58
  @classmethod
60
- def hook_from_dir(cls, emb_dir, tokenizer, text_encoder, log=True, device='cuda:0', **kwargs) -> Union[
59
+ def hook_from_dir(cls, emb_dir, tokenizer, text_encoder, device='cuda:0', **kwargs) -> Union[
61
60
  Tuple['ComposeEmbPTHook', Dict], Tuple[EmbeddingPTHook, Dict]]:
62
61
  if isinstance(text_encoder, ComposeTextEncoder):
63
62
  # multi text encoder
64
- #ex_words_emb = {file[:-3]:load_emb(os.path.join(emb_dir, file)).to(device) for file in os.listdir(emb_dir) if file.endswith('.pt')}
63
+ # ex_words_emb = {file[:-3]:load_emb(os.path.join(emb_dir, file)).to(device) for file in os.listdir(emb_dir) if file.endswith('.pt')}
65
64
 
66
65
  # slice of nn.Parameter cannot return grad. Split the tensor
67
66
  ex_words_emb = {}
68
- emb_dims = [x.embedding_dim for x in text_encoder.get_input_embeddings()]
69
- for file in os.listdir(emb_dir):
70
- if file.endswith('.pt'):
71
- emb = load_emb(os.path.join(emb_dir, file)).to(device)
72
- emb = ParameterGroup([nn.Parameter(item, requires_grad=False) for item in emb.split(emb_dims, dim=1)])
73
- ex_words_emb[file[:-3]] = emb
74
- return cls.hook(ex_words_emb, tokenizer, text_encoder, log, **kwargs), ex_words_emb
67
+ if emb_dir is not None and os.path.exists(emb_dir):
68
+ emb_dims = [x.embedding_dim for x in text_encoder.get_input_embeddings()]
69
+ for file in os.listdir(emb_dir):
70
+ if file.endswith('.pt'):
71
+ emb = load_emb(os.path.join(emb_dir, file)).to(device)
72
+ emb = ParameterGroup([nn.Parameter(item, requires_grad=False) for item in emb.split(emb_dims, dim=1)])
73
+ ex_words_emb[file[:-3]] = emb
74
+ return cls.hook(ex_words_emb, tokenizer, text_encoder, **kwargs), ex_words_emb
75
75
  else:
76
- return EmbeddingPTHook.hook_from_dir(emb_dir, tokenizer, text_encoder, log, device, **kwargs)
76
+ return EmbeddingPTHook.hook_from_dir(emb_dir, tokenizer, text_encoder, **kwargs)
77
77
 
78
78
  class ComposeTEEXHook:
79
79
  def __init__(self, tehook_list: List[Tuple[str, TEEXHook]], cat_dim=-1):
@@ -98,10 +98,28 @@ class ComposeTEEXHook:
98
98
  for name, tehook in self.tehook_list:
99
99
  tehook.clip_skip = value
100
100
 
101
+ @property
102
+ def clip_final_norm(self):
103
+ return self.tehook_list[0][1].clip_final_norm
104
+
105
+ @clip_final_norm.setter
106
+ def clip_final_norm(self, value: bool):
107
+ for name, tehook in self.tehook_list:
108
+ tehook.clip_final_norm = value
109
+
110
+ @property
111
+ def use_attention_mask(self):
112
+ return self.tehook_list[0][1].use_attention_mask
113
+
114
+ @use_attention_mask.setter
115
+ def use_attention_mask(self, value: bool):
116
+ for name, tehook in self.tehook_list:
117
+ tehook.use_attention_mask = value
118
+
101
119
  def encode_prompt_to_emb(self, prompt):
102
120
  emb_list = [tehook.encode_prompt_to_emb(prompt) for name, tehook in self.tehook_list]
103
- encoder_hidden_states, pooled_output = list(zip(*emb_list))
104
- return torch.cat(encoder_hidden_states, dim=self.cat_dim), pooled_output
121
+ encoder_hidden_states, pooled_output, attention_mask = list(zip(*emb_list))
122
+ return torch.cat(encoder_hidden_states, dim=self.cat_dim), pooled_output, attention_mask[0]
105
123
 
106
124
  def enable_xformers(self):
107
125
  for name, tehook in self.tehook_list:
@@ -112,16 +130,19 @@ class ComposeTEEXHook:
112
130
  return TEEXHook.mult_attn(prompt_embeds, attn_mult)
113
131
 
114
132
  @classmethod
115
- def hook(cls, text_enc: nn.Module, tokenizer, N_repeats=3, clip_skip=0, clip_final_norm=True, device='cuda', use_attention_mask=False) -> Union['ComposeTEEXHook', TEEXHook]:
133
+ def hook(cls, text_enc: nn.Module, tokenizer, N_repeats=3, clip_skip=0, clip_final_norm=True, use_attention_mask=False) -> Union[
134
+ 'ComposeTEEXHook', TEEXHook]:
116
135
  if isinstance(text_enc, ComposeTextEncoder):
117
136
  # multi text encoder
118
- tehook_list = [(name, TEEXHook.hook(getattr(text_enc, name), tokenizer_i, N_repeats, clip_skip, clip_final_norm, device=device, use_attention_mask=use_attention_mask))
119
- for name, tokenizer_i in tokenizer.tokenizer_list]
137
+ tehook_list = [(name, TEEXHook.hook(getattr(text_enc, name), getattr(tokenizer, name), N_repeats, clip_skip, clip_final_norm,
138
+ use_attention_mask=use_attention_mask))
139
+ for name in tokenizer.tokenizer_names]
120
140
  return cls(tehook_list)
121
141
  else:
122
142
  # single text encoder
123
- return TEEXHook.hook(text_enc, tokenizer, N_repeats, clip_skip, device=device, use_attention_mask=use_attention_mask)
143
+ return TEEXHook.hook(text_enc, tokenizer, N_repeats, clip_skip, clip_final_norm, use_attention_mask=use_attention_mask)
124
144
 
125
145
  @classmethod
126
146
  def hook_pipe(cls, pipe, N_repeats=3, clip_skip=0, clip_final_norm=True, use_attention_mask=False):
127
- return cls.hook(pipe.text_encoder, pipe.tokenizer, N_repeats=N_repeats, device='cuda', clip_skip=clip_skip, clip_final_norm=clip_final_norm, use_attention_mask=use_attention_mask)
147
+ return cls.hook(pipe.text_encoder, pipe.tokenizer, N_repeats=N_repeats, clip_skip=clip_skip, clip_final_norm=clip_final_norm,
148
+ use_attention_mask=use_attention_mask)
@@ -18,14 +18,19 @@ from transformers.tokenization_utils_base import BatchEncoding
18
18
  class ComposeTokenizer(PreTrainedTokenizer):
19
19
  def __init__(self, tokenizer_list: List[Tuple[str, CLIPTokenizer]], cat_dim=-1):
20
20
  self.cat_dim = cat_dim
21
- self.tokenizer_list = tokenizer_list
21
+
22
+ self.tokenizer_names = []
23
+ for name, tokenizer in tokenizer_list:
24
+ setattr(self, name, tokenizer)
25
+ self.tokenizer_names.append(name)
26
+
22
27
  super().__init__()
23
28
 
24
- self.model_max_length = self.first_tokenizer.model_max_length
29
+ self.model_max_length = torch.tensor([tokenizer.model_max_length for name, tokenizer in tokenizer_list])
25
30
 
26
31
  @property
27
32
  def first_tokenizer(self):
28
- return self.tokenizer_list[0][1]
33
+ return getattr(self, self.tokenizer_names[0])
29
34
 
30
35
  @property
31
36
  def vocab_size(self):
@@ -40,18 +45,26 @@ class ComposeTokenizer(PreTrainedTokenizer):
40
45
  return self.first_tokenizer.bos_token_id
41
46
 
42
47
  def get_vocab(self):
43
- return dict(self.first_tokenizer.encoder, **self.first_tokenizer.added_tokens_encoder)
48
+ return self.first_tokenizer.get_vocab()
44
49
 
45
50
  def tokenize(self, text, **kwargs) -> List[str]:
46
51
  return self.first_tokenizer.tokenize(text, **kwargs)
47
52
 
48
53
  def add_tokens( self, new_tokens, special_tokens: bool = False) -> List[int]:
49
- return [tokenizer.add_tokens(new_tokens, special_tokens) for name, tokenizer in self.tokenizer_list]
54
+ return [getattr(self, name).add_tokens(new_tokens, special_tokens) for name in self.tokenizer_names]
55
+
56
+ def save_vocabulary(self, save_directory: str, filename_prefix = None) -> Tuple[str]:
57
+ return self.first_tokenizer.save_vocabulary(save_directory, filename_prefix)
58
+
59
+ def __call__(self, text, *args, max_length=None, **kwargs):
60
+ if isinstance(max_length, torch.Tensor):
61
+ token_list: List[BatchEncoding] = [getattr(self, name)(text, *args, max_length=max_length_i, **kwargs)
62
+ for name, max_length_i in zip(self.tokenizer_names, max_length)]
63
+ else:
64
+ token_list: List[BatchEncoding] = [getattr(self, name)(text, *args, max_length=max_length, **kwargs) for name in self.tokenizer_names]
50
65
 
51
- def __call__(self, text, *args, **kwargs):
52
- token_list: List[BatchEncoding] = [tokenizer(text, *args, **kwargs) for name, tokenizer in self.tokenizer_list]
53
66
  input_ids = torch.cat([token.input_ids for token in token_list], dim=-1) # [N_tokenizer, N_token]
54
- attention_mask = [token.attention_mask for token in token_list]
67
+ attention_mask = torch.cat([token.attention_mask for token in token_list], dim=-1)
55
68
  return BatchEncoding({'input_ids':input_ids, 'attention_mask':attention_mask})
56
69
 
57
70
  @classmethod
@@ -27,13 +27,13 @@ class CLIPTextModelWithProjection_Align(CLIPTextModelWithProjection):
27
27
  class SDXLTextEncoder(ComposeTextEncoder):
28
28
  @classmethod
29
29
  def from_pretrained(cls, pretrained_model_name_or_path: str, *args, subfolder=None, revision:str=None, **kwargs):
30
- clip_B = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder', **kwargs)
30
+ clip_L = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder', **kwargs)
31
31
  clip_bigG = CLIPTextModelWithProjection_Align.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder_2', **kwargs)
32
- return cls([('clip_B', clip_B), ('clip_bigG', clip_bigG)])
32
+ return cls([('clip_L', clip_L), ('clip_bigG', clip_bigG)])
33
33
 
34
34
  class SDXLTokenizer(ComposeTokenizer):
35
35
  @classmethod
36
36
  def from_pretrained(cls, pretrained_model_name_or_path: str, *args, subfolder=None, revision:str=None, **kwargs):
37
- clip_B = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer', **kwargs)
37
+ clip_L = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer', **kwargs)
38
38
  clip_bigG = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer_2', **kwargs)
39
- return cls([('clip_B', clip_B), ('clip_bigG', clip_bigG)])
39
+ return cls([('clip_L', clip_L), ('clip_bigG', clip_bigG)])
@@ -5,7 +5,7 @@ import torch
5
5
  from torch import nn
6
6
  from copy import deepcopy
7
7
 
8
- from .plugin import MultiPluginBlock, BasePluginBlock
8
+ from rainbowneko.models.plugin import MultiPluginBlock, BasePluginBlock
9
9
  from hcpdiff.utils.net_utils import remove_all_hooks, remove_layers
10
10
 
11
11
  class ControlNetPlugin(MultiPluginBlock):
@@ -55,25 +55,25 @@ class ControlNetPlugin(MultiPluginBlock):
55
55
  self.cond_head = nn.Sequential(*cond_head)
56
56
 
57
57
  def reset_parameters(self) -> None:
58
- def weight_init(m):
59
- if isinstance(m, nn.Conv2d):
60
- nn.init.constant_(m.weight, 0)
61
- self.controlnet_down_blocks.apply(weight_init)
62
- self.controlnet_mid_block.apply(weight_init)
63
- self.cond_head[-1].apply(weight_init)
64
-
65
- def from_layer_hook(self, host, fea_in:Tuple[torch.Tensor], fea_out:Tuple[torch.Tensor], idx: int):
58
+ def zero_weight_init(m):
59
+ for p in m.parameters():
60
+ p.detach().zero_()
61
+ self.controlnet_down_blocks.apply(zero_weight_init)
62
+ self.controlnet_mid_block.apply(zero_weight_init)
63
+ self.cond_head[-1].apply(zero_weight_init)
64
+
65
+ def from_layer_hook(self, host, idx: int, args: Tuple[Any, ...], kwargs: Dict[str, Any], fea_out: Any=None):
66
66
  if idx==0:
67
- self.data_input = fea_in
67
+ self.data_input = (args, kwargs)
68
68
  elif idx==1:
69
- self.feat_to = self(*self.data_input)
69
+ self.feat_to = self(*self.data_input[0], **self.data_input[1])
70
70
 
71
- def to_layer_hook(self, host, fea_in:Tuple[torch.Tensor], fea_out:Tuple[torch.Tensor], idx: int):
71
+ def to_layer_hook(self, host, idx: int, args: Tuple[Any, ...], kwargs: Dict[str, Any], fea_out: Any=None):
72
72
  if idx == 5:
73
- sp = fea_in[0].shape[1]//2
74
- new_feat = fea_in[0].clone()
75
- new_feat[:, sp:, ...] = fea_in[0][:, sp:, ...] + self.feat_to[0]
76
- return (new_feat, fea_in[1])
73
+ sp = args[0].shape[1]//2
74
+ new_feat = args[0].clone()
75
+ new_feat[:, sp:, ...] = args[0][:, sp:, ...] + self.feat_to[0]
76
+ return (new_feat, args[1])
77
77
  elif idx == 3:
78
78
  return (fea_out[0], tuple(fea_out[1][i] + self.feat_to[(idx) * 3 + i+1] for i in range(2)))
79
79
  elif idx == 4:
@@ -13,7 +13,7 @@ from torch import nn
13
13
  from torch.nn import functional as F
14
14
 
15
15
  from hcpdiff.utils.utils import make_mask, low_rank_approximate, isinstance_list
16
- from .plugin import PatchPluginBlock, PluginGroup, PatchPluginContainer
16
+ from rainbowneko.models.plugin import PatchPluginBlock, PluginGroup, PatchPluginContainer
17
17
 
18
18
  from typing import Union, Tuple, Dict, Type
19
19
 
@@ -38,9 +38,9 @@ class LoraBlock(PatchPluginBlock):
38
38
  container_cls = LoraPatchContainer
39
39
  wrapable_classes = (nn.Linear, nn.Conv2d)
40
40
 
41
- def __init__(self, lora_id:int, host:Union[nn.Linear, nn.Conv2d], rank, dropout=0.1, alpha=1.0, bias=False,
41
+ def __init__(self, name:int, host:Union[nn.Linear, nn.Conv2d], rank, dropout=0.1, alpha=1.0, bias=False,
42
42
  alpha_auto_scale=True, parent_block=None, host_name=None, **kwargs):
43
- super().__init__(f'lora_block_{lora_id}', host, parent_block=parent_block, host_name=host_name)
43
+ super().__init__(name, host, parent_block=parent_block, host_name=host_name)
44
44
 
45
45
  self.bias=bias
46
46
 
@@ -56,8 +56,14 @@ class LoraBlock(PatchPluginBlock):
56
56
  self.dropout = nn.Dropout(dropout)
57
57
 
58
58
  self.rank = self.layer.rank
59
+ self.alpha_auto_scale = alpha_auto_scale
59
60
  self.register_buffer('alpha', torch.tensor(alpha/self.rank if alpha_auto_scale else alpha))
60
61
 
62
+ def set_hyper_params(self, alpha=None, **kwargs):
63
+ if alpha is not None:
64
+ self.register_buffer('alpha', torch.tensor(alpha/self.rank if self.alpha_auto_scale else alpha))
65
+ super().set_hyper_params(**kwargs)
66
+
61
67
  def get_weight(self):
62
68
  return self.layer.get_weight() * self.alpha
63
69
 
@@ -91,7 +97,7 @@ class LoraBlock(PatchPluginBlock):
91
97
  host.weight.data * base_alpha + alpha * re_w.to(host.weight.device, dtype=host.weight.dtype)
92
98
  )
93
99
 
94
- if self.layer.lora_up.bias is not None:
100
+ if re_b is not None:
95
101
  if host.bias is None:
96
102
  host.bias = nn.Parameter(re_b.to(host.weight.device, dtype=host.weight.dtype))
97
103
  else:
@@ -145,32 +151,15 @@ class LoraBlock(PatchPluginBlock):
145
151
  pass
146
152
 
147
153
  @classmethod
148
- def wrap_layer(cls, lora_id:int, layer: Union[nn.Linear, nn.Conv2d], rank=1, dropout=0.0, alpha=1.0, svd_init=False,
154
+ def wrap_layer(cls, name:str, host: Union[nn.Linear, nn.Conv2d], rank=1, dropout=0.0, alpha=1.0, svd_init=False,
149
155
  bias=False, mask=None, **kwargs):# -> LoraBlock:
150
- lora_block = cls(lora_id, layer, rank, dropout, alpha, bias=bias, **kwargs)
156
+ lora_block = cls(name, host, rank, dropout, alpha, bias=bias, **kwargs)
151
157
  lora_block.init_weights(svd_init)
152
158
  return lora_block
153
159
 
154
160
  @classmethod
155
- def wrap_model(cls, lora_id:int, model: nn.Module, **kwargs):# -> Dict[str, LoraBlock]:
156
- return super(LoraBlock, cls).wrap_model(lora_id, model, exclude_classes=(LoraBlock,), **kwargs)
157
-
158
- @staticmethod
159
- def extract_lora_state(model:nn.Module):
160
- return {k:v for k,v in model.state_dict().items() if 'lora_block_' in k}
161
-
162
- @staticmethod
163
- def extract_state_without_lora(model:nn.Module):
164
- return {k:v for k,v in model.state_dict().items() if 'lora_block_' not in k}
165
-
166
- @staticmethod
167
- def extract_param_without_lora(model:nn.Module):
168
- return {k:v for k,v in model.named_parameters() if 'lora_block_' not in k}
169
-
170
- @staticmethod
171
- def extract_trainable_state_without_lora(model:nn.Module):
172
- trainable_keys = {k for k,v in model.named_parameters() if ('lora_block_' not in k) and v.requires_grad}
173
- return {k: v for k, v in model.state_dict().items() if k in trainable_keys}
161
+ def wrap_model(cls, name:str, host: nn.Module, **kwargs):# -> Dict[str, LoraBlock]:
162
+ return super().wrap_model(name, host, exclude_classes=(LoraBlock,), **kwargs)
174
163
 
175
164
  class LoraGroup(PluginGroup):
176
165
  def set_mask(self, batch_mask):
@@ -15,7 +15,7 @@ from einops import repeat, rearrange, einsum
15
15
  from torch import nn
16
16
 
17
17
  from .lora_base import LoraBlock
18
- from .layers import GroupLinear
18
+ from rainbowneko.models.layers import GroupLinear
19
19
  import warnings
20
20
 
21
21
  class LoraLayer(LoraBlock):
@@ -59,8 +59,8 @@ class LoraLayerGroup(LoraBlock):
59
59
  def __init__(self, host, rank, bias, dropout, block):
60
60
  super().__init__(host, rank, bias, dropout, block)
61
61
  self.register_buffer('rank_groups', torch.tensor(block.rank_groups_raw, dtype=torch.int))
62
- self.lora_down = GroupLinear(host.in_features*self.rank_groups, self.rank, groups=self.rank_groups, bias=False)
63
- self.lora_up = GroupLinear(self.rank, host.out_features*self.rank_groups, groups=self.rank_groups, bias=bias)
62
+ self.lora_down = GroupLinear(host.in_features, self.rank//self.rank_groups, group=self.rank_groups, bias=False)
63
+ self.lora_up = GroupLinear(self.rank//self.rank_groups, host.out_features, group=self.rank_groups, bias=bias)
64
64
 
65
65
  def feed_svd(self, U, V, weight):
66
66
  self.lora_up.weight.data = rearrange(U, 'o (g ri) -> g ri o', g=self.rank_groups).to(device=weight.device, dtype=weight.dtype)
@@ -137,9 +137,3 @@ class LohaLayer(LoraBlock):
137
137
  w = torch.prod(einsum(self.W_up.data, self.W_down.data, 'g o r ..., g r i ... -> g o i ...'), dim=0)
138
138
  b = None
139
139
  return w, b
140
-
141
- lora_layer_map={
142
- 'lora': LoraLayer,
143
- 'loha_group': LoraLayerGroup,
144
- 'loha': LohaLayer,
145
- }