flaxdiff 0.1.35.4__py3-none-any.whl → 0.1.35.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/models/attention.py +13 -6
- flaxdiff/models/simple_unet.py +17 -11
- flaxdiff/models/simple_vit.py +10 -2
- {flaxdiff-0.1.35.4.dist-info → flaxdiff-0.1.35.6.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.35.4.dist-info → flaxdiff-0.1.35.6.dist-info}/RECORD +7 -7
- {flaxdiff-0.1.35.4.dist-info → flaxdiff-0.1.35.6.dist-info}/WHEEL +1 -1
- {flaxdiff-0.1.35.4.dist-info → flaxdiff-0.1.35.6.dist-info}/top_level.txt +0 -0
flaxdiff/models/attention.py
CHANGED
@@ -11,6 +11,7 @@ import einops
|
|
11
11
|
import functools
|
12
12
|
import math
|
13
13
|
from .common import kernel_init
|
14
|
+
import jax.experimental.pallas.ops.tpu.flash_attention
|
14
15
|
|
15
16
|
class EfficientAttention(nn.Module):
|
16
17
|
"""
|
@@ -303,27 +304,30 @@ class TransformerBlock(nn.Module):
|
|
303
304
|
only_pure_attention:bool = False
|
304
305
|
force_fp32_for_softmax: bool = True
|
305
306
|
kernel_init: Callable = kernel_init(1.0)
|
307
|
+
norm_inputs: bool = True
|
308
|
+
explicitly_add_residual: bool = True
|
306
309
|
|
307
310
|
@nn.compact
|
308
311
|
def __call__(self, x, context=None):
|
309
312
|
inner_dim = self.heads * self.dim_head
|
310
313
|
C = x.shape[-1]
|
311
|
-
|
314
|
+
if self.norm_inputs:
|
315
|
+
x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
|
312
316
|
if self.use_projection == True:
|
313
317
|
if self.use_linear_attention:
|
314
318
|
projected_x = nn.Dense(features=inner_dim,
|
315
319
|
use_bias=False, precision=self.precision,
|
316
320
|
kernel_init=self.kernel_init,
|
317
|
-
dtype=self.dtype, name=f'project_in')(
|
321
|
+
dtype=self.dtype, name=f'project_in')(x)
|
318
322
|
else:
|
319
323
|
projected_x = nn.Conv(
|
320
324
|
features=inner_dim, kernel_size=(1, 1),
|
321
325
|
kernel_init=self.kernel_init,
|
322
326
|
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
|
323
327
|
precision=self.precision, name=f'project_in_conv',
|
324
|
-
)(
|
328
|
+
)(x)
|
325
329
|
else:
|
326
|
-
projected_x =
|
330
|
+
projected_x = x
|
327
331
|
inner_dim = C
|
328
332
|
|
329
333
|
context = projected_x if context is None else context
|
@@ -356,6 +360,9 @@ class TransformerBlock(nn.Module):
|
|
356
360
|
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
|
357
361
|
precision=self.precision, name=f'project_out_conv',
|
358
362
|
)(projected_x)
|
359
|
-
|
360
|
-
|
363
|
+
|
364
|
+
if self.only_pure_attention or self.explicitly_add_residual:
|
365
|
+
projected_x = x + projected_x
|
366
|
+
|
367
|
+
out = projected_x
|
361
368
|
return out
|
flaxdiff/models/simple_unet.py
CHANGED
@@ -50,7 +50,7 @@ class Unet(nn.Module):
|
|
50
50
|
features=self.feature_depths[0],
|
51
51
|
kernel_size=(3, 3),
|
52
52
|
strides=(1, 1),
|
53
|
-
kernel_init=self.kernel_init(1.0),
|
53
|
+
kernel_init=self.kernel_init(scale=1.0),
|
54
54
|
dtype=self.dtype,
|
55
55
|
precision=self.precision
|
56
56
|
)(x)
|
@@ -65,7 +65,7 @@ class Unet(nn.Module):
|
|
65
65
|
down_conv_type,
|
66
66
|
name=f"down_{i}_residual_{j}",
|
67
67
|
features=dim_in,
|
68
|
-
kernel_init=self.kernel_init(1.0),
|
68
|
+
kernel_init=self.kernel_init(scale=1.0),
|
69
69
|
kernel_size=(3, 3),
|
70
70
|
strides=(1, 1),
|
71
71
|
activation=self.activation,
|
@@ -83,7 +83,9 @@ class Unet(nn.Module):
|
|
83
83
|
precision=attention_config.get("precision", self.precision),
|
84
84
|
only_pure_attention=attention_config.get("only_pure_attention", True),
|
85
85
|
force_fp32_for_softmax=attention_config.get("force_fp32_for_softmax", False),
|
86
|
-
|
86
|
+
norm_inputs=attention_config.get("norm_inputs", True),
|
87
|
+
explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
|
88
|
+
kernel_init=self.kernel_init(scale=1.0),
|
87
89
|
name=f"down_{i}_attention_{j}")(x, textcontext)
|
88
90
|
# print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
|
89
91
|
downs.append(x)
|
@@ -106,7 +108,7 @@ class Unet(nn.Module):
|
|
106
108
|
middle_conv_type,
|
107
109
|
name=f"middle_res1_{j}",
|
108
110
|
features=middle_dim_out,
|
109
|
-
kernel_init=self.kernel_init(1.0),
|
111
|
+
kernel_init=self.kernel_init(scale=1.0),
|
110
112
|
kernel_size=(3, 3),
|
111
113
|
strides=(1, 1),
|
112
114
|
activation=self.activation,
|
@@ -125,13 +127,15 @@ class Unet(nn.Module):
|
|
125
127
|
precision=middle_attention.get("precision", self.precision),
|
126
128
|
only_pure_attention=middle_attention.get("only_pure_attention", True),
|
127
129
|
force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
|
128
|
-
|
130
|
+
norm_inputs=middle_attention.get("norm_inputs", True),
|
131
|
+
explicitly_add_residual=middle_attention.get("explicitly_add_residual", True),
|
132
|
+
kernel_init=self.kernel_init(scale=1.0),
|
129
133
|
name=f"middle_attention_{j}")(x, textcontext)
|
130
134
|
x = ResidualBlock(
|
131
135
|
middle_conv_type,
|
132
136
|
name=f"middle_res2_{j}",
|
133
137
|
features=middle_dim_out,
|
134
|
-
kernel_init=self.kernel_init(1.0),
|
138
|
+
kernel_init=self.kernel_init(scale=1.0),
|
135
139
|
kernel_size=(3, 3),
|
136
140
|
strides=(1, 1),
|
137
141
|
activation=self.activation,
|
@@ -153,7 +157,7 @@ class Unet(nn.Module):
|
|
153
157
|
up_conv_type,# if j == 0 else "separable",
|
154
158
|
name=f"up_{i}_residual_{j}",
|
155
159
|
features=dim_out,
|
156
|
-
kernel_init=self.kernel_init(1.0),
|
160
|
+
kernel_init=self.kernel_init(scale=1.0),
|
157
161
|
kernel_size=kernel_size,
|
158
162
|
strides=(1, 1),
|
159
163
|
activation=self.activation,
|
@@ -171,7 +175,9 @@ class Unet(nn.Module):
|
|
171
175
|
precision=attention_config.get("precision", self.precision),
|
172
176
|
only_pure_attention=attention_config.get("only_pure_attention", True),
|
173
177
|
force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
|
174
|
-
|
178
|
+
norm_inputs=attention_config.get("norm_inputs", True),
|
179
|
+
explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
|
180
|
+
kernel_init=self.kernel_init(scale=1.0),
|
175
181
|
name=f"up_{i}_attention_{j}")(x, textcontext)
|
176
182
|
# print("Upscaling ", i, x.shape)
|
177
183
|
if i != len(feature_depths) - 1:
|
@@ -190,7 +196,7 @@ class Unet(nn.Module):
|
|
190
196
|
features=self.feature_depths[0],
|
191
197
|
kernel_size=(3, 3),
|
192
198
|
strides=(1, 1),
|
193
|
-
kernel_init=self.kernel_init(1.0),
|
199
|
+
kernel_init=self.kernel_init(scale=1.0),
|
194
200
|
dtype=self.dtype,
|
195
201
|
precision=self.precision
|
196
202
|
)(x)
|
@@ -201,7 +207,7 @@ class Unet(nn.Module):
|
|
201
207
|
conv_type,
|
202
208
|
name="final_residual",
|
203
209
|
features=self.feature_depths[0],
|
204
|
-
kernel_init=self.kernel_init(1.0),
|
210
|
+
kernel_init=self.kernel_init(scale=1.0),
|
205
211
|
kernel_size=(3,3),
|
206
212
|
strides=(1, 1),
|
207
213
|
activation=self.activation,
|
@@ -220,7 +226,7 @@ class Unet(nn.Module):
|
|
220
226
|
kernel_size=(3, 3),
|
221
227
|
strides=(1, 1),
|
222
228
|
# activation=jax.nn.mish
|
223
|
-
kernel_init=self.kernel_init(0.0),
|
229
|
+
kernel_init=self.kernel_init(scale=0.0),
|
224
230
|
dtype=self.dtype,
|
225
231
|
precision=self.precision
|
226
232
|
)(x)
|
flaxdiff/models/simple_vit.py
CHANGED
@@ -69,6 +69,8 @@ class UViT(nn.Module):
|
|
69
69
|
precision: PrecisionLike = None
|
70
70
|
kernel_init: Callable = partial(kernel_init, scale=1.0)
|
71
71
|
add_residualblock_output: bool = False
|
72
|
+
norm_inputs: bool = False
|
73
|
+
explicitly_add_residual: bool = True
|
72
74
|
|
73
75
|
def setup(self):
|
74
76
|
if self.norm_groups > 0:
|
@@ -110,16 +112,20 @@ class UViT(nn.Module):
|
|
110
112
|
for i in range(self.num_layers // 2):
|
111
113
|
x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
|
112
114
|
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
113
|
-
use_flash_attention=self.use_flash_attention, use_self_and_cross=
|
115
|
+
use_flash_attention=self.use_flash_attention, use_self_and_cross=False, force_fp32_for_softmax=self.force_fp32_for_softmax,
|
114
116
|
only_pure_attention=False,
|
117
|
+
norm_inputs=self.norm_inputs,
|
118
|
+
explicitly_add_residual=self.explicitly_add_residual,
|
115
119
|
kernel_init=self.kernel_init())(x)
|
116
120
|
skips.append(x)
|
117
121
|
|
118
122
|
# Middle block
|
119
123
|
x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
|
120
124
|
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
121
|
-
use_flash_attention=self.use_flash_attention, use_self_and_cross=
|
125
|
+
use_flash_attention=self.use_flash_attention, use_self_and_cross=False, force_fp32_for_softmax=self.force_fp32_for_softmax,
|
122
126
|
only_pure_attention=False,
|
127
|
+
norm_inputs=self.norm_inputs,
|
128
|
+
explicitly_add_residual=self.explicitly_add_residual,
|
123
129
|
kernel_init=self.kernel_init())(x)
|
124
130
|
|
125
131
|
# # Out blocks
|
@@ -131,6 +137,8 @@ class UViT(nn.Module):
|
|
131
137
|
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
132
138
|
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
|
133
139
|
only_pure_attention=False,
|
140
|
+
norm_inputs=self.norm_inputs,
|
141
|
+
explicitly_add_residual=self.explicitly_add_residual,
|
134
142
|
kernel_init=self.kernel_init())(x)
|
135
143
|
|
136
144
|
# print(f'Shape of x after transformer blocks: {x.shape}')
|
@@ -3,11 +3,11 @@ flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
|
|
3
3
|
flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
|
4
4
|
flaxdiff/data/online_loader.py,sha256=DoHrMZCi5gMd9tmkCpZIUU9lGxvfYtuaz58943_lCRc,11315
|
5
5
|
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
6
|
-
flaxdiff/models/attention.py,sha256=
|
6
|
+
flaxdiff/models/attention.py,sha256=JvrP7-09MV6IfRLRBhqjPmNUU-lkEMk9TOnJSBKcar8,13289
|
7
7
|
flaxdiff/models/common.py,sha256=hWsSs2BP2J-JN1s4qLRr-h-KYkcVyl2hOp1Wsm_L-h8,10994
|
8
8
|
flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
|
9
|
-
flaxdiff/models/simple_unet.py,sha256=
|
10
|
-
flaxdiff/models/simple_vit.py,sha256=
|
9
|
+
flaxdiff/models/simple_unet.py,sha256=L5m2j5580QP7pJ5VIme7U5xYA22PZiGP7qdvcKUnB38,11463
|
10
|
+
flaxdiff/models/simple_vit.py,sha256=UCDDr0XVnpf6tbJWKFtEt3_nAqMqOoakXf5amyVWZNo,7929
|
11
11
|
flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
|
12
12
|
flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
|
13
13
|
flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
|
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
|
|
34
34
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
|
35
35
|
flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
|
36
36
|
flaxdiff/trainer/simple_trainer.py,sha256=cawm6fZNQoLLATMneAU2gQ9j7kefqHnBPHuaIj3i_a4,18237
|
37
|
-
flaxdiff-0.1.35.
|
38
|
-
flaxdiff-0.1.35.
|
39
|
-
flaxdiff-0.1.35.
|
40
|
-
flaxdiff-0.1.35.
|
37
|
+
flaxdiff-0.1.35.6.dist-info/METADATA,sha256=NVCk5V7Zc3iq-nrWTivzO17dQa1fIjYgjJb800ZrZhQ,22085
|
38
|
+
flaxdiff-0.1.35.6.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
39
|
+
flaxdiff-0.1.35.6.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
40
|
+
flaxdiff-0.1.35.6.dist-info/RECORD,,
|
File without changes
|