x-transformers 1.31.2__py3-none-any.whl → 1.31.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -566,7 +566,7 @@ class LayerNorm(Module):
566
566
  def __init__(
567
567
  self,
568
568
  dim,
569
- unit_offset = 0.
569
+ unit_offset = False
570
570
  ):
571
571
  """
572
572
  bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
@@ -576,9 +576,7 @@ class LayerNorm(Module):
576
576
 
577
577
  self.ln = nn.LayerNorm(dim, elementwise_affine = False)
578
578
  self.gamma = nn.Parameter(torch.ones(dim))
579
- nn.init.constant_(self.gamma, 1. - unit_offset)
580
-
581
- self.register_buffer('beta', torch.zeros(dim), persistent = False)
579
+ nn.init.constant_(self.gamma, 1. - float(unit_offset))
582
580
 
583
581
  def forward(self, x):
584
582
  normed = self.ln(x)
@@ -596,7 +594,6 @@ class AdaptiveLayerNorm(Module):
596
594
 
597
595
  self.ln = nn.LayerNorm(dim, elementwise_affine = False)
598
596
  self.to_gamma = nn.Linear(dim_condition, dim, bias = False)
599
-
600
597
  nn.init.zeros_(self.to_gamma.weight)
601
598
 
602
599
  def forward(self, x, *, condition):
@@ -608,14 +605,14 @@ class ScaleNorm(Module):
608
605
  def __init__(
609
606
  self,
610
607
  dim,
611
- unit_offset = 0.
608
+ unit_offset = False
612
609
  ):
613
610
  super().__init__()
614
611
  self.unit_offset = unit_offset
615
612
  self.scale = dim ** 0.5
616
613
 
617
614
  self.g = nn.Parameter(torch.zeros(1))
618
- nn.init.constant_(self.g, 1. - unit_offset)
615
+ nn.init.constant_(self.g, 1. - float(unit_offset))
619
616
 
620
617
  def forward(self, x):
621
618
  return F.normalize(x, dim = -1) * self.scale * (self.g + self.unit_offset)
@@ -624,14 +621,14 @@ class RMSNorm(Module):
624
621
  def __init__(
625
622
  self,
626
623
  dim,
627
- unit_offset = 0.
624
+ unit_offset = False
628
625
  ):
629
626
  super().__init__()
630
627
  self.unit_offset = unit_offset
631
628
  self.scale = dim ** 0.5
632
629
 
633
630
  self.g = nn.Parameter(torch.zeros(dim))
634
- nn.init.constant_(self.g, 1. - unit_offset)
631
+ nn.init.constant_(self.g, 1. - float(unit_offset))
635
632
 
636
633
  def forward(self, x):
637
634
  return F.normalize(x, dim = -1) * self.scale * (self.g + self.unit_offset)
@@ -738,14 +735,14 @@ class LayerScale(Module):
738
735
  fn: Module,
739
736
  dim,
740
737
  init_value = 0.,
741
- unit_offset = 0.
738
+ unit_offset = False
742
739
  ):
743
740
  super().__init__()
744
741
  self.unit_offset = unit_offset
745
742
 
746
743
  self.fn = fn
747
744
  self.gamma = nn.Parameter(torch.zeros(dim))
748
- nn.init.constant_(self.gamma, init_value - unit_offset)
745
+ nn.init.constant_(self.gamma, init_value - float(unit_offset))
749
746
 
750
747
  def forward(self, x, **kwargs):
751
748
  out = self.fn(x, **kwargs)
@@ -1370,8 +1367,8 @@ class AttentionLayers(Module):
1370
1367
  norm_fn = partial(norm_class, dim)
1371
1368
 
1372
1369
  if not norm_need_condition and norm_add_unit_offset:
1373
- # researcher Ohad Rubin shares in a blog post by adding an offset to gammas and betas, they can be subjected to weight decay safely
1374
- norm_fn = partial(norm_fn, unit_offset = 1.)
1370
+ # researcher Ohad Rubin shares in a blog post by adding an offset to gammas, they can be subjected to weight decay safely
1371
+ norm_fn = partial(norm_fn, unit_offset = True)
1375
1372
 
1376
1373
  self.norm_need_condition = norm_need_condition
1377
1374
  self.dim_condition = dim_condition
@@ -1403,8 +1400,8 @@ class AttentionLayers(Module):
1403
1400
 
1404
1401
  self.post_branch_fn_needs_condition = post_branch_fn_needs_condition
1405
1402
 
1406
- if not post_branch_fn_needs_condition and norm_add_unit_offset:
1407
- post_branch_fn = partial(post_branch_fn, unit_offset = 1.)
1403
+ if exists(post_branch_fn) and not post_branch_fn_needs_condition and norm_add_unit_offset:
1404
+ post_branch_fn = partial(post_branch_fn, unit_offset = True)
1408
1405
 
1409
1406
  # setup mlp for conditioning
1410
1407
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.31.2
3
+ Version: 1.31.4
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=aeFAk19BRxkLOS-pUgtDk3HaMS6_dsuJtz2o_KkUkOU,75805
7
+ x_transformers/x_transformers.py,sha256=82gdvdwWWDWZVXTQRE6B4W-RQspg4SXhD8B6BV5pOo0,75789
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.31.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.31.2.dist-info/METADATA,sha256=dxDZYzaZwwOCD-9I6ver-SSLlSG4diss-ovDo2H_9xE,661
12
- x_transformers-1.31.2.dist-info/WHEEL,sha256=mguMlWGMX-VHnMpKOjjQidIo1ssRlCFu4a4mBpz1s2M,91
13
- x_transformers-1.31.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.31.2.dist-info/RECORD,,
10
+ x_transformers-1.31.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.31.4.dist-info/METADATA,sha256=USK7uPjCATSS4URgfekTTmPhOtvHagHC02b-R-gxwME,661
12
+ x_transformers-1.31.4.dist-info/WHEEL,sha256=mguMlWGMX-VHnMpKOjjQidIo1ssRlCFu4a4mBpz1s2M,91
13
+ x_transformers-1.31.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.31.4.dist-info/RECORD,,