flaxdiff 0.1.35.6__py3-none-any.whl → 0.1.36.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. flaxdiff/utils.py +105 -2
  2. {flaxdiff-0.1.35.6.dist-info → flaxdiff-0.1.36.1.dist-info}/METADATA +16 -7
  3. flaxdiff-0.1.36.1.dist-info/RECORD +6 -0
  4. {flaxdiff-0.1.35.6.dist-info → flaxdiff-0.1.36.1.dist-info}/WHEEL +1 -1
  5. flaxdiff/data/__init__.py +0 -1
  6. flaxdiff/data/online_loader.py +0 -336
  7. flaxdiff/models/__init__.py +0 -1
  8. flaxdiff/models/attention.py +0 -368
  9. flaxdiff/models/autoencoder/__init__.py +0 -2
  10. flaxdiff/models/autoencoder/autoencoder.py +0 -19
  11. flaxdiff/models/autoencoder/diffusers.py +0 -91
  12. flaxdiff/models/autoencoder/simple_autoenc.py +0 -26
  13. flaxdiff/models/common.py +0 -346
  14. flaxdiff/models/favor_fastattn.py +0 -723
  15. flaxdiff/models/simple_unet.py +0 -233
  16. flaxdiff/models/simple_vit.py +0 -180
  17. flaxdiff/predictors/__init__.py +0 -96
  18. flaxdiff/samplers/__init__.py +0 -7
  19. flaxdiff/samplers/common.py +0 -113
  20. flaxdiff/samplers/ddim.py +0 -10
  21. flaxdiff/samplers/ddpm.py +0 -43
  22. flaxdiff/samplers/euler.py +0 -59
  23. flaxdiff/samplers/heun_sampler.py +0 -28
  24. flaxdiff/samplers/multistep_dpm.py +0 -60
  25. flaxdiff/samplers/rk4_sampler.py +0 -34
  26. flaxdiff/schedulers/__init__.py +0 -6
  27. flaxdiff/schedulers/common.py +0 -98
  28. flaxdiff/schedulers/continuous.py +0 -12
  29. flaxdiff/schedulers/cosine.py +0 -40
  30. flaxdiff/schedulers/discrete.py +0 -74
  31. flaxdiff/schedulers/exp.py +0 -13
  32. flaxdiff/schedulers/karras.py +0 -69
  33. flaxdiff/schedulers/linear.py +0 -14
  34. flaxdiff/schedulers/sqrt.py +0 -10
  35. flaxdiff/trainer/__init__.py +0 -2
  36. flaxdiff/trainer/autoencoder_trainer.py +0 -182
  37. flaxdiff/trainer/diffusion_trainer.py +0 -234
  38. flaxdiff/trainer/simple_trainer.py +0 -442
  39. flaxdiff-0.1.35.6.dist-info/RECORD +0 -40
  40. {flaxdiff-0.1.35.6.dist-info → flaxdiff-0.1.36.1.dist-info}/top_level.txt +0 -0
