flaxdiff 0.1.3__tar.gz → 0.1.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {flaxdiff-0.1.3/flaxdiff.egg-info → flaxdiff-0.1.5}/PKG-INFO +10 -2
  2. flaxdiff-0.1.3/PKG-INFO → flaxdiff-0.1.5/README.md +9 -14
  3. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/flaxdiff/models/attention.py +132 -155
  4. flaxdiff-0.1.5/flaxdiff/models/autoencoder/__init__.py +0 -0
  5. flaxdiff-0.1.5/flaxdiff/models/autoencoder/autoencoder.py +14 -0
  6. flaxdiff-0.1.5/flaxdiff/models/autoencoder/diffusers.py +88 -0
  7. flaxdiff-0.1.5/flaxdiff/models/common.py +250 -0
  8. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/flaxdiff/models/simple_unet.py +17 -256
  9. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/flaxdiff/trainer/__init__.py +28 -45
  10. flaxdiff-0.1.5/flaxdiff/trainer/simple_trainer.py +418 -0
  11. flaxdiff-0.1.3/README.md → flaxdiff-0.1.5/flaxdiff.egg-info/PKG-INFO +22 -1
  12. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/flaxdiff.egg-info/SOURCES.txt +3 -0
  13. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/setup.py +1 -1
  14. flaxdiff-0.1.3/flaxdiff/models/common.py +0 -7
  15. flaxdiff-0.1.3/flaxdiff/trainer/simple_trainer.py +0 -323
  16. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/flaxdiff/__init__.py +0 -0
  17. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/flaxdiff/models/__init__.py +0 -0
  18. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/flaxdiff/models/favor_fastattn.py +0 -0
  19. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/flaxdiff/models/simple_vit.py +0 -0
  20. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/flaxdiff/predictors/__init__.py +0 -0
  21. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/flaxdiff/samplers/__init__.py +0 -0
  22. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/flaxdiff/samplers/common.py +0 -0
  23. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/flaxdiff/samplers/ddim.py +0 -0
  24. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/flaxdiff/samplers/ddpm.py +0 -0
  25. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/flaxdiff/samplers/euler.py +0 -0
  26. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/flaxdiff/samplers/heun_sampler.py +0 -0
  27. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/flaxdiff/samplers/multistep_dpm.py +0 -0
  28. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/flaxdiff/samplers/rk4_sampler.py +0 -0
  29. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/flaxdiff/schedulers/__init__.py +0 -0
  30. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/flaxdiff/schedulers/common.py +0 -0
  31. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/flaxdiff/schedulers/continuous.py +0 -0
  32. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/flaxdiff/schedulers/cosine.py +0 -0
  33. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/flaxdiff/schedulers/discrete.py +0 -0
  34. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/flaxdiff/schedulers/exp.py +0 -0
  35. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/flaxdiff/schedulers/karras.py +0 -0
  36. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/flaxdiff/schedulers/linear.py +0 -0
  37. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/flaxdiff/schedulers/sqrt.py +0 -0
  38. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/flaxdiff/utils.py +0 -0
  39. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/flaxdiff.egg-info/dependency_links.txt +0 -0
  40. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/flaxdiff.egg-info/requires.txt +0 -0
  41. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/flaxdiff.egg-info/top_level.txt +0 -0
  42. {flaxdiff-0.1.3 → flaxdiff-0.1.5}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.3
3
+ Version: 0.1.5
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -27,7 +27,7 @@ The `Diffusion_flax_linen.ipynb` notebook is my main workspace for experiments.
27
27
 
28
28
  In the `example notebooks` folder, you will find comprehensive notebooks for various diffusion techniques, written entirely from scratch and are independent of the FlaxDiff library. Each notebook includes detailed explanations of the underlying mathematics and concepts, making them invaluable resources for learning and understanding diffusion models.
29
29
 
30
- ### Available Notebooks
30
+ ### Available Notebooks and Resources
31
31
 
