hcpdiff 0.9.1__py3-none-any.whl → 2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +252 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +90 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +3 -3
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text.py +40 -0
  21. hcpdiff/data/source/text2img.py +36 -74
  22. hcpdiff/data/source/text2img_cond.py +9 -15
  23. hcpdiff/diffusion/__init__.py +0 -0
  24. hcpdiff/diffusion/noise/__init__.py +2 -0
  25. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  26. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  27. hcpdiff/diffusion/sampler/__init__.py +5 -0
  28. hcpdiff/diffusion/sampler/base.py +72 -0
  29. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  30. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  31. hcpdiff/diffusion/sampler/edm.py +22 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  35. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  36. hcpdiff/easy/__init__.py +2 -0
  37. hcpdiff/easy/cfg/__init__.py +3 -0
  38. hcpdiff/easy/cfg/sd15_train.py +207 -0
  39. hcpdiff/easy/cfg/sdxl_train.py +147 -0
  40. hcpdiff/easy/cfg/t2i.py +228 -0
  41. hcpdiff/easy/model/__init__.py +2 -0
  42. hcpdiff/easy/model/cnet.py +31 -0
  43. hcpdiff/easy/model/loader.py +79 -0
  44. hcpdiff/easy/sampler.py +46 -0
  45. hcpdiff/evaluate/__init__.py +1 -0
  46. hcpdiff/evaluate/previewer.py +60 -0
  47. hcpdiff/loss/__init__.py +4 -1
  48. hcpdiff/loss/base.py +41 -0
  49. hcpdiff/loss/gw.py +35 -0
  50. hcpdiff/loss/ssim.py +37 -0
  51. hcpdiff/loss/vlb.py +79 -0
  52. hcpdiff/loss/weighting.py +66 -0
  53. hcpdiff/models/__init__.py +2 -2
  54. hcpdiff/models/cfg_context.py +17 -14
  55. hcpdiff/models/compose/compose_hook.py +44 -23
  56. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  57. hcpdiff/models/compose/sdxl_composer.py +4 -4
  58. hcpdiff/models/controlnet.py +16 -16
  59. hcpdiff/models/lora_base_patch.py +14 -25
  60. hcpdiff/models/lora_layers.py +3 -9
  61. hcpdiff/models/lora_layers_patch.py +14 -24
  62. hcpdiff/models/text_emb_ex.py +84 -6
  63. hcpdiff/models/textencoder_ex.py +54 -18
  64. hcpdiff/models/wrapper/__init__.py +3 -0
  65. hcpdiff/models/wrapper/pixart.py +19 -0
  66. hcpdiff/models/wrapper/sd.py +218 -0
  67. hcpdiff/models/wrapper/utils.py +20 -0
  68. hcpdiff/parser/__init__.py +1 -0
  69. hcpdiff/parser/embpt.py +32 -0
  70. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  71. hcpdiff/tools/dataset_generator.py +94 -0
  72. hcpdiff/tools/download_hf_model.py +24 -0
  73. hcpdiff/tools/init_proj.py +3 -21
  74. hcpdiff/tools/lora_convert.py +18 -17
  75. hcpdiff/tools/save_model.py +12 -0
  76. hcpdiff/tools/sd2diffusers.py +1 -1
  77. hcpdiff/train_colo.py +1 -1
  78. hcpdiff/train_deepspeed.py +1 -1
  79. hcpdiff/trainer_ac.py +79 -0
  80. hcpdiff/trainer_ac_single.py +31 -0
  81. hcpdiff/utils/__init__.py +0 -2
  82. hcpdiff/utils/inpaint_pipe.py +7 -2
  83. hcpdiff/utils/net_utils.py +29 -6
  84. hcpdiff/utils/pipe_hook.py +24 -7
  85. hcpdiff/utils/utils.py +21 -4
  86. hcpdiff/workflow/__init__.py +15 -10
  87. hcpdiff/workflow/daam/__init__.py +1 -0
  88. hcpdiff/workflow/daam/act.py +66 -0
  89. hcpdiff/workflow/daam/hook.py +109 -0
  90. hcpdiff/workflow/diffusion.py +118 -128
  91. hcpdiff/workflow/fast.py +31 -0
  92. hcpdiff/workflow/flow.py +67 -0
  93. hcpdiff/workflow/io.py +36 -130
  94. hcpdiff/workflow/model.py +46 -43
  95. hcpdiff/workflow/text.py +60 -47
  96. hcpdiff/workflow/utils.py +32 -12
  97. hcpdiff/workflow/vae.py +37 -38
  98. hcpdiff-2.2.dist-info/METADATA +299 -0
  99. hcpdiff-2.2.dist-info/RECORD +115 -0
  100. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/WHEEL +1 -1
  101. hcpdiff-2.2.dist-info/entry_points.txt +5 -0
  102. hcpdiff/ckpt_manager/base.py +0 -16
  103. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  104. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  105. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
  106. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  107. hcpdiff/data/bucket.py +0 -358
  108. hcpdiff/data/caption_loader.py +0 -80
  109. hcpdiff/data/cond_dataset.py +0 -40
  110. hcpdiff/data/crop_info_dataset.py +0 -40
  111. hcpdiff/data/data_processor.py +0 -33
  112. hcpdiff/data/pair_dataset.py +0 -146
  113. hcpdiff/data/sampler.py +0 -54
  114. hcpdiff/data/source/base.py +0 -30
  115. hcpdiff/data/utils.py +0 -80
  116. hcpdiff/deprecated/__init__.py +0 -1
  117. hcpdiff/deprecated/cfg_converter.py +0 -81
  118. hcpdiff/deprecated/lora_convert.py +0 -31
  119. hcpdiff/infer_workflow.py +0 -57
  120. hcpdiff/loggers/__init__.py +0 -13
  121. hcpdiff/loggers/base_logger.py +0 -76
  122. hcpdiff/loggers/cli_logger.py +0 -40
  123. hcpdiff/loggers/preview/__init__.py +0 -1
  124. hcpdiff/loggers/preview/image_previewer.py +0 -149
  125. hcpdiff/loggers/tensorboard_logger.py +0 -30
  126. hcpdiff/loggers/wandb_logger.py +0 -31
  127. hcpdiff/loggers/webui_logger.py +0 -9
  128. hcpdiff/loss/min_snr_loss.py +0 -52
  129. hcpdiff/models/layers.py +0 -81
  130. hcpdiff/models/plugin.py +0 -348
  131. hcpdiff/models/wrapper.py +0 -75
  132. hcpdiff/noise/__init__.py +0 -3
  133. hcpdiff/noise/noise_base.py +0 -16
  134. hcpdiff/noise/pyramid_noise.py +0 -50
  135. hcpdiff/noise/zero_terminal.py +0 -44
  136. hcpdiff/train_ac.py +0 -566
  137. hcpdiff/train_ac_single.py +0 -39
  138. hcpdiff/utils/caption_tools.py +0 -105
  139. hcpdiff/utils/cfg_net_tools.py +0 -321
  140. hcpdiff/utils/cfg_resolvers.py +0 -16
  141. hcpdiff/utils/ema.py +0 -52
  142. hcpdiff/utils/img_size_tool.py +0 -248
  143. hcpdiff/vis/__init__.py +0 -3
  144. hcpdiff/vis/base_interface.py +0 -12
  145. hcpdiff/vis/disk_interface.py +0 -48
  146. hcpdiff/vis/webui_interface.py +0 -17
  147. hcpdiff/viser_fast.py +0 -138
  148. hcpdiff/visualizer.py +0 -265
  149. hcpdiff/visualizer_reloadable.py +0 -237
  150. hcpdiff/workflow/base.py +0 -59
  151. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  152. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  153. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  154. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  155. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  156. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  157. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  158. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  159. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  160. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  161. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  162. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  163. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  164. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  165. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  166. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  167. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  168. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  169. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  170. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  171. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  172. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  173. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  174. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  175. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  176. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  177. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  178. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  179. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  180. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  181. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  182. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  183. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  184. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  185. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  186. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  187. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  188. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  189. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  190. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  191. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  192. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  193. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  194. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  195. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
  196. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  197. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
  198. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
  199. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  200. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  201. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  202. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  203. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  204. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  205. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  206. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  207. hcpdiff-0.9.1.dist-info/METADATA +0 -199
  208. hcpdiff-0.9.1.dist-info/RECORD +0 -160
  209. hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
  210. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info/licenses}/LICENSE +0 -0
  211. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,94 @@
