flaxdiff 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/__init__.py +0 -0
- flaxdiff/models/__init__.py +1 -0
- flaxdiff/models/attention.py +489 -0
- flaxdiff/models/common.py +7 -0
- flaxdiff/models/favor_fastattn.py +723 -0
- flaxdiff/models/simple_unet.py +519 -0
- flaxdiff/predictors/__init__.py +96 -0
- flaxdiff/samplers/__init__.py +7 -0
- flaxdiff/samplers/common.py +113 -0
- flaxdiff/samplers/ddim.py +10 -0
- flaxdiff/samplers/ddpm.py +43 -0
- flaxdiff/samplers/euler.py +59 -0
- flaxdiff/samplers/heun_sampler.py +28 -0
- flaxdiff/samplers/multistep_dpm.py +60 -0
- flaxdiff/samplers/rk4_sampler.py +34 -0
- flaxdiff/schedulers/__init__.py +6 -0
- flaxdiff/schedulers/common.py +98 -0
- flaxdiff/schedulers/continuous.py +12 -0
- flaxdiff/schedulers/cosine.py +40 -0
- flaxdiff/schedulers/discrete.py +74 -0
- flaxdiff/schedulers/exp.py +13 -0
- flaxdiff/schedulers/karras.py +69 -0
- flaxdiff/schedulers/linear.py +14 -0
- flaxdiff/schedulers/sqrt.py +10 -0
- flaxdiff/trainer/__init__.py +216 -0
- flaxdiff/utils.py +89 -0
- flaxdiff-0.1.1.dist-info/METADATA +326 -0
- flaxdiff-0.1.1.dist-info/RECORD +30 -0
- flaxdiff-0.1.1.dist-info/WHEEL +5 -0
- flaxdiff-0.1.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,519 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
from flax import linen as nn
|
4
|
+
from typing import Dict, Callable, Sequence, Any, Union
|
5
|
+
import einops
|
6
|
+
from .common import kernel_init
|
7
|
+
from .attention import TransformerBlock
|
8
|
+
class WeightStandardizedConv(nn.Module):
|
9
|
+
"""
|
10
|
+
apply weight standardization https://arxiv.org/abs/1903.10520
|
11
|
+
"""
|
12
|
+
features: int
|
13
|
+
kernel_size: Sequence[int] = 3
|
14
|
+
strides: Union[None, int, Sequence[int]] = 1
|
15
|
+
padding: Any = 1
|
16
|
+
dtype: Any = jnp.float32
|
17
|
+
param_dtype: Any = jnp.float32
|
18
|
+
|
19
|
+
@nn.compact
|
20
|
+
def __call__(self, x):
|
21
|
+
"""
|
22
|
+
Applies a weight standardized convolution to the inputs.
|
23
|
+
|
24
|
+
Args:
|
25
|
+
inputs: input data with dimensions (batch, spatial_dims..., features).
|
26
|
+
|
27
|
+
Returns:
|
28
|
+
The convolved data.
|
29
|
+
"""
|
30
|
+
x = x.astype(self.dtype)
|
31
|
+
|
32
|
+
conv = nn.Conv(
|
33
|
+
features=self.features,
|
34
|
+
kernel_size=self.kernel_size,
|
35
|
+
strides = self.strides,
|
36
|
+
padding=self.padding,
|
37
|
+
dtype=self.dtype,
|
38
|
+
param_dtype = self.param_dtype,
|
39
|
+
parent=None)
|
40
|
+
|
41
|
+
kernel_init = lambda rng, x: conv.init(rng,x)['params']['kernel']
|
42
|
+
bias_init = lambda rng, x: conv.init(rng,x)['params']['bias']
|
43
|
+
|
44
|
+
# standardize kernel
|
45
|
+
kernel = self.param('kernel', kernel_init, x)
|
46
|
+
eps = 1e-5 if self.dtype == jnp.float32 else 1e-3
|
47
|
+
# reduce over dim_out
|
48
|
+
redux = tuple(range(kernel.ndim - 1))
|
49
|
+
mean = jnp.mean(kernel, axis=redux, dtype=self.dtype, keepdims=True)
|
50
|
+
var = jnp.var(kernel, axis=redux, dtype=self.dtype, keepdims=True)
|
51
|
+
standardized_kernel = (kernel - mean)/jnp.sqrt(var + eps)
|
52
|
+
|
53
|
+
bias = self.param('bias',bias_init, x)
|
54
|
+
|
55
|
+
return(conv.apply({'params': {'kernel': standardized_kernel, 'bias': bias}},x))
|
56
|
+
|
57
|
+
class PixelShuffle(nn.Module):
|
58
|
+
scale: int
|
59
|
+
|
60
|
+
@nn.compact
|
61
|
+
def __call__(self, x):
|
62
|
+
up = einops.rearrange(
|
63
|
+
x,
|
64
|
+
pattern="b h w (h2 w2 c) -> b (h h2) (w w2) c",
|
65
|
+
h2=self.scale,
|
66
|
+
w2=self.scale,
|
67
|
+
)
|
68
|
+
return up
|
69
|
+
|
70
|
+
class TimeEmbedding(nn.Module):
|
71
|
+
features:int
|
72
|
+
nax_positions:int=10000
|
73
|
+
|
74
|
+
def setup(self):
|
75
|
+
half_dim = self.features // 2
|
76
|
+
emb = jnp.log(self.nax_positions) / (half_dim - 1)
|
77
|
+
emb = jnp.exp(-emb * jnp.arange(half_dim, dtype=jnp.float32))
|
78
|
+
self.embeddings = emb
|
79
|
+
|
80
|
+
def __call__(self, x):
|
81
|
+
x = jax.lax.convert_element_type(x, jnp.float32)
|
82
|
+
emb = x[:, None] * self.embeddings[None, :]
|
83
|
+
emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
|
84
|
+
return emb
|
85
|
+
|
86
|
+
class FourierEmbedding(nn.Module):
|
87
|
+
features:int
|
88
|
+
scale:int = 16
|
89
|
+
|
90
|
+
def setup(self):
|
91
|
+
self.freqs = jax.random.normal(jax.random.PRNGKey(42), (self.features // 2, ), dtype=jnp.float32) * self.scale
|
92
|
+
|
93
|
+
def __call__(self, x):
|
94
|
+
x = jax.lax.convert_element_type(x, jnp.float32)
|
95
|
+
emb = x[:, None] * (2 * jnp.pi * self.freqs)[None, :]
|
96
|
+
emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
|
97
|
+
return emb
|
98
|
+
|
99
|
+
class TimeProjection(nn.Module):
|
100
|
+
features:int
|
101
|
+
activation:Callable=jax.nn.gelu
|
102
|
+
|
103
|
+
@nn.compact
|
104
|
+
def __call__(self, x):
|
105
|
+
x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x)
|
106
|
+
x = self.activation(x)
|
107
|
+
x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x)
|
108
|
+
x = self.activation(x)
|
109
|
+
return x
|
110
|
+
|
111
|
+
class SeparableConv(nn.Module):
|
112
|
+
features:int
|
113
|
+
kernel_size:tuple=(3, 3)
|
114
|
+
strides:tuple=(1, 1)
|
115
|
+
use_bias:bool=False
|
116
|
+
kernel_init:Callable=kernel_init(1.0)
|
117
|
+
padding:str="SAME"
|
118
|
+
dtype: Any = jnp.bfloat16
|
119
|
+
precision: Any = jax.lax.Precision.HIGH
|
120
|
+
|
121
|
+
@nn.compact
|
122
|
+
def __call__(self, x):
|
123
|
+
in_features = x.shape[-1]
|
124
|
+
depthwise = nn.Conv(
|
125
|
+
features=in_features, kernel_size=self.kernel_size,
|
126
|
+
strides=self.strides, kernel_init=self.kernel_init,
|
127
|
+
feature_group_count=in_features, use_bias=self.use_bias,
|
128
|
+
padding=self.padding,
|
129
|
+
dtype=self.dtype,
|
130
|
+
precision=self.precision
|
131
|
+
)(x)
|
132
|
+
pointwise = nn.Conv(
|
133
|
+
features=self.features, kernel_size=(1, 1),
|
134
|
+
strides=(1, 1), kernel_init=self.kernel_init,
|
135
|
+
use_bias=self.use_bias,
|
136
|
+
dtype=self.dtype,
|
137
|
+
precision=self.precision
|
138
|
+
)(depthwise)
|
139
|
+
return pointwise
|
140
|
+
|
141
|
+
class ConvLayer(nn.Module):
|
142
|
+
conv_type:str
|
143
|
+
features:int
|
144
|
+
kernel_size:tuple=(3, 3)
|
145
|
+
strides:tuple=(1, 1)
|
146
|
+
kernel_init:Callable=kernel_init(1.0)
|
147
|
+
dtype: Any = jnp.bfloat16
|
148
|
+
precision: Any = jax.lax.Precision.HIGH
|
149
|
+
|
150
|
+
def setup(self):
|
151
|
+
# conv_type can be "conv", "separable", "conv_transpose"
|
152
|
+
if self.conv_type == "conv":
|
153
|
+
self.conv = nn.Conv(
|
154
|
+
features=self.features,
|
155
|
+
kernel_size=self.kernel_size,
|
156
|
+
strides=self.strides,
|
157
|
+
kernel_init=self.kernel_init,
|
158
|
+
dtype=self.dtype,
|
159
|
+
precision=self.precision
|
160
|
+
)
|
161
|
+
elif self.conv_type == "w_conv":
|
162
|
+
self.conv = WeightStandardizedConv(
|
163
|
+
features=self.features,
|
164
|
+
kernel_size=self.kernel_size,
|
165
|
+
strides=self.strides,
|
166
|
+
padding="SAME",
|
167
|
+
param_dtype=self.dtype,
|
168
|
+
dtype=self.dtype,
|
169
|
+
precision=self.precision
|
170
|
+
)
|
171
|
+
elif self.conv_type == "separable":
|
172
|
+
self.conv = SeparableConv(
|
173
|
+
features=self.features,
|
174
|
+
kernel_size=self.kernel_size,
|
175
|
+
strides=self.strides,
|
176
|
+
kernel_init=self.kernel_init,
|
177
|
+
dtype=self.dtype,
|
178
|
+
precision=self.precision
|
179
|
+
)
|
180
|
+
elif self.conv_type == "conv_transpose":
|
181
|
+
self.conv = nn.ConvTranspose(
|
182
|
+
features=self.features,
|
183
|
+
kernel_size=self.kernel_size,
|
184
|
+
strides=self.strides,
|
185
|
+
kernel_init=self.kernel_init,
|
186
|
+
dtype=self.dtype,
|
187
|
+
precision=self.precision
|
188
|
+
)
|
189
|
+
|
190
|
+
def __call__(self, x):
|
191
|
+
return self.conv(x)
|
192
|
+
|
193
|
+
class Upsample(nn.Module):
|
194
|
+
features:int
|
195
|
+
scale:int
|
196
|
+
activation:Callable=jax.nn.swish
|
197
|
+
dtype: Any = jnp.bfloat16
|
198
|
+
precision: Any = jax.lax.Precision.HIGH
|
199
|
+
|
200
|
+
@nn.compact
|
201
|
+
def __call__(self, x, residual=None):
|
202
|
+
out = x
|
203
|
+
# out = PixelShuffle(scale=self.scale)(out)
|
204
|
+
B, H, W, C = x.shape
|
205
|
+
out = jax.image.resize(x, (B, H * self.scale, W * self.scale, C), method="nearest")
|
206
|
+
out = ConvLayer(
|
207
|
+
"conv",
|
208
|
+
features=self.features,
|
209
|
+
kernel_size=(3, 3),
|
210
|
+
strides=(1, 1),
|
211
|
+
dtype=self.dtype,
|
212
|
+
precision=self.precision
|
213
|
+
)(out)
|
214
|
+
if residual is not None:
|
215
|
+
out = jnp.concatenate([out, residual], axis=-1)
|
216
|
+
return out
|
217
|
+
|
218
|
+
class Downsample(nn.Module):
|
219
|
+
features:int
|
220
|
+
scale:int
|
221
|
+
activation:Callable=jax.nn.swish
|
222
|
+
dtype: Any = jnp.bfloat16
|
223
|
+
precision: Any = jax.lax.Precision.HIGH
|
224
|
+
|
225
|
+
@nn.compact
|
226
|
+
def __call__(self, x, residual=None):
|
227
|
+
out = ConvLayer(
|
228
|
+
"conv",
|
229
|
+
features=self.features,
|
230
|
+
kernel_size=(3, 3),
|
231
|
+
strides=(2, 2),
|
232
|
+
dtype=self.dtype,
|
233
|
+
precision=self.precision
|
234
|
+
)(x)
|
235
|
+
if residual is not None:
|
236
|
+
if residual.shape[1] > out.shape[1]:
|
237
|
+
residual = nn.avg_pool(residual, window_shape=(2, 2), strides=(2, 2), padding="SAME")
|
238
|
+
out = jnp.concatenate([out, residual], axis=-1)
|
239
|
+
return out
|
240
|
+
|
241
|
+
|
242
|
+
def l2norm(t, axis=1, eps=1e-12):
|
243
|
+
denom = jnp.clip(jnp.linalg.norm(t, ord=2, axis=axis, keepdims=True), eps)
|
244
|
+
out = t/denom
|
245
|
+
return (out)
|
246
|
+
class ResidualBlock(nn.Module):
|
247
|
+
conv_type:str
|
248
|
+
features:int
|
249
|
+
kernel_size:tuple=(3, 3)
|
250
|
+
strides:tuple=(1, 1)
|
251
|
+
padding:str="SAME"
|
252
|
+
activation:Callable=jax.nn.swish
|
253
|
+
direction:str=None
|
254
|
+
res:int=2
|
255
|
+
norm_groups:int=8
|
256
|
+
kernel_init:Callable=kernel_init(1.0)
|
257
|
+
dtype: Any = jnp.float32
|
258
|
+
precision: Any = jax.lax.Precision.HIGHEST
|
259
|
+
|
260
|
+
@nn.compact
|
261
|
+
def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_features:jax.Array=None):
|
262
|
+
residual = x
|
263
|
+
out = nn.GroupNorm(self.norm_groups)(x)
|
264
|
+
out = self.activation(out)
|
265
|
+
|
266
|
+
out = ConvLayer(
|
267
|
+
self.conv_type,
|
268
|
+
features=self.features,
|
269
|
+
kernel_size=self.kernel_size,
|
270
|
+
strides=self.strides,
|
271
|
+
kernel_init=self.kernel_init,
|
272
|
+
name="conv1",
|
273
|
+
dtype=self.dtype,
|
274
|
+
precision=self.precision
|
275
|
+
)(out)
|
276
|
+
|
277
|
+
temb = nn.DenseGeneral(
|
278
|
+
features=self.features,
|
279
|
+
name="temb_projection",
|
280
|
+
dtype=self.dtype,
|
281
|
+
precision=self.precision)(temb)
|
282
|
+
temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
|
283
|
+
# scale, shift = jnp.split(temb, 2, axis=-1)
|
284
|
+
# out = out * (1 + scale) + shift
|
285
|
+
out = out + temb
|
286
|
+
|
287
|
+
out = nn.GroupNorm(self.norm_groups)(out)
|
288
|
+
out = self.activation(out)
|
289
|
+
|
290
|
+
out = ConvLayer(
|
291
|
+
self.conv_type,
|
292
|
+
features=self.features,
|
293
|
+
kernel_size=self.kernel_size,
|
294
|
+
strides=self.strides,
|
295
|
+
kernel_init=self.kernel_init,
|
296
|
+
name="conv2",
|
297
|
+
dtype=self.dtype,
|
298
|
+
precision=self.precision
|
299
|
+
)(out)
|
300
|
+
|
301
|
+
if residual.shape != out.shape:
|
302
|
+
residual = ConvLayer(
|
303
|
+
self.conv_type,
|
304
|
+
features=self.features,
|
305
|
+
kernel_size=(1, 1),
|
306
|
+
strides=1,
|
307
|
+
kernel_init=self.kernel_init,
|
308
|
+
name="residual_conv",
|
309
|
+
dtype=self.dtype,
|
310
|
+
precision=self.precision
|
311
|
+
)(residual)
|
312
|
+
out = out + residual
|
313
|
+
|
314
|
+
out = jnp.concatenate([out, extra_features], axis=-1) if extra_features is not None else out
|
315
|
+
|
316
|
+
return out
|
317
|
+
|
318
|
+
class Unet(nn.Module):
|
319
|
+
emb_features:int=64*4,
|
320
|
+
feature_depths:list=[64, 128, 256, 512],
|
321
|
+
attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}],
|
322
|
+
num_res_blocks:int=2,
|
323
|
+
num_middle_res_blocks:int=1,
|
324
|
+
activation:Callable = jax.nn.swish
|
325
|
+
norm_groups:int=8
|
326
|
+
dtype: Any = jnp.bfloat16
|
327
|
+
precision: Any = jax.lax.Precision.HIGH
|
328
|
+
|
329
|
+
@nn.compact
|
330
|
+
def __call__(self, x, temb, textcontext=None):
|
331
|
+
# print("embedding features", self.emb_features)
|
332
|
+
temb = FourierEmbedding(features=self.emb_features)(temb)
|
333
|
+
temb = TimeProjection(features=self.emb_features)(temb)
|
334
|
+
|
335
|
+
_, TS, TC = textcontext.shape
|
336
|
+
|
337
|
+
# print("time embedding", temb.shape)
|
338
|
+
feature_depths = self.feature_depths
|
339
|
+
attention_configs = self.attention_configs
|
340
|
+
|
341
|
+
conv_type = up_conv_type = down_conv_type = middle_conv_type = "conv"
|
342
|
+
# middle_conv_type = "separable"
|
343
|
+
|
344
|
+
x = ConvLayer(
|
345
|
+
conv_type,
|
346
|
+
features=self.feature_depths[0],
|
347
|
+
kernel_size=(3, 3),
|
348
|
+
strides=(1, 1),
|
349
|
+
kernel_init=kernel_init(1.0),
|
350
|
+
dtype=self.dtype,
|
351
|
+
precision=self.precision
|
352
|
+
)(x)
|
353
|
+
downs = [x]
|
354
|
+
|
355
|
+
# Downscaling blocks
|
356
|
+
for i, (dim_out, attention_config) in enumerate(zip(feature_depths, attention_configs)):
|
357
|
+
dim_in = x.shape[-1]
|
358
|
+
# dim_in = dim_out
|
359
|
+
for j in range(self.num_res_blocks):
|
360
|
+
x = ResidualBlock(
|
361
|
+
down_conv_type,
|
362
|
+
name=f"down_{i}_residual_{j}",
|
363
|
+
features=dim_in,
|
364
|
+
kernel_init=kernel_init(1.0),
|
365
|
+
kernel_size=(3, 3),
|
366
|
+
strides=(1, 1),
|
367
|
+
activation=self.activation,
|
368
|
+
norm_groups=self.norm_groups,
|
369
|
+
dtype=self.dtype,
|
370
|
+
precision=self.precision
|
371
|
+
)(x, temb)
|
372
|
+
if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
|
373
|
+
B, H, W, _ = x.shape
|
374
|
+
if H > TS:
|
375
|
+
padded_context = jnp.pad(textcontext, ((0, 0), (0, H - TS), (0, 0)), mode='constant', constant_values=0).reshape((B, 1, H, TC))
|
376
|
+
else:
|
377
|
+
padded_context = None
|
378
|
+
x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
|
379
|
+
dim_head=dim_in // attention_config['heads'],
|
380
|
+
use_flash_attention=attention_config.get("flash_attention", True),
|
381
|
+
use_projection=attention_config.get("use_projection", False),
|
382
|
+
use_self_and_cross=attention_config.get("use_self_and_cross", True),
|
383
|
+
precision=attention_config.get("precision", self.precision),
|
384
|
+
name=f"down_{i}_attention_{j}")(x, padded_context)
|
385
|
+
# print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
|
386
|
+
downs.append(x)
|
387
|
+
if i != len(feature_depths) - 1:
|
388
|
+
# print("Downsample", i, x.shape)
|
389
|
+
x = Downsample(
|
390
|
+
features=dim_out,
|
391
|
+
scale=2,
|
392
|
+
activation=self.activation,
|
393
|
+
name=f"down_{i}_downsample",
|
394
|
+
dtype=self.dtype,
|
395
|
+
precision=self.precision
|
396
|
+
)(x)
|
397
|
+
|
398
|
+
# Middle Blocks
|
399
|
+
middle_dim_out = self.feature_depths[-1]
|
400
|
+
middle_attention = self.attention_configs[-1]
|
401
|
+
for j in range(self.num_middle_res_blocks):
|
402
|
+
x = ResidualBlock(
|
403
|
+
middle_conv_type,
|
404
|
+
name=f"middle_res1_{j}",
|
405
|
+
features=middle_dim_out,
|
406
|
+
kernel_init=kernel_init(1.0),
|
407
|
+
kernel_size=(3, 3),
|
408
|
+
strides=(1, 1),
|
409
|
+
activation=self.activation,
|
410
|
+
norm_groups=self.norm_groups,
|
411
|
+
dtype=self.dtype,
|
412
|
+
precision=self.precision
|
413
|
+
)(x, temb)
|
414
|
+
if middle_attention is not None and j == self.num_middle_res_blocks - 1: # Apply attention only on the last block
|
415
|
+
x = TransformerBlock(heads=middle_attention['heads'], dtype=middle_attention.get('dtype', jnp.float32),
|
416
|
+
dim_head=middle_dim_out // middle_attention['heads'],
|
417
|
+
use_flash_attention=middle_attention.get("flash_attention", True),
|
418
|
+
use_linear_attention=False,
|
419
|
+
use_projection=middle_attention.get("use_projection", False),
|
420
|
+
use_self_and_cross=False,
|
421
|
+
precision=attention_config.get("precision", self.precision),
|
422
|
+
name=f"middle_attention_{j}")(x)
|
423
|
+
x = ResidualBlock(
|
424
|
+
middle_conv_type,
|
425
|
+
name=f"middle_res2_{j}",
|
426
|
+
features=middle_dim_out,
|
427
|
+
kernel_init=kernel_init(1.0),
|
428
|
+
kernel_size=(3, 3),
|
429
|
+
strides=(1, 1),
|
430
|
+
activation=self.activation,
|
431
|
+
norm_groups=self.norm_groups,
|
432
|
+
dtype=self.dtype,
|
433
|
+
precision=self.precision
|
434
|
+
)(x, temb)
|
435
|
+
|
436
|
+
# Upscaling Blocks
|
437
|
+
for i, (dim_out, attention_config) in enumerate(zip(reversed(feature_depths), reversed(attention_configs))):
|
438
|
+
# print("Upscaling", i, "features", dim_out)
|
439
|
+
for j in range(self.num_res_blocks):
|
440
|
+
x = jnp.concatenate([x, downs.pop()], axis=-1)
|
441
|
+
# print("concat==> ", i, "concat", x.shape)
|
442
|
+
# kernel_size = (1 + 2 * (j + 1), 1 + 2 * (j + 1))
|
443
|
+
kernel_size = (3, 3)
|
444
|
+
x = ResidualBlock(
|
445
|
+
up_conv_type,# if j == 0 else "separable",
|
446
|
+
name=f"up_{i}_residual_{j}",
|
447
|
+
features=dim_out,
|
448
|
+
kernel_init=kernel_init(1.0),
|
449
|
+
kernel_size=kernel_size,
|
450
|
+
strides=(1, 1),
|
451
|
+
activation=self.activation,
|
452
|
+
norm_groups=self.norm_groups,
|
453
|
+
dtype=self.dtype,
|
454
|
+
precision=self.precision
|
455
|
+
)(x, temb)
|
456
|
+
if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
|
457
|
+
B, H, W, _ = x.shape
|
458
|
+
if H > TS:
|
459
|
+
padded_context = jnp.pad(textcontext, ((0, 0), (0, H - TS), (0, 0)), mode='constant', constant_values=0).reshape((B, 1, H, TC))
|
460
|
+
else:
|
461
|
+
padded_context = None
|
462
|
+
x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
|
463
|
+
dim_head=dim_out // attention_config['heads'],
|
464
|
+
use_flash_attention=attention_config.get("flash_attention", True),
|
465
|
+
use_projection=attention_config.get("use_projection", False),
|
466
|
+
use_self_and_cross=attention_config.get("use_self_and_cross", True),
|
467
|
+
precision=attention_config.get("precision", self.precision),
|
468
|
+
name=f"up_{i}_attention_{j}")(x, padded_context)
|
469
|
+
# print("Upscaling ", i, x.shape)
|
470
|
+
if i != len(feature_depths) - 1:
|
471
|
+
x = Upsample(
|
472
|
+
features=feature_depths[-i],
|
473
|
+
scale=2,
|
474
|
+
activation=self.activation,
|
475
|
+
name=f"up_{i}_upsample",
|
476
|
+
dtype=self.dtype,
|
477
|
+
precision=self.precision
|
478
|
+
)(x)
|
479
|
+
|
480
|
+
# x = nn.GroupNorm(8)(x)
|
481
|
+
x = ConvLayer(
|
482
|
+
conv_type,
|
483
|
+
features=self.feature_depths[0],
|
484
|
+
kernel_size=(3, 3),
|
485
|
+
strides=(1, 1),
|
486
|
+
kernel_init=kernel_init(0.0),
|
487
|
+
dtype=self.dtype,
|
488
|
+
precision=self.precision
|
489
|
+
)(x)
|
490
|
+
|
491
|
+
x = jnp.concatenate([x, downs.pop()], axis=-1)
|
492
|
+
|
493
|
+
x = ResidualBlock(
|
494
|
+
conv_type,
|
495
|
+
name="final_residual",
|
496
|
+
features=self.feature_depths[0],
|
497
|
+
kernel_init=kernel_init(1.0),
|
498
|
+
kernel_size=(3,3),
|
499
|
+
strides=(1, 1),
|
500
|
+
activation=self.activation,
|
501
|
+
norm_groups=self.norm_groups,
|
502
|
+
dtype=self.dtype,
|
503
|
+
precision=self.precision
|
504
|
+
)(x, temb)
|
505
|
+
|
506
|
+
x = nn.GroupNorm(self.norm_groups)(x)
|
507
|
+
x = self.activation(x)
|
508
|
+
|
509
|
+
noise_out = ConvLayer(
|
510
|
+
conv_type,
|
511
|
+
features=3,
|
512
|
+
kernel_size=(3, 3),
|
513
|
+
strides=(1, 1),
|
514
|
+
# activation=jax.nn.mish
|
515
|
+
kernel_init=kernel_init(0.0),
|
516
|
+
dtype=self.dtype,
|
517
|
+
precision=self.precision
|
518
|
+
)(x)
|
519
|
+
return noise_out#, attentions
|
@@ -0,0 +1,96 @@
|
|
1
|
+
from typing import Union
|
2
|
+
import jax.numpy as jnp
|
3
|
+
from ..schedulers import NoiseScheduler, GeneralizedNoiseScheduler
|
4
|
+
|
5
|
+
############################################################################################################
|
6
|
+
# Prediction Transforms
|
7
|
+
############################################################################################################
|
8
|
+
|
9
|
+
class DiffusionPredictionTransform():
|
10
|
+
def pred_transform(self, x_t, preds, rates) -> jnp.ndarray:
|
11
|
+
return preds
|
12
|
+
|
13
|
+
def __call__(self, x_t, preds, current_step, noise_schedule:NoiseScheduler) -> Union[jnp.ndarray, jnp.ndarray]:
|
14
|
+
rates = noise_schedule.get_rates(current_step)
|
15
|
+
preds = self.pred_transform(x_t, preds, rates)
|
16
|
+
x_0, epsilon = self.backward_diffusion(x_t, preds, rates)
|
17
|
+
return x_0, epsilon
|
18
|
+
|
19
|
+
def forward_diffusion(self, x_0, epsilon, rates: tuple[jnp.ndarray, jnp.ndarray]) -> Union[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
|
20
|
+
signal_rate, noise_rate = rates
|
21
|
+
x_t = signal_rate * x_0 + noise_rate * epsilon
|
22
|
+
expected_output = self.get_target(x_0, epsilon, (signal_rate, noise_rate))
|
23
|
+
c_in = self.get_input_scale((signal_rate, noise_rate))
|
24
|
+
return x_t, c_in, expected_output
|
25
|
+
|
26
|
+
def backward_diffusion(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> Union[jnp.ndarray, jnp.ndarray]:
|
27
|
+
raise NotImplementedError
|
28
|
+
|
29
|
+
def get_target(self, x_0, epsilon, rates) ->jnp.ndarray:
|
30
|
+
return x_0
|
31
|
+
|
32
|
+
def get_input_scale(self, rates: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
|
33
|
+
return 1
|
34
|
+
|
35
|
+
class EpsilonPredictionTransform(DiffusionPredictionTransform):
|
36
|
+
def backward_diffusion(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> Union[jnp.ndarray, jnp.ndarray]:
|
37
|
+
# preds is the predicted noise
|
38
|
+
epsilon = preds
|
39
|
+
signal_rates, noise_rates = rates
|
40
|
+
x_0 = (x_t - epsilon * noise_rates) / signal_rates
|
41
|
+
return x_0, epsilon
|
42
|
+
|
43
|
+
def get_target(self, x_0, epsilon, rates) ->jnp.ndarray:
|
44
|
+
return epsilon
|
45
|
+
|
46
|
+
class DirectPredictionTransform(DiffusionPredictionTransform):
|
47
|
+
def backward_diffusion(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> Union[jnp.ndarray, jnp.ndarray]:
|
48
|
+
# Here the model predicts x_0 directly
|
49
|
+
x_0 = preds
|
50
|
+
signal_rate, noise_rate = rates
|
51
|
+
epsilon = (x_t - x_0 * signal_rate) / noise_rate
|
52
|
+
return x_0, epsilon
|
53
|
+
|
54
|
+
class VPredictionTransform(DiffusionPredictionTransform):
|
55
|
+
def backward_diffusion(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> Union[jnp.ndarray, jnp.ndarray]:
|
56
|
+
# here the model output's V = sqrt_alpha_t * epsilon - sqrt_one_minus_alpha_t * x_0
|
57
|
+
# where epsilon is the noise
|
58
|
+
# x_0 is the current sample
|
59
|
+
v = preds
|
60
|
+
signal_rate, noise_rate = rates
|
61
|
+
variance = signal_rate ** 2 + noise_rate ** 2
|
62
|
+
v = v * jnp.sqrt(variance)
|
63
|
+
x_0 = signal_rate * x_t - noise_rate * v
|
64
|
+
eps_0 = signal_rate * v + noise_rate * x_t
|
65
|
+
return x_0 / variance, eps_0 / variance
|
66
|
+
|
67
|
+
def get_target(self, x_0, epsilon, rates) ->jnp.ndarray:
|
68
|
+
signal_rate, noise_rate = rates
|
69
|
+
v = signal_rate * epsilon - noise_rate * x_0
|
70
|
+
variance = signal_rate**2 + noise_rate**2
|
71
|
+
return v / jnp.sqrt(variance)
|
72
|
+
|
73
|
+
class KarrasPredictionTransform(DiffusionPredictionTransform):
|
74
|
+
def __init__(self, sigma_data=0.5) -> None:
|
75
|
+
super().__init__()
|
76
|
+
self.sigma_data = sigma_data
|
77
|
+
|
78
|
+
def backward_diffusion(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> Union[jnp.ndarray, jnp.ndarray]:
|
79
|
+
x_0 = preds
|
80
|
+
signal_rate, noise_rate = rates
|
81
|
+
epsilon = (x_t - x_0 * signal_rate) / noise_rate
|
82
|
+
return x_0, epsilon
|
83
|
+
|
84
|
+
def pred_transform(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
|
85
|
+
_, sigma = rates
|
86
|
+
c_out = sigma * self.sigma_data / jnp.sqrt(self.sigma_data ** 2 + sigma ** 2)
|
87
|
+
c_skip = self.sigma_data ** 2 / (self.sigma_data ** 2 + sigma ** 2)
|
88
|
+
c_out = c_out.reshape((-1, 1, 1, 1))
|
89
|
+
c_skip = c_skip.reshape((-1, 1, 1, 1))
|
90
|
+
x_0 = c_out * preds + c_skip * x_t
|
91
|
+
return x_0
|
92
|
+
|
93
|
+
def get_input_scale(self, rates: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
|
94
|
+
_, sigma = rates
|
95
|
+
c_in = 1 / jnp.sqrt(self.sigma_data ** 2 + sigma ** 2)
|
96
|
+
return c_in
|
@@ -0,0 +1,7 @@
|
|
1
|
+
from .common import DiffusionSampler
|
2
|
+
from .ddim import DDIMSampler
|
3
|
+
from .ddpm import DDPMSampler, SimpleDDPMSampler
|
4
|
+
from .euler import EulerSampler, SimplifiedEulerSampler
|
5
|
+
from .heun_sampler import HeunSampler
|
6
|
+
from .rk4_sampler import RK4Sampler
|
7
|
+
from .multistep_dpm import MultiStepDPM
|