flaxdiff 0.1.35.6__py3-none-any.whl → 0.1.36.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. flaxdiff/utils.py +105 -2
  2. {flaxdiff-0.1.35.6.dist-info → flaxdiff-0.1.36.1.dist-info}/METADATA +16 -7
  3. flaxdiff-0.1.36.1.dist-info/RECORD +6 -0
  4. {flaxdiff-0.1.35.6.dist-info → flaxdiff-0.1.36.1.dist-info}/WHEEL +1 -1
  5. flaxdiff/data/__init__.py +0 -1
  6. flaxdiff/data/online_loader.py +0 -336
  7. flaxdiff/models/__init__.py +0 -1
  8. flaxdiff/models/attention.py +0 -368
  9. flaxdiff/models/autoencoder/__init__.py +0 -2
  10. flaxdiff/models/autoencoder/autoencoder.py +0 -19
  11. flaxdiff/models/autoencoder/diffusers.py +0 -91
  12. flaxdiff/models/autoencoder/simple_autoenc.py +0 -26
  13. flaxdiff/models/common.py +0 -346
  14. flaxdiff/models/favor_fastattn.py +0 -723
  15. flaxdiff/models/simple_unet.py +0 -233
  16. flaxdiff/models/simple_vit.py +0 -180
  17. flaxdiff/predictors/__init__.py +0 -96
  18. flaxdiff/samplers/__init__.py +0 -7
  19. flaxdiff/samplers/common.py +0 -113
  20. flaxdiff/samplers/ddim.py +0 -10
  21. flaxdiff/samplers/ddpm.py +0 -43
  22. flaxdiff/samplers/euler.py +0 -59
  23. flaxdiff/samplers/heun_sampler.py +0 -28
  24. flaxdiff/samplers/multistep_dpm.py +0 -60
  25. flaxdiff/samplers/rk4_sampler.py +0 -34
  26. flaxdiff/schedulers/__init__.py +0 -6
  27. flaxdiff/schedulers/common.py +0 -98
  28. flaxdiff/schedulers/continuous.py +0 -12
  29. flaxdiff/schedulers/cosine.py +0 -40
  30. flaxdiff/schedulers/discrete.py +0 -74
  31. flaxdiff/schedulers/exp.py +0 -13
  32. flaxdiff/schedulers/karras.py +0 -69
  33. flaxdiff/schedulers/linear.py +0 -14
  34. flaxdiff/schedulers/sqrt.py +0 -10
  35. flaxdiff/trainer/__init__.py +0 -2
  36. flaxdiff/trainer/autoencoder_trainer.py +0 -182
  37. flaxdiff/trainer/diffusion_trainer.py +0 -234
  38. flaxdiff/trainer/simple_trainer.py +0 -442
  39. flaxdiff-0.1.35.6.dist-info/RECORD +0 -40
  40. {flaxdiff-0.1.35.6.dist-info → flaxdiff-0.1.36.1.dist-info}/top_level.txt +0 -0
