x-transformers 1.28.2__py3-none-any.whl → 1.28.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,6 +5,8 @@ from x_transformers.x_transformers import (
5
5
  PrefixDecoder,
6
6
  CrossAttender,
7
7
  Attention,
8
+ FeedForward,
9
+ RMSNorm,
8
10
  TransformerWrapper,
9
11
  ViTransformerWrapper
10
12
  )
@@ -3,8 +3,9 @@ from random import random
3
3
  from packaging import version
4
4
 
5
5
  import torch
6
- from torch import nn, einsum, Tensor
7
6
  import torch.nn.functional as F
7
+ from torch import nn, einsum, Tensor
8
+ from torch.nn import Module, ModuleList, ModuleDict
8
9
  from torch.cuda.amp import autocast
9
10
 
10
11
  from functools import partial, wraps
@@ -191,13 +192,13 @@ def dropout_seq(seq, mask, dropout):
191
192
 
192
193
  # activations
193
194
 
194
- class ReluSquared(nn.Module):
195
+ class ReluSquared(Module):
195
196
  def forward(self, x):
196
197
  return F.relu(x) ** 2
197
198
 
198
199
  # embedding
199
200
 
200
- class TokenEmbedding(nn.Module):
201
+ class TokenEmbedding(Module):
201
202
  def __init__(self, dim, num_tokens, l2norm_embed = False):
202
203
  super().__init__()
203
204
  self.l2norm_embed = l2norm_embed
@@ -209,7 +210,7 @@ class TokenEmbedding(nn.Module):
209
210
 
210
211
  # positional embeddings
211
212
 
212
- class AbsolutePositionalEmbedding(nn.Module):
213
+ class AbsolutePositionalEmbedding(Module):
213
214
  def __init__(self, dim, max_seq_len, l2norm_embed = False):
214
215
  super().__init__()
215
216
  self.scale = dim ** -0.5 if not l2norm_embed else 1.
@@ -231,7 +232,7 @@ class AbsolutePositionalEmbedding(nn.Module):
231
232
  pos_emb = pos_emb * self.scale
232
233
  return l2norm(pos_emb) if self.l2norm_embed else pos_emb
233
234
 
234
- class ScaledSinusoidalEmbedding(nn.Module):
235
+ class ScaledSinusoidalEmbedding(Module):
235
236
  def __init__(self, dim, theta = 10000):
236
237
  super().__init__()
237
238
  assert divisible_by(dim, 2)
@@ -255,7 +256,7 @@ class ScaledSinusoidalEmbedding(nn.Module):
255
256
  emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
256
257
  return emb * self.scale
257
258
 
258
- class RelativePositionBias(nn.Module):
259
+ class RelativePositionBias(Module):
259
260
  def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 8):
260
261
  super().__init__()
261
262
  self.scale = scale
@@ -300,13 +301,13 @@ class RelativePositionBias(nn.Module):
300
301
  bias = rearrange(values, 'i j h -> h i j')
301
302
  return bias * self.scale
302
303
 
303
- class DynamicPositionBias(nn.Module):
304
+ class DynamicPositionBias(Module):
304
305
  def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
305
306
  super().__init__()
306
307
  assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
307
308
  self.log_distance = log_distance
308
309
 
309
- self.mlp = nn.ModuleList([])
310
+ self.mlp = ModuleList([])
310
311
 
