flaxdiff 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flaxdiff/models/common.py CHANGED
@@ -1,7 +1,250 @@
1
1
  import jax.numpy as jnp
2
+ import jax
2
3
  from flax import linen as nn
4
+ from typing import Dict, Callable, Sequence, Any, Union
5
+ import einops
3
6
 
4
7
  # Kernel initializer to use
5
8
  def kernel_init(scale, dtype=jnp.float32):
6
9
  scale = max(scale, 1e-10)
7
10
  return nn.initializers.variance_scaling(scale=scale, mode="fan_avg", distribution="truncated_normal", dtype=dtype)
11
+
12
+
13
+ class WeightStandardizedConv(nn.Module):
14
+ """
15
+ apply weight standardization https://arxiv.org/abs/1903.10520
16
+ """
17
+ features: int
18
+ kernel_size: Sequence[int] = 3
19
+ strides: Union[None, int, Sequence[int]] = 1
20
+ padding: Any = 1
21
+ dtype: Any = jnp.float32
22
+ param_dtype: Any = jnp.float32
23
+
24
+ @nn.compact
25
+ def __call__(self, x):
26
+ """
27
+ Applies a weight standardized convolution to the inputs.
28
+
29
+ Args:
30
+ inputs: input data with dimensions (batch, spatial_dims..., features).
31
+
32
+ Returns:
33
+ The convolved data.
34
+ """
35
+ x = x.astype(self.dtype)
36
+
37
+ conv = nn.Conv(
38
+ features=self.features,
39
+ kernel_size=self.kernel_size,
40
+ strides = self.strides,
41
+ padding=self.padding,
42
+ dtype=self.dtype,
43
+ param_dtype = self.param_dtype,
44
+ parent=None)
45
+
46
+ kernel_init = lambda rng, x: conv.init(rng,x)['params']['kernel']
47
+ bias_init = lambda rng, x: conv.init(rng,x)['params']['bias']
48
+
49
+ # standardize kernel
50
+ kernel = self.param('kernel', kernel_init, x)
51
+ eps = 1e-5 if self.dtype == jnp.float32 else 1e-3
52
+ # reduce over dim_out
53
+ redux = tuple(range(kernel.ndim - 1))
54
+ mean = jnp.mean(kernel, axis=redux, dtype=self.dtype, keepdims=True)
55
+ var = jnp.var(kernel, axis=redux, dtype=self.dtype, keepdims=True)
56
+ standardized_kernel = (kernel - mean)/jnp.sqrt(var + eps)
57
+
58
+ bias = self.param('bias',bias_init, x)
59
+
60
+ return(conv.apply({'params': {'kernel': standardized_kernel, 'bias': bias}},x))
61
+
62
+ class PixelShuffle(nn.Module):
63
+ scale: int
64
+
65
+ @nn.compact
66
+ def __call__(self, x):
67
+ up = einops.rearrange(
68
+ x,
69
+ pattern="b h w (h2 w2 c) -> b (h h2) (w w2) c",
70
+ h2=self.scale,
71
+ w2=self.scale,
72
+ )
73
+ return up
74
+
75
+ class TimeEmbedding(nn.Module):
76
+ features:int
77
+ nax_positions:int=10000
78
+
79
+ def setup(self):
80
+ half_dim = self.features // 2
81
+ emb = jnp.log(self.nax_positions) / (half_dim - 1)
82
+ emb = jnp.exp(-emb * jnp.arange(half_dim, dtype=jnp.float32))
83
+ self.embeddings = emb
84
+
85
+ def __call__(self, x):
86
+ x = jax.lax.convert_element_type(x, jnp.float32)
87
+ emb = x[:, None] * self.embeddings[None, :]
88
+ emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
89
+ return emb
90
+
91
+ class FourierEmbedding(nn.Module):
92
+ features:int
93
+ scale:int = 16
94
+
95
+ def setup(self):
96
+ self.freqs = jax.random.normal(jax.random.PRNGKey(42), (self.features // 2, ), dtype=jnp.float32) * self.scale
97
+
98
+ def __call__(self, x):
99
+ x = jax.lax.convert_element_type(x, jnp.float32)
100
+ emb = x[:, None] * (2 * jnp.pi * self.freqs)[None, :]
101
+ emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
102
+ return emb
103
+
104
+ class TimeProjection(nn.Module):
105
+ features:int
106
+ activation:Callable=jax.nn.gelu
107
+
108
+ @nn.compact
109
+ def __call__(self, x):
110
+ x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x)
111
+ x = self.activation(x)
112
+ x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x)
113
+ x = self.activation(x)
114
+ return x
115
+
116
+ class SeparableConv(nn.Module):
117
+ features:int
118
+ kernel_size:tuple=(3, 3)
119
+ strides:tuple=(1, 1)
120
+ use_bias:bool=False
121
+ kernel_init:Callable=kernel_init(1.0)
122
+ padding:str="SAME"
123
+ dtype: Any = jnp.bfloat16
124
+ precision: Any = jax.lax.Precision.HIGH
125
+
126
+ @nn.compact
127
+ def __call__(self, x):
128
+ in_features = x.shape[-1]
129
+ depthwise = nn.Conv(
130
+ features=in_features, kernel_size=self.kernel_size,
131
+ strides=self.strides, kernel_init=self.kernel_init,
132
+ feature_group_count=in_features, use_bias=self.use_bias,
133
+ padding=self.padding,
134
+ dtype=self.dtype,
135
+ precision=self.precision
136
+ )(x)
137
+ pointwise = nn.Conv(
138
+ features=self.features, kernel_size=(1, 1),
139
+ strides=(1, 1), kernel_init=self.kernel_init,
140
+ use_bias=self.use_bias,
141
+ dtype=self.dtype,
142
+ precision=self.precision
143
+ )(depthwise)
144
+ return pointwise
145
+
146
+ class ConvLayer(nn.Module):
147
+ conv_type:str
148
+ features:int
149
+ kernel_size:tuple=(3, 3)
150
+ strides:tuple=(1, 1)
151
+ kernel_init:Callable=kernel_init(1.0)
152
+ dtype: Any = jnp.bfloat16
153
+ precision: Any = jax.lax.Precision.HIGH
154
+
155
+ def setup(self):
156
+ # conv_type can be "conv", "separable", "conv_transpose"
157
+ if self.conv_type == "conv":
158
+ self.conv = nn.Conv(
159
+ features=self.features,
160
+ kernel_size=self.kernel_size,
161
+ strides=self.strides,
162
+ kernel_init=self.kernel_init,
163
+ dtype=self.dtype,
164
+ precision=self.precision
165
+ )
166
+ elif self.conv_type == "w_conv":
167
+ self.conv = WeightStandardizedConv(
168
+ features=self.features,
169
+ kernel_size=self.kernel_size,
170
+ strides=self.strides,
171
+ padding="SAME",
172
+ param_dtype=self.dtype,
173
+ dtype=self.dtype,
174
+ precision=self.precision
175
+ )
176
+ elif self.conv_type == "separable":
177
+ self.conv = SeparableConv(
178
+ features=self.features,
179
+ kernel_size=self.kernel_size,
180
+ strides=self.strides,
181
+ kernel_init=self.kernel_init,
182
+ dtype=self.dtype,
183
+ precision=self.precision
184
+ )
185
+ elif self.conv_type == "conv_transpose":
186
+ self.conv = nn.ConvTranspose(
187
+ features=self.features,
188
+ kernel_size=self.kernel_size,
189
+ strides=self.strides,
190
+ kernel_init=self.kernel_init,
191
+ dtype=self.dtype,
192
+ precision=self.precision
193
+ )
194
+
195
+ def __call__(self, x):
196
+ return self.conv(x)
197
+
198
+ class Upsample(nn.Module):
199
+ features:int
200
+ scale:int
201
+ activation:Callable=jax.nn.swish
202
+ dtype: Any = jnp.bfloat16
203
+ precision: Any = jax.lax.Precision.HIGH
204
+
205
+ @nn.compact
206
+ def __call__(self, x, residual=None):
207
+ out = x
208
+ # out = PixelShuffle(scale=self.scale)(out)
209
+ B, H, W, C = x.shape
210
+ out = jax.image.resize(x, (B, H * self.scale, W * self.scale, C), method="nearest")
211
+ out = ConvLayer(
212
+ "conv",
213
+ features=self.features,
214
+ kernel_size=(3, 3),
215
+ strides=(1, 1),
216
+ dtype=self.dtype,
217
+ precision=self.precision
218
+ )(out)
219
+ if residual is not None:
220
+ out = jnp.concatenate([out, residual], axis=-1)
221
+ return out
222
+
223
+ class Downsample(nn.Module):
224
+ features:int
225
+ scale:int
226
+ activation:Callable=jax.nn.swish
227
+ dtype: Any = jnp.bfloat16
228
+ precision: Any = jax.lax.Precision.HIGH
229
+
230
+ @nn.compact
231
+ def __call__(self, x, residual=None):
232
+ out = ConvLayer(
233
+ "conv",
234
+ features=self.features,
235
+ kernel_size=(3, 3),
236
+ strides=(2, 2),
237
+ dtype=self.dtype,
238
+ precision=self.precision
239
+ )(x)
240
+ if residual is not None:
241
+ if residual.shape[1] > out.shape[1]:
242
+ residual = nn.avg_pool(residual, window_shape=(2, 2), strides=(2, 2), padding="SAME")
243
+ out = jnp.concatenate([out, residual], axis=-1)
244
+ return out
245
+
246
+
247
+ def l2norm(t, axis=1, eps=1e-12):
248
+ denom = jnp.clip(jnp.linalg.norm(t, ord=2, axis=axis, keepdims=True), eps)
249
+ out = t/denom
250
+ return (out)
@@ -3,248 +3,9 @@ import jax.numpy as jnp
3
3
  from flax import linen as nn
