x-transformers 1.28.4__py3-none-any.whl → 1.28.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +59 -32
- {x_transformers-1.28.4.dist-info → x_transformers-1.28.5.dist-info}/METADATA +1 -1
- {x_transformers-1.28.4.dist-info → x_transformers-1.28.5.dist-info}/RECORD +6 -6
- {x_transformers-1.28.4.dist-info → x_transformers-1.28.5.dist-info}/LICENSE +0 -0
- {x_transformers-1.28.4.dist-info → x_transformers-1.28.5.dist-info}/WHEEL +0 -0
- {x_transformers-1.28.4.dist-info → x_transformers-1.28.5.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -3,8 +3,9 @@ from random import random
|
|
3
3
|
from packaging import version
|
4
4
|
|
5
5
|
import torch
|
6
|
-
from torch import nn, einsum, Tensor
|
7
6
|
import torch.nn.functional as F
|
7
|
+
from torch import nn, einsum, Tensor
|
8
|
+
from torch.nn import Module, ModuleList, ModuleDict
|
8
9
|
from torch.cuda.amp import autocast
|
9
10
|
|
10
11
|
from functools import partial, wraps
|
@@ -191,13 +192,13 @@ def dropout_seq(seq, mask, dropout):
|
|
191
192
|
|
192
193
|
# activations
|
193
194
|
|
194
|
-
class ReluSquared(
|
195
|
+
class ReluSquared(Module):
|
195
196
|
def forward(self, x):
|
196
197
|
return F.relu(x) ** 2
|
197
198
|
|
198
199
|
# embedding
|
199
200
|
|
200
|
-
class TokenEmbedding(
|
201
|
+
class TokenEmbedding(Module):
|
201
202
|
def __init__(self, dim, num_tokens, l2norm_embed = False):
|
202
203
|
super().__init__()
|
203
204
|
self.l2norm_embed = l2norm_embed
|
@@ -209,7 +210,7 @@ class TokenEmbedding(nn.Module):
|
|
209
210
|
|
210
211
|
# positional embeddings
|
211
212
|
|
212
|
-
class AbsolutePositionalEmbedding(
|
213
|
+
class AbsolutePositionalEmbedding(Module):
|
213
214
|
def __init__(self, dim, max_seq_len, l2norm_embed = False):
|
214
215
|
super().__init__()
|
215
216
|
self.scale = dim ** -0.5 if not l2norm_embed else 1.
|
@@ -231,7 +232,7 @@ class AbsolutePositionalEmbedding(nn.Module):
|
|
231
232
|
pos_emb = pos_emb * self.scale
|
232
233
|
return l2norm(pos_emb) if self.l2norm_embed else pos_emb
|
233
234
|
|
234
|
-
class ScaledSinusoidalEmbedding(
|
235
|
+
class ScaledSinusoidalEmbedding(Module):
|
235
236
|
def __init__(self, dim, theta = 10000):
|
236
237
|
super().__init__()
|
237
238
|
assert divisible_by(dim, 2)
|
@@ -255,7 +256,7 @@ class ScaledSinusoidalEmbedding(nn.Module):
|
|
255
256
|
emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
|
256
257
|
return emb * self.scale
|
257
258
|
|
258
|
-
class RelativePositionBias(
|
259
|
+
class RelativePositionBias(Module):
|
259
260
|
def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 8):
|
260
261
|
super().__init__()
|
261
262
|
self.scale = scale
|
@@ -300,13 +301,13 @@ class RelativePositionBias(nn.Module):
|
|
300
301
|
bias = rearrange(values, 'i j h -> h i j')
|
301
302
|
return bias * self.scale
|
302
303
|
|
303
|
-
class DynamicPositionBias(
|
304
|
+
class DynamicPositionBias(Module):
|
304
305
|
def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
|
305
306
|
super().__init__()
|
306
307
|
assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
|
307
308
|
self.log_distance = log_distance
|
308
309
|
|
309
|
-
self.mlp =
|
310
|
+
self.mlp = ModuleList([])
|
310
311
|
|
311
312
|
self.mlp.append(Sequential(
|
312
313
|
nn.Linear(1, dim),
|
@@ -352,7 +353,7 @@ class DynamicPositionBias(nn.Module):
|
|
352
353
|
bias = rearrange(bias, 'i j h -> h i j')
|
353
354
|
return bias
|
354
355
|
|
355
|
-
class AlibiPositionalBias(
|
356
|
+
class AlibiPositionalBias(Module):
|
356
357
|
def __init__(self, heads, total_heads, **kwargs):
|
357
358
|
super().__init__()
|
358
359
|
self.heads = heads
|
@@ -401,7 +402,7 @@ class AlibiPositionalBias(nn.Module):
|
|
401
402
|
|
402
403
|
return self.bias
|
403
404
|
|
404
|
-
class RotaryEmbedding(
|
405
|
+
class RotaryEmbedding(Module):
|
405
406
|
def __init__(
|
406
407
|
self,
|
407
408
|
dim,
|
@@ -476,7 +477,7 @@ def apply_rotary_pos_emb(t, freqs, scale = 1):
|
|
476
477
|
|
477
478
|
# norms
|
478
479
|
|
479
|
-
class Scale(
|
480
|
+
class Scale(Module):
|
480
481
|
def __init__(self, value, fn):
|
481
482
|
super().__init__()
|
482
483
|
self.value = value
|
@@ -491,7 +492,7 @@ class Scale(nn.Module):
|
|
491
492
|
|
492
493
|
return (scale_fn(out[0]), *out[1:])
|
493
494
|
|
494
|
-
class LayerNorm(
|
495
|
+
class LayerNorm(Module):
|
495
496
|
def __init__(self, dim):
|
496
497
|
"""
|
497
498
|
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
|
@@ -506,7 +507,7 @@ class LayerNorm(nn.Module):
|
|
506
507
|
if version.parse(torch.__version__) >= version.parse('2.1.0'):
|
507
508
|
LayerNorm = partial(nn.LayerNorm, bias = False)
|
508
509
|
|
509
|
-
class ScaleNorm(
|
510
|
+
class ScaleNorm(Module):
|
510
511
|
def __init__(self, dim):
|
511
512
|
super().__init__()
|
512
513
|
self.scale = dim ** 0.5
|
@@ -515,7 +516,7 @@ class ScaleNorm(nn.Module):
|
|
515
516
|
def forward(self, x):
|
516
517
|
return F.normalize(x, dim = -1) * self.scale * self.g
|
517
518
|
|
518
|
-
class RMSNorm(
|
519
|
+
class RMSNorm(Module):
|
519
520
|
def __init__(self, dim):
|
520
521
|
super().__init__()
|
521
522
|
self.scale = dim ** 0.5
|
@@ -524,7 +525,7 @@ class RMSNorm(nn.Module):
|
|
524
525
|
def forward(self, x):
|
525
526
|
return F.normalize(x, dim = -1) * self.scale * self.g
|
526
527
|
|
527
|
-
class SimpleRMSNorm(
|
528
|
+
class SimpleRMSNorm(Module):
|
528
529
|
def __init__(self, dim):
|
529
530
|
super().__init__()
|
530
531
|
self.scale = dim ** 0.5
|
@@ -534,7 +535,7 @@ class SimpleRMSNorm(nn.Module):
|
|
534
535
|
|
535
536
|
# residual and residual gates
|
536
537
|
|
537
|
-
class Residual(
|
538
|
+
class Residual(Module):
|
538
539
|
def __init__(self, dim, scale_residual = False, scale_residual_constant = 1.):
|
539
540
|
super().__init__()
|
540
541
|
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
|
@@ -549,7 +550,7 @@ class Residual(nn.Module):
|
|
549
550
|
|
550
551
|
return x + residual
|
551
552
|
|
552
|
-
class GRUGating(
|
553
|
+
class GRUGating(Module):
|
553
554
|
def __init__(self, dim, scale_residual = False, **kwargs):
|
554
555
|
super().__init__()
|
555
556
|
self.gru = nn.GRUCell(dim, dim)
|
@@ -579,7 +580,7 @@ def shift(t, amount, mask = None):
|
|
579
580
|
|
580
581
|
return pad_at_dim(t, (amount, -amount), dim = - 2, value = 0.)
|
581
582
|
|
582
|
-
class ShiftTokens(
|
583
|
+
class ShiftTokens(Module):
|
583
584
|
def __init__(self, shifts, fn):
|
584
585
|
super().__init__()
|
585
586
|
self.fn = fn
|
@@ -598,7 +599,7 @@ class ShiftTokens(nn.Module):
|
|
598
599
|
|
599
600
|
# feedforward
|
600
601
|
|
601
|
-
class GLU(
|
602
|
+
class GLU(Module):
|
602
603
|
def __init__(
|
603
604
|
self,
|
604
605
|
dim_in,
|
@@ -615,7 +616,7 @@ class GLU(nn.Module):
|
|
615
616
|
x, gate = self.proj(x).chunk(2, dim = -1)
|
616
617
|
return x * self.act(gate) * self.mult_bias
|
617
618
|
|
618
|
-
class FeedForward(
|
619
|
+
class FeedForward(Module):
|
619
620
|
def __init__(
|
620
621
|
self,
|
621
622
|
dim,
|
@@ -665,7 +666,7 @@ class FeedForward(nn.Module):
|
|
665
666
|
|
666
667
|
# attention. it is all we need
|
667
668
|
|
668
|
-
class Attention(
|
669
|
+
class Attention(Module):
|
669
670
|
def __init__(
|
670
671
|
self,
|
671
672
|
dim,
|
@@ -996,7 +997,7 @@ class Attention(nn.Module):
|
|
996
997
|
|
997
998
|
return out, intermediates
|
998
999
|
|
999
|
-
class AttentionLayers(
|
1000
|
+
class AttentionLayers(Module):
|
1000
1001
|
def __init__(
|
1001
1002
|
self,
|
1002
1003
|
dim,
|
@@ -1058,7 +1059,7 @@ class AttentionLayers(nn.Module):
|
|
1058
1059
|
self.dim = dim
|
1059
1060
|
self.depth = depth
|
1060
1061
|
self.causal = causal
|
1061
|
-
self.layers =
|
1062
|
+
self.layers = ModuleList([])
|
1062
1063
|
|
1063
1064
|
self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb))
|
1064
1065
|
|
@@ -1215,13 +1216,13 @@ class AttentionLayers(nn.Module):
|
|
1215
1216
|
post_branch_norm = norm_fn() if sandwich_norm else None
|
1216
1217
|
post_main_norm = norm_fn() if not pre_norm else None
|
1217
1218
|
|
1218
|
-
norms =
|
1219
|
+
norms = ModuleList([
|
1219
1220
|
pre_branch_norm,
|
1220
1221
|
post_branch_norm,
|
1221
1222
|
post_main_norm
|
1222
1223
|
])
|
1223
1224
|
|
1224
|
-
self.layers.append(
|
1225
|
+
self.layers.append(ModuleList([
|
1225
1226
|
norms,
|
1226
1227
|
layer,
|
1227
1228
|
residual
|
@@ -1427,7 +1428,7 @@ class CrossAttender(AttentionLayers):
|
|
1427
1428
|
def __init__(self, **kwargs):
|
1428
1429
|
super().__init__(cross_attend = True, only_cross = True, **kwargs)
|
1429
1430
|
|
1430
|
-
class ViTransformerWrapper(
|
1431
|
+
class ViTransformerWrapper(Module):
|
1431
1432
|
def __init__(
|
1432
1433
|
self,
|
1433
1434
|
*,
|
@@ -1508,7 +1509,7 @@ class ViTransformerWrapper(nn.Module):
|
|
1508
1509
|
|
1509
1510
|
return logits, embed
|
1510
1511
|
|
1511
|
-
class TransformerWrapper(
|
1512
|
+
class TransformerWrapper(Module):
|
1512
1513
|
def __init__(
|
1513
1514
|
self,
|
1514
1515
|
*,
|
@@ -1525,6 +1526,7 @@ class TransformerWrapper(nn.Module):
|
|
1525
1526
|
memory_tokens_interspersed_every = None,
|
1526
1527
|
tie_embedding = False,
|
1527
1528
|
logits_dim = None,
|
1529
|
+
num_output_heads = 1,
|
1528
1530
|
use_abs_pos_emb = True,
|
1529
1531
|
scaled_sinu_pos_emb = False,
|
1530
1532
|
l2norm_embed = False,
|
@@ -1559,7 +1561,7 @@ class TransformerWrapper(nn.Module):
|
|
1559
1561
|
self.embeds = None
|
1560
1562
|
|
1561
1563
|
if len(embed_num_tokens) > 0:
|
1562
|
-
self.embeds =
|
1564
|
+
self.embeds = ModuleDict({f'{name}_embed': nn.Embedding(num_tokens, emb_dim) for name, num_tokens in embed_num_tokens.items()})
|
1563
1565
|
|
1564
1566
|
# fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290
|
1565
1567
|
|
@@ -1573,8 +1575,19 @@ class TransformerWrapper(nn.Module):
|
|
1573
1575
|
|
1574
1576
|
self.init_()
|
1575
1577
|
|
1578
|
+
# output head, usually to logits of num_tokens
|
1579
|
+
|
1576
1580
|
logits_dim = default(logits_dim, num_tokens)
|
1577
|
-
|
1581
|
+
|
1582
|
+
self.has_multiple_heads = False
|
1583
|
+
|
1584
|
+
if tie_embedding:
|
1585
|
+
self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
|
1586
|
+
elif num_output_heads > 1:
|
1587
|
+
self.has_multiple_heads = True
|
1588
|
+
self.to_logits = ModuleList([nn.Linear(dim, logits_dim, bias = False) for _ in range(num_output_heads)])
|
1589
|
+
else:
|
1590
|
+
self.to_logits = nn.Linear(dim, logits_dim, bias = False)
|
1578
1591
|
|
1579
1592
|
# memory tokens (like [cls]) from Memory Transformers paper
|
1580
1593
|
|
@@ -1705,6 +1718,8 @@ class TransformerWrapper(nn.Module):
|
|
1705
1718
|
|
1706
1719
|
x, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
|
1707
1720
|
|
1721
|
+
# handle memories post-attention
|
1722
|
+
|
1708
1723
|
if has_memory_tokens:
|
1709
1724
|
if exists(mem_every):
|
1710
1725
|
x = rearrange(x, 'b (n m) d -> (b n) m d', m = (mem_every + num_mems))
|
@@ -1718,12 +1733,24 @@ class TransformerWrapper(nn.Module):
|
|
1718
1733
|
|
1719
1734
|
x = x[:, :n]
|
1720
1735
|
|
1736
|
+
# projecting to logits
|
1737
|
+
|
1738
|
+
if not return_embeddings:
|
1739
|
+
if self.has_multiple_heads:
|
1740
|
+
logits = tuple(fn(x) for fn in self.to_logits)
|
1741
|
+
else:
|
1742
|
+
logits = self.to_logits(x)
|
1743
|
+
|
1744
|
+
# different returns
|
1745
|
+
|
1721
1746
|
if return_logits_and_embeddings:
|
1722
|
-
out = (
|
1747
|
+
out = (logits, x)
|
1723
1748
|
elif return_embeddings:
|
1724
1749
|
out = x
|
1725
1750
|
else:
|
1726
|
-
out =
|
1751
|
+
out = logits
|
1752
|
+
|
1753
|
+
# aux loss
|
1727
1754
|
|
1728
1755
|
if return_attn_z_loss:
|
1729
1756
|
pre_softmax_attns = list(map(lambda t: t.pre_softmax_attn, intermediates.attn_intermediates))
|
@@ -1749,7 +1776,7 @@ class TransformerWrapper(nn.Module):
|
|
1749
1776
|
|
1750
1777
|
return out
|
1751
1778
|
|
1752
|
-
class XTransformer(
|
1779
|
+
class XTransformer(Module):
|
1753
1780
|
def __init__(
|
1754
1781
|
self,
|
1755
1782
|
*,
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRT
|
|
4
4
|
x_transformers/continuous.py,sha256=dpHK4NSMDQAJQ_N3Uj9rip0fYGXyu0QCCO_OfEdbRGs,6192
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=w_S0zOCKJtAO2M5ZKdE7gqSUWzkqECkA87ah-vkqx0Y,64656
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=EN3hxxleTRGYeAz6i4x3U_PrOm9TjxMF3eDhMKGx59E,8575
|
10
|
-
x_transformers-1.28.
|
11
|
-
x_transformers-1.28.
|
12
|
-
x_transformers-1.28.
|
13
|
-
x_transformers-1.28.
|
14
|
-
x_transformers-1.28.
|
10
|
+
x_transformers-1.28.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.28.5.dist-info/METADATA,sha256=jLcNekd2_ccREKevcTAtHNAnjwqnxaRmAvq90_eSdQI,661
|
12
|
+
x_transformers-1.28.5.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
+
x_transformers-1.28.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.28.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|