hcpdiff 0.9.1__py3-none-any.whl → 2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +252 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +90 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +3 -3
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text.py +40 -0
  21. hcpdiff/data/source/text2img.py +36 -74
  22. hcpdiff/data/source/text2img_cond.py +9 -15
  23. hcpdiff/diffusion/__init__.py +0 -0
  24. hcpdiff/diffusion/noise/__init__.py +2 -0
  25. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  26. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  27. hcpdiff/diffusion/sampler/__init__.py +5 -0
  28. hcpdiff/diffusion/sampler/base.py +72 -0
  29. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  30. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  31. hcpdiff/diffusion/sampler/edm.py +22 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  35. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  36. hcpdiff/easy/__init__.py +2 -0
  37. hcpdiff/easy/cfg/__init__.py +3 -0
  38. hcpdiff/easy/cfg/sd15_train.py +207 -0
  39. hcpdiff/easy/cfg/sdxl_train.py +147 -0
  40. hcpdiff/easy/cfg/t2i.py +228 -0
  41. hcpdiff/easy/model/__init__.py +2 -0
  42. hcpdiff/easy/model/cnet.py +31 -0
  43. hcpdiff/easy/model/loader.py +79 -0
  44. hcpdiff/easy/sampler.py +46 -0
  45. hcpdiff/evaluate/__init__.py +1 -0
  46. hcpdiff/evaluate/previewer.py +60 -0
  47. hcpdiff/loss/__init__.py +4 -1
  48. hcpdiff/loss/base.py +41 -0
  49. hcpdiff/loss/gw.py +35 -0
  50. hcpdiff/loss/ssim.py +37 -0
  51. hcpdiff/loss/vlb.py +79 -0
  52. hcpdiff/loss/weighting.py +66 -0
  53. hcpdiff/models/__init__.py +2 -2
  54. hcpdiff/models/cfg_context.py +17 -14
  55. hcpdiff/models/compose/compose_hook.py +44 -23
  56. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  57. hcpdiff/models/compose/sdxl_composer.py +4 -4
  58. hcpdiff/models/controlnet.py +16 -16
  59. hcpdiff/models/lora_base_patch.py +14 -25
  60. hcpdiff/models/lora_layers.py +3 -9
  61. hcpdiff/models/lora_layers_patch.py +14 -24
  62. hcpdiff/models/text_emb_ex.py +84 -6
  63. hcpdiff/models/textencoder_ex.py +54 -18
  64. hcpdiff/models/wrapper/__init__.py +3 -0
  65. hcpdiff/models/wrapper/pixart.py +19 -0
  66. hcpdiff/models/wrapper/sd.py +218 -0
  67. hcpdiff/models/wrapper/utils.py +20 -0
  68. hcpdiff/parser/__init__.py +1 -0
  69. hcpdiff/parser/embpt.py +32 -0
  70. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  71. hcpdiff/tools/dataset_generator.py +94 -0
  72. hcpdiff/tools/download_hf_model.py +24 -0
  73. hcpdiff/tools/init_proj.py +3 -21
  74. hcpdiff/tools/lora_convert.py +18 -17
  75. hcpdiff/tools/save_model.py +12 -0
  76. hcpdiff/tools/sd2diffusers.py +1 -1
  77. hcpdiff/train_colo.py +1 -1
  78. hcpdiff/train_deepspeed.py +1 -1
  79. hcpdiff/trainer_ac.py +79 -0
  80. hcpdiff/trainer_ac_single.py +31 -0
  81. hcpdiff/utils/__init__.py +0 -2
  82. hcpdiff/utils/inpaint_pipe.py +7 -2
  83. hcpdiff/utils/net_utils.py +29 -6
  84. hcpdiff/utils/pipe_hook.py +24 -7
  85. hcpdiff/utils/utils.py +21 -4
  86. hcpdiff/workflow/__init__.py +15 -10
  87. hcpdiff/workflow/daam/__init__.py +1 -0
  88. hcpdiff/workflow/daam/act.py +66 -0
  89. hcpdiff/workflow/daam/hook.py +109 -0
  90. hcpdiff/workflow/diffusion.py +118 -128
  91. hcpdiff/workflow/fast.py +31 -0
  92. hcpdiff/workflow/flow.py +67 -0
  93. hcpdiff/workflow/io.py +36 -130
  94. hcpdiff/workflow/model.py +46 -43
  95. hcpdiff/workflow/text.py +60 -47
  96. hcpdiff/workflow/utils.py +32 -12
  97. hcpdiff/workflow/vae.py +37 -38
  98. hcpdiff-2.2.dist-info/METADATA +299 -0
  99. hcpdiff-2.2.dist-info/RECORD +115 -0
  100. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/WHEEL +1 -1
  101. hcpdiff-2.2.dist-info/entry_points.txt +5 -0
  102. hcpdiff/ckpt_manager/base.py +0 -16
  103. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  104. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  105. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
  106. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  107. hcpdiff/data/bucket.py +0 -358
  108. hcpdiff/data/caption_loader.py +0 -80
  109. hcpdiff/data/cond_dataset.py +0 -40
  110. hcpdiff/data/crop_info_dataset.py +0 -40
  111. hcpdiff/data/data_processor.py +0 -33
  112. hcpdiff/data/pair_dataset.py +0 -146
  113. hcpdiff/data/sampler.py +0 -54
  114. hcpdiff/data/source/base.py +0 -30
  115. hcpdiff/data/utils.py +0 -80
  116. hcpdiff/deprecated/__init__.py +0 -1
  117. hcpdiff/deprecated/cfg_converter.py +0 -81
  118. hcpdiff/deprecated/lora_convert.py +0 -31
  119. hcpdiff/infer_workflow.py +0 -57
  120. hcpdiff/loggers/__init__.py +0 -13
  121. hcpdiff/loggers/base_logger.py +0 -76
  122. hcpdiff/loggers/cli_logger.py +0 -40
  123. hcpdiff/loggers/preview/__init__.py +0 -1
  124. hcpdiff/loggers/preview/image_previewer.py +0 -149
  125. hcpdiff/loggers/tensorboard_logger.py +0 -30
  126. hcpdiff/loggers/wandb_logger.py +0 -31
  127. hcpdiff/loggers/webui_logger.py +0 -9
  128. hcpdiff/loss/min_snr_loss.py +0 -52
  129. hcpdiff/models/layers.py +0 -81
  130. hcpdiff/models/plugin.py +0 -348
  131. hcpdiff/models/wrapper.py +0 -75
  132. hcpdiff/noise/__init__.py +0 -3
  133. hcpdiff/noise/noise_base.py +0 -16
  134. hcpdiff/noise/pyramid_noise.py +0 -50
  135. hcpdiff/noise/zero_terminal.py +0 -44
  136. hcpdiff/train_ac.py +0 -566
  137. hcpdiff/train_ac_single.py +0 -39
  138. hcpdiff/utils/caption_tools.py +0 -105
  139. hcpdiff/utils/cfg_net_tools.py +0 -321
  140. hcpdiff/utils/cfg_resolvers.py +0 -16
  141. hcpdiff/utils/ema.py +0 -52
  142. hcpdiff/utils/img_size_tool.py +0 -248
  143. hcpdiff/vis/__init__.py +0 -3
  144. hcpdiff/vis/base_interface.py +0 -12
  145. hcpdiff/vis/disk_interface.py +0 -48
  146. hcpdiff/vis/webui_interface.py +0 -17
  147. hcpdiff/viser_fast.py +0 -138
  148. hcpdiff/visualizer.py +0 -265
  149. hcpdiff/visualizer_reloadable.py +0 -237
  150. hcpdiff/workflow/base.py +0 -59
  151. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  152. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  153. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  154. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  155. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  156. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  157. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  158. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  159. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  160. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  161. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  162. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  163. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  164. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  165. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  166. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  167. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  168. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  169. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  170. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  171. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  172. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  173. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  174. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  175. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  176. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  177. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  178. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  179. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  180. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  181. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  182. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  183. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  184. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  185. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  186. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  187. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  188. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  189. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  190. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  191. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  192. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  193. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  194. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  195. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
  196. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  197. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
  198. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
  199. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  200. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  201. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  202. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  203. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  204. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  205. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  206. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  207. hcpdiff-0.9.1.dist-info/METADATA +0 -199
  208. hcpdiff-0.9.1.dist-info/RECORD +0 -160
  209. hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
  210. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info/licenses}/LICENSE +0 -0
  211. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/top_level.txt +0 -0
