flaxdiff 0.1.35.6__py3-none-any.whl → 0.1.36.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/utils.py +105 -2
- {flaxdiff-0.1.35.6.dist-info → flaxdiff-0.1.36.1.dist-info}/METADATA +16 -7
- flaxdiff-0.1.36.1.dist-info/RECORD +6 -0
- {flaxdiff-0.1.35.6.dist-info → flaxdiff-0.1.36.1.dist-info}/WHEEL +1 -1
- flaxdiff/data/__init__.py +0 -1
- flaxdiff/data/online_loader.py +0 -336
- flaxdiff/models/__init__.py +0 -1
- flaxdiff/models/attention.py +0 -368
- flaxdiff/models/autoencoder/__init__.py +0 -2
- flaxdiff/models/autoencoder/autoencoder.py +0 -19
- flaxdiff/models/autoencoder/diffusers.py +0 -91
- flaxdiff/models/autoencoder/simple_autoenc.py +0 -26
- flaxdiff/models/common.py +0 -346
- flaxdiff/models/favor_fastattn.py +0 -723
- flaxdiff/models/simple_unet.py +0 -233
- flaxdiff/models/simple_vit.py +0 -180
- flaxdiff/predictors/__init__.py +0 -96
- flaxdiff/samplers/__init__.py +0 -7
- flaxdiff/samplers/common.py +0 -113
- flaxdiff/samplers/ddim.py +0 -10
- flaxdiff/samplers/ddpm.py +0 -43
- flaxdiff/samplers/euler.py +0 -59
- flaxdiff/samplers/heun_sampler.py +0 -28
- flaxdiff/samplers/multistep_dpm.py +0 -60
- flaxdiff/samplers/rk4_sampler.py +0 -34
- flaxdiff/schedulers/__init__.py +0 -6
- flaxdiff/schedulers/common.py +0 -98
- flaxdiff/schedulers/continuous.py +0 -12
- flaxdiff/schedulers/cosine.py +0 -40
- flaxdiff/schedulers/discrete.py +0 -74
- flaxdiff/schedulers/exp.py +0 -13
- flaxdiff/schedulers/karras.py +0 -69
- flaxdiff/schedulers/linear.py +0 -14
- flaxdiff/schedulers/sqrt.py +0 -10
- flaxdiff/trainer/__init__.py +0 -2
- flaxdiff/trainer/autoencoder_trainer.py +0 -182
- flaxdiff/trainer/diffusion_trainer.py +0 -234
- flaxdiff/trainer/simple_trainer.py +0 -442
- flaxdiff-0.1.35.6.dist-info/RECORD +0 -40
- {flaxdiff-0.1.35.6.dist-info → flaxdiff-0.1.36.1.dist-info}/top_level.txt +0 -0
| @@ -1,182 +0,0 @@ | |
| 1 | 
            -
            from flax import linen as nn
         | 
| 2 | 
            -
            import jax
         | 
| 3 | 
            -
            from typing import Callable
         | 
| 4 | 
            -
            from dataclasses import field
         | 
| 5 | 
            -
            import jax.numpy as jnp
         | 
| 6 | 
            -
            import optax
         | 
| 7 | 
            -
            from jax.sharding import Mesh, PartitionSpec as P
         | 
| 8 | 
            -
            from jax.experimental.shard_map import shard_map
         | 
| 9 | 
            -
            from typing import Dict, Callable, Sequence, Any, Union, Tuple
         | 
| 10 | 
            -
             | 
| 11 | 
            -
            from ..schedulers import NoiseScheduler
         | 
| 12 | 
            -
            from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
         | 
| 13 | 
            -
             | 
| 14 | 
            -
            from flaxdiff.utils import RandomMarkovState
         | 
| 15 | 
            -
             | 
| 16 | 
            -
            from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
         | 
| 17 | 
            -
             | 
| 18 | 
            -
            from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
         | 
| 19 | 
            -
             | 
