flaxdiff 0.1.37.3__tar.gz → 0.1.37.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/PKG-INFO +1 -1
  2. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/predictors/__init__.py +5 -5
  3. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/trainer/diffusion_trainer.py +20 -2
  4. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/trainer/simple_trainer.py +7 -3
  5. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff.egg-info/PKG-INFO +1 -1
  6. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/pyproject.toml +1 -1
  7. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/README.md +0 -0
  8. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/__init__.py +0 -0
  9. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/data/__init__.py +0 -0
  10. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/data/dataset_map.py +0 -0
  11. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/data/datasets.py +0 -0
  12. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/data/online_loader.py +0 -0
  13. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/data/sources/gcs.py +0 -0
  14. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/data/sources/tfds.py +0 -0
  15. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/metrics/inception.py +0 -0
  16. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/metrics/psnr.py +0 -0
  17. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/metrics/ssim.py +0 -0
  18. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/metrics/utils.py +0 -0
  19. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/models/__init__.py +0 -0
  20. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/models/attention.py +0 -0
  21. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/models/autoencoder/__init__.py +0 -0
  22. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
  23. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/models/autoencoder/diffusers.py +0 -0
  24. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
  25. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/models/common.py +0 -0
  26. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/models/favor_fastattn.py +0 -0
  27. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/models/simple_unet.py +0 -0
  28. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/models/simple_vit.py +0 -0
  29. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/samplers/__init__.py +0 -0
  30. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/samplers/common.py +0 -0
  31. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/samplers/ddim.py +0 -0
  32. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/samplers/ddpm.py +0 -0
  33. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/samplers/euler.py +0 -0
  34. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/samplers/heun_sampler.py +0 -0
  35. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/samplers/multistep_dpm.py +0 -0
  36. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/samplers/rk4_sampler.py +0 -0
  37. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/schedulers/__init__.py +0 -0
  38. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/schedulers/common.py +0 -0
  39. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/schedulers/continuous.py +0 -0
  40. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/schedulers/cosine.py +0 -0
  41. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/schedulers/discrete.py +0 -0
  42. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/schedulers/exp.py +0 -0
  43. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/schedulers/karras.py +0 -0
  44. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/schedulers/linear.py +0 -0
  45. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/schedulers/sqrt.py +0 -0
  46. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/trainer/__init__.py +0 -0
  47. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
  48. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/trainer/video_diffusion_trainer.py +0 -0
  49. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff/utils.py +0 -0
  50. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff.egg-info/SOURCES.txt +0 -0
  51. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff.egg-info/dependency_links.txt +0 -0
  52. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff.egg-info/requires.txt +0 -0
  53. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/flaxdiff.egg-info/top_level.txt +0 -0
  54. {flaxdiff-0.1.37.3 → flaxdiff-0.1.37.4}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.37.3
3
+ Version: 0.1.37.4
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -81,16 +81,16 @@ class KarrasPredictionTransform(DiffusionPredictionTransform):
81
81
  epsilon = (x_t - x_0 * signal_rate) / noise_rate
82
82
  return x_0, epsilon
83
83
 
84
- def pred_transform(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
84
+ def pred_transform(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray], epsilon=1e-8) -> jnp.ndarray:
85
85
  _, sigma = rates
86
- c_out = sigma * self.sigma_data / jnp.sqrt(self.sigma_data ** 2 + sigma ** 2)
87
- c_skip = self.sigma_data ** 2 / (self.sigma_data ** 2 + sigma ** 2)
86
+ c_out = sigma * self.sigma_data / (jnp.sqrt(self.sigma_data ** 2 + sigma ** 2) + epsilon)
87
+ c_skip = self.sigma_data ** 2 / (self.sigma_data ** 2 + sigma ** 2 + epsilon)
88
88
  c_out = c_out.reshape((-1, 1, 1, 1))
89
89
  c_skip = c_skip.reshape((-1, 1, 1, 1))
90
90
  x_0 = c_out * preds + c_skip * x_t
