x-transformers 2.3.17__py3-none-any.whl → 2.3.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -25,6 +25,7 @@ class Intermediates:
25
25
  values: Tensor | None = None
26
26
  cached_kv: Tuple[Tensor, Tensor] | None = None
27
27
  layer_type: str | None = None
28
+ hybrid_hidden: Tensor | None = None
28
29
 
29
30
  def to_tuple(self):
30
31
  return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn)
@@ -32,6 +32,15 @@ def default(val, d):
32
32
  return val
33
33
  return d() if not isinstance(d, Module) and callable(d) else d
34
34
 
35
+ def sample_from_mean_variance(
36
+ mean,
37
+ variance,
38
+ eps = 1e-5,
39
+ temperature = 1.
40
+ ):
41
+ std = variance.clamp(min = eps).sqrt()
42
+ return torch.normal(mean, std * temperature)
43
+
35
44
  def masked_mean(t, mask):
36
45
  t = einx.where('b n, b n d, -> b n d', mask, t, 0.)
37
46
 
@@ -274,9 +283,7 @@ class ContinuousAutoregressiveWrapper(Module):
274
283
 
275
284
  if self.probabilistic:
276
285
  mean, var = last_output
277
- stddev = var.clamp(min = 1e-5).sqrt()
278
-
279
- last_output = torch.normal(mean, stddev * temperature)
286
+ last_output = sample_from_mean_variance(mean, var, temperature = temperature)
280
287
 
281
288
  out = cat((out, last_output), dim = -2)
282
289
 
@@ -298,7 +305,6 @@ class ContinuousAutoregressiveWrapper(Module):
298
305
  **kwargs
299
306
  ):
300
307
  assert rollout_steps > 1
301
- assert not self.probabilistic, 'probabilistic not supported yet'
302
308
 
303
309
  steps = rollout_steps
304
310
 
@@ -369,8 +375,13 @@ class ContinuousAutoregressiveWrapper(Module):
369
375
  **kwargs
370
376
  )
371
377
 
372
- last_pred = out[:, -1:]
373
- inp = last_pred
378
+ last_pred = out[..., -1:, :]
379
+
380
+ if self.probabilistic:
381
+ mean, var = last_pred
382
+ inp = sample_from_mean_variance(mean, var)
383
+ else:
384
+ inp = last_pred
374
385
 
375
386
  preds.append(last_pred)
376
387
 
@@ -1079,10 +1079,11 @@ class FoldAxially(Module):
1079
1079
  def forward(
1080
1080
  self,
1081
1081
  x,
1082
+ *args,
1082
1083
  **kwargs
1083
1084
  ):
1084
1085
  if self.axial_dim == 1:
1085
- return self.fn(x, **kwargs)
1086
+ return self.fn(x, *args, **kwargs)
1086
1087
 
1087
1088
  seq_len, axial_dim = x.shape[1], self.axial_dim
1088
1089
 
@@ -1091,7 +1092,7 @@ class FoldAxially(Module):
1091
1092
 
1092
1093
  x = rearrange(x, 'b (n axial_dim) ... -> (b axial_dim) n ...', axial_dim = axial_dim)
1093
1094
 
1094
- out = self.fn(x, **kwargs)
1095
+ out = self.fn(x, *args, **kwargs)
1095
1096
 
1096
1097
  (out, *rest_out), tree_spec = tree_flatten(out)
1097
1098
 
@@ -1857,9 +1858,17 @@ class Attention(Module):
1857
1858
  if not self.causal and exists(self.hybrid_mask_kwarg):
1858
1859
  hybrid_forward_kwargs = {self.hybrid_mask_kwarg: mask}
1859
1860
 
1861
+ # handle maybe hybrid cache
1862
+
1863
+ hybrid_forward_args = ()
1864
+
1865
+ if exists(cache) and exists(cache.hybrid_hidden):
1866
+ hybrid_hiddens = cache.hybrid_hidden
1867
+ hybrid_forward_args = (hybrid_hiddens,)
1868
+
1860
1869
  # hybrid forward
1861
1870
 
1862
- hybrid_outputs = self.hybrid_module(x, **hybrid_forward_kwargs)
1871
+ hybrid_outputs = self.hybrid_module(x, *hybrid_forward_args, **hybrid_forward_kwargs)
1863
1872
 
1864
1873
  # handle hybrid out
1865
1874
 
@@ -1870,6 +1879,10 @@ class Attention(Module):
1870
1879
  if hybrid_out.ndim == 3:
1871
1880
  hybrid_out = rearrange(hybrid_out, 'b n (h d) -> b h n d', h = h)
1872
1881
 
1882
+ if len(rest_hybrid_outs) > 0:
1883
+ hybrid_hidden = first(rest_hybrid_outs)
1884
+ intermediates.hybrid_hidden = hybrid_hidden
1885
+
1873
1886
  out_norm, hybrid_out_norm = self.hybrid_norms
1874
1887
 
1875
1888
  out = out_norm(out)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.3.17
3
+ Version: 2.3.19
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,17 +1,17 @@
1
1
  x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,985
2
- x_transformers/attend.py,sha256=xFsBtl7h7_qebPh7kE81BpmCWAjCgFpB9i_IHu_91es,17288
2
+ x_transformers/attend.py,sha256=fXMuwHuBAFB4f4_U6j5_uVeK7N4cV0PDd6UTqtkjKKM,17333
3
3
  x_transformers/autoregressive_wrapper.py,sha256=LW1gr3cFONDEPA_HHhaTE7mk-JWbaINuB1fc_DfbCqw,10791
4
4
  x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
5
- x_transformers/continuous.py,sha256=QidhbSgBhYDpAp7FaryxrNJxWXp0-pJkeygguTpVp4k,12308
5
+ x_transformers/continuous.py,sha256=CHta8vizKl85n220fv5278fwjSU-vrN_FBy-m831_go,12551
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
8
8
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
9
9
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
10
10
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
11
- x_transformers/x_transformers.py,sha256=ZfOXrZSiy2jlZ8wVmDdMTLW4hAY_qfmPQHW9t2ABxbo,114097
11
+ x_transformers/x_transformers.py,sha256=s398YQ9JtXc5n34g9qaYnUqaTVLGfRvz0GLg3sEMHLI,114558
12
12
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
13
13
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
14
- x_transformers-2.3.17.dist-info/METADATA,sha256=u_DVXsX7KsVhnfMhpV-3KV6KGNHddmzW_SNG9om557s,89897
15
- x_transformers-2.3.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
- x_transformers-2.3.17.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
- x_transformers-2.3.17.dist-info/RECORD,,
14
+ x_transformers-2.3.19.dist-info/METADATA,sha256=Vn-U7mDaP7H-w-RF5YO3C5n9M5PvnDVKqFJwL3vFV0s,89897
15
+ x_transformers-2.3.19.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
+ x_transformers-2.3.19.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
+ x_transformers-2.3.19.dist-info/RECORD,,