flaxdiff 0.1.35.5__py3-none-any.whl → 0.1.36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/dataset_map.py +71 -0
 - flaxdiff/data/datasets.py +169 -0
 - flaxdiff/data/online_loader.py +69 -42
 - flaxdiff/models/attention.py +1 -0
 - flaxdiff/models/simple_unet.py +11 -11
 - flaxdiff/models/simple_vit.py +1 -1
 - flaxdiff/samplers/common.py +72 -20
 - flaxdiff/samplers/ddim.py +5 -5
 - flaxdiff/samplers/ddpm.py +5 -11
 - flaxdiff/samplers/euler.py +7 -10
 - flaxdiff/samplers/heun_sampler.py +3 -4
 - flaxdiff/samplers/multistep_dpm.py +2 -3
 - flaxdiff/samplers/rk4_sampler.py +9 -9
 - flaxdiff/trainer/autoencoder_trainer.py +1 -1
 - flaxdiff/trainer/diffusion_trainer.py +124 -32
 - flaxdiff/trainer/simple_trainer.py +187 -91
 - flaxdiff/trainer/video_diffusion_trainer.py +62 -0
 - flaxdiff/utils.py +105 -2
 - {flaxdiff-0.1.35.5.dist-info → flaxdiff-0.1.36.dist-info}/METADATA +11 -5
 - flaxdiff-0.1.36.dist-info/RECORD +43 -0
 - {flaxdiff-0.1.35.5.dist-info → flaxdiff-0.1.36.dist-info}/WHEEL +1 -1
 - flaxdiff-0.1.35.5.dist-info/RECORD +0 -40
 - {flaxdiff-0.1.35.5.dist-info → flaxdiff-0.1.36.dist-info}/top_level.txt +0 -0
 
| 
         @@ -24,6 +24,7 @@ from termcolor import colored 
     | 
|
| 
       24 
24 
     | 
    
         
             
            from typing import Dict, Callable, Sequence, Any, Union, Tuple
         
     | 
| 
       25 
25 
     | 
    
         
             
            from flax.training.dynamic_scale import DynamicScale
         
     | 
| 
       26 
26 
     | 
    
         
             
            from flaxdiff.utils import RandomMarkovState
         
     | 
| 
      
 27 
     | 
    
         
            +
            from flax.training import dynamic_scale as dynamic_scale_lib
         
     | 
| 
       27 
28 
     | 
    
         | 
| 
       28 
