x-transformers 2.3.1__tar.gz → 2.3.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. {x_transformers-2.3.1 → x_transformers-2.3.3}/PKG-INFO +1 -1
  2. {x_transformers-2.3.1 → x_transformers-2.3.3}/pyproject.toml +1 -1
  3. {x_transformers-2.3.1 → x_transformers-2.3.3}/tests/test_x_transformers.py +4 -2
  4. {x_transformers-2.3.1 → x_transformers-2.3.3}/x_transformers/continuous.py +29 -2
  5. {x_transformers-2.3.1 → x_transformers-2.3.3}/.github/FUNDING.yml +0 -0
  6. {x_transformers-2.3.1 → x_transformers-2.3.3}/.github/workflows/python-publish.yml +0 -0
  7. {x_transformers-2.3.1 → x_transformers-2.3.3}/.github/workflows/python-test.yaml +0 -0
  8. {x_transformers-2.3.1 → x_transformers-2.3.3}/.gitignore +0 -0
  9. {x_transformers-2.3.1 → x_transformers-2.3.3}/LICENSE +0 -0
  10. {x_transformers-2.3.1 → x_transformers-2.3.3}/README.md +0 -0
  11. {x_transformers-2.3.1 → x_transformers-2.3.3}/data/README.md +0 -0
  12. {x_transformers-2.3.1 → x_transformers-2.3.3}/data/enwik8.gz +0 -0
  13. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/all-attention.png +0 -0
  14. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/attention-on-attention.png +0 -0
  15. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/cosine-sim-attention.png +0 -0
  16. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/deepnorm.png +0 -0
  17. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/dynamic-pos-bias-linear.png +0 -0
  18. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/dynamic-pos-bias-log.png +0 -0
  19. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  20. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/dynamic-pos-bias.png +0 -0
  21. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/enhanced-recurrence.png +0 -0
  22. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/fcm.png +0 -0
  23. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/ffglu.png +0 -0
  24. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/flash-attention.png +0 -0
  25. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/gate_values.png +0 -0
  26. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/gating.png +0 -0
  27. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/length-extrapolation-scale.png +0 -0
  28. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/macaron-1.png +0 -0
  29. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/macaron-2.png +0 -0
  30. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/memory-transformer.png +0 -0
  31. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/normformer.png +0 -0
  32. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/pia.png +0 -0
  33. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/qknorm-analysis.png +0 -0
  34. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/resi_dual.png +0 -0
  35. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/residual_attn.png +0 -0
  36. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/rezero.png +0 -0
  37. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/rotary.png +0 -0
  38. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/sandwich-2.png +0 -0
  39. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/sandwich.png +0 -0
  40. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/sandwich_norm.png +0 -0
  41. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/scalenorm.png +0 -0
  42. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/talking-heads.png +0 -0
  43. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/topk-attention.png +0 -0
  44. {x_transformers-2.3.1 → x_transformers-2.3.3}/images/xval.png +0 -0
  45. {x_transformers-2.3.1 → x_transformers-2.3.3}/train_belief_state.py +0 -0
  46. {x_transformers-2.3.1 → x_transformers-2.3.3}/train_copy.py +0 -0
  47. {x_transformers-2.3.1 → x_transformers-2.3.3}/train_entropy_tokenizer.py +0 -0
  48. {x_transformers-2.3.1 → x_transformers-2.3.3}/train_enwik8.py +0 -0
  49. {x_transformers-2.3.1 → x_transformers-2.3.3}/train_length_extrapolate.py +0 -0
  50. {x_transformers-2.3.1 → x_transformers-2.3.3}/train_parity.py +0 -0
  51. {x_transformers-2.3.1 → x_transformers-2.3.3}/x_transformers/__init__.py +0 -0
  52. {x_transformers-2.3.1 → x_transformers-2.3.3}/x_transformers/attend.py +0 -0
  53. {x_transformers-2.3.1 → x_transformers-2.3.3}/x_transformers/autoregressive_wrapper.py +0 -0
  54. {x_transformers-2.3.1 → x_transformers-2.3.3}/x_transformers/belief_state_wrapper.py +0 -0
  55. {x_transformers-2.3.1 → x_transformers-2.3.3}/x_transformers/dpo.py +0 -0
  56. {x_transformers-2.3.1 → x_transformers-2.3.3}/x_transformers/entropy_based_tokenizer.py +0 -0
  57. {x_transformers-2.3.1 → x_transformers-2.3.3}/x_transformers/multi_input.py +0 -0
  58. {x_transformers-2.3.1 → x_transformers-2.3.3}/x_transformers/neo_mlp.py +0 -0
  59. {x_transformers-2.3.1 → x_transformers-2.3.3}/x_transformers/nonautoregressive_wrapper.py +0 -0
  60. {x_transformers-2.3.1 → x_transformers-2.3.3}/x_transformers/x_transformers.py +0 -0
  61. {x_transformers-2.3.1 → x_transformers-2.3.3}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  62. {x_transformers-2.3.1 → x_transformers-2.3.3}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.3.1
