x-transformers 2.11.5__tar.gz → 2.11.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of x-transformers might be problematic. Click here for more details.

Files changed (68) hide show
  1. {x_transformers-2.11.5 → x_transformers-2.11.8}/PKG-INFO +1 -1
  2. {x_transformers-2.11.5 → x_transformers-2.11.8}/pyproject.toml +1 -1
  3. {x_transformers-2.11.5 → x_transformers-2.11.8}/tests/test_x_transformers.py +17 -0
  4. {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/free_transformer.py +28 -21
  5. {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/x_transformers.py +29 -3
  6. {x_transformers-2.11.5 → x_transformers-2.11.8}/.github/FUNDING.yml +0 -0
  7. {x_transformers-2.11.5 → x_transformers-2.11.8}/.github/workflows/python-publish.yml +0 -0
  8. {x_transformers-2.11.5 → x_transformers-2.11.8}/.github/workflows/python-test.yaml +0 -0
  9. {x_transformers-2.11.5 → x_transformers-2.11.8}/.gitignore +0 -0
  10. {x_transformers-2.11.5 → x_transformers-2.11.8}/LICENSE +0 -0
  11. {x_transformers-2.11.5 → x_transformers-2.11.8}/README.md +0 -0
  12. {x_transformers-2.11.5 → x_transformers-2.11.8}/data/README.md +0 -0
  13. {x_transformers-2.11.5 → x_transformers-2.11.8}/data/enwik8.gz +0 -0
  14. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/all-attention.png +0 -0
  15. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/attention-on-attention.png +0 -0
  16. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/cosine-sim-attention.png +0 -0
  17. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/deepnorm.png +0 -0
  18. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/dynamic-pos-bias-linear.png +0 -0
  19. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/dynamic-pos-bias-log.png +0 -0
  20. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  21. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/dynamic-pos-bias.png +0 -0
  22. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/enhanced-recurrence.png +0 -0
  23. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/fcm.png +0 -0
  24. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/ffglu.png +0 -0
  25. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/flash-attention.png +0 -0
  26. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/gate_values.png +0 -0
  27. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/gating.png +0 -0
  28. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/length-extrapolation-scale.png +0 -0
  29. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/macaron-1.png +0 -0
  30. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/macaron-2.png +0 -0
  31. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/memory-transformer.png +0 -0
  32. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/normformer.png +0 -0
  33. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/pia.png +0 -0
  34. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/qknorm-analysis.png +0 -0
  35. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/resi_dual.png +0 -0
  36. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/residual_attn.png +0 -0
  37. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/rezero.png +0 -0
  38. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/rotary.png +0 -0
  39. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/sandwich-2.png +0 -0
  40. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/sandwich.png +0 -0
  41. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/sandwich_norm.png +0 -0
  42. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/scalenorm.png +0 -0
  43. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/talking-heads.png +0 -0
  44. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/topk-attention.png +0 -0
  45. {x_transformers-2.11.5 → x_transformers-2.11.8}/images/xval.png +0 -0
  46. {x_transformers-2.11.5 → x_transformers-2.11.8}/train_belief_state.py +0 -0
  47. {x_transformers-2.11.5 → x_transformers-2.11.8}/train_copy.py +0 -0
  48. {x_transformers-2.11.5 → x_transformers-2.11.8}/train_entropy_tokenizer.py +0 -0
  49. {x_transformers-2.11.5 → x_transformers-2.11.8}/train_enwik8.py +0 -0
  50. {x_transformers-2.11.5 → x_transformers-2.11.8}/train_free.py +0 -0
  51. {x_transformers-2.11.5 → x_transformers-2.11.8}/train_gpt_vae.py +0 -0
  52. {x_transformers-2.11.5 → x_transformers-2.11.8}/train_length_extrapolate.py +0 -0
  53. {x_transformers-2.11.5 → x_transformers-2.11.8}/train_parity.py +0 -0
  54. {x_transformers-2.11.5 → x_transformers-2.11.8}/train_with_muon.py +0 -0
  55. {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/__init__.py +0 -0
  56. {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/attend.py +0 -0
  57. {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/autoregressive_wrapper.py +0 -0
  58. {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/belief_state_wrapper.py +0 -0
  59. {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/continuous.py +0 -0
  60. {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/dpo.py +0 -0
  61. {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/entropy_based_tokenizer.py +0 -0
  62. {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/gpt_vae.py +0 -0
  63. {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/multi_input.py +0 -0
  64. {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/neo_mlp.py +0 -0
  65. {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/nonautoregressive_wrapper.py +0 -0
  66. {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/up_wrapper.py +0 -0
  67. {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  68. {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.11.5
3
+ Version: 2.11.8
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.11.5"
3
+ version = "2.11.8"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1434,3 +1434,20 @@ def test_free(
1434
1434
  loss.backward()
1435
1435
 
1436
1436
  assert aux_loss.numel() == 1
1437
+
1438
+ def test_kv_input_residual():
1439
+ attn = Decoder(
1440
+ dim = 256,
1441
+ depth = 2,
1442
+ heads = 4,
1443
+ cross_attend = True
1444
+ )
1445
+
1446
+ tokens = torch.randn(3, 32, 256)
1447
+ context = torch.randn(3, 64, 256)
1448
+
1449
+ condition = torch.randn(2, 3, 64, 256)
1450
+
1451
+ out = attn(tokens, context = context, cross_attn_kv_residuals = condition)
1452
+
1453
+ assert tokens.shape == out.shape
@@ -66,7 +66,6 @@ class BinaryMapper(Module):
66
66
 
67
67
  self.bits = bits
68
68
  self.num_codes = 2 ** bits
69
- self.kl_loss_threshold = kl_loss_threshold
70
69
 
71
70
  power_two = 2 ** arange(bits)
72
71
  codes = (arange(self.num_codes)[:, None].bitwise_and(power_two) != 0).byte().bool()
@@ -74,13 +73,20 @@ class BinaryMapper(Module):
74
73
  self.register_buffer('power_two', power_two, persistent = False)
75
74
  self.register_buffer('codes', codes, persistent = False)
76
75
 
76
+ # aux loss
77
+
78
+ self.kl_loss_threshold = kl_loss_threshold
79
+ self.register_buffer('zero', tensor(0.), persistent = False)
80
+
77
81
  def forward(
78
82
  self,
79
83
  logits,
80
84
  temperature = 1.,
81
- straight_through = None
85
+ straight_through = None,
86
+ calc_aux_loss = None
82
87
  ):
83
88
  straight_through = default(straight_through, self.training)
89
+ calc_aux_loss = default(calc_aux_loss, self.training)
84
90
 
85
91
  assert logits.shape[-1] == self.bits, f'logits must have a last dimension of {self.bits}'
86
92
 
@@ -95,26 +101,29 @@ class BinaryMapper(Module):
95
101
 
96
102
  one_hot = F.one_hot(indices, self.num_codes).float()
97
103
 
98
- # return hard one hot if not training or overridden
104
+ # maybe calculate aux loss
99
105
 
100
- if not straight_through:
101
- return one_hot
106
+ aux_kl_loss = self.zero
102
107
 
103
- # calculate negative entropy
108
+ if calc_aux_loss:
109
+ # calculate negative entropy
104
110
 
105
- kl_div = self.bits * NAT - binary_entropy(logits)
106
- aux_kl_loss = F.relu(kl_div - self.kl_loss_threshold).mean()
111
+ kl_div = self.bits * NAT - binary_entropy(logits)
112
+ aux_kl_loss = F.relu(kl_div - self.kl_loss_threshold).mean()
107
113
 
108
- # get the soft G for the gradients and do a straight through
114
+ # maybe straight through
109
115
 
110
- soft_G = (
111
- einsum(F.logsigmoid(logits), self.codes.float(), '... bits, codes bits -> ... codes') +
112
- einsum(F.logsigmoid(-logits), (~self.codes).float(), '... bits, codes bits -> ... codes')
113
- ).exp()
116
+ if straight_through:
117
+ # get the soft G for the gradients and do a straight through
114
118
 
115
- # straight through
119
+ soft_G = (
120
+ einsum(F.logsigmoid(logits), self.codes.float(), '... bits, codes bits -> ... codes') +
121
+ einsum(F.logsigmoid(-logits), (~self.codes).float(), '... bits, codes bits -> ... codes')
122
+ ).exp()
116
123
 
117
- one_hot = one_hot + soft_G - soft_G.detach()
124
+ # straight through
125
+
126
+ one_hot = one_hot + soft_G - soft_G.detach()
118
127
 
119
128
  return one_hot, aux_kl_loss
120
129
 
@@ -163,6 +172,7 @@ class FreeTransformer(Module):
163
172
  cross_attend = True,
164
173
  use_rmsnorm = True,
165
174
  rotary_pos_emb = True,
175
+ pre_norm_has_final_norm = True,
166
176
  **kwargs,
167
177
  **enc_kwargs
168
178
  )
@@ -242,7 +252,7 @@ class FreeTransformer(Module):
242
252
 
243
253
  bit_logits = self.to_latent_bit_logits(pooled)
244
254
 
245
- one_hot_latents, kl_loss = self.binary_mapper(bit_logits, straight_through = True)
255
+ one_hot_latents, kl_loss = self.binary_mapper(bit_logits, calc_aux_loss = return_kl_loss)
246
256
 
247
257
  if not return_kl_loss:
248
258
  return one_hot_latents
@@ -286,10 +296,7 @@ class FreeTransformer(Module):
286
296
 
287
297
  head_embed = self.decoder_head(tokens)
288
298
 
289
- if exists(condition):
290
- head_embed = head_embed + condition
291
-
292
- tail_embed = self.decoder_tail(head_embed)
299
+ tail_embed = self.decoder_tail(head_embed, self_attn_kv_residuals = condition)
293
300
 
294
301
  tail_embed = tail_embed[:, -1]
295
302
 
@@ -329,7 +336,7 @@ class FreeTransformer(Module):
329
336
 
330
337
  # decoder tail
331
338
 
332
- tokens = self.decoder_tail(tokens + condition)
339
+ tokens = self.decoder_tail(tokens, self_attn_kv_residuals = condition)
333
340
 
334
341
  # cross entropy loss
335
342
 
@@ -1686,6 +1686,7 @@ class Attention(Module):
1686
1686
  value_residual = None,
1687
1687
  additional_key_values: tuple[Tensor, Tensor] | None = None,
1688
1688
  additional_key_value_mask = None,
1689
+ kv_input_residual = None,
1689
1690
  ):
1690
1691
  b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals, is_multi_latent_attn = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals, self.use_latent_kv
1691
1692
 
@@ -1702,6 +1703,12 @@ class Attention(Module):
1702
1703
  kv_input = default(context, x)
1703
1704
  q_input, k_input, v_input = x, kv_input, kv_input
1704
1705
 
1706
+ # done for free transformer
1707
+
1708
+ if exists(kv_input_residual):
1709
+ k_input = k_input + kv_input_residual
1710
+ v_input = v_input + kv_input_residual
1711
+
1705
1712
  if exists(mem):
1706
1713
  k_input, mem_packed_shape = pack([mem, k_input], 'b * d')
1707
1714
  v_input, _ = pack([mem, v_input], 'b * d')
@@ -2543,7 +2550,9 @@ class AttentionLayers(Module):
2543
2550
  route_additional_kv_to_top = True,
2544
2551
  condition = None,
2545
2552
  in_attn_cond = None, # https://arxiv.org/abs/2105.04090
2546
- layers_execute_order: tuple[int, ...] | None = None
2553
+ layers_execute_order: tuple[int, ...] | None = None,
2554
+ self_attn_kv_residuals: Tensor | None = None,
2555
+ cross_attn_kv_residuals: Tensor | None = None
2547
2556
  ):
2548
2557
  assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
2549
2558
  assert not (exists(condition) ^ self.need_condition), 'condition needs to be passed in if using adaptive layernorm or vice versa'
@@ -2721,6 +2730,23 @@ class AttentionLayers(Module):
2721
2730
 
2722
2731
  skip_hiddens = []
2723
2732
 
2733
+ # for residuals to key value inputs for self and cross attention
2734
+
2735
+ self_attn_kv_residuals_iter = iter((None,))
2736
+ cross_attn_kv_residuals_iter = iter((None,))
2737
+
2738
+ if exists(self_attn_kv_residuals):
2739
+ if self_attn_kv_residuals.ndim == 3:
2740
+ self_attn_kv_residuals = rearrange(self_attn_kv_residuals, '... -> 1 ...')
2741
+
2742
+ self_attn_kv_residuals_iter = iter(self_attn_kv_residuals)
2743
+
2744
+ if exists(cross_attn_kv_residuals):
2745
+ if cross_attn_kv_residuals.ndim == 3:
2746
+ cross_attn_kv_residuals = rearrange(cross_attn_kv_residuals, '... -> 1 ...')
2747
+
2748
+ cross_attn_kv_residuals_iter = iter(cross_attn_kv_residuals)
2749
+
2724
2750
  # for value residuals
2725
2751
 
2726
2752
  first_self_attn_inter = None
@@ -2794,9 +2820,9 @@ class AttentionLayers(Module):
2794
2820
  # forward depending on layer type
2795
2821
 
2796
2822
  if layer_type == 'a':
2797
- out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, additional_key_values = next(iter_self_attn_kv, None), additional_key_value_mask = additional_kv_mask, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
2823
+ out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, additional_key_values = next(iter_self_attn_kv, None), additional_key_value_mask = additional_kv_mask, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, kv_input_residual = next(self_attn_kv_residuals_iter, None), value_residual = maybe_self_attn_value_residual, return_intermediates = True)
2798
2824
  elif layer_type == 'c':
2799
- out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
2825
+ out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), kv_input_residual = next(cross_attn_kv_residuals_iter, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
2800
2826
  elif layer_type == 'f':
2801
2827
  out = block(x, deep_embed = next(deep_embeds_iter, None))
2802
2828
 
File without changes