flaxdiff 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -162,65 +162,6 @@ class NormalAttention(nn.Module):
162
162
  proj = proj.reshape(orig_x_shape)
163
163
  return proj
164
164
 
165
- class BasicTransformerBlock(nn.Module):
166
- # Has self and cross attention
167
- query_dim: int
168
- heads: int = 4
169
- dim_head: int = 64
170
- dtype: Optional[Dtype] = None
171
- precision: PrecisionLike = None
172
- use_bias: bool = True
173
- kernel_init: Callable = lambda : kernel_init(1.0)
174
- use_flash_attention:bool = False
175
- use_cross_only:bool = False
176
-
177
- def setup(self):
178
- if self.use_flash_attention:
179
- attenBlock = EfficientAttention
180
- else:
181
- attenBlock = NormalAttention
182
-
183
- self.attention1 = attenBlock(
184
- query_dim=self.query_dim,
185
- heads=self.heads,
186
- dim_head=self.dim_head,
187
- name=f'Attention1',
188
- precision=self.precision,
189
- use_bias=self.use_bias,
190
- dtype=self.dtype,
191
- kernel_init=self.kernel_init
192
- )
193
- self.attention2 = attenBlock(
194
- query_dim=self.query_dim,
195
- heads=self.heads,
196
- dim_head=self.dim_head,
197
- name=f'Attention2',
198
- precision=self.precision,
199
- use_bias=self.use_bias,
200
- dtype=self.dtype,
201
- kernel_init=self.kernel_init
202
- )
203
-
204
- self.ff = FlaxFeedForward(dim=self.query_dim)
205
- self.norm1 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
206
- self.norm2 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
207
- self.norm3 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
208
-
209
- @nn.compact
210
- def __call__(self, hidden_states, context=None):
211
- # self attention
212
- if not self.use_cross_only:
213
- print("Using self attention")
214
- hidden_states = hidden_states + self.attention1(self.norm1(hidden_states))
215
-
216
- # cross attention
217
- hidden_states = hidden_states + self.attention2(self.norm2(hidden_states), context)
218
-
219
- # feed forward
220
- hidden_states = hidden_states + self.ff(self.norm3(hidden_states))
221
-
222
- return hidden_states
223
-
224
165
  class FlaxGEGLU(nn.Module):
