x-transformers 1.31.3__py3-none-any.whl → 1.31.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -69,7 +69,7 @@ def onnx_create_causal_mask(i, j, device):
69
69
 
70
70
  # main class
71
71
 
72
- class Attend(nn.Module):
72
+ class Attend(Module):
73
73
  def __init__(
74
74
  self,
75
75
  *,
@@ -81,7 +81,8 @@ class Attend(nn.Module):
81
81
  scale = None,
82
82
  qk_norm = False,
83
83
  flash = False,
84
- logit_softclamp_value = None,
84
+ softclamp_logits = False,
85
+ logit_softclamp_value = 30.,
85
86
  add_zero_kv = False,
86
87
  cope = None,
87
88
  onnxable = False,
@@ -123,10 +124,11 @@ class Attend(nn.Module):
123
124
 
124
125
  # soft clamp attention logit value
125
126
 
126
- if exists(logit_softclamp_value):
127
+ if softclamp_logits:
127
128
  assert not flash, 'flash attention not compatible with logit softclamp value yet'
128
129
  assert logit_softclamp_value > 0.
129
130
 
131
+ self.softclamp_logits = softclamp_logits
130
132
  self.logit_softclamp_value = logit_softclamp_value
131
133
 
132
134
  # contextual positional encoding
@@ -308,6 +310,9 @@ class Attend(nn.Module):
308
310
  if exists(attn_bias):
309
311
  sim = sim + attn_bias
310
312
 
313
+ if self.softclamp_logits:
314
+ sim = softclamp(sim, self.logit_softclamp_value)
315
+
311
316
  i, j, dtype = *sim.shape[-2:], sim.dtype
312
317
 
313
318
  mask_value = -torch.finfo(sim.dtype).max
@@ -329,9 +334,6 @@ class Attend(nn.Module):
329
334
 
330
335
  pre_softmax_attn = sim.clone()
331
336
 
332
- if exists(self.logit_softclamp_value):
333
- sim = softclamp(sim, self.logit_softclamp_value)
334
-
335
337
  attn = self.attn_fn(sim, dim = -1)
336
338
  attn = attn.type(dtype)
337
339
 
@@ -566,7 +566,7 @@ class LayerNorm(Module):
566
566
  def __init__(
567
567
  self,
568
568
  dim,
569
- unit_offset = 0.
569
+ unit_offset = False
570
570
  ):
571
571
  """
572
572
  bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
@@ -576,9 +576,7 @@ class LayerNorm(Module):
576
576
 
577
577
  self.ln = nn.LayerNorm(dim, elementwise_affine = False)
578
578
  self.gamma = nn.Parameter(torch.ones(dim))
579
- nn.init.constant_(self.gamma, 1. - unit_offset)
580
-
581
- self.register_buffer('beta', torch.zeros(dim), persistent = False)
579
+ nn.init.constant_(self.gamma, 1. - float(unit_offset))
582
580
 
583
581
  def forward(self, x):
584
582
  normed = self.ln(x)
@@ -596,7 +594,6 @@ class AdaptiveLayerNorm(Module):
596
594
 
597
595
  self.ln = nn.LayerNorm(dim, elementwise_affine = False)
598
596
  self.to_gamma = nn.Linear(dim_condition, dim, bias = False)
599
-
600
597
  nn.init.zeros_(self.to_gamma.weight)
601
598
 
602
599
  def forward(self, x, *, condition):
@@ -608,14 +605,14 @@ class ScaleNorm(Module):
608
605
  def __init__(
609
606
  self,
610
607
  dim,
611
- unit_offset = 0.
608
+ unit_offset = False
612
609
  ):
613
610
  super().__init__()
614
611
  self.unit_offset = unit_offset
615
612
  self.scale = dim ** 0.5
616
613
 
617
614
  self.g = nn.Parameter(torch.zeros(1))
618
- nn.init.constant_(self.g, 1. - unit_offset)
615
+ nn.init.constant_(self.g, 1. - float(unit_offset))
619
616
 
620
617
  def forward(self, x):
621
618
  return F.normalize(x, dim = -1) * self.scale * (self.g + self.unit_offset)
@@ -624,14 +621,14 @@ class RMSNorm(Module):
624
621
  def __init__(
625
622
  self,
626
623
  dim,
627
- unit_offset = 0.
624
+ unit_offset = False
628
625
  ):
629
626
  super().__init__()
630
627
  self.unit_offset = unit_offset
631
628
  self.scale = dim ** 0.5
632
629
 
633
630
  self.g = nn.Parameter(torch.zeros(dim))
634
- nn.init.constant_(self.g, 1. - unit_offset)
631
+ nn.init.constant_(self.g, 1. - float(unit_offset))
635
632
 
636
633
  def forward(self, x):
637
634
  return F.normalize(x, dim = -1) * self.scale * (self.g + self.unit_offset)
@@ -738,14 +735,14 @@ class LayerScale(Module):
738
735
  fn: Module,
739
736
  dim,
740
737
  init_value = 0.,
741
- unit_offset = 0.
738
+ unit_offset = False
742
739
  ):
743
740
  super().__init__()
744
741
  self.unit_offset = unit_offset
745
742
 
746
743
  self.fn = fn
747
744
  self.gamma = nn.Parameter(torch.zeros(dim))
748
- nn.init.constant_(self.gamma, init_value - unit_offset)
745
+ nn.init.constant_(self.gamma, init_value - float(unit_offset))
749
746
 
750
747
  def forward(self, x, **kwargs):
751
748
  out = self.fn(x, **kwargs)
@@ -887,7 +884,8 @@ class Attention(Module):
887
884
  cope_max_pos = 16,
888
885
  cope_soft_onehot_pos = False,
889
886
  cope_talking_heads = False,
890
- logit_softclamp_value = None,
887
+ softclamp_logits = False,
888
+ logit_softclamp_value = 30.,
891
889
  onnxable = False
892
890
  ):
893
891
  super().__init__()
@@ -990,6 +988,7 @@ class Attention(Module):
990
988
  scale = qk_norm_scale if qk_norm else self.scale,
991
989
  add_zero_kv = add_zero_kv,
992
990
  flash = flash,
991
+ softclamp_logits = softclamp_logits,
993
992
  logit_softclamp_value = logit_softclamp_value,
994
993
  cope = cope,
995
994
  onnxable = onnxable
@@ -1370,8 +1369,8 @@ class AttentionLayers(Module):
1370
1369
  norm_fn = partial(norm_class, dim)
1371
1370
 
1372
1371
  if not norm_need_condition and norm_add_unit_offset:
1373
- # researcher Ohad Rubin shares in a blog post by adding an offset to gammas and betas, they can be subjected to weight decay safely
1374
- norm_fn = partial(norm_fn, unit_offset = 1.)
1372
+ # researcher Ohad Rubin shares in a blog post by adding an offset to gammas, they can be subjected to weight decay safely
1373
+ norm_fn = partial(norm_fn, unit_offset = True)
1375
1374
 
1376
1375
  self.norm_need_condition = norm_need_condition
1377
1376
  self.dim_condition = dim_condition
@@ -1404,7 +1403,7 @@ class AttentionLayers(Module):
1404
1403
  self.post_branch_fn_needs_condition = post_branch_fn_needs_condition
1405
1404
 
1406
1405
  if exists(post_branch_fn) and not post_branch_fn_needs_condition and norm_add_unit_offset:
1407
- post_branch_fn = partial(post_branch_fn, unit_offset = 1.)
1406
+ post_branch_fn = partial(post_branch_fn, unit_offset = True)
1408
1407
 
1409
1408
  # setup mlp for conditioning
1410
1409
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.31.3
3
+ Version: 1.31.5
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,14 +1,14 @@
1
1
  x_transformers/__init__.py,sha256=8LQl-dNL6vj8VHRx5LMSOlRDTXQvYOuM21PDXz8WdiI,703
2
- x_transformers/attend.py,sha256=ap2QkD-bRadFE9ZFQP84Lo1P2DpLOXPam24Jq9ybpPY,10903
2
+ x_transformers/attend.py,sha256=UWq0bElvJf-_j1N2QbJ2yg28xkWlhnOrLjMJt3If3ao,10956
3
3
  x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=D9l3jL1D0RzHynIJcSUpxQec1n-7cHgRQZNDdDoYCFQ,75832
7
+ x_transformers/x_transformers.py,sha256=VL9Dm8L5jnpgyt_V6DWGtLs9MeiTn7ZMCQdcHSFLVo8,75871
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.31.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.31.3.dist-info/METADATA,sha256=cS3vVeEi3fSoYqXWDkW6FJhgmX11vrfQr9Lc3kMT9MI,661
12
- x_transformers-1.31.3.dist-info/WHEEL,sha256=mguMlWGMX-VHnMpKOjjQidIo1ssRlCFu4a4mBpz1s2M,91
13
- x_transformers-1.31.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.31.3.dist-info/RECORD,,
10
+ x_transformers-1.31.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.31.5.dist-info/METADATA,sha256=o7z9SmDOf9BAU9FyMRuAbX4LIH8nIACA0acjGHPMz8s,661
12
+ x_transformers-1.31.5.dist-info/WHEEL,sha256=mguMlWGMX-VHnMpKOjjQidIo1ssRlCFu4a4mBpz1s2M,91
13
+ x_transformers-1.31.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.31.5.dist-info/RECORD,,