flaxdiff 0.1.5__tar.gz → 0.1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/PKG-INFO +3 -1
  2. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/README.md +2 -0
  3. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/models/attention.py +12 -11
  4. flaxdiff-0.1.6/flaxdiff/models/autoencoder/__init__.py +2 -0
  5. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/models/autoencoder/autoencoder.py +7 -2
  6. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/models/autoencoder/diffusers.py +3 -0
  7. flaxdiff-0.1.6/flaxdiff/models/autoencoder/simple_autoenc.py +26 -0
  8. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/models/common.py +89 -10
  9. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/models/simple_unet.py +4 -75
  10. flaxdiff-0.1.6/flaxdiff/trainer/__init__.py +2 -0
  11. flaxdiff-0.1.6/flaxdiff/trainer/autoencoder_trainer.py +182 -0
  12. flaxdiff-0.1.5/flaxdiff/trainer/__init__.py → flaxdiff-0.1.6/flaxdiff/trainer/diffusion_trainer.py +22 -4
  13. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff.egg-info/PKG-INFO +3 -1
  14. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff.egg-info/SOURCES.txt +3 -0
  15. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/setup.py +1 -1
  16. flaxdiff-0.1.5/flaxdiff/models/autoencoder/__init__.py +0 -0
  17. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/__init__.py +0 -0
  18. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/models/__init__.py +0 -0
  19. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/models/favor_fastattn.py +0 -0
  20. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/models/simple_vit.py +0 -0
  21. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/predictors/__init__.py +0 -0
  22. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/samplers/__init__.py +0 -0
  23. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/samplers/common.py +0 -0
  24. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/samplers/ddim.py +0 -0
  25. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/samplers/ddpm.py +0 -0
  26. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/samplers/euler.py +0 -0
  27. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/samplers/heun_sampler.py +0 -0
  28. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/samplers/multistep_dpm.py +0 -0
  29. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/samplers/rk4_sampler.py +0 -0
  30. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/schedulers/__init__.py +0 -0
  31. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/schedulers/common.py +0 -0
  32. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/schedulers/continuous.py +0 -0
  33. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/schedulers/cosine.py +0 -0
  34. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/schedulers/discrete.py +0 -0
  35. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/schedulers/exp.py +0 -0
  36. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/schedulers/karras.py +0 -0
  37. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/schedulers/linear.py +0 -0
  38. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/schedulers/sqrt.py +0 -0
  39. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/trainer/simple_trainer.py +0 -0
  40. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/utils.py +0 -0
  41. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff.egg-info/dependency_links.txt +0 -0
  42. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff.egg-info/requires.txt +0 -0
  43. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff.egg-info/top_level.txt +0 -0
  44. {flaxdiff-0.1.5 → flaxdiff-0.1.6}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.5
3
+ Version: 0.1.6
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -13,6 +13,8 @@ Requires-Dist: clu
13
13
 
14
14
  # ![](images/logo.jpeg "FlaxDiff")
15
15
 
