flaxdiff 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,14 +5,60 @@ import jax
5
5
  from typing import Callable
6
6
  from dataclasses import field
7
7
  import jax.numpy as jnp
8
+ import numpy as np
9
+ from functools import partial
8
10
  from clu import metrics
9
11
  from flax.training import train_state # Useful dataclass to keep train state
10
12
  import optax
11
13
  from flax import struct # Flax dataclasses
14
+ import flax
12
15
  import time
13
16
  import os
14
17
  import orbax
15
18
  from flax.training import orbax_utils
19
+ from jax.sharding import Mesh, PartitionSpec as P
20
+ from jax.experimental import mesh_utils
21
+ from jax.experimental.shard_map import shard_map
22
+ from orbax.checkpoint.utils import fully_replicated_host_local_array_to_global_array
23
+ from termcolor import colored
24
+ from typing import Dict, Callable, Sequence, Any, Union, Tuple
25
+
26
+ from flaxdiff.utils import RandomMarkovState
27
+
28
+ PROCESS_COLOR_MAP = {
29
+ 0: "green",
30
+ 1: "yellow",
31
+ 2: "magenta",
32
+ 3: "cyan",
33
+ 4: "white",
34
+ 5: "light_blue",
35
+ 6: "light_red",
36
+ 7: "light_cyan"
37
+ }
38
+
39
+ def _build_global_shape_and_sharding(
40
+ local_shape: tuple[int, ...], global_mesh: Mesh
41
+ ) -> tuple[tuple[int, ...], jax.sharding.NamedSharding]:
42
+ sharding = jax.sharding.NamedSharding(global_mesh, P(global_mesh.axis_names))
43
+ global_shape = (jax.process_count() * local_shape[0],) + local_shape[1:]
44
+ return global_shape, sharding
45
+
46
+
47
+ def form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array:
48
+ """Put local sharded array into local devices"""
49
+ global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh)
50
+ try:
51
+ local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0)
52
+ except ValueError as array_split_error:
53
+ raise ValueError(
54
+ f"Unable to put to devices shape {array.shape} with "
55
+ f"local device count {len(global_mesh.local_devices)} "
56
+ ) from array_split_error
57
+ local_device_buffers = jax.device_put(local_device_arrays, global_mesh.local_devices)
58
+ return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers)
59
+
60
+ def convert_to_global_tree(global_mesh, pytree):
61
+ return jax.tree_util.tree_map_with_path(partial(form_global_array, global_mesh=global_mesh), pytree)
16
62
 
17
63
  @struct.dataclass
18
64
  class Metrics(metrics.Collection):
@@ -44,41 +90,75 @@ class SimpleTrainer:
44
90
  name: str = "Simple",
45
91
  load_from_checkpoint: bool = False,
46
92
  checkpoint_suffix: str = "",
93
+ checkpoint_id: str = None,
47
94
  loss_fn=optax.l2_loss,
48
95
  param_transforms: Callable = None,
49
96
  wandb_config: Dict[str, Any] = None,
50
97
  distributed_training: bool = None,
98
+ checkpoint_base_path: str = "./checkpoints",
51
99
  ):
52
100
  if distributed_training is None or distributed_training is True:
53
101
  # Auto-detect if we are running on multiple devices
54
102
  distributed_training = jax.device_count() > 1
103
+ self.mesh = jax.sharding.Mesh(jax.devices(), 'data')
104
+ # self.sharding = jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec('data'))
55
105
 
56
106
  self.distributed_training = distributed_training
57
107
  self.model = model
58
108
  self.name = name
59
109
  self.loss_fn = loss_fn
60
110
  self.input_shapes = input_shapes
61
-
62
- if wandb_config is not None:
111
+ self.checkpoint_base_path = checkpoint_base_path
112
+
113
+
114
+ if wandb_config is not None and jax.process_index() == 0:
115
+ import wandb
63
116
  run = wandb.init(**wandb_config)
64
117
  self.wandb = run
118
+
119
+ # define our custom x axis metric
120
+ self.wandb.define_metric("train/step")
121
+ self.wandb.define_metric("train/epoch")
122
+
123
+ self.wandb.define_metric("train/loss", step_metric="train/step")
124
+
125
+ self.wandb.define_metric("train/epoch_time", step_metric="train/epoch")
126
+ self.wandb.define_metric("train/avg_time_per_step", step_metric="train/epoch")
127
+ self.wandb.define_metric("train/avg_loss", step_metric="train/epoch")
128
+ self.wandb.define_metric("train/best_loss", step_metric="train/epoch")
129
+
130
+ if checkpoint_id is None:
131
+ self.checkpoint_id = name.replace(' ', '_').replace('-', '_').lower()
132
+ else:
133
+ self.checkpoint_id = checkpoint_id
134
+
135
+ # checkpointer = orbax.checkpoint.PyTreeCheckpointer()
136
+ async_checkpointer = orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler(), timeout_secs=60)
65
137
 
