flaxdiff 0.1.4__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -62,8 +62,13 @@ class EfficientAttention(nn.Module):
62
62
  # x has shape [B, H * W, C]
63
63
  context = x if context is None else context
64
64
 
65
- B, H, W, C = x.shape
66
- x = x.reshape((B, 1, H * W, C))
65
+ orig_x_shape = x.shape
66
+ if len(x.shape) == 4:
67
+ B, H, W, C = x.shape
68
+ x = x.reshape((B, 1, H * W, C))
69
+ else:
70
+ B, SEQ, C = x.shape
71
+ x = x.reshape((B, 1, SEQ, C))
67
72
 
68
73
  if len(context.shape) == 4:
69
74
  B, _H, _W, _C = context.shape
@@ -93,7 +98,7 @@ class EfficientAttention(nn.Module):
93
98
 
94
99
  proj = self.proj_attn(hidden_states)
95
100
 
96
- proj = proj.reshape((B, H, W, C))
101
+ proj = proj.reshape(orig_x_shape)
97
102
 
98
103
  return proj
99
104
 
@@ -138,8 +143,10 @@ class NormalAttention(nn.Module):
138
143
  @nn.compact
139
144
  def __call__(self, x, context=None):
140
145
  # x has shape [B, H, W, C]
141
- B, H, W, C = x.shape
142
- x = x.reshape((B, H*W, C))
146
+ orig_x_shape = x.shape
147
+ if len(x.shape) == 4:
148
+ B, H, W, C = x.shape
149
+ x = x.reshape((B, H*W, C))
143
150
  context = x if context is None else context
144
151
  if len(context.shape) == 4:
145
152
  context = context.reshape((B, H*W, C))
@@ -151,10 +158,10 @@ class NormalAttention(nn.Module):
151
158
  query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision
152
159
  )
153
160
  proj = self.proj_attn(hidden_states)
154
- proj = proj.reshape((B, H, W, C))
161
+ proj = proj.reshape(orig_x_shape)
155
162
  return proj
156
-
157
- class AttentionBlock(nn.Module):
163
+
164
+ class BasicTransformerBlock(nn.Module):
158
165
  # Has self and cross attention
159
166
  query_dim: int
160
167
  heads: int = 4
@@ -193,129 +200,26 @@ class AttentionBlock(nn.Module):
193
200
  kernel_init=self.kernel_init
194
201
  )
195
202
 
196
- self.ff = nn.DenseGeneral(
197
- features=self.query_dim,
198
- use_bias=self.use_bias,
199
- precision=self.precision,
200
- dtype=self.dtype,
201
- kernel_init=self.kernel_init(),
202
- name="ff"
203
- )
203
+ self.ff = FlaxFeedForward(dim=self.query_dim)
204
204
  self.norm1 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
205
205
  self.norm2 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
206
206
  self.norm3 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
207
- self.norm4 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
208
207
 
209
208
  @nn.compact
210
209
  def __call__(self, hidden_states, context=None):
211
210
  # self attention
212
- residual = hidden_states
213
- hidden_states = self.norm1(hidden_states)
214
- if self.use_cross_only:
215
- hidden_states = self.attention1(hidden_states, context)
216
- else:
217
- hidden_states = self.attention1(hidden_states)
218
- hidden_states = hidden_states + residual
211
+ if not self.use_cross_only:
212
+ print("Using self attention")
213
+ hidden_states = hidden_states + self.attention1(self.norm1(hidden_states))
219
214
 
220
215
  # cross attention
221
- residual = hidden_states
222
- hidden_states = self.norm2(hidden_states)
223
- hidden_states = self.attention2(hidden_states, context)
224
- hidden_states = hidden_states + residual
216
+ hidden_states = hidden_states + self.attention2(self.norm2(hidden_states), context)
225
217
 
226
218
  # feed forward
