flaxdiff 0.1.8__tar.gz → 0.1.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/PKG-INFO +18 -1
  2. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/README.md +17 -0
  3. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/models/attention.py +7 -5
  4. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/models/autoencoder/diffusers.py +1 -1
  5. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/models/common.py +12 -2
  6. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/models/simple_unet.py +17 -7
  7. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/models/simple_vit.py +13 -16
  8. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/trainer/diffusion_trainer.py +41 -11
  9. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/trainer/simple_trainer.py +80 -60
  10. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff.egg-info/PKG-INFO +18 -1
  11. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/setup.py +1 -1
  12. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/__init__.py +0 -0
  13. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/models/__init__.py +0 -0
  14. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/models/autoencoder/__init__.py +0 -0
  15. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
  16. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
  17. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/models/favor_fastattn.py +0 -0
  18. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/predictors/__init__.py +0 -0
  19. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/samplers/__init__.py +0 -0
  20. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/samplers/common.py +0 -0
  21. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/samplers/ddim.py +0 -0
  22. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/samplers/ddpm.py +0 -0
  23. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/samplers/euler.py +0 -0
  24. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/samplers/heun_sampler.py +0 -0
  25. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/samplers/multistep_dpm.py +0 -0
  26. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/samplers/rk4_sampler.py +0 -0
  27. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/schedulers/__init__.py +0 -0
  28. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/schedulers/common.py +0 -0
  29. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/schedulers/continuous.py +0 -0
  30. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/schedulers/cosine.py +0 -0
  31. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/schedulers/discrete.py +0 -0
  32. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/schedulers/exp.py +0 -0
  33. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/schedulers/karras.py +0 -0
  34. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/schedulers/linear.py +0 -0
  35. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/schedulers/sqrt.py +0 -0
  36. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/trainer/__init__.py +0 -0
  37. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
  38. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff/utils.py +0 -0
  39. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff.egg-info/SOURCES.txt +0 -0
  40. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff.egg-info/dependency_links.txt +0 -0
  41. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff.egg-info/requires.txt +0 -0
  42. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/flaxdiff.egg-info/top_level.txt +0 -0
  43. {flaxdiff-0.1.8 → flaxdiff-0.1.9}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.8
3
+ Version: 0.1.9
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -234,6 +234,23 @@ plotImages(samples, dpi=300)
234
234
 
235
235
  ## Gallery
236
236
 
237
+ ### Images generated by Euler Ancestral Sampler in 200 Steps [text2image with CFG]
238
+ Model trained on Laion-Aesthetics 12M + CC12M + MS COCO + 1M aesthetic 6+ subset of COYO-700M on TPU-v4-32:
239
+ `a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden`
240
+
241
+ **Params**:
242
+ `Dataset: Laion-Aesthetics 12M + CC12M + MS COCO + 1M aesthetic 6+ subset of COYO-700M`
243
+ `Batch size: 256`
244
+ `Image Size: 128`
245
+ `Training Epochs: 5`
246
+ `Steps per epoch: 74573`
247
+ `Model Configurations: feature_depths=[128, 256, 512, 1024]`
248
+
249
+ `Training Noise Schedule: EDMNoiseScheduler`
250
+ `Inference Noise Schedule: KarrasEDMPredictor`
251
+
252
+ ![EulerA with CFG](images/medium_epoch5.png)
253
+
237
254
  ### Images generated by Euler Ancestral Sampler in 200 Steps [text2image with CFG]
238
255
  Images generated by the following prompts using classifier free guidance with guidance factor = 2:
239
256
  `'water tulip, a water lily, a water lily, a water lily, a photo of a marigold, a water lily, a water lily, a photo of a lotus, a photo of a lotus, a photo of a lotus, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose'`
@@ -221,6 +221,23 @@ plotImages(samples, dpi=300)
221
221
 
222
222
  ## Gallery
223
223
 
224
+ ### Images generated by Euler Ancestral Sampler in 200 Steps [text2image with CFG]
225
+ Model trained on Laion-Aesthetics 12M + CC12M + MS COCO + 1M aesthetic 6+ subset of COYO-700M on TPU-v4-32:
226
+ `a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden`
227
+
228
+ **Params**:
229
+ `Dataset: Laion-Aesthetics 12M + CC12M + MS COCO + 1M aesthetic 6+ subset of COYO-700M`
230
+ `Batch size: 256`
231
+ `Image Size: 128`
232
+ `Training Epochs: 5`
233
+ `Steps per epoch: 74573`
234
+ `Model Configurations: feature_depths=[128, 256, 512, 1024]`
235
+
236
+ `Training Noise Schedule: EDMNoiseScheduler`
237
+ `Inference Noise Schedule: KarrasEDMPredictor`
238
+
239
+ ![EulerA with CFG](images/medium_epoch5.png)
240
+
224
241
  ### Images generated by Euler Ancestral Sampler in 200 Steps [text2image with CFG]
