flaxdiff 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,7 +5,8 @@ Some Code ported from https://github.com/huggingface/diffusers/blob/main/src/dif
5
5
  import jax
6
6
  import jax.numpy as jnp
7
7
  from flax import linen as nn
8
- from typing import Dict, Callable, Sequence, Any, Union
8
+ from typing import Dict, Callable, Sequence, Any, Union, Tuple, Optional
9
+ from flax.typing import Dtype, PrecisionLike
9
10
  import einops
10
11
  import functools
11
12
  import math
@@ -18,8 +19,8 @@ class EfficientAttention(nn.Module):
18
19
  query_dim: int
19
20
  heads: int = 4
20
21
  dim_head: int = 64
21
- dtype: Any = jnp.float32
22
- precision: Any = jax.lax.Precision.HIGHEST
22
+ dtype: Optional[Dtype] = None
23
+ precision: PrecisionLike = None
23
24
  use_bias: bool = True
24
25
  kernel_init: Callable = lambda : kernel_init(1.0)
25
26
 
@@ -62,8 +63,13 @@ class EfficientAttention(nn.Module):
62
63
  # x has shape [B, H * W, C]
63
64
  context = x if context is None else context
64
65
 
65
- B, H, W, C = x.shape
66
- x = x.reshape((B, 1, H * W, C))
66
+ orig_x_shape = x.shape
67
+ if len(x.shape) == 4:
68
+ B, H, W, C = x.shape
69
+ x = x.reshape((B, 1, H * W, C))
70
+ else:
71
+ B, SEQ, C = x.shape
72
+ x = x.reshape((B, 1, SEQ, C))
67
73
 
68
74
  if len(context.shape) == 4:
69
75
  B, _H, _W, _C = context.shape
@@ -93,7 +99,7 @@ class EfficientAttention(nn.Module):
93
99
 
94
100
  proj = self.proj_attn(hidden_states)
95
101
 
96
- proj = proj.reshape((B, H, W, C))
102
+ proj = proj.reshape(orig_x_shape)
97
103
 
98
104
  return proj
99
105
 
@@ -104,8 +110,8 @@ class NormalAttention(nn.Module):
104
110
  query_dim: int
105
111
  heads: int = 4
106
112
  dim_head: int = 64
107
- dtype: Any = jnp.float32
108
- precision: Any = jax.lax.Precision.HIGHEST
113
+ dtype: Optional[Dtype] = None
114
+ precision: PrecisionLike = None
109
115
  use_bias: bool = True
110
116
  kernel_init: Callable = lambda : kernel_init(1.0)
111
117
 
@@ -138,8 +144,10 @@ class NormalAttention(nn.Module):
138
144
  @nn.compact
139
145
  def __call__(self, x, context=None):
140
146
  # x has shape [B, H, W, C]
141
- B, H, W, C = x.shape
142
- x = x.reshape((B, H*W, C))
147
+ orig_x_shape = x.shape
148
+ if len(x.shape) == 4:
149
+ B, H, W, C = x.shape
150
+ x = x.reshape((B, H*W, C))
143
151
  context = x if context is None else context
144
152
  if len(context.shape) == 4:
145
153
  context = context.reshape((B, H*W, C))
@@ -151,16 +159,16 @@ class NormalAttention(nn.Module):
151
159
  query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision
152
160
  )
153
161
  proj = self.proj_attn(hidden_states)
154
- proj = proj.reshape((B, H, W, C))
162
+ proj = proj.reshape(orig_x_shape)
155
163
  return proj
156
-
157
- class AttentionBlock(nn.Module):
164
+
165
+ class BasicTransformerBlock(nn.Module):
158
166
  # Has self and cross attention
159
167
  query_dim: int
160
168
  heads: int = 4
161
169
  dim_head: int = 64
