hcpdiff 0.9.1__py3-none-any.whl → 2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +252 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +90 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +3 -3
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text.py +40 -0
  21. hcpdiff/data/source/text2img.py +36 -74
  22. hcpdiff/data/source/text2img_cond.py +9 -15
  23. hcpdiff/diffusion/__init__.py +0 -0
  24. hcpdiff/diffusion/noise/__init__.py +2 -0
  25. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  26. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  27. hcpdiff/diffusion/sampler/__init__.py +5 -0
  28. hcpdiff/diffusion/sampler/base.py +72 -0
  29. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  30. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  31. hcpdiff/diffusion/sampler/edm.py +22 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  35. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  36. hcpdiff/easy/__init__.py +2 -0
  37. hcpdiff/easy/cfg/__init__.py +3 -0
  38. hcpdiff/easy/cfg/sd15_train.py +207 -0
  39. hcpdiff/easy/cfg/sdxl_train.py +147 -0
  40. hcpdiff/easy/cfg/t2i.py +228 -0
  41. hcpdiff/easy/model/__init__.py +2 -0
  42. hcpdiff/easy/model/cnet.py +31 -0
  43. hcpdiff/easy/model/loader.py +79 -0
  44. hcpdiff/easy/sampler.py +46 -0
  45. hcpdiff/evaluate/__init__.py +1 -0
  46. hcpdiff/evaluate/previewer.py +60 -0
  47. hcpdiff/loss/__init__.py +4 -1
  48. hcpdiff/loss/base.py +41 -0
  49. hcpdiff/loss/gw.py +35 -0
  50. hcpdiff/loss/ssim.py +37 -0
  51. hcpdiff/loss/vlb.py +79 -0
  52. hcpdiff/loss/weighting.py +66 -0
  53. hcpdiff/models/__init__.py +2 -2
  54. hcpdiff/models/cfg_context.py +17 -14
  55. hcpdiff/models/compose/compose_hook.py +44 -23
  56. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  57. hcpdiff/models/compose/sdxl_composer.py +4 -4
  58. hcpdiff/models/controlnet.py +16 -16
  59. hcpdiff/models/lora_base_patch.py +14 -25
  60. hcpdiff/models/lora_layers.py +3 -9
  61. hcpdiff/models/lora_layers_patch.py +14 -24
  62. hcpdiff/models/text_emb_ex.py +84 -6
  63. hcpdiff/models/textencoder_ex.py +54 -18
  64. hcpdiff/models/wrapper/__init__.py +3 -0
  65. hcpdiff/models/wrapper/pixart.py +19 -0
  66. hcpdiff/models/wrapper/sd.py +218 -0
  67. hcpdiff/models/wrapper/utils.py +20 -0
  68. hcpdiff/parser/__init__.py +1 -0
  69. hcpdiff/parser/embpt.py +32 -0
  70. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  71. hcpdiff/tools/dataset_generator.py +94 -0
  72. hcpdiff/tools/download_hf_model.py +24 -0
  73. hcpdiff/tools/init_proj.py +3 -21
  74. hcpdiff/tools/lora_convert.py +18 -17
  75. hcpdiff/tools/save_model.py +12 -0
  76. hcpdiff/tools/sd2diffusers.py +1 -1
  77. hcpdiff/train_colo.py +1 -1
  78. hcpdiff/train_deepspeed.py +1 -1
  79. hcpdiff/trainer_ac.py +79 -0
  80. hcpdiff/trainer_ac_single.py +31 -0
  81. hcpdiff/utils/__init__.py +0 -2
  82. hcpdiff/utils/inpaint_pipe.py +7 -2
  83. hcpdiff/utils/net_utils.py +29 -6
  84. hcpdiff/utils/pipe_hook.py +24 -7
  85. hcpdiff/utils/utils.py +21 -4
  86. hcpdiff/workflow/__init__.py +15 -10
  87. hcpdiff/workflow/daam/__init__.py +1 -0
  88. hcpdiff/workflow/daam/act.py +66 -0
  89. hcpdiff/workflow/daam/hook.py +109 -0
  90. hcpdiff/workflow/diffusion.py +118 -128
  91. hcpdiff/workflow/fast.py +31 -0
  92. hcpdiff/workflow/flow.py +67 -0
  93. hcpdiff/workflow/io.py +36 -130
  94. hcpdiff/workflow/model.py +46 -43
  95. hcpdiff/workflow/text.py +60 -47
  96. hcpdiff/workflow/utils.py +32 -12
  97. hcpdiff/workflow/vae.py +37 -38
  98. hcpdiff-2.2.dist-info/METADATA +299 -0
  99. hcpdiff-2.2.dist-info/RECORD +115 -0
  100. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/WHEEL +1 -1
  101. hcpdiff-2.2.dist-info/entry_points.txt +5 -0
  102. hcpdiff/ckpt_manager/base.py +0 -16
  103. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  104. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  105. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
  106. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  107. hcpdiff/data/bucket.py +0 -358
  108. hcpdiff/data/caption_loader.py +0 -80
  109. hcpdiff/data/cond_dataset.py +0 -40
  110. hcpdiff/data/crop_info_dataset.py +0 -40
  111. hcpdiff/data/data_processor.py +0 -33
  112. hcpdiff/data/pair_dataset.py +0 -146
  113. hcpdiff/data/sampler.py +0 -54
  114. hcpdiff/data/source/base.py +0 -30
  115. hcpdiff/data/utils.py +0 -80
  116. hcpdiff/deprecated/__init__.py +0 -1
  117. hcpdiff/deprecated/cfg_converter.py +0 -81
  118. hcpdiff/deprecated/lora_convert.py +0 -31
  119. hcpdiff/infer_workflow.py +0 -57
  120. hcpdiff/loggers/__init__.py +0 -13
  121. hcpdiff/loggers/base_logger.py +0 -76
  122. hcpdiff/loggers/cli_logger.py +0 -40
  123. hcpdiff/loggers/preview/__init__.py +0 -1
  124. hcpdiff/loggers/preview/image_previewer.py +0 -149
  125. hcpdiff/loggers/tensorboard_logger.py +0 -30
  126. hcpdiff/loggers/wandb_logger.py +0 -31
  127. hcpdiff/loggers/webui_logger.py +0 -9
  128. hcpdiff/loss/min_snr_loss.py +0 -52
  129. hcpdiff/models/layers.py +0 -81
  130. hcpdiff/models/plugin.py +0 -348
  131. hcpdiff/models/wrapper.py +0 -75
  132. hcpdiff/noise/__init__.py +0 -3
  133. hcpdiff/noise/noise_base.py +0 -16
  134. hcpdiff/noise/pyramid_noise.py +0 -50
  135. hcpdiff/noise/zero_terminal.py +0 -44
  136. hcpdiff/train_ac.py +0 -566
  137. hcpdiff/train_ac_single.py +0 -39
  138. hcpdiff/utils/caption_tools.py +0 -105
  139. hcpdiff/utils/cfg_net_tools.py +0 -321
  140. hcpdiff/utils/cfg_resolvers.py +0 -16
  141. hcpdiff/utils/ema.py +0 -52
  142. hcpdiff/utils/img_size_tool.py +0 -248
  143. hcpdiff/vis/__init__.py +0 -3
  144. hcpdiff/vis/base_interface.py +0 -12
  145. hcpdiff/vis/disk_interface.py +0 -48
  146. hcpdiff/vis/webui_interface.py +0 -17
  147. hcpdiff/viser_fast.py +0 -138
  148. hcpdiff/visualizer.py +0 -265
  149. hcpdiff/visualizer_reloadable.py +0 -237
  150. hcpdiff/workflow/base.py +0 -59
  151. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  152. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  153. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  154. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  155. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  156. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  157. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  158. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  159. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  160. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  161. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  162. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  163. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  164. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  165. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  166. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  167. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  168. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  169. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  170. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  171. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  172. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  173. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  174. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  175. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  176. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  177. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  178. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  179. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  180. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  181. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  182. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  183. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  184. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  185. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  186. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  187. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  188. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  189. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  190. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  191. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  192. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  193. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  194. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  195. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
  196. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  197. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
  198. hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
  199. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  200. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  201. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  202. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  203. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  204. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  205. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  206. hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  207. hcpdiff-0.9.1.dist-info/METADATA +0 -199
  208. hcpdiff-0.9.1.dist-info/RECORD +0 -160
  209. hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
  210. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info/licenses}/LICENSE +0 -0
  211. {hcpdiff-0.9.1.dist-info → hcpdiff-2.2.dist-info}/top_level.txt +0 -0