91
91
  return x_0
92
92
 
93
- def get_input_scale(self, rates: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
93
+ def get_input_scale(self, rates: tuple[jnp.ndarray, jnp.ndarray], epsilon=1e-8) -> jnp.ndarray:
94
94
  _, sigma = rates
95
- c_in = 1 / jnp.sqrt(self.sigma_data ** 2 + sigma ** 2)
95
+ c_in = 1 / (jnp.sqrt(self.sigma_data ** 2 + sigma ** 2) + epsilon)
96
96
  return c_in
@@ -167,7 +167,10 @@ class DiffusionTrainer(SimpleTrainer):
167
167
  noise_level, local_rng_state = noise_schedule.generate_timesteps(images.shape[0], local_rng_state)
168
168
 
169
169
  local_rng_state, rngs = local_rng_state.get_random_key()
170
- noise: jax.Array = jax.random.normal(rngs, shape=images.shape)
170
+ noise: jax.Array = jax.random.normal(rngs, shape=images.shape, dtype=jnp.float32)
171
+
172
+ # Make sure image is also float32
173
+ images = images.astype(jnp.float32)
171
174
 
172
175
  rates = noise_schedule.get_rates(noise_level)
173
176
  noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
@@ -197,8 +200,23 @@ class DiffusionTrainer(SimpleTrainer):
197
200
  loss, grads = grad_fn(train_state.params)
198
201
  if distributed_training:
199
202
  grads = jax.lax.pmean(grads, "data")
203
+
204
+ # # check gradients for NaN/Inf
205
+ # has_nan_or_inf = jax.tree_util.tree_reduce(
206
+ # lambda acc, x: jnp.logical_or(acc, jnp.logical_or(jnp.isnan(x).any(), jnp.isinf(x).any())),
207
+ # grads,
208
+ # initializer=False
209
+ # )
200
210
 
201
- new_state = train_state.apply_gradients(grads=grads)
211
+ # # Only apply gradients if they're valid
212
+ # new_state = jax.lax.cond(
213
+ # has_nan_or_inf,
214
+ # lambda _: train_state, # Skip gradient update
215
+ # lambda _: train_state.apply_gradients(grads=grads),
216
+ # operand=None
217
+ # )
218
+
219
+ # new_state = train_state.apply_gradients(grads=grads)
202
220
 
203
221
  if train_state.dynamic_scale is not None:
204
222
  # if is_fin == False the gradients contain Inf/NaNs and optimizer state and
@@ -403,7 +403,6 @@ class SimpleTrainer:
403
403
  rng_state
404
404
  ):
405
405
  global_device_count = jax.device_count()
406
- local_device_count = jax.local_device_count()
407
406
  process_index = jax.process_index()
408
407
  if self.distributed_training:
409
408
  global_device_indexes = jnp.arange(global_device_count)
@@ -434,11 +433,16 @@ class SimpleTrainer:
434
433
  # loss = jax.experimental.multihost_utils.process_allgather(loss)
435
434
  loss = jnp.mean(loss) # Just to make sure its a scaler value
436
435
 
437
- if loss <= 1e-6:
436
+ if loss <= 1e-8:
438
437
  # If the loss is too low, we can assume the model has diverged
439
438
  print(colored(f"Loss too low at step {current_step} => {loss}", 'red'))
440
439
  # Reset the model to the old state
441
- exit(1)
440
+ if self.best_state is not None:
441
+ print(colored(f"Resetting model to best state", 'red'))
442
+ train_state = self.best_state
443
+ loss = self.best_loss
444
+ else:
445
+ exit(1)
442
446
 
443
447
  epoch_loss += loss
444
448
  current_step += 1
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.37.3
3
+ Version: 0.1.37.4
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "flaxdiff"
7
- version = "0.1.37.3"
7
+ version = "0.1.37.4"
8
8
  description = "A versatile and easy to understand Diffusion library"
9
9
  readme = "README.md"
10
10
  authors = [
File without changes
File without changes