x-transformers 1.28.4__py3-none-any.whl → 1.29.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,8 +3,9 @@ from random import random
3
3
  from packaging import version
4
4
 
5
5
  import torch
6
- from torch import nn, einsum, Tensor
7
6
  import torch.nn.functional as F
7
+ from torch import nn, einsum, Tensor
8
+ from torch.nn import Module, ModuleList, ModuleDict
8
9
  from torch.cuda.amp import autocast
9
10
 
10
11
  from functools import partial, wraps
@@ -191,13 +192,13 @@ def dropout_seq(seq, mask, dropout):
191
192
 
192
193
  # activations
193
194
 
194
- class ReluSquared(nn.Module):
195
+ class ReluSquared(Module):
195
196
  def forward(self, x):
196
197
  return F.relu(x) ** 2
197
198
 
198
199
  # embedding
199
200
 
200
- class TokenEmbedding(nn.Module):
201
+ class TokenEmbedding(Module):
201
202
  def __init__(self, dim, num_tokens, l2norm_embed = False):
202
203
  super().__init__()
203
204
  self.l2norm_embed = l2norm_embed
@@ -209,7 +210,7 @@ class TokenEmbedding(nn.Module):
209
210
 
210
211
  # positional embeddings
211
212
 
212
- class AbsolutePositionalEmbedding(nn.Module):
213
+ class AbsolutePositionalEmbedding(Module):
213
214
  def __init__(self, dim, max_seq_len, l2norm_embed = False):
214
215
  super().__init__()
215
216
  self.scale = dim ** -0.5 if not l2norm_embed else 1.
@@ -231,7 +232,7 @@ class AbsolutePositionalEmbedding(nn.Module):
231
232
  pos_emb = pos_emb * self.scale
232
233
  return l2norm(pos_emb) if self.l2norm_embed else pos_emb
233
234
 
234
- class ScaledSinusoidalEmbedding(nn.Module):
235
+ class ScaledSinusoidalEmbedding(Module):
235
236
  def __init__(self, dim, theta = 10000):
236
237
  super().__init__()
237
238
  assert divisible_by(dim, 2)
@@ -255,7 +256,7 @@ class ScaledSinusoidalEmbedding(nn.Module):
255
256
  emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
256
257
  return emb * self.scale
257
258
 
258
- class RelativePositionBias(nn.Module):
259
+ class RelativePositionBias(Module):
259
260
  def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 8):
260
261
  super().__init__()
261
262
  self.scale = scale
@@ -300,13 +301,13 @@ class RelativePositionBias(nn.Module):
300
301
  bias = rearrange(values, 'i j h -> h i j')
301
302
  return bias * self.scale
302
303
 
303
- class DynamicPositionBias(nn.Module):
304
+ class DynamicPositionBias(Module):
304
305
  def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
305
306
  super().__init__()
306
307
  assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
307
308
  self.log_distance = log_distance
308
309
 
309
- self.mlp = nn.ModuleList([])
310
+ self.mlp = ModuleList([])
310
311
 
