flaxdiff 0.1.10__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/models/attention.py +24 -15
- flaxdiff/models/simple_unet.py +15 -8
- {flaxdiff-0.1.10.dist-info → flaxdiff-0.1.12.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.10.dist-info → flaxdiff-0.1.12.dist-info}/RECORD +6 -6
- {flaxdiff-0.1.10.dist-info → flaxdiff-0.1.12.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.10.dist-info → flaxdiff-0.1.12.dist-info}/top_level.txt +0 -0
flaxdiff/models/attention.py
CHANGED
@@ -22,7 +22,8 @@ class EfficientAttention(nn.Module):
|
|
22
22
|
dtype: Optional[Dtype] = None
|
23
23
|
precision: PrecisionLike = None
|
24
24
|
use_bias: bool = True
|
25
|
-
kernel_init: Callable =
|
25
|
+
kernel_init: Callable = kernel_init(1.0)
|
26
|
+
force_fp32_for_softmax: bool = True
|
26
27
|
|
27
28
|
def setup(self):
|
28
29
|
inner_dim = self.dim_head * self.heads
|
@@ -32,7 +33,7 @@ class EfficientAttention(nn.Module):
|
|
32
33
|
self.heads * self.dim_head,
|
33
34
|
precision=self.precision,
|
34
35
|
use_bias=self.use_bias,
|
35
|
-
kernel_init=self.kernel_init
|
36
|
+
kernel_init=self.kernel_init,
|
36
37
|
dtype=self.dtype
|
37
38
|
)
|
38
39
|
self.query = dense(name="to_q")
|
@@ -40,7 +41,7 @@ class EfficientAttention(nn.Module):
|
|
40
41
|
self.value = dense(name="to_v")
|
41
42
|
|
42
43
|
self.proj_attn = nn.DenseGeneral(self.query_dim, use_bias=False, precision=self.precision,
|
43
|
-
kernel_init=self.kernel_init
|
44
|
+
kernel_init=self.kernel_init, dtype=self.dtype, name="to_out_0")
|
44
45
|
# self.attnfn = make_fast_generalized_attention(qkv_dim=inner_dim, lax_scan_unroll=16)
|
45
46
|
|
46
47
|
def _reshape_tensor_to_head_dim(self, tensor):
|
@@ -113,7 +114,8 @@ class NormalAttention(nn.Module):
|
|
113
114
|
dtype: Optional[Dtype] = None
|
114
115
|
precision: PrecisionLike = None
|
115
116
|
use_bias: bool = True
|
116
|
-
kernel_init: Callable =
|
117
|
+
kernel_init: Callable = kernel_init(1.0)
|
118
|
+
force_fp32_for_softmax: bool = True
|
117
119
|
|
118
120
|
def setup(self):
|
119
121
|
inner_dim = self.dim_head * self.heads
|
@@ -123,7 +125,7 @@ class NormalAttention(nn.Module):
|
|
123
125
|
axis=-1,
|
124
126
|
precision=self.precision,
|
125
127
|
use_bias=self.use_bias,
|
126
|
-
kernel_init=self.kernel_init
|
128
|
+
kernel_init=self.kernel_init,
|
127
129
|
dtype=self.dtype
|
128
130
|
)
|
129
131
|
self.query = dense(name="to_q")
|
@@ -137,7 +139,7 @@ class NormalAttention(nn.Module):
|
|
137
139
|
use_bias=self.use_bias,
|
138
140
|
dtype=self.dtype,
|
139
141
|
name="to_out_0",
|
140
|
-
kernel_init=self.kernel_init
|
142
|
+
kernel_init=self.kernel_init
|
141
143
|
# kernel_init=jax.nn.initializers.xavier_uniform()
|
142
144
|
)
|
143
145
|
|
@@ -157,7 +159,7 @@ class NormalAttention(nn.Module):
|
|
157
159
|
|
158
160
|
hidden_states = nn.dot_product_attention(
|
159
161
|
query, key, value, dtype=self.dtype, broadcast_dropout=False,
|
160
|
-
dropout_rng=None, precision=self.precision, force_fp32_for_softmax=
|
162
|
+
dropout_rng=None, precision=self.precision, force_fp32_for_softmax=self.force_fp32_for_softmax,
|
161
163
|
deterministic=True
|
162
164
|
)
|
163
165
|
proj = self.proj_attn(hidden_states)
|
@@ -233,10 +235,11 @@ class BasicTransformerBlock(nn.Module):
|
|
233
235
|
dtype: Optional[Dtype] = None
|
234
236
|
precision: PrecisionLike = None
|
235
237
|
use_bias: bool = True
|
236
|
-
kernel_init: Callable =
|
238
|
+
kernel_init: Callable = kernel_init(1.0)
|
237
239
|
use_flash_attention:bool = False
|
238
240
|
use_cross_only:bool = False
|
239
241
|
only_pure_attention:bool = False
|
242
|
+
force_fp32_for_softmax: bool = True
|
240
243
|
|
241
244
|
def setup(self):
|
242
245
|
if self.use_flash_attention:
|
@@ -252,7 +255,8 @@ class BasicTransformerBlock(nn.Module):
|
|
252
255
|
precision=self.precision,
|
253
256
|
use_bias=self.use_bias,
|
254
257
|
dtype=self.dtype,
|
255
|
-
kernel_init=self.kernel_init
|
258
|
+
kernel_init=self.kernel_init,
|
259
|
+
force_fp32_for_softmax=self.force_fp32_for_softmax
|
256
260
|
)
|
257
261
|
self.attention2 = attenBlock(
|
258
262
|
query_dim=self.query_dim,
|
@@ -262,7 +266,8 @@ class BasicTransformerBlock(nn.Module):
|
|
262
266
|
precision=self.precision,
|
263
267
|
use_bias=self.use_bias,
|
264
268
|
dtype=self.dtype,
|
265
|
-
kernel_init=self.kernel_init
|
269
|
+
kernel_init=self.kernel_init,
|
270
|
+
force_fp32_for_softmax=self.force_fp32_for_softmax
|
266
271
|
)
|
267
272
|
|
268
273
|
self.ff = FlaxFeedForward(dim=self.query_dim)
|
@@ -296,6 +301,8 @@ class TransformerBlock(nn.Module):
|
|
296
301
|
use_flash_attention:bool = False
|
297
302
|
use_self_and_cross:bool = True
|
298
303
|
only_pure_attention:bool = False
|
304
|
+
force_fp32_for_softmax: bool = True
|
305
|
+
kernel_init: Callable = kernel_init(1.0)
|
299
306
|
|
300
307
|
@nn.compact
|
301
308
|
def __call__(self, x, context=None):
|
@@ -306,12 +313,12 @@ class TransformerBlock(nn.Module):
|
|
306
313
|
if self.use_linear_attention:
|
307
314
|
projected_x = nn.Dense(features=inner_dim,
|
308
315
|
use_bias=False, precision=self.precision,
|
309
|
-
kernel_init=kernel_init
|
316
|
+
kernel_init=self.kernel_init,
|
310
317
|
dtype=self.dtype, name=f'project_in')(normed_x)
|
311
318
|
else:
|
312
319
|
projected_x = nn.Conv(
|
313
320
|
features=inner_dim, kernel_size=(1, 1),
|
314
|
-
kernel_init=kernel_init
|
321
|
+
kernel_init=self.kernel_init,
|
315
322
|
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
|
316
323
|
precision=self.precision, name=f'project_in_conv',
|
317
324
|
)(normed_x)
|
@@ -331,19 +338,21 @@ class TransformerBlock(nn.Module):
|
|
331
338
|
dtype=self.dtype,
|
332
339
|
use_flash_attention=self.use_flash_attention,
|
333
340
|
use_cross_only=(not self.use_self_and_cross),
|
334
|
-
only_pure_attention=self.only_pure_attention
|
341
|
+
only_pure_attention=self.only_pure_attention,
|
342
|
+
force_fp32_for_softmax=self.force_fp32_for_softmax,
|
343
|
+
kernel_init=self.kernel_init
|
335
344
|
)(projected_x, context)
|
336
345
|
|
337
346
|
if self.use_projection == True:
|
338
347
|
if self.use_linear_attention:
|
339
348
|
projected_x = nn.Dense(features=C, precision=self.precision,
|
340
349
|
dtype=self.dtype, use_bias=False,
|
341
|
-
kernel_init=kernel_init
|
350
|
+
kernel_init=self.kernel_init,
|
342
351
|
name=f'project_out')(projected_x)
|
343
352
|
else:
|
344
353
|
projected_x = nn.Conv(
|
345
354
|
features=C, kernel_size=(1, 1),
|
346
|
-
kernel_init=kernel_init
|
355
|
+
kernel_init=self.kernel_init,
|
347
356
|
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
|
348
357
|
precision=self.precision, name=f'project_out_conv',
|
349
358
|
)(projected_x)
|
flaxdiff/models/simple_unet.py
CHANGED
@@ -20,6 +20,7 @@ class Unet(nn.Module):
|
|
20
20
|
dtype: Optional[Dtype] = None
|
21
21
|
precision: PrecisionLike = None
|
22
22
|
named_norms: bool = False # This is for backward compatibility reasons; older checkpoints have named norms
|
23
|
+
kernel_init: Callable = partial(kernel_init, dtype=jnp.float32)
|
23
24
|
|
24
25
|
def setup(self):
|
25
26
|
if self.norm_groups > 0:
|
@@ -49,7 +50,7 @@ class Unet(nn.Module):
|
|
49
50
|
features=self.feature_depths[0],
|
50
51
|
kernel_size=(3, 3),
|
51
52
|
strides=(1, 1),
|
52
|
-
kernel_init=kernel_init(1.0),
|
53
|
+
kernel_init=self.kernel_init(1.0),
|
53
54
|
dtype=self.dtype,
|
54
55
|
precision=self.precision
|
55
56
|
)(x)
|
@@ -64,7 +65,7 @@ class Unet(nn.Module):
|
|
64
65
|
down_conv_type,
|
65
66
|
name=f"down_{i}_residual_{j}",
|
66
67
|
features=dim_in,
|
67
|
-
kernel_init=kernel_init(1.0),
|
68
|
+
kernel_init=self.kernel_init(1.0),
|
68
69
|
kernel_size=(3, 3),
|
69
70
|
strides=(1, 1),
|
70
71
|
activation=self.activation,
|
@@ -81,6 +82,8 @@ class Unet(nn.Module):
|
|
81
82
|
use_self_and_cross=attention_config.get("use_self_and_cross", True),
|
82
83
|
precision=attention_config.get("precision", self.precision),
|
83
84
|
only_pure_attention=attention_config.get("only_pure_attention", True),
|
85
|
+
force_fp32_for_softmax=attention_config.get("force_fp32_for_softmax", False),
|
86
|
+
kernel_init=self.kernel_init(1.0),
|
84
87
|
name=f"down_{i}_attention_{j}")(x, textcontext)
|
85
88
|
# print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
|
86
89
|
downs.append(x)
|
@@ -103,7 +106,7 @@ class Unet(nn.Module):
|
|
103
106
|
middle_conv_type,
|
104
107
|
name=f"middle_res1_{j}",
|
105
108
|
features=middle_dim_out,
|
106
|
-
kernel_init=kernel_init(1.0),
|
109
|
+
kernel_init=self.kernel_init(1.0),
|
107
110
|
kernel_size=(3, 3),
|
108
111
|
strides=(1, 1),
|
109
112
|
activation=self.activation,
|
@@ -121,12 +124,14 @@ class Unet(nn.Module):
|
|
121
124
|
use_self_and_cross=False,
|
122
125
|
precision=middle_attention.get("precision", self.precision),
|
123
126
|
only_pure_attention=middle_attention.get("only_pure_attention", True),
|
127
|
+
force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
|
128
|
+
kernel_init=self.kernel_init(1.0),
|
124
129
|
name=f"middle_attention_{j}")(x, textcontext)
|
125
130
|
x = ResidualBlock(
|
126
131
|
middle_conv_type,
|
127
132
|
name=f"middle_res2_{j}",
|
128
133
|
features=middle_dim_out,
|
129
|
-
kernel_init=kernel_init(1.0),
|
134
|
+
kernel_init=self.kernel_init(1.0),
|
130
135
|
kernel_size=(3, 3),
|
131
136
|
strides=(1, 1),
|
132
137
|
activation=self.activation,
|
@@ -148,7 +153,7 @@ class Unet(nn.Module):
|
|
148
153
|
up_conv_type,# if j == 0 else "separable",
|
149
154
|
name=f"up_{i}_residual_{j}",
|
150
155
|
features=dim_out,
|
151
|
-
kernel_init=kernel_init(1.0),
|
156
|
+
kernel_init=self.kernel_init(1.0),
|
152
157
|
kernel_size=kernel_size,
|
153
158
|
strides=(1, 1),
|
154
159
|
activation=self.activation,
|
@@ -165,6 +170,8 @@ class Unet(nn.Module):
|
|
165
170
|
use_self_and_cross=attention_config.get("use_self_and_cross", True),
|
166
171
|
precision=attention_config.get("precision", self.precision),
|
167
172
|
only_pure_attention=attention_config.get("only_pure_attention", True),
|
173
|
+
force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
|
174
|
+
kernel_init=self.kernel_init(1.0),
|
168
175
|
name=f"up_{i}_attention_{j}")(x, textcontext)
|
169
176
|
# print("Upscaling ", i, x.shape)
|
170
177
|
if i != len(feature_depths) - 1:
|
@@ -183,7 +190,7 @@ class Unet(nn.Module):
|
|
183
190
|
features=self.feature_depths[0],
|
184
191
|
kernel_size=(3, 3),
|
185
192
|
strides=(1, 1),
|
186
|
-
kernel_init=kernel_init(1.0),
|
193
|
+
kernel_init=self.kernel_init(1.0),
|
187
194
|
dtype=self.dtype,
|
188
195
|
precision=self.precision
|
189
196
|
)(x)
|
@@ -194,7 +201,7 @@ class Unet(nn.Module):
|
|
194
201
|
conv_type,
|
195
202
|
name="final_residual",
|
196
203
|
features=self.feature_depths[0],
|
197
|
-
kernel_init=kernel_init(1.0),
|
204
|
+
kernel_init=self.kernel_init(1.0),
|
198
205
|
kernel_size=(3,3),
|
199
206
|
strides=(1, 1),
|
200
207
|
activation=self.activation,
|
@@ -213,7 +220,7 @@ class Unet(nn.Module):
|
|
213
220
|
kernel_size=(3, 3),
|
214
221
|
strides=(1, 1),
|
215
222
|
# activation=jax.nn.mish
|
216
|
-
kernel_init=kernel_init(0.0),
|
223
|
+
kernel_init=self.kernel_init(0.0),
|
217
224
|
dtype=self.dtype,
|
218
225
|
precision=self.precision
|
219
226
|
)(x)
|
@@ -1,10 +1,10 @@
|
|
1
1
|
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
|
3
3
|
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
4
|
-
flaxdiff/models/attention.py,sha256=
|
4
|
+
flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
|
5
5
|
flaxdiff/models/common.py,sha256=fd-Fl0VCNEBjijHNwGBqYL5VvXe9u0347h25czNTmRw,10780
|
6
6
|
flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
|
7
|
-
flaxdiff/models/simple_unet.py,sha256=
|
7
|
+
flaxdiff/models/simple_unet.py,sha256=h1o9mQlLJy7Ec8Pz_O5miRbAyUaM5UNhSs-oXzpQvZo,10763
|
8
8
|
flaxdiff/models/simple_vit.py,sha256=xD23i1b7WEvoH4tUMsLyCe9ebDcv-PpaV0Nso38Jlb8,3887
|
9
9
|
flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
|
10
10
|
flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
|
@@ -32,7 +32,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
|
|
32
32
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
|
33
33
|
flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
|
34
34
|
flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
|
35
|
-
flaxdiff-0.1.
|
36
|
-
flaxdiff-0.1.
|
37
|
-
flaxdiff-0.1.
|
38
|
-
flaxdiff-0.1.
|
35
|
+
flaxdiff-0.1.12.dist-info/METADATA,sha256=s3rIj9jqh1Xr1NABqOJZw9XwyHdaLd01c_jKpzEMErQ,22083
|
36
|
+
flaxdiff-0.1.12.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
37
|
+
flaxdiff-0.1.12.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
38
|
+
flaxdiff-0.1.12.dist-info/RECORD,,
|
File without changes
|
File without changes
|