16
+ **This project is partially supported by [Google TPU Research Cloud](https://sites.research.google/trc/about/). I would like to thank the Google Cloud TPU team for providing me with the resources to train the bigger text-conditional models in multi-host distributed settings.**
17
+
16
18
  ## A Versatile and simple Diffusion Library
17
19
 
18
20
  In recent years, diffusion and score-based multi-step models have revolutionized the generative AI domain. However, the latest research in this field has become highly math-intensive, making it challenging to understand how state-of-the-art diffusion models work and generate such impressive images. Replicating this research in code can be daunting.
@@ -1,5 +1,7 @@
1
1
  # ![](images/logo.jpeg "FlaxDiff")
2
2
 
3
+ **This project is partially supported by [Google TPU Research Cloud](https://sites.research.google/trc/about/). I would like to thank the Google Cloud TPU team for providing me with the resources to train the bigger text-conditional models in multi-host distributed settings.**
4
+
3
5
  ## A Versatile and simple Diffusion Library
4
6
 
5
7
  In recent years, diffusion and score-based multi-step models have revolutionized the generative AI domain. However, the latest research in this field has become highly math-intensive, making it challenging to understand how state-of-the-art diffusion models work and generate such impressive images. Replicating this research in code can be daunting.
@@ -5,7 +5,8 @@ Some Code ported from https://github.com/huggingface/diffusers/blob/main/src/dif
5
5
  import jax
6
6
  import jax.numpy as jnp
7
7
  from flax import linen as nn
8
- from typing import Dict, Callable, Sequence, Any, Union
8
+ from typing import Dict, Callable, Sequence, Any, Union, Tuple, Optional
9
+ from flax.typing import Dtype, PrecisionLike
9
10
  import einops
10
11
  import functools
11
12
  import math
@@ -18,8 +19,8 @@ class EfficientAttention(nn.Module):
18
19
  query_dim: int
19
20
  heads: int = 4
20
21
  dim_head: int = 64
21
- dtype: Any = jnp.float32
22
- precision: Any = jax.lax.Precision.HIGHEST
22
+ dtype: Optional[Dtype] = None
23
+ precision: PrecisionLike = None
23
24
  use_bias: bool = True
24
25
  kernel_init: Callable = lambda : kernel_init(1.0)
25
26
 
@@ -109,8 +110,8 @@ class NormalAttention(nn.Module):
109
110
  query_dim: int
110
111
  heads: int = 4
111
112
  dim_head: int = 64
112
- dtype: Any = jnp.float32
113
- precision: Any = jax.lax.Precision.HIGHEST
113
+ dtype: Optional[Dtype] = None
114
+ precision: PrecisionLike = None
114
115
  use_bias: bool = True
115
116
  kernel_init: Callable = lambda : kernel_init(1.0)
116
117
 
@@ -166,8 +167,8 @@ class BasicTransformerBlock(nn.Module):
166
167
  query_dim: int
167
168
  heads: int = 4
168
169
  dim_head: int = 64
169
- dtype: Any = jnp.float32
170
- precision: Any = jax.lax.Precision.HIGHEST
170
+ dtype: Optional[Dtype] = None
171
+ precision: PrecisionLike = None
171
172
  use_bias: bool = True
172
173
  kernel_init: Callable = lambda : kernel_init(1.0)
173
174
  use_flash_attention:bool = False
@@ -286,8 +287,8 @@ class BasicTransformerBlock(nn.Module):
286
287
  query_dim: int
287
288
  heads: int = 4
288
289
  dim_head: int = 64
289
- dtype: Any = jnp.float32
290
- precision: Any = jax.lax.Precision.HIGHEST
290
+ dtype: Optional[Dtype] = None
291
+ precision: PrecisionLike = None
291
292
  use_bias: bool = True
292
293
  kernel_init: Callable = lambda : kernel_init(1.0)
293
294
  use_flash_attention:bool = False
@@ -346,8 +347,8 @@ class TransformerBlock(nn.Module):
346
347
  heads: int = 4
347
348
  dim_head: int = 32
348
349
  use_linear_attention: bool = True
349
- dtype: Any = jnp.float32
350
- precision: Any = jax.lax.Precision.HIGH
350
+ dtype: Optional[Dtype] = None
351
+ precision: PrecisionLike = None
351
352
  use_projection: bool = False
352
353
  use_flash_attention:bool = True
353
354
  use_self_and_cross:bool = False
@@ -0,0 +1,2 @@
1
+ from .autoencoder import AutoEncoder
2
+ from .diffusers import StableDiffusionVAE
@@ -6,9 +6,14 @@ import einops
6
6
  from ..common import kernel_init, ConvLayer, Upsample, Downsample, PixelShuffle
7
7
 
8
8
 
9
- class AutoEncoder:
9
+ class AutoEncoder():
10
10
  def encode(self, x: jnp.ndarray, **kwargs) -> jnp.ndarray:
11
11
  raise NotImplementedError
12
12
 
13
13
  def decode(self, z: jnp.ndarray, **kwargs) -> jnp.ndarray:
14
- raise NotImplementedError
14
+ raise NotImplementedError
15
+
16
+ def __call__(self, x: jnp.ndarray):
17
+ latents = self.encode(x)
18
+ reconstructions = self.decode(latents)
19
+ return reconstructions
@@ -5,6 +5,9 @@ from .autoencoder import AutoEncoder
5
5
 
6
6
  """
7
7
  This module contains an Autoencoder implementation which uses the Stable Diffusion VAE model from the HuggingFace Diffusers library.
8
+ The actual model was not trained by me, but was taken from the HuggingFace model hub.
9
+ I have only implemented the wrapper around the diffusers pipeline to make it compatible with our library
10
+ All credits for the model go to the developers of Stable Diffusion VAE and all credits for the pipeline go to the developers of the Diffusers library.
8
11
  """
9
12
 
10
13
  class StableDiffusionVAE(AutoEncoder):
@@ -0,0 +1,26 @@
1
+ from typing import Any, List, Optional, Callable
2
+ import jax
3
+ import flax.linen as nn
4
+ from jax import numpy as jnp
5
+ from flax.typing import Dtype, PrecisionLike
6
+ from .autoencoder import AutoEncoder
7
+
8
+ class SimpleAutoEncoder(AutoEncoder):
9
+ latent_channels: int
10
+ feature_depths: List[int]=[64, 128, 256, 512]
11
+ attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}],
12
+ num_res_blocks: int=2
13
+ num_middle_res_blocks:int=1,
14
+ activation:Callable = jax.nn.swish
15
+ norm_groups:int=8
16
+ dtype: Optional[Dtype] = None
17
+ precision: PrecisionLike = None
18
+
19
+ # def encode(self, x: jnp.ndarray):
20
+
21
+
22
+ @nn.compact
23
+ def __call__(self, x: jnp.ndarray):
24
+ latents = self.encode(x)
25
+ reconstructions = self.decode(latents)
26
+ return reconstructions
@@ -1,6 +1,8 @@
1
1
  import jax.numpy as jnp
