x-transformers 1.42.18__tar.gz → 1.42.20__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {x_transformers-1.42.18/x_transformers.egg-info → x_transformers-1.42.20}/PKG-INFO +1 -1
- {x_transformers-1.42.18 → x_transformers-1.42.20}/README.md +9 -0
- {x_transformers-1.42.18 → x_transformers-1.42.20}/setup.py +1 -1
- {x_transformers-1.42.18 → x_transformers-1.42.20}/tests/test_x_transformers.py +23 -3
- {x_transformers-1.42.18 → x_transformers-1.42.20}/x_transformers/attend.py +1 -1
- {x_transformers-1.42.18 → x_transformers-1.42.20}/x_transformers/x_transformers.py +16 -0
- {x_transformers-1.42.18 → x_transformers-1.42.20/x_transformers.egg-info}/PKG-INFO +1 -1
- {x_transformers-1.42.18 → x_transformers-1.42.20}/LICENSE +0 -0
- {x_transformers-1.42.18 → x_transformers-1.42.20}/setup.cfg +0 -0
- {x_transformers-1.42.18 → x_transformers-1.42.20}/x_transformers/__init__.py +0 -0
- {x_transformers-1.42.18 → x_transformers-1.42.20}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-1.42.18 → x_transformers-1.42.20}/x_transformers/continuous.py +0 -0
- {x_transformers-1.42.18 → x_transformers-1.42.20}/x_transformers/dpo.py +0 -0
- {x_transformers-1.42.18 → x_transformers-1.42.20}/x_transformers/multi_input.py +0 -0
- {x_transformers-1.42.18 → x_transformers-1.42.20}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-1.42.18 → x_transformers-1.42.20}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-1.42.18 → x_transformers-1.42.20}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-1.42.18 → x_transformers-1.42.20}/x_transformers/xval.py +0 -0
- {x_transformers-1.42.18 → x_transformers-1.42.20}/x_transformers.egg-info/SOURCES.txt +0 -0
- {x_transformers-1.42.18 → x_transformers-1.42.20}/x_transformers.egg-info/dependency_links.txt +0 -0
- {x_transformers-1.42.18 → x_transformers-1.42.20}/x_transformers.egg-info/requires.txt +0 -0
- {x_transformers-1.42.18 → x_transformers-1.42.20}/x_transformers.egg-info/top_level.txt +0 -0
@@ -2352,4 +2352,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2352
2352
|
}
|
2353
2353
|
```
|
2354
2354
|
|
2355
|
+
```bibtex
|
2356
|
+
@inproceedings{Duvvuri2024LASERAW,
|
2357
|
+
title = {LASER: Attention with Exponential Transformation},
|
2358
|
+
author = {Sai Surya Duvvuri and Inderjit S. Dhillon},
|
2359
|
+
year = {2024},
|
2360
|
+
url = {https://api.semanticscholar.org/CorpusID:273849947}
|
2361
|
+
}
|
2362
|
+
```
|
2363
|
+
|
2355
2364
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
@@ -388,7 +388,8 @@ def test_neo_mlp():
|
|
388
388
|
out = mlp(x)
|
389
389
|
assert out.shape == (3, 7)
|
390
390
|
|
391
|
-
|
391
|
+
@pytest.mark.parametrize('flash', (True, False))
|
392
|
+
def test_custom_alibi(flash: bool):
|
392
393
|
|
393
394
|
model = TransformerWrapper(
|
394
395
|
num_tokens = 20_000,
|
@@ -397,7 +398,8 @@ def test_custom_alibi():
|
|
397
398
|
dim = 512,
|
398
399
|
depth = 2,
|
399
400
|
heads = 8,
|
400
|
-
alibi_pos_bias = True
|
401
|
+
alibi_pos_bias = True,
|
402
|
+
attn_flash = flash
|
401
403
|
)
|
402
404
|
)
|
403
405
|
|
@@ -407,7 +409,8 @@ def test_custom_alibi():
|
|
407
409
|
|
408
410
|
logits = model(x, pos = pos)
|
409
411
|
|
410
|
-
|
412
|
+
@pytest.mark.parametrize('flash', (True, False))
|
413
|
+
def test_custom_alibi_across_heads(flash: bool):
|
411
414
|
|
412
415
|
model = Decoder(
|
413
416
|
dim = 512,
|
@@ -417,6 +420,7 @@ def test_custom_alibi_across_heads():
|
|
417
420
|
rel_pos_kwargs = dict(
|
418
421
|
slopes = [1, 1]
|
419
422
|
),
|
423
|
+
attn_flash = flash
|
420
424
|
)
|
421
425
|
|
422
426
|
x = torch.randn(2, 4, 512)
|
@@ -516,3 +520,19 @@ def test_to_logits(to_logits):
|
|
516
520
|
output = model(x, to_logits_kwargs=to_logits_kwargs)
|
517
521
|
|
518
522
|
assert output.shape == (2, 1024, 20000)
|
523
|
+
|
524
|
+
def test_laser():
|
525
|
+
model = TransformerWrapper(
|
526
|
+
num_tokens = 20000,
|
527
|
+
max_seq_len = 1024,
|
528
|
+
attn_layers = Decoder(
|
529
|
+
dim = 128,
|
530
|
+
depth = 6,
|
531
|
+
heads = 8,
|
532
|
+
attn_laser = True
|
533
|
+
)
|
534
|
+
)
|
535
|
+
|
536
|
+
x = torch.randint(0, 20000, (2, 1024))
|
537
|
+
|
538
|
+
model(x)
|
@@ -370,7 +370,7 @@ class Attend(Module):
|
|
370
370
|
# convert from bool to float
|
371
371
|
|
372
372
|
if exists(attn_bias):
|
373
|
-
attn_bias =
|
373
|
+
attn_bias = attn_bias.expand(batch, heads, -1, -1)
|
374
374
|
|
375
375
|
# if mask given, the mask would already contain the causal mask from above logic
|
376
376
|
# otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number
|
@@ -1075,6 +1075,7 @@ class Attention(Module):
|
|
1075
1075
|
neutreno_value_residual = False, # Nguyen et al. https://arxiv.org/abs/2312.00751
|
1076
1076
|
neutreno_alpha = 0.4,
|
1077
1077
|
learned_value_residual_mix = False,
|
1078
|
+
laser = False, # https://arxiv.org/abs/2411.03493v1
|
1078
1079
|
onnxable = False,
|
1079
1080
|
attend_sdp_kwargs: dict = dict(
|
1080
1081
|
enable_flash = True,
|
@@ -1114,6 +1115,11 @@ class Attention(Module):
|
|
1114
1115
|
assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
|
1115
1116
|
self.to_v = LinearNoBias(dim_kv, v_dim) if not shared_kv else None
|
1116
1117
|
|
1118
|
+
# enhancing gradients to attention through exponentiated values
|
1119
|
+
# todo - compare it to `attn = attn * large_value + attn.detach() * (1. - large_value)`
|
1120
|
+
|
1121
|
+
self.laser = laser
|
1122
|
+
|
1117
1123
|
# relations projection from tp-attention
|
1118
1124
|
|
1119
1125
|
self.to_r = LinearNoBias(dim, v_dim) if tensor_product else None
|
@@ -1439,6 +1445,11 @@ class Attention(Module):
|
|
1439
1445
|
|
1440
1446
|
attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0))
|
1441
1447
|
|
1448
|
+
if self.laser:
|
1449
|
+
values_max = v.amax(dim = -2, keepdim = True).detach() # numerical stability
|
1450
|
+
v = v - values_max
|
1451
|
+
v = v.exp()
|
1452
|
+
|
1442
1453
|
# attention is all we need
|
1443
1454
|
|
1444
1455
|
out, intermediates = self.attend(
|
@@ -1448,6 +1459,11 @@ class Attention(Module):
|
|
1448
1459
|
prev_attn = prev_attn
|
1449
1460
|
)
|
1450
1461
|
|
1462
|
+
# laser
|
1463
|
+
|
1464
|
+
if self.laser:
|
1465
|
+
out = out.log() + values_max
|
1466
|
+
|
1451
1467
|
# store the values for resformer or Neutreno
|
1452
1468
|
|
1453
1469
|
intermediates.values = orig_values
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.42.18 → x_transformers-1.42.20}/x_transformers/nonautoregressive_wrapper.py
RENAMED
File without changes
|
{x_transformers-1.42.18 → x_transformers-1.42.20}/x_transformers/xl_autoregressive_wrapper.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.42.18 → x_transformers-1.42.20}/x_transformers.egg-info/dependency_links.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|