162
- dtype: Any = jnp.float32
163
- precision: Any = jax.lax.Precision.HIGHEST
170
+ dtype: Optional[Dtype] = None
171
+ precision: PrecisionLike = None
164
172
  use_bias: bool = True
165
173
  kernel_init: Callable = lambda : kernel_init(1.0)
166
174
  use_flash_attention:bool = False
@@ -193,129 +201,26 @@ class AttentionBlock(nn.Module):
193
201
  kernel_init=self.kernel_init
194
202
  )
195
203
 
196
- self.ff = nn.DenseGeneral(
197
- features=self.query_dim,
198
- use_bias=self.use_bias,
199
- precision=self.precision,
200
- dtype=self.dtype,
201
- kernel_init=self.kernel_init(),
202
- name="ff"
203
- )
204
+ self.ff = FlaxFeedForward(dim=self.query_dim)
204
205
  self.norm1 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
205
206
  self.norm2 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
206
207
  self.norm3 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
207
- self.norm4 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
208
208
 
209
209
  @nn.compact
210
210
  def __call__(self, hidden_states, context=None):
211
211
  # self attention
212
- residual = hidden_states
213
- hidden_states = self.norm1(hidden_states)
214
- if self.use_cross_only:
215
- hidden_states = self.attention1(hidden_states, context)
216
- else:
217
- hidden_states = self.attention1(hidden_states)
218
- hidden_states = hidden_states + residual
212
+ if not self.use_cross_only:
213
+ print("Using self attention")
214
+ hidden_states = hidden_states + self.attention1(self.norm1(hidden_states))
219
215
 
220
216
  # cross attention
221
- residual = hidden_states
222
- hidden_states = self.norm2(hidden_states)
223
- hidden_states = self.attention2(hidden_states, context)
224
- hidden_states = hidden_states + residual
217
+ hidden_states = hidden_states + self.attention2(self.norm2(hidden_states), context)
225
218
 
226
219
  # feed forward
227
- residual = hidden_states
228
- hidden_states = self.norm3(hidden_states)
229
- hidden_states = nn.gelu(hidden_states)
230
- hidden_states = self.ff(hidden_states)
231
- hidden_states = hidden_states + residual
220
+ hidden_states = hidden_states + self.ff(self.norm3(hidden_states))
232
221
 
233
222
  return hidden_states
234
223
 
235
- class TransformerBlock(nn.Module):
236
- heads: int = 4
237
- dim_head: int = 32
238
- use_linear_attention: bool = True
239
- dtype: Any = jnp.float32
240
- precision: Any = jax.lax.Precision.HIGH
241
- use_projection: bool = False
242
- use_flash_attention:bool = True
243
- use_self_and_cross:bool = False
244
-
245
- @nn.compact
246
- def __call__(self, x, context=None):
247
- inner_dim = self.heads * self.dim_head
248
- B, H, W, C = x.shape
249
- normed_x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
250
- if self.use_projection == True:
251
- if self.use_linear_attention:
252
- projected_x = nn.Dense(features=inner_dim,
253
- use_bias=False, precision=self.precision,
254
- kernel_init=kernel_init(1.0),
255
- dtype=self.dtype, name=f'project_in')(normed_x)
256
- else:
257
- projected_x = nn.Conv(
258
- features=inner_dim, kernel_size=(1, 1),
259
- kernel_init=kernel_init(1.0),
260
- strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
261
- precision=self.precision, name=f'project_in_conv',
262
- )(normed_x)
263
- else:
264
- projected_x = normed_x
265
- inner_dim = C
266
-
267
- context = projected_x if context is None else context
268
-
269
- if self.use_self_and_cross:
270
- projected_x = AttentionBlock(
271
- query_dim=inner_dim,
272
- heads=self.heads,
273
- dim_head=self.dim_head,
274
- name=f'Attention',
275
- precision=self.precision,
276
- use_bias=False,
277
- dtype=self.dtype,
278
- use_flash_attention=self.use_flash_attention,
279
- use_cross_only=False
280
- )(projected_x, context)
281
- elif self.use_flash_attention == True:
282
- projected_x = EfficientAttention(
283
- query_dim=inner_dim,
284
- heads=self.heads,
285
- dim_head=self.dim_head,
286
- name=f'Attention',
287
- precision=self.precision,
288
- use_bias=False,
289
- dtype=self.dtype,
290
- )(projected_x, context)
291
- else:
292
- projected_x = NormalAttention(
293
- query_dim=inner_dim,
294
- heads=self.heads,
295
- dim_head=self.dim_head,
296
- name=f'Attention',
297
- precision=self.precision,
298
- use_bias=False,
299
- )(projected_x, context)
300
-
301
-
302
- if self.use_projection == True:
303
- if self.use_linear_attention:
304
- projected_x = nn.Dense(features=C, precision=self.precision,
305
- dtype=self.dtype, use_bias=False,
306
- kernel_init=kernel_init(1.0),
307
- name=f'project_out')(projected_x)
308
- else:
309
- projected_x = nn.Conv(
310
- features=C, kernel_size=(1, 1),
311
- kernel_init=kernel_init(1.0),
312
- strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
313
- precision=self.precision, name=f'project_out_conv',
314
- )(projected_x)
315
-
316
- out = x + projected_x
317
- return out
318
-
319
224
  class FlaxGEGLU(nn.Module):
