flaxdiff 0.1.8__tar.gz → 0.1.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/PKG-INFO +18 -1
  2. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/README.md +17 -0
  3. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/models/attention.py +7 -5
  4. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/models/autoencoder/diffusers.py +1 -1
  5. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/models/common.py +14 -2
  6. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/models/simple_unet.py +27 -12
  7. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/models/simple_vit.py +13 -16
  8. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/trainer/diffusion_trainer.py +44 -12
  9. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/trainer/simple_trainer.py +84 -61
  10. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff.egg-info/PKG-INFO +18 -1
  11. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/setup.py +1 -1
  12. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/__init__.py +0 -0
  13. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/models/__init__.py +0 -0
  14. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/models/autoencoder/__init__.py +0 -0
  15. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
  16. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
  17. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/models/favor_fastattn.py +0 -0
  18. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/predictors/__init__.py +0 -0
  19. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/samplers/__init__.py +0 -0
  20. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/samplers/common.py +0 -0
  21. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/samplers/ddim.py +0 -0
  22. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/samplers/ddpm.py +0 -0
  23. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/samplers/euler.py +0 -0
  24. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/samplers/heun_sampler.py +0 -0
  25. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/samplers/multistep_dpm.py +0 -0
  26. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/samplers/rk4_sampler.py +0 -0
  27. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/schedulers/__init__.py +0 -0
  28. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/schedulers/common.py +0 -0
  29. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/schedulers/continuous.py +0 -0
  30. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/schedulers/cosine.py +0 -0
  31. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/schedulers/discrete.py +0 -0
  32. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/schedulers/exp.py +0 -0
  33. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/schedulers/karras.py +0 -0
  34. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/schedulers/linear.py +0 -0
  35. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/schedulers/sqrt.py +0 -0
  36. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/trainer/__init__.py +0 -0
  37. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
  38. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff/utils.py +0 -0
  39. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff.egg-info/SOURCES.txt +0 -0
  40. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff.egg-info/dependency_links.txt +0 -0
  41. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff.egg-info/requires.txt +0 -0
  42. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/flaxdiff.egg-info/top_level.txt +0 -0
  43. {flaxdiff-0.1.8 → flaxdiff-0.1.10}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.8
3
+ Version: 0.1.10
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -234,6 +234,23 @@ plotImages(samples, dpi=300)
234
234
 
235
235
  ## Gallery
236
236
 
237
+ ### Images generated by Euler Ancestral Sampler in 200 Steps [text2image with CFG]
238
+ Model trained on Laion-Aesthetics 12M + CC12M + MS COCO + 1M aesthetic 6+ subset of COYO-700M on TPU-v4-32:
239
+ `a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden`
240
+
241
+ **Params**:
242
+ `Dataset: Laion-Aesthetics 12M + CC12M + MS COCO + 1M aesthetic 6+ subset of COYO-700M`
243
+ `Batch size: 256`
244
+ `Image Size: 128`
245
+ `Training Epochs: 5`
246
+ `Steps per epoch: 74573`
247
+ `Model Configurations: feature_depths=[128, 256, 512, 1024]`
248
+
249
+ `Training Noise Schedule: EDMNoiseScheduler`
250
+ `Inference Noise Schedule: KarrasEDMPredictor`
251
+
252
+ ![EulerA with CFG](images/medium_epoch5.png)
253
+
237
254
  ### Images generated by Euler Ancestral Sampler in 200 Steps [text2image with CFG]
238
255
  Images generated by the following prompts using classifier free guidance with guidance factor = 2:
239
256
  `'water tulip, a water lily, a water lily, a water lily, a photo of a marigold, a water lily, a water lily, a photo of a lotus, a photo of a lotus, a photo of a lotus, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose'`
@@ -221,6 +221,23 @@ plotImages(samples, dpi=300)
221
221
 
222
222
  ## Gallery
223
223
 
224
+ ### Images generated by Euler Ancestral Sampler in 200 Steps [text2image with CFG]
225
+ Model trained on Laion-Aesthetics 12M + CC12M + MS COCO + 1M aesthetic 6+ subset of COYO-700M on TPU-v4-32:
226
+ `a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden`
227
+
228
+ **Params**:
229
+ `Dataset: Laion-Aesthetics 12M + CC12M + MS COCO + 1M aesthetic 6+ subset of COYO-700M`
230
+ `Batch size: 256`
231
+ `Image Size: 128`
232
+ `Training Epochs: 5`
233
+ `Steps per epoch: 74573`
234
+ `Model Configurations: feature_depths=[128, 256, 512, 1024]`
235
+
236
+ `Training Noise Schedule: EDMNoiseScheduler`
237
+ `Inference Noise Schedule: KarrasEDMPredictor`
238
+
239
+ ![EulerA with CFG](images/medium_epoch5.png)
240
+
224
241
  ### Images generated by Euler Ancestral Sampler in 200 Steps [text2image with CFG]
225
242
  Images generated by the following prompts using classifier free guidance with guidance factor = 2:
226
243
  `'water tulip, a water lily, a water lily, a water lily, a photo of a marigold, a water lily, a water lily, a photo of a lotus, a photo of a lotus, a photo of a lotus, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose'`