1
+ import argparse
2
+ import json
3
+ import os.path
4
+ from typing import Callable
5
+
6
+ import pyarrow.parquet as pq
7
+ import torch
8
+ from PIL import Image
9
+ from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
10
+ from tqdm.auto import tqdm
11
+
12
+ from hcpdiff.data.caption_loader import auto_caption_loader
13
+
14
+ class DatasetCreator:
15
+ def __init__(self, pretrained_model, out_dir: str, img_w: int=512, img_h: int=512):
16
+ scheduler = DPMSolverMultistepScheduler(
17
+ beta_start = 0.00085,
18
+ beta_end = 0.012,
19
+ beta_schedule = 'scaled_linear',
20
+ algorithm_type = 'dpmsolver++',
21
+ use_karras_sigmas = True,
22
+ )
23
+
24
+ self.pipeline = DiffusionPipeline.from_pretrained(pretrained_model, scheduler=scheduler, torch_dtype=torch.float16)
25
+ self.pipeline.requires_safety_checker = False
26
+ self.pipeline.safety_checker = None
27
+ self.pipeline.to("cuda")
28
+ self.pipeline.unet.to(memory_format=torch.channels_last)
29
+ #self.pipeline.enable_xformers_memory_efficient_attention()
30
+
31
+ self.out_dir = out_dir
32
+ self.img_w = img_w
33
+ self.img_h = img_h
34
+
35
+ def create_from_prompt_dataset(self, prompt_file: str, negative_prompt: str, bs: int, num: int=None, repeat:int=1, save_fmt:str='txt',
36
+ callback: Callable[[int, int], bool] = None):
37
+ os.makedirs(self.out_dir, exist_ok=True)
38
+ data = auto_caption_loader(prompt_file).load()
39
+ data = list(data.items())
40
+ data = self.split_batch(data, bs) # [[(k,v),...],...]
41
+
42
+ if num is None:
43
+ num = len(data)
44
+ total = num*bs
45
+ count = 0
46
+ captions = {}
47
+ with torch.inference_mode():
48
+ for i in tqdm(range(num)):
49
+ for r in range(repeat):
50
+ name_batch, p_batch = list(zip(*data[i%len(data)]))
51
+ imgs = self.pipeline(list(p_batch), negative_prompt=[negative_prompt]*len(p_batch), num_inference_steps=25,
52
+ width=self.img_w, height=self.img_h).images
53
+ for name, prompt, img in zip(name_batch, p_batch, imgs):
54
+ img.save(os.path.join(self.out_dir, f'{count}_{name}.png'), format='PNG')
55
+ captions[f'{count}_{name}'] = prompt
56
+ count += 1
57
+ if callback:
58
+ if not callback(count, total):
59
+ break
60
+
61
+ if save_fmt=='txt':
62
+ for k, v in captions.items():
63
+ with open(os.path.join(self.out_dir, f'{k}.txt'), "w") as f:
64
+ f.write(v)
65
+ elif save_fmt=='json':
66
+ with open(os.path.join(self.out_dir, f'image_captions.json'), "w") as f:
67
+ json.dump(captions, f)
68
+ else:
69
+ raise ValueError(f"Invalid save_fmt: {save_fmt}")
70
+
71
+ @staticmethod
72
+ def split_batch(data, bs):
73
+ return [data[i:i+bs] for i in range(0, len(data), bs)]
74
+
75
+ # python dataset_generator.py --prompt_file 标注文件或文件夹路径 --model 模型名称 --out_dir 输出文件夹路径 --repeat 每个prompt生成几个图 --bs batch_size --img_w 图片宽度 --img_h 图片高度
76
+ # python dataset_generator.py --prompt_file 标注文件或文件夹路径 --model 模型名称 --out_dir 输出文件夹路径 --repeat 1 --bs 4 --img_w 640 --img_h 640
77
+ if __name__ == '__main__':
78
+ torch.backends.cudnn.benchmark = True
79
+ parser = argparse.ArgumentParser(description='Diffusion Dataset Generator')
80
+ parser.add_argument('--prompt_file', type=str, default='')
81
+ parser.add_argument('--model', type=str, default='runwayml/stable-diffusion-v1-5')
82
+ parser.add_argument('--out_dir', type=str, default=r'./prompt_ds')
83
+ parser.add_argument('--negative_prompt', type=str,
84
+ default='lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry')
85
+ parser.add_argument('--num', type=int, default=200)
86
+ parser.add_argument('--repeat', type=int, default=1)
87
+ parser.add_argument('--save_fmt', type=str, default='txt')
88
+ parser.add_argument('--bs', type=int, default=4)
89
+ parser.add_argument('--img_w', type=int, default=512)
90
+ parser.add_argument('--img_h', type=int, default=512)
91
+ args = parser.parse_args()
92
+
93
+ ds_creator = DatasetCreator(args.model, args.out_dir, args.img_w, args.img_h)
94
+ ds_creator.create_from_prompt_dataset(args.prompt_file, args.negative_prompt, args.bs, args.num, repeat=args.repeat, save_fmt=args.save_fmt)
@@ -0,0 +1,24 @@
1
+ from diffusers import DiffusionPipeline
2
+ import argparse
3
+ import torch
4
+
5
+ if __name__ == '__main__':
6
+ parser = argparse.ArgumentParser(description='Download Model')
7
+ parser.add_argument('--model', type=str, default='runwayml/stable-diffusion-v1-5')
8
+ parser.add_argument("--fp16", default=False, action="store_true")
9
+ parser.add_argument("--use_safetensors", default=False, action="store_true")
10
+ parser.add_argument("--out_path", type=str, default='ckpts/sd15')
11
+ args = parser.parse_args()
12
+
13
+ load_args = dict(torch_dtype = torch.float16 if args.fp16 else torch.float32)
14
+ save_args = dict()
15
+
16
+ if args.fp16:
17
+ load_args['variant'] = "fp16"
18
+ save_args['variant'] = "fp16"
19
+ if args.use_safetensors:
20
+ load_args['use_safetensors'] = True
21
+ save_args['safe_serialization'] = True
22
+
23
+ pipe = DiffusionPipeline.from_pretrained(args.model, **load_args)
24
+ pipe.save_pretrained(args.out_path, **save_args)
@@ -1,23 +1,5 @@
1
- import sys
2
- import shutil
3
- import os
1
+ from rainbowneko.tools.init_proj import copy_package_data
4
2
 
