flaxdiff 0.1.4__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flaxdiff/models/common.py CHANGED
@@ -1,7 +1,250 @@
1
1
  import jax.numpy as jnp
2
+ import jax
2
3
  from flax import linen as nn
4
+ from typing import Dict, Callable, Sequence, Any, Union
5
+ import einops
3
6
 
4
7
  # Kernel initializer to use
5
8
  def kernel_init(scale, dtype=jnp.float32):
6
9
  scale = max(scale, 1e-10)
7
10
  return nn.initializers.variance_scaling(scale=scale, mode="fan_avg", distribution="truncated_normal", dtype=dtype)
11
+
12
+
13
+ class WeightStandardizedConv(nn.Module):
14
+ """
15
+ apply weight standardization https://arxiv.org/abs/1903.10520
16
+ """
17
+ features: int
18
+ kernel_size: Sequence[int] = 3
19
+ strides: Union[None, int, Sequence[int]] = 1
20
+ padding: Any = 1
21
+ dtype: Any = jnp.float32
22
+ param_dtype: Any = jnp.float32
23
+
24
+ @nn.compact
25
+ def __call__(self, x):
26
+ """
27
+ Applies a weight standardized convolution to the inputs.
28
+
29
+ Args:
30
+ inputs: input data with dimensions (batch, spatial_dims..., features).
31
+
32
+ Returns:
33
+ The convolved data.
34
+ """
35
+ x = x.astype(self.dtype)
36
+
37
+ conv = nn.Conv(
38
+ features=self.features,
39
+ kernel_size=self.kernel_size,
40
+ strides = self.strides,
41
+ padding=self.padding,
42
+ dtype=self.dtype,
43
+ param_dtype = self.param_dtype,
44
+ parent=None)
45
+
46
+ kernel_init = lambda rng, x: conv.init(rng,x)['params']['kernel']
47
+ bias_init = lambda rng, x: conv.init(rng,x)['params']['bias']
48
+
49
+ # standardize kernel
50
+ kernel = self.param('kernel', kernel_init, x)
51
+ eps = 1e-5 if self.dtype == jnp.float32 else 1e-3
52
+ # reduce over dim_out
53
+ redux = tuple(range(kernel.ndim - 1))
54
+ mean = jnp.mean(kernel, axis=redux, dtype=self.dtype, keepdims=True)
55
+ var = jnp.var(kernel, axis=redux, dtype=self.dtype, keepdims=True)
56
+ standardized_kernel = (kernel - mean)/jnp.sqrt(var + eps)
57
+
58
+ bias = self.param('bias',bias_init, x)
59
+
60
+ return(conv.apply({'params': {'kernel': standardized_kernel, 'bias': bias}},x))
61
+
62
+ class PixelShuffle(nn.Module):
63
+ scale: int
64
+
65
+ @nn.compact
66
+ def __call__(self, x):
67
+ up = einops.rearrange(
68
+ x,
69
+ pattern="b h w (h2 w2 c) -> b (h h2) (w w2) c",
70
+ h2=self.scale,
71
+ w2=self.scale,
72
+ )
73
+ return up
74
+
75
+ class TimeEmbedding(nn.Module):
76
+ features:int
77
+ nax_positions:int=10000
78
+
79
+ def setup(self):
80
+ half_dim = self.features // 2
81
+ emb = jnp.log(self.nax_positions) / (half_dim - 1)
82
+ emb = jnp.exp(-emb * jnp.arange(half_dim, dtype=jnp.float32))
83
+ self.embeddings = emb
84
+
85
+ def __call__(self, x):
86
+ x = jax.lax.convert_element_type(x, jnp.float32)
87
+ emb = x[:, None] * self.embeddings[None, :]
88
+ emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
89
+ return emb
90
+
91
+ class FourierEmbedding(nn.Module):
92
+ features:int
93
+ scale:int = 16
94
+
95
+ def setup(self):
96
+ self.freqs = jax.random.normal(jax.random.PRNGKey(42), (self.features // 2, ), dtype=jnp.float32) * self.scale
97
+
98
+ def __call__(self, x):
99
+ x = jax.lax.convert_element_type(x, jnp.float32)
100
+ emb = x[:, None] * (2 * jnp.pi * self.freqs)[None, :]
101
+ emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
102
+ return emb
103
+
104
+ class TimeProjection(nn.Module):
105
+ features:int
106
+ activation:Callable=jax.nn.gelu
107
+
108
+ @nn.compact
109
+ def __call__(self, x):
110
+ x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x)
111
+ x = self.activation(x)
112
+ x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x)
113
+ x = self.activation(x)
114
+ return x
115
+
116
+ class SeparableConv(nn.Module):
117
+ features:int
118
+ kernel_size:tuple=(3, 3)
119
+ strides:tuple=(1, 1)
120
+ use_bias:bool=False
121
+ kernel_init:Callable=kernel_init(1.0)
122
+ padding:str="SAME"
123
+ dtype: Any = jnp.bfloat16
124
+ precision: Any = jax.lax.Precision.HIGH
125
+
126
+ @nn.compact
127
+ def __call__(self, x):
128
+ in_features = x.shape[-1]
129
+ depthwise = nn.Conv(
130
+ features=in_features, kernel_size=self.kernel_size,
131
+ strides=self.strides, kernel_init=self.kernel_init,
132
+ feature_group_count=in_features, use_bias=self.use_bias,
133
+ padding=self.padding,
134
+ dtype=self.dtype,
135
+ precision=self.precision
136
+ )(x)
137
+ pointwise = nn.Conv(
138
+ features=self.features, kernel_size=(1, 1),
139
+ strides=(1, 1), kernel_init=self.kernel_init,
140
+ use_bias=self.use_bias,
141
+ dtype=self.dtype,
142
+ precision=self.precision
143
+ )(depthwise)
144
+ return pointwise
145
+
146
+ class ConvLayer(nn.Module):
147
+ conv_type:str
148
+ features:int
149
+ kernel_size:tuple=(3, 3)
150
+ strides:tuple=(1, 1)
151
+ kernel_init:Callable=kernel_init(1.0)
152
+ dtype: Any = jnp.bfloat16
153
+ precision: Any = jax.lax.Precision.HIGH
154
+
155
+ def setup(self):
156
+ # conv_type can be "conv", "separable", "conv_transpose"
157
+ if self.conv_type == "conv":
158
+ self.conv = nn.Conv(
159
+ features=self.features,
160
+ kernel_size=self.kernel_size,
161
+ strides=self.strides,
162
+ kernel_init=self.kernel_init,
163
+ dtype=self.dtype,
164
+ precision=self.precision
165
+ )
166
+ elif self.conv_type == "w_conv":
167
+ self.conv = WeightStandardizedConv(
168
+ features=self.features,
169
+ kernel_size=self.kernel_size,
170
+ strides=self.strides,
171
+ padding="SAME",
172
+ param_dtype=self.dtype,
173
+ dtype=self.dtype,
174
+ precision=self.precision
175
+ )
176
+ elif self.conv_type == "separable":
177
+ self.conv = SeparableConv(
178
+ features=self.features,
179
+ kernel_size=self.kernel_size,
180
+ strides=self.strides,
181
+ kernel_init=self.kernel_init,
182
+ dtype=self.dtype,
183
+ precision=self.precision
184
+ )
185
+ elif self.conv_type == "conv_transpose":
186
+ self.conv = nn.ConvTranspose(
187
+ features=self.features,
188
+ kernel_size=self.kernel_size,
189
+ strides=self.strides,
190
+ kernel_init=self.kernel_init,
191
+ dtype=self.dtype,
192
+ precision=self.precision
193
+ )
194
+
195
+ def __call__(self, x):
196
+ return self.conv(x)
197
+
198
+ class Upsample(nn.Module):
199
+ features:int
200
+ scale:int
201
+ activation:Callable=jax.nn.swish
202
+ dtype: Any = jnp.bfloat16
203
+ precision: Any = jax.lax.Precision.HIGH
204
+
205
+ @nn.compact
206
+ def __call__(self, x, residual=None):
207
+ out = x
208
+ # out = PixelShuffle(scale=self.scale)(out)
209
+ B, H, W, C = x.shape
210
+ out = jax.image.resize(x, (B, H * self.scale, W * self.scale, C), method="nearest")
211
+ out = ConvLayer(
212
+ "conv",
213
+ features=self.features,
214
+ kernel_size=(3, 3),
215
+ strides=(1, 1),
216
+ dtype=self.dtype,
217
+ precision=self.precision
218
+ )(out)
219
+ if residual is not None:
220
+ out = jnp.concatenate([out, residual], axis=-1)
221
+ return out
222
+
223
+ class Downsample(nn.Module):
224
+ features:int
225
+ scale:int
226
+ activation:Callable=jax.nn.swish
227
+ dtype: Any = jnp.bfloat16
228
+ precision: Any = jax.lax.Precision.HIGH
229
+
230
+ @nn.compact
231
+ def __call__(self, x, residual=None):
232
+ out = ConvLayer(
233
+ "conv",
234
+ features=self.features,
235
+ kernel_size=(3, 3),
236
+ strides=(2, 2),
237
+ dtype=self.dtype,
238
+ precision=self.precision
239
+ )(x)
240
+ if residual is not None:
241
+ if residual.shape[1] > out.shape[1]:
242
+ residual = nn.avg_pool(residual, window_shape=(2, 2), strides=(2, 2), padding="SAME")
243
+ out = jnp.concatenate([out, residual], axis=-1)
244
+ return out
245
+
246
+
247
+ def l2norm(t, axis=1, eps=1e-12):
248
+ denom = jnp.clip(jnp.linalg.norm(t, ord=2, axis=axis, keepdims=True), eps)
249
+ out = t/denom
250
+ return (out)
@@ -3,248 +3,9 @@ import jax.numpy as jnp
3
3
  from flax import linen as nn
