hcpdiff 0.9.1__py3-none-any.whl → 2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +252 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +90 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +3 -3
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text.py +40 -0
  21. hcpdiff/data/source/text2img.py +36 -74
  22. hcpdiff/data/source/text2img_cond.py +9 -15
  23. hcpdiff/diffusion/__init__.py +0 -0
  24. hcpdiff/diffusion/noise/__init__.py +2 -0
  25. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  26. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  27. hcpdiff/diffusion/sampler/__init__.py +5 -0
  28. hcpdiff/diffusion/sampler/base.py +72 -0
  29. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  30. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  31. hcpdiff/diffusion/sampler/edm.py +22 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  35. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  36. hcpdiff/easy/__init__.py +2 -0
  37. hcpdiff/easy/cfg/__init__.py +3 -0
  38. hcpdiff/easy/cfg/sd15_train.py +207 -0
  39. hcpdiff/easy/cfg/sdxl_train.py +147 -0
  40. hcpdiff/easy/cfg/t2i.py +228 -0
  41. hcpdiff/easy/model/__init__.py +2 -0
  42. hcpdiff/easy/model/cnet.py +31 -0
  43. hcpdiff/easy/model/loader.py +79 -0
  44. hcpdiff/easy/sampler.py +46 -0
  45. hcpdiff/evaluate/__init__.py +1 -0
  46. hcpdiff/evaluate/previewer.py +60 -0
  47. hcpdiff/loss/__init__.py +4 -1
  48. hcpdiff/loss/base.py +41 -0
  49. hcpdiff/loss/gw.py +35 -0
  50. hcpdiff/loss/ssim.py +37 -0
  51. hcpdiff/loss/vlb.py +79 -0
  52. hcpdiff/loss/weighting.py +66 -0
  53. hcpdiff/models/__init__.py +2 -2
  54. hcpdiff/models/cfg_context.py +17 -14
  55. hcpdiff/models/compose/compose_hook.py +44 -23
  56. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  57. hcpdiff/models/compose/sdxl_composer.py +4 -4
  58. hcpdiff/models/controlnet.py +16 -16
  59. hcpdiff/models/lora_base_patch.py +14 -25
  60. hcpdiff/models/lora_layers.py +3 -9
  61. hcpdiff/models/lora_layers_patch.py +14 -24
  62. hcpdiff/models/text_emb_ex.py +84 -6
  63. hcpdiff/models/textencoder_ex.py +54 -18
  64. hcpdiff/models/wrapper/__init__.py +3 -0
  65. hcpdiff/models/wrapper/pixart.py +19 -0
  66. hcpdiff/models/wrapper/sd.py +218 -0
  67. hcpdiff/models/wrapper/utils.py +20 -0
  68. hcpdiff/parser/__init__.py +1 -0
  69. hcpdiff/parser/embpt.py +32 -0
  70. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  71. hcpdiff/tools/dataset_generator.py +94 -0
  72. hcpdiff/tools/download_hf_model.py +24 -0
  73. hcpdiff/tools/init_proj.py +3 -21
  74. hcpdiff/tools/lora_convert.py +18 -17
  75. hcpdiff/tools/save_model.py +12 -0
  76. hcpdiff/tools/sd2diffusers.py +1 -1
  77. hcpdiff/train_colo.py +1 -1
  78. hcpdiff/train_deepspeed.py +1 -1
  79. hcpdiff/trainer_ac.py +79 -0
  80. hcpdiff/trainer_ac_single.py +31 -0
  81. hcpdiff/utils/__init__.py +0 -2
  82. hcpdiff/utils/inpaint_pipe.py +7 -2
  83. hcpdiff/utils/net_utils.py +29 -6
  84. hcpdiff/utils/pipe_hook.py +24 -7
  85. hcpdiff/utils/utils.py +21 -4
  86. hcpdiff/workflow/__init__.py +15 -10
  87. hcpdiff/workflow/daam/__init__.py +1 -0
  88. hcpdiff/workflow/daam/act.py +66 -0
  89. hcpdiff/workflow/daam/hook.py +109 -0
  90. hcpdiff/workflow/diffusion.py +118 -128
  91. hcpdiff/workflow/fast.py +31 -0
  92. hcpdiff/workflow/flow.py +67 -0
  93. hcpdiff/workflow/io.py +36 -130
  94. hcpdiff/workflow/model.py +46 -43
  95. hcpdiff/workflow/text.py +60 -47
  96. hcpdiff/workflow/utils.py +32 -12
  97. hcpdiff/workflow/vae.py +37 -38
  98. hcpdiff-2.2.dist-info/METADATA +299 -0
  99. hcpdiff-2.2.dist-info/RECORD +115 -0
  100. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/WHEEL +1 -1
  101. hcpdiff-2.2.dist-info/entry_points.txt +5 -0
  102. hcpdiff/ckpt_manager/base.py +0 -16
  103. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  104. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  105. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
  106. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  107. hcpdiff/data/bucket.py +0 -358
  108. hcpdiff/data/caption_loader.py +0 -80
  109. hcpdiff/data/cond_dataset.py +0 -40
  110. hcpdiff/data/crop_info_dataset.py +0 -40
  111. hcpdiff/data/data_processor.py +0 -33
  112. hcpdiff/data/pair_dataset.py +0 -146
  113. hcpdiff/data/sampler.py +0 -54
  114. hcpdiff/data/source/base.py +0 -30
  115. hcpdiff/data/utils.py +0 -80
  116. hcpdiff/deprecated/__init__.py +0 -1
  117. hcpdiff/deprecated/cfg_converter.py +0 -81
  118. hcpdiff/deprecated/lora_convert.py +0 -31
  119. hcpdiff/infer_workflow.py +0 -57
  120. hcpdiff/loggers/__init__.py +0 -13
  121. hcpdiff/loggers/base_logger.py +0 -76
  122. hcpdiff/loggers/cli_logger.py +0 -40
  123. hcpdiff/loggers/preview/__init__.py +0 -1
  124. hcpdiff/loggers/preview/image_previewer.py +0 -149
  125. hcpdiff/loggers/tensorboard_logger.py +0 -30
  126. hcpdiff/loggers/wandb_logger.py +0 -31
  127. hcpdiff/loggers/webui_logger.py +0 -9
  128. hcpdiff/loss/min_snr_loss.py +0 -52
  129. hcpdiff/models/layers.py +0 -81
  130. hcpdiff/models/plugin.py +0 -348
  131. hcpdiff/models/wrapper.py +0 -75
  132. hcpdiff/noise/__init__.py +0 -3
  133. hcpdiff/noise/noise_base.py +0 -16
  134. hcpdiff/noise/pyramid_noise.py +0 -50
  135. hcpdiff/noise/zero_terminal.py +0 -44
  136. hcpdiff/train_ac.py +0 -566
  137. hcpdiff/train_ac_single.py +0 -39
  138. hcpdiff/utils/caption_tools.py +0 -105
  139. hcpdiff/utils/cfg_net_tools.py +0 -321
  140. hcpdiff/utils/cfg_resolvers.py +0 -16
  141. hcpdiff/utils/ema.py +0 -52
  142. hcpdiff/utils/img_size_tool.py +0 -248
  143. hcpdiff/vis/__init__.py +0 -3
  144. hcpdiff/vis/base_interface.py +0 -12
  145. hcpdiff/vis/disk_interface.py +0 -48
  146. hcpdiff/vis/webui_interface.py +0 -17
  147. hcpdiff/viser_fast.py +0 -138
  148. hcpdiff/visualizer.py +0 -265
  149. hcpdiff/visualizer_reloadable.py +0 -237
  150. hcpdiff/workflow/base.py +0 -59
  151. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  152. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  153. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  154. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  155. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  156. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  157. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  158. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  159. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  160. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  161. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  162. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  163. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  164. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  165. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  166. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  167. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  168. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  169. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  170. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  171. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  172. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  173. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  174. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  175. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  176. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  177. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  178. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  179. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  180. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  181. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  182. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  183. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  184. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  185. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  186. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  187. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  188. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  189. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  190. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  191. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  192. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  193. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  194. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  195. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
  196. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  197. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
  198. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
  199. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  200. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  201. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  202. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  203. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  204. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  205. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  206. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  207. hcpdiff-0.9.1.dist-info/METADATA +0 -199
  208. hcpdiff-0.9.1.dist-info/RECORD +0 -160
  209. hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
  210. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info/licenses}/LICENSE +0 -0
  211. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,147 @@
