hcpdiff 0.9.1__py3-none-any.whl → 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (210) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +80 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +1 -2
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text2img.py +36 -74
  21. hcpdiff/data/source/text2img_cond.py +9 -15
  22. hcpdiff/diffusion/__init__.py +0 -0
  23. hcpdiff/diffusion/noise/__init__.py +2 -0
  24. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  25. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  26. hcpdiff/diffusion/sampler/__init__.py +5 -0
  27. hcpdiff/diffusion/sampler/base.py +72 -0
  28. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  29. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  30. hcpdiff/diffusion/sampler/edm.py +22 -0
  31. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  35. hcpdiff/easy/__init__.py +2 -0
  36. hcpdiff/easy/cfg/__init__.py +3 -0
  37. hcpdiff/easy/cfg/sd15_train.py +201 -0
  38. hcpdiff/easy/cfg/sdxl_train.py +140 -0
  39. hcpdiff/easy/cfg/t2i.py +177 -0
  40. hcpdiff/easy/model/__init__.py +2 -0
  41. hcpdiff/easy/model/cnet.py +31 -0
  42. hcpdiff/easy/model/loader.py +79 -0
  43. hcpdiff/easy/sampler.py +46 -0
  44. hcpdiff/evaluate/__init__.py +1 -0
  45. hcpdiff/evaluate/previewer.py +60 -0
  46. hcpdiff/loss/__init__.py +4 -1
  47. hcpdiff/loss/base.py +41 -0
  48. hcpdiff/loss/gw.py +35 -0
  49. hcpdiff/loss/ssim.py +37 -0
  50. hcpdiff/loss/vlb.py +79 -0
  51. hcpdiff/loss/weighting.py +66 -0
  52. hcpdiff/models/__init__.py +2 -2
  53. hcpdiff/models/cfg_context.py +17 -14
  54. hcpdiff/models/compose/compose_hook.py +44 -23
  55. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  56. hcpdiff/models/compose/sdxl_composer.py +4 -4
  57. hcpdiff/models/controlnet.py +16 -16
  58. hcpdiff/models/lora_base_patch.py +14 -25
  59. hcpdiff/models/lora_layers.py +3 -9
  60. hcpdiff/models/lora_layers_patch.py +14 -24
  61. hcpdiff/models/text_emb_ex.py +84 -6
  62. hcpdiff/models/textencoder_ex.py +54 -18
  63. hcpdiff/models/wrapper/__init__.py +3 -0
  64. hcpdiff/models/wrapper/pixart.py +19 -0
  65. hcpdiff/models/wrapper/sd.py +218 -0
  66. hcpdiff/models/wrapper/utils.py +20 -0
  67. hcpdiff/parser/__init__.py +1 -0
  68. hcpdiff/parser/embpt.py +32 -0
  69. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  70. hcpdiff/tools/dataset_generator.py +94 -0
  71. hcpdiff/tools/download_hf_model.py +24 -0
  72. hcpdiff/tools/init_proj.py +3 -21
  73. hcpdiff/tools/lora_convert.py +18 -17
  74. hcpdiff/tools/save_model.py +12 -0
  75. hcpdiff/tools/sd2diffusers.py +1 -1
  76. hcpdiff/train_colo.py +1 -1
  77. hcpdiff/train_deepspeed.py +1 -1
  78. hcpdiff/trainer_ac.py +79 -0
  79. hcpdiff/trainer_ac_single.py +31 -0
  80. hcpdiff/utils/__init__.py +0 -2
  81. hcpdiff/utils/inpaint_pipe.py +7 -2
  82. hcpdiff/utils/net_utils.py +29 -6
  83. hcpdiff/utils/pipe_hook.py +24 -7
  84. hcpdiff/utils/utils.py +21 -4
  85. hcpdiff/workflow/__init__.py +15 -10
  86. hcpdiff/workflow/daam/__init__.py +1 -0
  87. hcpdiff/workflow/daam/act.py +66 -0
  88. hcpdiff/workflow/daam/hook.py +109 -0
  89. hcpdiff/workflow/diffusion.py +114 -125
  90. hcpdiff/workflow/fast.py +31 -0
  91. hcpdiff/workflow/flow.py +67 -0
  92. hcpdiff/workflow/io.py +36 -130
  93. hcpdiff/workflow/model.py +46 -43
  94. hcpdiff/workflow/text.py +78 -46
  95. hcpdiff/workflow/utils.py +32 -12
  96. hcpdiff/workflow/vae.py +37 -38
  97. hcpdiff-2.1.dist-info/METADATA +285 -0
  98. hcpdiff-2.1.dist-info/RECORD +114 -0
  99. {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
  100. hcpdiff-2.1.dist-info/entry_points.txt +5 -0
  101. hcpdiff/ckpt_manager/base.py +0 -16
  102. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  103. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  104. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
  105. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  106. hcpdiff/data/bucket.py +0 -358
  107. hcpdiff/data/caption_loader.py +0 -80
  108. hcpdiff/data/cond_dataset.py +0 -40
  109. hcpdiff/data/crop_info_dataset.py +0 -40
  110. hcpdiff/data/data_processor.py +0 -33
  111. hcpdiff/data/pair_dataset.py +0 -146
  112. hcpdiff/data/sampler.py +0 -54
  113. hcpdiff/data/source/base.py +0 -30
  114. hcpdiff/data/utils.py +0 -80
  115. hcpdiff/deprecated/__init__.py +0 -1
  116. hcpdiff/deprecated/cfg_converter.py +0 -81
  117. hcpdiff/deprecated/lora_convert.py +0 -31
  118. hcpdiff/infer_workflow.py +0 -57
  119. hcpdiff/loggers/__init__.py +0 -13
  120. hcpdiff/loggers/base_logger.py +0 -76
  121. hcpdiff/loggers/cli_logger.py +0 -40
  122. hcpdiff/loggers/preview/__init__.py +0 -1
  123. hcpdiff/loggers/preview/image_previewer.py +0 -149
  124. hcpdiff/loggers/tensorboard_logger.py +0 -30
  125. hcpdiff/loggers/wandb_logger.py +0 -31
  126. hcpdiff/loggers/webui_logger.py +0 -9
  127. hcpdiff/loss/min_snr_loss.py +0 -52
  128. hcpdiff/models/layers.py +0 -81
  129. hcpdiff/models/plugin.py +0 -348
  130. hcpdiff/models/wrapper.py +0 -75
  131. hcpdiff/noise/__init__.py +0 -3
  132. hcpdiff/noise/noise_base.py +0 -16
  133. hcpdiff/noise/pyramid_noise.py +0 -50
  134. hcpdiff/noise/zero_terminal.py +0 -44
  135. hcpdiff/train_ac.py +0 -566
  136. hcpdiff/train_ac_single.py +0 -39
  137. hcpdiff/utils/caption_tools.py +0 -105
  138. hcpdiff/utils/cfg_net_tools.py +0 -321
  139. hcpdiff/utils/cfg_resolvers.py +0 -16
  140. hcpdiff/utils/ema.py +0 -52
  141. hcpdiff/utils/img_size_tool.py +0 -248
  142. hcpdiff/vis/__init__.py +0 -3
  143. hcpdiff/vis/base_interface.py +0 -12
  144. hcpdiff/vis/disk_interface.py +0 -48
  145. hcpdiff/vis/webui_interface.py +0 -17
  146. hcpdiff/viser_fast.py +0 -138
  147. hcpdiff/visualizer.py +0 -265
  148. hcpdiff/visualizer_reloadable.py +0 -237
  149. hcpdiff/workflow/base.py +0 -59
  150. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  151. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  152. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  153. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  154. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  155. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  156. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  157. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  158. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  159. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  160. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  161. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  162. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  163. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  164. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  165. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  166. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  167. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  168. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  169. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  170. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  171. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  172. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  173. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  174. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  175. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  176. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  177. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  178. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  179. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  180. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  181. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  182. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  183. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  184. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  185. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  186. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  187. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  188. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  189. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  190. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  191. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  192. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  193. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  194. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
  195. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  196. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
  197. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
  198. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  199. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  200. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  201. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  202. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  203. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  204. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  205. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  206. hcpdiff-0.9.1.dist-info/METADATA +0 -199
  207. hcpdiff-0.9.1.dist-info/RECORD +0 -160
  208. hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
  209. {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
  210. {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
@@ -8,29 +8,53 @@ textencoder_ex.py
8
8
  :Licence: Apache-2.0
9
9
  """
10
10
 
11
- from typing import Tuple, Optional, List
11
+ from typing import Tuple, Optional
12
12
 
13
13
  import torch
14
14
  from einops import repeat, rearrange
15
15
  from einops.layers.torch import Rearrange
16
+ from loguru import logger
16
17
  from torch import nn
18
+ from transformers import CLIPTextModelWithProjection, T5EncoderModel
17
19
  from transformers.models.clip.modeling_clip import CLIPAttention
18
20
 
19
21
  class TEEXHook:
20
- def __init__(self, text_enc: nn.Module, tokenizer, N_repeats=3, clip_skip=0, clip_final_norm=True, device='cuda', use_attention_mask=False):
22
+ def __init__(self, text_enc: nn.Module, tokenizer, N_repeats=1, clip_skip=0, clip_final_norm=True, use_attention_mask=False):
21
23
  self.text_enc = text_enc
22
24
  self.tokenizer = tokenizer
23
25
 
24
26
  self.N_repeats = N_repeats
25
27
  self.clip_skip = clip_skip
26
28
  self.clip_final_norm = clip_final_norm
27
- self.device = device
28
- self.attn_mult = None
29
29
  self.use_attention_mask = use_attention_mask
30
30
 
31
31
  text_enc.register_forward_hook(self.forward_hook)
32
32
  text_enc.register_forward_pre_hook(self.forward_hook_input)
33
33
 
34
+ def find_final_norm(self, text_enc: nn.Module):
35
+ for module in text_enc.modules():
36
+ if 'final_layer_norm' in module._modules:
37
+ logger.info(f'find final_layer_norm in {type(module)}')
38
+ return module.final_layer_norm
39
+
40
+ logger.info(f'final_layer_norm not found in {type(text_enc)}')
41
+ return None
42
+
43
+ @property
44
+ def clip_final_norm(self):
45
+ return self.final_layer_norm is not None
46
+
47
+ @clip_final_norm.setter
48
+ def clip_final_norm(self, value: bool):
49
+ if value:
50
+ self.final_layer_norm = self.find_final_norm(self.text_enc)
51
+ else:
52
+ self.final_layer_norm = None
53
+
54
+ @property
55
+ def device(self):
56
+ return self.text_enc.device
57
+
34
58
  def encode_prompt_to_emb(self, prompt):
35
59
  text_inputs = self.tokenizer(
36
60
  prompt,
@@ -50,12 +74,23 @@ class TEEXHook:
50
74
  if position_ids is not None:
51
75
  position_ids = position_ids.to(self.device)
52
76
 
53
- prompt_embeds, pooled_output = self.text_enc(
54
- text_input_ids.to(self.device),
55
- attention_mask=attention_mask,
56
- position_ids=position_ids,
57
- output_hidden_states=True,
58
- )
77
+ # align with sd-webui
78
+ if isinstance(self.text_enc, CLIPTextModelWithProjection):
79
+ self.text_enc.text_projection.weight.data = self.text_enc.text_projection.weight.data.t()
80
+
81
+ if isinstance(self.text_enc, T5EncoderModel):
82
+ prompt_embeds, pooled_output = self.text_enc(
83
+ text_input_ids.to(self.device),
84
+ attention_mask=attention_mask,
85
+ output_hidden_states=True,
86
+ )
87
+ else:
88
+ prompt_embeds, pooled_output = self.text_enc(
89
+ text_input_ids.to(self.device),
90
+ attention_mask=attention_mask,
91
+ position_ids=position_ids,
92
+ output_hidden_states=True,
93
+ )
59
94
  return prompt_embeds, pooled_output, attention_mask
60
95
 
61
96
  def forward_hook_input(self, host, feat_in):
@@ -64,13 +99,12 @@ class TEEXHook:
64
99
 
65
100
  def forward_hook(self, host, feat_in: Tuple[torch.Tensor], feat_out):
66
101
  encoder_hidden_states = feat_out['hidden_states'][-self.clip_skip-1]
67
- if self.clip_final_norm:
68
- encoder_hidden_states = self.text_enc.text_model.final_layer_norm(encoder_hidden_states)
102
+ if self.clip_final_norm and self.final_layer_norm is not None:
103
+ encoder_hidden_states = self.final_layer_norm(encoder_hidden_states)
69
104
  if self.text_enc.training and self.clip_skip>0:
70
105
  encoder_hidden_states = encoder_hidden_states+0*feat_out['last_hidden_state'].mean() # avoid unused parameters, make gradient checkpointing happy
71
-
72
106
  encoder_hidden_states = rearrange(encoder_hidden_states, '(b r) ... -> b r ...', r=self.N_repeats) # [B, N_repeat, N_word+2, N_emb]
73
- pooled_output = feat_out.pooler_output
107
+ pooled_output = feat_out.get('pooler_output', feat_out.get('text_embeds', None))
74
108
  # TODO: may have better fusion method
75
109
  if pooled_output is not None:
76
110
  pooled_output = rearrange(pooled_output, '(b r) ... -> b r ...', r=self.N_repeats).mean(dim=1)
@@ -81,7 +115,7 @@ class TEEXHook:
81
115
  return encoder_hidden_states, pooled_output
82
116
 
83
117
  def pool_hidden_states(self, encoder_hidden_states, input_ids):
84
- pooled_output = encoder_hidden_states[:, :, -1, :].mean(dim=1) # [B, N_emb]
118
+ pooled_output = encoder_hidden_states[:, :, -1, :].mean(dim=1) # [B, N_emb]
85
119
  return pooled_output
86
120
 
87
121
  @staticmethod
@@ -147,9 +181,11 @@ class TEEXHook:
147
181
  layer.forward = forward
148
182
 
149
183
  @classmethod
150
- def hook(cls, text_enc: nn.Module, tokenizer, N_repeats=3, clip_skip=0, clip_final_norm=True, device='cuda', use_attention_mask=False):
151
- return cls(text_enc, tokenizer, N_repeats=N_repeats, clip_skip=clip_skip, clip_final_norm=clip_final_norm, device=device, use_attention_mask=use_attention_mask)
184
+ def hook(cls, text_enc: nn.Module, tokenizer, N_repeats=3, clip_skip=0, clip_final_norm=True, use_attention_mask=False):
185
+ return cls(text_enc, tokenizer, N_repeats=N_repeats, clip_skip=clip_skip, clip_final_norm=clip_final_norm,
186
+ use_attention_mask=use_attention_mask)
152
187
 
153
188
  @classmethod
154
189
  def hook_pipe(cls, pipe, N_repeats=3, clip_skip=0, clip_final_norm=True, use_attention_mask=False):
155
- return cls(pipe.text_encoder, pipe.tokenizer, N_repeats=N_repeats, device='cuda', clip_skip=clip_skip, clip_final_norm=clip_final_norm, use_attention_mask=use_attention_mask)
190
+ return cls(pipe.text_encoder, pipe.tokenizer, N_repeats=N_repeats, clip_skip=clip_skip, clip_final_norm=clip_final_norm,
191
+ use_attention_mask=use_attention_mask)
@@ -0,0 +1,3 @@
1
+ from .sd import SD15Wrapper, SDXLWrapper
2
+ from .pixart import PixArtWrapper
3
+ from .utils import TEHookCFG, SD15_TEHookCFG, SDXL_TEHookCFG
@@ -0,0 +1,19 @@
1
+ from .sd import SD15Wrapper
2
+ from hcpdiff.utils import pad_attn_bias
3
+
4
+ class PixArtWrapper(SD15Wrapper):
5
+ def forward_denoiser(self, x_t, prompt_ids, encoder_hidden_states, timesteps, attn_mask=None, position_ids=None, resolution=None, aspect_ratio=None,
6
+ plugin_input={}, **kwargs):
7
+ if attn_mask is not None:
8
+ attn_mask[:, :self.min_attnmask] = 1
9
+ encoder_hidden_states, attn_mask = pad_attn_bias(encoder_hidden_states, attn_mask)
10
+
11
+ input_all = dict(prompt_ids=prompt_ids, timesteps=timesteps, position_ids=position_ids, attn_mask=attn_mask,
12
+ encoder_hidden_states=encoder_hidden_states, **plugin_input)
13
+ if hasattr(self.denoiser, 'input_feeder'):
14
+ for feeder in self.denoiser.input_feeder:
15
+ feeder(input_all)
16
+ added_cond_kwargs = {"resolution":resolution, "aspect_ratio":aspect_ratio}
17
+ model_pred = self.denoiser(x_t, encoder_hidden_states, timesteps, encoder_attention_mask=attn_mask,
18
+ added_cond_kwargs=added_cond_kwargs).sample # Predict the noise residual
19
+ return model_pred
@@ -0,0 +1,218 @@
1
+ from contextlib import nullcontext
2
+ from functools import partial
3
+ from typing import Dict, Union
4
+
5
+ import torch
6
+ from diffusers import AutoencoderKL, UNet2DConditionModel
7
+ from rainbowneko.models.wrapper import BaseWrapper
8
+ from torch import Tensor
9
+ from torch import nn
10
+
11
+ from hcpdiff.diffusion.sampler import BaseSampler
12
+ from hcpdiff.models import TEEXHook
13
+ from hcpdiff.models.compose import ComposeTEEXHook
14
+ from hcpdiff.utils import pad_attn_bias
15
+ from .utils import TEHookCFG, SD15_TEHookCFG, SDXL_TEHookCFG
16
+ from ..cfg_context import CFGContext
17
+
18
+ class SD15Wrapper(BaseWrapper):
19
+ def __init__(self, denoiser: UNet2DConditionModel, TE, vae: AutoencoderKL, noise_sampler: BaseSampler, tokenizer, min_attnmask=0,
20
+ pred_type='eps', TE_hook_cfg:TEHookCFG=SD15_TEHookCFG, cfg_context=CFGContext(), key_map_in=None, key_map_out=None):
21
+ super().__init__()
22
+ self.key_mapper_in = self.build_mapper(key_map_in, None, (
23
+ 'prompt -> prompt_ids', 'image -> image', 'attn_mask -> attn_mask', 'position_ids -> position_ids', 'neg_prompt -> neg_prompt_ids',
24
+ 'neg_attn_mask -> neg_attn_mask', 'neg_position_ids -> neg_position_ids', 'plugin_input -> plugin_input'))
25
+ self.key_mapper_out = self.build_mapper(key_map_out, None, None)
26
+
27
+ self.denoiser = denoiser
28
+ self.TE = TE
29
+ self.vae = vae
30
+ self.noise_sampler = noise_sampler
31
+ self.tokenizer = tokenizer
32
+ self.min_attnmask = min_attnmask
33
+
34
+ self.pred_type = pred_type
35
+
36
+ self.TE_hook_cfg = TEHookCFG.create(TE_hook_cfg)
37
+ self.cfg_context = cfg_context
38
+ self.tokenizer.N_repeats = self.TE_hook_cfg.tokenizer_repeats
39
+
40
+ def post_init(self):
41
+ self.make_TE_hook(self.TE_hook_cfg)
42
+
43
+ self.vae_trainable = False
44
+ if self.vae is not None:
45
+ for p in self.vae.parameters():
46
+ if p.requires_grad:
47
+ self.vae_trainable = True
48
+ break
49
+
50
+ self.TE_trainable = False
51
+ for p in self.TE.parameters():
52
+ if p.requires_grad:
53
+ self.TE_trainable = True
54
+ break
55
+
56
+ def make_TE_hook(self, TE_hook_cfg):
57
+ # Hook and extend text_encoder
58
+ self.text_enc_hook = TEEXHook.hook(self.TE, self.tokenizer, N_repeats=TE_hook_cfg.tokenizer_repeats,
59
+ clip_skip=TE_hook_cfg.clip_skip, clip_final_norm=TE_hook_cfg.clip_final_norm)
60
+
61
+ def get_latents(self, image: Tensor):
62
+ if image.shape[1] == 3:
63
+ with torch.no_grad() if self.vae_trainable else nullcontext():
64
+ latents = self.vae.encode(image.to(dtype=self.vae.dtype)).latent_dist.sample()
65
+ latents = latents*self.vae.config.scaling_factor
66
+ else:
67
+ latents = image # Cached latents
68
+ return latents
69
+
70
+ def forward_TE(self, prompt_ids, timesteps, attn_mask=None, position_ids=None, plugin_input={}, **kwargs):
71
+ input_all = dict(prompt_ids=prompt_ids, timesteps=timesteps, position_ids=position_ids, attn_mask=attn_mask, **plugin_input)
72
+ if hasattr(self.TE, 'input_feeder'):
73
+ for feeder in self.TE.input_feeder:
74
+ feeder(input_all)
75
+ # Get the text embedding for conditioning
76
+ encoder_hidden_states = self.TE(prompt_ids, position_ids=position_ids, attention_mask=attn_mask, output_hidden_states=True)[0]
77
+ return encoder_hidden_states
78
+
79
+ def forward_denoiser(self, x_t, prompt_ids, encoder_hidden_states, timesteps, attn_mask=None, position_ids=None, plugin_input={}, **kwargs):
80
+ if attn_mask is not None:
81
+ attn_mask[:, :self.min_attnmask] = 1
82
+ encoder_hidden_states, attn_mask = pad_attn_bias(encoder_hidden_states, attn_mask)
83
+
84
+ input_all = dict(prompt_ids=prompt_ids, timesteps=timesteps, position_ids=position_ids, attn_mask=attn_mask,
85
+ encoder_hidden_states=encoder_hidden_states, **plugin_input)
86
+ if hasattr(self.denoiser, 'input_feeder'):
87
+ for feeder in self.denoiser.input_feeder:
88
+ feeder(input_all)
89
+ model_pred = self.denoiser(x_t, timesteps, encoder_hidden_states, encoder_attention_mask=attn_mask).sample # Predict the noise residual
90
+ return model_pred
91
+
92
+ def model_forward(self, prompt_ids, image, attn_mask=None, position_ids=None, neg_prompt_ids=None, neg_attn_mask=None, neg_position_ids=None,
93
+ plugin_input={}, **kwargs):
94
+ # input prepare
95
+ x_0 = self.get_latents(image)
96
+ x_t, noise, sigma, timesteps = self.noise_sampler.add_noise_rand_t(x_0)
97
+ x_t_in = x_t*self.noise_sampler.c_in(sigma).to(dtype=x_t.dtype)
98
+
99
+ if neg_prompt_ids:
100
+ prompt_ids = torch.cat([neg_prompt_ids, prompt_ids], dim=0)
101
+ if neg_attn_mask:
102
+ attn_mask = torch.cat([neg_attn_mask, attn_mask], dim=0)
103
+ if neg_position_ids:
104
+ position_ids = torch.cat([neg_position_ids, position_ids], dim=0)
105
+
106
+ # model forward
107
+ x_t_in, timesteps = self.cfg_context.pre(x_t_in, timesteps)
108
+ encoder_hidden_states = self.forward_TE(prompt_ids, timesteps, attn_mask=attn_mask, position_ids=position_ids,
109
+ plugin_input=plugin_input, **kwargs)
110
+ model_pred = self.forward_denoiser(x_t_in, prompt_ids, encoder_hidden_states, timesteps, attn_mask=attn_mask, position_ids=position_ids,
111
+ plugin_input=plugin_input, **kwargs)
112
+ model_pred = self.cfg_context.post(model_pred)
113
+
114
+ return dict(model_pred=model_pred, noise=noise, sigma=sigma, timesteps=timesteps, x_0=x_0, x_t=x_t, pred_type=self.pred_type,
115
+ noise_sampler=self.noise_sampler)
116
+
117
+ def forward(self, ds_name=None, **kwargs):
118
+ model_args, model_kwargs = self.get_map_data(self.key_mapper_in, kwargs, ds_name)
119
+ out = self.model_forward(*model_args, **model_kwargs)
120
+ return self.get_map_data(self.key_mapper_out, out, ds_name=ds_name)[1]
121
+
122
+ def enable_gradient_checkpointing(self):
123
+ def grad_ckpt_enable(m):
124
+ if getattr(m, 'gradient_checkpointing', False):
125
+ m.training = True
126
+
127
+ self.denoiser.enable_gradient_checkpointing()
128
+ if self.TE_trainable:
129
+ self.TE.gradient_checkpointing_enable()
130
+ self.apply(grad_ckpt_enable)
131
+
132
+ def enable_xformers(self):
133
+ self.denoiser.enable_xformers_memory_efficient_attention()
134
+
135
+ @property
136
+ def trainable_parameters(self):
137
+ return [p for p in self.parameters() if p.requires_grad]
138
+
139
+ @property
140
+ def trainable_models(self) -> Dict[str, nn.Module]:
141
+ return {'self':self}
142
+
143
+ def set_dtype(self, dtype, vae_dtype):
144
+ self.dtype = dtype
145
+ self.vae_dtype = vae_dtype
146
+ # Move vae and text_encoder to device and cast to weight_dtype
147
+ if self.vae is not None:
148
+ self.vae = self.vae.to(dtype=vae_dtype)
149
+ if not self.TE_trainable:
150
+ self.TE = self.TE.to(dtype=dtype)
151
+
152
+ @classmethod
153
+ def from_pretrained(cls, models: Union[partial, Dict[str, nn.Module]], **kwargs):
154
+ models = models() if isinstance(models, partial) else models
155
+ return cls(models['denoiser'], models['TE'], models['vae'], models['noise_sampler'], models['tokenizer'], **kwargs)
156
+
157
+ class SDXLWrapper(SD15Wrapper):
158
+ def __init__(self, denoiser: UNet2DConditionModel, TE, vae: AutoencoderKL, noise_sampler: BaseSampler, tokenizer, min_attnmask=0,
159
+ pred_type='eps', TE_hook_cfg:TEHookCFG=SDXL_TEHookCFG, cfg_context=CFGContext(), key_map_in=None, key_map_out=None):
160
+ super().__init__(denoiser, TE, vae, noise_sampler, tokenizer, min_attnmask, pred_type, TE_hook_cfg, cfg_context, key_map_in, key_map_out)
161
+ self.key_mapper_in = self.build_mapper(key_map_in, None, (
162
+ 'prompt -> prompt_ids', 'image -> image', 'attn_mask -> attn_mask', 'position_ids -> position_ids', 'neg_prompt -> neg_prompt_ids',
163
+ 'neg_attn_mask -> neg_attn_mask', 'neg_position_ids -> neg_position_ids', 'plugin_input -> plugin_input', 'coord -> crop_info'))
164
+
165
+ def make_TE_hook(self, TE_hook_cfg):
166
+ # Hook and extend text_encoder
167
+ self.text_enc_hook = ComposeTEEXHook.hook(self.TE, self.tokenizer, N_repeats=TE_hook_cfg.tokenizer_repeats,
168
+ clip_skip=TE_hook_cfg.clip_skip, clip_final_norm=TE_hook_cfg.clip_final_norm)
169
+
170
+ def forward_TE(self, prompt_ids, timesteps, attn_mask=None, position_ids=None, plugin_input={}, **kwargs):
171
+ input_all = dict(prompt_ids=prompt_ids, timesteps=timesteps, position_ids=position_ids, attn_mask=attn_mask, **plugin_input)
172
+ if hasattr(self.TE, 'input_feeder'):
173
+ for feeder in self.TE.input_feeder:
174
+ feeder(input_all)
175
+ # Get the text embedding for conditioning
176
+ encoder_hidden_states, pooled_output = self.TE(prompt_ids, position_ids=position_ids, attention_mask=attn_mask, output_hidden_states=True)
177
+ return encoder_hidden_states, pooled_output
178
+
179
+ def forward_denoiser(self, x_t, prompt_ids, encoder_hidden_states, timesteps, added_cond_kwargs, attn_mask=None, position_ids=None,
180
+ plugin_input={}, **kwargs):
181
+ if attn_mask is not None:
182
+ attn_mask[:, :self.min_attnmask] = 1
183
+ encoder_hidden_states, attn_mask = pad_attn_bias(encoder_hidden_states, attn_mask)
184
+
185
+ input_all = dict(prompt_ids=prompt_ids, timesteps=timesteps, position_ids=position_ids, attn_mask=attn_mask,
186
+ encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs, **plugin_input)
187
+ if hasattr(self.denoiser, 'input_feeder'):
188
+ for feeder in self.denoiser.input_feeder:
189
+ feeder(input_all)
190
+ model_pred = self.denoiser(x_t, timesteps, encoder_hidden_states, encoder_attention_mask=attn_mask,
191
+ added_cond_kwargs=added_cond_kwargs).sample # Predict the noise residual
192
+ return model_pred
193
+
194
+ def model_forward(self, prompt_ids, image, attn_mask=None, position_ids=None, neg_prompt_ids=None, neg_attn_mask=None, neg_position_ids=None,
195
+ crop_info=None, plugin_input={}):
196
+ # input prepare
197
+ x_0 = self.get_latents(image)
198
+ x_t, noise, sigma, timesteps = self.noise_sampler.add_noise_rand_t(x_0)
199
+ x_t_in = x_t*self.noise_sampler.c_in(sigma).to(dtype=x_t.dtype)
200
+
201
+ if neg_prompt_ids:
202
+ prompt_ids = torch.cat([neg_prompt_ids, prompt_ids], dim=0)
203
+ if neg_attn_mask:
204
+ attn_mask = torch.cat([neg_attn_mask, attn_mask], dim=0)
205
+ if neg_position_ids:
206
+ position_ids = torch.cat([neg_position_ids, position_ids], dim=0)
207
+
208
+ # model forward
209
+ x_t_in, timesteps = self.cfg_context.pre(x_t_in, timesteps)
210
+ encoder_hidden_states, pooled_output = self.forward_TE(prompt_ids, timesteps, attn_mask=attn_mask, position_ids=position_ids,
211
+ plugin_input=plugin_input)
212
+ added_cond_kwargs = {"text_embeds":pooled_output[-1], "time_ids":crop_info}
213
+ model_pred = self.forward_denoiser(x_t_in, prompt_ids, encoder_hidden_states, timesteps, added_cond_kwargs=added_cond_kwargs,
214
+ attn_mask=attn_mask, position_ids=position_ids, plugin_input=plugin_input)
215
+ model_pred = self.cfg_context.post(model_pred)
216
+
217
+ return dict(model_pred=model_pred, noise=noise, sigma=sigma, timesteps=timesteps, x_0=x_0, x_t=x_t, pred_type=self.pred_type,
218
+ noise_sampler=self.noise_sampler)
@@ -0,0 +1,20 @@
1
+ from dataclasses import dataclass
2
+ from rainbowneko.utils import is_dict
3
+
4
+ class TEHookCFG:
5
+ def __init__(self, tokenizer_repeats: int = 1, clip_skip: int = 0, clip_final_norm: bool = True):
6
+ self.tokenizer_repeats = tokenizer_repeats
7
+ self.clip_skip = clip_skip
8
+ self.clip_final_norm = clip_final_norm
9
+
10
+ @classmethod
11
+ def create(cls, cfg):
12
+ if is_dict(cfg):
13
+ return cls(**cfg)
14
+ elif isinstance(cfg, cls):
15
+ return cfg
16
+ else:
17
+ raise ValueError(f'Invalid TEHookCFG type: {type(cfg)}')
18
+
19
+ SD15_TEHookCFG = TEHookCFG()
20
+ SDXL_TEHookCFG = TEHookCFG(clip_skip=1, clip_final_norm=False)
@@ -0,0 +1 @@
1
+ from .embpt import CfgEmbPTParser
@@ -0,0 +1,32 @@
1
+ from typing import Dict, Tuple, List
2
+ from rainbowneko.utils import Path_Like
3
+ from hcpdiff.models import EmbeddingPTHook
4
+ from torch import Tensor
5
+
6
+ class CfgEmbPTParser:
7
+ def __init__(self, emb_dir: Path_Like, cfg_pt: Dict[str, Dict], lr: float = 1e-5, weight_decay: float = 0):
8
+ self.emb_dir = emb_dir
9
+ self.cfg_pt = cfg_pt
10
+ self.lr = lr
11
+ self.weight_decay = weight_decay
12
+
13
+ def get_params_group(self, model) -> Tuple[List, Dict[str, Tensor]]:
14
+ self.embedding_hook, self.ex_words_emb = EmbeddingPTHook.hook_from_dir(
15
+ self.emb_dir, model.tokenizer, model.TE, N_repeats=model.tokenizer.N_repeats)
16
+ self.embedding_hook.requires_grad_(False)
17
+
18
+ train_params_emb = []
19
+ train_pts = {}
20
+ for pt_name, info in self.cfg_pt.items():
21
+ word_emb = self.ex_words_emb[pt_name]
22
+ train_pts[pt_name] = word_emb
23
+ word_emb.requires_grad = True
24
+ self.embedding_hook.emb_train.append(word_emb)
25
+ param_group = {'params':word_emb}
26
+ if 'lr' in info:
27
+ param_group['lr'] = info.lr
28
+ if 'weight_decay' in info:
29
+ param_group['weight_decay'] = info.weight_decay
30
+ train_params_emb.append(param_group)
31
+
32
+ return train_params_emb, train_pts
@@ -2,7 +2,7 @@ import argparse
2
2
  import json
3
3
  import os
4
4
 
5
- from hcpdiff.utils.img_size_tool import types_support
5
+ from rainbowneko.utils import types_support
6
6
 
7
7
  parser = argparse.ArgumentParser(description='Stable Diffusion Training')
8
8
  parser.add_argument('--data_root', type=str, default='')
@@ -0,0 +1,94 @@
1
+ import argparse
2
+ import json
3
+ import os.path
4
+ from typing import Callable
5
+
6
+ import pyarrow.parquet as pq
7
+ import torch
8
+ from PIL import Image
9
+ from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
10
+ from tqdm.auto import tqdm
11
+
12
+ from hcpdiff.data.caption_loader import auto_caption_loader
13
+
14
+ class DatasetCreator:
15
+ def __init__(self, pretrained_model, out_dir: str, img_w: int=512, img_h: int=512):
16
+ scheduler = DPMSolverMultistepScheduler(
17
+ beta_start = 0.00085,
18
+ beta_end = 0.012,
19
+ beta_schedule = 'scaled_linear',
20
+ algorithm_type = 'dpmsolver++',
21
+ use_karras_sigmas = True,
22
+ )
23
+
24
+ self.pipeline = DiffusionPipeline.from_pretrained(pretrained_model, scheduler=scheduler, torch_dtype=torch.float16)
25
+ self.pipeline.requires_safety_checker = False
26
+ self.pipeline.safety_checker = None
27
+ self.pipeline.to("cuda")
28
+ self.pipeline.unet.to(memory_format=torch.channels_last)
29
+ #self.pipeline.enable_xformers_memory_efficient_attention()
30
+
31
+ self.out_dir = out_dir
32
+ self.img_w = img_w
33
+ self.img_h = img_h
34
+
35
+ def create_from_prompt_dataset(self, prompt_file: str, negative_prompt: str, bs: int, num: int=None, repeat:int=1, save_fmt:str='txt',
36
+ callback: Callable[[int, int], bool] = None):
37
+ os.makedirs(self.out_dir, exist_ok=True)
38
+ data = auto_caption_loader(prompt_file).load()
39
+ data = list(data.items())
40
+ data = self.split_batch(data, bs) # [[(k,v),...],...]
41
+
42
+ if num is None:
43
+ num = len(data)
44
+ total = num*bs
45
+ count = 0
46
+ captions = {}
47
+ with torch.inference_mode():
48
+ for i in tqdm(range(num)):
49
+ for r in range(repeat):
50
+ name_batch, p_batch = list(zip(*data[i%len(data)]))
51
+ imgs = self.pipeline(list(p_batch), negative_prompt=[negative_prompt]*len(p_batch), num_inference_steps=25,
52
+ width=self.img_w, height=self.img_h).images
53
+ for name, prompt, img in zip(name_batch, p_batch, imgs):
54
+ img.save(os.path.join(self.out_dir, f'{count}_{name}.png'), format='PNG')
55
+ captions[f'{count}_{name}'] = prompt
56
+ count += 1
57
+ if callback:
58
+ if not callback(count, total):
59
+ break
60
+
61
+ if save_fmt=='txt':
62
+ for k, v in captions.items():
63
+ with open(os.path.join(self.out_dir, f'{k}.txt'), "w") as f:
64
+ f.write(v)
65
+ elif save_fmt=='json':
66
+ with open(os.path.join(self.out_dir, f'image_captions.json'), "w") as f:
67
+ json.dump(captions, f)
68
+ else:
69
+ raise ValueError(f"Invalid save_fmt: {save_fmt}")
70
+
71
+ @staticmethod
72
+ def split_batch(data, bs):
73
+ return [data[i:i+bs] for i in range(0, len(data), bs)]
74
+
75
+ # python dataset_generator.py --prompt_file 标注文件或文件夹路径 --model 模型名称 --out_dir 输出文件夹路径 --repeat 每个prompt生成几个图 --bs batch_size --img_w 图片宽度 --img_h 图片高度
76
+ # python dataset_generator.py --prompt_file 标注文件或文件夹路径 --model 模型名称 --out_dir 输出文件夹路径 --repeat 1 --bs 4 --img_w 640 --img_h 640
77
+ if __name__ == '__main__':
78
+ torch.backends.cudnn.benchmark = True
79
+ parser = argparse.ArgumentParser(description='Diffusion Dataset Generator')
80
+ parser.add_argument('--prompt_file', type=str, default='')
81
+ parser.add_argument('--model', type=str, default='runwayml/stable-diffusion-v1-5')
82
+ parser.add_argument('--out_dir', type=str, default=r'./prompt_ds')
83
+ parser.add_argument('--negative_prompt', type=str,
84
+ default='lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry')
85
+ parser.add_argument('--num', type=int, default=200)
86
+ parser.add_argument('--repeat', type=int, default=1)
87
+ parser.add_argument('--save_fmt', type=str, default='txt')
88
+ parser.add_argument('--bs', type=int, default=4)
89
+ parser.add_argument('--img_w', type=int, default=512)
90
+ parser.add_argument('--img_h', type=int, default=512)
91
+ args = parser.parse_args()
92
+
93
+ ds_creator = DatasetCreator(args.model, args.out_dir, args.img_w, args.img_h)
94
+ ds_creator.create_from_prompt_dataset(args.prompt_file, args.negative_prompt, args.bs, args.num, repeat=args.repeat, save_fmt=args.save_fmt)
@@ -0,0 +1,24 @@
1
+ from diffusers import DiffusionPipeline
2
+ import argparse
3
+ import torch
4
+
5
+ if __name__ == '__main__':
6
+ parser = argparse.ArgumentParser(description='Download Model')
7
+ parser.add_argument('--model', type=str, default='runwayml/stable-diffusion-v1-5')
8
+ parser.add_argument("--fp16", default=False, action="store_true")
9
+ parser.add_argument("--use_safetensors", default=False, action="store_true")
10
+ parser.add_argument("--out_path", type=str, default='ckpts/sd15')
11
+ args = parser.parse_args()
12
+
13
+ load_args = dict(torch_dtype = torch.float16 if args.fp16 else torch.float32)
14
+ save_args = dict()
15
+
16
+ if args.fp16:
17
+ load_args['variant'] = "fp16"
18
+ save_args['variant'] = "fp16"
19
+ if args.use_safetensors:
20
+ load_args['use_safetensors'] = True
21
+ save_args['safe_serialization'] = True
22
+
23
+ pipe = DiffusionPipeline.from_pretrained(args.model, **load_args)
24
+ pipe.save_pretrained(args.out_path, **save_args)
@@ -1,23 +1,5 @@
1
- import sys
2
- import shutil
3
- import os
1
+ from rainbowneko.tools.init_proj import copy_package_data
4
2
 
5
3
  def main():
6
- prefix = sys.prefix
7
- if not os.path.exists(os.path.join(prefix, 'hcpdiff')):
8
- prefix = os.path.join(prefix, 'local')
9
- try:
10
- if os.path.exists(r'./cfgs'):
11
- shutil.rmtree(r'./cfgs')
12
- if os.path.exists(r'./prompt_tuning_template'):
13
- shutil.rmtree(r'./prompt_tuning_template')
14
- shutil.copytree(os.path.join(prefix, 'hcpdiff/cfgs'), r'./cfgs')
15
- shutil.copytree(os.path.join(prefix, 'hcpdiff/prompt_tuning_template'), r'./prompt_tuning_template')
16
- except:
17
- try:
18
- shutil.copytree(os.path.join(prefix, '../hcpdiff/cfgs'), r'./cfgs')
19
- shutil.copytree(os.path.join(prefix, '../hcpdiff/prompt_tuning_template'), r'./prompt_tuning_template')
20
- except:
21
- this_file_dir = os.path.dirname(os.path.abspath(__file__))
22
- shutil.copytree(os.path.join(this_file_dir, '../../cfgs'), r'./cfgs')
23
- shutil.copytree(os.path.join(this_file_dir, '../../prompt_tuning_template'), r'./prompt_tuning_template')
4
+ copy_package_data('hcpdiff', 'cfgs', './cfgs')
5
+ copy_package_data('hcpdiff', 'prompt_template', './prompt_template')