flaxdiff 0.1.37.3__py3-none-any.whl → 0.1.37.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -81,16 +81,16 @@ class KarrasPredictionTransform(DiffusionPredictionTransform):
81
81
  epsilon = (x_t - x_0 * signal_rate) / noise_rate
82
82
  return x_0, epsilon
83
83
 
84
- def pred_transform(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
84
+ def pred_transform(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray], epsilon=1e-8) -> jnp.ndarray:
85
85
  _, sigma = rates
86
- c_out = sigma * self.sigma_data / jnp.sqrt(self.sigma_data ** 2 + sigma ** 2)
87
- c_skip = self.sigma_data ** 2 / (self.sigma_data ** 2 + sigma ** 2)
86
+ c_out = sigma * self.sigma_data / (jnp.sqrt(self.sigma_data ** 2 + sigma ** 2) + epsilon)
87
+ c_skip = self.sigma_data ** 2 / (self.sigma_data ** 2 + sigma ** 2 + epsilon)
88
88
  c_out = c_out.reshape((-1, 1, 1, 1))
89
89
  c_skip = c_skip.reshape((-1, 1, 1, 1))
90
90
  x_0 = c_out * preds + c_skip * x_t
91
91
  return x_0
92
92
 
93
- def get_input_scale(self, rates: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
93
+ def get_input_scale(self, rates: tuple[jnp.ndarray, jnp.ndarray], epsilon=1e-8) -> jnp.ndarray:
94
94
  _, sigma = rates
95
- c_in = 1 / jnp.sqrt(self.sigma_data ** 2 + sigma ** 2)
95
+ c_in = 1 / (jnp.sqrt(self.sigma_data ** 2 + sigma ** 2) + epsilon)
96
96
  return c_in
@@ -167,7 +167,10 @@ class DiffusionTrainer(SimpleTrainer):
167
167
  noise_level, local_rng_state = noise_schedule.generate_timesteps(images.shape[0], local_rng_state)
168
168
 
169
169
  local_rng_state, rngs = local_rng_state.get_random_key()
170
- noise: jax.Array = jax.random.normal(rngs, shape=images.shape)
170
+ noise: jax.Array = jax.random.normal(rngs, shape=images.shape, dtype=jnp.float32)
171
+
172
+ # Make sure image is also float32
173
+ images = images.astype(jnp.float32)
171
174
 
172
175
  rates = noise_schedule.get_rates(noise_level)
173
176
  noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
@@ -197,8 +200,23 @@ class DiffusionTrainer(SimpleTrainer):
197
200
  loss, grads = grad_fn(train_state.params)
198
201
  if distributed_training:
199
202
  grads = jax.lax.pmean(grads, "data")
203
+
204
+ # # check gradients for NaN/Inf
205
+ # has_nan_or_inf = jax.tree_util.tree_reduce(
206
+ # lambda acc, x: jnp.logical_or(acc, jnp.logical_or(jnp.isnan(x).any(), jnp.isinf(x).any())),
207
+ # grads,
208
+ # initializer=False
209
+ # )
200
210
 
201
- new_state = train_state.apply_gradients(grads=grads)
211
+ # # Only apply gradients if they're valid
212
+ # new_state = jax.lax.cond(
213
+ # has_nan_or_inf,
214
+ # lambda _: train_state, # Skip gradient update
215
+ # lambda _: train_state.apply_gradients(grads=grads),
216
+ # operand=None
217
+ # )
218
+
219
+ # new_state = train_state.apply_gradients(grads=grads)
202
220
 
203
221
  if train_state.dynamic_scale is not None:
204
222
  # if is_fin == False the gradients contain Inf/NaNs and optimizer state and
@@ -403,7 +403,6 @@ class SimpleTrainer:
403
403
  rng_state
404
404
  ):
405
405
  global_device_count = jax.device_count()
406
- local_device_count = jax.local_device_count()
407
406
  process_index = jax.process_index()
408
407
  if self.distributed_training:
409
408
  global_device_indexes = jnp.arange(global_device_count)
@@ -434,11 +433,16 @@ class SimpleTrainer:
434
433
  # loss = jax.experimental.multihost_utils.process_allgather(loss)