32
32
  - **[Diffusion explained (nbviewer link)](https://nbviewer.org/github/AshishKumar4/FlaxDiff/blob/main/tutorial%20notebooks/simple%20diffusion%20flax.ipynb) [(local link)](tutorial%20notebooks/simple%20diffusion%20flax.ipynb)**
33
33
 
@@ -46,6 +46,14 @@ In the `example notebooks` folder, you will find comprehensive notebooks for var
46
46
 
47
47
  These notebooks aim to provide a very easy to understand and step-by-step guide to the various diffusion models and techniques. They are designed to be beginner-friendly, and thus although they may not adhere to the exact formulations and implementations of the original papers to make them more understandable and generalizable, I have tried my best to keep them as accurate as possible. If you find any mistakes or have any suggestions, please feel free to open an issue or a pull request.
48
48
 
49
+ #### Other resources
50
+
51
+ - **[Multi-host Data parallel training script in JAX](./training.py)**
52
+ - Training script for multi-host data parallel training in JAX, to serve as a reference for training large models on multiple GPUs/TPUs across multiple hosts. A full-fledged tutorial notebook is in the works.
53
+
54
+ - **[TPU utilities for making life easier](./tpu-tools/)**
55
+ - A collection of utilities and scripts to make working with TPUs easier, such as cli to create/start/stop/setup TPUs, script to setup TPU VMs (install everything you need), mounting gcs datasets etc.
56
+
49
57
  ## Disclaimer (and About Me)
50
58
 
51
59
  I worked as a Machine Learning Researcher at Hyperverge from 2019-2021, focusing on computer vision, specifically facial anti-spoofing and facial detection & recognition. Since switching to my current job in 2021, I haven't engaged in as much R&D work, leading me to start this pet project to revisit and relearn the fundamentals and get familiar with the state-of-the-art. My current role involves primarily Golang system engineering with some applied ML work just sprinkled in. Therefore, the code may reflect my learning journey. Please forgive any mistakes and do open an issue to let me know.
@@ -1,16 +1,3 @@
1
- Metadata-Version: 2.1
2
- Name: flaxdiff
3
- Version: 0.1.3
4
- Summary: A versatile and easy to understand Diffusion library
5
- Author: Ashish Kumar Singh
6
- Author-email: ashishkmr472@gmail.com
7
- Description-Content-Type: text/markdown
8
- Requires-Dist: flax>=0.8.4
9
- Requires-Dist: optax>=0.2.2
10
- Requires-Dist: jax>=0.4.28
11
- Requires-Dist: orbax
12
- Requires-Dist: clu
13
-
14
1
  # ![](images/logo.jpeg "FlaxDiff")
15
2
 
16
3
  ## A Versatile and simple Diffusion Library
@@ -27,7 +14,7 @@ The `Diffusion_flax_linen.ipynb` notebook is my main workspace for experiments.
27
14
 
28
15
  In the `example notebooks` folder, you will find comprehensive notebooks for various diffusion techniques, written entirely from scratch and are independent of the FlaxDiff library. Each notebook includes detailed explanations of the underlying mathematics and concepts, making them invaluable resources for learning and understanding diffusion models.
29
16
 
30
- ### Available Notebooks
17
+ ### Available Notebooks and Resources
31
18
 
32
19
  - **[Diffusion explained (nbviewer link)](https://nbviewer.org/github/AshishKumar4/FlaxDiff/blob/main/tutorial%20notebooks/simple%20diffusion%20flax.ipynb) [(local link)](tutorial%20notebooks/simple%20diffusion%20flax.ipynb)**
33
20
 
@@ -46,6 +33,14 @@ In the `example notebooks` folder, you will find comprehensive notebooks for var
46
33
 
47
34
  These notebooks aim to provide a very easy to understand and step-by-step guide to the various diffusion models and techniques. They are designed to be beginner-friendly, and thus although they may not adhere to the exact formulations and implementations of the original papers to make them more understandable and generalizable, I have tried my best to keep them as accurate as possible. If you find any mistakes or have any suggestions, please feel free to open an issue or a pull request.
48
35
 
36
+ #### Other resources
37
+
38
+ - **[Multi-host Data parallel training script in JAX](./training.py)**
39
+ - Training script for multi-host data parallel training in JAX, to serve as a reference for training large models on multiple GPUs/TPUs across multiple hosts. A full-fledged tutorial notebook is in the works.
40
+
41
+ - **[TPU utilities for making life easier](./tpu-tools/)**
42
+ - A collection of utilities and scripts to make working with TPUs easier, such as cli to create/start/stop/setup TPUs, script to setup TPU VMs (install everything you need), mounting gcs datasets etc.
43
+
49
44
  ## Disclaimer (and About Me)
50
45
 
51
46
  I worked as a Machine Learning Researcher at Hyperverge from 2019-2021, focusing on computer vision, specifically facial anti-spoofing and facial detection & recognition. Since switching to my current job in 2021, I haven't engaged in as much R&D work, leading me to start this pet project to revisit and relearn the fundamentals and get familiar with the state-of-the-art. My current role involves primarily Golang system engineering with some applied ML work just sprinkled in. Therefore, the code may reflect my learning journey. Please forgive any mistakes and do open an issue to let me know.
@@ -62,8 +62,13 @@ class EfficientAttention(nn.Module):
62
62
  # x has shape [B, H * W, C]
63
63
  context = x if context is None else context
64
64
 
65
- B, H, W, C = x.shape
66
- x = x.reshape((B, 1, H * W, C))
65
+ orig_x_shape = x.shape
66
+ if len(x.shape) == 4:
67
+ B, H, W, C = x.shape
68
+ x = x.reshape((B, 1, H * W, C))
69
+ else:
70
+ B, SEQ, C = x.shape
71
+ x = x.reshape((B, 1, SEQ, C))
67
72
 
68
73
  if len(context.shape) == 4:
69
74
  B, _H, _W, _C = context.shape
@@ -93,7 +98,7 @@ class EfficientAttention(nn.Module):
93
98
 
94
99
  proj = self.proj_attn(hidden_states)
95
100
 
96
- proj = proj.reshape((B, H, W, C))
101
+ proj = proj.reshape(orig_x_shape)
97
102
 
98
103
  return proj
99
104
 
@@ -138,8 +143,10 @@ class NormalAttention(nn.Module):
138
143
  @nn.compact
139
144
  def __call__(self, x, context=None):
140
145
  # x has shape [B, H, W, C]
141
- B, H, W, C = x.shape
142
- x = x.reshape((B, H*W, C))
146
+ orig_x_shape = x.shape
147
+ if len(x.shape) == 4:
148
+ B, H, W, C = x.shape
149
+ x = x.reshape((B, H*W, C))
143
150
  context = x if context is None else context
144
151
  if len(context.shape) == 4:
145
152
  context = context.reshape((B, H*W, C))
@@ -151,10 +158,10 @@ class NormalAttention(nn.Module):
151
158
  query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision
152
159
  )
153
160
  proj = self.proj_attn(hidden_states)
154
- proj = proj.reshape((B, H, W, C))
161
+ proj = proj.reshape(orig_x_shape)
155
162
  return proj
156
-
157
- class AttentionBlock(nn.Module):
163
+
164
+ class BasicTransformerBlock(nn.Module):
158
165
  # Has self and cross attention
159
166
  query_dim: int
160
167
  heads: int = 4
@@ -193,129 +200,26 @@ class AttentionBlock(nn.Module):
193
200
  kernel_init=self.kernel_init
194
201
  )