225
242
  Images generated by the following prompts using classifier free guidance with guidance factor = 2:
226
243
  `'water tulip, a water lily, a water lily, a water lily, a photo of a marigold, a water lily, a water lily, a photo of a lotus, a photo of a lotus, a photo of a lotus, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose'`
@@ -156,7 +156,9 @@ class NormalAttention(nn.Module):
156
156
  value = self.value(context)
157
157
 
158
158
  hidden_states = nn.dot_product_attention(
159
- query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision
159
+ query, key, value, dtype=self.dtype, broadcast_dropout=False,
160
+ dropout_rng=None, precision=self.precision, force_fp32_for_softmax=True,
161
+ deterministic=True
160
162
  )
161
163
  proj = self.proj_attn(hidden_states)
162
164
  proj = proj.reshape(orig_x_shape)
@@ -187,7 +189,7 @@ class FlaxGEGLU(nn.Module):
187
189
 
188
190
  def __call__(self, hidden_states):
189
191
  hidden_states = self.proj(hidden_states)
190
- hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=3)
192
+ hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=-1)
191
193
  return hidden_linear * nn.gelu(hidden_gelu)
192
194
 
193
195
  class FlaxFeedForward(nn.Module):
@@ -291,14 +293,14 @@ class TransformerBlock(nn.Module):
291
293
  dtype: Optional[Dtype] = None
292
294
  precision: PrecisionLike = None
293
295
  use_projection: bool = False
294
- use_flash_attention:bool = True
295
- use_self_and_cross:bool = False
296
+ use_flash_attention:bool = False
297
+ use_self_and_cross:bool = True
296
298
  only_pure_attention:bool = False
297
299
 
298
300
  @nn.compact
299
301
  def __call__(self, x, context=None):
300
302
  inner_dim = self.heads * self.dim_head
301
- B, H, W, C = x.shape
303
+ C = x.shape[-1]
302
304
  normed_x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
303
305
  if self.use_projection == True:
304
306
  if self.use_linear_attention:
@@ -78,7 +78,7 @@ class StableDiffusionVAE(AutoEncoder):
78
78
  log_std = jnp.clip(log_std, -30, 20)
79
79
  std = jnp.exp(0.5 * log_std)
80
80
  latents = mean + std * jax.random.normal(rngkey, mean.shape, dtype=mean.dtype)
81
- print("Sampled")
81
+ # print("Sampled")
82
82
  else:
83
83
  # return the mean
84
84
  latents, _ = jnp.split(latents, 2, axis=-1)
@@ -5,6 +5,7 @@ from typing import Optional, Any, Callable, Sequence, Union
5
5
  from flax.typing import Dtype, PrecisionLike
6
6
  from typing import Dict, Callable, Sequence, Any, Union
7
7
  import einops
8
+ from functools import partial
8
9
 
9
10
  # Kernel initializer to use
10
11
  def kernel_init(scale, dtype=jnp.float32):
@@ -266,11 +267,20 @@ class ResidualBlock(nn.Module):
266
267
  kernel_init:Callable=kernel_init(1.0)
267
268
  dtype: Optional[Dtype] = None
268
269
  precision: PrecisionLike = None
270
+
271
+ def setup(self):
272
+ if self.norm_groups > 0:
273
+ norm = partial(nn.GroupNorm, self.norm_groups)
274
+ else:
275
+ norm = partial(nn.RMSNorm, 1e-5)
276
+
277
+ self.norm1 = norm()
278
+ self.norm2 = norm()
269
279
 
270
280
  @nn.compact
271
281
  def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_features:jax.Array=None):
272
282
  residual = x
273
- out = nn.GroupNorm(self.norm_groups)(x)
283
+ out = self.norm1(x)
274
284
  # out = nn.RMSNorm()(x)
275
285
  out = self.activation(out)
276
286
 
@@ -295,7 +305,7 @@ class ResidualBlock(nn.Module):
295
305
  # out = out * (1 + scale) + shift
296
306
  out = out + temb
297
307
 
298
- out = nn.GroupNorm(self.norm_groups)(out)
308
+ out = self.norm2(out)
299
309
  # out = nn.RMSNorm()(out)
300
310
  out = self.activation(out)
301
311
 
@@ -6,6 +6,7 @@ from typing import Dict, Callable, Sequence, Any, Union, Optional
6
6
  import einops
7
7
  from .common import kernel_init, ConvLayer, Downsample, Upsample, FourierEmbedding, TimeProjection, ResidualBlock
8
8
  from .attention import TransformerBlock
9
+ from functools import partial
9
10
 