2
2
  import jax
3
3
  from flax import linen as nn
4
+ from typing import Optional, Any, Callable, Sequence, Union
5
+ from flax.typing import Dtype, PrecisionLike
4
6
  from typing import Dict, Callable, Sequence, Any, Union
5
7
  import einops
6
8
 
@@ -18,8 +20,9 @@ class WeightStandardizedConv(nn.Module):
18
20
  kernel_size: Sequence[int] = 3
19
21
  strides: Union[None, int, Sequence[int]] = 1
20
22
  padding: Any = 1
21
- dtype: Any = jnp.float32
22
- param_dtype: Any = jnp.float32
23
+ dtype: Optional[Dtype] = None
24
+ precision: PrecisionLike = None
25
+ param_dtype: Optional[Dtype] = None
23
26
 
24
27
  @nn.compact
25
28
  def __call__(self, x):
@@ -120,8 +123,8 @@ class SeparableConv(nn.Module):
120
123
  use_bias:bool=False
121
124
  kernel_init:Callable=kernel_init(1.0)
122
125
  padding:str="SAME"
123
- dtype: Any = jnp.bfloat16
124
- precision: Any = jax.lax.Precision.HIGH
126
+ dtype: Optional[Dtype] = None
127
+ precision: PrecisionLike = None
125
128
 
126
129
  @nn.compact
127
130
  def __call__(self, x):
@@ -149,8 +152,8 @@ class ConvLayer(nn.Module):
149
152
  kernel_size:tuple=(3, 3)
150
153
  strides:tuple=(1, 1)
151
154
  kernel_init:Callable=kernel_init(1.0)
152
- dtype: Any = jnp.bfloat16
153
- precision: Any = jax.lax.Precision.HIGH
155
+ dtype: Optional[Dtype] = None
156
+ precision: PrecisionLike = None
154
157
 
