x-transformers 1.42.8__tar.gz → 1.42.9__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {x_transformers-1.42.8/x_transformers.egg-info → x_transformers-1.42.9}/PKG-INFO +1 -1
- {x_transformers-1.42.8 → x_transformers-1.42.9}/setup.py +1 -1
- {x_transformers-1.42.8 → x_transformers-1.42.9}/tests/test_x_transformers.py +21 -0
- {x_transformers-1.42.8 → x_transformers-1.42.9}/x_transformers/x_transformers.py +19 -8
- {x_transformers-1.42.8 → x_transformers-1.42.9/x_transformers.egg-info}/PKG-INFO +1 -1
- {x_transformers-1.42.8 → x_transformers-1.42.9}/LICENSE +0 -0
- {x_transformers-1.42.8 → x_transformers-1.42.9}/README.md +0 -0
- {x_transformers-1.42.8 → x_transformers-1.42.9}/setup.cfg +0 -0
- {x_transformers-1.42.8 → x_transformers-1.42.9}/x_transformers/__init__.py +0 -0
- {x_transformers-1.42.8 → x_transformers-1.42.9}/x_transformers/attend.py +0 -0
- {x_transformers-1.42.8 → x_transformers-1.42.9}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-1.42.8 → x_transformers-1.42.9}/x_transformers/continuous.py +0 -0
- {x_transformers-1.42.8 → x_transformers-1.42.9}/x_transformers/dpo.py +0 -0
- {x_transformers-1.42.8 → x_transformers-1.42.9}/x_transformers/multi_input.py +0 -0
- {x_transformers-1.42.8 → x_transformers-1.42.9}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-1.42.8 → x_transformers-1.42.9}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-1.42.8 → x_transformers-1.42.9}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-1.42.8 → x_transformers-1.42.9}/x_transformers/xval.py +0 -0
- {x_transformers-1.42.8 → x_transformers-1.42.9}/x_transformers.egg-info/SOURCES.txt +0 -0
- {x_transformers-1.42.8 → x_transformers-1.42.9}/x_transformers.egg-info/dependency_links.txt +0 -0
- {x_transformers-1.42.8 → x_transformers-1.42.9}/x_transformers.egg-info/requires.txt +0 -0
- {x_transformers-1.42.8 → x_transformers-1.42.9}/x_transformers.egg-info/top_level.txt +0 -0
@@ -381,6 +381,7 @@ def test_neo_mlp():
|
|
381
381
|
assert out.shape == (3, 7)
|
382
382
|
|
383
383
|
def test_custom_alibi():
|
384
|
+
|
384
385
|
model = TransformerWrapper(
|
385
386
|
num_tokens = 20_000,
|
386
387
|
max_seq_len = 1024,
|
@@ -398,6 +399,26 @@ def test_custom_alibi():
|
|
398
399
|
|
399
400
|
logits = model(x, pos = pos)
|
400
401
|
|
402
|
+
def test_custom_alibi_across_heads():
|
403
|
+
|
404
|
+
model = Decoder(
|
405
|
+
dim = 512,
|
406
|
+
depth = 2,
|
407
|
+
heads = 2,
|
408
|
+
alibi_pos_bias = True,
|
409
|
+
rel_pos_kwargs = dict(
|
410
|
+
slopes = [1, 1]
|
411
|
+
),
|
412
|
+
)
|
413
|
+
|
414
|
+
x = torch.randn(2, 4, 512)
|
415
|
+
|
416
|
+
pos = torch.tensor([
|
417
|
+
[[0, 1, 2, 4], [1, 3, 5, 7]],
|
418
|
+
[[2, 3, 4, 5], [6, 8, 9, 10]]
|
419
|
+
])
|
420
|
+
|
421
|
+
embed = model(x, pos = pos)
|
401
422
|
|
402
423
|
@pytest.mark.parametrize('embedder_type', ('embedding', 'none', 'custom'))
|
403
424
|
def test_embedder(embedder_type):
|
@@ -452,13 +452,20 @@ class DynamicPositionBias(Module):
|
|
452
452
|
return bias
|
453
453
|
|
454
454
|
class AlibiPositionalBias(Module):
|
455
|
-
def __init__(
|
455
|
+
def __init__(
|
456
|
+
self,
|
457
|
+
heads,
|
458
|
+
total_heads = None,
|
459
|
+
slopes: list[int] | None = None,
|
460
|
+
**kwargs
|
461
|
+
):
|
456
462
|
super().__init__()
|
457
463
|
self.heads = heads
|
458
464
|
self.total_heads = default(total_heads, heads)
|
459
465
|
|
460
|
-
slopes = Tensor(self._get_slopes(heads))
|
466
|
+
slopes = Tensor(default(slopes, self._get_slopes(heads)))
|
461
467
|
slopes = rearrange(slopes, 'h -> h 1 1')
|
468
|
+
|
462
469
|
self.register_buffer('slopes', slopes, persistent = False)
|
463
470
|
self.register_buffer('bias', None, persistent = False)
|
464
471
|
|
@@ -487,7 +494,10 @@ class AlibiPositionalBias(Module):
|
|
487
494
|
h, device = self.total_heads, self.device
|
488
495
|
|
489
496
|
pos_j = default(pos_j, pos_i)
|
490
|
-
bias = -einx.subtract('... j, ... i -> ...
|
497
|
+
bias = -einx.subtract('... j, ... i -> ... i j', pos_j, pos_i).abs()
|
498
|
+
|
499
|
+
if bias.ndim == 3:
|
500
|
+
bias = rearrange(bias, 'b i j -> b 1 i j')
|
491
501
|
|
492
502
|
bias = bias * self.slopes
|
493
503
|
num_heads_unalibied = h - bias.shape[-3]
|
@@ -1531,8 +1541,9 @@ class AttentionLayers(Module):
|
|
1531
1541
|
use_layerscale = False,
|
1532
1542
|
layerscale_init_value = 0.,
|
1533
1543
|
unet_skips = False,
|
1534
|
-
reinject_input = False,
|
1535
|
-
add_value_residual = False,
|
1544
|
+
reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
|
1545
|
+
add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1
|
1546
|
+
rel_pos_kwargs: dict = dict(),
|
1536
1547
|
**kwargs
|
1537
1548
|
):
|
1538
1549
|
super().__init__()
|
@@ -1573,14 +1584,14 @@ class AttentionLayers(Module):
|
|
1573
1584
|
|
1574
1585
|
if rel_pos_bias:
|
1575
1586
|
assert not flash_attn, 'flash attention not compatible with t5 relative positional bias'
|
1576
|
-
self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
|
1587
|
+
self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance, **rel_pos_kwargs)
|
1577
1588
|
elif dynamic_pos_bias:
|
1578
1589
|
assert not flash_attn, 'flash attention not compatible with dynamic positional bias'
|
1579
|
-
self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = heads, log_distance = dynamic_pos_bias_log_distance, depth = dynamic_pos_bias_mlp_depth, norm = dynamic_pos_bias_norm)
|
1590
|
+
self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = heads, log_distance = dynamic_pos_bias_log_distance, depth = dynamic_pos_bias_mlp_depth, norm = dynamic_pos_bias_norm, **rel_pos_kwargs)
|
1580
1591
|
elif alibi_pos_bias:
|
1581
1592
|
alibi_num_heads = default(alibi_num_heads, heads)
|
1582
1593
|
assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
|
1583
|
-
self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads)
|
1594
|
+
self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads, **rel_pos_kwargs)
|
1584
1595
|
|
1585
1596
|
assert at_most_one_of(sandwich_norm, resi_dual), 'either sandwich norm or resiDual is selected, but not both'
|
1586
1597
|
assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.42.8 → x_transformers-1.42.9}/x_transformers.egg-info/dependency_links.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|