x-transformers 2.2.11__py3-none-any.whl → 2.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/continuous.py +56 -14
- x_transformers/x_transformers.py +19 -2
- {x_transformers-2.2.11.dist-info → x_transformers-2.3.0.dist-info}/METADATA +1 -1
- {x_transformers-2.2.11.dist-info → x_transformers-2.3.0.dist-info}/RECORD +6 -6
- {x_transformers-2.2.11.dist-info → x_transformers-2.3.0.dist-info}/WHEEL +0 -0
- {x_transformers-2.2.11.dist-info → x_transformers-2.3.0.dist-info}/licenses/LICENSE +0 -0
x_transformers/continuous.py
CHANGED
@@ -1,9 +1,13 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import torch
|
2
|
-
from torch import nn
|
4
|
+
from torch import nn, cat, stack
|
5
|
+
from torch.nn import Module
|
3
6
|
import torch.nn.functional as F
|
7
|
+
from torch.distributions import Normal
|
4
8
|
|
5
9
|
import einx
|
6
|
-
from einops import reduce, pack, repeat, unpack
|
10
|
+
from einops import rearrange, reduce, pack, repeat, unpack
|
7
11
|
|
8
12
|
from x_transformers.x_transformers import (
|
9
13
|
AttentionLayers,
|
@@ -23,7 +27,7 @@ def exists(val):
|
|
23
27
|
def default(val, d):
|
24
28
|
if exists(val):
|
25
29
|
return val
|
26
|
-
return d() if callable(d) else d
|
30
|
+
return d() if not isinstance(d, Module) and callable(d) else d
|
27
31
|
|
28
32
|
def masked_mean(t, mask):
|
29
33
|
t = einx.where('b n, b n d, -> b n d', mask, t, 0.)
|
@@ -34,9 +38,17 @@ def masked_mean(t, mask):
|
|
34
38
|
masked_average = num / den.clamp(min = 1.)
|
35
39
|
return masked_average
|
36
40
|
|
41
|
+
# probabilistic loss fn
|
42
|
+
|
43
|
+
class GaussianNLL(Module):
|
44
|
+
def forward(self, pred, target):
|
45
|
+
mean, var = pred
|
46
|
+
dist = Normal(mean, var)
|
47
|
+
return -dist.log_prob(target)
|
48
|
+
|
37
49
|
# main classes
|
38
50
|
|
39
|
-
class ContinuousTransformerWrapper(
|
51
|
+
class ContinuousTransformerWrapper(Module):
|
40
52
|
def __init__(
|
41
53
|
self,
|
42
54
|
*,
|
@@ -51,7 +63,8 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
51
63
|
emb_dropout = 0.,
|
52
64
|
use_abs_pos_emb = True,
|
53
65
|
scaled_sinu_pos_emb = False,
|
54
|
-
average_pool_embed = False
|
66
|
+
average_pool_embed = False,
|
67
|
+
probabilistic = False
|
55
68
|
):
|
56
69
|
super().__init__()
|
57
70
|
dim = attn_layers.dim
|
@@ -91,7 +104,12 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
91
104
|
# project in and out
|
92
105
|
|
93
106
|
self.project_in = nn.Linear(dim_in, dim, bias = False) if exists(dim_in) else nn.Identity()
|
94
|
-
|
107
|
+
|
108
|
+
# output is multipled by 2 for outputting mean and log variance
|
109
|
+
|
110
|
+
self.probabilistic = probabilistic
|
111
|
+
|
112
|
+
self.project_out = nn.Linear(dim, dim_out * (2 if probabilistic else 1), bias = False) if exists(dim_out) else nn.Identity()
|
95
113
|
|
96
114
|
def forward(
|
97
115
|
self,
|
@@ -132,13 +150,13 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
132
150
|
|
133
151
|
assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as model dimensions'
|
134
152
|
|
135
|
-
x =
|
153
|
+
x = cat((prepend_embeds, x), dim = -2)
|
136
154
|
|
137
155
|
if exists(prepend_mask) or exists(mask):
|
138
156
|
mask = default(mask, lambda: torch.ones((batch, seq), device = device, dtype = torch.bool))
|
139
157
|
prepend_mask = default(prepend_mask, lambda: torch.ones((batch, prepend_seq), device = device, dtype = torch.bool))
|
140
158
|
|
141
|
-
mask =
|
159
|
+
mask = cat((prepend_mask, mask), dim = -1)
|
142
160
|
|
143
161
|
x = self.emb_dropout(x)
|
144
162
|
|
@@ -159,6 +177,11 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
159
177
|
|
160
178
|
out = self.project_out(x) if not return_embeddings else x
|
161
179
|
|
180
|
+
if not return_embeddings and self.probabilistic:
|
181
|
+
mean, log_var = rearrange(out, '... (d mean_log_var) -> mean_log_var ... d', mean_log_var = 2)
|
182
|
+
variance = log_var.exp()
|
183
|
+
return stack((mean, variance))
|
184
|
+
|
162
185
|
if return_intermediates:
|
163
186
|
return out, intermediates
|
164
187
|
|
@@ -173,24 +196,34 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
173
196
|
|
174
197
|
return out
|
175
198
|
|
176
|
-
class ContinuousAutoregressiveWrapper(
|
199
|
+
class ContinuousAutoregressiveWrapper(Module):
|
177
200
|
def __init__(
|
178
201
|
self,
|
179
202
|
net: ContinuousTransformerWrapper,
|
180
203
|
ignore_index = -100,
|
181
204
|
pad_value = 0,
|
182
|
-
loss_fn
|
205
|
+
loss_fn: Module | None = None,
|
183
206
|
equal_loss_weight_batch = False # setting this to True, if the mask is passed in and sequences are variable in length, each sequence will be weighted the same (as opposed to each token)
|
184
207
|
):
|
185
208
|
super().__init__()
|
186
209
|
self.net = net
|
187
210
|
self.max_seq_len = net.max_seq_len
|
188
211
|
|
212
|
+
probabilistic = net.probabilistic
|
213
|
+
self.probabilistic = probabilistic
|
214
|
+
|
215
|
+
loss_fn = default(loss_fn, nn.MSELoss(reduction = 'none') if not probabilistic else GaussianNLL())
|
216
|
+
|
189
217
|
self.loss_fn = loss_fn
|
190
218
|
self.equal_loss_weight_batch = equal_loss_weight_batch
|
191
219
|
|
192
220
|
@torch.no_grad()
|
193
|
-
def generate(
|
221
|
+
def generate(
|
222
|
+
self,
|
223
|
+
start_tokens,
|
224
|
+
seq_len,
|
225
|
+
**kwargs
|
226
|
+
):
|
194
227
|
device = start_tokens.device
|
195
228
|
was_training = self.net.training
|
196
229
|
num_dims = len(start_tokens.shape)
|
@@ -208,8 +241,13 @@ class ContinuousAutoregressiveWrapper(nn.Module):
|
|
208
241
|
for _ in range(seq_len):
|
209
242
|
x = out[:, -self.max_seq_len:]
|
210
243
|
|
211
|
-
|
212
|
-
|
244
|
+
last_output = self.net(x, **kwargs)[..., -1:, :]
|
245
|
+
|
246
|
+
if self.probabilistic:
|
247
|
+
mean, var = last_output
|
248
|
+
last_output = torch.normal(mean, var)
|
249
|
+
|
250
|
+
out = cat((out, last_output), dim = -2)
|
213
251
|
|
214
252
|
out = out[:, t:]
|
215
253
|
|
@@ -219,7 +257,11 @@ class ContinuousAutoregressiveWrapper(nn.Module):
|
|
219
257
|
self.net.train(was_training)
|
220
258
|
return out
|
221
259
|
|
222
|
-
def forward(
|
260
|
+
def forward(
|
261
|
+
self,
|
262
|
+
x,
|
263
|
+
**kwargs
|
264
|
+
):
|
223
265
|
inp, target = x[:, :-1], x[:, 1:]
|
224
266
|
|
225
267
|
assert 'prepend_embeds' not in kwargs
|
x_transformers/x_transformers.py
CHANGED
@@ -62,6 +62,9 @@ def default(val, d):
|
|
62
62
|
return val
|
63
63
|
return d() if callable(d) else d
|
64
64
|
|
65
|
+
def identity(t, *args, **kwargs):
|
66
|
+
return t
|
67
|
+
|
65
68
|
def first(it, default = None):
|
66
69
|
return it[0] if len(it) > 0 else default
|
67
70
|
|
@@ -74,7 +77,10 @@ def cast_tuple(val, depth = 1):
|
|
74
77
|
def divisible_by(num, den):
|
75
78
|
return (num % den) == 0
|
76
79
|
|
77
|
-
def maybe(fn):
|
80
|
+
def maybe(fn = None):
|
81
|
+
if not exists(fn):
|
82
|
+
fn = identity
|
83
|
+
|
78
84
|
@wraps(fn)
|
79
85
|
def inner(x, *args, **kwargs):
|
80
86
|
if not exists(x):
|
@@ -1199,6 +1205,7 @@ class FeedForward(Module):
|
|
1199
1205
|
custom_activation = None,
|
1200
1206
|
post_act_ln = False,
|
1201
1207
|
dropout = 0.,
|
1208
|
+
sublayer_dropout = 0.,
|
1202
1209
|
no_bias = False,
|
1203
1210
|
zero_init_output = False
|
1204
1211
|
):
|
@@ -1227,7 +1234,8 @@ class FeedForward(Module):
|
|
1227
1234
|
project_in,
|
1228
1235
|
LayerNorm(inner_dim) if post_act_ln else None,
|
1229
1236
|
nn.Dropout(dropout),
|
1230
|
-
nn.Linear(inner_dim, dim_out, bias = not no_bias)
|
1237
|
+
nn.Linear(inner_dim, dim_out, bias = not no_bias),
|
1238
|
+
nn.Dropout(sublayer_dropout) if sublayer_dropout > 0. else None
|
1231
1239
|
)
|
1232
1240
|
|
1233
1241
|
# init last linear layer to 0
|
@@ -1256,6 +1264,7 @@ class Attention(Module):
|
|
1256
1264
|
sparse_topk_straight_through = False,
|
1257
1265
|
num_mem_kv = 0,
|
1258
1266
|
dropout = 0.,
|
1267
|
+
sublayer_dropout = 0.,
|
1259
1268
|
on_attn = False,
|
1260
1269
|
gate_value_heads = False,
|
1261
1270
|
swiglu_values = False,
|
@@ -1534,6 +1543,10 @@ class Attention(Module):
|
|
1534
1543
|
dim_out = default(dim_out, dim)
|
1535
1544
|
self.to_out = nn.Sequential(LinearNoBias(out_dim, dim_out * 2), nn.GLU()) if on_attn else LinearNoBias(out_dim, dim_out)
|
1536
1545
|
|
1546
|
+
# sublayer dropout
|
1547
|
+
|
1548
|
+
self.sublayer_dropout = nn.Dropout(sublayer_dropout) if sublayer_dropout > 0. else None
|
1549
|
+
|
1537
1550
|
# the number of attention heads to rotate, for decoupled rope in multi-latent attention
|
1538
1551
|
|
1539
1552
|
rotate_num_heads = default(rotate_num_heads, heads)
|
@@ -1871,6 +1884,10 @@ class Attention(Module):
|
|
1871
1884
|
|
1872
1885
|
out = self.to_out(out)
|
1873
1886
|
|
1887
|
+
# maybe sublayer dropout
|
1888
|
+
|
1889
|
+
out = maybe(self.sublayer_dropout)(out)
|
1890
|
+
|
1874
1891
|
if exists(mask):
|
1875
1892
|
out = einx.where('b n, b n d, -> b n d', mask, out, 0.)
|
1876
1893
|
|
@@ -2,16 +2,16 @@ x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,98
|
|
2
2
|
x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
|
-
x_transformers/continuous.py,sha256=
|
5
|
+
x_transformers/continuous.py,sha256=F7qGfpjJbPNvF49Dur0EqKE_KV6avSsO7CSvYC8pDh8,8216
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
|
8
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
9
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
10
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
11
|
-
x_transformers/x_transformers.py,sha256=
|
11
|
+
x_transformers/x_transformers.py,sha256=MF91aJGr2DOjIGe57uqwgyNxCExBg_tI9z7usAJMxOM,112401
|
12
12
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
13
13
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
14
|
-
x_transformers-2.
|
15
|
-
x_transformers-2.
|
16
|
-
x_transformers-2.
|
17
|
-
x_transformers-2.
|
14
|
+
x_transformers-2.3.0.dist-info/METADATA,sha256=ppv8n1fXmiOkiMvry_5WgmhP8zAHZoxIOjEVbHKDTtE,88686
|
15
|
+
x_transformers-2.3.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
x_transformers-2.3.0.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
17
|
+
x_transformers-2.3.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|