x-transformers 1.28.4__py3-none-any.whl → 1.29.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +66 -33
- {x_transformers-1.28.4.dist-info → x_transformers-1.29.0.dist-info}/METADATA +1 -1
- {x_transformers-1.28.4.dist-info → x_transformers-1.29.0.dist-info}/RECORD +6 -6
- {x_transformers-1.28.4.dist-info → x_transformers-1.29.0.dist-info}/LICENSE +0 -0
- {x_transformers-1.28.4.dist-info → x_transformers-1.29.0.dist-info}/WHEEL +0 -0
- {x_transformers-1.28.4.dist-info → x_transformers-1.29.0.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -3,8 +3,9 @@ from random import random
|
|
3
3
|
from packaging import version
|
4
4
|
|
5
5
|
import torch
|
6
|
-
from torch import nn, einsum, Tensor
|
7
6
|
import torch.nn.functional as F
|
7
|
+
from torch import nn, einsum, Tensor
|
8
|
+
from torch.nn import Module, ModuleList, ModuleDict
|
8
9
|
from torch.cuda.amp import autocast
|
9
10
|
|
10
11
|
from functools import partial, wraps
|
@@ -191,13 +192,13 @@ def dropout_seq(seq, mask, dropout):
|
|
191
192
|
|
192
193
|
# activations
|
193
194
|
|
194
|
-
class ReluSquared(
|
195
|
+
class ReluSquared(Module):
|
195
196
|
def forward(self, x):
|
196
197
|
return F.relu(x) ** 2
|
197
198
|
|
198
199
|
# embedding
|
199
200
|
|
200
|
-
class TokenEmbedding(
|
201
|
+
class TokenEmbedding(Module):
|
201
202
|
def __init__(self, dim, num_tokens, l2norm_embed = False):
|
202
203
|
super().__init__()
|
203
204
|
self.l2norm_embed = l2norm_embed
|
@@ -209,7 +210,7 @@ class TokenEmbedding(nn.Module):
|
|
209
210
|
|
210
211
|
# positional embeddings
|
211
212
|
|
212
|
-
class AbsolutePositionalEmbedding(
|
213
|
+
class AbsolutePositionalEmbedding(Module):
|
213
214
|
def __init__(self, dim, max_seq_len, l2norm_embed = False):
|
214
215
|
super().__init__()
|
215
216
|
self.scale = dim ** -0.5 if not l2norm_embed else 1.
|
@@ -231,7 +232,7 @@ class AbsolutePositionalEmbedding(nn.Module):
|
|
231
232
|
pos_emb = pos_emb * self.scale
|
232
233
|
return l2norm(pos_emb) if self.l2norm_embed else pos_emb
|
233
234
|
|
234
|
-
class ScaledSinusoidalEmbedding(
|
235
|
+
class ScaledSinusoidalEmbedding(Module):
|
235
236
|
def __init__(self, dim, theta = 10000):
|
236
237
|
super().__init__()
|
237
238
|
assert divisible_by(dim, 2)
|
@@ -255,7 +256,7 @@ class ScaledSinusoidalEmbedding(nn.Module):
|
|
255
256
|
emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
|
256
257
|
return emb * self.scale
|
257
258
|
|
258
|
-
class RelativePositionBias(
|
259
|
+
class RelativePositionBias(Module):
|
259
260
|
def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 8):
|
260
261
|
super().__init__()
|
261
262
|
self.scale = scale
|
@@ -300,13 +301,13 @@ class RelativePositionBias(nn.Module):
|
|
300
301
|
bias = rearrange(values, 'i j h -> h i j')
|
301
302
|
return bias * self.scale
|
302
303
|
|
303
|
-
class DynamicPositionBias(
|
304
|
+
class DynamicPositionBias(Module):
|
304
305
|
def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
|
305
306
|
super().__init__()
|
306
307
|
assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
|
307
308
|
self.log_distance = log_distance
|
308
309
|
|
309
|
-
self.mlp =
|
310
|
+
self.mlp = ModuleList([])
|
310
311
|
|
311
312
|
self.mlp.append(Sequential(
|
312
313
|
nn.Linear(1, dim),
|
@@ -352,7 +353,7 @@ class DynamicPositionBias(nn.Module):
|
|
352
353
|
bias = rearrange(bias, 'i j h -> h i j')
|
353
354
|
return bias
|
354
355
|
|
355
|
-
class AlibiPositionalBias(
|
356
|
+
class AlibiPositionalBias(Module):
|
356
357
|
def __init__(self, heads, total_heads, **kwargs):
|
357
358
|
super().__init__()
|
358
359
|
self.heads = heads
|
@@ -401,7 +402,7 @@ class AlibiPositionalBias(nn.Module):
|
|
401
402
|
|
402
403
|
return self.bias
|
403
404
|
|
404
|
-
class RotaryEmbedding(
|
405
|
+
class RotaryEmbedding(Module):
|
405
406
|
def __init__(
|
406
407
|
self,
|
407
408
|
dim,
|
@@ -476,7 +477,7 @@ def apply_rotary_pos_emb(t, freqs, scale = 1):
|
|
476
477
|
|
477
478
|
# norms
|
478
479
|
|
479
|
-
class Scale(
|
480
|
+
class Scale(Module):
|
480
481
|
def __init__(self, value, fn):
|
481
482
|
super().__init__()
|
482
483
|
self.value = value
|
@@ -491,7 +492,7 @@ class Scale(nn.Module):
|
|
491
492
|
|
492
493
|
return (scale_fn(out[0]), *out[1:])
|
493
494
|
|
494
|
-
class LayerNorm(
|
495
|
+
class LayerNorm(Module):
|
495
496
|
def __init__(self, dim):
|
496
497
|
"""
|
497
498
|
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
|
@@ -506,7 +507,7 @@ class LayerNorm(nn.Module):
|
|
506
507
|
if version.parse(torch.__version__) >= version.parse('2.1.0'):
|
507
508
|
LayerNorm = partial(nn.LayerNorm, bias = False)
|
508
509
|
|
509
|
-
class ScaleNorm(
|
510
|
+
class ScaleNorm(Module):
|
510
511
|
def __init__(self, dim):
|
511
512
|
super().__init__()
|
512
513
|
self.scale = dim ** 0.5
|
@@ -515,7 +516,7 @@ class ScaleNorm(nn.Module):
|
|
515
516
|
def forward(self, x):
|
516
517
|
return F.normalize(x, dim = -1) * self.scale * self.g
|
517
518
|
|
518
|
-
class RMSNorm(
|
519
|
+
class RMSNorm(Module):
|
519
520
|
def __init__(self, dim):
|
520
521
|
super().__init__()
|
521
522
|
self.scale = dim ** 0.5
|
@@ -524,7 +525,7 @@ class RMSNorm(nn.Module):
|
|
524
525
|
def forward(self, x):
|
525
526
|
return F.normalize(x, dim = -1) * self.scale * self.g
|
526
527
|
|
527
|
-
class SimpleRMSNorm(
|
528
|
+
class SimpleRMSNorm(Module):
|
528
529
|
def __init__(self, dim):
|
529
530
|
super().__init__()
|
530
531
|
self.scale = dim ** 0.5
|
@@ -534,7 +535,7 @@ class SimpleRMSNorm(nn.Module):
|
|
534
535
|
|
535
536
|
# residual and residual gates
|
536
537
|
|
537
|
-
class Residual(
|
538
|
+
class Residual(Module):
|
538
539
|
def __init__(self, dim, scale_residual = False, scale_residual_constant = 1.):
|
539
540
|
super().__init__()
|
540
541
|
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
|
@@ -549,7 +550,7 @@ class Residual(nn.Module):
|
|
549
550
|
|
550
551
|
return x + residual
|
551
552
|
|
552
|
-
class GRUGating(
|
553
|
+
class GRUGating(Module):
|
553
554
|
def __init__(self, dim, scale_residual = False, **kwargs):
|
554
555
|
super().__init__()
|
555
556
|
self.gru = nn.GRUCell(dim, dim)
|
@@ -579,7 +580,7 @@ def shift(t, amount, mask = None):
|
|
579
580
|
|
580
581
|
return pad_at_dim(t, (amount, -amount), dim = - 2, value = 0.)
|
581
582
|
|
582
|
-
class ShiftTokens(
|
583
|
+
class ShiftTokens(Module):
|
583
584
|
def __init__(self, shifts, fn):
|
584
585
|
super().__init__()
|
585
586
|
self.fn = fn
|
@@ -598,7 +599,7 @@ class ShiftTokens(nn.Module):
|
|
598
599
|
|
599
600
|
# feedforward
|
600
601
|
|
601
|
-
class GLU(
|
602
|
+
class GLU(Module):
|
602
603
|
def __init__(
|
603
604
|
self,
|
604
605
|
dim_in,
|
@@ -615,7 +616,7 @@ class GLU(nn.Module):
|
|
615
616
|
x, gate = self.proj(x).chunk(2, dim = -1)
|
616
617
|
return x * self.act(gate) * self.mult_bias
|
617
618
|
|
618
|
-
class FeedForward(
|
619
|
+
class FeedForward(Module):
|
619
620
|
def __init__(
|
620
621
|
self,
|
621
622
|
dim,
|
@@ -665,7 +666,7 @@ class FeedForward(nn.Module):
|
|
665
666
|
|
666
667
|
# attention. it is all we need
|
667
668
|
|
668
|
-
class Attention(
|
669
|
+
class Attention(Module):
|
669
670
|
def __init__(
|
670
671
|
self,
|
671
672
|
dim,
|
@@ -996,11 +997,11 @@ class Attention(nn.Module):
|
|
996
997
|
|
997
998
|
return out, intermediates
|
998
999
|
|
999
|
-
class AttentionLayers(
|
1000
|
+
class AttentionLayers(Module):
|
1000
1001
|
def __init__(
|
1001
1002
|
self,
|
1002
1003
|
dim,
|
1003
|
-
depth,
|
1004
|
+
depth = None,
|
1004
1005
|
heads = 8,
|
1005
1006
|
causal = False,
|
1006
1007
|
cross_attend = False,
|
@@ -1053,12 +1054,14 @@ class AttentionLayers(nn.Module):
|
|
1053
1054
|
attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs)
|
1054
1055
|
cross_attn_kwargs, kwargs = groupby_prefix_and_trim('cross_attn_', kwargs)
|
1055
1056
|
|
1057
|
+
assert len(kwargs) == 0, f'unrecognized kwargs passed in {kwargs.keys()}'
|
1058
|
+
|
1056
1059
|
dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
|
1057
1060
|
|
1058
1061
|
self.dim = dim
|
1059
1062
|
self.depth = depth
|
1060
1063
|
self.causal = causal
|
1061
|
-
self.layers =
|
1064
|
+
self.layers = ModuleList([])
|
1062
1065
|
|
1063
1066
|
self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb))
|
1064
1067
|
|
@@ -1137,9 +1140,12 @@ class AttentionLayers(nn.Module):
|
|
1137
1140
|
|
1138
1141
|
# setup weight tying, which is a special case of `layer_execute_order`
|
1139
1142
|
|
1143
|
+
assert not (exists(layers_execute_order) and exists(custom_layers) and exists(depth)), 'depth should not be passed in if using custom layers and custom layer execution order'
|
1144
|
+
|
1140
1145
|
assert not (weight_tie_layers and any([*map(exists, (custom_layers, par_ratio, sandwich_coef))]))
|
1141
1146
|
|
1142
1147
|
if weight_tie_layers:
|
1148
|
+
assert exists(depth), 'depth must be passed in with `weight_tie_layers` = True'
|
1143
1149
|
assert not exists(layers_execute_order)
|
1144
1150
|
layers_execute_order = tuple(range(len(default_block))) * depth
|
1145
1151
|
depth = 1
|
@@ -1163,6 +1169,7 @@ class AttentionLayers(nn.Module):
|
|
1163
1169
|
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
|
1164
1170
|
layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
|
1165
1171
|
else:
|
1172
|
+
assert exists(depth), '`depth` must be passed in for `Decoder` or `Encoder`'
|
1166
1173
|
layer_types = default_block * depth
|
1167
1174
|
|
1168
1175
|
self.layer_types = layer_types
|
@@ -1215,13 +1222,13 @@ class AttentionLayers(nn.Module):
|
|
1215
1222
|
post_branch_norm = norm_fn() if sandwich_norm else None
|
1216
1223
|
post_main_norm = norm_fn() if not pre_norm else None
|
1217
1224
|
|
1218
|
-
norms =
|
1225
|
+
norms = ModuleList([
|
1219
1226
|
pre_branch_norm,
|
1220
1227
|
post_branch_norm,
|
1221
1228
|
post_main_norm
|
1222
1229
|
])
|
1223
1230
|
|
1224
|
-
self.layers.append(
|
1231
|
+
self.layers.append(ModuleList([
|
1225
1232
|
norms,
|
1226
1233
|
layer,
|
1227
1234
|
residual
|
@@ -1427,7 +1434,7 @@ class CrossAttender(AttentionLayers):
|
|
1427
1434
|
def __init__(self, **kwargs):
|
1428
1435
|
super().__init__(cross_attend = True, only_cross = True, **kwargs)
|
1429
1436
|
|
1430
|
-
class ViTransformerWrapper(
|
1437
|
+
class ViTransformerWrapper(Module):
|
1431
1438
|
def __init__(
|
1432
1439
|
self,
|
1433
1440
|
*,
|
@@ -1508,7 +1515,7 @@ class ViTransformerWrapper(nn.Module):
|
|
1508
1515
|
|
1509
1516
|
return logits, embed
|
1510
1517
|
|
1511
|
-
class TransformerWrapper(
|
1518
|
+
class TransformerWrapper(Module):
|
1512
1519
|
def __init__(
|
1513
1520
|
self,
|
1514
1521
|
*,
|
@@ -1525,6 +1532,7 @@ class TransformerWrapper(nn.Module):
|
|
1525
1532
|
memory_tokens_interspersed_every = None,
|
1526
1533
|
tie_embedding = False,
|
1527
1534
|
logits_dim = None,
|
1535
|
+
num_output_heads = 1,
|
1528
1536
|
use_abs_pos_emb = True,
|
1529
1537
|
scaled_sinu_pos_emb = False,
|
1530
1538
|
l2norm_embed = False,
|
@@ -1559,7 +1567,7 @@ class TransformerWrapper(nn.Module):
|
|
1559
1567
|
self.embeds = None
|
1560
1568
|
|
1561
1569
|
if len(embed_num_tokens) > 0:
|
1562
|
-
self.embeds =
|
1570
|
+
self.embeds = ModuleDict({f'{name}_embed': nn.Embedding(num_tokens, emb_dim) for name, num_tokens in embed_num_tokens.items()})
|
1563
1571
|
|
1564
1572
|
# fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290
|
1565
1573
|
|
@@ -1573,8 +1581,19 @@ class TransformerWrapper(nn.Module):
|
|
1573
1581
|
|
1574
1582
|
self.init_()
|
1575
1583
|
|
1584
|
+
# output head, usually to logits of num_tokens
|
1585
|
+
|
1576
1586
|
logits_dim = default(logits_dim, num_tokens)
|
1577
|
-
|
1587
|
+
|
1588
|
+
self.has_multiple_heads = False
|
1589
|
+
|
1590
|
+
if tie_embedding:
|
1591
|
+
self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
|
1592
|
+
elif num_output_heads > 1:
|
1593
|
+
self.has_multiple_heads = True
|
1594
|
+
self.to_logits = ModuleList([nn.Linear(dim, logits_dim, bias = False) for _ in range(num_output_heads)])
|
1595
|
+
else:
|
1596
|
+
self.to_logits = nn.Linear(dim, logits_dim, bias = False)
|
1578
1597
|
|
1579
1598
|
# memory tokens (like [cls]) from Memory Transformers paper
|
1580
1599
|
|
@@ -1705,6 +1724,8 @@ class TransformerWrapper(nn.Module):
|
|
1705
1724
|
|
1706
1725
|
x, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
|
1707
1726
|
|
1727
|
+
# handle memories post-attention
|
1728
|
+
|
1708
1729
|
if has_memory_tokens:
|
1709
1730
|
if exists(mem_every):
|
1710
1731
|
x = rearrange(x, 'b (n m) d -> (b n) m d', m = (mem_every + num_mems))
|
@@ -1718,12 +1739,24 @@ class TransformerWrapper(nn.Module):
|
|
1718
1739
|
|
1719
1740
|
x = x[:, :n]
|
1720
1741
|
|
1742
|
+
# projecting to logits
|
1743
|
+
|
1744
|
+
if not return_embeddings:
|
1745
|
+
if self.has_multiple_heads:
|
1746
|
+
logits = tuple(fn(x) for fn in self.to_logits)
|
1747
|
+
else:
|
1748
|
+
logits = self.to_logits(x)
|
1749
|
+
|
1750
|
+
# different returns
|
1751
|
+
|
1721
1752
|
if return_logits_and_embeddings:
|
1722
|
-
out = (
|
1753
|
+
out = (logits, x)
|
1723
1754
|
elif return_embeddings:
|
1724
1755
|
out = x
|
1725
1756
|
else:
|
1726
|
-
out =
|
1757
|
+
out = logits
|
1758
|
+
|
1759
|
+
# aux loss
|
1727
1760
|
|
1728
1761
|
if return_attn_z_loss:
|
1729
1762
|
pre_softmax_attns = list(map(lambda t: t.pre_softmax_attn, intermediates.attn_intermediates))
|
@@ -1749,7 +1782,7 @@ class TransformerWrapper(nn.Module):
|
|
1749
1782
|
|
1750
1783
|
return out
|
1751
1784
|
|
1752
|
-
class XTransformer(
|
1785
|
+
class XTransformer(Module):
|
1753
1786
|
def __init__(
|
1754
1787
|
self,
|
1755
1788
|
*,
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRT
|
|
4
4
|
x_transformers/continuous.py,sha256=dpHK4NSMDQAJQ_N3Uj9rip0fYGXyu0QCCO_OfEdbRGs,6192
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=ub1QXJIXfoK5Bm8poZ1oJC99hbt9QitAuKmmmfBtxUY,65111
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=EN3hxxleTRGYeAz6i4x3U_PrOm9TjxMF3eDhMKGx59E,8575
|
10
|
-
x_transformers-1.
|
11
|
-
x_transformers-1.
|
12
|
-
x_transformers-1.
|
13
|
-
x_transformers-1.
|
14
|
-
x_transformers-1.
|
10
|
+
x_transformers-1.29.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.29.0.dist-info/METADATA,sha256=6ivD0nnIvXz057mJdIeHYNt2s9E0fN69eqSPGtSbcXg,661
|
12
|
+
x_transformers-1.29.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
+
x_transformers-1.29.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.29.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|