hcpdiff 0.9.1__py3-none-any.whl → 2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/__init__.py +4 -4
- hcpdiff/ckpt_manager/__init__.py +4 -5
- hcpdiff/ckpt_manager/ckpt.py +24 -0
- hcpdiff/ckpt_manager/format/__init__.py +4 -0
- hcpdiff/ckpt_manager/format/diffusers.py +59 -0
- hcpdiff/ckpt_manager/format/emb.py +21 -0
- hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
- hcpdiff/ckpt_manager/format/sd_single.py +41 -0
- hcpdiff/ckpt_manager/loader.py +64 -0
- hcpdiff/data/__init__.py +4 -28
- hcpdiff/data/cache/__init__.py +1 -0
- hcpdiff/data/cache/vae.py +102 -0
- hcpdiff/data/dataset.py +20 -0
- hcpdiff/data/handler/__init__.py +3 -0
- hcpdiff/data/handler/controlnet.py +18 -0
- hcpdiff/data/handler/diffusion.py +80 -0
- hcpdiff/data/handler/text.py +111 -0
- hcpdiff/data/source/__init__.py +1 -2
- hcpdiff/data/source/folder_class.py +12 -29
- hcpdiff/data/source/text2img.py +36 -74
- hcpdiff/data/source/text2img_cond.py +9 -15
- hcpdiff/diffusion/__init__.py +0 -0
- hcpdiff/diffusion/noise/__init__.py +2 -0
- hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
- hcpdiff/diffusion/noise/zero_terminal.py +39 -0
- hcpdiff/diffusion/sampler/__init__.py +5 -0
- hcpdiff/diffusion/sampler/base.py +72 -0
- hcpdiff/diffusion/sampler/ddpm.py +20 -0
- hcpdiff/diffusion/sampler/diffusers.py +66 -0
- hcpdiff/diffusion/sampler/edm.py +22 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
- hcpdiff/easy/__init__.py +2 -0
- hcpdiff/easy/cfg/__init__.py +3 -0
- hcpdiff/easy/cfg/sd15_train.py +201 -0
- hcpdiff/easy/cfg/sdxl_train.py +140 -0
- hcpdiff/easy/cfg/t2i.py +177 -0
- hcpdiff/easy/model/__init__.py +2 -0
- hcpdiff/easy/model/cnet.py +31 -0
- hcpdiff/easy/model/loader.py +79 -0
- hcpdiff/easy/sampler.py +46 -0
- hcpdiff/evaluate/__init__.py +1 -0
- hcpdiff/evaluate/previewer.py +60 -0
- hcpdiff/loss/__init__.py +4 -1
- hcpdiff/loss/base.py +41 -0
- hcpdiff/loss/gw.py +35 -0
- hcpdiff/loss/ssim.py +37 -0
- hcpdiff/loss/vlb.py +79 -0
- hcpdiff/loss/weighting.py +66 -0
- hcpdiff/models/__init__.py +2 -2
- hcpdiff/models/cfg_context.py +17 -14
- hcpdiff/models/compose/compose_hook.py +44 -23
- hcpdiff/models/compose/compose_tokenizer.py +21 -8
- hcpdiff/models/compose/sdxl_composer.py +4 -4
- hcpdiff/models/controlnet.py +16 -16
- hcpdiff/models/lora_base_patch.py +14 -25
- hcpdiff/models/lora_layers.py +3 -9
- hcpdiff/models/lora_layers_patch.py +14 -24
- hcpdiff/models/text_emb_ex.py +84 -6
- hcpdiff/models/textencoder_ex.py +54 -18
- hcpdiff/models/wrapper/__init__.py +3 -0
- hcpdiff/models/wrapper/pixart.py +19 -0
- hcpdiff/models/wrapper/sd.py +218 -0
- hcpdiff/models/wrapper/utils.py +20 -0
- hcpdiff/parser/__init__.py +1 -0
- hcpdiff/parser/embpt.py +32 -0
- hcpdiff/tools/convert_caption_txt2json.py +1 -1
- hcpdiff/tools/dataset_generator.py +94 -0
- hcpdiff/tools/download_hf_model.py +24 -0
- hcpdiff/tools/init_proj.py +3 -21
- hcpdiff/tools/lora_convert.py +18 -17
- hcpdiff/tools/save_model.py +12 -0
- hcpdiff/tools/sd2diffusers.py +1 -1
- hcpdiff/train_colo.py +1 -1
- hcpdiff/train_deepspeed.py +1 -1
- hcpdiff/trainer_ac.py +79 -0
- hcpdiff/trainer_ac_single.py +31 -0
- hcpdiff/utils/__init__.py +0 -2
- hcpdiff/utils/inpaint_pipe.py +7 -2
- hcpdiff/utils/net_utils.py +29 -6
- hcpdiff/utils/pipe_hook.py +24 -7
- hcpdiff/utils/utils.py +21 -4
- hcpdiff/workflow/__init__.py +15 -10
- hcpdiff/workflow/daam/__init__.py +1 -0
- hcpdiff/workflow/daam/act.py +66 -0
- hcpdiff/workflow/daam/hook.py +109 -0
- hcpdiff/workflow/diffusion.py +114 -125
- hcpdiff/workflow/fast.py +31 -0
- hcpdiff/workflow/flow.py +67 -0
- hcpdiff/workflow/io.py +36 -130
- hcpdiff/workflow/model.py +46 -43
- hcpdiff/workflow/text.py +78 -46
- hcpdiff/workflow/utils.py +32 -12
- hcpdiff/workflow/vae.py +37 -38
- hcpdiff-2.1.dist-info/METADATA +285 -0
- hcpdiff-2.1.dist-info/RECORD +114 -0
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
- hcpdiff-2.1.dist-info/entry_points.txt +5 -0
- hcpdiff/ckpt_manager/base.py +0 -16
- hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
- hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
- hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -64
- hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
- hcpdiff/data/bucket.py +0 -358
- hcpdiff/data/caption_loader.py +0 -80
- hcpdiff/data/cond_dataset.py +0 -40
- hcpdiff/data/crop_info_dataset.py +0 -40
- hcpdiff/data/data_processor.py +0 -33
- hcpdiff/data/pair_dataset.py +0 -146
- hcpdiff/data/sampler.py +0 -54
- hcpdiff/data/source/base.py +0 -30
- hcpdiff/data/utils.py +0 -80
- hcpdiff/deprecated/__init__.py +0 -1
- hcpdiff/deprecated/cfg_converter.py +0 -81
- hcpdiff/deprecated/lora_convert.py +0 -31
- hcpdiff/infer_workflow.py +0 -57
- hcpdiff/loggers/__init__.py +0 -13
- hcpdiff/loggers/base_logger.py +0 -76
- hcpdiff/loggers/cli_logger.py +0 -40
- hcpdiff/loggers/preview/__init__.py +0 -1
- hcpdiff/loggers/preview/image_previewer.py +0 -149
- hcpdiff/loggers/tensorboard_logger.py +0 -30
- hcpdiff/loggers/wandb_logger.py +0 -31
- hcpdiff/loggers/webui_logger.py +0 -9
- hcpdiff/loss/min_snr_loss.py +0 -52
- hcpdiff/models/layers.py +0 -81
- hcpdiff/models/plugin.py +0 -348
- hcpdiff/models/wrapper.py +0 -75
- hcpdiff/noise/__init__.py +0 -3
- hcpdiff/noise/noise_base.py +0 -16
- hcpdiff/noise/pyramid_noise.py +0 -50
- hcpdiff/noise/zero_terminal.py +0 -44
- hcpdiff/train_ac.py +0 -566
- hcpdiff/train_ac_single.py +0 -39
- hcpdiff/utils/caption_tools.py +0 -105
- hcpdiff/utils/cfg_net_tools.py +0 -321
- hcpdiff/utils/cfg_resolvers.py +0 -16
- hcpdiff/utils/ema.py +0 -52
- hcpdiff/utils/img_size_tool.py +0 -248
- hcpdiff/vis/__init__.py +0 -3
- hcpdiff/vis/base_interface.py +0 -12
- hcpdiff/vis/disk_interface.py +0 -48
- hcpdiff/vis/webui_interface.py +0 -17
- hcpdiff/viser_fast.py +0 -138
- hcpdiff/visualizer.py +0 -265
- hcpdiff/visualizer_reloadable.py +0 -237
- hcpdiff/workflow/base.py +0 -59
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -59
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero2.json +0 -32
- hcpdiff-0.9.1.data/data/hcpdiff/cfgs/zero3.json +0 -39
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
- hcpdiff-0.9.1.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
- hcpdiff-0.9.1.dist-info/METADATA +0 -199
- hcpdiff-0.9.1.dist-info/RECORD +0 -160
- hcpdiff-0.9.1.dist-info/entry_points.txt +0 -2
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
- {hcpdiff-0.9.1.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,22 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
from .base import BaseSampler
|
4
|
+
from .sigma_scheduler import SigmaScheduler
|
5
|
+
|
6
|
+
class EDMSampler(BaseSampler):
|
7
|
+
def __init__(self, sigma_scheduler: SigmaScheduler, generator: torch.Generator = None, sigma_data: float = 1.0, sigma_thr=1000):
|
8
|
+
super().__init__(sigma_scheduler, generator)
|
9
|
+
self.sigma_data = sigma_data
|
10
|
+
self.sigma_thr = sigma_thr
|
11
|
+
|
12
|
+
def c_in(self, sigma):
|
13
|
+
return 1/(sigma**2+self.sigma_data**2).sqrt()
|
14
|
+
|
15
|
+
def c_out(self, sigma):
|
16
|
+
return (sigma*self.sigma_data)/(sigma**2+self.sigma_data**2).sqrt()
|
17
|
+
|
18
|
+
def c_skip(self, sigma):
|
19
|
+
return self.sigma_data**2/(sigma**2+self.sigma_data**2)
|
20
|
+
|
21
|
+
def denoise(self, x, sigma, eps=None, generator=None):
|
22
|
+
raise NotImplementedError
|
@@ -0,0 +1,14 @@
|
|
1
|
+
from typing import Union, Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
class SigmaScheduler:
|
6
|
+
|
7
|
+
def get_sigma(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
|
8
|
+
'''
|
9
|
+
:param t: 0-1, rate of time step
|
10
|
+
'''
|
11
|
+
raise NotImplementedError
|
12
|
+
|
13
|
+
def sample_sigma(self, min_rate=0.0, max_rate=1.0, shape=(1,)) -> Tuple[torch.Tensor, torch.Tensor]:
|
14
|
+
raise NotImplementedError
|
@@ -0,0 +1,197 @@
|
|
1
|
+
import torch
|
2
|
+
import math
|
3
|
+
from typing import Union, Tuple
|
4
|
+
from hcpdiff.utils import linear_interp
|
5
|
+
from .base import SigmaScheduler
|
6
|
+
|
7
|
+
class DDPMDiscreteSigmaScheduler(SigmaScheduler):
|
8
|
+
def __init__(self, beta_schedule: str = "scaled_linear", linear_start=0.00085, linear_end=0.0120, num_timesteps=1000):
|
9
|
+
super().__init__()
|
10
|
+
self.num_timesteps = num_timesteps
|
11
|
+
self.betas = self.make_betas(beta_schedule, linear_start, linear_end, num_timesteps)
|
12
|
+
alphas = 1.0-self.betas
|
13
|
+
self.alphas_cumprod = torch.cumprod(alphas, dim=0)
|
14
|
+
self.sigmas = ((1-self.alphas_cumprod)/self.alphas_cumprod).sqrt()
|
15
|
+
|
16
|
+
# for VLB calculation
|
17
|
+
self.alphas_cumprod_prev = torch.cat([alphas.new_tensor([1.0]), self.alphas_cumprod[:-1]])
|
18
|
+
self.posterior_mean_coef1 = self.betas*torch.sqrt(self.alphas_cumprod_prev)/(1.0-self.alphas_cumprod)
|
19
|
+
self.posterior_mean_coef2 = (1.0-self.alphas_cumprod_prev)*torch.sqrt(alphas)/(1.0-self.alphas_cumprod)
|
20
|
+
|
21
|
+
self.posterior_variance = self.betas*(1.0-self.alphas_cumprod_prev)/(1.0-self.alphas_cumprod)
|
22
|
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
23
|
+
self.posterior_log_variance_clipped = torch.log(torch.cat([self.posterior_variance[1:2], self.posterior_variance[1:]]))
|
24
|
+
|
25
|
+
|
26
|
+
@property
|
27
|
+
def sigma_min(self):
|
28
|
+
return self.sigmas[0]
|
29
|
+
|
30
|
+
@property
|
31
|
+
def sigma_max(self):
|
32
|
+
return self.sigmas[-1]
|
33
|
+
|
34
|
+
def get_sigma(self, t: Union[float, torch.Tensor]):
|
35
|
+
if isinstance(t, float):
|
36
|
+
t = torch.tensor(t)
|
37
|
+
return self.sigmas[(t*len(self.sigmas)).long()]
|
38
|
+
|
39
|
+
def sample_sigma(self, min_rate=0.0, max_rate=1.0, shape=(1,)):
|
40
|
+
if isinstance(min_rate, float):
|
41
|
+
min_rate = torch.full(shape, min_rate)
|
42
|
+
if isinstance(max_rate, float):
|
43
|
+
max_rate = torch.full(shape, max_rate)
|
44
|
+
|
45
|
+
t = torch.lerp(min_rate, max_rate, torch.rand_like(min_rate))
|
46
|
+
t_scale = (t*(self.num_timesteps-1e-5)).long() # [0, num_timesteps-1)
|
47
|
+
return self.sigmas[t_scale], t
|
48
|
+
|
49
|
+
def sigma_to_t(self, sigma: Union[float, torch.Tensor]):
|
50
|
+
t = (self.sigmas-sigma).abs().argmin()
|
51
|
+
return t/self.num_timesteps
|
52
|
+
|
53
|
+
def get_post_mean(self, t, x_0, x_t):
|
54
|
+
t = (t*len(self.sigmas)).long()
|
55
|
+
return self.posterior_mean_coef1[t].view(-1, 1, 1, 1).to(t.device)*x_0 + self.posterior_mean_coef2[t].view(-1, 1, 1, 1).to(t.device)*x_t
|
56
|
+
|
57
|
+
def get_post_log_var(self, t, x_t_var=None):
|
58
|
+
t = (t*len(self.sigmas)).long()
|
59
|
+
min_log = self.posterior_log_variance_clipped[t].view(-1, 1, 1, 1).to(t.device)
|
60
|
+
if x_t_var is None:
|
61
|
+
return min_log
|
62
|
+
else:
|
63
|
+
max_log = self.betas.log()[t].view(-1, 1, 1, 1).to(t.device)
|
64
|
+
# The model_var_values is [-1, 1] for [min_var, max_var].
|
65
|
+
frac = (x_t_var+1)/2
|
66
|
+
model_log_variance = frac*max_log+(1-frac)*min_log
|
67
|
+
return model_log_variance
|
68
|
+
|
69
|
+
|
70
|
+
@staticmethod
|
71
|
+
def betas_for_alpha_bar(
|
72
|
+
num_diffusion_timesteps,
|
73
|
+
max_beta=0.999,
|
74
|
+
alpha_transform_type="cosine",
|
75
|
+
):
|
76
|
+
"""
|
77
|
+
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
78
|
+
(1-beta) over time from t = [0,1].
|
79
|
+
|
80
|
+
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
81
|
+
to that part of the diffusion process.
|
82
|
+
|
83
|
+
|
84
|
+
Args:
|
85
|
+
num_diffusion_timesteps (`int`): the number of betas to produce.
|
86
|
+
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
87
|
+
prevent singularities.
|
88
|
+
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
89
|
+
Choose from `cosine` or `exp`
|
90
|
+
|
91
|
+
Returns:
|
92
|
+
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
93
|
+
"""
|
94
|
+
if alpha_transform_type == "cosine":
|
95
|
+
|
96
|
+
def alpha_bar_fn(t):
|
97
|
+
return math.cos((t+0.008)/1.008*math.pi/2)**2
|
98
|
+
|
99
|
+
elif alpha_transform_type == "exp":
|
100
|
+
|
101
|
+
def alpha_bar_fn(t):
|
102
|
+
return math.exp(t*-12.0)
|
103
|
+
|
104
|
+
else:
|
105
|
+
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
106
|
+
|
107
|
+
betas = []
|
108
|
+
for i in range(num_diffusion_timesteps):
|
109
|
+
t1 = i/num_diffusion_timesteps
|
110
|
+
t2 = (i+1)/num_diffusion_timesteps
|
111
|
+
betas.append(min(1-alpha_bar_fn(t2)/alpha_bar_fn(t1), max_beta))
|
112
|
+
return torch.tensor(betas, dtype=torch.float32)
|
113
|
+
|
114
|
+
@staticmethod
|
115
|
+
def make_betas(beta_schedule, beta_start, beta_end, num_train_timesteps, betas=None):
|
116
|
+
if betas is not None:
|
117
|
+
return torch.tensor(betas, dtype=torch.float32)
|
118
|
+
elif beta_schedule == "linear":
|
119
|
+
return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
120
|
+
elif beta_schedule == "scaled_linear":
|
121
|
+
# this schedule is very specific to the latent diffusion model.
|
122
|
+
return torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32)**2
|
123
|
+
elif beta_schedule == "squaredcos_cap_v2":
|
124
|
+
# Glide cosine schedule
|
125
|
+
return DDPMDiscreteSigmaScheduler.betas_for_alpha_bar(num_train_timesteps)
|
126
|
+
elif beta_schedule == "sigmoid":
|
127
|
+
# GeoDiff sigmoid schedule
|
128
|
+
betas = torch.linspace(-6, 6, num_train_timesteps)
|
129
|
+
return torch.sigmoid(betas)*(beta_end-beta_start)+beta_start
|
130
|
+
else:
|
131
|
+
raise NotImplementedError(f"{beta_schedule} does is not implemented.")
|
132
|
+
|
133
|
+
class DDPMContinuousSigmaScheduler(DDPMDiscreteSigmaScheduler):
|
134
|
+
|
135
|
+
def get_sigma(self, t: Union[float, torch.Tensor]):
|
136
|
+
if isinstance(t, float):
|
137
|
+
t = torch.tensor(t)
|
138
|
+
return linear_interp(self.sigmas, t)
|
139
|
+
|
140
|
+
def sample_sigma(self, min_rate=0.0, max_rate=1.0, shape=(1,)):
|
141
|
+
if isinstance(min_rate, float):
|
142
|
+
min_rate = torch.full(shape, min_rate)
|
143
|
+
if isinstance(max_rate, float):
|
144
|
+
max_rate = torch.full(shape, max_rate)
|
145
|
+
|
146
|
+
t = torch.lerp(min_rate, max_rate, torch.rand_like(min_rate))
|
147
|
+
t_scale = (t*(self.num_timesteps-1-1e-5)) # [0, num_timesteps-1)
|
148
|
+
|
149
|
+
return linear_interp(self.sigmas, t_scale), t
|
150
|
+
|
151
|
+
def sigma_to_t(self, sigma: Union[float, torch.Tensor]):
|
152
|
+
diff = self.sigmas-sigma
|
153
|
+
diff[diff<0] = float('inf')
|
154
|
+
t0 = diff.argmin().clamp(0, self.num_timesteps-2)
|
155
|
+
return t0 + diff.min()/(self.sigmas[t0+1]-self.sigmas[t0])
|
156
|
+
|
157
|
+
class TimeSigmaScheduler(SigmaScheduler):
|
158
|
+
def __init__(self, num_timesteps=1000):
|
159
|
+
super().__init__()
|
160
|
+
self.num_timesteps = num_timesteps
|
161
|
+
|
162
|
+
def get_sigma(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
|
163
|
+
'''
|
164
|
+
:param t: 0-1, rate of time step
|
165
|
+
'''
|
166
|
+
return t
|
167
|
+
|
168
|
+
def sample_sigma(self, min_rate=0.0, max_rate=1.0, shape=(1,)) -> Tuple[torch.Tensor, torch.Tensor]:
|
169
|
+
if isinstance(min_rate, float):
|
170
|
+
min_rate = torch.full(shape, min_rate)
|
171
|
+
if isinstance(max_rate, float):
|
172
|
+
max_rate = torch.full(shape, max_rate)
|
173
|
+
|
174
|
+
t = torch.lerp(min_rate, max_rate, torch.rand_like(min_rate))
|
175
|
+
t_scale = (t*(self.num_timesteps-1e-5)).long() # [0, num_timesteps-1)
|
176
|
+
return t_scale, t
|
177
|
+
|
178
|
+
if __name__ == '__main__':
|
179
|
+
from matplotlib import pyplot as plt
|
180
|
+
import numpy as np
|
181
|
+
|
182
|
+
sigma_scheduler = DDPMDiscreteSigmaScheduler()
|
183
|
+
print(sigma_scheduler.sigma_min, sigma_scheduler.sigma_max)
|
184
|
+
t = torch.linspace(0, 1, 1000)
|
185
|
+
rho = 1.
|
186
|
+
s2 = (sigma_scheduler.sigma_min**(1/rho)+t*(sigma_scheduler.sigma_max**(1/rho)-sigma_scheduler.sigma_min**(1/rho)))**rho
|
187
|
+
t2 = np.interp(s2.log().numpy(), sigma_scheduler.sigmas.log().numpy(), t.numpy())
|
188
|
+
|
189
|
+
plt.figure()
|
190
|
+
plt.plot(sigma_scheduler.sigmas)
|
191
|
+
plt.plot(t2*1000, s2)
|
192
|
+
plt.show()
|
193
|
+
|
194
|
+
plt.figure()
|
195
|
+
plt.plot(sigma_scheduler.sigmas.log())
|
196
|
+
plt.plot(t2*1000, s2.log())
|
197
|
+
plt.show()
|
@@ -0,0 +1,48 @@
|
|
1
|
+
from typing import Union
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import numpy as np
|
5
|
+
|
6
|
+
from .base import SigmaScheduler
|
7
|
+
|
8
|
+
class EDMSigmaScheduler(SigmaScheduler):
|
9
|
+
def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0, num_timesteps=1000):
|
10
|
+
self.sigma_min = torch.tensor(sigma_min)
|
11
|
+
self.sigma_max = torch.tensor(sigma_max)
|
12
|
+
self.rho = rho
|
13
|
+
|
14
|
+
self.num_timesteps=num_timesteps
|
15
|
+
|
16
|
+
def get_sigma(self, t: Union[float, torch.Tensor]):
|
17
|
+
if isinstance(t, float):
|
18
|
+
t = torch.tensor(t)
|
19
|
+
|
20
|
+
min_inv_rho = self.sigma_min**(1/self.rho)
|
21
|
+
max_inv_rho = self.sigma_max**(1/self.rho)
|
22
|
+
return torch.lerp(min_inv_rho, max_inv_rho, t)**self.rho
|
23
|
+
|
24
|
+
def sample_sigma(self, min_rate=0.0, max_rate=1.0, shape=(1,)):
|
25
|
+
if isinstance(min_rate, float):
|
26
|
+
min_rate = torch.full(shape, min_rate)
|
27
|
+
if isinstance(max_rate, float):
|
28
|
+
max_rate = torch.full(shape, max_rate)
|
29
|
+
|
30
|
+
t = torch.lerp(min_rate, max_rate, torch.rand_like(min_rate))
|
31
|
+
return self.get_sigma(t), t
|
32
|
+
|
33
|
+
class EDMRefSigmaScheduler(EDMSigmaScheduler):
|
34
|
+
def __init__(self, ref_scheduler, sigma_min=0.002, sigma_max=80.0, rho=7.0, num_timesteps=1000):
|
35
|
+
super().__init__(sigma_min, sigma_max, rho, num_timesteps=num_timesteps)
|
36
|
+
self.ref_sigmas = ref_scheduler.sigmas.cpu().clip(min=1e-8).log().numpy()
|
37
|
+
self.ref_t = np.linspace(0, 1, len(self.ref_sigmas))
|
38
|
+
|
39
|
+
def sample_sigma(self, min_rate=0.0, max_rate=1.0, shape=(1,)):
|
40
|
+
if isinstance(min_rate, float):
|
41
|
+
min_rate = torch.full(shape, min_rate)
|
42
|
+
if isinstance(max_rate, float):
|
43
|
+
max_rate = torch.full(shape, max_rate)
|
44
|
+
|
45
|
+
t = torch.lerp(min_rate, max_rate, torch.rand_like(min_rate))
|
46
|
+
sigma = self.get_sigma(t)
|
47
|
+
t_rect = torch.tensor(np.interp(sigma.cpu().clip(min=1e-8).log().numpy(), self.ref_sigmas, self.ref_t))
|
48
|
+
return sigma, t_rect
|
hcpdiff/easy/__init__.py
ADDED
@@ -0,0 +1,201 @@
|
|
1
|
+
import torch
|
2
|
+
from rainbowneko.ckpt_manager import ckpt_saver, LAYERS_TRAINABLE, plugin_saver
|
3
|
+
from rainbowneko.data import RatioBucket, FixedBucket
|
4
|
+
from rainbowneko.parser import CfgWDPluginParser, neko_cfg, CfgWDModelParser, disable_neko_cfg
|
5
|
+
from rainbowneko.utils import ConstantLR, Path_Like
|
6
|
+
|
7
|
+
from hcpdiff.data import TextImagePairDataset, Text2ImageSource, StableDiffusionHandler
|
8
|
+
from hcpdiff.data import VaeCache
|
9
|
+
from hcpdiff.easy import SD15_auto_loader
|
10
|
+
from hcpdiff.models import SD15Wrapper, TEHookCFG
|
11
|
+
from hcpdiff.models.lora_layers_patch import LoraLayer
|
12
|
+
|
13
|
+
@neko_cfg
|
14
|
+
def SD15_finetuning(base_model: str, train_steps: int, dataset, save_step: int = 500, lr: float = 1e-5, clip_skip: int = 0,
|
15
|
+
dtype: str = 'fp16', low_vram: bool = False, warmup_steps: int = 0, name: str = 'SD15'):
|
16
|
+
if low_vram:
|
17
|
+
from bitsandbytes.optim import AdamW8bit
|
18
|
+
optimizer = AdamW8bit(_partial_=True)
|
19
|
+
else:
|
20
|
+
optimizer = torch.optim.AdamW(_partial_=True)
|
21
|
+
|
22
|
+
from cfgs.train.py import train_base, tuning_base
|
23
|
+
|
24
|
+
return dict(
|
25
|
+
_base_=[train_base, tuning_base],
|
26
|
+
mixed_precision=dtype,
|
27
|
+
|
28
|
+
model_part=CfgWDModelParser([
|
29
|
+
dict(
|
30
|
+
lr=lr,
|
31
|
+
layers=['denoiser'], # train UNet
|
32
|
+
)
|
33
|
+
], weight_decay=1e-2),
|
34
|
+
|
35
|
+
ckpt_saver=dict(
|
36
|
+
SD15=ckpt_saver(
|
37
|
+
ckpt_type='safetensors',
|
38
|
+
target_module='denoiser',
|
39
|
+
layers=LAYERS_TRAINABLE,
|
40
|
+
)
|
41
|
+
),
|
42
|
+
|
43
|
+
train=dict(
|
44
|
+
train_steps=train_steps,
|
45
|
+
save_step=save_step,
|
46
|
+
|
47
|
+
optimizer=optimizer,
|
48
|
+
|
49
|
+
scheduler=ConstantLR(
|
50
|
+
_partial_=True,
|
51
|
+
warmup_steps=warmup_steps,
|
52
|
+
),
|
53
|
+
),
|
54
|
+
|
55
|
+
model=dict(
|
56
|
+
name=name,
|
57
|
+
|
58
|
+
## Easy config
|
59
|
+
wrapper=SD15Wrapper.from_pretrained(
|
60
|
+
_partial_=True,
|
61
|
+
models=SD15_auto_loader(ckpt_path=base_model, _partial_=True),
|
62
|
+
TE_hook_cfg=TEHookCFG(clip_skip=clip_skip),
|
63
|
+
),
|
64
|
+
),
|
65
|
+
|
66
|
+
data_train=dataset,
|
67
|
+
)
|
68
|
+
|
69
|
+
@neko_cfg
|
70
|
+
def SD15_lora_train(base_model: str, train_steps: int, dataset, save_step: int = 200, lr: float = 1e-4, rank: int = 4, alpha: float = None,
|
71
|
+
clip_skip: int = 0, with_conv: bool = False, dtype: str = 'fp16', low_vram: bool = False, warmup_steps: int = 0,
|
72
|
+
name: str = 'SD15'):
|
73
|
+
with disable_neko_cfg:
|
74
|
+
if alpha is None:
|
75
|
+
alpha = rank
|
76
|
+
|
77
|
+
if with_conv:
|
78
|
+
lora_layers = [
|
79
|
+
r're:denoiser.*\.attn.?$',
|
80
|
+
r're:denoiser.*\.ff$',
|
81
|
+
r're:denoiser.*\.resnets$',
|
82
|
+
r're:denoiser.*\.proj_in$',
|
83
|
+
r're:denoiser.*\.proj_out$',
|
84
|
+
r're:denoiser.*\.conv$',
|
85
|
+
]
|
86
|
+
else:
|
87
|
+
lora_layers = [
|
88
|
+
r're:denoiser.*\.attn.?$',
|
89
|
+
r're:denoiser.*\.ff$',
|
90
|
+
]
|
91
|
+
|
92
|
+
if low_vram:
|
93
|
+
from bitsandbytes.optim import AdamW8bit
|
94
|
+
optimizer = AdamW8bit(_partial_=True, betas=(0.9, 0.99))
|
95
|
+
else:
|
96
|
+
optimizer = torch.optim.AdamW(_partial_=True, betas=(0.9, 0.99))
|
97
|
+
|
98
|
+
from cfgs.train.py.examples import SD_FT
|
99
|
+
|
100
|
+
return dict(
|
101
|
+
_base_=[SD_FT],
|
102
|
+
mixed_precision=dtype,
|
103
|
+
|
104
|
+
model_part=None,
|
105
|
+
model_plugin=CfgWDPluginParser(cfg_plugin=dict(
|
106
|
+
lora1=LoraLayer.wrap_model(
|
107
|
+
_partial_=True,
|
108
|
+
lr=lr,
|
109
|
+
rank=rank,
|
110
|
+
alpha=alpha,
|
111
|
+
layers=lora_layers
|
112
|
+
)
|
113
|
+
), weight_decay=0.1),
|
114
|
+
|
115
|
+
ckpt_saver=dict(
|
116
|
+
_replace_ = True,
|
117
|
+
lora_unet=plugin_saver(
|
118
|
+
ckpt_type='safetensors',
|
119
|
+
target_plugin='lora1',
|
120
|
+
)
|
121
|
+
),
|
122
|
+
|
123
|
+
train=dict(
|
124
|
+
train_steps=train_steps,
|
125
|
+
save_step=save_step,
|
126
|
+
|
127
|
+
optimizer=optimizer,
|
128
|
+
|
129
|
+
scheduler=ConstantLR(
|
130
|
+
_partial_=True,
|
131
|
+
warmup_steps=warmup_steps,
|
132
|
+
),
|
133
|
+
),
|
134
|
+
|
135
|
+
model=dict(
|
136
|
+
name=name,
|
137
|
+
|
138
|
+
wrapper=SD15Wrapper.from_pretrained(
|
139
|
+
_partial_=True,
|
140
|
+
models=SD15_auto_loader(ckpt_path=base_model, _partial_=True),
|
141
|
+
TE_hook_cfg=TEHookCFG(clip_skip=clip_skip),
|
142
|
+
),
|
143
|
+
),
|
144
|
+
|
145
|
+
data_train=dataset,
|
146
|
+
)
|
147
|
+
|
148
|
+
@neko_cfg
|
149
|
+
def cfg_data_SD_ARB(img_root: Path_Like, batch_size: int = 4, trigger_word: str = '', resolution: int = 512*512, num_bucket=4, word_names=None,
|
150
|
+
prompt_dropout: float = 0, prompt_template: Path_Like = 'prompt_template/caption.txt', loss_weight=1.0):
|
151
|
+
if word_names is None:
|
152
|
+
word_names = dict(pt1=trigger_word)
|
153
|
+
else:
|
154
|
+
word_names = word_names
|
155
|
+
|
156
|
+
return TextImagePairDataset(
|
157
|
+
_partial_=True, batch_size=batch_size, loss_weight=loss_weight,
|
158
|
+
source=dict(
|
159
|
+
data_source1=Text2ImageSource(
|
160
|
+
img_root=img_root,
|
161
|
+
label_file='${.img_root}', # path to image captions (file_words)
|
162
|
+
prompt_template=prompt_template,
|
163
|
+
),
|
164
|
+
),
|
165
|
+
handler=StableDiffusionHandler(
|
166
|
+
bucket=RatioBucket,
|
167
|
+
word_names=word_names,
|
168
|
+
erase=prompt_dropout,
|
169
|
+
),
|
170
|
+
bucket=RatioBucket.from_files(
|
171
|
+
target_area=resolution,
|
172
|
+
num_bucket=num_bucket,
|
173
|
+
),
|
174
|
+
cache=VaeCache(bs=batch_size)
|
175
|
+
)
|
176
|
+
|
177
|
+
@neko_cfg
|
178
|
+
def cfg_data_SD_resize_crop(img_root: Path_Like, batch_size: int = 4, trigger_word: str = '', target_size = (512, 512), word_names=None,
|
179
|
+
prompt_dropout: float = 0, prompt_template: Path_Like = 'prompt_template/caption.txt', loss_weight=1.0):
|
180
|
+
if word_names is None:
|
181
|
+
word_names = dict(pt1=trigger_word)
|
182
|
+
else:
|
183
|
+
word_names = word_names
|
184
|
+
|
185
|
+
return TextImagePairDataset(
|
186
|
+
_partial_=True, batch_size=batch_size, loss_weight=loss_weight,
|
187
|
+
source=dict(
|
188
|
+
data_source1=Text2ImageSource(
|
189
|
+
img_root=img_root,
|
190
|
+
label_file='${.img_root}', # path to image captions (file_words)
|
191
|
+
prompt_template=prompt_template,
|
192
|
+
),
|
193
|
+
),
|
194
|
+
handler=StableDiffusionHandler(
|
195
|
+
bucket=FixedBucket,
|
196
|
+
word_names=word_names,
|
197
|
+
erase=prompt_dropout,
|
198
|
+
),
|
199
|
+
bucket=FixedBucket(target_size=target_size),
|
200
|
+
cache=VaeCache(bs=batch_size)
|
201
|
+
)
|
@@ -0,0 +1,140 @@
|
|
1
|
+
import torch
|
2
|
+
from rainbowneko.ckpt_manager import ckpt_saver, plugin_saver, LAYERS_TRAINABLE
|
3
|
+
from rainbowneko.parser import CfgWDPluginParser, neko_cfg, CfgWDModelParser, disable_neko_cfg
|
4
|
+
from rainbowneko.utils import ConstantLR
|
5
|
+
|
6
|
+
from hcpdiff.easy import SDXL_auto_loader
|
7
|
+
from hcpdiff.models import SDXLWrapper
|
8
|
+
from hcpdiff.models.lora_layers_patch import LoraLayer
|
9
|
+
|
10
|
+
@neko_cfg
|
11
|
+
def SDXL_finetuning(base_model: str, train_steps: int, dataset, save_step: int = 500, lr: float = 1e-5,
|
12
|
+
dtype: str = 'fp16', low_vram: bool = False, warmup_steps: int = 0, name: str = 'SDXL'):
|
13
|
+
if low_vram:
|
14
|
+
from bitsandbytes.optim import AdamW8bit
|
15
|
+
optimizer = AdamW8bit(_partial_=True)
|
16
|
+
else:
|
17
|
+
optimizer = torch.optim.AdamW(_partial_=True)
|
18
|
+
|
19
|
+
from cfgs.train.py import train_base, tuning_base
|
20
|
+
|
21
|
+
return dict(
|
22
|
+
_base_=[train_base, tuning_base],
|
23
|
+
mixed_precision=dtype,
|
24
|
+
|
25
|
+
model_part=CfgWDModelParser([
|
26
|
+
dict(
|
27
|
+
lr=lr,
|
28
|
+
layers=['denoiser'], # train UNet
|
29
|
+
)
|
30
|
+
], weight_decay=1e-2),
|
31
|
+
|
32
|
+
ckpt_saver=dict(
|
33
|
+
SDXL=ckpt_saver(
|
34
|
+
ckpt_type='safetensors',
|
35
|
+
target_module='denoiser',
|
36
|
+
layers=LAYERS_TRAINABLE,
|
37
|
+
)
|
38
|
+
),
|
39
|
+
|
40
|
+
train=dict(
|
41
|
+
train_steps=train_steps,
|
42
|
+
save_step=save_step,
|
43
|
+
|
44
|
+
optimizer=optimizer,
|
45
|
+
|
46
|
+
scheduler=ConstantLR(
|
47
|
+
_partial_=True,
|
48
|
+
warmup_steps=warmup_steps,
|
49
|
+
),
|
50
|
+
),
|
51
|
+
|
52
|
+
model=dict(
|
53
|
+
name=name,
|
54
|
+
|
55
|
+
## Easy config
|
56
|
+
wrapper=SDXLWrapper.from_pretrained(
|
57
|
+
_partial_=True,
|
58
|
+
models=SDXL_auto_loader(ckpt_path=base_model, _partial_=True),
|
59
|
+
),
|
60
|
+
),
|
61
|
+
|
62
|
+
data_train=dataset,
|
63
|
+
)
|
64
|
+
|
65
|
+
@neko_cfg
|
66
|
+
def SDXL_lora_train(base_model: str, train_steps: int, dataset, save_step: int = 200, lr: float = 1e-4, rank: int = 4, alpha: float = None,
|
67
|
+
with_conv: bool = False, dtype: str = 'fp16', low_vram: bool = False, warmup_steps: int = 0, name: str = 'SD15'):
|
68
|
+
with disable_neko_cfg:
|
69
|
+
if alpha is None:
|
70
|
+
alpha = rank
|
71
|
+
|
72
|
+
if with_conv:
|
73
|
+
lora_layers = [
|
74
|
+
r're:denoiser.*\.attn.?$',
|
75
|
+
r're:denoiser.*\.ff$',
|
76
|
+
r're:denoiser.*\.resnets$',
|
77
|
+
r're:denoiser.*\.proj_in$',
|
78
|
+
r're:denoiser.*\.proj_out$',
|
79
|
+
r're:denoiser.*\.conv$',
|
80
|
+
]
|
81
|
+
else:
|
82
|
+
lora_layers = [
|
83
|
+
r're:denoiser.*\.attn.?$',
|
84
|
+
r're:denoiser.*\.ff$',
|
85
|
+
]
|
86
|
+
|
87
|
+
if low_vram:
|
88
|
+
from bitsandbytes.optim import AdamW8bit
|
89
|
+
optimizer = AdamW8bit(_partial_=True, betas=(0.9, 0.99))
|
90
|
+
else:
|
91
|
+
optimizer = torch.optim.AdamW(_partial_=True, betas=(0.9, 0.99))
|
92
|
+
|
93
|
+
from cfgs.train.py.examples import SD_FT
|
94
|
+
|
95
|
+
return dict(
|
96
|
+
_base_=[SD_FT],
|
97
|
+
mixed_precision=dtype,
|
98
|
+
|
99
|
+
model_part=None,
|
100
|
+
model_plugin=CfgWDPluginParser(cfg_plugin=dict(
|
101
|
+
lora1=LoraLayer.wrap_model(
|
102
|
+
_partial_=True,
|
103
|
+
lr=lr,
|
104
|
+
rank=rank,
|
105
|
+
alpha=alpha,
|
106
|
+
layers=lora_layers
|
107
|
+
)
|
108
|
+
), weight_decay=0.1),
|
109
|
+
|
110
|
+
ckpt_saver=dict(
|
111
|
+
_replace_ = True,
|
112
|
+
lora_unet=plugin_saver(
|
113
|
+
ckpt_type='safetensors',
|
114
|
+
target_plugin='lora1',
|
115
|
+
)
|
116
|
+
),
|
117
|
+
|
118
|
+
train=dict(
|
119
|
+
train_steps=train_steps,
|
120
|
+
save_step=save_step,
|
121
|
+
|
122
|
+
optimizer=optimizer,
|
123
|
+
|
124
|
+
scheduler=ConstantLR(
|
125
|
+
_partial_=True,
|
126
|
+
warmup_steps=warmup_steps,
|
127
|
+
),
|
128
|
+
),
|
129
|
+
|
130
|
+
model=dict(
|
131
|
+
name=name,
|
132
|
+
|
133
|
+
wrapper=SDXLWrapper.from_pretrained(
|
134
|
+
models=SDXL_auto_loader(ckpt_path=base_model, _partial_=True),
|
135
|
+
_partial_=True,
|
136
|
+
),
|
137
|
+
),
|
138
|
+
|
139
|
+
data_train=dataset,
|
140
|
+
)
|