flaxdiff 0.1.35.6__py3-none-any.whl → 0.1.36.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. flaxdiff/utils.py +105 -2
  2. {flaxdiff-0.1.35.6.dist-info → flaxdiff-0.1.36.1.dist-info}/METADATA +16 -7
  3. flaxdiff-0.1.36.1.dist-info/RECORD +6 -0
  4. {flaxdiff-0.1.35.6.dist-info → flaxdiff-0.1.36.1.dist-info}/WHEEL +1 -1
  5. flaxdiff/data/__init__.py +0 -1
  6. flaxdiff/data/online_loader.py +0 -336
  7. flaxdiff/models/__init__.py +0 -1
  8. flaxdiff/models/attention.py +0 -368
  9. flaxdiff/models/autoencoder/__init__.py +0 -2
  10. flaxdiff/models/autoencoder/autoencoder.py +0 -19
  11. flaxdiff/models/autoencoder/diffusers.py +0 -91
  12. flaxdiff/models/autoencoder/simple_autoenc.py +0 -26
  13. flaxdiff/models/common.py +0 -346
  14. flaxdiff/models/favor_fastattn.py +0 -723
  15. flaxdiff/models/simple_unet.py +0 -233
  16. flaxdiff/models/simple_vit.py +0 -180
  17. flaxdiff/predictors/__init__.py +0 -96
  18. flaxdiff/samplers/__init__.py +0 -7
  19. flaxdiff/samplers/common.py +0 -113
  20. flaxdiff/samplers/ddim.py +0 -10
  21. flaxdiff/samplers/ddpm.py +0 -43
  22. flaxdiff/samplers/euler.py +0 -59
  23. flaxdiff/samplers/heun_sampler.py +0 -28
  24. flaxdiff/samplers/multistep_dpm.py +0 -60
  25. flaxdiff/samplers/rk4_sampler.py +0 -34
  26. flaxdiff/schedulers/__init__.py +0 -6
  27. flaxdiff/schedulers/common.py +0 -98
  28. flaxdiff/schedulers/continuous.py +0 -12
  29. flaxdiff/schedulers/cosine.py +0 -40
  30. flaxdiff/schedulers/discrete.py +0 -74
  31. flaxdiff/schedulers/exp.py +0 -13
  32. flaxdiff/schedulers/karras.py +0 -69
  33. flaxdiff/schedulers/linear.py +0 -14
  34. flaxdiff/schedulers/sqrt.py +0 -10
  35. flaxdiff/trainer/__init__.py +0 -2
  36. flaxdiff/trainer/autoencoder_trainer.py +0 -182
  37. flaxdiff/trainer/diffusion_trainer.py +0 -234
  38. flaxdiff/trainer/simple_trainer.py +0 -442
  39. flaxdiff-0.1.35.6.dist-info/RECORD +0 -40
  40. {flaxdiff-0.1.35.6.dist-info → flaxdiff-0.1.36.1.dist-info}/top_level.txt +0 -0
