x-transformers 2.3.1__py3-none-any.whl → 2.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/continuous.py +29 -2
- {x_transformers-2.3.1.dist-info → x_transformers-2.3.3.dist-info}/METADATA +1 -1
- {x_transformers-2.3.1.dist-info → x_transformers-2.3.3.dist-info}/RECORD +5 -5
- {x_transformers-2.3.1.dist-info → x_transformers-2.3.3.dist-info}/WHEEL +0 -0
- {x_transformers-2.3.1.dist-info → x_transformers-2.3.3.dist-info}/licenses/LICENSE +0 -0
x_transformers/continuous.py
CHANGED
@@ -10,6 +10,7 @@ import einx
|
|
10
10
|
from einops import rearrange, reduce, pack, repeat, unpack
|
11
11
|
|
12
12
|
from x_transformers.x_transformers import (
|
13
|
+
Attention,
|
13
14
|
AttentionLayers,
|
14
15
|
ScaledSinusoidalEmbedding,
|
15
16
|
AbsolutePositionalEmbedding,
|
@@ -111,6 +112,10 @@ class ContinuousTransformerWrapper(Module):
|
|
111
112
|
|
112
113
|
self.project_out = nn.Linear(dim, dim_out * (2 if probabilistic else 1), bias = False) if exists(dim_out) else nn.Identity()
|
113
114
|
|
115
|
+
# can cache kv
|
116
|
+
|
117
|
+
self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)])
|
118
|
+
|
114
119
|
def forward(
|
115
120
|
self,
|
116
121
|
x,
|
@@ -118,6 +123,7 @@ class ContinuousTransformerWrapper(Module):
|
|
118
123
|
return_intermediates = False,
|
119
124
|
return_mems = False,
|
120
125
|
mask = None,
|
126
|
+
lens = None,
|
121
127
|
return_attn = False,
|
122
128
|
mems = None,
|
123
129
|
mem_masks = None,
|
@@ -128,6 +134,16 @@ class ContinuousTransformerWrapper(Module):
|
|
128
134
|
):
|
129
135
|
batch, seq, orig_mask, device = *x.shape[:2], mask, x.device
|
130
136
|
|
137
|
+
# maybe seq lengths passed in
|
138
|
+
|
139
|
+
if exists(lens):
|
140
|
+
assert not exists(mask), 'either `mask` or `lens` passed in, but not both'
|
141
|
+
seq_arange = torch.arange(seq, device = device)
|
142
|
+
|
143
|
+
mask = einx.less('j, i -> i j', seq_arange, lens)
|
144
|
+
|
145
|
+
# project in + positional embedding
|
146
|
+
|
131
147
|
x = self.project_in(x)
|
132
148
|
x = x + self.pos_emb(x, pos = pos)
|
133
149
|
|
@@ -180,7 +196,7 @@ class ContinuousTransformerWrapper(Module):
|
|
180
196
|
if not return_embeddings and self.probabilistic:
|
181
197
|
mean, log_var = rearrange(out, '... (d mean_log_var) -> mean_log_var ... d', mean_log_var = 2)
|
182
198
|
variance = log_var.exp()
|
183
|
-
|
199
|
+
out = stack((mean, variance))
|
184
200
|
|
185
201
|
if return_intermediates:
|
186
202
|
return out, intermediates
|
@@ -223,9 +239,12 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
223
239
|
start_tokens,
|
224
240
|
seq_len,
|
225
241
|
temperature = 1.,
|
242
|
+
cache_kv = True,
|
226
243
|
**kwargs
|
227
244
|
):
|
245
|
+
should_cache_kv = cache_kv and self.net.can_cache_kv
|
228
246
|
device = start_tokens.device
|
247
|
+
|
229
248
|
was_training = self.net.training
|
230
249
|
num_dims = len(start_tokens.shape)
|
231
250
|
|
@@ -239,10 +258,14 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
239
258
|
self.net.eval()
|
240
259
|
out = start_tokens
|
241
260
|
|
261
|
+
cache = None
|
262
|
+
|
242
263
|
for _ in range(seq_len):
|
243
264
|
x = out[:, -self.max_seq_len:]
|
244
265
|
|
245
|
-
|
266
|
+
net_out, new_cache = self.net(x, cache = cache, return_intermediates = True, **kwargs)
|
267
|
+
|
268
|
+
last_output = net_out[..., -1:, :]
|
246
269
|
|
247
270
|
if self.probabilistic:
|
248
271
|
mean, var = last_output
|
@@ -250,6 +273,9 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
250
273
|
|
251
274
|
out = cat((out, last_output), dim = -2)
|
252
275
|
|
276
|
+
if should_cache_kv:
|
277
|
+
cache = new_cache
|
278
|
+
|
253
279
|
out = out[:, t:]
|
254
280
|
|
255
281
|
if num_dims == 2:
|
@@ -268,6 +294,7 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
268
294
|
assert 'prepend_embeds' not in kwargs
|
269
295
|
|
270
296
|
mask = kwargs.get('mask', None)
|
297
|
+
|
271
298
|
if exists(mask) and mask.shape[1] == x.shape[1]:
|
272
299
|
mask = mask[:, :-1]
|
273
300
|
kwargs['mask'] = mask
|
@@ -2,7 +2,7 @@ x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,98
|
|
2
2
|
x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
|
-
x_transformers/continuous.py,sha256=
|
5
|
+
x_transformers/continuous.py,sha256=663KriM8j3VQPe732kTklB4RPheS4Zj30sB8F9c8ySg,9016
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
|
8
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
@@ -11,7 +11,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
|
|
11
11
|
x_transformers/x_transformers.py,sha256=MF91aJGr2DOjIGe57uqwgyNxCExBg_tI9z7usAJMxOM,112401
|
12
12
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
13
13
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
14
|
-
x_transformers-2.3.
|
15
|
-
x_transformers-2.3.
|
16
|
-
x_transformers-2.3.
|
17
|
-
x_transformers-2.3.
|
14
|
+
x_transformers-2.3.3.dist-info/METADATA,sha256=1Ly7M8Jc_tL2auBw7N8Q4wpCSsp8zq-Td7AHqHzkZmk,88686
|
15
|
+
x_transformers-2.3.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
x_transformers-2.3.3.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
17
|
+
x_transformers-2.3.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|