155
158
  def setup(self):
156
159
  # conv_type can be "conv", "separable", "conv_transpose"
@@ -199,8 +202,8 @@ class Upsample(nn.Module):
199
202
  features:int
200
203
  scale:int
201
204
  activation:Callable=jax.nn.swish
202
- dtype: Any = jnp.bfloat16
203
- precision: Any = jax.lax.Precision.HIGH
205
+ dtype: Optional[Dtype] = None
206
+ precision: PrecisionLike = None
204
207
 
205
208
  @nn.compact
206
209
  def __call__(self, x, residual=None):
@@ -224,8 +227,8 @@ class Downsample(nn.Module):
224
227
  features:int
225
228
  scale:int
226
229
  activation:Callable=jax.nn.swish
227
- dtype: Any = jnp.bfloat16
228
- precision: Any = jax.lax.Precision.HIGH
230
+ dtype: Optional[Dtype] = None
231
+ precision: PrecisionLike = None
229
232
 
230
233
  @nn.compact
231
234
  def __call__(self, x, residual=None):
@@ -248,3 +251,79 @@ def l2norm(t, axis=1, eps=1e-12):
248
251
  denom = jnp.clip(jnp.linalg.norm(t, ord=2, axis=axis, keepdims=True), eps)
249
252
  out = t/denom
250
253
  return (out)
254
+
255
+
256
+ class ResidualBlock(nn.Module):
257
+ conv_type:str
258
+ features:int
259
+ kernel_size:tuple=(3, 3)
260
+ strides:tuple=(1, 1)
261
+ padding:str="SAME"
262
+ activation:Callable=jax.nn.swish
263
+ direction:str=None
264
+ res:int=2
265
+ norm_groups:int=8
266
+ kernel_init:Callable=kernel_init(1.0)
267
+ dtype: Optional[Dtype] = None
268
+ precision: PrecisionLike = None
269
+
270
+ @nn.compact
271
+ def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_features:jax.Array=None):
272
+ residual = x
273
+ # out = nn.GroupNorm(self.norm_groups)(x)
274
+ out = nn.RMSNorm()(x)
275
+ out = self.activation(out)
276
+
277
+ out = ConvLayer(
278
+ self.conv_type,
279
+ features=self.features,
280
+ kernel_size=self.kernel_size,
281
+ strides=self.strides,
282
+ kernel_init=self.kernel_init,
283
+ name="conv1",
284
+ dtype=self.dtype,
285
+ precision=self.precision
286
+ )(out)
287
+
288
+ temb = nn.DenseGeneral(
289
+ features=self.features,
290
+ name="temb_projection",
291
+ dtype=self.dtype,
292
+ precision=self.precision)(temb)
293
+ temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
294
+ # scale, shift = jnp.split(temb, 2, axis=-1)
295
+ # out = out * (1 + scale) + shift
296
+ out = out + temb
297
+
298
+ # out = nn.GroupNorm(self.norm_groups)(out)
299
+ out = nn.RMSNorm()(out)
300
+ out = self.activation(out)
301
+
302
+ out = ConvLayer(
303
+ self.conv_type,
304
+ features=self.features,
305
+ kernel_size=self.kernel_size,
306
+ strides=self.strides,
307
+ kernel_init=self.kernel_init,
308
+ name="conv2",
309
+ dtype=self.dtype,
310
+ precision=self.precision
311
+ )(out)
312
+
313
+ if residual.shape != out.shape:
314
+ residual = ConvLayer(
315
+ self.conv_type,
316
+ features=self.features,
317
+ kernel_size=(1, 1),
318
+ strides=1,
319
+ kernel_init=self.kernel_init,
320
+ name="residual_conv",
321
+ dtype=self.dtype,
322
+ precision=self.precision
323
+ )(residual)
324
+ out = out + residual
325
+
326
+ out = jnp.concatenate([out, extra_features], axis=-1) if extra_features is not None else out
327
+
328
+ return out
329
+
@@ -1,83 +1,12 @@
1
1
  import jax