10
11
  class Unet(nn.Module):
11
12
  output_channels:int=3
@@ -19,6 +20,15 @@ class Unet(nn.Module):
19
20
  dtype: Optional[Dtype] = None
20
21
  precision: PrecisionLike = None
21
22
 
23
+ def setup(self):
24
+ if self.norm_groups > 0:
25
+ norm = partial(nn.GroupNorm, self.norm_groups)
26
+ else:
27
+ norm = partial(nn.RMSNorm, 1e-5)
28
+
29
+ # self.last_up_norm = norm()
30
+ self.conv_out_norm = norm()
31
+
22
32
  @nn.compact
23
33
  def __call__(self, x, temb, textcontext):
24
34
  # print("embedding features", self.emb_features)
@@ -69,7 +79,7 @@ class Unet(nn.Module):
69
79
  use_projection=attention_config.get("use_projection", False),
70
80
  use_self_and_cross=attention_config.get("use_self_and_cross", True),
71
81
  precision=attention_config.get("precision", self.precision),
72
- only_pure_attention=True,
82
+ only_pure_attention=attention_config.get("only_pure_attention", True),
73
83
  name=f"down_{i}_attention_{j}")(x, textcontext)
74
84
  # print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
75
85
  downs.append(x)
@@ -107,8 +117,8 @@ class Unet(nn.Module):
107
117
  use_linear_attention=False,
108
118
  use_projection=middle_attention.get("use_projection", False),
109
119
  use_self_and_cross=False,
110
- precision=attention_config.get("precision", self.precision),
111
- only_pure_attention=True,
120
+ precision=middle_attention.get("precision", self.precision),
121
+ only_pure_attention=middle_attention.get("only_pure_attention", True),
112
122
  name=f"middle_attention_{j}")(x, textcontext)
113
123
  x = ResidualBlock(
114
124
  middle_conv_type,
@@ -150,7 +160,7 @@ class Unet(nn.Module):
150
160
  use_projection=attention_config.get("use_projection", False),
151
161
  use_self_and_cross=attention_config.get("use_self_and_cross", True),
152
162
  precision=attention_config.get("precision", self.precision),
153
- only_pure_attention=True,
163
+ only_pure_attention=attention_config.get("only_pure_attention", True),
154
164
  name=f"up_{i}_attention_{j}")(x, textcontext)
155
165
  # print("Upscaling ", i, x.shape)
156
166
  if i != len(feature_depths) - 1:
@@ -163,13 +173,13 @@ class Unet(nn.Module):
163
173
  precision=self.precision
164
174
  )(x)
165
175
 
166
- # x = nn.GroupNorm(8)(x)
176
+ # x = self.last_up_norm(x)
167
177
  x = ConvLayer(
168
178
  conv_type,
169
179
  features=self.feature_depths[0],
170
180
  kernel_size=(3, 3),
171
181
  strides=(1, 1),
172
- kernel_init=kernel_init(0.0),
182
+ kernel_init=kernel_init(1.0),
173
183
  dtype=self.dtype,
174
184
  precision=self.precision
175
185
  )(x)
@@ -189,7 +199,7 @@ class Unet(nn.Module):
189
199
  precision=self.precision
190
200
  )(x, temb)
191
201
 
192
- x = nn.GroupNorm(self.norm_groups)(x)
202
+ x = self.conv_out_norm(x)
193
203
  x = self.activation(x)
194
204
 
195
205
  noise_out = ConvLayer(
@@ -4,7 +4,7 @@ import jax
4
4
  import jax.numpy as jnp
5
5
  from flax import linen as nn
6
6
  from typing import Callable, Any
7
- from .simply_unet import FourierEmbedding, TimeProjection, ConvLayer, kernel_init
7
+ from .simple_unet import FourierEmbedding, TimeProjection, ConvLayer, kernel_init
8
8
  from .attention import TransformerBlock
9
9
 
10
10
  class PatchEmbedding(nn.Module):
@@ -40,22 +40,23 @@ class PositionalEncoding(nn.Module):
40
40
  class TransformerEncoder(nn.Module):
41
41
  num_layers: int
42
42
  num_heads: int
43
- mlp_dim: int
44
43
  dropout_rate: float = 0.1
45
44
  dtype: Any = jnp.float32
46
45
  precision: Any = jax.lax.Precision.HIGH
46
+ use_projection: bool = False
47
47
 
48
48
  @nn.compact
49
- def __call__(self, x, training=True):
49
+ def __call__(self, x, context=None):
50
50
  for _ in range(self.num_layers):
51
51
  x = TransformerBlock(
52
52
  heads=self.num_heads,
53
53
  dim_head=x.shape[-1] // self.num_heads,
54
- mlp_dim=self.mlp_dim,
55
54
  dropout_rate=self.dropout_rate,
56
55
  dtype=self.dtype,
57
- precision=self.precision
58
- )(x)
56
+ precision=self.precision,
57
+ use_self_and_cross=True,
58
+ use_projection=self.use_projection,
59
+ )(x, context)
59
60
  return x