4
4
  from typing import Dict, Callable, Sequence, Any, Union
5
5
  import einops
6
- from .common import kernel_init
6
+ from .common import kernel_init, ConvLayer, Downsample, Upsample, FourierEmbedding, TimeProjection
7
7
  from .attention import TransformerBlock
8
8
 
9
- class WeightStandardizedConv(nn.Module):
10
- """
11
- apply weight standardization https://arxiv.org/abs/1903.10520
12
- """
13
- features: int
14
- kernel_size: Sequence[int] = 3
15
- strides: Union[None, int, Sequence[int]] = 1
16
- padding: Any = 1
17
- dtype: Any = jnp.float32
18
- param_dtype: Any = jnp.float32
19
-
20
- @nn.compact
21
- def __call__(self, x):
22
- """
23
- Applies a weight standardized convolution to the inputs.
24
-
25
- Args:
26
- inputs: input data with dimensions (batch, spatial_dims..., features).
27
-
28
- Returns:
29
- The convolved data.
30
- """
31
- x = x.astype(self.dtype)
32
-
33
- conv = nn.Conv(
34
- features=self.features,
35
- kernel_size=self.kernel_size,
36
- strides = self.strides,
37
- padding=self.padding,
38
- dtype=self.dtype,
39
- param_dtype = self.param_dtype,
40
- parent=None)
41
-
42
- kernel_init = lambda rng, x: conv.init(rng,x)['params']['kernel']
43
- bias_init = lambda rng, x: conv.init(rng,x)['params']['bias']
44
-
45
- # standardize kernel
46
- kernel = self.param('kernel', kernel_init, x)
47
- eps = 1e-5 if self.dtype == jnp.float32 else 1e-3
48
- # reduce over dim_out
49
- redux = tuple(range(kernel.ndim - 1))
50
- mean = jnp.mean(kernel, axis=redux, dtype=self.dtype, keepdims=True)
51
- var = jnp.var(kernel, axis=redux, dtype=self.dtype, keepdims=True)
52
- standardized_kernel = (kernel - mean)/jnp.sqrt(var + eps)
53
-
54
- bias = self.param('bias',bias_init, x)
55
-
56
- return(conv.apply({'params': {'kernel': standardized_kernel, 'bias': bias}},x))
57
-
58
- class PixelShuffle(nn.Module):
59
- scale: int
60
-
61
- @nn.compact
62
- def __call__(self, x):
63
- up = einops.rearrange(
64
- x,
65
- pattern="b h w (h2 w2 c) -> b (h h2) (w w2) c",
66
- h2=self.scale,
67
- w2=self.scale,
68
- )
69
- return up
70
-
71
- class TimeEmbedding(nn.Module):
72
- features:int
73
- nax_positions:int=10000
74
-
75
- def setup(self):
76
- half_dim = self.features // 2
77
- emb = jnp.log(self.nax_positions) / (half_dim - 1)
78
- emb = jnp.exp(-emb * jnp.arange(half_dim, dtype=jnp.float32))
79
- self.embeddings = emb
80
-
81
- def __call__(self, x):
82
- x = jax.lax.convert_element_type(x, jnp.float32)
83
- emb = x[:, None] * self.embeddings[None, :]
84
- emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
85
- return emb
86
-
87
- class FourierEmbedding(nn.Module):
88
- features:int
89
- scale:int = 16
90
-
91
- def setup(self):
92
- self.freqs = jax.random.normal(jax.random.PRNGKey(42), (self.features // 2, ), dtype=jnp.float32) * self.scale
93
-
94
- def __call__(self, x):
95
- x = jax.lax.convert_element_type(x, jnp.float32)
96
- emb = x[:, None] * (2 * jnp.pi * self.freqs)[None, :]
97
- emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
98
- return emb
99
-
100
- class TimeProjection(nn.Module):
101
- features:int
102
- activation:Callable=jax.nn.gelu
103
-
104
- @nn.compact
105
- def __call__(self, x):
106
- x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x)
107
- x = self.activation(x)
108
- x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x)
109
- x = self.activation(x)
110
- return x
111
-
112
- class SeparableConv(nn.Module):
113
- features:int
114
- kernel_size:tuple=(3, 3)
115
- strides:tuple=(1, 1)
116
- use_bias:bool=False
117
- kernel_init:Callable=kernel_init(1.0)
118
- padding:str="SAME"
119
- dtype: Any = jnp.bfloat16
120
- precision: Any = jax.lax.Precision.HIGH
121
-
122
- @nn.compact
123
- def __call__(self, x):
124
- in_features = x.shape[-1]
125
- depthwise = nn.Conv(
126
- features=in_features, kernel_size=self.kernel_size,
127
- strides=self.strides, kernel_init=self.kernel_init,
128
- feature_group_count=in_features, use_bias=self.use_bias,
129
- padding=self.padding,
130
- dtype=self.dtype,
131
- precision=self.precision
132
- )(x)
133
- pointwise = nn.Conv(
134
- features=self.features, kernel_size=(1, 1),
135
- strides=(1, 1), kernel_init=self.kernel_init,
136
- use_bias=self.use_bias,
137
- dtype=self.dtype,
138
- precision=self.precision
139
- )(depthwise)
140
- return pointwise
141
-
142
- class ConvLayer(nn.Module):
143
- conv_type:str
144
- features:int
145
- kernel_size:tuple=(3, 3)
146
- strides:tuple=(1, 1)
147
- kernel_init:Callable=kernel_init(1.0)
148
- dtype: Any = jnp.bfloat16
149
- precision: Any = jax.lax.Precision.HIGH
150
-
151
- def setup(self):
152
- # conv_type can be "conv", "separable", "conv_transpose"
153
- if self.conv_type == "conv":
154
- self.conv = nn.Conv(
155
- features=self.features,
156
- kernel_size=self.kernel_size,
157
- strides=self.strides,
158
- kernel_init=self.kernel_init,
159
- dtype=self.dtype,
160
- precision=self.precision
161
- )
162
- elif self.conv_type == "w_conv":
163
- self.conv = WeightStandardizedConv(
164
- features=self.features,
165
- kernel_size=self.kernel_size,
166
- strides=self.strides,
167
- padding="SAME",
168
- param_dtype=self.dtype,
169
- dtype=self.dtype,
170
- precision=self.precision
171
- )
172
- elif self.conv_type == "separable":
173
- self.conv = SeparableConv(
174
- features=self.features,
175
- kernel_size=self.kernel_size,
176
- strides=self.strides,
177
- kernel_init=self.kernel_init,
178
- dtype=self.dtype,
179
- precision=self.precision
180
- )
181
- elif self.conv_type == "conv_transpose":
182
- self.conv = nn.ConvTranspose(
183
- features=self.features,
184
- kernel_size=self.kernel_size,
185
- strides=self.strides,
186
- kernel_init=self.kernel_init,
187
- dtype=self.dtype,
188
- precision=self.precision
189
- )
190
-
191
- def __call__(self, x):
192
- return self.conv(x)
193
-
194
- class Upsample(nn.Module):
195
- features:int
196
- scale:int
197
- activation:Callable=jax.nn.swish
198
- dtype: Any = jnp.bfloat16
199
- precision: Any = jax.lax.Precision.HIGH
200
-
201
- @nn.compact
202
- def __call__(self, x, residual=None):
203
- out = x
204
- # out = PixelShuffle(scale=self.scale)(out)
205
- B, H, W, C = x.shape
206
- out = jax.image.resize(x, (B, H * self.scale, W * self.scale, C), method="nearest")
207
- out = ConvLayer(
208
- "conv",
209
- features=self.features,
210
- kernel_size=(3, 3),
211
- strides=(1, 1),
212
- dtype=self.dtype,
213
- precision=self.precision
214
- )(out)
215
- if residual is not None:
216
- out = jnp.concatenate([out, residual], axis=-1)
217
- return out
218
-
219
- class Downsample(nn.Module):
220
- features:int
221
- scale:int
222
- activation:Callable=jax.nn.swish
223
- dtype: Any = jnp.bfloat16
224
- precision: Any = jax.lax.Precision.HIGH
225
-
226
- @nn.compact
227
- def __call__(self, x, residual=None):
228
- out = ConvLayer(
229
- "conv",
230
- features=self.features,
231
- kernel_size=(3, 3),
232
- strides=(2, 2),
233
- dtype=self.dtype,
234
- precision=self.precision
235
- )(x)
236
- if residual is not None:
237
- if residual.shape[1] > out.shape[1]:
238
- residual = nn.avg_pool(residual, window_shape=(2, 2), strides=(2, 2), padding="SAME")
239
- out = jnp.concatenate([out, residual], axis=-1)
240
- return out
241
-
242
-
243
- def l2norm(t, axis=1, eps=1e-12):
244
- denom = jnp.clip(jnp.linalg.norm(t, ord=2, axis=axis, keepdims=True), eps)
245
- out = t/denom
246
- return (out)
247
-
248
9
  class ResidualBlock(nn.Module):
