hcpdiff 0.9.1__py3-none-any.whl → 2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +252 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +90 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +3 -3
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text.py +40 -0
  21. hcpdiff/data/source/text2img.py +36 -74
  22. hcpdiff/data/source/text2img_cond.py +9 -15
  23. hcpdiff/diffusion/__init__.py +0 -0
  24. hcpdiff/diffusion/noise/__init__.py +2 -0
  25. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  26. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  27. hcpdiff/diffusion/sampler/__init__.py +5 -0
  28. hcpdiff/diffusion/sampler/base.py +72 -0
  29. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  30. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  31. hcpdiff/diffusion/sampler/edm.py +22 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  35. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  36. hcpdiff/easy/__init__.py +2 -0
  37. hcpdiff/easy/cfg/__init__.py +3 -0
  38. hcpdiff/easy/cfg/sd15_train.py +207 -0
  39. hcpdiff/easy/cfg/sdxl_train.py +147 -0
  40. hcpdiff/easy/cfg/t2i.py +228 -0
  41. hcpdiff/easy/model/__init__.py +2 -0
  42. hcpdiff/easy/model/cnet.py +31 -0
  43. hcpdiff/easy/model/loader.py +79 -0
  44. hcpdiff/easy/sampler.py +46 -0
  45. hcpdiff/evaluate/__init__.py +1 -0
  46. hcpdiff/evaluate/previewer.py +60 -0
  47. hcpdiff/loss/__init__.py +4 -1
  48. hcpdiff/loss/base.py +41 -0
  49. hcpdiff/loss/gw.py +35 -0
  50. hcpdiff/loss/ssim.py +37 -0
  51. hcpdiff/loss/vlb.py +79 -0
  52. hcpdiff/loss/weighting.py +66 -0
  53. hcpdiff/models/__init__.py +2 -2
  54. hcpdiff/models/cfg_context.py +17 -14
  55. hcpdiff/models/compose/compose_hook.py +44 -23
  56. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  57. hcpdiff/models/compose/sdxl_composer.py +4 -4
  58. hcpdiff/models/controlnet.py +16 -16
  59. hcpdiff/models/lora_base_patch.py +14 -25
  60. hcpdiff/models/lora_layers.py +3 -9
  61. hcpdiff/models/lora_layers_patch.py +14 -24
  62. hcpdiff/models/text_emb_ex.py +84 -6
  63. hcpdiff/models/textencoder_ex.py +54 -18
  64. hcpdiff/models/wrapper/__init__.py +3 -0
  65. hcpdiff/models/wrapper/pixart.py +19 -0
  66. hcpdiff/models/wrapper/sd.py +218 -0
  67. hcpdiff/models/wrapper/utils.py +20 -0
  68. hcpdiff/parser/__init__.py +1 -0
  69. hcpdiff/parser/embpt.py +32 -0
  70. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  71. hcpdiff/tools/dataset_generator.py +94 -0
  72. hcpdiff/tools/download_hf_model.py +24 -0
  73. hcpdiff/tools/init_proj.py +3 -21
  74. hcpdiff/tools/lora_convert.py +18 -17
  75. hcpdiff/tools/save_model.py +12 -0
  76. hcpdiff/tools/sd2diffusers.py +1 -1
  77. hcpdiff/train_colo.py +1 -1
  78. hcpdiff/train_deepspeed.py +1 -1
  79. hcpdiff/trainer_ac.py +79 -0
  80. hcpdiff/trainer_ac_single.py +31 -0
  81. hcpdiff/utils/__init__.py +0 -2
  82. hcpdiff/utils/inpaint_pipe.py +7 -2
  83. hcpdiff/utils/net_utils.py +29 -6
  84. hcpdiff/utils/pipe_hook.py +24 -7
  85. hcpdiff/utils/utils.py +21 -4
  86. hcpdiff/workflow/__init__.py +15 -10
  87. hcpdiff/workflow/daam/__init__.py +1 -0
  88. hcpdiff/workflow/daam/act.py +66 -0
  89. hcpdiff/workflow/daam/hook.py +109 -0
  90. hcpdiff/workflow/diffusion.py +118 -128
  91. hcpdiff/workflow/fast.py +31 -0
  92. hcpdiff/workflow/flow.py +67 -0
  93. hcpdiff/workflow/io.py +36 -130
  94. hcpdiff/workflow/model.py +46 -43
  95. hcpdiff/workflow/text.py +60 -47
  96. hcpdiff/workflow/utils.py +32 -12
  97. hcpdiff/workflow/vae.py +37 -38
  98. hcpdiff-2.2.dist-info/METADATA +299 -0
  99. hcpdiff-2.2.dist-info/RECORD +115 -0
  100. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/WHEEL +1 -1
  101. hcpdiff-2.2.dist-info/entry_points.txt +5 -0
  102. hcpdiff/ckpt_manager/base.py +0 -16
  103. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  104. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  105. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
  106. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  107. hcpdiff/data/bucket.py +0 -358
  108. hcpdiff/data/caption_loader.py +0 -80
  109. hcpdiff/data/cond_dataset.py +0 -40
  110. hcpdiff/data/crop_info_dataset.py +0 -40
  111. hcpdiff/data/data_processor.py +0 -33
  112. hcpdiff/data/pair_dataset.py +0 -146
  113. hcpdiff/data/sampler.py +0 -54
  114. hcpdiff/data/source/base.py +0 -30
  115. hcpdiff/data/utils.py +0 -80
  116. hcpdiff/deprecated/__init__.py +0 -1
  117. hcpdiff/deprecated/cfg_converter.py +0 -81
  118. hcpdiff/deprecated/lora_convert.py +0 -31
  119. hcpdiff/infer_workflow.py +0 -57
  120. hcpdiff/loggers/__init__.py +0 -13
  121. hcpdiff/loggers/base_logger.py +0 -76
  122. hcpdiff/loggers/cli_logger.py +0 -40
  123. hcpdiff/loggers/preview/__init__.py +0 -1
  124. hcpdiff/loggers/preview/image_previewer.py +0 -149
  125. hcpdiff/loggers/tensorboard_logger.py +0 -30
  126. hcpdiff/loggers/wandb_logger.py +0 -31
  127. hcpdiff/loggers/webui_logger.py +0 -9
  128. hcpdiff/loss/min_snr_loss.py +0 -52
  129. hcpdiff/models/layers.py +0 -81
  130. hcpdiff/models/plugin.py +0 -348
  131. hcpdiff/models/wrapper.py +0 -75
  132. hcpdiff/noise/__init__.py +0 -3
  133. hcpdiff/noise/noise_base.py +0 -16
  134. hcpdiff/noise/pyramid_noise.py +0 -50
  135. hcpdiff/noise/zero_terminal.py +0 -44
  136. hcpdiff/train_ac.py +0 -566
  137. hcpdiff/train_ac_single.py +0 -39
  138. hcpdiff/utils/caption_tools.py +0 -105
  139. hcpdiff/utils/cfg_net_tools.py +0 -321
  140. hcpdiff/utils/cfg_resolvers.py +0 -16
  141. hcpdiff/utils/ema.py +0 -52
  142. hcpdiff/utils/img_size_tool.py +0 -248
  143. hcpdiff/vis/__init__.py +0 -3
  144. hcpdiff/vis/base_interface.py +0 -12
  145. hcpdiff/vis/disk_interface.py +0 -48
  146. hcpdiff/vis/webui_interface.py +0 -17
  147. hcpdiff/viser_fast.py +0 -138
  148. hcpdiff/visualizer.py +0 -265
  149. hcpdiff/visualizer_reloadable.py +0 -237
  150. hcpdiff/workflow/base.py +0 -59
  151. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  152. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  153. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  154. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  155. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  156. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  157. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  158. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  159. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  160. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  161. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  162. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  163. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  164. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  165. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  166. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  167. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  168. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  169. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  170. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  171. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  172. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  173. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  174. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  175. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  176. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  177. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  178. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  179. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  180. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  181. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  182. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  183. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  184. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  185. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  186. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  187. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  188. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  189. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  190. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  191. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  192. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  193. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  194. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  195. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
  196. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  197. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
  198. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
  199. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  200. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  201. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  202. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  203. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  204. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  205. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  206. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  207. hcpdiff-0.9.1.dist-info/METADATA +0 -199
  208. hcpdiff-0.9.1.dist-info/RECORD +0 -160
  209. hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
  210. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info/licenses}/LICENSE +0 -0
  211. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/top_level.txt +0 -0
