hcpdiff 0.9.0__py3-none-any.whl → 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (208) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +80 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +1 -2
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text2img.py +36 -74
  21. hcpdiff/data/source/text2img_cond.py +9 -15
  22. hcpdiff/diffusion/__init__.py +0 -0
  23. hcpdiff/diffusion/noise/__init__.py +2 -0
  24. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  25. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  26. hcpdiff/diffusion/sampler/__init__.py +5 -0
  27. hcpdiff/diffusion/sampler/base.py +72 -0
  28. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  29. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  30. hcpdiff/diffusion/sampler/edm.py +22 -0
  31. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  35. hcpdiff/easy/__init__.py +2 -0
  36. hcpdiff/easy/cfg/__init__.py +3 -0
  37. hcpdiff/easy/cfg/sd15_train.py +201 -0
  38. hcpdiff/easy/cfg/sdxl_train.py +140 -0
  39. hcpdiff/easy/cfg/t2i.py +177 -0
  40. hcpdiff/easy/model/__init__.py +2 -0
  41. hcpdiff/easy/model/cnet.py +31 -0
  42. hcpdiff/easy/model/loader.py +79 -0
  43. hcpdiff/easy/sampler.py +46 -0
  44. hcpdiff/evaluate/__init__.py +1 -0
  45. hcpdiff/evaluate/previewer.py +60 -0
  46. hcpdiff/loss/__init__.py +4 -1
  47. hcpdiff/loss/base.py +41 -0
  48. hcpdiff/loss/gw.py +35 -0
  49. hcpdiff/loss/ssim.py +37 -0
  50. hcpdiff/loss/vlb.py +79 -0
  51. hcpdiff/loss/weighting.py +66 -0
  52. hcpdiff/models/__init__.py +2 -2
  53. hcpdiff/models/cfg_context.py +17 -14
  54. hcpdiff/models/compose/compose_hook.py +44 -23
  55. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  56. hcpdiff/models/compose/sdxl_composer.py +4 -4
  57. hcpdiff/models/container.py +1 -1
  58. hcpdiff/models/controlnet.py +16 -16
  59. hcpdiff/models/lora_base_patch.py +14 -25
  60. hcpdiff/models/lora_layers.py +3 -9
  61. hcpdiff/models/lora_layers_patch.py +14 -24
  62. hcpdiff/models/text_emb_ex.py +84 -6
  63. hcpdiff/models/textencoder_ex.py +54 -18
  64. hcpdiff/models/wrapper/__init__.py +3 -0
  65. hcpdiff/models/wrapper/pixart.py +19 -0
  66. hcpdiff/models/wrapper/sd.py +218 -0
  67. hcpdiff/models/wrapper/utils.py +20 -0
  68. hcpdiff/parser/__init__.py +1 -0
  69. hcpdiff/parser/embpt.py +32 -0
  70. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  71. hcpdiff/tools/dataset_generator.py +94 -0
  72. hcpdiff/tools/download_hf_model.py +24 -0
  73. hcpdiff/tools/embedding_convert.py +6 -2
  74. hcpdiff/tools/init_proj.py +3 -21
  75. hcpdiff/tools/lora_convert.py +19 -15
  76. hcpdiff/tools/save_model.py +12 -0
  77. hcpdiff/tools/sd2diffusers.py +1 -1
  78. hcpdiff/train_colo.py +1 -1
  79. hcpdiff/train_deepspeed.py +1 -1
  80. hcpdiff/trainer_ac.py +79 -0
  81. hcpdiff/trainer_ac_single.py +31 -0
  82. hcpdiff/utils/__init__.py +0 -2
  83. hcpdiff/utils/inpaint_pipe.py +790 -0
  84. hcpdiff/utils/net_utils.py +29 -6
  85. hcpdiff/utils/pipe_hook.py +46 -33
  86. hcpdiff/utils/utils.py +21 -4
  87. hcpdiff/workflow/__init__.py +15 -10
  88. hcpdiff/workflow/daam/__init__.py +1 -0
  89. hcpdiff/workflow/daam/act.py +66 -0
  90. hcpdiff/workflow/daam/hook.py +109 -0
  91. hcpdiff/workflow/diffusion.py +128 -136
  92. hcpdiff/workflow/fast.py +31 -0
  93. hcpdiff/workflow/flow.py +67 -0
  94. hcpdiff/workflow/io.py +36 -68
  95. hcpdiff/workflow/model.py +46 -43
  96. hcpdiff/workflow/text.py +84 -52
  97. hcpdiff/workflow/utils.py +32 -12
  98. hcpdiff/workflow/vae.py +37 -38
  99. hcpdiff-2.1.dist-info/METADATA +285 -0
  100. hcpdiff-2.1.dist-info/RECORD +114 -0
  101. {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
  102. hcpdiff-2.1.dist-info/entry_points.txt +5 -0
  103. hcpdiff/ckpt_manager/base.py +0 -16
  104. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  105. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  106. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -60
  107. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  108. hcpdiff/data/bucket.py +0 -358
  109. hcpdiff/data/caption_loader.py +0 -80
  110. hcpdiff/data/cond_dataset.py +0 -40
  111. hcpdiff/data/crop_info_dataset.py +0 -40
  112. hcpdiff/data/data_processor.py +0 -33
  113. hcpdiff/data/pair_dataset.py +0 -146
  114. hcpdiff/data/sampler.py +0 -54
  115. hcpdiff/data/source/base.py +0 -30
  116. hcpdiff/data/utils.py +0 -80
  117. hcpdiff/infer_workflow.py +0 -57
  118. hcpdiff/loggers/__init__.py +0 -13
  119. hcpdiff/loggers/base_logger.py +0 -76
  120. hcpdiff/loggers/cli_logger.py +0 -40
  121. hcpdiff/loggers/preview/__init__.py +0 -1
  122. hcpdiff/loggers/preview/image_previewer.py +0 -149
  123. hcpdiff/loggers/tensorboard_logger.py +0 -30
  124. hcpdiff/loggers/wandb_logger.py +0 -31
  125. hcpdiff/loggers/webui_logger.py +0 -9
  126. hcpdiff/loss/min_snr_loss.py +0 -52
  127. hcpdiff/models/layers.py +0 -81
  128. hcpdiff/models/plugin.py +0 -348
  129. hcpdiff/models/wrapper.py +0 -75
  130. hcpdiff/noise/__init__.py +0 -3
  131. hcpdiff/noise/noise_base.py +0 -16
  132. hcpdiff/noise/pyramid_noise.py +0 -50
  133. hcpdiff/noise/zero_terminal.py +0 -44
  134. hcpdiff/train_ac.py +0 -565
  135. hcpdiff/train_ac_single.py +0 -39
  136. hcpdiff/utils/caption_tools.py +0 -105
  137. hcpdiff/utils/cfg_net_tools.py +0 -321
  138. hcpdiff/utils/cfg_resolvers.py +0 -16
  139. hcpdiff/utils/ema.py +0 -52
  140. hcpdiff/utils/img_size_tool.py +0 -248
  141. hcpdiff/vis/__init__.py +0 -3
  142. hcpdiff/vis/base_interface.py +0 -12
  143. hcpdiff/vis/disk_interface.py +0 -48
  144. hcpdiff/vis/webui_interface.py +0 -17
  145. hcpdiff/visualizer.py +0 -258
  146. hcpdiff/visualizer_reloadable.py +0 -237
  147. hcpdiff/workflow/base.py +0 -59
  148. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  149. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  150. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  151. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  152. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  153. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  154. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  155. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  156. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  157. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  158. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  159. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  160. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  161. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  162. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  163. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  164. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  165. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  166. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  167. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  168. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  169. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  170. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  171. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  172. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  173. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  174. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  175. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  176. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  177. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  178. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  179. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  180. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  181. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  182. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  183. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  184. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  185. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  186. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  187. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  188. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  189. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  190. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  191. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  192. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -57
  193. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  194. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero2.json +0 -32
  195. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero3.json +0 -39
  196. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  197. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  198. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  199. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  200. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  201. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  202. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  203. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  204. hcpdiff-0.9.0.dist-info/METADATA +0 -199
  205. hcpdiff-0.9.0.dist-info/RECORD +0 -155
  206. hcpdiff-0.9.0.dist-info/entry_points.txt +0 -2
  207. {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
  208. {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,13 @@
1
- from .base import BasicAction, from_memory_context, ExecAction, MemoryMixin
1
+ import random
2
+ import warnings
2
3
  from typing import Dict, Any, Union, List
4
+
3
5
  import torch
6
+ from hcpdiff.diffusion.sampler import BaseSampler, DiffusersSampler
7
+ from hcpdiff.utils import prepare_seed
8
+ from hcpdiff.utils.net_utils import get_dtype, to_cuda
9
+ from rainbowneko.infer import BasicAction
4
10
  from torch.cuda.amp import autocast
5
- import inspect
6
11
 
7
12
  try:
8
13
  from diffusers.utils import randn_tensor
@@ -10,197 +15,184 @@ except:
10
15
  # new version of diffusers
11
16
  from diffusers.utils.torch_utils import randn_tensor
12
17
 
13
- from hcpdiff.utils import prepare_seed
14
- from hcpdiff.utils.net_utils import get_dtype, to_cpu, to_cuda
15
- import random
16
-
17
18
  class InputFeederAction(BasicAction):
18
- @from_memory_context
19
- def __init__(self, ex_inputs:Dict[str, Any], unet=None):
20
- super().__init__()
19
+ def __init__(self, ex_inputs: Dict[str, Any], key_map_in=None, key_map_out=None):
20
+ super().__init__(key_map_in, key_map_out)
21
21
  self.ex_inputs = ex_inputs
22
- self.unet = unet
23
22
 
24
- def forward(self, **states):
25
- if hasattr(self.unet, 'input_feeder'):
26
- for feeder in self.unet.input_feeder:
27
- feeder(self.ex_inputs)
28
- return states
23
+ def forward(self, model, ex_inputs=None, **states):
24
+ ex_inputs = self.ex_inputs if ex_inputs is None else {**ex_inputs, **self.ex_inputs}
25
+ if hasattr(model, 'input_feeder'):
26
+ for feeder in model.input_feeder:
27
+ feeder(ex_inputs)
29
28
 
30
29
  class SeedAction(BasicAction):
31
- @from_memory_context
32
- def __init__(self, seed:Union[int, List[int]], bs:int=1):
33
- super().__init__()
30
+ def __init__(self, seed: Union[int, List[int]], bs: int = 1, key_map_in=None, key_map_out=None):
31
+ super().__init__(key_map_in, key_map_out)
34
32
  self.seed = seed
35
33
  self.bs = bs
36
34
 
37
- def forward(self, device, **states):
35
+ def forward(self, device, gen_step=0, **states):
38
36
  bs = states['prompt_embeds'].shape[0]//2 if 'prompt_embeds' in states else self.bs
39
37
  if self.seed is None:
40
38
  seeds = [None]*bs
41
39
  elif isinstance(self.seed, int):
42
- seeds = list(range(self.seed, self.seed+bs))
40
+ seeds = list(range(self.seed+gen_step*bs, self.seed+(gen_step+1)*bs))
43
41
  else:
44
42
  seeds = self.seed
45
43
  seeds = [s or random.randint(0, 1 << 30) for s in seeds]
46
44
 
47
45
  G = prepare_seed(seeds, device=device)
48
- return {**states, 'seeds':seeds, 'generator':G, 'device':device}
49
-
50
- class PrepareDiffusionAction(BasicAction, MemoryMixin):
51
- def __init__(self, dtype='fp32'):
52
- self.dtype = dtype
53
-
54
- def forward(self, memory, **states):
55
- dtype = get_dtype(self.dtype)
56
- memory.unet.to(dtype=dtype)
57
- memory.text_encoder.to(dtype=dtype)
58
- memory.vae.to(dtype=dtype)
59
-
60
- device = memory.unet.device
61
- vae_scale_factor = 2**(len(memory.vae.config.block_out_channels)-1)
62
- return {**states, 'dtype': self.dtype, 'device':device, 'vae_scale_factor':vae_scale_factor}
63
-
64
- class MakeTimestepsAction(BasicAction, MemoryMixin):
65
- @from_memory_context
66
- def __init__(self, scheduler=None, N_steps:int=30, strength:float=None):
67
- self.scheduler = scheduler
46
+ return {'seeds':seeds, 'generator':G}
47
+
48
+ class PrepareDiffusionAction(BasicAction):
49
+ def __init__(self, model_offload=False, amp=torch.float16, key_map_in=None, key_map_out=None):
50
+ super().__init__(key_map_in, key_map_out)
51
+ self.model_offload = model_offload
52
+ self.amp = amp
53
+
54
+ def forward(self, device, denoiser, TE, vae, **states):
55
+ denoiser.to(device)
56
+ TE.to(device)
57
+ vae.to(device)
58
+
59
+ TE.eval()
60
+ denoiser.eval()
61
+ vae.eval()
62
+ return {'amp':self.amp, 'model_offload':self.model_offload}
63
+
64
+ class MakeTimestepsAction(BasicAction):
65
+ def __init__(self, N_steps: int = 30, strength: float = None, key_map_in=None, key_map_out=None):
66
+ super().__init__(key_map_in, key_map_out)
68
67
  self.N_steps = N_steps
69
68
  self.strength = strength
70
69
 
71
- def get_timesteps(self, timesteps, strength):
70
+ def get_timesteps(self, noise_sampler:BaseSampler, timesteps, strength):
72
71
  # get the original timestep using init_timestep
73
72
  num_inference_steps = len(timesteps)
74
- init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
73
+ init_timestep = min(int(num_inference_steps*strength), num_inference_steps)
75
74
 
76
- t_start = max(num_inference_steps - init_timestep, 0)
77
- timesteps = timesteps[t_start * self.scheduler.order :]
75
+ t_start = max(num_inference_steps-init_timestep, 0)
76
+ if isinstance(noise_sampler, DiffusersSampler):
77
+ timesteps = timesteps[t_start*noise_sampler.scheduler.order:]
78
+ else:
79
+ timesteps = timesteps[t_start:]
78
80
 
79
81
  return timesteps
80
82
 
81
- def forward(self, memory, device, **states):
82
- self.scheduler = self.scheduler or memory.scheduler
83
-
84
- self.scheduler.set_timesteps(self.N_steps, device=device)
85
- timesteps = self.scheduler.timesteps
83
+ def forward(self, noise_sampler:BaseSampler, device, **states):
84
+ timesteps = noise_sampler.get_timesteps(self.N_steps, device=device)
86
85
  if self.strength:
87
- timesteps = self.get_timesteps(timesteps, self.strength)
88
- alphas_cumprod = self.scheduler.alphas_cumprod.to(timesteps.device)
89
- return {**states, 'device':device, 'timesteps':timesteps, 'alphas_cumprod':alphas_cumprod}
90
-
91
- class MakeLatentAction(BasicAction, MemoryMixin):
92
- @from_memory_context
93
- def __init__(self, scheduler=None, N_ch=4, height=512, width=512):
94
- self.scheduler = scheduler
95
- self.N_ch=N_ch
96
- self.height=height
97
- self.width=width
98
-
99
- def forward(self, memory, generator, device, dtype, bs=None, latents=None, vae_scale_factor=8, start_timestep=None, **states):
86
+ timesteps = self.get_timesteps(noise_sampler, timesteps, self.strength)
87
+ return {'timesteps':timesteps, 'start_timestep':timesteps[:1]}
88
+ else:
89
+ return {'timesteps':timesteps}
90
+
91
+ class MakeLatentAction(BasicAction):
92
+ def __init__(self, N_ch=4, height=None, width=None, key_map_in=None, key_map_out=None):
93
+ super().__init__(key_map_in, key_map_out)
94
+ self.N_ch = N_ch
95
+ self.height = height
96
+ self.width = width
97
+
98
+ def forward(self, noise_sampler:BaseSampler, vae, generator, device, dtype, bs=None, latents=None, start_timestep=None,
99
+ pooled_output=None, crop_coord=None, **states):
100
100
  if bs is None:
101
101
  if 'prompt' in states:
102
102
  bs = len(states['prompt'])
103
- scheduler = self.scheduler or memory.scheduler
103
+ vae_scale_factor = 2**(len(vae.config.block_out_channels)-1)
104
+ device = torch.device(device)
104
105
 
105
- shape = (bs, self.N_ch, self.height//vae_scale_factor, self.width//vae_scale_factor)
106
+ if latents is None:
107
+ shape = (bs, self.N_ch, self.height//vae_scale_factor, self.width//vae_scale_factor)
108
+ else:
109
+ if self.height is not None:
110
+ warnings.warn('latents exist! User-specified width and height will be ignored!')
111
+ shape = latents.shape
106
112
  if isinstance(generator, list) and len(generator) != bs:
107
113
  raise ValueError(
108
114
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
109
115
  f" size of {bs}. Make sure the batch size matches the length of the generators."
110
116
  )
111
117
 
112
- noise = randn_tensor(shape, generator=generator, device=device, dtype=get_dtype(dtype))
113
118
  if latents is None:
114
- # scale the initial noise by the standard deviation required by the scheduler
115
- latents = noise*scheduler.init_noise_sigma
119
+ # scale the initial noise by the standard deviation required by the noise_sampler
120
+ noise_sampler.generator = generator
121
+ latents = noise_sampler.init_noise(shape, device=device, dtype=get_dtype(dtype))
116
122
  else:
117
123
  # image to image
118
124
  latents = latents.to(device)
119
- latents = scheduler.add_noise(latents, noise, start_timestep)
125
+ latents, noise = noise_sampler.add_noise(latents, start_timestep)
126
+
127
+ output = {'latents':latents}
128
+
129
+ # SDXL inputs
130
+ if pooled_output is not None:
131
+ width, height = shape[3]*vae_scale_factor, shape[2]*vae_scale_factor
132
+ if crop_coord is None:
133
+ crop_info = torch.tensor([height, width, 0, 0, height, width], dtype=torch.float)
134
+ else:
135
+ crop_info = torch.tensor([height, width, *crop_coord], dtype=torch.float)
136
+ crop_info = crop_info.to(device).repeat(bs, 1)
137
+ output['text_embeds'] = pooled_output[-1].to(device)
120
138
 
121
- return {**states, 'latents': latents, 'device':device, 'dtype':dtype, 'generator':generator}
139
+ if 'negative_prompt' in states:
140
+ output['crop_info'] = torch.cat([crop_info, crop_info], dim=0)
122
141
 
123
- class NoisePredAction(BasicAction, MemoryMixin):
124
- @from_memory_context
125
- def __init__(self, unet=None, scheduler=None, guidance_scale:float=7.0):
126
- self.guidance_scale=guidance_scale
127
- self.unet = unet
128
- self.scheduler = scheduler
142
+ return output
129
143
 
130
- def forward(self, memory, t, latents, prompt_embeds, pooled_output=None, encoder_attention_mask=None, crop_info=None,
131
- cross_attention_kwargs=None, dtype='fp32', **states):
132
- self.scheduler = self.scheduler or memory.scheduler
133
- self.unet = self.unet or memory.unet
144
+ class DenoiseAction(BasicAction):
145
+ def __init__(self, guidance_scale: float = 7.0, key_map_in=None, key_map_out=None):
146
+ super().__init__(key_map_in, key_map_out)
147
+ self.guidance_scale = guidance_scale
134
148
 
135
- with autocast(enabled=dtype == 'amp'):
149
+ def forward(self, denoiser, noise_sampler: BaseSampler, t, latents, prompt_embeds, text_embeds=None, encoder_attention_mask=None, crop_info=None,
150
+ cross_attention_kwargs=None, dtype='fp32', amp=None, model_offload=False, **states):
151
+
152
+ if model_offload:
153
+ to_cuda(denoiser) # to_cpu in VAE
154
+
155
+ with autocast(enabled=amp is not None, dtype=get_dtype(amp)):
136
156
  latent_model_input = torch.cat([latents]*2) if self.guidance_scale>1 else latents
137
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
157
+ latent_model_input = noise_sampler.c_in(t)*latent_model_input
138
158
 
139
- if pooled_output is None:
140
- noise_pred = self.unet(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
141
- cross_attention_kwargs=cross_attention_kwargs, ).sample
159
+ if text_embeds is None:
160
+ noise_pred = denoiser(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
161
+ cross_attention_kwargs=cross_attention_kwargs, ).sample
142
162
  else:
143
- added_cond_kwargs = {"text_embeds":pooled_output, "time_ids":crop_info}
163
+ added_cond_kwargs = {"text_embeds":text_embeds, "time_ids":crop_info}
144
164
  # predict the noise residual
145
- noise_pred = self.unet(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
146
- cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
165
+ noise_pred = denoiser(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
166
+ cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
147
167
 
148
168
  # perform guidance
149
169
  if self.guidance_scale>1:
150
170
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
151
171
  noise_pred = noise_pred_uncond+self.guidance_scale*(noise_pred_text-noise_pred_uncond)
152
172
 
153
- return {**states, 'noise_pred':noise_pred, 'latents': latents, 't':t, 'prompt_embeds':prompt_embeds, 'pooled_output':pooled_output,
154
- 'crop_info':crop_info, 'cross_attention_kwargs':cross_attention_kwargs, 'dtype':dtype}
155
-
156
- class SampleAction(BasicAction, MemoryMixin):
157
- @from_memory_context
158
- def __init__(self, scheduler=None, eta=0.0):
159
- self.scheduler = scheduler
160
- self.eta = eta
161
-
162
- def prepare_extra_step_kwargs(self, generator, eta):
163
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
164
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
165
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
166
- # and should be between [0, 1]
167
-
168
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
169
- extra_step_kwargs = {}
170
- if accepts_eta:
171
- extra_step_kwargs["eta"] = eta
172
-
173
- # check if the scheduler accepts generator
174
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
175
- if accepts_generator:
176
- extra_step_kwargs["generator"] = generator
177
- return extra_step_kwargs
178
-
179
- def forward(self, memory, noise_pred, t, latents, generator, **states):
180
- self.scheduler = self.scheduler or memory.scheduler
181
-
182
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, self.eta)
173
+ return {'noise_pred':noise_pred}
183
174
 
175
+ class SampleAction(BasicAction):
176
+ def forward(self, noise_sampler: BaseSampler, noise_pred, t, latents, generator, **states):
184
177
  # compute the previous noisy sample x_t -> x_t-1
185
- sc_out = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)
186
- latents = sc_out.prev_sample
187
- return {**states, 'latents': latents, 't':t, 'generator':generator}
188
-
189
- class DiffusionStepAction(BasicAction, MemoryMixin):
190
- @from_memory_context
191
- def __init__(self, unet=None, scheduler=None, guidance_scale:float=7.0):
192
- self.act_noise_pred = NoisePredAction(unet, scheduler, guidance_scale)
193
- self.act_sample = SampleAction(scheduler)
194
-
195
- def forward(self, memory, **states):
196
- states = self.act_noise_pred(memory=memory, **states)
197
- states = self.act_sample(memory=memory, **states)
178
+ latents = noise_sampler.denoise(latents, t, noise_pred, generator=generator)
179
+ return {'latents':latents}
180
+
181
+ class DiffusionStepAction(BasicAction):
182
+ def __init__(self, guidance_scale: float = 7.0, key_map_in=None, key_map_out=None):
183
+ super().__init__(key_map_in, key_map_out)
184
+ self.act_noise_pred = DenoiseAction(guidance_scale)
185
+ self.act_sample = SampleAction()
186
+
187
+ def forward(self, denoiser, noise_sampler, **states):
188
+ states = self.act_noise_pred(denoiser=denoiser, noise_sampler=noise_sampler, **states)
189
+ states = self.act_sample(**states)
198
190
  return states
199
191
 
200
192
  class X0PredAction(BasicAction):
201
- def forward(self, latents, alphas_cumprod, t, noise_pred, **states):
202
- # x_t -> x_0
203
- alpha_prod_t = alphas_cumprod[t.long()]
204
- beta_prod_t = 1-alpha_prod_t
205
- latents_x0 = (latents-beta_prod_t**(0.5)*noise_pred)/alpha_prod_t**(0.5) # approximate x_0
206
- return {**states, 'latents_x0': latents_x0, 'latents': latents, 'alphas_cumprod':alphas_cumprod, 't':t, 'noise_pred':noise_pred}
193
+ def forward(self, latents, noise_sampler: BaseSampler, t, noise_pred, **states):
194
+ latents_x0 = noise_sampler.eps_to_x0(noise_pred, latents, t)
195
+ return {'latents_x0':latents_x0}
196
+
197
+ def time_iter(timesteps, **states):
198
+ return [{'t':t} for t in timesteps]
@@ -0,0 +1,31 @@
1
+ from sfast.compilers.diffusion_pipeline_compiler import (compile_unet, CompilationConfig)
2
+ from rainbowneko.infer import BasicAction
3
+
4
+
5
+ class SFastCompileAction(BasicAction):
6
+
7
+ @staticmethod
8
+ def compile_model(unet):
9
+ # compile model
10
+ config = CompilationConfig.Default()
11
+ config.enable_xformers = False
12
+ try:
13
+ import xformers
14
+ config.enable_xformers = True
15
+ except ImportError:
16
+ print('xformers not installed, skip')
17
+ # NOTE:
18
+ # When GPU VRAM is insufficient or the architecture is too old, Triton might be slow.
19
+ # Disable Triton if you encounter this problem.
20
+ try:
21
+ import tritonx
22
+ config.enable_triton = True
23
+ except ImportError:
24
+ print('Triton not installed, skip')
25
+ config.enable_cuda_graph = True
26
+
27
+ return compile_unet(unet, config)
28
+
29
+ def forward(self, denoiser, **states):
30
+ denoiser = self.compile_model(denoiser)
31
+ return {'denoiser': denoiser}
@@ -0,0 +1,67 @@
1
+ from rainbowneko.infer import BasicAction
2
+ from typing import List, Dict
3
+ from tqdm import tqdm
4
+ import math
5
+
6
+ class FilePromptAction(BasicAction):
7
+ def __init__(self, actions: List[BasicAction], prompt: str, negative_prompt: str, bs: int = 4, key_map_in=None, key_map_out=None):
8
+ super().__init__(key_map_in, key_map_out)
9
+ if prompt.endswith('.txt'):
10
+ with open(prompt, 'r') as f:
11
+ prompt = f.read().split('\n')
12
+ else:
13
+ prompt = [prompt]
14
+
15
+ if negative_prompt.endswith('.txt'):
16
+ with open(negative_prompt, 'r') as f:
17
+ negative_prompt = f.read().split('\n')
18
+ else:
19
+ negative_prompt = [negative_prompt]*len(prompt)
20
+
21
+ self.prompt = prompt
22
+ self.negative_prompt = negative_prompt
23
+ self.bs = bs
24
+ self.actions = actions
25
+
26
+
27
+ def forward(self, **states):
28
+ states.update({'prompt_all':self.prompt, 'negative_prompt_all':self.negative_prompt})
29
+ states_ref = dict(**states)
30
+
31
+ pbar = tqdm(range(math.ceil(len(self.prompt)/self.bs)))
32
+ N_steps = len(self.actions)
33
+ for gen_step in pbar:
34
+ states = dict(**states_ref)
35
+ feed_data = {'gen_step': gen_step}
36
+ states.update(feed_data)
37
+ for step, act in enumerate(self.actions):
38
+ pbar.set_description(f'[{step+1}/{N_steps}] action: {type(act).__name__}')
39
+ states = act(**states)
40
+ return states
41
+
42
+ class FlowPromptAction(BasicAction):
43
+ def __init__(self, actions: List[BasicAction], prompt: str, negative_prompt: str, bs: int = 4, num: int = None, key_map_in=None, key_map_out=None):
44
+ super().__init__(key_map_in, key_map_out)
45
+ prompt = [prompt]*num
46
+ negative_prompt = [negative_prompt]*num
47
+
48
+ self.prompt = prompt
49
+ self.negative_prompt = negative_prompt
50
+ self.bs = bs
51
+ self.actions = actions
52
+
53
+
54
+ def forward(self, **states):
55
+ states.update({'prompt_all':self.prompt, 'negative_prompt_all':self.negative_prompt})
56
+ states_ref = dict(**states)
57
+
58
+ pbar = tqdm(range(math.ceil(len(self.prompt)/self.bs)))
59
+ N_steps = len(self.actions)
60
+ for gen_step in pbar:
61
+ states = dict(**states_ref)
62
+ feed_data = {'gen_step': gen_step}
63
+ states.update(feed_data)
64
+ for step, act in enumerate(self.actions):
65
+ pbar.set_description(f'[{step+1}/{N_steps}] action: {type(act).__name__}')
66
+ states = act(**states)
67
+ return states
hcpdiff/workflow/io.py CHANGED
@@ -1,88 +1,56 @@
1
1
  import os
2
+ from functools import partial
3
+ from typing import List, Union
2
4
 
3
- from diffusers import UNet2DConditionModel, AutoencoderKL, PNDMScheduler
4
-
5
- from hcpdiff.utils import auto_text_encoder, auto_tokenizer, to_validate_file
6
- from hcpdiff.utils.cfg_net_tools import HCPModelLoader
7
- from hcpdiff.utils.img_size_tool import types_support
5
+ import torch
6
+ from hcpdiff.utils import to_validate_file
8
7
  from hcpdiff.utils.net_utils import get_dtype
9
- from .base import BasicAction, from_memory_context, MemoryMixin
10
-
11
- class LoadModelsAction(BasicAction, MemoryMixin):
12
- @from_memory_context
13
- def __init__(self, pretrained_model: str, dtype: str, unet=None, text_encoder=None, tokenizer=None, vae=None, scheduler=None):
14
- self.pretrained_model = pretrained_model
8
+ from rainbowneko.ckpt_manager import NekoLoader
9
+ from rainbowneko.infer import BasicAction
10
+ from rainbowneko.infer import LoadImageAction as Neko_LoadImageAction
11
+ from rainbowneko.utils.img_size_tool import types_support
12
+
13
+ class BuildModelsAction(BasicAction):
14
+ def __init__(self, model_loader: partial[NekoLoader.load], dtype: str=torch.float32, device='cuda', key_map_in=None, key_map_out=None):
15
+ super().__init__(key_map_in, key_map_out)
16
+ self.model_loader = model_loader
15
17
  self.dtype = get_dtype(dtype)
18
+ self.device = device
16
19
 
17
- self.unet = unet
18
- self.text_encoder = text_encoder
19
- self.tokenizer = tokenizer
20
- self.vae = vae
21
- self.scheduler = scheduler
20
+ def forward(self, in_preview=False, model=None, **states):
21
+ if in_preview:
22
+ model = self.model_loader(dtype=self.dtype, device=self.device, denoiser=model.denoiser, TE=model.TE, vae=model.vae)
23
+ else:
24
+ model = self.model_loader(dtype=self.dtype, device=self.device)
22
25
 
23
- def forward(self, memory, **states):
24
- memory.unet = self.unet or UNet2DConditionModel.from_pretrained(self.pretrained_model, subfolder="unet", torch_dtype=self.dtype)
25
- memory.text_encoder = self.text_encoder or auto_text_encoder(self.pretrained_model, subfolder="text_encoder", torch_dtype=self.dtype)
26
- memory.tokenizer = self.tokenizer or auto_tokenizer(self.pretrained_model, subfolder="tokenizer", use_fast=False)
27
- memory.vae = self.vae or AutoencoderKL.from_pretrained(self.pretrained_model, subfolder="vae", torch_dtype=self.dtype)
28
- memory.scheduler = self.scheduler or PNDMScheduler.from_pretrained(self.pretrained_model, subfolder="scheduler", torch_dtype=self.dtype)
26
+ if isinstance(model, dict):
27
+ return model
28
+ else:
29
+ return {'model':model}
29
30
 
30
- return states
31
+ class LoadImageAction(Neko_LoadImageAction):
32
+ def __init__(self, image_paths: Union[str, List[str]], image_transforms=None, key_map_in=None, key_map_out=('input.x -> images',)):
33
+ super().__init__(image_paths, image_transforms, key_map_in, key_map_out)
31
34
 
32
35
  class SaveImageAction(BasicAction):
33
- @from_memory_context
34
- def __init__(self, save_root: str, image_type: str = 'png', quality: int = 95):
36
+ def __init__(self, save_root: str, image_type: str = 'png', quality: int = 95, save_cfg=True, key_map_in=None, key_map_out=None):
37
+ super().__init__(key_map_in, key_map_out)
35
38
  self.save_root = save_root
36
39
  self.image_type = image_type
37
40
  self.quality = quality
41
+ self.save_cfg = save_cfg
38
42
 
39
43
  os.makedirs(save_root, exist_ok=True)
40
44
 
41
- def forward(self, images, prompt, negative_prompt, seeds=None, **states):
42
- num_img_exist = max([0]+[int(x.split('-', 1)[0]) for x in os.listdir(self.save_root) if x.rsplit('.', 1)[-1] in types_support])+1
45
+ def forward(self, images, prompt, negative_prompt, seeds, cfgs=None, parser=None, preview_root=None, preview_step=None, **states):
46
+ save_root = preview_root or self.save_root
47
+ num_img_exist = max([0]+[int(x.split('-', 1)[0]) for x in os.listdir(save_root) if x.rsplit('.', 1)[-1] in types_support])+1
43
48
 
44
49
  for bid, (p, pn, img) in enumerate(zip(prompt, negative_prompt, images)):
45
- img_path = os.path.join(self.save_root, f"{num_img_exist}-{seeds[bid]}-{to_validate_file(prompt[0])}.{self.image_type}")
50
+ img_path = os.path.join(save_root, f"{preview_step or num_img_exist}-{seeds[bid]}-{to_validate_file(prompt[0])}.{self.image_type}")
46
51
  img.save(img_path, quality=self.quality)
47
52
  num_img_exist += 1
48
53
 
49
- return {**states, 'images':images, 'prompt':prompt, 'negative_prompt':negative_prompt, 'seeds':seeds}
50
-
51
- class BuildModelLoaderAction(BasicAction, MemoryMixin):
52
- def forward(self, memory, **states):
53
- memory.model_loader_unet = HCPModelLoader(memory.unet)
54
- memory.model_loader_TE = HCPModelLoader(memory.text_encoder)
55
- return states
56
-
57
- class LoadPartAction(BasicAction, MemoryMixin):
58
- @from_memory_context
59
- def __init__(self, model: str, cfg):
60
- self.model = model
61
- self.cfg = cfg
62
-
63
- def forward(self, memory, **states):
64
- model_loader = memory[f"model_loader_{self.model}"]
65
- model_loader.load_part(self.cfg)
66
- return states
67
-
68
- class LoadLoraAction(BasicAction, MemoryMixin):
69
- @from_memory_context
70
- def __init__(self, model: str, cfg):
71
- self.model = model
72
- self.cfg = cfg
73
-
74
- def forward(self, memory, **states):
75
- model_loader = memory[f"model_loader_{self.model}"]
76
- model_loader.load_lora(self.cfg)
77
- return states
78
-
79
- class LoadPluginAction(BasicAction, MemoryMixin):
80
- @from_memory_context
81
- def __init__(self, model: str, cfg):
82
- self.model = model
83
- self.cfg = cfg
84
-
85
- def forward(self, memory, **states):
86
- model_loader = memory[f"model_loader_{self.model}"]
87
- model_loader.load_plugin(self.cfg)
88
- return states
54
+ if self.save_cfg:
55
+ cfgs.seed = seeds[bid]
56
+ parser.save_configs(cfgs, os.path.join(save_root, f"{preview_step or num_img_exist}-{seeds[bid]}-info"))