@@ -156,7 +156,9 @@ class NormalAttention(nn.Module):
156
156
  value = self.value(context)
157
157
 
158
158
  hidden_states = nn.dot_product_attention(
159
- query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision
159
+ query, key, value, dtype=self.dtype, broadcast_dropout=False,
160
+ dropout_rng=None, precision=self.precision, force_fp32_for_softmax=True,
161
+ deterministic=True
160
162
  )
161
163
  proj = self.proj_attn(hidden_states)
162
164
  proj = proj.reshape(orig_x_shape)
@@ -187,7 +189,7 @@ class FlaxGEGLU(nn.Module):
187
189
 
188
190
  def __call__(self, hidden_states):
189
191
  hidden_states = self.proj(hidden_states)
190
- hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=3)
192
+ hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=-1)
191
193
  return hidden_linear * nn.gelu(hidden_gelu)
192
194
 
193
195
  class FlaxFeedForward(nn.Module):
@@ -291,14 +293,14 @@ class TransformerBlock(nn.Module):
291
293
  dtype: Optional[Dtype] = None
292
294
  precision: PrecisionLike = None
293
295
  use_projection: bool = False
294
- use_flash_attention:bool = True
295
- use_self_and_cross:bool = False
296
+ use_flash_attention:bool = False
297
+ use_self_and_cross:bool = True
296
298
  only_pure_attention:bool = False
297
299
 
298
300
  @nn.compact
299
301
  def __call__(self, x, context=None):
300
302
  inner_dim = self.heads * self.dim_head
301
- B, H, W, C = x.shape
303
+ C = x.shape[-1]
302
304
  normed_x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
303
305
  if self.use_projection == True:
304
306
  if self.use_linear_attention:
@@ -78,7 +78,7 @@ class StableDiffusionVAE(AutoEncoder):
78
78
  log_std = jnp.clip(log_std, -30, 20)
79
79
  std = jnp.exp(0.5 * log_std)
80
80
  latents = mean + std * jax.random.normal(rngkey, mean.shape, dtype=mean.dtype)
81
- print("Sampled")
81
+ # print("Sampled")
82
82
  else:
83
83
  # return the mean
84
84
  latents, _ = jnp.split(latents, 2, axis=-1)
@@ -5,6 +5,7 @@ from typing import Optional, Any, Callable, Sequence, Union
5
5
  from flax.typing import Dtype, PrecisionLike
6
6
  from typing import Dict, Callable, Sequence, Any, Union
7
7
  import einops
8
+ from functools import partial
8
9
 
9
10
  # Kernel initializer to use
10
11
  def kernel_init(scale, dtype=jnp.float32):
@@ -266,11 +267,22 @@ class ResidualBlock(nn.Module):
266
267
  kernel_init:Callable=kernel_init(1.0)
267
268
  dtype: Optional[Dtype] = None
268
269
  precision: PrecisionLike = None
270
+ named_norms:bool=False
271
+
272
+ def setup(self):
273
+ if self.norm_groups > 0:
274
+ norm = partial(nn.GroupNorm, self.norm_groups)
275
+ self.norm1 = norm(name="GroupNorm_0") if self.named_norms else norm()
276
+ self.norm2 = norm(name="GroupNorm_1") if self.named_norms else norm()
277
+ else:
278
+ norm = partial(nn.RMSNorm, 1e-5)
279
+ self.norm1 = norm()
280
+ self.norm2 = norm()
269
281
 
270
282
  @nn.compact
271
283
  def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_features:jax.Array=None):
272
284
  residual = x
273
- out = nn.GroupNorm(self.norm_groups)(x)
285
+ out = self.norm1(x)
274
286
  # out = nn.RMSNorm()(x)
275
287
  out = self.activation(out)
276
288
 
@@ -295,7 +307,7 @@ class ResidualBlock(nn.Module):
295
307
  # out = out * (1 + scale) + shift
296
308
  out = out + temb
297
309
 
298
- out = nn.GroupNorm(self.norm_groups)(out)
310
+ out = self.norm2(out)
299
311
  # out = nn.RMSNorm()(out)
300
312
  out = self.activation(out)
301
313
 
@@ -6,6 +6,7 @@ from typing import Dict, Callable, Sequence, Any, Union, Optional
6
6
  import einops
7
7
  from .common import kernel_init, ConvLayer, Downsample, Upsample, FourierEmbedding, TimeProjection, ResidualBlock
8
8
  from .attention import TransformerBlock
9
+ from functools import partial
9
10
 
10
11
  class Unet(nn.Module):
11
12
  output_channels:int=3
@@ -18,7 +19,16 @@ class Unet(nn.Module):
18
19
  norm_groups:int=8
19
20
  dtype: Optional[Dtype] = None
20
21
  precision: PrecisionLike = None
22
+ named_norms: bool = False # This is for backward compatibility reasons; older checkpoints have named norms
21
23
 
24
+ def setup(self):
25
+ if self.norm_groups > 0:
26
+ norm = partial(nn.GroupNorm, self.norm_groups)
27
+ self.conv_out_norm = norm(name="GroupNorm_0") if self.named_norms else norm()
28
+ else:
29
+ norm = partial(nn.RMSNorm, 1e-5)
30
+ self.conv_out_norm = norm()
31
+
22
32
  @nn.compact
