x-transformers 2.11.5__tar.gz → 2.11.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of x-transformers might be problematic. Click here for more details.

Files changed (68) hide show
  1. {x_transformers-2.11.5 → x_transformers-2.11.7}/PKG-INFO +1 -1
  2. {x_transformers-2.11.5 → x_transformers-2.11.7}/pyproject.toml +1 -1
  3. {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/free_transformer.py +26 -16
  4. {x_transformers-2.11.5 → x_transformers-2.11.7}/.github/FUNDING.yml +0 -0
  5. {x_transformers-2.11.5 → x_transformers-2.11.7}/.github/workflows/python-publish.yml +0 -0
  6. {x_transformers-2.11.5 → x_transformers-2.11.7}/.github/workflows/python-test.yaml +0 -0
  7. {x_transformers-2.11.5 → x_transformers-2.11.7}/.gitignore +0 -0
  8. {x_transformers-2.11.5 → x_transformers-2.11.7}/LICENSE +0 -0
  9. {x_transformers-2.11.5 → x_transformers-2.11.7}/README.md +0 -0
  10. {x_transformers-2.11.5 → x_transformers-2.11.7}/data/README.md +0 -0
  11. {x_transformers-2.11.5 → x_transformers-2.11.7}/data/enwik8.gz +0 -0
  12. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/all-attention.png +0 -0
  13. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/attention-on-attention.png +0 -0
  14. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/cosine-sim-attention.png +0 -0
  15. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/deepnorm.png +0 -0
  16. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/dynamic-pos-bias-linear.png +0 -0
  17. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/dynamic-pos-bias-log.png +0 -0
  18. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  19. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/dynamic-pos-bias.png +0 -0
  20. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/enhanced-recurrence.png +0 -0
  21. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/fcm.png +0 -0
  22. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/ffglu.png +0 -0
  23. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/flash-attention.png +0 -0
  24. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/gate_values.png +0 -0
  25. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/gating.png +0 -0
  26. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/length-extrapolation-scale.png +0 -0
  27. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/macaron-1.png +0 -0
  28. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/macaron-2.png +0 -0
  29. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/memory-transformer.png +0 -0
  30. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/normformer.png +0 -0
  31. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/pia.png +0 -0
  32. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/qknorm-analysis.png +0 -0
  33. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/resi_dual.png +0 -0
  34. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/residual_attn.png +0 -0
  35. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/rezero.png +0 -0
  36. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/rotary.png +0 -0
  37. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/sandwich-2.png +0 -0
  38. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/sandwich.png +0 -0
  39. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/sandwich_norm.png +0 -0
  40. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/scalenorm.png +0 -0
  41. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/talking-heads.png +0 -0
  42. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/topk-attention.png +0 -0
  43. {x_transformers-2.11.5 → x_transformers-2.11.7}/images/xval.png +0 -0
  44. {x_transformers-2.11.5 → x_transformers-2.11.7}/tests/test_x_transformers.py +0 -0
  45. {x_transformers-2.11.5 → x_transformers-2.11.7}/train_belief_state.py +0 -0
  46. {x_transformers-2.11.5 → x_transformers-2.11.7}/train_copy.py +0 -0
  47. {x_transformers-2.11.5 → x_transformers-2.11.7}/train_entropy_tokenizer.py +0 -0
  48. {x_transformers-2.11.5 → x_transformers-2.11.7}/train_enwik8.py +0 -0
  49. {x_transformers-2.11.5 → x_transformers-2.11.7}/train_free.py +0 -0
  50. {x_transformers-2.11.5 → x_transformers-2.11.7}/train_gpt_vae.py +0 -0
  51. {x_transformers-2.11.5 → x_transformers-2.11.7}/train_length_extrapolate.py +0 -0
  52. {x_transformers-2.11.5 → x_transformers-2.11.7}/train_parity.py +0 -0
  53. {x_transformers-2.11.5 → x_transformers-2.11.7}/train_with_muon.py +0 -0
  54. {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/__init__.py +0 -0
  55. {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/attend.py +0 -0
  56. {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/autoregressive_wrapper.py +0 -0
  57. {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/belief_state_wrapper.py +0 -0
  58. {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/continuous.py +0 -0
  59. {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/dpo.py +0 -0
  60. {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/entropy_based_tokenizer.py +0 -0
  61. {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/gpt_vae.py +0 -0
  62. {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/multi_input.py +0 -0
  63. {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/neo_mlp.py +0 -0
  64. {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/nonautoregressive_wrapper.py +0 -0
  65. {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/up_wrapper.py +0 -0
  66. {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/x_transformers.py +0 -0
  67. {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  68. {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.11.5
3
+ Version: 2.11.7
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.11.5"
3
+ version = "2.11.7"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -66,7 +66,6 @@ class BinaryMapper(Module):
66
66
 
67
67
  self.bits = bits
68
68
  self.num_codes = 2 ** bits
69
- self.kl_loss_threshold = kl_loss_threshold
70
69
 
71
70
  power_two = 2 ** arange(bits)
72
71
  codes = (arange(self.num_codes)[:, None].bitwise_and(power_two) != 0).byte().bool()
@@ -74,13 +73,20 @@ class BinaryMapper(Module):
74
73
  self.register_buffer('power_two', power_two, persistent = False)
75
74
  self.register_buffer('codes', codes, persistent = False)
76
75
 
76
+ # aux loss
77
+
78
+ self.kl_loss_threshold = kl_loss_threshold
79
+ self.register_buffer('zero', tensor(0.), persistent = False)
80
+
77
81
  def forward(
78
82
  self,
79
83
  logits,
80
84
  temperature = 1.,
81
- straight_through = None
85
+ straight_through = None,
86
+ calc_aux_loss = None
82
87
  ):
83
88
  straight_through = default(straight_through, self.training)
89
+ calc_aux_loss = default(calc_aux_loss, self.training)
84
90
 
85
91
  assert logits.shape[-1] == self.bits, f'logits must have a last dimension of {self.bits}'
86
92
 
@@ -95,26 +101,29 @@ class BinaryMapper(Module):
95
101
 
96
102
  one_hot = F.one_hot(indices, self.num_codes).float()
97
103
 
98
- # return hard one hot if not training or overridden
104
+ # maybe calculate aux loss
105
+
106
+ aux_kl_loss = self.zero
99
107
 
100
- if not straight_through:
101
- return one_hot
108
+ if calc_aux_loss:
109
+ # calculate negative entropy
102
110
 
103
- # calculate negative entropy
111
+ kl_div = self.bits * NAT - binary_entropy(logits)
112
+ aux_kl_loss = F.relu(kl_div - self.kl_loss_threshold).mean()
104
113
 
105
- kl_div = self.bits * NAT - binary_entropy(logits)
106
- aux_kl_loss = F.relu(kl_div - self.kl_loss_threshold).mean()
114
+ # maybe straight through
107
115
 
108
- # get the soft G for the gradients and do a straight through
116
+ if straight_through:
117
+ # get the soft G for the gradients and do a straight through
109
118
 
110
- soft_G = (
111
- einsum(F.logsigmoid(logits), self.codes.float(), '... bits, codes bits -> ... codes') +
112
- einsum(F.logsigmoid(-logits), (~self.codes).float(), '... bits, codes bits -> ... codes')
113
- ).exp()
119
+ soft_G = (
120
+ einsum(F.logsigmoid(logits), self.codes.float(), '... bits, codes bits -> ... codes') +
121
+ einsum(F.logsigmoid(-logits), (~self.codes).float(), '... bits, codes bits -> ... codes')
122
+ ).exp()
114
123
 
115
- # straight through
124
+ # straight through
116
125
 
117
- one_hot = one_hot + soft_G - soft_G.detach()
126
+ one_hot = one_hot + soft_G - soft_G.detach()
118
127
 
119
128
  return one_hot, aux_kl_loss
120
129
 
@@ -163,6 +172,7 @@ class FreeTransformer(Module):
163
172
  cross_attend = True,
164
173
  use_rmsnorm = True,
165
174
  rotary_pos_emb = True,
175
+ pre_norm_has_final_norm = True,
166
176
  **kwargs,
167
177
  **enc_kwargs
168
178
  )
@@ -242,7 +252,7 @@ class FreeTransformer(Module):
242
252
 
243
253
  bit_logits = self.to_latent_bit_logits(pooled)
244
254
 
245
- one_hot_latents, kl_loss = self.binary_mapper(bit_logits, straight_through = True)
255
+ one_hot_latents, kl_loss = self.binary_mapper(bit_logits, calc_aux_loss = return_kl_loss)
246
256
 
247
257
  if not return_kl_loss:
248
258
  return one_hot_latents
File without changes