x-transformers 1.30.23__py3-none-any.whl → 1.31.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -93,6 +93,9 @@ def l2norm(t, groups = 1):
93
93
  t = F.normalize(t, p = 2, dim = -1)
94
94
  return rearrange(t, '... g d -> ... (g d)')
95
95
 
96
+ def softclamp(t, value):
97
+ return (t / value).tanh() * value
98
+
96
99
  def pad_at_dim(t, pad: Tuple[int, int], dim = -1, value = 0.):
97
100
  if pad == (0, 0):
98
101
  return t
@@ -560,22 +563,34 @@ class Scale(Module):
560
563
  return (scale_fn(out[0]), *out[1:])
561
564
 
562
565
  class LayerNorm(Module):
563
- def __init__(self, dim):
566
+ def __init__(
567
+ self,
568
+ dim,
569
+ unit_offset = 0.
570
+ ):
564
571
  """
565
572
  bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
566
573
  """
567
574
  super().__init__()
575
+ self.unit_offset = unit_offset
576
+
577
+ self.ln = nn.LayerNorm(dim, elementwise_affine = False)
568
578
  self.gamma = nn.Parameter(torch.ones(dim))
569
- self.register_buffer("beta", torch.zeros(dim))
579
+ nn.init.constant_(self.gamma, 1. - unit_offset)
570
580
 
571
- def forward(self, x):
572
- return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
581
+ self.register_buffer('beta', torch.zeros(dim), persistent = False)
573
582
 
574
- if version.parse(torch.__version__) >= version.parse('2.1.0'):
575
- LayerNorm = partial(nn.LayerNorm, bias = False)
583
+ def forward(self, x):
584
+ normed = self.ln(x)
585
+ gamma = self.gamma + self.unit_offset
586
+ return normed * gamma
576
587
 
577
588
  class AdaptiveLayerNorm(Module):
578
- def __init__(self, dim, dim_condition = None):
589
+ def __init__(
590
+ self,
591
+ dim,
592
+ dim_condition = None
593
+ ):
579
594
  super().__init__()
580
595
  dim_condition = default(dim_condition, dim)
581
596
 
@@ -590,25 +605,43 @@ class AdaptiveLayerNorm(Module):
590
605
  return normed * (gamma + 1.)
591
606
 
592
607
  class ScaleNorm(Module):
593
- def __init__(self, dim):
608
+ def __init__(
609
+ self,
610
+ dim,
611
+ unit_offset = 0.
612
+ ):
594
613
  super().__init__()
614
+ self.unit_offset = unit_offset
595
615
  self.scale = dim ** 0.5
596
- self.g = nn.Parameter(torch.ones(1))
616
+
617
+ self.g = nn.Parameter(torch.zeros(1))
618
+ nn.init.constant_(self.g, 1. - unit_offset)
597
619
 
598
620
  def forward(self, x):
599
- return F.normalize(x, dim = -1) * self.scale * self.g
621
+ return F.normalize(x, dim = -1) * self.scale * (self.g + self.unit_offset)
600
622
 
601
623
  class RMSNorm(Module):
602
- def __init__(self, dim):
624
+ def __init__(
625
+ self,
626
+ dim,
627
+ unit_offset = 0.
628
+ ):
603
629
  super().__init__()
630
+ self.unit_offset = unit_offset
604
631
  self.scale = dim ** 0.5
605
- self.g = nn.Parameter(torch.ones(dim))
632
+
633
+ self.g = nn.Parameter(torch.zeros(dim))
634
+ nn.init.constant_(self.g, 1. - unit_offset)
606
635
 
607
636
  def forward(self, x):
608
- return F.normalize(x, dim = -1) * self.scale * self.g
637
+ return F.normalize(x, dim = -1) * self.scale * (self.g + self.unit_offset)
609
638
 
610
639
  class AdaptiveRMSNorm(Module):
611
- def __init__(self, dim, dim_condition = None):
640
+ def __init__(
641
+ self,
642
+ dim,
643
+ dim_condition = None
644
+ ):
612
645
  super().__init__()
613
646
  self.scale = dim ** 0.5
614
647
  dim_condition = default(dim_condition, dim)
@@ -622,7 +655,11 @@ class AdaptiveRMSNorm(Module):
622
655
  return normed * self.scale * (gamma + 1.)
623
656
 
624
657
  class SimpleRMSNorm(Module):
625
- def __init__(self, dim):
658
+ def __init__(
659
+ self,
660
+ dim,
661
+ **kwargs
662
+ ):
626
663
  super().__init__()
627
664
  self.scale = dim ** 0.5
628
665
 
@@ -696,10 +733,19 @@ class ShiftTokens(Module):
696
733
  # post branch operator
697
734
 
698
735
  class LayerScale(Module):
699
- def __init__(self, fn: Module, dim, init_value = 0.):
736
+ def __init__(
737
+ self,
738
+ fn: Module,
739
+ dim,
740
+ init_value = 0.,
741
+ unit_offset = 0.
742
+ ):
700
743
  super().__init__()
744
+ self.unit_offset = unit_offset
745
+
701
746
  self.fn = fn
