x-transformers 2.3.0__py3-none-any.whl → 2.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/continuous.py +19 -3
- {x_transformers-2.3.0.dist-info → x_transformers-2.3.2.dist-info}/METADATA +1 -1
- {x_transformers-2.3.0.dist-info → x_transformers-2.3.2.dist-info}/RECORD +5 -5
- {x_transformers-2.3.0.dist-info → x_transformers-2.3.2.dist-info}/WHEEL +0 -0
- {x_transformers-2.3.0.dist-info → x_transformers-2.3.2.dist-info}/licenses/LICENSE +0 -0
x_transformers/continuous.py
CHANGED
@@ -10,6 +10,7 @@ import einx
|
|
10
10
|
from einops import rearrange, reduce, pack, repeat, unpack
|
11
11
|
|
12
12
|
from x_transformers.x_transformers import (
|
13
|
+
Attention,
|
13
14
|
AttentionLayers,
|
14
15
|
ScaledSinusoidalEmbedding,
|
15
16
|
AbsolutePositionalEmbedding,
|
@@ -111,6 +112,10 @@ class ContinuousTransformerWrapper(Module):
|
|
111
112
|
|
112
113
|
self.project_out = nn.Linear(dim, dim_out * (2 if probabilistic else 1), bias = False) if exists(dim_out) else nn.Identity()
|
113
114
|
|
115
|
+
# can cache kv
|
116
|
+
|
117
|
+
self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)])
|
118
|
+
|
114
119
|
def forward(
|
115
120
|
self,
|
116
121
|
x,
|
@@ -180,7 +185,7 @@ class ContinuousTransformerWrapper(Module):
|
|
180
185
|
if not return_embeddings and self.probabilistic:
|
181
186
|
mean, log_var = rearrange(out, '... (d mean_log_var) -> mean_log_var ... d', mean_log_var = 2)
|
182
187
|
variance = log_var.exp()
|
183
|
-
|
188
|
+
out = stack((mean, variance))
|
184
189
|
|
185
190
|
if return_intermediates:
|
186
191
|
return out, intermediates
|
@@ -222,9 +227,13 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
222
227
|
self,
|
223
228
|
start_tokens,
|
224
229
|
seq_len,
|
230
|
+
temperature = 1.,
|
231
|
+
cache_kv = True,
|
225
232
|
**kwargs
|
226
233
|
):
|
234
|
+
should_cache_kv = cache_kv and self.net.can_cache_kv
|
227
235
|
device = start_tokens.device
|
236
|
+
|
228
237
|
was_training = self.net.training
|
229
238
|
num_dims = len(start_tokens.shape)
|
230
239
|
|
@@ -238,17 +247,24 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
238
247
|
self.net.eval()
|
239
248
|
out = start_tokens
|
240
249
|
|
250
|
+
cache = None
|
251
|
+
|
241
252
|
for _ in range(seq_len):
|
242
253
|
x = out[:, -self.max_seq_len:]
|
243
254
|
|
244
|
-
|
255
|
+
net_out, new_cache = self.net(x, cache = cache, return_intermediates = True, **kwargs)
|
256
|
+
|
257
|
+
last_output = net_out[..., -1:, :]
|
245
258
|
|
246
259
|
if self.probabilistic:
|
247
260
|
mean, var = last_output
|
248
|
-
last_output = torch.normal(mean, var)
|
261
|
+
last_output = torch.normal(mean, var * temperature)
|
249
262
|
|
250
263
|
out = cat((out, last_output), dim = -2)
|
251
264
|
|
265
|
+
if should_cache_kv:
|
266
|
+
cache = new_cache
|
267
|
+
|
252
268
|
out = out[:, t:]
|
253
269
|
|
254
270
|
if num_dims == 2:
|
@@ -2,7 +2,7 @@ x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,98
|
|
2
2
|
x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
|
-
x_transformers/continuous.py,sha256=
|
5
|
+
x_transformers/continuous.py,sha256=yAE8hLyusrEd-12mkgLASDL-cFgpZQf32s93FKfez7o,8674
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
|
8
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
@@ -11,7 +11,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
|
|
11
11
|
x_transformers/x_transformers.py,sha256=MF91aJGr2DOjIGe57uqwgyNxCExBg_tI9z7usAJMxOM,112401
|
12
12
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
13
13
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
14
|
-
x_transformers-2.3.
|
15
|
-
x_transformers-2.3.
|
16
|
-
x_transformers-2.3.
|
17
|
-
x_transformers-2.3.
|
14
|
+
x_transformers-2.3.2.dist-info/METADATA,sha256=8m6hpJlMKesI-SLxth_9z0VYIHUh7bTWsJ9Am5OSni4,88686
|
15
|
+
x_transformers-2.3.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
x_transformers-2.3.2.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
17
|
+
x_transformers-2.3.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|