249
10
  conv_type:str
250
11
  features:int
@@ -318,6 +79,7 @@ class ResidualBlock(nn.Module):
318
79
  return out
319
80
 
320
81
  class Unet(nn.Module):
82
+ output_channels:int=3
321
83
  emb_features:int=64*4,
322
84
  feature_depths:list=[64, 128, 256, 512],
323
85
  attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}],
@@ -373,12 +135,13 @@ class Unet(nn.Module):
373
135
  )(x, temb)
374
136
  if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
375
137
  x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
376
- dim_head=dim_in // attention_config['heads'],
377
- use_flash_attention=attention_config.get("flash_attention", True),
378
- use_projection=attention_config.get("use_projection", False),
379
- use_self_and_cross=attention_config.get("use_self_and_cross", True),
380
- precision=attention_config.get("precision", self.precision),
381
- name=f"down_{i}_attention_{j}")(x, textcontext)
138
+ dim_head=dim_in // attention_config['heads'],
139
+ use_flash_attention=attention_config.get("flash_attention", True),
140
+ use_projection=attention_config.get("use_projection", False),
141
+ use_self_and_cross=attention_config.get("use_self_and_cross", True),
142
+ precision=attention_config.get("precision", self.precision),
143
+ only_pure_attention=True,
144
+ name=f"down_{i}_attention_{j}")(x, textcontext)
382
145
  # print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
