flaxdiff 0.1.35.6__py3-none-any.whl → 0.1.36.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. flaxdiff/utils.py +105 -2
  2. {flaxdiff-0.1.35.6.dist-info → flaxdiff-0.1.36.1.dist-info}/METADATA +16 -7
  3. flaxdiff-0.1.36.1.dist-info/RECORD +6 -0
  4. {flaxdiff-0.1.35.6.dist-info → flaxdiff-0.1.36.1.dist-info}/WHEEL +1 -1
  5. flaxdiff/data/__init__.py +0 -1
  6. flaxdiff/data/online_loader.py +0 -336
  7. flaxdiff/models/__init__.py +0 -1
  8. flaxdiff/models/attention.py +0 -368
  9. flaxdiff/models/autoencoder/__init__.py +0 -2
  10. flaxdiff/models/autoencoder/autoencoder.py +0 -19
  11. flaxdiff/models/autoencoder/diffusers.py +0 -91
  12. flaxdiff/models/autoencoder/simple_autoenc.py +0 -26
  13. flaxdiff/models/common.py +0 -346
  14. flaxdiff/models/favor_fastattn.py +0 -723
  15. flaxdiff/models/simple_unet.py +0 -233
  16. flaxdiff/models/simple_vit.py +0 -180
  17. flaxdiff/predictors/__init__.py +0 -96
  18. flaxdiff/samplers/__init__.py +0 -7
  19. flaxdiff/samplers/common.py +0 -113
  20. flaxdiff/samplers/ddim.py +0 -10
  21. flaxdiff/samplers/ddpm.py +0 -43
  22. flaxdiff/samplers/euler.py +0 -59
  23. flaxdiff/samplers/heun_sampler.py +0 -28
  24. flaxdiff/samplers/multistep_dpm.py +0 -60
  25. flaxdiff/samplers/rk4_sampler.py +0 -34
  26. flaxdiff/schedulers/__init__.py +0 -6
  27. flaxdiff/schedulers/common.py +0 -98
  28. flaxdiff/schedulers/continuous.py +0 -12
  29. flaxdiff/schedulers/cosine.py +0 -40
  30. flaxdiff/schedulers/discrete.py +0 -74
  31. flaxdiff/schedulers/exp.py +0 -13
  32. flaxdiff/schedulers/karras.py +0 -69
  33. flaxdiff/schedulers/linear.py +0 -14
  34. flaxdiff/schedulers/sqrt.py +0 -10
  35. flaxdiff/trainer/__init__.py +0 -2
  36. flaxdiff/trainer/autoencoder_trainer.py +0 -182
  37. flaxdiff/trainer/diffusion_trainer.py +0 -234
  38. flaxdiff/trainer/simple_trainer.py +0 -442
  39. flaxdiff-0.1.35.6.dist-info/RECORD +0 -40
  40. {flaxdiff-0.1.35.6.dist-info → flaxdiff-0.1.36.1.dist-info}/top_level.txt +0 -0
