flaxdiff 0.1.36.1__py3-none-any.whl → 0.1.36.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. flaxdiff/data/__init__.py +1 -0
  2. flaxdiff/data/dataset_map.py +71 -0
  3. flaxdiff/data/datasets.py +169 -0
  4. flaxdiff/data/online_loader.py +363 -0
  5. flaxdiff/data/sources/gcs.py +81 -0
  6. flaxdiff/data/sources/tfds.py +67 -0
  7. flaxdiff/metrics/inception.py +658 -0
  8. flaxdiff/metrics/utils.py +49 -0
  9. flaxdiff/models/__init__.py +1 -0
  10. flaxdiff/models/attention.py +368 -0
  11. flaxdiff/models/autoencoder/__init__.py +2 -0
  12. flaxdiff/models/autoencoder/autoencoder.py +19 -0
  13. flaxdiff/models/autoencoder/diffusers.py +91 -0
  14. flaxdiff/models/autoencoder/simple_autoenc.py +26 -0
  15. flaxdiff/models/common.py +346 -0
  16. flaxdiff/models/favor_fastattn.py +723 -0
  17. flaxdiff/models/simple_unet.py +233 -0
  18. flaxdiff/models/simple_vit.py +180 -0
  19. flaxdiff/predictors/__init__.py +96 -0
  20. flaxdiff/samplers/__init__.py +7 -0
  21. flaxdiff/samplers/common.py +165 -0
  22. flaxdiff/samplers/ddim.py +10 -0
  23. flaxdiff/samplers/ddpm.py +37 -0
  24. flaxdiff/samplers/euler.py +56 -0
  25. flaxdiff/samplers/heun_sampler.py +27 -0
  26. flaxdiff/samplers/multistep_dpm.py +59 -0
  27. flaxdiff/samplers/rk4_sampler.py +34 -0
  28. flaxdiff/schedulers/__init__.py +6 -0
  29. flaxdiff/schedulers/common.py +98 -0
  30. flaxdiff/schedulers/continuous.py +12 -0
  31. flaxdiff/schedulers/cosine.py +40 -0
  32. flaxdiff/schedulers/discrete.py +74 -0
  33. flaxdiff/schedulers/exp.py +13 -0
  34. flaxdiff/schedulers/karras.py +69 -0
  35. flaxdiff/schedulers/linear.py +14 -0
  36. flaxdiff/schedulers/sqrt.py +10 -0
  37. flaxdiff/trainer/__init__.py +2 -0
  38. flaxdiff/trainer/autoencoder_trainer.py +182 -0
  39. flaxdiff/trainer/diffusion_trainer.py +326 -0
  40. flaxdiff/trainer/simple_trainer.py +540 -0
  41. flaxdiff/trainer/video_diffusion_trainer.py +62 -0
  42. {flaxdiff-0.1.36.1.dist-info → flaxdiff-0.1.36.3.dist-info}/METADATA +1 -1
  43. flaxdiff-0.1.36.3.dist-info/RECORD +47 -0
  44. flaxdiff-0.1.36.1.dist-info/RECORD +0 -6
  45. {flaxdiff-0.1.36.1.dist-info → flaxdiff-0.1.36.3.dist-info}/WHEEL +0 -0
  46. {flaxdiff-0.1.36.1.dist-info → flaxdiff-0.1.36.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,10 @@
1
+ import jax.numpy as jnp
2
+ from .common import DiffusionSampler
3
+ from ..utils import MarkovState, RandomMarkovState
4
+
5
+ class DDIMSampler(DiffusionSampler):
6
+ def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
7
+ pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
8
+ next_signal_rate, next_noise_rate = self.noise_schedule.get_rates(next_step)
9
+ return reconstructed_samples * next_signal_rate + pred_noise * next_noise_rate, state
10
+
@@ -0,0 +1,37 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from .common import DiffusionSampler
4
+ from ..utils import MarkovState, RandomMarkovState
5
+ class DDPMSampler(DiffusionSampler):
6
+ def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
7
+ pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
8
+ mean = self.noise_schedule.get_posterior_mean(reconstructed_samples, current_samples, current_step)
9
+ variance = self.noise_schedule.get_posterior_variance(steps=current_step)
10
+
11
+ state, rng = state.get_random_key()
12
+ # Now sample from the posterior
13
+ noise = jax.random.normal(rng, reconstructed_samples.shape, dtype=jnp.float32)
14
+
15
+ return mean + noise * variance, state
16
+
17
+ def generate_images(self, num_images=16, diffusion_steps=1000, start_step: int = None, *args, **kwargs):
18
+ return super().generate_images(num_images=num_images, diffusion_steps=diffusion_steps, start_step=start_step, *args, **kwargs)
19
+
20
+ class SimpleDDPMSampler(DiffusionSampler):
21
+ def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
22
+ pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
23
+ state, rng = state.get_random_key()
24
+ noise = jax.random.normal(rng, reconstructed_samples.shape, dtype=jnp.float32)
25
+
26
+ # Compute noise rates and signal rates only once
27
+ current_signal_rate, current_noise_rate = self.noise_schedule.get_rates(current_step)
28
+ next_signal_rate, next_noise_rate = self.noise_schedule.get_rates(next_step)
29
+
30
+ pred_noise_coeff = ((next_noise_rate ** 2) * current_signal_rate) / (current_noise_rate * next_signal_rate)
31
+
32
+ noise_ratio_squared = (next_noise_rate ** 2) / (current_noise_rate ** 2)
33
+ signal_ratio_squared = (current_signal_rate ** 2) / (next_signal_rate ** 2)
34
+ gamma = jnp.sqrt(noise_ratio_squared * (1 - signal_ratio_squared))
35
+
36
+ next_samples = next_signal_rate * reconstructed_samples + pred_noise_coeff * pred_noise + noise * gamma
37
+ return next_samples, state
@@ -0,0 +1,56 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from .common import DiffusionSampler
4
+ from ..utils import RandomMarkovState
5
+
6
+ class EulerSampler(DiffusionSampler):
7
+ # Basically a DDIM Sampler but parameterized as an ODE
8
+ def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
9
+ pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
10
+ current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
11
+ next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
12
+
13
+ dt = next_sigma - current_sigma
14
+
15
+ x_0_coeff = (current_alpha * next_sigma - next_alpha * current_sigma) / (dt)
16
+ dx = (current_samples - x_0_coeff * reconstructed_samples) / current_sigma
17
+ next_samples = current_samples + dx * dt
18
+ return next_samples, state
19
+
20
+ class SimplifiedEulerSampler(DiffusionSampler):
21
+ """
22
+ This is for networks with forward diffusion of the form x_{t+1} = x_t + sigma_t * epsilon_t
23
+ """
24
+ def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
25
+ pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
26
+ _, current_sigma = self.noise_schedule.get_rates(current_step)
27
+ _, next_sigma = self.noise_schedule.get_rates(next_step)
28
+
29
+ dt = next_sigma - current_sigma
30
+
31
+ dx = (current_samples - reconstructed_samples) / current_sigma
32
+ next_samples = current_samples + dx * dt
33
+ return next_samples, state
34
+
35
+ class EulerAncestralSampler(DiffusionSampler):
36
+ """
37
+ Similar to EulerSampler but with ancestral sampling
38
+ """
39
+ def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
40
+ pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
41
+ current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
42
+ next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
43
+
44
+ sigma_up = (next_sigma**2 * (current_sigma**2 - next_sigma**2) / current_sigma**2) ** 0.5
45
+ sigma_down = (next_sigma**2 - sigma_up**2) ** 0.5
46
+
47
+ dt = sigma_down - current_sigma
48
+
49
+ x_0_coeff = (current_alpha * next_sigma - next_alpha * current_sigma) / (next_sigma - current_sigma)
50
+ dx = (current_samples - x_0_coeff * reconstructed_samples) / current_sigma
51
+
52
+ state, subkey = state.get_random_key()
53
+ dW = jax.random.normal(subkey, current_samples.shape) * sigma_up
54
+
55
+ next_samples = current_samples + dx * dt + dW
56
+ return next_samples, state
@@ -0,0 +1,27 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from .common import DiffusionSampler
4
+ from ..utils import RandomMarkovState
5
+
6
+ class HeunSampler(DiffusionSampler):
7
+ def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
8
+ pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
9
+ # Get the noise and signal rates for the current and next steps
10
+ current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
11
+ next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
12
+
13
+ dt = next_sigma - current_sigma
14
+ x_0_coeff = (current_alpha * next_sigma - next_alpha * current_sigma) / dt
15
+
16
+ dx_0 = (current_samples - x_0_coeff * reconstructed_samples) / current_sigma
17
+ next_samples_0 = current_samples + dx_0 * dt
18
+
19
+ # Recompute x_0 and eps at the first estimate to refine the derivative
20
+ estimated_x_0, _, _ = self.sample_model(next_samples_0, next_step, *model_conditioning_inputs)
21
+
22
+ # Estimate the refined derivative using the midpoint (Heun's method)
23
+ dx_1 = (next_samples_0 - x_0_coeff * estimated_x_0) / next_sigma
24
+ # Compute the final next samples by averaging the initial and refined derivatives
25
+ final_next_samples = current_samples + 0.5 * (dx_0 + dx_1) * dt
26
+
27
+ return final_next_samples, state
@@ -0,0 +1,59 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from .common import DiffusionSampler
4
+ from ..utils import RandomMarkovState
5
+
6
+ class MultiStepDPM(DiffusionSampler):
7
+ def __init__(self, *args, **kwargs):
8
+ super().__init__(*args, **kwargs)
9
+ self.history = []
10
+
11
+ def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
12
+ pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
13
+ # Get the noise and signal rates for the current and next steps
14
+ current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
15
+ next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
16
+
17
+ dt = next_sigma - current_sigma
18
+
19
+ def first_order(current_noise, current_sigma):
20
+ dx = current_noise
21
+ return dx
22
+
23
+ def second_order(current_noise, current_sigma, last_noise, last_sigma):
24
+ dx_2 = (current_noise - last_noise) / (current_sigma - last_sigma)
25
+ return dx_2
26
+
27
+ def third_order(current_noise, current_sigma, last_noise, last_sigma, second_last_noise, second_last_sigma):
28
+ dx_2 = second_order(current_noise, current_sigma, last_noise, last_sigma)
29
+ dx_2_last = second_order(last_noise, last_sigma, second_last_noise, second_last_sigma)
30
+
31
+ dx_3 = (dx_2 - dx_2_last) / (0.5 * ((current_sigma + last_sigma) - (last_sigma + second_last_sigma)))
32
+
33
+ return dx_3
34
+
35
+ if len(self.history) == 0:
36
+ # First order only
37
+ dx_1 = first_order(pred_noise, current_sigma)
38
+ next_samples = current_samples + dx_1 * dt
39
+ elif len(self.history) == 1:
40
+ # First + Second order
41
+ dx_1 = first_order(pred_noise, current_sigma)
42
+ last_step = self.history[-1]
43
+ dx_2 = second_order(pred_noise, current_sigma, last_step['eps'], last_step['sigma'])
44
+ next_samples = current_samples + dx_1 * dt + 0.5 * dx_2 * dt**2
45
+ else:
46
+ # First + Second + Third order
47
+ last_step = self.history[-1]
48
+ second_last_step = self.history[-2]
49
+
50
+ dx_1 = first_order(pred_noise, current_sigma)
51
+ dx_2 = second_order(pred_noise, current_sigma, last_step['eps'], last_step['sigma'])
52
+ dx_3 = third_order(pred_noise, current_sigma, last_step['eps'], last_step['sigma'], second_last_step['eps'], second_last_step['sigma'])
53
+ next_samples = current_samples + (dx_1 * dt) + (0.5 * dx_2 * dt**2) + ((1/6) * dx_3 * dt**3)
54
+
55
+ self.history.append({
56
+ "eps": pred_noise,
57
+ "sigma" : current_sigma,
58
+ })
59
+ return next_samples, state
@@ -0,0 +1,34 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from .common import DiffusionSampler
4
+ from ..utils import RandomMarkovState, MarkovState
5
+ from ..schedulers import GeneralizedNoiseScheduler
6
+
7
+ class RK4Sampler(DiffusionSampler):
8
+ def __init__(self, *args, **kwargs):
9
+ super().__init__(*args, **kwargs)
10
+ assert issubclass(type(self.noise_schedule), GeneralizedNoiseScheduler), "Noise schedule must be a GeneralizedNoiseScheduler"
11
+ @jax.jit
12
+ def get_derivative(x_t, sigma, state:RandomMarkovState, model_conditioning_inputs) -> tuple[jnp.ndarray, RandomMarkovState]:
13
+ t = self.noise_schedule.get_timesteps(sigma)
14
+ x_0, eps, _ = self.sample_model(x_t, t, *model_conditioning_inputs)
15
+ return eps, state
16
+
17
+ self.get_derivative = get_derivative
18
+
19
+ def sample_step(self, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
20
+ step_ones = jnp.ones((current_samples.shape[0], ), dtype=jnp.int32)
21
+ current_step = step_ones * current_step
22
+ next_step = step_ones * next_step
23
+ _, current_sigma = self.noise_schedule.get_rates(current_step)
24
+ _, next_sigma = self.noise_schedule.get_rates(next_step)
25
+
26
+ dt = next_sigma - current_sigma
27
+
28
+ k1, state = self.get_derivative(current_samples, current_sigma, state, model_conditioning_inputs)
29
+ k2, state = self.get_derivative(current_samples + 0.5 * k1 * dt, current_sigma + 0.5 * dt, state, model_conditioning_inputs)
30
+ k3, state = self.get_derivative(current_samples + 0.5 * k2 * dt, current_sigma + 0.5 * dt, state, model_conditioning_inputs)
31
+ k4, state = self.get_derivative(current_samples + k3 * dt, current_sigma + dt, state, model_conditioning_inputs)
32
+
33
+ next_samples = current_samples + (((k1 + 2 * k2 + 2 * k3 + k4) * dt) / 6)
34
+ return next_samples, state
@@ -0,0 +1,6 @@
1
+ from .discrete import DiscreteNoiseScheduler
2
+ from .common import NoiseScheduler, GeneralizedNoiseScheduler
3
+ from .cosine import CosineNoiseSchedule, ContinuousNoiseScheduler, CosineGeneralNoiseScheduler
4
+ from .linear import LinearNoiseSchedule
5
+ from .sqrt import SqrtContinuousNoiseScheduler
6
+ from .karras import KarrasVENoiseScheduler, SimpleExpNoiseScheduler, EDMNoiseScheduler
@@ -0,0 +1,98 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from typing import Union
4
+ from ..utils import RandomMarkovState
5
+
6
+ class NoiseScheduler():
7
+ def __init__(self, timesteps,
8
+ dtype=jnp.float32,
9
+ clip_min=-1.0,
10
+ clip_max=1.0,
11
+ *args, **kwargs):
12
+ self.max_timesteps = timesteps
13
+ self.dtype = dtype
14
+ self.clip_min = clip_min
15
+ self.clip_max = clip_max
16
+ if type(timesteps) == int and timesteps > 1:
17
+ timestep_generator = lambda rng, batch_size, max_timesteps = timesteps: jax.random.randint(rng, (batch_size,), 0, max_timesteps)
18
+ else:
19
+ timestep_generator = lambda rng, batch_size, max_timesteps = timesteps: jax.random.uniform(rng, (batch_size,), minval=0, maxval=max_timesteps)
20
+ self.timestep_generator = timestep_generator
21
+
22
+ def generate_timesteps(self, batch_size, state:RandomMarkovState) -> tuple[jnp.ndarray, RandomMarkovState]:
23
+ state, rng = state.get_random_key()
24
+ timesteps = self.timestep_generator(rng, batch_size, self.max_timesteps)
25
+ return timesteps, state
26
+
27
+ def get_weights(self, steps):
28
+ raise NotImplementedError
29
+
30
+ def reshape_rates(self, rates:tuple[jnp.ndarray, jnp.ndarray], shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
31
+ signal_rates, noise_rates = rates
32
+ signal_rates = jnp.reshape(signal_rates, shape)
33
+ noise_rates = jnp.reshape(noise_rates, shape)
34
+ return signal_rates, noise_rates
35
+
36
+ def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
37
+ raise NotImplementedError
38
+
39
+ def add_noise(self, images, noise, steps) -> jnp.ndarray:
40
+ signal_rates, noise_rates = self.get_rates(steps)
41
+ return signal_rates * images + noise_rates * noise
42
+
43
+ def remove_all_noise(self, noisy_images, noise, steps, clip_denoised=True, rates=None):
44
+ signal_rates, noise_rates = self.get_rates(steps)
45
+ x_0 = (noisy_images - noise * noise_rates) / signal_rates
46
+ return x_0
47
+
48
+ def transform_inputs(self, x, steps):
49
+ return x, steps
50
+
51
+ def get_posterior_mean(self, x_0, x_t, steps):
52
+ raise NotImplementedError
53
+
54
+ def get_posterior_variance(self, steps, shape=(-1, 1, 1, 1)):
55
+ raise NotImplementedError
56
+
57
+ def get_max_variance(self):
58
+ alpha_n, sigma_n = self.get_rates(self.max_timesteps)
59
+ variance = jnp.sqrt(alpha_n ** 2 + sigma_n ** 2)
60
+ return variance
61
+
62
+ class GeneralizedNoiseScheduler(NoiseScheduler):
63
+ """
64
+ As per the generalization presented in the paper
65
+ "Elucidating the Design Space of Diffusion-Based
66
+ Generative Models" by Tero Karras et al.
67
+ Basically the signal rate shall always be 1, and the model
68
+ input itself shall be scaled to match the noise rate
69
+ """
70
+ def __init__(self, timesteps, sigma_min=0.002, sigma_max=80.0, sigma_data=1, *args, **kwargs):
71
+ super().__init__(timesteps, *args, **kwargs)
72
+ self.sigma_min = sigma_min
73
+ self.sigma_max = sigma_max
74
+ self.sigma_data = sigma_data
75
+
76
+ def get_weights(self, steps, shape=(-1, 1, 1, 1)):
77
+ sigma = self.get_sigmas(steps)
78
+ return (1 + (1 / (1 + ((1 - sigma ** 2)/(sigma ** 2)))) / (self.sigma_max ** 2)).reshape(shape)
79
+
80
+ def get_sigmas(self, steps) -> jnp.ndarray:
81
+ raise NotImplementedError("This method should be implemented in the subclass")
82
+
83
+ def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
84
+ sigmas = self.get_sigmas(steps)
85
+ signal_rates = 1
86
+ noise_rates = sigmas
87
+ return self.reshape_rates((signal_rates, noise_rates), shape=shape)
88
+
89
+ def transform_inputs(self, x, steps, num_discrete_chunks=1000):
90
+ sigmas_discrete = (steps / self.max_timesteps) * num_discrete_chunks
91
+ sigmas_discrete = sigmas_discrete.astype(jnp.int32)
92
+ return x, sigmas_discrete
93
+
94
+ def get_timesteps(self, sigmas):
95
+ """
96
+ Inverse of the get_sigmas method
97
+ """
98
+ raise NotImplementedError("This method should be implemented in the subclass")
@@ -0,0 +1,12 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from typing import Union
4
+ from ..utils import RandomMarkovState
5
+ from .common import NoiseScheduler
6
+
7
+ class ContinuousNoiseScheduler(NoiseScheduler):
8
+ """
9
+ General Continuous Noise Scheduler
10
+ """
11
+ def __init__(self, *args, **kwargs):
12
+ super().__init__(timesteps=1, *args, **kwargs)
@@ -0,0 +1,40 @@
1
+ import math
2
+ import numpy as np
3
+ import jax.numpy as jnp
4
+ from .discrete import DiscreteNoiseScheduler
5
+ from .continuous import ContinuousNoiseScheduler
6
+ from .common import GeneralizedNoiseScheduler
7
+
8
+ def cosine_beta_schedule(timesteps, start_angle=0.008, end_angle=0.999):
9
+ ts = np.linspace(0, 1, timesteps + 1, dtype=np.float64)
10
+ alphas_bar = np.cos((ts + start_angle) / (1 + start_angle) * np.pi /2) ** 2
11
+ alphas_bar = alphas_bar/alphas_bar[0]
12
+ betas = 1 - (alphas_bar[1:] / alphas_bar[:-1])
13
+ return np.clip(betas, 0, end_angle)
14
+
15
+ class CosineNoiseSchedule(DiscreteNoiseScheduler):
16
+ def __init__(self, timesteps, beta_start=0.008, beta_end=0.999, *args, **kwargs):
17
+ super().__init__(timesteps, beta_start, beta_end, schedule_fn=cosine_beta_schedule, *args, **kwargs)
18
+
19
+ class CosineGeneralNoiseScheduler(GeneralizedNoiseScheduler):
20
+ def __init__(self, sigma_min=0.02, sigma_max=80.0, kappa=1.0, *args, **kwargs):
21
+ super().__init__(timesteps=1, sigma_min=sigma_min, sigma_max=sigma_max, *args, **kwargs)
22
+ self.kappa = kappa
23
+ logsnr_max = 2 * (math.log(self.kappa) - math.log(self.sigma_max))
24
+ self.theta_max = math.atan(math.exp(-0.5 * logsnr_max))
25
+ logsnr_min = 2 * (math.log(self.kappa) - math.log(self.sigma_min))
26
+ self.theta_min = math.atan(math.exp(-0.5 * logsnr_min))
27
+
28
+ def get_sigmas(self, steps):
29
+ return jnp.tan(self.theta_min + steps * (self.theta_max - self.theta_min)) / self.kappa
30
+
31
+ class CosineContinuousNoiseScheduler(ContinuousNoiseScheduler):
32
+ def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
33
+ signal_rates = jnp.cos((jnp.pi * steps) / (2 * self.max_timesteps))
34
+ noise_rates = jnp.sin((jnp.pi * steps) / (2 * self.max_timesteps))
35
+ return self.reshape_rates((signal_rates, noise_rates), shape=shape)
36
+
37
+ def get_weights(self, steps):
38
+ alpha, sigma = self.get_rates(steps, shape=())
39
+ return 1 / (1 + (alpha ** 2 / sigma ** 2))
40
+
@@ -0,0 +1,74 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from typing import Union
4
+ from ..utils import RandomMarkovState
5
+ from .common import NoiseScheduler
6
+
7
+ class DiscreteNoiseScheduler(NoiseScheduler):
8
+ """
9
+ Variance Preserving Noise Scheduler
10
+ signal_rate**2 + noise_rate**2 = 1
11
+ """
12
+ def __init__(self, timesteps,
13
+ beta_start=0.0001,
14
+ beta_end=0.02,
15
+ schedule_fn=None,
16
+ p2_loss_weight_k:float=1,
17
+ p2_loss_weight_gamma:float=1,
18
+ *args, **kwargs):
19
+ super().__init__(timesteps, *args, **kwargs)
20
+ betas = schedule_fn(timesteps, beta_start, beta_end)
21
+ alphas = 1 - betas
22
+ alpha_cumprod = jnp.cumprod(alphas, axis=0)
23
+ alpha_cumprod_prev = jnp.append(1.0, alpha_cumprod[:-1])
24
+
25
+ self.betas = jnp.array(betas, dtype=jnp.float32)
26
+ self.alphas = alphas.astype(jnp.float32)
27
+ self.alpha_cumprod = alpha_cumprod.astype(jnp.float32)
28
+ self.alpha_cumprod_prev = alpha_cumprod_prev.astype(jnp.float32)
29
+
30
+ self.sqrt_alpha_cumprod = jnp.sqrt(alpha_cumprod).astype(jnp.float32)
31
+ self.sqrt_one_minus_alpha_cumprod = jnp.sqrt(1 - alpha_cumprod).astype(jnp.float32)
32
+
33
+ posterior_variance = (betas * (1 - alpha_cumprod_prev) / (1 - alpha_cumprod))
34
+ self.posterior_variance = posterior_variance.astype(jnp.float32)
35
+ self.posterior_log_variance_clipped = (jnp.log(jnp.maximum(posterior_variance, 1e-20))).astype(jnp.float32)
36
+
37
+ self.posterior_mean_coef1 = (betas * jnp.sqrt(alpha_cumprod_prev) / (1 - alpha_cumprod)).astype(jnp.float32)
38
+ self.posterior_mean_coef2 = ((1 - alpha_cumprod_prev) * jnp.sqrt(alphas) / (1 - alpha_cumprod)).astype(jnp.float32)
39
+
40
+ self.p2_loss_weights = self.get_p2_weights(p2_loss_weight_k, p2_loss_weight_gamma)
41
+
42
+ def generate_timesteps(self, batch_size, state:RandomMarkovState) -> tuple[jnp.ndarray, RandomMarkovState]:
43
+ state, rng = state.get_random_key()
44
+ timesteps = jax.random.randint(rng, (batch_size,), 0, self.max_timesteps)
45
+ return timesteps, state
46
+
47
+ def get_p2_weights(self, k, gamma):
48
+ return (k + self.alpha_cumprod / (1 - self.alpha_cumprod)) ** -gamma
49
+
50
+ def get_weights(self, steps, shape=(-1, 1, 1, 1)):
51
+ steps = jnp.int16(steps)
52
+ return self.p2_loss_weights[steps].reshape(shape)
53
+
54
+ def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
55
+ steps = jnp.int16(steps)
56
+ signal_rate = self.sqrt_alpha_cumprod[steps]
57
+ noise_rate = self.sqrt_one_minus_alpha_cumprod[steps]
58
+ signal_rate = jnp.reshape(signal_rate, shape)
59
+ noise_rate = jnp.reshape(noise_rate, shape)
60
+ return signal_rate, noise_rate
61
+
62
+ def get_posterior_mean(self, x_0, x_t, steps):
63
+ steps = jnp.int16(steps)
64
+ x_0_coeff = self.posterior_mean_coef1[steps]
65
+ x_t_coeff = self.posterior_mean_coef2[steps]
66
+ x_0_coeff, x_t_coeff = self.reshape_rates((x_0_coeff, x_t_coeff))
67
+ mean = x_0_coeff * x_0 + x_t_coeff * x_t
68
+ return mean
69
+
70
+ def get_posterior_variance(self, steps, shape=(-1, 1, 1, 1)):
71
+ steps = int(steps)
72
+ return jnp.exp(0.5 * self.posterior_log_variance_clipped[steps]).reshape(shape)
73
+
74
+
@@ -0,0 +1,13 @@
1
+ import numpy as np
2
+ from .discrete import DiscreteNoiseScheduler
3
+
4
+ def exp_beta_schedule(timesteps, start_angle=0.008, end_angle=0.999):
5
+ ts = np.linspace(0, 1, timesteps + 1, dtype=np.float64)
6
+ alphas_bar = np.exp(ts * -12.0)
7
+ alphas_bar = alphas_bar/alphas_bar[0]
8
+ betas = 1 - (alphas_bar[1:] / alphas_bar[:-1])
9
+ return np.clip(betas, 0, end_angle)
10
+
11
+ class ExpNoiseSchedule(DiscreteNoiseScheduler):
12
+ def __init__(self, timesteps, beta_start=0.008, beta_end=0.999, *args, **kwargs):
13
+ super().__init__(timesteps, beta_start, beta_end, schedule_fn=exp_beta_schedule, *args, **kwargs)
@@ -0,0 +1,69 @@
1
+ import jax.numpy as jnp
2
+ from .common import GeneralizedNoiseScheduler
3
+ import math
4
+ import jax
5
+ from ..utils import RandomMarkovState
6
+
7
+ class KarrasVENoiseScheduler(GeneralizedNoiseScheduler):
8
+ def __init__(self, timesteps, sigma_min=0.002, sigma_max=80, rho=7., sigma_data=0.5, *args, **kwargs):
9
+ super().__init__(timesteps=timesteps, sigma_min=sigma_min, sigma_max=sigma_max, sigma_data=sigma_data, *args, **kwargs)
10
+ self.min_inv_rho = sigma_min ** (1 / rho)
11
+ self.max_inv_rho = sigma_max ** (1 / rho)
12
+ self.rho = rho
13
+
14
+ def get_sigmas(self, steps) -> jnp.ndarray:
15
+ # steps = jnp.int16(steps)
16
+ # return self.sigmas[steps]
17
+ ramp = 1 - steps / self.max_timesteps
18
+ sigmas = (self.max_inv_rho + ramp * (self.min_inv_rho - self.max_inv_rho)) ** self.rho
19
+ return sigmas
20
+
21
+ def get_weights(self, steps, shape=(-1, 1, 1, 1)) -> jnp.ndarray:
22
+ sigma = self.get_sigmas(steps)
23
+ weights = ((sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2)
24
+ return weights.reshape(shape)
25
+
26
+ def transform_inputs(self, x, steps, num_discrete_chunks=1000) -> tuple[jnp.ndarray, jnp.ndarray]:
27
+ sigmas = self.get_sigmas(steps)
28
+ # sigmas = (sigmas / self.sigma_max) * num_discrete_chunks
29
+ sigmas = jnp.log(sigmas) / 4
30
+ return x, sigmas
31
+
32
+ def get_timesteps(self, sigmas:jnp.ndarray) -> jnp.ndarray:
33
+ sigmas = sigmas.reshape(-1)
34
+ inv_rho = sigmas ** (1 / self.rho)
35
+ ramp = ((inv_rho - self.max_inv_rho) / (self.min_inv_rho - self.max_inv_rho))
36
+ steps = 1 - ramp * self.max_timesteps
37
+ return steps
38
+
39
+ def generate_timesteps(self, batch_size, state:RandomMarkovState) -> tuple[jnp.ndarray, RandomMarkovState]:
40
+ timesteps, state = super().generate_timesteps(batch_size, state)
41
+ timesteps = timesteps.astype(jnp.float32)
42
+ return timesteps, state
43
+
44
+ class SimpleExpNoiseScheduler(KarrasVENoiseScheduler):
45
+ def __init__(self, timesteps, sigma_min=0.002, sigma_max=80, rho=7., sigma_data=0.5, *args, **kwargs):
46
+ super().__init__(timesteps=timesteps, sigma_min=sigma_min, sigma_max=sigma_max, sigma_data=sigma_data, *args, **kwargs)
47
+ if type(timesteps) == int and timesteps > 1:
48
+ n = timesteps
49
+ else:
50
+ n = 1000
51
+ self.sigmas = jnp.exp(jnp.linspace(math.log(sigma_min), math.log(sigma_max), n))
52
+
53
+ def get_sigmas(self, steps) -> jnp.ndarray:
54
+ steps = jnp.int16(steps)
55
+ return self.sigmas[steps]
56
+
57
+ class EDMNoiseScheduler(KarrasVENoiseScheduler):
58
+ def __init__(self, timesteps, sigma_min=0.002, sigma_max=80, rho=7., sigma_data=0.5, *args, **kwargs):
59
+ super().__init__(timesteps=timesteps, sigma_min=sigma_min, sigma_max=sigma_max, sigma_data=sigma_data, *args, **kwargs)
60
+
61
+ def get_sigmas(self, steps, std=1.2, mean=-1.2) -> jnp.ndarray:
62
+ space = steps / self.max_timesteps
63
+ # space = jax.scipy.special.erfinv(self.erf_sigma_min + steps * (self.erf_sigma_max - self.erf_sigma_min))
64
+ return jnp.exp(space * std + mean)
65
+
66
+ def generate_timesteps(self, batch_size, state:RandomMarkovState) -> tuple[jnp.ndarray, RandomMarkovState]:
67
+ state, rng = state.get_random_key()
68
+ timesteps = jax.random.normal(rng, (batch_size,), dtype=jnp.float32)
69
+ return timesteps, state
@@ -0,0 +1,14 @@
1
+ import numpy as np
2
+ from .discrete import DiscreteNoiseScheduler
3
+
4
+ def linear_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02):
5
+ scale = 1000 / timesteps
6
+ beta_start = scale * beta_start
7
+ beta_end = scale * beta_end
8
+ betas = np.linspace(
9
+ beta_start, beta_end, timesteps, dtype=np.float64)
10
+ return betas
11
+
12
+ class LinearNoiseSchedule(DiscreteNoiseScheduler):
13
+ def __init__(self, timesteps, beta_start=0.0001, beta_end=0.02, *args, **kwargs):
14
+ super().__init__(timesteps, beta_start, beta_end, schedule_fn=linear_beta_schedule, *args, **kwargs)
@@ -0,0 +1,10 @@
1
+ import numpy as np
2
+ import jax.numpy as jnp
3
+ from .discrete import DiscreteNoiseScheduler
4
+ from .continuous import ContinuousNoiseScheduler
5
+
6
+ class SqrtContinuousNoiseScheduler(ContinuousNoiseScheduler):
7
+ def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
8
+ signal_rates = jnp.sqrt(1 - steps)
9
+ noise_rates = jnp.sqrt(steps)
10
+ return self.reshape_rates((signal_rates, noise_rates), shape=shape)
@@ -0,0 +1,2 @@
1
+ from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
2
+ from .diffusion_trainer import DiffusionTrainer, TrainState