227
- residual = hidden_states
228
- hidden_states = self.norm3(hidden_states)
229
- hidden_states = nn.gelu(hidden_states)
230
- hidden_states = self.ff(hidden_states)
231
- hidden_states = hidden_states + residual
219
+ hidden_states = hidden_states + self.ff(self.norm3(hidden_states))
232
220
 
233
221
  return hidden_states
234
222
 
235
- class TransformerBlock(nn.Module):
236
- heads: int = 4
237
- dim_head: int = 32
238
- use_linear_attention: bool = True
239
- dtype: Any = jnp.float32
240
- precision: Any = jax.lax.Precision.HIGH
241
- use_projection: bool = False
242
- use_flash_attention:bool = True
243
- use_self_and_cross:bool = False
244
-
245
- @nn.compact
246
- def __call__(self, x, context=None):
247
- inner_dim = self.heads * self.dim_head
248
- B, H, W, C = x.shape
249
- normed_x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
250
- if self.use_projection == True:
251
- if self.use_linear_attention:
252
- projected_x = nn.Dense(features=inner_dim,
253
- use_bias=False, precision=self.precision,
254
- kernel_init=kernel_init(1.0),
255
- dtype=self.dtype, name=f'project_in')(normed_x)
256
- else:
257
- projected_x = nn.Conv(
258
- features=inner_dim, kernel_size=(1, 1),
259
- kernel_init=kernel_init(1.0),
260
- strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
261
- precision=self.precision, name=f'project_in_conv',
262
- )(normed_x)
263
- else:
264
- projected_x = normed_x
265
- inner_dim = C
266
-
267
- context = projected_x if context is None else context
268
-
269
- if self.use_self_and_cross:
270
- projected_x = AttentionBlock(
271
- query_dim=inner_dim,
272
- heads=self.heads,
273
- dim_head=self.dim_head,
274
- name=f'Attention',
275
- precision=self.precision,
276
- use_bias=False,
277
- dtype=self.dtype,
278
- use_flash_attention=self.use_flash_attention,
279
- use_cross_only=False
280
- )(projected_x, context)
281
- elif self.use_flash_attention == True:
282
- projected_x = EfficientAttention(
283
- query_dim=inner_dim,
284
- heads=self.heads,
285
- dim_head=self.dim_head,
286
- name=f'Attention',
287
- precision=self.precision,
288
- use_bias=False,
289
- dtype=self.dtype,
290
- )(projected_x, context)
291
- else:
292
- projected_x = NormalAttention(
293
- query_dim=inner_dim,
294
- heads=self.heads,
295
- dim_head=self.dim_head,
296
- name=f'Attention',
297
- precision=self.precision,
298
- use_bias=False,
299
- )(projected_x, context)
300
-
301
-
302
- if self.use_projection == True:
303
- if self.use_linear_attention:
304
- projected_x = nn.Dense(features=C, precision=self.precision,
305
- dtype=self.dtype, use_bias=False,
306
- kernel_init=kernel_init(1.0),
307
- name=f'project_out')(projected_x)
308
- else:
309
- projected_x = nn.Conv(
310
- features=C, kernel_size=(1, 1),
311
- kernel_init=kernel_init(1.0),
312
- strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
313
- precision=self.precision, name=f'project_out_conv',
314
- )(projected_x)
315
-
316
- out = x + projected_x
317
- return out
318
-
319
223
  class FlaxGEGLU(nn.Module):
320
224
  r"""
321
225
  Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
@@ -333,10 +237,11 @@ class FlaxGEGLU(nn.Module):
333
237
  dim: int
334
238
  dropout: float = 0.0
335
239
  dtype: jnp.dtype = jnp.float32
240
+ precision: Any = jax.lax.Precision.DEFAULT
336
241
 
337
242
  def setup(self):
338
243
  inner_dim = self.dim * 4
339
- self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype, precision=jax.lax.Precision.DEFAULT)
244
+ self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype, precision=self.precision)
340
245
 
341
246
  def __call__(self, hidden_states):
342
247
  hidden_states = self.proj(hidden_states)
@@ -362,14 +267,14 @@ class FlaxFeedForward(nn.Module):
362
267
  """