23
33
  def __call__(self, x, temb, textcontext):
24
34
  # print("embedding features", self.emb_features)
@@ -60,7 +70,8 @@ class Unet(nn.Module):
60
70
  activation=self.activation,
61
71
  norm_groups=self.norm_groups,
62
72
  dtype=self.dtype,
63
- precision=self.precision
73
+ precision=self.precision,
74
+ named_norms=self.named_norms
64
75
  )(x, temb)
65
76
  if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
66
77
  x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
@@ -69,7 +80,7 @@ class Unet(nn.Module):
69
80
  use_projection=attention_config.get("use_projection", False),
70
81
  use_self_and_cross=attention_config.get("use_self_and_cross", True),
71
82
  precision=attention_config.get("precision", self.precision),
72
- only_pure_attention=True,
83
+ only_pure_attention=attention_config.get("only_pure_attention", True),
73
84
  name=f"down_{i}_attention_{j}")(x, textcontext)
74
85
  # print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
75
86
  downs.append(x)
@@ -98,7 +109,8 @@ class Unet(nn.Module):
98
109
  activation=self.activation,
99
110
  norm_groups=self.norm_groups,
100
111
  dtype=self.dtype,
101
- precision=self.precision
112
+ precision=self.precision,
113
+ named_norms=self.named_norms
102
114
  )(x, temb)
103
115
  if middle_attention is not None and j == self.num_middle_res_blocks - 1: # Apply attention only on the last block
104
116
  x = TransformerBlock(heads=middle_attention['heads'], dtype=middle_attention.get('dtype', jnp.float32),
@@ -107,8 +119,8 @@ class Unet(nn.Module):
107
119
  use_linear_attention=False,
108
120
  use_projection=middle_attention.get("use_projection", False),
109
121
  use_self_and_cross=False,
110
- precision=attention_config.get("precision", self.precision),
111
- only_pure_attention=True,
122
+ precision=middle_attention.get("precision", self.precision),
123
+ only_pure_attention=middle_attention.get("only_pure_attention", True),
112
124
  name=f"middle_attention_{j}")(x, textcontext)
113
125
  x = ResidualBlock(
114
126
  middle_conv_type,
@@ -120,7 +132,8 @@ class Unet(nn.Module):
120
132
  activation=self.activation,
121
133
  norm_groups=self.norm_groups,
122
134
  dtype=self.dtype,
123
- precision=self.precision
135
+ precision=self.precision,
136
+ named_norms=self.named_norms
124
137
  )(x, temb)
125
138
 
126
139
  # Upscaling Blocks
@@ -141,7 +154,8 @@ class Unet(nn.Module):
141
154
  activation=self.activation,
142
155
  norm_groups=self.norm_groups,
143
156
  dtype=self.dtype,
144
- precision=self.precision
157
+ precision=self.precision,
158
+ named_norms=self.named_norms
145
159
  )(x, temb)
146
160
  if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
147
161
  x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
@@ -150,7 +164,7 @@ class Unet(nn.Module):
150
164
  use_projection=attention_config.get("use_projection", False),
151
165
  use_self_and_cross=attention_config.get("use_self_and_cross", True),
152
166
  precision=attention_config.get("precision", self.precision),
153
- only_pure_attention=True,
167
+ only_pure_attention=attention_config.get("only_pure_attention", True),
154
168
  name=f"up_{i}_attention_{j}")(x, textcontext)
155
169
  # print("Upscaling ", i, x.shape)
156
170
  if i != len(feature_depths) - 1:
@@ -163,13 +177,13 @@ class Unet(nn.Module):
163
177
  precision=self.precision
164
178
  )(x)
165
179
 
166
- # x = nn.GroupNorm(8)(x)
180
+ # x = self.last_up_norm(x)
167
181
  x = ConvLayer(
168
182
  conv_type,
169
183
  features=self.feature_depths[0],
170
184
  kernel_size=(3, 3),
171
185
  strides=(1, 1),
172
- kernel_init=kernel_init(0.0),
186
+ kernel_init=kernel_init(1.0),
173
187
  dtype=self.dtype,
174
188
  precision=self.precision
175
189
  )(x)
@@ -186,10 +200,11 @@ class Unet(nn.Module):
186
200
  activation=self.activation,
187
201
  norm_groups=self.norm_groups,
188
202
  dtype=self.dtype,
189
- precision=self.precision
203
+ precision=self.precision,
204
+ named_norms=self.named_norms
190
205
  )(x, temb)
191
206
 
192
- x = nn.GroupNorm(self.norm_groups)(x)
207
+ x = self.conv_out_norm(x)
193
208
  x = self.activation(x)
194
209
 