| 20 | 
            -
            class AutoEncoderTrainer(SimpleTrainer):
         | 
| 21 | 
            -
                def __init__(self,
         | 
| 22 | 
            -
                             model: nn.Module,
         | 
| 23 | 
            -
                             input_shape: Union[int, int, int],
         | 
| 24 | 
            -
                             latent_dim: int,
         | 
| 25 | 
            -
                             spatial_scale: int,
         | 
| 26 | 
            -
                             optimizer: optax.GradientTransformation,
         | 
| 27 | 
            -
                             rngs: jax.random.PRNGKey,
         | 
| 28 | 
            -
                             name: str = "Autoencoder",
         | 
| 29 | 
            -
                             **kwargs
         | 
| 30 | 
            -
                             ):
         | 
| 31 | 
            -
                    super().__init__(
         | 
| 32 | 
            -
                        model=model,
         | 
| 33 | 
            -
                        input_shapes={"image": input_shape},
         | 
| 34 | 
            -
                        optimizer=optimizer,
         | 
| 35 | 
            -
                        rngs=rngs,
         | 
| 36 | 
            -
                        name=name,
         | 
| 37 | 
            -
                        **kwargs
         | 
| 38 | 
            -
                    )
         | 
| 39 | 
            -
                    self.latent_dim = latent_dim
         | 
| 40 | 
            -
                    self.spatial_scale = spatial_scale
         | 
| 41 | 
            -
                    
         | 
| 42 | 
            -
             | 
| 43 | 
            -
                def generate_states(
         | 
| 44 | 
            -
                    self,
         | 
| 45 | 
            -
                    optimizer: optax.GradientTransformation,
         | 
| 46 | 
            -
                    rngs: jax.random.PRNGKey,
         | 
| 47 | 
            -
                    existing_state: dict = None,
         | 
| 48 | 
            -
                    existing_best_state: dict = None,
         | 
| 49 | 
            -
                    model: nn.Module = None,
         | 
| 50 | 
            -
                    param_transforms: Callable = None
         | 
| 51 | 
            -
                ) -> Tuple[TrainState, TrainState]:
         | 
| 52 | 
            -
                    print("Generating states for DiffusionTrainer")
         | 
| 53 | 
            -
                    rngs, subkey = jax.random.split(rngs)
         | 
| 54 | 
            -
             | 
| 55 | 
            -
                    if existing_state == None:
         | 
| 56 | 
            -
                        input_vars = self.get_input_ones()
         | 
| 57 | 
            -
                        params = model.init(subkey, **input_vars)
         | 
| 58 | 
            -
                        new_state = {"params": params, "ema_params": params}
         | 
| 59 | 
            -
                    else:
         | 
| 60 | 
            -
                        new_state = existing_state
         | 
| 61 | 
            -
             | 
| 62 | 
            -
                    if param_transforms is not None:
         | 
| 63 | 
            -
                        params = param_transforms(params)
         | 
| 64 | 
            -
             | 
| 65 | 
            -
                    state = TrainState.create(
         | 
| 66 | 
            -
                        apply_fn=model.apply,
         | 
| 67 | 
            -
                        params=new_state['params'],
         | 
| 68 | 
            -
                        ema_params=new_state['ema_params'],
         | 
| 69 | 
            -
                        tx=optimizer,
         | 
| 70 | 
            -
                        rngs=rngs,
         | 
| 71 | 
            -
                        metrics=Metrics.empty()
         | 
| 72 | 
            -
                    )
         | 
| 73 | 
            -
                        
         | 
| 74 | 
            -
                    if existing_best_state is not None:
         | 
| 75 | 
            -
                        best_state = state.replace(
         | 
| 76 | 
            -
                            params=existing_best_state['params'], ema_params=existing_best_state['ema_params'])
         | 
| 77 | 
            -
                    else:
         | 
| 78 | 
            -
                        best_state = state
         | 
| 79 | 
            -
             | 