66
- checkpointer = orbax.checkpoint.PyTreeCheckpointer()
67
138
  options = orbax.checkpoint.CheckpointManagerOptions(
68
139
  max_to_keep=4, create=True)
69
140
  self.checkpointer = orbax.checkpoint.CheckpointManager(
70
- self.checkpoint_path() + checkpoint_suffix, checkpointer, options)
141
+ self.checkpoint_path() + checkpoint_suffix, async_checkpointer, options)
71
142
 
72
143
  if load_from_checkpoint:
73
- latest_epoch, old_state, old_best_state = self.load()
144
+ latest_epoch, old_state, old_best_state, rngstate = self.load()
74
145
  else:
75
- latest_epoch, old_state, old_best_state = 0, None, None
146
+ latest_epoch, old_state, old_best_state, rngstate = 0, None, None, None
76
147
 
77
148
  self.latest_epoch = latest_epoch
149
+
150
+ if rngstate:
151
+ self.rngstate = RandomMarkovState(**rngstate)
152
+ else:
153
+ self.rngstate = RandomMarkovState(rngs)
154
+
155
+ self.rngstate, subkey = self.rngstate.get_random_key()
78
156
 
79
157
  if train_state == None:
80
- self.init_state(optimizer, rngs, existing_state=old_state,
81
- existing_best_state=old_best_state, model=model, param_transforms=param_transforms)
158
+ state, best_state = self.generate_states(
159
+ optimizer, subkey, old_state, old_best_state, model, param_transforms
160
+ )
161
+ self.init_state(state, best_state)
82
162
  else:
83
163
  self.state = train_state
84
164
  self.best_state = train_state
@@ -87,7 +167,7 @@ class SimpleTrainer:
87
167
  def get_input_ones(self):
88
168
  return {k: jnp.ones((1, *v)) for k, v in self.input_shapes.items()}
89
169
 