4
4
  from typing import Dict, Callable, Sequence, Any, Union
5
5
  import einops
6
- from .common import kernel_init
6
+ from .common import kernel_init, ConvLayer, Downsample, Upsample, FourierEmbedding, TimeProjection
7
7
  from .attention import TransformerBlock
8
8
 
9
- class WeightStandardizedConv(nn.Module):
10
- """
11
- apply weight standardization https://arxiv.org/abs/1903.10520
12
- """
13
- features: int
14
- kernel_size: Sequence[int] = 3
15
- strides: Union[None, int, Sequence[int]] = 1
16
- padding: Any = 1
17
- dtype: Any = jnp.float32
18
- param_dtype: Any = jnp.float32
19
-
20
- @nn.compact
21
- def __call__(self, x):
22
- """
23
- Applies a weight standardized convolution to the inputs.
24
-
25
- Args:
26
- inputs: input data with dimensions (batch, spatial_dims..., features).
27
-
28
- Returns:
29
- The convolved data.
30
- """
31
- x = x.astype(self.dtype)
32
-
33
- conv = nn.Conv(
34
- features=self.features,
35
- kernel_size=self.kernel_size,
36
- strides = self.strides,
37
- padding=self.padding,
38
- dtype=self.dtype,
39
- param_dtype = self.param_dtype,
40
- parent=None)
41
-
42
- kernel_init = lambda rng, x: conv.init(rng,x)['params']['kernel']
43
- bias_init = lambda rng, x: conv.init(rng,x)['params']['bias']
44
-
45
- # standardize kernel
46
- kernel = self.param('kernel', kernel_init, x)
47
- eps = 1e-5 if self.dtype == jnp.float32 else 1e-3
48
- # reduce over dim_out
49
- redux = tuple(range(kernel.ndim - 1))
50
- mean = jnp.mean(kernel, axis=redux, dtype=self.dtype, keepdims=True)
51
- var = jnp.var(kernel, axis=redux, dtype=self.dtype, keepdims=True)
52
- standardized_kernel = (kernel - mean)/jnp.sqrt(var + eps)
53
-
54
- bias = self.param('bias',bias_init, x)
55
-
56
- return(conv.apply({'params': {'kernel': standardized_kernel, 'bias': bias}},x))
57
-
58
- class PixelShuffle(nn.Module):
59
- scale: int
60
-
61
- @nn.compact
62
- def __call__(self, x):
63
- up = einops.rearrange(
64
- x,
65
- pattern="b h w (h2 w2 c) -> b (h h2) (w w2) c",
66
- h2=self.scale,
67
- w2=self.scale,
68
- )
69
- return up
70
-
71
- class TimeEmbedding(nn.Module):
72
- features:int
73
- nax_positions:int=10000
74
-
75
- def setup(self):
76
- half_dim = self.features // 2
77
- emb = jnp.log(self.nax_positions) / (half_dim - 1)
78
- emb = jnp.exp(-emb * jnp.arange(half_dim, dtype=jnp.float32))
79
- self.embeddings = emb
80
-
81
- def __call__(self, x):
82
- x = jax.lax.convert_element_type(x, jnp.float32)
83
- emb = x[:, None] * self.embeddings[None, :]
84
- emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
85
- return emb
86
-
87
- class FourierEmbedding(nn.Module):
88
- features:int
89
- scale:int = 16
90
-
91
- def setup(self):
92
- self.freqs = jax.random.normal(jax.random.PRNGKey(42), (self.features // 2, ), dtype=jnp.float32) * self.scale
93
-
94
- def __call__(self, x):
95
- x = jax.lax.convert_element_type(x, jnp.float32)
96
- emb = x[:, None] * (2 * jnp.pi * self.freqs)[None, :]
97
- emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
98
- return emb
99
-
100
- class TimeProjection(nn.Module):
101
- features:int
102
- activation:Callable=jax.nn.gelu
103
-
104
- @nn.compact
105
- def __call__(self, x):
106
- x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x)
107
- x = self.activation(x)
108
- x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x)
109
- x = self.activation(x)
110
- return x
111
-
112
- class SeparableConv(nn.Module):
113
- features:int
114
- kernel_size:tuple=(3, 3)
115
- strides:tuple=(1, 1)
116
- use_bias:bool=False
117
- kernel_init:Callable=kernel_init(1.0)
118
- padding:str="SAME"
119
- dtype: Any = jnp.bfloat16
120
- precision: Any = jax.lax.Precision.HIGH
121
-
122
- @nn.compact
123
- def __call__(self, x):
124
- in_features = x.shape[-1]
125
- depthwise = nn.Conv(
126
- features=in_features, kernel_size=self.kernel_size,
127
- strides=self.strides, kernel_init=self.kernel_init,
128
- feature_group_count=in_features, use_bias=self.use_bias,
129
- padding=self.padding,
130
- dtype=self.dtype,
131
- precision=self.precision
132
- )(x)
133
- pointwise = nn.Conv(
134
- features=self.features, kernel_size=(1, 1),
135
- strides=(1, 1), kernel_init=self.kernel_init,
136
- use_bias=self.use_bias,
137
- dtype=self.dtype,
138
- precision=self.precision
139
- )(depthwise)
140
- return pointwise
141
-
142
- class ConvLayer(nn.Module):
143
- conv_type:str
144
- features:int
145
- kernel_size:tuple=(3, 3)
146
- strides:tuple=(1, 1)
147
- kernel_init:Callable=kernel_init(1.0)
148
- dtype: Any = jnp.bfloat16
149
- precision: Any = jax.lax.Precision.HIGH
150
-
151
- def setup(self):
152
- # conv_type can be "conv", "separable", "conv_transpose"
153
- if self.conv_type == "conv":
154
- self.conv = nn.Conv(
155
- features=self.features,
156
- kernel_size=self.kernel_size,
157
- strides=self.strides,
158
- kernel_init=self.kernel_init,
159
- dtype=self.dtype,
160
- precision=self.precision
161
- )
162
- elif self.conv_type == "w_conv":
163
- self.conv = WeightStandardizedConv(
164
- features=self.features,
165
- kernel_size=self.kernel_size,
166
- strides=self.strides,
167
- padding="SAME",
168
- param_dtype=self.dtype,
169
- dtype=self.dtype,
170
- precision=self.precision
171
- )
172
- elif self.conv_type == "separable":
173
- self.conv = SeparableConv(
174
- features=self.features,
175
- kernel_size=self.kernel_size,
176
- strides=self.strides,
177
- kernel_init=self.kernel_init,
178
- dtype=self.dtype,
179
- precision=self.precision
180
- )
181
- elif self.conv_type == "conv_transpose":
182
- self.conv = nn.ConvTranspose(
183
- features=self.features,
184
- kernel_size=self.kernel_size,
185
- strides=self.strides,
186
- kernel_init=self.kernel_init,
187
- dtype=self.dtype,
188
- precision=self.precision
189
- )
190
-
191
- def __call__(self, x):
192
- return self.conv(x)
193
-
194
- class Upsample(nn.Module):
195
- features:int
196
- scale:int
197
- activation:Callable=jax.nn.swish
198
- dtype: Any = jnp.bfloat16
199
- precision: Any = jax.lax.Precision.HIGH
200
-
201
- @nn.compact
202
- def __call__(self, x, residual=None):
203
- out = x
204
- # out = PixelShuffle(scale=self.scale)(out)
205
- B, H, W, C = x.shape
206
- out = jax.image.resize(x, (B, H * self.scale, W * self.scale, C), method="nearest")
207
- out = ConvLayer(
208
- "conv",
209
- features=self.features,
210
- kernel_size=(3, 3),
211
- strides=(1, 1),
212
- dtype=self.dtype,
213
- precision=self.precision
214
- )(out)
215
- if residual is not None:
216
- out = jnp.concatenate([out, residual], axis=-1)
217
- return out
218
-
219
- class Downsample(nn.Module):
220
- features:int
221
- scale:int
222
- activation:Callable=jax.nn.swish
223
- dtype: Any = jnp.bfloat16
224
- precision: Any = jax.lax.Precision.HIGH
225
-
226
- @nn.compact
227
- def __call__(self, x, residual=None):
228
- out = ConvLayer(
229
- "conv",
230
- features=self.features,
231
- kernel_size=(3, 3),
232
- strides=(2, 2),
233
- dtype=self.dtype,
234
- precision=self.precision
235
- )(x)
236
- if residual is not None:
237
- if residual.shape[1] > out.shape[1]:
238
- residual = nn.avg_pool(residual, window_shape=(2, 2), strides=(2, 2), padding="SAME")
239
- out = jnp.concatenate([out, residual], axis=-1)
240
- return out
241
-
242
-
243
- def l2norm(t, axis=1, eps=1e-12):
244
- denom = jnp.clip(jnp.linalg.norm(t, ord=2, axis=axis, keepdims=True), eps)
245
- out = t/denom
246
- return (out)
247
-
248
9
  class ResidualBlock(nn.Module):
