flaxdiff 0.1.36__py3-none-any.whl → 0.1.36.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {flaxdiff-0.1.36.dist-info → flaxdiff-0.1.36.1.dist-info}/METADATA +13 -10
  2. flaxdiff-0.1.36.1.dist-info/RECORD +6 -0
  3. flaxdiff/data/__init__.py +0 -1
  4. flaxdiff/data/dataset_map.py +0 -71
  5. flaxdiff/data/datasets.py +0 -169
  6. flaxdiff/data/online_loader.py +0 -363
  7. flaxdiff/models/__init__.py +0 -1
  8. flaxdiff/models/attention.py +0 -368
  9. flaxdiff/models/autoencoder/__init__.py +0 -2
  10. flaxdiff/models/autoencoder/autoencoder.py +0 -19
  11. flaxdiff/models/autoencoder/diffusers.py +0 -91
  12. flaxdiff/models/autoencoder/simple_autoenc.py +0 -26
  13. flaxdiff/models/common.py +0 -346
  14. flaxdiff/models/favor_fastattn.py +0 -723
  15. flaxdiff/models/simple_unet.py +0 -233
  16. flaxdiff/models/simple_vit.py +0 -180
  17. flaxdiff/predictors/__init__.py +0 -96
  18. flaxdiff/samplers/__init__.py +0 -7
  19. flaxdiff/samplers/common.py +0 -165
  20. flaxdiff/samplers/ddim.py +0 -10
  21. flaxdiff/samplers/ddpm.py +0 -37
  22. flaxdiff/samplers/euler.py +0 -56
  23. flaxdiff/samplers/heun_sampler.py +0 -27
  24. flaxdiff/samplers/multistep_dpm.py +0 -59
  25. flaxdiff/samplers/rk4_sampler.py +0 -34
  26. flaxdiff/schedulers/__init__.py +0 -6
  27. flaxdiff/schedulers/common.py +0 -98
  28. flaxdiff/schedulers/continuous.py +0 -12
  29. flaxdiff/schedulers/cosine.py +0 -40
  30. flaxdiff/schedulers/discrete.py +0 -74
  31. flaxdiff/schedulers/exp.py +0 -13
  32. flaxdiff/schedulers/karras.py +0 -69
  33. flaxdiff/schedulers/linear.py +0 -14
  34. flaxdiff/schedulers/sqrt.py +0 -10
  35. flaxdiff/trainer/__init__.py +0 -2
  36. flaxdiff/trainer/autoencoder_trainer.py +0 -182
  37. flaxdiff/trainer/diffusion_trainer.py +0 -326
  38. flaxdiff/trainer/simple_trainer.py +0 -538
  39. flaxdiff/trainer/video_diffusion_trainer.py +0 -62
  40. flaxdiff-0.1.36.dist-info/RECORD +0 -43
  41. {flaxdiff-0.1.36.dist-info → flaxdiff-0.1.36.1.dist-info}/WHEEL +0 -0
  42. {flaxdiff-0.1.36.dist-info → flaxdiff-0.1.36.1.dist-info}/top_level.txt +0 -0