60
61
 
61
62
  class VisionTransformer(nn.Module):
@@ -63,11 +64,11 @@ class VisionTransformer(nn.Module):
63
64
  embedding_dim: int = 768
64
65
  num_layers: int = 12
65
66
  num_heads: int = 12
66
- mlp_dim: int = 3072
67
67
  emb_features: int = 256
68
68
  dropout_rate: float = 0.1
69
69
  dtype: Any = jnp.float32
70
70
  precision: Any = jax.lax.Precision.HIGH
71
+ use_projection: bool = False
71
72
 
72
73
  @nn.compact
73
74
  def __call__(self, x, temb, textcontext=None):
@@ -81,27 +82,23 @@ class VisionTransformer(nn.Module):
81
82
 
82
83
  # Add positional encoding
83
84
  x = PositionalEncoding(max_len=x.shape[1], embedding_dim=self.embedding_dim)(x)
85
+
86
+ num_patches = x.shape[1]
84
87
 
85
88
  # Add time embedding
86
89
  temb = jnp.expand_dims(temb, axis=1)
87
90
  x = jnp.concatenate([x, temb], axis=1)
88
91
 
89
- # Add text context
90
- if textcontext is not None:
91
- x = jnp.concatenate([x, textcontext], axis=1)
92
-
93
92
  # Transformer encoder
94
93
  x = TransformerEncoder(
95
94
  num_layers=self.num_layers,
96
95
  num_heads=self.num_heads,
97
- mlp_dim=self.mlp_dim,
98
96
  dropout_rate=self.dropout_rate,
99
97
  dtype=self.dtype,
100
- precision=self.precision
101
- )(x)
98
+ precision=self.precision,
99
+ use_projection=self.use_projection
100
+ )(x, textcontext)
102
101
 
103
- # Extract the image tokens (exclude time and text embeddings)
104
- num_patches = (x.shape[1] - 1 - (0 if textcontext is None else textcontext.shape[1]))
105
102
  x = x[:, :num_patches, :]
106
103
 
107
104
  # Reshape to image dimensions
@@ -29,6 +29,8 @@ class TrainState(SimpleTrainState):
29
29
  )
30
30
  return self.replace(ema_params=new_ema_params)
31
31
 
32
+ from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
33
+
32
34
  class DiffusionTrainer(SimpleTrainer):
33
35
  noise_schedule: NoiseScheduler
34
36
  model_output_transform: DiffusionPredictionTransform
@@ -40,7 +42,7 @@ class DiffusionTrainer(SimpleTrainer):
40
42
  optimizer: optax.GradientTransformation,
41
43
  noise_schedule: NoiseScheduler,
42
44
  rngs: jax.random.PRNGKey,
43
- unconditional_prob: float = 0.2,
45
+ unconditional_prob: float = 0.12,
44
46
  name: str = "Diffusion",
45
47
  model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
46
48
  autoencoder: AutoEncoder = None,
@@ -67,7 +69,8 @@ class DiffusionTrainer(SimpleTrainer):
67
69
  existing_state: dict = None,
68
70
  existing_best_state: dict = None,
69
71
  model: nn.Module = None,
70
- param_transforms: Callable = None
72
+ param_transforms: Callable = None,
73
+ use_dynamic_scale: bool = False
71
74
  ) -> Tuple[TrainState, TrainState]:
72
75
  print("Generating states for DiffusionTrainer")
73
76
  rngs, subkey = jax.random.split(rngs)
@@ -88,7 +91,8 @@ class DiffusionTrainer(SimpleTrainer):
88
91
  ema_params=new_state['ema_params'],
89
92
  tx=optimizer,
90
93
  rngs=rngs,
91
- metrics=Metrics.empty()
94
+ metrics=Metrics.empty(),
95
+ dynamic_scale = flax.training.dynamic_scale.DynamicScale() if use_dynamic_scale else None
92
96
  )
93
97
 
94
98
  if existing_best_state is not None:
@@ -125,14 +129,14 @@ class DiffusionTrainer(SimpleTrainer):
125
129
  local_rng_state = RandomMarkovState(subkey)
126
130
 
127
131
  images = batch['image']
132
+ images = jnp.array(images, dtype=jnp.float32)
133
+ # normalize image
134
+ images = (images - 127.5) / 127.5
128
135
 
129
136
  if autoencoder is not None:
130
137
  # Convert the images to latent space
131
138
  local_rng_state, rngs = local_rng_state.get_random_key()
132
139
  images = autoencoder.encode(images, rngs)
