flaxdiff 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flaxdiff/models/common.py CHANGED
@@ -1,7 +1,329 @@
1
1
  import jax.numpy as jnp
2
+ import jax
2
3
  from flax import linen as nn
4
+ from typing import Optional, Any, Callable, Sequence, Union
5
+ from flax.typing import Dtype, PrecisionLike
6
+ from typing import Dict, Callable, Sequence, Any, Union
7
+ import einops
3
8
 
4
9
  # Kernel initializer to use
5
10
  def kernel_init(scale, dtype=jnp.float32):
6
11
  scale = max(scale, 1e-10)
7
12
  return nn.initializers.variance_scaling(scale=scale, mode="fan_avg", distribution="truncated_normal", dtype=dtype)
13
+
14
+
15
+ class WeightStandardizedConv(nn.Module):
16
+ """
17
+ apply weight standardization https://arxiv.org/abs/1903.10520
18
+ """
19
+ features: int
20
+ kernel_size: Sequence[int] = 3
21
+ strides: Union[None, int, Sequence[int]] = 1
22
+ padding: Any = 1
23
+ dtype: Optional[Dtype] = None
24
+ precision: PrecisionLike = None
25
+ param_dtype: Optional[Dtype] = None
26
+
27
+ @nn.compact
28
+ def __call__(self, x):
29
+ """
30
+ Applies a weight standardized convolution to the inputs.
31
+
32
+ Args:
33
+ inputs: input data with dimensions (batch, spatial_dims..., features).
34
+
35
+ Returns:
36
+ The convolved data.
37
+ """
38
+ x = x.astype(self.dtype)
39
+
40
+ conv = nn.Conv(
41
+ features=self.features,
42
+ kernel_size=self.kernel_size,
43
+ strides = self.strides,
44
+ padding=self.padding,
45
+ dtype=self.dtype,
46
+ param_dtype = self.param_dtype,
47
+ parent=None)
48
+
49
+ kernel_init = lambda rng, x: conv.init(rng,x)['params']['kernel']
50
+ bias_init = lambda rng, x: conv.init(rng,x)['params']['bias']
51
+
52
+ # standardize kernel
53
+ kernel = self.param('kernel', kernel_init, x)
54
+ eps = 1e-5 if self.dtype == jnp.float32 else 1e-3
55
+ # reduce over dim_out
56
+ redux = tuple(range(kernel.ndim - 1))
57
+ mean = jnp.mean(kernel, axis=redux, dtype=self.dtype, keepdims=True)
58
+ var = jnp.var(kernel, axis=redux, dtype=self.dtype, keepdims=True)
59
+ standardized_kernel = (kernel - mean)/jnp.sqrt(var + eps)
60
+
61
+ bias = self.param('bias',bias_init, x)
62
+
63
+ return(conv.apply({'params': {'kernel': standardized_kernel, 'bias': bias}},x))
64
+
65
+ class PixelShuffle(nn.Module):
66
+ scale: int
67
+
68
+ @nn.compact
69
+ def __call__(self, x):
70
+ up = einops.rearrange(
71
+ x,
72
+ pattern="b h w (h2 w2 c) -> b (h h2) (w w2) c",
73
+ h2=self.scale,
74
+ w2=self.scale,
75
+ )
76
+ return up
77
+
78
+ class TimeEmbedding(nn.Module):
79
+ features:int
80
+ nax_positions:int=10000
81
+
82
+ def setup(self):
83
+ half_dim = self.features // 2
84
+ emb = jnp.log(self.nax_positions) / (half_dim - 1)
85
+ emb = jnp.exp(-emb * jnp.arange(half_dim, dtype=jnp.float32))
86
+ self.embeddings = emb
87
+
88
+ def __call__(self, x):
89
+ x = jax.lax.convert_element_type(x, jnp.float32)
90
+ emb = x[:, None] * self.embeddings[None, :]
91
+ emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
92
+ return emb
93
+
94
+ class FourierEmbedding(nn.Module):
95
+ features:int
96
+ scale:int = 16
97
+
98
+ def setup(self):
99
+ self.freqs = jax.random.normal(jax.random.PRNGKey(42), (self.features // 2, ), dtype=jnp.float32) * self.scale
100
+
101
+ def __call__(self, x):
102
+ x = jax.lax.convert_element_type(x, jnp.float32)
103
+ emb = x[:, None] * (2 * jnp.pi * self.freqs)[None, :]
104
+ emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
105
+ return emb
106
+
107
+ class TimeProjection(nn.Module):
108
+ features:int
109
+ activation:Callable=jax.nn.gelu
110
+
111
+ @nn.compact
112
+ def __call__(self, x):
113
+ x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x)
114
+ x = self.activation(x)
115
+ x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x)
116
+ x = self.activation(x)
117
+ return x
118
+
119
+ class SeparableConv(nn.Module):
120
+ features:int
121
+ kernel_size:tuple=(3, 3)
122
+ strides:tuple=(1, 1)
123
+ use_bias:bool=False
124
+ kernel_init:Callable=kernel_init(1.0)
125
+ padding:str="SAME"
126
+ dtype: Optional[Dtype] = None
127
+ precision: PrecisionLike = None
128
+
129
+ @nn.compact
130
+ def __call__(self, x):
131
+ in_features = x.shape[-1]
132
+ depthwise = nn.Conv(
133
+ features=in_features, kernel_size=self.kernel_size,
134
+ strides=self.strides, kernel_init=self.kernel_init,
135
+ feature_group_count=in_features, use_bias=self.use_bias,
136
+ padding=self.padding,
137
+ dtype=self.dtype,
138
+ precision=self.precision
139
+ )(x)
140
+ pointwise = nn.Conv(
141
+ features=self.features, kernel_size=(1, 1),
142
+ strides=(1, 1), kernel_init=self.kernel_init,
143
+ use_bias=self.use_bias,
144
+ dtype=self.dtype,
145
+ precision=self.precision
146
+ )(depthwise)
147
+ return pointwise
148
+
149
+ class ConvLayer(nn.Module):
150
+ conv_type:str
151
+ features:int
152
+ kernel_size:tuple=(3, 3)
153
+ strides:tuple=(1, 1)
154
+ kernel_init:Callable=kernel_init(1.0)
155
+ dtype: Optional[Dtype] = None
156
+ precision: PrecisionLike = None
157
+
158
+ def setup(self):
159
+ # conv_type can be "conv", "separable", "conv_transpose"
160
+ if self.conv_type == "conv":
161
+ self.conv = nn.Conv(
162
+ features=self.features,
163
+ kernel_size=self.kernel_size,
164
+ strides=self.strides,
165
+ kernel_init=self.kernel_init,
166
+ dtype=self.dtype,
167
+ precision=self.precision
168
+ )
169
+ elif self.conv_type == "w_conv":
170
+ self.conv = WeightStandardizedConv(
171
+ features=self.features,
172
+ kernel_size=self.kernel_size,
173
+ strides=self.strides,
174
+ padding="SAME",
175
+ param_dtype=self.dtype,
176
+ dtype=self.dtype,
177
+ precision=self.precision
178
+ )
179
+ elif self.conv_type == "separable":
180
+ self.conv = SeparableConv(
181
+ features=self.features,
182
+ kernel_size=self.kernel_size,
183
+ strides=self.strides,
184
+ kernel_init=self.kernel_init,
185
+ dtype=self.dtype,
186
+ precision=self.precision
187
+ )
188
+ elif self.conv_type == "conv_transpose":
189
+ self.conv = nn.ConvTranspose(
190
+ features=self.features,
191
+ kernel_size=self.kernel_size,
192
+ strides=self.strides,
193
+ kernel_init=self.kernel_init,
194
+ dtype=self.dtype,
195
+ precision=self.precision
196
+ )
197
+
198
+ def __call__(self, x):
199
+ return self.conv(x)
200
+
201
+ class Upsample(nn.Module):
202
+ features:int
203
+ scale:int
204
+ activation:Callable=jax.nn.swish
205
+ dtype: Optional[Dtype] = None
206
+ precision: PrecisionLike = None
207
+
208
+ @nn.compact
209
+ def __call__(self, x, residual=None):
210
+ out = x
211
+ # out = PixelShuffle(scale=self.scale)(out)
212
+ B, H, W, C = x.shape
213
+ out = jax.image.resize(x, (B, H * self.scale, W * self.scale, C), method="nearest")
214
+ out = ConvLayer(
215
+ "conv",
216
+ features=self.features,
217
+ kernel_size=(3, 3),
218
+ strides=(1, 1),
219
+ dtype=self.dtype,
220
+ precision=self.precision
221
+ )(out)
222
+ if residual is not None:
223
+ out = jnp.concatenate([out, residual], axis=-1)
224
+ return out
225
+
226
+ class Downsample(nn.Module):
227
+ features:int
228
+ scale:int
229
+ activation:Callable=jax.nn.swish
230
+ dtype: Optional[Dtype] = None
231
+ precision: PrecisionLike = None
232
+
233
+ @nn.compact
234
+ def __call__(self, x, residual=None):
235
+ out = ConvLayer(
236
+ "conv",
237
+ features=self.features,
238
+ kernel_size=(3, 3),
239
+ strides=(2, 2),
240
+ dtype=self.dtype,
241
+ precision=self.precision
242
+ )(x)
243
+ if residual is not None:
244
+ if residual.shape[1] > out.shape[1]:
245
+ residual = nn.avg_pool(residual, window_shape=(2, 2), strides=(2, 2), padding="SAME")
246
+ out = jnp.concatenate([out, residual], axis=-1)
247
+ return out
248
+
249
+
250
+ def l2norm(t, axis=1, eps=1e-12):
251
+ denom = jnp.clip(jnp.linalg.norm(t, ord=2, axis=axis, keepdims=True), eps)
252
+ out = t/denom
253
+ return (out)
254
+
255
+
256
+ class ResidualBlock(nn.Module):
257
+ conv_type:str
258
+ features:int
259
+ kernel_size:tuple=(3, 3)
260
+ strides:tuple=(1, 1)
261
+ padding:str="SAME"
262
+ activation:Callable=jax.nn.swish
263
+ direction:str=None
264
+ res:int=2
265
+ norm_groups:int=8
266
+ kernel_init:Callable=kernel_init(1.0)
267
+ dtype: Optional[Dtype] = None
268
+ precision: PrecisionLike = None
269
+
270
+ @nn.compact
271
+ def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_features:jax.Array=None):
272
+ residual = x
273
+ # out = nn.GroupNorm(self.norm_groups)(x)
274
+ out = nn.RMSNorm()(x)
275
+ out = self.activation(out)
276
+
277
+ out = ConvLayer(
278
+ self.conv_type,
279
+ features=self.features,
280
+ kernel_size=self.kernel_size,
281
+ strides=self.strides,
282
+ kernel_init=self.kernel_init,
283
+ name="conv1",
284
+ dtype=self.dtype,
285
+ precision=self.precision
286
+ )(out)
287
+
288
+ temb = nn.DenseGeneral(
289
+ features=self.features,
290
+ name="temb_projection",
291
+ dtype=self.dtype,
292
+ precision=self.precision)(temb)
293
+ temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
294
+ # scale, shift = jnp.split(temb, 2, axis=-1)
295
+ # out = out * (1 + scale) + shift
296
+ out = out + temb
297
+
298
+ # out = nn.GroupNorm(self.norm_groups)(out)
299
+ out = nn.RMSNorm()(out)
300
+ out = self.activation(out)
301
+
302
+ out = ConvLayer(
303
+ self.conv_type,
304
+ features=self.features,
305
+ kernel_size=self.kernel_size,
306
+ strides=self.strides,
307
+ kernel_init=self.kernel_init,
308
+ name="conv2",
309
+ dtype=self.dtype,
310
+ precision=self.precision
311
+ )(out)
312
+
313
+ if residual.shape != out.shape:
314
+ residual = ConvLayer(
315
+ self.conv_type,
316
+ features=self.features,
317
+ kernel_size=(1, 1),
318
+ strides=1,
319
+ kernel_init=self.kernel_init,
320
+ name="residual_conv",
321
+ dtype=self.dtype,
322
+ precision=self.precision
323
+ )(residual)
324
+ out = out + residual
325
+
326
+ out = jnp.concatenate([out, extra_features], axis=-1) if extra_features is not None else out
327
+
328
+ return out
329
+