195
202
 
196
- self.ff = nn.DenseGeneral(
197
- features=self.query_dim,
198
- use_bias=self.use_bias,
199
- precision=self.precision,
200
- dtype=self.dtype,
201
- kernel_init=self.kernel_init(),
202
- name="ff"
203
- )
203
+ self.ff = FlaxFeedForward(dim=self.query_dim)
204
204
  self.norm1 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
205
205
  self.norm2 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
206
206
  self.norm3 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
207
- self.norm4 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
208
207
 
209
208
  @nn.compact
210
209
  def __call__(self, hidden_states, context=None):
211
210
  # self attention
212
- residual = hidden_states
213
- hidden_states = self.norm1(hidden_states)
214
- if self.use_cross_only:
215
- hidden_states = self.attention1(hidden_states, context)
216
- else:
217
- hidden_states = self.attention1(hidden_states)
218
- hidden_states = hidden_states + residual
211
+ if not self.use_cross_only:
212
+ print("Using self attention")
213
+ hidden_states = hidden_states + self.attention1(self.norm1(hidden_states))
219
214
 
220
215
  # cross attention
221
- residual = hidden_states
222
- hidden_states = self.norm2(hidden_states)
223
- hidden_states = self.attention2(hidden_states, context)
224
- hidden_states = hidden_states + residual
216
+ hidden_states = hidden_states + self.attention2(self.norm2(hidden_states), context)
225
217
 
