flaxdiff 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/models/attention.py +140 -162
- flaxdiff/models/autoencoder/__init__.py +2 -0
- flaxdiff/models/autoencoder/autoencoder.py +19 -0
- flaxdiff/models/autoencoder/diffusers.py +91 -0
- flaxdiff/models/autoencoder/simple_autoenc.py +26 -0
- flaxdiff/models/common.py +322 -0
- flaxdiff/models/simple_unet.py +21 -327
- flaxdiff/trainer/__init__.py +2 -201
- flaxdiff/trainer/autoencoder_trainer.py +182 -0
- flaxdiff/trainer/diffusion_trainer.py +202 -0
- flaxdiff/trainer/simple_trainer.py +175 -80
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.6.dist-info}/METADATA +12 -2
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.6.dist-info}/RECORD +15 -9
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.6.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.6.dist-info}/top_level.txt +0 -0
flaxdiff/models/simple_unet.py
CHANGED
@@ -1,323 +1,14 @@
|
|
1
1
|
import jax
|
2
2
|
import jax.numpy as jnp
|
3
3
|
from flax import linen as nn
|
4
|
-
from typing import
|
4
|
+
from flax.typing import Dtype, PrecisionLike
|
5
|
+
from typing import Dict, Callable, Sequence, Any, Union, Optional
|
5
6
|
import einops
|
6
|
-
from .common import kernel_init
|
7
|
+
from .common import kernel_init, ConvLayer, Downsample, Upsample, FourierEmbedding, TimeProjection
|
7
8
|
from .attention import TransformerBlock
|
8
9
|
|
9
|
-
class WeightStandardizedConv(nn.Module):
|
10
|
-
"""
|
11
|
-
apply weight standardization https://arxiv.org/abs/1903.10520
|
12
|
-
"""
|
13
|
-
features: int
|
14
|
-
kernel_size: Sequence[int] = 3
|
15
|
-
strides: Union[None, int, Sequence[int]] = 1
|
16
|
-
padding: Any = 1
|
17
|
-
dtype: Any = jnp.float32
|
18
|
-
param_dtype: Any = jnp.float32
|
19
|
-
|
20
|
-
@nn.compact
|
21
|
-
def __call__(self, x):
|
22
|
-
"""
|
23
|
-
Applies a weight standardized convolution to the inputs.
|
24
|
-
|
25
|
-
Args:
|
26
|
-
inputs: input data with dimensions (batch, spatial_dims..., features).
|
27
|
-
|
28
|
-
Returns:
|
29
|
-
The convolved data.
|
30
|
-
"""
|
31
|
-
x = x.astype(self.dtype)
|
32
|
-
|
33
|
-
conv = nn.Conv(
|
34
|
-
features=self.features,
|
35
|
-
kernel_size=self.kernel_size,
|
36
|
-
strides = self.strides,
|
37
|
-
padding=self.padding,
|
38
|
-
dtype=self.dtype,
|
39
|
-
param_dtype = self.param_dtype,
|
40
|
-
parent=None)
|
41
|
-
|
42
|
-
kernel_init = lambda rng, x: conv.init(rng,x)['params']['kernel']
|
43
|
-
bias_init = lambda rng, x: conv.init(rng,x)['params']['bias']
|
44
|
-
|
45
|
-
# standardize kernel
|
46
|
-
kernel = self.param('kernel', kernel_init, x)
|
47
|
-
eps = 1e-5 if self.dtype == jnp.float32 else 1e-3
|
48
|
-
# reduce over dim_out
|
49
|
-
redux = tuple(range(kernel.ndim - 1))
|
50
|
-
mean = jnp.mean(kernel, axis=redux, dtype=self.dtype, keepdims=True)
|
51
|
-
var = jnp.var(kernel, axis=redux, dtype=self.dtype, keepdims=True)
|
52
|
-
standardized_kernel = (kernel - mean)/jnp.sqrt(var + eps)
|
53
|
-
|
54
|
-
bias = self.param('bias',bias_init, x)
|
55
|
-
|
56
|
-
return(conv.apply({'params': {'kernel': standardized_kernel, 'bias': bias}},x))
|
57
|
-
|
58
|
-
class PixelShuffle(nn.Module):
|
59
|
-
scale: int
|
60
|
-
|
61
|
-
@nn.compact
|
62
|
-
def __call__(self, x):
|
63
|
-
up = einops.rearrange(
|
64
|
-
x,
|
65
|
-
pattern="b h w (h2 w2 c) -> b (h h2) (w w2) c",
|
66
|
-
h2=self.scale,
|
67
|
-
w2=self.scale,
|
68
|
-
)
|
69
|
-
return up
|
70
|
-
|
71
|
-
class TimeEmbedding(nn.Module):
|
72
|
-
features:int
|
73
|
-
nax_positions:int=10000
|
74
|
-
|
75
|
-
def setup(self):
|
76
|
-
half_dim = self.features // 2
|
77
|
-
emb = jnp.log(self.nax_positions) / (half_dim - 1)
|
78
|
-
emb = jnp.exp(-emb * jnp.arange(half_dim, dtype=jnp.float32))
|
79
|
-
self.embeddings = emb
|
80
|
-
|
81
|
-
def __call__(self, x):
|
82
|
-
x = jax.lax.convert_element_type(x, jnp.float32)
|
83
|
-
emb = x[:, None] * self.embeddings[None, :]
|
84
|
-
emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
|
85
|
-
return emb
|
86
|
-
|
87
|
-
class FourierEmbedding(nn.Module):
|
88
|
-
features:int
|
89
|
-
scale:int = 16
|
90
|
-
|
91
|
-
def setup(self):
|
92
|
-
self.freqs = jax.random.normal(jax.random.PRNGKey(42), (self.features // 2, ), dtype=jnp.float32) * self.scale
|
93
|
-
|
94
|
-
def __call__(self, x):
|
95
|
-
x = jax.lax.convert_element_type(x, jnp.float32)
|
96
|
-
emb = x[:, None] * (2 * jnp.pi * self.freqs)[None, :]
|
97
|
-
emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
|
98
|
-
return emb
|
99
|
-
|
100
|
-
class TimeProjection(nn.Module):
|
101
|
-
features:int
|
102
|
-
activation:Callable=jax.nn.gelu
|
103
|
-
|
104
|
-
@nn.compact
|
105
|
-
def __call__(self, x):
|
106
|
-
x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x)
|
107
|
-
x = self.activation(x)
|
108
|
-
x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x)
|
109
|
-
x = self.activation(x)
|
110
|
-
return x
|
111
|
-
|
112
|
-
class SeparableConv(nn.Module):
|
113
|
-
features:int
|
114
|
-
kernel_size:tuple=(3, 3)
|
115
|
-
strides:tuple=(1, 1)
|
116
|
-
use_bias:bool=False
|
117
|
-
kernel_init:Callable=kernel_init(1.0)
|
118
|
-
padding:str="SAME"
|
119
|
-
dtype: Any = jnp.bfloat16
|
120
|
-
precision: Any = jax.lax.Precision.HIGH
|
121
|
-
|
122
|
-
@nn.compact
|
123
|
-
def __call__(self, x):
|
124
|
-
in_features = x.shape[-1]
|
125
|
-
depthwise = nn.Conv(
|
126
|
-
features=in_features, kernel_size=self.kernel_size,
|
127
|
-
strides=self.strides, kernel_init=self.kernel_init,
|
128
|
-
feature_group_count=in_features, use_bias=self.use_bias,
|
129
|
-
padding=self.padding,
|
130
|
-
dtype=self.dtype,
|
131
|
-
precision=self.precision
|
132
|
-
)(x)
|
133
|
-
pointwise = nn.Conv(
|
134
|
-
features=self.features, kernel_size=(1, 1),
|
135
|
-
strides=(1, 1), kernel_init=self.kernel_init,
|
136
|
-
use_bias=self.use_bias,
|
137
|
-
dtype=self.dtype,
|
138
|
-
precision=self.precision
|
139
|
-
)(depthwise)
|
140
|
-
return pointwise
|
141
|
-
|
142
|
-
class ConvLayer(nn.Module):
|
143
|
-
conv_type:str
|
144
|
-
features:int
|
145
|
-
kernel_size:tuple=(3, 3)
|
146
|
-
strides:tuple=(1, 1)
|
147
|
-
kernel_init:Callable=kernel_init(1.0)
|
148
|
-
dtype: Any = jnp.bfloat16
|
149
|
-
precision: Any = jax.lax.Precision.HIGH
|
150
|
-
|
151
|
-
def setup(self):
|
152
|
-
# conv_type can be "conv", "separable", "conv_transpose"
|
153
|
-
if self.conv_type == "conv":
|
154
|
-
self.conv = nn.Conv(
|
155
|
-
features=self.features,
|
156
|
-
kernel_size=self.kernel_size,
|
157
|
-
strides=self.strides,
|
158
|
-
kernel_init=self.kernel_init,
|
159
|
-
dtype=self.dtype,
|
160
|
-
precision=self.precision
|
161
|
-
)
|
162
|
-
elif self.conv_type == "w_conv":
|
163
|
-
self.conv = WeightStandardizedConv(
|
164
|
-
features=self.features,
|
165
|
-
kernel_size=self.kernel_size,
|
166
|
-
strides=self.strides,
|
167
|
-
padding="SAME",
|
168
|
-
param_dtype=self.dtype,
|
169
|
-
dtype=self.dtype,
|
170
|
-
precision=self.precision
|
171
|
-
)
|
172
|
-
elif self.conv_type == "separable":
|
173
|
-
self.conv = SeparableConv(
|
174
|
-
features=self.features,
|
175
|
-
kernel_size=self.kernel_size,
|
176
|
-
strides=self.strides,
|
177
|
-
kernel_init=self.kernel_init,
|
178
|
-
dtype=self.dtype,
|
179
|
-
precision=self.precision
|
180
|
-
)
|
181
|
-
elif self.conv_type == "conv_transpose":
|
182
|
-
self.conv = nn.ConvTranspose(
|
183
|
-
features=self.features,
|
184
|
-
kernel_size=self.kernel_size,
|
185
|
-
strides=self.strides,
|
186
|
-
kernel_init=self.kernel_init,
|
187
|
-
dtype=self.dtype,
|
188
|
-
precision=self.precision
|
189
|
-
)
|
190
|
-
|
191
|
-
def __call__(self, x):
|
192
|
-
return self.conv(x)
|
193
|
-
|
194
|
-
class Upsample(nn.Module):
|
195
|
-
features:int
|
196
|
-
scale:int
|
197
|
-
activation:Callable=jax.nn.swish
|
198
|
-
dtype: Any = jnp.bfloat16
|
199
|
-
precision: Any = jax.lax.Precision.HIGH
|
200
|
-
|
201
|
-
@nn.compact
|
202
|
-
def __call__(self, x, residual=None):
|
203
|
-
out = x
|
204
|
-
# out = PixelShuffle(scale=self.scale)(out)
|
205
|
-
B, H, W, C = x.shape
|
206
|
-
out = jax.image.resize(x, (B, H * self.scale, W * self.scale, C), method="nearest")
|
207
|
-
out = ConvLayer(
|
208
|
-
"conv",
|
209
|
-
features=self.features,
|
210
|
-
kernel_size=(3, 3),
|
211
|
-
strides=(1, 1),
|
212
|
-
dtype=self.dtype,
|
213
|
-
precision=self.precision
|
214
|
-
)(out)
|
215
|
-
if residual is not None:
|
216
|
-
out = jnp.concatenate([out, residual], axis=-1)
|
217
|
-
return out
|
218
|
-
|
219
|
-
class Downsample(nn.Module):
|
220
|
-
features:int
|
221
|
-
scale:int
|
222
|
-
activation:Callable=jax.nn.swish
|
223
|
-
dtype: Any = jnp.bfloat16
|
224
|
-
precision: Any = jax.lax.Precision.HIGH
|
225
|
-
|
226
|
-
@nn.compact
|
227
|
-
def __call__(self, x, residual=None):
|
228
|
-
out = ConvLayer(
|
229
|
-
"conv",
|
230
|
-
features=self.features,
|
231
|
-
kernel_size=(3, 3),
|
232
|
-
strides=(2, 2),
|
233
|
-
dtype=self.dtype,
|
234
|
-
precision=self.precision
|
235
|
-
)(x)
|
236
|
-
if residual is not None:
|
237
|
-
if residual.shape[1] > out.shape[1]:
|
238
|
-
residual = nn.avg_pool(residual, window_shape=(2, 2), strides=(2, 2), padding="SAME")
|
239
|
-
out = jnp.concatenate([out, residual], axis=-1)
|
240
|
-
return out
|
241
|
-
|
242
|
-
|
243
|
-
def l2norm(t, axis=1, eps=1e-12):
|
244
|
-
denom = jnp.clip(jnp.linalg.norm(t, ord=2, axis=axis, keepdims=True), eps)
|
245
|
-
out = t/denom
|
246
|
-
return (out)
|
247
|
-
|
248
|
-
class ResidualBlock(nn.Module):
|
249
|
-
conv_type:str
|
250
|
-
features:int
|
251
|
-
kernel_size:tuple=(3, 3)
|
252
|
-
strides:tuple=(1, 1)
|
253
|
-
padding:str="SAME"
|
254
|
-
activation:Callable=jax.nn.swish
|
255
|
-
direction:str=None
|
256
|
-
res:int=2
|
257
|
-
norm_groups:int=8
|
258
|
-
kernel_init:Callable=kernel_init(1.0)
|
259
|
-
dtype: Any = jnp.float32
|
260
|
-
precision: Any = jax.lax.Precision.HIGHEST
|
261
|
-
|
262
|
-
@nn.compact
|
263
|
-
def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_features:jax.Array=None):
|
264
|
-
residual = x
|
265
|
-
out = nn.GroupNorm(self.norm_groups)(x)
|
266
|
-
out = self.activation(out)
|
267
|
-
|
268
|
-
out = ConvLayer(
|
269
|
-
self.conv_type,
|
270
|
-
features=self.features,
|
271
|
-
kernel_size=self.kernel_size,
|
272
|
-
strides=self.strides,
|
273
|
-
kernel_init=self.kernel_init,
|
274
|
-
name="conv1",
|
275
|
-
dtype=self.dtype,
|
276
|
-
precision=self.precision
|
277
|
-
)(out)
|
278
|
-
|
279
|
-
temb = nn.DenseGeneral(
|
280
|
-
features=self.features,
|
281
|
-
name="temb_projection",
|
282
|
-
dtype=self.dtype,
|
283
|
-
precision=self.precision)(temb)
|
284
|
-
temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
|
285
|
-
# scale, shift = jnp.split(temb, 2, axis=-1)
|
286
|
-
# out = out * (1 + scale) + shift
|
287
|
-
out = out + temb
|
288
|
-
|
289
|
-
out = nn.GroupNorm(self.norm_groups)(out)
|
290
|
-
out = self.activation(out)
|
291
|
-
|
292
|
-
out = ConvLayer(
|
293
|
-
self.conv_type,
|
294
|
-
features=self.features,
|
295
|
-
kernel_size=self.kernel_size,
|
296
|
-
strides=self.strides,
|
297
|
-
kernel_init=self.kernel_init,
|
298
|
-
name="conv2",
|
299
|
-
dtype=self.dtype,
|
300
|
-
precision=self.precision
|
301
|
-
)(out)
|
302
|
-
|
303
|
-
if residual.shape != out.shape:
|
304
|
-
residual = ConvLayer(
|
305
|
-
self.conv_type,
|
306
|
-
features=self.features,
|
307
|
-
kernel_size=(1, 1),
|
308
|
-
strides=1,
|
309
|
-
kernel_init=self.kernel_init,
|
310
|
-
name="residual_conv",
|
311
|
-
dtype=self.dtype,
|
312
|
-
precision=self.precision
|
313
|
-
)(residual)
|
314
|
-
out = out + residual
|
315
|
-
|
316
|
-
out = jnp.concatenate([out, extra_features], axis=-1) if extra_features is not None else out
|
317
|
-
|
318
|
-
return out
|
319
|
-
|
320
10
|
class Unet(nn.Module):
|
11
|
+
output_channels:int=3
|
321
12
|
emb_features:int=64*4,
|
322
13
|
feature_depths:list=[64, 128, 256, 512],
|
323
14
|
attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}],
|
@@ -325,8 +16,8 @@ class Unet(nn.Module):
|
|
325
16
|
num_middle_res_blocks:int=1,
|
326
17
|
activation:Callable = jax.nn.swish
|
327
18
|
norm_groups:int=8
|
328
|
-
dtype:
|
329
|
-
precision:
|
19
|
+
dtype: Optional[Dtype] = None
|
20
|
+
precision: PrecisionLike = None
|
330
21
|
|
331
22
|
@nn.compact
|
332
23
|
def __call__(self, x, temb, textcontext):
|
@@ -373,12 +64,13 @@ class Unet(nn.Module):
|
|
373
64
|
)(x, temb)
|
374
65
|
if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
|
375
66
|
x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
67
|
+
dim_head=dim_in // attention_config['heads'],
|
68
|
+
use_flash_attention=attention_config.get("flash_attention", True),
|
69
|
+
use_projection=attention_config.get("use_projection", False),
|
70
|
+
use_self_and_cross=attention_config.get("use_self_and_cross", True),
|
71
|
+
precision=attention_config.get("precision", self.precision),
|
72
|
+
only_pure_attention=True,
|
73
|
+
name=f"down_{i}_attention_{j}")(x, textcontext)
|
382
74
|
# print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
|
383
75
|
downs.append(x)
|
384
76
|
if i != len(feature_depths) - 1:
|
@@ -416,6 +108,7 @@ class Unet(nn.Module):
|
|
416
108
|
use_projection=middle_attention.get("use_projection", False),
|
417
109
|
use_self_and_cross=False,
|
418
110
|
precision=attention_config.get("precision", self.precision),
|
111
|
+
only_pure_attention=True,
|
419
112
|
name=f"middle_attention_{j}")(x, textcontext)
|
420
113
|
x = ResidualBlock(
|
421
114
|
middle_conv_type,
|
@@ -452,12 +145,13 @@ class Unet(nn.Module):
|
|
452
145
|
)(x, temb)
|
453
146
|
if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
|
454
147
|
x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
148
|
+
dim_head=dim_out // attention_config['heads'],
|
149
|
+
use_flash_attention=attention_config.get("flash_attention", True),
|
150
|
+
use_projection=attention_config.get("use_projection", False),
|
151
|
+
use_self_and_cross=attention_config.get("use_self_and_cross", True),
|
459
152
|
precision=attention_config.get("precision", self.precision),
|
460
|
-
|
153
|
+
only_pure_attention=True,
|
154
|
+
name=f"up_{i}_attention_{j}")(x, textcontext)
|
461
155
|
# print("Upscaling ", i, x.shape)
|
462
156
|
if i != len(feature_depths) - 1:
|
463
157
|
x = Upsample(
|
@@ -500,7 +194,7 @@ class Unet(nn.Module):
|
|
500
194
|
|
501
195
|
noise_out = ConvLayer(
|
502
196
|
conv_type,
|
503
|
-
features=
|
197
|
+
features=self.output_channels,
|
504
198
|
kernel_size=(3, 3),
|
505
199
|
strides=(1, 1),
|
506
200
|
# activation=jax.nn.mish
|
flaxdiff/trainer/__init__.py
CHANGED
@@ -1,201 +1,2 @@
|
|
1
|
-
import
|
2
|
-
import
|
3
|
-
from flax import linen as nn
|
4
|
-
import jax
|
5
|
-
from typing import Callable
|
6
|
-
from dataclasses import field
|
7
|
-
import jax.numpy as jnp
|
8
|
-
from clu import metrics
|
9
|
-
from flax.training import train_state # Useful dataclass to keep train state
|
10
|
-
import optax
|
11
|
-
from flax import struct # Flax dataclasses
|
12
|
-
import time
|
13
|
-
import os
|
14
|
-
import orbax
|
15
|
-
from flax.training import orbax_utils
|
16
|
-
|
17
|
-
from ..schedulers import NoiseScheduler
|
18
|
-
from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
|
19
|
-
|
20
|
-
from .simple_trainer import SimpleTrainer, SimpleTrainState
|
21
|
-
|
22
|
-
class TrainState(SimpleTrainState):
|
23
|
-
rngs: jax.random.PRNGKey
|
24
|
-
ema_params: dict
|
25
|
-
|
26
|
-
def get_random_key(self):
|
27
|
-
rngs, subkey = jax.random.split(self.rngs)
|
28
|
-
return self.replace(rngs=rngs), subkey
|
29
|
-
|
30
|
-
def apply_ema(self, decay: float = 0.999):
|
31
|
-
new_ema_params = jax.tree_util.tree_map(
|
32
|
-
lambda ema, param: decay * ema + (1 - decay) * param,
|
33
|
-
self.ema_params,
|
34
|
-
self.params,
|
35
|
-
)
|
36
|
-
return self.replace(ema_params=new_ema_params)
|
37
|
-
|
38
|
-
class DiffusionTrainer(SimpleTrainer):
|
39
|
-
noise_schedule: NoiseScheduler
|
40
|
-
model_output_transform: DiffusionPredictionTransform
|
41
|
-
ema_decay: float = 0.999
|
42
|
-
|
43
|
-
def __init__(self,
|
44
|
-
model: nn.Module,
|
45
|
-
input_shapes: Dict[str, Tuple[int]],
|
46
|
-
optimizer: optax.GradientTransformation,
|
47
|
-
noise_schedule: NoiseScheduler,
|
48
|
-
rngs: jax.random.PRNGKey,
|
49
|
-
unconditional_prob: float = 0.2,
|
50
|
-
name: str = "Diffusion",
|
51
|
-
model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
|
52
|
-
**kwargs
|
53
|
-
):
|
54
|
-
super().__init__(
|
55
|
-
model=model,
|
56
|
-
input_shapes=input_shapes,
|
57
|
-
optimizer=optimizer,
|
58
|
-
rngs=rngs,
|
59
|
-
name=name,
|
60
|
-
**kwargs
|
61
|
-
)
|
62
|
-
self.noise_schedule = noise_schedule
|
63
|
-
self.model_output_transform = model_output_transform
|
64
|
-
self.unconditional_prob = unconditional_prob
|
65
|
-
|
66
|
-
def __init_fn(
|
67
|
-
self,
|
68
|
-
optimizer: optax.GradientTransformation,
|
69
|
-
rngs: jax.random.PRNGKey,
|
70
|
-
existing_state: dict = None,
|
71
|
-
existing_best_state: dict = None,
|
72
|
-
model: nn.Module = None,
|
73
|
-
param_transforms: Callable = None
|
74
|
-
) -> Tuple[TrainState, TrainState]:
|
75
|
-
rngs, subkey = jax.random.split(rngs)
|
76
|
-
|
77
|
-
if existing_state == None:
|
78
|
-
input_vars = self.get_input_ones()
|
79
|
-
params = model.init(subkey, **input_vars)
|
80
|
-
new_state = {"params": params, "ema_params": params}
|
81
|
-
else:
|
82
|
-
new_state = existing_state
|
83
|
-
|
84
|
-
if param_transforms is not None:
|
85
|
-
params = param_transforms(params)
|
86
|
-
|
87
|
-
state = TrainState.create(
|
88
|
-
apply_fn=model.apply,
|
89
|
-
params=new_state['params'],
|
90
|
-
ema_params=new_state['ema_params'],
|
91
|
-
tx=optimizer,
|
92
|
-
rngs=rngs,
|
93
|
-
metrics=Metrics.empty()
|
94
|
-
)
|
95
|
-
|
96
|
-
if existing_best_state is not None:
|
97
|
-
best_state = state.replace(
|
98
|
-
params=existing_best_state['params'], ema_params=existing_best_state['ema_params'])
|
99
|
-
else:
|
100
|
-
best_state = state
|
101
|
-
|
102
|
-
return state, best_state
|
103
|
-
|
104
|
-
def _define_train_step(self, batch_size, null_labels_seq, text_embedder):
|
105
|
-
noise_schedule = self.noise_schedule
|
106
|
-
model = self.model
|
107
|
-
model_output_transform = self.model_output_transform
|
108
|
-
loss_fn = self.loss_fn
|
109
|
-
unconditional_prob = self.unconditional_prob
|
110
|
-
|
111
|
-
# Determine the number of unconditional samples
|
112
|
-
num_unconditional = int(batch_size * unconditional_prob)
|
113
|
-
|
114
|
-
nS, nC = null_labels_seq.shape
|
115
|
-
null_labels_seq = jnp.broadcast_to(
|
116
|
-
null_labels_seq, (batch_size, nS, nC))
|
117
|
-
|
118
|
-
distributed_training = self.distributed_training
|
119
|
-
|
120
|
-
def train_step(state: TrainState, batch):
|
121
|
-
"""Train for a single step."""
|
122
|
-
images = batch['image']
|
123
|
-
# normalize image
|
124
|
-
images = (images - 127.5) / 127.5
|
125
|
-
|
126
|
-
output = text_embedder(
|
127
|
-
input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
|
128
|
-
# output = infer(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
|
129
|
-
|
130
|
-
label_seq = output.last_hidden_state
|
131
|
-
|
132
|
-
# Generate random probabilities to decide how much of this batch will be unconditional
|
133
|
-
|
134
|
-
label_seq = jnp.concat(
|
135
|
-
[null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
|
136
|
-
|
137
|
-
noise_level, state = noise_schedule.generate_timesteps(
|
138
|
-
images.shape[0], state)
|
139
|
-
state, rngs = state.get_random_key()
|
140
|
-
noise: jax.Array = jax.random.normal(rngs, shape=images.shape)
|
141
|
-
rates = noise_schedule.get_rates(noise_level)
|
142
|
-
noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
|
143
|
-
images, noise, rates)
|
144
|
-
|
145
|
-
def model_loss(params):
|
146
|
-
preds = model.apply(
|
147
|
-
params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
|
148
|
-
preds = model_output_transform.pred_transform(
|
149
|
-
noisy_images, preds, rates)
|
150
|
-
nloss = loss_fn(preds, expected_output)
|
151
|
-
# nloss = jnp.mean(nloss, axis=1)
|
152
|
-
nloss *= noise_schedule.get_weights(noise_level)
|
153
|
-
nloss = jnp.mean(nloss)
|
154
|
-
loss = nloss
|
155
|
-
return loss
|
156
|
-
|
157
|
-
loss, grads = jax.value_and_grad(model_loss)(state.params)
|
158
|
-
if distributed_training:
|
159
|
-
grads = jax.lax.pmean(grads, "device")
|
160
|
-
state = state.apply_gradients(grads=grads)
|
161
|
-
state = state.apply_ema(self.ema_decay)
|
162
|
-
return state, loss
|
163
|
-
|
164
|
-
if distributed_training:
|
165
|
-
train_step = jax.pmap(axis_name="device")(train_step)
|
166
|
-
else:
|
167
|
-
train_step = jax.jit(train_step)
|
168
|
-
|
169
|
-
return train_step
|
170
|
-
|
171
|
-
def _define_compute_metrics(self):
|
172
|
-
@jax.jit
|
173
|
-
def compute_metrics(state: TrainState, expected, pred):
|
174
|
-
loss = jnp.mean(jnp.square(pred - expected))
|
175
|
-
metric_updates = state.metrics.single_from_model_output(loss=loss)
|
176
|
-
metrics = state.metrics.merge(metric_updates)
|
177
|
-
state = state.replace(metrics=metrics)
|
178
|
-
return state
|
179
|
-
return compute_metrics
|
180
|
-
|
181
|
-
def fit(self, data, steps_per_epoch, epochs):
|
182
|
-
null_labels_full = data['null_labels_full']
|
183
|
-
local_batch_size = data['local_batch_size']
|
184
|
-
text_embedder = data['model']
|
185
|
-
super().fit(data, steps_per_epoch, epochs, {
|
186
|
-
"batch_size": local_batch_size, "null_labels_seq": null_labels_full, "text_embedder": text_embedder})
|
187
|
-
|
188
|
-
|
189
|
-
pbar.set_postfix(loss=f'{loss:.4f}')
|
190
|
-
pbar.update(100)
|
191
|
-
end_time = time.time()
|
192
|
-
self.state = state
|
193
|
-
total_time = end_time - start_time
|
194
|
-
avg_time_per_step = total_time / steps_per_epoch
|
195
|
-
avg_loss = epoch_loss / steps_per_epoch
|
196
|
-
if avg_loss < self.best_loss:
|
197
|
-
self.best_loss = avg_loss
|
198
|
-
self.best_state = state
|
199
|
-
self.save(epoch, best=True)
|
200
|
-
print(f"\n\tEpoch {epoch+1} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss}")
|
201
|
-
return self.state
|
1
|
+
from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
|
2
|
+
from .diffusion_trainer import DiffusionTrainer, TrainState
|