133
- else:
134
- # normalize image
135
- images = (images - 127.5) / 127.5
136
140
 
137
141
  output = text_embedder(
138
142
  input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
@@ -163,12 +167,39 @@ class DiffusionTrainer(SimpleTrainer):
163
167
  loss = nloss
164
168
  return loss
165
169
 
166
- loss, grads = jax.value_and_grad(model_loss)(train_state.params)
170
+
171
+ if train_state.dynamic_scale is not None:
172
+ # dynamic scale takes care of averaging gradients across replicas
173
+ grad_fn = train_state.dynamic_scale.value_and_grad(
174
+ model_loss, axis_name="data"
175
+ )
176
+ dynamic_scale, is_fin, loss, grads = grad_fn(train_state.params)
177
+ train_state = train_state.replace(dynamic_scale=dynamic_scale)
178
+ else:
179
+ grad_fn = jax.value_and_grad(model_loss)
180
+ loss, grads = grad_fn(train_state.params)
181
+ if distributed_training:
182
+ grads = jax.lax.pmean(grads, "data")
183
+
184
+ new_state = train_state.apply_gradients(grads=grads)
185
+
186
+ if train_state.dynamic_scale:
187
+ # if is_fin == False the gradients contain Inf/NaNs and optimizer state and
188
+ # params should be restored (= skip this step).
189
+ select_fn = functools.partial(jnp.where, is_fin)
190
+ new_state = train_state.replace(
191
+ opt_state=jax.tree_util.tree_map(
192
+ select_fn, new_state.opt_state, train_state.opt_state
193
+ ),
194
+ params=jax.tree_util.tree_map(
195
+ select_fn, new_state.params, train_state.params
196
+ ),
197
+ )
198
+
199
+ train_state = new_state.apply_ema(self.ema_decay)
200
+
167
201
  if distributed_training:
168
- grads = jax.lax.pmean(grads, "data")
169
202
  loss = jax.lax.pmean(loss, "data")
170
- train_state = train_state.apply_gradients(grads=grads)
171
- train_state = train_state.apply_ema(self.ema_decay)
172
203
  return train_state, loss, rng_state
173
204
 
174
205
  if distributed_training:
@@ -199,4 +230,3 @@ def boolean_string(s):
199
230
  if type(s) == bool:
200
231
  return s
201
232
  return s == 'True'
202
-
@@ -39,23 +39,23 @@ PROCESS_COLOR_MAP = {
39
39
  def _build_global_shape_and_sharding(
40
40
  local_shape: tuple[int, ...], global_mesh: Mesh
41
41
  ) -> tuple[tuple[int, ...], jax.sharding.NamedSharding]:
42
- sharding = jax.sharding.NamedSharding(global_mesh, P(global_mesh.axis_names))
43
- global_shape = (jax.process_count() * local_shape[0],) + local_shape[1:]
44
- return global_shape, sharding
42
+ sharding = jax.sharding.NamedSharding(global_mesh, P(global_mesh.axis_names))
43
+ global_shape = (jax.process_count() * local_shape[0],) + local_shape[1:]
44
+ return global_shape, sharding
45
45
 
46
46
 
47
47
  def form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array:
48
- """Put local sharded array into local devices"""
49
- global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh)
50
- try:
51
- local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0)
52
- except ValueError as array_split_error:
53
- raise ValueError(
54
- f"Unable to put to devices shape {array.shape} with "
55
- f"local device count {len(global_mesh.local_devices)} "
56
- ) from array_split_error
57
- local_device_buffers = jax.device_put(local_device_arrays, global_mesh.local_devices)
58
- return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers)
48
+ """Put local sharded array into local devices"""
49
+ global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh)
50
+ try:
51
+ local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0)
52
+ except ValueError as array_split_error:
53
+ raise ValueError(
54
+ f"Unable to put to devices shape {array.shape} with "
55
+ f"local device count {len(global_mesh.local_devices)} "
56
+ ) from array_split_error
57
+ local_device_buffers = jax.device_put(local_device_arrays, global_mesh.local_devices)
58
+ return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers)
59
59
 
60
60
  def convert_to_global_tree(global_mesh, pytree):
61
61
  return jax.tree_util.tree_map_with_path(partial(form_global_array, global_mesh=global_mesh), pytree)
@@ -67,12 +67,8 @@ class Metrics(metrics.Collection):
67
67
 
68
68
  # Define the TrainState
69
69
  class SimpleTrainState(train_state.TrainState):
70
- rngs: jax.random.PRNGKey
71
70
  metrics: Metrics
72
-
73
- def get_random_key(self):
74
- rngs, subkey = jax.random.split(self.rngs)
75
- return self.replace(rngs=rngs), subkey
71
+ dynamic_scale: flax.training.dynamic_scale.DynamicScale
76
72
 