2
2
  import jax.numpy as jnp
3
3
  from flax import linen as nn
4
- from typing import Dict, Callable, Sequence, Any, Union
4
+ from flax.typing import Dtype, PrecisionLike
5
+ from typing import Dict, Callable, Sequence, Any, Union, Optional
5
6
  import einops
6
7
  from .common import kernel_init, ConvLayer, Downsample, Upsample, FourierEmbedding, TimeProjection
7
8
  from .attention import TransformerBlock
8
9
 
9
- class ResidualBlock(nn.Module):
10
- conv_type:str
11
- features:int
12
- kernel_size:tuple=(3, 3)
13
- strides:tuple=(1, 1)
14
- padding:str="SAME"
15
- activation:Callable=jax.nn.swish
16
- direction:str=None
17
- res:int=2
18
- norm_groups:int=8
19
- kernel_init:Callable=kernel_init(1.0)
20
- dtype: Any = jnp.float32
21
- precision: Any = jax.lax.Precision.HIGHEST
22
-
23
- @nn.compact
24
- def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_features:jax.Array=None):
25
- residual = x
26
- out = nn.GroupNorm(self.norm_groups)(x)
27
- out = self.activation(out)
28
-
29
- out = ConvLayer(
30
- self.conv_type,
31
- features=self.features,
32
- kernel_size=self.kernel_size,
33
- strides=self.strides,
34
- kernel_init=self.kernel_init,
35
- name="conv1",
36
- dtype=self.dtype,
37
- precision=self.precision
38
- )(out)
39
-
40
- temb = nn.DenseGeneral(
41
- features=self.features,
42
- name="temb_projection",
43
- dtype=self.dtype,
44
- precision=self.precision)(temb)
45
- temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
46
- # scale, shift = jnp.split(temb, 2, axis=-1)
47
- # out = out * (1 + scale) + shift
48
- out = out + temb
49
-
50
- out = nn.GroupNorm(self.norm_groups)(out)
51
- out = self.activation(out)
52
-
53
- out = ConvLayer(
54
- self.conv_type,
55
- features=self.features,
56
- kernel_size=self.kernel_size,
57
- strides=self.strides,
58
- kernel_init=self.kernel_init,
59
- name="conv2",
60
- dtype=self.dtype,
61
- precision=self.precision
62
- )(out)
63
-
64
- if residual.shape != out.shape:
65
- residual = ConvLayer(
66
- self.conv_type,
67
- features=self.features,
68
- kernel_size=(1, 1),
69
- strides=1,
70
- kernel_init=self.kernel_init,
71
- name="residual_conv",
72
- dtype=self.dtype,
73
- precision=self.precision
74
- )(residual)
75
- out = out + residual
76
-
77
- out = jnp.concatenate([out, extra_features], axis=-1) if extra_features is not None else out
78
-
79
- return out
80
-
81
10
  class Unet(nn.Module):
82
11
  output_channels:int=3
83
12
  emb_features:int=64*4,
@@ -87,8 +16,8 @@ class Unet(nn.Module):
87
16
  num_middle_res_blocks:int=1,
88
17
  activation:Callable = jax.nn.swish
89
18
  norm_groups:int=8
90
- dtype: Any = jnp.bfloat16
91
- precision: Any = jax.lax.Precision.HIGH
19
+ dtype: Optional[Dtype] = None
20
+ precision: PrecisionLike = None
92
21
 
93
22
  @nn.compact
94
23
  def __call__(self, x, temb, textcontext):