hcpdiff/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from .train_ac import Trainer
2
- from .train_ac_single import TrainerSingleCard
3
- from .visualizer import Visualizer
4
- from .visualizer_reloadable import VisualizerReloadable
1
+ #from .train_ac_old import Trainer
2
+ #from .train_ac_single import TrainerSingleCard
3
+ # from .visualizer import Visualizer
4
+ # from .visualizer_reloadable import VisualizerReloadable
@@ -1,5 +1,4 @@
1
- from .ckpt_pkl import CkptManagerPKL
2
- from .ckpt_safetensor import CkptManagerSafe
3
-
4
- def auto_manager(ckpt_path:str):
5
- return CkptManagerSafe() if ckpt_path.endswith('.safetensors') else CkptManagerPKL()
1
+ from .format import EmbFormat, DiffusersSD15Format, DiffusersModelFormat, DiffusersSDXLFormat, DiffusersPixArtFormat, OfficialSDXLFormat, \
2
+ OfficialSD15Format, LoraWebuiFormat
3
+ from .ckpt import EmbSaver, easy_emb_saver
4
+ from .loader import HCPLoraLoader
@@ -0,0 +1,24 @@
1
+ from rainbowneko.ckpt_manager import NekoSaver, CkptFormat, LocalCkptSource, PKLFormat
2
+ from torch import nn
3
+ from typing import Dict, Any
4
+
5
+ class EmbSaver(NekoSaver):
6
+ def __init__(self, format: CkptFormat, source: LocalCkptSource, target_key='embs', prefix=None):
7
+ super().__init__(format, source)
8
+ self.target_key = target_key
9
+ self.prefix = prefix
10
+
11
+ def save_to(self, name, model: nn.Module, plugin_groups: Dict[str, Any], model_ema=None, exclude_key=None,
12
+ name_template=None):
13
+ train_pts = plugin_groups[self.target_key]
14
+ for pt_name, pt in train_pts.items():
15
+ self.save(pt_name, (pt_name, pt), prefix=self.prefix)
16
+ if name_template is not None:
17
+ pt_name = name_template.format(pt_name)
18
+ self.save(pt_name, (pt_name, pt), prefix=self.prefix)
19
+
20
+ def easy_emb_saver():
21
+ return EmbSaver(
22
+ format=PKLFormat(),
23
+ source=LocalCkptSource(),
24
+ )
@@ -0,0 +1,4 @@
1
+ from .emb import EmbFormat
2
+ from .diffusers import DiffusersSD15Format, DiffusersModelFormat, DiffusersSDXLFormat, DiffusersPixArtFormat
3
+ from .sd_single import OfficialSD15Format, OfficialSDXLFormat
4
+ from .lora_webui import LoraWebuiFormat
@@ -0,0 +1,59 @@
1
+ import torch
2
+ from diffusers import ModelMixin, AutoencoderKL, UNet2DConditionModel, PixArtTransformer2DModel
3
+ from rainbowneko.ckpt_manager.format import CkptFormat
4
+ from transformers import CLIPTextModel, AutoTokenizer, T5EncoderModel
5
+
6
+ from hcpdiff.diffusion.sampler import DDPMSampler, DDPMDiscreteSigmaScheduler
7
+ from hcpdiff.models.compose import SDXLTokenizer, SDXLTextEncoder
8
+
9
+ class DiffusersModelFormat(CkptFormat):
10
+ def __init__(self, builder: ModelMixin):
11
+ self.builder = builder
12
+
13
+ def save_ckpt(self, sd_model: ModelMixin, save_f: str, **kwargs):
14
+ sd_model.save_pretrained(save_f)
15
+
16
+ def load_ckpt(self, ckpt_f: str, map_location="cpu", **kwargs):
17
+ self.builder.from_pretrained(ckpt_f, **kwargs)
18
+
19
+ class DiffusersSD15Format(CkptFormat):
20
+ def load_ckpt(self, pretrained_model: str, map_location="cpu", denoiser=None, TE=None, vae: AutoencoderKL = None, noise_sampler=None,
21
+ tokenizer=None, revision=None, dtype=torch.float32, **kwargs):
22
+ denoiser = denoiser or UNet2DConditionModel.from_pretrained(
23
+ pretrained_model, subfolder="unet", revision=revision, torch_dtype=dtype
24
+ )
25
+ vae = vae or AutoencoderKL.from_pretrained(pretrained_model, subfolder="vae", revision=revision, torch_dtype=dtype)
26
+ noise_sampler = noise_sampler or DDPMSampler(DDPMDiscreteSigmaScheduler())
27
+
28
+ TE = TE or CLIPTextModel.from_pretrained(pretrained_model, subfolder="text_encoder", revision=revision, torch_dtype=dtype)
29
+ tokenizer = tokenizer or AutoTokenizer.from_pretrained(pretrained_model, subfolder="tokenizer", revision=revision, use_fast=False)
30
+
31
+ return dict(denoiser=denoiser, TE=TE, vae=vae, noise_sampler=noise_sampler, tokenizer=tokenizer)
32
+
33
+ class DiffusersSDXLFormat(CkptFormat):
34
+ def load_ckpt(self, pretrained_model: str, map_location="cpu", denoiser=None, TE=None, vae: AutoencoderKL = None, noise_sampler=None,
35
+ tokenizer=None, revision=None, dtype=torch.float32, **kwargs):
36
+ denoiser = denoiser or UNet2DConditionModel.from_pretrained(
37
+ pretrained_model, subfolder="unet", revision=revision, torch_dtype=dtype
38
+ )
39
+ vae = vae or AutoencoderKL.from_pretrained(pretrained_model, subfolder="vae", revision=revision, torch_dtype=dtype)
40
+ noise_sampler = noise_sampler or DDPMSampler(DDPMDiscreteSigmaScheduler())
41
+
42
+ TE = TE or SDXLTextEncoder.from_pretrained(pretrained_model, subfolder="text_encoder", revision=revision, torch_dtype=dtype)
43
+ tokenizer = tokenizer or SDXLTokenizer.from_pretrained(pretrained_model, subfolder="tokenizer", revision=revision, use_fast=False)
44
+
45
+ return dict(denoiser=denoiser, TE=TE, vae=vae, noise_sampler=noise_sampler, tokenizer=tokenizer)
46
+
47
+ class DiffusersPixArtFormat(CkptFormat):
48
+ def load_ckpt(self, pretrained_model: str, map_location="cpu", denoiser=None, TE=None, vae: AutoencoderKL = None, noise_sampler=None,
49
+ tokenizer=None, revision=None, dtype=torch.float32, **kwargs):
50
+ denoiser = denoiser or PixArtTransformer2DModel.from_pretrained(
51
+ pretrained_model, subfolder="transformer", revision=revision, torch_dtype=dtype
52
+ )
53
+ vae = vae or AutoencoderKL.from_pretrained(pretrained_model, subfolder="vae", revision=revision, torch_dtype=dtype)
54
+ noise_sampler = noise_sampler or DDPMSampler(DDPMDiscreteSigmaScheduler())
55
+
56
+ TE = TE or T5EncoderModel.from_pretrained(pretrained_model, subfolder="text_encoder", revision=revision, torch_dtype=dtype)
57
+ tokenizer = tokenizer or AutoTokenizer.from_pretrained(pretrained_model, subfolder="tokenizer", revision=revision, use_fast=False)
58
+
59
+ return dict(denoiser=denoiser, TE=TE, vae=vae, noise_sampler=noise_sampler, tokenizer=tokenizer)
@@ -0,0 +1,21 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ from rainbowneko.ckpt_manager.format import CkptFormat
5
+ from torch.serialization import FILE_LIKE
6
+
7
+ class EmbFormat(CkptFormat):
8
+ EXT = 'pt'
9
+
10
+ def save_ckpt(self, sd_model: Tuple[str, torch.Tensor], save_f: FILE_LIKE):
11
+ name, emb = sd_model
12
+ torch.save({'string_to_param':{'*':emb}, 'name':name}, save_f)
13
+
14
+ def load_ckpt(self, ckpt_f: FILE_LIKE, map_location="cpu"):
15
+ state = torch.load(ckpt_f, map_location=map_location)
16
+ if 'string_to_param' in state:
17
+ emb = state['string_to_param']['*']
18
+ else:
19
+ emb = state['emb_params']
20
+ emb.requires_grad_(False)
21
+ return emb
@@ -0,0 +1,252 @@
1
+ import math
2
+ import re
3
+ from typing import List, Dict, Any
4
+
5
+ from rainbowneko.ckpt_manager.format import CkptFormat, SafeTensorFormat
6
+ from torch.serialization import FILE_LIKE
7
+
8
+ class LoraConverter:
9
+ com_name_unet = ['down_blocks', 'up_blocks', 'mid_block', 'transformer_blocks', 'to_q', 'to_k', 'to_v', 'to_out', 'proj_in', 'proj_out',
10
+ 'input_blocks', 'middle_block', 'output_blocks']
11
+ com_name_TE = ['self_attn', 'q_proj', 'v_proj', 'k_proj', 'out_proj', 'text_model']
12
+ prefix_unet = 'lora_unet_'
13
+ prefix_TE = 'lora_te_'
14
+ prefix_TE_xl_clip_L = 'lora_te1_'
15
+ prefix_TE_xl_clip_bigG = 'lora_te2_'
16
+
17
+ lora_w_map = {'lora_down.weight':'W_down', 'lora_up.weight':'W_up'}
18
+
19
+ def __init__(self):
20
+ self.com_name_unet_tmp = [x.replace('_', '%') for x in self.com_name_unet]
21
+ self.com_name_TE_tmp = [x.replace('_', '%') for x in self.com_name_TE]
22
+
23
+ def convert_from_webui(self, state, auto_scale_alpha=False, sdxl=False):
24
+ if not sdxl:
25
+ sd_unet = self.convert_from_webui_(state, prefix=self.prefix_unet, com_name=self.com_name_unet, com_name_tmp=self.com_name_unet_tmp)
26
+ sd_TE = self.convert_from_webui_(state, prefix=self.prefix_TE, com_name=self.com_name_TE, com_name_tmp=self.com_name_TE_tmp)
27
+ else:
28
+ sd_unet = self.convert_from_webui_xl_unet_(state, prefix=self.prefix_unet, com_name=self.com_name_unet,
29
+ com_name_tmp=self.com_name_unet_tmp)
30
+ sd_TE = self.convert_from_webui_xl_te_(state, prefix=self.prefix_TE_xl_clip_L, com_name=self.com_name_TE,
31
+ com_name_tmp=self.com_name_TE_tmp)
32
+ sd_TE2 = self.convert_from_webui_xl_te_(state, prefix=self.prefix_TE_xl_clip_bigG, com_name=self.com_name_TE,
33
+ com_name_tmp=self.com_name_TE_tmp)
34
+ sd_TE.update(sd_TE2)
35
+
36
+ if auto_scale_alpha:
37
+ sd_unet = self.alpha_scale_from_webui(sd_unet)
38
+ sd_TE = self.alpha_scale_from_webui(sd_TE)
39
+
40
+ sd = {
41
+ **{f'denoiser.{k}':v for k,v in sd_unet.items()},
42
+ **{f'TE.{k}':v for k,v in sd_TE.items()},
43
+ }
44
+ return {'base': sd}
45
+
46
+ def convert_to_webui(self, sd_unet, sd_TE, auto_scale_alpha=False, sdxl=False):
47
+ sd_unet = self.convert_to_webui_(sd_unet, prefix=self.prefix_unet)
48
+ if sdxl:
49
+ sd_TE = self.convert_to_webui_xl_(sd_TE, prefix=self.prefix_TE)
50
+ else:
51
+ sd_TE = self.convert_to_webui_(sd_TE, prefix=self.prefix_TE)
52
+ sd_unet.update(sd_TE)
53
+ if auto_scale_alpha:
54
+ sd_unet = self.alpha_scale_to_webui(sd_unet)
55
+ return sd_unet
56
+
57
+ def convert_from_webui_(self, state, prefix, com_name, com_name_tmp):
58
+ state = {k:v for k, v in state.items() if k.startswith(prefix)}
59
+ prefix_len = len(prefix)
60
+ sd_covert = {}
61
+ for k, v in state.items():
62
+ model_k, lora_k = k[prefix_len:].split('.', 1)
63
+ model_k = self.replace_all(model_k, com_name, com_name_tmp).replace('_', '.').replace('%', '_')
64
+ if lora_k == 'alpha':
65
+ sd_covert[f'{model_k}.___.{lora_k}'] = v
66
+ else:
67
+ sd_covert[f'{model_k}.___.layer.{self.lora_w_map[lora_k]}'] = v
68
+ return sd_covert
69
+
70
+ def convert_to_webui_(self, state, prefix):
71
+ sd_covert = {}
72
+ for k, v in state.items():
73
+ if k.endswith('W_down'):
74
+ model_k, _ = k.split('.___.', 1)
75
+ lora_k = 'lora_down.weight'
76
+ elif k.endswith('W_up'):
77
+ model_k, _ = k.split('.___.', 1)
78
+ lora_k = 'lora_up.weight'
79
+ else:
80
+ model_k, lora_k = k.split('.___.', 1)
81
+
82
+ sd_covert[f"{prefix}{model_k.replace('.', '_')}.{lora_k}"] = v
83
+ return sd_covert
84
+
85
+ def convert_to_webui_xl_(self, state, prefix):
86
+ sd_convert = {}
87
+ for k, v in state.items():
88
+ if k.endswith('W_down'):
89
+ model_k, _ = k.split('.___.', 1)
90
+ lora_k = 'lora_down.weight'
91
+ elif k.endswith('W_up'):
92
+ model_k, _ = k.split('.___.', 1)
93
+ lora_k = 'lora_up.weight'
94
+ else:
95
+ model_k, lora_k = k.split('.___.', 1)
96
+
97
+ new_k = f"{prefix}{model_k.replace('.', '_')}.{lora_k}"
98
+ if 'clip' in new_k:
99
+ new_k = new_k.replace('_clip_L', '1') if 'clip_L' in new_k else new_k.replace('_clip_bigG', '2')
100
+ sd_convert[new_k] = v
101
+ return sd_convert
102
+
103
+ def convert_from_webui_xl_te_(self, state, prefix, com_name, com_name_tmp):
104
+ state = {k:v for k, v in state.items() if k.startswith(prefix)}
105
+ sd_covert = {}
106
+ prefix_len = len(prefix)
107
+
108
+ for k, v in state.items():
109
+ model_k, lora_k = k[prefix_len:].split('.', 1)
110
+ model_k = self.replace_all(model_k, com_name, com_name_tmp).replace('_', '.').replace('%', '_')
111
+ if prefix == 'lora_te1_':
112
+ model_k = f'clip_L.{model_k}'
113
+ else:
114
+ model_k = f'clip_bigG.{model_k}'
115
+
116
+ if lora_k == 'alpha':
117
+ sd_covert[f'{model_k}.___.{lora_k}'] = v
118
+ else:
119
+ sd_covert[f'{model_k}.___.layer.{self.lora_w_map[lora_k]}'] = v
120
+ return sd_covert
121
+
122
+ def convert_from_webui_xl_unet_(self, state, prefix, com_name, com_name_tmp):
123
+ # Down:
124
+ # 4 -> 1, 0 4 = 1 + 3 * 1 + 0
125
+ # 5 -> 1, 1 5 = 1 + 3 * 1 + 1
126
+ # 7 -> 2, 0 7 = 1 + 3 * 2 + 0
127
+ # 8 -> 2, 1 8 = 1 + 3 * 2 + 1
128
+
129
+ # Up
130
+ # 0 -> 0, 0 0 = 0 * 3 + 0
131
+ # 1 -> 0, 1 1 = 0 * 3 + 1
132
+ # 2 -> 0, 2 2 = 0 * 3 + 2
133
+ # 3 -> 1, 0 3 = 1 * 3 + 0
134
+ # 4 -> 1, 1 4 = 1 * 3 + 1
135
+ # 5 -> 1, 2 5 = 1 * 3 + 2
136
+
137
+ down = {
138
+ '4':[1, 0],
139
+ '5':[1, 1],
140
+ '7':[2, 0],
141
+ '8':[2, 1],
142
+ }
143
+ up = {
144
+ '0':[0, 0],
145
+ '1':[0, 1],
146
+ '2':[0, 2],
147
+ '3':[1, 0],
148
+ '4':[1, 1],
149
+ '5':[1, 2],
150
+ }
151
+
152
+ m = []
153
+
154
+ def match(key, regex_text):
155
+ regex = re.compile(regex_text)
156
+ r = re.match(regex, key)
157
+ if not r:
158
+ return False
159
+
160
+ m.clear()
161
+ m.extend(r.groups())
162
+ return True
163
+
164
+ state = {k:v for k, v in state.items() if k.startswith(prefix)}
165
+ sd_covert = {}
166
+ prefix_len = len(prefix)
167
+ for k, v in state.items():
168
+ model_k, lora_k = k[prefix_len:].split('.', 1)
169
+
170
+ model_k = self.replace_all(model_k, com_name, com_name_tmp).replace('_', '.').replace('%', '_')
171
+
172
+ if match(model_k, r'input_blocks.(\d+).1.(.+)'):
173
+ new_k = f'down_blocks.{down[m[0]][0]}.attentions.{down[m[0]][1]}.{m[1]}'
174
+ elif match(model_k, r'middle_block.1.(.+)'):
175
+ new_k = f'mid_block.attentions.0.{m[0]}'
176
+ pass
177
+ elif match(model_k, r'output_blocks.(\d+).(\d+).(.+)'):
178
+ new_k = f'up_blocks.{up[m[0]][0]}.attentions.{up[m[0]][1]}.{m[2]}'
179
+ else:
180
+ raise NotImplementedError
181
+
182
+ if lora_k == 'alpha':
183
+ sd_covert[f'{new_k}.___.{lora_k}'] = v
184
+ else:
185
+ sd_covert[f'{new_k}.___.layer.{self.lora_w_map[lora_k]}'] = v
186
+
187
+ return sd_covert
188
+
189
+ @staticmethod
190
+ def replace_all(data: str, srcs: List[str], dsts: List[str]):
191
+ for src, dst in zip(srcs, dsts):
192
+ data = data.replace(src, dst)
193
+ return data
194
+
195
+ @staticmethod
196
+ def alpha_scale_from_webui(state):
197
+ # Apply to "lora_down" and "lora_up" respectively to prevent overflow
198
+ for k, v in state.items():
199
+ if 'W_up' in k:
200
+ state[k] = v*math.sqrt(v.shape[1])
201
+ elif 'W_down' in k:
202
+ state[k] = v*math.sqrt(v.shape[0])
203
+ return state
204
+
205
+ @staticmethod
206
+ def alpha_scale_to_webui(state):
207
+ for k, v in state.items():
208
+ if 'lora_up' in k:
209
+ state[k] = v*math.sqrt(v.shape[1])
210
+ elif 'lora_down' in k:
211
+ state[k] = v*math.sqrt(v.shape[0])
212
+ return state
213
+
214
+ class LoraWebuiFormat(CkptFormat):
215
+ def __init__(self, format=None, auto_scale_alpha=False):
216
+ self.converter = LoraConverter()
217
+ self.auto_scale_alpha = auto_scale_alpha
218
+
219
+ if format is None:
220
+ format = SafeTensorFormat()
221
+ self.format = format
222
+
223
+ def save_ckpt(self, sd_model: Dict[str, Any], save_f: FILE_LIKE):
224
+ sd_denoiser = {k.removeprefix('denoiser.'):v for k, v in sd_model['base'].items() if k.startswith('denoiser.')}
225
+ sd_TE = {k.removeprefix('TE.'):v for k, v in sd_model['base'].items() if k.startswith('TE.')}
226
+
227
+ if len(sd_denoiser)>0 or len(sd_TE)>0:
228
+ sdxl = False
229
+ for k in sd_TE.keys():
230
+ if 'clip_L' in k or 'clip_bigG' in k:
231
+ sdxl = True
232
+ break
233
+ sd_webui = self.converter.convert_to_webui(sd_denoiser, sd_TE, auto_scale_alpha=self.auto_scale_alpha, sdxl=sdxl)
234
+ else:
235
+ sd_webui = self.converter.convert_to_webui(sd_model['base'], {}, auto_scale_alpha=self.auto_scale_alpha)
236
+
237
+ self.format.save_ckpt(sd_webui, save_f)
238
+
239
+ def load_ckpt(self, ckpt_f: str, map_location="cpu", **kwargs):
240
+ sd_webui = self.format.load_ckpt(ckpt_f, map_location=map_location, **kwargs)
241
+
242
+ sdxl = False
243
+ for k in sd_webui.keys():
244
+ if ('lora_te1_' in k or 'lora_te2_' in k or
245
+ re.match(r'input_blocks.(\d+).1.(.+)', k) or
246
+ re.match(r'middle_block.1.(.+)', k) or
247
+ re.match(r'output_blocks.(\d+).(\d+).(.+)', k)):
248
+ sdxl = True
249
+ break
250
+
251
+ sd_all = self.converter.convert_from_webui(sd_webui, auto_scale_alpha=self.auto_scale_alpha, sdxl=sdxl)
252
+ return sd_all
@@ -0,0 +1,41 @@
1
+ import torch
2
+ from diffusers import AutoencoderKL, StableDiffusionPipeline, StableDiffusionXLPipeline
3
+ from rainbowneko.ckpt_manager.format import CkptFormat
4
+
5
+ from hcpdiff.diffusion.sampler import DDPMSampler, DDPMDiscreteSigmaScheduler
6
+ from hcpdiff.models.compose import SDXLTextEncoder, SDXLTokenizer
7
+
8
+ class OfficialSD15Format(CkptFormat):
9
+ # Single file format
10
+ def load_ckpt(self, pretrained_model: str, map_location="cpu", denoiser=None, TE=None, vae: AutoencoderKL = None, noise_sampler=None,
11
+ tokenizer=None, revision=None, dtype=torch.float32, **kwargs):
12
+ pipe_args = dict(unet=denoiser, vae=vae, text_encoder=TE, tokenizer=tokenizer)
13
+ pipe_args = {k:v for k,v in pipe_args.items() if v is not None}
14
+ pipe = StableDiffusionPipeline.from_single_file(
15
+ pretrained_model, revision=revision, torch_dtype=dtype, **pipe_args
16
+ )
17
+ noise_sampler = noise_sampler or DDPMSampler(DDPMDiscreteSigmaScheduler())
18
+ return dict(denoiser=pipe.unet, TE=pipe.text_encoder, vae=pipe.vae, noise_sampler=noise_sampler, tokenizer=pipe.tokenizer)
19
+
20
+ class OfficialSDXLFormat(CkptFormat):
21
+ # Single file format
22
+ def load_ckpt(self, pretrained_model: str, map_location="cpu", denoiser=None, TE=None, vae: AutoencoderKL = None, noise_sampler=None,
23
+ tokenizer=None, revision=None, dtype=torch.float32, **kwargs):
24
+ pipe_args = dict(unet=denoiser, vae=vae)
25
+ if TE is not None:
26
+ pipe_args['text_encoder'] = TE.clip_L
27
+ pipe_args['text_encoder_2'] = TE.clip_bigG
28
+ if tokenizer is not None:
29
+ pipe_args['tokenizer'] = tokenizer.clip_L
30
+ pipe_args['tokenizer_2'] = tokenizer.clip_bigG
31
+
32
+ pipe_args = {k:v for k,v in pipe_args.items() if v is not None}
33
+ pipe = StableDiffusionXLPipeline.from_single_file(
34
+ pretrained_model, revision=revision, torch_dtype=dtype, **pipe_args
35
+ )
36
+
37
+ noise_sampler = noise_sampler or DDPMSampler(DDPMDiscreteSigmaScheduler())
38
+ TE = SDXLTextEncoder([('clip_L', pipe.text_encoder), ('clip_bigG', pipe.text_encoder_2)])
39
+ tokenizer = SDXLTokenizer([('clip_L', pipe.tokenizer), ('clip_bigG', pipe.tokenizer_2)])
40
+
41
+ return dict(denoiser=pipe.unet, TE=TE, vae=pipe.vae, noise_sampler=noise_sampler, tokenizer=tokenizer)
@@ -0,0 +1,64 @@
1
+ from hcpdiff.models.lora_layers_patch import LoraLayer
2
+ from torch import nn
3
+ from hcpdiff.utils.net_utils import split_module_name
4
+ from rainbowneko.ckpt_manager import NekoPluginLoader, LocalCkptSource, CkptFormat
5
+ from rainbowneko.ckpt_manager.locator import get_match_layers
6
+ from rainbowneko.models.plugin import PluginGroup
7
+
8
+ def get_lora_rank_and_cls(lora_state):
9
+ if 'layer.W_down' in lora_state:
10
+ rank = lora_state['layer.W_down'].shape[0]
11
+ return LoraLayer, rank
12
+ else:
13
+ raise ValueError('Unknown lora format.')
14
+
15
+ class HCPLoraLoader(NekoPluginLoader):
16
+ def __init__(self, format: CkptFormat=None, source: LocalCkptSource=None, path: str = None, layers='all', target_plugin=None,
17
+ state_prefix=None, base_model_alpha=0.0, load_ema=False, module_to_load='', **plugin_kwargs):
18
+ super().__init__(format, source, path=path, layers=layers, target_plugin=target_plugin, state_prefix=state_prefix,
19
+ base_model_alpha=base_model_alpha, load_ema=load_ema, **plugin_kwargs)
20
+ self.module_to_load = module_to_load
21
+
22
+ def load_to(self, name, model):
23
+ # get model to load plugin and its named_modules
24
+ model = model if self.module_to_load == '' else eval(f"model.{self.module_to_load}")
25
+
26
+ named_modules = {k:v for k, v in model.named_modules()}
27
+ plugin_state = self.load(self.path, map_location='cpu')['base_ema' if self.load_ema else 'base']
28
+
29
+ # filter layers to load
30
+ if self.layers != 'all':
31
+ match_blocks = get_match_layers(self.layers, named_modules)
32
+ plugin_state = {k: v for blk in match_blocks for k, v in plugin_state.items() if k.startswith(blk)}
33
+
34
+ if self.state_prefix:
35
+ state_prefix_len = len(self.state_prefix)
36
+ plugin_state = {k[state_prefix_len:]: v for k, v in plugin_state.items() if k.startswith(self.state_prefix)}
37
+
38
+ lora_block_state = {}
39
+ # get all layers in the lora_state
40
+ for pname, p in plugin_state.items():
41
+ # lora_block. is the old format
42
+ prefix, block_name = pname.split('.___.', 1)
43
+ if prefix not in lora_block_state:
44
+ lora_block_state[prefix] = {}
45
+ lora_block_state[prefix][block_name] = p
46
+
47
+ # add lora to host and load weights
48
+ lora_blocks = {}
49
+ for layer_name, lora_state in lora_block_state.items():
50
+ lora_layer_cls, rank = get_lora_rank_and_cls(lora_state)
51
+
52
+ if 'alpha' in lora_state:
53
+ lora_state['alpha'] *= self.plugin_kwargs.get('alpha', 1.0)
54
+
55
+ parent_name, host_name = split_module_name(layer_name)
56
+
57
+ lora_block = lora_layer_cls.wrap_layer(name, named_modules[layer_name], rank=rank, bias='layer.bias' in lora_state,
58
+ parent_block=named_modules[parent_name], host_name=host_name)
59
+ lora_block.set_hyper_params(**self.plugin_kwargs)
60
+ lora_blocks[layer_name] = lora_block
61
+ load_info = lora_block.load_state_dict(lora_state, strict=False)
62
+ if len(load_info.unexpected_keys) > 0:
63
+ print(name, 'unexpected_keys', load_info.unexpected_keys)
64
+ return PluginGroup(lora_blocks)
hcpdiff/data/__init__.py CHANGED
@@ -1,28 +1,4 @@
1
- from .pair_dataset import TextImagePairDataset
2
- from .cond_dataset import TextImageCondPairDataset
3
- from .crop_info_dataset import CropInfoPairDataset
4
- from .bucket import BaseBucket, FixedBucket, RatioBucket, SizeBucket, RatioSizeBucket, LongEdgeBucket
5
- from .utils import CycleData
6
- from .caption_loader import JsonCaptionLoader, TXTCaptionLoader
7
- from .sampler import DistributedCycleSampler, get_sampler
8
-
9
- class DataGroup:
10
- def __init__(self, loader_list, loss_weights):
11
- self.loader_list = loader_list
12
- self.loss_weights = loss_weights
13
-
14
- def __iter__(self):
15
- self.data_iter_list = [iter(CycleData(loader)) for loader in self.loader_list]
16
- return self
17
-
18
- def __next__(self):
19
- return [next(data_iter) for data_iter in self.data_iter_list]
20
-
21
- def __len__(self):
22
- return len(self.loader_list)
23
-
24
- def get_dataset(self, idx):
25
- return self.loader_list[idx].dataset
26
-
27
- def get_loss_weights(self, idx):
28
- return self.loss_weights[idx]
1
+ from .dataset import TextImagePairDataset
2
+ from .source import Text2ImageSource, Text2ImageLossMapSource, Text2ImageCondSource, T2IFolderClassSource, TextSource
3
+ from .handler import StableDiffusionHandler, LossMapHandler, DiffusionImageHandler, DiffusionTextHandler
4
+ from .cache import VaeCache
@@ -0,0 +1 @@
1
+ from .vae import VaeCache
@@ -0,0 +1,102 @@
1
+ from io import BytesIO
2
+ from pathlib import Path
3
+ from typing import Dict, Any
4
+
5
+ import lmdb
6
+ import torch
7
+ from hcpdiff.models.wrapper import SD15Wrapper
8
+ from rainbowneko import _share
9
+ from rainbowneko.data import DataCache, CacheableDataset
10
+ from rainbowneko.utils import Path_Like
11
+ from torch.utils.data import DataLoader
12
+ from torch.utils.data.distributed import DistributedSampler
13
+ from tqdm import tqdm
14
+
15
+ class VaeCache(DataCache):
16
+ def __init__(self, pre_build: Path_Like = None, lazy=False, bs=1):
17
+ super().__init__(pre_build)
18
+ self.lazy = lazy
19
+ self.bs = bs
20
+
21
+ def load_latent(self, id):
22
+ if self.lazy:
23
+ with self.env.begin() as txn:
24
+ byte_tensor = txn.get(str(id).encode())
25
+ return torch.load(BytesIO(byte_tensor))
26
+ else:
27
+ return self.cache[id]
28
+
29
+ def before_handler(self, index: int, data: Dict[str, Any]):
30
+ cached_data = self.load_latent(data['id'])
31
+ data['image'] = cached_data['latent']
32
+ data['coord'] = cached_data['coord']
33
+ return data
34
+
35
+ def on_finish(self, index, data):
36
+ return data
37
+
38
+ def load(self, path):
39
+ if self.lazy:
40
+ self.env = lmdb.open(path, subdir=False, readonly=True, lock=False, readahead=False, meminit=False)
41
+ return {}
42
+ elif len(self.cache)>0:
43
+ return self.cache
44
+ else:
45
+ env = lmdb.open(path, subdir=False, readonly=True, lock=False, readahead=False, meminit=False)
46
+ with env.begin() as txn:
47
+ cache = {k.decode():torch.load(BytesIO(v)) for k, v in txn.cursor()}
48
+ env.close()
49
+ return cache
50
+
51
+ def build(self, dataset: CacheableDataset, model: SD15Wrapper, all_gather):
52
+ if (self.pre_build and Path(self.pre_build).exists()) or len(self.cache)>0:
53
+ model.vae = None
54
+ return
55
+
56
+ vae = model.vae.to(_share.device)
57
+ with dataset.disable_cache():
58
+ dataset.bucket.rest(0)
59
+
60
+ loader = DataLoader(
61
+ dataset,
62
+ batch_size=self.bs,
63
+ num_workers=0,
64
+ sampler=DistributedSampler(dataset, num_replicas=_share.world_size, rank=_share.local_rank, shuffle=False),
65
+ collate_fn=dataset.collate_fn,
66
+ drop_last=False,
67
+ )
68
+
69
+ if self.pre_build:
70
+ Path(self.pre_build).parent.mkdir(parents=True, exist_ok=True)
71
+ env = lmdb.open(self.pre_build, map_size=1099511627776)
72
+ with env.begin(write=True) as txn:
73
+ for data in tqdm(loader):
74
+ image = data['image'].to(device=_share.device, dtype=vae.dtype)
75
+ latents = model.vae.encode(image).latent_dist.sample()
76
+ latents = (latents*vae.config.scaling_factor).cpu()
77
+
78
+ for img_id, latent, coord in zip(data['id'], latents, data['coord']):
79
+ data_cache = {'latent': latent, 'coord': coord}
80
+
81
+ byte_stream = BytesIO()
82
+ torch.save(data_cache, byte_stream)
83
+ txn.put(str(img_id).encode(), byte_stream.getvalue())
84
+ if not self.lazy:
85
+ self.cache[img_id] = data_cache
86
+ env.close()
87
+ else:
88
+ for data in tqdm(loader):
89
+ img_id = data['id']
90
+ image = data['image'].to(device=_share.device, dtype=vae.dtype)
91
+ latents = model.vae.encode(image).latent_dist.sample()
92
+ latents = (latents*vae.config.scaling_factor).cpu()
93
+ for img_id, latent, coord in zip(data['id'], latents, data['coord']):
94
+ self.cache[img_id] = {'latent': latent, 'coord': coord}
95
+
96
+ model.vae.to('cpu')
97
+ #model.vae = None
98
+ torch.cuda.empty_cache()
99
+
100
+ cache_all = all_gather(self.cache)
101
+ for cache in cache_all:
102
+ self.cache.update(cache)
@@ -0,0 +1,20 @@
1
+ """
2
+ pair_dataset.py
3
+ ====================
4
+ :Name: text-image pair dataset
5
+ :Author: Dong Ziyi
6
+ :Affiliation: HCP Lab, SYSU
7
+ :Created: 10/03/2023
8
+ :Licence: Apache-2.0
9
+ """
10
+
11
+ from typing import Union, Dict
12
+
13
+ from rainbowneko.data import CacheableDataset, BaseDataset, BaseBucket, DataSource, DataHandler, DataCache
14
+
15
+ def TextImagePairDataset(bucket: BaseBucket = None, source: Dict[str, DataSource] = None, handler: DataHandler = None,
16
+ batch_handler: DataHandler = None, cache: DataCache = None, **kwargs) -> Union[CacheableDataset, BaseDataset]:
17
+ if cache is None:
18
+ return BaseDataset(bucket=bucket, source=source, handler=handler, batch_handler=batch_handler, **kwargs)
19
+ else:
20
+ return CacheableDataset(bucket=bucket, source=source, handler=handler, batch_handler=batch_handler, cache=cache, **kwargs)
@@ -0,0 +1,3 @@
1
+ from .diffusion import StableDiffusionHandler, DiffusionImageHandler, LossMapHandler, DiffusionTextHandler
2
+ from .text import TokenizeHandler, TagEraseHandler, TagDropoutHandler, TagShuffleHandler, TemplateFillHandler
3
+ from .controlnet import ControlNetHandler