@@ -1,146 +0,0 @@
1
- """
2
- pair_dataset.py
3
- ====================
4
- :Name: text-image pair dataset
5
- :Author: Dong Ziyi
6
- :Affiliation: HCP Lab, SYSU
7
- :Created: 10/03/2023
8
- :Licence: Apache-2.0
9
- """
10
-
11
- import os.path
12
- from argparse import Namespace
13
-
14
- import cv2
15
- import torch
16
- from PIL import Image
17
- from torch.utils.data import Dataset
18
- from tqdm.auto import tqdm
19
- from typing import Tuple
20
-
21
- from hcpdiff.utils.caption_tools import *
22
- from hcpdiff.utils.utils import get_file_name, get_file_ext
23
- from .bucket import BaseBucket
24
- from .source import DataSource, ComposeDataSource
25
-
26
- class TextImagePairDataset(Dataset):
27
- """
28
- A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
29
- It pre-processes the images and the tokenizes prompts.
30
- """
31
-
32
- def __init__(self, tokenizer, tokenizer_repeats: int = 1, att_mask_encode: bool = False,
33
- bucket: BaseBucket = None, source: Dict[str, DataSource] = None, return_path: bool = False,
34
- cache_path:str=None, encoder_attention_mask=False, **kwargs):
35
- self.return_path = return_path
36
-
37
- self.tokenizer = tokenizer
38
- self.tokenizer_repeats = tokenizer_repeats
39
- self.bucket: BaseBucket = bucket
40
- self.att_mask_encode = att_mask_encode
41
- self.source = ComposeDataSource(source)
42
- self.latents = None # Cache latents for faster training. Works only without image argumentations.
43
- self.cache_path = cache_path
44
- self.encoder_attention_mask = encoder_attention_mask
45
-
46
- def load_data(self, path:str, data_source:DataSource, size:Tuple[int]):
47
- image_dict = data_source.load_image(path)
48
- image = image_dict['image']
49
- att_mask = image_dict.get('att_mask', None)
50
- if att_mask is None:
51
- data, crop_coord = self.bucket.crop_resize({"img":image}, size)
52
- image = data_source.procees_image(data['img']) # resize to bucket size
53
- att_mask = torch.ones((size[1]//8, size[0]//8))
54
- else:
55
- data, crop_coord = self.bucket.crop_resize({"img":image, "mask":att_mask}, size)
56
- image = data_source.procees_image(data['img'])
57
- att_mask = torch.tensor(cv2.resize(data['mask'], (size[0]//8, size[1]//8), interpolation=cv2.INTER_LINEAR))
58
- return {'img':image, 'mask':att_mask}
59
-
60
- @torch.no_grad()
61
- def cache_latents(self, vae, weight_dtype, device, show_prog=True):
62
- if self.cache_path and os.path.exists(self.cache_path):
63
- self.latents = torch.load(self.cache_path)
64
- return
65
-
66
- self.latents = {}
67
- self.bucket.rest(0)
68
-
69
- for (path, data_source), size in tqdm(self.bucket, disable=not show_prog):
70
- img_name = data_source.get_image_name(path)
71
- if img_name not in self.latents:
72
- data = self.load_data(path, data_source, size)
73
- image = data['img'].unsqueeze(0).to(device, dtype=weight_dtype)
74
- latents = vae.encode(image).latent_dist.sample().squeeze(0)
75
- data['img'] = (latents*vae.config.scaling_factor).cpu()
76
- self.latents[img_name] = data
77
-
78
- if self.cache_path:
79
- torch.save(self.latents, self.cache_path)
80
-
81
- def __len__(self):
82
- return len(self.bucket)
83
-
84
- def __getitem__(self, index):
85
- (path, data_source), size = self.bucket[index]
86
- img_name = data_source.get_image_name(path)
87
-
88
- if self.latents is None:
89
- data = self.load_data(path, data_source, size)
90
- else:
91
- data = self.latents[img_name].copy()
92
-
93
- prompt_ist = data_source.load_caption(img_name)
94
-
95
- # tokenize Sp or (Sn, Sp)
96
- tokens = self.tokenizer(prompt_ist, truncation=True, padding="max_length", return_tensors="pt",
97
- max_length=self.tokenizer.model_max_length*self.tokenizer_repeats)
98
- data['prompt'] = tokens.input_ids.squeeze()
99
- if self.encoder_attention_mask and 'attention_mask' in tokens:
100
- data['attn_mask'] = tokens.attention_mask.squeeze()
101
- if 'position_ids' in tokens:
102
- data['position_ids'] = tokens.position_ids.squeeze()
103
-
104
- if self.return_path:
105
- return data, path
106
- else:
107
- return data
108
-
109
- @staticmethod
110
- def collate_fn(batch):
111
- '''
112
- batch: [{img:tensor, prompt:str, ..., plugin_input:{...}},{}]
113
- '''
114
- has_plugin_input = 'plugin_input' in batch[0]
115
- if has_plugin_input:
116
- plugin_input = {k:[] for k in batch[0]['plugin_input'].keys()}
117
-
118
- datas = {k:[] for k in batch[0].keys() if k != 'plugin_input' and k != 'prompt'}
119
- sn_list, sp_list = [], []
120
-
121
- for data in batch:
122
- if has_plugin_input:
123
- for k, v in data.pop('plugin_input').items():
124
- plugin_input[k].append(v)
125
-
126
- prompt = data.pop('prompt')
127
- if len(prompt.shape) == 2:
128
- sn_list.append(prompt[0])
129
- sp_list.append(prompt[1])
130
- else:
131
- sp_list.append(prompt)
132
-
133
- for k, v in data.items():
134
- datas[k].append(v)
135
-
136
- for k, v in datas.items():
137
- datas[k] = torch.stack(v)
138
- if k == 'mask':
139
- datas[k] = datas[k].unsqueeze(1)
140
-
141
- sn_list += sp_list
142
- datas['prompt'] = torch.stack(sn_list)
143
- if has_plugin_input:
144
- datas['plugin_input'] = {k:torch.stack(v) for k, v in plugin_input.items()}
145
-
146
- return datas
hcpdiff/data/sampler.py DELETED
@@ -1,54 +0,0 @@
1
- import torch
2
- from torch.utils.data.distributed import DistributedSampler
3
- from typing import Iterator
4
- import platform
5
- import math
6
-
7
- class DistributedCycleSampler(DistributedSampler):
8
- _cycle = 1000
9
-
10
- def __iter__(self) -> Iterator:
11
- def _iter():
12
- while True:
13
- if self.shuffle:
14
- # deterministically shuffle based on epoch and seed
15
- g = torch.Generator()
16
- g.manual_seed(self.seed + self.epoch)
17
- indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
18
- else:
19
- indices = list(range(len(self.dataset))) # type: ignore[arg-type]
20
-
21
- if not self.drop_last:
22
- # add extra samples to make it evenly divisible
23
- padding_size = self.total_size - len(indices)
24
- if padding_size <= len(indices):
25
- indices += indices[:padding_size]
26
- else:
27
- indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
28
- else:
29
- # remove tail of data to make it evenly divisible.
30
- indices = indices[:self.total_size]
31
- assert len(indices) == self.total_size
32
-
33
- # subsample
34
- indices = indices[self.rank:self.total_size:self.num_replicas]
35
- assert len(indices) == self.num_samples
36
-
37
- for idx in indices:
38
- yield idx
39
- self.epoch+=1
40
-
41
- if self.epoch>=self._cycle:
42
- break
43
-
44
- return _iter()
45
-
46
- def __len__(self) -> int:
47
- return self.num_samples #*self._cycle
48
-
49
- def get_sampler():
50
- # Fix DataLoader frequently reload bugs in windows
51
- if platform.system().lower() == 'windows':
52
- return DistributedCycleSampler
53
- else:
54
- return DistributedSampler
@@ -1,30 +0,0 @@
1
- import os
2
- from typing import Dict, List, Tuple, Any
3
-
4
- class DataSource:
5
- def __init__(self, img_root, repeat=1, **kwargs):
6
- self.img_root = img_root
7
- self.repeat = repeat
8
-
9
- def get_image_list(self) -> List[Tuple[str, "DataSource"]]:
10
- raise NotImplementedError()
11
-
12
- def procees_image(self, image):
13
- raise NotImplementedError()
14
-
15
- def load_image(self, path) -> Dict[str, Any]:
16
- raise NotImplementedError()
17
-
18
- def get_image_name(self, path: str) -> str:
19
- img_root, img_name = os.path.split(path)
20
- return img_name.rsplit('.')[0]
21
-
22
- class ComposeDataSource(DataSource):
23
- def __init__(self, source_dict: Dict[str, DataSource]):
24
- self.source_dict = source_dict
25
-
26
- def get_image_list(self) -> List[Tuple[str, DataSource]]:
27
- img_list = []
28
- for source in self.source_dict.values():
29
- img_list.extend(source.get_image_list())
30
- return img_list
hcpdiff/data/utils.py DELETED
@@ -1,80 +0,0 @@
1
- import cv2
2
- import numpy as np
3
- from PIL import Image
4
- from torchvision import transforms as T
5
- from torchvision.transforms import functional as F
6
-
7
- class DualRandomCrop:
8
- def __init__(self, size):
9
- self.size = size
10
-
11
- def __call__(self, img):
12
- crop_params = T.RandomCrop.get_params(img['img'], self.size)
13
- img['img'] = F.crop(img['img'], *crop_params)
14
- if "mask" in img:
15
- img['mask'] = self.crop(img['mask'], *crop_params)
16
- if "cond" in img:
17
- img['cond'] = F.crop(img['cond'], *crop_params)
18
- return img, crop_params[:2]
19
-
20
- @staticmethod
21
- def crop(img: np.ndarray, top: int, left: int, height: int, width: int) -> np.ndarray:
22
- right = left+width
23
- bottom = top+height
24
- return img[top:bottom, left:right, ...]
25
-
26
- def resize_crop_fix(img, target_size, mask_interp=cv2.INTER_CUBIC):
27
- w, h = img['img'].size
28
- if w == target_size[0] and h == target_size[1]:
29
- return img, [h,w,0,0,h,w]
30
-
31
- ratio_img = w/h
32
- if ratio_img>target_size[0]/target_size[1]:
33
- new_size = (round(ratio_img*target_size[1]), target_size[1])
34
- interp_type = Image.LANCZOS if h>target_size[1] else Image.BICUBIC
35
- else:
36
- new_size = (target_size[0], round(target_size[0]/ratio_img))
37
- interp_type = Image.LANCZOS if w>target_size[0] else Image.BICUBIC
38
- img['img'] = img['img'].resize(new_size, interp_type)
39
- if "mask" in img:
40
- img['mask'] = cv2.resize(img['mask'], new_size, interpolation=mask_interp)
41
- if "cond" in img:
42
- img['cond'] = img['cond'].resize(new_size, interp_type)
43
-
44
- img, crop_coord = DualRandomCrop(target_size[::-1])(img)
45
- return img, [*new_size, *crop_coord[::-1], *target_size]
46
-
47
- def pad_crop_fix(img, target_size):
48
- w, h = img['img'].size
49
- if w == target_size[0] and h == target_size[1]:
50
- return img, (h,w,0,0,h,w)
51
-
52
- pad_size = [0, 0, max(target_size[0]-w, 0), max(target_size[1]-h, 0)]
53
- if pad_size[2]>0 or pad_size[3]>0:
54
- img['img'] = F.pad(img['img'], pad_size)
55
- if "mask" in img:
56
- img['mask'] = np.pad(img['mask'], ((0, pad_size[3]), (0, pad_size[2])), 'constant', constant_values=(0, 0))
57
- if "cond" in img:
58
- img['cond'] = F.pad(img['cond'], pad_size)
59
-
60
- if pad_size[2]>0 and pad_size[3]>0:
61
- return img, (h,w,0,0,h,w) # No need to crop
62
- else:
63
- img, crop_coord = DualRandomCrop(target_size[::-1])(img)
64
- return img, crop_coord
65
-
66
- class CycleData():
67
- def __init__(self, data_loader):
68
- self.data_loader = data_loader
69
-
70
- def __iter__(self):
71
- self.epoch = 0
72
-
73
- def cycle():
74
- while True:
75
- self.data_loader.dataset.bucket.rest(self.epoch)
76
- for data in self.data_loader:
77
- yield data
78
- self.epoch += 1
79
-
80
- return cycle()
@@ -1 +0,0 @@
1
- from .lora_convert import convert_to_webui_maybe_old, convert_to_webui_xl_maybe_old
@@ -1,81 +0,0 @@
1
- """
2
- train_ac.py
3
- ====================
4
- :Name: convert old cfg format to new format
5
- :Author: Dong Ziyi
6
- :Affiliation: HCP Lab, SYSU
7
- :Created: 10/03/2023
8
- :Licence: Apache-2.0
9
- """
10
-
11
- from omegaconf import ListConfig, DictConfig, OmegaConf
12
-
13
- class DatasetCFGConverter:
14
-
15
- def convert_source(self, cfg_source:DictConfig):
16
- if '_target_' not in cfg_source:
17
- cfg_source['_target_'] = 'hcpdiff.data.source.Text2ImageAttMapSource'
18
-
19
- if 'tag_transforms' in cfg_source:
20
- cfg_source['text_transforms'] = cfg_source.pop('tag_transforms')
21
-
22
- def convert(self, cfg:DictConfig):
23
- for dataset in cfg['data'].values():
24
- for source in dataset['source'].values():
25
- self.convert_source(source)
26
- return cfg
27
-
28
- class TrainCFGConverter:
29
- def __init__(self):
30
- self.dataset_converter = DatasetCFGConverter()
31
-
32
- def convert_model(self, cfg_model:DictConfig):
33
- if 'ema_unet' in cfg_model and 'ema' not in cfg_model:
34
- if cfg_model['ema_unet']==0: # no ema
35
- cfg_model['ema'] = None
36
- else:
37
- cfg_model['ema'] = OmegaConf.create({
38
- '_target_': 'hcpdiff.utils.ema.ModelEMA',
39
- '_partial_': True,
40
- 'decay_max': cfg_model['ema_unet'],
41
- 'power': 0.85
42
- })
43
-
44
- if 'tokenizer' not in cfg_model:
45
- cfg_model['tokenizer'] = None
46
- if 'noise_scheduler' not in cfg_model:
47
- cfg_model['noise_scheduler'] = None
48
- if 'unet' not in cfg_model:
49
- cfg_model['unet'] = None
50
- if 'text_encoder' not in cfg_model:
51
- cfg_model['text_encoder'] = None
52
- if 'vae' not in cfg_model:
53
- cfg_model['vae'] = None
54
-
55
- def convert_loss(self, cfg_train:DictConfig):
56
- if cfg_train['loss']['criterion']['_target_']=='hcpdiff.loss.MSELoss':
57
- cfg_train['loss']['criterion']['_target_'] = 'torch.nn.MSELoss'
58
-
59
- def convert(self, cfg:DictConfig):
60
- self.convert_model(cfg['model'])
61
- self.convert_loss(cfg['train'])
62
-
63
- if 'previewer' not in cfg:
64
- cfg['previewer'] = None
65
-
66
- cfg = self.dataset_converter.convert(cfg)
67
- return cfg
68
-
69
- class InferCFGConverter:
70
-
71
- def convert(self, cfg:DictConfig):
72
- if 'encoder_attention_mask' not in cfg:
73
- cfg['encoder_attention_mask'] = False
74
-
75
- if 'amp' not in cfg:
76
- if cfg['dtype']=='amp':
77
- cfg['dtype'] = 'fp32'
78
- cfg['amp'] = True
79
- else:
80
- cfg['amp'] = False
81
- return cfg
@@ -1,31 +0,0 @@
1
-
2
- def convert_to_webui_maybe_old(new_func):
3
- def convert_to_webui_(self, state, prefix):
4
- sd_covert = {}
5
- for k, v in state.items():
6
- # new lora format
7
- if k.endswith('W_down'):
8
- return new_func(self, state, prefix)
9
-
10
- # old lora format
11
- model_k, lora_k = k.split('.___.' if ('alpha' in k or 'scale' in k) else '.___.layer.', 1)
12
- sd_covert[f"{prefix}{model_k.replace('.', '_')}.{lora_k}"] = v
13
- return sd_covert
14
- return convert_to_webui_
15
-
16
- def convert_to_webui_xl_maybe_old(new_func):
17
- def convert_to_webui_xl_(self, state, prefix):
18
- sd_convert = {}
19
- for k, v in state.items():
20
- # new lora format
21
- if k.endswith('W_down'):
22
- return new_func(self, state, prefix)
23
-
24
- # old lora format
25
- model_k, lora_k = k.split('.___.' if ('alpha' in k or 'scale' in k) else '.___.layer.', 1)
26
- new_k = f"{prefix}{model_k.replace('.', '_')}.{lora_k}"
27
- if 'clip' in new_k:
28
- new_k = new_k.replace('_clip_B', '1') if 'clip_B' in new_k else new_k.replace('_clip_bigG', '2')
29
- sd_convert[new_k] = v
30
- return sd_convert
31
- return convert_to_webui_xl_
hcpdiff/infer_workflow.py DELETED
@@ -1,57 +0,0 @@
1
- import argparse
2
-
3
- import torch
4
- import hydra
5
- from omegaconf import OmegaConf, DictConfig
6
- from easydict import EasyDict
7
-
8
- from hcpdiff.utils.utils import load_config_with_cli
9
- from .workflow import MemoryMixin
10
- from copy import deepcopy
11
-
12
- class WorkflowRunner:
13
- def __init__(self, cfgs):
14
- self.cfgs_raw = deepcopy(cfgs)
15
- self.cfgs = OmegaConf.structured(cfgs, flags={"allow_objects": True})
16
- OmegaConf.resolve(self.cfgs)
17
- self.memory = EasyDict(**hydra.utils.instantiate(self.cfgs.memory))
18
- self.attach_memory(self.cfgs)
19
-
20
- def start(self):
21
- prepare_actions = hydra.utils.instantiate(self.cfgs.prepare)
22
- states = self.run(prepare_actions, {'cfgs': self.cfgs_raw})
23
- actions = hydra.utils.instantiate(self.cfgs.actions)
24
- states = self.run(actions, states)
25
-
26
- def attach_memory(self, cfgs):
27
- if OmegaConf.is_dict(cfgs):
28
- if '_target_' in cfgs and cfgs['_target_'].endswith('.from_memory'):
29
- cfgs._set_flag('allow_objects', True)
30
- cfgs['memory'] = self.memory
31
- else:
32
- for v in cfgs.values():
33
- self.attach_memory(v)
34
- elif OmegaConf.is_list(cfgs):
35
- for v in cfgs:
36
- self.attach_memory(v)
37
-
38
- @torch.inference_mode()
39
- def run(self, actions, states):
40
- N_steps = len(actions)
41
- for step, act in enumerate(actions):
42
- print(f'[{step+1}/{N_steps}] action: {type(act).__name__}')
43
- if isinstance(act, MemoryMixin):
44
- states = act(memory=self.memory, **states)
45
- else:
46
- states = act(**states)
47
- print(f'states: {", ".join(states.keys())}')
48
- return states
49
-
50
- if __name__ == '__main__':
51
- parser = argparse.ArgumentParser(description='HCP-Diffusion workflow')
52
- parser.add_argument('--cfg', type=str, default='')
53
- args, cfg_args = parser.parse_known_args()
54
- cfgs = load_config_with_cli(args.cfg, args_list=cfg_args) # skip --cfg
55
-
56
- runner = WorkflowRunner(cfgs)
57
- runner.start()
@@ -1,13 +0,0 @@
1
- from .base_logger import BaseLogger, LoggerGroup
2
- from .cli_logger import CLILogger
3
- from .webui_logger import WebUILogger
4
-
5
- try:
6
- from .tensorboard_logger import TBLogger
7
- except:
8
- print('tensorboard is not available')
9
-
10
- try:
11
- from .wandb_logger import WanDBLogger
12
- except:
13
- print('wandb is not available')
@@ -1,76 +0,0 @@
1
- from typing import Dict, Any, List
2
-
3
- from PIL import Image
4
-
5
- from .preview import ImagePreviewer
6
-
7
- class BaseLogger:
8
- def __init__(self, exp_dir, out_path, enable_log_image=False, log_step=10, image_log_step=200):
9
- self.exp_dir = exp_dir
10
- self.out_path = out_path
11
- self.enable_log_image = enable_log_image
12
- self.log_step = log_step
13
- self.image_log_step = image_log_step
14
- self.enable_log = True
15
- self.previewer_list: List[ImagePreviewer] = []
16
-
17
- def enable(self):
18
- self.enable_log = True
19
-
20
- def disable(self):
21
- self.enable_log = False
22
-
23
- def add_previewer(self, previewer: ImagePreviewer):
24
- self.previewer_list.append(previewer)
25
-
26
- def info(self, info):
27
- if self.enable_log:
28
- self._info(info)
29
-
30
- def _info(self, info):
31
- raise NotImplementedError()
32
-
33
- def log(self, datas: Dict[str, Any], step: int = 0):
34
- if self.enable_log and step%self.log_step == 0:
35
- self._log(datas, step)
36
-
37
- def _log(self, datas: Dict[str, Any], step: int = 0):
38
- raise NotImplementedError()
39
-
40
- def log_image(self, imgs: Dict[str, Image.Image], step: int = 0):
41
- if self.enable_log and self.enable_log_image and step%self.image_log_step == 0:
42
- self._log_image(imgs, step)
43
-
44
- def _log_image(self, imgs: Dict[str, Image.Image], step: int = 0):
45
- raise NotImplementedError()
46
-
47
- class LoggerGroup:
48
- def __init__(self, logger_list: List[BaseLogger]):
49
- self.logger_list = logger_list
50
-
51
- def enable(self):
52
- for logger in self.logger_list:
53
- logger.enable()
54
-
55
- def disable(self):
56
- for logger in self.logger_list:
57
- logger.disable()
58
-
59
- def add_previewer(self, previewer):
60
- for logger in self.logger_list:
61
- logger.add_previewer(previewer)
62
-
63
- def info(self, info):
64
- for logger in self.logger_list:
65
- logger.info(info)
66
-
67
- def log(self, datas: Dict[str, Any], step: int = 0):
68
- for logger in self.logger_list:
69
- logger.log(datas, step)
70
-
71
- def log_image(self, imgs: Dict[str, Image.Image], step: int = 0):
72
- for logger in self.logger_list:
73
- logger.log_image(imgs, step)
74
-
75
- def __len__(self):
76
- return len(self.logger_list)
@@ -1,40 +0,0 @@
1
- import os
2
- from typing import Dict, Any
3
-
4
- from PIL import Image
5
- from loguru import logger
6
-
7
- from .base_logger import BaseLogger
8
-
9
- class CLILogger(BaseLogger):
10
- def __init__(self, exp_dir, out_path, enable_log_image=False, log_step=10, image_log_step=200,
11
- img_log_dir='preview', img_ext='png', img_quality=95):
12
- super().__init__(exp_dir, out_path, enable_log_image, log_step, image_log_step)
13
- if exp_dir is not None: # exp_dir is only available in local main process
14
- logger.add(os.path.join(exp_dir, out_path))
15
- if enable_log_image:
16
- self.img_log_dir = os.path.join(exp_dir, img_log_dir)
17
- os.makedirs(self.img_log_dir, exist_ok=True)
18
- self.img_ext = img_ext
19
- self.img_quality = img_quality
20
- else:
21
- self.disable()
22
-
23
- def enable(self):
24
- super(CLILogger, self).enable()
25
- logger.enable("__main__")
26
-
27
- def disable(self):
28
- super(CLILogger, self).disable()
29
- logger.disable("__main__")
30
-
31
- def _info(self, info):
32
- logger.info(info)
33
-
34
- def _log(self, datas: Dict[str, Any], step: int = 0):
35
- logger.info(', '.join([f"{k} = {v['format'].format(*v['data'])}" for k, v in datas.items()]))
36
-
37
- def _log_image(self, imgs: Dict[str, Image.Image], step: int = 0):
38
- logger.info(f'log {len(imgs)} images')
39
- for name, img in imgs.items():
40
- img.save(os.path.join(self.img_log_dir, f'{step}-{name}.{self.img_ext}'), quality=self.img_quality)
@@ -1 +0,0 @@
1
- from .image_previewer import ImagePreviewer