311
312
  self.mlp.append(Sequential(
312
313
  nn.Linear(1, dim),
@@ -352,7 +353,7 @@ class DynamicPositionBias(nn.Module):
352
353
  bias = rearrange(bias, 'i j h -> h i j')
353
354
  return bias
354
355
 
355
- class AlibiPositionalBias(nn.Module):
356
+ class AlibiPositionalBias(Module):
356
357
  def __init__(self, heads, total_heads, **kwargs):
357
358
  super().__init__()
358
359
  self.heads = heads
@@ -401,7 +402,7 @@ class AlibiPositionalBias(nn.Module):
401
402
 
402
403
  return self.bias
403
404
 
404
- class RotaryEmbedding(nn.Module):
405
+ class RotaryEmbedding(Module):
405
406
  def __init__(
406
407
  self,
407
408
  dim,
@@ -476,7 +477,7 @@ def apply_rotary_pos_emb(t, freqs, scale = 1):
476
477
 
477
478
  # norms
478
479
 
479
- class Scale(nn.Module):
480
+ class Scale(Module):
480
481
  def __init__(self, value, fn):
481
482
  super().__init__()
482
483
  self.value = value
@@ -491,7 +492,7 @@ class Scale(nn.Module):
491
492
 
492
493
  return (scale_fn(out[0]), *out[1:])
493
494
 
494
- class LayerNorm(nn.Module):
495
+ class LayerNorm(Module):
495
496
  def __init__(self, dim):
496
497
  """
497
498
  bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
@@ -506,7 +507,7 @@ class LayerNorm(nn.Module):
506
507
  if version.parse(torch.__version__) >= version.parse('2.1.0'):
507
508
  LayerNorm = partial(nn.LayerNorm, bias = False)
508
509
 
509
- class ScaleNorm(nn.Module):
510
+ class ScaleNorm(Module):
510
511
  def __init__(self, dim):
511
512
  super().__init__()
512
513
  self.scale = dim ** 0.5
@@ -515,7 +516,7 @@ class ScaleNorm(nn.Module):
515
516
  def forward(self, x):
516
517
  return F.normalize(x, dim = -1) * self.scale * self.g
517
518
 
518
- class RMSNorm(nn.Module):
519
+ class RMSNorm(Module):
519
520
  def __init__(self, dim):
520
521
  super().__init__()
521
522
  self.scale = dim ** 0.5
@@ -524,7 +525,7 @@ class RMSNorm(nn.Module):
524
525
  def forward(self, x):
525
526
  return F.normalize(x, dim = -1) * self.scale * self.g
526
527
 
527
- class SimpleRMSNorm(nn.Module):
528
+ class SimpleRMSNorm(Module):
528
529
  def __init__(self, dim):
529
530
  super().__init__()
530
531
  self.scale = dim ** 0.5
@@ -534,7 +535,7 @@ class SimpleRMSNorm(nn.Module):
534
535
 
535
536
  # residual and residual gates
536
537
 
537
- class Residual(nn.Module):
538
+ class Residual(Module):
538
539
  def __init__(self, dim, scale_residual = False, scale_residual_constant = 1.):
539
540
  super().__init__()
540
541
  self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
@@ -549,7 +550,7 @@ class Residual(nn.Module):
549
550
 
550
551
  return x + residual
551
552
 
552
- class GRUGating(nn.Module):
553
+ class GRUGating(Module):
553
554
  def __init__(self, dim, scale_residual = False, **kwargs):
554
555
  super().__init__()
555
556
  self.gru = nn.GRUCell(dim, dim)
@@ -579,7 +580,7 @@ def shift(t, amount, mask = None):
579
580
 
580
581
  return pad_at_dim(t, (amount, -amount), dim = - 2, value = 0.)
581
582
 
582
- class ShiftTokens(nn.Module):
583
+ class ShiftTokens(Module):
583
584
  def __init__(self, shifts, fn):
584
585
  super().__init__()
585
586
  self.fn = fn
@@ -598,7 +599,7 @@ class ShiftTokens(nn.Module):
598
599
 
599
600
  # feedforward
600
601
 
601
- class GLU(nn.Module):
602
+ class GLU(Module):
602
603
  def __init__(
603
604
  self,
604
605
  dim_in,
@@ -615,7 +616,7 @@ class GLU(nn.Module):
615
616
  x, gate = self.proj(x).chunk(2, dim = -1)
616
617
  return x * self.act(gate) * self.mult_bias
617
618
 
618
- class FeedForward(nn.Module):
619
+ class FeedForward(Module):
619
620
  def __init__(
620
621
  self,
621
622
  dim,
@@ -665,7 +666,7 @@ class FeedForward(nn.Module):
665
666
 
666
667
  # attention. it is all we need
667
668
 
668
- class Attention(nn.Module):
669
+ class Attention(Module):
669
670
  def __init__(
670
671
  self,
671
672
  dim,
@@ -996,7 +997,7 @@ class Attention(nn.Module):
996
997
 
997
998
  return out, intermediates
998
999
 
999
- class AttentionLayers(nn.Module):
1000
+ class AttentionLayers(Module):
1000
1001
  def __init__(
1001
1002
  self,
1002
1003
  dim,
@@ -1058,7 +1059,7 @@ class AttentionLayers(nn.Module):
1058
1059
  self.dim = dim
1059
1060
  self.depth = depth
1060
1061
  self.causal = causal
1061
- self.layers = nn.ModuleList([])
1062
+ self.layers = ModuleList([])
1062
1063
 
1063
1064
  self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb))
1064
1065
 
@@ -1215,13 +1216,13 @@ class AttentionLayers(nn.Module):
1215
1216
  post_branch_norm = norm_fn() if sandwich_norm else None
1216
1217
  post_main_norm = norm_fn() if not pre_norm else None
1217
1218
 
1218
- norms = nn.ModuleList([
1219
+ norms = ModuleList([
1219
1220
  pre_branch_norm,
1220
1221
  post_branch_norm,
1221
1222
  post_main_norm
1222
1223
  ])
1223
1224
 
1224
- self.layers.append(nn.ModuleList([
1225
+ self.layers.append(ModuleList([
1225
1226
  norms,
1226
1227
  layer,
1227
1228
  residual
@@ -1427,7 +1428,7 @@ class CrossAttender(AttentionLayers):
1427
1428
  def __init__(self, **kwargs):
1428
1429
  super().__init__(cross_attend = True, only_cross = True, **kwargs)
1429
1430
 
1430
- class ViTransformerWrapper(nn.Module):
1431
+ class ViTransformerWrapper(Module):
1431
1432
  def __init__(
1432
1433
  self,
1433
1434
  *,
@@ -1508,7 +1509,7 @@ class ViTransformerWrapper(nn.Module):
1508
1509
 
1509
1510
  return logits, embed
1510
1511
 
1511
- class TransformerWrapper(nn.Module):
1512
+ class TransformerWrapper(Module):
1512
1513
  def __init__(
1513
1514
  self,
1514
1515
  *,
@@ -1525,6 +1526,7 @@ class TransformerWrapper(nn.Module):
1525
1526
  memory_tokens_interspersed_every = None,
1526
1527
  tie_embedding = False,
1527
1528
  logits_dim = None,
1529
+ num_output_heads = 1,
1528
1530
  use_abs_pos_emb = True,
1529
1531
  scaled_sinu_pos_emb = False,
1530
1532
  l2norm_embed = False,
@@ -1559,7 +1561,7 @@ class TransformerWrapper(nn.Module):
1559
1561
  self.embeds = None
1560
1562
 
1561
1563
  if len(embed_num_tokens) > 0:
1562
- self.embeds = nn.ModuleDict({f'{name}_embed': nn.Embedding(num_tokens, emb_dim) for name, num_tokens in embed_num_tokens.items()})
1564
+ self.embeds = ModuleDict({f'{name}_embed': nn.Embedding(num_tokens, emb_dim) for name, num_tokens in embed_num_tokens.items()})
1563
1565
 
1564
1566
  # fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290
1565
1567
 
@@ -1573,8 +1575,19 @@ class TransformerWrapper(nn.Module):
1573
1575
 
1574
1576
  self.init_()
1575
1577
 
1578
+ # output head, usually to logits of num_tokens
1579
+
1576
1580
  logits_dim = default(logits_dim, num_tokens)
1577
- self.to_logits = nn.Linear(dim, logits_dim, bias = False) if not tie_embedding else lambda t: t @ self.token_emb.emb.weight.t()
1581
+
1582
+ self.has_multiple_heads = False
1583
+
1584
+ if tie_embedding:
1585
+ self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
1586
+ elif num_output_heads > 1:
1587
+ self.has_multiple_heads = True
1588
+ self.to_logits = ModuleList([nn.Linear(dim, logits_dim, bias = False) for _ in range(num_output_heads)])
1589
+ else:
1590
+ self.to_logits = nn.Linear(dim, logits_dim, bias = False)
1578
1591
 
1579
1592
  # memory tokens (like [cls]) from Memory Transformers paper
1580
1593
 
@@ -1705,6 +1718,8 @@ class TransformerWrapper(nn.Module):
1705
1718
 
1706
1719
  x, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
1707
1720
 
1721
+ # handle memories post-attention
1722
+
1708
1723
  if has_memory_tokens:
1709
1724
  if exists(mem_every):
1710
1725
  x = rearrange(x, 'b (n m) d -> (b n) m d', m = (mem_every + num_mems))
@@ -1718,12 +1733,24 @@ class TransformerWrapper(nn.Module):
1718
1733
 
1719
1734
  x = x[:, :n]
1720
1735
 
1736
+ # projecting to logits
1737
+
1738
+ if not return_embeddings:
1739
+ if self.has_multiple_heads:
1740
+ logits = tuple(fn(x) for fn in self.to_logits)
1741
+ else:
1742
+ logits = self.to_logits(x)
1743
+
1744
+ # different returns
1745
+
1721
1746
  if return_logits_and_embeddings:
1722
- out = (self.to_logits(x), x)
1747
+ out = (logits, x)
1723
1748
  elif return_embeddings:
1724
1749
  out = x
1725
1750
  else:
1726
- out = self.to_logits(x)
1751
+ out = logits
1752
+
1753
+ # aux loss
1727
1754
 
1728
1755
  if return_attn_z_loss:
1729
1756
  pre_softmax_attns = list(map(lambda t: t.pre_softmax_attn, intermediates.attn_intermediates))
@@ -1749,7 +1776,7 @@ class TransformerWrapper(nn.Module):
1749
1776
 
1750
1777
  return out
1751
1778
 
1752
- class XTransformer(nn.Module):
1779
+ class XTransformer(Module):
1753
1780
  def __init__(
1754
1781
  self,
1755
1782
  *,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.28.2
3
+ Version: 1.28.5
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,14 +1,14 @@
1
- x_transformers/__init__.py,sha256=0-2m0LtLpZiZYGwO-6OMYXofx5hbFb_FJOHMxIBqQr4,673
1
+ x_transformers/__init__.py,sha256=8LQl-dNL6vj8VHRx5LMSOlRDTXQvYOuM21PDXz8WdiI,703
2
2
  x_transformers/attend.py,sha256=L7vctHJ0PnECohu4cUu8yvY8cUrVyJxHmMFR0RGL0z4,10163
3
3
  x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRTJCNrYRdJ-Ww,9618
4
4
  x_transformers/continuous.py,sha256=dpHK4NSMDQAJQ_N3Uj9rip0fYGXyu0QCCO_OfEdbRGs,6192
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=GhhRfzxOQoUAqEeT8VnSAtW7wIJ6aW_5DF4LnsqozdQ,64018
7
+ x_transformers/x_transformers.py,sha256=w_S0zOCKJtAO2M5ZKdE7gqSUWzkqECkA87ah-vkqx0Y,64656
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=EN3hxxleTRGYeAz6i4x3U_PrOm9TjxMF3eDhMKGx59E,8575
10
- x_transformers-1.28.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.28.2.dist-info/METADATA,sha256=oT95hrc_XiI7dMKF9ATWyUwir3cfSfeD1PFTZF2zpy4,661
12
- x_transformers-1.28.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
- x_transformers-1.28.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.28.2.dist-info/RECORD,,
10
+ x_transformers-1.28.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.28.5.dist-info/METADATA,sha256=jLcNekd2_ccREKevcTAtHNAnjwqnxaRmAvq90_eSdQI,661
12
+ x_transformers-1.28.5.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
+ x_transformers-1.28.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.28.5.dist-info/RECORD,,