x-transformers 2.5.6__py3-none-any.whl → 2.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1617,7 +1617,9 @@ class Attention(Module):
1617
1617
  mem_mask = None,
1618
1618
  return_intermediates = False,
1619
1619
  cache: Intermediates | None = None,
1620
- value_residual = None
1620
+ value_residual = None,
1621
+ additional_key_values: tuple[Tensor, Tensor] | None = None,
1622
+ additional_key_value_mask = None,
1621
1623
  ):
1622
1624
  b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals, is_multi_latent_attn = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals, self.use_latent_kv
1623
1625
 
@@ -1787,6 +1789,26 @@ class Attention(Module):
1787
1789
  if exists(input_mask):
1788
1790
  input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True)
1789
1791
 
1792
+ # maybe append additional key / values
1793
+
1794
+ if exists(additional_key_values):
1795
+ seq_len = k.shape[-2]
1796
+
1797
+ added_k, added_v = additional_key_values
1798
+
1799
+ k = cat((added_k, k), dim = -2)
1800
+ v = cat((added_v, v), dim = -2)
1801
+
1802
+ if (exists(input_mask) or exists(additional_key_value_mask)):
1803
+
1804
+ if not exists(additional_key_value_mask):
1805
+ added_kv_len = added_k.shape[-2]
1806
+ input_mask = pad_at_dim(input_mask, (added_kv_len, 0), dim = -1, value = True)
1807
+ elif not exists(input_mask):
1808
+ input_mask = pad_at_dim(additional_key_value_mask, (0, seq_len), dim = -1, value = True)
1809
+ else:
1810
+ input_mask = cat((additional_key_value_mask, input_mask), dim = -1)
1811
+
1790
1812
  # determine masking
1791
1813
 
1792
1814
  mask_value = max_neg_value(q)
@@ -2411,6 +2433,8 @@ class AttentionLayers(Module):
2411
2433
  context_pos = None,
2412
2434
  attn_bias = None,
2413
2435
  deep_embeds_and_ids: tuple[nn.Parameter, Tensor] | None = None,
2436
+ self_attn_additional_kv: list[tuple[Tensor, Tensor]] | None = None,
2437
+ additional_kv_mask = None,
2414
2438
  condition = None,
2415
2439
  in_attn_cond = None, # https://arxiv.org/abs/2105.04090
2416
2440
  layers_execute_order: tuple[int, ...] | None = None
@@ -2520,6 +2544,10 @@ class AttentionLayers(Module):
2520
2544
 
2521
2545
  iter_attn_cache = iter(attn_cache)
2522
2546
 
2547
+ # additional self attn key / values
2548
+
2549
+ iter_self_attn_kv = iter(default(self_attn_additional_kv, ()))
2550
+
2523
2551
  # handle deep embeds if needed
2524
2552
 
2525
2553
  deep_embeds = []
@@ -2647,7 +2675,7 @@ class AttentionLayers(Module):
2647
2675
  # forward depending on layer type
2648
2676
 
2649
2677
  if layer_type == 'a':
2650
- out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
2678
+ out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, additional_key_values = next(iter_self_attn_kv, None), additional_key_value_mask = additional_kv_mask, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
2651
2679
  elif layer_type == 'c':
2652
2680
  out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
2653
2681
  elif layer_type == 'f':
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.5.6
3
+ Version: 2.6.1
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -9,10 +9,10 @@ x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg
9
9
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
10
10
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
11
11
  x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
12
- x_transformers/x_transformers.py,sha256=ymq4TL2OyeCPxwbTL0ShKptzRsMWXxsPv1MKZ9MxbHY,120021
12
+ x_transformers/x_transformers.py,sha256=O8Z4j7wDrj47-lZxmpvToHbHpoFqLy2pk199tQ4v4hI,121281
13
13
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
14
14
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
15
- x_transformers-2.5.6.dist-info/METADATA,sha256=2EpDDqrjJ20hyMPfHwQeoigq4LqJhD93h9w12FX82uQ,90223
16
- x_transformers-2.5.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
- x_transformers-2.5.6.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
- x_transformers-2.5.6.dist-info/RECORD,,
15
+ x_transformers-2.6.1.dist-info/METADATA,sha256=VVzitcHytmh6tNmtjSMyDWjFxfjpQu6PhR4sTFkxjpk,90223
16
+ x_transformers-2.6.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
+ x_transformers-2.6.1.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
+ x_transformers-2.6.1.dist-info/RECORD,,