x-transformers 1.31.0__py3-none-any.whl → 1.31.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -573,12 +573,17 @@ class LayerNorm(Module):
573
573
  """
574
574
  super().__init__()
575
575
  self.unit_offset = unit_offset
576
+
577
+ self.ln = nn.LayerNorm(dim, elementwise_affine = False)
576
578
  self.gamma = nn.Parameter(torch.ones(dim))
579
+ nn.init.constant_(self.gamma, 1. - unit_offset)
580
+
577
581
  self.register_buffer('beta', torch.zeros(dim), persistent = False)
578
582
 
579
583
  def forward(self, x):
584
+ normed = self.ln(x)
580
585
  gamma = self.gamma + self.unit_offset
581
- return F.layer_norm(x, x.shape[-1:], gamma, self.beta)
586
+ return normed * gamma
582
587
 
583
588
  class AdaptiveLayerNorm(Module):
584
589
  def __init__(
@@ -608,7 +613,9 @@ class ScaleNorm(Module):
608
613
  super().__init__()
609
614
  self.unit_offset = unit_offset
610
615
  self.scale = dim ** 0.5
611
- self.g = nn.Parameter(torch.ones(1))
616
+
617
+ self.g = nn.Parameter(torch.zeros(1))
618
+ nn.init.constant_(self.g, 1. - unit_offset)
612
619
 
613
620
  def forward(self, x):
614
621
  return F.normalize(x, dim = -1) * self.scale * (self.g + self.unit_offset)
@@ -622,7 +629,9 @@ class RMSNorm(Module):
622
629
  super().__init__()
623
630
  self.unit_offset = unit_offset
624
631
  self.scale = dim ** 0.5
625
- self.g = nn.Parameter(torch.ones(dim))
632
+
633
+ self.g = nn.Parameter(torch.zeros(dim))
634
+ nn.init.constant_(self.g, 1. - unit_offset)
626
635
 
627
636
  def forward(self, x):
628
637
  return F.normalize(x, dim = -1) * self.scale * (self.g + self.unit_offset)
@@ -724,10 +733,19 @@ class ShiftTokens(Module):
724
733
  # post branch operator
725
734
 
726
735
  class LayerScale(Module):
727
- def __init__(self, fn: Module, dim, init_value = 0.):
736
+ def __init__(
737
+ self,
738
+ fn: Module,
739
+ dim,
740
+ init_value = 0.,
741
+ unit_offset = 0.
742
+ ):
728
743
  super().__init__()
744
+ self.unit_offset = unit_offset
745
+
729
746
  self.fn = fn
730
- self.gamma = nn.Parameter(torch.ones(dim) * init_value)
747
+ self.gamma = nn.Parameter(torch.zeros(dim))
748
+ nn.init.constant_(self.gamma, init_value - unit_offset)
731
749
 
732
750
  def forward(self, x, **kwargs):
733
751
  out = self.fn(x, **kwargs)
@@ -739,7 +757,13 @@ class LayerScale(Module):
739
757
  return out * self.gamma, *rest
740
758
 
741
759
  class AdaptiveLayerScale(Module):
742
- def __init__(self, fn: Module, dim, dim_condition = None, init_bias_value = -2.):
760
+ def __init__(
761
+ self,
762
+ fn: Module,
763
+ dim,
764
+ dim_condition = None,
765
+ init_bias_value = -2.
766
+ ):
743
767
  super().__init__()
744
768
  self.fn = fn
745
769
 
@@ -1346,7 +1370,7 @@ class AttentionLayers(Module):
1346
1370
  norm_fn = partial(norm_class, dim)
1347
1371
 
1348
1372
  if not norm_need_condition and norm_add_unit_offset:
1349
- # research Ohad Rubin shares in a blog post by adding an offset to gammas and betas, they can be subjected to weight decay safely
1373
+ # researcher Ohad Rubin shares in a blog post by adding an offset to gammas and betas, they can be subjected to weight decay safely
1350
1374
  norm_fn = partial(norm_fn, unit_offset = 1.)
1351
1375
 
1352
1376
  self.norm_need_condition = norm_need_condition
@@ -1379,6 +1403,9 @@ class AttentionLayers(Module):
1379
1403
 
1380
1404
  self.post_branch_fn_needs_condition = post_branch_fn_needs_condition
1381
1405
 
1406
+ if not post_branch_fn_needs_condition and norm_add_unit_offset:
1407
+ post_branch_fn = partial(post_branch_fn, unit_offset = 1.)
1408
+
1382
1409
  # setup mlp for conditioning
1383
1410
 
1384
1411
  self.need_condition = norm_need_condition or post_branch_fn_needs_condition
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.31.0
3
+ Version: 1.31.1
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=Lbfp0VHMCMkZboYficEe8d-X7u55uOh4lQbvIz_db4c,75233
7
+ x_transformers/x_transformers.py,sha256=EehddXNWxU7NEqD8t76ekev__cCJ-F3J1oZjAD9PGa8,75806
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.31.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.31.0.dist-info/METADATA,sha256=0-TZQ6Y-FpgOjF3z3GIO2nPRX5__9lhe5N6Qct5U79Y,661
12
- x_transformers-1.31.0.dist-info/WHEEL,sha256=mguMlWGMX-VHnMpKOjjQidIo1ssRlCFu4a4mBpz1s2M,91
13
- x_transformers-1.31.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.31.0.dist-info/RECORD,,
10
+ x_transformers-1.31.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.31.1.dist-info/METADATA,sha256=FWASwgSjICfN3sak8Itwh41R-g2sT_wzhBm0RA3vPzA,661
12
+ x_transformers-1.31.1.dist-info/WHEEL,sha256=mguMlWGMX-VHnMpKOjjQidIo1ssRlCFu4a4mBpz1s2M,91
13
+ x_transformers-1.31.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.31.1.dist-info/RECORD,,