flaxdiff 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,519 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from flax import linen as nn
4
+ from typing import Dict, Callable, Sequence, Any, Union
5
+ import einops
6
+ from .common import kernel_init
7
+ from .attention import TransformerBlock
8
+ class WeightStandardizedConv(nn.Module):
9
+ """
10
+ apply weight standardization https://arxiv.org/abs/1903.10520
11
+ """
12
+ features: int
13
+ kernel_size: Sequence[int] = 3
14
+ strides: Union[None, int, Sequence[int]] = 1
15
+ padding: Any = 1
16
+ dtype: Any = jnp.float32
17
+ param_dtype: Any = jnp.float32
18
+
19
+ @nn.compact
20
+ def __call__(self, x):
21
+ """
22
+ Applies a weight standardized convolution to the inputs.
23
+
24
+ Args:
25
+ inputs: input data with dimensions (batch, spatial_dims..., features).
26
+
27
+ Returns:
28
+ The convolved data.
29
+ """
30
+ x = x.astype(self.dtype)
31
+
32
+ conv = nn.Conv(
33
+ features=self.features,
34
+ kernel_size=self.kernel_size,
35
+ strides = self.strides,
36
+ padding=self.padding,
37
+ dtype=self.dtype,
38
+ param_dtype = self.param_dtype,
39
+ parent=None)
40
+
41
+ kernel_init = lambda rng, x: conv.init(rng,x)['params']['kernel']
42
+ bias_init = lambda rng, x: conv.init(rng,x)['params']['bias']
43
+
44
+ # standardize kernel
45
+ kernel = self.param('kernel', kernel_init, x)
46
+ eps = 1e-5 if self.dtype == jnp.float32 else 1e-3
47
+ # reduce over dim_out
48
+ redux = tuple(range(kernel.ndim - 1))
49
+ mean = jnp.mean(kernel, axis=redux, dtype=self.dtype, keepdims=True)
50
+ var = jnp.var(kernel, axis=redux, dtype=self.dtype, keepdims=True)
51
+ standardized_kernel = (kernel - mean)/jnp.sqrt(var + eps)
52
+
53
+ bias = self.param('bias',bias_init, x)
54
+
55
+ return(conv.apply({'params': {'kernel': standardized_kernel, 'bias': bias}},x))
56
+
57
+ class PixelShuffle(nn.Module):
58
+ scale: int
59
+
60
+ @nn.compact
61
+ def __call__(self, x):
62
+ up = einops.rearrange(
63
+ x,
64
+ pattern="b h w (h2 w2 c) -> b (h h2) (w w2) c",
65
+ h2=self.scale,
66
+ w2=self.scale,
67
+ )
68
+ return up
69
+
70
+ class TimeEmbedding(nn.Module):
71
+ features:int
72
+ nax_positions:int=10000
73
+
74
+ def setup(self):
75
+ half_dim = self.features // 2
76
+ emb = jnp.log(self.nax_positions) / (half_dim - 1)
77
+ emb = jnp.exp(-emb * jnp.arange(half_dim, dtype=jnp.float32))
78
+ self.embeddings = emb
79
+
80
+ def __call__(self, x):
81
+ x = jax.lax.convert_element_type(x, jnp.float32)
82
+ emb = x[:, None] * self.embeddings[None, :]
83
+ emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
84
+ return emb
85
+
86
+ class FourierEmbedding(nn.Module):
87
+ features:int
88
+ scale:int = 16
89
+
90
+ def setup(self):
91
+ self.freqs = jax.random.normal(jax.random.PRNGKey(42), (self.features // 2, ), dtype=jnp.float32) * self.scale
92
+
93
+ def __call__(self, x):
94
+ x = jax.lax.convert_element_type(x, jnp.float32)
95
+ emb = x[:, None] * (2 * jnp.pi * self.freqs)[None, :]
96
+ emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
97
+ return emb
98
+
99
+ class TimeProjection(nn.Module):
100
+ features:int
101
+ activation:Callable=jax.nn.gelu
102
+
103
+ @nn.compact
104
+ def __call__(self, x):
105
+ x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x)
106
+ x = self.activation(x)
107
+ x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x)
108
+ x = self.activation(x)
109
+ return x
110
+
111
+ class SeparableConv(nn.Module):
112
+ features:int
113
+ kernel_size:tuple=(3, 3)
114
+ strides:tuple=(1, 1)
115
+ use_bias:bool=False
116
+ kernel_init:Callable=kernel_init(1.0)
117
+ padding:str="SAME"
118
+ dtype: Any = jnp.bfloat16
119
+ precision: Any = jax.lax.Precision.HIGH
120
+
121
+ @nn.compact
122
+ def __call__(self, x):
123
+ in_features = x.shape[-1]
124
+ depthwise = nn.Conv(
125
+ features=in_features, kernel_size=self.kernel_size,
126
+ strides=self.strides, kernel_init=self.kernel_init,
127
+ feature_group_count=in_features, use_bias=self.use_bias,
128
+ padding=self.padding,
129
+ dtype=self.dtype,
130
+ precision=self.precision
131
+ )(x)
132
+ pointwise = nn.Conv(
133
+ features=self.features, kernel_size=(1, 1),
134
+ strides=(1, 1), kernel_init=self.kernel_init,
135
+ use_bias=self.use_bias,
136
+ dtype=self.dtype,
137
+ precision=self.precision
138
+ )(depthwise)
139
+ return pointwise
140
+
141
+ class ConvLayer(nn.Module):
142
+ conv_type:str
143
+ features:int
144
+ kernel_size:tuple=(3, 3)
145
+ strides:tuple=(1, 1)
146
+ kernel_init:Callable=kernel_init(1.0)
147
+ dtype: Any = jnp.bfloat16
148
+ precision: Any = jax.lax.Precision.HIGH
149
+
150
+ def setup(self):
151
+ # conv_type can be "conv", "separable", "conv_transpose"
152
+ if self.conv_type == "conv":
153
+ self.conv = nn.Conv(
154
+ features=self.features,
155
+ kernel_size=self.kernel_size,
156
+ strides=self.strides,
157
+ kernel_init=self.kernel_init,
158
+ dtype=self.dtype,
159
+ precision=self.precision
160
+ )
161
+ elif self.conv_type == "w_conv":
162
+ self.conv = WeightStandardizedConv(
163
+ features=self.features,
164
+ kernel_size=self.kernel_size,
165
+ strides=self.strides,
166
+ padding="SAME",
167
+ param_dtype=self.dtype,
168
+ dtype=self.dtype,
169
+ precision=self.precision
170
+ )
171
+ elif self.conv_type == "separable":
172
+ self.conv = SeparableConv(
173
+ features=self.features,
174
+ kernel_size=self.kernel_size,
175
+ strides=self.strides,
176
+ kernel_init=self.kernel_init,
177
+ dtype=self.dtype,
178
+ precision=self.precision
179
+ )
180
+ elif self.conv_type == "conv_transpose":
181
+ self.conv = nn.ConvTranspose(
182
+ features=self.features,
183
+ kernel_size=self.kernel_size,
184
+ strides=self.strides,
185
+ kernel_init=self.kernel_init,
186
+ dtype=self.dtype,
187
+ precision=self.precision
188
+ )
189
+
190
+ def __call__(self, x):
191
+ return self.conv(x)
192
+
193
+ class Upsample(nn.Module):
194
+ features:int
195
+ scale:int
196
+ activation:Callable=jax.nn.swish
197
+ dtype: Any = jnp.bfloat16
198
+ precision: Any = jax.lax.Precision.HIGH
199
+
200
+ @nn.compact
201
+ def __call__(self, x, residual=None):
202
+ out = x
203
+ # out = PixelShuffle(scale=self.scale)(out)
204
+ B, H, W, C = x.shape
205
+ out = jax.image.resize(x, (B, H * self.scale, W * self.scale, C), method="nearest")
206
+ out = ConvLayer(
207
+ "conv",
208
+ features=self.features,
209
+ kernel_size=(3, 3),
210
+ strides=(1, 1),
211
+ dtype=self.dtype,
212
+ precision=self.precision
213
+ )(out)
214
+ if residual is not None:
215
+ out = jnp.concatenate([out, residual], axis=-1)
216
+ return out
217
+
218
+ class Downsample(nn.Module):
219
+ features:int
220
+ scale:int
221
+ activation:Callable=jax.nn.swish
222
+ dtype: Any = jnp.bfloat16
223
+ precision: Any = jax.lax.Precision.HIGH
224
+
225
+ @nn.compact
226
+ def __call__(self, x, residual=None):
227
+ out = ConvLayer(
228
+ "conv",
229
+ features=self.features,
230
+ kernel_size=(3, 3),
231
+ strides=(2, 2),
232
+ dtype=self.dtype,
233
+ precision=self.precision
234
+ )(x)
235
+ if residual is not None:
236
+ if residual.shape[1] > out.shape[1]:
237
+ residual = nn.avg_pool(residual, window_shape=(2, 2), strides=(2, 2), padding="SAME")
238
+ out = jnp.concatenate([out, residual], axis=-1)
239
+ return out
240
+
241
+
242
+ def l2norm(t, axis=1, eps=1e-12):
243
+ denom = jnp.clip(jnp.linalg.norm(t, ord=2, axis=axis, keepdims=True), eps)
244
+ out = t/denom
245
+ return (out)
246
+ class ResidualBlock(nn.Module):
247
+ conv_type:str
248
+ features:int
249
+ kernel_size:tuple=(3, 3)
250
+ strides:tuple=(1, 1)
251
+ padding:str="SAME"
252
+ activation:Callable=jax.nn.swish
253
+ direction:str=None
254
+ res:int=2
255
+ norm_groups:int=8
256
+ kernel_init:Callable=kernel_init(1.0)
257
+ dtype: Any = jnp.float32
258
+ precision: Any = jax.lax.Precision.HIGHEST
259
+
260
+ @nn.compact
261
+ def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_features:jax.Array=None):
262
+ residual = x
263
+ out = nn.GroupNorm(self.norm_groups)(x)
264
+ out = self.activation(out)
265
+
266
+ out = ConvLayer(
267
+ self.conv_type,
268
+ features=self.features,
269
+ kernel_size=self.kernel_size,
270
+ strides=self.strides,
271
+ kernel_init=self.kernel_init,
272
+ name="conv1",
273
+ dtype=self.dtype,
274
+ precision=self.precision
275
+ )(out)
276
+
277
+ temb = nn.DenseGeneral(
278
+ features=self.features,
279
+ name="temb_projection",
280
+ dtype=self.dtype,
281
+ precision=self.precision)(temb)
282
+ temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
283
+ # scale, shift = jnp.split(temb, 2, axis=-1)
284
+ # out = out * (1 + scale) + shift
285
+ out = out + temb
286
+
287
+ out = nn.GroupNorm(self.norm_groups)(out)
288
+ out = self.activation(out)
289
+
290
+ out = ConvLayer(
291
+ self.conv_type,
292
+ features=self.features,
293
+ kernel_size=self.kernel_size,
294
+ strides=self.strides,
295
+ kernel_init=self.kernel_init,
296
+ name="conv2",
297
+ dtype=self.dtype,
298
+ precision=self.precision
299
+ )(out)
300
+
301
+ if residual.shape != out.shape:
302
+ residual = ConvLayer(
303
+ self.conv_type,
304
+ features=self.features,
305
+ kernel_size=(1, 1),
306
+ strides=1,
307
+ kernel_init=self.kernel_init,
308
+ name="residual_conv",
309
+ dtype=self.dtype,
310
+ precision=self.precision
311
+ )(residual)
312
+ out = out + residual
313
+
314
+ out = jnp.concatenate([out, extra_features], axis=-1) if extra_features is not None else out
315
+
316
+ return out
317
+
318
+ class Unet(nn.Module):
319
+ emb_features:int=64*4,
320
+ feature_depths:list=[64, 128, 256, 512],
321
+ attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}],
322
+ num_res_blocks:int=2,
323
+ num_middle_res_blocks:int=1,
324
+ activation:Callable = jax.nn.swish
325
+ norm_groups:int=8
326
+ dtype: Any = jnp.bfloat16
327
+ precision: Any = jax.lax.Precision.HIGH
328
+
329
+ @nn.compact
330
+ def __call__(self, x, temb, textcontext=None):
331
+ # print("embedding features", self.emb_features)
332
+ temb = FourierEmbedding(features=self.emb_features)(temb)
333
+ temb = TimeProjection(features=self.emb_features)(temb)
334
+
335
+ _, TS, TC = textcontext.shape
336
+
337
+ # print("time embedding", temb.shape)
338
+ feature_depths = self.feature_depths
339
+ attention_configs = self.attention_configs
340
+
341
+ conv_type = up_conv_type = down_conv_type = middle_conv_type = "conv"
342
+ # middle_conv_type = "separable"
343
+
344
+ x = ConvLayer(
345
+ conv_type,
346
+ features=self.feature_depths[0],
347
+ kernel_size=(3, 3),
348
+ strides=(1, 1),
349
+ kernel_init=kernel_init(1.0),
350
+ dtype=self.dtype,
351
+ precision=self.precision
352
+ )(x)
353
+ downs = [x]
354
+
355
+ # Downscaling blocks
356
+ for i, (dim_out, attention_config) in enumerate(zip(feature_depths, attention_configs)):
357
+ dim_in = x.shape[-1]
358
+ # dim_in = dim_out
359
+ for j in range(self.num_res_blocks):
360
+ x = ResidualBlock(
361
+ down_conv_type,
362
+ name=f"down_{i}_residual_{j}",
363
+ features=dim_in,
364
+ kernel_init=kernel_init(1.0),
365
+ kernel_size=(3, 3),
366
+ strides=(1, 1),
367
+ activation=self.activation,
368
+ norm_groups=self.norm_groups,
369
+ dtype=self.dtype,
370
+ precision=self.precision
371
+ )(x, temb)
372
+ if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
373
+ B, H, W, _ = x.shape
374
+ if H > TS:
375
+ padded_context = jnp.pad(textcontext, ((0, 0), (0, H - TS), (0, 0)), mode='constant', constant_values=0).reshape((B, 1, H, TC))
376
+ else:
377
+ padded_context = None
378
+ x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
379
+ dim_head=dim_in // attention_config['heads'],
380
+ use_flash_attention=attention_config.get("flash_attention", True),
381
+ use_projection=attention_config.get("use_projection", False),
382
+ use_self_and_cross=attention_config.get("use_self_and_cross", True),
383
+ precision=attention_config.get("precision", self.precision),
384
+ name=f"down_{i}_attention_{j}")(x, padded_context)
385
+ # print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
386
+ downs.append(x)
387
+ if i != len(feature_depths) - 1:
388
+ # print("Downsample", i, x.shape)
389
+ x = Downsample(
390
+ features=dim_out,
391
+ scale=2,
392
+ activation=self.activation,
393
+ name=f"down_{i}_downsample",
394
+ dtype=self.dtype,
395
+ precision=self.precision
396
+ )(x)
397
+
398
+ # Middle Blocks
399
+ middle_dim_out = self.feature_depths[-1]
400
+ middle_attention = self.attention_configs[-1]
401
+ for j in range(self.num_middle_res_blocks):
402
+ x = ResidualBlock(
403
+ middle_conv_type,
404
+ name=f"middle_res1_{j}",
405
+ features=middle_dim_out,
406
+ kernel_init=kernel_init(1.0),
407
+ kernel_size=(3, 3),
408
+ strides=(1, 1),
409
+ activation=self.activation,
410
+ norm_groups=self.norm_groups,
411
+ dtype=self.dtype,
412
+ precision=self.precision
413
+ )(x, temb)
414
+ if middle_attention is not None and j == self.num_middle_res_blocks - 1: # Apply attention only on the last block
415
+ x = TransformerBlock(heads=middle_attention['heads'], dtype=middle_attention.get('dtype', jnp.float32),
416
+ dim_head=middle_dim_out // middle_attention['heads'],
417
+ use_flash_attention=middle_attention.get("flash_attention", True),
418
+ use_linear_attention=False,
419
+ use_projection=middle_attention.get("use_projection", False),
420
+ use_self_and_cross=False,
421
+ precision=attention_config.get("precision", self.precision),
422
+ name=f"middle_attention_{j}")(x)
423
+ x = ResidualBlock(
424
+ middle_conv_type,
425
+ name=f"middle_res2_{j}",
426
+ features=middle_dim_out,
427
+ kernel_init=kernel_init(1.0),
428
+ kernel_size=(3, 3),
429
+ strides=(1, 1),
430
+ activation=self.activation,
431
+ norm_groups=self.norm_groups,
432
+ dtype=self.dtype,
433
+ precision=self.precision
434
+ )(x, temb)
435
+
436
+ # Upscaling Blocks
437
+ for i, (dim_out, attention_config) in enumerate(zip(reversed(feature_depths), reversed(attention_configs))):
438
+ # print("Upscaling", i, "features", dim_out)
439
+ for j in range(self.num_res_blocks):
440
+ x = jnp.concatenate([x, downs.pop()], axis=-1)
441
+ # print("concat==> ", i, "concat", x.shape)
442
+ # kernel_size = (1 + 2 * (j + 1), 1 + 2 * (j + 1))
443
+ kernel_size = (3, 3)
444
+ x = ResidualBlock(
445
+ up_conv_type,# if j == 0 else "separable",
446
+ name=f"up_{i}_residual_{j}",
447
+ features=dim_out,
448
+ kernel_init=kernel_init(1.0),
449
+ kernel_size=kernel_size,
450
+ strides=(1, 1),
451
+ activation=self.activation,
452
+ norm_groups=self.norm_groups,
453
+ dtype=self.dtype,
454
+ precision=self.precision
455
+ )(x, temb)
456
+ if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
457
+ B, H, W, _ = x.shape
458
+ if H > TS:
459
+ padded_context = jnp.pad(textcontext, ((0, 0), (0, H - TS), (0, 0)), mode='constant', constant_values=0).reshape((B, 1, H, TC))
460
+ else:
461
+ padded_context = None
462
+ x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
463
+ dim_head=dim_out // attention_config['heads'],
464
+ use_flash_attention=attention_config.get("flash_attention", True),
465
+ use_projection=attention_config.get("use_projection", False),
466
+ use_self_and_cross=attention_config.get("use_self_and_cross", True),
467
+ precision=attention_config.get("precision", self.precision),
468
+ name=f"up_{i}_attention_{j}")(x, padded_context)
469
+ # print("Upscaling ", i, x.shape)
470
+ if i != len(feature_depths) - 1:
471
+ x = Upsample(
472
+ features=feature_depths[-i],
473
+ scale=2,
474
+ activation=self.activation,
475
+ name=f"up_{i}_upsample",
476
+ dtype=self.dtype,
477
+ precision=self.precision
478
+ )(x)
479
+
480
+ # x = nn.GroupNorm(8)(x)
481
+ x = ConvLayer(
482
+ conv_type,
483
+ features=self.feature_depths[0],
484
+ kernel_size=(3, 3),
485
+ strides=(1, 1),
486
+ kernel_init=kernel_init(0.0),
487
+ dtype=self.dtype,
488
+ precision=self.precision
489
+ )(x)
490
+
491
+ x = jnp.concatenate([x, downs.pop()], axis=-1)
492
+
493
+ x = ResidualBlock(
494
+ conv_type,
495
+ name="final_residual",
496
+ features=self.feature_depths[0],
497
+ kernel_init=kernel_init(1.0),
498
+ kernel_size=(3,3),
499
+ strides=(1, 1),
500
+ activation=self.activation,
501
+ norm_groups=self.norm_groups,
502
+ dtype=self.dtype,
503
+ precision=self.precision
504
+ )(x, temb)
505
+
506
+ x = nn.GroupNorm(self.norm_groups)(x)
507
+ x = self.activation(x)
508
+
509
+ noise_out = ConvLayer(
510
+ conv_type,
511
+ features=3,
512
+ kernel_size=(3, 3),
513
+ strides=(1, 1),
514
+ # activation=jax.nn.mish
515
+ kernel_init=kernel_init(0.0),
516
+ dtype=self.dtype,
517
+ precision=self.precision
518
+ )(x)
519
+ return noise_out#, attentions
@@ -0,0 +1,96 @@
1
+ from typing import Union
2
+ import jax.numpy as jnp
3
+ from ..schedulers import NoiseScheduler, GeneralizedNoiseScheduler
4
+
5
+ ############################################################################################################
6
+ # Prediction Transforms
7
+ ############################################################################################################
8
+
9
+ class DiffusionPredictionTransform():
10
+ def pred_transform(self, x_t, preds, rates) -> jnp.ndarray:
11
+ return preds
12
+
13
+ def __call__(self, x_t, preds, current_step, noise_schedule:NoiseScheduler) -> Union[jnp.ndarray, jnp.ndarray]:
14
+ rates = noise_schedule.get_rates(current_step)
15
+ preds = self.pred_transform(x_t, preds, rates)
16
+ x_0, epsilon = self.backward_diffusion(x_t, preds, rates)
17
+ return x_0, epsilon
18
+
19
+ def forward_diffusion(self, x_0, epsilon, rates: tuple[jnp.ndarray, jnp.ndarray]) -> Union[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
20
+ signal_rate, noise_rate = rates
21
+ x_t = signal_rate * x_0 + noise_rate * epsilon
22
+ expected_output = self.get_target(x_0, epsilon, (signal_rate, noise_rate))
23
+ c_in = self.get_input_scale((signal_rate, noise_rate))
24
+ return x_t, c_in, expected_output
25
+
26
+ def backward_diffusion(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> Union[jnp.ndarray, jnp.ndarray]:
27
+ raise NotImplementedError
28
+
29
+ def get_target(self, x_0, epsilon, rates) ->jnp.ndarray:
30
+ return x_0
31
+
32
+ def get_input_scale(self, rates: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
33
+ return 1
34
+
35
+ class EpsilonPredictionTransform(DiffusionPredictionTransform):
36
+ def backward_diffusion(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> Union[jnp.ndarray, jnp.ndarray]:
37
+ # preds is the predicted noise
38
+ epsilon = preds
39
+ signal_rates, noise_rates = rates
40
+ x_0 = (x_t - epsilon * noise_rates) / signal_rates
41
+ return x_0, epsilon
42
+
43
+ def get_target(self, x_0, epsilon, rates) ->jnp.ndarray:
44
+ return epsilon
45
+
46
+ class DirectPredictionTransform(DiffusionPredictionTransform):
47
+ def backward_diffusion(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> Union[jnp.ndarray, jnp.ndarray]:
48
+ # Here the model predicts x_0 directly
49
+ x_0 = preds
50
+ signal_rate, noise_rate = rates
51
+ epsilon = (x_t - x_0 * signal_rate) / noise_rate
52
+ return x_0, epsilon
53
+
54
+ class VPredictionTransform(DiffusionPredictionTransform):
55
+ def backward_diffusion(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> Union[jnp.ndarray, jnp.ndarray]:
56
+ # here the model output's V = sqrt_alpha_t * epsilon - sqrt_one_minus_alpha_t * x_0
57
+ # where epsilon is the noise
58
+ # x_0 is the current sample
59
+ v = preds
60
+ signal_rate, noise_rate = rates
61
+ variance = signal_rate ** 2 + noise_rate ** 2
62
+ v = v * jnp.sqrt(variance)
63
+ x_0 = signal_rate * x_t - noise_rate * v
64
+ eps_0 = signal_rate * v + noise_rate * x_t
65
+ return x_0 / variance, eps_0 / variance
66
+
67
+ def get_target(self, x_0, epsilon, rates) ->jnp.ndarray:
68
+ signal_rate, noise_rate = rates
69
+ v = signal_rate * epsilon - noise_rate * x_0
70
+ variance = signal_rate**2 + noise_rate**2
71
+ return v / jnp.sqrt(variance)
72
+
73
+ class KarrasPredictionTransform(DiffusionPredictionTransform):
74
+ def __init__(self, sigma_data=0.5) -> None:
75
+ super().__init__()
76
+ self.sigma_data = sigma_data
77
+
78
+ def backward_diffusion(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> Union[jnp.ndarray, jnp.ndarray]:
79
+ x_0 = preds
80
+ signal_rate, noise_rate = rates
81
+ epsilon = (x_t - x_0 * signal_rate) / noise_rate
82
+ return x_0, epsilon
83
+
84
+ def pred_transform(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
85
+ _, sigma = rates
86
+ c_out = sigma * self.sigma_data / jnp.sqrt(self.sigma_data ** 2 + sigma ** 2)
87
+ c_skip = self.sigma_data ** 2 / (self.sigma_data ** 2 + sigma ** 2)
88
+ c_out = c_out.reshape((-1, 1, 1, 1))
89
+ c_skip = c_skip.reshape((-1, 1, 1, 1))
90
+ x_0 = c_out * preds + c_skip * x_t
91
+ return x_0
92
+
93
+ def get_input_scale(self, rates: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
94
+ _, sigma = rates
95
+ c_in = 1 / jnp.sqrt(self.sigma_data ** 2 + sigma ** 2)
96
+ return c_in
@@ -0,0 +1,7 @@
1
+ from .common import DiffusionSampler
2
+ from .ddim import DDIMSampler
3
+ from .ddpm import DDPMSampler, SimpleDDPMSampler
4
+ from .euler import EulerSampler, SimplifiedEulerSampler
5
+ from .heun_sampler import HeunSampler
6
+ from .rk4_sampler import RK4Sampler
7
+ from .multistep_dpm import MultiStepDPM