hcpdiff 2.3.1__py3-none-any.whl → 2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/ckpt_manager/__init__.py +1 -1
- hcpdiff/ckpt_manager/format/__init__.py +2 -2
- hcpdiff/ckpt_manager/format/diffusers.py +19 -4
- hcpdiff/ckpt_manager/format/emb.py +8 -3
- hcpdiff/ckpt_manager/format/lora_webui.py +1 -1
- hcpdiff/ckpt_manager/format/sd_single.py +28 -5
- hcpdiff/data/cache/vae.py +10 -2
- hcpdiff/data/handler/text.py +15 -14
- hcpdiff/diffusion/sampler/__init__.py +2 -1
- hcpdiff/diffusion/sampler/base.py +17 -6
- hcpdiff/diffusion/sampler/diffusers.py +4 -3
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +5 -14
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +7 -6
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +4 -4
- hcpdiff/diffusion/sampler/sigma_scheduler/flow.py +3 -3
- hcpdiff/diffusion/sampler/timer/__init__.py +2 -0
- hcpdiff/diffusion/sampler/timer/base.py +26 -0
- hcpdiff/diffusion/sampler/timer/shift.py +49 -0
- hcpdiff/easy/__init__.py +2 -1
- hcpdiff/easy/cfg/sd15_train.py +1 -3
- hcpdiff/easy/model/__init__.py +1 -1
- hcpdiff/easy/model/loader.py +33 -11
- hcpdiff/easy/sampler.py +8 -1
- hcpdiff/loss/__init__.py +4 -3
- hcpdiff/loss/charbonnier.py +17 -0
- hcpdiff/loss/vlb.py +2 -2
- hcpdiff/loss/weighting.py +29 -11
- hcpdiff/models/__init__.py +1 -1
- hcpdiff/models/cfg_context.py +5 -3
- hcpdiff/models/compose/__init__.py +2 -1
- hcpdiff/models/compose/compose_hook.py +69 -67
- hcpdiff/models/compose/compose_textencoder.py +59 -45
- hcpdiff/models/compose/compose_tokenizer.py +48 -11
- hcpdiff/models/compose/flux.py +75 -0
- hcpdiff/models/compose/sdxl.py +86 -0
- hcpdiff/models/text_emb_ex.py +13 -9
- hcpdiff/models/textencoder_ex.py +8 -38
- hcpdiff/models/wrapper/__init__.py +2 -1
- hcpdiff/models/wrapper/flux.py +75 -0
- hcpdiff/models/wrapper/pixart.py +13 -1
- hcpdiff/models/wrapper/sd.py +17 -8
- hcpdiff/parser/embpt.py +7 -7
- hcpdiff/utils/net_utils.py +22 -12
- hcpdiff/workflow/__init__.py +1 -1
- hcpdiff/workflow/diffusion.py +145 -18
- hcpdiff/workflow/text.py +49 -18
- hcpdiff/workflow/vae.py +10 -2
- {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/METADATA +1 -1
- {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/RECORD +53 -49
- hcpdiff/models/compose/sdxl_composer.py +0 -39
- hcpdiff/utils/inpaint_pipe.py +0 -790
- hcpdiff/utils/pipe_hook.py +0 -656
- {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/WHEEL +0 -0
- {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/entry_points.txt +0 -0
- {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/licenses/LICENSE +0 -0
- {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/top_level.txt +0 -0
hcpdiff/ckpt_manager/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
1
|
from .format import EmbFormat, DiffusersSD15Format, DiffusersModelFormat, DiffusersSDXLFormat, DiffusersPixArtFormat, OfficialSDXLFormat, \
|
2
|
-
OfficialSD15Format, LoraWebuiFormat
|
2
|
+
OfficialSD15Format, LoraWebuiFormat, DiffusersFluxFormat, OneFileFluxFormat
|
3
3
|
from .ckpt import EmbSaver
|
4
4
|
from .loader import HCPLoraLoader
|
@@ -1,4 +1,4 @@
|
|
1
1
|
from .emb import EmbFormat
|
2
|
-
from .diffusers import DiffusersSD15Format, DiffusersModelFormat, DiffusersSDXLFormat, DiffusersPixArtFormat
|
3
|
-
from .sd_single import OfficialSD15Format, OfficialSDXLFormat
|
2
|
+
from .diffusers import DiffusersSD15Format, DiffusersModelFormat, DiffusersSDXLFormat, DiffusersPixArtFormat, DiffusersFluxFormat
|
3
|
+
from .sd_single import OfficialSD15Format, OfficialSDXLFormat, OneFileFluxFormat
|
4
4
|
from .lora_webui import LoraWebuiFormat
|
@@ -1,10 +1,10 @@
|
|
1
1
|
import torch
|
2
|
-
from diffusers import ModelMixin, AutoencoderKL, UNet2DConditionModel, PixArtTransformer2DModel
|
2
|
+
from diffusers import ModelMixin, AutoencoderKL, UNet2DConditionModel, PixArtTransformer2DModel, FluxTransformer2DModel
|
3
3
|
from rainbowneko.ckpt_manager.format import CkptFormat
|
4
4
|
from transformers import CLIPTextModel, AutoTokenizer, T5EncoderModel
|
5
5
|
|
6
|
-
from hcpdiff.diffusion.sampler import VPSampler, DDPMDiscreteSigmaScheduler
|
7
|
-
from hcpdiff.models.compose import SDXLTokenizer, SDXLTextEncoder
|
6
|
+
from hcpdiff.diffusion.sampler import VPSampler, DDPMDiscreteSigmaScheduler, FlowSigmaScheduler, Sampler, FluxShiftTimeSampler
|
7
|
+
from hcpdiff.models.compose import SDXLTokenizer, SDXLTextEncoder, FluxTokenizer, FluxTextEncoder
|
8
8
|
|
9
9
|
class DiffusersModelFormat(CkptFormat):
|
10
10
|
def __init__(self, builder: ModelMixin):
|
@@ -51,9 +51,24 @@ class DiffusersPixArtFormat(CkptFormat):
|
|
51
51
|
pretrained_model, subfolder="transformer", revision=revision, torch_dtype=dtype
|
52
52
|
)
|
53
53
|
vae = vae or AutoencoderKL.from_pretrained(pretrained_model, subfolder="vae", revision=revision, torch_dtype=dtype)
|
54
|
-
noise_sampler = noise_sampler or VPSampler(DDPMDiscreteSigmaScheduler())
|
54
|
+
noise_sampler = noise_sampler or VPSampler(DDPMDiscreteSigmaScheduler(linear_start=0.0001, linear_end=0.02, beta_schedule='linear'))
|
55
55
|
|
56
56
|
TE = TE or T5EncoderModel.from_pretrained(pretrained_model, subfolder="text_encoder", revision=revision, torch_dtype=dtype)
|
57
57
|
tokenizer = tokenizer or AutoTokenizer.from_pretrained(pretrained_model, subfolder="tokenizer", revision=revision, use_fast=False)
|
58
|
+
tokenizer.model_max_length = 300
|
58
59
|
|
59
60
|
return dict(denoiser=denoiser, TE=TE, vae=vae, noise_sampler=noise_sampler, tokenizer=tokenizer)
|
61
|
+
|
62
|
+
class DiffusersFluxFormat(CkptFormat):
|
63
|
+
def load_ckpt(self, pretrained_model: str, map_location="cpu", denoiser=None, TE=None, vae: AutoencoderKL = None, noise_sampler=None,
|
64
|
+
tokenizer=None, revision=None, dtype=torch.float32, **kwargs):
|
65
|
+
denoiser = denoiser or FluxTransformer2DModel.from_pretrained(
|
66
|
+
pretrained_model, subfolder="transformer", revision=revision, torch_dtype=dtype
|
67
|
+
)
|
68
|
+
vae = vae or AutoencoderKL.from_pretrained(pretrained_model, subfolder="vae", revision=revision, torch_dtype=dtype)
|
69
|
+
noise_sampler = noise_sampler or Sampler(FlowSigmaScheduler(), t_sampler=FluxShiftTimeSampler())
|
70
|
+
|
71
|
+
TE = TE or FluxTextEncoder.from_pretrained(pretrained_model, subfolder="text_encoder", revision=revision, torch_dtype=dtype)
|
72
|
+
tokenizer = tokenizer or FluxTokenizer.from_pretrained(pretrained_model, subfolder="tokenizer", revision=revision, use_fast=False)
|
73
|
+
|
74
|
+
return dict(denoiser=denoiser, TE=TE, vae=vae, noise_sampler=noise_sampler, tokenizer=tokenizer)
|
@@ -2,13 +2,18 @@ from typing import Tuple
|
|
2
2
|
|
3
3
|
import torch
|
4
4
|
from rainbowneko.ckpt_manager.format import CkptFormat
|
5
|
-
from torch
|
5
|
+
from torch import nn, Tensor
|
6
|
+
from rainbowneko.utils import FILE_LIKE
|
6
7
|
|
7
8
|
class EmbFormat(CkptFormat):
|
8
9
|
EXT = 'pt'
|
9
10
|
|
10
|
-
def save_ckpt(self, sd_model: Tuple[str,
|
11
|
+
def save_ckpt(self, sd_model: Tuple[str, Tensor | nn.Parameter | nn.ParameterDict], save_f: FILE_LIKE):
|
11
12
|
name, emb = sd_model
|
13
|
+
if hasattr(emb, 'named_parameters'):
|
14
|
+
emb = dict(emb.named_parameters())
|
15
|
+
elif isinstance(emb, nn.Parameter):
|
16
|
+
emb = emb.data
|
12
17
|
torch.save({'string_to_param':{'*':emb}, 'name':name}, save_f)
|
13
18
|
|
14
19
|
def load_ckpt(self, ckpt_f: FILE_LIKE, map_location="cpu"):
|
@@ -18,4 +23,4 @@ class EmbFormat(CkptFormat):
|
|
18
23
|
else:
|
19
24
|
emb = state['emb_params']
|
20
25
|
emb.requires_grad_(False)
|
21
|
-
return emb
|
26
|
+
return emb
|
@@ -3,7 +3,7 @@ import re
|
|
3
3
|
from typing import List, Dict, Any
|
4
4
|
|
5
5
|
from rainbowneko.ckpt_manager.format import CkptFormat, SafeTensorFormat
|
6
|
-
from
|
6
|
+
from rainbowneko.utils import FILE_LIKE
|
7
7
|
|
8
8
|
class LoraConverter:
|
9
9
|
com_name_unet = ['down_blocks', 'up_blocks', 'mid_block', 'transformer_blocks', 'to_q', 'to_k', 'to_v', 'to_out', 'proj_in', 'proj_out',
|
@@ -1,9 +1,9 @@
|
|
1
1
|
import torch
|
2
|
-
from diffusers import AutoencoderKL, StableDiffusionPipeline, StableDiffusionXLPipeline
|
2
|
+
from diffusers import AutoencoderKL, StableDiffusionPipeline, StableDiffusionXLPipeline, FluxPipeline
|
3
3
|
from rainbowneko.ckpt_manager.format import CkptFormat
|
4
4
|
|
5
|
-
from hcpdiff.diffusion.sampler import VPSampler, DDPMDiscreteSigmaScheduler
|
6
|
-
from hcpdiff.models.compose import SDXLTextEncoder, SDXLTokenizer
|
5
|
+
from hcpdiff.diffusion.sampler import VPSampler, DDPMDiscreteSigmaScheduler, FlowSigmaScheduler, Sampler, FluxShiftTimeSampler
|
6
|
+
from hcpdiff.models.compose import SDXLTextEncoder, SDXLTokenizer, FluxTextEncoder, FluxTokenizer
|
7
7
|
|
8
8
|
class OfficialSD15Format(CkptFormat):
|
9
9
|
# Single file format
|
@@ -35,7 +35,30 @@ class OfficialSDXLFormat(CkptFormat):
|
|
35
35
|
)
|
36
36
|
|
37
37
|
noise_sampler = noise_sampler or VPSampler(DDPMDiscreteSigmaScheduler())
|
38
|
-
TE = SDXLTextEncoder(
|
39
|
-
tokenizer = SDXLTokenizer(
|
38
|
+
TE = SDXLTextEncoder({'clip_L': pipe.text_encoder, 'clip_bigG': pipe.text_encoder_2})
|
39
|
+
tokenizer = SDXLTokenizer({'clip_L': pipe.tokenizer, 'clip_bigG': pipe.tokenizer_2})
|
40
40
|
|
41
41
|
return dict(denoiser=pipe.unet, TE=TE, vae=pipe.vae, noise_sampler=noise_sampler, tokenizer=tokenizer)
|
42
|
+
|
43
|
+
class OneFileFluxFormat(CkptFormat):
|
44
|
+
# Single file format
|
45
|
+
def load_ckpt(self, pretrained_model: str, map_location="cpu", denoiser=None, TE=None, vae: AutoencoderKL = None, noise_sampler=None,
|
46
|
+
tokenizer=None, revision=None, dtype=torch.float32, **kwargs):
|
47
|
+
pipe_args = dict(unet=denoiser, vae=vae)
|
48
|
+
if TE is not None:
|
49
|
+
pipe_args['text_encoder'] = TE.clip
|
50
|
+
pipe_args['text_encoder_2'] = TE.T5
|
51
|
+
if tokenizer is not None:
|
52
|
+
pipe_args['tokenizer'] = tokenizer.clip
|
53
|
+
pipe_args['tokenizer_2'] = tokenizer.T5
|
54
|
+
|
55
|
+
pipe_args = {k:v for k,v in pipe_args.items() if v is not None}
|
56
|
+
pipe = FluxPipeline.from_single_file(
|
57
|
+
pretrained_model, revision=revision, torch_dtype=dtype, **pipe_args
|
58
|
+
)
|
59
|
+
|
60
|
+
noise_sampler = noise_sampler or Sampler(FlowSigmaScheduler(), t_sampler=FluxShiftTimeSampler())
|
61
|
+
TE = FluxTextEncoder({'clip': pipe.text_encoder, 'T5': pipe.text_encoder_2})
|
62
|
+
tokenizer = FluxTokenizer({'clip': pipe.tokenizer, 'T5': pipe.tokenizer_2})
|
63
|
+
|
64
|
+
return dict(denoiser=pipe.unet, TE=TE, vae=pipe.vae, noise_sampler=noise_sampler, tokenizer=tokenizer)
|
hcpdiff/data/cache/vae.py
CHANGED
@@ -73,7 +73,11 @@ class VaeCache(DataCache):
|
|
73
73
|
for data in tqdm(loader):
|
74
74
|
image = data['image'].to(device=_share.device, dtype=vae.dtype)
|
75
75
|
latents = model.vae.encode(image).latent_dist.sample()
|
76
|
-
|
76
|
+
if shift_factor := getattr(vae.config, 'shift_factor', None) is not None:
|
77
|
+
latents = (latents-shift_factor)*vae.config.scaling_factor
|
78
|
+
else:
|
79
|
+
latents = latents*vae.config.scaling_factor
|
80
|
+
latents = latents.cpu()
|
77
81
|
|
78
82
|
for img_id, latent, coord in zip(data['id'], latents, data['coord']):
|
79
83
|
data_cache = {'latent': latent, 'coord': coord}
|
@@ -89,7 +93,11 @@ class VaeCache(DataCache):
|
|
89
93
|
img_id = data['id']
|
90
94
|
image = data['image'].to(device=_share.device, dtype=vae.dtype)
|
91
95
|
latents = model.vae.encode(image).latent_dist.sample()
|
92
|
-
|
96
|
+
if shift_factor := getattr(vae.config, 'shift_factor', None) is not None:
|
97
|
+
latents = (latents-shift_factor)*vae.config.scaling_factor
|
98
|
+
else:
|
99
|
+
latents = latents*vae.config.scaling_factor
|
100
|
+
latents = latents.cpu()
|
93
101
|
for img_id, latent, coord in zip(data['id'], latents, data['coord']):
|
94
102
|
self.cache[img_id] = {'latent': latent, 'coord': coord}
|
95
103
|
|
hcpdiff/data/handler/text.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1
1
|
import random
|
2
|
-
from
|
2
|
+
from string import Formatter
|
3
|
+
from typing import Dict, Union
|
3
4
|
|
4
5
|
import numpy as np
|
5
|
-
from string import Formatter
|
6
|
-
from rainbowneko.data import DataHandler
|
7
6
|
from rainbowneko._share import register_model_callback
|
7
|
+
from rainbowneko.data import DataHandler
|
8
|
+
|
9
|
+
from hcpdiff.models.compose import ComposeTokenizer
|
8
10
|
|
9
11
|
class TagShuffleHandler(DataHandler):
|
10
12
|
def __init__(self, key_map_in=('prompt -> prompt',), key_map_out=('prompt -> prompt',)):
|
@@ -58,7 +60,6 @@ class TagEraseHandler(DataHandler):
|
|
58
60
|
def __repr__(self):
|
59
61
|
return f'TagEraseHandler(p={self.p})'
|
60
62
|
|
61
|
-
|
62
63
|
class TemplateFillHandler(DataHandler):
|
63
64
|
def __init__(self, word_names: Dict[str, str], key_map_in=('prompt -> prompt',), key_map_out=('prompt -> prompt',)):
|
64
65
|
super().__init__(key_map_in, key_map_out)
|
@@ -68,7 +69,7 @@ class TemplateFillHandler(DataHandler):
|
|
68
69
|
template, caption = prompt['template'], prompt['caption']
|
69
70
|
|
70
71
|
keys_need = {i[1] for i in Formatter().parse(template) if i[1] is not None}
|
71
|
-
fill_dict = {k:
|
72
|
+
fill_dict = {k:v for k, v in self.word_names.items() if k in keys_need}
|
72
73
|
|
73
74
|
if (caption is not None) and ('caption' in keys_need):
|
74
75
|
fill_dict.update(caption=fill_dict.get('caption', None) or caption)
|
@@ -96,16 +97,16 @@ class TokenizeHandler(DataHandler):
|
|
96
97
|
self.tokenizer = model_wrapper.tokenizer
|
97
98
|
|
98
99
|
def handle(self, prompt):
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
100
|
+
# Tokenizer: {'input_ids':Tensor, 'attention_mask':Tensor, 'position_ids':Tensor, ...}
|
101
|
+
# ComposeTokenizer: {'input_ids':{'model1':Tensor, 'model2':Tensor}, ...}
|
102
|
+
token_info = ComposeTokenizer.tokenize_ex(self.tokenizer, prompt, truncation=True, padding="max_length",
|
103
|
+
return_tensors="pt", squeeze=True)
|
104
|
+
data = {'prompt':token_info.input_ids}
|
105
|
+
if 'attention_mask' in data:
|
106
|
+
data['attn_mask'] = data['attention_mask']
|
105
107
|
if 'position_ids' in token_info:
|
106
|
-
data['position_ids'] = token_info
|
107
|
-
|
108
|
+
data['position_ids'] = token_info['position_ids']
|
108
109
|
return data
|
109
110
|
|
110
111
|
def __repr__(self):
|
111
|
-
return f'TokenizeHandler(\nencoder_attention_mask={self.encoder_attention_mask}, tokenizer={self.tokenizer}\n)'
|
112
|
+
return f'TokenizeHandler(\nencoder_attention_mask={self.encoder_attention_mask}, tokenizer={self.tokenizer}\n)'
|
@@ -1,8 +1,10 @@
|
|
1
1
|
from typing import Tuple
|
2
2
|
|
3
3
|
import torch
|
4
|
+
from rainbowneko.utils import add_dims
|
4
5
|
|
5
6
|
from .sigma_scheduler import SigmaScheduler
|
7
|
+
from .timer import TimeSampler
|
6
8
|
|
7
9
|
try:
|
8
10
|
from diffusers.utils import randn_tensor
|
@@ -11,7 +13,8 @@ except:
|
|
11
13
|
from diffusers.utils.torch_utils import randn_tensor
|
12
14
|
|
13
15
|
class BaseSampler:
|
14
|
-
def __init__(self, sigma_scheduler: SigmaScheduler, pred_type='eps', target_type='eps',
|
16
|
+
def __init__(self, sigma_scheduler: SigmaScheduler, t_sampler:TimeSampler = None, pred_type='eps', target_type='eps',
|
17
|
+
generator: torch.Generator = None):
|
15
18
|
'''
|
16
19
|
Some losses can only be calculated in a specific space. Such as SSIM in x0 space.
|
17
20
|
The model pred need convert to target space.
|
@@ -19,6 +22,9 @@ class BaseSampler:
|
|
19
22
|
:param pred_type: ['x0', 'eps', 'velocity', ..., None] The output space of the model
|
20
23
|
:param target_type: ['x0', 'eps', 'velocity', ..., None] The space to calculate the loss
|
21
24
|
'''
|
25
|
+
if t_sampler is None:
|
26
|
+
t_sampler = TimeSampler()
|
27
|
+
self.t_sampler = t_sampler
|
22
28
|
|
23
29
|
self.sigma_scheduler = sigma_scheduler
|
24
30
|
self.generator = generator
|
@@ -38,15 +44,20 @@ class BaseSampler:
|
|
38
44
|
|
39
45
|
def add_noise(self, x, t) -> Tuple[torch.Tensor, torch.Tensor]:
|
40
46
|
noise = self.make_nosie(x.shape, device=x.device)
|
41
|
-
alpha = self.sigma_scheduler.alpha(t).
|
42
|
-
sigma = self.sigma_scheduler.sigma(t).
|
47
|
+
alpha = add_dims(self.sigma_scheduler.alpha(t), x.ndim-1).to(x.device)
|
48
|
+
sigma = add_dims(self.sigma_scheduler.sigma(t), x.ndim-1).to(x.device)
|
43
49
|
noisy_x = alpha*x+sigma*noise
|
44
50
|
return noisy_x.to(dtype=x.dtype), noise.to(dtype=x.dtype)
|
45
51
|
|
46
|
-
def add_noise_rand_t(self, x):
|
47
|
-
|
52
|
+
def add_noise_rand_t(self, x, reso=None):
|
53
|
+
if x.ndim == 3:
|
54
|
+
B,L,C = x.shape
|
55
|
+
reso = L if reso is None else reso
|
56
|
+
else:
|
57
|
+
B,C,H,W = x.shape
|
58
|
+
reso = H*W if reso is None else reso
|
48
59
|
# timesteps: [0, 1]
|
49
|
-
timesteps = self.
|
60
|
+
timesteps = self.t_sampler.sample(shape=(B,), reso=reso)
|
50
61
|
timesteps = timesteps.to(x.device)
|
51
62
|
noisy_x, noise = self.add_noise(x, timesteps)
|
52
63
|
|
@@ -22,8 +22,8 @@ class DiffusersSampler(BaseSampler):
|
|
22
22
|
|
23
23
|
def c_in(self, t):
|
24
24
|
one = torch.ones_like(t)
|
25
|
-
if hasattr(self.scheduler, '_step_index'):
|
26
|
-
|
25
|
+
# if hasattr(self.scheduler, '_step_index'):
|
26
|
+
# self.scheduler._step_index = None
|
27
27
|
return self.scheduler.scale_model_input(one, t)
|
28
28
|
|
29
29
|
def get_timesteps(self, N_steps, device='cuda'):
|
@@ -35,7 +35,8 @@ class DiffusersSampler(BaseSampler):
|
|
35
35
|
|
36
36
|
def add_noise(self, x, t):
|
37
37
|
noise = randn_tensor(x.shape, generator=self.generator, device=x.device, dtype=x.dtype)
|
38
|
-
|
38
|
+
t_in = self.sigma_scheduler.c_noise(t)
|
39
|
+
return self.scheduler.add_noise(x, noise, t_in), noise
|
39
40
|
|
40
41
|
def prepare_extra_step_kwargs(self, scheduler, generator, eta):
|
41
42
|
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
@@ -7,21 +7,21 @@ class SigmaScheduler:
|
|
7
7
|
return t
|
8
8
|
|
9
9
|
def sigma(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
|
10
|
-
'''
|
10
|
+
r'''
|
11
11
|
x(t) = \alpha(t)*x(0) + \sigma(t)*eps
|
12
12
|
:param t: 0-1, rate of time step
|
13
13
|
'''
|
14
14
|
raise NotImplementedError
|
15
15
|
|
16
16
|
def alpha(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
|
17
|
-
'''
|
17
|
+
r'''
|
18
18
|
x(t) = \alpha(t)*x(0) + \sigma(t)*eps
|
19
19
|
:param t: 0-1, rate of time step
|
20
20
|
'''
|
21
21
|
raise NotImplementedError
|
22
22
|
|
23
23
|
def velocity(self, t: Union[float, torch.Tensor], dt=1e-8, normlize=True) -> Tuple[torch.Tensor, torch.Tensor]:
|
24
|
-
'''
|
24
|
+
r'''
|
25
25
|
v(t) = dx(t)/dt = d\alpha(t)/dt * x(0) + d\sigma(t)/dt *eps
|
26
26
|
:param t: 0-1, rate of time step
|
27
27
|
:return: d\alpha(t)/dt, d\sigma(t)/dt
|
@@ -63,14 +63,14 @@ class SigmaScheduler:
|
|
63
63
|
return torch.ones_like(t, dtype=torch.float32)
|
64
64
|
|
65
65
|
def c_skip(self, t: Union[float, torch.Tensor]):
|
66
|
-
'''
|
66
|
+
r'''
|
67
67
|
\hat{x}(0) = c_skip*x(t) + c_out*f(x(t))
|
68
68
|
:param t: 0-1, rate of time step
|
69
69
|
'''
|
70
70
|
return 1./self.alpha(t)
|
71
71
|
|
72
72
|
def c_out(self, t: Union[float, torch.Tensor]):
|
73
|
-
'''
|
73
|
+
r'''
|
74
74
|
\hat{x}(0) = c_skip*x(t) + c_out*f(x(t))
|
75
75
|
:param t: 0-1, rate of time step
|
76
76
|
'''
|
@@ -78,12 +78,3 @@ class SigmaScheduler:
|
|
78
78
|
|
79
79
|
def c_noise(self, t: Union[float, torch.Tensor]):
|
80
80
|
return t
|
81
|
-
|
82
|
-
def sample(self, min_t=0.0, max_t=1.0, shape=(1,)) -> torch.Tensor:
|
83
|
-
if isinstance(min_t, float):
|
84
|
-
min_t = torch.full(shape, min_t)
|
85
|
-
if isinstance(max_t, float):
|
86
|
-
max_t = torch.full(shape, max_t)
|
87
|
-
|
88
|
-
t = torch.lerp(min_t, max_t, torch.rand_like(min_t))
|
89
|
-
return t
|
@@ -2,8 +2,9 @@ import math
|
|
2
2
|
from typing import Union, Tuple, Callable
|
3
3
|
|
4
4
|
import torch
|
5
|
-
|
6
5
|
from hcpdiff.utils import invert_func
|
6
|
+
from rainbowneko.utils import add_dims
|
7
|
+
|
7
8
|
from .base import SigmaScheduler
|
8
9
|
|
9
10
|
class DDPMDiscreteSigmaScheduler(SigmaScheduler):
|
@@ -93,15 +94,15 @@ class DDPMDiscreteSigmaScheduler(SigmaScheduler):
|
|
93
94
|
|
94
95
|
def get_post_mean(self, t, x_0, x_t):
|
95
96
|
t = (t*len(self.sigmas)).long()
|
96
|
-
return self.posterior_mean_coef1[t].
|
97
|
+
return add_dims(self.posterior_mean_coef1[t].to(t.device), x_0.ndim-1)*x_0+add_dims(self.posterior_mean_coef2[t].to(t.device), x_t.ndim-1)*x_t
|
97
98
|
|
98
|
-
def get_post_log_var(self, t, x_t_var=None):
|
99
|
+
def get_post_log_var(self, t, ndim, x_t_var=None):
|
99
100
|
t = (t*len(self.sigmas)).long()
|
100
|
-
min_log = self.posterior_log_variance_clipped[t].
|
101
|
+
min_log = add_dims(self.posterior_log_variance_clipped[t].to(t.device), ndim-1)
|
101
102
|
if x_t_var is None:
|
102
103
|
return min_log
|
103
104
|
else:
|
104
|
-
max_log = self.betas.log()[t].
|
105
|
+
max_log = add_dims(self.betas.log()[t].to(t.device), ndim-1)
|
105
106
|
# The model_var_values is [-1, 1] for [min_var, max_var].
|
106
107
|
frac = (x_t_var+1)/2
|
107
108
|
model_log_variance = frac*max_log+(1-frac)*min_log
|
@@ -201,7 +202,7 @@ class DDPMContinuousSigmaScheduler(SigmaScheduler):
|
|
201
202
|
B = 1-beta_s
|
202
203
|
B_At = B-A*t
|
203
204
|
|
204
|
-
#
|
205
|
+
# eps for stable
|
205
206
|
eps = 1e-12
|
206
207
|
B = torch.clamp(B, min=eps)
|
207
208
|
B_At = torch.clamp(B_At, min=eps)
|
@@ -21,7 +21,7 @@ class EDMSigmaScheduler(SigmaScheduler):
|
|
21
21
|
return torch.lerp(min_inv_rho, max_inv_rho, t)**self.rho
|
22
22
|
|
23
23
|
def sigma(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
|
24
|
-
'''
|
24
|
+
r'''
|
25
25
|
x_t = c_in(t) * (x(0) + \sigma(t)*eps), eps~N(0,I)
|
26
26
|
'''
|
27
27
|
if isinstance(t, float):
|
@@ -31,7 +31,7 @@ class EDMSigmaScheduler(SigmaScheduler):
|
|
31
31
|
return sigma_edm/torch.sqrt(sigma_edm**2+self.sigma_data**2)
|
32
32
|
|
33
33
|
def alpha(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
|
34
|
-
'''
|
34
|
+
r'''
|
35
35
|
x_t = c_in(t) * (x(0) + \sigma(t)*eps), eps~N(0,I)
|
36
36
|
'''
|
37
37
|
if isinstance(t, float):
|
@@ -41,7 +41,7 @@ class EDMSigmaScheduler(SigmaScheduler):
|
|
41
41
|
return 1./torch.sqrt(sigma_edm**2+self.sigma_data**2)
|
42
42
|
|
43
43
|
def c_skip(self, t: Union[float, torch.Tensor]):
|
44
|
-
'''
|
44
|
+
r'''
|
45
45
|
\hat{x}(0) = c_skip(t)*(x(t)/c_in(t)) + c_out(t)*f(x(t))
|
46
46
|
:param t: 0-1, rate of time step
|
47
47
|
'''
|
@@ -49,7 +49,7 @@ class EDMSigmaScheduler(SigmaScheduler):
|
|
49
49
|
return self.sigma_data**2/torch.sqrt(sigma_edm**2+self.sigma_data**2)
|
50
50
|
|
51
51
|
def c_out(self, t: Union[float, torch.Tensor]):
|
52
|
-
'''
|
52
|
+
r'''
|
53
53
|
\hat{x}(0) = c_skip(t)*(x(t)/c_in(t)) + c_out(t)*f(x(t))
|
54
54
|
:param t: 0-1, rate of time step
|
55
55
|
'''
|
@@ -23,7 +23,7 @@ class FlowSigmaScheduler(SigmaScheduler):
|
|
23
23
|
return 1-t
|
24
24
|
|
25
25
|
def velocity(self, t: Union[float, torch.Tensor], dt=1e-8, normlize=False) -> Tuple[torch.Tensor, torch.Tensor]:
|
26
|
-
'''
|
26
|
+
r'''
|
27
27
|
v(t) = dx(t)/dt = d\alpha(t)/dt * x(0) + d\sigma(t)/dt *eps
|
28
28
|
:param t: 0-1, rate of time step
|
29
29
|
:return: d\alpha(t)/dt, d\sigma(t)/dt
|
@@ -59,14 +59,14 @@ class FlowSigmaScheduler(SigmaScheduler):
|
|
59
59
|
return 1-sigma
|
60
60
|
|
61
61
|
def c_skip(self, t: Union[float, torch.Tensor]):
|
62
|
-
'''
|
62
|
+
r'''
|
63
63
|
\hat{x}(0) = c_skip*x(t) + c_out*f(x(t))
|
64
64
|
:param t: 0-1, rate of time step
|
65
65
|
'''
|
66
66
|
return 1.
|
67
67
|
|
68
68
|
def c_out(self, t: Union[float, torch.Tensor]):
|
69
|
-
'''
|
69
|
+
r'''
|
70
70
|
\hat{x}(0) = c_skip*x(t) + c_out*f(x(t))
|
71
71
|
:param t: 0-1, rate of time step
|
72
72
|
'''
|
@@ -0,0 +1,26 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
class TimeSampler:
|
4
|
+
def sample(self, min_t=0.0, max_t=1.0, shape=(1,), reso=0) -> torch.Tensor:
|
5
|
+
if isinstance(min_t, float):
|
6
|
+
min_t = torch.full(shape, min_t)
|
7
|
+
if isinstance(max_t, float):
|
8
|
+
max_t = torch.full(shape, max_t)
|
9
|
+
|
10
|
+
t = torch.lerp(min_t, max_t, torch.rand_like(min_t))
|
11
|
+
return t
|
12
|
+
|
13
|
+
class LogitNormalSampler(TimeSampler):
|
14
|
+
def __init__(self, mean=0.0, std=1.0):
|
15
|
+
self.mean = mean
|
16
|
+
self.std = std
|
17
|
+
|
18
|
+
def sample(self, min_t=0.0, max_t=1.0, shape=(1,), reso=0) -> torch.Tensor:
|
19
|
+
if isinstance(min_t, float):
|
20
|
+
min_t = torch.full(shape, min_t)
|
21
|
+
if isinstance(max_t, float):
|
22
|
+
max_t = torch.full(shape, max_t)
|
23
|
+
|
24
|
+
t = torch.sigmoid(torch.normal(mean=self.mean, std=self.std, size=shape))
|
25
|
+
t = torch.lerp(min_t, max_t, t)
|
26
|
+
return t
|
@@ -0,0 +1,49 @@
|
|
1
|
+
import torch
|
2
|
+
import math
|
3
|
+
from torch import Tensor
|
4
|
+
|
5
|
+
from .base import TimeSampler
|
6
|
+
|
7
|
+
class ShiftTimeSampler(TimeSampler):
|
8
|
+
def __init__(self, t_sampler: TimeSampler = None, base_reso=1024*1024):
|
9
|
+
self.t_sampler = t_sampler
|
10
|
+
self.base_reso = base_reso
|
11
|
+
|
12
|
+
def sample(self, min_t=0.0, max_t=1.0, shape=(1,), reso=0) -> torch.Tensor:
|
13
|
+
t = self.t_sampler.sample(min_t, max_t, shape)
|
14
|
+
shift = math.sqrt(self.base_reso/(reso))
|
15
|
+
t = (t*shift)/(1+(shift-1)*t)
|
16
|
+
return t
|
17
|
+
|
18
|
+
class FluxShiftTimeSampler(TimeSampler):
|
19
|
+
def __init__(self, t_sampler: TimeSampler = None, base_shift: float = 0.5, max_shift: float = 1.15, base_reso=256, max_reso=4096):
|
20
|
+
self.t_sampler = t_sampler
|
21
|
+
self.base_shift = base_shift
|
22
|
+
self.max_shift = max_shift
|
23
|
+
self.base_reso = base_reso
|
24
|
+
self.max_reso = max_reso
|
25
|
+
|
26
|
+
def time_shift(self, mu: float|Tensor, sigma: float, t: Tensor):
|
27
|
+
if torch.is_tensor(mu):
|
28
|
+
mu = mu.to(t.device)
|
29
|
+
return torch.exp(mu)/(torch.exp(mu)+(1/t-1)**sigma)
|
30
|
+
else:
|
31
|
+
return math.exp(mu)/(math.exp(mu)+(1/t-1)**sigma)
|
32
|
+
|
33
|
+
def get_lin_function(self, xi, x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15):
|
34
|
+
'''
|
35
|
+
^
|
36
|
+
| .(x2,y2)
|
37
|
+
| /
|
38
|
+
| . (x1,y1)
|
39
|
+
|_________>
|
40
|
+
'''
|
41
|
+
m = (y2-y1)/(x2-x1)
|
42
|
+
b = y1-m*x1
|
43
|
+
return m*xi+b
|
44
|
+
|
45
|
+
def sample(self, min_t=0.0, max_t=1.0, shape=(1,), reso=0) -> torch.Tensor:
|
46
|
+
mu = self.get_lin_function(reso, x1=self.base_reso, y1=self.base_shift, x2=self.max_reso, y2=self.max_shift)
|
47
|
+
t = self.t_sampler.sample(min_t, max_t, shape)
|
48
|
+
t = self.time_shift(mu, 1.0, t)
|
49
|
+
return t
|
hcpdiff/easy/__init__.py
CHANGED
@@ -1,2 +1,3 @@
|
|
1
|
-
from .model import SD15_auto_loader, SDXL_auto_loader, PixArt_auto_loader, ControlNet_SD15, make_controlnet_handler
|
1
|
+
from .model import (SD15_auto_loader, SDXL_auto_loader, PixArt_auto_loader, ControlNet_SD15, make_controlnet_handler, Flux_auto_loader,
|
2
|
+
auto_load_wrapper)
|
2
3
|
from .sampler import Diffusers_SD
|
hcpdiff/easy/cfg/sd15_train.py
CHANGED
hcpdiff/easy/model/__init__.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
|
-
from .loader import SD15_auto_loader, SDXL_auto_loader, PixArt_auto_loader
|
1
|
+
from .loader import SD15_auto_loader, SDXL_auto_loader, PixArt_auto_loader, Flux_auto_loader, auto_load_wrapper
|
2
2
|
from .cnet import ControlNet_SD15, make_controlnet_handler
|