flaxdiff 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/__init__.py +0 -0
- flaxdiff/models/__init__.py +1 -0
- flaxdiff/models/attention.py +489 -0
- flaxdiff/models/common.py +7 -0
- flaxdiff/models/favor_fastattn.py +723 -0
- flaxdiff/models/simple_unet.py +519 -0
- flaxdiff/predictors/__init__.py +96 -0
- flaxdiff/samplers/__init__.py +7 -0
- flaxdiff/samplers/common.py +113 -0
- flaxdiff/samplers/ddim.py +10 -0
- flaxdiff/samplers/ddpm.py +43 -0
- flaxdiff/samplers/euler.py +59 -0
- flaxdiff/samplers/heun_sampler.py +28 -0
- flaxdiff/samplers/multistep_dpm.py +60 -0
- flaxdiff/samplers/rk4_sampler.py +34 -0
- flaxdiff/schedulers/__init__.py +6 -0
- flaxdiff/schedulers/common.py +98 -0
- flaxdiff/schedulers/continuous.py +12 -0
- flaxdiff/schedulers/cosine.py +40 -0
- flaxdiff/schedulers/discrete.py +74 -0
- flaxdiff/schedulers/exp.py +13 -0
- flaxdiff/schedulers/karras.py +69 -0
- flaxdiff/schedulers/linear.py +14 -0
- flaxdiff/schedulers/sqrt.py +10 -0
- flaxdiff/trainer/__init__.py +216 -0
- flaxdiff/utils.py +89 -0
- flaxdiff-0.1.1.dist-info/METADATA +326 -0
- flaxdiff-0.1.1.dist-info/RECORD +30 -0
- flaxdiff-0.1.1.dist-info/WHEEL +5 -0
- flaxdiff-0.1.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,113 @@
|
|
1
|
+
from flax import linen as nn
|
2
|
+
import jax
|
3
|
+
import jax.numpy as jnp
|
4
|
+
import tqdm
|
5
|
+
from typing import Union
|
6
|
+
from ..schedulers import NoiseScheduler
|
7
|
+
from ..utils import RandomMarkovState, MarkovState, clip_images
|
8
|
+
from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
|
9
|
+
|
10
|
+
class DiffusionSampler():
|
11
|
+
model:nn.Module
|
12
|
+
noise_schedule:NoiseScheduler
|
13
|
+
params:dict
|
14
|
+
model_output_transform:DiffusionPredictionTransform
|
15
|
+
|
16
|
+
def __init__(self, model:nn.Module, params:dict,
|
17
|
+
noise_schedule:NoiseScheduler,
|
18
|
+
model_output_transform:DiffusionPredictionTransform=EpsilonPredictionTransform()):
|
19
|
+
self.model = model
|
20
|
+
self.noise_schedule = noise_schedule
|
21
|
+
self.params = params
|
22
|
+
self.model_output_transform = model_output_transform
|
23
|
+
|
24
|
+
@jax.jit
|
25
|
+
def sample_model(x_t, t):
|
26
|
+
rates = self.noise_schedule.get_rates(t)
|
27
|
+
c_in = self.model_output_transform.get_input_scale(rates)
|
28
|
+
model_output = self.model.apply(self.params, *self.noise_schedule.transform_inputs(x_t * c_in, t))
|
29
|
+
x_0, eps = self.model_output_transform(x_t, model_output, t, self.noise_schedule)
|
30
|
+
return x_0, eps, model_output
|
31
|
+
|
32
|
+
self.sample_model = sample_model
|
33
|
+
|
34
|
+
# Used to sample from the diffusion model
|
35
|
+
def sample_step(self, current_samples:jnp.ndarray, current_step, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
|
36
|
+
# First clip the noisy images
|
37
|
+
# pred_images = clip_images(pred_images)
|
38
|
+
step_ones = jnp.ones((current_samples.shape[0], ), dtype=jnp.int32)
|
39
|
+
current_step = step_ones * current_step
|
40
|
+
next_step = step_ones * next_step
|
41
|
+
pred_images, pred_noise, _ = self.sample_model(current_samples, current_step)
|
42
|
+
# plotImages(pred_images)
|
43
|
+
new_samples, state = self.take_next_step(current_samples=current_samples, reconstructed_samples=pred_images,
|
44
|
+
pred_noise=pred_noise, current_step=current_step, next_step=next_step, state=state)
|
45
|
+
return new_samples, state
|
46
|
+
|
47
|
+
def take_next_step(self, current_samples, reconstructed_samples,
|
48
|
+
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
49
|
+
# estimate the q(x_{t-1} | x_t, x_0).
|
50
|
+
# pred_images is x_0, noisy_images is x_t, steps is t
|
51
|
+
return NotImplementedError
|
52
|
+
|
53
|
+
def scale_steps(self, steps):
|
54
|
+
scale_factor = self.noise_schedule.max_timesteps / 1000
|
55
|
+
return steps * scale_factor
|
56
|
+
|
57
|
+
def get_steps(self, start_step, end_step, diffusion_steps):
|
58
|
+
step_range = start_step - end_step
|
59
|
+
if diffusion_steps is None or diffusion_steps == 0:
|
60
|
+
diffusion_steps = start_step - end_step
|
61
|
+
diffusion_steps = min(diffusion_steps, step_range)
|
62
|
+
steps = jnp.linspace(end_step, start_step, diffusion_steps, dtype=jnp.int16)[::-1]
|
63
|
+
return steps
|
64
|
+
|
65
|
+
def get_initial_samples(self, num_images, rngs:jax.random.PRNGKey, start_step, image_size=64):
|
66
|
+
start_step = self.scale_steps(start_step)
|
67
|
+
alpha_n, sigma_n = self.noise_schedule.get_rates(start_step)
|
68
|
+
variance = jnp.sqrt(alpha_n ** 2 + sigma_n ** 2)
|
69
|
+
return jax.random.normal(rngs, (num_images, image_size, image_size, 3)) * variance
|
70
|
+
|
71
|
+
def generate_images(self,
|
72
|
+
num_images=16,
|
73
|
+
diffusion_steps=1000,
|
74
|
+
start_step:int = None,
|
75
|
+
end_step:int = 0,
|
76
|
+
steps_override=None,
|
77
|
+
priors=None,
|
78
|
+
rngstate:RandomMarkovState=RandomMarkovState(jax.random.PRNGKey(42))) -> jnp.ndarray:
|
79
|
+
if priors is None:
|
80
|
+
rngstate, newrngs = rngstate.get_random_key()
|
81
|
+
samples = self.get_initial_samples(num_images, newrngs, start_step)
|
82
|
+
else:
|
83
|
+
print("Using priors")
|
84
|
+
samples = priors
|
85
|
+
|
86
|
+
@jax.jit
|
87
|
+
def sample_step(state:RandomMarkovState, samples, current_step, next_step):
|
88
|
+
samples, state = self.sample_step(current_samples=samples,
|
89
|
+
current_step=current_step,
|
90
|
+
state=state, next_step=next_step)
|
91
|
+
return samples, state
|
92
|
+
|
93
|
+
if start_step is None:
|
94
|
+
start_step = self.noise_schedule.max_timesteps
|
95
|
+
|
96
|
+
if steps_override is not None:
|
97
|
+
steps = steps_override
|
98
|
+
else:
|
99
|
+
steps = self.get_steps(start_step, end_step, diffusion_steps)
|
100
|
+
|
101
|
+
# print("Sampling steps", steps)
|
102
|
+
for i in tqdm.tqdm(range(0, len(steps))):
|
103
|
+
current_step = self.scale_steps(steps[i])
|
104
|
+
next_step = self.scale_steps(steps[i+1] if i+1 < len(steps) else 0)
|
105
|
+
if i != len(steps) - 1:
|
106
|
+
# print("normal step")
|
107
|
+
samples, rngstate = sample_step(rngstate, samples, current_step, next_step)
|
108
|
+
else:
|
109
|
+
# print("last step")
|
110
|
+
step_ones = jnp.ones((num_images, ), dtype=jnp.int32)
|
111
|
+
samples, _, _ = self.sample_model(samples, current_step * step_ones)
|
112
|
+
samples = clip_images(samples)
|
113
|
+
return samples
|
@@ -0,0 +1,10 @@
|
|
1
|
+
import jax.numpy as jnp
|
2
|
+
from .common import DiffusionSampler
|
3
|
+
from ..utils import MarkovState
|
4
|
+
|
5
|
+
class DDIMSampler(DiffusionSampler):
|
6
|
+
def take_next_step(self,
|
7
|
+
current_samples, reconstructed_samples,
|
8
|
+
pred_noise, current_step, state:MarkovState, next_step=None) -> tuple[jnp.ndarray, MarkovState]:
|
9
|
+
next_signal_rate, next_noise_rate = self.noise_schedule.get_rates(next_step)
|
10
|
+
return reconstructed_samples * next_signal_rate + pred_noise * next_noise_rate, state
|
@@ -0,0 +1,43 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
from .common import DiffusionSampler
|
4
|
+
from ..utils import MarkovState, RandomMarkovState
|
5
|
+
class DDPMSampler(DiffusionSampler):
|
6
|
+
def take_next_step(self,
|
7
|
+
current_samples, reconstructed_samples,
|
8
|
+
pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
|
9
|
+
mean = self.noise_schedule.get_posterior_mean(reconstructed_samples, current_samples, current_step)
|
10
|
+
variance = self.noise_schedule.get_posterior_variance(steps=current_step)
|
11
|
+
|
12
|
+
state, rng = state.get_random_key()
|
13
|
+
# Now sample from the posterior
|
14
|
+
noise = jax.random.normal(rng, reconstructed_samples.shape, dtype=jnp.float32)
|
15
|
+
|
16
|
+
return mean + noise * variance, state
|
17
|
+
|
18
|
+
def generate_images(self, num_images=16, diffusion_steps=1000, start_step: int = None, *args, **kwargs):
|
19
|
+
return super().generate_images(num_images=num_images, diffusion_steps=diffusion_steps, start_step=start_step, *args, **kwargs)
|
20
|
+
|
21
|
+
class SimpleDDPMSampler(DiffusionSampler):
|
22
|
+
def take_next_step(self,
|
23
|
+
current_samples, reconstructed_samples,
|
24
|
+
pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
|
25
|
+
state, rng = state.get_random_key()
|
26
|
+
noise = jax.random.normal(rng, reconstructed_samples.shape, dtype=jnp.float32)
|
27
|
+
|
28
|
+
# Compute noise rates and signal rates only once
|
29
|
+
current_signal_rate, current_noise_rate = self.noise_schedule.get_rates(current_step)
|
30
|
+
next_signal_rate, next_noise_rate = self.noise_schedule.get_rates(next_step)
|
31
|
+
|
32
|
+
pred_noise_coeff = ((next_noise_rate ** 2) * current_signal_rate) / (current_noise_rate * next_signal_rate)
|
33
|
+
|
34
|
+
noise_ratio_squared = (next_noise_rate ** 2) / (current_noise_rate ** 2)
|
35
|
+
signal_ratio_squared = (current_signal_rate ** 2) / (next_signal_rate ** 2)
|
36
|
+
betas = (1 - signal_ratio_squared)
|
37
|
+
gamma = jnp.sqrt(noise_ratio_squared * betas)
|
38
|
+
|
39
|
+
next_samples = next_signal_rate * reconstructed_samples + pred_noise_coeff * pred_noise + noise * gamma
|
40
|
+
# pred_noise_coeff = ((next_noise_rate ** 2) * current_signal_rate) / (current_noise_rate * next_signal_rate)
|
41
|
+
# next_samples = (2 - jnp.sqrt(1 - betas)) * current_samples - betas * (pred_noise / current_noise_rate) + noise * gamma#jnp.sqrt(betas)
|
42
|
+
# next_samples = (1 / (jnp.sqrt(1 - betas) + 1.e-24)) * (current_samples - betas * (pred_noise / current_noise_rate)) + noise * gamma
|
43
|
+
return next_samples, state
|
@@ -0,0 +1,59 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
from .common import DiffusionSampler
|
4
|
+
from ..utils import RandomMarkovState
|
5
|
+
|
6
|
+
class EulerSampler(DiffusionSampler):
|
7
|
+
# Basically a DDIM Sampler but parameterized as an ODE
|
8
|
+
def take_next_step(self,
|
9
|
+
current_samples, reconstructed_samples,
|
10
|
+
pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
|
11
|
+
current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
|
12
|
+
next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
|
13
|
+
|
14
|
+
dt = next_sigma - current_sigma
|
15
|
+
|
16
|
+
x_0_coeff = (current_alpha * next_sigma - next_alpha * current_sigma) / (dt)
|
17
|
+
dx = (current_samples - x_0_coeff * reconstructed_samples) / current_sigma
|
18
|
+
next_samples = current_samples + dx * dt
|
19
|
+
return next_samples, state
|
20
|
+
|
21
|
+
class SimplifiedEulerSampler(DiffusionSampler):
|
22
|
+
"""
|
23
|
+
This is for networks with forward diffusion of the form x_{t+1} = x_t + sigma_t * epsilon_t
|
24
|
+
"""
|
25
|
+
def take_next_step(self,
|
26
|
+
current_samples, reconstructed_samples,
|
27
|
+
pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
|
28
|
+
_, current_sigma = self.noise_schedule.get_rates(current_step)
|
29
|
+
_, next_sigma = self.noise_schedule.get_rates(next_step)
|
30
|
+
|
31
|
+
dt = next_sigma - current_sigma
|
32
|
+
|
33
|
+
dx = (current_samples - reconstructed_samples) / current_sigma
|
34
|
+
next_samples = current_samples + dx * dt
|
35
|
+
return next_samples, state
|
36
|
+
|
37
|
+
class EulerAncestralSampler(DiffusionSampler):
|
38
|
+
"""
|
39
|
+
Similar to EulerSampler but with ancestral sampling
|
40
|
+
"""
|
41
|
+
def take_next_step(self,
|
42
|
+
current_samples, reconstructed_samples,
|
43
|
+
pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
|
44
|
+
current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
|
45
|
+
next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
|
46
|
+
|
47
|
+
sigma_up = (next_sigma**2 * (current_sigma**2 - next_sigma**2) / current_sigma**2) ** 0.5
|
48
|
+
sigma_down = (next_sigma**2 - sigma_up**2) ** 0.5
|
49
|
+
|
50
|
+
dt = sigma_down - current_sigma
|
51
|
+
|
52
|
+
x_0_coeff = (current_alpha * next_sigma - next_alpha * current_sigma) / (next_sigma - current_sigma)
|
53
|
+
dx = (current_samples - x_0_coeff * reconstructed_samples) / current_sigma
|
54
|
+
|
55
|
+
state, subkey = state.get_random_key()
|
56
|
+
dW = jax.random.normal(subkey, current_samples.shape) * sigma_up
|
57
|
+
|
58
|
+
next_samples = current_samples + dx * dt + dW
|
59
|
+
return next_samples, state
|
@@ -0,0 +1,28 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
from .common import DiffusionSampler
|
4
|
+
from ..utils import RandomMarkovState
|
5
|
+
|
6
|
+
class HeunSampler(DiffusionSampler):
|
7
|
+
def take_next_step(self,
|
8
|
+
current_samples, reconstructed_samples,
|
9
|
+
pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
|
10
|
+
# Get the noise and signal rates for the current and next steps
|
11
|
+
current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
|
12
|
+
next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
|
13
|
+
|
14
|
+
dt = next_sigma - current_sigma
|
15
|
+
x_0_coeff = (current_alpha * next_sigma - next_alpha * current_sigma) / dt
|
16
|
+
|
17
|
+
dx_0 = (current_samples - x_0_coeff * reconstructed_samples) / current_sigma
|
18
|
+
next_samples_0 = current_samples + dx_0 * dt
|
19
|
+
|
20
|
+
# Recompute x_0 and eps at the first estimate to refine the derivative
|
21
|
+
estimated_x_0, _, _ = self.sample_model(next_samples_0, next_step)
|
22
|
+
|
23
|
+
# Estimate the refined derivative using the midpoint (Heun's method)
|
24
|
+
dx_1 = (next_samples_0 - x_0_coeff * estimated_x_0) / next_sigma
|
25
|
+
# Compute the final next samples by averaging the initial and refined derivatives
|
26
|
+
final_next_samples = current_samples + 0.5 * (dx_0 + dx_1) * dt
|
27
|
+
|
28
|
+
return final_next_samples, state
|
@@ -0,0 +1,60 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
from .common import DiffusionSampler
|
4
|
+
from ..utils import RandomMarkovState
|
5
|
+
|
6
|
+
class MultiStepDPM(DiffusionSampler):
|
7
|
+
def __init__(self, *args, **kwargs):
|
8
|
+
super().__init__(*args, **kwargs)
|
9
|
+
self.history = []
|
10
|
+
|
11
|
+
def _renoise(self,
|
12
|
+
current_samples, reconstructed_samples,
|
13
|
+
pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
|
14
|
+
# Get the noise and signal rates for the current and next steps
|
15
|
+
current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
|
16
|
+
next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
|
17
|
+
|
18
|
+
dt = next_sigma - current_sigma
|
19
|
+
|
20
|
+
def first_order(current_noise, current_sigma):
|
21
|
+
dx = current_noise
|
22
|
+
return dx
|
23
|
+
|
24
|
+
def second_order(current_noise, current_sigma, last_noise, last_sigma):
|
25
|
+
dx_2 = (current_noise - last_noise) / (current_sigma - last_sigma)
|
26
|
+
return dx_2
|
27
|
+
|
28
|
+
def third_order(current_noise, current_sigma, last_noise, last_sigma, second_last_noise, second_last_sigma):
|
29
|
+
dx_2 = second_order(current_noise, current_sigma, last_noise, last_sigma)
|
30
|
+
dx_2_last = second_order(last_noise, last_sigma, second_last_noise, second_last_sigma)
|
31
|
+
|
32
|
+
dx_3 = (dx_2 - dx_2_last) / (0.5 * ((current_sigma + last_sigma) - (last_sigma + second_last_sigma)))
|
33
|
+
|
34
|
+
return dx_3
|
35
|
+
|
36
|
+
if len(self.history) == 0:
|
37
|
+
# First order only
|
38
|
+
dx_1 = first_order(pred_noise, current_sigma)
|
39
|
+
next_samples = current_samples + dx_1 * dt
|
40
|
+
elif len(self.history) == 1:
|
41
|
+
# First + Second order
|
42
|
+
dx_1 = first_order(pred_noise, current_sigma)
|
43
|
+
last_step = self.history[-1]
|
44
|
+
dx_2 = second_order(pred_noise, current_sigma, last_step['eps'], last_step['sigma'])
|
45
|
+
next_samples = current_samples + dx_1 * dt + 0.5 * dx_2 * dt**2
|
46
|
+
else:
|
47
|
+
# First + Second + Third order
|
48
|
+
last_step = self.history[-1]
|
49
|
+
second_last_step = self.history[-2]
|
50
|
+
|
51
|
+
dx_1 = first_order(pred_noise, current_sigma)
|
52
|
+
dx_2 = second_order(pred_noise, current_sigma, last_step['eps'], last_step['sigma'])
|
53
|
+
dx_3 = third_order(pred_noise, current_sigma, last_step['eps'], last_step['sigma'], second_last_step['eps'], second_last_step['sigma'])
|
54
|
+
next_samples = current_samples + (dx_1 * dt) + (0.5 * dx_2 * dt**2) + ((1/6) * dx_3 * dt**3)
|
55
|
+
|
56
|
+
self.history.append({
|
57
|
+
"eps": pred_noise,
|
58
|
+
"sigma" : current_sigma,
|
59
|
+
})
|
60
|
+
return next_samples, state
|
@@ -0,0 +1,34 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
from .common import DiffusionSampler
|
4
|
+
from ..utils import RandomMarkovState
|
5
|
+
from ..schedulers import GeneralizedNoiseScheduler
|
6
|
+
|
7
|
+
class RK4Sampler(DiffusionSampler):
|
8
|
+
def __init__(self, *args, **kwargs):
|
9
|
+
super().__init__(*args, **kwargs)
|
10
|
+
assert issubclass(type(self.noise_schedule), GeneralizedNoiseScheduler), "Noise schedule must be a GeneralizedNoiseScheduler"
|
11
|
+
@jax.jit
|
12
|
+
def get_derivative(x_t, sigma, state:RandomMarkovState) -> tuple[jnp.ndarray, RandomMarkovState]:
|
13
|
+
t = self.noise_schedule.get_timesteps(sigma)
|
14
|
+
x_0, eps, _ = self.sample_model(x_t, t)
|
15
|
+
return eps, state
|
16
|
+
|
17
|
+
self.get_derivative = get_derivative
|
18
|
+
|
19
|
+
def sample_step(self, current_samples:jnp.ndarray, current_step, next_step, state:RandomMarkovState=None) -> tuple[jnp.ndarray, RandomMarkovState]:
|
20
|
+
step_ones = jnp.ones((current_samples.shape[0], ), dtype=jnp.int32)
|
21
|
+
current_step = step_ones * current_step
|
22
|
+
next_step = step_ones * next_step
|
23
|
+
_, current_sigma = self.noise_schedule.get_rates(current_step)
|
24
|
+
_, next_sigma = self.noise_schedule.get_rates(next_step)
|
25
|
+
|
26
|
+
dt = next_sigma - current_sigma
|
27
|
+
|
28
|
+
k1, state = self.get_derivative(current_samples, current_sigma, state)
|
29
|
+
k2, state = self.get_derivative(current_samples + 0.5 * k1 * dt, current_sigma + 0.5 * dt, state)
|
30
|
+
k3, state = self.get_derivative(current_samples + 0.5 * k2 * dt, current_sigma + 0.5 * dt, state)
|
31
|
+
k4, state = self.get_derivative(current_samples + k3 * dt, current_sigma + dt, state)
|
32
|
+
|
33
|
+
next_samples = current_samples + ((k1 + 2 * k2 + 2 * k3 + k4) / 6) * dt
|
34
|
+
return next_samples, state
|
@@ -0,0 +1,6 @@
|
|
1
|
+
from .discrete import DiscreteNoiseScheduler
|
2
|
+
from .common import NoiseScheduler, GeneralizedNoiseScheduler
|
3
|
+
from .cosine import CosineNoiseSchedule, ContinuousNoiseScheduler, CosineGeneralNoiseScheduler
|
4
|
+
from .linear import LinearNoiseSchedule
|
5
|
+
from .sqrt import SqrtContinuousNoiseScheduler
|
6
|
+
from .karras import KarrasVENoiseScheduler, SimpleExpNoiseScheduler, EDMNoiseScheduler
|
@@ -0,0 +1,98 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
from typing import Union
|
4
|
+
from ..utils import RandomMarkovState
|
5
|
+
|
6
|
+
class NoiseScheduler():
|
7
|
+
def __init__(self, timesteps,
|
8
|
+
dtype=jnp.float32,
|
9
|
+
clip_min=-1.0,
|
10
|
+
clip_max=1.0,
|
11
|
+
*args, **kwargs):
|
12
|
+
self.max_timesteps = timesteps
|
13
|
+
self.dtype = dtype
|
14
|
+
self.clip_min = clip_min
|
15
|
+
self.clip_max = clip_max
|
16
|
+
if type(timesteps) == int and timesteps > 1:
|
17
|
+
timestep_generator = lambda rng, batch_size, max_timesteps = timesteps: jax.random.randint(rng, (batch_size,), 0, max_timesteps)
|
18
|
+
else:
|
19
|
+
timestep_generator = lambda rng, batch_size, max_timesteps = timesteps: jax.random.uniform(rng, (batch_size,), minval=0, maxval=max_timesteps)
|
20
|
+
self.timestep_generator = timestep_generator
|
21
|
+
|
22
|
+
def generate_timesteps(self, batch_size, state:RandomMarkovState) -> tuple[jnp.ndarray, RandomMarkovState]:
|
23
|
+
state, rng = state.get_random_key()
|
24
|
+
timesteps = self.timestep_generator(rng, batch_size, self.max_timesteps)
|
25
|
+
return timesteps, state
|
26
|
+
|
27
|
+
def get_weights(self, steps):
|
28
|
+
raise NotImplementedError
|
29
|
+
|
30
|
+
def reshape_rates(self, rates:tuple[jnp.ndarray, jnp.ndarray], shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
|
31
|
+
signal_rates, noise_rates = rates
|
32
|
+
signal_rates = jnp.reshape(signal_rates, shape)
|
33
|
+
noise_rates = jnp.reshape(noise_rates, shape)
|
34
|
+
return signal_rates, noise_rates
|
35
|
+
|
36
|
+
def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
|
37
|
+
raise NotImplementedError
|
38
|
+
|
39
|
+
def add_noise(self, images, noise, steps) -> jnp.ndarray:
|
40
|
+
signal_rates, noise_rates = self.get_rates(steps)
|
41
|
+
return signal_rates * images + noise_rates * noise
|
42
|
+
|
43
|
+
def remove_all_noise(self, noisy_images, noise, steps, clip_denoised=True, rates=None):
|
44
|
+
signal_rates, noise_rates = self.get_rates(steps)
|
45
|
+
x_0 = (noisy_images - noise * noise_rates) / signal_rates
|
46
|
+
return x_0
|
47
|
+
|
48
|
+
def transform_inputs(self, x, steps):
|
49
|
+
return x, steps
|
50
|
+
|
51
|
+
def get_posterior_mean(self, x_0, x_t, steps):
|
52
|
+
raise NotImplementedError
|
53
|
+
|
54
|
+
def get_posterior_variance(self, steps, shape=(-1, 1, 1, 1)):
|
55
|
+
raise NotImplementedError
|
56
|
+
|
57
|
+
def get_max_variance(self):
|
58
|
+
alpha_n, sigma_n = self.get_rates(self.max_timesteps)
|
59
|
+
variance = jnp.sqrt(alpha_n ** 2 + sigma_n ** 2)
|
60
|
+
return variance
|
61
|
+
|
62
|
+
class GeneralizedNoiseScheduler(NoiseScheduler):
|
63
|
+
"""
|
64
|
+
As per the generalization presented in the paper
|
65
|
+
"Elucidating the Design Space of Diffusion-Based
|
66
|
+
Generative Models" by Tero Karras et al.
|
67
|
+
Basically the signal rate shall always be 1, and the model
|
68
|
+
input itself shall be scaled to match the noise rate
|
69
|
+
"""
|
70
|
+
def __init__(self, timesteps, sigma_min=0.002, sigma_max=80.0, sigma_data=1, *args, **kwargs):
|
71
|
+
super().__init__(timesteps, *args, **kwargs)
|
72
|
+
self.sigma_min = sigma_min
|
73
|
+
self.sigma_max = sigma_max
|
74
|
+
self.sigma_data = sigma_data
|
75
|
+
|
76
|
+
def get_weights(self, steps, shape=(-1, 1, 1, 1)):
|
77
|
+
sigma = self.get_sigmas(steps)
|
78
|
+
return (1 + (1 / (1 + ((1 - sigma ** 2)/(sigma ** 2)))) / (self.sigma_max ** 2)).reshape(shape)
|
79
|
+
|
80
|
+
def get_sigmas(self, steps) -> jnp.ndarray:
|
81
|
+
raise NotImplementedError("This method should be implemented in the subclass")
|
82
|
+
|
83
|
+
def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
|
84
|
+
sigmas = self.get_sigmas(steps)
|
85
|
+
signal_rates = 1
|
86
|
+
noise_rates = sigmas
|
87
|
+
return self.reshape_rates((signal_rates, noise_rates), shape=shape)
|
88
|
+
|
89
|
+
def transform_inputs(self, x, steps, num_discrete_chunks=1000):
|
90
|
+
sigmas_discrete = (steps / self.max_timesteps) * num_discrete_chunks
|
91
|
+
sigmas_discrete = sigmas_discrete.astype(jnp.int32)
|
92
|
+
return x, sigmas_discrete
|
93
|
+
|
94
|
+
def get_timesteps(self, sigmas):
|
95
|
+
"""
|
96
|
+
Inverse of the get_sigmas method
|
97
|
+
"""
|
98
|
+
raise NotImplementedError("This method should be implemented in the subclass")
|
@@ -0,0 +1,12 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
from typing import Union
|
4
|
+
from ..utils import RandomMarkovState
|
5
|
+
from .common import NoiseScheduler
|
6
|
+
|
7
|
+
class ContinuousNoiseScheduler(NoiseScheduler):
|
8
|
+
"""
|
9
|
+
General Continuous Noise Scheduler
|
10
|
+
"""
|
11
|
+
def __init__(self, *args, **kwargs):
|
12
|
+
super().__init__(timesteps=1, *args, **kwargs)
|
@@ -0,0 +1,40 @@
|
|
1
|
+
import math
|
2
|
+
import numpy as np
|
3
|
+
import jax.numpy as jnp
|
4
|
+
from .discrete import DiscreteNoiseScheduler
|
5
|
+
from .continuous import ContinuousNoiseScheduler
|
6
|
+
from .common import GeneralizedNoiseScheduler
|
7
|
+
|
8
|
+
def cosine_beta_schedule(timesteps, start_angle=0.008, end_angle=0.999):
|
9
|
+
ts = np.linspace(0, 1, timesteps + 1, dtype=np.float64)
|
10
|
+
alphas_bar = np.cos((ts + start_angle) / (1 + start_angle) * np.pi /2) ** 2
|
11
|
+
alphas_bar = alphas_bar/alphas_bar[0]
|
12
|
+
betas = 1 - (alphas_bar[1:] / alphas_bar[:-1])
|
13
|
+
return np.clip(betas, 0, end_angle)
|
14
|
+
|
15
|
+
class CosineNoiseSchedule(DiscreteNoiseScheduler):
|
16
|
+
def __init__(self, timesteps, beta_start=0.008, beta_end=0.999, *args, **kwargs):
|
17
|
+
super().__init__(timesteps, beta_start, beta_end, schedule_fn=cosine_beta_schedule, *args, **kwargs)
|
18
|
+
|
19
|
+
class CosineGeneralNoiseScheduler(GeneralizedNoiseScheduler):
|
20
|
+
def __init__(self, sigma_min=0.02, sigma_max=80.0, kappa=1.0, *args, **kwargs):
|
21
|
+
super().__init__(timesteps=1, sigma_min=sigma_min, sigma_max=sigma_max, *args, **kwargs)
|
22
|
+
self.kappa = kappa
|
23
|
+
logsnr_max = 2 * (math.log(self.kappa) - math.log(self.sigma_max))
|
24
|
+
self.theta_max = math.atan(math.exp(-0.5 * logsnr_max))
|
25
|
+
logsnr_min = 2 * (math.log(self.kappa) - math.log(self.sigma_min))
|
26
|
+
self.theta_min = math.atan(math.exp(-0.5 * logsnr_min))
|
27
|
+
|
28
|
+
def get_sigmas(self, steps):
|
29
|
+
return jnp.tan(self.theta_min + steps * (self.theta_max - self.theta_min)) / self.kappa
|
30
|
+
|
31
|
+
class CosineContinuousNoiseScheduler(ContinuousNoiseScheduler):
|
32
|
+
def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
|
33
|
+
signal_rates = jnp.cos((jnp.pi * steps) / (2 * self.max_timesteps))
|
34
|
+
noise_rates = jnp.sin((jnp.pi * steps) / (2 * self.max_timesteps))
|
35
|
+
return self.reshape_rates((signal_rates, noise_rates), shape=shape)
|
36
|
+
|
37
|
+
def get_weights(self, steps):
|
38
|
+
alpha, sigma = self.get_rates(steps, shape=())
|
39
|
+
return 1 / (1 + (alpha ** 2 / sigma ** 2))
|
40
|
+
|
@@ -0,0 +1,74 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
from typing import Union
|
4
|
+
from ..utils import RandomMarkovState
|
5
|
+
from .common import NoiseScheduler
|
6
|
+
|
7
|
+
class DiscreteNoiseScheduler(NoiseScheduler):
|
8
|
+
"""
|
9
|
+
Variance Preserving Noise Scheduler
|
10
|
+
signal_rate**2 + noise_rate**2 = 1
|
11
|
+
"""
|
12
|
+
def __init__(self, timesteps,
|
13
|
+
beta_start=0.0001,
|
14
|
+
beta_end=0.02,
|
15
|
+
schedule_fn=None,
|
16
|
+
p2_loss_weight_k:float=1,
|
17
|
+
p2_loss_weight_gamma:float=1,
|
18
|
+
*args, **kwargs):
|
19
|
+
super().__init__(timesteps, *args, **kwargs)
|
20
|
+
betas = schedule_fn(timesteps, beta_start, beta_end)
|
21
|
+
alphas = 1 - betas
|
22
|
+
alpha_cumprod = jnp.cumprod(alphas, axis=0)
|
23
|
+
alpha_cumprod_prev = jnp.append(1.0, alpha_cumprod[:-1])
|
24
|
+
|
25
|
+
self.betas = jnp.array(betas, dtype=jnp.float32)
|
26
|
+
self.alphas = alphas.astype(jnp.float32)
|
27
|
+
self.alpha_cumprod = alpha_cumprod.astype(jnp.float32)
|
28
|
+
self.alpha_cumprod_prev = alpha_cumprod_prev.astype(jnp.float32)
|
29
|
+
|
30
|
+
self.sqrt_alpha_cumprod = jnp.sqrt(alpha_cumprod).astype(jnp.float32)
|
31
|
+
self.sqrt_one_minus_alpha_cumprod = jnp.sqrt(1 - alpha_cumprod).astype(jnp.float32)
|
32
|
+
|
33
|
+
posterior_variance = (betas * (1 - alpha_cumprod_prev) / (1 - alpha_cumprod))
|
34
|
+
self.posterior_variance = posterior_variance.astype(jnp.float32)
|
35
|
+
self.posterior_log_variance_clipped = (jnp.log(jnp.maximum(posterior_variance, 1e-20))).astype(jnp.float32)
|
36
|
+
|
37
|
+
self.posterior_mean_coef1 = (betas * jnp.sqrt(alpha_cumprod_prev) / (1 - alpha_cumprod)).astype(jnp.float32)
|
38
|
+
self.posterior_mean_coef2 = ((1 - alpha_cumprod_prev) * jnp.sqrt(alphas) / (1 - alpha_cumprod)).astype(jnp.float32)
|
39
|
+
|
40
|
+
self.p2_loss_weights = self.get_p2_weights(p2_loss_weight_k, p2_loss_weight_gamma)
|
41
|
+
|
42
|
+
def generate_timesteps(self, batch_size, state:RandomMarkovState) -> tuple[jnp.ndarray, RandomMarkovState]:
|
43
|
+
state, rng = state.get_random_key()
|
44
|
+
timesteps = jax.random.randint(rng, (batch_size,), 0, self.max_timesteps)
|
45
|
+
return timesteps, state
|
46
|
+
|
47
|
+
def get_p2_weights(self, k, gamma):
|
48
|
+
return (k + self.alpha_cumprod / (1 - self.alpha_cumprod)) ** -gamma
|
49
|
+
|
50
|
+
def get_weights(self, steps, shape=(-1, 1, 1, 1)):
|
51
|
+
steps = jnp.int16(steps)
|
52
|
+
return self.p2_loss_weights[steps].reshape(shape)
|
53
|
+
|
54
|
+
def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
|
55
|
+
steps = jnp.int16(steps)
|
56
|
+
signal_rate = self.sqrt_alpha_cumprod[steps]
|
57
|
+
noise_rate = self.sqrt_one_minus_alpha_cumprod[steps]
|
58
|
+
signal_rate = jnp.reshape(signal_rate, shape)
|
59
|
+
noise_rate = jnp.reshape(noise_rate, shape)
|
60
|
+
return signal_rate, noise_rate
|
61
|
+
|
62
|
+
def get_posterior_mean(self, x_0, x_t, steps):
|
63
|
+
steps = jnp.int16(steps)
|
64
|
+
x_0_coeff = self.posterior_mean_coef1[steps]
|
65
|
+
x_t_coeff = self.posterior_mean_coef2[steps]
|
66
|
+
x_0_coeff, x_t_coeff = self.reshape_rates((x_0_coeff, x_t_coeff))
|
67
|
+
mean = x_0_coeff * x_0 + x_t_coeff * x_t
|
68
|
+
return mean
|
69
|
+
|
70
|
+
def get_posterior_variance(self, steps, shape=(-1, 1, 1, 1)):
|
71
|
+
steps = int(steps)
|
72
|
+
return jnp.exp(0.5 * self.posterior_log_variance_clipped[steps]).reshape(shape)
|
73
|
+
|
74
|
+
|
@@ -0,0 +1,13 @@
|
|
1
|
+
import numpy as np
|
2
|
+
from .discrete import DiscreteNoiseScheduler
|
3
|
+
|
4
|
+
def exp_beta_schedule(timesteps, start_angle=0.008, end_angle=0.999):
|
5
|
+
ts = np.linspace(0, 1, timesteps + 1, dtype=np.float64)
|
6
|
+
alphas_bar = np.exp(ts * -12.0)
|
7
|
+
alphas_bar = alphas_bar/alphas_bar[0]
|
8
|
+
betas = 1 - (alphas_bar[1:] / alphas_bar[:-1])
|
9
|
+
return np.clip(betas, 0, end_angle)
|
10
|
+
|
11
|
+
class ExpNoiseSchedule(DiscreteNoiseScheduler):
|
12
|
+
def __init__(self, timesteps, beta_start=0.008, beta_end=0.999, *args, **kwargs):
|
13
|
+
super().__init__(timesteps, beta_start, beta_end, schedule_fn=exp_beta_schedule, *args, **kwargs)
|