flaxdiff 0.1.11__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/models/attention.py +12 -12
- {flaxdiff-0.1.11.dist-info → flaxdiff-0.1.12.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.11.dist-info → flaxdiff-0.1.12.dist-info}/RECORD +5 -5
- {flaxdiff-0.1.11.dist-info → flaxdiff-0.1.12.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.11.dist-info → flaxdiff-0.1.12.dist-info}/top_level.txt +0 -0
flaxdiff/models/attention.py
CHANGED
@@ -22,7 +22,7 @@ class EfficientAttention(nn.Module):
|
|
22
22
|
dtype: Optional[Dtype] = None
|
23
23
|
precision: PrecisionLike = None
|
24
24
|
use_bias: bool = True
|
25
|
-
kernel_init: Callable =
|
25
|
+
kernel_init: Callable = kernel_init(1.0)
|
26
26
|
force_fp32_for_softmax: bool = True
|
27
27
|
|
28
28
|
def setup(self):
|
@@ -33,7 +33,7 @@ class EfficientAttention(nn.Module):
|
|
33
33
|
self.heads * self.dim_head,
|
34
34
|
precision=self.precision,
|
35
35
|
use_bias=self.use_bias,
|
36
|
-
kernel_init=self.kernel_init
|
36
|
+
kernel_init=self.kernel_init,
|
37
37
|
dtype=self.dtype
|
38
38
|
)
|
39
39
|
self.query = dense(name="to_q")
|
@@ -41,7 +41,7 @@ class EfficientAttention(nn.Module):
|
|
41
41
|
self.value = dense(name="to_v")
|
42
42
|
|
43
43
|
self.proj_attn = nn.DenseGeneral(self.query_dim, use_bias=False, precision=self.precision,
|
44
|
-
kernel_init=self.kernel_init
|
44
|
+
kernel_init=self.kernel_init, dtype=self.dtype, name="to_out_0")
|
45
45
|
# self.attnfn = make_fast_generalized_attention(qkv_dim=inner_dim, lax_scan_unroll=16)
|
46
46
|
|
47
47
|
def _reshape_tensor_to_head_dim(self, tensor):
|
@@ -114,7 +114,7 @@ class NormalAttention(nn.Module):
|
|
114
114
|
dtype: Optional[Dtype] = None
|
115
115
|
precision: PrecisionLike = None
|
116
116
|
use_bias: bool = True
|
117
|
-
kernel_init: Callable =
|
117
|
+
kernel_init: Callable = kernel_init(1.0)
|
118
118
|
force_fp32_for_softmax: bool = True
|
119
119
|
|
120
120
|
def setup(self):
|
@@ -125,7 +125,7 @@ class NormalAttention(nn.Module):
|
|
125
125
|
axis=-1,
|
126
126
|
precision=self.precision,
|
127
127
|
use_bias=self.use_bias,
|
128
|
-
kernel_init=self.kernel_init
|
128
|
+
kernel_init=self.kernel_init,
|
129
129
|
dtype=self.dtype
|
130
130
|
)
|
131
131
|
self.query = dense(name="to_q")
|
@@ -139,7 +139,7 @@ class NormalAttention(nn.Module):
|
|
139
139
|
use_bias=self.use_bias,
|
140
140
|
dtype=self.dtype,
|
141
141
|
name="to_out_0",
|
142
|
-
kernel_init=self.kernel_init
|
142
|
+
kernel_init=self.kernel_init
|
143
143
|
# kernel_init=jax.nn.initializers.xavier_uniform()
|
144
144
|
)
|
145
145
|
|
@@ -235,7 +235,7 @@ class BasicTransformerBlock(nn.Module):
|
|
235
235
|
dtype: Optional[Dtype] = None
|
236
236
|
precision: PrecisionLike = None
|
237
237
|
use_bias: bool = True
|
238
|
-
kernel_init: Callable =
|
238
|
+
kernel_init: Callable = kernel_init(1.0)
|
239
239
|
use_flash_attention:bool = False
|
240
240
|
use_cross_only:bool = False
|
241
241
|
only_pure_attention:bool = False
|
@@ -302,7 +302,7 @@ class TransformerBlock(nn.Module):
|
|
302
302
|
use_self_and_cross:bool = True
|
303
303
|
only_pure_attention:bool = False
|
304
304
|
force_fp32_for_softmax: bool = True
|
305
|
-
kernel_init: Callable =
|
305
|
+
kernel_init: Callable = kernel_init(1.0)
|
306
306
|
|
307
307
|
@nn.compact
|
308
308
|
def __call__(self, x, context=None):
|
@@ -313,12 +313,12 @@ class TransformerBlock(nn.Module):
|
|
313
313
|
if self.use_linear_attention:
|
314
314
|
projected_x = nn.Dense(features=inner_dim,
|
315
315
|
use_bias=False, precision=self.precision,
|
316
|
-
kernel_init=self.kernel_init
|
316
|
+
kernel_init=self.kernel_init,
|
317
317
|
dtype=self.dtype, name=f'project_in')(normed_x)
|
318
318
|
else:
|
319
319
|
projected_x = nn.Conv(
|
320
320
|
features=inner_dim, kernel_size=(1, 1),
|
321
|
-
kernel_init=self.kernel_init
|
321
|
+
kernel_init=self.kernel_init,
|
322
322
|
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
|
323
323
|
precision=self.precision, name=f'project_in_conv',
|
324
324
|
)(normed_x)
|
@@ -347,12 +347,12 @@ class TransformerBlock(nn.Module):
|
|
347
347
|
if self.use_linear_attention:
|
348
348
|
projected_x = nn.Dense(features=C, precision=self.precision,
|
349
349
|
dtype=self.dtype, use_bias=False,
|
350
|
-
kernel_init=self.kernel_init
|
350
|
+
kernel_init=self.kernel_init,
|
351
351
|
name=f'project_out')(projected_x)
|
352
352
|
else:
|
353
353
|
projected_x = nn.Conv(
|
354
354
|
features=C, kernel_size=(1, 1),
|
355
|
-
kernel_init=self.kernel_init
|
355
|
+
kernel_init=self.kernel_init,
|
356
356
|
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
|
357
357
|
precision=self.precision, name=f'project_out_conv',
|
358
358
|
)(projected_x)
|
@@ -1,7 +1,7 @@
|
|
1
1
|
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
|
3
3
|
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
4
|
-
flaxdiff/models/attention.py,sha256=
|
4
|
+
flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
|
5
5
|
flaxdiff/models/common.py,sha256=fd-Fl0VCNEBjijHNwGBqYL5VvXe9u0347h25czNTmRw,10780
|
6
6
|
flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
|
7
7
|
flaxdiff/models/simple_unet.py,sha256=h1o9mQlLJy7Ec8Pz_O5miRbAyUaM5UNhSs-oXzpQvZo,10763
|
@@ -32,7 +32,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
|
|
32
32
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
|
33
33
|
flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
|
34
34
|
flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
|
35
|
-
flaxdiff-0.1.
|
36
|
-
flaxdiff-0.1.
|
37
|
-
flaxdiff-0.1.
|
38
|
-
flaxdiff-0.1.
|
35
|
+
flaxdiff-0.1.12.dist-info/METADATA,sha256=s3rIj9jqh1Xr1NABqOJZw9XwyHdaLd01c_jKpzEMErQ,22083
|
36
|
+
flaxdiff-0.1.12.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
37
|
+
flaxdiff-0.1.12.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
38
|
+
flaxdiff-0.1.12.dist-info/RECORD,,
|
File without changes
|
File without changes
|