x-transformers 1.23.5__py3-none-any.whl → 1.23.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/__init__.py +15 -9
- x_transformers/continuous_autoregressive_wrapper.py +137 -1
- x_transformers/x_transformers.py +3 -124
- {x_transformers-1.23.5.dist-info → x_transformers-1.23.6.dist-info}/METADATA +1 -1
- x_transformers-1.23.6.dist-info/RECORD +12 -0
- x_transformers-1.23.5.dist-info/RECORD +0 -12
- {x_transformers-1.23.5.dist-info → x_transformers-1.23.6.dist-info}/LICENSE +0 -0
- {x_transformers-1.23.5.dist-info → x_transformers-1.23.6.dist-info}/WHEEL +0 -0
- {x_transformers-1.23.5.dist-info → x_transformers-1.23.6.dist-info}/top_level.txt +0 -0
x_transformers/__init__.py
CHANGED
@@ -1,13 +1,19 @@
|
|
1
|
-
import
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
1
|
+
from x_transformers.x_transformers import (
|
2
|
+
XTransformer,
|
3
|
+
Encoder,
|
4
|
+
Decoder,
|
5
|
+
CrossAttender,
|
6
|
+
Attention,
|
7
|
+
TransformerWrapper,
|
8
|
+
ViTransformerWrapper
|
9
|
+
)
|
9
10
|
|
10
11
|
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
11
12
|
from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper
|
12
|
-
|
13
|
+
|
14
|
+
from x_transformers.continuous_autoregressive_wrapper import (
|
15
|
+
ContinuousTransformerWrapper,
|
16
|
+
ContinuousAutoregressiveWrapper
|
17
|
+
)
|
18
|
+
|
13
19
|
from x_transformers.xl_autoregressive_wrapper import XLAutoregressiveWrapper
|
@@ -2,11 +2,147 @@ import torch
|
|
2
2
|
from torch import nn
|
3
3
|
import torch.nn.functional as F
|
4
4
|
|
5
|
+
from x_transformers.x_transformers import (
|
6
|
+
AttentionLayers,
|
7
|
+
ScaledSinusoidalEmbedding,
|
8
|
+
AbsolutePositionalEmbedding
|
9
|
+
)
|
10
|
+
|
11
|
+
# helper functions
|
12
|
+
|
5
13
|
def exists(val):
|
6
14
|
return val is not None
|
7
15
|
|
16
|
+
def default(val, d):
|
17
|
+
if exists(val):
|
18
|
+
return val
|
19
|
+
return d() if callable(d) else d
|
20
|
+
|
21
|
+
# main classes
|
22
|
+
|
23
|
+
class ContinuousTransformerWrapper(nn.Module):
|
24
|
+
def __init__(
|
25
|
+
self,
|
26
|
+
*,
|
27
|
+
max_seq_len,
|
28
|
+
attn_layers: AttentionLayers,
|
29
|
+
dim_in = None,
|
30
|
+
dim_out = None,
|
31
|
+
emb_dim = None,
|
32
|
+
max_mem_len = 0,
|
33
|
+
num_memory_tokens = None,
|
34
|
+
post_emb_norm = False,
|
35
|
+
emb_dropout = 0.,
|
36
|
+
use_abs_pos_emb = True,
|
37
|
+
scaled_sinu_pos_emb = False
|
38
|
+
):
|
39
|
+
super().__init__()
|
40
|
+
dim = attn_layers.dim
|
41
|
+
|
42
|
+
self.max_seq_len = max_seq_len
|
43
|
+
|
44
|
+
self.max_mem_len = max_mem_len
|
45
|
+
|
46
|
+
if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
|
47
|
+
self.pos_emb = always(0)
|
48
|
+
elif scaled_sinu_pos_emb:
|
49
|
+
self.pos_emb = ScaledSinusoidalEmbedding(dim)
|
50
|
+
else:
|
51
|
+
self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len)
|
52
|
+
|
53
|
+
self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
|
54
|
+
self.emb_dropout = nn.Dropout(emb_dropout)
|
55
|
+
|
56
|
+
# memory tokens
|
57
|
+
|
58
|
+
num_memory_tokens = default(num_memory_tokens, 0)
|
59
|
+
self.has_memory_tokens = num_memory_tokens > 0
|
60
|
+
|
61
|
+
if num_memory_tokens > 0:
|
62
|
+
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
|
63
|
+
|
64
|
+
# attention layers
|
65
|
+
|
66
|
+
self.attn_layers = attn_layers
|
67
|
+
|
68
|
+
# project in and out
|
69
|
+
|
70
|
+
self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
|
71
|
+
self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
|
72
|
+
|
73
|
+
def forward(
|
74
|
+
self,
|
75
|
+
x,
|
76
|
+
return_embeddings = False,
|
77
|
+
return_intermediates = False,
|
78
|
+
return_mems = False,
|
79
|
+
mask = None,
|
80
|
+
return_attn = False,
|
81
|
+
mems = None,
|
82
|
+
pos = None,
|
83
|
+
prepend_embeds = None,
|
84
|
+
**kwargs
|
85
|
+
):
|
86
|
+
batch = x.shape[0]
|
87
|
+
|
88
|
+
x = self.project_in(x)
|
89
|
+
x = x + self.pos_emb(x, pos = pos)
|
90
|
+
|
91
|
+
x = self.post_emb_norm(x)
|
92
|
+
|
93
|
+
# memory tokens
|
94
|
+
|
95
|
+
if self.has_memory_tokens:
|
96
|
+
m = repeat(self.memory_tokens, 'm d -> b m d', b = batch)
|
97
|
+
x, mem_ps = pack([m, x], 'b * d')
|
98
|
+
|
99
|
+
if exists(mask):
|
100
|
+
num_mems = m.shape[-2]
|
101
|
+
mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True)
|
102
|
+
|
103
|
+
# whether to append embeds, as in PaLI, for image embeddings
|
104
|
+
|
105
|
+
if exists(prepend_embeds):
|
106
|
+
_, prepend_dim = prepend_embeds.shape[1:]
|
107
|
+
assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as model dimensions'
|
108
|
+
|
109
|
+
x = torch.cat((prepend_embeds, x), dim = -2)
|
110
|
+
|
111
|
+
x = self.emb_dropout(x)
|
112
|
+
|
113
|
+
# attention layers
|
114
|
+
|
115
|
+
x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
|
116
|
+
|
117
|
+
# splice out memory tokens
|
118
|
+
|
119
|
+
if self.has_memory_tokens:
|
120
|
+
m, x = unpack(x, mem_ps, 'b * d')
|
121
|
+
intermediates.memory_tokens = m
|
122
|
+
|
123
|
+
out = self.project_out(x) if not return_embeddings else x
|
124
|
+
|
125
|
+
if return_intermediates:
|
126
|
+
return out, intermediates
|
127
|
+
|
128
|
+
if return_mems:
|
129
|
+
hiddens = intermediates.hiddens
|
130
|
+
new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), hiddens))
|
131
|
+
return out, new_mems
|
132
|
+
|
133
|
+
if return_attn:
|
134
|
+
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
135
|
+
return out, attn_maps
|
136
|
+
|
137
|
+
return out
|
138
|
+
|
8
139
|
class ContinuousAutoregressiveWrapper(nn.Module):
|
9
|
-
def __init__(
|
140
|
+
def __init__(
|
141
|
+
self,
|
142
|
+
net: ContinuousTransformerWrapper,
|
143
|
+
ignore_index = -100,
|
144
|
+
pad_value = 0
|
145
|
+
):
|
10
146
|
super().__init__()
|
11
147
|
self.net = net
|
12
148
|
self.max_seq_len = net.max_seq_len
|
x_transformers/x_transformers.py
CHANGED
@@ -6,7 +6,6 @@ from torch import nn, einsum, Tensor
|
|
6
6
|
import torch.nn.functional as F
|
7
7
|
|
8
8
|
from functools import partial, wraps
|
9
|
-
from inspect import isfunction
|
10
9
|
from collections import namedtuple
|
11
10
|
from dataclasses import dataclass
|
12
11
|
from typing import List, Callable, Optional
|
@@ -38,7 +37,7 @@ def exists(val):
|
|
38
37
|
def default(val, d):
|
39
38
|
if exists(val):
|
40
39
|
return val
|
41
|
-
return d() if
|
40
|
+
return d() if callable(d) else d
|
42
41
|
|
43
42
|
def cast_tuple(val, depth):
|
44
43
|
return val if isinstance(val, tuple) else (val,) * depth
|
@@ -1340,7 +1339,7 @@ class ViTransformerWrapper(nn.Module):
|
|
1340
1339
|
*,
|
1341
1340
|
image_size,
|
1342
1341
|
patch_size,
|
1343
|
-
attn_layers,
|
1342
|
+
attn_layers: Encoder,
|
1344
1343
|
channels = 3,
|
1345
1344
|
num_classes = None,
|
1346
1345
|
post_emb_norm = False,
|
@@ -1348,7 +1347,6 @@ class ViTransformerWrapper(nn.Module):
|
|
1348
1347
|
emb_dropout = 0.
|
1349
1348
|
):
|
1350
1349
|
super().__init__()
|
1351
|
-
assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder'
|
1352
1350
|
assert divisible_by(image_size, patch_size), 'image dimensions must be divisible by the patch size'
|
1353
1351
|
dim = attn_layers.dim
|
1354
1352
|
num_patches = (image_size // patch_size) ** 2
|
@@ -1414,7 +1412,7 @@ class TransformerWrapper(nn.Module):
|
|
1414
1412
|
*,
|
1415
1413
|
num_tokens,
|
1416
1414
|
max_seq_len,
|
1417
|
-
attn_layers,
|
1415
|
+
attn_layers: AttentionLayers,
|
1418
1416
|
emb_dim = None,
|
1419
1417
|
max_mem_len = 0,
|
1420
1418
|
shift_mem_down = 0,
|
@@ -1431,7 +1429,6 @@ class TransformerWrapper(nn.Module):
|
|
1431
1429
|
attn_z_loss_weight = 1e-4,
|
1432
1430
|
):
|
1433
1431
|
super().__init__()
|
1434
|
-
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
|
1435
1432
|
|
1436
1433
|
dim = attn_layers.dim
|
1437
1434
|
emb_dim = default(emb_dim, dim)
|
@@ -1615,124 +1612,6 @@ class TransformerWrapper(nn.Module):
|
|
1615
1612
|
|
1616
1613
|
return out
|
1617
1614
|
|
1618
|
-
class ContinuousTransformerWrapper(nn.Module):
|
1619
|
-
def __init__(
|
1620
|
-
self,
|
1621
|
-
*,
|
1622
|
-
max_seq_len,
|
1623
|
-
attn_layers,
|
1624
|
-
dim_in = None,
|
1625
|
-
dim_out = None,
|
1626
|
-
emb_dim = None,
|
1627
|
-
max_mem_len = 0,
|
1628
|
-
num_memory_tokens = None,
|
1629
|
-
post_emb_norm = False,
|
1630
|
-
emb_dropout = 0.,
|
1631
|
-
use_abs_pos_emb = True,
|
1632
|
-
scaled_sinu_pos_emb = False
|
1633
|
-
):
|
1634
|
-
super().__init__()
|
1635
|
-
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
|
1636
|
-
|
1637
|
-
dim = attn_layers.dim
|
1638
|
-
|
1639
|
-
self.max_seq_len = max_seq_len
|
1640
|
-
|
1641
|
-
self.max_mem_len = max_mem_len
|
1642
|
-
|
1643
|
-
if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
|
1644
|
-
self.pos_emb = always(0)
|
1645
|
-
elif scaled_sinu_pos_emb:
|
1646
|
-
self.pos_emb = ScaledSinusoidalEmbedding(dim)
|
1647
|
-
else:
|
1648
|
-
self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len)
|
1649
|
-
|
1650
|
-
self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
|
1651
|
-
self.emb_dropout = nn.Dropout(emb_dropout)
|
1652
|
-
|
1653
|
-
# memory tokens
|
1654
|
-
|
1655
|
-
num_memory_tokens = default(num_memory_tokens, 0)
|
1656
|
-
self.has_memory_tokens = num_memory_tokens > 0
|
1657
|
-
|
1658
|
-
if num_memory_tokens > 0:
|
1659
|
-
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
|
1660
|
-
|
1661
|
-
# attention layers
|
1662
|
-
|
1663
|
-
self.attn_layers = attn_layers
|
1664
|
-
|
1665
|
-
# project in and out
|
1666
|
-
|
1667
|
-
self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
|
1668
|
-
self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
|
1669
|
-
|
1670
|
-
def forward(
|
1671
|
-
self,
|
1672
|
-
x,
|
1673
|
-
return_embeddings = False,
|
1674
|
-
return_intermediates = False,
|
1675
|
-
return_mems = False,
|
1676
|
-
mask = None,
|
1677
|
-
return_attn = False,
|
1678
|
-
mems = None,
|
1679
|
-
pos = None,
|
1680
|
-
prepend_embeds = None,
|
1681
|
-
**kwargs
|
1682
|
-
):
|
1683
|
-
batch = x.shape[0]
|
1684
|
-
|
1685
|
-
x = self.project_in(x)
|
1686
|
-
x = x + self.pos_emb(x, pos = pos)
|
1687
|
-
|
1688
|
-
x = self.post_emb_norm(x)
|
1689
|
-
|
1690
|
-
# memory tokens
|
1691
|
-
|
1692
|
-
if self.has_memory_tokens:
|
1693
|
-
m = repeat(self.memory_tokens, 'm d -> b m d', b = batch)
|
1694
|
-
x, mem_ps = pack([m, x], 'b * d')
|
1695
|
-
|
1696
|
-
if exists(mask):
|
1697
|
-
num_mems = m.shape[-2]
|
1698
|
-
mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True)
|
1699
|
-
|
1700
|
-
# whether to append embeds, as in PaLI, for image embeddings
|
1701
|
-
|
1702
|
-
if exists(prepend_embeds):
|
1703
|
-
_, prepend_dim = prepend_embeds.shape[1:]
|
1704
|
-
assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as model dimensions'
|
1705
|
-
|
1706
|
-
x = torch.cat((prepend_embeds, x), dim = -2)
|
1707
|
-
|
1708
|
-
x = self.emb_dropout(x)
|
1709
|
-
|
1710
|
-
# attention layers
|
1711
|
-
|
1712
|
-
x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
|
1713
|
-
|
1714
|
-
# splice out memory tokens
|
1715
|
-
|
1716
|
-
if self.has_memory_tokens:
|
1717
|
-
m, x = unpack(x, mem_ps, 'b * d')
|
1718
|
-
intermediates.memory_tokens = m
|
1719
|
-
|
1720
|
-
out = self.project_out(x) if not return_embeddings else x
|
1721
|
-
|
1722
|
-
if return_intermediates:
|
1723
|
-
return out, intermediates
|
1724
|
-
|
1725
|
-
if return_mems:
|
1726
|
-
hiddens = intermediates.hiddens
|
1727
|
-
new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), hiddens))
|
1728
|
-
return out, new_mems
|
1729
|
-
|
1730
|
-
if return_attn:
|
1731
|
-
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
1732
|
-
return out, attn_maps
|
1733
|
-
|
1734
|
-
return out
|
1735
|
-
|
1736
1615
|
class XTransformer(nn.Module):
|
1737
1616
|
def __init__(
|
1738
1617
|
self,
|
@@ -0,0 +1,12 @@
|
|
1
|
+
x_transformers/__init__.py,sha256=LLbGkiUKu4nR4YOISh-1gPEZYs8R3I2KMwMZHMQ2YkU,538
|
2
|
+
x_transformers/attend.py,sha256=hZcz_iijzbEqbXp2_BPEVL-1LoHXmYaHE6e6Oy-7hFE,11263
|
3
|
+
x_transformers/autoregressive_wrapper.py,sha256=f2u0usjUfAlXwgTz87O8J8XjGTbsbrx2XEP6K2beSNI,8944
|
4
|
+
x_transformers/continuous_autoregressive_wrapper.py,sha256=lGqE5vFaDeuLFc7b-dAQ0hx3H4dHFK_yD4-tZQZ7vqQ,5337
|
5
|
+
x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
|
6
|
+
x_transformers/x_transformers.py,sha256=8zfU1iqrd6AwT-L23jkK7tpunxfjLi4HKgbZ6evBWKU,58416
|
7
|
+
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
8
|
+
x_transformers-1.23.6.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
9
|
+
x_transformers-1.23.6.dist-info/METADATA,sha256=UnmCziZTo4xYZJjAdgtRyWLytboEAR0JmyK2i0hwMvc,661
|
10
|
+
x_transformers-1.23.6.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
|
11
|
+
x_transformers-1.23.6.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
12
|
+
x_transformers-1.23.6.dist-info/RECORD,,
|
@@ -1,12 +0,0 @@
|
|
1
|
-
x_transformers/__init__.py,sha256=FDb654rUx8FpXRd76B8q0diH8I7q-ZjTWEtEJ4UM21Y,701
|
2
|
-
x_transformers/attend.py,sha256=hZcz_iijzbEqbXp2_BPEVL-1LoHXmYaHE6e6Oy-7hFE,11263
|
3
|
-
x_transformers/autoregressive_wrapper.py,sha256=f2u0usjUfAlXwgTz87O8J8XjGTbsbrx2XEP6K2beSNI,8944
|
4
|
-
x_transformers/continuous_autoregressive_wrapper.py,sha256=pTiDqu6JRUlnQJQp_xHATYHy0lgSd6ERLqyiFO3pC-4,1575
|
5
|
-
x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
|
6
|
-
x_transformers/x_transformers.py,sha256=xnHu4himfvXeH8HVnFm3JTAGcs22D8H-kzLN9Qrm9-c,62143
|
7
|
-
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
8
|
-
x_transformers-1.23.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
9
|
-
x_transformers-1.23.5.dist-info/METADATA,sha256=F1WPg9cX9USpJmgvN_NzOqp0dsNn3WcchI-cDgoGKfk,661
|
10
|
-
x_transformers-1.23.5.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
|
11
|
-
x_transformers-1.23.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
12
|
-
x_transformers-1.23.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|