hcpdiff 0.9.1__py3-none-any.whl → 2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +252 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +90 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +3 -3
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text.py +40 -0
  21. hcpdiff/data/source/text2img.py +36 -74
  22. hcpdiff/data/source/text2img_cond.py +9 -15
  23. hcpdiff/diffusion/__init__.py +0 -0
  24. hcpdiff/diffusion/noise/__init__.py +2 -0
  25. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  26. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  27. hcpdiff/diffusion/sampler/__init__.py +5 -0
  28. hcpdiff/diffusion/sampler/base.py +72 -0
  29. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  30. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  31. hcpdiff/diffusion/sampler/edm.py +22 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  35. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  36. hcpdiff/easy/__init__.py +2 -0
  37. hcpdiff/easy/cfg/__init__.py +3 -0
  38. hcpdiff/easy/cfg/sd15_train.py +207 -0
  39. hcpdiff/easy/cfg/sdxl_train.py +147 -0
  40. hcpdiff/easy/cfg/t2i.py +228 -0
  41. hcpdiff/easy/model/__init__.py +2 -0
  42. hcpdiff/easy/model/cnet.py +31 -0
  43. hcpdiff/easy/model/loader.py +79 -0
  44. hcpdiff/easy/sampler.py +46 -0
  45. hcpdiff/evaluate/__init__.py +1 -0
  46. hcpdiff/evaluate/previewer.py +60 -0
  47. hcpdiff/loss/__init__.py +4 -1
  48. hcpdiff/loss/base.py +41 -0
  49. hcpdiff/loss/gw.py +35 -0
  50. hcpdiff/loss/ssim.py +37 -0
  51. hcpdiff/loss/vlb.py +79 -0
  52. hcpdiff/loss/weighting.py +66 -0
  53. hcpdiff/models/__init__.py +2 -2
  54. hcpdiff/models/cfg_context.py +17 -14
  55. hcpdiff/models/compose/compose_hook.py +44 -23
  56. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  57. hcpdiff/models/compose/sdxl_composer.py +4 -4
  58. hcpdiff/models/controlnet.py +16 -16
  59. hcpdiff/models/lora_base_patch.py +14 -25
  60. hcpdiff/models/lora_layers.py +3 -9
  61. hcpdiff/models/lora_layers_patch.py +14 -24
  62. hcpdiff/models/text_emb_ex.py +84 -6
  63. hcpdiff/models/textencoder_ex.py +54 -18
  64. hcpdiff/models/wrapper/__init__.py +3 -0
  65. hcpdiff/models/wrapper/pixart.py +19 -0
  66. hcpdiff/models/wrapper/sd.py +218 -0
  67. hcpdiff/models/wrapper/utils.py +20 -0
  68. hcpdiff/parser/__init__.py +1 -0
  69. hcpdiff/parser/embpt.py +32 -0
  70. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  71. hcpdiff/tools/dataset_generator.py +94 -0
  72. hcpdiff/tools/download_hf_model.py +24 -0
  73. hcpdiff/tools/init_proj.py +3 -21
  74. hcpdiff/tools/lora_convert.py +18 -17
  75. hcpdiff/tools/save_model.py +12 -0
  76. hcpdiff/tools/sd2diffusers.py +1 -1
  77. hcpdiff/train_colo.py +1 -1
  78. hcpdiff/train_deepspeed.py +1 -1
  79. hcpdiff/trainer_ac.py +79 -0
  80. hcpdiff/trainer_ac_single.py +31 -0
  81. hcpdiff/utils/__init__.py +0 -2
  82. hcpdiff/utils/inpaint_pipe.py +7 -2
  83. hcpdiff/utils/net_utils.py +29 -6
  84. hcpdiff/utils/pipe_hook.py +24 -7
  85. hcpdiff/utils/utils.py +21 -4
  86. hcpdiff/workflow/__init__.py +15 -10
  87. hcpdiff/workflow/daam/__init__.py +1 -0
  88. hcpdiff/workflow/daam/act.py +66 -0
  89. hcpdiff/workflow/daam/hook.py +109 -0
  90. hcpdiff/workflow/diffusion.py +118 -128
  91. hcpdiff/workflow/fast.py +31 -0
  92. hcpdiff/workflow/flow.py +67 -0
  93. hcpdiff/workflow/io.py +36 -130
  94. hcpdiff/workflow/model.py +46 -43
  95. hcpdiff/workflow/text.py +60 -47
  96. hcpdiff/workflow/utils.py +32 -12
  97. hcpdiff/workflow/vae.py +37 -38
  98. hcpdiff-2.2.dist-info/METADATA +299 -0
  99. hcpdiff-2.2.dist-info/RECORD +115 -0
  100. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/WHEEL +1 -1
  101. hcpdiff-2.2.dist-info/entry_points.txt +5 -0
  102. hcpdiff/ckpt_manager/base.py +0 -16
  103. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  104. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  105. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
  106. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  107. hcpdiff/data/bucket.py +0 -358
  108. hcpdiff/data/caption_loader.py +0 -80
  109. hcpdiff/data/cond_dataset.py +0 -40
  110. hcpdiff/data/crop_info_dataset.py +0 -40
  111. hcpdiff/data/data_processor.py +0 -33
  112. hcpdiff/data/pair_dataset.py +0 -146
  113. hcpdiff/data/sampler.py +0 -54
  114. hcpdiff/data/source/base.py +0 -30
  115. hcpdiff/data/utils.py +0 -80
  116. hcpdiff/deprecated/__init__.py +0 -1
  117. hcpdiff/deprecated/cfg_converter.py +0 -81
  118. hcpdiff/deprecated/lora_convert.py +0 -31
  119. hcpdiff/infer_workflow.py +0 -57
  120. hcpdiff/loggers/__init__.py +0 -13
  121. hcpdiff/loggers/base_logger.py +0 -76
  122. hcpdiff/loggers/cli_logger.py +0 -40
  123. hcpdiff/loggers/preview/__init__.py +0 -1
  124. hcpdiff/loggers/preview/image_previewer.py +0 -149
  125. hcpdiff/loggers/tensorboard_logger.py +0 -30
  126. hcpdiff/loggers/wandb_logger.py +0 -31
  127. hcpdiff/loggers/webui_logger.py +0 -9
  128. hcpdiff/loss/min_snr_loss.py +0 -52
  129. hcpdiff/models/layers.py +0 -81
  130. hcpdiff/models/plugin.py +0 -348
  131. hcpdiff/models/wrapper.py +0 -75
  132. hcpdiff/noise/__init__.py +0 -3
  133. hcpdiff/noise/noise_base.py +0 -16
  134. hcpdiff/noise/pyramid_noise.py +0 -50
  135. hcpdiff/noise/zero_terminal.py +0 -44
  136. hcpdiff/train_ac.py +0 -566
  137. hcpdiff/train_ac_single.py +0 -39
  138. hcpdiff/utils/caption_tools.py +0 -105
  139. hcpdiff/utils/cfg_net_tools.py +0 -321
  140. hcpdiff/utils/cfg_resolvers.py +0 -16
  141. hcpdiff/utils/ema.py +0 -52
  142. hcpdiff/utils/img_size_tool.py +0 -248
  143. hcpdiff/vis/__init__.py +0 -3
  144. hcpdiff/vis/base_interface.py +0 -12
  145. hcpdiff/vis/disk_interface.py +0 -48
  146. hcpdiff/vis/webui_interface.py +0 -17
  147. hcpdiff/viser_fast.py +0 -138
  148. hcpdiff/visualizer.py +0 -265
  149. hcpdiff/visualizer_reloadable.py +0 -237
  150. hcpdiff/workflow/base.py +0 -59
  151. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  152. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  153. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  154. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  155. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  156. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  157. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  158. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  159. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  160. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  161. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  162. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  163. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  164. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  165. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  166. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  167. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  168. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  169. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  170. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  171. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  172. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  173. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  174. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  175. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  176. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  177. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  178. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  179. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  180. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  181. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  182. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  183. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  184. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  185. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  186. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  187. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  188. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  189. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  190. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  191. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  192. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  193. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  194. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  195. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
  196. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  197. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
  198. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
  199. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  200. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  201. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  202. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  203. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  204. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  205. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  206. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  207. hcpdiff-0.9.1.dist-info/METADATA +0 -199
  208. hcpdiff-0.9.1.dist-info/RECORD +0 -160
  209. hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
  210. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info/licenses}/LICENSE +0 -0
  211. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/top_level.txt +0 -0