195
210
  noise_out = ConvLayer(
@@ -4,7 +4,7 @@ import jax
4
4
  import jax.numpy as jnp
5
5
  from flax import linen as nn
6
6
  from typing import Callable, Any
7
- from .simply_unet import FourierEmbedding, TimeProjection, ConvLayer, kernel_init
7
+ from .simple_unet import FourierEmbedding, TimeProjection, ConvLayer, kernel_init
8
8
  from .attention import TransformerBlock
9
9
 
10
10
  class PatchEmbedding(nn.Module):
@@ -40,22 +40,23 @@ class PositionalEncoding(nn.Module):
40
40
  class TransformerEncoder(nn.Module):
41
41
  num_layers: int
42
42
  num_heads: int
43
- mlp_dim: int
44
43
  dropout_rate: float = 0.1
45
44
  dtype: Any = jnp.float32
46
45
  precision: Any = jax.lax.Precision.HIGH
46
+ use_projection: bool = False
47
47
 
48
48
  @nn.compact
49
- def __call__(self, x, training=True):
49
+ def __call__(self, x, context=None):
50
50
  for _ in range(self.num_layers):
51
51
  x = TransformerBlock(
52
52
  heads=self.num_heads,
53
53
  dim_head=x.shape[-1] // self.num_heads,
54
- mlp_dim=self.mlp_dim,
55
54
  dropout_rate=self.dropout_rate,
56
55
  dtype=self.dtype,
57
- precision=self.precision
58
- )(x)
56
+ precision=self.precision,
57
+ use_self_and_cross=True,
58
+ use_projection=self.use_projection,
59
+ )(x, context)
59
60
  return x
60
61
 
61
62
  class VisionTransformer(nn.Module):
@@ -63,11 +64,11 @@ class VisionTransformer(nn.Module):
63
64
  embedding_dim: int = 768
64
65
  num_layers: int = 12
65
66
  num_heads: int = 12
66
- mlp_dim: int = 3072
67
67
  emb_features: int = 256
68
68
  dropout_rate: float = 0.1
69
69
  dtype: Any = jnp.float32
70
70
  precision: Any = jax.lax.Precision.HIGH
71
+ use_projection: bool = False
71
72
 
72
73
  @nn.compact
73
74
  def __call__(self, x, temb, textcontext=None):
@@ -81,27 +82,23 @@ class VisionTransformer(nn.Module):
81
82
 
82
83
  # Add positional encoding
83
84
  x = PositionalEncoding(max_len=x.shape[1], embedding_dim=self.embedding_dim)(x)
85
+
86
+ num_patches = x.shape[1]
84
87
 
85
88
  # Add time embedding
86
89
  temb = jnp.expand_dims(temb, axis=1)
87
90
  x = jnp.concatenate([x, temb], axis=1)
88
91
 
89
- # Add text context
90
- if textcontext is not None:
91
- x = jnp.concatenate([x, textcontext], axis=1)
92
-
93
92
  # Transformer encoder
94
93
  x = TransformerEncoder(
95
94
  num_layers=self.num_layers,
96
95
  num_heads=self.num_heads,
97
- mlp_dim=self.mlp_dim,
98
96
  dropout_rate=self.dropout_rate,
99
97
  dtype=self.dtype,
100
- precision=self.precision
101
- )(x)
98
+ precision=self.precision,
99
+ use_projection=self.use_projection
100
+ )(x, textcontext)
102
101
 
103
- # Extract the image tokens (exclude time and text embeddings)
104
- num_patches = (x.shape[1] - 1 - (0 if textcontext is None else textcontext.shape[1]))
105
102
  x = x[:, :num_patches, :]
106
103
 
107
104
  # Reshape to image dimensions
@@ -16,6 +16,7 @@ from flaxdiff.utils import RandomMarkovState
16
16
  from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
17
17
 
18
18
  from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
19
+ from flax.training.dynamic_scale import DynamicScale
19
20
 
20
21
  class TrainState(SimpleTrainState):
21
22
  rngs: jax.random.PRNGKey
@@ -29,6 +30,8 @@ class TrainState(SimpleTrainState):
29
30
  )
30
31
  return self.replace(ema_params=new_ema_params)
31
32
 
33
+ from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
34
+
32
35
  class DiffusionTrainer(SimpleTrainer):
33
36
  noise_schedule: NoiseScheduler
34
37
  model_output_transform: DiffusionPredictionTransform
@@ -40,7 +43,7 @@ class DiffusionTrainer(SimpleTrainer):
40
43
  optimizer: optax.GradientTransformation,
41
44
  noise_schedule: NoiseScheduler,
42
45
  rngs: jax.random.PRNGKey,
43
- unconditional_prob: float = 0.2,
46
+ unconditional_prob: float = 0.12,
44
47
  name: str = "Diffusion",
45
48
  model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
46
49
  autoencoder: AutoEncoder = None,
@@ -67,7 +70,8 @@ class DiffusionTrainer(SimpleTrainer):
67
70
  existing_state: dict = None,
68
71
  existing_best_state: dict = None,
69
72
  model: nn.Module = None,
70
- param_transforms: Callable = None
73
+ param_transforms: Callable = None,
74
+ use_dynamic_scale: bool = False
71
75
  ) -> Tuple[TrainState, TrainState]:
72
76
  print("Generating states for DiffusionTrainer")
73
77
  rngs, subkey = jax.random.split(rngs)
@@ -80,7 +84,8 @@ class DiffusionTrainer(SimpleTrainer):
80
84
  new_state = existing_state
81
85
 
82
86
  if param_transforms is not None:
83
- params = param_transforms(params)
87
+ new_state['params'] = param_transforms(new_state['params'])
88
+ new_state['ema_params'] = param_transforms(new_state['ema_params'])
84
89
 
