x-transformers 1.31.2__py3-none-any.whl → 1.31.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +12 -15
- {x_transformers-1.31.2.dist-info → x_transformers-1.31.4.dist-info}/METADATA +1 -1
- {x_transformers-1.31.2.dist-info → x_transformers-1.31.4.dist-info}/RECORD +6 -6
- {x_transformers-1.31.2.dist-info → x_transformers-1.31.4.dist-info}/LICENSE +0 -0
- {x_transformers-1.31.2.dist-info → x_transformers-1.31.4.dist-info}/WHEEL +0 -0
- {x_transformers-1.31.2.dist-info → x_transformers-1.31.4.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -566,7 +566,7 @@ class LayerNorm(Module):
|
|
566
566
|
def __init__(
|
567
567
|
self,
|
568
568
|
dim,
|
569
|
-
unit_offset =
|
569
|
+
unit_offset = False
|
570
570
|
):
|
571
571
|
"""
|
572
572
|
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
|
@@ -576,9 +576,7 @@ class LayerNorm(Module):
|
|
576
576
|
|
577
577
|
self.ln = nn.LayerNorm(dim, elementwise_affine = False)
|
578
578
|
self.gamma = nn.Parameter(torch.ones(dim))
|
579
|
-
nn.init.constant_(self.gamma, 1. - unit_offset)
|
580
|
-
|
581
|
-
self.register_buffer('beta', torch.zeros(dim), persistent = False)
|
579
|
+
nn.init.constant_(self.gamma, 1. - float(unit_offset))
|
582
580
|
|
583
581
|
def forward(self, x):
|
584
582
|
normed = self.ln(x)
|
@@ -596,7 +594,6 @@ class AdaptiveLayerNorm(Module):
|
|
596
594
|
|
597
595
|
self.ln = nn.LayerNorm(dim, elementwise_affine = False)
|
598
596
|
self.to_gamma = nn.Linear(dim_condition, dim, bias = False)
|
599
|
-
|
600
597
|
nn.init.zeros_(self.to_gamma.weight)
|
601
598
|
|
602
599
|
def forward(self, x, *, condition):
|
@@ -608,14 +605,14 @@ class ScaleNorm(Module):
|
|
608
605
|
def __init__(
|
609
606
|
self,
|
610
607
|
dim,
|
611
|
-
unit_offset =
|
608
|
+
unit_offset = False
|
612
609
|
):
|
613
610
|
super().__init__()
|
614
611
|
self.unit_offset = unit_offset
|
615
612
|
self.scale = dim ** 0.5
|
616
613
|
|
617
614
|
self.g = nn.Parameter(torch.zeros(1))
|
618
|
-
nn.init.constant_(self.g, 1. - unit_offset)
|
615
|
+
nn.init.constant_(self.g, 1. - float(unit_offset))
|
619
616
|
|
620
617
|
def forward(self, x):
|
621
618
|
return F.normalize(x, dim = -1) * self.scale * (self.g + self.unit_offset)
|
@@ -624,14 +621,14 @@ class RMSNorm(Module):
|
|
624
621
|
def __init__(
|
625
622
|
self,
|
626
623
|
dim,
|
627
|
-
unit_offset =
|
624
|
+
unit_offset = False
|
628
625
|
):
|
629
626
|
super().__init__()
|
630
627
|
self.unit_offset = unit_offset
|
631
628
|
self.scale = dim ** 0.5
|
632
629
|
|
633
630
|
self.g = nn.Parameter(torch.zeros(dim))
|
634
|
-
nn.init.constant_(self.g, 1. - unit_offset)
|
631
|
+
nn.init.constant_(self.g, 1. - float(unit_offset))
|
635
632
|
|
636
633
|
def forward(self, x):
|
637
634
|
return F.normalize(x, dim = -1) * self.scale * (self.g + self.unit_offset)
|
@@ -738,14 +735,14 @@ class LayerScale(Module):
|
|
738
735
|
fn: Module,
|
739
736
|
dim,
|
740
737
|
init_value = 0.,
|
741
|
-
unit_offset =
|
738
|
+
unit_offset = False
|
742
739
|
):
|
743
740
|
super().__init__()
|
744
741
|
self.unit_offset = unit_offset
|
745
742
|
|
746
743
|
self.fn = fn
|
747
744
|
self.gamma = nn.Parameter(torch.zeros(dim))
|
748
|
-
nn.init.constant_(self.gamma, init_value - unit_offset)
|
745
|
+
nn.init.constant_(self.gamma, init_value - float(unit_offset))
|
749
746
|
|
750
747
|
def forward(self, x, **kwargs):
|
751
748
|
out = self.fn(x, **kwargs)
|
@@ -1370,8 +1367,8 @@ class AttentionLayers(Module):
|
|
1370
1367
|
norm_fn = partial(norm_class, dim)
|
1371
1368
|
|
1372
1369
|
if not norm_need_condition and norm_add_unit_offset:
|
1373
|
-
# researcher Ohad Rubin shares in a blog post by adding an offset to gammas
|
1374
|
-
norm_fn = partial(norm_fn, unit_offset =
|
1370
|
+
# researcher Ohad Rubin shares in a blog post by adding an offset to gammas, they can be subjected to weight decay safely
|
1371
|
+
norm_fn = partial(norm_fn, unit_offset = True)
|
1375
1372
|
|
1376
1373
|
self.norm_need_condition = norm_need_condition
|
1377
1374
|
self.dim_condition = dim_condition
|
@@ -1403,8 +1400,8 @@ class AttentionLayers(Module):
|
|
1403
1400
|
|
1404
1401
|
self.post_branch_fn_needs_condition = post_branch_fn_needs_condition
|
1405
1402
|
|
1406
|
-
if not post_branch_fn_needs_condition and norm_add_unit_offset:
|
1407
|
-
post_branch_fn = partial(post_branch_fn, unit_offset =
|
1403
|
+
if exists(post_branch_fn) and not post_branch_fn_needs_condition and norm_add_unit_offset:
|
1404
|
+
post_branch_fn = partial(post_branch_fn, unit_offset = True)
|
1408
1405
|
|
1409
1406
|
# setup mlp for conditioning
|
1410
1407
|
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
|
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=82gdvdwWWDWZVXTQRE6B4W-RQspg4SXhD8B6BV5pOo0,75789
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.31.
|
11
|
-
x_transformers-1.31.
|
12
|
-
x_transformers-1.31.
|
13
|
-
x_transformers-1.31.
|
14
|
-
x_transformers-1.31.
|
10
|
+
x_transformers-1.31.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.31.4.dist-info/METADATA,sha256=USK7uPjCATSS4URgfekTTmPhOtvHagHC02b-R-gxwME,661
|
12
|
+
x_transformers-1.31.4.dist-info/WHEEL,sha256=mguMlWGMX-VHnMpKOjjQidIo1ssRlCFu4a4mBpz1s2M,91
|
13
|
+
x_transformers-1.31.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.31.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|