5
3
  def main():
6
- prefix = sys.prefix
7
- if not os.path.exists(os.path.join(prefix, 'hcpdiff')):
8
- prefix = os.path.join(prefix, 'local')
9
- try:
10
- if os.path.exists(r'./cfgs'):
11
- shutil.rmtree(r'./cfgs')
12
- if os.path.exists(r'./prompt_tuning_template'):
13
- shutil.rmtree(r'./prompt_tuning_template')
14
- shutil.copytree(os.path.join(prefix, 'hcpdiff/cfgs'), r'./cfgs')
15
- shutil.copytree(os.path.join(prefix, 'hcpdiff/prompt_tuning_template'), r'./prompt_tuning_template')
16
- except:
17
- try:
18
- shutil.copytree(os.path.join(prefix, '../hcpdiff/cfgs'), r'./cfgs')
19
- shutil.copytree(os.path.join(prefix, '../hcpdiff/prompt_tuning_template'), r'./prompt_tuning_template')
20
- except:
21
- this_file_dir = os.path.dirname(os.path.abspath(__file__))
22
- shutil.copytree(os.path.join(this_file_dir, '../../cfgs'), r'./cfgs')
23
- shutil.copytree(os.path.join(this_file_dir, '../../prompt_tuning_template'), r'./prompt_tuning_template')
4
+ copy_package_data('hcpdiff', 'cfgs', './cfgs')
5
+ copy_package_data('hcpdiff', 'prompt_template', './prompt_template')
@@ -3,15 +3,14 @@ import os.path
3
3
  from typing import List