320
225
  r"""
321
226
  Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
@@ -333,10 +238,11 @@ class FlaxGEGLU(nn.Module):
333
238
  dim: int
334
239
  dropout: float = 0.0
335
240
  dtype: jnp.dtype = jnp.float32
241
+ precision: Any = jax.lax.Precision.DEFAULT
336
242
 
337
243
  def setup(self):
338
244
  inner_dim = self.dim * 4
339
- self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype, precision=jax.lax.Precision.DEFAULT)
245
+ self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype, precision=self.precision)
340
246
 
341
247
  def __call__(self, hidden_states):
342
248
  hidden_states = self.proj(hidden_states)
@@ -362,14 +268,14 @@ class FlaxFeedForward(nn.Module):
362
268
  """
363
269
 
364
270
  dim: int
365
- dropout: float = 0.0
366
271
  dtype: jnp.dtype = jnp.float32
272
+ precision: Any = jax.lax.Precision.DEFAULT
367
273
 
368
274
  def setup(self):
369
275
  # The second linear layer needs to be called
370
276
  # net_2 for now to match the index of the Sequential layer
371
- self.net_0 = FlaxGEGLU(self.dim, self.dtype)
372
- self.net_2 = nn.Dense(self.dim, dtype=self.dtype, precision=jax.lax.Precision.DEFAULT)
277
+ self.net_0 = FlaxGEGLU(self.dim, self.dtype, precision=self.precision)
278
+ self.net_2 = nn.Dense(self.dim, dtype=self.dtype, precision=self.precision)
373
279
 
374
280
  def __call__(self, hidden_states):
375
281
  hidden_states = self.net_0(hidden_states)
@@ -377,55 +283,127 @@ class FlaxFeedForward(nn.Module):
377
283
  return hidden_states
378
284
 
379
285
  class BasicTransformerBlock(nn.Module):
286
+ # Has self and cross attention
380
287
  query_dim: int
381
- heads: int
382
- dim_head: int
383
- dropout: float = 0.0
384
- only_cross_attention: bool = False
385
- dtype: jnp.dtype = jnp.float32
386
- use_memory_efficient_attention: bool = False
387
- split_head_dim: bool = False
388
- precision: Any = jax.lax.Precision.DEFAULT
389
-
288
+ heads: int = 4
289
+ dim_head: int = 64
290
+ dtype: Optional[Dtype] = None
291
+ precision: PrecisionLike = None
292
+ use_bias: bool = True
293
+ kernel_init: Callable = lambda : kernel_init(1.0)
294
+ use_flash_attention:bool = False
295
+ use_cross_only:bool = False
296
+ only_pure_attention:bool = False
297
+
390
298
  def setup(self):
