hcpdiff 2.2__py3-none-any.whl → 2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. hcpdiff/ckpt_manager/__init__.py +1 -1
  2. hcpdiff/ckpt_manager/ckpt.py +21 -17
  3. hcpdiff/ckpt_manager/format/diffusers.py +4 -4
  4. hcpdiff/ckpt_manager/format/sd_single.py +3 -3
  5. hcpdiff/ckpt_manager/loader.py +11 -4
  6. hcpdiff/diffusion/noise/__init__.py +0 -1
  7. hcpdiff/diffusion/sampler/VP.py +27 -0
  8. hcpdiff/diffusion/sampler/__init__.py +2 -3
  9. hcpdiff/diffusion/sampler/base.py +106 -44
  10. hcpdiff/diffusion/sampler/diffusers.py +11 -17
  11. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -1
  12. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +77 -2
  13. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +193 -49
  14. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +110 -33
  15. hcpdiff/diffusion/sampler/sigma_scheduler/flow.py +74 -0
  16. hcpdiff/diffusion/sampler/sigma_scheduler/zero_terminal.py +22 -0
  17. hcpdiff/easy/cfg/sd15_train.py +35 -24
  18. hcpdiff/easy/cfg/sdxl_train.py +34 -25
  19. hcpdiff/evaluate/__init__.py +3 -1
  20. hcpdiff/evaluate/evaluator.py +76 -0
  21. hcpdiff/evaluate/metrics/__init__.py +1 -0
  22. hcpdiff/evaluate/metrics/clip_score.py +23 -0
  23. hcpdiff/evaluate/previewer.py +29 -12
  24. hcpdiff/loss/base.py +9 -26
  25. hcpdiff/loss/weighting.py +36 -18
  26. hcpdiff/models/lora_base_patch.py +26 -0
  27. hcpdiff/models/text_emb_ex.py +4 -0
  28. hcpdiff/models/wrapper/sd.py +17 -19
  29. hcpdiff/trainer_ac.py +7 -12
  30. hcpdiff/trainer_ac_single.py +1 -6
  31. hcpdiff/trainer_deepspeed.py +47 -0
  32. hcpdiff/utils/__init__.py +2 -1
  33. hcpdiff/utils/torch_utils.py +25 -0
  34. hcpdiff/workflow/__init__.py +1 -1
  35. hcpdiff/workflow/diffusion.py +27 -7
  36. hcpdiff/workflow/io.py +20 -3
  37. hcpdiff/workflow/text.py +6 -1
  38. {hcpdiff-2.2.dist-info → hcpdiff-2.3.dist-info}/METADATA +8 -4
  39. {hcpdiff-2.2.dist-info → hcpdiff-2.3.dist-info}/RECORD +43 -39
  40. {hcpdiff-2.2.dist-info → hcpdiff-2.3.dist-info}/WHEEL +1 -1
  41. {hcpdiff-2.2.dist-info → hcpdiff-2.3.dist-info}/entry_points.txt +1 -0
  42. hcpdiff/diffusion/noise/zero_terminal.py +0 -39
  43. hcpdiff/diffusion/sampler/ddpm.py +0 -20
  44. hcpdiff/diffusion/sampler/edm.py +0 -22
  45. hcpdiff/train_deepspeed.py +0 -69
  46. {hcpdiff-2.2.dist-info → hcpdiff-2.3.dist-info}/licenses/LICENSE +0 -0
  47. {hcpdiff-2.2.dist-info → hcpdiff-2.3.dist-info}/top_level.txt +0 -0
@@ -17,7 +17,7 @@ from ..cfg_context import CFGContext
17
17
 
18
18
  class SD15Wrapper(BaseWrapper):
19
19
  def __init__(self, denoiser: UNet2DConditionModel, TE, vae: AutoencoderKL, noise_sampler: BaseSampler, tokenizer, min_attnmask=0,
20
- pred_type='eps', TE_hook_cfg:TEHookCFG=SD15_TEHookCFG, cfg_context=CFGContext(), key_map_in=None, key_map_out=None):
20
+ TE_hook_cfg:TEHookCFG=SD15_TEHookCFG, cfg_context=CFGContext(), key_map_in=None, key_map_out=None):
21
21
  super().__init__()
