hcpdiff 0.9.0__py3-none-any.whl → 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (208) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +80 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +1 -2
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text2img.py +36 -74
  21. hcpdiff/data/source/text2img_cond.py +9 -15
  22. hcpdiff/diffusion/__init__.py +0 -0
  23. hcpdiff/diffusion/noise/__init__.py +2 -0
  24. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  25. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  26. hcpdiff/diffusion/sampler/__init__.py +5 -0
  27. hcpdiff/diffusion/sampler/base.py +72 -0
  28. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  29. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  30. hcpdiff/diffusion/sampler/edm.py +22 -0
  31. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  35. hcpdiff/easy/__init__.py +2 -0
  36. hcpdiff/easy/cfg/__init__.py +3 -0
  37. hcpdiff/easy/cfg/sd15_train.py +201 -0
  38. hcpdiff/easy/cfg/sdxl_train.py +140 -0
  39. hcpdiff/easy/cfg/t2i.py +177 -0
  40. hcpdiff/easy/model/__init__.py +2 -0
  41. hcpdiff/easy/model/cnet.py +31 -0
  42. hcpdiff/easy/model/loader.py +79 -0
  43. hcpdiff/easy/sampler.py +46 -0
  44. hcpdiff/evaluate/__init__.py +1 -0
  45. hcpdiff/evaluate/previewer.py +60 -0
  46. hcpdiff/loss/__init__.py +4 -1
  47. hcpdiff/loss/base.py +41 -0
  48. hcpdiff/loss/gw.py +35 -0
  49. hcpdiff/loss/ssim.py +37 -0
  50. hcpdiff/loss/vlb.py +79 -0
  51. hcpdiff/loss/weighting.py +66 -0
  52. hcpdiff/models/__init__.py +2 -2
  53. hcpdiff/models/cfg_context.py +17 -14
  54. hcpdiff/models/compose/compose_hook.py +44 -23
  55. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  56. hcpdiff/models/compose/sdxl_composer.py +4 -4
  57. hcpdiff/models/container.py +1 -1
  58. hcpdiff/models/controlnet.py +16 -16
  59. hcpdiff/models/lora_base_patch.py +14 -25
  60. hcpdiff/models/lora_layers.py +3 -9
  61. hcpdiff/models/lora_layers_patch.py +14 -24
  62. hcpdiff/models/text_emb_ex.py +84 -6
  63. hcpdiff/models/textencoder_ex.py +54 -18
  64. hcpdiff/models/wrapper/__init__.py +3 -0
  65. hcpdiff/models/wrapper/pixart.py +19 -0
  66. hcpdiff/models/wrapper/sd.py +218 -0
  67. hcpdiff/models/wrapper/utils.py +20 -0
  68. hcpdiff/parser/__init__.py +1 -0
  69. hcpdiff/parser/embpt.py +32 -0
  70. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  71. hcpdiff/tools/dataset_generator.py +94 -0
  72. hcpdiff/tools/download_hf_model.py +24 -0
  73. hcpdiff/tools/embedding_convert.py +6 -2
  74. hcpdiff/tools/init_proj.py +3 -21
  75. hcpdiff/tools/lora_convert.py +19 -15
  76. hcpdiff/tools/save_model.py +12 -0
  77. hcpdiff/tools/sd2diffusers.py +1 -1
  78. hcpdiff/train_colo.py +1 -1
  79. hcpdiff/train_deepspeed.py +1 -1
  80. hcpdiff/trainer_ac.py +79 -0
  81. hcpdiff/trainer_ac_single.py +31 -0
  82. hcpdiff/utils/__init__.py +0 -2
  83. hcpdiff/utils/inpaint_pipe.py +790 -0
  84. hcpdiff/utils/net_utils.py +29 -6
  85. hcpdiff/utils/pipe_hook.py +46 -33
  86. hcpdiff/utils/utils.py +21 -4
  87. hcpdiff/workflow/__init__.py +15 -10
  88. hcpdiff/workflow/daam/__init__.py +1 -0
  89. hcpdiff/workflow/daam/act.py +66 -0
  90. hcpdiff/workflow/daam/hook.py +109 -0
  91. hcpdiff/workflow/diffusion.py +128 -136
  92. hcpdiff/workflow/fast.py +31 -0
  93. hcpdiff/workflow/flow.py +67 -0
  94. hcpdiff/workflow/io.py +36 -68
  95. hcpdiff/workflow/model.py +46 -43
  96. hcpdiff/workflow/text.py +84 -52
  97. hcpdiff/workflow/utils.py +32 -12
  98. hcpdiff/workflow/vae.py +37 -38
  99. hcpdiff-2.1.dist-info/METADATA +285 -0
  100. hcpdiff-2.1.dist-info/RECORD +114 -0
  101. {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
  102. hcpdiff-2.1.dist-info/entry_points.txt +5 -0
  103. hcpdiff/ckpt_manager/base.py +0 -16
  104. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  105. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  106. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -60
  107. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  108. hcpdiff/data/bucket.py +0 -358
  109. hcpdiff/data/caption_loader.py +0 -80
  110. hcpdiff/data/cond_dataset.py +0 -40
  111. hcpdiff/data/crop_info_dataset.py +0 -40
  112. hcpdiff/data/data_processor.py +0 -33
  113. hcpdiff/data/pair_dataset.py +0 -146
  114. hcpdiff/data/sampler.py +0 -54
  115. hcpdiff/data/source/base.py +0 -30
  116. hcpdiff/data/utils.py +0 -80
  117. hcpdiff/infer_workflow.py +0 -57
  118. hcpdiff/loggers/__init__.py +0 -13
  119. hcpdiff/loggers/base_logger.py +0 -76
  120. hcpdiff/loggers/cli_logger.py +0 -40
  121. hcpdiff/loggers/preview/__init__.py +0 -1
  122. hcpdiff/loggers/preview/image_previewer.py +0 -149
  123. hcpdiff/loggers/tensorboard_logger.py +0 -30
  124. hcpdiff/loggers/wandb_logger.py +0 -31
  125. hcpdiff/loggers/webui_logger.py +0 -9
  126. hcpdiff/loss/min_snr_loss.py +0 -52
  127. hcpdiff/models/layers.py +0 -81
  128. hcpdiff/models/plugin.py +0 -348
  129. hcpdiff/models/wrapper.py +0 -75
  130. hcpdiff/noise/__init__.py +0 -3
  131. hcpdiff/noise/noise_base.py +0 -16
  132. hcpdiff/noise/pyramid_noise.py +0 -50
  133. hcpdiff/noise/zero_terminal.py +0 -44
  134. hcpdiff/train_ac.py +0 -565
  135. hcpdiff/train_ac_single.py +0 -39
  136. hcpdiff/utils/caption_tools.py +0 -105
  137. hcpdiff/utils/cfg_net_tools.py +0 -321
  138. hcpdiff/utils/cfg_resolvers.py +0 -16
  139. hcpdiff/utils/ema.py +0 -52
  140. hcpdiff/utils/img_size_tool.py +0 -248
  141. hcpdiff/vis/__init__.py +0 -3
  142. hcpdiff/vis/base_interface.py +0 -12
  143. hcpdiff/vis/disk_interface.py +0 -48
  144. hcpdiff/vis/webui_interface.py +0 -17
  145. hcpdiff/visualizer.py +0 -258
  146. hcpdiff/visualizer_reloadable.py +0 -237
  147. hcpdiff/workflow/base.py +0 -59
  148. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  149. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  150. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  151. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  152. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  153. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  154. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  155. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  156. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  157. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  158. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  159. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  160. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  161. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  162. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  163. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  164. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  165. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  166. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  167. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  168. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  169. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  170. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  171. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  172. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  173. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  174. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  175. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  176. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  177. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  178. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  179. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  180. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  181. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  182. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  183. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  184. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  185. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  186. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  187. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  188. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  189. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  190. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  191. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  192. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -57
  193. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  194. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero2.json +0 -32
  195. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero3.json +0 -39
  196. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  197. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  198. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  199. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  200. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  201. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  202. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  203. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  204. hcpdiff-0.9.0.dist-info/METADATA +0 -199
  205. hcpdiff-0.9.0.dist-info/RECORD +0 -155
  206. hcpdiff-0.9.0.dist-info/entry_points.txt +0 -2
  207. {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
  208. {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
hcpdiff/workflow/model.py CHANGED
@@ -1,67 +1,70 @@
1
+ import torch
1
2
  from accelerate import infer_auto_device_map, dispatch_model
2
3
  from diffusers.utils.import_utils import is_xformers_available
4
+ from rainbowneko.infer import BasicAction
3
5
 
4
- from hcpdiff.utils.net_utils import get_dtype, to_cpu, to_cuda
6
+ from hcpdiff.utils.net_utils import get_dtype
7
+ from hcpdiff.utils.net_utils import to_cpu
5
8
  from hcpdiff.utils.utils import size_to_int, int_to_size
6
- from .base import BasicAction, from_memory_context, MemoryMixin
7
9
 
8
- class VaeOptimizeAction(BasicAction, MemoryMixin):
9
- @from_memory_context
10
- def __init__(self, vae=None, slicing=True, tiling=False):
11
- super().__init__()
10
+ class VaeOptimizeAction(BasicAction):
11
+ def __init__(self, slicing=True, tiling=False, key_map_in=None, key_map_out=None):
12
+ super().__init__(key_map_in, key_map_out)
12
13
  self.slicing = slicing
13
14
  self.tiling = tiling
14
- self.vae = vae
15
-
16
- def forward(self, memory, **states):
17
- vae = self.vae or memory.vae
18
15
 
16
+ def forward(self, vae, **states):
19
17
  if self.tiling:
20
18
  vae.enable_tiling()
21
19
  if self.slicing:
22
20
  vae.enable_slicing()
23
- return states
24
21
 
25
- class BuildOffloadAction(BasicAction, MemoryMixin):
26
- @from_memory_context
27
- def __init__(self, max_VRAM: str, max_RAM: str):
28
- super().__init__()
22
+ class BuildOffloadAction(BasicAction):
23
+ def __init__(self, max_VRAM: str, max_RAM: str, vae_cpu=False, key_map_in=None, key_map_out=None):
24
+ super().__init__(key_map_in, key_map_out)
29
25
  self.max_VRAM = max_VRAM
30
26
  self.max_RAM = max_RAM
27
+ self.vae_cpu = vae_cpu
31
28
 
32
- def forward(self, memory, dtype: str, **states):
29
+ def forward(self, vae, denoiser, dtype: str, **states):
30
+ # denoiser offload
33
31
  torch_dtype = get_dtype(dtype)
34
32
  vram = size_to_int(self.max_VRAM)
35
- device_map = infer_auto_device_map(memory.unet, max_memory={0:int_to_size(vram >> 1), "cpu":self.max_RAM}, dtype=torch_dtype)
36
- memory.unet = dispatch_model(memory.unet, device_map)
33
+ device_map = infer_auto_device_map(denoiser, max_memory={0:int_to_size(vram >> 1), "cpu":self.max_RAM}, dtype=torch_dtype)
34
+ denoiser = dispatch_model(denoiser, device_map)
37
35
 
38
- device_map = infer_auto_device_map(memory.vae, max_memory={0:int_to_size(vram >> 5), "cpu":self.max_RAM}, dtype=torch_dtype)
39
- memory.vae = dispatch_model(memory.vae, device_map)
40
- return {'dtype':dtype, **states}
36
+ device_map = infer_auto_device_map(vae, max_memory={0:int_to_size(vram >> 5), "cpu":self.max_RAM}, dtype=torch_dtype)
37
+ vae = dispatch_model(vae, device_map)
38
+ # VAE offload
39
+ vram = size_to_int(self.max_VRAM)
40
+ if not self.vae_cpu:
41
+ device_map = infer_auto_device_map(vae, max_memory={0:int_to_size(vram >> 5), "cpu":self.max_RAM}, dtype=torch.float32)
42
+ vae = dispatch_model(vae, device_map)
43
+ else:
44
+ to_cpu(vae)
45
+ vae_decode_raw = vae.decode
41
46
 
42
- class XformersEnableAction(BasicAction, MemoryMixin):
43
- def forward(self, memory, **states):
44
- if is_xformers_available():
45
- memory.unet.enable_xformers_memory_efficient_attention()
46
- # self.te_hook.enable_xformers()
47
- return states
47
+ def vae_decode_offload(latents, return_dict=True, decode_raw=vae.decode):
48
+ vae.to(dtype=torch.float32)
49
+ res = decode_raw(latents.cpu().to(dtype=torch.float32), return_dict=return_dict)
50
+ return res
48
51
 
49
- class StartTextEncode(BasicAction, MemoryMixin):
50
- def forward(self, memory, **states):
51
- to_cuda(memory.text_encoder)
52
- return states
52
+ vae.decode = vae_decode_offload
53
53
 
54
- class EndTextEncode(BasicAction, MemoryMixin):
55
- def forward(self, memory, **states):
56
- to_cpu(memory.text_encoder)
57
- return states
54
+ vae_encode_raw = vae.encode
58
55
 
59
- class StartDiffusion(BasicAction, MemoryMixin):
60
- def forward(self, memory, **states):
61
- to_cuda(memory.unet)
62
- return states
56
+ def vae_encode_offload(x, return_dict=True, encode_raw=vae.encode):
57
+ vae.to(dtype=torch.float32)
58
+ res = encode_raw(x.cpu().to(dtype=torch.float32), return_dict=return_dict)
59
+ return res
63
60
 
64
- class EndDiffusion(BasicAction, MemoryMixin):
65
- def forward(self, memory, **states):
66
- to_cpu(memory.unet)
67
- return states
61
+ vae.encode = vae_encode_offload
62
+ return {'denoiser':denoiser, 'vae':vae, 'vae_decode_raw':vae_decode_raw, 'vae_encode_raw':vae_encode_raw}
63
+
64
+ return {'denoiser':denoiser, 'vae':vae}
65
+
66
+ class XformersEnableAction(BasicAction):
67
+ def forward(self, denoiser, **states):
68
+ if is_xformers_available():
69
+ denoiser.enable_xformers_memory_efficient_attention()
70
+ # self.te_hook.enable_xformers()
hcpdiff/workflow/text.py CHANGED
@@ -1,80 +1,112 @@
1
1
  from typing import List, Union
2
2
 
3
3
  import torch
4
- from torch.cuda.amp import autocast
5
-
6
4
  from hcpdiff.models import TokenizerHook
7
5
  from hcpdiff.models.compose import ComposeTEEXHook, ComposeEmbPTHook
8
- from .base import BasicAction, from_memory_context, MemoryMixin
6
+ from hcpdiff.utils import pad_attn_bias
9
7
  from hcpdiff.utils.net_utils import get_dtype, to_cpu, to_cuda
8
+ from rainbowneko.infer import BasicAction
9
+ from torch.cuda.amp import autocast
10
10
 
11
- class TextHookAction(BasicAction, MemoryMixin):
12
- @from_memory_context
13
- def __init__(self, TE=None, tokenizer=None, emb_dir: str = 'embs/', N_repeats: int = 1, layer_skip: int = 0, TE_final_norm: bool = True):
14
- super().__init__()
15
- self.TE = TE
16
- self.tokenizer = tokenizer
11
+ class TextHookAction(BasicAction):
12
+ def __init__(self, emb_dir: str = None, N_repeats: int = 1, layer_skip: int = 0, TE_final_norm: bool = True,
13
+ use_attention_mask=False, key_map_in=None, key_map_out=None):
14
+ super().__init__(key_map_in, key_map_out)
17
15
 
18
16
  self.emb_dir = emb_dir
19
17
  self.N_repeats = N_repeats
20
18
  self.layer_skip = layer_skip
21
19
  self.TE_final_norm = TE_final_norm
22
-
23
- def forward(self, memory, **states):
24
- self.TE = self.TE or memory.text_encoder
25
- self.tokenizer = self.tokenizer or memory.tokenizer
26
-
27
- memory.emb_hook, _ = ComposeEmbPTHook.hook_from_dir(self.emb_dir, self.tokenizer, self.TE, N_repeats=self.N_repeats)
28
- memory.te_hook = ComposeTEEXHook.hook(self.TE, self.tokenizer, N_repeats=self.N_repeats, device='cuda',
29
- clip_skip=self.layer_skip, clip_final_norm=self.TE_final_norm)
30
- memory.token_ex = TokenizerHook(self.tokenizer)
31
- return states
32
-
33
- class TextEncodeAction(BasicAction, MemoryMixin):
34
- @from_memory_context
35
- def __init__(self, prompt: Union[List, str], negative_prompt: Union[List, str], bs: int = None, te_hook=None):
36
- super().__init__()
20
+ self.use_attention_mask = use_attention_mask
21
+
22
+ def forward(self, TE, tokenizer, in_preview=False, te_hook:ComposeTEEXHook=None, emb_hook=None, **states):
23
+ if in_preview and emb_hook is not None:
24
+ emb_hook.N_repeats = self.N_repeats
25
+ else:
26
+ emb_hook, _ = ComposeEmbPTHook.hook_from_dir(self.emb_dir, tokenizer, TE, N_repeats=self.N_repeats)
27
+ tokenizer.N_repeats = self.N_repeats
28
+
29
+ if in_preview:
30
+ te_hook.N_repeats = self.N_repeats
31
+ te_hook.clip_skip = self.layer_skip
32
+ te_hook.clip_final_norm = self.TE_final_norm
33
+ te_hook.use_attention_mask = self.use_attention_mask
34
+ else:
35
+ te_hook = ComposeTEEXHook.hook(TE, tokenizer, N_repeats=self.N_repeats,
36
+ clip_skip=self.layer_skip, clip_final_norm=self.TE_final_norm, use_attention_mask=self.use_attention_mask)
37
+ token_ex = TokenizerHook(tokenizer)
38
+ return {'te_hook':te_hook, 'emb_hook':emb_hook, 'token_ex':token_ex}
39
+
40
+ class TextEncodeAction(BasicAction):
41
+ def __init__(self, prompt: Union[List, str], negative_prompt: Union[List, str], bs: int = None, key_map_in=None, key_map_out=None):
42
+ super().__init__(key_map_in, key_map_out)
37
43
  if isinstance(prompt, str) and bs is not None:
38
44
  prompt = [prompt]*bs
39
45
  negative_prompt = [negative_prompt]*bs
40
46
 
41
47
  self.prompt = prompt
42
48
  self.negative_prompt = negative_prompt
49
+ self.bs = bs
43
50
 
44
- self.te_hook = te_hook
51
+ def forward(self, te_hook, TE, dtype: str, device, amp=None, gen_step=None, prompt_all=None, negative_prompt_all=None, model_offload=False,
52
+ **states):
53
+ prompt_all = prompt_all or self.prompt
54
+ negative_prompt_all = negative_prompt_all or self.negative_prompt
45
55
 
46
- def forward(self, memory, dtype: str, device, **states):
47
- te_hook = self.te_hook or memory.te_hook
48
- with autocast(enabled=dtype == 'amp'):
49
- emb, pooled_output = te_hook.encode_prompt_to_emb(self.negative_prompt+self.prompt)
50
- # emb = emb.to(dtype=get_dtype(dtype), device=device)
51
- return {**states, 'prompt':self.prompt, 'negative_prompt':self.negative_prompt, 'prompt_embeds':emb, 'device':device, 'dtype':dtype}
56
+ if gen_step is not None:
57
+ idx = (gen_step*self.bs)%len(prompt_all)
58
+ prompt = prompt_all[idx:idx+self.bs]
59
+ negative_prompt = negative_prompt_all[idx:idx+self.bs]
60
+ else:
61
+ prompt = prompt_all
62
+ negative_prompt = negative_prompt_all
63
+
64
+ if model_offload:
65
+ to_cuda(TE)
66
+
67
+ with autocast(enabled=amp is not None, dtype=get_dtype(amp)):
68
+ emb, pooled_output, attention_mask = te_hook.encode_prompt_to_emb(negative_prompt+prompt)
69
+ if attention_mask is not None:
70
+ emb, attention_mask = pad_attn_bias(emb, attention_mask)
71
+
72
+ if model_offload:
73
+ to_cpu(TE)
74
+
75
+ if not isinstance(te_hook, ComposeTEEXHook):
76
+ pooled_output = None
77
+ return {'prompt':prompt, 'negative_prompt':negative_prompt, 'prompt_embeds':emb, 'encoder_attention_mask':attention_mask,
78
+ 'pooled_output':pooled_output}
52
79
 
53
80
  class AttnMultTextEncodeAction(TextEncodeAction):
54
- @from_memory_context
55
- def __init__(self, prompt: Union[List, str], negative_prompt: Union[List, str], bs: int = None, te_hook=None, token_ex=None):
56
- super().__init__(prompt, negative_prompt, bs, te_hook)
57
- self.token_ex = token_ex
58
-
59
- def forward(self, memory, dtype: str, device, **states):
60
- te_hook = self.te_hook or memory.te_hook
61
- token_ex = self.token_ex or memory.token_ex
62
-
63
- offload = memory.text_encoder.device.type == 'cpu'
64
- if offload:
65
- to_cuda(memory.text_encoder)
66
-
67
- mult_p, clean_text_p = token_ex.parse_attn_mult(self.prompt)
68
- mult_n, clean_text_n = token_ex.parse_attn_mult(self.negative_prompt)
69
- with autocast(enabled=dtype == 'amp'):
81
+
82
+ def forward(self, te_hook, token_ex, TE, dtype: str, device, amp=None, gen_step=None, prompt_all=None, negative_prompt_all=None,
83
+ model_offload=False, **states):
84
+ prompt_all = prompt_all if prompt_all is not None else self.prompt
85
+ negative_prompt_all = negative_prompt_all if negative_prompt_all is not None else self.negative_prompt
86
+
87
+ if gen_step is not None:
88
+ idx = (gen_step*self.bs)%len(prompt_all)
89
+ prompt = prompt_all[idx:idx+self.bs]
90
+ negative_prompt = negative_prompt_all[idx:idx+self.bs]
91
+ else:
92
+ prompt = prompt_all
93
+ negative_prompt = negative_prompt_all
94
+
95
+ if model_offload:
96
+ to_cuda(TE)
97
+
98
+ mult_p, clean_text_p = token_ex.parse_attn_mult(prompt)
99
+ mult_n, clean_text_n = token_ex.parse_attn_mult(negative_prompt)
100
+ with autocast(enabled=amp is not None, dtype=get_dtype(amp)):
70
101
  emb, pooled_output, attention_mask = te_hook.encode_prompt_to_emb(clean_text_n+clean_text_p)
71
- # emb = emb.to(dtype=dtype, device=device)
102
+ if attention_mask is not None:
103
+ emb, attention_mask = pad_attn_bias(emb, attention_mask)
72
104
  emb_n, emb_p = emb.chunk(2)
73
105
  emb_p = te_hook.mult_attn(emb_p, mult_p)
74
106
  emb_n = te_hook.mult_attn(emb_n, mult_n)
75
107
 
76
- if offload:
77
- to_cpu(memory.text_encoder)
108
+ if model_offload:
109
+ to_cpu(TE)
78
110
 
79
- return {**states, 'prompt':self.prompt, 'negative_prompt':self.negative_prompt, 'prompt_embeds':torch.cat([emb_n, emb_p], dim=0),
80
- 'device':device, 'dtype':dtype, 'encoder_attention_mask': attention_mask}
111
+ return {'prompt':list(clean_text_p), 'negative_prompt':list(clean_text_n), 'prompt_embeds':torch.cat([emb_n, emb_p], dim=0),
112
+ 'encoder_attention_mask':attention_mask, 'pooled_output':pooled_output}
hcpdiff/workflow/utils.py CHANGED
@@ -1,13 +1,14 @@
1
- import torch
1
+ from typing import List, Union
2
2
 
3
- from .base import BasicAction, from_memory_context
4
- from torch import nn
3
+ import torch
5
4
  from PIL import Image
6
- from typing import List
5
+ from hcpdiff.data.handler import ControlNetHandler
6
+ from rainbowneko.infer import BasicAction
7
+ from torch import nn
7
8
 
8
9
  class LatentResizeAction(BasicAction):
9
- @from_memory_context
10
- def __init__(self, width=1024, height=1024, mode='bicubic', antialias=True):
10
+ def __init__(self, width=1024, height=1024, mode='bicubic', antialias=True, key_map_in=None, key_map_out=None):
11
+ super().__init__(key_map_in, key_map_out)
11
12
  self.size = (height//8, width//8)
12
13
  self.mode = mode
13
14
  self.antialias = antialias
@@ -16,18 +17,37 @@ class LatentResizeAction(BasicAction):
16
17
  latents_dtype = latents.dtype
17
18
  latents = nn.functional.interpolate(latents.to(dtype=torch.float32), size=self.size, mode=self.mode)
18
19
  latents = latents.to(dtype=latents_dtype)
19
- return {**states, 'latents':latents}
20
+ return {'latents':latents}
20
21
 
21
22
  class ImageResizeAction(BasicAction):
22
23
  # resample name to Image.xxx
23
24
  mode_map = {'nearest':Image.NEAREST, 'bilinear':Image.BILINEAR, 'bicubic':Image.BICUBIC, 'lanczos':Image.LANCZOS, 'box':Image.BOX,
24
- 'hamming':Image.HAMMING, 'antialias':Image.ANTIALIAS}
25
+ 'hamming':Image.HAMMING, 'antialias':Image.LANCZOS}
25
26
 
26
- @from_memory_context
27
- def __init__(self, width=1024, height=1024, mode='bicubic'):
27
+ def __init__(self, width=1024, height=1024, mode='bicubic', key_map_in=None, key_map_out=None):
28
+ super().__init__(key_map_in, key_map_out)
28
29
  self.size = (width, height)
29
30
  self.mode = self.mode_map[mode]
30
31
 
31
- def forward(self, images:List[Image.Image], **states):
32
+ def forward(self, images: List[Image.Image], **states):
32
33
  images = [image.resize(self.size, resample=self.mode) for image in images]
33
- return {**states, 'images':images}
34
+ return {'images':images}
35
+
36
+ class FeedtoCNetAction(BasicAction):
37
+ def __init__(self, width=None, height=None, key_map_in=None, key_map_out=None):
38
+ super().__init__(key_map_in, key_map_out)
39
+ self.size = (width, height)
40
+ self.cnet_handler = ControlNetHandler()
41
+
42
+ def forward(self, images: Union[List[Image.Image], Image.Image], device='cuda', dtype=None, bs=None, latents=None, **states):
43
+ if bs is None:
44
+ if 'prompt' in states:
45
+ bs = len(states['prompt'])
46
+
47
+ if latents is not None:
48
+ width, height = latents.shape[3]*8, latents.shape[2]*8
49
+ else:
50
+ width, height = self.size
51
+
52
+ images = self.cnet_handler.handle(images).to(device, dtype=dtype).expand(bs*2, 3, width, height)
53
+ return {'ex_inputs':{'cond':images}}
hcpdiff/workflow/vae.py CHANGED
@@ -1,33 +1,32 @@
1
- from .base import BasicAction, from_memory_context
2
- from diffusers import AutoencoderKL
3
- from diffusers.image_processor import VaeImageProcessor
4
- from typing import Dict, Any
5
1
  import torch
2
+ from diffusers.image_processor import VaeImageProcessor
6
3
  from hcpdiff.utils import to_cuda, to_cpu
7
4
  from hcpdiff.utils.net_utils import get_dtype
5
+ from rainbowneko.infer import BasicAction
8
6
 
9
7
  class EncodeAction(BasicAction):
10
- @from_memory_context
11
- def __init__(self, vae: AutoencoderKL, image_processor=None, offload: Dict[str, Any] = None):
12
- super().__init__()
13
- self.vae = vae
14
- self.vae_scale_factor = 2**(len(self.vae.config.block_out_channels)-1)
15
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) if image_processor is None else image_processor
16
- self.offload = offload
8
+ def __init__(self, image_processor=None, key_map_in=None, key_map_out=None):
9
+ super().__init__(key_map_in, key_map_out)
10
+ self.image_processor = image_processor
17
11
 
18
- def forward(self, images, dtype:str, device, generator, bs=None, **states):
12
+ def forward(self, vae, images, dtype: str, device, generator, bs=None, model_offload=False, **states):
19
13
  if bs is None:
20
14
  if 'prompt' in states:
21
15
  bs = len(states['prompt'])
16
+ vae_scale_factor = 2**(len(vae.config.block_out_channels)-1)
17
+ if self.image_processor is None:
18
+ self.image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
22
19
 
23
20
  image = self.image_processor.preprocess(images)
24
- image = image.to(device=device, dtype=self.vae.dtype)
21
+ if bs is not None and image.shape[0] != bs:
22
+ image = image.repeat(bs//image.shape[0], 1, 1, 1)
23
+ image = image.to(device=device, dtype=vae.dtype)
25
24
 
26
25
  if image.shape[1] == 4:
27
26
  init_latents = image
28
27
  else:
29
- if self.offload:
30
- to_cuda(self.vae)
28
+ if model_offload:
29
+ to_cuda(vae)
31
30
  if isinstance(generator, list) and len(generator) != bs:
32
31
  raise ValueError(
33
32
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -36,38 +35,38 @@ class EncodeAction(BasicAction):
36
35
 
37
36
  elif isinstance(generator, list):
38
37
  init_latents = [
39
- self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(bs)
38
+ vae.encode(image[i: i+1]).latent_dist.sample(generator[i]) for i in range(bs)
40
39
  ]
41
40
  init_latents = torch.cat(init_latents, dim=0)
42
41
  else:
43
- init_latents = self.vae.encode(image).latent_dist.sample(generator)
42
+ init_latents = vae.encode(image).latent_dist.sample(generator)
44
43
 
45
- init_latents = self.vae.config.scaling_factor * init_latents.to(dtype=get_dtype(dtype))
46
- if self.offload:
47
- to_cpu(self.vae)
48
- return {**states, 'latents':init_latents, 'dtype':dtype, 'device':device, 'bs':bs}
44
+ init_latents = vae.config.scaling_factor*init_latents.to(dtype=get_dtype(dtype))
45
+ if model_offload:
46
+ to_cpu(vae)
47
+ return {'latents':init_latents}
49
48
 
50
49
  class DecodeAction(BasicAction):
51
- @from_memory_context
52
- def __init__(self, vae: AutoencoderKL, image_processor=None, output_type='pil', offload: Dict[str, Any] = None, decode_key='latents'):
53
- super().__init__()
54
- self.vae = vae
55
- self.offload = offload
50
+ def __init__(self, image_processor=None, output_type='pil', key_map_in=None, key_map_out=None):
51
+ super().__init__(key_map_in, key_map_out)
56
52
 
57
- self.vae_scale_factor = 2**(len(self.vae.config.block_out_channels)-1)
58
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) if image_processor is None else image_processor
53
+ self.image_processor = image_processor
59
54
  self.output_type = output_type
60
- self.decode_key = decode_key
61
55
 
62
- def forward(self, **states):
63
- latents = states[self.decode_key]
64
- if self.offload:
65
- to_cuda(self.vae)
66
- latents = latents.to(dtype=self.vae.dtype)
67
- image = self.vae.decode(latents/self.vae.config.scaling_factor, return_dict=False)[0]
68
- if self.offload:
69
- to_cpu(self.vae)
56
+ def forward(self, vae, denoiser, latents, model_offload=False, **states):
57
+ vae_scale_factor = 2**(len(vae.config.block_out_channels)-1)
58
+ if self.image_processor is None:
59
+ self.image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
60
+
61
+ if model_offload:
62
+ to_cpu(denoiser)
63
+ torch.cuda.synchronize()
64
+ to_cuda(vae)
65
+ latents = latents.to(dtype=vae.dtype)
66
+ image = vae.decode(latents/vae.config.scaling_factor, return_dict=False)[0]
67
+ if model_offload:
68
+ to_cpu(vae)
70
69
 
71
70
  do_denormalize = [True]*image.shape[0]
72
71
  image = self.image_processor.postprocess(image, output_type=self.output_type, do_denormalize=do_denormalize)
73
- return {**states, 'images':image}
72
+ return {'images':image}