@@ -1,233 +0,0 @@
1
- import jax
2
- import jax.numpy as jnp
3
- from flax import linen as nn
4
- from flax.typing import Dtype, PrecisionLike
5
- from typing import Dict, Callable, Sequence, Any, Union, Optional
6
- import einops
7
- from .common import kernel_init, ConvLayer, Downsample, Upsample, FourierEmbedding, TimeProjection, ResidualBlock
8
- from .attention import TransformerBlock
9
- from functools import partial
10
-
11
- class Unet(nn.Module):
12
- output_channels:int=3
13
- emb_features:int=64*4,
14
- feature_depths:list=[64, 128, 256, 512],
15
- attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}],
16
- num_res_blocks:int=2,
17
- num_middle_res_blocks:int=1,
18
- activation:Callable = jax.nn.swish
19
- norm_groups:int=8
20
- dtype: Optional[Dtype] = None
21
- precision: PrecisionLike = None
22
- named_norms: bool = False # This is for backward compatibility reasons; older checkpoints have named norms
23
- kernel_init: Callable = partial(kernel_init, dtype=jnp.float32)
24
-
25
- def setup(self):
26
- if self.norm_groups > 0:
27
- norm = partial(nn.GroupNorm, self.norm_groups)
28
- self.conv_out_norm = norm(name="GroupNorm_0") if self.named_norms else norm()
29
- else:
30
- norm = partial(nn.RMSNorm, 1e-5)
31
- self.conv_out_norm = norm()
32
-
33
- @nn.compact
34
- def __call__(self, x, temb, textcontext):
35
- # print("embedding features", self.emb_features)
36
- temb = FourierEmbedding(features=self.emb_features)(temb)
37
- temb = TimeProjection(features=self.emb_features)(temb)
38
-
39
- _, TS, TC = textcontext.shape
40
-
41
- # print("time embedding", temb.shape)
42
- feature_depths = self.feature_depths
43
- attention_configs = self.attention_configs
44
-
45
- conv_type = up_conv_type = down_conv_type = middle_conv_type = "conv"
46
- # middle_conv_type = "separable"
47
-
48
- x = ConvLayer(
49
- conv_type,
50
- features=self.feature_depths[0],
51
- kernel_size=(3, 3),
52
- strides=(1, 1),
53
- kernel_init=self.kernel_init(scale=1.0),
54
- dtype=self.dtype,
55
- precision=self.precision
56
- )(x)
57
- downs = [x]
58
-
59
- # Downscaling blocks
60
- for i, (dim_out, attention_config) in enumerate(zip(feature_depths, attention_configs)):
61
- dim_in = x.shape[-1]
62
- # dim_in = dim_out
63
- for j in range(self.num_res_blocks):
64
- x = ResidualBlock(
65
- down_conv_type,
66
- name=f"down_{i}_residual_{j}",
67
- features=dim_in,
68
- kernel_init=self.kernel_init(scale=1.0),
69
- kernel_size=(3, 3),
70
- strides=(1, 1),
71
- activation=self.activation,
72
- norm_groups=self.norm_groups,
73
- dtype=self.dtype,
74
- precision=self.precision,
75
- named_norms=self.named_norms
76
- )(x, temb)
77
- if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
78
- x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
79
- dim_head=dim_in // attention_config['heads'],
80
- use_flash_attention=attention_config.get("flash_attention", False),
81
- use_projection=attention_config.get("use_projection", False),
82
- use_self_and_cross=attention_config.get("use_self_and_cross", True),
83
- precision=attention_config.get("precision", self.precision),
84
- only_pure_attention=attention_config.get("only_pure_attention", True),
85
- force_fp32_for_softmax=attention_config.get("force_fp32_for_softmax", False),
86
- norm_inputs=attention_config.get("norm_inputs", True),
87
- explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
88
- kernel_init=self.kernel_init(scale=1.0),
89
- name=f"down_{i}_attention_{j}")(x, textcontext)
90
- # print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
91
- downs.append(x)
92
- if i != len(feature_depths) - 1:
93
- # print("Downsample", i, x.shape)
94
- x = Downsample(
95
- features=dim_out,
96
- scale=2,
97
- activation=self.activation,
98
- name=f"down_{i}_downsample",
99
- dtype=self.dtype,
100
- precision=self.precision
101
- )(x)
102
-
103
- # Middle Blocks
104
- middle_dim_out = self.feature_depths[-1]
105
- middle_attention = self.attention_configs[-1]
106
- for j in range(self.num_middle_res_blocks):
107
- x = ResidualBlock(
108
- middle_conv_type,
109
- name=f"middle_res1_{j}",
110
- features=middle_dim_out,
111
- kernel_init=self.kernel_init(scale=1.0),
112
- kernel_size=(3, 3),
113
- strides=(1, 1),
114
- activation=self.activation,
115
- norm_groups=self.norm_groups,
116
- dtype=self.dtype,
117
- precision=self.precision,
118
- named_norms=self.named_norms
119
- )(x, temb)
120
- if middle_attention is not None and j == self.num_middle_res_blocks - 1: # Apply attention only on the last block
121
- x = TransformerBlock(heads=middle_attention['heads'], dtype=middle_attention.get('dtype', jnp.float32),
122
- dim_head=middle_dim_out // middle_attention['heads'],
123
- use_flash_attention=middle_attention.get("flash_attention", False),
124
- use_linear_attention=False,
125
- use_projection=middle_attention.get("use_projection", False),
126
- use_self_and_cross=False,
127
- precision=middle_attention.get("precision", self.precision),
128
- only_pure_attention=middle_attention.get("only_pure_attention", True),
129
- force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
130
- norm_inputs=middle_attention.get("norm_inputs", True),
131
- explicitly_add_residual=middle_attention.get("explicitly_add_residual", True),
132
- kernel_init=self.kernel_init(scale=1.0),
133
- name=f"middle_attention_{j}")(x, textcontext)
134
- x = ResidualBlock(
135
- middle_conv_type,
136
- name=f"middle_res2_{j}",
137
- features=middle_dim_out,
138
- kernel_init=self.kernel_init(scale=1.0),
139
- kernel_size=(3, 3),
140
- strides=(1, 1),
141
- activation=self.activation,
142
- norm_groups=self.norm_groups,
143
- dtype=self.dtype,
144
- precision=self.precision,
145
- named_norms=self.named_norms
146
- )(x, temb)
147
-
148
- # Upscaling Blocks
149
- for i, (dim_out, attention_config) in enumerate(zip(reversed(feature_depths), reversed(attention_configs))):
150
- # print("Upscaling", i, "features", dim_out)
151
- for j in range(self.num_res_blocks):
152
- x = jnp.concatenate([x, downs.pop()], axis=-1)
153
- # print("concat==> ", i, "concat", x.shape)
154
- # kernel_size = (1 + 2 * (j + 1), 1 + 2 * (j + 1))
155
- kernel_size = (3, 3)
156
- x = ResidualBlock(
157
- up_conv_type,# if j == 0 else "separable",
158
- name=f"up_{i}_residual_{j}",
159
- features=dim_out,
160
- kernel_init=self.kernel_init(scale=1.0),
161
- kernel_size=kernel_size,
162
- strides=(1, 1),
163
- activation=self.activation,
164
- norm_groups=self.norm_groups,
165
- dtype=self.dtype,
166
- precision=self.precision,
167
- named_norms=self.named_norms
168
- )(x, temb)
169
- if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
170
- x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
171
- dim_head=dim_out // attention_config['heads'],
172
- use_flash_attention=attention_config.get("flash_attention", False),
173
- use_projection=attention_config.get("use_projection", False),
174
- use_self_and_cross=attention_config.get("use_self_and_cross", True),
175
- precision=attention_config.get("precision", self.precision),
176
- only_pure_attention=attention_config.get("only_pure_attention", True),
177
- force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
178
- norm_inputs=attention_config.get("norm_inputs", True),
179
- explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
180
- kernel_init=self.kernel_init(scale=1.0),
181
- name=f"up_{i}_attention_{j}")(x, textcontext)
182
- # print("Upscaling ", i, x.shape)
183
- if i != len(feature_depths) - 1:
184
- x = Upsample(
185
- features=feature_depths[-i],
186
- scale=2,
187
- activation=self.activation,
188
- name=f"up_{i}_upsample",
189
- dtype=self.dtype,
190
- precision=self.precision
191
- )(x)
192
-
193
- # x = self.last_up_norm(x)
194
- x = ConvLayer(
195
- conv_type,
196
- features=self.feature_depths[0],
197
- kernel_size=(3, 3),
198
- strides=(1, 1),
199
- kernel_init=self.kernel_init(scale=1.0),
200
- dtype=self.dtype,
201
- precision=self.precision
202
- )(x)
203
-
204
- x = jnp.concatenate([x, downs.pop()], axis=-1)
205
-
206
- x = ResidualBlock(
207
- conv_type,
208
- name="final_residual",
209
- features=self.feature_depths[0],
210
- kernel_init=self.kernel_init(scale=1.0),
211
- kernel_size=(3,3),
212
- strides=(1, 1),
213
- activation=self.activation,
214
- norm_groups=self.norm_groups,
215
- dtype=self.dtype,
216
- precision=self.precision,
217
- named_norms=self.named_norms
218
- )(x, temb)
219
-
220
- x = self.conv_out_norm(x)
221
- x = self.activation(x)
222
-
223
- noise_out = ConvLayer(
224
- conv_type,
225
- features=self.output_channels,
226
- kernel_size=(3, 3),
227
- strides=(1, 1),
228
- # activation=jax.nn.mish
229
- kernel_init=self.kernel_init(scale=0.0),
230
- dtype=self.dtype,
231
- precision=self.precision
232
- )(x)
233
- return noise_out#, attentions
@@ -1,180 +0,0 @@
1
- # simple_vit.py
2
-
3
- import jax
4
- import jax.numpy as jnp
5
- from flax import linen as nn
6
- from typing import Callable, Any, Optional, Tuple
7
- from .simple_unet import FourierEmbedding, TimeProjection, ConvLayer, kernel_init
8
- from .attention import TransformerBlock
9
- from flaxdiff.models.simple_unet import FourierEmbedding, TimeProjection, ConvLayer, kernel_init, ResidualBlock
10
- import einops
11
- from flax.typing import Dtype, PrecisionLike
12
- from functools import partial
13
-
14
- def unpatchify(x, channels=3):
15
- patch_size = int((x.shape[2] // channels) ** 0.5)
16
- h = w = int(x.shape[1] ** .5)
17
- assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2], f"Invalid shape: {x.shape}, should be {h*w}, {patch_size**2*channels}"
18
- x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B (h p1) (w p2) C', h=h, p1=patch_size, p2=patch_size)
19
- return x
20
-
21
- class PatchEmbedding(nn.Module):
22
- patch_size: int
23
- embedding_dim: int
24
- dtype: Any = jnp.float32
25
- precision: Any = jax.lax.Precision.HIGH
26
- kernel_init: Callable = partial(kernel_init, 1.0)
27
-
28
- @nn.compact
29
- def __call__(self, x):
30
- batch, height, width, channels = x.shape
31
- assert height % self.patch_size == 0 and width % self.patch_size == 0, "Image dimensions must be divisible by patch size"
32
-
33
- x = nn.Conv(features=self.embedding_dim,
34
- kernel_size=(self.patch_size, self.patch_size),
35
- strides=(self.patch_size, self.patch_size),
36
- dtype=self.dtype,
37
- kernel_init=self.kernel_init(),
38
- precision=self.precision)(x)
39
- x = jnp.reshape(x, (batch, -1, self.embedding_dim))
40
- return x
41
-
42
- class PositionalEncoding(nn.Module):
43
- max_len: int
44
- embedding_dim: int
45
-
46
- @nn.compact
47
- def __call__(self, x):
48
- pe = self.param('pos_encoding',
49
- jax.nn.initializers.zeros,
50
- (1, self.max_len, self.embedding_dim))
51
- return x + pe[:, :x.shape[1], :]
52
-
53
- class UViT(nn.Module):
54
- output_channels:int=3
55
- patch_size: int = 16
56
- emb_features:int=768,
57
- num_layers: int = 12
58
- num_heads: int = 12
59
- dropout_rate: float = 0.1
60
- dtype: Any = jnp.float32
61
- precision: Any = jax.lax.Precision.HIGH
62
- use_projection: bool = False
63
- use_flash_attention: bool = False
64
- use_self_and_cross: bool = False
65
- force_fp32_for_softmax: bool = True
66
- activation:Callable = jax.nn.swish
67
- norm_groups:int=8
68
- dtype: Optional[Dtype] = None
69
- precision: PrecisionLike = None
70
- kernel_init: Callable = partial(kernel_init, scale=1.0)
71
- add_residualblock_output: bool = False
72
- norm_inputs: bool = False
73
- explicitly_add_residual: bool = True
74
-
75
- def setup(self):
76
- if self.norm_groups > 0:
77
- self.norm = partial(nn.GroupNorm, self.norm_groups)
78
- else:
79
- self.norm = partial(nn.RMSNorm, 1e-5)
80
-
81
- @nn.compact
82
- def __call__(self, x, temb, textcontext=None):
83
- # Time embedding
84
- temb = FourierEmbedding(features=self.emb_features)(temb)
85
- temb = TimeProjection(features=self.emb_features)(temb)
86
-
87
- original_img = x
88
-
89
- # Patch embedding
90
- x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.emb_features,
91
- dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init)(x)
92
- num_patches = x.shape[1]
93
-
94
- context_emb = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
95
- dtype=self.dtype, precision=self.precision)(textcontext)
96
- num_text_tokens = textcontext.shape[1]
97
-
98
- # print(f'Shape of x after patch embedding: {x.shape}, numPatches: {num_patches}, temb: {temb.shape}, context_emb: {context_emb.shape}')
99
-
100
- # Add time embedding
101
- temb = jnp.expand_dims(temb, axis=1)
102
- x = jnp.concatenate([x, temb, context_emb], axis=1)
103
- # print(f'Shape of x after time embedding: {x.shape}')
104
-
105
- # Add positional encoding
106
- x = PositionalEncoding(max_len=x.shape[1], embedding_dim=self.emb_features)(x)
107
-
108
- # print(f'Shape of x after positional encoding: {x.shape}')
109
-
110
- skips = []
111
- # In blocks
112
- for i in range(self.num_layers // 2):
113
- x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
114
- dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
115
- use_flash_attention=self.use_flash_attention, use_self_and_cross=False, force_fp32_for_softmax=self.force_fp32_for_softmax,
116
- only_pure_attention=False,
117
- norm_inputs=self.norm_inputs,
118
- explicitly_add_residual=self.explicitly_add_residual,
119
- kernel_init=self.kernel_init())(x)
120
- skips.append(x)
121
-
122
- # Middle block
123
- x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
124
- dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
125
- use_flash_attention=self.use_flash_attention, use_self_and_cross=False, force_fp32_for_softmax=self.force_fp32_for_softmax,
126
- only_pure_attention=False,
127
- norm_inputs=self.norm_inputs,
128
- explicitly_add_residual=self.explicitly_add_residual,
129
- kernel_init=self.kernel_init())(x)
130
-
131
- # # Out blocks
132
- for i in range(self.num_layers // 2):
133
- x = jnp.concatenate([x, skips.pop()], axis=-1)
134
- x = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
135
- dtype=self.dtype, precision=self.precision)(x)
136
- x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
137
- dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
138
- use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
139
- only_pure_attention=False,
140
- norm_inputs=self.norm_inputs,
141
- explicitly_add_residual=self.explicitly_add_residual,
142
- kernel_init=self.kernel_init())(x)
143
-
144
- # print(f'Shape of x after transformer blocks: {x.shape}')
145
- x = self.norm()(x)
146
-
147
- patch_dim = self.patch_size ** 2 * self.output_channels
148
- x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init())(x)
149
- x = x[:, 1 + num_text_tokens:, :]
150
- x = unpatchify(x, channels=self.output_channels)
151
-
152
- if self.add_residualblock_output:
153
- # Concatenate the original image
154
- x = jnp.concatenate([original_img, x], axis=-1)
155
-
156
- x = ConvLayer(
157
- "conv",
158
- features=64,
159
- kernel_size=(3, 3),
160
- strides=(1, 1),
161
- # activation=jax.nn.mish
162
- kernel_init=self.kernel_init(scale=0.0),
163
- dtype=self.dtype,
164
- precision=self.precision
165
- )(x)
166
-
167
- x = self.norm()(x)
168
- x = self.activation(x)
169
-
170
- x = ConvLayer(
171
- "conv",
172
- features=self.output_channels,
173
- kernel_size=(3, 3),
174
- strides=(1, 1),
175
- # activation=jax.nn.mish
176
- kernel_init=self.kernel_init(scale=0.0),
177
- dtype=self.dtype,
178
- precision=self.precision
179
- )(x)
180
- return x
@@ -1,96 +0,0 @@
1
- from typing import Union
2
- import jax.numpy as jnp
3
- from ..schedulers import NoiseScheduler, GeneralizedNoiseScheduler
4
-
5
- ############################################################################################################
6
- # Prediction Transforms
7
- ############################################################################################################
8
-
9
- class DiffusionPredictionTransform():
10
- def pred_transform(self, x_t, preds, rates) -> jnp.ndarray:
11
- return preds
12
-
13
- def __call__(self, x_t, preds, current_step, noise_schedule:NoiseScheduler) -> Union[jnp.ndarray, jnp.ndarray]:
14
- rates = noise_schedule.get_rates(current_step)
15
- preds = self.pred_transform(x_t, preds, rates)
16
- x_0, epsilon = self.backward_diffusion(x_t, preds, rates)
17
- return x_0, epsilon
18
-
19
- def forward_diffusion(self, x_0, epsilon, rates: tuple[jnp.ndarray, jnp.ndarray]) -> Union[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
20
- signal_rate, noise_rate = rates
21
- x_t = signal_rate * x_0 + noise_rate * epsilon
22
- expected_output = self.get_target(x_0, epsilon, (signal_rate, noise_rate))
23
- c_in = self.get_input_scale((signal_rate, noise_rate))
24
- return x_t, c_in, expected_output
25
-
26
- def backward_diffusion(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> Union[jnp.ndarray, jnp.ndarray]:
27
- raise NotImplementedError
28
-
29
- def get_target(self, x_0, epsilon, rates) ->jnp.ndarray:
30
- return x_0
31
-
32
- def get_input_scale(self, rates: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
33
- return 1
34
-
35
- class EpsilonPredictionTransform(DiffusionPredictionTransform):
36
- def backward_diffusion(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> Union[jnp.ndarray, jnp.ndarray]:
37
- # preds is the predicted noise
38
- epsilon = preds
39
- signal_rates, noise_rates = rates
40
- x_0 = (x_t - epsilon * noise_rates) / signal_rates
41
- return x_0, epsilon
42
-
43
- def get_target(self, x_0, epsilon, rates) ->jnp.ndarray:
44
- return epsilon
45
-
46
- class DirectPredictionTransform(DiffusionPredictionTransform):
47
- def backward_diffusion(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> Union[jnp.ndarray, jnp.ndarray]:
48
- # Here the model predicts x_0 directly
49
- x_0 = preds
50
- signal_rate, noise_rate = rates
51
- epsilon = (x_t - x_0 * signal_rate) / noise_rate
52
- return x_0, epsilon
53
-
54
- class VPredictionTransform(DiffusionPredictionTransform):
55
- def backward_diffusion(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> Union[jnp.ndarray, jnp.ndarray]:
56
- # here the model output's V = sqrt_alpha_t * epsilon - sqrt_one_minus_alpha_t * x_0
57
- # where epsilon is the noise
58
- # x_0 is the current sample
59
- v = preds
60
- signal_rate, noise_rate = rates
61
- variance = signal_rate ** 2 + noise_rate ** 2
62
- v = v * jnp.sqrt(variance)
63
- x_0 = signal_rate * x_t - noise_rate * v
64
- eps_0 = signal_rate * v + noise_rate * x_t
65
- return x_0 / variance, eps_0 / variance
66
-
67
- def get_target(self, x_0, epsilon, rates) ->jnp.ndarray:
68
- signal_rate, noise_rate = rates
69
- v = signal_rate * epsilon - noise_rate * x_0
70
- variance = signal_rate**2 + noise_rate**2
71
- return v / jnp.sqrt(variance)
72
-
73
- class KarrasPredictionTransform(DiffusionPredictionTransform):
74
- def __init__(self, sigma_data=0.5) -> None:
75
- super().__init__()
76
- self.sigma_data = sigma_data
77
-
78
- def backward_diffusion(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> Union[jnp.ndarray, jnp.ndarray]:
79
- x_0 = preds
80
- signal_rate, noise_rate = rates
81
- epsilon = (x_t - x_0 * signal_rate) / noise_rate
82
- return x_0, epsilon
83
-
84
- def pred_transform(self, x_t, preds, rates: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
85
- _, sigma = rates
86
- c_out = sigma * self.sigma_data / jnp.sqrt(self.sigma_data ** 2 + sigma ** 2)
87
- c_skip = self.sigma_data ** 2 / (self.sigma_data ** 2 + sigma ** 2)
88
- c_out = c_out.reshape((-1, 1, 1, 1))
89
- c_skip = c_skip.reshape((-1, 1, 1, 1))
90
- x_0 = c_out * preds + c_skip * x_t
91
- return x_0
92
-
93
- def get_input_scale(self, rates: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
94
- _, sigma = rates
95
- c_in = 1 / jnp.sqrt(self.sigma_data ** 2 + sigma ** 2)
96
- return c_in
@@ -1,7 +0,0 @@
1
- from .common import DiffusionSampler
2
- from .ddim import DDIMSampler
3
- from .ddpm import DDPMSampler, SimpleDDPMSampler
4
- from .euler import EulerSampler, SimplifiedEulerSampler
5
- from .heun_sampler import HeunSampler
6
- from .rk4_sampler import RK4Sampler
7
- from .multistep_dpm import MultiStepDPM
@@ -1,113 +0,0 @@
1
- from flax import linen as nn
2
- import jax
3
- import jax.numpy as jnp
4
- import tqdm
5
- from typing import Union
6
- from ..schedulers import NoiseScheduler
7
- from ..utils import RandomMarkovState, MarkovState, clip_images
8
- from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
9
-
10
- class DiffusionSampler():
11
- model:nn.Module
12
- noise_schedule:NoiseScheduler
13
- params:dict
14
- model_output_transform:DiffusionPredictionTransform
15
-
16
- def __init__(self, model:nn.Module, params:dict,
17
- noise_schedule:NoiseScheduler,
18
- model_output_transform:DiffusionPredictionTransform=EpsilonPredictionTransform()):
19
- self.model = model
20
- self.noise_schedule = noise_schedule
21
- self.params = params
22
- self.model_output_transform = model_output_transform
23
-
24
- @jax.jit
25
- def sample_model(x_t, t):
26
- rates = self.noise_schedule.get_rates(t)
27
- c_in = self.model_output_transform.get_input_scale(rates)
28
- model_output = self.model.apply(self.params, *self.noise_schedule.transform_inputs(x_t * c_in, t))
29
- x_0, eps = self.model_output_transform(x_t, model_output, t, self.noise_schedule)
30
- return x_0, eps, model_output
31
-
32
- self.sample_model = sample_model
33
-
34
- # Used to sample from the diffusion model
35
- def sample_step(self, current_samples:jnp.ndarray, current_step, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
36
- # First clip the noisy images
37
- # pred_images = clip_images(pred_images)
38
- step_ones = jnp.ones((current_samples.shape[0], ), dtype=jnp.int32)
39
- current_step = step_ones * current_step
40
- next_step = step_ones * next_step
41
- pred_images, pred_noise, _ = self.sample_model(current_samples, current_step)
42
- # plotImages(pred_images)
43
- new_samples, state = self.take_next_step(current_samples=current_samples, reconstructed_samples=pred_images,
44
- pred_noise=pred_noise, current_step=current_step, next_step=next_step, state=state)
45
- return new_samples, state
46
-
47
- def take_next_step(self, current_samples, reconstructed_samples,
48
- pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
49
- # estimate the q(x_{t-1} | x_t, x_0).
50
- # pred_images is x_0, noisy_images is x_t, steps is t
51
- return NotImplementedError
52
-
53
- def scale_steps(self, steps):
54
- scale_factor = self.noise_schedule.max_timesteps / 1000
55
- return steps * scale_factor
56
-
57
- def get_steps(self, start_step, end_step, diffusion_steps):
58
- step_range = start_step - end_step
59
- if diffusion_steps is None or diffusion_steps == 0:
60
- diffusion_steps = start_step - end_step
61
- diffusion_steps = min(diffusion_steps, step_range)
62
- steps = jnp.linspace(end_step, start_step, diffusion_steps, dtype=jnp.int16)[::-1]
63
- return steps
64
-
65
- def get_initial_samples(self, num_images, rngs:jax.random.PRNGKey, start_step, image_size=64):
66
- start_step = self.scale_steps(start_step)
67
- alpha_n, sigma_n = self.noise_schedule.get_rates(start_step)
68
- variance = jnp.sqrt(alpha_n ** 2 + sigma_n ** 2)
69
- return jax.random.normal(rngs, (num_images, image_size, image_size, 3)) * variance
70
-
71
- def generate_images(self,
72
- num_images=16,
73
- diffusion_steps=1000,
74
- start_step:int = None,
75
- end_step:int = 0,
76
- steps_override=None,
77
- priors=None,
78
- rngstate:RandomMarkovState=RandomMarkovState(jax.random.PRNGKey(42))) -> jnp.ndarray:
79
- if priors is None:
80
- rngstate, newrngs = rngstate.get_random_key()
81
- samples = self.get_initial_samples(num_images, newrngs, start_step)
82
- else:
83
- print("Using priors")
84
- samples = priors
85
-
86
- @jax.jit
87
- def sample_step(state:RandomMarkovState, samples, current_step, next_step):
88
- samples, state = self.sample_step(current_samples=samples,
89
- current_step=current_step,
90
- state=state, next_step=next_step)
91
- return samples, state
92
-
93
- if start_step is None:
94
- start_step = self.noise_schedule.max_timesteps
95
-
96
- if steps_override is not None:
97
- steps = steps_override
98
- else:
99
- steps = self.get_steps(start_step, end_step, diffusion_steps)
100
-
101
- # print("Sampling steps", steps)
102
- for i in tqdm.tqdm(range(0, len(steps))):
103
- current_step = self.scale_steps(steps[i])
104
- next_step = self.scale_steps(steps[i+1] if i+1 < len(steps) else 0)
105
- if i != len(steps) - 1:
106
- # print("normal step")
107
- samples, rngstate = sample_step(rngstate, samples, current_step, next_step)
108
- else:
109
- # print("last step")
110
- step_ones = jnp.ones((num_images, ), dtype=jnp.int32)
111
- samples, _, _ = self.sample_model(samples, current_step * step_ones)
112
- samples = clip_images(samples)
113
- return samples
flaxdiff/samplers/ddim.py DELETED
@@ -1,10 +0,0 @@
1
- import jax.numpy as jnp
2
- from .common import DiffusionSampler
3
- from ..utils import MarkovState
4
-
5
- class DDIMSampler(DiffusionSampler):
6
- def take_next_step(self,
7
- current_samples, reconstructed_samples,
8
- pred_noise, current_step, state:MarkovState, next_step=None) -> tuple[jnp.ndarray, MarkovState]:
9
- next_signal_rate, next_noise_rate = self.noise_schedule.get_rates(next_step)
10
- return reconstructed_samples * next_signal_rate + pred_noise * next_noise_rate, state