x-transformers 1.25.15__py3-none-any.whl → 1.26.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,6 @@
1
1
  import math
2
2
  from random import random
3
+ from typing import Dict
3
4
 
4
5
  import torch
5
6
  from torch import nn, einsum, Tensor
@@ -197,7 +198,7 @@ class TokenEmbedding(nn.Module):
197
198
  self.emb = nn.Embedding(num_tokens, dim)
198
199
 
199
200
  def forward(self, x):
200
- token_emb = self.emb(x)
201
+ token_emb = self.emb(x.long())
201
202
  return l2norm(token_emb) if self.l2norm_embed else token_emb
202
203
 
203
204
  # positional embeddings
@@ -425,14 +426,15 @@ class RotaryEmbedding(nn.Module):
425
426
  self.scale_base = scale_base
426
427
  self.register_buffer('scale', scale)
427
428
 
428
- @autocast(enabled = False)
429
- def forward(self, seq_arange_or_len: Union[int, Tensor]):
429
+ def forward_from_seq_len(self, seq_len):
430
430
  device = self.inv_freq.device
431
431
 
432
- if isinstance(seq_arange_or_len, int):
433
- t = torch.arange(seq_arange_or_len, device = device)
434
- else:
435
- t = seq_arange_or_len
432
+ t = torch.arange(seq_len, device = device)
433
+ return self.forward(t)
434
+
435
+ @autocast(enabled = False)
436
+ def forward(self, t):
437
+ device = self.inv_freq.device
436
438
 
437
439
  t = t.type_as(self.inv_freq)
438
440
 
@@ -798,7 +800,8 @@ class Attention(nn.Module):
798
800
  return_intermediates = False,
799
801
  cache: Optional[Intermediates] = None,
800
802
  ):
801
- b, n, _, h, kv_h, head_scale, device, has_context = *x.shape, self.heads, self.kv_heads, self.head_scale, x.device, exists(context)
803
+ b, n, h, kv_h, head_scale, device, has_context = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, x.device, exists(context)
804
+
802
805
  kv_input = default(context, x)
803
806
 
804
807
  q_input = x
@@ -1235,7 +1238,7 @@ class AttentionLayers(nn.Module):
1235
1238
 
1236
1239
  if not exists(rotary_pos_emb) and exists(self.rotary_pos_emb):
1237
1240
  max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
1238
- rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length)
1241
+ rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(max_rotary_emb_length)
1239
1242
 
1240
1243
  # assume cached key / values
1241
1244
 
@@ -1361,7 +1364,7 @@ class PrefixDecoder(AttentionLayers):
1361
1364
  prefix_attn_len = None,
1362
1365
  **kwargs
1363
1366
  ):
1364
- b, n, device = *x.shape[:2], x.device
1367
+ b, n, device = x.shape[0], x.shape[1], x.device
1365
1368
  causal_mask = torch.ones((n, n), device = device, dtype = torch.bool).triu(1)
1366
1369
 
1367
1370
  forwarded_mask = ~causal_mask
@@ -1462,6 +1465,7 @@ class TransformerWrapper(nn.Module):
1462
1465
  num_tokens,
1463
1466
  max_seq_len,
1464
1467
  attn_layers: AttentionLayers,
1468
+ embed_num_tokens: Dict[str, int] = dict(),
1465
1469
  emb_dim = None,
1466
1470
  max_mem_len = 0,
1467
1471
  shift_mem_down = 0,
@@ -1491,14 +1495,23 @@ class TransformerWrapper(nn.Module):
1491
1495
  self.l2norm_embed = l2norm_embed
1492
1496
  self.token_emb = TokenEmbedding(emb_dim, num_tokens, l2norm_embed = l2norm_embed)
1493
1497
 
1494
- if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
1498
+ if max_seq_len == 0 or not (use_abs_pos_emb and not attn_layers.has_pos_emb):
1495
1499
  self.pos_emb = always(0)
1496
1500
  elif scaled_sinu_pos_emb:
1497
1501
  self.pos_emb = ScaledSinusoidalEmbedding(emb_dim)
1498
1502
  else:
1499
1503
  self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len, l2norm_embed = l2norm_embed)
