x-transformers 1.26.0__py3-none-any.whl → 1.26.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +24 -1
- {x_transformers-1.26.0.dist-info → x_transformers-1.26.1.dist-info}/METADATA +1 -1
- {x_transformers-1.26.0.dist-info → x_transformers-1.26.1.dist-info}/RECORD +6 -6
- {x_transformers-1.26.0.dist-info → x_transformers-1.26.1.dist-info}/LICENSE +0 -0
- {x_transformers-1.26.0.dist-info → x_transformers-1.26.1.dist-info}/WHEEL +0 -0
- {x_transformers-1.26.0.dist-info → x_transformers-1.26.1.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import math
|
2
2
|
from random import random
|
3
|
+
from typing import Dict
|
3
4
|
|
4
5
|
import torch
|
5
6
|
from torch import nn, einsum, Tensor
|
@@ -1464,6 +1465,7 @@ class TransformerWrapper(nn.Module):
|
|
1464
1465
|
num_tokens,
|
1465
1466
|
max_seq_len,
|
1466
1467
|
attn_layers: AttentionLayers,
|
1468
|
+
embed_num_tokens: Dict[str, int] = dict(),
|
1467
1469
|
emb_dim = None,
|
1468
1470
|
max_mem_len = 0,
|
1469
1471
|
shift_mem_down = 0,
|
@@ -1500,7 +1502,16 @@ class TransformerWrapper(nn.Module):
|
|
1500
1502
|
else:
|
1501
1503
|
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len, l2norm_embed = l2norm_embed)
|
1502
1504
|
|
1503
|
-
|
1505
|
+
# additional embeddings - say type embedding from BERT
|
1506
|
+
|
1507
|
+
self.embeds = None
|
1508
|
+
|
1509
|
+
if len(embed_num_tokens) > 0:
|
1510
|
+
self.embeds = nn.ModuleDict({f'{name}_embed': nn.Embedding(num_tokens, emb_dim) for name, num_tokens in embed_num_tokens.items()})
|
1511
|
+
|
1512
|
+
# fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290
|
1513
|
+
|
1514
|
+
self.emb_frac_gradient = emb_frac_gradient
|
1504
1515
|
|
1505
1516
|
self.post_emb_norm = nn.LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
|
1506
1517
|
self.emb_dropout = nn.Dropout(emb_dropout)
|
@@ -1548,6 +1559,7 @@ class TransformerWrapper(nn.Module):
|
|
1548
1559
|
pos = None,
|
1549
1560
|
prepend_embeds = None,
|
1550
1561
|
prepend_mask = None,
|
1562
|
+
embed_ids: Dict[str, Tensor] = None,
|
1551
1563
|
sum_embeds = None,
|
1552
1564
|
return_attn_z_loss = False,
|
1553
1565
|
attn_z_loss_weight = 1e-4,
|
@@ -1564,6 +1576,17 @@ class TransformerWrapper(nn.Module):
|
|
1564
1576
|
pos_emb = self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos
|
1565
1577
|
x = self.token_emb(x) + pos_emb
|
1566
1578
|
|
1579
|
+
# add additional embeddings
|
1580
|
+
|
1581
|
+
if exists(self.embeds) and exists(embed_ids):
|
1582
|
+
for name, embed_id in embed_ids.items():
|
1583
|
+
embed_key = f'{name}_embed'
|
1584
|
+
|
1585
|
+
assert embed_key in self.embeds
|
1586
|
+
embed = self.embeds[embed_key](embed_id)
|
1587
|
+
|
1588
|
+
x = x + embed
|
1589
|
+
|
1567
1590
|
# for summing embeddings passed externally - needs this for self-conditioning in non-autoregressive training
|
1568
1591
|
|
1569
1592
|
if exists(sum_embeds):
|
@@ -3,11 +3,11 @@ x_transformers/attend.py,sha256=MFl_FbgPsm9mziZPTi_s8QbxASETwbGeciMH8sUIwT8,1018
|
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=2mzOq_rl_vevgrxCDncBlVJOAJGS-XGm-iBKJqMjj_c,9041
|
4
4
|
x_transformers/continuous.py,sha256=ixfgi2_zpGN03SX_STXFkNYEOAkgwVIxuS53QgDCx-g,6026
|
5
5
|
x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
|
6
|
-
x_transformers/x_transformers.py,sha256=
|
6
|
+
x_transformers/x_transformers.py,sha256=k-_pU9lICUYnCumwimPv3VaaxjpKOeFKDdKEWgnvslc,61597
|
7
7
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
8
8
|
x_transformers/xval.py,sha256=lS9W_E_RskPQAqVZkPiUzbByoW1Ajsw_phsikA3JDAg,8139
|
9
|
-
x_transformers-1.26.
|
10
|
-
x_transformers-1.26.
|
11
|
-
x_transformers-1.26.
|
12
|
-
x_transformers-1.26.
|
13
|
-
x_transformers-1.26.
|
9
|
+
x_transformers-1.26.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
10
|
+
x_transformers-1.26.1.dist-info/METADATA,sha256=mJsZmH-mfNGzd7eGYglFH6GlNzGUTeugQHLcH8oDecU,661
|
11
|
+
x_transformers-1.26.1.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
12
|
+
x_transformers-1.26.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
13
|
+
x_transformers-1.26.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|