@@ -0,0 +1,2 @@
1
+ from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
2
+ from .diffusion_trainer import DiffusionTrainer, TrainState
@@ -0,0 +1,182 @@
1
+ from flax import linen as nn
2
+ import jax
3
+ from typing import Callable
4
+ from dataclasses import field
5
+ import jax.numpy as jnp
6
+ import optax
7
+ from jax.sharding import Mesh, PartitionSpec as P
8
+ from jax.experimental.shard_map import shard_map
9
+ from typing import Dict, Callable, Sequence, Any, Union, Tuple
10
+
11
+ from ..schedulers import NoiseScheduler
12
+ from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
13
+
14
+ from flaxdiff.utils import RandomMarkovState
15
+
16
+ from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
17
+
18
+ from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
19
+
20
+ class AutoEncoderTrainer(SimpleTrainer):
21
+ def __init__(self,
22
+ model: nn.Module,
23
+ input_shape: Union[int, int, int],
24
+ latent_dim: int,
25
+ spatial_scale: int,
26
+ optimizer: optax.GradientTransformation,
27
+ rngs: jax.random.PRNGKey,
28
+ name: str = "Autoencoder",
29
+ **kwargs
30
+ ):
31
+ super().__init__(
32
+ model=model,
33
+ input_shapes={"image": input_shape},
34
+ optimizer=optimizer,
35
+ rngs=rngs,
36
+ name=name,
37
+ **kwargs
38
+ )
39
+ self.latent_dim = latent_dim
40
+ self.spatial_scale = spatial_scale
41
+
42
+
43
+ def generate_states(
44
+ self,
45
+ optimizer: optax.GradientTransformation,
46
+ rngs: jax.random.PRNGKey,
47
+ existing_state: dict = None,
48
+ existing_best_state: dict = None,
49
+ model: nn.Module = None,
50
+ param_transforms: Callable = None
51
+ ) -> Tuple[TrainState, TrainState]:
52
+ print("Generating states for DiffusionTrainer")
53
+ rngs, subkey = jax.random.split(rngs)
54
+
55
+ if existing_state == None:
56
+ input_vars = self.get_input_ones()
57
+ params = model.init(subkey, **input_vars)
58
+ new_state = {"params": params, "ema_params": params}
59
+ else:
60
+ new_state = existing_state
61
+
62
+ if param_transforms is not None:
63
+ params = param_transforms(params)
64
+
65
+ state = TrainState.create(
66
+ apply_fn=model.apply,
67
+ params=new_state['params'],
68
+ ema_params=new_state['ema_params'],
69
+ tx=optimizer,
70
+ rngs=rngs,
71
+ metrics=Metrics.empty()
72
+ )
73
+
74
+ if existing_best_state is not None:
75
+ best_state = state.replace(
76
+ params=existing_best_state['params'], ema_params=existing_best_state['ema_params'])
77
+ else:
78
+ best_state = state
79
+
80
+ return state, best_state
81
+
82
+ def _define_train_step(self, batch_size, null_labels_seq, text_embedder):
83
+ noise_schedule: NoiseScheduler = self.noise_schedule
84
+ model = self.model
85
+ model_output_transform = self.model_output_transform
86
+ loss_fn = self.loss_fn
87
+ unconditional_prob = self.unconditional_prob
88
+
89
+ # Determine the number of unconditional samples
90
+ num_unconditional = int(batch_size * unconditional_prob)
91
+
92
+ nS, nC = null_labels_seq.shape
93
+ null_labels_seq = jnp.broadcast_to(
94
+ null_labels_seq, (batch_size, nS, nC))
95
+
96
+ distributed_training = self.distributed_training
97
+
98
+ autoencoder = self.autoencoder
99
+
100
+ # @jax.jit
101
+ def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, local_device_index):
102
+ """Train for a single step."""
103
+ rng_state, subkey = rng_state.get_random_key()
104
+ subkey = jax.random.fold_in(subkey, local_device_index.reshape())
105
+ local_rng_state = RandomMarkovState(subkey)
106
+
107
+ images = batch['image']
108
+
109
+ if autoencoder is not None:
110
+ # Convert the images to latent space
111
+ local_rng_state, rngs = local_rng_state.get_random_key()
112
+ images = autoencoder.encode(images, rngs)
113
+ else:
114
+ # normalize image
115
+ images = (images - 127.5) / 127.5
116
+
117
+ output = text_embedder(
118
+ input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
119
+ label_seq = output.last_hidden_state
120
+
121
+ # Generate random probabilities to decide how much of this batch will be unconditional
122
+
123
+ label_seq = jnp.concat(
124
+ [null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
125
+
126
+ noise_level, local_rng_state = noise_schedule.generate_timesteps(images.shape[0], local_rng_state)
127
+
128
+ local_rng_state, rngs = local_rng_state.get_random_key()
129
+ noise: jax.Array = jax.random.normal(rngs, shape=images.shape)
130
+
131
+ rates = noise_schedule.get_rates(noise_level)
132
+ noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
133
+ images, noise, rates)
134
+
135
+ def model_loss(params):
136
+ preds = model.apply(params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
137
+ preds = model_output_transform.pred_transform(
138
+ noisy_images, preds, rates)
139
+ nloss = loss_fn(preds, expected_output)
140
+ # nloss = jnp.mean(nloss, axis=1)
141
+ nloss *= noise_schedule.get_weights(noise_level)
142
+ nloss = jnp.mean(nloss)
143
+ loss = nloss
144
+ return loss
145
+
146
+ loss, grads = jax.value_and_grad(model_loss)(train_state.params)
147
+ if distributed_training:
148
+ grads = jax.lax.pmean(grads, "data")
149
+ loss = jax.lax.pmean(loss, "data")
150
+ train_state = train_state.apply_gradients(grads=grads)
151
+ train_state = train_state.apply_ema(self.ema_decay)
152
+ return train_state, loss, rng_state
153
+
154
+ if distributed_training:
155
+ train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')),
156
+ out_specs=(P(), P(), P()))
157
+ train_step = jax.jit(train_step)
158
+
159
+ return train_step
160
+
161
+ def _define_compute_metrics(self):
162
+ @jax.jit
163
+ def compute_metrics(state: TrainState, expected, pred):
164
+ loss = jnp.mean(jnp.square(pred - expected))
165
+ metric_updates = state.metrics.single_from_model_output(loss=loss)
166
+ metrics = state.metrics.merge(metric_updates)
167
+ state = state.replace(metrics=metrics)
168
+ return state
169
+ return compute_metrics
170
+
171
+ def fit(self, data, steps_per_epoch, epochs):
172
+ null_labels_full = data['null_labels_full']
173
+ local_batch_size = data['local_batch_size']
174
+ text_embedder = data['model']
175
+ super().fit(data, steps_per_epoch, epochs, {
176
+ "batch_size": local_batch_size, "null_labels_seq": null_labels_full, "text_embedder": text_embedder})
177
+
178
+ def boolean_string(s):
179
+ if type(s) == bool:
180
+ return s
181
+ return s == 'True'
182
+
@@ -15,6 +15,8 @@ from flaxdiff.utils import RandomMarkovState
15
15
 
16
16
  from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
17
17
 
18
+ from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
19
+
18
20
  class TrainState(SimpleTrainState):
19
21
  rngs: jax.random.PRNGKey
20
22
  ema_params: dict
@@ -41,6 +43,7 @@ class DiffusionTrainer(SimpleTrainer):
41
43
  unconditional_prob: float = 0.2,
42
44
  name: str = "Diffusion",
43
45
  model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
46
+ autoencoder: AutoEncoder = None,
44
47
  **kwargs
45
48
  ):
46
49
  super().__init__(
@@ -54,6 +57,8 @@ class DiffusionTrainer(SimpleTrainer):
54
57
  self.noise_schedule = noise_schedule
55
58
  self.model_output_transform = model_output_transform
56
59
  self.unconditional_prob = unconditional_prob
60
+
61
+ self.autoencoder = autoencoder
57
62
 
58
63
  def generate_states(
59
64
  self,
@@ -109,6 +114,8 @@ class DiffusionTrainer(SimpleTrainer):
109
114
  null_labels_seq, (batch_size, nS, nC))
110
115
 
111
116
  distributed_training = self.distributed_training
117
+
118
+ autoencoder = self.autoencoder
112
119
 
113
120
  # @jax.jit
114
121
  def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, local_device_index):
@@ -118,8 +125,14 @@ class DiffusionTrainer(SimpleTrainer):
118
125
  local_rng_state = RandomMarkovState(subkey)
119
126
 
120
127
  images = batch['image']
121
- # normalize image
122
- images = (images - 127.5) / 127.5
128
+
129
+ if autoencoder is not None:
130
+ # Convert the images to latent space
131
+ local_rng_state, rngs = local_rng_state.get_random_key()
132
+ images = autoencoder.encode(images, rngs)
133
+ else:
134
+ # normalize image
135
+ images = (images - 127.5) / 127.5
123
136
 
124
137
  output = text_embedder(
125
138
  input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
@@ -140,8 +153,7 @@ class DiffusionTrainer(SimpleTrainer):
140
153
  images, noise, rates)
141
154
 
142
155
  def model_loss(params):
143
- preds = model.apply(
144
- params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
156
+ preds = model.apply(params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
145
157
  preds = model_output_transform.pred_transform(
146
158
  noisy_images, preds, rates)
147
159
  nloss = loss_fn(preds, expected_output)
@@ -182,3 +194,9 @@ class DiffusionTrainer(SimpleTrainer):
182
194
  text_embedder = data['model']
183
195
  super().fit(data, steps_per_epoch, epochs, {
184
196
  "batch_size": local_batch_size, "null_labels_seq": null_labels_full, "text_embedder": text_embedder})
197
+
198
+ def boolean_string(s):
199
+ if type(s) == bool:
200
+ return s
201
+ return s == 'True'
202
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.5
3
+ Version: 0.1.6
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -13,6 +13,8 @@ Requires-Dist: clu
13
13
 
14
14
  # ![](images/logo.jpeg "FlaxDiff")
15
15
 
16
+ **This project is partially supported by [Google TPU Research Cloud](https://sites.research.google/trc/about/). I would like to thank the Google Cloud TPU team for providing me with the resources to train the bigger text-conditional models in multi-host distributed settings.**
17
+
16
18
  ## A Versatile and simple Diffusion Library
17
19
 
18
20
  In recent years, diffusion and score-based multi-step models have revolutionized the generative AI domain. However, the latest research in this field has become highly math-intensive, making it challenging to understand how state-of-the-art diffusion models work and generate such impressive images. Replicating this research in code can be daunting.
@@ -16,6 +16,7 @@ flaxdiff/models/simple_vit.py
16
16
  flaxdiff/models/autoencoder/__init__.py
17
17
  flaxdiff/models/autoencoder/autoencoder.py
18
18
  flaxdiff/models/autoencoder/diffusers.py
19
+ flaxdiff/models/autoencoder/simple_autoenc.py
19
20
  flaxdiff/predictors/__init__.py
20
21
  flaxdiff/samplers/__init__.py
21
22
  flaxdiff/samplers/common.py
@@ -35,4 +36,6 @@ flaxdiff/schedulers/karras.py
35
36
  flaxdiff/schedulers/linear.py
36
37
  flaxdiff/schedulers/sqrt.py
37
38
  flaxdiff/trainer/__init__.py
39
+ flaxdiff/trainer/autoencoder_trainer.py
40
+ flaxdiff/trainer/diffusion_trainer.py
38
41
  flaxdiff/trainer/simple_trainer.py
@@ -11,7 +11,7 @@ required_packages=[
11
11
  setup(
12
12
  name='flaxdiff',
13
13
  packages=find_packages(),
14
- version='0.1.5',
14
+ version='0.1.6',
15
15
  description='A versatile and easy to understand Diffusion library',
16
16
  long_description=open('README.md').read(),
17
17
  long_description_content_type='text/markdown',
File without changes
File without changes
File without changes
File without changes