flaxdiff/models/common.py DELETED
@@ -1,346 +0,0 @@
1
- import jax.numpy as jnp
2
- import jax
3
- from flax import linen as nn
4
- from typing import Optional, Any, Callable, Sequence, Union
5
- from flax.typing import Dtype, PrecisionLike
6
- from typing import Dict, Callable, Sequence, Any, Union
7
- import einops
8
- from functools import partial
9
-
10
- # Kernel initializer to use
11
- def kernel_init(scale=1.0, dtype=jnp.float32):
12
- scale = max(scale, 1e-10)
13
- return nn.initializers.variance_scaling(scale=scale, mode="fan_avg", distribution="truncated_normal", dtype=dtype)
14
-
15
-
16
- class WeightStandardizedConv(nn.Module):
17
- """
18
- apply weight standardization https://arxiv.org/abs/1903.10520
19
- """
20
- features: int
21
- kernel_size: Sequence[int] = 3
22
- strides: Union[None, int, Sequence[int]] = 1
23
- padding: Any = 1
24
- dtype: Optional[Dtype] = None
25
- precision: PrecisionLike = None
26
- param_dtype: Optional[Dtype] = None
27
-
28
- @nn.compact
29
- def __call__(self, x):
30
- """
31
- Applies a weight standardized convolution to the inputs.
32
-
33
- Args:
34
- inputs: input data with dimensions (batch, spatial_dims..., features).
35
-
36
- Returns:
37
- The convolved data.
38
- """
39
- x = x.astype(self.dtype)
40
-
41
- conv = nn.Conv(
42
- features=self.features,
43
- kernel_size=self.kernel_size,
44
- strides = self.strides,
45
- padding=self.padding,
46
- dtype=self.dtype,
47
- param_dtype = self.param_dtype,
48
- parent=None)
49
-
50
- kernel_init = lambda rng, x: conv.init(rng,x)['params']['kernel']
51
- bias_init = lambda rng, x: conv.init(rng,x)['params']['bias']
52
-
53
- # standardize kernel
54
- kernel = self.param('kernel', kernel_init, x)
55
- eps = 1e-5 if self.dtype == jnp.float32 else 1e-3
56
- # reduce over dim_out
57
- redux = tuple(range(kernel.ndim - 1))
58
- mean = jnp.mean(kernel, axis=redux, dtype=self.dtype, keepdims=True)
59
- var = jnp.var(kernel, axis=redux, dtype=self.dtype, keepdims=True)
60
- standardized_kernel = (kernel - mean)/jnp.sqrt(var + eps)
61
-
62
- bias = self.param('bias',bias_init, x)
63
-
64
- return(conv.apply({'params': {'kernel': standardized_kernel, 'bias': bias}},x))
65
-
66
- class PixelShuffle(nn.Module):
67
- scale: int
68
-
69
- @nn.compact
70
- def __call__(self, x):
71
- up = einops.rearrange(
72
- x,
73
- pattern="b h w (h2 w2 c) -> b (h h2) (w w2) c",
74
- h2=self.scale,
75
- w2=self.scale,
76
- )
77
- return up
78
-
79
- class TimeEmbedding(nn.Module):
80
- features:int
81
- nax_positions:int=10000
82
-
83
- def setup(self):
84
- half_dim = self.features // 2
85
- emb = jnp.log(self.nax_positions) / (half_dim - 1)
86
- emb = jnp.exp(-emb * jnp.arange(half_dim, dtype=jnp.float32))
87
- self.embeddings = emb
88
-
89
- def __call__(self, x):
90
- x = jax.lax.convert_element_type(x, jnp.float32)
91
- emb = x[:, None] * self.embeddings[None, :]
92
- emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
93
- return emb
94
-
95
- class FourierEmbedding(nn.Module):
96
- features:int
97
- scale:int = 16
98
-
99
- def setup(self):
100
- self.freqs = jax.random.normal(jax.random.PRNGKey(42), (self.features // 2, ), dtype=jnp.float32) * self.scale
101
-
102
- def __call__(self, x):
103
- x = jax.lax.convert_element_type(x, jnp.float32)
104
- emb = x[:, None] * (2 * jnp.pi * self.freqs)[None, :]
105
- emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
106
- return emb
107
-
108
- class TimeProjection(nn.Module):
109
- features:int
110
- activation:Callable=jax.nn.gelu
111
- kernel_init:Callable=kernel_init(1.0)
112
-
113
- @nn.compact
114
- def __call__(self, x):
115
- x = nn.DenseGeneral(self.features, kernel_init=self.kernel_init)(x)
116
- x = self.activation(x)
117
- x = nn.DenseGeneral(self.features, kernel_init=self.kernel_init)(x)
118
- x = self.activation(x)
119
- return x
120
-
121
- class SeparableConv(nn.Module):
122
- features:int
123
- kernel_size:tuple=(3, 3)
124
- strides:tuple=(1, 1)
125
- use_bias:bool=False
126
- kernel_init:Callable=kernel_init(1.0)
127
- padding:str="SAME"
128
- dtype: Optional[Dtype] = None
129
- precision: PrecisionLike = None
130
-
131
- @nn.compact
132
- def __call__(self, x):
133
- in_features = x.shape[-1]
134
- depthwise = nn.Conv(
135
- features=in_features, kernel_size=self.kernel_size,
136
- strides=self.strides, kernel_init=self.kernel_init,
137
- feature_group_count=in_features, use_bias=self.use_bias,
138
- padding=self.padding,
139
- dtype=self.dtype,
140
- precision=self.precision
141
- )(x)
142
- pointwise = nn.Conv(
143
- features=self.features, kernel_size=(1, 1),
144
- strides=(1, 1), kernel_init=self.kernel_init,
145
- use_bias=self.use_bias,
146
- dtype=self.dtype,
147
- precision=self.precision
148
- )(depthwise)
149
- return pointwise
150
-
151
- class ConvLayer(nn.Module):
152
- conv_type:str
153
- features:int
154
- kernel_size:tuple=(3, 3)
155
- strides:tuple=(1, 1)
156
- kernel_init:Callable=kernel_init(1.0)
157
- dtype: Optional[Dtype] = None
158
- precision: PrecisionLike = None
159
-
160
- def setup(self):
161
- # conv_type can be "conv", "separable", "conv_transpose"
162
- if self.conv_type == "conv":
163
- self.conv = nn.Conv(
164
- features=self.features,
165
- kernel_size=self.kernel_size,
166
- strides=self.strides,
167
- kernel_init=self.kernel_init,
168
- dtype=self.dtype,
169
- precision=self.precision
170
- )
171
- elif self.conv_type == "w_conv":
172
- self.conv = WeightStandardizedConv(
173
- features=self.features,
174
- kernel_size=self.kernel_size,
175
- strides=self.strides,
176
- padding="SAME",
177
- param_dtype=self.dtype,
178
- dtype=self.dtype,
179
- precision=self.precision
180
- )
181
- elif self.conv_type == "separable":
182
- self.conv = SeparableConv(
183
- features=self.features,
184
- kernel_size=self.kernel_size,
185
- strides=self.strides,
186
- kernel_init=self.kernel_init,
187
- dtype=self.dtype,
188
- precision=self.precision
189
- )
190
- elif self.conv_type == "conv_transpose":
191
- self.conv = nn.ConvTranspose(
192
- features=self.features,
193
- kernel_size=self.kernel_size,
194
- strides=self.strides,
195
- kernel_init=self.kernel_init,
196
- dtype=self.dtype,
197
- precision=self.precision
198
- )
199
-
200
- def __call__(self, x):
201
- return self.conv(x)
202
-
203
- class Upsample(nn.Module):
204
- features:int
205
- scale:int
206
- activation:Callable=jax.nn.swish
207
- dtype: Optional[Dtype] = None
208
- precision: PrecisionLike = None
209
- kernel_init:Callable=kernel_init(1.0)
210
-
211
- @nn.compact
212
- def __call__(self, x, residual=None):
213
- out = x
214
- # out = PixelShuffle(scale=self.scale)(out)
215
- B, H, W, C = x.shape
216
- out = jax.image.resize(x, (B, H * self.scale, W * self.scale, C), method="nearest")
217
- out = ConvLayer(
218
- "conv",
219
- features=self.features,
220
- kernel_size=(3, 3),
221
- strides=(1, 1),
222
- dtype=self.dtype,
223
- precision=self.precision,
224
- kernel_init=self.kernel_init
225
- )(out)
226
- if residual is not None:
227
- out = jnp.concatenate([out, residual], axis=-1)
228
- return out
229
-
230
- class Downsample(nn.Module):
231
- features:int
232
- scale:int
233
- activation:Callable=jax.nn.swish
234
- dtype: Optional[Dtype] = None
235
- precision: PrecisionLike = None
236
- kernel_init:Callable=kernel_init(1.0)
237
-
238
- @nn.compact
239
- def __call__(self, x, residual=None):
240
- out = ConvLayer(
241
- "conv",
242
- features=self.features,
243
- kernel_size=(3, 3),
244
- strides=(2, 2),
245
- dtype=self.dtype,
246
- precision=self.precision,
247
- kernel_init=self.kernel_init
248
- )(x)
249
- if residual is not None:
250
- if residual.shape[1] > out.shape[1]:
251
- residual = nn.avg_pool(residual, window_shape=(2, 2), strides=(2, 2), padding="SAME")
252
- out = jnp.concatenate([out, residual], axis=-1)
253
- return out
254
-
255
-
256
- def l2norm(t, axis=1, eps=1e-12):
257
- denom = jnp.clip(jnp.linalg.norm(t, ord=2, axis=axis, keepdims=True), eps)
258
- out = t/denom
259
- return (out)
260
-
261
-
262
- class ResidualBlock(nn.Module):
263
- conv_type:str
264
- features:int
265
- kernel_size:tuple=(3, 3)
266
- strides:tuple=(1, 1)
267
- padding:str="SAME"
268
- activation:Callable=jax.nn.swish
269
- direction:str=None
270
- res:int=2
271
- norm_groups:int=8
272
- kernel_init:Callable=kernel_init(1.0)
273
- dtype: Optional[Dtype] = None
274
- precision: PrecisionLike = None
275
- named_norms:bool=False
276
-
277
- def setup(self):
278
- if self.norm_groups > 0:
279
- norm = partial(nn.GroupNorm, self.norm_groups)
280
- self.norm1 = norm(name="GroupNorm_0") if self.named_norms else norm()
281
- self.norm2 = norm(name="GroupNorm_1") if self.named_norms else norm()
282
- else:
283
- norm = partial(nn.RMSNorm, 1e-5)
284
- self.norm1 = norm()
285
- self.norm2 = norm()
286
-
287
- @nn.compact
288
- def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_features:jax.Array=None):
289
- residual = x
290
- out = self.norm1(x)
291
- # out = nn.RMSNorm()(x)
292
- out = self.activation(out)
293
-
294
- out = ConvLayer(
295
- self.conv_type,
296
- features=self.features,
297
- kernel_size=self.kernel_size,
298
- strides=self.strides,
299
- kernel_init=self.kernel_init,
300
- name="conv1",
301
- dtype=self.dtype,
302
- precision=self.precision
303
- )(out)
304
-
305
- temb = nn.DenseGeneral(
306
- features=self.features,
307
- name="temb_projection",
308
- dtype=self.dtype,
309
- precision=self.precision)(temb)
310
- temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
311
- # scale, shift = jnp.split(temb, 2, axis=-1)
312
- # out = out * (1 + scale) + shift
313
- out = out + temb
314
-
315
- out = self.norm2(out)
316
- # out = nn.RMSNorm()(out)
317
- out = self.activation(out)
318
-
319
- out = ConvLayer(
320
- self.conv_type,
321
- features=self.features,
322
- kernel_size=self.kernel_size,
323
- strides=self.strides,
324
- kernel_init=self.kernel_init,
325
- name="conv2",
326
- dtype=self.dtype,
327
- precision=self.precision
328
- )(out)
329
-
330
- if residual.shape != out.shape:
331
- residual = ConvLayer(
332
- self.conv_type,
333
- features=self.features,
334
- kernel_size=(1, 1),
335
- strides=1,
336
- kernel_init=self.kernel_init,
337
- name="residual_conv",
338
- dtype=self.dtype,
339
- precision=self.precision
340
- )(residual)
341
- out = out + residual
342
-
343
- out = jnp.concatenate([out, extra_features], axis=-1) if extra_features is not None else out
344
-
345
- return out
346
-