x-transformers 2.3.17__py3-none-any.whl → 2.3.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +1 -0
- x_transformers/continuous.py +17 -6
- x_transformers/x_transformers.py +16 -3
- {x_transformers-2.3.17.dist-info → x_transformers-2.3.19.dist-info}/METADATA +1 -1
- {x_transformers-2.3.17.dist-info → x_transformers-2.3.19.dist-info}/RECORD +7 -7
- {x_transformers-2.3.17.dist-info → x_transformers-2.3.19.dist-info}/WHEEL +0 -0
- {x_transformers-2.3.17.dist-info → x_transformers-2.3.19.dist-info}/licenses/LICENSE +0 -0
x_transformers/attend.py
CHANGED
@@ -25,6 +25,7 @@ class Intermediates:
|
|
25
25
|
values: Tensor | None = None
|
26
26
|
cached_kv: Tuple[Tensor, Tensor] | None = None
|
27
27
|
layer_type: str | None = None
|
28
|
+
hybrid_hidden: Tensor | None = None
|
28
29
|
|
29
30
|
def to_tuple(self):
|
30
31
|
return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn)
|
x_transformers/continuous.py
CHANGED
@@ -32,6 +32,15 @@ def default(val, d):
|
|
32
32
|
return val
|
33
33
|
return d() if not isinstance(d, Module) and callable(d) else d
|
34
34
|
|
35
|
+
def sample_from_mean_variance(
|
36
|
+
mean,
|
37
|
+
variance,
|
38
|
+
eps = 1e-5,
|
39
|
+
temperature = 1.
|
40
|
+
):
|
41
|
+
std = variance.clamp(min = eps).sqrt()
|
42
|
+
return torch.normal(mean, std * temperature)
|
43
|
+
|
35
44
|
def masked_mean(t, mask):
|
36
45
|
t = einx.where('b n, b n d, -> b n d', mask, t, 0.)
|
37
46
|
|
@@ -274,9 +283,7 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
274
283
|
|
275
284
|
if self.probabilistic:
|
276
285
|
mean, var = last_output
|
277
|
-
|
278
|
-
|
279
|
-
last_output = torch.normal(mean, stddev * temperature)
|
286
|
+
last_output = sample_from_mean_variance(mean, var, temperature = temperature)
|
280
287
|
|
281
288
|
out = cat((out, last_output), dim = -2)
|
282
289
|
|
@@ -298,7 +305,6 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
298
305
|
**kwargs
|
299
306
|
):
|
300
307
|
assert rollout_steps > 1
|
301
|
-
assert not self.probabilistic, 'probabilistic not supported yet'
|
302
308
|
|
303
309
|
steps = rollout_steps
|
304
310
|
|
@@ -369,8 +375,13 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
369
375
|
**kwargs
|
370
376
|
)
|
371
377
|
|
372
|
-
last_pred = out[
|
373
|
-
|
378
|
+
last_pred = out[..., -1:, :]
|
379
|
+
|
380
|
+
if self.probabilistic:
|
381
|
+
mean, var = last_pred
|
382
|
+
inp = sample_from_mean_variance(mean, var)
|
383
|
+
else:
|
384
|
+
inp = last_pred
|
374
385
|
|
375
386
|
preds.append(last_pred)
|
376
387
|
|
x_transformers/x_transformers.py
CHANGED
@@ -1079,10 +1079,11 @@ class FoldAxially(Module):
|
|
1079
1079
|
def forward(
|
1080
1080
|
self,
|
1081
1081
|
x,
|
1082
|
+
*args,
|
1082
1083
|
**kwargs
|
1083
1084
|
):
|
1084
1085
|
if self.axial_dim == 1:
|
1085
|
-
return self.fn(x, **kwargs)
|
1086
|
+
return self.fn(x, *args, **kwargs)
|
1086
1087
|
|
1087
1088
|
seq_len, axial_dim = x.shape[1], self.axial_dim
|
1088
1089
|
|
@@ -1091,7 +1092,7 @@ class FoldAxially(Module):
|
|
1091
1092
|
|
1092
1093
|
x = rearrange(x, 'b (n axial_dim) ... -> (b axial_dim) n ...', axial_dim = axial_dim)
|
1093
1094
|
|
1094
|
-
out = self.fn(x, **kwargs)
|
1095
|
+
out = self.fn(x, *args, **kwargs)
|
1095
1096
|
|
1096
1097
|
(out, *rest_out), tree_spec = tree_flatten(out)
|
1097
1098
|
|
@@ -1857,9 +1858,17 @@ class Attention(Module):
|
|
1857
1858
|
if not self.causal and exists(self.hybrid_mask_kwarg):
|
1858
1859
|
hybrid_forward_kwargs = {self.hybrid_mask_kwarg: mask}
|
1859
1860
|
|
1861
|
+
# handle maybe hybrid cache
|
1862
|
+
|
1863
|
+
hybrid_forward_args = ()
|
1864
|
+
|
1865
|
+
if exists(cache) and exists(cache.hybrid_hidden):
|
1866
|
+
hybrid_hiddens = cache.hybrid_hidden
|
1867
|
+
hybrid_forward_args = (hybrid_hiddens,)
|
1868
|
+
|
1860
1869
|
# hybrid forward
|
1861
1870
|
|
1862
|
-
hybrid_outputs = self.hybrid_module(x, **hybrid_forward_kwargs)
|
1871
|
+
hybrid_outputs = self.hybrid_module(x, *hybrid_forward_args, **hybrid_forward_kwargs)
|
1863
1872
|
|
1864
1873
|
# handle hybrid out
|
1865
1874
|
|
@@ -1870,6 +1879,10 @@ class Attention(Module):
|
|
1870
1879
|
if hybrid_out.ndim == 3:
|
1871
1880
|
hybrid_out = rearrange(hybrid_out, 'b n (h d) -> b h n d', h = h)
|
1872
1881
|
|
1882
|
+
if len(rest_hybrid_outs) > 0:
|
1883
|
+
hybrid_hidden = first(rest_hybrid_outs)
|
1884
|
+
intermediates.hybrid_hidden = hybrid_hidden
|
1885
|
+
|
1873
1886
|
out_norm, hybrid_out_norm = self.hybrid_norms
|
1874
1887
|
|
1875
1888
|
out = out_norm(out)
|
@@ -1,17 +1,17 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,985
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=fXMuwHuBAFB4f4_U6j5_uVeK7N4cV0PDd6UTqtkjKKM,17333
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=LW1gr3cFONDEPA_HHhaTE7mk-JWbaINuB1fc_DfbCqw,10791
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
|
-
x_transformers/continuous.py,sha256=
|
5
|
+
x_transformers/continuous.py,sha256=CHta8vizKl85n220fv5278fwjSU-vrN_FBy-m831_go,12551
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
|
8
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
9
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
10
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
11
|
-
x_transformers/x_transformers.py,sha256=
|
11
|
+
x_transformers/x_transformers.py,sha256=s398YQ9JtXc5n34g9qaYnUqaTVLGfRvz0GLg3sEMHLI,114558
|
12
12
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
13
13
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
14
|
-
x_transformers-2.3.
|
15
|
-
x_transformers-2.3.
|
16
|
-
x_transformers-2.3.
|
17
|
-
x_transformers-2.3.
|
14
|
+
x_transformers-2.3.19.dist-info/METADATA,sha256=Vn-U7mDaP7H-w-RF5YO3C5n9M5PvnDVKqFJwL3vFV0s,89897
|
15
|
+
x_transformers-2.3.19.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
x_transformers-2.3.19.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
17
|
+
x_transformers-2.3.19.dist-info/RECORD,,
|
File without changes
|
File without changes
|