x-transformers 2.3.19__py3-none-any.whl → 2.3.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -49,6 +49,7 @@ class LayerIntermediates:
49
49
  mems: Tensor | None = None
50
50
  memory_tokens: Tensor | None = None
51
51
  logit_entropies: Tensor | None = None
52
+ cache_length: int = 0
52
53
 
53
54
  LinearNoBias = partial(nn.Linear, bias = False)
54
55
 
@@ -282,12 +283,18 @@ class AbsolutePositionalEmbedding(Module):
282
283
  self.l2norm_embed = l2norm_embed
283
284
  self.emb = nn.Embedding(max_seq_len, dim)
284
285
 
285
- def forward(self, x, pos = None, seq_start_pos = None):
286
+ def forward(
287
+ self,
288
+ x,
289
+ pos = None,
290
+ seq_start_pos = None,
291
+ offset = 0
292
+ ):
286
293
  seq_len, device = x.shape[1], x.device
287
294
  assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
288
295
 
289
296
  if not exists(pos):
290
- pos = arange(seq_len, device = device)
297
+ pos = arange(seq_len, device = device) + offset
291
298
 
292
299
  if exists(seq_start_pos):
293
300
  pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
@@ -307,11 +314,17 @@ class ScaledSinusoidalEmbedding(Module):
307
314
  inv_freq = theta ** -freq_seq
308
315
  self.register_buffer('inv_freq', inv_freq, persistent = False)
309
316
 
310
- def forward(self, x, pos = None, seq_start_pos = None):
317
+ def forward(
318
+ self,
319
+ x,
320
+ pos = None,
321
+ seq_start_pos = None,
322
+ offset = 0
323
+ ):
311
324
  seq_len, device = x.shape[1], x.device
312
325
 
313
326
  if not exists(pos):
314
- pos = arange(seq_len, device = device)
327
+ pos = arange(seq_len, device = device) + offset
315
328
 
316
329
  if exists(seq_start_pos):
317
330
  pos = pos - seq_start_pos[..., None]
@@ -676,7 +689,7 @@ class RotaryEmbedding(Module):
676
689
  return self.forward(t)
677
690
 
678
691
  @autocast('cuda', enabled = False)
679
- def forward(self, t):
692
+ def forward(self, t, offset = 0):
680
693
  max_pos = t.max() + 1
681
694
 
682
695
  if t.ndim == 1:
@@ -2373,7 +2386,9 @@ class AttentionLayers(Module):
2373
2386
  mems = None,
2374
2387
  mem_masks = None,
2375
2388
  seq_start_pos: Tensor | None = None,
2389
+ seq_pos_offset: int = 0,
2376
2390
  cache: LayerIntermediates | None = None,
2391
+ input_not_include_cache = False,
2377
2392
  cache_age = 1,
2378
2393
  return_hiddens = False,
2379
2394
  rotary_pos_emb = None,
@@ -2447,7 +2462,7 @@ class AttentionLayers(Module):
2447
2462
  mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
2448
2463
 
2449
2464
  if not exists(pos):
2450
- pos = arange(x.shape[1] + mem_len, device = x.device) - mem_len
2465
+ pos = arange(x.shape[1] + mem_len + seq_pos_offset, device = x.device) - mem_len
2451
2466
 
2452
2467
  rotary_pos_emb = self.rotary_pos_emb(pos)
2453
2468
 
@@ -2464,11 +2479,15 @@ class AttentionLayers(Module):
2464
2479
 
2465
2480
  # assume cached key / values
2466
2481
 
2482
+ prev_cache_length = 0
2483
+
2467
2484
  attn_cache = []
2468
2485
 
2469
2486
  if exists(cache):
2470
2487
  assert self.causal and not any([*map(exists, (mask, attn_mask))])
2471
2488
 
2489
+ prev_cache_length = cache.cache_length
2490
+
2472
2491
  if exists(context):
2473
2492
  context = context[:, :0]
2474
2493
 
@@ -2482,6 +2501,8 @@ class AttentionLayers(Module):
2482
2501
 
2483
2502
  attn_cache = cache.attn_intermediates
2484
2503
 
2504
+ next_cache_length = x.shape[1]
2505
+
2485
2506
  iter_attn_cache = iter(attn_cache)
2486
2507
 
2487
2508
  # handle deep embeds if needed
@@ -2668,6 +2689,7 @@ class AttentionLayers(Module):
2668
2689
  last_hidden = x,
