hcpdiff 0.9.1__py3-none-any.whl → 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (210) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +80 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +1 -2
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text2img.py +36 -74
  21. hcpdiff/data/source/text2img_cond.py +9 -15
  22. hcpdiff/diffusion/__init__.py +0 -0
  23. hcpdiff/diffusion/noise/__init__.py +2 -0
  24. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  25. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  26. hcpdiff/diffusion/sampler/__init__.py +5 -0
  27. hcpdiff/diffusion/sampler/base.py +72 -0
  28. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  29. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  30. hcpdiff/diffusion/sampler/edm.py +22 -0
  31. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  35. hcpdiff/easy/__init__.py +2 -0
  36. hcpdiff/easy/cfg/__init__.py +3 -0
  37. hcpdiff/easy/cfg/sd15_train.py +201 -0
  38. hcpdiff/easy/cfg/sdxl_train.py +140 -0
  39. hcpdiff/easy/cfg/t2i.py +177 -0
  40. hcpdiff/easy/model/__init__.py +2 -0
  41. hcpdiff/easy/model/cnet.py +31 -0
  42. hcpdiff/easy/model/loader.py +79 -0
  43. hcpdiff/easy/sampler.py +46 -0
  44. hcpdiff/evaluate/__init__.py +1 -0
  45. hcpdiff/evaluate/previewer.py +60 -0
  46. hcpdiff/loss/__init__.py +4 -1
  47. hcpdiff/loss/base.py +41 -0
  48. hcpdiff/loss/gw.py +35 -0
  49. hcpdiff/loss/ssim.py +37 -0
  50. hcpdiff/loss/vlb.py +79 -0
  51. hcpdiff/loss/weighting.py +66 -0
  52. hcpdiff/models/__init__.py +2 -2
  53. hcpdiff/models/cfg_context.py +17 -14
  54. hcpdiff/models/compose/compose_hook.py +44 -23
  55. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  56. hcpdiff/models/compose/sdxl_composer.py +4 -4
  57. hcpdiff/models/controlnet.py +16 -16
  58. hcpdiff/models/lora_base_patch.py +14 -25
  59. hcpdiff/models/lora_layers.py +3 -9
  60. hcpdiff/models/lora_layers_patch.py +14 -24
  61. hcpdiff/models/text_emb_ex.py +84 -6
  62. hcpdiff/models/textencoder_ex.py +54 -18
  63. hcpdiff/models/wrapper/__init__.py +3 -0
  64. hcpdiff/models/wrapper/pixart.py +19 -0
  65. hcpdiff/models/wrapper/sd.py +218 -0
  66. hcpdiff/models/wrapper/utils.py +20 -0
  67. hcpdiff/parser/__init__.py +1 -0
  68. hcpdiff/parser/embpt.py +32 -0
  69. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  70. hcpdiff/tools/dataset_generator.py +94 -0
  71. hcpdiff/tools/download_hf_model.py +24 -0
  72. hcpdiff/tools/init_proj.py +3 -21
  73. hcpdiff/tools/lora_convert.py +18 -17
  74. hcpdiff/tools/save_model.py +12 -0
  75. hcpdiff/tools/sd2diffusers.py +1 -1
  76. hcpdiff/train_colo.py +1 -1
  77. hcpdiff/train_deepspeed.py +1 -1
  78. hcpdiff/trainer_ac.py +79 -0
  79. hcpdiff/trainer_ac_single.py +31 -0
  80. hcpdiff/utils/__init__.py +0 -2
  81. hcpdiff/utils/inpaint_pipe.py +7 -2
  82. hcpdiff/utils/net_utils.py +29 -6
  83. hcpdiff/utils/pipe_hook.py +24 -7
  84. hcpdiff/utils/utils.py +21 -4
  85. hcpdiff/workflow/__init__.py +15 -10
  86. hcpdiff/workflow/daam/__init__.py +1 -0
  87. hcpdiff/workflow/daam/act.py +66 -0
  88. hcpdiff/workflow/daam/hook.py +109 -0
  89. hcpdiff/workflow/diffusion.py +114 -125
  90. hcpdiff/workflow/fast.py +31 -0
  91. hcpdiff/workflow/flow.py +67 -0
  92. hcpdiff/workflow/io.py +36 -130
  93. hcpdiff/workflow/model.py +46 -43
  94. hcpdiff/workflow/text.py +78 -46
  95. hcpdiff/workflow/utils.py +32 -12
  96. hcpdiff/workflow/vae.py +37 -38
  97. hcpdiff-2.1.dist-info/METADATA +285 -0
  98. hcpdiff-2.1.dist-info/RECORD +114 -0
  99. {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
  100. hcpdiff-2.1.dist-info/entry_points.txt +5 -0
  101. hcpdiff/ckpt_manager/base.py +0 -16
  102. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  103. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  104. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
  105. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  106. hcpdiff/data/bucket.py +0 -358
  107. hcpdiff/data/caption_loader.py +0 -80
  108. hcpdiff/data/cond_dataset.py +0 -40
  109. hcpdiff/data/crop_info_dataset.py +0 -40
  110. hcpdiff/data/data_processor.py +0 -33
  111. hcpdiff/data/pair_dataset.py +0 -146
  112. hcpdiff/data/sampler.py +0 -54
  113. hcpdiff/data/source/base.py +0 -30
  114. hcpdiff/data/utils.py +0 -80
  115. hcpdiff/deprecated/__init__.py +0 -1
  116. hcpdiff/deprecated/cfg_converter.py +0 -81
  117. hcpdiff/deprecated/lora_convert.py +0 -31
  118. hcpdiff/infer_workflow.py +0 -57
  119. hcpdiff/loggers/__init__.py +0 -13
  120. hcpdiff/loggers/base_logger.py +0 -76
  121. hcpdiff/loggers/cli_logger.py +0 -40
  122. hcpdiff/loggers/preview/__init__.py +0 -1
  123. hcpdiff/loggers/preview/image_previewer.py +0 -149
  124. hcpdiff/loggers/tensorboard_logger.py +0 -30
  125. hcpdiff/loggers/wandb_logger.py +0 -31
  126. hcpdiff/loggers/webui_logger.py +0 -9
  127. hcpdiff/loss/min_snr_loss.py +0 -52
  128. hcpdiff/models/layers.py +0 -81
  129. hcpdiff/models/plugin.py +0 -348
  130. hcpdiff/models/wrapper.py +0 -75
  131. hcpdiff/noise/__init__.py +0 -3
  132. hcpdiff/noise/noise_base.py +0 -16
  133. hcpdiff/noise/pyramid_noise.py +0 -50
  134. hcpdiff/noise/zero_terminal.py +0 -44
  135. hcpdiff/train_ac.py +0 -566
  136. hcpdiff/train_ac_single.py +0 -39
  137. hcpdiff/utils/caption_tools.py +0 -105
  138. hcpdiff/utils/cfg_net_tools.py +0 -321
  139. hcpdiff/utils/cfg_resolvers.py +0 -16
  140. hcpdiff/utils/ema.py +0 -52
  141. hcpdiff/utils/img_size_tool.py +0 -248
  142. hcpdiff/vis/__init__.py +0 -3
  143. hcpdiff/vis/base_interface.py +0 -12
  144. hcpdiff/vis/disk_interface.py +0 -48
  145. hcpdiff/vis/webui_interface.py +0 -17
  146. hcpdiff/viser_fast.py +0 -138
  147. hcpdiff/visualizer.py +0 -265
  148. hcpdiff/visualizer_reloadable.py +0 -237
  149. hcpdiff/workflow/base.py +0 -59
  150. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  151. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  152. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  153. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  154. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  155. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  156. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  157. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  158. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  159. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  160. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  161. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  162. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  163. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  164. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  165. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  166. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  167. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  168. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  169. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  170. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  171. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  172. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  173. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  174. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  175. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  176. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  177. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  178. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  179. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  180. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  181. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  182. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  183. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  184. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  185. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  186. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  187. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  188. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  189. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  190. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  191. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  192. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  193. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  194. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
  195. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  196. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
  197. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
  198. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  199. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  200. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  201. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  202. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  203. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  204. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  205. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  206. hcpdiff-0.9.1.dist-info/METADATA +0 -199
  207. hcpdiff-0.9.1.dist-info/RECORD +0 -160
  208. hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
  209. {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
  210. {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,109 @@
1
+ from daam import AggregateHooker, RawHeatMapCollection, UNetCrossAttentionLocator, GlobalHeatMap
2
+ from daam.trace import UNetCrossAttentionHooker
3
+ from typing import List
4
+ from diffusers import UNet2DConditionModel
5
+ from PIL import Image
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ def auto_autocast(*args, **kwargs):
11
+ if not torch.cuda.is_available():
12
+ kwargs['enabled'] = False
13
+
14
+ return torch.cuda.amp.autocast(*args, **kwargs)
15
+
16
+ class DiffusionHeatMapHooker(AggregateHooker):
17
+ def __init__(
18
+ self,
19
+ unet: UNet2DConditionModel,
20
+ tokenizer,
21
+ vae_scale_factor: int,
22
+ low_memory: bool = False,
23
+ load_heads: bool = False,
24
+ save_heads: bool = False,
25
+ data_dir: str = None
26
+ ):
27
+ self.all_heat_maps = RawHeatMapCollection()
28
+ h = (unet.config.sample_size * vae_scale_factor)
29
+ self.latent_hw = 4096 if h == 512 or h == 1024 else 9216 # 64x64 or 96x96 depending on if it's 2.0-v or 2.0
30
+ locate_middle = load_heads or save_heads
31
+ self.locator = UNetCrossAttentionLocator(restrict={0} if low_memory else None, locate_middle_block=locate_middle)
32
+ self.last_prompt: str = ''
33
+ self.last_image: Image.Image = None
34
+ self.time_idx = 0
35
+ self._gen_idx = 0
36
+
37
+ self.tokenizer = tokenizer
38
+
39
+ modules = [
40
+ UNetCrossAttentionHooker(
41
+ x,
42
+ self,
43
+ layer_idx=idx,
44
+ latent_hw=self.latent_hw,
45
+ load_heads=load_heads,
46
+ save_heads=save_heads,
47
+ data_dir=data_dir
48
+ ) for idx, x in enumerate(self.locator.locate(unet))
49
+ ]
50
+
51
+ super().__init__(modules)
52
+
53
+ def time_callback(self, *args, **kwargs):
54
+ self.time_idx += 1
55
+
56
+ @property
57
+ def layer_names(self):
58
+ return self.locator.layer_names
59
+
60
+ def compute_global_heat_map(self, prompt=None, factors=None, head_idxs: List[int]=None, layer_idx=None, normalize=False):
61
+ # type: (str, List[float], int, int, bool) -> GlobalHeatMap
62
+ """
63
+ Compute the global heat map for the given prompt, aggregating across time (inference steps) and space (different
64
+ spatial transformer block heat maps).
65
+
66
+ Args:
67
+ prompt: The prompt to compute the heat map for. If none, uses the last prompt that was used for generation.
68
+ factors: Restrict the application to heat maps with spatial factors in this set. If `None`, use all sizes.
69
+ head_idx: Restrict the application to heat maps with this head index. If `None`, use all heads.
70
+ layer_idx: Restrict the application to heat maps with this layer index. If `None`, use all layers.
71
+
72
+ Returns:
73
+ A heat map object for computing word-level heat maps.
74
+ """
75
+ heat_maps = self.all_heat_maps
76
+
77
+ if prompt is None:
78
+ prompt = self.last_prompt
79
+
80
+ if factors is None:
81
+ factors = {0, 1, 2, 4, 8, 16, 32, 64}
82
+ else:
83
+ factors = set(factors)
84
+
85
+ all_merges = []
86
+ x = int(np.sqrt(self.latent_hw))
87
+
88
+ with auto_autocast(dtype=torch.float32):
89
+ for (factor, layer, head), heat_map in heat_maps:
90
+ if (head_idxs is None or head in head_idxs) and (layer_idx is None or layer_idx == layer):
91
+ heat_map = heat_map.unsqueeze(1)/25 # [L,1,H,W]
92
+ # The clamping fixes undershoot.
93
+ all_merges.append(F.interpolate(heat_map, size=(x, x), mode='bicubic').clamp_(min=0))
94
+
95
+ try:
96
+ maps = torch.stack(all_merges, dim=0) # [B*head, L, 1, H, W]
97
+ except RuntimeError:
98
+ if head_idxs is not None or layer_idx is not None:
99
+ raise RuntimeError('No heat maps found for the given parameters.')
100
+ else:
101
+ raise RuntimeError('No heat maps found. Did you forget to call `with trace(...)` during generation?')
102
+
103
+ maps = maps.mean(0)[:, 0] # [L,H,W]
104
+ #maps = maps[:len(self.tokenizer.tokenize(prompt)) + 2] # 1 for SOS and 1 for padding
105
+
106
+ if normalize:
107
+ maps = maps / (maps[1:-1].sum(0, keepdim=True) + 1e-6) # drop out [SOS] and [PAD] for proper probabilities
108
+
109
+ return GlobalHeatMap(self.tokenizer, prompt, maps)
@@ -1,209 +1,198 @@
1
- import inspect
1
+ import random
2
+ import warnings
2
3
  from typing import Dict, Any, Union, List
3
4
 
4
5
  import torch
6
+ from hcpdiff.diffusion.sampler import BaseSampler, DiffusersSampler
7
+ from hcpdiff.utils import prepare_seed
8
+ from hcpdiff.utils.net_utils import get_dtype, to_cuda
9
+ from rainbowneko.infer import BasicAction
5
10
  from torch.cuda.amp import autocast
6
11
 
7
- from .base import BasicAction, from_memory_context, MemoryMixin
8
-
9
12
  try:
10
13
  from diffusers.utils import randn_tensor
11
14
  except:
12
15
  # new version of diffusers
13
16
  from diffusers.utils.torch_utils import randn_tensor
14
17
 
15
- from hcpdiff.utils import prepare_seed
16
- from hcpdiff.utils.net_utils import get_dtype
17
- import random
18
-
19
18
  class InputFeederAction(BasicAction):
20
- @from_memory_context
21
- def __init__(self, ex_inputs: Dict[str, Any], unet=None):
22
- super().__init__()
19
+ def __init__(self, ex_inputs: Dict[str, Any], key_map_in=None, key_map_out=None):
20
+ super().__init__(key_map_in, key_map_out)
23
21
  self.ex_inputs = ex_inputs
24
- self.unet = unet
25
22
 
26
- def forward(self, **states):
27
- if hasattr(self.unet, 'input_feeder'):
28
- for feeder in self.unet.input_feeder:
29
- feeder(self.ex_inputs)
30
- return states
23
+ def forward(self, model, ex_inputs=None, **states):
24
+ ex_inputs = self.ex_inputs if ex_inputs is None else {**ex_inputs, **self.ex_inputs}
25
+ if hasattr(model, 'input_feeder'):
26
+ for feeder in model.input_feeder:
27
+ feeder(ex_inputs)
31
28
 
32
29
  class SeedAction(BasicAction):
33
- @from_memory_context
34
- def __init__(self, seed: Union[int, List[int]], bs: int = 1):
35
- super().__init__()
30
+ def __init__(self, seed: Union[int, List[int]], bs: int = 1, key_map_in=None, key_map_out=None):
31
+ super().__init__(key_map_in, key_map_out)
36
32
  self.seed = seed
37
33
  self.bs = bs
38
34
 
39
- def forward(self, device, **states):
35
+ def forward(self, device, gen_step=0, **states):
40
36
  bs = states['prompt_embeds'].shape[0]//2 if 'prompt_embeds' in states else self.bs
41
37
  if self.seed is None:
42
38
  seeds = [None]*bs
43
39
  elif isinstance(self.seed, int):
44
- seeds = list(range(self.seed, self.seed+bs))
40
+ seeds = list(range(self.seed+gen_step*bs, self.seed+(gen_step+1)*bs))
45
41
  else:
46
42
  seeds = self.seed
47
43
  seeds = [s or random.randint(0, 1 << 30) for s in seeds]
48
44
 
49
45
  G = prepare_seed(seeds, device=device)
50
- return {**states, 'seeds':seeds, 'generator':G, 'device':device}
46
+ return {'seeds':seeds, 'generator':G}
51
47
 
52
- class PrepareDiffusionAction(BasicAction, MemoryMixin):
53
- def __init__(self, dtype='fp32', amp=True):
54
- self.dtype = dtype
48
+ class PrepareDiffusionAction(BasicAction):
49
+ def __init__(self, model_offload=False, amp=torch.float16, key_map_in=None, key_map_out=None):
50
+ super().__init__(key_map_in, key_map_out)
51
+ self.model_offload = model_offload
55
52
  self.amp = amp
56
53
 
57
- def forward(self, memory, **states):
58
- dtype = get_dtype(self.dtype)
59
- memory.unet.to(dtype=dtype)
60
- memory.text_encoder.to(dtype=dtype)
61
- memory.vae.to(dtype=dtype)
54
+ def forward(self, device, denoiser, TE, vae, **states):
55
+ denoiser.to(device)
56
+ TE.to(device)
57
+ vae.to(device)
62
58
 
63
- device = memory.unet.device
64
- vae_scale_factor = 2**(len(memory.vae.config.block_out_channels)-1)
65
- return {**states, 'dtype':self.dtype, 'amp':self.amp, 'device':device, 'vae_scale_factor':vae_scale_factor}
59
+ TE.eval()
60
+ denoiser.eval()
61
+ vae.eval()
62
+ return {'amp':self.amp, 'model_offload':self.model_offload}
66
63
 
67
- class MakeTimestepsAction(BasicAction, MemoryMixin):
68
- @from_memory_context
69
- def __init__(self, scheduler=None, N_steps: int = 30, strength: float = None):
70
- self.scheduler = scheduler
64
+ class MakeTimestepsAction(BasicAction):
65
+ def __init__(self, N_steps: int = 30, strength: float = None, key_map_in=None, key_map_out=None):
66
+ super().__init__(key_map_in, key_map_out)
71
67
  self.N_steps = N_steps
72
68
  self.strength = strength
73
69
 
74
- def get_timesteps(self, timesteps, strength):
70
+ def get_timesteps(self, noise_sampler:BaseSampler, timesteps, strength):
75
71
  # get the original timestep using init_timestep
76
72
  num_inference_steps = len(timesteps)
77
73
  init_timestep = min(int(num_inference_steps*strength), num_inference_steps)
78
74
 
79
75
  t_start = max(num_inference_steps-init_timestep, 0)
80
- timesteps = timesteps[t_start*self.scheduler.order:]
76
+ if isinstance(noise_sampler, DiffusersSampler):
77
+ timesteps = timesteps[t_start*noise_sampler.scheduler.order:]
78
+ else:
79
+ timesteps = timesteps[t_start:]
81
80
 
82
81
  return timesteps
83
82
 
84
- def forward(self, memory, device, **states):
85
- self.scheduler = self.scheduler or memory.scheduler
86
-
87
- self.scheduler.set_timesteps(self.N_steps, device=device)
88
- timesteps = self.scheduler.timesteps
83
+ def forward(self, noise_sampler:BaseSampler, device, **states):
84
+ timesteps = noise_sampler.get_timesteps(self.N_steps, device=device)
89
85
  if self.strength:
90
- timesteps = self.get_timesteps(timesteps, self.strength)
91
- alphas_cumprod = self.scheduler.alphas_cumprod.to(timesteps.device)
92
- return {**states, 'device':device, 'timesteps':timesteps, 'alphas_cumprod':alphas_cumprod}
93
-
94
- class MakeLatentAction(BasicAction, MemoryMixin):
95
- @from_memory_context
96
- def __init__(self, scheduler=None, N_ch=4, height=512, width=512):
97
- self.scheduler = scheduler
86
+ timesteps = self.get_timesteps(noise_sampler, timesteps, self.strength)
87
+ return {'timesteps':timesteps, 'start_timestep':timesteps[:1]}
88
+ else:
89
+ return {'timesteps':timesteps}
90
+
91
+ class MakeLatentAction(BasicAction):
92
+ def __init__(self, N_ch=4, height=None, width=None, key_map_in=None, key_map_out=None):
93
+ super().__init__(key_map_in, key_map_out)
98
94
  self.N_ch = N_ch
99
95
  self.height = height
100
96
  self.width = width
101
97
 
102
- def forward(self, memory, generator, device, dtype, bs=None, latents=None, vae_scale_factor=8, start_timestep=None, **states):
98
+ def forward(self, noise_sampler:BaseSampler, vae, generator, device, dtype, bs=None, latents=None, start_timestep=None,
99
+ pooled_output=None, crop_coord=None, **states):
103
100
  if bs is None:
104
101
  if 'prompt' in states:
105
102
  bs = len(states['prompt'])
106
- scheduler = self.scheduler or memory.scheduler
103
+ vae_scale_factor = 2**(len(vae.config.block_out_channels)-1)
104
+ device = torch.device(device)
107
105
 
108
- shape = (bs, self.N_ch, self.height//vae_scale_factor, self.width//vae_scale_factor)
106
+ if latents is None:
107
+ shape = (bs, self.N_ch, self.height//vae_scale_factor, self.width//vae_scale_factor)
108
+ else:
109
+ if self.height is not None:
110
+ warnings.warn('latents exist! User-specified width and height will be ignored!')
111
+ shape = latents.shape
109
112
  if isinstance(generator, list) and len(generator) != bs:
110
113
  raise ValueError(
111
114
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
112
115
  f" size of {bs}. Make sure the batch size matches the length of the generators."
113
116
  )
114
117
 
115
- noise = randn_tensor(shape, generator=generator, device=device, dtype=get_dtype(dtype))
116
118
  if latents is None:
117
- # scale the initial noise by the standard deviation required by the scheduler
118
- latents = noise*scheduler.init_noise_sigma
119
+ # scale the initial noise by the standard deviation required by the noise_sampler
120
+ noise_sampler.generator = generator
121
+ latents = noise_sampler.init_noise(shape, device=device, dtype=get_dtype(dtype))
119
122
  else:
120
123
  # image to image
121
124
  latents = latents.to(device)
122
- latents = scheduler.add_noise(latents, noise, start_timestep)
125
+ latents, noise = noise_sampler.add_noise(latents, start_timestep)
123
126
 
124
- return {**states, 'latents':latents, 'device':device, 'dtype':dtype, 'generator':generator}
127
+ output = {'latents':latents}
125
128
 
126
- class NoisePredAction(BasicAction, MemoryMixin):
127
- @from_memory_context
128
- def __init__(self, unet=None, scheduler=None, guidance_scale: float = 7.0):
129
+ # SDXL inputs
130
+ if pooled_output is not None:
131
+ width, height = shape[3]*vae_scale_factor, shape[2]*vae_scale_factor
132
+ if crop_coord is None:
133
+ crop_info = torch.tensor([height, width, 0, 0, height, width], dtype=torch.float)
134
+ else:
135
+ crop_info = torch.tensor([height, width, *crop_coord], dtype=torch.float)
136
+ crop_info = crop_info.to(device).repeat(bs, 1)
137
+ output['text_embeds'] = pooled_output[-1].to(device)
138
+
139
+ if 'negative_prompt' in states:
140
+ output['crop_info'] = torch.cat([crop_info, crop_info], dim=0)
141
+
142
+ return output
143
+
144
+ class DenoiseAction(BasicAction):
145
+ def __init__(self, guidance_scale: float = 7.0, key_map_in=None, key_map_out=None):
146
+ super().__init__(key_map_in, key_map_out)
129
147
  self.guidance_scale = guidance_scale
130
- self.unet = unet
131
- self.scheduler = scheduler
132
148
 
133
- def forward(self, memory, t, latents, prompt_embeds, pooled_output=None, encoder_attention_mask=None, crop_info=None,
134
- cross_attention_kwargs=None, dtype='fp32', amp=None, **states):
135
- self.scheduler = self.scheduler or memory.scheduler
136
- self.unet = self.unet or memory.unet
149
+ def forward(self, denoiser, noise_sampler: BaseSampler, t, latents, prompt_embeds, text_embeds=None, encoder_attention_mask=None, crop_info=None,
150
+ cross_attention_kwargs=None, dtype='fp32', amp=None, model_offload=False, **states):
151
+
152
+ if model_offload:
153
+ to_cuda(denoiser) # to_cpu in VAE
137
154
 
138
155
  with autocast(enabled=amp is not None, dtype=get_dtype(amp)):
139
156
  latent_model_input = torch.cat([latents]*2) if self.guidance_scale>1 else latents
140
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
157
+ latent_model_input = noise_sampler.c_in(t)*latent_model_input
141
158
 
142
- if pooled_output is None:
143
- noise_pred = self.unet(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
144
- cross_attention_kwargs=cross_attention_kwargs, ).sample
159
+ if text_embeds is None:
160
+ noise_pred = denoiser(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
161
+ cross_attention_kwargs=cross_attention_kwargs, ).sample
145
162
  else:
146
- added_cond_kwargs = {"text_embeds":pooled_output, "time_ids":crop_info}
163
+ added_cond_kwargs = {"text_embeds":text_embeds, "time_ids":crop_info}
147
164
  # predict the noise residual
148
- noise_pred = self.unet(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
149
- cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
165
+ noise_pred = denoiser(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
166
+ cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
150
167
 
151
168
  # perform guidance
152
169
  if self.guidance_scale>1:
153
170
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
154
171
  noise_pred = noise_pred_uncond+self.guidance_scale*(noise_pred_text-noise_pred_uncond)
155
172
 
156
- return {**states, 'noise_pred':noise_pred, 'latents':latents, 't':t, 'prompt_embeds':prompt_embeds, 'pooled_output':pooled_output,
157
- 'crop_info':crop_info, 'cross_attention_kwargs':cross_attention_kwargs, 'dtype':dtype, 'amp':amp}
158
-
159
- class SampleAction(BasicAction, MemoryMixin):
160
- @from_memory_context
161
- def __init__(self, scheduler=None, eta=0.0):
162
- self.scheduler = scheduler
163
- self.eta = eta
164
-
165
- def prepare_extra_step_kwargs(self, generator, eta):
166
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
167
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
168
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
169
- # and should be between [0, 1]
170
-
171
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
172
- extra_step_kwargs = {}
173
- if accepts_eta:
174
- extra_step_kwargs["eta"] = eta
175
-
176
- # check if the scheduler accepts generator
177
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
178
- if accepts_generator:
179
- extra_step_kwargs["generator"] = generator
180
- return extra_step_kwargs
181
-
182
- def forward(self, memory, noise_pred, t, latents, generator, **states):
183
- self.scheduler = self.scheduler or memory.scheduler
184
-
185
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, self.eta)
173
+ return {'noise_pred':noise_pred}
186
174
 
175
+ class SampleAction(BasicAction):
176
+ def forward(self, noise_sampler: BaseSampler, noise_pred, t, latents, generator, **states):
187
177
  # compute the previous noisy sample x_t -> x_t-1
188
- sc_out = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)
189
- latents = sc_out.prev_sample
190
- return {**states, 'latents':latents, 't':t, 'generator':generator}
191
-
192
- class DiffusionStepAction(BasicAction, MemoryMixin):
193
- @from_memory_context
194
- def __init__(self, unet=None, scheduler=None, guidance_scale: float = 7.0):
195
- self.act_noise_pred = NoisePredAction(unet, scheduler, guidance_scale)
196
- self.act_sample = SampleAction(scheduler)
197
-
198
- def forward(self, memory, **states):
199
- states = self.act_noise_pred(memory=memory, **states)
200
- states = self.act_sample(memory=memory, **states)
178
+ latents = noise_sampler.denoise(latents, t, noise_pred, generator=generator)
179
+ return {'latents':latents}
180
+
181
+ class DiffusionStepAction(BasicAction):
182
+ def __init__(self, guidance_scale: float = 7.0, key_map_in=None, key_map_out=None):
183
+ super().__init__(key_map_in, key_map_out)
184
+ self.act_noise_pred = DenoiseAction(guidance_scale)
185
+ self.act_sample = SampleAction()
186
+
187
+ def forward(self, denoiser, noise_sampler, **states):
188
+ states = self.act_noise_pred(denoiser=denoiser, noise_sampler=noise_sampler, **states)
189
+ states = self.act_sample(**states)
201
190
  return states
202
191
 
203
192
  class X0PredAction(BasicAction):
204
- def forward(self, latents, alphas_cumprod, t, noise_pred, **states):
205
- # x_t -> x_0
206
- alpha_prod_t = alphas_cumprod[t.long()]
207
- beta_prod_t = 1-alpha_prod_t
208
- latents_x0 = (latents-beta_prod_t**(0.5)*noise_pred)/alpha_prod_t**(0.5) # approximate x_0
209
- return {**states, 'latents_x0':latents_x0, 'latents':latents, 'alphas_cumprod':alphas_cumprod, 't':t, 'noise_pred':noise_pred}
193
+ def forward(self, latents, noise_sampler: BaseSampler, t, noise_pred, **states):
194
+ latents_x0 = noise_sampler.eps_to_x0(noise_pred, latents, t)
195
+ return {'latents_x0':latents_x0}
196
+
197
+ def time_iter(timesteps, **states):
198
+ return [{'t':t} for t in timesteps]
@@ -0,0 +1,31 @@
1
+ from sfast.compilers.diffusion_pipeline_compiler import (compile_unet, CompilationConfig)
2
+ from rainbowneko.infer import BasicAction
3
+
4
+
5
+ class SFastCompileAction(BasicAction):
6
+
7
+ @staticmethod
8
+ def compile_model(unet):
9
+ # compile model
10
+ config = CompilationConfig.Default()
11
+ config.enable_xformers = False
12
+ try:
13
+ import xformers
14
+ config.enable_xformers = True
15
+ except ImportError:
16
+ print('xformers not installed, skip')
17
+ # NOTE:
18
+ # When GPU VRAM is insufficient or the architecture is too old, Triton might be slow.
19
+ # Disable Triton if you encounter this problem.
20
+ try:
21
+ import tritonx
22
+ config.enable_triton = True
23
+ except ImportError:
24
+ print('Triton not installed, skip')
25
+ config.enable_cuda_graph = True
26
+
27
+ return compile_unet(unet, config)
28
+
29
+ def forward(self, denoiser, **states):
30
+ denoiser = self.compile_model(denoiser)
31
+ return {'denoiser': denoiser}
@@ -0,0 +1,67 @@
1
+ from rainbowneko.infer import BasicAction
2
+ from typing import List, Dict
3
+ from tqdm import tqdm
4
+ import math
5
+
6
+ class FilePromptAction(BasicAction):
7
+ def __init__(self, actions: List[BasicAction], prompt: str, negative_prompt: str, bs: int = 4, key_map_in=None, key_map_out=None):
8
+ super().__init__(key_map_in, key_map_out)
9
+ if prompt.endswith('.txt'):
10
+ with open(prompt, 'r') as f:
11
+ prompt = f.read().split('\n')
12
+ else:
13
+ prompt = [prompt]
14
+
15
+ if negative_prompt.endswith('.txt'):
16
+ with open(negative_prompt, 'r') as f:
17
+ negative_prompt = f.read().split('\n')
18
+ else:
19
+ negative_prompt = [negative_prompt]*len(prompt)
20
+
21
+ self.prompt = prompt
22
+ self.negative_prompt = negative_prompt
23
+ self.bs = bs
24
+ self.actions = actions
25
+
26
+
27
+ def forward(self, **states):
28
+ states.update({'prompt_all':self.prompt, 'negative_prompt_all':self.negative_prompt})
29
+ states_ref = dict(**states)
30
+
31
+ pbar = tqdm(range(math.ceil(len(self.prompt)/self.bs)))
32
+ N_steps = len(self.actions)
33
+ for gen_step in pbar:
34
+ states = dict(**states_ref)
35
+ feed_data = {'gen_step': gen_step}
36
+ states.update(feed_data)
37
+ for step, act in enumerate(self.actions):
38
+ pbar.set_description(f'[{step+1}/{N_steps}] action: {type(act).__name__}')
39
+ states = act(**states)
40
+ return states
41
+
42
+ class FlowPromptAction(BasicAction):
43
+ def __init__(self, actions: List[BasicAction], prompt: str, negative_prompt: str, bs: int = 4, num: int = None, key_map_in=None, key_map_out=None):
44
+ super().__init__(key_map_in, key_map_out)
45
+ prompt = [prompt]*num
46
+ negative_prompt = [negative_prompt]*num
47
+
48
+ self.prompt = prompt
49
+ self.negative_prompt = negative_prompt
50
+ self.bs = bs
51
+ self.actions = actions
52
+
53
+
54
+ def forward(self, **states):
55
+ states.update({'prompt_all':self.prompt, 'negative_prompt_all':self.negative_prompt})
56
+ states_ref = dict(**states)
57
+
58
+ pbar = tqdm(range(math.ceil(len(self.prompt)/self.bs)))
59
+ N_steps = len(self.actions)
60
+ for gen_step in pbar:
61
+ states = dict(**states_ref)
62
+ feed_data = {'gen_step': gen_step}
63
+ states.update(feed_data)
64
+ for step, act in enumerate(self.actions):
65
+ pbar.set_description(f'[{step+1}/{N_steps}] action: {type(act).__name__}')
66
+ states = act(**states)
67
+ return states