x-transformers 1.30.23__py3-none-any.whl → 1.31.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +55 -13
- {x_transformers-1.30.23.dist-info → x_transformers-1.31.0.dist-info}/METADATA +1 -1
- {x_transformers-1.30.23.dist-info → x_transformers-1.31.0.dist-info}/RECORD +6 -6
- {x_transformers-1.30.23.dist-info → x_transformers-1.31.0.dist-info}/WHEEL +1 -1
- {x_transformers-1.30.23.dist-info → x_transformers-1.31.0.dist-info}/LICENSE +0 -0
- {x_transformers-1.30.23.dist-info → x_transformers-1.31.0.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -93,6 +93,9 @@ def l2norm(t, groups = 1):
|
|
93
93
|
t = F.normalize(t, p = 2, dim = -1)
|
94
94
|
return rearrange(t, '... g d -> ... (g d)')
|
95
95
|
|
96
|
+
def softclamp(t, value):
|
97
|
+
return (t / value).tanh() * value
|
98
|
+
|
96
99
|
def pad_at_dim(t, pad: Tuple[int, int], dim = -1, value = 0.):
|
97
100
|
if pad == (0, 0):
|
98
101
|
return t
|
@@ -560,22 +563,29 @@ class Scale(Module):
|
|
560
563
|
return (scale_fn(out[0]), *out[1:])
|
561
564
|
|
562
565
|
class LayerNorm(Module):
|
563
|
-
def __init__(
|
566
|
+
def __init__(
|
567
|
+
self,
|
568
|
+
dim,
|
569
|
+
unit_offset = 0.
|
570
|
+
):
|
564
571
|
"""
|
565
572
|
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
|
566
573
|
"""
|
567
574
|
super().__init__()
|
575
|
+
self.unit_offset = unit_offset
|
568
576
|
self.gamma = nn.Parameter(torch.ones(dim))
|
569
|
-
self.register_buffer(
|
577
|
+
self.register_buffer('beta', torch.zeros(dim), persistent = False)
|
570
578
|
|
571
579
|
def forward(self, x):
|
572
|
-
|
573
|
-
|
574
|
-
if version.parse(torch.__version__) >= version.parse('2.1.0'):
|
575
|
-
LayerNorm = partial(nn.LayerNorm, bias = False)
|
580
|
+
gamma = self.gamma + self.unit_offset
|
581
|
+
return F.layer_norm(x, x.shape[-1:], gamma, self.beta)
|
576
582
|
|
577
583
|
class AdaptiveLayerNorm(Module):
|
578
|
-
def __init__(
|
584
|
+
def __init__(
|
585
|
+
self,
|
586
|
+
dim,
|
587
|
+
dim_condition = None
|
588
|
+
):
|
579
589
|
super().__init__()
|
580
590
|
dim_condition = default(dim_condition, dim)
|
581
591
|
|
@@ -590,25 +600,39 @@ class AdaptiveLayerNorm(Module):
|
|
590
600
|
return normed * (gamma + 1.)
|
591
601
|
|
592
602
|
class ScaleNorm(Module):
|
593
|
-
def __init__(
|
603
|
+
def __init__(
|
604
|
+
self,
|
605
|
+
dim,
|
606
|
+
unit_offset = 0.
|
607
|
+
):
|
594
608
|
super().__init__()
|
609
|
+
self.unit_offset = unit_offset
|
595
610
|
self.scale = dim ** 0.5
|
596
611
|
self.g = nn.Parameter(torch.ones(1))
|
597
612
|
|
598
613
|
def forward(self, x):
|
599
|
-
return F.normalize(x, dim = -1) * self.scale * self.g
|
614
|
+
return F.normalize(x, dim = -1) * self.scale * (self.g + self.unit_offset)
|
600
615
|
|
601
616
|
class RMSNorm(Module):
|
602
|
-
def __init__(
|
617
|
+
def __init__(
|
618
|
+
self,
|
619
|
+
dim,
|
620
|
+
unit_offset = 0.
|
621
|
+
):
|
603
622
|
super().__init__()
|
623
|
+
self.unit_offset = unit_offset
|
604
624
|
self.scale = dim ** 0.5
|
605
625
|
self.g = nn.Parameter(torch.ones(dim))
|
606
626
|
|
607
627
|
def forward(self, x):
|
608
|
-
return F.normalize(x, dim = -1) * self.scale * self.g
|
628
|
+
return F.normalize(x, dim = -1) * self.scale * (self.g + self.unit_offset)
|
609
629
|
|
610
630
|
class AdaptiveRMSNorm(Module):
|
611
|
-
def __init__(
|
631
|
+
def __init__(
|
632
|
+
self,
|
633
|
+
dim,
|
634
|
+
dim_condition = None
|
635
|
+
):
|
612
636
|
super().__init__()
|
613
637
|
self.scale = dim ** 0.5
|
614
638
|
dim_condition = default(dim_condition, dim)
|
@@ -622,7 +646,11 @@ class AdaptiveRMSNorm(Module):
|
|
622
646
|
return normed * self.scale * (gamma + 1.)
|
623
647
|
|
624
648
|
class SimpleRMSNorm(Module):
|
625
|
-
def __init__(
|
649
|
+
def __init__(
|
650
|
+
self,
|
651
|
+
dim,
|
652
|
+
**kwargs
|
653
|
+
):
|
626
654
|
super().__init__()
|
627
655
|
self.scale = dim ** 0.5
|
628
656
|
|
@@ -1182,6 +1210,7 @@ class AttentionLayers(Module):
|
|
1182
1210
|
use_adaptive_layernorm = False,
|
1183
1211
|
use_adaptive_rmsnorm = False,
|
1184
1212
|
use_adaptive_layerscale = False, # paired with use_adaptive_layernorm for ada-ln-zero from DiT paper
|
1213
|
+
norm_add_unit_offset = False,
|
1185
1214
|
dim_condition = None,
|
1186
1215
|
adaptive_condition_mlp = False,
|
1187
1216
|
adaptive_condition_mlp_expansion = 4,
|
@@ -1215,6 +1244,7 @@ class AttentionLayers(Module):
|
|
1215
1244
|
scale_residual_constant = 1.,
|
1216
1245
|
shift_tokens = 0,
|
1217
1246
|
sandwich_norm = False,
|
1247
|
+
softclamp_output_value: float | None = None,
|
1218
1248
|
resi_dual = False,
|
1219
1249
|
resi_dual_scale = 1.,
|
1220
1250
|
zero_init_branch_output = False,
|
@@ -1315,6 +1345,10 @@ class AttentionLayers(Module):
|
|
1315
1345
|
|
1316
1346
|
norm_fn = partial(norm_class, dim)
|
1317
1347
|
|
1348
|
+
if not norm_need_condition and norm_add_unit_offset:
|
1349
|
+
# research Ohad Rubin shares in a blog post by adding an offset to gammas and betas, they can be subjected to weight decay safely
|
1350
|
+
norm_fn = partial(norm_fn, unit_offset = 1.)
|
1351
|
+
|
1318
1352
|
self.norm_need_condition = norm_need_condition
|
1319
1353
|
self.dim_condition = dim_condition
|
1320
1354
|
|
@@ -1421,6 +1455,11 @@ class AttentionLayers(Module):
|
|
1421
1455
|
|
1422
1456
|
shift_tokens = cast_tuple(shift_tokens, len(layer_types))
|
1423
1457
|
|
1458
|
+
# optional soft clamping just before the final norm
|
1459
|
+
# used in gemma 2
|
1460
|
+
|
1461
|
+
self.softclamp_output_value = softclamp_output_value
|
1462
|
+
|
1424
1463
|
# whether it has post norm
|
1425
1464
|
|
1426
1465
|
self.final_norm = norm_fn() if pre_norm or resi_dual else nn.Identity()
|
@@ -1652,6 +1691,9 @@ class AttentionLayers(Module):
|
|
1652
1691
|
if return_hiddens:
|
1653
1692
|
layer_hiddens.append(x)
|
1654
1693
|
|
1694
|
+
if exists(self.softclamp_output_value):
|
1695
|
+
x = softclamp(x, self.softclamp_output_value)
|
1696
|
+
|
1655
1697
|
final_norm = self.final_norm
|
1656
1698
|
|
1657
1699
|
if self.need_condition:
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
|
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=Lbfp0VHMCMkZboYficEe8d-X7u55uOh4lQbvIz_db4c,75233
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.
|
11
|
-
x_transformers-1.
|
12
|
-
x_transformers-1.
|
13
|
-
x_transformers-1.
|
14
|
-
x_transformers-1.
|
10
|
+
x_transformers-1.31.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.31.0.dist-info/METADATA,sha256=0-TZQ6Y-FpgOjF3z3GIO2nPRX5__9lhe5N6Qct5U79Y,661
|
12
|
+
x_transformers-1.31.0.dist-info/WHEEL,sha256=mguMlWGMX-VHnMpKOjjQidIo1ssRlCFu4a4mBpz1s2M,91
|
13
|
+
x_transformers-1.31.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.31.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|