flaxdiff 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/models/attention.py +140 -162
- flaxdiff/models/autoencoder/__init__.py +2 -0
- flaxdiff/models/autoencoder/autoencoder.py +19 -0
- flaxdiff/models/autoencoder/diffusers.py +91 -0
- flaxdiff/models/autoencoder/simple_autoenc.py +26 -0
- flaxdiff/models/common.py +322 -0
- flaxdiff/models/simple_unet.py +21 -327
- flaxdiff/trainer/__init__.py +2 -201
- flaxdiff/trainer/autoencoder_trainer.py +182 -0
- flaxdiff/trainer/diffusion_trainer.py +202 -0
- flaxdiff/trainer/simple_trainer.py +175 -80
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.6.dist-info}/METADATA +12 -2
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.6.dist-info}/RECORD +15 -9
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.6.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.6.dist-info}/top_level.txt +0 -0
flaxdiff/models/common.py
CHANGED
@@ -1,7 +1,329 @@
|
|
1
1
|
import jax.numpy as jnp
|
2
|
+
import jax
|
2
3
|
from flax import linen as nn
|
4
|
+
from typing import Optional, Any, Callable, Sequence, Union
|
5
|
+
from flax.typing import Dtype, PrecisionLike
|
6
|
+
from typing import Dict, Callable, Sequence, Any, Union
|
7
|
+
import einops
|
3
8
|
|
4
9
|
# Kernel initializer to use
|
5
10
|
def kernel_init(scale, dtype=jnp.float32):
|
6
11
|
scale = max(scale, 1e-10)
|
7
12
|
return nn.initializers.variance_scaling(scale=scale, mode="fan_avg", distribution="truncated_normal", dtype=dtype)
|
13
|
+
|
14
|
+
|
15
|
+
class WeightStandardizedConv(nn.Module):
|
16
|
+
"""
|
17
|
+
apply weight standardization https://arxiv.org/abs/1903.10520
|
18
|
+
"""
|
19
|
+
features: int
|
20
|
+
kernel_size: Sequence[int] = 3
|
21
|
+
strides: Union[None, int, Sequence[int]] = 1
|
22
|
+
padding: Any = 1
|
23
|
+
dtype: Optional[Dtype] = None
|
24
|
+
precision: PrecisionLike = None
|
25
|
+
param_dtype: Optional[Dtype] = None
|
26
|
+
|
27
|
+
@nn.compact
|
28
|
+
def __call__(self, x):
|
29
|
+
"""
|
30
|
+
Applies a weight standardized convolution to the inputs.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
inputs: input data with dimensions (batch, spatial_dims..., features).
|
34
|
+
|
35
|
+
Returns:
|
36
|
+
The convolved data.
|
37
|
+
"""
|
38
|
+
x = x.astype(self.dtype)
|
39
|
+
|
40
|
+
conv = nn.Conv(
|
41
|
+
features=self.features,
|
42
|
+
kernel_size=self.kernel_size,
|
43
|
+
strides = self.strides,
|
44
|
+
padding=self.padding,
|
45
|
+
dtype=self.dtype,
|
46
|
+
param_dtype = self.param_dtype,
|
47
|
+
parent=None)
|
48
|
+
|
49
|
+
kernel_init = lambda rng, x: conv.init(rng,x)['params']['kernel']
|
50
|
+
bias_init = lambda rng, x: conv.init(rng,x)['params']['bias']
|
51
|
+
|
52
|
+
# standardize kernel
|
53
|
+
kernel = self.param('kernel', kernel_init, x)
|
54
|
+
eps = 1e-5 if self.dtype == jnp.float32 else 1e-3
|
55
|
+
# reduce over dim_out
|
56
|
+
redux = tuple(range(kernel.ndim - 1))
|
57
|
+
mean = jnp.mean(kernel, axis=redux, dtype=self.dtype, keepdims=True)
|
58
|
+
var = jnp.var(kernel, axis=redux, dtype=self.dtype, keepdims=True)
|
59
|
+
standardized_kernel = (kernel - mean)/jnp.sqrt(var + eps)
|
60
|
+
|
61
|
+
bias = self.param('bias',bias_init, x)
|
62
|
+
|
63
|
+
return(conv.apply({'params': {'kernel': standardized_kernel, 'bias': bias}},x))
|
64
|
+
|
65
|
+
class PixelShuffle(nn.Module):
|
66
|
+
scale: int
|
67
|
+
|
68
|
+
@nn.compact
|
69
|
+
def __call__(self, x):
|
70
|
+
up = einops.rearrange(
|
71
|
+
x,
|
72
|
+
pattern="b h w (h2 w2 c) -> b (h h2) (w w2) c",
|
73
|
+
h2=self.scale,
|
74
|
+
w2=self.scale,
|
75
|
+
)
|
76
|
+
return up
|
77
|
+
|
78
|
+
class TimeEmbedding(nn.Module):
|
79
|
+
features:int
|
80
|
+
nax_positions:int=10000
|
81
|
+
|
82
|
+
def setup(self):
|
83
|
+
half_dim = self.features // 2
|
84
|
+
emb = jnp.log(self.nax_positions) / (half_dim - 1)
|
85
|
+
emb = jnp.exp(-emb * jnp.arange(half_dim, dtype=jnp.float32))
|
86
|
+
self.embeddings = emb
|
87
|
+
|
88
|
+
def __call__(self, x):
|
89
|
+
x = jax.lax.convert_element_type(x, jnp.float32)
|
90
|
+
emb = x[:, None] * self.embeddings[None, :]
|
91
|
+
emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
|
92
|
+
return emb
|
93
|
+
|
94
|
+
class FourierEmbedding(nn.Module):
|
95
|
+
features:int
|
96
|
+
scale:int = 16
|
97
|
+
|
98
|
+
def setup(self):
|
99
|
+
self.freqs = jax.random.normal(jax.random.PRNGKey(42), (self.features // 2, ), dtype=jnp.float32) * self.scale
|
100
|
+
|
101
|
+
def __call__(self, x):
|
102
|
+
x = jax.lax.convert_element_type(x, jnp.float32)
|
103
|
+
emb = x[:, None] * (2 * jnp.pi * self.freqs)[None, :]
|
104
|
+
emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
|
105
|
+
return emb
|
106
|
+
|
107
|
+
class TimeProjection(nn.Module):
|
108
|
+
features:int
|
109
|
+
activation:Callable=jax.nn.gelu
|
110
|
+
|
111
|
+
@nn.compact
|
112
|
+
def __call__(self, x):
|
113
|
+
x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x)
|
114
|
+
x = self.activation(x)
|
115
|
+
x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x)
|
116
|
+
x = self.activation(x)
|
117
|
+
return x
|
118
|
+
|
119
|
+
class SeparableConv(nn.Module):
|
120
|
+
features:int
|
121
|
+
kernel_size:tuple=(3, 3)
|
122
|
+
strides:tuple=(1, 1)
|
123
|
+
use_bias:bool=False
|
124
|
+
kernel_init:Callable=kernel_init(1.0)
|
125
|
+
padding:str="SAME"
|
126
|
+
dtype: Optional[Dtype] = None
|
127
|
+
precision: PrecisionLike = None
|
128
|
+
|
129
|
+
@nn.compact
|
130
|
+
def __call__(self, x):
|
131
|
+
in_features = x.shape[-1]
|
132
|
+
depthwise = nn.Conv(
|
133
|
+
features=in_features, kernel_size=self.kernel_size,
|
134
|
+
strides=self.strides, kernel_init=self.kernel_init,
|
135
|
+
feature_group_count=in_features, use_bias=self.use_bias,
|
136
|
+
padding=self.padding,
|
137
|
+
dtype=self.dtype,
|
138
|
+
precision=self.precision
|
139
|
+
)(x)
|
140
|
+
pointwise = nn.Conv(
|
141
|
+
features=self.features, kernel_size=(1, 1),
|
142
|
+
strides=(1, 1), kernel_init=self.kernel_init,
|
143
|
+
use_bias=self.use_bias,
|
144
|
+
dtype=self.dtype,
|
145
|
+
precision=self.precision
|
146
|
+
)(depthwise)
|
147
|
+
return pointwise
|
148
|
+
|
149
|
+
class ConvLayer(nn.Module):
|
150
|
+
conv_type:str
|
151
|
+
features:int
|
152
|
+
kernel_size:tuple=(3, 3)
|
153
|
+
strides:tuple=(1, 1)
|
154
|
+
kernel_init:Callable=kernel_init(1.0)
|
155
|
+
dtype: Optional[Dtype] = None
|
156
|
+
precision: PrecisionLike = None
|
157
|
+
|
158
|
+
def setup(self):
|
159
|
+
# conv_type can be "conv", "separable", "conv_transpose"
|
160
|
+
if self.conv_type == "conv":
|
161
|
+
self.conv = nn.Conv(
|
162
|
+
features=self.features,
|
163
|
+
kernel_size=self.kernel_size,
|
164
|
+
strides=self.strides,
|
165
|
+
kernel_init=self.kernel_init,
|
166
|
+
dtype=self.dtype,
|
167
|
+
precision=self.precision
|
168
|
+
)
|
169
|
+
elif self.conv_type == "w_conv":
|
170
|
+
self.conv = WeightStandardizedConv(
|
171
|
+
features=self.features,
|
172
|
+
kernel_size=self.kernel_size,
|
173
|
+
strides=self.strides,
|
174
|
+
padding="SAME",
|
175
|
+
param_dtype=self.dtype,
|
176
|
+
dtype=self.dtype,
|
177
|
+
precision=self.precision
|
178
|
+
)
|
179
|
+
elif self.conv_type == "separable":
|
180
|
+
self.conv = SeparableConv(
|
181
|
+
features=self.features,
|
182
|
+
kernel_size=self.kernel_size,
|
183
|
+
strides=self.strides,
|
184
|
+
kernel_init=self.kernel_init,
|
185
|
+
dtype=self.dtype,
|
186
|
+
precision=self.precision
|
187
|
+
)
|
188
|
+
elif self.conv_type == "conv_transpose":
|
189
|
+
self.conv = nn.ConvTranspose(
|
190
|
+
features=self.features,
|
191
|
+
kernel_size=self.kernel_size,
|
192
|
+
strides=self.strides,
|
193
|
+
kernel_init=self.kernel_init,
|
194
|
+
dtype=self.dtype,
|
195
|
+
precision=self.precision
|
196
|
+
)
|
197
|
+
|
198
|
+
def __call__(self, x):
|
199
|
+
return self.conv(x)
|
200
|
+
|
201
|
+
class Upsample(nn.Module):
|
202
|
+
features:int
|
203
|
+
scale:int
|
204
|
+
activation:Callable=jax.nn.swish
|
205
|
+
dtype: Optional[Dtype] = None
|
206
|
+
precision: PrecisionLike = None
|
207
|
+
|
208
|
+
@nn.compact
|
209
|
+
def __call__(self, x, residual=None):
|
210
|
+
out = x
|
211
|
+
# out = PixelShuffle(scale=self.scale)(out)
|
212
|
+
B, H, W, C = x.shape
|
213
|
+
out = jax.image.resize(x, (B, H * self.scale, W * self.scale, C), method="nearest")
|
214
|
+
out = ConvLayer(
|
215
|
+
"conv",
|
216
|
+
features=self.features,
|
217
|
+
kernel_size=(3, 3),
|
218
|
+
strides=(1, 1),
|
219
|
+
dtype=self.dtype,
|
220
|
+
precision=self.precision
|
221
|
+
)(out)
|
222
|
+
if residual is not None:
|
223
|
+
out = jnp.concatenate([out, residual], axis=-1)
|
224
|
+
return out
|
225
|
+
|
226
|
+
class Downsample(nn.Module):
|
227
|
+
features:int
|
228
|
+
scale:int
|
229
|
+
activation:Callable=jax.nn.swish
|
230
|
+
dtype: Optional[Dtype] = None
|
231
|
+
precision: PrecisionLike = None
|
232
|
+
|
233
|
+
@nn.compact
|
234
|
+
def __call__(self, x, residual=None):
|
235
|
+
out = ConvLayer(
|
236
|
+
"conv",
|
237
|
+
features=self.features,
|
238
|
+
kernel_size=(3, 3),
|
239
|
+
strides=(2, 2),
|
240
|
+
dtype=self.dtype,
|
241
|
+
precision=self.precision
|
242
|
+
)(x)
|
243
|
+
if residual is not None:
|
244
|
+
if residual.shape[1] > out.shape[1]:
|
245
|
+
residual = nn.avg_pool(residual, window_shape=(2, 2), strides=(2, 2), padding="SAME")
|
246
|
+
out = jnp.concatenate([out, residual], axis=-1)
|
247
|
+
return out
|
248
|
+
|
249
|
+
|
250
|
+
def l2norm(t, axis=1, eps=1e-12):
|
251
|
+
denom = jnp.clip(jnp.linalg.norm(t, ord=2, axis=axis, keepdims=True), eps)
|
252
|
+
out = t/denom
|
253
|
+
return (out)
|
254
|
+
|
255
|
+
|
256
|
+
class ResidualBlock(nn.Module):
|
257
|
+
conv_type:str
|
258
|
+
features:int
|
259
|
+
kernel_size:tuple=(3, 3)
|
260
|
+
strides:tuple=(1, 1)
|
261
|
+
padding:str="SAME"
|
262
|
+
activation:Callable=jax.nn.swish
|
263
|
+
direction:str=None
|
264
|
+
res:int=2
|
265
|
+
norm_groups:int=8
|
266
|
+
kernel_init:Callable=kernel_init(1.0)
|
267
|
+
dtype: Optional[Dtype] = None
|
268
|
+
precision: PrecisionLike = None
|
269
|
+
|
270
|
+
@nn.compact
|
271
|
+
def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_features:jax.Array=None):
|
272
|
+
residual = x
|
273
|
+
# out = nn.GroupNorm(self.norm_groups)(x)
|
274
|
+
out = nn.RMSNorm()(x)
|
275
|
+
out = self.activation(out)
|
276
|
+
|
277
|
+
out = ConvLayer(
|
278
|
+
self.conv_type,
|
279
|
+
features=self.features,
|
280
|
+
kernel_size=self.kernel_size,
|
281
|
+
strides=self.strides,
|
282
|
+
kernel_init=self.kernel_init,
|
283
|
+
name="conv1",
|
284
|
+
dtype=self.dtype,
|
285
|
+
precision=self.precision
|
286
|
+
)(out)
|
287
|
+
|
288
|
+
temb = nn.DenseGeneral(
|
289
|
+
features=self.features,
|
290
|
+
name="temb_projection",
|
291
|
+
dtype=self.dtype,
|
292
|
+
precision=self.precision)(temb)
|
293
|
+
temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
|
294
|
+
# scale, shift = jnp.split(temb, 2, axis=-1)
|
295
|
+
# out = out * (1 + scale) + shift
|
296
|
+
out = out + temb
|
297
|
+
|
298
|
+
# out = nn.GroupNorm(self.norm_groups)(out)
|
299
|
+
out = nn.RMSNorm()(out)
|
300
|
+
out = self.activation(out)
|
301
|
+
|
302
|
+
out = ConvLayer(
|
303
|
+
self.conv_type,
|
304
|
+
features=self.features,
|
305
|
+
kernel_size=self.kernel_size,
|
306
|
+
strides=self.strides,
|
307
|
+
kernel_init=self.kernel_init,
|
308
|
+
name="conv2",
|
309
|
+
dtype=self.dtype,
|
310
|
+
precision=self.precision
|
311
|
+
)(out)
|
312
|
+
|
313
|
+
if residual.shape != out.shape:
|
314
|
+
residual = ConvLayer(
|
315
|
+
self.conv_type,
|
316
|
+
features=self.features,
|
317
|
+
kernel_size=(1, 1),
|
318
|
+
strides=1,
|
319
|
+
kernel_init=self.kernel_init,
|
320
|
+
name="residual_conv",
|
321
|
+
dtype=self.dtype,
|
322
|
+
precision=self.precision
|
323
|
+
)(residual)
|
324
|
+
out = out + residual
|
325
|
+
|
326
|
+
out = jnp.concatenate([out, extra_features], axis=-1) if extra_features is not None else out
|
327
|
+
|
328
|
+
return out
|
329
|
+
|