| 80 | 
            -
                    return state, best_state
         | 
| 81 | 
            -
             | 
| 82 | 
            -
                def _define_train_step(self, batch_size, null_labels_seq, text_embedder):
         | 
| 83 | 
            -
                    noise_schedule: NoiseScheduler = self.noise_schedule
         | 
| 84 | 
            -
                    model = self.model
         | 
| 85 | 
            -
                    model_output_transform = self.model_output_transform
         | 
| 86 | 
            -
                    loss_fn = self.loss_fn
         | 
| 87 | 
            -
                    unconditional_prob = self.unconditional_prob
         | 
| 88 | 
            -
             | 
| 89 | 
            -
                    # Determine the number of unconditional samples
         | 
| 90 | 
            -
                    num_unconditional = int(batch_size * unconditional_prob)
         | 
| 91 | 
            -
             | 
| 92 | 
            -
                    nS, nC = null_labels_seq.shape
         | 
| 93 | 
            -
                    null_labels_seq = jnp.broadcast_to(
         | 
| 94 | 
            -
                        null_labels_seq, (batch_size, nS, nC))
         | 
| 95 | 
            -
             | 
| 96 | 
            -
                    distributed_training = self.distributed_training
         | 
| 97 | 
            -
                    
         | 
| 98 | 
            -
                    autoencoder = self.autoencoder
         | 
| 99 | 
            -
             | 
| 100 | 
            -
                    # @jax.jit
         | 
| 101 | 
            -
                    def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, local_device_index):
         | 
| 102 | 
            -
                        """Train for a single step."""
         | 
| 103 | 
            -
                        rng_state, subkey = rng_state.get_random_key()
         | 
| 104 | 
            -
                        subkey = jax.random.fold_in(subkey, local_device_index.reshape())
         | 
| 105 | 
            -
                        local_rng_state = RandomMarkovState(subkey)
         | 
| 106 | 
            -
                        
         | 
| 107 | 
            -
                        images = batch['image']
         | 
| 108 | 
            -
                        
         | 
| 109 | 
            -
                        if autoencoder is not None:
         | 
| 110 | 
            -
                            # Convert the images to latent space
         | 
| 111 | 
            -
                            local_rng_state, rngs = local_rng_state.get_random_key()
         | 
| 112 | 
            -
                            images = autoencoder.encode(images, rngs)
         | 
| 113 | 
            -
                        else:
         | 
| 114 | 
            -
                            # normalize image
         | 
| 115 | 
            -
                            images = (images - 127.5) / 127.5
         | 
| 116 | 
            -
             | 