@@ -1,105 +0,0 @@
1
- """
2
- caption_tools.py
3
- ====================
4
- :Name: process prompts
5
- :Author: Dong Ziyi
6
- :Affiliation: HCP Lab, SYSU
7
- :Created: 10/03/2023
8
- :Licence: Apache-2.0
9
- """
10
-
11
- import random
12
- from string import Formatter
13
- from typing import List, Dict, Union
14
-
15
- import numpy as np
16
-
17
-
18
- class TagShuffle:
19
- def __call__(self, data):
20
- if 'caption' in data:
21
- text = data['caption']
22
- if text is not None:
23
- tags = text.split(',')
24
- random.shuffle(tags)
25
- data['caption'] = ','.join(tags)
26
- return data
27
- else:
28
- for i, item in enumerate(data['prompt']):
29
- tags = item.split(',')
30
- random.shuffle(tags)
31
- data['prompt'][i] = ','.join(tags)
32
- return data
33
-
34
- def __repr__(self):
35
- return 'TagShuffle()'
36
-
37
-
38
- class TagDropout:
39
- def __init__(self, p=0.1):
40
- self.p = p
41
-
42
- def __call__(self, data):
43
- if 'caption' in data:
44
- text = data['caption']
45
- if text is not None:
46
- tags = np.array(text.split(','))
47
- data['caption'] = ','.join(tags[np.random.random(len(tags)) > self.p])
48
- return data
49
- else:
50
- for i, item in enumerate(data['prompt']):
51
- tags = item.split(',')
52
- data['prompt'][i] = ','.join(tags[np.random.random(len(tags)) > self.p])
53
- return data
54
-
55
- def __repr__(self):
56
- return f'TagDropout(p={self.p})'
57
-
58
- class TagErase:
59
- def __init__(self, p=0.1):
60
- self.p = p
61
-
62
- def __call__(self, data):
63
- for i, item in enumerate(data['prompt']):
64
- if random.random()<self.p:
65
- data['prompt'][i] = ''
66
- return data
67
-
68
- def __repr__(self):
69
- return f'TagErase(p={self.p})'
70
-
71
- class TemplateFill:
72
- def __init__(self, word_names: Dict[str, Union[str, List[str]]]):
73
- self.word_names = word_names
74
- self.DA_names = {k: v for k, v in word_names.items() if not isinstance(v, str)}
75
- self.dream_artist = len(self.DA_names) > 0
76
-
77
- def __call__(self, data):
78
- template, caption = data['prompt'], data['caption']
79
-
80
- keys_need = {i[1] for i in Formatter().parse(template) if i[1] is not None}
81
- fill_dict = {k: v for k, v in self.word_names.items() if k in keys_need}
82
-
83
- if (caption is not None) and ('caption' in keys_need):
84
- if self.dream_artist:
85
- cap_fill = fill_dict.get('caption', [None, None])
86
- fill_dict.update(caption=[cap_fill[0] or caption, cap_fill[1] or caption])
87
- else:
88
- fill_dict.update(caption=fill_dict.get('caption', None) or caption)
89
-
90
- # skip keys that not provide
91
- for k in keys_need:
92
- if k not in fill_dict:
93
- fill_dict[k] = ''
94
-
95
- if self.dream_artist:
96
- fill_dict_pos = {k: ((v if isinstance(v, str) else v[0]) or '') for k, v in fill_dict.items()}
97
- fill_dict_neg = {k: ((v if isinstance(v, str) else v[1]) or '') for k, v in fill_dict.items()}
98
- return {'prompt':[template.format(**fill_dict_neg), template.format(**fill_dict_pos)]}
99
- else:
100
- # replace None value with ''
101
- fill_dict = {k:(v or '') for k, v in fill_dict.items()}
102
- return {'prompt':[template.format(**fill_dict)]}
103
-
104
- def __repr__(self):
105
- return f'TemplateFill(\nword_names={self.word_names}\n)'
@@ -1,321 +0,0 @@
1
- """
2
- cfg_net_tools.py
3
- ====================
4
- :Name: creat model and plugin from config
5
- :Author: Dong Ziyi
6
- :Affiliation: HCP Lab, SYSU
7
- :Created: 10/03/2023
8
- :Licence: Apache-2.0
9
- """
10
- import warnings
11
- from typing import Dict, List, Tuple, Union, Any
12
-
13
- import re
14
- import torch
15
- from torch import nn
16
-
17
- from .utils import net_path_join
18
- from hcpdiff.models import LoraBlock, LoraGroup, lora_layer_map
19
- from hcpdiff.models.plugin import SinglePluginBlock, MultiPluginBlock, PluginBlock, PluginGroup, PatchPluginBlock
20
- from hcpdiff.ckpt_manager import auto_manager
21
- from .net_utils import split_module_name
22
- from hcpdiff.tools.convert_old_lora import convert_state
23
-
24
- def get_class_match_layer(class_name, block:nn.Module):
25
- if type(block).__name__==class_name:
26
- return ['']
27
- else:
28
- return ['.'+name for name, layer in block.named_modules() if type(layer).__name__==class_name]
29
-
30
- def get_match_layers(layers, all_layers, return_metas=False) -> Union[List[str], List[Dict[str, Any]]]:
31
- res=[]
32
- for name in layers:
33
- metas = name.split(':')
34
-
35
- use_re = False
36
- pre_hook = False
37
- cls_filter = None
38
- for meta in metas[:-1]:
39
- if meta=='re':
40
- use_re=True
41
- elif meta=='pre_hook':
42
- pre_hook=True
43
- elif meta.startswith('cls('):
44
- cls_filter=meta[4:-1]
45
-
46
- name = metas[-1]
47
- if use_re:
48
- pattern = re.compile(name)
49
- match_layers = filter(lambda x: pattern.match(x) != None, all_layers.keys())
50
- else:
51
- match_layers = [name]
52
-
53
- if cls_filter is not None:
54
- match_layers_new = []
55
- for layer in match_layers:
56
- match_layers_new.extend([layer + x for x in get_class_match_layer(name[1], all_layers[layer])])
57
- match_layers = match_layers_new
58
-
59
- for layer in match_layers:
60
- if return_metas:
61
- res.append({'layer': layer, 'pre_hook': pre_hook})
62
- else:
63
- res.append(layer)
64
-
65
- # Remove duplicates and keep the original order
66
- if return_metas:
67
- layer_set=set()
68
- res_unique = []
69
- for item in res:
70
- if item['layer'] not in layer_set:
71
- layer_set.add(item['layer'])
72
- res_unique.append(item)
73
- return res_unique
74
- else:
75
- return sorted(set(res), key=res.index)
76
-
77
- def get_lora_rank_and_cls(lora_state):
78
- if 'layer.lora_down.weight' in lora_state: # old format
79
- warnings.warn("The old lora format is deprecated.", DeprecationWarning)
80
- rank = lora_state['layer.lora_down.weight'].shape[0]
81
- lora_layer_cls = lora_layer_map['lora']
82
- return lora_layer_cls, rank, True
83
- elif 'layer.W_down' in lora_state:
84
- rank = lora_state['layer.W_down'].shape[0]
85
- lora_layer_cls = lora_layer_map['lora']
86
- return lora_layer_cls, rank, False
87
- else:
88
- raise ValueError('Unknown lora format.')
89
-
90
- def make_hcpdiff(model, cfg_model, cfg_lora, default_lr=1e-5) -> Tuple[List[Dict], Union[LoraGroup, Tuple[LoraGroup, LoraGroup]]]:
91
- named_modules = {k:v for k,v in model.named_modules()}
92
-
93
- train_params=[]
94
- all_lora_blocks={}
95
- all_lora_blocks_neg={}
96
-
97
- if cfg_model is not None:
98
- for item in cfg_model:
99
- params_group = []
100
- for layer_name in get_match_layers(item.layers, named_modules):
101
- layer = named_modules[layer_name]
102
- layer.requires_grad_(True)
103
- layer.train()
104
- params_group.extend(list(LoraBlock.extract_param_without_lora(layer).values()))
105
- train_params.append({'params':list(set(params_group)), 'lr':getattr(item, 'lr', default_lr)})
106
-
107
- if cfg_lora is not None:
108
- for lora_id, item in enumerate(cfg_lora):
109
- params_group = []
110
- for layer_name in get_match_layers(item.layers, named_modules):
111
- parent_name, host_name = split_module_name(layer_name)
112
- layer = named_modules[layer_name]
113
- arg_dict = {k:v for k,v in item.items() if k!='layers'}
114
- lora_block_dict = lora_layer_map[arg_dict.get('type', 'lora')].wrap_model(lora_id, layer, parent_block=named_modules[parent_name], host_name=host_name, **arg_dict)
115
-
116
- for k,v in lora_block_dict.items():
117
- block_path = net_path_join(layer_name, k)
118
- all_lora_blocks[block_path] = v
119
- v.requires_grad_(True)
120
- v.train()
121
- params_group.extend(v.parameters())
122
-
123
- train_params.append({'params': params_group, 'lr':getattr(item, 'lr', default_lr)})
124
-
125
- if len(all_lora_blocks_neg)>0:
126
- return train_params, (LoraGroup(all_lora_blocks), LoraGroup(all_lora_blocks_neg))
127
- else:
128
- return train_params, LoraGroup(all_lora_blocks)
129
-
130
- def make_plugin(model, cfg_plugin, default_lr=1e-5) -> Tuple[List, Dict[str, PluginGroup]]:
131
- train_params=[]
132
- all_plugin_group={}
133
-
134
- if cfg_plugin is None:
135
- return train_params, all_plugin_group
136
-
137
- named_modules = {k: v for k, v in model.named_modules()}
138
-
139
- # builder: functools.partial
140
- for plugin_name, builder in cfg_plugin.items():
141
- all_plugin_blocks={}
142
-
143
- lr = builder.keywords.pop('lr') if 'lr' in builder.keywords else default_lr
144
- train_plugin = builder.keywords.pop('train') if 'train' in builder.keywords else True
145
- plugin_class = getattr(builder.func, '__self__', builder.func) # support static or class method
146
-
147
- params_group = []
148
- if issubclass(plugin_class, MultiPluginBlock):
149
- from_layers = [{**item, 'layer':named_modules[item['layer']]} for item in get_match_layers(builder.keywords.pop('from_layers'), named_modules, return_metas=True)]
150
- to_layers = [{**item, 'layer':named_modules[item['layer']]} for item in get_match_layers(builder.keywords.pop('to_layers'), named_modules, return_metas=True)]
151
-
152
- layer = builder(name=plugin_name, host_model=model, from_layers=from_layers, to_layers=to_layers)
153
- if train_plugin:
154
- layer.train()
155
- params = layer.get_trainable_parameters()
156
- for p in params:
157
- p.requires_grad_(True)
158
- params_group.append(p)
159
- else:
160
- layer.requires_grad_(False)
161
- layer.eval()
162
- all_plugin_blocks[''] = layer
163
- elif issubclass(plugin_class, SinglePluginBlock):
164
- layers_name = builder.keywords.pop('layers')
165
- for layer_name in get_match_layers(layers_name, named_modules):
166
- blocks = builder(name=plugin_name, host_model=model, host=named_modules[layer_name])
167
- if not isinstance(blocks, dict):
168
- blocks={'':blocks}
169
-
170
- for k,v in blocks.items():
171
- all_plugin_blocks[net_path_join(layer_name, k)] = v
172
- if train_plugin:
173
- v.train()
174
- params = v.get_trainable_parameters()
175
- for p in params:
176
- p.requires_grad_(True)
177
- params_group.append(p)
178
- else:
179
- v.requires_grad_(False)
180
- v.eval()
181
- elif issubclass(plugin_class, PluginBlock):
182
- from_layer = get_match_layers(builder.keywords.pop('from_layer'), named_modules, return_metas=True)
183
- to_layer = get_match_layers(builder.keywords.pop('to_layer'), named_modules, return_metas=True)
184
-
185
- for from_layer_meta, to_layer_meta in zip(from_layer, to_layer):
186
- from_layer_name=from_layer_meta['layer']
187
- from_layer_meta['layer']=named_modules[from_layer_name]
188
- to_layer_meta['layer']=named_modules[to_layer_meta['layer']]
189
- layer = builder(name=plugin_name, host_model=model, from_layer=from_layer_meta, to_layer=to_layer_meta)
190
- if train_plugin:
191
- layer.train()
192
- params = layer.get_trainable_parameters()
193
- for p in params:
194
- p.requires_grad_(True)
195
- params_group.append(p)
196
- else:
197
- layer.requires_grad_(False)
198
- layer.eval()
199
- all_plugin_blocks[from_layer_name] = layer
200
- elif issubclass(plugin_class, PatchPluginBlock):
201
- layers_name = builder.keywords.pop('layers')
202
- for layer_name in get_match_layers(layers_name, named_modules):
203
- parent_name, host_name = split_module_name(layer_name)
204
- layers = builder(name=plugin_name, host_model=model, host=named_modules[layer_name],
205
- parent_block=named_modules[parent_name], host_name=host_name)
206
- if not isinstance(layers, dict):
207
- layers={'':layers}
208
-
209
- for k,v in layers.items():
210
- all_plugin_blocks[net_path_join(layer_name, k)] = v
211
- if train_plugin:
212
- v.train()
213
- params = v.get_trainable_parameters()
214
- for p in params:
215
- p.requires_grad_(True)
216
- params_group.append(p)
217
- else:
218
- v.requires_grad_(False)
219
- v.eval()
220
- else:
221
- raise NotImplementedError(f'Unknown plugin {plugin_class}')
222
- if train_plugin:
223
- train_params.append({'params':params_group, 'lr':lr})
224
- all_plugin_group[plugin_name] = PluginGroup(all_plugin_blocks)
225
- return train_params, all_plugin_group
226
-
227
- class HCPModelLoader:
228
- def __init__(self, host):
229
- self.host = host
230
- self.named_modules = {k:v for k, v in host.named_modules()}
231
- self.named_params = {k:v for k, v in host.named_parameters()}
232
-
233
- @torch.no_grad()
234
- def load_part(self, cfg, base_model_alpha=0.0, load_ema=False):
235
- if cfg is None:
236
- return
237
- for item in cfg:
238
- part_state = auto_manager(item.path).load_ckpt(item.path, map_location='cpu')['base_ema' if load_ema else 'base']
239
- layers = item.get('layers', 'all')
240
- if layers == 'all':
241
- for k, v in part_state.items():
242
- self.named_params[k].data = base_model_alpha * self.named_params[k].data + item.alpha * v
243
- else:
244
- match_blocks = get_match_layers(layers, self.named_modules)
245
- state_add = {k:v for blk in match_blocks for k,v in part_state.items() if k.startswith(blk)}
246
- for k, v in state_add.items():
247
- self.named_params[k].data = base_model_alpha * self.named_params[k].data + item.alpha * v
248
-
249
- @torch.no_grad()
250
- def load_lora(self, cfg, base_model_alpha=1.0, load_ema=False):
251
- if cfg is None:
252
- return
253
-
254
- all_lora_blocks = {}
255
- for lora_id, item in enumerate(cfg):
256
- lora_state = auto_manager(item.path).load_ckpt(item.path, map_location='cpu')['lora_ema' if load_ema else 'lora']
257
- lora_block_state = {}
258
- # get all layers in the lora_state
259
- for name, p in lora_state.items():
260
- # lora_block. is the old format
261
- prefix, block_name = name.split('.___.' if name.rfind('lora_block.')==-1 else '.lora_block.', 1)
262
- if prefix not in lora_block_state:
263
- lora_block_state[prefix] = {}
264
- lora_block_state[prefix][block_name] = p
265
- # get selected layers
266
- layers = item.get('layers', 'all')
267
- if layers != 'all':
268
- match_blocks = get_match_layers(layers, self.named_modules)
269
- lora_state_new = {}
270
- for k, v in lora_block_state.items():
271
- for mk in match_blocks:
272
- if k.startswith(mk):
273
- lora_state_new[k]=v
274
- break
275
- lora_block_state = lora_state_new
276
- # add lora to host and load weights
277
- for layer_name, lora_state in lora_block_state.items():
278
- parent_name, host_name = split_module_name(layer_name)
279
- lora_layer_cls, rank, old_format = get_lora_rank_and_cls(lora_state)
280
- if 'alpha' in lora_state:
281
- del lora_state['alpha']
282
-
283
- if old_format:
284
- lora_state = convert_state(lora_state)
285
-
286
- lora_block = lora_layer_cls.wrap_layer(lora_id, self.named_modules[layer_name], rank=rank, dropout=getattr(item, 'dropout', 0.0),
287
- alpha=getattr(item, 'alpha', 1.0), bias='layer.bias' in lora_state, alpha_auto_scale=getattr(item, 'alpha_auto_scale', True),
288
- parent_block=self.named_modules[parent_name], host_name=host_name)
289
- all_lora_blocks[f'{layer_name}.{lora_block.name}'] = lora_block
290
- lora_block.load_state_dict(lora_state, strict=False)
291
- lora_block.to(self.host.device)
292
- return LoraGroup(all_lora_blocks)
293
-
294
- @torch.no_grad()
295
- def load_plugin(self, cfg, load_ema=False):
296
- if cfg is None:
297
- return
298
-
299
- for name, item in cfg.items():
300
- plugin_state = auto_manager(item.path).load_ckpt(item.path, map_location='cpu')['plugin_ema' if load_ema else 'plugin']
301
- layers = item.get('layers', 'all')
302
- if layers != 'all':
303
- match_blocks = get_match_layers(layers, self.named_modules)
304
- plugin_state = {k:v for blk in match_blocks for k, v in plugin_state.items() if k.startswith(blk)}
305
- plugin_key_set = set([k.split('___', 1)[0]+name for k in plugin_state.keys()])
306
- plugin_state = {k.replace('___', name):v for k, v in plugin_state.items()} # replace placeholder to target plugin name
307
- self.host.load_state_dict(plugin_state, strict=False)
308
- if 'layers' in item:
309
- del item.layers
310
- del item.path
311
- if hasattr(self.host, name): # MultiPluginBlock
312
- getattr(self.host, name).set_hyper_params(**item)
313
- else:
314
- for plugin_key in plugin_key_set:
315
- self.named_modules[plugin_key].set_hyper_params(**item)
316
-
317
- def load_all(self, cfg_merge, load_ema=False):
318
- self.load_part(cfg_merge.get('part', []), base_model_alpha=cfg_merge.get('base_model_alpha', 0.0), load_ema=load_ema)
319
- lora_group = self.load_lora(cfg_merge.get('lora', []), base_model_alpha=cfg_merge.get('base_model_alpha', 1.0), load_ema=load_ema)
320
- self.load_plugin(cfg_merge.get('plugin', {}), load_ema=load_ema)
321
- return lora_group
@@ -1,16 +0,0 @@
1
- import time
2
- import warnings
3
- from omegaconf import OmegaConf
4
- import torch
5
- from .net_utils import dtype_dict
6
-
7
- def times(a,b):
8
- warnings.warn(f"${{times:{a},{b}}} is deprecated and will be removed in the future. Please use ${{hcp.eval:{a}*{b}}} instead.", DeprecationWarning)
9
- return a*b
10
-
11
- OmegaConf.register_new_resolver("times", times)
12
-
13
- OmegaConf.register_new_resolver("hcp.eval", lambda exp: eval(exp))
14
- OmegaConf.register_new_resolver("hcp.time", lambda format="%Y-%m-%d-%H-%M-%S": time.strftime(format))
15
-
16
- OmegaConf.register_new_resolver("hcp.dtype", lambda dtype: dtype_dict.get(dtype, torch.float32))
hcpdiff/utils/ema.py DELETED
@@ -1,52 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from copy import deepcopy
4
- from typing import Iterable, Tuple, Dict
5
- import numpy as np
6
-
7
- class ModelEMA:
8
- def __init__(self, model: nn.Module, decay_max=0.9997, inv_gamma=1., power=2/3, start_step=0, device='cpu'):
9
- self.train_params = {name:p.data.to(device) for name, p in model.named_parameters() if p.requires_grad}
10
- self.train_params.update({name:p.to(device) for name, p in model.named_buffers()})
11
- self.decay_max = decay_max
12
- self.inv_gamma = inv_gamma
13
- self.power = power
14
- self.step = start_step
15
- self.device=device
16
-
17
- @torch.no_grad()
18
- def update(self, model: nn.Module):
19
- self.step += 1
20
- # Compute the decay factor for the exponential moving average.
21
- decay = 1-(1+self.step/self.inv_gamma)**-self.power
22
- decay = np.clip(decay, 0., self.decay_max)
23
-
24
- for name, param in model.named_parameters():
25
- if name in self.train_params:
26
- self.train_params[name].lerp_(param.data.to(self.device), 1-decay) # (1-e)x + e*x_
27
-
28
- for name, param in model.named_buffers():
29
- if name in self.train_params:
30
- self.train_params[name].copy_(param.to(self.device))
31
-
32
- #torch.cuda.empty_cache()
33
-
34
- def copy_to(self, model: nn.Module) -> None:
35
- for name, param in model.named_parameters():
36
- if name in self.train_params:
37
- param.data.copy_(self.train_params[name])
38
-
39
- def to(self, device=None, dtype=None):
40
- # .to() on the tensors handles None correctly
41
- self.train_params = {
42
- name:(p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)) for name, p in self.train_params.items()
43
- }
44
- return self
45
-
46
- def state_dict(self) -> Dict[str, torch.Tensor]:
47
- return self.train_params
48
-
49
- def load_state_dict(self, state: Dict[str, torch.Tensor]):
50
- for k, v in state:
51
- if k in self.train_params:
52
- self.train_params[k]=v