flaxdiff 0.1.8__py3-none-any.whl → 0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/models/attention.py +7 -5
- flaxdiff/models/autoencoder/diffusers.py +1 -1
- flaxdiff/models/common.py +14 -2
- flaxdiff/models/simple_unet.py +27 -12
- flaxdiff/models/simple_vit.py +13 -16
- flaxdiff/trainer/diffusion_trainer.py +44 -12
- flaxdiff/trainer/simple_trainer.py +84 -61
- {flaxdiff-0.1.8.dist-info → flaxdiff-0.1.10.dist-info}/METADATA +18 -1
- {flaxdiff-0.1.8.dist-info → flaxdiff-0.1.10.dist-info}/RECORD +11 -11
- {flaxdiff-0.1.8.dist-info → flaxdiff-0.1.10.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.8.dist-info → flaxdiff-0.1.10.dist-info}/top_level.txt +0 -0
flaxdiff/models/attention.py
CHANGED
@@ -156,7 +156,9 @@ class NormalAttention(nn.Module):
|
|
156
156
|
value = self.value(context)
|
157
157
|
|
158
158
|
hidden_states = nn.dot_product_attention(
|
159
|
-
query, key, value, dtype=self.dtype, broadcast_dropout=False,
|
159
|
+
query, key, value, dtype=self.dtype, broadcast_dropout=False,
|
160
|
+
dropout_rng=None, precision=self.precision, force_fp32_for_softmax=True,
|
161
|
+
deterministic=True
|
160
162
|
)
|
161
163
|
proj = self.proj_attn(hidden_states)
|
162
164
|
proj = proj.reshape(orig_x_shape)
|
@@ -187,7 +189,7 @@ class FlaxGEGLU(nn.Module):
|
|
187
189
|
|
188
190
|
def __call__(self, hidden_states):
|
189
191
|
hidden_states = self.proj(hidden_states)
|
190
|
-
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis
|
192
|
+
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=-1)
|
191
193
|
return hidden_linear * nn.gelu(hidden_gelu)
|
192
194
|
|
193
195
|
class FlaxFeedForward(nn.Module):
|
@@ -291,14 +293,14 @@ class TransformerBlock(nn.Module):
|
|
291
293
|
dtype: Optional[Dtype] = None
|
292
294
|
precision: PrecisionLike = None
|
293
295
|
use_projection: bool = False
|
294
|
-
use_flash_attention:bool =
|
295
|
-
use_self_and_cross:bool =
|
296
|
+
use_flash_attention:bool = False
|
297
|
+
use_self_and_cross:bool = True
|
296
298
|
only_pure_attention:bool = False
|
297
299
|
|
298
300
|
@nn.compact
|
299
301
|
def __call__(self, x, context=None):
|
300
302
|
inner_dim = self.heads * self.dim_head
|
301
|
-
|
303
|
+
C = x.shape[-1]
|
302
304
|
normed_x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
|
303
305
|
if self.use_projection == True:
|
304
306
|
if self.use_linear_attention:
|
@@ -78,7 +78,7 @@ class StableDiffusionVAE(AutoEncoder):
|
|
78
78
|
log_std = jnp.clip(log_std, -30, 20)
|
79
79
|
std = jnp.exp(0.5 * log_std)
|
80
80
|
latents = mean + std * jax.random.normal(rngkey, mean.shape, dtype=mean.dtype)
|
81
|
-
print("Sampled")
|
81
|
+
# print("Sampled")
|
82
82
|
else:
|
83
83
|
# return the mean
|
84
84
|
latents, _ = jnp.split(latents, 2, axis=-1)
|
flaxdiff/models/common.py
CHANGED
@@ -5,6 +5,7 @@ from typing import Optional, Any, Callable, Sequence, Union
|
|
5
5
|
from flax.typing import Dtype, PrecisionLike
|
6
6
|
from typing import Dict, Callable, Sequence, Any, Union
|
7
7
|
import einops
|
8
|
+
from functools import partial
|
8
9
|
|
9
10
|
# Kernel initializer to use
|
10
11
|
def kernel_init(scale, dtype=jnp.float32):
|
@@ -266,11 +267,22 @@ class ResidualBlock(nn.Module):
|
|
266
267
|
kernel_init:Callable=kernel_init(1.0)
|
267
268
|
dtype: Optional[Dtype] = None
|
268
269
|
precision: PrecisionLike = None
|
270
|
+
named_norms:bool=False
|
271
|
+
|
272
|
+
def setup(self):
|
273
|
+
if self.norm_groups > 0:
|
274
|
+
norm = partial(nn.GroupNorm, self.norm_groups)
|
275
|
+
self.norm1 = norm(name="GroupNorm_0") if self.named_norms else norm()
|
276
|
+
self.norm2 = norm(name="GroupNorm_1") if self.named_norms else norm()
|
277
|
+
else:
|
278
|
+
norm = partial(nn.RMSNorm, 1e-5)
|
279
|
+
self.norm1 = norm()
|
280
|
+
self.norm2 = norm()
|
269
281
|
|
270
282
|
@nn.compact
|
271
283
|
def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_features:jax.Array=None):
|
272
284
|
residual = x
|
273
|
-
out =
|
285
|
+
out = self.norm1(x)
|
274
286
|
# out = nn.RMSNorm()(x)
|
275
287
|
out = self.activation(out)
|
276
288
|
|
@@ -295,7 +307,7 @@ class ResidualBlock(nn.Module):
|
|
295
307
|
# out = out * (1 + scale) + shift
|
296
308
|
out = out + temb
|
297
309
|
|
298
|
-
out =
|
310
|
+
out = self.norm2(out)
|
299
311
|
# out = nn.RMSNorm()(out)
|
300
312
|
out = self.activation(out)
|
301
313
|
|
flaxdiff/models/simple_unet.py
CHANGED
@@ -6,6 +6,7 @@ from typing import Dict, Callable, Sequence, Any, Union, Optional
|
|
6
6
|
import einops
|
7
7
|
from .common import kernel_init, ConvLayer, Downsample, Upsample, FourierEmbedding, TimeProjection, ResidualBlock
|
8
8
|
from .attention import TransformerBlock
|
9
|
+
from functools import partial
|
9
10
|
|
10
11
|
class Unet(nn.Module):
|
11
12
|
output_channels:int=3
|
@@ -18,7 +19,16 @@ class Unet(nn.Module):
|
|
18
19
|
norm_groups:int=8
|
19
20
|
dtype: Optional[Dtype] = None
|
20
21
|
precision: PrecisionLike = None
|
22
|
+
named_norms: bool = False # This is for backward compatibility reasons; older checkpoints have named norms
|
21
23
|
|
24
|
+
def setup(self):
|
25
|
+
if self.norm_groups > 0:
|
26
|
+
norm = partial(nn.GroupNorm, self.norm_groups)
|
27
|
+
self.conv_out_norm = norm(name="GroupNorm_0") if self.named_norms else norm()
|
28
|
+
else:
|
29
|
+
norm = partial(nn.RMSNorm, 1e-5)
|
30
|
+
self.conv_out_norm = norm()
|
31
|
+
|
22
32
|
@nn.compact
|
23
33
|
def __call__(self, x, temb, textcontext):
|
24
34
|
# print("embedding features", self.emb_features)
|
@@ -60,7 +70,8 @@ class Unet(nn.Module):
|
|
60
70
|
activation=self.activation,
|
61
71
|
norm_groups=self.norm_groups,
|
62
72
|
dtype=self.dtype,
|
63
|
-
precision=self.precision
|
73
|
+
precision=self.precision,
|
74
|
+
named_norms=self.named_norms
|
64
75
|
)(x, temb)
|
65
76
|
if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
|
66
77
|
x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
|
@@ -69,7 +80,7 @@ class Unet(nn.Module):
|
|
69
80
|
use_projection=attention_config.get("use_projection", False),
|
70
81
|
use_self_and_cross=attention_config.get("use_self_and_cross", True),
|
71
82
|
precision=attention_config.get("precision", self.precision),
|
72
|
-
only_pure_attention=True,
|
83
|
+
only_pure_attention=attention_config.get("only_pure_attention", True),
|
73
84
|
name=f"down_{i}_attention_{j}")(x, textcontext)
|
74
85
|
# print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
|
75
86
|
downs.append(x)
|
@@ -98,7 +109,8 @@ class Unet(nn.Module):
|
|
98
109
|
activation=self.activation,
|
99
110
|
norm_groups=self.norm_groups,
|
100
111
|
dtype=self.dtype,
|
101
|
-
precision=self.precision
|
112
|
+
precision=self.precision,
|
113
|
+
named_norms=self.named_norms
|
102
114
|
)(x, temb)
|
103
115
|
if middle_attention is not None and j == self.num_middle_res_blocks - 1: # Apply attention only on the last block
|
104
116
|
x = TransformerBlock(heads=middle_attention['heads'], dtype=middle_attention.get('dtype', jnp.float32),
|
@@ -107,8 +119,8 @@ class Unet(nn.Module):
|
|
107
119
|
use_linear_attention=False,
|
108
120
|
use_projection=middle_attention.get("use_projection", False),
|
109
121
|
use_self_and_cross=False,
|
110
|
-
precision=
|
111
|
-
only_pure_attention=True,
|
122
|
+
precision=middle_attention.get("precision", self.precision),
|
123
|
+
only_pure_attention=middle_attention.get("only_pure_attention", True),
|
112
124
|
name=f"middle_attention_{j}")(x, textcontext)
|
113
125
|
x = ResidualBlock(
|
114
126
|
middle_conv_type,
|
@@ -120,7 +132,8 @@ class Unet(nn.Module):
|
|
120
132
|
activation=self.activation,
|
121
133
|
norm_groups=self.norm_groups,
|
122
134
|
dtype=self.dtype,
|
123
|
-
precision=self.precision
|
135
|
+
precision=self.precision,
|
136
|
+
named_norms=self.named_norms
|
124
137
|
)(x, temb)
|
125
138
|
|
126
139
|
# Upscaling Blocks
|
@@ -141,7 +154,8 @@ class Unet(nn.Module):
|
|
141
154
|
activation=self.activation,
|
142
155
|
norm_groups=self.norm_groups,
|
143
156
|
dtype=self.dtype,
|
144
|
-
precision=self.precision
|
157
|
+
precision=self.precision,
|
158
|
+
named_norms=self.named_norms
|
145
159
|
)(x, temb)
|
146
160
|
if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
|
147
161
|
x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
|
@@ -150,7 +164,7 @@ class Unet(nn.Module):
|
|
150
164
|
use_projection=attention_config.get("use_projection", False),
|
151
165
|
use_self_and_cross=attention_config.get("use_self_and_cross", True),
|
152
166
|
precision=attention_config.get("precision", self.precision),
|
153
|
-
only_pure_attention=True,
|
167
|
+
only_pure_attention=attention_config.get("only_pure_attention", True),
|
154
168
|
name=f"up_{i}_attention_{j}")(x, textcontext)
|
155
169
|
# print("Upscaling ", i, x.shape)
|
156
170
|
if i != len(feature_depths) - 1:
|
@@ -163,13 +177,13 @@ class Unet(nn.Module):
|
|
163
177
|
precision=self.precision
|
164
178
|
)(x)
|
165
179
|
|
166
|
-
# x =
|
180
|
+
# x = self.last_up_norm(x)
|
167
181
|
x = ConvLayer(
|
168
182
|
conv_type,
|
169
183
|
features=self.feature_depths[0],
|
170
184
|
kernel_size=(3, 3),
|
171
185
|
strides=(1, 1),
|
172
|
-
kernel_init=kernel_init(
|
186
|
+
kernel_init=kernel_init(1.0),
|
173
187
|
dtype=self.dtype,
|
174
188
|
precision=self.precision
|
175
189
|
)(x)
|
@@ -186,10 +200,11 @@ class Unet(nn.Module):
|
|
186
200
|
activation=self.activation,
|
187
201
|
norm_groups=self.norm_groups,
|
188
202
|
dtype=self.dtype,
|
189
|
-
precision=self.precision
|
203
|
+
precision=self.precision,
|
204
|
+
named_norms=self.named_norms
|
190
205
|
)(x, temb)
|
191
206
|
|
192
|
-
x =
|
207
|
+
x = self.conv_out_norm(x)
|
193
208
|
x = self.activation(x)
|
194
209
|
|
195
210
|
noise_out = ConvLayer(
|
flaxdiff/models/simple_vit.py
CHANGED
@@ -4,7 +4,7 @@ import jax
|
|
4
4
|
import jax.numpy as jnp
|
5
5
|
from flax import linen as nn
|
6
6
|
from typing import Callable, Any
|
7
|
-
from .
|
7
|
+
from .simple_unet import FourierEmbedding, TimeProjection, ConvLayer, kernel_init
|
8
8
|
from .attention import TransformerBlock
|
9
9
|
|
10
10
|
class PatchEmbedding(nn.Module):
|
@@ -40,22 +40,23 @@ class PositionalEncoding(nn.Module):
|
|
40
40
|
class TransformerEncoder(nn.Module):
|
41
41
|
num_layers: int
|
42
42
|
num_heads: int
|
43
|
-
mlp_dim: int
|
44
43
|
dropout_rate: float = 0.1
|
45
44
|
dtype: Any = jnp.float32
|
46
45
|
precision: Any = jax.lax.Precision.HIGH
|
46
|
+
use_projection: bool = False
|
47
47
|
|
48
48
|
@nn.compact
|
49
|
-
def __call__(self, x,
|
49
|
+
def __call__(self, x, context=None):
|
50
50
|
for _ in range(self.num_layers):
|
51
51
|
x = TransformerBlock(
|
52
52
|
heads=self.num_heads,
|
53
53
|
dim_head=x.shape[-1] // self.num_heads,
|
54
|
-
mlp_dim=self.mlp_dim,
|
55
54
|
dropout_rate=self.dropout_rate,
|
56
55
|
dtype=self.dtype,
|
57
|
-
precision=self.precision
|
58
|
-
|
56
|
+
precision=self.precision,
|
57
|
+
use_self_and_cross=True,
|
58
|
+
use_projection=self.use_projection,
|
59
|
+
)(x, context)
|
59
60
|
return x
|
60
61
|
|
61
62
|
class VisionTransformer(nn.Module):
|
@@ -63,11 +64,11 @@ class VisionTransformer(nn.Module):
|
|
63
64
|
embedding_dim: int = 768
|
64
65
|
num_layers: int = 12
|
65
66
|
num_heads: int = 12
|
66
|
-
mlp_dim: int = 3072
|
67
67
|
emb_features: int = 256
|
68
68
|
dropout_rate: float = 0.1
|
69
69
|
dtype: Any = jnp.float32
|
70
70
|
precision: Any = jax.lax.Precision.HIGH
|
71
|
+
use_projection: bool = False
|
71
72
|
|
72
73
|
@nn.compact
|
73
74
|
def __call__(self, x, temb, textcontext=None):
|
@@ -81,27 +82,23 @@ class VisionTransformer(nn.Module):
|
|
81
82
|
|
82
83
|
# Add positional encoding
|
83
84
|
x = PositionalEncoding(max_len=x.shape[1], embedding_dim=self.embedding_dim)(x)
|
85
|
+
|
86
|
+
num_patches = x.shape[1]
|
84
87
|
|
85
88
|
# Add time embedding
|
86
89
|
temb = jnp.expand_dims(temb, axis=1)
|
87
90
|
x = jnp.concatenate([x, temb], axis=1)
|
88
91
|
|
89
|
-
# Add text context
|
90
|
-
if textcontext is not None:
|
91
|
-
x = jnp.concatenate([x, textcontext], axis=1)
|
92
|
-
|
93
92
|
# Transformer encoder
|
94
93
|
x = TransformerEncoder(
|
95
94
|
num_layers=self.num_layers,
|
96
95
|
num_heads=self.num_heads,
|
97
|
-
mlp_dim=self.mlp_dim,
|
98
96
|
dropout_rate=self.dropout_rate,
|
99
97
|
dtype=self.dtype,
|
100
|
-
precision=self.precision
|
101
|
-
|
98
|
+
precision=self.precision,
|
99
|
+
use_projection=self.use_projection
|
100
|
+
)(x, textcontext)
|
102
101
|
|
103
|
-
# Extract the image tokens (exclude time and text embeddings)
|
104
|
-
num_patches = (x.shape[1] - 1 - (0 if textcontext is None else textcontext.shape[1]))
|
105
102
|
x = x[:, :num_patches, :]
|
106
103
|
|
107
104
|
# Reshape to image dimensions
|
@@ -16,6 +16,7 @@ from flaxdiff.utils import RandomMarkovState
|
|
16
16
|
from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
|
17
17
|
|
18
18
|
from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
|
19
|
+
from flax.training.dynamic_scale import DynamicScale
|
19
20
|
|
20
21
|
class TrainState(SimpleTrainState):
|
21
22
|
rngs: jax.random.PRNGKey
|
@@ -29,6 +30,8 @@ class TrainState(SimpleTrainState):
|
|
29
30
|
)
|
30
31
|
return self.replace(ema_params=new_ema_params)
|
31
32
|
|
33
|
+
from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
|
34
|
+
|
32
35
|
class DiffusionTrainer(SimpleTrainer):
|
33
36
|
noise_schedule: NoiseScheduler
|
34
37
|
model_output_transform: DiffusionPredictionTransform
|
@@ -40,7 +43,7 @@ class DiffusionTrainer(SimpleTrainer):
|
|
40
43
|
optimizer: optax.GradientTransformation,
|
41
44
|
noise_schedule: NoiseScheduler,
|
42
45
|
rngs: jax.random.PRNGKey,
|
43
|
-
unconditional_prob: float = 0.
|
46
|
+
unconditional_prob: float = 0.12,
|
44
47
|
name: str = "Diffusion",
|
45
48
|
model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
|
46
49
|
autoencoder: AutoEncoder = None,
|
@@ -67,7 +70,8 @@ class DiffusionTrainer(SimpleTrainer):
|
|
67
70
|
existing_state: dict = None,
|
68
71
|
existing_best_state: dict = None,
|
69
72
|
model: nn.Module = None,
|
70
|
-
param_transforms: Callable = None
|
73
|
+
param_transforms: Callable = None,
|
74
|
+
use_dynamic_scale: bool = False
|
71
75
|
) -> Tuple[TrainState, TrainState]:
|
72
76
|
print("Generating states for DiffusionTrainer")
|
73
77
|
rngs, subkey = jax.random.split(rngs)
|
@@ -80,7 +84,8 @@ class DiffusionTrainer(SimpleTrainer):
|
|
80
84
|
new_state = existing_state
|
81
85
|
|
82
86
|
if param_transforms is not None:
|
83
|
-
params = param_transforms(params)
|
87
|
+
new_state['params'] = param_transforms(new_state['params'])
|
88
|
+
new_state['ema_params'] = param_transforms(new_state['ema_params'])
|
84
89
|
|
85
90
|
state = TrainState.create(
|
86
91
|
apply_fn=model.apply,
|
@@ -88,7 +93,8 @@ class DiffusionTrainer(SimpleTrainer):
|
|
88
93
|
ema_params=new_state['ema_params'],
|
89
94
|
tx=optimizer,
|
90
95
|
rngs=rngs,
|
91
|
-
metrics=Metrics.empty()
|
96
|
+
metrics=Metrics.empty(),
|
97
|
+
dynamic_scale = DynamicScale() if use_dynamic_scale else None
|
92
98
|
)
|
93
99
|
|
94
100
|
if existing_best_state is not None:
|
@@ -125,14 +131,14 @@ class DiffusionTrainer(SimpleTrainer):
|
|
125
131
|
local_rng_state = RandomMarkovState(subkey)
|
126
132
|
|
127
133
|
images = batch['image']
|
134
|
+
images = jnp.array(images, dtype=jnp.float32)
|
135
|
+
# normalize image
|
136
|
+
images = (images - 127.5) / 127.5
|
128
137
|
|
129
138
|
if autoencoder is not None:
|
130
139
|
# Convert the images to latent space
|
131
140
|
local_rng_state, rngs = local_rng_state.get_random_key()
|
132
141
|
images = autoencoder.encode(images, rngs)
|
133
|
-
else:
|
134
|
-
# normalize image
|
135
|
-
images = (images - 127.5) / 127.5
|
136
142
|
|
137
143
|
output = text_embedder(
|
138
144
|
input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
|
@@ -163,12 +169,39 @@ class DiffusionTrainer(SimpleTrainer):
|
|
163
169
|
loss = nloss
|
164
170
|
return loss
|
165
171
|
|
166
|
-
|
172
|
+
|
173
|
+
if train_state.dynamic_scale is not None:
|
174
|
+
# dynamic scale takes care of averaging gradients across replicas
|
175
|
+
grad_fn = train_state.dynamic_scale.value_and_grad(
|
176
|
+
model_loss, axis_name="data"
|
177
|
+
)
|
178
|
+
dynamic_scale, is_fin, loss, grads = grad_fn(train_state.params)
|
179
|
+
train_state = train_state.replace(dynamic_scale=dynamic_scale)
|
180
|
+
else:
|
181
|
+
grad_fn = jax.value_and_grad(model_loss)
|
182
|
+
loss, grads = grad_fn(train_state.params)
|
183
|
+
if distributed_training:
|
184
|
+
grads = jax.lax.pmean(grads, "data")
|
185
|
+
|
186
|
+
new_state = train_state.apply_gradients(grads=grads)
|
187
|
+
|
188
|
+
if train_state.dynamic_scale:
|
189
|
+
# if is_fin == False the gradients contain Inf/NaNs and optimizer state and
|
190
|
+
# params should be restored (= skip this step).
|
191
|
+
select_fn = functools.partial(jnp.where, is_fin)
|
192
|
+
new_state = train_state.replace(
|
193
|
+
opt_state=jax.tree_util.tree_map(
|
194
|
+
select_fn, new_state.opt_state, train_state.opt_state
|
195
|
+
),
|
196
|
+
params=jax.tree_util.tree_map(
|
197
|
+
select_fn, new_state.params, train_state.params
|
198
|
+
),
|
199
|
+
)
|
200
|
+
|
201
|
+
train_state = new_state.apply_ema(self.ema_decay)
|
202
|
+
|
167
203
|
if distributed_training:
|
168
|
-
grads = jax.lax.pmean(grads, "data")
|
169
204
|
loss = jax.lax.pmean(loss, "data")
|
170
|
-
train_state = train_state.apply_gradients(grads=grads)
|
171
|
-
train_state = train_state.apply_ema(self.ema_decay)
|
172
205
|
return train_state, loss, rng_state
|
173
206
|
|
174
207
|
if distributed_training:
|
@@ -199,4 +232,3 @@ def boolean_string(s):
|
|
199
232
|
if type(s) == bool:
|
200
233
|
return s
|
201
234
|
return s == 'True'
|
202
|
-
|
@@ -22,7 +22,7 @@ from jax.experimental.shard_map import shard_map
|
|
22
22
|
from orbax.checkpoint.utils import fully_replicated_host_local_array_to_global_array
|
23
23
|
from termcolor import colored
|
24
24
|
from typing import Dict, Callable, Sequence, Any, Union, Tuple
|
25
|
-
|
25
|
+
from flax.training.dynamic_scale import DynamicScale
|
26
26
|
from flaxdiff.utils import RandomMarkovState
|
27
27
|
|
28
28
|
PROCESS_COLOR_MAP = {
|
@@ -39,23 +39,23 @@ PROCESS_COLOR_MAP = {
|
|
39
39
|
def _build_global_shape_and_sharding(
|
40
40
|
local_shape: tuple[int, ...], global_mesh: Mesh
|
41
41
|
) -> tuple[tuple[int, ...], jax.sharding.NamedSharding]:
|
42
|
-
|
43
|
-
|
44
|
-
|
42
|
+
sharding = jax.sharding.NamedSharding(global_mesh, P(global_mesh.axis_names))
|
43
|
+
global_shape = (jax.process_count() * local_shape[0],) + local_shape[1:]
|
44
|
+
return global_shape, sharding
|
45
45
|
|
46
46
|
|
47
47
|
def form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array:
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
48
|
+
"""Put local sharded array into local devices"""
|
49
|
+
global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh)
|
50
|
+
try:
|
51
|
+
local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0)
|
52
|
+
except ValueError as array_split_error:
|
53
|
+
raise ValueError(
|
54
|
+
f"Unable to put to devices shape {array.shape} with "
|
55
|
+
f"local device count {len(global_mesh.local_devices)} "
|
56
|
+
) from array_split_error
|
57
|
+
local_device_buffers = jax.device_put(local_device_arrays, global_mesh.local_devices)
|
58
|
+
return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers)
|
59
59
|
|
60
60
|
def convert_to_global_tree(global_mesh, pytree):
|
61
61
|
return jax.tree_util.tree_map_with_path(partial(form_global_array, global_mesh=global_mesh), pytree)
|
@@ -67,12 +67,8 @@ class Metrics(metrics.Collection):
|
|
67
67
|
|
68
68
|
# Define the TrainState
|
69
69
|
class SimpleTrainState(train_state.TrainState):
|
70
|
-
rngs: jax.random.PRNGKey
|
71
70
|
metrics: Metrics
|
72
|
-
|
73
|
-
def get_random_key(self):
|
74
|
-
rngs, subkey = jax.random.split(self.rngs)
|
75
|
-
return self.replace(rngs=rngs), subkey
|
71
|
+
dynamic_scale: DynamicScale
|
76
72
|
|
77
73
|
class SimpleTrainer:
|
78
74
|
state: SimpleTrainState
|
@@ -88,20 +84,22 @@ class SimpleTrainer:
|
|
88
84
|
rngs: jax.random.PRNGKey,
|
89
85
|
train_state: SimpleTrainState = None,
|
90
86
|
name: str = "Simple",
|
91
|
-
load_from_checkpoint:
|
87
|
+
load_from_checkpoint: str = None,
|
92
88
|
checkpoint_suffix: str = "",
|
93
|
-
checkpoint_id: str = None,
|
94
89
|
loss_fn=optax.l2_loss,
|
95
90
|
param_transforms: Callable = None,
|
96
91
|
wandb_config: Dict[str, Any] = None,
|
97
92
|
distributed_training: bool = None,
|
98
93
|
checkpoint_base_path: str = "./checkpoints",
|
94
|
+
checkpoint_step: int = None,
|
95
|
+
use_dynamic_scale: bool = False,
|
99
96
|
):
|
100
97
|
if distributed_training is None or distributed_training is True:
|
101
98
|
# Auto-detect if we are running on multiple devices
|
102
99
|
distributed_training = jax.device_count() > 1
|
103
100
|
self.mesh = jax.sharding.Mesh(jax.devices(), 'data')
|
104
|
-
|
101
|
+
else:
|
102
|
+
self.mesh = None
|
105
103
|
|
106
104
|
self.distributed_training = distributed_training
|
107
105
|
self.model = model
|
@@ -112,7 +110,6 @@ class SimpleTrainer:
|
|
112
110
|
|
113
111
|
|
114
112
|
if wandb_config is not None and jax.process_index() == 0:
|
115
|
-
import wandb
|
116
113
|
run = wandb.init(**wandb_config)
|
117
114
|
self.wandb = run
|
118
115
|
|
@@ -126,11 +123,6 @@ class SimpleTrainer:
|
|
126
123
|
self.wandb.define_metric("train/avg_time_per_step", step_metric="train/epoch")
|
127
124
|
self.wandb.define_metric("train/avg_loss", step_metric="train/epoch")
|
128
125
|
self.wandb.define_metric("train/best_loss", step_metric="train/epoch")
|
129
|
-
|
130
|
-
if checkpoint_id is None:
|
131
|
-
self.checkpoint_id = name.replace(' ', '_').replace('-', '_').lower()
|
132
|
-
else:
|
133
|
-
self.checkpoint_id = checkpoint_id
|
134
126
|
|
135
127
|
# checkpointer = orbax.checkpoint.PyTreeCheckpointer()
|
136
128
|
async_checkpointer = orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler(), timeout_secs=60)
|
@@ -140,12 +132,12 @@ class SimpleTrainer:
|
|
140
132
|
self.checkpointer = orbax.checkpoint.CheckpointManager(
|
141
133
|
self.checkpoint_path() + checkpoint_suffix, async_checkpointer, options)
|
142
134
|
|
143
|
-
if load_from_checkpoint:
|
144
|
-
latest_epoch, old_state, old_best_state, rngstate = self.load()
|
135
|
+
if load_from_checkpoint is not None:
|
136
|
+
latest_epoch, latest_step, old_state, old_best_state, rngstate = self.load(load_from_checkpoint, checkpoint_step)
|
145
137
|
else:
|
146
|
-
latest_epoch, old_state, old_best_state, rngstate = 0, None, None, None
|
138
|
+
latest_epoch, latest_step, old_state, old_best_state, rngstate = 0, 0, None, None, None
|
147
139
|
|
148
|
-
self.
|
140
|
+
self.latest_step = latest_step
|
149
141
|
|
150
142
|
if rngstate:
|
151
143
|
self.rngstate = RandomMarkovState(**rngstate)
|
@@ -156,7 +148,7 @@ class SimpleTrainer:
|
|
156
148
|
|
157
149
|
if train_state == None:
|
158
150
|
state, best_state = self.generate_states(
|
159
|
-
optimizer, subkey, old_state, old_best_state, model, param_transforms
|
151
|
+
optimizer, subkey, old_state, old_best_state, model, param_transforms, use_dynamic_scale
|
160
152
|
)
|
161
153
|
self.init_state(state, best_state)
|
162
154
|
else:
|
@@ -174,7 +166,8 @@ class SimpleTrainer:
|
|
174
166
|
existing_state: dict = None,
|
175
167
|
existing_best_state: dict = None,
|
176
168
|
model: nn.Module = None,
|
177
|
-
param_transforms: Callable = None
|
169
|
+
param_transforms: Callable = None,
|
170
|
+
use_dynamic_scale: bool = False
|
178
171
|
) -> Tuple[SimpleTrainState, SimpleTrainState]:
|
179
172
|
print("Generating states for SimpleTrainer")
|
180
173
|
rngs, subkey = jax.random.split(rngs)
|
@@ -184,12 +177,16 @@ class SimpleTrainer:
|
|
184
177
|
params = model.init(subkey, **input_vars)
|
185
178
|
else:
|
186
179
|
params = existing_state['params']
|
180
|
+
|
181
|
+
if param_transforms is not None:
|
182
|
+
params = param_transforms(params)
|
187
183
|
|
188
184
|
state = SimpleTrainState.create(
|
189
185
|
apply_fn=model.apply,
|
190
186
|
params=params,
|
191
187
|
tx=optimizer,
|
192
|
-
metrics=Metrics.empty()
|
188
|
+
metrics=Metrics.empty(),
|
189
|
+
dynamic_scale = DynamicScale() if use_dynamic_scale else None
|
193
190
|
)
|
194
191
|
if existing_best_state is not None:
|
195
192
|
best_state = state.replace(
|
@@ -222,7 +219,7 @@ class SimpleTrainer:
|
|
222
219
|
return jax.tree_util.tree_map(lambda x : np.array(x), self.rngstate)
|
223
220
|
|
224
221
|
def checkpoint_path(self):
|
225
|
-
path = os.path.join(self.checkpoint_base_path, self.
|
222
|
+
path = os.path.join(self.checkpoint_base_path, self.name.replace(' ', '_').lower())
|
226
223
|
if not os.path.exists(path):
|
227
224
|
os.makedirs(path)
|
228
225
|
return path
|
@@ -234,31 +231,46 @@ class SimpleTrainer:
|
|
234
231
|
os.makedirs(path)
|
235
232
|
return path
|
236
233
|
|
237
|
-
def load(self):
|
238
|
-
|
239
|
-
|
240
|
-
|
234
|
+
def load(self, checkpoint_path=None, checkpoint_step=None):
|
235
|
+
if checkpoint_path is None:
|
236
|
+
checkpointer = self.checkpointer
|
237
|
+
else:
|
238
|
+
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
|
239
|
+
options = orbax.checkpoint.CheckpointManagerOptions(
|
240
|
+
max_to_keep=4, create=False)
|
241
|
+
checkpointer = orbax.checkpoint.CheckpointManager(
|
242
|
+
checkpoint_path, checkpointer, options)
|
243
|
+
|
244
|
+
if checkpoint_step is None:
|
245
|
+
step = checkpointer.latest_step()
|
246
|
+
else:
|
247
|
+
step = checkpoint_step
|
248
|
+
|
249
|
+
print("Loading model from checkpoint at step ", step)
|
250
|
+
ckpt = checkpointer.restore(step)
|
241
251
|
state = ckpt['state']
|
242
252
|
best_state = ckpt['best_state']
|
243
253
|
rngstate = ckpt['rngs']
|
244
254
|
# Convert the state to a TrainState
|
245
255
|
self.best_loss = ckpt['best_loss']
|
256
|
+
current_epoch = ckpt.get('epoch', step) # Must be a checkpoint from an older version which used epochs instead of steps
|
246
257
|
print(
|
247
|
-
f"Loaded model from checkpoint at epoch {
|
248
|
-
return
|
258
|
+
f"Loaded model from checkpoint at epoch {current_epoch} step {step}", ckpt['best_loss'])
|
259
|
+
return current_epoch, step, state, best_state, rngstate
|
249
260
|
|
250
|
-
def save(self, epoch=0):
|
251
|
-
print(f"Saving model at epoch {epoch}")
|
261
|
+
def save(self, epoch=0, step=0):
|
262
|
+
print(f"Saving model at epoch {epoch} step {step}")
|
252
263
|
ckpt = {
|
253
264
|
# 'model': self.model,
|
254
265
|
'rngs': self.get_rngstate(),
|
255
266
|
'state': self.get_state(),
|
256
267
|
'best_state': self.get_best_state(),
|
257
268
|
'best_loss': np.array(self.best_loss),
|
269
|
+
'epoch': epoch,
|
258
270
|
}
|
259
271
|
try:
|
260
272
|
save_args = orbax_utils.save_args_from_target(ckpt)
|
261
|
-
self.checkpointer.save(
|
273
|
+
self.checkpointer.save(step, ckpt, save_kwargs={
|
262
274
|
'save_args': save_args}, force=True)
|
263
275
|
self.checkpointer.wait_until_finished()
|
264
276
|
pass
|
@@ -350,9 +362,10 @@ class SimpleTrainer:
|
|
350
362
|
else:
|
351
363
|
global_device_indexes = 0
|
352
364
|
|
353
|
-
def train_loop(
|
365
|
+
def train_loop(current_step, pbar: tqdm.tqdm, train_state, rng_state):
|
354
366
|
epoch_loss = 0
|
355
|
-
|
367
|
+
current_epoch = current_step // steps_per_epoch
|
368
|
+
last_save_time = time.time()
|
356
369
|
for i in range(steps_per_epoch):
|
357
370
|
batch = next(train_ds)
|
358
371
|
if self.distributed_training and global_device_count > 1:
|
@@ -363,36 +376,46 @@ class SimpleTrainer:
|
|
363
376
|
if self.distributed_training:
|
364
377
|
loss = jax.experimental.multihost_utils.process_allgather(loss)
|
365
378
|
loss = jnp.mean(loss) # Just to make sure its a scaler value
|
379
|
+
|
380
|
+
if loss <= 1e-6:
|
381
|
+
# If the loss is too low, we can assume the model has diverged
|
382
|
+
print(colored(f"Loss too low at step {current_step} => {loss}", 'red'))
|
383
|
+
# Exit the training loop
|
384
|
+
exit(1)
|
366
385
|
|
367
386
|
epoch_loss += loss
|
368
|
-
|
369
|
-
if
|
370
|
-
if
|
387
|
+
current_step += 1
|
388
|
+
if i % 100 == 0:
|
389
|
+
if pbar is not None:
|
371
390
|
pbar.set_postfix(loss=f'{loss:.4f}')
|
372
391
|
pbar.update(100)
|
373
|
-
current_step = current_epoch*steps_per_epoch + i
|
374
392
|
if self.wandb is not None:
|
375
393
|
self.wandb.log({
|
376
394
|
"train/step" : current_step,
|
377
395
|
"train/loss": loss,
|
378
396
|
}, step=current_step)
|
397
|
+
# Save the model every 40 minutes
|
398
|
+
if time.time() - last_save_time > 40 * 60:
|
399
|
+
print(f"Saving model after 40 minutes at step {current_step}")
|
400
|
+
self.save(current_epoch, current_step)
|
401
|
+
last_save_time = time.time()
|
379
402
|
print(colored(f"Epoch done on index {process_index} => {current_epoch} Loss: {epoch_loss/steps_per_epoch}", 'green'))
|
380
403
|
return epoch_loss, current_step, train_state, rng_state
|
381
404
|
|
382
|
-
while self.
|
383
|
-
current_epoch = self.
|
384
|
-
self.latest_epoch += 1
|
405
|
+
while self.latest_step < epochs * steps_per_epoch:
|
406
|
+
current_epoch = self.latest_step // steps_per_epoch
|
385
407
|
print(f"\nEpoch {current_epoch}/{epochs}")
|
386
408
|
start_time = time.time()
|
387
409
|
epoch_loss = 0
|
388
410
|
|
389
411
|
if process_index == 0:
|
390
412
|
with tqdm.tqdm(total=steps_per_epoch, desc=f'\t\tEpoch {current_epoch}', ncols=100, unit='step') as pbar:
|
391
|
-
epoch_loss, current_step, train_state, rng_state = train_loop(
|
413
|
+
epoch_loss, current_step, train_state, rng_state = train_loop(self.latest_step, pbar, train_state, rng_state)
|
392
414
|
else:
|
393
|
-
epoch_loss, current_step, train_state, rng_state = train_loop(
|
394
|
-
print(colored(f"Epoch done on process index {process_index}", PROCESS_COLOR_MAP
|
395
|
-
|
415
|
+
epoch_loss, current_step, train_state, rng_state = train_loop(self.latest_step, None, train_state, rng_state)
|
416
|
+
print(colored(f"Epoch done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
|
417
|
+
|
418
|
+
self.latest_step = current_step
|
396
419
|
end_time = time.time()
|
397
420
|
self.state = train_state
|
398
421
|
self.rngstate = rng_state
|
@@ -402,7 +425,7 @@ class SimpleTrainer:
|
|
402
425
|
if avg_loss < self.best_loss:
|
403
426
|
self.best_loss = avg_loss
|
404
427
|
self.best_state = train_state
|
405
|
-
self.save(current_epoch)
|
428
|
+
self.save(current_epoch, current_step)
|
406
429
|
|
407
430
|
if process_index == 0:
|
408
431
|
if self.wandb is not None:
|
@@ -415,4 +438,4 @@ class SimpleTrainer:
|
|
415
438
|
}, step=current_step)
|
416
439
|
print(colored(f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss}", 'green'))
|
417
440
|
self.save(epochs)
|
418
|
-
return self.state
|
441
|
+
return self.state
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: flaxdiff
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.10
|
4
4
|
Summary: A versatile and easy to understand Diffusion library
|
5
5
|
Author: Ashish Kumar Singh
|
6
6
|
Author-email: ashishkmr472@gmail.com
|
@@ -234,6 +234,23 @@ plotImages(samples, dpi=300)
|
|
234
234
|
|
235
235
|
## Gallery
|
236
236
|
|
237
|
+
### Images generated by Euler Ancestral Sampler in 200 Steps [text2image with CFG]
|
238
|
+
Model trained on Laion-Aesthetics 12M + CC12M + MS COCO + 1M aesthetic 6+ subset of COYO-700M on TPU-v4-32:
|
239
|
+
`a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden`
|
240
|
+
|
241
|
+
**Params**:
|
242
|
+
`Dataset: Laion-Aesthetics 12M + CC12M + MS COCO + 1M aesthetic 6+ subset of COYO-700M`
|
243
|
+
`Batch size: 256`
|
244
|
+
`Image Size: 128`
|
245
|
+
`Training Epochs: 5`
|
246
|
+
`Steps per epoch: 74573`
|
247
|
+
`Model Configurations: feature_depths=[128, 256, 512, 1024]`
|
248
|
+
|
249
|
+
`Training Noise Schedule: EDMNoiseScheduler`
|
250
|
+
`Inference Noise Schedule: KarrasEDMPredictor`
|
251
|
+
|
252
|
+

|
253
|
+
|
237
254
|
### Images generated by Euler Ancestral Sampler in 200 Steps [text2image with CFG]
|
238
255
|
Images generated by the following prompts using classifier free guidance with guidance factor = 2:
|
239
256
|
`'water tulip, a water lily, a water lily, a water lily, a photo of a marigold, a water lily, a water lily, a photo of a lotus, a photo of a lotus, a photo of a lotus, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose'`
|
@@ -1,14 +1,14 @@
|
|
1
1
|
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
|
3
3
|
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
4
|
-
flaxdiff/models/attention.py,sha256=
|
5
|
-
flaxdiff/models/common.py,sha256=
|
4
|
+
flaxdiff/models/attention.py,sha256=YyVI3dTAMB8cS8VWHgtIigr2YY-MYfFTlaNDfjNJOCk,12596
|
5
|
+
flaxdiff/models/common.py,sha256=fd-Fl0VCNEBjijHNwGBqYL5VvXe9u0347h25czNTmRw,10780
|
6
6
|
flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
|
7
|
-
flaxdiff/models/simple_unet.py,sha256=
|
8
|
-
flaxdiff/models/simple_vit.py,sha256=
|
7
|
+
flaxdiff/models/simple_unet.py,sha256=H67Pfy8BqKHvhdw_K3lBiFdruNQFBMElw8SDZdvg9Ec,10084
|
8
|
+
flaxdiff/models/simple_vit.py,sha256=xD23i1b7WEvoH4tUMsLyCe9ebDcv-PpaV0Nso38Jlb8,3887
|
9
9
|
flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
|
10
10
|
flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
|
11
|
-
flaxdiff/models/autoencoder/diffusers.py,sha256=
|
11
|
+
flaxdiff/models/autoencoder/diffusers.py,sha256=l4teVksXd9XCCQWcVn9eB820xJyLT8hpg1CXQ_aHZ6M,3611
|
12
12
|
flaxdiff/models/autoencoder/simple_autoenc.py,sha256=UXHPgDmwGTnv3Uts6Zj3p9R9nJXnEiEXbllgarwDfXM,805
|
13
13
|
flaxdiff/predictors/__init__.py,sha256=SKkYYRF9Wfgk2zhtZw4vCXOdOeRlrm2Mk6cvuaEvAzc,4403
|
14
14
|
flaxdiff/samplers/__init__.py,sha256=_S-9TwDeshrI0VmapV-J2hqjTByOa0-oOeUs_IdovjU,285
|
@@ -30,9 +30,9 @@ flaxdiff/schedulers/linear.py,sha256=6003F5ISq1Wc0h6UAzY95MJgsDIKGMhBzbiVALpea0k
|
|
30
30
|
flaxdiff/schedulers/sqrt.py,sha256=1F84ZgQPuoNMhe6yxGTR2G0h7dPOZtm4UDQOakbSsEU,445
|
31
31
|
flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,128
|
32
32
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
|
33
|
-
flaxdiff/trainer/diffusion_trainer.py,sha256=
|
34
|
-
flaxdiff/trainer/simple_trainer.py,sha256=
|
35
|
-
flaxdiff-0.1.
|
36
|
-
flaxdiff-0.1.
|
37
|
-
flaxdiff-0.1.
|
38
|
-
flaxdiff-0.1.
|
33
|
+
flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
|
34
|
+
flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
|
35
|
+
flaxdiff-0.1.10.dist-info/METADATA,sha256=q9O56jlhtuznnbmlHeKa9-gLFtWXge0bwBU6g9_P8Jk,22083
|
36
|
+
flaxdiff-0.1.10.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
37
|
+
flaxdiff-0.1.10.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
38
|
+
flaxdiff-0.1.10.dist-info/RECORD,,
|
File without changes
|
File without changes
|