85
90
  state = TrainState.create(
86
91
  apply_fn=model.apply,
@@ -88,7 +93,8 @@ class DiffusionTrainer(SimpleTrainer):
88
93
  ema_params=new_state['ema_params'],
89
94
  tx=optimizer,
90
95
  rngs=rngs,
91
- metrics=Metrics.empty()
96
+ metrics=Metrics.empty(),
97
+ dynamic_scale = DynamicScale() if use_dynamic_scale else None
92
98
  )
93
99
 
94
100
  if existing_best_state is not None:
@@ -125,14 +131,14 @@ class DiffusionTrainer(SimpleTrainer):
125
131
  local_rng_state = RandomMarkovState(subkey)
126
132
 
127
133
  images = batch['image']
134
+ images = jnp.array(images, dtype=jnp.float32)
135
+ # normalize image
136
+ images = (images - 127.5) / 127.5
128
137
 
129
138
  if autoencoder is not None:
130
139
  # Convert the images to latent space
131
140
  local_rng_state, rngs = local_rng_state.get_random_key()
132
141
  images = autoencoder.encode(images, rngs)
133
- else:
134
- # normalize image
135
- images = (images - 127.5) / 127.5
136
142
 
137
143
  output = text_embedder(
138
144
  input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
@@ -163,12 +169,39 @@ class DiffusionTrainer(SimpleTrainer):
163
169
  loss = nloss
164
170
  return loss
165
171
 
166
- loss, grads = jax.value_and_grad(model_loss)(train_state.params)
172
+
173
+ if train_state.dynamic_scale is not None:
174
+ # dynamic scale takes care of averaging gradients across replicas
175
+ grad_fn = train_state.dynamic_scale.value_and_grad(
176
+ model_loss, axis_name="data"
177
+ )
178
+ dynamic_scale, is_fin, loss, grads = grad_fn(train_state.params)
179
+ train_state = train_state.replace(dynamic_scale=dynamic_scale)
180
+ else:
181
+ grad_fn = jax.value_and_grad(model_loss)
182
+ loss, grads = grad_fn(train_state.params)
183
+ if distributed_training:
184
+ grads = jax.lax.pmean(grads, "data")
185
+
186
+ new_state = train_state.apply_gradients(grads=grads)
187
+
188
+ if train_state.dynamic_scale:
189
+ # if is_fin == False the gradients contain Inf/NaNs and optimizer state and
190
+ # params should be restored (= skip this step).
191
+ select_fn = functools.partial(jnp.where, is_fin)
192
+ new_state = train_state.replace(
193
+ opt_state=jax.tree_util.tree_map(
194
+ select_fn, new_state.opt_state, train_state.opt_state
195
+ ),
196
+ params=jax.tree_util.tree_map(
197
+ select_fn, new_state.params, train_state.params
198
+ ),
199
+ )
200
+
201
+ train_state = new_state.apply_ema(self.ema_decay)
202
+
167
203
  if distributed_training:
168
- grads = jax.lax.pmean(grads, "data")
169
204
  loss = jax.lax.pmean(loss, "data")
170
- train_state = train_state.apply_gradients(grads=grads)
171
- train_state = train_state.apply_ema(self.ema_decay)
172
205
  return train_state, loss, rng_state
173
206
 
174
207
  if distributed_training:
@@ -199,4 +232,3 @@ def boolean_string(s):
199
232
  if type(s) == bool:
200
233
  return s
201
234
  return s == 'True'
202
-
@@ -22,7 +22,7 @@ from jax.experimental.shard_map import shard_map
22
22
  from orbax.checkpoint.utils import fully_replicated_host_local_array_to_global_array
23
23
  from termcolor import colored
24
24
  from typing import Dict, Callable, Sequence, Any, Union, Tuple
25
-
25
+ from flax.training.dynamic_scale import DynamicScale
26
26
  from flaxdiff.utils import RandomMarkovState
27
27
 
28
28
  PROCESS_COLOR_MAP = {
@@ -39,23 +39,23 @@ PROCESS_COLOR_MAP = {
39
39
  def _build_global_shape_and_sharding(
40
40
  local_shape: tuple[int, ...], global_mesh: Mesh
41
41
  ) -> tuple[tuple[int, ...], jax.sharding.NamedSharding]:
42
- sharding = jax.sharding.NamedSharding(global_mesh, P(global_mesh.axis_names))
43
- global_shape = (jax.process_count() * local_shape[0],) + local_shape[1:]
44
- return global_shape, sharding
42
+ sharding = jax.sharding.NamedSharding(global_mesh, P(global_mesh.axis_names))
43
+ global_shape = (jax.process_count() * local_shape[0],) + local_shape[1:]
44
+ return global_shape, sharding
45
45
 
46
46
 
47
47
  def form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array:
48
- """Put local sharded array into local devices"""
49
- global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh)
50
- try:
51
- local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0)
52
- except ValueError as array_split_error:
53
- raise ValueError(
54
- f"Unable to put to devices shape {array.shape} with "
55
- f"local device count {len(global_mesh.local_devices)} "
56
- ) from array_split_error
57
- local_device_buffers = jax.device_put(local_device_arrays, global_mesh.local_devices)
58
- return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers)
48
+ """Put local sharded array into local devices"""
49
+ global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh)
50
+ try:
51
+ local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0)
52
+ except ValueError as array_split_error:
53
+ raise ValueError(
54
+ f"Unable to put to devices shape {array.shape} with "
55
+ f"local device count {len(global_mesh.local_devices)} "
56
+ ) from array_split_error
57
+ local_device_buffers = jax.device_put(local_device_arrays, global_mesh.local_devices)
58
+ return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers)
59
59
 
