flaxdiff 0.1.4__tar.gz → 0.1.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {flaxdiff-0.1.4/flaxdiff.egg-info → flaxdiff-0.1.6}/PKG-INFO +12 -2
- flaxdiff-0.1.4/PKG-INFO → flaxdiff-0.1.6/README.md +11 -14
- {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/models/attention.py +140 -162
- flaxdiff-0.1.6/flaxdiff/models/autoencoder/__init__.py +2 -0
- flaxdiff-0.1.6/flaxdiff/models/autoencoder/autoencoder.py +19 -0
- flaxdiff-0.1.6/flaxdiff/models/autoencoder/diffusers.py +91 -0
- flaxdiff-0.1.6/flaxdiff/models/autoencoder/simple_autoenc.py +26 -0
- flaxdiff-0.1.4/flaxdiff/models/simple_unet.py → flaxdiff-0.1.6/flaxdiff/models/common.py +28 -210
- flaxdiff-0.1.6/flaxdiff/models/simple_unet.py +205 -0
- flaxdiff-0.1.6/flaxdiff/trainer/__init__.py +2 -0
- flaxdiff-0.1.6/flaxdiff/trainer/autoencoder_trainer.py +182 -0
- flaxdiff-0.1.4/flaxdiff/trainer/__init__.py → flaxdiff-0.1.6/flaxdiff/trainer/diffusion_trainer.py +48 -47
- flaxdiff-0.1.6/flaxdiff/trainer/simple_trainer.py +418 -0
- flaxdiff-0.1.4/README.md → flaxdiff-0.1.6/flaxdiff.egg-info/PKG-INFO +24 -1
- {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff.egg-info/SOURCES.txt +6 -0
- {flaxdiff-0.1.4 → flaxdiff-0.1.6}/setup.py +1 -1
- flaxdiff-0.1.4/flaxdiff/models/common.py +0 -7
- flaxdiff-0.1.4/flaxdiff/trainer/simple_trainer.py +0 -323
- {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/__init__.py +0 -0
- {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/models/__init__.py +0 -0
- {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/models/favor_fastattn.py +0 -0
- {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/models/simple_vit.py +0 -0
- {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/predictors/__init__.py +0 -0
- {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/samplers/__init__.py +0 -0
- {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/samplers/common.py +0 -0
- {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/samplers/ddim.py +0 -0
- {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/samplers/ddpm.py +0 -0
- {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/samplers/euler.py +0 -0
- {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/samplers/heun_sampler.py +0 -0
- {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/samplers/multistep_dpm.py +0 -0
- {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/samplers/rk4_sampler.py +0 -0
- {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/schedulers/__init__.py +0 -0
- {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/schedulers/common.py +0 -0
- {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/schedulers/continuous.py +0 -0
- {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/schedulers/cosine.py +0 -0
- {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/schedulers/discrete.py +0 -0
- {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/schedulers/exp.py +0 -0
- {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/schedulers/karras.py +0 -0
- {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/schedulers/linear.py +0 -0
- {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/schedulers/sqrt.py +0 -0
- {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff/utils.py +0 -0
- {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff.egg-info/dependency_links.txt +0 -0
- {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff.egg-info/requires.txt +0 -0
- {flaxdiff-0.1.4 → flaxdiff-0.1.6}/flaxdiff.egg-info/top_level.txt +0 -0
- {flaxdiff-0.1.4 → flaxdiff-0.1.6}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: flaxdiff
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.6
|
4
4
|
Summary: A versatile and easy to understand Diffusion library
|
5
5
|
Author: Ashish Kumar Singh
|
6
6
|
Author-email: ashishkmr472@gmail.com
|
@@ -13,6 +13,8 @@ Requires-Dist: clu
|
|
13
13
|
|
14
14
|
# 
|
15
15
|
|
16
|
+
**This project is partially supported by [Google TPU Research Cloud](https://sites.research.google/trc/about/). I would like to thank the Google Cloud TPU team for providing me with the resources to train the bigger text-conditional models in multi-host distributed settings.**
|
17
|
+
|
16
18
|
## A Versatile and simple Diffusion Library
|
17
19
|
|
18
20
|
In recent years, diffusion and score-based multi-step models have revolutionized the generative AI domain. However, the latest research in this field has become highly math-intensive, making it challenging to understand how state-of-the-art diffusion models work and generate such impressive images. Replicating this research in code can be daunting.
|
@@ -27,7 +29,7 @@ The `Diffusion_flax_linen.ipynb` notebook is my main workspace for experiments.
|
|
27
29
|
|
28
30
|
In the `example notebooks` folder, you will find comprehensive notebooks for various diffusion techniques, written entirely from scratch and are independent of the FlaxDiff library. Each notebook includes detailed explanations of the underlying mathematics and concepts, making them invaluable resources for learning and understanding diffusion models.
|
29
31
|
|
30
|
-
### Available Notebooks
|
32
|
+
### Available Notebooks and Resources
|
31
33
|
|
32
34
|
- **[Diffusion explained (nbviewer link)](https://nbviewer.org/github/AshishKumar4/FlaxDiff/blob/main/tutorial%20notebooks/simple%20diffusion%20flax.ipynb) [(local link)](tutorial%20notebooks/simple%20diffusion%20flax.ipynb)**
|
33
35
|
|
@@ -46,6 +48,14 @@ In the `example notebooks` folder, you will find comprehensive notebooks for var
|
|
46
48
|
|
47
49
|
These notebooks aim to provide a very easy to understand and step-by-step guide to the various diffusion models and techniques. They are designed to be beginner-friendly, and thus although they may not adhere to the exact formulations and implementations of the original papers to make them more understandable and generalizable, I have tried my best to keep them as accurate as possible. If you find any mistakes or have any suggestions, please feel free to open an issue or a pull request.
|
48
50
|
|
51
|
+
#### Other resources
|
52
|
+
|
53
|
+
- **[Multi-host Data parallel training script in JAX](./training.py)**
|
54
|
+
- Training script for multi-host data parallel training in JAX, to serve as a reference for training large models on multiple GPUs/TPUs across multiple hosts. A full-fledged tutorial notebook is in the works.
|
55
|
+
|
56
|
+
- **[TPU utilities for making life easier](./tpu-tools/)**
|
57
|
+
- A collection of utilities and scripts to make working with TPUs easier, such as cli to create/start/stop/setup TPUs, script to setup TPU VMs (install everything you need), mounting gcs datasets etc.
|
58
|
+
|
49
59
|
## Disclaimer (and About Me)
|
50
60
|
|
51
61
|
I worked as a Machine Learning Researcher at Hyperverge from 2019-2021, focusing on computer vision, specifically facial anti-spoofing and facial detection & recognition. Since switching to my current job in 2021, I haven't engaged in as much R&D work, leading me to start this pet project to revisit and relearn the fundamentals and get familiar with the state-of-the-art. My current role involves primarily Golang system engineering with some applied ML work just sprinkled in. Therefore, the code may reflect my learning journey. Please forgive any mistakes and do open an issue to let me know.
|
@@ -1,18 +1,7 @@
|
|
1
|
-
Metadata-Version: 2.1
|
2
|
-
Name: flaxdiff
|
3
|
-
Version: 0.1.4
|
4
|
-
Summary: A versatile and easy to understand Diffusion library
|
5
|
-
Author: Ashish Kumar Singh
|
6
|
-
Author-email: ashishkmr472@gmail.com
|
7
|
-
Description-Content-Type: text/markdown
|
8
|
-
Requires-Dist: flax>=0.8.4
|
9
|
-
Requires-Dist: optax>=0.2.2
|
10
|
-
Requires-Dist: jax>=0.4.28
|
11
|
-
Requires-Dist: orbax
|
12
|
-
Requires-Dist: clu
|
13
|
-
|
14
1
|
# 
|
15
2
|
|
3
|
+
**This project is partially supported by [Google TPU Research Cloud](https://sites.research.google/trc/about/). I would like to thank the Google Cloud TPU team for providing me with the resources to train the bigger text-conditional models in multi-host distributed settings.**
|
4
|
+
|
16
5
|
## A Versatile and simple Diffusion Library
|
17
6
|
|
18
7
|
In recent years, diffusion and score-based multi-step models have revolutionized the generative AI domain. However, the latest research in this field has become highly math-intensive, making it challenging to understand how state-of-the-art diffusion models work and generate such impressive images. Replicating this research in code can be daunting.
|
@@ -27,7 +16,7 @@ The `Diffusion_flax_linen.ipynb` notebook is my main workspace for experiments.
|
|
27
16
|
|
28
17
|
In the `example notebooks` folder, you will find comprehensive notebooks for various diffusion techniques, written entirely from scratch and are independent of the FlaxDiff library. Each notebook includes detailed explanations of the underlying mathematics and concepts, making them invaluable resources for learning and understanding diffusion models.
|
29
18
|
|
30
|
-
### Available Notebooks
|
19
|
+
### Available Notebooks and Resources
|
31
20
|
|
32
21
|
- **[Diffusion explained (nbviewer link)](https://nbviewer.org/github/AshishKumar4/FlaxDiff/blob/main/tutorial%20notebooks/simple%20diffusion%20flax.ipynb) [(local link)](tutorial%20notebooks/simple%20diffusion%20flax.ipynb)**
|
33
22
|
|
@@ -46,6 +35,14 @@ In the `example notebooks` folder, you will find comprehensive notebooks for var
|
|
46
35
|
|
47
36
|
These notebooks aim to provide a very easy to understand and step-by-step guide to the various diffusion models and techniques. They are designed to be beginner-friendly, and thus although they may not adhere to the exact formulations and implementations of the original papers to make them more understandable and generalizable, I have tried my best to keep them as accurate as possible. If you find any mistakes or have any suggestions, please feel free to open an issue or a pull request.
|
48
37
|
|
38
|
+
#### Other resources
|
39
|
+
|
40
|
+
- **[Multi-host Data parallel training script in JAX](./training.py)**
|
41
|
+
- Training script for multi-host data parallel training in JAX, to serve as a reference for training large models on multiple GPUs/TPUs across multiple hosts. A full-fledged tutorial notebook is in the works.
|
42
|
+
|
43
|
+
- **[TPU utilities for making life easier](./tpu-tools/)**
|
44
|
+
- A collection of utilities and scripts to make working with TPUs easier, such as cli to create/start/stop/setup TPUs, script to setup TPU VMs (install everything you need), mounting gcs datasets etc.
|
45
|
+
|
49
46
|
## Disclaimer (and About Me)
|
50
47
|
|
51
48
|
I worked as a Machine Learning Researcher at Hyperverge from 2019-2021, focusing on computer vision, specifically facial anti-spoofing and facial detection & recognition. Since switching to my current job in 2021, I haven't engaged in as much R&D work, leading me to start this pet project to revisit and relearn the fundamentals and get familiar with the state-of-the-art. My current role involves primarily Golang system engineering with some applied ML work just sprinkled in. Therefore, the code may reflect my learning journey. Please forgive any mistakes and do open an issue to let me know.
|
@@ -5,7 +5,8 @@ Some Code ported from https://github.com/huggingface/diffusers/blob/main/src/dif
|
|
5
5
|
import jax
|
6
6
|
import jax.numpy as jnp
|
7
7
|
from flax import linen as nn
|
8
|
-
from typing import Dict, Callable, Sequence, Any, Union
|
8
|
+
from typing import Dict, Callable, Sequence, Any, Union, Tuple, Optional
|
9
|
+
from flax.typing import Dtype, PrecisionLike
|
9
10
|
import einops
|
10
11
|
import functools
|
11
12
|
import math
|
@@ -18,8 +19,8 @@ class EfficientAttention(nn.Module):
|
|
18
19
|
query_dim: int
|
19
20
|
heads: int = 4
|
20
21
|
dim_head: int = 64
|
21
|
-
dtype:
|
22
|
-
precision:
|
22
|
+
dtype: Optional[Dtype] = None
|
23
|
+
precision: PrecisionLike = None
|
23
24
|
use_bias: bool = True
|
24
25
|
kernel_init: Callable = lambda : kernel_init(1.0)
|
25
26
|
|
@@ -62,8 +63,13 @@ class EfficientAttention(nn.Module):
|
|
62
63
|
# x has shape [B, H * W, C]
|
63
64
|
context = x if context is None else context
|
64
65
|
|
65
|
-
|
66
|
-
|
66
|
+
orig_x_shape = x.shape
|
67
|
+
if len(x.shape) == 4:
|
68
|
+
B, H, W, C = x.shape
|
69
|
+
x = x.reshape((B, 1, H * W, C))
|
70
|
+
else:
|
71
|
+
B, SEQ, C = x.shape
|
72
|
+
x = x.reshape((B, 1, SEQ, C))
|
67
73
|
|
68
74
|
if len(context.shape) == 4:
|
69
75
|
B, _H, _W, _C = context.shape
|
@@ -93,7 +99,7 @@ class EfficientAttention(nn.Module):
|
|
93
99
|
|
94
100
|
proj = self.proj_attn(hidden_states)
|
95
101
|
|
96
|
-
proj = proj.reshape(
|
102
|
+
proj = proj.reshape(orig_x_shape)
|
97
103
|
|
98
104
|
return proj
|
99
105
|
|
@@ -104,8 +110,8 @@ class NormalAttention(nn.Module):
|
|
104
110
|
query_dim: int
|
105
111
|
heads: int = 4
|
106
112
|
dim_head: int = 64
|
107
|
-
dtype:
|
108
|
-
precision:
|
113
|
+
dtype: Optional[Dtype] = None
|
114
|
+
precision: PrecisionLike = None
|
109
115
|
use_bias: bool = True
|
110
116
|
kernel_init: Callable = lambda : kernel_init(1.0)
|
111
117
|
|
@@ -138,8 +144,10 @@ class NormalAttention(nn.Module):
|
|
138
144
|
@nn.compact
|
139
145
|
def __call__(self, x, context=None):
|
140
146
|
# x has shape [B, H, W, C]
|
141
|
-
|
142
|
-
|
147
|
+
orig_x_shape = x.shape
|
148
|
+
if len(x.shape) == 4:
|
149
|
+
B, H, W, C = x.shape
|
150
|
+
x = x.reshape((B, H*W, C))
|
143
151
|
context = x if context is None else context
|
144
152
|
if len(context.shape) == 4:
|
145
153
|
context = context.reshape((B, H*W, C))
|
@@ -151,16 +159,16 @@ class NormalAttention(nn.Module):
|
|
151
159
|
query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision
|
152
160
|
)
|
153
161
|
proj = self.proj_attn(hidden_states)
|
154
|
-
proj = proj.reshape(
|
162
|
+
proj = proj.reshape(orig_x_shape)
|
155
163
|
return proj
|
156
|
-
|
157
|
-
class
|
164
|
+
|
165
|
+
class BasicTransformerBlock(nn.Module):
|
158
166
|
# Has self and cross attention
|
159
167
|
query_dim: int
|
160
168
|
heads: int = 4
|
161
169
|
dim_head: int = 64
|
162
|
-
dtype:
|
163
|
-
precision:
|
170
|
+
dtype: Optional[Dtype] = None
|
171
|
+
precision: PrecisionLike = None
|
164
172
|
use_bias: bool = True
|
165
173
|
kernel_init: Callable = lambda : kernel_init(1.0)
|
166
174
|
use_flash_attention:bool = False
|
@@ -193,129 +201,26 @@ class AttentionBlock(nn.Module):
|
|
193
201
|
kernel_init=self.kernel_init
|
194
202
|
)
|
195
203
|
|
196
|
-
self.ff =
|
197
|
-
features=self.query_dim,
|
198
|
-
use_bias=self.use_bias,
|
199
|
-
precision=self.precision,
|
200
|
-
dtype=self.dtype,
|
201
|
-
kernel_init=self.kernel_init(),
|
202
|
-
name="ff"
|
203
|
-
)
|
204
|
+
self.ff = FlaxFeedForward(dim=self.query_dim)
|
204
205
|
self.norm1 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
205
206
|
self.norm2 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
206
207
|
self.norm3 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
207
|
-
self.norm4 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
208
208
|
|
209
209
|
@nn.compact
|
210
210
|
def __call__(self, hidden_states, context=None):
|
211
211
|
# self attention
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
hidden_states = self.attention1(hidden_states, context)
|
216
|
-
else:
|
217
|
-
hidden_states = self.attention1(hidden_states)
|
218
|
-
hidden_states = hidden_states + residual
|
212
|
+
if not self.use_cross_only:
|
213
|
+
print("Using self attention")
|
214
|
+
hidden_states = hidden_states + self.attention1(self.norm1(hidden_states))
|
219
215
|
|
220
216
|
# cross attention
|
221
|
-
|
222
|
-
hidden_states = self.norm2(hidden_states)
|
223
|
-
hidden_states = self.attention2(hidden_states, context)
|
224
|
-
hidden_states = hidden_states + residual
|
217
|
+
hidden_states = hidden_states + self.attention2(self.norm2(hidden_states), context)
|
225
218
|
|
226
219
|
# feed forward
|
227
|
-
|
228
|
-
hidden_states = self.norm3(hidden_states)
|
229
|
-
hidden_states = nn.gelu(hidden_states)
|
230
|
-
hidden_states = self.ff(hidden_states)
|
231
|
-
hidden_states = hidden_states + residual
|
220
|
+
hidden_states = hidden_states + self.ff(self.norm3(hidden_states))
|
232
221
|
|
233
222
|
return hidden_states
|
234
223
|
|
235
|
-
class TransformerBlock(nn.Module):
|
236
|
-
heads: int = 4
|
237
|
-
dim_head: int = 32
|
238
|
-
use_linear_attention: bool = True
|
239
|
-
dtype: Any = jnp.float32
|
240
|
-
precision: Any = jax.lax.Precision.HIGH
|
241
|
-
use_projection: bool = False
|
242
|
-
use_flash_attention:bool = True
|
243
|
-
use_self_and_cross:bool = False
|
244
|
-
|
245
|
-
@nn.compact
|
246
|
-
def __call__(self, x, context=None):
|
247
|
-
inner_dim = self.heads * self.dim_head
|
248
|
-
B, H, W, C = x.shape
|
249
|
-
normed_x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
|
250
|
-
if self.use_projection == True:
|
251
|
-
if self.use_linear_attention:
|
252
|
-
projected_x = nn.Dense(features=inner_dim,
|
253
|
-
use_bias=False, precision=self.precision,
|
254
|
-
kernel_init=kernel_init(1.0),
|
255
|
-
dtype=self.dtype, name=f'project_in')(normed_x)
|
256
|
-
else:
|
257
|
-
projected_x = nn.Conv(
|
258
|
-
features=inner_dim, kernel_size=(1, 1),
|
259
|
-
kernel_init=kernel_init(1.0),
|
260
|
-
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
|
261
|
-
precision=self.precision, name=f'project_in_conv',
|
262
|
-
)(normed_x)
|
263
|
-
else:
|
264
|
-
projected_x = normed_x
|
265
|
-
inner_dim = C
|
266
|
-
|
267
|
-
context = projected_x if context is None else context
|
268
|
-
|
269
|
-
if self.use_self_and_cross:
|
270
|
-
projected_x = AttentionBlock(
|
271
|
-
query_dim=inner_dim,
|
272
|
-
heads=self.heads,
|
273
|
-
dim_head=self.dim_head,
|
274
|
-
name=f'Attention',
|
275
|
-
precision=self.precision,
|
276
|
-
use_bias=False,
|
277
|
-
dtype=self.dtype,
|
278
|
-
use_flash_attention=self.use_flash_attention,
|
279
|
-
use_cross_only=False
|
280
|
-
)(projected_x, context)
|
281
|
-
elif self.use_flash_attention == True:
|
282
|
-
projected_x = EfficientAttention(
|
283
|
-
query_dim=inner_dim,
|
284
|
-
heads=self.heads,
|
285
|
-
dim_head=self.dim_head,
|
286
|
-
name=f'Attention',
|
287
|
-
precision=self.precision,
|
288
|
-
use_bias=False,
|
289
|
-
dtype=self.dtype,
|
290
|
-
)(projected_x, context)
|
291
|
-
else:
|
292
|
-
projected_x = NormalAttention(
|
293
|
-
query_dim=inner_dim,
|
294
|
-
heads=self.heads,
|
295
|
-
dim_head=self.dim_head,
|
296
|
-
name=f'Attention',
|
297
|
-
precision=self.precision,
|
298
|
-
use_bias=False,
|
299
|
-
)(projected_x, context)
|
300
|
-
|
301
|
-
|
302
|
-
if self.use_projection == True:
|
303
|
-
if self.use_linear_attention:
|
304
|
-
projected_x = nn.Dense(features=C, precision=self.precision,
|
305
|
-
dtype=self.dtype, use_bias=False,
|
306
|
-
kernel_init=kernel_init(1.0),
|
307
|
-
name=f'project_out')(projected_x)
|
308
|
-
else:
|
309
|
-
projected_x = nn.Conv(
|
310
|
-
features=C, kernel_size=(1, 1),
|
311
|
-
kernel_init=kernel_init(1.0),
|
312
|
-
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
|
313
|
-
precision=self.precision, name=f'project_out_conv',
|
314
|
-
)(projected_x)
|
315
|
-
|
316
|
-
out = x + projected_x
|
317
|
-
return out
|
318
|
-
|
319
224
|
class FlaxGEGLU(nn.Module):
|
320
225
|
r"""
|
321
226
|
Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
|
@@ -333,10 +238,11 @@ class FlaxGEGLU(nn.Module):
|
|
333
238
|
dim: int
|
334
239
|
dropout: float = 0.0
|
335
240
|
dtype: jnp.dtype = jnp.float32
|
241
|
+
precision: Any = jax.lax.Precision.DEFAULT
|
336
242
|
|
337
243
|
def setup(self):
|
338
244
|
inner_dim = self.dim * 4
|
339
|
-
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype, precision=
|
245
|
+
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype, precision=self.precision)
|
340
246
|
|
341
247
|
def __call__(self, hidden_states):
|
342
248
|
hidden_states = self.proj(hidden_states)
|
@@ -362,14 +268,14 @@ class FlaxFeedForward(nn.Module):
|
|
362
268
|
"""
|
363
269
|
|
364
270
|
dim: int
|
365
|
-
dropout: float = 0.0
|
366
271
|
dtype: jnp.dtype = jnp.float32
|
272
|
+
precision: Any = jax.lax.Precision.DEFAULT
|
367
273
|
|
368
274
|
def setup(self):
|
369
275
|
# The second linear layer needs to be called
|
370
276
|
# net_2 for now to match the index of the Sequential layer
|
371
|
-
self.net_0 = FlaxGEGLU(self.dim, self.dtype)
|
372
|
-
self.net_2 = nn.Dense(self.dim, dtype=self.dtype, precision=
|
277
|
+
self.net_0 = FlaxGEGLU(self.dim, self.dtype, precision=self.precision)
|
278
|
+
self.net_2 = nn.Dense(self.dim, dtype=self.dtype, precision=self.precision)
|
373
279
|
|
374
280
|
def __call__(self, hidden_states):
|
375
281
|
hidden_states = self.net_0(hidden_states)
|
@@ -377,55 +283,127 @@ class FlaxFeedForward(nn.Module):
|
|
377
283
|
return hidden_states
|
378
284
|
|
379
285
|
class BasicTransformerBlock(nn.Module):
|
286
|
+
# Has self and cross attention
|
380
287
|
query_dim: int
|
381
|
-
heads: int
|
382
|
-
dim_head: int
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
288
|
+
heads: int = 4
|
289
|
+
dim_head: int = 64
|
290
|
+
dtype: Optional[Dtype] = None
|
291
|
+
precision: PrecisionLike = None
|
292
|
+
use_bias: bool = True
|
293
|
+
kernel_init: Callable = lambda : kernel_init(1.0)
|
294
|
+
use_flash_attention:bool = False
|
295
|
+
use_cross_only:bool = False
|
296
|
+
only_pure_attention:bool = False
|
297
|
+
|
390
298
|
def setup(self):
|
391
|
-
|
392
|
-
|
393
|
-
|
299
|
+
if self.use_flash_attention:
|
300
|
+
attenBlock = EfficientAttention
|
301
|
+
else:
|
302
|
+
attenBlock = NormalAttention
|
303
|
+
|
304
|
+
self.attention1 = attenBlock(
|
305
|
+
query_dim=self.query_dim,
|
394
306
|
heads=self.heads,
|
395
307
|
dim_head=self.dim_head,
|
396
|
-
|
308
|
+
name=f'Attention1',
|
397
309
|
precision=self.precision,
|
310
|
+
use_bias=self.use_bias,
|
311
|
+
dtype=self.dtype,
|
312
|
+
kernel_init=self.kernel_init
|
398
313
|
)
|
399
|
-
|
400
|
-
self.attn2 = NormalAttention(
|
314
|
+
self.attention2 = attenBlock(
|
401
315
|
query_dim=self.query_dim,
|
402
316
|
heads=self.heads,
|
403
317
|
dim_head=self.dim_head,
|
404
|
-
|
318
|
+
name=f'Attention2',
|
405
319
|
precision=self.precision,
|
320
|
+
use_bias=self.use_bias,
|
321
|
+
dtype=self.dtype,
|
322
|
+
kernel_init=self.kernel_init
|
406
323
|
)
|
407
|
-
|
324
|
+
|
325
|
+
self.ff = FlaxFeedForward(dim=self.query_dim)
|
408
326
|
self.norm1 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
409
327
|
self.norm2 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
410
328
|
self.norm3 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
411
|
-
|
412
|
-
|
329
|
+
|
330
|
+
@nn.compact
|
331
|
+
def __call__(self, hidden_states, context=None):
|
332
|
+
if self.only_pure_attention:
|
333
|
+
return self.attention2(self.norm2(hidden_states), context)
|
334
|
+
|
413
335
|
# self attention
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
else:
|
418
|
-
hidden_states = self.attn1(self.norm1(hidden_states))
|
419
|
-
hidden_states = hidden_states + residual
|
420
|
-
|
336
|
+
if not self.use_cross_only:
|
337
|
+
hidden_states = hidden_states + self.attention1(self.norm1(hidden_states))
|
338
|
+
|
421
339
|
# cross attention
|
422
|
-
|
423
|
-
hidden_states = self.attn2(self.norm2(hidden_states), context)
|
424
|
-
hidden_states = hidden_states + residual
|
425
|
-
|
340
|
+
hidden_states = hidden_states + self.attention2(self.norm2(hidden_states), context)
|
426
341
|
# feed forward
|
427
|
-
|
428
|
-
|
429
|
-
|
342
|
+
hidden_states = hidden_states + self.ff(self.norm3(hidden_states))
|
343
|
+
|
344
|
+
return hidden_states
|
345
|
+
|
346
|
+
class TransformerBlock(nn.Module):
|
347
|
+
heads: int = 4
|
348
|
+
dim_head: int = 32
|
349
|
+
use_linear_attention: bool = True
|
350
|
+
dtype: Optional[Dtype] = None
|
351
|
+
precision: PrecisionLike = None
|
352
|
+
use_projection: bool = False
|
353
|
+
use_flash_attention:bool = True
|
354
|
+
use_self_and_cross:bool = False
|
355
|
+
only_pure_attention:bool = False
|
356
|
+
|
357
|
+
@nn.compact
|
358
|
+
def __call__(self, x, context=None):
|
359
|
+
inner_dim = self.heads * self.dim_head
|
360
|
+
B, H, W, C = x.shape
|
361
|
+
normed_x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
|
362
|
+
if self.use_projection == True:
|
363
|
+
if self.use_linear_attention:
|
364
|
+
projected_x = nn.Dense(features=inner_dim,
|
365
|
+
use_bias=False, precision=self.precision,
|
366
|
+
kernel_init=kernel_init(1.0),
|
367
|
+
dtype=self.dtype, name=f'project_in')(normed_x)
|
368
|
+
else:
|
369
|
+
projected_x = nn.Conv(
|
370
|
+
features=inner_dim, kernel_size=(1, 1),
|
371
|
+
kernel_init=kernel_init(1.0),
|
372
|
+
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
|
373
|
+
precision=self.precision, name=f'project_in_conv',
|
374
|
+
)(normed_x)
|
375
|
+
else:
|
376
|
+
projected_x = normed_x
|
377
|
+
inner_dim = C
|
378
|
+
|
379
|
+
context = projected_x if context is None else context
|
430
380
|
|
431
|
-
|
381
|
+
projected_x = BasicTransformerBlock(
|
382
|
+
query_dim=inner_dim,
|
383
|
+
heads=self.heads,
|
384
|
+
dim_head=self.dim_head,
|
385
|
+
name=f'Attention',
|
386
|
+
precision=self.precision,
|
387
|
+
use_bias=False,
|
388
|
+
dtype=self.dtype,
|
389
|
+
use_flash_attention=self.use_flash_attention,
|
390
|
+
use_cross_only=(not self.use_self_and_cross),
|
391
|
+
only_pure_attention=self.only_pure_attention
|
392
|
+
)(projected_x, context)
|
393
|
+
|
394
|
+
if self.use_projection == True:
|
395
|
+
if self.use_linear_attention:
|
396
|
+
projected_x = nn.Dense(features=C, precision=self.precision,
|
397
|
+
dtype=self.dtype, use_bias=False,
|
398
|
+
kernel_init=kernel_init(1.0),
|
399
|
+
name=f'project_out')(projected_x)
|
400
|
+
else:
|
401
|
+
projected_x = nn.Conv(
|
402
|
+
features=C, kernel_size=(1, 1),
|
403
|
+
kernel_init=kernel_init(1.0),
|
404
|
+
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
|
405
|
+
precision=self.precision, name=f'project_out_conv',
|
406
|
+
)(projected_x)
|
407
|
+
|
408
|
+
out = x + projected_x
|
409
|
+
return out
|
@@ -0,0 +1,19 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
from flax import linen as nn
|
4
|
+
from typing import Dict, Callable, Sequence, Any, Union
|
5
|
+
import einops
|
6
|
+
from ..common import kernel_init, ConvLayer, Upsample, Downsample, PixelShuffle
|
7
|
+
|
8
|
+
|
9
|
+
class AutoEncoder():
|
10
|
+
def encode(self, x: jnp.ndarray, **kwargs) -> jnp.ndarray:
|
11
|
+
raise NotImplementedError
|
12
|
+
|
13
|
+
def decode(self, z: jnp.ndarray, **kwargs) -> jnp.ndarray:
|
14
|
+
raise NotImplementedError
|
15
|
+
|
16
|
+
def __call__(self, x: jnp.ndarray):
|
17
|
+
latents = self.encode(x)
|
18
|
+
reconstructions = self.decode(latents)
|
19
|
+
return reconstructions
|
@@ -0,0 +1,91 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
from flax import linen as nn
|
4
|
+
from .autoencoder import AutoEncoder
|
5
|
+
|
6
|
+
"""
|
7
|
+
This module contains an Autoencoder implementation which uses the Stable Diffusion VAE model from the HuggingFace Diffusers library.
|
8
|
+
The actual model was not trained by me, but was taken from the HuggingFace model hub.
|
9
|
+
I have only implemented the wrapper around the diffusers pipeline to make it compatible with our library
|
10
|
+
All credits for the model go to the developers of Stable Diffusion VAE and all credits for the pipeline go to the developers of the Diffusers library.
|
11
|
+
"""
|
12
|
+
|
13
|
+
class StableDiffusionVAE(AutoEncoder):
|
14
|
+
def __init__(self, modelname = "CompVis/stable-diffusion-v1-4"):
|
15
|
+
|
16
|
+
from diffusers.models.vae_flax import FlaxEncoder, FlaxDecoder
|
17
|
+
from diffusers import FlaxStableDiffusionPipeline
|
18
|
+
|
19
|
+
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
20
|
+
modelname,
|
21
|
+
revision="bf16",
|
22
|
+
dtype=jnp.bfloat16,
|
23
|
+
)
|
24
|
+
|
25
|
+
vae = pipeline.vae
|
26
|
+
|
27
|
+
enc = FlaxEncoder(
|
28
|
+
in_channels=vae.config.in_channels,
|
29
|
+
out_channels=vae.config.latent_channels,
|
30
|
+
down_block_types=vae.config.down_block_types,
|
31
|
+
block_out_channels=vae.config.block_out_channels,
|
32
|
+
layers_per_block=vae.config.layers_per_block,
|
33
|
+
act_fn=vae.config.act_fn,
|
34
|
+
norm_num_groups=vae.config.norm_num_groups,
|
35
|
+
double_z=True,
|
36
|
+
dtype=vae.dtype,
|
37
|
+
)
|
38
|
+
|
39
|
+
dec = FlaxDecoder(
|
40
|
+
in_channels=vae.config.latent_channels,
|
41
|
+
out_channels=vae.config.out_channels,
|
42
|
+
up_block_types=vae.config.up_block_types,
|
43
|
+
block_out_channels=vae.config.block_out_channels,
|
44
|
+
layers_per_block=vae.config.layers_per_block,
|
45
|
+
norm_num_groups=vae.config.norm_num_groups,
|
46
|
+
act_fn=vae.config.act_fn,
|
47
|
+
dtype=vae.dtype,
|
48
|
+
)
|
49
|
+
|
50
|
+
quant_conv = nn.Conv(
|
51
|
+
2 * vae.config.latent_channels,
|
52
|
+
kernel_size=(1, 1),
|
53
|
+
strides=(1, 1),
|
54
|
+
padding="VALID",
|
55
|
+
dtype=vae.dtype,
|
56
|
+
)
|
57
|
+
|
58
|
+
post_quant_conv = nn.Conv(
|
59
|
+
vae.config.latent_channels,
|
60
|
+
kernel_size=(1, 1),
|
61
|
+
strides=(1, 1),
|
62
|
+
padding="VALID",
|
63
|
+
dtype=vae.dtype,
|
64
|
+
)
|
65
|
+
|
66
|
+
self.enc = enc
|
67
|
+
self.dec = dec
|
68
|
+
self.post_quant_conv = post_quant_conv
|
69
|
+
self.quant_conv = quant_conv
|
70
|
+
self.params = params
|
71
|
+
self.scaling_factor = vae.scaling_factor
|
72
|
+
|
73
|
+
def encode(self, images, rngkey: jax.random.PRNGKey = None):
|
74
|
+
latents = self.enc.apply({"params": self.params["vae"]['encoder']}, images, deterministic=True)
|
75
|
+
latents = self.quant_conv.apply({"params": self.params["vae"]['quant_conv']}, latents)
|
76
|
+
if rngkey is not None:
|
77
|
+
mean, log_std = jnp.split(latents, 2, axis=-1)
|
78
|
+
log_std = jnp.clip(log_std, -30, 20)
|
79
|
+
std = jnp.exp(0.5 * log_std)
|
80
|
+
latents = mean + std * jax.random.normal(rngkey, mean.shape, dtype=mean.dtype)
|
81
|
+
print("Sampled")
|
82
|
+
else:
|
83
|
+
# return the mean
|
84
|
+
latents, _ = jnp.split(latents, 2, axis=-1)
|
85
|
+
latents *= self.scaling_factor
|
86
|
+
return latents
|
87
|
+
|
88
|
+
def decode(self, latents):
|
89
|
+
latents = (1.0 / self.scaling_factor) * latents
|
90
|
+
latents = self.post_quant_conv.apply({"params": self.params["vae"]['post_quant_conv']}, latents)
|
91
|
+
return self.dec.apply({"params": self.params["vae"]['decoder']}, latents)
|
@@ -0,0 +1,26 @@
|
|
1
|
+
from typing import Any, List, Optional, Callable
|
2
|
+
import jax
|
3
|
+
import flax.linen as nn
|
4
|
+
from jax import numpy as jnp
|
5
|
+
from flax.typing import Dtype, PrecisionLike
|
6
|
+
from .autoencoder import AutoEncoder
|
7
|
+
|
8
|
+
class SimpleAutoEncoder(AutoEncoder):
|
9
|
+
latent_channels: int
|
10
|
+
feature_depths: List[int]=[64, 128, 256, 512]
|
11
|
+
attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}],
|
12
|
+
num_res_blocks: int=2
|
13
|
+
num_middle_res_blocks:int=1,
|
14
|
+
activation:Callable = jax.nn.swish
|
15
|
+
norm_groups:int=8
|
16
|
+
dtype: Optional[Dtype] = None
|
17
|
+
precision: PrecisionLike = None
|
18
|
+
|
19
|
+
# def encode(self, x: jnp.ndarray):
|
20
|
+
|
21
|
+
|
22
|
+
@nn.compact
|
23
|
+
def __call__(self, x: jnp.ndarray):
|
24
|
+
latents = self.encode(x)
|
25
|
+
reconstructions = self.decode(latents)
|
26
|
+
return reconstructions
|