435
434
  loss = jnp.mean(loss) # Just to make sure its a scaler value
436
435
 
437
- if loss <= 1e-6:
436
+ if loss <= 1e-8:
438
437
  # If the loss is too low, we can assume the model has diverged
439
438
  print(colored(f"Loss too low at step {current_step} => {loss}", 'red'))
440
439
  # Reset the model to the old state
441
- exit(1)
440
+ if self.best_state is not None:
441
+ print(colored(f"Resetting model to best state", 'red'))
442
+ train_state = self.best_state
443
+ loss = self.best_loss
444
+ else:
445
+ exit(1)
442
446
 
443
447
  epoch_loss += loss
444
448
  current_step += 1
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.37.3
3
+ Version: 0.1.37.4
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -20,7 +20,7 @@ flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0
20
20
  flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
21
21
  flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
22
22
  flaxdiff/models/autoencoder/simple_autoenc.py,sha256=UXHPgDmwGTnv3Uts6Zj3p9R9nJXnEiEXbllgarwDfXM,805
23
- flaxdiff/predictors/__init__.py,sha256=SKkYYRF9Wfgk2zhtZw4vCXOdOeRlrm2Mk6cvuaEvAzc,4403
23
+ flaxdiff/predictors/__init__.py,sha256=S0R8_x-KST_cwaFKgBvaG4pwiMtrmgWjZCseyYfBPc4,4465
24
24
  flaxdiff/samplers/__init__.py,sha256=EY9v1pgwEoR64Kiz9K8fAR-4_ir9c03mYeY3hrpUNhE,308
25
25
  flaxdiff/samplers/common.py,sha256=7gKNY4mWVnLjtcioGLFD_Vwmxg9zJovUb8EcYWlc_GE,8833
26
26
  flaxdiff/samplers/ddim.py,sha256=hTjDm0SmIj-Tkc80QRATMcN_sKVhHbqZQboRQCAn4mY,569
@@ -40,10 +40,10 @@ flaxdiff/schedulers/linear.py,sha256=6003F5ISq1Wc0h6UAzY95MJgsDIKGMhBzbiVALpea0k
40
40
  flaxdiff/schedulers/sqrt.py,sha256=1F84ZgQPuoNMhe6yxGTR2G0h7dPOZtm4UDQOakbSsEU,445
41
41
  flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,128
42
42
  flaxdiff/trainer/autoencoder_trainer.py,sha256=hxihkRL9WCIQVGOP-pc1jjjIUaRXDLcNo3_erTKsuWM,7049
43
- flaxdiff/trainer/diffusion_trainer.py,sha256=KVeXJ9ZQKcvD-O_hCJnxro0dQRuQe5ZVGGMEL4Lgm9k,12814
44
- flaxdiff/trainer/simple_trainer.py,sha256=lmRo8N0bMupIyS3ejPvPtxoskY_3GLC8iyJE6u4TIWc,21990
43
+ flaxdiff/trainer/diffusion_trainer.py,sha256=oIvDco8nOZT0HSz2ZX5b8u3Y1_UhMlOECqt0vBLPn1Q,13567
44
+ flaxdiff/trainer/simple_trainer.py,sha256=LScHQZCy5ksSC7n0GC0tjOXK-zptxpMJsC6Udf-nz18,22178
45
45
  flaxdiff/trainer/video_diffusion_trainer.py,sha256=gMkKpnKNTo8QhTx5ptEEkc7W5-7rzXIr9queU53hXyQ,2197
46
- flaxdiff-0.1.37.3.dist-info/METADATA,sha256=7U7SINGO_ZzsUeuTPi2LTMqxrj93Cvglyh1Q7D39zRM,23985
47
- flaxdiff-0.1.37.3.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
48
- flaxdiff-0.1.37.3.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
49
- flaxdiff-0.1.37.3.dist-info/RECORD,,
46
+ flaxdiff-0.1.37.4.dist-info/METADATA,sha256=rdptpmTYbuDZlIcENjmgdYlHc7NA8FUDM_NRlfhAeWU,23985
47
+ flaxdiff-0.1.37.4.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
48
+ flaxdiff-0.1.37.4.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
49
+ flaxdiff-0.1.37.4.dist-info/RECORD,,