hcpdiff 0.9.1__py3-none-any.whl → 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (210) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +80 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +1 -2
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text2img.py +36 -74
  21. hcpdiff/data/source/text2img_cond.py +9 -15
  22. hcpdiff/diffusion/__init__.py +0 -0
  23. hcpdiff/diffusion/noise/__init__.py +2 -0
  24. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  25. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  26. hcpdiff/diffusion/sampler/__init__.py +5 -0
  27. hcpdiff/diffusion/sampler/base.py +72 -0
  28. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  29. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  30. hcpdiff/diffusion/sampler/edm.py +22 -0
  31. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  35. hcpdiff/easy/__init__.py +2 -0
  36. hcpdiff/easy/cfg/__init__.py +3 -0
  37. hcpdiff/easy/cfg/sd15_train.py +201 -0
  38. hcpdiff/easy/cfg/sdxl_train.py +140 -0
  39. hcpdiff/easy/cfg/t2i.py +177 -0
  40. hcpdiff/easy/model/__init__.py +2 -0
  41. hcpdiff/easy/model/cnet.py +31 -0
  42. hcpdiff/easy/model/loader.py +79 -0
  43. hcpdiff/easy/sampler.py +46 -0
  44. hcpdiff/evaluate/__init__.py +1 -0
  45. hcpdiff/evaluate/previewer.py +60 -0
  46. hcpdiff/loss/__init__.py +4 -1
  47. hcpdiff/loss/base.py +41 -0
  48. hcpdiff/loss/gw.py +35 -0
  49. hcpdiff/loss/ssim.py +37 -0
  50. hcpdiff/loss/vlb.py +79 -0
  51. hcpdiff/loss/weighting.py +66 -0
  52. hcpdiff/models/__init__.py +2 -2
  53. hcpdiff/models/cfg_context.py +17 -14
  54. hcpdiff/models/compose/compose_hook.py +44 -23
  55. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  56. hcpdiff/models/compose/sdxl_composer.py +4 -4
  57. hcpdiff/models/controlnet.py +16 -16
  58. hcpdiff/models/lora_base_patch.py +14 -25
  59. hcpdiff/models/lora_layers.py +3 -9
  60. hcpdiff/models/lora_layers_patch.py +14 -24
  61. hcpdiff/models/text_emb_ex.py +84 -6
  62. hcpdiff/models/textencoder_ex.py +54 -18
  63. hcpdiff/models/wrapper/__init__.py +3 -0
  64. hcpdiff/models/wrapper/pixart.py +19 -0
  65. hcpdiff/models/wrapper/sd.py +218 -0
  66. hcpdiff/models/wrapper/utils.py +20 -0
  67. hcpdiff/parser/__init__.py +1 -0
  68. hcpdiff/parser/embpt.py +32 -0
  69. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  70. hcpdiff/tools/dataset_generator.py +94 -0
  71. hcpdiff/tools/download_hf_model.py +24 -0
  72. hcpdiff/tools/init_proj.py +3 -21
  73. hcpdiff/tools/lora_convert.py +18 -17
  74. hcpdiff/tools/save_model.py +12 -0
  75. hcpdiff/tools/sd2diffusers.py +1 -1
  76. hcpdiff/train_colo.py +1 -1
  77. hcpdiff/train_deepspeed.py +1 -1
  78. hcpdiff/trainer_ac.py +79 -0
  79. hcpdiff/trainer_ac_single.py +31 -0
  80. hcpdiff/utils/__init__.py +0 -2
  81. hcpdiff/utils/inpaint_pipe.py +7 -2
  82. hcpdiff/utils/net_utils.py +29 -6
  83. hcpdiff/utils/pipe_hook.py +24 -7
  84. hcpdiff/utils/utils.py +21 -4
  85. hcpdiff/workflow/__init__.py +15 -10
  86. hcpdiff/workflow/daam/__init__.py +1 -0
  87. hcpdiff/workflow/daam/act.py +66 -0
  88. hcpdiff/workflow/daam/hook.py +109 -0
  89. hcpdiff/workflow/diffusion.py +114 -125
  90. hcpdiff/workflow/fast.py +31 -0
  91. hcpdiff/workflow/flow.py +67 -0
  92. hcpdiff/workflow/io.py +36 -130
  93. hcpdiff/workflow/model.py +46 -43
  94. hcpdiff/workflow/text.py +78 -46
  95. hcpdiff/workflow/utils.py +32 -12
  96. hcpdiff/workflow/vae.py +37 -38
  97. hcpdiff-2.1.dist-info/METADATA +285 -0
  98. hcpdiff-2.1.dist-info/RECORD +114 -0
  99. {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
  100. hcpdiff-2.1.dist-info/entry_points.txt +5 -0
  101. hcpdiff/ckpt_manager/base.py +0 -16
  102. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  103. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  104. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
  105. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  106. hcpdiff/data/bucket.py +0 -358
  107. hcpdiff/data/caption_loader.py +0 -80
  108. hcpdiff/data/cond_dataset.py +0 -40
  109. hcpdiff/data/crop_info_dataset.py +0 -40
  110. hcpdiff/data/data_processor.py +0 -33
  111. hcpdiff/data/pair_dataset.py +0 -146
  112. hcpdiff/data/sampler.py +0 -54
  113. hcpdiff/data/source/base.py +0 -30
  114. hcpdiff/data/utils.py +0 -80
  115. hcpdiff/deprecated/__init__.py +0 -1
  116. hcpdiff/deprecated/cfg_converter.py +0 -81
  117. hcpdiff/deprecated/lora_convert.py +0 -31
  118. hcpdiff/infer_workflow.py +0 -57
  119. hcpdiff/loggers/__init__.py +0 -13
  120. hcpdiff/loggers/base_logger.py +0 -76
  121. hcpdiff/loggers/cli_logger.py +0 -40
  122. hcpdiff/loggers/preview/__init__.py +0 -1
  123. hcpdiff/loggers/preview/image_previewer.py +0 -149
  124. hcpdiff/loggers/tensorboard_logger.py +0 -30
  125. hcpdiff/loggers/wandb_logger.py +0 -31
  126. hcpdiff/loggers/webui_logger.py +0 -9
  127. hcpdiff/loss/min_snr_loss.py +0 -52
  128. hcpdiff/models/layers.py +0 -81
  129. hcpdiff/models/plugin.py +0 -348
  130. hcpdiff/models/wrapper.py +0 -75
  131. hcpdiff/noise/__init__.py +0 -3
  132. hcpdiff/noise/noise_base.py +0 -16
  133. hcpdiff/noise/pyramid_noise.py +0 -50
  134. hcpdiff/noise/zero_terminal.py +0 -44
  135. hcpdiff/train_ac.py +0 -566
  136. hcpdiff/train_ac_single.py +0 -39
  137. hcpdiff/utils/caption_tools.py +0 -105
  138. hcpdiff/utils/cfg_net_tools.py +0 -321
  139. hcpdiff/utils/cfg_resolvers.py +0 -16
  140. hcpdiff/utils/ema.py +0 -52
  141. hcpdiff/utils/img_size_tool.py +0 -248
  142. hcpdiff/vis/__init__.py +0 -3
  143. hcpdiff/vis/base_interface.py +0 -12
  144. hcpdiff/vis/disk_interface.py +0 -48
  145. hcpdiff/vis/webui_interface.py +0 -17
  146. hcpdiff/viser_fast.py +0 -138
  147. hcpdiff/visualizer.py +0 -265
  148. hcpdiff/visualizer_reloadable.py +0 -237
  149. hcpdiff/workflow/base.py +0 -59
  150. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  151. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  152. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  153. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  154. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  155. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  156. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  157. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  158. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  159. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  160. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  161. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  162. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  163. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  164. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  165. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  166. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  167. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  168. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  169. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  170. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  171. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  172. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  173. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  174. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  175. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  176. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  177. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  178. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  179. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  180. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  181. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  182. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  183. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  184. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  185. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  186. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  187. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  188. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  189. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  190. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  191. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  192. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  193. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  194. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
  195. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  196. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
  197. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
  198. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  199. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  200. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  201. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  202. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  203. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  204. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  205. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  206. hcpdiff-0.9.1.dist-info/METADATA +0 -199
  207. hcpdiff-0.9.1.dist-info/RECORD +0 -160
  208. hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
  209. {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
  210. {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
hcpdiff/workflow/io.py CHANGED
@@ -1,150 +1,56 @@
1
1
  import os
2
- from typing import List
3
- import warnings
2
+ from functools import partial
3
+ from typing import List, Union
4
4
 
5
- from diffusers import UNet2DConditionModel, AutoencoderKL, PNDMScheduler
6
-
7
- from hcpdiff.utils import auto_text_encoder, auto_tokenizer, to_validate_file
8
- from hcpdiff.utils.cfg_net_tools import HCPModelLoader, make_plugin
9
- from hcpdiff.utils.img_size_tool import types_support
5
+ import torch
6
+ from hcpdiff.utils import to_validate_file
10
7
  from hcpdiff.utils.net_utils import get_dtype
11
- from .base import BasicAction, from_memory_context, MemoryMixin
12
-
13
- class LoadModelsAction(BasicAction, MemoryMixin):
14
- @from_memory_context
15
- def __init__(self, pretrained_model: str, dtype: str, unet=None, text_encoder=None, tokenizer=None, vae=None, scheduler=None):
16
- self.pretrained_model = pretrained_model
8
+ from rainbowneko.ckpt_manager import NekoLoader
9
+ from rainbowneko.infer import BasicAction
10
+ from rainbowneko.infer import LoadImageAction as Neko_LoadImageAction
11
+ from rainbowneko.utils.img_size_tool import types_support
12
+
13
+ class BuildModelsAction(BasicAction):
14
+ def __init__(self, model_loader: partial[NekoLoader.load], dtype: str=torch.float32, device='cuda', key_map_in=None, key_map_out=None):
15
+ super().__init__(key_map_in, key_map_out)
16
+ self.model_loader = model_loader
17
17
  self.dtype = get_dtype(dtype)
18
+ self.device = device
18
19
 
19
- self.unet = unet
20
- self.text_encoder = text_encoder
21
- self.tokenizer = tokenizer
22
- self.vae = vae
23
- self.scheduler = scheduler
20
+ def forward(self, in_preview=False, model=None, **states):
21
+ if in_preview:
22
+ model = self.model_loader(dtype=self.dtype, device=self.device, denoiser=model.denoiser, TE=model.TE, vae=model.vae)
23
+ else:
24
+ model = self.model_loader(dtype=self.dtype, device=self.device)
24
25
 
25
- def forward(self, memory, **states):
26
- memory.unet = self.unet or UNet2DConditionModel.from_pretrained(self.pretrained_model, subfolder="unet", torch_dtype=self.dtype)
27
- memory.text_encoder = self.text_encoder or auto_text_encoder(self.pretrained_model, subfolder="text_encoder", torch_dtype=self.dtype)
28
- memory.tokenizer = self.tokenizer or auto_tokenizer(self.pretrained_model, subfolder="tokenizer", use_fast=False)
29
- memory.vae = self.vae or AutoencoderKL.from_pretrained(self.pretrained_model, subfolder="vae", torch_dtype=self.dtype)
30
- memory.scheduler = self.scheduler or PNDMScheduler.from_pretrained(self.pretrained_model, subfolder="scheduler", torch_dtype=self.dtype)
26
+ if isinstance(model, dict):
27
+ return model
28
+ else:
29
+ return {'model':model}
31
30
 
32
- return states
31
+ class LoadImageAction(Neko_LoadImageAction):
32
+ def __init__(self, image_paths: Union[str, List[str]], image_transforms=None, key_map_in=None, key_map_out=('input.x -> images',)):
33
+ super().__init__(image_paths, image_transforms, key_map_in, key_map_out)
33
34
 
34
35
  class SaveImageAction(BasicAction):
35
- @from_memory_context
36
- def __init__(self, save_root: str, image_type: str = 'png', quality: int = 95):
36
+ def __init__(self, save_root: str, image_type: str = 'png', quality: int = 95, save_cfg=True, key_map_in=None, key_map_out=None):
37
+ super().__init__(key_map_in, key_map_out)
37
38
  self.save_root = save_root
38
39
  self.image_type = image_type
39
40
  self.quality = quality
41
+ self.save_cfg = save_cfg
40
42
 
41
43
  os.makedirs(save_root, exist_ok=True)
42
44
 
43
- def forward(self, images, prompt, negative_prompt, seeds=None, **states):
44
- num_img_exist = max([0]+[int(x.split('-', 1)[0]) for x in os.listdir(self.save_root) if x.rsplit('.', 1)[-1] in types_support])+1
45
+ def forward(self, images, prompt, negative_prompt, seeds, cfgs=None, parser=None, preview_root=None, preview_step=None, **states):
46
+ save_root = preview_root or self.save_root
47
+ num_img_exist = max([0]+[int(x.split('-', 1)[0]) for x in os.listdir(save_root) if x.rsplit('.', 1)[-1] in types_support])+1
45
48
 
46
49
  for bid, (p, pn, img) in enumerate(zip(prompt, negative_prompt, images)):
47
- img_path = os.path.join(self.save_root, f"{num_img_exist}-{seeds[bid]}-{to_validate_file(prompt[0])}.{self.image_type}")
50
+ img_path = os.path.join(save_root, f"{preview_step or num_img_exist}-{seeds[bid]}-{to_validate_file(prompt[0])}.{self.image_type}")
48
51
  img.save(img_path, quality=self.quality)
49
52
  num_img_exist += 1
50
53
 
51
- return {**states, 'images':images, 'prompt':prompt, 'negative_prompt':negative_prompt, 'seeds':seeds}
52
-
53
- class BuildModelLoaderAction(BasicAction, MemoryMixin):
54
- def forward(self, memory, **states):
55
- memory.model_loader_unet = HCPModelLoader(memory.unet)
56
- memory.model_loader_TE = HCPModelLoader(memory.text_encoder)
57
- return states
58
-
59
- class LoadPartAction(BasicAction, MemoryMixin):
60
- @from_memory_context
61
- def __init__(self, model: str, cfg):
62
- self.model = model
63
- self.cfg = cfg
64
-
65
- def forward(self, memory, **states):
66
- model_loader = memory[f"model_loader_{self.model}"]
67
- model_loader.load_part(self.cfg)
68
- return states
69
-
70
- class LoadLoraAction(BasicAction, MemoryMixin):
71
- @from_memory_context
72
- def __init__(self, model: str, cfg):
73
- self.model = model
74
- self.cfg = cfg
75
-
76
- def forward(self, memory, **states):
77
- model_loader = memory[f"model_loader_{self.model}"]
78
- lora_group = model_loader.load_lora(self.cfg)
79
- if 'lora_dict' not in memory:
80
- memory.lora_dict = {}
81
- if path in memory.lora_dict:
82
- warnings.warn(f"Lora {path} already loaded, and will be replaced!")
83
- memory.lora_dict[path].remove()
84
- memory.lora_dict[path] = lora_group
85
- return states
86
-
87
- class BuildPluginAction(BasicAction, MemoryMixin):
88
- @from_memory_context
89
- def __init__(self, model: str, cfg):
90
- self.model = model
91
- self.cfg = cfg
92
-
93
- def forward(self, memory, **states):
94
- if isinstance(self.cfg_merge.plugin_cfg, str):
95
- plugin_cfg = load_config(self.cfg_merge.plugin_cfg)
96
- plugin_cfg = {'plugin_unet':hydra.utils.instantiate(plugin_cfg['plugin_unet']),
97
- 'plugin_TE':hydra.utils.instantiate(plugin_cfg['plugin_TE'])}
98
- else:
99
- plugin_cfg = self.cfg_merge.plugin_cfg
100
- all_plugin_group_unet = make_plugin(memory.unet, plugin_cfg['plugin_unet'])
101
- all_plugin_group_TE = make_plugin(memory.text_encoder, plugin_cfg['plugin_TE'])
102
-
103
- if 'plugin_dict' not in memory:
104
- memory.plugin_dict = {}
105
-
106
- for name, plugin_group in all_plugin_group_unet.items():
107
- memory.plugin_dict[name] = plugin_group
108
- for name, plugin_group in all_plugin_group_TE.items():
109
- memory.plugin_dict[name] = plugin_group
110
-
111
- return states
112
-
113
- class LoadPluginAction(BasicAction, MemoryMixin):
114
- @from_memory_context
115
- def __init__(self, model: str, cfg):
116
- self.model = model
117
- self.cfg = cfg
118
-
119
- def forward(self, memory, **states):
120
- model_loader = memory[f"model_loader_{self.model}"]
121
- model_loader.load_plugin(self.cfg)
122
- return states
123
-
124
- class RemoveLoraAction(BasicAction, MemoryMixin):
125
- @from_memory_context
126
- def __init__(self, path_list: List[str]):
127
- self.path_list = path_list
128
-
129
- def forward(self, memory, **states):
130
- for path in self.path_list:
131
- if path in memory.lora_dict:
132
- memory.lora_dict[path].remove()
133
- del memory.lora_dict[path]
134
- else:
135
- warnings.warn(f"Lora {path} not loaded!")
136
- return states
137
-
138
- class RemovePluginAction(BasicAction, MemoryMixin):
139
- @from_memory_context
140
- def __init__(self, name_list: List[str]):
141
- self.name_list = name_list
142
-
143
- def forward(self, memory, **states):
144
- for name in self.name_list:
145
- if name in memory.plugin_dict:
146
- memory.plugin_dict[name].remove()
147
- del memory.plugin_dict[name]
148
- else:
149
- warnings.warn(f"Plugin {name} not loaded!")
150
- return states
54
+ if self.save_cfg:
55
+ cfgs.seed = seeds[bid]
56
+ parser.save_configs(cfgs, os.path.join(save_root, f"{preview_step or num_img_exist}-{seeds[bid]}-info"))
hcpdiff/workflow/model.py CHANGED
@@ -1,67 +1,70 @@
1
+ import torch
1
2
  from accelerate import infer_auto_device_map, dispatch_model
2
3
  from diffusers.utils.import_utils import is_xformers_available
4
+ from rainbowneko.infer import BasicAction
3
5
 
4
- from hcpdiff.utils.net_utils import get_dtype, to_cpu, to_cuda
6
+ from hcpdiff.utils.net_utils import get_dtype
7
+ from hcpdiff.utils.net_utils import to_cpu
5
8
  from hcpdiff.utils.utils import size_to_int, int_to_size
6
- from .base import BasicAction, from_memory_context, MemoryMixin
7
9
 
8
- class VaeOptimizeAction(BasicAction, MemoryMixin):
9
- @from_memory_context
10
- def __init__(self, vae=None, slicing=True, tiling=False):
11
- super().__init__()
10
+ class VaeOptimizeAction(BasicAction):
11
+ def __init__(self, slicing=True, tiling=False, key_map_in=None, key_map_out=None):
12
+ super().__init__(key_map_in, key_map_out)
12
13
  self.slicing = slicing
13
14
  self.tiling = tiling
14
- self.vae = vae
15
-
16
- def forward(self, memory, **states):
17
- vae = self.vae or memory.vae
18
15
 
16
+ def forward(self, vae, **states):
19
17
  if self.tiling:
20
18
  vae.enable_tiling()
21
19
  if self.slicing:
22
20
  vae.enable_slicing()
23
- return states
24
21
 
25
- class BuildOffloadAction(BasicAction, MemoryMixin):
26
- @from_memory_context
27
- def __init__(self, max_VRAM: str, max_RAM: str):
28
- super().__init__()
22
+ class BuildOffloadAction(BasicAction):
23
+ def __init__(self, max_VRAM: str, max_RAM: str, vae_cpu=False, key_map_in=None, key_map_out=None):
24
+ super().__init__(key_map_in, key_map_out)
29
25
  self.max_VRAM = max_VRAM
30
26
  self.max_RAM = max_RAM
27
+ self.vae_cpu = vae_cpu
31
28
 
32
- def forward(self, memory, dtype: str, **states):
29
+ def forward(self, vae, denoiser, dtype: str, **states):
30
+ # denoiser offload
33
31
  torch_dtype = get_dtype(dtype)
34
32
  vram = size_to_int(self.max_VRAM)
35
- device_map = infer_auto_device_map(memory.unet, max_memory={0:int_to_size(vram >> 1), "cpu":self.max_RAM}, dtype=torch_dtype)
36
- memory.unet = dispatch_model(memory.unet, device_map)
33
+ device_map = infer_auto_device_map(denoiser, max_memory={0:int_to_size(vram >> 1), "cpu":self.max_RAM}, dtype=torch_dtype)
34
+ denoiser = dispatch_model(denoiser, device_map)
37
35
 
38
- device_map = infer_auto_device_map(memory.vae, max_memory={0:int_to_size(vram >> 5), "cpu":self.max_RAM}, dtype=torch_dtype)
39
- memory.vae = dispatch_model(memory.vae, device_map)
40
- return {'dtype':dtype, **states}
36
+ device_map = infer_auto_device_map(vae, max_memory={0:int_to_size(vram >> 5), "cpu":self.max_RAM}, dtype=torch_dtype)
37
+ vae = dispatch_model(vae, device_map)
38
+ # VAE offload
39
+ vram = size_to_int(self.max_VRAM)
40
+ if not self.vae_cpu:
41
+ device_map = infer_auto_device_map(vae, max_memory={0:int_to_size(vram >> 5), "cpu":self.max_RAM}, dtype=torch.float32)
42
+ vae = dispatch_model(vae, device_map)
43
+ else:
44
+ to_cpu(vae)
45
+ vae_decode_raw = vae.decode
41
46
 
42
- class XformersEnableAction(BasicAction, MemoryMixin):
43
- def forward(self, memory, **states):
44
- if is_xformers_available():
45
- memory.unet.enable_xformers_memory_efficient_attention()
46
- # self.te_hook.enable_xformers()
47
- return states
47
+ def vae_decode_offload(latents, return_dict=True, decode_raw=vae.decode):
48
+ vae.to(dtype=torch.float32)
49
+ res = decode_raw(latents.cpu().to(dtype=torch.float32), return_dict=return_dict)
50
+ return res
48
51
 
49
- class StartTextEncode(BasicAction, MemoryMixin):
50
- def forward(self, memory, **states):
51
- to_cuda(memory.text_encoder)
52
- return states
52
+ vae.decode = vae_decode_offload
53
53
 
54
- class EndTextEncode(BasicAction, MemoryMixin):
55
- def forward(self, memory, **states):
56
- to_cpu(memory.text_encoder)
57
- return states
54
+ vae_encode_raw = vae.encode
58
55
 
59
- class StartDiffusion(BasicAction, MemoryMixin):
60
- def forward(self, memory, **states):
61
- to_cuda(memory.unet)
62
- return states
56
+ def vae_encode_offload(x, return_dict=True, encode_raw=vae.encode):
57
+ vae.to(dtype=torch.float32)
58
+ res = encode_raw(x.cpu().to(dtype=torch.float32), return_dict=return_dict)
59
+ return res
63
60
 
64
- class EndDiffusion(BasicAction, MemoryMixin):
65
- def forward(self, memory, **states):
66
- to_cpu(memory.unet)
67
- return states
61
+ vae.encode = vae_encode_offload
62
+ return {'denoiser':denoiser, 'vae':vae, 'vae_decode_raw':vae_decode_raw, 'vae_encode_raw':vae_encode_raw}
63
+
64
+ return {'denoiser':denoiser, 'vae':vae}
65
+
66
+ class XformersEnableAction(BasicAction):
67
+ def forward(self, denoiser, **states):
68
+ if is_xformers_available():
69
+ denoiser.enable_xformers_memory_efficient_attention()
70
+ # self.te_hook.enable_xformers()
hcpdiff/workflow/text.py CHANGED
@@ -3,78 +3,110 @@ from typing import List, Union
3
3
  import torch
4
4
  from hcpdiff.models import TokenizerHook
5
5
  from hcpdiff.models.compose import ComposeTEEXHook, ComposeEmbPTHook
6
+ from hcpdiff.utils import pad_attn_bias
6
7
  from hcpdiff.utils.net_utils import get_dtype, to_cpu, to_cuda
8
+ from rainbowneko.infer import BasicAction
7
9
  from torch.cuda.amp import autocast
8
10
 
9
- from .base import BasicAction, from_memory_context, MemoryMixin
10
-
11
- class TextHookAction(BasicAction, MemoryMixin):
12
- @from_memory_context
13
- def __init__(self, TE=None, tokenizer=None, emb_dir: str = 'embs/', N_repeats: int = 1, layer_skip: int = 0, TE_final_norm: bool = True):
14
- super().__init__()
15
- self.TE = TE
16
- self.tokenizer = tokenizer
11
+ class TextHookAction(BasicAction):
12
+ def __init__(self, emb_dir: str = None, N_repeats: int = 1, layer_skip: int = 0, TE_final_norm: bool = True,
13
+ use_attention_mask=False, key_map_in=None, key_map_out=None):
14
+ super().__init__(key_map_in, key_map_out)
17
15
 
18
16
  self.emb_dir = emb_dir
19
17
  self.N_repeats = N_repeats
20
18
  self.layer_skip = layer_skip
21
19
  self.TE_final_norm = TE_final_norm
22
-
23
- def forward(self, memory, **states):
24
- self.TE = self.TE or memory.text_encoder
25
- self.tokenizer = self.tokenizer or memory.tokenizer
26
-
27
- memory.emb_hook, _ = ComposeEmbPTHook.hook_from_dir(self.emb_dir, self.tokenizer, self.TE, N_repeats=self.N_repeats)
28
- memory.te_hook = ComposeTEEXHook.hook(self.TE, self.tokenizer, N_repeats=self.N_repeats, device='cuda',
29
- clip_skip=self.layer_skip, clip_final_norm=self.TE_final_norm)
30
- memory.token_ex = TokenizerHook(self.tokenizer)
31
- return states
32
-
33
- class TextEncodeAction(BasicAction, MemoryMixin):
34
- @from_memory_context
35
- def __init__(self, prompt: Union[List, str], negative_prompt: Union[List, str], bs: int = None, te_hook=None):
36
- super().__init__()
20
+ self.use_attention_mask = use_attention_mask
21
+
22
+ def forward(self, TE, tokenizer, in_preview=False, te_hook:ComposeTEEXHook=None, emb_hook=None, **states):
23
+ if in_preview and emb_hook is not None:
24
+ emb_hook.N_repeats = self.N_repeats
25
+ else:
26
+ emb_hook, _ = ComposeEmbPTHook.hook_from_dir(self.emb_dir, tokenizer, TE, N_repeats=self.N_repeats)
27
+ tokenizer.N_repeats = self.N_repeats
28
+
29
+ if in_preview:
30
+ te_hook.N_repeats = self.N_repeats
31
+ te_hook.clip_skip = self.layer_skip
32
+ te_hook.clip_final_norm = self.TE_final_norm
33
+ te_hook.use_attention_mask = self.use_attention_mask
34
+ else:
35
+ te_hook = ComposeTEEXHook.hook(TE, tokenizer, N_repeats=self.N_repeats,
36
+ clip_skip=self.layer_skip, clip_final_norm=self.TE_final_norm, use_attention_mask=self.use_attention_mask)
37
+ token_ex = TokenizerHook(tokenizer)
38
+ return {'te_hook':te_hook, 'emb_hook':emb_hook, 'token_ex':token_ex}
39
+
40
+ class TextEncodeAction(BasicAction):
41
+ def __init__(self, prompt: Union[List, str], negative_prompt: Union[List, str], bs: int = None, key_map_in=None, key_map_out=None):
42
+ super().__init__(key_map_in, key_map_out)
37
43
  if isinstance(prompt, str) and bs is not None:
38
44
  prompt = [prompt]*bs
39
45
  negative_prompt = [negative_prompt]*bs
40
46
 
41
47
  self.prompt = prompt
42
48
  self.negative_prompt = negative_prompt
49
+ self.bs = bs
50
+
51
+ def forward(self, te_hook, TE, dtype: str, device, amp=None, gen_step=None, prompt_all=None, negative_prompt_all=None, model_offload=False,
52
+ **states):
53
+ prompt_all = prompt_all or self.prompt
54
+ negative_prompt_all = negative_prompt_all or self.negative_prompt
43
55
 
44
- self.te_hook = te_hook
56
+ if gen_step is not None:
57
+ idx = (gen_step*self.bs)%len(prompt_all)
58
+ prompt = prompt_all[idx:idx+self.bs]
59
+ negative_prompt = negative_prompt_all[idx:idx+self.bs]
60
+ else:
61
+ prompt = prompt_all
62
+ negative_prompt = negative_prompt_all
63
+
64
+ if model_offload:
65
+ to_cuda(TE)
45
66
 
46
- def forward(self, memory, dtype: str, device, amp=None, **states):
47
- te_hook = self.te_hook or memory.te_hook
48
67
  with autocast(enabled=amp is not None, dtype=get_dtype(amp)):
49
- emb, pooled_output = te_hook.encode_prompt_to_emb(self.negative_prompt+self.prompt)
50
- # emb = emb.to(dtype=get_dtype(dtype), device=device)
51
- return {**states, 'prompt':self.prompt, 'negative_prompt':self.negative_prompt, 'prompt_embeds':emb, 'amp':amp,
52
- 'device':device, 'dtype':dtype}
68
+ emb, pooled_output, attention_mask = te_hook.encode_prompt_to_emb(negative_prompt+prompt)
69
+ if attention_mask is not None:
70
+ emb, attention_mask = pad_attn_bias(emb, attention_mask)
71
+
72
+ if model_offload:
73
+ to_cpu(TE)
74
+
75
+ if not isinstance(te_hook, ComposeTEEXHook):
76
+ pooled_output = None
77
+ return {'prompt':prompt, 'negative_prompt':negative_prompt, 'prompt_embeds':emb, 'encoder_attention_mask':attention_mask,
78
+ 'pooled_output':pooled_output}
53
79
 
54
80
  class AttnMultTextEncodeAction(TextEncodeAction):
55
- @from_memory_context
56
- def __init__(self, prompt: Union[List, str], negative_prompt: Union[List, str], bs: int = None, te_hook=None, token_ex=None):
57
- super().__init__(prompt, negative_prompt, bs, te_hook)
58
- self.token_ex = token_ex
59
81
 
60
- def forward(self, memory, dtype: str, device, amp=None, **states):
61
- te_hook = self.te_hook or memory.te_hook
62
- token_ex = self.token_ex or memory.token_ex
82
+ def forward(self, te_hook, token_ex, TE, dtype: str, device, amp=None, gen_step=None, prompt_all=None, negative_prompt_all=None,
83
+ model_offload=False, **states):
84
+ prompt_all = prompt_all if prompt_all is not None else self.prompt
85
+ negative_prompt_all = negative_prompt_all if negative_prompt_all is not None else self.negative_prompt
86
+
87
+ if gen_step is not None:
88
+ idx = (gen_step*self.bs)%len(prompt_all)
89
+ prompt = prompt_all[idx:idx+self.bs]
90
+ negative_prompt = negative_prompt_all[idx:idx+self.bs]
91
+ else:
92
+ prompt = prompt_all
93
+ negative_prompt = negative_prompt_all
63
94
 
64
- offload = memory.text_encoder.device.type == 'cpu'
65
- if offload:
66
- to_cuda(memory.text_encoder)
95
+ if model_offload:
96
+ to_cuda(TE)
67
97
 
68
- mult_p, clean_text_p = token_ex.parse_attn_mult(self.prompt)
69
- mult_n, clean_text_n = token_ex.parse_attn_mult(self.negative_prompt)
98
+ mult_p, clean_text_p = token_ex.parse_attn_mult(prompt)
99
+ mult_n, clean_text_n = token_ex.parse_attn_mult(negative_prompt)
70
100
  with autocast(enabled=amp is not None, dtype=get_dtype(amp)):
71
101
  emb, pooled_output, attention_mask = te_hook.encode_prompt_to_emb(clean_text_n+clean_text_p)
102
+ if attention_mask is not None:
103
+ emb, attention_mask = pad_attn_bias(emb, attention_mask)
72
104
  emb_n, emb_p = emb.chunk(2)
73
105
  emb_p = te_hook.mult_attn(emb_p, mult_p)
74
106
  emb_n = te_hook.mult_attn(emb_n, mult_n)
75
107
 
76
- if offload:
77
- to_cpu(memory.text_encoder)
108
+ if model_offload:
109
+ to_cpu(TE)
78
110
 
79
- return {**states, 'prompt':self.prompt, 'negative_prompt':self.negative_prompt, 'prompt_embeds':torch.cat([emb_n, emb_p], dim=0),
80
- 'device':device, 'dtype':dtype, 'amp':amp, 'encoder_attention_mask':attention_mask}
111
+ return {'prompt':list(clean_text_p), 'negative_prompt':list(clean_text_n), 'prompt_embeds':torch.cat([emb_n, emb_p], dim=0),
112
+ 'encoder_attention_mask':attention_mask, 'pooled_output':pooled_output}
hcpdiff/workflow/utils.py CHANGED
@@ -1,13 +1,14 @@
1
- import torch
1
+ from typing import List, Union
2
2
 
3
- from .base import BasicAction, from_memory_context
4
- from torch import nn
3
+ import torch
5
4
  from PIL import Image
6
- from typing import List
5
+ from hcpdiff.data.handler import ControlNetHandler
6
+ from rainbowneko.infer import BasicAction
7
+ from torch import nn
7
8
 
8
9
  class LatentResizeAction(BasicAction):
9
- @from_memory_context
10
- def __init__(self, width=1024, height=1024, mode='bicubic', antialias=True):
10
+ def __init__(self, width=1024, height=1024, mode='bicubic', antialias=True, key_map_in=None, key_map_out=None):
11
+ super().__init__(key_map_in, key_map_out)
11
12
  self.size = (height//8, width//8)
12
13
  self.mode = mode
13
14
  self.antialias = antialias
@@ -16,18 +17,37 @@ class LatentResizeAction(BasicAction):
16
17
  latents_dtype = latents.dtype
17
18
  latents = nn.functional.interpolate(latents.to(dtype=torch.float32), size=self.size, mode=self.mode)
18
19
  latents = latents.to(dtype=latents_dtype)
19
- return {**states, 'latents':latents}
20
+ return {'latents':latents}
20
21
 
21
22
  class ImageResizeAction(BasicAction):
22
23
  # resample name to Image.xxx
23
24
  mode_map = {'nearest':Image.NEAREST, 'bilinear':Image.BILINEAR, 'bicubic':Image.BICUBIC, 'lanczos':Image.LANCZOS, 'box':Image.BOX,
24
- 'hamming':Image.HAMMING, 'antialias':Image.ANTIALIAS}
25
+ 'hamming':Image.HAMMING, 'antialias':Image.LANCZOS}
25
26
 
26
- @from_memory_context
27
- def __init__(self, width=1024, height=1024, mode='bicubic'):
27
+ def __init__(self, width=1024, height=1024, mode='bicubic', key_map_in=None, key_map_out=None):
28
+ super().__init__(key_map_in, key_map_out)
28
29
  self.size = (width, height)
29
30
  self.mode = self.mode_map[mode]
30
31
 
31
- def forward(self, images:List[Image.Image], **states):
32
+ def forward(self, images: List[Image.Image], **states):
32
33
  images = [image.resize(self.size, resample=self.mode) for image in images]
33
- return {**states, 'images':images}
34
+ return {'images':images}
35
+
36
+ class FeedtoCNetAction(BasicAction):
37
+ def __init__(self, width=None, height=None, key_map_in=None, key_map_out=None):
38
+ super().__init__(key_map_in, key_map_out)
39
+ self.size = (width, height)
40
+ self.cnet_handler = ControlNetHandler()
41
+
42
+ def forward(self, images: Union[List[Image.Image], Image.Image], device='cuda', dtype=None, bs=None, latents=None, **states):
43
+ if bs is None:
44
+ if 'prompt' in states:
45
+ bs = len(states['prompt'])
46
+
47
+ if latents is not None:
48
+ width, height = latents.shape[3]*8, latents.shape[2]*8
49
+ else:
50
+ width, height = self.size
51
+
52
+ images = self.cnet_handler.handle(images).to(device, dtype=dtype).expand(bs*2, 3, width, height)
53
+ return {'ex_inputs':{'cond':images}}
hcpdiff/workflow/vae.py CHANGED
@@ -1,33 +1,32 @@
1
- from .base import BasicAction, from_memory_context
2
- from diffusers import AutoencoderKL
3
- from diffusers.image_processor import VaeImageProcessor
4
- from typing import Dict, Any
5
1
  import torch
2
+ from diffusers.image_processor import VaeImageProcessor
6
3
  from hcpdiff.utils import to_cuda, to_cpu
7
4
  from hcpdiff.utils.net_utils import get_dtype
5
+ from rainbowneko.infer import BasicAction
8
6
 
9
7
  class EncodeAction(BasicAction):
10
- @from_memory_context
11
- def __init__(self, vae: AutoencoderKL, image_processor=None, offload: Dict[str, Any] = None):
12
- super().__init__()
13
- self.vae = vae
14
- self.vae_scale_factor = 2**(len(self.vae.config.block_out_channels)-1)
15
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) if image_processor is None else image_processor
16
- self.offload = offload
8
+ def __init__(self, image_processor=None, key_map_in=None, key_map_out=None):
9
+ super().__init__(key_map_in, key_map_out)
10
+ self.image_processor = image_processor
17
11
 
18
- def forward(self, images, dtype:str, device, generator, bs=None, **states):
12
+ def forward(self, vae, images, dtype: str, device, generator, bs=None, model_offload=False, **states):
19
13
  if bs is None:
20
14
  if 'prompt' in states:
21
15
  bs = len(states['prompt'])
16
+ vae_scale_factor = 2**(len(vae.config.block_out_channels)-1)
17
+ if self.image_processor is None:
18
+ self.image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
22
19
 
23
20
  image = self.image_processor.preprocess(images)
24
- image = image.to(device=device, dtype=self.vae.dtype)
21
+ if bs is not None and image.shape[0] != bs:
22
+ image = image.repeat(bs//image.shape[0], 1, 1, 1)
23
+ image = image.to(device=device, dtype=vae.dtype)
25
24
 
26
25
  if image.shape[1] == 4:
27
26
  init_latents = image
28
27
  else:
29
- if self.offload:
30
- to_cuda(self.vae)
28
+ if model_offload:
29
+ to_cuda(vae)
31
30
  if isinstance(generator, list) and len(generator) != bs:
32
31
  raise ValueError(
33
32
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -36,38 +35,38 @@ class EncodeAction(BasicAction):
36
35
 
37
36
  elif isinstance(generator, list):
38
37
  init_latents = [
39
- self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(bs)
38
+ vae.encode(image[i: i+1]).latent_dist.sample(generator[i]) for i in range(bs)
40
39
  ]
41
40
  init_latents = torch.cat(init_latents, dim=0)
42
41
  else:
43
- init_latents = self.vae.encode(image).latent_dist.sample(generator)
42
+ init_latents = vae.encode(image).latent_dist.sample(generator)
44
43
 
45
- init_latents = self.vae.config.scaling_factor * init_latents.to(dtype=get_dtype(dtype))
46
- if self.offload:
47
- to_cpu(self.vae)
48
- return {**states, 'latents':init_latents, 'dtype':dtype, 'device':device, 'bs':bs}
44
+ init_latents = vae.config.scaling_factor*init_latents.to(dtype=get_dtype(dtype))
45
+ if model_offload:
46
+ to_cpu(vae)
47
+ return {'latents':init_latents}
49
48
 
50
49
  class DecodeAction(BasicAction):
51
- @from_memory_context
52
- def __init__(self, vae: AutoencoderKL, image_processor=None, output_type='pil', offload: Dict[str, Any] = None, decode_key='latents'):
53
- super().__init__()
54
- self.vae = vae
55
- self.offload = offload
50
+ def __init__(self, image_processor=None, output_type='pil', key_map_in=None, key_map_out=None):
51
+ super().__init__(key_map_in, key_map_out)
56
52
 
57
- self.vae_scale_factor = 2**(len(self.vae.config.block_out_channels)-1)
58
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) if image_processor is None else image_processor
53
+ self.image_processor = image_processor
59
54
  self.output_type = output_type
60
- self.decode_key = decode_key
61
55
 
62
- def forward(self, **states):
63
- latents = states[self.decode_key]
64
- if self.offload:
65
- to_cuda(self.vae)
66
- latents = latents.to(dtype=self.vae.dtype)
67
- image = self.vae.decode(latents/self.vae.config.scaling_factor, return_dict=False)[0]
68
- if self.offload:
69
- to_cpu(self.vae)
56
+ def forward(self, vae, denoiser, latents, model_offload=False, **states):
57
+ vae_scale_factor = 2**(len(vae.config.block_out_channels)-1)
58
+ if self.image_processor is None:
59
+ self.image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
60
+
61
+ if model_offload:
62
+ to_cpu(denoiser)
63
+ torch.cuda.synchronize()
64
+ to_cuda(vae)
65
+ latents = latents.to(dtype=vae.dtype)
66
+ image = vae.decode(latents/vae.config.scaling_factor, return_dict=False)[0]
67
+ if model_offload:
68
+ to_cpu(vae)
70
69
 
71
70
  do_denormalize = [True]*image.shape[0]
72
71
  image = self.image_processor.postprocess(image, output_type=self.output_type, do_denormalize=do_denormalize)
73
- return {**states, 'images':image}
72
+ return {'images':image}