60
60
  def convert_to_global_tree(global_mesh, pytree):
61
61
  return jax.tree_util.tree_map_with_path(partial(form_global_array, global_mesh=global_mesh), pytree)
@@ -67,12 +67,8 @@ class Metrics(metrics.Collection):
67
67
 
68
68
  # Define the TrainState
69
69
  class SimpleTrainState(train_state.TrainState):
70
- rngs: jax.random.PRNGKey
71
70
  metrics: Metrics
72
-
73
- def get_random_key(self):
74
- rngs, subkey = jax.random.split(self.rngs)
75
- return self.replace(rngs=rngs), subkey
71
+ dynamic_scale: DynamicScale
76
72
 
77
73
  class SimpleTrainer:
78
74
  state: SimpleTrainState
@@ -88,20 +84,22 @@ class SimpleTrainer:
88
84
  rngs: jax.random.PRNGKey,
89
85
  train_state: SimpleTrainState = None,
90
86
  name: str = "Simple",
91
- load_from_checkpoint: bool = False,
87
+ load_from_checkpoint: str = None,
92
88
  checkpoint_suffix: str = "",
93
- checkpoint_id: str = None,
94
89
  loss_fn=optax.l2_loss,
95
90
  param_transforms: Callable = None,
96
91
  wandb_config: Dict[str, Any] = None,
97
92
  distributed_training: bool = None,
98
93
  checkpoint_base_path: str = "./checkpoints",
94
+ checkpoint_step: int = None,
95
+ use_dynamic_scale: bool = False,
99
96
  ):
100
97
  if distributed_training is None or distributed_training is True:
101
98
  # Auto-detect if we are running on multiple devices
102
99
  distributed_training = jax.device_count() > 1
103
100
  self.mesh = jax.sharding.Mesh(jax.devices(), 'data')
104
- # self.sharding = jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec('data'))
101
+ else:
102
+ self.mesh = None
105
103
 
106
104
  self.distributed_training = distributed_training
107
105
  self.model = model
@@ -112,7 +110,6 @@ class SimpleTrainer:
112
110
 
113
111
 
114
112
  if wandb_config is not None and jax.process_index() == 0:
115
- import wandb
116
113
  run = wandb.init(**wandb_config)
117
114
  self.wandb = run
118
115
 
@@ -126,11 +123,6 @@ class SimpleTrainer:
126
123
  self.wandb.define_metric("train/avg_time_per_step", step_metric="train/epoch")
127
124
  self.wandb.define_metric("train/avg_loss", step_metric="train/epoch")
128
125
  self.wandb.define_metric("train/best_loss", step_metric="train/epoch")
129
-
130
- if checkpoint_id is None:
131
- self.checkpoint_id = name.replace(' ', '_').replace('-', '_').lower()
132
- else:
133
- self.checkpoint_id = checkpoint_id
134
126
 
135
127
  # checkpointer = orbax.checkpoint.PyTreeCheckpointer()
136
128
  async_checkpointer = orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler(), timeout_secs=60)
@@ -140,12 +132,12 @@ class SimpleTrainer:
140
132
  self.checkpointer = orbax.checkpoint.CheckpointManager(
141
133
  self.checkpoint_path() + checkpoint_suffix, async_checkpointer, options)
142
134
 
143
- if load_from_checkpoint:
144
- latest_epoch, old_state, old_best_state, rngstate = self.load()
135
+ if load_from_checkpoint is not None:
136
+ latest_epoch, latest_step, old_state, old_best_state, rngstate = self.load(load_from_checkpoint, checkpoint_step)
145
137
  else:
146
- latest_epoch, old_state, old_best_state, rngstate = 0, None, None, None
138
+ latest_epoch, latest_step, old_state, old_best_state, rngstate = 0, 0, None, None, None
147
139
 
148
- self.latest_epoch = latest_epoch
140
+ self.latest_step = latest_step
149
141
 
150
142
  if rngstate:
151
143
  self.rngstate = RandomMarkovState(**rngstate)
@@ -156,7 +148,7 @@ class SimpleTrainer:
156
148
 
157
149
  if train_state == None:
158
150
  state, best_state = self.generate_states(
159
- optimizer, subkey, old_state, old_best_state, model, param_transforms
151
+ optimizer, subkey, old_state, old_best_state, model, param_transforms, use_dynamic_scale
160
152
  )
161
153
  self.init_state(state, best_state)
162
154
  else:
@@ -174,7 +166,8 @@ class SimpleTrainer:
174
166
  existing_state: dict = None,
175
167
  existing_best_state: dict = None,
176
168
  model: nn.Module = None,