1
+ import torch
2
+ from rainbowneko.ckpt_manager import ckpt_saver, NekoPluginSaver, LAYERS_TRAINABLE, SafeTensorFormat
3
+ from rainbowneko.parser import CfgWDPluginParser, neko_cfg, CfgWDModelParser, disable_neko_cfg
4
+ from rainbowneko.utils import ConstantLR
5
+
6
+ from hcpdiff.easy import SDXL_auto_loader
7
+ from hcpdiff.models import SDXLWrapper
8
+ from hcpdiff.models.lora_layers_patch import LoraLayer
9
+ from hcpdiff.ckpt_manager import LoraWebuiFormat
10
+
11
+ @neko_cfg
12
+ def SDXL_finetuning(base_model: str, train_steps: int, dataset, save_step: int = 500, lr: float = 1e-5,
13
+ dtype: str = 'fp16', low_vram: bool = False, warmup_steps: int = 0, name: str = 'SDXL'):
14
+ if low_vram:
15
+ from bitsandbytes.optim import AdamW8bit
16
+ optimizer = AdamW8bit(_partial_=True)
17
+ else:
18
+ optimizer = torch.optim.AdamW(_partial_=True)
19
+
20
+ from cfgs.train.py import train_base, tuning_base
21
+
22
+ return dict(
23
+ _base_=[train_base, tuning_base],
24
+ mixed_precision=dtype,
25
+
26
+ model_part=CfgWDModelParser([
27
+ dict(
28
+ lr=lr,
29
+ layers=['denoiser'], # train UNet
30
+ )
31
+ ], weight_decay=1e-2),
32
+
33
+ ckpt_saver=dict(
34
+ SDXL=ckpt_saver(
35
+ ckpt_type='safetensors',
36
+ target_module='denoiser',
37
+ layers=LAYERS_TRAINABLE,
38
+ )
39
+ ),
40
+
41
+ train=dict(
42
+ train_steps=train_steps,
43
+ save_step=save_step,
44
+
45
+ optimizer=optimizer,
46
+
47
+ scheduler=ConstantLR(
48
+ _partial_=True,
49
+ warmup_steps=warmup_steps,
50
+ ),
51
+ ),
52
+
53
+ model=dict(
54
+ name=name,
55
+
56
+ ## Easy config
57
+ wrapper=SDXLWrapper.from_pretrained(
58
+ _partial_=True,
59
+ models=SDXL_auto_loader(ckpt_path=base_model, _partial_=True),
60
+ ),
61
+ ),
62
+
63
+ data_train=dataset,
64
+ )
65
+
66
+ @neko_cfg
67
+ def SDXL_lora_train(base_model: str, train_steps: int, dataset, save_step: int = 200, lr: float = 1e-4, rank: int = 4, alpha: float = None,
68
+ with_conv: bool = False, dtype: str = 'fp16', low_vram: bool = False, warmup_steps: int = 0, name: str = 'SDXL',
69
+ save_webui_format=False):
70
+ with disable_neko_cfg:
71
+ if alpha is None:
72
+ alpha = rank
73
+
74
+ if with_conv:
75
+ lora_layers = [
76
+ r're:denoiser.*\.attn.?$',
77
+ r're:denoiser.*\.ff$',
78
+ r're:denoiser.*\.resnets$',
79
+ r're:denoiser.*\.proj_in$',
80
+ r're:denoiser.*\.proj_out$',
81
+ r're:denoiser.*\.conv$',
82
+ ]
83
+ else:
84
+ lora_layers = [
85
+ r're:denoiser.*\.attn.?$',
86
+ r're:denoiser.*\.ff$',
87
+ ]
88
+
89
+ if low_vram:
90
+ from bitsandbytes.optim import AdamW8bit
91
+ optimizer = AdamW8bit(_partial_=True, betas=(0.9, 0.99))
92
+ else:
93
+ optimizer = torch.optim.AdamW(_partial_=True, betas=(0.9, 0.99))
94
+
95
+ if save_webui_format:
96
+ lora_format = LoraWebuiFormat()
97
+ else:
98
+ lora_format = SafeTensorFormat()
99
+
100
+ from cfgs.train.py.examples import SD_FT
101
+
102
+ return dict(
103
+ _base_=[SD_FT],
104
+ mixed_precision=dtype,
105
+
106
+ model_part=None,
107
+ model_plugin=CfgWDPluginParser(cfg_plugin=dict(
108
+ lora1=LoraLayer.wrap_model(
109
+ _partial_=True,
110
+ lr=lr,
111
+ rank=rank,
112
+ alpha=alpha,
113
+ layers=lora_layers
114
+ )
115
+ ), weight_decay=0.1),
116
+
117
+ ckpt_saver=dict(
118
+ _replace_ = True,
119
+ lora_unet=NekoPluginSaver(
120
+ format=lora_format,
121
+ target_plugin='lora1',
122
+ )
123
+ ),
124
+
125
+ train=dict(
126
+ train_steps=train_steps,
127
+ save_step=save_step,
128
+
129
+ optimizer=optimizer,
130
+
131
+ scheduler=ConstantLR(
132
+ _partial_=True,
133
+ warmup_steps=warmup_steps,
134
+ ),
135
+ ),
136
+
137
+ model=dict(
138
+ name=name,
139
+
140
+ wrapper=SDXLWrapper.from_pretrained(
141
+ models=SDXL_auto_loader(ckpt_path=base_model, _partial_=True),
142
+ _partial_=True,
143
+ ),
144
+ ),
145
+
146
+ data_train=dataset,
147
+ )
@@ -0,0 +1,228 @@
1
+ import torch
2
+ from rainbowneko.infer.workflow import (Actions, PrepareAction, LoopAction, LoadModelAction)
3
+ from rainbowneko.ckpt_manager import NekoModelLoader
4
+ from rainbowneko.parser import neko_cfg, disable_neko_cfg
5
+ from typing import Union, List
6
+
7
+ from hcpdiff.ckpt_manager import HCPLoraLoader
8
+ from hcpdiff.easy import Diffusers_SD, SD15_auto_loader, SDXL_auto_loader
9
+ from hcpdiff.workflow import (BuildModelsAction, PrepareDiffusionAction, XformersEnableAction, VaeOptimizeAction, TextHookAction,
10
+ AttnMultTextEncodeAction, SeedAction, MakeTimestepsAction, MakeLatentAction, DiffusionStepAction,
11
+ time_iter, DecodeAction, SaveImageAction, LatentResizeAction)
12
+
13
+ negative_prompt = 'lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry'
14
+
15
+ ## Easy config
16
+ @neko_cfg
17
+ def build_model(pretrained_model='ckpts/any5', noise_sampler=Diffusers_SD.dpmpp_2m_karras) -> Actions:
18
+ return Actions([
19
+ PrepareAction(device='cuda', dtype=torch.float16),
20
+ BuildModelsAction(
21
+ model_loader=SD15_auto_loader(
22
+ _partial_=True,
23
+ ckpt_path=pretrained_model,
24
+ noise_sampler=noise_sampler
25
+ )
26
+ ),
27
+ ])
28
+
29
+ @neko_cfg
30
+ def load_parts(info: List[str]) -> Actions:
31
+ acts = []
32
+ for i, path in enumerate(info):
33
+ part_unet = LoadModelAction(cfg={
34
+ f'part_unet_{i}':NekoModelLoader(
35
+ path=path,
36
+ state_prefix='denoiser.'
37
+ )
38
+ }, key_map_in=('denoiser -> model', 'in_preview -> in_preview'))
39
+ part_TE = LoadModelAction(cfg={
40
+ f'part_TE_{i}':NekoModelLoader(
41
+ path=path,
42
+ state_prefix='TE.',
43
+ )
44
+ }, key_map_in=('TE -> model', 'in_preview -> in_preview'))
45
+
46
+ with disable_neko_cfg:
47
+ acts.append(part_unet)
48
+ acts.append(part_TE)
49
+
50
+ return Actions(acts)
51
+
52
+ @neko_cfg
53
+ def load_lora(info: List[List]) -> Actions:
54
+ lora_acts = []
55
+ for i, item in enumerate(info):
56
+ lora_unet = LoadModelAction(cfg={
57
+ f'lora_unet_{i}':HCPLoraLoader(
58
+ path=item[0],
59
+ state_prefix='denoiser.',
60
+ alpha=item[1],
61
+ )
62
+ }, key_map_in=('denoiser -> model', 'in_preview -> in_preview'))
63
+ lora_TE = LoadModelAction(cfg={
64
+ f'lora_TE_{i}':HCPLoraLoader(
65
+ path=item[0],
66
+ state_prefix='TE.',
67
+ alpha=item[1],
68
+ )
69
+ }, key_map_in=('TE -> model', 'in_preview -> in_preview'))
70
+
71
+ with disable_neko_cfg:
72
+ lora_acts.append(lora_unet)
73
+ lora_acts.append(lora_TE)
74
+
75
+ return Actions(lora_acts)
76
+
77
+ @neko_cfg
78
+ def optimize_model() -> Actions:
79
+ return Actions([
80
+ PrepareDiffusionAction(),
81
+ XformersEnableAction(),
82
+ VaeOptimizeAction(slicing=True),
83
+ ])
84
+
85
+ @neko_cfg
86
+ def text(prompt, negative_prompt=negative_prompt, bs=4, N_repeats=1, layer_skip=1) -> Actions:
87
+ return Actions([
88
+ TextHookAction(N_repeats=N_repeats, layer_skip=layer_skip),
89
+ AttnMultTextEncodeAction(
90
+ prompt=prompt,
91
+ negative_prompt=negative_prompt,
92
+ bs=bs
93
+ ),
94
+ ])
95
+
96
+ @neko_cfg
97
+ def build_model_SDXL(pretrained_model='ckpts/any5', noise_sampler=Diffusers_SD.dpmpp_2m_karras) -> Actions:
98
+ return Actions([
99
+ PrepareAction(device='cuda', dtype=torch.float16),
100
+ ## Easy config
101
+ BuildModelsAction(
102
+ model_loader=SDXL_auto_loader(
103
+ _partial_=True,
104
+ ckpt_path=pretrained_model,
105
+ noise_sampler=noise_sampler
106
+ )
107
+ ),
108
+ ])
109
+
110
+ @neko_cfg
111
+ def text_SDXL(prompt, negative_prompt=negative_prompt, bs=4, N_repeats=1, layer_skip=1) -> Actions:
112
+ return Actions([
113
+ TextHookAction(N_repeats=N_repeats, layer_skip=layer_skip, TE_final_norm=False),
114
+ AttnMultTextEncodeAction(
115
+ prompt=prompt,
116
+ negative_prompt=negative_prompt,
117
+ bs=bs
118
+ ),
119
+ ])
120
+
121
+ @neko_cfg
122
+ def config_diffusion(width=512, height=512, seed=None, N_steps=20, strength: float = None) -> Actions:
123
+ return Actions([
124
+ SeedAction(seed),
125
+ MakeTimestepsAction(N_steps=N_steps, strength=strength),
126
+ MakeLatentAction(width=width, height=height)
127
+ ])
128
+
129
+ @neko_cfg
130
+ def diffusion(guidance_scale=7.0) -> Actions:
131
+ return Actions([
132
+ LoopAction(
133
+ iterator=time_iter,
134
+ actions=[
135
+ DiffusionStepAction(guidance_scale=guidance_scale)
136
+ ]
137
+ )
138
+ ])
139
+
140
+ @neko_cfg
141
+ def decode(save_root='output_pipe/') -> Actions:
142
+ return Actions([
143
+ DecodeAction(),
144
+ SaveImageAction(save_root=save_root, image_type='png'),
145
+ ])
146
+
147
+ @neko_cfg
148
+ def resize(width=1024, height=1024):
149
+ return Actions([
150
+ LatentResizeAction(width=width, height=height)
151
+ ])
152
+
153
+ @neko_cfg
154
+ def SD15_t2i(pretrained_model, prompt, negative_prompt=negative_prompt, noise_sampler=Diffusers_SD.dpmpp_2m_karras, bs=4, width=512, height=512,
155
+ seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/', N_repeats=1, layer_skip=1):
156
+ return dict(workflow=Actions(actions=[
157
+ build_model(pretrained_model=pretrained_model, noise_sampler=noise_sampler),
158
+ optimize_model(),
159
+ text(prompt=prompt, negative_prompt=negative_prompt, bs=bs, N_repeats=N_repeats, layer_skip=layer_skip),
160
+ config_diffusion(width=width, height=height, seed=seed, N_steps=N_steps),
161
+ diffusion(guidance_scale=guidance_scale),
162
+ decode(save_root=save_root)
163
+ ]))
164
+
165
+ @neko_cfg
166
+ def SD15_t2i_parts(pretrained_model, parts, prompt, negative_prompt=negative_prompt, noise_sampler=Diffusers_SD.dpmpp_2m_karras, bs=4, width=512, height=512,
167
+ seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/', N_repeats=1, layer_skip=1):
168
+ return dict(workflow=Actions(actions=[
169
+ build_model(pretrained_model=pretrained_model, noise_sampler=noise_sampler),
170
+ load_parts(parts),
171
+ optimize_model(),
172
+ text(prompt=prompt, negative_prompt=negative_prompt, bs=bs, N_repeats=N_repeats, layer_skip=layer_skip),
173
+ config_diffusion(width=width, height=height, seed=seed, N_steps=N_steps),
174
+ diffusion(guidance_scale=guidance_scale),
175
+ decode(save_root=save_root)
176
+ ]))
177
+
178
+ @neko_cfg
179
+ def SD15_t2i_lora(pretrained_model, lora_info, prompt, negative_prompt=negative_prompt, noise_sampler=Diffusers_SD.dpmpp_2m_karras, bs=4,
180
+ width=512, height=512, seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/', N_repeats=1, layer_skip=1):
181
+ return dict(workflow=Actions(actions=[
182
+ build_model(pretrained_model=pretrained_model, noise_sampler=noise_sampler),
183
+ load_lora(info=lora_info),
184
+ optimize_model(),
185
+ text(prompt=prompt, negative_prompt=negative_prompt, bs=bs, N_repeats=N_repeats, layer_skip=layer_skip),
186
+ config_diffusion(width=width, height=height, seed=seed, N_steps=N_steps),
187
+ diffusion(guidance_scale=guidance_scale),
188
+ decode(save_root=save_root)
189
+ ]))
190
+
191
+ @neko_cfg
192
+ def SDXL_t2i(pretrained_model, prompt, negative_prompt=negative_prompt, noise_sampler=Diffusers_SD.dpmpp_2m_karras, bs=4, width=1024, height=1024,
193
+ seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/', N_repeats=1, layer_skip=1):
194
+ return dict(workflow=Actions(actions=[
195
+ build_model_SDXL(pretrained_model=pretrained_model, noise_sampler=noise_sampler),
196
+ optimize_model(),
197
+ text_SDXL(prompt=prompt, negative_prompt=negative_prompt, bs=bs, N_repeats=N_repeats, layer_skip=layer_skip),
198
+ config_diffusion(width=width, height=height, seed=seed, N_steps=N_steps),
199
+ diffusion(guidance_scale=guidance_scale),
200
+ decode(save_root=save_root)
201
+ ]))
202
+
203
+ @neko_cfg
204
+ def SDXL_t2i_parts(pretrained_model, parts, prompt, negative_prompt=negative_prompt, noise_sampler=Diffusers_SD.dpmpp_2m_karras, bs=4, width=1024, height=1024,
205
+ seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/', N_repeats=1, layer_skip=1):
206
+ return dict(workflow=Actions(actions=[
207
+ build_model_SDXL(pretrained_model=pretrained_model, noise_sampler=noise_sampler),
208
+ load_parts(parts),
209
+ optimize_model(),
210
+ text_SDXL(prompt=prompt, negative_prompt=negative_prompt, bs=bs, N_repeats=N_repeats, layer_skip=layer_skip),
211
+ config_diffusion(width=width, height=height, seed=seed, N_steps=N_steps),
212
+ diffusion(guidance_scale=guidance_scale),
213
+ decode(save_root=save_root)
214
+ ]))
215
+
216
+
217
+ @neko_cfg
218
+ def SDXL_t2i_lora(pretrained_model, lora_info, prompt, negative_prompt=negative_prompt, noise_sampler=Diffusers_SD.dpmpp_2m_karras, bs=4,
219
+ width=1024, height=1024, seed=None, N_steps=20, guidance_scale=7.0, save_root='output_pipe/', N_repeats=1, layer_skip=1):
220
+ return dict(workflow=Actions(actions=[
221
+ build_model_SDXL(pretrained_model=pretrained_model, noise_sampler=noise_sampler),
222
+ load_lora(info=lora_info),
223
+ optimize_model(),
224
+ text_SDXL(prompt=prompt, negative_prompt=negative_prompt, bs=bs, N_repeats=N_repeats, layer_skip=layer_skip),
225
+ config_diffusion(width=width, height=height, seed=seed, N_steps=N_steps),
226
+ diffusion(guidance_scale=guidance_scale),
227
+ decode(save_root=save_root)
228
+ ]))
@@ -0,0 +1,2 @@
1
+ from .loader import SD15_auto_loader, SDXL_auto_loader, PixArt_auto_loader
2
+ from .cnet import ControlNet_SD15, make_controlnet_handler
@@ -0,0 +1,31 @@
1
+ from hcpdiff.data.handler import ControlNetHandler, StableDiffusionHandler
2
+ from hcpdiff.models import ControlNetPlugin
3
+ from rainbowneko.data import SyncHandler
4
+ from rainbowneko.parser import neko_cfg
5
+
6
+ @neko_cfg
7
+ def ControlNet_SD15(lr=1e-4):
8
+ return ControlNetPlugin(
9
+ _partial_=True,
10
+ lr=lr,
11
+ from_layers=[
12
+ 'pre_hook:',
13
+ 'pre_hook:conv_in', # to make forward inside autocast
14
+ ],
15
+ to_layers=[
16
+ 'down_blocks.0',
17
+ 'down_blocks.1',
18
+ 'down_blocks.2',
19
+ 'down_blocks.3',
20
+ 'mid_block',
21
+ 'pre_hook:up_blocks.3.resnets.2',
22
+ ]
23
+ )
24
+
25
+ @neko_cfg
26
+ def make_controlnet_handler(bucket=None, encoder_attention_mask=False, erase=0.15, dropout=0.0, shuffle=0.0, word_names={}):
27
+ return SyncHandler(
28
+ diffusion=StableDiffusionHandler(bucket=bucket, encoder_attention_mask=encoder_attention_mask, erase=erase, dropout=dropout, shuffle=shuffle,
29
+ word_names=word_names),
30
+ cnet=ControlNetHandler(bucket=bucket)
31
+ )
@@ -0,0 +1,79 @@
1
+ import torch
2
+ from hcpdiff.ckpt_manager import DiffusersSD15Format, DiffusersSDXLFormat, DiffusersPixArtFormat, OfficialSD15Format, OfficialSDXLFormat
3
+ from rainbowneko.ckpt_manager import NekoLoader, LocalCkptSource
4
+ from hcpdiff.utils import auto_tokenizer_cls, auto_text_encoder_cls, get_pipe_name
5
+ from hcpdiff.models.wrapper import SDXLWrapper, SD15Wrapper, PixArtWrapper
6
+ from hcpdiff.models.compose import SDXLTextEncoder
7
+ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
8
+
9
+ def SD15_auto_loader(ckpt_path, denoiser=None, TE=None, vae=None, noise_sampler=None,
10
+ tokenizer=None, revision=None, dtype=torch.float32, **kwargs):
11
+ try:
12
+ try_diffusers = StableDiffusionPipeline.load_config(ckpt_path)
13
+ loader = NekoLoader(
14
+ format=DiffusersSD15Format(),
15
+ source=LocalCkptSource(),
16
+ )
17
+ except EnvironmentError:
18
+ loader = NekoLoader(
19
+ format=OfficialSD15Format(),
20
+ source=LocalCkptSource(),
21
+ )
22
+ models = loader.load(ckpt_path, denoiser=denoiser, TE=TE, vae=vae, noise_sampler=noise_sampler, tokenizer=tokenizer, revision=revision,
23
+ dtype=dtype, **kwargs)
24
+ return models
25
+
26
+ def SDXL_auto_loader(ckpt_path, denoiser=None, TE=None, vae=None, noise_sampler=None,
27
+ tokenizer=None, revision=None, dtype=torch.float32, **kwargs):
28
+ try:
29
+ try_diffusers = StableDiffusionXLPipeline.load_config(ckpt_path)
30
+ loader = NekoLoader(
31
+ format=DiffusersSDXLFormat(),
32
+ source=LocalCkptSource(),
33
+ )
34
+ except EnvironmentError:
35
+ loader = NekoLoader(
36
+ format=OfficialSDXLFormat(),
37
+ source=LocalCkptSource(),
38
+ )
39
+ models = loader.load(ckpt_path, denoiser=denoiser, TE=TE, vae=vae, noise_sampler=noise_sampler, tokenizer=tokenizer, revision=revision,
40
+ dtype=dtype, **kwargs)
41
+ return models
42
+
43
+ def PixArt_auto_loader(ckpt_path, denoiser=None, TE=None, vae=None, noise_sampler=None,
44
+ tokenizer=None, revision=None, dtype=torch.float32, **kwargs):
45
+ loader = NekoLoader(
46
+ format=DiffusersPixArtFormat(),
47
+ source=LocalCkptSource(),
48
+ )
49
+ models = loader.load(ckpt_path, denoiser=denoiser, TE=TE, vae=vae, noise_sampler=noise_sampler, tokenizer=tokenizer, revision=revision,
50
+ dtype=dtype, **kwargs)
51
+ return models
52
+
53
+ def auto_load_wrapper(pretrained_model, denoiser=None, TE=None, vae=None, noise_sampler=None, tokenizer=None, revision=None,
54
+ dtype=torch.float32, **kwargs):
55
+ if TE is not None:
56
+ text_encoder_cls = type(TE)
57
+ else:
58
+ text_encoder_cls = auto_text_encoder_cls(pretrained_model, revision)
59
+
60
+ pipe_name = get_pipe_name(pretrained_model)
61
+
62
+ if text_encoder_cls == SDXLTextEncoder:
63
+ wrapper_cls = SDXLWrapper
64
+ format = DiffusersSDXLFormat()
65
+ elif 'PixArt' in pipe_name:
66
+ wrapper_cls = PixArtWrapper
67
+ format = DiffusersPixArtFormat()
68
+ else:
69
+ wrapper_cls = SD15Wrapper
70
+ format = DiffusersSD15Format()
71
+
72
+ loader = NekoLoader(
73
+ format=format,
74
+ source=LocalCkptSource(),
75
+ )
76
+ models = loader.load(pretrained_model, denoiser=denoiser, TE=TE, vae=vae, noise_sampler=noise_sampler, tokenizer=tokenizer, revision=revision,
77
+ dtype=dtype)
78
+
79
+ return wrapper_cls.build_from_pretrained(models, **kwargs)
@@ -0,0 +1,46 @@
1
+ from hcpdiff.diffusion.sampler import DiffusersSampler
2
+ from diffusers import DPMSolverMultistepScheduler, DDIMScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler
3
+
4
+ class Diffusers_SD:
5
+ dpmpp_2m = DiffusersSampler(
6
+ DPMSolverMultistepScheduler(
7
+ beta_start=0.00085,
8
+ beta_end=0.012,
9
+ beta_schedule='scaled_linear',
10
+ algorithm_type='dpmsolver++',
11
+ )
12
+ )
13
+
14
+ dpmpp_2m_karras = DiffusersSampler(
15
+ DPMSolverMultistepScheduler(
16
+ beta_start=0.00085,
17
+ beta_end=0.012,
18
+ beta_schedule='scaled_linear',
19
+ algorithm_type='dpmsolver++',
20
+ use_karras_sigmas=True,
21
+ )
22
+ )
23
+
24
+ ddim = DiffusersSampler(
25
+ DDIMScheduler(
26
+ beta_start=0.00085,
27
+ beta_end=0.012,
28
+ beta_schedule='scaled_linear',
29
+ )
30
+ )
31
+
32
+ euler = DiffusersSampler(
33
+ EulerDiscreteScheduler(
34
+ beta_start=0.00085,
35
+ beta_end=0.012,
36
+ beta_schedule='scaled_linear',
37
+ )
38
+ )
39
+
40
+ euler_a = DiffusersSampler(
41
+ EulerAncestralDiscreteScheduler(
42
+ beta_start=0.00085,
43
+ beta_end=0.012,
44
+ beta_schedule='scaled_linear',
45
+ )
46
+ )
@@ -0,0 +1 @@
1
+ from .previewer import HCPPreviewer
@@ -0,0 +1,60 @@
1
+ from pathlib import Path
2
+
3
+ import torch
4
+ from rainbowneko.evaluate.preview import WorkflowPreviewer
5
+ from rainbowneko.utils import to_cuda
6
+
7
+ from hcpdiff.models.wrapper import SD15Wrapper
8
+ from accelerate.hooks import remove_hook_from_module
9
+
10
+ class HCPPreviewer(WorkflowPreviewer):
11
+
12
+ @torch.no_grad()
13
+ def evaluate(self, step: int, model: SD15Wrapper, prefix='eval/'):
14
+ if step%self.interval != 0 or not self.trainer.is_local_main_process:
15
+ return
16
+
17
+ # record training layers
18
+ training_layers = [layer for layer in model.modules() if layer.training]
19
+
20
+ model.eval()
21
+ self.trainer.loggers.info(f'Preview')
22
+
23
+ N_repeats = model.text_enc_hook.N_repeats
24
+ clip_skip = model.text_enc_hook.clip_skip
25
+ clip_final_norm = model.text_enc_hook.clip_final_norm
26
+ use_attention_mask = model.text_enc_hook.use_attention_mask
27
+
28
+ preview_root = Path(self.trainer.exp_dir)/'imgs'
29
+ preview_root.mkdir(parents=True, exist_ok=True)
30
+
31
+ states = self.workflow_runner.run(model=model, in_preview=True, te_hook=model.text_enc_hook,
32
+ device=self.device, dtype=self.dtype, preview_root=preview_root, preview_step=step,
33
+ world_size=self.trainer.world_size, local_rank=self.trainer.local_rank,
34
+ emb_hook=self.trainer.cfgs.emb_pt.embedding_hook if self.trainer.pt_trainable else None)
35
+
36
+ # restore model states
37
+ if model.vae is not None:
38
+ model.vae.disable_tiling()
39
+ model.vae.disable_slicing()
40
+ remove_hook_from_module(model.vae, recurse=True)
41
+ if 'vae_encode_raw' in states:
42
+ model.vae.encode = states['vae_encode_raw']
43
+ model.vae.decode = states['vae_decode_raw']
44
+
45
+ if 'emb_hook' in states and not self.trainer.pt_trainable:
46
+ states['emb_hook'].remove()
47
+
48
+ if self.trainer.pt_trainable:
49
+ self.trainer.cfgs.emb_pt.embedding_hook.N_repeats = N_repeats
50
+
51
+ model.tokenizer.N_repeats = N_repeats
52
+ model.text_enc_hook.N_repeats = N_repeats
53
+ model.text_enc_hook.clip_skip = clip_skip
54
+ model.text_enc_hook.clip_final_norm = clip_final_norm
55
+ model.text_enc_hook.use_attention_mask = use_attention_mask
56
+
57
+ to_cuda(model)
58
+
59
+ for layer in training_layers:
60
+ layer.train()
hcpdiff/loss/__init__.py CHANGED
@@ -1 +1,4 @@
1
- from .min_snr_loss import MinSNRLoss, SoftMinSNRLoss, KDiffMinSNRLoss, EDMLoss
1
+ from .weighting import MinSNRWeight, SNRWeight, EDMWeight, LossWeight
2
+ from .ssim import SSIMLoss, MS_SSIMLoss
3
+ from .gw import GWLoss
4
+ from .base import DiffusionLossContainer
hcpdiff/loss/base.py ADDED
@@ -0,0 +1,41 @@
1
+ from rainbowneko.train.loss import LossContainer
2
+ from typing import Dict, Any
3
+ from torch import Tensor
4
+
5
+ class DiffusionLossContainer(LossContainer):
6
+ def __init__(self, loss, weight=1.0, key_map=None):
7
+ key_map = key_map or getattr(loss, '_key_map', None) or ('pred.model_pred -> 0', 'pred.target -> 1')
8
+ super().__init__(loss, weight, key_map)
9
+ self.target_type = getattr(loss, 'target_type', 'eps')
10
+
11
+ def get_target(self, pred_type, model_pred, x_0, noise, x_t, sigma, noise_sampler, **kwargs):
12
+ # Get target
13
+ if self.target_type == "eps":
14
+ target = noise
15
+ elif self.target_type == "x0":
16
+ target = x_0
17
+ elif self.target_type == "velocity":
18
+ target = noise_sampler.eps_to_velocity(noise, x_t, sigma)
19
+ else:
20
+ raise ValueError(f"Unsupport target_type {self.target_type}")
21
+
22
+ # TODO: put in wrapper
23
+ # # remove pred vars
24
+ # if model_pred.shape[1] == target.shape[1]*2:
25
+ # model_pred, _ = model_pred.chunk(2, dim=1)
26
+
27
+ # Convert pred_type to target_type
28
+ if pred_type != self.target_type:
29
+ cvt_func = getattr(noise_sampler, f'{pred_type}_to_{self.target_type}', None)
30
+ if cvt_func is None:
31
+ raise ValueError(f"Unsupport pred_type {pred_type} with target_type {self.target_type}")
32
+ else:
33
+ model_pred = cvt_func(model_pred, x_t, sigma)
34
+ return model_pred, target
35
+
36
+ def forward(self, pred:Dict[str,Any], inputs:Dict[str,Any]) -> Tensor:
37
+ model_pred, target = self.get_target(**pred)
38
+ pred['model_pred'] = model_pred
39
+ pred['target'] = target
40
+ loss = super().forward(pred, inputs) * self.weight # [B,*,*,*]
41
+ return loss.mean()
hcpdiff/loss/gw.py ADDED
@@ -0,0 +1,35 @@
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ class GWLoss(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ sobel_x = [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]
10
+ sobel_y = [[-1, -2, -1], [0, 0, 0], [1, 2, 1]]
11
+ self.sobel_x = torch.FloatTensor(sobel_x)
12
+ self.sobel_y = torch.FloatTensor(sobel_y)
13
+ self.register_buffer('sobel_x', self.sobel_x)
14
+ self.register_buffer('sobel_y', self.sobel_y)
15
+
16
+ def forward(self, pred, target):
17
+ '''
18
+
19
+ :param pred: [B,C,H,W]
20
+ :param target: [B,C,H,W]
21
+ :return: [B,C,H,W]
22
+ '''
23
+ b, c, w, h = pred.shape
24
+
25
+ sobel_x = self.sobel_x.expand(c, 1, 3, 3).to(pred.device)
26
+ sobel_y = self.sobel_y.expand(c, 1, 3, 3).to(pred.device)
27
+ Ix1 = F.conv2d(pred, sobel_x, stride=1, padding=1, groups=c)
28
+ Ix2 = F.conv2d(target, sobel_x, stride=1, padding=1, groups=c)
29
+ Iy1 = F.conv2d(pred, sobel_y, stride=1, padding=1, groups=c)
30
+ Iy2 = F.conv2d(target, sobel_y, stride=1, padding=1, groups=c)
31
+
32
+ dx = torch.abs(Ix1 - Ix2)
33
+ dy = torch.abs(Iy1 - Iy2)
34
+ loss = (1 + 4 * dx) * (1 + 4 * dy) * torch.abs(pred - target)
35
+ return loss