1500
1504
 
1501
- self.emb_frac_gradient = emb_frac_gradient # fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290
1505
+ # additional embeddings - say type embedding from BERT
1506
+
1507
+ self.embeds = None
1508
+
1509
+ if len(embed_num_tokens) > 0:
1510
+ self.embeds = nn.ModuleDict({f'{name}_embed': nn.Embedding(num_tokens, emb_dim) for name, num_tokens in embed_num_tokens.items()})
1511
+
1512
+ # fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290
1513
+
1514
+ self.emb_frac_gradient = emb_frac_gradient
1502
1515
 
1503
1516
  self.post_emb_norm = nn.LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
1504
1517
  self.emb_dropout = nn.Dropout(emb_dropout)
@@ -1546,6 +1559,7 @@ class TransformerWrapper(nn.Module):
1546
1559
  pos = None,
1547
1560
  prepend_embeds = None,
1548
1561
  prepend_mask = None,
1562
+ embed_ids: Dict[str, Tensor] = None,
1549
1563
  sum_embeds = None,
1550
1564
  return_attn_z_loss = False,
1551
1565
  attn_z_loss_weight = 1e-4,
@@ -1553,7 +1567,7 @@ class TransformerWrapper(nn.Module):
1553
1567
  cache: Optional[LayerIntermediates] = None,
1554
1568
  **kwargs
1555
1569
  ):
1556
- b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = *x.shape, x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient
1570
+ b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient
1557
1571
  return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
1558
1572
 
1559
1573
  # absolute positional embedding
@@ -1562,6 +1576,17 @@ class TransformerWrapper(nn.Module):
1562
1576
  pos_emb = self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos
1563
1577
  x = self.token_emb(x) + pos_emb
1564
1578
 
1579
+ # add additional embeddings
1580
+
1581
+ if exists(self.embeds) and exists(embed_ids):
1582
+ for name, embed_id in embed_ids.items():
1583
+ embed_key = f'{name}_embed'
1584
+
1585
+ assert embed_key in self.embeds
1586
+ embed = self.embeds[embed_key](embed_id)
1587
+
1588
+ x = x + embed
1589
+
1565
1590
  # for summing embeddings passed externally - needs this for self-conditioning in non-autoregressive training
1566
1591
 
1567
1592
  if exists(sum_embeds):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.25.15
3
+ Version: 1.26.1
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -3,11 +3,11 @@ x_transformers/attend.py,sha256=MFl_FbgPsm9mziZPTi_s8QbxASETwbGeciMH8sUIwT8,1018
3
3
  x_transformers/autoregressive_wrapper.py,sha256=2mzOq_rl_vevgrxCDncBlVJOAJGS-XGm-iBKJqMjj_c,9041
4
4
  x_transformers/continuous.py,sha256=ixfgi2_zpGN03SX_STXFkNYEOAkgwVIxuS53QgDCx-g,6026
5
5
  x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
6
- x_transformers/x_transformers.py,sha256=5T5fOdveqe4vwqGRrNZqJ053pmtw-dJ12AV0nNaWLRc,60814
6
+ x_transformers/x_transformers.py,sha256=k-_pU9lICUYnCumwimPv3VaaxjpKOeFKDdKEWgnvslc,61597
7
7
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
8
8
  x_transformers/xval.py,sha256=lS9W_E_RskPQAqVZkPiUzbByoW1Ajsw_phsikA3JDAg,8139
9
- x_transformers-1.25.15.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
10
- x_transformers-1.25.15.dist-info/METADATA,sha256=npEGmCXBOdJv7S8odHPNOrjTOvFgSxNFjx1377QtYn8,662
11
- x_transformers-1.25.15.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
12
- x_transformers-1.25.15.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
13
- x_transformers-1.25.15.dist-info/RECORD,,
9
+ x_transformers-1.26.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
10
+ x_transformers-1.26.1.dist-info/METADATA,sha256=mJsZmH-mfNGzd7eGYglFH6GlNzGUTeugQHLcH8oDecU,661
11
+ x_transformers-1.26.1.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
12
+ x_transformers-1.26.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
13
+ x_transformers-1.26.1.dist-info/RECORD,,