177
- param_transforms: Callable = None
169
+ param_transforms: Callable = None,
170
+ use_dynamic_scale: bool = False
178
171
  ) -> Tuple[SimpleTrainState, SimpleTrainState]:
179
172
  print("Generating states for SimpleTrainer")
180
173
  rngs, subkey = jax.random.split(rngs)
@@ -184,12 +177,16 @@ class SimpleTrainer:
184
177
  params = model.init(subkey, **input_vars)
185
178
  else:
186
179
  params = existing_state['params']
180
+
181
+ if param_transforms is not None:
182
+ params = param_transforms(params)
187
183
 
188
184
  state = SimpleTrainState.create(
189
185
  apply_fn=model.apply,
190
186
  params=params,
191
187
  tx=optimizer,
192
- metrics=Metrics.empty()
188
+ metrics=Metrics.empty(),
189
+ dynamic_scale = DynamicScale() if use_dynamic_scale else None
193
190
  )
194
191
  if existing_best_state is not None:
195
192
  best_state = state.replace(
@@ -222,7 +219,7 @@ class SimpleTrainer:
222
219
  return jax.tree_util.tree_map(lambda x : np.array(x), self.rngstate)
223
220
 
224
221
  def checkpoint_path(self):
225
- path = os.path.join(self.checkpoint_base_path, self.checkpoint_id)
222
+ path = os.path.join(self.checkpoint_base_path, self.name.replace(' ', '_').lower())
226
223
  if not os.path.exists(path):
227
224
  os.makedirs(path)
228
225
  return path
@@ -234,31 +231,46 @@ class SimpleTrainer:
234
231
  os.makedirs(path)
235
232
  return path
236
233
 
237
- def load(self):
238
- epoch = self.checkpointer.latest_step()
239
- print("Loading model from checkpoint", epoch)
240
- ckpt = self.checkpointer.restore(epoch)
234
+ def load(self, checkpoint_path=None, checkpoint_step=None):
235
+ if checkpoint_path is None:
236
+ checkpointer = self.checkpointer
237
+ else:
238
+ checkpointer = orbax.checkpoint.PyTreeCheckpointer()
239
+ options = orbax.checkpoint.CheckpointManagerOptions(
240
+ max_to_keep=4, create=False)
241
+ checkpointer = orbax.checkpoint.CheckpointManager(
242
+ checkpoint_path, checkpointer, options)
243
+
244
+ if checkpoint_step is None:
245
+ step = checkpointer.latest_step()
246
+ else:
247
+ step = checkpoint_step
248
+
249
+ print("Loading model from checkpoint at step ", step)
250
+ ckpt = checkpointer.restore(step)
241
251
  state = ckpt['state']
242
252
  best_state = ckpt['best_state']
243
253
  rngstate = ckpt['rngs']
244
254
  # Convert the state to a TrainState
245
255
  self.best_loss = ckpt['best_loss']
256
+ current_epoch = ckpt.get('epoch', step) # Must be a checkpoint from an older version which used epochs instead of steps
246
257
  print(
247
- f"Loaded model from checkpoint at epoch {epoch}", ckpt['best_loss'])
248
- return epoch, state, best_state, rngstate
258
+ f"Loaded model from checkpoint at epoch {current_epoch} step {step}", ckpt['best_loss'])
259
+ return current_epoch, step, state, best_state, rngstate
249
260
 
250
- def save(self, epoch=0):
251
- print(f"Saving model at epoch {epoch}")
261
+ def save(self, epoch=0, step=0):
262
+ print(f"Saving model at epoch {epoch} step {step}")
252
263
  ckpt = {
253
264
  # 'model': self.model,
254
265
  'rngs': self.get_rngstate(),
255
266
  'state': self.get_state(),
256
267
  'best_state': self.get_best_state(),
257
268
  'best_loss': np.array(self.best_loss),
269
+ 'epoch': epoch,
258
270
  }
259
271
  try:
260
272
  save_args = orbax_utils.save_args_from_target(ckpt)
