hcpdiff 2.2__py3-none-any.whl → 2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/ckpt_manager/__init__.py +1 -1
- hcpdiff/ckpt_manager/ckpt.py +21 -17
- hcpdiff/ckpt_manager/format/diffusers.py +4 -4
- hcpdiff/ckpt_manager/format/sd_single.py +3 -3
- hcpdiff/ckpt_manager/loader.py +11 -4
- hcpdiff/diffusion/noise/__init__.py +0 -1
- hcpdiff/diffusion/sampler/VP.py +27 -0
- hcpdiff/diffusion/sampler/__init__.py +2 -3
- hcpdiff/diffusion/sampler/base.py +106 -44
- hcpdiff/diffusion/sampler/diffusers.py +11 -17
- hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -1
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +77 -2
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +193 -49
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +110 -33
- hcpdiff/diffusion/sampler/sigma_scheduler/flow.py +74 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/zero_terminal.py +22 -0
- hcpdiff/easy/cfg/sd15_train.py +35 -24
- hcpdiff/easy/cfg/sdxl_train.py +34 -25
- hcpdiff/evaluate/__init__.py +3 -1
- hcpdiff/evaluate/evaluator.py +76 -0
- hcpdiff/evaluate/metrics/__init__.py +1 -0
- hcpdiff/evaluate/metrics/clip_score.py +23 -0
- hcpdiff/evaluate/previewer.py +29 -12
- hcpdiff/loss/base.py +9 -26
- hcpdiff/loss/weighting.py +36 -18
- hcpdiff/models/lora_base_patch.py +26 -0
- hcpdiff/models/text_emb_ex.py +4 -0
- hcpdiff/models/wrapper/sd.py +17 -19
- hcpdiff/trainer_ac.py +7 -12
- hcpdiff/trainer_ac_single.py +1 -6
- hcpdiff/trainer_deepspeed.py +47 -0
- hcpdiff/utils/__init__.py +2 -1
- hcpdiff/utils/torch_utils.py +25 -0
- hcpdiff/workflow/__init__.py +1 -1
- hcpdiff/workflow/diffusion.py +27 -7
- hcpdiff/workflow/io.py +20 -3
- hcpdiff/workflow/text.py +6 -1
- {hcpdiff-2.2.dist-info → hcpdiff-2.3.dist-info}/METADATA +8 -4
- {hcpdiff-2.2.dist-info → hcpdiff-2.3.dist-info}/RECORD +43 -39
- {hcpdiff-2.2.dist-info → hcpdiff-2.3.dist-info}/WHEEL +1 -1
- {hcpdiff-2.2.dist-info → hcpdiff-2.3.dist-info}/entry_points.txt +1 -0
- hcpdiff/diffusion/noise/zero_terminal.py +0 -39
- hcpdiff/diffusion/sampler/ddpm.py +0 -20
- hcpdiff/diffusion/sampler/edm.py +0 -22
- hcpdiff/train_deepspeed.py +0 -69
- {hcpdiff-2.2.dist-info → hcpdiff-2.3.dist-info}/licenses/LICENSE +0 -0
- {hcpdiff-2.2.dist-info → hcpdiff-2.3.dist-info}/top_level.txt +0 -0
hcpdiff/models/wrapper/sd.py
CHANGED
@@ -17,7 +17,7 @@ from ..cfg_context import CFGContext
|
|
17
17
|
|
18
18
|
class SD15Wrapper(BaseWrapper):
|
19
19
|
def __init__(self, denoiser: UNet2DConditionModel, TE, vae: AutoencoderKL, noise_sampler: BaseSampler, tokenizer, min_attnmask=0,
|
20
|
-
|
20
|
+
TE_hook_cfg:TEHookCFG=SD15_TEHookCFG, cfg_context=CFGContext(), key_map_in=None, key_map_out=None):
|
21
21
|
super().__init__()
|
22
22
|
self.key_mapper_in = self.build_mapper(key_map_in, None, (
|
23
23
|
'prompt -> prompt_ids', 'image -> image', 'attn_mask -> attn_mask', 'position_ids -> position_ids', 'neg_prompt -> neg_prompt_ids',
|
@@ -31,8 +31,6 @@ class SD15Wrapper(BaseWrapper):
|
|
31
31
|
self.tokenizer = tokenizer
|
32
32
|
self.min_attnmask = min_attnmask
|
33
33
|
|
34
|
-
self.pred_type = pred_type
|
35
|
-
|
36
34
|
self.TE_hook_cfg = TEHookCFG.create(TE_hook_cfg)
|
37
35
|
self.cfg_context = cfg_context
|
38
36
|
self.tokenizer.N_repeats = self.TE_hook_cfg.tokenizer_repeats
|
@@ -93,8 +91,9 @@ class SD15Wrapper(BaseWrapper):
|
|
93
91
|
plugin_input={}, **kwargs):
|
94
92
|
# input prepare
|
95
93
|
x_0 = self.get_latents(image)
|
96
|
-
x_t, noise,
|
97
|
-
x_t_in = x_t*self.noise_sampler.c_in(
|
94
|
+
x_t, noise, timesteps = self.noise_sampler.add_noise_rand_t(x_0)
|
95
|
+
x_t_in = x_t*self.noise_sampler.sigma_scheduler.c_in(timesteps).to(dtype=x_t.dtype).view(-1,1,1,1)
|
96
|
+
t_in = self.noise_sampler.sigma_scheduler.c_noise(timesteps)
|
98
97
|
|
99
98
|
if neg_prompt_ids:
|
100
99
|
prompt_ids = torch.cat([neg_prompt_ids, prompt_ids], dim=0)
|
@@ -104,15 +103,14 @@ class SD15Wrapper(BaseWrapper):
|
|
104
103
|
position_ids = torch.cat([neg_position_ids, position_ids], dim=0)
|
105
104
|
|
106
105
|
# model forward
|
107
|
-
x_t_in,
|
108
|
-
encoder_hidden_states = self.forward_TE(prompt_ids,
|
106
|
+
x_t_in, t_in = self.cfg_context.pre(x_t_in, t_in)
|
107
|
+
encoder_hidden_states = self.forward_TE(prompt_ids, t_in, attn_mask=attn_mask, position_ids=position_ids,
|
109
108
|
plugin_input=plugin_input, **kwargs)
|
110
|
-
model_pred = self.forward_denoiser(x_t_in, prompt_ids, encoder_hidden_states,
|
109
|
+
model_pred = self.forward_denoiser(x_t_in, prompt_ids, encoder_hidden_states, t_in, attn_mask=attn_mask, position_ids=position_ids,
|
111
110
|
plugin_input=plugin_input, **kwargs)
|
112
111
|
model_pred = self.cfg_context.post(model_pred)
|
113
112
|
|
114
|
-
return dict(model_pred=model_pred, noise=noise,
|
115
|
-
noise_sampler=self.noise_sampler)
|
113
|
+
return dict(model_pred=model_pred, noise=noise, timesteps=timesteps, x_0=x_0, x_t=x_t, noise_sampler=self.noise_sampler)
|
116
114
|
|
117
115
|
def forward(self, ds_name=None, **kwargs):
|
118
116
|
model_args, model_kwargs = self.get_map_data(self.key_mapper_in, kwargs, ds_name)
|
@@ -156,8 +154,8 @@ class SD15Wrapper(BaseWrapper):
|
|
156
154
|
|
157
155
|
class SDXLWrapper(SD15Wrapper):
|
158
156
|
def __init__(self, denoiser: UNet2DConditionModel, TE, vae: AutoencoderKL, noise_sampler: BaseSampler, tokenizer, min_attnmask=0,
|
159
|
-
|
160
|
-
super().__init__(denoiser, TE, vae, noise_sampler, tokenizer, min_attnmask,
|
157
|
+
TE_hook_cfg:TEHookCFG=SDXL_TEHookCFG, cfg_context=CFGContext(), key_map_in=None, key_map_out=None):
|
158
|
+
super().__init__(denoiser, TE, vae, noise_sampler, tokenizer, min_attnmask, TE_hook_cfg, cfg_context, key_map_in, key_map_out)
|
161
159
|
self.key_mapper_in = self.build_mapper(key_map_in, None, (
|
162
160
|
'prompt -> prompt_ids', 'image -> image', 'attn_mask -> attn_mask', 'position_ids -> position_ids', 'neg_prompt -> neg_prompt_ids',
|
163
161
|
'neg_attn_mask -> neg_attn_mask', 'neg_position_ids -> neg_position_ids', 'plugin_input -> plugin_input', 'coord -> crop_info'))
|
@@ -195,8 +193,9 @@ class SDXLWrapper(SD15Wrapper):
|
|
195
193
|
crop_info=None, plugin_input={}):
|
196
194
|
# input prepare
|
197
195
|
x_0 = self.get_latents(image)
|
198
|
-
x_t, noise,
|
199
|
-
x_t_in = x_t*self.noise_sampler.c_in(
|
196
|
+
x_t, noise, timesteps = self.noise_sampler.add_noise_rand_t(x_0)
|
197
|
+
x_t_in = x_t*self.noise_sampler.sigma_scheduler.c_in(timesteps).to(dtype=x_t.dtype).view(-1,1,1,1)
|
198
|
+
t_in = self.noise_sampler.sigma_scheduler.c_noise(timesteps)
|
200
199
|
|
201
200
|
if neg_prompt_ids:
|
202
201
|
prompt_ids = torch.cat([neg_prompt_ids, prompt_ids], dim=0)
|
@@ -206,13 +205,12 @@ class SDXLWrapper(SD15Wrapper):
|
|
206
205
|
position_ids = torch.cat([neg_position_ids, position_ids], dim=0)
|
207
206
|
|
208
207
|
# model forward
|
209
|
-
x_t_in,
|
210
|
-
encoder_hidden_states, pooled_output = self.forward_TE(prompt_ids,
|
208
|
+
x_t_in, t_in = self.cfg_context.pre(x_t_in, t_in)
|
209
|
+
encoder_hidden_states, pooled_output = self.forward_TE(prompt_ids, t_in, attn_mask=attn_mask, position_ids=position_ids,
|
211
210
|
plugin_input=plugin_input)
|
212
211
|
added_cond_kwargs = {"text_embeds":pooled_output[-1], "time_ids":crop_info}
|
213
|
-
model_pred = self.forward_denoiser(x_t_in, prompt_ids, encoder_hidden_states,
|
212
|
+
model_pred = self.forward_denoiser(x_t_in, prompt_ids, encoder_hidden_states, t_in, added_cond_kwargs=added_cond_kwargs,
|
214
213
|
attn_mask=attn_mask, position_ids=position_ids, plugin_input=plugin_input)
|
215
214
|
model_pred = self.cfg_context.post(model_pred)
|
216
215
|
|
217
|
-
return dict(model_pred=model_pred, noise=noise,
|
218
|
-
noise_sampler=self.noise_sampler)
|
216
|
+
return dict(model_pred=model_pred, noise=noise, timesteps=timesteps, x_0=x_0, x_t=x_t, noise_sampler=self.noise_sampler)
|
hcpdiff/trainer_ac.py
CHANGED
@@ -4,8 +4,8 @@ import warnings
|
|
4
4
|
import torch
|
5
5
|
from rainbowneko.parser import load_config_with_cli
|
6
6
|
from rainbowneko.ckpt_manager import NekoSaver
|
7
|
-
from rainbowneko.train import Trainer
|
8
|
-
from rainbowneko.utils import xformers_available, is_dict
|
7
|
+
from rainbowneko.train.trainer import Trainer
|
8
|
+
from rainbowneko.utils import xformers_available, is_dict, weight_dtype_map
|
9
9
|
from hcpdiff.ckpt_manager import EmbFormat
|
10
10
|
|
11
11
|
class HCPTrainer(Trainer):
|
@@ -17,7 +17,7 @@ class HCPTrainer(Trainer):
|
|
17
17
|
warnings.warn("xformers is not available. Make sure it is installed correctly")
|
18
18
|
|
19
19
|
if self.model_wrapper.vae is not None:
|
20
|
-
self.vae_dtype =
|
20
|
+
self.vae_dtype = weight_dtype_map.get(self.cfgs.model.get('vae_dtype', None), torch.float32)
|
21
21
|
self.model_wrapper.set_dtype(self.weight_dtype, self.vae_dtype)
|
22
22
|
|
23
23
|
if self.cfgs.model.gradient_checkpointing:
|
@@ -42,19 +42,14 @@ class HCPTrainer(Trainer):
|
|
42
42
|
def pt_trainable(self):
|
43
43
|
return self.cfgs.emb_pt is not None
|
44
44
|
|
45
|
-
def get_loss(self, ds_name, model_pred, inputs):
|
46
|
-
loss = super().get_loss(ds_name, model_pred, inputs)
|
47
|
-
# make DDP happy
|
48
|
-
if len(self.train_pts)>0:
|
49
|
-
loss = loss+0*sum([emb.mean() for emb in self.train_pts.values()])
|
50
|
-
return loss
|
51
|
-
|
52
45
|
def save_model(self, from_raw=False):
|
53
46
|
NekoSaver.save_all(
|
54
|
-
self.model_raw,
|
55
|
-
plugin_groups={**self.all_plugin, 'embs': self.train_pts},
|
56
47
|
cfg=self.ckpt_saver,
|
48
|
+
model=self.model_raw,
|
49
|
+
plugin_groups=self.all_plugin,
|
50
|
+
embs=self.train_pts,
|
57
51
|
model_ema=getattr(self, "ema_model", None),
|
52
|
+
optimizer=self.optimizer,
|
58
53
|
name_template=f'{{}}-{self.real_step}',
|
59
54
|
)
|
60
55
|
|
hcpdiff/trainer_ac_single.py
CHANGED
@@ -1,12 +1,7 @@
|
|
1
1
|
import argparse
|
2
|
-
import sys
|
3
|
-
from functools import partial
|
4
|
-
|
5
|
-
import torch
|
6
|
-
from accelerate import Accelerator
|
7
|
-
from loguru import logger
|
8
2
|
|
9
3
|
from rainbowneko.train.trainer import TrainerSingleCard
|
4
|
+
|
10
5
|
from hcpdiff.trainer_ac import HCPTrainer, load_config_with_cli
|
11
6
|
|
12
7
|
class HCPTrainerSingleCard(TrainerSingleCard, HCPTrainer):
|
@@ -0,0 +1,47 @@
|
|
1
|
+
import argparse
|
2
|
+
import warnings
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from rainbowneko.ckpt_manager import NekoPluginSaver
|
6
|
+
from rainbowneko.train.trainer import TrainerDeepspeed
|
7
|
+
from rainbowneko.utils import xformers_available
|
8
|
+
|
9
|
+
from hcpdiff.trainer_ac import HCPTrainer, load_config_with_cli
|
10
|
+
|
11
|
+
class HCPTrainerDeepspeed(TrainerDeepspeed, HCPTrainer):
|
12
|
+
def config_model(self):
|
13
|
+
if self.cfgs.model.enable_xformers:
|
14
|
+
if xformers_available:
|
15
|
+
self.model_wrapper.enable_xformers()
|
16
|
+
else:
|
17
|
+
warnings.warn("xformers is not available. Make sure it is installed correctly")
|
18
|
+
|
19
|
+
if self.model_wrapper.vae is not None:
|
20
|
+
self.vae_dtype = self.weight_dtype_map.get(self.cfgs.model.get('vae_dtype', None), torch.float32)
|
21
|
+
self.model_wrapper.set_dtype(self.weight_dtype, self.vae_dtype)
|
22
|
+
|
23
|
+
if self.cfgs.model.gradient_checkpointing:
|
24
|
+
self.model_wrapper.enable_gradient_checkpointing()
|
25
|
+
|
26
|
+
if self.is_local_main_process:
|
27
|
+
for saver in self.ckpt_saver.values():
|
28
|
+
if isinstance(saver, NekoPluginSaver):
|
29
|
+
saver.plugin_from_raw = True
|
30
|
+
|
31
|
+
def hcp_train():
|
32
|
+
import subprocess
|
33
|
+
parser = argparse.ArgumentParser(description='HCP-Diffusion Launcher')
|
34
|
+
parser.add_argument('--launch_cfg', type=str, default='cfgs/launcher/deepspeed.yaml')
|
35
|
+
args, train_args = parser.parse_known_args()
|
36
|
+
|
37
|
+
subprocess.run(["accelerate", "launch", '--config_file', args.launch_cfg, "-m",
|
38
|
+
"hcpdiff.trainer_deepspeed"]+train_args, check=True)
|
39
|
+
|
40
|
+
if __name__ == '__main__':
|
41
|
+
parser = argparse.ArgumentParser(description='HCP Diffusion Trainer for DeepSpeed')
|
42
|
+
parser.add_argument("--cfg", type=str, default=None, required=True)
|
43
|
+
args, cfg_args = parser.parse_known_args()
|
44
|
+
|
45
|
+
parser, conf = load_config_with_cli(args.cfg, args_list=cfg_args) # skip --cfg
|
46
|
+
trainer = HCPTrainerDeepspeed(parser, conf)
|
47
|
+
trainer.train()
|
hcpdiff/utils/__init__.py
CHANGED
@@ -0,0 +1,25 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
def invert_func(func, y, x_min=0.0, x_max=1.0, tol=1e-5, max_iter=100):
|
4
|
+
"""
|
5
|
+
y: [B]
|
6
|
+
:return: x [B]
|
7
|
+
"""
|
8
|
+
y = y.to(dtype=torch.float32)
|
9
|
+
left = torch.full_like(y, x_min)
|
10
|
+
right = torch.full_like(y, x_max)
|
11
|
+
|
12
|
+
for _ in range(max_iter):
|
13
|
+
mid = (left+right)/2
|
14
|
+
val = func(mid)
|
15
|
+
|
16
|
+
too_large = val>y
|
17
|
+
too_small = ~too_large
|
18
|
+
|
19
|
+
left = torch.where(too_small, mid, left)
|
20
|
+
right = torch.where(too_large, mid, right)
|
21
|
+
|
22
|
+
if torch.all(torch.abs(val-y)<tol):
|
23
|
+
break
|
24
|
+
|
25
|
+
return (left+right)/2
|
hcpdiff/workflow/__init__.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
from .diffusion import InputFeederAction, MakeLatentAction, DenoiseAction, SampleAction, DiffusionStepAction, \
|
2
|
-
X0PredAction, SeedAction, MakeTimestepsAction, PrepareDiffusionAction, time_iter
|
2
|
+
X0PredAction, SeedAction, MakeTimestepsAction, PrepareDiffusionAction, time_iter, DiffusionActions
|
3
3
|
from .text import TextEncodeAction, TextHookAction, AttnMultTextEncodeAction
|
4
4
|
from .vae import EncodeAction, DecodeAction
|
5
5
|
from .io import BuildModelsAction, SaveImageAction, LoadImageAction
|
hcpdiff/workflow/diffusion.py
CHANGED
@@ -6,7 +6,7 @@ import torch
|
|
6
6
|
from hcpdiff.diffusion.sampler import BaseSampler, DiffusersSampler
|
7
7
|
from hcpdiff.utils import prepare_seed
|
8
8
|
from hcpdiff.utils.net_utils import get_dtype, to_cuda
|
9
|
-
from rainbowneko.infer import BasicAction
|
9
|
+
from rainbowneko.infer import BasicAction, Actions
|
10
10
|
from torch.cuda.amp import autocast
|
11
11
|
|
12
12
|
try:
|
@@ -32,8 +32,9 @@ class SeedAction(BasicAction):
|
|
32
32
|
self.seed = seed
|
33
33
|
self.bs = bs
|
34
34
|
|
35
|
-
def forward(self, device, seed=None, **states):
|
36
|
-
|
35
|
+
def forward(self, device, seed=None, bs=None, **states):
|
36
|
+
if bs is None:
|
37
|
+
bs = states['prompt_embeds'].shape[0]//2 if 'prompt_embeds' in states else self.bs
|
37
38
|
seed = seed or self.seed
|
38
39
|
if seed is None:
|
39
40
|
seeds = [None]*bs
|
@@ -155,15 +156,16 @@ class DenoiseAction(BasicAction):
|
|
155
156
|
|
156
157
|
with autocast(enabled=amp is not None, dtype=get_dtype(amp)):
|
157
158
|
latent_model_input = torch.cat([latents]*2) if self.guidance_scale>1 else latents
|
158
|
-
latent_model_input = noise_sampler.c_in(t)*latent_model_input
|
159
|
+
latent_model_input = noise_sampler.sigma_scheduler.c_in(t)*latent_model_input
|
160
|
+
t_in = noise_sampler.sigma_scheduler.c_noise(t)
|
159
161
|
|
160
162
|
if text_embeds is None:
|
161
|
-
noise_pred = denoiser(latent_model_input,
|
163
|
+
noise_pred = denoiser(latent_model_input, t_in, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
|
162
164
|
cross_attention_kwargs=cross_attention_kwargs, ).sample
|
163
165
|
else:
|
164
166
|
added_cond_kwargs = {"text_embeds":text_embeds, "time_ids":crop_info}
|
165
167
|
# predict the noise residual
|
166
|
-
noise_pred = denoiser(latent_model_input,
|
168
|
+
noise_pred = denoiser(latent_model_input, t_in, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
|
167
169
|
cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
|
168
170
|
|
169
171
|
# perform guidance
|
@@ -189,10 +191,28 @@ class DiffusionStepAction(BasicAction):
|
|
189
191
|
states = self.act_noise_pred(denoiser=denoiser, noise_sampler=noise_sampler, **states)
|
190
192
|
states = self.act_sample(**states)
|
191
193
|
return states
|
194
|
+
|
195
|
+
class DiffusionActions(Actions):
|
196
|
+
def __init__(self, actions: List[BasicAction], clean_latent=True, seed_inc=True, key_map_in=None, key_map_out=None):
|
197
|
+
super().__init__(actions, key_map_in=key_map_in, key_map_out=key_map_out)
|
198
|
+
self.clean_latent = clean_latent
|
199
|
+
self.seed_inc = seed_inc
|
200
|
+
|
201
|
+
def forward(self, **states):
|
202
|
+
states = super().forward(**states)
|
203
|
+
if self.seed_inc and 'seed' in states:
|
204
|
+
bs = states['latents'].shape[0]
|
205
|
+
states['seed'] = states['seed'] + bs
|
206
|
+
if self.clean_latent:
|
207
|
+
states.pop('noise_pred', None)
|
208
|
+
states.pop('latents', None)
|
209
|
+
states.pop('prompt', None)
|
210
|
+
states.pop('negative_prompt', None)
|
211
|
+
return states
|
192
212
|
|
193
213
|
class X0PredAction(BasicAction):
|
194
214
|
def forward(self, latents, noise_sampler: BaseSampler, t, noise_pred, **states):
|
195
|
-
latents_x0 = noise_sampler.
|
215
|
+
latents_x0 = noise_sampler.pred_for_target(noise_pred, latents, t, target_type='x0')
|
196
216
|
return {'latents_x0':latents_x0}
|
197
217
|
|
198
218
|
def time_iter(timesteps, **states):
|
hcpdiff/workflow/io.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import os
|
2
2
|
from functools import partial
|
3
3
|
from typing import List, Union
|
4
|
+
from addict import Addict
|
4
5
|
|
5
6
|
import torch
|
6
7
|
from hcpdiff.utils import to_validate_file
|
@@ -9,6 +10,8 @@ from rainbowneko.ckpt_manager import NekoLoader
|
|
9
10
|
from rainbowneko.infer import BasicAction
|
10
11
|
from rainbowneko.infer import LoadImageAction as Neko_LoadImageAction
|
11
12
|
from rainbowneko.utils.img_size_tool import types_support
|
13
|
+
from rainbowneko import _share
|
14
|
+
from rainbowneko.utils import is_dict
|
12
15
|
|
13
16
|
class BuildModelsAction(BasicAction):
|
14
17
|
def __init__(self, model_loader: partial[NekoLoader.load], dtype: str=torch.float32, device='cuda', key_map_in=None, key_map_out=None):
|
@@ -23,6 +26,14 @@ class BuildModelsAction(BasicAction):
|
|
23
26
|
else:
|
24
27
|
model = self.model_loader(dtype=self.dtype, device=self.device)
|
25
28
|
|
29
|
+
# Callback for TokenizerHandler
|
30
|
+
if is_dict(model):
|
31
|
+
model_wrapper = Addict(model)
|
32
|
+
else:
|
33
|
+
model_wrapper = model
|
34
|
+
for callback in _share.model_callbacks:
|
35
|
+
callback(model_wrapper)
|
36
|
+
|
26
37
|
if isinstance(model, dict):
|
27
38
|
return model
|
28
39
|
else:
|
@@ -33,12 +44,13 @@ class LoadImageAction(Neko_LoadImageAction):
|
|
33
44
|
super().__init__(image_paths, image_transforms, key_map_in, key_map_out)
|
34
45
|
|
35
46
|
class SaveImageAction(BasicAction):
|
36
|
-
def __init__(self, save_root: str, image_type: str = 'png', quality: int = 95, save_cfg=True, key_map_in=None, key_map_out=None):
|
47
|
+
def __init__(self, save_root: str, image_type: str = 'png', quality: int = 95, save_cfg=True, save_txt=False, key_map_in=None, key_map_out=None):
|
37
48
|
super().__init__(key_map_in, key_map_out)
|
38
49
|
self.save_root = save_root
|
39
50
|
self.image_type = image_type
|
40
51
|
self.quality = quality
|
41
52
|
self.save_cfg = save_cfg
|
53
|
+
self.save_txt = save_txt
|
42
54
|
|
43
55
|
os.makedirs(save_root, exist_ok=True)
|
44
56
|
|
@@ -47,10 +59,15 @@ class SaveImageAction(BasicAction):
|
|
47
59
|
num_img_exist = max([0]+[int(x.split('-', 1)[0]) for x in os.listdir(save_root) if x.rsplit('.', 1)[-1] in types_support])+1
|
48
60
|
|
49
61
|
for bid, (p, pn, img) in enumerate(zip(prompt, negative_prompt, images)):
|
50
|
-
img_path = os.path.join(save_root, f"{preview_step or num_img_exist}-{seeds[bid]}-{to_validate_file(
|
62
|
+
img_path = os.path.join(save_root, f"{preview_step or num_img_exist}-{seeds[bid]}-{to_validate_file(p)}.{self.image_type}")
|
51
63
|
img.save(img_path, quality=self.quality)
|
52
|
-
num_img_exist += 1
|
53
64
|
|
54
65
|
if self.save_cfg:
|
55
66
|
cfgs.seed = seeds[bid]
|
56
67
|
parser.save_configs(cfgs, os.path.join(save_root, f"{preview_step or num_img_exist}-{seeds[bid]}-info"))
|
68
|
+
|
69
|
+
if self.save_txt:
|
70
|
+
txt_path = os.path.join(save_root, f"{preview_step or num_img_exist}-{seeds[bid]}-{to_validate_file(prompt[0])}.txt")
|
71
|
+
with open(txt_path, 'w') as f:
|
72
|
+
f.write(p)
|
73
|
+
num_img_exist += 1
|
hcpdiff/workflow/text.py
CHANGED
@@ -38,7 +38,7 @@ class TextHookAction(BasicAction):
|
|
38
38
|
return {'te_hook':te_hook, 'emb_hook':emb_hook, 'token_ex':token_ex}
|
39
39
|
|
40
40
|
class TextEncodeAction(BasicAction):
|
41
|
-
def __init__(self, prompt:
|
41
|
+
def __init__(self, prompt: List|str|None, negative_prompt: List|str|None, bs: int = None, key_map_in=None, key_map_out=None):
|
42
42
|
super().__init__(key_map_in, key_map_out)
|
43
43
|
if isinstance(prompt, str) and bs is not None:
|
44
44
|
prompt = [prompt]*bs
|
@@ -73,6 +73,11 @@ class AttnMultTextEncodeAction(TextEncodeAction):
|
|
73
73
|
prompt = prompt or self.prompt
|
74
74
|
negative_prompt = negative_prompt or self.negative_prompt
|
75
75
|
|
76
|
+
if isinstance(negative_prompt, str) and isinstance(prompt, (list, tuple)):
|
77
|
+
negative_prompt = [negative_prompt]*len(prompt)
|
78
|
+
if isinstance(prompt, str) and isinstance(negative_prompt, (list, tuple)):
|
79
|
+
prompt = [prompt]*len(negative_prompt)
|
80
|
+
|
76
81
|
if model_offload:
|
77
82
|
to_cuda(TE)
|
78
83
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: hcpdiff
|
3
|
-
Version: 2.
|
3
|
+
Version: 2.3
|
4
4
|
Summary: A universal Diffusion toolbox
|
5
5
|
Home-page: https://github.com/IrisRainbowNeko/HCP-Diffusion
|
6
6
|
Author: Ziyi Dong
|
@@ -17,7 +17,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
17
17
|
Requires-Python: >=3.8
|
18
18
|
Description-Content-Type: text/markdown
|
19
19
|
License-File: LICENSE
|
20
|
-
Requires-Dist: rainbowneko
|
20
|
+
Requires-Dist: rainbowneko>=1.9
|
21
21
|
Requires-Dist: diffusers
|
22
22
|
Requires-Dist: matplotlib
|
23
23
|
Requires-Dist: pyarrow
|
@@ -262,9 +262,13 @@ hcp_run --cfg cfgs/workflow/text2img_cli.py \
|
|
262
262
|
seed=42
|
263
263
|
```
|
264
264
|
|
265
|
-
### Tutorials
|
265
|
+
### 📚 Tutorials
|
266
266
|
|
267
|
-
|
267
|
+
+ 🧠 [Model Training Guide](https://hcpdiff.readthedocs.io/en/latest/user_guides/train.html)
|
268
|
+
+ 🔧 [LoRA Training Tutorial](https://hcpdiff.readthedocs.io/enlatest/tutorial/lora.html)
|
269
|
+
+ 🎨 [Image Generation Guide](https://hcpdiff.readthedocs.io/en/latest/user_guides/workflow.html)
|
270
|
+
+ ⚙️ [Configuration File Explanation](https://hcpdiff.readthedocs.io/en/latest/user_guides/cfg.html)
|
271
|
+
+ 🧩 [Model Format Explanation](https://hcpdiff.readthedocs.io/en/latest/user_guides/model_format.html)
|
268
272
|
|
269
273
|
---
|
270
274
|
|
@@ -1,16 +1,16 @@
|
|
1
1
|
hcpdiff/__init__.py,sha256=dwNwrEgvG4g60fGMG6b50K3q3AWD1XCfzlIgbxkSUpE,177
|
2
2
|
hcpdiff/train_colo.py,sha256=EsuNSzLBvGTZWU_LEk0JpP-F5eNW0lwkawIRAX38jmE,9250
|
3
|
-
hcpdiff/
|
4
|
-
hcpdiff/
|
5
|
-
hcpdiff/
|
6
|
-
hcpdiff/ckpt_manager/__init__.py,sha256=
|
7
|
-
hcpdiff/ckpt_manager/ckpt.py,sha256=
|
8
|
-
hcpdiff/ckpt_manager/loader.py,sha256=
|
3
|
+
hcpdiff/trainer_ac.py,sha256=-owV-3_bvPxuQsZS2WaajBDh58HpftRtnx0GJkswqaY,2787
|
4
|
+
hcpdiff/trainer_ac_single.py,sha256=zyZVrutLUbIJYW1HzUnQ_RnmIcDhbC7M_CT833PJH5w,993
|
5
|
+
hcpdiff/trainer_deepspeed.py,sha256=7lGsiAstWuIlmhRMwWTcJCkoxzUaakVxBngKDnJdSJk,1947
|
6
|
+
hcpdiff/ckpt_manager/__init__.py,sha256=r_sgjZWCLtdJrRkqqU6aPdfubXSYfPh2Z_Vf_XpZXXs,240
|
7
|
+
hcpdiff/ckpt_manager/ckpt.py,sha256=2A093lT03M1ZsJIMWl376V165eh0TZwOgiGrz3LM73Q,1248
|
8
|
+
hcpdiff/ckpt_manager/loader.py,sha256=6iZDUj-Vfc5T9eGdWfFMQw4n1GqyLqaLBolgAtgqPq8,3640
|
9
9
|
hcpdiff/ckpt_manager/format/__init__.py,sha256=a3cdKkOTDgdVbDQwSC4mlxOigjX2hBvRb5_X7E3TQWs,237
|
10
|
-
hcpdiff/ckpt_manager/format/diffusers.py,sha256=
|
10
|
+
hcpdiff/ckpt_manager/format/diffusers.py,sha256=qhGbrKAaeLyjFzY-Lj4sL1THHFNrta41JGGMoXT-bCE,3761
|
11
11
|
hcpdiff/ckpt_manager/format/emb.py,sha256=FrqfTfJ8H7f0Zw17NTWCP2AJtpsJI5oXR5IAd4NekhU,680
|
12
12
|
hcpdiff/ckpt_manager/format/lora_webui.py,sha256=4y_T9RdmFTxWzsXd8guNjCiukmyILa5j4MPrhVIL4Qk,10017
|
13
|
-
hcpdiff/ckpt_manager/format/sd_single.py,sha256=
|
13
|
+
hcpdiff/ckpt_manager/format/sd_single.py,sha256=4DZLAl1RNC_nPxuW-lmrBlIMFUhpSTa7HGHgu7Yx8qk,2322
|
14
14
|
hcpdiff/data/__init__.py,sha256=ZFKtanOoMo3G3eKUJPhysnHXnr8BNARERkcMB6B897U,292
|
15
15
|
hcpdiff/data/dataset.py,sha256=1k4GldW13eVyqK_9hrQniqr3_XYAapnWF7iXl_1GXGg,877
|
16
16
|
hcpdiff/data/cache/__init__.py,sha256=ToCmokYH6DghlSwm7HJFirPRIWJ0LkgzqVOYlgoAkQw,25
|
@@ -25,44 +25,47 @@ hcpdiff/data/source/text.py,sha256=VgI5Ouq986Yy1jwD2fZ9iBlsRciPCeARZmOPEZIcaQY,1
|
|
25
25
|
hcpdiff/data/source/text2img.py,sha256=acYdolQhZUEpkd7tUAdNkCTVnPc1SMJOVTmGqFt9ZpE,1813
|
26
26
|
hcpdiff/data/source/text2img_cond.py,sha256=yj1KpARA2rkjENutnnzC4uDkcU2Rye21FL2VdC25Hac,585
|
27
27
|
hcpdiff/diffusion/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
28
|
-
hcpdiff/diffusion/noise/__init__.py,sha256=
|
28
|
+
hcpdiff/diffusion/noise/__init__.py,sha256=D83EZ6bnc6Ucu4AZwE6rpmXCtwYfHHumeVq97brbnIE,47
|
29
29
|
hcpdiff/diffusion/noise/pyramid_noise.py,sha256=KbpyMT1BHNIaAa7g5eECDkTttOMoMWVFmbP-ekBsuEY,1693
|
30
|
-
hcpdiff/diffusion/
|
31
|
-
hcpdiff/diffusion/sampler/__init__.py,sha256=
|
32
|
-
hcpdiff/diffusion/sampler/base.py,sha256=
|
33
|
-
hcpdiff/diffusion/sampler/
|
34
|
-
hcpdiff/diffusion/sampler/
|
35
|
-
hcpdiff/diffusion/sampler/
|
36
|
-
hcpdiff/diffusion/sampler/sigma_scheduler/
|
37
|
-
hcpdiff/diffusion/sampler/sigma_scheduler/
|
38
|
-
hcpdiff/diffusion/sampler/sigma_scheduler/
|
39
|
-
hcpdiff/diffusion/sampler/sigma_scheduler/
|
30
|
+
hcpdiff/diffusion/sampler/VP.py,sha256=r0Q_RROEIeNNw93XrOD5htW78rfuoSxy1WBQEoQL83s,958
|
31
|
+
hcpdiff/diffusion/sampler/__init__.py,sha256=Lrwg1us8qo943T7mdIXFDRXfKvnLhrzwmi6DrIKIiUA,135
|
32
|
+
hcpdiff/diffusion/sampler/base.py,sha256=UbE_AmtvLg-Hr2bkYz8PvNWB63tvtacUvCIDm_W6opA,5484
|
33
|
+
hcpdiff/diffusion/sampler/diffusers.py,sha256=wIMs8n3kdci1On0FUCV0si324ZE9zeRw_CxaHP8rdcs,2586
|
34
|
+
hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py,sha256=eiSmMBkXI_LfxnNrXj5XptcF0dGcPas--vWvqhFGlv8,273
|
35
|
+
hcpdiff/diffusion/sampler/sigma_scheduler/base.py,sha256=UT4tbjFf80KYfU08y0hJf8h_Cl80a5MUhK5FsKLsqbY,2521
|
36
|
+
hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py,sha256=SA8lXT6lucJot_rpJ84Wz-_uc5dXfb2QPoQgJHSKOj4,12999
|
37
|
+
hcpdiff/diffusion/sampler/sigma_scheduler/edm.py,sha256=m1YlIyn61zfDjxLMcHvWs0nzULbHgXeB7WGKmTiaGSU,4127
|
38
|
+
hcpdiff/diffusion/sampler/sigma_scheduler/flow.py,sha256=FtWpesUtSmFuiIGkrrVhYJweB7INZiw0atC64tc0Nk4,2020
|
39
|
+
hcpdiff/diffusion/sampler/sigma_scheduler/zero_terminal.py,sha256=CCqQLkGo4omkxzFovYdZQzdZVwIxK3PiOitZFww8MHs,859
|
40
40
|
hcpdiff/easy/__init__.py,sha256=-emoyCOZlLCu3KNMI8L4qapUEtEYFSoiGU6-rKv1at4,149
|
41
41
|
hcpdiff/easy/sampler.py,sha256=dQSBkeGh71O0DAmZLhTHTbk1bY7XzyUCeW1oJO14A4I,1250
|
42
42
|
hcpdiff/easy/cfg/__init__.py,sha256=SxHMWG6T2CXhX3dP0xizSMd9vFWPaZQDc4Gj4CF__yQ,253
|
43
|
-
hcpdiff/easy/cfg/sd15_train.py,sha256=
|
44
|
-
hcpdiff/easy/cfg/sdxl_train.py,sha256=
|
43
|
+
hcpdiff/easy/cfg/sd15_train.py,sha256=NtgsQLg1sd5JFmHU4nqMPOrvP7zmwo2x0MCspjVNQEY,7000
|
44
|
+
hcpdiff/easy/cfg/sdxl_train.py,sha256=rVLLKVMKB_PHuum3dKQcBqL0uR8QhzmdRllM-pYnbK4,4534
|
45
45
|
hcpdiff/easy/cfg/t2i.py,sha256=SnjFjZAKd9orjJr3RW5_N2_EIlW2Ree7JMvdNUAR9gc,9507
|
46
46
|
hcpdiff/easy/model/__init__.py,sha256=CA-7r3R2Jgweekk1XNByFYttLolbWyUV2bCnXygcD8w,133
|
47
47
|
hcpdiff/easy/model/cnet.py,sha256=m0NTH9V1kLzb5GybwBrSNT0KvTcRpPfGkzUeMz9jZZQ,1084
|
48
48
|
hcpdiff/easy/model/loader.py,sha256=Tdx-lhQEYf2NYjVM1A5B8x6ZZpJKcXUkFIPIbr7h7XM,3456
|
49
|
-
hcpdiff/evaluate/__init__.py,sha256=
|
50
|
-
hcpdiff/evaluate/
|
49
|
+
hcpdiff/evaluate/__init__.py,sha256=qWxV0D8Ho5uBj2YbaC_QFDnT49PSKPfh44m4ivkNbMM,108
|
50
|
+
hcpdiff/evaluate/evaluator.py,sha256=9BZQBeC-N7p-ICx6Giw9v-2Tb9volMTDmeDfhj0nXJ0,2940
|
51
|
+
hcpdiff/evaluate/previewer.py,sha256=-vE0YXVfos70CQMo9ZInw7xu3d88DlTfVLs4BzzkxfM,3140
|
52
|
+
hcpdiff/evaluate/metrics/__init__.py,sha256=vE0nSvBtDBu9SomANvWcm2UHX56PhCYwhgrcmm_mKyo,39
|
53
|
+
hcpdiff/evaluate/metrics/clip_score.py,sha256=rQgweu5QcqW3fPI3EXcNbrH2QCcSAekE3lpYk45P2M4,900
|
51
54
|
hcpdiff/loss/__init__.py,sha256=2dwPczSiv3rB5fzOeYbl5ZHpMU-qXOQlXeOiXdxcxwM,173
|
52
|
-
hcpdiff/loss/base.py,sha256=
|
55
|
+
hcpdiff/loss/base.py,sha256=Vvpm-KZGH4n-gYIlnVAtPl1B799c7v0dJXJ5BBh3yO0,1112
|
53
56
|
hcpdiff/loss/gw.py,sha256=0yi1kozuII3xZA6FnjOhINtvScWt1MyBZLBtMKmgojM,1224
|
54
57
|
hcpdiff/loss/ssim.py,sha256=YofadvBkc6sklxBUx1p3ADw5OHOZPK3kaHz8FH5a6m4,1281
|
55
58
|
hcpdiff/loss/vlb.py,sha256=s78iBnXUiDWfGf7mYmhUnHqxqea5gSByKOoqBrX6bzU,3222
|
56
|
-
hcpdiff/loss/weighting.py,sha256=
|
59
|
+
hcpdiff/loss/weighting.py,sha256=9qzMnvCb6b5qx0p08GDSlkxmYEqQcNt79XdRBvfHmiI,2914
|
57
60
|
hcpdiff/models/__init__.py,sha256=eQS7DPiGLiE1MFRkZj_17IY3IsfDUVcYpcOmhHb5B9o,472
|
58
61
|
hcpdiff/models/cfg_context.py,sha256=e2B3K1KwJhzbD6xdJUOyNtl_XgQ0296XI3FHw3gvZF4,1502
|
59
62
|
hcpdiff/models/container.py,sha256=z3p5TmQhxdzXSIfofz55_bmEhSsgUJsy1o9EcDs8Oeo,696
|
60
63
|
hcpdiff/models/controlnet.py,sha256=VIkUzJCVpCqqQOtRSLQPfbcDy9CsXutxLeZB6PdZfA0,7809
|
61
64
|
hcpdiff/models/lora_base.py,sha256=LGwBD9KP6qf4pgTx24i5-JLo4rDBQ6jFfterQKBjTbE,6758
|
62
|
-
hcpdiff/models/lora_base_patch.py,sha256=
|
65
|
+
hcpdiff/models/lora_base_patch.py,sha256=Tdb_b3TN_K-04nlUvcfBh6flPcbL9M4iP7jOVyb1jXQ,7271
|
63
66
|
hcpdiff/models/lora_layers.py,sha256=O9W_Ue71lHj7Y_GbpioF4Hc3h2-z_zOqck93VYUra6s,7777
|
64
67
|
hcpdiff/models/lora_layers_patch.py,sha256=GYFYsJD2VSLZfdnLma9CmQEHz09HROFJcc4wc_gs9f0,8198
|
65
|
-
hcpdiff/models/text_emb_ex.py,sha256=
|
68
|
+
hcpdiff/models/text_emb_ex.py,sha256=O0XZqid01OrB0dHY7hCiBvdU2026SvZ38yfQaF2TWrs,8018
|
66
69
|
hcpdiff/models/textencoder_ex.py,sha256=JrTQ30Avx8tPbdr-Q6K5BvEWCEdsu8Z7eSOzMqpUuzg,8270
|
67
70
|
hcpdiff/models/tokenizer_ex.py,sha256=zKUn4BY7b3yXwK9PWkZtQKJPyKYwUc07E-hwB9NQybs,2446
|
68
71
|
hcpdiff/models/compose/__init__.py,sha256=lTNFTGg5csqvUuys22RqgjmWlk_7Okw6ZTsnTi1pqCg,217
|
@@ -72,7 +75,7 @@ hcpdiff/models/compose/compose_tokenizer.py,sha256=g3l0pOFv6p7Iigxm6Pqt_iTUXBlO1
|
|
72
75
|
hcpdiff/models/compose/sdxl_composer.py,sha256=NtMGaFGZTfKsPJSVi2yT-UM6K1WKWtk99XxVmTcKlk8,2164
|
73
76
|
hcpdiff/models/wrapper/__init__.py,sha256=HbGQmFnfccr-dtvZKjEv-pmR4cCnF4fwGLKS3tuG_OY,135
|
74
77
|
hcpdiff/models/wrapper/pixart.py,sha256=nRUvHSHn4TYg_smC0xpeW-GtUgXss-MuaVPTHpMozDE,1147
|
75
|
-
hcpdiff/models/wrapper/sd.py,sha256=
|
78
|
+
hcpdiff/models/wrapper/sd.py,sha256=EywmVU2QzR74M_4eH_uXVW8HJNauyjwcZPU7rRAQ7eI,11666
|
76
79
|
hcpdiff/models/wrapper/utils.py,sha256=NyebMoAPnrgcTHbiIocSD-eGdGdD-V1G_TQuWsRWufw,665
|
77
80
|
hcpdiff/parser/__init__.py,sha256=-2dDZ2Ii4zoGQqDTme94q4PpJbBiV6HS5BsDASz4Xbo,33
|
78
81
|
hcpdiff/parser/embpt.py,sha256=LgwZ0f0tLn3DrTo5ZpSCsZcA5330UpiW_sK96yEPmOM,1307
|
@@ -89,27 +92,28 @@ hcpdiff/tools/init_proj.py,sha256=XrXxxhIaItywG7HsrloJo-x8w9suZiY35daelzZvjrg,19
|
|
89
92
|
hcpdiff/tools/lora_convert.py,sha256=So14WvSVIm6rU4m1XCajFXDnhq7abpZS95SLbaoyBFU,10058
|
90
93
|
hcpdiff/tools/save_model.py,sha256=gbfYi_EfEBZEUcDjle6MDHA19sQWY0zA8_y_LMzHQ7M,428
|
91
94
|
hcpdiff/tools/sd2diffusers.py,sha256=vB6OnBLw60sJkdpVZcYEPtKAZW1h8ErbSGSRq0uAiIk,16855
|
92
|
-
hcpdiff/utils/__init__.py,sha256=
|
95
|
+
hcpdiff/utils/__init__.py,sha256=28K9Ui0uur-vHuUdlSyIBYijgu2b7rGOPXN2ogJu1z8,82
|
93
96
|
hcpdiff/utils/colo_utils.py,sha256=JyLUvVnISa48CnryNLrgVxMo-jxu2UhBq70eYPrkjuI,837
|
94
97
|
hcpdiff/utils/inpaint_pipe.py,sha256=CRy1MUlPmHifCAbZnKOP0qbLp2grn7ZbVeaB2qIA4ig,42862
|
95
98
|
hcpdiff/utils/net_utils.py,sha256=gdwLYDNKV2t3SP0jBIO3d0HtY6E7jRaf_rmPT8gKZZE,9762
|
96
99
|
hcpdiff/utils/pipe_hook.py,sha256=-UDX3FtZGl-bxSk13gdbPXc1OvtbCcpk_fvKxLQo3Ag,31987
|
100
|
+
hcpdiff/utils/torch_utils.py,sha256=gBZCcDKZc0NGDQx6QeHuQePoZ82kQRhaL7oEdZIYGvU,573
|
97
101
|
hcpdiff/utils/utils.py,sha256=hZnZP1IETgVpScxES0yIuRfc34TnzvAqmgOTK_56ssw,4976
|
98
|
-
hcpdiff/workflow/__init__.py,sha256=
|
99
|
-
hcpdiff/workflow/diffusion.py,sha256=
|
102
|
+
hcpdiff/workflow/__init__.py,sha256=i5s7QXo6wK9607KL0KTW4suE1c-HGJ5_EgnCdVLl3WM,885
|
103
|
+
hcpdiff/workflow/diffusion.py,sha256=hKefBrVP6-025MhdrKOQMUhHxLaGqjpUKhR6WahYwh0,9549
|
100
104
|
hcpdiff/workflow/fast.py,sha256=kZt7bKrvpFInSn7GzbkTkpoCSM0Z6IbDjgaDvcbFYf8,1024
|
101
105
|
hcpdiff/workflow/flow.py,sha256=FFbFFOAXT4c31L5bHBEB_qeVGuBQDLYhq8kTD1chGNo,2548
|
102
|
-
hcpdiff/workflow/io.py,sha256=
|
106
|
+
hcpdiff/workflow/io.py,sha256=4oiE_PS3sOVYT8M6PDwvT5h9XzoKDMQR0n_4-Ktttys,3284
|
103
107
|
hcpdiff/workflow/model.py,sha256=1gj5yOTefYTnGXVR6JPAfxIwuB69YwN6E-BontRcuyQ,2913
|
104
|
-
hcpdiff/workflow/text.py,sha256=
|
108
|
+
hcpdiff/workflow/text.py,sha256=XQvN4zzK7VaGxy4FDgSDeWh2jjk7UZU24moeRKAWXRE,4608
|
105
109
|
hcpdiff/workflow/utils.py,sha256=xojaMG4lHsymslc8df5uiVXmmBVWpn_Phqka8qzJEWw,2226
|
106
110
|
hcpdiff/workflow/vae.py,sha256=cingDPkIOc4qGpOwwhXJK4EQbGoIxO583pm6gGov5t8,3118
|
107
111
|
hcpdiff/workflow/daam/__init__.py,sha256=ySIDaxloN-D3qM7OuVaG1BR3D-CibDoXYpoTgw0zUhU,59
|
108
112
|
hcpdiff/workflow/daam/act.py,sha256=tHbsFWTYYU4bvcZOo1Bpi_z6ofpJatRYccl4vvf8wIA,2756
|
109
113
|
hcpdiff/workflow/daam/hook.py,sha256=z9f9mBjKW21xuUZ-iQxQ0HbWOBXtZrisFB0VNMq6d0U,4383
|
110
|
-
hcpdiff-2.
|
111
|
-
hcpdiff-2.
|
112
|
-
hcpdiff-2.
|
113
|
-
hcpdiff-2.
|
114
|
-
hcpdiff-2.
|
115
|
-
hcpdiff-2.
|
114
|
+
hcpdiff-2.3.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
115
|
+
hcpdiff-2.3.dist-info/METADATA,sha256=1aqZH8IwB7WjDBot_fADwyfLDVNiRZuZdub2-zezFck,10321
|
116
|
+
hcpdiff-2.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
117
|
+
hcpdiff-2.3.dist-info/entry_points.txt,sha256=_4VRsEsEWOhHfzBDu9bx8Wh_S8Wi4ZTHpI0n6rU0J-I,258
|
118
|
+
hcpdiff-2.3.dist-info/top_level.txt,sha256=shyf78x-HVgykYpsmY22mKG0xIc7Qk30fDMdavdYWQ8,8
|
119
|
+
hcpdiff-2.3.dist-info/RECORD,,
|
@@ -1,39 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
from ..sampler.sigma_scheduler import DDPMDiscreteSigmaScheduler
|
3
|
-
|
4
|
-
class ZeroTerminalSampler:
|
5
|
-
|
6
|
-
@classmethod
|
7
|
-
def patch(cls, base_sampler):
|
8
|
-
assert isinstance(base_sampler.sigma_scheduler, DDPMDiscreteSigmaScheduler), "ZeroTerminalScheduler only works with DDPM SigmaScheduler"
|
9
|
-
|
10
|
-
alphas_cumprod = base_sampler.sigma_scheduler.alphas_cumprod
|
11
|
-
base_sampler.sigma_scheduler.alphas_cumprod = cls.rescale_zero_terminal_snr(alphas_cumprod)
|
12
|
-
base_sampler.sigma_scheduler.sigmas = ((1-alphas_cumprod)/alphas_cumprod).sqrt()
|
13
|
-
|
14
|
-
|
15
|
-
@staticmethod
|
16
|
-
def rescale_zero_terminal_snr(alphas_cumprod, thr=1e-4):
|
17
|
-
"""
|
18
|
-
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
19
|
-
Args:
|
20
|
-
alphas_cumprod (`torch.FloatTensor`)
|
21
|
-
Returns:
|
22
|
-
`torch.FloatTensor`: rescaled betas with zero terminal SNR
|
23
|
-
"""
|
24
|
-
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
25
|
-
|
26
|
-
# Store old values.
|
27
|
-
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
28
|
-
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
29
|
-
|
30
|
-
# Shift so the last timestep is zero.
|
31
|
-
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
32
|
-
|
33
|
-
# Scale so the first timestep is back to the old value.
|
34
|
-
alphas_bar_sqrt *= alphas_bar_sqrt_0/(alphas_bar_sqrt_0-alphas_bar_sqrt_T)
|
35
|
-
alphas_bar_sqrt[-1] = thr # avoid nan sigma
|
36
|
-
|
37
|
-
# Convert alphas_bar_sqrt to betas
|
38
|
-
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
39
|
-
return alphas_bar
|