29 
     | 
    
         
             
            PROCESS_COLOR_MAP = {
         
     | 
| 
       29 
30 
     | 
    
         
             
                0: "green",
         
     | 
| 
         @@ -63,12 +64,12 @@ def convert_to_global_tree(global_mesh, pytree): 
     | 
|
| 
       63 
64 
     | 
    
         
             
            @struct.dataclass
         
     | 
| 
       64 
65 
     | 
    
         
             
            class Metrics(metrics.Collection):
         
     | 
| 
       65 
66 
     | 
    
         
             
                accuracy: metrics.Accuracy
         
     | 
| 
       66 
     | 
    
         
            -
                loss: metrics.Average 
     | 
| 
      
 67 
     | 
    
         
            +
                loss: metrics.Average#.from_output('loss')
         
     | 
| 
       67 
68 
     | 
    
         | 
| 
       68 
69 
     | 
    
         
             
            # Define the TrainState
         
     | 
| 
       69 
70 
     | 
    
         
             
            class SimpleTrainState(train_state.TrainState):
         
     | 
| 
       70 
71 
     | 
    
         
             
                metrics: Metrics
         
     | 
| 
       71 
     | 
    
         
            -
                dynamic_scale: DynamicScale
         
     | 
| 
      
 72 
     | 
    
         
            +
                dynamic_scale: dynamic_scale_lib.DynamicScale
         
     | 
| 
       72 
73 
     | 
    
         | 
| 
       73 
74 
     | 
    
         
             
            class SimpleTrainer:
         
     | 
| 
       74 
75 
     | 
    
         
             
                state: SimpleTrainState
         
     | 
| 
         @@ -110,6 +111,7 @@ class SimpleTrainer: 
     | 
|
| 
       110 
111 
     | 
    
         | 
| 
       111 
112 
     | 
    
         | 
| 
       112 
113 
     | 
    
         
             
                    if wandb_config is not None and jax.process_index() == 0:
         
     | 
| 
      
 114 
     | 
    
         
            +
                        import wandb
         
     | 
| 
       113 
115 
     | 
    
         
             
                        run = wandb.init(**wandb_config)
         
     | 
| 
       114 
116 
     | 
    
         
             
                        self.wandb = run
         
     | 
| 
       115 
117 
     | 
    
         | 
| 
         @@ -177,16 +179,13 @@ class SimpleTrainer: 
     | 
|
| 
       177 
179 
     | 
    
         
             
                        params = model.init(subkey, **input_vars)
         
     | 
| 
       178 
180 
     | 
    
         
             
                    else:
         
     | 
| 
       179 
181 
     | 
    
         
             
                        params = existing_state['params']
         
     | 
| 
       180 
     | 
    
         
            -
                        
         
     | 
| 
       181 
     | 
    
         
            -
                    if param_transforms is not None:
         
     | 
| 
       182 
     | 
    
         
            -
                        params = param_transforms(params)
         
     | 
| 
       183 
182 
     | 
    
         | 
| 
       184 
183 
     | 
    
         
             
                    state = SimpleTrainState.create(
         
     | 
| 
       185 
184 
     | 
    
         
             
                        apply_fn=model.apply,
         
     | 
| 
       186 
185 
     | 
    
         
             
                        params=params,
         
     | 
| 
       187 
186 
     | 
    
         
             
                        tx=optimizer,
         
     | 
| 
       188 
187 
     | 
    
         
             
                        metrics=Metrics.empty(),
         
     | 
| 
       189 
     | 
    
         
            -
                        dynamic_scale = DynamicScale() if use_dynamic_scale else None
         
     | 
| 
      
 188 
     | 
    
         
            +
                        dynamic_scale = dynamic_scale_lib.DynamicScale() if use_dynamic_scale else None
         
     | 
| 
       190 
189 
     | 
    
         
             
                    )
         
     | 
| 
       191 
190 
     | 
    
         
             
                    if existing_best_state is not None:
         
     | 
| 
       192 
191 
     | 
    
         
             
                        best_state = state.replace(
         
     | 
| 
         @@ -207,16 +206,16 @@ class SimpleTrainer: 
     | 
|
| 
       207 
206 
     | 
    
         
             
                    self.best_state = best_state
         
     | 
| 
       208 
207 
     | 
    
         | 
| 
       209 
208 
     | 
    
         
             
                def get_state(self):
         
     | 
| 
       210 
     | 
    
         
            -
                     
     | 
| 
       211 
     | 
    
         
            -
                    return jax.tree_util.tree_map(lambda x : np.array(x), self.state)
         
     | 
| 
      
 209 
     | 
    
         
            +
                    return self.get_np_tree(self.state)
         
     | 
| 
       212 
210 
     | 
    
         | 
| 
       213 
211 
     | 
    
         
             
                def get_best_state(self):
         
     | 
| 
       214 
     | 
    
         
            -
                     
     | 
| 
       215 
     | 
    
         
            -
                    return jax.tree_util.tree_map(lambda x : np.array(x), self.best_state)
         
     | 
| 
      
 212 
     | 
    
         
            +
                    return self.get_np_tree(self.best_state)
         
     | 
| 
       216 
213 
     | 
    
         | 
| 
       217 
214 
     | 
    
         
             
                def get_rngstate(self):
         
     | 
| 
       218 
     | 
    
         
            -
                     
     | 
| 
       219 
     | 
    
         
            -
             
     | 
| 
      
 215 
     | 
    
         
            +
                    return self.get_np_tree(self.rngstate)
         
     | 
| 
      
 216 
     | 
    
         
            +
                
         
     | 
| 
      
 217 
     | 
    
         
            +
                def get_np_tree(self, pytree):
         
     | 
| 
      
 218 
     | 
    
         
            +
                    return jax.tree_util.tree_map(lambda x : np.array(x), pytree)
         
     | 
| 
       220 
219 
     | 
    
         | 
| 
       221 
220 
     | 
    
         
             
                def checkpoint_path(self):
         
     | 
| 
       222 
221 
     | 
    
         
             
                    path = os.path.join(self.checkpoint_base_path, self.name.replace(' ', '_').lower())
         
     | 
| 
         @@ -253,29 +252,35 @@ class SimpleTrainer: 
     | 
|
| 
       253 
252 
     | 
    
         
             
                    rngstate = ckpt['rngs']
         
     | 
| 
       254 
253 
     | 
    
         
             
                    # Convert the state to a TrainState
         
     | 
| 
       255 
254 
     | 
    
         
             
                    self.best_loss = ckpt['best_loss']
         
     | 
| 
      
 255 
     | 
    
         
            +
                    if self.best_loss == 0:
         
     | 
| 
      
 256 
     | 
    
         
            +
                        # It cant be zero as that must have been some problem
         
     | 
| 
      
 257 
     | 
    
         
            +
                        self.best_loss = 1e9
         
     | 
| 
       256 
258 
     | 
    
         
             
                    current_epoch = ckpt.get('epoch', step) # Must be a checkpoint from an older version which used epochs instead of steps
         
     | 
| 
       257 
259 
     | 
    
         
             
                    print(
         
     | 
| 
       258 
260 
     | 
    
         
             
                        f"Loaded model from checkpoint at epoch {current_epoch} step {step}", ckpt['best_loss'])
         
     | 
| 
       259 
261 
     | 
    
         
             
                    return current_epoch, step, state, best_state, rngstate
         
     | 
| 
       260 
262 
     | 
    
         | 
| 
       261 
     | 
    
         
            -
                def save(self, epoch=0, step=0):
         
     | 
| 
      
 263 
     | 
    
         
            +
                def save(self, epoch=0, step=0, state=None, rngstate=None):
         
     | 
| 
       262 
264 
     | 
    
         
             
                    print(f"Saving model at epoch {epoch} step {step}")
         
     | 
| 
       263 
     | 
    
         
            -
                    ckpt = {
         
     | 
| 
       264 
     | 
    
         
            -
                        # 'model': self.model,
         
     | 
| 
       265 
     | 
    
         
            -
                        'rngs': self.get_rngstate(),
         
     | 
| 
       266 
     | 
    
         
            -
                        'state': self.get_state(),
         
     | 
| 
       267 
     | 
    
         
            -
                        'best_state': self.get_best_state(),
         
     | 
| 
       268 
     | 
    
         
            -
                        'best_loss': np.array(self.best_loss),
         
     | 
| 
       269 
     | 
    
         
            -
                        'epoch': epoch,
         
     | 
| 
       270 
     | 
    
         
            -
                    }
         
     | 
| 
       271 
265 
     | 
    
         
             
                    try:
         
     | 
| 
       272 
     | 
    
         
            -
                         
     | 
| 
       273 
     | 
    
         
            -
             
     | 
| 
       274 
     | 
    
         
            -
             
     | 
| 
       275 
     | 
    
         
            -
             
     | 
| 
       276 
     | 
    
         
            -
             
     | 
| 
      
 266 
     | 
    
         
            +
                        ckpt = {
         
     | 
| 
      
 267 
     | 
    
         
            +
                            # 'model': self.model,
         
     | 
| 
      
 268 
     | 
    
         
            +
                            'rngs': self.get_rngstate() if rngstate is None else self.get_np_tree(rngstate),
         
     | 
| 
      
 269 
     | 
    
         
            +
                            'state': self.get_state() if state is None else self.get_np_tree(state),
         
     | 
| 
      
 270 
     | 
    
         
            +
                            'best_state': self.get_best_state(),
         
     | 
| 
      
 271 
     | 
    
         
            +
                            'best_loss': np.array(self.best_loss),
         
     | 
| 
      
 272 
     | 
    
         
            +
                            'epoch': epoch,
         
     | 
| 
      
 273 
     | 
    
         
            +
                        }
         
     | 
| 
      
 274 
     | 
    
         
            +
                        try:
         
     | 
| 
      
 275 
     | 
    
         
            +
                            save_args = orbax_utils.save_args_from_target(ckpt)
         
     | 
| 
      
 276 
     | 
    
         
            +
                            self.checkpointer.save(step, ckpt, save_kwargs={
         
     | 
| 
      
 277 
     | 
    
         
            +
                                                'save_args': save_args}, force=True)
         
     | 
| 
      
 278 
     | 
    
         
            +
                            self.checkpointer.wait_until_finished()
         
     | 
| 
      
 279 
     | 
    
         
            +
                            pass
         
     | 
| 
      
 280 
     | 
    
         
            +
                        except Exception as e:
         
     | 
| 
      
 281 
     | 
    
         
            +
                            print("Error saving checkpoint", e)
         
     | 
| 
       277 
282 
     | 
    
         
             
                    except Exception as e:
         
     | 
| 
       278 
     | 
    
         
            -
                        print("Error saving checkpoint", e)
         
     | 
| 
      
 283 
     | 
    
         
            +
                        print("Error saving checkpoint outer", e)
         
     | 
| 
       279 
284 
     | 
    
         | 
| 
       280 
285 
     | 
    
         
             
                def _define_train_step(self, **kwargs):
         
     | 
| 
       281 
286 
     | 
    
         
             
                    model = self.model
         
     | 
| 
         @@ -304,21 +309,26 @@ class SimpleTrainer: 
     | 
|
| 
       304 
309 
     | 
    
         
             
                        train_step = jax.pmap(train_step)
         
     | 
| 
       305 
310 
     | 
    
         
             
                    return train_step
         
     | 
| 
       306 
311 
     | 
    
         | 
| 
       307 
     | 
    
         
            -
                def  
     | 
| 
      
 312 
     | 
    
         
            +
                def _define_vaidation_step(self):
         
     | 
| 
       308 
313 
     | 
    
         
             
                    model = self.model
         
     | 
| 
       309 
314 
     | 
    
         
             
                    loss_fn = self.loss_fn
         
     | 
| 
      
 315 
     | 
    
         
            +
                    distributed_training = self.distributed_training
         
     | 
| 
       310 
316 
     | 
    
         | 
| 
       311 
     | 
    
         
            -
                     
     | 
| 
       312 
     | 
    
         
            -
                    def compute_metrics(state: SimpleTrainState, batch):
         
     | 
| 
      
 317 
     | 
    
         
            +
                    def validation_step(state: SimpleTrainState, batch):
         
     | 
| 
       313 
318 
     | 
    
         
             
                        preds = model.apply(state.params, batch['image'])
         
     | 
| 
       314 
319 
     | 
    
         
             
                        expected_output = batch['label']
         
     | 
| 
       315 
320 
     | 
    
         
             
                        loss = jnp.mean(loss_fn(preds, expected_output))
         
     | 
| 
      
 321 
     | 
    
         
            +
                        if distributed_training:
         
     | 
| 
      
 322 
     | 
    
         
            +
                            loss = jax.lax.pmean(loss, "data")
         
     | 
| 
       316 
323 
     | 
    
         
             
                        metric_updates = state.metrics.single_from_model_output(
         
     | 
| 
       317 
324 
     | 
    
         
             
                            loss=loss, logits=preds, labels=expected_output)
         
     | 
| 
       318 
325 
     | 
    
         
             
                        metrics = state.metrics.merge(metric_updates)
         
     | 
| 
       319 
326 
     | 
    
         
             
                        state = state.replace(metrics=metrics)
         
     | 
| 
       320 
327 
     | 
    
         
             
                        return state
         
     | 
| 
       321 
     | 
    
         
            -
                     
     | 
| 
      
 328 
     | 
    
         
            +
                    if distributed_training:
         
     | 
| 
      
 329 
     | 
    
         
            +
                        validation_step = shard_map(validation_step, mesh=self.mesh, in_specs=(P(), P('data')), out_specs=(P()))
         
     | 
| 
      
 330 
     | 
    
         
            +
                        validation_step = jax.pmap(validation_step)
         
     | 
| 
      
 331 
     | 
    
         
            +
                    return validation_step
         
     | 
| 
       322 
332 
     | 
    
         | 
| 
       323 
333 
     | 
    
         
             
                def summary(self):
         
     | 
| 
       324 
334 
     | 
    
         
             
                    input_vars = self.get_input_ones()
         
     | 
| 
         @@ -343,17 +353,53 @@ class SimpleTrainer: 
     | 
|
| 
       343 
353 
     | 
    
         
             
                        "batch_size": batch_size
         
     | 
| 
       344 
354 
     | 
    
         
             
                    })
         
     | 
| 
       345 
