flaxdiff 0.1.37.4__py3-none-any.whl → 0.1.37.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/predictors/__init__.py +4 -4
- flaxdiff/samplers/common.py +2 -1
- flaxdiff/schedulers/common.py +17 -13
- flaxdiff/schedulers/cosine.py +4 -4
- flaxdiff/schedulers/discrete.py +5 -7
- flaxdiff/schedulers/linear.py +1 -2
- flaxdiff/schedulers/sqrt.py +2 -1
- flaxdiff/trainer/diffusion_trainer.py +30 -14
- {flaxdiff-0.1.37.4.dist-info → flaxdiff-0.1.37.6.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.37.4.dist-info → flaxdiff-0.1.37.6.dist-info}/RECORD +12 -12
- {flaxdiff-0.1.37.4.dist-info → flaxdiff-0.1.37.6.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.37.4.dist-info → flaxdiff-0.1.37.6.dist-info}/top_level.txt +0 -0
flaxdiff/predictors/__init__.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
from typing import Union
|
2
2
|
import jax.numpy as jnp
|
3
|
-
from ..schedulers import NoiseScheduler, GeneralizedNoiseScheduler
|
3
|
+
from ..schedulers import NoiseScheduler, GeneralizedNoiseScheduler, get_coeff_shapes_tuple
|
4
4
|
|
5
5
|
############################################################################################################
|
6
6
|
# Prediction Transforms
|
@@ -11,7 +11,7 @@ class DiffusionPredictionTransform():
|
|
11
11
|
return preds
|
12
12
|
|
13
13
|
def __call__(self, x_t, preds, current_step, noise_schedule:NoiseScheduler) -> Union[jnp.ndarray, jnp.ndarray]:
|
14
|
-
rates = noise_schedule.get_rates(current_step)
|
14
|
+
rates = noise_schedule.get_rates(current_step, shape=get_coeff_shapes_tuple(x_t))
|
15
15
|
preds = self.pred_transform(x_t, preds, rates)
|
16
16
|
x_0, epsilon = self.backward_diffusion(x_t, preds, rates)
|
17
17
|
return x_0, epsilon
|
@@ -85,8 +85,8 @@ class KarrasPredictionTransform(DiffusionPredictionTransform):
|
|
85
85
|
_, sigma = rates
|
86
86
|
c_out = sigma * self.sigma_data / (jnp.sqrt(self.sigma_data ** 2 + sigma ** 2) + epsilon)
|
87
87
|
c_skip = self.sigma_data ** 2 / (self.sigma_data ** 2 + sigma ** 2 + epsilon)
|
88
|
-
c_out = c_out.reshape((
|
89
|
-
c_skip = c_skip.reshape((
|
88
|
+
c_out = c_out.reshape(get_coeff_shapes_tuple(preds))
|
89
|
+
c_skip = c_skip.reshape(get_coeff_shapes_tuple(x_t))
|
90
90
|
x_0 = c_out * preds + c_skip * x_t
|
91
91
|
return x_0
|
92
92
|
|
flaxdiff/samplers/common.py
CHANGED
@@ -67,7 +67,7 @@ class DiffusionSampler():
|
|
67
67
|
# Used to sample from the diffusion model
|
68
68
|
def sample_step(self, sample_model_fn, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
|
69
69
|
# First clip the noisy images
|
70
|
-
step_ones = jnp.ones((current_samples
|
70
|
+
step_ones = jnp.ones((len(current_samples), ), dtype=jnp.int32)
|
71
71
|
current_step = step_ones * current_step
|
72
72
|
next_step = step_ones * next_step
|
73
73
|
pred_images, pred_noise, _ = sample_model_fn(current_samples, current_step, *model_conditioning_inputs)
|
@@ -133,6 +133,7 @@ class DiffusionSampler():
|
|
133
133
|
|
134
134
|
params = params if params is not None else self.params
|
135
135
|
|
136
|
+
@jax.jit
|
136
137
|
def sample_model_fn(x_t, t, *additional_inputs):
|
137
138
|
return self.sample_model(params, x_t, t, *additional_inputs)
|
138
139
|
|
flaxdiff/schedulers/common.py
CHANGED
@@ -3,6 +3,16 @@ import jax.numpy as jnp
|
|
3
3
|
from typing import Union
|
4
4
|
from ..utils import RandomMarkovState
|
5
5
|
|
6
|
+
def get_coeff_shapes_tuple(array):
|
7
|
+
shape_tuple = (-1,) + (1,) * (array.ndim - 1)
|
8
|
+
return shape_tuple
|
9
|
+
|
10
|
+
def reshape_rates(rates:tuple[jnp.ndarray, jnp.ndarray], shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
|
11
|
+
signal_rates, noise_rates = rates
|
12
|
+
signal_rates = jnp.reshape(signal_rates, shape)
|
13
|
+
noise_rates = jnp.reshape(noise_rates, shape)
|
14
|
+
return signal_rates, noise_rates
|
15
|
+
|
6
16
|
class NoiseScheduler():
|
7
17
|
def __init__(self, timesteps,
|
8
18
|
dtype=jnp.float32,
|
@@ -24,24 +34,18 @@ class NoiseScheduler():
|
|
24
34
|
timesteps = self.timestep_generator(rng, batch_size, self.max_timesteps)
|
25
35
|
return timesteps, state
|
26
36
|
|
27
|
-
def get_weights(self, steps):
|
37
|
+
def get_weights(self, steps, shape=(-1, 1, 1, 1)):
|
28
38
|
raise NotImplementedError
|
29
39
|
|
30
|
-
def reshape_rates(self, rates:tuple[jnp.ndarray, jnp.ndarray], shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
|
31
|
-
signal_rates, noise_rates = rates
|
32
|
-
signal_rates = jnp.reshape(signal_rates, shape)
|
33
|
-
noise_rates = jnp.reshape(noise_rates, shape)
|
34
|
-
return signal_rates, noise_rates
|
35
|
-
|
36
40
|
def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
|
37
41
|
raise NotImplementedError
|
38
42
|
|
39
43
|
def add_noise(self, images, noise, steps) -> jnp.ndarray:
|
40
|
-
signal_rates, noise_rates = self.get_rates(steps)
|
44
|
+
signal_rates, noise_rates = self.get_rates(steps, shape=get_coeff_shapes_tuple(images))
|
41
45
|
return signal_rates * images + noise_rates * noise
|
42
46
|
|
43
47
|
def remove_all_noise(self, noisy_images, noise, steps, clip_denoised=True, rates=None):
|
44
|
-
signal_rates, noise_rates = self.get_rates(steps)
|
48
|
+
signal_rates, noise_rates = self.get_rates(steps, shape=get_coeff_shapes_tuple(noisy_images))
|
45
49
|
x_0 = (noisy_images - noise * noise_rates) / signal_rates
|
46
50
|
return x_0
|
47
51
|
|
@@ -54,8 +58,8 @@ class NoiseScheduler():
|
|
54
58
|
def get_posterior_variance(self, steps, shape=(-1, 1, 1, 1)):
|
55
59
|
raise NotImplementedError
|
56
60
|
|
57
|
-
def get_max_variance(self):
|
58
|
-
alpha_n, sigma_n = self.get_rates(self.max_timesteps)
|
61
|
+
def get_max_variance(self, shape=(-1, 1, 1, 1)):
|
62
|
+
alpha_n, sigma_n = self.get_rates(self.max_timesteps, shape=shape)
|
59
63
|
variance = jnp.sqrt(alpha_n ** 2 + sigma_n ** 2)
|
60
64
|
return variance
|
61
65
|
|
@@ -82,9 +86,9 @@ class GeneralizedNoiseScheduler(NoiseScheduler):
|
|
82
86
|
|
83
87
|
def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
|
84
88
|
sigmas = self.get_sigmas(steps)
|
85
|
-
signal_rates =
|
89
|
+
signal_rates = jnp.ones_like(sigmas)
|
86
90
|
noise_rates = sigmas
|
87
|
-
return
|
91
|
+
return reshape_rates((signal_rates, noise_rates), shape=shape)
|
88
92
|
|
89
93
|
def transform_inputs(self, x, steps, num_discrete_chunks=1000):
|
90
94
|
sigmas_discrete = (steps / self.max_timesteps) * num_discrete_chunks
|
flaxdiff/schedulers/cosine.py
CHANGED
@@ -3,7 +3,7 @@ import numpy as np
|
|
3
3
|
import jax.numpy as jnp
|
4
4
|
from .discrete import DiscreteNoiseScheduler
|
5
5
|
from .continuous import ContinuousNoiseScheduler
|
6
|
-
from .common import GeneralizedNoiseScheduler
|
6
|
+
from .common import GeneralizedNoiseScheduler, reshape_rates
|
7
7
|
|
8
8
|
def cosine_beta_schedule(timesteps, start_angle=0.008, end_angle=0.999):
|
9
9
|
ts = np.linspace(0, 1, timesteps + 1, dtype=np.float64)
|
@@ -32,9 +32,9 @@ class CosineContinuousNoiseScheduler(ContinuousNoiseScheduler):
|
|
32
32
|
def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
|
33
33
|
signal_rates = jnp.cos((jnp.pi * steps) / (2 * self.max_timesteps))
|
34
34
|
noise_rates = jnp.sin((jnp.pi * steps) / (2 * self.max_timesteps))
|
35
|
-
return
|
35
|
+
return reshape_rates((signal_rates, noise_rates), shape=shape)
|
36
36
|
|
37
|
-
def get_weights(self, steps):
|
38
|
-
alpha, sigma = self.get_rates(steps, shape=
|
37
|
+
def get_weights(self, steps, shape=(-1, 1, 1, 1)) -> jnp.ndarray:
|
38
|
+
alpha, sigma = self.get_rates(steps, shape=shape)
|
39
39
|
return 1 / (1 + (alpha ** 2 / sigma ** 2))
|
40
40
|
|
flaxdiff/schedulers/discrete.py
CHANGED
@@ -2,7 +2,7 @@ import jax
|
|
2
2
|
import jax.numpy as jnp
|
3
3
|
from typing import Union
|
4
4
|
from ..utils import RandomMarkovState
|
5
|
-
from .common import NoiseScheduler
|
5
|
+
from .common import NoiseScheduler, reshape_rates, get_coeff_shapes_tuple
|
6
6
|
|
7
7
|
class DiscreteNoiseScheduler(NoiseScheduler):
|
8
8
|
"""
|
@@ -53,17 +53,15 @@ class DiscreteNoiseScheduler(NoiseScheduler):
|
|
53
53
|
|
54
54
|
def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
|
55
55
|
steps = jnp.int16(steps)
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
noise_rate = jnp.reshape(noise_rate, shape)
|
60
|
-
return signal_rate, noise_rate
|
56
|
+
signal_rates = self.sqrt_alpha_cumprod[steps]
|
57
|
+
noise_rates = self.sqrt_one_minus_alpha_cumprod[steps]
|
58
|
+
return reshape_rates((signal_rates, noise_rates), shape=shape)
|
61
59
|
|
62
60
|
def get_posterior_mean(self, x_0, x_t, steps):
|
63
61
|
steps = jnp.int16(steps)
|
64
62
|
x_0_coeff = self.posterior_mean_coef1[steps]
|
65
63
|
x_t_coeff = self.posterior_mean_coef2[steps]
|
66
|
-
x_0_coeff, x_t_coeff =
|
64
|
+
x_0_coeff, x_t_coeff = reshape_rates((x_0_coeff, x_t_coeff), shape=get_coeff_shapes_tuple(x_0))
|
67
65
|
mean = x_0_coeff * x_0 + x_t_coeff * x_t
|
68
66
|
return mean
|
69
67
|
|
flaxdiff/schedulers/linear.py
CHANGED
@@ -5,8 +5,7 @@ def linear_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02):
|
|
5
5
|
scale = 1000 / timesteps
|
6
6
|
beta_start = scale * beta_start
|
7
7
|
beta_end = scale * beta_end
|
8
|
-
betas = np.linspace(
|
9
|
-
beta_start, beta_end, timesteps, dtype=np.float64)
|
8
|
+
betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float64)
|
10
9
|
return betas
|
11
10
|
|
12
11
|
class LinearNoiseSchedule(DiscreteNoiseScheduler):
|
flaxdiff/schedulers/sqrt.py
CHANGED
@@ -2,9 +2,10 @@ import numpy as np
|
|
2
2
|
import jax.numpy as jnp
|
3
3
|
from .discrete import DiscreteNoiseScheduler
|
4
4
|
from .continuous import ContinuousNoiseScheduler
|
5
|
+
from .common import reshape_rates
|
5
6
|
|
6
7
|
class SqrtContinuousNoiseScheduler(ContinuousNoiseScheduler):
|
7
8
|
def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
|
8
9
|
signal_rates = jnp.sqrt(1 - steps)
|
9
10
|
noise_rates = jnp.sqrt(steps)
|
10
|
-
return
|
11
|
+
return reshape_rates((signal_rates, noise_rates), shape=shape)
|
@@ -11,7 +11,7 @@ from jax.sharding import Mesh, PartitionSpec as P
|
|
11
11
|
from jax.experimental.shard_map import shard_map
|
12
12
|
from typing import Dict, Callable, Sequence, Any, Union, Tuple, Type
|
13
13
|
|
14
|
-
from ..schedulers import NoiseScheduler
|
14
|
+
from ..schedulers import NoiseScheduler, get_coeff_shapes_tuple
|
15
15
|
from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
|
16
16
|
from ..samplers.common import DiffusionSampler
|
17
17
|
from ..samplers.ddim import DDIMSampler
|
@@ -144,6 +144,8 @@ class DiffusionTrainer(SimpleTrainer):
|
|
144
144
|
|
145
145
|
images = batch['image']
|
146
146
|
|
147
|
+
local_batch_size = images.shape[0]
|
148
|
+
|
147
149
|
# First get the standard deviation of the images
|
148
150
|
# std = jnp.std(images, axis=(1, 2, 3))
|
149
151
|
# is_non_zero = (std > 0)
|
@@ -164,7 +166,7 @@ class DiffusionTrainer(SimpleTrainer):
|
|
164
166
|
label_seq = jnp.concat(
|
165
167
|
[null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
|
166
168
|
|
167
|
-
noise_level, local_rng_state = noise_schedule.generate_timesteps(
|
169
|
+
noise_level, local_rng_state = noise_schedule.generate_timesteps(local_batch_size, local_rng_state)
|
168
170
|
|
169
171
|
local_rng_state, rngs = local_rng_state.get_random_key()
|
170
172
|
noise: jax.Array = jax.random.normal(rngs, shape=images.shape, dtype=jnp.float32)
|
@@ -172,17 +174,15 @@ class DiffusionTrainer(SimpleTrainer):
|
|
172
174
|
# Make sure image is also float32
|
173
175
|
images = images.astype(jnp.float32)
|
174
176
|
|
175
|
-
rates = noise_schedule.get_rates(noise_level)
|
176
|
-
noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
|
177
|
-
images, noise, rates)
|
177
|
+
rates = noise_schedule.get_rates(noise_level, get_coeff_shapes_tuple(images))
|
178
|
+
noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(images, noise, rates)
|
178
179
|
|
179
180
|
def model_loss(params):
|
180
181
|
preds = model.apply(params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
|
181
|
-
preds = model_output_transform.pred_transform(
|
182
|
-
noisy_images, preds, rates)
|
182
|
+
preds = model_output_transform.pred_transform(noisy_images, preds, rates)
|
183
183
|
nloss = loss_fn(preds, expected_output)
|
184
184
|
# Ignore the loss contribution of images with zero standard deviation
|
185
|
-
nloss *= noise_schedule.get_weights(noise_level)
|
185
|
+
nloss *= noise_schedule.get_weights(noise_level, get_coeff_shapes_tuple(nloss))
|
186
186
|
nloss = jnp.mean(nloss)
|
187
187
|
loss = nloss
|
188
188
|
return loss
|
@@ -216,7 +216,7 @@ class DiffusionTrainer(SimpleTrainer):
|
|
216
216
|
# operand=None
|
217
217
|
# )
|
218
218
|
|
219
|
-
|
219
|
+
new_state = train_state.apply_gradients(grads=grads)
|
220
220
|
|
221
221
|
if train_state.dynamic_scale is not None:
|
222
222
|
# if is_fin == False the gradients contain Inf/NaNs and optimizer state and
|
@@ -238,9 +238,16 @@ class DiffusionTrainer(SimpleTrainer):
|
|
238
238
|
return train_state, loss, rng_state
|
239
239
|
|
240
240
|
if distributed_training:
|
241
|
-
train_step = shard_map(
|
242
|
-
|
243
|
-
|
241
|
+
train_step = shard_map(
|
242
|
+
train_step,
|
243
|
+
mesh=self.mesh,
|
244
|
+
in_specs=(P(), P(), P('data'), P('data')),
|
245
|
+
out_specs=(P(), P(), P()),
|
246
|
+
)
|
247
|
+
train_step = jax.jit(
|
248
|
+
train_step,
|
249
|
+
donate_argnums=(2)
|
250
|
+
)
|
244
251
|
|
245
252
|
return train_step
|
246
253
|
|
@@ -253,12 +260,21 @@ class DiffusionTrainer(SimpleTrainer):
|
|
253
260
|
null_labels_full = null_labels_full.astype(jnp.float16)
|
254
261
|
# null_labels_seq = jnp.array(null_labels_full[0], dtype=jnp.float16)
|
255
262
|
|
263
|
+
if 'image' in self.input_shapes:
|
264
|
+
image_size = self.input_shapes['image'][1]
|
265
|
+
elif 'x' in self.input_shapes:
|
266
|
+
image_size = self.input_shapes['x'][1]
|
267
|
+
elif 'sample' in self.input_shapes:
|
268
|
+
image_size = self.input_shapes['sample'][1]
|
269
|
+
else:
|
270
|
+
raise ValueError("No image input shape found in input shapes")
|
271
|
+
|
256
272
|
sampler = sampler_class(
|
257
273
|
model=model,
|
258
274
|
params=None,
|
259
275
|
noise_schedule=self.noise_schedule if sampling_noise_schedule is None else sampling_noise_schedule,
|
260
276
|
model_output_transform=self.model_output_transform,
|
261
|
-
image_size=
|
277
|
+
image_size=image_size,
|
262
278
|
null_labels_seq=null_labels_full,
|
263
279
|
autoencoder=autoencoder,
|
264
280
|
guidance_scale=3.0,
|
@@ -309,7 +325,7 @@ class DiffusionTrainer(SimpleTrainer):
|
|
309
325
|
)
|
310
326
|
|
311
327
|
# Put each sample on wandb
|
312
|
-
if self.wandb:
|
328
|
+
if getattr(self, 'wandb', None) is not None and self.wandb:
|
313
329
|
import numpy as np
|
314
330
|
from wandb import Image as wandbImage
|
315
331
|
wandb_images = []
|
@@ -20,9 +20,9 @@ flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0
|
|
20
20
|
flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
|
21
21
|
flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
|
22
22
|
flaxdiff/models/autoencoder/simple_autoenc.py,sha256=UXHPgDmwGTnv3Uts6Zj3p9R9nJXnEiEXbllgarwDfXM,805
|
23
|
-
flaxdiff/predictors/__init__.py,sha256=
|
23
|
+
flaxdiff/predictors/__init__.py,sha256=A6lMnRSqDLwBdwB0a1x4xPpcXcTYCISWFMkldbiK6Vs,4554
|
24
24
|
flaxdiff/samplers/__init__.py,sha256=EY9v1pgwEoR64Kiz9K8fAR-4_ir9c03mYeY3hrpUNhE,308
|
25
|
-
flaxdiff/samplers/common.py,sha256=
|
25
|
+
flaxdiff/samplers/common.py,sha256=wn8tryC3B0KE0V98zMiH_X2x-Tc1NbM5iV27hn5p8Aw,8846
|
26
26
|
flaxdiff/samplers/ddim.py,sha256=hTjDm0SmIj-Tkc80QRATMcN_sKVhHbqZQboRQCAn4mY,569
|
27
27
|
flaxdiff/samplers/ddpm.py,sha256=JgkNSo7fp7Jm-8rCy4eu5m4YIzXTWzxv-iHf3EQ0z5w,2243
|
28
28
|
flaxdiff/samplers/euler.py,sha256=QSkttB4DYnepDGwhWq3EGXYjMAqj4qLOdh7u98HttzY,2791
|
@@ -30,20 +30,20 @@ flaxdiff/samplers/heun_sampler.py,sha256=EvR3hy4t_D47ZOH4luzRFqPmv2v4z78P_JhqBGE
|
|
30
30
|
flaxdiff/samplers/multistep_dpm.py,sha256=2M4Abb93-GUVN1f0_ZHBeA6lF0eF15Hi6QOgOu2K45s,2752
|
31
31
|
flaxdiff/samplers/rk4_sampler.py,sha256=vcQefFhOUZdNOQGBdzNkb2NgmTC2KWd_nhUhyLtt3yI,2026
|
32
32
|
flaxdiff/schedulers/__init__.py,sha256=EIva9gBz3DKHORuGmv1LQCKTtRqCRavFOXMNqxAR_ks,131
|
33
|
-
flaxdiff/schedulers/common.py,sha256=
|
33
|
+
flaxdiff/schedulers/common.py,sha256=PDeje2NmN7X3J5qKGauE0jYPpxjgEX44f_evJHRIG3E,4382
|
34
34
|
flaxdiff/schedulers/continuous.py,sha256=5c_niOA20fxJ5oJDi09FfayIRogBGwtfG0XThW2IUZk,334
|
35
|
-
flaxdiff/schedulers/cosine.py,sha256=
|
36
|
-
flaxdiff/schedulers/discrete.py,sha256=
|
35
|
+
flaxdiff/schedulers/cosine.py,sha256=E5pODAmINfdyC4kSYOJSPAvq3GNlKPpKEn3X82vYMz0,2055
|
36
|
+
flaxdiff/schedulers/discrete.py,sha256=m1q3bAgeAxU3gTj5di3XFWDm4yLfMKAFJPlYdozLE2Y,3316
|
37
37
|
flaxdiff/schedulers/exp.py,sha256=cPTnUJpYdzJRRZqMLYQz0rRUCpEmaP2tXhRumLx94jA,605
|
38
38
|
flaxdiff/schedulers/karras.py,sha256=4GN120kGwdxxU-h2mVdhBVy9IORkUMm_vvz3XjthBcI,3355
|
39
|
-
flaxdiff/schedulers/linear.py,sha256=
|
40
|
-
flaxdiff/schedulers/sqrt.py,sha256=
|
39
|
+
flaxdiff/schedulers/linear.py,sha256=pBDTXSQcOS4Z03JTh6S0f9E2qLcTQzF2E-pGoQnRoy0,572
|
40
|
+
flaxdiff/schedulers/sqrt.py,sha256=mCd_szmOqF6vqQKiAiEOqV_3eBIPGYrW3VxK0o4rBuo,474
|
41
41
|
flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,128
|
42
42
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=hxihkRL9WCIQVGOP-pc1jjjIUaRXDLcNo3_erTKsuWM,7049
|
43
|
-
flaxdiff/trainer/diffusion_trainer.py,sha256=
|
43
|
+
flaxdiff/trainer/diffusion_trainer.py,sha256=kEulMnk6ZkKhQRSVr3UtDdCmXR4cWphJ3XNuk7VIAUY,14189
|
44
44
|
flaxdiff/trainer/simple_trainer.py,sha256=LScHQZCy5ksSC7n0GC0tjOXK-zptxpMJsC6Udf-nz18,22178
|
45
45
|
flaxdiff/trainer/video_diffusion_trainer.py,sha256=gMkKpnKNTo8QhTx5ptEEkc7W5-7rzXIr9queU53hXyQ,2197
|
46
|
-
flaxdiff-0.1.37.
|
47
|
-
flaxdiff-0.1.37.
|
48
|
-
flaxdiff-0.1.37.
|
49
|
-
flaxdiff-0.1.37.
|
46
|
+
flaxdiff-0.1.37.6.dist-info/METADATA,sha256=SujaCKk29ECrfSEIdchYvAl-nf0L270t2of7oeX5kgk,23985
|
47
|
+
flaxdiff-0.1.37.6.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
48
|
+
flaxdiff-0.1.37.6.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
49
|
+
flaxdiff-0.1.37.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|