x-transformers 2.3.18__py3-none-any.whl → 2.3.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +1 -0
- x_transformers/continuous.py +11 -5
- x_transformers/x_transformers.py +64 -12
- {x_transformers-2.3.18.dist-info → x_transformers-2.3.20.dist-info}/METADATA +1 -1
- {x_transformers-2.3.18.dist-info → x_transformers-2.3.20.dist-info}/RECORD +7 -7
- {x_transformers-2.3.18.dist-info → x_transformers-2.3.20.dist-info}/WHEEL +0 -0
- {x_transformers-2.3.18.dist-info → x_transformers-2.3.20.dist-info}/licenses/LICENSE +0 -0
x_transformers/attend.py
CHANGED
@@ -25,6 +25,7 @@ class Intermediates:
|
|
25
25
|
values: Tensor | None = None
|
26
26
|
cached_kv: Tuple[Tensor, Tensor] | None = None
|
27
27
|
layer_type: str | None = None
|
28
|
+
hybrid_hidden: Tensor | None = None
|
28
29
|
|
29
30
|
def to_tuple(self):
|
30
31
|
return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn)
|
x_transformers/continuous.py
CHANGED
@@ -32,6 +32,15 @@ def default(val, d):
|
|
32
32
|
return val
|
33
33
|
return d() if not isinstance(d, Module) and callable(d) else d
|
34
34
|
|
35
|
+
def sample_from_mean_variance(
|
36
|
+
mean,
|
37
|
+
variance,
|
38
|
+
eps = 1e-5,
|
39
|
+
temperature = 1.
|
40
|
+
):
|
41
|
+
std = variance.clamp(min = eps).sqrt()
|
42
|
+
return torch.normal(mean, std * temperature)
|
43
|
+
|
35
44
|
def masked_mean(t, mask):
|
36
45
|
t = einx.where('b n, b n d, -> b n d', mask, t, 0.)
|
37
46
|
|
@@ -274,9 +283,7 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
274
283
|
|
275
284
|
if self.probabilistic:
|
276
285
|
mean, var = last_output
|
277
|
-
|
278
|
-
|
279
|
-
last_output = torch.normal(mean, stddev * temperature)
|
286
|
+
last_output = sample_from_mean_variance(mean, var, temperature = temperature)
|
280
287
|
|
281
288
|
out = cat((out, last_output), dim = -2)
|
282
289
|
|
@@ -372,8 +379,7 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
372
379
|
|
373
380
|
if self.probabilistic:
|
374
381
|
mean, var = last_pred
|
375
|
-
|
376
|
-
inp = torch.normal(mean, std)
|
382
|
+
inp = sample_from_mean_variance(mean, var)
|
377
383
|
else:
|
378
384
|
inp = last_pred
|
379
385
|
|
x_transformers/x_transformers.py
CHANGED
@@ -49,6 +49,7 @@ class LayerIntermediates:
|
|
49
49
|
mems: Tensor | None = None
|
50
50
|
memory_tokens: Tensor | None = None
|
51
51
|
logit_entropies: Tensor | None = None
|
52
|
+
cache_length: int = 0
|
52
53
|
|
53
54
|
LinearNoBias = partial(nn.Linear, bias = False)
|
54
55
|
|
@@ -282,12 +283,18 @@ class AbsolutePositionalEmbedding(Module):
|
|
282
283
|
self.l2norm_embed = l2norm_embed
|
283
284
|
self.emb = nn.Embedding(max_seq_len, dim)
|
284
285
|
|
285
|
-
def forward(
|
286
|
+
def forward(
|
287
|
+
self,
|
288
|
+
x,
|
289
|
+
pos = None,
|
290
|
+
seq_start_pos = None,
|
291
|
+
offset = 0
|
292
|
+
):
|
286
293
|
seq_len, device = x.shape[1], x.device
|
287
294
|
assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
|
288
295
|
|
289
296
|
if not exists(pos):
|
290
|
-
pos = arange(seq_len, device = device)
|
297
|
+
pos = arange(seq_len, device = device) + offset
|
291
298
|
|
292
299
|
if exists(seq_start_pos):
|
293
300
|
pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
|
@@ -307,11 +314,17 @@ class ScaledSinusoidalEmbedding(Module):
|
|
307
314
|
inv_freq = theta ** -freq_seq
|
308
315
|
self.register_buffer('inv_freq', inv_freq, persistent = False)
|
309
316
|
|
310
|
-
def forward(
|
317
|
+
def forward(
|
318
|
+
self,
|
319
|
+
x,
|
320
|
+
pos = None,
|
321
|
+
seq_start_pos = None,
|
322
|
+
offset = 0
|
323
|
+
):
|
311
324
|
seq_len, device = x.shape[1], x.device
|
312
325
|
|
313
326
|
if not exists(pos):
|
314
|
-
pos = arange(seq_len, device = device)
|
327
|
+
pos = arange(seq_len, device = device) + offset
|
315
328
|
|
316
329
|
if exists(seq_start_pos):
|
317
330
|
pos = pos - seq_start_pos[..., None]
|
@@ -676,7 +689,7 @@ class RotaryEmbedding(Module):
|
|
676
689
|
return self.forward(t)
|
677
690
|
|
678
691
|
@autocast('cuda', enabled = False)
|
679
|
-
def forward(self, t):
|
692
|
+
def forward(self, t, offset = 0):
|
680
693
|
max_pos = t.max() + 1
|
681
694
|
|
682
695
|
if t.ndim == 1:
|
@@ -1079,10 +1092,11 @@ class FoldAxially(Module):
|
|
1079
1092
|
def forward(
|
1080
1093
|
self,
|
1081
1094
|
x,
|
1095
|
+
*args,
|
1082
1096
|
**kwargs
|
1083
1097
|
):
|
1084
1098
|
if self.axial_dim == 1:
|
1085
|
-
return self.fn(x, **kwargs)
|
1099
|
+
return self.fn(x, *args, **kwargs)
|
1086
1100
|
|
1087
1101
|
seq_len, axial_dim = x.shape[1], self.axial_dim
|
1088
1102
|
|
@@ -1091,7 +1105,7 @@ class FoldAxially(Module):
|
|
1091
1105
|
|
1092
1106
|
x = rearrange(x, 'b (n axial_dim) ... -> (b axial_dim) n ...', axial_dim = axial_dim)
|
1093
1107
|
|
1094
|
-
out = self.fn(x, **kwargs)
|
1108
|
+
out = self.fn(x, *args, **kwargs)
|
1095
1109
|
|
1096
1110
|
(out, *rest_out), tree_spec = tree_flatten(out)
|
1097
1111
|
|
@@ -1857,9 +1871,17 @@ class Attention(Module):
|
|
1857
1871
|
if not self.causal and exists(self.hybrid_mask_kwarg):
|
1858
1872
|
hybrid_forward_kwargs = {self.hybrid_mask_kwarg: mask}
|
1859
1873
|
|
1874
|
+
# handle maybe hybrid cache
|
1875
|
+
|
1876
|
+
hybrid_forward_args = ()
|
1877
|
+
|
1878
|
+
if exists(cache) and exists(cache.hybrid_hidden):
|
1879
|
+
hybrid_hiddens = cache.hybrid_hidden
|
1880
|
+
hybrid_forward_args = (hybrid_hiddens,)
|
1881
|
+
|
1860
1882
|
# hybrid forward
|
1861
1883
|
|
1862
|
-
hybrid_outputs = self.hybrid_module(x, **hybrid_forward_kwargs)
|
1884
|
+
hybrid_outputs = self.hybrid_module(x, *hybrid_forward_args, **hybrid_forward_kwargs)
|
1863
1885
|
|
1864
1886
|
# handle hybrid out
|
1865
1887
|
|
@@ -1870,6 +1892,10 @@ class Attention(Module):
|
|
1870
1892
|
if hybrid_out.ndim == 3:
|
1871
1893
|
hybrid_out = rearrange(hybrid_out, 'b n (h d) -> b h n d', h = h)
|
1872
1894
|
|
1895
|
+
if len(rest_hybrid_outs) > 0:
|
1896
|
+
hybrid_hidden = first(rest_hybrid_outs)
|
1897
|
+
intermediates.hybrid_hidden = hybrid_hidden
|
1898
|
+
|
1873
1899
|
out_norm, hybrid_out_norm = self.hybrid_norms
|
1874
1900
|
|
1875
1901
|
out = out_norm(out)
|
@@ -2360,7 +2386,9 @@ class AttentionLayers(Module):
|
|
2360
2386
|
mems = None,
|
2361
2387
|
mem_masks = None,
|
2362
2388
|
seq_start_pos: Tensor | None = None,
|
2389
|
+
seq_pos_offset: int = 0,
|
2363
2390
|
cache: LayerIntermediates | None = None,
|
2391
|
+
input_not_include_cache = False,
|
2364
2392
|
cache_age = 1,
|
2365
2393
|
return_hiddens = False,
|
2366
2394
|
rotary_pos_emb = None,
|
@@ -2434,7 +2462,7 @@ class AttentionLayers(Module):
|
|
2434
2462
|
mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
|
2435
2463
|
|
2436
2464
|
if not exists(pos):
|
2437
|
-
pos = arange(x.shape[1] + mem_len, device = x.device) - mem_len
|
2465
|
+
pos = arange(x.shape[1] + mem_len + seq_pos_offset, device = x.device) - mem_len
|
2438
2466
|
|
2439
2467
|
rotary_pos_emb = self.rotary_pos_emb(pos)
|
2440
2468
|
|
@@ -2451,11 +2479,15 @@ class AttentionLayers(Module):
|
|
2451
2479
|
|
2452
2480
|
# assume cached key / values
|
2453
2481
|
|
2482
|
+
prev_cache_length = 0
|
2483
|
+
|
2454
2484
|
attn_cache = []
|
2455
2485
|
|
2456
2486
|
if exists(cache):
|
2457
2487
|
assert self.causal and not any([*map(exists, (mask, attn_mask))])
|
2458
2488
|
|
2489
|
+
prev_cache_length = cache.cache_length
|
2490
|
+
|
2459
2491
|
if exists(context):
|
2460
2492
|
context = context[:, :0]
|
2461
2493
|
|
@@ -2469,6 +2501,8 @@ class AttentionLayers(Module):
|
|
2469
2501
|
|
2470
2502
|
attn_cache = cache.attn_intermediates
|
2471
2503
|
|
2504
|
+
next_cache_length = x.shape[1]
|
2505
|
+
|
2472
2506
|
iter_attn_cache = iter(attn_cache)
|
2473
2507
|
|
2474
2508
|
# handle deep embeds if needed
|
@@ -2655,6 +2689,7 @@ class AttentionLayers(Module):
|
|
2655
2689
|
last_hidden = x,
|
2656
2690
|
attn_intermediates = intermediates,
|
2657
2691
|
layer_hiddens = layer_hiddens,
|
2692
|
+
cache_length = next_cache_length + prev_cache_length
|
2658
2693
|
)
|
2659
2694
|
|
2660
2695
|
return x, intermediates
|
@@ -2989,6 +3024,7 @@ class TransformerWrapper(Module):
|
|
2989
3024
|
attn_z_loss_weight = 1e-4,
|
2990
3025
|
seq_start_pos = None,
|
2991
3026
|
cache: LayerIntermediates | None = None,
|
3027
|
+
input_not_include_cache = False,
|
2992
3028
|
token_emb_kwargs = dict(),
|
2993
3029
|
to_logits_kwargs = dict(),
|
2994
3030
|
**kwargs,
|
@@ -3007,10 +3043,17 @@ class TransformerWrapper(Module):
|
|
3007
3043
|
return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss | return_embeddings_and_intermediates
|
3008
3044
|
return_embeddings = return_embeddings | (not exists(self.to_logits)) | return_embeddings_and_intermediates
|
3009
3045
|
|
3046
|
+
# take care of position embedding offsets in the presence of cache and sequence is less than cache length (not full sequence)
|
3047
|
+
|
3048
|
+
seq_pos_offset = 0
|
3049
|
+
|
3050
|
+
if exists(cache) and input_not_include_cache:
|
3051
|
+
seq_pos_offset = cache.cache_length
|
3052
|
+
|
3010
3053
|
# absolute positional embedding
|
3011
3054
|
|
3012
3055
|
external_pos_emb = exists(pos) and pos.dtype != torch.long
|
3013
|
-
pos_emb = self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos
|
3056
|
+
pos_emb = self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos, offset = seq_pos_offset) if not external_pos_emb else pos
|
3014
3057
|
x = self.token_emb(x, **token_emb_kwargs) + pos_emb
|
3015
3058
|
|
3016
3059
|
# add additional embeddings
|
@@ -3109,6 +3152,15 @@ class TransformerWrapper(Module):
|
|
3109
3152
|
mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
|
3110
3153
|
mems = [*mems_r, *mems_l]
|
3111
3154
|
|
3155
|
+
# attn layers kwargs
|
3156
|
+
|
3157
|
+
kwargs = dict(
|
3158
|
+
**kwargs,
|
3159
|
+
seq_pos_offset = seq_pos_offset,
|
3160
|
+
seq_start_pos = seq_start_pos,
|
3161
|
+
input_not_include_cache = input_not_include_cache
|
3162
|
+
)
|
3163
|
+
|
3112
3164
|
# attention layers
|
3113
3165
|
|
3114
3166
|
if not self.recycling:
|
@@ -3116,7 +3168,7 @@ class TransformerWrapper(Module):
|
|
3116
3168
|
|
3117
3169
|
# regular
|
3118
3170
|
|
3119
|
-
attended, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, deep_embeds_and_ids = deep_embed_and_ids, return_hiddens = True,
|
3171
|
+
attended, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, deep_embeds_and_ids = deep_embed_and_ids, return_hiddens = True, **kwargs)
|
3120
3172
|
|
3121
3173
|
else:
|
3122
3174
|
# recycling
|
@@ -3133,7 +3185,7 @@ class TransformerWrapper(Module):
|
|
3133
3185
|
with context():
|
3134
3186
|
maybe_recycled = self.recycled_proj(attended.detach()) if not first_step else 0.
|
3135
3187
|
|
3136
|
-
attended, intermediates = self.attn_layers(x + maybe_recycled, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True,
|
3188
|
+
attended, intermediates = self.attn_layers(x + maybe_recycled, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, **kwargs)
|
3137
3189
|
|
3138
3190
|
x = attended
|
3139
3191
|
|
@@ -1,17 +1,17 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,985
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=fXMuwHuBAFB4f4_U6j5_uVeK7N4cV0PDd6UTqtkjKKM,17333
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=LW1gr3cFONDEPA_HHhaTE7mk-JWbaINuB1fc_DfbCqw,10791
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
|
-
x_transformers/continuous.py,sha256=
|
5
|
+
x_transformers/continuous.py,sha256=CHta8vizKl85n220fv5278fwjSU-vrN_FBy-m831_go,12551
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
|
8
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
9
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
10
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
11
|
-
x_transformers/x_transformers.py,sha256=
|
11
|
+
x_transformers/x_transformers.py,sha256=l2p-r0iJNlYHUB3vM4lb6ptzNCx9HgA7UfgieEcQT6w,115521
|
12
12
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
13
13
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
14
|
-
x_transformers-2.3.
|
15
|
-
x_transformers-2.3.
|
16
|
-
x_transformers-2.3.
|
17
|
-
x_transformers-2.3.
|
14
|
+
x_transformers-2.3.20.dist-info/METADATA,sha256=ygWyfnlIh2Mw6bd12gJjjZJyM9vfnXmvvOLyrd2El2k,89897
|
15
|
+
x_transformers-2.3.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
x_transformers-2.3.20.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
17
|
+
x_transformers-2.3.20.dist-info/RECORD,,
|
File without changes
|
File without changes
|