flaxdiff 0.1.36.1__py3-none-any.whl → 0.1.36.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/__init__.py +1 -0
- flaxdiff/data/dataset_map.py +71 -0
- flaxdiff/data/datasets.py +169 -0
- flaxdiff/data/online_loader.py +363 -0
- flaxdiff/data/sources/gcs.py +81 -0
- flaxdiff/data/sources/tfds.py +67 -0
- flaxdiff/metrics/inception.py +658 -0
- flaxdiff/metrics/utils.py +49 -0
- flaxdiff/models/__init__.py +1 -0
- flaxdiff/models/attention.py +368 -0
- flaxdiff/models/autoencoder/__init__.py +2 -0
- flaxdiff/models/autoencoder/autoencoder.py +19 -0
- flaxdiff/models/autoencoder/diffusers.py +91 -0
- flaxdiff/models/autoencoder/simple_autoenc.py +26 -0
- flaxdiff/models/common.py +346 -0
- flaxdiff/models/favor_fastattn.py +723 -0
- flaxdiff/models/simple_unet.py +233 -0
- flaxdiff/models/simple_vit.py +180 -0
- flaxdiff/predictors/__init__.py +96 -0
- flaxdiff/samplers/__init__.py +7 -0
- flaxdiff/samplers/common.py +165 -0
- flaxdiff/samplers/ddim.py +10 -0
- flaxdiff/samplers/ddpm.py +37 -0
- flaxdiff/samplers/euler.py +56 -0
- flaxdiff/samplers/heun_sampler.py +27 -0
- flaxdiff/samplers/multistep_dpm.py +59 -0
- flaxdiff/samplers/rk4_sampler.py +34 -0
- flaxdiff/schedulers/__init__.py +6 -0
- flaxdiff/schedulers/common.py +98 -0
- flaxdiff/schedulers/continuous.py +12 -0
- flaxdiff/schedulers/cosine.py +40 -0
- flaxdiff/schedulers/discrete.py +74 -0
- flaxdiff/schedulers/exp.py +13 -0
- flaxdiff/schedulers/karras.py +69 -0
- flaxdiff/schedulers/linear.py +14 -0
- flaxdiff/schedulers/sqrt.py +10 -0
- flaxdiff/trainer/__init__.py +2 -0
- flaxdiff/trainer/autoencoder_trainer.py +182 -0
- flaxdiff/trainer/diffusion_trainer.py +326 -0
- flaxdiff/trainer/simple_trainer.py +540 -0
- flaxdiff/trainer/video_diffusion_trainer.py +62 -0
- {flaxdiff-0.1.36.1.dist-info → flaxdiff-0.1.36.3.dist-info}/METADATA +1 -1
- flaxdiff-0.1.36.3.dist-info/RECORD +47 -0
- flaxdiff-0.1.36.1.dist-info/RECORD +0 -6
- {flaxdiff-0.1.36.1.dist-info → flaxdiff-0.1.36.3.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.36.1.dist-info → flaxdiff-0.1.36.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,49 @@
|
|
1
|
+
# Mostly derived from
|
2
|
+
# https://github.com/matthias-wright/jax-fid
|
3
|
+
|
4
|
+
import jax
|
5
|
+
import flax
|
6
|
+
import numpy as np
|
7
|
+
from tqdm import tqdm
|
8
|
+
import requests
|
9
|
+
import os
|
10
|
+
import tempfile
|
11
|
+
|
12
|
+
|
13
|
+
def download(url, ckpt_dir=None):
|
14
|
+
name = url[url.rfind('/') + 1 : url.rfind('?')]
|
15
|
+
if ckpt_dir is None:
|
16
|
+
ckpt_dir = tempfile.gettempdir()
|
17
|
+
ckpt_dir = os.path.join(ckpt_dir, 'jax_fid')
|
18
|
+
ckpt_file = os.path.join(ckpt_dir, name)
|
19
|
+
if not os.path.exists(ckpt_file):
|
20
|
+
print(f'Downloading: \"{url[:url.rfind("?")]}\" to {ckpt_file}')
|
21
|
+
if not os.path.exists(ckpt_dir):
|
22
|
+
os.makedirs(ckpt_dir)
|
23
|
+
|
24
|
+
response = requests.get(url, stream=True)
|
25
|
+
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
26
|
+
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
|
27
|
+
|
28
|
+
# first create temp file, in case the download fails
|
29
|
+
ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp')
|
30
|
+
with open(ckpt_file_temp, 'wb') as file:
|
31
|
+
for data in response.iter_content(chunk_size=1024):
|
32
|
+
progress_bar.update(len(data))
|
33
|
+
file.write(data)
|
34
|
+
progress_bar.close()
|
35
|
+
|
36
|
+
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
|
37
|
+
print('An error occured while downloading, please try again.')
|
38
|
+
if os.path.exists(ckpt_file_temp):
|
39
|
+
os.remove(ckpt_file_temp)
|
40
|
+
else:
|
41
|
+
# if download was successful, rename the temp file
|
42
|
+
os.rename(ckpt_file_temp, ckpt_file)
|
43
|
+
return ckpt_file
|
44
|
+
|
45
|
+
|
46
|
+
def get(dictionary, key):
|
47
|
+
if dictionary is None or key not in dictionary:
|
48
|
+
return None
|
49
|
+
return dictionary[key]
|
@@ -0,0 +1 @@
|
|
1
|
+
from .simple_unet import *
|
@@ -0,0 +1,368 @@
|
|
1
|
+
"""
|
2
|
+
Some Code ported from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_flax.py
|
3
|
+
"""
|
4
|
+
|
5
|
+
import jax
|
6
|
+
import jax.numpy as jnp
|
7
|
+
from flax import linen as nn
|
8
|
+
from typing import Dict, Callable, Sequence, Any, Union, Tuple, Optional
|
9
|
+
from flax.typing import Dtype, PrecisionLike
|
10
|
+
import einops
|
11
|
+
import functools
|
12
|
+
import math
|
13
|
+
from .common import kernel_init
|
14
|
+
import jax.experimental.pallas.ops.tpu.flash_attention
|
15
|
+
|
16
|
+
class EfficientAttention(nn.Module):
|
17
|
+
"""
|
18
|
+
Based on the pallas attention implementation.
|
19
|
+
"""
|
20
|
+
query_dim: int
|
21
|
+
heads: int = 4
|
22
|
+
dim_head: int = 64
|
23
|
+
dtype: Optional[Dtype] = None
|
24
|
+
precision: PrecisionLike = None
|
25
|
+
use_bias: bool = True
|
26
|
+
kernel_init: Callable = kernel_init(1.0)
|
27
|
+
force_fp32_for_softmax: bool = True
|
28
|
+
|
29
|
+
def setup(self):
|
30
|
+
inner_dim = self.dim_head * self.heads
|
31
|
+
# Weights were exported with old names {to_q, to_k, to_v, to_out}
|
32
|
+
dense = functools.partial(
|
33
|
+
nn.Dense,
|
34
|
+
self.heads * self.dim_head,
|
35
|
+
precision=self.precision,
|
36
|
+
use_bias=self.use_bias,
|
37
|
+
kernel_init=self.kernel_init,
|
38
|
+
dtype=self.dtype
|
39
|
+
)
|
40
|
+
self.query = dense(name="to_q")
|
41
|
+
self.key = dense(name="to_k")
|
42
|
+
self.value = dense(name="to_v")
|
43
|
+
|
44
|
+
self.proj_attn = nn.DenseGeneral(self.query_dim, use_bias=False, precision=self.precision,
|
45
|
+
kernel_init=self.kernel_init, dtype=self.dtype, name="to_out_0")
|
46
|
+
# self.attnfn = make_fast_generalized_attention(qkv_dim=inner_dim, lax_scan_unroll=16)
|
47
|
+
|
48
|
+
def _reshape_tensor_to_head_dim(self, tensor):
|
49
|
+
batch_size, _, seq_len, dim = tensor.shape
|
50
|
+
head_size = self.heads
|
51
|
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
52
|
+
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
|
53
|
+
return tensor
|
54
|
+
|
55
|
+
def _reshape_tensor_from_head_dim(self, tensor):
|
56
|
+
batch_size, _, seq_len, dim = tensor.shape
|
57
|
+
head_size = self.heads
|
58
|
+
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
|
59
|
+
tensor = tensor.reshape(batch_size, 1, seq_len, dim * head_size)
|
60
|
+
return tensor
|
61
|
+
|
62
|
+
@nn.compact
|
63
|
+
def __call__(self, x:jax.Array, context=None):
|
64
|
+
# print(x.shape)
|
65
|
+
# x has shape [B, H * W, C]
|
66
|
+
context = x if context is None else context
|
67
|
+
|
68
|
+
orig_x_shape = x.shape
|
69
|
+
if len(x.shape) == 4:
|
70
|
+
B, H, W, C = x.shape
|
71
|
+
x = x.reshape((B, 1, H * W, C))
|
72
|
+
else:
|
73
|
+
B, SEQ, C = x.shape
|
74
|
+
x = x.reshape((B, 1, SEQ, C))
|
75
|
+
|
76
|
+
if len(context.shape) == 4:
|
77
|
+
B, _H, _W, _C = context.shape
|
78
|
+
context = context.reshape((B, 1, _H * _W, _C))
|
79
|
+
else:
|
80
|
+
B, SEQ, _C = context.shape
|
81
|
+
context = context.reshape((B, 1, SEQ, _C))
|
82
|
+
|
83
|
+
query = self.query(x)
|
84
|
+
key = self.key(context)
|
85
|
+
value = self.value(context)
|
86
|
+
|
87
|
+
query = self._reshape_tensor_to_head_dim(query)
|
88
|
+
key = self._reshape_tensor_to_head_dim(key)
|
89
|
+
value = self._reshape_tensor_to_head_dim(value)
|
90
|
+
|
91
|
+
hidden_states = jax.experimental.pallas.ops.tpu.flash_attention.flash_attention(
|
92
|
+
query, key, value, None
|
93
|
+
)
|
94
|
+
|
95
|
+
hidden_states = self._reshape_tensor_from_head_dim(hidden_states)
|
96
|
+
|
97
|
+
|
98
|
+
# hidden_states = nn.dot_product_attention(
|
99
|
+
# query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision
|
100
|
+
# )
|
101
|
+
|
102
|
+
proj = self.proj_attn(hidden_states)
|
103
|
+
|
104
|
+
proj = proj.reshape(orig_x_shape)
|
105
|
+
|
106
|
+
return proj
|
107
|
+
|
108
|
+
class NormalAttention(nn.Module):
|
109
|
+
"""
|
110
|
+
Simple implementation of the normal attention.
|
111
|
+
"""
|
112
|
+
query_dim: int
|
113
|
+
heads: int = 4
|
114
|
+
dim_head: int = 64
|
115
|
+
dtype: Optional[Dtype] = None
|
116
|
+
precision: PrecisionLike = None
|
117
|
+
use_bias: bool = True
|
118
|
+
kernel_init: Callable = kernel_init(1.0)
|
119
|
+
force_fp32_for_softmax: bool = True
|
120
|
+
|
121
|
+
def setup(self):
|
122
|
+
inner_dim = self.dim_head * self.heads
|
123
|
+
dense = functools.partial(
|
124
|
+
nn.DenseGeneral,
|
125
|
+
features=[self.heads, self.dim_head],
|
126
|
+
axis=-1,
|
127
|
+
precision=self.precision,
|
128
|
+
use_bias=self.use_bias,
|
129
|
+
kernel_init=self.kernel_init,
|
130
|
+
dtype=self.dtype
|
131
|
+
)
|
132
|
+
self.query = dense(name="to_q")
|
133
|
+
self.key = dense(name="to_k")
|
134
|
+
self.value = dense(name="to_v")
|
135
|
+
|
136
|
+
self.proj_attn = nn.DenseGeneral(
|
137
|
+
self.query_dim,
|
138
|
+
axis=(-2, -1),
|
139
|
+
precision=self.precision,
|
140
|
+
use_bias=self.use_bias,
|
141
|
+
dtype=self.dtype,
|
142
|
+
name="to_out_0",
|
143
|
+
kernel_init=self.kernel_init
|
144
|
+
# kernel_init=jax.nn.initializers.xavier_uniform()
|
145
|
+
)
|
146
|
+
|
147
|
+
@nn.compact
|
148
|
+
def __call__(self, x, context=None):
|
149
|
+
# x has shape [B, H, W, C]
|
150
|
+
orig_x_shape = x.shape
|
151
|
+
if len(x.shape) == 4:
|
152
|
+
B, H, W, C = x.shape
|
153
|
+
x = x.reshape((B, H*W, C))
|
154
|
+
context = x if context is None else context
|
155
|
+
if len(context.shape) == 4:
|
156
|
+
context = context.reshape((B, H*W, C))
|
157
|
+
query = self.query(x)
|
158
|
+
key = self.key(context)
|
159
|
+
value = self.value(context)
|
160
|
+
|
161
|
+
hidden_states = nn.dot_product_attention(
|
162
|
+
query, key, value, dtype=self.dtype, broadcast_dropout=False,
|
163
|
+
dropout_rng=None, precision=self.precision, force_fp32_for_softmax=self.force_fp32_for_softmax,
|
164
|
+
deterministic=True
|
165
|
+
)
|
166
|
+
proj = self.proj_attn(hidden_states)
|
167
|
+
proj = proj.reshape(orig_x_shape)
|
168
|
+
return proj
|
169
|
+
|
170
|
+
class FlaxGEGLU(nn.Module):
|
171
|
+
r"""
|
172
|
+
Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
|
173
|
+
https://arxiv.org/abs/2002.05202.
|
174
|
+
|
175
|
+
Parameters:
|
176
|
+
dim (:obj:`int`):
|
177
|
+
Input hidden states dimension
|
178
|
+
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
179
|
+
Dropout rate
|
180
|
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
181
|
+
Parameters `dtype`
|
182
|
+
"""
|
183
|
+
|
184
|
+
dim: int
|
185
|
+
dropout: float = 0.0
|
186
|
+
dtype: jnp.dtype = jnp.float32
|
187
|
+
precision: Any = jax.lax.Precision.DEFAULT
|
188
|
+
|
189
|
+
def setup(self):
|
190
|
+
inner_dim = self.dim * 4
|
191
|
+
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype, precision=self.precision)
|
192
|
+
|
193
|
+
def __call__(self, hidden_states):
|
194
|
+
hidden_states = self.proj(hidden_states)
|
195
|
+
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=-1)
|
196
|
+
return hidden_linear * nn.gelu(hidden_gelu)
|
197
|
+
|
198
|
+
class FlaxFeedForward(nn.Module):
|
199
|
+
r"""
|
200
|
+
Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's
|
201
|
+
[`FeedForward`] class, with the following simplifications:
|
202
|
+
- The activation function is currently hardcoded to a gated linear unit from:
|
203
|
+
https://arxiv.org/abs/2002.05202
|
204
|
+
- `dim_out` is equal to `dim`.
|
205
|
+
- The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`].
|
206
|
+
|
207
|
+
Parameters:
|
208
|
+
dim (:obj:`int`):
|
209
|
+
Inner hidden states dimension
|
210
|
+
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
211
|
+
Dropout rate
|
212
|
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
213
|
+
Parameters `dtype`
|
214
|
+
"""
|
215
|
+
|
216
|
+
dim: int
|
217
|
+
dtype: jnp.dtype = jnp.float32
|
218
|
+
precision: Any = jax.lax.Precision.DEFAULT
|
219
|
+
|
220
|
+
def setup(self):
|
221
|
+
# The second linear layer needs to be called
|
222
|
+
# net_2 for now to match the index of the Sequential layer
|
223
|
+
self.net_0 = FlaxGEGLU(self.dim, self.dtype, precision=self.precision)
|
224
|
+
self.net_2 = nn.Dense(self.dim, dtype=self.dtype, precision=self.precision)
|
225
|
+
|
226
|
+
def __call__(self, hidden_states):
|
227
|
+
hidden_states = self.net_0(hidden_states)
|
228
|
+
hidden_states = self.net_2(hidden_states)
|
229
|
+
return hidden_states
|
230
|
+
|
231
|
+
class BasicTransformerBlock(nn.Module):
|
232
|
+
# Has self and cross attention
|
233
|
+
query_dim: int
|
234
|
+
heads: int = 4
|
235
|
+
dim_head: int = 64
|
236
|
+
dtype: Optional[Dtype] = None
|
237
|
+
precision: PrecisionLike = None
|
238
|
+
use_bias: bool = True
|
239
|
+
kernel_init: Callable = kernel_init(1.0)
|
240
|
+
use_flash_attention:bool = False
|
241
|
+
use_cross_only:bool = False
|
242
|
+
only_pure_attention:bool = False
|
243
|
+
force_fp32_for_softmax: bool = True
|
244
|
+
|
245
|
+
def setup(self):
|
246
|
+
if self.use_flash_attention:
|
247
|
+
attenBlock = EfficientAttention
|
248
|
+
else:
|
249
|
+
attenBlock = NormalAttention
|
250
|
+
|
251
|
+
self.attention1 = attenBlock(
|
252
|
+
query_dim=self.query_dim,
|
253
|
+
heads=self.heads,
|
254
|
+
dim_head=self.dim_head,
|
255
|
+
name=f'Attention1',
|
256
|
+
precision=self.precision,
|
257
|
+
use_bias=self.use_bias,
|
258
|
+
dtype=self.dtype,
|
259
|
+
kernel_init=self.kernel_init,
|
260
|
+
force_fp32_for_softmax=self.force_fp32_for_softmax
|
261
|
+
)
|
262
|
+
self.attention2 = attenBlock(
|
263
|
+
query_dim=self.query_dim,
|
264
|
+
heads=self.heads,
|
265
|
+
dim_head=self.dim_head,
|
266
|
+
name=f'Attention2',
|
267
|
+
precision=self.precision,
|
268
|
+
use_bias=self.use_bias,
|
269
|
+
dtype=self.dtype,
|
270
|
+
kernel_init=self.kernel_init,
|
271
|
+
force_fp32_for_softmax=self.force_fp32_for_softmax
|
272
|
+
)
|
273
|
+
|
274
|
+
self.ff = FlaxFeedForward(dim=self.query_dim)
|
275
|
+
self.norm1 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
276
|
+
self.norm2 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
277
|
+
self.norm3 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
278
|
+
|
279
|
+
@nn.compact
|
280
|
+
def __call__(self, hidden_states, context=None):
|
281
|
+
if self.only_pure_attention:
|
282
|
+
return self.attention2(hidden_states, context)
|
283
|
+
|
284
|
+
# self attention
|
285
|
+
if not self.use_cross_only:
|
286
|
+
hidden_states = hidden_states + self.attention1(self.norm1(hidden_states))
|
287
|
+
|
288
|
+
# cross attention
|
289
|
+
hidden_states = hidden_states + self.attention2(self.norm2(hidden_states), context)
|
290
|
+
# feed forward
|
291
|
+
hidden_states = hidden_states + self.ff(self.norm3(hidden_states))
|
292
|
+
|
293
|
+
return hidden_states
|
294
|
+
|
295
|
+
class TransformerBlock(nn.Module):
|
296
|
+
heads: int = 4
|
297
|
+
dim_head: int = 32
|
298
|
+
use_linear_attention: bool = True
|
299
|
+
dtype: Optional[Dtype] = None
|
300
|
+
precision: PrecisionLike = None
|
301
|
+
use_projection: bool = False
|
302
|
+
use_flash_attention:bool = False
|
303
|
+
use_self_and_cross:bool = True
|
304
|
+
only_pure_attention:bool = False
|
305
|
+
force_fp32_for_softmax: bool = True
|
306
|
+
kernel_init: Callable = kernel_init(1.0)
|
307
|
+
norm_inputs: bool = True
|
308
|
+
explicitly_add_residual: bool = True
|
309
|
+
|
310
|
+
@nn.compact
|
311
|
+
def __call__(self, x, context=None):
|
312
|
+
inner_dim = self.heads * self.dim_head
|
313
|
+
C = x.shape[-1]
|
314
|
+
if self.norm_inputs:
|
315
|
+
x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
|
316
|
+
if self.use_projection == True:
|
317
|
+
if self.use_linear_attention:
|
318
|
+
projected_x = nn.Dense(features=inner_dim,
|
319
|
+
use_bias=False, precision=self.precision,
|
320
|
+
kernel_init=self.kernel_init,
|
321
|
+
dtype=self.dtype, name=f'project_in')(x)
|
322
|
+
else:
|
323
|
+
projected_x = nn.Conv(
|
324
|
+
features=inner_dim, kernel_size=(1, 1),
|
325
|
+
kernel_init=self.kernel_init,
|
326
|
+
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
|
327
|
+
precision=self.precision, name=f'project_in_conv',
|
328
|
+
)(x)
|
329
|
+
else:
|
330
|
+
projected_x = x
|
331
|
+
inner_dim = C
|
332
|
+
|
333
|
+
context = projected_x if context is None else context
|
334
|
+
|
335
|
+
projected_x = BasicTransformerBlock(
|
336
|
+
query_dim=inner_dim,
|
337
|
+
heads=self.heads,
|
338
|
+
dim_head=self.dim_head,
|
339
|
+
name=f'Attention',
|
340
|
+
precision=self.precision,
|
341
|
+
use_bias=False,
|
342
|
+
dtype=self.dtype,
|
343
|
+
use_flash_attention=self.use_flash_attention,
|
344
|
+
use_cross_only=(not self.use_self_and_cross),
|
345
|
+
only_pure_attention=self.only_pure_attention,
|
346
|
+
force_fp32_for_softmax=self.force_fp32_for_softmax,
|
347
|
+
kernel_init=self.kernel_init
|
348
|
+
)(projected_x, context)
|
349
|
+
|
350
|
+
if self.use_projection == True:
|
351
|
+
if self.use_linear_attention:
|
352
|
+
projected_x = nn.Dense(features=C, precision=self.precision,
|
353
|
+
dtype=self.dtype, use_bias=False,
|
354
|
+
kernel_init=self.kernel_init,
|
355
|
+
name=f'project_out')(projected_x)
|
356
|
+
else:
|
357
|
+
projected_x = nn.Conv(
|
358
|
+
features=C, kernel_size=(1, 1),
|
359
|
+
kernel_init=self.kernel_init,
|
360
|
+
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
|
361
|
+
precision=self.precision, name=f'project_out_conv',
|
362
|
+
)(projected_x)
|
363
|
+
|
364
|
+
if self.only_pure_attention or self.explicitly_add_residual:
|
365
|
+
projected_x = x + projected_x
|
366
|
+
|
367
|
+
out = projected_x
|
368
|
+
return out
|
@@ -0,0 +1,19 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
from flax import linen as nn
|
4
|
+
from typing import Dict, Callable, Sequence, Any, Union
|
5
|
+
import einops
|
6
|
+
from ..common import kernel_init, ConvLayer, Upsample, Downsample, PixelShuffle
|
7
|
+
|
8
|
+
|
9
|
+
class AutoEncoder():
|
10
|
+
def encode(self, x: jnp.ndarray, **kwargs) -> jnp.ndarray:
|
11
|
+
raise NotImplementedError
|
12
|
+
|
13
|
+
def decode(self, z: jnp.ndarray, **kwargs) -> jnp.ndarray:
|
14
|
+
raise NotImplementedError
|
15
|
+
|
16
|
+
def __call__(self, x: jnp.ndarray):
|
17
|
+
latents = self.encode(x)
|
18
|
+
reconstructions = self.decode(latents)
|
19
|
+
return reconstructions
|
@@ -0,0 +1,91 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
from flax import linen as nn
|
4
|
+
from .autoencoder import AutoEncoder
|
5
|
+
|
6
|
+
"""
|
7
|
+
This module contains an Autoencoder implementation which uses the Stable Diffusion VAE model from the HuggingFace Diffusers library.
|
8
|
+
The actual model was not trained by me, but was taken from the HuggingFace model hub.
|
9
|
+
I have only implemented the wrapper around the diffusers pipeline to make it compatible with our library
|
10
|
+
All credits for the model go to the developers of Stable Diffusion VAE and all credits for the pipeline go to the developers of the Diffusers library.
|
11
|
+
"""
|
12
|
+
|
13
|
+
class StableDiffusionVAE(AutoEncoder):
|
14
|
+
def __init__(self, modelname = "CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16):
|
15
|
+
|
16
|
+
from diffusers.models.vae_flax import FlaxEncoder, FlaxDecoder
|
17
|
+
from diffusers import FlaxStableDiffusionPipeline
|
18
|
+
|
19
|
+
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
20
|
+
modelname,
|
21
|
+
revision=revision,
|
22
|
+
dtype=dtype,
|
23
|
+
)
|
24
|
+
|
25
|
+
vae = pipeline.vae
|
26
|
+
|
27
|
+
enc = FlaxEncoder(
|
28
|
+
in_channels=vae.config.in_channels,
|
29
|
+
out_channels=vae.config.latent_channels,
|
30
|
+
down_block_types=vae.config.down_block_types,
|
31
|
+
block_out_channels=vae.config.block_out_channels,
|
32
|
+
layers_per_block=vae.config.layers_per_block,
|
33
|
+
act_fn=vae.config.act_fn,
|
34
|
+
norm_num_groups=vae.config.norm_num_groups,
|
35
|
+
double_z=True,
|
36
|
+
dtype=vae.dtype,
|
37
|
+
)
|
38
|
+
|
39
|
+
dec = FlaxDecoder(
|
40
|
+
in_channels=vae.config.latent_channels,
|
41
|
+
out_channels=vae.config.out_channels,
|
42
|
+
up_block_types=vae.config.up_block_types,
|
43
|
+
block_out_channels=vae.config.block_out_channels,
|
44
|
+
layers_per_block=vae.config.layers_per_block,
|
45
|
+
norm_num_groups=vae.config.norm_num_groups,
|
46
|
+
act_fn=vae.config.act_fn,
|
47
|
+
dtype=vae.dtype,
|
48
|
+
)
|
49
|
+
|
50
|
+
quant_conv = nn.Conv(
|
51
|
+
2 * vae.config.latent_channels,
|
52
|
+
kernel_size=(1, 1),
|
53
|
+
strides=(1, 1),
|
54
|
+
padding="VALID",
|
55
|
+
dtype=vae.dtype,
|
56
|
+
)
|
57
|
+
|
58
|
+
post_quant_conv = nn.Conv(
|
59
|
+
vae.config.latent_channels,
|
60
|
+
kernel_size=(1, 1),
|
61
|
+
strides=(1, 1),
|
62
|
+
padding="VALID",
|
63
|
+
dtype=vae.dtype,
|
64
|
+
)
|
65
|
+
|
66
|
+
self.enc = enc
|
67
|
+
self.dec = dec
|
68
|
+
self.post_quant_conv = post_quant_conv
|
69
|
+
self.quant_conv = quant_conv
|
70
|
+
self.params = params
|
71
|
+
self.scaling_factor = vae.scaling_factor
|
72
|
+
|
73
|
+
def encode(self, images, rngkey: jax.random.PRNGKey = None):
|
74
|
+
latents = self.enc.apply({"params": self.params["vae"]['encoder']}, images, deterministic=True)
|
75
|
+
latents = self.quant_conv.apply({"params": self.params["vae"]['quant_conv']}, latents)
|
76
|
+
if rngkey is not None:
|
77
|
+
mean, log_std = jnp.split(latents, 2, axis=-1)
|
78
|
+
log_std = jnp.clip(log_std, -30, 20)
|
79
|
+
std = jnp.exp(0.5 * log_std)
|
80
|
+
latents = mean + std * jax.random.normal(rngkey, mean.shape, dtype=mean.dtype)
|
81
|
+
# print("Sampled")
|
82
|
+
else:
|
83
|
+
# return the mean
|
84
|
+
latents, _ = jnp.split(latents, 2, axis=-1)
|
85
|
+
latents *= self.scaling_factor
|
86
|
+
return latents
|
87
|
+
|
88
|
+
def decode(self, latents):
|
89
|
+
latents = (1.0 / self.scaling_factor) * latents
|
90
|
+
latents = self.post_quant_conv.apply({"params": self.params["vae"]['post_quant_conv']}, latents)
|
91
|
+
return self.dec.apply({"params": self.params["vae"]['decoder']}, latents)
|
@@ -0,0 +1,26 @@
|
|
1
|
+
from typing import Any, List, Optional, Callable
|
2
|
+
import jax
|
3
|
+
import flax.linen as nn
|
4
|
+
from jax import numpy as jnp
|
5
|
+
from flax.typing import Dtype, PrecisionLike
|
6
|
+
from .autoencoder import AutoEncoder
|
7
|
+
|
8
|
+
class SimpleAutoEncoder(AutoEncoder):
|
9
|
+
latent_channels: int
|
10
|
+
feature_depths: List[int]=[64, 128, 256, 512]
|
11
|
+
attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}],
|
12
|
+
num_res_blocks: int=2
|
13
|
+
num_middle_res_blocks:int=1,
|
14
|
+
activation:Callable = jax.nn.swish
|
15
|
+
norm_groups:int=8
|
16
|
+
dtype: Optional[Dtype] = None
|
17
|
+
precision: PrecisionLike = None
|
18
|
+
|
19
|
+
# def encode(self, x: jnp.ndarray):
|
20
|
+
|
21
|
+
|
22
|
+
@nn.compact
|
23
|
+
def __call__(self, x: jnp.ndarray):
|
24
|
+
latents = self.encode(x)
|
25
|
+
reconstructions = self.decode(latents)
|
26
|
+
return reconstructions
|