x-transformers 2.5.5__py3-none-any.whl → 2.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1617,7 +1617,8 @@ class Attention(Module):
1617
1617
  mem_mask = None,
1618
1618
  return_intermediates = False,
1619
1619
  cache: Intermediates | None = None,
1620
- value_residual = None
1620
+ value_residual = None,
1621
+ additional_key_values: tuple[Tensor, Tensor] | None = None
1621
1622
  ):
1622
1623
  b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals, is_multi_latent_attn = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals, self.use_latent_kv
1623
1624
 
@@ -1787,6 +1788,19 @@ class Attention(Module):
1787
1788
  if exists(input_mask):
1788
1789
  input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True)
1789
1790
 
1791
+ # maybe append additional key / values
1792
+
1793
+ if exists(additional_key_values):
1794
+
1795
+ added_k, added_v = additional_key_values
1796
+ added_kv_len = added_k.shape[-2]
1797
+
1798
+ k = cat((added_k, k), dim = -2)
1799
+ v = cat((added_v, v), dim = -2)
1800
+
1801
+ if exists(input_mask):
1802
+ input_mask = pad_at_dim(input_mask, (added_kv_len, 0), dim = -1, value = True)
1803
+
1790
1804
  # determine masking
1791
1805
 
1792
1806
  mask_value = max_neg_value(q)
@@ -2267,7 +2281,7 @@ class AttentionLayers(Module):
2267
2281
 
2268
2282
  # whether it has post norm
2269
2283
 
2270
- self.final_norm = norm_fn() if pre_norm else nn.Identity()
2284
+ self.final_norm = norm_fn() if pre_norm and pre_norm_has_final_norm else nn.Identity()
2271
2285
 
2272
2286
  # whether unet or not
2273
2287
 
@@ -2411,6 +2425,7 @@ class AttentionLayers(Module):
2411
2425
  context_pos = None,
2412
2426
  attn_bias = None,
2413
2427
  deep_embeds_and_ids: tuple[nn.Parameter, Tensor] | None = None,
2428
+ self_attn_additional_kv: list[tuple[Tensor, Tensor]] | None = None,
2414
2429
  condition = None,
2415
2430
  in_attn_cond = None, # https://arxiv.org/abs/2105.04090
2416
2431
  layers_execute_order: tuple[int, ...] | None = None
@@ -2520,6 +2535,10 @@ class AttentionLayers(Module):
2520
2535
 
2521
2536
  iter_attn_cache = iter(attn_cache)
2522
2537
 
2538
+ # additional self attn key / values
2539
+
2540
+ iter_self_attn_kv = iter(default(self_attn_additional_kv, ()))
2541
+
2523
2542
  # handle deep embeds if needed
2524
2543
 
2525
2544
  deep_embeds = []
@@ -2647,7 +2666,7 @@ class AttentionLayers(Module):
2647
2666
  # forward depending on layer type
2648
2667
 
2649
2668
  if layer_type == 'a':
2650
- out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
2669
+ out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, additional_key_values = next(iter_self_attn_kv, None), prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
2651
2670
  elif layer_type == 'c':
2652
2671
  out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
2653
2672
  elif layer_type == 'f':
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.5.5
3
+ Version: 2.6.0
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -9,10 +9,10 @@ x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg
9
9
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
10
10
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
11
11
  x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
12
- x_transformers/x_transformers.py,sha256=fW-AoomNCw4n2JFbZN9rZV3lKQvz_Tl6L4txUvac_9o,119993
12
+ x_transformers/x_transformers.py,sha256=6zTGMOo6n9-aIpg9VfKShcjemVOqNvTm8lHMMGK8tIc,120747
13
13
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
14
14
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
15
- x_transformers-2.5.5.dist-info/METADATA,sha256=Igay1acyeLzF_vDvB9BW7NWuAy_ck7G2rhITKre3Lew,90223
16
- x_transformers-2.5.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
- x_transformers-2.5.5.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
- x_transformers-2.5.5.dist-info/RECORD,,
15
+ x_transformers-2.6.0.dist-info/METADATA,sha256=l4XLxjzRkTP3STtaauLDNMWD6vQEqubw56ZUXSn6ajQ,90223
16
+ x_transformers-2.6.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
+ x_transformers-2.6.0.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
+ x_transformers-2.6.0.dist-info/RECORD,,