flaxdiff 0.1.38.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. flaxdiff/data/__init__.py +5 -1
  2. flaxdiff/data/benchmark_decord.py +443 -0
  3. flaxdiff/data/dataloaders.py +608 -0
  4. flaxdiff/data/dataset_map.py +61 -6
  5. flaxdiff/data/online_loader.py +779 -150
  6. flaxdiff/data/sources/audio_utils.py +142 -0
  7. flaxdiff/data/sources/av_example.py +125 -0
  8. flaxdiff/data/sources/av_utils.py +590 -0
  9. flaxdiff/data/sources/base.py +129 -0
  10. flaxdiff/data/sources/images.py +309 -0
  11. flaxdiff/data/sources/utils.py +158 -0
  12. flaxdiff/data/sources/videos.py +250 -0
  13. flaxdiff/data/sources/voxceleb2.py +412 -0
  14. flaxdiff/inference/__init__.py +0 -0
  15. flaxdiff/inference/pipeline.py +260 -0
  16. flaxdiff/inference/utils.py +320 -0
  17. flaxdiff/inputs/__init__.py +173 -0
  18. flaxdiff/inputs/encoders.py +98 -0
  19. flaxdiff/models/__init__.py +2 -1
  20. flaxdiff/models/autoencoder/autoencoder.py +141 -9
  21. flaxdiff/models/autoencoder/diffusers.py +88 -25
  22. flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
  23. flaxdiff/models/simple_unet.py +5 -5
  24. flaxdiff/models/simple_vit.py +1 -1
  25. flaxdiff/models/unet_3d.py +446 -0
  26. flaxdiff/models/unet_3d_blocks.py +505 -0
  27. flaxdiff/samplers/common.py +358 -96
  28. flaxdiff/samplers/ddim.py +44 -5
  29. flaxdiff/schedulers/karras.py +20 -12
  30. flaxdiff/trainer/__init__.py +2 -1
  31. flaxdiff/trainer/autoencoder_trainer.py +1 -2
  32. flaxdiff/trainer/diffusion_trainer.py +33 -27
  33. flaxdiff/trainer/general_diffusion_trainer.py +583 -0
  34. flaxdiff/trainer/simple_trainer.py +48 -31
  35. flaxdiff/utils.py +128 -57
  36. {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
  37. flaxdiff-0.2.0.dist-info/RECORD +64 -0
  38. {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
  39. flaxdiff/data/datasets.py +0 -169
  40. flaxdiff/data/sources/gcs.py +0 -81
  41. flaxdiff/data/sources/tfds.py +0 -79
  42. flaxdiff/trainer/video_diffusion_trainer.py +0 -62
  43. flaxdiff-0.1.38.1.dist-info/RECORD +0 -50
  44. {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
22
22
 
23
23
  from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
24
24
  from flax.training import dynamic_scale as dynamic_scale_lib
25
- from flaxdiff.utils import TextEncoder, ConditioningEncoder
25
+ from flaxdiff.inputs import TextEncoder, ConditioningEncoder
26
26
 
27
27
  class TrainState(SimpleTrainState):
28
28
  rngs: jax.random.PRNGKey
@@ -42,6 +42,7 @@ class DiffusionTrainer(SimpleTrainer):
42
42
  noise_schedule: NoiseScheduler
43
43
  model_output_transform: DiffusionPredictionTransform
44
44
  ema_decay: float = 0.999
45
+ native_resolution: int = None
45
46
 
46
47
  def __init__(self,
47
48
  model: nn.Module,
@@ -54,6 +55,7 @@ class DiffusionTrainer(SimpleTrainer):
54
55
  model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
55
56
  autoencoder: AutoEncoder = None,
56
57
  encoder: ConditioningEncoder = None,
58
+ native_resolution: int = None,
57
59
  **kwargs
58
60
  ):
59
61
  super().__init__(
@@ -68,6 +70,20 @@ class DiffusionTrainer(SimpleTrainer):
68
70
  self.model_output_transform = model_output_transform
69
71
  self.unconditional_prob = unconditional_prob
70
72
 
73
+ if native_resolution is None:
74
+ if 'image' in input_shapes:
75
+ native_resolution = input_shapes['image'][1]
76
+ elif 'x' in input_shapes:
77
+ native_resolution = input_shapes['x'][1]
78
+ elif 'sample' in input_shapes:
79
+ native_resolution = input_shapes['sample'][1]
80
+ else:
81
+ raise ValueError("No image input shape found in input shapes")
82
+ if autoencoder is not None:
83
+ native_resolution = native_resolution * 8
84
+
85
+ self.native_resolution = native_resolution
86
+
71
87
  self.autoencoder = autoencoder
72
88
  self.encoder = encoder
73
89
 
@@ -118,9 +134,6 @@ class DiffusionTrainer(SimpleTrainer):
118
134
  model_output_transform = self.model_output_transform
119
135
  loss_fn = self.loss_fn
120
136
  unconditional_prob = self.unconditional_prob
121
-
122
- # Determine the number of unconditional samples
123
- num_unconditional = int(batch_size * unconditional_prob)
124
137
 
125
138
  null_labels_full = self.encoder([""])
126
139
  null_labels_seq = jnp.array(null_labels_full[0], dtype=jnp.float16)
@@ -159,12 +172,19 @@ class DiffusionTrainer(SimpleTrainer):
159
172
  local_rng_state, rngs = local_rng_state.get_random_key()
160
173
  images = autoencoder.encode(images, rngs)
161
174
 
162
- label_seq = conditioning_encoder.encode_from_tokens(batch)
175
+ label_seq = conditioning_encoder.encode_from_tokens(batch['text'])
163
176
 
164
177
  # Generate random probabilities to decide how much of this batch will be unconditional
178
+ local_rng_state, uncond_key = local_rng_state.get_random_key()
179
+ # Efficient way to determine unconditional samples for JIT compatibility
180
+ uncond_mask = jax.random.bernoulli(
181
+ uncond_key,
182
+ shape=(local_batch_size,),
183
+ p=unconditional_prob
184
+ )
185
+ num_unconditional = jnp.sum(uncond_mask).astype(jnp.int32)
165
186
 
166
- label_seq = jnp.concat(
167
- [null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
187
+ label_seq = jnp.concatenate([null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
168
188
 
169
189
  noise_level, local_rng_state = noise_schedule.generate_timesteps(local_batch_size, local_rng_state)
170
190
 
@@ -200,21 +220,6 @@ class DiffusionTrainer(SimpleTrainer):
200
220
  loss, grads = grad_fn(train_state.params)
201
221
  if distributed_training:
202
222
  grads = jax.lax.pmean(grads, "data")
203
-
204
- # # check gradients for NaN/Inf
205
- # has_nan_or_inf = jax.tree_util.tree_reduce(
206
- # lambda acc, x: jnp.logical_or(acc, jnp.logical_or(jnp.isnan(x).any(), jnp.isinf(x).any())),
207
- # grads,
208
- # initializer=False
209
- # )
210
-
211
- # # Only apply gradients if they're valid
212
- # new_state = jax.lax.cond(
213
- # has_nan_or_inf,
214
- # lambda _: train_state, # Skip gradient update
215
- # lambda _: train_state.apply_gradients(grads=grads),
216
- # operand=None
217
- # )
218
223
 
219
224
  new_state = train_state.apply_gradients(grads=grads)
220
225
 
@@ -251,7 +256,7 @@ class DiffusionTrainer(SimpleTrainer):
251
256
 
252
257
  return train_step
253
258
 
254
- def _define_vaidation_step(self, sampler_class: Type[DiffusionSampler]=DDIMSampler, sampling_noise_schedule: NoiseScheduler=None):
259
+ def _define_validation_step(self, sampler_class: Type[DiffusionSampler]=DDIMSampler, sampling_noise_schedule: NoiseScheduler=None):
255
260
  model = self.model
256
261
  encoder = self.encoder
257
262
  autoencoder = self.autoencoder
@@ -260,7 +265,9 @@ class DiffusionTrainer(SimpleTrainer):
260
265
  null_labels_full = null_labels_full.astype(jnp.float16)
261
266
  # null_labels_seq = jnp.array(null_labels_full[0], dtype=jnp.float16)
262
267
 
263
- if 'image' in self.input_shapes:
268
+ if self.native_resolution is not None:
269
+ image_size = self.native_resolution
270
+ elif 'image' in self.input_shapes:
264
271
  image_size = self.input_shapes['image'][1]
265
272
  elif 'x' in self.input_shapes:
266
273
  image_size = self.input_shapes['x'][1]
@@ -271,10 +278,8 @@ class DiffusionTrainer(SimpleTrainer):
271
278
 
272
279
  sampler = sampler_class(
273
280
  model=model,
274
- params=None,
275
281
  noise_schedule=self.noise_schedule if sampling_noise_schedule is None else sampling_noise_schedule,
276
282
  model_output_transform=self.model_output_transform,
277
- image_size=image_size,
278
283
  null_labels_seq=null_labels_full,
279
284
  autoencoder=autoencoder,
280
285
  guidance_scale=3.0,
@@ -290,7 +295,8 @@ class DiffusionTrainer(SimpleTrainer):
290
295
  labels_seq = jnp.array(labels_seq, dtype=jnp.float16)
291
296
  samples = sampler.generate_images(
292
297
  params=val_state.ema_params,
293
- num_images=len(labels_seq),
298
+ resolution=image_size,
299
+ num_samples=len(labels_seq),
294
300
  diffusion_steps=diffusion_steps,
295
301
  start_step=1000,
296
302
  end_step=0,