flaxdiff 0.1.37.3__py3-none-any.whl → 0.1.37.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/predictors/__init__.py +9 -9
- flaxdiff/samplers/common.py +2 -1
- flaxdiff/schedulers/common.py +17 -13
- flaxdiff/schedulers/cosine.py +4 -4
- flaxdiff/schedulers/discrete.py +5 -7
- flaxdiff/schedulers/linear.py +1 -2
- flaxdiff/schedulers/sqrt.py +2 -1
- flaxdiff/trainer/diffusion_trainer.py +48 -14
- flaxdiff/trainer/simple_trainer.py +7 -3
- {flaxdiff-0.1.37.3.dist-info → flaxdiff-0.1.37.6.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.37.3.dist-info → flaxdiff-0.1.37.6.dist-info}/RECORD +13 -13
- {flaxdiff-0.1.37.3.dist-info → flaxdiff-0.1.37.6.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.37.3.dist-info → flaxdiff-0.1.37.6.dist-info}/top_level.txt +0 -0
flaxdiff/predictors/__init__.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
from typing import Union
|
2
2
|
import jax.numpy as jnp
|
3
|
-
from ..schedulers import NoiseScheduler, GeneralizedNoiseScheduler
|
3
|
+
from ..schedulers import NoiseScheduler, GeneralizedNoiseScheduler, get_coeff_shapes_tuple
|
4
4
|
|
5
5
|
############################################################################################################
|
6
6
|
# Prediction Transforms
|
@@ -11,7 +11,7 @@ class DiffusionPredictionTransform():
|
|
11
11
|
return preds
|
12
12
|
|
13
13
|
def __call__(self, x_t, preds, current_step, noise_schedule:NoiseScheduler) -> Union[jnp.ndarray, jnp.ndarray]:
|
14
|
-
rates = noise_schedule.get_rates(current_step)
|
14
|
+
rates = noise_schedule.get_rates(current_step, shape=get_coeff_shapes_tuple(x_t))
|
15
15
|
preds = self.pred_transform(x_t, preds, rates)
|
16
16
|
x_0, epsilon = self.backward_diffusion(x_t, preds, rates)
|
17
17
|
return x_0, epsilon
|
@@ -81,16 +81,16 @@ class KarrasPredictionTransform(DiffusionPredictionTransform):
|
|
81
81
|
epsilon = (x_t - x_0 * signal_rate) / noise_rate
|
82
82
|
return x_0, epsilon
|
83
83
|
|
84
|
-
def pred_transform(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
|
84
|
+
def pred_transform(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray], epsilon=1e-8) -> jnp.ndarray:
|
85
85
|
_, sigma = rates
|
86
|
-
c_out = sigma * self.sigma_data / jnp.sqrt(self.sigma_data ** 2 + sigma ** 2)
|
87
|
-
c_skip = self.sigma_data ** 2 / (self.sigma_data ** 2 + sigma ** 2)
|
88
|
-
c_out = c_out.reshape((
|
89
|
-
c_skip = c_skip.reshape((
|
86
|
+
c_out = sigma * self.sigma_data / (jnp.sqrt(self.sigma_data ** 2 + sigma ** 2) + epsilon)
|
87
|
+
c_skip = self.sigma_data ** 2 / (self.sigma_data ** 2 + sigma ** 2 + epsilon)
|
88
|
+
c_out = c_out.reshape(get_coeff_shapes_tuple(preds))
|
89
|
+
c_skip = c_skip.reshape(get_coeff_shapes_tuple(x_t))
|
90
90
|
x_0 = c_out * preds + c_skip * x_t
|
91
91
|
return x_0
|
92
92
|
|
93
|
-
def get_input_scale(self, rates: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
|
93
|
+
def get_input_scale(self, rates: tuple[jnp.ndarray, jnp.ndarray], epsilon=1e-8) -> jnp.ndarray:
|
94
94
|
_, sigma = rates
|
95
|
-
c_in = 1 / jnp.sqrt(self.sigma_data ** 2 + sigma ** 2)
|
95
|
+
c_in = 1 / (jnp.sqrt(self.sigma_data ** 2 + sigma ** 2) + epsilon)
|
96
96
|
return c_in
|
flaxdiff/samplers/common.py
CHANGED
@@ -67,7 +67,7 @@ class DiffusionSampler():
|
|
67
67
|
# Used to sample from the diffusion model
|
68
68
|
def sample_step(self, sample_model_fn, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
|
69
69
|
# First clip the noisy images
|
70
|
-
step_ones = jnp.ones((current_samples
|
70
|
+
step_ones = jnp.ones((len(current_samples), ), dtype=jnp.int32)
|
71
71
|
current_step = step_ones * current_step
|
72
72
|
next_step = step_ones * next_step
|
73
73
|
pred_images, pred_noise, _ = sample_model_fn(current_samples, current_step, *model_conditioning_inputs)
|
@@ -133,6 +133,7 @@ class DiffusionSampler():
|
|
133
133
|
|
134
134
|
params = params if params is not None else self.params
|
135
135
|
|
136
|
+
@jax.jit
|
136
137
|
def sample_model_fn(x_t, t, *additional_inputs):
|
137
138
|
return self.sample_model(params, x_t, t, *additional_inputs)
|
138
139
|
|
flaxdiff/schedulers/common.py
CHANGED
@@ -3,6 +3,16 @@ import jax.numpy as jnp
|
|
3
3
|
from typing import Union
|
4
4
|
from ..utils import RandomMarkovState
|
5
5
|
|
6
|
+
def get_coeff_shapes_tuple(array):
|
7
|
+
shape_tuple = (-1,) + (1,) * (array.ndim - 1)
|
8
|
+
return shape_tuple
|
9
|
+
|
10
|
+
def reshape_rates(rates:tuple[jnp.ndarray, jnp.ndarray], shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
|
11
|
+
signal_rates, noise_rates = rates
|
12
|
+
signal_rates = jnp.reshape(signal_rates, shape)
|
13
|
+
noise_rates = jnp.reshape(noise_rates, shape)
|
14
|
+
return signal_rates, noise_rates
|
15
|
+
|
6
16
|
class NoiseScheduler():
|
7
17
|
def __init__(self, timesteps,
|
8
18
|
dtype=jnp.float32,
|
@@ -24,24 +34,18 @@ class NoiseScheduler():
|
|
24
34
|
timesteps = self.timestep_generator(rng, batch_size, self.max_timesteps)
|
25
35
|
return timesteps, state
|
26
36
|
|
27
|
-
def get_weights(self, steps):
|
37
|
+
def get_weights(self, steps, shape=(-1, 1, 1, 1)):
|
28
38
|
raise NotImplementedError
|
29
39
|
|
30
|
-
def reshape_rates(self, rates:tuple[jnp.ndarray, jnp.ndarray], shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
|
31
|
-
signal_rates, noise_rates = rates
|
32
|
-
signal_rates = jnp.reshape(signal_rates, shape)
|
33
|
-
noise_rates = jnp.reshape(noise_rates, shape)
|
34
|
-
return signal_rates, noise_rates
|
35
|
-
|
36
40
|
def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
|
37
41
|
raise NotImplementedError
|
38
42
|
|
39
43
|
def add_noise(self, images, noise, steps) -> jnp.ndarray:
|
40
|
-
signal_rates, noise_rates = self.get_rates(steps)
|
44
|
+
signal_rates, noise_rates = self.get_rates(steps, shape=get_coeff_shapes_tuple(images))
|
41
45
|
return signal_rates * images + noise_rates * noise
|
42
46
|
|
43
47
|
def remove_all_noise(self, noisy_images, noise, steps, clip_denoised=True, rates=None):
|
44
|
-
signal_rates, noise_rates = self.get_rates(steps)
|
48
|
+
signal_rates, noise_rates = self.get_rates(steps, shape=get_coeff_shapes_tuple(noisy_images))
|
45
49
|
x_0 = (noisy_images - noise * noise_rates) / signal_rates
|
46
50
|
return x_0
|
47
51
|
|
@@ -54,8 +58,8 @@ class NoiseScheduler():
|
|
54
58
|
def get_posterior_variance(self, steps, shape=(-1, 1, 1, 1)):
|
55
59
|
raise NotImplementedError
|
56
60
|
|
57
|
-
def get_max_variance(self):
|
58
|
-
alpha_n, sigma_n = self.get_rates(self.max_timesteps)
|
61
|
+
def get_max_variance(self, shape=(-1, 1, 1, 1)):
|
62
|
+
alpha_n, sigma_n = self.get_rates(self.max_timesteps, shape=shape)
|
59
63
|
variance = jnp.sqrt(alpha_n ** 2 + sigma_n ** 2)
|
60
64
|
return variance
|
61
65
|
|
@@ -82,9 +86,9 @@ class GeneralizedNoiseScheduler(NoiseScheduler):
|
|
82
86
|
|
83
87
|
def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
|
84
88
|
sigmas = self.get_sigmas(steps)
|
85
|
-
signal_rates =
|
89
|
+
signal_rates = jnp.ones_like(sigmas)
|
86
90
|
noise_rates = sigmas
|
87
|
-
return
|
91
|
+
return reshape_rates((signal_rates, noise_rates), shape=shape)
|
88
92
|
|
89
93
|
def transform_inputs(self, x, steps, num_discrete_chunks=1000):
|
90
94
|
sigmas_discrete = (steps / self.max_timesteps) * num_discrete_chunks
|
flaxdiff/schedulers/cosine.py
CHANGED
@@ -3,7 +3,7 @@ import numpy as np
|
|
3
3
|
import jax.numpy as jnp
|
4
4
|
from .discrete import DiscreteNoiseScheduler
|
5
5
|
from .continuous import ContinuousNoiseScheduler
|
6
|
-
from .common import GeneralizedNoiseScheduler
|
6
|
+
from .common import GeneralizedNoiseScheduler, reshape_rates
|
7
7
|
|
8
8
|
def cosine_beta_schedule(timesteps, start_angle=0.008, end_angle=0.999):
|
9
9
|
ts = np.linspace(0, 1, timesteps + 1, dtype=np.float64)
|
@@ -32,9 +32,9 @@ class CosineContinuousNoiseScheduler(ContinuousNoiseScheduler):
|
|
32
32
|
def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
|
33
33
|
signal_rates = jnp.cos((jnp.pi * steps) / (2 * self.max_timesteps))
|
34
34
|
noise_rates = jnp.sin((jnp.pi * steps) / (2 * self.max_timesteps))
|
35
|
-
return
|
35
|
+
return reshape_rates((signal_rates, noise_rates), shape=shape)
|
36
36
|
|
37
|
-
def get_weights(self, steps):
|
38
|
-
alpha, sigma = self.get_rates(steps, shape=
|
37
|
+
def get_weights(self, steps, shape=(-1, 1, 1, 1)) -> jnp.ndarray:
|
38
|
+
alpha, sigma = self.get_rates(steps, shape=shape)
|
39
39
|
return 1 / (1 + (alpha ** 2 / sigma ** 2))
|
40
40
|
|
flaxdiff/schedulers/discrete.py
CHANGED
@@ -2,7 +2,7 @@ import jax
|
|
2
2
|
import jax.numpy as jnp
|
3
3
|
from typing import Union
|
4
4
|
from ..utils import RandomMarkovState
|
5
|
-
from .common import NoiseScheduler
|
5
|
+
from .common import NoiseScheduler, reshape_rates, get_coeff_shapes_tuple
|
6
6
|
|
7
7
|
class DiscreteNoiseScheduler(NoiseScheduler):
|
8
8
|
"""
|
@@ -53,17 +53,15 @@ class DiscreteNoiseScheduler(NoiseScheduler):
|
|
53
53
|
|
54
54
|
def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
|
55
55
|
steps = jnp.int16(steps)
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
noise_rate = jnp.reshape(noise_rate, shape)
|
60
|
-
return signal_rate, noise_rate
|
56
|
+
signal_rates = self.sqrt_alpha_cumprod[steps]
|
57
|
+
noise_rates = self.sqrt_one_minus_alpha_cumprod[steps]
|
58
|
+
return reshape_rates((signal_rates, noise_rates), shape=shape)
|
61
59
|
|
62
60
|
def get_posterior_mean(self, x_0, x_t, steps):
|
63
61
|
steps = jnp.int16(steps)
|
64
62
|
x_0_coeff = self.posterior_mean_coef1[steps]
|
65
63
|
x_t_coeff = self.posterior_mean_coef2[steps]
|
66
|
-
x_0_coeff, x_t_coeff =
|
64
|
+
x_0_coeff, x_t_coeff = reshape_rates((x_0_coeff, x_t_coeff), shape=get_coeff_shapes_tuple(x_0))
|
67
65
|
mean = x_0_coeff * x_0 + x_t_coeff * x_t
|
68
66
|
return mean
|
69
67
|
|
flaxdiff/schedulers/linear.py
CHANGED
@@ -5,8 +5,7 @@ def linear_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02):
|
|
5
5
|
scale = 1000 / timesteps
|
6
6
|
beta_start = scale * beta_start
|
7
7
|
beta_end = scale * beta_end
|
8
|
-
betas = np.linspace(
|
9
|
-
beta_start, beta_end, timesteps, dtype=np.float64)
|
8
|
+
betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float64)
|
10
9
|
return betas
|
11
10
|
|
12
11
|
class LinearNoiseSchedule(DiscreteNoiseScheduler):
|
flaxdiff/schedulers/sqrt.py
CHANGED
@@ -2,9 +2,10 @@ import numpy as np
|
|
2
2
|
import jax.numpy as jnp
|
3
3
|
from .discrete import DiscreteNoiseScheduler
|
4
4
|
from .continuous import ContinuousNoiseScheduler
|
5
|
+
from .common import reshape_rates
|
5
6
|
|
6
7
|
class SqrtContinuousNoiseScheduler(ContinuousNoiseScheduler):
|
7
8
|
def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
|
8
9
|
signal_rates = jnp.sqrt(1 - steps)
|
9
10
|
noise_rates = jnp.sqrt(steps)
|
10
|
-
return
|
11
|
+
return reshape_rates((signal_rates, noise_rates), shape=shape)
|
@@ -11,7 +11,7 @@ from jax.sharding import Mesh, PartitionSpec as P
|
|
11
11
|
from jax.experimental.shard_map import shard_map
|
12
12
|
from typing import Dict, Callable, Sequence, Any, Union, Tuple, Type
|
13
13
|
|
14
|
-
from ..schedulers import NoiseScheduler
|
14
|
+
from ..schedulers import NoiseScheduler, get_coeff_shapes_tuple
|
15
15
|
from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
|
16
16
|
from ..samplers.common import DiffusionSampler
|
17
17
|
from ..samplers.ddim import DDIMSampler
|
@@ -144,6 +144,8 @@ class DiffusionTrainer(SimpleTrainer):
|
|
144
144
|
|
145
145
|
images = batch['image']
|
146
146
|
|
147
|
+
local_batch_size = images.shape[0]
|
148
|
+
|
147
149
|
# First get the standard deviation of the images
|
148
150
|
# std = jnp.std(images, axis=(1, 2, 3))
|
149
151
|
# is_non_zero = (std > 0)
|
@@ -164,22 +166,23 @@ class DiffusionTrainer(SimpleTrainer):
|
|
164
166
|
label_seq = jnp.concat(
|
165
167
|
[null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
|
166
168
|
|
167
|
-
noise_level, local_rng_state = noise_schedule.generate_timesteps(
|
169
|
+
noise_level, local_rng_state = noise_schedule.generate_timesteps(local_batch_size, local_rng_state)
|
168
170
|
|
169
171
|
local_rng_state, rngs = local_rng_state.get_random_key()
|
170
|
-
noise: jax.Array = jax.random.normal(rngs, shape=images.shape)
|
172
|
+
noise: jax.Array = jax.random.normal(rngs, shape=images.shape, dtype=jnp.float32)
|
173
|
+
|
174
|
+
# Make sure image is also float32
|
175
|
+
images = images.astype(jnp.float32)
|
171
176
|
|
172
|
-
rates = noise_schedule.get_rates(noise_level)
|
173
|
-
noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
|
174
|
-
images, noise, rates)
|
177
|
+
rates = noise_schedule.get_rates(noise_level, get_coeff_shapes_tuple(images))
|
178
|
+
noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(images, noise, rates)
|
175
179
|
|
176
180
|
def model_loss(params):
|
177
181
|
preds = model.apply(params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
|
178
|
-
preds = model_output_transform.pred_transform(
|
179
|
-
noisy_images, preds, rates)
|
182
|
+
preds = model_output_transform.pred_transform(noisy_images, preds, rates)
|
180
183
|
nloss = loss_fn(preds, expected_output)
|
181
184
|
# Ignore the loss contribution of images with zero standard deviation
|
182
|
-
nloss *= noise_schedule.get_weights(noise_level)
|
185
|
+
nloss *= noise_schedule.get_weights(noise_level, get_coeff_shapes_tuple(nloss))
|
183
186
|
nloss = jnp.mean(nloss)
|
184
187
|
loss = nloss
|
185
188
|
return loss
|
@@ -197,7 +200,22 @@ class DiffusionTrainer(SimpleTrainer):
|
|
197
200
|
loss, grads = grad_fn(train_state.params)
|
198
201
|
if distributed_training:
|
199
202
|
grads = jax.lax.pmean(grads, "data")
|
203
|
+
|
204
|
+
# # check gradients for NaN/Inf
|
205
|
+
# has_nan_or_inf = jax.tree_util.tree_reduce(
|
206
|
+
# lambda acc, x: jnp.logical_or(acc, jnp.logical_or(jnp.isnan(x).any(), jnp.isinf(x).any())),
|
207
|
+
# grads,
|
208
|
+
# initializer=False
|
209
|
+
# )
|
200
210
|
|
211
|
+
# # Only apply gradients if they're valid
|
212
|
+
# new_state = jax.lax.cond(
|
213
|
+
# has_nan_or_inf,
|
214
|
+
# lambda _: train_state, # Skip gradient update
|
215
|
+
# lambda _: train_state.apply_gradients(grads=grads),
|
216
|
+
# operand=None
|
217
|
+
# )
|
218
|
+
|
201
219
|
new_state = train_state.apply_gradients(grads=grads)
|
202
220
|
|
203
221
|
if train_state.dynamic_scale is not None:
|
@@ -220,9 +238,16 @@ class DiffusionTrainer(SimpleTrainer):
|
|
220
238
|
return train_state, loss, rng_state
|
221
239
|
|
222
240
|
if distributed_training:
|
223
|
-
train_step = shard_map(
|
224
|
-
|
225
|
-
|
241
|
+
train_step = shard_map(
|
242
|
+
train_step,
|
243
|
+
mesh=self.mesh,
|
244
|
+
in_specs=(P(), P(), P('data'), P('data')),
|
245
|
+
out_specs=(P(), P(), P()),
|
246
|
+
)
|
247
|
+
train_step = jax.jit(
|
248
|
+
train_step,
|
249
|
+
donate_argnums=(2)
|
250
|
+
)
|
226
251
|
|
227
252
|
return train_step
|
228
253
|
|
@@ -235,12 +260,21 @@ class DiffusionTrainer(SimpleTrainer):
|
|
235
260
|
null_labels_full = null_labels_full.astype(jnp.float16)
|
236
261
|
# null_labels_seq = jnp.array(null_labels_full[0], dtype=jnp.float16)
|
237
262
|
|
263
|
+
if 'image' in self.input_shapes:
|
264
|
+
image_size = self.input_shapes['image'][1]
|
265
|
+
elif 'x' in self.input_shapes:
|
266
|
+
image_size = self.input_shapes['x'][1]
|
267
|
+
elif 'sample' in self.input_shapes:
|
268
|
+
image_size = self.input_shapes['sample'][1]
|
269
|
+
else:
|
270
|
+
raise ValueError("No image input shape found in input shapes")
|
271
|
+
|
238
272
|
sampler = sampler_class(
|
239
273
|
model=model,
|
240
274
|
params=None,
|
241
275
|
noise_schedule=self.noise_schedule if sampling_noise_schedule is None else sampling_noise_schedule,
|
242
276
|
model_output_transform=self.model_output_transform,
|
243
|
-
image_size=
|
277
|
+
image_size=image_size,
|
244
278
|
null_labels_seq=null_labels_full,
|
245
279
|
autoencoder=autoencoder,
|
246
280
|
guidance_scale=3.0,
|
@@ -291,7 +325,7 @@ class DiffusionTrainer(SimpleTrainer):
|
|
291
325
|
)
|
292
326
|
|
293
327
|
# Put each sample on wandb
|
294
|
-
if self.wandb:
|
328
|
+
if getattr(self, 'wandb', None) is not None and self.wandb:
|
295
329
|
import numpy as np
|
296
330
|
from wandb import Image as wandbImage
|
297
331
|
wandb_images = []
|
@@ -403,7 +403,6 @@ class SimpleTrainer:
|
|
403
403
|
rng_state
|
404
404
|
):
|
405
405
|
global_device_count = jax.device_count()
|
406
|
-
local_device_count = jax.local_device_count()
|
407
406
|
process_index = jax.process_index()
|
408
407
|
if self.distributed_training:
|
409
408
|
global_device_indexes = jnp.arange(global_device_count)
|
@@ -434,11 +433,16 @@ class SimpleTrainer:
|
|
434
433
|
# loss = jax.experimental.multihost_utils.process_allgather(loss)
|
435
434
|
loss = jnp.mean(loss) # Just to make sure its a scaler value
|
436
435
|
|
437
|
-
if loss <= 1e-
|
436
|
+
if loss <= 1e-8:
|
438
437
|
# If the loss is too low, we can assume the model has diverged
|
439
438
|
print(colored(f"Loss too low at step {current_step} => {loss}", 'red'))
|
440
439
|
# Reset the model to the old state
|
441
|
-
|
440
|
+
if self.best_state is not None:
|
441
|
+
print(colored(f"Resetting model to best state", 'red'))
|
442
|
+
train_state = self.best_state
|
443
|
+
loss = self.best_loss
|
444
|
+
else:
|
445
|
+
exit(1)
|
442
446
|
|
443
447
|
epoch_loss += loss
|
444
448
|
current_step += 1
|
@@ -20,9 +20,9 @@ flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0
|
|
20
20
|
flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
|
21
21
|
flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
|
22
22
|
flaxdiff/models/autoencoder/simple_autoenc.py,sha256=UXHPgDmwGTnv3Uts6Zj3p9R9nJXnEiEXbllgarwDfXM,805
|
23
|
-
flaxdiff/predictors/__init__.py,sha256=
|
23
|
+
flaxdiff/predictors/__init__.py,sha256=A6lMnRSqDLwBdwB0a1x4xPpcXcTYCISWFMkldbiK6Vs,4554
|
24
24
|
flaxdiff/samplers/__init__.py,sha256=EY9v1pgwEoR64Kiz9K8fAR-4_ir9c03mYeY3hrpUNhE,308
|
25
|
-
flaxdiff/samplers/common.py,sha256=
|
25
|
+
flaxdiff/samplers/common.py,sha256=wn8tryC3B0KE0V98zMiH_X2x-Tc1NbM5iV27hn5p8Aw,8846
|
26
26
|
flaxdiff/samplers/ddim.py,sha256=hTjDm0SmIj-Tkc80QRATMcN_sKVhHbqZQboRQCAn4mY,569
|
27
27
|
flaxdiff/samplers/ddpm.py,sha256=JgkNSo7fp7Jm-8rCy4eu5m4YIzXTWzxv-iHf3EQ0z5w,2243
|
28
28
|
flaxdiff/samplers/euler.py,sha256=QSkttB4DYnepDGwhWq3EGXYjMAqj4qLOdh7u98HttzY,2791
|
@@ -30,20 +30,20 @@ flaxdiff/samplers/heun_sampler.py,sha256=EvR3hy4t_D47ZOH4luzRFqPmv2v4z78P_JhqBGE
|
|
30
30
|
flaxdiff/samplers/multistep_dpm.py,sha256=2M4Abb93-GUVN1f0_ZHBeA6lF0eF15Hi6QOgOu2K45s,2752
|
31
31
|
flaxdiff/samplers/rk4_sampler.py,sha256=vcQefFhOUZdNOQGBdzNkb2NgmTC2KWd_nhUhyLtt3yI,2026
|
32
32
|
flaxdiff/schedulers/__init__.py,sha256=EIva9gBz3DKHORuGmv1LQCKTtRqCRavFOXMNqxAR_ks,131
|
33
|
-
flaxdiff/schedulers/common.py,sha256=
|
33
|
+
flaxdiff/schedulers/common.py,sha256=PDeje2NmN7X3J5qKGauE0jYPpxjgEX44f_evJHRIG3E,4382
|
34
34
|
flaxdiff/schedulers/continuous.py,sha256=5c_niOA20fxJ5oJDi09FfayIRogBGwtfG0XThW2IUZk,334
|
35
|
-
flaxdiff/schedulers/cosine.py,sha256=
|
36
|
-
flaxdiff/schedulers/discrete.py,sha256=
|
35
|
+
flaxdiff/schedulers/cosine.py,sha256=E5pODAmINfdyC4kSYOJSPAvq3GNlKPpKEn3X82vYMz0,2055
|
36
|
+
flaxdiff/schedulers/discrete.py,sha256=m1q3bAgeAxU3gTj5di3XFWDm4yLfMKAFJPlYdozLE2Y,3316
|
37
37
|
flaxdiff/schedulers/exp.py,sha256=cPTnUJpYdzJRRZqMLYQz0rRUCpEmaP2tXhRumLx94jA,605
|
38
38
|
flaxdiff/schedulers/karras.py,sha256=4GN120kGwdxxU-h2mVdhBVy9IORkUMm_vvz3XjthBcI,3355
|
39
|
-
flaxdiff/schedulers/linear.py,sha256=
|
40
|
-
flaxdiff/schedulers/sqrt.py,sha256=
|
39
|
+
flaxdiff/schedulers/linear.py,sha256=pBDTXSQcOS4Z03JTh6S0f9E2qLcTQzF2E-pGoQnRoy0,572
|
40
|
+
flaxdiff/schedulers/sqrt.py,sha256=mCd_szmOqF6vqQKiAiEOqV_3eBIPGYrW3VxK0o4rBuo,474
|
41
41
|
flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,128
|
42
42
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=hxihkRL9WCIQVGOP-pc1jjjIUaRXDLcNo3_erTKsuWM,7049
|
43
|
-
flaxdiff/trainer/diffusion_trainer.py,sha256=
|
44
|
-
flaxdiff/trainer/simple_trainer.py,sha256=
|
43
|
+
flaxdiff/trainer/diffusion_trainer.py,sha256=kEulMnk6ZkKhQRSVr3UtDdCmXR4cWphJ3XNuk7VIAUY,14189
|
44
|
+
flaxdiff/trainer/simple_trainer.py,sha256=LScHQZCy5ksSC7n0GC0tjOXK-zptxpMJsC6Udf-nz18,22178
|
45
45
|
flaxdiff/trainer/video_diffusion_trainer.py,sha256=gMkKpnKNTo8QhTx5ptEEkc7W5-7rzXIr9queU53hXyQ,2197
|
46
|
-
flaxdiff-0.1.37.
|
47
|
-
flaxdiff-0.1.37.
|
48
|
-
flaxdiff-0.1.37.
|
49
|
-
flaxdiff-0.1.37.
|
46
|
+
flaxdiff-0.1.37.6.dist-info/METADATA,sha256=SujaCKk29ECrfSEIdchYvAl-nf0L270t2of7oeX5kgk,23985
|
47
|
+
flaxdiff-0.1.37.6.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
48
|
+
flaxdiff-0.1.37.6.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
49
|
+
flaxdiff-0.1.37.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|