flaxdiff 0.1.37.4__py3-none-any.whl → 0.1.37.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/predictors/__init__.py +4 -4
 - flaxdiff/samplers/common.py +2 -1
 - flaxdiff/schedulers/common.py +17 -13
 - flaxdiff/schedulers/cosine.py +4 -4
 - flaxdiff/schedulers/discrete.py +5 -7
 - flaxdiff/schedulers/linear.py +1 -2
 - flaxdiff/schedulers/sqrt.py +2 -1
 - flaxdiff/trainer/diffusion_trainer.py +30 -14
 - {flaxdiff-0.1.37.4.dist-info → flaxdiff-0.1.37.6.dist-info}/METADATA +1 -1
 - {flaxdiff-0.1.37.4.dist-info → flaxdiff-0.1.37.6.dist-info}/RECORD +12 -12
 - {flaxdiff-0.1.37.4.dist-info → flaxdiff-0.1.37.6.dist-info}/WHEEL +0 -0
 - {flaxdiff-0.1.37.4.dist-info → flaxdiff-0.1.37.6.dist-info}/top_level.txt +0 -0
 
    
        flaxdiff/predictors/__init__.py
    CHANGED
    
    | 
         @@ -1,6 +1,6 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            from typing import Union
         
     | 
| 
       2 
2 
     | 
    
         
             
            import jax.numpy as jnp
         
     | 
| 
       3 
     | 
    
         
            -
            from ..schedulers import NoiseScheduler, GeneralizedNoiseScheduler
         
     | 
| 
      
 3 
     | 
    
         
            +
            from ..schedulers import NoiseScheduler, GeneralizedNoiseScheduler, get_coeff_shapes_tuple
         
     | 
| 
       4 
4 
     | 
    
         | 
| 
       5 
5 
     | 
    
         
             
            ############################################################################################################
         
     | 
| 
       6 
6 
     | 
    
         
             
            # Prediction Transforms
         
     | 
| 
         @@ -11,7 +11,7 @@ class DiffusionPredictionTransform(): 
     | 
|
| 
       11 
11 
     | 
    
         
             
                    return preds
         
     | 
| 
       12 
12 
     | 
    
         | 
| 
       13 
13 
     | 
    
         
             
                def __call__(self, x_t, preds, current_step, noise_schedule:NoiseScheduler) -> Union[jnp.ndarray, jnp.ndarray]:
         
     | 
| 
       14 
     | 
    
         
            -
                    rates = noise_schedule.get_rates(current_step)
         
     | 
| 
      
 14 
     | 
    
         
            +
                    rates = noise_schedule.get_rates(current_step, shape=get_coeff_shapes_tuple(x_t))
         
     | 
| 
       15 
15 
     | 
    
         
             
                    preds = self.pred_transform(x_t, preds, rates)
         
     | 
| 
       16 
16 
     | 
    
         
             
                    x_0, epsilon = self.backward_diffusion(x_t, preds, rates)
         
     | 
| 
       17 
17 
     | 
    
         
             
                    return x_0, epsilon
         
     | 
| 
         @@ -85,8 +85,8 @@ class KarrasPredictionTransform(DiffusionPredictionTransform): 
     | 
|
| 
       85 
85 
     | 
    
         
             
                    _, sigma = rates
         
     | 
| 
       86 
86 
     | 
    
         
             
                    c_out = sigma * self.sigma_data / (jnp.sqrt(self.sigma_data ** 2 + sigma ** 2) + epsilon)
         
     | 
| 
       87 
87 
     | 
    
         
             
                    c_skip = self.sigma_data ** 2 / (self.sigma_data ** 2 + sigma ** 2 + epsilon)
         
     | 
