flaxdiff 0.1.37.3__py3-none-any.whl → 0.1.37.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  from typing import Union
2
2
  import jax.numpy as jnp
3
- from ..schedulers import NoiseScheduler, GeneralizedNoiseScheduler
3
+ from ..schedulers import NoiseScheduler, GeneralizedNoiseScheduler, get_coeff_shapes_tuple
4
4
 
5
5
  ############################################################################################################
6
6
  # Prediction Transforms
@@ -11,7 +11,7 @@ class DiffusionPredictionTransform():
11
11
  return preds
12
12
 
13
13
  def __call__(self, x_t, preds, current_step, noise_schedule:NoiseScheduler) -> Union[jnp.ndarray, jnp.ndarray]:
14
- rates = noise_schedule.get_rates(current_step)
14
+ rates = noise_schedule.get_rates(current_step, shape=get_coeff_shapes_tuple(x_t))
15
15
  preds = self.pred_transform(x_t, preds, rates)
16
16
  x_0, epsilon = self.backward_diffusion(x_t, preds, rates)
17
17
  return x_0, epsilon
@@ -81,16 +81,16 @@ class KarrasPredictionTransform(DiffusionPredictionTransform):
81
81
  epsilon = (x_t - x_0 * signal_rate) / noise_rate
82
82
  return x_0, epsilon
83
83
 
84
- def pred_transform(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
84
+ def pred_transform(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray], epsilon=1e-8) -> jnp.ndarray:
85
85
  _, sigma = rates
86
- c_out = sigma * self.sigma_data / jnp.sqrt(self.sigma_data ** 2 + sigma ** 2)
87
- c_skip = self.sigma_data ** 2 / (self.sigma_data ** 2 + sigma ** 2)
88
- c_out = c_out.reshape((-1, 1, 1, 1))
89
- c_skip = c_skip.reshape((-1, 1, 1, 1))
86
+ c_out = sigma * self.sigma_data / (jnp.sqrt(self.sigma_data ** 2 + sigma ** 2) + epsilon)
87
+ c_skip = self.sigma_data ** 2 / (self.sigma_data ** 2 + sigma ** 2 + epsilon)
88
+ c_out = c_out.reshape(get_coeff_shapes_tuple(preds))
89
+ c_skip = c_skip.reshape(get_coeff_shapes_tuple(x_t))
90
90
  x_0 = c_out * preds + c_skip * x_t
91
91
  return x_0
92
92
 