77
73
  class SimpleTrainer:
78
74
  state: SimpleTrainState
@@ -88,20 +84,22 @@ class SimpleTrainer:
88
84
  rngs: jax.random.PRNGKey,
89
85
  train_state: SimpleTrainState = None,
90
86
  name: str = "Simple",
91
- load_from_checkpoint: bool = False,
87
+ load_from_checkpoint: str = None,
92
88
  checkpoint_suffix: str = "",
93
- checkpoint_id: str = None,
94
89
  loss_fn=optax.l2_loss,
95
90
  param_transforms: Callable = None,
96
91
  wandb_config: Dict[str, Any] = None,
97
92
  distributed_training: bool = None,
98
93
  checkpoint_base_path: str = "./checkpoints",
94
+ checkpoint_step: int = None,
95
+ use_dynamic_scale: bool = False,
99
96
  ):
100
97
  if distributed_training is None or distributed_training is True:
101
98
  # Auto-detect if we are running on multiple devices
102
99
  distributed_training = jax.device_count() > 1
103
100
  self.mesh = jax.sharding.Mesh(jax.devices(), 'data')
104
- # self.sharding = jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec('data'))
101
+ else:
102
+ self.mesh = None
105
103
 
106
104
  self.distributed_training = distributed_training
107
105
  self.model = model
@@ -112,7 +110,6 @@ class SimpleTrainer:
112
110
 
113
111
 
114
112
  if wandb_config is not None and jax.process_index() == 0:
115
- import wandb
116
113
  run = wandb.init(**wandb_config)
117
114
  self.wandb = run
118
115
 
@@ -126,11 +123,6 @@ class SimpleTrainer:
126
123
  self.wandb.define_metric("train/avg_time_per_step", step_metric="train/epoch")
127
124
  self.wandb.define_metric("train/avg_loss", step_metric="train/epoch")
128
125
  self.wandb.define_metric("train/best_loss", step_metric="train/epoch")
129
-
130
- if checkpoint_id is None:
131
- self.checkpoint_id = name.replace(' ', '_').replace('-', '_').lower()
132
- else:
133
- self.checkpoint_id = checkpoint_id
134
126
 
135
127
  # checkpointer = orbax.checkpoint.PyTreeCheckpointer()
136
128
  async_checkpointer = orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler(), timeout_secs=60)
@@ -140,12 +132,12 @@ class SimpleTrainer:
140
132
  self.checkpointer = orbax.checkpoint.CheckpointManager(
141
133
  self.checkpoint_path() + checkpoint_suffix, async_checkpointer, options)
142
134
 
143
- if load_from_checkpoint:
144
- latest_epoch, old_state, old_best_state, rngstate = self.load()
135
+ if load_from_checkpoint is not None:
136
+ latest_epoch, latest_step, old_state, old_best_state, rngstate = self.load(load_from_checkpoint, checkpoint_step)
145
137
  else:
146
- latest_epoch, old_state, old_best_state, rngstate = 0, None, None, None
138
+ latest_epoch, latest_step, old_state, old_best_state, rngstate = 0, 0, None, None, None
147
139
 
148
- self.latest_epoch = latest_epoch
140
+ self.latest_step = latest_step
149
141
 
150
142
  if rngstate:
151
143
  self.rngstate = RandomMarkovState(**rngstate)
@@ -156,7 +148,7 @@ class SimpleTrainer:
156
148
 
157
149
  if train_state == None:
158
150
  state, best_state = self.generate_states(
159
- optimizer, subkey, old_state, old_best_state, model, param_transforms
151
+ optimizer, subkey, old_state, old_best_state, model, param_transforms, use_dynamic_scale
160
152
  )
161
153
  self.init_state(state, best_state)
162
154
  else:
@@ -174,7 +166,8 @@ class SimpleTrainer:
174
166
  existing_state: dict = None,
175
167
  existing_best_state: dict = None,
176
168
  model: nn.Module = None,
177
- param_transforms: Callable = None
169
+ param_transforms: Callable = None,
170
+ use_dynamic_scale: bool = False
178
171
  ) -> Tuple[SimpleTrainState, SimpleTrainState]:
179
172
  print("Generating states for SimpleTrainer")
180
173
  rngs, subkey = jax.random.split(rngs)
@@ -189,7 +182,8 @@ class SimpleTrainer:
189
182
  apply_fn=model.apply,
190
183
  params=params,
191
184
  tx=optimizer,
192
- metrics=Metrics.empty()
185
+ metrics=Metrics.empty(),
186
+ dynamic_scale = flax.training.dynamic_scale.DynamicScale() if use_dynamic_scale else None
193
187
  )
194
188
  if existing_best_state is not None:
