x-transformers 1.30.6__py3-none-any.whl → 1.30.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -4,6 +4,7 @@ from functools import partial
4
4
  from typing import Tuple
5
5
 
6
6
  import torch
7
+ from torch.nn import Module
7
8
  from torch import nn, einsum, Tensor
8
9
  import torch.nn.functional as F
9
10
 
@@ -82,6 +83,7 @@ class Attend(nn.Module):
82
83
  flash = False,
83
84
  logit_softclamp_value = None,
84
85
  add_zero_kv = False,
86
+ cope = None,
85
87
  onnxable = False,
86
88
  sdp_kwargs: dict = dict(
87
89
  enable_flash = True,
@@ -127,6 +129,10 @@ class Attend(nn.Module):
127
129
 
128
130
  self.logit_softclamp_value = logit_softclamp_value
129
131
 
132
+ # contextual positional encoding
133
+
134
+ self.cope = cope
135
+
130
136
  # flash attention
131
137
 
132
138
  self.flash = flash
@@ -318,6 +324,9 @@ class Attend(nn.Module):
318
324
  causal_mask = self.create_causal_mask(i, j, device = device)
319
325
  sim = sim.masked_fill(causal_mask, mask_value)
320
326
 
327
+ if exists(self.cope):
328
+ sim = sim + self.cope(q, sim)
329
+
321
330
  pre_softmax_attn = sim.clone()
322
331
 
323
332
  if exists(self.logit_softclamp_value):
@@ -304,6 +304,57 @@ class RelativePositionBias(Module):
304
304
  bias = rearrange(values, 'i j h -> h i j')
305
305
  return bias * self.scale
306
306
 
307
+ class CoPE(Module):
308
+ """
309
+ Appendix B of https://arxiv.org/abs/2405.18719
310
+ """
311
+ def __init__ (
312
+ self,
313
+ dim,
314
+ max_pos,
315
+ soft_onehot = True,
316
+ reverse = True
317
+ ):
318
+ super () . __init__ ()
319
+ self.reverse = reverse
320
+ self.max_pos = max_pos
321
+ self.pos_emb = nn.Parameter(torch.zeros(max_pos, dim))
322
+
323
+ self.soft_onehot = soft_onehot
324
+
325
+ if soft_onehot:
326
+ self.register_buffer('positions', torch.arange(max_pos))
327
+
328
+ def forward(self, query, attn_logits, temp = 5e-2):
329
+ # compute positions
330
+
331
+ gates = attn_logits.sigmoid()
332
+
333
+ if self.reverse:
334
+ pos = gates.flip(-1).cumsum(dim = -1).flip(-1)
335
+ else:
336
+ pos = gates.cumsum(dim = -1)
337
+
338
+ pos = pos.clamp(max = self.max_pos - 1)
339
+
340
+ logits_int = einsum('b h n d, p d -> b h n p', query, self.pos_emb)
341
+
342
+ if self.soft_onehot:
343
+ diff_pos = (pos[..., None] - self.positions).abs()
344
+ soft_onehot_pos = F.softmax(-diff_pos / temp, dim = -1)
345
+ cope_pos_emb = einsum('b h i j p, b h i p -> b h i j', soft_onehot_pos, logits_int)
346
+ else:
347
+ # interpolate from integer positions
348
+ pos_ceil = pos.ceil().long()
349
+ pos_floor = pos.floor().long()
350
+ logits_ceil = logits_int.gather(-1, pos_ceil)
351
+ logits_floor = logits_int.gather(-1, pos_floor)
352
+
353
+ w = pos - pos_floor
354
+ cope_pos_emb = logits_ceil * w + logits_floor * (1 - w)
355
+
356
+ return cope_pos_emb
357
+
307
358
  class DynamicPositionBias(Module):
308
359
  def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
309
360
  super().__init__()
@@ -722,6 +773,8 @@ class Attention(Module):
722
773
  tensor_product = False, # https://arxiv.org/abs/2208.06061
723
774
  add_zero_kv = False, # same as add_zero_attn in pytorch
724
775
  rotary_embed_values = False,
776
+ use_cope = False,
777
+ cope_max_pos = 16,
725
778
  logit_softclamp_value = None,
726
779
  onnxable = False
727
780
  ):
@@ -753,13 +806,16 @@ class Attention(Module):
753
806
  self.to_k = nn.Linear(dim_kv, k_dim, bias = False)
754
807
 
755
808
  # shared key / values, for further memory savings during inference
809
+
756
810
  assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
757
811
  self.to_v = nn.Linear(dim_kv, v_dim, bias = False) if not shared_kv else None
758
812
 
759
813
  # relations projection from tp-attention
814
+
760
815
  self.to_r = nn.Linear(dim, v_dim, bias = False) if tensor_product else None
761
816
 
762
817
  # add GLU gating for aggregated values, from alphafold2
818
+
763
819
  self.to_v_gate = None
764
820
  if gate_values:
765
821
  self.to_v_gate = nn.Linear(dim, out_dim)