226
218
  # feed forward
227
- residual = hidden_states
228
- hidden_states = self.norm3(hidden_states)
229
- hidden_states = nn.gelu(hidden_states)
230
- hidden_states = self.ff(hidden_states)
231
- hidden_states = hidden_states + residual
219
+ hidden_states = hidden_states + self.ff(self.norm3(hidden_states))
232
220
 
233
221
  return hidden_states
234
222
 
235
- class TransformerBlock(nn.Module):
236
- heads: int = 4
237
- dim_head: int = 32
238
- use_linear_attention: bool = True
239
- dtype: Any = jnp.float32
240
- precision: Any = jax.lax.Precision.HIGH
241
- use_projection: bool = False
242
- use_flash_attention:bool = True
243
- use_self_and_cross:bool = False
244
-
245
- @nn.compact
246
- def __call__(self, x, context=None):
247
- inner_dim = self.heads * self.dim_head
248
- B, H, W, C = x.shape
249
- normed_x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
250
- if self.use_projection == True:
251
- if self.use_linear_attention:
252
- projected_x = nn.Dense(features=inner_dim,
253
- use_bias=False, precision=self.precision,
254
- kernel_init=kernel_init(1.0),
255
- dtype=self.dtype, name=f'project_in')(normed_x)
256
- else:
257
- projected_x = nn.Conv(
258
- features=inner_dim, kernel_size=(1, 1),
259
- kernel_init=kernel_init(1.0),
260
- strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
261
- precision=self.precision, name=f'project_in_conv',
262
- )(normed_x)
263
- else:
264
- projected_x = normed_x
265
- inner_dim = C
266
-
267
- context = projected_x if context is None else context
268
-
269
- if self.use_self_and_cross:
270
- projected_x = AttentionBlock(
271
- query_dim=inner_dim,
272
- heads=self.heads,
273
- dim_head=self.dim_head,
274
- name=f'Attention',
275
- precision=self.precision,
276
- use_bias=False,
277
- dtype=self.dtype,
278
- use_flash_attention=self.use_flash_attention,
279
- use_cross_only=False
280
- )(projected_x, context)
281
- elif self.use_flash_attention == True:
282
- projected_x = EfficientAttention(
283
- query_dim=inner_dim,
284
- heads=self.heads,
285
- dim_head=self.dim_head,
286
- name=f'Attention',
287
- precision=self.precision,
288
- use_bias=False,
289
- dtype=self.dtype,
290
- )(projected_x, context)
291
- else:
292
- projected_x = NormalAttention(
293
- query_dim=inner_dim,
294
- heads=self.heads,
295
- dim_head=self.dim_head,
296
- name=f'Attention',
297
- precision=self.precision,
298
- use_bias=False,
299
- )(projected_x, context)
300
-
301
-
302
- if self.use_projection == True:
303
- if self.use_linear_attention:
304
- projected_x = nn.Dense(features=C, precision=self.precision,
305
- dtype=self.dtype, use_bias=False,
306
- kernel_init=kernel_init(1.0),
307
- name=f'project_out')(projected_x)
308
- else:
309
- projected_x = nn.Conv(
310
- features=C, kernel_size=(1, 1),
311
- kernel_init=kernel_init(1.0),
312
- strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
313
- precision=self.precision, name=f'project_out_conv',
314
- )(projected_x)
315
-
316
- out = x + projected_x
317
- return out
318
-
319
223
  class FlaxGEGLU(nn.Module):
