flaxdiff 0.1.4__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/models/attention.py +132 -155
- flaxdiff/models/autoencoder/__init__.py +0 -0
- flaxdiff/models/autoencoder/autoencoder.py +14 -0
- flaxdiff/models/autoencoder/diffusers.py +88 -0
- flaxdiff/models/common.py +243 -0
- flaxdiff/models/simple_unet.py +17 -252
- flaxdiff/trainer/__init__.py +28 -45
- flaxdiff/trainer/simple_trainer.py +175 -80
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.5.dist-info}/METADATA +10 -2
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.5.dist-info}/RECORD +12 -9
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.5.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.5.dist-info}/top_level.txt +0 -0
flaxdiff/models/attention.py
CHANGED
@@ -62,8 +62,13 @@ class EfficientAttention(nn.Module):
|
|
62
62
|
# x has shape [B, H * W, C]
|
63
63
|
context = x if context is None else context
|
64
64
|
|
65
|
-
|
66
|
-
|
65
|
+
orig_x_shape = x.shape
|
66
|
+
if len(x.shape) == 4:
|
67
|
+
B, H, W, C = x.shape
|
68
|
+
x = x.reshape((B, 1, H * W, C))
|
69
|
+
else:
|
70
|
+
B, SEQ, C = x.shape
|
71
|
+
x = x.reshape((B, 1, SEQ, C))
|
67
72
|
|
68
73
|
if len(context.shape) == 4:
|
69
74
|
B, _H, _W, _C = context.shape
|
@@ -93,7 +98,7 @@ class EfficientAttention(nn.Module):
|
|
93
98
|
|
94
99
|
proj = self.proj_attn(hidden_states)
|
95
100
|
|
96
|
-
proj = proj.reshape(
|
101
|
+
proj = proj.reshape(orig_x_shape)
|
97
102
|
|
98
103
|
return proj
|
99
104
|
|
@@ -138,8 +143,10 @@ class NormalAttention(nn.Module):
|
|
138
143
|
@nn.compact
|
139
144
|
def __call__(self, x, context=None):
|
140
145
|
# x has shape [B, H, W, C]
|
141
|
-
|
142
|
-
|
146
|
+
orig_x_shape = x.shape
|
147
|
+
if len(x.shape) == 4:
|
148
|
+
B, H, W, C = x.shape
|
149
|
+
x = x.reshape((B, H*W, C))
|
143
150
|
context = x if context is None else context
|
144
151
|
if len(context.shape) == 4:
|
145
152
|
context = context.reshape((B, H*W, C))
|
@@ -151,10 +158,10 @@ class NormalAttention(nn.Module):
|
|
151
158
|
query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision
|
152
159
|
)
|
153
160
|
proj = self.proj_attn(hidden_states)
|
154
|
-
proj = proj.reshape(
|
161
|
+
proj = proj.reshape(orig_x_shape)
|
155
162
|
return proj
|
156
|
-
|
157
|
-
class
|
163
|
+
|
164
|
+
class BasicTransformerBlock(nn.Module):
|
158
165
|
# Has self and cross attention
|
159
166
|
query_dim: int
|
160
167
|
heads: int = 4
|
@@ -193,129 +200,26 @@ class AttentionBlock(nn.Module):
|
|
193
200
|
kernel_init=self.kernel_init
|
194
201
|
)
|
195
202
|
|
196
|
-
self.ff =
|
197
|
-
features=self.query_dim,
|
198
|
-
use_bias=self.use_bias,
|
199
|
-
precision=self.precision,
|
200
|
-
dtype=self.dtype,
|
201
|
-
kernel_init=self.kernel_init(),
|
202
|
-
name="ff"
|
203
|
-
)
|
203
|
+
self.ff = FlaxFeedForward(dim=self.query_dim)
|
204
204
|
self.norm1 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
205
205
|
self.norm2 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
206
206
|
self.norm3 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
207
|
-
self.norm4 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
208
207
|
|
209
208
|
@nn.compact
|
210
209
|
def __call__(self, hidden_states, context=None):
|
211
210
|
# self attention
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
hidden_states = self.attention1(hidden_states, context)
|
216
|
-
else:
|
217
|
-
hidden_states = self.attention1(hidden_states)
|
218
|
-
hidden_states = hidden_states + residual
|
211
|
+
if not self.use_cross_only:
|
212
|
+
print("Using self attention")
|
213
|
+
hidden_states = hidden_states + self.attention1(self.norm1(hidden_states))
|
219
214
|
|
220
215
|
# cross attention
|
221
|
-
|
222
|
-
hidden_states = self.norm2(hidden_states)
|
223
|
-
hidden_states = self.attention2(hidden_states, context)
|
224
|
-
hidden_states = hidden_states + residual
|
216
|
+
hidden_states = hidden_states + self.attention2(self.norm2(hidden_states), context)
|
225
217
|
|
226
218
|
# feed forward
|
227
|
-
|
228
|
-
hidden_states = self.norm3(hidden_states)
|
229
|
-
hidden_states = nn.gelu(hidden_states)
|
230
|
-
hidden_states = self.ff(hidden_states)
|
231
|
-
hidden_states = hidden_states + residual
|
219
|
+
hidden_states = hidden_states + self.ff(self.norm3(hidden_states))
|
232
220
|
|
233
221
|
return hidden_states
|
234
222
|
|
235
|
-
class TransformerBlock(nn.Module):
|
236
|
-
heads: int = 4
|
237
|
-
dim_head: int = 32
|
238
|
-
use_linear_attention: bool = True
|
239
|
-
dtype: Any = jnp.float32
|
240
|
-
precision: Any = jax.lax.Precision.HIGH
|
241
|
-
use_projection: bool = False
|
242
|
-
use_flash_attention:bool = True
|
243
|
-
use_self_and_cross:bool = False
|
244
|
-
|
245
|
-
@nn.compact
|
246
|
-
def __call__(self, x, context=None):
|
247
|
-
inner_dim = self.heads * self.dim_head
|
248
|
-
B, H, W, C = x.shape
|
249
|
-
normed_x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
|
250
|
-
if self.use_projection == True:
|
251
|
-
if self.use_linear_attention:
|
252
|
-
projected_x = nn.Dense(features=inner_dim,
|
253
|
-
use_bias=False, precision=self.precision,
|
254
|
-
kernel_init=kernel_init(1.0),
|
255
|
-
dtype=self.dtype, name=f'project_in')(normed_x)
|
256
|
-
else:
|
257
|
-
projected_x = nn.Conv(
|
258
|
-
features=inner_dim, kernel_size=(1, 1),
|
259
|
-
kernel_init=kernel_init(1.0),
|
260
|
-
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
|
261
|
-
precision=self.precision, name=f'project_in_conv',
|
262
|
-
)(normed_x)
|
263
|
-
else:
|
264
|
-
projected_x = normed_x
|
265
|
-
inner_dim = C
|
266
|
-
|
267
|
-
context = projected_x if context is None else context
|
268
|
-
|
269
|
-
if self.use_self_and_cross:
|
270
|
-
projected_x = AttentionBlock(
|
271
|
-
query_dim=inner_dim,
|
272
|
-
heads=self.heads,
|
273
|
-
dim_head=self.dim_head,
|
274
|
-
name=f'Attention',
|
275
|
-
precision=self.precision,
|
276
|
-
use_bias=False,
|
277
|
-
dtype=self.dtype,
|
278
|
-
use_flash_attention=self.use_flash_attention,
|
279
|
-
use_cross_only=False
|
280
|
-
)(projected_x, context)
|
281
|
-
elif self.use_flash_attention == True:
|
282
|
-
projected_x = EfficientAttention(
|
283
|
-
query_dim=inner_dim,
|
284
|
-
heads=self.heads,
|
285
|
-
dim_head=self.dim_head,
|
286
|
-
name=f'Attention',
|
287
|
-
precision=self.precision,
|
288
|
-
use_bias=False,
|
289
|
-
dtype=self.dtype,
|
290
|
-
)(projected_x, context)
|
291
|
-
else:
|
292
|
-
projected_x = NormalAttention(
|
293
|
-
query_dim=inner_dim,
|
294
|
-
heads=self.heads,
|
295
|
-
dim_head=self.dim_head,
|
296
|
-
name=f'Attention',
|
297
|
-
precision=self.precision,
|
298
|
-
use_bias=False,
|
299
|
-
)(projected_x, context)
|
300
|
-
|
301
|
-
|
302
|
-
if self.use_projection == True:
|
303
|
-
if self.use_linear_attention:
|
304
|
-
projected_x = nn.Dense(features=C, precision=self.precision,
|
305
|
-
dtype=self.dtype, use_bias=False,
|
306
|
-
kernel_init=kernel_init(1.0),
|
307
|
-
name=f'project_out')(projected_x)
|
308
|
-
else:
|
309
|
-
projected_x = nn.Conv(
|
310
|
-
features=C, kernel_size=(1, 1),
|
311
|
-
kernel_init=kernel_init(1.0),
|
312
|
-
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
|
313
|
-
precision=self.precision, name=f'project_out_conv',
|
314
|
-
)(projected_x)
|
315
|
-
|
316
|
-
out = x + projected_x
|
317
|
-
return out
|
318
|
-
|
319
223
|
class FlaxGEGLU(nn.Module):
|
320
224
|
r"""
|
321
225
|
Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
|
@@ -333,10 +237,11 @@ class FlaxGEGLU(nn.Module):
|
|
333
237
|
dim: int
|
334
238
|
dropout: float = 0.0
|
335
239
|
dtype: jnp.dtype = jnp.float32
|
240
|
+
precision: Any = jax.lax.Precision.DEFAULT
|
336
241
|
|
337
242
|
def setup(self):
|
338
243
|
inner_dim = self.dim * 4
|
339
|
-
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype, precision=
|
244
|
+
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype, precision=self.precision)
|
340
245
|
|
341
246
|
def __call__(self, hidden_states):
|
342
247
|
hidden_states = self.proj(hidden_states)
|
@@ -362,14 +267,14 @@ class FlaxFeedForward(nn.Module):
|
|
362
267
|
"""
|
363
268
|
|
364
269
|
dim: int
|
365
|
-
dropout: float = 0.0
|
366
270
|
dtype: jnp.dtype = jnp.float32
|
271
|
+
precision: Any = jax.lax.Precision.DEFAULT
|
367
272
|
|
368
273
|
def setup(self):
|
369
274
|
# The second linear layer needs to be called
|
370
275
|
# net_2 for now to match the index of the Sequential layer
|
371
|
-
self.net_0 = FlaxGEGLU(self.dim, self.dtype)
|
372
|
-
self.net_2 = nn.Dense(self.dim, dtype=self.dtype, precision=
|
276
|
+
self.net_0 = FlaxGEGLU(self.dim, self.dtype, precision=self.precision)
|
277
|
+
self.net_2 = nn.Dense(self.dim, dtype=self.dtype, precision=self.precision)
|
373
278
|
|
374
279
|
def __call__(self, hidden_states):
|
375
280
|
hidden_states = self.net_0(hidden_states)
|
@@ -377,55 +282,127 @@ class FlaxFeedForward(nn.Module):
|
|
377
282
|
return hidden_states
|
378
283
|
|
379
284
|
class BasicTransformerBlock(nn.Module):
|
285
|
+
# Has self and cross attention
|
380
286
|
query_dim: int
|
381
|
-
heads: int
|
382
|
-
dim_head: int
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
287
|
+
heads: int = 4
|
288
|
+
dim_head: int = 64
|
289
|
+
dtype: Any = jnp.float32
|
290
|
+
precision: Any = jax.lax.Precision.HIGHEST
|
291
|
+
use_bias: bool = True
|
292
|
+
kernel_init: Callable = lambda : kernel_init(1.0)
|
293
|
+
use_flash_attention:bool = False
|
294
|
+
use_cross_only:bool = False
|
295
|
+
only_pure_attention:bool = False
|
296
|
+
|
390
297
|
def setup(self):
|
391
|
-
|
392
|
-
|
393
|
-
|
298
|
+
if self.use_flash_attention:
|
299
|
+
attenBlock = EfficientAttention
|
300
|
+
else:
|
301
|
+
attenBlock = NormalAttention
|
302
|
+
|
303
|
+
self.attention1 = attenBlock(
|
304
|
+
query_dim=self.query_dim,
|
394
305
|
heads=self.heads,
|
395
306
|
dim_head=self.dim_head,
|
396
|
-
|
307
|
+
name=f'Attention1',
|
397
308
|
precision=self.precision,
|
309
|
+
use_bias=self.use_bias,
|
310
|
+
dtype=self.dtype,
|
311
|
+
kernel_init=self.kernel_init
|
398
312
|
)
|
399
|
-
|
400
|
-
self.attn2 = NormalAttention(
|
313
|
+
self.attention2 = attenBlock(
|
401
314
|
query_dim=self.query_dim,
|
402
315
|
heads=self.heads,
|
403
316
|
dim_head=self.dim_head,
|
404
|
-
|
317
|
+
name=f'Attention2',
|
405
318
|
precision=self.precision,
|
319
|
+
use_bias=self.use_bias,
|
320
|
+
dtype=self.dtype,
|
321
|
+
kernel_init=self.kernel_init
|
406
322
|
)
|
407
|
-
|
323
|
+
|
324
|
+
self.ff = FlaxFeedForward(dim=self.query_dim)
|
408
325
|
self.norm1 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
409
326
|
self.norm2 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
410
327
|
self.norm3 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
411
|
-
|
412
|
-
|
328
|
+
|
329
|
+
@nn.compact
|
330
|
+
def __call__(self, hidden_states, context=None):
|
331
|
+
if self.only_pure_attention:
|
332
|
+
return self.attention2(self.norm2(hidden_states), context)
|
333
|
+
|
413
334
|
# self attention
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
else:
|
418
|
-
hidden_states = self.attn1(self.norm1(hidden_states))
|
419
|
-
hidden_states = hidden_states + residual
|
420
|
-
|
335
|
+
if not self.use_cross_only:
|
336
|
+
hidden_states = hidden_states + self.attention1(self.norm1(hidden_states))
|
337
|
+
|
421
338
|
# cross attention
|
422
|
-
|
423
|
-
hidden_states = self.attn2(self.norm2(hidden_states), context)
|
424
|
-
hidden_states = hidden_states + residual
|
425
|
-
|
339
|
+
hidden_states = hidden_states + self.attention2(self.norm2(hidden_states), context)
|
426
340
|
# feed forward
|
427
|
-
|
428
|
-
|
429
|
-
|
341
|
+
hidden_states = hidden_states + self.ff(self.norm3(hidden_states))
|
342
|
+
|
343
|
+
return hidden_states
|
430
344
|
|
431
|
-
|
345
|
+
class TransformerBlock(nn.Module):
|
346
|
+
heads: int = 4
|
347
|
+
dim_head: int = 32
|
348
|
+
use_linear_attention: bool = True
|
349
|
+
dtype: Any = jnp.float32
|
350
|
+
precision: Any = jax.lax.Precision.HIGH
|
351
|
+
use_projection: bool = False
|
352
|
+
use_flash_attention:bool = True
|
353
|
+
use_self_and_cross:bool = False
|
354
|
+
only_pure_attention:bool = False
|
355
|
+
|
356
|
+
@nn.compact
|
357
|
+
def __call__(self, x, context=None):
|
358
|
+
inner_dim = self.heads * self.dim_head
|
359
|
+
B, H, W, C = x.shape
|
360
|
+
normed_x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
|
361
|
+
if self.use_projection == True:
|
362
|
+
if self.use_linear_attention:
|
363
|
+
projected_x = nn.Dense(features=inner_dim,
|
364
|
+
use_bias=False, precision=self.precision,
|
365
|
+
kernel_init=kernel_init(1.0),
|
366
|
+
dtype=self.dtype, name=f'project_in')(normed_x)
|
367
|
+
else:
|
368
|
+
projected_x = nn.Conv(
|
369
|
+
features=inner_dim, kernel_size=(1, 1),
|
370
|
+
kernel_init=kernel_init(1.0),
|
371
|
+
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
|
372
|
+
precision=self.precision, name=f'project_in_conv',
|
373
|
+
)(normed_x)
|
374
|
+
else:
|
375
|
+
projected_x = normed_x
|
376
|
+
inner_dim = C
|
377
|
+
|
378
|
+
context = projected_x if context is None else context
|
379
|
+
|
380
|
+
projected_x = BasicTransformerBlock(
|
381
|
+
query_dim=inner_dim,
|
382
|
+
heads=self.heads,
|
383
|
+
dim_head=self.dim_head,
|
384
|
+
name=f'Attention',
|
385
|
+
precision=self.precision,
|
386
|
+
use_bias=False,
|
387
|
+
dtype=self.dtype,
|
388
|
+
use_flash_attention=self.use_flash_attention,
|
389
|
+
use_cross_only=(not self.use_self_and_cross),
|
390
|
+
only_pure_attention=self.only_pure_attention
|
391
|
+
)(projected_x, context)
|
392
|
+
|
393
|
+
if self.use_projection == True:
|
394
|
+
if self.use_linear_attention:
|
395
|
+
projected_x = nn.Dense(features=C, precision=self.precision,
|
396
|
+
dtype=self.dtype, use_bias=False,
|
397
|
+
kernel_init=kernel_init(1.0),
|
398
|
+
name=f'project_out')(projected_x)
|
399
|
+
else:
|
400
|
+
projected_x = nn.Conv(
|
401
|
+
features=C, kernel_size=(1, 1),
|
402
|
+
kernel_init=kernel_init(1.0),
|
403
|
+
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
|
404
|
+
precision=self.precision, name=f'project_out_conv',
|
405
|
+
)(projected_x)
|
406
|
+
|
407
|
+
out = x + projected_x
|
408
|
+
return out
|
File without changes
|
@@ -0,0 +1,14 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
from flax import linen as nn
|
4
|
+
from typing import Dict, Callable, Sequence, Any, Union
|
5
|
+
import einops
|
6
|
+
from ..common import kernel_init, ConvLayer, Upsample, Downsample, PixelShuffle
|
7
|
+
|
8
|
+
|
9
|
+
class AutoEncoder:
|
10
|
+
def encode(self, x: jnp.ndarray, **kwargs) -> jnp.ndarray:
|
11
|
+
raise NotImplementedError
|
12
|
+
|
13
|
+
def decode(self, z: jnp.ndarray, **kwargs) -> jnp.ndarray:
|
14
|
+
raise NotImplementedError
|
@@ -0,0 +1,88 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
from flax import linen as nn
|
4
|
+
from .autoencoder import AutoEncoder
|
5
|
+
|
6
|
+
"""
|
7
|
+
This module contains an Autoencoder implementation which uses the Stable Diffusion VAE model from the HuggingFace Diffusers library.
|
8
|
+
"""
|
9
|
+
|
10
|
+
class StableDiffusionVAE(AutoEncoder):
|
11
|
+
def __init__(self, modelname = "CompVis/stable-diffusion-v1-4"):
|
12
|
+
|
13
|
+
from diffusers.models.vae_flax import FlaxEncoder, FlaxDecoder
|
14
|
+
from diffusers import FlaxStableDiffusionPipeline
|
15
|
+
|
16
|
+
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
17
|
+
modelname,
|
18
|
+
revision="bf16",
|
19
|
+
dtype=jnp.bfloat16,
|
20
|
+
)
|
21
|
+
|
22
|
+
vae = pipeline.vae
|
23
|
+
|
24
|
+
enc = FlaxEncoder(
|
25
|
+
in_channels=vae.config.in_channels,
|
26
|
+
out_channels=vae.config.latent_channels,
|
27
|
+
down_block_types=vae.config.down_block_types,
|
28
|
+
block_out_channels=vae.config.block_out_channels,
|
29
|
+
layers_per_block=vae.config.layers_per_block,
|
30
|
+
act_fn=vae.config.act_fn,
|
31
|
+
norm_num_groups=vae.config.norm_num_groups,
|
32
|
+
double_z=True,
|
33
|
+
dtype=vae.dtype,
|
34
|
+
)
|
35
|
+
|
36
|
+
dec = FlaxDecoder(
|
37
|
+
in_channels=vae.config.latent_channels,
|
38
|
+
out_channels=vae.config.out_channels,
|
39
|
+
up_block_types=vae.config.up_block_types,
|
40
|
+
block_out_channels=vae.config.block_out_channels,
|
41
|
+
layers_per_block=vae.config.layers_per_block,
|
42
|
+
norm_num_groups=vae.config.norm_num_groups,
|
43
|
+
act_fn=vae.config.act_fn,
|
44
|
+
dtype=vae.dtype,
|
45
|
+
)
|
46
|
+
|
47
|
+
quant_conv = nn.Conv(
|
48
|
+
2 * vae.config.latent_channels,
|
49
|
+
kernel_size=(1, 1),
|
50
|
+
strides=(1, 1),
|
51
|
+
padding="VALID",
|
52
|
+
dtype=vae.dtype,
|
53
|
+
)
|
54
|
+
|
55
|
+
post_quant_conv = nn.Conv(
|
56
|
+
vae.config.latent_channels,
|
57
|
+
kernel_size=(1, 1),
|
58
|
+
strides=(1, 1),
|
59
|
+
padding="VALID",
|
60
|
+
dtype=vae.dtype,
|
61
|
+
)
|
62
|
+
|
63
|
+
self.enc = enc
|
64
|
+
self.dec = dec
|
65
|
+
self.post_quant_conv = post_quant_conv
|
66
|
+
self.quant_conv = quant_conv
|
67
|
+
self.params = params
|
68
|
+
self.scaling_factor = vae.scaling_factor
|
69
|
+
|
70
|
+
def encode(self, images, rngkey: jax.random.PRNGKey = None):
|
71
|
+
latents = self.enc.apply({"params": self.params["vae"]['encoder']}, images, deterministic=True)
|
72
|
+
latents = self.quant_conv.apply({"params": self.params["vae"]['quant_conv']}, latents)
|
73
|
+
if rngkey is not None:
|
74
|
+
mean, log_std = jnp.split(latents, 2, axis=-1)
|
75
|
+
log_std = jnp.clip(log_std, -30, 20)
|
76
|
+
std = jnp.exp(0.5 * log_std)
|
77
|
+
latents = mean + std * jax.random.normal(rngkey, mean.shape, dtype=mean.dtype)
|
78
|
+
print("Sampled")
|
79
|
+
else:
|
80
|
+
# return the mean
|
81
|
+
latents, _ = jnp.split(latents, 2, axis=-1)
|
82
|
+
latents *= self.scaling_factor
|
83
|
+
return latents
|
84
|
+
|
85
|
+
def decode(self, latents):
|
86
|
+
latents = (1.0 / self.scaling_factor) * latents
|
87
|
+
latents = self.post_quant_conv.apply({"params": self.params["vae"]['post_quant_conv']}, latents)
|
88
|
+
return self.dec.apply({"params": self.params["vae"]['decoder']}, latents)
|