flaxdiff 0.1.23__py3-none-any.whl → 0.1.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/models/simple_vit.py +6 -3
- {flaxdiff-0.1.23.dist-info → flaxdiff-0.1.24.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.23.dist-info → flaxdiff-0.1.24.dist-info}/RECORD +5 -5
- {flaxdiff-0.1.23.dist-info → flaxdiff-0.1.24.dist-info}/WHEEL +1 -1
- {flaxdiff-0.1.23.dist-info → flaxdiff-0.1.24.dist-info}/top_level.txt +0 -0
flaxdiff/models/simple_vit.py
CHANGED
@@ -58,6 +58,9 @@ class UViT(nn.Module):
|
|
58
58
|
dtype: Any = jnp.float32
|
59
59
|
precision: Any = jax.lax.Precision.HIGH
|
60
60
|
use_projection: bool = False
|
61
|
+
use_flash_attention: bool = False
|
62
|
+
use_self_and_cross: bool = False
|
63
|
+
force_fp32_for_softmax: bool = True
|
61
64
|
activation:Callable = jax.nn.swish
|
62
65
|
norm_groups:int=8
|
63
66
|
dtype: Optional[Dtype] = None
|
@@ -102,7 +105,7 @@ class UViT(nn.Module):
|
|
102
105
|
for i in range(self.num_layers // 2):
|
103
106
|
x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
|
104
107
|
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
105
|
-
use_flash_attention=
|
108
|
+
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
|
106
109
|
only_pure_attention=False,
|
107
110
|
kernel_init=self.kernel_init())(x)
|
108
111
|
skips.append(x)
|
@@ -110,7 +113,7 @@ class UViT(nn.Module):
|
|
110
113
|
# Middle block
|
111
114
|
x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
|
112
115
|
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
113
|
-
use_flash_attention=
|
116
|
+
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.fforce_fp32_for_softmax,
|
114
117
|
only_pure_attention=False,
|
115
118
|
kernel_init=self.kernel_init())(x)
|
116
119
|
|
@@ -121,7 +124,7 @@ class UViT(nn.Module):
|
|
121
124
|
dtype=self.dtype, precision=self.precision)(skip)
|
122
125
|
x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
|
123
126
|
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
124
|
-
use_flash_attention=
|
127
|
+
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.fforce_fp32_for_softmax,
|
125
128
|
only_pure_attention=False,
|
126
129
|
kernel_init=self.kernel_init())(skip)
|
127
130
|
|
@@ -7,7 +7,7 @@ flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,
|
|
7
7
|
flaxdiff/models/common.py,sha256=lzfOHB-7Bjx83ZZzywXy2mjhwP2UOMKR11vdVSKvsCo,11068
|
8
8
|
flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
|
9
9
|
flaxdiff/models/simple_unet.py,sha256=h1o9mQlLJy7Ec8Pz_O5miRbAyUaM5UNhSs-oXzpQvZo,10763
|
10
|
-
flaxdiff/models/simple_vit.py,sha256=
|
10
|
+
flaxdiff/models/simple_vit.py,sha256=li5IAdTlfKRzLnNL4UXuAweMm7gCwL2LyLUip42KBck,6701
|
11
11
|
flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
|
12
12
|
flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
|
13
13
|
flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
|
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
|
|
34
34
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
|
35
35
|
flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
|
36
36
|
flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
|
37
|
-
flaxdiff-0.1.
|
38
|
-
flaxdiff-0.1.
|
39
|
-
flaxdiff-0.1.
|
40
|
-
flaxdiff-0.1.
|
37
|
+
flaxdiff-0.1.24.dist-info/METADATA,sha256=n0xavOLzyPThatgRyXTwR0Gn84UzMzvgdC2xMTqIWg0,22083
|
38
|
+
flaxdiff-0.1.24.dist-info/WHEEL,sha256=Mdi9PDNwEZptOjTlUcAth7XJDFtKrHYaQMPulZeBCiQ,91
|
39
|
+
flaxdiff-0.1.24.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
40
|
+
flaxdiff-0.1.24.dist-info/RECORD,,
|
File without changes
|