x-transformers 1.42.3__py3-none-any.whl → 1.42.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1252,6 +1252,20 @@ class Attention(Module):
1252
1252
 
1253
1253
  k, v, r = tuple(maybe(rearrange)(t, 'b n (h d) -> b h n d', h = kv_h) for t in (k, v, r))
1254
1254
 
1255
+ # if previous values passed in for residual, either invoke resformer or neutreno
1256
+
1257
+ orig_values = v
1258
+
1259
+ if exists(value_residual):
1260
+ if self.neutreno_value_residual:
1261
+ diff_values = (value_residual - v) * self.neutreno_alpha
1262
+ diff_values = repeat(diff_values, 'b h n d -> b (r h) n d', r = h // kv_h)
1263
+ else:
1264
+ # https://arxiv.org/abs/2410.17897v1
1265
+ v = 0.5 * (v + value_residual)
1266
+
1267
+ # take care of caching
1268
+
1255
1269
  if exists(cache):
1256
1270
  ck, cv = cache.cached_kv
1257
1271
 
@@ -1363,16 +1377,6 @@ class Attention(Module):
1363
1377
  if exists(self.data_dependent_alibi):
1364
1378
  attn_bias = self.data_dependent_alibi(x)
1365
1379
 
1366
- # if previous values passed in for residual, either invoke resformer or neutreno
1367
-
1368
- if exists(value_residual):
1369
- if self.neutreno_value_residual:
1370
- diff_values = (value_residual - v) * self.neutreno_alpha
1371
- diff_values = repeat(diff_values, 'b h n d -> b (r h) n d', r = h // kv_h)
1372
- else:
1373
- # https://arxiv.org/abs/2410.17897v1
1374
- v = 0.5 * (v + value_residual)
1375
-
1376
1380
  # attention is all we need
1377
1381
 
1378
1382
  out, intermediates = self.attend(
@@ -1384,7 +1388,7 @@ class Attention(Module):
1384
1388
 
1385
1389
  # store the values for resformer or Neutreno
1386
1390
 
1387
- intermediates.values = v
1391
+ intermediates.values = orig_values
1388
1392
 
1389
1393
  if exists(value_residual) and self.neutreno_value_residual:
1390
1394
  out = out + diff_values
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.3
3
+ Version: 1.42.4
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=K8zt6n7aC8iMMrWJ-0ryNsJeq7eXR94ci5DXAElC-lY,91995
9
+ x_transformers/x_transformers.py,sha256=o_Rm-v1XJyIYU_zDcXWxbHN6whFcK8VKRHvlqTNaQTc,92062
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-1.42.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
- x_transformers-1.42.3.dist-info/METADATA,sha256=A9f9ByY6e62AuBpOEMBndE_b_8TOovaau75D2fo73H4,689
14
- x_transformers-1.42.3.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
15
- x_transformers-1.42.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
- x_transformers-1.42.3.dist-info/RECORD,,
12
+ x_transformers-1.42.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.42.4.dist-info/METADATA,sha256=9KikrLFEmmDn92O9ne5Qd6pEuiztY21vkXg5KWiChhw,689
14
+ x_transformers-1.42.4.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
15
+ x_transformers-1.42.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.42.4.dist-info/RECORD,,