flaxdiff 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,113 @@
1
+ from flax import linen as nn
2
+ import jax
3
+ import jax.numpy as jnp
4
+ import tqdm
5
+ from typing import Union
6
+ from ..schedulers import NoiseScheduler
7
+ from ..utils import RandomMarkovState, MarkovState, clip_images
8
+ from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
9
+
10
+ class DiffusionSampler():
11
+ model:nn.Module
12
+ noise_schedule:NoiseScheduler
13
+ params:dict
14
+ model_output_transform:DiffusionPredictionTransform
15
+
16
+ def __init__(self, model:nn.Module, params:dict,
17
+ noise_schedule:NoiseScheduler,
18
+ model_output_transform:DiffusionPredictionTransform=EpsilonPredictionTransform()):
19
+ self.model = model
20
+ self.noise_schedule = noise_schedule
21
+ self.params = params
22
+ self.model_output_transform = model_output_transform
23
+
24
+ @jax.jit
25
+ def sample_model(x_t, t):
26
+ rates = self.noise_schedule.get_rates(t)
27
+ c_in = self.model_output_transform.get_input_scale(rates)
28
+ model_output = self.model.apply(self.params, *self.noise_schedule.transform_inputs(x_t * c_in, t))
29
+ x_0, eps = self.model_output_transform(x_t, model_output, t, self.noise_schedule)
30
+ return x_0, eps, model_output
31
+
32
+ self.sample_model = sample_model
33
+
34
+ # Used to sample from the diffusion model
35
+ def sample_step(self, current_samples:jnp.ndarray, current_step, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
36
+ # First clip the noisy images
37
+ # pred_images = clip_images(pred_images)
38
+ step_ones = jnp.ones((current_samples.shape[0], ), dtype=jnp.int32)
39
+ current_step = step_ones * current_step
40
+ next_step = step_ones * next_step
41
+ pred_images, pred_noise, _ = self.sample_model(current_samples, current_step)
42
+ # plotImages(pred_images)
43
+ new_samples, state = self.take_next_step(current_samples=current_samples, reconstructed_samples=pred_images,
44
+ pred_noise=pred_noise, current_step=current_step, next_step=next_step, state=state)
45
+ return new_samples, state
46
+
47
+ def take_next_step(self, current_samples, reconstructed_samples,
48
+ pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
49
+ # estimate the q(x_{t-1} | x_t, x_0).
50
+ # pred_images is x_0, noisy_images is x_t, steps is t
51
+ return NotImplementedError
52
+
53
+ def scale_steps(self, steps):
54
+ scale_factor = self.noise_schedule.max_timesteps / 1000
55
+ return steps * scale_factor
56
+
57
+ def get_steps(self, start_step, end_step, diffusion_steps):
58
+ step_range = start_step - end_step
59
+ if diffusion_steps is None or diffusion_steps == 0:
60
+ diffusion_steps = start_step - end_step
61
+ diffusion_steps = min(diffusion_steps, step_range)
62
+ steps = jnp.linspace(end_step, start_step, diffusion_steps, dtype=jnp.int16)[::-1]
63
+ return steps
64
+
65
+ def get_initial_samples(self, num_images, rngs:jax.random.PRNGKey, start_step, image_size=64):
66
+ start_step = self.scale_steps(start_step)
67
+ alpha_n, sigma_n = self.noise_schedule.get_rates(start_step)
68
+ variance = jnp.sqrt(alpha_n ** 2 + sigma_n ** 2)
69
+ return jax.random.normal(rngs, (num_images, image_size, image_size, 3)) * variance
70
+
71
+ def generate_images(self,
72
+ num_images=16,
73
+ diffusion_steps=1000,
74
+ start_step:int = None,
75
+ end_step:int = 0,
76
+ steps_override=None,
77
+ priors=None,
78
+ rngstate:RandomMarkovState=RandomMarkovState(jax.random.PRNGKey(42))) -> jnp.ndarray:
79
+ if priors is None:
80
+ rngstate, newrngs = rngstate.get_random_key()
81
+ samples = self.get_initial_samples(num_images, newrngs, start_step)
82
+ else:
83
+ print("Using priors")
84
+ samples = priors
85
+
86
+ @jax.jit
87
+ def sample_step(state:RandomMarkovState, samples, current_step, next_step):
88
+ samples, state = self.sample_step(current_samples=samples,
89
+ current_step=current_step,
90
+ state=state, next_step=next_step)
91
+ return samples, state
92
+
93
+ if start_step is None:
94
+ start_step = self.noise_schedule.max_timesteps
95
+
96
+ if steps_override is not None:
97
+ steps = steps_override
98
+ else:
99
+ steps = self.get_steps(start_step, end_step, diffusion_steps)
100
+
101
+ # print("Sampling steps", steps)
102
+ for i in tqdm.tqdm(range(0, len(steps))):
103
+ current_step = self.scale_steps(steps[i])
104
+ next_step = self.scale_steps(steps[i+1] if i+1 < len(steps) else 0)
105
+ if i != len(steps) - 1:
106
+ # print("normal step")
107
+ samples, rngstate = sample_step(rngstate, samples, current_step, next_step)
108
+ else:
109
+ # print("last step")
110
+ step_ones = jnp.ones((num_images, ), dtype=jnp.int32)
111
+ samples, _, _ = self.sample_model(samples, current_step * step_ones)
112
+ samples = clip_images(samples)
113
+ return samples
@@ -0,0 +1,10 @@
1
+ import jax.numpy as jnp
2
+ from .common import DiffusionSampler
3
+ from ..utils import MarkovState
4
+
5
+ class DDIMSampler(DiffusionSampler):
6
+ def take_next_step(self,
7
+ current_samples, reconstructed_samples,
8
+ pred_noise, current_step, state:MarkovState, next_step=None) -> tuple[jnp.ndarray, MarkovState]:
9
+ next_signal_rate, next_noise_rate = self.noise_schedule.get_rates(next_step)
10
+ return reconstructed_samples * next_signal_rate + pred_noise * next_noise_rate, state
@@ -0,0 +1,43 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from .common import DiffusionSampler
4
+ from ..utils import MarkovState, RandomMarkovState
5
+ class DDPMSampler(DiffusionSampler):
6
+ def take_next_step(self,
7
+ current_samples, reconstructed_samples,
8
+ pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
9
+ mean = self.noise_schedule.get_posterior_mean(reconstructed_samples, current_samples, current_step)
10
+ variance = self.noise_schedule.get_posterior_variance(steps=current_step)
11
+
12
+ state, rng = state.get_random_key()
13
+ # Now sample from the posterior
14
+ noise = jax.random.normal(rng, reconstructed_samples.shape, dtype=jnp.float32)
15
+
16
+ return mean + noise * variance, state
17
+
18
+ def generate_images(self, num_images=16, diffusion_steps=1000, start_step: int = None, *args, **kwargs):
19
+ return super().generate_images(num_images=num_images, diffusion_steps=diffusion_steps, start_step=start_step, *args, **kwargs)
20
+
21
+ class SimpleDDPMSampler(DiffusionSampler):
22
+ def take_next_step(self,
23
+ current_samples, reconstructed_samples,
24
+ pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
25
+ state, rng = state.get_random_key()
26
+ noise = jax.random.normal(rng, reconstructed_samples.shape, dtype=jnp.float32)
27
+
28
+ # Compute noise rates and signal rates only once
29
+ current_signal_rate, current_noise_rate = self.noise_schedule.get_rates(current_step)
30
+ next_signal_rate, next_noise_rate = self.noise_schedule.get_rates(next_step)
31
+
32
+ pred_noise_coeff = ((next_noise_rate ** 2) * current_signal_rate) / (current_noise_rate * next_signal_rate)
33
+
34
+ noise_ratio_squared = (next_noise_rate ** 2) / (current_noise_rate ** 2)
35
+ signal_ratio_squared = (current_signal_rate ** 2) / (next_signal_rate ** 2)
36
+ betas = (1 - signal_ratio_squared)
37
+ gamma = jnp.sqrt(noise_ratio_squared * betas)
38
+
39
+ next_samples = next_signal_rate * reconstructed_samples + pred_noise_coeff * pred_noise + noise * gamma
40
+ # pred_noise_coeff = ((next_noise_rate ** 2) * current_signal_rate) / (current_noise_rate * next_signal_rate)
41
+ # next_samples = (2 - jnp.sqrt(1 - betas)) * current_samples - betas * (pred_noise / current_noise_rate) + noise * gamma#jnp.sqrt(betas)
42
+ # next_samples = (1 / (jnp.sqrt(1 - betas) + 1.e-24)) * (current_samples - betas * (pred_noise / current_noise_rate)) + noise * gamma
43
+ return next_samples, state
@@ -0,0 +1,59 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from .common import DiffusionSampler
4
+ from ..utils import RandomMarkovState
5
+
6
+ class EulerSampler(DiffusionSampler):
7
+ # Basically a DDIM Sampler but parameterized as an ODE
8
+ def take_next_step(self,
9
+ current_samples, reconstructed_samples,
10
+ pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
11
+ current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
12
+ next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
13
+
14
+ dt = next_sigma - current_sigma
15
+
16
+ x_0_coeff = (current_alpha * next_sigma - next_alpha * current_sigma) / (dt)
17
+ dx = (current_samples - x_0_coeff * reconstructed_samples) / current_sigma
18
+ next_samples = current_samples + dx * dt
19
+ return next_samples, state
20
+
21
+ class SimplifiedEulerSampler(DiffusionSampler):
22
+ """
23
+ This is for networks with forward diffusion of the form x_{t+1} = x_t + sigma_t * epsilon_t
24
+ """
25
+ def take_next_step(self,
26
+ current_samples, reconstructed_samples,
27
+ pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
28
+ _, current_sigma = self.noise_schedule.get_rates(current_step)
29
+ _, next_sigma = self.noise_schedule.get_rates(next_step)
30
+
31
+ dt = next_sigma - current_sigma
32
+
33
+ dx = (current_samples - reconstructed_samples) / current_sigma
34
+ next_samples = current_samples + dx * dt
35
+ return next_samples, state
36
+
37
+ class EulerAncestralSampler(DiffusionSampler):
38
+ """
39
+ Similar to EulerSampler but with ancestral sampling
40
+ """
41
+ def take_next_step(self,
42
+ current_samples, reconstructed_samples,
43
+ pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
44
+ current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
45
+ next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
46
+
47
+ sigma_up = (next_sigma**2 * (current_sigma**2 - next_sigma**2) / current_sigma**2) ** 0.5
48
+ sigma_down = (next_sigma**2 - sigma_up**2) ** 0.5
49
+
50
+ dt = sigma_down - current_sigma
51
+
52
+ x_0_coeff = (current_alpha * next_sigma - next_alpha * current_sigma) / (next_sigma - current_sigma)
53
+ dx = (current_samples - x_0_coeff * reconstructed_samples) / current_sigma
54
+
55
+ state, subkey = state.get_random_key()
56
+ dW = jax.random.normal(subkey, current_samples.shape) * sigma_up
57
+
58
+ next_samples = current_samples + dx * dt + dW
59
+ return next_samples, state
@@ -0,0 +1,28 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from .common import DiffusionSampler
4
+ from ..utils import RandomMarkovState
5
+
6
+ class HeunSampler(DiffusionSampler):
7
+ def take_next_step(self,
8
+ current_samples, reconstructed_samples,
9
+ pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
10
+ # Get the noise and signal rates for the current and next steps
11
+ current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
12
+ next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
13
+
14
+ dt = next_sigma - current_sigma
15
+ x_0_coeff = (current_alpha * next_sigma - next_alpha * current_sigma) / dt
16
+
17
+ dx_0 = (current_samples - x_0_coeff * reconstructed_samples) / current_sigma
18
+ next_samples_0 = current_samples + dx_0 * dt
19
+
20
+ # Recompute x_0 and eps at the first estimate to refine the derivative
21
+ estimated_x_0, _, _ = self.sample_model(next_samples_0, next_step)
22
+
23
+ # Estimate the refined derivative using the midpoint (Heun's method)
24
+ dx_1 = (next_samples_0 - x_0_coeff * estimated_x_0) / next_sigma
25
+ # Compute the final next samples by averaging the initial and refined derivatives
26
+ final_next_samples = current_samples + 0.5 * (dx_0 + dx_1) * dt
27
+
28
+ return final_next_samples, state
@@ -0,0 +1,60 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from .common import DiffusionSampler
4
+ from ..utils import RandomMarkovState
5
+
6
+ class MultiStepDPM(DiffusionSampler):
7
+ def __init__(self, *args, **kwargs):
8
+ super().__init__(*args, **kwargs)
9
+ self.history = []
10
+
11
+ def _renoise(self,
12
+ current_samples, reconstructed_samples,
13
+ pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
14
+ # Get the noise and signal rates for the current and next steps
15
+ current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
16
+ next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
17
+
18
+ dt = next_sigma - current_sigma
19
+
20
+ def first_order(current_noise, current_sigma):
21
+ dx = current_noise
22
+ return dx
23
+
24
+ def second_order(current_noise, current_sigma, last_noise, last_sigma):
25
+ dx_2 = (current_noise - last_noise) / (current_sigma - last_sigma)
26
+ return dx_2
27
+
28
+ def third_order(current_noise, current_sigma, last_noise, last_sigma, second_last_noise, second_last_sigma):
29
+ dx_2 = second_order(current_noise, current_sigma, last_noise, last_sigma)
30
+ dx_2_last = second_order(last_noise, last_sigma, second_last_noise, second_last_sigma)
31
+
32
+ dx_3 = (dx_2 - dx_2_last) / (0.5 * ((current_sigma + last_sigma) - (last_sigma + second_last_sigma)))
33
+
34
+ return dx_3
35
+
36
+ if len(self.history) == 0:
37
+ # First order only
38
+ dx_1 = first_order(pred_noise, current_sigma)
39
+ next_samples = current_samples + dx_1 * dt
40
+ elif len(self.history) == 1:
41
+ # First + Second order
42
+ dx_1 = first_order(pred_noise, current_sigma)
43
+ last_step = self.history[-1]
44
+ dx_2 = second_order(pred_noise, current_sigma, last_step['eps'], last_step['sigma'])
45
+ next_samples = current_samples + dx_1 * dt + 0.5 * dx_2 * dt**2
46
+ else:
47
+ # First + Second + Third order
48
+ last_step = self.history[-1]
49
+ second_last_step = self.history[-2]
50
+
51
+ dx_1 = first_order(pred_noise, current_sigma)
52
+ dx_2 = second_order(pred_noise, current_sigma, last_step['eps'], last_step['sigma'])
53
+ dx_3 = third_order(pred_noise, current_sigma, last_step['eps'], last_step['sigma'], second_last_step['eps'], second_last_step['sigma'])
54
+ next_samples = current_samples + (dx_1 * dt) + (0.5 * dx_2 * dt**2) + ((1/6) * dx_3 * dt**3)
55
+
56
+ self.history.append({
57
+ "eps": pred_noise,
58
+ "sigma" : current_sigma,
59
+ })
60
+ return next_samples, state
@@ -0,0 +1,34 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from .common import DiffusionSampler
4
+ from ..utils import RandomMarkovState
5
+ from ..schedulers import GeneralizedNoiseScheduler
6
+
7
+ class RK4Sampler(DiffusionSampler):
8
+ def __init__(self, *args, **kwargs):
9
+ super().__init__(*args, **kwargs)
10
+ assert issubclass(type(self.noise_schedule), GeneralizedNoiseScheduler), "Noise schedule must be a GeneralizedNoiseScheduler"
11
+ @jax.jit
12
+ def get_derivative(x_t, sigma, state:RandomMarkovState) -> tuple[jnp.ndarray, RandomMarkovState]:
13
+ t = self.noise_schedule.get_timesteps(sigma)
14
+ x_0, eps, _ = self.sample_model(x_t, t)
15
+ return eps, state
16
+
17
+ self.get_derivative = get_derivative
18
+
19
+ def sample_step(self, current_samples:jnp.ndarray, current_step, next_step, state:RandomMarkovState=None) -> tuple[jnp.ndarray, RandomMarkovState]:
20
+ step_ones = jnp.ones((current_samples.shape[0], ), dtype=jnp.int32)
21
+ current_step = step_ones * current_step
22
+ next_step = step_ones * next_step
23
+ _, current_sigma = self.noise_schedule.get_rates(current_step)
24
+ _, next_sigma = self.noise_schedule.get_rates(next_step)
25
+
26
+ dt = next_sigma - current_sigma
27
+
28
+ k1, state = self.get_derivative(current_samples, current_sigma, state)
29
+ k2, state = self.get_derivative(current_samples + 0.5 * k1 * dt, current_sigma + 0.5 * dt, state)
30
+ k3, state = self.get_derivative(current_samples + 0.5 * k2 * dt, current_sigma + 0.5 * dt, state)
31
+ k4, state = self.get_derivative(current_samples + k3 * dt, current_sigma + dt, state)
32
+
33
+ next_samples = current_samples + ((k1 + 2 * k2 + 2 * k3 + k4) / 6) * dt
34
+ return next_samples, state
@@ -0,0 +1,6 @@
1
+ from .discrete import DiscreteNoiseScheduler
2
+ from .common import NoiseScheduler, GeneralizedNoiseScheduler
3
+ from .cosine import CosineNoiseSchedule, ContinuousNoiseScheduler, CosineGeneralNoiseScheduler
4
+ from .linear import LinearNoiseSchedule
5
+ from .sqrt import SqrtContinuousNoiseScheduler
6
+ from .karras import KarrasVENoiseScheduler, SimpleExpNoiseScheduler, EDMNoiseScheduler
@@ -0,0 +1,98 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from typing import Union
4
+ from ..utils import RandomMarkovState
5
+
6
+ class NoiseScheduler():
7
+ def __init__(self, timesteps,
8
+ dtype=jnp.float32,
9
+ clip_min=-1.0,
10
+ clip_max=1.0,
11
+ *args, **kwargs):
12
+ self.max_timesteps = timesteps
13
+ self.dtype = dtype
14
+ self.clip_min = clip_min
15
+ self.clip_max = clip_max
16
+ if type(timesteps) == int and timesteps > 1:
17
+ timestep_generator = lambda rng, batch_size, max_timesteps = timesteps: jax.random.randint(rng, (batch_size,), 0, max_timesteps)
18
+ else:
19
+ timestep_generator = lambda rng, batch_size, max_timesteps = timesteps: jax.random.uniform(rng, (batch_size,), minval=0, maxval=max_timesteps)
20
+ self.timestep_generator = timestep_generator
21
+
22
+ def generate_timesteps(self, batch_size, state:RandomMarkovState) -> tuple[jnp.ndarray, RandomMarkovState]:
23
+ state, rng = state.get_random_key()
24
+ timesteps = self.timestep_generator(rng, batch_size, self.max_timesteps)
25
+ return timesteps, state
26
+
27
+ def get_weights(self, steps):
28
+ raise NotImplementedError
29
+
30
+ def reshape_rates(self, rates:tuple[jnp.ndarray, jnp.ndarray], shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
31
+ signal_rates, noise_rates = rates
32
+ signal_rates = jnp.reshape(signal_rates, shape)
33
+ noise_rates = jnp.reshape(noise_rates, shape)
34
+ return signal_rates, noise_rates
35
+
36
+ def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
37
+ raise NotImplementedError
38
+
39
+ def add_noise(self, images, noise, steps) -> jnp.ndarray:
40
+ signal_rates, noise_rates = self.get_rates(steps)
41
+ return signal_rates * images + noise_rates * noise
42
+
43
+ def remove_all_noise(self, noisy_images, noise, steps, clip_denoised=True, rates=None):
44
+ signal_rates, noise_rates = self.get_rates(steps)
45
+ x_0 = (noisy_images - noise * noise_rates) / signal_rates
46
+ return x_0
47
+
48
+ def transform_inputs(self, x, steps):
49
+ return x, steps
50
+
51
+ def get_posterior_mean(self, x_0, x_t, steps):
52
+ raise NotImplementedError
53
+
54
+ def get_posterior_variance(self, steps, shape=(-1, 1, 1, 1)):
55
+ raise NotImplementedError
56
+
57
+ def get_max_variance(self):
58
+ alpha_n, sigma_n = self.get_rates(self.max_timesteps)
59
+ variance = jnp.sqrt(alpha_n ** 2 + sigma_n ** 2)
60
+ return variance
61
+
62
+ class GeneralizedNoiseScheduler(NoiseScheduler):
63
+ """
64
+ As per the generalization presented in the paper
65
+ "Elucidating the Design Space of Diffusion-Based
66
+ Generative Models" by Tero Karras et al.
67
+ Basically the signal rate shall always be 1, and the model
68
+ input itself shall be scaled to match the noise rate
69
+ """
70
+ def __init__(self, timesteps, sigma_min=0.002, sigma_max=80.0, sigma_data=1, *args, **kwargs):
71
+ super().__init__(timesteps, *args, **kwargs)
72
+ self.sigma_min = sigma_min
73
+ self.sigma_max = sigma_max
74
+ self.sigma_data = sigma_data
75
+
76
+ def get_weights(self, steps, shape=(-1, 1, 1, 1)):
77
+ sigma = self.get_sigmas(steps)
78
+ return (1 + (1 / (1 + ((1 - sigma ** 2)/(sigma ** 2)))) / (self.sigma_max ** 2)).reshape(shape)
79
+
80
+ def get_sigmas(self, steps) -> jnp.ndarray:
81
+ raise NotImplementedError("This method should be implemented in the subclass")
82
+
83
+ def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
84
+ sigmas = self.get_sigmas(steps)
85
+ signal_rates = 1
86
+ noise_rates = sigmas
87
+ return self.reshape_rates((signal_rates, noise_rates), shape=shape)
88
+
89
+ def transform_inputs(self, x, steps, num_discrete_chunks=1000):
90
+ sigmas_discrete = (steps / self.max_timesteps) * num_discrete_chunks
91
+ sigmas_discrete = sigmas_discrete.astype(jnp.int32)
92
+ return x, sigmas_discrete
93
+
94
+ def get_timesteps(self, sigmas):
95
+ """
96
+ Inverse of the get_sigmas method
97
+ """
98
+ raise NotImplementedError("This method should be implemented in the subclass")
@@ -0,0 +1,12 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from typing import Union
4
+ from ..utils import RandomMarkovState
5
+ from .common import NoiseScheduler
6
+
7
+ class ContinuousNoiseScheduler(NoiseScheduler):
8
+ """
9
+ General Continuous Noise Scheduler
10
+ """
11
+ def __init__(self, *args, **kwargs):
12
+ super().__init__(timesteps=1, *args, **kwargs)
@@ -0,0 +1,40 @@
1
+ import math
2
+ import numpy as np
3
+ import jax.numpy as jnp
4
+ from .discrete import DiscreteNoiseScheduler
5
+ from .continuous import ContinuousNoiseScheduler
6
+ from .common import GeneralizedNoiseScheduler
7
+
8
+ def cosine_beta_schedule(timesteps, start_angle=0.008, end_angle=0.999):
9
+ ts = np.linspace(0, 1, timesteps + 1, dtype=np.float64)
10
+ alphas_bar = np.cos((ts + start_angle) / (1 + start_angle) * np.pi /2) ** 2
11
+ alphas_bar = alphas_bar/alphas_bar[0]
12
+ betas = 1 - (alphas_bar[1:] / alphas_bar[:-1])
13
+ return np.clip(betas, 0, end_angle)
14
+
15
+ class CosineNoiseSchedule(DiscreteNoiseScheduler):
16
+ def __init__(self, timesteps, beta_start=0.008, beta_end=0.999, *args, **kwargs):
17
+ super().__init__(timesteps, beta_start, beta_end, schedule_fn=cosine_beta_schedule, *args, **kwargs)
18
+
19
+ class CosineGeneralNoiseScheduler(GeneralizedNoiseScheduler):
20
+ def __init__(self, sigma_min=0.02, sigma_max=80.0, kappa=1.0, *args, **kwargs):
21
+ super().__init__(timesteps=1, sigma_min=sigma_min, sigma_max=sigma_max, *args, **kwargs)
22
+ self.kappa = kappa
23
+ logsnr_max = 2 * (math.log(self.kappa) - math.log(self.sigma_max))
24
+ self.theta_max = math.atan(math.exp(-0.5 * logsnr_max))
25
+ logsnr_min = 2 * (math.log(self.kappa) - math.log(self.sigma_min))
26
+ self.theta_min = math.atan(math.exp(-0.5 * logsnr_min))
27
+
28
+ def get_sigmas(self, steps):
29
+ return jnp.tan(self.theta_min + steps * (self.theta_max - self.theta_min)) / self.kappa
30
+
31
+ class CosineContinuousNoiseScheduler(ContinuousNoiseScheduler):
32
+ def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
33
+ signal_rates = jnp.cos((jnp.pi * steps) / (2 * self.max_timesteps))
34
+ noise_rates = jnp.sin((jnp.pi * steps) / (2 * self.max_timesteps))
35
+ return self.reshape_rates((signal_rates, noise_rates), shape=shape)
36
+
37
+ def get_weights(self, steps):
38
+ alpha, sigma = self.get_rates(steps, shape=())
39
+ return 1 / (1 + (alpha ** 2 / sigma ** 2))
40
+
@@ -0,0 +1,74 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from typing import Union
4
+ from ..utils import RandomMarkovState
5
+ from .common import NoiseScheduler
6
+
7
+ class DiscreteNoiseScheduler(NoiseScheduler):
8
+ """
9
+ Variance Preserving Noise Scheduler
10
+ signal_rate**2 + noise_rate**2 = 1
11
+ """
12
+ def __init__(self, timesteps,
13
+ beta_start=0.0001,
14
+ beta_end=0.02,
15
+ schedule_fn=None,
16
+ p2_loss_weight_k:float=1,
17
+ p2_loss_weight_gamma:float=1,
18
+ *args, **kwargs):
19
+ super().__init__(timesteps, *args, **kwargs)
20
+ betas = schedule_fn(timesteps, beta_start, beta_end)
21
+ alphas = 1 - betas
22
+ alpha_cumprod = jnp.cumprod(alphas, axis=0)
23
+ alpha_cumprod_prev = jnp.append(1.0, alpha_cumprod[:-1])
24
+
25
+ self.betas = jnp.array(betas, dtype=jnp.float32)
26
+ self.alphas = alphas.astype(jnp.float32)
27
+ self.alpha_cumprod = alpha_cumprod.astype(jnp.float32)
28
+ self.alpha_cumprod_prev = alpha_cumprod_prev.astype(jnp.float32)
29
+
30
+ self.sqrt_alpha_cumprod = jnp.sqrt(alpha_cumprod).astype(jnp.float32)
31
+ self.sqrt_one_minus_alpha_cumprod = jnp.sqrt(1 - alpha_cumprod).astype(jnp.float32)
32
+
33
+ posterior_variance = (betas * (1 - alpha_cumprod_prev) / (1 - alpha_cumprod))
34
+ self.posterior_variance = posterior_variance.astype(jnp.float32)
35
+ self.posterior_log_variance_clipped = (jnp.log(jnp.maximum(posterior_variance, 1e-20))).astype(jnp.float32)
36
+
37
+ self.posterior_mean_coef1 = (betas * jnp.sqrt(alpha_cumprod_prev) / (1 - alpha_cumprod)).astype(jnp.float32)
38
+ self.posterior_mean_coef2 = ((1 - alpha_cumprod_prev) * jnp.sqrt(alphas) / (1 - alpha_cumprod)).astype(jnp.float32)
39
+
40
+ self.p2_loss_weights = self.get_p2_weights(p2_loss_weight_k, p2_loss_weight_gamma)
41
+
42
+ def generate_timesteps(self, batch_size, state:RandomMarkovState) -> tuple[jnp.ndarray, RandomMarkovState]:
43
+ state, rng = state.get_random_key()
44
+ timesteps = jax.random.randint(rng, (batch_size,), 0, self.max_timesteps)
45
+ return timesteps, state
46
+
47
+ def get_p2_weights(self, k, gamma):
48
+ return (k + self.alpha_cumprod / (1 - self.alpha_cumprod)) ** -gamma
49
+
50
+ def get_weights(self, steps, shape=(-1, 1, 1, 1)):
51
+ steps = jnp.int16(steps)
52
+ return self.p2_loss_weights[steps].reshape(shape)
53
+
54
+ def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
55
+ steps = jnp.int16(steps)
56
+ signal_rate = self.sqrt_alpha_cumprod[steps]
57
+ noise_rate = self.sqrt_one_minus_alpha_cumprod[steps]
58
+ signal_rate = jnp.reshape(signal_rate, shape)
59
+ noise_rate = jnp.reshape(noise_rate, shape)
60
+ return signal_rate, noise_rate
61
+
62
+ def get_posterior_mean(self, x_0, x_t, steps):
63
+ steps = jnp.int16(steps)
64
+ x_0_coeff = self.posterior_mean_coef1[steps]
65
+ x_t_coeff = self.posterior_mean_coef2[steps]
66
+ x_0_coeff, x_t_coeff = self.reshape_rates((x_0_coeff, x_t_coeff))
67
+ mean = x_0_coeff * x_0 + x_t_coeff * x_t
68
+ return mean
69
+
70
+ def get_posterior_variance(self, steps, shape=(-1, 1, 1, 1)):
71
+ steps = int(steps)
72
+ return jnp.exp(0.5 * self.posterior_log_variance_clipped[steps]).reshape(shape)
73
+
74
+
@@ -0,0 +1,13 @@
1
+ import numpy as np
2
+ from .discrete import DiscreteNoiseScheduler
3
+
4
+ def exp_beta_schedule(timesteps, start_angle=0.008, end_angle=0.999):
5
+ ts = np.linspace(0, 1, timesteps + 1, dtype=np.float64)
6
+ alphas_bar = np.exp(ts * -12.0)
7
+ alphas_bar = alphas_bar/alphas_bar[0]
8
+ betas = 1 - (alphas_bar[1:] / alphas_bar[:-1])
9
+ return np.clip(betas, 0, end_angle)
10
+
11
+ class ExpNoiseSchedule(DiscreteNoiseScheduler):
12
+ def __init__(self, timesteps, beta_start=0.008, beta_end=0.999, *args, **kwargs):
13
+ super().__init__(timesteps, beta_start, beta_end, schedule_fn=exp_beta_schedule, *args, **kwargs)