320
224
  r"""
321
225
  Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
@@ -333,10 +237,11 @@ class FlaxGEGLU(nn.Module):
333
237
  dim: int
334
238
  dropout: float = 0.0
335
239
  dtype: jnp.dtype = jnp.float32
240
+ precision: Any = jax.lax.Precision.DEFAULT
336
241
 
337
242
  def setup(self):
338
243
  inner_dim = self.dim * 4
339
- self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype, precision=jax.lax.Precision.DEFAULT)
244
+ self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype, precision=self.precision)
340
245
 
341
246
  def __call__(self, hidden_states):
342
247
  hidden_states = self.proj(hidden_states)
@@ -362,14 +267,14 @@ class FlaxFeedForward(nn.Module):
362
267
  """
363
268
 
364
269
  dim: int
365
- dropout: float = 0.0
366
270
  dtype: jnp.dtype = jnp.float32
271
+ precision: Any = jax.lax.Precision.DEFAULT
367
272
 
368
273
  def setup(self):
369
274
  # The second linear layer needs to be called
370
275
  # net_2 for now to match the index of the Sequential layer
371
- self.net_0 = FlaxGEGLU(self.dim, self.dtype)
372
- self.net_2 = nn.Dense(self.dim, dtype=self.dtype, precision=jax.lax.Precision.DEFAULT)
276
+ self.net_0 = FlaxGEGLU(self.dim, self.dtype, precision=self.precision)
277
+ self.net_2 = nn.Dense(self.dim, dtype=self.dtype, precision=self.precision)
373
278
 
374
279
  def __call__(self, hidden_states):
375
280
  hidden_states = self.net_0(hidden_states)
@@ -377,55 +282,127 @@ class FlaxFeedForward(nn.Module):
377
282
  return hidden_states
378
283
 
379
284
  class BasicTransformerBlock(nn.Module):
285
+ # Has self and cross attention
380
286
  query_dim: int
381
- heads: int
382
- dim_head: int
383
- dropout: float = 0.0
384
- only_cross_attention: bool = False
385
- dtype: jnp.dtype = jnp.float32
386
- use_memory_efficient_attention: bool = False
387
- split_head_dim: bool = False
388
- precision: Any = jax.lax.Precision.DEFAULT
389
-
287
+ heads: int = 4
288
+ dim_head: int = 64
289
+ dtype: Any = jnp.float32
290
+ precision: Any = jax.lax.Precision.HIGHEST
291
+ use_bias: bool = True
292
+ kernel_init: Callable = lambda : kernel_init(1.0)
293
+ use_flash_attention:bool = False
294
+ use_cross_only:bool = False
295
+ only_pure_attention:bool = False
296
+
390
297
  def setup(self):