249
10
  conv_type:str
250
11
  features:int
@@ -318,6 +79,7 @@ class ResidualBlock(nn.Module):
318
79
  return out
319
80
 
320
81
  class Unet(nn.Module):
82
+ output_channels:int=3
321
83
  emb_features:int=64*4,
322
84
  feature_depths:list=[64, 128, 256, 512],
323
85
  attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}],
@@ -342,8 +104,6 @@ class Unet(nn.Module):
342
104
 
343
105
  conv_type = up_conv_type = down_conv_type = middle_conv_type = "conv"
344
106
  # middle_conv_type = "separable"
345
-
346
- print(f"input shape: {x.shape}")
347
107
 
348
108
  x = ConvLayer(
349
109
  conv_type,
@@ -355,8 +115,6 @@ class Unet(nn.Module):
355
115
  precision=self.precision
356
116
  )(x)
357
117
  downs = [x]
358
-
359
- print(f"x shape: {x.shape}")
360
118
 
361
119
  # Downscaling blocks
362
120
  for i, (dim_out, attention_config) in enumerate(zip(feature_depths, attention_configs)):
@@ -377,12 +135,13 @@ class Unet(nn.Module):
377
135
  )(x, temb)
378
136
  if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
379
137
  x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
380
- dim_head=dim_in // attention_config['heads'],
381
- use_flash_attention=attention_config.get("flash_attention", True),
382
- use_projection=attention_config.get("use_projection", False),
383
- use_self_and_cross=attention_config.get("use_self_and_cross", True),
384
- precision=attention_config.get("precision", self.precision),
385
- name=f"down_{i}_attention_{j}")(x, textcontext)
138
+ dim_head=dim_in // attention_config['heads'],
139
+ use_flash_attention=attention_config.get("flash_attention", True),
140
+ use_projection=attention_config.get("use_projection", False),
141
+ use_self_and_cross=attention_config.get("use_self_and_cross", True),
142
+ precision=attention_config.get("precision", self.precision),
143
+ only_pure_attention=True,
144
+ name=f"down_{i}_attention_{j}")(x, textcontext)
386
145
  # print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
