flaxdiff 0.1.38.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. flaxdiff/data/__init__.py +5 -1
  2. flaxdiff/data/benchmark_decord.py +443 -0
  3. flaxdiff/data/dataloaders.py +608 -0
  4. flaxdiff/data/dataset_map.py +61 -6
  5. flaxdiff/data/online_loader.py +779 -150
  6. flaxdiff/data/sources/audio_utils.py +142 -0
  7. flaxdiff/data/sources/av_example.py +125 -0
  8. flaxdiff/data/sources/av_utils.py +590 -0
  9. flaxdiff/data/sources/base.py +129 -0
  10. flaxdiff/data/sources/images.py +309 -0
  11. flaxdiff/data/sources/utils.py +158 -0
  12. flaxdiff/data/sources/videos.py +250 -0
  13. flaxdiff/data/sources/voxceleb2.py +412 -0
  14. flaxdiff/inference/__init__.py +0 -0
  15. flaxdiff/inference/pipeline.py +260 -0
  16. flaxdiff/inference/utils.py +320 -0
  17. flaxdiff/inputs/__init__.py +173 -0
  18. flaxdiff/inputs/encoders.py +98 -0
  19. flaxdiff/models/__init__.py +2 -1
  20. flaxdiff/models/autoencoder/autoencoder.py +141 -9
  21. flaxdiff/models/autoencoder/diffusers.py +88 -25
  22. flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
  23. flaxdiff/models/simple_unet.py +5 -5
  24. flaxdiff/models/simple_vit.py +1 -1
  25. flaxdiff/models/unet_3d.py +446 -0
  26. flaxdiff/models/unet_3d_blocks.py +505 -0
  27. flaxdiff/samplers/common.py +358 -96
  28. flaxdiff/samplers/ddim.py +44 -5
  29. flaxdiff/schedulers/karras.py +20 -12
  30. flaxdiff/trainer/__init__.py +2 -1
  31. flaxdiff/trainer/autoencoder_trainer.py +1 -2
  32. flaxdiff/trainer/diffusion_trainer.py +33 -27
  33. flaxdiff/trainer/general_diffusion_trainer.py +583 -0
  34. flaxdiff/trainer/simple_trainer.py +48 -31
  35. flaxdiff/utils.py +128 -57
  36. {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
  37. flaxdiff-0.2.0.dist-info/RECORD +64 -0
  38. {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
  39. flaxdiff/data/datasets.py +0 -169
  40. flaxdiff/data/sources/gcs.py +0 -81
  41. flaxdiff/data/sources/tfds.py +0 -79
  42. flaxdiff/trainer/video_diffusion_trainer.py +0 -62
  43. flaxdiff-0.1.38.1.dist-info/RECORD +0 -50
  44. {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,148 +1,368 @@
1
- from flax import linen as nn
1
+ from typing import Union, Type
2
+
2
3
  import jax
3
4
  import jax.numpy as jnp
4
5
  import tqdm
5
- from typing import Union, Type
6
+ from flax import linen as nn
7
+ from typing import List, Tuple, Dict, Any, Optional
8
+
9
+ from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
6
10
  from ..schedulers import NoiseScheduler
7
11
  from ..utils import RandomMarkovState, MarkovState, clip_images
8
- from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
12
+ from jax.experimental.shard_map import shard_map
13
+ from jax.sharding import Mesh, PartitionSpec as P
14
+ from flaxdiff.models.autoencoder import AutoEncoder
15
+ from flaxdiff.inputs import DiffusionInputConfig
9
16
 
10
- class DiffusionSampler():
11
- def __init__(self, model:nn.Module, params:dict,
12
- noise_schedule:NoiseScheduler,
13
- model_output_transform:DiffusionPredictionTransform,
14
- guidance_scale:float = 0.0,
15
- null_labels_seq:jax.Array=None,
16
- autoencoder=None,
17
- image_size=256,
18
- autoenc_scale_reduction=8,
19
- autoenc_latent_channels=4,
20
- ):
17
+ class DiffusionSampler:
18
+ """Base class for diffusion samplers."""
19
+
20
+ def __init__(
21
+ self,
22
+ model: nn.Module,
23
+ noise_schedule: NoiseScheduler,
24
+ model_output_transform: DiffusionPredictionTransform,
25
+ input_config: DiffusionInputConfig,
26
+ guidance_scale: float = 0.0,
27
+ autoencoder: AutoEncoder = None,
28
+ timestep_spacing: str = 'linear',
29
+ ):
30
+ """Initialize the diffusion sampler.
31
+
32
+ Args:
33
+ model: Neural network model
34
+ params: Model parameters
35
+ noise_schedule: Noise scheduler
36
+ model_output_transform: Transform for model predictions
37
+ guidance_scale: Scale for classifier-free guidance (0.0 means disabled)
38
+ autoencoder: Optional autoencoder for latent diffusion
39
+ timestep_spacing: Strategy for timestep spacing in sampling
40
+ 'linear' - Default equal spacing
41
+ 'quadratic' - Emphasizes early steps
42
+ 'karras' - Based on EDM paper, better with fewer steps
43
+ 'exponential' - Concentrates steps near the end
44
+ """
21
45
  self.model = model
22
46
  self.noise_schedule = noise_schedule
23
- self.params = params
24
47
  self.model_output_transform = model_output_transform
25
48
  self.guidance_scale = guidance_scale
26
- self.image_size = image_size
27
- self.autoenc_scale_reduction = autoenc_scale_reduction
28
49
  self.autoencoder = autoencoder
29
- self.autoenc_latent_channels = autoenc_latent_channels
50
+ self.timestep_spacing = timestep_spacing
51
+ self.input_config = input_config
52
+
53
+ unconditionals = input_config.get_unconditionals()
54
+
55
+ # For Karras spacing if needed
56
+ if hasattr(noise_schedule, 'min_inv_rho') and hasattr(noise_schedule, 'max_inv_rho'):
57
+ self.min_inv_rho = noise_schedule.min_inv_rho
58
+ self.max_inv_rho = noise_schedule.max_inv_rho
30
59
 
31
60
  if self.guidance_scale > 0:
32
61
  # Classifier free guidance
33
- assert null_labels_seq is not None, "Null labels sequence is required for classifier-free guidance"
34
62
  print("Using classifier-free guidance")
35
- def sample_model(params, x_t, t, *additional_inputs):
63
+
64
+ def sample_model(params, x_t, t, *conditioning_inputs):
36
65
  # Concatenate unconditional and conditional inputs
37
66
  x_t_cat = jnp.concatenate([x_t] * 2, axis=0)
38
67
  t_cat = jnp.concatenate([t] * 2, axis=0)
39
68
  rates_cat = self.noise_schedule.get_rates(t_cat)
40
69
  c_in_cat = self.model_output_transform.get_input_scale(rates_cat)
41
70
 
42
- text_labels_seq, = additional_inputs
43
- text_labels_seq = jnp.concatenate([text_labels_seq, jnp.broadcast_to(null_labels_seq, text_labels_seq.shape)], axis=0)
44
- model_output = self.model.apply(params, *self.noise_schedule.transform_inputs(x_t_cat * c_in_cat, t_cat), text_labels_seq)
71
+ final_conditionals = []
72
+ for conditional, unconditional in zip(conditioning_inputs, unconditionals):
73
+ final = jnp.concatenate([
74
+ conditional,
75
+ jnp.broadcast_to(unconditional, conditional.shape)
76
+ ], axis=0)
77
+ final_conditionals.append(final)
78
+ final_conditionals = tuple(final_conditionals)
79
+
80
+ model_output = self.model.apply(
81
+ params,
82
+ *self.noise_schedule.transform_inputs(x_t_cat * c_in_cat, t_cat),
83
+ *final_conditionals
84
+ )
85
+
45
86
  # Split model output into unconditional and conditional parts
46
87
  model_output_cond, model_output_uncond = jnp.split(model_output, 2, axis=0)
47
88
  model_output = model_output_uncond + guidance_scale * (model_output_cond - model_output_uncond)
48
-
89
+
49
90
  x_0, eps = self.model_output_transform(x_t, model_output, t, self.noise_schedule)
50
91
  return x_0, eps, model_output
51
92
  else:
52
93
  # Unconditional sampling
53
- def sample_model(params, x_t, t, *additional_inputs):
94
+ def sample_model(params, x_t, t, *conditioning_inputs):
54
95
  rates = self.noise_schedule.get_rates(t)
55
96
  c_in = self.model_output_transform.get_input_scale(rates)
56
- model_output = self.model.apply(params, *self.noise_schedule.transform_inputs(x_t * c_in, t), *additional_inputs)
97
+ model_output = self.model.apply(
98
+ params,
99
+ *self.noise_schedule.transform_inputs(x_t * c_in, t),
100
+ *conditioning_inputs
101
+ )
57
102
  x_0, eps = self.model_output_transform(x_t, model_output, t, self.noise_schedule)
58
103
  return x_0, eps, model_output
59
104
 
60
- # if jax.device_count() > 1:
61
- # mesh = jax.sharding.Mesh(jax.devices(), 'data')
62
- # sample_model = shard_map(sample_model, mesh=mesh, in_specs=(P('data'), P('data'), P('data')),
63
- # out_specs=(P('data'), P('data'), P('data')))
64
- sample_model = jax.jit(sample_model)
65
- self.sample_model = sample_model
66
-
67
- # Used to sample from the diffusion model
68
- def sample_step(self, sample_model_fn, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
69
- # First clip the noisy images
70
- step_ones = jnp.ones((len(current_samples), ), dtype=jnp.int32)
105
+ # JIT compile the sampling function for better performance
106
+ def post_process(samples: jnp.ndarray):
107
+ """Post-process the generated samples."""
108
+ if autoencoder is not None:
109
+ samples = autoencoder.decode(samples)
110
+
111
+ samples = clip_images(samples)
112
+ return samples
113
+
114
+ self.sample_model = jax.jit(sample_model)
115
+ self.post_process = jax.jit(post_process)
116
+
117
+ def sample_step(
118
+ self,
119
+ sample_model_fn,
120
+ current_samples: jnp.ndarray,
121
+ current_step,
122
+ model_conditioning_inputs,
123
+ next_step=None,
124
+ state: RandomMarkovState = None
125
+ ) -> tuple[jnp.ndarray, RandomMarkovState]:
126
+ """Perform a single sampling step in the diffusion process.
127
+
128
+ Args:
129
+ sample_model_fn: Function to sample from model
130
+ current_samples: Current noisy samples
131
+ current_step: Current diffusion timestep
132
+ model_conditioning_inputs: Conditioning inputs for the model
133
+ next_step: Next diffusion timestep
134
+ state: Current Markov state
135
+
136
+ Returns:
137
+ Tuple of (new samples, updated state)
138
+ """
139
+ step_ones = jnp.ones((len(current_samples),), dtype=jnp.int32)
71
140
  current_step = step_ones * current_step
72
141
  next_step = step_ones * next_step
73
- pred_images, pred_noise, _ = sample_model_fn(current_samples, current_step, *model_conditioning_inputs)
74
- # plotImages(pred_images)
75
- # pred_images = clip_images(pred_images)
76
- new_samples, state = self.take_next_step(current_samples=current_samples, reconstructed_samples=pred_images,
77
- pred_noise=pred_noise, current_step=current_step, next_step=next_step, state=state,
78
- model_conditioning_inputs=model_conditioning_inputs,
79
- sample_model_fn=sample_model_fn,
80
- )
142
+
143
+ pred_images, pred_noise, _ = sample_model_fn(
144
+ current_samples, current_step, *model_conditioning_inputs
145
+ )
146
+
147
+ new_samples, state = self.take_next_step(
148
+ current_samples=current_samples,
149
+ reconstructed_samples=pred_images,
150
+ pred_noise=pred_noise,
151
+ current_step=current_step,
152
+ next_step=next_step,
153
+ state=state,
154
+ model_conditioning_inputs=model_conditioning_inputs,
155
+ sample_model_fn=sample_model_fn,
156
+ )
81
157
  return new_samples, state
82
158
 
83
- def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
84
- pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1,) -> tuple[jnp.ndarray, RandomMarkovState]:
85
- # estimate the q(x_{t-1} | x_t, x_0).
86
- # pred_images is x_0, noisy_images is x_t, steps is t
87
- return NotImplementedError
88
-
159
+
160
+ def take_next_step(
161
+ self,
162
+ current_samples,
163
+ reconstructed_samples,
164
+ model_conditioning_inputs,
165
+ pred_noise,
166
+ current_step,
167
+ state: RandomMarkovState,
168
+ sample_model_fn,
169
+ next_step=1,
170
+ ) -> tuple[jnp.ndarray, RandomMarkovState]:
171
+ """Take the next step in the diffusion process.
172
+
173
+ This method needs to be implemented by subclasses.
174
+ """
175
+ raise NotImplementedError("Subclasses must implement take_next_step method")
176
+
177
+
89
178
  def scale_steps(self, steps):
179
+ """Scale timesteps to match the noise schedule's range."""
90
180
  scale_factor = self.noise_schedule.max_timesteps / 1000
91
181
  return steps * scale_factor
92
182
 
183
+
93
184
  def get_steps(self, start_step, end_step, diffusion_steps):
185
+ """Get the sequence of timesteps for the diffusion process.
186
+
187
+ Args:
188
+ start_step: Starting timestep (typically the max)
189
+ end_step: Ending timestep (typically 0)
190
+ diffusion_steps: Number of steps to use
191
+
192
+ Returns:
193
+ Array of timesteps for sampling
194
+ """
94
195
  step_range = start_step - end_step
95
196
  if diffusion_steps is None or diffusion_steps == 0:
96
- diffusion_steps = start_step - end_step
197
+ diffusion_steps = step_range
97
198
  diffusion_steps = min(diffusion_steps, step_range)
98
- steps = jnp.linspace(end_step, start_step, diffusion_steps, dtype=jnp.int16)[::-1]
199
+
200
+ # Linear spacing (default)
201
+ if getattr(self, 'timestep_spacing', 'linear') == 'linear':
202
+ steps = jnp.linspace(
203
+ end_step, start_step,
204
+ diffusion_steps, dtype=jnp.int16
205
+ )[::-1]
206
+
207
+ # Quadratic spacing (emphasizes early steps)
208
+ elif self.timestep_spacing == 'quadratic':
209
+ steps = jnp.linspace(0, 1, diffusion_steps) ** 2
210
+ steps = (start_step - end_step) * steps + end_step
211
+ steps = jnp.asarray(steps, dtype=jnp.int16)[::-1]
212
+
213
+ # Karras spacing from the EDM paper - often gives better results with fewer steps
214
+ elif self.timestep_spacing == 'karras':
215
+ # Implementation based on the EDM paper's recommendations
216
+ sigma_min = end_step / start_step
217
+ sigma_max = 1.0
218
+ rho = 7.0 # Karras paper default, controls the distribution
219
+
220
+ # Create log-spaced steps in sigma space
221
+ sigmas = jnp.exp(jnp.linspace(
222
+ jnp.log(sigma_max), jnp.log(sigma_min), diffusion_steps
223
+ ))
224
+ steps = jnp.clip(
225
+ (sigmas ** (1 / rho) - self.min_inv_rho) /
226
+ (self.max_inv_rho - self.min_inv_rho),
227
+ 0, 1
228
+ ) * start_step
229
+ steps = jnp.asarray(steps, dtype=jnp.int16)
230
+
231
+ # Exponential spacing (concentrates steps near the end)
232
+ elif self.timestep_spacing == 'exponential':
233
+ steps = jnp.linspace(0, 1, diffusion_steps)
234
+ steps = jnp.exp(steps * jnp.log((start_step + 1) / (end_step + 1))) * (end_step + 1) - 1
235
+ steps = jnp.clip(steps, end_step, start_step)
236
+ steps = jnp.asarray(steps, dtype=jnp.int16)[::-1]
237
+
238
+ # Fallback to linear spacing
239
+ else:
240
+ steps = jnp.linspace(
241
+ end_step, start_step,
242
+ diffusion_steps, dtype=jnp.int16
243
+ )[::-1]
244
+
99
245
  return steps
100
-
101
- def get_initial_samples(self, num_images, rngs:jax.random.PRNGKey, start_step):
102
- start_step = self.scale_steps(start_step)
103
- alpha_n, sigma_n = self.noise_schedule.get_rates(start_step)
104
- variance = jnp.sqrt(alpha_n ** 2 + sigma_n ** 2)
105
- image_size = self.image_size
106
- image_channels = 3
107
- if self.autoencoder is not None:
108
- image_size = image_size // self.autoenc_scale_reduction
109
- image_channels = self.autoenc_latent_channels
110
- return jax.random.normal(rngs, (num_images, image_size, image_size, image_channels)) * variance
111
-
112
- def generate_images(self,
113
- params:dict=None,
114
- num_images=16,
115
- diffusion_steps=1000,
116
- start_step:int = None,
117
- end_step:int = 0,
118
- steps_override=None,
119
- priors=None,
120
- rngstate:RandomMarkovState=None,
121
- model_conditioning_inputs:tuple=()
122
- ) -> jnp.ndarray:
246
+
247
+
248
+ def generate_samples(
249
+ self,
250
+ params: dict,
251
+ num_samples: int,
252
+ resolution: int,
253
+ sequence_length: int = None,
254
+ diffusion_steps: int = 1000,
255
+ start_step: int = None,
256
+ end_step: int = 0,
257
+ steps_override=None,
258
+ priors=None,
259
+ rngstate: RandomMarkovState = None,
260
+ conditioning: List[Union[Tuple, Dict]] = None,
261
+ model_conditioning_inputs: Tuple = None,
262
+ ) -> jnp.ndarray:
263
+ """Generate samples using the diffusion model.
264
+
265
+ Provides a unified interface for generating both images and videos.
266
+ For images, just specify batch_size.
267
+ For videos, specify both batch_size and sequence_length.
268
+
269
+ Args:
270
+ params: Model parameters (uses self.params if None)
271
+ num_samples: Number of samples to generate (videos or images)
272
+ resolution: Resolution of the generated samples (H, W)
273
+ sequence_length: Length of each sequence (for videos/audio/etc)
274
+ If None, generates regular images
275
+ diffusion_steps: Number of diffusion steps to perform
276
+ start_step: Starting timestep (defaults to max)
277
+ end_step: Ending timestep
278
+ steps_override: Override default timestep sequence
279
+ priors: Prior samples to start from instead of noise
280
+ rngstate: Random state for reproducibility
281
+ conditioning: (Optional) List of conditioning inputs for the model
282
+ model_conditioning_inputs: (Optional) Pre-processed conditioning inputs
283
+
284
+ Returns:
285
+ Generated samples as a JAX array:
286
+ - For images: shape [batch_size, H, W, C]
287
+ - For videos: shape [batch_size, sequence_length, H, W, C]
288
+ """
123
289
  if rngstate is None:
124
290
  rngstate = RandomMarkovState(jax.random.PRNGKey(42))
291
+
292
+ if start_step is None:
293
+ start_step = self.noise_schedule.max_timesteps
294
+
125
295
  if priors is None:
296
+ # Determine if we're generating videos or images based on sequence_length
297
+ is_video = sequence_length is not None
298
+
126
299
  rngstate, newrngs = rngstate.get_random_key()
127
- samples = self.get_initial_samples(num_images, newrngs, start_step)
300
+
301
+ # Get sample shape based on whether we're generating video or images
302
+ if is_video:
303
+ samples = self._get_initial_sequence_samples(
304
+ resolution, num_samples, sequence_length, newrngs, start_step
305
+ )
306
+ else:
307
+ samples = self._get_initial_samples(resolution, num_samples, newrngs, start_step)
128
308
  else:
129
309
  print("Using priors")
130
310
  if self.autoencoder is not None:
311
+ # Let the autoencoder handle both image and video priors
131
312
  priors = self.autoencoder.encode(priors)
132
313
  samples = priors
133
-
134
- params = params if params is not None else self.params
135
314
 
136
- # @jax.jit
315
+ if conditioning is not None:
316
+ if model_conditioning_inputs is not None:
317
+ raise ValueError("Cannot provide both conditioning and model_conditioning_inputs")
318
+ print("Processing raw conditioning inputs to generate model conditioning inputs")
319
+ separated: Dict[str, List] = {}
320
+ for cond in self.input_config.conditions:
321
+ separated[cond.encoder.key] = []
322
+ # Separate the conditioning inputs, one for each condition
323
+ for vals in conditioning:
324
+ if isinstance(vals, tuple) or isinstance(vals, list):
325
+ # If its a tuple, assume that the ordering aligns with the ordering of the conditions
326
+ # Thus, use the conditioning encoder key as the key
327
+ for cond, val in zip(self.input_config.conditions, vals):
328
+ separated[cond.encoder.key].append(val)
329
+ elif isinstance(vals, dict):
330
+ # If its a dict, use the encoder key as the key
331
+ for cond in self.input_config.conditions:
332
+ if cond.encoder.key in vals:
333
+ separated[cond.encoder.key].append(vals[cond.encoder.key])
334
+ else:
335
+ raise ValueError(f"Conditioning input {cond.encoder.key} not found in provided dictionary")
336
+ else:
337
+ # If its a single value, use the encoder key as the key
338
+ for cond in self.input_config.conditions:
339
+ separated[cond.encoder.key].append(vals)
340
+
341
+ # Now we have a dictionary of lists, one for each condition, encode them
342
+ finals = []
343
+ for cond in self.input_config.conditions:
344
+ # Get the encoder for the condition
345
+ encoder = cond.encoder
346
+ encoded = encoder(separated[encoder.key])
347
+ finals.append(encoded)
348
+
349
+ model_conditioning_inputs = tuple(finals)
350
+
351
+ if model_conditioning_inputs is None:
352
+ model_conditioning_inputs = []
353
+
137
354
  def sample_model_fn(x_t, t, *additional_inputs):
138
355
  return self.sample_model(params, x_t, t, *additional_inputs)
139
356
 
140
- # @jax.jit
141
- def sample_step(sample_model_fn, state:RandomMarkovState, samples, current_step, next_step):
142
- samples, state = self.sample_step(sample_model_fn=sample_model_fn, current_samples=samples,
143
- current_step=current_step,
144
- model_conditioning_inputs=model_conditioning_inputs,
145
- state=state, next_step=next_step)
357
+ def sample_step(sample_model_fn, state: RandomMarkovState, samples, current_step, next_step):
358
+ samples, state = self.sample_step(
359
+ sample_model_fn=sample_model_fn,
360
+ current_samples=samples,
361
+ current_step=current_step,
362
+ model_conditioning_inputs=model_conditioning_inputs,
363
+ state=state,
364
+ next_step=next_step
365
+ )
146
366
  return samples, state
147
367
 
148
368
  if start_step is None:
@@ -153,19 +373,61 @@ class DiffusionSampler():
153
373
  else:
154
374
  steps = self.get_steps(start_step, end_step, diffusion_steps)
155
375
 
156
- # print("Sampling steps", steps)
157
376
  for i in tqdm.tqdm(range(0, len(steps))):
158
377
  current_step = self.scale_steps(steps[i])
159
378
  next_step = self.scale_steps(steps[i+1] if i+1 < len(steps) else 0)
379
+
160
380
  if i != len(steps) - 1:
161
- # print("normal step")
162
- samples, rngstate = sample_step(sample_model_fn, rngstate, samples, current_step, next_step)
381
+ samples, rngstate = sample_step(
382
+ sample_model_fn, rngstate, samples, current_step, next_step
383
+ )
163
384
  else:
164
- # print("last step")
165
- step_ones = jnp.ones((num_images, ), dtype=jnp.int32)
166
- samples, _, _ = sample_model_fn(samples, current_step * step_ones, *model_conditioning_inputs)
385
+ step_ones = jnp.ones((samples.shape[0],), dtype=jnp.int32)
386
+ samples, _, _ = sample_model_fn(
387
+ samples, current_step * step_ones, *model_conditioning_inputs
388
+ )
389
+ return self.post_process(samples)
390
+
391
+ def _get_noise_parameters(self, resolution, start_step):
392
+ """Calculate common noise parameters for sample generation.
393
+
394
+ Args:
395
+ start_step: Starting timestep for noise generation
396
+
397
+ Returns:
398
+ Tuple of (variance, image_size, image_channels)
399
+ """
400
+ start_step = self.scale_steps(start_step)
401
+ alpha_n, sigma_n = self.noise_schedule.get_rates(start_step)
402
+ variance = jnp.sqrt(alpha_n ** 2 + sigma_n ** 2)
403
+
404
+ image_size = resolution
405
+ image_channels = 3
167
406
  if self.autoencoder is not None:
168
- samples = self.autoencoder.decode(samples)
169
- samples = clip_images(samples)
170
- return samples
171
-
407
+ image_size = image_size // self.autoencoder.downscale_factor
408
+ image_channels = self.autoencoder.latent_channels
409
+
410
+ return variance, image_size, image_channels
411
+
412
+ def _get_initial_samples(self, resolution, batch_size, rngs: jax.random.PRNGKey, start_step):
413
+ """Generate initial noisy samples for image generation."""
414
+ variance, image_size, image_channels = self._get_noise_parameters(resolution, start_step)
415
+
416
+ # Standard image generation
417
+ return jax.random.normal(
418
+ rngs,
419
+ (batch_size, image_size, image_size, image_channels)
420
+ ) * variance
421
+
422
+ def _get_initial_sequence_samples(self, resolution, batch_size, sequence_length, rngs: jax.random.PRNGKey, start_step):
423
+ """Generate initial noisy samples for sequence data (video/audio)."""
424
+ variance, image_size, image_channels = self._get_noise_parameters(resolution, start_step)
425
+
426
+ # Generate sequence data (like video)
427
+ return jax.random.normal(
428
+ rngs,
429
+ (batch_size, sequence_length, image_size, image_size, image_channels)
430
+ ) * variance
431
+
432
+ # Alias for backward compatibility
433
+ generate_images = generate_samples
flaxdiff/samplers/ddim.py CHANGED
@@ -1,10 +1,49 @@
1
1
  import jax.numpy as jnp
2
2
  from .common import DiffusionSampler
3
3
  from ..utils import MarkovState, RandomMarkovState
4
+ import jax
5
+ from flaxdiff.schedulers import get_coeff_shapes_tuple
4
6
 
5
7
  class DDIMSampler(DiffusionSampler):
6
- def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
7
- pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
8
- next_signal_rate, next_noise_rate = self.noise_schedule.get_rates(next_step)
9
- return reconstructed_samples * next_signal_rate + pred_noise * next_noise_rate, state
10
-
8
+ def __init__(self, *args, eta=0.0, **kwargs):
9
+ """Initialize DDIM sampler with customizable noise level.
10
+
11
+ Args:
12
+ eta: Controls the stochasticity of the sampler.
13
+ 0.0 = deterministic (DDIM), 1.0 = DDPM-like.
14
+ """
15
+ super().__init__(*args, **kwargs)
16
+ self.eta = eta
17
+
18
+ def take_next_step(
19
+ self,
20
+ current_samples,
21
+ reconstructed_samples,
22
+ model_conditioning_inputs,
23
+ pred_noise,
24
+ current_step,
25
+ state: RandomMarkovState,
26
+ sample_model_fn,
27
+ next_step=1
28
+ ) -> tuple[jnp.ndarray, RandomMarkovState]:
29
+ # Get diffusion coefficients for current and next timesteps
30
+ alpha_t, sigma_t = self.noise_schedule.get_rates(current_step, get_coeff_shapes_tuple(current_samples))
31
+ alpha_next, sigma_next = self.noise_schedule.get_rates(next_step, get_coeff_shapes_tuple(current_samples))
32
+
33
+ # Extract random noise if needed for stochastic sampling
34
+ if self.eta > 0:
35
+ # For DDIM, we need to compute the variance coefficient
36
+ # This is based on the original DDIM paper's formula
37
+ # When eta=0, it's deterministic DDIM, when eta=1.0 it approaches DDPM
38
+ sigma_tilde = self.eta * sigma_next * (1 - alpha_t**2 / alpha_next**2).sqrt() / (1 - alpha_t**2).sqrt()
39
+ state, noise_key = state.get_random_key()
40
+ noise = jax.random.normal(noise_key, current_samples.shape)
41
+ # Add the stochastic component
42
+ stochastic_term = sigma_tilde * noise
43
+ else:
44
+ stochastic_term = 0
45
+
46
+ # Direct DDIM update formula
47
+ new_samples = alpha_next * reconstructed_samples + sigma_next * pred_noise + stochastic_term
48
+
49
+ return new_samples, state
@@ -5,35 +5,43 @@ import jax
5
5
  from ..utils import RandomMarkovState
6
6
 
7
7
  class KarrasVENoiseScheduler(GeneralizedNoiseScheduler):
8
- def __init__(self, timesteps, sigma_min=0.002, sigma_max=80, rho=7., sigma_data=0.5, *args, **kwargs):
8
+ def __init__(self, timesteps=1.0, sigma_min=0.002, sigma_max=80, rho=7., sigma_data=0.5, *args, **kwargs):
9
9
  super().__init__(timesteps=timesteps, sigma_min=sigma_min, sigma_max=sigma_max, sigma_data=sigma_data, *args, **kwargs)
10
10
  self.min_inv_rho = sigma_min ** (1 / rho)
11
11
  self.max_inv_rho = sigma_max ** (1 / rho)
12
12
  self.rho = rho
13
-
13
+
14
14
  def get_sigmas(self, steps) -> jnp.ndarray:
15
- # steps = jnp.int16(steps)
16
- # return self.sigmas[steps]
17
- ramp = 1 - steps / self.max_timesteps
15
+ # Ensure steps are properly normalized and clamped to avoid edge cases
16
+ ramp = jnp.clip(1 - steps / self.max_timesteps, 0.0, 1.0)
18
17
  sigmas = (self.max_inv_rho + ramp * (self.min_inv_rho - self.max_inv_rho)) ** self.rho
19
18
  return sigmas
20
-
19
+
21
20
  def get_weights(self, steps, shape=(-1, 1, 1, 1)) -> jnp.ndarray:
22
21
  sigma = self.get_sigmas(steps)
23
- weights = ((sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2)
22
+ # Add epsilon for numerical stability
23
+ epsilon = 1e-6
24
+ weights = ((sigma ** 2 + self.sigma_data ** 2) / ((sigma * self.sigma_data) ** 2 + epsilon))
24
25
  return weights.reshape(shape)
25
26
 
26
27
  def transform_inputs(self, x, steps, num_discrete_chunks=1000) -> tuple[jnp.ndarray, jnp.ndarray]:
27
28
  sigmas = self.get_sigmas(steps)
28
- # sigmas = (sigmas / self.sigma_max) * num_discrete_chunks
29
- sigmas = jnp.log(sigmas) / 4
29
+ # Avoid log(0) by adding a small epsilon
30
+ epsilon = 1e-12
31
+ sigmas = jnp.log(sigmas + epsilon) / 4
30
32
  return x, sigmas
31
33
 
32
34
  def get_timesteps(self, sigmas:jnp.ndarray) -> jnp.ndarray:
33
35
  sigmas = sigmas.reshape(-1)
34
- inv_rho = sigmas ** (1 / self.rho)
35
- ramp = ((inv_rho - self.max_inv_rho) / (self.min_inv_rho - self.max_inv_rho))
36
- steps = 1 - ramp * self.max_timesteps
36
+ # Add epsilon for numerical stability
37
+ epsilon = 1e-12
38
+ inv_rho = (sigmas + epsilon) ** (1 / self.rho)
39
+ # Ensure proper clamping to avoid numerical issues
40
+ denominator = (self.min_inv_rho - self.max_inv_rho)
41
+ if abs(denominator) < 1e-7:
42
+ denominator = jnp.sign(denominator) * 1e-7
43
+ ramp = jnp.clip((inv_rho - self.max_inv_rho) / denominator, 0.0, 1.0)
44
+ steps = jnp.clip(1 - ramp, 0.0, 1.0) * self.max_timesteps
37
45
  return steps
38
46
 
39
47
  def generate_timesteps(self, batch_size, state:RandomMarkovState) -> tuple[jnp.ndarray, RandomMarkovState]:
@@ -1,2 +1,3 @@
1
1
  from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
2
- from .diffusion_trainer import DiffusionTrainer, TrainState
2
+ from .diffusion_trainer import DiffusionTrainer, TrainState
3
+ from .general_diffusion_trainer import GeneralDiffusionTrainer, ConditionalInputConfig
@@ -114,8 +114,7 @@ class AutoEncoderTrainer(SimpleTrainer):
114
114
  # normalize image
115
115
  images = (images - 127.5) / 127.5
116
116
 
117
- output = text_embedder(
118
- input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
117
+ output = text_embedder.encode_from_tokens(batch['text'])
119
118
  label_seq = output.last_hidden_state
120
119
 
121
120
  # Generate random probabilities to decide how much of this batch will be unconditional