x-transformers 1.42.3__py3-none-any.whl → 1.42.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +36 -15
- {x_transformers-1.42.3.dist-info → x_transformers-1.42.5.dist-info}/METADATA +1 -1
- {x_transformers-1.42.3.dist-info → x_transformers-1.42.5.dist-info}/RECORD +6 -6
- {x_transformers-1.42.3.dist-info → x_transformers-1.42.5.dist-info}/LICENSE +0 -0
- {x_transformers-1.42.3.dist-info → x_transformers-1.42.5.dist-info}/WHEEL +0 -0
- {x_transformers-1.42.3.dist-info → x_transformers-1.42.5.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -512,12 +512,15 @@ class DataDependentAlibi(Module):
|
|
512
512
|
self,
|
513
513
|
dim,
|
514
514
|
heads,
|
515
|
+
causal = True,
|
515
516
|
bias_init = 5.,
|
516
|
-
post_log_scale = 1
|
517
|
+
post_log_scale = 1.,
|
517
518
|
):
|
518
519
|
super().__init__()
|
519
520
|
|
520
|
-
|
521
|
+
self.causal = causal
|
522
|
+
|
523
|
+
linear = nn.Linear(dim, heads * (1 if causal else 2))
|
521
524
|
|
522
525
|
self.to_forget_gates = nn.Sequential(
|
523
526
|
linear,
|
@@ -529,9 +532,21 @@ class DataDependentAlibi(Module):
|
|
529
532
|
self.post_log_scale = post_log_scale
|
530
533
|
|
531
534
|
def forward(self, x):
|
535
|
+
bidirectional = not self.causal
|
536
|
+
|
532
537
|
forget_gates = self.to_forget_gates(x) * self.post_log_scale
|
538
|
+
|
533
539
|
forget_gates = forget_gates.cumsum(dim = -1)
|
540
|
+
|
541
|
+
if bidirectional:
|
542
|
+
forget_gates, forget_gates_reversed = forget_gates.chunk(2, dim = 1)
|
543
|
+
|
534
544
|
forget_gates = einx.subtract('b h i, b h j -> b h i j', forget_gates, forget_gates)
|
545
|
+
|
546
|
+
if bidirectional:
|
547
|
+
forget_gates_reversed = einx.subtract('b h j, b h i -> b h i j', forget_gates_reversed, forget_gates_reversed)
|
548
|
+
forget_gates = forget_gates.tril() + forget_gates_reversed.triu()
|
549
|
+
|
535
550
|
return forget_gates
|
536
551
|
|
537
552
|
class PerRowDataDependentAlibi(Module):
|
@@ -541,10 +556,13 @@ class PerRowDataDependentAlibi(Module):
|
|
541
556
|
self,
|
542
557
|
dim,
|
543
558
|
heads,
|
559
|
+
causal = True,
|
544
560
|
dim_head = 8,
|
545
561
|
post_log_scale = 1.
|
546
562
|
):
|
547
563
|
super().__init__()
|
564
|
+
assert causal, 'bidirectional not supported yet'
|
565
|
+
|
548
566
|
self.scale = dim_head ** -0.5
|
549
567
|
|
550
568
|
linear = nn.Linear(dim, heads * dim_head * 2, bias = False)
|
@@ -1138,10 +1156,9 @@ class Attention(Module):
|
|
1138
1156
|
self.data_dependent_alibi = None
|
1139
1157
|
|
1140
1158
|
if data_dependent_alibi:
|
1141
|
-
assert causal, 'data dependent alibi only works for autoregressive for now until further research'
|
1142
1159
|
|
1143
1160
|
dda_klass = DataDependentAlibi if not data_dependent_alibi_per_row else PerRowDataDependentAlibi
|
1144
|
-
dda_kwargs = dict(dim = dim, heads = heads)
|
1161
|
+
dda_kwargs = dict(dim = dim, heads = heads, causal = causal)
|
1145
1162
|
|
1146
1163
|
if data_dependent_alibi_per_row:
|
1147
1164
|
dda_kwargs.update(dim_head = data_dependent_alibi_per_row_dim_head)
|
@@ -1252,6 +1269,20 @@ class Attention(Module):
|
|
1252
1269
|
|
1253
1270
|
k, v, r = tuple(maybe(rearrange)(t, 'b n (h d) -> b h n d', h = kv_h) for t in (k, v, r))
|
1254
1271
|
|
1272
|
+
# if previous values passed in for residual, either invoke resformer or neutreno
|
1273
|
+
|
1274
|
+
orig_values = v
|
1275
|
+
|
1276
|
+
if exists(value_residual):
|
1277
|
+
if self.neutreno_value_residual:
|
1278
|
+
diff_values = (value_residual - v) * self.neutreno_alpha
|
1279
|
+
diff_values = repeat(diff_values, 'b h n d -> b (r h) n d', r = h // kv_h)
|
1280
|
+
else:
|
1281
|
+
# https://arxiv.org/abs/2410.17897v1
|
1282
|
+
v = 0.5 * (v + value_residual)
|
1283
|
+
|
1284
|
+
# take care of caching
|
1285
|
+
|
1255
1286
|
if exists(cache):
|
1256
1287
|
ck, cv = cache.cached_kv
|
1257
1288
|
|
@@ -1363,16 +1394,6 @@ class Attention(Module):
|
|
1363
1394
|
if exists(self.data_dependent_alibi):
|
1364
1395
|
attn_bias = self.data_dependent_alibi(x)
|
1365
1396
|
|
1366
|
-
# if previous values passed in for residual, either invoke resformer or neutreno
|
1367
|
-
|
1368
|
-
if exists(value_residual):
|
1369
|
-
if self.neutreno_value_residual:
|
1370
|
-
diff_values = (value_residual - v) * self.neutreno_alpha
|
1371
|
-
diff_values = repeat(diff_values, 'b h n d -> b (r h) n d', r = h // kv_h)
|
1372
|
-
else:
|
1373
|
-
# https://arxiv.org/abs/2410.17897v1
|
1374
|
-
v = 0.5 * (v + value_residual)
|
1375
|
-
|
1376
1397
|
# attention is all we need
|
1377
1398
|
|
1378
1399
|
out, intermediates = self.attend(
|
@@ -1384,7 +1405,7 @@ class Attention(Module):
|
|
1384
1405
|
|
1385
1406
|
# store the values for resformer or Neutreno
|
1386
1407
|
|
1387
|
-
intermediates.values =
|
1408
|
+
intermediates.values = orig_values
|
1388
1409
|
|
1389
1410
|
if exists(value_residual) and self.neutreno_value_residual:
|
1390
1411
|
out = out + diff_values
|
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=KIR7efx59xl0BVshU1e6RO0YKgz7zeYBXITDNYWJ4mQ,92506
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-1.42.
|
13
|
-
x_transformers-1.42.
|
14
|
-
x_transformers-1.42.
|
15
|
-
x_transformers-1.42.
|
16
|
-
x_transformers-1.42.
|
12
|
+
x_transformers-1.42.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
13
|
+
x_transformers-1.42.5.dist-info/METADATA,sha256=cLnay5nt6F6GKdghsaqHiZQsVmJ9dS5l-IDozZIs3ec,689
|
14
|
+
x_transformers-1.42.5.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
15
|
+
x_transformers-1.42.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
16
|
+
x_transformers-1.42.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|