flaxdiff 0.1.35.6__py3-none-any.whl → 0.1.36.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/utils.py +105 -2
- {flaxdiff-0.1.35.6.dist-info → flaxdiff-0.1.36.1.dist-info}/METADATA +16 -7
- flaxdiff-0.1.36.1.dist-info/RECORD +6 -0
- {flaxdiff-0.1.35.6.dist-info → flaxdiff-0.1.36.1.dist-info}/WHEEL +1 -1
- flaxdiff/data/__init__.py +0 -1
- flaxdiff/data/online_loader.py +0 -336
- flaxdiff/models/__init__.py +0 -1
- flaxdiff/models/attention.py +0 -368
- flaxdiff/models/autoencoder/__init__.py +0 -2
- flaxdiff/models/autoencoder/autoencoder.py +0 -19
- flaxdiff/models/autoencoder/diffusers.py +0 -91
- flaxdiff/models/autoencoder/simple_autoenc.py +0 -26
- flaxdiff/models/common.py +0 -346
- flaxdiff/models/favor_fastattn.py +0 -723
- flaxdiff/models/simple_unet.py +0 -233
- flaxdiff/models/simple_vit.py +0 -180
- flaxdiff/predictors/__init__.py +0 -96
- flaxdiff/samplers/__init__.py +0 -7
- flaxdiff/samplers/common.py +0 -113
- flaxdiff/samplers/ddim.py +0 -10
- flaxdiff/samplers/ddpm.py +0 -43
- flaxdiff/samplers/euler.py +0 -59
- flaxdiff/samplers/heun_sampler.py +0 -28
- flaxdiff/samplers/multistep_dpm.py +0 -60
- flaxdiff/samplers/rk4_sampler.py +0 -34
- flaxdiff/schedulers/__init__.py +0 -6
- flaxdiff/schedulers/common.py +0 -98
- flaxdiff/schedulers/continuous.py +0 -12
- flaxdiff/schedulers/cosine.py +0 -40
- flaxdiff/schedulers/discrete.py +0 -74
- flaxdiff/schedulers/exp.py +0 -13
- flaxdiff/schedulers/karras.py +0 -69
- flaxdiff/schedulers/linear.py +0 -14
- flaxdiff/schedulers/sqrt.py +0 -10
- flaxdiff/trainer/__init__.py +0 -2
- flaxdiff/trainer/autoencoder_trainer.py +0 -182
- flaxdiff/trainer/diffusion_trainer.py +0 -234
- flaxdiff/trainer/simple_trainer.py +0 -442
- flaxdiff-0.1.35.6.dist-info/RECORD +0 -40
- {flaxdiff-0.1.35.6.dist-info → flaxdiff-0.1.36.1.dist-info}/top_level.txt +0 -0
flaxdiff/models/simple_unet.py
DELETED
@@ -1,233 +0,0 @@
|
|
1
|
-
import jax
|
2
|
-
import jax.numpy as jnp
|
3
|
-
from flax import linen as nn
|
4
|
-
from flax.typing import Dtype, PrecisionLike
|
5
|
-
from typing import Dict, Callable, Sequence, Any, Union, Optional
|
6
|
-
import einops
|
7
|
-
from .common import kernel_init, ConvLayer, Downsample, Upsample, FourierEmbedding, TimeProjection, ResidualBlock
|
8
|
-
from .attention import TransformerBlock
|
9
|
-
from functools import partial
|
10
|
-
|
11
|
-
class Unet(nn.Module):
|
12
|
-
output_channels:int=3
|
13
|
-
emb_features:int=64*4,
|
14
|
-
feature_depths:list=[64, 128, 256, 512],
|
15
|
-
attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}],
|
16
|
-
num_res_blocks:int=2,
|
17
|
-
num_middle_res_blocks:int=1,
|
18
|
-
activation:Callable = jax.nn.swish
|
19
|
-
norm_groups:int=8
|
20
|
-
dtype: Optional[Dtype] = None
|
21
|
-
precision: PrecisionLike = None
|
22
|
-
named_norms: bool = False # This is for backward compatibility reasons; older checkpoints have named norms
|
23
|
-
kernel_init: Callable = partial(kernel_init, dtype=jnp.float32)
|
24
|
-
|
25
|
-
def setup(self):
|
26
|
-
if self.norm_groups > 0:
|
27
|
-
norm = partial(nn.GroupNorm, self.norm_groups)
|
28
|
-
self.conv_out_norm = norm(name="GroupNorm_0") if self.named_norms else norm()
|
29
|
-
else:
|
30
|
-
norm = partial(nn.RMSNorm, 1e-5)
|
31
|
-
self.conv_out_norm = norm()
|
32
|
-
|
33
|
-
@nn.compact
|
34
|
-
def __call__(self, x, temb, textcontext):
|
35
|
-
# print("embedding features", self.emb_features)
|
36
|
-
temb = FourierEmbedding(features=self.emb_features)(temb)
|
37
|
-
temb = TimeProjection(features=self.emb_features)(temb)
|
38
|
-
|
39
|
-
_, TS, TC = textcontext.shape
|
40
|
-
|
41
|
-
# print("time embedding", temb.shape)
|
42
|
-
feature_depths = self.feature_depths
|
43
|
-
attention_configs = self.attention_configs
|
44
|
-
|
45
|
-
conv_type = up_conv_type = down_conv_type = middle_conv_type = "conv"
|
46
|
-
# middle_conv_type = "separable"
|
47
|
-
|
48
|
-
x = ConvLayer(
|
49
|
-
conv_type,
|
50
|
-
features=self.feature_depths[0],
|
51
|
-
kernel_size=(3, 3),
|
52
|
-
strides=(1, 1),
|
53
|
-
kernel_init=self.kernel_init(scale=1.0),
|
54
|
-
dtype=self.dtype,
|
55
|
-
precision=self.precision
|
56
|
-
)(x)
|
57
|
-
downs = [x]
|
58
|
-
|
59
|
-
# Downscaling blocks
|
60
|
-
for i, (dim_out, attention_config) in enumerate(zip(feature_depths, attention_configs)):
|
61
|
-
dim_in = x.shape[-1]
|
62
|
-
# dim_in = dim_out
|
63
|
-
for j in range(self.num_res_blocks):
|
64
|
-
x = ResidualBlock(
|
65
|
-
down_conv_type,
|
66
|
-
name=f"down_{i}_residual_{j}",
|
67
|
-
features=dim_in,
|
68
|
-
kernel_init=self.kernel_init(scale=1.0),
|
69
|
-
kernel_size=(3, 3),
|
70
|
-
strides=(1, 1),
|
71
|
-
activation=self.activation,
|
72
|
-
norm_groups=self.norm_groups,
|
73
|
-
dtype=self.dtype,
|
74
|
-
precision=self.precision,
|
75
|
-
named_norms=self.named_norms
|
76
|
-
)(x, temb)
|
77
|
-
if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
|
78
|
-
x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
|
79
|
-
dim_head=dim_in // attention_config['heads'],
|
80
|
-
use_flash_attention=attention_config.get("flash_attention", False),
|
81
|
-
use_projection=attention_config.get("use_projection", False),
|
82
|
-
use_self_and_cross=attention_config.get("use_self_and_cross", True),
|
83
|
-
precision=attention_config.get("precision", self.precision),
|
84
|
-
only_pure_attention=attention_config.get("only_pure_attention", True),
|
85
|
-
force_fp32_for_softmax=attention_config.get("force_fp32_for_softmax", False),
|
86
|
-
norm_inputs=attention_config.get("norm_inputs", True),
|
87
|
-
explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
|
88
|
-
kernel_init=self.kernel_init(scale=1.0),
|
89
|
-
name=f"down_{i}_attention_{j}")(x, textcontext)
|
90
|
-
# print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
|
91
|
-
downs.append(x)
|
92
|
-
if i != len(feature_depths) - 1:
|
93
|
-
# print("Downsample", i, x.shape)
|
94
|
-
x = Downsample(
|
95
|
-
features=dim_out,
|
96
|
-
scale=2,
|
97
|
-
activation=self.activation,
|
98
|
-
name=f"down_{i}_downsample",
|
99
|
-
dtype=self.dtype,
|
100
|
-
precision=self.precision
|
101
|
-
)(x)
|
102
|
-
|
103
|
-
# Middle Blocks
|
104
|
-
middle_dim_out = self.feature_depths[-1]
|
105
|
-
middle_attention = self.attention_configs[-1]
|
106
|
-
for j in range(self.num_middle_res_blocks):
|
107
|
-
x = ResidualBlock(
|
108
|
-
middle_conv_type,
|
109
|
-
name=f"middle_res1_{j}",
|
110
|
-
features=middle_dim_out,
|
111
|
-
kernel_init=self.kernel_init(scale=1.0),
|
112
|
-
kernel_size=(3, 3),
|
113
|
-
strides=(1, 1),
|
114
|
-
activation=self.activation,
|
115
|
-
norm_groups=self.norm_groups,
|
116
|
-
dtype=self.dtype,
|
117
|
-
precision=self.precision,
|
118
|
-
named_norms=self.named_norms
|
119
|
-
)(x, temb)
|
120
|
-
if middle_attention is not None and j == self.num_middle_res_blocks - 1: # Apply attention only on the last block
|
121
|
-
x = TransformerBlock(heads=middle_attention['heads'], dtype=middle_attention.get('dtype', jnp.float32),
|
122
|
-
dim_head=middle_dim_out // middle_attention['heads'],
|
123
|
-
use_flash_attention=middle_attention.get("flash_attention", False),
|
124
|
-
use_linear_attention=False,
|
125
|
-
use_projection=middle_attention.get("use_projection", False),
|
126
|
-
use_self_and_cross=False,
|
127
|
-
precision=middle_attention.get("precision", self.precision),
|
128
|
-
only_pure_attention=middle_attention.get("only_pure_attention", True),
|
129
|
-
force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
|
130
|
-
norm_inputs=middle_attention.get("norm_inputs", True),
|
131
|
-
explicitly_add_residual=middle_attention.get("explicitly_add_residual", True),
|
132
|
-
kernel_init=self.kernel_init(scale=1.0),
|
133
|
-
name=f"middle_attention_{j}")(x, textcontext)
|
134
|
-
x = ResidualBlock(
|
135
|
-
middle_conv_type,
|
136
|
-
name=f"middle_res2_{j}",
|
137
|
-
features=middle_dim_out,
|
138
|
-
kernel_init=self.kernel_init(scale=1.0),
|
139
|
-
kernel_size=(3, 3),
|
140
|
-
strides=(1, 1),
|
141
|
-
activation=self.activation,
|
142
|
-
norm_groups=self.norm_groups,
|
143
|
-
dtype=self.dtype,
|
144
|
-
precision=self.precision,
|
145
|
-
named_norms=self.named_norms
|
146
|
-
)(x, temb)
|
147
|
-
|
148
|
-
# Upscaling Blocks
|
149
|
-
for i, (dim_out, attention_config) in enumerate(zip(reversed(feature_depths), reversed(attention_configs))):
|
150
|
-
# print("Upscaling", i, "features", dim_out)
|
151
|
-
for j in range(self.num_res_blocks):
|
152
|
-
x = jnp.concatenate([x, downs.pop()], axis=-1)
|
153
|
-
# print("concat==> ", i, "concat", x.shape)
|
154
|
-
# kernel_size = (1 + 2 * (j + 1), 1 + 2 * (j + 1))
|
155
|
-
kernel_size = (3, 3)
|
156
|
-
x = ResidualBlock(
|
157
|
-
up_conv_type,# if j == 0 else "separable",
|
158
|
-
name=f"up_{i}_residual_{j}",
|
159
|
-
features=dim_out,
|
160
|
-
kernel_init=self.kernel_init(scale=1.0),
|
161
|
-
kernel_size=kernel_size,
|
162
|
-
strides=(1, 1),
|
163
|
-
activation=self.activation,
|
164
|
-
norm_groups=self.norm_groups,
|
165
|
-
dtype=self.dtype,
|
166
|
-
precision=self.precision,
|
167
|
-
named_norms=self.named_norms
|
168
|
-
)(x, temb)
|
169
|
-
if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
|
170
|
-
x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
|
171
|
-
dim_head=dim_out // attention_config['heads'],
|
172
|
-
use_flash_attention=attention_config.get("flash_attention", False),
|
173
|
-
use_projection=attention_config.get("use_projection", False),
|
174
|
-
use_self_and_cross=attention_config.get("use_self_and_cross", True),
|
175
|
-
precision=attention_config.get("precision", self.precision),
|
176
|
-
only_pure_attention=attention_config.get("only_pure_attention", True),
|
177
|
-
force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
|
178
|
-
norm_inputs=attention_config.get("norm_inputs", True),
|
179
|
-
explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
|
180
|
-
kernel_init=self.kernel_init(scale=1.0),
|
181
|
-
name=f"up_{i}_attention_{j}")(x, textcontext)
|
182
|
-
# print("Upscaling ", i, x.shape)
|
183
|
-
if i != len(feature_depths) - 1:
|
184
|
-
x = Upsample(
|
185
|
-
features=feature_depths[-i],
|
186
|
-
scale=2,
|
187
|
-
activation=self.activation,
|
188
|
-
name=f"up_{i}_upsample",
|
189
|
-
dtype=self.dtype,
|
190
|
-
precision=self.precision
|
191
|
-
)(x)
|
192
|
-
|
193
|
-
# x = self.last_up_norm(x)
|
194
|
-
x = ConvLayer(
|
195
|
-
conv_type,
|
196
|
-
features=self.feature_depths[0],
|
197
|
-
kernel_size=(3, 3),
|
198
|
-
strides=(1, 1),
|
199
|
-
kernel_init=self.kernel_init(scale=1.0),
|
200
|
-
dtype=self.dtype,
|
201
|
-
precision=self.precision
|
202
|
-
)(x)
|
203
|
-
|
204
|
-
x = jnp.concatenate([x, downs.pop()], axis=-1)
|
205
|
-
|
206
|
-
x = ResidualBlock(
|
207
|
-
conv_type,
|
208
|
-
name="final_residual",
|
209
|
-
features=self.feature_depths[0],
|
210
|
-
kernel_init=self.kernel_init(scale=1.0),
|
211
|
-
kernel_size=(3,3),
|
212
|
-
strides=(1, 1),
|
213
|
-
activation=self.activation,
|
214
|
-
norm_groups=self.norm_groups,
|
215
|
-
dtype=self.dtype,
|
216
|
-
precision=self.precision,
|
217
|
-
named_norms=self.named_norms
|
218
|
-
)(x, temb)
|
219
|
-
|
220
|
-
x = self.conv_out_norm(x)
|
221
|
-
x = self.activation(x)
|
222
|
-
|
223
|
-
noise_out = ConvLayer(
|
224
|
-
conv_type,
|
225
|
-
features=self.output_channels,
|
226
|
-
kernel_size=(3, 3),
|
227
|
-
strides=(1, 1),
|
228
|
-
# activation=jax.nn.mish
|
229
|
-
kernel_init=self.kernel_init(scale=0.0),
|
230
|
-
dtype=self.dtype,
|
231
|
-
precision=self.precision
|
232
|
-
)(x)
|
233
|
-
return noise_out#, attentions
|
flaxdiff/models/simple_vit.py
DELETED
@@ -1,180 +0,0 @@
|
|
1
|
-
# simple_vit.py
|
2
|
-
|
3
|
-
import jax
|
4
|
-
import jax.numpy as jnp
|
5
|
-
from flax import linen as nn
|
6
|
-
from typing import Callable, Any, Optional, Tuple
|
7
|
-
from .simple_unet import FourierEmbedding, TimeProjection, ConvLayer, kernel_init
|
8
|
-
from .attention import TransformerBlock
|
9
|
-
from flaxdiff.models.simple_unet import FourierEmbedding, TimeProjection, ConvLayer, kernel_init, ResidualBlock
|
10
|
-
import einops
|
11
|
-
from flax.typing import Dtype, PrecisionLike
|
12
|
-
from functools import partial
|
13
|
-
|
14
|
-
def unpatchify(x, channels=3):
|
15
|
-
patch_size = int((x.shape[2] // channels) ** 0.5)
|
16
|
-
h = w = int(x.shape[1] ** .5)
|
17
|
-
assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2], f"Invalid shape: {x.shape}, should be {h*w}, {patch_size**2*channels}"
|
18
|
-
x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B (h p1) (w p2) C', h=h, p1=patch_size, p2=patch_size)
|
19
|
-
return x
|
20
|
-
|
21
|
-
class PatchEmbedding(nn.Module):
|
22
|
-
patch_size: int
|
23
|
-
embedding_dim: int
|
24
|
-
dtype: Any = jnp.float32
|
25
|
-
precision: Any = jax.lax.Precision.HIGH
|
26
|
-
kernel_init: Callable = partial(kernel_init, 1.0)
|
27
|
-
|
28
|
-
@nn.compact
|
29
|
-
def __call__(self, x):
|
30
|
-
batch, height, width, channels = x.shape
|
31
|
-
assert height % self.patch_size == 0 and width % self.patch_size == 0, "Image dimensions must be divisible by patch size"
|
32
|
-
|
33
|
-
x = nn.Conv(features=self.embedding_dim,
|
34
|
-
kernel_size=(self.patch_size, self.patch_size),
|
35
|
-
strides=(self.patch_size, self.patch_size),
|
36
|
-
dtype=self.dtype,
|
37
|
-
kernel_init=self.kernel_init(),
|
38
|
-
precision=self.precision)(x)
|
39
|
-
x = jnp.reshape(x, (batch, -1, self.embedding_dim))
|
40
|
-
return x
|
41
|
-
|
42
|
-
class PositionalEncoding(nn.Module):
|
43
|
-
max_len: int
|
44
|
-
embedding_dim: int
|
45
|
-
|
46
|
-
@nn.compact
|
47
|
-
def __call__(self, x):
|
48
|
-
pe = self.param('pos_encoding',
|
49
|
-
jax.nn.initializers.zeros,
|
50
|
-
(1, self.max_len, self.embedding_dim))
|
51
|
-
return x + pe[:, :x.shape[1], :]
|
52
|
-
|
53
|
-
class UViT(nn.Module):
|
54
|
-
output_channels:int=3
|
55
|
-
patch_size: int = 16
|
56
|
-
emb_features:int=768,
|
57
|
-
num_layers: int = 12
|
58
|
-
num_heads: int = 12
|
59
|
-
dropout_rate: float = 0.1
|
60
|
-
dtype: Any = jnp.float32
|
61
|
-
precision: Any = jax.lax.Precision.HIGH
|
62
|
-
use_projection: bool = False
|
63
|
-
use_flash_attention: bool = False
|
64
|
-
use_self_and_cross: bool = False
|
65
|
-
force_fp32_for_softmax: bool = True
|
66
|
-
activation:Callable = jax.nn.swish
|
67
|
-
norm_groups:int=8
|
68
|
-
dtype: Optional[Dtype] = None
|
69
|
-
precision: PrecisionLike = None
|
70
|
-
kernel_init: Callable = partial(kernel_init, scale=1.0)
|
71
|
-
add_residualblock_output: bool = False
|
72
|
-
norm_inputs: bool = False
|
73
|
-
explicitly_add_residual: bool = True
|
74
|
-
|
75
|
-
def setup(self):
|
76
|
-
if self.norm_groups > 0:
|
77
|
-
self.norm = partial(nn.GroupNorm, self.norm_groups)
|
78
|
-
else:
|
79
|
-
self.norm = partial(nn.RMSNorm, 1e-5)
|
80
|
-
|
81
|
-
@nn.compact
|
82
|
-
def __call__(self, x, temb, textcontext=None):
|
83
|
-
# Time embedding
|
84
|
-
temb = FourierEmbedding(features=self.emb_features)(temb)
|
85
|
-
temb = TimeProjection(features=self.emb_features)(temb)
|
86
|
-
|
87
|
-
original_img = x
|
88
|
-
|
89
|
-
# Patch embedding
|
90
|
-
x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.emb_features,
|
91
|
-
dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init)(x)
|
92
|
-
num_patches = x.shape[1]
|
93
|
-
|
94
|
-
context_emb = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
|
95
|
-
dtype=self.dtype, precision=self.precision)(textcontext)
|
96
|
-
num_text_tokens = textcontext.shape[1]
|
97
|
-
|
98
|
-
# print(f'Shape of x after patch embedding: {x.shape}, numPatches: {num_patches}, temb: {temb.shape}, context_emb: {context_emb.shape}')
|
99
|
-
|
100
|
-
# Add time embedding
|
101
|
-
temb = jnp.expand_dims(temb, axis=1)
|
102
|
-
x = jnp.concatenate([x, temb, context_emb], axis=1)
|
103
|
-
# print(f'Shape of x after time embedding: {x.shape}')
|
104
|
-
|
105
|
-
# Add positional encoding
|
106
|
-
x = PositionalEncoding(max_len=x.shape[1], embedding_dim=self.emb_features)(x)
|
107
|
-
|
108
|
-
# print(f'Shape of x after positional encoding: {x.shape}')
|
109
|
-
|
110
|
-
skips = []
|
111
|
-
# In blocks
|
112
|
-
for i in range(self.num_layers // 2):
|
113
|
-
x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
|
114
|
-
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
115
|
-
use_flash_attention=self.use_flash_attention, use_self_and_cross=False, force_fp32_for_softmax=self.force_fp32_for_softmax,
|
116
|
-
only_pure_attention=False,
|
117
|
-
norm_inputs=self.norm_inputs,
|
118
|
-
explicitly_add_residual=self.explicitly_add_residual,
|
119
|
-
kernel_init=self.kernel_init())(x)
|
120
|
-
skips.append(x)
|
121
|
-
|
122
|
-
# Middle block
|
123
|
-
x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
|
124
|
-
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
125
|
-
use_flash_attention=self.use_flash_attention, use_self_and_cross=False, force_fp32_for_softmax=self.force_fp32_for_softmax,
|
126
|
-
only_pure_attention=False,
|
127
|
-
norm_inputs=self.norm_inputs,
|
128
|
-
explicitly_add_residual=self.explicitly_add_residual,
|
129
|
-
kernel_init=self.kernel_init())(x)
|
130
|
-
|
131
|
-
# # Out blocks
|
132
|
-
for i in range(self.num_layers // 2):
|
133
|
-
x = jnp.concatenate([x, skips.pop()], axis=-1)
|
134
|
-
x = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
|
135
|
-
dtype=self.dtype, precision=self.precision)(x)
|
136
|
-
x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
|
137
|
-
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
138
|
-
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
|
139
|
-
only_pure_attention=False,
|
140
|
-
norm_inputs=self.norm_inputs,
|
141
|
-
explicitly_add_residual=self.explicitly_add_residual,
|
142
|
-
kernel_init=self.kernel_init())(x)
|
143
|
-
|
144
|
-
# print(f'Shape of x after transformer blocks: {x.shape}')
|
145
|
-
x = self.norm()(x)
|
146
|
-
|
147
|
-
patch_dim = self.patch_size ** 2 * self.output_channels
|
148
|
-
x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init())(x)
|
149
|
-
x = x[:, 1 + num_text_tokens:, :]
|
150
|
-
x = unpatchify(x, channels=self.output_channels)
|
151
|
-
|
152
|
-
if self.add_residualblock_output:
|
153
|
-
# Concatenate the original image
|
154
|
-
x = jnp.concatenate([original_img, x], axis=-1)
|
155
|
-
|
156
|
-
x = ConvLayer(
|
157
|
-
"conv",
|
158
|
-
features=64,
|
159
|
-
kernel_size=(3, 3),
|
160
|
-
strides=(1, 1),
|
161
|
-
# activation=jax.nn.mish
|
162
|
-
kernel_init=self.kernel_init(scale=0.0),
|
163
|
-
dtype=self.dtype,
|
164
|
-
precision=self.precision
|
165
|
-
)(x)
|
166
|
-
|
167
|
-
x = self.norm()(x)
|
168
|
-
x = self.activation(x)
|
169
|
-
|
170
|
-
x = ConvLayer(
|
171
|
-
"conv",
|
172
|
-
features=self.output_channels,
|
173
|
-
kernel_size=(3, 3),
|
174
|
-
strides=(1, 1),
|
175
|
-
# activation=jax.nn.mish
|
176
|
-
kernel_init=self.kernel_init(scale=0.0),
|
177
|
-
dtype=self.dtype,
|
178
|
-
precision=self.precision
|
179
|
-
)(x)
|
180
|
-
return x
|
flaxdiff/predictors/__init__.py
DELETED
@@ -1,96 +0,0 @@
|
|
1
|
-
from typing import Union
|
2
|
-
import jax.numpy as jnp
|
3
|
-
from ..schedulers import NoiseScheduler, GeneralizedNoiseScheduler
|
4
|
-
|
5
|
-
############################################################################################################
|
6
|
-
# Prediction Transforms
|
7
|
-
############################################################################################################
|
8
|
-
|
9
|
-
class DiffusionPredictionTransform():
|
10
|
-
def pred_transform(self, x_t, preds, rates) -> jnp.ndarray:
|
11
|
-
return preds
|
12
|
-
|
13
|
-
def __call__(self, x_t, preds, current_step, noise_schedule:NoiseScheduler) -> Union[jnp.ndarray, jnp.ndarray]:
|
14
|
-
rates = noise_schedule.get_rates(current_step)
|
15
|
-
preds = self.pred_transform(x_t, preds, rates)
|
16
|
-
x_0, epsilon = self.backward_diffusion(x_t, preds, rates)
|
17
|
-
return x_0, epsilon
|
18
|
-
|
19
|
-
def forward_diffusion(self, x_0, epsilon, rates: tuple[jnp.ndarray, jnp.ndarray]) -> Union[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
|
20
|
-
signal_rate, noise_rate = rates
|
21
|
-
x_t = signal_rate * x_0 + noise_rate * epsilon
|
22
|
-
expected_output = self.get_target(x_0, epsilon, (signal_rate, noise_rate))
|
23
|
-
c_in = self.get_input_scale((signal_rate, noise_rate))
|
24
|
-
return x_t, c_in, expected_output
|
25
|
-
|
26
|
-
def backward_diffusion(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> Union[jnp.ndarray, jnp.ndarray]:
|
27
|
-
raise NotImplementedError
|
28
|
-
|
29
|
-
def get_target(self, x_0, epsilon, rates) ->jnp.ndarray:
|
30
|
-
return x_0
|
31
|
-
|
32
|
-
def get_input_scale(self, rates: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
|
33
|
-
return 1
|
34
|
-
|
35
|
-
class EpsilonPredictionTransform(DiffusionPredictionTransform):
|
36
|
-
def backward_diffusion(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> Union[jnp.ndarray, jnp.ndarray]:
|
37
|
-
# preds is the predicted noise
|
38
|
-
epsilon = preds
|
39
|
-
signal_rates, noise_rates = rates
|
40
|
-
x_0 = (x_t - epsilon * noise_rates) / signal_rates
|
41
|
-
return x_0, epsilon
|
42
|
-
|
43
|
-
def get_target(self, x_0, epsilon, rates) ->jnp.ndarray:
|
44
|
-
return epsilon
|
45
|
-
|
46
|
-
class DirectPredictionTransform(DiffusionPredictionTransform):
|
47
|
-
def backward_diffusion(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> Union[jnp.ndarray, jnp.ndarray]:
|
48
|
-
# Here the model predicts x_0 directly
|
49
|
-
x_0 = preds
|
50
|
-
signal_rate, noise_rate = rates
|
51
|
-
epsilon = (x_t - x_0 * signal_rate) / noise_rate
|
52
|
-
return x_0, epsilon
|
53
|
-
|
54
|
-
class VPredictionTransform(DiffusionPredictionTransform):
|
55
|
-
def backward_diffusion(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> Union[jnp.ndarray, jnp.ndarray]:
|
56
|
-
# here the model output's V = sqrt_alpha_t * epsilon - sqrt_one_minus_alpha_t * x_0
|
57
|
-
# where epsilon is the noise
|
58
|
-
# x_0 is the current sample
|
59
|
-
v = preds
|
60
|
-
signal_rate, noise_rate = rates
|
61
|
-
variance = signal_rate ** 2 + noise_rate ** 2
|
62
|
-
v = v * jnp.sqrt(variance)
|
63
|
-
x_0 = signal_rate * x_t - noise_rate * v
|
64
|
-
eps_0 = signal_rate * v + noise_rate * x_t
|
65
|
-
return x_0 / variance, eps_0 / variance
|
66
|
-
|
67
|
-
def get_target(self, x_0, epsilon, rates) ->jnp.ndarray:
|
68
|
-
signal_rate, noise_rate = rates
|
69
|
-
v = signal_rate * epsilon - noise_rate * x_0
|
70
|
-
variance = signal_rate**2 + noise_rate**2
|
71
|
-
return v / jnp.sqrt(variance)
|
72
|
-
|
73
|
-
class KarrasPredictionTransform(DiffusionPredictionTransform):
|
74
|
-
def __init__(self, sigma_data=0.5) -> None:
|
75
|
-
super().__init__()
|
76
|
-
self.sigma_data = sigma_data
|
77
|
-
|
78
|
-
def backward_diffusion(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> Union[jnp.ndarray, jnp.ndarray]:
|
79
|
-
x_0 = preds
|
80
|
-
signal_rate, noise_rate = rates
|
81
|
-
epsilon = (x_t - x_0 * signal_rate) / noise_rate
|
82
|
-
return x_0, epsilon
|
83
|
-
|
84
|
-
def pred_transform(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
|
85
|
-
_, sigma = rates
|
86
|
-
c_out = sigma * self.sigma_data / jnp.sqrt(self.sigma_data ** 2 + sigma ** 2)
|
87
|
-
c_skip = self.sigma_data ** 2 / (self.sigma_data ** 2 + sigma ** 2)
|
88
|
-
c_out = c_out.reshape((-1, 1, 1, 1))
|
89
|
-
c_skip = c_skip.reshape((-1, 1, 1, 1))
|
90
|
-
x_0 = c_out * preds + c_skip * x_t
|
91
|
-
return x_0
|
92
|
-
|
93
|
-
def get_input_scale(self, rates: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
|
94
|
-
_, sigma = rates
|
95
|
-
c_in = 1 / jnp.sqrt(self.sigma_data ** 2 + sigma ** 2)
|
96
|
-
return c_in
|
flaxdiff/samplers/__init__.py
DELETED
@@ -1,7 +0,0 @@
|
|
1
|
-
from .common import DiffusionSampler
|
2
|
-
from .ddim import DDIMSampler
|
3
|
-
from .ddpm import DDPMSampler, SimpleDDPMSampler
|
4
|
-
from .euler import EulerSampler, SimplifiedEulerSampler
|
5
|
-
from .heun_sampler import HeunSampler
|
6
|
-
from .rk4_sampler import RK4Sampler
|
7
|
-
from .multistep_dpm import MultiStepDPM
|
flaxdiff/samplers/common.py
DELETED
@@ -1,113 +0,0 @@
|
|
1
|
-
from flax import linen as nn
|
2
|
-
import jax
|
3
|
-
import jax.numpy as jnp
|
4
|
-
import tqdm
|
5
|
-
from typing import Union
|
6
|
-
from ..schedulers import NoiseScheduler
|
7
|
-
from ..utils import RandomMarkovState, MarkovState, clip_images
|
8
|
-
from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
|
9
|
-
|
10
|
-
class DiffusionSampler():
|
11
|
-
model:nn.Module
|
12
|
-
noise_schedule:NoiseScheduler
|
13
|
-
params:dict
|
14
|
-
model_output_transform:DiffusionPredictionTransform
|
15
|
-
|
16
|
-
def __init__(self, model:nn.Module, params:dict,
|
17
|
-
noise_schedule:NoiseScheduler,
|
18
|
-
model_output_transform:DiffusionPredictionTransform=EpsilonPredictionTransform()):
|
19
|
-
self.model = model
|
20
|
-
self.noise_schedule = noise_schedule
|
21
|
-
self.params = params
|
22
|
-
self.model_output_transform = model_output_transform
|
23
|
-
|
24
|
-
@jax.jit
|
25
|
-
def sample_model(x_t, t):
|
26
|
-
rates = self.noise_schedule.get_rates(t)
|
27
|
-
c_in = self.model_output_transform.get_input_scale(rates)
|
28
|
-
model_output = self.model.apply(self.params, *self.noise_schedule.transform_inputs(x_t * c_in, t))
|
29
|
-
x_0, eps = self.model_output_transform(x_t, model_output, t, self.noise_schedule)
|
30
|
-
return x_0, eps, model_output
|
31
|
-
|
32
|
-
self.sample_model = sample_model
|
33
|
-
|
34
|
-
# Used to sample from the diffusion model
|
35
|
-
def sample_step(self, current_samples:jnp.ndarray, current_step, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
|
36
|
-
# First clip the noisy images
|
37
|
-
# pred_images = clip_images(pred_images)
|
38
|
-
step_ones = jnp.ones((current_samples.shape[0], ), dtype=jnp.int32)
|
39
|
-
current_step = step_ones * current_step
|
40
|
-
next_step = step_ones * next_step
|
41
|
-
pred_images, pred_noise, _ = self.sample_model(current_samples, current_step)
|
42
|
-
# plotImages(pred_images)
|
43
|
-
new_samples, state = self.take_next_step(current_samples=current_samples, reconstructed_samples=pred_images,
|
44
|
-
pred_noise=pred_noise, current_step=current_step, next_step=next_step, state=state)
|
45
|
-
return new_samples, state
|
46
|
-
|
47
|
-
def take_next_step(self, current_samples, reconstructed_samples,
|
48
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
49
|
-
# estimate the q(x_{t-1} | x_t, x_0).
|
50
|
-
# pred_images is x_0, noisy_images is x_t, steps is t
|
51
|
-
return NotImplementedError
|
52
|
-
|
53
|
-
def scale_steps(self, steps):
|
54
|
-
scale_factor = self.noise_schedule.max_timesteps / 1000
|
55
|
-
return steps * scale_factor
|
56
|
-
|
57
|
-
def get_steps(self, start_step, end_step, diffusion_steps):
|
58
|
-
step_range = start_step - end_step
|
59
|
-
if diffusion_steps is None or diffusion_steps == 0:
|
60
|
-
diffusion_steps = start_step - end_step
|
61
|
-
diffusion_steps = min(diffusion_steps, step_range)
|
62
|
-
steps = jnp.linspace(end_step, start_step, diffusion_steps, dtype=jnp.int16)[::-1]
|
63
|
-
return steps
|
64
|
-
|
65
|
-
def get_initial_samples(self, num_images, rngs:jax.random.PRNGKey, start_step, image_size=64):
|
66
|
-
start_step = self.scale_steps(start_step)
|
67
|
-
alpha_n, sigma_n = self.noise_schedule.get_rates(start_step)
|
68
|
-
variance = jnp.sqrt(alpha_n ** 2 + sigma_n ** 2)
|
69
|
-
return jax.random.normal(rngs, (num_images, image_size, image_size, 3)) * variance
|
70
|
-
|
71
|
-
def generate_images(self,
|
72
|
-
num_images=16,
|
73
|
-
diffusion_steps=1000,
|
74
|
-
start_step:int = None,
|
75
|
-
end_step:int = 0,
|
76
|
-
steps_override=None,
|
77
|
-
priors=None,
|
78
|
-
rngstate:RandomMarkovState=RandomMarkovState(jax.random.PRNGKey(42))) -> jnp.ndarray:
|
79
|
-
if priors is None:
|
80
|
-
rngstate, newrngs = rngstate.get_random_key()
|
81
|
-
samples = self.get_initial_samples(num_images, newrngs, start_step)
|
82
|
-
else:
|
83
|
-
print("Using priors")
|
84
|
-
samples = priors
|
85
|
-
|
86
|
-
@jax.jit
|
87
|
-
def sample_step(state:RandomMarkovState, samples, current_step, next_step):
|
88
|
-
samples, state = self.sample_step(current_samples=samples,
|
89
|
-
current_step=current_step,
|
90
|
-
state=state, next_step=next_step)
|
91
|
-
return samples, state
|
92
|
-
|
93
|
-
if start_step is None:
|
94
|
-
start_step = self.noise_schedule.max_timesteps
|
95
|
-
|
96
|
-
if steps_override is not None:
|
97
|
-
steps = steps_override
|
98
|
-
else:
|
99
|
-
steps = self.get_steps(start_step, end_step, diffusion_steps)
|
100
|
-
|
101
|
-
# print("Sampling steps", steps)
|
102
|
-
for i in tqdm.tqdm(range(0, len(steps))):
|
103
|
-
current_step = self.scale_steps(steps[i])
|
104
|
-
next_step = self.scale_steps(steps[i+1] if i+1 < len(steps) else 0)
|
105
|
-
if i != len(steps) - 1:
|
106
|
-
# print("normal step")
|
107
|
-
samples, rngstate = sample_step(rngstate, samples, current_step, next_step)
|
108
|
-
else:
|
109
|
-
# print("last step")
|
110
|
-
step_ones = jnp.ones((num_images, ), dtype=jnp.int32)
|
111
|
-
samples, _, _ = self.sample_model(samples, current_step * step_ones)
|
112
|
-
samples = clip_images(samples)
|
113
|
-
return samples
|
flaxdiff/samplers/ddim.py
DELETED
@@ -1,10 +0,0 @@
|
|
1
|
-
import jax.numpy as jnp
|
2
|
-
from .common import DiffusionSampler
|
3
|
-
from ..utils import MarkovState
|
4
|
-
|
5
|
-
class DDIMSampler(DiffusionSampler):
|
6
|
-
def take_next_step(self,
|
7
|
-
current_samples, reconstructed_samples,
|
8
|
-
pred_noise, current_step, state:MarkovState, next_step=None) -> tuple[jnp.ndarray, MarkovState]:
|
9
|
-
next_signal_rate, next_noise_rate = self.noise_schedule.get_rates(next_step)
|
10
|
-
return reconstructed_samples * next_signal_rate + pred_noise * next_noise_rate, state
|