hcpdiff 0.9.0__py3-none-any.whl → 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (208) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +80 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +1 -2
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text2img.py +36 -74
  21. hcpdiff/data/source/text2img_cond.py +9 -15
  22. hcpdiff/diffusion/__init__.py +0 -0
  23. hcpdiff/diffusion/noise/__init__.py +2 -0
  24. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  25. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  26. hcpdiff/diffusion/sampler/__init__.py +5 -0
  27. hcpdiff/diffusion/sampler/base.py +72 -0
  28. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  29. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  30. hcpdiff/diffusion/sampler/edm.py +22 -0
  31. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  35. hcpdiff/easy/__init__.py +2 -0
  36. hcpdiff/easy/cfg/__init__.py +3 -0
  37. hcpdiff/easy/cfg/sd15_train.py +201 -0
  38. hcpdiff/easy/cfg/sdxl_train.py +140 -0
  39. hcpdiff/easy/cfg/t2i.py +177 -0
  40. hcpdiff/easy/model/__init__.py +2 -0
  41. hcpdiff/easy/model/cnet.py +31 -0
  42. hcpdiff/easy/model/loader.py +79 -0
  43. hcpdiff/easy/sampler.py +46 -0
  44. hcpdiff/evaluate/__init__.py +1 -0
  45. hcpdiff/evaluate/previewer.py +60 -0
  46. hcpdiff/loss/__init__.py +4 -1
  47. hcpdiff/loss/base.py +41 -0
  48. hcpdiff/loss/gw.py +35 -0
  49. hcpdiff/loss/ssim.py +37 -0
  50. hcpdiff/loss/vlb.py +79 -0
  51. hcpdiff/loss/weighting.py +66 -0
  52. hcpdiff/models/__init__.py +2 -2
  53. hcpdiff/models/cfg_context.py +17 -14
  54. hcpdiff/models/compose/compose_hook.py +44 -23
  55. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  56. hcpdiff/models/compose/sdxl_composer.py +4 -4
  57. hcpdiff/models/container.py +1 -1
  58. hcpdiff/models/controlnet.py +16 -16
  59. hcpdiff/models/lora_base_patch.py +14 -25
  60. hcpdiff/models/lora_layers.py +3 -9
  61. hcpdiff/models/lora_layers_patch.py +14 -24
  62. hcpdiff/models/text_emb_ex.py +84 -6
  63. hcpdiff/models/textencoder_ex.py +54 -18
  64. hcpdiff/models/wrapper/__init__.py +3 -0
  65. hcpdiff/models/wrapper/pixart.py +19 -0
  66. hcpdiff/models/wrapper/sd.py +218 -0
  67. hcpdiff/models/wrapper/utils.py +20 -0
  68. hcpdiff/parser/__init__.py +1 -0
  69. hcpdiff/parser/embpt.py +32 -0
  70. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  71. hcpdiff/tools/dataset_generator.py +94 -0
  72. hcpdiff/tools/download_hf_model.py +24 -0
  73. hcpdiff/tools/embedding_convert.py +6 -2
  74. hcpdiff/tools/init_proj.py +3 -21
  75. hcpdiff/tools/lora_convert.py +19 -15
  76. hcpdiff/tools/save_model.py +12 -0
  77. hcpdiff/tools/sd2diffusers.py +1 -1
  78. hcpdiff/train_colo.py +1 -1
  79. hcpdiff/train_deepspeed.py +1 -1
  80. hcpdiff/trainer_ac.py +79 -0
  81. hcpdiff/trainer_ac_single.py +31 -0
  82. hcpdiff/utils/__init__.py +0 -2
  83. hcpdiff/utils/inpaint_pipe.py +790 -0
  84. hcpdiff/utils/net_utils.py +29 -6
  85. hcpdiff/utils/pipe_hook.py +46 -33
  86. hcpdiff/utils/utils.py +21 -4
  87. hcpdiff/workflow/__init__.py +15 -10
  88. hcpdiff/workflow/daam/__init__.py +1 -0
  89. hcpdiff/workflow/daam/act.py +66 -0
  90. hcpdiff/workflow/daam/hook.py +109 -0
  91. hcpdiff/workflow/diffusion.py +128 -136
  92. hcpdiff/workflow/fast.py +31 -0
  93. hcpdiff/workflow/flow.py +67 -0
  94. hcpdiff/workflow/io.py +36 -68
  95. hcpdiff/workflow/model.py +46 -43
  96. hcpdiff/workflow/text.py +84 -52
  97. hcpdiff/workflow/utils.py +32 -12
  98. hcpdiff/workflow/vae.py +37 -38
  99. hcpdiff-2.1.dist-info/METADATA +285 -0
  100. hcpdiff-2.1.dist-info/RECORD +114 -0
  101. {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
  102. hcpdiff-2.1.dist-info/entry_points.txt +5 -0
  103. hcpdiff/ckpt_manager/base.py +0 -16
  104. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  105. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  106. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -60
  107. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  108. hcpdiff/data/bucket.py +0 -358
  109. hcpdiff/data/caption_loader.py +0 -80
  110. hcpdiff/data/cond_dataset.py +0 -40
  111. hcpdiff/data/crop_info_dataset.py +0 -40
  112. hcpdiff/data/data_processor.py +0 -33
  113. hcpdiff/data/pair_dataset.py +0 -146
  114. hcpdiff/data/sampler.py +0 -54
  115. hcpdiff/data/source/base.py +0 -30
  116. hcpdiff/data/utils.py +0 -80
  117. hcpdiff/infer_workflow.py +0 -57
  118. hcpdiff/loggers/__init__.py +0 -13
  119. hcpdiff/loggers/base_logger.py +0 -76
  120. hcpdiff/loggers/cli_logger.py +0 -40
  121. hcpdiff/loggers/preview/__init__.py +0 -1
  122. hcpdiff/loggers/preview/image_previewer.py +0 -149
  123. hcpdiff/loggers/tensorboard_logger.py +0 -30
  124. hcpdiff/loggers/wandb_logger.py +0 -31
  125. hcpdiff/loggers/webui_logger.py +0 -9
  126. hcpdiff/loss/min_snr_loss.py +0 -52
  127. hcpdiff/models/layers.py +0 -81
  128. hcpdiff/models/plugin.py +0 -348
  129. hcpdiff/models/wrapper.py +0 -75
  130. hcpdiff/noise/__init__.py +0 -3
  131. hcpdiff/noise/noise_base.py +0 -16
  132. hcpdiff/noise/pyramid_noise.py +0 -50
  133. hcpdiff/noise/zero_terminal.py +0 -44
  134. hcpdiff/train_ac.py +0 -565
  135. hcpdiff/train_ac_single.py +0 -39
  136. hcpdiff/utils/caption_tools.py +0 -105
  137. hcpdiff/utils/cfg_net_tools.py +0 -321
  138. hcpdiff/utils/cfg_resolvers.py +0 -16
  139. hcpdiff/utils/ema.py +0 -52
  140. hcpdiff/utils/img_size_tool.py +0 -248
  141. hcpdiff/vis/__init__.py +0 -3
  142. hcpdiff/vis/base_interface.py +0 -12
  143. hcpdiff/vis/disk_interface.py +0 -48
  144. hcpdiff/vis/webui_interface.py +0 -17
  145. hcpdiff/visualizer.py +0 -258
  146. hcpdiff/visualizer_reloadable.py +0 -237
  147. hcpdiff/workflow/base.py +0 -59
  148. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  149. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  150. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  151. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  152. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  153. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  154. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  155. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  156. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  157. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  158. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  159. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  160. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  161. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  162. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  163. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  164. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  165. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  166. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  167. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  168. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  169. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  170. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  171. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  172. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  173. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  174. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  175. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  176. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  177. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  178. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  179. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  180. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  181. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  182. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  183. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  184. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  185. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  186. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  187. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  188. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  189. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  190. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  191. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  192. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -57
  193. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  194. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero2.json +0 -32
  195. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero3.json +0 -39
  196. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  197. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  198. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  199. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  200. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  201. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  202. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  203. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  204. hcpdiff-0.9.0.dist-info/METADATA +0 -199
  205. hcpdiff-0.9.0.dist-info/RECORD +0 -155
  206. hcpdiff-0.9.0.dist-info/entry_points.txt +0 -2
  207. {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
  208. {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
@@ -3,14 +3,14 @@ import os.path
3
3
  from typing import List
4
4
  import math
5
5
 
6
- from hcpdiff.ckpt_manager import auto_manager
6
+ from rainbowneko.ckpt_manager import auto_ckpt_loader, NekoModelSaver
7
7
 
8
8
  class LoraConverter:
9
9
  com_name_unet = ['down_blocks', 'up_blocks', 'mid_block', 'transformer_blocks', 'to_q', 'to_k', 'to_v', 'to_out', 'proj_in', 'proj_out', 'input_blocks', 'middle_block', 'output_blocks']
10
10
  com_name_TE = ['self_attn', 'q_proj', 'v_proj', 'k_proj', 'out_proj', 'text_model']
11
11
  prefix_unet = 'lora_unet_'
12
12
  prefix_TE = 'lora_te_'
13
- prefix_TE_xl_clip_B = 'lora_te1_'
13
+ prefix_TE_xl_clip_L = 'lora_te1_'
14
14
  prefix_TE_xl_clip_bigG = 'lora_te2_'
15
15
 
16
16
  lora_w_map = {'lora_down.weight': 'W_down', 'lora_up.weight':'W_up'}
@@ -25,14 +25,14 @@ class LoraConverter:
25
25
  sd_TE = self.convert_from_webui_(state, prefix=self.prefix_TE, com_name=self.com_name_TE, com_name_tmp=self.com_name_TE_tmp)
26
26
  else:
27
27
  sd_unet = self.convert_from_webui_xl_unet_(state, prefix=self.prefix_unet, com_name=self.com_name_unet, com_name_tmp=self.com_name_unet_tmp)
28
- sd_TE = self.convert_from_webui_xl_te_(state, prefix=self.prefix_TE_xl_clip_B, com_name=self.com_name_TE, com_name_tmp=self.com_name_TE_tmp)
28
+ sd_TE = self.convert_from_webui_xl_te_(state, prefix=self.prefix_TE_xl_clip_L, com_name=self.com_name_TE, com_name_tmp=self.com_name_TE_tmp)
29
29
  sd_TE2 = self.convert_from_webui_xl_te_(state, prefix=self.prefix_TE_xl_clip_bigG, com_name=self.com_name_TE, com_name_tmp=self.com_name_TE_tmp)
30
30
  sd_TE.update(sd_TE2)
31
31
 
32
32
  if auto_scale_alpha:
33
33
  sd_unet = self.alpha_scale_from_webui(sd_unet)
34
34
  sd_TE = self.alpha_scale_from_webui(sd_TE)
35
- return {'lora': sd_TE}, {'lora': sd_unet}
35
+ return {'plugin': sd_TE}, {'plugin': sd_unet}
36
36
 
37
37
  def convert_to_webui(self, sd_unet, sd_TE, auto_scale_alpha=False, sdxl=False):
38
38
  sd_unet = self.convert_to_webui_(sd_unet, prefix=self.prefix_unet)
@@ -72,7 +72,7 @@ class LoraConverter:
72
72
 
73
73
  sd_covert[f"{prefix}{model_k.replace('.', '_')}.{lora_k}"] = v
74
74
  return sd_covert
75
-
75
+
76
76
  def convert_to_webui_xl_(self, state, prefix):
77
77
  sd_convert = {}
78
78
  for k, v in state.items():
@@ -87,7 +87,7 @@ class LoraConverter:
87
87
 
88
88
  new_k = f"{prefix}{model_k.replace('.', '_')}.{lora_k}"
89
89
  if 'clip' in new_k:
90
- new_k = new_k.replace('_clip_B', '1') if 'clip_B' in new_k else new_k.replace('_clip_bigG', '2')
90
+ new_k = new_k.replace('_clip_L', '1') if 'clip_L' in new_k else new_k.replace('_clip_bigG', '2')
91
91
  sd_convert[new_k] = v
92
92
  return sd_convert
93
93
 
@@ -100,7 +100,7 @@ class LoraConverter:
100
100
  model_k, lora_k = k[prefix_len:].split('.', 1)
101
101
  model_k = self.replace_all(model_k, com_name, com_name_tmp).replace('_', '.').replace('%', '_')
102
102
  if prefix == 'lora_te1_':
103
- model_k = f'clip_B.{model_k}'
103
+ model_k = f'clip_L.{model_k}'
104
104
  else:
105
105
  model_k = f'clip_bigG.{model_k}'
106
106
 
@@ -221,23 +221,27 @@ if __name__ == '__main__':
221
221
 
222
222
  # load lora model
223
223
  print('convert lora model')
224
- ckpt_manager = auto_manager(args.lora_path)
224
+ ckpt_loader = auto_ckpt_loader(args.lora_path)
225
+ ckpt_saver = NekoModelSaver(
226
+ format=ckpt_loader.format,
227
+ source=ckpt_loader.source,
228
+ )
225
229
 
226
230
  if args.from_webui:
227
- state = ckpt_manager.load_ckpt(args.lora_path)
231
+ state = ckpt_loader.load(args.lora_path)
228
232
  # convert the weight name
229
233
  sd_TE, sd_unet = converter.convert_from_webui(state, auto_scale_alpha=args.auto_scale_alpha, sdxl=args.sdxl)
230
234
  # wegiht save
231
235
  os.makedirs(args.dump_path, exist_ok=True)
232
236
  TE_path = os.path.join(args.dump_path, 'TE-'+lora_name)
233
237
  unet_path = os.path.join(args.dump_path, 'unet-'+lora_name)
234
- ckpt_manager._save_ckpt(sd_TE, save_path=TE_path)
235
- ckpt_manager._save_ckpt(sd_unet, save_path=unet_path)
238
+ ckpt_saver.save(sd_TE, TE_path)
239
+ ckpt_saver.save(sd_unet, unet_path)
236
240
  print('save text encoder lora to:', TE_path)
237
241
  print('save unet lora to:', unet_path)
238
242
  elif args.to_webui:
239
- sd_unet = ckpt_manager.load_ckpt(args.lora_path)
240
- sd_TE = ckpt_manager.load_ckpt(args.lora_path_TE)
241
- state = converter.convert_to_webui(sd_unet['lora'], sd_TE['lora'], auto_scale_alpha=args.auto_scale_alpha, sdxl=args.sdxl)
242
- ckpt_manager._save_ckpt(state, save_path=args.dump_path)
243
+ sd_unet = ckpt_loader.load(args.lora_path)
244
+ sd_TE = ckpt_loader.load(args.lora_path_TE) if args.lora_path_TE else {'base':{}}
245
+ state = converter.convert_to_webui(sd_unet['base'], sd_TE['base'], auto_scale_alpha=args.auto_scale_alpha, sdxl=args.sdxl)
246
+ ckpt_saver.save(state, args.dump_path)
243
247
  print('save lora to:', args.dump_path)
@@ -0,0 +1,12 @@
1
+ from diffusers import DiffusionPipeline
2
+ import argparse
3
+
4
+ parser = argparse.ArgumentParser()
5
+ parser.add_argument("model", default=None, type=str)
6
+ parser.add_argument("output", default=None, type=str)
7
+ args = parser.parse_args()
8
+
9
+ pipe = DiffusionPipeline.from_pretrained(args.model, safety_checker=None, requires_safety_checker=False,
10
+ resume_download=True)
11
+
12
+ pipe.save_pretrained(args.output)
@@ -211,7 +211,7 @@ def sd_vae_to_diffuser(args):
211
211
  def convert_ckpt(args):
212
212
  pipe = load_sd_ckpt(
213
213
  args.checkpoint_path,
214
- original_config_file=args.original_config_file,
214
+ config_files={'v1': args.original_config_file},
215
215
  image_size=args.image_size,
216
216
  prediction_type=args.prediction_type,
217
217
  model_type=args.pipeline_type,
hcpdiff/train_colo.py CHANGED
@@ -23,7 +23,7 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
23
23
  from colossalai.utils.model.colo_init_context import _convert_to_coloparam
24
24
  from colossalai.tensor import ColoParameter
25
25
 
26
- from hcpdiff.train_ac import Trainer, get_scheduler, ModelEMA
26
+ from hcpdiff.train_ac_old import Trainer, get_scheduler, ModelEMA
27
27
  from diffusers import UNet2DConditionModel
28
28
  from hcpdiff.utils.colo_utils import gemini_zero_dpp, GeminiAdamOptimizerP
29
29
  from hcpdiff.utils.utils import load_config_with_cli
@@ -7,7 +7,7 @@ from functools import partial
7
7
  import torch
8
8
 
9
9
  from hcpdiff.ckpt_manager import CkptManagerPKL, CkptManagerSafe
10
- from hcpdiff.train_ac import Trainer, load_config_with_cli
10
+ from hcpdiff.train_ac_old import Trainer, load_config_with_cli
11
11
  from hcpdiff.utils.net_utils import get_scheduler
12
12
 
13
13
  class TrainerDeepSpeed(Trainer):
hcpdiff/trainer_ac.py ADDED
@@ -0,0 +1,79 @@
1
+ import argparse
2
+ import warnings
3
+
4
+ import torch
5
+ from rainbowneko.parser import load_config_with_cli
6
+ from rainbowneko.ckpt_manager import NekoSaver
7
+ from rainbowneko.train import Trainer
8
+ from rainbowneko.utils import xformers_available, is_dict
9
+ from hcpdiff.ckpt_manager import EmbFormat
10
+
11
+ class HCPTrainer(Trainer):
12
+ def config_model(self):
13
+ if self.cfgs.model.enable_xformers:
14
+ if xformers_available:
15
+ self.model_wrapper.enable_xformers()
16
+ else:
17
+ warnings.warn("xformers is not available. Make sure it is installed correctly")
18
+
19
+ if self.model_wrapper.vae is not None:
20
+ self.vae_dtype = self.weight_dtype_map.get(self.cfgs.model.get('vae_dtype', None), torch.float32)
21
+ self.model_wrapper.set_dtype(self.weight_dtype, self.vae_dtype)
22
+
23
+ if self.cfgs.model.gradient_checkpointing:
24
+ self.model_wrapper.enable_gradient_checkpointing()
25
+
26
+ def get_param_group_train(self):
27
+ train_params = super().get_param_group_train()
28
+
29
+ # For prompt-tuning
30
+ if self.cfgs.emb_pt is None:
31
+ train_params_emb, self.train_pts = [], {}
32
+ else:
33
+ from hcpdiff.parser import CfgEmbPTParser
34
+ self.cfgs.emb_pt: CfgEmbPTParser
35
+
36
+ train_params_emb, self.train_pts = self.cfgs.emb_pt.get_params_group(self.model_wrapper)
37
+ self.emb_format = EmbFormat()
38
+ train_params += train_params_emb
39
+ return train_params
40
+
41
+ @property
42
+ def pt_trainable(self):
43
+ return self.cfgs.emb_pt is not None
44
+
45
+ def get_loss(self, ds_name, model_pred, inputs):
46
+ loss = super().get_loss(ds_name, model_pred, inputs)
47
+ # make DDP happy
48
+ if len(self.train_pts)>0:
49
+ loss = loss+0*sum([emb.mean() for emb in self.train_pts.values()])
50
+ return loss
51
+
52
+ def save_model(self, from_raw=False):
53
+ NekoSaver.save_all(
54
+ self.model_raw,
55
+ plugin_groups={**self.all_plugin, 'embs': self.train_pts},
56
+ cfg=self.ckpt_saver,
57
+ model_ema=getattr(self, "ema_model", None),
58
+ name_template=f'{{}}-{self.real_step}',
59
+ )
60
+
61
+ self.loggers.info(f"Saved state, step: {self.real_step}")
62
+
63
+ def hcp_train():
64
+ import subprocess
65
+ parser = argparse.ArgumentParser(description='HCP-Diffusion Launcher')
66
+ parser.add_argument('--launch_cfg', type=str, default='cfgs/launcher/multi.yaml')
67
+ args, train_args = parser.parse_known_args()
68
+
69
+ subprocess.run(["accelerate", "launch", '--config_file', args.launch_cfg, "-m",
70
+ "hcpdiff.trainer_ac"] + train_args, check=True)
71
+
72
+ if __name__ == "__main__":
73
+ parser = argparse.ArgumentParser(description="HCP Diffusion Trainer")
74
+ parser.add_argument("--cfg", type=str, default=None, required=True)
75
+ args, cfg_args = parser.parse_known_args()
76
+
77
+ parser, conf = load_config_with_cli(args.cfg, args_list=cfg_args) # skip --cfg
78
+ trainer = HCPTrainer(parser, conf)
79
+ trainer.train()
@@ -0,0 +1,31 @@
1
+ import argparse
2
+ import sys
3
+ from functools import partial
4
+
5
+ import torch
6
+ from accelerate import Accelerator
7
+ from loguru import logger
8
+
9
+ from rainbowneko.train.trainer import TrainerSingleCard
10
+ from hcpdiff.trainer_ac import HCPTrainer, load_config_with_cli
11
+
12
+ class HCPTrainerSingleCard(TrainerSingleCard, HCPTrainer):
13
+ pass
14
+
15
+ def hcp_train():
16
+ import subprocess
17
+ parser = argparse.ArgumentParser(description='HCP-Diffusion Launcher')
18
+ parser.add_argument('--launch_cfg', type=str, default='cfgs/launcher/single.yaml')
19
+ args, train_args = parser.parse_known_args()
20
+
21
+ subprocess.run(["accelerate", "launch", '--config_file', args.launch_cfg, "-m",
22
+ "hcpdiff.trainer_ac_single"] + train_args, check=True)
23
+
24
+ if __name__ == '__main__':
25
+ parser = argparse.ArgumentParser(description='HCP Diffusion Trainer')
26
+ parser.add_argument("--cfg", type=str, default=None, required=True)
27
+ args, cfg_args = parser.parse_known_args()
28
+
29
+ parser, conf = load_config_with_cli(args.cfg, args_list=cfg_args) # skip --cfg
30
+ trainer = HCPTrainerSingleCard(parser, conf)
31
+ trainer.train()
hcpdiff/utils/__init__.py CHANGED
@@ -1,4 +1,2 @@
1
1
  from .utils import *
2
- from .img_size_tool import get_image_size
3
- from .cfg_resolvers import *
4
2
  from .net_utils import *