| 
       88 
     | 
    
         
            -
                    c_out = c_out.reshape(( 
     | 
| 
       89 
     | 
    
         
            -
                    c_skip = c_skip.reshape(( 
     | 
| 
      
 88 
     | 
    
         
            +
                    c_out = c_out.reshape(get_coeff_shapes_tuple(preds))
         
     | 
| 
      
 89 
     | 
    
         
            +
                    c_skip = c_skip.reshape(get_coeff_shapes_tuple(x_t))
         
     | 
| 
       90 
90 
     | 
    
         
             
                    x_0 = c_out * preds + c_skip * x_t
         
     | 
| 
       91 
91 
     | 
    
         
             
                    return x_0
         
     | 
| 
       92 
92 
     | 
    
         | 
    
        flaxdiff/samplers/common.py
    CHANGED
    
    | 
         @@ -67,7 +67,7 @@ class DiffusionSampler(): 
     | 
|
| 
       67 
67 
     | 
    
         
             
                # Used to sample from the diffusion model
         
     | 
| 
       68 
68 
     | 
    
         
             
                def sample_step(self, sample_model_fn, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
         
     | 
| 
       69 
69 
     | 
    
         
             
                    # First clip the noisy images
         
     | 
| 
       70 
     | 
    
         
            -
                    step_ones = jnp.ones((current_samples 
     | 
| 
      
 70 
     | 
    
         
            +
                    step_ones = jnp.ones((len(current_samples), ), dtype=jnp.int32)
         
     | 
| 
       71 
71 
     | 
    
         
             
                    current_step = step_ones * current_step
         
     | 
| 
       72 
72 
     | 
    
         
             
                    next_step = step_ones * next_step
         
     | 
| 
       73 
73 
     | 
    
         
             
                    pred_images, pred_noise, _ = sample_model_fn(current_samples, current_step, *model_conditioning_inputs)
         
     | 
| 
         @@ -133,6 +133,7 @@ class DiffusionSampler(): 
     | 
|
| 
       133 
133 
     | 
    
         | 
| 
       134 
134 
     | 
    
         
             
                    params = params if params is not None else self.params
         
     | 
| 
       135 
135 
     | 
    
         | 
| 
      
 136 
     | 
    
         
            +
                    @jax.jit
         
     | 
| 
       136 
137 
     | 
    
         
             
                    def sample_model_fn(x_t, t, *additional_inputs):
         
     | 
| 
       137 
138 
     | 
    
         
             
                        return self.sample_model(params, x_t, t, *additional_inputs)
         
     | 
| 
       138 
139 
     | 
    
         | 
    
        flaxdiff/schedulers/common.py
    CHANGED
    
    | 
         @@ -3,6 +3,16 @@ import jax.numpy as jnp 
     | 
|
| 
       3 
3 
     | 
    
         
             
            from typing import Union
         
     | 
| 
       4 
4 
     | 
    
         
             
            from ..utils import RandomMarkovState  
         
     | 
| 
       5 
5 
     | 
    
         | 
| 
      
 6 
     | 
    
         
            +
            def get_coeff_shapes_tuple(array):
         
     | 
| 
      
 7 
     | 
    
         
            +
                shape_tuple = (-1,) + (1,) * (array.ndim - 1)
         
     | 
| 
      
 8 
     | 
    
         
            +
                return shape_tuple
         
     | 
| 
      
 9 
     | 
    
         
            +
             
     | 
| 
      
 10 
     | 
    
         
            +
            def reshape_rates(rates:tuple[jnp.ndarray, jnp.ndarray], shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
         
     | 
| 
      
 11 
     | 
    
         
            +
                signal_rates, noise_rates = rates
         
     | 
| 
      
 12 
     | 
    
         
            +
                signal_rates = jnp.reshape(signal_rates, shape)
         
     | 
| 
      
 13 
     | 
    
         
            +
                noise_rates = jnp.reshape(noise_rates, shape)
         
     | 
| 
      
 14 
     | 
    
         
            +
                return signal_rates, noise_rates
         
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
       6 
16 
     | 
    
         
             
            class NoiseScheduler():
         
     | 
| 
       7 
17 
     | 
    
         
             
                def __init__(self, timesteps,
         
     | 
| 
       8 
18 
     | 
    
         
             
                                dtype=jnp.float32,
         
     | 
| 
         @@ -24,24 +34,18 @@ class NoiseScheduler(): 
     | 
|
| 
       24 
34 
     | 
    
         
             
                    timesteps = self.timestep_generator(rng, batch_size, self.max_timesteps)
         
     | 
| 
       25 
35 
     | 
    
         
             
                    return timesteps, state
         
     | 
| 
       26 
36 
     | 
    
         | 
| 
       27 
     | 
    
         
            -
                def get_weights(self, steps):
         
     | 
| 
      
 37 
     | 
    
         
            +
                def get_weights(self, steps, shape=(-1, 1, 1, 1)):
         
     | 
| 
       28 
38 
     | 
    
         
             
                    raise NotImplementedError
         
     | 
| 
       29 
39 
     | 
    
         | 
| 
       30 
     | 
    
         
            -
                def reshape_rates(self, rates:tuple[jnp.ndarray, jnp.ndarray], shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
         
     | 
| 
       31 
     | 
    
         
            -
                    signal_rates, noise_rates = rates
         
     | 
| 
       32 
     | 
    
         
            -
                    signal_rates = jnp.reshape(signal_rates, shape)
         
     | 
| 
       33 
     | 
    
         
            -
                    noise_rates = jnp.reshape(noise_rates, shape)
         
     | 
| 
       34 
     | 
    
         
            -
                    return signal_rates, noise_rates
         
     | 
| 
       35 
     | 
    
         
            -
                
         
     | 
| 
       36 
40 
     | 
    
         
             
                def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
         
     | 
| 
       37 
41 
     | 
    
         
             
                    raise NotImplementedError
         
     | 
| 
       38 
42 
     | 
    
         | 
| 
       39 
43 
     | 
    
         
             
                def add_noise(self, images, noise, steps) -> jnp.ndarray:
         
     | 
| 
       40 
     | 
    
         
            -
                    signal_rates, noise_rates = self.get_rates(steps)
         
     | 
| 
      
 44 
     | 
    
         
            +
                    signal_rates, noise_rates = self.get_rates(steps, shape=get_coeff_shapes_tuple(images))
         
     | 
| 
       41 
45 
     | 
    
         
             
                    return signal_rates * images + noise_rates * noise
         
     | 
| 
       42 
46 
     | 
    
         | 
| 
       43 
47 
     | 
    
         
             
                def remove_all_noise(self, noisy_images, noise, steps, clip_denoised=True, rates=None):
         
     | 
| 
       44 
     | 
    
         
            -
                    signal_rates, noise_rates = self.get_rates(steps)
         
     | 
| 
      
 48 
     | 
    
         
            +
                    signal_rates, noise_rates = self.get_rates(steps, shape=get_coeff_shapes_tuple(noisy_images))
         
     | 
| 
       45 
49 
     | 
    
         
             
                    x_0 = (noisy_images - noise * noise_rates) / signal_rates
         
     | 
| 
       46 
50 
     | 
    
         
             
                    return x_0
         
     | 
| 
       47 
51 
     | 
    
         | 
| 
         @@ -54,8 +58,8 @@ class NoiseScheduler(): 
     | 
|
| 
       54 
58 
     | 
    
         
             
                def get_posterior_variance(self, steps, shape=(-1, 1, 1, 1)):
         
     | 
| 
       55 
59 
     | 
    
         
             
                    raise NotImplementedError
         
     | 
| 
       56 
60 
     | 
    
         | 
| 
       57 
     | 
    
         
            -
                def get_max_variance(self):
         
     | 
| 
       58 
     | 
    
         
            -
                    alpha_n, sigma_n = self.get_rates(self.max_timesteps)
         
     | 
| 
      
 61 
     | 
    
         
            +
                def get_max_variance(self, shape=(-1, 1, 1, 1)):
         
     | 
| 
      
 62 
     | 
    
         
            +
                    alpha_n, sigma_n = self.get_rates(self.max_timesteps, shape=shape)
         
     | 
| 
       59 
63 
     | 
    
         
             
                    variance = jnp.sqrt(alpha_n ** 2 + sigma_n ** 2) 
         
     | 
| 
       60 
64 
     | 
    
         
             
                    return variance
         
     | 
| 
       61 
65 
     | 
    
         | 
| 
         @@ -82,9 +86,9 @@ class GeneralizedNoiseScheduler(NoiseScheduler): 
     | 
|
| 
       82 
86 
     | 
    
         | 
| 
       83 
87 
     | 
    
         
             
                def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
         
     | 
| 
       84 
88 
     | 
    
         
             
                    sigmas = self.get_sigmas(steps)
         
     | 
| 
       85 
     | 
    
         
            -
                    signal_rates =  
     | 
| 
      
 89 
     | 
    
         
            +
                    signal_rates = jnp.ones_like(sigmas)
         
     | 
| 
       86 
90 
     | 
    
         
             
                    noise_rates = sigmas
         
     | 
| 
       87 
     | 
    
         
            -
                    return  
     | 
| 
      
 91 
     | 
    
         
            +
                    return reshape_rates((signal_rates, noise_rates), shape=shape)
         
     | 
| 
       88 
92 
     | 
    
         | 
| 
       89 
93 
     | 
    
         
             
                def transform_inputs(self, x, steps, num_discrete_chunks=1000):
         
     | 
| 
       90 
94 
     | 
    
         
             
                    sigmas_discrete = (steps / self.max_timesteps) * num_discrete_chunks
         
     | 
    
        flaxdiff/schedulers/cosine.py
    CHANGED
    
    | 
         @@ -3,7 +3,7 @@ import numpy as np 
     | 
|
| 
       3 
3 
     | 
    
         
             
            import jax.numpy as jnp
         
     | 
| 
       4 
4 
     | 
    
         
             
            from .discrete import DiscreteNoiseScheduler
         
     | 
| 
       5 
5 
     | 
    
         
             
            from .continuous import ContinuousNoiseScheduler
         
     | 
| 
       6 
     | 
    
         
            -
            from .common import GeneralizedNoiseScheduler
         
     | 
| 
      
 6 
     | 
    
         
            +
            from .common import GeneralizedNoiseScheduler, reshape_rates
         
     | 
| 
       7 
7 
     | 
    
         | 
| 
       8 
8 
     | 
    
         
             
            def cosine_beta_schedule(timesteps, start_angle=0.008, end_angle=0.999):
         
     | 
| 
       9 
9 
     | 
    
         
             
                ts = np.linspace(0, 1, timesteps + 1, dtype=np.float64)
         
     | 
| 
         @@ -32,9 +32,9 @@ class CosineContinuousNoiseScheduler(ContinuousNoiseScheduler): 
     | 
|
| 
       32 
32 
     | 
    
         
             
                def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
         
     | 
| 
       33 
33 
     | 
    
         
             
                    signal_rates = jnp.cos((jnp.pi * steps) / (2 * self.max_timesteps))
         
     | 
| 
       34 
34 
     | 
    
         
             
                    noise_rates = jnp.sin((jnp.pi * steps) / (2 * self.max_timesteps))
         
     | 
| 
       35 
     | 
    
         
            -
                    return  
     | 
| 
      
 35 
     | 
    
         
            +
                    return reshape_rates((signal_rates, noise_rates), shape=shape)
         
     | 
| 
       36 
36 
     | 
    
         | 
| 
       37 
     | 
    
         
            -
                def get_weights(self, steps):
         
     | 
| 
       38 
     | 
    
         
            -
                    alpha, sigma = self.get_rates(steps, shape= 
     | 
| 
      
 37 
     | 
    
         
            +
                def get_weights(self, steps, shape=(-1, 1, 1, 1)) -> jnp.ndarray:
         
     | 
| 
      
 38 
     | 
    
         
            +
                    alpha, sigma = self.get_rates(steps, shape=shape)
         
     | 
| 
       39 
39 
     | 
    
         
             
                    return 1 / (1 + (alpha ** 2 / sigma ** 2))
         
     | 
| 
       40 
40 
     | 
    
         | 
    
        flaxdiff/schedulers/discrete.py
    CHANGED
    
    | 
         @@ -2,7 +2,7 @@ import jax 
     | 
|
| 
       2 
2 
     | 
    
         
             
            import jax.numpy as jnp
         
     | 
| 
       3 
3 
     | 
    
         
             
            from typing import Union
         
     | 
| 
       4 
4 
     | 
    
         
             
            from ..utils import RandomMarkovState  
         
     | 
| 
       5 
     | 
    
         
            -
            from .common import NoiseScheduler
         
     | 
| 
      
 5 
     | 
    
         
            +
            from .common import NoiseScheduler, reshape_rates, get_coeff_shapes_tuple
         
     | 
| 
       6 
6 
     | 
    
         | 
| 
       7 
7 
     | 
    
         
             
            class DiscreteNoiseScheduler(NoiseScheduler):
         
     | 
| 
       8 
8 
     | 
    
         
             
                """
         
     | 
| 
         @@ -53,17 +53,15 @@ class DiscreteNoiseScheduler(NoiseScheduler): 
     | 
|
| 
       53 
53 
     | 
    
         | 
| 
       54 
54 
     | 
    
         
             
                def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
         
     | 
| 
       55 
55 
     | 
    
         
             
                    steps = jnp.int16(steps)
         
     | 
| 
       56 
     | 
    
         
            -
                     
     | 
| 
       57 
     | 
    
         
            -
                     
     | 
| 
       58 
     | 
    
         
            -
                     
     | 
| 
       59 
     | 
    
         
            -
                    noise_rate = jnp.reshape(noise_rate, shape)
         
     | 
| 
       60 
     | 
    
         
            -
                    return signal_rate, noise_rate
         
     | 
| 
      
 56 
     | 
    
         
            +
                    signal_rates = self.sqrt_alpha_cumprod[steps]
         
     | 
| 
      
 57 
     | 
    
         
            +
                    noise_rates = self.sqrt_one_minus_alpha_cumprod[steps]
         
     | 
| 
      
 58 
     | 
    
         
            +
                    return reshape_rates((signal_rates, noise_rates), shape=shape)
         
     | 
| 
       61 
59 
     | 
    
         | 
| 
       62 
60 
     | 
    
         
             
                def get_posterior_mean(self, x_0, x_t, steps):
         
     | 
| 
       63 
61 
     | 
    
         
             
                    steps = jnp.int16(steps)
         
     | 
| 
       64 
62 
     | 
    
         
             
                    x_0_coeff = self.posterior_mean_coef1[steps]
         
     | 
| 
       65 
63 
     | 
    
         
             
                    x_t_coeff = self.posterior_mean_coef2[steps]
         
     | 
| 
       66 
     | 
    
         
            -
                    x_0_coeff, x_t_coeff =  
     | 
| 
      
 64 
     | 
    
         
            +
                    x_0_coeff, x_t_coeff = reshape_rates((x_0_coeff, x_t_coeff), shape=get_coeff_shapes_tuple(x_0))
         
     | 
| 
       67 
65 
     | 
    
         
             
                    mean = x_0_coeff * x_0 + x_t_coeff * x_t
         
     | 
| 
       68 
66 
     | 
    
         
             
                    return mean
         
     | 
| 
       69 
67 
     | 
    
         | 
    
        flaxdiff/schedulers/linear.py
    CHANGED
    
    | 
         @@ -5,8 +5,7 @@ def linear_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02): 
     | 
|
| 
       5 
5 
     | 
    
         
             
                scale = 1000 / timesteps
         
     | 
| 
       6 
6 
     | 
    
         
             
                beta_start = scale * beta_start
         
     | 
| 
       7 
7 
     | 
    
         
             
                beta_end = scale * beta_end
         
     | 
| 
       8 
     | 
    
         
            -
                betas = np.linspace(
         
     | 
| 
       9 
     | 
    
         
            -
                    beta_start, beta_end, timesteps, dtype=np.float64)
         
     | 
| 
      
 8 
     | 
    
         
            +
                betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float64)
         
     | 
| 
       10 
9 
     | 
    
         
             
                return betas
         
     | 
| 
       11 
10 
     | 
    
         | 
| 
       12 
11 
     | 
    
         
             
            class LinearNoiseSchedule(DiscreteNoiseScheduler):
         
     | 
    
        flaxdiff/schedulers/sqrt.py
    CHANGED
    
    | 
         @@ -2,9 +2,10 @@ import numpy as np 
     | 
|
| 
       2 
2 
     | 
    
         
             
            import jax.numpy as jnp
         
     | 
| 
       3 
3 
     | 
    
         
             
            from .discrete import DiscreteNoiseScheduler
         
     | 
| 
       4 
4 
     | 
    
         
             
            from .continuous import ContinuousNoiseScheduler
         
     | 
| 
      
 5 
     | 
    
         
            +
            from .common import reshape_rates
         
     | 
| 
       5 
6 
     | 
    
         | 
| 
       6 
7 
     | 
    
         
             
            class SqrtContinuousNoiseScheduler(ContinuousNoiseScheduler):
         
     | 
| 
       7 
8 
     | 
    
         
             
                def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
         
     | 
| 
       8 
9 
     | 
    
         
             
                    signal_rates = jnp.sqrt(1 - steps)
         
     | 
| 
       9 
10 
     | 
    
         
             
                    noise_rates = jnp.sqrt(steps)
         
     | 
| 
       10 
     | 
    
         
            -
                    return  
     | 
| 
      
 11 
     | 
    
         
            +
                    return reshape_rates((signal_rates, noise_rates), shape=shape)
         
     | 
| 
         @@ -11,7 +11,7 @@ from jax.sharding import Mesh, PartitionSpec as P 
     | 
|
| 
       11 
11 
     | 
    
         
             
            from jax.experimental.shard_map import shard_map
         
     | 
| 
       12 
12 
     | 
    
         
             
            from typing import Dict, Callable, Sequence, Any, Union, Tuple, Type
         
     | 
| 
       13 
13 
     | 
    
         | 
| 
       14 
     | 
    
         
            -
            from ..schedulers import NoiseScheduler
         
     | 
| 
      
 14 
     | 
    
         
            +
            from ..schedulers import NoiseScheduler, get_coeff_shapes_tuple
         
     | 
| 
       15 
15 
     | 
    
         
             
            from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
         
     | 
| 
       16 
16 
     | 
    
         
             
            from ..samplers.common import DiffusionSampler
         
     | 
| 
       17 
17 
     | 
    
         
             
            from ..samplers.ddim import DDIMSampler
         
     | 
| 
         @@ -144,6 +144,8 @@ class DiffusionTrainer(SimpleTrainer): 
     | 
|
| 
       144 
144 
     | 
    
         | 
| 
       145 
145 
     | 
    
         
             
                        images = batch['image']
         
     | 
| 
       146 
146 
     | 
    
         | 
| 
      
 147 
     | 
    
         
            +
                        local_batch_size = images.shape[0]
         
     | 
| 
      
 148 
     | 
    
         
            +
                        
         
     | 
| 
       147 
149 
     | 
    
         
             
                        # First get the standard deviation of the images
         
     | 
| 
       148 
150 
     | 
    
         
             
                        # std = jnp.std(images, axis=(1, 2, 3))
         
     | 
| 
       149 
151 
     | 
    
         
             
                        # is_non_zero = (std > 0)
         
     | 
| 
         @@ -164,7 +166,7 @@ class DiffusionTrainer(SimpleTrainer): 
     | 
|
| 
       164 
166 
     | 
    
         
             
                        label_seq = jnp.concat(
         
     | 
| 
       165 
167 
     | 
    
         
             
                            [null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
         
     | 
| 
       166 
168 
     | 
    
         | 
| 
       167 
     | 
    
         
            -
                        noise_level, local_rng_state = noise_schedule.generate_timesteps( 
     | 
| 
      
 169 
     | 
    
         
            +
                        noise_level, local_rng_state = noise_schedule.generate_timesteps(local_batch_size, local_rng_state)
         
     | 
| 
       168 
170 
     | 
    
         | 
| 
       169 
171 
     | 
    
         
             
                        local_rng_state, rngs = local_rng_state.get_random_key()
         
     | 
| 
       170 
172 
     | 
    
         
             
                        noise: jax.Array = jax.random.normal(rngs, shape=images.shape, dtype=jnp.float32)
         
     | 
| 
         @@ -172,17 +174,15 @@ class DiffusionTrainer(SimpleTrainer): 
     | 
|
| 
       172 
174 
     | 
    
         
             
                        # Make sure image is also float32
         
     | 
| 
       173 
175 
     | 
    
         
             
                        images = images.astype(jnp.float32)
         
     | 
| 
       174 
176 
     | 
    
         | 
| 
       175 
     | 
    
         
            -
                        rates = noise_schedule.get_rates(noise_level)
         
     | 
| 
       176 
     | 
    
         
            -
                        noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
         
     | 
| 
       177 
     | 
    
         
            -
                            images, noise, rates)
         
     | 
| 
      
 177 
     | 
    
         
            +
                        rates = noise_schedule.get_rates(noise_level, get_coeff_shapes_tuple(images))
         
     | 
| 
      
 178 
     | 
    
         
            +
                        noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(images, noise, rates)
         
     | 
| 
       178 
179 
     | 
    
         | 
| 
       179 
180 
     | 
    
         
             
                        def model_loss(params):
         
     | 
| 
       180 
181 
     | 
    
         
             
                            preds = model.apply(params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
         
     | 
| 
       181 
     | 
    
         
            -
                            preds = model_output_transform.pred_transform(
         
     | 
| 
       182 
     | 
    
         
            -
                                noisy_images, preds, rates)
         
     | 
| 
      
 182 
     | 
    
         
            +
                            preds = model_output_transform.pred_transform(noisy_images, preds, rates)
         
     | 
| 
       183 
183 
     | 
    
         
             
                            nloss = loss_fn(preds, expected_output)
         
     | 
| 
       184 
184 
     | 
    
         
             
                            # Ignore the loss contribution of images with zero standard deviation
         
     | 
| 
       185 
     | 
    
         
            -
                            nloss *= noise_schedule.get_weights(noise_level)
         
     | 
| 
      
 185 
     | 
    
         
            +
                            nloss *= noise_schedule.get_weights(noise_level, get_coeff_shapes_tuple(nloss))
         
     | 
| 
       186 
186 
     | 
    
         
             
                            nloss = jnp.mean(nloss)
         
     | 
| 
       187 
187 
     | 
    
         
             
                            loss = nloss
         
     | 
| 
       188 
188 
     | 
    
         
             
                            return loss
         
     | 
| 
         @@ -216,7 +216,7 @@ class DiffusionTrainer(SimpleTrainer): 
     | 
|
| 
       216 
216 
     | 
    
         
             
                        #     operand=None
         
     | 
| 
       217 
217 
     | 
    
         
             
                        # )
         
     | 
| 
       218 
218 
     | 
    
         | 
| 
       219 
     | 
    
         
            -
                         
     | 
| 
      
 219 
     | 
    
         
            +
                        new_state = train_state.apply_gradients(grads=grads)
         
     | 
| 
       220 
220 
     | 
    
         | 
| 
       221 
221 
     | 
    
         
             
                        if train_state.dynamic_scale is not None:
         
     | 
| 
       222 
222 
     | 
    
         
             
                            # if is_fin == False the gradients contain Inf/NaNs and optimizer state and
         
     | 
| 
         @@ -238,9 +238,16 @@ class DiffusionTrainer(SimpleTrainer): 
     | 
|
| 
       238 
238 
     | 
    
         
             
                        return train_state, loss, rng_state
         
     | 
| 
       239 
239 
     | 
    
         | 
| 
       240 
240 
     | 
    
         
             
                    if distributed_training:
         
     | 
| 
       241 
     | 
    
         
            -
                        train_step = shard_map( 
     | 
| 
       242 
     | 
    
         
            -
             
     | 
| 
       243 
     | 
    
         
            -
             
     | 
| 
      
 241 
     | 
    
         
            +
                        train_step = shard_map(
         
     | 
| 
      
 242 
     | 
    
         
            +
                            train_step, 
         
     | 
| 
      
 243 
     | 
    
         
            +
                            mesh=self.mesh, 
         
     | 
| 
      
 244 
     | 
    
         
            +
                            in_specs=(P(), P(), P('data'), P('data')), 
         
     | 
| 
      
 245 
     | 
    
         
            +
                            out_specs=(P(), P(), P()),
         
     | 
| 
      
 246 
     | 
    
         
            +
                        )
         
     | 
| 
      
 247 
     | 
    
         
            +
                    train_step = jax.jit(
         
     | 
| 
      
 248 
     | 
    
         
            +
                        train_step,
         
     | 
| 
      
 249 
     | 
    
         
            +
                        donate_argnums=(2)
         
     | 
| 
      
 250 
     | 
    
         
            +
                    )
         
     | 
| 
       244 
251 
     | 
    
         | 
| 
       245 
252 
     | 
    
         
             
                    return train_step
         
     | 
| 
       246 
253 
     | 
    
         | 
| 
         @@ -253,12 +260,21 @@ class DiffusionTrainer(SimpleTrainer): 
     | 
|
| 
       253 
260 
     | 
    
         
             
                    null_labels_full = null_labels_full.astype(jnp.float16)
         
     | 
| 
       254 
261 
     | 
    
         
             
                    # null_labels_seq = jnp.array(null_labels_full[0], dtype=jnp.float16)
         
     | 
| 
       255 
262 
     | 
    
         | 
| 
      
 263 
     | 
    
         
            +
                    if 'image' in self.input_shapes:
         
     | 
| 
      
 264 
     | 
    
         
            +
                        image_size = self.input_shapes['image'][1]
         
     | 
| 
      
 265 
     | 
    
         
            +
                    elif 'x' in self.input_shapes:
         
     | 
| 
      
 266 
     | 
    
         
            +
                        image_size = self.input_shapes['x'][1]
         
     | 
| 
      
 267 
     | 
    
         
            +
                    elif 'sample' in self.input_shapes:
         
     | 
| 
      
 268 
     | 
    
         
            +
                        image_size = self.input_shapes['sample'][1]
         
     | 
| 
      
 269 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 270 
     | 
    
         
            +
                        raise ValueError("No image input shape found in input shapes")
         
     | 
| 
      
 271 
     | 
    
         
            +
                    
         
     | 
| 
       256 
272 
     | 
    
         
             
                    sampler = sampler_class(
         
     | 
| 
       257 
273 
     | 
    
         
             
                        model=model,
         
     | 
| 
       258 
274 
     | 
    
         
             
                        params=None,
         
     | 
| 
       259 
275 
     | 
    
         
             
                        noise_schedule=self.noise_schedule if sampling_noise_schedule is None else sampling_noise_schedule,
         
     | 
| 
       260 
276 
     | 
    
         
             
                        model_output_transform=self.model_output_transform,
         
     | 
| 
       261 
     | 
    
         
            -
                        image_size= 
     | 
| 
      
 277 
     | 
    
         
            +
                        image_size=image_size,
         
     | 
| 
       262 
278 
     | 
    
         
             
                        null_labels_seq=null_labels_full,
         
     | 
| 
       263 
279 
     | 
    
         
             
                        autoencoder=autoencoder,
         
     | 
| 
       264 
280 
     | 
    
         
             
                        guidance_scale=3.0,
         
     | 
| 
         @@ -309,7 +325,7 @@ class DiffusionTrainer(SimpleTrainer): 
     | 
|
| 
       309 
325 
     | 
    
         
             
                        )
         
     | 
| 
       310 
326 
     | 
    
         | 
| 
       311 
327 
     | 
    
         
             
                        # Put each sample on wandb
         
     | 
| 
       312 
     | 
    
         
            -
                        if self.wandb:
         
     | 
| 
      
 328 
     | 
    
         
            +
                        if getattr(self, 'wandb', None) is not None and self.wandb:
         
     | 
| 
       313 
329 
     | 
    
         
             
                            import numpy as np
         
     | 
| 
       314 
330 
     | 
    
         
             
                            from wandb import Image as wandbImage
         
     | 
| 
       315 
331 
     | 
    
         
             
                            wandb_images = []
         
     | 
| 
         @@ -20,9 +20,9 @@ flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0 
     | 
|
| 
       20 
20 
     | 
    
         
             
            flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
         
     | 
| 
       21 
21 
     | 
    
         
             
            flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
         
     | 
| 
       22 
22 
     | 
    
         
             
            flaxdiff/models/autoencoder/simple_autoenc.py,sha256=UXHPgDmwGTnv3Uts6Zj3p9R9nJXnEiEXbllgarwDfXM,805
         
     | 
| 
       23 
     | 
    
         
            -
            flaxdiff/predictors/__init__.py,sha256= 
     | 
| 
      
 23 
     | 
    
         
            +
            flaxdiff/predictors/__init__.py,sha256=A6lMnRSqDLwBdwB0a1x4xPpcXcTYCISWFMkldbiK6Vs,4554
         
     | 
| 
       24 
24 
     | 
    
         
             
            flaxdiff/samplers/__init__.py,sha256=EY9v1pgwEoR64Kiz9K8fAR-4_ir9c03mYeY3hrpUNhE,308
         
     | 
| 
       25 
     | 
    
         
            -
            flaxdiff/samplers/common.py,sha256= 
     | 
| 
      
 25 
     | 
    
         
            +
            flaxdiff/samplers/common.py,sha256=wn8tryC3B0KE0V98zMiH_X2x-Tc1NbM5iV27hn5p8Aw,8846
         
     | 
| 
       26 
26 
     | 
    
         
             
            flaxdiff/samplers/ddim.py,sha256=hTjDm0SmIj-Tkc80QRATMcN_sKVhHbqZQboRQCAn4mY,569
         
     | 
| 
       27 
27 
     | 
    
         
             
            flaxdiff/samplers/ddpm.py,sha256=JgkNSo7fp7Jm-8rCy4eu5m4YIzXTWzxv-iHf3EQ0z5w,2243
         
     | 
| 
       28 
28 
     | 
    
         
             
            flaxdiff/samplers/euler.py,sha256=QSkttB4DYnepDGwhWq3EGXYjMAqj4qLOdh7u98HttzY,2791
         
     | 
| 
         @@ -30,20 +30,20 @@ flaxdiff/samplers/heun_sampler.py,sha256=EvR3hy4t_D47ZOH4luzRFqPmv2v4z78P_JhqBGE 
     | 
|
| 
       30 
30 
     | 
    
         
             
            flaxdiff/samplers/multistep_dpm.py,sha256=2M4Abb93-GUVN1f0_ZHBeA6lF0eF15Hi6QOgOu2K45s,2752
         
     | 
| 
       31 
31 
     | 
    
         
             
            flaxdiff/samplers/rk4_sampler.py,sha256=vcQefFhOUZdNOQGBdzNkb2NgmTC2KWd_nhUhyLtt3yI,2026
         
     | 
| 
       32 
32 
     | 
    
         
             
            flaxdiff/schedulers/__init__.py,sha256=EIva9gBz3DKHORuGmv1LQCKTtRqCRavFOXMNqxAR_ks,131
         
     | 
| 
       33 
     | 
    
         
            -
            flaxdiff/schedulers/common.py,sha256= 
     | 
| 
      
 33 
     | 
    
         
            +
            flaxdiff/schedulers/common.py,sha256=PDeje2NmN7X3J5qKGauE0jYPpxjgEX44f_evJHRIG3E,4382
         
     | 
| 
       34 
34 
     | 
    
         
             
            flaxdiff/schedulers/continuous.py,sha256=5c_niOA20fxJ5oJDi09FfayIRogBGwtfG0XThW2IUZk,334
         
     | 
| 
       35 
     | 
    
         
            -
            flaxdiff/schedulers/cosine.py,sha256= 
     | 
| 
       36 
     | 
    
         
            -
            flaxdiff/schedulers/discrete.py,sha256= 
     | 
| 
      
 35 
     | 
    
         
            +
            flaxdiff/schedulers/cosine.py,sha256=E5pODAmINfdyC4kSYOJSPAvq3GNlKPpKEn3X82vYMz0,2055
         
     | 
| 
      
 36 
     | 
    
         
            +
            flaxdiff/schedulers/discrete.py,sha256=m1q3bAgeAxU3gTj5di3XFWDm4yLfMKAFJPlYdozLE2Y,3316
         
     | 
| 
       37 
37 
     | 
    
         
             
            flaxdiff/schedulers/exp.py,sha256=cPTnUJpYdzJRRZqMLYQz0rRUCpEmaP2tXhRumLx94jA,605
         
     | 
| 
       38 
38 
     | 
    
         
             
            flaxdiff/schedulers/karras.py,sha256=4GN120kGwdxxU-h2mVdhBVy9IORkUMm_vvz3XjthBcI,3355
         
     | 
| 
       39 
     | 
    
         
            -
            flaxdiff/schedulers/linear.py,sha256= 
     | 
| 
       40 
     | 
    
         
            -
            flaxdiff/schedulers/sqrt.py,sha256= 
     | 
| 
      
 39 
     | 
    
         
            +
            flaxdiff/schedulers/linear.py,sha256=pBDTXSQcOS4Z03JTh6S0f9E2qLcTQzF2E-pGoQnRoy0,572
         
     | 
| 
      
 40 
     | 
    
         
            +
            flaxdiff/schedulers/sqrt.py,sha256=mCd_szmOqF6vqQKiAiEOqV_3eBIPGYrW3VxK0o4rBuo,474
         
     | 
| 
       41 
41 
     | 
    
         
             
            flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,128
         
     | 
| 
       42 
42 
     | 
    
         
             
            flaxdiff/trainer/autoencoder_trainer.py,sha256=hxihkRL9WCIQVGOP-pc1jjjIUaRXDLcNo3_erTKsuWM,7049
         
     | 
| 
       43 
     | 
    
         
            -
            flaxdiff/trainer/diffusion_trainer.py,sha256= 
     | 
| 
      
 43 
     | 
    
         
            +
            flaxdiff/trainer/diffusion_trainer.py,sha256=kEulMnk6ZkKhQRSVr3UtDdCmXR4cWphJ3XNuk7VIAUY,14189
         
     | 
| 
       44 
44 
     | 
    
         
             
            flaxdiff/trainer/simple_trainer.py,sha256=LScHQZCy5ksSC7n0GC0tjOXK-zptxpMJsC6Udf-nz18,22178
         
     | 
| 
       45 
45 
     | 
    
         
             
            flaxdiff/trainer/video_diffusion_trainer.py,sha256=gMkKpnKNTo8QhTx5ptEEkc7W5-7rzXIr9queU53hXyQ,2197
         
     | 
| 
       46 
     | 
    
         
            -
            flaxdiff-0.1.37. 
     | 
| 
       47 
     | 
    
         
            -
            flaxdiff-0.1.37. 
     | 
| 
       48 
     | 
    
         
            -
            flaxdiff-0.1.37. 
     | 
| 
       49 
     | 
    
         
            -
            flaxdiff-0.1.37. 
     | 
| 
      
 46 
     | 
    
         
            +
            flaxdiff-0.1.37.6.dist-info/METADATA,sha256=SujaCKk29ECrfSEIdchYvAl-nf0L270t2of7oeX5kgk,23985
         
     | 
| 
      
 47 
     | 
    
         
            +
            flaxdiff-0.1.37.6.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
         
     | 
| 
      
 48 
     | 
    
         
            +
            flaxdiff-0.1.37.6.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
         
     | 
| 
      
 49 
     | 
    
         
            +
            flaxdiff-0.1.37.6.dist-info/RECORD,,
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     |