hcpdiff/models/plugin.py DELETED
@@ -1,348 +0,0 @@
1
- """
2
- plugin.py
3
- ====================
4
- :Name: model plugin
5
- :Author: Dong Ziyi
6
- :Affiliation: HCP Lab, SYSU
7
- :Created: 10/03/2023
8
- :Licence: Apache-2.0
9
- """
10
-
11
- import weakref
12
- import re
13
- from typing import Tuple, List, Dict, Any, Iterable
14
-
15
- import torch
16
- from torch import nn
17
-
18
- from hcpdiff.utils.net_utils import split_module_name
19
-
20
- class BasePluginBlock(nn.Module):
21
- def __init__(self, name: str):
22
- super().__init__()
23
- self.name = name
24
-
25
- def forward(self, host: nn.Module, fea_in: Tuple[torch.Tensor], fea_out: torch.Tensor):
26
- return fea_out
27
-
28
- def remove(self):
29
- pass
30
-
31
- def feed_input_data(self, data):
32
- self.input_data = data
33
-
34
- def register_input_feeder_to(self, host_model):
35
- if not hasattr(host_model, 'input_feeder'):
36
- host_model.input_feeder = []
37
- host_model.input_feeder.append(self.feed_input_data)
38
-
39
- def set_hyper_params(self, **kwargs):
40
- for k, v in kwargs.items():
41
- setattr(self, k, v)
42
-
43
- @staticmethod
44
- def extract_state_without_plugin(model: nn.Module, trainable=False):
45
- trainable_keys = {k for k, v in model.named_parameters() if v.requires_grad}
46
- plugin_names = {k for k, v in model.named_modules() if isinstance(v, BasePluginBlock)}
47
- model_sd = {}
48
- for k, v in model.state_dict().items():
49
- if (not trainable) or k in trainable_keys:
50
- for name in plugin_names:
51
- if k.startswith(name):
52
- break
53
- else:
54
- model_sd[k] = v
55
- return model_sd
56
-
57
- def get_trainable_parameters(self) -> Iterable[nn.Parameter]:
58
- return self.parameters()
59
-
60
- class WrapablePlugin:
61
- wrapable_classes = ()
62
-
63
- @classmethod
64
- def wrap_layer(cls, name: str, layer: nn.Module, **kwargs):
65
- plugin = cls(name, layer, **kwargs)
66
- return plugin
67
-
68
- @classmethod
69
- def named_modules_with_exclude(cls, self, memo = None, prefix: str = '', remove_duplicate: bool = True,
70
- exclude_key=None, exclude_classes=tuple()):
71
-
72
- if memo is None:
73
- memo = set()
74
- if self not in memo:
75
- if remove_duplicate:
76
- memo.add(self)
77
- if (exclude_key is None or not re.search(exclude_key, prefix)) and not isinstance(self, exclude_classes):
78
- yield prefix, self
79
- for name, module in self._modules.items():
80
- if module is None:
81
- continue
82
- submodule_prefix = prefix + ('.' if prefix else '') + name
83
- for m in cls.named_modules_with_exclude(module, memo, submodule_prefix, remove_duplicate, exclude_key, exclude_classes):
84
- yield m
85
-
86
- @classmethod
87
- def wrap_model(cls, name: str, host: nn.Module, exclude_key=None, exclude_classes=tuple(), **kwargs): # -> Dict[str, SinglePluginBlock]:
88
- '''
89
- parent_block and other args required in __init__ will be put into kwargs, compatible with multiple models.
90
- '''
91
- plugin_block_dict = {}
92
- if isinstance(host, cls.wrapable_classes):
93
- plugin_block_dict[''] = cls.wrap_layer(name, host, **kwargs)
94
- else:
95
- named_modules = {layer_name:layer for layer_name, layer in cls.named_modules_with_exclude(
96
- host, exclude_key=exclude_key, exclude_classes=exclude_classes)}
97
- for layer_name, layer in named_modules.items():
98
- if isinstance(layer, cls.wrapable_classes):
99
- # For plugins that need parent_block
100
- if 'parent_block' in kwargs:
101
- parent_name, host_name = split_module_name(layer_name)
102
- kwargs['parent_block'] = named_modules[parent_name]
103
- kwargs['host_name'] = host_name
104
- plugin_block_dict[layer_name] = cls.wrap_layer(name, layer, **kwargs)
105
- return plugin_block_dict
106
-
107
- class SinglePluginBlock(BasePluginBlock, WrapablePlugin):
108
-
109
- def __init__(self, name: str, host: nn.Module, hook_param=None, host_model=None):
110
- super().__init__(name)
111
- self.host = weakref.ref(host)
112
- setattr(host, name, self)
113
-
114
- if hook_param is None:
115
- self.hook_handle = host.register_forward_hook(self.layer_hook)
116
- else: # hook for model parameters
117
- self.backup = getattr(host, hook_param)
118
- self.target = hook_param
119
- self.handle_pre = host.register_forward_pre_hook(self.pre_hook)
120
- self.handle_post = host.register_forward_hook(self.post_hook)
121
-
122
- def layer_hook(self, host, fea_in: Tuple[torch.Tensor], fea_out: torch.Tensor):
123
- return self(fea_in, fea_out)
124
-
125
- def pre_hook(self, host, fea_in: torch.Tensor):
126
- host.weight_restored = False
127
- host_param = getattr(host, self.target)
128
- delattr(host, self.target)
129
- setattr(host, self.target, self(host_param))
130
- return fea_in
131
-
132
- def post_hook(self, host, fea_int, fea_out):
133
- if not getattr(host, 'weight_restored', False):
134
- setattr(host, self.target, self.backup)
135
- host.weight_restored = True
136
-
137
- def remove(self):
138
- host = self.host()
139
- delattr(host, self.name)
140
- if hasattr(self, 'hook_handle'):
141
- self.hook_handle.remove()
142
- else:
143
- self.handle_pre.remove()
144
- self.handle_post.remove()
145
-
146
- class PluginBlock(BasePluginBlock):
147
- def __init__(self, name, from_layer: Dict[str, Any], to_layer: Dict[str, Any], host_model=None):
148
- super().__init__(name)
149
- self.host_from = weakref.ref(from_layer['layer'])
150
- self.host_to = weakref.ref(to_layer['layer'])
151
- setattr(from_layer['layer'], name, self)
152
-
153
- if from_layer['pre_hook']:
154
- self.hook_handle_from = from_layer['layer'].register_forward_pre_hook(lambda host, fea_in:self.from_layer_hook(host, fea_in, None))
155
- else:
156
- self.hook_handle_from = from_layer['layer'].register_forward_hook(
157
- lambda host, fea_in, fea_out:self.from_layer_hook(host, fea_in, fea_out))
158
- if to_layer['pre_hook']:
159
- self.hook_handle_to = to_layer['layer'].register_forward_pre_hook(lambda host, fea_in:self.to_layer_hook(host, fea_in, None))
160
- else:
161
- self.hook_handle_to = to_layer['layer'].register_forward_hook(lambda host, fea_in, fea_out:self.to_layer_hook(host, fea_in, fea_out))
162
-
163
- def from_layer_hook(self, host, fea_in: Tuple[torch.Tensor], fea_out: torch.Tensor):
164
- self.feat_from = fea_in
165
-
166
- def to_layer_hook(self, host, fea_in: Tuple[torch.Tensor], fea_out: torch.Tensor):
167
- return self(self.feat_from, fea_in, fea_out)
168
-
169
- def remove(self):
170
- host_from = self.host_from()
171
- delattr(host_from, self.name)
172
- self.hook_handle_from.remove()
173
- self.hook_handle_to.remove()
174
-
175
- class MultiPluginBlock(BasePluginBlock):
176
- def __init__(self, name: str, from_layers: List[Dict[str, Any]], to_layers: List[Dict[str, Any]], host_model=None):
177
- super().__init__(name)
178
- assert host_model is not None
179
- self.host_from = [weakref.ref(x['layer']) for x in from_layers]
180
- self.host_to = [weakref.ref(x['layer']) for x in to_layers]
181
- self.host_model = weakref.ref(host_model)
182
- setattr(host_model, name, self)
183
-
184
- self.feat_from = [None for _ in range(len(from_layers))]
185
-
186
- self.hook_handle_from = []
187
- self.hook_handle_to = []
188
-
189
- for idx, layer in enumerate(from_layers):
190
- if layer['pre_hook']:
191
- handle_from = layer['layer'].register_forward_pre_hook(lambda host, fea_in, idx=idx:self.from_layer_hook(host, fea_in, None, idx))
192
- else:
193
- handle_from = layer['layer'].register_forward_hook(
194
- lambda host, fea_in, fea_out, idx=idx:self.from_layer_hook(host, fea_in, fea_out, idx))
195
- self.hook_handle_from.append(handle_from)
196
- for idx, layer in enumerate(to_layers):
197
- if layer['pre_hook']:
198
- handle_to = layer['layer'].register_forward_pre_hook(lambda host, fea_in, idx=idx:self.to_layer_hook(host, fea_in, None, idx))
199
- else:
200
- handle_to = layer['layer'].register_forward_hook(lambda host, fea_in, fea_out, idx=idx:self.to_layer_hook(host, fea_in, fea_out, idx))
201
- self.hook_handle_to.append(handle_to)
202
-
203
- self.record_count = 0
204
-
205
- def from_layer_hook(self, host, fea_in: Tuple[torch.Tensor], fea_out: Tuple[torch.Tensor], idx: int):
206
- self.feat_from[idx] = fea_in
207
- self.record_count += 1
208
- if self.record_count == len(self.feat_from): # call forward when all feat is record
209
- self.record_count = 0
210
- self.feat_to = self(self.feat_from)
211
-
212
- def to_layer_hook(self, host, fea_in: Tuple[torch.Tensor], fea_out: Tuple[torch.Tensor], idx: int):
213
- return self.feat_to[idx]+fea_out
214
-
215
- def remove(self):
216
- host_model = self.host_model()
217
- delattr(host_model, self.name)
218
- for handle_from in self.hook_handle_from:
219
- handle_from.remove()
220
- for handle_to in self.hook_handle_to:
221
- handle_to.remove()
222
-
223
- class PatchPluginContainer(nn.Module):
224
- def __init__(self, host_name, host, parent_block):
225
- super().__init__()
226
- self._host = host
227
- self.host_name = host_name
228
- self.parent_block = weakref.ref(parent_block)
229
- self.plugin_names = []
230
-
231
- delattr(parent_block, host_name)
232
- setattr(parent_block, host_name, self)
233
-
234
- def add_plugin(self, name: str, plugin: 'PatchPluginBlock'):
235
- setattr(self, name, plugin)
236
- self.plugin_names.append(name)
237
-
238
- def remove_plugin(self, name: str):
239
- delattr(self, name)
240
- self.plugin_names.remove(name)
241
- if len(self.plugin_names) == 0:
242
- self.remove()
243
-
244
- def forward(self, *args, **kwargs):
245
- for name, plugin in self:
246
- args, kwargs = plugin.pre_forward(*args, **kwargs)
247
- output = self._host(*args, **kwargs)
248
- for name, plugin in self:
249
- output = plugin.post_forward(output, *args, **kwargs)
250
- return output
251
-
252
- def remove(self):
253
- parent_block = self.parent_block()
254
- delattr(parent_block, self.host_name)
255
- setattr(parent_block, self.host_name, self._host)
256
-
257
- def __iter__(self):
258
- for name in self.plugin_names:
259
- yield name, self[name]
260
-
261
- def __getitem__(self, name):
262
- return getattr(self, name)
263
-
264
- class PatchPluginBlock(BasePluginBlock, WrapablePlugin):
265
- container_cls = PatchPluginContainer
266
-
267
- def __init__(self, name: str, host: nn.Module, host_model=None, parent_block: nn.Module = None, host_name: str = None):
268
- super().__init__(name)
269
- if isinstance(host, self.container_cls):
270
- self.host = weakref.ref(host._host)
271
- else:
272
- self.host = weakref.ref(host)
273
- self.parent_block = weakref.ref(parent_block)
274
- self.host_name = host_name
275
-
276
- container = self.get_container(host, host_name, parent_block)
277
- container.add_plugin(name, self)
278
- self.container = weakref.ref(container)
279
-
280
- def pre_forward(self, *args, **kwargs):
281
- return args, kwargs
282
-
283
- def post_forward(self, output, *args, **kwargs):
284
- return output
285
-
286
- def remove(self):
287
- container = self.container()
288
- container.remove_plugin(self.name)
289
-
290
- def get_container(self, host, host_name, parent_block):
291
- if isinstance(host, self.container_cls):
292
- return host
293
- else:
294
- return self.container_cls(host_name, host, parent_block)
295
-
296
- @classmethod
297
- def wrap_model(cls, name: str, host: nn.Module, exclude_key=None, exclude_classes=tuple(), **kwargs): # -> Dict[str, SinglePluginBlock]:
298
- '''
299
- parent_block and other args required in __init__ will be put into kwargs, compatible with multiple models.
300
- '''
301
- plugin_block_dict = {}
302
- if isinstance(host, cls.wrapable_classes):
303
- plugin_block_dict[''] = cls.wrap_layer(name, host, **kwargs)
304
- else:
305
- named_modules = {layer_name:layer for layer_name, layer in cls.named_modules_with_exclude(
306
- host, exclude_key=exclude_key or '_host', exclude_classes=exclude_classes)}
307
- for layer_name, layer in named_modules.items():
308
- if isinstance(layer, cls.wrapable_classes) or isinstance(layer, cls.container_cls):
309
- # For plugins that need parent_block
310
- if 'parent_block' in kwargs:
311
- parent_name, host_name = split_module_name(layer_name)
312
- kwargs['parent_block'] = named_modules[parent_name]
313
- kwargs['host_name'] = host_name
314
- plugin_block_dict[layer_name] = cls.wrap_layer(name, layer, **kwargs)
315
- return plugin_block_dict
316
-
317
- class PluginGroup:
318
- def __init__(self, plugin_dict: Dict[str, BasePluginBlock]):
319
- self.plugin_dict = plugin_dict # {host_model_path: plugin_object}
320
-
321
- def __setitem__(self, k, v):
322
- self.plugin_dict[k] = v
323
-
324
- def __getitem__(self, k):
325
- return self.plugin_dict[k]
326
-
327
- @property
328
- def plugin_name(self):
329
- if self.empty():
330
- return None
331
- return next(iter(self.plugin_dict.values())).name
332
-
333
- def remove(self):
334
- for plugin in self.plugin_dict.values():
335
- plugin.remove()
336
-
337
- def state_dict(self, model=None):
338
- if model is None:
339
- return {f'{k}.___.{ks}':vs for k, v in self.plugin_dict.items() for ks, vs in v.state_dict().items()}
340
- else:
341
- sd_model = model.state_dict()
342
- return {f'{k}.___.{ks}':sd_model[f'{k}.{v.name}.{ks}'] for k, v in self.plugin_dict.items() for ks, vs in v.state_dict().items()}
343
-
344
- def state_keys_raw(self):
345
- return [f'{k}.{v.name}.{ks}' for k, v in self.plugin_dict.items() for ks, vs in v.state_dict().items()]
346
-
347
- def empty(self):
348
- return len(self.plugin_dict) == 0
hcpdiff/models/wrapper.py DELETED
@@ -1,75 +0,0 @@
1
- from torch import nn
2
- import itertools
3
- from transformers import CLIPTextModel
4
- from hcpdiff.utils import pad_attn_bias
5
-
6
- class TEUnetWrapper(nn.Module):
7
- def __init__(self, unet, TE, train_TE=False):
8
- super().__init__()
9
- self.unet = unet
10
- self.TE = TE
11
-
12
- self.train_TE = train_TE
13
-
14
- def forward(self, prompt_ids, noisy_latents, timesteps, attn_mask=None, position_ids=None, plugin_input={}, **kwargs):
15
- input_all = dict(prompt_ids=prompt_ids, noisy_latents=noisy_latents, timesteps=timesteps, position_ids=position_ids, attn_mask=attn_mask, **plugin_input)
16
-
17
- if hasattr(self.TE, 'input_feeder'):
18
- for feeder in self.TE.input_feeder:
19
- feeder(input_all)
20
- encoder_hidden_states = self.TE(prompt_ids, position_ids=position_ids, attention_mask=attn_mask, output_hidden_states=True)[0] # Get the text embedding for conditioning
21
-
22
- if attn_mask is not None:
23
- encoder_hidden_states, attn_mask = pad_attn_bias(encoder_hidden_states, attn_mask)
24
-
25
- input_all['encoder_hidden_states'] = encoder_hidden_states
26
- if hasattr(self.unet, 'input_feeder'):
27
- for feeder in self.unet.input_feeder:
28
- feeder(input_all)
29
- model_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, encoder_attention_mask=attn_mask).sample # Predict the noise residual
30
- return model_pred
31
-
32
- def prepare(self, accelerator):
33
- if self.train_TE:
34
- return accelerator.prepare(self)
35
- else:
36
- self.unet = accelerator.prepare(self.unet)
37
- return self
38
-
39
- def enable_gradient_checkpointing(self):
40
- def grad_ckpt_enable(m):
41
- if hasattr(m, 'gradient_checkpointing'):
42
- m.training = True
43
-
44
- self.unet.enable_gradient_checkpointing()
45
- if self.train_TE:
46
- self.TE.gradient_checkpointing_enable()
47
- self.apply(grad_ckpt_enable)
48
- else:
49
- self.unet.apply(grad_ckpt_enable)
50
-
51
- def trainable_parameters(self):
52
- if self.train_TE:
53
- return itertools.chain(self.unet.parameters(), self.TE.parameters())
54
- else:
55
- return self.unet.parameters()
56
-
57
- class SDXLTEUnetWrapper(TEUnetWrapper):
58
- def forward(self, prompt_ids, noisy_latents, timesteps, attn_mask=None, position_ids=None, crop_info=None, plugin_input={}, **kwargs):
59
- input_all = dict(prompt_ids=prompt_ids, noisy_latents=noisy_latents, timesteps=timesteps, position_ids=position_ids, attn_mask=attn_mask, **plugin_input)
60
-
61
- if hasattr(self.TE, 'input_feeder'):
62
- for feeder in self.TE.input_feeder:
63
- feeder(input_all)
64
- encoder_hidden_states, pooled_output = self.TE(prompt_ids, position_ids=position_ids, attention_mask=attn_mask, output_hidden_states=True) # Get the text embedding for conditioning
65
-
66
- added_cond_kwargs = {"text_embeds":pooled_output[-1], "time_ids":crop_info}
67
- if attn_mask is not None:
68
- encoder_hidden_states, attn_mask = pad_attn_bias(encoder_hidden_states, attn_mask)
69
-
70
- input_all['encoder_hidden_states'] = encoder_hidden_states
71
- if hasattr(self.unet, 'input_feeder'):
72
- for feeder in self.unet.input_feeder:
73
- feeder(input_all)
74
- model_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, encoder_attention_mask=attn_mask, added_cond_kwargs=added_cond_kwargs).sample # Predict the noise residual
75
- return model_pred
hcpdiff/noise/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from .noise_base import NoiseBase
2
- from .pyramid_noise import PyramidNoiseScheduler
3
- from .zero_terminal import ZeroTerminalScheduler
@@ -1,16 +0,0 @@
1
-
2
- class NoiseBase:
3
- def __init__(self, base_scheduler):
4
- self.base_scheduler = base_scheduler
5
-
6
- def __getattr__(self, item):
7
- try:
8
- return super(NoiseBase, self).__getattr__(item)
9
- except:
10
- return getattr(self.base_scheduler, item)
11
-
12
- def __setattr__(self, key, value):
13
- if hasattr(super(), 'base_scheduler') and hasattr(self.base_scheduler, key):
14
- setattr(self.base_scheduler, key, value)
15
- else:
16
- super(NoiseBase, self).__setattr__(key, value)
@@ -1,50 +0,0 @@
1
- import random
2
-
3
- import torch
4
- from torch.nn import functional as F
5
- from diffusers import SchedulerMixin
6
-
7
- from .noise_base import NoiseBase
8
-
9
- class PyramidNoiseScheduler(NoiseBase, SchedulerMixin):
10
- def __init__(self, base_scheduler, level: int = 10, discount: float = 0.9, step_size: float = 2., resize_mode: str = 'bilinear'):
11
- super().__init__(base_scheduler)
12
- self.level = level
13
- self.step_size = step_size
14
- self.resize_mode = resize_mode
15
- self.discount = discount
16
-
17
- def add_noise(
18
- self,
19
- original_samples: torch.FloatTensor,
20
- noise: torch.FloatTensor,
21
- timesteps: torch.IntTensor,
22
- ) -> torch.FloatTensor:
23
- with torch.no_grad():
24
- b, c, h, w = noise.shape
25
- for i in range(1, self.level):
26
- r = random.random()*2+self.step_size
27
- wn, hn = max(1, int(w/(r**i))), max(1, int(h/(r**i)))
28
- noise += F.interpolate(torch.randn(b, c, hn, wn).to(noise), (h, w), None, self.resize_mode)*(self.discount**i)
29
- if wn == 1 or hn == 1:
30
- break
31
- noise = noise/noise.std()
32
- return self.base_scheduler.add_noise(original_samples, noise, timesteps)
33
-
34
- # if __name__ == '__main__':
35
- # noise = torch.randn(1,3,512,512)
36
- # level=10
37
- # discount=0.6
38
- # b, c, h, w = noise.shape
39
- # for i in range(level):
40
- # r = random.random() * 2 + 2
41
- # wn, hn = max(1, int(w / (r ** i))), max(1, int(h / (r ** i)))
42
- # noise += F.interpolate(torch.randn(b, c, wn, hn).to(noise), (w, h), None, 'bilinear') * discount ** i
43
- # if wn == 1 or hn == 1:
44
- # break
45
- # noise = noise / noise.std()
46
- #
47
- # from matplotlib import pyplot as plt
48
- # plt.figure()
49
- # plt.imshow(noise[0].permute(1,2,0))
50
- # plt.show()
@@ -1,44 +0,0 @@
1
- import torch
2
- from diffusers import SchedulerMixin
3
- from .noise_base import NoiseBase
4
-
5
- class ZeroTerminalScheduler(NoiseBase, SchedulerMixin):
6
- def __init__(self, base_scheduler):
7
- super().__init__(base_scheduler)
8
- base_scheduler.betas = self.rescale_zero_terminal_snr(base_scheduler.betas)
9
- base_scheduler.alphas = 1.0-base_scheduler.betas
10
- base_scheduler.alphas_cumprod = torch.cumprod(base_scheduler.alphas, dim=0)
11
-
12
- def rescale_zero_terminal_snr(self, betas):
13
- """
14
- Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
15
- Args:
16
- betas (`torch.FloatTensor`):
17
- the betas that the scheduler is being initialized with.
18
- Returns:
19
- `torch.FloatTensor`: rescaled betas with zero terminal SNR
20
- """
21
- # Convert betas to alphas_bar_sqrt
22
- alphas = 1.0-betas
23
- alphas_cumprod = torch.cumprod(alphas, dim=0)
24
- alphas_bar_sqrt = alphas_cumprod.sqrt()
25
-
26
- # Store old values.
27
- alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
28
- alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
29
-
30
- # Shift so the last timestep is zero.
31
- alphas_bar_sqrt -= alphas_bar_sqrt_T
32
-
33
- # Scale so the first timestep is back to the old value.
34
- alphas_bar_sqrt *= alphas_bar_sqrt_0/(alphas_bar_sqrt_0-alphas_bar_sqrt_T)
35
-
36
- # Convert alphas_bar_sqrt to betas
37
- alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
38
- alphas = alphas_bar[1:]/alphas_bar[:-1] # Revert cumprod
39
- alphas = torch.cat([alphas_bar[0:1], alphas])
40
- betas = 1-alphas
41
-
42
- return betas
43
-
44
-