x-transformers 1.37.10__py3-none-any.whl → 1.38.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -65,6 +65,35 @@ def once(fn):
65
65
 
66
66
  print_once = once(print)
67
67
 
68
+ # selective attention
69
+ # https://arxiv.org/abs/2410.02703 - section 3.3
70
+ # it is a parameter-less technique to allow each token to prevent itself from being attended to by future tokens
71
+ # if sim_head_gate not supplied, will use the first head of the attention logits (sim in this framework)
72
+
73
+ def selective_attn(
74
+ sim,
75
+ sim_head_gate = None,
76
+ no_mask_sos = True
77
+ ):
78
+ i, j, device = *sim.shape[-2:], sim.device
79
+ sim_head_gate = default(sim_head_gate, sim[:, 0])
80
+
81
+ gate = F.relu(sim_head_gate) # only positive
82
+
83
+ if no_mask_sos:
84
+ gate[..., -i] = 0.
85
+
86
+ eye = torch.eye(j, device = device)
87
+
88
+ if j > i:
89
+ eye = F.pad(eye, (j - i, 0), value = 1.)
90
+
91
+ gate = (1. - eye) * gate
92
+ gate = F.pad(gate, (0, 0, 1, -1), value = 0.) # only allow for masking the future
93
+ gate = gate.cumsum(dim = -2)
94
+
95
+ return sim - rearrange(gate, 'b i j -> b 1 i j')
96
+
68
97
  # alternative distance functions
69
98
 
70
99
  def qk_l2_dist_squared(q, k):
@@ -109,6 +138,7 @@ class Attend(Module):
109
138
  softclamp_logits = False,
110
139
  logit_softclamp_value = 50.,
111
140
  add_zero_kv = False,
141
+ selective = False,
112
142
  sigsoftmax = False,
113
143
  cope = None,
114
144
  onnxable = False,
@@ -155,6 +185,11 @@ class Attend(Module):
155
185
  self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
156
186
  self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
157
187
 
188
+ # selective attention
189
+
190
+ assert not (selective and not causal), 'selective attention is designed for autoregressive'
191
+ self.selective = selective
192
+
158
193
  # sparse topk
159
194
 
160
195
  assert not (flash and sparse_topk), 'sparse topk not compatible with flash attention'
@@ -427,6 +462,9 @@ class Attend(Module):
427
462
  if exists(self.cope):
428
463
  sim = sim + self.cope(q, sim)
429
464
 
465
+ if self.selective:
466
+ sim = selective_attn(sim)
467
+
430
468
  pre_softmax_attn = sim
431
469
 
432
470
  if self.sigsoftmax:
@@ -925,6 +925,7 @@ class Attention(Module):
925
925
  qk_norm_dim_scale = False,
926
926
  l2_distance = False,
927
927
  sigmoid = False,
928
+ selective = False,
928
929
  custom_attn_fn: Callable | None = None,
929
930
  one_kv_head = False,
930
931
  kv_heads = None,
@@ -1042,6 +1043,7 @@ class Attention(Module):
1042
1043
  scale = qk_norm_scale if qk_norm else self.scale,
1043
1044
  l2_distance = l2_distance,
1044
1045
  sigmoid = sigmoid,
1046
+ selective = selective,
1045
1047
  custom_attn_fn = custom_attn_fn,
1046
1048
  add_zero_kv = add_zero_kv,
1047
1049
  flash = flash,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.37.10
3
+ Version: 1.38.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,15 +1,15 @@
1
1
  x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
- x_transformers/attend.py,sha256=MJmMLIt0rHx-JNNsc2auUsCjsB-69NewufaRV32ADmA,14012
2
+ x_transformers/attend.py,sha256=tZOPsSgWwabIGgUlNdWrVZgcxOkF3sgiQQ7nrcOjl9I,15151
3
3
  x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=q6I6rvyYUWLgwtKOxPwF12UL1HzcIlauI8YrM8gvZac,83901
8
+ x_transformers/x_transformers.py,sha256=Dol6GMZOoHGOFdHwe21o2SbJp6b3YKCUHoIs_AjfvTo,83963
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.37.10.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.37.10.dist-info/METADATA,sha256=qA7YAj5ZeaesnGamBR-cPSR_0HSwquwPBptaUmi7c3c,662
13
- x_transformers-1.37.10.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
- x_transformers-1.37.10.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.37.10.dist-info/RECORD,,
11
+ x_transformers-1.38.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.38.0.dist-info/METADATA,sha256=9KjkWIqzpGS3vs7MpWmfbFfPHPvdOJSE6hOhxQsu4XI,661
13
+ x_transformers-1.38.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
+ x_transformers-1.38.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.38.0.dist-info/RECORD,,