x-transformers 1.38.2__py3-none-any.whl → 1.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -67,7 +67,7 @@ print_once = once(print)
67
67
 
68
68
  # selective attention
69
69
  # https://arxiv.org/abs/2410.02703 - section 3.3
70
- # it is a parameter-less technique to allow each token to prevent itself from being attended to by future tokens
70
+ # it is a technique to allow each token to prevent itself from being attended to by future tokens
71
71
  # if sim_head_gate not supplied, will use the first head of the attention logits (sim in this framework)
72
72
 
73
73
  def selective_attn(
@@ -81,9 +81,10 @@ def selective_attn(
81
81
  gate = F.relu(sim_head_gate) # only positive
82
82
 
83
83
  if no_mask_sos:
84
+ gate = gate.clone()
84
85
  gate[..., -i] = 0.
85
86
 
86
- eye = torch.eye(j, device = device)
87
+ eye = torch.eye(i, device = device)
87
88
 
88
89
  if j > i:
89
90
  eye = F.pad(eye, (j - i, 0), value = 1.)
@@ -127,7 +128,8 @@ class Attend(Module):
127
128
  dropout = 0.,
128
129
  causal = False,
129
130
  heads = None,
130
- talking_heads = False,
131
+ pre_talking_heads = True,
132
+ post_talking_heads = True,
131
133
  sparse_topk = None,
132
134
  scale = None,
133
135
  qk_norm = False,
@@ -178,14 +180,15 @@ class Attend(Module):
178
180
 
179
181
  # talking heads
180
182
 
181
- assert not (flash and talking_heads), 'talking heads not compatible with flash attention'
183
+ assert not (flash and (pre_talking_heads or post_talking_heads)), 'talking heads not compatible with flash attention'
182
184
 
183
- self.talking_heads = talking_heads
184
- if talking_heads:
185
- self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
186
- self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
185
+ self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if pre_talking_heads else None
186
+ self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if post_talking_heads else None
187
187
 
188
+ if exists(self.pre_softmax_talking_heads):
188
189
  nn.init.dirac_(self.pre_softmax_talking_heads.weight)
190
+
191
+ if exists(self.post_softmax_talking_heads):
189
192
  nn.init.dirac_(self.post_softmax_talking_heads.weight)
190
193
 
191
194
  # selective attention
@@ -433,8 +436,8 @@ class Attend(Module):
433
436
 
434
437
  qk_similarities = sim.clone()
435
438
 
436
- if self.talking_heads:
437
- sim = self.pre_softmax_talking_heads(sim)
439
+ if exists(self.pre_softmax_talking_heads):
440
+ sim = sim + self.pre_softmax_talking_heads(sim)
438
441
 
439
442
  if exists(attn_bias):
440
443
  sim = sim + attn_bias
@@ -481,7 +484,7 @@ class Attend(Module):
481
484
 
482
485
  attn = self.attn_dropout(attn)
483
486
 
484
- if self.talking_heads:
487
+ if exists(self.post_softmax_talking_heads):
485
488
  attn = self.post_softmax_talking_heads(attn)
486
489
 
487
490
  out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
@@ -907,7 +907,8 @@ class Attention(Module):
907
907
  heads = 8,
908
908
  causal = False,
909
909
  flash = False,
910
- talking_heads = False,
910
+ pre_talking_heads = False,
911
+ post_talking_heads = False,
911
912
  head_scale = False,
912
913
  sparse_topk = None,
913
914
  num_mem_kv = 0,
@@ -1036,7 +1037,8 @@ class Attention(Module):
1036
1037
  self.attend = Attend(
1037
1038
  heads = heads,
1038
1039
  causal = causal,
1039
- talking_heads = talking_heads,
1040
+ pre_talking_heads = pre_talking_heads,
1041
+ post_talking_heads = post_talking_heads,
1040
1042
  dropout = dropout,
1041
1043
  sparse_topk = sparse_topk,
1042
1044
  qk_norm = qk_norm,
@@ -1084,6 +1086,10 @@ class Attention(Module):
1084
1086
 
1085
1087
  self.rotary_embed_values = rotary_embed_values
1086
1088
 
1089
+ # whether parent can kv cache
1090
+
1091
+ self.can_cache_kv = not selective
1092
+
1087
1093
  # init output projection 0
1088
1094
 
1089
1095
  if zero_init_output:
@@ -1634,6 +1640,10 @@ class AttentionLayers(Module):
1634
1640
  residual
1635
1641
  ]))
1636
1642
 
1643
+ # determine whether can cache kv
1644
+
1645
+ self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention) ])
1646
+
1637
1647
  def forward(
1638
1648
  self,
1639
1649
  x,
@@ -2135,7 +2145,7 @@ class TransformerWrapper(Module):
2135
2145
 
2136
2146
  # whether can do cached kv decoding
2137
2147
 
2138
- self.can_cache_kv = self.num_memory_tokens == 0 and not recycling
2148
+ self.can_cache_kv = self.num_memory_tokens == 0 and not recycling and self.attn_layers.can_cache_kv
2139
2149
  self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb
2140
2150
 
2141
2151
  def init_(self):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.38.2
3
+ Version: 1.39.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,15 +1,15 @@
1
1
  x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
- x_transformers/attend.py,sha256=HO6HZ1fowJ6a6v915PH4s8PnfNj0_q47Sq7yc9AP5YQ,15380
2
+ x_transformers/attend.py,sha256=sTeX7DmUt6I5FhHtgcTDOIvmD1CvJ1PmVjZ_-lYO-QA,15596
3
3
  x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=Dol6GMZOoHGOFdHwe21o2SbJp6b3YKCUHoIs_AjfvTo,83963
8
+ x_transformers/x_transformers.py,sha256=nfq-EOLx5HWf8tXlmVuDbkhqNfFnfqRMCEkALK3SFkA,84341
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.38.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.38.2.dist-info/METADATA,sha256=pThqFTEo8bihgUlSYdv3r-JFB153pbO2baJgXwMMZZs,661
13
- x_transformers-1.38.2.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
- x_transformers-1.38.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.38.2.dist-info/RECORD,,
11
+ x_transformers-1.39.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.39.0.dist-info/METADATA,sha256=-ppCMMH6ZTsmwaMJB9q4b4Yvd-nU8v95l5SdzTY17OU,661
13
+ x_transformers-1.39.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
+ x_transformers-1.39.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.39.0.dist-info/RECORD,,