4
4
  import math
5
5
 
6
- from hcpdiff.ckpt_manager import auto_manager
7
- from hcpdiff.deprecated import convert_to_webui_maybe_old, convert_to_webui_xl_maybe_old
6
+ from rainbowneko.ckpt_manager import auto_ckpt_loader, NekoModelSaver
8
7
 
9
8
  class LoraConverter:
10
9
  com_name_unet = ['down_blocks', 'up_blocks', 'mid_block', 'transformer_blocks', 'to_q', 'to_k', 'to_v', 'to_out', 'proj_in', 'proj_out', 'input_blocks', 'middle_block', 'output_blocks']
11
10
  com_name_TE = ['self_attn', 'q_proj', 'v_proj', 'k_proj', 'out_proj', 'text_model']
12
11
  prefix_unet = 'lora_unet_'
13
12
  prefix_TE = 'lora_te_'
14
- prefix_TE_xl_clip_B = 'lora_te1_'
13
+ prefix_TE_xl_clip_L = 'lora_te1_'
15
14
  prefix_TE_xl_clip_bigG = 'lora_te2_'
16
15
 
17
16
  lora_w_map = {'lora_down.weight': 'W_down', 'lora_up.weight':'W_up'}
@@ -26,14 +25,14 @@ class LoraConverter:
26
25
  sd_TE = self.convert_from_webui_(state, prefix=self.prefix_TE, com_name=self.com_name_TE, com_name_tmp=self.com_name_TE_tmp)
27
26
  else:
28
27
  sd_unet = self.convert_from_webui_xl_unet_(state, prefix=self.prefix_unet, com_name=self.com_name_unet, com_name_tmp=self.com_name_unet_tmp)
29
- sd_TE = self.convert_from_webui_xl_te_(state, prefix=self.prefix_TE_xl_clip_B, com_name=self.com_name_TE, com_name_tmp=self.com_name_TE_tmp)
28
+ sd_TE = self.convert_from_webui_xl_te_(state, prefix=self.prefix_TE_xl_clip_L, com_name=self.com_name_TE, com_name_tmp=self.com_name_TE_tmp)
30
29
  sd_TE2 = self.convert_from_webui_xl_te_(state, prefix=self.prefix_TE_xl_clip_bigG, com_name=self.com_name_TE, com_name_tmp=self.com_name_TE_tmp)
31
30
  sd_TE.update(sd_TE2)
32
31
 
33
32
  if auto_scale_alpha:
34
33
  sd_unet = self.alpha_scale_from_webui(sd_unet)
35
34
  sd_TE = self.alpha_scale_from_webui(sd_TE)
36
- return {'lora': sd_TE}, {'lora': sd_unet}
35
+ return {'plugin': sd_TE}, {'plugin': sd_unet}
37
36
 
38
37
  def convert_to_webui(self, sd_unet, sd_TE, auto_scale_alpha=False, sdxl=False):
39
38
  sd_unet = self.convert_to_webui_(sd_unet, prefix=self.prefix_unet)
@@ -59,7 +58,6 @@ class LoraConverter:
59
58
  sd_covert[f'{model_k}.___.layer.{self.lora_w_map[lora_k]}'] = v
60
59
  return sd_covert
61
60
 
62
- @convert_to_webui_maybe_old
63
61
  def convert_to_webui_(self, state, prefix):
64
62
  sd_covert = {}
65
63
  for k, v in state.items():
@@ -75,7 +73,6 @@ class LoraConverter:
75
73
  sd_covert[f"{prefix}{model_k.replace('.', '_')}.{lora_k}"] = v
76
74
  return sd_covert
77
75
 
78
- @convert_to_webui_xl_maybe_old
79
76
  def convert_to_webui_xl_(self, state, prefix):
80
77
  sd_convert = {}
81
78
  for k, v in state.items():
@@ -90,7 +87,7 @@ class LoraConverter:
90
87
 
91
88
  new_k = f"{prefix}{model_k.replace('.', '_')}.{lora_k}"
92
89
  if 'clip' in new_k:
93
- new_k = new_k.replace('_clip_B', '1') if 'clip_B' in new_k else new_k.replace('_clip_bigG', '2')
90
+ new_k = new_k.replace('_clip_L', '1') if 'clip_L' in new_k else new_k.replace('_clip_bigG', '2')
94
91
  sd_convert[new_k] = v
95
92
  return sd_convert
96
93
 
@@ -103,7 +100,7 @@ class LoraConverter:
103
100
  model_k, lora_k = k[prefix_len:].split('.', 1)
104
101
  model_k = self.replace_all(model_k, com_name, com_name_tmp).replace('_', '.').replace('%', '_')
105
102
  if prefix == 'lora_te1_':
106
- model_k = f'clip_B.{model_k}'
103
+ model_k = f'clip_L.{model_k}'
107
104
  else:
108
105
  model_k = f'clip_bigG.{model_k}'
109
106
 
@@ -224,23 +221,27 @@ if __name__ == '__main__':
224
221
 
225
222
  # load lora model
226
223
  print('convert lora model')
227
- ckpt_manager = auto_manager(args.lora_path)
224
+ ckpt_loader = auto_ckpt_loader(args.lora_path)
225
+ ckpt_saver = NekoModelSaver(
226
+ format=ckpt_loader.format,
227
+ source=ckpt_loader.source,
228
+ )
228
229
 
229
230
  if args.from_webui:
230
- state = ckpt_manager.load_ckpt(args.lora_path)
231
+ state = ckpt_loader.load(args.lora_path)
231
232
  # convert the weight name
232
233
  sd_TE, sd_unet = converter.convert_from_webui(state, auto_scale_alpha=args.auto_scale_alpha, sdxl=args.sdxl)
233
234
  # wegiht save
234
235
  os.makedirs(args.dump_path, exist_ok=True)
235
236
  TE_path = os.path.join(args.dump_path, 'TE-'+lora_name)
236
237
  unet_path = os.path.join(args.dump_path, 'unet-'+lora_name)
237
- ckpt_manager._save_ckpt(sd_TE, save_path=TE_path)
238
- ckpt_manager._save_ckpt(sd_unet, save_path=unet_path)
238
+ ckpt_saver.save(sd_TE, TE_path)
239
+ ckpt_saver.save(sd_unet, unet_path)
239
240
  print('save text encoder lora to:', TE_path)
240
241
  print('save unet lora to:', unet_path)
241
242
  elif args.to_webui:
242
- sd_unet = ckpt_manager.load_ckpt(args.lora_path)
243
- sd_TE = ckpt_manager.load_ckpt(args.lora_path_TE) if args.lora_path_TE else {'lora':{}}
244
- state = converter.convert_to_webui(sd_unet['lora'], sd_TE['lora'], auto_scale_alpha=args.auto_scale_alpha, sdxl=args.sdxl)
245
- ckpt_manager._save_ckpt(state, save_path=args.dump_path)
243
+ sd_unet = ckpt_loader.load(args.lora_path)
244
+ sd_TE = ckpt_loader.load(args.lora_path_TE) if args.lora_path_TE else {'base':{}}
245
+ state = converter.convert_to_webui(sd_unet['base'], sd_TE['base'], auto_scale_alpha=args.auto_scale_alpha, sdxl=args.sdxl)
246
+ ckpt_saver.save(state, args.dump_path)
246
247
  print('save lora to:', args.dump_path)
@@ -0,0 +1,12 @@
1
+ from diffusers import DiffusionPipeline
2
+ import argparse
3
+
4
+ parser = argparse.ArgumentParser()
5
+ parser.add_argument("model", default=None, type=str)
6
+ parser.add_argument("output", default=None, type=str)
7
+ args = parser.parse_args()
8
+
9
+ pipe = DiffusionPipeline.from_pretrained(args.model, safety_checker=None, requires_safety_checker=False,
10
+ resume_download=True)
11
+
12
+ pipe.save_pretrained(args.output)
@@ -211,7 +211,7 @@ def sd_vae_to_diffuser(args):
211
211
  def convert_ckpt(args):