@@ -1,368 +0,0 @@
1
- """
2
- Some Code ported from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_flax.py
3
- """
4
-
5
- import jax
6
- import jax.numpy as jnp
7
- from flax import linen as nn
8
- from typing import Dict, Callable, Sequence, Any, Union, Tuple, Optional
9
- from flax.typing import Dtype, PrecisionLike
10
- import einops
11
- import functools
12
- import math
13
- from .common import kernel_init
14
- import jax.experimental.pallas.ops.tpu.flash_attention
15
-
16
- class EfficientAttention(nn.Module):
17
- """
18
- Based on the pallas attention implementation.
19
- """
20
- query_dim: int
21
- heads: int = 4
22
- dim_head: int = 64
23
- dtype: Optional[Dtype] = None
24
- precision: PrecisionLike = None
25
- use_bias: bool = True
26
- kernel_init: Callable = kernel_init(1.0)
27
- force_fp32_for_softmax: bool = True
28
-
29
- def setup(self):
30
- inner_dim = self.dim_head * self.heads
31
- # Weights were exported with old names {to_q, to_k, to_v, to_out}
32
- dense = functools.partial(
33
- nn.Dense,
34
- self.heads * self.dim_head,
35
- precision=self.precision,
36
- use_bias=self.use_bias,
37
- kernel_init=self.kernel_init,
38
- dtype=self.dtype
39
- )
40
- self.query = dense(name="to_q")
41
- self.key = dense(name="to_k")
42
- self.value = dense(name="to_v")
43
-
44
- self.proj_attn = nn.DenseGeneral(self.query_dim, use_bias=False, precision=self.precision,
45
- kernel_init=self.kernel_init, dtype=self.dtype, name="to_out_0")
46
- # self.attnfn = make_fast_generalized_attention(qkv_dim=inner_dim, lax_scan_unroll=16)
47
-
48
- def _reshape_tensor_to_head_dim(self, tensor):
49
- batch_size, _, seq_len, dim = tensor.shape
50
- head_size = self.heads
51
- tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
52
- tensor = jnp.transpose(tensor, (0, 2, 1, 3))
53
- return tensor
54
-
55
- def _reshape_tensor_from_head_dim(self, tensor):
56
- batch_size, _, seq_len, dim = tensor.shape
57
- head_size = self.heads
58
- tensor = jnp.transpose(tensor, (0, 2, 1, 3))
59
- tensor = tensor.reshape(batch_size, 1, seq_len, dim * head_size)
60
- return tensor
61
-
62
- @nn.compact
63
- def __call__(self, x:jax.Array, context=None):
64
- # print(x.shape)
65
- # x has shape [B, H * W, C]
66
- context = x if context is None else context
67
-
68
- orig_x_shape = x.shape
69
- if len(x.shape) == 4:
70
- B, H, W, C = x.shape
71
- x = x.reshape((B, 1, H * W, C))
72
- else:
73
- B, SEQ, C = x.shape
74
- x = x.reshape((B, 1, SEQ, C))
75
-
76
- if len(context.shape) == 4:
77
- B, _H, _W, _C = context.shape
78
- context = context.reshape((B, 1, _H * _W, _C))
79
- else:
80
- B, SEQ, _C = context.shape
81
- context = context.reshape((B, 1, SEQ, _C))
82
-
83
- query = self.query(x)
84
- key = self.key(context)
85
- value = self.value(context)
86
-
87
- query = self._reshape_tensor_to_head_dim(query)
88
- key = self._reshape_tensor_to_head_dim(key)
89
- value = self._reshape_tensor_to_head_dim(value)
90
-
91
- hidden_states = jax.experimental.pallas.ops.tpu.flash_attention.flash_attention(
92
- query, key, value, None
93
- )
94
-
95
- hidden_states = self._reshape_tensor_from_head_dim(hidden_states)
96
-
97
-
98
- # hidden_states = nn.dot_product_attention(
99
- # query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision
100
- # )
101
-
102
- proj = self.proj_attn(hidden_states)
103
-
104
- proj = proj.reshape(orig_x_shape)
105
-
106
- return proj
107
-
108
- class NormalAttention(nn.Module):
109
- """
110
- Simple implementation of the normal attention.
111
- """
112
- query_dim: int
113
- heads: int = 4
114
- dim_head: int = 64
115
- dtype: Optional[Dtype] = None
116
- precision: PrecisionLike = None
117
- use_bias: bool = True
118
- kernel_init: Callable = kernel_init(1.0)
119
- force_fp32_for_softmax: bool = True
120
-
121
- def setup(self):
122
- inner_dim = self.dim_head * self.heads
123
- dense = functools.partial(
124
- nn.DenseGeneral,
125
- features=[self.heads, self.dim_head],
126
- axis=-1,
127
- precision=self.precision,
128
- use_bias=self.use_bias,
129
- kernel_init=self.kernel_init,
130
- dtype=self.dtype
131
- )
132
- self.query = dense(name="to_q")
133
- self.key = dense(name="to_k")
134
- self.value = dense(name="to_v")
135
-
136
- self.proj_attn = nn.DenseGeneral(
137
- self.query_dim,
138
- axis=(-2, -1),
139
- precision=self.precision,
140
- use_bias=self.use_bias,
141
- dtype=self.dtype,
142
- name="to_out_0",
143
- kernel_init=self.kernel_init
144
- # kernel_init=jax.nn.initializers.xavier_uniform()
145
- )
146
-
147
- @nn.compact
148
- def __call__(self, x, context=None):
149
- # x has shape [B, H, W, C]
150
- orig_x_shape = x.shape
151
- if len(x.shape) == 4:
152
- B, H, W, C = x.shape
153
- x = x.reshape((B, H*W, C))
154
- context = x if context is None else context
155
- if len(context.shape) == 4:
156
- context = context.reshape((B, H*W, C))
157
- query = self.query(x)
158
- key = self.key(context)
159
- value = self.value(context)
160
-
161
- hidden_states = nn.dot_product_attention(
162
- query, key, value, dtype=self.dtype, broadcast_dropout=False,
163
- dropout_rng=None, precision=self.precision, force_fp32_for_softmax=self.force_fp32_for_softmax,
164
- deterministic=True
165
- )
166
- proj = self.proj_attn(hidden_states)
167
- proj = proj.reshape(orig_x_shape)
168
- return proj
169
-
170
- class FlaxGEGLU(nn.Module):
171
- r"""
172
- Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
173
- https://arxiv.org/abs/2002.05202.
174
-
175
- Parameters:
176
- dim (:obj:`int`):
177
- Input hidden states dimension
178
- dropout (:obj:`float`, *optional*, defaults to 0.0):
179
- Dropout rate
180
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
181
- Parameters `dtype`
182
- """
183
-
184
- dim: int
185
- dropout: float = 0.0
186
- dtype: jnp.dtype = jnp.float32
187
- precision: Any = jax.lax.Precision.DEFAULT
188
-
189
- def setup(self):
190
- inner_dim = self.dim * 4
191
- self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype, precision=self.precision)
192
-
193
- def __call__(self, hidden_states):
194
- hidden_states = self.proj(hidden_states)
195
- hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=-1)
196
- return hidden_linear * nn.gelu(hidden_gelu)
197
-
198
- class FlaxFeedForward(nn.Module):
199
- r"""
200
- Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's
201
- [`FeedForward`] class, with the following simplifications:
202
- - The activation function is currently hardcoded to a gated linear unit from:
203
- https://arxiv.org/abs/2002.05202
204
- - `dim_out` is equal to `dim`.
205
- - The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`].
206
-
207
- Parameters:
208
- dim (:obj:`int`):
209
- Inner hidden states dimension
210
- dropout (:obj:`float`, *optional*, defaults to 0.0):
211
- Dropout rate
212
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
213
- Parameters `dtype`
214
- """
215
-
216
- dim: int
217
- dtype: jnp.dtype = jnp.float32
218
- precision: Any = jax.lax.Precision.DEFAULT
219
-
220
- def setup(self):
221
- # The second linear layer needs to be called
222
- # net_2 for now to match the index of the Sequential layer
223
- self.net_0 = FlaxGEGLU(self.dim, self.dtype, precision=self.precision)
224
- self.net_2 = nn.Dense(self.dim, dtype=self.dtype, precision=self.precision)
225
-
226
- def __call__(self, hidden_states):
227
- hidden_states = self.net_0(hidden_states)
228
- hidden_states = self.net_2(hidden_states)
229
- return hidden_states
230
-
231
- class BasicTransformerBlock(nn.Module):
232
- # Has self and cross attention
233
- query_dim: int
234
- heads: int = 4
235
- dim_head: int = 64
236
- dtype: Optional[Dtype] = None
237
- precision: PrecisionLike = None
238
- use_bias: bool = True
239
- kernel_init: Callable = kernel_init(1.0)
240
- use_flash_attention:bool = False
241
- use_cross_only:bool = False
242
- only_pure_attention:bool = False
243
- force_fp32_for_softmax: bool = True
244
-
245
- def setup(self):
246
- if self.use_flash_attention:
247
- attenBlock = EfficientAttention
248
- else:
249
- attenBlock = NormalAttention
250
-
251
- self.attention1 = attenBlock(
252
- query_dim=self.query_dim,
253
- heads=self.heads,
254
- dim_head=self.dim_head,
255
- name=f'Attention1',
256
- precision=self.precision,
257
- use_bias=self.use_bias,
258
- dtype=self.dtype,
259
- kernel_init=self.kernel_init,
260
- force_fp32_for_softmax=self.force_fp32_for_softmax
261
- )
262
- self.attention2 = attenBlock(
263
- query_dim=self.query_dim,
264
- heads=self.heads,
265
- dim_head=self.dim_head,
266
- name=f'Attention2',
267
- precision=self.precision,
268
- use_bias=self.use_bias,
269
- dtype=self.dtype,
270
- kernel_init=self.kernel_init,
271
- force_fp32_for_softmax=self.force_fp32_for_softmax
272
- )
273
-
274
- self.ff = FlaxFeedForward(dim=self.query_dim)
275
- self.norm1 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
276
- self.norm2 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
277
- self.norm3 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
278
-
279
- @nn.compact
280
- def __call__(self, hidden_states, context=None):
281
- if self.only_pure_attention:
282
- return self.attention2(hidden_states, context)
283
-
284
- # self attention
285
- if not self.use_cross_only:
286
- hidden_states = hidden_states + self.attention1(self.norm1(hidden_states))
287
-
288
- # cross attention
289
- hidden_states = hidden_states + self.attention2(self.norm2(hidden_states), context)
290
- # feed forward
291
- hidden_states = hidden_states + self.ff(self.norm3(hidden_states))
292
-
293
- return hidden_states
294
-
295
- class TransformerBlock(nn.Module):
296
- heads: int = 4
297
- dim_head: int = 32
298
- use_linear_attention: bool = True
299
- dtype: Optional[Dtype] = None
300
- precision: PrecisionLike = None
301
- use_projection: bool = False
302
- use_flash_attention:bool = False
303
- use_self_and_cross:bool = True
304
- only_pure_attention:bool = False
305
- force_fp32_for_softmax: bool = True
306
- kernel_init: Callable = kernel_init(1.0)
307
- norm_inputs: bool = True
308
- explicitly_add_residual: bool = True
309
-
310
- @nn.compact
311
- def __call__(self, x, context=None):
312
- inner_dim = self.heads * self.dim_head
313
- C = x.shape[-1]
314
- if self.norm_inputs:
315
- x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
316
- if self.use_projection == True:
317
- if self.use_linear_attention:
318
- projected_x = nn.Dense(features=inner_dim,
319
- use_bias=False, precision=self.precision,
320
- kernel_init=self.kernel_init,
321
- dtype=self.dtype, name=f'project_in')(x)
322
- else:
323
- projected_x = nn.Conv(
324
- features=inner_dim, kernel_size=(1, 1),
325
- kernel_init=self.kernel_init,
326
- strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
327
- precision=self.precision, name=f'project_in_conv',
328
- )(x)
329
- else:
330
- projected_x = x
331
- inner_dim = C
332
-
333
- context = projected_x if context is None else context
334
-
335
- projected_x = BasicTransformerBlock(
336
- query_dim=inner_dim,
337
- heads=self.heads,
338
- dim_head=self.dim_head,
339
- name=f'Attention',
340
- precision=self.precision,
341
- use_bias=False,
342
- dtype=self.dtype,
343
- use_flash_attention=self.use_flash_attention,
344
- use_cross_only=(not self.use_self_and_cross),
345
- only_pure_attention=self.only_pure_attention,
346
- force_fp32_for_softmax=self.force_fp32_for_softmax,
347
- kernel_init=self.kernel_init
348
- )(projected_x, context)
349
-
350
- if self.use_projection == True:
351
- if self.use_linear_attention:
352
- projected_x = nn.Dense(features=C, precision=self.precision,
353
- dtype=self.dtype, use_bias=False,
354
- kernel_init=self.kernel_init,
355
- name=f'project_out')(projected_x)
356
- else:
357
- projected_x = nn.Conv(
358
- features=C, kernel_size=(1, 1),
359
- kernel_init=self.kernel_init,
360
- strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
361
- precision=self.precision, name=f'project_out_conv',
362
- )(projected_x)
363
-
364
- if self.only_pure_attention or self.explicitly_add_residual:
365
- projected_x = x + projected_x
366
-
367
- out = projected_x
368
- return out
@@ -1,2 +0,0 @@
1
- from .autoencoder import AutoEncoder
2
- from .diffusers import StableDiffusionVAE
@@ -1,19 +0,0 @@
1
- import jax
2
- import jax.numpy as jnp
3
- from flax import linen as nn
4
- from typing import Dict, Callable, Sequence, Any, Union
5
- import einops
6
- from ..common import kernel_init, ConvLayer, Upsample, Downsample, PixelShuffle
7
-
8
-
9
- class AutoEncoder():
10
- def encode(self, x: jnp.ndarray, **kwargs) -> jnp.ndarray:
11
- raise NotImplementedError
12
-
13
- def decode(self, z: jnp.ndarray, **kwargs) -> jnp.ndarray:
14
- raise NotImplementedError
15
-
16
- def __call__(self, x: jnp.ndarray):
17
- latents = self.encode(x)
18
- reconstructions = self.decode(latents)
19
- return reconstructions
@@ -1,91 +0,0 @@
1
- import jax
2
- import jax.numpy as jnp
3
- from flax import linen as nn
4
- from .autoencoder import AutoEncoder
5
-
6
- """
7
- This module contains an Autoencoder implementation which uses the Stable Diffusion VAE model from the HuggingFace Diffusers library.
8
- The actual model was not trained by me, but was taken from the HuggingFace model hub.
9
- I have only implemented the wrapper around the diffusers pipeline to make it compatible with our library
10
- All credits for the model go to the developers of Stable Diffusion VAE and all credits for the pipeline go to the developers of the Diffusers library.
11
- """
12
-
13
- class StableDiffusionVAE(AutoEncoder):
14
- def __init__(self, modelname = "CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16):
15
-
16
- from diffusers.models.vae_flax import FlaxEncoder, FlaxDecoder
17
- from diffusers import FlaxStableDiffusionPipeline
18
-
19
- pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
20
- modelname,
21
- revision=revision,
22
- dtype=dtype,
23
- )
24
-
25
- vae = pipeline.vae
26
-
27
- enc = FlaxEncoder(
28
- in_channels=vae.config.in_channels,
29
- out_channels=vae.config.latent_channels,
30
- down_block_types=vae.config.down_block_types,
31
- block_out_channels=vae.config.block_out_channels,
32
- layers_per_block=vae.config.layers_per_block,
33
- act_fn=vae.config.act_fn,
34
- norm_num_groups=vae.config.norm_num_groups,
35
- double_z=True,
36
- dtype=vae.dtype,
37
- )
38
-
39
- dec = FlaxDecoder(
40
- in_channels=vae.config.latent_channels,
41
- out_channels=vae.config.out_channels,
42
- up_block_types=vae.config.up_block_types,
43
- block_out_channels=vae.config.block_out_channels,
44
- layers_per_block=vae.config.layers_per_block,
45
- norm_num_groups=vae.config.norm_num_groups,
46
- act_fn=vae.config.act_fn,
47
- dtype=vae.dtype,
48
- )
49
-
50
- quant_conv = nn.Conv(
51
- 2 * vae.config.latent_channels,
52
- kernel_size=(1, 1),
53
- strides=(1, 1),
54
- padding="VALID",
55
- dtype=vae.dtype,
56
- )
57
-
58
- post_quant_conv = nn.Conv(
59
- vae.config.latent_channels,
60
- kernel_size=(1, 1),
61
- strides=(1, 1),
62
- padding="VALID",
63
- dtype=vae.dtype,
64
- )
65
-
66
- self.enc = enc
67
- self.dec = dec
68
- self.post_quant_conv = post_quant_conv
69
- self.quant_conv = quant_conv
70
- self.params = params
71
- self.scaling_factor = vae.scaling_factor
72
-
73
- def encode(self, images, rngkey: jax.random.PRNGKey = None):
74
- latents = self.enc.apply({"params": self.params["vae"]['encoder']}, images, deterministic=True)
75
- latents = self.quant_conv.apply({"params": self.params["vae"]['quant_conv']}, latents)
76
- if rngkey is not None:
77
- mean, log_std = jnp.split(latents, 2, axis=-1)
78
- log_std = jnp.clip(log_std, -30, 20)
79
- std = jnp.exp(0.5 * log_std)
80
- latents = mean + std * jax.random.normal(rngkey, mean.shape, dtype=mean.dtype)
81
- # print("Sampled")
82
- else:
83
- # return the mean
84
- latents, _ = jnp.split(latents, 2, axis=-1)
85
- latents *= self.scaling_factor
86
- return latents
87
-
88
- def decode(self, latents):
89
- latents = (1.0 / self.scaling_factor) * latents
90
- latents = self.post_quant_conv.apply({"params": self.params["vae"]['post_quant_conv']}, latents)
91
- return self.dec.apply({"params": self.params["vae"]['decoder']}, latents)
@@ -1,26 +0,0 @@
1
- from typing import Any, List, Optional, Callable
2
- import jax
3
- import flax.linen as nn
4
- from jax import numpy as jnp
5
- from flax.typing import Dtype, PrecisionLike
6
- from .autoencoder import AutoEncoder
7
-
8
- class SimpleAutoEncoder(AutoEncoder):
9
- latent_channels: int
10
- feature_depths: List[int]=[64, 128, 256, 512]
11
- attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}],
12
- num_res_blocks: int=2
13
- num_middle_res_blocks:int=1,
14
- activation:Callable = jax.nn.swish
15
- norm_groups:int=8
16
- dtype: Optional[Dtype] = None
17
- precision: PrecisionLike = None
18
-
19
- # def encode(self, x: jnp.ndarray):
20
-
21
-
22
- @nn.compact
23
- def __call__(self, x: jnp.ndarray):
24
- latents = self.encode(x)
25
- reconstructions = self.decode(latents)
26
- return reconstructions