355 
     | 
    
         
             
                    return summary_writer
         
     | 
| 
      
 356 
     | 
    
         
            +
                
         
     | 
| 
      
 357 
     | 
    
         
            +
                def validation_loop(
         
     | 
| 
      
 358 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 359 
     | 
    
         
            +
                    val_state: SimpleTrainState,
         
     | 
| 
      
 360 
     | 
    
         
            +
                    val_step_fn: Callable,
         
     | 
| 
      
 361 
     | 
    
         
            +
                    val_ds,
         
     | 
| 
      
 362 
     | 
    
         
            +
                    val_steps_per_epoch,
         
     | 
| 
      
 363 
     | 
    
         
            +
                    current_step,
         
     | 
| 
      
 364 
     | 
    
         
            +
                ):
         
     | 
| 
      
 365 
     | 
    
         
            +
                    global_device_count = jax.device_count()
         
     | 
| 
      
 366 
     | 
    
         
            +
                    local_device_count = jax.local_device_count()
         
     | 
| 
      
 367 
     | 
    
         
            +
                    process_index = jax.process_index()
         
     | 
| 
      
 368 
     | 
    
         
            +
                    
         
     | 
| 
      
 369 
     | 
    
         
            +
                    val_ds = iter(val_ds()) if val_ds else None
         
     | 
| 
      
 370 
     | 
    
         
            +
                    # Evaluation step
         
     | 
| 
      
 371 
     | 
    
         
            +
                    try:
         
     | 
| 
      
 372 
     | 
    
         
            +
                        for i in range(val_steps_per_epoch):
         
     | 
| 
      
 373 
     | 
    
         
            +
                            if val_ds is None:
         
     | 
| 
      
 374 
     | 
    
         
            +
                                batch = None
         
     | 
| 
      
 375 
     | 
    
         
            +
                            else:
         
     | 
| 
      
 376 
     | 
    
         
            +
                                batch = next(val_ds)
         
     | 
| 
      
 377 
     | 
    
         
            +
                                if self.distributed_training and global_device_count > 1:
         
     | 
| 
      
 378 
     | 
    
         
            +
                                    batch = convert_to_global_tree(self.mesh, batch)
         
     | 
| 
      
 379 
     | 
    
         
            +
                            if i == 0:
         
     | 
| 
      
 380 
     | 
    
         
            +
                                print(f"Evaluation started for process index {process_index}")
         
     | 
| 
      
 381 
     | 
    
         
            +
                            metrics = val_step_fn(val_state, batch)
         
     | 
| 
      
 382 
     | 
    
         
            +
                            if self.wandb is not None:
         
     | 
| 
      
 383 
     | 
    
         
            +
                                # metrics is a dict of metrics
         
     | 
| 
      
 384 
     | 
    
         
            +
                                if metrics and type(metrics) == dict:
         
     | 
| 
      
 385 
     | 
    
         
            +
                                    for key, value in metrics.items():
         
     | 
| 
      
 386 
     | 
    
         
            +
                                        if isinstance(value, jnp.ndarray):
         
     | 
| 
      
 387 
     | 
    
         
            +
                                            value = np.array(value)
         
     | 
| 
      
 388 
     | 
    
         
            +
                                        self.wandb.log({
         
     | 
| 
      
 389 
     | 
    
         
            +
                                            f"val/{key}": value,
         
     | 
| 
      
 390 
     | 
    
         
            +
                                        }, step=current_step)
         
     | 
| 
      
 391 
     | 
    
         
            +
                    except Exception as e:
         
     | 
| 
      
 392 
     | 
    
         
            +
                        print("Error logging images to wandb", e)
         
     | 
| 
       346 
393 
     | 
    
         | 
| 
       347 
     | 
    
         
            -
                def  
     | 
| 
       348 
     | 
    
         
            -
                     
     | 
| 
       349 
     | 
    
         
            -
                     
     | 
| 
       350 
     | 
    
         
            -
             
     | 
| 
       351 
     | 
    
         
            -
                     
     | 
| 
       352 
     | 
    
         
            -
             
     | 
| 
       353 
     | 
    
         
            -
                     
     | 
| 
       354 
     | 
    
         
            -
                     
     | 
| 
       355 
     | 
    
         
            -
             
     | 
| 
       356 
     | 
    
         
            -
                    rng_state = self.rngstate
         
     | 
| 
      
 394 
     | 
    
         
            +
                def train_loop(
         
     | 
| 
      
 395 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 396 
     | 
    
         
            +
                    train_state: SimpleTrainState,
         
     | 
| 
      
 397 
     | 
    
         
            +
                    train_step_fn: Callable,
         
     | 
| 
      
 398 
     | 
    
         
            +
                    train_ds,
         
     | 
| 
      
 399 
     | 
    
         
            +
                    train_steps_per_epoch,
         
     | 
| 
      
 400 
     | 
    
         
            +
                    current_step,
         
     | 
| 
      
 401 
     | 
    
         
            +
                    rng_state
         
     | 
| 
      
 402 
     | 
    
         
            +
                ):
         
     | 
| 
       357 
403 
     | 
    
         
             
                    global_device_count = jax.device_count()
         
     | 
| 
       358 
404 
     | 
    
         
             
                    local_device_count = jax.local_device_count()
         
     | 
| 
       359 
405 
     | 
    
         
             
                    process_index = jax.process_index()
         
     | 
| 
         @@ -361,67 +407,105 @@ class SimpleTrainer: 
     | 
|
| 
       361 
407 
     | 
    
         
             
                        global_device_indexes = jnp.arange(global_device_count)
         
     | 
| 
       362 
408 
     | 
    
         
             
                    else:
         
     | 
| 
       363 
409 
     | 
    
         
             
                        global_device_indexes = 0
         
     | 
| 
      
 410 
     | 
    
         
            +
                        
         
     | 
| 
      
 411 
     | 
    
         
            +
                    epoch_loss = 0
         
     | 
| 
      
 412 
     | 
    
         
            +
                    current_epoch = current_step // train_steps_per_epoch
         
     | 
| 
      
 413 
     | 
    
         
            +
                    last_save_time = time.time()
         
     | 
| 
      
 414 
     | 
    
         
            +
                    
         
     | 
| 
      
 415 
     | 
    
         
            +
                    if process_index == 0:
         
     | 
| 
      
 416 
     | 
    
         
            +
                        pbar = tqdm.tqdm(total=train_steps_per_epoch, desc=f'\t\tEpoch {current_epoch}', ncols=100, unit='step')
         
     | 
| 
      
 417 
     | 
    
         
            +
                        
         
     | 
| 
      
 418 
     | 
    
         
            +
                    for i in range(train_steps_per_epoch):
         
     | 
| 
      
 419 
     | 
    
         
            +
                        batch = next(train_ds)
         
     | 
| 
      
 420 
     | 
    
         
            +
                        if i == 0:
         
     | 
| 
      
 421 
     | 
    
         
            +
                            print(f"First batch loaded at step {current_step}")
         
     | 
| 
      
 422 
     | 
    
         
            +
                            
         
     | 
| 
      
 423 
     | 
    
         
            +
                        if self.distributed_training and global_device_count > 1:
         
     | 
| 
      
 424 
     | 
    
         
            +
                        #     # Convert the local device batches to a unified global jax.Array 
         
     | 
| 
      
 425 
     | 
    
         
            +
                            batch = convert_to_global_tree(self.mesh, batch)
         
     | 
| 
      
 426 
     | 
    
         
            +
                        train_state, loss, rng_state = train_step_fn(train_state, rng_state, batch, global_device_indexes)
         
     | 
| 
       364 
427 
     | 
    
         | 
| 
       365 
     | 
    
         
            -
             
     | 
| 
       366 
     | 
    
         
            -
             
     | 
| 
       367 
     | 
    
         
            -
             
     | 
| 
       368 
     | 
    
         
            -
                         
     | 
| 
       369 
     | 
    
         
            -
             
     | 
| 
       370 
     | 
    
         
            -
                             
     | 
| 
       371 
     | 
    
         
            -
                            if self.distributed_training and global_device_count > 1:
         
     | 
| 
       372 
     | 
    
         
            -
                                # Convert the local device batches to a unified global jax.Array 
         
     | 
| 
       373 
     | 
    
         
            -
                                batch = convert_to_global_tree(self.mesh, batch)
         
     | 
| 
       374 
     | 
    
         
            -
                            train_state, loss, rng_state = train_step(train_state, rng_state, batch, global_device_indexes)
         
     | 