363
268
 
364
269
  dim: int
365
- dropout: float = 0.0
366
270
  dtype: jnp.dtype = jnp.float32
271
+ precision: Any = jax.lax.Precision.DEFAULT
367
272
 
368
273
  def setup(self):
369
274
  # The second linear layer needs to be called
370
275
  # net_2 for now to match the index of the Sequential layer
371
- self.net_0 = FlaxGEGLU(self.dim, self.dtype)
372
- self.net_2 = nn.Dense(self.dim, dtype=self.dtype, precision=jax.lax.Precision.DEFAULT)
276
+ self.net_0 = FlaxGEGLU(self.dim, self.dtype, precision=self.precision)
277
+ self.net_2 = nn.Dense(self.dim, dtype=self.dtype, precision=self.precision)
373
278
 
374
279
  def __call__(self, hidden_states):
375
280
  hidden_states = self.net_0(hidden_states)
@@ -377,55 +282,127 @@ class FlaxFeedForward(nn.Module):
377
282
  return hidden_states
378
283
 
379
284
  class BasicTransformerBlock(nn.Module):
285
+ # Has self and cross attention
380
286
  query_dim: int
381
- heads: int
382
- dim_head: int
383
- dropout: float = 0.0
384
- only_cross_attention: bool = False
385
- dtype: jnp.dtype = jnp.float32
386
- use_memory_efficient_attention: bool = False
387
- split_head_dim: bool = False
388
- precision: Any = jax.lax.Precision.DEFAULT
389
-
287
+ heads: int = 4
288
+ dim_head: int = 64
289
+ dtype: Any = jnp.float32
290
+ precision: Any = jax.lax.Precision.HIGHEST
291
+ use_bias: bool = True
292
+ kernel_init: Callable = lambda : kernel_init(1.0)
293
+ use_flash_attention:bool = False
294
+ use_cross_only:bool = False
295
+ only_pure_attention:bool = False
296
+
390
297
  def setup(self):
