x-transformers 1.26.0__py3-none-any.whl → 1.26.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,6 @@
1
1
  import math
2
2
  from random import random
3
+ from typing import Dict
3
4
 
4
5
  import torch
5
6
  from torch import nn, einsum, Tensor
@@ -1464,6 +1465,7 @@ class TransformerWrapper(nn.Module):
1464
1465
  num_tokens,
1465
1466
  max_seq_len,
1466
1467
  attn_layers: AttentionLayers,
1468
+ embed_num_tokens: Dict[str, int] = dict(),
1467
1469
  emb_dim = None,
1468
1470
  max_mem_len = 0,
1469
1471
  shift_mem_down = 0,
@@ -1500,7 +1502,16 @@ class TransformerWrapper(nn.Module):
1500
1502
  else:
1501
1503
  self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len, l2norm_embed = l2norm_embed)
1502
1504
 
1503
- self.emb_frac_gradient = emb_frac_gradient # fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290
1505
+ # additional embeddings - say type embedding from BERT
1506
+
1507
+ self.embeds = None
1508
+
1509
+ if len(embed_num_tokens) > 0:
1510
+ self.embeds = nn.ModuleDict({f'{name}_embed': nn.Embedding(num_tokens, emb_dim) for name, num_tokens in embed_num_tokens.items()})
1511
+
1512
+ # fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290
1513
+
1514
+ self.emb_frac_gradient = emb_frac_gradient
1504
1515
 
1505
1516
  self.post_emb_norm = nn.LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
1506
1517
  self.emb_dropout = nn.Dropout(emb_dropout)
@@ -1548,6 +1559,7 @@ class TransformerWrapper(nn.Module):
1548
1559
  pos = None,
1549
1560
  prepend_embeds = None,
1550
1561
  prepend_mask = None,
1562
+ embed_ids: Dict[str, Tensor] = dict(),
1551
1563
  sum_embeds = None,
1552
1564
  return_attn_z_loss = False,
1553
1565
  attn_z_loss_weight = 1e-4,
@@ -1564,6 +1576,19 @@ class TransformerWrapper(nn.Module):
1564
1576
  pos_emb = self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos
1565
1577
  x = self.token_emb(x) + pos_emb
1566
1578
 
1579
+ # add additional embeddings
1580
+
1581
+ if exists(self.embeds):
1582
+ assert len(embed_ids) == len(self.embeds)
1583
+
1584
+ for name, embed_id in embed_ids.items():
1585
+ embed_key = f'{name}_embed'
1586
+
1587
+ assert embed_key in self.embeds
1588
+ embed = self.embeds[embed_key](embed_id)
1589
+
1590
+ x = x + embed
1591
+
1567
1592
  # for summing embeddings passed externally - needs this for self-conditioning in non-autoregressive training
1568
1593
 
1569
1594
  if exists(sum_embeds):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.26.0
3
+ Version: 1.26.2
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -3,11 +3,11 @@ x_transformers/attend.py,sha256=MFl_FbgPsm9mziZPTi_s8QbxASETwbGeciMH8sUIwT8,1018
3
3
  x_transformers/autoregressive_wrapper.py,sha256=2mzOq_rl_vevgrxCDncBlVJOAJGS-XGm-iBKJqMjj_c,9041
4
4
  x_transformers/continuous.py,sha256=ixfgi2_zpGN03SX_STXFkNYEOAkgwVIxuS53QgDCx-g,6026
5
5
  x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
6
- x_transformers/x_transformers.py,sha256=iJLui66WnGDjfoYgT0p1Ziv0Bmqpy0ffR8RPkV5x2ck,60868
6
+ x_transformers/x_transformers.py,sha256=gXq_IpWswCeM06r0VCxvZGTVndsxD69_uvKSJFC1544,61632
7
7
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
8
8
  x_transformers/xval.py,sha256=lS9W_E_RskPQAqVZkPiUzbByoW1Ajsw_phsikA3JDAg,8139
9
- x_transformers-1.26.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
10
- x_transformers-1.26.0.dist-info/METADATA,sha256=wUSqgSth94QEUQWmFjTJFCQxUBa2Cl_aS7QWKP-FTcQ,661
11
- x_transformers-1.26.0.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
12
- x_transformers-1.26.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
13
- x_transformers-1.26.0.dist-info/RECORD,,
9
+ x_transformers-1.26.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
10
+ x_transformers-1.26.2.dist-info/METADATA,sha256=fWgPsKNzN3QgsdPNc-fgtU8UO6pEJTZa1FCFZ4hlbVY,661
11
+ x_transformers-1.26.2.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
12
+ x_transformers-1.26.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
13
+ x_transformers-1.26.2.dist-info/RECORD,,