261
- self.checkpointer.save(epoch, ckpt, save_kwargs={
273
+ self.checkpointer.save(step, ckpt, save_kwargs={
262
274
  'save_args': save_args}, force=True)
263
275
  self.checkpointer.wait_until_finished()
264
276
  pass
@@ -350,9 +362,10 @@ class SimpleTrainer:
350
362
  else:
351
363
  global_device_indexes = 0
352
364
 
353
- def train_loop(current_epoch, pbar: tqdm.tqdm, train_state, rng_state):
365
+ def train_loop(current_step, pbar: tqdm.tqdm, train_state, rng_state):
354
366
  epoch_loss = 0
355
- current_step = 0
367
+ current_epoch = current_step // steps_per_epoch
368
+ last_save_time = time.time()
356
369
  for i in range(steps_per_epoch):
357
370
  batch = next(train_ds)
358
371
  if self.distributed_training and global_device_count > 1:
@@ -363,36 +376,46 @@ class SimpleTrainer:
363
376
  if self.distributed_training:
364
377
  loss = jax.experimental.multihost_utils.process_allgather(loss)
365
378
  loss = jnp.mean(loss) # Just to make sure its a scaler value
379
+
380
+ if loss <= 1e-6:
381
+ # If the loss is too low, we can assume the model has diverged
382
+ print(colored(f"Loss too low at step {current_step} => {loss}", 'red'))
383
+ # Exit the training loop
384
+ exit(1)
366
385
 
367
386
  epoch_loss += loss
368
-
369
- if pbar is not None:
370
- if i % 100 == 0:
387
+ current_step += 1
388
+ if i % 100 == 0:
389
+ if pbar is not None:
371
390
  pbar.set_postfix(loss=f'{loss:.4f}')
372
391
  pbar.update(100)
373
- current_step = current_epoch*steps_per_epoch + i
374
392
  if self.wandb is not None:
375
393
  self.wandb.log({
376
394
  "train/step" : current_step,
377
395
  "train/loss": loss,
378
396
  }, step=current_step)
397
+ # Save the model every 40 minutes
398
+ if time.time() - last_save_time > 40 * 60:
399
+ print(f"Saving model after 40 minutes at step {current_step}")
400
+ self.save(current_epoch, current_step)
401
+ last_save_time = time.time()
379
402
  print(colored(f"Epoch done on index {process_index} => {current_epoch} Loss: {epoch_loss/steps_per_epoch}", 'green'))
380
403
  return epoch_loss, current_step, train_state, rng_state
381
404
 
382
- while self.latest_epoch < epochs:
383
- current_epoch = self.latest_epoch
384
- self.latest_epoch += 1
405
+ while self.latest_step < epochs * steps_per_epoch:
406
+ current_epoch = self.latest_step // steps_per_epoch
385
407
  print(f"\nEpoch {current_epoch}/{epochs}")
386
408
  start_time = time.time()
387
409
  epoch_loss = 0
388
410
 
389
411
  if process_index == 0:
390
412
  with tqdm.tqdm(total=steps_per_epoch, desc=f'\t\tEpoch {current_epoch}', ncols=100, unit='step') as pbar:
391
- epoch_loss, current_step, train_state, rng_state = train_loop(current_epoch, pbar, train_state, rng_state)
413
+ epoch_loss, current_step, train_state, rng_state = train_loop(self.latest_step, pbar, train_state, rng_state)
392
414
  else:
393
- epoch_loss, current_step, train_state, rng_state = train_loop(current_epoch, None, train_state, rng_state)
394
- print(colored(f"Epoch done on process index {process_index}", PROCESS_COLOR_MAP.get(process_index, 'white')))
395
-
415
+ epoch_loss, current_step, train_state, rng_state = train_loop(self.latest_step, None, train_state, rng_state)
416
+ print(colored(f"Epoch done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
417
+
418
+ self.latest_step = current_step
396
419
  end_time = time.time()
397
420
  self.state = train_state
398
421
  self.rngstate = rng_state
@@ -402,7 +425,7 @@ class SimpleTrainer:
402
425
  if avg_loss < self.best_loss:
403
426
  self.best_loss = avg_loss
404
427
  self.best_state = train_state
405
- self.save(current_epoch)
428
+ self.save(current_epoch, current_step)
406
429
 
407
430
  if process_index == 0:
408
431
  if self.wandb is not None:
@@ -415,4 +438,4 @@ class SimpleTrainer:
415
438
  }, step=current_step)
416
439
  print(colored(f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss}", 'green'))
417
440
  self.save(epochs)
418
- return self.state
441
+ return self.state
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.8
3
+ Version: 0.1.10
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -234,6 +234,23 @@ plotImages(samples, dpi=300)
234
234
 
235
235
  ## Gallery
236
236
 
237
+ ### Images generated by Euler Ancestral Sampler in 200 Steps [text2image with CFG]
238
+ Model trained on Laion-Aesthetics 12M + CC12M + MS COCO + 1M aesthetic 6+ subset of COYO-700M on TPU-v4-32:
239
+ `a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden`
240
+
241
+ **Params**:
242
+ `Dataset: Laion-Aesthetics 12M + CC12M + MS COCO + 1M aesthetic 6+ subset of COYO-700M`
243
+ `Batch size: 256`
244
+ `Image Size: 128`
245
+ `Training Epochs: 5`
246
+ `Steps per epoch: 74573`
247
+ `Model Configurations: feature_depths=[128, 256, 512, 1024]`
248
+
249
+ `Training Noise Schedule: EDMNoiseScheduler`
250
+ `Inference Noise Schedule: KarrasEDMPredictor`
251
+
252
+ ![EulerA with CFG](images/medium_epoch5.png)
253
+
237
254
  ### Images generated by Euler Ancestral Sampler in 200 Steps [text2image with CFG]
238
255
  Images generated by the following prompts using classifier free guidance with guidance factor = 2:
239
256
  `'water tulip, a water lily, a water lily, a water lily, a photo of a marigold, a water lily, a water lily, a photo of a lotus, a photo of a lotus, a photo of a lotus, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose'`
@@ -11,7 +11,7 @@ required_packages=[
11
11
  setup(
12
12
  name='flaxdiff',
13
13
  packages=find_packages(),
14
- version='0.1.8',
14
+ version='0.1.10',
15
15
  description='A versatile and easy to understand Diffusion library',
16
16
  long_description=open('README.md').read(),
17
17
  long_description_content_type='text/markdown',
File without changes
File without changes
File without changes