x-transformers 1.31.0__py3-none-any.whl → 1.31.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +35 -8
- {x_transformers-1.31.0.dist-info → x_transformers-1.31.2.dist-info}/METADATA +1 -1
- {x_transformers-1.31.0.dist-info → x_transformers-1.31.2.dist-info}/RECORD +6 -6
- {x_transformers-1.31.0.dist-info → x_transformers-1.31.2.dist-info}/LICENSE +0 -0
- {x_transformers-1.31.0.dist-info → x_transformers-1.31.2.dist-info}/WHEEL +0 -0
- {x_transformers-1.31.0.dist-info → x_transformers-1.31.2.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -573,12 +573,17 @@ class LayerNorm(Module):
|
|
573
573
|
"""
|
574
574
|
super().__init__()
|
575
575
|
self.unit_offset = unit_offset
|
576
|
+
|
577
|
+
self.ln = nn.LayerNorm(dim, elementwise_affine = False)
|
576
578
|
self.gamma = nn.Parameter(torch.ones(dim))
|
579
|
+
nn.init.constant_(self.gamma, 1. - unit_offset)
|
580
|
+
|
577
581
|
self.register_buffer('beta', torch.zeros(dim), persistent = False)
|
578
582
|
|
579
583
|
def forward(self, x):
|
584
|
+
normed = self.ln(x)
|
580
585
|
gamma = self.gamma + self.unit_offset
|
581
|
-
return
|
586
|
+
return normed * gamma
|
582
587
|
|
583
588
|
class AdaptiveLayerNorm(Module):
|
584
589
|
def __init__(
|
@@ -608,7 +613,9 @@ class ScaleNorm(Module):
|
|
608
613
|
super().__init__()
|
609
614
|
self.unit_offset = unit_offset
|
610
615
|
self.scale = dim ** 0.5
|
611
|
-
|
616
|
+
|
617
|
+
self.g = nn.Parameter(torch.zeros(1))
|
618
|
+
nn.init.constant_(self.g, 1. - unit_offset)
|
612
619
|
|
613
620
|
def forward(self, x):
|
614
621
|
return F.normalize(x, dim = -1) * self.scale * (self.g + self.unit_offset)
|
@@ -622,7 +629,9 @@ class RMSNorm(Module):
|
|
622
629
|
super().__init__()
|
623
630
|
self.unit_offset = unit_offset
|
624
631
|
self.scale = dim ** 0.5
|
625
|
-
|
632
|
+
|
633
|
+
self.g = nn.Parameter(torch.zeros(dim))
|
634
|
+
nn.init.constant_(self.g, 1. - unit_offset)
|
626
635
|
|
627
636
|
def forward(self, x):
|
628
637
|
return F.normalize(x, dim = -1) * self.scale * (self.g + self.unit_offset)
|
@@ -724,10 +733,19 @@ class ShiftTokens(Module):
|
|
724
733
|
# post branch operator
|
725
734
|
|
726
735
|
class LayerScale(Module):
|
727
|
-
def __init__(
|
736
|
+
def __init__(
|
737
|
+
self,
|
738
|
+
fn: Module,
|
739
|
+
dim,
|
740
|
+
init_value = 0.,
|
741
|
+
unit_offset = 0.
|
742
|
+
):
|
728
743
|
super().__init__()
|
744
|
+
self.unit_offset = unit_offset
|
745
|
+
|
729
746
|
self.fn = fn
|
730
|
-
self.gamma = nn.Parameter(torch.
|
747
|
+
self.gamma = nn.Parameter(torch.zeros(dim))
|
748
|
+
nn.init.constant_(self.gamma, init_value - unit_offset)
|
731
749
|
|
732
750
|
def forward(self, x, **kwargs):
|
733
751
|
out = self.fn(x, **kwargs)
|
@@ -739,7 +757,13 @@ class LayerScale(Module):
|
|
739
757
|
return out * self.gamma, *rest
|
740
758
|
|
741
759
|
class AdaptiveLayerScale(Module):
|
742
|
-
def __init__(
|
760
|
+
def __init__(
|
761
|
+
self,
|
762
|
+
fn: Module,
|
763
|
+
dim,
|
764
|
+
dim_condition = None,
|
765
|
+
init_bias_value = -2.
|
766
|
+
):
|
743
767
|
super().__init__()
|
744
768
|
self.fn = fn
|
745
769
|
|
@@ -1210,7 +1234,7 @@ class AttentionLayers(Module):
|
|
1210
1234
|
use_adaptive_layernorm = False,
|
1211
1235
|
use_adaptive_rmsnorm = False,
|
1212
1236
|
use_adaptive_layerscale = False, # paired with use_adaptive_layernorm for ada-ln-zero from DiT paper
|
1213
|
-
norm_add_unit_offset =
|
1237
|
+
norm_add_unit_offset = True,
|
1214
1238
|
dim_condition = None,
|
1215
1239
|
adaptive_condition_mlp = False,
|
1216
1240
|
adaptive_condition_mlp_expansion = 4,
|
@@ -1346,7 +1370,7 @@ class AttentionLayers(Module):
|
|
1346
1370
|
norm_fn = partial(norm_class, dim)
|
1347
1371
|
|
1348
1372
|
if not norm_need_condition and norm_add_unit_offset:
|
1349
|
-
#
|
1373
|
+
# researcher Ohad Rubin shares in a blog post by adding an offset to gammas and betas, they can be subjected to weight decay safely
|
1350
1374
|
norm_fn = partial(norm_fn, unit_offset = 1.)
|
1351
1375
|
|
1352
1376
|
self.norm_need_condition = norm_need_condition
|
@@ -1379,6 +1403,9 @@ class AttentionLayers(Module):
|
|
1379
1403
|
|
1380
1404
|
self.post_branch_fn_needs_condition = post_branch_fn_needs_condition
|
1381
1405
|
|
1406
|
+
if not post_branch_fn_needs_condition and norm_add_unit_offset:
|
1407
|
+
post_branch_fn = partial(post_branch_fn, unit_offset = 1.)
|
1408
|
+
|
1382
1409
|
# setup mlp for conditioning
|
1383
1410
|
|
1384
1411
|
self.need_condition = norm_need_condition or post_branch_fn_needs_condition
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
|
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=aeFAk19BRxkLOS-pUgtDk3HaMS6_dsuJtz2o_KkUkOU,75805
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.31.
|
11
|
-
x_transformers-1.31.
|
12
|
-
x_transformers-1.31.
|
13
|
-
x_transformers-1.31.
|
14
|
-
x_transformers-1.31.
|
10
|
+
x_transformers-1.31.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.31.2.dist-info/METADATA,sha256=dxDZYzaZwwOCD-9I6ver-SSLlSG4diss-ovDo2H_9xE,661
|
12
|
+
x_transformers-1.31.2.dist-info/WHEEL,sha256=mguMlWGMX-VHnMpKOjjQidIo1ssRlCFu4a4mBpz1s2M,91
|
13
|
+
x_transformers-1.31.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.31.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|