| 
       375 
     | 
    
         
            -
             
     | 
| 
       376 
     | 
    
         
            -
                            if self.distributed_training:
         
     | 
| 
       377 
     | 
    
         
            -
                                loss = jax.experimental.multihost_utils.process_allgather(loss)
         
     | 
| 
       378 
     | 
    
         
            -
                                loss = jnp.mean(loss) # Just to make sure its a scaler value
         
     | 
| 
      
 428 
     | 
    
         
            +
                        if i == 0:
         
     | 
| 
      
 429 
     | 
    
         
            +
                            print(f"Training started for process index {process_index} at step {current_step}")
         
     | 
| 
      
 430 
     | 
    
         
            +
                            
         
     | 
| 
      
 431 
     | 
    
         
            +
                        if self.distributed_training:
         
     | 
| 
      
 432 
     | 
    
         
            +
                            # loss = jax.experimental.multihost_utils.process_allgather(loss)
         
     | 
| 
      
 433 
     | 
    
         
            +
                            loss = jnp.mean(loss) # Just to make sure its a scaler value
         
     | 
| 
       379 
434 
     | 
    
         | 
| 
       380 
     | 
    
         
            -
             
     | 
| 
       381 
     | 
    
         
            -
             
     | 
| 
       382 
     | 
    
         
            -
             
     | 
| 
       383 
     | 
    
         
            -
             
     | 
| 
       384 
     | 
    
         
            -
             
     | 
| 
      
 435 
     | 
    
         
            +
                        if loss <= 1e-6:
         
     | 
| 
      
 436 
     | 
    
         
            +
                            # If the loss is too low, we can assume the model has diverged
         
     | 
| 
      
 437 
     | 
    
         
            +
                            print(colored(f"Loss too low at step {current_step} => {loss}", 'red'))
         
     | 
| 
      
 438 
     | 
    
         
            +
                            # Reset the model to the old state
         
     | 
| 
      
 439 
     | 
    
         
            +
                            exit(1)
         
     | 
| 
       385 
440 
     | 
    
         | 
| 
       386 
     | 
    
         
            -
             
     | 
| 
       387 
     | 
    
         
            -
             
     | 
| 
       388 
     | 
    
         
            -
             
     | 
| 
       389 
     | 
    
         
            -
             
     | 
| 
       390 
     | 
    
         
            -
             
     | 
| 
       391 
     | 
    
         
            -
             
     | 
| 
       392 
     | 
    
         
            -
             
     | 
| 
       393 
     | 
    
         
            -
             
     | 
| 
       394 
     | 
    
         
            -
             
     | 
| 
       395 
     | 
    
         
            -
             
     | 
| 
       396 
     | 
    
         
            -
             
     | 
| 
       397 
     | 
    
         
            -
             
     | 
| 
       398 
     | 
    
         
            -
             
     | 
| 
       399 
     | 
    
         
            -
             
     | 
| 
       400 
     | 
    
         
            -
             
     | 
| 
       401 
     | 
    
         
            -
             
     | 
| 
       402 
     | 
    
         
            -
             
     | 
| 
       403 
     | 
    
         
            -
             
     | 
| 
       404 
     | 
    
         
            -
             
     | 
| 
       405 
     | 
    
         
            -
                     
     | 
| 
       406 
     | 
    
         
            -
                         
     | 
| 
      
 441 
     | 
    
         
            +
                        epoch_loss += loss
         
     | 
| 
      
 442 
     | 
    
         
            +
                        current_step += 1
         
     | 
| 
      
 443 
     | 
    
         
            +
                        if i % 100 == 0:
         
     | 
| 
      
 444 
     | 
    
         
            +
                            if pbar is not None:
         
     | 
| 
      
 445 
     | 
    
         
            +
                                pbar.set_postfix(loss=f'{loss:.4f}')
         
     | 
| 
      
 446 
     | 
    
         
            +
                                pbar.update(100)
         
     | 
| 
      
 447 
     | 
    
         
            +
                                if self.wandb is not None:
         
     | 
| 
      
 448 
     | 
    
         
            +
                                    self.wandb.log({
         
     | 
| 
      
 449 
     | 
    
         
            +
                                        "train/step" : current_step,
         
     | 
| 
      
 450 
     | 
    
         
            +
                                        "train/loss": loss,
         
     | 
| 
      
 451 
     | 
    
         
            +
                                    }, step=current_step)
         
     | 
| 
      
 452 
     | 
    
         
            +
                            # Save the model every few steps
         
     | 
| 
      
 453 
     | 
    
         
            +
                            if i % 10000 == 0 and i > 0:
         
     | 
| 
      
 454 
     | 
    
         
            +
                                print(f"Saving model after 10000 step {current_step}")
         
     | 
| 
      
 455 
     | 
    
         
            +
                                print(f"Devices: {len(jax.devices())}") # To sync the devices
         
     | 
| 
      
 456 
     | 
    
         
            +
                                self.save(current_epoch, current_step, train_state, rng_state)
         
     | 
| 
      
 457 
     | 
    
         
            +
                                print(f"Saving done by process index {process_index}")
         
     | 
| 
      
 458 
     | 
    
         
            +
                                last_save_time = time.time()
         
     | 
| 
      
 459 
     | 
    
         
            +
                    print(colored(f"Epoch done on index {process_index} => {current_epoch} Loss: {epoch_loss/train_steps_per_epoch}", 'green'))
         
     | 
| 
      
 460 
     | 
    
         
            +
                    if pbar is not None:
         
     | 
| 
      
 461 
     | 
    
         
            +
                        pbar.close()
         
     | 
| 
      
 462 
     | 
    
         
            +
                    return epoch_loss, current_step, train_state, rng_state
         
     | 
| 
      
 463 
     | 
    
         
            +
             
     | 
| 
      
 464 
     | 
    
         
            +
             
     | 
| 
      
 465 
     | 
    
         
            +
                def fit(self, data, train_steps_per_epoch, epochs, train_step_args={}, val_steps_per_epoch=5, validation_step_args={}):
         
     | 
| 
      
 466 
     | 
    
         
            +
                    train_ds = iter(data['train']())
         
     | 
| 
      
 467 
     | 
    
         
            +
                    train_step = self._define_train_step(**train_step_args)
         
     | 
| 
      
 468 
     | 
    
         
            +
                    val_step = self._define_vaidation_step(**validation_step_args)
         
     | 
| 
      
 469 
     | 
    
         
            +
                    train_state = self.state
         
     | 
| 
      
 470 
     | 
    
         
            +
                    rng_state = self.rngstate
         
     | 
| 
      
 471 
     | 
    
         
            +
                    process_index = jax.process_index()
         
     | 
| 
      
 472 
     | 
    
         
            +
                    
         
     | 
| 
      
 473 
     | 
    
         
            +
                    if val_steps_per_epoch > 0:
         
     | 
| 
      
 474 
     | 
    
         
            +
                        # We should first run a validation step to make sure the model is working
         
     | 
| 
      
 475 
     | 
    
         
            +
                        print(f"Validation run for sanity check for process index {process_index}")
         
     | 
| 
      
 476 
     | 
    
         
            +
                        # Validation step
         
     | 
| 
      
 477 
     | 
    
         
            +
                        self.validation_loop(
         
     | 
| 
      
 478 
     | 
    
         
            +
                            train_state,
         
     | 
| 
      
 479 
     | 
    
         
            +
                            val_step,
         
     | 
| 
      
 480 
     | 
    
         
            +
                            data.get('test', data.get('val', None)),
         
     | 
| 
      
 481 
     | 
    
         
            +
                            val_steps_per_epoch,
         
     | 
| 
      
 482 
     | 
    
         
            +
                            self.latest_step,
         
     | 
| 
      
 483 
     | 
    
         
            +
                        )
         
     | 