387
146
  downs.append(x)
388
147
  if i != len(feature_depths) - 1:
@@ -420,6 +179,7 @@ class Unet(nn.Module):
420
179
  use_projection=middle_attention.get("use_projection", False),
421
180
  use_self_and_cross=False,
422
181
  precision=attention_config.get("precision", self.precision),
182
+ only_pure_attention=True,
423
183
  name=f"middle_attention_{j}")(x, textcontext)
424
184
  x = ResidualBlock(
425
185
  middle_conv_type,
@@ -456,12 +216,13 @@ class Unet(nn.Module):
456
216
  )(x, temb)
457
217
  if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
458
218
  x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
459
- dim_head=dim_out // attention_config['heads'],
460
- use_flash_attention=attention_config.get("flash_attention", True),
461
- use_projection=attention_config.get("use_projection", False),
462
- use_self_and_cross=attention_config.get("use_self_and_cross", True),
219
+ dim_head=dim_out // attention_config['heads'],
220
+ use_flash_attention=attention_config.get("flash_attention", True),
221
+ use_projection=attention_config.get("use_projection", False),
222
+ use_self_and_cross=attention_config.get("use_self_and_cross", True),
463
223
  precision=attention_config.get("precision", self.precision),
464
- name=f"up_{i}_attention_{j}")(x, textcontext)
224
+ only_pure_attention=True,
225
+ name=f"up_{i}_attention_{j}")(x, textcontext)
465
226
  # print("Upscaling ", i, x.shape)
466
227
  if i != len(feature_depths) - 1:
467
228
  x = Upsample(
@@ -504,7 +265,7 @@ class Unet(nn.Module):
504
265
 
505
266
  noise_out = ConvLayer(
506
267
  conv_type,
507
- features=3,
268
+ features=self.output_channels,
508
269
  kernel_size=(3, 3),
509
270
  strides=(1, 1),
510
271
  # activation=jax.nn.mish