flaxdiff 0.1.36__py3-none-any.whl → 0.1.36.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {flaxdiff-0.1.36.dist-info → flaxdiff-0.1.36.1.dist-info}/METADATA +13 -10
  2. flaxdiff-0.1.36.1.dist-info/RECORD +6 -0
  3. flaxdiff/data/__init__.py +0 -1
  4. flaxdiff/data/dataset_map.py +0 -71
  5. flaxdiff/data/datasets.py +0 -169
  6. flaxdiff/data/online_loader.py +0 -363
  7. flaxdiff/models/__init__.py +0 -1
  8. flaxdiff/models/attention.py +0 -368
  9. flaxdiff/models/autoencoder/__init__.py +0 -2
  10. flaxdiff/models/autoencoder/autoencoder.py +0 -19
  11. flaxdiff/models/autoencoder/diffusers.py +0 -91
  12. flaxdiff/models/autoencoder/simple_autoenc.py +0 -26
  13. flaxdiff/models/common.py +0 -346
  14. flaxdiff/models/favor_fastattn.py +0 -723
  15. flaxdiff/models/simple_unet.py +0 -233
  16. flaxdiff/models/simple_vit.py +0 -180
  17. flaxdiff/predictors/__init__.py +0 -96
  18. flaxdiff/samplers/__init__.py +0 -7
  19. flaxdiff/samplers/common.py +0 -165
  20. flaxdiff/samplers/ddim.py +0 -10
  21. flaxdiff/samplers/ddpm.py +0 -37
  22. flaxdiff/samplers/euler.py +0 -56
  23. flaxdiff/samplers/heun_sampler.py +0 -27
  24. flaxdiff/samplers/multistep_dpm.py +0 -59
  25. flaxdiff/samplers/rk4_sampler.py +0 -34
  26. flaxdiff/schedulers/__init__.py +0 -6
  27. flaxdiff/schedulers/common.py +0 -98
  28. flaxdiff/schedulers/continuous.py +0 -12
  29. flaxdiff/schedulers/cosine.py +0 -40
  30. flaxdiff/schedulers/discrete.py +0 -74
  31. flaxdiff/schedulers/exp.py +0 -13
  32. flaxdiff/schedulers/karras.py +0 -69
  33. flaxdiff/schedulers/linear.py +0 -14
  34. flaxdiff/schedulers/sqrt.py +0 -10
  35. flaxdiff/trainer/__init__.py +0 -2
  36. flaxdiff/trainer/autoencoder_trainer.py +0 -182
  37. flaxdiff/trainer/diffusion_trainer.py +0 -326
  38. flaxdiff/trainer/simple_trainer.py +0 -538
  39. flaxdiff/trainer/video_diffusion_trainer.py +0 -62
  40. flaxdiff-0.1.36.dist-info/RECORD +0 -43
  41. {flaxdiff-0.1.36.dist-info → flaxdiff-0.1.36.1.dist-info}/WHEEL +0 -0
  42. {flaxdiff-0.1.36.dist-info → flaxdiff-0.1.36.1.dist-info}/top_level.txt +0 -0