391
- # self attention (or cross_attention if only_cross_attention is True)
392
- self.attn1 = NormalAttention(
393
- query_dim=self.query_dim,
298
+ if self.use_flash_attention:
299
+ attenBlock = EfficientAttention
300
+ else:
301
+ attenBlock = NormalAttention
302
+
303
+ self.attention1 = attenBlock(
304
+ query_dim=self.query_dim,
394
305
  heads=self.heads,
395
306
  dim_head=self.dim_head,
396
- dtype=self.dtype,
307
+ name=f'Attention1',
397
308
  precision=self.precision,
309
+ use_bias=self.use_bias,
310
+ dtype=self.dtype,
311
+ kernel_init=self.kernel_init
398
312
  )
399
- # cross attention
400
- self.attn2 = NormalAttention(
313
+ self.attention2 = attenBlock(
401
314
  query_dim=self.query_dim,
402
315
  heads=self.heads,
403
316
  dim_head=self.dim_head,
404
- dtype=self.dtype,
317
+ name=f'Attention2',
405
318
  precision=self.precision,
319
+ use_bias=self.use_bias,
320
+ dtype=self.dtype,
321
+ kernel_init=self.kernel_init
406
322
  )
407
- self.ff = FlaxFeedForward(dim=self.query_dim, dropout=self.dropout, dtype=self.dtype)
323
+
324
+ self.ff = FlaxFeedForward(dim=self.query_dim)
408
325
  self.norm1 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
409
326
  self.norm2 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
410
327
  self.norm3 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
411
-
412
- def __call__(self, hidden_states, context, deterministic=True):
328
+
329
+ @nn.compact
330
+ def __call__(self, hidden_states, context=None):
331
+ if self.only_pure_attention:
332
+ return self.attention2(self.norm2(hidden_states), context)
333
+
413
334
  # self attention
414
- residual = hidden_states
415
- if self.only_cross_attention:
416
- hidden_states = self.attn1(self.norm1(hidden_states), context)
417
- else:
418
- hidden_states = self.attn1(self.norm1(hidden_states))
419
- hidden_states = hidden_states + residual
420
-
335
+ if not self.use_cross_only:
336
+ hidden_states = hidden_states + self.attention1(self.norm1(hidden_states))
337
+
421
338
  # cross attention
422
- residual = hidden_states
423
- hidden_states = self.attn2(self.norm2(hidden_states), context)
424
- hidden_states = hidden_states + residual
425
-
339
+ hidden_states = hidden_states + self.attention2(self.norm2(hidden_states), context)
426
340
  # feed forward
427
- residual = hidden_states
428
- hidden_states = self.ff(self.norm3(hidden_states))
429
- hidden_states = hidden_states + residual
341
+ hidden_states = hidden_states + self.ff(self.norm3(hidden_states))
342
+
343
+ return hidden_states
430
344
 
431
- return hidden_states
345
+ class TransformerBlock(nn.Module):
346
+ heads: int = 4
347
+ dim_head: int = 32
348
+ use_linear_attention: bool = True
349
+ dtype: Any = jnp.float32
350
+ precision: Any = jax.lax.Precision.HIGH
351
+ use_projection: bool = False
352
+ use_flash_attention:bool = True
353
+ use_self_and_cross:bool = False
354
+ only_pure_attention:bool = False
355
+
356
+ @nn.compact
357
+ def __call__(self, x, context=None):
358
+ inner_dim = self.heads * self.dim_head
359
+ B, H, W, C = x.shape
360
+ normed_x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
361
+ if self.use_projection == True:
362
+ if self.use_linear_attention:
363
+ projected_x = nn.Dense(features=inner_dim,
364
+ use_bias=False, precision=self.precision,
365
+ kernel_init=kernel_init(1.0),
366
+ dtype=self.dtype, name=f'project_in')(normed_x)
367
+ else:
368
+ projected_x = nn.Conv(
369
+ features=inner_dim, kernel_size=(1, 1),
370
+ kernel_init=kernel_init(1.0),
371
+ strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
372
+ precision=self.precision, name=f'project_in_conv',
373
+ )(normed_x)
374
+ else:
375
+ projected_x = normed_x
376
+ inner_dim = C
377
+
378
+ context = projected_x if context is None else context
379
+
380
+ projected_x = BasicTransformerBlock(
381
+ query_dim=inner_dim,
382
+ heads=self.heads,
383
+ dim_head=self.dim_head,
384
+ name=f'Attention',
385
+ precision=self.precision,
386
+ use_bias=False,
387
+ dtype=self.dtype,
388
+ use_flash_attention=self.use_flash_attention,
389
+ use_cross_only=(not self.use_self_and_cross),
390
+ only_pure_attention=self.only_pure_attention
391
+ )(projected_x, context)
392
+
393
+ if self.use_projection == True:
394
+ if self.use_linear_attention:
395
+ projected_x = nn.Dense(features=C, precision=self.precision,
396
+ dtype=self.dtype, use_bias=False,
397
+ kernel_init=kernel_init(1.0),
398
+ name=f'project_out')(projected_x)
399
+ else:
400
+ projected_x = nn.Conv(
401
+ features=C, kernel_size=(1, 1),
402
+ kernel_init=kernel_init(1.0),
403
+ strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
404
+ precision=self.precision, name=f'project_out_conv',
405
+ )(projected_x)
406
+
407
+ out = x + projected_x
408
+ return out
File without changes
@@ -0,0 +1,14 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from flax import linen as nn
4
+ from typing import Dict, Callable, Sequence, Any, Union
5
+ import einops
6
+ from ..common import kernel_init, ConvLayer, Upsample, Downsample, PixelShuffle
7
+
8
+
9
+ class AutoEncoder:
10
+ def encode(self, x: jnp.ndarray, **kwargs) -> jnp.ndarray:
11
+ raise NotImplementedError
12
+
13
+ def decode(self, z: jnp.ndarray, **kwargs) -> jnp.ndarray:
14
+ raise NotImplementedError
@@ -0,0 +1,88 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from flax import linen as nn
4
+ from .autoencoder import AutoEncoder
5
+
6
+ """
7
+ This module contains an Autoencoder implementation which uses the Stable Diffusion VAE model from the HuggingFace Diffusers library.
8
+ """
9
+
10
+ class StableDiffusionVAE(AutoEncoder):
11
+ def __init__(self, modelname = "CompVis/stable-diffusion-v1-4"):
12
+
13
+ from diffusers.models.vae_flax import FlaxEncoder, FlaxDecoder
14
+ from diffusers import FlaxStableDiffusionPipeline
15
+
16
+ pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
17
+ modelname,
18
+ revision="bf16",
19
+ dtype=jnp.bfloat16,
20
+ )
21
+
22
+ vae = pipeline.vae
23
+
24
+ enc = FlaxEncoder(
25
+ in_channels=vae.config.in_channels,
26
+ out_channels=vae.config.latent_channels,
27
+ down_block_types=vae.config.down_block_types,
28
+ block_out_channels=vae.config.block_out_channels,
29
+ layers_per_block=vae.config.layers_per_block,
30
+ act_fn=vae.config.act_fn,
31
+ norm_num_groups=vae.config.norm_num_groups,
32
+ double_z=True,
33
+ dtype=vae.dtype,
34
+ )
35
+
36
+ dec = FlaxDecoder(
37
+ in_channels=vae.config.latent_channels,
38
+ out_channels=vae.config.out_channels,
39
+ up_block_types=vae.config.up_block_types,
40
+ block_out_channels=vae.config.block_out_channels,
41
+ layers_per_block=vae.config.layers_per_block,
42
+ norm_num_groups=vae.config.norm_num_groups,
43
+ act_fn=vae.config.act_fn,
44
+ dtype=vae.dtype,
45
+ )
46
+
47
+ quant_conv = nn.Conv(
48
+ 2 * vae.config.latent_channels,
49
+ kernel_size=(1, 1),
50
+ strides=(1, 1),
51
+ padding="VALID",
52
+ dtype=vae.dtype,
53
+ )
54
+
55
+ post_quant_conv = nn.Conv(
56
+ vae.config.latent_channels,
57
+ kernel_size=(1, 1),
58
+ strides=(1, 1),
59
+ padding="VALID",
60
+ dtype=vae.dtype,
61
+ )
62
+
63
+ self.enc = enc
64
+ self.dec = dec
65
+ self.post_quant_conv = post_quant_conv
66
+ self.quant_conv = quant_conv
67
+ self.params = params
68
+ self.scaling_factor = vae.scaling_factor
69
+
70
+ def encode(self, images, rngkey: jax.random.PRNGKey = None):
71
+ latents = self.enc.apply({"params": self.params["vae"]['encoder']}, images, deterministic=True)
72
+ latents = self.quant_conv.apply({"params": self.params["vae"]['quant_conv']}, latents)
73
+ if rngkey is not None:
74
+ mean, log_std = jnp.split(latents, 2, axis=-1)
75
+ log_std = jnp.clip(log_std, -30, 20)
76
+ std = jnp.exp(0.5 * log_std)
77
+ latents = mean + std * jax.random.normal(rngkey, mean.shape, dtype=mean.dtype)
78
+ print("Sampled")
79
+ else:
80
+ # return the mean
81
+ latents, _ = jnp.split(latents, 2, axis=-1)
82
+ latents *= self.scaling_factor
83
+ return latents
84
+
85
+ def decode(self, latents):
86
+ latents = (1.0 / self.scaling_factor) * latents
87
+ latents = self.post_quant_conv.apply({"params": self.params["vae"]['post_quant_conv']}, latents)
88
+ return self.dec.apply({"params": self.params["vae"]['decoder']}, latents)