391
- # self attention (or cross_attention if only_cross_attention is True)
392
- self.attn1 = NormalAttention(
393
- query_dim=self.query_dim,
299
+ if self.use_flash_attention:
300
+ attenBlock = EfficientAttention
301
+ else:
302
+ attenBlock = NormalAttention
303
+
304
+ self.attention1 = attenBlock(
305
+ query_dim=self.query_dim,
394
306
  heads=self.heads,
395
307
  dim_head=self.dim_head,
396
- dtype=self.dtype,
308
+ name=f'Attention1',
397
309
  precision=self.precision,
310
+ use_bias=self.use_bias,
311
+ dtype=self.dtype,
312
+ kernel_init=self.kernel_init
398
313
  )
399
- # cross attention
400
- self.attn2 = NormalAttention(
314
+ self.attention2 = attenBlock(
401
315
  query_dim=self.query_dim,
402
316
  heads=self.heads,
403
317
  dim_head=self.dim_head,
404
- dtype=self.dtype,
318
+ name=f'Attention2',
405
319
  precision=self.precision,
320
+ use_bias=self.use_bias,
321
+ dtype=self.dtype,
322
+ kernel_init=self.kernel_init
406
323
  )
407
- self.ff = FlaxFeedForward(dim=self.query_dim, dropout=self.dropout, dtype=self.dtype)
324
+
325
+ self.ff = FlaxFeedForward(dim=self.query_dim)
408
326
  self.norm1 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
409
327
  self.norm2 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
410
328
  self.norm3 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
411
-
412
- def __call__(self, hidden_states, context, deterministic=True):
329
+
330
+ @nn.compact
331
+ def __call__(self, hidden_states, context=None):
332
+ if self.only_pure_attention:
333
+ return self.attention2(self.norm2(hidden_states), context)
334
+
413
335
  # self attention
414
- residual = hidden_states
415
- if self.only_cross_attention:
416
- hidden_states = self.attn1(self.norm1(hidden_states), context)
417
- else:
418
- hidden_states = self.attn1(self.norm1(hidden_states))
419
- hidden_states = hidden_states + residual
420
-
336
+ if not self.use_cross_only:
337
+ hidden_states = hidden_states + self.attention1(self.norm1(hidden_states))
338
+
421
339
  # cross attention
422
- residual = hidden_states
423
- hidden_states = self.attn2(self.norm2(hidden_states), context)
424
- hidden_states = hidden_states + residual
425
-
340
+ hidden_states = hidden_states + self.attention2(self.norm2(hidden_states), context)
426
341
  # feed forward
427
- residual = hidden_states
428
- hidden_states = self.ff(self.norm3(hidden_states))
429
- hidden_states = hidden_states + residual
342
+ hidden_states = hidden_states + self.ff(self.norm3(hidden_states))
343
+
344
+ return hidden_states
345
+
346
+ class TransformerBlock(nn.Module):
347
+ heads: int = 4
348
+ dim_head: int = 32
349
+ use_linear_attention: bool = True
350
+ dtype: Optional[Dtype] = None
351
+ precision: PrecisionLike = None
352
+ use_projection: bool = False
353
+ use_flash_attention:bool = True
354
+ use_self_and_cross:bool = False
355
+ only_pure_attention:bool = False
356
+
357
+ @nn.compact
358
+ def __call__(self, x, context=None):
359
+ inner_dim = self.heads * self.dim_head
360
+ B, H, W, C = x.shape
361
+ normed_x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
362
+ if self.use_projection == True:
363
+ if self.use_linear_attention:
364
+ projected_x = nn.Dense(features=inner_dim,
365
+ use_bias=False, precision=self.precision,
366
+ kernel_init=kernel_init(1.0),
367
+ dtype=self.dtype, name=f'project_in')(normed_x)
368
+ else:
369
+ projected_x = nn.Conv(
370
+ features=inner_dim, kernel_size=(1, 1),
371
+ kernel_init=kernel_init(1.0),
372
+ strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
373
+ precision=self.precision, name=f'project_in_conv',
374
+ )(normed_x)
375
+ else:
376
+ projected_x = normed_x
377
+ inner_dim = C
378
+
379
+ context = projected_x if context is None else context
430
380
 
431
- return hidden_states
381
+ projected_x = BasicTransformerBlock(
382
+ query_dim=inner_dim,
383
+ heads=self.heads,
384
+ dim_head=self.dim_head,
385
+ name=f'Attention',
386
+ precision=self.precision,
387
+ use_bias=False,
388
+ dtype=self.dtype,
389
+ use_flash_attention=self.use_flash_attention,
390
+ use_cross_only=(not self.use_self_and_cross),
391
+ only_pure_attention=self.only_pure_attention
392
+ )(projected_x, context)
393
+
394
+ if self.use_projection == True:
395
+ if self.use_linear_attention:
396
+ projected_x = nn.Dense(features=C, precision=self.precision,
397
+ dtype=self.dtype, use_bias=False,
398
+ kernel_init=kernel_init(1.0),
399
+ name=f'project_out')(projected_x)
400
+ else:
401
+ projected_x = nn.Conv(
402
+ features=C, kernel_size=(1, 1),
403
+ kernel_init=kernel_init(1.0),
404
+ strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
405
+ precision=self.precision, name=f'project_out_conv',
406
+ )(projected_x)
407
+
408
+ out = x + projected_x
409
+ return out
@@ -0,0 +1,2 @@
1
+ from .autoencoder import AutoEncoder
2
+ from .diffusers import StableDiffusionVAE
@@ -0,0 +1,19 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from flax import linen as nn
4
+ from typing import Dict, Callable, Sequence, Any, Union
5
+ import einops
6
+ from ..common import kernel_init, ConvLayer, Upsample, Downsample, PixelShuffle
7
+
8
+
9
+ class AutoEncoder():
10
+ def encode(self, x: jnp.ndarray, **kwargs) -> jnp.ndarray:
11
+ raise NotImplementedError
12
+
13
+ def decode(self, z: jnp.ndarray, **kwargs) -> jnp.ndarray:
14
+ raise NotImplementedError
15
+
16
+ def __call__(self, x: jnp.ndarray):
17
+ latents = self.encode(x)
18
+ reconstructions = self.decode(latents)
19
+ return reconstructions
@@ -0,0 +1,91 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from flax import linen as nn
4
+ from .autoencoder import AutoEncoder
5
+
6
+ """
7
+ This module contains an Autoencoder implementation which uses the Stable Diffusion VAE model from the HuggingFace Diffusers library.
8
+ The actual model was not trained by me, but was taken from the HuggingFace model hub.
9
+ I have only implemented the wrapper around the diffusers pipeline to make it compatible with our library
10
+ All credits for the model go to the developers of Stable Diffusion VAE and all credits for the pipeline go to the developers of the Diffusers library.
11
+ """
12
+
13
+ class StableDiffusionVAE(AutoEncoder):
14
+ def __init__(self, modelname = "CompVis/stable-diffusion-v1-4"):
15
+
16
+ from diffusers.models.vae_flax import FlaxEncoder, FlaxDecoder
17
+ from diffusers import FlaxStableDiffusionPipeline
18
+
19
+ pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
20
+ modelname,
21
+ revision="bf16",
22
+ dtype=jnp.bfloat16,
23
+ )
24
+
25
+ vae = pipeline.vae
26
+
27
+ enc = FlaxEncoder(
28
+ in_channels=vae.config.in_channels,
29
+ out_channels=vae.config.latent_channels,
30
+ down_block_types=vae.config.down_block_types,
31
+ block_out_channels=vae.config.block_out_channels,
32
+ layers_per_block=vae.config.layers_per_block,
33
+ act_fn=vae.config.act_fn,
34
+ norm_num_groups=vae.config.norm_num_groups,
35
+ double_z=True,
36
+ dtype=vae.dtype,
37
+ )
38
+
39
+ dec = FlaxDecoder(
40
+ in_channels=vae.config.latent_channels,
41
+ out_channels=vae.config.out_channels,
42
+ up_block_types=vae.config.up_block_types,
43
+ block_out_channels=vae.config.block_out_channels,
44
+ layers_per_block=vae.config.layers_per_block,
45
+ norm_num_groups=vae.config.norm_num_groups,
46
+ act_fn=vae.config.act_fn,
47
+ dtype=vae.dtype,
48
+ )
49
+
50
+ quant_conv = nn.Conv(
51
+ 2 * vae.config.latent_channels,
52
+ kernel_size=(1, 1),
53
+ strides=(1, 1),
54
+ padding="VALID",
55
+ dtype=vae.dtype,
56
+ )
57
+
58
+ post_quant_conv = nn.Conv(
59
+ vae.config.latent_channels,
60
+ kernel_size=(1, 1),
61
+ strides=(1, 1),
62
+ padding="VALID",
63
+ dtype=vae.dtype,
64
+ )
65
+
66
+ self.enc = enc
67
+ self.dec = dec
68
+ self.post_quant_conv = post_quant_conv
69
+ self.quant_conv = quant_conv
70
+ self.params = params
71
+ self.scaling_factor = vae.scaling_factor
72
+
73
+ def encode(self, images, rngkey: jax.random.PRNGKey = None):
74
+ latents = self.enc.apply({"params": self.params["vae"]['encoder']}, images, deterministic=True)
75
+ latents = self.quant_conv.apply({"params": self.params["vae"]['quant_conv']}, latents)
76
+ if rngkey is not None:
77
+ mean, log_std = jnp.split(latents, 2, axis=-1)
78
+ log_std = jnp.clip(log_std, -30, 20)
79
+ std = jnp.exp(0.5 * log_std)
80
+ latents = mean + std * jax.random.normal(rngkey, mean.shape, dtype=mean.dtype)
81
+ print("Sampled")
82
+ else:
83
+ # return the mean
84
+ latents, _ = jnp.split(latents, 2, axis=-1)
85
+ latents *= self.scaling_factor
86
+ return latents
87
+
88
+ def decode(self, latents):
89
+ latents = (1.0 / self.scaling_factor) * latents
90
+ latents = self.post_quant_conv.apply({"params": self.params["vae"]['post_quant_conv']}, latents)
91
+ return self.dec.apply({"params": self.params["vae"]['decoder']}, latents)
@@ -0,0 +1,26 @@
1
+ from typing import Any, List, Optional, Callable
2
+ import jax
3
+ import flax.linen as nn
4
+ from jax import numpy as jnp
5
+ from flax.typing import Dtype, PrecisionLike
6
+ from .autoencoder import AutoEncoder
7
+
8
+ class SimpleAutoEncoder(AutoEncoder):
9
+ latent_channels: int
10
+ feature_depths: List[int]=[64, 128, 256, 512]
11
+ attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}],
12
+ num_res_blocks: int=2
13
+ num_middle_res_blocks:int=1,
14
+ activation:Callable = jax.nn.swish
15
+ norm_groups:int=8
16
+ dtype: Optional[Dtype] = None
17
+ precision: PrecisionLike = None
18
+
19
+ # def encode(self, x: jnp.ndarray):
20
+
21
+
22
+ @nn.compact
23
+ def __call__(self, x: jnp.ndarray):
24
+ latents = self.encode(x)
25
+ reconstructions = self.decode(latents)
26
+ return reconstructions