2669
2690
  attn_intermediates = intermediates,
2670
2691
  layer_hiddens = layer_hiddens,
2692
+ cache_length = next_cache_length + prev_cache_length
2671
2693
  )
2672
2694
 
2673
2695
  return x, intermediates
@@ -3002,6 +3024,7 @@ class TransformerWrapper(Module):
3002
3024
  attn_z_loss_weight = 1e-4,
3003
3025
  seq_start_pos = None,
3004
3026
  cache: LayerIntermediates | None = None,
3027
+ input_not_include_cache = False,
3005
3028
  token_emb_kwargs = dict(),
3006
3029
  to_logits_kwargs = dict(),
3007
3030
  **kwargs,
@@ -3020,10 +3043,17 @@ class TransformerWrapper(Module):
3020
3043
  return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss | return_embeddings_and_intermediates
3021
3044
  return_embeddings = return_embeddings | (not exists(self.to_logits)) | return_embeddings_and_intermediates
3022
3045
 
3046
+ # take care of position embedding offsets in the presence of cache and sequence is less than cache length (not full sequence)
3047
+
3048
+ seq_pos_offset = 0
3049
+
3050
+ if exists(cache) and input_not_include_cache:
3051
+ seq_pos_offset = cache.cache_length
3052
+
3023
3053
  # absolute positional embedding
3024
3054
 
3025
3055
  external_pos_emb = exists(pos) and pos.dtype != torch.long
3026
- pos_emb = self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos
3056
+ pos_emb = self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos, offset = seq_pos_offset) if not external_pos_emb else pos
3027
3057
  x = self.token_emb(x, **token_emb_kwargs) + pos_emb
3028
3058
 
3029
3059
  # add additional embeddings
@@ -3122,6 +3152,15 @@ class TransformerWrapper(Module):
3122
3152
  mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
3123
3153
  mems = [*mems_r, *mems_l]
3124
3154
 
3155
+ # attn layers kwargs
3156
+
3157
+ kwargs = dict(
3158
+ **kwargs,
3159
+ seq_pos_offset = seq_pos_offset,
3160
+ seq_start_pos = seq_start_pos,
3161
+ input_not_include_cache = input_not_include_cache
3162
+ )
3163
+
3125
3164
  # attention layers
3126
3165
 
3127
3166
  if not self.recycling:
@@ -3129,7 +3168,7 @@ class TransformerWrapper(Module):
3129
3168
 
3130
3169
  # regular
3131
3170
 
3132
- attended, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, deep_embeds_and_ids = deep_embed_and_ids, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
3171
+ attended, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, deep_embeds_and_ids = deep_embed_and_ids, return_hiddens = True, **kwargs)
3133
3172
 
3134
3173
  else:
3135
3174
  # recycling
@@ -3146,7 +3185,7 @@ class TransformerWrapper(Module):
3146
3185
  with context():
3147
3186
  maybe_recycled = self.recycled_proj(attended.detach()) if not first_step else 0.
3148
3187
 
3149
- attended, intermediates = self.attn_layers(x + maybe_recycled, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
3188
+ attended, intermediates = self.attn_layers(x + maybe_recycled, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, **kwargs)
3150
3189
 
3151
3190
  x = attended
3152
3191
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.3.19
3
+ Version: 2.3.20
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -8,10 +8,10 @@ x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaY
8
8
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
9
9
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
10
10
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
11
- x_transformers/x_transformers.py,sha256=s398YQ9JtXc5n34g9qaYnUqaTVLGfRvz0GLg3sEMHLI,114558
11
+ x_transformers/x_transformers.py,sha256=l2p-r0iJNlYHUB3vM4lb6ptzNCx9HgA7UfgieEcQT6w,115521
12
12
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
13
13
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
14
- x_transformers-2.3.19.dist-info/METADATA,sha256=Vn-U7mDaP7H-w-RF5YO3C5n9M5PvnDVKqFJwL3vFV0s,89897
15
- x_transformers-2.3.19.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
- x_transformers-2.3.19.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
- x_transformers-2.3.19.dist-info/RECORD,,
14
+ x_transformers-2.3.20.dist-info/METADATA,sha256=ygWyfnlIh2Mw6bd12gJjjZJyM9vfnXmvvOLyrd2El2k,89897
15
+ x_transformers-2.3.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
+ x_transformers-2.3.20.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
+ x_transformers-2.3.20.dist-info/RECORD,,