flaxdiff 0.1.4__tar.gz → 0.1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {flaxdiff-0.1.4/flaxdiff.egg-info → flaxdiff-0.1.6}/PKG-INFO +12 -2
  2. flaxdiff-0.1.4/PKG-INFO → flaxdiff-0.1.6/README.md +11 -14
  3. {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/models/attention.py +140 -162
  4. flaxdiff-0.1.6/flaxdiff/models/autoencoder/__init__.py +2 -0
  5. flaxdiff-0.1.6/flaxdiff/models/autoencoder/autoencoder.py +19 -0
  6. flaxdiff-0.1.6/flaxdiff/models/autoencoder/diffusers.py +91 -0
  7. flaxdiff-0.1.6/flaxdiff/models/autoencoder/simple_autoenc.py +26 -0
  8. flaxdiff-0.1.4/flaxdiff/models/simple_unet.py → flaxdiff-0.1.6/flaxdiff/models/common.py +28 -210
  9. flaxdiff-0.1.6/flaxdiff/models/simple_unet.py +205 -0
  10. flaxdiff-0.1.6/flaxdiff/trainer/__init__.py +2 -0
  11. flaxdiff-0.1.6/flaxdiff/trainer/autoencoder_trainer.py +182 -0
  12. flaxdiff-0.1.4/flaxdiff/trainer/__init__.py → flaxdiff-0.1.6/flaxdiff/trainer/diffusion_trainer.py +48 -47
  13. flaxdiff-0.1.6/flaxdiff/trainer/simple_trainer.py +418 -0
  14. flaxdiff-0.1.4/README.md → flaxdiff-0.1.6/flaxdiff.egg-info/PKG-INFO +24 -1
  15. {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff.egg-info/SOURCES.txt +6 -0
  16. {flaxdiff-0.1.4 → flaxdiff-0.1.6}/setup.py +1 -1
  17. flaxdiff-0.1.4/flaxdiff/models/common.py +0 -7
  18. flaxdiff-0.1.4/flaxdiff/trainer/simple_trainer.py +0 -323
  19. {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/__init__.py +0 -0
  20. {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/models/__init__.py +0 -0
  21. {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/models/favor_fastattn.py +0 -0
  22. {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/models/simple_vit.py +0 -0
  23. {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/predictors/__init__.py +0 -0
  24. {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/samplers/__init__.py +0 -0
  25. {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/samplers/common.py +0 -0
  26. {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/samplers/ddim.py +0 -0
  27. {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/samplers/ddpm.py +0 -0
  28. {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/samplers/euler.py +0 -0
  29. {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/samplers/heun_sampler.py +0 -0
  30. {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/samplers/multistep_dpm.py +0 -0
  31. {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/samplers/rk4_sampler.py +0 -0
  32. {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/schedulers/__init__.py +0 -0
  33. {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/schedulers/common.py +0 -0
  34. {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/schedulers/continuous.py +0 -0
  35. {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/schedulers/cosine.py +0 -0
  36. {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/schedulers/discrete.py +0 -0
  37. {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/schedulers/exp.py +0 -0
  38. {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/schedulers/karras.py +0 -0
  39. {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/schedulers/linear.py +0 -0
  40. {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/schedulers/sqrt.py +0 -0
  41. {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/utils.py +0 -0
  42. {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff.egg-info/dependency_links.txt +0 -0
  43. {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff.egg-info/requires.txt +0 -0
  44. {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff.egg-info/top_level.txt +0 -0
  45. {flaxdiff-0.1.4 → flaxdiff-0.1.6}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.4
3
+ Version: 0.1.6
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -13,6 +13,8 @@ Requires-Dist: clu
13
13
 
14
14
  # ![](images/logo.jpeg "FlaxDiff")
15
15
 
16
+ **This project is partially supported by [Google TPU Research Cloud](https://sites.research.google/trc/about/). I would like to thank the Google Cloud TPU team for providing me with the resources to train the bigger text-conditional models in multi-host distributed settings.**
17
+
16
18
  ## A Versatile and simple Diffusion Library
17
19
 
18
20
  In recent years, diffusion and score-based multi-step models have revolutionized the generative AI domain. However, the latest research in this field has become highly math-intensive, making it challenging to understand how state-of-the-art diffusion models work and generate such impressive images. Replicating this research in code can be daunting.
@@ -27,7 +29,7 @@ The `Diffusion_flax_linen.ipynb` notebook is my main workspace for experiments.
27
29
 
28
30
  In the `example notebooks` folder, you will find comprehensive notebooks for various diffusion techniques, written entirely from scratch and are independent of the FlaxDiff library. Each notebook includes detailed explanations of the underlying mathematics and concepts, making them invaluable resources for learning and understanding diffusion models.
29
31
 
30
- ### Available Notebooks
32
+ ### Available Notebooks and Resources
31
33
 
32
34
  - **[Diffusion explained (nbviewer link)](https://nbviewer.org/github/AshishKumar4/FlaxDiff/blob/main/tutorial%20notebooks/simple%20diffusion%20flax.ipynb) [(local link)](tutorial%20notebooks/simple%20diffusion%20flax.ipynb)**
33
35
 
@@ -46,6 +48,14 @@ In the `example notebooks` folder, you will find comprehensive notebooks for var
46
48
 
47
49
  These notebooks aim to provide a very easy to understand and step-by-step guide to the various diffusion models and techniques. They are designed to be beginner-friendly, and thus although they may not adhere to the exact formulations and implementations of the original papers to make them more understandable and generalizable, I have tried my best to keep them as accurate as possible. If you find any mistakes or have any suggestions, please feel free to open an issue or a pull request.
48
50
 
51
+ #### Other resources
52
+
53
+ - **[Multi-host Data parallel training script in JAX](./training.py)**
54
+ - Training script for multi-host data parallel training in JAX, to serve as a reference for training large models on multiple GPUs/TPUs across multiple hosts. A full-fledged tutorial notebook is in the works.
55
+
56
+ - **[TPU utilities for making life easier](./tpu-tools/)**
57
+ - A collection of utilities and scripts to make working with TPUs easier, such as cli to create/start/stop/setup TPUs, script to setup TPU VMs (install everything you need), mounting gcs datasets etc.
58
+
49
59
  ## Disclaimer (and About Me)
50
60
 
51
61
  I worked as a Machine Learning Researcher at Hyperverge from 2019-2021, focusing on computer vision, specifically facial anti-spoofing and facial detection & recognition. Since switching to my current job in 2021, I haven't engaged in as much R&D work, leading me to start this pet project to revisit and relearn the fundamentals and get familiar with the state-of-the-art. My current role involves primarily Golang system engineering with some applied ML work just sprinkled in. Therefore, the code may reflect my learning journey. Please forgive any mistakes and do open an issue to let me know.
@@ -1,18 +1,7 @@
1
- Metadata-Version: 2.1
2
- Name: flaxdiff
3
- Version: 0.1.4
4
- Summary: A versatile and easy to understand Diffusion library
5
- Author: Ashish Kumar Singh
6
- Author-email: ashishkmr472@gmail.com
7
- Description-Content-Type: text/markdown
8
- Requires-Dist: flax>=0.8.4
9
- Requires-Dist: optax>=0.2.2
10
- Requires-Dist: jax>=0.4.28
11
- Requires-Dist: orbax
12
- Requires-Dist: clu
13
-
14
1
  # ![](images/logo.jpeg "FlaxDiff")
15
2
 
3
+ **This project is partially supported by [Google TPU Research Cloud](https://sites.research.google/trc/about/). I would like to thank the Google Cloud TPU team for providing me with the resources to train the bigger text-conditional models in multi-host distributed settings.**
4
+
16
5
  ## A Versatile and simple Diffusion Library
17
6
 
18
7
  In recent years, diffusion and score-based multi-step models have revolutionized the generative AI domain. However, the latest research in this field has become highly math-intensive, making it challenging to understand how state-of-the-art diffusion models work and generate such impressive images. Replicating this research in code can be daunting.
@@ -27,7 +16,7 @@ The `Diffusion_flax_linen.ipynb` notebook is my main workspace for experiments.
27
16
 
28
17
  In the `example notebooks` folder, you will find comprehensive notebooks for various diffusion techniques, written entirely from scratch and are independent of the FlaxDiff library. Each notebook includes detailed explanations of the underlying mathematics and concepts, making them invaluable resources for learning and understanding diffusion models.
29
18
 
30
- ### Available Notebooks
19
+ ### Available Notebooks and Resources
31
20
 
32
21
  - **[Diffusion explained (nbviewer link)](https://nbviewer.org/github/AshishKumar4/FlaxDiff/blob/main/tutorial%20notebooks/simple%20diffusion%20flax.ipynb) [(local link)](tutorial%20notebooks/simple%20diffusion%20flax.ipynb)**
33
22
 
@@ -46,6 +35,14 @@ In the `example notebooks` folder, you will find comprehensive notebooks for var
46
35
 
47
36
  These notebooks aim to provide a very easy to understand and step-by-step guide to the various diffusion models and techniques. They are designed to be beginner-friendly, and thus although they may not adhere to the exact formulations and implementations of the original papers to make them more understandable and generalizable, I have tried my best to keep them as accurate as possible. If you find any mistakes or have any suggestions, please feel free to open an issue or a pull request.
48
37
 
38
+ #### Other resources
39
+
40
+ - **[Multi-host Data parallel training script in JAX](./training.py)**
41
+ - Training script for multi-host data parallel training in JAX, to serve as a reference for training large models on multiple GPUs/TPUs across multiple hosts. A full-fledged tutorial notebook is in the works.
42
+
43
+ - **[TPU utilities for making life easier](./tpu-tools/)**
44
+ - A collection of utilities and scripts to make working with TPUs easier, such as cli to create/start/stop/setup TPUs, script to setup TPU VMs (install everything you need), mounting gcs datasets etc.
45
+
49
46
  ## Disclaimer (and About Me)
50
47
 
51
48
  I worked as a Machine Learning Researcher at Hyperverge from 2019-2021, focusing on computer vision, specifically facial anti-spoofing and facial detection & recognition. Since switching to my current job in 2021, I haven't engaged in as much R&D work, leading me to start this pet project to revisit and relearn the fundamentals and get familiar with the state-of-the-art. My current role involves primarily Golang system engineering with some applied ML work just sprinkled in. Therefore, the code may reflect my learning journey. Please forgive any mistakes and do open an issue to let me know.
@@ -5,7 +5,8 @@ Some Code ported from https://github.com/huggingface/diffusers/blob/main/src/dif
5
5
  import jax
6
6
  import jax.numpy as jnp
7
7
  from flax import linen as nn
8
- from typing import Dict, Callable, Sequence, Any, Union
8
+ from typing import Dict, Callable, Sequence, Any, Union, Tuple, Optional
9
+ from flax.typing import Dtype, PrecisionLike
9
10
  import einops
10
11
  import functools
11
12
  import math
@@ -18,8 +19,8 @@ class EfficientAttention(nn.Module):
18
19
  query_dim: int
19
20
  heads: int = 4
20
21
  dim_head: int = 64
21
- dtype: Any = jnp.float32
22
- precision: Any = jax.lax.Precision.HIGHEST
22
+ dtype: Optional[Dtype] = None
23
+ precision: PrecisionLike = None
23
24
  use_bias: bool = True
24
25
  kernel_init: Callable = lambda : kernel_init(1.0)
25
26
 
@@ -62,8 +63,13 @@ class EfficientAttention(nn.Module):
62
63
  # x has shape [B, H * W, C]
63
64
  context = x if context is None else context
64
65
 
65
- B, H, W, C = x.shape
66
- x = x.reshape((B, 1, H * W, C))
66
+ orig_x_shape = x.shape
67
+ if len(x.shape) == 4:
68
+ B, H, W, C = x.shape
69
+ x = x.reshape((B, 1, H * W, C))
70
+ else:
71
+ B, SEQ, C = x.shape
72
+ x = x.reshape((B, 1, SEQ, C))
67
73
 
68
74
  if len(context.shape) == 4:
69
75
  B, _H, _W, _C = context.shape
@@ -93,7 +99,7 @@ class EfficientAttention(nn.Module):
93
99
 
94
100
  proj = self.proj_attn(hidden_states)
95
101
 
96
- proj = proj.reshape((B, H, W, C))
102
+ proj = proj.reshape(orig_x_shape)
97
103
 
98
104
  return proj
99
105
 
@@ -104,8 +110,8 @@ class NormalAttention(nn.Module):
104
110
  query_dim: int
105
111
  heads: int = 4
106
112
  dim_head: int = 64
107
- dtype: Any = jnp.float32
108
- precision: Any = jax.lax.Precision.HIGHEST
113
+ dtype: Optional[Dtype] = None
114
+ precision: PrecisionLike = None
109
115
  use_bias: bool = True
110
116
  kernel_init: Callable = lambda : kernel_init(1.0)
111
117
 
@@ -138,8 +144,10 @@ class NormalAttention(nn.Module):
138
144
  @nn.compact
139
145
  def __call__(self, x, context=None):
140
146
  # x has shape [B, H, W, C]
141
- B, H, W, C = x.shape
142
- x = x.reshape((B, H*W, C))
147
+ orig_x_shape = x.shape
148
+ if len(x.shape) == 4:
149
+ B, H, W, C = x.shape
150
+ x = x.reshape((B, H*W, C))
143
151
  context = x if context is None else context
144
152
  if len(context.shape) == 4:
145
153
  context = context.reshape((B, H*W, C))
@@ -151,16 +159,16 @@ class NormalAttention(nn.Module):
151
159
  query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision
152
160
  )
153
161
  proj = self.proj_attn(hidden_states)
154
- proj = proj.reshape((B, H, W, C))
162
+ proj = proj.reshape(orig_x_shape)
155
163
  return proj
156
-
157
- class AttentionBlock(nn.Module):
164
+
165
+ class BasicTransformerBlock(nn.Module):
158
166
  # Has self and cross attention
159
167
  query_dim: int
160
168
  heads: int = 4
161
169
  dim_head: int = 64
162
- dtype: Any = jnp.float32
163
- precision: Any = jax.lax.Precision.HIGHEST
170
+ dtype: Optional[Dtype] = None
171
+ precision: PrecisionLike = None
164
172
  use_bias: bool = True
165
173
  kernel_init: Callable = lambda : kernel_init(1.0)
166
174
  use_flash_attention:bool = False
@@ -193,129 +201,26 @@ class AttentionBlock(nn.Module):
193
201
  kernel_init=self.kernel_init
194
202
  )
195
203
 
196
- self.ff = nn.DenseGeneral(
197
- features=self.query_dim,
198
- use_bias=self.use_bias,
199
- precision=self.precision,
200
- dtype=self.dtype,
201
- kernel_init=self.kernel_init(),
202
- name="ff"
203
- )
204
+ self.ff = FlaxFeedForward(dim=self.query_dim)
204
205
  self.norm1 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
205
206
  self.norm2 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
206
207
  self.norm3 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
207
- self.norm4 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
208
208
 
209
209
  @nn.compact
210
210
  def __call__(self, hidden_states, context=None):
211
211
  # self attention
212
- residual = hidden_states
213
- hidden_states = self.norm1(hidden_states)
214
- if self.use_cross_only:
215
- hidden_states = self.attention1(hidden_states, context)
216
- else:
217
- hidden_states = self.attention1(hidden_states)
218
- hidden_states = hidden_states + residual
212
+ if not self.use_cross_only:
213
+ print("Using self attention")
214
+ hidden_states = hidden_states + self.attention1(self.norm1(hidden_states))
219
215
 
220
216
  # cross attention
221
- residual = hidden_states
222
- hidden_states = self.norm2(hidden_states)
223
- hidden_states = self.attention2(hidden_states, context)
224
- hidden_states = hidden_states + residual
217
+ hidden_states = hidden_states + self.attention2(self.norm2(hidden_states), context)
225
218
 
226
219
  # feed forward
227
- residual = hidden_states
228
- hidden_states = self.norm3(hidden_states)
229
- hidden_states = nn.gelu(hidden_states)
230
- hidden_states = self.ff(hidden_states)
231
- hidden_states = hidden_states + residual
220
+ hidden_states = hidden_states + self.ff(self.norm3(hidden_states))
232
221
 
233
222
  return hidden_states
234
223
 
235
- class TransformerBlock(nn.Module):
236
- heads: int = 4
237
- dim_head: int = 32
238
- use_linear_attention: bool = True
239
- dtype: Any = jnp.float32
240
- precision: Any = jax.lax.Precision.HIGH
241
- use_projection: bool = False
242
- use_flash_attention:bool = True
243
- use_self_and_cross:bool = False
244
-
245
- @nn.compact
246
- def __call__(self, x, context=None):
247
- inner_dim = self.heads * self.dim_head
248
- B, H, W, C = x.shape
249
- normed_x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
250
- if self.use_projection == True:
251
- if self.use_linear_attention:
252
- projected_x = nn.Dense(features=inner_dim,
253
- use_bias=False, precision=self.precision,
254
- kernel_init=kernel_init(1.0),
255
- dtype=self.dtype, name=f'project_in')(normed_x)
256
- else:
257
- projected_x = nn.Conv(
258
- features=inner_dim, kernel_size=(1, 1),
259
- kernel_init=kernel_init(1.0),
260
- strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
261
- precision=self.precision, name=f'project_in_conv',
262
- )(normed_x)
263
- else:
264
- projected_x = normed_x
265
- inner_dim = C
266
-
267
- context = projected_x if context is None else context
268
-
269
- if self.use_self_and_cross:
270
- projected_x = AttentionBlock(
271
- query_dim=inner_dim,
272
- heads=self.heads,
273
- dim_head=self.dim_head,
274
- name=f'Attention',
275
- precision=self.precision,
276
- use_bias=False,
277
- dtype=self.dtype,
278
- use_flash_attention=self.use_flash_attention,
279
- use_cross_only=False
280
- )(projected_x, context)
281
- elif self.use_flash_attention == True:
282
- projected_x = EfficientAttention(
283
- query_dim=inner_dim,
284
- heads=self.heads,
285
- dim_head=self.dim_head,
286
- name=f'Attention',
287
- precision=self.precision,
288
- use_bias=False,
289
- dtype=self.dtype,
290
- )(projected_x, context)
291
- else:
292
- projected_x = NormalAttention(
293
- query_dim=inner_dim,
294
- heads=self.heads,
295
- dim_head=self.dim_head,
296
- name=f'Attention',
297
- precision=self.precision,
298
- use_bias=False,
299
- )(projected_x, context)
300
-
301
-
302
- if self.use_projection == True:
303
- if self.use_linear_attention:
304
- projected_x = nn.Dense(features=C, precision=self.precision,
305
- dtype=self.dtype, use_bias=False,
306
- kernel_init=kernel_init(1.0),
307
- name=f'project_out')(projected_x)
308
- else:
309
- projected_x = nn.Conv(
310
- features=C, kernel_size=(1, 1),
311
- kernel_init=kernel_init(1.0),
312
- strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
313
- precision=self.precision, name=f'project_out_conv',
314
- )(projected_x)
315
-
316
- out = x + projected_x
317
- return out
318
-
319
224
  class FlaxGEGLU(nn.Module):
320
225
  r"""
321
226
  Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
@@ -333,10 +238,11 @@ class FlaxGEGLU(nn.Module):
333
238
  dim: int
334
239
  dropout: float = 0.0
335
240
  dtype: jnp.dtype = jnp.float32
241
+ precision: Any = jax.lax.Precision.DEFAULT
336
242
 
337
243
  def setup(self):
338
244
  inner_dim = self.dim * 4
339
- self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype, precision=jax.lax.Precision.DEFAULT)
245
+ self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype, precision=self.precision)
340
246
 
341
247
  def __call__(self, hidden_states):
342
248
  hidden_states = self.proj(hidden_states)
@@ -362,14 +268,14 @@ class FlaxFeedForward(nn.Module):
362
268
  """
363
269
 
364
270
  dim: int
365
- dropout: float = 0.0
366
271
  dtype: jnp.dtype = jnp.float32
272
+ precision: Any = jax.lax.Precision.DEFAULT
367
273
 
368
274
  def setup(self):
369
275
  # The second linear layer needs to be called
370
276
  # net_2 for now to match the index of the Sequential layer
371
- self.net_0 = FlaxGEGLU(self.dim, self.dtype)
372
- self.net_2 = nn.Dense(self.dim, dtype=self.dtype, precision=jax.lax.Precision.DEFAULT)
277
+ self.net_0 = FlaxGEGLU(self.dim, self.dtype, precision=self.precision)
278
+ self.net_2 = nn.Dense(self.dim, dtype=self.dtype, precision=self.precision)
373
279
 
374
280
  def __call__(self, hidden_states):
375
281
  hidden_states = self.net_0(hidden_states)
@@ -377,55 +283,127 @@ class FlaxFeedForward(nn.Module):
377
283
  return hidden_states
378
284
 
379
285
  class BasicTransformerBlock(nn.Module):
286
+ # Has self and cross attention
380
287
  query_dim: int
381
- heads: int
382
- dim_head: int
383
- dropout: float = 0.0
384
- only_cross_attention: bool = False
385
- dtype: jnp.dtype = jnp.float32
386
- use_memory_efficient_attention: bool = False
387
- split_head_dim: bool = False
388
- precision: Any = jax.lax.Precision.DEFAULT
389
-
288
+ heads: int = 4
289
+ dim_head: int = 64
290
+ dtype: Optional[Dtype] = None
291
+ precision: PrecisionLike = None
292
+ use_bias: bool = True
293
+ kernel_init: Callable = lambda : kernel_init(1.0)
294
+ use_flash_attention:bool = False
295
+ use_cross_only:bool = False
296
+ only_pure_attention:bool = False
297
+
390
298
  def setup(self):
391
- # self attention (or cross_attention if only_cross_attention is True)
392
- self.attn1 = NormalAttention(
393
- query_dim=self.query_dim,
299
+ if self.use_flash_attention:
300
+ attenBlock = EfficientAttention
301
+ else:
302
+ attenBlock = NormalAttention
303
+
304
+ self.attention1 = attenBlock(
305
+ query_dim=self.query_dim,
394
306
  heads=self.heads,
395
307
  dim_head=self.dim_head,
396
- dtype=self.dtype,
308
+ name=f'Attention1',
397
309
  precision=self.precision,
310
+ use_bias=self.use_bias,
311
+ dtype=self.dtype,
312
+ kernel_init=self.kernel_init
398
313
  )
399
- # cross attention
400
- self.attn2 = NormalAttention(
314
+ self.attention2 = attenBlock(
401
315
  query_dim=self.query_dim,
402
316
  heads=self.heads,
403
317
  dim_head=self.dim_head,
404
- dtype=self.dtype,
318
+ name=f'Attention2',
405
319
  precision=self.precision,
320
+ use_bias=self.use_bias,
321
+ dtype=self.dtype,
322
+ kernel_init=self.kernel_init
406
323
  )
407
- self.ff = FlaxFeedForward(dim=self.query_dim, dropout=self.dropout, dtype=self.dtype)
324
+
325
+ self.ff = FlaxFeedForward(dim=self.query_dim)
408
326
  self.norm1 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
409
327
  self.norm2 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
410
328
  self.norm3 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
411
-
412
- def __call__(self, hidden_states, context, deterministic=True):
329
+
330
+ @nn.compact
331
+ def __call__(self, hidden_states, context=None):
332
+ if self.only_pure_attention:
333
+ return self.attention2(self.norm2(hidden_states), context)
334
+
413
335
  # self attention
414
- residual = hidden_states
415
- if self.only_cross_attention:
416
- hidden_states = self.attn1(self.norm1(hidden_states), context)
417
- else:
418
- hidden_states = self.attn1(self.norm1(hidden_states))
419
- hidden_states = hidden_states + residual
420
-
336
+ if not self.use_cross_only:
337
+ hidden_states = hidden_states + self.attention1(self.norm1(hidden_states))
338
+
421
339
  # cross attention
422
- residual = hidden_states
423
- hidden_states = self.attn2(self.norm2(hidden_states), context)
424
- hidden_states = hidden_states + residual
425
-
340
+ hidden_states = hidden_states + self.attention2(self.norm2(hidden_states), context)
426
341
  # feed forward
427
- residual = hidden_states
428
- hidden_states = self.ff(self.norm3(hidden_states))
429
- hidden_states = hidden_states + residual
342
+ hidden_states = hidden_states + self.ff(self.norm3(hidden_states))
343
+
344
+ return hidden_states
345
+
346
+ class TransformerBlock(nn.Module):
347
+ heads: int = 4
348
+ dim_head: int = 32
349
+ use_linear_attention: bool = True
350
+ dtype: Optional[Dtype] = None
351
+ precision: PrecisionLike = None
352
+ use_projection: bool = False
353
+ use_flash_attention:bool = True
354
+ use_self_and_cross:bool = False
355
+ only_pure_attention:bool = False
356
+
357
+ @nn.compact
358
+ def __call__(self, x, context=None):
359
+ inner_dim = self.heads * self.dim_head
360
+ B, H, W, C = x.shape
361
+ normed_x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
362
+ if self.use_projection == True:
363
+ if self.use_linear_attention:
364
+ projected_x = nn.Dense(features=inner_dim,
365
+ use_bias=False, precision=self.precision,
366
+ kernel_init=kernel_init(1.0),
367
+ dtype=self.dtype, name=f'project_in')(normed_x)
368
+ else:
369
+ projected_x = nn.Conv(
370
+ features=inner_dim, kernel_size=(1, 1),
371
+ kernel_init=kernel_init(1.0),
372
+ strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
373
+ precision=self.precision, name=f'project_in_conv',
374
+ )(normed_x)
375
+ else:
376
+ projected_x = normed_x
377
+ inner_dim = C
378
+
379
+ context = projected_x if context is None else context
430
380
 
431
- return hidden_states
381
+ projected_x = BasicTransformerBlock(
382
+ query_dim=inner_dim,
383
+ heads=self.heads,
384
+ dim_head=self.dim_head,
385
+ name=f'Attention',
386
+ precision=self.precision,
387
+ use_bias=False,
388
+ dtype=self.dtype,
389
+ use_flash_attention=self.use_flash_attention,
390
+ use_cross_only=(not self.use_self_and_cross),
391
+ only_pure_attention=self.only_pure_attention
392
+ )(projected_x, context)
393
+
394
+ if self.use_projection == True:
395
+ if self.use_linear_attention:
396
+ projected_x = nn.Dense(features=C, precision=self.precision,
397
+ dtype=self.dtype, use_bias=False,
398
+ kernel_init=kernel_init(1.0),
399
+ name=f'project_out')(projected_x)
400
+ else:
401
+ projected_x = nn.Conv(
402
+ features=C, kernel_size=(1, 1),
403
+ kernel_init=kernel_init(1.0),
404
+ strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
405
+ precision=self.precision, name=f'project_out_conv',
406
+ )(projected_x)
407
+
408
+ out = x + projected_x
409
+ return out
@@ -0,0 +1,2 @@
1
+ from .autoencoder import AutoEncoder
2
+ from .diffusers import StableDiffusionVAE
@@ -0,0 +1,19 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from flax import linen as nn
4
+ from typing import Dict, Callable, Sequence, Any, Union
5
+ import einops
6
+ from ..common import kernel_init, ConvLayer, Upsample, Downsample, PixelShuffle
7
+
8
+
9
+ class AutoEncoder():
10
+ def encode(self, x: jnp.ndarray, **kwargs) -> jnp.ndarray:
11
+ raise NotImplementedError
12
+
13
+ def decode(self, z: jnp.ndarray, **kwargs) -> jnp.ndarray:
14
+ raise NotImplementedError
15
+
16
+ def __call__(self, x: jnp.ndarray):
17
+ latents = self.encode(x)
18
+ reconstructions = self.decode(latents)
19
+ return reconstructions
@@ -0,0 +1,91 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from flax import linen as nn
4
+ from .autoencoder import AutoEncoder
5
+
6
+ """
7
+ This module contains an Autoencoder implementation which uses the Stable Diffusion VAE model from the HuggingFace Diffusers library.
8
+ The actual model was not trained by me, but was taken from the HuggingFace model hub.
9
+ I have only implemented the wrapper around the diffusers pipeline to make it compatible with our library
10
+ All credits for the model go to the developers of Stable Diffusion VAE and all credits for the pipeline go to the developers of the Diffusers library.
11
+ """
12
+
13
+ class StableDiffusionVAE(AutoEncoder):
14
+ def __init__(self, modelname = "CompVis/stable-diffusion-v1-4"):
15
+
16
+ from diffusers.models.vae_flax import FlaxEncoder, FlaxDecoder
17
+ from diffusers import FlaxStableDiffusionPipeline
18
+
19
+ pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
20
+ modelname,
21
+ revision="bf16",
22
+ dtype=jnp.bfloat16,
23
+ )
24
+
25
+ vae = pipeline.vae
26
+
27
+ enc = FlaxEncoder(
28
+ in_channels=vae.config.in_channels,
29
+ out_channels=vae.config.latent_channels,
30
+ down_block_types=vae.config.down_block_types,
31
+ block_out_channels=vae.config.block_out_channels,
32
+ layers_per_block=vae.config.layers_per_block,
33
+ act_fn=vae.config.act_fn,
34
+ norm_num_groups=vae.config.norm_num_groups,
35
+ double_z=True,
36
+ dtype=vae.dtype,
37
+ )
38
+
39
+ dec = FlaxDecoder(
40
+ in_channels=vae.config.latent_channels,
41
+ out_channels=vae.config.out_channels,
42
+ up_block_types=vae.config.up_block_types,
43
+ block_out_channels=vae.config.block_out_channels,
44
+ layers_per_block=vae.config.layers_per_block,
45
+ norm_num_groups=vae.config.norm_num_groups,
46
+ act_fn=vae.config.act_fn,
47
+ dtype=vae.dtype,
48
+ )
49
+
50
+ quant_conv = nn.Conv(
51
+ 2 * vae.config.latent_channels,
52
+ kernel_size=(1, 1),
53
+ strides=(1, 1),
54
+ padding="VALID",
55
+ dtype=vae.dtype,
56
+ )
57
+
58
+ post_quant_conv = nn.Conv(
59
+ vae.config.latent_channels,
60
+ kernel_size=(1, 1),
61
+ strides=(1, 1),
62
+ padding="VALID",
63
+ dtype=vae.dtype,
64
+ )
65
+
66
+ self.enc = enc
67
+ self.dec = dec
68
+ self.post_quant_conv = post_quant_conv
69
+ self.quant_conv = quant_conv
70
+ self.params = params
71
+ self.scaling_factor = vae.scaling_factor
72
+
73
+ def encode(self, images, rngkey: jax.random.PRNGKey = None):
74
+ latents = self.enc.apply({"params": self.params["vae"]['encoder']}, images, deterministic=True)
75
+ latents = self.quant_conv.apply({"params": self.params["vae"]['quant_conv']}, latents)
76
+ if rngkey is not None:
77
+ mean, log_std = jnp.split(latents, 2, axis=-1)
78
+ log_std = jnp.clip(log_std, -30, 20)
79
+ std = jnp.exp(0.5 * log_std)
80
+ latents = mean + std * jax.random.normal(rngkey, mean.shape, dtype=mean.dtype)
81
+ print("Sampled")
82
+ else:
83
+ # return the mean
84
+ latents, _ = jnp.split(latents, 2, axis=-1)
85
+ latents *= self.scaling_factor
86
+ return latents
87
+
88
+ def decode(self, latents):
89
+ latents = (1.0 / self.scaling_factor) * latents
90
+ latents = self.post_quant_conv.apply({"params": self.params["vae"]['post_quant_conv']}, latents)
91
+ return self.dec.apply({"params": self.params["vae"]['decoder']}, latents)
@@ -0,0 +1,26 @@
1
+ from typing import Any, List, Optional, Callable
2
+ import jax
3
+ import flax.linen as nn
4
+ from jax import numpy as jnp
5
+ from flax.typing import Dtype, PrecisionLike
6
+ from .autoencoder import AutoEncoder
7
+
8
+ class SimpleAutoEncoder(AutoEncoder):
9
+ latent_channels: int
10
+ feature_depths: List[int]=[64, 128, 256, 512]
11
+ attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}],
12
+ num_res_blocks: int=2
13
+ num_middle_res_blocks:int=1,
14
+ activation:Callable = jax.nn.swish
15
+ norm_groups:int=8
16
+ dtype: Optional[Dtype] = None
17
+ precision: PrecisionLike = None
18
+
19
+ # def encode(self, x: jnp.ndarray):
20
+
21
+
22
+ @nn.compact
23
+ def __call__(self, x: jnp.ndarray):
24
+ latents = self.encode(x)
25
+ reconstructions = self.decode(latents)
26
+ return reconstructions