3
+ Version: 2.3.3
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.3.1"
3
+ version = "2.3.3"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -849,8 +849,10 @@ def test_custom_ff_activation():
849
849
  assert logits.shape == (2, 1024, 20000)
850
850
 
851
851
  @pytest.mark.parametrize('probabilistic', (False, True))
852
+ @pytest.mark.parametrize('cache_kv', (False, True))
852
853
  def test_continuous(
853
- probabilistic
854
+ probabilistic,
855
+ cache_kv
854
856
  ):
855
857
  from x_transformers import (
856
858
  ContinuousTransformerWrapper,
@@ -887,5 +889,5 @@ def test_continuous(
887
889
  # then generate
888
890
 
889
891
  start_emb = torch.randn(1, 777)
890
- generated = model.generate(start_emb, 17) # (17, 777)
892
+ generated = model.generate(start_emb, 17, cache_kv = cache_kv) # (17, 777)
891
893
  assert generated.shape == (17, 777)
@@ -10,6 +10,7 @@ import einx
10
10
  from einops import rearrange, reduce, pack, repeat, unpack
11
11
 
12
12
  from x_transformers.x_transformers import (
13
+ Attention,
13
14
  AttentionLayers,
14
15
  ScaledSinusoidalEmbedding,
15
16
  AbsolutePositionalEmbedding,
@@ -111,6 +112,10 @@ class ContinuousTransformerWrapper(Module):
111
112
 
112
113
  self.project_out = nn.Linear(dim, dim_out * (2 if probabilistic else 1), bias = False) if exists(dim_out) else nn.Identity()
113
114
 
115
+ # can cache kv
116
+
117
+ self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)])
118
+
114
119
  def forward(
115
120
  self,
116
121
  x,
@@ -118,6 +123,7 @@ class ContinuousTransformerWrapper(Module):
118
123
  return_intermediates = False,
119
124
  return_mems = False,
120
125
  mask = None,
126
+ lens = None,
121
127
  return_attn = False,
122
128
  mems = None,
123
129
  mem_masks = None,
@@ -128,6 +134,16 @@ class ContinuousTransformerWrapper(Module):
128
134
  ):
129
135
  batch, seq, orig_mask, device = *x.shape[:2], mask, x.device
130
136
 
137
+ # maybe seq lengths passed in
138
+
139
+ if exists(lens):
140
+ assert not exists(mask), 'either `mask` or `lens` passed in, but not both'
141
+ seq_arange = torch.arange(seq, device = device)
142
+
143
+ mask = einx.less('j, i -> i j', seq_arange, lens)
144
+
145
+ # project in + positional embedding
146
+
131
147
  x = self.project_in(x)
132
148
  x = x + self.pos_emb(x, pos = pos)
133
149
 
@@ -180,7 +196,7 @@ class ContinuousTransformerWrapper(Module):
180
196
  if not return_embeddings and self.probabilistic:
181
197
  mean, log_var = rearrange(out, '... (d mean_log_var) -> mean_log_var ... d', mean_log_var = 2)
182
198
  variance = log_var.exp()
183
- return stack((mean, variance))
199
+ out = stack((mean, variance))
184
200
 
185
201
  if return_intermediates:
186
202
  return out, intermediates
@@ -223,9 +239,12 @@ class ContinuousAutoregressiveWrapper(Module):
223
239
  start_tokens,
224
240
  seq_len,
225
241
  temperature = 1.,
242
+ cache_kv = True,
226
243
  **kwargs
227
244
  ):
245
+ should_cache_kv = cache_kv and self.net.can_cache_kv
228
246
  device = start_tokens.device
247
+
229
248
  was_training = self.net.training
230
249
  num_dims = len(start_tokens.shape)
231
250
 
@@ -239,10 +258,14 @@ class ContinuousAutoregressiveWrapper(Module):
239
258
  self.net.eval()
240
259
  out = start_tokens
241
260
 
261
+ cache = None
262
+
242
263
  for _ in range(seq_len):
243
264
  x = out[:, -self.max_seq_len:]
244
265
 
245
- last_output = self.net(x, **kwargs)[..., -1:, :]
266
+ net_out, new_cache = self.net(x, cache = cache, return_intermediates = True, **kwargs)
267
+
268
+ last_output = net_out[..., -1:, :]
246
269
 
247
270
  if self.probabilistic:
248
271
  mean, var = last_output
@@ -250,6 +273,9 @@ class ContinuousAutoregressiveWrapper(Module):
250
273
 
251
274
  out = cat((out, last_output), dim = -2)
252
275
 
276
+ if should_cache_kv:
277
+ cache = new_cache
278
+
253
279
  out = out[:, t:]
254
280
 
255
281
  if num_dims == 2:
@@ -268,6 +294,7 @@ class ContinuousAutoregressiveWrapper(Module):
268
294
  assert 'prepend_embeds' not in kwargs
269
295
 
270
296
  mask = kwargs.get('mask', None)
297
+
271
298
  if exists(mask) and mask.shape[1] == x.shape[1]:
272
299
  mask = mask[:, :-1]
273
300
  kwargs['mask'] = mask
File without changes
File without changes