x-transformers 1.30.23__tar.gz → 1.31.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-1.30.23/x_transformers.egg-info → x_transformers-1.31.1}/PKG-INFO +1 -1
- {x_transformers-1.30.23 → x_transformers-1.31.1}/README.md +18 -0
- {x_transformers-1.30.23 → x_transformers-1.31.1}/setup.py +1 -1
- {x_transformers-1.30.23 → x_transformers-1.31.1}/x_transformers/x_transformers.py +87 -18
- {x_transformers-1.30.23 → x_transformers-1.31.1/x_transformers.egg-info}/PKG-INFO +1 -1
- {x_transformers-1.30.23 → x_transformers-1.31.1}/LICENSE +0 -0
- {x_transformers-1.30.23 → x_transformers-1.31.1}/setup.cfg +0 -0
- {x_transformers-1.30.23 → x_transformers-1.31.1}/tests/test_x_transformers.py +0 -0
- {x_transformers-1.30.23 → x_transformers-1.31.1}/x_transformers/__init__.py +0 -0
- {x_transformers-1.30.23 → x_transformers-1.31.1}/x_transformers/attend.py +0 -0
- {x_transformers-1.30.23 → x_transformers-1.31.1}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-1.30.23 → x_transformers-1.31.1}/x_transformers/continuous.py +0 -0
- {x_transformers-1.30.23 → x_transformers-1.31.1}/x_transformers/dpo.py +0 -0
- {x_transformers-1.30.23 → x_transformers-1.31.1}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-1.30.23 → x_transformers-1.31.1}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-1.30.23 → x_transformers-1.31.1}/x_transformers/xval.py +0 -0
- {x_transformers-1.30.23 → x_transformers-1.31.1}/x_transformers.egg-info/SOURCES.txt +0 -0
- {x_transformers-1.30.23 → x_transformers-1.31.1}/x_transformers.egg-info/dependency_links.txt +0 -0
- {x_transformers-1.30.23 → x_transformers-1.31.1}/x_transformers.egg-info/requires.txt +0 -0
- {x_transformers-1.30.23 → x_transformers-1.31.1}/x_transformers.egg-info/top_level.txt +0 -0
@@ -2169,4 +2169,22 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2169
2169
|
}
|
2170
2170
|
```
|
2171
2171
|
|
2172
|
+
```bibtex
|
2173
|
+
@misc{Guttenberg2023,
|
2174
|
+
author = {Ohad Rubin},
|
2175
|
+
url = {https://medium.com/@ohadrubin/exploring-weight-decay-in-layer-normalization-challenges-and-a-reparameterization-solution-ad4d12c24950}
|
2176
|
+
}
|
2177
|
+
```
|
2178
|
+
|
2179
|
+
```bibtex
|
2180
|
+
@article{Mesnard2024GemmaOM,
|
2181
|
+
title = {Gemma: Open Models Based on Gemini Research and Technology},
|
2182
|
+
author = {Gemma Team Thomas Mesnard and Cassidy Hardin and Robert Dadashi and Surya Bhupatiraju and Shreya Pathak and L. Sifre and Morgane Riviere and Mihir Kale and J Christopher Love and Pouya Dehghani Tafti and L'eonard Hussenot and Aakanksha Chowdhery and Adam Roberts and Aditya Barua and Alex Botev and Alex Castro-Ros and Ambrose Slone and Am'elie H'eliou and Andrea Tacchetti and Anna Bulanova and Antonia Paterson and Beth Tsai and Bobak Shahriari and Charline Le Lan and Christopher A. Choquette-Choo and Cl'ement Crepy and Daniel Cer and Daphne Ippolito and David Reid and Elena Buchatskaya and Eric Ni and Eric Noland and Geng Yan and George Tucker and George-Christian Muraru and Grigory Rozhdestvenskiy and Henryk Michalewski and Ian Tenney and Ivan Grishchenko and Jacob Austin and James Keeling and Jane Labanowski and Jean-Baptiste Lespiau and Jeff Stanway and Jenny Brennan and Jeremy Chen and Johan Ferret and Justin Chiu and Justin Mao-Jones and Katherine Lee and Kathy Yu and Katie Millican and Lars Lowe Sjoesund and Lisa Lee and Lucas Dixon and Machel Reid and Maciej Mikula and Mateo Wirth and Michael Sharman and Nikolai Chinaev and Nithum Thain and Olivier Bachem and Oscar Chang and Oscar Wahltinez and Paige Bailey and Paul Michel and Petko Yotov and Pier Giuseppe Sessa and Rahma Chaabouni and Ramona Comanescu and Reena Jana and Rohan Anil and Ross McIlroy and Ruibo Liu and Ryan Mullins and Samuel L Smith and Sebastian Borgeaud and Sertan Girgin and Sholto Douglas and Shree Pandya and Siamak Shakeri and Soham De and Ted Klimenko and Tom Hennigan and Vladimir Feinberg and Wojciech Stokowiec and Yu-hui Chen and Zafarali Ahmed and Zhitao Gong and Tris Brian Warkentin and Ludovic Peran and Minh Giang and Cl'ement Farabet and Oriol Vinyals and Jeffrey Dean and Koray Kavukcuoglu and Demis Hassabis and Zoubin Ghahramani and Douglas Eck and Joelle Barral and Fernando Pereira and Eli Collins and Armand Joulin and Noah Fiedel and Evan Senter and Alek Andreev and Kathleen Kenealy},
|
2183
|
+
journal = {ArXiv},
|
2184
|
+
year = {2024},
|
2185
|
+
volume = {abs/2403.08295},
|
2186
|
+
url = {https://api.semanticscholar.org/CorpusID:268379206}
|
2187
|
+
}
|
2188
|
+
```
|
2189
|
+
|
2172
2190
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
@@ -93,6 +93,9 @@ def l2norm(t, groups = 1):
|
|
93
93
|
t = F.normalize(t, p = 2, dim = -1)
|
94
94
|
return rearrange(t, '... g d -> ... (g d)')
|
95
95
|
|
96
|
+
def softclamp(t, value):
|
97
|
+
return (t / value).tanh() * value
|
98
|
+
|
96
99
|
def pad_at_dim(t, pad: Tuple[int, int], dim = -1, value = 0.):
|
97
100
|
if pad == (0, 0):
|
98
101
|
return t
|
@@ -560,22 +563,34 @@ class Scale(Module):
|
|
560
563
|
return (scale_fn(out[0]), *out[1:])
|
561
564
|
|
562
565
|
class LayerNorm(Module):
|
563
|
-
def __init__(
|
566
|
+
def __init__(
|
567
|
+
self,
|
568
|
+
dim,
|
569
|
+
unit_offset = 0.
|
570
|
+
):
|
564
571
|
"""
|
565
572
|
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
|
566
573
|
"""
|
567
574
|
super().__init__()
|
575
|
+
self.unit_offset = unit_offset
|
576
|
+
|
577
|
+
self.ln = nn.LayerNorm(dim, elementwise_affine = False)
|
568
578
|
self.gamma = nn.Parameter(torch.ones(dim))
|
569
|
-
self.
|
579
|
+
nn.init.constant_(self.gamma, 1. - unit_offset)
|
570
580
|
|
571
|
-
|
572
|
-
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
|
581
|
+
self.register_buffer('beta', torch.zeros(dim), persistent = False)
|
573
582
|
|
574
|
-
|
575
|
-
|
583
|
+
def forward(self, x):
|
584
|
+
normed = self.ln(x)
|
585
|
+
gamma = self.gamma + self.unit_offset
|
586
|
+
return normed * gamma
|
576
587
|
|
577
588
|
class AdaptiveLayerNorm(Module):
|
578
|
-
def __init__(
|
589
|
+
def __init__(
|
590
|
+
self,
|
591
|
+
dim,
|
592
|
+
dim_condition = None
|
593
|
+
):
|
579
594
|
super().__init__()
|
580
595
|
dim_condition = default(dim_condition, dim)
|
581
596
|
|
@@ -590,25 +605,43 @@ class AdaptiveLayerNorm(Module):
|
|
590
605
|
return normed * (gamma + 1.)
|
591
606
|
|
592
607
|
class ScaleNorm(Module):
|
593
|
-
def __init__(
|
608
|
+
def __init__(
|
609
|
+
self,
|
610
|
+
dim,
|
611
|
+
unit_offset = 0.
|
612
|
+
):
|
594
613
|
super().__init__()
|
614
|
+
self.unit_offset = unit_offset
|
595
615
|
self.scale = dim ** 0.5
|
596
|
-
|
616
|
+
|
617
|
+
self.g = nn.Parameter(torch.zeros(1))
|
618
|
+
nn.init.constant_(self.g, 1. - unit_offset)
|
597
619
|
|
598
620
|
def forward(self, x):
|
599
|
-
return F.normalize(x, dim = -1) * self.scale * self.g
|
621
|
+
return F.normalize(x, dim = -1) * self.scale * (self.g + self.unit_offset)
|
600
622
|
|
601
623
|
class RMSNorm(Module):
|
602
|
-
def __init__(
|
624
|
+
def __init__(
|
625
|
+
self,
|
626
|
+
dim,
|
627
|
+
unit_offset = 0.
|
628
|
+
):
|
603
629
|
super().__init__()
|
630
|
+
self.unit_offset = unit_offset
|
604
631
|
self.scale = dim ** 0.5
|
605
|
-
|
632
|
+
|
633
|
+
self.g = nn.Parameter(torch.zeros(dim))
|
634
|
+
nn.init.constant_(self.g, 1. - unit_offset)
|
606
635
|
|
607
636
|
def forward(self, x):
|
608
|
-
return F.normalize(x, dim = -1) * self.scale * self.g
|
637
|
+
return F.normalize(x, dim = -1) * self.scale * (self.g + self.unit_offset)
|
609
638
|
|
610
639
|
class AdaptiveRMSNorm(Module):
|
611
|
-
def __init__(
|
640
|
+
def __init__(
|
641
|
+
self,
|
642
|
+
dim,
|
643
|
+
dim_condition = None
|
644
|
+
):
|
612
645
|
super().__init__()
|
613
646
|
self.scale = dim ** 0.5
|
614
647
|
dim_condition = default(dim_condition, dim)
|
@@ -622,7 +655,11 @@ class AdaptiveRMSNorm(Module):
|
|
622
655
|
return normed * self.scale * (gamma + 1.)
|
623
656
|
|
624
657
|
class SimpleRMSNorm(Module):
|
625
|
-
def __init__(
|
658
|
+
def __init__(
|
659
|
+
self,
|
660
|
+
dim,
|
661
|
+
**kwargs
|
662
|
+
):
|
626
663
|
super().__init__()
|
627
664
|
self.scale = dim ** 0.5
|
628
665
|
|
@@ -696,10 +733,19 @@ class ShiftTokens(Module):
|
|
696
733
|
# post branch operator
|
697
734
|
|
698
735
|
class LayerScale(Module):
|
699
|
-
def __init__(
|
736
|
+
def __init__(
|
737
|
+
self,
|
738
|
+
fn: Module,
|
739
|
+
dim,
|
740
|
+
init_value = 0.,
|
741
|
+
unit_offset = 0.
|
742
|
+
):
|
700
743
|
super().__init__()
|
744
|
+
self.unit_offset = unit_offset
|
745
|
+
|
701
746
|
self.fn = fn
|
702
|
-
self.gamma = nn.Parameter(torch.
|
747
|
+
self.gamma = nn.Parameter(torch.zeros(dim))
|
748
|
+
nn.init.constant_(self.gamma, init_value - unit_offset)
|
703
749
|
|
704
750
|
def forward(self, x, **kwargs):
|
705
751
|
out = self.fn(x, **kwargs)
|
@@ -711,7 +757,13 @@ class LayerScale(Module):
|
|
711
757
|
return out * self.gamma, *rest
|
712
758
|
|
713
759
|
class AdaptiveLayerScale(Module):
|
714
|
-
def __init__(
|
760
|
+
def __init__(
|
761
|
+
self,
|
762
|
+
fn: Module,
|
763
|
+
dim,
|
764
|
+
dim_condition = None,
|
765
|
+
init_bias_value = -2.
|
766
|
+
):
|
715
767
|
super().__init__()
|
716
768
|
self.fn = fn
|
717
769
|
|
@@ -1182,6 +1234,7 @@ class AttentionLayers(Module):
|
|
1182
1234
|
use_adaptive_layernorm = False,
|
1183
1235
|
use_adaptive_rmsnorm = False,
|
1184
1236
|
use_adaptive_layerscale = False, # paired with use_adaptive_layernorm for ada-ln-zero from DiT paper
|
1237
|
+
norm_add_unit_offset = False,
|
1185
1238
|
dim_condition = None,
|
1186
1239
|
adaptive_condition_mlp = False,
|
1187
1240
|
adaptive_condition_mlp_expansion = 4,
|
@@ -1215,6 +1268,7 @@ class AttentionLayers(Module):
|
|
1215
1268
|
scale_residual_constant = 1.,
|
1216
1269
|
shift_tokens = 0,
|
1217
1270
|
sandwich_norm = False,
|
1271
|
+
softclamp_output_value: float | None = None,
|
1218
1272
|
resi_dual = False,
|
1219
1273
|
resi_dual_scale = 1.,
|
1220
1274
|
zero_init_branch_output = False,
|
@@ -1315,6 +1369,10 @@ class AttentionLayers(Module):
|
|
1315
1369
|
|
1316
1370
|
norm_fn = partial(norm_class, dim)
|
1317
1371
|
|
1372
|
+
if not norm_need_condition and norm_add_unit_offset:
|
1373
|
+
# researcher Ohad Rubin shares in a blog post by adding an offset to gammas and betas, they can be subjected to weight decay safely
|
1374
|
+
norm_fn = partial(norm_fn, unit_offset = 1.)
|
1375
|
+
|
1318
1376
|
self.norm_need_condition = norm_need_condition
|
1319
1377
|
self.dim_condition = dim_condition
|
1320
1378
|
|
@@ -1345,6 +1403,9 @@ class AttentionLayers(Module):
|
|
1345
1403
|
|
1346
1404
|
self.post_branch_fn_needs_condition = post_branch_fn_needs_condition
|
1347
1405
|
|
1406
|
+
if not post_branch_fn_needs_condition and norm_add_unit_offset:
|
1407
|
+
post_branch_fn = partial(post_branch_fn, unit_offset = 1.)
|
1408
|
+
|
1348
1409
|
# setup mlp for conditioning
|
1349
1410
|
|
1350
1411
|
self.need_condition = norm_need_condition or post_branch_fn_needs_condition
|
@@ -1421,6 +1482,11 @@ class AttentionLayers(Module):
|
|
1421
1482
|
|
1422
1483
|
shift_tokens = cast_tuple(shift_tokens, len(layer_types))
|
1423
1484
|
|
1485
|
+
# optional soft clamping just before the final norm
|
1486
|
+
# used in gemma 2
|
1487
|
+
|
1488
|
+
self.softclamp_output_value = softclamp_output_value
|
1489
|
+
|
1424
1490
|
# whether it has post norm
|
1425
1491
|
|
1426
1492
|
self.final_norm = norm_fn() if pre_norm or resi_dual else nn.Identity()
|
@@ -1652,6 +1718,9 @@ class AttentionLayers(Module):
|
|
1652
1718
|
if return_hiddens:
|
1653
1719
|
layer_hiddens.append(x)
|
1654
1720
|
|
1721
|
+
if exists(self.softclamp_output_value):
|
1722
|
+
x = softclamp(x, self.softclamp_output_value)
|
1723
|
+
|
1655
1724
|
final_norm = self.final_norm
|
1656
1725
|
|
1657
1726
|
if self.need_condition:
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.30.23 → x_transformers-1.31.1}/x_transformers/nonautoregressive_wrapper.py
RENAMED
File without changes
|
{x_transformers-1.30.23 → x_transformers-1.31.1}/x_transformers/xl_autoregressive_wrapper.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.30.23 → x_transformers-1.31.1}/x_transformers.egg-info/dependency_links.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|