| 117 | 
            -
                        output = text_embedder(
         | 
| 118 | 
            -
                            input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
         | 
| 119 | 
            -
                        label_seq = output.last_hidden_state
         | 
| 120 | 
            -
             | 
| 121 | 
            -
                        # Generate random probabilities to decide how much of this batch will be unconditional
         | 
| 122 | 
            -
             | 
| 123 | 
            -
                        label_seq = jnp.concat(
         | 
| 124 | 
            -
                            [null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
         | 
| 125 | 
            -
             | 
| 126 | 
            -
                        noise_level, local_rng_state = noise_schedule.generate_timesteps(images.shape[0], local_rng_state)
         | 
| 127 | 
            -
                        
         | 
| 128 | 
            -
                        local_rng_state, rngs = local_rng_state.get_random_key()
         | 
| 129 | 
            -
                        noise: jax.Array = jax.random.normal(rngs, shape=images.shape)
         | 
| 130 | 
            -
                        
         | 
| 131 | 
            -
                        rates = noise_schedule.get_rates(noise_level)
         | 
| 132 | 
            -
                        noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
         | 
| 133 | 
            -
                            images, noise, rates)
         | 
| 134 | 
            -
             | 
| 135 | 
            -
                        def model_loss(params):
         | 
| 136 | 
            -
                            preds = model.apply(params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
         | 
| 137 | 
            -
                            preds = model_output_transform.pred_transform(
         | 
| 138 | 
            -
                                noisy_images, preds, rates)
         | 
| 139 | 
            -
                            nloss = loss_fn(preds, expected_output)
         | 
| 140 | 
            -
                            # nloss = jnp.mean(nloss, axis=1)
         | 
| 141 | 
            -
                            nloss *= noise_schedule.get_weights(noise_level)
         | 
| 142 | 
            -
                            nloss = jnp.mean(nloss)
         | 
| 143 | 
            -
                            loss = nloss
         | 
| 144 | 
            -
                            return loss
         | 
| 145 | 
            -
                        
         | 
| 146 | 
            -
                        loss, grads = jax.value_and_grad(model_loss)(train_state.params)
         | 
| 147 | 
            -
                        if distributed_training:
         | 
| 148 | 
            -
                            grads = jax.lax.pmean(grads, "data")
         | 
| 149 | 
            -
                            loss = jax.lax.pmean(loss, "data")
         | 
| 150 | 
            -
                        train_state = train_state.apply_gradients(grads=grads)
         | 
| 151 | 
            -
                        train_state = train_state.apply_ema(self.ema_decay)
         | 
| 152 | 
            -
                        return train_state, loss, rng_state
         | 
| 153 | 
            -
             | 
| 154 | 
            -
                    if distributed_training:
         | 
| 155 | 
            -
                        train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')), 
         | 
| 156 | 
            -
                                               out_specs=(P(), P(), P()))
         | 
| 157 | 
            -
                        train_step = jax.jit(train_step)
         | 
| 158 | 
            -
                        
         | 
| 159 | 
            -
                    return train_step
         | 
| 160 | 
            -
             | 
| 161 | 
            -
                def _define_compute_metrics(self):
         | 
| 162 | 
            -
                    @jax.jit
         | 
| 163 | 
            -
                    def compute_metrics(state: TrainState, expected, pred):
         | 
| 164 | 
            -
                        loss = jnp.mean(jnp.square(pred - expected))
         | 
| 165 | 
            -
                        metric_updates = state.metrics.single_from_model_output(loss=loss)
         | 
| 166 | 
            -
                        metrics = state.metrics.merge(metric_updates)
         | 
| 167 | 
            -
                        state = state.replace(metrics=metrics)
         | 
| 168 | 
            -
                        return state
         | 
| 169 | 
            -
                    return compute_metrics
         | 
| 170 | 
            -
             | 
| 171 | 
            -
                def fit(self, data, steps_per_epoch, epochs):
         | 
| 172 | 
            -
                    null_labels_full = data['null_labels_full']
         | 
| 173 | 
            -
                    local_batch_size = data['local_batch_size']
         | 
| 174 | 
            -
                    text_embedder = data['model']
         | 
| 175 | 
            -
                    super().fit(data, steps_per_epoch, epochs, {
         | 
| 176 | 
            -
                        "batch_size": local_batch_size, "null_labels_seq": null_labels_full, "text_embedder": text_embedder})
         | 
| 177 | 
            -
             | 
| 178 | 
            -
            def boolean_string(s):
         | 
| 179 | 
            -
                if type(s) == bool:
         | 
| 180 | 
            -
                    return s
         | 
| 181 | 
            -
                return s == 'True'
         | 
| 182 | 
            -
             | 
| @@ -1,234 +0,0 @@ | |
| 1 | 
            -
            from flax import linen as nn
         | 
| 2 | 
            -
            import jax
         | 
| 3 | 
            -
            from typing import Callable
         | 
| 4 | 
            -
            from dataclasses import field
         | 
| 5 | 
            -
            import jax.numpy as jnp
         | 
| 6 | 
            -
            import optax
         | 
| 7 | 
            -
            from jax.sharding import Mesh, PartitionSpec as P
         | 
| 8 | 
            -
            from jax.experimental.shard_map import shard_map
         | 
| 9 | 
            -
            from typing import Dict, Callable, Sequence, Any, Union, Tuple
         | 
| 10 | 
            -
             | 
| 11 | 
            -
            from ..schedulers import NoiseScheduler
         | 
| 12 | 
            -
            from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
         | 
| 13 | 
            -
             | 
| 14 | 
            -
            from flaxdiff.utils import RandomMarkovState
         | 
| 15 | 
            -
             | 
| 16 | 
            -
            from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
         | 
| 17 | 
            -
             | 
| 18 | 
            -
            from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
         | 
| 19 | 
            -
            from flax.training.dynamic_scale import DynamicScale
         | 
| 20 | 
            -
             | 
| 21 | 
            -
            class TrainState(SimpleTrainState):
         | 
| 22 | 
            -
                rngs: jax.random.PRNGKey
         | 
| 23 | 
            -
                ema_params: dict
         | 
| 24 | 
            -
             | 
| 25 | 
            -
                def apply_ema(self, decay: float = 0.999):
         | 
| 26 | 
            -
                    new_ema_params = jax.tree_util.tree_map(
         | 
| 27 | 
            -
                        lambda ema, param: decay * ema + (1 - decay) * param,
         | 
| 28 | 
            -
                        self.ema_params,
         | 
| 29 | 
            -
                        self.params,
         | 
| 30 | 
            -
                    )
         | 
| 31 | 
            -
                    return self.replace(ema_params=new_ema_params)
         | 
| 32 | 
            -
             | 
| 33 | 
            -
            from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
         | 
| 34 | 
            -
             | 
| 35 | 
            -
            class DiffusionTrainer(SimpleTrainer):
         | 
| 36 | 
            -
                noise_schedule: NoiseScheduler
         | 
| 37 | 
            -
                model_output_transform: DiffusionPredictionTransform
         | 
| 38 | 
            -
                ema_decay: float = 0.999
         | 
| 39 | 
            -
             | 
| 40 | 
            -
                def __init__(self,
         | 
| 41 | 
            -
                             model: nn.Module,
         | 
| 42 | 
            -
                             input_shapes: Dict[str, Tuple[int]],
         | 
| 43 | 
            -
                             optimizer: optax.GradientTransformation,
         | 
| 44 | 
            -
                             noise_schedule: NoiseScheduler,
         | 
| 45 | 
            -
                             rngs: jax.random.PRNGKey,
         | 
| 46 | 
            -
                             unconditional_prob: float = 0.12,
         | 
| 47 | 
            -
                             name: str = "Diffusion",
         | 
| 48 | 
            -
                             model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
         | 
| 49 | 
            -
                             autoencoder: AutoEncoder = None,
         | 
| 50 | 
            -
                             **kwargs
         | 
| 51 | 
            -
                             ):
         | 
| 52 | 
            -
                    super().__init__(
         | 
| 53 | 
            -
                        model=model,
         | 
| 54 | 
            -
                        input_shapes=input_shapes,
         | 
| 55 | 
            -
                        optimizer=optimizer,
         | 
| 56 | 
            -
                        rngs=rngs,
         | 
| 57 | 
            -
                        name=name,
         | 
| 58 | 
            -
                        **kwargs
         | 
| 59 | 
            -
                    )
         | 
| 60 | 
            -
                    self.noise_schedule = noise_schedule
         | 
| 61 | 
            -
                    self.model_output_transform = model_output_transform
         | 
| 62 | 
            -
                    self.unconditional_prob = unconditional_prob
         | 
| 63 | 
            -
                    
         | 
| 64 | 
            -
                    self.autoencoder = autoencoder
         | 
| 65 | 
            -
             | 
| 66 | 
            -
                def generate_states(
         | 
| 67 | 
            -
                    self,
         | 
| 68 | 
            -
                    optimizer: optax.GradientTransformation,
         | 
| 69 | 
            -
                    rngs: jax.random.PRNGKey,
         | 
| 70 | 
            -
                    existing_state: dict = None,
         | 
| 71 | 
            -
                    existing_best_state: dict = None,
         | 
| 72 | 
            -
                    model: nn.Module = None,
         | 
| 73 | 
            -
                    param_transforms: Callable = None,
         | 
| 74 | 
            -
                    use_dynamic_scale: bool = False
         | 
| 75 | 
            -
                ) -> Tuple[TrainState, TrainState]:
         | 
| 76 | 
            -
                    print("Generating states for DiffusionTrainer")
         | 
| 77 | 
            -
                    rngs, subkey = jax.random.split(rngs)
         | 
| 78 | 
            -
             | 
| 79 | 
            -
                    if existing_state == None:
         | 
| 80 | 
            -
                        input_vars = self.get_input_ones()
         | 
| 81 | 
            -
                        params = model.init(subkey, **input_vars)
         | 
| 82 | 
            -
                        new_state = {"params": params, "ema_params": params}
         | 
| 83 | 
            -
                    else:
         | 
| 84 | 
            -
                        new_state = existing_state
         | 
| 85 | 
            -
             | 
| 86 | 
            -
                    if param_transforms is not None:
         | 
| 87 | 
            -
                        new_state['params'] = param_transforms(new_state['params'])
         | 
| 88 | 
            -
                        new_state['ema_params'] = param_transforms(new_state['ema_params'])
         | 
| 89 | 
            -
             | 
| 90 | 
            -
                    state = TrainState.create(
         | 
| 91 | 
            -
                        apply_fn=model.apply,
         | 
| 92 | 
            -
                        params=new_state['params'],
         | 
| 93 | 
            -
                        ema_params=new_state['ema_params'],
         | 
| 94 | 
            -
                        tx=optimizer,
         | 
| 95 | 
            -
                        rngs=rngs,
         | 
| 96 | 
            -
                        metrics=Metrics.empty(),
         | 
| 97 | 
            -
                        dynamic_scale = DynamicScale() if use_dynamic_scale else None
         | 
| 98 | 
            -
                    )
         | 
| 99 | 
            -
                        
         | 
| 100 | 
            -
                    if existing_best_state is not None:
         | 
| 101 | 
            -
                        best_state = state.replace(
         | 
| 102 | 
            -
                            params=existing_best_state['params'], ema_params=existing_best_state['ema_params'])
         | 
| 103 | 
            -
                    else:
         | 
| 104 | 
            -
                        best_state = state
         | 
| 105 | 
            -
             | 
| 106 | 
            -
                    return state, best_state
         | 
| 107 | 
            -
             | 
| 108 | 
            -
                def _define_train_step(self, batch_size, null_labels_seq, text_embedder):
         | 
| 109 | 
            -
                    noise_schedule: NoiseScheduler = self.noise_schedule
         | 
| 110 | 
            -
                    model = self.model
         | 
| 111 | 
            -
                    model_output_transform = self.model_output_transform
         | 
| 112 | 
            -
                    loss_fn = self.loss_fn
         | 
| 113 | 
            -
                    unconditional_prob = self.unconditional_prob
         | 
| 114 | 
            -
             | 
| 115 | 
            -
                    # Determine the number of unconditional samples
         | 
| 116 | 
            -
                    num_unconditional = int(batch_size * unconditional_prob)
         | 
| 117 | 
            -
             | 
| 118 | 
            -
                    nS, nC = null_labels_seq.shape
         | 
| 119 | 
            -
                    null_labels_seq = jnp.broadcast_to(
         | 
| 120 | 
            -
                        null_labels_seq, (batch_size, nS, nC))
         | 
| 121 | 
            -
             | 
| 122 | 
            -
                    distributed_training = self.distributed_training
         | 
| 123 | 
            -
                    
         | 
| 124 | 
            -
                    autoencoder = self.autoencoder
         | 
| 125 | 
            -
             | 
| 126 | 
            -
                    # @jax.jit
         | 
| 127 | 
            -
                    def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, local_device_index):
         | 
| 128 | 
            -
                        """Train for a single step."""
         | 
| 129 | 
            -
                        rng_state, subkey = rng_state.get_random_key()
         | 
| 130 | 
            -
                        subkey = jax.random.fold_in(subkey, local_device_index.reshape())
         | 
| 131 | 
            -
                        local_rng_state = RandomMarkovState(subkey)
         | 
| 132 | 
            -
                        
         | 
| 133 | 
            -
                        images = batch['image']
         | 
| 134 | 
            -
                        images = jnp.array(images, dtype=jnp.float32)
         | 
| 135 | 
            -
                        # normalize image
         | 
| 136 | 
            -
                        images = (images - 127.5) / 127.5
         | 
| 137 | 
            -
                        
         | 
| 138 | 
            -
                        if autoencoder is not None:
         | 
| 139 | 
            -
                            # Convert the images to latent space
         | 
| 140 | 
            -
                            local_rng_state, rngs = local_rng_state.get_random_key()
         | 
| 141 | 
            -
                            images = autoencoder.encode(images, rngs)
         | 
| 142 | 
            -
             | 
| 143 | 
            -
                        output = text_embedder(
         | 
| 144 | 
            -
                            input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
         | 
| 145 | 
            -
                        label_seq = output.last_hidden_state
         | 
| 146 | 
            -
             | 
| 147 | 
            -
                        # Generate random probabilities to decide how much of this batch will be unconditional
         | 
| 148 | 
            -
             | 
| 149 | 
            -
                        label_seq = jnp.concat(
         | 
| 150 | 
            -
                            [null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
         | 
| 151 | 
            -
             | 
| 152 | 
            -
                        noise_level, local_rng_state = noise_schedule.generate_timesteps(images.shape[0], local_rng_state)
         | 
| 153 | 
            -
                        
         | 
| 154 | 
            -
                        local_rng_state, rngs = local_rng_state.get_random_key()
         | 
| 155 | 
            -
                        noise: jax.Array = jax.random.normal(rngs, shape=images.shape)
         | 
| 156 | 
            -
                        
         | 
| 157 | 
            -
                        rates = noise_schedule.get_rates(noise_level)
         | 
| 158 | 
            -
                        noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
         | 
| 159 | 
            -
                            images, noise, rates)
         | 
| 160 | 
            -
             | 
| 161 | 
            -
                        def model_loss(params):
         | 
| 162 | 
            -
                            preds = model.apply(params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
         | 
| 163 | 
            -
                            preds = model_output_transform.pred_transform(
         | 
| 164 | 
            -
                                noisy_images, preds, rates)
         | 
| 165 | 
            -
                            nloss = loss_fn(preds, expected_output)
         | 
| 166 | 
            -
                            # nloss = jnp.mean(nloss, axis=1)
         | 
| 167 | 
            -
                            nloss *= noise_schedule.get_weights(noise_level)
         | 
| 168 | 
            -
                            nloss = jnp.mean(nloss)
         | 
| 169 | 
            -
                            loss = nloss
         | 
| 170 | 
            -
                            return loss
         | 
| 171 | 
            -
                        
         | 
| 172 | 
            -
                        
         | 
| 173 | 
            -
                        if train_state.dynamic_scale is not None:
         | 
| 174 | 
            -
                            # dynamic scale takes care of averaging gradients across replicas
         | 
| 175 | 
            -
                            grad_fn = train_state.dynamic_scale.value_and_grad(
         | 
| 176 | 
            -
                                model_loss, axis_name="data"
         | 
| 177 | 
            -
                            )
         | 
| 178 | 
            -
                            dynamic_scale, is_fin, loss, grads = grad_fn(train_state.params)
         | 
| 179 | 
            -
                            train_state = train_state.replace(dynamic_scale=dynamic_scale)
         | 
| 180 | 
            -
                        else:
         | 
| 181 | 
            -
                            grad_fn = jax.value_and_grad(model_loss)
         | 
| 182 | 
            -
                            loss, grads = grad_fn(train_state.params)
         | 
| 183 | 
            -
                            if distributed_training:
         | 
| 184 | 
            -
                                grads = jax.lax.pmean(grads, "data")
         | 
| 185 | 
            -
                        
         | 
| 186 | 
            -
                        new_state = train_state.apply_gradients(grads=grads)
         | 
| 187 | 
            -
                        
         | 
| 188 | 
            -
                        if train_state.dynamic_scale:
         | 
| 189 | 
            -
                            # if is_fin == False the gradients contain Inf/NaNs and optimizer state and
         | 
| 190 | 
            -
                            # params should be restored (= skip this step).
         | 
| 191 | 
            -
                            select_fn = functools.partial(jnp.where, is_fin)
         | 
| 192 | 
            -
                            new_state = train_state.replace(
         | 
| 193 | 
            -
                                opt_state=jax.tree_util.tree_map(
         | 
| 194 | 
            -
                                    select_fn, new_state.opt_state, train_state.opt_state
         | 
| 195 | 
            -
                                ),
         | 
| 196 | 
            -
                                params=jax.tree_util.tree_map(
         | 
| 197 | 
            -
                                    select_fn, new_state.params, train_state.params
         | 
| 198 | 
            -
                                ),
         | 
| 199 | 
            -
                            )
         | 
| 200 | 
            -
                
         | 
| 201 | 
            -
                        train_state = new_state.apply_ema(self.ema_decay)
         | 
| 202 | 
            -
                        
         | 
| 203 | 
            -
                        if distributed_training:
         | 
| 204 | 
            -
                            loss = jax.lax.pmean(loss, "data")
         | 
| 205 | 
            -
                        return train_state, loss, rng_state
         | 
| 206 | 
            -
             | 
| 207 | 
            -
                    if distributed_training:
         | 
| 208 | 
            -
                        train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')), 
         | 
| 209 | 
            -
                                               out_specs=(P(), P(), P()))
         | 
| 210 | 
            -
                        train_step = jax.jit(train_step)
         | 
| 211 | 
            -
                        
         | 
| 212 | 
            -
                    return train_step
         | 
| 213 | 
            -
             | 
| 214 | 
            -
                def _define_compute_metrics(self):
         | 
| 215 | 
            -
                    @jax.jit
         | 
| 216 | 
            -
                    def compute_metrics(state: TrainState, expected, pred):
         | 
| 217 | 
            -
                        loss = jnp.mean(jnp.square(pred - expected))
         | 
| 218 | 
            -
                        metric_updates = state.metrics.single_from_model_output(loss=loss)
         | 
| 219 | 
            -
                        metrics = state.metrics.merge(metric_updates)
         | 
| 220 | 
            -
                        state = state.replace(metrics=metrics)
         | 
| 221 | 
            -
                        return state
         | 
| 222 | 
            -
                    return compute_metrics
         | 
| 223 | 
            -
             | 
| 224 | 
            -
                def fit(self, data, steps_per_epoch, epochs):
         | 
| 225 | 
            -
                    null_labels_full = data['null_labels_full']
         | 
| 226 | 
            -
                    local_batch_size = data['local_batch_size']
         | 
| 227 | 
            -
                    text_embedder = data['model']
         | 
| 228 | 
            -
                    super().fit(data, steps_per_epoch, epochs, {
         | 
| 229 | 
            -
                        "batch_size": local_batch_size, "null_labels_seq": null_labels_full, "text_embedder": text_embedder})
         | 
| 230 | 
            -
             | 
| 231 | 
            -
            def boolean_string(s):
         | 
| 232 | 
            -
                if type(s) == bool:
         | 
| 233 | 
            -
                    return s
         | 
| 234 | 
            -
                return s == 'True'
         |