x-transformers 1.30.23__py3-none-any.whl → 1.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -93,6 +93,9 @@ def l2norm(t, groups = 1):
93
93
  t = F.normalize(t, p = 2, dim = -1)
94
94
  return rearrange(t, '... g d -> ... (g d)')
95
95
 
96
+ def softclamp(t, value):
97
+ return (t / value).tanh() * value
98
+
96
99
  def pad_at_dim(t, pad: Tuple[int, int], dim = -1, value = 0.):
97
100
  if pad == (0, 0):
98
101
  return t
@@ -560,22 +563,29 @@ class Scale(Module):
560
563
  return (scale_fn(out[0]), *out[1:])
561
564
 
562
565
  class LayerNorm(Module):
563
- def __init__(self, dim):
566
+ def __init__(
567
+ self,
568
+ dim,
569
+ unit_offset = 0.
570
+ ):
564
571
  """
565
572
  bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
566
573
  """
567
574
  super().__init__()
575
+ self.unit_offset = unit_offset
568
576
  self.gamma = nn.Parameter(torch.ones(dim))
569
- self.register_buffer("beta", torch.zeros(dim))
577
+ self.register_buffer('beta', torch.zeros(dim), persistent = False)
570
578
 
571
579
  def forward(self, x):
572
- return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
573
-
574
- if version.parse(torch.__version__) >= version.parse('2.1.0'):
575
- LayerNorm = partial(nn.LayerNorm, bias = False)
580
+ gamma = self.gamma + self.unit_offset
581
+ return F.layer_norm(x, x.shape[-1:], gamma, self.beta)
576
582
 
577
583
  class AdaptiveLayerNorm(Module):
578
- def __init__(self, dim, dim_condition = None):
584
+ def __init__(
585
+ self,
586
+ dim,
587
+ dim_condition = None
588
+ ):
579
589
  super().__init__()
580
590
  dim_condition = default(dim_condition, dim)
581
591
 
@@ -590,25 +600,39 @@ class AdaptiveLayerNorm(Module):
590
600
  return normed * (gamma + 1.)
591
601
 
592
602
  class ScaleNorm(Module):
593
- def __init__(self, dim):
603
+ def __init__(
604
+ self,
605
+ dim,
606
+ unit_offset = 0.
607
+ ):
594
608
  super().__init__()
609
+ self.unit_offset = unit_offset
595
610
  self.scale = dim ** 0.5
596
611
  self.g = nn.Parameter(torch.ones(1))
597
612
 
598
613
  def forward(self, x):
599
- return F.normalize(x, dim = -1) * self.scale * self.g
614
+ return F.normalize(x, dim = -1) * self.scale * (self.g + self.unit_offset)
600
615
 
601
616
  class RMSNorm(Module):
602
- def __init__(self, dim):
617
+ def __init__(
618
+ self,
619
+ dim,
620
+ unit_offset = 0.
621
+ ):
603
622
  super().__init__()
623
+ self.unit_offset = unit_offset
604
624
  self.scale = dim ** 0.5
605
625
  self.g = nn.Parameter(torch.ones(dim))
606
626
 
607
627
  def forward(self, x):
608
- return F.normalize(x, dim = -1) * self.scale * self.g
628
+ return F.normalize(x, dim = -1) * self.scale * (self.g + self.unit_offset)
609
629
 
610
630
  class AdaptiveRMSNorm(Module):
611
- def __init__(self, dim, dim_condition = None):
631
+ def __init__(
632
+ self,
633
+ dim,
634
+ dim_condition = None
635
+ ):
612
636
  super().__init__()
613
637
  self.scale = dim ** 0.5
614
638
  dim_condition = default(dim_condition, dim)
@@ -622,7 +646,11 @@ class AdaptiveRMSNorm(Module):
622
646
  return normed * self.scale * (gamma + 1.)
623
647
 
624
648
  class SimpleRMSNorm(Module):
