hcpdiff 2.2.1__py3-none-any.whl → 2.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. hcpdiff/ckpt_manager/__init__.py +1 -1
  2. hcpdiff/ckpt_manager/ckpt.py +21 -17
  3. hcpdiff/ckpt_manager/format/diffusers.py +4 -4
  4. hcpdiff/ckpt_manager/format/sd_single.py +3 -3
  5. hcpdiff/ckpt_manager/loader.py +11 -4
  6. hcpdiff/diffusion/noise/__init__.py +0 -1
  7. hcpdiff/diffusion/sampler/VP.py +27 -0
  8. hcpdiff/diffusion/sampler/__init__.py +2 -3
  9. hcpdiff/diffusion/sampler/base.py +106 -44
  10. hcpdiff/diffusion/sampler/diffusers.py +11 -17
  11. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -1
  12. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +77 -2
  13. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +193 -49
  14. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +110 -33
  15. hcpdiff/diffusion/sampler/sigma_scheduler/flow.py +74 -0
  16. hcpdiff/diffusion/sampler/sigma_scheduler/zero_terminal.py +22 -0
  17. hcpdiff/easy/cfg/sd15_train.py +33 -22
  18. hcpdiff/easy/cfg/sdxl_train.py +32 -23
  19. hcpdiff/evaluate/__init__.py +3 -1
  20. hcpdiff/evaluate/evaluator.py +76 -0
  21. hcpdiff/evaluate/metrics/__init__.py +1 -0
  22. hcpdiff/evaluate/metrics/clip_score.py +23 -0
  23. hcpdiff/evaluate/previewer.py +29 -12
  24. hcpdiff/loss/base.py +9 -26
  25. hcpdiff/loss/weighting.py +36 -18
  26. hcpdiff/models/lora_base_patch.py +26 -0
  27. hcpdiff/models/wrapper/sd.py +17 -19
  28. hcpdiff/trainer_ac.py +7 -5
  29. hcpdiff/trainer_ac_single.py +1 -6
  30. hcpdiff/utils/__init__.py +2 -1
  31. hcpdiff/utils/torch_utils.py +25 -0
  32. hcpdiff/workflow/__init__.py +1 -1
  33. hcpdiff/workflow/diffusion.py +27 -7
  34. hcpdiff/workflow/io.py +20 -3
  35. hcpdiff/workflow/text.py +6 -1
  36. {hcpdiff-2.2.1.dist-info → hcpdiff-2.3.1.dist-info}/METADATA +2 -2
  37. {hcpdiff-2.2.1.dist-info → hcpdiff-2.3.1.dist-info}/RECORD +41 -37
  38. {hcpdiff-2.2.1.dist-info → hcpdiff-2.3.1.dist-info}/WHEEL +1 -1
  39. hcpdiff/diffusion/noise/zero_terminal.py +0 -39
  40. hcpdiff/diffusion/sampler/ddpm.py +0 -20
  41. hcpdiff/diffusion/sampler/edm.py +0 -22
  42. {hcpdiff-2.2.1.dist-info → hcpdiff-2.3.1.dist-info}/entry_points.txt +0 -0
  43. {hcpdiff-2.2.1.dist-info → hcpdiff-2.3.1.dist-info}/licenses/LICENSE +0 -0
  44. {hcpdiff-2.2.1.dist-info → hcpdiff-2.3.1.dist-info}/top_level.txt +0 -0
@@ -6,7 +6,7 @@ import torch
6
6
  from hcpdiff.diffusion.sampler import BaseSampler, DiffusersSampler
7
7
  from hcpdiff.utils import prepare_seed
8
8
  from hcpdiff.utils.net_utils import get_dtype, to_cuda
9
- from rainbowneko.infer import BasicAction
9
+ from rainbowneko.infer import BasicAction, Actions
10
10
  from torch.cuda.amp import autocast
11
11
 
12
12
  try:
@@ -32,8 +32,9 @@ class SeedAction(BasicAction):
32
32
  self.seed = seed
33
33
  self.bs = bs
34
34
 
35
- def forward(self, device, seed=None, **states):
36
- bs = states['prompt_embeds'].shape[0]//2 if 'prompt_embeds' in states else self.bs
35
+ def forward(self, device, seed=None, bs=None, **states):
36
+ if bs is None:
37
+ bs = states['prompt_embeds'].shape[0]//2 if 'prompt_embeds' in states else self.bs
37
38
  seed = seed or self.seed
38
39
  if seed is None:
39
40
  seeds = [None]*bs
@@ -155,15 +156,16 @@ class DenoiseAction(BasicAction):
155
156
 
156
157
  with autocast(enabled=amp is not None, dtype=get_dtype(amp)):
157
158
  latent_model_input = torch.cat([latents]*2) if self.guidance_scale>1 else latents
158
- latent_model_input = noise_sampler.c_in(t)*latent_model_input
159
+ latent_model_input = noise_sampler.sigma_scheduler.c_in(t)*latent_model_input
160
+ t_in = noise_sampler.sigma_scheduler.c_noise(t)
159
161
 
