diffsynth 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth/__init__.py +6 -0
- diffsynth/configs/__init__.py +0 -0
- diffsynth/configs/model_config.py +243 -0
- diffsynth/controlnets/__init__.py +2 -0
- diffsynth/controlnets/controlnet_unit.py +53 -0
- diffsynth/controlnets/processors.py +51 -0
- diffsynth/data/__init__.py +1 -0
- diffsynth/data/simple_text_image.py +35 -0
- diffsynth/data/video.py +148 -0
- diffsynth/extensions/ESRGAN/__init__.py +118 -0
- diffsynth/extensions/FastBlend/__init__.py +63 -0
- diffsynth/extensions/FastBlend/api.py +397 -0
- diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
- diffsynth/extensions/FastBlend/data.py +146 -0
- diffsynth/extensions/FastBlend/patch_match.py +298 -0
- diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
- diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
- diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
- diffsynth/extensions/FastBlend/runners/fast.py +141 -0
- diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
- diffsynth/extensions/RIFE/__init__.py +242 -0
- diffsynth/extensions/__init__.py +0 -0
- diffsynth/models/__init__.py +1 -0
- diffsynth/models/attention.py +89 -0
- diffsynth/models/downloader.py +66 -0
- diffsynth/models/hunyuan_dit.py +451 -0
- diffsynth/models/hunyuan_dit_text_encoder.py +163 -0
- diffsynth/models/kolors_text_encoder.py +1363 -0
- diffsynth/models/lora.py +195 -0
- diffsynth/models/model_manager.py +536 -0
- diffsynth/models/sd3_dit.py +798 -0
- diffsynth/models/sd3_text_encoder.py +1107 -0
- diffsynth/models/sd3_vae_decoder.py +81 -0
- diffsynth/models/sd3_vae_encoder.py +95 -0
- diffsynth/models/sd_controlnet.py +588 -0
- diffsynth/models/sd_ipadapter.py +57 -0
- diffsynth/models/sd_motion.py +199 -0
- diffsynth/models/sd_text_encoder.py +321 -0
- diffsynth/models/sd_unet.py +1108 -0
- diffsynth/models/sd_vae_decoder.py +336 -0
- diffsynth/models/sd_vae_encoder.py +282 -0
- diffsynth/models/sdxl_ipadapter.py +122 -0
- diffsynth/models/sdxl_motion.py +104 -0
- diffsynth/models/sdxl_text_encoder.py +759 -0
- diffsynth/models/sdxl_unet.py +1899 -0
- diffsynth/models/sdxl_vae_decoder.py +24 -0
- diffsynth/models/sdxl_vae_encoder.py +24 -0
- diffsynth/models/svd_image_encoder.py +505 -0
- diffsynth/models/svd_unet.py +2004 -0
- diffsynth/models/svd_vae_decoder.py +578 -0
- diffsynth/models/svd_vae_encoder.py +139 -0
- diffsynth/models/tiler.py +106 -0
- diffsynth/pipelines/__init__.py +9 -0
- diffsynth/pipelines/base.py +34 -0
- diffsynth/pipelines/dancer.py +178 -0
- diffsynth/pipelines/hunyuan_image.py +274 -0
- diffsynth/pipelines/pipeline_runner.py +105 -0
- diffsynth/pipelines/sd3_image.py +132 -0
- diffsynth/pipelines/sd_image.py +173 -0
- diffsynth/pipelines/sd_video.py +266 -0
- diffsynth/pipelines/sdxl_image.py +191 -0
- diffsynth/pipelines/sdxl_video.py +223 -0
- diffsynth/pipelines/svd_video.py +297 -0
- diffsynth/processors/FastBlend.py +142 -0
- diffsynth/processors/PILEditor.py +28 -0
- diffsynth/processors/RIFE.py +77 -0
- diffsynth/processors/__init__.py +0 -0
- diffsynth/processors/base.py +6 -0
- diffsynth/processors/sequencial_processor.py +41 -0
- diffsynth/prompters/__init__.py +6 -0
- diffsynth/prompters/base_prompter.py +57 -0
- diffsynth/prompters/hunyuan_dit_prompter.py +69 -0
- diffsynth/prompters/kolors_prompter.py +353 -0
- diffsynth/prompters/prompt_refiners.py +77 -0
- diffsynth/prompters/sd3_prompter.py +92 -0
- diffsynth/prompters/sd_prompter.py +73 -0
- diffsynth/prompters/sdxl_prompter.py +61 -0
- diffsynth/schedulers/__init__.py +3 -0
- diffsynth/schedulers/continuous_ode.py +59 -0
- diffsynth/schedulers/ddim.py +79 -0
- diffsynth/schedulers/flow_match.py +51 -0
- diffsynth/tokenizer_configs/__init__.py +0 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json +7 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json +16 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt +47020 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt +21128 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json +28 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json +1 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model +0 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json +1 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model +0 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json +12 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt +0 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json +24 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json +34 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json +38 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json +125 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/spiece.model +0 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json +129428 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json +940 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt +40213 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json +24 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json +38 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/vocab.json +49411 -0
- diffsynth/trainers/__init__.py +0 -0
- diffsynth/trainers/text_to_image.py +253 -0
- diffsynth-1.0.0.dist-info/LICENSE +201 -0
- diffsynth-1.0.0.dist-info/METADATA +23 -0
- diffsynth-1.0.0.dist-info/RECORD +120 -0
- diffsynth-1.0.0.dist-info/WHEEL +5 -0
- diffsynth-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from .base_prompter import BasePrompter, tokenize_long_prompt
|
|
2
|
+
from ..models.model_manager import ModelManager
|
|
3
|
+
from ..models import SDXLTextEncoder, SDXLTextEncoder2
|
|
4
|
+
from transformers import CLIPTokenizer
|
|
5
|
+
import torch, os
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SDXLPrompter(BasePrompter):
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
tokenizer_path=None,
|
|
13
|
+
tokenizer_2_path=None
|
|
14
|
+
):
|
|
15
|
+
if tokenizer_path is None:
|
|
16
|
+
base_path = os.path.dirname(os.path.dirname(__file__))
|
|
17
|
+
tokenizer_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion/tokenizer")
|
|
18
|
+
if tokenizer_2_path is None:
|
|
19
|
+
base_path = os.path.dirname(os.path.dirname(__file__))
|
|
20
|
+
tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_xl/tokenizer_2")
|
|
21
|
+
super().__init__()
|
|
22
|
+
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
|
23
|
+
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
|
|
24
|
+
self.text_encoder: SDXLTextEncoder = None
|
|
25
|
+
self.text_encoder_2: SDXLTextEncoder2 = None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def fetch_models(self, text_encoder: SDXLTextEncoder = None, text_encoder_2: SDXLTextEncoder2 = None):
|
|
29
|
+
self.text_encoder = text_encoder
|
|
30
|
+
self.text_encoder_2 = text_encoder_2
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def encode_prompt(
|
|
34
|
+
self,
|
|
35
|
+
prompt,
|
|
36
|
+
clip_skip=1,
|
|
37
|
+
clip_skip_2=2,
|
|
38
|
+
positive=True,
|
|
39
|
+
device="cuda"
|
|
40
|
+
):
|
|
41
|
+
prompt = self.process_prompt(prompt, positive=positive)
|
|
42
|
+
|
|
43
|
+
# 1
|
|
44
|
+
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
|
|
45
|
+
prompt_emb_1 = self.text_encoder(input_ids, clip_skip=clip_skip)
|
|
46
|
+
|
|
47
|
+
# 2
|
|
48
|
+
input_ids_2 = tokenize_long_prompt(self.tokenizer_2, prompt).to(device)
|
|
49
|
+
add_text_embeds, prompt_emb_2 = self.text_encoder_2(input_ids_2, clip_skip=clip_skip_2)
|
|
50
|
+
|
|
51
|
+
# Merge
|
|
52
|
+
if prompt_emb_1.shape[0] != prompt_emb_2.shape[0]:
|
|
53
|
+
max_batch_size = min(prompt_emb_1.shape[0], prompt_emb_2.shape[0])
|
|
54
|
+
prompt_emb_1 = prompt_emb_1[: max_batch_size]
|
|
55
|
+
prompt_emb_2 = prompt_emb_2[: max_batch_size]
|
|
56
|
+
prompt_emb = torch.concatenate([prompt_emb_1, prompt_emb_2], dim=-1)
|
|
57
|
+
|
|
58
|
+
# For very long prompt, we only use the first 77 tokens to compute `add_text_embeds`.
|
|
59
|
+
add_text_embeds = add_text_embeds[0:1]
|
|
60
|
+
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
|
61
|
+
return add_text_embeds, prompt_emb
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class ContinuousODEScheduler():
|
|
5
|
+
|
|
6
|
+
def __init__(self, num_inference_steps=100, sigma_max=700.0, sigma_min=0.002, rho=7.0):
|
|
7
|
+
self.sigma_max = sigma_max
|
|
8
|
+
self.sigma_min = sigma_min
|
|
9
|
+
self.rho = rho
|
|
10
|
+
self.set_timesteps(num_inference_steps)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0):
|
|
14
|
+
ramp = torch.linspace(1-denoising_strength, 1, num_inference_steps)
|
|
15
|
+
min_inv_rho = torch.pow(torch.tensor((self.sigma_min,)), (1 / self.rho))
|
|
16
|
+
max_inv_rho = torch.pow(torch.tensor((self.sigma_max,)), (1 / self.rho))
|
|
17
|
+
self.sigmas = torch.pow(max_inv_rho + ramp * (min_inv_rho - max_inv_rho), self.rho)
|
|
18
|
+
self.timesteps = torch.log(self.sigmas) * 0.25
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def step(self, model_output, timestep, sample, to_final=False):
|
|
22
|
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
|
23
|
+
sigma = self.sigmas[timestep_id]
|
|
24
|
+
sample *= (sigma*sigma + 1).sqrt()
|
|
25
|
+
estimated_sample = -sigma / (sigma*sigma + 1).sqrt() * model_output + 1 / (sigma*sigma + 1) * sample
|
|
26
|
+
if to_final or timestep_id + 1 >= len(self.timesteps):
|
|
27
|
+
prev_sample = estimated_sample
|
|
28
|
+
else:
|
|
29
|
+
sigma_ = self.sigmas[timestep_id + 1]
|
|
30
|
+
derivative = 1 / sigma * (sample - estimated_sample)
|
|
31
|
+
prev_sample = sample + derivative * (sigma_ - sigma)
|
|
32
|
+
prev_sample /= (sigma_*sigma_ + 1).sqrt()
|
|
33
|
+
return prev_sample
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def return_to_timestep(self, timestep, sample, sample_stablized):
|
|
37
|
+
# This scheduler doesn't support this function.
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def add_noise(self, original_samples, noise, timestep):
|
|
42
|
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
|
43
|
+
sigma = self.sigmas[timestep_id]
|
|
44
|
+
sample = (original_samples + noise * sigma) / (sigma*sigma + 1).sqrt()
|
|
45
|
+
return sample
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def training_target(self, sample, noise, timestep):
|
|
49
|
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
|
50
|
+
sigma = self.sigmas[timestep_id]
|
|
51
|
+
target = (-(sigma*sigma + 1).sqrt() / sigma + 1 / (sigma*sigma + 1).sqrt() / sigma) * sample + 1 / (sigma*sigma + 1).sqrt() * noise
|
|
52
|
+
return target
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def training_weight(self, timestep):
|
|
56
|
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
|
57
|
+
sigma = self.sigmas[timestep_id]
|
|
58
|
+
weight = (1 + sigma*sigma).sqrt() / sigma
|
|
59
|
+
return weight
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import torch, math
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class EnhancedDDIMScheduler():
|
|
5
|
+
|
|
6
|
+
def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="epsilon"):
|
|
7
|
+
self.num_train_timesteps = num_train_timesteps
|
|
8
|
+
if beta_schedule == "scaled_linear":
|
|
9
|
+
betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32))
|
|
10
|
+
elif beta_schedule == "linear":
|
|
11
|
+
betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
|
12
|
+
else:
|
|
13
|
+
raise NotImplementedError(f"{beta_schedule} is not implemented")
|
|
14
|
+
self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0).tolist()
|
|
15
|
+
self.set_timesteps(10)
|
|
16
|
+
self.prediction_type = prediction_type
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def set_timesteps(self, num_inference_steps, denoising_strength=1.0):
|
|
20
|
+
# The timesteps are aligned to 999...0, which is different from other implementations,
|
|
21
|
+
# but I think this implementation is more reasonable in theory.
|
|
22
|
+
max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0)
|
|
23
|
+
num_inference_steps = min(num_inference_steps, max_timestep + 1)
|
|
24
|
+
if num_inference_steps == 1:
|
|
25
|
+
self.timesteps = torch.Tensor([max_timestep])
|
|
26
|
+
else:
|
|
27
|
+
step_length = max_timestep / (num_inference_steps - 1)
|
|
28
|
+
self.timesteps = torch.Tensor([round(max_timestep - i*step_length) for i in range(num_inference_steps)])
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev):
|
|
32
|
+
if self.prediction_type == "epsilon":
|
|
33
|
+
weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t)
|
|
34
|
+
weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t)
|
|
35
|
+
prev_sample = sample * weight_x + model_output * weight_e
|
|
36
|
+
elif self.prediction_type == "v_prediction":
|
|
37
|
+
weight_e = -math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t)) + math.sqrt(alpha_prod_t * (1 - alpha_prod_t_prev))
|
|
38
|
+
weight_x = math.sqrt(alpha_prod_t * alpha_prod_t_prev) + math.sqrt((1 - alpha_prod_t) * (1 - alpha_prod_t_prev))
|
|
39
|
+
prev_sample = sample * weight_x + model_output * weight_e
|
|
40
|
+
else:
|
|
41
|
+
raise NotImplementedError(f"{self.prediction_type} is not implemented")
|
|
42
|
+
return prev_sample
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def step(self, model_output, timestep, sample, to_final=False):
|
|
46
|
+
alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
|
|
47
|
+
if isinstance(timestep, torch.Tensor):
|
|
48
|
+
timestep = timestep.cpu()
|
|
49
|
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
|
50
|
+
if to_final or timestep_id + 1 >= len(self.timesteps):
|
|
51
|
+
alpha_prod_t_prev = 1.0
|
|
52
|
+
else:
|
|
53
|
+
timestep_prev = int(self.timesteps[timestep_id + 1])
|
|
54
|
+
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev]
|
|
55
|
+
|
|
56
|
+
return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def return_to_timestep(self, timestep, sample, sample_stablized):
|
|
60
|
+
alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
|
|
61
|
+
noise_pred = (sample - math.sqrt(alpha_prod_t) * sample_stablized) / math.sqrt(1 - alpha_prod_t)
|
|
62
|
+
return noise_pred
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def add_noise(self, original_samples, noise, timestep):
|
|
66
|
+
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
|
67
|
+
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
|
68
|
+
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
|
69
|
+
return noisy_samples
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def training_target(self, sample, noise, timestep):
|
|
73
|
+
if self.prediction_type == "epsilon":
|
|
74
|
+
return noise
|
|
75
|
+
else:
|
|
76
|
+
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
|
77
|
+
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
|
78
|
+
target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
|
79
|
+
return target
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class FlowMatchScheduler():
|
|
6
|
+
|
|
7
|
+
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002):
|
|
8
|
+
self.num_train_timesteps = num_train_timesteps
|
|
9
|
+
self.shift = shift
|
|
10
|
+
self.sigma_max = sigma_max
|
|
11
|
+
self.sigma_min = sigma_min
|
|
12
|
+
self.set_timesteps(num_inference_steps)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0):
|
|
16
|
+
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
|
|
17
|
+
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
|
|
18
|
+
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
|
|
19
|
+
self.timesteps = self.sigmas * self.num_train_timesteps
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def step(self, model_output, timestep, sample, to_final=False):
|
|
23
|
+
if isinstance(timestep, torch.Tensor):
|
|
24
|
+
timestep = timestep.cpu()
|
|
25
|
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
|
26
|
+
sigma = self.sigmas[timestep_id]
|
|
27
|
+
if to_final or timestep_id + 1 >= len(self.timesteps):
|
|
28
|
+
sigma_ = 0
|
|
29
|
+
else:
|
|
30
|
+
sigma_ = self.sigmas[timestep_id + 1]
|
|
31
|
+
prev_sample = sample + model_output * (sigma_ - sigma)
|
|
32
|
+
return prev_sample
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def return_to_timestep(self, timestep, sample, sample_stablized):
|
|
36
|
+
# This scheduler doesn't support this function.
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def add_noise(self, original_samples, noise, timestep):
|
|
41
|
+
if isinstance(timestep, torch.Tensor):
|
|
42
|
+
timestep = timestep.cpu()
|
|
43
|
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
|
44
|
+
sigma = self.sigmas[timestep_id]
|
|
45
|
+
sample = (1 - sigma) * original_samples + sigma * noise
|
|
46
|
+
return sample
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def training_target(self, sample, noise, timestep):
|
|
50
|
+
target = noise - sample
|
|
51
|
+
return target
|
|
File without changes
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
{
|
|
2
|
+
"cls_token": "[CLS]",
|
|
3
|
+
"do_basic_tokenize": true,
|
|
4
|
+
"do_lower_case": true,
|
|
5
|
+
"mask_token": "[MASK]",
|
|
6
|
+
"name_or_path": "hfl/chinese-roberta-wwm-ext",
|
|
7
|
+
"never_split": null,
|
|
8
|
+
"pad_token": "[PAD]",
|
|
9
|
+
"sep_token": "[SEP]",
|
|
10
|
+
"special_tokens_map_file": "/home/chenweifeng/.cache/huggingface/hub/models--hfl--chinese-roberta-wwm-ext/snapshots/5c58d0b8ec1d9014354d691c538661bf00bfdb44/special_tokens_map.json",
|
|
11
|
+
"strip_accents": null,
|
|
12
|
+
"tokenize_chinese_chars": true,
|
|
13
|
+
"tokenizer_class": "BertTokenizer",
|
|
14
|
+
"unk_token": "[UNK]",
|
|
15
|
+
"model_max_length": 77
|
|
16
|
+
}
|