flaxdiff 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/models/attention.py +140 -162
- flaxdiff/models/autoencoder/__init__.py +2 -0
- flaxdiff/models/autoencoder/autoencoder.py +19 -0
- flaxdiff/models/autoencoder/diffusers.py +91 -0
- flaxdiff/models/autoencoder/simple_autoenc.py +26 -0
- flaxdiff/models/common.py +322 -0
- flaxdiff/models/simple_unet.py +21 -327
- flaxdiff/trainer/__init__.py +2 -201
- flaxdiff/trainer/autoencoder_trainer.py +182 -0
- flaxdiff/trainer/diffusion_trainer.py +202 -0
- flaxdiff/trainer/simple_trainer.py +175 -80
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.6.dist-info}/METADATA +12 -2
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.6.dist-info}/RECORD +15 -9
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.6.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.6.dist-info}/top_level.txt +0 -0
flaxdiff/models/attention.py
CHANGED
@@ -5,7 +5,8 @@ Some Code ported from https://github.com/huggingface/diffusers/blob/main/src/dif
|
|
5
5
|
import jax
|
6
6
|
import jax.numpy as jnp
|
7
7
|
from flax import linen as nn
|
8
|
-
from typing import Dict, Callable, Sequence, Any, Union
|
8
|
+
from typing import Dict, Callable, Sequence, Any, Union, Tuple, Optional
|
9
|
+
from flax.typing import Dtype, PrecisionLike
|
9
10
|
import einops
|
10
11
|
import functools
|
11
12
|
import math
|
@@ -18,8 +19,8 @@ class EfficientAttention(nn.Module):
|
|
18
19
|
query_dim: int
|
19
20
|
heads: int = 4
|
20
21
|
dim_head: int = 64
|
21
|
-
dtype:
|
22
|
-
precision:
|
22
|
+
dtype: Optional[Dtype] = None
|
23
|
+
precision: PrecisionLike = None
|
23
24
|
use_bias: bool = True
|
24
25
|
kernel_init: Callable = lambda : kernel_init(1.0)
|
25
26
|
|
@@ -62,8 +63,13 @@ class EfficientAttention(nn.Module):
|
|
62
63
|
# x has shape [B, H * W, C]
|
63
64
|
context = x if context is None else context
|
64
65
|
|
65
|
-
|
66
|
-
|
66
|
+
orig_x_shape = x.shape
|
67
|
+
if len(x.shape) == 4:
|
68
|
+
B, H, W, C = x.shape
|
69
|
+
x = x.reshape((B, 1, H * W, C))
|
70
|
+
else:
|
71
|
+
B, SEQ, C = x.shape
|
72
|
+
x = x.reshape((B, 1, SEQ, C))
|
67
73
|
|
68
74
|
if len(context.shape) == 4:
|
69
75
|
B, _H, _W, _C = context.shape
|
@@ -93,7 +99,7 @@ class EfficientAttention(nn.Module):
|
|
93
99
|
|
94
100
|
proj = self.proj_attn(hidden_states)
|
95
101
|
|
96
|
-
proj = proj.reshape(
|
102
|
+
proj = proj.reshape(orig_x_shape)
|
97
103
|
|
98
104
|
return proj
|
99
105
|
|
@@ -104,8 +110,8 @@ class NormalAttention(nn.Module):
|
|
104
110
|
query_dim: int
|
105
111
|
heads: int = 4
|
106
112
|
dim_head: int = 64
|
107
|
-
dtype:
|
108
|
-
precision:
|
113
|
+
dtype: Optional[Dtype] = None
|
114
|
+
precision: PrecisionLike = None
|
109
115
|
use_bias: bool = True
|
110
116
|
kernel_init: Callable = lambda : kernel_init(1.0)
|
111
117
|
|
@@ -138,8 +144,10 @@ class NormalAttention(nn.Module):
|
|
138
144
|
@nn.compact
|
139
145
|
def __call__(self, x, context=None):
|
140
146
|
# x has shape [B, H, W, C]
|
141
|
-
|
142
|
-
|
147
|
+
orig_x_shape = x.shape
|
148
|
+
if len(x.shape) == 4:
|
149
|
+
B, H, W, C = x.shape
|
150
|
+
x = x.reshape((B, H*W, C))
|
143
151
|
context = x if context is None else context
|
144
152
|
if len(context.shape) == 4:
|
145
153
|
context = context.reshape((B, H*W, C))
|
@@ -151,16 +159,16 @@ class NormalAttention(nn.Module):
|
|
151
159
|
query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision
|
152
160
|
)
|
153
161
|
proj = self.proj_attn(hidden_states)
|
154
|
-
proj = proj.reshape(
|
162
|
+
proj = proj.reshape(orig_x_shape)
|
155
163
|
return proj
|
156
|
-
|
157
|
-
class
|
164
|
+
|
165
|
+
class BasicTransformerBlock(nn.Module):
|
158
166
|
# Has self and cross attention
|
159
167
|
query_dim: int
|
160
168
|
heads: int = 4
|
161
169
|
dim_head: int = 64
|
162
|
-
dtype:
|
163
|
-
precision:
|
170
|
+
dtype: Optional[Dtype] = None
|
171
|
+
precision: PrecisionLike = None
|
164
172
|
use_bias: bool = True
|
165
173
|
kernel_init: Callable = lambda : kernel_init(1.0)
|
166
174
|
use_flash_attention:bool = False
|
@@ -193,129 +201,26 @@ class AttentionBlock(nn.Module):
|
|
193
201
|
kernel_init=self.kernel_init
|
194
202
|
)
|
195
203
|
|
196
|
-
self.ff =
|
197
|
-
features=self.query_dim,
|
198
|
-
use_bias=self.use_bias,
|
199
|
-
precision=self.precision,
|
200
|
-
dtype=self.dtype,
|
201
|
-
kernel_init=self.kernel_init(),
|
202
|
-
name="ff"
|
203
|
-
)
|
204
|
+
self.ff = FlaxFeedForward(dim=self.query_dim)
|
204
205
|
self.norm1 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
205
206
|
self.norm2 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
206
207
|
self.norm3 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
207
|
-
self.norm4 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
208
208
|
|
209
209
|
@nn.compact
|
210
210
|
def __call__(self, hidden_states, context=None):
|
211
211
|
# self attention
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
hidden_states = self.attention1(hidden_states, context)
|
216
|
-
else:
|
217
|
-
hidden_states = self.attention1(hidden_states)
|
218
|
-
hidden_states = hidden_states + residual
|
212
|
+
if not self.use_cross_only:
|
213
|
+
print("Using self attention")
|
214
|
+
hidden_states = hidden_states + self.attention1(self.norm1(hidden_states))
|
219
215
|
|
220
216
|
# cross attention
|
221
|
-
|
222
|
-
hidden_states = self.norm2(hidden_states)
|
223
|
-
hidden_states = self.attention2(hidden_states, context)
|
224
|
-
hidden_states = hidden_states + residual
|
217
|
+
hidden_states = hidden_states + self.attention2(self.norm2(hidden_states), context)
|
225
218
|
|
226
219
|
# feed forward
|
227
|
-
|
228
|
-
hidden_states = self.norm3(hidden_states)
|
229
|
-
hidden_states = nn.gelu(hidden_states)
|
230
|
-
hidden_states = self.ff(hidden_states)
|
231
|
-
hidden_states = hidden_states + residual
|
220
|
+
hidden_states = hidden_states + self.ff(self.norm3(hidden_states))
|
232
221
|
|
233
222
|
return hidden_states
|
234
223
|
|
235
|
-
class TransformerBlock(nn.Module):
|
236
|
-
heads: int = 4
|
237
|
-
dim_head: int = 32
|
238
|
-
use_linear_attention: bool = True
|
239
|
-
dtype: Any = jnp.float32
|
240
|
-
precision: Any = jax.lax.Precision.HIGH
|
241
|
-
use_projection: bool = False
|
242
|
-
use_flash_attention:bool = True
|
243
|
-
use_self_and_cross:bool = False
|
244
|
-
|
245
|
-
@nn.compact
|
246
|
-
def __call__(self, x, context=None):
|
247
|
-
inner_dim = self.heads * self.dim_head
|
248
|
-
B, H, W, C = x.shape
|
249
|
-
normed_x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
|
250
|
-
if self.use_projection == True:
|
251
|
-
if self.use_linear_attention:
|
252
|
-
projected_x = nn.Dense(features=inner_dim,
|
253
|
-
use_bias=False, precision=self.precision,
|
254
|
-
kernel_init=kernel_init(1.0),
|
255
|
-
dtype=self.dtype, name=f'project_in')(normed_x)
|
256
|
-
else:
|
257
|
-
projected_x = nn.Conv(
|
258
|
-
features=inner_dim, kernel_size=(1, 1),
|
259
|
-
kernel_init=kernel_init(1.0),
|
260
|
-
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
|
261
|
-
precision=self.precision, name=f'project_in_conv',
|
262
|
-
)(normed_x)
|
263
|
-
else:
|
264
|
-
projected_x = normed_x
|
265
|
-
inner_dim = C
|
266
|
-
|
267
|
-
context = projected_x if context is None else context
|
268
|
-
|
269
|
-
if self.use_self_and_cross:
|
270
|
-
projected_x = AttentionBlock(
|
271
|
-
query_dim=inner_dim,
|
272
|
-
heads=self.heads,
|
273
|
-
dim_head=self.dim_head,
|
274
|
-
name=f'Attention',
|
275
|
-
precision=self.precision,
|
276
|
-
use_bias=False,
|
277
|
-
dtype=self.dtype,
|
278
|
-
use_flash_attention=self.use_flash_attention,
|
279
|
-
use_cross_only=False
|
280
|
-
)(projected_x, context)
|
281
|
-
elif self.use_flash_attention == True:
|
282
|
-
projected_x = EfficientAttention(
|
283
|
-
query_dim=inner_dim,
|
284
|
-
heads=self.heads,
|
285
|
-
dim_head=self.dim_head,
|
286
|
-
name=f'Attention',
|
287
|
-
precision=self.precision,
|
288
|
-
use_bias=False,
|
289
|
-
dtype=self.dtype,
|
290
|
-
)(projected_x, context)
|
291
|
-
else:
|
292
|
-
projected_x = NormalAttention(
|
293
|
-
query_dim=inner_dim,
|
294
|
-
heads=self.heads,
|
295
|
-
dim_head=self.dim_head,
|
296
|
-
name=f'Attention',
|
297
|
-
precision=self.precision,
|
298
|
-
use_bias=False,
|
299
|
-
)(projected_x, context)
|
300
|
-
|
301
|
-
|
302
|
-
if self.use_projection == True:
|
303
|
-
if self.use_linear_attention:
|
304
|
-
projected_x = nn.Dense(features=C, precision=self.precision,
|
305
|
-
dtype=self.dtype, use_bias=False,
|
306
|
-
kernel_init=kernel_init(1.0),
|
307
|
-
name=f'project_out')(projected_x)
|
308
|
-
else:
|
309
|
-
projected_x = nn.Conv(
|
310
|
-
features=C, kernel_size=(1, 1),
|
311
|
-
kernel_init=kernel_init(1.0),
|
312
|
-
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
|
313
|
-
precision=self.precision, name=f'project_out_conv',
|
314
|
-
)(projected_x)
|
315
|
-
|
316
|
-
out = x + projected_x
|
317
|
-
return out
|
318
|
-
|
319
224
|
class FlaxGEGLU(nn.Module):
|
320
225
|
r"""
|
321
226
|
Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
|
@@ -333,10 +238,11 @@ class FlaxGEGLU(nn.Module):
|
|
333
238
|
dim: int
|
334
239
|
dropout: float = 0.0
|
335
240
|
dtype: jnp.dtype = jnp.float32
|
241
|
+
precision: Any = jax.lax.Precision.DEFAULT
|
336
242
|
|
337
243
|
def setup(self):
|
338
244
|
inner_dim = self.dim * 4
|
339
|
-
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype, precision=
|
245
|
+
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype, precision=self.precision)
|
340
246
|
|
341
247
|
def __call__(self, hidden_states):
|
342
248
|
hidden_states = self.proj(hidden_states)
|
@@ -362,14 +268,14 @@ class FlaxFeedForward(nn.Module):
|
|
362
268
|
"""
|
363
269
|
|
364
270
|
dim: int
|
365
|
-
dropout: float = 0.0
|
366
271
|
dtype: jnp.dtype = jnp.float32
|
272
|
+
precision: Any = jax.lax.Precision.DEFAULT
|
367
273
|
|
368
274
|
def setup(self):
|
369
275
|
# The second linear layer needs to be called
|
370
276
|
# net_2 for now to match the index of the Sequential layer
|
371
|
-
self.net_0 = FlaxGEGLU(self.dim, self.dtype)
|
372
|
-
self.net_2 = nn.Dense(self.dim, dtype=self.dtype, precision=
|
277
|
+
self.net_0 = FlaxGEGLU(self.dim, self.dtype, precision=self.precision)
|
278
|
+
self.net_2 = nn.Dense(self.dim, dtype=self.dtype, precision=self.precision)
|
373
279
|
|
374
280
|
def __call__(self, hidden_states):
|
375
281
|
hidden_states = self.net_0(hidden_states)
|
@@ -377,55 +283,127 @@ class FlaxFeedForward(nn.Module):
|
|
377
283
|
return hidden_states
|
378
284
|
|
379
285
|
class BasicTransformerBlock(nn.Module):
|
286
|
+
# Has self and cross attention
|
380
287
|
query_dim: int
|
381
|
-
heads: int
|
382
|
-
dim_head: int
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
288
|
+
heads: int = 4
|
289
|
+
dim_head: int = 64
|
290
|
+
dtype: Optional[Dtype] = None
|
291
|
+
precision: PrecisionLike = None
|
292
|
+
use_bias: bool = True
|
293
|
+
kernel_init: Callable = lambda : kernel_init(1.0)
|
294
|
+
use_flash_attention:bool = False
|
295
|
+
use_cross_only:bool = False
|
296
|
+
only_pure_attention:bool = False
|
297
|
+
|
390
298
|
def setup(self):
|
391
|
-
|
392
|
-
|
393
|
-
|
299
|
+
if self.use_flash_attention:
|
300
|
+
attenBlock = EfficientAttention
|
301
|
+
else:
|
302
|
+
attenBlock = NormalAttention
|
303
|
+
|
304
|
+
self.attention1 = attenBlock(
|
305
|
+
query_dim=self.query_dim,
|
394
306
|
heads=self.heads,
|
395
307
|
dim_head=self.dim_head,
|
396
|
-
|
308
|
+
name=f'Attention1',
|
397
309
|
precision=self.precision,
|
310
|
+
use_bias=self.use_bias,
|
311
|
+
dtype=self.dtype,
|
312
|
+
kernel_init=self.kernel_init
|
398
313
|
)
|
399
|
-
|
400
|
-
self.attn2 = NormalAttention(
|
314
|
+
self.attention2 = attenBlock(
|
401
315
|
query_dim=self.query_dim,
|
402
316
|
heads=self.heads,
|
403
317
|
dim_head=self.dim_head,
|
404
|
-
|
318
|
+
name=f'Attention2',
|
405
319
|
precision=self.precision,
|
320
|
+
use_bias=self.use_bias,
|
321
|
+
dtype=self.dtype,
|
322
|
+
kernel_init=self.kernel_init
|
406
323
|
)
|
407
|
-
|
324
|
+
|
325
|
+
self.ff = FlaxFeedForward(dim=self.query_dim)
|
408
326
|
self.norm1 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
409
327
|
self.norm2 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
410
328
|
self.norm3 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
411
|
-
|
412
|
-
|
329
|
+
|
330
|
+
@nn.compact
|
331
|
+
def __call__(self, hidden_states, context=None):
|
332
|
+
if self.only_pure_attention:
|
333
|
+
return self.attention2(self.norm2(hidden_states), context)
|
334
|
+
|
413
335
|
# self attention
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
else:
|
418
|
-
hidden_states = self.attn1(self.norm1(hidden_states))
|
419
|
-
hidden_states = hidden_states + residual
|
420
|
-
|
336
|
+
if not self.use_cross_only:
|
337
|
+
hidden_states = hidden_states + self.attention1(self.norm1(hidden_states))
|
338
|
+
|
421
339
|
# cross attention
|
422
|
-
|
423
|
-
hidden_states = self.attn2(self.norm2(hidden_states), context)
|
424
|
-
hidden_states = hidden_states + residual
|
425
|
-
|
340
|
+
hidden_states = hidden_states + self.attention2(self.norm2(hidden_states), context)
|
426
341
|
# feed forward
|
427
|
-
|
428
|
-
|
429
|
-
|
342
|
+
hidden_states = hidden_states + self.ff(self.norm3(hidden_states))
|
343
|
+
|
344
|
+
return hidden_states
|
345
|
+
|
346
|
+
class TransformerBlock(nn.Module):
|
347
|
+
heads: int = 4
|
348
|
+
dim_head: int = 32
|
349
|
+
use_linear_attention: bool = True
|
350
|
+
dtype: Optional[Dtype] = None
|
351
|
+
precision: PrecisionLike = None
|
352
|
+
use_projection: bool = False
|
353
|
+
use_flash_attention:bool = True
|
354
|
+
use_self_and_cross:bool = False
|
355
|
+
only_pure_attention:bool = False
|
356
|
+
|
357
|
+
@nn.compact
|
358
|
+
def __call__(self, x, context=None):
|
359
|
+
inner_dim = self.heads * self.dim_head
|
360
|
+
B, H, W, C = x.shape
|
361
|
+
normed_x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
|
362
|
+
if self.use_projection == True:
|
363
|
+
if self.use_linear_attention:
|
364
|
+
projected_x = nn.Dense(features=inner_dim,
|
365
|
+
use_bias=False, precision=self.precision,
|
366
|
+
kernel_init=kernel_init(1.0),
|
367
|
+
dtype=self.dtype, name=f'project_in')(normed_x)
|
368
|
+
else:
|
369
|
+
projected_x = nn.Conv(
|
370
|
+
features=inner_dim, kernel_size=(1, 1),
|
371
|
+
kernel_init=kernel_init(1.0),
|
372
|
+
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
|
373
|
+
precision=self.precision, name=f'project_in_conv',
|
374
|
+
)(normed_x)
|
375
|
+
else:
|
376
|
+
projected_x = normed_x
|
377
|
+
inner_dim = C
|
378
|
+
|
379
|
+
context = projected_x if context is None else context
|
430
380
|
|
431
|
-
|
381
|
+
projected_x = BasicTransformerBlock(
|
382
|
+
query_dim=inner_dim,
|
383
|
+
heads=self.heads,
|
384
|
+
dim_head=self.dim_head,
|
385
|
+
name=f'Attention',
|
386
|
+
precision=self.precision,
|
387
|
+
use_bias=False,
|
388
|
+
dtype=self.dtype,
|
389
|
+
use_flash_attention=self.use_flash_attention,
|
390
|
+
use_cross_only=(not self.use_self_and_cross),
|
391
|
+
only_pure_attention=self.only_pure_attention
|
392
|
+
)(projected_x, context)
|
393
|
+
|
394
|
+
if self.use_projection == True:
|
395
|
+
if self.use_linear_attention:
|
396
|
+
projected_x = nn.Dense(features=C, precision=self.precision,
|
397
|
+
dtype=self.dtype, use_bias=False,
|
398
|
+
kernel_init=kernel_init(1.0),
|
399
|
+
name=f'project_out')(projected_x)
|
400
|
+
else:
|
401
|
+
projected_x = nn.Conv(
|
402
|
+
features=C, kernel_size=(1, 1),
|
403
|
+
kernel_init=kernel_init(1.0),
|
404
|
+
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
|
405
|
+
precision=self.precision, name=f'project_out_conv',
|
406
|
+
)(projected_x)
|
407
|
+
|
408
|
+
out = x + projected_x
|
409
|
+
return out
|
@@ -0,0 +1,19 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
from flax import linen as nn
|
4
|
+
from typing import Dict, Callable, Sequence, Any, Union
|
5
|
+
import einops
|
6
|
+
from ..common import kernel_init, ConvLayer, Upsample, Downsample, PixelShuffle
|
7
|
+
|
8
|
+
|
9
|
+
class AutoEncoder():
|
10
|
+
def encode(self, x: jnp.ndarray, **kwargs) -> jnp.ndarray:
|
11
|
+
raise NotImplementedError
|
12
|
+
|
13
|
+
def decode(self, z: jnp.ndarray, **kwargs) -> jnp.ndarray:
|
14
|
+
raise NotImplementedError
|
15
|
+
|
16
|
+
def __call__(self, x: jnp.ndarray):
|
17
|
+
latents = self.encode(x)
|
18
|
+
reconstructions = self.decode(latents)
|
19
|
+
return reconstructions
|
@@ -0,0 +1,91 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
from flax import linen as nn
|
4
|
+
from .autoencoder import AutoEncoder
|
5
|
+
|
6
|
+
"""
|
7
|
+
This module contains an Autoencoder implementation which uses the Stable Diffusion VAE model from the HuggingFace Diffusers library.
|
8
|
+
The actual model was not trained by me, but was taken from the HuggingFace model hub.
|
9
|
+
I have only implemented the wrapper around the diffusers pipeline to make it compatible with our library
|
10
|
+
All credits for the model go to the developers of Stable Diffusion VAE and all credits for the pipeline go to the developers of the Diffusers library.
|
11
|
+
"""
|
12
|
+
|
13
|
+
class StableDiffusionVAE(AutoEncoder):
|
14
|
+
def __init__(self, modelname = "CompVis/stable-diffusion-v1-4"):
|
15
|
+
|
16
|
+
from diffusers.models.vae_flax import FlaxEncoder, FlaxDecoder
|
17
|
+
from diffusers import FlaxStableDiffusionPipeline
|
18
|
+
|
19
|
+
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
20
|
+
modelname,
|
21
|
+
revision="bf16",
|
22
|
+
dtype=jnp.bfloat16,
|
23
|
+
)
|
24
|
+
|
25
|
+
vae = pipeline.vae
|
26
|
+
|
27
|
+
enc = FlaxEncoder(
|
28
|
+
in_channels=vae.config.in_channels,
|
29
|
+
out_channels=vae.config.latent_channels,
|
30
|
+
down_block_types=vae.config.down_block_types,
|
31
|
+
block_out_channels=vae.config.block_out_channels,
|
32
|
+
layers_per_block=vae.config.layers_per_block,
|
33
|
+
act_fn=vae.config.act_fn,
|
34
|
+
norm_num_groups=vae.config.norm_num_groups,
|
35
|
+
double_z=True,
|
36
|
+
dtype=vae.dtype,
|
37
|
+
)
|
38
|
+
|
39
|
+
dec = FlaxDecoder(
|
40
|
+
in_channels=vae.config.latent_channels,
|
41
|
+
out_channels=vae.config.out_channels,
|
42
|
+
up_block_types=vae.config.up_block_types,
|
43
|
+
block_out_channels=vae.config.block_out_channels,
|
44
|
+
layers_per_block=vae.config.layers_per_block,
|
45
|
+
norm_num_groups=vae.config.norm_num_groups,
|
46
|
+
act_fn=vae.config.act_fn,
|
47
|
+
dtype=vae.dtype,
|
48
|
+
)
|
49
|
+
|
50
|
+
quant_conv = nn.Conv(
|
51
|
+
2 * vae.config.latent_channels,
|
52
|
+
kernel_size=(1, 1),
|
53
|
+
strides=(1, 1),
|
54
|
+
padding="VALID",
|
55
|
+
dtype=vae.dtype,
|
56
|
+
)
|
57
|
+
|
58
|
+
post_quant_conv = nn.Conv(
|
59
|
+
vae.config.latent_channels,
|
60
|
+
kernel_size=(1, 1),
|
61
|
+
strides=(1, 1),
|
62
|
+
padding="VALID",
|
63
|
+
dtype=vae.dtype,
|
64
|
+
)
|
65
|
+
|
66
|
+
self.enc = enc
|
67
|
+
self.dec = dec
|
68
|
+
self.post_quant_conv = post_quant_conv
|
69
|
+
self.quant_conv = quant_conv
|
70
|
+
self.params = params
|
71
|
+
self.scaling_factor = vae.scaling_factor
|
72
|
+
|
73
|
+
def encode(self, images, rngkey: jax.random.PRNGKey = None):
|
74
|
+
latents = self.enc.apply({"params": self.params["vae"]['encoder']}, images, deterministic=True)
|
75
|
+
latents = self.quant_conv.apply({"params": self.params["vae"]['quant_conv']}, latents)
|
76
|
+
if rngkey is not None:
|
77
|
+
mean, log_std = jnp.split(latents, 2, axis=-1)
|
78
|
+
log_std = jnp.clip(log_std, -30, 20)
|
79
|
+
std = jnp.exp(0.5 * log_std)
|
80
|
+
latents = mean + std * jax.random.normal(rngkey, mean.shape, dtype=mean.dtype)
|
81
|
+
print("Sampled")
|
82
|
+
else:
|
83
|
+
# return the mean
|
84
|
+
latents, _ = jnp.split(latents, 2, axis=-1)
|
85
|
+
latents *= self.scaling_factor
|
86
|
+
return latents
|
87
|
+
|
88
|
+
def decode(self, latents):
|
89
|
+
latents = (1.0 / self.scaling_factor) * latents
|
90
|
+
latents = self.post_quant_conv.apply({"params": self.params["vae"]['post_quant_conv']}, latents)
|
91
|
+
return self.dec.apply({"params": self.params["vae"]['decoder']}, latents)
|
@@ -0,0 +1,26 @@
|
|
1
|
+
from typing import Any, List, Optional, Callable
|
2
|
+
import jax
|
3
|
+
import flax.linen as nn
|
4
|
+
from jax import numpy as jnp
|
5
|
+
from flax.typing import Dtype, PrecisionLike
|
6
|
+
from .autoencoder import AutoEncoder
|
7
|
+
|
8
|
+
class SimpleAutoEncoder(AutoEncoder):
|
9
|
+
latent_channels: int
|
10
|
+
feature_depths: List[int]=[64, 128, 256, 512]
|
11
|
+
attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}],
|
12
|
+
num_res_blocks: int=2
|
13
|
+
num_middle_res_blocks:int=1,
|
14
|
+
activation:Callable = jax.nn.swish
|
15
|
+
norm_groups:int=8
|
16
|
+
dtype: Optional[Dtype] = None
|
17
|
+
precision: PrecisionLike = None
|
18
|
+
|
19
|
+
# def encode(self, x: jnp.ndarray):
|
20
|
+
|
21
|
+
|
22
|
+
@nn.compact
|
23
|
+
def __call__(self, x: jnp.ndarray):
|
24
|
+
latents = self.encode(x)
|
25
|
+
reconstructions = self.decode(latents)
|
26
|
+
return reconstructions
|