625
- def __init__(self, dim):
649
+ def __init__(
650
+ self,
651
+ dim,
652
+ **kwargs
653
+ ):
626
654
  super().__init__()
627
655
  self.scale = dim ** 0.5
628
656
 
@@ -1182,6 +1210,7 @@ class AttentionLayers(Module):
1182
1210
  use_adaptive_layernorm = False,
1183
1211
  use_adaptive_rmsnorm = False,
1184
1212
  use_adaptive_layerscale = False, # paired with use_adaptive_layernorm for ada-ln-zero from DiT paper
1213
+ norm_add_unit_offset = False,
1185
1214
  dim_condition = None,
1186
1215
  adaptive_condition_mlp = False,
1187
1216
  adaptive_condition_mlp_expansion = 4,
@@ -1215,6 +1244,7 @@ class AttentionLayers(Module):
1215
1244
  scale_residual_constant = 1.,
1216
1245
  shift_tokens = 0,
1217
1246
  sandwich_norm = False,
1247
+ softclamp_output_value: float | None = None,
1218
1248
  resi_dual = False,
1219
1249
  resi_dual_scale = 1.,
1220
1250
  zero_init_branch_output = False,
@@ -1315,6 +1345,10 @@ class AttentionLayers(Module):
1315
1345
 
1316
1346
  norm_fn = partial(norm_class, dim)
1317
1347
 
1348
+ if not norm_need_condition and norm_add_unit_offset:
1349
+ # research Ohad Rubin shares in a blog post by adding an offset to gammas and betas, they can be subjected to weight decay safely
1350
+ norm_fn = partial(norm_fn, unit_offset = 1.)
1351
+
1318
1352
  self.norm_need_condition = norm_need_condition
1319
1353
  self.dim_condition = dim_condition
1320
1354
 
@@ -1421,6 +1455,11 @@ class AttentionLayers(Module):
1421
1455
 
1422
1456
  shift_tokens = cast_tuple(shift_tokens, len(layer_types))
1423
1457
 
1458
+ # optional soft clamping just before the final norm
1459
+ # used in gemma 2
1460
+
1461
+ self.softclamp_output_value = softclamp_output_value
1462
+
1424
1463
  # whether it has post norm
1425
1464
 
1426
1465
  self.final_norm = norm_fn() if pre_norm or resi_dual else nn.Identity()
@@ -1652,6 +1691,9 @@ class AttentionLayers(Module):
1652
1691
  if return_hiddens:
1653
1692
  layer_hiddens.append(x)
1654
1693
 
1694
+ if exists(self.softclamp_output_value):
1695
+ x = softclamp(x, self.softclamp_output_value)
1696
+
1655
1697
  final_norm = self.final_norm
1656
1698
 
1657
1699
  if self.need_condition:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.30.23
3
+ Version: 1.31.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=tZixUvlsaEj3CpB49KLDOJ2BwYSPjdWotDUjB9Rbf7g,74213
7
+ x_transformers/x_transformers.py,sha256=Lbfp0VHMCMkZboYficEe8d-X7u55uOh4lQbvIz_db4c,75233
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.30.23.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.30.23.dist-info/METADATA,sha256=LM8Y0bkOF259zCn_FE2A-Uw5Yjr8YrqCKNYuW4DqtQY,662
12
- x_transformers-1.30.23.dist-info/WHEEL,sha256=cpQTJ5IWu9CdaPViMhC9YzF8gZuS5-vlfoFihTBC86A,91
13
- x_transformers-1.30.23.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.30.23.dist-info/RECORD,,
10
+ x_transformers-1.31.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.31.0.dist-info/METADATA,sha256=0-TZQ6Y-FpgOjF3z3GIO2nPRX5__9lhe5N6Qct5U79Y,661
12
+ x_transformers-1.31.0.dist-info/WHEEL,sha256=mguMlWGMX-VHnMpKOjjQidIo1ssRlCFu4a4mBpz1s2M,91
13
+ x_transformers-1.31.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.31.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (70.1.0)
2
+ Generator: setuptools (70.1.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5