flaxdiff 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/models/attention.py +132 -155
- flaxdiff/models/autoencoder/__init__.py +0 -0
- flaxdiff/models/autoencoder/autoencoder.py +14 -0
- flaxdiff/models/autoencoder/diffusers.py +88 -0
- flaxdiff/models/common.py +243 -0
- flaxdiff/models/simple_unet.py +17 -256
- flaxdiff/trainer/__init__.py +28 -45
- flaxdiff/trainer/simple_trainer.py +175 -80
- {flaxdiff-0.1.3.dist-info → flaxdiff-0.1.5.dist-info}/METADATA +10 -2
- {flaxdiff-0.1.3.dist-info → flaxdiff-0.1.5.dist-info}/RECORD +12 -9
- {flaxdiff-0.1.3.dist-info → flaxdiff-0.1.5.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.3.dist-info → flaxdiff-0.1.5.dist-info}/top_level.txt +0 -0
flaxdiff/models/common.py
CHANGED
@@ -1,7 +1,250 @@
|
|
1
1
|
import jax.numpy as jnp
|
2
|
+
import jax
|
2
3
|
from flax import linen as nn
|
4
|
+
from typing import Dict, Callable, Sequence, Any, Union
|
5
|
+
import einops
|
3
6
|
|
4
7
|
# Kernel initializer to use
|
5
8
|
def kernel_init(scale, dtype=jnp.float32):
|
6
9
|
scale = max(scale, 1e-10)
|
7
10
|
return nn.initializers.variance_scaling(scale=scale, mode="fan_avg", distribution="truncated_normal", dtype=dtype)
|
11
|
+
|
12
|
+
|
13
|
+
class WeightStandardizedConv(nn.Module):
|
14
|
+
"""
|
15
|
+
apply weight standardization https://arxiv.org/abs/1903.10520
|
16
|
+
"""
|
17
|
+
features: int
|
18
|
+
kernel_size: Sequence[int] = 3
|
19
|
+
strides: Union[None, int, Sequence[int]] = 1
|
20
|
+
padding: Any = 1
|
21
|
+
dtype: Any = jnp.float32
|
22
|
+
param_dtype: Any = jnp.float32
|
23
|
+
|
24
|
+
@nn.compact
|
25
|
+
def __call__(self, x):
|
26
|
+
"""
|
27
|
+
Applies a weight standardized convolution to the inputs.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
inputs: input data with dimensions (batch, spatial_dims..., features).
|
31
|
+
|
32
|
+
Returns:
|
33
|
+
The convolved data.
|
34
|
+
"""
|
35
|
+
x = x.astype(self.dtype)
|
36
|
+
|
37
|
+
conv = nn.Conv(
|
38
|
+
features=self.features,
|
39
|
+
kernel_size=self.kernel_size,
|
40
|
+
strides = self.strides,
|
41
|
+
padding=self.padding,
|
42
|
+
dtype=self.dtype,
|
43
|
+
param_dtype = self.param_dtype,
|
44
|
+
parent=None)
|
45
|
+
|
46
|
+
kernel_init = lambda rng, x: conv.init(rng,x)['params']['kernel']
|
47
|
+
bias_init = lambda rng, x: conv.init(rng,x)['params']['bias']
|
48
|
+
|
49
|
+
# standardize kernel
|
50
|
+
kernel = self.param('kernel', kernel_init, x)
|
51
|
+
eps = 1e-5 if self.dtype == jnp.float32 else 1e-3
|
52
|
+
# reduce over dim_out
|
53
|
+
redux = tuple(range(kernel.ndim - 1))
|
54
|
+
mean = jnp.mean(kernel, axis=redux, dtype=self.dtype, keepdims=True)
|
55
|
+
var = jnp.var(kernel, axis=redux, dtype=self.dtype, keepdims=True)
|
56
|
+
standardized_kernel = (kernel - mean)/jnp.sqrt(var + eps)
|
57
|
+
|
58
|
+
bias = self.param('bias',bias_init, x)
|
59
|
+
|
60
|
+
return(conv.apply({'params': {'kernel': standardized_kernel, 'bias': bias}},x))
|
61
|
+
|
62
|
+
class PixelShuffle(nn.Module):
|
63
|
+
scale: int
|
64
|
+
|
65
|
+
@nn.compact
|
66
|
+
def __call__(self, x):
|
67
|
+
up = einops.rearrange(
|
68
|
+
x,
|
69
|
+
pattern="b h w (h2 w2 c) -> b (h h2) (w w2) c",
|
70
|
+
h2=self.scale,
|
71
|
+
w2=self.scale,
|
72
|
+
)
|
73
|
+
return up
|
74
|
+
|
75
|
+
class TimeEmbedding(nn.Module):
|
76
|
+
features:int
|
77
|
+
nax_positions:int=10000
|
78
|
+
|
79
|
+
def setup(self):
|
80
|
+
half_dim = self.features // 2
|
81
|
+
emb = jnp.log(self.nax_positions) / (half_dim - 1)
|
82
|
+
emb = jnp.exp(-emb * jnp.arange(half_dim, dtype=jnp.float32))
|
83
|
+
self.embeddings = emb
|
84
|
+
|
85
|
+
def __call__(self, x):
|
86
|
+
x = jax.lax.convert_element_type(x, jnp.float32)
|
87
|
+
emb = x[:, None] * self.embeddings[None, :]
|
88
|
+
emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
|
89
|
+
return emb
|
90
|
+
|
91
|
+
class FourierEmbedding(nn.Module):
|
92
|
+
features:int
|
93
|
+
scale:int = 16
|
94
|
+
|
95
|
+
def setup(self):
|
96
|
+
self.freqs = jax.random.normal(jax.random.PRNGKey(42), (self.features // 2, ), dtype=jnp.float32) * self.scale
|
97
|
+
|
98
|
+
def __call__(self, x):
|
99
|
+
x = jax.lax.convert_element_type(x, jnp.float32)
|
100
|
+
emb = x[:, None] * (2 * jnp.pi * self.freqs)[None, :]
|
101
|
+
emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
|
102
|
+
return emb
|
103
|
+
|
104
|
+
class TimeProjection(nn.Module):
|
105
|
+
features:int
|
106
|
+
activation:Callable=jax.nn.gelu
|
107
|
+
|
108
|
+
@nn.compact
|
109
|
+
def __call__(self, x):
|
110
|
+
x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x)
|
111
|
+
x = self.activation(x)
|
112
|
+
x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x)
|
113
|
+
x = self.activation(x)
|
114
|
+
return x
|
115
|
+
|
116
|
+
class SeparableConv(nn.Module):
|
117
|
+
features:int
|
118
|
+
kernel_size:tuple=(3, 3)
|
119
|
+
strides:tuple=(1, 1)
|
120
|
+
use_bias:bool=False
|
121
|
+
kernel_init:Callable=kernel_init(1.0)
|
122
|
+
padding:str="SAME"
|
123
|
+
dtype: Any = jnp.bfloat16
|
124
|
+
precision: Any = jax.lax.Precision.HIGH
|
125
|
+
|
126
|
+
@nn.compact
|
127
|
+
def __call__(self, x):
|
128
|
+
in_features = x.shape[-1]
|
129
|
+
depthwise = nn.Conv(
|
130
|
+
features=in_features, kernel_size=self.kernel_size,
|
131
|
+
strides=self.strides, kernel_init=self.kernel_init,
|
132
|
+
feature_group_count=in_features, use_bias=self.use_bias,
|
133
|
+
padding=self.padding,
|
134
|
+
dtype=self.dtype,
|
135
|
+
precision=self.precision
|
136
|
+
)(x)
|
137
|
+
pointwise = nn.Conv(
|
138
|
+
features=self.features, kernel_size=(1, 1),
|
139
|
+
strides=(1, 1), kernel_init=self.kernel_init,
|
140
|
+
use_bias=self.use_bias,
|
141
|
+
dtype=self.dtype,
|
142
|
+
precision=self.precision
|
143
|
+
)(depthwise)
|
144
|
+
return pointwise
|
145
|
+
|
146
|
+
class ConvLayer(nn.Module):
|
147
|
+
conv_type:str
|
148
|
+
features:int
|
149
|
+
kernel_size:tuple=(3, 3)
|
150
|
+
strides:tuple=(1, 1)
|
151
|
+
kernel_init:Callable=kernel_init(1.0)
|
152
|
+
dtype: Any = jnp.bfloat16
|
153
|
+
precision: Any = jax.lax.Precision.HIGH
|
154
|
+
|
155
|
+
def setup(self):
|
156
|
+
# conv_type can be "conv", "separable", "conv_transpose"
|
157
|
+
if self.conv_type == "conv":
|
158
|
+
self.conv = nn.Conv(
|
159
|
+
features=self.features,
|
160
|
+
kernel_size=self.kernel_size,
|
161
|
+
strides=self.strides,
|
162
|
+
kernel_init=self.kernel_init,
|
163
|
+
dtype=self.dtype,
|
164
|
+
precision=self.precision
|
165
|
+
)
|
166
|
+
elif self.conv_type == "w_conv":
|
167
|
+
self.conv = WeightStandardizedConv(
|
168
|
+
features=self.features,
|
169
|
+
kernel_size=self.kernel_size,
|
170
|
+
strides=self.strides,
|
171
|
+
padding="SAME",
|
172
|
+
param_dtype=self.dtype,
|
173
|
+
dtype=self.dtype,
|
174
|
+
precision=self.precision
|
175
|
+
)
|
176
|
+
elif self.conv_type == "separable":
|
177
|
+
self.conv = SeparableConv(
|
178
|
+
features=self.features,
|
179
|
+
kernel_size=self.kernel_size,
|
180
|
+
strides=self.strides,
|
181
|
+
kernel_init=self.kernel_init,
|
182
|
+
dtype=self.dtype,
|
183
|
+
precision=self.precision
|
184
|
+
)
|
185
|
+
elif self.conv_type == "conv_transpose":
|
186
|
+
self.conv = nn.ConvTranspose(
|
187
|
+
features=self.features,
|
188
|
+
kernel_size=self.kernel_size,
|
189
|
+
strides=self.strides,
|
190
|
+
kernel_init=self.kernel_init,
|
191
|
+
dtype=self.dtype,
|
192
|
+
precision=self.precision
|
193
|
+
)
|
194
|
+
|
195
|
+
def __call__(self, x):
|
196
|
+
return self.conv(x)
|
197
|
+
|
198
|
+
class Upsample(nn.Module):
|
199
|
+
features:int
|
200
|
+
scale:int
|
201
|
+
activation:Callable=jax.nn.swish
|
202
|
+
dtype: Any = jnp.bfloat16
|
203
|
+
precision: Any = jax.lax.Precision.HIGH
|
204
|
+
|
205
|
+
@nn.compact
|
206
|
+
def __call__(self, x, residual=None):
|
207
|
+
out = x
|
208
|
+
# out = PixelShuffle(scale=self.scale)(out)
|
209
|
+
B, H, W, C = x.shape
|
210
|
+
out = jax.image.resize(x, (B, H * self.scale, W * self.scale, C), method="nearest")
|
211
|
+
out = ConvLayer(
|
212
|
+
"conv",
|
213
|
+
features=self.features,
|
214
|
+
kernel_size=(3, 3),
|
215
|
+
strides=(1, 1),
|
216
|
+
dtype=self.dtype,
|
217
|
+
precision=self.precision
|
218
|
+
)(out)
|
219
|
+
if residual is not None:
|
220
|
+
out = jnp.concatenate([out, residual], axis=-1)
|
221
|
+
return out
|
222
|
+
|
223
|
+
class Downsample(nn.Module):
|
224
|
+
features:int
|
225
|
+
scale:int
|
226
|
+
activation:Callable=jax.nn.swish
|
227
|
+
dtype: Any = jnp.bfloat16
|
228
|
+
precision: Any = jax.lax.Precision.HIGH
|
229
|
+
|
230
|
+
@nn.compact
|
231
|
+
def __call__(self, x, residual=None):
|
232
|
+
out = ConvLayer(
|
233
|
+
"conv",
|
234
|
+
features=self.features,
|
235
|
+
kernel_size=(3, 3),
|
236
|
+
strides=(2, 2),
|
237
|
+
dtype=self.dtype,
|
238
|
+
precision=self.precision
|
239
|
+
)(x)
|
240
|
+
if residual is not None:
|
241
|
+
if residual.shape[1] > out.shape[1]:
|
242
|
+
residual = nn.avg_pool(residual, window_shape=(2, 2), strides=(2, 2), padding="SAME")
|
243
|
+
out = jnp.concatenate([out, residual], axis=-1)
|
244
|
+
return out
|
245
|
+
|
246
|
+
|
247
|
+
def l2norm(t, axis=1, eps=1e-12):
|
248
|
+
denom = jnp.clip(jnp.linalg.norm(t, ord=2, axis=axis, keepdims=True), eps)
|
249
|
+
out = t/denom
|
250
|
+
return (out)
|
flaxdiff/models/simple_unet.py
CHANGED
@@ -3,248 +3,9 @@ import jax.numpy as jnp
|
|
3
3
|
from flax import linen as nn
|
4
4
|
from typing import Dict, Callable, Sequence, Any, Union
|
5
5
|
import einops
|
6
|
-
from .common import kernel_init
|
6
|
+
from .common import kernel_init, ConvLayer, Downsample, Upsample, FourierEmbedding, TimeProjection
|
7
7
|
from .attention import TransformerBlock
|
8
8
|
|
9
|
-
class WeightStandardizedConv(nn.Module):
|
10
|
-
"""
|
11
|
-
apply weight standardization https://arxiv.org/abs/1903.10520
|
12
|
-
"""
|
13
|
-
features: int
|
14
|
-
kernel_size: Sequence[int] = 3
|
15
|
-
strides: Union[None, int, Sequence[int]] = 1
|
16
|
-
padding: Any = 1
|
17
|
-
dtype: Any = jnp.float32
|
18
|
-
param_dtype: Any = jnp.float32
|
19
|
-
|
20
|
-
@nn.compact
|
21
|
-
def __call__(self, x):
|
22
|
-
"""
|
23
|
-
Applies a weight standardized convolution to the inputs.
|
24
|
-
|
25
|
-
Args:
|
26
|
-
inputs: input data with dimensions (batch, spatial_dims..., features).
|
27
|
-
|
28
|
-
Returns:
|
29
|
-
The convolved data.
|
30
|
-
"""
|
31
|
-
x = x.astype(self.dtype)
|
32
|
-
|
33
|
-
conv = nn.Conv(
|
34
|
-
features=self.features,
|
35
|
-
kernel_size=self.kernel_size,
|
36
|
-
strides = self.strides,
|
37
|
-
padding=self.padding,
|
38
|
-
dtype=self.dtype,
|
39
|
-
param_dtype = self.param_dtype,
|
40
|
-
parent=None)
|
41
|
-
|
42
|
-
kernel_init = lambda rng, x: conv.init(rng,x)['params']['kernel']
|
43
|
-
bias_init = lambda rng, x: conv.init(rng,x)['params']['bias']
|
44
|
-
|
45
|
-
# standardize kernel
|
46
|
-
kernel = self.param('kernel', kernel_init, x)
|
47
|
-
eps = 1e-5 if self.dtype == jnp.float32 else 1e-3
|
48
|
-
# reduce over dim_out
|
49
|
-
redux = tuple(range(kernel.ndim - 1))
|
50
|
-
mean = jnp.mean(kernel, axis=redux, dtype=self.dtype, keepdims=True)
|
51
|
-
var = jnp.var(kernel, axis=redux, dtype=self.dtype, keepdims=True)
|
52
|
-
standardized_kernel = (kernel - mean)/jnp.sqrt(var + eps)
|
53
|
-
|
54
|
-
bias = self.param('bias',bias_init, x)
|
55
|
-
|
56
|
-
return(conv.apply({'params': {'kernel': standardized_kernel, 'bias': bias}},x))
|
57
|
-
|
58
|
-
class PixelShuffle(nn.Module):
|
59
|
-
scale: int
|
60
|
-
|
61
|
-
@nn.compact
|
62
|
-
def __call__(self, x):
|
63
|
-
up = einops.rearrange(
|
64
|
-
x,
|
65
|
-
pattern="b h w (h2 w2 c) -> b (h h2) (w w2) c",
|
66
|
-
h2=self.scale,
|
67
|
-
w2=self.scale,
|
68
|
-
)
|
69
|
-
return up
|
70
|
-
|
71
|
-
class TimeEmbedding(nn.Module):
|
72
|
-
features:int
|
73
|
-
nax_positions:int=10000
|
74
|
-
|
75
|
-
def setup(self):
|
76
|
-
half_dim = self.features // 2
|
77
|
-
emb = jnp.log(self.nax_positions) / (half_dim - 1)
|
78
|
-
emb = jnp.exp(-emb * jnp.arange(half_dim, dtype=jnp.float32))
|
79
|
-
self.embeddings = emb
|
80
|
-
|
81
|
-
def __call__(self, x):
|
82
|
-
x = jax.lax.convert_element_type(x, jnp.float32)
|
83
|
-
emb = x[:, None] * self.embeddings[None, :]
|
84
|
-
emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
|
85
|
-
return emb
|
86
|
-
|
87
|
-
class FourierEmbedding(nn.Module):
|
88
|
-
features:int
|
89
|
-
scale:int = 16
|
90
|
-
|
91
|
-
def setup(self):
|
92
|
-
self.freqs = jax.random.normal(jax.random.PRNGKey(42), (self.features // 2, ), dtype=jnp.float32) * self.scale
|
93
|
-
|
94
|
-
def __call__(self, x):
|
95
|
-
x = jax.lax.convert_element_type(x, jnp.float32)
|
96
|
-
emb = x[:, None] * (2 * jnp.pi * self.freqs)[None, :]
|
97
|
-
emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
|
98
|
-
return emb
|
99
|
-
|
100
|
-
class TimeProjection(nn.Module):
|
101
|
-
features:int
|
102
|
-
activation:Callable=jax.nn.gelu
|
103
|
-
|
104
|
-
@nn.compact
|
105
|
-
def __call__(self, x):
|
106
|
-
x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x)
|
107
|
-
x = self.activation(x)
|
108
|
-
x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x)
|
109
|
-
x = self.activation(x)
|
110
|
-
return x
|
111
|
-
|
112
|
-
class SeparableConv(nn.Module):
|
113
|
-
features:int
|
114
|
-
kernel_size:tuple=(3, 3)
|
115
|
-
strides:tuple=(1, 1)
|
116
|
-
use_bias:bool=False
|
117
|
-
kernel_init:Callable=kernel_init(1.0)
|
118
|
-
padding:str="SAME"
|
119
|
-
dtype: Any = jnp.bfloat16
|
120
|
-
precision: Any = jax.lax.Precision.HIGH
|
121
|
-
|
122
|
-
@nn.compact
|
123
|
-
def __call__(self, x):
|
124
|
-
in_features = x.shape[-1]
|
125
|
-
depthwise = nn.Conv(
|
126
|
-
features=in_features, kernel_size=self.kernel_size,
|
127
|
-
strides=self.strides, kernel_init=self.kernel_init,
|
128
|
-
feature_group_count=in_features, use_bias=self.use_bias,
|
129
|
-
padding=self.padding,
|
130
|
-
dtype=self.dtype,
|
131
|
-
precision=self.precision
|
132
|
-
)(x)
|
133
|
-
pointwise = nn.Conv(
|
134
|
-
features=self.features, kernel_size=(1, 1),
|
135
|
-
strides=(1, 1), kernel_init=self.kernel_init,
|
136
|
-
use_bias=self.use_bias,
|
137
|
-
dtype=self.dtype,
|
138
|
-
precision=self.precision
|
139
|
-
)(depthwise)
|
140
|
-
return pointwise
|
141
|
-
|
142
|
-
class ConvLayer(nn.Module):
|
143
|
-
conv_type:str
|
144
|
-
features:int
|
145
|
-
kernel_size:tuple=(3, 3)
|
146
|
-
strides:tuple=(1, 1)
|
147
|
-
kernel_init:Callable=kernel_init(1.0)
|
148
|
-
dtype: Any = jnp.bfloat16
|
149
|
-
precision: Any = jax.lax.Precision.HIGH
|
150
|
-
|
151
|
-
def setup(self):
|
152
|
-
# conv_type can be "conv", "separable", "conv_transpose"
|
153
|
-
if self.conv_type == "conv":
|
154
|
-
self.conv = nn.Conv(
|
155
|
-
features=self.features,
|
156
|
-
kernel_size=self.kernel_size,
|
157
|
-
strides=self.strides,
|
158
|
-
kernel_init=self.kernel_init,
|
159
|
-
dtype=self.dtype,
|
160
|
-
precision=self.precision
|
161
|
-
)
|
162
|
-
elif self.conv_type == "w_conv":
|
163
|
-
self.conv = WeightStandardizedConv(
|
164
|
-
features=self.features,
|
165
|
-
kernel_size=self.kernel_size,
|
166
|
-
strides=self.strides,
|
167
|
-
padding="SAME",
|
168
|
-
param_dtype=self.dtype,
|
169
|
-
dtype=self.dtype,
|
170
|
-
precision=self.precision
|
171
|
-
)
|
172
|
-
elif self.conv_type == "separable":
|
173
|
-
self.conv = SeparableConv(
|
174
|
-
features=self.features,
|
175
|
-
kernel_size=self.kernel_size,
|
176
|
-
strides=self.strides,
|
177
|
-
kernel_init=self.kernel_init,
|
178
|
-
dtype=self.dtype,
|
179
|
-
precision=self.precision
|
180
|
-
)
|
181
|
-
elif self.conv_type == "conv_transpose":
|
182
|
-
self.conv = nn.ConvTranspose(
|
183
|
-
features=self.features,
|
184
|
-
kernel_size=self.kernel_size,
|
185
|
-
strides=self.strides,
|
186
|
-
kernel_init=self.kernel_init,
|
187
|
-
dtype=self.dtype,
|
188
|
-
precision=self.precision
|
189
|
-
)
|
190
|
-
|
191
|
-
def __call__(self, x):
|
192
|
-
return self.conv(x)
|
193
|
-
|
194
|
-
class Upsample(nn.Module):
|
195
|
-
features:int
|
196
|
-
scale:int
|
197
|
-
activation:Callable=jax.nn.swish
|
198
|
-
dtype: Any = jnp.bfloat16
|
199
|
-
precision: Any = jax.lax.Precision.HIGH
|
200
|
-
|
201
|
-
@nn.compact
|
202
|
-
def __call__(self, x, residual=None):
|
203
|
-
out = x
|
204
|
-
# out = PixelShuffle(scale=self.scale)(out)
|
205
|
-
B, H, W, C = x.shape
|
206
|
-
out = jax.image.resize(x, (B, H * self.scale, W * self.scale, C), method="nearest")
|
207
|
-
out = ConvLayer(
|
208
|
-
"conv",
|
209
|
-
features=self.features,
|
210
|
-
kernel_size=(3, 3),
|
211
|
-
strides=(1, 1),
|
212
|
-
dtype=self.dtype,
|
213
|
-
precision=self.precision
|
214
|
-
)(out)
|
215
|
-
if residual is not None:
|
216
|
-
out = jnp.concatenate([out, residual], axis=-1)
|
217
|
-
return out
|
218
|
-
|
219
|
-
class Downsample(nn.Module):
|
220
|
-
features:int
|
221
|
-
scale:int
|
222
|
-
activation:Callable=jax.nn.swish
|
223
|
-
dtype: Any = jnp.bfloat16
|
224
|
-
precision: Any = jax.lax.Precision.HIGH
|
225
|
-
|
226
|
-
@nn.compact
|
227
|
-
def __call__(self, x, residual=None):
|
228
|
-
out = ConvLayer(
|
229
|
-
"conv",
|
230
|
-
features=self.features,
|
231
|
-
kernel_size=(3, 3),
|
232
|
-
strides=(2, 2),
|
233
|
-
dtype=self.dtype,
|
234
|
-
precision=self.precision
|
235
|
-
)(x)
|
236
|
-
if residual is not None:
|
237
|
-
if residual.shape[1] > out.shape[1]:
|
238
|
-
residual = nn.avg_pool(residual, window_shape=(2, 2), strides=(2, 2), padding="SAME")
|
239
|
-
out = jnp.concatenate([out, residual], axis=-1)
|
240
|
-
return out
|
241
|
-
|
242
|
-
|
243
|
-
def l2norm(t, axis=1, eps=1e-12):
|
244
|
-
denom = jnp.clip(jnp.linalg.norm(t, ord=2, axis=axis, keepdims=True), eps)
|
245
|
-
out = t/denom
|
246
|
-
return (out)
|
247
|
-
|
248
9
|
class ResidualBlock(nn.Module):
|
249
10
|
conv_type:str
|
250
11
|
features:int
|
@@ -318,6 +79,7 @@ class ResidualBlock(nn.Module):
|
|
318
79
|
return out
|
319
80
|
|
320
81
|
class Unet(nn.Module):
|
82
|
+
output_channels:int=3
|
321
83
|
emb_features:int=64*4,
|
322
84
|
feature_depths:list=[64, 128, 256, 512],
|
323
85
|
attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}],
|
@@ -342,8 +104,6 @@ class Unet(nn.Module):
|
|
342
104
|
|
343
105
|
conv_type = up_conv_type = down_conv_type = middle_conv_type = "conv"
|
344
106
|
# middle_conv_type = "separable"
|
345
|
-
|
346
|
-
print(f"input shape: {x.shape}")
|
347
107
|
|
348
108
|
x = ConvLayer(
|
349
109
|
conv_type,
|
@@ -355,8 +115,6 @@ class Unet(nn.Module):
|
|
355
115
|
precision=self.precision
|
356
116
|
)(x)
|
357
117
|
downs = [x]
|
358
|
-
|
359
|
-
print(f"x shape: {x.shape}")
|
360
118
|
|
361
119
|
# Downscaling blocks
|
362
120
|
for i, (dim_out, attention_config) in enumerate(zip(feature_depths, attention_configs)):
|
@@ -377,12 +135,13 @@ class Unet(nn.Module):
|
|
377
135
|
)(x, temb)
|
378
136
|
if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
|
379
137
|
x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
138
|
+
dim_head=dim_in // attention_config['heads'],
|
139
|
+
use_flash_attention=attention_config.get("flash_attention", True),
|
140
|
+
use_projection=attention_config.get("use_projection", False),
|
141
|
+
use_self_and_cross=attention_config.get("use_self_and_cross", True),
|
142
|
+
precision=attention_config.get("precision", self.precision),
|
143
|
+
only_pure_attention=True,
|
144
|
+
name=f"down_{i}_attention_{j}")(x, textcontext)
|
386
145
|
# print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
|
387
146
|
downs.append(x)
|
388
147
|
if i != len(feature_depths) - 1:
|
@@ -420,6 +179,7 @@ class Unet(nn.Module):
|
|
420
179
|
use_projection=middle_attention.get("use_projection", False),
|
421
180
|
use_self_and_cross=False,
|
422
181
|
precision=attention_config.get("precision", self.precision),
|
182
|
+
only_pure_attention=True,
|
423
183
|
name=f"middle_attention_{j}")(x, textcontext)
|
424
184
|
x = ResidualBlock(
|
425
185
|
middle_conv_type,
|
@@ -456,12 +216,13 @@ class Unet(nn.Module):
|
|
456
216
|
)(x, temb)
|
457
217
|
if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
|
458
218
|
x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
219
|
+
dim_head=dim_out // attention_config['heads'],
|
220
|
+
use_flash_attention=attention_config.get("flash_attention", True),
|
221
|
+
use_projection=attention_config.get("use_projection", False),
|
222
|
+
use_self_and_cross=attention_config.get("use_self_and_cross", True),
|
463
223
|
precision=attention_config.get("precision", self.precision),
|
464
|
-
|
224
|
+
only_pure_attention=True,
|
225
|
+
name=f"up_{i}_attention_{j}")(x, textcontext)
|
465
226
|
# print("Upscaling ", i, x.shape)
|
466
227
|
if i != len(feature_depths) - 1:
|
467
228
|
x = Upsample(
|
@@ -504,7 +265,7 @@ class Unet(nn.Module):
|
|
504
265
|
|
505
266
|
noise_out = ConvLayer(
|
506
267
|
conv_type,
|
507
|
-
features=
|
268
|
+
features=self.output_channels,
|
508
269
|
kernel_size=(3, 3),
|
509
270
|
strides=(1, 1),
|
510
271
|
# activation=jax.nn.mish
|