x-transformers 1.42.24__py3-none-any.whl → 1.42.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +33 -11
- {x_transformers-1.42.24.dist-info → x_transformers-1.42.26.dist-info}/METADATA +1 -1
- {x_transformers-1.42.24.dist-info → x_transformers-1.42.26.dist-info}/RECORD +6 -6
- {x_transformers-1.42.24.dist-info → x_transformers-1.42.26.dist-info}/LICENSE +0 -0
- {x_transformers-1.42.24.dist-info → x_transformers-1.42.26.dist-info}/WHEEL +0 -0
- {x_transformers-1.42.24.dist-info → x_transformers-1.42.26.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -51,8 +51,8 @@ def default(val, d):
|
|
51
51
|
return val
|
52
52
|
return d() if callable(d) else d
|
53
53
|
|
54
|
-
def first(it):
|
55
|
-
return it[0]
|
54
|
+
def first(it, default = None):
|
55
|
+
return it[0] if len(it) > 0 else default
|
56
56
|
|
57
57
|
def is_empty(x):
|
58
58
|
return len(x) == 0
|
@@ -1284,6 +1284,7 @@ class Attention(Module):
|
|
1284
1284
|
rel_pos = None,
|
1285
1285
|
attn_bias = None,
|
1286
1286
|
rotary_pos_emb = None,
|
1287
|
+
context_rotary_pos_emb = None,
|
1287
1288
|
pos = None, # for custom alibi positions
|
1288
1289
|
prev_attn = None,
|
1289
1290
|
mem = None,
|
@@ -1355,11 +1356,18 @@ class Attention(Module):
|
|
1355
1356
|
q = q * self.qk_norm_q_scale
|
1356
1357
|
k = k * self.qk_norm_k_scale
|
1357
1358
|
|
1358
|
-
if exists(rotary_pos_emb)
|
1359
|
+
if exists(rotary_pos_emb):
|
1359
1360
|
freqs, xpos_scale = rotary_pos_emb
|
1360
1361
|
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
|
1361
1362
|
|
1362
1363
|
q = apply_rotary_pos_emb(q, freqs, q_xpos_scale)
|
1364
|
+
|
1365
|
+
if has_context:
|
1366
|
+
# override with `context_rotary_pos_emb` if provided
|
1367
|
+
|
1368
|
+
freqs, xpos_scale = context_rotary_pos_emb
|
1369
|
+
_, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
|
1370
|
+
|
1363
1371
|
k = apply_rotary_pos_emb(k, freqs, k_xpos_scale)
|
1364
1372
|
|
1365
1373
|
if self.rotary_embed_values:
|
@@ -1848,7 +1856,6 @@ class AttentionLayers(Module):
|
|
1848
1856
|
layer = Attention(dim, heads = heads, causal = causal, learned_value_residual_mix = self_attn_learned_value_residual, **attn_kwargs)
|
1849
1857
|
is_first_self_attn = False
|
1850
1858
|
elif layer_type == 'c':
|
1851
|
-
cross_attn_learned_value_residual = learned_value_residual_mix and not is_first_cross_attn
|
1852
1859
|
layer = Attention(dim, heads = heads, **{**attn_kwargs, **cross_attn_kwargs})
|
1853
1860
|
is_first_cross_attn = False
|
1854
1861
|
elif layer_type == 'f':
|
@@ -1917,6 +1924,7 @@ class AttentionLayers(Module):
|
|
1917
1924
|
return_hiddens = False,
|
1918
1925
|
rotary_pos_emb = None,
|
1919
1926
|
pos = None,
|
1927
|
+
context_pos = None,
|
1920
1928
|
attn_bias = None,
|
1921
1929
|
condition = None,
|
1922
1930
|
in_attn_cond = None, # https://arxiv.org/abs/2105.04090
|
@@ -1976,14 +1984,28 @@ class AttentionLayers(Module):
|
|
1976
1984
|
|
1977
1985
|
# rotary positions
|
1978
1986
|
|
1979
|
-
|
1980
|
-
|
1981
|
-
|
1987
|
+
cross_attn_rotary_pos_emb = dict()
|
1988
|
+
|
1989
|
+
if exists(self.rotary_pos_emb):
|
1990
|
+
if not exists(rotary_pos_emb):
|
1991
|
+
maybe_mem = first(mems, None) # todo - handle edge case where different layers get different memory lengths. don't think this will ever come up but who knows
|
1992
|
+
mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
|
1993
|
+
|
1994
|
+
if not exists(pos):
|
1995
|
+
pos = torch.arange(x.shape[1] + mem_len, device = x.device) - mem_len
|
1996
|
+
|
1997
|
+
rotary_pos_emb = self.rotary_pos_emb(pos)
|
1998
|
+
|
1999
|
+
# allow for rotary positions for context if provided
|
1982
2000
|
|
1983
|
-
if
|
1984
|
-
|
2001
|
+
if exists(context_pos):
|
2002
|
+
assert self.cross_attend
|
2003
|
+
context_rotary_pos_emb = self.rotary_pos_emb(context_pos)
|
1985
2004
|
|
1986
|
-
|
2005
|
+
cross_attn_rotary_pos_emb.update(
|
2006
|
+
rotary_pos_emb = rotary_pos_emb,
|
2007
|
+
context_rotary_pos_emb = context_rotary_pos_emb
|
2008
|
+
)
|
1987
2009
|
|
1988
2010
|
# assume cached key / values
|
1989
2011
|
|
@@ -2108,7 +2130,7 @@ class AttentionLayers(Module):
|
|
2108
2130
|
if layer_type == 'a':
|
2109
2131
|
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
|
2110
2132
|
elif layer_type == 'c':
|
2111
|
-
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, return_intermediates = True)
|
2133
|
+
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
|
2112
2134
|
elif layer_type == 'f':
|
2113
2135
|
out = block(x)
|
2114
2136
|
|
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=iocWSfj6h0GvmvXBIRBqzZm8l4IZDf4so0pOcIy3jFg,96696
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-1.42.
|
13
|
-
x_transformers-1.42.
|
14
|
-
x_transformers-1.42.
|
15
|
-
x_transformers-1.42.
|
16
|
-
x_transformers-1.42.
|
12
|
+
x_transformers-1.42.26.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
13
|
+
x_transformers-1.42.26.dist-info/METADATA,sha256=cf7t334edZKkTEpnwlhTQNYRgLDU6SSBPxtsA58ubxY,739
|
14
|
+
x_transformers-1.42.26.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
15
|
+
x_transformers-1.42.26.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
16
|
+
x_transformers-1.42.26.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|