90
- def __init_fn(
170
+ def generate_states(
91
171
  self,
92
172
  optimizer: optax.GradientTransformation,
93
173
  rngs: jax.random.PRNGKey,
@@ -96,17 +176,19 @@ class SimpleTrainer:
96
176
  model: nn.Module = None,
97
177
  param_transforms: Callable = None
98
178
  ) -> Tuple[SimpleTrainState, SimpleTrainState]:
179
+ print("Generating states for SimpleTrainer")
99
180
  rngs, subkey = jax.random.split(rngs)
100
181
 
101
182
  if existing_state == None:
102
183
  input_vars = self.get_input_ones()
103
184
  params = model.init(subkey, **input_vars)
185
+ else:
186
+ params = existing_state['params']
104
187
 
105
188
  state = SimpleTrainState.create(
106
189
  apply_fn=model.apply,
107
190
  params=params,
108
191
  tx=optimizer,
109
- rngs=rngs,
110
192
  metrics=Metrics.empty()
111
193
  )
112
194
  if existing_best_state is not None:
@@ -119,40 +201,28 @@ class SimpleTrainer:
119
201
 
120
202
  def init_state(
121
203
  self,
122
- optimizer: optax.GradientTransformation,
123
- rngs: jax.random.PRNGKey,
124
- existing_state: dict = None,
125
- existing_best_state: dict = None,
126
- model: nn.Module = None,
127
- param_transforms: Callable = None
204
+ state: SimpleTrainState,
205
+ best_state: SimpleTrainState,
128
206
  ):
129
-
130
- state, best_state = self.__init_fn(
131
- optimizer, rngs, existing_state, existing_best_state, model, param_transforms
132
- )
133
207
  self.best_loss = 1e9
134
208
 
135
- if self.distributed_training:
136
- devices = jax.local_devices()
137
- if len(devices) > 1:
138
- print("Replicating state across devices ", devices)
139
- state = flax.jax_utils.replicate(state, devices)
140
- best_state = flax.jax_utils.replicate(best_state, devices)
141
- else:
142
- print("Not replicating any state, Only single device connected to the process")
143
-
144
209
  self.state = state
145
210
  self.best_state = best_state
146
211
 
147
212
  def get_state(self):
148
- return flax.jax_utils.unreplicate(self.state)
213
+ # return fully_replicated_host_local_array_to_global_array()
214
+ return jax.tree_util.tree_map(lambda x : np.array(x), self.state)
149
215
 
150
216
  def get_best_state(self):
151
- return flax.jax_utils.unreplicate(self.best_state)
217
+ # return convert_to_global_tree(self.mesh, flax.jax_utils.replicate(self.best_state, jax.local_devices()))
218
+ return jax.tree_util.tree_map(lambda x : np.array(x), self.best_state)
219
+
220
+ def get_rngstate(self):
221
+ # return convert_to_global_tree(self.mesh, flax.jax_utils.replicate(self.rngstate, jax.local_devices()))
222
+ return jax.tree_util.tree_map(lambda x : np.array(x), self.rngstate)
152
223
 
153
224
  def checkpoint_path(self):
154
- experiment_name = self.name
155
- path = os.path.join(os.path.abspath('./checkpoints'), experiment_name)
225
+ path = os.path.join(self.checkpoint_base_path, self.checkpoint_id)
156
226
  if not os.path.exists(path):
157
227
  os.makedirs(path)
158
228
  return path
@@ -170,24 +240,27 @@ class SimpleTrainer:
170
240
  ckpt = self.checkpointer.restore(epoch)
171
241
  state = ckpt['state']
172
242
  best_state = ckpt['best_state']
243
+ rngstate = ckpt['rngs']
173
244
  # Convert the state to a TrainState
174
245
  self.best_loss = ckpt['best_loss']
175
246
  print(
176
247
  f"Loaded model from checkpoint at epoch {epoch}", ckpt['best_loss'])
177
- return epoch, state, best_state
248
+ return epoch, state, best_state, rngstate
178
249
 
179
250
  def save(self, epoch=0):
180
251
  print(f"Saving model at epoch {epoch}")
181
252
  ckpt = {
182
253
  # 'model': self.model,
254
+ 'rngs': self.get_rngstate(),
183
255
  'state': self.get_state(),
184
256
  'best_state': self.get_best_state(),
185
- 'best_loss': self.best_loss
257
+ 'best_loss': np.array(self.best_loss),
186
258
  }
187
259
  try:
188
260
  save_args = orbax_utils.save_args_from_target(ckpt)
189
261
  self.checkpointer.save(epoch, ckpt, save_kwargs={
190
262
  'save_args': save_args}, force=True)
263
+ self.checkpointer.wait_until_finished()
191
264
  pass
192
265
  except Exception as e:
193
266
  print("Error saving checkpoint", e)
@@ -197,7 +270,7 @@ class SimpleTrainer:
197
270
  loss_fn = self.loss_fn
198
271
  distributed_training = self.distributed_training
199
272
 
200
- def train_step(state: SimpleTrainState, batch):
273
+ def train_step(train_state: SimpleTrainState, rng_state: RandomMarkovState, batch, local_device_indexes):
201
274
  """Train for a single step."""
202
275
  images = batch['image']
203
276
  labels = batch['label']
@@ -208,17 +281,15 @@ class SimpleTrainer:
208
281
  nloss = loss_fn(preds, expected_output)
209
282
  loss = jnp.mean(nloss)
210
283
  return loss
211
- loss, grads = jax.value_and_grad(model_loss)(state.params)
284
+ loss, grads = jax.value_and_grad(model_loss)(train_state.params)
212
285
  if distributed_training:
213
- grads = jax.lax.pmean(grads, "device")
214
- state = state.apply_gradients(grads=grads)
215
- return state, loss
286
+ grads = jax.lax.pmean(grads, "data")
287
+ train_state = train_state.apply_gradients(grads=grads)
288
+ return train_state, loss, rng_state
216
289
 
217
290
  if distributed_training:
218
- train_step = jax.pmap(axis_name="device")(train_step)
219
- else:
220
- train_step = jax.jit(train_step)
221
-
291
+ train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')), out_specs=(P(), P('data'), P()))
292
+ train_step = jax.pmap(train_step)
222
293
  return train_step
223
294
 
224
295
  def _define_compute_metrics(self):
@@ -251,6 +322,7 @@ class SimpleTrainer:
251
322
  }
252
323
 
253
324
  def init_tensorboard(self, batch_size, steps_per_epoch, epochs):
