flaxdiff 0.1.5__tar.gz → 0.1.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/PKG-INFO +3 -1
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/README.md +2 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/models/attention.py +12 -11
- flaxdiff-0.1.6/flaxdiff/models/autoencoder/__init__.py +2 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/models/autoencoder/autoencoder.py +7 -2
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/models/autoencoder/diffusers.py +3 -0
- flaxdiff-0.1.6/flaxdiff/models/autoencoder/simple_autoenc.py +26 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/models/common.py +89 -10
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/models/simple_unet.py +4 -75
- flaxdiff-0.1.6/flaxdiff/trainer/__init__.py +2 -0
- flaxdiff-0.1.6/flaxdiff/trainer/autoencoder_trainer.py +182 -0
- flaxdiff-0.1.5/flaxdiff/trainer/__init__.py → flaxdiff-0.1.6/flaxdiff/trainer/diffusion_trainer.py +22 -4
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff.egg-info/PKG-INFO +3 -1
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff.egg-info/SOURCES.txt +3 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/setup.py +1 -1
- flaxdiff-0.1.5/flaxdiff/models/autoencoder/__init__.py +0 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/__init__.py +0 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/models/__init__.py +0 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/models/favor_fastattn.py +0 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/models/simple_vit.py +0 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/predictors/__init__.py +0 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/samplers/__init__.py +0 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/samplers/common.py +0 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/samplers/ddim.py +0 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/samplers/ddpm.py +0 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/samplers/euler.py +0 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/samplers/heun_sampler.py +0 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/samplers/multistep_dpm.py +0 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/samplers/rk4_sampler.py +0 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/schedulers/__init__.py +0 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/schedulers/common.py +0 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/schedulers/continuous.py +0 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/schedulers/cosine.py +0 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/schedulers/discrete.py +0 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/schedulers/exp.py +0 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/schedulers/karras.py +0 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/schedulers/linear.py +0 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/schedulers/sqrt.py +0 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/trainer/simple_trainer.py +0 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff/utils.py +0 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff.egg-info/dependency_links.txt +0 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff.egg-info/requires.txt +0 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/flaxdiff.egg-info/top_level.txt +0 -0
- {flaxdiff-0.1.5 → flaxdiff-0.1.6}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: flaxdiff
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.6
|
4
4
|
Summary: A versatile and easy to understand Diffusion library
|
5
5
|
Author: Ashish Kumar Singh
|
6
6
|
Author-email: ashishkmr472@gmail.com
|
@@ -13,6 +13,8 @@ Requires-Dist: clu
|
|
13
13
|
|
14
14
|
# 
|
15
15
|
|
16
|
+
**This project is partially supported by [Google TPU Research Cloud](https://sites.research.google/trc/about/). I would like to thank the Google Cloud TPU team for providing me with the resources to train the bigger text-conditional models in multi-host distributed settings.**
|
17
|
+
|
16
18
|
## A Versatile and simple Diffusion Library
|
17
19
|
|
18
20
|
In recent years, diffusion and score-based multi-step models have revolutionized the generative AI domain. However, the latest research in this field has become highly math-intensive, making it challenging to understand how state-of-the-art diffusion models work and generate such impressive images. Replicating this research in code can be daunting.
|
@@ -1,5 +1,7 @@
|
|
1
1
|
# 
|
2
2
|
|
3
|
+
**This project is partially supported by [Google TPU Research Cloud](https://sites.research.google/trc/about/). I would like to thank the Google Cloud TPU team for providing me with the resources to train the bigger text-conditional models in multi-host distributed settings.**
|
4
|
+
|
3
5
|
## A Versatile and simple Diffusion Library
|
4
6
|
|
5
7
|
In recent years, diffusion and score-based multi-step models have revolutionized the generative AI domain. However, the latest research in this field has become highly math-intensive, making it challenging to understand how state-of-the-art diffusion models work and generate such impressive images. Replicating this research in code can be daunting.
|
@@ -5,7 +5,8 @@ Some Code ported from https://github.com/huggingface/diffusers/blob/main/src/dif
|
|
5
5
|
import jax
|
6
6
|
import jax.numpy as jnp
|
7
7
|
from flax import linen as nn
|
8
|
-
from typing import Dict, Callable, Sequence, Any, Union
|
8
|
+
from typing import Dict, Callable, Sequence, Any, Union, Tuple, Optional
|
9
|
+
from flax.typing import Dtype, PrecisionLike
|
9
10
|
import einops
|
10
11
|
import functools
|
11
12
|
import math
|
@@ -18,8 +19,8 @@ class EfficientAttention(nn.Module):
|
|
18
19
|
query_dim: int
|
19
20
|
heads: int = 4
|
20
21
|
dim_head: int = 64
|
21
|
-
dtype:
|
22
|
-
precision:
|
22
|
+
dtype: Optional[Dtype] = None
|
23
|
+
precision: PrecisionLike = None
|
23
24
|
use_bias: bool = True
|
24
25
|
kernel_init: Callable = lambda : kernel_init(1.0)
|
25
26
|
|
@@ -109,8 +110,8 @@ class NormalAttention(nn.Module):
|
|
109
110
|
query_dim: int
|
110
111
|
heads: int = 4
|
111
112
|
dim_head: int = 64
|
112
|
-
dtype:
|
113
|
-
precision:
|
113
|
+
dtype: Optional[Dtype] = None
|
114
|
+
precision: PrecisionLike = None
|
114
115
|
use_bias: bool = True
|
115
116
|
kernel_init: Callable = lambda : kernel_init(1.0)
|
116
117
|
|
@@ -166,8 +167,8 @@ class BasicTransformerBlock(nn.Module):
|
|
166
167
|
query_dim: int
|
167
168
|
heads: int = 4
|
168
169
|
dim_head: int = 64
|
169
|
-
dtype:
|
170
|
-
precision:
|
170
|
+
dtype: Optional[Dtype] = None
|
171
|
+
precision: PrecisionLike = None
|
171
172
|
use_bias: bool = True
|
172
173
|
kernel_init: Callable = lambda : kernel_init(1.0)
|
173
174
|
use_flash_attention:bool = False
|
@@ -286,8 +287,8 @@ class BasicTransformerBlock(nn.Module):
|
|
286
287
|
query_dim: int
|
287
288
|
heads: int = 4
|
288
289
|
dim_head: int = 64
|
289
|
-
dtype:
|
290
|
-
precision:
|
290
|
+
dtype: Optional[Dtype] = None
|
291
|
+
precision: PrecisionLike = None
|
291
292
|
use_bias: bool = True
|
292
293
|
kernel_init: Callable = lambda : kernel_init(1.0)
|
293
294
|
use_flash_attention:bool = False
|
@@ -346,8 +347,8 @@ class TransformerBlock(nn.Module):
|
|
346
347
|
heads: int = 4
|
347
348
|
dim_head: int = 32
|
348
349
|
use_linear_attention: bool = True
|
349
|
-
dtype:
|
350
|
-
precision:
|
350
|
+
dtype: Optional[Dtype] = None
|
351
|
+
precision: PrecisionLike = None
|
351
352
|
use_projection: bool = False
|
352
353
|
use_flash_attention:bool = True
|
353
354
|
use_self_and_cross:bool = False
|
@@ -6,9 +6,14 @@ import einops
|
|
6
6
|
from ..common import kernel_init, ConvLayer, Upsample, Downsample, PixelShuffle
|
7
7
|
|
8
8
|
|
9
|
-
class AutoEncoder:
|
9
|
+
class AutoEncoder():
|
10
10
|
def encode(self, x: jnp.ndarray, **kwargs) -> jnp.ndarray:
|
11
11
|
raise NotImplementedError
|
12
12
|
|
13
13
|
def decode(self, z: jnp.ndarray, **kwargs) -> jnp.ndarray:
|
14
|
-
raise NotImplementedError
|
14
|
+
raise NotImplementedError
|
15
|
+
|
16
|
+
def __call__(self, x: jnp.ndarray):
|
17
|
+
latents = self.encode(x)
|
18
|
+
reconstructions = self.decode(latents)
|
19
|
+
return reconstructions
|
@@ -5,6 +5,9 @@ from .autoencoder import AutoEncoder
|
|
5
5
|
|
6
6
|
"""
|
7
7
|
This module contains an Autoencoder implementation which uses the Stable Diffusion VAE model from the HuggingFace Diffusers library.
|
8
|
+
The actual model was not trained by me, but was taken from the HuggingFace model hub.
|
9
|
+
I have only implemented the wrapper around the diffusers pipeline to make it compatible with our library
|
10
|
+
All credits for the model go to the developers of Stable Diffusion VAE and all credits for the pipeline go to the developers of the Diffusers library.
|
8
11
|
"""
|
9
12
|
|
10
13
|
class StableDiffusionVAE(AutoEncoder):
|
@@ -0,0 +1,26 @@
|
|
1
|
+
from typing import Any, List, Optional, Callable
|
2
|
+
import jax
|
3
|
+
import flax.linen as nn
|
4
|
+
from jax import numpy as jnp
|
5
|
+
from flax.typing import Dtype, PrecisionLike
|
6
|
+
from .autoencoder import AutoEncoder
|
7
|
+
|
8
|
+
class SimpleAutoEncoder(AutoEncoder):
|
9
|
+
latent_channels: int
|
10
|
+
feature_depths: List[int]=[64, 128, 256, 512]
|
11
|
+
attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}],
|
12
|
+
num_res_blocks: int=2
|
13
|
+
num_middle_res_blocks:int=1,
|
14
|
+
activation:Callable = jax.nn.swish
|
15
|
+
norm_groups:int=8
|
16
|
+
dtype: Optional[Dtype] = None
|
17
|
+
precision: PrecisionLike = None
|
18
|
+
|
19
|
+
# def encode(self, x: jnp.ndarray):
|
20
|
+
|
21
|
+
|
22
|
+
@nn.compact
|
23
|
+
def __call__(self, x: jnp.ndarray):
|
24
|
+
latents = self.encode(x)
|
25
|
+
reconstructions = self.decode(latents)
|
26
|
+
return reconstructions
|
@@ -1,6 +1,8 @@
|
|
1
1
|
import jax.numpy as jnp
|
2
2
|
import jax
|
3
3
|
from flax import linen as nn
|
4
|
+
from typing import Optional, Any, Callable, Sequence, Union
|
5
|
+
from flax.typing import Dtype, PrecisionLike
|
4
6
|
from typing import Dict, Callable, Sequence, Any, Union
|
5
7
|
import einops
|
6
8
|
|
@@ -18,8 +20,9 @@ class WeightStandardizedConv(nn.Module):
|
|
18
20
|
kernel_size: Sequence[int] = 3
|
19
21
|
strides: Union[None, int, Sequence[int]] = 1
|
20
22
|
padding: Any = 1
|
21
|
-
dtype:
|
22
|
-
|
23
|
+
dtype: Optional[Dtype] = None
|
24
|
+
precision: PrecisionLike = None
|
25
|
+
param_dtype: Optional[Dtype] = None
|
23
26
|
|
24
27
|
@nn.compact
|
25
28
|
def __call__(self, x):
|
@@ -120,8 +123,8 @@ class SeparableConv(nn.Module):
|
|
120
123
|
use_bias:bool=False
|
121
124
|
kernel_init:Callable=kernel_init(1.0)
|
122
125
|
padding:str="SAME"
|
123
|
-
dtype:
|
124
|
-
precision:
|
126
|
+
dtype: Optional[Dtype] = None
|
127
|
+
precision: PrecisionLike = None
|
125
128
|
|
126
129
|
@nn.compact
|
127
130
|
def __call__(self, x):
|
@@ -149,8 +152,8 @@ class ConvLayer(nn.Module):
|
|
149
152
|
kernel_size:tuple=(3, 3)
|
150
153
|
strides:tuple=(1, 1)
|
151
154
|
kernel_init:Callable=kernel_init(1.0)
|
152
|
-
dtype:
|
153
|
-
precision:
|
155
|
+
dtype: Optional[Dtype] = None
|
156
|
+
precision: PrecisionLike = None
|
154
157
|
|
155
158
|
def setup(self):
|
156
159
|
# conv_type can be "conv", "separable", "conv_transpose"
|
@@ -199,8 +202,8 @@ class Upsample(nn.Module):
|
|
199
202
|
features:int
|
200
203
|
scale:int
|
201
204
|
activation:Callable=jax.nn.swish
|
202
|
-
dtype:
|
203
|
-
precision:
|
205
|
+
dtype: Optional[Dtype] = None
|
206
|
+
precision: PrecisionLike = None
|
204
207
|
|
205
208
|
@nn.compact
|
206
209
|
def __call__(self, x, residual=None):
|
@@ -224,8 +227,8 @@ class Downsample(nn.Module):
|
|
224
227
|
features:int
|
225
228
|
scale:int
|
226
229
|
activation:Callable=jax.nn.swish
|
227
|
-
dtype:
|
228
|
-
precision:
|
230
|
+
dtype: Optional[Dtype] = None
|
231
|
+
precision: PrecisionLike = None
|
229
232
|
|
230
233
|
@nn.compact
|
231
234
|
def __call__(self, x, residual=None):
|
@@ -248,3 +251,79 @@ def l2norm(t, axis=1, eps=1e-12):
|
|
248
251
|
denom = jnp.clip(jnp.linalg.norm(t, ord=2, axis=axis, keepdims=True), eps)
|
249
252
|
out = t/denom
|
250
253
|
return (out)
|
254
|
+
|
255
|
+
|
256
|
+
class ResidualBlock(nn.Module):
|
257
|
+
conv_type:str
|
258
|
+
features:int
|
259
|
+
kernel_size:tuple=(3, 3)
|
260
|
+
strides:tuple=(1, 1)
|
261
|
+
padding:str="SAME"
|
262
|
+
activation:Callable=jax.nn.swish
|
263
|
+
direction:str=None
|
264
|
+
res:int=2
|
265
|
+
norm_groups:int=8
|
266
|
+
kernel_init:Callable=kernel_init(1.0)
|
267
|
+
dtype: Optional[Dtype] = None
|
268
|
+
precision: PrecisionLike = None
|
269
|
+
|
270
|
+
@nn.compact
|
271
|
+
def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_features:jax.Array=None):
|
272
|
+
residual = x
|
273
|
+
# out = nn.GroupNorm(self.norm_groups)(x)
|
274
|
+
out = nn.RMSNorm()(x)
|
275
|
+
out = self.activation(out)
|
276
|
+
|
277
|
+
out = ConvLayer(
|
278
|
+
self.conv_type,
|
279
|
+
features=self.features,
|
280
|
+
kernel_size=self.kernel_size,
|
281
|
+
strides=self.strides,
|
282
|
+
kernel_init=self.kernel_init,
|
283
|
+
name="conv1",
|
284
|
+
dtype=self.dtype,
|
285
|
+
precision=self.precision
|
286
|
+
)(out)
|
287
|
+
|
288
|
+
temb = nn.DenseGeneral(
|
289
|
+
features=self.features,
|
290
|
+
name="temb_projection",
|
291
|
+
dtype=self.dtype,
|
292
|
+
precision=self.precision)(temb)
|
293
|
+
temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
|
294
|
+
# scale, shift = jnp.split(temb, 2, axis=-1)
|
295
|
+
# out = out * (1 + scale) + shift
|
296
|
+
out = out + temb
|
297
|
+
|
298
|
+
# out = nn.GroupNorm(self.norm_groups)(out)
|
299
|
+
out = nn.RMSNorm()(out)
|
300
|
+
out = self.activation(out)
|
301
|
+
|
302
|
+
out = ConvLayer(
|
303
|
+
self.conv_type,
|
304
|
+
features=self.features,
|
305
|
+
kernel_size=self.kernel_size,
|
306
|
+
strides=self.strides,
|
307
|
+
kernel_init=self.kernel_init,
|
308
|
+
name="conv2",
|
309
|
+
dtype=self.dtype,
|
310
|
+
precision=self.precision
|
311
|
+
)(out)
|
312
|
+
|
313
|
+
if residual.shape != out.shape:
|
314
|
+
residual = ConvLayer(
|
315
|
+
self.conv_type,
|
316
|
+
features=self.features,
|
317
|
+
kernel_size=(1, 1),
|
318
|
+
strides=1,
|
319
|
+
kernel_init=self.kernel_init,
|
320
|
+
name="residual_conv",
|
321
|
+
dtype=self.dtype,
|
322
|
+
precision=self.precision
|
323
|
+
)(residual)
|
324
|
+
out = out + residual
|
325
|
+
|
326
|
+
out = jnp.concatenate([out, extra_features], axis=-1) if extra_features is not None else out
|
327
|
+
|
328
|
+
return out
|
329
|
+
|
@@ -1,83 +1,12 @@
|
|
1
1
|
import jax
|
2
2
|
import jax.numpy as jnp
|
3
3
|
from flax import linen as nn
|
4
|
-
from typing import
|
4
|
+
from flax.typing import Dtype, PrecisionLike
|
5
|
+
from typing import Dict, Callable, Sequence, Any, Union, Optional
|
5
6
|
import einops
|
6
7
|
from .common import kernel_init, ConvLayer, Downsample, Upsample, FourierEmbedding, TimeProjection
|
7
8
|
from .attention import TransformerBlock
|
8
9
|
|
9
|
-
class ResidualBlock(nn.Module):
|
10
|
-
conv_type:str
|
11
|
-
features:int
|
12
|
-
kernel_size:tuple=(3, 3)
|
13
|
-
strides:tuple=(1, 1)
|
14
|
-
padding:str="SAME"
|
15
|
-
activation:Callable=jax.nn.swish
|
16
|
-
direction:str=None
|
17
|
-
res:int=2
|
18
|
-
norm_groups:int=8
|
19
|
-
kernel_init:Callable=kernel_init(1.0)
|
20
|
-
dtype: Any = jnp.float32
|
21
|
-
precision: Any = jax.lax.Precision.HIGHEST
|
22
|
-
|
23
|
-
@nn.compact
|
24
|
-
def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_features:jax.Array=None):
|
25
|
-
residual = x
|
26
|
-
out = nn.GroupNorm(self.norm_groups)(x)
|
27
|
-
out = self.activation(out)
|
28
|
-
|
29
|
-
out = ConvLayer(
|
30
|
-
self.conv_type,
|
31
|
-
features=self.features,
|
32
|
-
kernel_size=self.kernel_size,
|
33
|
-
strides=self.strides,
|
34
|
-
kernel_init=self.kernel_init,
|
35
|
-
name="conv1",
|
36
|
-
dtype=self.dtype,
|
37
|
-
precision=self.precision
|
38
|
-
)(out)
|
39
|
-
|
40
|
-
temb = nn.DenseGeneral(
|
41
|
-
features=self.features,
|
42
|
-
name="temb_projection",
|
43
|
-
dtype=self.dtype,
|
44
|
-
precision=self.precision)(temb)
|
45
|
-
temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
|
46
|
-
# scale, shift = jnp.split(temb, 2, axis=-1)
|
47
|
-
# out = out * (1 + scale) + shift
|
48
|
-
out = out + temb
|
49
|
-
|
50
|
-
out = nn.GroupNorm(self.norm_groups)(out)
|
51
|
-
out = self.activation(out)
|
52
|
-
|
53
|
-
out = ConvLayer(
|
54
|
-
self.conv_type,
|
55
|
-
features=self.features,
|
56
|
-
kernel_size=self.kernel_size,
|
57
|
-
strides=self.strides,
|
58
|
-
kernel_init=self.kernel_init,
|
59
|
-
name="conv2",
|
60
|
-
dtype=self.dtype,
|
61
|
-
precision=self.precision
|
62
|
-
)(out)
|
63
|
-
|
64
|
-
if residual.shape != out.shape:
|
65
|
-
residual = ConvLayer(
|
66
|
-
self.conv_type,
|
67
|
-
features=self.features,
|
68
|
-
kernel_size=(1, 1),
|
69
|
-
strides=1,
|
70
|
-
kernel_init=self.kernel_init,
|
71
|
-
name="residual_conv",
|
72
|
-
dtype=self.dtype,
|
73
|
-
precision=self.precision
|
74
|
-
)(residual)
|
75
|
-
out = out + residual
|
76
|
-
|
77
|
-
out = jnp.concatenate([out, extra_features], axis=-1) if extra_features is not None else out
|
78
|
-
|
79
|
-
return out
|
80
|
-
|
81
10
|
class Unet(nn.Module):
|
82
11
|
output_channels:int=3
|
83
12
|
emb_features:int=64*4,
|
@@ -87,8 +16,8 @@ class Unet(nn.Module):
|
|
87
16
|
num_middle_res_blocks:int=1,
|
88
17
|
activation:Callable = jax.nn.swish
|
89
18
|
norm_groups:int=8
|
90
|
-
dtype:
|
91
|
-
precision:
|
19
|
+
dtype: Optional[Dtype] = None
|
20
|
+
precision: PrecisionLike = None
|
92
21
|
|
93
22
|
@nn.compact
|
94
23
|
def __call__(self, x, temb, textcontext):
|
@@ -0,0 +1,182 @@
|
|
1
|
+
from flax import linen as nn
|
2
|
+
import jax
|
3
|
+
from typing import Callable
|
4
|
+
from dataclasses import field
|
5
|
+
import jax.numpy as jnp
|
6
|
+
import optax
|
7
|
+
from jax.sharding import Mesh, PartitionSpec as P
|
8
|
+
from jax.experimental.shard_map import shard_map
|
9
|
+
from typing import Dict, Callable, Sequence, Any, Union, Tuple
|
10
|
+
|
11
|
+
from ..schedulers import NoiseScheduler
|
12
|
+
from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
|
13
|
+
|
14
|
+
from flaxdiff.utils import RandomMarkovState
|
15
|
+
|
16
|
+
from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
|
17
|
+
|
18
|
+
from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
|
19
|
+
|
20
|
+
class AutoEncoderTrainer(SimpleTrainer):
|
21
|
+
def __init__(self,
|
22
|
+
model: nn.Module,
|
23
|
+
input_shape: Union[int, int, int],
|
24
|
+
latent_dim: int,
|
25
|
+
spatial_scale: int,
|
26
|
+
optimizer: optax.GradientTransformation,
|
27
|
+
rngs: jax.random.PRNGKey,
|
28
|
+
name: str = "Autoencoder",
|
29
|
+
**kwargs
|
30
|
+
):
|
31
|
+
super().__init__(
|
32
|
+
model=model,
|
33
|
+
input_shapes={"image": input_shape},
|
34
|
+
optimizer=optimizer,
|
35
|
+
rngs=rngs,
|
36
|
+
name=name,
|
37
|
+
**kwargs
|
38
|
+
)
|
39
|
+
self.latent_dim = latent_dim
|
40
|
+
self.spatial_scale = spatial_scale
|
41
|
+
|
42
|
+
|
43
|
+
def generate_states(
|
44
|
+
self,
|
45
|
+
optimizer: optax.GradientTransformation,
|
46
|
+
rngs: jax.random.PRNGKey,
|
47
|
+
existing_state: dict = None,
|
48
|
+
existing_best_state: dict = None,
|
49
|
+
model: nn.Module = None,
|
50
|
+
param_transforms: Callable = None
|
51
|
+
) -> Tuple[TrainState, TrainState]:
|
52
|
+
print("Generating states for DiffusionTrainer")
|
53
|
+
rngs, subkey = jax.random.split(rngs)
|
54
|
+
|
55
|
+
if existing_state == None:
|
56
|
+
input_vars = self.get_input_ones()
|
57
|
+
params = model.init(subkey, **input_vars)
|
58
|
+
new_state = {"params": params, "ema_params": params}
|
59
|
+
else:
|
60
|
+
new_state = existing_state
|
61
|
+
|
62
|
+
if param_transforms is not None:
|
63
|
+
params = param_transforms(params)
|
64
|
+
|
65
|
+
state = TrainState.create(
|
66
|
+
apply_fn=model.apply,
|
67
|
+
params=new_state['params'],
|
68
|
+
ema_params=new_state['ema_params'],
|
69
|
+
tx=optimizer,
|
70
|
+
rngs=rngs,
|
71
|
+
metrics=Metrics.empty()
|
72
|
+
)
|
73
|
+
|
74
|
+
if existing_best_state is not None:
|
75
|
+
best_state = state.replace(
|
76
|
+
params=existing_best_state['params'], ema_params=existing_best_state['ema_params'])
|
77
|
+
else:
|
78
|
+
best_state = state
|
79
|
+
|
80
|
+
return state, best_state
|
81
|
+
|
82
|
+
def _define_train_step(self, batch_size, null_labels_seq, text_embedder):
|
83
|
+
noise_schedule: NoiseScheduler = self.noise_schedule
|
84
|
+
model = self.model
|
85
|
+
model_output_transform = self.model_output_transform
|
86
|
+
loss_fn = self.loss_fn
|
87
|
+
unconditional_prob = self.unconditional_prob
|
88
|
+
|
89
|
+
# Determine the number of unconditional samples
|
90
|
+
num_unconditional = int(batch_size * unconditional_prob)
|
91
|
+
|
92
|
+
nS, nC = null_labels_seq.shape
|
93
|
+
null_labels_seq = jnp.broadcast_to(
|
94
|
+
null_labels_seq, (batch_size, nS, nC))
|
95
|
+
|
96
|
+
distributed_training = self.distributed_training
|
97
|
+
|
98
|
+
autoencoder = self.autoencoder
|
99
|
+
|
100
|
+
# @jax.jit
|
101
|
+
def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, local_device_index):
|
102
|
+
"""Train for a single step."""
|
103
|
+
rng_state, subkey = rng_state.get_random_key()
|
104
|
+
subkey = jax.random.fold_in(subkey, local_device_index.reshape())
|
105
|
+
local_rng_state = RandomMarkovState(subkey)
|
106
|
+
|
107
|
+
images = batch['image']
|
108
|
+
|
109
|
+
if autoencoder is not None:
|
110
|
+
# Convert the images to latent space
|
111
|
+
local_rng_state, rngs = local_rng_state.get_random_key()
|
112
|
+
images = autoencoder.encode(images, rngs)
|
113
|
+
else:
|
114
|
+
# normalize image
|
115
|
+
images = (images - 127.5) / 127.5
|
116
|
+
|
117
|
+
output = text_embedder(
|
118
|
+
input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
|
119
|
+
label_seq = output.last_hidden_state
|
120
|
+
|
121
|
+
# Generate random probabilities to decide how much of this batch will be unconditional
|
122
|
+
|
123
|
+
label_seq = jnp.concat(
|
124
|
+
[null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
|
125
|
+
|
126
|
+
noise_level, local_rng_state = noise_schedule.generate_timesteps(images.shape[0], local_rng_state)
|
127
|
+
|
128
|
+
local_rng_state, rngs = local_rng_state.get_random_key()
|
129
|
+
noise: jax.Array = jax.random.normal(rngs, shape=images.shape)
|
130
|
+
|
131
|
+
rates = noise_schedule.get_rates(noise_level)
|
132
|
+
noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
|
133
|
+
images, noise, rates)
|
134
|
+
|
135
|
+
def model_loss(params):
|
136
|
+
preds = model.apply(params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
|
137
|
+
preds = model_output_transform.pred_transform(
|
138
|
+
noisy_images, preds, rates)
|
139
|
+
nloss = loss_fn(preds, expected_output)
|
140
|
+
# nloss = jnp.mean(nloss, axis=1)
|
141
|
+
nloss *= noise_schedule.get_weights(noise_level)
|
142
|
+
nloss = jnp.mean(nloss)
|
143
|
+
loss = nloss
|
144
|
+
return loss
|
145
|
+
|
146
|
+
loss, grads = jax.value_and_grad(model_loss)(train_state.params)
|
147
|
+
if distributed_training:
|
148
|
+
grads = jax.lax.pmean(grads, "data")
|
149
|
+
loss = jax.lax.pmean(loss, "data")
|
150
|
+
train_state = train_state.apply_gradients(grads=grads)
|
151
|
+
train_state = train_state.apply_ema(self.ema_decay)
|
152
|
+
return train_state, loss, rng_state
|
153
|
+
|
154
|
+
if distributed_training:
|
155
|
+
train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')),
|
156
|
+
out_specs=(P(), P(), P()))
|
157
|
+
train_step = jax.jit(train_step)
|
158
|
+
|
159
|
+
return train_step
|
160
|
+
|
161
|
+
def _define_compute_metrics(self):
|
162
|
+
@jax.jit
|
163
|
+
def compute_metrics(state: TrainState, expected, pred):
|
164
|
+
loss = jnp.mean(jnp.square(pred - expected))
|
165
|
+
metric_updates = state.metrics.single_from_model_output(loss=loss)
|
166
|
+
metrics = state.metrics.merge(metric_updates)
|
167
|
+
state = state.replace(metrics=metrics)
|
168
|
+
return state
|
169
|
+
return compute_metrics
|
170
|
+
|
171
|
+
def fit(self, data, steps_per_epoch, epochs):
|
172
|
+
null_labels_full = data['null_labels_full']
|
173
|
+
local_batch_size = data['local_batch_size']
|
174
|
+
text_embedder = data['model']
|
175
|
+
super().fit(data, steps_per_epoch, epochs, {
|
176
|
+
"batch_size": local_batch_size, "null_labels_seq": null_labels_full, "text_embedder": text_embedder})
|
177
|
+
|
178
|
+
def boolean_string(s):
|
179
|
+
if type(s) == bool:
|
180
|
+
return s
|
181
|
+
return s == 'True'
|
182
|
+
|
flaxdiff-0.1.5/flaxdiff/trainer/__init__.py → flaxdiff-0.1.6/flaxdiff/trainer/diffusion_trainer.py
RENAMED
@@ -15,6 +15,8 @@ from flaxdiff.utils import RandomMarkovState
|
|
15
15
|
|
16
16
|
from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
|
17
17
|
|
18
|
+
from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
|
19
|
+
|
18
20
|
class TrainState(SimpleTrainState):
|
19
21
|
rngs: jax.random.PRNGKey
|
20
22
|
ema_params: dict
|
@@ -41,6 +43,7 @@ class DiffusionTrainer(SimpleTrainer):
|
|
41
43
|
unconditional_prob: float = 0.2,
|
42
44
|
name: str = "Diffusion",
|
43
45
|
model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
|
46
|
+
autoencoder: AutoEncoder = None,
|
44
47
|
**kwargs
|
45
48
|
):
|
46
49
|
super().__init__(
|
@@ -54,6 +57,8 @@ class DiffusionTrainer(SimpleTrainer):
|
|
54
57
|
self.noise_schedule = noise_schedule
|
55
58
|
self.model_output_transform = model_output_transform
|
56
59
|
self.unconditional_prob = unconditional_prob
|
60
|
+
|
61
|
+
self.autoencoder = autoencoder
|
57
62
|
|
58
63
|
def generate_states(
|
59
64
|
self,
|
@@ -109,6 +114,8 @@ class DiffusionTrainer(SimpleTrainer):
|
|
109
114
|
null_labels_seq, (batch_size, nS, nC))
|
110
115
|
|
111
116
|
distributed_training = self.distributed_training
|
117
|
+
|
118
|
+
autoencoder = self.autoencoder
|
112
119
|
|
113
120
|
# @jax.jit
|
114
121
|
def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, local_device_index):
|
@@ -118,8 +125,14 @@ class DiffusionTrainer(SimpleTrainer):
|
|
118
125
|
local_rng_state = RandomMarkovState(subkey)
|
119
126
|
|
120
127
|
images = batch['image']
|
121
|
-
|
122
|
-
|
128
|
+
|
129
|
+
if autoencoder is not None:
|
130
|
+
# Convert the images to latent space
|
131
|
+
local_rng_state, rngs = local_rng_state.get_random_key()
|
132
|
+
images = autoencoder.encode(images, rngs)
|
133
|
+
else:
|
134
|
+
# normalize image
|
135
|
+
images = (images - 127.5) / 127.5
|
123
136
|
|
124
137
|
output = text_embedder(
|
125
138
|
input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
|
@@ -140,8 +153,7 @@ class DiffusionTrainer(SimpleTrainer):
|
|
140
153
|
images, noise, rates)
|
141
154
|
|
142
155
|
def model_loss(params):
|
143
|
-
preds = model.apply(
|
144
|
-
params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
|
156
|
+
preds = model.apply(params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
|
145
157
|
preds = model_output_transform.pred_transform(
|
146
158
|
noisy_images, preds, rates)
|
147
159
|
nloss = loss_fn(preds, expected_output)
|
@@ -182,3 +194,9 @@ class DiffusionTrainer(SimpleTrainer):
|
|
182
194
|
text_embedder = data['model']
|
183
195
|
super().fit(data, steps_per_epoch, epochs, {
|
184
196
|
"batch_size": local_batch_size, "null_labels_seq": null_labels_full, "text_embedder": text_embedder})
|
197
|
+
|
198
|
+
def boolean_string(s):
|
199
|
+
if type(s) == bool:
|
200
|
+
return s
|
201
|
+
return s == 'True'
|
202
|
+
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: flaxdiff
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.6
|
4
4
|
Summary: A versatile and easy to understand Diffusion library
|
5
5
|
Author: Ashish Kumar Singh
|
6
6
|
Author-email: ashishkmr472@gmail.com
|
@@ -13,6 +13,8 @@ Requires-Dist: clu
|
|
13
13
|
|
14
14
|
# 
|
15
15
|
|
16
|
+
**This project is partially supported by [Google TPU Research Cloud](https://sites.research.google/trc/about/). I would like to thank the Google Cloud TPU team for providing me with the resources to train the bigger text-conditional models in multi-host distributed settings.**
|
17
|
+
|
16
18
|
## A Versatile and simple Diffusion Library
|
17
19
|
|
18
20
|
In recent years, diffusion and score-based multi-step models have revolutionized the generative AI domain. However, the latest research in this field has become highly math-intensive, making it challenging to understand how state-of-the-art diffusion models work and generate such impressive images. Replicating this research in code can be daunting.
|
@@ -16,6 +16,7 @@ flaxdiff/models/simple_vit.py
|
|
16
16
|
flaxdiff/models/autoencoder/__init__.py
|
17
17
|
flaxdiff/models/autoencoder/autoencoder.py
|
18
18
|
flaxdiff/models/autoencoder/diffusers.py
|
19
|
+
flaxdiff/models/autoencoder/simple_autoenc.py
|
19
20
|
flaxdiff/predictors/__init__.py
|
20
21
|
flaxdiff/samplers/__init__.py
|
21
22
|
flaxdiff/samplers/common.py
|
@@ -35,4 +36,6 @@ flaxdiff/schedulers/karras.py
|
|
35
36
|
flaxdiff/schedulers/linear.py
|
36
37
|
flaxdiff/schedulers/sqrt.py
|
37
38
|
flaxdiff/trainer/__init__.py
|
39
|
+
flaxdiff/trainer/autoencoder_trainer.py
|
40
|
+
flaxdiff/trainer/diffusion_trainer.py
|
38
41
|
flaxdiff/trainer/simple_trainer.py
|
@@ -11,7 +11,7 @@ required_packages=[
|
|
11
11
|
setup(
|
12
12
|
name='flaxdiff',
|
13
13
|
packages=find_packages(),
|
14
|
-
version='0.1.
|
14
|
+
version='0.1.6',
|
15
15
|
description='A versatile and easy to understand Diffusion library',
|
16
16
|
long_description=open('README.md').read(),
|
17
17
|
long_description_content_type='text/markdown',
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|