x-transformers 1.30.4__py3-none-any.whl → 1.30.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -4,6 +4,7 @@ from functools import partial
4
4
  from typing import Tuple
5
5
 
6
6
  import torch
7
+ from torch.nn import Module
7
8
  from torch import nn, einsum, Tensor
8
9
  import torch.nn.functional as F
9
10
 
@@ -22,6 +23,7 @@ class Intermediates:
22
23
  pre_softmax_attn: Tensor | None = None
23
24
  post_softmax_attn: Tensor | None = None
24
25
  cached_kv: Tuple[Tensor, Tensor] | None = None
26
+ layer_type: str | None = None
25
27
 
26
28
  def to_tuple(self):
27
29
  return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn)
@@ -81,6 +83,7 @@ class Attend(nn.Module):
81
83
  flash = False,
82
84
  logit_softclamp_value = None,
83
85
  add_zero_kv = False,
86
+ cope = None,
84
87
  onnxable = False,
85
88
  sdp_kwargs: dict = dict(
86
89
  enable_flash = True,
@@ -126,6 +129,10 @@ class Attend(nn.Module):
126
129
 
127
130
  self.logit_softclamp_value = logit_softclamp_value
128
131
 
132
+ # contextual positional encoding
133
+
134
+ self.cope = cope
135
+
129
136
  # flash attention
130
137
 
131
138
  self.flash = flash
@@ -317,6 +324,9 @@ class Attend(nn.Module):
317
324
  causal_mask = self.create_causal_mask(i, j, device = device)
318
325
  sim = sim.masked_fill(causal_mask, mask_value)
319
326
 
327
+ if exists(self.cope):
328
+ sim = sim + self.cope(q, sim)
329
+
320
330
  pre_softmax_attn = sim.clone()
321
331
 
322
332
  if exists(self.logit_softclamp_value):
@@ -304,6 +304,33 @@ class RelativePositionBias(Module):
304
304
  bias = rearrange(values, 'i j h -> h i j')
305
305
  return bias * self.scale
306
306
 
307
+ class CoPE(Module):
308
+ """
309
+ Appendix B of https://arxiv.org/abs/2405.18719
310
+ """
311
+ def __init__ (self, dim, max_pos):
312
+ super () . __init__ ()
313
+ self.max_pos = max_pos
314
+ self.pos_emb = nn.Parameter(torch.zeros(max_pos, dim))
315
+
316
+ def forward(self, query, attn_logits):
317
+ # compute positions
318
+
319
+ gates = attn_logits.sigmoid()
320
+ pos = gates.flip(-1).cumsum(dim = -1).flip(-1)
321
+ pos = pos.clamp(max = self.max_pos - 1)
322
+
323
+ # interpolate from integer positions
324
+
325
+ pos_ceil = pos.ceil().long()
326
+ pos_floor = pos.floor().long()
327
+ logits_int = einsum('b h n d, p d -> b h n p', query, self.pos_emb)
328
+ logits_ceil = logits_int.gather(-1, pos_ceil)
329
+ logits_floor = logits_int.gather(-1, pos_floor)
330
+
331
+ w = pos - pos_floor
332
+ return logits_ceil * w + logits_floor * (1 - w)
333
+
307
334
  class DynamicPositionBias(Module):
308
335
  def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
309
336
  super().__init__()
@@ -722,6 +749,8 @@ class Attention(Module):
722
749
  tensor_product = False, # https://arxiv.org/abs/2208.06061
723
750
  add_zero_kv = False, # same as add_zero_attn in pytorch
724
751
  rotary_embed_values = False,
752
+ use_cope = False,
753
+ cope_max_pos = 16,
725
754
  logit_softclamp_value = None,
726
755
  onnxable = False
727
756
  ):
@@ -753,13 +782,16 @@ class Attention(Module):
753
782
  self.to_k = nn.Linear(dim_kv, k_dim, bias = False)
754
783
 
755
784
  # shared key / values, for further memory savings during inference
785
+
756
786
  assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
757
787
  self.to_v = nn.Linear(dim_kv, v_dim, bias = False) if not shared_kv else None
758
788
 
759
789
  # relations projection from tp-attention
790
+
760
791
  self.to_r = nn.Linear(dim, v_dim, bias = False) if tensor_product else None
761
792
 
762
793
  # add GLU gating for aggregated values, from alphafold2
794
+
763
795
  self.to_v_gate = None
764
796
  if gate_values:
765
797
  self.to_v_gate = nn.Linear(dim, out_dim)
@@ -768,6 +800,7 @@ class Attention(Module):
768
800
  nn.init.constant_(self.to_v_gate.bias, 10)
769
801
 
770
802
  # add per head gating of the output values, from 'Attend to nothing' paper
803
+
771
804
  self.to_v_head_gate = None
772
805
  if gate_value_heads:
773
806
  self.to_v_head_gate = nn.Linear(dim, heads)
@@ -775,11 +808,13 @@ class Attention(Module):
775
808
  nn.init.constant_(self.to_v_head_gate.bias, 10)
