flaxdiff 0.1.37.2__tar.gz → 0.1.37.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/PKG-INFO +1 -1
- flaxdiff-0.1.37.4/flaxdiff/metrics/psnr.py +0 -0
- flaxdiff-0.1.37.4/flaxdiff/metrics/ssim.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/predictors/__init__.py +5 -5
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/samplers/common.py +7 -9
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/trainer/diffusion_trainer.py +20 -2
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/trainer/simple_trainer.py +7 -3
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff.egg-info/PKG-INFO +1 -1
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff.egg-info/SOURCES.txt +2 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/pyproject.toml +1 -1
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/README.md +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/__init__.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/data/__init__.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/data/dataset_map.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/data/datasets.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/data/online_loader.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/data/sources/gcs.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/data/sources/tfds.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/metrics/inception.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/metrics/utils.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/models/__init__.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/models/attention.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/models/autoencoder/__init__.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/models/autoencoder/diffusers.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/models/common.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/models/favor_fastattn.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/models/simple_unet.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/models/simple_vit.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/samplers/__init__.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/samplers/ddim.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/samplers/ddpm.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/samplers/euler.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/samplers/heun_sampler.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/samplers/multistep_dpm.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/samplers/rk4_sampler.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/schedulers/__init__.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/schedulers/common.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/schedulers/continuous.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/schedulers/cosine.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/schedulers/discrete.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/schedulers/exp.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/schedulers/karras.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/schedulers/linear.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/schedulers/sqrt.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/trainer/__init__.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/trainer/video_diffusion_trainer.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff/utils.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff.egg-info/dependency_links.txt +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff.egg-info/requires.txt +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/flaxdiff.egg-info/top_level.txt +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.4}/setup.cfg +0 -0
File without changes
|
File without changes
|
@@ -81,16 +81,16 @@ class KarrasPredictionTransform(DiffusionPredictionTransform):
|
|
81
81
|
epsilon = (x_t - x_0 * signal_rate) / noise_rate
|
82
82
|
return x_0, epsilon
|
83
83
|
|
84
|
-
def pred_transform(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
|
84
|
+
def pred_transform(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray], epsilon=1e-8) -> jnp.ndarray:
|
85
85
|
_, sigma = rates
|
86
|
-
c_out = sigma * self.sigma_data / jnp.sqrt(self.sigma_data ** 2 + sigma ** 2)
|
87
|
-
c_skip = self.sigma_data ** 2 / (self.sigma_data ** 2 + sigma ** 2)
|
86
|
+
c_out = sigma * self.sigma_data / (jnp.sqrt(self.sigma_data ** 2 + sigma ** 2) + epsilon)
|
87
|
+
c_skip = self.sigma_data ** 2 / (self.sigma_data ** 2 + sigma ** 2 + epsilon)
|
88
88
|
c_out = c_out.reshape((-1, 1, 1, 1))
|
89
89
|
c_skip = c_skip.reshape((-1, 1, 1, 1))
|
90
90
|
x_0 = c_out * preds + c_skip * x_t
|
91
91
|
return x_0
|
92
92
|
|
93
|
-
def get_input_scale(self, rates: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
|
93
|
+
def get_input_scale(self, rates: tuple[jnp.ndarray, jnp.ndarray], epsilon=1e-8) -> jnp.ndarray:
|
94
94
|
_, sigma = rates
|
95
|
-
c_in = 1 / jnp.sqrt(self.sigma_data ** 2 + sigma ** 2)
|
95
|
+
c_in = 1 / (jnp.sqrt(self.sigma_data ** 2 + sigma ** 2) + epsilon)
|
96
96
|
return c_in
|
@@ -2,20 +2,15 @@ from flax import linen as nn
|
|
2
2
|
import jax
|
3
3
|
import jax.numpy as jnp
|
4
4
|
import tqdm
|
5
|
-
from typing import Union
|
5
|
+
from typing import Union, Type
|
6
6
|
from ..schedulers import NoiseScheduler
|
7
7
|
from ..utils import RandomMarkovState, MarkovState, clip_images
|
8
8
|
from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
|
9
9
|
|
10
10
|
class DiffusionSampler():
|
11
|
-
model:nn.Module
|
12
|
-
noise_schedule:NoiseScheduler
|
13
|
-
params:dict
|
14
|
-
model_output_transform:DiffusionPredictionTransform
|
15
|
-
|
16
11
|
def __init__(self, model:nn.Module, params:dict,
|
17
12
|
noise_schedule:NoiseScheduler,
|
18
|
-
model_output_transform:DiffusionPredictionTransform
|
13
|
+
model_output_transform:DiffusionPredictionTransform,
|
19
14
|
guidance_scale:float = 0.0,
|
20
15
|
null_labels_seq:jax.Array=None,
|
21
16
|
autoencoder=None,
|
@@ -122,9 +117,11 @@ class DiffusionSampler():
|
|
122
117
|
end_step:int = 0,
|
123
118
|
steps_override=None,
|
124
119
|
priors=None,
|
125
|
-
rngstate:RandomMarkovState=
|
120
|
+
rngstate:RandomMarkovState=None,
|
126
121
|
model_conditioning_inputs:tuple=()
|
127
122
|
) -> jnp.ndarray:
|
123
|
+
if rngstate is None:
|
124
|
+
rngstate = RandomMarkovState(jax.random.PRNGKey(42))
|
128
125
|
if priors is None:
|
129
126
|
rngstate, newrngs = rngstate.get_random_key()
|
130
127
|
samples = self.get_initial_samples(num_images, newrngs, start_step)
|
@@ -169,4 +166,5 @@ class DiffusionSampler():
|
|
169
166
|
if self.autoencoder is not None:
|
170
167
|
samples = self.autoencoder.decode(samples)
|
171
168
|
samples = clip_images(samples)
|
172
|
-
return samples
|
169
|
+
return samples
|
170
|
+
|
@@ -167,7 +167,10 @@ class DiffusionTrainer(SimpleTrainer):
|
|
167
167
|
noise_level, local_rng_state = noise_schedule.generate_timesteps(images.shape[0], local_rng_state)
|
168
168
|
|
169
169
|
local_rng_state, rngs = local_rng_state.get_random_key()
|
170
|
-
noise: jax.Array = jax.random.normal(rngs, shape=images.shape)
|
170
|
+
noise: jax.Array = jax.random.normal(rngs, shape=images.shape, dtype=jnp.float32)
|
171
|
+
|
172
|
+
# Make sure image is also float32
|
173
|
+
images = images.astype(jnp.float32)
|
171
174
|
|
172
175
|
rates = noise_schedule.get_rates(noise_level)
|
173
176
|
noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
|
@@ -197,8 +200,23 @@ class DiffusionTrainer(SimpleTrainer):
|
|
197
200
|
loss, grads = grad_fn(train_state.params)
|
198
201
|
if distributed_training:
|
199
202
|
grads = jax.lax.pmean(grads, "data")
|
203
|
+
|
204
|
+
# # check gradients for NaN/Inf
|
205
|
+
# has_nan_or_inf = jax.tree_util.tree_reduce(
|
206
|
+
# lambda acc, x: jnp.logical_or(acc, jnp.logical_or(jnp.isnan(x).any(), jnp.isinf(x).any())),
|
207
|
+
# grads,
|
208
|
+
# initializer=False
|
209
|
+
# )
|
200
210
|
|
201
|
-
|
211
|
+
# # Only apply gradients if they're valid
|
212
|
+
# new_state = jax.lax.cond(
|
213
|
+
# has_nan_or_inf,
|
214
|
+
# lambda _: train_state, # Skip gradient update
|
215
|
+
# lambda _: train_state.apply_gradients(grads=grads),
|
216
|
+
# operand=None
|
217
|
+
# )
|
218
|
+
|
219
|
+
# new_state = train_state.apply_gradients(grads=grads)
|
202
220
|
|
203
221
|
if train_state.dynamic_scale is not None:
|
204
222
|
# if is_fin == False the gradients contain Inf/NaNs and optimizer state and
|
@@ -403,7 +403,6 @@ class SimpleTrainer:
|
|
403
403
|
rng_state
|
404
404
|
):
|
405
405
|
global_device_count = jax.device_count()
|
406
|
-
local_device_count = jax.local_device_count()
|
407
406
|
process_index = jax.process_index()
|
408
407
|
if self.distributed_training:
|
409
408
|
global_device_indexes = jnp.arange(global_device_count)
|
@@ -434,11 +433,16 @@ class SimpleTrainer:
|
|
434
433
|
# loss = jax.experimental.multihost_utils.process_allgather(loss)
|
435
434
|
loss = jnp.mean(loss) # Just to make sure its a scaler value
|
436
435
|
|
437
|
-
if loss <= 1e-
|
436
|
+
if loss <= 1e-8:
|
438
437
|
# If the loss is too low, we can assume the model has diverged
|
439
438
|
print(colored(f"Loss too low at step {current_step} => {loss}", 'red'))
|
440
439
|
# Reset the model to the old state
|
441
|
-
|
440
|
+
if self.best_state is not None:
|
441
|
+
print(colored(f"Resetting model to best state", 'red'))
|
442
|
+
train_state = self.best_state
|
443
|
+
loss = self.best_loss
|
444
|
+
else:
|
445
|
+
exit(1)
|
442
446
|
|
443
447
|
epoch_loss += loss
|
444
448
|
current_step += 1
|
@@ -14,6 +14,8 @@ flaxdiff/data/online_loader.py
|
|
14
14
|
flaxdiff/data/sources/gcs.py
|
15
15
|
flaxdiff/data/sources/tfds.py
|
16
16
|
flaxdiff/metrics/inception.py
|
17
|
+
flaxdiff/metrics/psnr.py
|
18
|
+
flaxdiff/metrics/ssim.py
|
17
19
|
flaxdiff/metrics/utils.py
|
18
20
|
flaxdiff/models/__init__.py
|
19
21
|
flaxdiff/models/attention.py
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|