hcpdiff 0.9.0__py3-none-any.whl → 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (208) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +80 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +1 -2
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text2img.py +36 -74
  21. hcpdiff/data/source/text2img_cond.py +9 -15
  22. hcpdiff/diffusion/__init__.py +0 -0
  23. hcpdiff/diffusion/noise/__init__.py +2 -0
  24. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  25. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  26. hcpdiff/diffusion/sampler/__init__.py +5 -0
  27. hcpdiff/diffusion/sampler/base.py +72 -0
  28. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  29. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  30. hcpdiff/diffusion/sampler/edm.py +22 -0
  31. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  35. hcpdiff/easy/__init__.py +2 -0
  36. hcpdiff/easy/cfg/__init__.py +3 -0
  37. hcpdiff/easy/cfg/sd15_train.py +201 -0
  38. hcpdiff/easy/cfg/sdxl_train.py +140 -0
  39. hcpdiff/easy/cfg/t2i.py +177 -0
  40. hcpdiff/easy/model/__init__.py +2 -0
  41. hcpdiff/easy/model/cnet.py +31 -0
  42. hcpdiff/easy/model/loader.py +79 -0
  43. hcpdiff/easy/sampler.py +46 -0
  44. hcpdiff/evaluate/__init__.py +1 -0
  45. hcpdiff/evaluate/previewer.py +60 -0
  46. hcpdiff/loss/__init__.py +4 -1
  47. hcpdiff/loss/base.py +41 -0
  48. hcpdiff/loss/gw.py +35 -0
  49. hcpdiff/loss/ssim.py +37 -0
  50. hcpdiff/loss/vlb.py +79 -0
  51. hcpdiff/loss/weighting.py +66 -0
  52. hcpdiff/models/__init__.py +2 -2
  53. hcpdiff/models/cfg_context.py +17 -14
  54. hcpdiff/models/compose/compose_hook.py +44 -23
  55. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  56. hcpdiff/models/compose/sdxl_composer.py +4 -4
  57. hcpdiff/models/container.py +1 -1
  58. hcpdiff/models/controlnet.py +16 -16
  59. hcpdiff/models/lora_base_patch.py +14 -25
  60. hcpdiff/models/lora_layers.py +3 -9
  61. hcpdiff/models/lora_layers_patch.py +14 -24
  62. hcpdiff/models/text_emb_ex.py +84 -6
  63. hcpdiff/models/textencoder_ex.py +54 -18
  64. hcpdiff/models/wrapper/__init__.py +3 -0
  65. hcpdiff/models/wrapper/pixart.py +19 -0
  66. hcpdiff/models/wrapper/sd.py +218 -0
  67. hcpdiff/models/wrapper/utils.py +20 -0
  68. hcpdiff/parser/__init__.py +1 -0
  69. hcpdiff/parser/embpt.py +32 -0
  70. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  71. hcpdiff/tools/dataset_generator.py +94 -0
  72. hcpdiff/tools/download_hf_model.py +24 -0
  73. hcpdiff/tools/embedding_convert.py +6 -2
  74. hcpdiff/tools/init_proj.py +3 -21
  75. hcpdiff/tools/lora_convert.py +19 -15
  76. hcpdiff/tools/save_model.py +12 -0
  77. hcpdiff/tools/sd2diffusers.py +1 -1
  78. hcpdiff/train_colo.py +1 -1
  79. hcpdiff/train_deepspeed.py +1 -1
  80. hcpdiff/trainer_ac.py +79 -0
  81. hcpdiff/trainer_ac_single.py +31 -0
  82. hcpdiff/utils/__init__.py +0 -2
  83. hcpdiff/utils/inpaint_pipe.py +790 -0
  84. hcpdiff/utils/net_utils.py +29 -6
  85. hcpdiff/utils/pipe_hook.py +46 -33
  86. hcpdiff/utils/utils.py +21 -4
  87. hcpdiff/workflow/__init__.py +15 -10
  88. hcpdiff/workflow/daam/__init__.py +1 -0
  89. hcpdiff/workflow/daam/act.py +66 -0
  90. hcpdiff/workflow/daam/hook.py +109 -0
  91. hcpdiff/workflow/diffusion.py +128 -136
  92. hcpdiff/workflow/fast.py +31 -0
  93. hcpdiff/workflow/flow.py +67 -0
  94. hcpdiff/workflow/io.py +36 -68
  95. hcpdiff/workflow/model.py +46 -43
  96. hcpdiff/workflow/text.py +84 -52
  97. hcpdiff/workflow/utils.py +32 -12
  98. hcpdiff/workflow/vae.py +37 -38
  99. hcpdiff-2.1.dist-info/METADATA +285 -0
  100. hcpdiff-2.1.dist-info/RECORD +114 -0
  101. {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
  102. hcpdiff-2.1.dist-info/entry_points.txt +5 -0
  103. hcpdiff/ckpt_manager/base.py +0 -16
  104. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  105. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  106. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -60
  107. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  108. hcpdiff/data/bucket.py +0 -358
  109. hcpdiff/data/caption_loader.py +0 -80
  110. hcpdiff/data/cond_dataset.py +0 -40
  111. hcpdiff/data/crop_info_dataset.py +0 -40
  112. hcpdiff/data/data_processor.py +0 -33
  113. hcpdiff/data/pair_dataset.py +0 -146
  114. hcpdiff/data/sampler.py +0 -54
  115. hcpdiff/data/source/base.py +0 -30
  116. hcpdiff/data/utils.py +0 -80
  117. hcpdiff/infer_workflow.py +0 -57
  118. hcpdiff/loggers/__init__.py +0 -13
  119. hcpdiff/loggers/base_logger.py +0 -76
  120. hcpdiff/loggers/cli_logger.py +0 -40
  121. hcpdiff/loggers/preview/__init__.py +0 -1
  122. hcpdiff/loggers/preview/image_previewer.py +0 -149
  123. hcpdiff/loggers/tensorboard_logger.py +0 -30
  124. hcpdiff/loggers/wandb_logger.py +0 -31
  125. hcpdiff/loggers/webui_logger.py +0 -9
  126. hcpdiff/loss/min_snr_loss.py +0 -52
  127. hcpdiff/models/layers.py +0 -81
  128. hcpdiff/models/plugin.py +0 -348
  129. hcpdiff/models/wrapper.py +0 -75
  130. hcpdiff/noise/__init__.py +0 -3
  131. hcpdiff/noise/noise_base.py +0 -16
  132. hcpdiff/noise/pyramid_noise.py +0 -50
  133. hcpdiff/noise/zero_terminal.py +0 -44
  134. hcpdiff/train_ac.py +0 -565
  135. hcpdiff/train_ac_single.py +0 -39
  136. hcpdiff/utils/caption_tools.py +0 -105
  137. hcpdiff/utils/cfg_net_tools.py +0 -321
  138. hcpdiff/utils/cfg_resolvers.py +0 -16
  139. hcpdiff/utils/ema.py +0 -52
  140. hcpdiff/utils/img_size_tool.py +0 -248
  141. hcpdiff/vis/__init__.py +0 -3
  142. hcpdiff/vis/base_interface.py +0 -12
  143. hcpdiff/vis/disk_interface.py +0 -48
  144. hcpdiff/vis/webui_interface.py +0 -17
  145. hcpdiff/visualizer.py +0 -258
  146. hcpdiff/visualizer_reloadable.py +0 -237
  147. hcpdiff/workflow/base.py +0 -59
  148. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  149. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  150. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  151. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  152. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  153. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  154. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  155. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  156. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  157. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  158. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  159. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  160. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  161. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  162. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  163. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  164. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  165. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  166. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  167. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  168. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  169. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  170. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  171. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  172. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  173. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  174. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  175. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  176. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  177. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  178. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  179. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  180. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  181. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  182. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  183. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  184. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  185. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  186. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  187. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  188. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  189. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  190. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  191. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  192. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -57
  193. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  194. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero2.json +0 -32
  195. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero3.json +0 -39
  196. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  197. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  198. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  199. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  200. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  201. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  202. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  203. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  204. hcpdiff-0.9.0.dist-info/METADATA +0 -199
  205. hcpdiff-0.9.0.dist-info/RECORD +0 -155
  206. hcpdiff-0.9.0.dist-info/entry_points.txt +0 -2
  207. {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
  208. {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
@@ -1,146 +0,0 @@
1
- """
2
- pair_dataset.py
3
- ====================
4
- :Name: text-image pair dataset
5
- :Author: Dong Ziyi
6
- :Affiliation: HCP Lab, SYSU
7
- :Created: 10/03/2023
8
- :Licence: Apache-2.0
9
- """
10
-
11
- import os.path
12
- from argparse import Namespace
13
-
14
- import cv2
15
- import torch
16
- from PIL import Image
17
- from torch.utils.data import Dataset
18
- from tqdm.auto import tqdm
19
- from typing import Tuple
20
-
21
- from hcpdiff.utils.caption_tools import *
22
- from hcpdiff.utils.utils import get_file_name, get_file_ext
23
- from .bucket import BaseBucket
24
- from .source import DataSource, ComposeDataSource
25
-
26
- class TextImagePairDataset(Dataset):
27
- """
28
- A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
29
- It pre-processes the images and the tokenizes prompts.
30
- """
31
-
32
- def __init__(self, tokenizer, tokenizer_repeats: int = 1, att_mask_encode: bool = False,
33
- bucket: BaseBucket = None, source: Dict[str, DataSource] = None, return_path: bool = False,
34
- cache_path:str=None, encoder_attention_mask=False, **kwargs):
35
- self.return_path = return_path
36
-
37
- self.tokenizer = tokenizer
38
- self.tokenizer_repeats = tokenizer_repeats
39
- self.bucket: BaseBucket = bucket
40
- self.att_mask_encode = att_mask_encode
41
- self.source = ComposeDataSource(source)
42
- self.latents = None # Cache latents for faster training. Works only without image argumentations.
43
- self.cache_path = cache_path
44
- self.encoder_attention_mask = encoder_attention_mask
45
-
46
- def load_data(self, path:str, data_source:DataSource, size:Tuple[int]):
47
- image_dict = data_source.load_image(path)
48
- image = image_dict['image']
49
- att_mask = image_dict.get('att_mask', None)
50
- if att_mask is None:
51
- data, crop_coord = self.bucket.crop_resize({"img":image}, size)
52
- image = data_source.procees_image(data['img']) # resize to bucket size
53
- att_mask = torch.ones((size[1]//8, size[0]//8))
54
- else:
55
- data, crop_coord = self.bucket.crop_resize({"img":image, "mask":att_mask}, size)
56
- image = data_source.procees_image(data['img'])
57
- att_mask = torch.tensor(cv2.resize(data['mask'], (size[0]//8, size[1]//8), interpolation=cv2.INTER_LINEAR))
58
- return {'img':image, 'mask':att_mask}
59
-
60
- @torch.no_grad()
61
- def cache_latents(self, vae, weight_dtype, device, show_prog=True):
62
- if self.cache_path and os.path.exists(self.cache_path):
63
- self.latents = torch.load(self.cache_path)
64
- return
65
-
66
- self.latents = {}
67
- self.bucket.rest(0)
68
-
69
- for (path, data_source), size in tqdm(self.bucket, disable=not show_prog):
70
- img_name = data_source.get_image_name(path)
71
- if img_name not in self.latents:
72
- data = self.load_data(path, data_source, size)
73
- image = data['img'].unsqueeze(0).to(device, dtype=weight_dtype)
74
- latents = vae.encode(image).latent_dist.sample().squeeze(0)
75
- data['img'] = (latents*vae.config.scaling_factor).cpu()
76
- self.latents[img_name] = data
77
-
78
- if self.cache_path:
79
- torch.save(self.latents, self.cache_path)
80
-
81
- def __len__(self):
82
- return len(self.bucket)
83
-
84
- def __getitem__(self, index):
85
- (path, data_source), size = self.bucket[index]
86
- img_name = data_source.get_image_name(path)
87
-
88
- if self.latents is None:
89
- data = self.load_data(path, data_source, size)
90
- else:
91
- data = self.latents[img_name].copy()
92
-
93
- prompt_ist = data_source.load_caption(img_name)
94
-
95
- # tokenize Sp or (Sn, Sp)
96
- tokens = self.tokenizer(prompt_ist, truncation=True, padding="max_length", return_tensors="pt",
97
- max_length=self.tokenizer.model_max_length*self.tokenizer_repeats)
98
- data['prompt'] = tokens.input_ids.squeeze()
99
- if self.encoder_attention_mask and 'attention_mask' in tokens:
100
- data['attn_mask'] = tokens.attention_mask.squeeze()
101
- if 'position_ids' in tokens:
102
- data['position_ids'] = tokens.position_ids.squeeze()
103
-
104
- if self.return_path:
105
- return data, path
106
- else:
107
- return data
108
-
109
- @staticmethod
110
- def collate_fn(batch):
111
- '''
112
- batch: [{img:tensor, prompt:str, ..., plugin_input:{...}},{}]
113
- '''
114
- has_plugin_input = 'plugin_input' in batch[0]
115
- if has_plugin_input:
116
- plugin_input = {k:[] for k in batch[0]['plugin_input'].keys()}
117
-
118
- datas = {k:[] for k in batch[0].keys() if k != 'plugin_input' and k != 'prompt'}
119
- sn_list, sp_list = [], []
120
-
121
- for data in batch:
122
- if has_plugin_input:
123
- for k, v in data.pop('plugin_input').items():
124
- plugin_input[k].append(v)
125
-
126
- prompt = data.pop('prompt')
127
- if len(prompt.shape) == 2:
128
- sn_list.append(prompt[0])
129
- sp_list.append(prompt[1])
130
- else:
131
- sp_list.append(prompt)
132
-
133
- for k, v in data.items():
134
- datas[k].append(v)
135
-
136
- for k, v in datas.items():
137
- datas[k] = torch.stack(v)
138
- if k == 'mask':
139
- datas[k] = datas[k].unsqueeze(1)
140
-
141
- sn_list += sp_list
142
- datas['prompt'] = torch.stack(sn_list)
143
- if has_plugin_input:
144
- datas['plugin_input'] = {k:torch.stack(v) for k, v in plugin_input.items()}
145
-
146
- return datas
hcpdiff/data/sampler.py DELETED
@@ -1,54 +0,0 @@
1
- import torch
2
- from torch.utils.data.distributed import DistributedSampler
3
- from typing import Iterator
4
- import platform
5
- import math
6
-
7
- class DistributedCycleSampler(DistributedSampler):
8
- _cycle = 1000
9
-
10
- def __iter__(self) -> Iterator:
11
- def _iter():
12
- while True:
13
- if self.shuffle:
14
- # deterministically shuffle based on epoch and seed
15
- g = torch.Generator()
16
- g.manual_seed(self.seed + self.epoch)
17
- indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
18
- else:
19
- indices = list(range(len(self.dataset))) # type: ignore[arg-type]
20
-
21
- if not self.drop_last:
22
- # add extra samples to make it evenly divisible
23
- padding_size = self.total_size - len(indices)
24
- if padding_size <= len(indices):
25
- indices += indices[:padding_size]
26
- else:
27
- indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
28
- else:
29
- # remove tail of data to make it evenly divisible.
30
- indices = indices[:self.total_size]
31
- assert len(indices) == self.total_size
32
-
33
- # subsample
34
- indices = indices[self.rank:self.total_size:self.num_replicas]
35
- assert len(indices) == self.num_samples
36
-
37
- for idx in indices:
38
- yield idx
39
- self.epoch+=1
40
-
41
- if self.epoch>=self._cycle:
42
- break
43
-
44
- return _iter()
45
-
46
- def __len__(self) -> int:
47
- return self.num_samples #*self._cycle
48
-
49
- def get_sampler():
50
- # Fix DataLoader frequently reload bugs in windows
51
- if platform.system().lower() == 'windows':
52
- return DistributedCycleSampler
53
- else:
54
- return DistributedSampler
@@ -1,30 +0,0 @@
1
- import os
2
- from typing import Dict, List, Tuple, Any
3
-
4
- class DataSource:
5
- def __init__(self, img_root, repeat=1, **kwargs):
6
- self.img_root = img_root
7
- self.repeat = repeat
8
-
9
- def get_image_list(self) -> List[Tuple[str, "DataSource"]]:
10
- raise NotImplementedError()
11
-
12
- def procees_image(self, image):
13
- raise NotImplementedError()
14
-
15
- def load_image(self, path) -> Dict[str, Any]:
16
- raise NotImplementedError()
17
-
18
- def get_image_name(self, path: str) -> str:
19
- img_root, img_name = os.path.split(path)
20
- return img_name.rsplit('.')[0]
21
-
22
- class ComposeDataSource(DataSource):
23
- def __init__(self, source_dict: Dict[str, DataSource]):
24
- self.source_dict = source_dict
25
-
26
- def get_image_list(self) -> List[Tuple[str, DataSource]]:
27
- img_list = []
28
- for source in self.source_dict.values():
29
- img_list.extend(source.get_image_list())
30
- return img_list
hcpdiff/data/utils.py DELETED
@@ -1,80 +0,0 @@
1
- import cv2
2
- import numpy as np
3
- from PIL import Image
4
- from torchvision import transforms as T
5
- from torchvision.transforms import functional as F
6
-
7
- class DualRandomCrop:
8
- def __init__(self, size):
9
- self.size = size
10
-
11
- def __call__(self, img):
12
- crop_params = T.RandomCrop.get_params(img['img'], self.size)
13
- img['img'] = F.crop(img['img'], *crop_params)
14
- if "mask" in img:
15
- img['mask'] = self.crop(img['mask'], *crop_params)
16
- if "cond" in img:
17
- img['cond'] = F.crop(img['cond'], *crop_params)
18
- return img, crop_params[:2]
19
-
20
- @staticmethod
21
- def crop(img: np.ndarray, top: int, left: int, height: int, width: int) -> np.ndarray:
22
- right = left+width
23
- bottom = top+height
24
- return img[top:bottom, left:right, ...]
25
-
26
- def resize_crop_fix(img, target_size, mask_interp=cv2.INTER_CUBIC):
27
- w, h = img['img'].size
28
- if w == target_size[0] and h == target_size[1]:
29
- return img, [h,w,0,0,h,w]
30
-
31
- ratio_img = w/h
32
- if ratio_img>target_size[0]/target_size[1]:
33
- new_size = (round(ratio_img*target_size[1]), target_size[1])
34
- interp_type = Image.LANCZOS if h>target_size[1] else Image.BICUBIC
35
- else:
36
- new_size = (target_size[0], round(target_size[0]/ratio_img))
37
- interp_type = Image.LANCZOS if w>target_size[0] else Image.BICUBIC
38
- img['img'] = img['img'].resize(new_size, interp_type)
39
- if "mask" in img:
40
- img['mask'] = cv2.resize(img['mask'], new_size, interpolation=mask_interp)
41
- if "cond" in img:
42
- img['cond'] = img['cond'].resize(new_size, interp_type)
43
-
44
- img, crop_coord = DualRandomCrop(target_size[::-1])(img)
45
- return img, [*new_size, *crop_coord[::-1], *target_size]
46
-
47
- def pad_crop_fix(img, target_size):
48
- w, h = img['img'].size
49
- if w == target_size[0] and h == target_size[1]:
50
- return img, (h,w,0,0,h,w)
51
-
52
- pad_size = [0, 0, max(target_size[0]-w, 0), max(target_size[1]-h, 0)]
53
- if pad_size[2]>0 or pad_size[3]>0:
54
- img['img'] = F.pad(img['img'], pad_size)
55
- if "mask" in img:
56
- img['mask'] = np.pad(img['mask'], ((0, pad_size[3]), (0, pad_size[2])), 'constant', constant_values=(0, 0))
57
- if "cond" in img:
58
- img['cond'] = F.pad(img['cond'], pad_size)
59
-
60
- if pad_size[2]>0 and pad_size[3]>0:
61
- return img, (h,w,0,0,h,w) # No need to crop
62
- else:
63
- img, crop_coord = DualRandomCrop(target_size[::-1])(img)
64
- return img, crop_coord
65
-
66
- class CycleData():
67
- def __init__(self, data_loader):
68
- self.data_loader = data_loader
69
-
70
- def __iter__(self):
71
- self.epoch = 0
72
-
73
- def cycle():
74
- while True:
75
- self.data_loader.dataset.bucket.rest(self.epoch)
76
- for data in self.data_loader:
77
- yield data
78
- self.epoch += 1
79
-
80
- return cycle()
hcpdiff/infer_workflow.py DELETED
@@ -1,57 +0,0 @@
1
- import argparse
2
-
3
- import torch
4
- import hydra
5
- from omegaconf import OmegaConf, DictConfig
6
- from easydict import EasyDict
7
-
8
- from hcpdiff.utils.utils import load_config_with_cli
9
- from .workflow import MemoryMixin
10
- from copy import deepcopy
11
-
12
- class WorkflowRunner:
13
- def __init__(self, cfgs):
14
- self.cfgs_raw = deepcopy(cfgs)
15
- self.cfgs = OmegaConf.structured(cfgs, flags={"allow_objects": True})
16
- OmegaConf.resolve(self.cfgs)
17
- self.memory = EasyDict(**hydra.utils.instantiate(self.cfgs.memory))
18
- self.attach_memory(self.cfgs)
19
-
20
- def start(self):
21
- prepare_actions = hydra.utils.instantiate(self.cfgs.prepare)
22
- states = self.run(prepare_actions, {'cfgs': self.cfgs_raw})
23
- actions = hydra.utils.instantiate(self.cfgs.actions)
24
- states = self.run(actions, states)
25
-
26
- def attach_memory(self, cfgs):
27
- if OmegaConf.is_dict(cfgs):
28
- if '_target_' in cfgs and cfgs['_target_'].endswith('.from_memory'):
29
- cfgs._set_flag('allow_objects', True)
30
- cfgs['memory'] = self.memory
31
- else:
32
- for v in cfgs.values():
33
- self.attach_memory(v)
34
- elif OmegaConf.is_list(cfgs):
35
- for v in cfgs:
36
- self.attach_memory(v)
37
-
38
- @torch.inference_mode()
39
- def run(self, actions, states):
40
- N_steps = len(actions)
41
- for step, act in enumerate(actions):
42
- print(f'[{step+1}/{N_steps}] action: {type(act).__name__}')
43
- if isinstance(act, MemoryMixin):
44
- states = act(memory=self.memory, **states)
45
- else:
46
- states = act(**states)
47
- print(f'states: {", ".join(states.keys())}')
48
- return states
49
-
50
- if __name__ == '__main__':
51
- parser = argparse.ArgumentParser(description='HCP-Diffusion workflow')
52
- parser.add_argument('--cfg', type=str, default='')
53
- args, cfg_args = parser.parse_known_args()
54
- cfgs = load_config_with_cli(args.cfg, args_list=cfg_args) # skip --cfg
55
-
56
- runner = WorkflowRunner(cfgs)
57
- runner.start()
@@ -1,13 +0,0 @@
1
- from .base_logger import BaseLogger, LoggerGroup
2
- from .cli_logger import CLILogger
3
- from .webui_logger import WebUILogger
4
-
5
- try:
6
- from .tensorboard_logger import TBLogger
7
- except:
8
- print('tensorboard is not available')
9
-
10
- try:
11
- from .wandb_logger import WanDBLogger
12
- except:
13
- print('wandb is not available')
@@ -1,76 +0,0 @@
1
- from typing import Dict, Any, List
2
-
3
- from PIL import Image
4
-
5
- from .preview import ImagePreviewer
6
-
7
- class BaseLogger:
8
- def __init__(self, exp_dir, out_path, enable_log_image=False, log_step=10, image_log_step=200):
9
- self.exp_dir = exp_dir
10
- self.out_path = out_path
11
- self.enable_log_image = enable_log_image
12
- self.log_step = log_step
13
- self.image_log_step = image_log_step
14
- self.enable_log = True
15
- self.previewer_list: List[ImagePreviewer] = []
16
-
17
- def enable(self):
18
- self.enable_log = True
19
-
20
- def disable(self):
21
- self.enable_log = False
22
-
23
- def add_previewer(self, previewer: ImagePreviewer):
24
- self.previewer_list.append(previewer)
25
-
26
- def info(self, info):
27
- if self.enable_log:
28
- self._info(info)
29
-
30
- def _info(self, info):
31
- raise NotImplementedError()
32
-
33
- def log(self, datas: Dict[str, Any], step: int = 0):
34
- if self.enable_log and step%self.log_step == 0:
35
- self._log(datas, step)
36
-
37
- def _log(self, datas: Dict[str, Any], step: int = 0):
38
- raise NotImplementedError()
39
-
40
- def log_image(self, imgs: Dict[str, Image.Image], step: int = 0):
41
- if self.enable_log and self.enable_log_image and step%self.image_log_step == 0:
42
- self._log_image(imgs, step)
43
-
44
- def _log_image(self, imgs: Dict[str, Image.Image], step: int = 0):
45
- raise NotImplementedError()
46
-
47
- class LoggerGroup:
48
- def __init__(self, logger_list: List[BaseLogger]):
49
- self.logger_list = logger_list
50
-
51
- def enable(self):
52
- for logger in self.logger_list:
53
- logger.enable()
54
-
55
- def disable(self):
56
- for logger in self.logger_list:
57
- logger.disable()
58
-
59
- def add_previewer(self, previewer):
60
- for logger in self.logger_list:
61
- logger.add_previewer(previewer)
62
-
63
- def info(self, info):
64
- for logger in self.logger_list:
65
- logger.info(info)
66
-
67
- def log(self, datas: Dict[str, Any], step: int = 0):
68
- for logger in self.logger_list:
69
- logger.log(datas, step)
70
-
71
- def log_image(self, imgs: Dict[str, Image.Image], step: int = 0):
72
- for logger in self.logger_list:
73
- logger.log_image(imgs, step)
74
-
75
- def __len__(self):
76
- return len(self.logger_list)
@@ -1,40 +0,0 @@
1
- import os
2
- from typing import Dict, Any
3
-
4
- from PIL import Image
5
- from loguru import logger
6
-
7
- from .base_logger import BaseLogger
8
-
9
- class CLILogger(BaseLogger):
10
- def __init__(self, exp_dir, out_path, enable_log_image=False, log_step=10, image_log_step=200,
11
- img_log_dir='preview', img_ext='png', img_quality=95):
12
- super().__init__(exp_dir, out_path, enable_log_image, log_step, image_log_step)
13
- if exp_dir is not None: # exp_dir is only available in local main process
14
- logger.add(os.path.join(exp_dir, out_path))
15
- if enable_log_image:
16
- self.img_log_dir = os.path.join(exp_dir, img_log_dir)
17
- os.makedirs(self.img_log_dir, exist_ok=True)
18
- self.img_ext = img_ext
19
- self.img_quality = img_quality
20
- else:
21
- self.disable()
22
-
23
- def enable(self):
24
- super(CLILogger, self).enable()
25
- logger.enable("__main__")
26
-
27
- def disable(self):
28
- super(CLILogger, self).disable()
29
- logger.disable("__main__")
30
-
31
- def _info(self, info):
32
- logger.info(info)
33
-
34
- def _log(self, datas: Dict[str, Any], step: int = 0):
35
- logger.info(', '.join([f"{k} = {v['format'].format(*v['data'])}" for k, v in datas.items()]))
36
-
37
- def _log_image(self, imgs: Dict[str, Image.Image], step: int = 0):
38
- logger.info(f'log {len(imgs)} images')
39
- for name, img in imgs.items():
40
- img.save(os.path.join(self.img_log_dir, f'{step}-{name}.{self.img_ext}'), quality=self.img_quality)
@@ -1 +0,0 @@
1
- from .image_previewer import ImagePreviewer
@@ -1,149 +0,0 @@
1
- from contextlib import contextmanager
2
- from typing import List
3
-
4
- import hydra
5
- import torch
6
- from accelerate import infer_auto_device_map, dispatch_model
7
- from accelerate.hooks import remove_hook_from_module
8
- from diffusers import PNDMScheduler
9
- from torch.cuda.amp import autocast
10
-
11
- from hcpdiff.models import TokenizerHook
12
- from hcpdiff.utils.net_utils import to_cpu
13
- from hcpdiff.utils.utils import prepare_seed, load_config, size_to_int, int_to_size
14
- from hcpdiff.utils.utils import to_validate_file
15
- from hcpdiff.visualizer import Visualizer
16
-
17
- class ImagePreviewer(Visualizer):
18
- def __init__(self, infer_cfg, exp_dir, te_hook,
19
- unet, TE, tokenizer, vae, save_cfg=False):
20
- self.exp_dir = exp_dir
21
- self.cfgs_raw = load_config(infer_cfg)
22
- self.cfgs = hydra.utils.instantiate(self.cfgs_raw)
23
- self.save_cfg = save_cfg
24
- self.offload = 'offload' in self.cfgs and self.cfgs.offload is not None
25
- self.dtype = self.dtype_dict[self.cfgs.dtype]
26
-
27
- if getattr(self.cfgs.new_components, 'scheduler', None) is None:
28
- scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule='scaled_linear')
29
- else:
30
- scheduler = self.cfgs.new_components.scheduler
31
-
32
- pipe_cls = self.get_pipeline()
33
- self.pipe = pipe_cls(vae=vae, text_encoder=TE, tokenizer=tokenizer, unet=unet, scheduler=scheduler, feature_extractor=None,
34
- safety_checker=None, requires_safety_checker=False)
35
-
36
- self.token_ex = TokenizerHook(tokenizer)
37
- self.te_hook = te_hook
38
-
39
- if self.cfgs.seed is not None:
40
- self.seeds = list(range(self.cfgs.seed, self.cfgs.seed+self.cfgs.num*self.cfgs.bs))
41
- else:
42
- self.seeds = [None]*(self.cfgs.num*self.cfgs.bs)
43
-
44
- def build_vae_offload(self, offload_cfg):
45
- vram = size_to_int(offload_cfg.max_VRAM)
46
- if not offload_cfg.vae_cpu:
47
- device_map = infer_auto_device_map(self.pipe.vae, max_memory={0:int_to_size(vram >> 5), "cpu":offload_cfg.max_RAM}, dtype=torch.float32)
48
- self.pipe.vae = dispatch_model(self.pipe.vae, device_map)
49
- else:
50
- to_cpu(self.pipe.vae)
51
- self.vae_decode_raw = self.pipe.vae.decode
52
-
53
- def vae_decode_offload(latents, return_dict=True, decode_raw=self.pipe.vae.decode):
54
- self.pipe.vae.to(dtype=torch.float32)
55
- res = decode_raw(latents.cpu().to(dtype=torch.float32), return_dict=return_dict)
56
- return res
57
-
58
- self.pipe.vae.decode = vae_decode_offload
59
-
60
- self.vae_encode_raw = self.pipe.vae.encode
61
-
62
- def vae_encode_offload(x, return_dict=True, encode_raw=self.pipe.vae.encode):
63
- self.pipe.vae.to(dtype=torch.float32)
64
- res = encode_raw(x.cpu().to(dtype=torch.float32), return_dict=return_dict)
65
- return res
66
-
67
- self.pipe.vae.encode = vae_encode_offload
68
-
69
- def remove_vae_offload(self, offload_cfg):
70
- if not offload_cfg.vae_cpu:
71
- remove_hook_from_module(self.pipe.vae, recurse=True)
72
- else:
73
- self.pipe.vae.encode = self.vae_encode_raw
74
- self.pipe.vae.decode = self.vae_decode_raw
75
-
76
- @contextmanager
77
- def infer_optimize(self):
78
- if getattr(self.cfgs, 'vae_optimize', None) is not None:
79
- if self.cfgs.vae_optimize.tiling:
80
- self.pipe.vae.enable_tiling()
81
- if self.cfgs.vae_optimize.slicing:
82
- self.pipe.vae.enable_slicing()
83
- vae_device = self.pipe.vae.device
84
- if self.offload:
85
- self.build_vae_offload(self.cfgs.offload)
86
- else:
87
- self.pipe.vae.to(self.pipe.unet.device)
88
-
89
- yield
90
-
91
- if self.offload:
92
- self.remove_vae_offload(self.cfgs.offload)
93
- self.pipe.vae.to(vae_device)
94
- self.pipe.vae.disable_tiling()
95
- self.pipe.vae.disable_slicing()
96
-
97
- def preview(self):
98
- image_list, info_list = [], []
99
- with self.infer_optimize():
100
- for i in range(self.cfgs.num):
101
- prompt = self.cfgs.prompt[i*self.cfgs.bs:(i+1)*self.cfgs.bs] if isinstance(self.cfgs.prompt, list) \
102
- else [self.cfgs.prompt]*self.cfgs.bs
103
- negative_prompt = self.cfgs.neg_prompt[i*self.cfgs.bs:(i+1)*self.cfgs.bs] if isinstance(self.cfgs.neg_prompt, list) \
104
- else [self.cfgs.neg_prompt]*self.cfgs.bs
105
- seeds = self.seeds[i*self.cfgs.bs:(i+1)*self.cfgs.bs]
106
- images = self.vis_images(prompt=prompt, negative_prompt=negative_prompt, seeds=seeds,
107
- **self.cfgs.infer_args)
108
- for prompt_i, negative_prompt_i, seed in zip(prompt, negative_prompt, seeds):
109
- info_list.append({
110
- 'prompt':prompt_i,
111
- 'negative_prompt':negative_prompt_i,
112
- 'seed':seed,
113
- })
114
- image_list += images
115
-
116
- return image_list, info_list
117
-
118
- def preview_dict(self):
119
- image_list, info_list = self.preview()
120
- imgs = {f'{info["seed"]}-{to_validate_file(info["prompt"])}':img for img, info in zip(image_list, info_list)}
121
- return imgs
122
-
123
- @torch.no_grad()
124
- def vis_images(self, prompt, negative_prompt='', seeds: List[int] = None, **kwargs):
125
- G = prepare_seed(seeds or [None]*len(prompt))
126
-
127
- ex_input_dict, pipe_input_dict = self.get_ex_input()
128
- kwargs.update(pipe_input_dict)
129
-
130
- mult_p, clean_text_p = self.token_ex.parse_attn_mult(prompt)
131
- mult_n, clean_text_n = self.token_ex.parse_attn_mult(negative_prompt)
132
- with autocast(enabled=self.cfgs.amp, dtype=self.dtype):
133
- emb, pooled_output, attention_mask = self.te_hook.encode_prompt_to_emb(clean_text_n+clean_text_p)
134
- if not self.cfgs.encoder_attention_mask:
135
- attention_mask = None
136
- emb_n, emb_p = emb.chunk(2)
137
- emb_p = self.te_hook.mult_attn(emb_p, mult_p)
138
- emb_n = self.te_hook.mult_attn(emb_n, mult_n)
139
-
140
- if hasattr(self.pipe.unet, 'input_feeder'):
141
- for feeder in self.pipe.unet.input_feeder:
142
- feeder(ex_input_dict)
143
-
144
- if pooled_output is not None:
145
- pooled_output = pooled_output[-1]
146
-
147
- images = self.pipe(prompt_embeds=emb_p, negative_prompt_embeds=emb_n, callback=self.inter_callback, generator=G,
148
- pooled_output=pooled_output, encoder_attention_mask=attention_mask, **kwargs).images
149
- return images
@@ -1,30 +0,0 @@
1
- import os
2
- from typing import Dict, Any
3
-
4
- import numpy as np
5
- from PIL import Image
6
- from torch.utils.tensorboard import SummaryWriter
7
-
8
- from .base_logger import BaseLogger
9
-
10
-
11
- class TBLogger(BaseLogger):
12
- def __init__(self, exp_dir, out_path, enable_log_image=False, log_step=10, image_log_step=200):
13
- super().__init__(exp_dir, out_path, enable_log_image, log_step, image_log_step)
14
- if exp_dir is not None: # exp_dir is only available in local main process
15
- self.writer = SummaryWriter(os.path.join(exp_dir, out_path))
16
- else:
17
- self.writer = None
18
- self.disable()
19
-
20
- def _info(self, info):
21
- pass
22
-
23
- def _log(self, datas: Dict[str, Any], step: int = 0):
24
- for k, v in datas.items():
25
- if len(v['data']) == 1:
26
- self.writer.add_scalar(k, v['data'][0], global_step=step)
27
-
28
- def _log_image(self, imgs: Dict[str, Image.Image], step: int = 0):
29
- for name, img in imgs.items():
30
- self.writer.add_image(f'img/{name}', np.array(img), dataformats='HWC', global_step=step)