x-transformers 1.30.23__py3-none-any.whl → 1.31.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +87 -18
- {x_transformers-1.30.23.dist-info → x_transformers-1.31.1.dist-info}/METADATA +1 -1
- {x_transformers-1.30.23.dist-info → x_transformers-1.31.1.dist-info}/RECORD +6 -6
- {x_transformers-1.30.23.dist-info → x_transformers-1.31.1.dist-info}/WHEEL +1 -1
- {x_transformers-1.30.23.dist-info → x_transformers-1.31.1.dist-info}/LICENSE +0 -0
- {x_transformers-1.30.23.dist-info → x_transformers-1.31.1.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -93,6 +93,9 @@ def l2norm(t, groups = 1):
|
|
93
93
|
t = F.normalize(t, p = 2, dim = -1)
|
94
94
|
return rearrange(t, '... g d -> ... (g d)')
|
95
95
|
|
96
|
+
def softclamp(t, value):
|
97
|
+
return (t / value).tanh() * value
|
98
|
+
|
96
99
|
def pad_at_dim(t, pad: Tuple[int, int], dim = -1, value = 0.):
|
97
100
|
if pad == (0, 0):
|
98
101
|
return t
|
@@ -560,22 +563,34 @@ class Scale(Module):
|
|
560
563
|
return (scale_fn(out[0]), *out[1:])
|
561
564
|
|
562
565
|
class LayerNorm(Module):
|
563
|
-
def __init__(
|
566
|
+
def __init__(
|
567
|
+
self,
|
568
|
+
dim,
|
569
|
+
unit_offset = 0.
|
570
|
+
):
|
564
571
|
"""
|
565
572
|
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
|
566
573
|
"""
|
567
574
|
super().__init__()
|
575
|
+
self.unit_offset = unit_offset
|
576
|
+
|
577
|
+
self.ln = nn.LayerNorm(dim, elementwise_affine = False)
|
568
578
|
self.gamma = nn.Parameter(torch.ones(dim))
|
569
|
-
self.
|
579
|
+
nn.init.constant_(self.gamma, 1. - unit_offset)
|
570
580
|
|
571
|
-
|
572
|
-
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
|
581
|
+
self.register_buffer('beta', torch.zeros(dim), persistent = False)
|
573
582
|
|
574
|
-
|
575
|
-
|
583
|
+
def forward(self, x):
|
584
|
+
normed = self.ln(x)
|
585
|
+
gamma = self.gamma + self.unit_offset
|
586
|
+
return normed * gamma
|
576
587
|
|
577
588
|
class AdaptiveLayerNorm(Module):
|
578
|
-
def __init__(
|
589
|
+
def __init__(
|
590
|
+
self,
|
591
|
+
dim,
|
592
|
+
dim_condition = None
|
593
|
+
):
|
579
594
|
super().__init__()
|
580
595
|
dim_condition = default(dim_condition, dim)
|
581
596
|
|
@@ -590,25 +605,43 @@ class AdaptiveLayerNorm(Module):
|
|
590
605
|
return normed * (gamma + 1.)
|
591
606
|
|
592
607
|
class ScaleNorm(Module):
|
593
|
-
def __init__(
|
608
|
+
def __init__(
|
609
|
+
self,
|
610
|
+
dim,
|
611
|
+
unit_offset = 0.
|
612
|
+
):
|
594
613
|
super().__init__()
|
614
|
+
self.unit_offset = unit_offset
|
595
615
|
self.scale = dim ** 0.5
|
596
|
-
|
616
|
+
|
617
|
+
self.g = nn.Parameter(torch.zeros(1))
|
618
|
+
nn.init.constant_(self.g, 1. - unit_offset)
|
597
619
|
|
598
620
|
def forward(self, x):
|
599
|
-
return F.normalize(x, dim = -1) * self.scale * self.g
|
621
|
+
return F.normalize(x, dim = -1) * self.scale * (self.g + self.unit_offset)
|
600
622
|
|
601
623
|
class RMSNorm(Module):
|
602
|
-
def __init__(
|
624
|
+
def __init__(
|
625
|
+
self,
|
626
|
+
dim,
|
627
|
+
unit_offset = 0.
|
628
|
+
):
|
603
629
|
super().__init__()
|
630
|
+
self.unit_offset = unit_offset
|
604
631
|
self.scale = dim ** 0.5
|
605
|
-
|
632
|
+
|
633
|
+
self.g = nn.Parameter(torch.zeros(dim))
|
634
|
+
nn.init.constant_(self.g, 1. - unit_offset)
|
606
635
|
|
607
636
|
def forward(self, x):
|
608
|
-
return F.normalize(x, dim = -1) * self.scale * self.g
|
637
|
+
return F.normalize(x, dim = -1) * self.scale * (self.g + self.unit_offset)
|
609
638
|
|
610
639
|
class AdaptiveRMSNorm(Module):
|
611
|
-
def __init__(
|
640
|
+
def __init__(
|
641
|
+
self,
|
642
|
+
dim,
|
643
|
+
dim_condition = None
|
644
|
+
):
|
612
645
|
super().__init__()
|
613
646
|
self.scale = dim ** 0.5
|
614
647
|
dim_condition = default(dim_condition, dim)
|
@@ -622,7 +655,11 @@ class AdaptiveRMSNorm(Module):
|
|
622
655
|
return normed * self.scale * (gamma + 1.)
|
623
656
|
|
624
657
|
class SimpleRMSNorm(Module):
|
625
|
-
def __init__(
|
658
|
+
def __init__(
|
659
|
+
self,
|
660
|
+
dim,
|
661
|
+
**kwargs
|
662
|
+
):
|
626
663
|
super().__init__()
|
627
664
|
self.scale = dim ** 0.5
|
628
665
|
|
@@ -696,10 +733,19 @@ class ShiftTokens(Module):
|
|
696
733
|
# post branch operator
|
697
734
|
|
698
735
|
class LayerScale(Module):
|
699
|
-
def __init__(
|
736
|
+
def __init__(
|
737
|
+
self,
|
738
|
+
fn: Module,
|
739
|
+
dim,
|
740
|
+
init_value = 0.,
|
741
|
+
unit_offset = 0.
|
742
|
+
):
|
700
743
|
super().__init__()
|
744
|
+
self.unit_offset = unit_offset
|
745
|
+
|
701
746
|
self.fn = fn
|
702
|
-
self.gamma = nn.Parameter(torch.
|
747
|
+
self.gamma = nn.Parameter(torch.zeros(dim))
|
748
|
+
nn.init.constant_(self.gamma, init_value - unit_offset)
|
703
749
|
|
704
750
|
def forward(self, x, **kwargs):
|
705
751
|
out = self.fn(x, **kwargs)
|
@@ -711,7 +757,13 @@ class LayerScale(Module):
|
|
711
757
|
return out * self.gamma, *rest
|
712
758
|
|
713
759
|
class AdaptiveLayerScale(Module):
|
714
|
-
def __init__(
|
760
|
+
def __init__(
|
761
|
+
self,
|
762
|
+
fn: Module,
|
763
|
+
dim,
|
764
|
+
dim_condition = None,
|
765
|
+
init_bias_value = -2.
|
766
|
+
):
|
715
767
|
super().__init__()
|
716
768
|
self.fn = fn
|
717
769
|
|
@@ -1182,6 +1234,7 @@ class AttentionLayers(Module):
|
|
1182
1234
|
use_adaptive_layernorm = False,
|
1183
1235
|
use_adaptive_rmsnorm = False,
|
1184
1236
|
use_adaptive_layerscale = False, # paired with use_adaptive_layernorm for ada-ln-zero from DiT paper
|
1237
|
+
norm_add_unit_offset = False,
|
1185
1238
|
dim_condition = None,
|
1186
1239
|
adaptive_condition_mlp = False,
|
1187
1240
|
adaptive_condition_mlp_expansion = 4,
|
@@ -1215,6 +1268,7 @@ class AttentionLayers(Module):
|
|
1215
1268
|
scale_residual_constant = 1.,
|
1216
1269
|
shift_tokens = 0,
|
1217
1270
|
sandwich_norm = False,
|
1271
|
+
softclamp_output_value: float | None = None,
|
1218
1272
|
resi_dual = False,
|
1219
1273
|
resi_dual_scale = 1.,
|
1220
1274
|
zero_init_branch_output = False,
|
@@ -1315,6 +1369,10 @@ class AttentionLayers(Module):
|
|
1315
1369
|
|
1316
1370
|
norm_fn = partial(norm_class, dim)
|
1317
1371
|
|
1372
|
+
if not norm_need_condition and norm_add_unit_offset:
|
1373
|
+
# researcher Ohad Rubin shares in a blog post by adding an offset to gammas and betas, they can be subjected to weight decay safely
|
1374
|
+
norm_fn = partial(norm_fn, unit_offset = 1.)
|
1375
|
+
|
1318
1376
|
self.norm_need_condition = norm_need_condition
|
1319
1377
|
self.dim_condition = dim_condition
|
1320
1378
|
|
@@ -1345,6 +1403,9 @@ class AttentionLayers(Module):
|
|
1345
1403
|
|
1346
1404
|
self.post_branch_fn_needs_condition = post_branch_fn_needs_condition
|
1347
1405
|
|
1406
|
+
if not post_branch_fn_needs_condition and norm_add_unit_offset:
|
1407
|
+
post_branch_fn = partial(post_branch_fn, unit_offset = 1.)
|
1408
|
+
|
1348
1409
|
# setup mlp for conditioning
|
1349
1410
|
|
1350
1411
|
self.need_condition = norm_need_condition or post_branch_fn_needs_condition
|
@@ -1421,6 +1482,11 @@ class AttentionLayers(Module):
|
|
1421
1482
|
|
1422
1483
|
shift_tokens = cast_tuple(shift_tokens, len(layer_types))
|
1423
1484
|
|
1485
|
+
# optional soft clamping just before the final norm
|
1486
|
+
# used in gemma 2
|
1487
|
+
|
1488
|
+
self.softclamp_output_value = softclamp_output_value
|
1489
|
+
|
1424
1490
|
# whether it has post norm
|
1425
1491
|
|
1426
1492
|
self.final_norm = norm_fn() if pre_norm or resi_dual else nn.Identity()
|
@@ -1652,6 +1718,9 @@ class AttentionLayers(Module):
|
|
1652
1718
|
if return_hiddens:
|
1653
1719
|
layer_hiddens.append(x)
|
1654
1720
|
|
1721
|
+
if exists(self.softclamp_output_value):
|
1722
|
+
x = softclamp(x, self.softclamp_output_value)
|
1723
|
+
|
1655
1724
|
final_norm = self.final_norm
|
1656
1725
|
|
1657
1726
|
if self.need_condition:
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
|
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=EehddXNWxU7NEqD8t76ekev__cCJ-F3J1oZjAD9PGa8,75806
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.
|
11
|
-
x_transformers-1.
|
12
|
-
x_transformers-1.
|
13
|
-
x_transformers-1.
|
14
|
-
x_transformers-1.
|
10
|
+
x_transformers-1.31.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.31.1.dist-info/METADATA,sha256=FWASwgSjICfN3sak8Itwh41R-g2sT_wzhBm0RA3vPzA,661
|
12
|
+
x_transformers-1.31.1.dist-info/WHEEL,sha256=mguMlWGMX-VHnMpKOjjQidIo1ssRlCFu4a4mBpz1s2M,91
|
13
|
+
x_transformers-1.31.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.31.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|