383
146
  downs.append(x)
384
147
  if i != len(feature_depths) - 1:
@@ -416,6 +179,7 @@ class Unet(nn.Module):
416
179
  use_projection=middle_attention.get("use_projection", False),
417
180
  use_self_and_cross=False,
418
181
  precision=attention_config.get("precision", self.precision),
182
+ only_pure_attention=True,
419
183
  name=f"middle_attention_{j}")(x, textcontext)
420
184
  x = ResidualBlock(
421
185
  middle_conv_type,
@@ -452,12 +216,13 @@ class Unet(nn.Module):
452
216
  )(x, temb)
453
217
  if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
454
218
  x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
455
- dim_head=dim_out // attention_config['heads'],
456
- use_flash_attention=attention_config.get("flash_attention", True),
457
- use_projection=attention_config.get("use_projection", False),
458
- use_self_and_cross=attention_config.get("use_self_and_cross", True),
219
+ dim_head=dim_out // attention_config['heads'],
220
+ use_flash_attention=attention_config.get("flash_attention", True),
221
+ use_projection=attention_config.get("use_projection", False),
222
+ use_self_and_cross=attention_config.get("use_self_and_cross", True),
459
223
  precision=attention_config.get("precision", self.precision),
460
- name=f"up_{i}_attention_{j}")(x, textcontext)
224
+ only_pure_attention=True,
225
+ name=f"up_{i}_attention_{j}")(x, textcontext)
461
226
  # print("Upscaling ", i, x.shape)
462
227
  if i != len(feature_depths) - 1:
463
228
  x = Upsample(
@@ -500,7 +265,7 @@ class Unet(nn.Module):
500
265
 
501
266
  noise_out = ConvLayer(
502
267
  conv_type,
503
- features=3,
268
+ features=self.output_channels,
504
269
  kernel_size=(3, 3),
505
270
  strides=(1, 1),
506
271
  # activation=jax.nn.mish