212
212
  pipe = load_sd_ckpt(
213
213
  args.checkpoint_path,
214
- original_config_file=args.original_config_file,
214
+ config_files={'v1': args.original_config_file},
215
215
  image_size=args.image_size,
216
216
  prediction_type=args.prediction_type,
217
217
  model_type=args.pipeline_type,
hcpdiff/train_colo.py CHANGED
@@ -23,7 +23,7 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
23
23
  from colossalai.utils.model.colo_init_context import _convert_to_coloparam
24
24
  from colossalai.tensor import ColoParameter
25
25
 
26
- from hcpdiff.train_ac import Trainer, get_scheduler, ModelEMA
26
+ from hcpdiff.train_ac_old import Trainer, get_scheduler, ModelEMA
27
27
  from diffusers import UNet2DConditionModel
28
28
  from hcpdiff.utils.colo_utils import gemini_zero_dpp, GeminiAdamOptimizerP
29
29
  from hcpdiff.utils.utils import load_config_with_cli
@@ -7,7 +7,7 @@ from functools import partial
7
7
  import torch
8
8
 
9
9
  from hcpdiff.ckpt_manager import CkptManagerPKL, CkptManagerSafe
10
- from hcpdiff.train_ac import Trainer, load_config_with_cli
10
+ from hcpdiff.train_ac_old import Trainer, load_config_with_cli
11
11
  from hcpdiff.utils.net_utils import get_scheduler
12
12
 
13
13
  class TrainerDeepSpeed(Trainer):
hcpdiff/trainer_ac.py ADDED
@@ -0,0 +1,79 @@
1
+ import argparse
2
+ import warnings
3
+
4
+ import torch
5
+ from rainbowneko.parser import load_config_with_cli
6
+ from rainbowneko.ckpt_manager import NekoSaver
7
+ from rainbowneko.train import Trainer
8
+ from rainbowneko.utils import xformers_available, is_dict
9
+ from hcpdiff.ckpt_manager import EmbFormat
10
+
11
+ class HCPTrainer(Trainer):
12
+ def config_model(self):
13
+ if self.cfgs.model.enable_xformers:
14
+ if xformers_available:
15
+ self.model_wrapper.enable_xformers()
16
+ else:
17
+ warnings.warn("xformers is not available. Make sure it is installed correctly")
18
+
19
+ if self.model_wrapper.vae is not None:
20
+ self.vae_dtype = self.weight_dtype_map.get(self.cfgs.model.get('vae_dtype', None), torch.float32)
21
+ self.model_wrapper.set_dtype(self.weight_dtype, self.vae_dtype)
22
+
23
+ if self.cfgs.model.gradient_checkpointing:
24
+ self.model_wrapper.enable_gradient_checkpointing()
25
+
26
+ def get_param_group_train(self):
27
+ train_params = super().get_param_group_train()
28
+
29
+ # For prompt-tuning
30
+ if self.cfgs.emb_pt is None:
31
+ train_params_emb, self.train_pts = [], {}
32
+ else:
33
+ from hcpdiff.parser import CfgEmbPTParser
34
+ self.cfgs.emb_pt: CfgEmbPTParser
35
+
36
+ train_params_emb, self.train_pts = self.cfgs.emb_pt.get_params_group(self.model_wrapper)
37
+ self.emb_format = EmbFormat()
38
+ train_params += train_params_emb
39
+ return train_params
40
+
41
+ @property
42
+ def pt_trainable(self):
43
+ return self.cfgs.emb_pt is not None
44
+
45
+ def get_loss(self, ds_name, model_pred, inputs):
46
+ loss = super().get_loss(ds_name, model_pred, inputs)
47
+ # make DDP happy
48
+ if len(self.train_pts)>0:
49
+ loss = loss+0*sum([emb.mean() for emb in self.train_pts.values()])
50
+ return loss
51
+
52
+ def save_model(self, from_raw=False):
53
+ NekoSaver.save_all(
54
+ self.model_raw,
55
+ plugin_groups={**self.all_plugin, 'embs': self.train_pts},
56
+ cfg=self.ckpt_saver,
57
+ model_ema=getattr(self, "ema_model", None),
58
+ name_template=f'{{}}-{self.real_step}',
59
+ )
60
+
61
+ self.loggers.info(f"Saved state, step: {self.real_step}")
62
+
63
+ def hcp_train():
64
+ import subprocess
65
+ parser = argparse.ArgumentParser(description='HCP-Diffusion Launcher')
66
+ parser.add_argument('--launch_cfg', type=str, default='cfgs/launcher/multi.yaml')
67
+ args, train_args = parser.parse_known_args()
68
+
69
+ subprocess.run(["accelerate", "launch", '--config_file', args.launch_cfg, "-m",
70
+ "hcpdiff.trainer_ac"] + train_args, check=True)
71
+
72
+ if __name__ == "__main__":
73
+ parser = argparse.ArgumentParser(description="HCP Diffusion Trainer")
74
+ parser.add_argument("--cfg", type=str, default=None, required=True)
75
+ args, cfg_args = parser.parse_known_args()
76
+
77
+ parser, conf = load_config_with_cli(args.cfg, args_list=cfg_args) # skip --cfg
78
+ trainer = HCPTrainer(parser, conf)
79
+ trainer.train()
@@ -0,0 +1,31 @@
1
+ import argparse
2
+ import sys
3
+ from functools import partial
4
+
5
+ import torch
6
+ from accelerate import Accelerator
7
+ from loguru import logger
8
+
9
+ from rainbowneko.train.trainer import TrainerSingleCard
10
+ from hcpdiff.trainer_ac import HCPTrainer, load_config_with_cli
11
+
12
+ class HCPTrainerSingleCard(TrainerSingleCard, HCPTrainer):
13
+ pass
14
+
15
+ def hcp_train():
16
+ import subprocess
17
+ parser = argparse.ArgumentParser(description='HCP-Diffusion Launcher')
18
+ parser.add_argument('--launch_cfg', type=str, default='cfgs/launcher/single.yaml')
19
+ args, train_args = parser.parse_known_args()
20
+
21
+ subprocess.run(["accelerate", "launch", '--config_file', args.launch_cfg, "-m",
22
+ "hcpdiff.trainer_ac_single"] + train_args, check=True)
23
+
24
+ if __name__ == '__main__':
25
+ parser = argparse.ArgumentParser(description='HCP Diffusion Trainer')
26
+ parser.add_argument("--cfg", type=str, default=None, required=True)
27
+ args, cfg_args = parser.parse_known_args()
28
+
29
+ parser, conf = load_config_with_cli(args.cfg, args_list=cfg_args) # skip --cfg
30
+ trainer = HCPTrainerSingleCard(parser, conf)
31
+ trainer.train()
hcpdiff/utils/__init__.py CHANGED
@@ -1,4 +1,2 @@
1
1
  from .utils import *
2
- from .img_size_tool import get_image_size
3
- from .cfg_resolvers import *
4
2
  from .net_utils import *
@@ -21,18 +21,23 @@ import torch
21
21
  from packaging import version
22
22
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
23
23
 
24
+ from diffusers import StableDiffusionInpaintPipelineLegacy
24
25
  from diffusers.configuration_utils import FrozenDict
25
26
  from diffusers.image_processor import VaeImageProcessor
26
27
  from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
27
28
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
28
29
  from diffusers.models.lora import adjust_lora_scale_text_encoder
29
30
  from diffusers.schedulers import KarrasDiffusionSchedulers
30
- from diffusers.utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
31
+ from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
31
32
  from diffusers.utils.torch_utils import randn_tensor
32
33
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
33
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
34
+ from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
34
35
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
35
36
 
37
+ try:
38
+ from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers
39
+ except:
40
+ USE_PEFT_BACKEND = False
36
41
 
37
42
  logger = logging.get_logger(__name__)
38
43
 
@@ -6,11 +6,19 @@ import torch
6
6
  from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION, Optimizer
7
7
  from torch import nn
8
8
  from torch.optim import lr_scheduler
9
- from transformers import PretrainedConfig, AutoTokenizer
9
+ from transformers import PretrainedConfig, AutoTokenizer, T5EncoderModel, CLIPTextModel
10
10
  from functools import partial
11
+ from huggingface_hub import hf_hub_download
12
+ import json
11
13
 
12
14
  dtype_dict = {'fp32':torch.float32, 'amp':torch.float32, 'fp16':torch.float16, 'bf16':torch.bfloat16}
13
15
 
16
+ try:
17
+ dtype_dict['fp8_e4m3'] = torch.float8_e4m3fn
18
+ dtype_dict['fp8_e5m2'] = torch.float8_e5m2
19
+ except:
20
+ pass
21
+
14
22
  def get_scheduler(cfg, optimizer):
15
23
  if cfg is None:
16
24
  return None
@@ -90,7 +98,7 @@ def auto_tokenizer_cls(pretrained_model_name_or_path: str, revision: str = None)
90
98
  revision=revision, use_fast=False,
91
99
  )