311
312
  self.mlp.append(Sequential(
312
313
  nn.Linear(1, dim),
@@ -352,7 +353,7 @@ class DynamicPositionBias(nn.Module):
352
353
  bias = rearrange(bias, 'i j h -> h i j')
353
354
  return bias
354
355
 
355
- class AlibiPositionalBias(nn.Module):
356
+ class AlibiPositionalBias(Module):
356
357
  def __init__(self, heads, total_heads, **kwargs):
357
358
  super().__init__()
358
359
  self.heads = heads
@@ -401,7 +402,7 @@ class AlibiPositionalBias(nn.Module):
401
402
 
402
403
  return self.bias
403
404
 
404
- class RotaryEmbedding(nn.Module):
405
+ class RotaryEmbedding(Module):
405
406
  def __init__(
406
407
  self,
407
408
  dim,
@@ -476,7 +477,7 @@ def apply_rotary_pos_emb(t, freqs, scale = 1):
476
477
 
477
478
  # norms
478
479
 
479
- class Scale(nn.Module):
480
+ class Scale(Module):
480
481
  def __init__(self, value, fn):
481
482
  super().__init__()
482
483
  self.value = value
@@ -491,7 +492,7 @@ class Scale(nn.Module):
491
492
 
492
493
  return (scale_fn(out[0]), *out[1:])
493
494
 
494
- class LayerNorm(nn.Module):
495
+ class LayerNorm(Module):
495
496
  def __init__(self, dim):
496
497
  """
497
498
  bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
@@ -506,7 +507,7 @@ class LayerNorm(nn.Module):
506
507
  if version.parse(torch.__version__) >= version.parse('2.1.0'):
507
508
  LayerNorm = partial(nn.LayerNorm, bias = False)
508
509
 
509
- class ScaleNorm(nn.Module):
510
+ class ScaleNorm(Module):
510
511
  def __init__(self, dim):
511
512
  super().__init__()
512
513
  self.scale = dim ** 0.5
@@ -515,7 +516,7 @@ class ScaleNorm(nn.Module):
515
516
  def forward(self, x):
516
517
  return F.normalize(x, dim = -1) * self.scale * self.g
517
518
 
518
- class RMSNorm(nn.Module):
519
+ class RMSNorm(Module):
519
520
  def __init__(self, dim):
520
521
  super().__init__()
521
522
  self.scale = dim ** 0.5
@@ -524,7 +525,7 @@ class RMSNorm(nn.Module):
524
525
  def forward(self, x):
525
526
  return F.normalize(x, dim = -1) * self.scale * self.g
526
527
 
527
- class SimpleRMSNorm(nn.Module):
528
+ class SimpleRMSNorm(Module):
528
529
  def __init__(self, dim):
529
530
  super().__init__()
530
531
  self.scale = dim ** 0.5
@@ -534,7 +535,7 @@ class SimpleRMSNorm(nn.Module):
534
535
 
535
536
  # residual and residual gates
536
537
 
537
- class Residual(nn.Module):
538
+ class Residual(Module):
538
539
  def __init__(self, dim, scale_residual = False, scale_residual_constant = 1.):
539
540
  super().__init__()
540
541
  self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
@@ -549,7 +550,7 @@ class Residual(nn.Module):
549
550
 
550
551
  return x + residual
551
552
 
552
- class GRUGating(nn.Module):
553
+ class GRUGating(Module):
553
554
  def __init__(self, dim, scale_residual = False, **kwargs):
554
555
  super().__init__()
555
556
  self.gru = nn.GRUCell(dim, dim)
@@ -579,7 +580,7 @@ def shift(t, amount, mask = None):
579
580
 
580
581
  return pad_at_dim(t, (amount, -amount), dim = - 2, value = 0.)
581
582
 
582
- class ShiftTokens(nn.Module):
583
+ class ShiftTokens(Module):
583
584
  def __init__(self, shifts, fn):
584
585
  super().__init__()
585
586
  self.fn = fn
@@ -598,7 +599,7 @@ class ShiftTokens(nn.Module):
598
599
 
599
600
  # feedforward
600
601
 
601
- class GLU(nn.Module):
602
+ class GLU(Module):
602
603
  def __init__(
603
604
  self,
604
605
  dim_in,
@@ -615,7 +616,7 @@ class GLU(nn.Module):
615
616
  x, gate = self.proj(x).chunk(2, dim = -1)
616
617
  return x * self.act(gate) * self.mult_bias
617
618
 
618
- class FeedForward(nn.Module):
619
+ class FeedForward(Module):
619
620
  def __init__(
620
621
  self,
621
622
  dim,
@@ -665,7 +666,7 @@ class FeedForward(nn.Module):
665
666
 
666
667
  # attention. it is all we need
667
668
 
668
- class Attention(nn.Module):
669
+ class Attention(Module):
669
670
  def __init__(
670
671
  self,
671
672
  dim,
@@ -996,11 +997,11 @@ class Attention(nn.Module):
996
997
 
997
998
  return out, intermediates
998
999
 
999
- class AttentionLayers(nn.Module):
1000
+ class AttentionLayers(Module):
1000
1001
  def __init__(
1001
1002
  self,
1002
1003
  dim,
1003
- depth,
1004
+ depth = None,
1004
1005
  heads = 8,
1005
1006
  causal = False,
1006
1007
  cross_attend = False,
@@ -1053,12 +1054,14 @@ class AttentionLayers(nn.Module):
1053
1054
  attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs)
1054
1055
  cross_attn_kwargs, kwargs = groupby_prefix_and_trim('cross_attn_', kwargs)
1055
1056
 
1057
+ assert len(kwargs) == 0, f'unrecognized kwargs passed in {kwargs.keys()}'
1058
+
1056
1059
  dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
1057
1060
 
1058
1061
  self.dim = dim
1059
1062
  self.depth = depth
1060
1063
  self.causal = causal
1061
- self.layers = nn.ModuleList([])
1064
+ self.layers = ModuleList([])
1062
1065
 
1063
1066
  self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb))
1064
1067
 
@@ -1137,9 +1140,12 @@ class AttentionLayers(nn.Module):
1137
1140
 
1138
1141
  # setup weight tying, which is a special case of `layer_execute_order`
1139
1142
 
1143
+ assert not (exists(layers_execute_order) and exists(custom_layers) and exists(depth)), 'depth should not be passed in if using custom layers and custom layer execution order'
1144
+
1140
1145
  assert not (weight_tie_layers and any([*map(exists, (custom_layers, par_ratio, sandwich_coef))]))
1141
1146
 
1142
1147
  if weight_tie_layers:
1148
+ assert exists(depth), 'depth must be passed in with `weight_tie_layers` = True'
1143
1149
  assert not exists(layers_execute_order)
1144
1150
  layers_execute_order = tuple(range(len(default_block))) * depth
1145
1151
  depth = 1
@@ -1163,6 +1169,7 @@ class AttentionLayers(nn.Module):
1163
1169
  assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
1164
1170
  layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
1165
1171
  else:
1172
+ assert exists(depth), '`depth` must be passed in for `Decoder` or `Encoder`'
1166
1173
  layer_types = default_block * depth
1167
1174
 
1168
1175
  self.layer_types = layer_types
@@ -1215,13 +1222,13 @@ class AttentionLayers(nn.Module):
1215
1222
  post_branch_norm = norm_fn() if sandwich_norm else None
1216
1223
  post_main_norm = norm_fn() if not pre_norm else None
1217
1224
 
1218
- norms = nn.ModuleList([
1225
+ norms = ModuleList([
1219
1226
  pre_branch_norm,
1220
1227
  post_branch_norm,
1221
1228
  post_main_norm
1222
1229
  ])
1223
1230
 
1224
- self.layers.append(nn.ModuleList([
1231
+ self.layers.append(ModuleList([
1225
1232
  norms,
1226
1233
  layer,
1227
1234
  residual
@@ -1427,7 +1434,7 @@ class CrossAttender(AttentionLayers):
1427
1434
  def __init__(self, **kwargs):
1428
1435
  super().__init__(cross_attend = True, only_cross = True, **kwargs)
1429
1436
 
1430
- class ViTransformerWrapper(nn.Module):
1437
+ class ViTransformerWrapper(Module):
1431
1438
  def __init__(
1432
1439
  self,
1433
1440
  *,
@@ -1508,7 +1515,7 @@ class ViTransformerWrapper(nn.Module):
1508
1515
 
1509
1516
  return logits, embed
1510
1517
 
1511
- class TransformerWrapper(nn.Module):
1518
+ class TransformerWrapper(Module):
1512
1519
  def __init__(
1513
1520
  self,
1514
1521
  *,
@@ -1525,6 +1532,7 @@ class TransformerWrapper(nn.Module):
1525
1532
  memory_tokens_interspersed_every = None,
1526
1533
  tie_embedding = False,
1527
1534
  logits_dim = None,
1535
+ num_output_heads = 1,
1528
1536
  use_abs_pos_emb = True,
1529
1537
  scaled_sinu_pos_emb = False,
1530
1538
  l2norm_embed = False,
@@ -1559,7 +1567,7 @@ class TransformerWrapper(nn.Module):
1559
1567
  self.embeds = None
1560
1568
 
1561
1569
  if len(embed_num_tokens) > 0:
1562
- self.embeds = nn.ModuleDict({f'{name}_embed': nn.Embedding(num_tokens, emb_dim) for name, num_tokens in embed_num_tokens.items()})
1570
+ self.embeds = ModuleDict({f'{name}_embed': nn.Embedding(num_tokens, emb_dim) for name, num_tokens in embed_num_tokens.items()})
1563
1571
 
1564
1572
  # fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290
1565
1573
 
@@ -1573,8 +1581,19 @@ class TransformerWrapper(nn.Module):
1573
1581
 
1574
1582
  self.init_()
1575
1583
 
1584
+ # output head, usually to logits of num_tokens
1585
+
1576
1586
  logits_dim = default(logits_dim, num_tokens)
1577
- self.to_logits = nn.Linear(dim, logits_dim, bias = False) if not tie_embedding else lambda t: t @ self.token_emb.emb.weight.t()
1587
+
1588
+ self.has_multiple_heads = False
1589
+
1590
+ if tie_embedding:
1591
+ self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
1592
+ elif num_output_heads > 1:
1593
+ self.has_multiple_heads = True
1594
+ self.to_logits = ModuleList([nn.Linear(dim, logits_dim, bias = False) for _ in range(num_output_heads)])
1595
+ else:
1596
+ self.to_logits = nn.Linear(dim, logits_dim, bias = False)
1578
1597
 
1579
1598
  # memory tokens (like [cls]) from Memory Transformers paper
1580
1599
 
@@ -1705,6 +1724,8 @@ class TransformerWrapper(nn.Module):
1705
1724
 
1706
1725
  x, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
1707
1726
 
1727
+ # handle memories post-attention
1728
+
1708
1729
  if has_memory_tokens:
1709
1730
  if exists(mem_every):
1710
1731
  x = rearrange(x, 'b (n m) d -> (b n) m d', m = (mem_every + num_mems))
@@ -1718,12 +1739,24 @@ class TransformerWrapper(nn.Module):
1718
1739
 
1719
1740
  x = x[:, :n]
1720
1741
 
1742
+ # projecting to logits
1743
+
1744
+ if not return_embeddings:
1745
+ if self.has_multiple_heads:
1746
+ logits = tuple(fn(x) for fn in self.to_logits)
1747
+ else:
1748
+ logits = self.to_logits(x)
1749
+
1750
+ # different returns
1751
+
1721
1752
  if return_logits_and_embeddings:
1722
- out = (self.to_logits(x), x)
1753
+ out = (logits, x)
1723
1754
  elif return_embeddings:
1724
1755
  out = x
1725
1756
  else:
1726
- out = self.to_logits(x)
1757
+ out = logits
1758
+
1759
+ # aux loss
1727
1760
 
1728
1761
  if return_attn_z_loss:
1729
1762
  pre_softmax_attns = list(map(lambda t: t.pre_softmax_attn, intermediates.attn_intermediates))
@@ -1749,7 +1782,7 @@ class TransformerWrapper(nn.Module):
1749
1782
 
1750
1783
  return out
1751
1784
 
1752
- class XTransformer(nn.Module):
1785
+ class XTransformer(Module):
1753
1786
  def __init__(
1754
1787
  self,
1755
1788
  *,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.28.4
3
+ Version: 1.29.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRT
4
4
  x_transformers/continuous.py,sha256=dpHK4NSMDQAJQ_N3Uj9rip0fYGXyu0QCCO_OfEdbRGs,6192
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=GhhRfzxOQoUAqEeT8VnSAtW7wIJ6aW_5DF4LnsqozdQ,64018
7
+ x_transformers/x_transformers.py,sha256=ub1QXJIXfoK5Bm8poZ1oJC99hbt9QitAuKmmmfBtxUY,65111
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=EN3hxxleTRGYeAz6i4x3U_PrOm9TjxMF3eDhMKGx59E,8575
10
- x_transformers-1.28.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.28.4.dist-info/METADATA,sha256=JKCQN6QEaSe9M63vpez9hdan0f67zSiu5okyl9GDNKU,661
12
- x_transformers-1.28.4.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
- x_transformers-1.28.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.28.4.dist-info/RECORD,,
10
+ x_transformers-1.29.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.29.0.dist-info/METADATA,sha256=6ivD0nnIvXz057mJdIeHYNt2s9E0fN69eqSPGtSbcXg,661
12
+ x_transformers-1.29.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
+ x_transformers-1.29.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.29.0.dist-info/RECORD,,