@@ -768,6 +824,7 @@ class Attention(Module):
768
824
  nn.init.constant_(self.to_v_gate.bias, 10)
769
825
 
770
826
  # add per head gating of the output values, from 'Attend to nothing' paper
827
+
771
828
  self.to_v_head_gate = None
772
829
  if gate_value_heads:
773
830
  self.to_v_head_gate = nn.Linear(dim, heads)
@@ -775,11 +832,13 @@ class Attention(Module):
775
832
  nn.init.constant_(self.to_v_head_gate.bias, 10)
776
833
 
777
834
  # cosine sim attention
835
+
778
836
  self.qk_norm = qk_norm
779
837
  self.qk_norm_groups = qk_norm_groups
780
838
  self.qk_norm_scale = qk_norm_scale
781
839
 
782
840
  # whether to use the rmsnorm (equivalent to cosine sim attention when scale is equal to 1) - https://arxiv.org/abs/2302.05442
841
+
783
842
  self.qk_norm_dim_scale = qk_norm_dim_scale
784
843
 
785
844
  self.qk_norm_q_scale = self.qk_norm_k_scale = 1
@@ -790,6 +849,17 @@ class Attention(Module):
790
849
  assert (not qk_norm) or divisible_by(dim_head, qk_norm_groups), 'dimension per attention head must be divisible by the qk norm groups'
791
850
  assert not (qk_norm and (dim_head // qk_norm_groups) <= 2), 'the group dimension may be too small (2 was too small in my tests, but 4 still works, surprisingly)'
792
851
 
852
+ # contextual positional encoding
853
+ # https://arxiv.org/html/2405.18719v2
854
+
855
+ cope = None
856
+
857
+ if use_cope:
858
+ assert causal, 'CoPE was designed for causal attention'
859
+ assert not flash, 'CoPE is not flash attention compatible'
860
+
861
+ cope = CoPE(dim_head, cope_max_pos)
862
+
793
863
  # attend class - includes core attention algorithm + talking heads
794
864
 
795
865
  self.attend = Attend(
@@ -803,31 +873,38 @@ class Attention(Module):
803
873
  add_zero_kv = add_zero_kv,
804
874
  flash = flash,
805
875
  logit_softclamp_value = logit_softclamp_value,
876
+ cope = cope,
806
877
  onnxable = onnxable
807
878
  )
808
879
 
809
880
  # head scaling
881
+
810
882
  self.head_scale = head_scale
811
883
  if head_scale:
812
884
  self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1))
813
885
 
814
886
  # explicit topk sparse attention
887
+
815
888
  self.sparse_topk = sparse_topk
816
889
 
817
890
  # add memory key / values
891
+
818
892
  self.num_mem_kv = num_mem_kv
819
893
  if num_mem_kv > 0:
820
894
  self.mem_k = nn.Parameter(torch.randn(kv_heads, num_mem_kv, dim_head))
821
895
  self.mem_v = nn.Parameter(torch.randn(kv_heads, num_mem_kv, dim_head))
822
896
 
823
897
  # attention on attention
898
+
824
899
  self.attn_on_attn = on_attn
825
900
  self.to_out = nn.Sequential(nn.Linear(out_dim, dim * 2, bias = False), nn.GLU()) if on_attn else nn.Linear(out_dim, dim, bias = False)
826
901
 
827
902
  # whether to rotate positions into values, for absolute positions in addition to relative
903
+
828
904
  self.rotary_embed_values = rotary_embed_values
829
905
 
830
906
  # init output projection 0
907
+
831
908
  if zero_init_output:
832
909
  init_zero_(self.to_out)
833
910
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.30.6
3
+ Version: 1.30.8
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,14 +1,14 @@
1
1
  x_transformers/__init__.py,sha256=8LQl-dNL6vj8VHRx5LMSOlRDTXQvYOuM21PDXz8WdiI,703
2
- x_transformers/attend.py,sha256=8opOeCQddi440WcH73B_wB5vtL0jaEQwBL-DIWq2lCs,10713
2
+ x_transformers/attend.py,sha256=ap2QkD-bRadFE9ZFQP84Lo1P2DpLOXPam24Jq9ybpPY,10903
3
3
  x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=pXckFcDL6kTghYEUjIamZiR5H8dV6aIEPQTIYAGgqxA,66388
7
+ x_transformers/x_transformers.py,sha256=5cof7yvOAfFviLh-luafmhtTJDemCPoy9rHHYjWxLu4,68338
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.30.6.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.30.6.dist-info/METADATA,sha256=1Xq9oSctaCQ5TOpdM3j6lJENYStLuda5VqEzOtq1B0c,661
12
- x_transformers-1.30.6.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
- x_transformers-1.30.6.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.30.6.dist-info/RECORD,,
10
+ x_transformers-1.30.8.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.30.8.dist-info/METADATA,sha256=2L0SfGhrbLMjpRKLwTp1_YH1Amu3g_j1nEuWIuGNqrQ,661
12
+ x_transformers-1.30.8.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
+ x_transformers-1.30.8.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.30.8.dist-info/RECORD,,