92
100
  return SDXLTokenizer
93
- except OSError:
101
+ except:
94
102
  # not sdxl, only one tokenizer
95
103
  return AutoTokenizer
96
104
 
@@ -102,8 +110,10 @@ def auto_text_encoder_cls(pretrained_model_name_or_path: str, revision: str = No
102
110
  subfolder="text_encoder_2",
103
111
  revision=revision,
104
112
  )
113
+ if text_encoder_config.architectures is None:
114
+ raise ValueError()
105
115
  return SDXLTextEncoder
106
- except OSError:
116
+ except:
107
117
  text_encoder_config = PretrainedConfig.from_pretrained(
108
118
  pretrained_model_name_or_path,
109
119
  subfolder="text_encoder",
@@ -112,16 +122,26 @@ def auto_text_encoder_cls(pretrained_model_name_or_path: str, revision: str = No
112
122
  model_class = text_encoder_config.architectures[0]
113
123
 
114
124
  if model_class == "CLIPTextModel":
115
- from transformers import CLIPTextModel
116
-
117
125
  return CLIPTextModel
118
126
  elif model_class == "RobertaSeriesModelWithTransformation":
119
127
  from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
120
128
 
121
129
  return RobertaSeriesModelWithTransformation
130
+ elif model_class == "T5EncoderModel":
131
+ return T5EncoderModel
122
132
  else:
123
133
  raise ValueError(f"{model_class} is not supported.")
124
134
 
135
+ def get_pipe_name(path: str):
136
+ if os.path.isdir(path):
137
+ json_file = os.path.join(path, "model_index.json")
138
+ else:
139
+ json_file = hf_hub_download(path, "model_index.json")
140
+ with open(json_file, "r", encoding="utf-8") as reader:
141
+ text = reader.read()
142
+ data = json.loads(text)
143
+ return data['_class_name']
144
+
125
145
  def auto_tokenizer(pretrained_model_name_or_path: str, revision: str = None, **kwargs):
126
146
  return auto_tokenizer_cls(pretrained_model_name_or_path, revision).from_pretrained(pretrained_model_name_or_path, revision=revision, **kwargs)
127
147
 
@@ -225,4 +245,7 @@ def split_module_name(layer_name):
225
245
  return parent_name, host_name
226
246
 
227
247
  def get_dtype(dtype):
228
- return dtype_dict.get(dtype, torch.float32)
248
+ if isinstance(dtype, torch.dtype):
249
+ return dtype
250
+ else:
251
+ return dtype_dict.get(dtype, torch.float32)
@@ -2,9 +2,9 @@ from typing import Union, List, Optional, Callable, Dict, Any
2
2
 
3
3
  import PIL
4
4
  import torch
5
- from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
5
+ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, PixArtTransformer2DModel
6
6
  from diffusers.image_processor import VaeImageProcessor
7
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
7
+ from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
8
8
  from .inpaint_pipe import preprocess_mask, preprocess_image, StableDiffusionInpaintPipelineLegacy
9
9
  from einops import repeat
10
10
 
@@ -122,12 +122,20 @@ class HookPipe_T2I(StableDiffusionPipeline):
122
122
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
123
123
 
124
124
  if pooled_output is None:
125
- noise_pred = self.unet(latent_model_input, t, prompt_embeds[i], encoder_attention_mask=encoder_attention_mask,
126
- cross_attention_kwargs=cross_attention_kwargs, ).sample
125
+ if isinstance(self.unet, PixArtTransformer2DModel):
126
+ added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
127
+ noise_pred = self.unet(latent_model_input, timestep=t.repeat(latent_model_input.shape[0]), encoder_hidden_states=prompt_embeds[i],
128
+ encoder_attention_mask=encoder_attention_mask,
129
+ cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
130
+ else:
131
+ noise_pred = self.unet(latent_model_input, timestep=t, encoder_hidden_states=prompt_embeds[i],
132
+ encoder_attention_mask=encoder_attention_mask,
133
+ cross_attention_kwargs=cross_attention_kwargs).sample
127
134
  else:
128
135
  added_cond_kwargs = {"text_embeds":pooled_output, "time_ids":crop_info}
129
136
  # predict the noise residual
130
- noise_pred = self.unet(latent_model_input, t, prompt_embeds[i], encoder_attention_mask=encoder_attention_mask,
137
+ noise_pred = self.unet(latent_model_input, timestep=t, encoder_hidden_states=prompt_embeds[i],
138
+ encoder_attention_mask=encoder_attention_mask,
131
139
  cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
132
140
 
133
141
  # perform guidance
@@ -135,6 +143,10 @@ class HookPipe_T2I(StableDiffusionPipeline):
135
143
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
136
144
  noise_pred = noise_pred_uncond+guidance_scale*(noise_pred_text-noise_pred_uncond)
137
145
 
146
+ # learned sigma
147
+ if self.unet.config.out_channels // 2 == num_channels_latents:
148
+ noise_pred = noise_pred.chunk(2, dim=1)[0]
149
+
138
150
  # x_t -> x_0
139
151
  alpha_prod_t = alphas_cumprod[t.long()]
140
152
  beta_prod_t = 1-alpha_prod_t
@@ -271,8 +283,13 @@ class HookPipe_I2I(StableDiffusionImg2ImgPipeline):
271
283
 
272
284
  # predict the noise residual
273
285
  if pooled_output is None:
274
- noise_pred = self.unet(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
275
- cross_attention_kwargs=cross_attention_kwargs, ).sample
286
+ if isinstance(self.unet, PixArtTransformer2DModel):
287
+ added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
288
+ noise_pred = self.unet(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
289
+ cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
290
+ else:
291
+ noise_pred = self.unet(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
292
+ cross_attention_kwargs=cross_attention_kwargs, ).sample
276
293
  else:
277
294
  added_cond_kwargs = {"text_embeds":pooled_output, "time_ids":crop_info}
278
295
  # predict the noise residual
hcpdiff/utils/utils.py CHANGED
@@ -56,8 +56,8 @@ def remove_config_undefined(cfg):
56
56
  def load_config(path, remove_undefined=True):
57
57
  cfg = OmegaConf.load(path)
58
58
  if '_base_' in cfg:
59
- for base in cfg['_base_']:
60
- cfg = OmegaConf.merge(load_config(base, remove_undefined=False), cfg)
59
+ base_cfgs = [load_config(base, remove_undefined=False) for base in cfg['_base_']]
60
+ cfg = OmegaConf.merge(*base_cfgs, cfg)
61
61
  del cfg['_base_']
62
62
  if remove_undefined:
63
63
  cfg = remove_config_undefined(cfg)
@@ -85,7 +85,7 @@ def get_cfg_range(cfg_text:str):
85
85
  def to_validate_file(name):
86
86
  rstr = r"[\/\\\:\*\?\"\<\>\|]" # '/ \ : * ? " < > |'
87
87
  new_title = re.sub(rstr, "_", name) # 替换为下划线
88
- return new_title[:50]
88
+ return new_title[:200]
89
89
 
90
90
  def make_mask(start, end, length):
91
91
  mask=torch.zeros(length)
@@ -159,4 +159,21 @@ def pad_attn_bias(x, attn_bias, block_size=8):
159
159
  # 在k维度上进行填充
160
160
  x_padded = F.pad(x, (0, 0, 0, padding_l, 0, 0), mode='constant', value=0)
161
161
  attn_bias_padded = F.pad(attn_bias, (0, padding_l, 0, 0), mode='constant', value=0)
162
- return x_padded, attn_bias_padded
162
+ return x_padded, attn_bias_padded
163
+
164
+ def linear_interp(t, x):
165
+ '''
166
+ t_l ---------t_h
167
+ ^x
168
+ '''
169
+ if (x>=len(t)).any():
170
+ x = x.clamp(max=len(t)-1e-6)
171
+ x0 = x.floor().long()
172
+ x1 = x0 + 1
173
+
174
+ y0 = t[x0]
175
+ y1 = t[x1]
176
+
177
+ xd = (x - x0.float())
178
+
179
+ return y0 * (1 - xd) + y1 * xd
@@ -1,15 +1,20 @@
1
- from .base import BasicAction, MemoryMixin, from_memory, ExecAction, LoopAction
2
- from .diffusion import InputFeederAction, PrepareDiffusionAction, MakeLatentAction, NoisePredAction, SampleAction, DiffusionStepAction, \
3
- X0PredAction, SeedAction, MakeTimestepsAction
1
+ from .diffusion import InputFeederAction, MakeLatentAction, DenoiseAction, SampleAction, DiffusionStepAction, \
2
+ X0PredAction, SeedAction, MakeTimestepsAction, PrepareDiffusionAction, time_iter
4
3
  from .text import TextEncodeAction, TextHookAction, AttnMultTextEncodeAction
5
4
  from .vae import EncodeAction, DecodeAction
6
- from .io import LoadModelsAction, SaveImageAction, BuildModelLoaderAction, LoadPartAction, LoadLoraAction, LoadPluginAction
7
- from .utils import LatentResizeAction, ImageResizeAction
8
- from .model import VaeOptimizeAction, BuildOffloadAction, XformersEnableAction, StartTextEncode, StartDiffusion, EndTextEncode, EndDiffusion
5
+ from .io import BuildModelsAction, SaveImageAction, LoadImageAction
6
+ from .utils import LatentResizeAction, ImageResizeAction, FeedtoCNetAction
7
+ from .model import VaeOptimizeAction, BuildOffloadAction, XformersEnableAction
8
+ #from .flow import FilePromptAction
9
+
10
+ try:
11
+ from .fast import SFastCompileAction
12
+ except:
13
+ print('stable fast not installed.')
9
14
 
10
15
  from omegaconf import OmegaConf
11
16
 
12
- OmegaConf.register_new_resolver("hcp.from_memory", lambda mem_name: OmegaConf.create({
13
- '_target_': 'hcpdiff.workflow.from_memory',
14
- 'mem_name': mem_name,
15
- }))
17
+ OmegaConf.register_new_resolver("hcp.from_memory", lambda mem_name:OmegaConf.create({
18
+ '_target_':'hcpdiff.workflow.from_memory',
19
+ 'mem_name':mem_name,
20
+ }))