@@ -1,442 +0,0 @@
1
- import orbax.checkpoint
2
- import tqdm
3
- from flax import linen as nn
4
- import jax
5
- from typing import Callable
6
- from dataclasses import field
7
- import jax.numpy as jnp
8
- import numpy as np
9
- from functools import partial
10
- from clu import metrics
11
- from flax.training import train_state # Useful dataclass to keep train state
12
- import optax
13
- from flax import struct # Flax dataclasses
14
- import flax
15
- import time
16
- import os
17
- import orbax
18
- from flax.training import orbax_utils
19
- from jax.sharding import Mesh, PartitionSpec as P
20
- from jax.experimental import mesh_utils
21
- from jax.experimental.shard_map import shard_map
22
- from orbax.checkpoint.utils import fully_replicated_host_local_array_to_global_array
23
- from termcolor import colored
24
- from typing import Dict, Callable, Sequence, Any, Union, Tuple
25
- from flax.training.dynamic_scale import DynamicScale
26
- from flaxdiff.utils import RandomMarkovState
27
-
28
- PROCESS_COLOR_MAP = {
29
- 0: "green",
30
- 1: "yellow",
31
- 2: "magenta",
32
- 3: "cyan",
33
- 4: "white",
34
- 5: "light_blue",
35
- 6: "light_red",
36
- 7: "light_cyan"
37
- }
38
-
39
- def _build_global_shape_and_sharding(
40
- local_shape: tuple[int, ...], global_mesh: Mesh
41
- ) -> tuple[tuple[int, ...], jax.sharding.NamedSharding]:
42
- sharding = jax.sharding.NamedSharding(global_mesh, P(global_mesh.axis_names))
43
- global_shape = (jax.process_count() * local_shape[0],) + local_shape[1:]
44
- return global_shape, sharding
45
-
46
-
47
- def form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array:
48
- """Put local sharded array into local devices"""
49
- global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh)
50
- try:
51
- local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0)
52
- except ValueError as array_split_error:
53
- raise ValueError(
54
- f"Unable to put to devices shape {array.shape} with "
55
- f"local device count {len(global_mesh.local_devices)} "
56
- ) from array_split_error
57
- local_device_buffers = jax.device_put(local_device_arrays, global_mesh.local_devices)
58
- return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers)
59
-
60
- def convert_to_global_tree(global_mesh, pytree):
61
- return jax.tree_util.tree_map_with_path(partial(form_global_array, global_mesh=global_mesh), pytree)
62
-
63
- @struct.dataclass
64
- class Metrics(metrics.Collection):
65
- accuracy: metrics.Accuracy
66
- loss: metrics.Average.from_output('loss')
67
-
68
- # Define the TrainState
69
- class SimpleTrainState(train_state.TrainState):
70
- metrics: Metrics
71
- dynamic_scale: DynamicScale
72
-
73
- class SimpleTrainer:
74
- state: SimpleTrainState
75
- best_state: SimpleTrainState
76
- best_loss: float
77
- model: nn.Module
78
- ema_decay: float = 0.999
79
-
80
- def __init__(self,
81
- model: nn.Module,
82
- input_shapes: Dict[str, Tuple[int]],
83
- optimizer: optax.GradientTransformation,
84
- rngs: jax.random.PRNGKey,
85
- train_state: SimpleTrainState = None,
86
- name: str = "Simple",
87
- load_from_checkpoint: str = None,
88
- checkpoint_suffix: str = "",
89
- loss_fn=optax.l2_loss,
90
- param_transforms: Callable = None,
91
- wandb_config: Dict[str, Any] = None,
92
- distributed_training: bool = None,
93
- checkpoint_base_path: str = "./checkpoints",
94
- checkpoint_step: int = None,
95
- use_dynamic_scale: bool = False,
96
- ):
97
- if distributed_training is None or distributed_training is True:
98
- # Auto-detect if we are running on multiple devices
99
- distributed_training = jax.device_count() > 1
100
- self.mesh = jax.sharding.Mesh(jax.devices(), 'data')
101
- else:
102
- self.mesh = None
103
-
104
- self.distributed_training = distributed_training
105
- self.model = model
106
- self.name = name
107
- self.loss_fn = loss_fn
108
- self.input_shapes = input_shapes
109
- self.checkpoint_base_path = checkpoint_base_path
110
-
111
-
112
- if wandb_config is not None and jax.process_index() == 0:
113
- run = wandb.init(**wandb_config)
114
- self.wandb = run
115
-
116
- # define our custom x axis metric
117
- self.wandb.define_metric("train/step")
118
- self.wandb.define_metric("train/epoch")
119
-
120
- self.wandb.define_metric("train/loss", step_metric="train/step")
121
-
122
- self.wandb.define_metric("train/epoch_time", step_metric="train/epoch")
123
- self.wandb.define_metric("train/avg_time_per_step", step_metric="train/epoch")
124
- self.wandb.define_metric("train/avg_loss", step_metric="train/epoch")
125
- self.wandb.define_metric("train/best_loss", step_metric="train/epoch")
126
-
127
- # checkpointer = orbax.checkpoint.PyTreeCheckpointer()
128
- async_checkpointer = orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler(), timeout_secs=60)
129
-
130
- options = orbax.checkpoint.CheckpointManagerOptions(
131
- max_to_keep=4, create=True)
132
- self.checkpointer = orbax.checkpoint.CheckpointManager(
133
- self.checkpoint_path() + checkpoint_suffix, async_checkpointer, options)
134
-
135
- if load_from_checkpoint is not None:
136
- latest_epoch, latest_step, old_state, old_best_state, rngstate = self.load(load_from_checkpoint, checkpoint_step)
137
- else:
138
- latest_epoch, latest_step, old_state, old_best_state, rngstate = 0, 0, None, None, None
139
-
140
- self.latest_step = latest_step
141
-
142
- if rngstate:
143
- self.rngstate = RandomMarkovState(**rngstate)
144
- else:
145
- self.rngstate = RandomMarkovState(rngs)
146
-
147
- self.rngstate, subkey = self.rngstate.get_random_key()
148
-
149
- if train_state == None:
150
- state, best_state = self.generate_states(
151
- optimizer, subkey, old_state, old_best_state, model, param_transforms, use_dynamic_scale
152
- )
153
- self.init_state(state, best_state)
154
- else:
155
- self.state = train_state
156
- self.best_state = train_state
157
- self.best_loss = 1e9
158
-
159
- def get_input_ones(self):
160
- return {k: jnp.ones((1, *v)) for k, v in self.input_shapes.items()}
161
-
162
- def generate_states(
163
- self,
164
- optimizer: optax.GradientTransformation,
165
- rngs: jax.random.PRNGKey,
166
- existing_state: dict = None,
167
- existing_best_state: dict = None,
168
- model: nn.Module = None,
169
- param_transforms: Callable = None,
170
- use_dynamic_scale: bool = False
171
- ) -> Tuple[SimpleTrainState, SimpleTrainState]:
172
- print("Generating states for SimpleTrainer")
173
- rngs, subkey = jax.random.split(rngs)
174
-
175
- if existing_state == None:
176
- input_vars = self.get_input_ones()
177
- params = model.init(subkey, **input_vars)
178
- else:
179
- params = existing_state['params']
180
-
181
- if param_transforms is not None:
182
- params = param_transforms(params)
183
-
184
- state = SimpleTrainState.create(
185
- apply_fn=model.apply,
186
- params=params,
187
- tx=optimizer,
188
- metrics=Metrics.empty(),
189
- dynamic_scale = DynamicScale() if use_dynamic_scale else None
190
- )
191
- if existing_best_state is not None:
192
- best_state = state.replace(
193
- params=existing_best_state['params'])
194
- else:
195
- best_state = state
196
-
197
- return state, best_state
198
-
199
- def init_state(
200
- self,
201
- state: SimpleTrainState,
202
- best_state: SimpleTrainState,
203
- ):
204
- self.best_loss = 1e9
205
-
206
- self.state = state
207
- self.best_state = best_state
208
-
209
- def get_state(self):
210
- # return fully_replicated_host_local_array_to_global_array()
211
- return jax.tree_util.tree_map(lambda x : np.array(x), self.state)
212
-
213
- def get_best_state(self):
214
- # return convert_to_global_tree(self.mesh, flax.jax_utils.replicate(self.best_state, jax.local_devices()))
215
- return jax.tree_util.tree_map(lambda x : np.array(x), self.best_state)
216
-
217
- def get_rngstate(self):
218
- # return convert_to_global_tree(self.mesh, flax.jax_utils.replicate(self.rngstate, jax.local_devices()))
219
- return jax.tree_util.tree_map(lambda x : np.array(x), self.rngstate)
220
-
221
- def checkpoint_path(self):
222
- path = os.path.join(self.checkpoint_base_path, self.name.replace(' ', '_').lower())
223
- if not os.path.exists(path):
224
- os.makedirs(path)
225
- return path
226
-
227
- def tensorboard_path(self):
228
- experiment_name = self.name
229
- path = os.path.join(os.path.abspath('./tensorboard'), experiment_name)
230
- if not os.path.exists(path):
231
- os.makedirs(path)
232
- return path
233
-
234
- def load(self, checkpoint_path=None, checkpoint_step=None):
235
- if checkpoint_path is None:
236
- checkpointer = self.checkpointer
237
- else:
238
- checkpointer = orbax.checkpoint.PyTreeCheckpointer()
239
- options = orbax.checkpoint.CheckpointManagerOptions(
240
- max_to_keep=4, create=False)
241
- checkpointer = orbax.checkpoint.CheckpointManager(
242
- checkpoint_path, checkpointer, options)
243
-
244
- if checkpoint_step is None:
245
- step = checkpointer.latest_step()
246
- else:
247
- step = checkpoint_step
248
-
249
- print("Loading model from checkpoint at step ", step)
250
- ckpt = checkpointer.restore(step)
251
- state = ckpt['state']
252
- best_state = ckpt['best_state']
253
- rngstate = ckpt['rngs']
254
- # Convert the state to a TrainState
255
- self.best_loss = ckpt['best_loss']
256
- current_epoch = ckpt.get('epoch', step) # Must be a checkpoint from an older version which used epochs instead of steps
257
- print(
258
- f"Loaded model from checkpoint at epoch {current_epoch} step {step}", ckpt['best_loss'])
259
- return current_epoch, step, state, best_state, rngstate
260
-
261
- def save(self, epoch=0, step=0):
262
- print(f"Saving model at epoch {epoch} step {step}")
263
- ckpt = {
264
- # 'model': self.model,
265
- 'rngs': self.get_rngstate(),
266
- 'state': self.get_state(),
267
- 'best_state': self.get_best_state(),
268
- 'best_loss': np.array(self.best_loss),
269
- 'epoch': epoch,
270
- }
271
- try:
272
- save_args = orbax_utils.save_args_from_target(ckpt)
273
- self.checkpointer.save(step, ckpt, save_kwargs={
274
- 'save_args': save_args}, force=True)
275
- self.checkpointer.wait_until_finished()
276
- pass
277
- except Exception as e:
278
- print("Error saving checkpoint", e)
279
-
280
- def _define_train_step(self, **kwargs):
281
- model = self.model
282
- loss_fn = self.loss_fn
283
- distributed_training = self.distributed_training
284
-
285
- def train_step(train_state: SimpleTrainState, rng_state: RandomMarkovState, batch, local_device_indexes):
286
- """Train for a single step."""
287
- images = batch['image']
288
- labels = batch['label']
289
-
290
- def model_loss(params):
291
- preds = model.apply(params, images)
292
- expected_output = labels
293
- nloss = loss_fn(preds, expected_output)
294
- loss = jnp.mean(nloss)
295
- return loss
296
- loss, grads = jax.value_and_grad(model_loss)(train_state.params)
297
- if distributed_training:
298
- grads = jax.lax.pmean(grads, "data")
299
- train_state = train_state.apply_gradients(grads=grads)
300
- return train_state, loss, rng_state
301
-
302
- if distributed_training:
303
- train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')), out_specs=(P(), P('data'), P()))
304
- train_step = jax.pmap(train_step)
305
- return train_step
306
-
307
- def _define_compute_metrics(self):
308
- model = self.model
309
- loss_fn = self.loss_fn
310
-
311
- @jax.jit
312
- def compute_metrics(state: SimpleTrainState, batch):
313
- preds = model.apply(state.params, batch['image'])
314
- expected_output = batch['label']
315
- loss = jnp.mean(loss_fn(preds, expected_output))
316
- metric_updates = state.metrics.single_from_model_output(
317
- loss=loss, logits=preds, labels=expected_output)
318
- metrics = state.metrics.merge(metric_updates)
319
- state = state.replace(metrics=metrics)
320
- return state
321
- return compute_metrics
322
-
323
- def summary(self):
324
- input_vars = self.get_input_ones()
325
- print(self.model.tabulate(jax.random.key(0), **input_vars,
326
- console_kwargs={"width": 200, "force_jupyter": True, }))
327
-
328
- def config(self):
329
- return {
330
- "model": self.model,
331
- "state": self.state,
332
- "name": self.name,
333
- "input_shapes": self.input_shapes
334
- }
335
-
336
- def init_tensorboard(self, batch_size, steps_per_epoch, epochs):
337
- from flax.metrics import tensorboard
338
- summary_writer = tensorboard.SummaryWriter(self.tensorboard_path())
339
- summary_writer.hparams({
340
- **self.config(),
341
- "steps_per_epoch": steps_per_epoch,
342
- "epochs": epochs,
343
- "batch_size": batch_size
344
- })
345
- return summary_writer
346
-
347
- def fit(self, data, steps_per_epoch, epochs, train_step_args={}):
348
- train_ds = iter(data['train']())
349
- if 'test' in data:
350
- test_ds = data['test']
351
- else:
352
- test_ds = None
353
- train_step = self._define_train_step(**train_step_args)
354
- compute_metrics = self._define_compute_metrics()
355
- train_state = self.state
356
- rng_state = self.rngstate
357
- global_device_count = jax.device_count()
358
- local_device_count = jax.local_device_count()
359
- process_index = jax.process_index()
360
- if self.distributed_training:
361
- global_device_indexes = jnp.arange(global_device_count)
362
- else:
363
- global_device_indexes = 0
364
-
365
- def train_loop(current_step, pbar: tqdm.tqdm, train_state, rng_state):
366
- epoch_loss = 0
367
- current_epoch = current_step // steps_per_epoch
368
- last_save_time = time.time()
369
- for i in range(steps_per_epoch):
370
- batch = next(train_ds)
371
- if self.distributed_training and global_device_count > 1:
372
- # Convert the local device batches to a unified global jax.Array
373
- batch = convert_to_global_tree(self.mesh, batch)
374
- train_state, loss, rng_state = train_step(train_state, rng_state, batch, global_device_indexes)
375
-
376
- if self.distributed_training:
377
- loss = jax.experimental.multihost_utils.process_allgather(loss)
378
- loss = jnp.mean(loss) # Just to make sure its a scaler value
379
-
380
- if loss <= 1e-6:
381
- # If the loss is too low, we can assume the model has diverged
382
- print(colored(f"Loss too low at step {current_step} => {loss}", 'red'))
383
- # Exit the training loop
384
- exit(1)
385
-
386
- epoch_loss += loss
387
- current_step += 1
388
- if i % 100 == 0:
389
- if pbar is not None:
390
- pbar.set_postfix(loss=f'{loss:.4f}')
391
- pbar.update(100)
392
- if self.wandb is not None:
393
- self.wandb.log({
394
- "train/step" : current_step,
395
- "train/loss": loss,
396
- }, step=current_step)
397
- # Save the model every 40 minutes
398
- if time.time() - last_save_time > 40 * 60:
399
- print(f"Saving model after 40 minutes at step {current_step}")
400
- self.save(current_epoch, current_step)
401
- last_save_time = time.time()
402
- print(colored(f"Epoch done on index {process_index} => {current_epoch} Loss: {epoch_loss/steps_per_epoch}", 'green'))
403
- return epoch_loss, current_step, train_state, rng_state
404
-
405
- while self.latest_step < epochs * steps_per_epoch:
406
- current_epoch = self.latest_step // steps_per_epoch
407
- print(f"\nEpoch {current_epoch}/{epochs}")
408
- start_time = time.time()
409
- epoch_loss = 0
410
-
411
- if process_index == 0:
412
- with tqdm.tqdm(total=steps_per_epoch, desc=f'\t\tEpoch {current_epoch}', ncols=100, unit='step') as pbar:
413
- epoch_loss, current_step, train_state, rng_state = train_loop(self.latest_step, pbar, train_state, rng_state)
414
- else:
415
- epoch_loss, current_step, train_state, rng_state = train_loop(self.latest_step, None, train_state, rng_state)
416
- print(colored(f"Epoch done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
417
-
418
- self.latest_step = current_step
419
- end_time = time.time()
420
- self.state = train_state
421
- self.rngstate = rng_state
422
- total_time = end_time - start_time
423
- avg_time_per_step = total_time / steps_per_epoch
424
- avg_loss = epoch_loss / steps_per_epoch
425
- if avg_loss < self.best_loss:
426
- self.best_loss = avg_loss
427
- self.best_state = train_state
428
- self.save(current_epoch, current_step)
429
-
430
- if process_index == 0:
431
- if self.wandb is not None:
432
- self.wandb.log({
433
- "train/epoch_time": total_time,
434
- "train/avg_time_per_step": avg_time_per_step,
435
- "train/avg_loss": avg_loss,
436
- "train/best_loss": self.best_loss,
437
- "train/epoch": current_epoch,
438
- }, step=current_step)
439
- print(colored(f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss}", 'green'))
440
- print("Training done")
441
- self.save(epochs)
442
- return self.state
@@ -1,40 +0,0 @@
1
- flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
3
- flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
4
- flaxdiff/data/online_loader.py,sha256=DoHrMZCi5gMd9tmkCpZIUU9lGxvfYtuaz58943_lCRc,11315
5
- flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
6
- flaxdiff/models/attention.py,sha256=JvrP7-09MV6IfRLRBhqjPmNUU-lkEMk9TOnJSBKcar8,13289
7
- flaxdiff/models/common.py,sha256=hWsSs2BP2J-JN1s4qLRr-h-KYkcVyl2hOp1Wsm_L-h8,10994
8
- flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
9
- flaxdiff/models/simple_unet.py,sha256=L5m2j5580QP7pJ5VIme7U5xYA22PZiGP7qdvcKUnB38,11463
10
- flaxdiff/models/simple_vit.py,sha256=UCDDr0XVnpf6tbJWKFtEt3_nAqMqOoakXf5amyVWZNo,7929
11
- flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
12
- flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
13
- flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
14
- flaxdiff/models/autoencoder/simple_autoenc.py,sha256=UXHPgDmwGTnv3Uts6Zj3p9R9nJXnEiEXbllgarwDfXM,805
15
- flaxdiff/predictors/__init__.py,sha256=SKkYYRF9Wfgk2zhtZw4vCXOdOeRlrm2Mk6cvuaEvAzc,4403
16
- flaxdiff/samplers/__init__.py,sha256=_S-9TwDeshrI0VmapV-J2hqjTByOa0-oOeUs_IdovjU,285
17
- flaxdiff/samplers/common.py,sha256=_an5h5Niz9Joz_-ppridLrGHpu8X0VVvhNGknPu6AUY,5272
18
- flaxdiff/samplers/ddim.py,sha256=XHMBX06S5hMTnKMaGh6fmq189pQcaGkA6fnX6YPbHP0,511
19
- flaxdiff/samplers/ddpm.py,sha256=d_58hfVShJHsRPQf5h1-4YKrD43-HGjWz9Vd8hltZBg,2627
20
- flaxdiff/samplers/euler.py,sha256=Epf7LBKUky7B8b-1ZyIlLWdRMgjmP08BQraGSKmr_3I,2726
21
- flaxdiff/samplers/heun_sampler.py,sha256=hhWnSM26OfOIFAcsuWYa1z-2QPjASuoYTop2byLWqzE,1388
22
- flaxdiff/samplers/multistep_dpm.py,sha256=ocmEq2sCvsULy6oTFaD5BhTU4c8VHsge4bdg6tfxW80,2724
23
- flaxdiff/samplers/rk4_sampler.py,sha256=BF-dMV1KauO-SYShqrCfm3U3V-1n4clqQXBeoG8RWQo,1728
24
- flaxdiff/schedulers/__init__.py,sha256=3id390WEfdf-MN-oLSPAhlRFIXrFWr6ioAHPAwURJyE,375
25
- flaxdiff/schedulers/common.py,sha256=b-W4iI-aqScpVE8VZbBpiYvAVI6rqDkUP-C_hEVBwCI,4151
26
- flaxdiff/schedulers/continuous.py,sha256=5c_niOA20fxJ5oJDi09FfayIRogBGwtfG0XThW2IUZk,334
27
- flaxdiff/schedulers/cosine.py,sha256=9ban0dFHLMm35wQvaBT4nCQwPGmzNsXwQ1xI0oppmJI,2005
28
- flaxdiff/schedulers/discrete.py,sha256=O54wH2HVu3olJA71NxgAXFW9cr6B6Gl-DR_uZeytpds,3319
29
- flaxdiff/schedulers/exp.py,sha256=cPTnUJpYdzJRRZqMLYQz0rRUCpEmaP2tXhRumLx94jA,605
30
- flaxdiff/schedulers/karras.py,sha256=4GN120kGwdxxU-h2mVdhBVy9IORkUMm_vvz3XjthBcI,3355
31
- flaxdiff/schedulers/linear.py,sha256=6003F5ISq1Wc0h6UAzY95MJgsDIKGMhBzbiVALpea0k,581
32
- flaxdiff/schedulers/sqrt.py,sha256=1F84ZgQPuoNMhe6yxGTR2G0h7dPOZtm4UDQOakbSsEU,445
33
- flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,128
34
- flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
35
- flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
36
- flaxdiff/trainer/simple_trainer.py,sha256=cawm6fZNQoLLATMneAU2gQ9j7kefqHnBPHuaIj3i_a4,18237
37
- flaxdiff-0.1.35.6.dist-info/METADATA,sha256=NVCk5V7Zc3iq-nrWTivzO17dQa1fIjYgjJb800ZrZhQ,22085
38
- flaxdiff-0.1.35.6.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
39
- flaxdiff-0.1.35.6.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
- flaxdiff-0.1.35.6.dist-info/RECORD,,