702
- self.gamma = nn.Parameter(torch.ones(dim) * init_value)
747
+ self.gamma = nn.Parameter(torch.zeros(dim))
748
+ nn.init.constant_(self.gamma, init_value - unit_offset)
703
749
 
704
750
  def forward(self, x, **kwargs):
705
751
  out = self.fn(x, **kwargs)
@@ -711,7 +757,13 @@ class LayerScale(Module):
711
757
  return out * self.gamma, *rest
712
758
 
713
759
  class AdaptiveLayerScale(Module):
714
- def __init__(self, fn: Module, dim, dim_condition = None, init_bias_value = -2.):
760
+ def __init__(
761
+ self,
762
+ fn: Module,
763
+ dim,
764
+ dim_condition = None,
765
+ init_bias_value = -2.
766
+ ):
715
767
  super().__init__()
716
768
  self.fn = fn
717
769
 
@@ -1182,6 +1234,7 @@ class AttentionLayers(Module):
1182
1234
  use_adaptive_layernorm = False,
1183
1235
  use_adaptive_rmsnorm = False,
1184
1236
  use_adaptive_layerscale = False, # paired with use_adaptive_layernorm for ada-ln-zero from DiT paper
1237
+ norm_add_unit_offset = False,
1185
1238
  dim_condition = None,
1186
1239
  adaptive_condition_mlp = False,
1187
1240
  adaptive_condition_mlp_expansion = 4,
@@ -1215,6 +1268,7 @@ class AttentionLayers(Module):
1215
1268
  scale_residual_constant = 1.,
1216
1269
  shift_tokens = 0,
1217
1270
  sandwich_norm = False,
1271
+ softclamp_output_value: float | None = None,
1218
1272
  resi_dual = False,
1219
1273
  resi_dual_scale = 1.,
1220
1274
  zero_init_branch_output = False,
@@ -1315,6 +1369,10 @@ class AttentionLayers(Module):
1315
1369
 
1316
1370
  norm_fn = partial(norm_class, dim)
1317
1371
 
1372
+ if not norm_need_condition and norm_add_unit_offset:
1373
+ # researcher Ohad Rubin shares in a blog post by adding an offset to gammas and betas, they can be subjected to weight decay safely
1374
+ norm_fn = partial(norm_fn, unit_offset = 1.)
1375
+
1318
1376
  self.norm_need_condition = norm_need_condition
1319
1377
  self.dim_condition = dim_condition
1320
1378
 
@@ -1345,6 +1403,9 @@ class AttentionLayers(Module):
1345
1403
 
1346
1404
  self.post_branch_fn_needs_condition = post_branch_fn_needs_condition
1347
1405
 
1406
+ if not post_branch_fn_needs_condition and norm_add_unit_offset:
1407
+ post_branch_fn = partial(post_branch_fn, unit_offset = 1.)
1408
+
1348
1409
  # setup mlp for conditioning
1349
1410
 
1350
1411
  self.need_condition = norm_need_condition or post_branch_fn_needs_condition
@@ -1421,6 +1482,11 @@ class AttentionLayers(Module):
1421
1482
 
1422
1483
  shift_tokens = cast_tuple(shift_tokens, len(layer_types))
1423
1484
 
1485
+ # optional soft clamping just before the final norm
1486
+ # used in gemma 2
1487
+
1488
+ self.softclamp_output_value = softclamp_output_value
1489
+
1424
1490
  # whether it has post norm
1425
1491
 
1426
1492
  self.final_norm = norm_fn() if pre_norm or resi_dual else nn.Identity()
@@ -1652,6 +1718,9 @@ class AttentionLayers(Module):
1652
1718
  if return_hiddens:
1653
1719
  layer_hiddens.append(x)
1654
1720
 
1721
+ if exists(self.softclamp_output_value):
1722
+ x = softclamp(x, self.softclamp_output_value)
1723
+
1655
1724
  final_norm = self.final_norm
1656
1725
 
1657
1726
  if self.need_condition:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.30.23
3
+ Version: 1.31.1
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=tZixUvlsaEj3CpB49KLDOJ2BwYSPjdWotDUjB9Rbf7g,74213
7
+ x_transformers/x_transformers.py,sha256=EehddXNWxU7NEqD8t76ekev__cCJ-F3J1oZjAD9PGa8,75806
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.30.23.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.30.23.dist-info/METADATA,sha256=LM8Y0bkOF259zCn_FE2A-Uw5Yjr8YrqCKNYuW4DqtQY,662
12
- x_transformers-1.30.23.dist-info/WHEEL,sha256=cpQTJ5IWu9CdaPViMhC9YzF8gZuS5-vlfoFihTBC86A,91
13
- x_transformers-1.30.23.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.30.23.dist-info/RECORD,,
10
+ x_transformers-1.31.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.31.1.dist-info/METADATA,sha256=FWASwgSjICfN3sak8Itwh41R-g2sT_wzhBm0RA3vPzA,661
12
+ x_transformers-1.31.1.dist-info/WHEEL,sha256=mguMlWGMX-VHnMpKOjjQidIo1ssRlCFu4a4mBpz1s2M,91
13
+ x_transformers-1.31.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.31.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (70.1.0)
2
+ Generator: setuptools (70.1.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5