325
+ from flax.metrics import tensorboard
254
326
  summary_writer = tensorboard.SummaryWriter(self.tensorboard_path())
255
327
  summary_writer.hparams({
256
328
  **self.config(),
@@ -268,56 +340,79 @@ class SimpleTrainer:
268
340
  test_ds = None
269
341
  train_step = self._define_train_step(**train_step_args)
270
342
  compute_metrics = self._define_compute_metrics()
271
- state = self.state
272
- device_count = jax.local_device_count()
273
- # train_ds = flax.jax_utils.prefetch_to_device(train_ds, jax.devices())
274
-
275
- summary_writer = self.init_tensorboard(
276
- data['global_batch_size'], steps_per_epoch, epochs)
343
+ train_state = self.state
344
+ rng_state = self.rngstate
345
+ global_device_count = jax.device_count()
346
+ local_device_count = jax.local_device_count()
347
+ process_index = jax.process_index()
348
+ if self.distributed_training:
349
+ global_device_indexes = jnp.arange(global_device_count)
350
+ else:
351
+ global_device_indexes = 0
277
352
 
278
- while self.latest_epoch <= epochs:
279
- self.latest_epoch += 1
280
- current_epoch = self.latest_epoch
281
- print(f"\nEpoch {current_epoch}/{epochs}")
282
- start_time = time.time()
353
+ def train_loop(current_epoch, pbar: tqdm.tqdm, train_state, rng_state):
283
354
  epoch_loss = 0
284
-
285
- with tqdm.tqdm(total=steps_per_epoch, desc=f'\t\tEpoch {current_epoch}', ncols=100, unit='step') as pbar:
286
- for i in range(steps_per_epoch):
287
- batch = next(train_ds)
288
- if self.distributed_training and device_count > 1:
289
- batch = jax.tree.map(lambda x: x.reshape(
290
- (device_count, -1, *x.shape[1:])), batch)
355
+ current_step = 0
356
+ for i in range(steps_per_epoch):
357
+ batch = next(train_ds)
358
+ if self.distributed_training and global_device_count > 1:
359
+ # Convert the local device batches to a unified global jax.Array
360
+ batch = convert_to_global_tree(self.mesh, batch)
361
+ train_state, loss, rng_state = train_step(train_state, rng_state, batch, global_device_indexes)
362
+
363
+ if self.distributed_training:
364
+ loss = jax.experimental.multihost_utils.process_allgather(loss)
365
+ loss = jnp.mean(loss) # Just to make sure its a scaler value
366
+
367
+ epoch_loss += loss
291
368
 
292
- state, loss = train_step(state, batch)
293
- loss = jnp.mean(loss)
294
-
295
- epoch_loss += loss
369
+ if pbar is not None:
296
370
  if i % 100 == 0:
297
371
  pbar.set_postfix(loss=f'{loss:.4f}')
298
372
  pbar.update(100)
299
373
  current_step = current_epoch*steps_per_epoch + i
300
- summary_writer.scalar(
301
- 'Train Loss', loss, step=current_step)
302
374
  if self.wandb is not None:
303
- self.wandb.log({"train/loss": loss})
375
+ self.wandb.log({
376
+ "train/step" : current_step,
377
+ "train/loss": loss,
378
+ }, step=current_step)
379
+ print(colored(f"Epoch done on index {process_index} => {current_epoch} Loss: {epoch_loss/steps_per_epoch}", 'green'))
380
+ return epoch_loss, current_step, train_state, rng_state
381
+
382
+ while self.latest_epoch < epochs:
383
+ current_epoch = self.latest_epoch
384
+ self.latest_epoch += 1
385
+ print(f"\nEpoch {current_epoch}/{epochs}")
386
+ start_time = time.time()
387
+ epoch_loss = 0
304
388
 
305
- print(f"\n\tEpoch done")
389
+ if process_index == 0:
390
+ with tqdm.tqdm(total=steps_per_epoch, desc=f'\t\tEpoch {current_epoch}', ncols=100, unit='step') as pbar:
391
+ epoch_loss, current_step, train_state, rng_state = train_loop(current_epoch, pbar, train_state, rng_state)
392
+ else:
393
+ epoch_loss, current_step, train_state, rng_state = train_loop(current_epoch, None, train_state, rng_state)
394
+ print(colored(f"Epoch done on process index {process_index}", PROCESS_COLOR_MAP.get(process_index, 'white')))
395
+
306
396
  end_time = time.time()
307
- self.state = state
397
+ self.state = train_state
398
+ self.rngstate = rng_state
308
399
  total_time = end_time - start_time
309
400
  avg_time_per_step = total_time / steps_per_epoch
310
401
  avg_loss = epoch_loss / steps_per_epoch
311
402
  if avg_loss < self.best_loss:
312
403
  self.best_loss = avg_loss
313
- self.best_state = state
404
+ self.best_state = train_state
314
405
  self.save(current_epoch)
315
-
316
- # Compute Metrics
317
- metrics_str = ''
318
-
319
- print(
320
- f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss} {metrics_str}")
321
-
406
+
407
+ if process_index == 0:
408
+ if self.wandb is not None:
409
+ self.wandb.log({
410
+ "train/epoch_time": total_time,
411
+ "train/avg_time_per_step": avg_time_per_step,
412
+ "train/avg_loss": avg_loss,
413
+ "train/best_loss": self.best_loss,
414
+ "train/epoch": current_epoch,
415
+ }, step=current_step)
416
+ print(colored(f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss}", 'green'))
322
417
  self.save(epochs)
323
- return self.state
418
+ return self.state
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.4
3
+ Version: 0.1.6
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -13,6 +13,8 @@ Requires-Dist: clu
13
13
 
14
14
  # ![](images/logo.jpeg "FlaxDiff")
15
15
 
16
+ **This project is partially supported by [Google TPU Research Cloud](https://sites.research.google/trc/about/). I would like to thank the Google Cloud TPU team for providing me with the resources to train the bigger text-conditional models in multi-host distributed settings.**
17
+
16
18
  ## A Versatile and simple Diffusion Library
17
19
 
18
20
  In recent years, diffusion and score-based multi-step models have revolutionized the generative AI domain. However, the latest research in this field has become highly math-intensive, making it challenging to understand how state-of-the-art diffusion models work and generate such impressive images. Replicating this research in code can be daunting.
@@ -27,7 +29,7 @@ The `Diffusion_flax_linen.ipynb` notebook is my main workspace for experiments.
27
29
 
28
30
  In the `example notebooks` folder, you will find comprehensive notebooks for various diffusion techniques, written entirely from scratch and are independent of the FlaxDiff library. Each notebook includes detailed explanations of the underlying mathematics and concepts, making them invaluable resources for learning and understanding diffusion models.
29
31
 
30
- ### Available Notebooks
32
+ ### Available Notebooks and Resources
31
33
 
32
34
  - **[Diffusion explained (nbviewer link)](https://nbviewer.org/github/AshishKumar4/FlaxDiff/blob/main/tutorial%20notebooks/simple%20diffusion%20flax.ipynb) [(local link)](tutorial%20notebooks/simple%20diffusion%20flax.ipynb)**
33
35
 
@@ -46,6 +48,14 @@ In the `example notebooks` folder, you will find comprehensive notebooks for var
46
48
 
47
49
  These notebooks aim to provide a very easy to understand and step-by-step guide to the various diffusion models and techniques. They are designed to be beginner-friendly, and thus although they may not adhere to the exact formulations and implementations of the original papers to make them more understandable and generalizable, I have tried my best to keep them as accurate as possible. If you find any mistakes or have any suggestions, please feel free to open an issue or a pull request.
48
50
 
51
+ #### Other resources
52
+
53
+ - **[Multi-host Data parallel training script in JAX](./training.py)**
54
+ - Training script for multi-host data parallel training in JAX, to serve as a reference for training large models on multiple GPUs/TPUs across multiple hosts. A full-fledged tutorial notebook is in the works.
55
+
56
+ - **[TPU utilities for making life easier](./tpu-tools/)**
57
+ - A collection of utilities and scripts to make working with TPUs easier, such as cli to create/start/stop/setup TPUs, script to setup TPU VMs (install everything you need), mounting gcs datasets etc.
58
+
49
59
  ## Disclaimer (and About Me)
50
60
 
51
61
  I worked as a Machine Learning Researcher at Hyperverge from 2019-2021, focusing on computer vision, specifically facial anti-spoofing and facial detection & recognition. Since switching to my current job in 2021, I haven't engaged in as much R&D work, leading me to start this pet project to revisit and relearn the fundamentals and get familiar with the state-of-the-art. My current role involves primarily Golang system engineering with some applied ML work just sprinkled in. Therefore, the code may reflect my learning journey. Please forgive any mistakes and do open an issue to let me know.
@@ -1,11 +1,15 @@
1
1
  flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
3
3
  flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
4
- flaxdiff/models/attention.py,sha256=SL9cvINjmabW1LPvXLAFZNHv-FF1Ez_d3J7n5uHBTyQ,15301
5
- flaxdiff/models/common.py,sha256=CjC4iRLjkF3oQ0f6rAqfiLaiHllZGtCOwN3rXDUndbE,274
4
+ flaxdiff/models/attention.py,sha256=OhpKQXdxWbf8K2_yotLfS0DYdHb-zNpL2p8--ql_FAg,14503
5
+ flaxdiff/models/common.py,sha256=RYNxX9K19hvwSWaB9Wtv7MIZLhcacdugDgD9uZDh8XM,10358
6
6
  flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
7
- flaxdiff/models/simple_unet.py,sha256=WlLry6v18syHBzcN8zAJ-zIVtq6ItMEIBWbeCcX0MLU,18693
7
+ flaxdiff/models/simple_unet.py,sha256=hAcz074E9NVdUtECPMi1c1Kw-52Dc6l_ME-5FqIg-n8,9255
8
8
  flaxdiff/models/simple_vit.py,sha256=vTu2CQRoSOxetBHTrnCWddm-vxrZDkMe8EpdNxtpJMk,4015
9
+ flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
10
+ flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
11
+ flaxdiff/models/autoencoder/diffusers.py,sha256=kwlKwHBSAegtTiEkGju_1Trltegj-e47hXFN9jCKmgY,3609
12
+ flaxdiff/models/autoencoder/simple_autoenc.py,sha256=UXHPgDmwGTnv3Uts6Zj3p9R9nJXnEiEXbllgarwDfXM,805
9
13
  flaxdiff/predictors/__init__.py,sha256=SKkYYRF9Wfgk2zhtZw4vCXOdOeRlrm2Mk6cvuaEvAzc,4403
10
14
  flaxdiff/samplers/__init__.py,sha256=_S-9TwDeshrI0VmapV-J2hqjTByOa0-oOeUs_IdovjU,285
11
15
  flaxdiff/samplers/common.py,sha256=_an5h5Niz9Joz_-ppridLrGHpu8X0VVvhNGknPu6AUY,5272
@@ -24,9 +28,11 @@ flaxdiff/schedulers/exp.py,sha256=cPTnUJpYdzJRRZqMLYQz0rRUCpEmaP2tXhRumLx94jA,60
24
28
  flaxdiff/schedulers/karras.py,sha256=4GN120kGwdxxU-h2mVdhBVy9IORkUMm_vvz3XjthBcI,3355
25
29
  flaxdiff/schedulers/linear.py,sha256=6003F5ISq1Wc0h6UAzY95MJgsDIKGMhBzbiVALpea0k,581
26
30
  flaxdiff/schedulers/sqrt.py,sha256=1F84ZgQPuoNMhe6yxGTR2G0h7dPOZtm4UDQOakbSsEU,445
27
- flaxdiff/trainer/__init__.py,sha256=kwzkm-BD97hffFIXZUP1Hb3_D85fZ4SRNO7bviEwHU8,7591
28
- flaxdiff/trainer/simple_trainer.py,sha256=jafxr-yZ6FXn0Qi-iTSnlf275QWnIO4GnSvNAeB3H-Q,11651
29
- flaxdiff-0.1.4.dist-info/METADATA,sha256=G8OijdrrYWuKyAfCNtD_dKwdfBmdME56vpR-EYIZKXg,19229
30
- flaxdiff-0.1.4.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
31
- flaxdiff-0.1.4.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
32
- flaxdiff-0.1.4.dist-info/RECORD,,
31
+ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,128
32
+ flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
33
+ flaxdiff/trainer/diffusion_trainer.py,sha256=h5YxIMjBI553xDNeapzLDGF0_4y0MfGRMuHume5sPtM,7785
34
+ flaxdiff/trainer/simple_trainer.py,sha256=f4g2KGuGM__d9v_4Ip3ng8wQubmenWZUW60VEu2ANOg,16774
35
+ flaxdiff-0.1.6.dist-info/METADATA,sha256=sWY_oQgQhhuyW89KyRwIBrpVHBPJjRMmsk5twfgIBlo,20090
36
+ flaxdiff-0.1.6.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
37
+ flaxdiff-0.1.6.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
38
+ flaxdiff-0.1.6.dist-info/RECORD,,