22
22
  self.key_mapper_in = self.build_mapper(key_map_in, None, (
23
23
  'prompt -> prompt_ids', 'image -> image', 'attn_mask -> attn_mask', 'position_ids -> position_ids', 'neg_prompt -> neg_prompt_ids',
@@ -31,8 +31,6 @@ class SD15Wrapper(BaseWrapper):
31
31
  self.tokenizer = tokenizer
32
32
  self.min_attnmask = min_attnmask
33
33
 
34
- self.pred_type = pred_type
35
-
36
34
  self.TE_hook_cfg = TEHookCFG.create(TE_hook_cfg)
37
35
  self.cfg_context = cfg_context
38
36
  self.tokenizer.N_repeats = self.TE_hook_cfg.tokenizer_repeats
@@ -93,8 +91,9 @@ class SD15Wrapper(BaseWrapper):
93
91
  plugin_input={}, **kwargs):
94
92
  # input prepare
95
93
  x_0 = self.get_latents(image)
96
- x_t, noise, sigma, timesteps = self.noise_sampler.add_noise_rand_t(x_0)
97
- x_t_in = x_t*self.noise_sampler.c_in(sigma).to(dtype=x_t.dtype)
94
+ x_t, noise, timesteps = self.noise_sampler.add_noise_rand_t(x_0)
95
+ x_t_in = x_t*self.noise_sampler.sigma_scheduler.c_in(timesteps).to(dtype=x_t.dtype).view(-1,1,1,1)
96
+ t_in = self.noise_sampler.sigma_scheduler.c_noise(timesteps)
98
97
 
99
98
  if neg_prompt_ids:
100
99
  prompt_ids = torch.cat([neg_prompt_ids, prompt_ids], dim=0)
@@ -104,15 +103,14 @@ class SD15Wrapper(BaseWrapper):
104
103
  position_ids = torch.cat([neg_position_ids, position_ids], dim=0)
105
104
 
106
105
  # model forward
107
- x_t_in, timesteps = self.cfg_context.pre(x_t_in, timesteps)
108
- encoder_hidden_states = self.forward_TE(prompt_ids, timesteps, attn_mask=attn_mask, position_ids=position_ids,
106
+ x_t_in, t_in = self.cfg_context.pre(x_t_in, t_in)
107
+ encoder_hidden_states = self.forward_TE(prompt_ids, t_in, attn_mask=attn_mask, position_ids=position_ids,
109
108
  plugin_input=plugin_input, **kwargs)
110
- model_pred = self.forward_denoiser(x_t_in, prompt_ids, encoder_hidden_states, timesteps, attn_mask=attn_mask, position_ids=position_ids,
109
+ model_pred = self.forward_denoiser(x_t_in, prompt_ids, encoder_hidden_states, t_in, attn_mask=attn_mask, position_ids=position_ids,
111
110
  plugin_input=plugin_input, **kwargs)
112
111
  model_pred = self.cfg_context.post(model_pred)
113
112
 
114
- return dict(model_pred=model_pred, noise=noise, sigma=sigma, timesteps=timesteps, x_0=x_0, x_t=x_t, pred_type=self.pred_type,
115
- noise_sampler=self.noise_sampler)
113
+ return dict(model_pred=model_pred, noise=noise, timesteps=timesteps, x_0=x_0, x_t=x_t, noise_sampler=self.noise_sampler)
116
114
 
117
115
  def forward(self, ds_name=None, **kwargs):
118
116
  model_args, model_kwargs = self.get_map_data(self.key_mapper_in, kwargs, ds_name)
@@ -156,8 +154,8 @@ class SD15Wrapper(BaseWrapper):
156
154
 
157
155
  class SDXLWrapper(SD15Wrapper):
158
156
  def __init__(self, denoiser: UNet2DConditionModel, TE, vae: AutoencoderKL, noise_sampler: BaseSampler, tokenizer, min_attnmask=0,
159
- pred_type='eps', TE_hook_cfg:TEHookCFG=SDXL_TEHookCFG, cfg_context=CFGContext(), key_map_in=None, key_map_out=None):
160
- super().__init__(denoiser, TE, vae, noise_sampler, tokenizer, min_attnmask, pred_type, TE_hook_cfg, cfg_context, key_map_in, key_map_out)
157
+ TE_hook_cfg:TEHookCFG=SDXL_TEHookCFG, cfg_context=CFGContext(), key_map_in=None, key_map_out=None):
158
+ super().__init__(denoiser, TE, vae, noise_sampler, tokenizer, min_attnmask, TE_hook_cfg, cfg_context, key_map_in, key_map_out)
161
159
  self.key_mapper_in = self.build_mapper(key_map_in, None, (
162
160
  'prompt -> prompt_ids', 'image -> image', 'attn_mask -> attn_mask', 'position_ids -> position_ids', 'neg_prompt -> neg_prompt_ids',
163
161
  'neg_attn_mask -> neg_attn_mask', 'neg_position_ids -> neg_position_ids', 'plugin_input -> plugin_input', 'coord -> crop_info'))
@@ -195,8 +193,9 @@ class SDXLWrapper(SD15Wrapper):
195
193
  crop_info=None, plugin_input={}):
196
194
  # input prepare
197
195
  x_0 = self.get_latents(image)
198
- x_t, noise, sigma, timesteps = self.noise_sampler.add_noise_rand_t(x_0)
199
- x_t_in = x_t*self.noise_sampler.c_in(sigma).to(dtype=x_t.dtype)
196
+ x_t, noise, timesteps = self.noise_sampler.add_noise_rand_t(x_0)
197
+ x_t_in = x_t*self.noise_sampler.sigma_scheduler.c_in(timesteps).to(dtype=x_t.dtype).view(-1,1,1,1)
198
+ t_in = self.noise_sampler.sigma_scheduler.c_noise(timesteps)
200
199
 
201
200
  if neg_prompt_ids:
202
201
  prompt_ids = torch.cat([neg_prompt_ids, prompt_ids], dim=0)
@@ -206,13 +205,12 @@ class SDXLWrapper(SD15Wrapper):
206
205
  position_ids = torch.cat([neg_position_ids, position_ids], dim=0)
207
206
 
208
207
  # model forward
209
- x_t_in, timesteps = self.cfg_context.pre(x_t_in, timesteps)
210
- encoder_hidden_states, pooled_output = self.forward_TE(prompt_ids, timesteps, attn_mask=attn_mask, position_ids=position_ids,
208
+ x_t_in, t_in = self.cfg_context.pre(x_t_in, t_in)
209
+ encoder_hidden_states, pooled_output = self.forward_TE(prompt_ids, t_in, attn_mask=attn_mask, position_ids=position_ids,
211
210
  plugin_input=plugin_input)
212
211
  added_cond_kwargs = {"text_embeds":pooled_output[-1], "time_ids":crop_info}
213
- model_pred = self.forward_denoiser(x_t_in, prompt_ids, encoder_hidden_states, timesteps, added_cond_kwargs=added_cond_kwargs,
212
+ model_pred = self.forward_denoiser(x_t_in, prompt_ids, encoder_hidden_states, t_in, added_cond_kwargs=added_cond_kwargs,
214
213
  attn_mask=attn_mask, position_ids=position_ids, plugin_input=plugin_input)
215
214
  model_pred = self.cfg_context.post(model_pred)
216
215
 
217
- return dict(model_pred=model_pred, noise=noise, sigma=sigma, timesteps=timesteps, x_0=x_0, x_t=x_t, pred_type=self.pred_type,
218
- noise_sampler=self.noise_sampler)
216
+ return dict(model_pred=model_pred, noise=noise, timesteps=timesteps, x_0=x_0, x_t=x_t, noise_sampler=self.noise_sampler)
hcpdiff/trainer_ac.py CHANGED
@@ -4,8 +4,8 @@ import warnings
4
4
  import torch
5
5
  from rainbowneko.parser import load_config_with_cli
6
6
  from rainbowneko.ckpt_manager import NekoSaver
7
- from rainbowneko.train import Trainer
8
- from rainbowneko.utils import xformers_available, is_dict
7
+ from rainbowneko.train.trainer import Trainer
8
+ from rainbowneko.utils import xformers_available, is_dict, weight_dtype_map
9
9
  from hcpdiff.ckpt_manager import EmbFormat
10
10
 
11
11
  class HCPTrainer(Trainer):
@@ -17,7 +17,7 @@ class HCPTrainer(Trainer):
17
17
  warnings.warn("xformers is not available. Make sure it is installed correctly")
18
18
 
19
19
  if self.model_wrapper.vae is not None:
20
- self.vae_dtype = self.weight_dtype_map.get(self.cfgs.model.get('vae_dtype', None), torch.float32)
20
+ self.vae_dtype = weight_dtype_map.get(self.cfgs.model.get('vae_dtype', None), torch.float32)
21
21
  self.model_wrapper.set_dtype(self.weight_dtype, self.vae_dtype)
22
22
 
23
23
  if self.cfgs.model.gradient_checkpointing:
@@ -42,19 +42,14 @@ class HCPTrainer(Trainer):
42
42
  def pt_trainable(self):
43
43
  return self.cfgs.emb_pt is not None
44
44
 
45
- def get_loss(self, ds_name, model_pred, inputs):
46
- loss = super().get_loss(ds_name, model_pred, inputs)
47
- # make DDP happy
48
- if len(self.train_pts)>0:
49
- loss = loss+0*sum([emb.mean() for emb in self.train_pts.values()])
50
- return loss
51
-
52
45
  def save_model(self, from_raw=False):
53
46
  NekoSaver.save_all(
54
- self.model_raw,
55
- plugin_groups={**self.all_plugin, 'embs': self.train_pts},
56
47
  cfg=self.ckpt_saver,
48
+ model=self.model_raw,
49
+ plugin_groups=self.all_plugin,
50
+ embs=self.train_pts,
57
51
  model_ema=getattr(self, "ema_model", None),
52
+ optimizer=self.optimizer,
58
53
  name_template=f'{{}}-{self.real_step}',
59
54
  )
60
55
 
@@ -1,12 +1,7 @@
1
1
  import argparse
2
- import sys
3
- from functools import partial
4
-
5
- import torch
6
- from accelerate import Accelerator
7
- from loguru import logger
8
2
 
9
3
  from rainbowneko.train.trainer import TrainerSingleCard
4
+
10
5
  from hcpdiff.trainer_ac import HCPTrainer, load_config_with_cli
11
6
 
12
7
  class HCPTrainerSingleCard(TrainerSingleCard, HCPTrainer):
@@ -0,0 +1,47 @@
1
+ import argparse
2
+ import warnings
3
+
4
+ import torch
5
+ from rainbowneko.ckpt_manager import NekoPluginSaver
6
+ from rainbowneko.train.trainer import TrainerDeepspeed
7
+ from rainbowneko.utils import xformers_available
8
+
9
+ from hcpdiff.trainer_ac import HCPTrainer, load_config_with_cli
10
+
11
+ class HCPTrainerDeepspeed(TrainerDeepspeed, HCPTrainer):
12
+ def config_model(self):
13
+ if self.cfgs.model.enable_xformers:
14
+ if xformers_available:
15
+ self.model_wrapper.enable_xformers()
16
+ else:
17
+ warnings.warn("xformers is not available. Make sure it is installed correctly")
18
+
19
+ if self.model_wrapper.vae is not None:
20
+ self.vae_dtype = self.weight_dtype_map.get(self.cfgs.model.get('vae_dtype', None), torch.float32)
21
+ self.model_wrapper.set_dtype(self.weight_dtype, self.vae_dtype)
22
+
23
+ if self.cfgs.model.gradient_checkpointing:
24
+ self.model_wrapper.enable_gradient_checkpointing()
25
+
26
+ if self.is_local_main_process:
27
+ for saver in self.ckpt_saver.values():
28
+ if isinstance(saver, NekoPluginSaver):
29
+ saver.plugin_from_raw = True
30
+
31
+ def hcp_train():
32
+ import subprocess
33
+ parser = argparse.ArgumentParser(description='HCP-Diffusion Launcher')
34
+ parser.add_argument('--launch_cfg', type=str, default='cfgs/launcher/deepspeed.yaml')
35
+ args, train_args = parser.parse_known_args()
36
+
37
+ subprocess.run(["accelerate", "launch", '--config_file', args.launch_cfg, "-m",
38
+ "hcpdiff.trainer_deepspeed"]+train_args, check=True)
39
+
40
+ if __name__ == '__main__':
41
+ parser = argparse.ArgumentParser(description='HCP Diffusion Trainer for DeepSpeed')
42
+ parser.add_argument("--cfg", type=str, default=None, required=True)
43
+ args, cfg_args = parser.parse_known_args()
44
+
45
+ parser, conf = load_config_with_cli(args.cfg, args_list=cfg_args) # skip --cfg
46
+ trainer = HCPTrainerDeepspeed(parser, conf)
47
+ trainer.train()
hcpdiff/utils/__init__.py CHANGED
@@ -1,2 +1,3 @@
1
1
  from .utils import *
2
- from .net_utils import *
2
+ from .net_utils import *
3
+ from .torch_utils import invert_func
@@ -0,0 +1,25 @@
1
+ import torch
2
+
3
+ def invert_func(func, y, x_min=0.0, x_max=1.0, tol=1e-5, max_iter=100):
4
+ """
5
+ y: [B]
6
+ :return: x [B]
7
+ """
8
+ y = y.to(dtype=torch.float32)
9
+ left = torch.full_like(y, x_min)
10
+ right = torch.full_like(y, x_max)
11
+
12
+ for _ in range(max_iter):
13
+ mid = (left+right)/2
14
+ val = func(mid)
15
+
16
+ too_large = val>y
17
+ too_small = ~too_large
18
+
19
+ left = torch.where(too_small, mid, left)
20
+ right = torch.where(too_large, mid, right)
21
+
22
+ if torch.all(torch.abs(val-y)<tol):
23
+ break
24
+
25
+ return (left+right)/2
@@ -1,5 +1,5 @@
1
1
  from .diffusion import InputFeederAction, MakeLatentAction, DenoiseAction, SampleAction, DiffusionStepAction, \
2
- X0PredAction, SeedAction, MakeTimestepsAction, PrepareDiffusionAction, time_iter
2
+ X0PredAction, SeedAction, MakeTimestepsAction, PrepareDiffusionAction, time_iter, DiffusionActions
3
3
  from .text import TextEncodeAction, TextHookAction, AttnMultTextEncodeAction
4
4
  from .vae import EncodeAction, DecodeAction
5
5
  from .io import BuildModelsAction, SaveImageAction, LoadImageAction
@@ -6,7 +6,7 @@ import torch
6
6
  from hcpdiff.diffusion.sampler import BaseSampler, DiffusersSampler
7
7
  from hcpdiff.utils import prepare_seed
8
8
  from hcpdiff.utils.net_utils import get_dtype, to_cuda
9
- from rainbowneko.infer import BasicAction
9
+ from rainbowneko.infer import BasicAction, Actions
10
10
  from torch.cuda.amp import autocast
11
11
 
12
12
  try:
@@ -32,8 +32,9 @@ class SeedAction(BasicAction):
32
32
  self.seed = seed
33
33
  self.bs = bs
34
34
 
35
- def forward(self, device, seed=None, **states):
36
- bs = states['prompt_embeds'].shape[0]//2 if 'prompt_embeds' in states else self.bs
35
+ def forward(self, device, seed=None, bs=None, **states):
36
+ if bs is None:
37
+ bs = states['prompt_embeds'].shape[0]//2 if 'prompt_embeds' in states else self.bs
37
38
  seed = seed or self.seed
38
39
  if seed is None:
39
40
  seeds = [None]*bs
@@ -155,15 +156,16 @@ class DenoiseAction(BasicAction):
155
156
 
156
157
  with autocast(enabled=amp is not None, dtype=get_dtype(amp)):
157
158
  latent_model_input = torch.cat([latents]*2) if self.guidance_scale>1 else latents
158
- latent_model_input = noise_sampler.c_in(t)*latent_model_input
159
+ latent_model_input = noise_sampler.sigma_scheduler.c_in(t)*latent_model_input
160
+ t_in = noise_sampler.sigma_scheduler.c_noise(t)
159
161
 
160
162
  if text_embeds is None:
161
- noise_pred = denoiser(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
163
+ noise_pred = denoiser(latent_model_input, t_in, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
162
164
  cross_attention_kwargs=cross_attention_kwargs, ).sample
163
165
  else:
164
166
  added_cond_kwargs = {"text_embeds":text_embeds, "time_ids":crop_info}
165
167
  # predict the noise residual
166
- noise_pred = denoiser(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
168
+ noise_pred = denoiser(latent_model_input, t_in, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
167
169
  cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
168
170
 
169
171
  # perform guidance
@@ -189,10 +191,28 @@ class DiffusionStepAction(BasicAction):
189
191
  states = self.act_noise_pred(denoiser=denoiser, noise_sampler=noise_sampler, **states)
190
192
  states = self.act_sample(**states)
191
193
  return states
194
+
195
+ class DiffusionActions(Actions):
196
+ def __init__(self, actions: List[BasicAction], clean_latent=True, seed_inc=True, key_map_in=None, key_map_out=None):
197
+ super().__init__(actions, key_map_in=key_map_in, key_map_out=key_map_out)
198
+ self.clean_latent = clean_latent
199
+ self.seed_inc = seed_inc
200
+
201
+ def forward(self, **states):
202
+ states = super().forward(**states)
203
+ if self.seed_inc and 'seed' in states:
204
+ bs = states['latents'].shape[0]
205
+ states['seed'] = states['seed'] + bs
206
+ if self.clean_latent:
207
+ states.pop('noise_pred', None)
208
+ states.pop('latents', None)
209
+ states.pop('prompt', None)
210
+ states.pop('negative_prompt', None)
211
+ return states
192
212
 
193
213
  class X0PredAction(BasicAction):
194
214
  def forward(self, latents, noise_sampler: BaseSampler, t, noise_pred, **states):
195
- latents_x0 = noise_sampler.eps_to_x0(noise_pred, latents, t)
215
+ latents_x0 = noise_sampler.pred_for_target(noise_pred, latents, t, target_type='x0')
196
216
  return {'latents_x0':latents_x0}
197
217
 
198
218
  def time_iter(timesteps, **states):
hcpdiff/workflow/io.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import os
2
2
  from functools import partial
3
3
  from typing import List, Union
4
+ from addict import Addict
4
5
 
5
6
  import torch
6
7
  from hcpdiff.utils import to_validate_file
@@ -9,6 +10,8 @@ from rainbowneko.ckpt_manager import NekoLoader
9
10
  from rainbowneko.infer import BasicAction
10
11
  from rainbowneko.infer import LoadImageAction as Neko_LoadImageAction
11
12
  from rainbowneko.utils.img_size_tool import types_support
13
+ from rainbowneko import _share
14
+ from rainbowneko.utils import is_dict
12
15
 
13
16
  class BuildModelsAction(BasicAction):
14
17
  def __init__(self, model_loader: partial[NekoLoader.load], dtype: str=torch.float32, device='cuda', key_map_in=None, key_map_out=None):
@@ -23,6 +26,14 @@ class BuildModelsAction(BasicAction):
23
26
  else:
24
27
  model = self.model_loader(dtype=self.dtype, device=self.device)
25
28
 
29
+ # Callback for TokenizerHandler
30
+ if is_dict(model):
31
+ model_wrapper = Addict(model)
32
+ else:
33
+ model_wrapper = model
34
+ for callback in _share.model_callbacks:
35
+ callback(model_wrapper)
36
+
26
37
  if isinstance(model, dict):
27
38
  return model
28
39
  else:
@@ -33,12 +44,13 @@ class LoadImageAction(Neko_LoadImageAction):
33
44
  super().__init__(image_paths, image_transforms, key_map_in, key_map_out)
34
45
 
35
46
  class SaveImageAction(BasicAction):
36
- def __init__(self, save_root: str, image_type: str = 'png', quality: int = 95, save_cfg=True, key_map_in=None, key_map_out=None):
47
+ def __init__(self, save_root: str, image_type: str = 'png', quality: int = 95, save_cfg=True, save_txt=False, key_map_in=None, key_map_out=None):
37
48
  super().__init__(key_map_in, key_map_out)
38
49
  self.save_root = save_root
39
50
  self.image_type = image_type
40
51
  self.quality = quality
41
52
  self.save_cfg = save_cfg
53
+ self.save_txt = save_txt
42
54
 
43
55
  os.makedirs(save_root, exist_ok=True)
44
56
 
@@ -47,10 +59,15 @@ class SaveImageAction(BasicAction):
47
59
  num_img_exist = max([0]+[int(x.split('-', 1)[0]) for x in os.listdir(save_root) if x.rsplit('.', 1)[-1] in types_support])+1
48
60
 
49
61
  for bid, (p, pn, img) in enumerate(zip(prompt, negative_prompt, images)):
50
- img_path = os.path.join(save_root, f"{preview_step or num_img_exist}-{seeds[bid]}-{to_validate_file(prompt[0])}.{self.image_type}")
62
+ img_path = os.path.join(save_root, f"{preview_step or num_img_exist}-{seeds[bid]}-{to_validate_file(p)}.{self.image_type}")
51
63
  img.save(img_path, quality=self.quality)
52
- num_img_exist += 1
53
64
 
54
65
  if self.save_cfg:
55
66
  cfgs.seed = seeds[bid]
56
67
  parser.save_configs(cfgs, os.path.join(save_root, f"{preview_step or num_img_exist}-{seeds[bid]}-info"))
68
+
69
+ if self.save_txt:
70
+ txt_path = os.path.join(save_root, f"{preview_step or num_img_exist}-{seeds[bid]}-{to_validate_file(prompt[0])}.txt")
71
+ with open(txt_path, 'w') as f:
72
+ f.write(p)
73
+ num_img_exist += 1
hcpdiff/workflow/text.py CHANGED
@@ -38,7 +38,7 @@ class TextHookAction(BasicAction):
38
38
  return {'te_hook':te_hook, 'emb_hook':emb_hook, 'token_ex':token_ex}
39
39
 
40
40
  class TextEncodeAction(BasicAction):
41
- def __init__(self, prompt: Union[List, str], negative_prompt: Union[List, str], bs: int = None, key_map_in=None, key_map_out=None):
41
+ def __init__(self, prompt: List|str|None, negative_prompt: List|str|None, bs: int = None, key_map_in=None, key_map_out=None):
42
42
  super().__init__(key_map_in, key_map_out)
43
43
  if isinstance(prompt, str) and bs is not None:
44
44
  prompt = [prompt]*bs
@@ -73,6 +73,11 @@ class AttnMultTextEncodeAction(TextEncodeAction):
73
73
  prompt = prompt or self.prompt
74
74
  negative_prompt = negative_prompt or self.negative_prompt
75
75
 
76
+ if isinstance(negative_prompt, str) and isinstance(prompt, (list, tuple)):
77
+ negative_prompt = [negative_prompt]*len(prompt)
78
+ if isinstance(prompt, str) and isinstance(negative_prompt, (list, tuple)):
79
+ prompt = [prompt]*len(negative_prompt)
80
+
76
81
  if model_offload:
77
82
  to_cuda(TE)
78
83
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hcpdiff
3
- Version: 2.2
3
+ Version: 2.3
4
4
  Summary: A universal Diffusion toolbox
5
5
  Home-page: https://github.com/IrisRainbowNeko/HCP-Diffusion
6
6
  Author: Ziyi Dong
@@ -17,7 +17,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
17
17
  Requires-Python: >=3.8
18
18
  Description-Content-Type: text/markdown
19
19
  License-File: LICENSE
20
- Requires-Dist: rainbowneko
20
+ Requires-Dist: rainbowneko>=1.9
21
21
  Requires-Dist: diffusers
22
22
  Requires-Dist: matplotlib
23
23
  Requires-Dist: pyarrow
@@ -262,9 +262,13 @@ hcp_run --cfg cfgs/workflow/text2img_cli.py \
262
262
  seed=42
263
263
  ```
264
264
 
265
- ### Tutorials
265
+ ### 📚 Tutorials
266
266
 
267
- 🚧 In Development
267
+ + 🧠 [Model Training Guide](https://hcpdiff.readthedocs.io/en/latest/user_guides/train.html)
268
+ + 🔧 [LoRA Training Tutorial](https://hcpdiff.readthedocs.io/enlatest/tutorial/lora.html)
269
+ + 🎨 [Image Generation Guide](https://hcpdiff.readthedocs.io/en/latest/user_guides/workflow.html)
270
+ + ⚙️ [Configuration File Explanation](https://hcpdiff.readthedocs.io/en/latest/user_guides/cfg.html)
271
+ + 🧩 [Model Format Explanation](https://hcpdiff.readthedocs.io/en/latest/user_guides/model_format.html)
268
272
 
269
273
  ---
270
274
 
@@ -1,16 +1,16 @@
1
1
  hcpdiff/__init__.py,sha256=dwNwrEgvG4g60fGMG6b50K3q3AWD1XCfzlIgbxkSUpE,177
2
2
  hcpdiff/train_colo.py,sha256=EsuNSzLBvGTZWU_LEk0JpP-F5eNW0lwkawIRAX38jmE,9250
3
- hcpdiff/train_deepspeed.py,sha256=PwyNukWi0of6TXy_VRDgBQSMLCZBhipO5g3Lq0nCYNk,2988
4
- hcpdiff/trainer_ac.py,sha256=6KAzo54in7ZRHud_rHjJdwRRZ4uWtc0B4SxVCxgcrmM,2990
5
- hcpdiff/trainer_ac_single.py,sha256=0PIC5EScqcxp49EaeIWq4KS5K_09OZfKajqbFu-hUb8,1108
6
- hcpdiff/ckpt_manager/__init__.py,sha256=Mn_5KOC4xbf2GcN6OXg_XdbF5wO9zWeER_1ZO_prKAI,256
7
- hcpdiff/ckpt_manager/ckpt.py,sha256=Pa3uXQbCi2T99mpV5fYddQ-OGHcpk8r1ll-0lmP_WXk,965
8
- hcpdiff/ckpt_manager/loader.py,sha256=Ch1xsZmseq4nyPhpox9-nebN-dZB4k0rqBEHos-ZLso,3245
3
+ hcpdiff/trainer_ac.py,sha256=-owV-3_bvPxuQsZS2WaajBDh58HpftRtnx0GJkswqaY,2787
4
+ hcpdiff/trainer_ac_single.py,sha256=zyZVrutLUbIJYW1HzUnQ_RnmIcDhbC7M_CT833PJH5w,993
5
+ hcpdiff/trainer_deepspeed.py,sha256=7lGsiAstWuIlmhRMwWTcJCkoxzUaakVxBngKDnJdSJk,1947
6
+ hcpdiff/ckpt_manager/__init__.py,sha256=r_sgjZWCLtdJrRkqqU6aPdfubXSYfPh2Z_Vf_XpZXXs,240
7
+ hcpdiff/ckpt_manager/ckpt.py,sha256=2A093lT03M1ZsJIMWl376V165eh0TZwOgiGrz3LM73Q,1248
8
+ hcpdiff/ckpt_manager/loader.py,sha256=6iZDUj-Vfc5T9eGdWfFMQw4n1GqyLqaLBolgAtgqPq8,3640
9
9
  hcpdiff/ckpt_manager/format/__init__.py,sha256=a3cdKkOTDgdVbDQwSC4mlxOigjX2hBvRb5_X7E3TQWs,237
10
- hcpdiff/ckpt_manager/format/diffusers.py,sha256=T81WN95Nj1il9DfQp9iioVn0uqFEWOlmdIYs2beNOFU,3769
10
+ hcpdiff/ckpt_manager/format/diffusers.py,sha256=qhGbrKAaeLyjFzY-Lj4sL1THHFNrta41JGGMoXT-bCE,3761
11
11
  hcpdiff/ckpt_manager/format/emb.py,sha256=FrqfTfJ8H7f0Zw17NTWCP2AJtpsJI5oXR5IAd4NekhU,680
12
12
  hcpdiff/ckpt_manager/format/lora_webui.py,sha256=4y_T9RdmFTxWzsXd8guNjCiukmyILa5j4MPrhVIL4Qk,10017
13
- hcpdiff/ckpt_manager/format/sd_single.py,sha256=LpCAL_7nAVooCHTFznVVsNMku1G3C77NBORxxr8GDtQ,2328
13
+ hcpdiff/ckpt_manager/format/sd_single.py,sha256=4DZLAl1RNC_nPxuW-lmrBlIMFUhpSTa7HGHgu7Yx8qk,2322
14
14
  hcpdiff/data/__init__.py,sha256=ZFKtanOoMo3G3eKUJPhysnHXnr8BNARERkcMB6B897U,292
15
15
  hcpdiff/data/dataset.py,sha256=1k4GldW13eVyqK_9hrQniqr3_XYAapnWF7iXl_1GXGg,877
16
16
  hcpdiff/data/cache/__init__.py,sha256=ToCmokYH6DghlSwm7HJFirPRIWJ0LkgzqVOYlgoAkQw,25
@@ -25,44 +25,47 @@ hcpdiff/data/source/text.py,sha256=VgI5Ouq986Yy1jwD2fZ9iBlsRciPCeARZmOPEZIcaQY,1
25
25
  hcpdiff/data/source/text2img.py,sha256=acYdolQhZUEpkd7tUAdNkCTVnPc1SMJOVTmGqFt9ZpE,1813
26
26
  hcpdiff/data/source/text2img_cond.py,sha256=yj1KpARA2rkjENutnnzC4uDkcU2Rye21FL2VdC25Hac,585
27
27
  hcpdiff/diffusion/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
- hcpdiff/diffusion/noise/__init__.py,sha256=seBpOtd0YsU53PqMn7Nyl_RtwoC-ONEIOX7v2XLGpZQ,93
28
+ hcpdiff/diffusion/noise/__init__.py,sha256=D83EZ6bnc6Ucu4AZwE6rpmXCtwYfHHumeVq97brbnIE,47
29
29
  hcpdiff/diffusion/noise/pyramid_noise.py,sha256=KbpyMT1BHNIaAa7g5eECDkTttOMoMWVFmbP-ekBsuEY,1693
30
- hcpdiff/diffusion/noise/zero_terminal.py,sha256=EfVOaqrTCfw11AolDBl0LIOey3uQT1bDw2XKr2Bm434,1532
31
- hcpdiff/diffusion/sampler/__init__.py,sha256=pSHsKpLjscY5yLbdzHeBUeK9nFDuVeMIIeA_k6FQFdY,158
32
- hcpdiff/diffusion/sampler/base.py,sha256=2AuPVT2ZSXYt2etZmHMyNKuGlT5zn6KIkoMz4m5PGcs,2577
33
- hcpdiff/diffusion/sampler/ddpm.py,sha256=raqSuKsEPN1AEqRVCuBdMAOnKDoeJTRO17wtLBNJCf4,523
34
- hcpdiff/diffusion/sampler/diffusers.py,sha256=XIu-oIlT4LnAYI2-yyIoNeIMeSQe5YpeEvn-FkGVFnE,2684
35
- hcpdiff/diffusion/sampler/edm.py,sha256=5W4pv8hxQsPpJGiFBgElZxR3C9S8kWAhzGKejEVwq3I,753
36
- hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py,sha256=kmIoWgsWqij6b7KYon3UOCSC07sRJCo-DPR6qkJwUd0,184
37
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py,sha256=9-JI-jwf7xZoQUtrU0qfbjkhNZT8a_tmapLtwVbFUx0,381
38
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py,sha256=2PMIpg2K6CVoxew1y1pIqvCHbdggC-m3amrOYk15OdQ,8107
39
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py,sha256=fOPB3lgnS9uVo4oW26Fur_nc8X_wQ6mmUcbkKhnoQjs,1900
30
+ hcpdiff/diffusion/sampler/VP.py,sha256=r0Q_RROEIeNNw93XrOD5htW78rfuoSxy1WBQEoQL83s,958
31
+ hcpdiff/diffusion/sampler/__init__.py,sha256=Lrwg1us8qo943T7mdIXFDRXfKvnLhrzwmi6DrIKIiUA,135
32
+ hcpdiff/diffusion/sampler/base.py,sha256=UbE_AmtvLg-Hr2bkYz8PvNWB63tvtacUvCIDm_W6opA,5484
33
+ hcpdiff/diffusion/sampler/diffusers.py,sha256=wIMs8n3kdci1On0FUCV0si324ZE9zeRw_CxaHP8rdcs,2586
34
+ hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py,sha256=eiSmMBkXI_LfxnNrXj5XptcF0dGcPas--vWvqhFGlv8,273
35
+ hcpdiff/diffusion/sampler/sigma_scheduler/base.py,sha256=UT4tbjFf80KYfU08y0hJf8h_Cl80a5MUhK5FsKLsqbY,2521
36
+ hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py,sha256=SA8lXT6lucJot_rpJ84Wz-_uc5dXfb2QPoQgJHSKOj4,12999
37
+ hcpdiff/diffusion/sampler/sigma_scheduler/edm.py,sha256=m1YlIyn61zfDjxLMcHvWs0nzULbHgXeB7WGKmTiaGSU,4127
38
+ hcpdiff/diffusion/sampler/sigma_scheduler/flow.py,sha256=FtWpesUtSmFuiIGkrrVhYJweB7INZiw0atC64tc0Nk4,2020
39
+ hcpdiff/diffusion/sampler/sigma_scheduler/zero_terminal.py,sha256=CCqQLkGo4omkxzFovYdZQzdZVwIxK3PiOitZFww8MHs,859
40
40
  hcpdiff/easy/__init__.py,sha256=-emoyCOZlLCu3KNMI8L4qapUEtEYFSoiGU6-rKv1at4,149
41
41
  hcpdiff/easy/sampler.py,sha256=dQSBkeGh71O0DAmZLhTHTbk1bY7XzyUCeW1oJO14A4I,1250
42
42
  hcpdiff/easy/cfg/__init__.py,sha256=SxHMWG6T2CXhX3dP0xizSMd9vFWPaZQDc4Gj4CF__yQ,253
43
- hcpdiff/easy/cfg/sd15_train.py,sha256=KtplqN-OhzdZjsX2s60J3XR6o7tRJ-QDx7Eqza_eDkM,6704
44
- hcpdiff/easy/cfg/sdxl_train.py,sha256=ZKfJ19IvR2dZqDNXULmhZEmqjE7qV4QYxSTvEhI7efQ,4269
43
+ hcpdiff/easy/cfg/sd15_train.py,sha256=NtgsQLg1sd5JFmHU4nqMPOrvP7zmwo2x0MCspjVNQEY,7000
44
+ hcpdiff/easy/cfg/sdxl_train.py,sha256=rVLLKVMKB_PHuum3dKQcBqL0uR8QhzmdRllM-pYnbK4,4534
45
45
  hcpdiff/easy/cfg/t2i.py,sha256=SnjFjZAKd9orjJr3RW5_N2_EIlW2Ree7JMvdNUAR9gc,9507
46
46
  hcpdiff/easy/model/__init__.py,sha256=CA-7r3R2Jgweekk1XNByFYttLolbWyUV2bCnXygcD8w,133
47
47
  hcpdiff/easy/model/cnet.py,sha256=m0NTH9V1kLzb5GybwBrSNT0KvTcRpPfGkzUeMz9jZZQ,1084
48
48
  hcpdiff/easy/model/loader.py,sha256=Tdx-lhQEYf2NYjVM1A5B8x6ZZpJKcXUkFIPIbr7h7XM,3456
49
- hcpdiff/evaluate/__init__.py,sha256=CtNzi8xdUWZBDBcP5TZTDMcRyykaOJhBIxTJgVuMabo,35
50
- hcpdiff/evaluate/previewer.py,sha256=QiYYiEJBKP06uL3wKLhnpIGUZuAkr9BuHxTE29obpXI,2432
49
+ hcpdiff/evaluate/__init__.py,sha256=qWxV0D8Ho5uBj2YbaC_QFDnT49PSKPfh44m4ivkNbMM,108
50
+ hcpdiff/evaluate/evaluator.py,sha256=9BZQBeC-N7p-ICx6Giw9v-2Tb9volMTDmeDfhj0nXJ0,2940
51
+ hcpdiff/evaluate/previewer.py,sha256=-vE0YXVfos70CQMo9ZInw7xu3d88DlTfVLs4BzzkxfM,3140
52
+ hcpdiff/evaluate/metrics/__init__.py,sha256=vE0nSvBtDBu9SomANvWcm2UHX56PhCYwhgrcmm_mKyo,39
53
+ hcpdiff/evaluate/metrics/clip_score.py,sha256=rQgweu5QcqW3fPI3EXcNbrH2QCcSAekE3lpYk45P2M4,900
51
54
  hcpdiff/loss/__init__.py,sha256=2dwPczSiv3rB5fzOeYbl5ZHpMU-qXOQlXeOiXdxcxwM,173
52
- hcpdiff/loss/base.py,sha256=3bvgMbwyPOEA9iSkv0hRHw4VnKjkUCZAENNnDMFilYM,1780
55
+ hcpdiff/loss/base.py,sha256=Vvpm-KZGH4n-gYIlnVAtPl1B799c7v0dJXJ5BBh3yO0,1112
53
56
  hcpdiff/loss/gw.py,sha256=0yi1kozuII3xZA6FnjOhINtvScWt1MyBZLBtMKmgojM,1224
54
57
  hcpdiff/loss/ssim.py,sha256=YofadvBkc6sklxBUx1p3ADw5OHOZPK3kaHz8FH5a6m4,1281
55
58
  hcpdiff/loss/vlb.py,sha256=s78iBnXUiDWfGf7mYmhUnHqxqea5gSByKOoqBrX6bzU,3222
56
- hcpdiff/loss/weighting.py,sha256=UR2PyZ1JTNOydXMw4e1Fh52XmtwKaviPvddcmVKCTlI,2242
59
+ hcpdiff/loss/weighting.py,sha256=9qzMnvCb6b5qx0p08GDSlkxmYEqQcNt79XdRBvfHmiI,2914
57
60
  hcpdiff/models/__init__.py,sha256=eQS7DPiGLiE1MFRkZj_17IY3IsfDUVcYpcOmhHb5B9o,472
58
61
  hcpdiff/models/cfg_context.py,sha256=e2B3K1KwJhzbD6xdJUOyNtl_XgQ0296XI3FHw3gvZF4,1502
59
62
  hcpdiff/models/container.py,sha256=z3p5TmQhxdzXSIfofz55_bmEhSsgUJsy1o9EcDs8Oeo,696
60
63
  hcpdiff/models/controlnet.py,sha256=VIkUzJCVpCqqQOtRSLQPfbcDy9CsXutxLeZB6PdZfA0,7809
61
64
  hcpdiff/models/lora_base.py,sha256=LGwBD9KP6qf4pgTx24i5-JLo4rDBQ6jFfterQKBjTbE,6758
62
- hcpdiff/models/lora_base_patch.py,sha256=WW3CULnROTxKXyynJiqirhHYCKN5JtxLhVpT5b7AUQg,6532
65
+ hcpdiff/models/lora_base_patch.py,sha256=Tdb_b3TN_K-04nlUvcfBh6flPcbL9M4iP7jOVyb1jXQ,7271
63
66
  hcpdiff/models/lora_layers.py,sha256=O9W_Ue71lHj7Y_GbpioF4Hc3h2-z_zOqck93VYUra6s,7777
64
67
  hcpdiff/models/lora_layers_patch.py,sha256=GYFYsJD2VSLZfdnLma9CmQEHz09HROFJcc4wc_gs9f0,8198
65
- hcpdiff/models/text_emb_ex.py,sha256=a5QImxzvj0zWR12qXOPP9kmpESl8J9VLabA0W9D_i_c,7867
68
+ hcpdiff/models/text_emb_ex.py,sha256=O0XZqid01OrB0dHY7hCiBvdU2026SvZ38yfQaF2TWrs,8018
66
69
  hcpdiff/models/textencoder_ex.py,sha256=JrTQ30Avx8tPbdr-Q6K5BvEWCEdsu8Z7eSOzMqpUuzg,8270
67
70
  hcpdiff/models/tokenizer_ex.py,sha256=zKUn4BY7b3yXwK9PWkZtQKJPyKYwUc07E-hwB9NQybs,2446
68
71
  hcpdiff/models/compose/__init__.py,sha256=lTNFTGg5csqvUuys22RqgjmWlk_7Okw6ZTsnTi1pqCg,217
@@ -72,7 +75,7 @@ hcpdiff/models/compose/compose_tokenizer.py,sha256=g3l0pOFv6p7Iigxm6Pqt_iTUXBlO1
72
75
  hcpdiff/models/compose/sdxl_composer.py,sha256=NtMGaFGZTfKsPJSVi2yT-UM6K1WKWtk99XxVmTcKlk8,2164
73
76
  hcpdiff/models/wrapper/__init__.py,sha256=HbGQmFnfccr-dtvZKjEv-pmR4cCnF4fwGLKS3tuG_OY,135
74
77
  hcpdiff/models/wrapper/pixart.py,sha256=nRUvHSHn4TYg_smC0xpeW-GtUgXss-MuaVPTHpMozDE,1147
75
- hcpdiff/models/wrapper/sd.py,sha256=D7VDI4OmbLTk9mzYta-C4LJjWfZmuBiDub4t8v1-M9o,11711
78
+ hcpdiff/models/wrapper/sd.py,sha256=EywmVU2QzR74M_4eH_uXVW8HJNauyjwcZPU7rRAQ7eI,11666
76
79
  hcpdiff/models/wrapper/utils.py,sha256=NyebMoAPnrgcTHbiIocSD-eGdGdD-V1G_TQuWsRWufw,665
77
80
  hcpdiff/parser/__init__.py,sha256=-2dDZ2Ii4zoGQqDTme94q4PpJbBiV6HS5BsDASz4Xbo,33
78
81
  hcpdiff/parser/embpt.py,sha256=LgwZ0f0tLn3DrTo5ZpSCsZcA5330UpiW_sK96yEPmOM,1307
@@ -89,27 +92,28 @@ hcpdiff/tools/init_proj.py,sha256=XrXxxhIaItywG7HsrloJo-x8w9suZiY35daelzZvjrg,19
89
92
  hcpdiff/tools/lora_convert.py,sha256=So14WvSVIm6rU4m1XCajFXDnhq7abpZS95SLbaoyBFU,10058
90
93
  hcpdiff/tools/save_model.py,sha256=gbfYi_EfEBZEUcDjle6MDHA19sQWY0zA8_y_LMzHQ7M,428
91
94
  hcpdiff/tools/sd2diffusers.py,sha256=vB6OnBLw60sJkdpVZcYEPtKAZW1h8ErbSGSRq0uAiIk,16855
92
- hcpdiff/utils/__init__.py,sha256=VOLhdNq2mRyqmWxrssIWSZtR_PQ8rFwo2u0uq6GbLHA,45
95
+ hcpdiff/utils/__init__.py,sha256=28K9Ui0uur-vHuUdlSyIBYijgu2b7rGOPXN2ogJu1z8,82
93
96
  hcpdiff/utils/colo_utils.py,sha256=JyLUvVnISa48CnryNLrgVxMo-jxu2UhBq70eYPrkjuI,837
94
97
  hcpdiff/utils/inpaint_pipe.py,sha256=CRy1MUlPmHifCAbZnKOP0qbLp2grn7ZbVeaB2qIA4ig,42862
95
98
  hcpdiff/utils/net_utils.py,sha256=gdwLYDNKV2t3SP0jBIO3d0HtY6E7jRaf_rmPT8gKZZE,9762
96
99
  hcpdiff/utils/pipe_hook.py,sha256=-UDX3FtZGl-bxSk13gdbPXc1OvtbCcpk_fvKxLQo3Ag,31987
100
+ hcpdiff/utils/torch_utils.py,sha256=gBZCcDKZc0NGDQx6QeHuQePoZ82kQRhaL7oEdZIYGvU,573
97
101
  hcpdiff/utils/utils.py,sha256=hZnZP1IETgVpScxES0yIuRfc34TnzvAqmgOTK_56ssw,4976
98
- hcpdiff/workflow/__init__.py,sha256=t7Zyc0XFORdNvcwHp9AsCtEkhJ3l7Hm41ugngIL0Sag,867
99
- hcpdiff/workflow/diffusion.py,sha256=yzhqKA3019OPu1RKggrLoytMgm919qf6j9S85PYOwjQ,8644
102
+ hcpdiff/workflow/__init__.py,sha256=i5s7QXo6wK9607KL0KTW4suE1c-HGJ5_EgnCdVLl3WM,885
103
+ hcpdiff/workflow/diffusion.py,sha256=hKefBrVP6-025MhdrKOQMUhHxLaGqjpUKhR6WahYwh0,9549
100
104
  hcpdiff/workflow/fast.py,sha256=kZt7bKrvpFInSn7GzbkTkpoCSM0Z6IbDjgaDvcbFYf8,1024
101
105
  hcpdiff/workflow/flow.py,sha256=FFbFFOAXT4c31L5bHBEB_qeVGuBQDLYhq8kTD1chGNo,2548
102
- hcpdiff/workflow/io.py,sha256=aTrMR3s44apVJpnSyvZIabW2Op0tslk_Z9JFJl5svm0,2635
106
+ hcpdiff/workflow/io.py,sha256=4oiE_PS3sOVYT8M6PDwvT5h9XzoKDMQR0n_4-Ktttys,3284
103
107
  hcpdiff/workflow/model.py,sha256=1gj5yOTefYTnGXVR6JPAfxIwuB69YwN6E-BontRcuyQ,2913
104
- hcpdiff/workflow/text.py,sha256=Z__SJHZyuaKyzkYJ6rbiAzOGRiYcCjwCGeqfpP1Jo7o,4336
108
+ hcpdiff/workflow/text.py,sha256=XQvN4zzK7VaGxy4FDgSDeWh2jjk7UZU24moeRKAWXRE,4608
105
109
  hcpdiff/workflow/utils.py,sha256=xojaMG4lHsymslc8df5uiVXmmBVWpn_Phqka8qzJEWw,2226
106
110
  hcpdiff/workflow/vae.py,sha256=cingDPkIOc4qGpOwwhXJK4EQbGoIxO583pm6gGov5t8,3118
107
111
  hcpdiff/workflow/daam/__init__.py,sha256=ySIDaxloN-D3qM7OuVaG1BR3D-CibDoXYpoTgw0zUhU,59
108
112
  hcpdiff/workflow/daam/act.py,sha256=tHbsFWTYYU4bvcZOo1Bpi_z6ofpJatRYccl4vvf8wIA,2756
109
113
  hcpdiff/workflow/daam/hook.py,sha256=z9f9mBjKW21xuUZ-iQxQ0HbWOBXtZrisFB0VNMq6d0U,4383
110
- hcpdiff-2.2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
111
- hcpdiff-2.2.dist-info/METADATA,sha256=u52mZtA0hI2P_fObmJZRUkZZfnKFYg5c24f4p0trH0o,9833
112
- hcpdiff-2.2.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
113
- hcpdiff-2.2.dist-info/entry_points.txt,sha256=86wPOMzsfWWflTJ-sQPLc7WG5Vtu0kGYBH9C_vR3ur8,207
114
- hcpdiff-2.2.dist-info/top_level.txt,sha256=shyf78x-HVgykYpsmY22mKG0xIc7Qk30fDMdavdYWQ8,8
115
- hcpdiff-2.2.dist-info/RECORD,,
114
+ hcpdiff-2.3.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
115
+ hcpdiff-2.3.dist-info/METADATA,sha256=1aqZH8IwB7WjDBot_fADwyfLDVNiRZuZdub2-zezFck,10321
116
+ hcpdiff-2.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
117
+ hcpdiff-2.3.dist-info/entry_points.txt,sha256=_4VRsEsEWOhHfzBDu9bx8Wh_S8Wi4ZTHpI0n6rU0J-I,258
118
+ hcpdiff-2.3.dist-info/top_level.txt,sha256=shyf78x-HVgykYpsmY22mKG0xIc7Qk30fDMdavdYWQ8,8
119
+ hcpdiff-2.3.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.0)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -2,4 +2,5 @@
2
2
  hcp_run = rainbowneko.infer.infer_workflow:run_workflow
3
3
  hcp_train = hcpdiff.trainer_ac:hcp_train
4
4
  hcp_train_1gpu = hcpdiff.trainer_ac_single:hcp_train
5
+ hcp_train_ds = hcpdiff.trainer_deepspeed:hcp_train
5
6
  hcpinit = hcpdiff.tools.init_proj:main
@@ -1,39 +0,0 @@
1
- import torch
2
- from ..sampler.sigma_scheduler import DDPMDiscreteSigmaScheduler
3
-
4
- class ZeroTerminalSampler:
5
-
6
- @classmethod
7
- def patch(cls, base_sampler):
8
- assert isinstance(base_sampler.sigma_scheduler, DDPMDiscreteSigmaScheduler), "ZeroTerminalScheduler only works with DDPM SigmaScheduler"
9
-
10
- alphas_cumprod = base_sampler.sigma_scheduler.alphas_cumprod
11
- base_sampler.sigma_scheduler.alphas_cumprod = cls.rescale_zero_terminal_snr(alphas_cumprod)
12
- base_sampler.sigma_scheduler.sigmas = ((1-alphas_cumprod)/alphas_cumprod).sqrt()
13
-
14
-
15
- @staticmethod
16
- def rescale_zero_terminal_snr(alphas_cumprod, thr=1e-4):
17
- """
18
- Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
19
- Args:
20
- alphas_cumprod (`torch.FloatTensor`)
21
- Returns:
22
- `torch.FloatTensor`: rescaled betas with zero terminal SNR
23
- """
24
- alphas_bar_sqrt = alphas_cumprod.sqrt()
25
-
26
- # Store old values.
27
- alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
28
- alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
29
-
30
- # Shift so the last timestep is zero.
31
- alphas_bar_sqrt -= alphas_bar_sqrt_T
32
-
33
- # Scale so the first timestep is back to the old value.
34
- alphas_bar_sqrt *= alphas_bar_sqrt_0/(alphas_bar_sqrt_0-alphas_bar_sqrt_T)
35
- alphas_bar_sqrt[-1] = thr # avoid nan sigma
36
-
37
- # Convert alphas_bar_sqrt to betas
38
- alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
39
- return alphas_bar