776
809
 
777
810
  # cosine sim attention
811
+
778
812
  self.qk_norm = qk_norm
779
813
  self.qk_norm_groups = qk_norm_groups
780
814
  self.qk_norm_scale = qk_norm_scale
781
815
 
782
816
  # whether to use the rmsnorm (equivalent to cosine sim attention when scale is equal to 1) - https://arxiv.org/abs/2302.05442
817
+
783
818
  self.qk_norm_dim_scale = qk_norm_dim_scale
784
819
 
785
820
  self.qk_norm_q_scale = self.qk_norm_k_scale = 1
@@ -790,6 +825,17 @@ class Attention(Module):
790
825
  assert (not qk_norm) or divisible_by(dim_head, qk_norm_groups), 'dimension per attention head must be divisible by the qk norm groups'
791
826
  assert not (qk_norm and (dim_head // qk_norm_groups) <= 2), 'the group dimension may be too small (2 was too small in my tests, but 4 still works, surprisingly)'
792
827
 
828
+ # contextual positional encoding
829
+ # https://arxiv.org/html/2405.18719v2
830
+
831
+ cope = None
832
+
833
+ if use_cope:
834
+ assert causal, 'CoPE was designed for causal attention'
835
+ assert not flash, 'CoPE is not flash attention compatible'
836
+
837
+ cope = CoPE(dim_head, cope_max_pos)
838
+
793
839
  # attend class - includes core attention algorithm + talking heads
794
840
 
795
841
  self.attend = Attend(
@@ -803,31 +849,38 @@ class Attention(Module):
803
849
  add_zero_kv = add_zero_kv,
804
850
  flash = flash,
805
851
  logit_softclamp_value = logit_softclamp_value,
852
+ cope = cope,
806
853
  onnxable = onnxable
807
854
  )
808
855
 
809
856
  # head scaling
857
+
810
858
  self.head_scale = head_scale
811
859
  if head_scale:
812
860
  self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1))
813
861
 
814
862
  # explicit topk sparse attention
863
+
815
864
  self.sparse_topk = sparse_topk
816
865
 
817
866
  # add memory key / values
867
+
818
868
  self.num_mem_kv = num_mem_kv
819
869
  if num_mem_kv > 0:
820
870
  self.mem_k = nn.Parameter(torch.randn(kv_heads, num_mem_kv, dim_head))
821
871
  self.mem_v = nn.Parameter(torch.randn(kv_heads, num_mem_kv, dim_head))
822
872
 
823
873
  # attention on attention
874
+
824
875
  self.attn_on_attn = on_attn
825
876
  self.to_out = nn.Sequential(nn.Linear(out_dim, dim * 2, bias = False), nn.GLU()) if on_attn else nn.Linear(out_dim, dim, bias = False)
826
877
 
827
878
  # whether to rotate positions into values, for absolute positions in addition to relative
879
+
828
880
  self.rotary_embed_values = rotary_embed_values
829
881
 
830
882
  # init output projection 0
883
+
831
884
  if zero_init_output:
832
885
  init_zero_(self.to_out)
833
886
 
@@ -1410,6 +1463,7 @@ class AttentionLayers(Module):
1410
1463
  x = residual_fn(out, inner_residual)
1411
1464
 
1412
1465
  if layer_type in ('a', 'c') and return_hiddens:
1466
+ inter.layer_type = layer_type
1413
1467
  intermediates.append(inter)
1414
1468
 
1415
1469
  if layer_type == 'a' and self.residual_attn:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.30.4
3
+ Version: 1.30.7
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,14 +1,14 @@
1
1
  x_transformers/__init__.py,sha256=8LQl-dNL6vj8VHRx5LMSOlRDTXQvYOuM21PDXz8WdiI,703
2
- x_transformers/attend.py,sha256=2SPHjXS_QAAZt04lHWGtdOypTExmo3BrbFhgcIQTk-Y,10671
2
+ x_transformers/attend.py,sha256=ap2QkD-bRadFE9ZFQP84Lo1P2DpLOXPam24Jq9ybpPY,10903
3
3
  x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=P4rqlYGS9j9Gz00B4NPM7L6mhvamSYdBy5nG0ggOIMM,66342
7
+ x_transformers/x_transformers.py,sha256=r9F_LLp5bQyAlue3bBTRwoRx02noTCh4ICF8oWCw1wE,67657
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.30.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.30.4.dist-info/METADATA,sha256=VwdrJaRjocQXIAxdGzq4rByPGvaA4jsogostzCysdjI,661
12
- x_transformers-1.30.4.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
- x_transformers-1.30.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.30.4.dist-info/RECORD,,
10
+ x_transformers-1.30.7.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.30.7.dist-info/METADATA,sha256=_TniCg2s6tlimpfzpWeMCsCMOjsoYwUObBiXFdY-JhA,661
12
+ x_transformers-1.30.7.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
+ x_transformers-1.30.7.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.30.7.dist-info/RECORD,,