195
189
  best_state = state.replace(
@@ -222,7 +216,7 @@ class SimpleTrainer:
222
216
  return jax.tree_util.tree_map(lambda x : np.array(x), self.rngstate)
223
217
 
224
218
  def checkpoint_path(self):
225
- path = os.path.join(self.checkpoint_base_path, self.checkpoint_id)
219
+ path = os.path.join(self.checkpoint_base_path, self.name.replace(' ', '_').lower())
226
220
  if not os.path.exists(path):
227
221
  os.makedirs(path)
228
222
  return path
@@ -234,31 +228,46 @@ class SimpleTrainer:
234
228
  os.makedirs(path)
235
229
  return path
236
230
 
237
- def load(self):
238
- epoch = self.checkpointer.latest_step()
239
- print("Loading model from checkpoint", epoch)
240
- ckpt = self.checkpointer.restore(epoch)
231
+ def load(self, checkpoint_path=None, checkpoint_step=None):
232
+ if checkpoint_path is None:
233
+ checkpointer = self.checkpointer
234
+ else:
235
+ checkpointer = orbax.checkpoint.PyTreeCheckpointer()
236
+ options = orbax.checkpoint.CheckpointManagerOptions(
237
+ max_to_keep=4, create=False)
238
+ checkpointer = orbax.checkpoint.CheckpointManager(
239
+ checkpoint_path, checkpointer, options)
240
+
241
+ if checkpoint_step is None:
242
+ step = checkpointer.latest_step()
243
+ else:
244
+ step = checkpoint_step
245
+
246
+ print("Loading model from checkpoint at step ", step)
247
+ ckpt = checkpointer.restore(step)
241
248
  state = ckpt['state']
242
249
  best_state = ckpt['best_state']
243
250
  rngstate = ckpt['rngs']
244
251
  # Convert the state to a TrainState
245
252
  self.best_loss = ckpt['best_loss']
253
+ current_epoch = ckpt.get('epoch', step) # Must be a checkpoint from an older version which used epochs instead of steps
246
254
  print(
247
- f"Loaded model from checkpoint at epoch {epoch}", ckpt['best_loss'])
248
- return epoch, state, best_state, rngstate
255
+ f"Loaded model from checkpoint at epoch {current_epoch} step {step}", ckpt['best_loss'])
256
+ return current_epoch, step, state, best_state, rngstate
249
257
 
250
- def save(self, epoch=0):
251
- print(f"Saving model at epoch {epoch}")
258
+ def save(self, epoch=0, step=0):
259
+ print(f"Saving model at epoch {epoch} step {step}")
252
260
  ckpt = {
253
261
  # 'model': self.model,
254
262
  'rngs': self.get_rngstate(),
255
263
  'state': self.get_state(),
256
264
  'best_state': self.get_best_state(),
257
265
  'best_loss': np.array(self.best_loss),
266
+ 'epoch': epoch,
258
267
  }
259
268
  try:
260
269
  save_args = orbax_utils.save_args_from_target(ckpt)
