x-transformers 1.38.2__py3-none-any.whl → 1.38.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -67,7 +67,7 @@ print_once = once(print)
67
67
 
68
68
  # selective attention
69
69
  # https://arxiv.org/abs/2410.02703 - section 3.3
70
- # it is a parameter-less technique to allow each token to prevent itself from being attended to by future tokens
70
+ # it is a technique to allow each token to prevent itself from being attended to by future tokens
71
71
  # if sim_head_gate not supplied, will use the first head of the attention logits (sim in this framework)
72
72
 
73
73
  def selective_attn(
@@ -81,9 +81,10 @@ def selective_attn(
81
81
  gate = F.relu(sim_head_gate) # only positive
82
82
 
83
83
  if no_mask_sos:
84
+ gate = gate.clone()
84
85
  gate[..., -i] = 0.
85
86
 
86
- eye = torch.eye(j, device = device)
87
+ eye = torch.eye(i, device = device)
87
88
 
88
89
  if j > i:
89
90
  eye = F.pad(eye, (j - i, 0), value = 1.)
@@ -1084,6 +1084,10 @@ class Attention(Module):
1084
1084
 
1085
1085
  self.rotary_embed_values = rotary_embed_values
1086
1086
 
1087
+ # whether parent can kv cache
1088
+
1089
+ self.can_cache_kv = not selective
1090
+
1087
1091
  # init output projection 0
1088
1092
 
1089
1093
  if zero_init_output:
@@ -1634,6 +1638,10 @@ class AttentionLayers(Module):
1634
1638
  residual
1635
1639
  ]))
1636
1640
 
1641
+ # determine whether can cache kv
1642
+
1643
+ self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention) ])
1644
+
1637
1645
  def forward(
1638
1646
  self,
1639
1647
  x,
@@ -2135,7 +2143,7 @@ class TransformerWrapper(Module):
2135
2143
 
2136
2144
  # whether can do cached kv decoding
2137
2145
 
2138
- self.can_cache_kv = self.num_memory_tokens == 0 and not recycling
2146
+ self.can_cache_kv = self.num_memory_tokens == 0 and not recycling and self.attn_layers.can_cache_kv
2139
2147
  self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb
2140
2148
 
2141
2149
  def init_(self):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.38.2
3
+ Version: 1.38.3
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,15 +1,15 @@
1
1
  x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
- x_transformers/attend.py,sha256=HO6HZ1fowJ6a6v915PH4s8PnfNj0_q47Sq7yc9AP5YQ,15380
2
+ x_transformers/attend.py,sha256=Z9cnbY3f9zCl1yUOMWVQT5_ee4keC0lo_NX4cd0rbKk,15393
3
3
  x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=Dol6GMZOoHGOFdHwe21o2SbJp6b3YKCUHoIs_AjfvTo,83963
8
+ x_transformers/x_transformers.py,sha256=P3o1R2DY2ic71FlJJ4ie4w_z-g3jIIrkBXcbllRoXHA,84240
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.38.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.38.2.dist-info/METADATA,sha256=pThqFTEo8bihgUlSYdv3r-JFB153pbO2baJgXwMMZZs,661
13
- x_transformers-1.38.2.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
- x_transformers-1.38.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.38.2.dist-info/RECORD,,
11
+ x_transformers-1.38.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.38.3.dist-info/METADATA,sha256=XbmJ91NWqmOjO0xvyXh9uh82TIwVbU54L3eFmaFVvYs,661
13
+ x_transformers-1.38.3.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
+ x_transformers-1.38.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.38.3.dist-info/RECORD,,