x-transformers 1.42.24__tar.gz → 1.42.26__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {x_transformers-1.42.24/x_transformers.egg-info → x_transformers-1.42.26}/PKG-INFO +1 -1
- {x_transformers-1.42.24 → x_transformers-1.42.26}/README.md +2 -0
- {x_transformers-1.42.24 → x_transformers-1.42.26}/setup.py +1 -1
- {x_transformers-1.42.24 → x_transformers-1.42.26}/tests/test_x_transformers.py +33 -0
- {x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers/x_transformers.py +33 -11
- {x_transformers-1.42.24 → x_transformers-1.42.26/x_transformers.egg-info}/PKG-INFO +1 -1
- {x_transformers-1.42.24 → x_transformers-1.42.26}/LICENSE +0 -0
- {x_transformers-1.42.24 → x_transformers-1.42.26}/setup.cfg +0 -0
- {x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers/__init__.py +0 -0
- {x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers/attend.py +0 -0
- {x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers/continuous.py +0 -0
- {x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers/dpo.py +0 -0
- {x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers/multi_input.py +0 -0
- {x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers/xval.py +0 -0
- {x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers.egg-info/SOURCES.txt +0 -0
- {x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers.egg-info/dependency_links.txt +0 -0
- {x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers.egg-info/requires.txt +0 -0
- {x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers.egg-info/top_level.txt +0 -0
@@ -317,6 +317,8 @@ model = TransformerWrapper(
|
|
317
317
|
|
318
318
|
Update: MetaAI researchers <a href="https://arxiv.org/abs/2309.16588">have found</a> that adding memory tokens (they call them register tokens), alleviates outliers (which is suspected now to be a pathology of attention networks unable to <a href="https://arxiv.org/abs/2306.12929">attend to nothing</a>).
|
319
319
|
|
320
|
+
Update 2: a hybrid architecture out of Nvidia named <a href="https://openreview.net/forum?id=A1ztozypga">Hymba</a> used memory tokens successfully in the autoregressive case, termed meta tokens in their paper
|
321
|
+
|
320
322
|
### Transformers Without Tears
|
321
323
|
|
322
324
|
<img src="./images/scalenorm.png"></img>
|
@@ -557,3 +557,36 @@ def test_laser():
|
|
557
557
|
x = torch.randint(0, 20000, (2, 1024))
|
558
558
|
|
559
559
|
model(x)
|
560
|
+
|
561
|
+
@pytest.mark.parametrize('self_attn_custom_pos', (True, False))
|
562
|
+
@pytest.mark.parametrize('cross_attn_rotary', (True, False))
|
563
|
+
def test_cross_attn_rotary(
|
564
|
+
self_attn_custom_pos: bool,
|
565
|
+
cross_attn_rotary: bool
|
566
|
+
):
|
567
|
+
|
568
|
+
x = torch.randn((1, 64, 256))
|
569
|
+
mask = torch.ones((1, 64)).bool()
|
570
|
+
context = torch.randn((1, 128, 512))
|
571
|
+
context_mask = torch.ones((1, 128)).bool()
|
572
|
+
|
573
|
+
model = Encoder(
|
574
|
+
dim = 256,
|
575
|
+
depth = 4,
|
576
|
+
heads = 4,
|
577
|
+
rotary_pos_emb = True,
|
578
|
+
cross_attend = True,
|
579
|
+
cross_attn_dim_context = 512
|
580
|
+
)
|
581
|
+
|
582
|
+
pos = torch.arange(64) if self_attn_custom_pos else None
|
583
|
+
context_pos = torch.arange(128) if cross_attn_rotary else None
|
584
|
+
|
585
|
+
embed = model(
|
586
|
+
x = x,
|
587
|
+
mask = mask,
|
588
|
+
context = context,
|
589
|
+
pos = pos,
|
590
|
+
context_pos = context_pos,
|
591
|
+
context_mask = context_mask
|
592
|
+
)
|
@@ -51,8 +51,8 @@ def default(val, d):
|
|
51
51
|
return val
|
52
52
|
return d() if callable(d) else d
|
53
53
|
|
54
|
-
def first(it):
|
55
|
-
return it[0]
|
54
|
+
def first(it, default = None):
|
55
|
+
return it[0] if len(it) > 0 else default
|
56
56
|
|
57
57
|
def is_empty(x):
|
58
58
|
return len(x) == 0
|
@@ -1284,6 +1284,7 @@ class Attention(Module):
|
|
1284
1284
|
rel_pos = None,
|
1285
1285
|
attn_bias = None,
|
1286
1286
|
rotary_pos_emb = None,
|
1287
|
+
context_rotary_pos_emb = None,
|
1287
1288
|
pos = None, # for custom alibi positions
|
1288
1289
|
prev_attn = None,
|
1289
1290
|
mem = None,
|
@@ -1355,11 +1356,18 @@ class Attention(Module):
|
|
1355
1356
|
q = q * self.qk_norm_q_scale
|
1356
1357
|
k = k * self.qk_norm_k_scale
|
1357
1358
|
|
1358
|
-
if exists(rotary_pos_emb)
|
1359
|
+
if exists(rotary_pos_emb):
|
1359
1360
|
freqs, xpos_scale = rotary_pos_emb
|
1360
1361
|
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
|
1361
1362
|
|
1362
1363
|
q = apply_rotary_pos_emb(q, freqs, q_xpos_scale)
|
1364
|
+
|
1365
|
+
if has_context:
|
1366
|
+
# override with `context_rotary_pos_emb` if provided
|
1367
|
+
|
1368
|
+
freqs, xpos_scale = context_rotary_pos_emb
|
1369
|
+
_, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
|
1370
|
+
|
1363
1371
|
k = apply_rotary_pos_emb(k, freqs, k_xpos_scale)
|
1364
1372
|
|
1365
1373
|
if self.rotary_embed_values:
|
@@ -1848,7 +1856,6 @@ class AttentionLayers(Module):
|
|
1848
1856
|
layer = Attention(dim, heads = heads, causal = causal, learned_value_residual_mix = self_attn_learned_value_residual, **attn_kwargs)
|
1849
1857
|
is_first_self_attn = False
|
1850
1858
|
elif layer_type == 'c':
|
1851
|
-
cross_attn_learned_value_residual = learned_value_residual_mix and not is_first_cross_attn
|
1852
1859
|
layer = Attention(dim, heads = heads, **{**attn_kwargs, **cross_attn_kwargs})
|
1853
1860
|
is_first_cross_attn = False
|
1854
1861
|
elif layer_type == 'f':
|
@@ -1917,6 +1924,7 @@ class AttentionLayers(Module):
|
|
1917
1924
|
return_hiddens = False,
|
1918
1925
|
rotary_pos_emb = None,
|
1919
1926
|
pos = None,
|
1927
|
+
context_pos = None,
|
1920
1928
|
attn_bias = None,
|
1921
1929
|
condition = None,
|
1922
1930
|
in_attn_cond = None, # https://arxiv.org/abs/2105.04090
|
@@ -1976,14 +1984,28 @@ class AttentionLayers(Module):
|
|
1976
1984
|
|
1977
1985
|
# rotary positions
|
1978
1986
|
|
1979
|
-
|
1980
|
-
|
1981
|
-
|
1987
|
+
cross_attn_rotary_pos_emb = dict()
|
1988
|
+
|
1989
|
+
if exists(self.rotary_pos_emb):
|
1990
|
+
if not exists(rotary_pos_emb):
|
1991
|
+
maybe_mem = first(mems, None) # todo - handle edge case where different layers get different memory lengths. don't think this will ever come up but who knows
|
1992
|
+
mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
|
1993
|
+
|
1994
|
+
if not exists(pos):
|
1995
|
+
pos = torch.arange(x.shape[1] + mem_len, device = x.device) - mem_len
|
1996
|
+
|
1997
|
+
rotary_pos_emb = self.rotary_pos_emb(pos)
|
1998
|
+
|
1999
|
+
# allow for rotary positions for context if provided
|
1982
2000
|
|
1983
|
-
if
|
1984
|
-
|
2001
|
+
if exists(context_pos):
|
2002
|
+
assert self.cross_attend
|
2003
|
+
context_rotary_pos_emb = self.rotary_pos_emb(context_pos)
|
1985
2004
|
|
1986
|
-
|
2005
|
+
cross_attn_rotary_pos_emb.update(
|
2006
|
+
rotary_pos_emb = rotary_pos_emb,
|
2007
|
+
context_rotary_pos_emb = context_rotary_pos_emb
|
2008
|
+
)
|
1987
2009
|
|
1988
2010
|
# assume cached key / values
|
1989
2011
|
|
@@ -2108,7 +2130,7 @@ class AttentionLayers(Module):
|
|
2108
2130
|
if layer_type == 'a':
|
2109
2131
|
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
|
2110
2132
|
elif layer_type == 'c':
|
2111
|
-
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, return_intermediates = True)
|
2133
|
+
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
|
2112
2134
|
elif layer_type == 'f':
|
2113
2135
|
out = block(x)
|
2114
2136
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers/nonautoregressive_wrapper.py
RENAMED
File without changes
|
{x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers/xl_autoregressive_wrapper.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers.egg-info/dependency_links.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|