flaxdiff 0.1.38.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/__init__.py +5 -1
- flaxdiff/data/benchmark_decord.py +443 -0
- flaxdiff/data/dataloaders.py +608 -0
- flaxdiff/data/dataset_map.py +61 -6
- flaxdiff/data/online_loader.py +779 -150
- flaxdiff/data/sources/audio_utils.py +142 -0
- flaxdiff/data/sources/av_example.py +125 -0
- flaxdiff/data/sources/av_utils.py +590 -0
- flaxdiff/data/sources/base.py +129 -0
- flaxdiff/data/sources/images.py +309 -0
- flaxdiff/data/sources/utils.py +158 -0
- flaxdiff/data/sources/videos.py +250 -0
- flaxdiff/data/sources/voxceleb2.py +412 -0
- flaxdiff/inference/__init__.py +0 -0
- flaxdiff/inference/pipeline.py +260 -0
- flaxdiff/inference/utils.py +320 -0
- flaxdiff/inputs/__init__.py +173 -0
- flaxdiff/inputs/encoders.py +98 -0
- flaxdiff/models/__init__.py +2 -1
- flaxdiff/models/autoencoder/autoencoder.py +141 -9
- flaxdiff/models/autoencoder/diffusers.py +88 -25
- flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
- flaxdiff/models/simple_unet.py +5 -5
- flaxdiff/models/simple_vit.py +1 -1
- flaxdiff/models/unet_3d.py +446 -0
- flaxdiff/models/unet_3d_blocks.py +505 -0
- flaxdiff/samplers/common.py +358 -96
- flaxdiff/samplers/ddim.py +44 -5
- flaxdiff/schedulers/karras.py +20 -12
- flaxdiff/trainer/__init__.py +2 -1
- flaxdiff/trainer/autoencoder_trainer.py +1 -2
- flaxdiff/trainer/diffusion_trainer.py +33 -27
- flaxdiff/trainer/general_diffusion_trainer.py +583 -0
- flaxdiff/trainer/simple_trainer.py +48 -31
- flaxdiff/utils.py +128 -57
- {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
- flaxdiff-0.2.0.dist-info/RECORD +64 -0
- {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
- flaxdiff/data/datasets.py +0 -169
- flaxdiff/data/sources/gcs.py +0 -81
- flaxdiff/data/sources/tfds.py +0 -79
- flaxdiff/trainer/video_diffusion_trainer.py +0 -62
- flaxdiff-0.1.38.1.dist-info/RECORD +0 -50
- {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
flaxdiff/samplers/common.py
CHANGED
@@ -1,148 +1,368 @@
|
|
1
|
-
from
|
1
|
+
from typing import Union, Type
|
2
|
+
|
2
3
|
import jax
|
3
4
|
import jax.numpy as jnp
|
4
5
|
import tqdm
|
5
|
-
from
|
6
|
+
from flax import linen as nn
|
7
|
+
from typing import List, Tuple, Dict, Any, Optional
|
8
|
+
|
9
|
+
from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
|
6
10
|
from ..schedulers import NoiseScheduler
|
7
11
|
from ..utils import RandomMarkovState, MarkovState, clip_images
|
8
|
-
from
|
12
|
+
from jax.experimental.shard_map import shard_map
|
13
|
+
from jax.sharding import Mesh, PartitionSpec as P
|
14
|
+
from flaxdiff.models.autoencoder import AutoEncoder
|
15
|
+
from flaxdiff.inputs import DiffusionInputConfig
|
9
16
|
|
10
|
-
class DiffusionSampler
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
17
|
+
class DiffusionSampler:
|
18
|
+
"""Base class for diffusion samplers."""
|
19
|
+
|
20
|
+
def __init__(
|
21
|
+
self,
|
22
|
+
model: nn.Module,
|
23
|
+
noise_schedule: NoiseScheduler,
|
24
|
+
model_output_transform: DiffusionPredictionTransform,
|
25
|
+
input_config: DiffusionInputConfig,
|
26
|
+
guidance_scale: float = 0.0,
|
27
|
+
autoencoder: AutoEncoder = None,
|
28
|
+
timestep_spacing: str = 'linear',
|
29
|
+
):
|
30
|
+
"""Initialize the diffusion sampler.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
model: Neural network model
|
34
|
+
params: Model parameters
|
35
|
+
noise_schedule: Noise scheduler
|
36
|
+
model_output_transform: Transform for model predictions
|
37
|
+
guidance_scale: Scale for classifier-free guidance (0.0 means disabled)
|
38
|
+
autoencoder: Optional autoencoder for latent diffusion
|
39
|
+
timestep_spacing: Strategy for timestep spacing in sampling
|
40
|
+
'linear' - Default equal spacing
|
41
|
+
'quadratic' - Emphasizes early steps
|
42
|
+
'karras' - Based on EDM paper, better with fewer steps
|
43
|
+
'exponential' - Concentrates steps near the end
|
44
|
+
"""
|
21
45
|
self.model = model
|
22
46
|
self.noise_schedule = noise_schedule
|
23
|
-
self.params = params
|
24
47
|
self.model_output_transform = model_output_transform
|
25
48
|
self.guidance_scale = guidance_scale
|
26
|
-
self.image_size = image_size
|
27
|
-
self.autoenc_scale_reduction = autoenc_scale_reduction
|
28
49
|
self.autoencoder = autoencoder
|
29
|
-
self.
|
50
|
+
self.timestep_spacing = timestep_spacing
|
51
|
+
self.input_config = input_config
|
52
|
+
|
53
|
+
unconditionals = input_config.get_unconditionals()
|
54
|
+
|
55
|
+
# For Karras spacing if needed
|
56
|
+
if hasattr(noise_schedule, 'min_inv_rho') and hasattr(noise_schedule, 'max_inv_rho'):
|
57
|
+
self.min_inv_rho = noise_schedule.min_inv_rho
|
58
|
+
self.max_inv_rho = noise_schedule.max_inv_rho
|
30
59
|
|
31
60
|
if self.guidance_scale > 0:
|
32
61
|
# Classifier free guidance
|
33
|
-
assert null_labels_seq is not None, "Null labels sequence is required for classifier-free guidance"
|
34
62
|
print("Using classifier-free guidance")
|
35
|
-
|
63
|
+
|
64
|
+
def sample_model(params, x_t, t, *conditioning_inputs):
|
36
65
|
# Concatenate unconditional and conditional inputs
|
37
66
|
x_t_cat = jnp.concatenate([x_t] * 2, axis=0)
|
38
67
|
t_cat = jnp.concatenate([t] * 2, axis=0)
|
39
68
|
rates_cat = self.noise_schedule.get_rates(t_cat)
|
40
69
|
c_in_cat = self.model_output_transform.get_input_scale(rates_cat)
|
41
70
|
|
42
|
-
|
43
|
-
|
44
|
-
|
71
|
+
final_conditionals = []
|
72
|
+
for conditional, unconditional in zip(conditioning_inputs, unconditionals):
|
73
|
+
final = jnp.concatenate([
|
74
|
+
conditional,
|
75
|
+
jnp.broadcast_to(unconditional, conditional.shape)
|
76
|
+
], axis=0)
|
77
|
+
final_conditionals.append(final)
|
78
|
+
final_conditionals = tuple(final_conditionals)
|
79
|
+
|
80
|
+
model_output = self.model.apply(
|
81
|
+
params,
|
82
|
+
*self.noise_schedule.transform_inputs(x_t_cat * c_in_cat, t_cat),
|
83
|
+
*final_conditionals
|
84
|
+
)
|
85
|
+
|
45
86
|
# Split model output into unconditional and conditional parts
|
46
87
|
model_output_cond, model_output_uncond = jnp.split(model_output, 2, axis=0)
|
47
88
|
model_output = model_output_uncond + guidance_scale * (model_output_cond - model_output_uncond)
|
48
|
-
|
89
|
+
|
49
90
|
x_0, eps = self.model_output_transform(x_t, model_output, t, self.noise_schedule)
|
50
91
|
return x_0, eps, model_output
|
51
92
|
else:
|
52
93
|
# Unconditional sampling
|
53
|
-
def sample_model(params, x_t, t, *
|
94
|
+
def sample_model(params, x_t, t, *conditioning_inputs):
|
54
95
|
rates = self.noise_schedule.get_rates(t)
|
55
96
|
c_in = self.model_output_transform.get_input_scale(rates)
|
56
|
-
model_output = self.model.apply(
|
97
|
+
model_output = self.model.apply(
|
98
|
+
params,
|
99
|
+
*self.noise_schedule.transform_inputs(x_t * c_in, t),
|
100
|
+
*conditioning_inputs
|
101
|
+
)
|
57
102
|
x_0, eps = self.model_output_transform(x_t, model_output, t, self.noise_schedule)
|
58
103
|
return x_0, eps, model_output
|
59
104
|
|
60
|
-
#
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
105
|
+
# JIT compile the sampling function for better performance
|
106
|
+
def post_process(samples: jnp.ndarray):
|
107
|
+
"""Post-process the generated samples."""
|
108
|
+
if autoencoder is not None:
|
109
|
+
samples = autoencoder.decode(samples)
|
110
|
+
|
111
|
+
samples = clip_images(samples)
|
112
|
+
return samples
|
113
|
+
|
114
|
+
self.sample_model = jax.jit(sample_model)
|
115
|
+
self.post_process = jax.jit(post_process)
|
116
|
+
|
117
|
+
def sample_step(
|
118
|
+
self,
|
119
|
+
sample_model_fn,
|
120
|
+
current_samples: jnp.ndarray,
|
121
|
+
current_step,
|
122
|
+
model_conditioning_inputs,
|
123
|
+
next_step=None,
|
124
|
+
state: RandomMarkovState = None
|
125
|
+
) -> tuple[jnp.ndarray, RandomMarkovState]:
|
126
|
+
"""Perform a single sampling step in the diffusion process.
|
127
|
+
|
128
|
+
Args:
|
129
|
+
sample_model_fn: Function to sample from model
|
130
|
+
current_samples: Current noisy samples
|
131
|
+
current_step: Current diffusion timestep
|
132
|
+
model_conditioning_inputs: Conditioning inputs for the model
|
133
|
+
next_step: Next diffusion timestep
|
134
|
+
state: Current Markov state
|
135
|
+
|
136
|
+
Returns:
|
137
|
+
Tuple of (new samples, updated state)
|
138
|
+
"""
|
139
|
+
step_ones = jnp.ones((len(current_samples),), dtype=jnp.int32)
|
71
140
|
current_step = step_ones * current_step
|
72
141
|
next_step = step_ones * next_step
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
142
|
+
|
143
|
+
pred_images, pred_noise, _ = sample_model_fn(
|
144
|
+
current_samples, current_step, *model_conditioning_inputs
|
145
|
+
)
|
146
|
+
|
147
|
+
new_samples, state = self.take_next_step(
|
148
|
+
current_samples=current_samples,
|
149
|
+
reconstructed_samples=pred_images,
|
150
|
+
pred_noise=pred_noise,
|
151
|
+
current_step=current_step,
|
152
|
+
next_step=next_step,
|
153
|
+
state=state,
|
154
|
+
model_conditioning_inputs=model_conditioning_inputs,
|
155
|
+
sample_model_fn=sample_model_fn,
|
156
|
+
)
|
81
157
|
return new_samples, state
|
82
158
|
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
159
|
+
|
160
|
+
def take_next_step(
|
161
|
+
self,
|
162
|
+
current_samples,
|
163
|
+
reconstructed_samples,
|
164
|
+
model_conditioning_inputs,
|
165
|
+
pred_noise,
|
166
|
+
current_step,
|
167
|
+
state: RandomMarkovState,
|
168
|
+
sample_model_fn,
|
169
|
+
next_step=1,
|
170
|
+
) -> tuple[jnp.ndarray, RandomMarkovState]:
|
171
|
+
"""Take the next step in the diffusion process.
|
172
|
+
|
173
|
+
This method needs to be implemented by subclasses.
|
174
|
+
"""
|
175
|
+
raise NotImplementedError("Subclasses must implement take_next_step method")
|
176
|
+
|
177
|
+
|
89
178
|
def scale_steps(self, steps):
|
179
|
+
"""Scale timesteps to match the noise schedule's range."""
|
90
180
|
scale_factor = self.noise_schedule.max_timesteps / 1000
|
91
181
|
return steps * scale_factor
|
92
182
|
|
183
|
+
|
93
184
|
def get_steps(self, start_step, end_step, diffusion_steps):
|
185
|
+
"""Get the sequence of timesteps for the diffusion process.
|
186
|
+
|
187
|
+
Args:
|
188
|
+
start_step: Starting timestep (typically the max)
|
189
|
+
end_step: Ending timestep (typically 0)
|
190
|
+
diffusion_steps: Number of steps to use
|
191
|
+
|
192
|
+
Returns:
|
193
|
+
Array of timesteps for sampling
|
194
|
+
"""
|
94
195
|
step_range = start_step - end_step
|
95
196
|
if diffusion_steps is None or diffusion_steps == 0:
|
96
|
-
diffusion_steps =
|
197
|
+
diffusion_steps = step_range
|
97
198
|
diffusion_steps = min(diffusion_steps, step_range)
|
98
|
-
|
199
|
+
|
200
|
+
# Linear spacing (default)
|
201
|
+
if getattr(self, 'timestep_spacing', 'linear') == 'linear':
|
202
|
+
steps = jnp.linspace(
|
203
|
+
end_step, start_step,
|
204
|
+
diffusion_steps, dtype=jnp.int16
|
205
|
+
)[::-1]
|
206
|
+
|
207
|
+
# Quadratic spacing (emphasizes early steps)
|
208
|
+
elif self.timestep_spacing == 'quadratic':
|
209
|
+
steps = jnp.linspace(0, 1, diffusion_steps) ** 2
|
210
|
+
steps = (start_step - end_step) * steps + end_step
|
211
|
+
steps = jnp.asarray(steps, dtype=jnp.int16)[::-1]
|
212
|
+
|
213
|
+
# Karras spacing from the EDM paper - often gives better results with fewer steps
|
214
|
+
elif self.timestep_spacing == 'karras':
|
215
|
+
# Implementation based on the EDM paper's recommendations
|
216
|
+
sigma_min = end_step / start_step
|
217
|
+
sigma_max = 1.0
|
218
|
+
rho = 7.0 # Karras paper default, controls the distribution
|
219
|
+
|
220
|
+
# Create log-spaced steps in sigma space
|
221
|
+
sigmas = jnp.exp(jnp.linspace(
|
222
|
+
jnp.log(sigma_max), jnp.log(sigma_min), diffusion_steps
|
223
|
+
))
|
224
|
+
steps = jnp.clip(
|
225
|
+
(sigmas ** (1 / rho) - self.min_inv_rho) /
|
226
|
+
(self.max_inv_rho - self.min_inv_rho),
|
227
|
+
0, 1
|
228
|
+
) * start_step
|
229
|
+
steps = jnp.asarray(steps, dtype=jnp.int16)
|
230
|
+
|
231
|
+
# Exponential spacing (concentrates steps near the end)
|
232
|
+
elif self.timestep_spacing == 'exponential':
|
233
|
+
steps = jnp.linspace(0, 1, diffusion_steps)
|
234
|
+
steps = jnp.exp(steps * jnp.log((start_step + 1) / (end_step + 1))) * (end_step + 1) - 1
|
235
|
+
steps = jnp.clip(steps, end_step, start_step)
|
236
|
+
steps = jnp.asarray(steps, dtype=jnp.int16)[::-1]
|
237
|
+
|
238
|
+
# Fallback to linear spacing
|
239
|
+
else:
|
240
|
+
steps = jnp.linspace(
|
241
|
+
end_step, start_step,
|
242
|
+
diffusion_steps, dtype=jnp.int16
|
243
|
+
)[::-1]
|
244
|
+
|
99
245
|
return steps
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
246
|
+
|
247
|
+
|
248
|
+
def generate_samples(
|
249
|
+
self,
|
250
|
+
params: dict,
|
251
|
+
num_samples: int,
|
252
|
+
resolution: int,
|
253
|
+
sequence_length: int = None,
|
254
|
+
diffusion_steps: int = 1000,
|
255
|
+
start_step: int = None,
|
256
|
+
end_step: int = 0,
|
257
|
+
steps_override=None,
|
258
|
+
priors=None,
|
259
|
+
rngstate: RandomMarkovState = None,
|
260
|
+
conditioning: List[Union[Tuple, Dict]] = None,
|
261
|
+
model_conditioning_inputs: Tuple = None,
|
262
|
+
) -> jnp.ndarray:
|
263
|
+
"""Generate samples using the diffusion model.
|
264
|
+
|
265
|
+
Provides a unified interface for generating both images and videos.
|
266
|
+
For images, just specify batch_size.
|
267
|
+
For videos, specify both batch_size and sequence_length.
|
268
|
+
|
269
|
+
Args:
|
270
|
+
params: Model parameters (uses self.params if None)
|
271
|
+
num_samples: Number of samples to generate (videos or images)
|
272
|
+
resolution: Resolution of the generated samples (H, W)
|
273
|
+
sequence_length: Length of each sequence (for videos/audio/etc)
|
274
|
+
If None, generates regular images
|
275
|
+
diffusion_steps: Number of diffusion steps to perform
|
276
|
+
start_step: Starting timestep (defaults to max)
|
277
|
+
end_step: Ending timestep
|
278
|
+
steps_override: Override default timestep sequence
|
279
|
+
priors: Prior samples to start from instead of noise
|
280
|
+
rngstate: Random state for reproducibility
|
281
|
+
conditioning: (Optional) List of conditioning inputs for the model
|
282
|
+
model_conditioning_inputs: (Optional) Pre-processed conditioning inputs
|
283
|
+
|
284
|
+
Returns:
|
285
|
+
Generated samples as a JAX array:
|
286
|
+
- For images: shape [batch_size, H, W, C]
|
287
|
+
- For videos: shape [batch_size, sequence_length, H, W, C]
|
288
|
+
"""
|
123
289
|
if rngstate is None:
|
124
290
|
rngstate = RandomMarkovState(jax.random.PRNGKey(42))
|
291
|
+
|
292
|
+
if start_step is None:
|
293
|
+
start_step = self.noise_schedule.max_timesteps
|
294
|
+
|
125
295
|
if priors is None:
|
296
|
+
# Determine if we're generating videos or images based on sequence_length
|
297
|
+
is_video = sequence_length is not None
|
298
|
+
|
126
299
|
rngstate, newrngs = rngstate.get_random_key()
|
127
|
-
|
300
|
+
|
301
|
+
# Get sample shape based on whether we're generating video or images
|
302
|
+
if is_video:
|
303
|
+
samples = self._get_initial_sequence_samples(
|
304
|
+
resolution, num_samples, sequence_length, newrngs, start_step
|
305
|
+
)
|
306
|
+
else:
|
307
|
+
samples = self._get_initial_samples(resolution, num_samples, newrngs, start_step)
|
128
308
|
else:
|
129
309
|
print("Using priors")
|
130
310
|
if self.autoencoder is not None:
|
311
|
+
# Let the autoencoder handle both image and video priors
|
131
312
|
priors = self.autoencoder.encode(priors)
|
132
313
|
samples = priors
|
133
|
-
|
134
|
-
params = params if params is not None else self.params
|
135
314
|
|
136
|
-
|
315
|
+
if conditioning is not None:
|
316
|
+
if model_conditioning_inputs is not None:
|
317
|
+
raise ValueError("Cannot provide both conditioning and model_conditioning_inputs")
|
318
|
+
print("Processing raw conditioning inputs to generate model conditioning inputs")
|
319
|
+
separated: Dict[str, List] = {}
|
320
|
+
for cond in self.input_config.conditions:
|
321
|
+
separated[cond.encoder.key] = []
|
322
|
+
# Separate the conditioning inputs, one for each condition
|
323
|
+
for vals in conditioning:
|
324
|
+
if isinstance(vals, tuple) or isinstance(vals, list):
|
325
|
+
# If its a tuple, assume that the ordering aligns with the ordering of the conditions
|
326
|
+
# Thus, use the conditioning encoder key as the key
|
327
|
+
for cond, val in zip(self.input_config.conditions, vals):
|
328
|
+
separated[cond.encoder.key].append(val)
|
329
|
+
elif isinstance(vals, dict):
|
330
|
+
# If its a dict, use the encoder key as the key
|
331
|
+
for cond in self.input_config.conditions:
|
332
|
+
if cond.encoder.key in vals:
|
333
|
+
separated[cond.encoder.key].append(vals[cond.encoder.key])
|
334
|
+
else:
|
335
|
+
raise ValueError(f"Conditioning input {cond.encoder.key} not found in provided dictionary")
|
336
|
+
else:
|
337
|
+
# If its a single value, use the encoder key as the key
|
338
|
+
for cond in self.input_config.conditions:
|
339
|
+
separated[cond.encoder.key].append(vals)
|
340
|
+
|
341
|
+
# Now we have a dictionary of lists, one for each condition, encode them
|
342
|
+
finals = []
|
343
|
+
for cond in self.input_config.conditions:
|
344
|
+
# Get the encoder for the condition
|
345
|
+
encoder = cond.encoder
|
346
|
+
encoded = encoder(separated[encoder.key])
|
347
|
+
finals.append(encoded)
|
348
|
+
|
349
|
+
model_conditioning_inputs = tuple(finals)
|
350
|
+
|
351
|
+
if model_conditioning_inputs is None:
|
352
|
+
model_conditioning_inputs = []
|
353
|
+
|
137
354
|
def sample_model_fn(x_t, t, *additional_inputs):
|
138
355
|
return self.sample_model(params, x_t, t, *additional_inputs)
|
139
356
|
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
357
|
+
def sample_step(sample_model_fn, state: RandomMarkovState, samples, current_step, next_step):
|
358
|
+
samples, state = self.sample_step(
|
359
|
+
sample_model_fn=sample_model_fn,
|
360
|
+
current_samples=samples,
|
361
|
+
current_step=current_step,
|
362
|
+
model_conditioning_inputs=model_conditioning_inputs,
|
363
|
+
state=state,
|
364
|
+
next_step=next_step
|
365
|
+
)
|
146
366
|
return samples, state
|
147
367
|
|
148
368
|
if start_step is None:
|
@@ -153,19 +373,61 @@ class DiffusionSampler():
|
|
153
373
|
else:
|
154
374
|
steps = self.get_steps(start_step, end_step, diffusion_steps)
|
155
375
|
|
156
|
-
# print("Sampling steps", steps)
|
157
376
|
for i in tqdm.tqdm(range(0, len(steps))):
|
158
377
|
current_step = self.scale_steps(steps[i])
|
159
378
|
next_step = self.scale_steps(steps[i+1] if i+1 < len(steps) else 0)
|
379
|
+
|
160
380
|
if i != len(steps) - 1:
|
161
|
-
|
162
|
-
|
381
|
+
samples, rngstate = sample_step(
|
382
|
+
sample_model_fn, rngstate, samples, current_step, next_step
|
383
|
+
)
|
163
384
|
else:
|
164
|
-
|
165
|
-
|
166
|
-
|
385
|
+
step_ones = jnp.ones((samples.shape[0],), dtype=jnp.int32)
|
386
|
+
samples, _, _ = sample_model_fn(
|
387
|
+
samples, current_step * step_ones, *model_conditioning_inputs
|
388
|
+
)
|
389
|
+
return self.post_process(samples)
|
390
|
+
|
391
|
+
def _get_noise_parameters(self, resolution, start_step):
|
392
|
+
"""Calculate common noise parameters for sample generation.
|
393
|
+
|
394
|
+
Args:
|
395
|
+
start_step: Starting timestep for noise generation
|
396
|
+
|
397
|
+
Returns:
|
398
|
+
Tuple of (variance, image_size, image_channels)
|
399
|
+
"""
|
400
|
+
start_step = self.scale_steps(start_step)
|
401
|
+
alpha_n, sigma_n = self.noise_schedule.get_rates(start_step)
|
402
|
+
variance = jnp.sqrt(alpha_n ** 2 + sigma_n ** 2)
|
403
|
+
|
404
|
+
image_size = resolution
|
405
|
+
image_channels = 3
|
167
406
|
if self.autoencoder is not None:
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
407
|
+
image_size = image_size // self.autoencoder.downscale_factor
|
408
|
+
image_channels = self.autoencoder.latent_channels
|
409
|
+
|
410
|
+
return variance, image_size, image_channels
|
411
|
+
|
412
|
+
def _get_initial_samples(self, resolution, batch_size, rngs: jax.random.PRNGKey, start_step):
|
413
|
+
"""Generate initial noisy samples for image generation."""
|
414
|
+
variance, image_size, image_channels = self._get_noise_parameters(resolution, start_step)
|
415
|
+
|
416
|
+
# Standard image generation
|
417
|
+
return jax.random.normal(
|
418
|
+
rngs,
|
419
|
+
(batch_size, image_size, image_size, image_channels)
|
420
|
+
) * variance
|
421
|
+
|
422
|
+
def _get_initial_sequence_samples(self, resolution, batch_size, sequence_length, rngs: jax.random.PRNGKey, start_step):
|
423
|
+
"""Generate initial noisy samples for sequence data (video/audio)."""
|
424
|
+
variance, image_size, image_channels = self._get_noise_parameters(resolution, start_step)
|
425
|
+
|
426
|
+
# Generate sequence data (like video)
|
427
|
+
return jax.random.normal(
|
428
|
+
rngs,
|
429
|
+
(batch_size, sequence_length, image_size, image_size, image_channels)
|
430
|
+
) * variance
|
431
|
+
|
432
|
+
# Alias for backward compatibility
|
433
|
+
generate_images = generate_samples
|
flaxdiff/samplers/ddim.py
CHANGED
@@ -1,10 +1,49 @@
|
|
1
1
|
import jax.numpy as jnp
|
2
2
|
from .common import DiffusionSampler
|
3
3
|
from ..utils import MarkovState, RandomMarkovState
|
4
|
+
import jax
|
5
|
+
from flaxdiff.schedulers import get_coeff_shapes_tuple
|
4
6
|
|
5
7
|
class DDIMSampler(DiffusionSampler):
|
6
|
-
def
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
8
|
+
def __init__(self, *args, eta=0.0, **kwargs):
|
9
|
+
"""Initialize DDIM sampler with customizable noise level.
|
10
|
+
|
11
|
+
Args:
|
12
|
+
eta: Controls the stochasticity of the sampler.
|
13
|
+
0.0 = deterministic (DDIM), 1.0 = DDPM-like.
|
14
|
+
"""
|
15
|
+
super().__init__(*args, **kwargs)
|
16
|
+
self.eta = eta
|
17
|
+
|
18
|
+
def take_next_step(
|
19
|
+
self,
|
20
|
+
current_samples,
|
21
|
+
reconstructed_samples,
|
22
|
+
model_conditioning_inputs,
|
23
|
+
pred_noise,
|
24
|
+
current_step,
|
25
|
+
state: RandomMarkovState,
|
26
|
+
sample_model_fn,
|
27
|
+
next_step=1
|
28
|
+
) -> tuple[jnp.ndarray, RandomMarkovState]:
|
29
|
+
# Get diffusion coefficients for current and next timesteps
|
30
|
+
alpha_t, sigma_t = self.noise_schedule.get_rates(current_step, get_coeff_shapes_tuple(current_samples))
|
31
|
+
alpha_next, sigma_next = self.noise_schedule.get_rates(next_step, get_coeff_shapes_tuple(current_samples))
|
32
|
+
|
33
|
+
# Extract random noise if needed for stochastic sampling
|
34
|
+
if self.eta > 0:
|
35
|
+
# For DDIM, we need to compute the variance coefficient
|
36
|
+
# This is based on the original DDIM paper's formula
|
37
|
+
# When eta=0, it's deterministic DDIM, when eta=1.0 it approaches DDPM
|
38
|
+
sigma_tilde = self.eta * sigma_next * (1 - alpha_t**2 / alpha_next**2).sqrt() / (1 - alpha_t**2).sqrt()
|
39
|
+
state, noise_key = state.get_random_key()
|
40
|
+
noise = jax.random.normal(noise_key, current_samples.shape)
|
41
|
+
# Add the stochastic component
|
42
|
+
stochastic_term = sigma_tilde * noise
|
43
|
+
else:
|
44
|
+
stochastic_term = 0
|
45
|
+
|
46
|
+
# Direct DDIM update formula
|
47
|
+
new_samples = alpha_next * reconstructed_samples + sigma_next * pred_noise + stochastic_term
|
48
|
+
|
49
|
+
return new_samples, state
|
flaxdiff/schedulers/karras.py
CHANGED
@@ -5,35 +5,43 @@ import jax
|
|
5
5
|
from ..utils import RandomMarkovState
|
6
6
|
|
7
7
|
class KarrasVENoiseScheduler(GeneralizedNoiseScheduler):
|
8
|
-
def __init__(self, timesteps, sigma_min=0.002, sigma_max=80, rho=7., sigma_data=0.5, *args, **kwargs):
|
8
|
+
def __init__(self, timesteps=1.0, sigma_min=0.002, sigma_max=80, rho=7., sigma_data=0.5, *args, **kwargs):
|
9
9
|
super().__init__(timesteps=timesteps, sigma_min=sigma_min, sigma_max=sigma_max, sigma_data=sigma_data, *args, **kwargs)
|
10
10
|
self.min_inv_rho = sigma_min ** (1 / rho)
|
11
11
|
self.max_inv_rho = sigma_max ** (1 / rho)
|
12
12
|
self.rho = rho
|
13
|
-
|
13
|
+
|
14
14
|
def get_sigmas(self, steps) -> jnp.ndarray:
|
15
|
-
# steps
|
16
|
-
|
17
|
-
ramp = 1 - steps / self.max_timesteps
|
15
|
+
# Ensure steps are properly normalized and clamped to avoid edge cases
|
16
|
+
ramp = jnp.clip(1 - steps / self.max_timesteps, 0.0, 1.0)
|
18
17
|
sigmas = (self.max_inv_rho + ramp * (self.min_inv_rho - self.max_inv_rho)) ** self.rho
|
19
18
|
return sigmas
|
20
|
-
|
19
|
+
|
21
20
|
def get_weights(self, steps, shape=(-1, 1, 1, 1)) -> jnp.ndarray:
|
22
21
|
sigma = self.get_sigmas(steps)
|
23
|
-
|
22
|
+
# Add epsilon for numerical stability
|
23
|
+
epsilon = 1e-6
|
24
|
+
weights = ((sigma ** 2 + self.sigma_data ** 2) / ((sigma * self.sigma_data) ** 2 + epsilon))
|
24
25
|
return weights.reshape(shape)
|
25
26
|
|
26
27
|
def transform_inputs(self, x, steps, num_discrete_chunks=1000) -> tuple[jnp.ndarray, jnp.ndarray]:
|
27
28
|
sigmas = self.get_sigmas(steps)
|
28
|
-
#
|
29
|
-
|
29
|
+
# Avoid log(0) by adding a small epsilon
|
30
|
+
epsilon = 1e-12
|
31
|
+
sigmas = jnp.log(sigmas + epsilon) / 4
|
30
32
|
return x, sigmas
|
31
33
|
|
32
34
|
def get_timesteps(self, sigmas:jnp.ndarray) -> jnp.ndarray:
|
33
35
|
sigmas = sigmas.reshape(-1)
|
34
|
-
|
35
|
-
|
36
|
-
|
36
|
+
# Add epsilon for numerical stability
|
37
|
+
epsilon = 1e-12
|
38
|
+
inv_rho = (sigmas + epsilon) ** (1 / self.rho)
|
39
|
+
# Ensure proper clamping to avoid numerical issues
|
40
|
+
denominator = (self.min_inv_rho - self.max_inv_rho)
|
41
|
+
if abs(denominator) < 1e-7:
|
42
|
+
denominator = jnp.sign(denominator) * 1e-7
|
43
|
+
ramp = jnp.clip((inv_rho - self.max_inv_rho) / denominator, 0.0, 1.0)
|
44
|
+
steps = jnp.clip(1 - ramp, 0.0, 1.0) * self.max_timesteps
|
37
45
|
return steps
|
38
46
|
|
39
47
|
def generate_timesteps(self, batch_size, state:RandomMarkovState) -> tuple[jnp.ndarray, RandomMarkovState]:
|
flaxdiff/trainer/__init__.py
CHANGED
@@ -1,2 +1,3 @@
|
|
1
1
|
from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
|
2
|
-
from .diffusion_trainer import DiffusionTrainer, TrainState
|
2
|
+
from .diffusion_trainer import DiffusionTrainer, TrainState
|
3
|
+
from .general_diffusion_trainer import GeneralDiffusionTrainer, ConditionalInputConfig
|
@@ -114,8 +114,7 @@ class AutoEncoderTrainer(SimpleTrainer):
|
|
114
114
|
# normalize image
|
115
115
|
images = (images - 127.5) / 127.5
|
116
116
|
|
117
|
-
output = text_embedder(
|
118
|
-
input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
|
117
|
+
output = text_embedder.encode_from_tokens(batch['text'])
|
119
118
|
label_seq = output.last_hidden_state
|
120
119
|
|
121
120
|
# Generate random probabilities to decide how much of this batch will be unconditional
|