flaxdiff 0.1.36.1__py3-none-any.whl → 0.1.36.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. flaxdiff/data/__init__.py +1 -0
  2. flaxdiff/data/dataset_map.py +71 -0
  3. flaxdiff/data/datasets.py +169 -0
  4. flaxdiff/data/online_loader.py +363 -0
  5. flaxdiff/data/sources/gcs.py +81 -0
  6. flaxdiff/data/sources/tfds.py +67 -0
  7. flaxdiff/metrics/inception.py +658 -0
  8. flaxdiff/metrics/utils.py +49 -0
  9. flaxdiff/models/__init__.py +1 -0
  10. flaxdiff/models/attention.py +368 -0
  11. flaxdiff/models/autoencoder/__init__.py +2 -0
  12. flaxdiff/models/autoencoder/autoencoder.py +19 -0
  13. flaxdiff/models/autoencoder/diffusers.py +91 -0
  14. flaxdiff/models/autoencoder/simple_autoenc.py +26 -0
  15. flaxdiff/models/common.py +346 -0
  16. flaxdiff/models/favor_fastattn.py +723 -0
  17. flaxdiff/models/simple_unet.py +233 -0
  18. flaxdiff/models/simple_vit.py +180 -0
  19. flaxdiff/predictors/__init__.py +96 -0
  20. flaxdiff/samplers/__init__.py +7 -0
  21. flaxdiff/samplers/common.py +165 -0
  22. flaxdiff/samplers/ddim.py +10 -0
  23. flaxdiff/samplers/ddpm.py +37 -0
  24. flaxdiff/samplers/euler.py +56 -0
  25. flaxdiff/samplers/heun_sampler.py +27 -0
  26. flaxdiff/samplers/multistep_dpm.py +59 -0
  27. flaxdiff/samplers/rk4_sampler.py +34 -0
  28. flaxdiff/schedulers/__init__.py +6 -0
  29. flaxdiff/schedulers/common.py +98 -0
  30. flaxdiff/schedulers/continuous.py +12 -0
  31. flaxdiff/schedulers/cosine.py +40 -0
  32. flaxdiff/schedulers/discrete.py +74 -0
  33. flaxdiff/schedulers/exp.py +13 -0
  34. flaxdiff/schedulers/karras.py +69 -0
  35. flaxdiff/schedulers/linear.py +14 -0
  36. flaxdiff/schedulers/sqrt.py +10 -0
  37. flaxdiff/trainer/__init__.py +2 -0
  38. flaxdiff/trainer/autoencoder_trainer.py +182 -0
  39. flaxdiff/trainer/diffusion_trainer.py +326 -0
  40. flaxdiff/trainer/simple_trainer.py +540 -0
  41. flaxdiff/trainer/video_diffusion_trainer.py +62 -0
  42. {flaxdiff-0.1.36.1.dist-info → flaxdiff-0.1.36.3.dist-info}/METADATA +1 -1
  43. flaxdiff-0.1.36.3.dist-info/RECORD +47 -0
  44. flaxdiff-0.1.36.1.dist-info/RECORD +0 -6
  45. {flaxdiff-0.1.36.1.dist-info → flaxdiff-0.1.36.3.dist-info}/WHEEL +0 -0
  46. {flaxdiff-0.1.36.1.dist-info → flaxdiff-0.1.36.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,233 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from flax import linen as nn
4
+ from flax.typing import Dtype, PrecisionLike
5
+ from typing import Dict, Callable, Sequence, Any, Union, Optional
6
+ import einops
7
+ from .common import kernel_init, ConvLayer, Downsample, Upsample, FourierEmbedding, TimeProjection, ResidualBlock
8
+ from .attention import TransformerBlock
9
+ from functools import partial
10
+
11
+ class Unet(nn.Module):
12
+ output_channels:int=3
13
+ emb_features:int=64*4,
14
+ feature_depths:list=[64, 128, 256, 512],
15
+ attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}],
16
+ num_res_blocks:int=2,
17
+ num_middle_res_blocks:int=1,
18
+ activation:Callable = jax.nn.swish
19
+ norm_groups:int=8
20
+ dtype: Optional[Dtype] = None
21
+ precision: PrecisionLike = None
22
+ named_norms: bool = False # This is for backward compatibility reasons; older checkpoints have named norms
23
+ kernel_init: Callable = partial(kernel_init, dtype=jnp.float32)
24
+
25
+ def setup(self):
26
+ if self.norm_groups > 0:
27
+ norm = partial(nn.GroupNorm, self.norm_groups)
28
+ self.conv_out_norm = norm(name="GroupNorm_0") if self.named_norms else norm()
29
+ else:
30
+ norm = partial(nn.RMSNorm, 1e-5)
31
+ self.conv_out_norm = norm()
32
+
33
+ @nn.compact
34
+ def __call__(self, x, temb, textcontext):
35
+ # print("embedding features", self.emb_features)
36
+ temb = FourierEmbedding(features=self.emb_features)(temb)
37
+ temb = TimeProjection(features=self.emb_features)(temb)
38
+
39
+ _, TS, TC = textcontext.shape
40
+
41
+ # print("time embedding", temb.shape)
42
+ feature_depths = self.feature_depths
43
+ attention_configs = self.attention_configs
44
+
45
+ conv_type = up_conv_type = down_conv_type = middle_conv_type = "conv"
46
+ # middle_conv_type = "separable"
47
+
48
+ x = ConvLayer(
49
+ conv_type,
50
+ features=self.feature_depths[0],
51
+ kernel_size=(3, 3),
52
+ strides=(1, 1),
53
+ kernel_init=self.kernel_init(scale=1.0),
54
+ dtype=self.dtype,
55
+ precision=self.precision
56
+ )(x)
57
+ downs = [x]
58
+
59
+ # Downscaling blocks
60
+ for i, (dim_out, attention_config) in enumerate(zip(feature_depths, attention_configs)):
61
+ dim_in = x.shape[-1]
62
+ # dim_in = dim_out
63
+ for j in range(self.num_res_blocks):
64
+ x = ResidualBlock(
65
+ down_conv_type,
66
+ name=f"down_{i}_residual_{j}",
67
+ features=dim_in,
68
+ kernel_init=self.kernel_init(scale=1.0),
69
+ kernel_size=(3, 3),
70
+ strides=(1, 1),
71
+ activation=self.activation,
72
+ norm_groups=self.norm_groups,
73
+ dtype=self.dtype,
74
+ precision=self.precision,
75
+ named_norms=self.named_norms
76
+ )(x, temb)
77
+ if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
78
+ x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
79
+ dim_head=dim_in // attention_config['heads'],
80
+ use_flash_attention=attention_config.get("flash_attention", False),
81
+ use_projection=attention_config.get("use_projection", False),
82
+ use_self_and_cross=attention_config.get("use_self_and_cross", True),
83
+ precision=attention_config.get("precision", self.precision),
84
+ only_pure_attention=attention_config.get("only_pure_attention", True),
85
+ force_fp32_for_softmax=attention_config.get("force_fp32_for_softmax", False),
86
+ norm_inputs=attention_config.get("norm_inputs", True),
87
+ explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
88
+ kernel_init=self.kernel_init(scale=1.0),
89
+ name=f"down_{i}_attention_{j}")(x, textcontext)
90
+ # print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
91
+ downs.append(x)
92
+ if i != len(feature_depths) - 1:
93
+ # print("Downsample", i, x.shape)
94
+ x = Downsample(
95
+ features=dim_out,
96
+ scale=2,
97
+ activation=self.activation,
98
+ name=f"down_{i}_downsample",
99
+ dtype=self.dtype,
100
+ precision=self.precision
101
+ )(x)
102
+
103
+ # Middle Blocks
104
+ middle_dim_out = self.feature_depths[-1]
105
+ middle_attention = self.attention_configs[-1]
106
+ for j in range(self.num_middle_res_blocks):
107
+ x = ResidualBlock(
108
+ middle_conv_type,
109
+ name=f"middle_res1_{j}",
110
+ features=middle_dim_out,
111
+ kernel_init=self.kernel_init(scale=1.0),
112
+ kernel_size=(3, 3),
113
+ strides=(1, 1),
114
+ activation=self.activation,
115
+ norm_groups=self.norm_groups,
116
+ dtype=self.dtype,
117
+ precision=self.precision,
118
+ named_norms=self.named_norms
119
+ )(x, temb)
120
+ if middle_attention is not None and j == self.num_middle_res_blocks - 1: # Apply attention only on the last block
121
+ x = TransformerBlock(heads=middle_attention['heads'], dtype=middle_attention.get('dtype', jnp.float32),
122
+ dim_head=middle_dim_out // middle_attention['heads'],
123
+ use_flash_attention=middle_attention.get("flash_attention", False),
124
+ use_linear_attention=False,
125
+ use_projection=middle_attention.get("use_projection", False),
126
+ use_self_and_cross=False,
127
+ precision=middle_attention.get("precision", self.precision),
128
+ only_pure_attention=middle_attention.get("only_pure_attention", True),
129
+ force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
130
+ norm_inputs=middle_attention.get("norm_inputs", True),
131
+ explicitly_add_residual=middle_attention.get("explicitly_add_residual", True),
132
+ kernel_init=self.kernel_init(scale=1.0),
133
+ name=f"middle_attention_{j}")(x, textcontext)
134
+ x = ResidualBlock(
135
+ middle_conv_type,
136
+ name=f"middle_res2_{j}",
137
+ features=middle_dim_out,
138
+ kernel_init=self.kernel_init(scale=1.0),
139
+ kernel_size=(3, 3),
140
+ strides=(1, 1),
141
+ activation=self.activation,
142
+ norm_groups=self.norm_groups,
143
+ dtype=self.dtype,
144
+ precision=self.precision,
145
+ named_norms=self.named_norms
146
+ )(x, temb)
147
+
148
+ # Upscaling Blocks
149
+ for i, (dim_out, attention_config) in enumerate(zip(reversed(feature_depths), reversed(attention_configs))):
150
+ # print("Upscaling", i, "features", dim_out)
151
+ for j in range(self.num_res_blocks):
152
+ x = jnp.concatenate([x, downs.pop()], axis=-1)
153
+ # print("concat==> ", i, "concat", x.shape)
154
+ # kernel_size = (1 + 2 * (j + 1), 1 + 2 * (j + 1))
155
+ kernel_size = (3, 3)
156
+ x = ResidualBlock(
157
+ up_conv_type,# if j == 0 else "separable",
158
+ name=f"up_{i}_residual_{j}",
159
+ features=dim_out,
160
+ kernel_init=self.kernel_init(scale=1.0),
161
+ kernel_size=kernel_size,
162
+ strides=(1, 1),
163
+ activation=self.activation,
164
+ norm_groups=self.norm_groups,
165
+ dtype=self.dtype,
166
+ precision=self.precision,
167
+ named_norms=self.named_norms
168
+ )(x, temb)
169
+ if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
170
+ x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
171
+ dim_head=dim_out // attention_config['heads'],
172
+ use_flash_attention=attention_config.get("flash_attention", False),
173
+ use_projection=attention_config.get("use_projection", False),
174
+ use_self_and_cross=attention_config.get("use_self_and_cross", True),
175
+ precision=attention_config.get("precision", self.precision),
176
+ only_pure_attention=attention_config.get("only_pure_attention", True),
177
+ force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
178
+ norm_inputs=attention_config.get("norm_inputs", True),
179
+ explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
180
+ kernel_init=self.kernel_init(scale=1.0),
181
+ name=f"up_{i}_attention_{j}")(x, textcontext)
182
+ # print("Upscaling ", i, x.shape)
183
+ if i != len(feature_depths) - 1:
184
+ x = Upsample(
185
+ features=feature_depths[-i],
186
+ scale=2,
187
+ activation=self.activation,
188
+ name=f"up_{i}_upsample",
189
+ dtype=self.dtype,
190
+ precision=self.precision
191
+ )(x)
192
+
193
+ # x = self.last_up_norm(x)
194
+ x = ConvLayer(
195
+ conv_type,
196
+ features=self.feature_depths[0],
197
+ kernel_size=(3, 3),
198
+ strides=(1, 1),
199
+ kernel_init=self.kernel_init(scale=1.0),
200
+ dtype=self.dtype,
201
+ precision=self.precision
202
+ )(x)
203
+
204
+ x = jnp.concatenate([x, downs.pop()], axis=-1)
205
+
206
+ x = ResidualBlock(
207
+ conv_type,
208
+ name="final_residual",
209
+ features=self.feature_depths[0],
210
+ kernel_init=self.kernel_init(scale=1.0),
211
+ kernel_size=(3,3),
212
+ strides=(1, 1),
213
+ activation=self.activation,
214
+ norm_groups=self.norm_groups,
215
+ dtype=self.dtype,
216
+ precision=self.precision,
217
+ named_norms=self.named_norms
218
+ )(x, temb)
219
+
220
+ x = self.conv_out_norm(x)
221
+ x = self.activation(x)
222
+
223
+ noise_out = ConvLayer(
224
+ conv_type,
225
+ features=self.output_channels,
226
+ kernel_size=(3, 3),
227
+ strides=(1, 1),
228
+ # activation=jax.nn.mish
229
+ kernel_init=self.kernel_init(scale=0.0),
230
+ dtype=self.dtype,
231
+ precision=self.precision
232
+ )(x)
233
+ return noise_out#, attentions
@@ -0,0 +1,180 @@
1
+ # simple_vit.py
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from flax import linen as nn
6
+ from typing import Callable, Any, Optional, Tuple
7
+ from .simple_unet import FourierEmbedding, TimeProjection, ConvLayer, kernel_init
8
+ from .attention import TransformerBlock
9
+ from flaxdiff.models.simple_unet import FourierEmbedding, TimeProjection, ConvLayer, kernel_init, ResidualBlock
10
+ import einops
11
+ from flax.typing import Dtype, PrecisionLike
12
+ from functools import partial
13
+
14
+ def unpatchify(x, channels=3):
15
+ patch_size = int((x.shape[2] // channels) ** 0.5)
16
+ h = w = int(x.shape[1] ** .5)
17
+ assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2], f"Invalid shape: {x.shape}, should be {h*w}, {patch_size**2*channels}"
18
+ x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B (h p1) (w p2) C', h=h, p1=patch_size, p2=patch_size)
19
+ return x
20
+
21
+ class PatchEmbedding(nn.Module):
22
+ patch_size: int
23
+ embedding_dim: int
24
+ dtype: Any = jnp.float32
25
+ precision: Any = jax.lax.Precision.HIGH
26
+ kernel_init: Callable = partial(kernel_init, 1.0)
27
+
28
+ @nn.compact
29
+ def __call__(self, x):
30
+ batch, height, width, channels = x.shape
31
+ assert height % self.patch_size == 0 and width % self.patch_size == 0, "Image dimensions must be divisible by patch size"
32
+
33
+ x = nn.Conv(features=self.embedding_dim,
34
+ kernel_size=(self.patch_size, self.patch_size),
35
+ strides=(self.patch_size, self.patch_size),
36
+ dtype=self.dtype,
37
+ kernel_init=self.kernel_init(),
38
+ precision=self.precision)(x)
39
+ x = jnp.reshape(x, (batch, -1, self.embedding_dim))
40
+ return x
41
+
42
+ class PositionalEncoding(nn.Module):
43
+ max_len: int
44
+ embedding_dim: int
45
+
46
+ @nn.compact
47
+ def __call__(self, x):
48
+ pe = self.param('pos_encoding',
49
+ jax.nn.initializers.zeros,
50
+ (1, self.max_len, self.embedding_dim))
51
+ return x + pe[:, :x.shape[1], :]
52
+
53
+ class UViT(nn.Module):
54
+ output_channels:int=3
55
+ patch_size: int = 16
56
+ emb_features:int=768,
57
+ num_layers: int = 12
58
+ num_heads: int = 12
59
+ dropout_rate: float = 0.1
60
+ dtype: Any = jnp.float32
61
+ precision: Any = jax.lax.Precision.HIGH
62
+ use_projection: bool = False
63
+ use_flash_attention: bool = False
64
+ use_self_and_cross: bool = False
65
+ force_fp32_for_softmax: bool = True
66
+ activation:Callable = jax.nn.swish
67
+ norm_groups:int=8
68
+ dtype: Optional[Dtype] = None
69
+ precision: PrecisionLike = None
70
+ kernel_init: Callable = partial(kernel_init, scale=1.0)
71
+ add_residualblock_output: bool = False
72
+ norm_inputs: bool = False
73
+ explicitly_add_residual: bool = True
74
+
75
+ def setup(self):
76
+ if self.norm_groups > 0:
77
+ self.norm = partial(nn.GroupNorm, self.norm_groups)
78
+ else:
79
+ self.norm = partial(nn.RMSNorm, 1e-5)
80
+
81
+ @nn.compact
82
+ def __call__(self, x, temb, textcontext=None):
83
+ # Time embedding
84
+ temb = FourierEmbedding(features=self.emb_features)(temb)
85
+ temb = TimeProjection(features=self.emb_features)(temb)
86
+
87
+ original_img = x
88
+
89
+ # Patch embedding
90
+ x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.emb_features,
91
+ dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init)(x)
92
+ num_patches = x.shape[1]
93
+
94
+ context_emb = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
95
+ dtype=self.dtype, precision=self.precision)(textcontext)
96
+ num_text_tokens = textcontext.shape[1]
97
+
98
+ # print(f'Shape of x after patch embedding: {x.shape}, numPatches: {num_patches}, temb: {temb.shape}, context_emb: {context_emb.shape}')
99
+
100
+ # Add time embedding
101
+ temb = jnp.expand_dims(temb, axis=1)
102
+ x = jnp.concatenate([x, temb, context_emb], axis=1)
103
+ # print(f'Shape of x after time embedding: {x.shape}')
104
+
105
+ # Add positional encoding
106
+ x = PositionalEncoding(max_len=x.shape[1], embedding_dim=self.emb_features)(x)
107
+
108
+ # print(f'Shape of x after positional encoding: {x.shape}')
109
+
110
+ skips = []
111
+ # In blocks
112
+ for i in range(self.num_layers // 2):
113
+ x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
114
+ dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
115
+ use_flash_attention=self.use_flash_attention, use_self_and_cross=False, force_fp32_for_softmax=self.force_fp32_for_softmax,
116
+ only_pure_attention=False,
117
+ norm_inputs=self.norm_inputs,
118
+ explicitly_add_residual=self.explicitly_add_residual,
119
+ kernel_init=self.kernel_init())(x)
120
+ skips.append(x)
121
+
122
+ # Middle block
123
+ x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
124
+ dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
125
+ use_flash_attention=self.use_flash_attention, use_self_and_cross=False, force_fp32_for_softmax=self.force_fp32_for_softmax,
126
+ only_pure_attention=False,
127
+ norm_inputs=self.norm_inputs,
128
+ explicitly_add_residual=self.explicitly_add_residual,
129
+ kernel_init=self.kernel_init())(x)
130
+
131
+ # # Out blocks
132
+ for i in range(self.num_layers // 2):
133
+ x = jnp.concatenate([x, skips.pop()], axis=-1)
134
+ x = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
135
+ dtype=self.dtype, precision=self.precision)(x)
136
+ x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
137
+ dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
138
+ use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
139
+ only_pure_attention=False,
140
+ norm_inputs=self.norm_inputs,
141
+ explicitly_add_residual=self.explicitly_add_residual,
142
+ kernel_init=self.kernel_init())(x)
143
+
144
+ # print(f'Shape of x after transformer blocks: {x.shape}')
145
+ x = self.norm()(x)
146
+
147
+ patch_dim = self.patch_size ** 2 * self.output_channels
148
+ x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init())(x)
149
+ x = x[:, 1 + num_text_tokens:, :]
150
+ x = unpatchify(x, channels=self.output_channels)
151
+
152
+ if self.add_residualblock_output:
153
+ # Concatenate the original image
154
+ x = jnp.concatenate([original_img, x], axis=-1)
155
+
156
+ x = ConvLayer(
157
+ "conv",
158
+ features=64,
159
+ kernel_size=(3, 3),
160
+ strides=(1, 1),
161
+ # activation=jax.nn.mish
162
+ kernel_init=self.kernel_init(scale=0.0),
163
+ dtype=self.dtype,
164
+ precision=self.precision
165
+ )(x)
166
+
167
+ x = self.norm()(x)
168
+ x = self.activation(x)
169
+
170
+ x = ConvLayer(
171
+ "conv",
172
+ features=self.output_channels,
173
+ kernel_size=(3, 3),
174
+ strides=(1, 1),
175
+ # activation=jax.nn.mish
176
+ kernel_init=self.kernel_init(scale=0.0),
177
+ dtype=self.dtype,
178
+ precision=self.precision
179
+ )(x)
180
+ return x
@@ -0,0 +1,96 @@
1
+ from typing import Union
2
+ import jax.numpy as jnp
3
+ from ..schedulers import NoiseScheduler, GeneralizedNoiseScheduler
4
+
5
+ ############################################################################################################
6
+ # Prediction Transforms
7
+ ############################################################################################################
8
+
9
+ class DiffusionPredictionTransform():
10
+ def pred_transform(self, x_t, preds, rates) -> jnp.ndarray:
11
+ return preds
12
+
13
+ def __call__(self, x_t, preds, current_step, noise_schedule:NoiseScheduler) -> Union[jnp.ndarray, jnp.ndarray]:
14
+ rates = noise_schedule.get_rates(current_step)
15
+ preds = self.pred_transform(x_t, preds, rates)
16
+ x_0, epsilon = self.backward_diffusion(x_t, preds, rates)
17
+ return x_0, epsilon
18
+
19
+ def forward_diffusion(self, x_0, epsilon, rates: tuple[jnp.ndarray, jnp.ndarray]) -> Union[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
20
+ signal_rate, noise_rate = rates
21
+ x_t = signal_rate * x_0 + noise_rate * epsilon
22
+ expected_output = self.get_target(x_0, epsilon, (signal_rate, noise_rate))
23
+ c_in = self.get_input_scale((signal_rate, noise_rate))
24
+ return x_t, c_in, expected_output
25
+
26
+ def backward_diffusion(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> Union[jnp.ndarray, jnp.ndarray]:
27
+ raise NotImplementedError
28
+
29
+ def get_target(self, x_0, epsilon, rates) ->jnp.ndarray:
30
+ return x_0
31
+
32
+ def get_input_scale(self, rates: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
33
+ return 1
34
+
35
+ class EpsilonPredictionTransform(DiffusionPredictionTransform):
36
+ def backward_diffusion(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> Union[jnp.ndarray, jnp.ndarray]:
37
+ # preds is the predicted noise
38
+ epsilon = preds
39
+ signal_rates, noise_rates = rates
40
+ x_0 = (x_t - epsilon * noise_rates) / signal_rates
41
+ return x_0, epsilon
42
+
43
+ def get_target(self, x_0, epsilon, rates) ->jnp.ndarray:
44
+ return epsilon
45
+
46
+ class DirectPredictionTransform(DiffusionPredictionTransform):
47
+ def backward_diffusion(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> Union[jnp.ndarray, jnp.ndarray]:
48
+ # Here the model predicts x_0 directly
49
+ x_0 = preds
50
+ signal_rate, noise_rate = rates
51
+ epsilon = (x_t - x_0 * signal_rate) / noise_rate
52
+ return x_0, epsilon
53
+
54
+ class VPredictionTransform(DiffusionPredictionTransform):
55
+ def backward_diffusion(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> Union[jnp.ndarray, jnp.ndarray]:
56
+ # here the model output's V = sqrt_alpha_t * epsilon - sqrt_one_minus_alpha_t * x_0
57
+ # where epsilon is the noise
58
+ # x_0 is the current sample
59
+ v = preds
60
+ signal_rate, noise_rate = rates
61
+ variance = signal_rate ** 2 + noise_rate ** 2
62
+ v = v * jnp.sqrt(variance)
63
+ x_0 = signal_rate * x_t - noise_rate * v
64
+ eps_0 = signal_rate * v + noise_rate * x_t
65
+ return x_0 / variance, eps_0 / variance
66
+
67
+ def get_target(self, x_0, epsilon, rates) ->jnp.ndarray:
68
+ signal_rate, noise_rate = rates
69
+ v = signal_rate * epsilon - noise_rate * x_0
70
+ variance = signal_rate**2 + noise_rate**2
71
+ return v / jnp.sqrt(variance)
72
+
73
+ class KarrasPredictionTransform(DiffusionPredictionTransform):
74
+ def __init__(self, sigma_data=0.5) -> None:
75
+ super().__init__()
76
+ self.sigma_data = sigma_data
77
+
78
+ def backward_diffusion(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> Union[jnp.ndarray, jnp.ndarray]:
79
+ x_0 = preds
80
+ signal_rate, noise_rate = rates
81
+ epsilon = (x_t - x_0 * signal_rate) / noise_rate
82
+ return x_0, epsilon
83
+
84
+ def pred_transform(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
85
+ _, sigma = rates
86
+ c_out = sigma * self.sigma_data / jnp.sqrt(self.sigma_data ** 2 + sigma ** 2)
87
+ c_skip = self.sigma_data ** 2 / (self.sigma_data ** 2 + sigma ** 2)
88
+ c_out = c_out.reshape((-1, 1, 1, 1))
89
+ c_skip = c_skip.reshape((-1, 1, 1, 1))
90
+ x_0 = c_out * preds + c_skip * x_t
91
+ return x_0
92
+
93
+ def get_input_scale(self, rates: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
94
+ _, sigma = rates
95
+ c_in = 1 / jnp.sqrt(self.sigma_data ** 2 + sigma ** 2)
96
+ return c_in
@@ -0,0 +1,7 @@
1
+ from .common import DiffusionSampler
2
+ from .ddim import DDIMSampler
3
+ from .ddpm import DDPMSampler, SimpleDDPMSampler
4
+ from .euler import EulerSampler, SimplifiedEulerSampler
5
+ from .heun_sampler import HeunSampler
6
+ from .rk4_sampler import RK4Sampler
7
+ from .multistep_dpm import MultiStepDPM
@@ -0,0 +1,165 @@
1
+ from flax import linen as nn
2
+ import jax
3
+ import jax.numpy as jnp
4
+ import tqdm
5
+ from typing import Union
6
+ from ..schedulers import NoiseScheduler
7
+ from ..utils import RandomMarkovState, MarkovState, clip_images
8
+ from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
9
+
10
+ class DiffusionSampler():
11
+ model:nn.Module
12
+ noise_schedule:NoiseScheduler
13
+ params:dict
14
+ model_output_transform:DiffusionPredictionTransform
15
+
16
+ def __init__(self, model:nn.Module, params:dict,
17
+ noise_schedule:NoiseScheduler,
18
+ model_output_transform:DiffusionPredictionTransform=EpsilonPredictionTransform(),
19
+ guidance_scale:float = 0.0,
20
+ null_labels_seq:jax.Array=None,
21
+ autoencoder=None,
22
+ image_size=256,
23
+ autoenc_scale_reduction=8,
24
+ autoenc_latent_channels=4,
25
+ ):
26
+ self.model = model
27
+ self.noise_schedule = noise_schedule
28
+ self.params = params
29
+ self.model_output_transform = model_output_transform
30
+ self.guidance_scale = guidance_scale
31
+ self.image_size = image_size
32
+ self.autoenc_scale_reduction = autoenc_scale_reduction
33
+ self.autoencoder = autoencoder
34
+ self.autoenc_latent_channels = autoenc_latent_channels
35
+
36
+ if self.guidance_scale > 0:
37
+ # Classifier free guidance
38
+ assert null_labels_seq is not None, "Null labels sequence is required for classifier-free guidance"
39
+ print("Using classifier-free guidance")
40
+ def sample_model(x_t, t, *additional_inputs):
41
+ # Concatenate unconditional and conditional inputs
42
+ x_t_cat = jnp.concatenate([x_t] * 2, axis=0)
43
+ t_cat = jnp.concatenate([t] * 2, axis=0)
44
+ rates_cat = self.noise_schedule.get_rates(t_cat)
45
+ c_in_cat = self.model_output_transform.get_input_scale(rates_cat)
46
+
47
+ text_labels_seq, = additional_inputs
48
+ text_labels_seq = jnp.concatenate([text_labels_seq, jnp.broadcast_to(null_labels_seq, text_labels_seq.shape)], axis=0)
49
+ model_output = self.model.apply(self.params, *self.noise_schedule.transform_inputs(x_t_cat * c_in_cat, t_cat), text_labels_seq)
50
+ # Split model output into unconditional and conditional parts
51
+ model_output_cond, model_output_uncond = jnp.split(model_output, 2, axis=0)
52
+ model_output = model_output_uncond + guidance_scale * (model_output_cond - model_output_uncond)
53
+
54
+ x_0, eps = self.model_output_transform(x_t, model_output, t, self.noise_schedule)
55
+ return x_0, eps, model_output
56
+ else:
57
+ # Unconditional sampling
58
+ def sample_model(x_t, t, *additional_inputs):
59
+ rates = self.noise_schedule.get_rates(t)
60
+ c_in = self.model_output_transform.get_input_scale(rates)
61
+ model_output = self.model.apply(self.params, *self.noise_schedule.transform_inputs(x_t * c_in, t), *additional_inputs)
62
+ x_0, eps = self.model_output_transform(x_t, model_output, t, self.noise_schedule)
63
+ return x_0, eps, model_output
64
+
65
+ # if jax.device_count() > 1:
66
+ # mesh = jax.sharding.Mesh(jax.devices(), 'data')
67
+ # sample_model = shard_map(sample_model, mesh=mesh, in_specs=(P('data'), P('data'), P('data')),
68
+ # out_specs=(P('data'), P('data'), P('data')))
69
+ sample_model = jax.jit(sample_model)
70
+ self.sample_model = sample_model
71
+
72
+ # Used to sample from the diffusion model
73
+ def sample_step(self, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
74
+ # First clip the noisy images
75
+ step_ones = jnp.ones((current_samples.shape[0], ), dtype=jnp.int32)
76
+ current_step = step_ones * current_step
77
+ next_step = step_ones * next_step
78
+ pred_images, pred_noise, _ = self.sample_model(current_samples, current_step, *model_conditioning_inputs)
79
+ # plotImages(pred_images)
80
+ # pred_images = clip_images(pred_images)
81
+ new_samples, state = self.take_next_step(current_samples=current_samples, reconstructed_samples=pred_images,
82
+ pred_noise=pred_noise, current_step=current_step, next_step=next_step, state=state,
83
+ model_conditioning_inputs=model_conditioning_inputs
84
+ )
85
+ return new_samples, state
86
+
87
+ def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
88
+ pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
89
+ # estimate the q(x_{t-1} | x_t, x_0).
90
+ # pred_images is x_0, noisy_images is x_t, steps is t
91
+ return NotImplementedError
92
+
93
+ def scale_steps(self, steps):
94
+ scale_factor = self.noise_schedule.max_timesteps / 1000
95
+ return steps * scale_factor
96
+
97
+ def get_steps(self, start_step, end_step, diffusion_steps):
98
+ step_range = start_step - end_step
99
+ if diffusion_steps is None or diffusion_steps == 0:
100
+ diffusion_steps = start_step - end_step
101
+ diffusion_steps = min(diffusion_steps, step_range)
102
+ steps = jnp.linspace(end_step, start_step, diffusion_steps, dtype=jnp.int16)[::-1]
103
+ return steps
104
+
105
+ def get_initial_samples(self, num_images, rngs:jax.random.PRNGKey, start_step):
106
+ start_step = self.scale_steps(start_step)
107
+ alpha_n, sigma_n = self.noise_schedule.get_rates(start_step)
108
+ variance = jnp.sqrt(alpha_n ** 2 + sigma_n ** 2)
109
+ image_size = self.image_size
110
+ image_channels = 3
111
+ if self.autoencoder is not None:
112
+ image_size = image_size // self.autoenc_scale_reduction
113
+ image_channels = self.autoenc_latent_channels
114
+ return jax.random.normal(rngs, (num_images, image_size, image_size, image_channels)) * variance
115
+
116
+ def generate_images(self,
117
+ num_images=16,
118
+ diffusion_steps=1000,
119
+ start_step:int = None,
120
+ end_step:int = 0,
121
+ steps_override=None,
122
+ priors=None,
123
+ rngstate:RandomMarkovState=RandomMarkovState(jax.random.PRNGKey(42)),
124
+ model_conditioning_inputs:tuple=()
125
+ ) -> jnp.ndarray:
126
+ if priors is None:
127
+ rngstate, newrngs = rngstate.get_random_key()
128
+ samples = self.get_initial_samples(num_images, newrngs, start_step)
129
+ else:
130
+ print("Using priors")
131
+ if self.autoencoder is not None:
132
+ priors = self.autoencoder.encode(priors)
133
+ samples = priors
134
+
135
+ # @jax.jit
136
+ def sample_step(state:RandomMarkovState, samples, current_step, next_step):
137
+ samples, state = self.sample_step(current_samples=samples,
138
+ current_step=current_step,
139
+ model_conditioning_inputs=model_conditioning_inputs,
140
+ state=state, next_step=next_step)
141
+ return samples, state
142
+
143
+ if start_step is None:
144
+ start_step = self.noise_schedule.max_timesteps
145
+
146
+ if steps_override is not None:
147
+ steps = steps_override
148
+ else:
149
+ steps = self.get_steps(start_step, end_step, diffusion_steps)
150
+
151
+ # print("Sampling steps", steps)
152
+ for i in tqdm.tqdm(range(0, len(steps))):
153
+ current_step = self.scale_steps(steps[i])
154
+ next_step = self.scale_steps(steps[i+1] if i+1 < len(steps) else 0)
155
+ if i != len(steps) - 1:
156
+ # print("normal step")
157
+ samples, rngstate = sample_step(rngstate, samples, current_step, next_step)
158
+ else:
159
+ # print("last step")
160
+ step_ones = jnp.ones((num_images, ), dtype=jnp.int32)
161
+ samples, _, _ = self.sample_model(samples, current_step * step_ones, *model_conditioning_inputs)
162
+ if self.autoencoder is not None:
163
+ samples = self.autoencoder.decode(samples)
164
+ samples = clip_images(samples)
165
+ return samples