flaxdiff/samplers/ddpm.py DELETED
@@ -1,43 +0,0 @@
1
- import jax
2
- import jax.numpy as jnp
3
- from .common import DiffusionSampler
4
- from ..utils import MarkovState, RandomMarkovState
5
- class DDPMSampler(DiffusionSampler):
6
- def take_next_step(self,
7
- current_samples, reconstructed_samples,
8
- pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
9
- mean = self.noise_schedule.get_posterior_mean(reconstructed_samples, current_samples, current_step)
10
- variance = self.noise_schedule.get_posterior_variance(steps=current_step)
11
-
12
- state, rng = state.get_random_key()
13
- # Now sample from the posterior
14
- noise = jax.random.normal(rng, reconstructed_samples.shape, dtype=jnp.float32)
15
-
16
- return mean + noise * variance, state
17
-
18
- def generate_images(self, num_images=16, diffusion_steps=1000, start_step: int = None, *args, **kwargs):
19
- return super().generate_images(num_images=num_images, diffusion_steps=diffusion_steps, start_step=start_step, *args, **kwargs)
20
-
21
- class SimpleDDPMSampler(DiffusionSampler):
22
- def take_next_step(self,
23
- current_samples, reconstructed_samples,
24
- pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
25
- state, rng = state.get_random_key()
26
- noise = jax.random.normal(rng, reconstructed_samples.shape, dtype=jnp.float32)
27
-
28
- # Compute noise rates and signal rates only once
29
- current_signal_rate, current_noise_rate = self.noise_schedule.get_rates(current_step)
30
- next_signal_rate, next_noise_rate = self.noise_schedule.get_rates(next_step)
31
-
32
- pred_noise_coeff = ((next_noise_rate ** 2) * current_signal_rate) / (current_noise_rate * next_signal_rate)
33
-
34
- noise_ratio_squared = (next_noise_rate ** 2) / (current_noise_rate ** 2)
35
- signal_ratio_squared = (current_signal_rate ** 2) / (next_signal_rate ** 2)
36
- betas = (1 - signal_ratio_squared)
37
- gamma = jnp.sqrt(noise_ratio_squared * betas)
38
-
39
- next_samples = next_signal_rate * reconstructed_samples + pred_noise_coeff * pred_noise + noise * gamma
40
- # pred_noise_coeff = ((next_noise_rate ** 2) * current_signal_rate) / (current_noise_rate * next_signal_rate)
41
- # next_samples = (2 - jnp.sqrt(1 - betas)) * current_samples - betas * (pred_noise / current_noise_rate) + noise * gamma#jnp.sqrt(betas)
42
- # next_samples = (1 / (jnp.sqrt(1 - betas) + 1.e-24)) * (current_samples - betas * (pred_noise / current_noise_rate)) + noise * gamma
43
- return next_samples, state
@@ -1,59 +0,0 @@
1
- import jax
2
- import jax.numpy as jnp
3
- from .common import DiffusionSampler
4
- from ..utils import RandomMarkovState
5
-
6
- class EulerSampler(DiffusionSampler):
7
- # Basically a DDIM Sampler but parameterized as an ODE
8
- def take_next_step(self,
9
- current_samples, reconstructed_samples,
10
- pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
11
- current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
12
- next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
13
-
14
- dt = next_sigma - current_sigma
15
-
16
- x_0_coeff = (current_alpha * next_sigma - next_alpha * current_sigma) / (dt)
17
- dx = (current_samples - x_0_coeff * reconstructed_samples) / current_sigma
18
- next_samples = current_samples + dx * dt
19
- return next_samples, state
20
-
21
- class SimplifiedEulerSampler(DiffusionSampler):
22
- """
23
- This is for networks with forward diffusion of the form x_{t+1} = x_t + sigma_t * epsilon_t
24
- """
25
- def take_next_step(self,
26
- current_samples, reconstructed_samples,
27
- pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
28
- _, current_sigma = self.noise_schedule.get_rates(current_step)
29
- _, next_sigma = self.noise_schedule.get_rates(next_step)
30
-
31
- dt = next_sigma - current_sigma
32
-
33
- dx = (current_samples - reconstructed_samples) / current_sigma
34
- next_samples = current_samples + dx * dt
35
- return next_samples, state
36
-
37
- class EulerAncestralSampler(DiffusionSampler):
38
- """
39
- Similar to EulerSampler but with ancestral sampling
40
- """
41
- def take_next_step(self,
42
- current_samples, reconstructed_samples,
43
- pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
44
- current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
45
- next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
46
-
47
- sigma_up = (next_sigma**2 * (current_sigma**2 - next_sigma**2) / current_sigma**2) ** 0.5
48
- sigma_down = (next_sigma**2 - sigma_up**2) ** 0.5
49
-
50
- dt = sigma_down - current_sigma
51
-
52
- x_0_coeff = (current_alpha * next_sigma - next_alpha * current_sigma) / (next_sigma - current_sigma)
53
- dx = (current_samples - x_0_coeff * reconstructed_samples) / current_sigma
54
-
55
- state, subkey = state.get_random_key()
56
- dW = jax.random.normal(subkey, current_samples.shape) * sigma_up
57
-
58
- next_samples = current_samples + dx * dt + dW
59
- return next_samples, state
@@ -1,28 +0,0 @@
1
- import jax
2
- import jax.numpy as jnp
3
- from .common import DiffusionSampler
4
- from ..utils import RandomMarkovState
5
-
6
- class HeunSampler(DiffusionSampler):
7
- def take_next_step(self,
8
- current_samples, reconstructed_samples,
9
- pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
10
- # Get the noise and signal rates for the current and next steps
11
- current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
12
- next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
13
-
14
- dt = next_sigma - current_sigma
15
- x_0_coeff = (current_alpha * next_sigma - next_alpha * current_sigma) / dt
16
-
17
- dx_0 = (current_samples - x_0_coeff * reconstructed_samples) / current_sigma
18
- next_samples_0 = current_samples + dx_0 * dt
19
-
20
- # Recompute x_0 and eps at the first estimate to refine the derivative
21
- estimated_x_0, _, _ = self.sample_model(next_samples_0, next_step)
22
-
23
- # Estimate the refined derivative using the midpoint (Heun's method)
24
- dx_1 = (next_samples_0 - x_0_coeff * estimated_x_0) / next_sigma
25
- # Compute the final next samples by averaging the initial and refined derivatives
26
- final_next_samples = current_samples + 0.5 * (dx_0 + dx_1) * dt
27
-
28
- return final_next_samples, state
@@ -1,60 +0,0 @@
1
- import jax
2
- import jax.numpy as jnp
3
- from .common import DiffusionSampler
4
- from ..utils import RandomMarkovState
5
-
6
- class MultiStepDPM(DiffusionSampler):
7
- def __init__(self, *args, **kwargs):
8
- super().__init__(*args, **kwargs)
9
- self.history = []
10
-
11
- def _renoise(self,
12
- current_samples, reconstructed_samples,
13
- pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
14
- # Get the noise and signal rates for the current and next steps
15
- current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
16
- next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
17
-
18
- dt = next_sigma - current_sigma
19
-
20
- def first_order(current_noise, current_sigma):
21
- dx = current_noise
22
- return dx
23
-
24
- def second_order(current_noise, current_sigma, last_noise, last_sigma):
25
- dx_2 = (current_noise - last_noise) / (current_sigma - last_sigma)
26
- return dx_2
27
-
28
- def third_order(current_noise, current_sigma, last_noise, last_sigma, second_last_noise, second_last_sigma):
29
- dx_2 = second_order(current_noise, current_sigma, last_noise, last_sigma)
30
- dx_2_last = second_order(last_noise, last_sigma, second_last_noise, second_last_sigma)
31
-
32
- dx_3 = (dx_2 - dx_2_last) / (0.5 * ((current_sigma + last_sigma) - (last_sigma + second_last_sigma)))
33
-
34
- return dx_3
35
-
36
- if len(self.history) == 0:
37
- # First order only
38
- dx_1 = first_order(pred_noise, current_sigma)
39
- next_samples = current_samples + dx_1 * dt
40
- elif len(self.history) == 1:
41
- # First + Second order
42
- dx_1 = first_order(pred_noise, current_sigma)
43
- last_step = self.history[-1]
44
- dx_2 = second_order(pred_noise, current_sigma, last_step['eps'], last_step['sigma'])
45
- next_samples = current_samples + dx_1 * dt + 0.5 * dx_2 * dt**2
46
- else:
47
- # First + Second + Third order
48
- last_step = self.history[-1]
49
- second_last_step = self.history[-2]
50
-
51
- dx_1 = first_order(pred_noise, current_sigma)
52
- dx_2 = second_order(pred_noise, current_sigma, last_step['eps'], last_step['sigma'])
53
- dx_3 = third_order(pred_noise, current_sigma, last_step['eps'], last_step['sigma'], second_last_step['eps'], second_last_step['sigma'])
54
- next_samples = current_samples + (dx_1 * dt) + (0.5 * dx_2 * dt**2) + ((1/6) * dx_3 * dt**3)
55
-
56
- self.history.append({
57
- "eps": pred_noise,
58
- "sigma" : current_sigma,
59
- })
60
- return next_samples, state
@@ -1,34 +0,0 @@
1
- import jax
2
- import jax.numpy as jnp
3
- from .common import DiffusionSampler
4
- from ..utils import RandomMarkovState
5
- from ..schedulers import GeneralizedNoiseScheduler
6
-
7
- class RK4Sampler(DiffusionSampler):
8
- def __init__(self, *args, **kwargs):
9
- super().__init__(*args, **kwargs)
10
- assert issubclass(type(self.noise_schedule), GeneralizedNoiseScheduler), "Noise schedule must be a GeneralizedNoiseScheduler"
11
- @jax.jit
12
- def get_derivative(x_t, sigma, state:RandomMarkovState) -> tuple[jnp.ndarray, RandomMarkovState]:
13
- t = self.noise_schedule.get_timesteps(sigma)
14
- x_0, eps, _ = self.sample_model(x_t, t)
15
- return eps, state
16
-
17
- self.get_derivative = get_derivative
18
-
19
- def sample_step(self, current_samples:jnp.ndarray, current_step, next_step, state:RandomMarkovState=None) -> tuple[jnp.ndarray, RandomMarkovState]:
20
- step_ones = jnp.ones((current_samples.shape[0], ), dtype=jnp.int32)
21
- current_step = step_ones * current_step
22
- next_step = step_ones * next_step
23
- _, current_sigma = self.noise_schedule.get_rates(current_step)
24
- _, next_sigma = self.noise_schedule.get_rates(next_step)
25
-
26
- dt = next_sigma - current_sigma
27
-
28
- k1, state = self.get_derivative(current_samples, current_sigma, state)
29
- k2, state = self.get_derivative(current_samples + 0.5 * k1 * dt, current_sigma + 0.5 * dt, state)
30
- k3, state = self.get_derivative(current_samples + 0.5 * k2 * dt, current_sigma + 0.5 * dt, state)
31
- k4, state = self.get_derivative(current_samples + k3 * dt, current_sigma + dt, state)
32
-
33
- next_samples = current_samples + ((k1 + 2 * k2 + 2 * k3 + k4) / 6) * dt
34
- return next_samples, state
@@ -1,6 +0,0 @@
1
- from .discrete import DiscreteNoiseScheduler
2
- from .common import NoiseScheduler, GeneralizedNoiseScheduler
3
- from .cosine import CosineNoiseSchedule, ContinuousNoiseScheduler, CosineGeneralNoiseScheduler
4
- from .linear import LinearNoiseSchedule
5
- from .sqrt import SqrtContinuousNoiseScheduler
6
- from .karras import KarrasVENoiseScheduler, SimpleExpNoiseScheduler, EDMNoiseScheduler
@@ -1,98 +0,0 @@
1
- import jax
2
- import jax.numpy as jnp
3
- from typing import Union
4
- from ..utils import RandomMarkovState
5
-
6
- class NoiseScheduler():
7
- def __init__(self, timesteps,
8
- dtype=jnp.float32,
9
- clip_min=-1.0,
10
- clip_max=1.0,
11
- *args, **kwargs):
12
- self.max_timesteps = timesteps
13
- self.dtype = dtype
14
- self.clip_min = clip_min
15
- self.clip_max = clip_max
16
- if type(timesteps) == int and timesteps > 1:
17
- timestep_generator = lambda rng, batch_size, max_timesteps = timesteps: jax.random.randint(rng, (batch_size,), 0, max_timesteps)
18
- else:
19
- timestep_generator = lambda rng, batch_size, max_timesteps = timesteps: jax.random.uniform(rng, (batch_size,), minval=0, maxval=max_timesteps)
20
- self.timestep_generator = timestep_generator
21
-
22
- def generate_timesteps(self, batch_size, state:RandomMarkovState) -> tuple[jnp.ndarray, RandomMarkovState]:
23
- state, rng = state.get_random_key()
24
- timesteps = self.timestep_generator(rng, batch_size, self.max_timesteps)
25
- return timesteps, state
26
-
27
- def get_weights(self, steps):
28
- raise NotImplementedError
29
-
30
- def reshape_rates(self, rates:tuple[jnp.ndarray, jnp.ndarray], shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
31
- signal_rates, noise_rates = rates
32
- signal_rates = jnp.reshape(signal_rates, shape)
33
- noise_rates = jnp.reshape(noise_rates, shape)
34
- return signal_rates, noise_rates
35
-
36
- def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
37
- raise NotImplementedError
38
-
39
- def add_noise(self, images, noise, steps) -> jnp.ndarray:
40
- signal_rates, noise_rates = self.get_rates(steps)
41
- return signal_rates * images + noise_rates * noise
42
-
43
- def remove_all_noise(self, noisy_images, noise, steps, clip_denoised=True, rates=None):
44
- signal_rates, noise_rates = self.get_rates(steps)
45
- x_0 = (noisy_images - noise * noise_rates) / signal_rates
46
- return x_0
47
-
48
- def transform_inputs(self, x, steps):
49
- return x, steps
50
-
51
- def get_posterior_mean(self, x_0, x_t, steps):
52
- raise NotImplementedError
53
-
54
- def get_posterior_variance(self, steps, shape=(-1, 1, 1, 1)):
55
- raise NotImplementedError
56
-
57
- def get_max_variance(self):
58
- alpha_n, sigma_n = self.get_rates(self.max_timesteps)
59
- variance = jnp.sqrt(alpha_n ** 2 + sigma_n ** 2)
60
- return variance
61
-
62
- class GeneralizedNoiseScheduler(NoiseScheduler):
63
- """
64
- As per the generalization presented in the paper
65
- "Elucidating the Design Space of Diffusion-Based
66
- Generative Models" by Tero Karras et al.
67
- Basically the signal rate shall always be 1, and the model
68
- input itself shall be scaled to match the noise rate
69
- """
70
- def __init__(self, timesteps, sigma_min=0.002, sigma_max=80.0, sigma_data=1, *args, **kwargs):
71
- super().__init__(timesteps, *args, **kwargs)
72
- self.sigma_min = sigma_min
73
- self.sigma_max = sigma_max
74
- self.sigma_data = sigma_data
75
-
76
- def get_weights(self, steps, shape=(-1, 1, 1, 1)):
77
- sigma = self.get_sigmas(steps)
78
- return (1 + (1 / (1 + ((1 - sigma ** 2)/(sigma ** 2)))) / (self.sigma_max ** 2)).reshape(shape)
79
-
80
- def get_sigmas(self, steps) -> jnp.ndarray:
81
- raise NotImplementedError("This method should be implemented in the subclass")
82
-
83
- def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
84
- sigmas = self.get_sigmas(steps)
85
- signal_rates = 1
86
- noise_rates = sigmas
87
- return self.reshape_rates((signal_rates, noise_rates), shape=shape)
88
-
89
- def transform_inputs(self, x, steps, num_discrete_chunks=1000):
90
- sigmas_discrete = (steps / self.max_timesteps) * num_discrete_chunks
91
- sigmas_discrete = sigmas_discrete.astype(jnp.int32)
92
- return x, sigmas_discrete
93
-
94
- def get_timesteps(self, sigmas):
95
- """
96
- Inverse of the get_sigmas method
97
- """
98
- raise NotImplementedError("This method should be implemented in the subclass")
@@ -1,12 +0,0 @@
1
- import jax
2
- import jax.numpy as jnp
3
- from typing import Union
4
- from ..utils import RandomMarkovState
5
- from .common import NoiseScheduler
6
-
7
- class ContinuousNoiseScheduler(NoiseScheduler):
8
- """
9
- General Continuous Noise Scheduler
10
- """
11
- def __init__(self, *args, **kwargs):
12
- super().__init__(timesteps=1, *args, **kwargs)
@@ -1,40 +0,0 @@
1
- import math
2
- import numpy as np
3
- import jax.numpy as jnp
4
- from .discrete import DiscreteNoiseScheduler
5
- from .continuous import ContinuousNoiseScheduler
6
- from .common import GeneralizedNoiseScheduler
7
-
8
- def cosine_beta_schedule(timesteps, start_angle=0.008, end_angle=0.999):
9
- ts = np.linspace(0, 1, timesteps + 1, dtype=np.float64)
10
- alphas_bar = np.cos((ts + start_angle) / (1 + start_angle) * np.pi /2) ** 2
11
- alphas_bar = alphas_bar/alphas_bar[0]
12
- betas = 1 - (alphas_bar[1:] / alphas_bar[:-1])
13
- return np.clip(betas, 0, end_angle)
14
-
15
- class CosineNoiseSchedule(DiscreteNoiseScheduler):
16
- def __init__(self, timesteps, beta_start=0.008, beta_end=0.999, *args, **kwargs):
17
- super().__init__(timesteps, beta_start, beta_end, schedule_fn=cosine_beta_schedule, *args, **kwargs)
18
-
19
- class CosineGeneralNoiseScheduler(GeneralizedNoiseScheduler):
20
- def __init__(self, sigma_min=0.02, sigma_max=80.0, kappa=1.0, *args, **kwargs):
21
- super().__init__(timesteps=1, sigma_min=sigma_min, sigma_max=sigma_max, *args, **kwargs)
22
- self.kappa = kappa
23
- logsnr_max = 2 * (math.log(self.kappa) - math.log(self.sigma_max))
24
- self.theta_max = math.atan(math.exp(-0.5 * logsnr_max))
25
- logsnr_min = 2 * (math.log(self.kappa) - math.log(self.sigma_min))
26
- self.theta_min = math.atan(math.exp(-0.5 * logsnr_min))
27
-
28
- def get_sigmas(self, steps):
29
- return jnp.tan(self.theta_min + steps * (self.theta_max - self.theta_min)) / self.kappa
30
-
31
- class CosineContinuousNoiseScheduler(ContinuousNoiseScheduler):
32
- def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
33
- signal_rates = jnp.cos((jnp.pi * steps) / (2 * self.max_timesteps))
34
- noise_rates = jnp.sin((jnp.pi * steps) / (2 * self.max_timesteps))
35
- return self.reshape_rates((signal_rates, noise_rates), shape=shape)
36
-
37
- def get_weights(self, steps):
38
- alpha, sigma = self.get_rates(steps, shape=())
39
- return 1 / (1 + (alpha ** 2 / sigma ** 2))
40
-
@@ -1,74 +0,0 @@
1
- import jax
2
- import jax.numpy as jnp
3
- from typing import Union
4
- from ..utils import RandomMarkovState
5
- from .common import NoiseScheduler
6
-
7
- class DiscreteNoiseScheduler(NoiseScheduler):
8
- """
9
- Variance Preserving Noise Scheduler
10
- signal_rate**2 + noise_rate**2 = 1
11
- """
12
- def __init__(self, timesteps,
13
- beta_start=0.0001,
14
- beta_end=0.02,
15
- schedule_fn=None,
16
- p2_loss_weight_k:float=1,
17
- p2_loss_weight_gamma:float=1,
18
- *args, **kwargs):
19
- super().__init__(timesteps, *args, **kwargs)
20
- betas = schedule_fn(timesteps, beta_start, beta_end)
21
- alphas = 1 - betas
22
- alpha_cumprod = jnp.cumprod(alphas, axis=0)
23
- alpha_cumprod_prev = jnp.append(1.0, alpha_cumprod[:-1])
24
-
25
- self.betas = jnp.array(betas, dtype=jnp.float32)
26
- self.alphas = alphas.astype(jnp.float32)
27
- self.alpha_cumprod = alpha_cumprod.astype(jnp.float32)
28
- self.alpha_cumprod_prev = alpha_cumprod_prev.astype(jnp.float32)
29
-
30
- self.sqrt_alpha_cumprod = jnp.sqrt(alpha_cumprod).astype(jnp.float32)
31
- self.sqrt_one_minus_alpha_cumprod = jnp.sqrt(1 - alpha_cumprod).astype(jnp.float32)
32
-
33
- posterior_variance = (betas * (1 - alpha_cumprod_prev) / (1 - alpha_cumprod))
34
- self.posterior_variance = posterior_variance.astype(jnp.float32)
35
- self.posterior_log_variance_clipped = (jnp.log(jnp.maximum(posterior_variance, 1e-20))).astype(jnp.float32)
36
-
37
- self.posterior_mean_coef1 = (betas * jnp.sqrt(alpha_cumprod_prev) / (1 - alpha_cumprod)).astype(jnp.float32)
38
- self.posterior_mean_coef2 = ((1 - alpha_cumprod_prev) * jnp.sqrt(alphas) / (1 - alpha_cumprod)).astype(jnp.float32)
39
-
40
- self.p2_loss_weights = self.get_p2_weights(p2_loss_weight_k, p2_loss_weight_gamma)
41
-
42
- def generate_timesteps(self, batch_size, state:RandomMarkovState) -> tuple[jnp.ndarray, RandomMarkovState]:
43
- state, rng = state.get_random_key()
44
- timesteps = jax.random.randint(rng, (batch_size,), 0, self.max_timesteps)
45
- return timesteps, state
46
-
47
- def get_p2_weights(self, k, gamma):
48
- return (k + self.alpha_cumprod / (1 - self.alpha_cumprod)) ** -gamma
49
-
50
- def get_weights(self, steps, shape=(-1, 1, 1, 1)):
51
- steps = jnp.int16(steps)
52
- return self.p2_loss_weights[steps].reshape(shape)
53
-
54
- def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
55
- steps = jnp.int16(steps)
56
- signal_rate = self.sqrt_alpha_cumprod[steps]
57
- noise_rate = self.sqrt_one_minus_alpha_cumprod[steps]
58
- signal_rate = jnp.reshape(signal_rate, shape)
59
- noise_rate = jnp.reshape(noise_rate, shape)
60
- return signal_rate, noise_rate
61
-
62
- def get_posterior_mean(self, x_0, x_t, steps):
63
- steps = jnp.int16(steps)
64
- x_0_coeff = self.posterior_mean_coef1[steps]
65
- x_t_coeff = self.posterior_mean_coef2[steps]
66
- x_0_coeff, x_t_coeff = self.reshape_rates((x_0_coeff, x_t_coeff))
67
- mean = x_0_coeff * x_0 + x_t_coeff * x_t
68
- return mean
69
-
70
- def get_posterior_variance(self, steps, shape=(-1, 1, 1, 1)):
71
- steps = int(steps)
72
- return jnp.exp(0.5 * self.posterior_log_variance_clipped[steps]).reshape(shape)
73
-
74
-
@@ -1,13 +0,0 @@
1
- import numpy as np
2
- from .discrete import DiscreteNoiseScheduler
3
-
4
- def exp_beta_schedule(timesteps, start_angle=0.008, end_angle=0.999):
5
- ts = np.linspace(0, 1, timesteps + 1, dtype=np.float64)
6
- alphas_bar = np.exp(ts * -12.0)
7
- alphas_bar = alphas_bar/alphas_bar[0]
8
- betas = 1 - (alphas_bar[1:] / alphas_bar[:-1])
9
- return np.clip(betas, 0, end_angle)
10
-
11
- class ExpNoiseSchedule(DiscreteNoiseScheduler):
12
- def __init__(self, timesteps, beta_start=0.008, beta_end=0.999, *args, **kwargs):
13
- super().__init__(timesteps, beta_start, beta_end, schedule_fn=exp_beta_schedule, *args, **kwargs)
@@ -1,69 +0,0 @@
1
- import jax.numpy as jnp
2
- from .common import GeneralizedNoiseScheduler
3
- import math
4
- import jax
5
- from ..utils import RandomMarkovState
6
-
7
- class KarrasVENoiseScheduler(GeneralizedNoiseScheduler):
8
- def __init__(self, timesteps, sigma_min=0.002, sigma_max=80, rho=7., sigma_data=0.5, *args, **kwargs):
9
- super().__init__(timesteps=timesteps, sigma_min=sigma_min, sigma_max=sigma_max, sigma_data=sigma_data, *args, **kwargs)
10
- self.min_inv_rho = sigma_min ** (1 / rho)
11
- self.max_inv_rho = sigma_max ** (1 / rho)
12
- self.rho = rho
13
-
14
- def get_sigmas(self, steps) -> jnp.ndarray:
15
- # steps = jnp.int16(steps)
16
- # return self.sigmas[steps]
17
- ramp = 1 - steps / self.max_timesteps
18
- sigmas = (self.max_inv_rho + ramp * (self.min_inv_rho - self.max_inv_rho)) ** self.rho
19
- return sigmas
20
-
21
- def get_weights(self, steps, shape=(-1, 1, 1, 1)) -> jnp.ndarray:
22
- sigma = self.get_sigmas(steps)
23
- weights = ((sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2)
24
- return weights.reshape(shape)
25
-
26
- def transform_inputs(self, x, steps, num_discrete_chunks=1000) -> tuple[jnp.ndarray, jnp.ndarray]:
27
- sigmas = self.get_sigmas(steps)
28
- # sigmas = (sigmas / self.sigma_max) * num_discrete_chunks
29
- sigmas = jnp.log(sigmas) / 4
30
- return x, sigmas
31
-
32
- def get_timesteps(self, sigmas:jnp.ndarray) -> jnp.ndarray:
33
- sigmas = sigmas.reshape(-1)
34
- inv_rho = sigmas ** (1 / self.rho)
35
- ramp = ((inv_rho - self.max_inv_rho) / (self.min_inv_rho - self.max_inv_rho))
36
- steps = 1 - ramp * self.max_timesteps
37
- return steps
38
-
39
- def generate_timesteps(self, batch_size, state:RandomMarkovState) -> tuple[jnp.ndarray, RandomMarkovState]:
40
- timesteps, state = super().generate_timesteps(batch_size, state)
41
- timesteps = timesteps.astype(jnp.float32)
42
- return timesteps, state
43
-
44
- class SimpleExpNoiseScheduler(KarrasVENoiseScheduler):
45
- def __init__(self, timesteps, sigma_min=0.002, sigma_max=80, rho=7., sigma_data=0.5, *args, **kwargs):
46
- super().__init__(timesteps=timesteps, sigma_min=sigma_min, sigma_max=sigma_max, sigma_data=sigma_data, *args, **kwargs)
47
- if type(timesteps) == int and timesteps > 1:
48
- n = timesteps
49
- else:
50
- n = 1000
51
- self.sigmas = jnp.exp(jnp.linspace(math.log(sigma_min), math.log(sigma_max), n))
52
-
53
- def get_sigmas(self, steps) -> jnp.ndarray:
54
- steps = jnp.int16(steps)
55
- return self.sigmas[steps]
56
-
57
- class EDMNoiseScheduler(KarrasVENoiseScheduler):
58
- def __init__(self, timesteps, sigma_min=0.002, sigma_max=80, rho=7., sigma_data=0.5, *args, **kwargs):
59
- super().__init__(timesteps=timesteps, sigma_min=sigma_min, sigma_max=sigma_max, sigma_data=sigma_data, *args, **kwargs)
60
-
61
- def get_sigmas(self, steps, std=1.2, mean=-1.2) -> jnp.ndarray:
62
- space = steps / self.max_timesteps
63
- # space = jax.scipy.special.erfinv(self.erf_sigma_min + steps * (self.erf_sigma_max - self.erf_sigma_min))
64
- return jnp.exp(space * std + mean)
65
-
66
- def generate_timesteps(self, batch_size, state:RandomMarkovState) -> tuple[jnp.ndarray, RandomMarkovState]:
67
- state, rng = state.get_random_key()
68
- timesteps = jax.random.normal(rng, (batch_size,), dtype=jnp.float32)
69
- return timesteps, state
@@ -1,14 +0,0 @@
1
- import numpy as np
2
- from .discrete import DiscreteNoiseScheduler
3
-
4
- def linear_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02):
5
- scale = 1000 / timesteps
6
- beta_start = scale * beta_start
7
- beta_end = scale * beta_end
8
- betas = np.linspace(
9
- beta_start, beta_end, timesteps, dtype=np.float64)
10
- return betas
11
-
12
- class LinearNoiseSchedule(DiscreteNoiseScheduler):
13
- def __init__(self, timesteps, beta_start=0.0001, beta_end=0.02, *args, **kwargs):
14
- super().__init__(timesteps, beta_start, beta_end, schedule_fn=linear_beta_schedule, *args, **kwargs)
@@ -1,10 +0,0 @@
1
- import numpy as np
2
- import jax.numpy as jnp
3
- from .discrete import DiscreteNoiseScheduler
4
- from .continuous import ContinuousNoiseScheduler
5
-
6
- class SqrtContinuousNoiseScheduler(ContinuousNoiseScheduler):
7
- def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
8
- signal_rates = jnp.sqrt(1 - steps)
9
- noise_rates = jnp.sqrt(steps)
10
- return self.reshape_rates((signal_rates, noise_rates), shape=shape)
@@ -1,2 +0,0 @@
1
- from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
2
- from .diffusion_trainer import DiffusionTrainer, TrainState