@@ -1,538 +0,0 @@
1
- import orbax.checkpoint
2
- import tqdm
3
- from flax import linen as nn
4
- import jax
5
- from typing import Callable
6
- from dataclasses import field
7
- import jax.numpy as jnp
8
- import numpy as np
9
- from functools import partial
10
- from clu import metrics
11
- from flax.training import train_state # Useful dataclass to keep train state
12
- import optax
13
- from flax import struct # Flax dataclasses
14
- import flax
15
- import time
16
- import os
17
- import orbax
18
- from flax.training import orbax_utils
19
- from jax.sharding import Mesh, PartitionSpec as P
20
- from jax.experimental import mesh_utils
21
- from jax.experimental.shard_map import shard_map
22
- from orbax.checkpoint.utils import fully_replicated_host_local_array_to_global_array
23
- from termcolor import colored
24
- from typing import Dict, Callable, Sequence, Any, Union, Tuple
25
- from flax.training.dynamic_scale import DynamicScale
26
- from flaxdiff.utils import RandomMarkovState
27
- from flax.training import dynamic_scale as dynamic_scale_lib
28
-
29
- PROCESS_COLOR_MAP = {
30
- 0: "green",
31
- 1: "yellow",
32
- 2: "magenta",
33
- 3: "cyan",
34
- 4: "white",
35
- 5: "light_blue",
36
- 6: "light_red",
37
- 7: "light_cyan"
38
- }
39
-
40
- def _build_global_shape_and_sharding(
41
- local_shape: tuple[int, ...], global_mesh: Mesh
42
- ) -> tuple[tuple[int, ...], jax.sharding.NamedSharding]:
43
- sharding = jax.sharding.NamedSharding(global_mesh, P(global_mesh.axis_names))
44
- global_shape = (jax.process_count() * local_shape[0],) + local_shape[1:]
45
- return global_shape, sharding
46
-
47
-
48
- def form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array:
49
- """Put local sharded array into local devices"""
50
- global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh)
51
- try:
52
- local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0)
53
- except ValueError as array_split_error:
54
- raise ValueError(
55
- f"Unable to put to devices shape {array.shape} with "
56
- f"local device count {len(global_mesh.local_devices)} "
57
- ) from array_split_error
58
- local_device_buffers = jax.device_put(local_device_arrays, global_mesh.local_devices)
59
- return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers)
60
-
61
- def convert_to_global_tree(global_mesh, pytree):
62
- return jax.tree_util.tree_map_with_path(partial(form_global_array, global_mesh=global_mesh), pytree)
63
-
64
- @struct.dataclass
65
- class Metrics(metrics.Collection):
66
- accuracy: metrics.Accuracy
67
- loss: metrics.Average#.from_output('loss')
68
-
69
- # Define the TrainState
70
- class SimpleTrainState(train_state.TrainState):
71
- metrics: Metrics
72
- dynamic_scale: dynamic_scale_lib.DynamicScale
73
-
74
- class SimpleTrainer:
75
- state: SimpleTrainState
76
- best_state: SimpleTrainState
77
- best_loss: float
78
- model: nn.Module
79
- ema_decay: float = 0.999
80
-
81
- def __init__(self,
82
- model: nn.Module,
83
- input_shapes: Dict[str, Tuple[int]],
84
- optimizer: optax.GradientTransformation,
85
- rngs: jax.random.PRNGKey,
86
- train_state: SimpleTrainState = None,
87
- name: str = "Simple",
88
- load_from_checkpoint: str = None,
89
- checkpoint_suffix: str = "",
90
- loss_fn=optax.l2_loss,
91
- param_transforms: Callable = None,
92
- wandb_config: Dict[str, Any] = None,
93
- distributed_training: bool = None,
94
- checkpoint_base_path: str = "./checkpoints",
95
- checkpoint_step: int = None,
96
- use_dynamic_scale: bool = False,
97
- ):
98
- if distributed_training is None or distributed_training is True:
99
- # Auto-detect if we are running on multiple devices
100
- distributed_training = jax.device_count() > 1
101
- self.mesh = jax.sharding.Mesh(jax.devices(), 'data')
102
- else:
103
- self.mesh = None
104
-
105
- self.distributed_training = distributed_training
106
- self.model = model
107
- self.name = name
108
- self.loss_fn = loss_fn
109
- self.input_shapes = input_shapes
110
- self.checkpoint_base_path = checkpoint_base_path
111
-
112
-
113
- if wandb_config is not None and jax.process_index() == 0:
114
- import wandb
115
- run = wandb.init(**wandb_config)
116
- self.wandb = run
117
-
118
- # define our custom x axis metric
119
- self.wandb.define_metric("train/step")
120
- self.wandb.define_metric("train/epoch")
121
-
122
- self.wandb.define_metric("train/loss", step_metric="train/step")
123
-
124
- self.wandb.define_metric("train/epoch_time", step_metric="train/epoch")
125
- self.wandb.define_metric("train/avg_time_per_step", step_metric="train/epoch")
126
- self.wandb.define_metric("train/avg_loss", step_metric="train/epoch")
127
- self.wandb.define_metric("train/best_loss", step_metric="train/epoch")
128
-
129
- # checkpointer = orbax.checkpoint.PyTreeCheckpointer()
130
- async_checkpointer = orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler(), timeout_secs=60)
131
-
132
- options = orbax.checkpoint.CheckpointManagerOptions(
133
- max_to_keep=4, create=True)
134
- self.checkpointer = orbax.checkpoint.CheckpointManager(
135
- self.checkpoint_path() + checkpoint_suffix, async_checkpointer, options)
136
-
137
- if load_from_checkpoint is not None:
138
- latest_epoch, latest_step, old_state, old_best_state, rngstate = self.load(load_from_checkpoint, checkpoint_step)
139
- else:
140
- latest_epoch, latest_step, old_state, old_best_state, rngstate = 0, 0, None, None, None
141
-
142
- self.latest_step = latest_step
143
-
144
- if rngstate:
145
- self.rngstate = RandomMarkovState(**rngstate)
146
- else:
147
- self.rngstate = RandomMarkovState(rngs)
148
-
149
- self.rngstate, subkey = self.rngstate.get_random_key()
150
-
151
- if train_state == None:
152
- state, best_state = self.generate_states(
153
- optimizer, subkey, old_state, old_best_state, model, param_transforms, use_dynamic_scale
154
- )
155
- self.init_state(state, best_state)
156
- else:
157
- self.state = train_state
158
- self.best_state = train_state
159
- self.best_loss = 1e9
160
-
161
- def get_input_ones(self):
162
- return {k: jnp.ones((1, *v)) for k, v in self.input_shapes.items()}
163
-
164
- def generate_states(
165
- self,
166
- optimizer: optax.GradientTransformation,
167
- rngs: jax.random.PRNGKey,
168
- existing_state: dict = None,
169
- existing_best_state: dict = None,
170
- model: nn.Module = None,
171
- param_transforms: Callable = None,
172
- use_dynamic_scale: bool = False
173
- ) -> Tuple[SimpleTrainState, SimpleTrainState]:
174
- print("Generating states for SimpleTrainer")
175
- rngs, subkey = jax.random.split(rngs)
176
-
177
- if existing_state == None:
178
- input_vars = self.get_input_ones()
179
- params = model.init(subkey, **input_vars)
180
- else:
181
- params = existing_state['params']
182
-
183
- state = SimpleTrainState.create(
184
- apply_fn=model.apply,
185
- params=params,
186
- tx=optimizer,
187
- metrics=Metrics.empty(),
188
- dynamic_scale = dynamic_scale_lib.DynamicScale() if use_dynamic_scale else None
189
- )
190
- if existing_best_state is not None:
191
- best_state = state.replace(
192
- params=existing_best_state['params'])
193
- else:
194
- best_state = state
195
-
196
- return state, best_state
197
-
198
- def init_state(
199
- self,
200
- state: SimpleTrainState,
201
- best_state: SimpleTrainState,
202
- ):
203
- self.best_loss = 1e9
204
-
205
- self.state = state
206
- self.best_state = best_state
207
-
208
- def get_state(self):
209
- return self.get_np_tree(self.state)
210
-
211
- def get_best_state(self):
212
- return self.get_np_tree(self.best_state)
213
-
214
- def get_rngstate(self):
215
- return self.get_np_tree(self.rngstate)
216
-
217
- def get_np_tree(self, pytree):
218
- return jax.tree_util.tree_map(lambda x : np.array(x), pytree)
219
-
220
- def checkpoint_path(self):
221
- path = os.path.join(self.checkpoint_base_path, self.name.replace(' ', '_').lower())
222
- if not os.path.exists(path):
223
- os.makedirs(path)
224
- return path
225
-
226
- def tensorboard_path(self):
227
- experiment_name = self.name
228
- path = os.path.join(os.path.abspath('./tensorboard'), experiment_name)
229
- if not os.path.exists(path):
230
- os.makedirs(path)
231
- return path
232
-
233
- def load(self, checkpoint_path=None, checkpoint_step=None):
234
- if checkpoint_path is None:
235
- checkpointer = self.checkpointer
236
- else:
237
- checkpointer = orbax.checkpoint.PyTreeCheckpointer()
238
- options = orbax.checkpoint.CheckpointManagerOptions(
239
- max_to_keep=4, create=False)
240
- checkpointer = orbax.checkpoint.CheckpointManager(
241
- checkpoint_path, checkpointer, options)
242
-
243
- if checkpoint_step is None:
244
- step = checkpointer.latest_step()
245
- else:
246
- step = checkpoint_step
247
-
248
- print("Loading model from checkpoint at step ", step)
249
- ckpt = checkpointer.restore(step)
250
- state = ckpt['state']
251
- best_state = ckpt['best_state']
252
- rngstate = ckpt['rngs']
253
- # Convert the state to a TrainState
254
- self.best_loss = ckpt['best_loss']
255
- if self.best_loss == 0:
256
- # It cant be zero as that must have been some problem
257
- self.best_loss = 1e9
258
- current_epoch = ckpt.get('epoch', step) # Must be a checkpoint from an older version which used epochs instead of steps
259
- print(
260
- f"Loaded model from checkpoint at epoch {current_epoch} step {step}", ckpt['best_loss'])
261
- return current_epoch, step, state, best_state, rngstate
262
-
263
- def save(self, epoch=0, step=0, state=None, rngstate=None):
264
- print(f"Saving model at epoch {epoch} step {step}")
265
- try:
266
- ckpt = {
267
- # 'model': self.model,
268
- 'rngs': self.get_rngstate() if rngstate is None else self.get_np_tree(rngstate),
269
- 'state': self.get_state() if state is None else self.get_np_tree(state),
270
- 'best_state': self.get_best_state(),
271
- 'best_loss': np.array(self.best_loss),
272
- 'epoch': epoch,
273
- }
274
- try:
275
- save_args = orbax_utils.save_args_from_target(ckpt)
276
- self.checkpointer.save(step, ckpt, save_kwargs={
277
- 'save_args': save_args}, force=True)
278
- self.checkpointer.wait_until_finished()
279
- pass
280
- except Exception as e:
281
- print("Error saving checkpoint", e)
282
- except Exception as e:
283
- print("Error saving checkpoint outer", e)
284
-
285
- def _define_train_step(self, **kwargs):
286
- model = self.model
287
- loss_fn = self.loss_fn
288
- distributed_training = self.distributed_training
289
-
290
- def train_step(train_state: SimpleTrainState, rng_state: RandomMarkovState, batch, local_device_indexes):
291
- """Train for a single step."""
292
- images = batch['image']
293
- labels = batch['label']
294
-
295
- def model_loss(params):
296
- preds = model.apply(params, images)
297
- expected_output = labels
298
- nloss = loss_fn(preds, expected_output)
299
- loss = jnp.mean(nloss)
300
- return loss
301
- loss, grads = jax.value_and_grad(model_loss)(train_state.params)
302
- if distributed_training:
303
- grads = jax.lax.pmean(grads, "data")
304
- train_state = train_state.apply_gradients(grads=grads)
305
- return train_state, loss, rng_state
306
-
307
- if distributed_training:
308
- train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')), out_specs=(P(), P('data'), P()))
309
- train_step = jax.pmap(train_step)
310
- return train_step
311
-
312
- def _define_vaidation_step(self):
313
- model = self.model
314
- loss_fn = self.loss_fn
315
- distributed_training = self.distributed_training
316
-
317
- def validation_step(state: SimpleTrainState, batch):
318
- preds = model.apply(state.params, batch['image'])
319
- expected_output = batch['label']
320
- loss = jnp.mean(loss_fn(preds, expected_output))
321
- if distributed_training:
322
- loss = jax.lax.pmean(loss, "data")
323
- metric_updates = state.metrics.single_from_model_output(
324
- loss=loss, logits=preds, labels=expected_output)
325
- metrics = state.metrics.merge(metric_updates)
326
- state = state.replace(metrics=metrics)
327
- return state
328
- if distributed_training:
329
- validation_step = shard_map(validation_step, mesh=self.mesh, in_specs=(P(), P('data')), out_specs=(P()))
330
- validation_step = jax.pmap(validation_step)
331
- return validation_step
332
-
333
- def summary(self):
334
- input_vars = self.get_input_ones()
335
- print(self.model.tabulate(jax.random.key(0), **input_vars,
336
- console_kwargs={"width": 200, "force_jupyter": True, }))
337
-
338
- def config(self):
339
- return {
340
- "model": self.model,
341
- "state": self.state,
342
- "name": self.name,
343
- "input_shapes": self.input_shapes
344
- }
345
-
346
- def init_tensorboard(self, batch_size, steps_per_epoch, epochs):
347
- from flax.metrics import tensorboard
348
- summary_writer = tensorboard.SummaryWriter(self.tensorboard_path())
349
- summary_writer.hparams({
350
- **self.config(),
351
- "steps_per_epoch": steps_per_epoch,
352
- "epochs": epochs,
353
- "batch_size": batch_size
354
- })
355
- return summary_writer
356
-
357
- def validation_loop(
358
- self,
359
- val_state: SimpleTrainState,
360
- val_step_fn: Callable,
361
- val_ds,
362
- val_steps_per_epoch,
363
- current_step,
364
- ):
365
- global_device_count = jax.device_count()
366
- local_device_count = jax.local_device_count()
367
- process_index = jax.process_index()
368
-
369
- val_ds = iter(val_ds()) if val_ds else None
370
- # Evaluation step
371
- try:
372
- for i in range(val_steps_per_epoch):
373
- if val_ds is None:
374
- batch = None
375
- else:
376
- batch = next(val_ds)
377
- if self.distributed_training and global_device_count > 1:
378
- batch = convert_to_global_tree(self.mesh, batch)
379
- if i == 0:
380
- print(f"Evaluation started for process index {process_index}")
381
- metrics = val_step_fn(val_state, batch)
382
- if self.wandb is not None:
383
- # metrics is a dict of metrics
384
- if metrics and type(metrics) == dict:
385
- for key, value in metrics.items():
386
- if isinstance(value, jnp.ndarray):
387
- value = np.array(value)
388
- self.wandb.log({
389
- f"val/{key}": value,
390
- }, step=current_step)
391
- except Exception as e:
392
- print("Error logging images to wandb", e)
393
-
394
- def train_loop(
395
- self,
396
- train_state: SimpleTrainState,
397
- train_step_fn: Callable,
398
- train_ds,
399
- train_steps_per_epoch,
400
- current_step,
401
- rng_state
402
- ):
403
- global_device_count = jax.device_count()
404
- local_device_count = jax.local_device_count()
405
- process_index = jax.process_index()
406
- if self.distributed_training:
407
- global_device_indexes = jnp.arange(global_device_count)
408
- else:
409
- global_device_indexes = 0
410
-
411
- epoch_loss = 0
412
- current_epoch = current_step // train_steps_per_epoch
413
- last_save_time = time.time()
414
-
415
- if process_index == 0:
416
- pbar = tqdm.tqdm(total=train_steps_per_epoch, desc=f'\t\tEpoch {current_epoch}', ncols=100, unit='step')
417
-
418
- for i in range(train_steps_per_epoch):
419
- batch = next(train_ds)
420
- if i == 0:
421
- print(f"First batch loaded at step {current_step}")
422
-
423
- if self.distributed_training and global_device_count > 1:
424
- # # Convert the local device batches to a unified global jax.Array
425
- batch = convert_to_global_tree(self.mesh, batch)
426
- train_state, loss, rng_state = train_step_fn(train_state, rng_state, batch, global_device_indexes)
427
-
428
- if i == 0:
429
- print(f"Training started for process index {process_index} at step {current_step}")
430
-
431
- if self.distributed_training:
432
- # loss = jax.experimental.multihost_utils.process_allgather(loss)
433
- loss = jnp.mean(loss) # Just to make sure its a scaler value
434
-
435
- if loss <= 1e-6:
436
- # If the loss is too low, we can assume the model has diverged
437
- print(colored(f"Loss too low at step {current_step} => {loss}", 'red'))
438
- # Reset the model to the old state
439
- exit(1)
440
-
441
- epoch_loss += loss
442
- current_step += 1
443
- if i % 100 == 0:
444
- if pbar is not None:
445
- pbar.set_postfix(loss=f'{loss:.4f}')
446
- pbar.update(100)
447
- if self.wandb is not None:
448
- self.wandb.log({
449
- "train/step" : current_step,
450
- "train/loss": loss,
451
- }, step=current_step)
452
- # Save the model every few steps
453
- if i % 10000 == 0 and i > 0:
454
- print(f"Saving model after 10000 step {current_step}")
455
- print(f"Devices: {len(jax.devices())}") # To sync the devices
456
- self.save(current_epoch, current_step, train_state, rng_state)
457
- print(f"Saving done by process index {process_index}")
458
- last_save_time = time.time()
459
- print(colored(f"Epoch done on index {process_index} => {current_epoch} Loss: {epoch_loss/train_steps_per_epoch}", 'green'))
460
- if pbar is not None:
461
- pbar.close()
462
- return epoch_loss, current_step, train_state, rng_state
463
-
464
-
465
- def fit(self, data, train_steps_per_epoch, epochs, train_step_args={}, val_steps_per_epoch=5, validation_step_args={}):
466
- train_ds = iter(data['train']())
467
- train_step = self._define_train_step(**train_step_args)
468
- val_step = self._define_vaidation_step(**validation_step_args)
469
- train_state = self.state
470
- rng_state = self.rngstate
471
- process_index = jax.process_index()
472
-
473
- if val_steps_per_epoch > 0:
474
- # We should first run a validation step to make sure the model is working
475
- print(f"Validation run for sanity check for process index {process_index}")
476
- # Validation step
477
- self.validation_loop(
478
- train_state,
479
- val_step,
480
- data.get('test', data.get('val', None)),
481
- val_steps_per_epoch,
482
- self.latest_step,
483
- )
484
- print(colored(f"Sanity Validation done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
485
-
486
- while self.latest_step < epochs * train_steps_per_epoch:
487
- current_epoch = self.latest_step // train_steps_per_epoch
488
- print(f"\nEpoch {current_epoch}/{epochs}")
489
- start_time = time.time()
490
- epoch_loss = 0
491
-
492
- epoch_loss, current_step, train_state, rng_state = self.train_loop(
493
- train_state,
494
- train_step,
495
- train_ds,
496
- train_steps_per_epoch,
497
- self.latest_step,
498
- rng_state,
499
- )
500
- print(colored(f"Epoch done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
501
-
502
- self.latest_step = current_step
503
- end_time = time.time()
504
- self.state = train_state
505
- self.rngstate = rng_state
506
- total_time = end_time - start_time
507
- avg_time_per_step = total_time / train_steps_per_epoch
508
- avg_loss = epoch_loss / train_steps_per_epoch
509
- if avg_loss < self.best_loss:
510
- self.best_loss = avg_loss
511
- self.best_state = train_state
512
- self.save(current_epoch, current_step)
513
-
514
- if process_index == 0:
515
- if self.wandb is not None:
516
- self.wandb.log({
517
- "train/epoch_time": total_time,
518
- "train/avg_time_per_step": avg_time_per_step,
519
- "train/avg_loss": avg_loss,
520
- "train/best_loss": self.best_loss,
521
- "train/epoch": current_epoch,
522
- }, step=current_step)
523
- print(colored(f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss}", 'green'))
524
-
525
- if val_steps_per_epoch > 0:
526
- print(f"Validation started for process index {process_index}")
527
- # Validation step
528
- self.validation_loop(
529
- train_state,
530
- val_step,
531
- data.get('test', None),
532
- val_steps_per_epoch,
533
- current_step,
534
- )
535
- print(colored(f"Validation done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
536
-
537
- self.save(epochs)
538
- return self.state
@@ -1,62 +0,0 @@
1
- import flax
2
- from flax import linen as nn
3
- import jax
4
- from typing import Callable
5
- from dataclasses import field
6
- import jax.numpy as jnp
7
- import optax
8
- import functools
9
- from jax.sharding import Mesh, PartitionSpec as P
10
- from jax.experimental.shard_map import shard_map
11
- from typing import Dict, Callable, Sequence, Any, Union, Tuple
12
-
13
- from ..schedulers import NoiseScheduler
14
- from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
15
-
16
- from flaxdiff.utils import RandomMarkovState
17
-
18
- from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
19
-
20
- from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
21
- from flax.training import dynamic_scale as dynamic_scale_lib
22
-
23
- class TrainState(SimpleTrainState):
24
- rngs: jax.random.PRNGKey
25
- ema_params: dict
26
-
27
- def apply_ema(self, decay: float = 0.999):
28
- new_ema_params = jax.tree_util.tree_map(
29
- lambda ema, param: decay * ema + (1 - decay) * param,
30
- self.ema_params,
31
- self.params,
32
- )
33
- return self.replace(ema_params=new_ema_params)
34
-
35
- from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
36
- from flaxdiff.trainer.diffusion_trainer import DiffusionTrainer
37
-
38
- class SimpleVideoDiffusionTrainer(DiffusionTrainer):
39
- def __init__(self,
40
- model: nn.Module,
41
- input_shapes: Dict[str, Tuple[int]],
42
- optimizer: optax.GradientTransformation,
43
- noise_schedule: NoiseScheduler,
44
- rngs: jax.random.PRNGKey,
45
- unconditional_prob: float = 0.12,
46
- name: str = "SimpleVideoDiffusion",
47
- model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
48
- autoencoder: AutoEncoder = None,
49
- **kwargs
50
- ):
51
- super().__init__(
52
- model=model,
53
- input_shapes=input_shapes,
54
- optimizer=optimizer,
55
- noise_schedule=noise_schedule,
56
- unconditional_prob=unconditional_prob,
57
- autoencoder=autoencoder,
58
- model_output_transform=model_output_transform,
59
- rngs=rngs,
60
- name=name,
61
- **kwargs
62
- )
@@ -1,43 +0,0 @@
1
- flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- flaxdiff/utils.py,sha256=b_hFXsam2NICQYCFk0EOcqtBjM-RUqnN0NKTn0lQ070,6532
3
- flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
4
- flaxdiff/data/dataset_map.py,sha256=hcHaoR2IbNQmfyPUhYd6_8xinurxxCqawQijAsDI0Ek,3093
5
- flaxdiff/data/datasets.py,sha256=YUMoSvF2yAyikRvRofZVlHwfEOU3zXSSG4KkLnVfpoA,5626
6
- flaxdiff/data/online_loader.py,sha256=1Fi_QRixxRzbt602nORINcDeHEccvCrBpagrz4PURYg,12499
7
- flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
8
- flaxdiff/models/attention.py,sha256=JvrP7-09MV6IfRLRBhqjPmNUU-lkEMk9TOnJSBKcar8,13289
9
- flaxdiff/models/common.py,sha256=hWsSs2BP2J-JN1s4qLRr-h-KYkcVyl2hOp1Wsm_L-h8,10994
10
- flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
11
- flaxdiff/models/simple_unet.py,sha256=L5m2j5580QP7pJ5VIme7U5xYA22PZiGP7qdvcKUnB38,11463
12
- flaxdiff/models/simple_vit.py,sha256=UCDDr0XVnpf6tbJWKFtEt3_nAqMqOoakXf5amyVWZNo,7929
13
- flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
14
- flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
15
- flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
16
- flaxdiff/models/autoencoder/simple_autoenc.py,sha256=UXHPgDmwGTnv3Uts6Zj3p9R9nJXnEiEXbllgarwDfXM,805
17
- flaxdiff/predictors/__init__.py,sha256=SKkYYRF9Wfgk2zhtZw4vCXOdOeRlrm2Mk6cvuaEvAzc,4403
18
- flaxdiff/samplers/__init__.py,sha256=_S-9TwDeshrI0VmapV-J2hqjTByOa0-oOeUs_IdovjU,285
19
- flaxdiff/samplers/common.py,sha256=ZA08VyovxegpRx4wOQq9LSwZi0gSCz2lrbS5oVYOEYg,8488
20
- flaxdiff/samplers/ddim.py,sha256=pB8Kod8ZLJ3GXev4uM3cOj1Uy6ibR0jsaZa-VE0fyJM,552
21
- flaxdiff/samplers/ddpm.py,sha256=u1OchQu0XPhc_6w9JXoaFp2wo4y-zXyQNtGAIJwxNLg,2209
22
- flaxdiff/samplers/euler.py,sha256=Htb-IJeu7jSgY6mvgYr9yl9pUnos49vijlVk5IQsRps,2740
23
- flaxdiff/samplers/heun_sampler.py,sha256=UyI-hSlyWvt-7VEUJj27zjgyzKkGVl8fDUHV-YpSOCc,1421
24
- flaxdiff/samplers/multistep_dpm.py,sha256=3Wu3MrMLYaBb1ObraTbWrJmtEtU0adl1dDbz5fPJ4Gs,2735
25
- flaxdiff/samplers/rk4_sampler.py,sha256=1j1pES_Q2QiaURvEWeedbbT1LHmkc3jsu0GgH83qBL0,1926
26
- flaxdiff/schedulers/__init__.py,sha256=3id390WEfdf-MN-oLSPAhlRFIXrFWr6ioAHPAwURJyE,375
27
- flaxdiff/schedulers/common.py,sha256=b-W4iI-aqScpVE8VZbBpiYvAVI6rqDkUP-C_hEVBwCI,4151
28
- flaxdiff/schedulers/continuous.py,sha256=5c_niOA20fxJ5oJDi09FfayIRogBGwtfG0XThW2IUZk,334
29
- flaxdiff/schedulers/cosine.py,sha256=9ban0dFHLMm35wQvaBT4nCQwPGmzNsXwQ1xI0oppmJI,2005
30
- flaxdiff/schedulers/discrete.py,sha256=O54wH2HVu3olJA71NxgAXFW9cr6B6Gl-DR_uZeytpds,3319
31
- flaxdiff/schedulers/exp.py,sha256=cPTnUJpYdzJRRZqMLYQz0rRUCpEmaP2tXhRumLx94jA,605
32
- flaxdiff/schedulers/karras.py,sha256=4GN120kGwdxxU-h2mVdhBVy9IORkUMm_vvz3XjthBcI,3355
33
- flaxdiff/schedulers/linear.py,sha256=6003F5ISq1Wc0h6UAzY95MJgsDIKGMhBzbiVALpea0k,581
34
- flaxdiff/schedulers/sqrt.py,sha256=1F84ZgQPuoNMhe6yxGTR2G0h7dPOZtm4UDQOakbSsEU,445
35
- flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,128
36
- flaxdiff/trainer/autoencoder_trainer.py,sha256=hxihkRL9WCIQVGOP-pc1jjjIUaRXDLcNo3_erTKsuWM,7049
37
- flaxdiff/trainer/diffusion_trainer.py,sha256=ajOWBgFFwXP_VQScUjcuPoaB4Gk02aF0Ls5LNlA8wqA,12691
38
- flaxdiff/trainer/simple_trainer.py,sha256=jCD9-qCwX0SC0rN3GrXUBfRrndWNqUI0HmbOAbmYBMM,21906
39
- flaxdiff/trainer/video_diffusion_trainer.py,sha256=gMkKpnKNTo8QhTx5ptEEkc7W5-7rzXIr9queU53hXyQ,2197
40
- flaxdiff-0.1.36.dist-info/METADATA,sha256=7fO1e_icIEK6dmSopv538Hm2fQnhnkOAE2Ab9inpcNE,22213
41
- flaxdiff-0.1.36.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
42
- flaxdiff-0.1.36.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
43
- flaxdiff-0.1.36.dist-info/RECORD,,