flaxdiff 0.1.23__py3-none-any.whl → 0.1.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -58,6 +58,9 @@ class UViT(nn.Module):
58
58
  dtype: Any = jnp.float32
59
59
  precision: Any = jax.lax.Precision.HIGH
60
60
  use_projection: bool = False
61
+ use_flash_attention: bool = False
62
+ use_self_and_cross: bool = False
63
+ force_fp32_for_softmax: bool = True
61
64
  activation:Callable = jax.nn.swish
62
65
  norm_groups:int=8
63
66
  dtype: Optional[Dtype] = None
@@ -102,7 +105,7 @@ class UViT(nn.Module):
102
105
  for i in range(self.num_layers // 2):
103
106
  x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
104
107
  dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
105
- use_flash_attention=False, use_self_and_cross=False, force_fp32_for_softmax=True,
108
+ use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
106
109
  only_pure_attention=False,
107
110
  kernel_init=self.kernel_init())(x)
108
111
  skips.append(x)
@@ -110,7 +113,7 @@ class UViT(nn.Module):
110
113
  # Middle block
111
114
  x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
112
115
  dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
113
- use_flash_attention=False, use_self_and_cross=True, force_fp32_for_softmax=True,
116
+ use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.fforce_fp32_for_softmax,
114
117
  only_pure_attention=False,
115
118
  kernel_init=self.kernel_init())(x)
116
119
 
@@ -121,7 +124,7 @@ class UViT(nn.Module):
121
124
  dtype=self.dtype, precision=self.precision)(skip)
122
125
  x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
123
126
  dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
124
- use_flash_attention=False, use_self_and_cross=False, force_fp32_for_softmax=True,
127
+ use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.fforce_fp32_for_softmax,
125
128
  only_pure_attention=False,
126
129
  kernel_init=self.kernel_init())(skip)
127
130
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.23
3
+ Version: 0.1.24
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -7,7 +7,7 @@ flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,
7
7
  flaxdiff/models/common.py,sha256=lzfOHB-7Bjx83ZZzywXy2mjhwP2UOMKR11vdVSKvsCo,11068
8
8
  flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
9
9
  flaxdiff/models/simple_unet.py,sha256=h1o9mQlLJy7Ec8Pz_O5miRbAyUaM5UNhSs-oXzpQvZo,10763
10
- flaxdiff/models/simple_vit.py,sha256=tXO0WIozq2C4j2GqphSM9mMAYEwj9fKr7rfm4G6vf4A,6403
10
+ flaxdiff/models/simple_vit.py,sha256=li5IAdTlfKRzLnNL4UXuAweMm7gCwL2LyLUip42KBck,6701
11
11
  flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
12
12
  flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
13
13
  flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
34
34
  flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
35
35
  flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
36
36
  flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
37
- flaxdiff-0.1.23.dist-info/METADATA,sha256=iTgk4DY-kuALF-Y1U-a433c7pYRhSy-bXPQrtXh6d54,22083
38
- flaxdiff-0.1.23.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
39
- flaxdiff-0.1.23.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
- flaxdiff-0.1.23.dist-info/RECORD,,
37
+ flaxdiff-0.1.24.dist-info/METADATA,sha256=n0xavOLzyPThatgRyXTwR0Gn84UzMzvgdC2xMTqIWg0,22083
38
+ flaxdiff-0.1.24.dist-info/WHEEL,sha256=Mdi9PDNwEZptOjTlUcAth7XJDFtKrHYaQMPulZeBCiQ,91
39
+ flaxdiff-0.1.24.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
+ flaxdiff-0.1.24.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (72.2.0)
2
+ Generator: setuptools (73.0.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5