flaxdiff 0.1.4__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/models/attention.py +132 -155
- flaxdiff/models/autoencoder/__init__.py +0 -0
- flaxdiff/models/autoencoder/autoencoder.py +14 -0
- flaxdiff/models/autoencoder/diffusers.py +88 -0
- flaxdiff/models/common.py +243 -0
- flaxdiff/models/simple_unet.py +17 -252
- flaxdiff/trainer/__init__.py +28 -45
- flaxdiff/trainer/simple_trainer.py +175 -80
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.5.dist-info}/METADATA +10 -2
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.5.dist-info}/RECORD +12 -9
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.5.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.5.dist-info}/top_level.txt +0 -0
flaxdiff/models/common.py
CHANGED
@@ -1,7 +1,250 @@
|
|
1
1
|
import jax.numpy as jnp
|
2
|
+
import jax
|
2
3
|
from flax import linen as nn
|
4
|
+
from typing import Dict, Callable, Sequence, Any, Union
|
5
|
+
import einops
|
3
6
|
|
4
7
|
# Kernel initializer to use
|
5
8
|
def kernel_init(scale, dtype=jnp.float32):
|
6
9
|
scale = max(scale, 1e-10)
|
7
10
|
return nn.initializers.variance_scaling(scale=scale, mode="fan_avg", distribution="truncated_normal", dtype=dtype)
|
11
|
+
|
12
|
+
|
13
|
+
class WeightStandardizedConv(nn.Module):
|
14
|
+
"""
|
15
|
+
apply weight standardization https://arxiv.org/abs/1903.10520
|
16
|
+
"""
|
17
|
+
features: int
|
18
|
+
kernel_size: Sequence[int] = 3
|
19
|
+
strides: Union[None, int, Sequence[int]] = 1
|
20
|
+
padding: Any = 1
|
21
|
+
dtype: Any = jnp.float32
|
22
|
+
param_dtype: Any = jnp.float32
|
23
|
+
|
24
|
+
@nn.compact
|
25
|
+
def __call__(self, x):
|
26
|
+
"""
|
27
|
+
Applies a weight standardized convolution to the inputs.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
inputs: input data with dimensions (batch, spatial_dims..., features).
|
31
|
+
|
32
|
+
Returns:
|
33
|
+
The convolved data.
|
34
|
+
"""
|
35
|
+
x = x.astype(self.dtype)
|
36
|
+
|
37
|
+
conv = nn.Conv(
|
38
|
+
features=self.features,
|
39
|
+
kernel_size=self.kernel_size,
|
40
|
+
strides = self.strides,
|
41
|
+
padding=self.padding,
|
42
|
+
dtype=self.dtype,
|
43
|
+
param_dtype = self.param_dtype,
|
44
|
+
parent=None)
|
45
|
+
|
46
|
+
kernel_init = lambda rng, x: conv.init(rng,x)['params']['kernel']
|
47
|
+
bias_init = lambda rng, x: conv.init(rng,x)['params']['bias']
|
48
|
+
|
49
|
+
# standardize kernel
|
50
|
+
kernel = self.param('kernel', kernel_init, x)
|
51
|
+
eps = 1e-5 if self.dtype == jnp.float32 else 1e-3
|
52
|
+
# reduce over dim_out
|
53
|
+
redux = tuple(range(kernel.ndim - 1))
|
54
|
+
mean = jnp.mean(kernel, axis=redux, dtype=self.dtype, keepdims=True)
|
55
|
+
var = jnp.var(kernel, axis=redux, dtype=self.dtype, keepdims=True)
|
56
|
+
standardized_kernel = (kernel - mean)/jnp.sqrt(var + eps)
|
57
|
+
|
58
|
+
bias = self.param('bias',bias_init, x)
|
59
|
+
|
60
|
+
return(conv.apply({'params': {'kernel': standardized_kernel, 'bias': bias}},x))
|
61
|
+
|
62
|
+
class PixelShuffle(nn.Module):
|
63
|
+
scale: int
|
64
|
+
|
65
|
+
@nn.compact
|
66
|
+
def __call__(self, x):
|
67
|
+
up = einops.rearrange(
|
68
|
+
x,
|
69
|
+
pattern="b h w (h2 w2 c) -> b (h h2) (w w2) c",
|
70
|
+
h2=self.scale,
|
71
|
+
w2=self.scale,
|
72
|
+
)
|
73
|
+
return up
|
74
|
+
|
75
|
+
class TimeEmbedding(nn.Module):
|
76
|
+
features:int
|
77
|
+
nax_positions:int=10000
|
78
|
+
|
79
|
+
def setup(self):
|
80
|
+
half_dim = self.features // 2
|
81
|
+
emb = jnp.log(self.nax_positions) / (half_dim - 1)
|
82
|
+
emb = jnp.exp(-emb * jnp.arange(half_dim, dtype=jnp.float32))
|
83
|
+
self.embeddings = emb
|
84
|
+
|
85
|
+
def __call__(self, x):
|
86
|
+
x = jax.lax.convert_element_type(x, jnp.float32)
|
87
|
+
emb = x[:, None] * self.embeddings[None, :]
|
88
|
+
emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
|
89
|
+
return emb
|
90
|
+
|
91
|
+
class FourierEmbedding(nn.Module):
|
92
|
+
features:int
|
93
|
+
scale:int = 16
|
94
|
+
|
95
|
+
def setup(self):
|
96
|
+
self.freqs = jax.random.normal(jax.random.PRNGKey(42), (self.features // 2, ), dtype=jnp.float32) * self.scale
|
97
|
+
|
98
|
+
def __call__(self, x):
|
99
|
+
x = jax.lax.convert_element_type(x, jnp.float32)
|
100
|
+
emb = x[:, None] * (2 * jnp.pi * self.freqs)[None, :]
|
101
|
+
emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
|
102
|
+
return emb
|
103
|
+
|
104
|
+
class TimeProjection(nn.Module):
|
105
|
+
features:int
|
106
|
+
activation:Callable=jax.nn.gelu
|
107
|
+
|
108
|
+
@nn.compact
|
109
|
+
def __call__(self, x):
|
110
|
+
x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x)
|
111
|
+
x = self.activation(x)
|
112
|
+
x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x)
|
113
|
+
x = self.activation(x)
|
114
|
+
return x
|
115
|
+
|
116
|
+
class SeparableConv(nn.Module):
|
117
|
+
features:int
|
118
|
+
kernel_size:tuple=(3, 3)
|
119
|
+
strides:tuple=(1, 1)
|
120
|
+
use_bias:bool=False
|
121
|
+
kernel_init:Callable=kernel_init(1.0)
|
122
|
+
padding:str="SAME"
|
123
|
+
dtype: Any = jnp.bfloat16
|
124
|
+
precision: Any = jax.lax.Precision.HIGH
|
125
|
+
|
126
|
+
@nn.compact
|
127
|
+
def __call__(self, x):
|
128
|
+
in_features = x.shape[-1]
|
129
|
+
depthwise = nn.Conv(
|
130
|
+
features=in_features, kernel_size=self.kernel_size,
|
131
|
+
strides=self.strides, kernel_init=self.kernel_init,
|
132
|
+
feature_group_count=in_features, use_bias=self.use_bias,
|
133
|
+
padding=self.padding,
|
134
|
+
dtype=self.dtype,
|
135
|
+
precision=self.precision
|
136
|
+
)(x)
|
137
|
+
pointwise = nn.Conv(
|
138
|
+
features=self.features, kernel_size=(1, 1),
|
139
|
+
strides=(1, 1), kernel_init=self.kernel_init,
|
140
|
+
use_bias=self.use_bias,
|
141
|
+
dtype=self.dtype,
|
142
|
+
precision=self.precision
|
143
|
+
)(depthwise)
|
144
|
+
return pointwise
|
145
|
+
|
146
|
+
class ConvLayer(nn.Module):
|
147
|
+
conv_type:str
|
148
|
+
features:int
|
149
|
+
kernel_size:tuple=(3, 3)
|
150
|
+
strides:tuple=(1, 1)
|
151
|
+
kernel_init:Callable=kernel_init(1.0)
|
152
|
+
dtype: Any = jnp.bfloat16
|
153
|
+
precision: Any = jax.lax.Precision.HIGH
|
154
|
+
|
155
|
+
def setup(self):
|
156
|
+
# conv_type can be "conv", "separable", "conv_transpose"
|
157
|
+
if self.conv_type == "conv":
|
158
|
+
self.conv = nn.Conv(
|
159
|
+
features=self.features,
|
160
|
+
kernel_size=self.kernel_size,
|
161
|
+
strides=self.strides,
|
162
|
+
kernel_init=self.kernel_init,
|
163
|
+
dtype=self.dtype,
|
164
|
+
precision=self.precision
|
165
|
+
)
|
166
|
+
elif self.conv_type == "w_conv":
|
167
|
+
self.conv = WeightStandardizedConv(
|
168
|
+
features=self.features,
|
169
|
+
kernel_size=self.kernel_size,
|
170
|
+
strides=self.strides,
|
171
|
+
padding="SAME",
|
172
|
+
param_dtype=self.dtype,
|
173
|
+
dtype=self.dtype,
|
174
|
+
precision=self.precision
|
175
|
+
)
|
176
|
+
elif self.conv_type == "separable":
|
177
|
+
self.conv = SeparableConv(
|
178
|
+
features=self.features,
|
179
|
+
kernel_size=self.kernel_size,
|
180
|
+
strides=self.strides,
|
181
|
+
kernel_init=self.kernel_init,
|
182
|
+
dtype=self.dtype,
|
183
|
+
precision=self.precision
|
184
|
+
)
|
185
|
+
elif self.conv_type == "conv_transpose":
|
186
|
+
self.conv = nn.ConvTranspose(
|
187
|
+
features=self.features,
|
188
|
+
kernel_size=self.kernel_size,
|
189
|
+
strides=self.strides,
|
190
|
+
kernel_init=self.kernel_init,
|
191
|
+
dtype=self.dtype,
|
192
|
+
precision=self.precision
|
193
|
+
)
|
194
|
+
|
195
|
+
def __call__(self, x):
|
196
|
+
return self.conv(x)
|
197
|
+
|
198
|
+
class Upsample(nn.Module):
|
199
|
+
features:int
|
200
|
+
scale:int
|
201
|
+
activation:Callable=jax.nn.swish
|
202
|
+
dtype: Any = jnp.bfloat16
|
203
|
+
precision: Any = jax.lax.Precision.HIGH
|
204
|
+
|
205
|
+
@nn.compact
|
206
|
+
def __call__(self, x, residual=None):
|
207
|
+
out = x
|
208
|
+
# out = PixelShuffle(scale=self.scale)(out)
|
209
|
+
B, H, W, C = x.shape
|
210
|
+
out = jax.image.resize(x, (B, H * self.scale, W * self.scale, C), method="nearest")
|
211
|
+
out = ConvLayer(
|
212
|
+
"conv",
|
213
|
+
features=self.features,
|
214
|
+
kernel_size=(3, 3),
|
215
|
+
strides=(1, 1),
|
216
|
+
dtype=self.dtype,
|
217
|
+
precision=self.precision
|
218
|
+
)(out)
|
219
|
+
if residual is not None:
|
220
|
+
out = jnp.concatenate([out, residual], axis=-1)
|
221
|
+
return out
|
222
|
+
|
223
|
+
class Downsample(nn.Module):
|
224
|
+
features:int
|
225
|
+
scale:int
|
226
|
+
activation:Callable=jax.nn.swish
|
227
|
+
dtype: Any = jnp.bfloat16
|
228
|
+
precision: Any = jax.lax.Precision.HIGH
|
229
|
+
|
230
|
+
@nn.compact
|
231
|
+
def __call__(self, x, residual=None):
|
232
|
+
out = ConvLayer(
|
233
|
+
"conv",
|
234
|
+
features=self.features,
|
235
|
+
kernel_size=(3, 3),
|
236
|
+
strides=(2, 2),
|
237
|
+
dtype=self.dtype,
|
238
|
+
precision=self.precision
|
239
|
+
)(x)
|
240
|
+
if residual is not None:
|
241
|
+
if residual.shape[1] > out.shape[1]:
|
242
|
+
residual = nn.avg_pool(residual, window_shape=(2, 2), strides=(2, 2), padding="SAME")
|
243
|
+
out = jnp.concatenate([out, residual], axis=-1)
|
244
|
+
return out
|
245
|
+
|
246
|
+
|
247
|
+
def l2norm(t, axis=1, eps=1e-12):
|
248
|
+
denom = jnp.clip(jnp.linalg.norm(t, ord=2, axis=axis, keepdims=True), eps)
|
249
|
+
out = t/denom
|
250
|
+
return (out)
|
flaxdiff/models/simple_unet.py
CHANGED
@@ -3,248 +3,9 @@ import jax.numpy as jnp
|
|
3
3
|
from flax import linen as nn
|
4
4
|
from typing import Dict, Callable, Sequence, Any, Union
|
5
5
|
import einops
|
6
|
-
from .common import kernel_init
|
6
|
+
from .common import kernel_init, ConvLayer, Downsample, Upsample, FourierEmbedding, TimeProjection
|
7
7
|
from .attention import TransformerBlock
|
8
8
|
|
9
|
-
class WeightStandardizedConv(nn.Module):
|
10
|
-
"""
|
11
|
-
apply weight standardization https://arxiv.org/abs/1903.10520
|
12
|
-
"""
|
13
|
-
features: int
|
14
|
-
kernel_size: Sequence[int] = 3
|
15
|
-
strides: Union[None, int, Sequence[int]] = 1
|
16
|
-
padding: Any = 1
|
17
|
-
dtype: Any = jnp.float32
|
18
|
-
param_dtype: Any = jnp.float32
|
19
|
-
|
20
|
-
@nn.compact
|
21
|
-
def __call__(self, x):
|
22
|
-
"""
|
23
|
-
Applies a weight standardized convolution to the inputs.
|
24
|
-
|
25
|
-
Args:
|
26
|
-
inputs: input data with dimensions (batch, spatial_dims..., features).
|
27
|
-
|
28
|
-
Returns:
|
29
|
-
The convolved data.
|
30
|
-
"""
|
31
|
-
x = x.astype(self.dtype)
|
32
|
-
|
33
|
-
conv = nn.Conv(
|
34
|
-
features=self.features,
|
35
|
-
kernel_size=self.kernel_size,
|
36
|
-
strides = self.strides,
|
37
|
-
padding=self.padding,
|
38
|
-
dtype=self.dtype,
|
39
|
-
param_dtype = self.param_dtype,
|
40
|
-
parent=None)
|
41
|
-
|
42
|
-
kernel_init = lambda rng, x: conv.init(rng,x)['params']['kernel']
|
43
|
-
bias_init = lambda rng, x: conv.init(rng,x)['params']['bias']
|
44
|
-
|
45
|
-
# standardize kernel
|
46
|
-
kernel = self.param('kernel', kernel_init, x)
|
47
|
-
eps = 1e-5 if self.dtype == jnp.float32 else 1e-3
|
48
|
-
# reduce over dim_out
|
49
|
-
redux = tuple(range(kernel.ndim - 1))
|
50
|
-
mean = jnp.mean(kernel, axis=redux, dtype=self.dtype, keepdims=True)
|
51
|
-
var = jnp.var(kernel, axis=redux, dtype=self.dtype, keepdims=True)
|
52
|
-
standardized_kernel = (kernel - mean)/jnp.sqrt(var + eps)
|
53
|
-
|
54
|
-
bias = self.param('bias',bias_init, x)
|
55
|
-
|
56
|
-
return(conv.apply({'params': {'kernel': standardized_kernel, 'bias': bias}},x))
|
57
|
-
|
58
|
-
class PixelShuffle(nn.Module):
|
59
|
-
scale: int
|
60
|
-
|
61
|
-
@nn.compact
|
62
|
-
def __call__(self, x):
|
63
|
-
up = einops.rearrange(
|
64
|
-
x,
|
65
|
-
pattern="b h w (h2 w2 c) -> b (h h2) (w w2) c",
|
66
|
-
h2=self.scale,
|
67
|
-
w2=self.scale,
|
68
|
-
)
|
69
|
-
return up
|
70
|
-
|
71
|
-
class TimeEmbedding(nn.Module):
|
72
|
-
features:int
|
73
|
-
nax_positions:int=10000
|
74
|
-
|
75
|
-
def setup(self):
|
76
|
-
half_dim = self.features // 2
|
77
|
-
emb = jnp.log(self.nax_positions) / (half_dim - 1)
|
78
|
-
emb = jnp.exp(-emb * jnp.arange(half_dim, dtype=jnp.float32))
|
79
|
-
self.embeddings = emb
|
80
|
-
|
81
|
-
def __call__(self, x):
|
82
|
-
x = jax.lax.convert_element_type(x, jnp.float32)
|
83
|
-
emb = x[:, None] * self.embeddings[None, :]
|
84
|
-
emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
|
85
|
-
return emb
|
86
|
-
|
87
|
-
class FourierEmbedding(nn.Module):
|
88
|
-
features:int
|
89
|
-
scale:int = 16
|
90
|
-
|
91
|
-
def setup(self):
|
92
|
-
self.freqs = jax.random.normal(jax.random.PRNGKey(42), (self.features // 2, ), dtype=jnp.float32) * self.scale
|
93
|
-
|
94
|
-
def __call__(self, x):
|
95
|
-
x = jax.lax.convert_element_type(x, jnp.float32)
|
96
|
-
emb = x[:, None] * (2 * jnp.pi * self.freqs)[None, :]
|
97
|
-
emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
|
98
|
-
return emb
|
99
|
-
|
100
|
-
class TimeProjection(nn.Module):
|
101
|
-
features:int
|
102
|
-
activation:Callable=jax.nn.gelu
|
103
|
-
|
104
|
-
@nn.compact
|
105
|
-
def __call__(self, x):
|
106
|
-
x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x)
|
107
|
-
x = self.activation(x)
|
108
|
-
x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x)
|
109
|
-
x = self.activation(x)
|
110
|
-
return x
|
111
|
-
|
112
|
-
class SeparableConv(nn.Module):
|
113
|
-
features:int
|
114
|
-
kernel_size:tuple=(3, 3)
|
115
|
-
strides:tuple=(1, 1)
|
116
|
-
use_bias:bool=False
|
117
|
-
kernel_init:Callable=kernel_init(1.0)
|
118
|
-
padding:str="SAME"
|
119
|
-
dtype: Any = jnp.bfloat16
|
120
|
-
precision: Any = jax.lax.Precision.HIGH
|
121
|
-
|
122
|
-
@nn.compact
|
123
|
-
def __call__(self, x):
|
124
|
-
in_features = x.shape[-1]
|
125
|
-
depthwise = nn.Conv(
|
126
|
-
features=in_features, kernel_size=self.kernel_size,
|
127
|
-
strides=self.strides, kernel_init=self.kernel_init,
|
128
|
-
feature_group_count=in_features, use_bias=self.use_bias,
|
129
|
-
padding=self.padding,
|
130
|
-
dtype=self.dtype,
|
131
|
-
precision=self.precision
|
132
|
-
)(x)
|
133
|
-
pointwise = nn.Conv(
|
134
|
-
features=self.features, kernel_size=(1, 1),
|
135
|
-
strides=(1, 1), kernel_init=self.kernel_init,
|
136
|
-
use_bias=self.use_bias,
|
137
|
-
dtype=self.dtype,
|
138
|
-
precision=self.precision
|
139
|
-
)(depthwise)
|
140
|
-
return pointwise
|
141
|
-
|
142
|
-
class ConvLayer(nn.Module):
|
143
|
-
conv_type:str
|
144
|
-
features:int
|
145
|
-
kernel_size:tuple=(3, 3)
|
146
|
-
strides:tuple=(1, 1)
|
147
|
-
kernel_init:Callable=kernel_init(1.0)
|
148
|
-
dtype: Any = jnp.bfloat16
|
149
|
-
precision: Any = jax.lax.Precision.HIGH
|
150
|
-
|
151
|
-
def setup(self):
|
152
|
-
# conv_type can be "conv", "separable", "conv_transpose"
|
153
|
-
if self.conv_type == "conv":
|
154
|
-
self.conv = nn.Conv(
|
155
|
-
features=self.features,
|
156
|
-
kernel_size=self.kernel_size,
|
157
|
-
strides=self.strides,
|
158
|
-
kernel_init=self.kernel_init,
|
159
|
-
dtype=self.dtype,
|
160
|
-
precision=self.precision
|
161
|
-
)
|
162
|
-
elif self.conv_type == "w_conv":
|
163
|
-
self.conv = WeightStandardizedConv(
|
164
|
-
features=self.features,
|
165
|
-
kernel_size=self.kernel_size,
|
166
|
-
strides=self.strides,
|
167
|
-
padding="SAME",
|
168
|
-
param_dtype=self.dtype,
|
169
|
-
dtype=self.dtype,
|
170
|
-
precision=self.precision
|
171
|
-
)
|
172
|
-
elif self.conv_type == "separable":
|
173
|
-
self.conv = SeparableConv(
|
174
|
-
features=self.features,
|
175
|
-
kernel_size=self.kernel_size,
|
176
|
-
strides=self.strides,
|
177
|
-
kernel_init=self.kernel_init,
|
178
|
-
dtype=self.dtype,
|
179
|
-
precision=self.precision
|
180
|
-
)
|
181
|
-
elif self.conv_type == "conv_transpose":
|
182
|
-
self.conv = nn.ConvTranspose(
|
183
|
-
features=self.features,
|
184
|
-
kernel_size=self.kernel_size,
|
185
|
-
strides=self.strides,
|
186
|
-
kernel_init=self.kernel_init,
|
187
|
-
dtype=self.dtype,
|
188
|
-
precision=self.precision
|
189
|
-
)
|
190
|
-
|
191
|
-
def __call__(self, x):
|
192
|
-
return self.conv(x)
|
193
|
-
|
194
|
-
class Upsample(nn.Module):
|
195
|
-
features:int
|
196
|
-
scale:int
|
197
|
-
activation:Callable=jax.nn.swish
|
198
|
-
dtype: Any = jnp.bfloat16
|
199
|
-
precision: Any = jax.lax.Precision.HIGH
|
200
|
-
|
201
|
-
@nn.compact
|
202
|
-
def __call__(self, x, residual=None):
|
203
|
-
out = x
|
204
|
-
# out = PixelShuffle(scale=self.scale)(out)
|
205
|
-
B, H, W, C = x.shape
|
206
|
-
out = jax.image.resize(x, (B, H * self.scale, W * self.scale, C), method="nearest")
|
207
|
-
out = ConvLayer(
|
208
|
-
"conv",
|
209
|
-
features=self.features,
|
210
|
-
kernel_size=(3, 3),
|
211
|
-
strides=(1, 1),
|
212
|
-
dtype=self.dtype,
|
213
|
-
precision=self.precision
|
214
|
-
)(out)
|
215
|
-
if residual is not None:
|
216
|
-
out = jnp.concatenate([out, residual], axis=-1)
|
217
|
-
return out
|
218
|
-
|
219
|
-
class Downsample(nn.Module):
|
220
|
-
features:int
|
221
|
-
scale:int
|
222
|
-
activation:Callable=jax.nn.swish
|
223
|
-
dtype: Any = jnp.bfloat16
|
224
|
-
precision: Any = jax.lax.Precision.HIGH
|
225
|
-
|
226
|
-
@nn.compact
|
227
|
-
def __call__(self, x, residual=None):
|
228
|
-
out = ConvLayer(
|
229
|
-
"conv",
|
230
|
-
features=self.features,
|
231
|
-
kernel_size=(3, 3),
|
232
|
-
strides=(2, 2),
|
233
|
-
dtype=self.dtype,
|
234
|
-
precision=self.precision
|
235
|
-
)(x)
|
236
|
-
if residual is not None:
|
237
|
-
if residual.shape[1] > out.shape[1]:
|
238
|
-
residual = nn.avg_pool(residual, window_shape=(2, 2), strides=(2, 2), padding="SAME")
|
239
|
-
out = jnp.concatenate([out, residual], axis=-1)
|
240
|
-
return out
|
241
|
-
|
242
|
-
|
243
|
-
def l2norm(t, axis=1, eps=1e-12):
|
244
|
-
denom = jnp.clip(jnp.linalg.norm(t, ord=2, axis=axis, keepdims=True), eps)
|
245
|
-
out = t/denom
|
246
|
-
return (out)
|
247
|
-
|
248
9
|
class ResidualBlock(nn.Module):
|
249
10
|
conv_type:str
|
250
11
|
features:int
|
@@ -318,6 +79,7 @@ class ResidualBlock(nn.Module):
|
|
318
79
|
return out
|
319
80
|
|
320
81
|
class Unet(nn.Module):
|
82
|
+
output_channels:int=3
|
321
83
|
emb_features:int=64*4,
|
322
84
|
feature_depths:list=[64, 128, 256, 512],
|
323
85
|
attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}],
|
@@ -373,12 +135,13 @@ class Unet(nn.Module):
|
|
373
135
|
)(x, temb)
|
374
136
|
if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
|
375
137
|
x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
138
|
+
dim_head=dim_in // attention_config['heads'],
|
139
|
+
use_flash_attention=attention_config.get("flash_attention", True),
|
140
|
+
use_projection=attention_config.get("use_projection", False),
|
141
|
+
use_self_and_cross=attention_config.get("use_self_and_cross", True),
|
142
|
+
precision=attention_config.get("precision", self.precision),
|
143
|
+
only_pure_attention=True,
|
144
|
+
name=f"down_{i}_attention_{j}")(x, textcontext)
|
382
145
|
# print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
|
383
146
|
downs.append(x)
|
384
147
|
if i != len(feature_depths) - 1:
|
@@ -416,6 +179,7 @@ class Unet(nn.Module):
|
|
416
179
|
use_projection=middle_attention.get("use_projection", False),
|
417
180
|
use_self_and_cross=False,
|
418
181
|
precision=attention_config.get("precision", self.precision),
|
182
|
+
only_pure_attention=True,
|
419
183
|
name=f"middle_attention_{j}")(x, textcontext)
|
420
184
|
x = ResidualBlock(
|
421
185
|
middle_conv_type,
|
@@ -452,12 +216,13 @@ class Unet(nn.Module):
|
|
452
216
|
)(x, temb)
|
453
217
|
if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
|
454
218
|
x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
219
|
+
dim_head=dim_out // attention_config['heads'],
|
220
|
+
use_flash_attention=attention_config.get("flash_attention", True),
|
221
|
+
use_projection=attention_config.get("use_projection", False),
|
222
|
+
use_self_and_cross=attention_config.get("use_self_and_cross", True),
|
459
223
|
precision=attention_config.get("precision", self.precision),
|
460
|
-
|
224
|
+
only_pure_attention=True,
|
225
|
+
name=f"up_{i}_attention_{j}")(x, textcontext)
|
461
226
|
# print("Upscaling ", i, x.shape)
|
462
227
|
if i != len(feature_depths) - 1:
|
463
228
|
x = Upsample(
|
@@ -500,7 +265,7 @@ class Unet(nn.Module):
|
|
500
265
|
|
501
266
|
noise_out = ConvLayer(
|
502
267
|
conv_type,
|
503
|
-
features=
|
268
|
+
features=self.output_channels,
|
504
269
|
kernel_size=(3, 3),
|
505
270
|
strides=(1, 1),
|
506
271
|
# activation=jax.nn.mish
|