261
- self.checkpointer.save(epoch, ckpt, save_kwargs={
270
+ self.checkpointer.save(step, ckpt, save_kwargs={
262
271
  'save_args': save_args}, force=True)
263
272
  self.checkpointer.wait_until_finished()
264
273
  pass
@@ -350,9 +359,10 @@ class SimpleTrainer:
350
359
  else:
351
360
  global_device_indexes = 0
352
361
 
353
- def train_loop(current_epoch, pbar: tqdm.tqdm, train_state, rng_state):
362
+ def train_loop(current_step, pbar: tqdm.tqdm, train_state, rng_state):
354
363
  epoch_loss = 0
355
- current_step = 0
364
+ current_epoch = current_step // steps_per_epoch
365
+ last_save_time = time.time()
356
366
  for i in range(steps_per_epoch):
357
367
  batch = next(train_ds)
358
368
  if self.distributed_training and global_device_count > 1:
@@ -363,36 +373,46 @@ class SimpleTrainer:
363
373
  if self.distributed_training:
364
374
  loss = jax.experimental.multihost_utils.process_allgather(loss)
365
375
  loss = jnp.mean(loss) # Just to make sure its a scaler value
376
+
377
+ if loss <= 1e-6:
378
+ # If the loss is too low, we can assume the model has diverged
379
+ print(colored(f"Loss too low at step {current_step} => {loss}", 'red'))
380
+ # Exit the training loop
381
+ exit(1)
366
382
 
367
383
  epoch_loss += loss
368
-
369
- if pbar is not None:
370
- if i % 100 == 0:
384
+ current_step += 1
385
+ if i % 100 == 0:
386
+ if pbar is not None:
371
387
  pbar.set_postfix(loss=f'{loss:.4f}')
372
388
  pbar.update(100)
373
- current_step = current_epoch*steps_per_epoch + i
374
389
  if self.wandb is not None:
375
390
  self.wandb.log({
376
391
  "train/step" : current_step,
377
392
  "train/loss": loss,
378
393
  }, step=current_step)
394
+ # Save the model every 40 minutes
395
+ if time.time() - last_save_time > 40 * 60:
396
+ print(f"Saving model after 40 minutes at step {current_step}")
397
+ self.save(current_epoch, current_step)
398
+ last_save_time = time.time()
379
399
  print(colored(f"Epoch done on index {process_index} => {current_epoch} Loss: {epoch_loss/steps_per_epoch}", 'green'))
380
400
  return epoch_loss, current_step, train_state, rng_state
381
401
 
382
- while self.latest_epoch < epochs:
383
- current_epoch = self.latest_epoch
384
- self.latest_epoch += 1
402
+ while self.latest_step < epochs * steps_per_epoch:
403
+ current_epoch = self.latest_step // steps_per_epoch
385
404
  print(f"\nEpoch {current_epoch}/{epochs}")
386
405
  start_time = time.time()
387
406
  epoch_loss = 0
388
407
 
389
408
  if process_index == 0:
390
409
  with tqdm.tqdm(total=steps_per_epoch, desc=f'\t\tEpoch {current_epoch}', ncols=100, unit='step') as pbar:
391
- epoch_loss, current_step, train_state, rng_state = train_loop(current_epoch, pbar, train_state, rng_state)
410
+ epoch_loss, current_step, train_state, rng_state = train_loop(self.latest_step, pbar, train_state, rng_state)
392
411
  else:
393
- epoch_loss, current_step, train_state, rng_state = train_loop(current_epoch, None, train_state, rng_state)
394
- print(colored(f"Epoch done on process index {process_index}", PROCESS_COLOR_MAP.get(process_index, 'white')))
395
-
412
+ epoch_loss, current_step, train_state, rng_state = train_loop(self.latest_step, None, train_state, rng_state)
413
+ print(colored(f"Epoch done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
414
+
415
+ self.latest_step = current_step
396
416
  end_time = time.time()
397
417
  self.state = train_state
398
418
  self.rngstate = rng_state
@@ -402,7 +422,7 @@ class SimpleTrainer:
402
422
  if avg_loss < self.best_loss:
403
423
  self.best_loss = avg_loss
404
424
  self.best_state = train_state
405
- self.save(current_epoch)
425
+ self.save(current_epoch, current_step)
406
426
 
407
427
  if process_index == 0:
408
428
  if self.wandb is not None:
@@ -415,4 +435,4 @@ class SimpleTrainer:
415
435
  }, step=current_step)
416
436
  print(colored(f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss}", 'green'))
417
437
  self.save(epochs)
418
- return self.state
438
+ return self.state
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.8
3
+ Version: 0.1.9
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -234,6 +234,23 @@ plotImages(samples, dpi=300)
234
234
 
235
235
  ## Gallery
236
236
 
237
+ ### Images generated by Euler Ancestral Sampler in 200 Steps [text2image with CFG]
238
+ Model trained on Laion-Aesthetics 12M + CC12M + MS COCO + 1M aesthetic 6+ subset of COYO-700M on TPU-v4-32:
239
+ `a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden`
240
+
241
+ **Params**:
242
+ `Dataset: Laion-Aesthetics 12M + CC12M + MS COCO + 1M aesthetic 6+ subset of COYO-700M`
243
+ `Batch size: 256`
244
+ `Image Size: 128`
245
+ `Training Epochs: 5`
246
+ `Steps per epoch: 74573`
247
+ `Model Configurations: feature_depths=[128, 256, 512, 1024]`
248
+
249
+ `Training Noise Schedule: EDMNoiseScheduler`
250
+ `Inference Noise Schedule: KarrasEDMPredictor`
251
+
252
+ ![EulerA with CFG](images/medium_epoch5.png)
253
+
237
254
  ### Images generated by Euler Ancestral Sampler in 200 Steps [text2image with CFG]
238
255
  Images generated by the following prompts using classifier free guidance with guidance factor = 2:
239
256
  `'water tulip, a water lily, a water lily, a water lily, a photo of a marigold, a water lily, a water lily, a photo of a lotus, a photo of a lotus, a photo of a lotus, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose'`
@@ -11,7 +11,7 @@ required_packages=[
11
11
  setup(
12
12
  name='flaxdiff',
13
13
  packages=find_packages(),
14
- version='0.1.8',
14
+ version='0.1.9',
15
15
  description='A versatile and easy to understand Diffusion library',
16
16
  long_description=open('README.md').read(),
17
17
  long_description_content_type='text/markdown',
File without changes
File without changes
File without changes