391
- # self attention (or cross_attention if only_cross_attention is True)
392
- self.attn1 = NormalAttention(
393
- query_dim=self.query_dim,
298
+ if self.use_flash_attention:
299
+ attenBlock = EfficientAttention
300
+ else:
301
+ attenBlock = NormalAttention
302
+
303
+ self.attention1 = attenBlock(
304
+ query_dim=self.query_dim,
394
305
  heads=self.heads,
395
306
  dim_head=self.dim_head,
396
- dtype=self.dtype,
307
+ name=f'Attention1',
397
308
  precision=self.precision,
309
+ use_bias=self.use_bias,
310
+ dtype=self.dtype,
311
+ kernel_init=self.kernel_init
398
312
  )
399
- # cross attention
400
- self.attn2 = NormalAttention(
313
+ self.attention2 = attenBlock(
401
314
  query_dim=self.query_dim,
402
315
  heads=self.heads,
403
316
  dim_head=self.dim_head,
404
- dtype=self.dtype,
317
+ name=f'Attention2',
405
318
  precision=self.precision,
319
+ use_bias=self.use_bias,
320
+ dtype=self.dtype,
321
+ kernel_init=self.kernel_init
406
322
  )
407
- self.ff = FlaxFeedForward(dim=self.query_dim, dropout=self.dropout, dtype=self.dtype)
323
+
324
+ self.ff = FlaxFeedForward(dim=self.query_dim)
408
325
  self.norm1 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
409
326
  self.norm2 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
410
327
  self.norm3 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
411
-
412
- def __call__(self, hidden_states, context, deterministic=True):
328
+
329
+ @nn.compact
330
+ def __call__(self, hidden_states, context=None):
331
+ if self.only_pure_attention:
332
+ return self.attention2(self.norm2(hidden_states), context)
333
+
413
334
  # self attention
414
- residual = hidden_states
415
- if self.only_cross_attention:
416
- hidden_states = self.attn1(self.norm1(hidden_states), context)
417
- else:
418
- hidden_states = self.attn1(self.norm1(hidden_states))
419
- hidden_states = hidden_states + residual
420
-
335
+ if not self.use_cross_only:
336
+ hidden_states = hidden_states + self.attention1(self.norm1(hidden_states))
337
+
421
338
  # cross attention
422
- residual = hidden_states
423
- hidden_states = self.attn2(self.norm2(hidden_states), context)
424
- hidden_states = hidden_states + residual
425
-
339
+ hidden_states = hidden_states + self.attention2(self.norm2(hidden_states), context)
426
340
  # feed forward
427
- residual = hidden_states
428
- hidden_states = self.ff(self.norm3(hidden_states))
429
- hidden_states = hidden_states + residual
341
+ hidden_states = hidden_states + self.ff(self.norm3(hidden_states))
342
+
343
+ return hidden_states
430
344
 
431
- return hidden_states
345
+ class TransformerBlock(nn.Module):
346
+ heads: int = 4
347
+ dim_head: int = 32
348
+ use_linear_attention: bool = True
349
+ dtype: Any = jnp.float32
350
+ precision: Any = jax.lax.Precision.HIGH
351
+ use_projection: bool = False
352
+ use_flash_attention:bool = True
353
+ use_self_and_cross:bool = False
354
+ only_pure_attention:bool = False
355
+
356
+ @nn.compact
357
+ def __call__(self, x, context=None):
358
+ inner_dim = self.heads * self.dim_head
359
+ B, H, W, C = x.shape
360
+ normed_x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
361
+ if self.use_projection == True:
362
+ if self.use_linear_attention:
363
+ projected_x = nn.Dense(features=inner_dim,
364
+ use_bias=False, precision=self.precision,
365
+ kernel_init=kernel_init(1.0),
366
+ dtype=self.dtype, name=f'project_in')(normed_x)
367
+ else:
368
+ projected_x = nn.Conv(
369
+ features=inner_dim, kernel_size=(1, 1),
370
+ kernel_init=kernel_init(1.0),
371
+ strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
372
+ precision=self.precision, name=f'project_in_conv',
373
+ )(normed_x)
374
+ else:
375
+ projected_x = normed_x
376
+ inner_dim = C
377
+
378
+ context = projected_x if context is None else context
379
+
380
+ projected_x = BasicTransformerBlock(
381
+ query_dim=inner_dim,
382
+ heads=self.heads,
383
+ dim_head=self.dim_head,
384
+ name=f'Attention',
385
+ precision=self.precision,
386
+ use_bias=False,
387
+ dtype=self.dtype,
388
+ use_flash_attention=self.use_flash_attention,
389
+ use_cross_only=(not self.use_self_and_cross),
390
+ only_pure_attention=self.only_pure_attention
391
+ )(projected_x, context)
392
+
393
+ if self.use_projection == True:
394
+ if self.use_linear_attention:
395
+ projected_x = nn.Dense(features=C, precision=self.precision,
396
+ dtype=self.dtype, use_bias=False,
397
+ kernel_init=kernel_init(1.0),
398
+ name=f'project_out')(projected_x)
399
+ else:
400
+ projected_x = nn.Conv(
401
+ features=C, kernel_size=(1, 1),
402
+ kernel_init=kernel_init(1.0),
403
+ strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
404
+ precision=self.precision, name=f'project_out_conv',
405
+ )(projected_x)
406
+
407
+ out = x + projected_x
408
+ return out
File without changes
@@ -0,0 +1,14 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from flax import linen as nn
4
+ from typing import Dict, Callable, Sequence, Any, Union
5
+ import einops
6
+ from ..common import kernel_init, ConvLayer, Upsample, Downsample, PixelShuffle
7
+
8
+
9
+ class AutoEncoder:
10
+ def encode(self, x: jnp.ndarray, **kwargs) -> jnp.ndarray:
11
+ raise NotImplementedError
12
+
13
+ def decode(self, z: jnp.ndarray, **kwargs) -> jnp.ndarray:
14
+ raise NotImplementedError
@@ -0,0 +1,88 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from flax import linen as nn
4
+ from .autoencoder import AutoEncoder
5
+
6
+ """
7
+ This module contains an Autoencoder implementation which uses the Stable Diffusion VAE model from the HuggingFace Diffusers library.
8
+ """
9
+
10
+ class StableDiffusionVAE(AutoEncoder):
11
+ def __init__(self, modelname = "CompVis/stable-diffusion-v1-4"):
12
+
13
+ from diffusers.models.vae_flax import FlaxEncoder, FlaxDecoder
14
+ from diffusers import FlaxStableDiffusionPipeline
15
+
16
+ pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
17
+ modelname,
18
+ revision="bf16",
19
+ dtype=jnp.bfloat16,
20
+ )
21
+
22
+ vae = pipeline.vae
23
+
24
+ enc = FlaxEncoder(
25
+ in_channels=vae.config.in_channels,
26
+ out_channels=vae.config.latent_channels,
27
+ down_block_types=vae.config.down_block_types,
28
+ block_out_channels=vae.config.block_out_channels,
29
+ layers_per_block=vae.config.layers_per_block,
30
+ act_fn=vae.config.act_fn,
31
+ norm_num_groups=vae.config.norm_num_groups,
32
+ double_z=True,
33
+ dtype=vae.dtype,
34
+ )
35
+
36
+ dec = FlaxDecoder(
37
+ in_channels=vae.config.latent_channels,
38
+ out_channels=vae.config.out_channels,
39
+ up_block_types=vae.config.up_block_types,
40
+ block_out_channels=vae.config.block_out_channels,
41
+ layers_per_block=vae.config.layers_per_block,
42
+ norm_num_groups=vae.config.norm_num_groups,
43
+ act_fn=vae.config.act_fn,
44
+ dtype=vae.dtype,
45
+ )
46
+
47
+ quant_conv = nn.Conv(
48
+ 2 * vae.config.latent_channels,
49
+ kernel_size=(1, 1),
50
+ strides=(1, 1),
51
+ padding="VALID",
52
+ dtype=vae.dtype,
53
+ )
54
+
55
+ post_quant_conv = nn.Conv(
56
+ vae.config.latent_channels,
57
+ kernel_size=(1, 1),
58
+ strides=(1, 1),
59
+ padding="VALID",
60
+ dtype=vae.dtype,
61
+ )
62
+
63
+ self.enc = enc
64
+ self.dec = dec
65
+ self.post_quant_conv = post_quant_conv
66
+ self.quant_conv = quant_conv
67
+ self.params = params
68
+ self.scaling_factor = vae.scaling_factor
69
+
70
+ def encode(self, images, rngkey: jax.random.PRNGKey = None):
71
+ latents = self.enc.apply({"params": self.params["vae"]['encoder']}, images, deterministic=True)
72
+ latents = self.quant_conv.apply({"params": self.params["vae"]['quant_conv']}, latents)
73
+ if rngkey is not None:
74
+ mean, log_std = jnp.split(latents, 2, axis=-1)
75
+ log_std = jnp.clip(log_std, -30, 20)
76
+ std = jnp.exp(0.5 * log_std)
77
+ latents = mean + std * jax.random.normal(rngkey, mean.shape, dtype=mean.dtype)
78
+ print("Sampled")
79
+ else:
80
+ # return the mean
81
+ latents, _ = jnp.split(latents, 2, axis=-1)
82
+ latents *= self.scaling_factor
83
+ return latents
84
+
85
+ def decode(self, latents):
86
+ latents = (1.0 / self.scaling_factor) * latents
87
+ latents = self.post_quant_conv.apply({"params": self.params["vae"]['post_quant_conv']}, latents)
88
+ return self.dec.apply({"params": self.params["vae"]['decoder']}, latents)