160
162
  if text_embeds is None:
161
- noise_pred = denoiser(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
163
+ noise_pred = denoiser(latent_model_input, t_in, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
162
164
  cross_attention_kwargs=cross_attention_kwargs, ).sample
163
165
  else:
164
166
  added_cond_kwargs = {"text_embeds":text_embeds, "time_ids":crop_info}
165
167
  # predict the noise residual
166
- noise_pred = denoiser(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
168
+ noise_pred = denoiser(latent_model_input, t_in, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
167
169
  cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
168
170
 
169
171
  # perform guidance
@@ -189,10 +191,28 @@ class DiffusionStepAction(BasicAction):
189
191
  states = self.act_noise_pred(denoiser=denoiser, noise_sampler=noise_sampler, **states)
190
192
  states = self.act_sample(**states)
191
193
  return states
194
+
195
+ class DiffusionActions(Actions):
196
+ def __init__(self, actions: List[BasicAction], clean_latent=True, seed_inc=True, key_map_in=None, key_map_out=None):
197
+ super().__init__(actions, key_map_in=key_map_in, key_map_out=key_map_out)
198
+ self.clean_latent = clean_latent
199
+ self.seed_inc = seed_inc
200
+
201
+ def forward(self, **states):
202
+ states = super().forward(**states)
203
+ if self.seed_inc and 'seed' in states:
204
+ bs = states['latents'].shape[0]
205
+ states['seed'] = states['seed'] + bs
206
+ if self.clean_latent:
207
+ states.pop('noise_pred', None)
208
+ states.pop('latents', None)
209
+ states.pop('prompt', None)
210
+ states.pop('negative_prompt', None)
211
+ return states
192
212
 
193
213
  class X0PredAction(BasicAction):
194
214
  def forward(self, latents, noise_sampler: BaseSampler, t, noise_pred, **states):
195
- latents_x0 = noise_sampler.eps_to_x0(noise_pred, latents, t)
215
+ latents_x0 = noise_sampler.pred_for_target(noise_pred, latents, t, target_type='x0')
196
216
  return {'latents_x0':latents_x0}
197
217
 
198
218
  def time_iter(timesteps, **states):
hcpdiff/workflow/io.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import os
2
2
  from functools import partial
3
3
  from typing import List, Union
4
+ from addict import Addict
4
5
 
5
6
  import torch
6
7
  from hcpdiff.utils import to_validate_file
@@ -9,6 +10,8 @@ from rainbowneko.ckpt_manager import NekoLoader
9
10
  from rainbowneko.infer import BasicAction
10
11
  from rainbowneko.infer import LoadImageAction as Neko_LoadImageAction
11
12
  from rainbowneko.utils.img_size_tool import types_support
13
+ from rainbowneko import _share
14
+ from rainbowneko.utils import is_dict
12
15
 
13
16
  class BuildModelsAction(BasicAction):
14
17
  def __init__(self, model_loader: partial[NekoLoader.load], dtype: str=torch.float32, device='cuda', key_map_in=None, key_map_out=None):
@@ -23,6 +26,14 @@ class BuildModelsAction(BasicAction):
23
26
  else:
24
27
  model = self.model_loader(dtype=self.dtype, device=self.device)
25
28
 
29
+ # Callback for TokenizerHandler
30
+ if is_dict(model):
31
+ model_wrapper = Addict(model)
32
+ else:
33
+ model_wrapper = model
34
+ for callback in _share.model_callbacks:
35
+ callback(model_wrapper)
36
+
26
37
  if isinstance(model, dict):
27
38
  return model
28
39
  else:
@@ -33,12 +44,13 @@ class LoadImageAction(Neko_LoadImageAction):
33
44
  super().__init__(image_paths, image_transforms, key_map_in, key_map_out)
34
45
 
35
46
  class SaveImageAction(BasicAction):
36
- def __init__(self, save_root: str, image_type: str = 'png', quality: int = 95, save_cfg=True, key_map_in=None, key_map_out=None):
47
+ def __init__(self, save_root: str, image_type: str = 'png', quality: int = 95, save_cfg=True, save_txt=False, key_map_in=None, key_map_out=None):
37
48
  super().__init__(key_map_in, key_map_out)
38
49
  self.save_root = save_root
39
50
  self.image_type = image_type
40
51
  self.quality = quality
41
52
  self.save_cfg = save_cfg
53
+ self.save_txt = save_txt
42
54
 
43
55
  os.makedirs(save_root, exist_ok=True)
44
56
 
@@ -47,10 +59,15 @@ class SaveImageAction(BasicAction):
47
59
  num_img_exist = max([0]+[int(x.split('-', 1)[0]) for x in os.listdir(save_root) if x.rsplit('.', 1)[-1] in types_support])+1
48
60
 
49
61
  for bid, (p, pn, img) in enumerate(zip(prompt, negative_prompt, images)):
50
- img_path = os.path.join(save_root, f"{preview_step or num_img_exist}-{seeds[bid]}-{to_validate_file(prompt[0])}.{self.image_type}")
62
+ img_path = os.path.join(save_root, f"{preview_step or num_img_exist}-{seeds[bid]}-{to_validate_file(p)}.{self.image_type}")
51
63
  img.save(img_path, quality=self.quality)
52
- num_img_exist += 1
53
64
 
54
65
  if self.save_cfg:
55
66
  cfgs.seed = seeds[bid]
56
67
  parser.save_configs(cfgs, os.path.join(save_root, f"{preview_step or num_img_exist}-{seeds[bid]}-info"))
68
+
69
+ if self.save_txt:
70
+ txt_path = os.path.join(save_root, f"{preview_step or num_img_exist}-{seeds[bid]}-{to_validate_file(prompt[0])}.txt")
71
+ with open(txt_path, 'w') as f:
72
+ f.write(p)
73
+ num_img_exist += 1
hcpdiff/workflow/text.py CHANGED
@@ -38,7 +38,7 @@ class TextHookAction(BasicAction):
38
38
  return {'te_hook':te_hook, 'emb_hook':emb_hook, 'token_ex':token_ex}
39
39
 
40
40
  class TextEncodeAction(BasicAction):
41
- def __init__(self, prompt: Union[List, str], negative_prompt: Union[List, str], bs: int = None, key_map_in=None, key_map_out=None):
41
+ def __init__(self, prompt: List|str|None, negative_prompt: List|str|None, bs: int = None, key_map_in=None, key_map_out=None):
42
42
  super().__init__(key_map_in, key_map_out)
43
43
  if isinstance(prompt, str) and bs is not None:
44
44
  prompt = [prompt]*bs
@@ -73,6 +73,11 @@ class AttnMultTextEncodeAction(TextEncodeAction):
73
73
  prompt = prompt or self.prompt
74
74
  negative_prompt = negative_prompt or self.negative_prompt
75
75
 
76
+ if isinstance(negative_prompt, str) and isinstance(prompt, (list, tuple)):
77
+ negative_prompt = [negative_prompt]*len(prompt)
78
+ if isinstance(prompt, str) and isinstance(negative_prompt, (list, tuple)):
79
+ prompt = [prompt]*len(negative_prompt)
80
+
76
81
  if model_offload:
77
82
  to_cuda(TE)
78
83
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hcpdiff
3
- Version: 2.2.1
3
+ Version: 2.3.1
4
4
  Summary: A universal Diffusion toolbox
5
5
  Home-page: https://github.com/IrisRainbowNeko/HCP-Diffusion
6
6
  Author: Ziyi Dong
@@ -17,7 +17,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
17
17
  Requires-Python: >=3.8
18
18
  Description-Content-Type: text/markdown
19
19
  License-File: LICENSE
20
- Requires-Dist: rainbowneko==1.6
20
+ Requires-Dist: rainbowneko>=1.9
21
21
  Requires-Dist: diffusers
22
22
  Requires-Dist: matplotlib
23
23
  Requires-Dist: pyarrow
@@ -1,16 +1,16 @@
1
1
  hcpdiff/__init__.py,sha256=dwNwrEgvG4g60fGMG6b50K3q3AWD1XCfzlIgbxkSUpE,177
2
2
  hcpdiff/train_colo.py,sha256=EsuNSzLBvGTZWU_LEk0JpP-F5eNW0lwkawIRAX38jmE,9250
3
- hcpdiff/trainer_ac.py,sha256=scH3FU0onCQtwLiy0-pcrhuowTZob3fLQqRP52iwY0c,2717
4
- hcpdiff/trainer_ac_single.py,sha256=0PIC5EScqcxp49EaeIWq4KS5K_09OZfKajqbFu-hUb8,1108
3
+ hcpdiff/trainer_ac.py,sha256=-owV-3_bvPxuQsZS2WaajBDh58HpftRtnx0GJkswqaY,2787
4
+ hcpdiff/trainer_ac_single.py,sha256=zyZVrutLUbIJYW1HzUnQ_RnmIcDhbC7M_CT833PJH5w,993
5
5
  hcpdiff/trainer_deepspeed.py,sha256=7lGsiAstWuIlmhRMwWTcJCkoxzUaakVxBngKDnJdSJk,1947
6
- hcpdiff/ckpt_manager/__init__.py,sha256=Mn_5KOC4xbf2GcN6OXg_XdbF5wO9zWeER_1ZO_prKAI,256
7
- hcpdiff/ckpt_manager/ckpt.py,sha256=Pa3uXQbCi2T99mpV5fYddQ-OGHcpk8r1ll-0lmP_WXk,965
8
- hcpdiff/ckpt_manager/loader.py,sha256=Ch1xsZmseq4nyPhpox9-nebN-dZB4k0rqBEHos-ZLso,3245
6
+ hcpdiff/ckpt_manager/__init__.py,sha256=r_sgjZWCLtdJrRkqqU6aPdfubXSYfPh2Z_Vf_XpZXXs,240
7
+ hcpdiff/ckpt_manager/ckpt.py,sha256=2A093lT03M1ZsJIMWl376V165eh0TZwOgiGrz3LM73Q,1248
8
+ hcpdiff/ckpt_manager/loader.py,sha256=6iZDUj-Vfc5T9eGdWfFMQw4n1GqyLqaLBolgAtgqPq8,3640
9
9
  hcpdiff/ckpt_manager/format/__init__.py,sha256=a3cdKkOTDgdVbDQwSC4mlxOigjX2hBvRb5_X7E3TQWs,237
10
- hcpdiff/ckpt_manager/format/diffusers.py,sha256=T81WN95Nj1il9DfQp9iioVn0uqFEWOlmdIYs2beNOFU,3769
10
+ hcpdiff/ckpt_manager/format/diffusers.py,sha256=qhGbrKAaeLyjFzY-Lj4sL1THHFNrta41JGGMoXT-bCE,3761
11
11
  hcpdiff/ckpt_manager/format/emb.py,sha256=FrqfTfJ8H7f0Zw17NTWCP2AJtpsJI5oXR5IAd4NekhU,680
12
12
  hcpdiff/ckpt_manager/format/lora_webui.py,sha256=4y_T9RdmFTxWzsXd8guNjCiukmyILa5j4MPrhVIL4Qk,10017
13
- hcpdiff/ckpt_manager/format/sd_single.py,sha256=LpCAL_7nAVooCHTFznVVsNMku1G3C77NBORxxr8GDtQ,2328
13
+ hcpdiff/ckpt_manager/format/sd_single.py,sha256=4DZLAl1RNC_nPxuW-lmrBlIMFUhpSTa7HGHgu7Yx8qk,2322
14
14
  hcpdiff/data/__init__.py,sha256=ZFKtanOoMo3G3eKUJPhysnHXnr8BNARERkcMB6B897U,292
15
15
  hcpdiff/data/dataset.py,sha256=1k4GldW13eVyqK_9hrQniqr3_XYAapnWF7iXl_1GXGg,877
16
16
  hcpdiff/data/cache/__init__.py,sha256=ToCmokYH6DghlSwm7HJFirPRIWJ0LkgzqVOYlgoAkQw,25
@@ -25,41 +25,44 @@ hcpdiff/data/source/text.py,sha256=VgI5Ouq986Yy1jwD2fZ9iBlsRciPCeARZmOPEZIcaQY,1
25
25
  hcpdiff/data/source/text2img.py,sha256=acYdolQhZUEpkd7tUAdNkCTVnPc1SMJOVTmGqFt9ZpE,1813
26
26
  hcpdiff/data/source/text2img_cond.py,sha256=yj1KpARA2rkjENutnnzC4uDkcU2Rye21FL2VdC25Hac,585
27
27
  hcpdiff/diffusion/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
- hcpdiff/diffusion/noise/__init__.py,sha256=seBpOtd0YsU53PqMn7Nyl_RtwoC-ONEIOX7v2XLGpZQ,93
28
+ hcpdiff/diffusion/noise/__init__.py,sha256=D83EZ6bnc6Ucu4AZwE6rpmXCtwYfHHumeVq97brbnIE,47
29
29
  hcpdiff/diffusion/noise/pyramid_noise.py,sha256=KbpyMT1BHNIaAa7g5eECDkTttOMoMWVFmbP-ekBsuEY,1693
30
- hcpdiff/diffusion/noise/zero_terminal.py,sha256=EfVOaqrTCfw11AolDBl0LIOey3uQT1bDw2XKr2Bm434,1532
31
- hcpdiff/diffusion/sampler/__init__.py,sha256=pSHsKpLjscY5yLbdzHeBUeK9nFDuVeMIIeA_k6FQFdY,158
32
- hcpdiff/diffusion/sampler/base.py,sha256=2AuPVT2ZSXYt2etZmHMyNKuGlT5zn6KIkoMz4m5PGcs,2577
33
- hcpdiff/diffusion/sampler/ddpm.py,sha256=raqSuKsEPN1AEqRVCuBdMAOnKDoeJTRO17wtLBNJCf4,523
34
- hcpdiff/diffusion/sampler/diffusers.py,sha256=XIu-oIlT4LnAYI2-yyIoNeIMeSQe5YpeEvn-FkGVFnE,2684
35
- hcpdiff/diffusion/sampler/edm.py,sha256=5W4pv8hxQsPpJGiFBgElZxR3C9S8kWAhzGKejEVwq3I,753
36
- hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py,sha256=kmIoWgsWqij6b7KYon3UOCSC07sRJCo-DPR6qkJwUd0,184
37
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py,sha256=9-JI-jwf7xZoQUtrU0qfbjkhNZT8a_tmapLtwVbFUx0,381
38
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py,sha256=2PMIpg2K6CVoxew1y1pIqvCHbdggC-m3amrOYk15OdQ,8107
39
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py,sha256=fOPB3lgnS9uVo4oW26Fur_nc8X_wQ6mmUcbkKhnoQjs,1900
30
+ hcpdiff/diffusion/sampler/VP.py,sha256=r0Q_RROEIeNNw93XrOD5htW78rfuoSxy1WBQEoQL83s,958
31
+ hcpdiff/diffusion/sampler/__init__.py,sha256=Lrwg1us8qo943T7mdIXFDRXfKvnLhrzwmi6DrIKIiUA,135
32
+ hcpdiff/diffusion/sampler/base.py,sha256=UbE_AmtvLg-Hr2bkYz8PvNWB63tvtacUvCIDm_W6opA,5484
33
+ hcpdiff/diffusion/sampler/diffusers.py,sha256=wIMs8n3kdci1On0FUCV0si324ZE9zeRw_CxaHP8rdcs,2586
34
+ hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py,sha256=eiSmMBkXI_LfxnNrXj5XptcF0dGcPas--vWvqhFGlv8,273
35
+ hcpdiff/diffusion/sampler/sigma_scheduler/base.py,sha256=UT4tbjFf80KYfU08y0hJf8h_Cl80a5MUhK5FsKLsqbY,2521
36
+ hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py,sha256=SA8lXT6lucJot_rpJ84Wz-_uc5dXfb2QPoQgJHSKOj4,12999
37
+ hcpdiff/diffusion/sampler/sigma_scheduler/edm.py,sha256=m1YlIyn61zfDjxLMcHvWs0nzULbHgXeB7WGKmTiaGSU,4127
38
+ hcpdiff/diffusion/sampler/sigma_scheduler/flow.py,sha256=FtWpesUtSmFuiIGkrrVhYJweB7INZiw0atC64tc0Nk4,2020
39
+ hcpdiff/diffusion/sampler/sigma_scheduler/zero_terminal.py,sha256=CCqQLkGo4omkxzFovYdZQzdZVwIxK3PiOitZFww8MHs,859
40
40
  hcpdiff/easy/__init__.py,sha256=-emoyCOZlLCu3KNMI8L4qapUEtEYFSoiGU6-rKv1at4,149
41
41
  hcpdiff/easy/sampler.py,sha256=dQSBkeGh71O0DAmZLhTHTbk1bY7XzyUCeW1oJO14A4I,1250
42
42
  hcpdiff/easy/cfg/__init__.py,sha256=SxHMWG6T2CXhX3dP0xizSMd9vFWPaZQDc4Gj4CF__yQ,253
43
- hcpdiff/easy/cfg/sd15_train.py,sha256=kKdESVqAxNlBhhz12PvwrpHJBea80OUFzDDMHwiulVs,6710
44
- hcpdiff/easy/cfg/sdxl_train.py,sha256=FUWE_hRJdQc9Qd9J6730jAyK0H4EIKS7-3BSufCItXU,4275
43
+ hcpdiff/easy/cfg/sd15_train.py,sha256=NtgsQLg1sd5JFmHU4nqMPOrvP7zmwo2x0MCspjVNQEY,7000
44
+ hcpdiff/easy/cfg/sdxl_train.py,sha256=rVLLKVMKB_PHuum3dKQcBqL0uR8QhzmdRllM-pYnbK4,4534
45
45
  hcpdiff/easy/cfg/t2i.py,sha256=SnjFjZAKd9orjJr3RW5_N2_EIlW2Ree7JMvdNUAR9gc,9507
46
46
  hcpdiff/easy/model/__init__.py,sha256=CA-7r3R2Jgweekk1XNByFYttLolbWyUV2bCnXygcD8w,133
47
47
  hcpdiff/easy/model/cnet.py,sha256=m0NTH9V1kLzb5GybwBrSNT0KvTcRpPfGkzUeMz9jZZQ,1084
48
48
  hcpdiff/easy/model/loader.py,sha256=Tdx-lhQEYf2NYjVM1A5B8x6ZZpJKcXUkFIPIbr7h7XM,3456
49
- hcpdiff/evaluate/__init__.py,sha256=CtNzi8xdUWZBDBcP5TZTDMcRyykaOJhBIxTJgVuMabo,35
50
- hcpdiff/evaluate/previewer.py,sha256=QiYYiEJBKP06uL3wKLhnpIGUZuAkr9BuHxTE29obpXI,2432
49
+ hcpdiff/evaluate/__init__.py,sha256=qWxV0D8Ho5uBj2YbaC_QFDnT49PSKPfh44m4ivkNbMM,108
50
+ hcpdiff/evaluate/evaluator.py,sha256=9BZQBeC-N7p-ICx6Giw9v-2Tb9volMTDmeDfhj0nXJ0,2940
51
+ hcpdiff/evaluate/previewer.py,sha256=-vE0YXVfos70CQMo9ZInw7xu3d88DlTfVLs4BzzkxfM,3140
52
+ hcpdiff/evaluate/metrics/__init__.py,sha256=vE0nSvBtDBu9SomANvWcm2UHX56PhCYwhgrcmm_mKyo,39
53
+ hcpdiff/evaluate/metrics/clip_score.py,sha256=rQgweu5QcqW3fPI3EXcNbrH2QCcSAekE3lpYk45P2M4,900
51
54
  hcpdiff/loss/__init__.py,sha256=2dwPczSiv3rB5fzOeYbl5ZHpMU-qXOQlXeOiXdxcxwM,173
52
- hcpdiff/loss/base.py,sha256=3bvgMbwyPOEA9iSkv0hRHw4VnKjkUCZAENNnDMFilYM,1780
55
+ hcpdiff/loss/base.py,sha256=Vvpm-KZGH4n-gYIlnVAtPl1B799c7v0dJXJ5BBh3yO0,1112
53
56
  hcpdiff/loss/gw.py,sha256=0yi1kozuII3xZA6FnjOhINtvScWt1MyBZLBtMKmgojM,1224
54
57
  hcpdiff/loss/ssim.py,sha256=YofadvBkc6sklxBUx1p3ADw5OHOZPK3kaHz8FH5a6m4,1281
55
58
  hcpdiff/loss/vlb.py,sha256=s78iBnXUiDWfGf7mYmhUnHqxqea5gSByKOoqBrX6bzU,3222
56
- hcpdiff/loss/weighting.py,sha256=UR2PyZ1JTNOydXMw4e1Fh52XmtwKaviPvddcmVKCTlI,2242
59
+ hcpdiff/loss/weighting.py,sha256=9qzMnvCb6b5qx0p08GDSlkxmYEqQcNt79XdRBvfHmiI,2914
57
60
  hcpdiff/models/__init__.py,sha256=eQS7DPiGLiE1MFRkZj_17IY3IsfDUVcYpcOmhHb5B9o,472
58
61
  hcpdiff/models/cfg_context.py,sha256=e2B3K1KwJhzbD6xdJUOyNtl_XgQ0296XI3FHw3gvZF4,1502
59
62
  hcpdiff/models/container.py,sha256=z3p5TmQhxdzXSIfofz55_bmEhSsgUJsy1o9EcDs8Oeo,696
60
63
  hcpdiff/models/controlnet.py,sha256=VIkUzJCVpCqqQOtRSLQPfbcDy9CsXutxLeZB6PdZfA0,7809
61
64
  hcpdiff/models/lora_base.py,sha256=LGwBD9KP6qf4pgTx24i5-JLo4rDBQ6jFfterQKBjTbE,6758
62
- hcpdiff/models/lora_base_patch.py,sha256=WW3CULnROTxKXyynJiqirhHYCKN5JtxLhVpT5b7AUQg,6532
65
+ hcpdiff/models/lora_base_patch.py,sha256=Tdb_b3TN_K-04nlUvcfBh6flPcbL9M4iP7jOVyb1jXQ,7271
63
66
  hcpdiff/models/lora_layers.py,sha256=O9W_Ue71lHj7Y_GbpioF4Hc3h2-z_zOqck93VYUra6s,7777
64
67
  hcpdiff/models/lora_layers_patch.py,sha256=GYFYsJD2VSLZfdnLma9CmQEHz09HROFJcc4wc_gs9f0,8198
65
68
  hcpdiff/models/text_emb_ex.py,sha256=O0XZqid01OrB0dHY7hCiBvdU2026SvZ38yfQaF2TWrs,8018
@@ -72,7 +75,7 @@ hcpdiff/models/compose/compose_tokenizer.py,sha256=g3l0pOFv6p7Iigxm6Pqt_iTUXBlO1
72
75
  hcpdiff/models/compose/sdxl_composer.py,sha256=NtMGaFGZTfKsPJSVi2yT-UM6K1WKWtk99XxVmTcKlk8,2164
73
76
  hcpdiff/models/wrapper/__init__.py,sha256=HbGQmFnfccr-dtvZKjEv-pmR4cCnF4fwGLKS3tuG_OY,135
74
77
  hcpdiff/models/wrapper/pixart.py,sha256=nRUvHSHn4TYg_smC0xpeW-GtUgXss-MuaVPTHpMozDE,1147
75
- hcpdiff/models/wrapper/sd.py,sha256=D7VDI4OmbLTk9mzYta-C4LJjWfZmuBiDub4t8v1-M9o,11711
78
+ hcpdiff/models/wrapper/sd.py,sha256=EywmVU2QzR74M_4eH_uXVW8HJNauyjwcZPU7rRAQ7eI,11666
76
79
  hcpdiff/models/wrapper/utils.py,sha256=NyebMoAPnrgcTHbiIocSD-eGdGdD-V1G_TQuWsRWufw,665
77
80
  hcpdiff/parser/__init__.py,sha256=-2dDZ2Ii4zoGQqDTme94q4PpJbBiV6HS5BsDASz4Xbo,33
78
81
  hcpdiff/parser/embpt.py,sha256=LgwZ0f0tLn3DrTo5ZpSCsZcA5330UpiW_sK96yEPmOM,1307
@@ -89,27 +92,28 @@ hcpdiff/tools/init_proj.py,sha256=XrXxxhIaItywG7HsrloJo-x8w9suZiY35daelzZvjrg,19
89
92
  hcpdiff/tools/lora_convert.py,sha256=So14WvSVIm6rU4m1XCajFXDnhq7abpZS95SLbaoyBFU,10058
90
93
  hcpdiff/tools/save_model.py,sha256=gbfYi_EfEBZEUcDjle6MDHA19sQWY0zA8_y_LMzHQ7M,428
91
94
  hcpdiff/tools/sd2diffusers.py,sha256=vB6OnBLw60sJkdpVZcYEPtKAZW1h8ErbSGSRq0uAiIk,16855
92
- hcpdiff/utils/__init__.py,sha256=VOLhdNq2mRyqmWxrssIWSZtR_PQ8rFwo2u0uq6GbLHA,45
95
+ hcpdiff/utils/__init__.py,sha256=28K9Ui0uur-vHuUdlSyIBYijgu2b7rGOPXN2ogJu1z8,82
93
96
  hcpdiff/utils/colo_utils.py,sha256=JyLUvVnISa48CnryNLrgVxMo-jxu2UhBq70eYPrkjuI,837
94
97
  hcpdiff/utils/inpaint_pipe.py,sha256=CRy1MUlPmHifCAbZnKOP0qbLp2grn7ZbVeaB2qIA4ig,42862
95
98
  hcpdiff/utils/net_utils.py,sha256=gdwLYDNKV2t3SP0jBIO3d0HtY6E7jRaf_rmPT8gKZZE,9762
96
99
  hcpdiff/utils/pipe_hook.py,sha256=-UDX3FtZGl-bxSk13gdbPXc1OvtbCcpk_fvKxLQo3Ag,31987
100
+ hcpdiff/utils/torch_utils.py,sha256=gBZCcDKZc0NGDQx6QeHuQePoZ82kQRhaL7oEdZIYGvU,573
97
101
  hcpdiff/utils/utils.py,sha256=hZnZP1IETgVpScxES0yIuRfc34TnzvAqmgOTK_56ssw,4976
98
- hcpdiff/workflow/__init__.py,sha256=t7Zyc0XFORdNvcwHp9AsCtEkhJ3l7Hm41ugngIL0Sag,867
99
- hcpdiff/workflow/diffusion.py,sha256=yzhqKA3019OPu1RKggrLoytMgm919qf6j9S85PYOwjQ,8644
102
+ hcpdiff/workflow/__init__.py,sha256=i5s7QXo6wK9607KL0KTW4suE1c-HGJ5_EgnCdVLl3WM,885
103
+ hcpdiff/workflow/diffusion.py,sha256=hKefBrVP6-025MhdrKOQMUhHxLaGqjpUKhR6WahYwh0,9549
100
104
  hcpdiff/workflow/fast.py,sha256=kZt7bKrvpFInSn7GzbkTkpoCSM0Z6IbDjgaDvcbFYf8,1024
101
105
  hcpdiff/workflow/flow.py,sha256=FFbFFOAXT4c31L5bHBEB_qeVGuBQDLYhq8kTD1chGNo,2548
102
- hcpdiff/workflow/io.py,sha256=aTrMR3s44apVJpnSyvZIabW2Op0tslk_Z9JFJl5svm0,2635
106
+ hcpdiff/workflow/io.py,sha256=4oiE_PS3sOVYT8M6PDwvT5h9XzoKDMQR0n_4-Ktttys,3284
103
107
  hcpdiff/workflow/model.py,sha256=1gj5yOTefYTnGXVR6JPAfxIwuB69YwN6E-BontRcuyQ,2913
104
- hcpdiff/workflow/text.py,sha256=Z__SJHZyuaKyzkYJ6rbiAzOGRiYcCjwCGeqfpP1Jo7o,4336
108
+ hcpdiff/workflow/text.py,sha256=XQvN4zzK7VaGxy4FDgSDeWh2jjk7UZU24moeRKAWXRE,4608
105
109
  hcpdiff/workflow/utils.py,sha256=xojaMG4lHsymslc8df5uiVXmmBVWpn_Phqka8qzJEWw,2226
106
110
  hcpdiff/workflow/vae.py,sha256=cingDPkIOc4qGpOwwhXJK4EQbGoIxO583pm6gGov5t8,3118
107
111
  hcpdiff/workflow/daam/__init__.py,sha256=ySIDaxloN-D3qM7OuVaG1BR3D-CibDoXYpoTgw0zUhU,59
108
112
  hcpdiff/workflow/daam/act.py,sha256=tHbsFWTYYU4bvcZOo1Bpi_z6ofpJatRYccl4vvf8wIA,2756
109
113
  hcpdiff/workflow/daam/hook.py,sha256=z9f9mBjKW21xuUZ-iQxQ0HbWOBXtZrisFB0VNMq6d0U,4383
110
- hcpdiff-2.2.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
111
- hcpdiff-2.2.1.dist-info/METADATA,sha256=f96Tc90K5WTBbJ35wWJw60G2JR46eGpUvQSaPIysVDg,10323
112
- hcpdiff-2.2.1.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
113
- hcpdiff-2.2.1.dist-info/entry_points.txt,sha256=_4VRsEsEWOhHfzBDu9bx8Wh_S8Wi4ZTHpI0n6rU0J-I,258
114
- hcpdiff-2.2.1.dist-info/top_level.txt,sha256=shyf78x-HVgykYpsmY22mKG0xIc7Qk30fDMdavdYWQ8,8
115
- hcpdiff-2.2.1.dist-info/RECORD,,
114
+ hcpdiff-2.3.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
115
+ hcpdiff-2.3.1.dist-info/METADATA,sha256=zaJHhKQiezDTvyv-IIoRHf4VCv0z2gU9fq0sVi9XhTg,10323
116
+ hcpdiff-2.3.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
117
+ hcpdiff-2.3.1.dist-info/entry_points.txt,sha256=_4VRsEsEWOhHfzBDu9bx8Wh_S8Wi4ZTHpI0n6rU0J-I,258
118
+ hcpdiff-2.3.1.dist-info/top_level.txt,sha256=shyf78x-HVgykYpsmY22mKG0xIc7Qk30fDMdavdYWQ8,8
119
+ hcpdiff-2.3.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.1)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,39 +0,0 @@
1
- import torch
2
- from ..sampler.sigma_scheduler import DDPMDiscreteSigmaScheduler
3
-
4
- class ZeroTerminalSampler:
5
-
6
- @classmethod
7
- def patch(cls, base_sampler):
8
- assert isinstance(base_sampler.sigma_scheduler, DDPMDiscreteSigmaScheduler), "ZeroTerminalScheduler only works with DDPM SigmaScheduler"
9
-
10
- alphas_cumprod = base_sampler.sigma_scheduler.alphas_cumprod
11
- base_sampler.sigma_scheduler.alphas_cumprod = cls.rescale_zero_terminal_snr(alphas_cumprod)
12
- base_sampler.sigma_scheduler.sigmas = ((1-alphas_cumprod)/alphas_cumprod).sqrt()
13
-
14
-
15
- @staticmethod
16
- def rescale_zero_terminal_snr(alphas_cumprod, thr=1e-4):
17
- """
18
- Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
19
- Args:
20
- alphas_cumprod (`torch.FloatTensor`)
21
- Returns:
22
- `torch.FloatTensor`: rescaled betas with zero terminal SNR
23
- """
24
- alphas_bar_sqrt = alphas_cumprod.sqrt()
25
-
26
- # Store old values.
27
- alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
28
- alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
29
-
30
- # Shift so the last timestep is zero.
31
- alphas_bar_sqrt -= alphas_bar_sqrt_T
32
-
33
- # Scale so the first timestep is back to the old value.
34
- alphas_bar_sqrt *= alphas_bar_sqrt_0/(alphas_bar_sqrt_0-alphas_bar_sqrt_T)
35
- alphas_bar_sqrt[-1] = thr # avoid nan sigma
36
-
37
- # Convert alphas_bar_sqrt to betas
38
- alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
39
- return alphas_bar
@@ -1,20 +0,0 @@
1
- import torch
2
-
3
- from .base import BaseSampler
4
- from .sigma_scheduler import SigmaScheduler
5
-
6
- class DDPMSampler(BaseSampler):
7
- def __init__(self, sigma_scheduler: SigmaScheduler, generator: torch.Generator=None):
8
- super().__init__(sigma_scheduler, generator)
9
-
10
- def c_in(self, sigma):
11
- return 1./(sigma**2+1).sqrt()
12
-
13
- def c_out(self, sigma):
14
- return -sigma
15
-
16
- def c_skip(self, sigma):
17
- return 1.
18
-
19
- def denoise(self, x, sigma, eps=None, generator=None):
20
- raise NotImplementedError
@@ -1,22 +0,0 @@
1
- import torch
2
-
3
- from .base import BaseSampler
4
- from .sigma_scheduler import SigmaScheduler
5
-
6
- class EDMSampler(BaseSampler):
7
- def __init__(self, sigma_scheduler: SigmaScheduler, generator: torch.Generator = None, sigma_data: float = 1.0, sigma_thr=1000):
8
- super().__init__(sigma_scheduler, generator)
9
- self.sigma_data = sigma_data
10
- self.sigma_thr = sigma_thr
11
-
12
- def c_in(self, sigma):
13
- return 1/(sigma**2+self.sigma_data**2).sqrt()
14
-
15
- def c_out(self, sigma):
16
- return (sigma*self.sigma_data)/(sigma**2+self.sigma_data**2).sqrt()
17
-
18
- def c_skip(self, sigma):
19
- return self.sigma_data**2/(sigma**2+self.sigma_data**2)
20
-
21
- def denoise(self, x, sigma, eps=None, generator=None):
22
- raise NotImplementedError