hcpdiff 0.9.0__py3-none-any.whl → 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (208) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +80 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +1 -2
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text2img.py +36 -74
  21. hcpdiff/data/source/text2img_cond.py +9 -15
  22. hcpdiff/diffusion/__init__.py +0 -0
  23. hcpdiff/diffusion/noise/__init__.py +2 -0
  24. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  25. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  26. hcpdiff/diffusion/sampler/__init__.py +5 -0
  27. hcpdiff/diffusion/sampler/base.py +72 -0
  28. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  29. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  30. hcpdiff/diffusion/sampler/edm.py +22 -0
  31. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  35. hcpdiff/easy/__init__.py +2 -0
  36. hcpdiff/easy/cfg/__init__.py +3 -0
  37. hcpdiff/easy/cfg/sd15_train.py +201 -0
  38. hcpdiff/easy/cfg/sdxl_train.py +140 -0
  39. hcpdiff/easy/cfg/t2i.py +177 -0
  40. hcpdiff/easy/model/__init__.py +2 -0
  41. hcpdiff/easy/model/cnet.py +31 -0
  42. hcpdiff/easy/model/loader.py +79 -0
  43. hcpdiff/easy/sampler.py +46 -0
  44. hcpdiff/evaluate/__init__.py +1 -0
  45. hcpdiff/evaluate/previewer.py +60 -0
  46. hcpdiff/loss/__init__.py +4 -1
  47. hcpdiff/loss/base.py +41 -0
  48. hcpdiff/loss/gw.py +35 -0
  49. hcpdiff/loss/ssim.py +37 -0
  50. hcpdiff/loss/vlb.py +79 -0
  51. hcpdiff/loss/weighting.py +66 -0
  52. hcpdiff/models/__init__.py +2 -2
  53. hcpdiff/models/cfg_context.py +17 -14
  54. hcpdiff/models/compose/compose_hook.py +44 -23
  55. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  56. hcpdiff/models/compose/sdxl_composer.py +4 -4
  57. hcpdiff/models/container.py +1 -1
  58. hcpdiff/models/controlnet.py +16 -16
  59. hcpdiff/models/lora_base_patch.py +14 -25
  60. hcpdiff/models/lora_layers.py +3 -9
  61. hcpdiff/models/lora_layers_patch.py +14 -24
  62. hcpdiff/models/text_emb_ex.py +84 -6
  63. hcpdiff/models/textencoder_ex.py +54 -18
  64. hcpdiff/models/wrapper/__init__.py +3 -0
  65. hcpdiff/models/wrapper/pixart.py +19 -0
  66. hcpdiff/models/wrapper/sd.py +218 -0
  67. hcpdiff/models/wrapper/utils.py +20 -0
  68. hcpdiff/parser/__init__.py +1 -0
  69. hcpdiff/parser/embpt.py +32 -0
  70. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  71. hcpdiff/tools/dataset_generator.py +94 -0
  72. hcpdiff/tools/download_hf_model.py +24 -0
  73. hcpdiff/tools/embedding_convert.py +6 -2
  74. hcpdiff/tools/init_proj.py +3 -21
  75. hcpdiff/tools/lora_convert.py +19 -15
  76. hcpdiff/tools/save_model.py +12 -0
  77. hcpdiff/tools/sd2diffusers.py +1 -1
  78. hcpdiff/train_colo.py +1 -1
  79. hcpdiff/train_deepspeed.py +1 -1
  80. hcpdiff/trainer_ac.py +79 -0
  81. hcpdiff/trainer_ac_single.py +31 -0
  82. hcpdiff/utils/__init__.py +0 -2
  83. hcpdiff/utils/inpaint_pipe.py +790 -0
  84. hcpdiff/utils/net_utils.py +29 -6
  85. hcpdiff/utils/pipe_hook.py +46 -33
  86. hcpdiff/utils/utils.py +21 -4
  87. hcpdiff/workflow/__init__.py +15 -10
  88. hcpdiff/workflow/daam/__init__.py +1 -0
  89. hcpdiff/workflow/daam/act.py +66 -0
  90. hcpdiff/workflow/daam/hook.py +109 -0
  91. hcpdiff/workflow/diffusion.py +128 -136
  92. hcpdiff/workflow/fast.py +31 -0
  93. hcpdiff/workflow/flow.py +67 -0
  94. hcpdiff/workflow/io.py +36 -68
  95. hcpdiff/workflow/model.py +46 -43
  96. hcpdiff/workflow/text.py +84 -52
  97. hcpdiff/workflow/utils.py +32 -12
  98. hcpdiff/workflow/vae.py +37 -38
  99. hcpdiff-2.1.dist-info/METADATA +285 -0
  100. hcpdiff-2.1.dist-info/RECORD +114 -0
  101. {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
  102. hcpdiff-2.1.dist-info/entry_points.txt +5 -0
  103. hcpdiff/ckpt_manager/base.py +0 -16
  104. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  105. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  106. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -60
  107. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  108. hcpdiff/data/bucket.py +0 -358
  109. hcpdiff/data/caption_loader.py +0 -80
  110. hcpdiff/data/cond_dataset.py +0 -40
  111. hcpdiff/data/crop_info_dataset.py +0 -40
  112. hcpdiff/data/data_processor.py +0 -33
  113. hcpdiff/data/pair_dataset.py +0 -146
  114. hcpdiff/data/sampler.py +0 -54
  115. hcpdiff/data/source/base.py +0 -30
  116. hcpdiff/data/utils.py +0 -80
  117. hcpdiff/infer_workflow.py +0 -57
  118. hcpdiff/loggers/__init__.py +0 -13
  119. hcpdiff/loggers/base_logger.py +0 -76
  120. hcpdiff/loggers/cli_logger.py +0 -40
  121. hcpdiff/loggers/preview/__init__.py +0 -1
  122. hcpdiff/loggers/preview/image_previewer.py +0 -149
  123. hcpdiff/loggers/tensorboard_logger.py +0 -30
  124. hcpdiff/loggers/wandb_logger.py +0 -31
  125. hcpdiff/loggers/webui_logger.py +0 -9
  126. hcpdiff/loss/min_snr_loss.py +0 -52
  127. hcpdiff/models/layers.py +0 -81
  128. hcpdiff/models/plugin.py +0 -348
  129. hcpdiff/models/wrapper.py +0 -75
  130. hcpdiff/noise/__init__.py +0 -3
  131. hcpdiff/noise/noise_base.py +0 -16
  132. hcpdiff/noise/pyramid_noise.py +0 -50
  133. hcpdiff/noise/zero_terminal.py +0 -44
  134. hcpdiff/train_ac.py +0 -565
  135. hcpdiff/train_ac_single.py +0 -39
  136. hcpdiff/utils/caption_tools.py +0 -105
  137. hcpdiff/utils/cfg_net_tools.py +0 -321
  138. hcpdiff/utils/cfg_resolvers.py +0 -16
  139. hcpdiff/utils/ema.py +0 -52
  140. hcpdiff/utils/img_size_tool.py +0 -248
  141. hcpdiff/vis/__init__.py +0 -3
  142. hcpdiff/vis/base_interface.py +0 -12
  143. hcpdiff/vis/disk_interface.py +0 -48
  144. hcpdiff/vis/webui_interface.py +0 -17
  145. hcpdiff/visualizer.py +0 -258
  146. hcpdiff/visualizer_reloadable.py +0 -237
  147. hcpdiff/workflow/base.py +0 -59
  148. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  149. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  150. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  151. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  152. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  153. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  154. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  155. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  156. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  157. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  158. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  159. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  160. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  161. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  162. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  163. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  164. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  165. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  166. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  167. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  168. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  169. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  170. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  171. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  172. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  173. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  174. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  175. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  176. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  177. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  178. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  179. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  180. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  181. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  182. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  183. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  184. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  185. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  186. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  187. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  188. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  189. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  190. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  191. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  192. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -57
  193. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  194. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero2.json +0 -32
  195. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero3.json +0 -39
  196. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  197. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  198. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  199. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  200. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  201. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  202. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  203. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  204. hcpdiff-0.9.0.dist-info/METADATA +0 -199
  205. hcpdiff-0.9.0.dist-info/RECORD +0 -155
  206. hcpdiff-0.9.0.dist-info/entry_points.txt +0 -2
  207. {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
  208. {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,80 @@
1
+ from typing import Union, Dict, Any
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torchvision.transforms as T
6
+ from PIL import Image
7
+ from rainbowneko.data import DataHandler, HandlerChain, LoadImageHandler, ImageHandler
8
+
9
+ from .text import TemplateFillHandler, TagDropoutHandler, TagEraseHandler, TagShuffleHandler, TokenizeHandler
10
+
11
+ class LossMapHandler(DataHandler):
12
+ def __init__(self, bucket, vae_scale=8, key_map_in=('loss_map -> image', 'image_size -> image_size'),
13
+ key_map_out=('image -> loss_map', 'coord -> coord')):
14
+ super().__init__(key_map_in, key_map_out)
15
+ self.vae_scale = vae_scale
16
+
17
+ self.handlers = HandlerChain(
18
+ load=LoadImageHandler(mode='L'),
19
+ bucket=bucket.handler,
20
+ image=ImageHandler(transform=T.Compose([
21
+ lambda x:x.resize((x.size[0]//self.vae_scale, x.size[1]//self.vae_scale), Image.BILINEAR),
22
+ T.ToTensor()
23
+ ]), )
24
+ )
25
+
26
+ def handle(self, image: Union[Image.Image, str], image_size: np.ndarray[int]):
27
+ data = self.handlers(dict(image=image, image_size=image_size))
28
+ image = data['image']
29
+ image[image<=0.5] *= 2
30
+ image[image>0.5] = (image[image>0.5]-0.5)*4+1
31
+ return self.handlers(dict(**data, image=image))
32
+
33
+ class DiffusionImageHandler(DataHandler):
34
+ def __init__(self, bucket, key_map_in=('image -> image', 'image_size -> image_size'), key_map_out=('image -> image', 'coord -> coord')):
35
+ super().__init__(key_map_in, key_map_out)
36
+
37
+ self.handlers = HandlerChain(
38
+ load=LoadImageHandler(),
39
+ bucket=bucket.handler,
40
+ image=ImageHandler(transform=T.Compose([
41
+ T.ToTensor(),
42
+ T.Normalize([0.5], [0.5])
43
+ ]), )
44
+ )
45
+
46
+ def handle(self, image: Image.Image, image_size: np.ndarray[int]):
47
+ if isinstance(image, torch.Tensor): # cached latents
48
+ return dict(image=image, image_size=image_size)
49
+ else:
50
+ return self.handlers(dict(image=image, image_size=image_size))
51
+
52
+ class StableDiffusionHandler(DataHandler):
53
+ def __init__(self, bucket, encoder_attention_mask=False, key_map_in=('image -> image', 'image_size -> image_size', 'prompt -> prompt'),
54
+ key_map_out=('image -> image', 'coord -> coord', 'prompt -> prompt'),
55
+ erase=0.15, dropout=0.0, shuffle=0.0, word_names={}, tokenize=True):
56
+ super().__init__(key_map_in, key_map_out)
57
+
58
+ self.image_handlers = DiffusionImageHandler(bucket)
59
+
60
+ text_handlers = {}
61
+ if dropout>0:
62
+ text_handlers['dropout'] = TagDropoutHandler(p=dropout)
63
+ if erase>0:
64
+ text_handlers['erase'] = TagEraseHandler(p=erase)
65
+ if shuffle>0:
66
+ text_handlers['shuffle'] = TagShuffleHandler()
67
+ text_handlers['fill'] = TemplateFillHandler(word_names)
68
+ if tokenize:
69
+ text_handlers['tokenize'] = TokenizeHandler(encoder_attention_mask)
70
+ self.text_handlers = HandlerChain(**text_handlers)
71
+
72
+ def handle(self, image: Image.Image, image_size: np.ndarray[int], prompt: str):
73
+ return dict(**self.image_handlers(dict(image=image, image_size=image_size)), **self.text_handlers(dict(prompt=prompt)))
74
+
75
+ def __call__(self, data) -> Dict[str, Any]:
76
+ data_proc = self.handle(**self.key_mapper_in.map_data(data)[1])
77
+ out_data = self.key_mapper_out.map_data(data_proc)[1]
78
+ data = dict(**data)
79
+ data.update(out_data)
80
+ return data
@@ -0,0 +1,111 @@
1
+ import random
2
+ from typing import Dict, Union, List
3
+
4
+ import numpy as np
5
+ from string import Formatter
6
+ from rainbowneko.data import DataHandler
7
+ from rainbowneko._share import register_model_callback
8
+
9
+ class TagShuffleHandler(DataHandler):
10
+ def __init__(self, key_map_in=('prompt -> prompt',), key_map_out=('prompt -> prompt',)):
11
+ super().__init__(key_map_in, key_map_out)
12
+
13
+ def handle(self, prompt: Union[Dict[str, str], str]):
14
+ if isinstance(prompt, str):
15
+ tags = prompt.split(',')
16
+ random.shuffle(tags)
17
+ prompt = ','.join(tags)
18
+ else:
19
+ tags = prompt['caption'].split(',')
20
+ random.shuffle(tags)
21
+ prompt['caption'] = ','.join(tags)
22
+ return {'prompt':prompt}
23
+
24
+ def __repr__(self):
25
+ return 'TagShuffleHandler()'
26
+
27
+ class TagDropoutHandler(DataHandler):
28
+ def __init__(self, p=0.1, key_map_in=('prompt -> prompt',), key_map_out=('prompt -> prompt',)):
29
+ super().__init__(key_map_in, key_map_out)
30
+ self.p = p
31
+
32
+ def handle(self, prompt: Union[Dict[str, str], str]):
33
+ if isinstance(prompt, str):
34
+ tags = np.array(prompt.split(','))
35
+ prompt = ','.join(tags[np.random.random(len(tags))>self.p])
36
+ else:
37
+ tags = prompt['caption'].split(',')
38
+ prompt['caption'] = ','.join(tags[np.random.random(len(tags))>self.p])
39
+ return {'prompt':prompt}
40
+
41
+ def __repr__(self):
42
+ return f'TagDropoutHandler(p={self.p})'
43
+
44
+ class TagEraseHandler(DataHandler):
45
+ def __init__(self, p=0.1, key_map_in=('prompt -> prompt',), key_map_out=('prompt -> prompt',)):
46
+ super().__init__(key_map_in, key_map_out)
47
+ self.p = p
48
+
49
+ def handle(self, prompt):
50
+ if isinstance(prompt, str):
51
+ if random.random()<self.p:
52
+ prompt = ''
53
+ else:
54
+ if random.random()<self.p:
55
+ prompt['caption'] = ''
56
+ return {'prompt':prompt}
57
+
58
+ def __repr__(self):
59
+ return f'TagEraseHandler(p={self.p})'
60
+
61
+
62
+ class TemplateFillHandler(DataHandler):
63
+ def __init__(self, word_names: Dict[str, str], key_map_in=('prompt -> prompt',), key_map_out=('prompt -> prompt',)):
64
+ super().__init__(key_map_in, key_map_out)
65
+ self.word_names = word_names
66
+
67
+ def handle(self, prompt):
68
+ template, caption = prompt['template'], prompt['caption']
69
+
70
+ keys_need = {i[1] for i in Formatter().parse(template) if i[1] is not None}
71
+ fill_dict = {k: v for k, v in self.word_names.items() if k in keys_need}
72
+
73
+ if (caption is not None) and ('caption' in keys_need):
74
+ fill_dict.update(caption=fill_dict.get('caption', None) or caption)
75
+
76
+ # skip keys that not provide
77
+ for k in keys_need:
78
+ if k not in fill_dict:
79
+ fill_dict[k] = ''
80
+
81
+ # replace None value with ''
82
+ fill_dict = {k:(v or '') for k, v in fill_dict.items()}
83
+ return {'prompt':template.format(**fill_dict)}
84
+
85
+ def __repr__(self):
86
+ return f'TemplateFill(\nword_names={self.word_names}\n)'
87
+
88
+ class TokenizeHandler(DataHandler):
89
+ def __init__(self, encoder_attention_mask=False, key_map_in=('prompt -> prompt',), key_map_out=None):
90
+ super().__init__(key_map_in, key_map_out)
91
+ self.encoder_attention_mask = encoder_attention_mask
92
+
93
+ register_model_callback(self.acquire_tokenizer)
94
+
95
+ def acquire_tokenizer(self, model_wrapper):
96
+ self.tokenizer = model_wrapper.tokenizer
97
+
98
+ def handle(self, prompt):
99
+ token_info = self.tokenizer(prompt, truncation=True, padding="max_length", return_tensors="pt",
100
+ max_length=self.tokenizer.model_max_length*self.tokenizer.N_repeats)
101
+ tokens = token_info.input_ids.squeeze()
102
+ data = {'prompt':tokens}
103
+ if self.encoder_attention_mask and 'attention_mask' in token_info:
104
+ data['attn_mask'] = token_info.attention_mask.squeeze()
105
+ if 'position_ids' in token_info:
106
+ data['position_ids'] = token_info.position_ids.squeeze()
107
+
108
+ return data
109
+
110
+ def __repr__(self):
111
+ return f'TokenizeHandler(\nencoder_attention_mask={self.encoder_attention_mask}, tokenizer={self.tokenizer}\n)'
@@ -1,4 +1,3 @@
1
- from .base import DataSource, ComposeDataSource
2
- from .text2img import Text2ImageSource, Text2ImageAttMapSource
1
+ from .text2img import Text2ImageSource, Text2ImageLossMapSource
3
2
  from .text2img_cond import Text2ImageCondSource
4
3
  from .folder_class import T2IFolderClassSource
@@ -1,40 +1,23 @@
1
- import os
2
- from typing import List, Tuple, Union
3
- from hcpdiff.utils.utils import get_file_name, get_file_ext
4
- from hcpdiff.utils.img_size_tool import types_support
5
- from .text2img import Text2ImageAttMapSource
6
- from hcpdiff.data.caption_loader import BaseCaptionLoader, auto_caption_loader
7
1
  from copy import copy
2
+ from typing import Union
8
3
 
9
- class T2IFolderClassSource(Text2ImageAttMapSource):
4
+ from rainbowneko.data.label_loader import BaseLabelLoader, auto_label_loader
10
5
 
11
- def get_image_list(self) -> List[Tuple[str, "T2IFolderClassSource"]]:
12
- sub_folders = [os.path.join(self.img_root, x) for x in os.listdir(self.img_root)]
13
- class_imgs = []
14
- for class_folder in sub_folders:
15
- class_name = os.path.basename(class_folder)
16
- imgs = [(os.path.join(class_folder, x), self) for x in os.listdir(class_folder) if get_file_ext(x) in types_support]
17
- class_imgs.extend(imgs*self.repeat[class_name])
18
- return class_imgs
6
+ from .text2img import Text2ImageLossMapSource
19
7
 
20
- def load_captions(self, caption_file: Union[str, BaseCaptionLoader]):
21
- if caption_file is None:
8
+ class T2IFolderClassSource(Text2ImageLossMapSource):
9
+ def _load_label_data(self, label_file: Union[str, BaseLabelLoader]):
10
+ ''' {class_name/image.ext: label} '''
11
+ if label_file is None:
22
12
  return {}
23
- elif isinstance(caption_file, str):
13
+ elif isinstance(label_file, str):
24
14
  captions = {}
25
- caption_loader = auto_caption_loader(caption_file)
26
- for class_name in os.listdir(caption_loader.path):
27
- class_folder = os.path.join(caption_loader.path, class_name)
15
+ caption_loader = auto_label_loader(label_file)
16
+ for class_folder in caption_loader.path.iterdir():
28
17
  caption_loader_class = copy(caption_loader)
29
18
  caption_loader_class.path = class_folder
30
- captions_class = {f'{class_name}/{name}':caption for name, caption in caption_loader_class.load().item()}
19
+ captions_class = {f'{class_folder.name}/{name}':caption for name, caption in caption_loader_class.load().item()}
31
20
  captions.update(captions_class)
32
21
  return captions
33
22
  else:
34
- return caption_file.load()
35
-
36
- def get_image_name(self, path: str) -> str:
37
- img_root, img_name = os.path.split(path)
38
- img_name = img_name.rsplit('.')[0]
39
- img_root, class_name = os.path.split(img_root)
40
- return f'{class_name}/{img_name}'
23
+ return label_file.load()
@@ -1,13 +1,11 @@
1
- from .base import DataSource
2
- from hcpdiff.data.caption_loader import BaseCaptionLoader, auto_caption_loader
3
- from typing import Union, Any
4
1
  import os
5
- from hcpdiff.utils.utils import get_file_name, get_file_ext
6
- from hcpdiff.utils.img_size_tool import types_support
7
- from typing import Dict, List, Tuple
8
- from PIL import Image
9
- import numpy as np
10
2
  import random
3
+ from pathlib import Path
4
+ from typing import Any
5
+ from typing import Dict
6
+
7
+ from rainbowneko.data import ImageLabelSource
8
+ from rainbowneko.utils.utils import is_image_file
11
9
  from torchvision.transforms import transforms
12
10
 
13
11
  default_image_transforms = transforms.Compose([
@@ -15,77 +13,41 @@ default_image_transforms = transforms.Compose([
15
13
  transforms.Normalize([0.5], [0.5])
16
14
  ])
17
15
 
18
- class Text2ImageSource(DataSource):
19
- def __init__(self, img_root, caption_file, prompt_template, text_transforms, image_transforms=default_image_transforms,
20
- bg_color=(255,255,255), repeat=1, **kwargs):
21
- super(Text2ImageSource, self).__init__(img_root, repeat=repeat)
16
+ class Text2ImageSource(ImageLabelSource):
17
+ def __init__(self, img_root, label_file, prompt_template, repeat=1, **kwargs):
18
+ super().__init__(img_root, label_file, repeat=repeat)
22
19
 
23
- self.caption_dict = self.load_captions(caption_file)
24
20
  self.prompt_template = self.load_template(prompt_template)
25
- self.image_transforms = image_transforms
26
- self.text_transforms = text_transforms
27
- self.bg_color = tuple(bg_color)
28
-
29
- def load_captions(self, caption_file: Union[str, BaseCaptionLoader]):
30
- if caption_file is None:
31
- return {}
32
- elif isinstance(caption_file, str):
33
- return auto_caption_loader(caption_file).load()
34
- else:
35
- return caption_file.load()
36
21
 
37
22
  def load_template(self, template_file):
38
23
  with open(template_file, 'r', encoding='utf-8') as f:
39
24
  return f.read().strip().split('\n')
40
25
 
41
- def get_image_list(self) -> List[Tuple[str, DataSource]]:
42
- imgs = [(os.path.join(self.img_root, x), self) for x in os.listdir(self.img_root) if get_file_ext(x) in types_support]
43
- return imgs*self.repeat
44
-
45
- def procees_image(self, image):
46
- return self.image_transforms(image)
47
-
48
- def process_text(self, text_dict):
49
- return self.text_transforms(text_dict)
50
-
51
- def load_image(self, path) -> Dict[str, Any]:
52
- image = Image.open(path)
53
- if image.mode == 'RGBA':
54
- x, y = image.size
55
- canvas = Image.new('RGBA', image.size, self.bg_color)
56
- canvas.paste(image, (0, 0, x, y), image)
57
- image = canvas
58
- return {'image': image.convert("RGB")}
59
-
60
- def load_caption(self, img_name) -> str:
61
- caption_ist = self.caption_dict.get(img_name, None)
62
- prompt_template = random.choice(self.prompt_template)
63
- prompt_ist = self.process_text({'prompt':prompt_template, 'caption':caption_ist})['prompt']
64
- return prompt_ist
65
-
66
- class Text2ImageAttMapSource(Text2ImageSource):
67
- def __init__(self, img_root, caption_file, prompt_template, text_transforms, image_transforms=default_image_transforms, att_mask=None,
68
- bg_color=(255, 255, 255), repeat=1, **kwargs):
69
- super().__init__(img_root, caption_file, prompt_template, image_transforms=image_transforms, text_transforms=text_transforms,
70
- bg_color=bg_color, repeat=repeat)
71
-
72
- if att_mask is None:
73
- self.att_mask = {}
26
+ def __getitem__(self, index) -> Dict[str, Any]:
27
+ img_name = self.img_ids[index]
28
+ path = os.path.join(self.img_root, img_name)
29
+
30
+ return {
31
+ 'id':img_name,
32
+ 'image':path,
33
+ 'prompt':{
34
+ 'template':random.choice(self.prompt_template),
35
+ 'caption':self.label_dict.get(img_name, None),
36
+ }
37
+ }
38
+
39
+ class Text2ImageLossMapSource(Text2ImageSource):
40
+ def __init__(self, img_root, caption_file, prompt_template, loss_map=None, repeat=1, **kwargs):
41
+ super().__init__(img_root, caption_file, prompt_template, repeat=repeat)
42
+
43
+ if loss_map is None:
44
+ self.loss_map = {}
74
45
  else:
75
- self.att_mask = {get_file_name(file):os.path.join(att_mask, file)
76
- for file in os.listdir(att_mask) if get_file_ext(file) in types_support}
77
-
78
- def get_att_mask(self, img_name):
79
- if img_name not in self.att_mask:
80
- return None
81
- att_mask = Image.open(self.att_mask[img_name]).convert("L")
82
- np_mask = np.array(att_mask).astype(float)
83
- np_mask[np_mask<=127+0.1] = (np_mask[np_mask<=127+0.1]/127.)
84
- np_mask[np_mask>127] = ((np_mask[np_mask>127]-127)/128.)*4+1
85
- return np_mask
86
-
87
- def load_image(self, path) -> Dict[str, Any]:
88
- img_root, img_name = os.path.split(path)
89
- image_dict = super().load_image(path)
90
- image_dict['att_mask'] = self.get_att_mask(get_file_name(img_name))
91
- return image_dict
46
+ loss_map = Path(loss_map)
47
+ self.loss_map = {file.stem:loss_map/file for file in loss_map.iterdir() if is_image_file(file)}
48
+
49
+ def __getitem__(self, index) -> Dict[str, Any]:
50
+ data = super().__getitem__(index)
51
+ img_name = self.img_ids[index]
52
+ data['loss_map'] = self.loss_map[Path(img_name).stem]
53
+ return data
@@ -1,22 +1,16 @@
1
1
  import os
2
2
  from typing import Dict, Any
3
3
 
4
- from PIL import Image
5
- from torchvision import transforms
4
+ from .text2img import Text2ImageSource
6
5
 
7
- from .text2img import Text2ImageAttMapSource, default_image_transforms
8
-
9
- class Text2ImageCondSource(Text2ImageAttMapSource):
10
- def __init__(self, img_root, caption_file, prompt_template, text_transforms, image_transforms=default_image_transforms,
11
- bg_color=(255, 255, 255), repeat=1, cond_dir=None, **kwargs):
12
- super().__init__(img_root, caption_file, prompt_template, image_transforms=image_transforms, text_transforms=text_transforms,
13
- bg_color=bg_color, repeat=repeat)
14
- self.cond_transform = transforms.ToTensor()
6
+ class Text2ImageCondSource(Text2ImageSource):
7
+ def __init__(self, img_root, caption_file, prompt_template, repeat=1, cond_dir=None, **kwargs):
8
+ super().__init__(img_root, caption_file, prompt_template, repeat=repeat)
15
9
  self.cond_dir = cond_dir
16
10
 
17
- def load_image(self, path) -> Dict[str, Any]:
18
- img_root, img_name = os.path.split(path)
19
- image_dict = super().load_image(path)
11
+ def __getitem__(self, index) -> Dict[str, Any]:
12
+ data = super().__getitem__(index)
13
+ img_name = self.img_ids[index]
20
14
  cond_path = os.path.join(self.cond_dir, img_name)
21
- image_dict['cond'] = Image.open(cond_path).convert("RGB")
22
- return image_dict
15
+ data['cond'] = cond_path
16
+ return data
File without changes
@@ -0,0 +1,2 @@
1
+ from .pyramid_noise import PyramidNoiseSampler
2
+ from .zero_terminal import ZeroTerminalSampler
@@ -0,0 +1,42 @@
1
+ import random
2
+
3
+ import torch
4
+ from torch.nn import functional as F
5
+
6
+ from hcpdiff.diffusion.sampler import BaseSampler
7
+
8
+ class PyramidNoiseSampler:
9
+ def __init__(self, level: int = 6, discount: float = 0.4, step_size: float = 2., resize_mode: str = 'bilinear'):
10
+ self.level = level
11
+ self.step_size = step_size
12
+ self.resize_mode = resize_mode
13
+ self.discount = discount
14
+
15
+ def make_nosie(self, shape, device='cuda', dtype=torch.float32):
16
+ noise = torch.randn(shape, device=device, dtype=dtype)
17
+ with torch.no_grad():
18
+ b, c, h, w = noise.shape
19
+ for i in range(1, self.level):
20
+ r = random.random()*2+self.step_size
21
+ wn, hn = max(1, int(w/(r**i))), max(1, int(h/(r**i)))
22
+ noise += F.interpolate(torch.randn(b, c, hn, wn).to(noise), (h, w), None, self.resize_mode)*(self.discount**i)
23
+ if wn == 1 or hn == 1:
24
+ break
25
+ noise = noise/noise.std()
26
+ return noise
27
+
28
+ @classmethod
29
+ def patch(cls, base_sampler: BaseSampler, level: int = 6, discount: float = 0.4, step_size: float = 2., resize_mode: str = 'bilinear'):
30
+ patcher = cls(level, discount, step_size, resize_mode)
31
+ base_sampler.make_nosie = patcher.make_nosie
32
+ return base_sampler
33
+
34
+ if __name__ == '__main__':
35
+ from hcpdiff.diffusion.sampler import EDM_DDPMSampler, DDPMContinuousSigmaScheduler
36
+ from matplotlib import pyplot as plt
37
+
38
+ sampler = PyramidNoiseSampler.patch(EDM_DDPMSampler(DDPMContinuousSigmaScheduler()))
39
+ noise = sampler.make_nosie((1,3,512,512), device='cpu')
40
+ plt.figure()
41
+ plt.imshow(noise[0].permute(1,2,0))
42
+ plt.show()
@@ -0,0 +1,39 @@
1
+ import torch
2
+ from ..sampler.sigma_scheduler import DDPMDiscreteSigmaScheduler
3
+
4
+ class ZeroTerminalSampler:
5
+
6
+ @classmethod
7
+ def patch(cls, base_sampler):
8
+ assert isinstance(base_sampler.sigma_scheduler, DDPMDiscreteSigmaScheduler), "ZeroTerminalScheduler only works with DDPM SigmaScheduler"
9
+
10
+ alphas_cumprod = base_sampler.sigma_scheduler.alphas_cumprod
11
+ base_sampler.sigma_scheduler.alphas_cumprod = cls.rescale_zero_terminal_snr(alphas_cumprod)
12
+ base_sampler.sigma_scheduler.sigmas = ((1-alphas_cumprod)/alphas_cumprod).sqrt()
13
+
14
+
15
+ @staticmethod
16
+ def rescale_zero_terminal_snr(alphas_cumprod, thr=1e-4):
17
+ """
18
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
19
+ Args:
20
+ alphas_cumprod (`torch.FloatTensor`)
21
+ Returns:
22
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
23
+ """
24
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
25
+
26
+ # Store old values.
27
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
28
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
29
+
30
+ # Shift so the last timestep is zero.
31
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
32
+
33
+ # Scale so the first timestep is back to the old value.
34
+ alphas_bar_sqrt *= alphas_bar_sqrt_0/(alphas_bar_sqrt_0-alphas_bar_sqrt_T)
35
+ alphas_bar_sqrt[-1] = thr # avoid nan sigma
36
+
37
+ # Convert alphas_bar_sqrt to betas
38
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
39
+ return alphas_bar
@@ -0,0 +1,5 @@
1
+ from .sigma_scheduler import *
2
+ from .base import BaseSampler
3
+ from .ddpm import DDPMSampler
4
+ from .edm import EDMSampler
5
+ from .diffusers import DiffusersSampler
@@ -0,0 +1,72 @@
1
+ from typing import Tuple
2
+ import torch
3
+ from .sigma_scheduler import SigmaScheduler
4
+ from diffusers import DDPMScheduler
5
+
6
+ class BaseSampler:
7
+ def __init__(self, sigma_scheduler: SigmaScheduler, generator: torch.Generator = None):
8
+ self.sigma_scheduler = sigma_scheduler
9
+ self.generator = generator
10
+
11
+ def c_in(self, sigma):
12
+ return 1
13
+
14
+ def c_out(self, sigma):
15
+ return 1
16
+
17
+ def c_skip(self, sigma):
18
+ return 1
19
+
20
+ @property
21
+ def num_timesteps(self):
22
+ return getattr(self.sigma_scheduler, 'num_timesteps', 1000.)
23
+
24
+ def get_timesteps(self, N_steps, device='cuda'):
25
+ return torch.linspace(0, self.num_timesteps, N_steps, device=device)
26
+
27
+ def make_nosie(self, shape, device='cuda', dtype=torch.float32):
28
+ return torch.randn(shape, generator=self.generator, device=device, dtype=dtype)
29
+
30
+ def init_noise(self, shape, device='cuda', dtype=torch.float32):
31
+ sigma = self.sigma_scheduler.sigma_max
32
+ return self.make_nosie(shape, device, dtype)*sigma
33
+
34
+ def add_noise(self, x, sigma) -> Tuple[torch.Tensor, torch.Tensor]:
35
+ noise = self.make_nosie(x.shape, device=x.device)
36
+ noisy_x = (x.to(dtype=torch.float32)-self.c_out(sigma)*noise)/self.c_skip(sigma)
37
+ return noisy_x.to(dtype=x.dtype), noise.to(dtype=x.dtype)
38
+
39
+ def add_noise_rand_t(self, x):
40
+ bs = x.shape[0]
41
+ # timesteps: [0, 1]
42
+ sigma, timesteps = self.sigma_scheduler.sample_sigma(shape=(bs,))
43
+ sigma = sigma.view(-1, 1, 1, 1).to(x.device)
44
+ timesteps = timesteps.to(x.device)
45
+ noisy_x, noise = self.add_noise(x, sigma)
46
+
47
+ # Sample a random timestep for each image
48
+ timesteps = timesteps*(self.num_timesteps-1)
49
+ return noisy_x, noise, sigma, timesteps
50
+
51
+ def denoise(self, x, sigma, eps=None, generator=None):
52
+ raise NotImplementedError
53
+
54
+ def eps_to_x0(self, eps, x_t, sigma):
55
+ return self.c_skip(sigma)*x_t+self.c_out(sigma)*eps
56
+
57
+ def velocity_to_eps(self, v_pred, x_t, sigma):
58
+ alpha = 1/(sigma**2+1)
59
+ sqrt_alpha = alpha.sqrt()
60
+ one_sqrt_alpha = (1-alpha).sqrt()
61
+ return sqrt_alpha*v_pred + one_sqrt_alpha*(x_t*sqrt_alpha)
62
+
63
+ def eps_to_velocity(self, eps, x_t, sigma):
64
+ alpha = 1/(sigma**2+1)
65
+ sqrt_alpha = alpha.sqrt()
66
+ one_sqrt_alpha = (1-alpha).sqrt()
67
+ return eps/sqrt_alpha - one_sqrt_alpha*x_t
68
+
69
+ def velocity_to_x0(self, v_pred, x_t, sigma):
70
+ alpha = 1/(sigma**2+1)
71
+ one_sqrt_alpha = (1-alpha).sqrt()
72
+ return alpha*x_t - one_sqrt_alpha*v_pred
@@ -0,0 +1,20 @@
1
+ import torch
2
+
3
+ from .base import BaseSampler
4
+ from .sigma_scheduler import SigmaScheduler
5
+
6
+ class DDPMSampler(BaseSampler):
7
+ def __init__(self, sigma_scheduler: SigmaScheduler, generator: torch.Generator=None):
8
+ super().__init__(sigma_scheduler, generator)
9
+
10
+ def c_in(self, sigma):
11
+ return 1./(sigma**2+1).sqrt()
12
+
13
+ def c_out(self, sigma):
14
+ return -sigma
15
+
16
+ def c_skip(self, sigma):
17
+ return 1.
18
+
19
+ def denoise(self, x, sigma, eps=None, generator=None):
20
+ raise NotImplementedError
@@ -0,0 +1,66 @@
1
+ import torch
2
+ import inspect
3
+ from diffusers import SchedulerMixin, DDPMScheduler
4
+
5
+ try:
6
+ from diffusers.utils import randn_tensor
7
+ except:
8
+ # new version of diffusers
9
+ from diffusers.utils.torch_utils import randn_tensor
10
+
11
+ from .base import BaseSampler
12
+ from .sigma_scheduler import TimeSigmaScheduler
13
+
14
+ class DiffusersSampler(BaseSampler):
15
+ def __init__(self, scheduler: SchedulerMixin, eta=0.0, generator: torch.Generator=None):
16
+ sigma_scheduler = TimeSigmaScheduler()
17
+ super().__init__(sigma_scheduler, generator)
18
+ self.scheduler = scheduler
19
+ self.eta = eta
20
+
21
+ def c_in(self, sigma):
22
+ one = torch.ones_like(sigma)
23
+ if hasattr(self.scheduler, '_step_index'):
24
+ self.scheduler._step_index = None
25
+ return self.scheduler.scale_model_input(one, sigma)
26
+
27
+ def c_out(self, sigma):
28
+ return -sigma
29
+
30
+ def c_skip(self, sigma):
31
+ if self.c_in(sigma) == 1.: # DDPM model
32
+ return (sigma**2+1).sqrt() # 1/sqrt(alpha_)
33
+ else: # EDM model
34
+ return 1.
35
+
36
+ def get_timesteps(self, N_steps, device='cuda'):
37
+ self.scheduler.set_timesteps(N_steps, device=device)
38
+ return self.scheduler.timesteps
39
+
40
+ def init_noise(self, shape, device='cuda', dtype=torch.float32):
41
+ return randn_tensor(shape, generator=self.generator, device=device, dtype=dtype)*self.scheduler.init_noise_sigma
42
+
43
+ def add_noise(self, x, sigma):
44
+ noise = randn_tensor(x.shape, generator=self.generator, device=x.device, dtype=x.dtype)
45
+ return self.scheduler.add_noise(x, noise, sigma), noise
46
+
47
+ def prepare_extra_step_kwargs(self, scheduler, generator, eta):
48
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
49
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
50
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
51
+ # and should be between [0, 1]
52
+
53
+ accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys())
54
+ extra_step_kwargs = {}
55
+ if accepts_eta:
56
+ extra_step_kwargs["eta"] = eta
57
+
58
+ # check if the scheduler accepts generator
59
+ accepts_generator = "generator" in set(inspect.signature(scheduler.step).parameters.keys())
60
+ if accepts_generator:
61
+ extra_step_kwargs["generator"] = generator
62
+ return extra_step_kwargs
63
+
64
+ def denoise(self, x_t, sigma, eps=None, generator=None):
65
+ extra_step_kwargs = self.prepare_extra_step_kwargs(self.scheduler, generator, self.eta)
66
+ return self.scheduler.step(eps, sigma, x_t, **extra_step_kwargs).prev_sample