flaxdiff 0.1.36.1__py3-none-any.whl → 0.1.36.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. flaxdiff/data/__init__.py +1 -0
  2. flaxdiff/data/dataset_map.py +71 -0
  3. flaxdiff/data/datasets.py +169 -0
  4. flaxdiff/data/online_loader.py +363 -0
  5. flaxdiff/data/sources/gcs.py +81 -0
  6. flaxdiff/data/sources/tfds.py +67 -0
  7. flaxdiff/metrics/inception.py +658 -0
  8. flaxdiff/metrics/utils.py +49 -0
  9. flaxdiff/models/__init__.py +1 -0
  10. flaxdiff/models/attention.py +368 -0
  11. flaxdiff/models/autoencoder/__init__.py +2 -0
  12. flaxdiff/models/autoencoder/autoencoder.py +19 -0
  13. flaxdiff/models/autoencoder/diffusers.py +91 -0
  14. flaxdiff/models/autoencoder/simple_autoenc.py +26 -0
  15. flaxdiff/models/common.py +346 -0
  16. flaxdiff/models/favor_fastattn.py +723 -0
  17. flaxdiff/models/simple_unet.py +233 -0
  18. flaxdiff/models/simple_vit.py +180 -0
  19. flaxdiff/predictors/__init__.py +96 -0
  20. flaxdiff/samplers/__init__.py +7 -0
  21. flaxdiff/samplers/common.py +165 -0
  22. flaxdiff/samplers/ddim.py +10 -0
  23. flaxdiff/samplers/ddpm.py +37 -0
  24. flaxdiff/samplers/euler.py +56 -0
  25. flaxdiff/samplers/heun_sampler.py +27 -0
  26. flaxdiff/samplers/multistep_dpm.py +59 -0
  27. flaxdiff/samplers/rk4_sampler.py +34 -0
  28. flaxdiff/schedulers/__init__.py +6 -0
  29. flaxdiff/schedulers/common.py +98 -0
  30. flaxdiff/schedulers/continuous.py +12 -0
  31. flaxdiff/schedulers/cosine.py +40 -0
  32. flaxdiff/schedulers/discrete.py +74 -0
  33. flaxdiff/schedulers/exp.py +13 -0
  34. flaxdiff/schedulers/karras.py +69 -0
  35. flaxdiff/schedulers/linear.py +14 -0
  36. flaxdiff/schedulers/sqrt.py +10 -0
  37. flaxdiff/trainer/__init__.py +2 -0
  38. flaxdiff/trainer/autoencoder_trainer.py +182 -0
  39. flaxdiff/trainer/diffusion_trainer.py +326 -0
  40. flaxdiff/trainer/simple_trainer.py +540 -0
  41. flaxdiff/trainer/video_diffusion_trainer.py +62 -0
  42. {flaxdiff-0.1.36.1.dist-info → flaxdiff-0.1.36.3.dist-info}/METADATA +1 -1
  43. flaxdiff-0.1.36.3.dist-info/RECORD +47 -0
  44. flaxdiff-0.1.36.1.dist-info/RECORD +0 -6
  45. {flaxdiff-0.1.36.1.dist-info → flaxdiff-0.1.36.3.dist-info}/WHEEL +0 -0
  46. {flaxdiff-0.1.36.1.dist-info → flaxdiff-0.1.36.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,540 @@
1
+ import orbax.checkpoint
2
+ import tqdm
3
+ from flax import linen as nn
4
+ import jax
5
+ from typing import Callable
6
+ from dataclasses import field
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+ from functools import partial
10
+ from clu import metrics
11
+ from flax.training import train_state # Useful dataclass to keep train state
12
+ import optax
13
+ from flax import struct # Flax dataclasses
14
+ import flax
15
+ import time
16
+ import os
17
+ import orbax
18
+ from flax.training import orbax_utils
19
+ from jax.sharding import Mesh, PartitionSpec as P
20
+ from jax.experimental import mesh_utils
21
+ from jax.experimental.shard_map import shard_map
22
+ from orbax.checkpoint.utils import fully_replicated_host_local_array_to_global_array
23
+ from termcolor import colored
24
+ from typing import Dict, Callable, Sequence, Any, Union, Tuple
25
+ from flax.training.dynamic_scale import DynamicScale
26
+ from flaxdiff.utils import RandomMarkovState
27
+ from flax.training import dynamic_scale as dynamic_scale_lib
28
+
29
+ PROCESS_COLOR_MAP = {
30
+ 0: "green",
31
+ 1: "yellow",
32
+ 2: "magenta",
33
+ 3: "cyan",
34
+ 4: "white",
35
+ 5: "light_blue",
36
+ 6: "light_red",
37
+ 7: "light_cyan"
38
+ }
39
+
40
+ def _build_global_shape_and_sharding(
41
+ local_shape: tuple[int, ...], global_mesh: Mesh
42
+ ) -> tuple[tuple[int, ...], jax.sharding.NamedSharding]:
43
+ sharding = jax.sharding.NamedSharding(global_mesh, P(global_mesh.axis_names))
44
+ global_shape = (jax.process_count() * local_shape[0],) + local_shape[1:]
45
+ return global_shape, sharding
46
+
47
+
48
+ def form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array:
49
+ """Put local sharded array into local devices"""
50
+ global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh)
51
+ try:
52
+ local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0)
53
+ except ValueError as array_split_error:
54
+ raise ValueError(
55
+ f"Unable to put to devices shape {array.shape} with "
56
+ f"local device count {len(global_mesh.local_devices)} "
57
+ ) from array_split_error
58
+ local_device_buffers = jax.device_put(local_device_arrays, global_mesh.local_devices)
59
+ return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers)
60
+
61
+ def convert_to_global_tree(global_mesh, pytree):
62
+ return jax.tree_util.tree_map_with_path(partial(form_global_array, global_mesh=global_mesh), pytree)
63
+
64
+ @struct.dataclass
65
+ class Metrics(metrics.Collection):
66
+ accuracy: metrics.Accuracy
67
+ loss: metrics.Average#.from_output('loss')
68
+
69
+ # Define the TrainState
70
+ class SimpleTrainState(train_state.TrainState):
71
+ metrics: Metrics
72
+ dynamic_scale: dynamic_scale_lib.DynamicScale
73
+
74
+ class SimpleTrainer:
75
+ state: SimpleTrainState
76
+ best_state: SimpleTrainState
77
+ best_loss: float
78
+ model: nn.Module
79
+ ema_decay: float = 0.999
80
+
81
+ def __init__(self,
82
+ model: nn.Module,
83
+ input_shapes: Dict[str, Tuple[int]],
84
+ optimizer: optax.GradientTransformation,
85
+ rngs: jax.random.PRNGKey,
86
+ train_state: SimpleTrainState = None,
87
+ name: str = "Simple",
88
+ load_from_checkpoint: str = None,
89
+ checkpoint_suffix: str = "",
90
+ loss_fn=optax.l2_loss,
91
+ param_transforms: Callable = None,
92
+ wandb_config: Dict[str, Any] = None,
93
+ distributed_training: bool = None,
94
+ checkpoint_base_path: str = "./checkpoints",
95
+ checkpoint_step: int = None,
96
+ use_dynamic_scale: bool = False,
97
+ ):
98
+ if distributed_training is None or distributed_training is True:
99
+ # Auto-detect if we are running on multiple devices
100
+ distributed_training = jax.device_count() > 1
101
+ self.mesh = jax.sharding.Mesh(jax.devices(), 'data')
102
+ else:
103
+ self.mesh = None
104
+
105
+ self.distributed_training = distributed_training
106
+ self.model = model
107
+ self.name = name
108
+ self.loss_fn = loss_fn
109
+ self.input_shapes = input_shapes
110
+ self.checkpoint_base_path = checkpoint_base_path
111
+
112
+
113
+ if wandb_config is not None and jax.process_index() == 0:
114
+ import wandb
115
+ run = wandb.init(**wandb_config)
116
+ self.wandb = run
117
+
118
+ # define our custom x axis metric
119
+ self.wandb.define_metric("train/step")
120
+ self.wandb.define_metric("train/epoch")
121
+
122
+ self.wandb.define_metric("train/loss", step_metric="train/step")
123
+
124
+ self.wandb.define_metric("train/epoch_time", step_metric="train/epoch")
125
+ self.wandb.define_metric("train/avg_time_per_step", step_metric="train/epoch")
126
+ self.wandb.define_metric("train/avg_loss", step_metric="train/epoch")
127
+ self.wandb.define_metric("train/best_loss", step_metric="train/epoch")
128
+
129
+ # checkpointer = orbax.checkpoint.PyTreeCheckpointer()
130
+ async_checkpointer = orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler(), timeout_secs=60)
131
+
132
+ options = orbax.checkpoint.CheckpointManagerOptions(
133
+ max_to_keep=4, create=True)
134
+ self.checkpointer = orbax.checkpoint.CheckpointManager(
135
+ self.checkpoint_path() + checkpoint_suffix, async_checkpointer, options)
136
+
137
+ if load_from_checkpoint is not None:
138
+ latest_epoch, latest_step, old_state, old_best_state, rngstate = self.load(load_from_checkpoint, checkpoint_step)
139
+ else:
140
+ latest_epoch, latest_step, old_state, old_best_state, rngstate = 0, 0, None, None, None
141
+
142
+ self.latest_step = latest_step
143
+
144
+ if rngstate:
145
+ self.rngstate = RandomMarkovState(**rngstate)
146
+ else:
147
+ self.rngstate = RandomMarkovState(rngs)
148
+
149
+ self.rngstate, subkey = self.rngstate.get_random_key()
150
+
151
+ if train_state == None:
152
+ state, best_state = self.generate_states(
153
+ optimizer, subkey, old_state, old_best_state, model, param_transforms, use_dynamic_scale
154
+ )
155
+ self.init_state(state, best_state)
156
+ else:
157
+ self.state = train_state
158
+ self.best_state = train_state
159
+ self.best_loss = 1e9
160
+
161
+ def get_input_ones(self):
162
+ return {k: jnp.ones((1, *v)) for k, v in self.input_shapes.items()}
163
+
164
+ def generate_states(
165
+ self,
166
+ optimizer: optax.GradientTransformation,
167
+ rngs: jax.random.PRNGKey,
168
+ existing_state: dict = None,
169
+ existing_best_state: dict = None,
170
+ model: nn.Module = None,
171
+ param_transforms: Callable = None,
172
+ use_dynamic_scale: bool = False
173
+ ) -> Tuple[SimpleTrainState, SimpleTrainState]:
174
+ print("Generating states for SimpleTrainer")
175
+ rngs, subkey = jax.random.split(rngs)
176
+
177
+ if existing_state == None:
178
+ input_vars = self.get_input_ones()
179
+ params = model.init(subkey, **input_vars)
180
+ else:
181
+ params = existing_state['params']
182
+
183
+ state = SimpleTrainState.create(
184
+ apply_fn=model.apply,
185
+ params=params,
186
+ tx=optimizer,
187
+ metrics=Metrics.empty(),
188
+ dynamic_scale = dynamic_scale_lib.DynamicScale() if use_dynamic_scale else None
189
+ )
190
+ if existing_best_state is not None:
191
+ best_state = state.replace(
192
+ params=existing_best_state['params'])
193
+ else:
194
+ best_state = state
195
+
196
+ return state, best_state
197
+
198
+ def init_state(
199
+ self,
200
+ state: SimpleTrainState,
201
+ best_state: SimpleTrainState,
202
+ ):
203
+ self.best_loss = 1e9
204
+
205
+ self.state = state
206
+ self.best_state = best_state
207
+
208
+ def get_state(self):
209
+ return self.get_np_tree(self.state)
210
+
211
+ def get_best_state(self):
212
+ return self.get_np_tree(self.best_state)
213
+
214
+ def get_rngstate(self):
215
+ return self.get_np_tree(self.rngstate)
216
+
217
+ def get_np_tree(self, pytree):
218
+ return jax.tree_util.tree_map(lambda x : np.array(x), pytree)
219
+
220
+ def checkpoint_path(self):
221
+ path = os.path.join(self.checkpoint_base_path, self.name.replace(' ', '_').lower())
222
+ # Convert the path to an absolute path
223
+ path = os.path.abspath(path)
224
+ if not os.path.exists(path):
225
+ os.makedirs(path)
226
+ return path
227
+
228
+ def tensorboard_path(self):
229
+ experiment_name = self.name
230
+ path = os.path.join(os.path.abspath('./tensorboard'), experiment_name)
231
+ if not os.path.exists(path):
232
+ os.makedirs(path)
233
+ return path
234
+
235
+ def load(self, checkpoint_path=None, checkpoint_step=None):
236
+ if checkpoint_path is None:
237
+ checkpointer = self.checkpointer
238
+ else:
239
+ checkpointer = orbax.checkpoint.PyTreeCheckpointer()
240
+ options = orbax.checkpoint.CheckpointManagerOptions(
241
+ max_to_keep=4, create=False)
242
+ checkpointer = orbax.checkpoint.CheckpointManager(
243
+ checkpoint_path, checkpointer, options)
244
+
245
+ if checkpoint_step is None:
246
+ step = checkpointer.latest_step()
247
+ else:
248
+ step = checkpoint_step
249
+
250
+ print("Loading model from checkpoint at step ", step)
251
+ ckpt = checkpointer.restore(step)
252
+ state = ckpt['state']
253
+ best_state = ckpt['best_state']
254
+ rngstate = ckpt['rngs']
255
+ # Convert the state to a TrainState
256
+ self.best_loss = ckpt['best_loss']
257
+ if self.best_loss == 0:
258
+ # It cant be zero as that must have been some problem
259
+ self.best_loss = 1e9
260
+ current_epoch = ckpt.get('epoch', step) # Must be a checkpoint from an older version which used epochs instead of steps
261
+ print(
262
+ f"Loaded model from checkpoint at epoch {current_epoch} step {step}", ckpt['best_loss'])
263
+ return current_epoch, step, state, best_state, rngstate
264
+
265
+ def save(self, epoch=0, step=0, state=None, rngstate=None):
266
+ print(f"Saving model at epoch {epoch} step {step}")
267
+ try:
268
+ ckpt = {
269
+ # 'model': self.model,
270
+ 'rngs': self.get_rngstate() if rngstate is None else self.get_np_tree(rngstate),
271
+ 'state': self.get_state() if state is None else self.get_np_tree(state),
272
+ 'best_state': self.get_best_state(),
273
+ 'best_loss': np.array(self.best_loss),
274
+ 'epoch': epoch,
275
+ }
276
+ try:
277
+ save_args = orbax_utils.save_args_from_target(ckpt)
278
+ self.checkpointer.save(step, ckpt, save_kwargs={
279
+ 'save_args': save_args}, force=True)
280
+ self.checkpointer.wait_until_finished()
281
+ pass
282
+ except Exception as e:
283
+ print("Error saving checkpoint", e)
284
+ except Exception as e:
285
+ print("Error saving checkpoint outer", e)
286
+
287
+ def _define_train_step(self, **kwargs):
288
+ model = self.model
289
+ loss_fn = self.loss_fn
290
+ distributed_training = self.distributed_training
291
+
292
+ def train_step(train_state: SimpleTrainState, rng_state: RandomMarkovState, batch, local_device_indexes):
293
+ """Train for a single step."""
294
+ images = batch['image']
295
+ labels = batch['label']
296
+
297
+ def model_loss(params):
298
+ preds = model.apply(params, images)
299
+ expected_output = labels
300
+ nloss = loss_fn(preds, expected_output)
301
+ loss = jnp.mean(nloss)
302
+ return loss
303
+ loss, grads = jax.value_and_grad(model_loss)(train_state.params)
304
+ if distributed_training:
305
+ grads = jax.lax.pmean(grads, "data")
306
+ train_state = train_state.apply_gradients(grads=grads)
307
+ return train_state, loss, rng_state
308
+
309
+ if distributed_training:
310
+ train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')), out_specs=(P(), P('data'), P()))
311
+ train_step = jax.pmap(train_step)
312
+ return train_step
313
+
314
+ def _define_vaidation_step(self):
315
+ model = self.model
316
+ loss_fn = self.loss_fn
317
+ distributed_training = self.distributed_training
318
+
319
+ def validation_step(state: SimpleTrainState, batch):
320
+ preds = model.apply(state.params, batch['image'])
321
+ expected_output = batch['label']
322
+ loss = jnp.mean(loss_fn(preds, expected_output))
323
+ if distributed_training:
324
+ loss = jax.lax.pmean(loss, "data")
325
+ metric_updates = state.metrics.single_from_model_output(
326
+ loss=loss, logits=preds, labels=expected_output)
327
+ metrics = state.metrics.merge(metric_updates)
328
+ state = state.replace(metrics=metrics)
329
+ return state
330
+ if distributed_training:
331
+ validation_step = shard_map(validation_step, mesh=self.mesh, in_specs=(P(), P('data')), out_specs=(P()))
332
+ validation_step = jax.pmap(validation_step)
333
+ return validation_step
334
+
335
+ def summary(self):
336
+ input_vars = self.get_input_ones()
337
+ print(self.model.tabulate(jax.random.key(0), **input_vars,
338
+ console_kwargs={"width": 200, "force_jupyter": True, }))
339
+
340
+ def config(self):
341
+ return {
342
+ "model": self.model,
343
+ "state": self.state,
344
+ "name": self.name,
345
+ "input_shapes": self.input_shapes
346
+ }
347
+
348
+ def init_tensorboard(self, batch_size, steps_per_epoch, epochs):
349
+ from flax.metrics import tensorboard
350
+ summary_writer = tensorboard.SummaryWriter(self.tensorboard_path())
351
+ summary_writer.hparams({
352
+ **self.config(),
353
+ "steps_per_epoch": steps_per_epoch,
354
+ "epochs": epochs,
355
+ "batch_size": batch_size
356
+ })
357
+ return summary_writer
358
+
359
+ def validation_loop(
360
+ self,
361
+ val_state: SimpleTrainState,
362
+ val_step_fn: Callable,
363
+ val_ds,
364
+ val_steps_per_epoch,
365
+ current_step,
366
+ ):
367
+ global_device_count = jax.device_count()
368
+ local_device_count = jax.local_device_count()
369
+ process_index = jax.process_index()
370
+
371
+ val_ds = iter(val_ds()) if val_ds else None
372
+ # Evaluation step
373
+ try:
374
+ for i in range(val_steps_per_epoch):
375
+ if val_ds is None:
376
+ batch = None
377
+ else:
378
+ batch = next(val_ds)
379
+ if self.distributed_training and global_device_count > 1:
380
+ batch = convert_to_global_tree(self.mesh, batch)
381
+ if i == 0:
382
+ print(f"Evaluation started for process index {process_index}")
383
+ metrics = val_step_fn(val_state, batch)
384
+ if self.wandb is not None:
385
+ # metrics is a dict of metrics
386
+ if metrics and type(metrics) == dict:
387
+ for key, value in metrics.items():
388
+ if isinstance(value, jnp.ndarray):
389
+ value = np.array(value)
390
+ self.wandb.log({
391
+ f"val/{key}": value,
392
+ }, step=current_step)
393
+ except Exception as e:
394
+ print("Error logging images to wandb", e)
395
+
396
+ def train_loop(
397
+ self,
398
+ train_state: SimpleTrainState,
399
+ train_step_fn: Callable,
400
+ train_ds,
401
+ train_steps_per_epoch,
402
+ current_step,
403
+ rng_state
404
+ ):
405
+ global_device_count = jax.device_count()
406
+ local_device_count = jax.local_device_count()
407
+ process_index = jax.process_index()
408
+ if self.distributed_training:
409
+ global_device_indexes = jnp.arange(global_device_count)
410
+ else:
411
+ global_device_indexes = 0
412
+
413
+ epoch_loss = 0
414
+ current_epoch = current_step // train_steps_per_epoch
415
+ last_save_time = time.time()
416
+
417
+ if process_index == 0:
418
+ pbar = tqdm.tqdm(total=train_steps_per_epoch, desc=f'\t\tEpoch {current_epoch}', ncols=100, unit='step')
419
+
420
+ for i in range(train_steps_per_epoch):
421
+ batch = next(train_ds)
422
+ if i == 0:
423
+ print(f"First batch loaded at step {current_step}")
424
+
425
+ if self.distributed_training and global_device_count > 1:
426
+ # # Convert the local device batches to a unified global jax.Array
427
+ batch = convert_to_global_tree(self.mesh, batch)
428
+ train_state, loss, rng_state = train_step_fn(train_state, rng_state, batch, global_device_indexes)
429
+
430
+ if i == 0:
431
+ print(f"Training started for process index {process_index} at step {current_step}")
432
+
433
+ if self.distributed_training:
434
+ # loss = jax.experimental.multihost_utils.process_allgather(loss)
435
+ loss = jnp.mean(loss) # Just to make sure its a scaler value
436
+
437
+ if loss <= 1e-6:
438
+ # If the loss is too low, we can assume the model has diverged
439
+ print(colored(f"Loss too low at step {current_step} => {loss}", 'red'))
440
+ # Reset the model to the old state
441
+ exit(1)
442
+
443
+ epoch_loss += loss
444
+ current_step += 1
445
+ if i % 100 == 0:
446
+ if pbar is not None:
447
+ pbar.set_postfix(loss=f'{loss:.4f}')
448
+ pbar.update(100)
449
+ if self.wandb is not None:
450
+ self.wandb.log({
451
+ "train/step" : current_step,
452
+ "train/loss": loss,
453
+ }, step=current_step)
454
+ # Save the model every few steps
455
+ if i % 10000 == 0 and i > 0:
456
+ print(f"Saving model after 10000 step {current_step}")
457
+ print(f"Devices: {len(jax.devices())}") # To sync the devices
458
+ self.save(current_epoch, current_step, train_state, rng_state)
459
+ print(f"Saving done by process index {process_index}")
460
+ last_save_time = time.time()
461
+ print(colored(f"Epoch done on index {process_index} => {current_epoch} Loss: {epoch_loss/train_steps_per_epoch}", 'green'))
462
+ if pbar is not None:
463
+ pbar.close()
464
+ return epoch_loss, current_step, train_state, rng_state
465
+
466
+
467
+ def fit(self, data, train_steps_per_epoch, epochs, train_step_args={}, val_steps_per_epoch=5, validation_step_args={}):
468
+ train_ds = iter(data['train']())
469
+ train_step = self._define_train_step(**train_step_args)
470
+ val_step = self._define_vaidation_step(**validation_step_args)
471
+ train_state = self.state
472
+ rng_state = self.rngstate
473
+ process_index = jax.process_index()
474
+
475
+ if val_steps_per_epoch > 0:
476
+ # We should first run a validation step to make sure the model is working
477
+ print(f"Validation run for sanity check for process index {process_index}")
478
+ # Validation step
479
+ self.validation_loop(
480
+ train_state,
481
+ val_step,
482
+ data.get('test', data.get('val', None)),
483
+ val_steps_per_epoch,
484
+ self.latest_step,
485
+ )
486
+ print(colored(f"Sanity Validation done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
487
+
488
+ while self.latest_step < epochs * train_steps_per_epoch:
489
+ current_epoch = self.latest_step // train_steps_per_epoch
490
+ print(f"\nEpoch {current_epoch}/{epochs}")
491
+ start_time = time.time()
492
+ epoch_loss = 0
493
+
494
+ epoch_loss, current_step, train_state, rng_state = self.train_loop(
495
+ train_state,
496
+ train_step,
497
+ train_ds,
498
+ train_steps_per_epoch,
499
+ self.latest_step,
500
+ rng_state,
501
+ )
502
+ print(colored(f"Epoch done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
503
+
504
+ self.latest_step = current_step
505
+ end_time = time.time()
506
+ self.state = train_state
507
+ self.rngstate = rng_state
508
+ total_time = end_time - start_time
509
+ avg_time_per_step = total_time / train_steps_per_epoch
510
+ avg_loss = epoch_loss / train_steps_per_epoch
511
+ if avg_loss < self.best_loss:
512
+ self.best_loss = avg_loss
513
+ self.best_state = train_state
514
+ self.save(current_epoch, current_step)
515
+
516
+ if process_index == 0:
517
+ if self.wandb is not None:
518
+ self.wandb.log({
519
+ "train/epoch_time": total_time,
520
+ "train/avg_time_per_step": avg_time_per_step,
521
+ "train/avg_loss": avg_loss,
522
+ "train/best_loss": self.best_loss,
523
+ "train/epoch": current_epoch,
524
+ }, step=current_step)
525
+ print(colored(f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss}", 'green'))
526
+
527
+ if val_steps_per_epoch > 0:
528
+ print(f"Validation started for process index {process_index}")
529
+ # Validation step
530
+ self.validation_loop(
531
+ train_state,
532
+ val_step,
533
+ data.get('test', None),
534
+ val_steps_per_epoch,
535
+ current_step,
536
+ )
537
+ print(colored(f"Validation done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
538
+
539
+ self.save(epochs)
540
+ return self.state
@@ -0,0 +1,62 @@
1
+ import flax
2
+ from flax import linen as nn
3
+ import jax
4
+ from typing import Callable
5
+ from dataclasses import field
6
+ import jax.numpy as jnp
7
+ import optax
8
+ import functools
9
+ from jax.sharding import Mesh, PartitionSpec as P
10
+ from jax.experimental.shard_map import shard_map
11
+ from typing import Dict, Callable, Sequence, Any, Union, Tuple
12
+
13
+ from ..schedulers import NoiseScheduler
14
+ from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
15
+
16
+ from flaxdiff.utils import RandomMarkovState
17
+
18
+ from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
19
+
20
+ from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
21
+ from flax.training import dynamic_scale as dynamic_scale_lib
22
+
23
+ class TrainState(SimpleTrainState):
24
+ rngs: jax.random.PRNGKey
25
+ ema_params: dict
26
+
27
+ def apply_ema(self, decay: float = 0.999):
28
+ new_ema_params = jax.tree_util.tree_map(
29
+ lambda ema, param: decay * ema + (1 - decay) * param,
30
+ self.ema_params,
31
+ self.params,
32
+ )
33
+ return self.replace(ema_params=new_ema_params)
34
+
35
+ from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
36
+ from flaxdiff.trainer.diffusion_trainer import DiffusionTrainer
37
+
38
+ class SimpleVideoDiffusionTrainer(DiffusionTrainer):
39
+ def __init__(self,
40
+ model: nn.Module,
41
+ input_shapes: Dict[str, Tuple[int]],
42
+ optimizer: optax.GradientTransformation,
43
+ noise_schedule: NoiseScheduler,
44
+ rngs: jax.random.PRNGKey,
45
+ unconditional_prob: float = 0.12,
46
+ name: str = "SimpleVideoDiffusion",
47
+ model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
48
+ autoencoder: AutoEncoder = None,
49
+ **kwargs
50
+ ):
51
+ super().__init__(
52
+ model=model,
53
+ input_shapes=input_shapes,
54
+ optimizer=optimizer,
55
+ noise_schedule=noise_schedule,
56
+ unconditional_prob=unconditional_prob,
57
+ autoencoder=autoencoder,
58
+ model_output_transform=model_output_transform,
59
+ rngs=rngs,
60
+ name=name,
61
+ **kwargs
62
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.36.1
3
+ Version: 0.1.36.3
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -0,0 +1,47 @@
1
+ flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ flaxdiff/utils.py,sha256=b_hFXsam2NICQYCFk0EOcqtBjM-RUqnN0NKTn0lQ070,6532
3
+ flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
4
+ flaxdiff/data/dataset_map.py,sha256=hcHaoR2IbNQmfyPUhYd6_8xinurxxCqawQijAsDI0Ek,3093
5
+ flaxdiff/data/datasets.py,sha256=YUMoSvF2yAyikRvRofZVlHwfEOU3zXSSG4KkLnVfpoA,5626
6
+ flaxdiff/data/online_loader.py,sha256=1Fi_QRixxRzbt602nORINcDeHEccvCrBpagrz4PURYg,12499
7
+ flaxdiff/data/sources/gcs.py,sha256=11ZuQhvMyJRLg21DgVdzO5qEuae7zgzTXGNOskF-cbs,3380
8
+ flaxdiff/data/sources/tfds.py,sha256=WA3h9lyR4yotCNEmJON2noIN-2HNcqhf6zigx1XXsMI,2481
9
+ flaxdiff/metrics/inception.py,sha256=a5kjMCPMT9gB88c_HCKiek-2vsAyoE35K7nDt4h4pVI,31843
10
+ flaxdiff/metrics/utils.py,sha256=YuuOfqvqgIjsceupwNeJ59vQ2TnGeNMIyKdkIqOmoNg,1702
11
+ flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
12
+ flaxdiff/models/attention.py,sha256=JvrP7-09MV6IfRLRBhqjPmNUU-lkEMk9TOnJSBKcar8,13289
13
+ flaxdiff/models/common.py,sha256=hWsSs2BP2J-JN1s4qLRr-h-KYkcVyl2hOp1Wsm_L-h8,10994
14
+ flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
15
+ flaxdiff/models/simple_unet.py,sha256=L5m2j5580QP7pJ5VIme7U5xYA22PZiGP7qdvcKUnB38,11463
16
+ flaxdiff/models/simple_vit.py,sha256=UCDDr0XVnpf6tbJWKFtEt3_nAqMqOoakXf5amyVWZNo,7929
17
+ flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
18
+ flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
19
+ flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
20
+ flaxdiff/models/autoencoder/simple_autoenc.py,sha256=UXHPgDmwGTnv3Uts6Zj3p9R9nJXnEiEXbllgarwDfXM,805
21
+ flaxdiff/predictors/__init__.py,sha256=SKkYYRF9Wfgk2zhtZw4vCXOdOeRlrm2Mk6cvuaEvAzc,4403
22
+ flaxdiff/samplers/__init__.py,sha256=_S-9TwDeshrI0VmapV-J2hqjTByOa0-oOeUs_IdovjU,285
23
+ flaxdiff/samplers/common.py,sha256=ZA08VyovxegpRx4wOQq9LSwZi0gSCz2lrbS5oVYOEYg,8488
24
+ flaxdiff/samplers/ddim.py,sha256=pB8Kod8ZLJ3GXev4uM3cOj1Uy6ibR0jsaZa-VE0fyJM,552
25
+ flaxdiff/samplers/ddpm.py,sha256=u1OchQu0XPhc_6w9JXoaFp2wo4y-zXyQNtGAIJwxNLg,2209
26
+ flaxdiff/samplers/euler.py,sha256=Htb-IJeu7jSgY6mvgYr9yl9pUnos49vijlVk5IQsRps,2740
27
+ flaxdiff/samplers/heun_sampler.py,sha256=UyI-hSlyWvt-7VEUJj27zjgyzKkGVl8fDUHV-YpSOCc,1421
28
+ flaxdiff/samplers/multistep_dpm.py,sha256=3Wu3MrMLYaBb1ObraTbWrJmtEtU0adl1dDbz5fPJ4Gs,2735
29
+ flaxdiff/samplers/rk4_sampler.py,sha256=1j1pES_Q2QiaURvEWeedbbT1LHmkc3jsu0GgH83qBL0,1926
30
+ flaxdiff/schedulers/__init__.py,sha256=3id390WEfdf-MN-oLSPAhlRFIXrFWr6ioAHPAwURJyE,375
31
+ flaxdiff/schedulers/common.py,sha256=b-W4iI-aqScpVE8VZbBpiYvAVI6rqDkUP-C_hEVBwCI,4151
32
+ flaxdiff/schedulers/continuous.py,sha256=5c_niOA20fxJ5oJDi09FfayIRogBGwtfG0XThW2IUZk,334
33
+ flaxdiff/schedulers/cosine.py,sha256=9ban0dFHLMm35wQvaBT4nCQwPGmzNsXwQ1xI0oppmJI,2005
34
+ flaxdiff/schedulers/discrete.py,sha256=O54wH2HVu3olJA71NxgAXFW9cr6B6Gl-DR_uZeytpds,3319
35
+ flaxdiff/schedulers/exp.py,sha256=cPTnUJpYdzJRRZqMLYQz0rRUCpEmaP2tXhRumLx94jA,605
36
+ flaxdiff/schedulers/karras.py,sha256=4GN120kGwdxxU-h2mVdhBVy9IORkUMm_vvz3XjthBcI,3355
37
+ flaxdiff/schedulers/linear.py,sha256=6003F5ISq1Wc0h6UAzY95MJgsDIKGMhBzbiVALpea0k,581
38
+ flaxdiff/schedulers/sqrt.py,sha256=1F84ZgQPuoNMhe6yxGTR2G0h7dPOZtm4UDQOakbSsEU,445
39
+ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,128
40
+ flaxdiff/trainer/autoencoder_trainer.py,sha256=hxihkRL9WCIQVGOP-pc1jjjIUaRXDLcNo3_erTKsuWM,7049
41
+ flaxdiff/trainer/diffusion_trainer.py,sha256=ajOWBgFFwXP_VQScUjcuPoaB4Gk02aF0Ls5LNlA8wqA,12691
42
+ flaxdiff/trainer/simple_trainer.py,sha256=lmRo8N0bMupIyS3ejPvPtxoskY_3GLC8iyJE6u4TIWc,21990
43
+ flaxdiff/trainer/video_diffusion_trainer.py,sha256=gMkKpnKNTo8QhTx5ptEEkc7W5-7rzXIr9queU53hXyQ,2197
44
+ flaxdiff-0.1.36.3.dist-info/METADATA,sha256=9XaZMJ6SMFP7OUn2tp9v5FQveMGoxvuiyxdJ8SmMd8w,22310
45
+ flaxdiff-0.1.36.3.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
46
+ flaxdiff-0.1.36.3.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
47
+ flaxdiff-0.1.36.3.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- flaxdiff/utils.py,sha256=b_hFXsam2NICQYCFk0EOcqtBjM-RUqnN0NKTn0lQ070,6532
3
- flaxdiff-0.1.36.1.dist-info/METADATA,sha256=Fl9tlGh_BgRnT-f8k4cEYnFj7G03VecUNOX_1zbJrmE,22310
4
- flaxdiff-0.1.36.1.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
5
- flaxdiff-0.1.36.1.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
6
- flaxdiff-0.1.36.1.dist-info/RECORD,,