| 
      
 484 
     | 
    
         
            +
                        print(colored(f"Sanity Validation done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
         
     | 
| 
      
 485 
     | 
    
         
            +
                            
         
     | 
| 
      
 486 
     | 
    
         
            +
                    while self.latest_step < epochs * train_steps_per_epoch:
         
     | 
| 
      
 487 
     | 
    
         
            +
                        current_epoch = self.latest_step // train_steps_per_epoch
         
     | 
| 
       407 
488 
     | 
    
         
             
                        print(f"\nEpoch {current_epoch}/{epochs}")
         
     | 
| 
       408 
489 
     | 
    
         
             
                        start_time = time.time()
         
     | 
| 
       409 
490 
     | 
    
         
             
                        epoch_loss = 0
         
     | 
| 
       410 
     | 
    
         
            -
             
     | 
| 
       411 
     | 
    
         
            -
                         
     | 
| 
       412 
     | 
    
         
            -
                             
     | 
| 
       413 
     | 
    
         
            -
             
     | 
| 
       414 
     | 
    
         
            -
             
     | 
| 
       415 
     | 
    
         
            -
                             
     | 
| 
       416 
     | 
    
         
            -
                             
     | 
| 
      
 491 
     | 
    
         
            +
                        
         
     | 
| 
      
 492 
     | 
    
         
            +
                        epoch_loss, current_step, train_state, rng_state = self.train_loop(
         
     | 
| 
      
 493 
     | 
    
         
            +
                            train_state,
         
     | 
| 
      
 494 
     | 
    
         
            +
                            train_step,
         
     | 
| 
      
 495 
     | 
    
         
            +
                            train_ds,
         
     | 
| 
      
 496 
     | 
    
         
            +
                            train_steps_per_epoch,
         
     | 
| 
      
 497 
     | 
    
         
            +
                            self.latest_step,
         
     | 
| 
      
 498 
     | 
    
         
            +
                            rng_state,
         
     | 
| 
      
 499 
     | 
    
         
            +
                        )
         
     | 
| 
      
 500 
     | 
    
         
            +
                        print(colored(f"Epoch done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
         
     | 
| 
       417 
501 
     | 
    
         | 
| 
       418 
502 
     | 
    
         
             
                        self.latest_step = current_step
         
     | 
| 
       419 
503 
     | 
    
         
             
                        end_time = time.time()
         
     | 
| 
       420 
504 
     | 
    
         
             
                        self.state = train_state
         
     | 
| 
       421 
505 
     | 
    
         
             
                        self.rngstate = rng_state
         
     | 
| 
       422 
506 
     | 
    
         
             
                        total_time = end_time - start_time
         
     | 
| 
       423 
     | 
    
         
            -
                        avg_time_per_step = total_time /  
     | 
| 
       424 
     | 
    
         
            -
                        avg_loss = epoch_loss /  
     | 
| 
      
 507 
     | 
    
         
            +
                        avg_time_per_step = total_time / train_steps_per_epoch
         
     | 
| 
      
 508 
     | 
    
         
            +
                        avg_loss = epoch_loss / train_steps_per_epoch
         
     | 
| 
       425 
509 
     | 
    
         
             
                        if avg_loss < self.best_loss:
         
     | 
| 
       426 
510 
     | 
    
         
             
                            self.best_loss = avg_loss
         
     | 
| 
       427 
511 
     | 
    
         
             
                            self.best_state = train_state
         
     | 
| 
         @@ -437,6 +521,18 @@ class SimpleTrainer: 
     | 
|
| 
       437 
521 
     | 
    
         
             
                                    "train/epoch": current_epoch,
         
     | 
| 
       438 
522 
     | 
    
         
             
                                }, step=current_step)
         
     | 
| 
       439 
523 
     | 
    
         
             
                            print(colored(f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss}", 'green'))
         
     | 
| 
       440 
     | 
    
         
            -
             
     | 
| 
      
 524 
     | 
    
         
            +
                                
         
     | 
| 
      
 525 
     | 
    
         
            +
                        if val_steps_per_epoch > 0:
         
     | 
| 
      
 526 
     | 
    
         
            +
                            print(f"Validation started for process index {process_index}")
         
     | 
| 
      
 527 
     | 
    
         
            +
                            # Validation step
         
     | 
| 
      
 528 
     | 
    
         
            +
                            self.validation_loop(
         
     | 
| 
      
 529 
     | 
    
         
            +
                                train_state,
         
     | 
| 
      
 530 
     | 
    
         
            +
                                val_step,
         
     | 
| 
      
 531 
     | 
    
         
            +
                                data.get('test', None),
         
     | 
| 
      
 532 
     | 
    
         
            +
                                val_steps_per_epoch,
         
     | 
| 
      
 533 
     | 
    
         
            +
                                current_step,
         
     | 
| 
      
 534 
     | 
    
         
            +
                            )
         
     | 
| 
      
 535 
     | 
    
         
            +
                            print(colored(f"Validation done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
         
     | 
| 
      
 536 
     | 
    
         
            +
                            
         
     | 
| 
       441 
537 
     | 
    
         
             
                    self.save(epochs)
         
     | 
| 
       442 
538 
     | 
    
         
             
                    return self.state
         
     | 
| 
         @@ -0,0 +1,62 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            import flax
         
     | 
| 
      
 2 
     | 
    
         
            +
            from flax import linen as nn
         
     | 
| 
      
 3 
     | 
    
         
            +
            import jax
         
     | 
| 
      
 4 
     | 
    
         
            +
            from typing import Callable
         
     | 
| 
      
 5 
     | 
    
         
            +
            from dataclasses import field
         
     | 
| 
      
 6 
     | 
    
         
            +
            import jax.numpy as jnp
         
     | 
| 
      
 7 
     | 
    
         
            +
            import optax
         
     | 
| 
      
 8 
     | 
    
         
            +
            import functools
         
     | 
| 
      
 9 
     | 
    
         
            +
            from jax.sharding import Mesh, PartitionSpec as P
         
     | 
| 
      
 10 
     | 
    
         
            +
            from jax.experimental.shard_map import shard_map
         
     | 
| 
      
 11 
     | 
    
         
            +
            from typing import Dict, Callable, Sequence, Any, Union, Tuple
         
     | 
| 
      
 12 
     | 
    
         
            +
             
     | 
| 
      
 13 
     | 
    
         
            +
            from ..schedulers import NoiseScheduler
         
     | 
| 
      
 14 
     | 
    
         
            +
            from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
         
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
            from flaxdiff.utils import RandomMarkovState
         
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
            from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
         
     | 
| 
      
 19 
     | 
    
         
            +
             
     | 
| 
      
 20 
     | 
    
         
            +
            from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
         
     | 
| 
      
 21 
     | 
    
         
            +
            from flax.training import dynamic_scale as dynamic_scale_lib
         
     | 
| 
      
 22 
     | 
    
         
            +
             
     | 
| 
      
 23 
     | 
    
         
            +
            class TrainState(SimpleTrainState):
         
     | 
| 
      
 24 
     | 
    
         
            +
                rngs: jax.random.PRNGKey
         
     | 
| 
      
 25 
     | 
    
         
            +
                ema_params: dict
         
     | 
| 
      
 26 
     | 
    
         
            +
             
     | 
| 
      
 27 
     | 
    
         
            +
                def apply_ema(self, decay: float = 0.999):
         
     | 
| 
      
 28 
     | 
    
         
            +
                    new_ema_params = jax.tree_util.tree_map(
         
     | 
| 
      
 29 
     | 
    
         
            +
                        lambda ema, param: decay * ema + (1 - decay) * param,
         
     | 
| 
      
 30 
     | 
    
         
            +
                        self.ema_params,
         
     | 
| 
      
 31 
     | 
    
         
            +
                        self.params,
         
     | 
| 
      
 32 
     | 
    
         
            +
                    )
         
     | 
| 
      
 33 
     | 
    
         
            +
                    return self.replace(ema_params=new_ema_params)
         
     | 
| 
      
 34 
     | 
    
         
            +
             
     | 
| 
      
 35 
     | 
    
         
            +
            from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
         
     | 
| 
      
 36 
     | 
    
         
            +
            from flaxdiff.trainer.diffusion_trainer import DiffusionTrainer
         
     | 
| 
      
 37 
     | 
    
         
            +
             
     | 
| 
      
 38 
     | 
    
         
            +
            class SimpleVideoDiffusionTrainer(DiffusionTrainer):
         
     | 
| 
      
 39 
     | 
    
         
            +
                def __init__(self,
         
     | 
| 
      
 40 
     | 
    
         
            +
                             model: nn.Module,
         
     | 
| 
      
 41 
     | 
    
         
            +
                             input_shapes: Dict[str, Tuple[int]],
         
     | 
| 
      
 42 
     | 
    
         
            +
                             optimizer: optax.GradientTransformation,
         
     | 
| 
      
 43 
     | 
    
         
            +
                             noise_schedule: NoiseScheduler,
         
     | 
| 
      
 44 
     | 
    
         
            +
                             rngs: jax.random.PRNGKey,
         
     | 
| 
      
 45 
     | 
    
         
            +
                             unconditional_prob: float = 0.12,
         
     | 
| 
      
 46 
     | 
    
         
            +
                             name: str = "SimpleVideoDiffusion",
         
     | 
| 
      
 47 
     | 
    
         
            +
                             model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
         
     | 
| 
      
 48 
     | 
    
         
            +
                             autoencoder: AutoEncoder = None,
         
     | 
| 
      
 49 
     | 
    
         
            +
                             **kwargs
         
     | 
| 
      
 50 
     | 
    
         
            +
                             ):
         
     | 
| 
      
 51 
     | 
    
         
            +
                    super().__init__(
         
     | 
| 
      
 52 
     | 
    
         
            +
                        model=model,
         
     | 
| 
      
 53 
     | 
    
         
            +
                        input_shapes=input_shapes,
         
     | 
| 
      
 54 
     | 
    
         
            +
                        optimizer=optimizer,
         
     | 
| 
      
 55 
     | 
    
         
            +
                        noise_schedule=noise_schedule,
         
     | 
| 
      
 56 
     | 
    
         
            +
                        unconditional_prob=unconditional_prob,
         
     | 
| 
      
 57 
     | 
    
         
            +
                        autoencoder=autoencoder,
         
     | 
| 
      
 58 
     | 
    
         
            +
                        model_output_transform=model_output_transform,
         
     | 
| 
      
 59 
     | 
    
         
            +
                        rngs=rngs,
         
     | 
| 
      
 60 
     | 
    
         
            +
                        name=name,
         
     | 
| 
      
 61 
     | 
    
         
            +
                        **kwargs
         
     | 
| 
      
 62 
     | 
    
         
            +
                    )
         
     | 
    
        flaxdiff/utils.py
    CHANGED
    
    | 
         @@ -2,7 +2,12 @@ import jax 
     | 
|
| 
       2 
2 
     | 
    
         
             
            import jax.numpy as jnp
         
     | 
| 
       3 
3 
     | 
    
         
             
            import flax.struct as struct
         
     | 
| 
       4 
4 
     | 
    
         
             
            import flax.linen as nn
         
     | 
| 
       5 
     | 
    
         
            -
            from typing import Any
         
     | 
| 
      
 5 
     | 
    
         
            +
            from typing import Any, Callable
         
     | 
| 
      
 6 
     | 
    
         
            +
            from dataclasses import dataclass
         
     | 
| 
      
 7 
     | 
    
         
            +
            from functools import partial
         
     | 
| 
      
 8 
     | 
    
         
            +
            import numpy as np
         
     | 
| 
      
 9 
     | 
    
         
            +
            from jax.sharding import Mesh, PartitionSpec as P
         
     | 
| 
      
 10 
     | 
    
         
            +
            from abc import ABC, abstractmethod
         
     | 
| 
       6 
11 
     | 
    
         | 
| 
       7 
12 
     | 
    
         
             
            class MarkovState(struct.PyTreeNode):
         
     | 
| 
       8 
13 
     | 
    
         
             
                pass
         
     | 
| 
         @@ -17,6 +22,30 @@ class RandomMarkovState(MarkovState): 
     | 
|
| 
       17 
22 
     | 
    
         
             
            def clip_images(images, clip_min=-1, clip_max=1):
         
     | 
| 
       18 
23 
     | 
    
         
             
                return jnp.clip(images, clip_min, clip_max)
         
     | 
| 
       19 
24 
     | 
    
         | 
| 
      
 25 
     | 
    
         
            +
            def _build_global_shape_and_sharding(
         
     | 
| 
      
 26 
     | 
    
         
            +
                local_shape: tuple[int, ...], global_mesh: Mesh
         
     | 
| 
      
 27 
     | 
    
         
            +
            ) -> tuple[tuple[int, ...], jax.sharding.NamedSharding]:
         
     | 
| 
      
 28 
     | 
    
         
            +
              sharding = jax.sharding.NamedSharding(global_mesh, P(global_mesh.axis_names))
         
     | 
| 
      
 29 
     | 
    
         
            +
              global_shape = (jax.process_count() * local_shape[0],) + local_shape[1:]
         
     | 
| 
      
 30 
     | 
    
         
            +
              return global_shape, sharding
         
     | 
| 
      
 31 
     | 
    
         
            +
             
     | 
| 
      
 32 
     | 
    
         
            +
             
     | 
| 
      
 33 
     | 
    
         
            +
            def form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array:
         
     | 
| 
      
 34 
     | 
    
         
            +
              """Put local sharded array into local devices"""
         
     | 
| 
      
 35 
     | 
    
         
            +
              global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh)
         
     | 
| 
      
 36 
     | 
    
         
            +
              try:
         
     | 
| 
      
 37 
     | 
    
         
            +
                local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0)
         
     | 
| 
      
 38 
     | 
    
         
            +
              except ValueError as array_split_error:
         
     | 
| 
      
 39 
     | 
    
         
            +
                raise ValueError(
         
     | 
| 
      
 40 
     | 
    
         
            +
                    f"Unable to put to devices shape {array.shape} with "
         
     | 
| 
      
 41 
     | 
    
         
            +
                    f"local device count {len(global_mesh.local_devices)} "
         
     | 
| 
      
 42 
     | 
    
         
            +
                ) from array_split_error
         
     | 
| 
      
 43 
     | 
    
         
            +
              local_device_buffers = jax.device_put(local_device_arrays, global_mesh.local_devices)
         
     | 
| 
      
 44 
     | 
    
         
            +
              return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers)
         
     | 
| 
      
 45 
     | 
    
         
            +
             
     | 
| 
      
 46 
     | 
    
         
            +
            def convert_to_global_tree(global_mesh, pytree):
         
     | 
| 
      
 47 
     | 
    
         
            +
                return jax.tree_util.tree_map_with_path(partial(form_global_array, global_mesh=global_mesh), pytree)
         
     | 
| 
      
 48 
     | 
    
         
            +
             
     | 
| 
       20 
49 
     | 
    
         
             
            class RMSNorm(nn.Module):
         
     | 
| 
       21 
50 
     | 
    
         
             
                """
         
     | 
| 
       22 
51 
     | 
    
         
             
                From "Root Mean Square Layer Normalization" by https://arxiv.org/abs/1910.07467
         
     | 
| 
         @@ -86,4 +115,78 @@ class RMSNorm(nn.Module): 
     | 
|
| 
       86 
115 
     | 
    
         
             
                        ).reshape(feature_shape)
         
     | 
| 
       87 
116 
     | 
    
         
             
                        mul *= scale
         
     | 
| 
       88 
117 
     | 
    
         
             
                    y = mul * x
         
     | 
| 
       89 
     | 
    
         
            -
                    return jnp.asarray(y, dtype)
         
     | 
| 
      
 118 
     | 
    
         
            +
                    return jnp.asarray(y, dtype)
         
     | 
| 
      
 119 
     | 
    
         
            +
             
     | 
| 
      
 120 
     | 
    
         
            +
            @dataclass
         
     | 
| 
      
 121 
     | 
    
         
            +
            class ConditioningEncoder(ABC):
         
     | 
| 
      
 122 
     | 
    
         
            +
                model: nn.Module
         
     | 
| 
      
 123 
     | 
    
         
            +
                tokenizer: Callable
         
     | 
| 
      
 124 
     | 
    
         
            +
             
     | 
| 
      
 125 
     | 
    
         
            +
                def __call__(self, data):
         
     | 
| 
      
 126 
     | 
    
         
            +
                    tokens = self.tokenize(data)
         
     | 
| 
      
 127 
     | 
    
         
            +
                    outputs = self.encode_from_tokens(tokens)
         
     | 
| 
      
 128 
     | 
    
         
            +
                    return outputs
         
     | 
| 
      
 129 
     | 
    
         
            +
                    
         
     | 
| 
      
 130 
     | 
    
         
            +
                def encode_from_tokens(self, tokens):
         
     | 
| 
      
 131 
     | 
    
         
            +
                    outputs = self.model(input_ids=tokens['input_ids'],
         
     | 
| 
      
 132 
     | 
    
         
            +
                                    attention_mask=tokens['attention_mask'])
         
     | 
| 
      
 133 
     | 
    
         
            +
                    last_hidden_state = outputs.last_hidden_state
         
     | 
| 
      
 134 
     | 
    
         
            +
                    return last_hidden_state
         
     | 
| 
      
 135 
     | 
    
         
            +
                
         
     | 
| 
      
 136 
     | 
    
         
            +
                def tokenize(self, data):
         
     | 
| 
      
 137 
     | 
    
         
            +
                    tokens = self.tokenizer(data, padding="max_length",
         
     | 
| 
      
 138 
     | 
    
         
            +
                                    max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np")
         
     | 
| 
      
 139 
     | 
    
         
            +
                    return tokens
         
     | 
| 
      
 140 
     | 
    
         
            +
                
         
     | 
| 
      
 141 
     | 
    
         
            +
            @dataclass
         
     | 
| 
      
 142 
     | 
    
         
            +
            class TextEncoder(ConditioningEncoder):
         
     | 
| 
      
 143 
     | 
    
         
            +
                # def __call__(self, data):
         
     | 
| 
      
 144 
     | 
    
         
            +
                #     tokens = self.tokenize(data)
         
     | 
| 
      
 145 
     | 
    
         
            +
                #     outputs = self.encode_from_tokens(tokens)
         
     | 
| 
      
 146 
     | 
    
         
            +
                #     return outputs
         
     | 
| 
      
 147 
     | 
    
         
            +
                    
         
     | 
| 
      
 148 
     | 
    
         
            +
                # def encode_from_tokens(self, tokens):
         
     | 
| 
      
 149 
     | 
    
         
            +
                #     outputs = self.model(input_ids=tokens['input_ids'],
         
     | 
| 
      
 150 
     | 
    
         
            +
                #                     attention_mask=tokens['attention_mask'])
         
     | 
| 
      
 151 
     | 
    
         
            +
                #     last_hidden_state = outputs.last_hidden_state
         
     | 
| 
      
 152 
     | 
    
         
            +
                #     # pooler_output = outputs.pooler_output  # pooled (EOS token) states
         
     | 
| 
      
 153 
     | 
    
         
            +
                #     # embed_pooled = pooler_output  # .astype(jnp.float16)
         
     | 
| 
      
 154 
     | 
    
         
            +
                #     embed_labels_full = last_hidden_state  # .astype(jnp.float16)
         
     | 
| 
      
 155 
     | 
    
         
            +
             
     | 
| 
      
 156 
     | 
    
         
            +
                #     return embed_labels_full
         
     | 
| 
      
 157 
     | 
    
         
            +
                pass
         
     | 
| 
      
 158 
     | 
    
         
            +
             
     | 
| 
      
 159 
     | 
    
         
            +
            class AutoTextTokenizer:
         
     | 
| 
      
 160 
     | 
    
         
            +
                def __init__(self, tensor_type="pt", modelname="openai/clip-vit-large-patch14"):
         
     | 
| 
      
 161 
     | 
    
         
            +
                    from transformers import AutoTokenizer
         
     | 
| 
      
 162 
     | 
    
         
            +
                    self.tokenizer = AutoTokenizer.from_pretrained(modelname)
         
     | 
| 
      
 163 
     | 
    
         
            +
                    self.tensor_type = tensor_type
         
     | 
| 
      
 164 
     | 
    
         
            +
             
     | 
| 
      
 165 
     | 
    
         
            +
                def __call__(self, inputs):
         
     | 
| 
      
 166 
     | 
    
         
            +
                    # print(caption)
         
     | 
| 
      
 167 
     | 
    
         
            +
                    tokens = self.tokenizer(inputs, padding="max_length", max_length=self.tokenizer.model_max_length,
         
     | 
| 
      
 168 
     | 
    
         
            +
                                            truncation=True, return_tensors=self.tensor_type)
         
     | 
| 
      
 169 
     | 
    
         
            +
                    # print(tokens.keys())
         
     | 
| 
      
 170 
     | 
    
         
            +
                    return {
         
     | 
| 
      
 171 
     | 
    
         
            +
                        "input_ids": tokens["input_ids"],
         
     | 
| 
      
 172 
     | 
    
         
            +
                        "attention_mask": tokens["attention_mask"],
         
     | 
| 
      
 173 
     | 
    
         
            +
                        "caption": inputs,
         
     | 
| 
      
 174 
     | 
    
         
            +
                    }
         
     | 
| 
      
 175 
     | 
    
         
            +
             
     | 
| 
      
 176 
     | 
    
         
            +
                def __repr__(self):
         
     | 
| 
      
 177 
     | 
    
         
            +
                    return self.__class__.__name__ + '()'
         
     | 
| 
      
 178 
     | 
    
         
            +
             
     | 
| 
      
 179 
     | 
    
         
            +
            def defaultTextEncodeModel(backend="jax"):
         
     | 
| 
      
 180 
     | 
    
         
            +
                from transformers import (
         
     | 
| 
      
 181 
     | 
    
         
            +
                    CLIPTextModel,
         
     | 
| 
      
 182 
     | 
    
         
            +
                    FlaxCLIPTextModel,
         
     | 
| 
      
 183 
     | 
    
         
            +
                    AutoTokenizer,
         
     | 
| 
      
 184 
     | 
    
         
            +
                )
         
     | 
| 
      
 185 
     | 
    
         
            +
                modelname = "openai/clip-vit-large-patch14"
         
     | 
| 
      
 186 
     | 
    
         
            +
                if backend == "jax":
         
     | 
| 
      
 187 
     | 
    
         
            +
                    model = FlaxCLIPTextModel.from_pretrained(
         
     | 
| 
      
 188 
     | 
    
         
            +
                        modelname, dtype=jnp.bfloat16)
         
     | 
| 
      
 189 
     | 
    
         
            +
                else:
         
     | 
| 
      
 190 
     | 
    
         
            +
                    model = CLIPTextModel.from_pretrained(modelname)
         
     | 
| 
      
 191 
     | 
    
         
            +
                tokenizer = AutoTokenizer.from_pretrained(modelname, dtype=jnp.float16)
         
     | 
| 
      
 192 
     | 
    
         
            +
                return TextEncoder(model, tokenizer)
         
     | 
| 
         @@ -1,15 +1,21 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            Metadata-Version: 2. 
     | 
| 
      
 1 
     | 
    
         
            +
            Metadata-Version: 2.4
         
     | 
| 
       2 
2 
     | 
    
         
             
            Name: flaxdiff
         
     | 
| 
       3 
     | 
    
         
            -
            Version: 0.1. 
     | 
| 
      
 3 
     | 
    
         
            +
            Version: 0.1.36
         
     | 
| 
       4 
4 
     | 
    
         
             
            Summary: A versatile and easy to understand Diffusion library
         
     | 
| 
       5 
5 
     | 
    
         
             
            Author: Ashish Kumar Singh
         
     | 
| 
       6 
6 
     | 
    
         
             
            Author-email: ashishkmr472@gmail.com
         
     | 
| 
       7 
7 
     | 
    
         
             
            Description-Content-Type: text/markdown
         
     | 
| 
       8 
     | 
    
         
            -
            Requires-Dist: flax 
     | 
| 
       9 
     | 
    
         
            -
            Requires-Dist: optax 
     | 
| 
       10 
     | 
    
         
            -
            Requires-Dist: jax 
     | 
| 
      
 8 
     | 
    
         
            +
            Requires-Dist: flax>=0.8.4
         
     | 
| 
      
 9 
     | 
    
         
            +
            Requires-Dist: optax>=0.2.2
         
     | 
| 
      
 10 
     | 
    
         
            +
            Requires-Dist: jax>=0.4.28
         
     | 
| 
       11 
11 
     | 
    
         
             
            Requires-Dist: orbax
         
     | 
| 
       12 
12 
     | 
    
         
             
            Requires-Dist: clu
         
     | 
| 
      
 13 
     | 
    
         
            +
            Dynamic: author
         
     | 
| 
      
 14 
     | 
    
         
            +
            Dynamic: author-email
         
     | 
| 
      
 15 
     | 
    
         
            +
            Dynamic: description
         
     | 
| 
      
 16 
     | 
    
         
            +
            Dynamic: description-content-type
         
     | 
| 
      
 17 
     | 
    
         
            +
            Dynamic: requires-dist
         
     | 
| 
      
 18 
     | 
    
         
            +
            Dynamic: summary
         
     | 
| 
       13 
19 
     | 
    
         | 
| 
       14 
20 
     | 
    
         
             
            # 
         
     | 
| 
       15 
21 
     | 
    
         | 
| 
         @@ -0,0 +1,43 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         
     | 
| 
      
 2 
     | 
    
         
            +
            flaxdiff/utils.py,sha256=b_hFXsam2NICQYCFk0EOcqtBjM-RUqnN0NKTn0lQ070,6532
         
     | 
| 
      
 3 
     | 
    
         
            +
            flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
         
     | 
| 
      
 4 
     | 
    
         
            +
            flaxdiff/data/dataset_map.py,sha256=hcHaoR2IbNQmfyPUhYd6_8xinurxxCqawQijAsDI0Ek,3093
         
     | 
| 
      
 5 
     | 
    
         
            +
            flaxdiff/data/datasets.py,sha256=YUMoSvF2yAyikRvRofZVlHwfEOU3zXSSG4KkLnVfpoA,5626
         
     | 
| 
      
 6 
     | 
    
         
            +
            flaxdiff/data/online_loader.py,sha256=1Fi_QRixxRzbt602nORINcDeHEccvCrBpagrz4PURYg,12499
         
     | 
| 
      
 7 
     | 
    
         
            +
            flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
         
     | 
| 
      
 8 
     | 
    
         
            +
            flaxdiff/models/attention.py,sha256=JvrP7-09MV6IfRLRBhqjPmNUU-lkEMk9TOnJSBKcar8,13289
         
     | 
| 
      
 9 
     | 
    
         
            +
            flaxdiff/models/common.py,sha256=hWsSs2BP2J-JN1s4qLRr-h-KYkcVyl2hOp1Wsm_L-h8,10994
         
     | 
| 
      
 10 
     | 
    
         
            +
            flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
         
     | 
| 
      
 11 
     | 
    
         
            +
            flaxdiff/models/simple_unet.py,sha256=L5m2j5580QP7pJ5VIme7U5xYA22PZiGP7qdvcKUnB38,11463
         
     | 
| 
      
 12 
     | 
    
         
            +
            flaxdiff/models/simple_vit.py,sha256=UCDDr0XVnpf6tbJWKFtEt3_nAqMqOoakXf5amyVWZNo,7929
         
     | 
| 
      
 13 
     | 
    
         
            +
            flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
         
     | 
| 
      
 14 
     | 
    
         
            +
            flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
         
     | 
| 
      
 15 
     | 
    
         
            +
            flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
         
     | 
| 
      
 16 
     | 
    
         
            +
            flaxdiff/models/autoencoder/simple_autoenc.py,sha256=UXHPgDmwGTnv3Uts6Zj3p9R9nJXnEiEXbllgarwDfXM,805
         
     | 
| 
      
 17 
     | 
    
         
            +
            flaxdiff/predictors/__init__.py,sha256=SKkYYRF9Wfgk2zhtZw4vCXOdOeRlrm2Mk6cvuaEvAzc,4403
         
     | 
| 
      
 18 
     | 
    
         
            +
            flaxdiff/samplers/__init__.py,sha256=_S-9TwDeshrI0VmapV-J2hqjTByOa0-oOeUs_IdovjU,285
         
     | 
| 
      
 19 
     | 
    
         
            +
            flaxdiff/samplers/common.py,sha256=ZA08VyovxegpRx4wOQq9LSwZi0gSCz2lrbS5oVYOEYg,8488
         
     | 
| 
      
 20 
     | 
    
         
            +
            flaxdiff/samplers/ddim.py,sha256=pB8Kod8ZLJ3GXev4uM3cOj1Uy6ibR0jsaZa-VE0fyJM,552
         
     | 
| 
      
 21 
     | 
    
         
            +
            flaxdiff/samplers/ddpm.py,sha256=u1OchQu0XPhc_6w9JXoaFp2wo4y-zXyQNtGAIJwxNLg,2209
         
     | 
| 
      
 22 
     | 
    
         
            +
            flaxdiff/samplers/euler.py,sha256=Htb-IJeu7jSgY6mvgYr9yl9pUnos49vijlVk5IQsRps,2740
         
     | 
| 
      
 23 
     | 
    
         
            +
            flaxdiff/samplers/heun_sampler.py,sha256=UyI-hSlyWvt-7VEUJj27zjgyzKkGVl8fDUHV-YpSOCc,1421
         
     | 
| 
      
 24 
     | 
    
         
            +
            flaxdiff/samplers/multistep_dpm.py,sha256=3Wu3MrMLYaBb1ObraTbWrJmtEtU0adl1dDbz5fPJ4Gs,2735
         
     | 
| 
      
 25 
     | 
    
         
            +
            flaxdiff/samplers/rk4_sampler.py,sha256=1j1pES_Q2QiaURvEWeedbbT1LHmkc3jsu0GgH83qBL0,1926
         
     | 
| 
      
 26 
     | 
    
         
            +
            flaxdiff/schedulers/__init__.py,sha256=3id390WEfdf-MN-oLSPAhlRFIXrFWr6ioAHPAwURJyE,375
         
     | 
| 
      
 27 
     | 
    
         
            +
            flaxdiff/schedulers/common.py,sha256=b-W4iI-aqScpVE8VZbBpiYvAVI6rqDkUP-C_hEVBwCI,4151
         
     | 
| 
      
 28 
     | 
    
         
            +
            flaxdiff/schedulers/continuous.py,sha256=5c_niOA20fxJ5oJDi09FfayIRogBGwtfG0XThW2IUZk,334
         
     | 
| 
      
 29 
     | 
    
         
            +
            flaxdiff/schedulers/cosine.py,sha256=9ban0dFHLMm35wQvaBT4nCQwPGmzNsXwQ1xI0oppmJI,2005
         
     | 
| 
      
 30 
     | 
    
         
            +
            flaxdiff/schedulers/discrete.py,sha256=O54wH2HVu3olJA71NxgAXFW9cr6B6Gl-DR_uZeytpds,3319
         
     | 
| 
      
 31 
     | 
    
         
            +
            flaxdiff/schedulers/exp.py,sha256=cPTnUJpYdzJRRZqMLYQz0rRUCpEmaP2tXhRumLx94jA,605
         
     | 
| 
      
 32 
     | 
    
         
            +
            flaxdiff/schedulers/karras.py,sha256=4GN120kGwdxxU-h2mVdhBVy9IORkUMm_vvz3XjthBcI,3355
         
     | 
| 
      
 33 
     | 
    
         
            +
            flaxdiff/schedulers/linear.py,sha256=6003F5ISq1Wc0h6UAzY95MJgsDIKGMhBzbiVALpea0k,581
         
     | 
| 
      
 34 
     | 
    
         
            +
            flaxdiff/schedulers/sqrt.py,sha256=1F84ZgQPuoNMhe6yxGTR2G0h7dPOZtm4UDQOakbSsEU,445
         
     | 
| 
      
 35 
     | 
    
         
            +
            flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,128
         
     | 
| 
      
 36 
     | 
    
         
            +
            flaxdiff/trainer/autoencoder_trainer.py,sha256=hxihkRL9WCIQVGOP-pc1jjjIUaRXDLcNo3_erTKsuWM,7049
         
     | 
| 
      
 37 
     | 
    
         
            +
            flaxdiff/trainer/diffusion_trainer.py,sha256=ajOWBgFFwXP_VQScUjcuPoaB4Gk02aF0Ls5LNlA8wqA,12691
         
     | 
| 
      
 38 
     | 
    
         
            +
            flaxdiff/trainer/simple_trainer.py,sha256=jCD9-qCwX0SC0rN3GrXUBfRrndWNqUI0HmbOAbmYBMM,21906
         
     | 
| 
      
 39 
     | 
    
         
            +
            flaxdiff/trainer/video_diffusion_trainer.py,sha256=gMkKpnKNTo8QhTx5ptEEkc7W5-7rzXIr9queU53hXyQ,2197
         
     | 
| 
      
 40 
     | 
    
         
            +
            flaxdiff-0.1.36.dist-info/METADATA,sha256=7fO1e_icIEK6dmSopv538Hm2fQnhnkOAE2Ab9inpcNE,22213
         
     | 
| 
      
 41 
     | 
    
         
            +
            flaxdiff-0.1.36.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
         
     | 
| 
      
 42 
     | 
    
         
            +
            flaxdiff-0.1.36.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
         
     | 
| 
      
 43 
     | 
    
         
            +
            flaxdiff-0.1.36.dist-info/RECORD,,
         
     |