x-transformers 1.25.15__py3-none-any.whl → 1.26.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +38 -13
- {x_transformers-1.25.15.dist-info → x_transformers-1.26.1.dist-info}/METADATA +1 -1
- {x_transformers-1.25.15.dist-info → x_transformers-1.26.1.dist-info}/RECORD +6 -6
- {x_transformers-1.25.15.dist-info → x_transformers-1.26.1.dist-info}/LICENSE +0 -0
- {x_transformers-1.25.15.dist-info → x_transformers-1.26.1.dist-info}/WHEEL +0 -0
- {x_transformers-1.25.15.dist-info → x_transformers-1.26.1.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import math
|
2
2
|
from random import random
|
3
|
+
from typing import Dict
|
3
4
|
|
4
5
|
import torch
|
5
6
|
from torch import nn, einsum, Tensor
|
@@ -197,7 +198,7 @@ class TokenEmbedding(nn.Module):
|
|
197
198
|
self.emb = nn.Embedding(num_tokens, dim)
|
198
199
|
|
199
200
|
def forward(self, x):
|
200
|
-
token_emb = self.emb(x)
|
201
|
+
token_emb = self.emb(x.long())
|
201
202
|
return l2norm(token_emb) if self.l2norm_embed else token_emb
|
202
203
|
|
203
204
|
# positional embeddings
|
@@ -425,14 +426,15 @@ class RotaryEmbedding(nn.Module):
|
|
425
426
|
self.scale_base = scale_base
|
426
427
|
self.register_buffer('scale', scale)
|
427
428
|
|
428
|
-
|
429
|
-
def forward(self, seq_arange_or_len: Union[int, Tensor]):
|
429
|
+
def forward_from_seq_len(self, seq_len):
|
430
430
|
device = self.inv_freq.device
|
431
431
|
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
432
|
+
t = torch.arange(seq_len, device = device)
|
433
|
+
return self.forward(t)
|
434
|
+
|
435
|
+
@autocast(enabled = False)
|
436
|
+
def forward(self, t):
|
437
|
+
device = self.inv_freq.device
|
436
438
|
|
437
439
|
t = t.type_as(self.inv_freq)
|
438
440
|
|
@@ -798,7 +800,8 @@ class Attention(nn.Module):
|
|
798
800
|
return_intermediates = False,
|
799
801
|
cache: Optional[Intermediates] = None,
|
800
802
|
):
|
801
|
-
b, n,
|
803
|
+
b, n, h, kv_h, head_scale, device, has_context = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, x.device, exists(context)
|
804
|
+
|
802
805
|
kv_input = default(context, x)
|
803
806
|
|
804
807
|
q_input = x
|
@@ -1235,7 +1238,7 @@ class AttentionLayers(nn.Module):
|
|
1235
1238
|
|
1236
1239
|
if not exists(rotary_pos_emb) and exists(self.rotary_pos_emb):
|
1237
1240
|
max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
|
1238
|
-
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length)
|
1241
|
+
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(max_rotary_emb_length)
|
1239
1242
|
|
1240
1243
|
# assume cached key / values
|
1241
1244
|
|
@@ -1361,7 +1364,7 @@ class PrefixDecoder(AttentionLayers):
|
|
1361
1364
|
prefix_attn_len = None,
|
1362
1365
|
**kwargs
|
1363
1366
|
):
|
1364
|
-
b, n, device =
|
1367
|
+
b, n, device = x.shape[0], x.shape[1], x.device
|
1365
1368
|
causal_mask = torch.ones((n, n), device = device, dtype = torch.bool).triu(1)
|
1366
1369
|
|
1367
1370
|
forwarded_mask = ~causal_mask
|
@@ -1462,6 +1465,7 @@ class TransformerWrapper(nn.Module):
|
|
1462
1465
|
num_tokens,
|
1463
1466
|
max_seq_len,
|
1464
1467
|
attn_layers: AttentionLayers,
|
1468
|
+
embed_num_tokens: Dict[str, int] = dict(),
|
1465
1469
|
emb_dim = None,
|
1466
1470
|
max_mem_len = 0,
|
1467
1471
|
shift_mem_down = 0,
|
@@ -1491,14 +1495,23 @@ class TransformerWrapper(nn.Module):
|
|
1491
1495
|
self.l2norm_embed = l2norm_embed
|
1492
1496
|
self.token_emb = TokenEmbedding(emb_dim, num_tokens, l2norm_embed = l2norm_embed)
|
1493
1497
|
|
1494
|
-
if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
|
1498
|
+
if max_seq_len == 0 or not (use_abs_pos_emb and not attn_layers.has_pos_emb):
|
1495
1499
|
self.pos_emb = always(0)
|
1496
1500
|
elif scaled_sinu_pos_emb:
|
1497
1501
|
self.pos_emb = ScaledSinusoidalEmbedding(emb_dim)
|
1498
1502
|
else:
|
1499
1503
|
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len, l2norm_embed = l2norm_embed)
|
1500
1504
|
|
1501
|
-
|
1505
|
+
# additional embeddings - say type embedding from BERT
|
1506
|
+
|
1507
|
+
self.embeds = None
|
1508
|
+
|
1509
|
+
if len(embed_num_tokens) > 0:
|
1510
|
+
self.embeds = nn.ModuleDict({f'{name}_embed': nn.Embedding(num_tokens, emb_dim) for name, num_tokens in embed_num_tokens.items()})
|
1511
|
+
|
1512
|
+
# fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290
|
1513
|
+
|
1514
|
+
self.emb_frac_gradient = emb_frac_gradient
|
1502
1515
|
|
1503
1516
|
self.post_emb_norm = nn.LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
|
1504
1517
|
self.emb_dropout = nn.Dropout(emb_dropout)
|
@@ -1546,6 +1559,7 @@ class TransformerWrapper(nn.Module):
|
|
1546
1559
|
pos = None,
|
1547
1560
|
prepend_embeds = None,
|
1548
1561
|
prepend_mask = None,
|
1562
|
+
embed_ids: Dict[str, Tensor] = None,
|
1549
1563
|
sum_embeds = None,
|
1550
1564
|
return_attn_z_loss = False,
|
1551
1565
|
attn_z_loss_weight = 1e-4,
|
@@ -1553,7 +1567,7 @@ class TransformerWrapper(nn.Module):
|
|
1553
1567
|
cache: Optional[LayerIntermediates] = None,
|
1554
1568
|
**kwargs
|
1555
1569
|
):
|
1556
|
-
b, n, device, num_mems, has_memory_tokens, emb_frac_gradient =
|
1570
|
+
b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient
|
1557
1571
|
return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
|
1558
1572
|
|
1559
1573
|
# absolute positional embedding
|
@@ -1562,6 +1576,17 @@ class TransformerWrapper(nn.Module):
|
|
1562
1576
|
pos_emb = self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos
|
1563
1577
|
x = self.token_emb(x) + pos_emb
|
1564
1578
|
|
1579
|
+
# add additional embeddings
|
1580
|
+
|
1581
|
+
if exists(self.embeds) and exists(embed_ids):
|
1582
|
+
for name, embed_id in embed_ids.items():
|
1583
|
+
embed_key = f'{name}_embed'
|
1584
|
+
|
1585
|
+
assert embed_key in self.embeds
|
1586
|
+
embed = self.embeds[embed_key](embed_id)
|
1587
|
+
|
1588
|
+
x = x + embed
|
1589
|
+
|
1565
1590
|
# for summing embeddings passed externally - needs this for self-conditioning in non-autoregressive training
|
1566
1591
|
|
1567
1592
|
if exists(sum_embeds):
|
@@ -3,11 +3,11 @@ x_transformers/attend.py,sha256=MFl_FbgPsm9mziZPTi_s8QbxASETwbGeciMH8sUIwT8,1018
|
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=2mzOq_rl_vevgrxCDncBlVJOAJGS-XGm-iBKJqMjj_c,9041
|
4
4
|
x_transformers/continuous.py,sha256=ixfgi2_zpGN03SX_STXFkNYEOAkgwVIxuS53QgDCx-g,6026
|
5
5
|
x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
|
6
|
-
x_transformers/x_transformers.py,sha256=
|
6
|
+
x_transformers/x_transformers.py,sha256=k-_pU9lICUYnCumwimPv3VaaxjpKOeFKDdKEWgnvslc,61597
|
7
7
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
8
8
|
x_transformers/xval.py,sha256=lS9W_E_RskPQAqVZkPiUzbByoW1Ajsw_phsikA3JDAg,8139
|
9
|
-
x_transformers-1.
|
10
|
-
x_transformers-1.
|
11
|
-
x_transformers-1.
|
12
|
-
x_transformers-1.
|
13
|
-
x_transformers-1.
|
9
|
+
x_transformers-1.26.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
10
|
+
x_transformers-1.26.1.dist-info/METADATA,sha256=mJsZmH-mfNGzd7eGYglFH6GlNzGUTeugQHLcH8oDecU,661
|
11
|
+
x_transformers-1.26.1.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
12
|
+
x_transformers-1.26.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
13
|
+
x_transformers-1.26.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|