225
166
  r"""
226
167
  Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
@@ -330,7 +271,7 @@ class BasicTransformerBlock(nn.Module):
330
271
  @nn.compact
331
272
  def __call__(self, hidden_states, context=None):
332
273
  if self.only_pure_attention:
333
- return self.attention2(self.norm2(hidden_states), context)
274
+ return self.attention2(hidden_states, context)
334
275
 
335
276
  # self attention
336
277
  if not self.use_cross_only:
flaxdiff/models/common.py CHANGED
@@ -270,8 +270,8 @@ class ResidualBlock(nn.Module):
270
270
  @nn.compact
271
271
  def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_features:jax.Array=None):
272
272
  residual = x
273
- # out = nn.GroupNorm(self.norm_groups)(x)
274
- out = nn.RMSNorm()(x)
273
+ out = nn.GroupNorm(self.norm_groups)(x)
274
+ # out = nn.RMSNorm()(x)
275
275
  out = self.activation(out)
276
276
 
277
277
  out = ConvLayer(
@@ -295,8 +295,8 @@ class ResidualBlock(nn.Module):
295
295
  # out = out * (1 + scale) + shift
296
296
  out = out + temb
297
297
 
298
- # out = nn.GroupNorm(self.norm_groups)(out)
299
- out = nn.RMSNorm()(out)
298
+ out = nn.GroupNorm(self.norm_groups)(out)
299
+ # out = nn.RMSNorm()(out)
300
300
  out = self.activation(out)
301
301
 
302
302
  out = ConvLayer(
@@ -4,7 +4,7 @@ from flax import linen as nn
4
4
  from flax.typing import Dtype, PrecisionLike
5
5
  from typing import Dict, Callable, Sequence, Any, Union, Optional
6
6
  import einops
7
- from .common import kernel_init, ConvLayer, Downsample, Upsample, FourierEmbedding, TimeProjection
7
+ from .common import kernel_init, ConvLayer, Downsample, Upsample, FourierEmbedding, TimeProjection, ResidualBlock
8
8
  from .attention import TransformerBlock
9
9
 
10
10
  class Unet(nn.Module):
@@ -65,7 +65,7 @@ class Unet(nn.Module):
65
65
  if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
66
66
  x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
67
67
  dim_head=dim_in // attention_config['heads'],
68
- use_flash_attention=attention_config.get("flash_attention", True),
68
+ use_flash_attention=attention_config.get("flash_attention", False),
69
69
  use_projection=attention_config.get("use_projection", False),
70
70
  use_self_and_cross=attention_config.get("use_self_and_cross", True),
71
71
  precision=attention_config.get("precision", self.precision),
@@ -103,7 +103,7 @@ class Unet(nn.Module):
103
103
  if middle_attention is not None and j == self.num_middle_res_blocks - 1: # Apply attention only on the last block
104
104
  x = TransformerBlock(heads=middle_attention['heads'], dtype=middle_attention.get('dtype', jnp.float32),
105
105
  dim_head=middle_dim_out // middle_attention['heads'],
106
- use_flash_attention=middle_attention.get("flash_attention", True),
106
+ use_flash_attention=middle_attention.get("flash_attention", False),
107
107
  use_linear_attention=False,
108
108
  use_projection=middle_attention.get("use_projection", False),
109
109
  use_self_and_cross=False,
@@ -146,7 +146,7 @@ class Unet(nn.Module):
146
146
  if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
147
147
  x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
148
148
  dim_head=dim_out // attention_config['heads'],
149
- use_flash_attention=attention_config.get("flash_attention", True),
149
+ use_flash_attention=attention_config.get("flash_attention", False),
150
150
  use_projection=attention_config.get("use_projection", False),
151
151
  use_self_and_cross=attention_config.get("use_self_and_cross", True),
152
152
  precision=attention_config.get("precision", self.precision),
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.6
3
+ Version: 0.1.8
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -1,10 +1,10 @@
1
1
  flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
3
3
  flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
4
- flaxdiff/models/attention.py,sha256=OhpKQXdxWbf8K2_yotLfS0DYdHb-zNpL2p8--ql_FAg,14503
5
- flaxdiff/models/common.py,sha256=RYNxX9K19hvwSWaB9Wtv7MIZLhcacdugDgD9uZDh8XM,10358
4
+ flaxdiff/models/attention.py,sha256=pDGXG2DT7znvHJWyx7_vTUx235s_D9cubwmA6FDq4qE,12526
5
+ flaxdiff/models/common.py,sha256=lBY2ffKikNeSFlt2umsCTUUe43UgonTVMyQPgzWoAM8,10358
6
6
  flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
7
- flaxdiff/models/simple_unet.py,sha256=hAcz074E9NVdUtECPMi1c1Kw-52Dc6l_ME-5FqIg-n8,9255
7
+ flaxdiff/models/simple_unet.py,sha256=lakCwUkCODEiuS5T6j45Z_sHamcQa9ZWk77NowgjZyc,9273
8
8
  flaxdiff/models/simple_vit.py,sha256=vTu2CQRoSOxetBHTrnCWddm-vxrZDkMe8EpdNxtpJMk,4015
9
9
  flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
10
10
  flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
@@ -32,7 +32,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
32
32
  flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
33
33
  flaxdiff/trainer/diffusion_trainer.py,sha256=h5YxIMjBI553xDNeapzLDGF0_4y0MfGRMuHume5sPtM,7785
34
34
  flaxdiff/trainer/simple_trainer.py,sha256=f4g2KGuGM__d9v_4Ip3ng8wQubmenWZUW60VEu2ANOg,16774
35
- flaxdiff-0.1.6.dist-info/METADATA,sha256=sWY_oQgQhhuyW89KyRwIBrpVHBPJjRMmsk5twfgIBlo,20090
36
- flaxdiff-0.1.6.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
37
- flaxdiff-0.1.6.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
38
- flaxdiff-0.1.6.dist-info/RECORD,,
35
+ flaxdiff-0.1.8.dist-info/METADATA,sha256=RVH7dPknslUCneKMAY_ira3uzPWfr1whKhsoZXEKiqU,20090
36
+ flaxdiff-0.1.8.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
37
+ flaxdiff-0.1.8.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
38
+ flaxdiff-0.1.8.dist-info/RECORD,,