93
- def get_input_scale(self, rates: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
93
+ def get_input_scale(self, rates: tuple[jnp.ndarray, jnp.ndarray], epsilon=1e-8) -> jnp.ndarray:
94
94
  _, sigma = rates
95
- c_in = 1 / jnp.sqrt(self.sigma_data ** 2 + sigma ** 2)
95
+ c_in = 1 / (jnp.sqrt(self.sigma_data ** 2 + sigma ** 2) + epsilon)
96
96
  return c_in
@@ -67,7 +67,7 @@ class DiffusionSampler():
67
67
  # Used to sample from the diffusion model
68
68
  def sample_step(self, sample_model_fn, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
69
69
  # First clip the noisy images
70
- step_ones = jnp.ones((current_samples.shape[0], ), dtype=jnp.int32)
70
+ step_ones = jnp.ones((len(current_samples), ), dtype=jnp.int32)
71
71
  current_step = step_ones * current_step
72
72
  next_step = step_ones * next_step
73
73
  pred_images, pred_noise, _ = sample_model_fn(current_samples, current_step, *model_conditioning_inputs)
@@ -133,6 +133,7 @@ class DiffusionSampler():
133
133
 
134
134
  params = params if params is not None else self.params
135
135
 
136
+ @jax.jit
136
137
  def sample_model_fn(x_t, t, *additional_inputs):
137
138
  return self.sample_model(params, x_t, t, *additional_inputs)
138
139
 
@@ -3,6 +3,16 @@ import jax.numpy as jnp
3
3
  from typing import Union
4
4
  from ..utils import RandomMarkovState
5
5
 
6
+ def get_coeff_shapes_tuple(array):
7
+ shape_tuple = (-1,) + (1,) * (array.ndim - 1)
8
+ return shape_tuple
9
+
10
+ def reshape_rates(rates:tuple[jnp.ndarray, jnp.ndarray], shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
11
+ signal_rates, noise_rates = rates
12
+ signal_rates = jnp.reshape(signal_rates, shape)
13
+ noise_rates = jnp.reshape(noise_rates, shape)
14
+ return signal_rates, noise_rates
15
+
6
16
  class NoiseScheduler():
7
17
  def __init__(self, timesteps,
8
18
  dtype=jnp.float32,
@@ -24,24 +34,18 @@ class NoiseScheduler():
24
34
  timesteps = self.timestep_generator(rng, batch_size, self.max_timesteps)
25
35
  return timesteps, state
26
36
 
27
- def get_weights(self, steps):
37
+ def get_weights(self, steps, shape=(-1, 1, 1, 1)):
28
38
  raise NotImplementedError
29
39
 
30
- def reshape_rates(self, rates:tuple[jnp.ndarray, jnp.ndarray], shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
31
- signal_rates, noise_rates = rates
32
- signal_rates = jnp.reshape(signal_rates, shape)
33
- noise_rates = jnp.reshape(noise_rates, shape)
34
- return signal_rates, noise_rates
35
-
36
40
  def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
37
41
  raise NotImplementedError
38
42
 
39
43
  def add_noise(self, images, noise, steps) -> jnp.ndarray:
40
- signal_rates, noise_rates = self.get_rates(steps)
44
+ signal_rates, noise_rates = self.get_rates(steps, shape=get_coeff_shapes_tuple(images))
41
45
  return signal_rates * images + noise_rates * noise
42
46
 
43
47
  def remove_all_noise(self, noisy_images, noise, steps, clip_denoised=True, rates=None):
44
- signal_rates, noise_rates = self.get_rates(steps)
48
+ signal_rates, noise_rates = self.get_rates(steps, shape=get_coeff_shapes_tuple(noisy_images))
45
49
  x_0 = (noisy_images - noise * noise_rates) / signal_rates
46
50
  return x_0
47
51
 
@@ -54,8 +58,8 @@ class NoiseScheduler():
54
58
  def get_posterior_variance(self, steps, shape=(-1, 1, 1, 1)):
55
59
  raise NotImplementedError
56
60
 
57
- def get_max_variance(self):
58
- alpha_n, sigma_n = self.get_rates(self.max_timesteps)
61
+ def get_max_variance(self, shape=(-1, 1, 1, 1)):
62
+ alpha_n, sigma_n = self.get_rates(self.max_timesteps, shape=shape)
59
63
  variance = jnp.sqrt(alpha_n ** 2 + sigma_n ** 2)
60
64
  return variance
61
65
 
@@ -82,9 +86,9 @@ class GeneralizedNoiseScheduler(NoiseScheduler):
82
86
 
83
87
  def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
84
88
  sigmas = self.get_sigmas(steps)
85
- signal_rates = 1
89
+ signal_rates = jnp.ones_like(sigmas)
86
90
  noise_rates = sigmas
87
- return self.reshape_rates((signal_rates, noise_rates), shape=shape)
91
+ return reshape_rates((signal_rates, noise_rates), shape=shape)
88
92
 
89
93
  def transform_inputs(self, x, steps, num_discrete_chunks=1000):
90
94
  sigmas_discrete = (steps / self.max_timesteps) * num_discrete_chunks
@@ -3,7 +3,7 @@ import numpy as np
3
3
  import jax.numpy as jnp
4
4
  from .discrete import DiscreteNoiseScheduler
5
5
  from .continuous import ContinuousNoiseScheduler
6
- from .common import GeneralizedNoiseScheduler
6
+ from .common import GeneralizedNoiseScheduler, reshape_rates
7
7
 
8
8
  def cosine_beta_schedule(timesteps, start_angle=0.008, end_angle=0.999):
9
9
  ts = np.linspace(0, 1, timesteps + 1, dtype=np.float64)
@@ -32,9 +32,9 @@ class CosineContinuousNoiseScheduler(ContinuousNoiseScheduler):
32
32
  def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
33
33
  signal_rates = jnp.cos((jnp.pi * steps) / (2 * self.max_timesteps))
34
34
  noise_rates = jnp.sin((jnp.pi * steps) / (2 * self.max_timesteps))
35
- return self.reshape_rates((signal_rates, noise_rates), shape=shape)
35
+ return reshape_rates((signal_rates, noise_rates), shape=shape)
36
36
 
37
- def get_weights(self, steps):
38
- alpha, sigma = self.get_rates(steps, shape=())
37
+ def get_weights(self, steps, shape=(-1, 1, 1, 1)) -> jnp.ndarray:
38
+ alpha, sigma = self.get_rates(steps, shape=shape)
39
39
  return 1 / (1 + (alpha ** 2 / sigma ** 2))
40
40
 
@@ -2,7 +2,7 @@ import jax
2
2
  import jax.numpy as jnp
3
3
  from typing import Union
4
4
  from ..utils import RandomMarkovState
5
- from .common import NoiseScheduler
5
+ from .common import NoiseScheduler, reshape_rates, get_coeff_shapes_tuple
6
6
 
7
7
  class DiscreteNoiseScheduler(NoiseScheduler):
8
8
  """
@@ -53,17 +53,15 @@ class DiscreteNoiseScheduler(NoiseScheduler):
53
53
 
54
54
  def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
55
55
  steps = jnp.int16(steps)
56
- signal_rate = self.sqrt_alpha_cumprod[steps]
57
- noise_rate = self.sqrt_one_minus_alpha_cumprod[steps]
58
- signal_rate = jnp.reshape(signal_rate, shape)
59
- noise_rate = jnp.reshape(noise_rate, shape)
60
- return signal_rate, noise_rate
56
+ signal_rates = self.sqrt_alpha_cumprod[steps]
57
+ noise_rates = self.sqrt_one_minus_alpha_cumprod[steps]
58
+ return reshape_rates((signal_rates, noise_rates), shape=shape)
61
59
 
62
60
  def get_posterior_mean(self, x_0, x_t, steps):
63
61
  steps = jnp.int16(steps)
64
62
  x_0_coeff = self.posterior_mean_coef1[steps]
65
63
  x_t_coeff = self.posterior_mean_coef2[steps]
66
- x_0_coeff, x_t_coeff = self.reshape_rates((x_0_coeff, x_t_coeff))
64
+ x_0_coeff, x_t_coeff = reshape_rates((x_0_coeff, x_t_coeff), shape=get_coeff_shapes_tuple(x_0))
67
65
  mean = x_0_coeff * x_0 + x_t_coeff * x_t
68
66
  return mean
69
67
 
@@ -5,8 +5,7 @@ def linear_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02):
5
5
  scale = 1000 / timesteps
6
6
  beta_start = scale * beta_start
7
7
  beta_end = scale * beta_end
8
- betas = np.linspace(
9
- beta_start, beta_end, timesteps, dtype=np.float64)
8
+ betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float64)
10
9
  return betas
11
10
 
12
11
  class LinearNoiseSchedule(DiscreteNoiseScheduler):
@@ -2,9 +2,10 @@ import numpy as np
2
2
  import jax.numpy as jnp
3
3
  from .discrete import DiscreteNoiseScheduler
4
4
  from .continuous import ContinuousNoiseScheduler
5
+ from .common import reshape_rates
5
6
 
6
7
  class SqrtContinuousNoiseScheduler(ContinuousNoiseScheduler):
7
8
  def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
8
9
  signal_rates = jnp.sqrt(1 - steps)
9
10
  noise_rates = jnp.sqrt(steps)
10
- return self.reshape_rates((signal_rates, noise_rates), shape=shape)
11
+ return reshape_rates((signal_rates, noise_rates), shape=shape)
@@ -11,7 +11,7 @@ from jax.sharding import Mesh, PartitionSpec as P
11
11
  from jax.experimental.shard_map import shard_map
12
12
  from typing import Dict, Callable, Sequence, Any, Union, Tuple, Type
13
13
 
14
- from ..schedulers import NoiseScheduler
14
+ from ..schedulers import NoiseScheduler, get_coeff_shapes_tuple
15
15
  from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
16
16
  from ..samplers.common import DiffusionSampler
17
17
  from ..samplers.ddim import DDIMSampler
@@ -144,6 +144,8 @@ class DiffusionTrainer(SimpleTrainer):
144
144
 
145
145
  images = batch['image']
146
146
 
147
+ local_batch_size = images.shape[0]
148
+
147
149
  # First get the standard deviation of the images
148
150
  # std = jnp.std(images, axis=(1, 2, 3))
149
151
  # is_non_zero = (std > 0)
@@ -164,22 +166,23 @@ class DiffusionTrainer(SimpleTrainer):
164
166
  label_seq = jnp.concat(
165
167
  [null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
166
168
 
167
- noise_level, local_rng_state = noise_schedule.generate_timesteps(images.shape[0], local_rng_state)
169
+ noise_level, local_rng_state = noise_schedule.generate_timesteps(local_batch_size, local_rng_state)
168
170
 
169
171
  local_rng_state, rngs = local_rng_state.get_random_key()
170
- noise: jax.Array = jax.random.normal(rngs, shape=images.shape)
172
+ noise: jax.Array = jax.random.normal(rngs, shape=images.shape, dtype=jnp.float32)
173
+
174
+ # Make sure image is also float32
175
+ images = images.astype(jnp.float32)
171
176
 
172
- rates = noise_schedule.get_rates(noise_level)
173
- noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
174
- images, noise, rates)
177
+ rates = noise_schedule.get_rates(noise_level, get_coeff_shapes_tuple(images))
178
+ noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(images, noise, rates)
175
179
 
176
180
  def model_loss(params):
177
181
  preds = model.apply(params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
178
- preds = model_output_transform.pred_transform(
179
- noisy_images, preds, rates)
182
+ preds = model_output_transform.pred_transform(noisy_images, preds, rates)
180
183
  nloss = loss_fn(preds, expected_output)
181
184
  # Ignore the loss contribution of images with zero standard deviation
182
- nloss *= noise_schedule.get_weights(noise_level)
185
+ nloss *= noise_schedule.get_weights(noise_level, get_coeff_shapes_tuple(nloss))
183
186
  nloss = jnp.mean(nloss)
184
187
  loss = nloss
185
188
  return loss
@@ -197,7 +200,22 @@ class DiffusionTrainer(SimpleTrainer):
197
200
  loss, grads = grad_fn(train_state.params)
198
201
  if distributed_training:
199
202
  grads = jax.lax.pmean(grads, "data")
203
+
204
+ # # check gradients for NaN/Inf
205
+ # has_nan_or_inf = jax.tree_util.tree_reduce(
206
+ # lambda acc, x: jnp.logical_or(acc, jnp.logical_or(jnp.isnan(x).any(), jnp.isinf(x).any())),
207
+ # grads,
208
+ # initializer=False
209
+ # )
200
210
 
211
+ # # Only apply gradients if they're valid
212
+ # new_state = jax.lax.cond(
213
+ # has_nan_or_inf,
214
+ # lambda _: train_state, # Skip gradient update
215
+ # lambda _: train_state.apply_gradients(grads=grads),
216
+ # operand=None
217
+ # )
218
+
201
219
  new_state = train_state.apply_gradients(grads=grads)
202
220
 
203
221
  if train_state.dynamic_scale is not None:
@@ -220,9 +238,16 @@ class DiffusionTrainer(SimpleTrainer):
220
238
  return train_state, loss, rng_state
221
239
 
222
240
  if distributed_training:
223
- train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')),
224
- out_specs=(P(), P(), P()))
225
- train_step = jax.jit(train_step)
241
+ train_step = shard_map(
242
+ train_step,
243
+ mesh=self.mesh,
244
+ in_specs=(P(), P(), P('data'), P('data')),
245
+ out_specs=(P(), P(), P()),
246
+ )
247
+ train_step = jax.jit(
248
+ train_step,
249
+ donate_argnums=(2)
250
+ )
226
251
 
227
252
  return train_step
228
253
 
@@ -235,12 +260,21 @@ class DiffusionTrainer(SimpleTrainer):
235
260
  null_labels_full = null_labels_full.astype(jnp.float16)
236
261
  # null_labels_seq = jnp.array(null_labels_full[0], dtype=jnp.float16)
237
262
 
263
+ if 'image' in self.input_shapes:
264
+ image_size = self.input_shapes['image'][1]
265
+ elif 'x' in self.input_shapes:
266
+ image_size = self.input_shapes['x'][1]
267
+ elif 'sample' in self.input_shapes:
268
+ image_size = self.input_shapes['sample'][1]
269
+ else:
270
+ raise ValueError("No image input shape found in input shapes")
271
+
238
272
  sampler = sampler_class(
239
273
  model=model,
240
274
  params=None,
241
275
  noise_schedule=self.noise_schedule if sampling_noise_schedule is None else sampling_noise_schedule,
242
276
  model_output_transform=self.model_output_transform,
243
- image_size=self.input_shapes['x'][0],
277
+ image_size=image_size,
244
278
  null_labels_seq=null_labels_full,
245
279
  autoencoder=autoencoder,
246
280
  guidance_scale=3.0,
@@ -291,7 +325,7 @@ class DiffusionTrainer(SimpleTrainer):
291
325
  )
292
326
 
293
327
  # Put each sample on wandb
294
- if self.wandb:
328
+ if getattr(self, 'wandb', None) is not None and self.wandb:
295
329
  import numpy as np
296
330
  from wandb import Image as wandbImage
297
331
  wandb_images = []
@@ -403,7 +403,6 @@ class SimpleTrainer:
403
403
  rng_state
404
404
  ):
405
405
  global_device_count = jax.device_count()
406
- local_device_count = jax.local_device_count()
407
406
  process_index = jax.process_index()
408
407
  if self.distributed_training:
409
408
  global_device_indexes = jnp.arange(global_device_count)
@@ -434,11 +433,16 @@ class SimpleTrainer:
434
433
  # loss = jax.experimental.multihost_utils.process_allgather(loss)
435
434
  loss = jnp.mean(loss) # Just to make sure its a scaler value
436
435
 
437
- if loss <= 1e-6:
436
+ if loss <= 1e-8:
438
437
  # If the loss is too low, we can assume the model has diverged
439
438
  print(colored(f"Loss too low at step {current_step} => {loss}", 'red'))
440
439
  # Reset the model to the old state
441
- exit(1)
440
+ if self.best_state is not None:
441
+ print(colored(f"Resetting model to best state", 'red'))
442
+ train_state = self.best_state
443
+ loss = self.best_loss
444
+ else:
445
+ exit(1)
442
446
 
443
447
  epoch_loss += loss
444
448
  current_step += 1
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.37.3
3
+ Version: 0.1.37.6
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -20,9 +20,9 @@ flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0
20
20
  flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
21
21
  flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
22
22
  flaxdiff/models/autoencoder/simple_autoenc.py,sha256=UXHPgDmwGTnv3Uts6Zj3p9R9nJXnEiEXbllgarwDfXM,805
23
- flaxdiff/predictors/__init__.py,sha256=SKkYYRF9Wfgk2zhtZw4vCXOdOeRlrm2Mk6cvuaEvAzc,4403
23
+ flaxdiff/predictors/__init__.py,sha256=A6lMnRSqDLwBdwB0a1x4xPpcXcTYCISWFMkldbiK6Vs,4554
24
24
  flaxdiff/samplers/__init__.py,sha256=EY9v1pgwEoR64Kiz9K8fAR-4_ir9c03mYeY3hrpUNhE,308
25
- flaxdiff/samplers/common.py,sha256=7gKNY4mWVnLjtcioGLFD_Vwmxg9zJovUb8EcYWlc_GE,8833
25
+ flaxdiff/samplers/common.py,sha256=wn8tryC3B0KE0V98zMiH_X2x-Tc1NbM5iV27hn5p8Aw,8846
26
26
  flaxdiff/samplers/ddim.py,sha256=hTjDm0SmIj-Tkc80QRATMcN_sKVhHbqZQboRQCAn4mY,569
27
27
  flaxdiff/samplers/ddpm.py,sha256=JgkNSo7fp7Jm-8rCy4eu5m4YIzXTWzxv-iHf3EQ0z5w,2243
28
28
  flaxdiff/samplers/euler.py,sha256=QSkttB4DYnepDGwhWq3EGXYjMAqj4qLOdh7u98HttzY,2791
@@ -30,20 +30,20 @@ flaxdiff/samplers/heun_sampler.py,sha256=EvR3hy4t_D47ZOH4luzRFqPmv2v4z78P_JhqBGE
30
30
  flaxdiff/samplers/multistep_dpm.py,sha256=2M4Abb93-GUVN1f0_ZHBeA6lF0eF15Hi6QOgOu2K45s,2752
31
31
  flaxdiff/samplers/rk4_sampler.py,sha256=vcQefFhOUZdNOQGBdzNkb2NgmTC2KWd_nhUhyLtt3yI,2026
32
32
  flaxdiff/schedulers/__init__.py,sha256=EIva9gBz3DKHORuGmv1LQCKTtRqCRavFOXMNqxAR_ks,131
33
- flaxdiff/schedulers/common.py,sha256=b-W4iI-aqScpVE8VZbBpiYvAVI6rqDkUP-C_hEVBwCI,4151
33
+ flaxdiff/schedulers/common.py,sha256=PDeje2NmN7X3J5qKGauE0jYPpxjgEX44f_evJHRIG3E,4382
34
34
  flaxdiff/schedulers/continuous.py,sha256=5c_niOA20fxJ5oJDi09FfayIRogBGwtfG0XThW2IUZk,334
35
- flaxdiff/schedulers/cosine.py,sha256=EtU3SjJaP9R9ULHNiYrX9jBLSsAGKPGteHiwOzWNzYo,2006
36
- flaxdiff/schedulers/discrete.py,sha256=O54wH2HVu3olJA71NxgAXFW9cr6B6Gl-DR_uZeytpds,3319
35
+ flaxdiff/schedulers/cosine.py,sha256=E5pODAmINfdyC4kSYOJSPAvq3GNlKPpKEn3X82vYMz0,2055
36
+ flaxdiff/schedulers/discrete.py,sha256=m1q3bAgeAxU3gTj5di3XFWDm4yLfMKAFJPlYdozLE2Y,3316
37
37
  flaxdiff/schedulers/exp.py,sha256=cPTnUJpYdzJRRZqMLYQz0rRUCpEmaP2tXhRumLx94jA,605
38
38
  flaxdiff/schedulers/karras.py,sha256=4GN120kGwdxxU-h2mVdhBVy9IORkUMm_vvz3XjthBcI,3355
39
- flaxdiff/schedulers/linear.py,sha256=6003F5ISq1Wc0h6UAzY95MJgsDIKGMhBzbiVALpea0k,581
40
- flaxdiff/schedulers/sqrt.py,sha256=1F84ZgQPuoNMhe6yxGTR2G0h7dPOZtm4UDQOakbSsEU,445
39
+ flaxdiff/schedulers/linear.py,sha256=pBDTXSQcOS4Z03JTh6S0f9E2qLcTQzF2E-pGoQnRoy0,572
40
+ flaxdiff/schedulers/sqrt.py,sha256=mCd_szmOqF6vqQKiAiEOqV_3eBIPGYrW3VxK0o4rBuo,474
41
41
  flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,128
42
42
  flaxdiff/trainer/autoencoder_trainer.py,sha256=hxihkRL9WCIQVGOP-pc1jjjIUaRXDLcNo3_erTKsuWM,7049
43
- flaxdiff/trainer/diffusion_trainer.py,sha256=KVeXJ9ZQKcvD-O_hCJnxro0dQRuQe5ZVGGMEL4Lgm9k,12814
44
- flaxdiff/trainer/simple_trainer.py,sha256=lmRo8N0bMupIyS3ejPvPtxoskY_3GLC8iyJE6u4TIWc,21990
43
+ flaxdiff/trainer/diffusion_trainer.py,sha256=kEulMnk6ZkKhQRSVr3UtDdCmXR4cWphJ3XNuk7VIAUY,14189
44
+ flaxdiff/trainer/simple_trainer.py,sha256=LScHQZCy5ksSC7n0GC0tjOXK-zptxpMJsC6Udf-nz18,22178
45
45
  flaxdiff/trainer/video_diffusion_trainer.py,sha256=gMkKpnKNTo8QhTx5ptEEkc7W5-7rzXIr9queU53hXyQ,2197
46
- flaxdiff-0.1.37.3.dist-info/METADATA,sha256=7U7SINGO_ZzsUeuTPi2LTMqxrj93Cvglyh1Q7D39zRM,23985
47
- flaxdiff-0.1.37.3.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
48
- flaxdiff-0.1.37.3.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
49
- flaxdiff-0.1.37.3.dist-info/RECORD,,
46
+ flaxdiff-0.1.37.6.dist-info/METADATA,sha256=SujaCKk29ECrfSEIdchYvAl-nf0L270t2of7oeX5kgk,23985
47
+ flaxdiff-0.1.37.6.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
48
+ flaxdiff-0.1.37.6.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
49
+ flaxdiff-0.1.37.6.dist-info/RECORD,,