x-transformers 1.23.4__py3-none-any.whl → 1.23.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/__init__.py +15 -9
- x_transformers/continuous_autoregressive_wrapper.py +137 -1
- x_transformers/x_transformers.py +6 -123
- {x_transformers-1.23.4.dist-info → x_transformers-1.23.6.dist-info}/METADATA +1 -1
- x_transformers-1.23.6.dist-info/RECORD +12 -0
- x_transformers-1.23.4.dist-info/RECORD +0 -12
- {x_transformers-1.23.4.dist-info → x_transformers-1.23.6.dist-info}/LICENSE +0 -0
- {x_transformers-1.23.4.dist-info → x_transformers-1.23.6.dist-info}/WHEEL +0 -0
- {x_transformers-1.23.4.dist-info → x_transformers-1.23.6.dist-info}/top_level.txt +0 -0
x_transformers/__init__.py
CHANGED
@@ -1,13 +1,19 @@
|
|
1
|
-
import
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
1
|
+
from x_transformers.x_transformers import (
|
2
|
+
XTransformer,
|
3
|
+
Encoder,
|
4
|
+
Decoder,
|
5
|
+
CrossAttender,
|
6
|
+
Attention,
|
7
|
+
TransformerWrapper,
|
8
|
+
ViTransformerWrapper
|
9
|
+
)
|
9
10
|
|
10
11
|
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
11
12
|
from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper
|
12
|
-
|
13
|
+
|
14
|
+
from x_transformers.continuous_autoregressive_wrapper import (
|
15
|
+
ContinuousTransformerWrapper,
|
16
|
+
ContinuousAutoregressiveWrapper
|
17
|
+
)
|
18
|
+
|
13
19
|
from x_transformers.xl_autoregressive_wrapper import XLAutoregressiveWrapper
|
@@ -2,11 +2,147 @@ import torch
|
|
2
2
|
from torch import nn
|
3
3
|
import torch.nn.functional as F
|
4
4
|
|
5
|
+
from x_transformers.x_transformers import (
|
6
|
+
AttentionLayers,
|
7
|
+
ScaledSinusoidalEmbedding,
|
8
|
+
AbsolutePositionalEmbedding
|
9
|
+
)
|
10
|
+
|
11
|
+
# helper functions
|
12
|
+
|
5
13
|
def exists(val):
|
6
14
|
return val is not None
|
7
15
|
|
16
|
+
def default(val, d):
|
17
|
+
if exists(val):
|
18
|
+
return val
|
19
|
+
return d() if callable(d) else d
|
20
|
+
|
21
|
+
# main classes
|
22
|
+
|
23
|
+
class ContinuousTransformerWrapper(nn.Module):
|
24
|
+
def __init__(
|
25
|
+
self,
|
26
|
+
*,
|
27
|
+
max_seq_len,
|
28
|
+
attn_layers: AttentionLayers,
|
29
|
+
dim_in = None,
|
30
|
+
dim_out = None,
|
31
|
+
emb_dim = None,
|
32
|
+
max_mem_len = 0,
|
33
|
+
num_memory_tokens = None,
|
34
|
+
post_emb_norm = False,
|
35
|
+
emb_dropout = 0.,
|
36
|
+
use_abs_pos_emb = True,
|
37
|
+
scaled_sinu_pos_emb = False
|
38
|
+
):
|
39
|
+
super().__init__()
|
40
|
+
dim = attn_layers.dim
|
41
|
+
|
42
|
+
self.max_seq_len = max_seq_len
|
43
|
+
|
44
|
+
self.max_mem_len = max_mem_len
|
45
|
+
|
46
|
+
if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
|
47
|
+
self.pos_emb = always(0)
|
48
|
+
elif scaled_sinu_pos_emb:
|
49
|
+
self.pos_emb = ScaledSinusoidalEmbedding(dim)
|
50
|
+
else:
|
51
|
+
self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len)
|
52
|
+
|
53
|
+
self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
|
54
|
+
self.emb_dropout = nn.Dropout(emb_dropout)
|
55
|
+
|
56
|
+
# memory tokens
|
57
|
+
|
58
|
+
num_memory_tokens = default(num_memory_tokens, 0)
|
59
|
+
self.has_memory_tokens = num_memory_tokens > 0
|
60
|
+
|
61
|
+
if num_memory_tokens > 0:
|
62
|
+
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
|
63
|
+
|
64
|
+
# attention layers
|
65
|
+
|
66
|
+
self.attn_layers = attn_layers
|
67
|
+
|
68
|
+
# project in and out
|
69
|
+
|
70
|
+
self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
|
71
|
+
self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
|
72
|
+
|
73
|
+
def forward(
|
74
|
+
self,
|
75
|
+
x,
|
76
|
+
return_embeddings = False,
|
77
|
+
return_intermediates = False,
|
78
|
+
return_mems = False,
|
79
|
+
mask = None,
|
80
|
+
return_attn = False,
|
81
|
+
mems = None,
|
82
|
+
pos = None,
|
83
|
+
prepend_embeds = None,
|
84
|
+
**kwargs
|
85
|
+
):
|
86
|
+
batch = x.shape[0]
|
87
|
+
|
88
|
+
x = self.project_in(x)
|
89
|
+
x = x + self.pos_emb(x, pos = pos)
|
90
|
+
|
91
|
+
x = self.post_emb_norm(x)
|
92
|
+
|
93
|
+
# memory tokens
|
94
|
+
|
95
|
+
if self.has_memory_tokens:
|
96
|
+
m = repeat(self.memory_tokens, 'm d -> b m d', b = batch)
|
97
|
+
x, mem_ps = pack([m, x], 'b * d')
|
98
|
+
|
99
|
+
if exists(mask):
|
100
|
+
num_mems = m.shape[-2]
|
101
|
+
mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True)
|
102
|
+
|
103
|
+
# whether to append embeds, as in PaLI, for image embeddings
|
104
|
+
|
105
|
+
if exists(prepend_embeds):
|
106
|
+
_, prepend_dim = prepend_embeds.shape[1:]
|
107
|
+
assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as model dimensions'
|
108
|
+
|
109
|
+
x = torch.cat((prepend_embeds, x), dim = -2)
|
110
|
+
|
111
|
+
x = self.emb_dropout(x)
|
112
|
+
|
113
|
+
# attention layers
|
114
|
+
|
115
|
+
x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
|
116
|
+
|
117
|
+
# splice out memory tokens
|
118
|
+
|
119
|
+
if self.has_memory_tokens:
|
120
|
+
m, x = unpack(x, mem_ps, 'b * d')
|
121
|
+
intermediates.memory_tokens = m
|
122
|
+
|
123
|
+
out = self.project_out(x) if not return_embeddings else x
|
124
|
+
|
125
|
+
if return_intermediates:
|
126
|
+
return out, intermediates
|
127
|
+
|
128
|
+
if return_mems:
|
129
|
+
hiddens = intermediates.hiddens
|
130
|
+
new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), hiddens))
|
131
|
+
return out, new_mems
|
132
|
+
|
133
|
+
if return_attn:
|
134
|
+
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
135
|
+
return out, attn_maps
|
136
|
+
|
137
|
+
return out
|
138
|
+
|
8
139
|
class ContinuousAutoregressiveWrapper(nn.Module):
|
9
|
-
def __init__(
|
140
|
+
def __init__(
|
141
|
+
self,
|
142
|
+
net: ContinuousTransformerWrapper,
|
143
|
+
ignore_index = -100,
|
144
|
+
pad_value = 0
|
145
|
+
):
|
10
146
|
super().__init__()
|
11
147
|
self.net = net
|
12
148
|
self.max_seq_len = net.max_seq_len
|
x_transformers/x_transformers.py
CHANGED
@@ -6,7 +6,6 @@ from torch import nn, einsum, Tensor
|
|
6
6
|
import torch.nn.functional as F
|
7
7
|
|
8
8
|
from functools import partial, wraps
|
9
|
-
from inspect import isfunction
|
10
9
|
from collections import namedtuple
|
11
10
|
from dataclasses import dataclass
|
12
11
|
from typing import List, Callable, Optional
|
@@ -28,6 +27,7 @@ class LayerIntermediates:
|
|
28
27
|
layer_hiddens: Optional[List[Tensor]] = None
|
29
28
|
attn_z_loss: Optional[Tensor] = None
|
30
29
|
mems: Optional[Tensor] = None
|
30
|
+
memory_tokens: Optional[Tensor] = None
|
31
31
|
|
32
32
|
# helpers
|
33
33
|
|
@@ -37,7 +37,7 @@ def exists(val):
|
|
37
37
|
def default(val, d):
|
38
38
|
if exists(val):
|
39
39
|
return val
|
40
|
-
return d() if
|
40
|
+
return d() if callable(d) else d
|
41
41
|
|
42
42
|
def cast_tuple(val, depth):
|
43
43
|
return val if isinstance(val, tuple) else (val,) * depth
|
@@ -1339,7 +1339,7 @@ class ViTransformerWrapper(nn.Module):
|
|
1339
1339
|
*,
|
1340
1340
|
image_size,
|
1341
1341
|
patch_size,
|
1342
|
-
attn_layers,
|
1342
|
+
attn_layers: Encoder,
|
1343
1343
|
channels = 3,
|
1344
1344
|
num_classes = None,
|
1345
1345
|
post_emb_norm = False,
|
@@ -1347,7 +1347,6 @@ class ViTransformerWrapper(nn.Module):
|
|
1347
1347
|
emb_dropout = 0.
|
1348
1348
|
):
|
1349
1349
|
super().__init__()
|
1350
|
-
assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder'
|
1351
1350
|
assert divisible_by(image_size, patch_size), 'image dimensions must be divisible by the patch size'
|
1352
1351
|
dim = attn_layers.dim
|
1353
1352
|
num_patches = (image_size // patch_size) ** 2
|
@@ -1413,7 +1412,7 @@ class TransformerWrapper(nn.Module):
|
|
1413
1412
|
*,
|
1414
1413
|
num_tokens,
|
1415
1414
|
max_seq_len,
|
1416
|
-
attn_layers,
|
1415
|
+
attn_layers: AttentionLayers,
|
1417
1416
|
emb_dim = None,
|
1418
1417
|
max_mem_len = 0,
|
1419
1418
|
shift_mem_down = 0,
|
@@ -1430,7 +1429,6 @@ class TransformerWrapper(nn.Module):
|
|
1430
1429
|
attn_z_loss_weight = 1e-4,
|
1431
1430
|
):
|
1432
1431
|
super().__init__()
|
1433
|
-
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
|
1434
1432
|
|
1435
1433
|
dim = attn_layers.dim
|
1436
1434
|
emb_dim = default(emb_dim, dim)
|
@@ -1576,6 +1574,8 @@ class TransformerWrapper(nn.Module):
|
|
1576
1574
|
|
1577
1575
|
mem, x = unpack(x, mem_packed_shape, 'b * d')
|
1578
1576
|
|
1577
|
+
intermediates.memory_tokens = mem
|
1578
|
+
|
1579
1579
|
if exists(mem_every):
|
1580
1580
|
x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
|
1581
1581
|
|
@@ -1612,123 +1612,6 @@ class TransformerWrapper(nn.Module):
|
|
1612
1612
|
|
1613
1613
|
return out
|
1614
1614
|
|
1615
|
-
class ContinuousTransformerWrapper(nn.Module):
|
1616
|
-
def __init__(
|
1617
|
-
self,
|
1618
|
-
*,
|
1619
|
-
max_seq_len,
|
1620
|
-
attn_layers,
|
1621
|
-
dim_in = None,
|
1622
|
-
dim_out = None,
|
1623
|
-
emb_dim = None,
|
1624
|
-
max_mem_len = 0,
|
1625
|
-
num_memory_tokens = None,
|
1626
|
-
post_emb_norm = False,
|
1627
|
-
emb_dropout = 0.,
|
1628
|
-
use_abs_pos_emb = True,
|
1629
|
-
scaled_sinu_pos_emb = False
|
1630
|
-
):
|
1631
|
-
super().__init__()
|
1632
|
-
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
|
1633
|
-
|
1634
|
-
dim = attn_layers.dim
|
1635
|
-
|
1636
|
-
self.max_seq_len = max_seq_len
|
1637
|
-
|
1638
|
-
self.max_mem_len = max_mem_len
|
1639
|
-
|
1640
|
-
if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
|
1641
|
-
self.pos_emb = always(0)
|
1642
|
-
elif scaled_sinu_pos_emb:
|
1643
|
-
self.pos_emb = ScaledSinusoidalEmbedding(dim)
|
1644
|
-
else:
|
1645
|
-
self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len)
|
1646
|
-
|
1647
|
-
self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
|
1648
|
-
self.emb_dropout = nn.Dropout(emb_dropout)
|
1649
|
-
|
1650
|
-
# memory tokens
|
1651
|
-
|
1652
|
-
num_memory_tokens = default(num_memory_tokens, 0)
|
1653
|
-
self.has_memory_tokens = num_memory_tokens > 0
|
1654
|
-
|
1655
|
-
if num_memory_tokens > 0:
|
1656
|
-
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
|
1657
|
-
|
1658
|
-
# attention layers
|
1659
|
-
|
1660
|
-
self.attn_layers = attn_layers
|
1661
|
-
|
1662
|
-
# project in and out
|
1663
|
-
|
1664
|
-
self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
|
1665
|
-
self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
|
1666
|
-
|
1667
|
-
def forward(
|
1668
|
-
self,
|
1669
|
-
x,
|
1670
|
-
return_embeddings = False,
|
1671
|
-
return_intermediates = False,
|
1672
|
-
return_mems = False,
|
1673
|
-
mask = None,
|
1674
|
-
return_attn = False,
|
1675
|
-
mems = None,
|
1676
|
-
pos = None,
|
1677
|
-
prepend_embeds = None,
|
1678
|
-
**kwargs
|
1679
|
-
):
|
1680
|
-
batch = x.shape[0]
|
1681
|
-
|
1682
|
-
x = self.project_in(x)
|
1683
|
-
x = x + self.pos_emb(x, pos = pos)
|
1684
|
-
|
1685
|
-
x = self.post_emb_norm(x)
|
1686
|
-
|
1687
|
-
# memory tokens
|
1688
|
-
|
1689
|
-
if self.has_memory_tokens:
|
1690
|
-
m = repeat(self.memory_tokens, 'm d -> b m d', b = batch)
|
1691
|
-
x, mem_ps = pack([m, x], 'b * d')
|
1692
|
-
|
1693
|
-
if exists(mask):
|
1694
|
-
num_mems = m.shape[-2]
|
1695
|
-
mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True)
|
1696
|
-
|
1697
|
-
# whether to append embeds, as in PaLI, for image embeddings
|
1698
|
-
|
1699
|
-
if exists(prepend_embeds):
|
1700
|
-
_, prepend_dim = prepend_embeds.shape[1:]
|
1701
|
-
assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as model dimensions'
|
1702
|
-
|
1703
|
-
x = torch.cat((prepend_embeds, x), dim = -2)
|
1704
|
-
|
1705
|
-
x = self.emb_dropout(x)
|
1706
|
-
|
1707
|
-
# attention layers
|
1708
|
-
|
1709
|
-
x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
|
1710
|
-
|
1711
|
-
# splice out memory tokens
|
1712
|
-
|
1713
|
-
if self.has_memory_tokens:
|
1714
|
-
m, x = unpack(x, mem_ps, 'b * d')
|
1715
|
-
|
1716
|
-
out = self.project_out(x) if not return_embeddings else x
|
1717
|
-
|
1718
|
-
if return_intermediates:
|
1719
|
-
return out, intermediates
|
1720
|
-
|
1721
|
-
if return_mems:
|
1722
|
-
hiddens = intermediates.hiddens
|
1723
|
-
new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), hiddens))
|
1724
|
-
return out, new_mems
|
1725
|
-
|
1726
|
-
if return_attn:
|
1727
|
-
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
1728
|
-
return out, attn_maps
|
1729
|
-
|
1730
|
-
return out
|
1731
|
-
|
1732
1615
|
class XTransformer(nn.Module):
|
1733
1616
|
def __init__(
|
1734
1617
|
self,
|
@@ -0,0 +1,12 @@
|
|
1
|
+
x_transformers/__init__.py,sha256=LLbGkiUKu4nR4YOISh-1gPEZYs8R3I2KMwMZHMQ2YkU,538
|
2
|
+
x_transformers/attend.py,sha256=hZcz_iijzbEqbXp2_BPEVL-1LoHXmYaHE6e6Oy-7hFE,11263
|
3
|
+
x_transformers/autoregressive_wrapper.py,sha256=f2u0usjUfAlXwgTz87O8J8XjGTbsbrx2XEP6K2beSNI,8944
|
4
|
+
x_transformers/continuous_autoregressive_wrapper.py,sha256=lGqE5vFaDeuLFc7b-dAQ0hx3H4dHFK_yD4-tZQZ7vqQ,5337
|
5
|
+
x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
|
6
|
+
x_transformers/x_transformers.py,sha256=8zfU1iqrd6AwT-L23jkK7tpunxfjLi4HKgbZ6evBWKU,58416
|
7
|
+
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
8
|
+
x_transformers-1.23.6.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
9
|
+
x_transformers-1.23.6.dist-info/METADATA,sha256=UnmCziZTo4xYZJjAdgtRyWLytboEAR0JmyK2i0hwMvc,661
|
10
|
+
x_transformers-1.23.6.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
|
11
|
+
x_transformers-1.23.6.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
12
|
+
x_transformers-1.23.6.dist-info/RECORD,,
|
@@ -1,12 +0,0 @@
|
|
1
|
-
x_transformers/__init__.py,sha256=FDb654rUx8FpXRd76B8q0diH8I7q-ZjTWEtEJ4UM21Y,701
|
2
|
-
x_transformers/attend.py,sha256=hZcz_iijzbEqbXp2_BPEVL-1LoHXmYaHE6e6Oy-7hFE,11263
|
3
|
-
x_transformers/autoregressive_wrapper.py,sha256=f2u0usjUfAlXwgTz87O8J8XjGTbsbrx2XEP6K2beSNI,8944
|
4
|
-
x_transformers/continuous_autoregressive_wrapper.py,sha256=pTiDqu6JRUlnQJQp_xHATYHy0lgSd6ERLqyiFO3pC-4,1575
|
5
|
-
x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
|
6
|
-
x_transformers/x_transformers.py,sha256=jjuh7MLIlV4pHsJHONFlQy96QTdI5VonR9cc8FVs4J8,62009
|
7
|
-
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
8
|
-
x_transformers-1.23.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
9
|
-
x_transformers-1.23.4.dist-info/METADATA,sha256=51dyPgX1be0-AH9PktXW_fme4GejI0YgZXbviy97j0Q,661
|
10
|
-
x_transformers-1.23.4.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
|
11
|
-
x_transformers-1.23.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
12
|
-
x_transformers-1.23.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|