x-transformers 2.3.19__py3-none-any.whl → 2.3.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +48 -9
- {x_transformers-2.3.19.dist-info → x_transformers-2.3.20.dist-info}/METADATA +1 -1
- {x_transformers-2.3.19.dist-info → x_transformers-2.3.20.dist-info}/RECORD +5 -5
- {x_transformers-2.3.19.dist-info → x_transformers-2.3.20.dist-info}/WHEEL +0 -0
- {x_transformers-2.3.19.dist-info → x_transformers-2.3.20.dist-info}/licenses/LICENSE +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -49,6 +49,7 @@ class LayerIntermediates:
|
|
49
49
|
mems: Tensor | None = None
|
50
50
|
memory_tokens: Tensor | None = None
|
51
51
|
logit_entropies: Tensor | None = None
|
52
|
+
cache_length: int = 0
|
52
53
|
|
53
54
|
LinearNoBias = partial(nn.Linear, bias = False)
|
54
55
|
|
@@ -282,12 +283,18 @@ class AbsolutePositionalEmbedding(Module):
|
|
282
283
|
self.l2norm_embed = l2norm_embed
|
283
284
|
self.emb = nn.Embedding(max_seq_len, dim)
|
284
285
|
|
285
|
-
def forward(
|
286
|
+
def forward(
|
287
|
+
self,
|
288
|
+
x,
|
289
|
+
pos = None,
|
290
|
+
seq_start_pos = None,
|
291
|
+
offset = 0
|
292
|
+
):
|
286
293
|
seq_len, device = x.shape[1], x.device
|
287
294
|
assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
|
288
295
|
|
289
296
|
if not exists(pos):
|
290
|
-
pos = arange(seq_len, device = device)
|
297
|
+
pos = arange(seq_len, device = device) + offset
|
291
298
|
|
292
299
|
if exists(seq_start_pos):
|
293
300
|
pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
|
@@ -307,11 +314,17 @@ class ScaledSinusoidalEmbedding(Module):
|
|
307
314
|
inv_freq = theta ** -freq_seq
|
308
315
|
self.register_buffer('inv_freq', inv_freq, persistent = False)
|
309
316
|
|
310
|
-
def forward(
|
317
|
+
def forward(
|
318
|
+
self,
|
319
|
+
x,
|
320
|
+
pos = None,
|
321
|
+
seq_start_pos = None,
|
322
|
+
offset = 0
|
323
|
+
):
|
311
324
|
seq_len, device = x.shape[1], x.device
|
312
325
|
|
313
326
|
if not exists(pos):
|
314
|
-
pos = arange(seq_len, device = device)
|
327
|
+
pos = arange(seq_len, device = device) + offset
|
315
328
|
|
316
329
|
if exists(seq_start_pos):
|
317
330
|
pos = pos - seq_start_pos[..., None]
|
@@ -676,7 +689,7 @@ class RotaryEmbedding(Module):
|
|
676
689
|
return self.forward(t)
|
677
690
|
|
678
691
|
@autocast('cuda', enabled = False)
|
679
|
-
def forward(self, t):
|
692
|
+
def forward(self, t, offset = 0):
|
680
693
|
max_pos = t.max() + 1
|
681
694
|
|
682
695
|
if t.ndim == 1:
|
@@ -2373,7 +2386,9 @@ class AttentionLayers(Module):
|
|
2373
2386
|
mems = None,
|
2374
2387
|
mem_masks = None,
|
2375
2388
|
seq_start_pos: Tensor | None = None,
|
2389
|
+
seq_pos_offset: int = 0,
|
2376
2390
|
cache: LayerIntermediates | None = None,
|
2391
|
+
input_not_include_cache = False,
|
2377
2392
|
cache_age = 1,
|
2378
2393
|
return_hiddens = False,
|
2379
2394
|
rotary_pos_emb = None,
|
@@ -2447,7 +2462,7 @@ class AttentionLayers(Module):
|
|
2447
2462
|
mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
|
2448
2463
|
|
2449
2464
|
if not exists(pos):
|
2450
|
-
pos = arange(x.shape[1] + mem_len, device = x.device) - mem_len
|
2465
|
+
pos = arange(x.shape[1] + mem_len + seq_pos_offset, device = x.device) - mem_len
|
2451
2466
|
|
2452
2467
|
rotary_pos_emb = self.rotary_pos_emb(pos)
|
2453
2468
|
|
@@ -2464,11 +2479,15 @@ class AttentionLayers(Module):
|
|
2464
2479
|
|
2465
2480
|
# assume cached key / values
|
2466
2481
|
|
2482
|
+
prev_cache_length = 0
|
2483
|
+
|
2467
2484
|
attn_cache = []
|
2468
2485
|
|
2469
2486
|
if exists(cache):
|
2470
2487
|
assert self.causal and not any([*map(exists, (mask, attn_mask))])
|
2471
2488
|
|
2489
|
+
prev_cache_length = cache.cache_length
|
2490
|
+
|
2472
2491
|
if exists(context):
|
2473
2492
|
context = context[:, :0]
|
2474
2493
|
|
@@ -2482,6 +2501,8 @@ class AttentionLayers(Module):
|
|
2482
2501
|
|
2483
2502
|
attn_cache = cache.attn_intermediates
|
2484
2503
|
|
2504
|
+
next_cache_length = x.shape[1]
|
2505
|
+
|
2485
2506
|
iter_attn_cache = iter(attn_cache)
|
2486
2507
|
|
2487
2508
|
# handle deep embeds if needed
|
@@ -2668,6 +2689,7 @@ class AttentionLayers(Module):
|
|
2668
2689
|
last_hidden = x,
|
2669
2690
|
attn_intermediates = intermediates,
|
2670
2691
|
layer_hiddens = layer_hiddens,
|
2692
|
+
cache_length = next_cache_length + prev_cache_length
|
2671
2693
|
)
|
2672
2694
|
|
2673
2695
|
return x, intermediates
|
@@ -3002,6 +3024,7 @@ class TransformerWrapper(Module):
|
|
3002
3024
|
attn_z_loss_weight = 1e-4,
|
3003
3025
|
seq_start_pos = None,
|
3004
3026
|
cache: LayerIntermediates | None = None,
|
3027
|
+
input_not_include_cache = False,
|
3005
3028
|
token_emb_kwargs = dict(),
|
3006
3029
|
to_logits_kwargs = dict(),
|
3007
3030
|
**kwargs,
|
@@ -3020,10 +3043,17 @@ class TransformerWrapper(Module):
|
|
3020
3043
|
return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss | return_embeddings_and_intermediates
|
3021
3044
|
return_embeddings = return_embeddings | (not exists(self.to_logits)) | return_embeddings_and_intermediates
|
3022
3045
|
|
3046
|
+
# take care of position embedding offsets in the presence of cache and sequence is less than cache length (not full sequence)
|
3047
|
+
|
3048
|
+
seq_pos_offset = 0
|
3049
|
+
|
3050
|
+
if exists(cache) and input_not_include_cache:
|
3051
|
+
seq_pos_offset = cache.cache_length
|
3052
|
+
|
3023
3053
|
# absolute positional embedding
|
3024
3054
|
|
3025
3055
|
external_pos_emb = exists(pos) and pos.dtype != torch.long
|
3026
|
-
pos_emb = self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos
|
3056
|
+
pos_emb = self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos, offset = seq_pos_offset) if not external_pos_emb else pos
|
3027
3057
|
x = self.token_emb(x, **token_emb_kwargs) + pos_emb
|
3028
3058
|
|
3029
3059
|
# add additional embeddings
|
@@ -3122,6 +3152,15 @@ class TransformerWrapper(Module):
|
|
3122
3152
|
mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
|
3123
3153
|
mems = [*mems_r, *mems_l]
|
3124
3154
|
|
3155
|
+
# attn layers kwargs
|
3156
|
+
|
3157
|
+
kwargs = dict(
|
3158
|
+
**kwargs,
|
3159
|
+
seq_pos_offset = seq_pos_offset,
|
3160
|
+
seq_start_pos = seq_start_pos,
|
3161
|
+
input_not_include_cache = input_not_include_cache
|
3162
|
+
)
|
3163
|
+
|
3125
3164
|
# attention layers
|
3126
3165
|
|
3127
3166
|
if not self.recycling:
|
@@ -3129,7 +3168,7 @@ class TransformerWrapper(Module):
|
|
3129
3168
|
|
3130
3169
|
# regular
|
3131
3170
|
|
3132
|
-
attended, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, deep_embeds_and_ids = deep_embed_and_ids, return_hiddens = True,
|
3171
|
+
attended, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, deep_embeds_and_ids = deep_embed_and_ids, return_hiddens = True, **kwargs)
|
3133
3172
|
|
3134
3173
|
else:
|
3135
3174
|
# recycling
|
@@ -3146,7 +3185,7 @@ class TransformerWrapper(Module):
|
|
3146
3185
|
with context():
|
3147
3186
|
maybe_recycled = self.recycled_proj(attended.detach()) if not first_step else 0.
|
3148
3187
|
|
3149
|
-
attended, intermediates = self.attn_layers(x + maybe_recycled, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True,
|
3188
|
+
attended, intermediates = self.attn_layers(x + maybe_recycled, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, **kwargs)
|
3150
3189
|
|
3151
3190
|
x = attended
|
3152
3191
|
|
@@ -8,10 +8,10 @@ x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaY
|
|
8
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
9
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
10
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
11
|
-
x_transformers/x_transformers.py,sha256=
|
11
|
+
x_transformers/x_transformers.py,sha256=l2p-r0iJNlYHUB3vM4lb6ptzNCx9HgA7UfgieEcQT6w,115521
|
12
12
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
13
13
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
14
|
-
x_transformers-2.3.
|
15
|
-
x_transformers-2.3.
|
16
|
-
x_transformers-2.3.
|
17
|
-
x_transformers-2.3.
|
14
|
+
x_transformers-2.3.20.dist-info/METADATA,sha256=ygWyfnlIh2Mw6bd12gJjjZJyM9vfnXmvvOLyrd2El2k,89897
|
15
|
+
x_transformers-2.3.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
x_transformers-2.3.20.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
17
|
+
x_transformers-2.3.20.dist-info/RECORD,,
|
File without changes
|
File without changes
|